fix(proxy_server.py): enable rate limiting concurrent user requests

This commit is contained in:
Krrish Dholakia 2023-12-06 15:10:56 -08:00
parent cf1df8204e
commit ad922b205b
2 changed files with 35 additions and 4 deletions

View file

@ -96,7 +96,7 @@ from litellm.proxy.utils import (
get_instance_fn
)
import pydantic
from litellm.proxy.types import *
from litellm.proxy._types import *
from litellm.caching import DualCache
from litellm.health_check import perform_health_check
litellm.suppress_debug_info = True
@ -212,14 +212,15 @@ def usage_telemetry(
target=litellm.utils.litellm_telemetry, args=(data,), daemon=True
).start()
def _get_bearer_token(api_key: str):
assert api_key.startswith("Bearer ") # ensure Bearer token passed in
api_key = api_key.replace("Bearer ", "") # extract the token
async def user_api_key_auth(request: Request, api_key: str = fastapi.Security(api_key_header)) -> UserAPIKeyAuth:
global master_key, prisma_client, llm_model_list, user_custom_auth
try:
if isinstance(api_key, str):
assert api_key.startswith("Bearer ") # ensure Bearer token passed in
api_key = api_key.replace("Bearer ", "") # extract the token
api_key = _get_bearer_token(api_key=api_key)
### USER-DEFINED AUTH FUNCTION ###
if user_custom_auth:
response = await user_custom_auth(request=request, api_key=api_key)
@ -745,6 +746,36 @@ def litellm_completion(*args, **kwargs):
return StreamingResponse(data_generator(response), media_type='text/event-stream')
return response
@app.middleware("http")
async def rate_limit_per_token(request: Request, call_next):
global user_api_key_cache, general_settings
max_parallel_requests = general_settings.get("max_parallel_requests", None)
api_key = request.headers.get("Authorization")
if max_parallel_requests is not None and api_key is not None: # Rate limiting is enabled
api_key = _get_bearer_token(api_key=api_key)
# CHECK IF REQUEST ALLOWED
request_count_api_key = f"{api_key}_request_count"
current = user_api_key_cache.get_cache(key=request_count_api_key)
if current is None:
user_api_key_cache.set_cache(request_count_api_key, 1)
elif int(current) < max_parallel_requests:
# Increase count for this token
user_api_key_cache.set_cache(request_count_api_key, int(current) + 1)
else:
raise HTTPException(status_code=429, detail="Too many requests.")
response = await call_next(request)
# Decrease count for this token
current = user_api_key_cache.get_cache(key=request_count_api_key)
user_api_key_cache.set_cache(request_count_api_key, int(current) - 1)
return response
else: # Rate limiting is not enabled, just pass the request
response = await call_next(request)
return response
@router.on_event("startup")
async def startup_event():
global prisma_client, master_key