Support checking provider-specific /models endpoints for available models based on key (#7538)

* test(test_utils.py): initial test for valid models

Addresses https://github.com/BerriAI/litellm/issues/7525

* fix: test

* feat(fireworks_ai/transformation.py): support retrieving valid models from fireworks ai endpoint

* refactor(fireworks_ai/): support checking model info on `/v1/models` route

* docs(set_keys.md): update docs to clarify check llm provider api usage

* fix(watsonx/common_utils.py): support 'WATSONX_ZENAPIKEY' for iam auth

* fix(watsonx): read in watsonx token from env var

* fix: fix linting errors

* fix(utils.py): fix provider config check

* style: cleanup unused imports
This commit is contained in:
Krish Dholakia 2025-01-03 19:29:59 -08:00 committed by GitHub
parent cac06a32b8
commit f770dd0c95
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
12 changed files with 350 additions and 42 deletions

View file

@ -4223,6 +4223,7 @@ def _get_model_info_helper( # noqa: PLR0915
_model_info: Optional[Dict[str, Any]] = None
key: Optional[str] = None
provider_config: Optional[BaseLLMModelInfo] = None
if combined_model_name in litellm.model_cost:
key = combined_model_name
_model_info = _get_model_info_from_model_cost(key=key)
@ -4261,16 +4262,20 @@ def _get_model_info_helper( # noqa: PLR0915
model_info=_model_info, custom_llm_provider=custom_llm_provider
):
_model_info = None
if _model_info is None and ProviderConfigManager.get_provider_model_info(
model=model, provider=LlmProviders(custom_llm_provider)
):
if custom_llm_provider:
provider_config = ProviderConfigManager.get_provider_model_info(
model=model, provider=LlmProviders(custom_llm_provider)
)
if provider_config is not None:
_model_info = cast(
dict, provider_config.get_model_info(model=model)
)
if _model_info is None and provider_config is not None:
_model_info = cast(
Optional[Dict],
provider_config.get_model_info(
model=model, existing_model_info=_model_info
),
)
if key is None:
key = "provider_specific_model_info"
if _model_info is None or key is None:
raise ValueError(
@ -5706,12 +5711,12 @@ def trim_messages(
return messages
def get_valid_models() -> List[str]:
def get_valid_models(check_provider_endpoint: bool = False) -> List[str]:
"""
Returns a list of valid LLMs based on the set environment variables
Args:
None
check_provider_endpoint: If True, will check the provider's endpoint for valid models.
Returns:
A list of valid LLMs
@ -5725,22 +5730,36 @@ def get_valid_models() -> List[str]:
for provider in litellm.provider_list:
# edge case litellm has together_ai as a provider, it should be togetherai
provider = provider.replace("_", "")
env_provider_1 = provider.replace("_", "")
env_provider_2 = provider
# litellm standardizes expected provider keys to
# PROVIDER_API_KEY. Example: OPENAI_API_KEY, COHERE_API_KEY
expected_provider_key = f"{provider.upper()}_API_KEY"
if expected_provider_key in environ_keys:
expected_provider_key_1 = f"{env_provider_1.upper()}_API_KEY"
expected_provider_key_2 = f"{env_provider_2.upper()}_API_KEY"
if (
expected_provider_key_1 in environ_keys
or expected_provider_key_2 in environ_keys
):
# key is set
valid_providers.append(provider)
for provider in valid_providers:
provider_config = ProviderConfigManager.get_provider_model_info(
model=None,
provider=LlmProviders(provider),
)
if provider == "azure":
valid_models.append("Azure-LLM")
elif provider_config is not None and check_provider_endpoint:
valid_models.extend(provider_config.get_models())
else:
models_for_provider = litellm.models_by_provider.get(provider, [])
valid_models.extend(models_for_provider)
return valid_models
except Exception:
except Exception as e:
verbose_logger.debug(f"Error getting valid models: {e}")
return [] # NON-Blocking
@ -6291,11 +6310,14 @@ class ProviderConfigManager:
@staticmethod
def get_provider_model_info(
model: str,
model: Optional[str],
provider: LlmProviders,
) -> Optional[BaseLLMModelInfo]:
if LlmProviders.FIREWORKS_AI == provider:
return litellm.FireworksAIConfig()
elif LlmProviders.LITELLM_PROXY == provider:
return litellm.LiteLLMProxyChatConfig()
return None