mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-04 20:14:13 +00:00
Distribution server now functioning
This commit is contained in:
parent
041cafbee3
commit
2cf9915806
21 changed files with 635 additions and 266 deletions
202
llama_toolchain/distribution/server.py
Normal file
202
llama_toolchain/distribution/server.py
Normal file
|
@ -0,0 +1,202 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import signal
|
||||
from collections.abc import (
|
||||
AsyncGenerator as AsyncGeneratorABC,
|
||||
AsyncIterator as AsyncIteratorABC,
|
||||
)
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import Any, AsyncGenerator, AsyncIterator, Dict, get_type_hints, Optional
|
||||
|
||||
import fire
|
||||
import httpx
|
||||
import yaml
|
||||
from dotenv import load_dotenv
|
||||
|
||||
from fastapi import FastAPI, Request
|
||||
from fastapi.responses import StreamingResponse
|
||||
from fastapi.routing import APIRoute
|
||||
|
||||
from pydantic import BaseModel
|
||||
from termcolor import cprint
|
||||
|
||||
from .datatypes import PassthroughApiAdapter
|
||||
from .distribution import api_surface_endpoints
|
||||
from .dynamic import instantiate_adapter
|
||||
|
||||
from .registry import resolve_distribution
|
||||
|
||||
load_dotenv()
|
||||
|
||||
|
||||
def is_async_iterator_type(typ):
|
||||
if hasattr(typ, "__origin__"):
|
||||
origin = typ.__origin__
|
||||
if isinstance(origin, type):
|
||||
return issubclass(
|
||||
origin,
|
||||
(AsyncIterator, AsyncGenerator, AsyncIteratorABC, AsyncGeneratorABC),
|
||||
)
|
||||
return False
|
||||
return isinstance(
|
||||
typ, (AsyncIterator, AsyncGenerator, AsyncIteratorABC, AsyncGeneratorABC)
|
||||
)
|
||||
|
||||
|
||||
def create_sse_event(data: Any) -> str:
|
||||
if isinstance(data, BaseModel):
|
||||
data = data.json()
|
||||
else:
|
||||
data = json.dumps(data)
|
||||
|
||||
return f"data: {data}\n\n"
|
||||
|
||||
|
||||
async def passthrough(
|
||||
request: Request,
|
||||
downstream_url: str,
|
||||
downstream_headers: Optional[Dict[str, str]] = None,
|
||||
):
|
||||
client = httpx.AsyncClient()
|
||||
|
||||
headers = dict(request.headers)
|
||||
headers.pop("host", None)
|
||||
headers.update(downstream_headers or {})
|
||||
|
||||
body = await request.body()
|
||||
|
||||
try:
|
||||
response = await client.request(
|
||||
method=request.method,
|
||||
url=downstream_url,
|
||||
headers=headers,
|
||||
content=body,
|
||||
params=request.query_params,
|
||||
)
|
||||
return StreamingResponse(
|
||||
response.iter_bytes(),
|
||||
status_code=response.status_code,
|
||||
headers=dict(response.headers),
|
||||
)
|
||||
finally:
|
||||
await client.aclose()
|
||||
|
||||
|
||||
def handle_sigint(*args, **kwargs):
|
||||
print("SIGINT or CTRL-C detected. Exiting gracefully", args)
|
||||
loop = asyncio.get_event_loop()
|
||||
for task in asyncio.all_tasks(loop):
|
||||
task.cancel()
|
||||
loop.stop()
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
print("Starting up")
|
||||
yield
|
||||
print("Shutting down")
|
||||
|
||||
|
||||
def create_dynamic_passthrough(
|
||||
downstream_url: str, downstream_headers: Optional[Dict[str, str]] = None
|
||||
):
|
||||
async def endpoint(request: Request):
|
||||
return await passthrough(request, downstream_url, downstream_headers)
|
||||
|
||||
return endpoint
|
||||
|
||||
|
||||
def create_dynamic_typed_route(func: Any):
|
||||
hints = get_type_hints(func)
|
||||
request_model = next(iter(hints.values()))
|
||||
response_model = hints["return"]
|
||||
|
||||
is_streaming = is_async_iterator_type(response_model)
|
||||
|
||||
if is_streaming:
|
||||
|
||||
async def endpoint(request: request_model):
|
||||
async def event_generator():
|
||||
async for item in func(request):
|
||||
yield create_sse_event(item)
|
||||
await asyncio.sleep(0.001)
|
||||
|
||||
return StreamingResponse(event_generator(), media_type="text/event-stream")
|
||||
|
||||
else:
|
||||
|
||||
async def endpoint(request: request_model):
|
||||
return func(request)
|
||||
|
||||
return endpoint
|
||||
|
||||
|
||||
def main(
|
||||
dist_name: str, yaml_config: str, port: int = 5000, disable_ipv6: bool = False
|
||||
):
|
||||
dist = resolve_distribution(dist_name)
|
||||
if dist is None:
|
||||
raise ValueError(f"Could not find distribution {dist_name}")
|
||||
|
||||
with open(yaml_config, "r") as fp:
|
||||
config = yaml.safe_load(fp)
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
all_endpoints = api_surface_endpoints()
|
||||
|
||||
adapter_configs = config["adapters"]
|
||||
for surface, adapter in dist.adapters.items():
|
||||
if surface.value not in adapter_configs:
|
||||
raise ValueError(
|
||||
f"Could not find adapter config for {surface}. Please add it to the config"
|
||||
)
|
||||
|
||||
adapter_config = adapter_configs[surface.value]
|
||||
endpoints = all_endpoints[surface]
|
||||
if isinstance(adapter, PassthroughApiAdapter):
|
||||
for endpoint in endpoints:
|
||||
url = adapter.base_url.rstrip("/") + endpoint.route
|
||||
getattr(app, endpoint.method)(endpoint.route)(
|
||||
create_dynamic_passthrough(url)
|
||||
)
|
||||
else:
|
||||
impl = instantiate_adapter(adapter, adapter_config)
|
||||
for endpoint in endpoints:
|
||||
if not hasattr(impl, endpoint.name):
|
||||
# ideally this should be a typing violation already
|
||||
raise ValueError(
|
||||
f"Could not find method {endpoint.name} on {impl}!!"
|
||||
)
|
||||
|
||||
impl_method = getattr(impl, endpoint.name)
|
||||
getattr(app, endpoint.method)(endpoint.route, response_model=None)(
|
||||
create_dynamic_typed_route(impl_method)
|
||||
)
|
||||
|
||||
for route in app.routes:
|
||||
if isinstance(route, APIRoute):
|
||||
cprint(
|
||||
f"Serving {next(iter(route.methods))} {route.path}",
|
||||
"white",
|
||||
attrs=["bold"],
|
||||
)
|
||||
|
||||
signal.signal(signal.SIGINT, handle_sigint)
|
||||
|
||||
import uvicorn
|
||||
|
||||
# FYI this does not do hot-reloads
|
||||
listen_host = "::" if not disable_ipv6 else "0.0.0.0"
|
||||
print(f"Listening on {listen_host}:{port}")
|
||||
uvicorn.run(app, host=listen_host, port=port)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
fire.Fire(main)
|
Loading…
Add table
Add a link
Reference in a new issue