mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-24 18:24:20 +00:00
fix(cost_calculator.py): handle custom pricing at deployment level fo… (#9855)
* fix(cost_calculator.py): handle custom pricing at deployment level for router * test: add unit tests * fix(router.py): show custom pricing on UI check correct model str * fix: fix linting error * docs(custom_pricing.md): clarify custom pricing for proxy Fixes https://github.com/BerriAI/litellm/issues/8573#issuecomment-2790420740 * test: update code qa test * fix: cleanup traceback * fix: handle litellm param custom pricing * test: update test * fix(cost_calculator.py): add router model id to list of potential model names * fix(cost_calculator.py): fix router model id check * fix: router.py - maintain older model registry approach * fix: fix ruff check * fix(router.py): router get deployment info add custom values to mapped dict * test: update test * fix(utils.py): update only if value is non-null * test: add unit test
This commit is contained in:
parent
0c5b4aa96d
commit
0dbd663877
16 changed files with 193 additions and 37 deletions
|
@ -26,10 +26,12 @@ model_list:
|
|||
- model_name: sagemaker-completion-model
|
||||
litellm_params:
|
||||
model: sagemaker/berri-benchmarking-Llama-2-70b-chat-hf-4
|
||||
model_info:
|
||||
input_cost_per_second: 0.000420
|
||||
- model_name: sagemaker-embedding-model
|
||||
litellm_params:
|
||||
model: sagemaker/berri-benchmarking-gpt-j-6b-fp16
|
||||
model_info:
|
||||
input_cost_per_second: 0.000420
|
||||
```
|
||||
|
||||
|
@ -55,11 +57,33 @@ model_list:
|
|||
api_key: os.environ/AZURE_API_KEY
|
||||
api_base: os.environ/AZURE_API_BASE
|
||||
api_version: os.envrion/AZURE_API_VERSION
|
||||
model_info:
|
||||
input_cost_per_token: 0.000421 # 👈 ONLY to track cost per token
|
||||
output_cost_per_token: 0.000520 # 👈 ONLY to track cost per token
|
||||
```
|
||||
|
||||
### Debugging
|
||||
## Override Model Cost Map
|
||||
|
||||
You can override [our model cost map](https://github.com/BerriAI/litellm/blob/main/model_prices_and_context_window.json) with your own custom pricing for a mapped model.
|
||||
|
||||
Just add a `model_info` key to your model in the config, and override the desired keys.
|
||||
|
||||
Example: Override Anthropic's model cost map for the `prod/claude-3-5-sonnet-20241022` model.
|
||||
|
||||
```yaml
|
||||
model_list:
|
||||
- model_name: "prod/claude-3-5-sonnet-20241022"
|
||||
litellm_params:
|
||||
model: "anthropic/claude-3-5-sonnet-20241022"
|
||||
api_key: os.environ/ANTHROPIC_PROD_API_KEY
|
||||
model_info:
|
||||
input_cost_per_token: 0.000006
|
||||
output_cost_per_token: 0.00003
|
||||
cache_creation_input_token_cost: 0.0000075
|
||||
cache_read_input_token_cost: 0.0000006
|
||||
```
|
||||
|
||||
## Debugging
|
||||
|
||||
If you're custom pricing is not being used or you're seeing errors, please check the following:
|
||||
|
||||
|
|
|
@ -403,6 +403,7 @@ def _select_model_name_for_cost_calc(
|
|||
base_model: Optional[str] = None,
|
||||
custom_pricing: Optional[bool] = None,
|
||||
custom_llm_provider: Optional[str] = None,
|
||||
router_model_id: Optional[str] = None,
|
||||
) -> Optional[str]:
|
||||
"""
|
||||
1. If custom pricing is true, return received model name
|
||||
|
@ -417,12 +418,6 @@ def _select_model_name_for_cost_calc(
|
|||
model=model, custom_llm_provider=custom_llm_provider
|
||||
)
|
||||
|
||||
if custom_pricing is True:
|
||||
return_model = model
|
||||
|
||||
if base_model is not None:
|
||||
return_model = base_model
|
||||
|
||||
completion_response_model: Optional[str] = None
|
||||
if completion_response is not None:
|
||||
if isinstance(completion_response, BaseModel):
|
||||
|
@ -430,6 +425,16 @@ def _select_model_name_for_cost_calc(
|
|||
elif isinstance(completion_response, dict):
|
||||
completion_response_model = completion_response.get("model", None)
|
||||
hidden_params: Optional[dict] = getattr(completion_response, "_hidden_params", None)
|
||||
|
||||
if custom_pricing is True:
|
||||
if router_model_id is not None and router_model_id in litellm.model_cost:
|
||||
return_model = router_model_id
|
||||
else:
|
||||
return_model = model
|
||||
|
||||
if base_model is not None:
|
||||
return_model = base_model
|
||||
|
||||
if completion_response_model is None and hidden_params is not None:
|
||||
if (
|
||||
hidden_params.get("model", None) is not None
|
||||
|
@ -559,6 +564,7 @@ def completion_cost( # noqa: PLR0915
|
|||
base_model: Optional[str] = None,
|
||||
standard_built_in_tools_params: Optional[StandardBuiltInToolsParams] = None,
|
||||
litellm_model_name: Optional[str] = None,
|
||||
router_model_id: Optional[str] = None,
|
||||
) -> float:
|
||||
"""
|
||||
Calculate the cost of a given completion call fot GPT-3.5-turbo, llama2, any litellm supported llm.
|
||||
|
@ -617,12 +623,12 @@ def completion_cost( # noqa: PLR0915
|
|||
custom_llm_provider=custom_llm_provider,
|
||||
custom_pricing=custom_pricing,
|
||||
base_model=base_model,
|
||||
router_model_id=router_model_id,
|
||||
)
|
||||
|
||||
potential_model_names = [selected_model]
|
||||
if model is not None:
|
||||
potential_model_names.append(model)
|
||||
|
||||
for idx, model in enumerate(potential_model_names):
|
||||
try:
|
||||
verbose_logger.info(
|
||||
|
@ -943,6 +949,7 @@ def response_cost_calculator(
|
|||
prompt: str = "",
|
||||
standard_built_in_tools_params: Optional[StandardBuiltInToolsParams] = None,
|
||||
litellm_model_name: Optional[str] = None,
|
||||
router_model_id: Optional[str] = None,
|
||||
) -> float:
|
||||
"""
|
||||
Returns
|
||||
|
@ -973,6 +980,8 @@ def response_cost_calculator(
|
|||
base_model=base_model,
|
||||
prompt=prompt,
|
||||
standard_built_in_tools_params=standard_built_in_tools_params,
|
||||
litellm_model_name=litellm_model_name,
|
||||
router_model_id=router_model_id,
|
||||
)
|
||||
return response_cost
|
||||
except Exception as e:
|
||||
|
|
|
@ -1,7 +1,6 @@
|
|||
from typing import Literal, Optional
|
||||
|
||||
import litellm
|
||||
from litellm._logging import verbose_logger
|
||||
from litellm.exceptions import BadRequestError
|
||||
from litellm.types.utils import LlmProviders, LlmProvidersSet
|
||||
|
||||
|
@ -43,9 +42,6 @@ def get_supported_openai_params( # noqa: PLR0915
|
|||
provider_config = None
|
||||
|
||||
if provider_config and request_type == "chat_completion":
|
||||
verbose_logger.info(
|
||||
f"using provider_config: {provider_config} for checking supported params"
|
||||
)
|
||||
return provider_config.get_supported_openai_params(model=model)
|
||||
|
||||
if custom_llm_provider == "bedrock":
|
||||
|
|
|
@ -622,7 +622,6 @@ class Logging(LiteLLMLoggingBaseClass):
|
|||
] = RawRequestTypedDict(
|
||||
error=str(e),
|
||||
)
|
||||
traceback.print_exc()
|
||||
_metadata[
|
||||
"raw_request"
|
||||
] = "Unable to Log \
|
||||
|
@ -906,6 +905,7 @@ class Logging(LiteLLMLoggingBaseClass):
|
|||
],
|
||||
cache_hit: Optional[bool] = None,
|
||||
litellm_model_name: Optional[str] = None,
|
||||
router_model_id: Optional[str] = None,
|
||||
) -> Optional[float]:
|
||||
"""
|
||||
Calculate response cost using result + logging object variables.
|
||||
|
@ -944,6 +944,7 @@ class Logging(LiteLLMLoggingBaseClass):
|
|||
"custom_pricing": custom_pricing,
|
||||
"prompt": prompt,
|
||||
"standard_built_in_tools_params": self.standard_built_in_tools_params,
|
||||
"router_model_id": router_model_id,
|
||||
}
|
||||
except Exception as e: # error creating kwargs for cost calculation
|
||||
debug_info = StandardLoggingModelCostFailureDebugInformation(
|
||||
|
|
|
@ -36,11 +36,16 @@ class ResponseMetadata:
|
|||
self, logging_obj: LiteLLMLoggingObject, model: Optional[str], kwargs: dict
|
||||
) -> None:
|
||||
"""Set hidden parameters on the response"""
|
||||
|
||||
## ADD OTHER HIDDEN PARAMS
|
||||
model_id = kwargs.get("model_info", {}).get("id", None)
|
||||
new_params = {
|
||||
"litellm_call_id": getattr(logging_obj, "litellm_call_id", None),
|
||||
"model_id": kwargs.get("model_info", {}).get("id", None),
|
||||
"api_base": get_api_base(model=model or "", optional_params=kwargs),
|
||||
"response_cost": logging_obj._response_cost_calculator(result=self.result),
|
||||
"model_id": model_id,
|
||||
"response_cost": logging_obj._response_cost_calculator(
|
||||
result=self.result, litellm_model_name=model, router_model_id=model_id
|
||||
),
|
||||
"additional_headers": process_response_headers(
|
||||
self._get_value_from_hidden_params("additional_headers") or {}
|
||||
),
|
||||
|
|
|
@ -1,7 +1,6 @@
|
|||
import copy
|
||||
import json
|
||||
import re
|
||||
import traceback
|
||||
import uuid
|
||||
import xml.etree.ElementTree as ET
|
||||
from enum import Enum
|
||||
|
@ -748,7 +747,6 @@ def convert_to_anthropic_image_obj(
|
|||
data=base64_data,
|
||||
)
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
if "Error: Unable to fetch image from URL" in str(e):
|
||||
raise e
|
||||
raise Exception(
|
||||
|
|
|
@ -100,7 +100,6 @@ async def cache_ping():
|
|||
except Exception as e:
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
error_message = {
|
||||
"message": f"Service Unhealthy ({str(e)})",
|
||||
"litellm_cache_params": safe_dumps(litellm_cache_params),
|
||||
|
|
|
@ -816,9 +816,6 @@ async def add_member_to_organization(
|
|||
return user_object, organization_membership
|
||||
|
||||
except Exception as e:
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
raise ValueError(
|
||||
f"Error adding member={member} to organization={organization_id}: {e}"
|
||||
)
|
||||
|
|
|
@ -116,6 +116,7 @@ from litellm.types.router import (
|
|||
AllowedFailsPolicy,
|
||||
AssistantsTypedDict,
|
||||
CredentialLiteLLMParams,
|
||||
CustomPricingLiteLLMParams,
|
||||
CustomRoutingStrategyBase,
|
||||
Deployment,
|
||||
DeploymentTypedDict,
|
||||
|
@ -132,6 +133,7 @@ from litellm.types.router import (
|
|||
)
|
||||
from litellm.types.services import ServiceTypes
|
||||
from litellm.types.utils import GenericBudgetConfigType
|
||||
from litellm.types.utils import ModelInfo
|
||||
from litellm.types.utils import ModelInfo as ModelMapInfo
|
||||
from litellm.types.utils import StandardLoggingPayload
|
||||
from litellm.utils import (
|
||||
|
@ -3324,7 +3326,6 @@ class Router:
|
|||
|
||||
return response
|
||||
except Exception as new_exception:
|
||||
traceback.print_exc()
|
||||
parent_otel_span = _get_parent_otel_span_from_kwargs(kwargs)
|
||||
verbose_router_logger.error(
|
||||
"litellm.router.py::async_function_with_fallbacks() - Error occurred while trying to do fallbacks - {}\n{}\n\nDebug Information:\nCooldown Deployments={}".format(
|
||||
|
@ -4301,7 +4302,20 @@ class Router:
|
|||
model_info=_model_info,
|
||||
)
|
||||
|
||||
for field in CustomPricingLiteLLMParams.model_fields.keys():
|
||||
if deployment.litellm_params.get(field) is not None:
|
||||
_model_info[field] = deployment.litellm_params[field]
|
||||
|
||||
## REGISTER MODEL INFO IN LITELLM MODEL COST MAP
|
||||
model_id = deployment.model_info.id
|
||||
if model_id is not None:
|
||||
litellm.register_model(
|
||||
model_cost={
|
||||
model_id: _model_info,
|
||||
}
|
||||
)
|
||||
|
||||
## OLD MODEL REGISTRATION ## Kept to prevent breaking changes
|
||||
_model_name = deployment.litellm_params.model
|
||||
if deployment.litellm_params.custom_llm_provider is not None:
|
||||
_model_name = (
|
||||
|
@ -4802,6 +4816,42 @@ class Router:
|
|||
model_name = model_info["model_name"]
|
||||
return self.get_model_list(model_name=model_name)
|
||||
|
||||
def get_deployment_model_info(
|
||||
self, model_id: str, model_name: str
|
||||
) -> Optional[ModelInfo]:
|
||||
"""
|
||||
For a given model id, return the model info
|
||||
|
||||
1. Check if model_id is in model info
|
||||
2. If not, check if litellm model name is in model info
|
||||
3. If not, return None
|
||||
"""
|
||||
from litellm.utils import _update_dictionary
|
||||
|
||||
model_info: Optional[ModelInfo] = None
|
||||
litellm_model_name_model_info: Optional[ModelInfo] = None
|
||||
|
||||
try:
|
||||
model_info = litellm.get_model_info(model=model_id)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
try:
|
||||
litellm_model_name_model_info = litellm.get_model_info(model=model_name)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if model_info is not None and litellm_model_name_model_info is not None:
|
||||
model_info = cast(
|
||||
ModelInfo,
|
||||
_update_dictionary(
|
||||
cast(dict, litellm_model_name_model_info).copy(),
|
||||
cast(dict, model_info),
|
||||
),
|
||||
)
|
||||
|
||||
return model_info
|
||||
|
||||
def _set_model_group_info( # noqa: PLR0915
|
||||
self, model_group: str, user_facing_model_group_name: str
|
||||
) -> Optional[ModelGroupInfo]:
|
||||
|
@ -4860,9 +4910,16 @@ class Router:
|
|||
|
||||
# get model info
|
||||
try:
|
||||
model_info = litellm.get_model_info(model=litellm_params.model)
|
||||
model_id = model.get("model_info", {}).get("id", None)
|
||||
if model_id is not None:
|
||||
model_info = self.get_deployment_model_info(
|
||||
model_id=model_id, model_name=litellm_params.model
|
||||
)
|
||||
else:
|
||||
model_info = None
|
||||
except Exception:
|
||||
model_info = None
|
||||
|
||||
# get llm provider
|
||||
litellm_model, llm_provider = "", ""
|
||||
try:
|
||||
|
|
|
@ -162,7 +162,15 @@ class CredentialLiteLLMParams(BaseModel):
|
|||
watsonx_region_name: Optional[str] = None
|
||||
|
||||
|
||||
class GenericLiteLLMParams(CredentialLiteLLMParams):
|
||||
class CustomPricingLiteLLMParams(BaseModel):
|
||||
## CUSTOM PRICING ##
|
||||
input_cost_per_token: Optional[float] = None
|
||||
output_cost_per_token: Optional[float] = None
|
||||
input_cost_per_second: Optional[float] = None
|
||||
output_cost_per_second: Optional[float] = None
|
||||
|
||||
|
||||
class GenericLiteLLMParams(CredentialLiteLLMParams, CustomPricingLiteLLMParams):
|
||||
"""
|
||||
LiteLLM Params without 'model' arg (used across completion / assistants api)
|
||||
"""
|
||||
|
@ -184,12 +192,6 @@ class GenericLiteLLMParams(CredentialLiteLLMParams):
|
|||
## LOGGING PARAMS ##
|
||||
litellm_trace_id: Optional[str] = None
|
||||
|
||||
## CUSTOM PRICING ##
|
||||
input_cost_per_token: Optional[float] = None
|
||||
output_cost_per_token: Optional[float] = None
|
||||
input_cost_per_second: Optional[float] = None
|
||||
output_cost_per_second: Optional[float] = None
|
||||
|
||||
max_file_size_mb: Optional[float] = None
|
||||
|
||||
# Deployment budgets
|
||||
|
|
|
@ -2245,7 +2245,8 @@ def supports_embedding_image_input(
|
|||
####### HELPER FUNCTIONS ################
|
||||
def _update_dictionary(existing_dict: Dict, new_dict: dict) -> dict:
|
||||
for k, v in new_dict.items():
|
||||
existing_dict[k] = v
|
||||
if v is not None:
|
||||
existing_dict[k] = v
|
||||
|
||||
return existing_dict
|
||||
|
||||
|
|
|
@ -31,7 +31,7 @@ def get_all_functions_called_in_tests(base_dir):
|
|||
specifically in files containing the word 'router'.
|
||||
"""
|
||||
called_functions = set()
|
||||
test_dirs = ["local_testing", "router_unit_tests"]
|
||||
test_dirs = ["local_testing", "router_unit_tests", "litellm"]
|
||||
|
||||
for test_dir in test_dirs:
|
||||
dir_path = os.path.join(base_dir, test_dir)
|
||||
|
|
|
@ -151,3 +151,63 @@ def test_handle_realtime_stream_cost_calculation():
|
|||
litellm_model_name="gpt-3.5-turbo",
|
||||
)
|
||||
assert cost == 0.0 # No usage, no cost
|
||||
|
||||
|
||||
def test_custom_pricing_with_router_model_id():
|
||||
from litellm import Router
|
||||
|
||||
router = Router(
|
||||
model_list=[
|
||||
{
|
||||
"model_name": "prod/claude-3-5-sonnet-20240620",
|
||||
"litellm_params": {
|
||||
"model": "anthropic/claude-3-5-sonnet-20240620",
|
||||
"api_key": "test_api_key",
|
||||
},
|
||||
"model_info": {
|
||||
"id": "my-unique-model-id",
|
||||
"input_cost_per_token": 0.000006,
|
||||
"output_cost_per_token": 0.00003,
|
||||
"cache_creation_input_token_cost": 0.0000075,
|
||||
"cache_read_input_token_cost": 0.0000006,
|
||||
},
|
||||
},
|
||||
{
|
||||
"model_name": "claude-3-5-sonnet-20240620",
|
||||
"litellm_params": {
|
||||
"model": "anthropic/claude-3-5-sonnet-20240620",
|
||||
"api_key": "test_api_key",
|
||||
},
|
||||
"model_info": {
|
||||
"input_cost_per_token": 100,
|
||||
"output_cost_per_token": 200,
|
||||
},
|
||||
},
|
||||
]
|
||||
)
|
||||
|
||||
result = router.completion(
|
||||
model="claude-3-5-sonnet-20240620",
|
||||
messages=[{"role": "user", "content": "Hello, world!"}],
|
||||
mock_response=True,
|
||||
)
|
||||
|
||||
result_2 = router.completion(
|
||||
model="prod/claude-3-5-sonnet-20240620",
|
||||
messages=[{"role": "user", "content": "Hello, world!"}],
|
||||
mock_response=True,
|
||||
)
|
||||
|
||||
assert (
|
||||
result._hidden_params["response_cost"]
|
||||
> result_2._hidden_params["response_cost"]
|
||||
)
|
||||
|
||||
model_info = router.get_deployment_model_info(
|
||||
model_id="my-unique-model-id", model_name="anthropic/claude-3-5-sonnet-20240620"
|
||||
)
|
||||
assert model_info is not None
|
||||
assert model_info["input_cost_per_token"] == 0.000006
|
||||
assert model_info["output_cost_per_token"] == 0.00003
|
||||
assert model_info["cache_creation_input_token_cost"] == 0.0000075
|
||||
assert model_info["cache_read_input_token_cost"] == 0.0000006
|
||||
|
|
|
@ -2954,9 +2954,6 @@ def test_cost_calculator_with_custom_pricing():
|
|||
@pytest.mark.asyncio
|
||||
async def test_cost_calculator_with_custom_pricing_router(model_item, custom_pricing):
|
||||
from litellm import Router
|
||||
|
||||
litellm._turn_on_debug()
|
||||
|
||||
if custom_pricing == "litellm_params":
|
||||
model_item["litellm_params"]["input_cost_per_token"] = 0.0000008
|
||||
model_item["litellm_params"]["output_cost_per_token"] = 0.0000032
|
||||
|
|
|
@ -314,12 +314,14 @@ def test_get_model_info_custom_model_router():
|
|||
"input_cost_per_token": 1,
|
||||
"output_cost_per_token": 1,
|
||||
"model": "openai/meta-llama/Meta-Llama-3-8B-Instruct",
|
||||
"model_id": "c20d603e-1166-4e0f-aa65-ed9c476ad4ca",
|
||||
},
|
||||
"model_info": {
|
||||
"id": "c20d603e-1166-4e0f-aa65-ed9c476ad4ca",
|
||||
}
|
||||
}
|
||||
]
|
||||
)
|
||||
info = get_model_info("openai/meta-llama/Meta-Llama-3-8B-Instruct")
|
||||
info = get_model_info("c20d603e-1166-4e0f-aa65-ed9c476ad4ca")
|
||||
print("info", info)
|
||||
assert info is not None
|
||||
|
||||
|
|
|
@ -451,3 +451,11 @@ def test_router_get_deployment_credentials():
|
|||
credentials = router.get_deployment_credentials(model_id="1")
|
||||
assert credentials is not None
|
||||
assert credentials["api_key"] == "123"
|
||||
|
||||
|
||||
def test_router_get_deployment_model_info():
|
||||
router = Router(
|
||||
model_list=[{"model_name": "gemini/*", "litellm_params": {"model": "gemini/*"}, "model_info": {"id": "1"}}]
|
||||
)
|
||||
model_info = router.get_deployment_model_info(model_id="1", model_name="gemini/gemini-1.5-flash")
|
||||
assert model_info is not None
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue