mirror of
https://github.com/ghndrx/starlane-router.git
synced 2026-02-10 06:45:01 +00:00
feat: initial commit for starlane-router (FastAPI + Gradient)
This commit is contained in:
1
app/__init__.py
Normal file
1
app/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
# Package marker for app
|
||||
24
app/config.py
Normal file
24
app/config.py
Normal file
@@ -0,0 +1,24 @@
|
||||
import os
|
||||
from typing import List
|
||||
|
||||
|
||||
def _get_env(name: str, default: str = "") -> str:
|
||||
return os.getenv(name, default).strip()
|
||||
|
||||
|
||||
def get_gradient_endpoint_url() -> str:
|
||||
return _get_env("GRADIENT_ENDPOINT_URL")
|
||||
|
||||
|
||||
def get_gradient_api_key() -> str:
|
||||
return _get_env("GRADIENT_API_KEY")
|
||||
|
||||
|
||||
def get_gradient_auth_scheme() -> str:
|
||||
# 'authorization_bearer' or 'x_api_key'
|
||||
return _get_env("GRADIENT_AUTH_SCHEME", "authorization_bearer").lower()
|
||||
|
||||
|
||||
def get_route_keywords() -> List[str]:
|
||||
raw = _get_env("ROUTE_KEYWORDS", "ai,model,ml,gpt,router,gradient")
|
||||
return [kw.strip().lower() for kw in raw.split(",") if kw.strip()]
|
||||
34
app/gradient_client.py
Normal file
34
app/gradient_client.py
Normal file
@@ -0,0 +1,34 @@
|
||||
from typing import Optional, Dict, Any
|
||||
import httpx
|
||||
|
||||
|
||||
async def call_gradient_inference(
|
||||
*,
|
||||
endpoint_url: str,
|
||||
api_key: str,
|
||||
message: str,
|
||||
auth_scheme: str = "authorization_bearer",
|
||||
timeout_seconds: float = 30.0,
|
||||
extra_payload: Optional[Dict[str, Any]] = None,
|
||||
) -> Dict[str, Any]:
|
||||
headers = _build_headers(api_key=api_key, auth_scheme=auth_scheme)
|
||||
|
||||
payload: Dict[str, Any] = {"input": message}
|
||||
if extra_payload:
|
||||
payload.update(extra_payload)
|
||||
|
||||
async with httpx.AsyncClient(timeout=timeout_seconds) as client:
|
||||
resp = await client.post(endpoint_url, headers=headers, json=payload)
|
||||
resp.raise_for_status()
|
||||
data = resp.json()
|
||||
|
||||
# Normalize output shape
|
||||
return {"raw": data}
|
||||
|
||||
|
||||
def _build_headers(*, api_key: str, auth_scheme: str) -> Dict[str, str]:
|
||||
scheme = (auth_scheme or "authorization_bearer").lower()
|
||||
if scheme == "x_api_key":
|
||||
return {"X-API-Key": api_key}
|
||||
# default
|
||||
return {"Authorization": f"Bearer {api_key}"}
|
||||
69
app/main.py
Normal file
69
app/main.py
Normal file
@@ -0,0 +1,69 @@
|
||||
from fastapi import FastAPI, HTTPException
|
||||
from pydantic import BaseModel, Field
|
||||
from typing import Optional, Dict, Any
|
||||
import asyncio
|
||||
|
||||
from .config import (
|
||||
get_gradient_endpoint_url,
|
||||
get_gradient_api_key,
|
||||
get_gradient_auth_scheme,
|
||||
)
|
||||
from .gradient_client import call_gradient_inference
|
||||
from .router import decide_route
|
||||
|
||||
|
||||
app = FastAPI(title="a2a-router", version="0.1.0")
|
||||
|
||||
|
||||
class RouteRequest(BaseModel):
|
||||
message: str = Field(..., description="User message to route")
|
||||
route_hint: Optional[str] = Field(None, description="Optional explicit route: 'gradient' or 'local'")
|
||||
metadata: Optional[Dict[str, Any]] = Field(default=None, description="Optional JSON metadata")
|
||||
|
||||
|
||||
class RouteResponse(BaseModel):
|
||||
route: str
|
||||
output: Dict[str, Any]
|
||||
|
||||
|
||||
@app.get("/healthz")
|
||||
async def healthz() -> Dict[str, str]:
|
||||
return {"status": "ok"}
|
||||
|
||||
|
||||
@app.post("/route", response_model=RouteResponse)
|
||||
async def route_message(req: RouteRequest) -> RouteResponse:
|
||||
route = decide_route(message=req.message, explicit_hint=req.route_hint)
|
||||
|
||||
if route == "local":
|
||||
return RouteResponse(route=route, output={"echo": req.message, "metadata": req.metadata or {}})
|
||||
|
||||
if route == "gradient":
|
||||
endpoint_url = get_gradient_endpoint_url()
|
||||
api_key = get_gradient_api_key()
|
||||
auth_scheme = get_gradient_auth_scheme()
|
||||
|
||||
if not endpoint_url:
|
||||
raise HTTPException(status_code=500, detail="Missing GRADIENT_ENDPOINT_URL")
|
||||
if not api_key:
|
||||
raise HTTPException(status_code=500, detail="Missing GRADIENT_API_KEY")
|
||||
|
||||
try:
|
||||
result = await call_gradient_inference(
|
||||
endpoint_url=endpoint_url,
|
||||
api_key=api_key,
|
||||
message=req.message,
|
||||
auth_scheme=auth_scheme,
|
||||
extra_payload={"metadata": req.metadata} if req.metadata else None,
|
||||
)
|
||||
return RouteResponse(route=route, output=result)
|
||||
except Exception as exc:
|
||||
raise HTTPException(status_code=502, detail=f"Gradient call failed: {exc}")
|
||||
|
||||
raise HTTPException(status_code=400, detail=f"Unknown route: {route}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
|
||||
uvicorn.run("app.main:app", host="0.0.0.0", port=8080, reload=False)
|
||||
21
app/router.py
Normal file
21
app/router.py
Normal file
@@ -0,0 +1,21 @@
|
||||
from typing import Optional
|
||||
from .config import get_route_keywords
|
||||
|
||||
|
||||
def decide_route(*, message: str, explicit_hint: Optional[str] = None) -> str:
|
||||
if explicit_hint:
|
||||
hint = explicit_hint.strip().lower()
|
||||
if hint in {"gradient", "local"}:
|
||||
return hint
|
||||
|
||||
text = (message or "").strip().lower()
|
||||
|
||||
# Simple heuristic: if message is long or has keywords, use gradient
|
||||
if len(text) > 120:
|
||||
return "gradient"
|
||||
|
||||
for kw in get_route_keywords():
|
||||
if kw in text:
|
||||
return "gradient"
|
||||
|
||||
return "local"
|
||||
Reference in New Issue
Block a user