mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-27 03:34:10 +00:00
test(test_parallel_request_limiter.py): unit testing for tpm/rpm rate limits
This commit is contained in:
parent
3957a8303a
commit
34c3b33b37
2 changed files with 17 additions and 5 deletions
|
@ -1,5 +1,5 @@
|
||||||
from pydantic import BaseModel, Extra, Field, root_validator
|
from pydantic import BaseModel, Extra, Field, root_validator
|
||||||
import enum
|
import enum, sys
|
||||||
from typing import Optional, List, Union, Dict, Literal
|
from typing import Optional, List, Union, Dict, Literal
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
import uuid, json
|
import uuid, json
|
||||||
|
@ -161,6 +161,8 @@ class UserAPIKeyAuth(LiteLLMBase): # the expected response object for user api
|
||||||
max_parallel_requests: Optional[int] = None
|
max_parallel_requests: Optional[int] = None
|
||||||
duration: str = "1h"
|
duration: str = "1h"
|
||||||
metadata: dict = {}
|
metadata: dict = {}
|
||||||
|
tpm_limit: int = sys.maxsize
|
||||||
|
rpm_limit: int = sys.maxsize
|
||||||
|
|
||||||
|
|
||||||
class GenerateKeyResponse(LiteLLMBase):
|
class GenerateKeyResponse(LiteLLMBase):
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
import litellm, traceback
|
import litellm, traceback, sys
|
||||||
from litellm.caching import DualCache
|
from litellm.caching import DualCache
|
||||||
from litellm.proxy._types import UserAPIKeyAuth
|
from litellm.proxy._types import UserAPIKeyAuth
|
||||||
from litellm.integrations.custom_logger import CustomLogger
|
from litellm.integrations.custom_logger import CustomLogger
|
||||||
|
@ -28,12 +28,18 @@ class MaxParallelRequestsHandler(CustomLogger):
|
||||||
):
|
):
|
||||||
self.print_verbose(f"Inside Max Parallel Request Pre-Call Hook")
|
self.print_verbose(f"Inside Max Parallel Request Pre-Call Hook")
|
||||||
api_key = user_api_key_dict.api_key
|
api_key = user_api_key_dict.api_key
|
||||||
max_parallel_requests = user_api_key_dict.max_parallel_requests
|
max_parallel_requests = user_api_key_dict.max_parallel_requests or sys.maxsize
|
||||||
|
tpm_limit = user_api_key_dict.tpm_limit
|
||||||
|
rpm_limit = user_api_key_dict.rpm_limit
|
||||||
|
|
||||||
if api_key is None:
|
if api_key is None:
|
||||||
return
|
return
|
||||||
|
|
||||||
if max_parallel_requests is None:
|
if (
|
||||||
|
max_parallel_requests == sys.maxsize
|
||||||
|
and tpm_limit == sys.maxsize
|
||||||
|
and rpm_limit == sys.maxsize
|
||||||
|
):
|
||||||
return
|
return
|
||||||
|
|
||||||
self.user_api_key_cache = cache # save the api key cache for updating the value
|
self.user_api_key_cache = cache # save the api key cache for updating the value
|
||||||
|
@ -60,7 +66,11 @@ class MaxParallelRequestsHandler(CustomLogger):
|
||||||
"current_rpm": 0,
|
"current_rpm": 0,
|
||||||
}
|
}
|
||||||
cache.set_cache(request_count_api_key, new_val)
|
cache.set_cache(request_count_api_key, new_val)
|
||||||
elif int(current["current_requests"]) < max_parallel_requests:
|
elif (
|
||||||
|
int(current["current_requests"]) < max_parallel_requests
|
||||||
|
and current["current_tpm"] < tpm_limit
|
||||||
|
and current["current_rpm"] < rpm_limit
|
||||||
|
):
|
||||||
# Increase count for this token
|
# Increase count for this token
|
||||||
new_val = {
|
new_val = {
|
||||||
"current_requests": current["current_requests"] + 1,
|
"current_requests": current["current_requests"] + 1,
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue