fix(cost_calculator.py): handle custom pricing at deployment level fo… (#9855)

* fix(cost_calculator.py): handle custom pricing at deployment level for router

* test: add unit tests

* fix(router.py): show custom pricing on UI

check correct model str

* fix: fix linting error

* docs(custom_pricing.md): clarify custom pricing for proxy

Fixes https://github.com/BerriAI/litellm/issues/8573#issuecomment-2790420740

* test: update code qa test

* fix: cleanup traceback

* fix: handle litellm param custom pricing

* test: update test

* fix(cost_calculator.py): add router model id to list of potential model names

* fix(cost_calculator.py): fix router model id check

* fix: router.py - maintain older model registry approach

* fix: fix ruff check

* fix(router.py): router get deployment info

add custom values to mapped dict

* test: update test

* fix(utils.py): update only if value is non-null

* test: add unit test
This commit is contained in:
Krish Dholakia 2025-04-09 22:13:10 -07:00 committed by GitHub
parent 0c5b4aa96d
commit 0dbd663877
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
16 changed files with 193 additions and 37 deletions

View file

@ -26,10 +26,12 @@ model_list:
- model_name: sagemaker-completion-model - model_name: sagemaker-completion-model
litellm_params: litellm_params:
model: sagemaker/berri-benchmarking-Llama-2-70b-chat-hf-4 model: sagemaker/berri-benchmarking-Llama-2-70b-chat-hf-4
model_info:
input_cost_per_second: 0.000420 input_cost_per_second: 0.000420
- model_name: sagemaker-embedding-model - model_name: sagemaker-embedding-model
litellm_params: litellm_params:
model: sagemaker/berri-benchmarking-gpt-j-6b-fp16 model: sagemaker/berri-benchmarking-gpt-j-6b-fp16
model_info:
input_cost_per_second: 0.000420 input_cost_per_second: 0.000420
``` ```
@ -55,11 +57,33 @@ model_list:
api_key: os.environ/AZURE_API_KEY api_key: os.environ/AZURE_API_KEY
api_base: os.environ/AZURE_API_BASE api_base: os.environ/AZURE_API_BASE
api_version: os.envrion/AZURE_API_VERSION api_version: os.envrion/AZURE_API_VERSION
model_info:
input_cost_per_token: 0.000421 # 👈 ONLY to track cost per token input_cost_per_token: 0.000421 # 👈 ONLY to track cost per token
output_cost_per_token: 0.000520 # 👈 ONLY to track cost per token output_cost_per_token: 0.000520 # 👈 ONLY to track cost per token
``` ```
### Debugging ## Override Model Cost Map
You can override [our model cost map](https://github.com/BerriAI/litellm/blob/main/model_prices_and_context_window.json) with your own custom pricing for a mapped model.
Just add a `model_info` key to your model in the config, and override the desired keys.
Example: Override Anthropic's model cost map for the `prod/claude-3-5-sonnet-20241022` model.
```yaml
model_list:
- model_name: "prod/claude-3-5-sonnet-20241022"
litellm_params:
model: "anthropic/claude-3-5-sonnet-20241022"
api_key: os.environ/ANTHROPIC_PROD_API_KEY
model_info:
input_cost_per_token: 0.000006
output_cost_per_token: 0.00003
cache_creation_input_token_cost: 0.0000075
cache_read_input_token_cost: 0.0000006
```
## Debugging
If you're custom pricing is not being used or you're seeing errors, please check the following: If you're custom pricing is not being used or you're seeing errors, please check the following:

View file

@ -403,6 +403,7 @@ def _select_model_name_for_cost_calc(
base_model: Optional[str] = None, base_model: Optional[str] = None,
custom_pricing: Optional[bool] = None, custom_pricing: Optional[bool] = None,
custom_llm_provider: Optional[str] = None, custom_llm_provider: Optional[str] = None,
router_model_id: Optional[str] = None,
) -> Optional[str]: ) -> Optional[str]:
""" """
1. If custom pricing is true, return received model name 1. If custom pricing is true, return received model name
@ -417,12 +418,6 @@ def _select_model_name_for_cost_calc(
model=model, custom_llm_provider=custom_llm_provider model=model, custom_llm_provider=custom_llm_provider
) )
if custom_pricing is True:
return_model = model
if base_model is not None:
return_model = base_model
completion_response_model: Optional[str] = None completion_response_model: Optional[str] = None
if completion_response is not None: if completion_response is not None:
if isinstance(completion_response, BaseModel): if isinstance(completion_response, BaseModel):
@ -430,6 +425,16 @@ def _select_model_name_for_cost_calc(
elif isinstance(completion_response, dict): elif isinstance(completion_response, dict):
completion_response_model = completion_response.get("model", None) completion_response_model = completion_response.get("model", None)
hidden_params: Optional[dict] = getattr(completion_response, "_hidden_params", None) hidden_params: Optional[dict] = getattr(completion_response, "_hidden_params", None)
if custom_pricing is True:
if router_model_id is not None and router_model_id in litellm.model_cost:
return_model = router_model_id
else:
return_model = model
if base_model is not None:
return_model = base_model
if completion_response_model is None and hidden_params is not None: if completion_response_model is None and hidden_params is not None:
if ( if (
hidden_params.get("model", None) is not None hidden_params.get("model", None) is not None
@ -559,6 +564,7 @@ def completion_cost( # noqa: PLR0915
base_model: Optional[str] = None, base_model: Optional[str] = None,
standard_built_in_tools_params: Optional[StandardBuiltInToolsParams] = None, standard_built_in_tools_params: Optional[StandardBuiltInToolsParams] = None,
litellm_model_name: Optional[str] = None, litellm_model_name: Optional[str] = None,
router_model_id: Optional[str] = None,
) -> float: ) -> float:
""" """
Calculate the cost of a given completion call fot GPT-3.5-turbo, llama2, any litellm supported llm. Calculate the cost of a given completion call fot GPT-3.5-turbo, llama2, any litellm supported llm.
@ -617,12 +623,12 @@ def completion_cost( # noqa: PLR0915
custom_llm_provider=custom_llm_provider, custom_llm_provider=custom_llm_provider,
custom_pricing=custom_pricing, custom_pricing=custom_pricing,
base_model=base_model, base_model=base_model,
router_model_id=router_model_id,
) )
potential_model_names = [selected_model] potential_model_names = [selected_model]
if model is not None: if model is not None:
potential_model_names.append(model) potential_model_names.append(model)
for idx, model in enumerate(potential_model_names): for idx, model in enumerate(potential_model_names):
try: try:
verbose_logger.info( verbose_logger.info(
@ -943,6 +949,7 @@ def response_cost_calculator(
prompt: str = "", prompt: str = "",
standard_built_in_tools_params: Optional[StandardBuiltInToolsParams] = None, standard_built_in_tools_params: Optional[StandardBuiltInToolsParams] = None,
litellm_model_name: Optional[str] = None, litellm_model_name: Optional[str] = None,
router_model_id: Optional[str] = None,
) -> float: ) -> float:
""" """
Returns Returns
@ -973,6 +980,8 @@ def response_cost_calculator(
base_model=base_model, base_model=base_model,
prompt=prompt, prompt=prompt,
standard_built_in_tools_params=standard_built_in_tools_params, standard_built_in_tools_params=standard_built_in_tools_params,
litellm_model_name=litellm_model_name,
router_model_id=router_model_id,
) )
return response_cost return response_cost
except Exception as e: except Exception as e:

View file

@ -1,7 +1,6 @@
from typing import Literal, Optional from typing import Literal, Optional
import litellm import litellm
from litellm._logging import verbose_logger
from litellm.exceptions import BadRequestError from litellm.exceptions import BadRequestError
from litellm.types.utils import LlmProviders, LlmProvidersSet from litellm.types.utils import LlmProviders, LlmProvidersSet
@ -43,9 +42,6 @@ def get_supported_openai_params( # noqa: PLR0915
provider_config = None provider_config = None
if provider_config and request_type == "chat_completion": if provider_config and request_type == "chat_completion":
verbose_logger.info(
f"using provider_config: {provider_config} for checking supported params"
)
return provider_config.get_supported_openai_params(model=model) return provider_config.get_supported_openai_params(model=model)
if custom_llm_provider == "bedrock": if custom_llm_provider == "bedrock":

View file

@ -622,7 +622,6 @@ class Logging(LiteLLMLoggingBaseClass):
] = RawRequestTypedDict( ] = RawRequestTypedDict(
error=str(e), error=str(e),
) )
traceback.print_exc()
_metadata[ _metadata[
"raw_request" "raw_request"
] = "Unable to Log \ ] = "Unable to Log \
@ -906,6 +905,7 @@ class Logging(LiteLLMLoggingBaseClass):
], ],
cache_hit: Optional[bool] = None, cache_hit: Optional[bool] = None,
litellm_model_name: Optional[str] = None, litellm_model_name: Optional[str] = None,
router_model_id: Optional[str] = None,
) -> Optional[float]: ) -> Optional[float]:
""" """
Calculate response cost using result + logging object variables. Calculate response cost using result + logging object variables.
@ -944,6 +944,7 @@ class Logging(LiteLLMLoggingBaseClass):
"custom_pricing": custom_pricing, "custom_pricing": custom_pricing,
"prompt": prompt, "prompt": prompt,
"standard_built_in_tools_params": self.standard_built_in_tools_params, "standard_built_in_tools_params": self.standard_built_in_tools_params,
"router_model_id": router_model_id,
} }
except Exception as e: # error creating kwargs for cost calculation except Exception as e: # error creating kwargs for cost calculation
debug_info = StandardLoggingModelCostFailureDebugInformation( debug_info = StandardLoggingModelCostFailureDebugInformation(

View file

@ -36,11 +36,16 @@ class ResponseMetadata:
self, logging_obj: LiteLLMLoggingObject, model: Optional[str], kwargs: dict self, logging_obj: LiteLLMLoggingObject, model: Optional[str], kwargs: dict
) -> None: ) -> None:
"""Set hidden parameters on the response""" """Set hidden parameters on the response"""
## ADD OTHER HIDDEN PARAMS
model_id = kwargs.get("model_info", {}).get("id", None)
new_params = { new_params = {
"litellm_call_id": getattr(logging_obj, "litellm_call_id", None), "litellm_call_id": getattr(logging_obj, "litellm_call_id", None),
"model_id": kwargs.get("model_info", {}).get("id", None),
"api_base": get_api_base(model=model or "", optional_params=kwargs), "api_base": get_api_base(model=model or "", optional_params=kwargs),
"response_cost": logging_obj._response_cost_calculator(result=self.result), "model_id": model_id,
"response_cost": logging_obj._response_cost_calculator(
result=self.result, litellm_model_name=model, router_model_id=model_id
),
"additional_headers": process_response_headers( "additional_headers": process_response_headers(
self._get_value_from_hidden_params("additional_headers") or {} self._get_value_from_hidden_params("additional_headers") or {}
), ),

View file

@ -1,7 +1,6 @@
import copy import copy
import json import json
import re import re
import traceback
import uuid import uuid
import xml.etree.ElementTree as ET import xml.etree.ElementTree as ET
from enum import Enum from enum import Enum
@ -748,7 +747,6 @@ def convert_to_anthropic_image_obj(
data=base64_data, data=base64_data,
) )
except Exception as e: except Exception as e:
traceback.print_exc()
if "Error: Unable to fetch image from URL" in str(e): if "Error: Unable to fetch image from URL" in str(e):
raise e raise e
raise Exception( raise Exception(

View file

@ -100,7 +100,6 @@ async def cache_ping():
except Exception as e: except Exception as e:
import traceback import traceback
traceback.print_exc()
error_message = { error_message = {
"message": f"Service Unhealthy ({str(e)})", "message": f"Service Unhealthy ({str(e)})",
"litellm_cache_params": safe_dumps(litellm_cache_params), "litellm_cache_params": safe_dumps(litellm_cache_params),

View file

@ -816,9 +816,6 @@ async def add_member_to_organization(
return user_object, organization_membership return user_object, organization_membership
except Exception as e: except Exception as e:
import traceback
traceback.print_exc()
raise ValueError( raise ValueError(
f"Error adding member={member} to organization={organization_id}: {e}" f"Error adding member={member} to organization={organization_id}: {e}"
) )

View file

@ -116,6 +116,7 @@ from litellm.types.router import (
AllowedFailsPolicy, AllowedFailsPolicy,
AssistantsTypedDict, AssistantsTypedDict,
CredentialLiteLLMParams, CredentialLiteLLMParams,
CustomPricingLiteLLMParams,
CustomRoutingStrategyBase, CustomRoutingStrategyBase,
Deployment, Deployment,
DeploymentTypedDict, DeploymentTypedDict,
@ -132,6 +133,7 @@ from litellm.types.router import (
) )
from litellm.types.services import ServiceTypes from litellm.types.services import ServiceTypes
from litellm.types.utils import GenericBudgetConfigType from litellm.types.utils import GenericBudgetConfigType
from litellm.types.utils import ModelInfo
from litellm.types.utils import ModelInfo as ModelMapInfo from litellm.types.utils import ModelInfo as ModelMapInfo
from litellm.types.utils import StandardLoggingPayload from litellm.types.utils import StandardLoggingPayload
from litellm.utils import ( from litellm.utils import (
@ -3324,7 +3326,6 @@ class Router:
return response return response
except Exception as new_exception: except Exception as new_exception:
traceback.print_exc()
parent_otel_span = _get_parent_otel_span_from_kwargs(kwargs) parent_otel_span = _get_parent_otel_span_from_kwargs(kwargs)
verbose_router_logger.error( verbose_router_logger.error(
"litellm.router.py::async_function_with_fallbacks() - Error occurred while trying to do fallbacks - {}\n{}\n\nDebug Information:\nCooldown Deployments={}".format( "litellm.router.py::async_function_with_fallbacks() - Error occurred while trying to do fallbacks - {}\n{}\n\nDebug Information:\nCooldown Deployments={}".format(
@ -4301,7 +4302,20 @@ class Router:
model_info=_model_info, model_info=_model_info,
) )
for field in CustomPricingLiteLLMParams.model_fields.keys():
if deployment.litellm_params.get(field) is not None:
_model_info[field] = deployment.litellm_params[field]
## REGISTER MODEL INFO IN LITELLM MODEL COST MAP ## REGISTER MODEL INFO IN LITELLM MODEL COST MAP
model_id = deployment.model_info.id
if model_id is not None:
litellm.register_model(
model_cost={
model_id: _model_info,
}
)
## OLD MODEL REGISTRATION ## Kept to prevent breaking changes
_model_name = deployment.litellm_params.model _model_name = deployment.litellm_params.model
if deployment.litellm_params.custom_llm_provider is not None: if deployment.litellm_params.custom_llm_provider is not None:
_model_name = ( _model_name = (
@ -4802,6 +4816,42 @@ class Router:
model_name = model_info["model_name"] model_name = model_info["model_name"]
return self.get_model_list(model_name=model_name) return self.get_model_list(model_name=model_name)
def get_deployment_model_info(
self, model_id: str, model_name: str
) -> Optional[ModelInfo]:
"""
For a given model id, return the model info
1. Check if model_id is in model info
2. If not, check if litellm model name is in model info
3. If not, return None
"""
from litellm.utils import _update_dictionary
model_info: Optional[ModelInfo] = None
litellm_model_name_model_info: Optional[ModelInfo] = None
try:
model_info = litellm.get_model_info(model=model_id)
except Exception:
pass
try:
litellm_model_name_model_info = litellm.get_model_info(model=model_name)
except Exception:
pass
if model_info is not None and litellm_model_name_model_info is not None:
model_info = cast(
ModelInfo,
_update_dictionary(
cast(dict, litellm_model_name_model_info).copy(),
cast(dict, model_info),
),
)
return model_info
def _set_model_group_info( # noqa: PLR0915 def _set_model_group_info( # noqa: PLR0915
self, model_group: str, user_facing_model_group_name: str self, model_group: str, user_facing_model_group_name: str
) -> Optional[ModelGroupInfo]: ) -> Optional[ModelGroupInfo]:
@ -4860,9 +4910,16 @@ class Router:
# get model info # get model info
try: try:
model_info = litellm.get_model_info(model=litellm_params.model) model_id = model.get("model_info", {}).get("id", None)
if model_id is not None:
model_info = self.get_deployment_model_info(
model_id=model_id, model_name=litellm_params.model
)
else:
model_info = None
except Exception: except Exception:
model_info = None model_info = None
# get llm provider # get llm provider
litellm_model, llm_provider = "", "" litellm_model, llm_provider = "", ""
try: try:

View file

@ -162,7 +162,15 @@ class CredentialLiteLLMParams(BaseModel):
watsonx_region_name: Optional[str] = None watsonx_region_name: Optional[str] = None
class GenericLiteLLMParams(CredentialLiteLLMParams): class CustomPricingLiteLLMParams(BaseModel):
## CUSTOM PRICING ##
input_cost_per_token: Optional[float] = None
output_cost_per_token: Optional[float] = None
input_cost_per_second: Optional[float] = None
output_cost_per_second: Optional[float] = None
class GenericLiteLLMParams(CredentialLiteLLMParams, CustomPricingLiteLLMParams):
""" """
LiteLLM Params without 'model' arg (used across completion / assistants api) LiteLLM Params without 'model' arg (used across completion / assistants api)
""" """
@ -184,12 +192,6 @@ class GenericLiteLLMParams(CredentialLiteLLMParams):
## LOGGING PARAMS ## ## LOGGING PARAMS ##
litellm_trace_id: Optional[str] = None litellm_trace_id: Optional[str] = None
## CUSTOM PRICING ##
input_cost_per_token: Optional[float] = None
output_cost_per_token: Optional[float] = None
input_cost_per_second: Optional[float] = None
output_cost_per_second: Optional[float] = None
max_file_size_mb: Optional[float] = None max_file_size_mb: Optional[float] = None
# Deployment budgets # Deployment budgets

View file

@ -2245,7 +2245,8 @@ def supports_embedding_image_input(
####### HELPER FUNCTIONS ################ ####### HELPER FUNCTIONS ################
def _update_dictionary(existing_dict: Dict, new_dict: dict) -> dict: def _update_dictionary(existing_dict: Dict, new_dict: dict) -> dict:
for k, v in new_dict.items(): for k, v in new_dict.items():
existing_dict[k] = v if v is not None:
existing_dict[k] = v
return existing_dict return existing_dict

View file

@ -31,7 +31,7 @@ def get_all_functions_called_in_tests(base_dir):
specifically in files containing the word 'router'. specifically in files containing the word 'router'.
""" """
called_functions = set() called_functions = set()
test_dirs = ["local_testing", "router_unit_tests"] test_dirs = ["local_testing", "router_unit_tests", "litellm"]
for test_dir in test_dirs: for test_dir in test_dirs:
dir_path = os.path.join(base_dir, test_dir) dir_path = os.path.join(base_dir, test_dir)

View file

@ -151,3 +151,63 @@ def test_handle_realtime_stream_cost_calculation():
litellm_model_name="gpt-3.5-turbo", litellm_model_name="gpt-3.5-turbo",
) )
assert cost == 0.0 # No usage, no cost assert cost == 0.0 # No usage, no cost
def test_custom_pricing_with_router_model_id():
from litellm import Router
router = Router(
model_list=[
{
"model_name": "prod/claude-3-5-sonnet-20240620",
"litellm_params": {
"model": "anthropic/claude-3-5-sonnet-20240620",
"api_key": "test_api_key",
},
"model_info": {
"id": "my-unique-model-id",
"input_cost_per_token": 0.000006,
"output_cost_per_token": 0.00003,
"cache_creation_input_token_cost": 0.0000075,
"cache_read_input_token_cost": 0.0000006,
},
},
{
"model_name": "claude-3-5-sonnet-20240620",
"litellm_params": {
"model": "anthropic/claude-3-5-sonnet-20240620",
"api_key": "test_api_key",
},
"model_info": {
"input_cost_per_token": 100,
"output_cost_per_token": 200,
},
},
]
)
result = router.completion(
model="claude-3-5-sonnet-20240620",
messages=[{"role": "user", "content": "Hello, world!"}],
mock_response=True,
)
result_2 = router.completion(
model="prod/claude-3-5-sonnet-20240620",
messages=[{"role": "user", "content": "Hello, world!"}],
mock_response=True,
)
assert (
result._hidden_params["response_cost"]
> result_2._hidden_params["response_cost"]
)
model_info = router.get_deployment_model_info(
model_id="my-unique-model-id", model_name="anthropic/claude-3-5-sonnet-20240620"
)
assert model_info is not None
assert model_info["input_cost_per_token"] == 0.000006
assert model_info["output_cost_per_token"] == 0.00003
assert model_info["cache_creation_input_token_cost"] == 0.0000075
assert model_info["cache_read_input_token_cost"] == 0.0000006

View file

@ -2954,9 +2954,6 @@ def test_cost_calculator_with_custom_pricing():
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_cost_calculator_with_custom_pricing_router(model_item, custom_pricing): async def test_cost_calculator_with_custom_pricing_router(model_item, custom_pricing):
from litellm import Router from litellm import Router
litellm._turn_on_debug()
if custom_pricing == "litellm_params": if custom_pricing == "litellm_params":
model_item["litellm_params"]["input_cost_per_token"] = 0.0000008 model_item["litellm_params"]["input_cost_per_token"] = 0.0000008
model_item["litellm_params"]["output_cost_per_token"] = 0.0000032 model_item["litellm_params"]["output_cost_per_token"] = 0.0000032

View file

@ -314,12 +314,14 @@ def test_get_model_info_custom_model_router():
"input_cost_per_token": 1, "input_cost_per_token": 1,
"output_cost_per_token": 1, "output_cost_per_token": 1,
"model": "openai/meta-llama/Meta-Llama-3-8B-Instruct", "model": "openai/meta-llama/Meta-Llama-3-8B-Instruct",
"model_id": "c20d603e-1166-4e0f-aa65-ed9c476ad4ca",
}, },
"model_info": {
"id": "c20d603e-1166-4e0f-aa65-ed9c476ad4ca",
}
} }
] ]
) )
info = get_model_info("openai/meta-llama/Meta-Llama-3-8B-Instruct") info = get_model_info("c20d603e-1166-4e0f-aa65-ed9c476ad4ca")
print("info", info) print("info", info)
assert info is not None assert info is not None

View file

@ -451,3 +451,11 @@ def test_router_get_deployment_credentials():
credentials = router.get_deployment_credentials(model_id="1") credentials = router.get_deployment_credentials(model_id="1")
assert credentials is not None assert credentials is not None
assert credentials["api_key"] == "123" assert credentials["api_key"] == "123"
def test_router_get_deployment_model_info():
router = Router(
model_list=[{"model_name": "gemini/*", "litellm_params": {"model": "gemini/*"}, "model_info": {"id": "1"}}]
)
model_info = router.get_deployment_model_info(model_id="1", model_name="gemini/gemini-1.5-flash")
assert model_info is not None