bug fixes

This commit is contained in:
Ashwin Bharambe 2024-09-30 16:15:51 -07:00
parent 878b2c31c7
commit 0996ffb3b3
6 changed files with 27 additions and 19 deletions

View file

@ -23,7 +23,7 @@ class NeedsRequestProviderData:
if not validator_class:
raise ValueError(f"Provider {provider_id} does not have a validator")
val = _THREAD_LOCAL.provider_data_header_value
val = getattr(_THREAD_LOCAL, "provider_data_header_value", None)
if not val:
return None

View file

@ -20,9 +20,9 @@ from llama_stack.providers.adapters.inference.bedrock.config import BedrockConfi
# mapping of Model SKUs to ollama models
BEDROCK_SUPPORTED_MODELS = {
"Meta-Llama3.1-8B-Instruct": "meta.llama3-1-8b-instruct-v1:0",
"Meta-Llama3.1-70B-Instruct": "meta.llama3-1-70b-instruct-v1:0",
"Meta-Llama3.1-405B-Instruct": "meta.llama3-1-405b-instruct-v1:0",
"Llama3.1-8B-Instruct": "meta.llama3-1-8b-instruct-v1:0",
"Llama3.1-70B-Instruct": "meta.llama3-1-70b-instruct-v1:0",
"Llama3.1-405B-Instruct": "meta.llama3-1-405b-instruct-v1:0",
}

View file

@ -4,6 +4,8 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Optional
from llama_models.schema_utils import json_schema_type
from pydantic import BaseModel, Field
@ -14,7 +16,7 @@ class TogetherImplConfig(BaseModel):
default="https://api.together.xyz/v1",
description="The URL for the Together AI server",
)
api_key: str = Field(
default="",
api_key: Optional[str] = Field(
default=None,
description="The Together AI API Key",
)

View file

@ -96,12 +96,15 @@ class TogetherInferenceAdapter(
) -> AsyncGenerator:
together_api_key = None
provider_data = self.get_request_provider_data()
if provider_data is None or not provider_data.together_api_key:
raise ValueError(
'Pass Together API Key in the header X-LlamaStack-ProviderData as { "together_api_key": <your api key>}'
)
together_api_key = provider_data.together_api_key
if self.config.api_key is not None:
together_api_key = self.config.api_key
else:
provider_data = self.get_request_provider_data()
if provider_data is None or not provider_data.together_api_key:
raise ValueError(
'Pass Together API Key in the header X-LlamaStack-ProviderData as { "together_api_key": <your api key>}'
)
together_api_key = provider_data.together_api_key
client = Together(api_key=together_api_key)
# wrapper request to make it easier to pass around (internal only, not exposed to API)

View file

@ -51,12 +51,15 @@ class TogetherSafetyImpl(Safety, NeedsRequestProviderData, RoutableProvider):
raise ValueError(f"Unknown safety shield type: {shield_type}")
together_api_key = None
provider_data = self.get_request_provider_data()
if provider_data is None or not provider_data.together_api_key:
raise ValueError(
'Pass Together API Key in the header X-LlamaStack-ProviderData as { "together_api_key": <your api key>}'
)
together_api_key = provider_data.together_api_key
if self.config.api_key is not None:
together_api_key = self.config.api_key
else:
provider_data = self.get_request_provider_data()
if provider_data is None or not provider_data.together_api_key:
raise ValueError(
'Pass Together API Key in the header X-LlamaStack-ProviderData as { "together_api_key": <your api key>}'
)
together_api_key = provider_data.together_api_key
model_name = SAFETY_SHIELD_TYPES[shield_type]

View file

@ -4,7 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Dict
from typing import Dict, List
from llama_models.sku_list import resolve_model