mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 18:54:30 +00:00
Support discovering gemini, anthropic, xai models by calling their /v1/model
endpoint (#9530)
* fix: initial commit for adding provider model discovery to gemini * feat(gemini/): add model discovery for gemini/ route * docs(set_keys.md): update docs to show you can check available gemini models as well * feat(anthropic/): add model discovery for anthropic api key * feat(xai/): add model discovery for XAI enables checking what models an xai key can call * ci: bump ci config yml * fix(topaz/common_utils.py): fix linting error * fix: fix linting error for python38
This commit is contained in:
parent
ff4419e5ee
commit
3543b2a808
9 changed files with 223 additions and 6 deletions
|
@ -188,7 +188,13 @@ Currently implemented for:
|
||||||
- OpenAI (if OPENAI_API_KEY is set)
|
- OpenAI (if OPENAI_API_KEY is set)
|
||||||
- Fireworks AI (if FIREWORKS_AI_API_KEY is set)
|
- Fireworks AI (if FIREWORKS_AI_API_KEY is set)
|
||||||
- LiteLLM Proxy (if LITELLM_PROXY_API_KEY is set)
|
- LiteLLM Proxy (if LITELLM_PROXY_API_KEY is set)
|
||||||
|
- Gemini (if GEMINI_API_KEY is set)
|
||||||
|
- XAI (if XAI_API_KEY is set)
|
||||||
|
- Anthropic (if ANTHROPIC_API_KEY is set)
|
||||||
|
|
||||||
|
You can also specify a custom provider to check:
|
||||||
|
|
||||||
|
**All providers**:
|
||||||
```python
|
```python
|
||||||
from litellm import get_valid_models
|
from litellm import get_valid_models
|
||||||
|
|
||||||
|
@ -196,6 +202,14 @@ valid_models = get_valid_models(check_provider_endpoint=True)
|
||||||
print(valid_models)
|
print(valid_models)
|
||||||
```
|
```
|
||||||
|
|
||||||
|
**Specific provider**:
|
||||||
|
```python
|
||||||
|
from litellm import get_valid_models
|
||||||
|
|
||||||
|
valid_models = get_valid_models(check_provider_endpoint=True, custom_llm_provider="openai")
|
||||||
|
print(valid_models)
|
||||||
|
```
|
||||||
|
|
||||||
### `validate_environment(model: str)`
|
### `validate_environment(model: str)`
|
||||||
|
|
||||||
This helper tells you if you have all the required environment variables for a model, and if not - what's missing.
|
This helper tells you if you have all the required environment variables for a model, and if not - what's missing.
|
||||||
|
|
|
@ -813,6 +813,7 @@ from .llms.oobabooga.chat.transformation import OobaboogaConfig
|
||||||
from .llms.maritalk import MaritalkConfig
|
from .llms.maritalk import MaritalkConfig
|
||||||
from .llms.openrouter.chat.transformation import OpenrouterConfig
|
from .llms.openrouter.chat.transformation import OpenrouterConfig
|
||||||
from .llms.anthropic.chat.transformation import AnthropicConfig
|
from .llms.anthropic.chat.transformation import AnthropicConfig
|
||||||
|
from .llms.anthropic.common_utils import AnthropicModelInfo
|
||||||
from .llms.groq.stt.transformation import GroqSTTConfig
|
from .llms.groq.stt.transformation import GroqSTTConfig
|
||||||
from .llms.anthropic.completion.transformation import AnthropicTextConfig
|
from .llms.anthropic.completion.transformation import AnthropicTextConfig
|
||||||
from .llms.triton.completion.transformation import TritonConfig
|
from .llms.triton.completion.transformation import TritonConfig
|
||||||
|
@ -848,6 +849,7 @@ from .llms.vertex_ai.gemini.vertex_and_google_ai_studio_gemini import (
|
||||||
VertexGeminiConfig,
|
VertexGeminiConfig,
|
||||||
VertexGeminiConfig as VertexAIConfig,
|
VertexGeminiConfig as VertexAIConfig,
|
||||||
)
|
)
|
||||||
|
from .llms.gemini.common_utils import GeminiModelInfo
|
||||||
from .llms.gemini.chat.transformation import (
|
from .llms.gemini.chat.transformation import (
|
||||||
GoogleAIStudioGeminiConfig,
|
GoogleAIStudioGeminiConfig,
|
||||||
GoogleAIStudioGeminiConfig as GeminiConfig, # aliased to maintain backwards compatibility
|
GoogleAIStudioGeminiConfig as GeminiConfig, # aliased to maintain backwards compatibility
|
||||||
|
@ -984,6 +986,7 @@ from .llms.fireworks_ai.embed.fireworks_ai_transformation import (
|
||||||
from .llms.friendliai.chat.transformation import FriendliaiChatConfig
|
from .llms.friendliai.chat.transformation import FriendliaiChatConfig
|
||||||
from .llms.jina_ai.embedding.transformation import JinaAIEmbeddingConfig
|
from .llms.jina_ai.embedding.transformation import JinaAIEmbeddingConfig
|
||||||
from .llms.xai.chat.transformation import XAIChatConfig
|
from .llms.xai.chat.transformation import XAIChatConfig
|
||||||
|
from .llms.xai.common_utils import XAIModelInfo
|
||||||
from .llms.volcengine import VolcEngineConfig
|
from .llms.volcengine import VolcEngineConfig
|
||||||
from .llms.codestral.completion.transformation import CodestralTextCompletionConfig
|
from .llms.codestral.completion.transformation import CodestralTextCompletionConfig
|
||||||
from .llms.azure.azure import (
|
from .llms.azure.azure import (
|
||||||
|
|
|
@ -6,7 +6,10 @@ from typing import Optional, Union
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
|
|
||||||
|
import litellm
|
||||||
|
from litellm.llms.base_llm.base_utils import BaseLLMModelInfo
|
||||||
from litellm.llms.base_llm.chat.transformation import BaseLLMException
|
from litellm.llms.base_llm.chat.transformation import BaseLLMException
|
||||||
|
from litellm.secret_managers.main import get_secret_str
|
||||||
|
|
||||||
|
|
||||||
class AnthropicError(BaseLLMException):
|
class AnthropicError(BaseLLMException):
|
||||||
|
@ -19,6 +22,54 @@ class AnthropicError(BaseLLMException):
|
||||||
super().__init__(status_code=status_code, message=message, headers=headers)
|
super().__init__(status_code=status_code, message=message, headers=headers)
|
||||||
|
|
||||||
|
|
||||||
|
class AnthropicModelInfo(BaseLLMModelInfo):
|
||||||
|
@staticmethod
|
||||||
|
def get_api_base(api_base: Optional[str] = None) -> str | None:
|
||||||
|
return (
|
||||||
|
api_base
|
||||||
|
or get_secret_str("ANTHROPIC_API_BASE")
|
||||||
|
or "https://api.anthropic.com"
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_api_key(api_key: str | None = None) -> str | None:
|
||||||
|
return api_key or get_secret_str("ANTHROPIC_API_KEY")
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_base_model(model: str) -> str | None:
|
||||||
|
return model.replace("anthropic/", "")
|
||||||
|
|
||||||
|
def get_models(
|
||||||
|
self, api_key: Optional[str] = None, api_base: Optional[str] = None
|
||||||
|
) -> list[str]:
|
||||||
|
api_base = AnthropicModelInfo.get_api_base(api_base)
|
||||||
|
api_key = AnthropicModelInfo.get_api_key(api_key)
|
||||||
|
if api_base is None or api_key is None:
|
||||||
|
raise ValueError(
|
||||||
|
"ANTHROPIC_API_BASE or ANTHROPIC_API_KEY is not set. Please set the environment variable, to query Anthropic's `/models` endpoint."
|
||||||
|
)
|
||||||
|
response = litellm.module_level_client.get(
|
||||||
|
url=f"{api_base}/v1/models",
|
||||||
|
headers={"x-api-key": api_key, "anthropic-version": "2023-06-01"},
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
response.raise_for_status()
|
||||||
|
except httpx.HTTPStatusError:
|
||||||
|
raise Exception(
|
||||||
|
f"Failed to fetch models from Anthropic. Status code: {response.status_code}, Response: {response.text}"
|
||||||
|
)
|
||||||
|
|
||||||
|
models = response.json()["data"]
|
||||||
|
|
||||||
|
litellm_model_names = []
|
||||||
|
for model in models:
|
||||||
|
stripped_model_name = model["id"]
|
||||||
|
litellm_model_name = "anthropic/" + stripped_model_name
|
||||||
|
litellm_model_names.append(litellm_model_name)
|
||||||
|
return litellm_model_names
|
||||||
|
|
||||||
|
|
||||||
def process_anthropic_headers(headers: Union[httpx.Headers, dict]) -> dict:
|
def process_anthropic_headers(headers: Union[httpx.Headers, dict]) -> dict:
|
||||||
openai_headers = {}
|
openai_headers = {}
|
||||||
if "anthropic-ratelimit-requests-limit" in headers:
|
if "anthropic-ratelimit-requests-limit" in headers:
|
||||||
|
|
|
@ -19,11 +19,19 @@ class BaseLLMModelInfo(ABC):
|
||||||
self,
|
self,
|
||||||
model: str,
|
model: str,
|
||||||
) -> Optional[ProviderSpecificModelInfo]:
|
) -> Optional[ProviderSpecificModelInfo]:
|
||||||
|
"""
|
||||||
|
Default values all models of this provider support.
|
||||||
|
"""
|
||||||
return None
|
return None
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def get_models(self) -> List[str]:
|
def get_models(
|
||||||
pass
|
self, api_key: Optional[str] = None, api_base: Optional[str] = None
|
||||||
|
) -> List[str]:
|
||||||
|
"""
|
||||||
|
Returns a list of models supported by this provider.
|
||||||
|
"""
|
||||||
|
return []
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
|
|
52
litellm/llms/gemini/common_utils.py
Normal file
52
litellm/llms/gemini/common_utils.py
Normal file
|
@ -0,0 +1,52 @@
|
||||||
|
from typing import List, Optional
|
||||||
|
|
||||||
|
import litellm
|
||||||
|
from litellm.llms.base_llm.base_utils import BaseLLMModelInfo
|
||||||
|
from litellm.secret_managers.main import get_secret_str
|
||||||
|
|
||||||
|
|
||||||
|
class GeminiModelInfo(BaseLLMModelInfo):
|
||||||
|
@staticmethod
|
||||||
|
def get_api_base(api_base: Optional[str] = None) -> Optional[str]:
|
||||||
|
return (
|
||||||
|
api_base
|
||||||
|
or get_secret_str("GEMINI_API_BASE")
|
||||||
|
or "https://generativelanguage.googleapis.com/v1beta"
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_api_key(api_key: Optional[str] = None) -> Optional[str]:
|
||||||
|
return api_key or (get_secret_str("GEMINI_API_KEY"))
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_base_model(model: str) -> Optional[str]:
|
||||||
|
return model.replace("gemini/", "")
|
||||||
|
|
||||||
|
def get_models(
|
||||||
|
self, api_key: Optional[str] = None, api_base: Optional[str] = None
|
||||||
|
) -> List[str]:
|
||||||
|
|
||||||
|
api_base = GeminiModelInfo.get_api_base(api_base)
|
||||||
|
api_key = GeminiModelInfo.get_api_key(api_key)
|
||||||
|
if api_base is None or api_key is None:
|
||||||
|
raise ValueError(
|
||||||
|
"GEMINI_API_BASE or GEMINI_API_KEY is not set. Please set the environment variable, to query Gemini's `/models` endpoint."
|
||||||
|
)
|
||||||
|
|
||||||
|
response = litellm.module_level_client.get(
|
||||||
|
url=f"{api_base}/models?key={api_key}",
|
||||||
|
)
|
||||||
|
|
||||||
|
if response.status_code != 200:
|
||||||
|
raise ValueError(
|
||||||
|
f"Failed to fetch models from Gemini. Status code: {response.status_code}, Response: {response.json()}"
|
||||||
|
)
|
||||||
|
|
||||||
|
models = response.json()["models"]
|
||||||
|
|
||||||
|
litellm_model_names = []
|
||||||
|
for model in models:
|
||||||
|
stripped_model_name = model["name"].strip("models/")
|
||||||
|
litellm_model_name = "gemini/" + stripped_model_name
|
||||||
|
litellm_model_names.append(litellm_model_name)
|
||||||
|
return litellm_model_names
|
|
@ -11,7 +11,9 @@ class TopazException(BaseLLMException):
|
||||||
|
|
||||||
|
|
||||||
class TopazModelInfo(BaseLLMModelInfo):
|
class TopazModelInfo(BaseLLMModelInfo):
|
||||||
def get_models(self) -> List[str]:
|
def get_models(
|
||||||
|
self, api_key: Optional[str] = None, api_base: Optional[str] = None
|
||||||
|
) -> List[str]:
|
||||||
return [
|
return [
|
||||||
"topaz/Standard V2",
|
"topaz/Standard V2",
|
||||||
"topaz/Low Resolution V2",
|
"topaz/Low Resolution V2",
|
||||||
|
|
51
litellm/llms/xai/common_utils.py
Normal file
51
litellm/llms/xai/common_utils.py
Normal file
|
@ -0,0 +1,51 @@
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
|
||||||
|
import litellm
|
||||||
|
from litellm.llms.base_llm.base_utils import BaseLLMModelInfo
|
||||||
|
from litellm.secret_managers.main import get_secret_str
|
||||||
|
|
||||||
|
|
||||||
|
class XAIModelInfo(BaseLLMModelInfo):
|
||||||
|
@staticmethod
|
||||||
|
def get_api_base(api_base: Optional[str] = None) -> Optional[str]:
|
||||||
|
return api_base or get_secret_str("XAI_API_BASE") or "https://api.x.ai"
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_api_key(api_key: Optional[str] = None) -> Optional[str]:
|
||||||
|
return api_key or get_secret_str("XAI_API_KEY")
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_base_model(model: str) -> Optional[str]:
|
||||||
|
return model.replace("xai/", "")
|
||||||
|
|
||||||
|
def get_models(
|
||||||
|
self, api_key: Optional[str] = None, api_base: Optional[str] = None
|
||||||
|
) -> list[str]:
|
||||||
|
api_base = self.get_api_base(api_base)
|
||||||
|
api_key = self.get_api_key(api_key)
|
||||||
|
if api_base is None or api_key is None:
|
||||||
|
raise ValueError(
|
||||||
|
"XAI_API_BASE or XAI_API_KEY is not set. Please set the environment variable, to query XAI's `/models` endpoint."
|
||||||
|
)
|
||||||
|
response = litellm.module_level_client.get(
|
||||||
|
url=f"{api_base}/v1/models",
|
||||||
|
headers={"Authorization": f"Bearer {api_key}"},
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
response.raise_for_status()
|
||||||
|
except httpx.HTTPStatusError:
|
||||||
|
raise Exception(
|
||||||
|
f"Failed to fetch models from XAI. Status code: {response.status_code}, Response: {response.text}"
|
||||||
|
)
|
||||||
|
|
||||||
|
models = response.json()["data"]
|
||||||
|
|
||||||
|
litellm_model_names = []
|
||||||
|
for model in models:
|
||||||
|
stripped_model_name = model["id"]
|
||||||
|
litellm_model_name = "xai/" + stripped_model_name
|
||||||
|
litellm_model_names.append(litellm_model_name)
|
||||||
|
return litellm_model_names
|
|
@ -5744,13 +5744,15 @@ def trim_messages(
|
||||||
return messages
|
return messages
|
||||||
|
|
||||||
|
|
||||||
def get_valid_models(check_provider_endpoint: bool = False) -> List[str]:
|
def get_valid_models(
|
||||||
|
check_provider_endpoint: bool = False, custom_llm_provider: Optional[str] = None
|
||||||
|
) -> List[str]:
|
||||||
"""
|
"""
|
||||||
Returns a list of valid LLMs based on the set environment variables
|
Returns a list of valid LLMs based on the set environment variables
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
check_provider_endpoint: If True, will check the provider's endpoint for valid models.
|
check_provider_endpoint: If True, will check the provider's endpoint for valid models.
|
||||||
|
custom_llm_provider: If provided, will only check the provider's endpoint for valid models.
|
||||||
Returns:
|
Returns:
|
||||||
A list of valid LLMs
|
A list of valid LLMs
|
||||||
"""
|
"""
|
||||||
|
@ -5762,6 +5764,9 @@ def get_valid_models(check_provider_endpoint: bool = False) -> List[str]:
|
||||||
valid_models = []
|
valid_models = []
|
||||||
|
|
||||||
for provider in litellm.provider_list:
|
for provider in litellm.provider_list:
|
||||||
|
if custom_llm_provider and provider != custom_llm_provider:
|
||||||
|
continue
|
||||||
|
|
||||||
# edge case litellm has together_ai as a provider, it should be togetherai
|
# edge case litellm has together_ai as a provider, it should be togetherai
|
||||||
env_provider_1 = provider.replace("_", "")
|
env_provider_1 = provider.replace("_", "")
|
||||||
env_provider_2 = provider
|
env_provider_2 = provider
|
||||||
|
@ -5783,10 +5788,17 @@ def get_valid_models(check_provider_endpoint: bool = False) -> List[str]:
|
||||||
provider=LlmProviders(provider),
|
provider=LlmProviders(provider),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if custom_llm_provider and provider != custom_llm_provider:
|
||||||
|
continue
|
||||||
|
|
||||||
if provider == "azure":
|
if provider == "azure":
|
||||||
valid_models.append("Azure-LLM")
|
valid_models.append("Azure-LLM")
|
||||||
elif provider_config is not None and check_provider_endpoint:
|
elif provider_config is not None and check_provider_endpoint:
|
||||||
valid_models.extend(provider_config.get_models())
|
try:
|
||||||
|
models = provider_config.get_models()
|
||||||
|
valid_models.extend(models)
|
||||||
|
except Exception as e:
|
||||||
|
verbose_logger.debug(f"Error getting valid models: {e}")
|
||||||
else:
|
else:
|
||||||
models_for_provider = litellm.models_by_provider.get(provider, [])
|
models_for_provider = litellm.models_by_provider.get(provider, [])
|
||||||
valid_models.extend(models_for_provider)
|
valid_models.extend(models_for_provider)
|
||||||
|
@ -6400,10 +6412,16 @@ class ProviderConfigManager:
|
||||||
return litellm.FireworksAIConfig()
|
return litellm.FireworksAIConfig()
|
||||||
elif LlmProviders.OPENAI == provider:
|
elif LlmProviders.OPENAI == provider:
|
||||||
return litellm.OpenAIGPTConfig()
|
return litellm.OpenAIGPTConfig()
|
||||||
|
elif LlmProviders.GEMINI == provider:
|
||||||
|
return litellm.GeminiModelInfo()
|
||||||
elif LlmProviders.LITELLM_PROXY == provider:
|
elif LlmProviders.LITELLM_PROXY == provider:
|
||||||
return litellm.LiteLLMProxyChatConfig()
|
return litellm.LiteLLMProxyChatConfig()
|
||||||
elif LlmProviders.TOPAZ == provider:
|
elif LlmProviders.TOPAZ == provider:
|
||||||
return litellm.TopazModelInfo()
|
return litellm.TopazModelInfo()
|
||||||
|
elif LlmProviders.ANTHROPIC == provider:
|
||||||
|
return litellm.AnthropicModelInfo()
|
||||||
|
elif LlmProviders.XAI == provider:
|
||||||
|
return litellm.XAIModelInfo()
|
||||||
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
|
@ -303,6 +303,24 @@ def test_aget_valid_models():
|
||||||
os.environ = old_environ
|
os.environ = old_environ
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("custom_llm_provider", ["gemini", "anthropic", "xai"])
|
||||||
|
def test_get_valid_models_with_custom_llm_provider(custom_llm_provider):
|
||||||
|
from litellm.utils import ProviderConfigManager
|
||||||
|
from litellm.types.utils import LlmProviders
|
||||||
|
|
||||||
|
provider_config = ProviderConfigManager.get_provider_model_info(
|
||||||
|
model=None,
|
||||||
|
provider=LlmProviders(custom_llm_provider),
|
||||||
|
)
|
||||||
|
assert provider_config is not None
|
||||||
|
valid_models = get_valid_models(
|
||||||
|
check_provider_endpoint=True, custom_llm_provider=custom_llm_provider
|
||||||
|
)
|
||||||
|
print(valid_models)
|
||||||
|
assert len(valid_models) > 0
|
||||||
|
assert provider_config.get_models() == valid_models
|
||||||
|
|
||||||
|
|
||||||
# test_get_valid_models()
|
# test_get_valid_models()
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue