fix(proxy/utils.py): fix failure logging for rejected requests. + unit tests

This commit is contained in:
Krrish Dholakia 2024-07-16 17:15:20 -07:00
parent 06efe28132
commit ec03e675c9
5 changed files with 113 additions and 18 deletions

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View file

@ -593,12 +593,41 @@ class ProxyLogging:
)
if litellm_logging_obj is not None:
## UPDATE LOGGING INPUT
_optional_params = {}
for k, v in request_data.items():
if k != "model" and k != "user" and k != "litellm_params":
_optional_params[k] = v
litellm_logging_obj.update_environment_variables(
model=request_data.get("model", ""),
user=request_data.get("user", ""),
optional_params=_optional_params,
litellm_params=request_data.get("litellm_params", {}),
)
input: Union[list, str, dict] = ""
if "messages" in request_data and isinstance(
request_data["messages"], list
):
input = request_data["messages"]
elif "prompt" in request_data and isinstance(
request_data["prompt"], str
):
input = request_data["prompt"]
elif "input" in request_data and isinstance(
request_data["input"], list
):
input = request_data["input"]
litellm_logging_obj.pre_call(
input=input,
api_key="",
)
# log the custom exception
await litellm_logging_obj.async_failure_handler(
exception=original_exception,
traceback_exception=traceback.format_exc(),
start_time=time.time(),
end_time=time.time(),
)
threading.Thread(
@ -606,8 +635,6 @@ class ProxyLogging:
args=(
original_exception,
traceback.format_exc(),
time.time(),
time.time(),
),
).start()

View file

@ -1,20 +1,29 @@
import sys, os
import os
import sys
import traceback
from unittest import mock
from dotenv import load_dotenv
import litellm.proxy
import litellm.proxy.proxy_server
load_dotenv()
import os, io
import io
import os
# this file is to test litellm/proxy
sys.path.insert(
0, os.path.abspath("../..")
) # Adds the parent directory to the system path
import pytest, logging, asyncio
import asyncio
import logging
import pytest
import litellm
from litellm import embedding, completion, completion_cost, Timeout
from litellm import RateLimitError
from litellm import RateLimitError, Timeout, completion, completion_cost, embedding
# Configure logging
logging.basicConfig(
@ -22,14 +31,20 @@ logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(message)s",
)
from unittest.mock import AsyncMock, MagicMock, patch
from fastapi import FastAPI
# test /chat/completion request to the proxy
from fastapi.testclient import TestClient
from fastapi import FastAPI
from litellm.proxy.proxy_server import (
from litellm.integrations.custom_logger import CustomLogger
from litellm.proxy.proxy_server import ( # Replace with the actual module where your FastAPI router is defined
app,
save_worker_config,
initialize,
) # Replace with the actual module where your FastAPI router is defined
save_worker_config,
)
from litellm.proxy.utils import ProxyLogging
# Your bearer token
token = "sk-1234"
@ -158,6 +173,61 @@ def test_chat_completion(mock_acompletion, client_no_auth):
pytest.fail(f"LiteLLM Proxy test failed. Exception - {str(e)}")
from litellm.tests.test_custom_callback_input import CompletionCustomHandler
@mock_patch_acompletion()
def test_custom_logger_failure_handler(mock_acompletion, client_no_auth):
from litellm.proxy._types import UserAPIKeyAuth
from litellm.proxy.proxy_server import hash_token, user_api_key_cache
rpm_limit = 0
mock_api_key = "sk-my-test-key"
cache_value = UserAPIKeyAuth(token=hash_token(mock_api_key), rpm_limit=rpm_limit)
user_api_key_cache.set_cache(key=hash_token(mock_api_key), value=cache_value)
mock_logger = CustomLogger()
mock_logger_unit_tests = CompletionCustomHandler()
proxy_logging_obj: ProxyLogging = getattr(
litellm.proxy.proxy_server, "proxy_logging_obj"
)
litellm.callbacks = [mock_logger, mock_logger_unit_tests]
proxy_logging_obj._init_litellm_callbacks(llm_router=None)
setattr(litellm.proxy.proxy_server, "user_api_key_cache", user_api_key_cache)
setattr(litellm.proxy.proxy_server, "master_key", "sk-1234")
setattr(litellm.proxy.proxy_server, "prisma_client", "FAKE-VAR")
setattr(litellm.proxy.proxy_server, "proxy_logging_obj", proxy_logging_obj)
with patch.object(
mock_logger, "async_log_failure_event", new=AsyncMock()
) as mock_failed_alert:
# Your test data
test_data = {
"model": "gpt-3.5-turbo",
"messages": [
{"role": "user", "content": "hi"},
],
"max_tokens": 10,
}
print("testing proxy server with chat completions")
response = client_no_auth.post(
"/v1/chat/completions",
json=test_data,
headers={"Authorization": "Bearer {}".format(mock_api_key)},
)
assert response.status_code == 429
# confirm async_log_failure_event is called
mock_failed_alert.assert_called()
assert len(mock_logger_unit_tests.errors) == 0
@mock_patch_acompletion()
def test_engines_model_chat_completions(mock_acompletion, client_no_auth):
global headers
@ -422,9 +492,10 @@ def test_add_new_model(client_no_auth):
def test_health(client_no_auth):
global headers
import time
from litellm._logging import verbose_logger, verbose_proxy_logger
import logging
import time
from litellm._logging import verbose_logger, verbose_proxy_logger
verbose_proxy_logger.setLevel(logging.DEBUG)