mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 18:54:30 +00:00
Fix custom pricing - separate provider info from model info (#7990)
All checks were successful
Read Version from pyproject.toml / read-version (push) Successful in 34s
All checks were successful
Read Version from pyproject.toml / read-version (push) Successful in 34s
* fix(utils.py): initial commit fixing custom cost tracking refactors out provider specific model info from `get_model_info` - this was causing custom costs to be registered incorrectly * fix(utils.py): cleanup `_supports_factory` to check provider info, if model info is None some providers support features like vision across all models * fix(utils.py): refactor to use _supports_factory * test: update testing * fix: fix linting errors * test: fix testing
This commit is contained in:
parent
d9b8100cca
commit
03eef5a2a0
10 changed files with 118 additions and 91 deletions
|
@ -149,6 +149,7 @@ from litellm.types.utils import (
|
|||
ModelResponse,
|
||||
ModelResponseStream,
|
||||
ProviderField,
|
||||
ProviderSpecificModelInfo,
|
||||
StreamingChoices,
|
||||
TextChoices,
|
||||
TextCompletionResponse,
|
||||
|
@ -1898,6 +1899,13 @@ def _supports_factory(model: str, custom_llm_provider: Optional[str], key: str)
|
|||
verbose_logger.debug(
|
||||
f"Model not found or error in checking {key} support. You passed model={model}, custom_llm_provider={custom_llm_provider}. Error: {str(e)}"
|
||||
)
|
||||
|
||||
provider_info = get_provider_info(
|
||||
model=model, custom_llm_provider=custom_llm_provider
|
||||
)
|
||||
|
||||
if provider_info is not None and provider_info.get(key, False) is True:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
|
@ -1958,23 +1966,11 @@ def supports_vision(model: str, custom_llm_provider: Optional[str] = None) -> bo
|
|||
Returns:
|
||||
bool: True if the model supports vision, False otherwise.
|
||||
"""
|
||||
try:
|
||||
model, custom_llm_provider, _, _ = litellm.get_llm_provider(
|
||||
model=model, custom_llm_provider=custom_llm_provider
|
||||
)
|
||||
|
||||
model_info = litellm.get_model_info(
|
||||
model=model, custom_llm_provider=custom_llm_provider
|
||||
)
|
||||
|
||||
if model_info.get("supports_vision", False) is True:
|
||||
return True
|
||||
return False
|
||||
except Exception as e:
|
||||
verbose_logger.error(
|
||||
f"Model not found or error in checking vision support. You passed model={model}, custom_llm_provider={custom_llm_provider}. Error: {str(e)}"
|
||||
)
|
||||
return False
|
||||
return _supports_factory(
|
||||
model=model,
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
key="supports_vision",
|
||||
)
|
||||
|
||||
|
||||
def supports_embedding_image_input(
|
||||
|
@ -2037,6 +2033,7 @@ def register_model(model_cost: Union[str, dict]): # noqa: PLR0915
|
|||
},
|
||||
}
|
||||
"""
|
||||
|
||||
loaded_model_cost = {}
|
||||
if isinstance(model_cost, dict):
|
||||
loaded_model_cost = model_cost
|
||||
|
@ -2054,6 +2051,9 @@ def register_model(model_cost: Union[str, dict]): # noqa: PLR0915
|
|||
## override / add new keys to the existing model cost dictionary
|
||||
updated_dictionary = _update_dictionary(existing_model, value)
|
||||
litellm.model_cost.setdefault(model_cost_key, {}).update(updated_dictionary)
|
||||
verbose_logger.debug(
|
||||
f"added/updated model={model_cost_key} in litellm.model_cost: {model_cost_key}"
|
||||
)
|
||||
# add new model names to provider lists
|
||||
if value.get("litellm_provider") == "openai":
|
||||
if key not in litellm.open_ai_chat_completion_models:
|
||||
|
@ -4048,6 +4048,26 @@ def _cached_get_model_info_helper(
|
|||
return _get_model_info_helper(model=model, custom_llm_provider=custom_llm_provider)
|
||||
|
||||
|
||||
def get_provider_info(
|
||||
model: str, custom_llm_provider: Optional[str]
|
||||
) -> Optional[ProviderSpecificModelInfo]:
|
||||
## PROVIDER-SPECIFIC INFORMATION
|
||||
# if custom_llm_provider == "predibase":
|
||||
# _model_info["supports_response_schema"] = True
|
||||
provider_config: Optional[BaseLLMModelInfo] = None
|
||||
if custom_llm_provider and custom_llm_provider in LlmProvidersSet:
|
||||
# Check if the provider string exists in LlmProviders enum
|
||||
provider_config = ProviderConfigManager.get_provider_model_info(
|
||||
model=model, provider=LlmProviders(custom_llm_provider)
|
||||
)
|
||||
|
||||
model_info: Optional[ProviderSpecificModelInfo] = None
|
||||
if provider_config:
|
||||
model_info = provider_config.get_provider_info(model=model)
|
||||
|
||||
return model_info
|
||||
|
||||
|
||||
def _get_model_info_helper( # noqa: PLR0915
|
||||
model: str, custom_llm_provider: Optional[str] = None
|
||||
) -> ModelInfoBase:
|
||||
|
@ -4071,6 +4091,11 @@ def _get_model_info_helper( # noqa: PLR0915
|
|||
potential_model_names = _get_potential_model_names(
|
||||
model=model, custom_llm_provider=custom_llm_provider
|
||||
)
|
||||
|
||||
verbose_logger.debug(
|
||||
f"checking potential_model_names in litellm.model_cost: {potential_model_names}"
|
||||
)
|
||||
|
||||
combined_model_name = potential_model_names["combined_model_name"]
|
||||
stripped_model_name = potential_model_names["stripped_model_name"]
|
||||
combined_stripped_model_name = potential_model_names[
|
||||
|
@ -4111,7 +4136,6 @@ def _get_model_info_helper( # noqa: PLR0915
|
|||
|
||||
_model_info: Optional[Dict[str, Any]] = None
|
||||
key: Optional[str] = None
|
||||
provider_config: Optional[BaseLLMModelInfo] = None
|
||||
|
||||
if combined_model_name in litellm.model_cost:
|
||||
key = combined_model_name
|
||||
|
@ -4121,6 +4145,7 @@ def _get_model_info_helper( # noqa: PLR0915
|
|||
):
|
||||
_model_info = None
|
||||
if _model_info is None and model in litellm.model_cost:
|
||||
|
||||
key = model
|
||||
_model_info = _get_model_info_from_model_cost(key=key)
|
||||
if not _check_provider_match(
|
||||
|
@ -4131,6 +4156,7 @@ def _get_model_info_helper( # noqa: PLR0915
|
|||
_model_info is None
|
||||
and combined_stripped_model_name in litellm.model_cost
|
||||
):
|
||||
|
||||
key = combined_stripped_model_name
|
||||
_model_info = _get_model_info_from_model_cost(key=key)
|
||||
if not _check_provider_match(
|
||||
|
@ -4138,6 +4164,7 @@ def _get_model_info_helper( # noqa: PLR0915
|
|||
):
|
||||
_model_info = None
|
||||
if _model_info is None and stripped_model_name in litellm.model_cost:
|
||||
|
||||
key = stripped_model_name
|
||||
_model_info = _get_model_info_from_model_cost(key=key)
|
||||
if not _check_provider_match(
|
||||
|
@ -4145,6 +4172,7 @@ def _get_model_info_helper( # noqa: PLR0915
|
|||
):
|
||||
_model_info = None
|
||||
if _model_info is None and split_model in litellm.model_cost:
|
||||
|
||||
key = split_model
|
||||
_model_info = _get_model_info_from_model_cost(key=key)
|
||||
if not _check_provider_match(
|
||||
|
@ -4152,29 +4180,11 @@ def _get_model_info_helper( # noqa: PLR0915
|
|||
):
|
||||
_model_info = None
|
||||
|
||||
if custom_llm_provider and custom_llm_provider in LlmProvidersSet:
|
||||
# Check if the provider string exists in LlmProviders enum
|
||||
provider_config = ProviderConfigManager.get_provider_model_info(
|
||||
model=model, provider=LlmProviders(custom_llm_provider)
|
||||
)
|
||||
|
||||
if _model_info is None and provider_config is not None:
|
||||
_model_info = cast(
|
||||
Optional[Dict],
|
||||
provider_config.get_model_info(
|
||||
model=model, existing_model_info=_model_info
|
||||
),
|
||||
)
|
||||
key = "provider_specific_model_info"
|
||||
if _model_info is None or key is None:
|
||||
raise ValueError(
|
||||
"This model isn't mapped yet. Add it here - https://github.com/BerriAI/litellm/blob/main/model_prices_and_context_window.json"
|
||||
)
|
||||
|
||||
## PROVIDER-SPECIFIC INFORMATION
|
||||
if custom_llm_provider == "predibase":
|
||||
_model_info["supports_response_schema"] = True
|
||||
|
||||
_input_cost_per_token: Optional[float] = _model_info.get(
|
||||
"input_cost_per_token"
|
||||
)
|
||||
|
@ -4357,6 +4367,8 @@ def get_model_info(model: str, custom_llm_provider: Optional[str] = None) -> Mod
|
|||
custom_llm_provider=custom_llm_provider,
|
||||
)
|
||||
|
||||
verbose_logger.debug(f"model_info: {_model_info}")
|
||||
|
||||
returned_model_info = ModelInfo(
|
||||
**_model_info, supported_openai_params=supported_openai_params
|
||||
)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue