Add a special header per-client call to parser provider data

This commit is contained in:
Ashwin Bharambe 2024-09-18 09:17:59 -07:00 committed by Xi Yan
parent a6be32bc3d
commit 32beecb20d
11 changed files with 955 additions and 104 deletions

View file

@ -92,6 +92,9 @@ Fully-qualified name of the module to import. The module is expected to have:
default=None,
description="Fully-qualified classname of the config for this provider",
)
provider_data_validator: Optional[str] = Field(
default=None,
)
@json_schema_type
@ -115,6 +118,9 @@ Fully-qualified name of the module to import. The module is expected to have:
- `get_provider_impl(config, deps)`: returns the local implementation
""",
)
provider_data_validator: Optional[str] = Field(
default=None,
)
class RemoteProviderConfig(BaseModel):
@ -159,6 +165,12 @@ as being "Llama Stack compatible"
return self.adapter.pip_packages
return []
@property
def provider_data_validator(self) -> Optional[str]:
if self.adapter:
return self.adapter.provider_data_validator
return None
# Can avoid this by using Pydantic computed_field
def remote_provider_spec(

View file

@ -0,0 +1,49 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import json
import threading
from typing import Any, Dict, Optional
from .utils.dynamic import instantiate_class_type
_THREAD_LOCAL = threading.local()
def get_request_provider_data() -> Any:
return getattr(_THREAD_LOCAL, "provider_data", None)
def set_request_provider_data(headers: Dict[str, str], validator_class: Optional[str]):
if not validator_class:
return
keys = [
"X-LlamaStack-ProviderData",
"x-llamastack-providerdata",
]
for key in keys:
val = headers.get(key, None)
if val:
break
if not val:
return
try:
val = json.loads(val)
except json.JSONDecodeError:
print("Provider data not encoded as a JSON object!", val)
return
validator = instantiate_class_type(validator_class)
try:
provider_data = validator(**val)
except Exception as e:
print("Error parsing provider data", e)
return
_THREAD_LOCAL.provider_data = provider_data

View file

@ -49,6 +49,7 @@ from typing_extensions import Annotated
from llama_stack.distribution.datatypes import * # noqa: F403
from llama_stack.distribution.distribution import api_endpoints, api_providers
from llama_stack.distribution.request_headers import set_request_provider_data
from llama_stack.distribution.utils.dynamic import instantiate_provider
@ -177,9 +178,9 @@ def create_dynamic_passthrough(
return endpoint
def create_dynamic_typed_route(func: Any, method: str):
cprint(f"> create_dynamic_typed_route func={func}", "red")
cprint(f"> create_dynamic_typed_route method={method}", "red")
def create_dynamic_typed_route(
func: Any, method: str, provider_data_validator: Optional[str]
):
hints = get_type_hints(func)
response_model = hints.get("return")
@ -191,9 +192,11 @@ def create_dynamic_typed_route(func: Any, method: str):
if is_streaming:
async def endpoint(**kwargs):
async def endpoint(request: Request, **kwargs):
await start_trace(func.__name__)
set_request_provider_data(request.headers, provider_data_validator)
async def sse_generator(event_gen):
try:
async for item in event_gen:
@ -220,8 +223,11 @@ def create_dynamic_typed_route(func: Any, method: str):
else:
async def endpoint(**kwargs):
async def endpoint(request: Request, **kwargs):
await start_trace(func.__name__)
set_request_provider_data(request.headers, provider_data_validator)
try:
return (
await func(**kwargs)
@ -235,20 +241,23 @@ def create_dynamic_typed_route(func: Any, method: str):
await end_trace()
sig = inspect.signature(func)
new_params = [
inspect.Parameter(
"request", inspect.Parameter.POSITIONAL_OR_KEYWORD, annotation=Request
)
]
new_params.extend(sig.parameters.values())
if method == "post":
# make sure every parameter is annotated with Body() so FASTAPI doesn't
# do anything too intelligent and ask for some parameters in the query
# and some in the body
endpoint.__signature__ = sig.replace(
parameters=[
param.replace(
annotation=Annotated[param.annotation, Body(..., embed=True)]
)
for param in sig.parameters.values()
]
)
else:
endpoint.__signature__ = sig
new_params = [new_params[0]] + [
param.replace(annotation=Annotated[param.annotation, Body(..., embed=True)])
for param in new_params[1:]
]
endpoint.__signature__ = sig.replace(parameters=new_params)
return endpoint
@ -420,7 +429,11 @@ def run_main_DEPRECATED(
impl_method = getattr(impl, endpoint.name)
getattr(app, endpoint.method)(endpoint.route, response_model=None)(
create_dynamic_typed_route(impl_method, endpoint.method)
create_dynamic_typed_route(
impl_method,
endpoint.method,
provider_spec.provider_data_validator,
)
)
for route in app.routes:

View file

@ -4,7 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .config import TogetherImplConfig
from .config import TogetherImplConfig, TogetherHeaderExtractor
async def get_adapter_impl(config: TogetherImplConfig, _deps):

View file

@ -4,9 +4,18 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_models.schema_utils import json_schema_type
from pydantic import BaseModel, Field
from llama_models.schema_utils import json_schema_type
from llama_stack.distribution.request_headers import annotate_header
class TogetherHeaderExtractor(BaseModel):
api_key: annotate_header(
"X-LlamaStack-Together-ApiKey", str, "The API Key for the request"
)
@json_schema_type
class TogetherImplConfig(BaseModel):

View file

@ -63,6 +63,7 @@ def available_providers() -> List[ProviderSpec]:
],
module="llama_stack.providers.adapters.inference.together",
config_class="llama_stack.providers.adapters.inference.together.TogetherImplConfig",
header_extractor_class="llama_stack.providers.adapters.inference.together.TogetherHeaderExtractor",
),
),
]

View file

@ -16,6 +16,7 @@ import httpx
import numpy as np
from numpy.typing import NDArray
from pypdf import PdfReader
from termcolor import cprint
from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_models.llama3.api.tokenizer import Tokenizer
@ -160,6 +161,8 @@ class BankWithIndex:
self.bank.config.overlap_size_in_tokens
or (self.bank.config.chunk_size_in_tokens // 4),
)
if not chunks:
continue
embeddings = model.encode([x.content for x in chunks]).astype(np.float32)
await self.index.add_chunks(chunks, embeddings)