fix(cost_calculator.py): handle custom pricing at deployment level fo… (#9855)

* fix(cost_calculator.py): handle custom pricing at deployment level for router

* test: add unit tests

* fix(router.py): show custom pricing on UI

check correct model str

* fix: fix linting error

* docs(custom_pricing.md): clarify custom pricing for proxy

Fixes https://github.com/BerriAI/litellm/issues/8573#issuecomment-2790420740

* test: update code qa test

* fix: cleanup traceback

* fix: handle litellm param custom pricing

* test: update test

* fix(cost_calculator.py): add router model id to list of potential model names

* fix(cost_calculator.py): fix router model id check

* fix: router.py - maintain older model registry approach

* fix: fix ruff check

* fix(router.py): router get deployment info

add custom values to mapped dict

* test: update test

* fix(utils.py): update only if value is non-null

* test: add unit test
This commit is contained in:
Krish Dholakia 2025-04-09 22:13:10 -07:00 committed by GitHub
parent baa9bd6338
commit e1eb5e32c1
16 changed files with 193 additions and 37 deletions

View file

@ -116,6 +116,7 @@ from litellm.types.router import (
AllowedFailsPolicy,
AssistantsTypedDict,
CredentialLiteLLMParams,
CustomPricingLiteLLMParams,
CustomRoutingStrategyBase,
Deployment,
DeploymentTypedDict,
@ -132,6 +133,7 @@ from litellm.types.router import (
)
from litellm.types.services import ServiceTypes
from litellm.types.utils import GenericBudgetConfigType
from litellm.types.utils import ModelInfo
from litellm.types.utils import ModelInfo as ModelMapInfo
from litellm.types.utils import StandardLoggingPayload
from litellm.utils import (
@ -3324,7 +3326,6 @@ class Router:
return response
except Exception as new_exception:
traceback.print_exc()
parent_otel_span = _get_parent_otel_span_from_kwargs(kwargs)
verbose_router_logger.error(
"litellm.router.py::async_function_with_fallbacks() - Error occurred while trying to do fallbacks - {}\n{}\n\nDebug Information:\nCooldown Deployments={}".format(
@ -4301,7 +4302,20 @@ class Router:
model_info=_model_info,
)
for field in CustomPricingLiteLLMParams.model_fields.keys():
if deployment.litellm_params.get(field) is not None:
_model_info[field] = deployment.litellm_params[field]
## REGISTER MODEL INFO IN LITELLM MODEL COST MAP
model_id = deployment.model_info.id
if model_id is not None:
litellm.register_model(
model_cost={
model_id: _model_info,
}
)
## OLD MODEL REGISTRATION ## Kept to prevent breaking changes
_model_name = deployment.litellm_params.model
if deployment.litellm_params.custom_llm_provider is not None:
_model_name = (
@ -4802,6 +4816,42 @@ class Router:
model_name = model_info["model_name"]
return self.get_model_list(model_name=model_name)
def get_deployment_model_info(
self, model_id: str, model_name: str
) -> Optional[ModelInfo]:
"""
For a given model id, return the model info
1. Check if model_id is in model info
2. If not, check if litellm model name is in model info
3. If not, return None
"""
from litellm.utils import _update_dictionary
model_info: Optional[ModelInfo] = None
litellm_model_name_model_info: Optional[ModelInfo] = None
try:
model_info = litellm.get_model_info(model=model_id)
except Exception:
pass
try:
litellm_model_name_model_info = litellm.get_model_info(model=model_name)
except Exception:
pass
if model_info is not None and litellm_model_name_model_info is not None:
model_info = cast(
ModelInfo,
_update_dictionary(
cast(dict, litellm_model_name_model_info).copy(),
cast(dict, model_info),
),
)
return model_info
def _set_model_group_info( # noqa: PLR0915
self, model_group: str, user_facing_model_group_name: str
) -> Optional[ModelGroupInfo]:
@ -4860,9 +4910,16 @@ class Router:
# get model info
try:
model_info = litellm.get_model_info(model=litellm_params.model)
model_id = model.get("model_info", {}).get("id", None)
if model_id is not None:
model_info = self.get_deployment_model_info(
model_id=model_id, model_name=litellm_params.model
)
else:
model_info = None
except Exception:
model_info = None
# get llm provider
litellm_model, llm_provider = "", ""
try: