diff --git a/litellm/proxy/utils.py b/litellm/proxy/utils.py index 2145d1226..8a8494ce4 100644 --- a/litellm/proxy/utils.py +++ b/litellm/proxy/utils.py @@ -600,12 +600,41 @@ class ProxyLogging: ) if litellm_logging_obj is not None: + ## UPDATE LOGGING INPUT + _optional_params = {} + for k, v in request_data.items(): + if k != "model" and k != "user" and k != "litellm_params": + _optional_params[k] = v + litellm_logging_obj.update_environment_variables( + model=request_data.get("model", ""), + user=request_data.get("user", ""), + optional_params=_optional_params, + litellm_params=request_data.get("litellm_params", {}), + ) + + input: Union[list, str, dict] = "" + if "messages" in request_data and isinstance( + request_data["messages"], list + ): + input = request_data["messages"] + elif "prompt" in request_data and isinstance( + request_data["prompt"], str + ): + input = request_data["prompt"] + elif "input" in request_data and isinstance( + request_data["input"], list + ): + input = request_data["input"] + + litellm_logging_obj.pre_call( + input=input, + api_key="", + ) + # log the custom exception await litellm_logging_obj.async_failure_handler( exception=original_exception, traceback_exception=traceback.format_exc(), - start_time=time.time(), - end_time=time.time(), ) threading.Thread( @@ -613,8 +642,6 @@ class ProxyLogging: args=( original_exception, traceback.format_exc(), - time.time(), - time.time(), ), ).start() diff --git a/litellm/tests/test_proxy_server.py b/litellm/tests/test_proxy_server.py index d8bfb5229..ed7451c27 100644 --- a/litellm/tests/test_proxy_server.py +++ b/litellm/tests/test_proxy_server.py @@ -1,20 +1,29 @@ -import sys, os +import os +import sys import traceback from unittest import mock + from dotenv import load_dotenv +import litellm.proxy +import litellm.proxy.proxy_server + load_dotenv() -import os, io +import io +import os # this file is to test litellm/proxy sys.path.insert( 0, os.path.abspath("../..") ) # Adds the parent directory to the system path -import pytest, logging, asyncio +import asyncio +import logging + +import pytest + import litellm -from litellm import embedding, completion, completion_cost, Timeout -from litellm import RateLimitError +from litellm import RateLimitError, Timeout, completion, completion_cost, embedding # Configure logging logging.basicConfig( @@ -22,14 +31,20 @@ logging.basicConfig( format="%(asctime)s - %(levelname)s - %(message)s", ) +from unittest.mock import AsyncMock, MagicMock, patch + +from fastapi import FastAPI + # test /chat/completion request to the proxy from fastapi.testclient import TestClient -from fastapi import FastAPI -from litellm.proxy.proxy_server import ( + +from litellm.integrations.custom_logger import CustomLogger +from litellm.proxy.proxy_server import ( # Replace with the actual module where your FastAPI router is defined app, - save_worker_config, initialize, -) # Replace with the actual module where your FastAPI router is defined + save_worker_config, +) +from litellm.proxy.utils import ProxyLogging # Your bearer token token = "sk-1234" @@ -158,6 +173,61 @@ def test_chat_completion(mock_acompletion, client_no_auth): pytest.fail(f"LiteLLM Proxy test failed. Exception - {str(e)}") +from litellm.tests.test_custom_callback_input import CompletionCustomHandler + + +@mock_patch_acompletion() +def test_custom_logger_failure_handler(mock_acompletion, client_no_auth): + from litellm.proxy._types import UserAPIKeyAuth + from litellm.proxy.proxy_server import hash_token, user_api_key_cache + + rpm_limit = 0 + + mock_api_key = "sk-my-test-key" + cache_value = UserAPIKeyAuth(token=hash_token(mock_api_key), rpm_limit=rpm_limit) + + user_api_key_cache.set_cache(key=hash_token(mock_api_key), value=cache_value) + + mock_logger = CustomLogger() + mock_logger_unit_tests = CompletionCustomHandler() + proxy_logging_obj: ProxyLogging = getattr( + litellm.proxy.proxy_server, "proxy_logging_obj" + ) + + litellm.callbacks = [mock_logger, mock_logger_unit_tests] + proxy_logging_obj._init_litellm_callbacks(llm_router=None) + + setattr(litellm.proxy.proxy_server, "user_api_key_cache", user_api_key_cache) + setattr(litellm.proxy.proxy_server, "master_key", "sk-1234") + setattr(litellm.proxy.proxy_server, "prisma_client", "FAKE-VAR") + setattr(litellm.proxy.proxy_server, "proxy_logging_obj", proxy_logging_obj) + + with patch.object( + mock_logger, "async_log_failure_event", new=AsyncMock() + ) as mock_failed_alert: + # Your test data + test_data = { + "model": "gpt-3.5-turbo", + "messages": [ + {"role": "user", "content": "hi"}, + ], + "max_tokens": 10, + } + + print("testing proxy server with chat completions") + response = client_no_auth.post( + "/v1/chat/completions", + json=test_data, + headers={"Authorization": "Bearer {}".format(mock_api_key)}, + ) + assert response.status_code == 429 + + # confirm async_log_failure_event is called + mock_failed_alert.assert_called() + + assert len(mock_logger_unit_tests.errors) == 0 + + @mock_patch_acompletion() def test_engines_model_chat_completions(mock_acompletion, client_no_auth): global headers @@ -422,9 +492,10 @@ def test_add_new_model(client_no_auth): def test_health(client_no_auth): global headers - import time - from litellm._logging import verbose_logger, verbose_proxy_logger import logging + import time + + from litellm._logging import verbose_logger, verbose_proxy_logger verbose_proxy_logger.setLevel(logging.DEBUG)