diff --git a/litellm/proxy/types.py b/litellm/proxy/_types.py similarity index 100% rename from litellm/proxy/types.py rename to litellm/proxy/_types.py diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index 54131ee27..51ae34cf9 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -96,7 +96,7 @@ from litellm.proxy.utils import ( get_instance_fn ) import pydantic -from litellm.proxy.types import * +from litellm.proxy._types import * from litellm.caching import DualCache from litellm.health_check import perform_health_check litellm.suppress_debug_info = True @@ -212,14 +212,15 @@ def usage_telemetry( target=litellm.utils.litellm_telemetry, args=(data,), daemon=True ).start() - +def _get_bearer_token(api_key: str): + assert api_key.startswith("Bearer ") # ensure Bearer token passed in + api_key = api_key.replace("Bearer ", "") # extract the token async def user_api_key_auth(request: Request, api_key: str = fastapi.Security(api_key_header)) -> UserAPIKeyAuth: global master_key, prisma_client, llm_model_list, user_custom_auth try: if isinstance(api_key, str): - assert api_key.startswith("Bearer ") # ensure Bearer token passed in - api_key = api_key.replace("Bearer ", "") # extract the token + api_key = _get_bearer_token(api_key=api_key) ### USER-DEFINED AUTH FUNCTION ### if user_custom_auth: response = await user_custom_auth(request=request, api_key=api_key) @@ -745,6 +746,36 @@ def litellm_completion(*args, **kwargs): return StreamingResponse(data_generator(response), media_type='text/event-stream') return response +@app.middleware("http") +async def rate_limit_per_token(request: Request, call_next): + global user_api_key_cache, general_settings + max_parallel_requests = general_settings.get("max_parallel_requests", None) + api_key = request.headers.get("Authorization") + if max_parallel_requests is not None and api_key is not None: # Rate limiting is enabled + api_key = _get_bearer_token(api_key=api_key) + # CHECK IF REQUEST ALLOWED + request_count_api_key = f"{api_key}_request_count" + current = user_api_key_cache.get_cache(key=request_count_api_key) + if current is None: + user_api_key_cache.set_cache(request_count_api_key, 1) + elif int(current) < max_parallel_requests: + # Increase count for this token + user_api_key_cache.set_cache(request_count_api_key, int(current) + 1) + else: + raise HTTPException(status_code=429, detail="Too many requests.") + + + response = await call_next(request) + + # Decrease count for this token + current = user_api_key_cache.get_cache(key=request_count_api_key) + user_api_key_cache.set_cache(request_count_api_key, int(current) - 1) + + return response + else: # Rate limiting is not enabled, just pass the request + response = await call_next(request) + return response + @router.on_event("startup") async def startup_event(): global prisma_client, master_key