fix(cost_calculator.py): handle custom pricing at deployment level fo… (#9855)

* fix(cost_calculator.py): handle custom pricing at deployment level for router

* test: add unit tests

* fix(router.py): show custom pricing on UI

check correct model str

* fix: fix linting error

* docs(custom_pricing.md): clarify custom pricing for proxy

Fixes https://github.com/BerriAI/litellm/issues/8573#issuecomment-2790420740

* test: update code qa test

* fix: cleanup traceback

* fix: handle litellm param custom pricing

* test: update test

* fix(cost_calculator.py): add router model id to list of potential model names

* fix(cost_calculator.py): fix router model id check

* fix: router.py - maintain older model registry approach

* fix: fix ruff check

* fix(router.py): router get deployment info

add custom values to mapped dict

* test: update test

* fix(utils.py): update only if value is non-null

* test: add unit test
This commit is contained in:
Krish Dholakia 2025-04-09 22:13:10 -07:00 committed by GitHub
parent 0c5b4aa96d
commit 0dbd663877
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
16 changed files with 193 additions and 37 deletions

View file

@ -26,10 +26,12 @@ model_list:
- model_name: sagemaker-completion-model
litellm_params:
model: sagemaker/berri-benchmarking-Llama-2-70b-chat-hf-4
model_info:
input_cost_per_second: 0.000420
- model_name: sagemaker-embedding-model
litellm_params:
model: sagemaker/berri-benchmarking-gpt-j-6b-fp16
model_info:
input_cost_per_second: 0.000420
```
@ -55,11 +57,33 @@ model_list:
api_key: os.environ/AZURE_API_KEY
api_base: os.environ/AZURE_API_BASE
api_version: os.envrion/AZURE_API_VERSION
model_info:
input_cost_per_token: 0.000421 # 👈 ONLY to track cost per token
output_cost_per_token: 0.000520 # 👈 ONLY to track cost per token
```
### Debugging
## Override Model Cost Map
You can override [our model cost map](https://github.com/BerriAI/litellm/blob/main/model_prices_and_context_window.json) with your own custom pricing for a mapped model.
Just add a `model_info` key to your model in the config, and override the desired keys.
Example: Override Anthropic's model cost map for the `prod/claude-3-5-sonnet-20241022` model.
```yaml
model_list:
- model_name: "prod/claude-3-5-sonnet-20241022"
litellm_params:
model: "anthropic/claude-3-5-sonnet-20241022"
api_key: os.environ/ANTHROPIC_PROD_API_KEY
model_info:
input_cost_per_token: 0.000006
output_cost_per_token: 0.00003
cache_creation_input_token_cost: 0.0000075
cache_read_input_token_cost: 0.0000006
```
## Debugging
If you're custom pricing is not being used or you're seeing errors, please check the following:

View file

@ -403,6 +403,7 @@ def _select_model_name_for_cost_calc(
base_model: Optional[str] = None,
custom_pricing: Optional[bool] = None,
custom_llm_provider: Optional[str] = None,
router_model_id: Optional[str] = None,
) -> Optional[str]:
"""
1. If custom pricing is true, return received model name
@ -417,12 +418,6 @@ def _select_model_name_for_cost_calc(
model=model, custom_llm_provider=custom_llm_provider
)
if custom_pricing is True:
return_model = model
if base_model is not None:
return_model = base_model
completion_response_model: Optional[str] = None
if completion_response is not None:
if isinstance(completion_response, BaseModel):
@ -430,6 +425,16 @@ def _select_model_name_for_cost_calc(
elif isinstance(completion_response, dict):
completion_response_model = completion_response.get("model", None)
hidden_params: Optional[dict] = getattr(completion_response, "_hidden_params", None)
if custom_pricing is True:
if router_model_id is not None and router_model_id in litellm.model_cost:
return_model = router_model_id
else:
return_model = model
if base_model is not None:
return_model = base_model
if completion_response_model is None and hidden_params is not None:
if (
hidden_params.get("model", None) is not None
@ -559,6 +564,7 @@ def completion_cost( # noqa: PLR0915
base_model: Optional[str] = None,
standard_built_in_tools_params: Optional[StandardBuiltInToolsParams] = None,
litellm_model_name: Optional[str] = None,
router_model_id: Optional[str] = None,
) -> float:
"""
Calculate the cost of a given completion call fot GPT-3.5-turbo, llama2, any litellm supported llm.
@ -617,12 +623,12 @@ def completion_cost( # noqa: PLR0915
custom_llm_provider=custom_llm_provider,
custom_pricing=custom_pricing,
base_model=base_model,
router_model_id=router_model_id,
)
potential_model_names = [selected_model]
if model is not None:
potential_model_names.append(model)
for idx, model in enumerate(potential_model_names):
try:
verbose_logger.info(
@ -943,6 +949,7 @@ def response_cost_calculator(
prompt: str = "",
standard_built_in_tools_params: Optional[StandardBuiltInToolsParams] = None,
litellm_model_name: Optional[str] = None,
router_model_id: Optional[str] = None,
) -> float:
"""
Returns
@ -973,6 +980,8 @@ def response_cost_calculator(
base_model=base_model,
prompt=prompt,
standard_built_in_tools_params=standard_built_in_tools_params,
litellm_model_name=litellm_model_name,
router_model_id=router_model_id,
)
return response_cost
except Exception as e:

View file

@ -1,7 +1,6 @@
from typing import Literal, Optional
import litellm
from litellm._logging import verbose_logger
from litellm.exceptions import BadRequestError
from litellm.types.utils import LlmProviders, LlmProvidersSet
@ -43,9 +42,6 @@ def get_supported_openai_params( # noqa: PLR0915
provider_config = None
if provider_config and request_type == "chat_completion":
verbose_logger.info(
f"using provider_config: {provider_config} for checking supported params"
)
return provider_config.get_supported_openai_params(model=model)
if custom_llm_provider == "bedrock":

View file

@ -622,7 +622,6 @@ class Logging(LiteLLMLoggingBaseClass):
] = RawRequestTypedDict(
error=str(e),
)
traceback.print_exc()
_metadata[
"raw_request"
] = "Unable to Log \
@ -906,6 +905,7 @@ class Logging(LiteLLMLoggingBaseClass):
],
cache_hit: Optional[bool] = None,
litellm_model_name: Optional[str] = None,
router_model_id: Optional[str] = None,
) -> Optional[float]:
"""
Calculate response cost using result + logging object variables.
@ -944,6 +944,7 @@ class Logging(LiteLLMLoggingBaseClass):
"custom_pricing": custom_pricing,
"prompt": prompt,
"standard_built_in_tools_params": self.standard_built_in_tools_params,
"router_model_id": router_model_id,
}
except Exception as e: # error creating kwargs for cost calculation
debug_info = StandardLoggingModelCostFailureDebugInformation(

View file

@ -36,11 +36,16 @@ class ResponseMetadata:
self, logging_obj: LiteLLMLoggingObject, model: Optional[str], kwargs: dict
) -> None:
"""Set hidden parameters on the response"""
## ADD OTHER HIDDEN PARAMS
model_id = kwargs.get("model_info", {}).get("id", None)
new_params = {
"litellm_call_id": getattr(logging_obj, "litellm_call_id", None),
"model_id": kwargs.get("model_info", {}).get("id", None),
"api_base": get_api_base(model=model or "", optional_params=kwargs),
"response_cost": logging_obj._response_cost_calculator(result=self.result),
"model_id": model_id,
"response_cost": logging_obj._response_cost_calculator(
result=self.result, litellm_model_name=model, router_model_id=model_id
),
"additional_headers": process_response_headers(
self._get_value_from_hidden_params("additional_headers") or {}
),

View file

@ -1,7 +1,6 @@
import copy
import json
import re
import traceback
import uuid
import xml.etree.ElementTree as ET
from enum import Enum
@ -748,7 +747,6 @@ def convert_to_anthropic_image_obj(
data=base64_data,
)
except Exception as e:
traceback.print_exc()
if "Error: Unable to fetch image from URL" in str(e):
raise e
raise Exception(

View file

@ -100,7 +100,6 @@ async def cache_ping():
except Exception as e:
import traceback
traceback.print_exc()
error_message = {
"message": f"Service Unhealthy ({str(e)})",
"litellm_cache_params": safe_dumps(litellm_cache_params),

View file

@ -816,9 +816,6 @@ async def add_member_to_organization(
return user_object, organization_membership
except Exception as e:
import traceback
traceback.print_exc()
raise ValueError(
f"Error adding member={member} to organization={organization_id}: {e}"
)

View file

@ -116,6 +116,7 @@ from litellm.types.router import (
AllowedFailsPolicy,
AssistantsTypedDict,
CredentialLiteLLMParams,
CustomPricingLiteLLMParams,
CustomRoutingStrategyBase,
Deployment,
DeploymentTypedDict,
@ -132,6 +133,7 @@ from litellm.types.router import (
)
from litellm.types.services import ServiceTypes
from litellm.types.utils import GenericBudgetConfigType
from litellm.types.utils import ModelInfo
from litellm.types.utils import ModelInfo as ModelMapInfo
from litellm.types.utils import StandardLoggingPayload
from litellm.utils import (
@ -3324,7 +3326,6 @@ class Router:
return response
except Exception as new_exception:
traceback.print_exc()
parent_otel_span = _get_parent_otel_span_from_kwargs(kwargs)
verbose_router_logger.error(
"litellm.router.py::async_function_with_fallbacks() - Error occurred while trying to do fallbacks - {}\n{}\n\nDebug Information:\nCooldown Deployments={}".format(
@ -4301,7 +4302,20 @@ class Router:
model_info=_model_info,
)
for field in CustomPricingLiteLLMParams.model_fields.keys():
if deployment.litellm_params.get(field) is not None:
_model_info[field] = deployment.litellm_params[field]
## REGISTER MODEL INFO IN LITELLM MODEL COST MAP
model_id = deployment.model_info.id
if model_id is not None:
litellm.register_model(
model_cost={
model_id: _model_info,
}
)
## OLD MODEL REGISTRATION ## Kept to prevent breaking changes
_model_name = deployment.litellm_params.model
if deployment.litellm_params.custom_llm_provider is not None:
_model_name = (
@ -4802,6 +4816,42 @@ class Router:
model_name = model_info["model_name"]
return self.get_model_list(model_name=model_name)
def get_deployment_model_info(
self, model_id: str, model_name: str
) -> Optional[ModelInfo]:
"""
For a given model id, return the model info
1. Check if model_id is in model info
2. If not, check if litellm model name is in model info
3. If not, return None
"""
from litellm.utils import _update_dictionary
model_info: Optional[ModelInfo] = None
litellm_model_name_model_info: Optional[ModelInfo] = None
try:
model_info = litellm.get_model_info(model=model_id)
except Exception:
pass
try:
litellm_model_name_model_info = litellm.get_model_info(model=model_name)
except Exception:
pass
if model_info is not None and litellm_model_name_model_info is not None:
model_info = cast(
ModelInfo,
_update_dictionary(
cast(dict, litellm_model_name_model_info).copy(),
cast(dict, model_info),
),
)
return model_info
def _set_model_group_info( # noqa: PLR0915
self, model_group: str, user_facing_model_group_name: str
) -> Optional[ModelGroupInfo]:
@ -4860,9 +4910,16 @@ class Router:
# get model info
try:
model_info = litellm.get_model_info(model=litellm_params.model)
model_id = model.get("model_info", {}).get("id", None)
if model_id is not None:
model_info = self.get_deployment_model_info(
model_id=model_id, model_name=litellm_params.model
)
else:
model_info = None
except Exception:
model_info = None
# get llm provider
litellm_model, llm_provider = "", ""
try:

View file

@ -162,7 +162,15 @@ class CredentialLiteLLMParams(BaseModel):
watsonx_region_name: Optional[str] = None
class GenericLiteLLMParams(CredentialLiteLLMParams):
class CustomPricingLiteLLMParams(BaseModel):
## CUSTOM PRICING ##
input_cost_per_token: Optional[float] = None
output_cost_per_token: Optional[float] = None
input_cost_per_second: Optional[float] = None
output_cost_per_second: Optional[float] = None
class GenericLiteLLMParams(CredentialLiteLLMParams, CustomPricingLiteLLMParams):
"""
LiteLLM Params without 'model' arg (used across completion / assistants api)
"""
@ -184,12 +192,6 @@ class GenericLiteLLMParams(CredentialLiteLLMParams):
## LOGGING PARAMS ##
litellm_trace_id: Optional[str] = None
## CUSTOM PRICING ##
input_cost_per_token: Optional[float] = None
output_cost_per_token: Optional[float] = None
input_cost_per_second: Optional[float] = None
output_cost_per_second: Optional[float] = None
max_file_size_mb: Optional[float] = None
# Deployment budgets

View file

@ -2245,7 +2245,8 @@ def supports_embedding_image_input(
####### HELPER FUNCTIONS ################
def _update_dictionary(existing_dict: Dict, new_dict: dict) -> dict:
for k, v in new_dict.items():
existing_dict[k] = v
if v is not None:
existing_dict[k] = v
return existing_dict

View file

@ -31,7 +31,7 @@ def get_all_functions_called_in_tests(base_dir):
specifically in files containing the word 'router'.
"""
called_functions = set()
test_dirs = ["local_testing", "router_unit_tests"]
test_dirs = ["local_testing", "router_unit_tests", "litellm"]
for test_dir in test_dirs:
dir_path = os.path.join(base_dir, test_dir)

View file

@ -151,3 +151,63 @@ def test_handle_realtime_stream_cost_calculation():
litellm_model_name="gpt-3.5-turbo",
)
assert cost == 0.0 # No usage, no cost
def test_custom_pricing_with_router_model_id():
from litellm import Router
router = Router(
model_list=[
{
"model_name": "prod/claude-3-5-sonnet-20240620",
"litellm_params": {
"model": "anthropic/claude-3-5-sonnet-20240620",
"api_key": "test_api_key",
},
"model_info": {
"id": "my-unique-model-id",
"input_cost_per_token": 0.000006,
"output_cost_per_token": 0.00003,
"cache_creation_input_token_cost": 0.0000075,
"cache_read_input_token_cost": 0.0000006,
},
},
{
"model_name": "claude-3-5-sonnet-20240620",
"litellm_params": {
"model": "anthropic/claude-3-5-sonnet-20240620",
"api_key": "test_api_key",
},
"model_info": {
"input_cost_per_token": 100,
"output_cost_per_token": 200,
},
},
]
)
result = router.completion(
model="claude-3-5-sonnet-20240620",
messages=[{"role": "user", "content": "Hello, world!"}],
mock_response=True,
)
result_2 = router.completion(
model="prod/claude-3-5-sonnet-20240620",
messages=[{"role": "user", "content": "Hello, world!"}],
mock_response=True,
)
assert (
result._hidden_params["response_cost"]
> result_2._hidden_params["response_cost"]
)
model_info = router.get_deployment_model_info(
model_id="my-unique-model-id", model_name="anthropic/claude-3-5-sonnet-20240620"
)
assert model_info is not None
assert model_info["input_cost_per_token"] == 0.000006
assert model_info["output_cost_per_token"] == 0.00003
assert model_info["cache_creation_input_token_cost"] == 0.0000075
assert model_info["cache_read_input_token_cost"] == 0.0000006

View file

@ -2954,9 +2954,6 @@ def test_cost_calculator_with_custom_pricing():
@pytest.mark.asyncio
async def test_cost_calculator_with_custom_pricing_router(model_item, custom_pricing):
from litellm import Router
litellm._turn_on_debug()
if custom_pricing == "litellm_params":
model_item["litellm_params"]["input_cost_per_token"] = 0.0000008
model_item["litellm_params"]["output_cost_per_token"] = 0.0000032

View file

@ -314,12 +314,14 @@ def test_get_model_info_custom_model_router():
"input_cost_per_token": 1,
"output_cost_per_token": 1,
"model": "openai/meta-llama/Meta-Llama-3-8B-Instruct",
"model_id": "c20d603e-1166-4e0f-aa65-ed9c476ad4ca",
},
"model_info": {
"id": "c20d603e-1166-4e0f-aa65-ed9c476ad4ca",
}
}
]
)
info = get_model_info("openai/meta-llama/Meta-Llama-3-8B-Instruct")
info = get_model_info("c20d603e-1166-4e0f-aa65-ed9c476ad4ca")
print("info", info)
assert info is not None

View file

@ -451,3 +451,11 @@ def test_router_get_deployment_credentials():
credentials = router.get_deployment_credentials(model_id="1")
assert credentials is not None
assert credentials["api_key"] == "123"
def test_router_get_deployment_model_info():
router = Router(
model_list=[{"model_name": "gemini/*", "litellm_params": {"model": "gemini/*"}, "model_info": {"id": "1"}}]
)
model_info = router.get_deployment_model_info(model_id="1", model_name="gemini/gemini-1.5-flash")
assert model_info is not None