Files
starlane-router/app/main.py

69 lines
2.2 KiB
Python

from fastapi import FastAPI, HTTPException
from pydantic import BaseModel, Field
from typing import Optional, Dict, Any
import asyncio
from .config import (
get_gradient_endpoint_url,
get_gradient_api_key,
get_gradient_auth_scheme,
)
from .gradient_client import call_gradient_inference
from .router import decide_route
app = FastAPI(title="a2a-router", version="0.1.0")
class RouteRequest(BaseModel):
message: str = Field(..., description="User message to route")
route_hint: Optional[str] = Field(None, description="Optional explicit route: 'gradient' or 'local'")
metadata: Optional[Dict[str, Any]] = Field(default=None, description="Optional JSON metadata")
class RouteResponse(BaseModel):
route: str
output: Dict[str, Any]
@app.get("/healthz")
async def healthz() -> Dict[str, str]:
return {"status": "ok"}
@app.post("/route", response_model=RouteResponse)
async def route_message(req: RouteRequest) -> RouteResponse:
route = decide_route(message=req.message, explicit_hint=req.route_hint)
if route == "local":
return RouteResponse(route=route, output={"echo": req.message, "metadata": req.metadata or {}})
if route == "gradient":
endpoint_url = get_gradient_endpoint_url()
api_key = get_gradient_api_key()
auth_scheme = get_gradient_auth_scheme()
if not endpoint_url:
raise HTTPException(status_code=500, detail="Missing GRADIENT_ENDPOINT_URL")
if not api_key:
raise HTTPException(status_code=500, detail="Missing GRADIENT_API_KEY")
try:
result = await call_gradient_inference(
endpoint_url=endpoint_url,
api_key=api_key,
message=req.message,
auth_scheme=auth_scheme,
extra_payload={"metadata": req.metadata} if req.metadata else None,
)
return RouteResponse(route=route, output=result)
except Exception as exc:
raise HTTPException(status_code=502, detail=f"Gradient call failed: {exc}")
raise HTTPException(status_code=400, detail=f"Unknown route: {route}")
if __name__ == "__main__":
import uvicorn
uvicorn.run("app.main:app", host="0.0.0.0", port=8080, reload=False)