from fastapi import FastAPI, HTTPException from pydantic import BaseModel, Field from typing import Optional, Dict, Any import asyncio from .config import ( get_gradient_endpoint_url, get_gradient_api_key, get_gradient_auth_scheme, ) from .gradient_client import call_gradient_inference from .router import decide_route app = FastAPI(title="a2a-router", version="0.1.0") class RouteRequest(BaseModel): message: str = Field(..., description="User message to route") route_hint: Optional[str] = Field(None, description="Optional explicit route: 'gradient' or 'local'") metadata: Optional[Dict[str, Any]] = Field(default=None, description="Optional JSON metadata") class RouteResponse(BaseModel): route: str output: Dict[str, Any] @app.get("/healthz") async def healthz() -> Dict[str, str]: return {"status": "ok"} @app.post("/route", response_model=RouteResponse) async def route_message(req: RouteRequest) -> RouteResponse: route = decide_route(message=req.message, explicit_hint=req.route_hint) if route == "local": return RouteResponse(route=route, output={"echo": req.message, "metadata": req.metadata or {}}) if route == "gradient": endpoint_url = get_gradient_endpoint_url() api_key = get_gradient_api_key() auth_scheme = get_gradient_auth_scheme() if not endpoint_url: raise HTTPException(status_code=500, detail="Missing GRADIENT_ENDPOINT_URL") if not api_key: raise HTTPException(status_code=500, detail="Missing GRADIENT_API_KEY") try: result = await call_gradient_inference( endpoint_url=endpoint_url, api_key=api_key, message=req.message, auth_scheme=auth_scheme, extra_payload={"metadata": req.metadata} if req.metadata else None, ) return RouteResponse(route=route, output=result) except Exception as exc: raise HTTPException(status_code=502, detail=f"Gradient call failed: {exc}") raise HTTPException(status_code=400, detail=f"Unknown route: {route}") if __name__ == "__main__": import uvicorn uvicorn.run("app.main:app", host="0.0.0.0", port=8080, reload=False)