mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-04 04:04:14 +00:00
Distribution server now functioning
This commit is contained in:
parent
041cafbee3
commit
2cf9915806
21 changed files with 635 additions and 266 deletions
|
@ -5,54 +5,10 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, List, Literal, Union
|
||||
from typing import Dict, List
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from strong_typing.schema import json_schema_type
|
||||
from typing_extensions import Annotated
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class AdapterType(Enum):
|
||||
passthrough_api = "passthrough_api"
|
||||
python_impl = "python_impl"
|
||||
not_implemented = "not_implemented"
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class PassthroughApiAdapterConfig(BaseModel):
|
||||
type: Literal[AdapterType.passthrough_api.value] = AdapterType.passthrough_api.value
|
||||
base_url: str = Field(..., description="The base URL for the llama stack provider")
|
||||
headers: Dict[str, str] = Field(
|
||||
default_factory=dict,
|
||||
description="Headers (e.g., authorization) to send with the request",
|
||||
)
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class PythonImplAdapterConfig(BaseModel):
|
||||
type: Literal[AdapterType.python_impl.value] = AdapterType.python_impl.value
|
||||
adapter_id: str
|
||||
kwargs: Dict[str, Any] = Field(
|
||||
default_factory=dict, description="kwargs to pass to the entrypoint"
|
||||
)
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class NotImplementedAdapterConfig(BaseModel):
|
||||
type: Literal[AdapterType.not_implemented.value] = AdapterType.not_implemented.value
|
||||
|
||||
|
||||
# should we define very granular typed classes for each of the PythonImplAdapters we will have?
|
||||
# e.g., OllamaInference / vLLMInference / etc. might need very specific parameters
|
||||
AdapterConfig = Annotated[
|
||||
Union[
|
||||
PassthroughApiAdapterConfig,
|
||||
NotImplementedAdapterConfig,
|
||||
PythonImplAdapterConfig,
|
||||
],
|
||||
Field(discriminator="type"),
|
||||
]
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
@ -61,6 +17,13 @@ class ApiSurface(Enum):
|
|||
safety = "safety"
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class ApiSurfaceEndpoint(BaseModel):
|
||||
route: str
|
||||
method: str
|
||||
name: str
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class Adapter(BaseModel):
|
||||
api_surface: ApiSurface
|
||||
|
@ -108,13 +71,3 @@ class Distribution(BaseModel):
|
|||
default_factory=list,
|
||||
description="Additional pip packages beyond those required by the adapters",
|
||||
)
|
||||
|
||||
|
||||
def distribution_dependencies(distribution: Distribution) -> List[str]:
|
||||
# only consider SourceAdapters when calculating dependencies
|
||||
return [
|
||||
dep
|
||||
for adapter in distribution.adapters.values()
|
||||
if isinstance(adapter, SourceAdapter)
|
||||
for dep in adapter.pip_packages
|
||||
] + distribution.additional_pip_packages
|
||||
|
|
51
llama_toolchain/distribution/distribution.py
Normal file
51
llama_toolchain/distribution/distribution.py
Normal file
|
@ -0,0 +1,51 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import inspect
|
||||
from typing import Dict, List
|
||||
|
||||
from llama_toolchain.inference.api.endpoints import Inference
|
||||
from llama_toolchain.safety.api.endpoints import Safety
|
||||
|
||||
from .datatypes import ApiSurface, ApiSurfaceEndpoint, Distribution, SourceAdapter
|
||||
|
||||
|
||||
def distribution_dependencies(distribution: Distribution) -> List[str]:
|
||||
# only consider SourceAdapters when calculating dependencies
|
||||
return [
|
||||
dep
|
||||
for adapter in distribution.adapters.values()
|
||||
if isinstance(adapter, SourceAdapter)
|
||||
for dep in adapter.pip_packages
|
||||
] + distribution.additional_pip_packages
|
||||
|
||||
|
||||
def api_surface_endpoints() -> Dict[ApiSurface, List[ApiSurfaceEndpoint]]:
|
||||
surfaces = {}
|
||||
|
||||
protocols = {
|
||||
ApiSurface.inference: Inference,
|
||||
ApiSurface.safety: Safety,
|
||||
}
|
||||
|
||||
for surface, protocol in protocols.items():
|
||||
endpoints = []
|
||||
protocol_methods = inspect.getmembers(protocol, predicate=inspect.isfunction)
|
||||
|
||||
for name, method in protocol_methods:
|
||||
if not hasattr(method, "__webmethod__"):
|
||||
continue
|
||||
|
||||
webmethod = method.__webmethod__
|
||||
route = webmethod.route
|
||||
|
||||
# use `post` for all methods right now until we fix up the `webmethod` openapi
|
||||
# annotation and write our own openapi generator
|
||||
endpoints.append(ApiSurfaceEndpoint(route=route, method="post", name=name))
|
||||
|
||||
surfaces[surface] = endpoints
|
||||
|
||||
return surfaces
|
26
llama_toolchain/distribution/dynamic.py
Normal file
26
llama_toolchain/distribution/dynamic.py
Normal file
|
@ -0,0 +1,26 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import asyncio
|
||||
import importlib
|
||||
from typing import Any, Dict
|
||||
|
||||
from .datatypes import SourceAdapter
|
||||
|
||||
|
||||
def instantiate_class_type(fully_qualified_name):
|
||||
module_name, class_name = fully_qualified_name.rsplit(".", 1)
|
||||
module = importlib.import_module(module_name)
|
||||
return getattr(module, class_name)
|
||||
|
||||
|
||||
# returns a class implementing the protocol corresponding to the ApiSurface
|
||||
def instantiate_adapter(adapter: SourceAdapter, adapter_config: Dict[str, Any]):
|
||||
module = importlib.import_module(adapter.module)
|
||||
|
||||
config_type = instantiate_class_type(adapter.config_class)
|
||||
config = config_type(**adapter_config)
|
||||
return asyncio.run(module.get_adapter_impl(config))
|
|
@ -4,7 +4,7 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import List
|
||||
from typing import List, Optional
|
||||
|
||||
from llama_toolchain.inference.adapters import available_inference_adapters
|
||||
|
||||
|
@ -63,3 +63,10 @@ def available_distributions() -> List[Distribution]:
|
|||
},
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
def resolve_distribution(name: str) -> Optional[Distribution]:
|
||||
for dist in available_distributions():
|
||||
if dist.name == name:
|
||||
return dist
|
||||
return None
|
||||
|
|
202
llama_toolchain/distribution/server.py
Normal file
202
llama_toolchain/distribution/server.py
Normal file
|
@ -0,0 +1,202 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import signal
|
||||
from collections.abc import (
|
||||
AsyncGenerator as AsyncGeneratorABC,
|
||||
AsyncIterator as AsyncIteratorABC,
|
||||
)
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import Any, AsyncGenerator, AsyncIterator, Dict, get_type_hints, Optional
|
||||
|
||||
import fire
|
||||
import httpx
|
||||
import yaml
|
||||
from dotenv import load_dotenv
|
||||
|
||||
from fastapi import FastAPI, Request
|
||||
from fastapi.responses import StreamingResponse
|
||||
from fastapi.routing import APIRoute
|
||||
|
||||
from pydantic import BaseModel
|
||||
from termcolor import cprint
|
||||
|
||||
from .datatypes import PassthroughApiAdapter
|
||||
from .distribution import api_surface_endpoints
|
||||
from .dynamic import instantiate_adapter
|
||||
|
||||
from .registry import resolve_distribution
|
||||
|
||||
load_dotenv()
|
||||
|
||||
|
||||
def is_async_iterator_type(typ):
|
||||
if hasattr(typ, "__origin__"):
|
||||
origin = typ.__origin__
|
||||
if isinstance(origin, type):
|
||||
return issubclass(
|
||||
origin,
|
||||
(AsyncIterator, AsyncGenerator, AsyncIteratorABC, AsyncGeneratorABC),
|
||||
)
|
||||
return False
|
||||
return isinstance(
|
||||
typ, (AsyncIterator, AsyncGenerator, AsyncIteratorABC, AsyncGeneratorABC)
|
||||
)
|
||||
|
||||
|
||||
def create_sse_event(data: Any) -> str:
|
||||
if isinstance(data, BaseModel):
|
||||
data = data.json()
|
||||
else:
|
||||
data = json.dumps(data)
|
||||
|
||||
return f"data: {data}\n\n"
|
||||
|
||||
|
||||
async def passthrough(
|
||||
request: Request,
|
||||
downstream_url: str,
|
||||
downstream_headers: Optional[Dict[str, str]] = None,
|
||||
):
|
||||
client = httpx.AsyncClient()
|
||||
|
||||
headers = dict(request.headers)
|
||||
headers.pop("host", None)
|
||||
headers.update(downstream_headers or {})
|
||||
|
||||
body = await request.body()
|
||||
|
||||
try:
|
||||
response = await client.request(
|
||||
method=request.method,
|
||||
url=downstream_url,
|
||||
headers=headers,
|
||||
content=body,
|
||||
params=request.query_params,
|
||||
)
|
||||
return StreamingResponse(
|
||||
response.iter_bytes(),
|
||||
status_code=response.status_code,
|
||||
headers=dict(response.headers),
|
||||
)
|
||||
finally:
|
||||
await client.aclose()
|
||||
|
||||
|
||||
def handle_sigint(*args, **kwargs):
|
||||
print("SIGINT or CTRL-C detected. Exiting gracefully", args)
|
||||
loop = asyncio.get_event_loop()
|
||||
for task in asyncio.all_tasks(loop):
|
||||
task.cancel()
|
||||
loop.stop()
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
print("Starting up")
|
||||
yield
|
||||
print("Shutting down")
|
||||
|
||||
|
||||
def create_dynamic_passthrough(
|
||||
downstream_url: str, downstream_headers: Optional[Dict[str, str]] = None
|
||||
):
|
||||
async def endpoint(request: Request):
|
||||
return await passthrough(request, downstream_url, downstream_headers)
|
||||
|
||||
return endpoint
|
||||
|
||||
|
||||
def create_dynamic_typed_route(func: Any):
|
||||
hints = get_type_hints(func)
|
||||
request_model = next(iter(hints.values()))
|
||||
response_model = hints["return"]
|
||||
|
||||
is_streaming = is_async_iterator_type(response_model)
|
||||
|
||||
if is_streaming:
|
||||
|
||||
async def endpoint(request: request_model):
|
||||
async def event_generator():
|
||||
async for item in func(request):
|
||||
yield create_sse_event(item)
|
||||
await asyncio.sleep(0.001)
|
||||
|
||||
return StreamingResponse(event_generator(), media_type="text/event-stream")
|
||||
|
||||
else:
|
||||
|
||||
async def endpoint(request: request_model):
|
||||
return func(request)
|
||||
|
||||
return endpoint
|
||||
|
||||
|
||||
def main(
|
||||
dist_name: str, yaml_config: str, port: int = 5000, disable_ipv6: bool = False
|
||||
):
|
||||
dist = resolve_distribution(dist_name)
|
||||
if dist is None:
|
||||
raise ValueError(f"Could not find distribution {dist_name}")
|
||||
|
||||
with open(yaml_config, "r") as fp:
|
||||
config = yaml.safe_load(fp)
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
all_endpoints = api_surface_endpoints()
|
||||
|
||||
adapter_configs = config["adapters"]
|
||||
for surface, adapter in dist.adapters.items():
|
||||
if surface.value not in adapter_configs:
|
||||
raise ValueError(
|
||||
f"Could not find adapter config for {surface}. Please add it to the config"
|
||||
)
|
||||
|
||||
adapter_config = adapter_configs[surface.value]
|
||||
endpoints = all_endpoints[surface]
|
||||
if isinstance(adapter, PassthroughApiAdapter):
|
||||
for endpoint in endpoints:
|
||||
url = adapter.base_url.rstrip("/") + endpoint.route
|
||||
getattr(app, endpoint.method)(endpoint.route)(
|
||||
create_dynamic_passthrough(url)
|
||||
)
|
||||
else:
|
||||
impl = instantiate_adapter(adapter, adapter_config)
|
||||
for endpoint in endpoints:
|
||||
if not hasattr(impl, endpoint.name):
|
||||
# ideally this should be a typing violation already
|
||||
raise ValueError(
|
||||
f"Could not find method {endpoint.name} on {impl}!!"
|
||||
)
|
||||
|
||||
impl_method = getattr(impl, endpoint.name)
|
||||
getattr(app, endpoint.method)(endpoint.route, response_model=None)(
|
||||
create_dynamic_typed_route(impl_method)
|
||||
)
|
||||
|
||||
for route in app.routes:
|
||||
if isinstance(route, APIRoute):
|
||||
cprint(
|
||||
f"Serving {next(iter(route.methods))} {route.path}",
|
||||
"white",
|
||||
attrs=["bold"],
|
||||
)
|
||||
|
||||
signal.signal(signal.SIGINT, handle_sigint)
|
||||
|
||||
import uvicorn
|
||||
|
||||
# FYI this does not do hot-reloads
|
||||
listen_host = "::" if not disable_ipv6 else "0.0.0.0"
|
||||
print(f"Listening on {listen_host}:{port}")
|
||||
uvicorn.run(app, host=listen_host, port=port)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
fire.Fire(main)
|
Loading…
Add table
Add a link
Reference in a new issue