fix(parallel_request_limiter.py): decrement count for failed llm calls

https://github.com/BerriAI/litellm/issues/1477
This commit is contained in:
Krrish Dholakia 2024-01-18 12:42:14 -08:00
parent 37e6c6a59f
commit 1ea3833ef7
3 changed files with 350 additions and 27 deletions

View file

@ -1,9 +1,10 @@
from typing import Optional from typing import Optional
import litellm import litellm, traceback
from litellm.caching import DualCache from litellm.caching import DualCache
from litellm.proxy._types import UserAPIKeyAuth from litellm.proxy._types import UserAPIKeyAuth
from litellm.integrations.custom_logger import CustomLogger from litellm.integrations.custom_logger import CustomLogger
from fastapi import HTTPException from fastapi import HTTPException
from litellm._logging import verbose_proxy_logger
class MaxParallelRequestsHandler(CustomLogger): class MaxParallelRequestsHandler(CustomLogger):
@ -14,8 +15,7 @@ class MaxParallelRequestsHandler(CustomLogger):
pass pass
def print_verbose(self, print_statement): def print_verbose(self, print_statement):
if litellm.set_verbose is True: verbose_proxy_logger.debug(print_statement)
print(print_statement) # noqa
async def async_pre_call_hook( async def async_pre_call_hook(
self, self,
@ -52,7 +52,7 @@ class MaxParallelRequestsHandler(CustomLogger):
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time): async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
try: try:
self.print_verbose(f"INSIDE ASYNC SUCCESS LOGGING") self.print_verbose(f"INSIDE parallel request limiter ASYNC SUCCESS LOGGING")
user_api_key = kwargs["litellm_params"]["metadata"]["user_api_key"] user_api_key = kwargs["litellm_params"]["metadata"]["user_api_key"]
if user_api_key is None: if user_api_key is None:
return return
@ -61,28 +61,19 @@ class MaxParallelRequestsHandler(CustomLogger):
return return
request_count_api_key = f"{user_api_key}_request_count" request_count_api_key = f"{user_api_key}_request_count"
# check if it has collected an entire stream response
self.print_verbose(
f"'complete_streaming_response' is in kwargs: {'complete_streaming_response' in kwargs}"
)
if "complete_streaming_response" in kwargs or kwargs["stream"] != True:
# Decrease count for this token # Decrease count for this token
current = ( current = self.user_api_key_cache.get_cache(key=request_count_api_key) or 1
self.user_api_key_cache.get_cache(key=request_count_api_key) or 1
)
new_val = current - 1 new_val = current - 1
self.print_verbose(f"updated_value in success call: {new_val}") self.print_verbose(f"updated_value in success call: {new_val}")
self.user_api_key_cache.set_cache(request_count_api_key, new_val) self.user_api_key_cache.set_cache(request_count_api_key, new_val)
except Exception as e: except Exception as e:
self.print_verbose(e) # noqa self.print_verbose(e) # noqa
async def async_log_failure_call( async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time):
self, user_api_key_dict: UserAPIKeyAuth, original_exception: Exception
):
try: try:
self.print_verbose(f"Inside Max Parallel Request Failure Hook") self.print_verbose(f"Inside Max Parallel Request Failure Hook")
api_key = user_api_key_dict.api_key user_api_key = kwargs["litellm_params"]["metadata"]["user_api_key"]
if api_key is None: if user_api_key is None:
return return
if self.user_api_key_cache is None: if self.user_api_key_cache is None:
@ -90,13 +81,13 @@ class MaxParallelRequestsHandler(CustomLogger):
## decrement call count if call failed ## decrement call count if call failed
if ( if (
hasattr(original_exception, "status_code") hasattr(kwargs["exception"], "status_code")
and original_exception.status_code == 429 and kwargs["exception"].status_code == 429
and "Max parallel request limit reached" in str(original_exception) and "Max parallel request limit reached" in str(kwargs["exception"])
): ):
pass # ignore failed calls due to max limit being reached pass # ignore failed calls due to max limit being reached
else: else:
request_count_api_key = f"{api_key}_request_count" request_count_api_key = f"{user_api_key}_request_count"
# Decrease count for this token # Decrease count for this token
current = ( current = (
self.user_api_key_cache.get_cache(key=request_count_api_key) or 1 self.user_api_key_cache.get_cache(key=request_count_api_key) or 1

View file

@ -1102,7 +1102,7 @@ async def generate_key_helper_fn(
} }
if prisma_client is not None: if prisma_client is not None:
## CREATE USER (If necessary) ## CREATE USER (If necessary)
verbose_proxy_logger.debug(f"CustomDBClient: Creating User={user_data}") verbose_proxy_logger.debug(f"prisma_client: Creating User={user_data}")
user_row = await prisma_client.insert_data( user_row = await prisma_client.insert_data(
data=user_data, table_name="user" data=user_data, table_name="user"
) )
@ -1111,7 +1111,7 @@ async def generate_key_helper_fn(
if len(user_row.models) > 0 and len(key_data["models"]) == 0: # type: ignore if len(user_row.models) > 0 and len(key_data["models"]) == 0: # type: ignore
key_data["models"] = user_row.models key_data["models"] = user_row.models
## CREATE KEY ## CREATE KEY
verbose_proxy_logger.debug(f"CustomDBClient: Creating Key={key_data}") verbose_proxy_logger.debug(f"prisma_client: Creating Key={key_data}")
await prisma_client.insert_data(data=key_data, table_name="key") await prisma_client.insert_data(data=key_data, table_name="key")
elif custom_db_client is not None: elif custom_db_client is not None:
## CREATE USER (If necessary) ## CREATE USER (If necessary)

View file

@ -0,0 +1,332 @@
# What this tests?
## Unit Tests for the max parallel request limiter for the proxy
import sys, os, asyncio, time, random
from datetime import datetime
import traceback
from dotenv import load_dotenv
load_dotenv()
import os
sys.path.insert(
0, os.path.abspath("../..")
) # Adds the parent directory to the system path
import pytest
import litellm
from litellm import Router
from litellm.proxy.utils import ProxyLogging
from litellm.proxy._types import UserAPIKeyAuth
from litellm.caching import DualCache
from litellm.proxy.hooks.parallel_request_limiter import MaxParallelRequestsHandler
## On Request received
## On Request success
## On Request failure
@pytest.mark.asyncio
async def test_pre_call_hook():
"""
Test if cache updated on call being received
"""
_api_key = "sk-12345"
user_api_key_dict = UserAPIKeyAuth(api_key=_api_key, max_parallel_requests=1)
local_cache = DualCache()
parallel_request_handler = MaxParallelRequestsHandler()
await parallel_request_handler.async_pre_call_hook(
user_api_key_dict=user_api_key_dict, cache=local_cache, data={}, call_type=""
)
print(
parallel_request_handler.user_api_key_cache.get_cache(
key=f"{_api_key}_request_count"
)
)
assert (
parallel_request_handler.user_api_key_cache.get_cache(
key=f"{_api_key}_request_count"
)
== 1
)
@pytest.mark.asyncio
async def test_success_call_hook():
"""
Test if on success, cache correctly decremented
"""
_api_key = "sk-12345"
user_api_key_dict = UserAPIKeyAuth(api_key=_api_key, max_parallel_requests=1)
local_cache = DualCache()
parallel_request_handler = MaxParallelRequestsHandler()
await parallel_request_handler.async_pre_call_hook(
user_api_key_dict=user_api_key_dict, cache=local_cache, data={}, call_type=""
)
assert (
parallel_request_handler.user_api_key_cache.get_cache(
key=f"{_api_key}_request_count"
)
== 1
)
kwargs = {"litellm_params": {"metadata": {"user_api_key": _api_key}}}
await parallel_request_handler.async_log_success_event(
kwargs=kwargs, response_obj="", start_time="", end_time=""
)
assert (
parallel_request_handler.user_api_key_cache.get_cache(
key=f"{_api_key}_request_count"
)
== 0
)
@pytest.mark.asyncio
async def test_failure_call_hook():
"""
Test if on failure, cache correctly decremented
"""
_api_key = "sk-12345"
user_api_key_dict = UserAPIKeyAuth(api_key=_api_key, max_parallel_requests=1)
local_cache = DualCache()
parallel_request_handler = MaxParallelRequestsHandler()
await parallel_request_handler.async_pre_call_hook(
user_api_key_dict=user_api_key_dict, cache=local_cache, data={}, call_type=""
)
assert (
parallel_request_handler.user_api_key_cache.get_cache(
key=f"{_api_key}_request_count"
)
== 1
)
kwargs = {
"litellm_params": {"metadata": {"user_api_key": _api_key}},
"exception": Exception(),
}
await parallel_request_handler.async_log_failure_event(
kwargs=kwargs, response_obj="", start_time="", end_time=""
)
assert (
parallel_request_handler.user_api_key_cache.get_cache(
key=f"{_api_key}_request_count"
)
== 0
)
"""
Test with Router
- normal call
- streaming call
- bad call
"""
@pytest.mark.asyncio
async def test_normal_router_call():
model_list = [
{
"model_name": "azure-model",
"litellm_params": {
"model": "azure/gpt-turbo",
"api_key": "os.environ/AZURE_FRANCE_API_KEY",
"api_base": "https://openai-france-1234.openai.azure.com",
"rpm": 1440,
},
"model_info": {"id": 1},
},
{
"model_name": "azure-model",
"litellm_params": {
"model": "azure/gpt-35-turbo",
"api_key": "os.environ/AZURE_EUROPE_API_KEY",
"api_base": "https://my-endpoint-europe-berri-992.openai.azure.com",
"rpm": 6,
},
"model_info": {"id": 2},
},
]
router = Router(
model_list=model_list,
set_verbose=False,
num_retries=3,
) # type: ignore
_api_key = "sk-12345"
user_api_key_dict = UserAPIKeyAuth(api_key=_api_key, max_parallel_requests=1)
local_cache = DualCache()
pl = ProxyLogging(user_api_key_cache=local_cache)
pl._init_litellm_callbacks()
print(f"litellm callbacks: {litellm.callbacks}")
parallel_request_handler = pl.max_parallel_request_limiter
await parallel_request_handler.async_pre_call_hook(
user_api_key_dict=user_api_key_dict, cache=local_cache, data={}, call_type=""
)
assert (
parallel_request_handler.user_api_key_cache.get_cache(
key=f"{_api_key}_request_count"
)
== 1
)
# normal call
response = await router.acompletion(
model="azure-model",
messages=[{"role": "user", "content": "Hey, how's it going?"}],
metadata={"user_api_key": _api_key},
)
await asyncio.sleep(1) # success is done in a separate thread
print(f"response: {response}")
value = parallel_request_handler.user_api_key_cache.get_cache(
key=f"{_api_key}_request_count"
)
print(f"cache value: {value}")
assert value == 0
@pytest.mark.asyncio
async def test_streaming_router_call():
model_list = [
{
"model_name": "azure-model",
"litellm_params": {
"model": "azure/gpt-turbo",
"api_key": "os.environ/AZURE_FRANCE_API_KEY",
"api_base": "https://openai-france-1234.openai.azure.com",
"rpm": 1440,
},
"model_info": {"id": 1},
},
{
"model_name": "azure-model",
"litellm_params": {
"model": "azure/gpt-35-turbo",
"api_key": "os.environ/AZURE_EUROPE_API_KEY",
"api_base": "https://my-endpoint-europe-berri-992.openai.azure.com",
"rpm": 6,
},
"model_info": {"id": 2},
},
]
router = Router(
model_list=model_list,
set_verbose=False,
num_retries=3,
) # type: ignore
_api_key = "sk-12345"
user_api_key_dict = UserAPIKeyAuth(api_key=_api_key, max_parallel_requests=1)
local_cache = DualCache()
pl = ProxyLogging(user_api_key_cache=local_cache)
pl._init_litellm_callbacks()
print(f"litellm callbacks: {litellm.callbacks}")
parallel_request_handler = pl.max_parallel_request_limiter
await parallel_request_handler.async_pre_call_hook(
user_api_key_dict=user_api_key_dict, cache=local_cache, data={}, call_type=""
)
assert (
parallel_request_handler.user_api_key_cache.get_cache(
key=f"{_api_key}_request_count"
)
== 1
)
# streaming call
response = await router.acompletion(
model="azure-model",
messages=[{"role": "user", "content": "Hey, how's it going?"}],
stream=True,
metadata={"user_api_key": _api_key},
)
async for chunk in response:
continue
await asyncio.sleep(1) # success is done in a separate thread
value = parallel_request_handler.user_api_key_cache.get_cache(
key=f"{_api_key}_request_count"
)
print(f"cache value: {value}")
assert value == 0
@pytest.mark.asyncio
async def test_bad_router_call():
model_list = [
{
"model_name": "azure-model",
"litellm_params": {
"model": "azure/gpt-turbo",
"api_key": "os.environ/AZURE_FRANCE_API_KEY",
"api_base": "https://openai-france-1234.openai.azure.com",
"rpm": 1440,
},
"model_info": {"id": 1},
},
{
"model_name": "azure-model",
"litellm_params": {
"model": "azure/gpt-35-turbo",
"api_key": "os.environ/AZURE_EUROPE_API_KEY",
"api_base": "https://my-endpoint-europe-berri-992.openai.azure.com",
"rpm": 6,
},
"model_info": {"id": 2},
},
]
router = Router(
model_list=model_list,
set_verbose=False,
num_retries=3,
) # type: ignore
_api_key = "sk-12345"
user_api_key_dict = UserAPIKeyAuth(api_key=_api_key, max_parallel_requests=1)
local_cache = DualCache()
pl = ProxyLogging(user_api_key_cache=local_cache)
pl._init_litellm_callbacks()
print(f"litellm callbacks: {litellm.callbacks}")
parallel_request_handler = pl.max_parallel_request_limiter
await parallel_request_handler.async_pre_call_hook(
user_api_key_dict=user_api_key_dict, cache=local_cache, data={}, call_type=""
)
assert (
parallel_request_handler.user_api_key_cache.get_cache(
key=f"{_api_key}_request_count"
)
== 1
)
# bad streaming call
try:
response = await router.acompletion(
model="azure-model",
messages=[{"role": "user2", "content": "Hey, how's it going?"}],
stream=True,
metadata={"user_api_key": _api_key},
)
except:
pass
value = parallel_request_handler.user_api_key_cache.get_cache(
key=f"{_api_key}_request_count"
)
print(f"cache value: {value}")
assert value == 0