mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 02:34:29 +00:00
Support checking provider /models
endpoints on proxy /v1/models
endpoint (#9958)
* feat(utils.py): support global flag for 'check_provider_endpoints' enables setting this for `/models` on proxy * feat(utils.py): add caching to 'get_valid_models' Prevents checking endpoint repeatedly * fix(utils.py): ensure mutations don't impact cached results * test(test_utils.py): add unit test to confirm cache invalidation logic * feat(utils.py): get_valid_models - support passing litellm params dynamically Allows for checking endpoints based on received credentials * test: update test * feat(model_checks.py): pass router credentials to get_valid_models - ensures it checks correct credentials * refactor(utils.py): refactor for simpler functions * fix: fix linting errors * fix(utils.py): fix test * fix(utils.py): set valid providers to custom_llm_provider, if given * test: update test * fix: fix ruff check error
This commit is contained in:
parent
e94eb4ec70
commit
33ead69c0a
6 changed files with 313 additions and 110 deletions
182
litellm/utils.py
182
litellm/utils.py
|
@ -5807,8 +5807,133 @@ def trim_messages(
|
|||
return messages
|
||||
|
||||
|
||||
from litellm.caching.in_memory_cache import InMemoryCache
|
||||
|
||||
|
||||
class AvailableModelsCache(InMemoryCache):
|
||||
def __init__(self, ttl_seconds: int = 300, max_size: int = 1000):
|
||||
super().__init__(ttl_seconds, max_size)
|
||||
self._env_hash: Optional[str] = None
|
||||
|
||||
def _get_env_hash(self) -> str:
|
||||
"""Create a hash of relevant environment variables"""
|
||||
env_vars = {
|
||||
k: v
|
||||
for k, v in os.environ.items()
|
||||
if k.startswith(("OPENAI", "ANTHROPIC", "AZURE", "AWS"))
|
||||
}
|
||||
return str(hash(frozenset(env_vars.items())))
|
||||
|
||||
def _check_env_changed(self) -> bool:
|
||||
"""Check if environment variables have changed"""
|
||||
current_hash = self._get_env_hash()
|
||||
if self._env_hash is None:
|
||||
self._env_hash = current_hash
|
||||
return True
|
||||
return current_hash != self._env_hash
|
||||
|
||||
def _get_cache_key(
|
||||
self,
|
||||
custom_llm_provider: Optional[str],
|
||||
litellm_params: Optional[LiteLLM_Params],
|
||||
) -> str:
|
||||
valid_str = ""
|
||||
|
||||
if litellm_params is not None:
|
||||
valid_str = litellm_params.model_dump_json()
|
||||
if custom_llm_provider is not None:
|
||||
valid_str = f"{custom_llm_provider}:{valid_str}"
|
||||
return hashlib.sha256(valid_str.encode()).hexdigest()
|
||||
|
||||
def get_cached_model_info(
|
||||
self,
|
||||
custom_llm_provider: Optional[str] = None,
|
||||
litellm_params: Optional[LiteLLM_Params] = None,
|
||||
) -> Optional[List[str]]:
|
||||
"""Get cached model info"""
|
||||
# Check if environment has changed
|
||||
if litellm_params is None and self._check_env_changed():
|
||||
self.cache_dict.clear()
|
||||
return None
|
||||
|
||||
cache_key = self._get_cache_key(custom_llm_provider, litellm_params)
|
||||
|
||||
result = cast(Optional[List[str]], self.get_cache(cache_key))
|
||||
|
||||
if result is not None:
|
||||
return copy.deepcopy(result)
|
||||
return result
|
||||
|
||||
def set_cached_model_info(
|
||||
self,
|
||||
custom_llm_provider: str,
|
||||
litellm_params: Optional[LiteLLM_Params],
|
||||
available_models: List[str],
|
||||
):
|
||||
"""Set cached model info"""
|
||||
cache_key = self._get_cache_key(custom_llm_provider, litellm_params)
|
||||
self.set_cache(cache_key, copy.deepcopy(available_models))
|
||||
|
||||
|
||||
# Global cache instance
|
||||
_model_cache = AvailableModelsCache()
|
||||
|
||||
|
||||
def _infer_valid_provider_from_env_vars(
|
||||
custom_llm_provider: Optional[str] = None,
|
||||
) -> List[str]:
|
||||
valid_providers: List[str] = []
|
||||
environ_keys = os.environ.keys()
|
||||
for provider in litellm.provider_list:
|
||||
if custom_llm_provider and provider != custom_llm_provider:
|
||||
continue
|
||||
|
||||
# edge case litellm has together_ai as a provider, it should be togetherai
|
||||
env_provider_1 = provider.replace("_", "")
|
||||
env_provider_2 = provider
|
||||
|
||||
# litellm standardizes expected provider keys to
|
||||
# PROVIDER_API_KEY. Example: OPENAI_API_KEY, COHERE_API_KEY
|
||||
expected_provider_key_1 = f"{env_provider_1.upper()}_API_KEY"
|
||||
expected_provider_key_2 = f"{env_provider_2.upper()}_API_KEY"
|
||||
if (
|
||||
expected_provider_key_1 in environ_keys
|
||||
or expected_provider_key_2 in environ_keys
|
||||
):
|
||||
# key is set
|
||||
valid_providers.append(provider)
|
||||
|
||||
return valid_providers
|
||||
|
||||
|
||||
def _get_valid_models_from_provider_api(
|
||||
provider_config: BaseLLMModelInfo,
|
||||
custom_llm_provider: str,
|
||||
litellm_params: Optional[LiteLLM_Params] = None,
|
||||
) -> List[str]:
|
||||
try:
|
||||
cached_result = _model_cache.get_cached_model_info(
|
||||
custom_llm_provider, litellm_params
|
||||
)
|
||||
|
||||
if cached_result is not None:
|
||||
return cached_result
|
||||
models = provider_config.get_models(
|
||||
api_key=litellm_params.api_key if litellm_params is not None else None,
|
||||
api_base=litellm_params.api_base if litellm_params is not None else None,
|
||||
)
|
||||
|
||||
_model_cache.set_cached_model_info(custom_llm_provider, litellm_params, models)
|
||||
return models
|
||||
except Exception as e:
|
||||
verbose_logger.debug(f"Error getting valid models: {e}")
|
||||
return []
|
||||
|
||||
|
||||
def get_valid_models(
|
||||
check_provider_endpoint: bool = False, custom_llm_provider: Optional[str] = None
|
||||
check_provider_endpoint: Optional[bool] = None,
|
||||
custom_llm_provider: Optional[str] = None,
|
||||
litellm_params: Optional[LiteLLM_Params] = None,
|
||||
) -> List[str]:
|
||||
"""
|
||||
Returns a list of valid LLMs based on the set environment variables
|
||||
|
@ -5819,31 +5944,21 @@ def get_valid_models(
|
|||
Returns:
|
||||
A list of valid LLMs
|
||||
"""
|
||||
|
||||
try:
|
||||
check_provider_endpoint = (
|
||||
check_provider_endpoint or litellm.check_provider_endpoint
|
||||
)
|
||||
# get keys set in .env
|
||||
environ_keys = os.environ.keys()
|
||||
valid_providers = []
|
||||
|
||||
valid_providers: List[str] = []
|
||||
valid_models: List[str] = []
|
||||
# for all valid providers, make a list of supported llms
|
||||
valid_models = []
|
||||
|
||||
for provider in litellm.provider_list:
|
||||
if custom_llm_provider and provider != custom_llm_provider:
|
||||
continue
|
||||
|
||||
# edge case litellm has together_ai as a provider, it should be togetherai
|
||||
env_provider_1 = provider.replace("_", "")
|
||||
env_provider_2 = provider
|
||||
|
||||
# litellm standardizes expected provider keys to
|
||||
# PROVIDER_API_KEY. Example: OPENAI_API_KEY, COHERE_API_KEY
|
||||
expected_provider_key_1 = f"{env_provider_1.upper()}_API_KEY"
|
||||
expected_provider_key_2 = f"{env_provider_2.upper()}_API_KEY"
|
||||
if (
|
||||
expected_provider_key_1 in environ_keys
|
||||
or expected_provider_key_2 in environ_keys
|
||||
):
|
||||
# key is set
|
||||
valid_providers.append(provider)
|
||||
if custom_llm_provider:
|
||||
valid_providers = [custom_llm_provider]
|
||||
else:
|
||||
valid_providers = _infer_valid_provider_from_env_vars(custom_llm_provider)
|
||||
|
||||
for provider in valid_providers:
|
||||
provider_config = ProviderConfigManager.get_provider_model_info(
|
||||
|
@ -5856,15 +5971,24 @@ def get_valid_models(
|
|||
|
||||
if provider == "azure":
|
||||
valid_models.append("Azure-LLM")
|
||||
elif provider_config is not None and check_provider_endpoint:
|
||||
try:
|
||||
models = provider_config.get_models()
|
||||
valid_models.extend(models)
|
||||
except Exception as e:
|
||||
verbose_logger.debug(f"Error getting valid models: {e}")
|
||||
elif (
|
||||
provider_config is not None
|
||||
and check_provider_endpoint
|
||||
and provider is not None
|
||||
):
|
||||
valid_models.extend(
|
||||
_get_valid_models_from_provider_api(
|
||||
provider_config,
|
||||
provider,
|
||||
litellm_params,
|
||||
)
|
||||
)
|
||||
else:
|
||||
models_for_provider = litellm.models_by_provider.get(provider, [])
|
||||
models_for_provider = copy.deepcopy(
|
||||
litellm.models_by_provider.get(provider, [])
|
||||
)
|
||||
valid_models.extend(models_for_provider)
|
||||
|
||||
return valid_models
|
||||
except Exception as e:
|
||||
verbose_logger.debug(f"Error getting valid models: {e}")
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue