mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-05 18:22:41 +00:00
update get_apikey in adaptor get_params
This commit is contained in:
parent
3372301fa7
commit
085cc7beed
3 changed files with 24 additions and 9 deletions
|
@ -6,7 +6,7 @@
|
||||||
|
|
||||||
from typing import Any, Dict, Optional
|
from typing import Any, Dict, Optional
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field, SecretStr
|
||||||
|
|
||||||
from llama_stack.schema_utils import json_schema_type
|
from llama_stack.schema_utils import json_schema_type
|
||||||
|
|
||||||
|
@ -24,14 +24,14 @@ class SambaNovaImplConfig(BaseModel):
|
||||||
default="https://api.sambanova.ai/v1",
|
default="https://api.sambanova.ai/v1",
|
||||||
description="The URL for the SambaNova AI server",
|
description="The URL for the SambaNova AI server",
|
||||||
)
|
)
|
||||||
api_key: Optional[str] = Field(
|
api_key: Optional[SecretStr] = Field(
|
||||||
default=None,
|
default=None,
|
||||||
description="The SambaNova cloud API Key",
|
description="The SambaNova cloud API Key",
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def sample_run_config(cls, **kwargs) -> Dict[str, Any]:
|
def sample_run_config(cls, api_key: str = "${env.SAMBANOVA_API_KEY}", **kwargs) -> Dict[str, Any]:
|
||||||
return {
|
return {
|
||||||
"url": "https://api.sambanova.ai/v1",
|
"url": "https://api.sambanova.ai/v1",
|
||||||
"api_key": "${env.SAMBANOVA_API_KEY}",
|
"api_key": api_key,
|
||||||
}
|
}
|
||||||
|
|
|
@ -56,8 +56,8 @@ from llama_stack.apis.inference import (
|
||||||
ToolResponseMessage,
|
ToolResponseMessage,
|
||||||
UserMessage,
|
UserMessage,
|
||||||
)
|
)
|
||||||
|
from llama_stack.log import get_logger
|
||||||
from llama_stack.models.llama.datatypes import BuiltinTool
|
from llama_stack.models.llama.datatypes import BuiltinTool
|
||||||
from llama_stack.providers.remote.inference.sambanova.config import SambaNovaImplConfig
|
|
||||||
from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOpenAIMixin
|
from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOpenAIMixin
|
||||||
from llama_stack.providers.utils.inference.openai_compat import (
|
from llama_stack.providers.utils.inference.openai_compat import (
|
||||||
convert_tooldef_to_openai_tool,
|
convert_tooldef_to_openai_tool,
|
||||||
|
@ -65,8 +65,11 @@ from llama_stack.providers.utils.inference.openai_compat import (
|
||||||
)
|
)
|
||||||
from llama_stack.providers.utils.inference.prompt_adapter import convert_image_content_to_url
|
from llama_stack.providers.utils.inference.prompt_adapter import convert_image_content_to_url
|
||||||
|
|
||||||
|
from .config import SambaNovaImplConfig
|
||||||
from .models import MODEL_ENTRIES
|
from .models import MODEL_ENTRIES
|
||||||
|
|
||||||
|
logger = get_logger(name=__name__, category="inference")
|
||||||
|
|
||||||
|
|
||||||
async def convert_message_to_openai_dict_with_b64_images(
|
async def convert_message_to_openai_dict_with_b64_images(
|
||||||
message: Message | Dict,
|
message: Message | Dict,
|
||||||
|
@ -172,13 +175,25 @@ class SambaNovaInferenceAdapter(LiteLLMOpenAIMixin):
|
||||||
_config: SambaNovaImplConfig
|
_config: SambaNovaImplConfig
|
||||||
|
|
||||||
def __init__(self, config: SambaNovaImplConfig):
|
def __init__(self, config: SambaNovaImplConfig):
|
||||||
|
self.config = config
|
||||||
LiteLLMOpenAIMixin.__init__(
|
LiteLLMOpenAIMixin.__init__(
|
||||||
self,
|
self,
|
||||||
model_entries=MODEL_ENTRIES,
|
model_entries=MODEL_ENTRIES,
|
||||||
api_key_from_config=config.api_key,
|
api_key_from_config=self.config.api_key,
|
||||||
provider_data_api_key_field="sambanova_api_key",
|
provider_data_api_key_field="sambanova_api_key",
|
||||||
)
|
)
|
||||||
self.config = config
|
|
||||||
|
def _get_api_key(self) -> str:
|
||||||
|
config_api_key = self.config.api_key if self.config.api_key else None
|
||||||
|
if config_api_key:
|
||||||
|
return config_api_key.get_secret_value()
|
||||||
|
else:
|
||||||
|
provider_data = self.get_request_provider_data()
|
||||||
|
if provider_data is None or not provider_data.sambanova_api_key:
|
||||||
|
raise ValueError(
|
||||||
|
'Pass Sambanova API Key in the header X-LlamaStack-Provider-Data as { "sambanova_api_key": <your api key> }'
|
||||||
|
)
|
||||||
|
return provider_data.sambanova_api_key
|
||||||
|
|
||||||
async def _get_params(self, request: ChatCompletionRequest) -> dict:
|
async def _get_params(self, request: ChatCompletionRequest) -> dict:
|
||||||
input_dict = {}
|
input_dict = {}
|
||||||
|
@ -220,7 +235,7 @@ class SambaNovaInferenceAdapter(LiteLLMOpenAIMixin):
|
||||||
if provider_data and getattr(provider_data, key_field, None):
|
if provider_data and getattr(provider_data, key_field, None):
|
||||||
api_key = getattr(provider_data, key_field)
|
api_key = getattr(provider_data, key_field)
|
||||||
else:
|
else:
|
||||||
api_key = self.api_key_from_config
|
api_key = self._get_api_key()
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"model": request.model,
|
"model": request.model,
|
||||||
|
|
|
@ -38,7 +38,7 @@ providers:
|
||||||
provider_type: remote::sambanova
|
provider_type: remote::sambanova
|
||||||
config:
|
config:
|
||||||
url: https://api.sambanova.ai/v1
|
url: https://api.sambanova.ai/v1
|
||||||
api_key: ${env.SAMBANOVA_API_KEY}
|
api_key: ${env.SAMBANOVA_API_KEY:}
|
||||||
- provider_id: sentence-transformers
|
- provider_id: sentence-transformers
|
||||||
provider_type: inline::sentence-transformers
|
provider_type: inline::sentence-transformers
|
||||||
config: {}
|
config: {}
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue