Pull (extract) provider data from the provider instead of pushing from the top (#148)

This commit is contained in:
Ashwin Bharambe 2024-09-29 20:00:51 -07:00 committed by GitHub
parent f6a6598d1a
commit 5bf679cab6
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 32 additions and 38 deletions

View file

@ -6,21 +6,36 @@
import json
import threading
from typing import Any, Dict, List
from typing import Any, Dict
from .utils.dynamic import instantiate_class_type
_THREAD_LOCAL = threading.local()
def get_request_provider_data() -> Any:
return getattr(_THREAD_LOCAL, "provider_data", None)
class NeedsRequestProviderData:
def get_request_provider_data(self) -> Any:
spec = self.__provider_spec__
assert spec, f"Provider spec not set on {self.__class__}"
provider_id = spec.provider_id
validator_class = spec.provider_data_validator
if not validator_class:
raise ValueError(f"Provider {provider_id} does not have a validator")
val = _THREAD_LOCAL.provider_data_header_value
if not val:
return None
validator = instantiate_class_type(validator_class)
try:
provider_data = validator(**val)
return provider_data
except Exception as e:
print("Error parsing provider data", e)
def set_request_provider_data(headers: Dict[str, str], validator_classes: List[str]):
if not validator_classes:
return
def set_request_provider_data(headers: Dict[str, str]):
keys = [
"X-LlamaStack-ProviderData",
"x-llamastack-providerdata",
@ -39,12 +54,4 @@ def set_request_provider_data(headers: Dict[str, str], validator_classes: List[s
print("Provider data not encoded as a JSON object!", val)
return
for validator_class in validator_classes:
validator = instantiate_class_type(validator_class)
try:
provider_data = validator(**val)
if provider_data:
_THREAD_LOCAL.provider_data = provider_data
return
except Exception as e:
print("Error parsing provider data", e)
_THREAD_LOCAL.provider_data_header_value = val