diff --git a/llama_stack/providers/remote/inference/sambanova/config.py b/llama_stack/providers/remote/inference/sambanova/config.py index f76262914..6b31af2f0 100644 --- a/llama_stack/providers/remote/inference/sambanova/config.py +++ b/llama_stack/providers/remote/inference/sambanova/config.py @@ -6,7 +6,7 @@ from typing import Any, Dict, Optional -from pydantic import BaseModel, Field +from pydantic import BaseModel, Field, SecretStr from llama_stack.schema_utils import json_schema_type @@ -24,14 +24,14 @@ class SambaNovaImplConfig(BaseModel): default="https://api.sambanova.ai/v1", description="The URL for the SambaNova AI server", ) - api_key: Optional[str] = Field( + api_key: Optional[SecretStr] = Field( default=None, description="The SambaNova cloud API Key", ) @classmethod - def sample_run_config(cls, **kwargs) -> Dict[str, Any]: + def sample_run_config(cls, api_key: str = "${env.SAMBANOVA_API_KEY}", **kwargs) -> Dict[str, Any]: return { "url": "https://api.sambanova.ai/v1", - "api_key": "${env.SAMBANOVA_API_KEY}", + "api_key": api_key, } diff --git a/llama_stack/providers/remote/inference/sambanova/sambanova.py b/llama_stack/providers/remote/inference/sambanova/sambanova.py index 68d959504..d7387c9ba 100644 --- a/llama_stack/providers/remote/inference/sambanova/sambanova.py +++ b/llama_stack/providers/remote/inference/sambanova/sambanova.py @@ -56,8 +56,8 @@ from llama_stack.apis.inference import ( ToolResponseMessage, UserMessage, ) +from llama_stack.log import get_logger from llama_stack.models.llama.datatypes import BuiltinTool -from llama_stack.providers.remote.inference.sambanova.config import SambaNovaImplConfig from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOpenAIMixin from llama_stack.providers.utils.inference.openai_compat import ( convert_tooldef_to_openai_tool, @@ -65,8 +65,11 @@ from llama_stack.providers.utils.inference.openai_compat import ( ) from llama_stack.providers.utils.inference.prompt_adapter import convert_image_content_to_url +from .config import SambaNovaImplConfig from .models import MODEL_ENTRIES +logger = get_logger(name=__name__, category="inference") + async def convert_message_to_openai_dict_with_b64_images( message: Message | Dict, @@ -172,13 +175,25 @@ class SambaNovaInferenceAdapter(LiteLLMOpenAIMixin): _config: SambaNovaImplConfig def __init__(self, config: SambaNovaImplConfig): + self.config = config LiteLLMOpenAIMixin.__init__( self, model_entries=MODEL_ENTRIES, - api_key_from_config=config.api_key, + api_key_from_config=self.config.api_key, provider_data_api_key_field="sambanova_api_key", ) - self.config = config + + def _get_api_key(self) -> str: + config_api_key = self.config.api_key if self.config.api_key else None + if config_api_key: + return config_api_key.get_secret_value() + else: + provider_data = self.get_request_provider_data() + if provider_data is None or not provider_data.sambanova_api_key: + raise ValueError( + 'Pass Sambanova API Key in the header X-LlamaStack-Provider-Data as { "sambanova_api_key": }' + ) + return provider_data.sambanova_api_key async def _get_params(self, request: ChatCompletionRequest) -> dict: input_dict = {} @@ -220,7 +235,7 @@ class SambaNovaInferenceAdapter(LiteLLMOpenAIMixin): if provider_data and getattr(provider_data, key_field, None): api_key = getattr(provider_data, key_field) else: - api_key = self.api_key_from_config + api_key = self._get_api_key() return { "model": request.model, diff --git a/llama_stack/templates/dev/run.yaml b/llama_stack/templates/dev/run.yaml index 8aae16103..578ae19d9 100644 --- a/llama_stack/templates/dev/run.yaml +++ b/llama_stack/templates/dev/run.yaml @@ -38,7 +38,7 @@ providers: provider_type: remote::sambanova config: url: https://api.sambanova.ai/v1 - api_key: ${env.SAMBANOVA_API_KEY} + api_key: ${env.SAMBANOVA_API_KEY:} - provider_id: sentence-transformers provider_type: inline::sentence-transformers config: {}