mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-04 12:07:34 +00:00
Add a special header per-client call to parser provider data
This commit is contained in:
parent
a6be32bc3d
commit
32beecb20d
11 changed files with 955 additions and 104 deletions
|
@ -49,6 +49,7 @@ from typing_extensions import Annotated
|
|||
from llama_stack.distribution.datatypes import * # noqa: F403
|
||||
|
||||
from llama_stack.distribution.distribution import api_endpoints, api_providers
|
||||
from llama_stack.distribution.request_headers import set_request_provider_data
|
||||
from llama_stack.distribution.utils.dynamic import instantiate_provider
|
||||
|
||||
|
||||
|
@ -177,9 +178,9 @@ def create_dynamic_passthrough(
|
|||
return endpoint
|
||||
|
||||
|
||||
def create_dynamic_typed_route(func: Any, method: str):
|
||||
cprint(f"> create_dynamic_typed_route func={func}", "red")
|
||||
cprint(f"> create_dynamic_typed_route method={method}", "red")
|
||||
def create_dynamic_typed_route(
|
||||
func: Any, method: str, provider_data_validator: Optional[str]
|
||||
):
|
||||
hints = get_type_hints(func)
|
||||
response_model = hints.get("return")
|
||||
|
||||
|
@ -191,9 +192,11 @@ def create_dynamic_typed_route(func: Any, method: str):
|
|||
|
||||
if is_streaming:
|
||||
|
||||
async def endpoint(**kwargs):
|
||||
async def endpoint(request: Request, **kwargs):
|
||||
await start_trace(func.__name__)
|
||||
|
||||
set_request_provider_data(request.headers, provider_data_validator)
|
||||
|
||||
async def sse_generator(event_gen):
|
||||
try:
|
||||
async for item in event_gen:
|
||||
|
@ -220,8 +223,11 @@ def create_dynamic_typed_route(func: Any, method: str):
|
|||
|
||||
else:
|
||||
|
||||
async def endpoint(**kwargs):
|
||||
async def endpoint(request: Request, **kwargs):
|
||||
await start_trace(func.__name__)
|
||||
|
||||
set_request_provider_data(request.headers, provider_data_validator)
|
||||
|
||||
try:
|
||||
return (
|
||||
await func(**kwargs)
|
||||
|
@ -235,20 +241,23 @@ def create_dynamic_typed_route(func: Any, method: str):
|
|||
await end_trace()
|
||||
|
||||
sig = inspect.signature(func)
|
||||
new_params = [
|
||||
inspect.Parameter(
|
||||
"request", inspect.Parameter.POSITIONAL_OR_KEYWORD, annotation=Request
|
||||
)
|
||||
]
|
||||
new_params.extend(sig.parameters.values())
|
||||
|
||||
if method == "post":
|
||||
# make sure every parameter is annotated with Body() so FASTAPI doesn't
|
||||
# do anything too intelligent and ask for some parameters in the query
|
||||
# and some in the body
|
||||
endpoint.__signature__ = sig.replace(
|
||||
parameters=[
|
||||
param.replace(
|
||||
annotation=Annotated[param.annotation, Body(..., embed=True)]
|
||||
)
|
||||
for param in sig.parameters.values()
|
||||
]
|
||||
)
|
||||
else:
|
||||
endpoint.__signature__ = sig
|
||||
new_params = [new_params[0]] + [
|
||||
param.replace(annotation=Annotated[param.annotation, Body(..., embed=True)])
|
||||
for param in new_params[1:]
|
||||
]
|
||||
|
||||
endpoint.__signature__ = sig.replace(parameters=new_params)
|
||||
|
||||
return endpoint
|
||||
|
||||
|
@ -420,7 +429,11 @@ def run_main_DEPRECATED(
|
|||
|
||||
impl_method = getattr(impl, endpoint.name)
|
||||
getattr(app, endpoint.method)(endpoint.route, response_model=None)(
|
||||
create_dynamic_typed_route(impl_method, endpoint.method)
|
||||
create_dynamic_typed_route(
|
||||
impl_method,
|
||||
endpoint.method,
|
||||
provider_spec.provider_data_validator,
|
||||
)
|
||||
)
|
||||
|
||||
for route in app.routes:
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue