forked from phoenix-oss/llama-stack-mirror
Pull (extract) provider data from the provider instead of pushing from the top (#148)
This commit is contained in:
parent
f6a6598d1a
commit
5bf679cab6
4 changed files with 32 additions and 38 deletions
|
@ -15,7 +15,7 @@ from llama_models.sku_list import resolve_model
|
|||
from together import Together
|
||||
|
||||
from llama_stack.apis.inference import * # noqa: F403
|
||||
from llama_stack.distribution.request_headers import get_request_provider_data
|
||||
from llama_stack.distribution.request_headers import NeedsRequestProviderData
|
||||
from llama_stack.providers.utils.inference.augment_messages import (
|
||||
augment_messages_for_tools,
|
||||
)
|
||||
|
@ -32,7 +32,7 @@ TOGETHER_SUPPORTED_MODELS = {
|
|||
}
|
||||
|
||||
|
||||
class TogetherInferenceAdapter(Inference):
|
||||
class TogetherInferenceAdapter(Inference, NeedsRequestProviderData):
|
||||
def __init__(self, config: TogetherImplConfig) -> None:
|
||||
self.config = config
|
||||
tokenizer = Tokenizer.get_instance()
|
||||
|
@ -103,7 +103,7 @@ class TogetherInferenceAdapter(Inference):
|
|||
) -> AsyncGenerator:
|
||||
|
||||
together_api_key = None
|
||||
provider_data = get_request_provider_data()
|
||||
provider_data = self.get_request_provider_data()
|
||||
if provider_data is None or not provider_data.together_api_key:
|
||||
raise ValueError(
|
||||
'Pass Together API Key in the header X-LlamaStack-ProviderData as { "together_api_key": <your api key>}'
|
||||
|
|
|
@ -13,7 +13,7 @@ from llama_stack.apis.safety import (
|
|||
SafetyViolation,
|
||||
ViolationLevel,
|
||||
)
|
||||
from llama_stack.distribution.request_headers import get_request_provider_data
|
||||
from llama_stack.distribution.request_headers import NeedsRequestProviderData
|
||||
|
||||
from .config import TogetherSafetyConfig
|
||||
|
||||
|
@ -40,7 +40,7 @@ def shield_type_to_model_name(shield_type: str) -> str:
|
|||
return SAFETY_SHIELD_TYPES.get(model.descriptor(shorten_default_variant=True))
|
||||
|
||||
|
||||
class TogetherSafetyImpl(Safety):
|
||||
class TogetherSafetyImpl(Safety, NeedsRequestProviderData):
|
||||
def __init__(self, config: TogetherSafetyConfig) -> None:
|
||||
self.config = config
|
||||
|
||||
|
@ -52,7 +52,7 @@ class TogetherSafetyImpl(Safety):
|
|||
) -> RunShieldResponse:
|
||||
|
||||
together_api_key = None
|
||||
provider_data = get_request_provider_data()
|
||||
provider_data = self.get_request_provider_data()
|
||||
if provider_data is None or not provider_data.together_api_key:
|
||||
raise ValueError(
|
||||
'Pass Together API Key in the header X-LlamaStack-ProviderData as { "together_api_key": <your api key>}'
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue