diff --git a/litellm/proxy/hooks/parallel_request_limiter.py b/litellm/proxy/hooks/parallel_request_limiter.py index 38247cbe07..007004ca39 100644 --- a/litellm/proxy/hooks/parallel_request_limiter.py +++ b/litellm/proxy/hooks/parallel_request_limiter.py @@ -1,9 +1,10 @@ from typing import Optional -import litellm +import litellm, traceback from litellm.caching import DualCache from litellm.proxy._types import UserAPIKeyAuth from litellm.integrations.custom_logger import CustomLogger from fastapi import HTTPException +from litellm._logging import verbose_proxy_logger class MaxParallelRequestsHandler(CustomLogger): @@ -14,8 +15,7 @@ class MaxParallelRequestsHandler(CustomLogger): pass def print_verbose(self, print_statement): - if litellm.set_verbose is True: - print(print_statement) # noqa + verbose_proxy_logger.debug(print_statement) async def async_pre_call_hook( self, @@ -52,7 +52,7 @@ class MaxParallelRequestsHandler(CustomLogger): async def async_log_success_event(self, kwargs, response_obj, start_time, end_time): try: - self.print_verbose(f"INSIDE ASYNC SUCCESS LOGGING") + self.print_verbose(f"INSIDE parallel request limiter ASYNC SUCCESS LOGGING") user_api_key = kwargs["litellm_params"]["metadata"]["user_api_key"] if user_api_key is None: return @@ -61,28 +61,19 @@ class MaxParallelRequestsHandler(CustomLogger): return request_count_api_key = f"{user_api_key}_request_count" - # check if it has collected an entire stream response - self.print_verbose( - f"'complete_streaming_response' is in kwargs: {'complete_streaming_response' in kwargs}" - ) - if "complete_streaming_response" in kwargs or kwargs["stream"] != True: - # Decrease count for this token - current = ( - self.user_api_key_cache.get_cache(key=request_count_api_key) or 1 - ) - new_val = current - 1 - self.print_verbose(f"updated_value in success call: {new_val}") - self.user_api_key_cache.set_cache(request_count_api_key, new_val) + # Decrease count for this token + current = self.user_api_key_cache.get_cache(key=request_count_api_key) or 1 + new_val = current - 1 + self.print_verbose(f"updated_value in success call: {new_val}") + self.user_api_key_cache.set_cache(request_count_api_key, new_val) except Exception as e: self.print_verbose(e) # noqa - async def async_log_failure_call( - self, user_api_key_dict: UserAPIKeyAuth, original_exception: Exception - ): + async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time): try: self.print_verbose(f"Inside Max Parallel Request Failure Hook") - api_key = user_api_key_dict.api_key - if api_key is None: + user_api_key = kwargs["litellm_params"]["metadata"]["user_api_key"] + if user_api_key is None: return if self.user_api_key_cache is None: @@ -90,13 +81,13 @@ class MaxParallelRequestsHandler(CustomLogger): ## decrement call count if call failed if ( - hasattr(original_exception, "status_code") - and original_exception.status_code == 429 - and "Max parallel request limit reached" in str(original_exception) + hasattr(kwargs["exception"], "status_code") + and kwargs["exception"].status_code == 429 + and "Max parallel request limit reached" in str(kwargs["exception"]) ): pass # ignore failed calls due to max limit being reached else: - request_count_api_key = f"{api_key}_request_count" + request_count_api_key = f"{user_api_key}_request_count" # Decrease count for this token current = ( self.user_api_key_cache.get_cache(key=request_count_api_key) or 1 diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index cde397d387..10c968b1c5 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -1102,7 +1102,7 @@ async def generate_key_helper_fn( } if prisma_client is not None: ## CREATE USER (If necessary) - verbose_proxy_logger.debug(f"CustomDBClient: Creating User={user_data}") + verbose_proxy_logger.debug(f"prisma_client: Creating User={user_data}") user_row = await prisma_client.insert_data( data=user_data, table_name="user" ) @@ -1111,7 +1111,7 @@ async def generate_key_helper_fn( if len(user_row.models) > 0 and len(key_data["models"]) == 0: # type: ignore key_data["models"] = user_row.models ## CREATE KEY - verbose_proxy_logger.debug(f"CustomDBClient: Creating Key={key_data}") + verbose_proxy_logger.debug(f"prisma_client: Creating Key={key_data}") await prisma_client.insert_data(data=key_data, table_name="key") elif custom_db_client is not None: ## CREATE USER (If necessary) diff --git a/litellm/tests/test_parallel_request_limiter.py b/litellm/tests/test_parallel_request_limiter.py new file mode 100644 index 0000000000..41c9d3c828 --- /dev/null +++ b/litellm/tests/test_parallel_request_limiter.py @@ -0,0 +1,332 @@ +# What this tests? +## Unit Tests for the max parallel request limiter for the proxy + +import sys, os, asyncio, time, random +from datetime import datetime +import traceback +from dotenv import load_dotenv + +load_dotenv() +import os + +sys.path.insert( + 0, os.path.abspath("../..") +) # Adds the parent directory to the system path +import pytest +import litellm +from litellm import Router +from litellm.proxy.utils import ProxyLogging +from litellm.proxy._types import UserAPIKeyAuth +from litellm.caching import DualCache +from litellm.proxy.hooks.parallel_request_limiter import MaxParallelRequestsHandler + +## On Request received +## On Request success +## On Request failure + + +@pytest.mark.asyncio +async def test_pre_call_hook(): + """ + Test if cache updated on call being received + """ + _api_key = "sk-12345" + user_api_key_dict = UserAPIKeyAuth(api_key=_api_key, max_parallel_requests=1) + local_cache = DualCache() + parallel_request_handler = MaxParallelRequestsHandler() + + await parallel_request_handler.async_pre_call_hook( + user_api_key_dict=user_api_key_dict, cache=local_cache, data={}, call_type="" + ) + + print( + parallel_request_handler.user_api_key_cache.get_cache( + key=f"{_api_key}_request_count" + ) + ) + assert ( + parallel_request_handler.user_api_key_cache.get_cache( + key=f"{_api_key}_request_count" + ) + == 1 + ) + + +@pytest.mark.asyncio +async def test_success_call_hook(): + """ + Test if on success, cache correctly decremented + """ + _api_key = "sk-12345" + user_api_key_dict = UserAPIKeyAuth(api_key=_api_key, max_parallel_requests=1) + local_cache = DualCache() + parallel_request_handler = MaxParallelRequestsHandler() + + await parallel_request_handler.async_pre_call_hook( + user_api_key_dict=user_api_key_dict, cache=local_cache, data={}, call_type="" + ) + + assert ( + parallel_request_handler.user_api_key_cache.get_cache( + key=f"{_api_key}_request_count" + ) + == 1 + ) + + kwargs = {"litellm_params": {"metadata": {"user_api_key": _api_key}}} + + await parallel_request_handler.async_log_success_event( + kwargs=kwargs, response_obj="", start_time="", end_time="" + ) + + assert ( + parallel_request_handler.user_api_key_cache.get_cache( + key=f"{_api_key}_request_count" + ) + == 0 + ) + + +@pytest.mark.asyncio +async def test_failure_call_hook(): + """ + Test if on failure, cache correctly decremented + """ + _api_key = "sk-12345" + user_api_key_dict = UserAPIKeyAuth(api_key=_api_key, max_parallel_requests=1) + local_cache = DualCache() + parallel_request_handler = MaxParallelRequestsHandler() + + await parallel_request_handler.async_pre_call_hook( + user_api_key_dict=user_api_key_dict, cache=local_cache, data={}, call_type="" + ) + + assert ( + parallel_request_handler.user_api_key_cache.get_cache( + key=f"{_api_key}_request_count" + ) + == 1 + ) + + kwargs = { + "litellm_params": {"metadata": {"user_api_key": _api_key}}, + "exception": Exception(), + } + + await parallel_request_handler.async_log_failure_event( + kwargs=kwargs, response_obj="", start_time="", end_time="" + ) + + assert ( + parallel_request_handler.user_api_key_cache.get_cache( + key=f"{_api_key}_request_count" + ) + == 0 + ) + + +""" +Test with Router +- normal call +- streaming call +- bad call +""" + + +@pytest.mark.asyncio +async def test_normal_router_call(): + model_list = [ + { + "model_name": "azure-model", + "litellm_params": { + "model": "azure/gpt-turbo", + "api_key": "os.environ/AZURE_FRANCE_API_KEY", + "api_base": "https://openai-france-1234.openai.azure.com", + "rpm": 1440, + }, + "model_info": {"id": 1}, + }, + { + "model_name": "azure-model", + "litellm_params": { + "model": "azure/gpt-35-turbo", + "api_key": "os.environ/AZURE_EUROPE_API_KEY", + "api_base": "https://my-endpoint-europe-berri-992.openai.azure.com", + "rpm": 6, + }, + "model_info": {"id": 2}, + }, + ] + router = Router( + model_list=model_list, + set_verbose=False, + num_retries=3, + ) # type: ignore + + _api_key = "sk-12345" + user_api_key_dict = UserAPIKeyAuth(api_key=_api_key, max_parallel_requests=1) + local_cache = DualCache() + pl = ProxyLogging(user_api_key_cache=local_cache) + pl._init_litellm_callbacks() + print(f"litellm callbacks: {litellm.callbacks}") + parallel_request_handler = pl.max_parallel_request_limiter + + await parallel_request_handler.async_pre_call_hook( + user_api_key_dict=user_api_key_dict, cache=local_cache, data={}, call_type="" + ) + + assert ( + parallel_request_handler.user_api_key_cache.get_cache( + key=f"{_api_key}_request_count" + ) + == 1 + ) + + # normal call + response = await router.acompletion( + model="azure-model", + messages=[{"role": "user", "content": "Hey, how's it going?"}], + metadata={"user_api_key": _api_key}, + ) + await asyncio.sleep(1) # success is done in a separate thread + print(f"response: {response}") + value = parallel_request_handler.user_api_key_cache.get_cache( + key=f"{_api_key}_request_count" + ) + print(f"cache value: {value}") + + assert value == 0 + + +@pytest.mark.asyncio +async def test_streaming_router_call(): + model_list = [ + { + "model_name": "azure-model", + "litellm_params": { + "model": "azure/gpt-turbo", + "api_key": "os.environ/AZURE_FRANCE_API_KEY", + "api_base": "https://openai-france-1234.openai.azure.com", + "rpm": 1440, + }, + "model_info": {"id": 1}, + }, + { + "model_name": "azure-model", + "litellm_params": { + "model": "azure/gpt-35-turbo", + "api_key": "os.environ/AZURE_EUROPE_API_KEY", + "api_base": "https://my-endpoint-europe-berri-992.openai.azure.com", + "rpm": 6, + }, + "model_info": {"id": 2}, + }, + ] + router = Router( + model_list=model_list, + set_verbose=False, + num_retries=3, + ) # type: ignore + + _api_key = "sk-12345" + user_api_key_dict = UserAPIKeyAuth(api_key=_api_key, max_parallel_requests=1) + local_cache = DualCache() + pl = ProxyLogging(user_api_key_cache=local_cache) + pl._init_litellm_callbacks() + print(f"litellm callbacks: {litellm.callbacks}") + parallel_request_handler = pl.max_parallel_request_limiter + + await parallel_request_handler.async_pre_call_hook( + user_api_key_dict=user_api_key_dict, cache=local_cache, data={}, call_type="" + ) + + assert ( + parallel_request_handler.user_api_key_cache.get_cache( + key=f"{_api_key}_request_count" + ) + == 1 + ) + + # streaming call + response = await router.acompletion( + model="azure-model", + messages=[{"role": "user", "content": "Hey, how's it going?"}], + stream=True, + metadata={"user_api_key": _api_key}, + ) + async for chunk in response: + continue + await asyncio.sleep(1) # success is done in a separate thread + value = parallel_request_handler.user_api_key_cache.get_cache( + key=f"{_api_key}_request_count" + ) + print(f"cache value: {value}") + + assert value == 0 + + +@pytest.mark.asyncio +async def test_bad_router_call(): + model_list = [ + { + "model_name": "azure-model", + "litellm_params": { + "model": "azure/gpt-turbo", + "api_key": "os.environ/AZURE_FRANCE_API_KEY", + "api_base": "https://openai-france-1234.openai.azure.com", + "rpm": 1440, + }, + "model_info": {"id": 1}, + }, + { + "model_name": "azure-model", + "litellm_params": { + "model": "azure/gpt-35-turbo", + "api_key": "os.environ/AZURE_EUROPE_API_KEY", + "api_base": "https://my-endpoint-europe-berri-992.openai.azure.com", + "rpm": 6, + }, + "model_info": {"id": 2}, + }, + ] + router = Router( + model_list=model_list, + set_verbose=False, + num_retries=3, + ) # type: ignore + + _api_key = "sk-12345" + user_api_key_dict = UserAPIKeyAuth(api_key=_api_key, max_parallel_requests=1) + local_cache = DualCache() + pl = ProxyLogging(user_api_key_cache=local_cache) + pl._init_litellm_callbacks() + print(f"litellm callbacks: {litellm.callbacks}") + parallel_request_handler = pl.max_parallel_request_limiter + + await parallel_request_handler.async_pre_call_hook( + user_api_key_dict=user_api_key_dict, cache=local_cache, data={}, call_type="" + ) + + assert ( + parallel_request_handler.user_api_key_cache.get_cache( + key=f"{_api_key}_request_count" + ) + == 1 + ) + + # bad streaming call + try: + response = await router.acompletion( + model="azure-model", + messages=[{"role": "user2", "content": "Hey, how's it going?"}], + stream=True, + metadata={"user_api_key": _api_key}, + ) + except: + pass + value = parallel_request_handler.user_api_key_cache.get_cache( + key=f"{_api_key}_request_count" + ) + print(f"cache value: {value}") + + assert value == 0