test(test_parallel_request_limiter.py): unit testing for tpm/rpm rate limits

This commit is contained in:
Krrish Dholakia 2024-01-18 15:28:28 -08:00
parent 3957a8303a
commit 34c3b33b37
2 changed files with 17 additions and 5 deletions

View file

@ -1,5 +1,5 @@
from pydantic import BaseModel, Extra, Field, root_validator from pydantic import BaseModel, Extra, Field, root_validator
import enum import enum, sys
from typing import Optional, List, Union, Dict, Literal from typing import Optional, List, Union, Dict, Literal
from datetime import datetime from datetime import datetime
import uuid, json import uuid, json
@ -161,6 +161,8 @@ class UserAPIKeyAuth(LiteLLMBase): # the expected response object for user api
max_parallel_requests: Optional[int] = None max_parallel_requests: Optional[int] = None
duration: str = "1h" duration: str = "1h"
metadata: dict = {} metadata: dict = {}
tpm_limit: int = sys.maxsize
rpm_limit: int = sys.maxsize
class GenerateKeyResponse(LiteLLMBase): class GenerateKeyResponse(LiteLLMBase):

View file

@ -1,5 +1,5 @@
from typing import Optional from typing import Optional
import litellm, traceback import litellm, traceback, sys
from litellm.caching import DualCache from litellm.caching import DualCache
from litellm.proxy._types import UserAPIKeyAuth from litellm.proxy._types import UserAPIKeyAuth
from litellm.integrations.custom_logger import CustomLogger from litellm.integrations.custom_logger import CustomLogger
@ -28,12 +28,18 @@ class MaxParallelRequestsHandler(CustomLogger):
): ):
self.print_verbose(f"Inside Max Parallel Request Pre-Call Hook") self.print_verbose(f"Inside Max Parallel Request Pre-Call Hook")
api_key = user_api_key_dict.api_key api_key = user_api_key_dict.api_key
max_parallel_requests = user_api_key_dict.max_parallel_requests max_parallel_requests = user_api_key_dict.max_parallel_requests or sys.maxsize
tpm_limit = user_api_key_dict.tpm_limit
rpm_limit = user_api_key_dict.rpm_limit
if api_key is None: if api_key is None:
return return
if max_parallel_requests is None: if (
max_parallel_requests == sys.maxsize
and tpm_limit == sys.maxsize
and rpm_limit == sys.maxsize
):
return return
self.user_api_key_cache = cache # save the api key cache for updating the value self.user_api_key_cache = cache # save the api key cache for updating the value
@ -60,7 +66,11 @@ class MaxParallelRequestsHandler(CustomLogger):
"current_rpm": 0, "current_rpm": 0,
} }
cache.set_cache(request_count_api_key, new_val) cache.set_cache(request_count_api_key, new_val)
elif int(current["current_requests"]) < max_parallel_requests: elif (
int(current["current_requests"]) < max_parallel_requests
and current["current_tpm"] < tpm_limit
and current["current_rpm"] < rpm_limit
):
# Increase count for this token # Increase count for this token
new_val = { new_val = {
"current_requests": current["current_requests"] + 1, "current_requests": current["current_requests"] + 1,