diff --git a/llama_stack/providers/remote/inference/sambanova/sambanova.py b/llama_stack/providers/remote/inference/sambanova/sambanova.py index 9be5763aa..aade2f726 100644 --- a/llama_stack/providers/remote/inference/sambanova/sambanova.py +++ b/llama_stack/providers/remote/inference/sambanova/sambanova.py @@ -178,6 +178,7 @@ class SambaNovaInferenceAdapter(LiteLLMOpenAIMixin): def __init__(self, config: SambaNovaImplConfig): self.config = config + self.environment_available_models = [] LiteLLMOpenAIMixin.__init__( self, model_entries=MODEL_ENTRIES, @@ -250,15 +251,18 @@ class SambaNovaInferenceAdapter(LiteLLMOpenAIMixin): async def register_model(self, model: Model) -> Model: model_id = self.get_provider_model_id(model.provider_resource_id) + list_models_url = self.config.url + "/models" - try: - response = requests.get(list_models_url) - response.raise_for_status() - except requests.exceptions.RequestException as e: - raise RuntimeError(f"Request to {list_models_url} failed") from e - available_models = [model.get("id") for model in response.json().get("data", {})] - if len(available_models) == 0 or model_id.split("sambanova/")[-1] not in available_models: - logger.warning(f"Model {model_id} not available in {self.config.url}/models") + if len(self.environment_available_models) == 0: + try: + response = requests.get(list_models_url) + response.raise_for_status() + except requests.exceptions.RequestException as e: + raise RuntimeError(f"Request to {list_models_url} failed") from e + self.environment_available_models = [model.get("id") for model in response.json().get("data", {})] + + if model_id.split("sambanova/")[-1] not in self.environment_available_models: + logger.warning(f"Model {model_id} not available in {list_models_url}") return model async def initialize(self): diff --git a/llama_stack/providers/remote/safety/sambanova/sambanova.py b/llama_stack/providers/remote/safety/sambanova/sambanova.py index 242da83a6..1a65f6aa1 100644 --- a/llama_stack/providers/remote/safety/sambanova/sambanova.py +++ b/llama_stack/providers/remote/safety/sambanova/sambanova.py @@ -33,6 +33,7 @@ CANNED_RESPONSE_TEXT = "I can't answer that. Can I help with something else?" class SambaNovaSafetyAdapter(Safety, ShieldsProtocolPrivate, NeedsRequestProviderData): def __init__(self, config: SambaNovaSafetyConfig) -> None: self.config = config + self.environment_available_models = [] async def initialize(self) -> None: pass @@ -54,18 +55,18 @@ class SambaNovaSafetyAdapter(Safety, ShieldsProtocolPrivate, NeedsRequestProvide async def register_shield(self, shield: Shield) -> None: list_models_url = self.config.url + "/models" - try: - response = requests.get(list_models_url) - response.raise_for_status() - except requests.exceptions.RequestException as e: - raise RuntimeError(f"Request to {list_models_url} failed") from e - available_models = [model.get("id") for model in response.json().get("data", {})] + if len(self.environment_available_models) == 0: + try: + response = requests.get(list_models_url) + response.raise_for_status() + except requests.exceptions.RequestException as e: + raise RuntimeError(f"Request to {list_models_url} failed") from e + self.environment_available_models = [model.get("id") for model in response.json().get("data", {})] if ( - len(available_models) == 0 - or "guard" not in shield.provider_resource_id.lower() - or shield.provider_resource_id.split("sambanova/")[-1] not in available_models + "guard" not in shield.provider_resource_id.lower() + or shield.provider_resource_id.split("sambanova/")[-1] not in self.environment_available_models ): - logger.warning(f"Shield {shield.provider_resource_id} not available in {self.config.url}/models") + logger.warning(f"Shield {shield.provider_resource_id} not available in {list_models_url}") async def run_shield( self, shield_id: str, messages: list[Message], params: dict[str, Any] | None = None