feat(cost_calculator.py): only override base model if custom pricing is set

This commit is contained in:
Krrish Dholakia 2024-08-19 16:05:49 -07:00
parent a494b5b2f3
commit 55217fa8d7
4 changed files with 98 additions and 25 deletions

View file

@ -11,6 +11,7 @@ import asyncio
import os
import time
from typing import Optional
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
@ -1031,3 +1032,68 @@ def test_completion_cost_deepseek():
pass
except Exception as e:
pytest.fail(f"Error occurred: {e}")
def test_completion_cost_azure_common_deployment_name():
from litellm.utils import (
CallTypes,
Choices,
Delta,
Message,
ModelResponse,
StreamingChoices,
Usage,
)
router = litellm.Router(
model_list=[
{
"model_name": "gpt-4",
"litellm_params": {
"model": "azure/gpt-4-0314",
"max_tokens": 4096,
"api_key": os.getenv("AZURE_API_KEY"),
"api_base": os.getenv("AZURE_API_BASE"),
},
"model_info": {"base_model": "azure/gpt-4"},
}
]
)
response = ModelResponse(
id="chatcmpl-876cce24-e520-4cf8-8649-562a9be11c02",
choices=[
Choices(
finish_reason="stop",
index=0,
message=Message(
content="Hi! I'm an AI, so I don't have emotions or feelings like humans do, but I'm functioning properly and ready to help with any questions or topics you'd like to discuss! How can I assist you today?",
role="assistant",
),
)
],
created=1717519830,
model="gpt-4",
object="chat.completion",
system_fingerprint="fp_c1a4bcec29",
usage=Usage(completion_tokens=46, prompt_tokens=17, total_tokens=63),
)
response._hidden_params["custom_llm_provider"] = "azure"
print(response)
with patch.object(
litellm.cost_calculator, "completion_cost", new=MagicMock()
) as mock_client:
_ = litellm.response_cost_calculator(
response_object=response,
model="gpt-4-0314",
custom_llm_provider="azure",
call_type=CallTypes.acompletion.value,
optional_params={},
base_model="azure/gpt-4",
)
mock_client.assert_called()
print(f"mock_client.call_args: {mock_client.call_args.kwargs}")
assert "azure/gpt-4" == mock_client.call_args.kwargs["model"]