From 0996ffb3b3e83625200b8d1d444f9bf92936ab92 Mon Sep 17 00:00:00 2001 From: Ashwin Bharambe Date: Mon, 30 Sep 2024 16:15:51 -0700 Subject: [PATCH] bug fixes --- llama_stack/distribution/request_headers.py | 2 +- .../adapters/inference/bedrock/bedrock.py | 6 +++--- .../adapters/inference/together/config.py | 6 ++++-- .../adapters/inference/together/together.py | 15 +++++++++------ .../adapters/safety/together/together.py | 15 +++++++++------ llama_stack/providers/utils/inference/routable.py | 2 +- 6 files changed, 27 insertions(+), 19 deletions(-) diff --git a/llama_stack/distribution/request_headers.py b/llama_stack/distribution/request_headers.py index 5ed04a13a..990fa66d5 100644 --- a/llama_stack/distribution/request_headers.py +++ b/llama_stack/distribution/request_headers.py @@ -23,7 +23,7 @@ class NeedsRequestProviderData: if not validator_class: raise ValueError(f"Provider {provider_id} does not have a validator") - val = _THREAD_LOCAL.provider_data_header_value + val = getattr(_THREAD_LOCAL, "provider_data_header_value", None) if not val: return None diff --git a/llama_stack/providers/adapters/inference/bedrock/bedrock.py b/llama_stack/providers/adapters/inference/bedrock/bedrock.py index 14b506964..de0ee84eb 100644 --- a/llama_stack/providers/adapters/inference/bedrock/bedrock.py +++ b/llama_stack/providers/adapters/inference/bedrock/bedrock.py @@ -20,9 +20,9 @@ from llama_stack.providers.adapters.inference.bedrock.config import BedrockConfi # mapping of Model SKUs to ollama models BEDROCK_SUPPORTED_MODELS = { - "Meta-Llama3.1-8B-Instruct": "meta.llama3-1-8b-instruct-v1:0", - "Meta-Llama3.1-70B-Instruct": "meta.llama3-1-70b-instruct-v1:0", - "Meta-Llama3.1-405B-Instruct": "meta.llama3-1-405b-instruct-v1:0", + "Llama3.1-8B-Instruct": "meta.llama3-1-8b-instruct-v1:0", + "Llama3.1-70B-Instruct": "meta.llama3-1-70b-instruct-v1:0", + "Llama3.1-405B-Instruct": "meta.llama3-1-405b-instruct-v1:0", } diff --git a/llama_stack/providers/adapters/inference/together/config.py b/llama_stack/providers/adapters/inference/together/config.py index 03ee047d2..e928a771d 100644 --- a/llama_stack/providers/adapters/inference/together/config.py +++ b/llama_stack/providers/adapters/inference/together/config.py @@ -4,6 +4,8 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. +from typing import Optional + from llama_models.schema_utils import json_schema_type from pydantic import BaseModel, Field @@ -14,7 +16,7 @@ class TogetherImplConfig(BaseModel): default="https://api.together.xyz/v1", description="The URL for the Together AI server", ) - api_key: str = Field( - default="", + api_key: Optional[str] = Field( + default=None, description="The Together AI API Key", ) diff --git a/llama_stack/providers/adapters/inference/together/together.py b/llama_stack/providers/adapters/inference/together/together.py index 2c2c0c4d8..1db354bc3 100644 --- a/llama_stack/providers/adapters/inference/together/together.py +++ b/llama_stack/providers/adapters/inference/together/together.py @@ -96,12 +96,15 @@ class TogetherInferenceAdapter( ) -> AsyncGenerator: together_api_key = None - provider_data = self.get_request_provider_data() - if provider_data is None or not provider_data.together_api_key: - raise ValueError( - 'Pass Together API Key in the header X-LlamaStack-ProviderData as { "together_api_key": }' - ) - together_api_key = provider_data.together_api_key + if self.config.api_key is not None: + together_api_key = self.config.api_key + else: + provider_data = self.get_request_provider_data() + if provider_data is None or not provider_data.together_api_key: + raise ValueError( + 'Pass Together API Key in the header X-LlamaStack-ProviderData as { "together_api_key": }' + ) + together_api_key = provider_data.together_api_key client = Together(api_key=together_api_key) # wrapper request to make it easier to pass around (internal only, not exposed to API) diff --git a/llama_stack/providers/adapters/safety/together/together.py b/llama_stack/providers/adapters/safety/together/together.py index cb1040d19..06b16d23d 100644 --- a/llama_stack/providers/adapters/safety/together/together.py +++ b/llama_stack/providers/adapters/safety/together/together.py @@ -51,12 +51,15 @@ class TogetherSafetyImpl(Safety, NeedsRequestProviderData, RoutableProvider): raise ValueError(f"Unknown safety shield type: {shield_type}") together_api_key = None - provider_data = self.get_request_provider_data() - if provider_data is None or not provider_data.together_api_key: - raise ValueError( - 'Pass Together API Key in the header X-LlamaStack-ProviderData as { "together_api_key": }' - ) - together_api_key = provider_data.together_api_key + if self.config.api_key is not None: + together_api_key = self.config.api_key + else: + provider_data = self.get_request_provider_data() + if provider_data is None or not provider_data.together_api_key: + raise ValueError( + 'Pass Together API Key in the header X-LlamaStack-ProviderData as { "together_api_key": }' + ) + together_api_key = provider_data.together_api_key model_name = SAFETY_SHIELD_TYPES[shield_type] diff --git a/llama_stack/providers/utils/inference/routable.py b/llama_stack/providers/utils/inference/routable.py index 254e12d60..6dd2dd1fe 100644 --- a/llama_stack/providers/utils/inference/routable.py +++ b/llama_stack/providers/utils/inference/routable.py @@ -4,7 +4,7 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from typing import Dict +from typing import Dict, List from llama_models.sku_list import resolve_model