diff --git a/llama_stack/providers/adapters/inference/together/config.py b/llama_stack/providers/adapters/inference/together/config.py index c58f722bc..378e85522 100644 --- a/llama_stack/providers/adapters/inference/together/config.py +++ b/llama_stack/providers/adapters/inference/together/config.py @@ -4,17 +4,12 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from pydantic import BaseModel, Field - from llama_models.schema_utils import json_schema_type - -from llama_stack.distribution.request_headers import annotate_header +from pydantic import BaseModel, Field class TogetherHeaderExtractor(BaseModel): - api_key: annotate_header( - "X-LlamaStack-Together-ApiKey", str, "The API Key for the request" - ) + together_api_key: str @json_schema_type diff --git a/llama_stack/providers/adapters/inference/together/together.py b/llama_stack/providers/adapters/inference/together/together.py index cafca3fdf..18c83aada 100644 --- a/llama_stack/providers/adapters/inference/together/together.py +++ b/llama_stack/providers/adapters/inference/together/together.py @@ -18,6 +18,7 @@ from llama_stack.apis.inference import * # noqa: F403 from llama_stack.providers.utils.inference.augment_messages import ( augment_messages_for_tools, ) +from llama_stack.distribution.request_headers import get_request_provider_data from .config import TogetherImplConfig @@ -97,6 +98,16 @@ class TogetherInferenceAdapter(Inference): stream: Optional[bool] = False, logprobs: Optional[LogProbConfig] = None, ) -> AsyncGenerator: + + together_api_key = None + provider_data = get_request_provider_data() + if provider_data is None or not provider_data.together_api_key: + raise ValueError( + 'Pass Together API Key in the header X-LlamaStack-ProviderData as { "together_api_key": }' + ) + together_api_key = provider_data.together_api_key + + client = Together(api_key=together_api_key) # wrapper request to make it easier to pass around (internal only, not exposed to API) request = ChatCompletionRequest( model=model, @@ -116,7 +127,7 @@ class TogetherInferenceAdapter(Inference): if not request.stream: # TODO: might need to add back an async here - r = self.client.chat.completions.create( + r = client.chat.completions.create( model=together_model, messages=self._messages_to_together_messages(messages), stream=False, @@ -151,7 +162,7 @@ class TogetherInferenceAdapter(Inference): ipython = False stop_reason = None - for chunk in self.client.chat.completions.create( + for chunk in client.chat.completions.create( model=together_model, messages=self._messages_to_together_messages(messages), stream=True, diff --git a/llama_stack/providers/adapters/safety/together/together.py b/llama_stack/providers/adapters/safety/together/together.py index 223377073..15b6bb3a1 100644 --- a/llama_stack/providers/adapters/safety/together/together.py +++ b/llama_stack/providers/adapters/safety/together/together.py @@ -6,9 +6,16 @@ from together import Together +from llama_models.llama3.api.datatypes import * # noqa: F403 +from llama_stack.apis.safety import ( + RunShieldResponse, + Safety, + SafetyViolation, + ViolationLevel, +) from llama_stack.distribution.request_headers import get_request_provider_data -from .config import TogetherProviderDataValidator, TogetherSafetyConfig +from .config import TogetherSafetyConfig class TogetherSafetyImpl(Safety): @@ -24,21 +31,13 @@ class TogetherSafetyImpl(Safety): if shield_type != "llama_guard": raise ValueError(f"shield type {shield_type} is not supported") - provider_data = get_request_provider_data() - together_api_key = None - if provider_data is not None: - if not isinstance(provider_data, TogetherProviderDataValidator): - raise ValueError( - 'Pass Together API Key in the header X-LlamaStack-ProviderData as { "together_api_key": }' - ) - - together_api_key = provider_data.together_api_key - if not together_api_key: - together_api_key = self.config.api_key - - if not together_api_key: - raise ValueError("The API key must be provider in the header or config") + provider_data = get_request_provider_data() + if provider_data is None or not provider_data.together_api_key: + raise ValueError( + 'Pass Together API Key in the header X-LlamaStack-ProviderData as { "together_api_key": }' + ) + together_api_key = provider_data.together_api_key # messages can have role assistant or user api_messages = [] @@ -62,7 +61,9 @@ async def get_safety_response( response_text = response.choices[0].message.content if response_text == "safe": - return None + return SafetyViolation( + violation_level=ViolationLevel.INFO, user_message="safe", metadata={} + ) parts = response_text.split("\n") if len(parts) != 2: diff --git a/llama_stack/providers/registry/inference.py b/llama_stack/providers/registry/inference.py index 31b3e2c2d..9e7ed90f7 100644 --- a/llama_stack/providers/registry/inference.py +++ b/llama_stack/providers/registry/inference.py @@ -91,7 +91,7 @@ def available_providers() -> List[ProviderSpec]: ], module="llama_stack.providers.adapters.inference.together", config_class="llama_stack.providers.adapters.inference.together.TogetherImplConfig", - header_extractor_class="llama_stack.providers.adapters.inference.together.TogetherHeaderExtractor", + provider_data_validator="llama_stack.providers.adapters.safety.together.TogetherProviderDataValidator", ), ), ]