update get_apikey in adaptor get_params

This commit is contained in:
jhpiedrahitao 2025-04-02 15:43:14 -05:00
parent 3372301fa7
commit 085cc7beed
3 changed files with 24 additions and 9 deletions

View file

@ -56,8 +56,8 @@ from llama_stack.apis.inference import (
ToolResponseMessage,
UserMessage,
)
from llama_stack.log import get_logger
from llama_stack.models.llama.datatypes import BuiltinTool
from llama_stack.providers.remote.inference.sambanova.config import SambaNovaImplConfig
from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOpenAIMixin
from llama_stack.providers.utils.inference.openai_compat import (
convert_tooldef_to_openai_tool,
@ -65,8 +65,11 @@ from llama_stack.providers.utils.inference.openai_compat import (
)
from llama_stack.providers.utils.inference.prompt_adapter import convert_image_content_to_url
from .config import SambaNovaImplConfig
from .models import MODEL_ENTRIES
logger = get_logger(name=__name__, category="inference")
async def convert_message_to_openai_dict_with_b64_images(
message: Message | Dict,
@ -172,13 +175,25 @@ class SambaNovaInferenceAdapter(LiteLLMOpenAIMixin):
_config: SambaNovaImplConfig
def __init__(self, config: SambaNovaImplConfig):
self.config = config
LiteLLMOpenAIMixin.__init__(
self,
model_entries=MODEL_ENTRIES,
api_key_from_config=config.api_key,
api_key_from_config=self.config.api_key,
provider_data_api_key_field="sambanova_api_key",
)
self.config = config
def _get_api_key(self) -> str:
config_api_key = self.config.api_key if self.config.api_key else None
if config_api_key:
return config_api_key.get_secret_value()
else:
provider_data = self.get_request_provider_data()
if provider_data is None or not provider_data.sambanova_api_key:
raise ValueError(
'Pass Sambanova API Key in the header X-LlamaStack-Provider-Data as { "sambanova_api_key": <your api key> }'
)
return provider_data.sambanova_api_key
async def _get_params(self, request: ChatCompletionRequest) -> dict:
input_dict = {}
@ -220,7 +235,7 @@ class SambaNovaInferenceAdapter(LiteLLMOpenAIMixin):
if provider_data and getattr(provider_data, key_field, None):
api_key = getattr(provider_data, key_field)
else:
api_key = self.api_key_from_config
api_key = self._get_api_key()
return {
"model": request.model,