bug fixes

This commit is contained in:
Ashwin Bharambe 2024-09-30 16:15:51 -07:00
parent 878b2c31c7
commit 0996ffb3b3
6 changed files with 27 additions and 19 deletions

View file

@ -23,7 +23,7 @@ class NeedsRequestProviderData:
if not validator_class: if not validator_class:
raise ValueError(f"Provider {provider_id} does not have a validator") raise ValueError(f"Provider {provider_id} does not have a validator")
val = _THREAD_LOCAL.provider_data_header_value val = getattr(_THREAD_LOCAL, "provider_data_header_value", None)
if not val: if not val:
return None return None

View file

@ -20,9 +20,9 @@ from llama_stack.providers.adapters.inference.bedrock.config import BedrockConfi
# mapping of Model SKUs to ollama models # mapping of Model SKUs to ollama models
BEDROCK_SUPPORTED_MODELS = { BEDROCK_SUPPORTED_MODELS = {
"Meta-Llama3.1-8B-Instruct": "meta.llama3-1-8b-instruct-v1:0", "Llama3.1-8B-Instruct": "meta.llama3-1-8b-instruct-v1:0",
"Meta-Llama3.1-70B-Instruct": "meta.llama3-1-70b-instruct-v1:0", "Llama3.1-70B-Instruct": "meta.llama3-1-70b-instruct-v1:0",
"Meta-Llama3.1-405B-Instruct": "meta.llama3-1-405b-instruct-v1:0", "Llama3.1-405B-Instruct": "meta.llama3-1-405b-instruct-v1:0",
} }

View file

@ -4,6 +4,8 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from typing import Optional
from llama_models.schema_utils import json_schema_type from llama_models.schema_utils import json_schema_type
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
@ -14,7 +16,7 @@ class TogetherImplConfig(BaseModel):
default="https://api.together.xyz/v1", default="https://api.together.xyz/v1",
description="The URL for the Together AI server", description="The URL for the Together AI server",
) )
api_key: str = Field( api_key: Optional[str] = Field(
default="", default=None,
description="The Together AI API Key", description="The Together AI API Key",
) )

View file

@ -96,6 +96,9 @@ class TogetherInferenceAdapter(
) -> AsyncGenerator: ) -> AsyncGenerator:
together_api_key = None together_api_key = None
if self.config.api_key is not None:
together_api_key = self.config.api_key
else:
provider_data = self.get_request_provider_data() provider_data = self.get_request_provider_data()
if provider_data is None or not provider_data.together_api_key: if provider_data is None or not provider_data.together_api_key:
raise ValueError( raise ValueError(

View file

@ -51,6 +51,9 @@ class TogetherSafetyImpl(Safety, NeedsRequestProviderData, RoutableProvider):
raise ValueError(f"Unknown safety shield type: {shield_type}") raise ValueError(f"Unknown safety shield type: {shield_type}")
together_api_key = None together_api_key = None
if self.config.api_key is not None:
together_api_key = self.config.api_key
else:
provider_data = self.get_request_provider_data() provider_data = self.get_request_provider_data()
if provider_data is None or not provider_data.together_api_key: if provider_data is None or not provider_data.together_api_key:
raise ValueError( raise ValueError(

View file

@ -4,7 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from typing import Dict from typing import Dict, List
from llama_models.sku_list import resolve_model from llama_models.sku_list import resolve_model