diff --git a/llama_stack/providers/remote/inference/sambanova/sambanova.py b/llama_stack/providers/remote/inference/sambanova/sambanova.py index 20f863665..9c2dda889 100644 --- a/llama_stack/providers/remote/inference/sambanova/sambanova.py +++ b/llama_stack/providers/remote/inference/sambanova/sambanova.py @@ -7,6 +7,7 @@ import json from collections.abc import Iterable +import requests from openai.types.chat import ( ChatCompletionAssistantMessageParam as OpenAIChatCompletionAssistantMessage, ) @@ -56,6 +57,7 @@ from llama_stack.apis.inference import ( ToolResponseMessage, UserMessage, ) +from llama_stack.apis.models import Model from llama_stack.log import get_logger from llama_stack.models.llama.datatypes import BuiltinTool from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOpenAIMixin @@ -176,10 +178,11 @@ class SambaNovaInferenceAdapter(LiteLLMOpenAIMixin): def __init__(self, config: SambaNovaImplConfig): self.config = config + self.environment_available_models = [] LiteLLMOpenAIMixin.__init__( self, model_entries=MODEL_ENTRIES, - api_key_from_config=self.config.api_key, + api_key_from_config=self.config.api_key.get_secret_value() if self.config.api_key else None, provider_data_api_key_field="sambanova_api_key", ) @@ -246,6 +249,22 @@ class SambaNovaInferenceAdapter(LiteLLMOpenAIMixin): **get_sampling_options(request.sampling_params), } + async def register_model(self, model: Model) -> Model: + model_id = self.get_provider_model_id(model.provider_resource_id) + + list_models_url = self.config.url + "/models" + if len(self.environment_available_models) == 0: + try: + response = requests.get(list_models_url) + response.raise_for_status() + except requests.exceptions.RequestException as e: + raise RuntimeError(f"Request to {list_models_url} failed") from e + self.environment_available_models = [model.get("id") for model in response.json().get("data", {})] + + if model_id.split("sambanova/")[-1] not in self.environment_available_models: + logger.warning(f"Model {model_id} not available in {list_models_url}") + return model + async def initialize(self): await super().initialize() diff --git a/llama_stack/providers/remote/safety/sambanova/sambanova.py b/llama_stack/providers/remote/safety/sambanova/sambanova.py index 84c8267ae..1a65f6aa1 100644 --- a/llama_stack/providers/remote/safety/sambanova/sambanova.py +++ b/llama_stack/providers/remote/safety/sambanova/sambanova.py @@ -33,6 +33,7 @@ CANNED_RESPONSE_TEXT = "I can't answer that. Can I help with something else?" class SambaNovaSafetyAdapter(Safety, ShieldsProtocolPrivate, NeedsRequestProviderData): def __init__(self, config: SambaNovaSafetyConfig) -> None: self.config = config + self.environment_available_models = [] async def initialize(self) -> None: pass @@ -54,18 +55,18 @@ class SambaNovaSafetyAdapter(Safety, ShieldsProtocolPrivate, NeedsRequestProvide async def register_shield(self, shield: Shield) -> None: list_models_url = self.config.url + "/models" - try: - response = requests.get(list_models_url) - response.raise_for_status() - except requests.exceptions.RequestException as e: - raise RuntimeError(f"Request to {list_models_url} failed") from e - available_models = [model.get("id") for model in response.json().get("data", {})] + if len(self.environment_available_models) == 0: + try: + response = requests.get(list_models_url) + response.raise_for_status() + except requests.exceptions.RequestException as e: + raise RuntimeError(f"Request to {list_models_url} failed") from e + self.environment_available_models = [model.get("id") for model in response.json().get("data", {})] if ( - len(available_models) == 0 - or "guard" not in shield.provider_resource_id.lower() - or shield.provider_resource_id.split("sambanova/")[-1] not in available_models + "guard" not in shield.provider_resource_id.lower() + or shield.provider_resource_id.split("sambanova/")[-1] not in self.environment_available_models ): - raise ValueError(f"Shield {shield.provider_resource_id} not found in SambaNova") + logger.warning(f"Shield {shield.provider_resource_id} not available in {list_models_url}") async def run_shield( self, shield_id: str, messages: list[Message], params: dict[str, Any] | None = None diff --git a/tests/integration/inference/test_openai_completion.py b/tests/integration/inference/test_openai_completion.py index 05aee5096..e82714ffd 100644 --- a/tests/integration/inference/test_openai_completion.py +++ b/tests/integration/inference/test_openai_completion.py @@ -71,7 +71,6 @@ def skip_if_model_doesnt_support_openai_chat_completion(client_with_models, mode "remote::cerebras", "remote::databricks", "remote::runpod", - "remote::sambanova", "remote::tgi", ): pytest.skip(f"Model {model_id} hosted by {provider.provider_type} doesn't support OpenAI chat completions.")