diff --git a/llama_stack/providers/remote/inference/gemini/gemini.py b/llama_stack/providers/remote/inference/gemini/gemini.py index 11f6f05ad..baf0ec316 100644 --- a/llama_stack/providers/remote/inference/gemini/gemini.py +++ b/llama_stack/providers/remote/inference/gemini/gemini.py @@ -17,6 +17,7 @@ class GeminiInferenceAdapter(LiteLLMOpenAIMixin): MODEL_ENTRIES, api_key_from_config=config.api_key, provider_data_api_key_field="gemini_api_key", + litellm_provider_name="gemini", ) self.config = config diff --git a/llama_stack/providers/remote/inference/groq/groq.py b/llama_stack/providers/remote/inference/groq/groq.py index 91c6b6c17..5ad4bf49c 100644 --- a/llama_stack/providers/remote/inference/groq/groq.py +++ b/llama_stack/providers/remote/inference/groq/groq.py @@ -36,6 +36,7 @@ class GroqInferenceAdapter(LiteLLMOpenAIMixin): model_entries=MODEL_ENTRIES, api_key_from_config=config.api_key, provider_data_api_key_field="groq_api_key", + litellm_provider_name="groq", ) self.config = config diff --git a/llama_stack/providers/remote/inference/sambanova/sambanova.py b/llama_stack/providers/remote/inference/sambanova/sambanova.py index 9c2dda889..bdffb04bf 100644 --- a/llama_stack/providers/remote/inference/sambanova/sambanova.py +++ b/llama_stack/providers/remote/inference/sambanova/sambanova.py @@ -184,6 +184,7 @@ class SambaNovaInferenceAdapter(LiteLLMOpenAIMixin): model_entries=MODEL_ENTRIES, api_key_from_config=self.config.api_key.get_secret_value() if self.config.api_key else None, provider_data_api_key_field="sambanova_api_key", + litellm_provider_name="sambanova", ) def _get_api_key(self) -> str: diff --git a/llama_stack/providers/utils/inference/litellm_openai_mixin.py b/llama_stack/providers/utils/inference/litellm_openai_mixin.py index 0de267f6c..af12fc631 100644 --- a/llama_stack/providers/utils/inference/litellm_openai_mixin.py +++ b/llama_stack/providers/utils/inference/litellm_openai_mixin.py @@ -71,7 +71,17 @@ class LiteLLMOpenAIMixin( api_key_from_config: str | None, provider_data_api_key_field: str, openai_compat_api_base: str | None = None, + litellm_provider_name: str | None = None, ): + """ + Initialize the LiteLLMOpenAIMixin. + + :param model_entries: The model entries to register. + :param api_key_from_config: The API key to use from the config. + :param provider_data_api_key_field: The field in the provider data that contains the API key. + :param openai_compat_api_base: The base URL for OpenAI compatibility, or None if not using OpenAI compatibility. + :param litellm_provider_name: The name of the provider, used for model lookups. + """ ModelRegistryHelper.__init__(self, model_entries) self.api_key_from_config = api_key_from_config self.provider_data_api_key_field = provider_data_api_key_field @@ -82,6 +92,8 @@ class LiteLLMOpenAIMixin( else: self.is_openai_compat = False + self.litellm_provider_name = litellm_provider_name + async def initialize(self): pass @@ -421,3 +433,21 @@ class LiteLLMOpenAIMixin( logprobs: LogProbConfig | None = None, ): raise NotImplementedError("Batch chat completion is not supported for OpenAI Compat") + + async def check_model_availability(self, model: str) -> bool: + """ + Check if a specific model is available via LiteLLM for the current + provider (self.litellm_provider_name). + + :param model: The model identifier to check. + :return: True if the model is available dynamically, False otherwise. + """ + if not self.litellm_provider_name: + logger.warning("Provider name is not set, cannot check model availability.") + return False + + if self.litellm_provider_name not in litellm.models_by_provider: + logger.warning(f"Provider {self.litellm_provider_name} is not registered in litellm.") + return False + + return model in litellm.models_by_provider[self.litellm_provider_name]