forked from phoenix/litellm-mirror
fix(proxy/utils.py): fix failure logging for rejected requests. + unit tests
This commit is contained in:
parent
06efe28132
commit
ec03e675c9
5 changed files with 113 additions and 18 deletions
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
|
@ -593,12 +593,41 @@ class ProxyLogging:
|
||||||
)
|
)
|
||||||
|
|
||||||
if litellm_logging_obj is not None:
|
if litellm_logging_obj is not None:
|
||||||
|
## UPDATE LOGGING INPUT
|
||||||
|
_optional_params = {}
|
||||||
|
for k, v in request_data.items():
|
||||||
|
if k != "model" and k != "user" and k != "litellm_params":
|
||||||
|
_optional_params[k] = v
|
||||||
|
litellm_logging_obj.update_environment_variables(
|
||||||
|
model=request_data.get("model", ""),
|
||||||
|
user=request_data.get("user", ""),
|
||||||
|
optional_params=_optional_params,
|
||||||
|
litellm_params=request_data.get("litellm_params", {}),
|
||||||
|
)
|
||||||
|
|
||||||
|
input: Union[list, str, dict] = ""
|
||||||
|
if "messages" in request_data and isinstance(
|
||||||
|
request_data["messages"], list
|
||||||
|
):
|
||||||
|
input = request_data["messages"]
|
||||||
|
elif "prompt" in request_data and isinstance(
|
||||||
|
request_data["prompt"], str
|
||||||
|
):
|
||||||
|
input = request_data["prompt"]
|
||||||
|
elif "input" in request_data and isinstance(
|
||||||
|
request_data["input"], list
|
||||||
|
):
|
||||||
|
input = request_data["input"]
|
||||||
|
|
||||||
|
litellm_logging_obj.pre_call(
|
||||||
|
input=input,
|
||||||
|
api_key="",
|
||||||
|
)
|
||||||
|
|
||||||
# log the custom exception
|
# log the custom exception
|
||||||
await litellm_logging_obj.async_failure_handler(
|
await litellm_logging_obj.async_failure_handler(
|
||||||
exception=original_exception,
|
exception=original_exception,
|
||||||
traceback_exception=traceback.format_exc(),
|
traceback_exception=traceback.format_exc(),
|
||||||
start_time=time.time(),
|
|
||||||
end_time=time.time(),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
threading.Thread(
|
threading.Thread(
|
||||||
|
@ -606,8 +635,6 @@ class ProxyLogging:
|
||||||
args=(
|
args=(
|
||||||
original_exception,
|
original_exception,
|
||||||
traceback.format_exc(),
|
traceback.format_exc(),
|
||||||
time.time(),
|
|
||||||
time.time(),
|
|
||||||
),
|
),
|
||||||
).start()
|
).start()
|
||||||
|
|
||||||
|
|
|
@ -1,20 +1,29 @@
|
||||||
import sys, os
|
import os
|
||||||
|
import sys
|
||||||
import traceback
|
import traceback
|
||||||
from unittest import mock
|
from unittest import mock
|
||||||
|
|
||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
|
|
||||||
|
import litellm.proxy
|
||||||
|
import litellm.proxy.proxy_server
|
||||||
|
|
||||||
load_dotenv()
|
load_dotenv()
|
||||||
import os, io
|
import io
|
||||||
|
import os
|
||||||
|
|
||||||
# this file is to test litellm/proxy
|
# this file is to test litellm/proxy
|
||||||
|
|
||||||
sys.path.insert(
|
sys.path.insert(
|
||||||
0, os.path.abspath("../..")
|
0, os.path.abspath("../..")
|
||||||
) # Adds the parent directory to the system path
|
) # Adds the parent directory to the system path
|
||||||
import pytest, logging, asyncio
|
import asyncio
|
||||||
|
import logging
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
import litellm
|
import litellm
|
||||||
from litellm import embedding, completion, completion_cost, Timeout
|
from litellm import RateLimitError, Timeout, completion, completion_cost, embedding
|
||||||
from litellm import RateLimitError
|
|
||||||
|
|
||||||
# Configure logging
|
# Configure logging
|
||||||
logging.basicConfig(
|
logging.basicConfig(
|
||||||
|
@ -22,14 +31,20 @@ logging.basicConfig(
|
||||||
format="%(asctime)s - %(levelname)s - %(message)s",
|
format="%(asctime)s - %(levelname)s - %(message)s",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
|
||||||
|
from fastapi import FastAPI
|
||||||
|
|
||||||
# test /chat/completion request to the proxy
|
# test /chat/completion request to the proxy
|
||||||
from fastapi.testclient import TestClient
|
from fastapi.testclient import TestClient
|
||||||
from fastapi import FastAPI
|
|
||||||
from litellm.proxy.proxy_server import (
|
from litellm.integrations.custom_logger import CustomLogger
|
||||||
|
from litellm.proxy.proxy_server import ( # Replace with the actual module where your FastAPI router is defined
|
||||||
app,
|
app,
|
||||||
save_worker_config,
|
|
||||||
initialize,
|
initialize,
|
||||||
) # Replace with the actual module where your FastAPI router is defined
|
save_worker_config,
|
||||||
|
)
|
||||||
|
from litellm.proxy.utils import ProxyLogging
|
||||||
|
|
||||||
# Your bearer token
|
# Your bearer token
|
||||||
token = "sk-1234"
|
token = "sk-1234"
|
||||||
|
@ -158,6 +173,61 @@ def test_chat_completion(mock_acompletion, client_no_auth):
|
||||||
pytest.fail(f"LiteLLM Proxy test failed. Exception - {str(e)}")
|
pytest.fail(f"LiteLLM Proxy test failed. Exception - {str(e)}")
|
||||||
|
|
||||||
|
|
||||||
|
from litellm.tests.test_custom_callback_input import CompletionCustomHandler
|
||||||
|
|
||||||
|
|
||||||
|
@mock_patch_acompletion()
|
||||||
|
def test_custom_logger_failure_handler(mock_acompletion, client_no_auth):
|
||||||
|
from litellm.proxy._types import UserAPIKeyAuth
|
||||||
|
from litellm.proxy.proxy_server import hash_token, user_api_key_cache
|
||||||
|
|
||||||
|
rpm_limit = 0
|
||||||
|
|
||||||
|
mock_api_key = "sk-my-test-key"
|
||||||
|
cache_value = UserAPIKeyAuth(token=hash_token(mock_api_key), rpm_limit=rpm_limit)
|
||||||
|
|
||||||
|
user_api_key_cache.set_cache(key=hash_token(mock_api_key), value=cache_value)
|
||||||
|
|
||||||
|
mock_logger = CustomLogger()
|
||||||
|
mock_logger_unit_tests = CompletionCustomHandler()
|
||||||
|
proxy_logging_obj: ProxyLogging = getattr(
|
||||||
|
litellm.proxy.proxy_server, "proxy_logging_obj"
|
||||||
|
)
|
||||||
|
|
||||||
|
litellm.callbacks = [mock_logger, mock_logger_unit_tests]
|
||||||
|
proxy_logging_obj._init_litellm_callbacks(llm_router=None)
|
||||||
|
|
||||||
|
setattr(litellm.proxy.proxy_server, "user_api_key_cache", user_api_key_cache)
|
||||||
|
setattr(litellm.proxy.proxy_server, "master_key", "sk-1234")
|
||||||
|
setattr(litellm.proxy.proxy_server, "prisma_client", "FAKE-VAR")
|
||||||
|
setattr(litellm.proxy.proxy_server, "proxy_logging_obj", proxy_logging_obj)
|
||||||
|
|
||||||
|
with patch.object(
|
||||||
|
mock_logger, "async_log_failure_event", new=AsyncMock()
|
||||||
|
) as mock_failed_alert:
|
||||||
|
# Your test data
|
||||||
|
test_data = {
|
||||||
|
"model": "gpt-3.5-turbo",
|
||||||
|
"messages": [
|
||||||
|
{"role": "user", "content": "hi"},
|
||||||
|
],
|
||||||
|
"max_tokens": 10,
|
||||||
|
}
|
||||||
|
|
||||||
|
print("testing proxy server with chat completions")
|
||||||
|
response = client_no_auth.post(
|
||||||
|
"/v1/chat/completions",
|
||||||
|
json=test_data,
|
||||||
|
headers={"Authorization": "Bearer {}".format(mock_api_key)},
|
||||||
|
)
|
||||||
|
assert response.status_code == 429
|
||||||
|
|
||||||
|
# confirm async_log_failure_event is called
|
||||||
|
mock_failed_alert.assert_called()
|
||||||
|
|
||||||
|
assert len(mock_logger_unit_tests.errors) == 0
|
||||||
|
|
||||||
|
|
||||||
@mock_patch_acompletion()
|
@mock_patch_acompletion()
|
||||||
def test_engines_model_chat_completions(mock_acompletion, client_no_auth):
|
def test_engines_model_chat_completions(mock_acompletion, client_no_auth):
|
||||||
global headers
|
global headers
|
||||||
|
@ -422,9 +492,10 @@ def test_add_new_model(client_no_auth):
|
||||||
|
|
||||||
def test_health(client_no_auth):
|
def test_health(client_no_auth):
|
||||||
global headers
|
global headers
|
||||||
import time
|
|
||||||
from litellm._logging import verbose_logger, verbose_proxy_logger
|
|
||||||
import logging
|
import logging
|
||||||
|
import time
|
||||||
|
|
||||||
|
from litellm._logging import verbose_logger, verbose_proxy_logger
|
||||||
|
|
||||||
verbose_proxy_logger.setLevel(logging.DEBUG)
|
verbose_proxy_logger.setLevel(logging.DEBUG)
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue