mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-03 17:29:01 +00:00
update get_apikey in adaptor get_params
This commit is contained in:
parent
3372301fa7
commit
085cc7beed
3 changed files with 24 additions and 9 deletions
|
@ -6,7 +6,7 @@
|
|||
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import BaseModel, Field, SecretStr
|
||||
|
||||
from llama_stack.schema_utils import json_schema_type
|
||||
|
||||
|
@ -24,14 +24,14 @@ class SambaNovaImplConfig(BaseModel):
|
|||
default="https://api.sambanova.ai/v1",
|
||||
description="The URL for the SambaNova AI server",
|
||||
)
|
||||
api_key: Optional[str] = Field(
|
||||
api_key: Optional[SecretStr] = Field(
|
||||
default=None,
|
||||
description="The SambaNova cloud API Key",
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(cls, **kwargs) -> Dict[str, Any]:
|
||||
def sample_run_config(cls, api_key: str = "${env.SAMBANOVA_API_KEY}", **kwargs) -> Dict[str, Any]:
|
||||
return {
|
||||
"url": "https://api.sambanova.ai/v1",
|
||||
"api_key": "${env.SAMBANOVA_API_KEY}",
|
||||
"api_key": api_key,
|
||||
}
|
||||
|
|
|
@ -56,8 +56,8 @@ from llama_stack.apis.inference import (
|
|||
ToolResponseMessage,
|
||||
UserMessage,
|
||||
)
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.models.llama.datatypes import BuiltinTool
|
||||
from llama_stack.providers.remote.inference.sambanova.config import SambaNovaImplConfig
|
||||
from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOpenAIMixin
|
||||
from llama_stack.providers.utils.inference.openai_compat import (
|
||||
convert_tooldef_to_openai_tool,
|
||||
|
@ -65,8 +65,11 @@ from llama_stack.providers.utils.inference.openai_compat import (
|
|||
)
|
||||
from llama_stack.providers.utils.inference.prompt_adapter import convert_image_content_to_url
|
||||
|
||||
from .config import SambaNovaImplConfig
|
||||
from .models import MODEL_ENTRIES
|
||||
|
||||
logger = get_logger(name=__name__, category="inference")
|
||||
|
||||
|
||||
async def convert_message_to_openai_dict_with_b64_images(
|
||||
message: Message | Dict,
|
||||
|
@ -172,13 +175,25 @@ class SambaNovaInferenceAdapter(LiteLLMOpenAIMixin):
|
|||
_config: SambaNovaImplConfig
|
||||
|
||||
def __init__(self, config: SambaNovaImplConfig):
|
||||
self.config = config
|
||||
LiteLLMOpenAIMixin.__init__(
|
||||
self,
|
||||
model_entries=MODEL_ENTRIES,
|
||||
api_key_from_config=config.api_key,
|
||||
api_key_from_config=self.config.api_key,
|
||||
provider_data_api_key_field="sambanova_api_key",
|
||||
)
|
||||
self.config = config
|
||||
|
||||
def _get_api_key(self) -> str:
|
||||
config_api_key = self.config.api_key if self.config.api_key else None
|
||||
if config_api_key:
|
||||
return config_api_key.get_secret_value()
|
||||
else:
|
||||
provider_data = self.get_request_provider_data()
|
||||
if provider_data is None or not provider_data.sambanova_api_key:
|
||||
raise ValueError(
|
||||
'Pass Sambanova API Key in the header X-LlamaStack-Provider-Data as { "sambanova_api_key": <your api key> }'
|
||||
)
|
||||
return provider_data.sambanova_api_key
|
||||
|
||||
async def _get_params(self, request: ChatCompletionRequest) -> dict:
|
||||
input_dict = {}
|
||||
|
@ -220,7 +235,7 @@ class SambaNovaInferenceAdapter(LiteLLMOpenAIMixin):
|
|||
if provider_data and getattr(provider_data, key_field, None):
|
||||
api_key = getattr(provider_data, key_field)
|
||||
else:
|
||||
api_key = self.api_key_from_config
|
||||
api_key = self._get_api_key()
|
||||
|
||||
return {
|
||||
"model": request.model,
|
||||
|
|
|
@ -38,7 +38,7 @@ providers:
|
|||
provider_type: remote::sambanova
|
||||
config:
|
||||
url: https://api.sambanova.ai/v1
|
||||
api_key: ${env.SAMBANOVA_API_KEY}
|
||||
api_key: ${env.SAMBANOVA_API_KEY:}
|
||||
- provider_id: sentence-transformers
|
||||
provider_type: inline::sentence-transformers
|
||||
config: {}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue