forked from phoenix-oss/llama-stack-mirror
This PR makes several core changes to the developer experience surrounding Llama Stack. Background: PR #92 introduced the notion of "routing" to the Llama Stack. It introduces three object types: (1) models, (2) shields and (3) memory banks. Each of these objects can be associated with a distinct provider. So you can get model A to be inferenced locally while model B, C can be inference remotely (e.g.) However, this had a few drawbacks: you could not address the provider instances -- i.e., if you configured "meta-reference" with a given model, you could not assign an identifier to this instance which you could re-use later. the above meant that you could not register a "routing_key" (e.g. model) dynamically and say "please use this existing provider I have already configured" for a new model. the terms "routing_table" and "routing_key" were exposed directly to the user. in my view, this is way too much overhead for a new user (which almost everyone is.) people come to the stack wanting to do ML and encounter a completely unexpected term. What this PR does: This PR structures the run config with only a single prominent key: - providers Providers are instances of configured provider types. Here's an example which shows two instances of the remote::tgi provider which are serving two different models. providers: inference: - provider_id: foo provider_type: remote::tgi config: { ... } - provider_id: bar provider_type: remote::tgi config: { ... } Secondly, the PR adds dynamic registration of { models | shields | memory_banks } to the API surface. The distribution still acts like a "routing table" (as previously) except that it asks the backing providers for a listing of these objects. For example it asks a TGI or Ollama inference adapter what models it is serving. Only the models that are being actually served can be requested by the user for inference. Otherwise, the Stack server will throw an error. When dynamically registering these objects, you can use the provider IDs shown above. Info about providers can be obtained using the Api.inspect set of endpoints (/providers, /routes, etc.) The above examples shows the correspondence between inference providers and models registry items. Things work similarly for the safety <=> shields and memory <=> memory_banks pairs. Registry: This PR also makes it so that Providers need to implement additional methods for registering and listing objects. For example, each Inference provider is now expected to implement the ModelsProtocolPrivate protocol (naming is not great!) which consists of two methods register_model list_models The goal is to inform the provider that a certain model needs to be supported so the provider can make any relevant backend changes if needed (or throw an error if the model cannot be supported.) There are many other cleanups included some of which are detailed in a follow-up comment.
343 lines
10 KiB
Python
343 lines
10 KiB
Python
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
# All rights reserved.
|
|
#
|
|
# This source code is licensed under the terms described in the LICENSE file in
|
|
# the root directory of this source tree.
|
|
|
|
import asyncio
|
|
import functools
|
|
import inspect
|
|
import json
|
|
import signal
|
|
import traceback
|
|
|
|
from contextlib import asynccontextmanager
|
|
from ssl import SSLError
|
|
from typing import Any, Dict, Optional
|
|
|
|
import fire
|
|
import httpx
|
|
import yaml
|
|
|
|
from fastapi import Body, FastAPI, HTTPException, Request, Response
|
|
from fastapi.exceptions import RequestValidationError
|
|
from fastapi.responses import JSONResponse, StreamingResponse
|
|
from pydantic import BaseModel, ValidationError
|
|
from termcolor import cprint
|
|
from typing_extensions import Annotated
|
|
|
|
from llama_stack.distribution.distribution import builtin_automatically_routed_apis
|
|
|
|
from llama_stack.providers.utils.telemetry.tracing import (
|
|
end_trace,
|
|
setup_logger,
|
|
SpanStatus,
|
|
start_trace,
|
|
)
|
|
from llama_stack.distribution.datatypes import * # noqa: F403
|
|
|
|
from llama_stack.distribution.request_headers import set_request_provider_data
|
|
from llama_stack.distribution.resolver import resolve_impls_with_routing
|
|
|
|
from .endpoints import get_all_api_endpoints
|
|
|
|
|
|
def create_sse_event(data: Any) -> str:
|
|
if isinstance(data, BaseModel):
|
|
data = data.json()
|
|
else:
|
|
data = json.dumps(data)
|
|
|
|
return f"data: {data}\n\n"
|
|
|
|
|
|
async def global_exception_handler(request: Request, exc: Exception):
|
|
traceback.print_exception(exc)
|
|
http_exc = translate_exception(exc)
|
|
|
|
return JSONResponse(
|
|
status_code=http_exc.status_code, content={"error": {"detail": http_exc.detail}}
|
|
)
|
|
|
|
|
|
def translate_exception(exc: Exception) -> Union[HTTPException, RequestValidationError]:
|
|
if isinstance(exc, ValidationError):
|
|
exc = RequestValidationError(exc.raw_errors)
|
|
|
|
if isinstance(exc, RequestValidationError):
|
|
return HTTPException(
|
|
status_code=400,
|
|
detail={
|
|
"errors": [
|
|
{
|
|
"loc": list(error["loc"]),
|
|
"msg": error["msg"],
|
|
"type": error["type"],
|
|
}
|
|
for error in exc.errors()
|
|
]
|
|
},
|
|
)
|
|
elif isinstance(exc, ValueError):
|
|
return HTTPException(status_code=400, detail=f"Invalid value: {str(exc)}")
|
|
elif isinstance(exc, PermissionError):
|
|
return HTTPException(status_code=403, detail=f"Permission denied: {str(exc)}")
|
|
elif isinstance(exc, TimeoutError):
|
|
return HTTPException(status_code=504, detail=f"Operation timed out: {str(exc)}")
|
|
elif isinstance(exc, NotImplementedError):
|
|
return HTTPException(status_code=501, detail=f"Not implemented: {str(exc)}")
|
|
else:
|
|
return HTTPException(
|
|
status_code=500,
|
|
detail="Internal server error: An unexpected error occurred.",
|
|
)
|
|
|
|
|
|
async def passthrough(
|
|
request: Request,
|
|
downstream_url: str,
|
|
downstream_headers: Optional[Dict[str, str]] = None,
|
|
):
|
|
await start_trace(request.path, {"downstream_url": downstream_url})
|
|
|
|
headers = dict(request.headers)
|
|
headers.pop("host", None)
|
|
headers.update(downstream_headers or {})
|
|
|
|
content = await request.body()
|
|
|
|
client = httpx.AsyncClient()
|
|
erred = False
|
|
try:
|
|
req = client.build_request(
|
|
method=request.method,
|
|
url=downstream_url,
|
|
headers=headers,
|
|
content=content,
|
|
params=request.query_params,
|
|
)
|
|
response = await client.send(req, stream=True)
|
|
|
|
async def stream_response():
|
|
async for chunk in response.aiter_raw(chunk_size=64):
|
|
yield chunk
|
|
|
|
await response.aclose()
|
|
await client.aclose()
|
|
|
|
return StreamingResponse(
|
|
stream_response(),
|
|
status_code=response.status_code,
|
|
headers=dict(response.headers),
|
|
media_type=response.headers.get("content-type"),
|
|
)
|
|
|
|
except httpx.ReadTimeout:
|
|
erred = True
|
|
return Response(content="Downstream server timed out", status_code=504)
|
|
except httpx.NetworkError as e:
|
|
erred = True
|
|
return Response(content=f"Network error: {str(e)}", status_code=502)
|
|
except httpx.TooManyRedirects:
|
|
erred = True
|
|
return Response(content="Too many redirects", status_code=502)
|
|
except SSLError as e:
|
|
erred = True
|
|
return Response(content=f"SSL error: {str(e)}", status_code=502)
|
|
except httpx.HTTPStatusError as e:
|
|
erred = True
|
|
return Response(content=str(e), status_code=e.response.status_code)
|
|
except Exception as e:
|
|
erred = True
|
|
return Response(content=f"Unexpected error: {str(e)}", status_code=500)
|
|
finally:
|
|
await end_trace(SpanStatus.OK if not erred else SpanStatus.ERROR)
|
|
|
|
|
|
def handle_sigint(app, *args, **kwargs):
|
|
print("SIGINT or CTRL-C detected. Exiting gracefully...")
|
|
|
|
async def run_shutdown():
|
|
for impl in app.__llama_stack_impls__.values():
|
|
print(f"Shutting down {impl}")
|
|
await impl.shutdown()
|
|
|
|
asyncio.run(run_shutdown())
|
|
|
|
loop = asyncio.get_event_loop()
|
|
for task in asyncio.all_tasks(loop):
|
|
task.cancel()
|
|
|
|
loop.stop()
|
|
|
|
|
|
@asynccontextmanager
|
|
async def lifespan(app: FastAPI):
|
|
print("Starting up")
|
|
yield
|
|
|
|
print("Shutting down")
|
|
for impl in app.__llama_stack_impls__.values():
|
|
await impl.shutdown()
|
|
|
|
|
|
def create_dynamic_passthrough(
|
|
downstream_url: str, downstream_headers: Optional[Dict[str, str]] = None
|
|
):
|
|
async def endpoint(request: Request):
|
|
return await passthrough(request, downstream_url, downstream_headers)
|
|
|
|
return endpoint
|
|
|
|
|
|
def is_streaming_request(func_name: str, request: Request, **kwargs):
|
|
# TODO: pass the api method and punt it to the Protocol definition directly
|
|
return kwargs.get("stream", False)
|
|
|
|
|
|
async def maybe_await(value):
|
|
if inspect.iscoroutine(value):
|
|
return await value
|
|
return value
|
|
|
|
|
|
async def sse_generator(event_gen):
|
|
try:
|
|
async for item in event_gen:
|
|
yield create_sse_event(item)
|
|
await asyncio.sleep(0.01)
|
|
except asyncio.CancelledError:
|
|
print("Generator cancelled")
|
|
await event_gen.aclose()
|
|
except Exception as e:
|
|
traceback.print_exception(e)
|
|
yield create_sse_event(
|
|
{
|
|
"error": {
|
|
"message": str(translate_exception(e)),
|
|
},
|
|
}
|
|
)
|
|
finally:
|
|
await end_trace()
|
|
|
|
|
|
def create_dynamic_typed_route(func: Any, method: str):
|
|
|
|
async def endpoint(request: Request, **kwargs):
|
|
await start_trace(func.__name__)
|
|
|
|
set_request_provider_data(request.headers)
|
|
|
|
is_streaming = is_streaming_request(func.__name__, request, **kwargs)
|
|
try:
|
|
if is_streaming:
|
|
return StreamingResponse(
|
|
sse_generator(func(**kwargs)), media_type="text/event-stream"
|
|
)
|
|
else:
|
|
value = func(**kwargs)
|
|
return await maybe_await(value)
|
|
except Exception as e:
|
|
traceback.print_exception(e)
|
|
raise translate_exception(e) from e
|
|
finally:
|
|
await end_trace()
|
|
|
|
sig = inspect.signature(func)
|
|
new_params = [
|
|
inspect.Parameter(
|
|
"request", inspect.Parameter.POSITIONAL_OR_KEYWORD, annotation=Request
|
|
)
|
|
]
|
|
new_params.extend(sig.parameters.values())
|
|
|
|
if method == "post":
|
|
# make sure every parameter is annotated with Body() so FASTAPI doesn't
|
|
# do anything too intelligent and ask for some parameters in the query
|
|
# and some in the body
|
|
new_params = [new_params[0]] + [
|
|
param.replace(annotation=Annotated[param.annotation, Body(..., embed=True)])
|
|
for param in new_params[1:]
|
|
]
|
|
|
|
endpoint.__signature__ = sig.replace(parameters=new_params)
|
|
|
|
return endpoint
|
|
|
|
|
|
def main(
|
|
yaml_config: str = "llamastack-run.yaml",
|
|
port: int = 5000,
|
|
disable_ipv6: bool = False,
|
|
):
|
|
with open(yaml_config, "r") as fp:
|
|
config = StackRunConfig(**yaml.safe_load(fp))
|
|
|
|
app = FastAPI()
|
|
|
|
impls = asyncio.run(resolve_impls_with_routing(config))
|
|
if Api.telemetry in impls:
|
|
setup_logger(impls[Api.telemetry])
|
|
|
|
all_endpoints = get_all_api_endpoints()
|
|
|
|
if config.apis:
|
|
apis_to_serve = set(config.apis)
|
|
else:
|
|
apis_to_serve = set(impls.keys())
|
|
|
|
for inf in builtin_automatically_routed_apis():
|
|
apis_to_serve.add(inf.routing_table_api.value)
|
|
|
|
apis_to_serve.add("inspect")
|
|
for api_str in apis_to_serve:
|
|
api = Api(api_str)
|
|
|
|
endpoints = all_endpoints[api]
|
|
impl = impls[api]
|
|
|
|
if is_passthrough(impl.__provider_spec__):
|
|
for endpoint in endpoints:
|
|
url = impl.__provider_config__.url.rstrip("/") + endpoint.route
|
|
getattr(app, endpoint.method)(endpoint.route)(
|
|
create_dynamic_passthrough(url)
|
|
)
|
|
else:
|
|
for endpoint in endpoints:
|
|
if not hasattr(impl, endpoint.name):
|
|
# ideally this should be a typing violation already
|
|
raise ValueError(
|
|
f"Could not find method {endpoint.name} on {impl}!!"
|
|
)
|
|
|
|
impl_method = getattr(impl, endpoint.name)
|
|
|
|
getattr(app, endpoint.method)(endpoint.route, response_model=None)(
|
|
create_dynamic_typed_route(
|
|
impl_method,
|
|
endpoint.method,
|
|
)
|
|
)
|
|
|
|
cprint(f"Serving API {api_str}", "white", attrs=["bold"])
|
|
for endpoint in endpoints:
|
|
cprint(f" {endpoint.method.upper()} {endpoint.route}", "white")
|
|
|
|
print("")
|
|
app.exception_handler(RequestValidationError)(global_exception_handler)
|
|
app.exception_handler(Exception)(global_exception_handler)
|
|
signal.signal(signal.SIGINT, functools.partial(handle_sigint, app))
|
|
|
|
app.__llama_stack_impls__ = impls
|
|
|
|
import uvicorn
|
|
|
|
# FYI this does not do hot-reloads
|
|
listen_host = "::" if not disable_ipv6 else "0.0.0.0"
|
|
print(f"Listening on {listen_host}:{port}")
|
|
uvicorn.run(app, host=listen_host, port=port)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
fire.Fire(main)
|