fix(cost_calculator.py): fix cost calc

This commit is contained in:
Krrish Dholakia 2024-08-12 16:44:44 -07:00
parent 89e3141e2d
commit ef8fb23334
2 changed files with 22 additions and 8 deletions

View file

@ -490,10 +490,18 @@ def completion_cost(
isinstance(completion_response, BaseModel) isinstance(completion_response, BaseModel)
or isinstance(completion_response, dict) or isinstance(completion_response, dict)
): # tts returns a custom class ): # tts returns a custom class
if isinstance(completion_response, BaseModel) and not isinstance(
completion_response, litellm.Usage usage_obj: Optional[Union[dict, litellm.Usage]] = completion_response.get(
"usage", {}
)
if isinstance(usage_obj, BaseModel) and not isinstance(
usage_obj, litellm.Usage
): ):
completion_response = litellm.Usage(**completion_response.model_dump()) setattr(
completion_response,
"usage",
litellm.Usage(**usage_obj.model_dump()),
)
# get input/output tokens from completion_response # get input/output tokens from completion_response
prompt_tokens = completion_response.get("usage", {}).get("prompt_tokens", 0) prompt_tokens = completion_response.get("usage", {}).get("prompt_tokens", 0)
completion_tokens = completion_response.get("usage", {}).get( completion_tokens = completion_response.get("usage", {}).get(

View file

@ -1,11 +1,17 @@
### What this tests #### ### What this tests ####
import sys, os, time, inspect, asyncio, traceback import asyncio
import inspect
import os
import sys
import time
import traceback
import pytest import pytest
sys.path.insert(0, os.path.abspath("../..")) sys.path.insert(0, os.path.abspath("../.."))
from litellm import completion, embedding
import litellm import litellm
from litellm import completion, embedding
from litellm.integrations.custom_logger import CustomLogger from litellm.integrations.custom_logger import CustomLogger
@ -201,7 +207,7 @@ def test_async_custom_handler_stream():
print("complete_streaming_response: ", complete_streaming_response) print("complete_streaming_response: ", complete_streaming_response)
assert response_in_success_handler == complete_streaming_response assert response_in_success_handler == complete_streaming_response
except Exception as e: except Exception as e:
pytest.fail(f"Error occurred: {e}") pytest.fail(f"Error occurred: {e}\n{traceback.format_exc()}")
# test_async_custom_handler_stream() # test_async_custom_handler_stream()
@ -457,11 +463,11 @@ async def test_cost_tracking_with_caching():
def test_redis_cache_completion_stream(): def test_redis_cache_completion_stream():
from litellm import Cache
# Important Test - This tests if we can add to streaming cache, when custom callbacks are set # Important Test - This tests if we can add to streaming cache, when custom callbacks are set
import random import random
from litellm import Cache
try: try:
print("\nrunning test_redis_cache_completion_stream") print("\nrunning test_redis_cache_completion_stream")
litellm.set_verbose = True litellm.set_verbose = True