mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 11:14:04 +00:00
feat(cost_calculator.py): only override base model if custom pricing is set
This commit is contained in:
parent
a494b5b2f3
commit
55217fa8d7
4 changed files with 98 additions and 25 deletions
|
@ -757,9 +757,7 @@ def response_cost_calculator(
|
||||||
custom_llm_provider=custom_llm_provider,
|
custom_llm_provider=custom_llm_provider,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
if (
|
if custom_pricing is True: # override defaults if custom pricing is set
|
||||||
model in litellm.model_cost or custom_pricing is True
|
|
||||||
): # override defaults if custom pricing is set
|
|
||||||
base_model = model
|
base_model = model
|
||||||
# base_model defaults to None if not set on model_info
|
# base_model defaults to None if not set on model_info
|
||||||
|
|
||||||
|
|
|
@ -689,23 +689,7 @@ class Logging:
|
||||||
complete_streaming_response
|
complete_streaming_response
|
||||||
)
|
)
|
||||||
self.model_call_details["response_cost"] = (
|
self.model_call_details["response_cost"] = (
|
||||||
litellm.response_cost_calculator(
|
self._response_cost_calculator(result=complete_streaming_response)
|
||||||
response_object=complete_streaming_response,
|
|
||||||
model=self.model,
|
|
||||||
cache_hit=self.model_call_details.get("cache_hit", False),
|
|
||||||
custom_llm_provider=self.model_call_details.get(
|
|
||||||
"custom_llm_provider", None
|
|
||||||
),
|
|
||||||
base_model=_get_base_model_from_metadata(
|
|
||||||
model_call_details=self.model_call_details
|
|
||||||
),
|
|
||||||
call_type=self.call_type,
|
|
||||||
optional_params=(
|
|
||||||
self.optional_params
|
|
||||||
if hasattr(self, "optional_params")
|
|
||||||
else {}
|
|
||||||
),
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
if self.dynamic_success_callbacks is not None and isinstance(
|
if self.dynamic_success_callbacks is not None and isinstance(
|
||||||
self.dynamic_success_callbacks, list
|
self.dynamic_success_callbacks, list
|
||||||
|
@ -1308,9 +1292,10 @@ class Logging:
|
||||||
model_call_details=self.model_call_details
|
model_call_details=self.model_call_details
|
||||||
)
|
)
|
||||||
# base_model defaults to None if not set on model_info
|
# base_model defaults to None if not set on model_info
|
||||||
self.model_call_details["response_cost"] = litellm.completion_cost(
|
self.model_call_details["response_cost"] = (
|
||||||
completion_response=complete_streaming_response,
|
self._response_cost_calculator(
|
||||||
model=base_model,
|
result=complete_streaming_response
|
||||||
|
)
|
||||||
)
|
)
|
||||||
verbose_logger.debug(
|
verbose_logger.debug(
|
||||||
f"Model={self.model}; cost={self.model_call_details['response_cost']}"
|
f"Model={self.model}; cost={self.model_call_details['response_cost']}"
|
||||||
|
|
|
@ -11,6 +11,7 @@ import asyncio
|
||||||
import os
|
import os
|
||||||
import time
|
import time
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
|
@ -1031,3 +1032,68 @@ def test_completion_cost_deepseek():
|
||||||
pass
|
pass
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
pytest.fail(f"Error occurred: {e}")
|
pytest.fail(f"Error occurred: {e}")
|
||||||
|
|
||||||
|
|
||||||
|
def test_completion_cost_azure_common_deployment_name():
|
||||||
|
from litellm.utils import (
|
||||||
|
CallTypes,
|
||||||
|
Choices,
|
||||||
|
Delta,
|
||||||
|
Message,
|
||||||
|
ModelResponse,
|
||||||
|
StreamingChoices,
|
||||||
|
Usage,
|
||||||
|
)
|
||||||
|
|
||||||
|
router = litellm.Router(
|
||||||
|
model_list=[
|
||||||
|
{
|
||||||
|
"model_name": "gpt-4",
|
||||||
|
"litellm_params": {
|
||||||
|
"model": "azure/gpt-4-0314",
|
||||||
|
"max_tokens": 4096,
|
||||||
|
"api_key": os.getenv("AZURE_API_KEY"),
|
||||||
|
"api_base": os.getenv("AZURE_API_BASE"),
|
||||||
|
},
|
||||||
|
"model_info": {"base_model": "azure/gpt-4"},
|
||||||
|
}
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
response = ModelResponse(
|
||||||
|
id="chatcmpl-876cce24-e520-4cf8-8649-562a9be11c02",
|
||||||
|
choices=[
|
||||||
|
Choices(
|
||||||
|
finish_reason="stop",
|
||||||
|
index=0,
|
||||||
|
message=Message(
|
||||||
|
content="Hi! I'm an AI, so I don't have emotions or feelings like humans do, but I'm functioning properly and ready to help with any questions or topics you'd like to discuss! How can I assist you today?",
|
||||||
|
role="assistant",
|
||||||
|
),
|
||||||
|
)
|
||||||
|
],
|
||||||
|
created=1717519830,
|
||||||
|
model="gpt-4",
|
||||||
|
object="chat.completion",
|
||||||
|
system_fingerprint="fp_c1a4bcec29",
|
||||||
|
usage=Usage(completion_tokens=46, prompt_tokens=17, total_tokens=63),
|
||||||
|
)
|
||||||
|
response._hidden_params["custom_llm_provider"] = "azure"
|
||||||
|
print(response)
|
||||||
|
|
||||||
|
with patch.object(
|
||||||
|
litellm.cost_calculator, "completion_cost", new=MagicMock()
|
||||||
|
) as mock_client:
|
||||||
|
_ = litellm.response_cost_calculator(
|
||||||
|
response_object=response,
|
||||||
|
model="gpt-4-0314",
|
||||||
|
custom_llm_provider="azure",
|
||||||
|
call_type=CallTypes.acompletion.value,
|
||||||
|
optional_params={},
|
||||||
|
base_model="azure/gpt-4",
|
||||||
|
)
|
||||||
|
|
||||||
|
mock_client.assert_called()
|
||||||
|
|
||||||
|
print(f"mock_client.call_args: {mock_client.call_args.kwargs}")
|
||||||
|
assert "azure/gpt-4" == mock_client.call_args.kwargs["model"]
|
||||||
|
|
|
@ -5203,19 +5203,43 @@ def get_model_info(model: str, custom_llm_provider: Optional[str] = None) -> Mod
|
||||||
if custom_llm_provider == "predibase":
|
if custom_llm_provider == "predibase":
|
||||||
_model_info["supports_response_schema"] = True
|
_model_info["supports_response_schema"] = True
|
||||||
|
|
||||||
|
_input_cost_per_token: Optional[float] = _model_info.get(
|
||||||
|
"input_cost_per_token"
|
||||||
|
)
|
||||||
|
if _input_cost_per_token is None:
|
||||||
|
# default value to 0, be noisy about this
|
||||||
|
verbose_logger.debug(
|
||||||
|
"model={}, custom_llm_provider={} has no input_cost_per_token in model_cost_map. Defaulting to 0.".format(
|
||||||
|
model, custom_llm_provider
|
||||||
|
)
|
||||||
|
)
|
||||||
|
_input_cost_per_token = 0
|
||||||
|
|
||||||
|
_output_cost_per_token: Optional[float] = _model_info.get(
|
||||||
|
"output_cost_per_token"
|
||||||
|
)
|
||||||
|
if _output_cost_per_token is None:
|
||||||
|
# default value to 0, be noisy about this
|
||||||
|
verbose_logger.debug(
|
||||||
|
"model={}, custom_llm_provider={} has no output_cost_per_token in model_cost_map. Defaulting to 0.".format(
|
||||||
|
model, custom_llm_provider
|
||||||
|
)
|
||||||
|
)
|
||||||
|
_output_cost_per_token = 0
|
||||||
|
|
||||||
return ModelInfo(
|
return ModelInfo(
|
||||||
key=key,
|
key=key,
|
||||||
max_tokens=_model_info.get("max_tokens", None),
|
max_tokens=_model_info.get("max_tokens", None),
|
||||||
max_input_tokens=_model_info.get("max_input_tokens", None),
|
max_input_tokens=_model_info.get("max_input_tokens", None),
|
||||||
max_output_tokens=_model_info.get("max_output_tokens", None),
|
max_output_tokens=_model_info.get("max_output_tokens", None),
|
||||||
input_cost_per_token=_model_info.get("input_cost_per_token", 0),
|
input_cost_per_token=_input_cost_per_token,
|
||||||
input_cost_per_character=_model_info.get(
|
input_cost_per_character=_model_info.get(
|
||||||
"input_cost_per_character", None
|
"input_cost_per_character", None
|
||||||
),
|
),
|
||||||
input_cost_per_token_above_128k_tokens=_model_info.get(
|
input_cost_per_token_above_128k_tokens=_model_info.get(
|
||||||
"input_cost_per_token_above_128k_tokens", None
|
"input_cost_per_token_above_128k_tokens", None
|
||||||
),
|
),
|
||||||
output_cost_per_token=_model_info.get("output_cost_per_token", 0),
|
output_cost_per_token=_output_cost_per_token,
|
||||||
output_cost_per_character=_model_info.get(
|
output_cost_per_character=_model_info.get(
|
||||||
"output_cost_per_character", None
|
"output_cost_per_character", None
|
||||||
),
|
),
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue