From 3543b2a808330f11e31ad9c2875b7d7ba4bd71b3 Mon Sep 17 00:00:00 2001 From: Krish Dholakia Date: Thu, 27 Mar 2025 22:50:48 -0700 Subject: [PATCH] Support discovering gemini, anthropic, xai models by calling their `/v1/model` endpoint (#9530) * fix: initial commit for adding provider model discovery to gemini * feat(gemini/): add model discovery for gemini/ route * docs(set_keys.md): update docs to show you can check available gemini models as well * feat(anthropic/): add model discovery for anthropic api key * feat(xai/): add model discovery for XAI enables checking what models an xai key can call * ci: bump ci config yml * fix(topaz/common_utils.py): fix linting error * fix: fix linting error for python38 --- docs/my-website/docs/set_keys.md | 14 +++++++ litellm/__init__.py | 3 ++ litellm/llms/anthropic/common_utils.py | 51 ++++++++++++++++++++++++ litellm/llms/base_llm/base_utils.py | 12 +++++- litellm/llms/gemini/common_utils.py | 52 +++++++++++++++++++++++++ litellm/llms/topaz/common_utils.py | 4 +- litellm/llms/xai/common_utils.py | 51 ++++++++++++++++++++++++ litellm/utils.py | 24 ++++++++++-- tests/litellm_utils_tests/test_utils.py | 18 +++++++++ 9 files changed, 223 insertions(+), 6 deletions(-) create mode 100644 litellm/llms/gemini/common_utils.py create mode 100644 litellm/llms/xai/common_utils.py diff --git a/docs/my-website/docs/set_keys.md b/docs/my-website/docs/set_keys.md index 3a5ff08d63..693cf5f7f4 100644 --- a/docs/my-website/docs/set_keys.md +++ b/docs/my-website/docs/set_keys.md @@ -188,7 +188,13 @@ Currently implemented for: - OpenAI (if OPENAI_API_KEY is set) - Fireworks AI (if FIREWORKS_AI_API_KEY is set) - LiteLLM Proxy (if LITELLM_PROXY_API_KEY is set) +- Gemini (if GEMINI_API_KEY is set) +- XAI (if XAI_API_KEY is set) +- Anthropic (if ANTHROPIC_API_KEY is set) +You can also specify a custom provider to check: + +**All providers**: ```python from litellm import get_valid_models @@ -196,6 +202,14 @@ valid_models = get_valid_models(check_provider_endpoint=True) print(valid_models) ``` +**Specific provider**: +```python +from litellm import get_valid_models + +valid_models = get_valid_models(check_provider_endpoint=True, custom_llm_provider="openai") +print(valid_models) +``` + ### `validate_environment(model: str)` This helper tells you if you have all the required environment variables for a model, and if not - what's missing. diff --git a/litellm/__init__.py b/litellm/__init__.py index 8cdde24a6a..a4903f828c 100644 --- a/litellm/__init__.py +++ b/litellm/__init__.py @@ -813,6 +813,7 @@ from .llms.oobabooga.chat.transformation import OobaboogaConfig from .llms.maritalk import MaritalkConfig from .llms.openrouter.chat.transformation import OpenrouterConfig from .llms.anthropic.chat.transformation import AnthropicConfig +from .llms.anthropic.common_utils import AnthropicModelInfo from .llms.groq.stt.transformation import GroqSTTConfig from .llms.anthropic.completion.transformation import AnthropicTextConfig from .llms.triton.completion.transformation import TritonConfig @@ -848,6 +849,7 @@ from .llms.vertex_ai.gemini.vertex_and_google_ai_studio_gemini import ( VertexGeminiConfig, VertexGeminiConfig as VertexAIConfig, ) +from .llms.gemini.common_utils import GeminiModelInfo from .llms.gemini.chat.transformation import ( GoogleAIStudioGeminiConfig, GoogleAIStudioGeminiConfig as GeminiConfig, # aliased to maintain backwards compatibility @@ -984,6 +986,7 @@ from .llms.fireworks_ai.embed.fireworks_ai_transformation import ( from .llms.friendliai.chat.transformation import FriendliaiChatConfig from .llms.jina_ai.embedding.transformation import JinaAIEmbeddingConfig from .llms.xai.chat.transformation import XAIChatConfig +from .llms.xai.common_utils import XAIModelInfo from .llms.volcengine import VolcEngineConfig from .llms.codestral.completion.transformation import CodestralTextCompletionConfig from .llms.azure.azure import ( diff --git a/litellm/llms/anthropic/common_utils.py b/litellm/llms/anthropic/common_utils.py index 409bbe2d82..52a96f5a30 100644 --- a/litellm/llms/anthropic/common_utils.py +++ b/litellm/llms/anthropic/common_utils.py @@ -6,7 +6,10 @@ from typing import Optional, Union import httpx +import litellm +from litellm.llms.base_llm.base_utils import BaseLLMModelInfo from litellm.llms.base_llm.chat.transformation import BaseLLMException +from litellm.secret_managers.main import get_secret_str class AnthropicError(BaseLLMException): @@ -19,6 +22,54 @@ class AnthropicError(BaseLLMException): super().__init__(status_code=status_code, message=message, headers=headers) +class AnthropicModelInfo(BaseLLMModelInfo): + @staticmethod + def get_api_base(api_base: Optional[str] = None) -> str | None: + return ( + api_base + or get_secret_str("ANTHROPIC_API_BASE") + or "https://api.anthropic.com" + ) + + @staticmethod + def get_api_key(api_key: str | None = None) -> str | None: + return api_key or get_secret_str("ANTHROPIC_API_KEY") + + @staticmethod + def get_base_model(model: str) -> str | None: + return model.replace("anthropic/", "") + + def get_models( + self, api_key: Optional[str] = None, api_base: Optional[str] = None + ) -> list[str]: + api_base = AnthropicModelInfo.get_api_base(api_base) + api_key = AnthropicModelInfo.get_api_key(api_key) + if api_base is None or api_key is None: + raise ValueError( + "ANTHROPIC_API_BASE or ANTHROPIC_API_KEY is not set. Please set the environment variable, to query Anthropic's `/models` endpoint." + ) + response = litellm.module_level_client.get( + url=f"{api_base}/v1/models", + headers={"x-api-key": api_key, "anthropic-version": "2023-06-01"}, + ) + + try: + response.raise_for_status() + except httpx.HTTPStatusError: + raise Exception( + f"Failed to fetch models from Anthropic. Status code: {response.status_code}, Response: {response.text}" + ) + + models = response.json()["data"] + + litellm_model_names = [] + for model in models: + stripped_model_name = model["id"] + litellm_model_name = "anthropic/" + stripped_model_name + litellm_model_names.append(litellm_model_name) + return litellm_model_names + + def process_anthropic_headers(headers: Union[httpx.Headers, dict]) -> dict: openai_headers = {} if "anthropic-ratelimit-requests-limit" in headers: diff --git a/litellm/llms/base_llm/base_utils.py b/litellm/llms/base_llm/base_utils.py index 919cdbfd02..cef64d01e3 100644 --- a/litellm/llms/base_llm/base_utils.py +++ b/litellm/llms/base_llm/base_utils.py @@ -19,11 +19,19 @@ class BaseLLMModelInfo(ABC): self, model: str, ) -> Optional[ProviderSpecificModelInfo]: + """ + Default values all models of this provider support. + """ return None @abstractmethod - def get_models(self) -> List[str]: - pass + def get_models( + self, api_key: Optional[str] = None, api_base: Optional[str] = None + ) -> List[str]: + """ + Returns a list of models supported by this provider. + """ + return [] @staticmethod @abstractmethod diff --git a/litellm/llms/gemini/common_utils.py b/litellm/llms/gemini/common_utils.py new file mode 100644 index 0000000000..7f266c0536 --- /dev/null +++ b/litellm/llms/gemini/common_utils.py @@ -0,0 +1,52 @@ +from typing import List, Optional + +import litellm +from litellm.llms.base_llm.base_utils import BaseLLMModelInfo +from litellm.secret_managers.main import get_secret_str + + +class GeminiModelInfo(BaseLLMModelInfo): + @staticmethod + def get_api_base(api_base: Optional[str] = None) -> Optional[str]: + return ( + api_base + or get_secret_str("GEMINI_API_BASE") + or "https://generativelanguage.googleapis.com/v1beta" + ) + + @staticmethod + def get_api_key(api_key: Optional[str] = None) -> Optional[str]: + return api_key or (get_secret_str("GEMINI_API_KEY")) + + @staticmethod + def get_base_model(model: str) -> Optional[str]: + return model.replace("gemini/", "") + + def get_models( + self, api_key: Optional[str] = None, api_base: Optional[str] = None + ) -> List[str]: + + api_base = GeminiModelInfo.get_api_base(api_base) + api_key = GeminiModelInfo.get_api_key(api_key) + if api_base is None or api_key is None: + raise ValueError( + "GEMINI_API_BASE or GEMINI_API_KEY is not set. Please set the environment variable, to query Gemini's `/models` endpoint." + ) + + response = litellm.module_level_client.get( + url=f"{api_base}/models?key={api_key}", + ) + + if response.status_code != 200: + raise ValueError( + f"Failed to fetch models from Gemini. Status code: {response.status_code}, Response: {response.json()}" + ) + + models = response.json()["models"] + + litellm_model_names = [] + for model in models: + stripped_model_name = model["name"].strip("models/") + litellm_model_name = "gemini/" + stripped_model_name + litellm_model_names.append(litellm_model_name) + return litellm_model_names diff --git a/litellm/llms/topaz/common_utils.py b/litellm/llms/topaz/common_utils.py index 4ef2315db4..0252585922 100644 --- a/litellm/llms/topaz/common_utils.py +++ b/litellm/llms/topaz/common_utils.py @@ -11,7 +11,9 @@ class TopazException(BaseLLMException): class TopazModelInfo(BaseLLMModelInfo): - def get_models(self) -> List[str]: + def get_models( + self, api_key: Optional[str] = None, api_base: Optional[str] = None + ) -> List[str]: return [ "topaz/Standard V2", "topaz/Low Resolution V2", diff --git a/litellm/llms/xai/common_utils.py b/litellm/llms/xai/common_utils.py new file mode 100644 index 0000000000..fdf2edbfa3 --- /dev/null +++ b/litellm/llms/xai/common_utils.py @@ -0,0 +1,51 @@ +from typing import Optional + +import httpx + +import litellm +from litellm.llms.base_llm.base_utils import BaseLLMModelInfo +from litellm.secret_managers.main import get_secret_str + + +class XAIModelInfo(BaseLLMModelInfo): + @staticmethod + def get_api_base(api_base: Optional[str] = None) -> Optional[str]: + return api_base or get_secret_str("XAI_API_BASE") or "https://api.x.ai" + + @staticmethod + def get_api_key(api_key: Optional[str] = None) -> Optional[str]: + return api_key or get_secret_str("XAI_API_KEY") + + @staticmethod + def get_base_model(model: str) -> Optional[str]: + return model.replace("xai/", "") + + def get_models( + self, api_key: Optional[str] = None, api_base: Optional[str] = None + ) -> list[str]: + api_base = self.get_api_base(api_base) + api_key = self.get_api_key(api_key) + if api_base is None or api_key is None: + raise ValueError( + "XAI_API_BASE or XAI_API_KEY is not set. Please set the environment variable, to query XAI's `/models` endpoint." + ) + response = litellm.module_level_client.get( + url=f"{api_base}/v1/models", + headers={"Authorization": f"Bearer {api_key}"}, + ) + + try: + response.raise_for_status() + except httpx.HTTPStatusError: + raise Exception( + f"Failed to fetch models from XAI. Status code: {response.status_code}, Response: {response.text}" + ) + + models = response.json()["data"] + + litellm_model_names = [] + for model in models: + stripped_model_name = model["id"] + litellm_model_name = "xai/" + stripped_model_name + litellm_model_names.append(litellm_model_name) + return litellm_model_names diff --git a/litellm/utils.py b/litellm/utils.py index 3fcb4a803a..3c8b6667f9 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -5744,13 +5744,15 @@ def trim_messages( return messages -def get_valid_models(check_provider_endpoint: bool = False) -> List[str]: +def get_valid_models( + check_provider_endpoint: bool = False, custom_llm_provider: Optional[str] = None +) -> List[str]: """ Returns a list of valid LLMs based on the set environment variables Args: check_provider_endpoint: If True, will check the provider's endpoint for valid models. - + custom_llm_provider: If provided, will only check the provider's endpoint for valid models. Returns: A list of valid LLMs """ @@ -5762,6 +5764,9 @@ def get_valid_models(check_provider_endpoint: bool = False) -> List[str]: valid_models = [] for provider in litellm.provider_list: + if custom_llm_provider and provider != custom_llm_provider: + continue + # edge case litellm has together_ai as a provider, it should be togetherai env_provider_1 = provider.replace("_", "") env_provider_2 = provider @@ -5783,10 +5788,17 @@ def get_valid_models(check_provider_endpoint: bool = False) -> List[str]: provider=LlmProviders(provider), ) + if custom_llm_provider and provider != custom_llm_provider: + continue + if provider == "azure": valid_models.append("Azure-LLM") elif provider_config is not None and check_provider_endpoint: - valid_models.extend(provider_config.get_models()) + try: + models = provider_config.get_models() + valid_models.extend(models) + except Exception as e: + verbose_logger.debug(f"Error getting valid models: {e}") else: models_for_provider = litellm.models_by_provider.get(provider, []) valid_models.extend(models_for_provider) @@ -6400,10 +6412,16 @@ class ProviderConfigManager: return litellm.FireworksAIConfig() elif LlmProviders.OPENAI == provider: return litellm.OpenAIGPTConfig() + elif LlmProviders.GEMINI == provider: + return litellm.GeminiModelInfo() elif LlmProviders.LITELLM_PROXY == provider: return litellm.LiteLLMProxyChatConfig() elif LlmProviders.TOPAZ == provider: return litellm.TopazModelInfo() + elif LlmProviders.ANTHROPIC == provider: + return litellm.AnthropicModelInfo() + elif LlmProviders.XAI == provider: + return litellm.XAIModelInfo() return None diff --git a/tests/litellm_utils_tests/test_utils.py b/tests/litellm_utils_tests/test_utils.py index 535861ce1a..3088fa250f 100644 --- a/tests/litellm_utils_tests/test_utils.py +++ b/tests/litellm_utils_tests/test_utils.py @@ -303,6 +303,24 @@ def test_aget_valid_models(): os.environ = old_environ +@pytest.mark.parametrize("custom_llm_provider", ["gemini", "anthropic", "xai"]) +def test_get_valid_models_with_custom_llm_provider(custom_llm_provider): + from litellm.utils import ProviderConfigManager + from litellm.types.utils import LlmProviders + + provider_config = ProviderConfigManager.get_provider_model_info( + model=None, + provider=LlmProviders(custom_llm_provider), + ) + assert provider_config is not None + valid_models = get_valid_models( + check_provider_endpoint=True, custom_llm_provider=custom_llm_provider + ) + print(valid_models) + assert len(valid_models) > 0 + assert provider_config.get_models() == valid_models + + # test_get_valid_models()