Support checking provider-specific /models endpoints for available models based on key (#7538)

* test(test_utils.py): initial test for valid models

Addresses https://github.com/BerriAI/litellm/issues/7525

* fix: test

* feat(fireworks_ai/transformation.py): support retrieving valid models from fireworks ai endpoint

* refactor(fireworks_ai/): support checking model info on `/v1/models` route

* docs(set_keys.md): update docs to clarify check llm provider api usage

* fix(watsonx/common_utils.py): support 'WATSONX_ZENAPIKEY' for iam auth

* fix(watsonx): read in watsonx token from env var

* fix: fix linting errors

* fix(utils.py): fix provider config check

* style: cleanup unused imports
This commit is contained in:
Krish Dholakia 2025-01-03 19:29:59 -08:00 committed by GitHub
parent cac06a32b8
commit f770dd0c95
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
12 changed files with 350 additions and 42 deletions

View file

@ -1270,6 +1270,8 @@ def test_fireworks_ai_document_inlining():
"""
from litellm.utils import supports_pdf_input, supports_vision
litellm._turn_on_debug()
assert supports_pdf_input("fireworks_ai/llama-3.1-8b-instruct") is True
assert supports_vision("fireworks_ai/llama-3.1-8b-instruct") is True
@ -1288,3 +1290,131 @@ def test_logprobs_type():
assert logprobs.token_logprobs is None
assert logprobs.tokens is None
assert logprobs.top_logprobs is None
def test_get_valid_models_openai_proxy(monkeypatch):
from litellm.utils import get_valid_models
import litellm
litellm._turn_on_debug()
monkeypatch.setenv("LITELLM_PROXY_API_KEY", "sk-1234")
monkeypatch.setenv("LITELLM_PROXY_API_BASE", "https://litellm-api.up.railway.app/")
monkeypatch.delenv("FIREWORKS_AI_ACCOUNT_ID", None)
monkeypatch.delenv("FIREWORKS_AI_API_KEY", None)
mock_response_data = {
"object": "list",
"data": [
{
"id": "gpt-4o",
"object": "model",
"created": 1686935002,
"owned_by": "organization-owner",
},
],
}
# Create a mock response object
mock_response = MagicMock()
mock_response.status_code = 200
mock_response.json.return_value = mock_response_data
with patch.object(
litellm.module_level_client, "get", return_value=mock_response
) as mock_post:
valid_models = get_valid_models(check_provider_endpoint=True)
assert "litellm_proxy/gpt-4o" in valid_models
def test_get_valid_models_fireworks_ai(monkeypatch):
from litellm.utils import get_valid_models
import litellm
monkeypatch.setenv("FIREWORKS_API_KEY", "sk-1234")
monkeypatch.setenv("FIREWORKS_ACCOUNT_ID", "1234")
mock_response_data = {
"models": [
{
"name": "accounts/fireworks/models/llama-3.1-8b-instruct",
"displayName": "<string>",
"description": "<string>",
"createTime": "2023-11-07T05:31:56Z",
"createdBy": "<string>",
"state": "STATE_UNSPECIFIED",
"status": {"code": "OK", "message": "<string>"},
"kind": "KIND_UNSPECIFIED",
"githubUrl": "<string>",
"huggingFaceUrl": "<string>",
"baseModelDetails": {
"worldSize": 123,
"checkpointFormat": "CHECKPOINT_FORMAT_UNSPECIFIED",
"parameterCount": "<string>",
"moe": True,
"tunable": True,
},
"peftDetails": {
"baseModel": "<string>",
"r": 123,
"targetModules": ["<string>"],
},
"teftDetails": {},
"public": True,
"conversationConfig": {
"style": "<string>",
"system": "<string>",
"template": "<string>",
},
"contextLength": 123,
"supportsImageInput": True,
"supportsTools": True,
"importedFrom": "<string>",
"fineTuningJob": "<string>",
"defaultDraftModel": "<string>",
"defaultDraftTokenCount": 123,
"precisions": ["PRECISION_UNSPECIFIED"],
"deployedModelRefs": [
{
"name": "<string>",
"deployment": "<string>",
"state": "STATE_UNSPECIFIED",
"default": True,
"public": True,
}
],
"cluster": "<string>",
"deprecationDate": {"year": 123, "month": 123, "day": 123},
}
],
"nextPageToken": "<string>",
"totalSize": 123,
}
# Create a mock response object
mock_response = MagicMock()
mock_response.status_code = 200
mock_response.json.return_value = mock_response_data
with patch.object(
litellm.module_level_client, "get", return_value=mock_response
) as mock_post:
valid_models = get_valid_models(check_provider_endpoint=True)
assert (
"fireworks_ai/accounts/fireworks/models/llama-3.1-8b-instruct"
in valid_models
)
def test_get_valid_models_default(monkeypatch):
"""
Ensure that the default models is used when error retrieving from model api.
Prevent regression for existing usage.
"""
from litellm.utils import get_valid_models
import litellm
monkeypatch.setenv("FIREWORKS_API_KEY", "sk-1234")
valid_models = get_valid_models()
assert len(valid_models) > 0