Fix custom pricing - separate provider info from model info (#7990)
All checks were successful
Read Version from pyproject.toml / read-version (push) Successful in 34s

* fix(utils.py): initial commit fixing custom cost tracking

refactors out provider specific model info from `get_model_info` - this was causing custom costs to be registered incorrectly

* fix(utils.py): cleanup `_supports_factory` to check provider info, if model info is None

some providers support features like vision across all models

* fix(utils.py): refactor to use _supports_factory

* test: update testing

* fix: fix linting errors

* test: fix testing
This commit is contained in:
Krish Dholakia 2025-01-25 21:49:28 -08:00 committed by GitHub
parent d9b8100cca
commit 03eef5a2a0
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
10 changed files with 118 additions and 91 deletions

View file

@ -149,6 +149,7 @@ from litellm.types.utils import (
ModelResponse,
ModelResponseStream,
ProviderField,
ProviderSpecificModelInfo,
StreamingChoices,
TextChoices,
TextCompletionResponse,
@ -1898,6 +1899,13 @@ def _supports_factory(model: str, custom_llm_provider: Optional[str], key: str)
verbose_logger.debug(
f"Model not found or error in checking {key} support. You passed model={model}, custom_llm_provider={custom_llm_provider}. Error: {str(e)}"
)
provider_info = get_provider_info(
model=model, custom_llm_provider=custom_llm_provider
)
if provider_info is not None and provider_info.get(key, False) is True:
return True
return False
@ -1958,23 +1966,11 @@ def supports_vision(model: str, custom_llm_provider: Optional[str] = None) -> bo
Returns:
bool: True if the model supports vision, False otherwise.
"""
try:
model, custom_llm_provider, _, _ = litellm.get_llm_provider(
model=model, custom_llm_provider=custom_llm_provider
)
model_info = litellm.get_model_info(
model=model, custom_llm_provider=custom_llm_provider
)
if model_info.get("supports_vision", False) is True:
return True
return False
except Exception as e:
verbose_logger.error(
f"Model not found or error in checking vision support. You passed model={model}, custom_llm_provider={custom_llm_provider}. Error: {str(e)}"
)
return False
return _supports_factory(
model=model,
custom_llm_provider=custom_llm_provider,
key="supports_vision",
)
def supports_embedding_image_input(
@ -2037,6 +2033,7 @@ def register_model(model_cost: Union[str, dict]): # noqa: PLR0915
},
}
"""
loaded_model_cost = {}
if isinstance(model_cost, dict):
loaded_model_cost = model_cost
@ -2054,6 +2051,9 @@ def register_model(model_cost: Union[str, dict]): # noqa: PLR0915
## override / add new keys to the existing model cost dictionary
updated_dictionary = _update_dictionary(existing_model, value)
litellm.model_cost.setdefault(model_cost_key, {}).update(updated_dictionary)
verbose_logger.debug(
f"added/updated model={model_cost_key} in litellm.model_cost: {model_cost_key}"
)
# add new model names to provider lists
if value.get("litellm_provider") == "openai":
if key not in litellm.open_ai_chat_completion_models:
@ -4048,6 +4048,26 @@ def _cached_get_model_info_helper(
return _get_model_info_helper(model=model, custom_llm_provider=custom_llm_provider)
def get_provider_info(
model: str, custom_llm_provider: Optional[str]
) -> Optional[ProviderSpecificModelInfo]:
## PROVIDER-SPECIFIC INFORMATION
# if custom_llm_provider == "predibase":
# _model_info["supports_response_schema"] = True
provider_config: Optional[BaseLLMModelInfo] = None
if custom_llm_provider and custom_llm_provider in LlmProvidersSet:
# Check if the provider string exists in LlmProviders enum
provider_config = ProviderConfigManager.get_provider_model_info(
model=model, provider=LlmProviders(custom_llm_provider)
)
model_info: Optional[ProviderSpecificModelInfo] = None
if provider_config:
model_info = provider_config.get_provider_info(model=model)
return model_info
def _get_model_info_helper( # noqa: PLR0915
model: str, custom_llm_provider: Optional[str] = None
) -> ModelInfoBase:
@ -4071,6 +4091,11 @@ def _get_model_info_helper( # noqa: PLR0915
potential_model_names = _get_potential_model_names(
model=model, custom_llm_provider=custom_llm_provider
)
verbose_logger.debug(
f"checking potential_model_names in litellm.model_cost: {potential_model_names}"
)
combined_model_name = potential_model_names["combined_model_name"]
stripped_model_name = potential_model_names["stripped_model_name"]
combined_stripped_model_name = potential_model_names[
@ -4111,7 +4136,6 @@ def _get_model_info_helper( # noqa: PLR0915
_model_info: Optional[Dict[str, Any]] = None
key: Optional[str] = None
provider_config: Optional[BaseLLMModelInfo] = None
if combined_model_name in litellm.model_cost:
key = combined_model_name
@ -4121,6 +4145,7 @@ def _get_model_info_helper( # noqa: PLR0915
):
_model_info = None
if _model_info is None and model in litellm.model_cost:
key = model
_model_info = _get_model_info_from_model_cost(key=key)
if not _check_provider_match(
@ -4131,6 +4156,7 @@ def _get_model_info_helper( # noqa: PLR0915
_model_info is None
and combined_stripped_model_name in litellm.model_cost
):
key = combined_stripped_model_name
_model_info = _get_model_info_from_model_cost(key=key)
if not _check_provider_match(
@ -4138,6 +4164,7 @@ def _get_model_info_helper( # noqa: PLR0915
):
_model_info = None
if _model_info is None and stripped_model_name in litellm.model_cost:
key = stripped_model_name
_model_info = _get_model_info_from_model_cost(key=key)
if not _check_provider_match(
@ -4145,6 +4172,7 @@ def _get_model_info_helper( # noqa: PLR0915
):
_model_info = None
if _model_info is None and split_model in litellm.model_cost:
key = split_model
_model_info = _get_model_info_from_model_cost(key=key)
if not _check_provider_match(
@ -4152,29 +4180,11 @@ def _get_model_info_helper( # noqa: PLR0915
):
_model_info = None
if custom_llm_provider and custom_llm_provider in LlmProvidersSet:
# Check if the provider string exists in LlmProviders enum
provider_config = ProviderConfigManager.get_provider_model_info(
model=model, provider=LlmProviders(custom_llm_provider)
)
if _model_info is None and provider_config is not None:
_model_info = cast(
Optional[Dict],
provider_config.get_model_info(
model=model, existing_model_info=_model_info
),
)
key = "provider_specific_model_info"
if _model_info is None or key is None:
raise ValueError(
"This model isn't mapped yet. Add it here - https://github.com/BerriAI/litellm/blob/main/model_prices_and_context_window.json"
)
## PROVIDER-SPECIFIC INFORMATION
if custom_llm_provider == "predibase":
_model_info["supports_response_schema"] = True
_input_cost_per_token: Optional[float] = _model_info.get(
"input_cost_per_token"
)
@ -4357,6 +4367,8 @@ def get_model_info(model: str, custom_llm_provider: Optional[str] = None) -> Mod
custom_llm_provider=custom_llm_provider,
)
verbose_logger.debug(f"model_info: {_model_info}")
returned_model_info = ModelInfo(
**_model_info, supported_openai_params=supported_openai_params
)