fix(proxy_server.py): enable pre+post-call hooks and max parallel request limits

This commit is contained in:
Krrish Dholakia 2023-12-08 17:10:57 -08:00
parent 977bfaaab9
commit 5fa2b6e5ad
9 changed files with 213 additions and 130 deletions

View file

@ -102,7 +102,7 @@ from litellm.proxy._types import *
from litellm.caching import DualCache
from litellm.proxy.health_check import perform_health_check
litellm.suppress_debug_info = True
from fastapi import FastAPI, Request, HTTPException, status, Depends, BackgroundTasks
from fastapi import FastAPI, Request, HTTPException, status, Depends, BackgroundTasks, Header
from fastapi.routing import APIRouter
from fastapi.security import OAuth2PasswordBearer
from fastapi.encoders import jsonable_encoder
@ -198,7 +198,7 @@ user_custom_auth = None
use_background_health_checks = None
health_check_interval = None
health_check_results = {}
call_hooks = CallHooks()
call_hooks = CallHooks(user_api_key_cache=user_api_key_cache)
proxy_logging_obj: Optional[ProxyLogging] = None
### REDIS QUEUE ###
async_result = None
@ -259,10 +259,10 @@ async def user_api_key_auth(request: Request, api_key: str = fastapi.Security(ap
if prisma_client:
## check for cache hit (In-Memory Cache)
valid_token = user_api_key_cache.get_cache(key=api_key)
print(f"valid_token from cache: {valid_token}")
if valid_token is None:
## check db
cleaned_api_key = api_key
valid_token = await prisma_client.get_data(token=cleaned_api_key, expires=datetime.utcnow())
valid_token = await prisma_client.get_data(token=api_key, expires=datetime.utcnow())
user_api_key_cache.set_cache(key=api_key, value=valid_token, ttl=60)
elif valid_token is not None:
print(f"API Key Cache Hit!")
@ -274,10 +274,10 @@ async def user_api_key_auth(request: Request, api_key: str = fastapi.Security(ap
llm_model_list = model_list
print("\n new llm router model list", llm_model_list)
if len(valid_token.models) == 0: # assume an empty model list means all models are allowed to be called
return_dict = {"api_key": valid_token.token}
if valid_token.user_id:
return_dict["user_id"] = valid_token.user_id
return UserAPIKeyAuth(**return_dict)
api_key = valid_token.token
valid_token_dict = valid_token.model_dump()
valid_token_dict.pop("token", None)
return UserAPIKeyAuth(api_key=api_key, **valid_token_dict)
else:
data = await request.json()
model = data.get("model", None)
@ -285,10 +285,10 @@ async def user_api_key_auth(request: Request, api_key: str = fastapi.Security(ap
model = litellm.model_alias_map[model]
if model and model not in valid_token.models:
raise Exception(f"Token not allowed to access model")
return_dict = {"api_key": valid_token.token}
if valid_token.user_id:
return_dict["user_id"] = valid_token.user_id
return UserAPIKeyAuth(**return_dict)
api_key = valid_token.token
valid_token_dict = valid_token.model_dump()
valid_token.pop("token", None)
return UserAPIKeyAuth(api_key=api_key, **valid_token)
else:
raise Exception(f"Invalid token")
except Exception as e:
@ -588,7 +588,7 @@ def load_router_config(router: Optional[litellm.Router], config_file_path: str):
call_hooks.update_router_config(litellm_settings=litellm_settings, model_list=model_list, general_settings=general_settings)
return router, model_list, general_settings
async def generate_key_helper_fn(duration_str: Optional[str], models: list, aliases: dict, config: dict, spend: float, token: Optional[str]=None, user_id: Optional[str]=None):
async def generate_key_helper_fn(duration: Optional[str], models: list, aliases: dict, config: dict, spend: float, token: Optional[str]=None, user_id: Optional[str]=None, max_parallel_requests: Optional[int]=None):
global prisma_client
if prisma_client is None:
@ -617,11 +617,11 @@ async def generate_key_helper_fn(duration_str: Optional[str], models: list, alia
else:
raise ValueError("Unsupported duration unit")
if duration_str is None: # allow tokens that never expire
if duration is None: # allow tokens that never expire
expires = None
else:
duration = _duration_in_seconds(duration=duration_str)
expires = datetime.utcnow() + timedelta(seconds=duration)
duration_s = _duration_in_seconds(duration=duration)
expires = datetime.utcnow() + timedelta(seconds=duration_s)
aliases_json = json.dumps(aliases)
config_json = json.dumps(config)
@ -635,7 +635,8 @@ async def generate_key_helper_fn(duration_str: Optional[str], models: list, alia
"aliases": aliases_json,
"config": config_json,
"spend": spend,
"user_id": user_id
"user_id": user_id,
"max_parallel_requests": max_parallel_requests
}
new_verification_token = await prisma_client.insert_data(data=verification_token_data)
except Exception as e:
@ -755,14 +756,12 @@ def data_generator(response):
except:
yield f"data: {json.dumps(chunk)}\n\n"
async def async_data_generator(response):
async def async_data_generator(response, user_api_key_dict):
global call_hooks
print_verbose("inside generator")
async for chunk in response:
print_verbose(f"returned chunk: {chunk}")
### CALL HOOKS ### - modify outgoing response
response = call_hooks.post_call_success(chunk=chunk, call_type="completion")
try:
yield f"data: {json.dumps(chunk.dict())}\n\n"
except:
@ -812,36 +811,6 @@ def get_litellm_model_info(model: dict = {}):
# if litellm does not have info on the model it should return {}
return {}
@app.middleware("http")
async def rate_limit_per_token(request: Request, call_next):
global user_api_key_cache, general_settings
max_parallel_requests = general_settings.get("max_parallel_requests", None)
api_key = request.headers.get("Authorization")
if max_parallel_requests is not None and api_key is not None: # Rate limiting is enabled
api_key = _get_bearer_token(api_key=api_key)
# CHECK IF REQUEST ALLOWED
request_count_api_key = f"{api_key}_request_count"
current = user_api_key_cache.get_cache(key=request_count_api_key)
if current is None:
user_api_key_cache.set_cache(request_count_api_key, 1)
elif int(current) < max_parallel_requests:
# Increase count for this token
user_api_key_cache.set_cache(request_count_api_key, int(current) + 1)
else:
raise HTTPException(status_code=429, detail="Too many requests.")
response = await call_next(request)
# Decrease count for this token
current = user_api_key_cache.get_cache(key=request_count_api_key)
user_api_key_cache.set_cache(request_count_api_key, int(current) - 1)
return response
else: # Rate limiting is not enabled, just pass the request
response = await call_next(request)
return response
@router.on_event("startup")
async def startup_event():
global prisma_client, master_key, use_background_health_checks
@ -868,7 +837,7 @@ async def startup_event():
if prisma_client is not None and master_key is not None:
# add master key to db
await generate_key_helper_fn(duration_str=None, models=[], aliases={}, config={}, spend=0, token=master_key)
await generate_key_helper_fn(duration=None, models=[], aliases={}, config={}, spend=0, token=master_key)
@router.on_event("shutdown")
async def shutdown_event():
@ -1008,7 +977,7 @@ async def chat_completion(request: Request, model: Optional[str] = None, user_ap
data["api_base"] = user_api_base
### CALL HOOKS ### - modify incoming data before calling the model
data = call_hooks.pre_call(data=data, call_type="completion")
data = await call_hooks.pre_call(user_api_key_dict=user_api_key_dict, data=data, call_type="completion")
### ROUTE THE REQUEST ###
router_model_names = [m["model_name"] for m in llm_model_list] if llm_model_list is not None else []
@ -1021,15 +990,19 @@ async def chat_completion(request: Request, model: Optional[str] = None, user_ap
else: # router is not set
response = await litellm.acompletion(**data)
print(f"final response: {response}")
if 'stream' in data and data['stream'] == True: # use generate_responses to stream responses
return StreamingResponse(async_data_generator(response), media_type='text/event-stream')
return StreamingResponse(async_data_generator(user_api_key_dict=user_api_key_dict, response=response), media_type='text/event-stream')
### CALL HOOKS ### - modify outgoing response
response = call_hooks.post_call_success(response=response, call_type="completion")
response = await call_hooks.post_call_success(user_api_key_dict=user_api_key_dict, response=response, call_type="completion")
background_tasks.add_task(log_input_output, request, response) # background task for logging to OTEL
return response
except Exception as e:
except Exception as e:
print(f"Exception received: {str(e)}")
raise e
await call_hooks.post_call_failure(original_exception=e, user_api_key_dict=user_api_key_dict)
print(f"\033[1;31mAn error occurred: {e}\n\n Debug this by setting `--debug`, e.g. `litellm --model gpt-3.5-turbo --debug`")
router_model_names = [m["model_name"] for m in llm_model_list] if llm_model_list is not None else []
if llm_router is not None and data.get("model", "") in router_model_names:
@ -1046,23 +1019,26 @@ async def chat_completion(request: Request, model: Optional[str] = None, user_ap
print(f"{key}: {value}")
if user_debug:
traceback.print_exc()
error_traceback = traceback.format_exc()
error_msg = f"{str(e)}\n\n{error_traceback}"
try:
status = e.status_code # type: ignore
except:
status = 500
raise HTTPException(
status_code=status,
detail=error_msg
)
if isinstance(e, HTTPException):
raise e
else:
error_traceback = traceback.format_exc()
error_msg = f"{str(e)}\n\n{error_traceback}"
try:
status = e.status_code # type: ignore
except:
status = 500
raise HTTPException(
status_code=status,
detail=error_msg
)
@router.post("/v1/embeddings", dependencies=[Depends(user_api_key_auth)], response_class=ORJSONResponse)
@router.post("/embeddings", dependencies=[Depends(user_api_key_auth)], response_class=ORJSONResponse)
async def embeddings(request: Request, user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), background_tasks: BackgroundTasks = BackgroundTasks()):
global call_hooks
try:
global call_hooks
# Use orjson to parse JSON data, orjson speeds up requests significantly
body = await request.body()
data = orjson.loads(body)
@ -1105,7 +1081,7 @@ async def embeddings(request: Request, user_api_key_dict: UserAPIKeyAuth = Depen
break
### CALL HOOKS ### - modify incoming data before calling the model
data = call_hooks.pre_call(data=data, call_type="embeddings")
data = await call_hooks.pre_call(user_api_key_dict=user_api_key_dict, data=data, call_type="embeddings")
## ROUTE TO CORRECT ENDPOINT ##
if llm_router is not None and data["model"] in router_model_names: # model in router model list
@ -1117,19 +1093,18 @@ async def embeddings(request: Request, user_api_key_dict: UserAPIKeyAuth = Depen
background_tasks.add_task(log_input_output, request, response) # background task for logging to OTEL
### CALL HOOKS ### - modify outgoing response
data = call_hooks.post_call_success(response=response, call_type="embeddings")
data = call_hooks.post_call_success(user_api_key_dict=user_api_key_dict, response=response, call_type="embeddings")
return response
except Exception as e:
await call_hooks.post_call_failure(user_api_key_dict=user_api_key_dict, original_exception=e)
traceback.print_exc()
raise e
except Exception as e:
pass
#### KEY MANAGEMENT ####
@router.post("/key/generate", tags=["key management"], dependencies=[Depends(user_api_key_auth)], response_model=GenerateKeyResponse)
async def generate_key_fn(request: Request, data: GenerateKeyRequest):
async def generate_key_fn(request: Request, data: GenerateKeyRequest, Authorization: Optional[str] = Header(None)):
"""
Generate an API key based on the provided data.
@ -1141,26 +1116,17 @@ async def generate_key_fn(request: Request, data: GenerateKeyRequest):
- aliases: Optional[dict] - Any alias mappings, on top of anything in the config.yaml model list. - https://docs.litellm.ai/docs/proxy/virtual_keys#managing-auth---upgradedowngrade-models
- config: Optional[dict] - any key-specific configs, overrides config in config.yaml
- spend: Optional[int] - Amount spent by key. Default is 0. Will be updated by proxy whenever key is used. https://docs.litellm.ai/docs/proxy/virtual_keys#managing-auth---tracking-spend
- max_parallel_requests: Optional[int] - Rate limit a user based on the number of parallel requests. Raises 429 error, if user's parallel requests > x.
Returns:
- key: The generated api key
- expires: Datetime object for when key expires.
- key: (str) The generated api key
- expires: (datetime) Datetime object for when key expires.
- user_id: (str) Unique user id - used for tracking spend across multiple keys for same user id.
"""
# data = await request.json()
duration_str = data.duration # Default to 1 hour if duration is not provided
models = data.models # Default to an empty list (meaning allow token to call all models)
aliases = data.aliases # Default to an empty dict (no alias mappings, on top of anything in the config.yaml model_list)
config = data.config
spend = data.spend
user_id = data.user_id
if isinstance(models, list):
response = await generate_key_helper_fn(duration_str=duration_str, models=models, aliases=aliases, config=config, spend=spend, user_id=user_id)
return GenerateKeyResponse(key=response["token"], expires=response["expires"], user_id=response["user_id"])
else:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail={"error": "models param must be a list"},
)
data_json = data.model_dump()
response = await generate_key_helper_fn(**data_json)
return GenerateKeyResponse(key=response["token"], expires=response["expires"], user_id=response["user_id"])
@router.post("/key/delete", tags=["key management"], dependencies=[Depends(user_api_key_auth)])
async def delete_key_fn(request: Request, data: DeleteKeyRequest):