Merge pull request #1413 from BerriAI/litellm_log_cache_hits

[Feat] Proxy - Log Cache Hits on success callbacks + Testing
This commit is contained in:
Ishaan Jaff 2024-01-11 16:39:22 +05:30 committed by GitHub
commit e5b491b39f
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 35 additions and 4 deletions

View file

@ -93,6 +93,7 @@ class S3Logger:
messages = kwargs.get("messages")
optional_params = kwargs.get("optional_params", {})
call_type = kwargs.get("call_type", "litellm.completion")
cache_hit = kwargs.get("cache_hit", False)
usage = response_obj["usage"]
id = response_obj.get("id", str(uuid.uuid4()))
@ -100,6 +101,7 @@ class S3Logger:
payload = {
"id": id,
"call_type": call_type,
"cache_hit": cache_hit,
"startTime": start_time,
"endTime": end_time,
"model": kwargs.get("model", ""),
@ -118,7 +120,10 @@ class S3Logger:
except:
# non blocking if it can't cast to a str
pass
s3_object_key = payload["id"]
s3_object_key = (
payload["id"] + "-time=" + str(start_time)
) # we need the s3 key to include the time, so we log cache hits too
import json

View file

@ -20,8 +20,10 @@ def test_s3_logging():
# since we are modifying stdout, and pytests runs tests in parallel
# on circle ci - we only test litellm.acompletion()
try:
# pre
# redirect stdout to log_file
litellm.cache = litellm.Cache(
type="s3", s3_bucket_name="cache-bucket-litellm", s3_region_name="us-west-2"
)
litellm.success_callback = ["s3"]
litellm.s3_callback_params = {
@ -35,10 +37,14 @@ def test_s3_logging():
expected_keys = []
import time
curr_time = str(time.time())
async def _test():
return await litellm.acompletion(
model="gpt-3.5-turbo",
messages=[{"role": "user", "content": "This is a test"}],
messages=[{"role": "user", "content": f"This is a test {curr_time}"}],
max_tokens=10,
temperature=0.7,
user="ishaan-2",
@ -48,6 +54,19 @@ def test_s3_logging():
print(f"response: {response}")
expected_keys.append(response.id)
async def _test():
return await litellm.acompletion(
model="gpt-3.5-turbo",
messages=[{"role": "user", "content": f"This is a test {curr_time}"}],
max_tokens=10,
temperature=0.7,
user="ishaan-2",
)
response = asyncio.run(_test())
expected_keys.append(response.id)
print(f"response: {response}")
# # streaming + async
# async def _test2():
# response = await litellm.acompletion(
@ -86,10 +105,17 @@ def test_s3_logging():
)
# Get the keys of the most recent objects
most_recent_keys = [obj["Key"] for obj in objects]
print(most_recent_keys)
# for each key, get the part before "-" as the key. Do it safely
cleaned_keys = []
for key in most_recent_keys:
split_key = key.split("-time=")
cleaned_keys.append(split_key[0])
print("\n most recent keys", most_recent_keys)
print("\n cleaned keys", cleaned_keys)
print("\n Expected keys: ", expected_keys)
for key in expected_keys:
assert key in most_recent_keys
assert key in cleaned_keys
except Exception as e:
pytest.fail(f"An exception occurred - {e}")
finally: