Pull (extract) provider data from the provider instead of pushing from the top (#148)

This commit is contained in:
Ashwin Bharambe 2024-09-29 20:00:51 -07:00 committed by GitHub
parent f6a6598d1a
commit 5bf679cab6
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 32 additions and 38 deletions

View file

@ -6,21 +6,36 @@
import json import json
import threading import threading
from typing import Any, Dict, List from typing import Any, Dict
from .utils.dynamic import instantiate_class_type from .utils.dynamic import instantiate_class_type
_THREAD_LOCAL = threading.local() _THREAD_LOCAL = threading.local()
def get_request_provider_data() -> Any: class NeedsRequestProviderData:
return getattr(_THREAD_LOCAL, "provider_data", None) def get_request_provider_data(self) -> Any:
spec = self.__provider_spec__
assert spec, f"Provider spec not set on {self.__class__}"
provider_id = spec.provider_id
validator_class = spec.provider_data_validator
if not validator_class:
raise ValueError(f"Provider {provider_id} does not have a validator")
val = _THREAD_LOCAL.provider_data_header_value
if not val:
return None
validator = instantiate_class_type(validator_class)
try:
provider_data = validator(**val)
return provider_data
except Exception as e:
print("Error parsing provider data", e)
def set_request_provider_data(headers: Dict[str, str], validator_classes: List[str]): def set_request_provider_data(headers: Dict[str, str]):
if not validator_classes:
return
keys = [ keys = [
"X-LlamaStack-ProviderData", "X-LlamaStack-ProviderData",
"x-llamastack-providerdata", "x-llamastack-providerdata",
@ -39,12 +54,4 @@ def set_request_provider_data(headers: Dict[str, str], validator_classes: List[s
print("Provider data not encoded as a JSON object!", val) print("Provider data not encoded as a JSON object!", val)
return return
for validator_class in validator_classes: _THREAD_LOCAL.provider_data_header_value = val
validator = instantiate_class_type(validator_class)
try:
provider_data = validator(**val)
if provider_data:
_THREAD_LOCAL.provider_data = provider_data
return
except Exception as e:
print("Error parsing provider data", e)

View file

@ -207,9 +207,7 @@ def create_dynamic_passthrough(
return endpoint return endpoint
def create_dynamic_typed_route( def create_dynamic_typed_route(func: Any, method: str):
func: Any, method: str, provider_data_validators: List[str]
):
hints = get_type_hints(func) hints = get_type_hints(func)
response_model = hints.get("return") response_model = hints.get("return")
@ -224,7 +222,7 @@ def create_dynamic_typed_route(
async def endpoint(request: Request, **kwargs): async def endpoint(request: Request, **kwargs):
await start_trace(func.__name__) await start_trace(func.__name__)
set_request_provider_data(request.headers, provider_data_validators) set_request_provider_data(request.headers)
async def sse_generator(event_gen): async def sse_generator(event_gen):
try: try:
@ -255,7 +253,7 @@ def create_dynamic_typed_route(
async def endpoint(request: Request, **kwargs): async def endpoint(request: Request, **kwargs):
await start_trace(func.__name__) await start_trace(func.__name__)
set_request_provider_data(request.headers, provider_data_validators) set_request_provider_data(request.headers)
try: try:
return ( return (
@ -462,21 +460,10 @@ def main(yaml_config: str, port: int = 5000, disable_ipv6: bool = False):
impl_method = getattr(impl, endpoint.name) impl_method = getattr(impl, endpoint.name)
validators = []
if isinstance(provider_spec, AutoRoutedProviderSpec):
inner_specs = specs[provider_spec.routing_table_api].inner_specs
for spec in inner_specs:
if spec.provider_data_validator:
validators.append(spec.provider_data_validator)
elif not isinstance(provider_spec, RoutingTableProviderSpec):
if provider_spec.provider_data_validator:
validators.append(provider_spec.provider_data_validator)
getattr(app, endpoint.method)(endpoint.route, response_model=None)( getattr(app, endpoint.method)(endpoint.route, response_model=None)(
create_dynamic_typed_route( create_dynamic_typed_route(
impl_method, impl_method,
endpoint.method, endpoint.method,
validators,
) )
) )

View file

@ -15,7 +15,7 @@ from llama_models.sku_list import resolve_model
from together import Together from together import Together
from llama_stack.apis.inference import * # noqa: F403 from llama_stack.apis.inference import * # noqa: F403
from llama_stack.distribution.request_headers import get_request_provider_data from llama_stack.distribution.request_headers import NeedsRequestProviderData
from llama_stack.providers.utils.inference.augment_messages import ( from llama_stack.providers.utils.inference.augment_messages import (
augment_messages_for_tools, augment_messages_for_tools,
) )
@ -32,7 +32,7 @@ TOGETHER_SUPPORTED_MODELS = {
} }
class TogetherInferenceAdapter(Inference): class TogetherInferenceAdapter(Inference, NeedsRequestProviderData):
def __init__(self, config: TogetherImplConfig) -> None: def __init__(self, config: TogetherImplConfig) -> None:
self.config = config self.config = config
tokenizer = Tokenizer.get_instance() tokenizer = Tokenizer.get_instance()
@ -103,7 +103,7 @@ class TogetherInferenceAdapter(Inference):
) -> AsyncGenerator: ) -> AsyncGenerator:
together_api_key = None together_api_key = None
provider_data = get_request_provider_data() provider_data = self.get_request_provider_data()
if provider_data is None or not provider_data.together_api_key: if provider_data is None or not provider_data.together_api_key:
raise ValueError( raise ValueError(
'Pass Together API Key in the header X-LlamaStack-ProviderData as { "together_api_key": <your api key>}' 'Pass Together API Key in the header X-LlamaStack-ProviderData as { "together_api_key": <your api key>}'

View file

@ -13,7 +13,7 @@ from llama_stack.apis.safety import (
SafetyViolation, SafetyViolation,
ViolationLevel, ViolationLevel,
) )
from llama_stack.distribution.request_headers import get_request_provider_data from llama_stack.distribution.request_headers import NeedsRequestProviderData
from .config import TogetherSafetyConfig from .config import TogetherSafetyConfig
@ -40,7 +40,7 @@ def shield_type_to_model_name(shield_type: str) -> str:
return SAFETY_SHIELD_TYPES.get(model.descriptor(shorten_default_variant=True)) return SAFETY_SHIELD_TYPES.get(model.descriptor(shorten_default_variant=True))
class TogetherSafetyImpl(Safety): class TogetherSafetyImpl(Safety, NeedsRequestProviderData):
def __init__(self, config: TogetherSafetyConfig) -> None: def __init__(self, config: TogetherSafetyConfig) -> None:
self.config = config self.config = config
@ -52,7 +52,7 @@ class TogetherSafetyImpl(Safety):
) -> RunShieldResponse: ) -> RunShieldResponse:
together_api_key = None together_api_key = None
provider_data = get_request_provider_data() provider_data = self.get_request_provider_data()
if provider_data is None or not provider_data.together_api_key: if provider_data is None or not provider_data.together_api_key:
raise ValueError( raise ValueError(
'Pass Together API Key in the header X-LlamaStack-ProviderData as { "together_api_key": <your api key>}' 'Pass Together API Key in the header X-LlamaStack-ProviderData as { "together_api_key": <your api key>}'