forked from phoenix/litellm-mirror
fix(proxy/utils.py): fix failure logging for rejected requests. + unit tests
This commit is contained in:
parent
06efe28132
commit
ec03e675c9
5 changed files with 113 additions and 18 deletions
|
@ -1,20 +1,29 @@
|
|||
import sys, os
|
||||
import os
|
||||
import sys
|
||||
import traceback
|
||||
from unittest import mock
|
||||
|
||||
from dotenv import load_dotenv
|
||||
|
||||
import litellm.proxy
|
||||
import litellm.proxy.proxy_server
|
||||
|
||||
load_dotenv()
|
||||
import os, io
|
||||
import io
|
||||
import os
|
||||
|
||||
# this file is to test litellm/proxy
|
||||
|
||||
sys.path.insert(
|
||||
0, os.path.abspath("../..")
|
||||
) # Adds the parent directory to the system path
|
||||
import pytest, logging, asyncio
|
||||
import asyncio
|
||||
import logging
|
||||
|
||||
import pytest
|
||||
|
||||
import litellm
|
||||
from litellm import embedding, completion, completion_cost, Timeout
|
||||
from litellm import RateLimitError
|
||||
from litellm import RateLimitError, Timeout, completion, completion_cost, embedding
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(
|
||||
|
@ -22,14 +31,20 @@ logging.basicConfig(
|
|||
format="%(asctime)s - %(levelname)s - %(message)s",
|
||||
)
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
from fastapi import FastAPI
|
||||
|
||||
# test /chat/completion request to the proxy
|
||||
from fastapi.testclient import TestClient
|
||||
from fastapi import FastAPI
|
||||
from litellm.proxy.proxy_server import (
|
||||
|
||||
from litellm.integrations.custom_logger import CustomLogger
|
||||
from litellm.proxy.proxy_server import ( # Replace with the actual module where your FastAPI router is defined
|
||||
app,
|
||||
save_worker_config,
|
||||
initialize,
|
||||
) # Replace with the actual module where your FastAPI router is defined
|
||||
save_worker_config,
|
||||
)
|
||||
from litellm.proxy.utils import ProxyLogging
|
||||
|
||||
# Your bearer token
|
||||
token = "sk-1234"
|
||||
|
@ -158,6 +173,61 @@ def test_chat_completion(mock_acompletion, client_no_auth):
|
|||
pytest.fail(f"LiteLLM Proxy test failed. Exception - {str(e)}")
|
||||
|
||||
|
||||
from litellm.tests.test_custom_callback_input import CompletionCustomHandler
|
||||
|
||||
|
||||
@mock_patch_acompletion()
|
||||
def test_custom_logger_failure_handler(mock_acompletion, client_no_auth):
|
||||
from litellm.proxy._types import UserAPIKeyAuth
|
||||
from litellm.proxy.proxy_server import hash_token, user_api_key_cache
|
||||
|
||||
rpm_limit = 0
|
||||
|
||||
mock_api_key = "sk-my-test-key"
|
||||
cache_value = UserAPIKeyAuth(token=hash_token(mock_api_key), rpm_limit=rpm_limit)
|
||||
|
||||
user_api_key_cache.set_cache(key=hash_token(mock_api_key), value=cache_value)
|
||||
|
||||
mock_logger = CustomLogger()
|
||||
mock_logger_unit_tests = CompletionCustomHandler()
|
||||
proxy_logging_obj: ProxyLogging = getattr(
|
||||
litellm.proxy.proxy_server, "proxy_logging_obj"
|
||||
)
|
||||
|
||||
litellm.callbacks = [mock_logger, mock_logger_unit_tests]
|
||||
proxy_logging_obj._init_litellm_callbacks(llm_router=None)
|
||||
|
||||
setattr(litellm.proxy.proxy_server, "user_api_key_cache", user_api_key_cache)
|
||||
setattr(litellm.proxy.proxy_server, "master_key", "sk-1234")
|
||||
setattr(litellm.proxy.proxy_server, "prisma_client", "FAKE-VAR")
|
||||
setattr(litellm.proxy.proxy_server, "proxy_logging_obj", proxy_logging_obj)
|
||||
|
||||
with patch.object(
|
||||
mock_logger, "async_log_failure_event", new=AsyncMock()
|
||||
) as mock_failed_alert:
|
||||
# Your test data
|
||||
test_data = {
|
||||
"model": "gpt-3.5-turbo",
|
||||
"messages": [
|
||||
{"role": "user", "content": "hi"},
|
||||
],
|
||||
"max_tokens": 10,
|
||||
}
|
||||
|
||||
print("testing proxy server with chat completions")
|
||||
response = client_no_auth.post(
|
||||
"/v1/chat/completions",
|
||||
json=test_data,
|
||||
headers={"Authorization": "Bearer {}".format(mock_api_key)},
|
||||
)
|
||||
assert response.status_code == 429
|
||||
|
||||
# confirm async_log_failure_event is called
|
||||
mock_failed_alert.assert_called()
|
||||
|
||||
assert len(mock_logger_unit_tests.errors) == 0
|
||||
|
||||
|
||||
@mock_patch_acompletion()
|
||||
def test_engines_model_chat_completions(mock_acompletion, client_no_auth):
|
||||
global headers
|
||||
|
@ -422,9 +492,10 @@ def test_add_new_model(client_no_auth):
|
|||
|
||||
def test_health(client_no_auth):
|
||||
global headers
|
||||
import time
|
||||
from litellm._logging import verbose_logger, verbose_proxy_logger
|
||||
import logging
|
||||
import time
|
||||
|
||||
from litellm._logging import verbose_logger, verbose_proxy_logger
|
||||
|
||||
verbose_proxy_logger.setLevel(logging.DEBUG)
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue