update get_apikey in adaptor get_params

This commit is contained in:
jhpiedrahitao 2025-04-02 15:43:14 -05:00
parent 3372301fa7
commit 085cc7beed
3 changed files with 24 additions and 9 deletions

View file

@ -6,7 +6,7 @@
from typing import Any, Dict, Optional
from pydantic import BaseModel, Field
from pydantic import BaseModel, Field, SecretStr
from llama_stack.schema_utils import json_schema_type
@ -24,14 +24,14 @@ class SambaNovaImplConfig(BaseModel):
default="https://api.sambanova.ai/v1",
description="The URL for the SambaNova AI server",
)
api_key: Optional[str] = Field(
api_key: Optional[SecretStr] = Field(
default=None,
description="The SambaNova cloud API Key",
)
@classmethod
def sample_run_config(cls, **kwargs) -> Dict[str, Any]:
def sample_run_config(cls, api_key: str = "${env.SAMBANOVA_API_KEY}", **kwargs) -> Dict[str, Any]:
return {
"url": "https://api.sambanova.ai/v1",
"api_key": "${env.SAMBANOVA_API_KEY}",
"api_key": api_key,
}

View file

@ -56,8 +56,8 @@ from llama_stack.apis.inference import (
ToolResponseMessage,
UserMessage,
)
from llama_stack.log import get_logger
from llama_stack.models.llama.datatypes import BuiltinTool
from llama_stack.providers.remote.inference.sambanova.config import SambaNovaImplConfig
from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOpenAIMixin
from llama_stack.providers.utils.inference.openai_compat import (
convert_tooldef_to_openai_tool,
@ -65,8 +65,11 @@ from llama_stack.providers.utils.inference.openai_compat import (
)
from llama_stack.providers.utils.inference.prompt_adapter import convert_image_content_to_url
from .config import SambaNovaImplConfig
from .models import MODEL_ENTRIES
logger = get_logger(name=__name__, category="inference")
async def convert_message_to_openai_dict_with_b64_images(
message: Message | Dict,
@ -172,13 +175,25 @@ class SambaNovaInferenceAdapter(LiteLLMOpenAIMixin):
_config: SambaNovaImplConfig
def __init__(self, config: SambaNovaImplConfig):
self.config = config
LiteLLMOpenAIMixin.__init__(
self,
model_entries=MODEL_ENTRIES,
api_key_from_config=config.api_key,
api_key_from_config=self.config.api_key,
provider_data_api_key_field="sambanova_api_key",
)
self.config = config
def _get_api_key(self) -> str:
config_api_key = self.config.api_key if self.config.api_key else None
if config_api_key:
return config_api_key.get_secret_value()
else:
provider_data = self.get_request_provider_data()
if provider_data is None or not provider_data.sambanova_api_key:
raise ValueError(
'Pass Sambanova API Key in the header X-LlamaStack-Provider-Data as { "sambanova_api_key": <your api key> }'
)
return provider_data.sambanova_api_key
async def _get_params(self, request: ChatCompletionRequest) -> dict:
input_dict = {}
@ -220,7 +235,7 @@ class SambaNovaInferenceAdapter(LiteLLMOpenAIMixin):
if provider_data and getattr(provider_data, key_field, None):
api_key = getattr(provider_data, key_field)
else:
api_key = self.api_key_from_config
api_key = self._get_api_key()
return {
"model": request.model,

View file

@ -38,7 +38,7 @@ providers:
provider_type: remote::sambanova
config:
url: https://api.sambanova.ai/v1
api_key: ${env.SAMBANOVA_API_KEY}
api_key: ${env.SAMBANOVA_API_KEY:}
- provider_id: sentence-transformers
provider_type: inline::sentence-transformers
config: {}