From 5bf679cab639b386c93d15ee15b0328d82d802a7 Mon Sep 17 00:00:00 2001 From: Ashwin Bharambe Date: Sun, 29 Sep 2024 20:00:51 -0700 Subject: [PATCH] Pull (extract) provider data from the provider instead of pushing from the top (#148) --- llama_stack/distribution/request_headers.py | 39 +++++++++++-------- llama_stack/distribution/server/server.py | 19 ++------- .../adapters/inference/together/together.py | 6 +-- .../adapters/safety/together/together.py | 6 +-- 4 files changed, 32 insertions(+), 38 deletions(-) diff --git a/llama_stack/distribution/request_headers.py b/llama_stack/distribution/request_headers.py index 27b8b531f..5ed04a13a 100644 --- a/llama_stack/distribution/request_headers.py +++ b/llama_stack/distribution/request_headers.py @@ -6,21 +6,36 @@ import json import threading -from typing import Any, Dict, List +from typing import Any, Dict from .utils.dynamic import instantiate_class_type _THREAD_LOCAL = threading.local() -def get_request_provider_data() -> Any: - return getattr(_THREAD_LOCAL, "provider_data", None) +class NeedsRequestProviderData: + def get_request_provider_data(self) -> Any: + spec = self.__provider_spec__ + assert spec, f"Provider spec not set on {self.__class__}" + + provider_id = spec.provider_id + validator_class = spec.provider_data_validator + if not validator_class: + raise ValueError(f"Provider {provider_id} does not have a validator") + + val = _THREAD_LOCAL.provider_data_header_value + if not val: + return None + + validator = instantiate_class_type(validator_class) + try: + provider_data = validator(**val) + return provider_data + except Exception as e: + print("Error parsing provider data", e) -def set_request_provider_data(headers: Dict[str, str], validator_classes: List[str]): - if not validator_classes: - return - +def set_request_provider_data(headers: Dict[str, str]): keys = [ "X-LlamaStack-ProviderData", "x-llamastack-providerdata", @@ -39,12 +54,4 @@ def set_request_provider_data(headers: Dict[str, str], validator_classes: List[s print("Provider data not encoded as a JSON object!", val) return - for validator_class in validator_classes: - validator = instantiate_class_type(validator_class) - try: - provider_data = validator(**val) - if provider_data: - _THREAD_LOCAL.provider_data = provider_data - return - except Exception as e: - print("Error parsing provider data", e) + _THREAD_LOCAL.provider_data_header_value = val diff --git a/llama_stack/distribution/server/server.py b/llama_stack/distribution/server/server.py index a32c470d5..9cebe9b85 100644 --- a/llama_stack/distribution/server/server.py +++ b/llama_stack/distribution/server/server.py @@ -207,9 +207,7 @@ def create_dynamic_passthrough( return endpoint -def create_dynamic_typed_route( - func: Any, method: str, provider_data_validators: List[str] -): +def create_dynamic_typed_route(func: Any, method: str): hints = get_type_hints(func) response_model = hints.get("return") @@ -224,7 +222,7 @@ def create_dynamic_typed_route( async def endpoint(request: Request, **kwargs): await start_trace(func.__name__) - set_request_provider_data(request.headers, provider_data_validators) + set_request_provider_data(request.headers) async def sse_generator(event_gen): try: @@ -255,7 +253,7 @@ def create_dynamic_typed_route( async def endpoint(request: Request, **kwargs): await start_trace(func.__name__) - set_request_provider_data(request.headers, provider_data_validators) + set_request_provider_data(request.headers) try: return ( @@ -462,21 +460,10 @@ def main(yaml_config: str, port: int = 5000, disable_ipv6: bool = False): impl_method = getattr(impl, endpoint.name) - validators = [] - if isinstance(provider_spec, AutoRoutedProviderSpec): - inner_specs = specs[provider_spec.routing_table_api].inner_specs - for spec in inner_specs: - if spec.provider_data_validator: - validators.append(spec.provider_data_validator) - elif not isinstance(provider_spec, RoutingTableProviderSpec): - if provider_spec.provider_data_validator: - validators.append(provider_spec.provider_data_validator) - getattr(app, endpoint.method)(endpoint.route, response_model=None)( create_dynamic_typed_route( impl_method, endpoint.method, - validators, ) ) diff --git a/llama_stack/providers/adapters/inference/together/together.py b/llama_stack/providers/adapters/inference/together/together.py index 0737868ac..7053834bd 100644 --- a/llama_stack/providers/adapters/inference/together/together.py +++ b/llama_stack/providers/adapters/inference/together/together.py @@ -15,7 +15,7 @@ from llama_models.sku_list import resolve_model from together import Together from llama_stack.apis.inference import * # noqa: F403 -from llama_stack.distribution.request_headers import get_request_provider_data +from llama_stack.distribution.request_headers import NeedsRequestProviderData from llama_stack.providers.utils.inference.augment_messages import ( augment_messages_for_tools, ) @@ -32,7 +32,7 @@ TOGETHER_SUPPORTED_MODELS = { } -class TogetherInferenceAdapter(Inference): +class TogetherInferenceAdapter(Inference, NeedsRequestProviderData): def __init__(self, config: TogetherImplConfig) -> None: self.config = config tokenizer = Tokenizer.get_instance() @@ -103,7 +103,7 @@ class TogetherInferenceAdapter(Inference): ) -> AsyncGenerator: together_api_key = None - provider_data = get_request_provider_data() + provider_data = self.get_request_provider_data() if provider_data is None or not provider_data.together_api_key: raise ValueError( 'Pass Together API Key in the header X-LlamaStack-ProviderData as { "together_api_key": }' diff --git a/llama_stack/providers/adapters/safety/together/together.py b/llama_stack/providers/adapters/safety/together/together.py index 8e552fb6c..24fcc63b1 100644 --- a/llama_stack/providers/adapters/safety/together/together.py +++ b/llama_stack/providers/adapters/safety/together/together.py @@ -13,7 +13,7 @@ from llama_stack.apis.safety import ( SafetyViolation, ViolationLevel, ) -from llama_stack.distribution.request_headers import get_request_provider_data +from llama_stack.distribution.request_headers import NeedsRequestProviderData from .config import TogetherSafetyConfig @@ -40,7 +40,7 @@ def shield_type_to_model_name(shield_type: str) -> str: return SAFETY_SHIELD_TYPES.get(model.descriptor(shorten_default_variant=True)) -class TogetherSafetyImpl(Safety): +class TogetherSafetyImpl(Safety, NeedsRequestProviderData): def __init__(self, config: TogetherSafetyConfig) -> None: self.config = config @@ -52,7 +52,7 @@ class TogetherSafetyImpl(Safety): ) -> RunShieldResponse: together_api_key = None - provider_data = get_request_provider_data() + provider_data = self.get_request_provider_data() if provider_data is None or not provider_data.together_api_key: raise ValueError( 'Pass Together API Key in the header X-LlamaStack-ProviderData as { "together_api_key": }'