diff --git a/litellm/tests/test_custom_callback_router.py b/litellm/tests/test_custom_callback_router.py index ac8b2fa101..b33cfd8cf0 100644 --- a/litellm/tests/test_custom_callback_router.py +++ b/litellm/tests/test_custom_callback_router.py @@ -256,6 +256,7 @@ class CompletionCustomHandler( async def async_log_success_event(self, kwargs, response_obj, start_time, end_time): try: self.states.append("async_success") + print("in async success, kwargs: ", kwargs) ## START TIME assert isinstance(start_time, datetime) ## END TIME @@ -266,6 +267,37 @@ class CompletionCustomHandler( ) ## KWARGS assert isinstance(kwargs["model"], str) + + # checking we use base_model for azure cost calculation + base_model = ( + kwargs.get("litellm_params", {}) + .get("metadata", {}) + .get("model_info", {}) + .get("base_model", None) + ) + + if kwargs["model"] == "chatgpt-v-2" and base_model is not None: + # when base_model is set for azure, we should use pricing for the base_model + # this checks response_cost == litellm.cost_per_token(model=base_model) + assert isinstance(kwargs["response_cost"], float) + response_cost = kwargs["response_cost"] + print( + f"response_cost: {response_cost}, for model: {kwargs['model']} and base_model: {base_model}" + ) + prompt_tokens = response_obj.usage.prompt_tokens + completion_tokens = response_obj.usage.completion_tokens + # ensure the pricing is based on the base_model here + prompt_price, completion_price = litellm.cost_per_token( + model=base_model, + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, + ) + expected_price = prompt_price + completion_price + print(f"expected price: {expected_price}") + assert ( + response_cost == expected_price + ), f"response_cost: {response_cost} != expected_price: {expected_price}. For model: {kwargs['model']} and base_model: {base_model}. should have used base_model for price" + assert isinstance(kwargs["messages"], list) assert isinstance(kwargs["optional_params"], dict) assert isinstance(kwargs["litellm_params"], dict) @@ -345,6 +377,7 @@ async def test_async_chat_azure(): customHandler_streaming_azure_router = CompletionCustomHandler() customHandler_failure = CompletionCustomHandler() litellm.callbacks = [customHandler_completion_azure_router] + litellm.set_verbose = True model_list = [ { "model_name": "gpt-3.5-turbo", # openai model name @@ -354,6 +387,7 @@ async def test_async_chat_azure(): "api_version": os.getenv("AZURE_API_VERSION"), "api_base": os.getenv("AZURE_API_BASE"), }, + "model_info": {"base_model": "azure/gpt-4-1106-preview"}, "tpm": 240000, "rpm": 1800, },