Add a special header per-client call to parser provider data

This commit is contained in:
Ashwin Bharambe 2024-09-18 09:17:59 -07:00 committed by Xi Yan
parent a6be32bc3d
commit 32beecb20d
11 changed files with 955 additions and 104 deletions

View file

@ -49,6 +49,7 @@ from typing_extensions import Annotated
from llama_stack.distribution.datatypes import * # noqa: F403
from llama_stack.distribution.distribution import api_endpoints, api_providers
from llama_stack.distribution.request_headers import set_request_provider_data
from llama_stack.distribution.utils.dynamic import instantiate_provider
@ -177,9 +178,9 @@ def create_dynamic_passthrough(
return endpoint
def create_dynamic_typed_route(func: Any, method: str):
cprint(f"> create_dynamic_typed_route func={func}", "red")
cprint(f"> create_dynamic_typed_route method={method}", "red")
def create_dynamic_typed_route(
func: Any, method: str, provider_data_validator: Optional[str]
):
hints = get_type_hints(func)
response_model = hints.get("return")
@ -191,9 +192,11 @@ def create_dynamic_typed_route(func: Any, method: str):
if is_streaming:
async def endpoint(**kwargs):
async def endpoint(request: Request, **kwargs):
await start_trace(func.__name__)
set_request_provider_data(request.headers, provider_data_validator)
async def sse_generator(event_gen):
try:
async for item in event_gen:
@ -220,8 +223,11 @@ def create_dynamic_typed_route(func: Any, method: str):
else:
async def endpoint(**kwargs):
async def endpoint(request: Request, **kwargs):
await start_trace(func.__name__)
set_request_provider_data(request.headers, provider_data_validator)
try:
return (
await func(**kwargs)
@ -235,20 +241,23 @@ def create_dynamic_typed_route(func: Any, method: str):
await end_trace()
sig = inspect.signature(func)
new_params = [
inspect.Parameter(
"request", inspect.Parameter.POSITIONAL_OR_KEYWORD, annotation=Request
)
]
new_params.extend(sig.parameters.values())
if method == "post":
# make sure every parameter is annotated with Body() so FASTAPI doesn't
# do anything too intelligent and ask for some parameters in the query
# and some in the body
endpoint.__signature__ = sig.replace(
parameters=[
param.replace(
annotation=Annotated[param.annotation, Body(..., embed=True)]
)
for param in sig.parameters.values()
]
)
else:
endpoint.__signature__ = sig
new_params = [new_params[0]] + [
param.replace(annotation=Annotated[param.annotation, Body(..., embed=True)])
for param in new_params[1:]
]
endpoint.__signature__ = sig.replace(parameters=new_params)
return endpoint
@ -420,7 +429,11 @@ def run_main_DEPRECATED(
impl_method = getattr(impl, endpoint.name)
getattr(app, endpoint.method)(endpoint.route, response_model=None)(
create_dynamic_typed_route(impl_method, endpoint.method)
create_dynamic_typed_route(
impl_method,
endpoint.method,
provider_spec.provider_data_validator,
)
)
for route in app.routes: