Support checking provider /models endpoints on proxy /v1/models endpoint (#9958)

* feat(utils.py): support global flag for 'check_provider_endpoints'

enables setting this for `/models` on proxy

* feat(utils.py): add caching to 'get_valid_models'

Prevents checking endpoint repeatedly

* fix(utils.py): ensure mutations don't impact cached results

* test(test_utils.py): add unit test to confirm cache invalidation logic

* feat(utils.py): get_valid_models - support passing litellm params dynamically

Allows for checking endpoints based on received credentials

* test: update test

* feat(model_checks.py): pass router credentials to get_valid_models - ensures it checks correct credentials

* refactor(utils.py): refactor for simpler functions

* fix: fix linting errors

* fix(utils.py): fix test

* fix(utils.py): set valid providers to custom_llm_provider, if given

* test: update test

* fix: fix ruff check error
This commit is contained in:
Krish Dholakia 2025-04-14 23:23:20 -07:00 committed by GitHub
parent e94eb4ec70
commit 33ead69c0a
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 313 additions and 110 deletions

View file

@ -5807,8 +5807,133 @@ def trim_messages(
return messages
from litellm.caching.in_memory_cache import InMemoryCache
class AvailableModelsCache(InMemoryCache):
def __init__(self, ttl_seconds: int = 300, max_size: int = 1000):
super().__init__(ttl_seconds, max_size)
self._env_hash: Optional[str] = None
def _get_env_hash(self) -> str:
"""Create a hash of relevant environment variables"""
env_vars = {
k: v
for k, v in os.environ.items()
if k.startswith(("OPENAI", "ANTHROPIC", "AZURE", "AWS"))
}
return str(hash(frozenset(env_vars.items())))
def _check_env_changed(self) -> bool:
"""Check if environment variables have changed"""
current_hash = self._get_env_hash()
if self._env_hash is None:
self._env_hash = current_hash
return True
return current_hash != self._env_hash
def _get_cache_key(
self,
custom_llm_provider: Optional[str],
litellm_params: Optional[LiteLLM_Params],
) -> str:
valid_str = ""
if litellm_params is not None:
valid_str = litellm_params.model_dump_json()
if custom_llm_provider is not None:
valid_str = f"{custom_llm_provider}:{valid_str}"
return hashlib.sha256(valid_str.encode()).hexdigest()
def get_cached_model_info(
self,
custom_llm_provider: Optional[str] = None,
litellm_params: Optional[LiteLLM_Params] = None,
) -> Optional[List[str]]:
"""Get cached model info"""
# Check if environment has changed
if litellm_params is None and self._check_env_changed():
self.cache_dict.clear()
return None
cache_key = self._get_cache_key(custom_llm_provider, litellm_params)
result = cast(Optional[List[str]], self.get_cache(cache_key))
if result is not None:
return copy.deepcopy(result)
return result
def set_cached_model_info(
self,
custom_llm_provider: str,
litellm_params: Optional[LiteLLM_Params],
available_models: List[str],
):
"""Set cached model info"""
cache_key = self._get_cache_key(custom_llm_provider, litellm_params)
self.set_cache(cache_key, copy.deepcopy(available_models))
# Global cache instance
_model_cache = AvailableModelsCache()
def _infer_valid_provider_from_env_vars(
custom_llm_provider: Optional[str] = None,
) -> List[str]:
valid_providers: List[str] = []
environ_keys = os.environ.keys()
for provider in litellm.provider_list:
if custom_llm_provider and provider != custom_llm_provider:
continue
# edge case litellm has together_ai as a provider, it should be togetherai
env_provider_1 = provider.replace("_", "")
env_provider_2 = provider
# litellm standardizes expected provider keys to
# PROVIDER_API_KEY. Example: OPENAI_API_KEY, COHERE_API_KEY
expected_provider_key_1 = f"{env_provider_1.upper()}_API_KEY"
expected_provider_key_2 = f"{env_provider_2.upper()}_API_KEY"
if (
expected_provider_key_1 in environ_keys
or expected_provider_key_2 in environ_keys
):
# key is set
valid_providers.append(provider)
return valid_providers
def _get_valid_models_from_provider_api(
provider_config: BaseLLMModelInfo,
custom_llm_provider: str,
litellm_params: Optional[LiteLLM_Params] = None,
) -> List[str]:
try:
cached_result = _model_cache.get_cached_model_info(
custom_llm_provider, litellm_params
)
if cached_result is not None:
return cached_result
models = provider_config.get_models(
api_key=litellm_params.api_key if litellm_params is not None else None,
api_base=litellm_params.api_base if litellm_params is not None else None,
)
_model_cache.set_cached_model_info(custom_llm_provider, litellm_params, models)
return models
except Exception as e:
verbose_logger.debug(f"Error getting valid models: {e}")
return []
def get_valid_models(
check_provider_endpoint: bool = False, custom_llm_provider: Optional[str] = None
check_provider_endpoint: Optional[bool] = None,
custom_llm_provider: Optional[str] = None,
litellm_params: Optional[LiteLLM_Params] = None,
) -> List[str]:
"""
Returns a list of valid LLMs based on the set environment variables
@ -5819,31 +5944,21 @@ def get_valid_models(
Returns:
A list of valid LLMs
"""
try:
check_provider_endpoint = (
check_provider_endpoint or litellm.check_provider_endpoint
)
# get keys set in .env
environ_keys = os.environ.keys()
valid_providers = []
valid_providers: List[str] = []
valid_models: List[str] = []
# for all valid providers, make a list of supported llms
valid_models = []
for provider in litellm.provider_list:
if custom_llm_provider and provider != custom_llm_provider:
continue
# edge case litellm has together_ai as a provider, it should be togetherai
env_provider_1 = provider.replace("_", "")
env_provider_2 = provider
# litellm standardizes expected provider keys to
# PROVIDER_API_KEY. Example: OPENAI_API_KEY, COHERE_API_KEY
expected_provider_key_1 = f"{env_provider_1.upper()}_API_KEY"
expected_provider_key_2 = f"{env_provider_2.upper()}_API_KEY"
if (
expected_provider_key_1 in environ_keys
or expected_provider_key_2 in environ_keys
):
# key is set
valid_providers.append(provider)
if custom_llm_provider:
valid_providers = [custom_llm_provider]
else:
valid_providers = _infer_valid_provider_from_env_vars(custom_llm_provider)
for provider in valid_providers:
provider_config = ProviderConfigManager.get_provider_model_info(
@ -5856,15 +5971,24 @@ def get_valid_models(
if provider == "azure":
valid_models.append("Azure-LLM")
elif provider_config is not None and check_provider_endpoint:
try:
models = provider_config.get_models()
valid_models.extend(models)
except Exception as e:
verbose_logger.debug(f"Error getting valid models: {e}")
elif (
provider_config is not None
and check_provider_endpoint
and provider is not None
):
valid_models.extend(
_get_valid_models_from_provider_api(
provider_config,
provider,
litellm_params,
)
)
else:
models_for_provider = litellm.models_by_provider.get(provider, [])
models_for_provider = copy.deepcopy(
litellm.models_by_provider.get(provider, [])
)
valid_models.extend(models_for_provider)
return valid_models
except Exception as e:
verbose_logger.debug(f"Error getting valid models: {e}")