mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-14 17:16:09 +00:00
fix: sambanova shields and model validation (#2693)
# What does this PR do? Update the shield register validation of Sambanova not to raise, but only warn when a model is not available in the base url endpoint used, also added warnings when model is not available in the base url endpoint used <!-- If resolving an issue, uncomment and update the line below --> <!-- Closes #[issue-number] --> ## Test Plan <!-- Describe the tests you ran to verify your changes with result summaries. *Provide clear instructions so the plan can be easily re-executed.* --> run starter distro with Sambanova enabled
This commit is contained in:
parent
30b2e6a495
commit
aa2595c7c3
3 changed files with 31 additions and 12 deletions
|
@ -7,6 +7,7 @@
|
||||||
import json
|
import json
|
||||||
from collections.abc import Iterable
|
from collections.abc import Iterable
|
||||||
|
|
||||||
|
import requests
|
||||||
from openai.types.chat import (
|
from openai.types.chat import (
|
||||||
ChatCompletionAssistantMessageParam as OpenAIChatCompletionAssistantMessage,
|
ChatCompletionAssistantMessageParam as OpenAIChatCompletionAssistantMessage,
|
||||||
)
|
)
|
||||||
|
@ -56,6 +57,7 @@ from llama_stack.apis.inference import (
|
||||||
ToolResponseMessage,
|
ToolResponseMessage,
|
||||||
UserMessage,
|
UserMessage,
|
||||||
)
|
)
|
||||||
|
from llama_stack.apis.models import Model
|
||||||
from llama_stack.log import get_logger
|
from llama_stack.log import get_logger
|
||||||
from llama_stack.models.llama.datatypes import BuiltinTool
|
from llama_stack.models.llama.datatypes import BuiltinTool
|
||||||
from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOpenAIMixin
|
from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOpenAIMixin
|
||||||
|
@ -176,10 +178,11 @@ class SambaNovaInferenceAdapter(LiteLLMOpenAIMixin):
|
||||||
|
|
||||||
def __init__(self, config: SambaNovaImplConfig):
|
def __init__(self, config: SambaNovaImplConfig):
|
||||||
self.config = config
|
self.config = config
|
||||||
|
self.environment_available_models = []
|
||||||
LiteLLMOpenAIMixin.__init__(
|
LiteLLMOpenAIMixin.__init__(
|
||||||
self,
|
self,
|
||||||
model_entries=MODEL_ENTRIES,
|
model_entries=MODEL_ENTRIES,
|
||||||
api_key_from_config=self.config.api_key,
|
api_key_from_config=self.config.api_key.get_secret_value() if self.config.api_key else None,
|
||||||
provider_data_api_key_field="sambanova_api_key",
|
provider_data_api_key_field="sambanova_api_key",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -246,6 +249,22 @@ class SambaNovaInferenceAdapter(LiteLLMOpenAIMixin):
|
||||||
**get_sampling_options(request.sampling_params),
|
**get_sampling_options(request.sampling_params),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
async def register_model(self, model: Model) -> Model:
|
||||||
|
model_id = self.get_provider_model_id(model.provider_resource_id)
|
||||||
|
|
||||||
|
list_models_url = self.config.url + "/models"
|
||||||
|
if len(self.environment_available_models) == 0:
|
||||||
|
try:
|
||||||
|
response = requests.get(list_models_url)
|
||||||
|
response.raise_for_status()
|
||||||
|
except requests.exceptions.RequestException as e:
|
||||||
|
raise RuntimeError(f"Request to {list_models_url} failed") from e
|
||||||
|
self.environment_available_models = [model.get("id") for model in response.json().get("data", {})]
|
||||||
|
|
||||||
|
if model_id.split("sambanova/")[-1] not in self.environment_available_models:
|
||||||
|
logger.warning(f"Model {model_id} not available in {list_models_url}")
|
||||||
|
return model
|
||||||
|
|
||||||
async def initialize(self):
|
async def initialize(self):
|
||||||
await super().initialize()
|
await super().initialize()
|
||||||
|
|
||||||
|
|
|
@ -33,6 +33,7 @@ CANNED_RESPONSE_TEXT = "I can't answer that. Can I help with something else?"
|
||||||
class SambaNovaSafetyAdapter(Safety, ShieldsProtocolPrivate, NeedsRequestProviderData):
|
class SambaNovaSafetyAdapter(Safety, ShieldsProtocolPrivate, NeedsRequestProviderData):
|
||||||
def __init__(self, config: SambaNovaSafetyConfig) -> None:
|
def __init__(self, config: SambaNovaSafetyConfig) -> None:
|
||||||
self.config = config
|
self.config = config
|
||||||
|
self.environment_available_models = []
|
||||||
|
|
||||||
async def initialize(self) -> None:
|
async def initialize(self) -> None:
|
||||||
pass
|
pass
|
||||||
|
@ -54,18 +55,18 @@ class SambaNovaSafetyAdapter(Safety, ShieldsProtocolPrivate, NeedsRequestProvide
|
||||||
|
|
||||||
async def register_shield(self, shield: Shield) -> None:
|
async def register_shield(self, shield: Shield) -> None:
|
||||||
list_models_url = self.config.url + "/models"
|
list_models_url = self.config.url + "/models"
|
||||||
try:
|
if len(self.environment_available_models) == 0:
|
||||||
response = requests.get(list_models_url)
|
try:
|
||||||
response.raise_for_status()
|
response = requests.get(list_models_url)
|
||||||
except requests.exceptions.RequestException as e:
|
response.raise_for_status()
|
||||||
raise RuntimeError(f"Request to {list_models_url} failed") from e
|
except requests.exceptions.RequestException as e:
|
||||||
available_models = [model.get("id") for model in response.json().get("data", {})]
|
raise RuntimeError(f"Request to {list_models_url} failed") from e
|
||||||
|
self.environment_available_models = [model.get("id") for model in response.json().get("data", {})]
|
||||||
if (
|
if (
|
||||||
len(available_models) == 0
|
"guard" not in shield.provider_resource_id.lower()
|
||||||
or "guard" not in shield.provider_resource_id.lower()
|
or shield.provider_resource_id.split("sambanova/")[-1] not in self.environment_available_models
|
||||||
or shield.provider_resource_id.split("sambanova/")[-1] not in available_models
|
|
||||||
):
|
):
|
||||||
raise ValueError(f"Shield {shield.provider_resource_id} not found in SambaNova")
|
logger.warning(f"Shield {shield.provider_resource_id} not available in {list_models_url}")
|
||||||
|
|
||||||
async def run_shield(
|
async def run_shield(
|
||||||
self, shield_id: str, messages: list[Message], params: dict[str, Any] | None = None
|
self, shield_id: str, messages: list[Message], params: dict[str, Any] | None = None
|
||||||
|
|
|
@ -71,7 +71,6 @@ def skip_if_model_doesnt_support_openai_chat_completion(client_with_models, mode
|
||||||
"remote::cerebras",
|
"remote::cerebras",
|
||||||
"remote::databricks",
|
"remote::databricks",
|
||||||
"remote::runpod",
|
"remote::runpod",
|
||||||
"remote::sambanova",
|
|
||||||
"remote::tgi",
|
"remote::tgi",
|
||||||
):
|
):
|
||||||
pytest.skip(f"Model {model_id} hosted by {provider.provider_type} doesn't support OpenAI chat completions.")
|
pytest.skip(f"Model {model_id} hosted by {provider.provider_type} doesn't support OpenAI chat completions.")
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue