Support checking provider /models endpoints on proxy /v1/models endpoint (#9958)

* feat(utils.py): support global flag for 'check_provider_endpoints'

enables setting this for `/models` on proxy

* feat(utils.py): add caching to 'get_valid_models'

Prevents checking endpoint repeatedly

* fix(utils.py): ensure mutations don't impact cached results

* test(test_utils.py): add unit test to confirm cache invalidation logic

* feat(utils.py): get_valid_models - support passing litellm params dynamically

Allows for checking endpoints based on received credentials

* test: update test

* feat(model_checks.py): pass router credentials to get_valid_models - ensures it checks correct credentials

* refactor(utils.py): refactor for simpler functions

* fix: fix linting errors

* fix(utils.py): fix test

* fix(utils.py): set valid providers to custom_llm_provider, if given

* test: update test

* fix: fix ruff check error
This commit is contained in:
Krish Dholakia 2025-04-14 23:23:20 -07:00 committed by GitHub
parent e94eb4ec70
commit 33ead69c0a
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 313 additions and 110 deletions

View file

@ -41,8 +41,10 @@ from unittest.mock import AsyncMock, MagicMock, patch
# Assuming your trim_messages, shorten_message_to_fit_limit, and get_token_count functions are all in a module named 'message_utils'
@pytest.fixture(autouse=True)
def reset_mock_cache():
from litellm.utils import _model_cache
_model_cache.flush_cache()
# Test 1: Check trimming of normal message
def test_basic_trimming():
messages = [
@ -1539,6 +1541,7 @@ def test_get_valid_models_fireworks_ai(monkeypatch):
litellm.module_level_client, "get", return_value=mock_response
) as mock_post:
valid_models = get_valid_models(check_provider_endpoint=True)
print("valid_models", valid_models)
mock_post.assert_called_once()
assert (
"fireworks_ai/accounts/fireworks/models/llama-3.1-8b-instruct"
@ -2122,3 +2125,58 @@ def test_get_provider_audio_transcription_config():
config = ProviderConfigManager.get_provider_audio_transcription_config(
model="whisper-1", provider=provider
)
def test_get_valid_models_from_provider():
"""
Test that get_valid_models returns the correct models for a given provider
"""
from litellm.utils import get_valid_models
valid_models = get_valid_models(custom_llm_provider="openai")
assert len(valid_models) > 0
assert "gpt-4o-mini" in valid_models
print("Valid models: ", valid_models)
valid_models.remove("gpt-4o-mini")
assert "gpt-4o-mini" not in valid_models
valid_models = get_valid_models(custom_llm_provider="openai")
assert len(valid_models) > 0
assert "gpt-4o-mini" in valid_models
def test_get_valid_models_from_provider_cache_invalidation(monkeypatch):
"""
Test that get_valid_models returns the correct models for a given provider
"""
from litellm.utils import _model_cache
monkeypatch.setenv("OPENAI_API_KEY", "123")
_model_cache.set_cached_model_info("openai", litellm_params=None, available_models=["gpt-4o-mini"])
monkeypatch.delenv("OPENAI_API_KEY")
assert _model_cache.get_cached_model_info("openai") is None
def test_get_valid_models_from_dynamic_api_key():
"""
Test that get_valid_models returns the correct models for a given provider
"""
from litellm.utils import get_valid_models
from litellm.types.router import CredentialLiteLLMParams
creds = CredentialLiteLLMParams(api_key="123")
valid_models = get_valid_models(custom_llm_provider="anthropic", litellm_params=creds, check_provider_endpoint=True)
assert len(valid_models) == 0
creds = CredentialLiteLLMParams(api_key=os.getenv("ANTHROPIC_API_KEY"))
valid_models = get_valid_models(custom_llm_provider="anthropic", litellm_params=creds, check_provider_endpoint=True)
assert len(valid_models) > 0
assert "anthropic/claude-3-7-sonnet-20250219" in valid_models