Merge pull request #4742 from BerriAI/litellm_logging_fix

fix(proxy/utils.py): fix failure logging for rejected requests
This commit is contained in:
Krish Dholakia 2024-07-16 21:45:06 -07:00 committed by GitHub
commit 5e1e413de0
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 113 additions and 15 deletions

View file

@ -600,12 +600,41 @@ class ProxyLogging:
)
if litellm_logging_obj is not None:
## UPDATE LOGGING INPUT
_optional_params = {}
for k, v in request_data.items():
if k != "model" and k != "user" and k != "litellm_params":
_optional_params[k] = v
litellm_logging_obj.update_environment_variables(
model=request_data.get("model", ""),
user=request_data.get("user", ""),
optional_params=_optional_params,
litellm_params=request_data.get("litellm_params", {}),
)
input: Union[list, str, dict] = ""
if "messages" in request_data and isinstance(
request_data["messages"], list
):
input = request_data["messages"]
elif "prompt" in request_data and isinstance(
request_data["prompt"], str
):
input = request_data["prompt"]
elif "input" in request_data and isinstance(
request_data["input"], list
):
input = request_data["input"]
litellm_logging_obj.pre_call(
input=input,
api_key="",
)
# log the custom exception
await litellm_logging_obj.async_failure_handler(
exception=original_exception,
traceback_exception=traceback.format_exc(),
start_time=time.time(),
end_time=time.time(),
)
threading.Thread(
@ -613,8 +642,6 @@ class ProxyLogging:
args=(
original_exception,
traceback.format_exc(),
time.time(),
time.time(),
),
).start()

View file

@ -1,20 +1,29 @@
import sys, os
import os
import sys
import traceback
from unittest import mock
from dotenv import load_dotenv
import litellm.proxy
import litellm.proxy.proxy_server
load_dotenv()
import os, io
import io
import os
# this file is to test litellm/proxy
sys.path.insert(
0, os.path.abspath("../..")
) # Adds the parent directory to the system path
import pytest, logging, asyncio
import asyncio
import logging
import pytest
import litellm
from litellm import embedding, completion, completion_cost, Timeout
from litellm import RateLimitError
from litellm import RateLimitError, Timeout, completion, completion_cost, embedding
# Configure logging
logging.basicConfig(
@ -22,14 +31,20 @@ logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(message)s",
)
from unittest.mock import AsyncMock, MagicMock, patch
from fastapi import FastAPI
# test /chat/completion request to the proxy
from fastapi.testclient import TestClient
from fastapi import FastAPI
from litellm.proxy.proxy_server import (
from litellm.integrations.custom_logger import CustomLogger
from litellm.proxy.proxy_server import ( # Replace with the actual module where your FastAPI router is defined
app,
save_worker_config,
initialize,
) # Replace with the actual module where your FastAPI router is defined
save_worker_config,
)
from litellm.proxy.utils import ProxyLogging
# Your bearer token
token = "sk-1234"
@ -158,6 +173,61 @@ def test_chat_completion(mock_acompletion, client_no_auth):
pytest.fail(f"LiteLLM Proxy test failed. Exception - {str(e)}")
from litellm.tests.test_custom_callback_input import CompletionCustomHandler
@mock_patch_acompletion()
def test_custom_logger_failure_handler(mock_acompletion, client_no_auth):
from litellm.proxy._types import UserAPIKeyAuth
from litellm.proxy.proxy_server import hash_token, user_api_key_cache
rpm_limit = 0
mock_api_key = "sk-my-test-key"
cache_value = UserAPIKeyAuth(token=hash_token(mock_api_key), rpm_limit=rpm_limit)
user_api_key_cache.set_cache(key=hash_token(mock_api_key), value=cache_value)
mock_logger = CustomLogger()
mock_logger_unit_tests = CompletionCustomHandler()
proxy_logging_obj: ProxyLogging = getattr(
litellm.proxy.proxy_server, "proxy_logging_obj"
)
litellm.callbacks = [mock_logger, mock_logger_unit_tests]
proxy_logging_obj._init_litellm_callbacks(llm_router=None)
setattr(litellm.proxy.proxy_server, "user_api_key_cache", user_api_key_cache)
setattr(litellm.proxy.proxy_server, "master_key", "sk-1234")
setattr(litellm.proxy.proxy_server, "prisma_client", "FAKE-VAR")
setattr(litellm.proxy.proxy_server, "proxy_logging_obj", proxy_logging_obj)
with patch.object(
mock_logger, "async_log_failure_event", new=AsyncMock()
) as mock_failed_alert:
# Your test data
test_data = {
"model": "gpt-3.5-turbo",
"messages": [
{"role": "user", "content": "hi"},
],
"max_tokens": 10,
}
print("testing proxy server with chat completions")
response = client_no_auth.post(
"/v1/chat/completions",
json=test_data,
headers={"Authorization": "Bearer {}".format(mock_api_key)},
)
assert response.status_code == 429
# confirm async_log_failure_event is called
mock_failed_alert.assert_called()
assert len(mock_logger_unit_tests.errors) == 0
@mock_patch_acompletion()
def test_engines_model_chat_completions(mock_acompletion, client_no_auth):
global headers
@ -422,9 +492,10 @@ def test_add_new_model(client_no_auth):
def test_health(client_no_auth):
global headers
import time
from litellm._logging import verbose_logger, verbose_proxy_logger
import logging
import time
from litellm._logging import verbose_logger, verbose_proxy_logger
verbose_proxy_logger.setLevel(logging.DEBUG)