(test) using base_model for cost_calc on router

This commit is contained in:
ishaan-jaff 2024-02-07 16:30:58 -08:00
parent 920d684da4
commit 705396240e

View file

@ -256,6 +256,7 @@ class CompletionCustomHandler(
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
try:
self.states.append("async_success")
print("in async success, kwargs: ", kwargs)
## START TIME
assert isinstance(start_time, datetime)
## END TIME
@ -266,6 +267,37 @@ class CompletionCustomHandler(
)
## KWARGS
assert isinstance(kwargs["model"], str)
# checking we use base_model for azure cost calculation
base_model = (
kwargs.get("litellm_params", {})
.get("metadata", {})
.get("model_info", {})
.get("base_model", None)
)
if kwargs["model"] == "chatgpt-v-2" and base_model is not None:
# when base_model is set for azure, we should use pricing for the base_model
# this checks response_cost == litellm.cost_per_token(model=base_model)
assert isinstance(kwargs["response_cost"], float)
response_cost = kwargs["response_cost"]
print(
f"response_cost: {response_cost}, for model: {kwargs['model']} and base_model: {base_model}"
)
prompt_tokens = response_obj.usage.prompt_tokens
completion_tokens = response_obj.usage.completion_tokens
# ensure the pricing is based on the base_model here
prompt_price, completion_price = litellm.cost_per_token(
model=base_model,
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
)
expected_price = prompt_price + completion_price
print(f"expected price: {expected_price}")
assert (
response_cost == expected_price
), f"response_cost: {response_cost} != expected_price: {expected_price}. For model: {kwargs['model']} and base_model: {base_model}. should have used base_model for price"
assert isinstance(kwargs["messages"], list)
assert isinstance(kwargs["optional_params"], dict)
assert isinstance(kwargs["litellm_params"], dict)
@ -345,6 +377,7 @@ async def test_async_chat_azure():
customHandler_streaming_azure_router = CompletionCustomHandler()
customHandler_failure = CompletionCustomHandler()
litellm.callbacks = [customHandler_completion_azure_router]
litellm.set_verbose = True
model_list = [
{
"model_name": "gpt-3.5-turbo", # openai model name
@ -354,6 +387,7 @@ async def test_async_chat_azure():
"api_version": os.getenv("AZURE_API_VERSION"),
"api_base": os.getenv("AZURE_API_BASE"),
},
"model_info": {"base_model": "azure/gpt-4-1106-preview"},
"tpm": 240000,
"rpm": 1800,
},