fix: sambanova shields and model validation (#2693)

# What does this PR do?
Update the shield register validation of Sambanova not to raise, but
only warn when a model is not available in the base url endpoint used,
also added warnings when model is not available in the base url endpoint
used

<!-- If resolving an issue, uncomment and update the line below -->
<!-- Closes #[issue-number] -->

## Test Plan
<!-- Describe the tests you ran to verify your changes with result
summaries. *Provide clear instructions so the plan can be easily
re-executed.* -->
run starter distro with Sambanova enabled
This commit is contained in:
Jorge Piedrahita Ortiz 2025-07-11 15:29:15 -05:00 committed by GitHub
parent 30b2e6a495
commit aa2595c7c3
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 31 additions and 12 deletions

View file

@ -7,6 +7,7 @@
import json import json
from collections.abc import Iterable from collections.abc import Iterable
import requests
from openai.types.chat import ( from openai.types.chat import (
ChatCompletionAssistantMessageParam as OpenAIChatCompletionAssistantMessage, ChatCompletionAssistantMessageParam as OpenAIChatCompletionAssistantMessage,
) )
@ -56,6 +57,7 @@ from llama_stack.apis.inference import (
ToolResponseMessage, ToolResponseMessage,
UserMessage, UserMessage,
) )
from llama_stack.apis.models import Model
from llama_stack.log import get_logger from llama_stack.log import get_logger
from llama_stack.models.llama.datatypes import BuiltinTool from llama_stack.models.llama.datatypes import BuiltinTool
from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOpenAIMixin from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOpenAIMixin
@ -176,10 +178,11 @@ class SambaNovaInferenceAdapter(LiteLLMOpenAIMixin):
def __init__(self, config: SambaNovaImplConfig): def __init__(self, config: SambaNovaImplConfig):
self.config = config self.config = config
self.environment_available_models = []
LiteLLMOpenAIMixin.__init__( LiteLLMOpenAIMixin.__init__(
self, self,
model_entries=MODEL_ENTRIES, model_entries=MODEL_ENTRIES,
api_key_from_config=self.config.api_key, api_key_from_config=self.config.api_key.get_secret_value() if self.config.api_key else None,
provider_data_api_key_field="sambanova_api_key", provider_data_api_key_field="sambanova_api_key",
) )
@ -246,6 +249,22 @@ class SambaNovaInferenceAdapter(LiteLLMOpenAIMixin):
**get_sampling_options(request.sampling_params), **get_sampling_options(request.sampling_params),
} }
async def register_model(self, model: Model) -> Model:
model_id = self.get_provider_model_id(model.provider_resource_id)
list_models_url = self.config.url + "/models"
if len(self.environment_available_models) == 0:
try:
response = requests.get(list_models_url)
response.raise_for_status()
except requests.exceptions.RequestException as e:
raise RuntimeError(f"Request to {list_models_url} failed") from e
self.environment_available_models = [model.get("id") for model in response.json().get("data", {})]
if model_id.split("sambanova/")[-1] not in self.environment_available_models:
logger.warning(f"Model {model_id} not available in {list_models_url}")
return model
async def initialize(self): async def initialize(self):
await super().initialize() await super().initialize()

View file

@ -33,6 +33,7 @@ CANNED_RESPONSE_TEXT = "I can't answer that. Can I help with something else?"
class SambaNovaSafetyAdapter(Safety, ShieldsProtocolPrivate, NeedsRequestProviderData): class SambaNovaSafetyAdapter(Safety, ShieldsProtocolPrivate, NeedsRequestProviderData):
def __init__(self, config: SambaNovaSafetyConfig) -> None: def __init__(self, config: SambaNovaSafetyConfig) -> None:
self.config = config self.config = config
self.environment_available_models = []
async def initialize(self) -> None: async def initialize(self) -> None:
pass pass
@ -54,18 +55,18 @@ class SambaNovaSafetyAdapter(Safety, ShieldsProtocolPrivate, NeedsRequestProvide
async def register_shield(self, shield: Shield) -> None: async def register_shield(self, shield: Shield) -> None:
list_models_url = self.config.url + "/models" list_models_url = self.config.url + "/models"
try: if len(self.environment_available_models) == 0:
response = requests.get(list_models_url) try:
response.raise_for_status() response = requests.get(list_models_url)
except requests.exceptions.RequestException as e: response.raise_for_status()
raise RuntimeError(f"Request to {list_models_url} failed") from e except requests.exceptions.RequestException as e:
available_models = [model.get("id") for model in response.json().get("data", {})] raise RuntimeError(f"Request to {list_models_url} failed") from e
self.environment_available_models = [model.get("id") for model in response.json().get("data", {})]
if ( if (
len(available_models) == 0 "guard" not in shield.provider_resource_id.lower()
or "guard" not in shield.provider_resource_id.lower() or shield.provider_resource_id.split("sambanova/")[-1] not in self.environment_available_models
or shield.provider_resource_id.split("sambanova/")[-1] not in available_models
): ):
raise ValueError(f"Shield {shield.provider_resource_id} not found in SambaNova") logger.warning(f"Shield {shield.provider_resource_id} not available in {list_models_url}")
async def run_shield( async def run_shield(
self, shield_id: str, messages: list[Message], params: dict[str, Any] | None = None self, shield_id: str, messages: list[Message], params: dict[str, Any] | None = None

View file

@ -71,7 +71,6 @@ def skip_if_model_doesnt_support_openai_chat_completion(client_with_models, mode
"remote::cerebras", "remote::cerebras",
"remote::databricks", "remote::databricks",
"remote::runpod", "remote::runpod",
"remote::sambanova",
"remote::tgi", "remote::tgi",
): ):
pytest.skip(f"Model {model_id} hosted by {provider.provider_type} doesn't support OpenAI chat completions.") pytest.skip(f"Model {model_id} hosted by {provider.provider_type} doesn't support OpenAI chat completions.")