update get_apikey in adaptor get_params

This commit is contained in:
jhpiedrahitao 2025-04-02 15:43:14 -05:00
parent 3372301fa7
commit 085cc7beed
3 changed files with 24 additions and 9 deletions

View file

@ -6,7 +6,7 @@
from typing import Any, Dict, Optional from typing import Any, Dict, Optional
from pydantic import BaseModel, Field from pydantic import BaseModel, Field, SecretStr
from llama_stack.schema_utils import json_schema_type from llama_stack.schema_utils import json_schema_type
@ -24,14 +24,14 @@ class SambaNovaImplConfig(BaseModel):
default="https://api.sambanova.ai/v1", default="https://api.sambanova.ai/v1",
description="The URL for the SambaNova AI server", description="The URL for the SambaNova AI server",
) )
api_key: Optional[str] = Field( api_key: Optional[SecretStr] = Field(
default=None, default=None,
description="The SambaNova cloud API Key", description="The SambaNova cloud API Key",
) )
@classmethod @classmethod
def sample_run_config(cls, **kwargs) -> Dict[str, Any]: def sample_run_config(cls, api_key: str = "${env.SAMBANOVA_API_KEY}", **kwargs) -> Dict[str, Any]:
return { return {
"url": "https://api.sambanova.ai/v1", "url": "https://api.sambanova.ai/v1",
"api_key": "${env.SAMBANOVA_API_KEY}", "api_key": api_key,
} }

View file

@ -56,8 +56,8 @@ from llama_stack.apis.inference import (
ToolResponseMessage, ToolResponseMessage,
UserMessage, UserMessage,
) )
from llama_stack.log import get_logger
from llama_stack.models.llama.datatypes import BuiltinTool from llama_stack.models.llama.datatypes import BuiltinTool
from llama_stack.providers.remote.inference.sambanova.config import SambaNovaImplConfig
from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOpenAIMixin from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOpenAIMixin
from llama_stack.providers.utils.inference.openai_compat import ( from llama_stack.providers.utils.inference.openai_compat import (
convert_tooldef_to_openai_tool, convert_tooldef_to_openai_tool,
@ -65,8 +65,11 @@ from llama_stack.providers.utils.inference.openai_compat import (
) )
from llama_stack.providers.utils.inference.prompt_adapter import convert_image_content_to_url from llama_stack.providers.utils.inference.prompt_adapter import convert_image_content_to_url
from .config import SambaNovaImplConfig
from .models import MODEL_ENTRIES from .models import MODEL_ENTRIES
logger = get_logger(name=__name__, category="inference")
async def convert_message_to_openai_dict_with_b64_images( async def convert_message_to_openai_dict_with_b64_images(
message: Message | Dict, message: Message | Dict,
@ -172,13 +175,25 @@ class SambaNovaInferenceAdapter(LiteLLMOpenAIMixin):
_config: SambaNovaImplConfig _config: SambaNovaImplConfig
def __init__(self, config: SambaNovaImplConfig): def __init__(self, config: SambaNovaImplConfig):
self.config = config
LiteLLMOpenAIMixin.__init__( LiteLLMOpenAIMixin.__init__(
self, self,
model_entries=MODEL_ENTRIES, model_entries=MODEL_ENTRIES,
api_key_from_config=config.api_key, api_key_from_config=self.config.api_key,
provider_data_api_key_field="sambanova_api_key", provider_data_api_key_field="sambanova_api_key",
) )
self.config = config
def _get_api_key(self) -> str:
config_api_key = self.config.api_key if self.config.api_key else None
if config_api_key:
return config_api_key.get_secret_value()
else:
provider_data = self.get_request_provider_data()
if provider_data is None or not provider_data.sambanova_api_key:
raise ValueError(
'Pass Sambanova API Key in the header X-LlamaStack-Provider-Data as { "sambanova_api_key": <your api key> }'
)
return provider_data.sambanova_api_key
async def _get_params(self, request: ChatCompletionRequest) -> dict: async def _get_params(self, request: ChatCompletionRequest) -> dict:
input_dict = {} input_dict = {}
@ -220,7 +235,7 @@ class SambaNovaInferenceAdapter(LiteLLMOpenAIMixin):
if provider_data and getattr(provider_data, key_field, None): if provider_data and getattr(provider_data, key_field, None):
api_key = getattr(provider_data, key_field) api_key = getattr(provider_data, key_field)
else: else:
api_key = self.api_key_from_config api_key = self._get_api_key()
return { return {
"model": request.model, "model": request.model,

View file

@ -38,7 +38,7 @@ providers:
provider_type: remote::sambanova provider_type: remote::sambanova
config: config:
url: https://api.sambanova.ai/v1 url: https://api.sambanova.ai/v1
api_key: ${env.SAMBANOVA_API_KEY} api_key: ${env.SAMBANOVA_API_KEY:}
- provider_id: sentence-transformers - provider_id: sentence-transformers
provider_type: inline::sentence-transformers provider_type: inline::sentence-transformers
config: {} config: {}