# Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. import asyncio import json import signal from collections.abc import ( AsyncGenerator as AsyncGeneratorABC, AsyncIterator as AsyncIteratorABC, ) from contextlib import asynccontextmanager from typing import Any, AsyncGenerator, AsyncIterator, Dict, get_type_hints, Optional import fire import httpx import yaml from dotenv import load_dotenv from fastapi import FastAPI, Request from fastapi.responses import StreamingResponse from fastapi.routing import APIRoute from pydantic import BaseModel from termcolor import cprint from .datatypes import PassthroughApiAdapter from .distribution import api_surface_endpoints from .dynamic import instantiate_adapter from .registry import resolve_distribution load_dotenv() def is_async_iterator_type(typ): if hasattr(typ, "__origin__"): origin = typ.__origin__ if isinstance(origin, type): return issubclass( origin, (AsyncIterator, AsyncGenerator, AsyncIteratorABC, AsyncGeneratorABC), ) return False return isinstance( typ, (AsyncIterator, AsyncGenerator, AsyncIteratorABC, AsyncGeneratorABC) ) def create_sse_event(data: Any) -> str: if isinstance(data, BaseModel): data = data.json() else: data = json.dumps(data) return f"data: {data}\n\n" async def passthrough( request: Request, downstream_url: str, downstream_headers: Optional[Dict[str, str]] = None, ): client = httpx.AsyncClient() headers = dict(request.headers) headers.pop("host", None) headers.update(downstream_headers or {}) body = await request.body() try: response = await client.request( method=request.method, url=downstream_url, headers=headers, content=body, params=request.query_params, ) return StreamingResponse( response.iter_bytes(), status_code=response.status_code, headers=dict(response.headers), ) finally: await client.aclose() def handle_sigint(*args, **kwargs): print("SIGINT or CTRL-C detected. Exiting gracefully", args) loop = asyncio.get_event_loop() for task in asyncio.all_tasks(loop): task.cancel() loop.stop() @asynccontextmanager async def lifespan(app: FastAPI): print("Starting up") yield print("Shutting down") def create_dynamic_passthrough( downstream_url: str, downstream_headers: Optional[Dict[str, str]] = None ): async def endpoint(request: Request): return await passthrough(request, downstream_url, downstream_headers) return endpoint def create_dynamic_typed_route(func: Any): hints = get_type_hints(func) request_model = next(iter(hints.values())) response_model = hints["return"] is_streaming = is_async_iterator_type(response_model) if is_streaming: async def endpoint(request: request_model): async def event_generator(): async for item in func(request): yield create_sse_event(item) await asyncio.sleep(0.001) return StreamingResponse(event_generator(), media_type="text/event-stream") else: async def endpoint(request: request_model): return func(request) return endpoint def main( dist_name: str, yaml_config: str, port: int = 5000, disable_ipv6: bool = False ): dist = resolve_distribution(dist_name) if dist is None: raise ValueError(f"Could not find distribution {dist_name}") with open(yaml_config, "r") as fp: config = yaml.safe_load(fp) app = FastAPI() all_endpoints = api_surface_endpoints() adapter_configs = config["adapters"] for surface, adapter in dist.adapters.items(): if surface.value not in adapter_configs: raise ValueError( f"Could not find adapter config for {surface}. Please add it to the config" ) adapter_config = adapter_configs[surface.value] endpoints = all_endpoints[surface] if isinstance(adapter, PassthroughApiAdapter): for endpoint in endpoints: url = adapter.base_url.rstrip("/") + endpoint.route getattr(app, endpoint.method)(endpoint.route)( create_dynamic_passthrough(url) ) else: impl = instantiate_adapter(adapter, adapter_config) for endpoint in endpoints: if not hasattr(impl, endpoint.name): # ideally this should be a typing violation already raise ValueError( f"Could not find method {endpoint.name} on {impl}!!" ) impl_method = getattr(impl, endpoint.name) getattr(app, endpoint.method)(endpoint.route, response_model=None)( create_dynamic_typed_route(impl_method) ) for route in app.routes: if isinstance(route, APIRoute): cprint( f"Serving {next(iter(route.methods))} {route.path}", "white", attrs=["bold"], ) signal.signal(signal.SIGINT, handle_sigint) import uvicorn # FYI this does not do hot-reloads listen_host = "::" if not disable_ipv6 else "0.0.0.0" print(f"Listening on {listen_host}:{port}") uvicorn.run(app, host=listen_host, port=port) if __name__ == "__main__": fire.Fire(main)