diff --git a/llama_stack/providers/remote/inference/sambanova/sambanova.py b/llama_stack/providers/remote/inference/sambanova/sambanova.py index 20f863665..9b2f7f0e3 100644 --- a/llama_stack/providers/remote/inference/sambanova/sambanova.py +++ b/llama_stack/providers/remote/inference/sambanova/sambanova.py @@ -7,6 +7,7 @@ import json from collections.abc import Iterable +import requests from openai.types.chat import ( ChatCompletionAssistantMessageParam as OpenAIChatCompletionAssistantMessage, ) @@ -56,6 +57,7 @@ from llama_stack.apis.inference import ( ToolResponseMessage, UserMessage, ) +from llama_stack.apis.models import Model from llama_stack.log import get_logger from llama_stack.models.llama.datatypes import BuiltinTool from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOpenAIMixin @@ -246,6 +248,19 @@ class SambaNovaInferenceAdapter(LiteLLMOpenAIMixin): **get_sampling_options(request.sampling_params), } + async def register_model(self, model: Model) -> Model: + model_id = self.get_provider_model_id(model.provider_resource_id) + list_models_url = self.config.url + "/models" + try: + response = requests.get(list_models_url) + response.raise_for_status() + except requests.exceptions.RequestException as e: + raise RuntimeError(f"Request to {list_models_url} failed") from e + available_models = [model.get("id") for model in response.json().get("data", {})] + if len(available_models) == 0 or model_id.split("sambanova/")[-1] not in available_models: + logger.warning(f"Model {model_id} not found as available in SambaNova models") + return model + async def initialize(self): await super().initialize()