feat(cost_calculator.py): only override base model if custom pricing is set

This commit is contained in:
Krrish Dholakia 2024-08-19 16:05:49 -07:00
parent a494b5b2f3
commit 55217fa8d7
4 changed files with 98 additions and 25 deletions

View file

@ -757,9 +757,7 @@ def response_cost_calculator(
custom_llm_provider=custom_llm_provider, custom_llm_provider=custom_llm_provider,
) )
else: else:
if ( if custom_pricing is True: # override defaults if custom pricing is set
model in litellm.model_cost or custom_pricing is True
): # override defaults if custom pricing is set
base_model = model base_model = model
# base_model defaults to None if not set on model_info # base_model defaults to None if not set on model_info

View file

@ -689,23 +689,7 @@ class Logging:
complete_streaming_response complete_streaming_response
) )
self.model_call_details["response_cost"] = ( self.model_call_details["response_cost"] = (
litellm.response_cost_calculator( self._response_cost_calculator(result=complete_streaming_response)
response_object=complete_streaming_response,
model=self.model,
cache_hit=self.model_call_details.get("cache_hit", False),
custom_llm_provider=self.model_call_details.get(
"custom_llm_provider", None
),
base_model=_get_base_model_from_metadata(
model_call_details=self.model_call_details
),
call_type=self.call_type,
optional_params=(
self.optional_params
if hasattr(self, "optional_params")
else {}
),
)
) )
if self.dynamic_success_callbacks is not None and isinstance( if self.dynamic_success_callbacks is not None and isinstance(
self.dynamic_success_callbacks, list self.dynamic_success_callbacks, list
@ -1308,9 +1292,10 @@ class Logging:
model_call_details=self.model_call_details model_call_details=self.model_call_details
) )
# base_model defaults to None if not set on model_info # base_model defaults to None if not set on model_info
self.model_call_details["response_cost"] = litellm.completion_cost( self.model_call_details["response_cost"] = (
completion_response=complete_streaming_response, self._response_cost_calculator(
model=base_model, result=complete_streaming_response
)
) )
verbose_logger.debug( verbose_logger.debug(
f"Model={self.model}; cost={self.model_call_details['response_cost']}" f"Model={self.model}; cost={self.model_call_details['response_cost']}"

View file

@ -11,6 +11,7 @@ import asyncio
import os import os
import time import time
from typing import Optional from typing import Optional
from unittest.mock import AsyncMock, MagicMock, patch
import pytest import pytest
@ -1031,3 +1032,68 @@ def test_completion_cost_deepseek():
pass pass
except Exception as e: except Exception as e:
pytest.fail(f"Error occurred: {e}") pytest.fail(f"Error occurred: {e}")
def test_completion_cost_azure_common_deployment_name():
from litellm.utils import (
CallTypes,
Choices,
Delta,
Message,
ModelResponse,
StreamingChoices,
Usage,
)
router = litellm.Router(
model_list=[
{
"model_name": "gpt-4",
"litellm_params": {
"model": "azure/gpt-4-0314",
"max_tokens": 4096,
"api_key": os.getenv("AZURE_API_KEY"),
"api_base": os.getenv("AZURE_API_BASE"),
},
"model_info": {"base_model": "azure/gpt-4"},
}
]
)
response = ModelResponse(
id="chatcmpl-876cce24-e520-4cf8-8649-562a9be11c02",
choices=[
Choices(
finish_reason="stop",
index=0,
message=Message(
content="Hi! I'm an AI, so I don't have emotions or feelings like humans do, but I'm functioning properly and ready to help with any questions or topics you'd like to discuss! How can I assist you today?",
role="assistant",
),
)
],
created=1717519830,
model="gpt-4",
object="chat.completion",
system_fingerprint="fp_c1a4bcec29",
usage=Usage(completion_tokens=46, prompt_tokens=17, total_tokens=63),
)
response._hidden_params["custom_llm_provider"] = "azure"
print(response)
with patch.object(
litellm.cost_calculator, "completion_cost", new=MagicMock()
) as mock_client:
_ = litellm.response_cost_calculator(
response_object=response,
model="gpt-4-0314",
custom_llm_provider="azure",
call_type=CallTypes.acompletion.value,
optional_params={},
base_model="azure/gpt-4",
)
mock_client.assert_called()
print(f"mock_client.call_args: {mock_client.call_args.kwargs}")
assert "azure/gpt-4" == mock_client.call_args.kwargs["model"]

View file

@ -5203,19 +5203,43 @@ def get_model_info(model: str, custom_llm_provider: Optional[str] = None) -> Mod
if custom_llm_provider == "predibase": if custom_llm_provider == "predibase":
_model_info["supports_response_schema"] = True _model_info["supports_response_schema"] = True
_input_cost_per_token: Optional[float] = _model_info.get(
"input_cost_per_token"
)
if _input_cost_per_token is None:
# default value to 0, be noisy about this
verbose_logger.debug(
"model={}, custom_llm_provider={} has no input_cost_per_token in model_cost_map. Defaulting to 0.".format(
model, custom_llm_provider
)
)
_input_cost_per_token = 0
_output_cost_per_token: Optional[float] = _model_info.get(
"output_cost_per_token"
)
if _output_cost_per_token is None:
# default value to 0, be noisy about this
verbose_logger.debug(
"model={}, custom_llm_provider={} has no output_cost_per_token in model_cost_map. Defaulting to 0.".format(
model, custom_llm_provider
)
)
_output_cost_per_token = 0
return ModelInfo( return ModelInfo(
key=key, key=key,
max_tokens=_model_info.get("max_tokens", None), max_tokens=_model_info.get("max_tokens", None),
max_input_tokens=_model_info.get("max_input_tokens", None), max_input_tokens=_model_info.get("max_input_tokens", None),
max_output_tokens=_model_info.get("max_output_tokens", None), max_output_tokens=_model_info.get("max_output_tokens", None),
input_cost_per_token=_model_info.get("input_cost_per_token", 0), input_cost_per_token=_input_cost_per_token,
input_cost_per_character=_model_info.get( input_cost_per_character=_model_info.get(
"input_cost_per_character", None "input_cost_per_character", None
), ),
input_cost_per_token_above_128k_tokens=_model_info.get( input_cost_per_token_above_128k_tokens=_model_info.get(
"input_cost_per_token_above_128k_tokens", None "input_cost_per_token_above_128k_tokens", None
), ),
output_cost_per_token=_model_info.get("output_cost_per_token", 0), output_cost_per_token=_output_cost_per_token,
output_cost_per_character=_model_info.get( output_cost_per_character=_model_info.get(
"output_cost_per_character", None "output_cost_per_character", None
), ),