mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 10:44:24 +00:00
Support checking provider-specific /models
endpoints for available models based on key (#7538)
* test(test_utils.py): initial test for valid models Addresses https://github.com/BerriAI/litellm/issues/7525 * fix: test * feat(fireworks_ai/transformation.py): support retrieving valid models from fireworks ai endpoint * refactor(fireworks_ai/): support checking model info on `/v1/models` route * docs(set_keys.md): update docs to clarify check llm provider api usage * fix(watsonx/common_utils.py): support 'WATSONX_ZENAPIKEY' for iam auth * fix(watsonx): read in watsonx token from env var * fix: fix linting errors * fix(utils.py): fix provider config check * style: cleanup unused imports
This commit is contained in:
parent
cac06a32b8
commit
f770dd0c95
12 changed files with 350 additions and 42 deletions
|
@ -4223,6 +4223,7 @@ def _get_model_info_helper( # noqa: PLR0915
|
|||
|
||||
_model_info: Optional[Dict[str, Any]] = None
|
||||
key: Optional[str] = None
|
||||
provider_config: Optional[BaseLLMModelInfo] = None
|
||||
if combined_model_name in litellm.model_cost:
|
||||
key = combined_model_name
|
||||
_model_info = _get_model_info_from_model_cost(key=key)
|
||||
|
@ -4261,16 +4262,20 @@ def _get_model_info_helper( # noqa: PLR0915
|
|||
model_info=_model_info, custom_llm_provider=custom_llm_provider
|
||||
):
|
||||
_model_info = None
|
||||
if _model_info is None and ProviderConfigManager.get_provider_model_info(
|
||||
model=model, provider=LlmProviders(custom_llm_provider)
|
||||
):
|
||||
|
||||
if custom_llm_provider:
|
||||
provider_config = ProviderConfigManager.get_provider_model_info(
|
||||
model=model, provider=LlmProviders(custom_llm_provider)
|
||||
)
|
||||
if provider_config is not None:
|
||||
_model_info = cast(
|
||||
dict, provider_config.get_model_info(model=model)
|
||||
)
|
||||
|
||||
if _model_info is None and provider_config is not None:
|
||||
_model_info = cast(
|
||||
Optional[Dict],
|
||||
provider_config.get_model_info(
|
||||
model=model, existing_model_info=_model_info
|
||||
),
|
||||
)
|
||||
if key is None:
|
||||
key = "provider_specific_model_info"
|
||||
if _model_info is None or key is None:
|
||||
raise ValueError(
|
||||
|
@ -5706,12 +5711,12 @@ def trim_messages(
|
|||
return messages
|
||||
|
||||
|
||||
def get_valid_models() -> List[str]:
|
||||
def get_valid_models(check_provider_endpoint: bool = False) -> List[str]:
|
||||
"""
|
||||
Returns a list of valid LLMs based on the set environment variables
|
||||
|
||||
Args:
|
||||
None
|
||||
check_provider_endpoint: If True, will check the provider's endpoint for valid models.
|
||||
|
||||
Returns:
|
||||
A list of valid LLMs
|
||||
|
@ -5725,22 +5730,36 @@ def get_valid_models() -> List[str]:
|
|||
|
||||
for provider in litellm.provider_list:
|
||||
# edge case litellm has together_ai as a provider, it should be togetherai
|
||||
provider = provider.replace("_", "")
|
||||
env_provider_1 = provider.replace("_", "")
|
||||
env_provider_2 = provider
|
||||
|
||||
# litellm standardizes expected provider keys to
|
||||
# PROVIDER_API_KEY. Example: OPENAI_API_KEY, COHERE_API_KEY
|
||||
expected_provider_key = f"{provider.upper()}_API_KEY"
|
||||
if expected_provider_key in environ_keys:
|
||||
expected_provider_key_1 = f"{env_provider_1.upper()}_API_KEY"
|
||||
expected_provider_key_2 = f"{env_provider_2.upper()}_API_KEY"
|
||||
if (
|
||||
expected_provider_key_1 in environ_keys
|
||||
or expected_provider_key_2 in environ_keys
|
||||
):
|
||||
# key is set
|
||||
valid_providers.append(provider)
|
||||
|
||||
for provider in valid_providers:
|
||||
provider_config = ProviderConfigManager.get_provider_model_info(
|
||||
model=None,
|
||||
provider=LlmProviders(provider),
|
||||
)
|
||||
|
||||
if provider == "azure":
|
||||
valid_models.append("Azure-LLM")
|
||||
elif provider_config is not None and check_provider_endpoint:
|
||||
valid_models.extend(provider_config.get_models())
|
||||
else:
|
||||
models_for_provider = litellm.models_by_provider.get(provider, [])
|
||||
valid_models.extend(models_for_provider)
|
||||
return valid_models
|
||||
except Exception:
|
||||
except Exception as e:
|
||||
verbose_logger.debug(f"Error getting valid models: {e}")
|
||||
return [] # NON-Blocking
|
||||
|
||||
|
||||
|
@ -6291,11 +6310,14 @@ class ProviderConfigManager:
|
|||
|
||||
@staticmethod
|
||||
def get_provider_model_info(
|
||||
model: str,
|
||||
model: Optional[str],
|
||||
provider: LlmProviders,
|
||||
) -> Optional[BaseLLMModelInfo]:
|
||||
if LlmProviders.FIREWORKS_AI == provider:
|
||||
return litellm.FireworksAIConfig()
|
||||
elif LlmProviders.LITELLM_PROXY == provider:
|
||||
return litellm.LiteLLMProxyChatConfig()
|
||||
|
||||
return None
|
||||
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue