Pull (extract) provider data from the provider instead of pushing from the top (#148)

This commit is contained in:
Ashwin Bharambe 2024-09-29 20:00:51 -07:00 committed by GitHub
parent f6a6598d1a
commit 5bf679cab6
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 32 additions and 38 deletions

View file

@ -13,7 +13,7 @@ from llama_stack.apis.safety import (
SafetyViolation,
ViolationLevel,
)
from llama_stack.distribution.request_headers import get_request_provider_data
from llama_stack.distribution.request_headers import NeedsRequestProviderData
from .config import TogetherSafetyConfig
@ -40,7 +40,7 @@ def shield_type_to_model_name(shield_type: str) -> str:
return SAFETY_SHIELD_TYPES.get(model.descriptor(shorten_default_variant=True))
class TogetherSafetyImpl(Safety):
class TogetherSafetyImpl(Safety, NeedsRequestProviderData):
def __init__(self, config: TogetherSafetyConfig) -> None:
self.config = config
@ -52,7 +52,7 @@ class TogetherSafetyImpl(Safety):
) -> RunShieldResponse:
together_api_key = None
provider_data = get_request_provider_data()
provider_data = self.get_request_provider_data()
if provider_data is None or not provider_data.together_api_key:
raise ValueError(
'Pass Together API Key in the header X-LlamaStack-ProviderData as { "together_api_key": <your api key>}'