Add a special header per-client call to parser provider data

This commit is contained in:
Ashwin Bharambe 2024-09-18 09:17:59 -07:00
parent 942cb87a3c
commit 90a59fd89b
11 changed files with 955 additions and 102 deletions

View file

@ -48,6 +48,7 @@ from llama_stack.providers.utils.telemetry.tracing import (
from llama_stack.distribution.datatypes import * # noqa: F403
from llama_stack.distribution.distribution import api_endpoints, api_providers
from llama_stack.distribution.request_headers import set_request_provider_data
from llama_stack.distribution.utils.dynamic import instantiate_provider
@ -176,7 +177,9 @@ def create_dynamic_passthrough(
return endpoint
def create_dynamic_typed_route(func: Any, method: str):
def create_dynamic_typed_route(
func: Any, method: str, provider_data_validator: Optional[str]
):
hints = get_type_hints(func)
response_model = hints.get("return")
@ -188,9 +191,11 @@ def create_dynamic_typed_route(func: Any, method: str):
if is_streaming:
async def endpoint(**kwargs):
async def endpoint(request: Request, **kwargs):
await start_trace(func.__name__)
set_request_provider_data(request.headers, provider_data_validator)
async def sse_generator(event_gen):
try:
async for item in event_gen:
@ -217,8 +222,11 @@ def create_dynamic_typed_route(func: Any, method: str):
else:
async def endpoint(**kwargs):
async def endpoint(request: Request, **kwargs):
await start_trace(func.__name__)
set_request_provider_data(request.headers, provider_data_validator)
try:
return (
await func(**kwargs)
@ -232,20 +240,23 @@ def create_dynamic_typed_route(func: Any, method: str):
await end_trace()
sig = inspect.signature(func)
new_params = [
inspect.Parameter(
"request", inspect.Parameter.POSITIONAL_OR_KEYWORD, annotation=Request
)
]
new_params.extend(sig.parameters.values())
if method == "post":
# make sure every parameter is annotated with Body() so FASTAPI doesn't
# do anything too intelligent and ask for some parameters in the query
# and some in the body
endpoint.__signature__ = sig.replace(
parameters=[
param.replace(
annotation=Annotated[param.annotation, Body(..., embed=True)]
)
for param in sig.parameters.values()
]
)
else:
endpoint.__signature__ = sig
new_params = [new_params[0]] + [
param.replace(annotation=Annotated[param.annotation, Body(..., embed=True)])
for param in new_params[1:]
]
endpoint.__signature__ = sig.replace(parameters=new_params)
return endpoint
@ -365,7 +376,11 @@ def main(yaml_config: str, port: int = 5000, disable_ipv6: bool = False):
impl_method = getattr(impl, endpoint.name)
getattr(app, endpoint.method)(endpoint.route, response_model=None)(
create_dynamic_typed_route(impl_method, endpoint.method)
create_dynamic_typed_route(
impl_method,
endpoint.method,
provider_spec.provider_data_validator,
)
)
for route in app.routes: