mirror of
https://github.com/ghndrx/starlane-router.git
synced 2026-02-10 06:45:01 +00:00
feat: initial commit for starlane-router (FastAPI + Gradient)
This commit is contained in:
69
app/main.py
Normal file
69
app/main.py
Normal file
@@ -0,0 +1,69 @@
|
||||
from fastapi import FastAPI, HTTPException
|
||||
from pydantic import BaseModel, Field
|
||||
from typing import Optional, Dict, Any
|
||||
import asyncio
|
||||
|
||||
from .config import (
|
||||
get_gradient_endpoint_url,
|
||||
get_gradient_api_key,
|
||||
get_gradient_auth_scheme,
|
||||
)
|
||||
from .gradient_client import call_gradient_inference
|
||||
from .router import decide_route
|
||||
|
||||
|
||||
app = FastAPI(title="a2a-router", version="0.1.0")
|
||||
|
||||
|
||||
class RouteRequest(BaseModel):
|
||||
message: str = Field(..., description="User message to route")
|
||||
route_hint: Optional[str] = Field(None, description="Optional explicit route: 'gradient' or 'local'")
|
||||
metadata: Optional[Dict[str, Any]] = Field(default=None, description="Optional JSON metadata")
|
||||
|
||||
|
||||
class RouteResponse(BaseModel):
|
||||
route: str
|
||||
output: Dict[str, Any]
|
||||
|
||||
|
||||
@app.get("/healthz")
|
||||
async def healthz() -> Dict[str, str]:
|
||||
return {"status": "ok"}
|
||||
|
||||
|
||||
@app.post("/route", response_model=RouteResponse)
|
||||
async def route_message(req: RouteRequest) -> RouteResponse:
|
||||
route = decide_route(message=req.message, explicit_hint=req.route_hint)
|
||||
|
||||
if route == "local":
|
||||
return RouteResponse(route=route, output={"echo": req.message, "metadata": req.metadata or {}})
|
||||
|
||||
if route == "gradient":
|
||||
endpoint_url = get_gradient_endpoint_url()
|
||||
api_key = get_gradient_api_key()
|
||||
auth_scheme = get_gradient_auth_scheme()
|
||||
|
||||
if not endpoint_url:
|
||||
raise HTTPException(status_code=500, detail="Missing GRADIENT_ENDPOINT_URL")
|
||||
if not api_key:
|
||||
raise HTTPException(status_code=500, detail="Missing GRADIENT_API_KEY")
|
||||
|
||||
try:
|
||||
result = await call_gradient_inference(
|
||||
endpoint_url=endpoint_url,
|
||||
api_key=api_key,
|
||||
message=req.message,
|
||||
auth_scheme=auth_scheme,
|
||||
extra_payload={"metadata": req.metadata} if req.metadata else None,
|
||||
)
|
||||
return RouteResponse(route=route, output=result)
|
||||
except Exception as exc:
|
||||
raise HTTPException(status_code=502, detail=f"Gradient call failed: {exc}")
|
||||
|
||||
raise HTTPException(status_code=400, detail=f"Unknown route: {route}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
|
||||
uvicorn.run("app.main:app", host="0.0.0.0", port=8080, reload=False)
|
||||
Reference in New Issue
Block a user