mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-29 07:14:20 +00:00
bug fixes
This commit is contained in:
parent
878b2c31c7
commit
0996ffb3b3
6 changed files with 27 additions and 19 deletions
|
@ -23,7 +23,7 @@ class NeedsRequestProviderData:
|
||||||
if not validator_class:
|
if not validator_class:
|
||||||
raise ValueError(f"Provider {provider_id} does not have a validator")
|
raise ValueError(f"Provider {provider_id} does not have a validator")
|
||||||
|
|
||||||
val = _THREAD_LOCAL.provider_data_header_value
|
val = getattr(_THREAD_LOCAL, "provider_data_header_value", None)
|
||||||
if not val:
|
if not val:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
|
@ -20,9 +20,9 @@ from llama_stack.providers.adapters.inference.bedrock.config import BedrockConfi
|
||||||
|
|
||||||
# mapping of Model SKUs to ollama models
|
# mapping of Model SKUs to ollama models
|
||||||
BEDROCK_SUPPORTED_MODELS = {
|
BEDROCK_SUPPORTED_MODELS = {
|
||||||
"Meta-Llama3.1-8B-Instruct": "meta.llama3-1-8b-instruct-v1:0",
|
"Llama3.1-8B-Instruct": "meta.llama3-1-8b-instruct-v1:0",
|
||||||
"Meta-Llama3.1-70B-Instruct": "meta.llama3-1-70b-instruct-v1:0",
|
"Llama3.1-70B-Instruct": "meta.llama3-1-70b-instruct-v1:0",
|
||||||
"Meta-Llama3.1-405B-Instruct": "meta.llama3-1-405b-instruct-v1:0",
|
"Llama3.1-405B-Instruct": "meta.llama3-1-405b-instruct-v1:0",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -4,6 +4,8 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
from llama_models.schema_utils import json_schema_type
|
from llama_models.schema_utils import json_schema_type
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
@ -14,7 +16,7 @@ class TogetherImplConfig(BaseModel):
|
||||||
default="https://api.together.xyz/v1",
|
default="https://api.together.xyz/v1",
|
||||||
description="The URL for the Together AI server",
|
description="The URL for the Together AI server",
|
||||||
)
|
)
|
||||||
api_key: str = Field(
|
api_key: Optional[str] = Field(
|
||||||
default="",
|
default=None,
|
||||||
description="The Together AI API Key",
|
description="The Together AI API Key",
|
||||||
)
|
)
|
||||||
|
|
|
@ -96,12 +96,15 @@ class TogetherInferenceAdapter(
|
||||||
) -> AsyncGenerator:
|
) -> AsyncGenerator:
|
||||||
|
|
||||||
together_api_key = None
|
together_api_key = None
|
||||||
provider_data = self.get_request_provider_data()
|
if self.config.api_key is not None:
|
||||||
if provider_data is None or not provider_data.together_api_key:
|
together_api_key = self.config.api_key
|
||||||
raise ValueError(
|
else:
|
||||||
'Pass Together API Key in the header X-LlamaStack-ProviderData as { "together_api_key": <your api key>}'
|
provider_data = self.get_request_provider_data()
|
||||||
)
|
if provider_data is None or not provider_data.together_api_key:
|
||||||
together_api_key = provider_data.together_api_key
|
raise ValueError(
|
||||||
|
'Pass Together API Key in the header X-LlamaStack-ProviderData as { "together_api_key": <your api key>}'
|
||||||
|
)
|
||||||
|
together_api_key = provider_data.together_api_key
|
||||||
|
|
||||||
client = Together(api_key=together_api_key)
|
client = Together(api_key=together_api_key)
|
||||||
# wrapper request to make it easier to pass around (internal only, not exposed to API)
|
# wrapper request to make it easier to pass around (internal only, not exposed to API)
|
||||||
|
|
|
@ -51,12 +51,15 @@ class TogetherSafetyImpl(Safety, NeedsRequestProviderData, RoutableProvider):
|
||||||
raise ValueError(f"Unknown safety shield type: {shield_type}")
|
raise ValueError(f"Unknown safety shield type: {shield_type}")
|
||||||
|
|
||||||
together_api_key = None
|
together_api_key = None
|
||||||
provider_data = self.get_request_provider_data()
|
if self.config.api_key is not None:
|
||||||
if provider_data is None or not provider_data.together_api_key:
|
together_api_key = self.config.api_key
|
||||||
raise ValueError(
|
else:
|
||||||
'Pass Together API Key in the header X-LlamaStack-ProviderData as { "together_api_key": <your api key>}'
|
provider_data = self.get_request_provider_data()
|
||||||
)
|
if provider_data is None or not provider_data.together_api_key:
|
||||||
together_api_key = provider_data.together_api_key
|
raise ValueError(
|
||||||
|
'Pass Together API Key in the header X-LlamaStack-ProviderData as { "together_api_key": <your api key>}'
|
||||||
|
)
|
||||||
|
together_api_key = provider_data.together_api_key
|
||||||
|
|
||||||
model_name = SAFETY_SHIELD_TYPES[shield_type]
|
model_name = SAFETY_SHIELD_TYPES[shield_type]
|
||||||
|
|
||||||
|
|
|
@ -4,7 +4,7 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from typing import Dict
|
from typing import Dict, List
|
||||||
|
|
||||||
from llama_models.sku_list import resolve_model
|
from llama_models.sku_list import resolve_model
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue