mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 18:54:30 +00:00
fix(utils.py): fix cost tracking for cache hits (should be 0)
This commit is contained in:
parent
62ad6f19b7
commit
1ed6842009
2 changed files with 129 additions and 19 deletions
|
@ -31,6 +31,7 @@ class MyCustomHandler(CustomLogger):
|
|||
self.sync_stream_collected_response = None # type: ignore
|
||||
self.user = None # type: ignore
|
||||
self.data_sent_to_api: dict = {}
|
||||
self.response_cost = 0
|
||||
|
||||
def log_pre_api_call(self, model, messages, kwargs):
|
||||
print(f"Pre-API Call")
|
||||
|
@ -47,6 +48,8 @@ class MyCustomHandler(CustomLogger):
|
|||
self.success = True
|
||||
if kwargs.get("stream") == True:
|
||||
self.sync_stream_collected_response = response_obj
|
||||
print(f"response cost in log_success_event: {kwargs.get('response_cost')}")
|
||||
self.response_cost = kwargs.get("response_cost", 0)
|
||||
|
||||
def log_failure_event(self, kwargs, response_obj, start_time, end_time):
|
||||
print(f"On Failure")
|
||||
|
@ -64,6 +67,10 @@ class MyCustomHandler(CustomLogger):
|
|||
self.stream_collected_response = response_obj
|
||||
self.async_completion_kwargs = kwargs
|
||||
self.user = kwargs.get("user", None)
|
||||
print(
|
||||
f"response cost in log_async_success_event: {kwargs.get('response_cost')}"
|
||||
)
|
||||
self.response_cost = kwargs.get("response_cost", 0)
|
||||
|
||||
async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time):
|
||||
print(f"On Async Failure")
|
||||
|
@ -400,6 +407,50 @@ async def test_async_custom_handler_embedding_optional_param_bedrock():
|
|||
assert "user" not in customHandler_optional_params.data_sent_to_api
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cost_tracking_with_caching():
|
||||
"""
|
||||
Important Test - This tests if that cost is 0 for cached responses
|
||||
"""
|
||||
from litellm import Cache
|
||||
|
||||
litellm.set_verbose = False
|
||||
litellm.cache = Cache(
|
||||
type="redis",
|
||||
host=os.environ["REDIS_HOST"],
|
||||
port=os.environ["REDIS_PORT"],
|
||||
password=os.environ["REDIS_PASSWORD"],
|
||||
)
|
||||
customHandler_optional_params = MyCustomHandler()
|
||||
litellm.callbacks = [customHandler_optional_params]
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": f"write a one sentence poem about: {time.time()}",
|
||||
}
|
||||
]
|
||||
response1 = await litellm.acompletion(
|
||||
model="gpt-3.5-turbo",
|
||||
messages=messages,
|
||||
max_tokens=40,
|
||||
temperature=0.2,
|
||||
caching=True,
|
||||
)
|
||||
await asyncio.sleep(1) # success callback is async
|
||||
response_cost = customHandler_optional_params.response_cost
|
||||
assert response_cost > 0
|
||||
response2 = await litellm.acompletion(
|
||||
model="gpt-3.5-turbo",
|
||||
messages=messages,
|
||||
max_tokens=40,
|
||||
temperature=0.2,
|
||||
caching=True,
|
||||
)
|
||||
await asyncio.sleep(1) # success callback is async
|
||||
response_cost_2 = customHandler_optional_params.response_cost
|
||||
assert response_cost_2 == 0
|
||||
|
||||
|
||||
def test_redis_cache_completion_stream():
|
||||
from litellm import Cache
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue