diff --git a/litellm/cost_calculator.py b/litellm/cost_calculator.py index 6eec8d3cd..a3cb847a4 100644 --- a/litellm/cost_calculator.py +++ b/litellm/cost_calculator.py @@ -490,10 +490,18 @@ def completion_cost( isinstance(completion_response, BaseModel) or isinstance(completion_response, dict) ): # tts returns a custom class - if isinstance(completion_response, BaseModel) and not isinstance( - completion_response, litellm.Usage + + usage_obj: Optional[Union[dict, litellm.Usage]] = completion_response.get( + "usage", {} + ) + if isinstance(usage_obj, BaseModel) and not isinstance( + usage_obj, litellm.Usage ): - completion_response = litellm.Usage(**completion_response.model_dump()) + setattr( + completion_response, + "usage", + litellm.Usage(**usage_obj.model_dump()), + ) # get input/output tokens from completion_response prompt_tokens = completion_response.get("usage", {}).get("prompt_tokens", 0) completion_tokens = completion_response.get("usage", {}).get( diff --git a/litellm/tests/test_custom_logger.py b/litellm/tests/test_custom_logger.py index e3407c9e1..465012bff 100644 --- a/litellm/tests/test_custom_logger.py +++ b/litellm/tests/test_custom_logger.py @@ -1,11 +1,17 @@ ### What this tests #### -import sys, os, time, inspect, asyncio, traceback +import asyncio +import inspect +import os +import sys +import time +import traceback + import pytest sys.path.insert(0, os.path.abspath("../..")) -from litellm import completion, embedding import litellm +from litellm import completion, embedding from litellm.integrations.custom_logger import CustomLogger @@ -201,7 +207,7 @@ def test_async_custom_handler_stream(): print("complete_streaming_response: ", complete_streaming_response) assert response_in_success_handler == complete_streaming_response except Exception as e: - pytest.fail(f"Error occurred: {e}") + pytest.fail(f"Error occurred: {e}\n{traceback.format_exc()}") # test_async_custom_handler_stream() @@ -457,11 +463,11 @@ async def test_cost_tracking_with_caching(): def test_redis_cache_completion_stream(): - from litellm import Cache - # Important Test - This tests if we can add to streaming cache, when custom callbacks are set import random + from litellm import Cache + try: print("\nrunning test_redis_cache_completion_stream") litellm.set_verbose = True