mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 18:54:30 +00:00
fix(cost_calculator.py): handle custom pricing at deployment level fo… (#9855)
* fix(cost_calculator.py): handle custom pricing at deployment level for router * test: add unit tests * fix(router.py): show custom pricing on UI check correct model str * fix: fix linting error * docs(custom_pricing.md): clarify custom pricing for proxy Fixes https://github.com/BerriAI/litellm/issues/8573#issuecomment-2790420740 * test: update code qa test * fix: cleanup traceback * fix: handle litellm param custom pricing * test: update test * fix(cost_calculator.py): add router model id to list of potential model names * fix(cost_calculator.py): fix router model id check * fix: router.py - maintain older model registry approach * fix: fix ruff check * fix(router.py): router get deployment info add custom values to mapped dict * test: update test * fix(utils.py): update only if value is non-null * test: add unit test
This commit is contained in:
parent
0c5b4aa96d
commit
0dbd663877
16 changed files with 193 additions and 37 deletions
|
@ -116,6 +116,7 @@ from litellm.types.router import (
|
|||
AllowedFailsPolicy,
|
||||
AssistantsTypedDict,
|
||||
CredentialLiteLLMParams,
|
||||
CustomPricingLiteLLMParams,
|
||||
CustomRoutingStrategyBase,
|
||||
Deployment,
|
||||
DeploymentTypedDict,
|
||||
|
@ -132,6 +133,7 @@ from litellm.types.router import (
|
|||
)
|
||||
from litellm.types.services import ServiceTypes
|
||||
from litellm.types.utils import GenericBudgetConfigType
|
||||
from litellm.types.utils import ModelInfo
|
||||
from litellm.types.utils import ModelInfo as ModelMapInfo
|
||||
from litellm.types.utils import StandardLoggingPayload
|
||||
from litellm.utils import (
|
||||
|
@ -3324,7 +3326,6 @@ class Router:
|
|||
|
||||
return response
|
||||
except Exception as new_exception:
|
||||
traceback.print_exc()
|
||||
parent_otel_span = _get_parent_otel_span_from_kwargs(kwargs)
|
||||
verbose_router_logger.error(
|
||||
"litellm.router.py::async_function_with_fallbacks() - Error occurred while trying to do fallbacks - {}\n{}\n\nDebug Information:\nCooldown Deployments={}".format(
|
||||
|
@ -4301,7 +4302,20 @@ class Router:
|
|||
model_info=_model_info,
|
||||
)
|
||||
|
||||
for field in CustomPricingLiteLLMParams.model_fields.keys():
|
||||
if deployment.litellm_params.get(field) is not None:
|
||||
_model_info[field] = deployment.litellm_params[field]
|
||||
|
||||
## REGISTER MODEL INFO IN LITELLM MODEL COST MAP
|
||||
model_id = deployment.model_info.id
|
||||
if model_id is not None:
|
||||
litellm.register_model(
|
||||
model_cost={
|
||||
model_id: _model_info,
|
||||
}
|
||||
)
|
||||
|
||||
## OLD MODEL REGISTRATION ## Kept to prevent breaking changes
|
||||
_model_name = deployment.litellm_params.model
|
||||
if deployment.litellm_params.custom_llm_provider is not None:
|
||||
_model_name = (
|
||||
|
@ -4802,6 +4816,42 @@ class Router:
|
|||
model_name = model_info["model_name"]
|
||||
return self.get_model_list(model_name=model_name)
|
||||
|
||||
def get_deployment_model_info(
|
||||
self, model_id: str, model_name: str
|
||||
) -> Optional[ModelInfo]:
|
||||
"""
|
||||
For a given model id, return the model info
|
||||
|
||||
1. Check if model_id is in model info
|
||||
2. If not, check if litellm model name is in model info
|
||||
3. If not, return None
|
||||
"""
|
||||
from litellm.utils import _update_dictionary
|
||||
|
||||
model_info: Optional[ModelInfo] = None
|
||||
litellm_model_name_model_info: Optional[ModelInfo] = None
|
||||
|
||||
try:
|
||||
model_info = litellm.get_model_info(model=model_id)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
try:
|
||||
litellm_model_name_model_info = litellm.get_model_info(model=model_name)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if model_info is not None and litellm_model_name_model_info is not None:
|
||||
model_info = cast(
|
||||
ModelInfo,
|
||||
_update_dictionary(
|
||||
cast(dict, litellm_model_name_model_info).copy(),
|
||||
cast(dict, model_info),
|
||||
),
|
||||
)
|
||||
|
||||
return model_info
|
||||
|
||||
def _set_model_group_info( # noqa: PLR0915
|
||||
self, model_group: str, user_facing_model_group_name: str
|
||||
) -> Optional[ModelGroupInfo]:
|
||||
|
@ -4860,9 +4910,16 @@ class Router:
|
|||
|
||||
# get model info
|
||||
try:
|
||||
model_info = litellm.get_model_info(model=litellm_params.model)
|
||||
model_id = model.get("model_info", {}).get("id", None)
|
||||
if model_id is not None:
|
||||
model_info = self.get_deployment_model_info(
|
||||
model_id=model_id, model_name=litellm_params.model
|
||||
)
|
||||
else:
|
||||
model_info = None
|
||||
except Exception:
|
||||
model_info = None
|
||||
|
||||
# get llm provider
|
||||
litellm_model, llm_provider = "", ""
|
||||
try:
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue