mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-06-28 02:53:30 +00:00
Pull (extract) provider data from the provider instead of pushing from the top (#148)
This commit is contained in:
parent
f6a6598d1a
commit
5bf679cab6
4 changed files with 32 additions and 38 deletions
|
@ -6,21 +6,36 @@
|
||||||
|
|
||||||
import json
|
import json
|
||||||
import threading
|
import threading
|
||||||
from typing import Any, Dict, List
|
from typing import Any, Dict
|
||||||
|
|
||||||
from .utils.dynamic import instantiate_class_type
|
from .utils.dynamic import instantiate_class_type
|
||||||
|
|
||||||
_THREAD_LOCAL = threading.local()
|
_THREAD_LOCAL = threading.local()
|
||||||
|
|
||||||
|
|
||||||
def get_request_provider_data() -> Any:
|
class NeedsRequestProviderData:
|
||||||
return getattr(_THREAD_LOCAL, "provider_data", None)
|
def get_request_provider_data(self) -> Any:
|
||||||
|
spec = self.__provider_spec__
|
||||||
|
assert spec, f"Provider spec not set on {self.__class__}"
|
||||||
|
|
||||||
|
provider_id = spec.provider_id
|
||||||
|
validator_class = spec.provider_data_validator
|
||||||
|
if not validator_class:
|
||||||
|
raise ValueError(f"Provider {provider_id} does not have a validator")
|
||||||
|
|
||||||
|
val = _THREAD_LOCAL.provider_data_header_value
|
||||||
|
if not val:
|
||||||
|
return None
|
||||||
|
|
||||||
|
validator = instantiate_class_type(validator_class)
|
||||||
|
try:
|
||||||
|
provider_data = validator(**val)
|
||||||
|
return provider_data
|
||||||
|
except Exception as e:
|
||||||
|
print("Error parsing provider data", e)
|
||||||
|
|
||||||
|
|
||||||
def set_request_provider_data(headers: Dict[str, str], validator_classes: List[str]):
|
def set_request_provider_data(headers: Dict[str, str]):
|
||||||
if not validator_classes:
|
|
||||||
return
|
|
||||||
|
|
||||||
keys = [
|
keys = [
|
||||||
"X-LlamaStack-ProviderData",
|
"X-LlamaStack-ProviderData",
|
||||||
"x-llamastack-providerdata",
|
"x-llamastack-providerdata",
|
||||||
|
@ -39,12 +54,4 @@ def set_request_provider_data(headers: Dict[str, str], validator_classes: List[s
|
||||||
print("Provider data not encoded as a JSON object!", val)
|
print("Provider data not encoded as a JSON object!", val)
|
||||||
return
|
return
|
||||||
|
|
||||||
for validator_class in validator_classes:
|
_THREAD_LOCAL.provider_data_header_value = val
|
||||||
validator = instantiate_class_type(validator_class)
|
|
||||||
try:
|
|
||||||
provider_data = validator(**val)
|
|
||||||
if provider_data:
|
|
||||||
_THREAD_LOCAL.provider_data = provider_data
|
|
||||||
return
|
|
||||||
except Exception as e:
|
|
||||||
print("Error parsing provider data", e)
|
|
||||||
|
|
|
@ -207,9 +207,7 @@ def create_dynamic_passthrough(
|
||||||
return endpoint
|
return endpoint
|
||||||
|
|
||||||
|
|
||||||
def create_dynamic_typed_route(
|
def create_dynamic_typed_route(func: Any, method: str):
|
||||||
func: Any, method: str, provider_data_validators: List[str]
|
|
||||||
):
|
|
||||||
hints = get_type_hints(func)
|
hints = get_type_hints(func)
|
||||||
response_model = hints.get("return")
|
response_model = hints.get("return")
|
||||||
|
|
||||||
|
@ -224,7 +222,7 @@ def create_dynamic_typed_route(
|
||||||
async def endpoint(request: Request, **kwargs):
|
async def endpoint(request: Request, **kwargs):
|
||||||
await start_trace(func.__name__)
|
await start_trace(func.__name__)
|
||||||
|
|
||||||
set_request_provider_data(request.headers, provider_data_validators)
|
set_request_provider_data(request.headers)
|
||||||
|
|
||||||
async def sse_generator(event_gen):
|
async def sse_generator(event_gen):
|
||||||
try:
|
try:
|
||||||
|
@ -255,7 +253,7 @@ def create_dynamic_typed_route(
|
||||||
async def endpoint(request: Request, **kwargs):
|
async def endpoint(request: Request, **kwargs):
|
||||||
await start_trace(func.__name__)
|
await start_trace(func.__name__)
|
||||||
|
|
||||||
set_request_provider_data(request.headers, provider_data_validators)
|
set_request_provider_data(request.headers)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
return (
|
return (
|
||||||
|
@ -462,21 +460,10 @@ def main(yaml_config: str, port: int = 5000, disable_ipv6: bool = False):
|
||||||
|
|
||||||
impl_method = getattr(impl, endpoint.name)
|
impl_method = getattr(impl, endpoint.name)
|
||||||
|
|
||||||
validators = []
|
|
||||||
if isinstance(provider_spec, AutoRoutedProviderSpec):
|
|
||||||
inner_specs = specs[provider_spec.routing_table_api].inner_specs
|
|
||||||
for spec in inner_specs:
|
|
||||||
if spec.provider_data_validator:
|
|
||||||
validators.append(spec.provider_data_validator)
|
|
||||||
elif not isinstance(provider_spec, RoutingTableProviderSpec):
|
|
||||||
if provider_spec.provider_data_validator:
|
|
||||||
validators.append(provider_spec.provider_data_validator)
|
|
||||||
|
|
||||||
getattr(app, endpoint.method)(endpoint.route, response_model=None)(
|
getattr(app, endpoint.method)(endpoint.route, response_model=None)(
|
||||||
create_dynamic_typed_route(
|
create_dynamic_typed_route(
|
||||||
impl_method,
|
impl_method,
|
||||||
endpoint.method,
|
endpoint.method,
|
||||||
validators,
|
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -15,7 +15,7 @@ from llama_models.sku_list import resolve_model
|
||||||
from together import Together
|
from together import Together
|
||||||
|
|
||||||
from llama_stack.apis.inference import * # noqa: F403
|
from llama_stack.apis.inference import * # noqa: F403
|
||||||
from llama_stack.distribution.request_headers import get_request_provider_data
|
from llama_stack.distribution.request_headers import NeedsRequestProviderData
|
||||||
from llama_stack.providers.utils.inference.augment_messages import (
|
from llama_stack.providers.utils.inference.augment_messages import (
|
||||||
augment_messages_for_tools,
|
augment_messages_for_tools,
|
||||||
)
|
)
|
||||||
|
@ -32,7 +32,7 @@ TOGETHER_SUPPORTED_MODELS = {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
class TogetherInferenceAdapter(Inference):
|
class TogetherInferenceAdapter(Inference, NeedsRequestProviderData):
|
||||||
def __init__(self, config: TogetherImplConfig) -> None:
|
def __init__(self, config: TogetherImplConfig) -> None:
|
||||||
self.config = config
|
self.config = config
|
||||||
tokenizer = Tokenizer.get_instance()
|
tokenizer = Tokenizer.get_instance()
|
||||||
|
@ -103,7 +103,7 @@ class TogetherInferenceAdapter(Inference):
|
||||||
) -> AsyncGenerator:
|
) -> AsyncGenerator:
|
||||||
|
|
||||||
together_api_key = None
|
together_api_key = None
|
||||||
provider_data = get_request_provider_data()
|
provider_data = self.get_request_provider_data()
|
||||||
if provider_data is None or not provider_data.together_api_key:
|
if provider_data is None or not provider_data.together_api_key:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
'Pass Together API Key in the header X-LlamaStack-ProviderData as { "together_api_key": <your api key>}'
|
'Pass Together API Key in the header X-LlamaStack-ProviderData as { "together_api_key": <your api key>}'
|
||||||
|
|
|
@ -13,7 +13,7 @@ from llama_stack.apis.safety import (
|
||||||
SafetyViolation,
|
SafetyViolation,
|
||||||
ViolationLevel,
|
ViolationLevel,
|
||||||
)
|
)
|
||||||
from llama_stack.distribution.request_headers import get_request_provider_data
|
from llama_stack.distribution.request_headers import NeedsRequestProviderData
|
||||||
|
|
||||||
from .config import TogetherSafetyConfig
|
from .config import TogetherSafetyConfig
|
||||||
|
|
||||||
|
@ -40,7 +40,7 @@ def shield_type_to_model_name(shield_type: str) -> str:
|
||||||
return SAFETY_SHIELD_TYPES.get(model.descriptor(shorten_default_variant=True))
|
return SAFETY_SHIELD_TYPES.get(model.descriptor(shorten_default_variant=True))
|
||||||
|
|
||||||
|
|
||||||
class TogetherSafetyImpl(Safety):
|
class TogetherSafetyImpl(Safety, NeedsRequestProviderData):
|
||||||
def __init__(self, config: TogetherSafetyConfig) -> None:
|
def __init__(self, config: TogetherSafetyConfig) -> None:
|
||||||
self.config = config
|
self.config = config
|
||||||
|
|
||||||
|
@ -52,7 +52,7 @@ class TogetherSafetyImpl(Safety):
|
||||||
) -> RunShieldResponse:
|
) -> RunShieldResponse:
|
||||||
|
|
||||||
together_api_key = None
|
together_api_key = None
|
||||||
provider_data = get_request_provider_data()
|
provider_data = self.get_request_provider_data()
|
||||||
if provider_data is None or not provider_data.together_api_key:
|
if provider_data is None or not provider_data.together_api_key:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
'Pass Together API Key in the header X-LlamaStack-ProviderData as { "together_api_key": <your api key>}'
|
'Pass Together API Key in the header X-LlamaStack-ProviderData as { "together_api_key": <your api key>}'
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue