mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-05 12:21:52 +00:00
Add a special header per-client call to parser provider data
This commit is contained in:
parent
a6be32bc3d
commit
32beecb20d
11 changed files with 955 additions and 104 deletions
|
@ -92,6 +92,9 @@ Fully-qualified name of the module to import. The module is expected to have:
|
|||
default=None,
|
||||
description="Fully-qualified classname of the config for this provider",
|
||||
)
|
||||
provider_data_validator: Optional[str] = Field(
|
||||
default=None,
|
||||
)
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
@ -115,6 +118,9 @@ Fully-qualified name of the module to import. The module is expected to have:
|
|||
- `get_provider_impl(config, deps)`: returns the local implementation
|
||||
""",
|
||||
)
|
||||
provider_data_validator: Optional[str] = Field(
|
||||
default=None,
|
||||
)
|
||||
|
||||
|
||||
class RemoteProviderConfig(BaseModel):
|
||||
|
@ -159,6 +165,12 @@ as being "Llama Stack compatible"
|
|||
return self.adapter.pip_packages
|
||||
return []
|
||||
|
||||
@property
|
||||
def provider_data_validator(self) -> Optional[str]:
|
||||
if self.adapter:
|
||||
return self.adapter.provider_data_validator
|
||||
return None
|
||||
|
||||
|
||||
# Can avoid this by using Pydantic computed_field
|
||||
def remote_provider_spec(
|
||||
|
|
49
llama_stack/distribution/request_headers.py
Normal file
49
llama_stack/distribution/request_headers.py
Normal file
|
@ -0,0 +1,49 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import json
|
||||
import threading
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from .utils.dynamic import instantiate_class_type
|
||||
|
||||
_THREAD_LOCAL = threading.local()
|
||||
|
||||
|
||||
def get_request_provider_data() -> Any:
|
||||
return getattr(_THREAD_LOCAL, "provider_data", None)
|
||||
|
||||
|
||||
def set_request_provider_data(headers: Dict[str, str], validator_class: Optional[str]):
|
||||
if not validator_class:
|
||||
return
|
||||
|
||||
keys = [
|
||||
"X-LlamaStack-ProviderData",
|
||||
"x-llamastack-providerdata",
|
||||
]
|
||||
for key in keys:
|
||||
val = headers.get(key, None)
|
||||
if val:
|
||||
break
|
||||
|
||||
if not val:
|
||||
return
|
||||
|
||||
try:
|
||||
val = json.loads(val)
|
||||
except json.JSONDecodeError:
|
||||
print("Provider data not encoded as a JSON object!", val)
|
||||
return
|
||||
|
||||
validator = instantiate_class_type(validator_class)
|
||||
try:
|
||||
provider_data = validator(**val)
|
||||
except Exception as e:
|
||||
print("Error parsing provider data", e)
|
||||
return
|
||||
|
||||
_THREAD_LOCAL.provider_data = provider_data
|
|
@ -49,6 +49,7 @@ from typing_extensions import Annotated
|
|||
from llama_stack.distribution.datatypes import * # noqa: F403
|
||||
|
||||
from llama_stack.distribution.distribution import api_endpoints, api_providers
|
||||
from llama_stack.distribution.request_headers import set_request_provider_data
|
||||
from llama_stack.distribution.utils.dynamic import instantiate_provider
|
||||
|
||||
|
||||
|
@ -177,9 +178,9 @@ def create_dynamic_passthrough(
|
|||
return endpoint
|
||||
|
||||
|
||||
def create_dynamic_typed_route(func: Any, method: str):
|
||||
cprint(f"> create_dynamic_typed_route func={func}", "red")
|
||||
cprint(f"> create_dynamic_typed_route method={method}", "red")
|
||||
def create_dynamic_typed_route(
|
||||
func: Any, method: str, provider_data_validator: Optional[str]
|
||||
):
|
||||
hints = get_type_hints(func)
|
||||
response_model = hints.get("return")
|
||||
|
||||
|
@ -191,9 +192,11 @@ def create_dynamic_typed_route(func: Any, method: str):
|
|||
|
||||
if is_streaming:
|
||||
|
||||
async def endpoint(**kwargs):
|
||||
async def endpoint(request: Request, **kwargs):
|
||||
await start_trace(func.__name__)
|
||||
|
||||
set_request_provider_data(request.headers, provider_data_validator)
|
||||
|
||||
async def sse_generator(event_gen):
|
||||
try:
|
||||
async for item in event_gen:
|
||||
|
@ -220,8 +223,11 @@ def create_dynamic_typed_route(func: Any, method: str):
|
|||
|
||||
else:
|
||||
|
||||
async def endpoint(**kwargs):
|
||||
async def endpoint(request: Request, **kwargs):
|
||||
await start_trace(func.__name__)
|
||||
|
||||
set_request_provider_data(request.headers, provider_data_validator)
|
||||
|
||||
try:
|
||||
return (
|
||||
await func(**kwargs)
|
||||
|
@ -235,20 +241,23 @@ def create_dynamic_typed_route(func: Any, method: str):
|
|||
await end_trace()
|
||||
|
||||
sig = inspect.signature(func)
|
||||
new_params = [
|
||||
inspect.Parameter(
|
||||
"request", inspect.Parameter.POSITIONAL_OR_KEYWORD, annotation=Request
|
||||
)
|
||||
]
|
||||
new_params.extend(sig.parameters.values())
|
||||
|
||||
if method == "post":
|
||||
# make sure every parameter is annotated with Body() so FASTAPI doesn't
|
||||
# do anything too intelligent and ask for some parameters in the query
|
||||
# and some in the body
|
||||
endpoint.__signature__ = sig.replace(
|
||||
parameters=[
|
||||
param.replace(
|
||||
annotation=Annotated[param.annotation, Body(..., embed=True)]
|
||||
)
|
||||
for param in sig.parameters.values()
|
||||
]
|
||||
)
|
||||
else:
|
||||
endpoint.__signature__ = sig
|
||||
new_params = [new_params[0]] + [
|
||||
param.replace(annotation=Annotated[param.annotation, Body(..., embed=True)])
|
||||
for param in new_params[1:]
|
||||
]
|
||||
|
||||
endpoint.__signature__ = sig.replace(parameters=new_params)
|
||||
|
||||
return endpoint
|
||||
|
||||
|
@ -420,7 +429,11 @@ def run_main_DEPRECATED(
|
|||
|
||||
impl_method = getattr(impl, endpoint.name)
|
||||
getattr(app, endpoint.method)(endpoint.route, response_model=None)(
|
||||
create_dynamic_typed_route(impl_method, endpoint.method)
|
||||
create_dynamic_typed_route(
|
||||
impl_method,
|
||||
endpoint.method,
|
||||
provider_spec.provider_data_validator,
|
||||
)
|
||||
)
|
||||
|
||||
for route in app.routes:
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue