mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-04 04:04:14 +00:00
202 lines
5.6 KiB
Python
202 lines
5.6 KiB
Python
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
# All rights reserved.
|
|
#
|
|
# This source code is licensed under the terms described in the LICENSE file in
|
|
# the root directory of this source tree.
|
|
|
|
import asyncio
|
|
import json
|
|
import signal
|
|
from collections.abc import (
|
|
AsyncGenerator as AsyncGeneratorABC,
|
|
AsyncIterator as AsyncIteratorABC,
|
|
)
|
|
from contextlib import asynccontextmanager
|
|
from typing import Any, AsyncGenerator, AsyncIterator, Dict, get_type_hints, Optional
|
|
|
|
import fire
|
|
import httpx
|
|
import yaml
|
|
from dotenv import load_dotenv
|
|
|
|
from fastapi import FastAPI, Request
|
|
from fastapi.responses import StreamingResponse
|
|
from fastapi.routing import APIRoute
|
|
|
|
from pydantic import BaseModel
|
|
from termcolor import cprint
|
|
|
|
from .datatypes import PassthroughApiAdapter
|
|
from .distribution import api_surface_endpoints
|
|
from .dynamic import instantiate_adapter
|
|
|
|
from .registry import resolve_distribution
|
|
|
|
load_dotenv()
|
|
|
|
|
|
def is_async_iterator_type(typ):
|
|
if hasattr(typ, "__origin__"):
|
|
origin = typ.__origin__
|
|
if isinstance(origin, type):
|
|
return issubclass(
|
|
origin,
|
|
(AsyncIterator, AsyncGenerator, AsyncIteratorABC, AsyncGeneratorABC),
|
|
)
|
|
return False
|
|
return isinstance(
|
|
typ, (AsyncIterator, AsyncGenerator, AsyncIteratorABC, AsyncGeneratorABC)
|
|
)
|
|
|
|
|
|
def create_sse_event(data: Any) -> str:
|
|
if isinstance(data, BaseModel):
|
|
data = data.json()
|
|
else:
|
|
data = json.dumps(data)
|
|
|
|
return f"data: {data}\n\n"
|
|
|
|
|
|
async def passthrough(
|
|
request: Request,
|
|
downstream_url: str,
|
|
downstream_headers: Optional[Dict[str, str]] = None,
|
|
):
|
|
client = httpx.AsyncClient()
|
|
|
|
headers = dict(request.headers)
|
|
headers.pop("host", None)
|
|
headers.update(downstream_headers or {})
|
|
|
|
body = await request.body()
|
|
|
|
try:
|
|
response = await client.request(
|
|
method=request.method,
|
|
url=downstream_url,
|
|
headers=headers,
|
|
content=body,
|
|
params=request.query_params,
|
|
)
|
|
return StreamingResponse(
|
|
response.iter_bytes(),
|
|
status_code=response.status_code,
|
|
headers=dict(response.headers),
|
|
)
|
|
finally:
|
|
await client.aclose()
|
|
|
|
|
|
def handle_sigint(*args, **kwargs):
|
|
print("SIGINT or CTRL-C detected. Exiting gracefully", args)
|
|
loop = asyncio.get_event_loop()
|
|
for task in asyncio.all_tasks(loop):
|
|
task.cancel()
|
|
loop.stop()
|
|
|
|
|
|
@asynccontextmanager
|
|
async def lifespan(app: FastAPI):
|
|
print("Starting up")
|
|
yield
|
|
print("Shutting down")
|
|
|
|
|
|
def create_dynamic_passthrough(
|
|
downstream_url: str, downstream_headers: Optional[Dict[str, str]] = None
|
|
):
|
|
async def endpoint(request: Request):
|
|
return await passthrough(request, downstream_url, downstream_headers)
|
|
|
|
return endpoint
|
|
|
|
|
|
def create_dynamic_typed_route(func: Any):
|
|
hints = get_type_hints(func)
|
|
request_model = next(iter(hints.values()))
|
|
response_model = hints["return"]
|
|
|
|
is_streaming = is_async_iterator_type(response_model)
|
|
|
|
if is_streaming:
|
|
|
|
async def endpoint(request: request_model):
|
|
async def event_generator():
|
|
async for item in func(request):
|
|
yield create_sse_event(item)
|
|
await asyncio.sleep(0.001)
|
|
|
|
return StreamingResponse(event_generator(), media_type="text/event-stream")
|
|
|
|
else:
|
|
|
|
async def endpoint(request: request_model):
|
|
return func(request)
|
|
|
|
return endpoint
|
|
|
|
|
|
def main(
|
|
dist_name: str, yaml_config: str, port: int = 5000, disable_ipv6: bool = False
|
|
):
|
|
dist = resolve_distribution(dist_name)
|
|
if dist is None:
|
|
raise ValueError(f"Could not find distribution {dist_name}")
|
|
|
|
with open(yaml_config, "r") as fp:
|
|
config = yaml.safe_load(fp)
|
|
|
|
app = FastAPI()
|
|
|
|
all_endpoints = api_surface_endpoints()
|
|
|
|
adapter_configs = config["adapters"]
|
|
for surface, adapter in dist.adapters.items():
|
|
if surface.value not in adapter_configs:
|
|
raise ValueError(
|
|
f"Could not find adapter config for {surface}. Please add it to the config"
|
|
)
|
|
|
|
adapter_config = adapter_configs[surface.value]
|
|
endpoints = all_endpoints[surface]
|
|
if isinstance(adapter, PassthroughApiAdapter):
|
|
for endpoint in endpoints:
|
|
url = adapter.base_url.rstrip("/") + endpoint.route
|
|
getattr(app, endpoint.method)(endpoint.route)(
|
|
create_dynamic_passthrough(url)
|
|
)
|
|
else:
|
|
impl = instantiate_adapter(adapter, adapter_config)
|
|
for endpoint in endpoints:
|
|
if not hasattr(impl, endpoint.name):
|
|
# ideally this should be a typing violation already
|
|
raise ValueError(
|
|
f"Could not find method {endpoint.name} on {impl}!!"
|
|
)
|
|
|
|
impl_method = getattr(impl, endpoint.name)
|
|
getattr(app, endpoint.method)(endpoint.route, response_model=None)(
|
|
create_dynamic_typed_route(impl_method)
|
|
)
|
|
|
|
for route in app.routes:
|
|
if isinstance(route, APIRoute):
|
|
cprint(
|
|
f"Serving {next(iter(route.methods))} {route.path}",
|
|
"white",
|
|
attrs=["bold"],
|
|
)
|
|
|
|
signal.signal(signal.SIGINT, handle_sigint)
|
|
|
|
import uvicorn
|
|
|
|
# FYI this does not do hot-reloads
|
|
listen_host = "::" if not disable_ipv6 else "0.0.0.0"
|
|
print(f"Listening on {listen_host}:{port}")
|
|
uvicorn.run(app, host=listen_host, port=port)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
fire.Fire(main)
|