From bf63470c22f3101bb8f390d11141d73c1ac4abbd Mon Sep 17 00:00:00 2001 From: Matthew Farrellee Date: Thu, 24 Jul 2025 09:49:32 -0400 Subject: [PATCH] feat: implement dynamic model detection support for inference providers using litellm This enhancement allows inference providers using LiteLLMOpenAIMixin to validate model availability against LiteLLM's official provider model listings, improving reliability and user experience when working with different AI service providers. - Add litellm_provider_name parameter to LiteLLMOpenAIMixin constructor - Add check_model_availability method to LiteLLMOpenAIMixin using litellm.models_by_provider - Update Gemini, Groq, and SambaNova inference adapters to pass litellm_provider_name --- .../remote/inference/gemini/gemini.py | 1 + .../providers/remote/inference/groq/groq.py | 1 + .../remote/inference/sambanova/sambanova.py | 1 + .../utils/inference/litellm_openai_mixin.py | 30 +++++++++++++++++++ 4 files changed, 33 insertions(+) diff --git a/llama_stack/providers/remote/inference/gemini/gemini.py b/llama_stack/providers/remote/inference/gemini/gemini.py index 11f6f05ad..baf0ec316 100644 --- a/llama_stack/providers/remote/inference/gemini/gemini.py +++ b/llama_stack/providers/remote/inference/gemini/gemini.py @@ -17,6 +17,7 @@ class GeminiInferenceAdapter(LiteLLMOpenAIMixin): MODEL_ENTRIES, api_key_from_config=config.api_key, provider_data_api_key_field="gemini_api_key", + litellm_provider_name="gemini", ) self.config = config diff --git a/llama_stack/providers/remote/inference/groq/groq.py b/llama_stack/providers/remote/inference/groq/groq.py index 91c6b6c17..5ad4bf49c 100644 --- a/llama_stack/providers/remote/inference/groq/groq.py +++ b/llama_stack/providers/remote/inference/groq/groq.py @@ -36,6 +36,7 @@ class GroqInferenceAdapter(LiteLLMOpenAIMixin): model_entries=MODEL_ENTRIES, api_key_from_config=config.api_key, provider_data_api_key_field="groq_api_key", + litellm_provider_name="groq", ) self.config = config diff --git a/llama_stack/providers/remote/inference/sambanova/sambanova.py b/llama_stack/providers/remote/inference/sambanova/sambanova.py index 9c2dda889..bdffb04bf 100644 --- a/llama_stack/providers/remote/inference/sambanova/sambanova.py +++ b/llama_stack/providers/remote/inference/sambanova/sambanova.py @@ -184,6 +184,7 @@ class SambaNovaInferenceAdapter(LiteLLMOpenAIMixin): model_entries=MODEL_ENTRIES, api_key_from_config=self.config.api_key.get_secret_value() if self.config.api_key else None, provider_data_api_key_field="sambanova_api_key", + litellm_provider_name="sambanova", ) def _get_api_key(self) -> str: diff --git a/llama_stack/providers/utils/inference/litellm_openai_mixin.py b/llama_stack/providers/utils/inference/litellm_openai_mixin.py index 0de267f6c..af12fc631 100644 --- a/llama_stack/providers/utils/inference/litellm_openai_mixin.py +++ b/llama_stack/providers/utils/inference/litellm_openai_mixin.py @@ -71,7 +71,17 @@ class LiteLLMOpenAIMixin( api_key_from_config: str | None, provider_data_api_key_field: str, openai_compat_api_base: str | None = None, + litellm_provider_name: str | None = None, ): + """ + Initialize the LiteLLMOpenAIMixin. + + :param model_entries: The model entries to register. + :param api_key_from_config: The API key to use from the config. + :param provider_data_api_key_field: The field in the provider data that contains the API key. + :param openai_compat_api_base: The base URL for OpenAI compatibility, or None if not using OpenAI compatibility. + :param litellm_provider_name: The name of the provider, used for model lookups. + """ ModelRegistryHelper.__init__(self, model_entries) self.api_key_from_config = api_key_from_config self.provider_data_api_key_field = provider_data_api_key_field @@ -82,6 +92,8 @@ class LiteLLMOpenAIMixin( else: self.is_openai_compat = False + self.litellm_provider_name = litellm_provider_name + async def initialize(self): pass @@ -421,3 +433,21 @@ class LiteLLMOpenAIMixin( logprobs: LogProbConfig | None = None, ): raise NotImplementedError("Batch chat completion is not supported for OpenAI Compat") + + async def check_model_availability(self, model: str) -> bool: + """ + Check if a specific model is available via LiteLLM for the current + provider (self.litellm_provider_name). + + :param model: The model identifier to check. + :return: True if the model is available dynamically, False otherwise. + """ + if not self.litellm_provider_name: + logger.warning("Provider name is not set, cannot check model availability.") + return False + + if self.litellm_provider_name not in litellm.models_by_provider: + logger.warning(f"Provider {self.litellm_provider_name} is not registered in litellm.") + return False + + return model in litellm.models_by_provider[self.litellm_provider_name]