forked from phoenix/litellm-mirror
fix(proxy_server.py): enable rate limiting concurrent user requests
This commit is contained in:
parent
cf1df8204e
commit
ad922b205b
2 changed files with 35 additions and 4 deletions
|
@ -96,7 +96,7 @@ from litellm.proxy.utils import (
|
|||
get_instance_fn
|
||||
)
|
||||
import pydantic
|
||||
from litellm.proxy.types import *
|
||||
from litellm.proxy._types import *
|
||||
from litellm.caching import DualCache
|
||||
from litellm.health_check import perform_health_check
|
||||
litellm.suppress_debug_info = True
|
||||
|
@ -212,14 +212,15 @@ def usage_telemetry(
|
|||
target=litellm.utils.litellm_telemetry, args=(data,), daemon=True
|
||||
).start()
|
||||
|
||||
|
||||
def _get_bearer_token(api_key: str):
|
||||
assert api_key.startswith("Bearer ") # ensure Bearer token passed in
|
||||
api_key = api_key.replace("Bearer ", "") # extract the token
|
||||
|
||||
async def user_api_key_auth(request: Request, api_key: str = fastapi.Security(api_key_header)) -> UserAPIKeyAuth:
|
||||
global master_key, prisma_client, llm_model_list, user_custom_auth
|
||||
try:
|
||||
if isinstance(api_key, str):
|
||||
assert api_key.startswith("Bearer ") # ensure Bearer token passed in
|
||||
api_key = api_key.replace("Bearer ", "") # extract the token
|
||||
api_key = _get_bearer_token(api_key=api_key)
|
||||
### USER-DEFINED AUTH FUNCTION ###
|
||||
if user_custom_auth:
|
||||
response = await user_custom_auth(request=request, api_key=api_key)
|
||||
|
@ -745,6 +746,36 @@ def litellm_completion(*args, **kwargs):
|
|||
return StreamingResponse(data_generator(response), media_type='text/event-stream')
|
||||
return response
|
||||
|
||||
@app.middleware("http")
|
||||
async def rate_limit_per_token(request: Request, call_next):
|
||||
global user_api_key_cache, general_settings
|
||||
max_parallel_requests = general_settings.get("max_parallel_requests", None)
|
||||
api_key = request.headers.get("Authorization")
|
||||
if max_parallel_requests is not None and api_key is not None: # Rate limiting is enabled
|
||||
api_key = _get_bearer_token(api_key=api_key)
|
||||
# CHECK IF REQUEST ALLOWED
|
||||
request_count_api_key = f"{api_key}_request_count"
|
||||
current = user_api_key_cache.get_cache(key=request_count_api_key)
|
||||
if current is None:
|
||||
user_api_key_cache.set_cache(request_count_api_key, 1)
|
||||
elif int(current) < max_parallel_requests:
|
||||
# Increase count for this token
|
||||
user_api_key_cache.set_cache(request_count_api_key, int(current) + 1)
|
||||
else:
|
||||
raise HTTPException(status_code=429, detail="Too many requests.")
|
||||
|
||||
|
||||
response = await call_next(request)
|
||||
|
||||
# Decrease count for this token
|
||||
current = user_api_key_cache.get_cache(key=request_count_api_key)
|
||||
user_api_key_cache.set_cache(request_count_api_key, int(current) - 1)
|
||||
|
||||
return response
|
||||
else: # Rate limiting is not enabled, just pass the request
|
||||
response = await call_next(request)
|
||||
return response
|
||||
|
||||
@router.on_event("startup")
|
||||
async def startup_event():
|
||||
global prisma_client, master_key
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue