Support checking provider /models endpoints on proxy /v1/models endpoint (#9958)

* feat(utils.py): support global flag for 'check_provider_endpoints'

enables setting this for `/models` on proxy

* feat(utils.py): add caching to 'get_valid_models'

Prevents checking endpoint repeatedly

* fix(utils.py): ensure mutations don't impact cached results

* test(test_utils.py): add unit test to confirm cache invalidation logic

* feat(utils.py): get_valid_models - support passing litellm params dynamically

Allows for checking endpoints based on received credentials

* test: update test

* feat(model_checks.py): pass router credentials to get_valid_models - ensures it checks correct credentials

* refactor(utils.py): refactor for simpler functions

* fix: fix linting errors

* fix(utils.py): fix test

* fix(utils.py): set valid providers to custom_llm_provider, if given

* test: update test

* fix: fix ruff check error
This commit is contained in:
Krish Dholakia 2025-04-14 23:23:20 -07:00 committed by GitHub
parent e94eb4ec70
commit 33ead69c0a
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 313 additions and 110 deletions

View file

@ -128,19 +128,19 @@ prometheus_initialize_budget_metrics: Optional[bool] = False
require_auth_for_metrics_endpoint: Optional[bool] = False
argilla_batch_size: Optional[int] = None
datadog_use_v1: Optional[bool] = False # if you want to use v1 datadog logged payload
gcs_pub_sub_use_v1: Optional[bool] = (
False # if you want to use v1 gcs pubsub logged payload
)
gcs_pub_sub_use_v1: Optional[
bool
] = False # if you want to use v1 gcs pubsub logged payload
argilla_transformation_object: Optional[Dict[str, Any]] = None
_async_input_callback: List[Union[str, Callable, CustomLogger]] = (
[]
) # internal variable - async custom callbacks are routed here.
_async_success_callback: List[Union[str, Callable, CustomLogger]] = (
[]
) # internal variable - async custom callbacks are routed here.
_async_failure_callback: List[Union[str, Callable, CustomLogger]] = (
[]
) # internal variable - async custom callbacks are routed here.
_async_input_callback: List[
Union[str, Callable, CustomLogger]
] = [] # internal variable - async custom callbacks are routed here.
_async_success_callback: List[
Union[str, Callable, CustomLogger]
] = [] # internal variable - async custom callbacks are routed here.
_async_failure_callback: List[
Union[str, Callable, CustomLogger]
] = [] # internal variable - async custom callbacks are routed here.
pre_call_rules: List[Callable] = []
post_call_rules: List[Callable] = []
turn_off_message_logging: Optional[bool] = False
@ -148,18 +148,18 @@ log_raw_request_response: bool = False
redact_messages_in_exceptions: Optional[bool] = False
redact_user_api_key_info: Optional[bool] = False
filter_invalid_headers: Optional[bool] = False
add_user_information_to_llm_headers: Optional[bool] = (
None # adds user_id, team_id, token hash (params from StandardLoggingMetadata) to request headers
)
add_user_information_to_llm_headers: Optional[
bool
] = None # adds user_id, team_id, token hash (params from StandardLoggingMetadata) to request headers
store_audit_logs = False # Enterprise feature, allow users to see audit logs
### end of callbacks #############
email: Optional[str] = (
None # Not used anymore, will be removed in next MAJOR release - https://github.com/BerriAI/litellm/discussions/648
)
token: Optional[str] = (
None # Not used anymore, will be removed in next MAJOR release - https://github.com/BerriAI/litellm/discussions/648
)
email: Optional[
str
] = None # Not used anymore, will be removed in next MAJOR release - https://github.com/BerriAI/litellm/discussions/648
token: Optional[
str
] = None # Not used anymore, will be removed in next MAJOR release - https://github.com/BerriAI/litellm/discussions/648
telemetry = True
max_tokens: int = DEFAULT_MAX_TOKENS # OpenAI Defaults
drop_params = bool(os.getenv("LITELLM_DROP_PARAMS", False))
@ -235,24 +235,20 @@ enable_loadbalancing_on_batch_endpoints: Optional[bool] = None
enable_caching_on_provider_specific_optional_params: bool = (
False # feature-flag for caching on optional params - e.g. 'top_k'
)
caching: bool = (
False # Not used anymore, will be removed in next MAJOR release - https://github.com/BerriAI/litellm/discussions/648
)
caching_with_models: bool = (
False # # Not used anymore, will be removed in next MAJOR release - https://github.com/BerriAI/litellm/discussions/648
)
cache: Optional[Cache] = (
None # cache object <- use this - https://docs.litellm.ai/docs/caching
)
caching: bool = False # Not used anymore, will be removed in next MAJOR release - https://github.com/BerriAI/litellm/discussions/648
caching_with_models: bool = False # # Not used anymore, will be removed in next MAJOR release - https://github.com/BerriAI/litellm/discussions/648
cache: Optional[
Cache
] = None # cache object <- use this - https://docs.litellm.ai/docs/caching
default_in_memory_ttl: Optional[float] = None
default_redis_ttl: Optional[float] = None
default_redis_batch_cache_expiry: Optional[float] = None
model_alias_map: Dict[str, str] = {}
model_group_alias_map: Dict[str, str] = {}
max_budget: float = 0.0 # set the max budget across all providers
budget_duration: Optional[str] = (
None # proxy only - resets budget after fixed duration. You can set duration as seconds ("30s"), minutes ("30m"), hours ("30h"), days ("30d").
)
budget_duration: Optional[
str
] = None # proxy only - resets budget after fixed duration. You can set duration as seconds ("30s"), minutes ("30m"), hours ("30h"), days ("30d").
default_soft_budget: float = (
DEFAULT_SOFT_BUDGET # by default all litellm proxy keys have a soft budget of 50.0
)
@ -261,15 +257,11 @@ forward_traceparent_to_llm_provider: bool = False
_current_cost = 0.0 # private variable, used if max budget is set
error_logs: Dict = {}
add_function_to_prompt: bool = (
False # if function calling not supported by api, append function call details to system prompt
)
add_function_to_prompt: bool = False # if function calling not supported by api, append function call details to system prompt
client_session: Optional[httpx.Client] = None
aclient_session: Optional[httpx.AsyncClient] = None
model_fallbacks: Optional[List] = None # Deprecated for 'litellm.fallbacks'
model_cost_map_url: str = (
"https://raw.githubusercontent.com/BerriAI/litellm/main/model_prices_and_context_window.json"
)
model_cost_map_url: str = "https://raw.githubusercontent.com/BerriAI/litellm/main/model_prices_and_context_window.json"
suppress_debug_info = False
dynamodb_table_name: Optional[str] = None
s3_callback_params: Optional[Dict] = None
@ -292,9 +284,7 @@ disable_end_user_cost_tracking_prometheus_only: Optional[bool] = None
custom_prometheus_metadata_labels: List[str] = []
#### REQUEST PRIORITIZATION ####
priority_reservation: Optional[Dict[str, float]] = None
force_ipv4: bool = (
False # when True, litellm will force ipv4 for all LLM requests. Some users have seen httpx ConnectionError when using ipv6.
)
force_ipv4: bool = False # when True, litellm will force ipv4 for all LLM requests. Some users have seen httpx ConnectionError when using ipv6.
module_level_aclient = AsyncHTTPHandler(
timeout=request_timeout, client_alias="module level aclient"
)
@ -308,13 +298,13 @@ fallbacks: Optional[List] = None
context_window_fallbacks: Optional[List] = None
content_policy_fallbacks: Optional[List] = None
allowed_fails: int = 3
num_retries_per_request: Optional[int] = (
None # for the request overall (incl. fallbacks + model retries)
)
num_retries_per_request: Optional[
int
] = None # for the request overall (incl. fallbacks + model retries)
####### SECRET MANAGERS #####################
secret_manager_client: Optional[Any] = (
None # list of instantiated key management clients - e.g. azure kv, infisical, etc.
)
secret_manager_client: Optional[
Any
] = None # list of instantiated key management clients - e.g. azure kv, infisical, etc.
_google_kms_resource_name: Optional[str] = None
_key_management_system: Optional[KeyManagementSystem] = None
_key_management_settings: KeyManagementSettings = KeyManagementSettings()
@ -325,6 +315,7 @@ from litellm.litellm_core_utils.get_model_cost_map import get_model_cost_map
model_cost = get_model_cost_map(url=model_cost_map_url)
custom_prompt_dict: Dict[str, dict] = {}
check_provider_endpoint = False
####### THREAD-SPECIFIC DATA ####################
@ -1064,10 +1055,10 @@ from .types.llms.custom_llm import CustomLLMItem
from .types.utils import GenericStreamingChunk
custom_provider_map: List[CustomLLMItem] = []
_custom_providers: List[str] = (
[]
) # internal helper util, used to track names of custom providers
disable_hf_tokenizer_download: Optional[bool] = (
None # disable huggingface tokenizer download. Defaults to openai clk100
)
_custom_providers: List[
str
] = [] # internal helper util, used to track names of custom providers
disable_hf_tokenizer_download: Optional[
bool
] = None # disable huggingface tokenizer download. Defaults to openai clk100
global_disable_no_log_param: bool = False

View file

@ -33,6 +33,7 @@ model_list:
litellm_settings:
num_retries: 0
callbacks: ["prometheus"]
check_provider_endpoint: true
files_settings:
- custom_llm_provider: gemini

View file

@ -1,11 +1,12 @@
# What is this?
## Common checks for /v1/models and `/model/info`
import copy
from typing import Dict, List, Optional, Set
import litellm
from litellm._logging import verbose_proxy_logger
from litellm.proxy._types import SpecialModelNames, UserAPIKeyAuth
from litellm.router import Router
from litellm.types.router import LiteLLM_Params
from litellm.utils import get_valid_models
@ -23,15 +24,20 @@ def _check_wildcard_routing(model: str) -> bool:
return False
def get_provider_models(provider: str) -> Optional[List[str]]:
def get_provider_models(
provider: str, litellm_params: Optional[LiteLLM_Params] = None
) -> Optional[List[str]]:
"""
Returns the list of known models by provider
"""
if provider == "*":
return get_valid_models()
return get_valid_models(litellm_params=litellm_params)
if provider in litellm.models_by_provider:
provider_models = copy.deepcopy(litellm.models_by_provider[provider])
provider_models = get_valid_models(
custom_llm_provider=provider, litellm_params=litellm_params
)
# provider_models = copy.deepcopy(litellm.models_by_provider[provider])
for idx, _model in enumerate(provider_models):
if provider not in _model:
provider_models[idx] = f"{provider}/{_model}"
@ -118,6 +124,7 @@ def get_complete_model_list(
user_model: Optional[str],
infer_model_from_keys: Optional[bool],
return_wildcard_routes: Optional[bool] = False,
llm_router: Optional[Router] = None,
) -> List[str]:
"""Logic for returning complete model list for a given key + team pair"""
@ -143,19 +150,25 @@ def get_complete_model_list(
unique_models.update(valid_models)
all_wildcard_models = _get_wildcard_models(
unique_models=unique_models, return_wildcard_routes=return_wildcard_routes
unique_models=unique_models,
return_wildcard_routes=return_wildcard_routes,
llm_router=llm_router,
)
return list(unique_models) + all_wildcard_models
def get_known_models_from_wildcard(wildcard_model: str) -> List[str]:
def get_known_models_from_wildcard(
wildcard_model: str, litellm_params: Optional[LiteLLM_Params] = None
) -> List[str]:
try:
provider, model = wildcard_model.split("/", 1)
except ValueError: # safely fail
return []
# get all known provider models
wildcard_models = get_provider_models(provider=provider)
wildcard_models = get_provider_models(
provider=provider, litellm_params=litellm_params
)
if wildcard_models is None:
return []
if model == "*":
@ -172,7 +185,9 @@ def get_known_models_from_wildcard(wildcard_model: str) -> List[str]:
def _get_wildcard_models(
unique_models: Set[str], return_wildcard_routes: Optional[bool] = False
unique_models: Set[str],
return_wildcard_routes: Optional[bool] = False,
llm_router: Optional[Router] = None,
) -> List[str]:
models_to_remove = set()
all_wildcard_models = []
@ -183,12 +198,25 @@ def _get_wildcard_models(
): # will add the wildcard route to the list eg: anthropic/*.
all_wildcard_models.append(model)
# get all known provider models
wildcard_models = get_known_models_from_wildcard(wildcard_model=model)
## get litellm params from model
if llm_router is not None:
model_list = llm_router.get_model_list(model_name=model)
if model_list is not None:
for router_model in model_list:
wildcard_models = get_known_models_from_wildcard(
wildcard_model=model,
litellm_params=LiteLLM_Params(
**router_model["litellm_params"] # type: ignore
),
)
all_wildcard_models.extend(wildcard_models)
else:
# get all known provider models
wildcard_models = get_known_models_from_wildcard(wildcard_model=model)
if wildcard_models is not None:
models_to_remove.add(model)
all_wildcard_models.extend(wildcard_models)
if wildcard_models is not None:
models_to_remove.add(model)
all_wildcard_models.extend(wildcard_models)
for model in models_to_remove:
unique_models.remove(model)

View file

@ -803,9 +803,9 @@ model_max_budget_limiter = _PROXY_VirtualKeyModelMaxBudgetLimiter(
dual_cache=user_api_key_cache
)
litellm.logging_callback_manager.add_litellm_callback(model_max_budget_limiter)
redis_usage_cache: Optional[RedisCache] = (
None # redis cache used for tracking spend, tpm/rpm limits
)
redis_usage_cache: Optional[
RedisCache
] = None # redis cache used for tracking spend, tpm/rpm limits
user_custom_auth = None
user_custom_key_generate = None
user_custom_sso = None
@ -1131,9 +1131,9 @@ async def update_cache( # noqa: PLR0915
_id = "team_id:{}".format(team_id)
try:
# Fetch the existing cost for the given user
existing_spend_obj: Optional[LiteLLM_TeamTable] = (
await user_api_key_cache.async_get_cache(key=_id)
)
existing_spend_obj: Optional[
LiteLLM_TeamTable
] = await user_api_key_cache.async_get_cache(key=_id)
if existing_spend_obj is None:
# do nothing if team not in api key cache
return
@ -2812,9 +2812,9 @@ async def initialize( # noqa: PLR0915
user_api_base = api_base
dynamic_config[user_model]["api_base"] = api_base
if api_version:
os.environ["AZURE_API_VERSION"] = (
api_version # set this for azure - litellm can read this from the env
)
os.environ[
"AZURE_API_VERSION"
] = api_version # set this for azure - litellm can read this from the env
if max_tokens: # model-specific param
dynamic_config[user_model]["max_tokens"] = max_tokens
if temperature: # model-specific param
@ -3316,6 +3316,7 @@ async def model_list(
user_model=user_model,
infer_model_from_keys=general_settings.get("infer_model_from_keys", False),
return_wildcard_routes=return_wildcard_routes,
llm_router=llm_router,
)
return dict(
@ -7758,9 +7759,9 @@ async def get_config_list(
hasattr(sub_field_info, "description")
and sub_field_info.description is not None
):
nested_fields[idx].field_description = (
sub_field_info.description
)
nested_fields[
idx
].field_description = sub_field_info.description
idx += 1
_stored_in_db = None

View file

@ -5807,8 +5807,133 @@ def trim_messages(
return messages
from litellm.caching.in_memory_cache import InMemoryCache
class AvailableModelsCache(InMemoryCache):
def __init__(self, ttl_seconds: int = 300, max_size: int = 1000):
super().__init__(ttl_seconds, max_size)
self._env_hash: Optional[str] = None
def _get_env_hash(self) -> str:
"""Create a hash of relevant environment variables"""
env_vars = {
k: v
for k, v in os.environ.items()
if k.startswith(("OPENAI", "ANTHROPIC", "AZURE", "AWS"))
}
return str(hash(frozenset(env_vars.items())))
def _check_env_changed(self) -> bool:
"""Check if environment variables have changed"""
current_hash = self._get_env_hash()
if self._env_hash is None:
self._env_hash = current_hash
return True
return current_hash != self._env_hash
def _get_cache_key(
self,
custom_llm_provider: Optional[str],
litellm_params: Optional[LiteLLM_Params],
) -> str:
valid_str = ""
if litellm_params is not None:
valid_str = litellm_params.model_dump_json()
if custom_llm_provider is not None:
valid_str = f"{custom_llm_provider}:{valid_str}"
return hashlib.sha256(valid_str.encode()).hexdigest()
def get_cached_model_info(
self,
custom_llm_provider: Optional[str] = None,
litellm_params: Optional[LiteLLM_Params] = None,
) -> Optional[List[str]]:
"""Get cached model info"""
# Check if environment has changed
if litellm_params is None and self._check_env_changed():
self.cache_dict.clear()
return None
cache_key = self._get_cache_key(custom_llm_provider, litellm_params)
result = cast(Optional[List[str]], self.get_cache(cache_key))
if result is not None:
return copy.deepcopy(result)
return result
def set_cached_model_info(
self,
custom_llm_provider: str,
litellm_params: Optional[LiteLLM_Params],
available_models: List[str],
):
"""Set cached model info"""
cache_key = self._get_cache_key(custom_llm_provider, litellm_params)
self.set_cache(cache_key, copy.deepcopy(available_models))
# Global cache instance
_model_cache = AvailableModelsCache()
def _infer_valid_provider_from_env_vars(
custom_llm_provider: Optional[str] = None,
) -> List[str]:
valid_providers: List[str] = []
environ_keys = os.environ.keys()
for provider in litellm.provider_list:
if custom_llm_provider and provider != custom_llm_provider:
continue
# edge case litellm has together_ai as a provider, it should be togetherai
env_provider_1 = provider.replace("_", "")
env_provider_2 = provider
# litellm standardizes expected provider keys to
# PROVIDER_API_KEY. Example: OPENAI_API_KEY, COHERE_API_KEY
expected_provider_key_1 = f"{env_provider_1.upper()}_API_KEY"
expected_provider_key_2 = f"{env_provider_2.upper()}_API_KEY"
if (
expected_provider_key_1 in environ_keys
or expected_provider_key_2 in environ_keys
):
# key is set
valid_providers.append(provider)
return valid_providers
def _get_valid_models_from_provider_api(
provider_config: BaseLLMModelInfo,
custom_llm_provider: str,
litellm_params: Optional[LiteLLM_Params] = None,
) -> List[str]:
try:
cached_result = _model_cache.get_cached_model_info(
custom_llm_provider, litellm_params
)
if cached_result is not None:
return cached_result
models = provider_config.get_models(
api_key=litellm_params.api_key if litellm_params is not None else None,
api_base=litellm_params.api_base if litellm_params is not None else None,
)
_model_cache.set_cached_model_info(custom_llm_provider, litellm_params, models)
return models
except Exception as e:
verbose_logger.debug(f"Error getting valid models: {e}")
return []
def get_valid_models(
check_provider_endpoint: bool = False, custom_llm_provider: Optional[str] = None
check_provider_endpoint: Optional[bool] = None,
custom_llm_provider: Optional[str] = None,
litellm_params: Optional[LiteLLM_Params] = None,
) -> List[str]:
"""
Returns a list of valid LLMs based on the set environment variables
@ -5819,31 +5944,21 @@ def get_valid_models(
Returns:
A list of valid LLMs
"""
try:
check_provider_endpoint = (
check_provider_endpoint or litellm.check_provider_endpoint
)
# get keys set in .env
environ_keys = os.environ.keys()
valid_providers = []
valid_providers: List[str] = []
valid_models: List[str] = []
# for all valid providers, make a list of supported llms
valid_models = []
for provider in litellm.provider_list:
if custom_llm_provider and provider != custom_llm_provider:
continue
# edge case litellm has together_ai as a provider, it should be togetherai
env_provider_1 = provider.replace("_", "")
env_provider_2 = provider
# litellm standardizes expected provider keys to
# PROVIDER_API_KEY. Example: OPENAI_API_KEY, COHERE_API_KEY
expected_provider_key_1 = f"{env_provider_1.upper()}_API_KEY"
expected_provider_key_2 = f"{env_provider_2.upper()}_API_KEY"
if (
expected_provider_key_1 in environ_keys
or expected_provider_key_2 in environ_keys
):
# key is set
valid_providers.append(provider)
if custom_llm_provider:
valid_providers = [custom_llm_provider]
else:
valid_providers = _infer_valid_provider_from_env_vars(custom_llm_provider)
for provider in valid_providers:
provider_config = ProviderConfigManager.get_provider_model_info(
@ -5856,15 +5971,24 @@ def get_valid_models(
if provider == "azure":
valid_models.append("Azure-LLM")
elif provider_config is not None and check_provider_endpoint:
try:
models = provider_config.get_models()
valid_models.extend(models)
except Exception as e:
verbose_logger.debug(f"Error getting valid models: {e}")
elif (
provider_config is not None
and check_provider_endpoint
and provider is not None
):
valid_models.extend(
_get_valid_models_from_provider_api(
provider_config,
provider,
litellm_params,
)
)
else:
models_for_provider = litellm.models_by_provider.get(provider, [])
models_for_provider = copy.deepcopy(
litellm.models_by_provider.get(provider, [])
)
valid_models.extend(models_for_provider)
return valid_models
except Exception as e:
verbose_logger.debug(f"Error getting valid models: {e}")

View file

@ -41,8 +41,10 @@ from unittest.mock import AsyncMock, MagicMock, patch
# Assuming your trim_messages, shorten_message_to_fit_limit, and get_token_count functions are all in a module named 'message_utils'
@pytest.fixture(autouse=True)
def reset_mock_cache():
from litellm.utils import _model_cache
_model_cache.flush_cache()
# Test 1: Check trimming of normal message
def test_basic_trimming():
messages = [
@ -1539,6 +1541,7 @@ def test_get_valid_models_fireworks_ai(monkeypatch):
litellm.module_level_client, "get", return_value=mock_response
) as mock_post:
valid_models = get_valid_models(check_provider_endpoint=True)
print("valid_models", valid_models)
mock_post.assert_called_once()
assert (
"fireworks_ai/accounts/fireworks/models/llama-3.1-8b-instruct"
@ -2122,3 +2125,58 @@ def test_get_provider_audio_transcription_config():
config = ProviderConfigManager.get_provider_audio_transcription_config(
model="whisper-1", provider=provider
)
def test_get_valid_models_from_provider():
"""
Test that get_valid_models returns the correct models for a given provider
"""
from litellm.utils import get_valid_models
valid_models = get_valid_models(custom_llm_provider="openai")
assert len(valid_models) > 0
assert "gpt-4o-mini" in valid_models
print("Valid models: ", valid_models)
valid_models.remove("gpt-4o-mini")
assert "gpt-4o-mini" not in valid_models
valid_models = get_valid_models(custom_llm_provider="openai")
assert len(valid_models) > 0
assert "gpt-4o-mini" in valid_models
def test_get_valid_models_from_provider_cache_invalidation(monkeypatch):
"""
Test that get_valid_models returns the correct models for a given provider
"""
from litellm.utils import _model_cache
monkeypatch.setenv("OPENAI_API_KEY", "123")
_model_cache.set_cached_model_info("openai", litellm_params=None, available_models=["gpt-4o-mini"])
monkeypatch.delenv("OPENAI_API_KEY")
assert _model_cache.get_cached_model_info("openai") is None
def test_get_valid_models_from_dynamic_api_key():
"""
Test that get_valid_models returns the correct models for a given provider
"""
from litellm.utils import get_valid_models
from litellm.types.router import CredentialLiteLLMParams
creds = CredentialLiteLLMParams(api_key="123")
valid_models = get_valid_models(custom_llm_provider="anthropic", litellm_params=creds, check_provider_endpoint=True)
assert len(valid_models) == 0
creds = CredentialLiteLLMParams(api_key=os.getenv("ANTHROPIC_API_KEY"))
valid_models = get_valid_models(custom_llm_provider="anthropic", litellm_params=creds, check_provider_endpoint=True)
assert len(valid_models) > 0
assert "anthropic/claude-3-7-sonnet-20250219" in valid_models