Distribution server now functioning

This commit is contained in:
Ashwin Bharambe 2024-08-02 13:37:40 -07:00
parent 041cafbee3
commit 2cf9915806
21 changed files with 635 additions and 266 deletions

View file

@ -5,54 +5,10 @@
# the root directory of this source tree.
from enum import Enum
from typing import Any, Dict, List, Literal, Union
from typing import Dict, List
from pydantic import BaseModel, Field
from strong_typing.schema import json_schema_type
from typing_extensions import Annotated
@json_schema_type
class AdapterType(Enum):
passthrough_api = "passthrough_api"
python_impl = "python_impl"
not_implemented = "not_implemented"
@json_schema_type
class PassthroughApiAdapterConfig(BaseModel):
type: Literal[AdapterType.passthrough_api.value] = AdapterType.passthrough_api.value
base_url: str = Field(..., description="The base URL for the llama stack provider")
headers: Dict[str, str] = Field(
default_factory=dict,
description="Headers (e.g., authorization) to send with the request",
)
@json_schema_type
class PythonImplAdapterConfig(BaseModel):
type: Literal[AdapterType.python_impl.value] = AdapterType.python_impl.value
adapter_id: str
kwargs: Dict[str, Any] = Field(
default_factory=dict, description="kwargs to pass to the entrypoint"
)
@json_schema_type
class NotImplementedAdapterConfig(BaseModel):
type: Literal[AdapterType.not_implemented.value] = AdapterType.not_implemented.value
# should we define very granular typed classes for each of the PythonImplAdapters we will have?
# e.g., OllamaInference / vLLMInference / etc. might need very specific parameters
AdapterConfig = Annotated[
Union[
PassthroughApiAdapterConfig,
NotImplementedAdapterConfig,
PythonImplAdapterConfig,
],
Field(discriminator="type"),
]
@json_schema_type
@ -61,6 +17,13 @@ class ApiSurface(Enum):
safety = "safety"
@json_schema_type
class ApiSurfaceEndpoint(BaseModel):
route: str
method: str
name: str
@json_schema_type
class Adapter(BaseModel):
api_surface: ApiSurface
@ -108,13 +71,3 @@ class Distribution(BaseModel):
default_factory=list,
description="Additional pip packages beyond those required by the adapters",
)
def distribution_dependencies(distribution: Distribution) -> List[str]:
# only consider SourceAdapters when calculating dependencies
return [
dep
for adapter in distribution.adapters.values()
if isinstance(adapter, SourceAdapter)
for dep in adapter.pip_packages
] + distribution.additional_pip_packages

View file

@ -0,0 +1,51 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import inspect
from typing import Dict, List
from llama_toolchain.inference.api.endpoints import Inference
from llama_toolchain.safety.api.endpoints import Safety
from .datatypes import ApiSurface, ApiSurfaceEndpoint, Distribution, SourceAdapter
def distribution_dependencies(distribution: Distribution) -> List[str]:
# only consider SourceAdapters when calculating dependencies
return [
dep
for adapter in distribution.adapters.values()
if isinstance(adapter, SourceAdapter)
for dep in adapter.pip_packages
] + distribution.additional_pip_packages
def api_surface_endpoints() -> Dict[ApiSurface, List[ApiSurfaceEndpoint]]:
surfaces = {}
protocols = {
ApiSurface.inference: Inference,
ApiSurface.safety: Safety,
}
for surface, protocol in protocols.items():
endpoints = []
protocol_methods = inspect.getmembers(protocol, predicate=inspect.isfunction)
for name, method in protocol_methods:
if not hasattr(method, "__webmethod__"):
continue
webmethod = method.__webmethod__
route = webmethod.route
# use `post` for all methods right now until we fix up the `webmethod` openapi
# annotation and write our own openapi generator
endpoints.append(ApiSurfaceEndpoint(route=route, method="post", name=name))
surfaces[surface] = endpoints
return surfaces

View file

@ -0,0 +1,26 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import asyncio
import importlib
from typing import Any, Dict
from .datatypes import SourceAdapter
def instantiate_class_type(fully_qualified_name):
module_name, class_name = fully_qualified_name.rsplit(".", 1)
module = importlib.import_module(module_name)
return getattr(module, class_name)
# returns a class implementing the protocol corresponding to the ApiSurface
def instantiate_adapter(adapter: SourceAdapter, adapter_config: Dict[str, Any]):
module = importlib.import_module(adapter.module)
config_type = instantiate_class_type(adapter.config_class)
config = config_type(**adapter_config)
return asyncio.run(module.get_adapter_impl(config))

View file

@ -4,7 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import List
from typing import List, Optional
from llama_toolchain.inference.adapters import available_inference_adapters
@ -63,3 +63,10 @@ def available_distributions() -> List[Distribution]:
},
),
]
def resolve_distribution(name: str) -> Optional[Distribution]:
for dist in available_distributions():
if dist.name == name:
return dist
return None

View file

@ -0,0 +1,202 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import asyncio
import json
import signal
from collections.abc import (
AsyncGenerator as AsyncGeneratorABC,
AsyncIterator as AsyncIteratorABC,
)
from contextlib import asynccontextmanager
from typing import Any, AsyncGenerator, AsyncIterator, Dict, get_type_hints, Optional
import fire
import httpx
import yaml
from dotenv import load_dotenv
from fastapi import FastAPI, Request
from fastapi.responses import StreamingResponse
from fastapi.routing import APIRoute
from pydantic import BaseModel
from termcolor import cprint
from .datatypes import PassthroughApiAdapter
from .distribution import api_surface_endpoints
from .dynamic import instantiate_adapter
from .registry import resolve_distribution
load_dotenv()
def is_async_iterator_type(typ):
if hasattr(typ, "__origin__"):
origin = typ.__origin__
if isinstance(origin, type):
return issubclass(
origin,
(AsyncIterator, AsyncGenerator, AsyncIteratorABC, AsyncGeneratorABC),
)
return False
return isinstance(
typ, (AsyncIterator, AsyncGenerator, AsyncIteratorABC, AsyncGeneratorABC)
)
def create_sse_event(data: Any) -> str:
if isinstance(data, BaseModel):
data = data.json()
else:
data = json.dumps(data)
return f"data: {data}\n\n"
async def passthrough(
request: Request,
downstream_url: str,
downstream_headers: Optional[Dict[str, str]] = None,
):
client = httpx.AsyncClient()
headers = dict(request.headers)
headers.pop("host", None)
headers.update(downstream_headers or {})
body = await request.body()
try:
response = await client.request(
method=request.method,
url=downstream_url,
headers=headers,
content=body,
params=request.query_params,
)
return StreamingResponse(
response.iter_bytes(),
status_code=response.status_code,
headers=dict(response.headers),
)
finally:
await client.aclose()
def handle_sigint(*args, **kwargs):
print("SIGINT or CTRL-C detected. Exiting gracefully", args)
loop = asyncio.get_event_loop()
for task in asyncio.all_tasks(loop):
task.cancel()
loop.stop()
@asynccontextmanager
async def lifespan(app: FastAPI):
print("Starting up")
yield
print("Shutting down")
def create_dynamic_passthrough(
downstream_url: str, downstream_headers: Optional[Dict[str, str]] = None
):
async def endpoint(request: Request):
return await passthrough(request, downstream_url, downstream_headers)
return endpoint
def create_dynamic_typed_route(func: Any):
hints = get_type_hints(func)
request_model = next(iter(hints.values()))
response_model = hints["return"]
is_streaming = is_async_iterator_type(response_model)
if is_streaming:
async def endpoint(request: request_model):
async def event_generator():
async for item in func(request):
yield create_sse_event(item)
await asyncio.sleep(0.001)
return StreamingResponse(event_generator(), media_type="text/event-stream")
else:
async def endpoint(request: request_model):
return func(request)
return endpoint
def main(
dist_name: str, yaml_config: str, port: int = 5000, disable_ipv6: bool = False
):
dist = resolve_distribution(dist_name)
if dist is None:
raise ValueError(f"Could not find distribution {dist_name}")
with open(yaml_config, "r") as fp:
config = yaml.safe_load(fp)
app = FastAPI()
all_endpoints = api_surface_endpoints()
adapter_configs = config["adapters"]
for surface, adapter in dist.adapters.items():
if surface.value not in adapter_configs:
raise ValueError(
f"Could not find adapter config for {surface}. Please add it to the config"
)
adapter_config = adapter_configs[surface.value]
endpoints = all_endpoints[surface]
if isinstance(adapter, PassthroughApiAdapter):
for endpoint in endpoints:
url = adapter.base_url.rstrip("/") + endpoint.route
getattr(app, endpoint.method)(endpoint.route)(
create_dynamic_passthrough(url)
)
else:
impl = instantiate_adapter(adapter, adapter_config)
for endpoint in endpoints:
if not hasattr(impl, endpoint.name):
# ideally this should be a typing violation already
raise ValueError(
f"Could not find method {endpoint.name} on {impl}!!"
)
impl_method = getattr(impl, endpoint.name)
getattr(app, endpoint.method)(endpoint.route, response_model=None)(
create_dynamic_typed_route(impl_method)
)
for route in app.routes:
if isinstance(route, APIRoute):
cprint(
f"Serving {next(iter(route.methods))} {route.path}",
"white",
attrs=["bold"],
)
signal.signal(signal.SIGINT, handle_sigint)
import uvicorn
# FYI this does not do hot-reloads
listen_host = "::" if not disable_ipv6 else "0.0.0.0"
print(f"Listening on {listen_host}:{port}")
uvicorn.run(app, host=listen_host, port=port)
if __name__ == "__main__":
fire.Fire(main)