Support checking provider-specific /models endpoints for available models based on key (#7538)

* test(test_utils.py): initial test for valid models

Addresses https://github.com/BerriAI/litellm/issues/7525

* fix: test

* feat(fireworks_ai/transformation.py): support retrieving valid models from fireworks ai endpoint

* refactor(fireworks_ai/): support checking model info on `/v1/models` route

* docs(set_keys.md): update docs to clarify check llm provider api usage

* fix(watsonx/common_utils.py): support 'WATSONX_ZENAPIKEY' for iam auth

* fix(watsonx): read in watsonx token from env var

* fix: fix linting errors

* fix(utils.py): fix provider config check

* style: cleanup unused imports
This commit is contained in:
Krish Dholakia 2025-01-03 19:29:59 -08:00 committed by GitHub
parent cac06a32b8
commit f770dd0c95
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
12 changed files with 350 additions and 42 deletions

View file

@ -14,6 +14,7 @@ os.environ["WATSONX_TOKEN"] = "" # IAM auth token
# optional - can also be passed as params to completion() or embedding()
os.environ["WATSONX_PROJECT_ID"] = "" # Project ID of your WatsonX instance
os.environ["WATSONX_DEPLOYMENT_SPACE_ID"] = "" # ID of your deployment space to use deployed models
os.environ["WATSONX_ZENAPIKEY"] = "" # Zen API key (use for long-term api token)
```
See [here](https://cloud.ibm.com/apidocs/watsonx-ai#api-authentication) for more information on how to get an access token to authenticate to watsonx.ai.

View file

@ -179,6 +179,22 @@ assert(valid_models == expected_models)
os.environ = old_environ
```
### `get_valid_models(check_provider_endpoint: True)`
This helper will check the provider's endpoint for valid models.
Currently implemented for:
- OpenAI (if OPENAI_API_KEY is set)
- Fireworks AI (if FIREWORKS_AI_API_KEY is set)
- LiteLLM Proxy (if LITELLM_PROXY_API_KEY is set)
```python
from litellm import get_valid_models
valid_models = get_valid_models(check_provider_endpoint=True)
print(valid_models)
```
### `validate_environment(model: str)`
This helper tells you if you have all the required environment variables for a model, and if not - what's missing.

View file

@ -1168,6 +1168,7 @@ from .llms.azure.azure import (
from .llms.azure.chat.gpt_transformation import AzureOpenAIConfig
from .llms.azure.completion.transformation import AzureOpenAITextConfig
from .llms.hosted_vllm.chat.transformation import HostedVLLMChatConfig
from .llms.litellm_proxy.chat.transformation import LiteLLMProxyChatConfig
from .llms.vllm.completion.transformation import VLLMConfig
from .llms.deepseek.chat.transformation import DeepSeekChatConfig
from .llms.lm_studio.chat.transformation import LMStudioChatConfig

View file

@ -488,11 +488,10 @@ def _get_openai_compatible_provider_info( # noqa: PLR0915
elif custom_llm_provider == "fireworks_ai":
# fireworks is openai compatible, we just need to set this to custom_openai and have the api_base be https://api.fireworks.ai/inference/v1
(
model,
api_base,
dynamic_api_key,
) = litellm.FireworksAIConfig()._get_openai_compatible_provider_info(
model=model, api_base=api_base, api_key=api_key
api_base=api_base, api_key=api_key
)
elif custom_llm_provider == "azure_ai":
(

View file

@ -1,9 +1,18 @@
from abc import ABC, abstractmethod
from typing import List, Optional
from litellm.types.utils import ModelInfoBase
class BaseLLMModelInfo(ABC):
@abstractmethod
def get_model_info(self, model: str) -> ModelInfoBase:
def get_model_info(
self,
model: str,
existing_model_info: Optional[ModelInfoBase] = None,
) -> Optional[ModelInfoBase]:
pass
@abstractmethod
def get_models(self) -> List[str]:
pass

View file

@ -1,7 +1,6 @@
from typing import List, Literal, Optional, Tuple, Union, cast
import litellm
from litellm.llms.base_llm.base_utils import BaseLLMModelInfo
from litellm.secret_managers.main import get_secret_str
from litellm.types.llms.openai import AllMessageValues, ChatCompletionImageObject
from litellm.types.utils import ModelInfoBase, ProviderSpecificModelInfo
@ -9,7 +8,7 @@ from litellm.types.utils import ModelInfoBase, ProviderSpecificModelInfo
from ...openai.chat.gpt_transformation import OpenAIGPTConfig
class FireworksAIConfig(BaseLLMModelInfo, OpenAIGPTConfig):
class FireworksAIConfig(OpenAIGPTConfig):
"""
Reference: https://docs.fireworks.ai/api-reference/post-chatcompletions
@ -209,8 +208,8 @@ class FireworksAIConfig(BaseLLMModelInfo, OpenAIGPTConfig):
)
def _get_openai_compatible_provider_info(
self, model: str, api_base: Optional[str], api_key: Optional[str]
) -> Tuple[str, Optional[str], Optional[str]]:
self, api_base: Optional[str], api_key: Optional[str]
) -> Tuple[Optional[str], Optional[str]]:
api_base = (
api_base
or get_secret_str("FIREWORKS_API_BASE")
@ -222,4 +221,32 @@ class FireworksAIConfig(BaseLLMModelInfo, OpenAIGPTConfig):
or get_secret_str("FIREWORKSAI_API_KEY")
or get_secret_str("FIREWORKS_AI_TOKEN")
)
return model, api_base, dynamic_api_key
return api_base, dynamic_api_key
def get_models(self, api_key: Optional[str] = None, api_base: Optional[str] = None):
api_base, api_key = self._get_openai_compatible_provider_info(
api_base=api_base, api_key=api_key
)
if api_base is None or api_key is None:
raise ValueError(
"FIREWORKS_API_BASE or FIREWORKS_API_KEY is not set. Please set the environment variable, to query Fireworks AI's `/models` endpoint."
)
account_id = get_secret_str("FIREWORKS_ACCOUNT_ID")
if account_id is None:
raise ValueError(
"FIREWORKS_ACCOUNT_ID is not set. Please set the environment variable, to query Fireworks AI's `/models` endpoint."
)
response = litellm.module_level_client.get(
url=f"{api_base}/v1/accounts/{account_id}/models",
headers={"Authorization": f"Bearer {api_key}"},
)
if response.status_code != 200:
raise ValueError(
f"Failed to fetch models from Fireworks AI. Status code: {response.status_code}, Response: {response.json()}"
)
models = response.json()["models"]
return ["fireworks_ai/" + model["name"] for model in models]

View file

@ -0,0 +1,29 @@
"""
Translate from OpenAI's `/v1/chat/completions` to VLLM's `/v1/chat/completions`
"""
from typing import List, Optional, Tuple
from litellm.secret_managers.main import get_secret_str
from ...openai.chat.gpt_transformation import OpenAIGPTConfig
class LiteLLMProxyChatConfig(OpenAIGPTConfig):
def _get_openai_compatible_provider_info(
self, api_base: Optional[str], api_key: Optional[str]
) -> Tuple[Optional[str], Optional[str]]:
api_base = api_base or get_secret_str("LITELLM_PROXY_API_BASE") # type: ignore
dynamic_api_key = api_key or get_secret_str("LITELLM_PROXY_API_KEY")
return api_base, dynamic_api_key
def get_models(
self, api_key: Optional[str] = None, api_base: Optional[str] = None
) -> List[str]:
api_base, api_key = self._get_openai_compatible_provider_info(api_base, api_key)
if api_base is None:
raise ValueError(
"api_base not set for LiteLLM Proxy route. Set in env via `LITELLM_PROXY_API_BASE`"
)
models = super().get_models(api_key=api_key, api_base=api_base)
return [f"litellm_proxy/{model}" for model in models]

View file

@ -7,9 +7,11 @@ from typing import TYPE_CHECKING, Any, List, Optional, Union, cast
import httpx
import litellm
from litellm.llms.base_llm.base_utils import BaseLLMModelInfo
from litellm.llms.base_llm.chat.transformation import BaseConfig, BaseLLMException
from litellm.secret_managers.main import get_secret_str
from litellm.types.llms.openai import AllMessageValues
from litellm.types.utils import ModelResponse
from litellm.types.utils import ModelInfoBase, ModelResponse
from ..common_utils import OpenAIError
@ -21,7 +23,7 @@ else:
LiteLLMLoggingObj = Any
class OpenAIGPTConfig(BaseConfig):
class OpenAIGPTConfig(BaseLLMModelInfo, BaseConfig):
"""
Reference: https://platform.openai.com/docs/api-reference/chat/create
@ -229,3 +231,43 @@ class OpenAIGPTConfig(BaseConfig):
api_base: Optional[str] = None,
) -> dict:
raise NotImplementedError
def get_models(
self, api_key: Optional[str] = None, api_base: Optional[str] = None
) -> List[str]:
"""
Calls OpenAI's `/v1/models` endpoint and returns the list of models.
"""
if api_base is None:
api_base = "https://api.openai.com"
if api_key is None:
api_key = get_secret_str("OPENAI_API_KEY")
response = litellm.module_level_client.get(
url=f"{api_base}/v1/models",
headers={"Authorization": f"Bearer {api_key}"},
)
if response.status_code != 200:
raise Exception(f"Failed to get models: {response.text}")
models = response.json()["data"]
return [model["id"] for model in models]
def get_model_info(
self, model: str, existing_model_info: Optional[ModelInfoBase] = None
) -> ModelInfoBase:
if existing_model_info is not None:
return existing_model_info
return ModelInfoBase(
key=model,
litellm_provider="openai",
mode="chat",
input_cost_per_token=0.0,
output_cost_per_token=0.0,
max_tokens=None,
max_input_tokens=None,
max_output_tokens=None,
)

View file

@ -175,7 +175,12 @@ class IBMWatsonXMixin:
if "Authorization" in headers:
return {**default_headers, **headers}
token = cast(Optional[str], optional_params.get("token"))
token = cast(
Optional[str],
optional_params.get("token")
or get_secret_str("WATSONX_ZENAPIKEY")
or get_secret_str("WATSONX_TOKEN"),
)
if token:
headers["Authorization"] = f"Bearer {token}"
else:
@ -245,6 +250,7 @@ class IBMWatsonXMixin:
)
token: Optional[str] = None
if wx_credentials is not None:
api_base = wx_credentials.get("url", api_base)
api_key = wx_credentials.get(

View file

@ -4223,6 +4223,7 @@ def _get_model_info_helper( # noqa: PLR0915
_model_info: Optional[Dict[str, Any]] = None
key: Optional[str] = None
provider_config: Optional[BaseLLMModelInfo] = None
if combined_model_name in litellm.model_cost:
key = combined_model_name
_model_info = _get_model_info_from_model_cost(key=key)
@ -4261,16 +4262,20 @@ def _get_model_info_helper( # noqa: PLR0915
model_info=_model_info, custom_llm_provider=custom_llm_provider
):
_model_info = None
if _model_info is None and ProviderConfigManager.get_provider_model_info(
model=model, provider=LlmProviders(custom_llm_provider)
):
if custom_llm_provider:
provider_config = ProviderConfigManager.get_provider_model_info(
model=model, provider=LlmProviders(custom_llm_provider)
)
if provider_config is not None:
_model_info = cast(
dict, provider_config.get_model_info(model=model)
)
if _model_info is None and provider_config is not None:
_model_info = cast(
Optional[Dict],
provider_config.get_model_info(
model=model, existing_model_info=_model_info
),
)
if key is None:
key = "provider_specific_model_info"
if _model_info is None or key is None:
raise ValueError(
@ -5706,12 +5711,12 @@ def trim_messages(
return messages
def get_valid_models() -> List[str]:
def get_valid_models(check_provider_endpoint: bool = False) -> List[str]:
"""
Returns a list of valid LLMs based on the set environment variables
Args:
None
check_provider_endpoint: If True, will check the provider's endpoint for valid models.
Returns:
A list of valid LLMs
@ -5725,22 +5730,36 @@ def get_valid_models() -> List[str]:
for provider in litellm.provider_list:
# edge case litellm has together_ai as a provider, it should be togetherai
provider = provider.replace("_", "")
env_provider_1 = provider.replace("_", "")
env_provider_2 = provider
# litellm standardizes expected provider keys to
# PROVIDER_API_KEY. Example: OPENAI_API_KEY, COHERE_API_KEY
expected_provider_key = f"{provider.upper()}_API_KEY"
if expected_provider_key in environ_keys:
expected_provider_key_1 = f"{env_provider_1.upper()}_API_KEY"
expected_provider_key_2 = f"{env_provider_2.upper()}_API_KEY"
if (
expected_provider_key_1 in environ_keys
or expected_provider_key_2 in environ_keys
):
# key is set
valid_providers.append(provider)
for provider in valid_providers:
provider_config = ProviderConfigManager.get_provider_model_info(
model=None,
provider=LlmProviders(provider),
)
if provider == "azure":
valid_models.append("Azure-LLM")
elif provider_config is not None and check_provider_endpoint:
valid_models.extend(provider_config.get_models())
else:
models_for_provider = litellm.models_by_provider.get(provider, [])
valid_models.extend(models_for_provider)
return valid_models
except Exception:
except Exception as e:
verbose_logger.debug(f"Error getting valid models: {e}")
return [] # NON-Blocking
@ -6291,11 +6310,14 @@ class ProviderConfigManager:
@staticmethod
def get_provider_model_info(
model: str,
model: Optional[str],
provider: LlmProviders,
) -> Optional[BaseLLMModelInfo]:
if LlmProviders.FIREWORKS_AI == provider:
return litellm.FireworksAIConfig()
elif LlmProviders.LITELLM_PROXY == provider:
return litellm.LiteLLMProxyChatConfig()
return None

View file

@ -23,31 +23,43 @@ def watsonx_chat_completion_call():
api_key="test_api_key",
headers=None,
client=None,
patch_token_call=True,
):
if messages is None:
messages = [{"role": "user", "content": "Hello, how are you?"}]
if client is None:
client = HTTPHandler()
mock_response = Mock()
mock_response.json.return_value = {
"access_token": "mock_access_token",
"expires_in": 3600,
}
mock_response.raise_for_status = Mock() # No-op to simulate no exception
if patch_token_call:
mock_response = Mock()
mock_response.json.return_value = {
"access_token": "mock_access_token",
"expires_in": 3600,
}
mock_response.raise_for_status = Mock() # No-op to simulate no exception
with patch.object(client, "post") as mock_post, patch.object(
litellm.module_level_client, "post", return_value=mock_response
) as mock_get:
completion(
model=model,
messages=messages,
api_key=api_key,
headers=headers or {},
client=client,
)
with patch.object(client, "post") as mock_post, patch.object(
litellm.module_level_client, "post", return_value=mock_response
) as mock_get:
completion(
model=model,
messages=messages,
api_key=api_key,
headers=headers or {},
client=client,
)
return mock_post, mock_get
return mock_post, mock_get
else:
with patch.object(client, "post") as mock_post:
completion(
model=model,
messages=messages,
api_key=api_key,
headers=headers or {},
client=client,
)
return mock_post, None
return _call
@ -77,6 +89,20 @@ def test_watsonx_custom_auth_header(
)
@pytest.mark.parametrize("env_var_key", ["WATSONX_ZENAPIKEY", "WATSONX_TOKEN"])
def test_watsonx_token_in_env_var(
monkeypatch, watsonx_chat_completion_call, env_var_key
):
monkeypatch.setenv(env_var_key, "my-custom-token")
mock_post, _ = watsonx_chat_completion_call(patch_token_call=False)
assert mock_post.call_count == 1
assert (
mock_post.call_args[1]["headers"]["Authorization"] == "Bearer my-custom-token"
)
def test_watsonx_chat_completions_endpoint(watsonx_chat_completion_call):
model = "watsonx/another-model"
messages = [{"role": "user", "content": "Test message"}]

View file

@ -1270,6 +1270,8 @@ def test_fireworks_ai_document_inlining():
"""
from litellm.utils import supports_pdf_input, supports_vision
litellm._turn_on_debug()
assert supports_pdf_input("fireworks_ai/llama-3.1-8b-instruct") is True
assert supports_vision("fireworks_ai/llama-3.1-8b-instruct") is True
@ -1288,3 +1290,131 @@ def test_logprobs_type():
assert logprobs.token_logprobs is None
assert logprobs.tokens is None
assert logprobs.top_logprobs is None
def test_get_valid_models_openai_proxy(monkeypatch):
from litellm.utils import get_valid_models
import litellm
litellm._turn_on_debug()
monkeypatch.setenv("LITELLM_PROXY_API_KEY", "sk-1234")
monkeypatch.setenv("LITELLM_PROXY_API_BASE", "https://litellm-api.up.railway.app/")
monkeypatch.delenv("FIREWORKS_AI_ACCOUNT_ID", None)
monkeypatch.delenv("FIREWORKS_AI_API_KEY", None)
mock_response_data = {
"object": "list",
"data": [
{
"id": "gpt-4o",
"object": "model",
"created": 1686935002,
"owned_by": "organization-owner",
},
],
}
# Create a mock response object
mock_response = MagicMock()
mock_response.status_code = 200
mock_response.json.return_value = mock_response_data
with patch.object(
litellm.module_level_client, "get", return_value=mock_response
) as mock_post:
valid_models = get_valid_models(check_provider_endpoint=True)
assert "litellm_proxy/gpt-4o" in valid_models
def test_get_valid_models_fireworks_ai(monkeypatch):
from litellm.utils import get_valid_models
import litellm
monkeypatch.setenv("FIREWORKS_API_KEY", "sk-1234")
monkeypatch.setenv("FIREWORKS_ACCOUNT_ID", "1234")
mock_response_data = {
"models": [
{
"name": "accounts/fireworks/models/llama-3.1-8b-instruct",
"displayName": "<string>",
"description": "<string>",
"createTime": "2023-11-07T05:31:56Z",
"createdBy": "<string>",
"state": "STATE_UNSPECIFIED",
"status": {"code": "OK", "message": "<string>"},
"kind": "KIND_UNSPECIFIED",
"githubUrl": "<string>",
"huggingFaceUrl": "<string>",
"baseModelDetails": {
"worldSize": 123,
"checkpointFormat": "CHECKPOINT_FORMAT_UNSPECIFIED",
"parameterCount": "<string>",
"moe": True,
"tunable": True,
},
"peftDetails": {
"baseModel": "<string>",
"r": 123,
"targetModules": ["<string>"],
},
"teftDetails": {},
"public": True,
"conversationConfig": {
"style": "<string>",
"system": "<string>",
"template": "<string>",
},
"contextLength": 123,
"supportsImageInput": True,
"supportsTools": True,
"importedFrom": "<string>",
"fineTuningJob": "<string>",
"defaultDraftModel": "<string>",
"defaultDraftTokenCount": 123,
"precisions": ["PRECISION_UNSPECIFIED"],
"deployedModelRefs": [
{
"name": "<string>",
"deployment": "<string>",
"state": "STATE_UNSPECIFIED",
"default": True,
"public": True,
}
],
"cluster": "<string>",
"deprecationDate": {"year": 123, "month": 123, "day": 123},
}
],
"nextPageToken": "<string>",
"totalSize": 123,
}
# Create a mock response object
mock_response = MagicMock()
mock_response.status_code = 200
mock_response.json.return_value = mock_response_data
with patch.object(
litellm.module_level_client, "get", return_value=mock_response
) as mock_post:
valid_models = get_valid_models(check_provider_endpoint=True)
assert (
"fireworks_ai/accounts/fireworks/models/llama-3.1-8b-instruct"
in valid_models
)
def test_get_valid_models_default(monkeypatch):
"""
Ensure that the default models is used when error retrieving from model api.
Prevent regression for existing usage.
"""
from litellm.utils import get_valid_models
import litellm
monkeypatch.setenv("FIREWORKS_API_KEY", "sk-1234")
valid_models = get_valid_models()
assert len(valid_models) > 0