(fixes) gcs bucket key based logging (#6044)

* fixes for gcs bucket logging

* fix StandardCallbackDynamicParams

* fix - gcs logging when payload is not serializable

* add test_add_callback_via_key_litellm_pre_call_utils_gcs_bucket

* working success callbacks

* linting fixes

* fix linting error

* add type hints to functions

* fixes for dynamic success and failure logging

* fix for test_async_chat_openai_stream
This commit is contained in:
Ishaan Jaff 2024-10-04 11:56:10 +05:30 committed by GitHub
parent 793593e735
commit 670ecda4e2
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
9 changed files with 446 additions and 39 deletions

View file

@ -267,7 +267,7 @@ async def test_basic_gcs_logger_failure():
@pytest.mark.asyncio
async def test_basic_gcs_logging_per_request():
async def test_basic_gcs_logging_per_request_with_callback_set():
"""
Test GCS Bucket logging per request
@ -391,3 +391,128 @@ async def test_basic_gcs_logging_per_request():
object_name=object_name,
standard_callback_dynamic_params=standard_callback_dynamic_params,
)
@pytest.mark.asyncio
async def test_basic_gcs_logging_per_request_with_no_litellm_callback_set():
"""
Test GCS Bucket logging per request
key difference: no litellm.callbacks set
Request 1 - pass gcs_bucket_name in kwargs
Request 2 - don't pass gcs_bucket_name in kwargs - ensure 'litellm-testing-bucket'
"""
import logging
from litellm._logging import verbose_logger
verbose_logger.setLevel(logging.DEBUG)
load_vertex_ai_credentials()
gcs_logger = GCSBucketLogger()
GCS_BUCKET_NAME = "key-logging-project1"
standard_callback_dynamic_params: StandardCallbackDynamicParams = (
StandardCallbackDynamicParams(gcs_bucket_name=GCS_BUCKET_NAME)
)
try:
response = await litellm.acompletion(
model="gpt-4o-mini",
temperature=0.7,
messages=[{"role": "user", "content": "This is a test"}],
max_tokens=10,
user="ishaan-2",
gcs_bucket_name=GCS_BUCKET_NAME,
success_callback=["gcs_bucket"],
failure_callback=["gcs_bucket"],
)
except:
pass
await asyncio.sleep(5)
# Get the current date
# Get the current date
current_date = datetime.now().strftime("%Y-%m-%d")
# Modify the object_name to include the date-based folder
object_name = f"{current_date}%2F{response.id}"
print("object_name", object_name)
# Check if object landed on GCS
object_from_gcs = await gcs_logger.download_gcs_object(
object_name=object_name,
standard_callback_dynamic_params=standard_callback_dynamic_params,
)
print("object from gcs=", object_from_gcs)
# convert object_from_gcs from bytes to DICT
parsed_data = json.loads(object_from_gcs)
print("object_from_gcs as dict", parsed_data)
print("type of object_from_gcs", type(parsed_data))
gcs_payload = StandardLoggingPayload(**parsed_data)
assert gcs_payload["model"] == "gpt-4o-mini"
assert gcs_payload["messages"] == [{"role": "user", "content": "This is a test"}]
assert gcs_payload["response_cost"] > 0.0
assert gcs_payload["status"] == "success"
# clean up the object from GCS
await gcs_logger.delete_gcs_object(
object_name=object_name,
standard_callback_dynamic_params=standard_callback_dynamic_params,
)
# make a failure request - assert that failure callback is hit
gcs_log_id = f"failure-test-{uuid.uuid4().hex}"
try:
response = await litellm.acompletion(
model="gpt-4o-mini",
temperature=0.7,
messages=[{"role": "user", "content": "This is a test"}],
max_tokens=10,
user="ishaan-2",
mock_response=litellm.BadRequestError(
model="gpt-3.5-turbo",
message="Error: 400: Bad Request: Invalid API key, please check your API key and try again.",
llm_provider="openai",
),
success_callback=["gcs_bucket"],
failure_callback=["gcs_bucket"],
gcs_bucket_name=GCS_BUCKET_NAME,
metadata={
"gcs_log_id": gcs_log_id,
},
)
except:
pass
await asyncio.sleep(5)
# check if the failure object is logged in GCS
object_from_gcs = await gcs_logger.download_gcs_object(
object_name=gcs_log_id,
standard_callback_dynamic_params=standard_callback_dynamic_params,
)
print("object from gcs=", object_from_gcs)
# convert object_from_gcs from bytes to DICT
parsed_data = json.loads(object_from_gcs)
print("object_from_gcs as dict", parsed_data)
gcs_payload = StandardLoggingPayload(**parsed_data)
assert gcs_payload["model"] == "gpt-4o-mini"
assert gcs_payload["messages"] == [{"role": "user", "content": "This is a test"}]
assert gcs_payload["response_cost"] == 0
assert gcs_payload["status"] == "failure"
# clean up the object from GCS
await gcs_logger.delete_gcs_object(
object_name=gcs_log_id,
standard_callback_dynamic_params=standard_callback_dynamic_params,
)

View file

@ -1389,6 +1389,138 @@ async def test_add_callback_via_key_litellm_pre_call_utils(
assert new_data["failure_callback"] == expected_failure_callbacks
@pytest.mark.asyncio
@pytest.mark.parametrize(
"callback_type, expected_success_callbacks, expected_failure_callbacks",
[
("success", ["gcs_bucket"], []),
("failure", [], ["gcs_bucket"]),
("success_and_failure", ["gcs_bucket"], ["gcs_bucket"]),
],
)
async def test_add_callback_via_key_litellm_pre_call_utils_gcs_bucket(
prisma_client, callback_type, expected_success_callbacks, expected_failure_callbacks
):
import json
from fastapi import HTTPException, Request, Response
from starlette.datastructures import URL
from litellm.proxy.litellm_pre_call_utils import add_litellm_data_to_request
setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client)
setattr(litellm.proxy.proxy_server, "master_key", "sk-1234")
await litellm.proxy.proxy_server.prisma_client.connect()
proxy_config = getattr(litellm.proxy.proxy_server, "proxy_config")
request = Request(scope={"type": "http", "method": "POST", "headers": {}})
request._url = URL(url="/chat/completions")
test_data = {
"model": "azure/chatgpt-v-2",
"messages": [
{"role": "user", "content": "write 1 sentence poem"},
],
"max_tokens": 10,
"mock_response": "Hello world",
"api_key": "my-fake-key",
}
json_bytes = json.dumps(test_data).encode("utf-8")
request._body = json_bytes
data = {
"data": {
"model": "azure/chatgpt-v-2",
"messages": [{"role": "user", "content": "write 1 sentence poem"}],
"max_tokens": 10,
"mock_response": "Hello world",
"api_key": "my-fake-key",
},
"request": request,
"user_api_key_dict": UserAPIKeyAuth(
token=None,
key_name=None,
key_alias=None,
spend=0.0,
max_budget=None,
expires=None,
models=[],
aliases={},
config={},
user_id=None,
team_id=None,
max_parallel_requests=None,
metadata={
"logging": [
{
"callback_name": "gcs_bucket",
"callback_type": callback_type,
"callback_vars": {
"gcs_bucket_name": "key-logging-project1",
"gcs_path_service_account": "adroit-crow-413218-a956eef1a2a8.json",
},
}
]
},
tpm_limit=None,
rpm_limit=None,
budget_duration=None,
budget_reset_at=None,
allowed_cache_controls=[],
permissions={},
model_spend={},
model_max_budget={},
soft_budget_cooldown=False,
litellm_budget_table=None,
org_id=None,
team_spend=None,
team_alias=None,
team_tpm_limit=None,
team_rpm_limit=None,
team_max_budget=None,
team_models=[],
team_blocked=False,
soft_budget=None,
team_model_aliases=None,
team_member_spend=None,
team_metadata=None,
end_user_id=None,
end_user_tpm_limit=None,
end_user_rpm_limit=None,
end_user_max_budget=None,
last_refreshed_at=None,
api_key=None,
user_role=None,
allowed_model_region=None,
parent_otel_span=None,
),
"proxy_config": proxy_config,
"general_settings": {},
"version": "0.0.0",
}
new_data = await add_litellm_data_to_request(**data)
print("NEW DATA: {}".format(new_data))
assert "gcs_bucket_name" in new_data
assert new_data["gcs_bucket_name"] == "key-logging-project1"
assert "gcs_path_service_account" in new_data
assert (
new_data["gcs_path_service_account"] == "adroit-crow-413218-a956eef1a2a8.json"
)
if expected_success_callbacks:
assert "success_callback" in new_data
assert new_data["success_callback"] == expected_success_callbacks
if expected_failure_callbacks:
assert "failure_callback" in new_data
assert new_data["failure_callback"] == expected_failure_callbacks
@pytest.mark.asyncio
async def test_gemini_pass_through_endpoint():
from starlette.datastructures import URL