forked from phoenix/litellm-mirror
fix(proxy_server.py): enable rate limiting concurrent user requests
This commit is contained in:
parent
cf1df8204e
commit
ad922b205b
2 changed files with 35 additions and 4 deletions
|
@ -96,7 +96,7 @@ from litellm.proxy.utils import (
|
||||||
get_instance_fn
|
get_instance_fn
|
||||||
)
|
)
|
||||||
import pydantic
|
import pydantic
|
||||||
from litellm.proxy.types import *
|
from litellm.proxy._types import *
|
||||||
from litellm.caching import DualCache
|
from litellm.caching import DualCache
|
||||||
from litellm.health_check import perform_health_check
|
from litellm.health_check import perform_health_check
|
||||||
litellm.suppress_debug_info = True
|
litellm.suppress_debug_info = True
|
||||||
|
@ -212,14 +212,15 @@ def usage_telemetry(
|
||||||
target=litellm.utils.litellm_telemetry, args=(data,), daemon=True
|
target=litellm.utils.litellm_telemetry, args=(data,), daemon=True
|
||||||
).start()
|
).start()
|
||||||
|
|
||||||
|
def _get_bearer_token(api_key: str):
|
||||||
|
assert api_key.startswith("Bearer ") # ensure Bearer token passed in
|
||||||
|
api_key = api_key.replace("Bearer ", "") # extract the token
|
||||||
|
|
||||||
async def user_api_key_auth(request: Request, api_key: str = fastapi.Security(api_key_header)) -> UserAPIKeyAuth:
|
async def user_api_key_auth(request: Request, api_key: str = fastapi.Security(api_key_header)) -> UserAPIKeyAuth:
|
||||||
global master_key, prisma_client, llm_model_list, user_custom_auth
|
global master_key, prisma_client, llm_model_list, user_custom_auth
|
||||||
try:
|
try:
|
||||||
if isinstance(api_key, str):
|
if isinstance(api_key, str):
|
||||||
assert api_key.startswith("Bearer ") # ensure Bearer token passed in
|
api_key = _get_bearer_token(api_key=api_key)
|
||||||
api_key = api_key.replace("Bearer ", "") # extract the token
|
|
||||||
### USER-DEFINED AUTH FUNCTION ###
|
### USER-DEFINED AUTH FUNCTION ###
|
||||||
if user_custom_auth:
|
if user_custom_auth:
|
||||||
response = await user_custom_auth(request=request, api_key=api_key)
|
response = await user_custom_auth(request=request, api_key=api_key)
|
||||||
|
@ -745,6 +746,36 @@ def litellm_completion(*args, **kwargs):
|
||||||
return StreamingResponse(data_generator(response), media_type='text/event-stream')
|
return StreamingResponse(data_generator(response), media_type='text/event-stream')
|
||||||
return response
|
return response
|
||||||
|
|
||||||
|
@app.middleware("http")
|
||||||
|
async def rate_limit_per_token(request: Request, call_next):
|
||||||
|
global user_api_key_cache, general_settings
|
||||||
|
max_parallel_requests = general_settings.get("max_parallel_requests", None)
|
||||||
|
api_key = request.headers.get("Authorization")
|
||||||
|
if max_parallel_requests is not None and api_key is not None: # Rate limiting is enabled
|
||||||
|
api_key = _get_bearer_token(api_key=api_key)
|
||||||
|
# CHECK IF REQUEST ALLOWED
|
||||||
|
request_count_api_key = f"{api_key}_request_count"
|
||||||
|
current = user_api_key_cache.get_cache(key=request_count_api_key)
|
||||||
|
if current is None:
|
||||||
|
user_api_key_cache.set_cache(request_count_api_key, 1)
|
||||||
|
elif int(current) < max_parallel_requests:
|
||||||
|
# Increase count for this token
|
||||||
|
user_api_key_cache.set_cache(request_count_api_key, int(current) + 1)
|
||||||
|
else:
|
||||||
|
raise HTTPException(status_code=429, detail="Too many requests.")
|
||||||
|
|
||||||
|
|
||||||
|
response = await call_next(request)
|
||||||
|
|
||||||
|
# Decrease count for this token
|
||||||
|
current = user_api_key_cache.get_cache(key=request_count_api_key)
|
||||||
|
user_api_key_cache.set_cache(request_count_api_key, int(current) - 1)
|
||||||
|
|
||||||
|
return response
|
||||||
|
else: # Rate limiting is not enabled, just pass the request
|
||||||
|
response = await call_next(request)
|
||||||
|
return response
|
||||||
|
|
||||||
@router.on_event("startup")
|
@router.on_event("startup")
|
||||||
async def startup_event():
|
async def startup_event():
|
||||||
global prisma_client, master_key
|
global prisma_client, master_key
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue