fix(proxy_server.py): enable pre+post-call hooks and max parallel request limits

This commit is contained in:
Krrish Dholakia 2023-12-08 17:10:57 -08:00
parent 977bfaaab9
commit 5fa2b6e5ad
9 changed files with 213 additions and 130 deletions

View file

@ -3,7 +3,7 @@ repos:
rev: 3.8.4 # The version of flake8 to use
hooks:
- id: flake8
exclude: ^litellm/tests/|^litellm/proxy/|^litellm/integrations/
exclude: ^litellm/tests/|^litellm/proxy/proxy_server.py|^litellm/integrations/
additional_dependencies: [flake8-print]
files: litellm/.*\.py
- repo: local

View file

@ -76,12 +76,13 @@ class ModelParams(BaseModel):
protected_namespaces = ()
class GenerateKeyRequest(BaseModel):
duration: str = "1h"
models: list = []
aliases: dict = {}
config: dict = {}
spend: int = 0
duration: Optional[str] = "1h"
models: Optional[list] = []
aliases: Optional[dict] = {}
config: Optional[dict] = {}
spend: Optional[float] = 0
user_id: Optional[str] = None
max_parallel_requests: Optional[int] = None
class GenerateKeyResponse(BaseModel):
key: str
@ -96,8 +97,17 @@ class DeleteKeyRequest(BaseModel):
class UserAPIKeyAuth(BaseModel): # the expected response object for user api key auth
"""
Return the row in the db
"""
api_key: Optional[str] = None
models: list = []
aliases: dict = {}
config: dict = {}
spend: Optional[float] = 0
user_id: Optional[str] = None
max_parallel_requests: Optional[int] = None
duration: str = "1h"
class ConfigGeneralSettings(BaseModel):
"""

View file

@ -0,0 +1 @@
from . import *

View file

@ -0,0 +1,33 @@
from typing import Optional
from litellm.caching import DualCache
from fastapi import HTTPException
async def max_parallel_request_allow_request(max_parallel_requests: Optional[int], api_key: Optional[str], user_api_key_cache: DualCache):
if api_key is None:
return
if max_parallel_requests is None:
return
# CHECK IF REQUEST ALLOWED
request_count_api_key = f"{api_key}_request_count"
current = user_api_key_cache.get_cache(key=request_count_api_key)
if current is None:
user_api_key_cache.set_cache(request_count_api_key, 1)
elif int(current) < max_parallel_requests:
# Increase count for this token
user_api_key_cache.set_cache(request_count_api_key, int(current) + 1)
else:
raise HTTPException(status_code=429, detail="Max parallel request limit reached.")
async def max_parallel_request_update_count(api_key: Optional[str], user_api_key_cache: DualCache):
if api_key is None:
return
request_count_api_key = f"{api_key}_request_count"
# Decrease count for this token
current = user_api_key_cache.get_cache(key=request_count_api_key) or 1
user_api_key_cache.set_cache(request_count_api_key, int(current) - 1)
return

View file

@ -102,7 +102,7 @@ from litellm.proxy._types import *
from litellm.caching import DualCache
from litellm.proxy.health_check import perform_health_check
litellm.suppress_debug_info = True
from fastapi import FastAPI, Request, HTTPException, status, Depends, BackgroundTasks
from fastapi import FastAPI, Request, HTTPException, status, Depends, BackgroundTasks, Header
from fastapi.routing import APIRouter
from fastapi.security import OAuth2PasswordBearer
from fastapi.encoders import jsonable_encoder
@ -198,7 +198,7 @@ user_custom_auth = None
use_background_health_checks = None
health_check_interval = None
health_check_results = {}
call_hooks = CallHooks()
call_hooks = CallHooks(user_api_key_cache=user_api_key_cache)
proxy_logging_obj: Optional[ProxyLogging] = None
### REDIS QUEUE ###
async_result = None
@ -259,10 +259,10 @@ async def user_api_key_auth(request: Request, api_key: str = fastapi.Security(ap
if prisma_client:
## check for cache hit (In-Memory Cache)
valid_token = user_api_key_cache.get_cache(key=api_key)
print(f"valid_token from cache: {valid_token}")
if valid_token is None:
## check db
cleaned_api_key = api_key
valid_token = await prisma_client.get_data(token=cleaned_api_key, expires=datetime.utcnow())
valid_token = await prisma_client.get_data(token=api_key, expires=datetime.utcnow())
user_api_key_cache.set_cache(key=api_key, value=valid_token, ttl=60)
elif valid_token is not None:
print(f"API Key Cache Hit!")
@ -274,10 +274,10 @@ async def user_api_key_auth(request: Request, api_key: str = fastapi.Security(ap
llm_model_list = model_list
print("\n new llm router model list", llm_model_list)
if len(valid_token.models) == 0: # assume an empty model list means all models are allowed to be called
return_dict = {"api_key": valid_token.token}
if valid_token.user_id:
return_dict["user_id"] = valid_token.user_id
return UserAPIKeyAuth(**return_dict)
api_key = valid_token.token
valid_token_dict = valid_token.model_dump()
valid_token_dict.pop("token", None)
return UserAPIKeyAuth(api_key=api_key, **valid_token_dict)
else:
data = await request.json()
model = data.get("model", None)
@ -285,10 +285,10 @@ async def user_api_key_auth(request: Request, api_key: str = fastapi.Security(ap
model = litellm.model_alias_map[model]
if model and model not in valid_token.models:
raise Exception(f"Token not allowed to access model")
return_dict = {"api_key": valid_token.token}
if valid_token.user_id:
return_dict["user_id"] = valid_token.user_id
return UserAPIKeyAuth(**return_dict)
api_key = valid_token.token
valid_token_dict = valid_token.model_dump()
valid_token.pop("token", None)
return UserAPIKeyAuth(api_key=api_key, **valid_token)
else:
raise Exception(f"Invalid token")
except Exception as e:
@ -588,7 +588,7 @@ def load_router_config(router: Optional[litellm.Router], config_file_path: str):
call_hooks.update_router_config(litellm_settings=litellm_settings, model_list=model_list, general_settings=general_settings)
return router, model_list, general_settings
async def generate_key_helper_fn(duration_str: Optional[str], models: list, aliases: dict, config: dict, spend: float, token: Optional[str]=None, user_id: Optional[str]=None):
async def generate_key_helper_fn(duration: Optional[str], models: list, aliases: dict, config: dict, spend: float, token: Optional[str]=None, user_id: Optional[str]=None, max_parallel_requests: Optional[int]=None):
global prisma_client
if prisma_client is None:
@ -617,11 +617,11 @@ async def generate_key_helper_fn(duration_str: Optional[str], models: list, alia
else:
raise ValueError("Unsupported duration unit")
if duration_str is None: # allow tokens that never expire
if duration is None: # allow tokens that never expire
expires = None
else:
duration = _duration_in_seconds(duration=duration_str)
expires = datetime.utcnow() + timedelta(seconds=duration)
duration_s = _duration_in_seconds(duration=duration)
expires = datetime.utcnow() + timedelta(seconds=duration_s)
aliases_json = json.dumps(aliases)
config_json = json.dumps(config)
@ -635,7 +635,8 @@ async def generate_key_helper_fn(duration_str: Optional[str], models: list, alia
"aliases": aliases_json,
"config": config_json,
"spend": spend,
"user_id": user_id
"user_id": user_id,
"max_parallel_requests": max_parallel_requests
}
new_verification_token = await prisma_client.insert_data(data=verification_token_data)
except Exception as e:
@ -755,14 +756,12 @@ def data_generator(response):
except:
yield f"data: {json.dumps(chunk)}\n\n"
async def async_data_generator(response):
async def async_data_generator(response, user_api_key_dict):
global call_hooks
print_verbose("inside generator")
async for chunk in response:
print_verbose(f"returned chunk: {chunk}")
### CALL HOOKS ### - modify outgoing response
response = call_hooks.post_call_success(chunk=chunk, call_type="completion")
try:
yield f"data: {json.dumps(chunk.dict())}\n\n"
except:
@ -812,36 +811,6 @@ def get_litellm_model_info(model: dict = {}):
# if litellm does not have info on the model it should return {}
return {}
@app.middleware("http")
async def rate_limit_per_token(request: Request, call_next):
global user_api_key_cache, general_settings
max_parallel_requests = general_settings.get("max_parallel_requests", None)
api_key = request.headers.get("Authorization")
if max_parallel_requests is not None and api_key is not None: # Rate limiting is enabled
api_key = _get_bearer_token(api_key=api_key)
# CHECK IF REQUEST ALLOWED
request_count_api_key = f"{api_key}_request_count"
current = user_api_key_cache.get_cache(key=request_count_api_key)
if current is None:
user_api_key_cache.set_cache(request_count_api_key, 1)
elif int(current) < max_parallel_requests:
# Increase count for this token
user_api_key_cache.set_cache(request_count_api_key, int(current) + 1)
else:
raise HTTPException(status_code=429, detail="Too many requests.")
response = await call_next(request)
# Decrease count for this token
current = user_api_key_cache.get_cache(key=request_count_api_key)
user_api_key_cache.set_cache(request_count_api_key, int(current) - 1)
return response
else: # Rate limiting is not enabled, just pass the request
response = await call_next(request)
return response
@router.on_event("startup")
async def startup_event():
global prisma_client, master_key, use_background_health_checks
@ -868,7 +837,7 @@ async def startup_event():
if prisma_client is not None and master_key is not None:
# add master key to db
await generate_key_helper_fn(duration_str=None, models=[], aliases={}, config={}, spend=0, token=master_key)
await generate_key_helper_fn(duration=None, models=[], aliases={}, config={}, spend=0, token=master_key)
@router.on_event("shutdown")
async def shutdown_event():
@ -1008,7 +977,7 @@ async def chat_completion(request: Request, model: Optional[str] = None, user_ap
data["api_base"] = user_api_base
### CALL HOOKS ### - modify incoming data before calling the model
data = call_hooks.pre_call(data=data, call_type="completion")
data = await call_hooks.pre_call(user_api_key_dict=user_api_key_dict, data=data, call_type="completion")
### ROUTE THE REQUEST ###
router_model_names = [m["model_name"] for m in llm_model_list] if llm_model_list is not None else []
@ -1021,15 +990,19 @@ async def chat_completion(request: Request, model: Optional[str] = None, user_ap
else: # router is not set
response = await litellm.acompletion(**data)
print(f"final response: {response}")
if 'stream' in data and data['stream'] == True: # use generate_responses to stream responses
return StreamingResponse(async_data_generator(response), media_type='text/event-stream')
return StreamingResponse(async_data_generator(user_api_key_dict=user_api_key_dict, response=response), media_type='text/event-stream')
### CALL HOOKS ### - modify outgoing response
response = call_hooks.post_call_success(response=response, call_type="completion")
response = await call_hooks.post_call_success(user_api_key_dict=user_api_key_dict, response=response, call_type="completion")
background_tasks.add_task(log_input_output, request, response) # background task for logging to OTEL
return response
except Exception as e:
except Exception as e:
print(f"Exception received: {str(e)}")
raise e
await call_hooks.post_call_failure(original_exception=e, user_api_key_dict=user_api_key_dict)
print(f"\033[1;31mAn error occurred: {e}\n\n Debug this by setting `--debug`, e.g. `litellm --model gpt-3.5-turbo --debug`")
router_model_names = [m["model_name"] for m in llm_model_list] if llm_model_list is not None else []
if llm_router is not None and data.get("model", "") in router_model_names:
@ -1046,23 +1019,26 @@ async def chat_completion(request: Request, model: Optional[str] = None, user_ap
print(f"{key}: {value}")
if user_debug:
traceback.print_exc()
error_traceback = traceback.format_exc()
error_msg = f"{str(e)}\n\n{error_traceback}"
try:
status = e.status_code # type: ignore
except:
status = 500
raise HTTPException(
status_code=status,
detail=error_msg
)
if isinstance(e, HTTPException):
raise e
else:
error_traceback = traceback.format_exc()
error_msg = f"{str(e)}\n\n{error_traceback}"
try:
status = e.status_code # type: ignore
except:
status = 500
raise HTTPException(
status_code=status,
detail=error_msg
)
@router.post("/v1/embeddings", dependencies=[Depends(user_api_key_auth)], response_class=ORJSONResponse)
@router.post("/embeddings", dependencies=[Depends(user_api_key_auth)], response_class=ORJSONResponse)
async def embeddings(request: Request, user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), background_tasks: BackgroundTasks = BackgroundTasks()):
global call_hooks
try:
global call_hooks
# Use orjson to parse JSON data, orjson speeds up requests significantly
body = await request.body()
data = orjson.loads(body)
@ -1105,7 +1081,7 @@ async def embeddings(request: Request, user_api_key_dict: UserAPIKeyAuth = Depen
break
### CALL HOOKS ### - modify incoming data before calling the model
data = call_hooks.pre_call(data=data, call_type="embeddings")
data = await call_hooks.pre_call(user_api_key_dict=user_api_key_dict, data=data, call_type="embeddings")
## ROUTE TO CORRECT ENDPOINT ##
if llm_router is not None and data["model"] in router_model_names: # model in router model list
@ -1117,19 +1093,18 @@ async def embeddings(request: Request, user_api_key_dict: UserAPIKeyAuth = Depen
background_tasks.add_task(log_input_output, request, response) # background task for logging to OTEL
### CALL HOOKS ### - modify outgoing response
data = call_hooks.post_call_success(response=response, call_type="embeddings")
data = call_hooks.post_call_success(user_api_key_dict=user_api_key_dict, response=response, call_type="embeddings")
return response
except Exception as e:
await call_hooks.post_call_failure(user_api_key_dict=user_api_key_dict, original_exception=e)
traceback.print_exc()
raise e
except Exception as e:
pass
#### KEY MANAGEMENT ####
@router.post("/key/generate", tags=["key management"], dependencies=[Depends(user_api_key_auth)], response_model=GenerateKeyResponse)
async def generate_key_fn(request: Request, data: GenerateKeyRequest):
async def generate_key_fn(request: Request, data: GenerateKeyRequest, Authorization: Optional[str] = Header(None)):
"""
Generate an API key based on the provided data.
@ -1141,26 +1116,17 @@ async def generate_key_fn(request: Request, data: GenerateKeyRequest):
- aliases: Optional[dict] - Any alias mappings, on top of anything in the config.yaml model list. - https://docs.litellm.ai/docs/proxy/virtual_keys#managing-auth---upgradedowngrade-models
- config: Optional[dict] - any key-specific configs, overrides config in config.yaml
- spend: Optional[int] - Amount spent by key. Default is 0. Will be updated by proxy whenever key is used. https://docs.litellm.ai/docs/proxy/virtual_keys#managing-auth---tracking-spend
- max_parallel_requests: Optional[int] - Rate limit a user based on the number of parallel requests. Raises 429 error, if user's parallel requests > x.
Returns:
- key: The generated api key
- expires: Datetime object for when key expires.
- key: (str) The generated api key
- expires: (datetime) Datetime object for when key expires.
- user_id: (str) Unique user id - used for tracking spend across multiple keys for same user id.
"""
# data = await request.json()
duration_str = data.duration # Default to 1 hour if duration is not provided
models = data.models # Default to an empty list (meaning allow token to call all models)
aliases = data.aliases # Default to an empty dict (no alias mappings, on top of anything in the config.yaml model_list)
config = data.config
spend = data.spend
user_id = data.user_id
if isinstance(models, list):
response = await generate_key_helper_fn(duration_str=duration_str, models=models, aliases=aliases, config=config, spend=spend, user_id=user_id)
return GenerateKeyResponse(key=response["token"], expires=response["expires"], user_id=response["user_id"])
else:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail={"error": "models param must be a list"},
)
data_json = data.model_dump()
response = await generate_key_helper_fn(**data_json)
return GenerateKeyResponse(key=response["token"], expires=response["expires"], user_id=response["user_id"])
@router.post("/key/delete", tags=["key management"], dependencies=[Depends(user_api_key_auth)])
async def delete_key_fn(request: Request, data: DeleteKeyRequest):

View file

@ -16,4 +16,5 @@ model LiteLLM_VerificationToken {
aliases Json @default("{}")
config Json @default("{}")
user_id String?
max_parallel_requests Int?
}

View file

@ -1,7 +1,13 @@
from typing import Optional, List, Any, Literal
import os, subprocess, hashlib, importlib, asyncio
import litellm, backoff
from litellm.proxy._types import UserAPIKeyAuth
from litellm.caching import DualCache
from litellm.proxy.hooks.parallel_request_limiter import max_parallel_request_allow_request, max_parallel_request_update_count
def print_verbose(print_statement):
if litellm.set_verbose:
print(print_statement) # noqa
### LOGGING ###
class ProxyLogging:
"""
@ -17,7 +23,6 @@ class ProxyLogging:
self._init_litellm_callbacks()
pass
def _init_litellm_callbacks(self):
if len(litellm.callbacks) > 0:
for callback in litellm.callbacks:
@ -69,11 +74,11 @@ class ProxyLogging:
# Function to be called whenever a retry is about to happen
def on_backoff(details):
# The 'tries' key in the details dictionary contains the number of completed tries
print(f"Backing off... this was attempt #{details['tries']}")
print_verbose(f"Backing off... this was attempt #{details['tries']}")
class PrismaClient:
def __init__(self, database_url: str, proxy_logging_obj: ProxyLogging):
print("LiteLLM: DATABASE_URL Set in config, trying to 'pip install prisma'")
print_verbose("LiteLLM: DATABASE_URL Set in config, trying to 'pip install prisma'")
## init logging object
self.proxy_logging_obj = proxy_logging_obj
@ -109,20 +114,22 @@ class PrismaClient:
max_time=10, # maximum total time to retry for
on_backoff=on_backoff, # specifying the function to call on backoff
)
async def get_data(self, token: str, expires: Optional[Any]=None):
async def get_data(self, token: str, expires: Optional[Any]=None):
try:
hashed_token = self.hash_token(token=token)
# check if plain text or hash
if token.startswith("sk-"):
token = self.hash_token(token=token)
if expires:
response = await self.db.litellm_verificationtoken.find_first(
where={
"token": hashed_token,
"token": token,
"expires": {"gte": expires} # Check if the token is not expired
}
)
else:
response = await self.db.litellm_verificationtoken.find_unique(
where={
"token": hashed_token
"token": token
}
)
return response
@ -175,25 +182,23 @@ class PrismaClient:
Update existing data
"""
try:
hashed_token = self.hash_token(token=token)
data["token"] = hashed_token
await self.db.litellm_verificationtoken.update(
print_verbose(f"token: {token}")
# check if plain text or hash
if token.startswith("sk-"):
token = self.hash_token(token=token)
data["token"] = token
response = await self.db.litellm_verificationtoken.update(
where={
"token": hashed_token
"token": token
},
data={**data} # type: ignore
)
print("\033[91m" + f"DB write succeeded" + "\033[0m")
print_verbose("\033[91m" + f"DB write succeeded {response}" + "\033[0m")
return {"token": token, "data": data}
except Exception as e:
asyncio.create_task(self.proxy_logging_obj.failure_handler(original_exception=e))
print()
print()
print()
print("\033[91m" + f"DB write failed: {e}" + "\033[0m")
print()
print()
print()
print_verbose("\033[91m" + f"DB write failed: {e}" + "\033[0m")
raise e
@ -252,7 +257,7 @@ class PrismaClient:
### CUSTOM FILE ###
def get_instance_fn(value: str, config_file_path: Optional[str] = None) -> Any:
try:
print(f"value: {value}")
print_verbose(f"value: {value}")
# Split the path by dots to separate module from instance
parts = value.split(".")
@ -285,8 +290,6 @@ def get_instance_fn(value: str, config_file_path: Optional[str] = None) -> Any:
except Exception as e:
raise e
### CALL HOOKS ###
class CallHooks:
"""
@ -297,20 +300,55 @@ class CallHooks:
2. /embeddings
"""
def __init__(self, *args, **kwargs):
self.call_details = {}
def __init__(self, user_api_key_cache: DualCache):
self.call_details: dict = {}
self.call_details["user_api_key_cache"] = user_api_key_cache
def update_router_config(self, litellm_settings: dict, general_settings: dict, model_list: list):
self.call_details["litellm_settings"] = litellm_settings
self.call_details["general_settings"] = general_settings
self.call_details["model_list"] = model_list
def pre_call(self, data: dict, call_type: Literal["completion", "embeddings"]):
self.call_details["data"] = data
return data
async def pre_call(self, user_api_key_dict: UserAPIKeyAuth, data: dict, call_type: Literal["completion", "embeddings"]):
try:
self.call_details["data"] = data
self.call_details["call_type"] = call_type
def post_call_success(self, response: Optional[Any]=None, call_type: Optional[Literal["completion", "embeddings"]]=None, chunk: Optional[Any]=None):
return response
## check if max parallel requests set
if user_api_key_dict.max_parallel_requests is not None:
## if set, check if request allowed
await max_parallel_request_allow_request(
max_parallel_requests=user_api_key_dict.max_parallel_requests,
api_key=user_api_key_dict.api_key,
user_api_key_cache=self.call_details["user_api_key_cache"])
return data
except Exception as e:
raise e
def post_call_failure(self, *args, **kwargs):
pass
async def post_call_success(self, user_api_key_dict: UserAPIKeyAuth, response: Optional[Any]=None, call_type: Optional[Literal["completion", "embeddings"]]=None, chunk: Optional[Any]=None):
try:
# check if max parallel requests set
if user_api_key_dict.max_parallel_requests is not None:
## decrement call, once complete
await max_parallel_request_update_count(
api_key=user_api_key_dict.api_key,
user_api_key_cache=self.call_details["user_api_key_cache"])
return response
except Exception as e:
raise e
async def post_call_failure(self, original_exception: Exception, user_api_key_dict: UserAPIKeyAuth):
# check if max parallel requests set
if user_api_key_dict.max_parallel_requests is not None:
## decrement call count if call failed
if (hasattr(original_exception, "status_code")
and original_exception.status_code == 429
and "Max parallel request limit reached" in str(original_exception)):
pass # ignore failed calls due to max limit being reached
else:
await max_parallel_request_update_count(
api_key=user_api_key_dict.api_key,
user_api_key_cache=self.call_details["user_api_key_cache"])
return

View file

@ -3,7 +3,6 @@ general_settings:
master_key: os.environ/PROXY_MASTER_KEY
litellm_settings:
drop_params: true
set_verbose: true
success_callback: ["langfuse"]
model_list:

View file

@ -1,4 +1,4 @@
import sys, os
import sys, os, time
import traceback
from dotenv import load_dotenv
@ -19,7 +19,7 @@ logging.basicConfig(
level=logging.DEBUG, # Set the desired logging level
format="%(asctime)s - %(levelname)s - %(message)s",
)
from concurrent.futures import ThreadPoolExecutor
# test /chat/completion request to the proxy
from fastapi.testclient import TestClient
from fastapi import FastAPI
@ -62,6 +62,41 @@ def test_add_new_key(client):
assert result["key"].startswith("sk-")
print(f"Received response: {result}")
except Exception as e:
pytest.fail("LiteLLM Proxy test failed. Exception", e)
pytest.fail(f"LiteLLM Proxy test failed. Exception: {str(e)}")
# # Run the test - only runs via pytest
# # Run the test - only runs via pytest
def test_add_new_key_max_parallel_limit(client):
try:
# Your test data
test_data = {"duration": "20m", "max_parallel_requests": 1}
# Your bearer token
token = os.getenv("PROXY_MASTER_KEY")
headers = {
"Authorization": f"Bearer {token}"
}
response = client.post("/key/generate", json=test_data, headers=headers)
print(f"response: {response.text}")
assert response.status_code == 200
result = response.json()
def _post_data():
json_data = {'model': 'azure-model', "messages": [{"role": "user", "content": f"this is a test request, write a short poem {time.time()}"}]}
response = client.post("/chat/completions", json=json_data, headers={"Authorization": f"Bearer {result['key']}"})
return response
def _run_in_parallel():
with ThreadPoolExecutor(max_workers=2) as executor:
future1 = executor.submit(_post_data)
future2 = executor.submit(_post_data)
# Obtain the results from the futures
response1 = future1.result()
response2 = future2.result()
if response1.status_code == 429 or response2.status_code == 429:
pass
else:
raise Exception()
_run_in_parallel()
except Exception as e:
pytest.fail(f"LiteLLM Proxy test failed. Exception: {str(e)}")