forked from phoenix/litellm-mirror
fix(proxy_server.py): enable pre+post-call hooks and max parallel request limits
This commit is contained in:
parent
977bfaaab9
commit
5fa2b6e5ad
9 changed files with 213 additions and 130 deletions
|
@ -102,7 +102,7 @@ from litellm.proxy._types import *
|
|||
from litellm.caching import DualCache
|
||||
from litellm.proxy.health_check import perform_health_check
|
||||
litellm.suppress_debug_info = True
|
||||
from fastapi import FastAPI, Request, HTTPException, status, Depends, BackgroundTasks
|
||||
from fastapi import FastAPI, Request, HTTPException, status, Depends, BackgroundTasks, Header
|
||||
from fastapi.routing import APIRouter
|
||||
from fastapi.security import OAuth2PasswordBearer
|
||||
from fastapi.encoders import jsonable_encoder
|
||||
|
@ -198,7 +198,7 @@ user_custom_auth = None
|
|||
use_background_health_checks = None
|
||||
health_check_interval = None
|
||||
health_check_results = {}
|
||||
call_hooks = CallHooks()
|
||||
call_hooks = CallHooks(user_api_key_cache=user_api_key_cache)
|
||||
proxy_logging_obj: Optional[ProxyLogging] = None
|
||||
### REDIS QUEUE ###
|
||||
async_result = None
|
||||
|
@ -259,10 +259,10 @@ async def user_api_key_auth(request: Request, api_key: str = fastapi.Security(ap
|
|||
if prisma_client:
|
||||
## check for cache hit (In-Memory Cache)
|
||||
valid_token = user_api_key_cache.get_cache(key=api_key)
|
||||
print(f"valid_token from cache: {valid_token}")
|
||||
if valid_token is None:
|
||||
## check db
|
||||
cleaned_api_key = api_key
|
||||
valid_token = await prisma_client.get_data(token=cleaned_api_key, expires=datetime.utcnow())
|
||||
valid_token = await prisma_client.get_data(token=api_key, expires=datetime.utcnow())
|
||||
user_api_key_cache.set_cache(key=api_key, value=valid_token, ttl=60)
|
||||
elif valid_token is not None:
|
||||
print(f"API Key Cache Hit!")
|
||||
|
@ -274,10 +274,10 @@ async def user_api_key_auth(request: Request, api_key: str = fastapi.Security(ap
|
|||
llm_model_list = model_list
|
||||
print("\n new llm router model list", llm_model_list)
|
||||
if len(valid_token.models) == 0: # assume an empty model list means all models are allowed to be called
|
||||
return_dict = {"api_key": valid_token.token}
|
||||
if valid_token.user_id:
|
||||
return_dict["user_id"] = valid_token.user_id
|
||||
return UserAPIKeyAuth(**return_dict)
|
||||
api_key = valid_token.token
|
||||
valid_token_dict = valid_token.model_dump()
|
||||
valid_token_dict.pop("token", None)
|
||||
return UserAPIKeyAuth(api_key=api_key, **valid_token_dict)
|
||||
else:
|
||||
data = await request.json()
|
||||
model = data.get("model", None)
|
||||
|
@ -285,10 +285,10 @@ async def user_api_key_auth(request: Request, api_key: str = fastapi.Security(ap
|
|||
model = litellm.model_alias_map[model]
|
||||
if model and model not in valid_token.models:
|
||||
raise Exception(f"Token not allowed to access model")
|
||||
return_dict = {"api_key": valid_token.token}
|
||||
if valid_token.user_id:
|
||||
return_dict["user_id"] = valid_token.user_id
|
||||
return UserAPIKeyAuth(**return_dict)
|
||||
api_key = valid_token.token
|
||||
valid_token_dict = valid_token.model_dump()
|
||||
valid_token.pop("token", None)
|
||||
return UserAPIKeyAuth(api_key=api_key, **valid_token)
|
||||
else:
|
||||
raise Exception(f"Invalid token")
|
||||
except Exception as e:
|
||||
|
@ -588,7 +588,7 @@ def load_router_config(router: Optional[litellm.Router], config_file_path: str):
|
|||
call_hooks.update_router_config(litellm_settings=litellm_settings, model_list=model_list, general_settings=general_settings)
|
||||
return router, model_list, general_settings
|
||||
|
||||
async def generate_key_helper_fn(duration_str: Optional[str], models: list, aliases: dict, config: dict, spend: float, token: Optional[str]=None, user_id: Optional[str]=None):
|
||||
async def generate_key_helper_fn(duration: Optional[str], models: list, aliases: dict, config: dict, spend: float, token: Optional[str]=None, user_id: Optional[str]=None, max_parallel_requests: Optional[int]=None):
|
||||
global prisma_client
|
||||
|
||||
if prisma_client is None:
|
||||
|
@ -617,11 +617,11 @@ async def generate_key_helper_fn(duration_str: Optional[str], models: list, alia
|
|||
else:
|
||||
raise ValueError("Unsupported duration unit")
|
||||
|
||||
if duration_str is None: # allow tokens that never expire
|
||||
if duration is None: # allow tokens that never expire
|
||||
expires = None
|
||||
else:
|
||||
duration = _duration_in_seconds(duration=duration_str)
|
||||
expires = datetime.utcnow() + timedelta(seconds=duration)
|
||||
duration_s = _duration_in_seconds(duration=duration)
|
||||
expires = datetime.utcnow() + timedelta(seconds=duration_s)
|
||||
|
||||
aliases_json = json.dumps(aliases)
|
||||
config_json = json.dumps(config)
|
||||
|
@ -635,7 +635,8 @@ async def generate_key_helper_fn(duration_str: Optional[str], models: list, alia
|
|||
"aliases": aliases_json,
|
||||
"config": config_json,
|
||||
"spend": spend,
|
||||
"user_id": user_id
|
||||
"user_id": user_id,
|
||||
"max_parallel_requests": max_parallel_requests
|
||||
}
|
||||
new_verification_token = await prisma_client.insert_data(data=verification_token_data)
|
||||
except Exception as e:
|
||||
|
@ -755,14 +756,12 @@ def data_generator(response):
|
|||
except:
|
||||
yield f"data: {json.dumps(chunk)}\n\n"
|
||||
|
||||
async def async_data_generator(response):
|
||||
async def async_data_generator(response, user_api_key_dict):
|
||||
global call_hooks
|
||||
|
||||
print_verbose("inside generator")
|
||||
async for chunk in response:
|
||||
print_verbose(f"returned chunk: {chunk}")
|
||||
### CALL HOOKS ### - modify outgoing response
|
||||
response = call_hooks.post_call_success(chunk=chunk, call_type="completion")
|
||||
try:
|
||||
yield f"data: {json.dumps(chunk.dict())}\n\n"
|
||||
except:
|
||||
|
@ -812,36 +811,6 @@ def get_litellm_model_info(model: dict = {}):
|
|||
# if litellm does not have info on the model it should return {}
|
||||
return {}
|
||||
|
||||
@app.middleware("http")
|
||||
async def rate_limit_per_token(request: Request, call_next):
|
||||
global user_api_key_cache, general_settings
|
||||
max_parallel_requests = general_settings.get("max_parallel_requests", None)
|
||||
api_key = request.headers.get("Authorization")
|
||||
if max_parallel_requests is not None and api_key is not None: # Rate limiting is enabled
|
||||
api_key = _get_bearer_token(api_key=api_key)
|
||||
# CHECK IF REQUEST ALLOWED
|
||||
request_count_api_key = f"{api_key}_request_count"
|
||||
current = user_api_key_cache.get_cache(key=request_count_api_key)
|
||||
if current is None:
|
||||
user_api_key_cache.set_cache(request_count_api_key, 1)
|
||||
elif int(current) < max_parallel_requests:
|
||||
# Increase count for this token
|
||||
user_api_key_cache.set_cache(request_count_api_key, int(current) + 1)
|
||||
else:
|
||||
raise HTTPException(status_code=429, detail="Too many requests.")
|
||||
|
||||
|
||||
response = await call_next(request)
|
||||
|
||||
# Decrease count for this token
|
||||
current = user_api_key_cache.get_cache(key=request_count_api_key)
|
||||
user_api_key_cache.set_cache(request_count_api_key, int(current) - 1)
|
||||
|
||||
return response
|
||||
else: # Rate limiting is not enabled, just pass the request
|
||||
response = await call_next(request)
|
||||
return response
|
||||
|
||||
@router.on_event("startup")
|
||||
async def startup_event():
|
||||
global prisma_client, master_key, use_background_health_checks
|
||||
|
@ -868,7 +837,7 @@ async def startup_event():
|
|||
|
||||
if prisma_client is not None and master_key is not None:
|
||||
# add master key to db
|
||||
await generate_key_helper_fn(duration_str=None, models=[], aliases={}, config={}, spend=0, token=master_key)
|
||||
await generate_key_helper_fn(duration=None, models=[], aliases={}, config={}, spend=0, token=master_key)
|
||||
|
||||
@router.on_event("shutdown")
|
||||
async def shutdown_event():
|
||||
|
@ -1008,7 +977,7 @@ async def chat_completion(request: Request, model: Optional[str] = None, user_ap
|
|||
data["api_base"] = user_api_base
|
||||
|
||||
### CALL HOOKS ### - modify incoming data before calling the model
|
||||
data = call_hooks.pre_call(data=data, call_type="completion")
|
||||
data = await call_hooks.pre_call(user_api_key_dict=user_api_key_dict, data=data, call_type="completion")
|
||||
|
||||
### ROUTE THE REQUEST ###
|
||||
router_model_names = [m["model_name"] for m in llm_model_list] if llm_model_list is not None else []
|
||||
|
@ -1021,15 +990,19 @@ async def chat_completion(request: Request, model: Optional[str] = None, user_ap
|
|||
else: # router is not set
|
||||
response = await litellm.acompletion(**data)
|
||||
|
||||
print(f"final response: {response}")
|
||||
if 'stream' in data and data['stream'] == True: # use generate_responses to stream responses
|
||||
return StreamingResponse(async_data_generator(response), media_type='text/event-stream')
|
||||
return StreamingResponse(async_data_generator(user_api_key_dict=user_api_key_dict, response=response), media_type='text/event-stream')
|
||||
|
||||
### CALL HOOKS ### - modify outgoing response
|
||||
response = call_hooks.post_call_success(response=response, call_type="completion")
|
||||
response = await call_hooks.post_call_success(user_api_key_dict=user_api_key_dict, response=response, call_type="completion")
|
||||
|
||||
background_tasks.add_task(log_input_output, request, response) # background task for logging to OTEL
|
||||
return response
|
||||
except Exception as e:
|
||||
except Exception as e:
|
||||
print(f"Exception received: {str(e)}")
|
||||
raise e
|
||||
await call_hooks.post_call_failure(original_exception=e, user_api_key_dict=user_api_key_dict)
|
||||
print(f"\033[1;31mAn error occurred: {e}\n\n Debug this by setting `--debug`, e.g. `litellm --model gpt-3.5-turbo --debug`")
|
||||
router_model_names = [m["model_name"] for m in llm_model_list] if llm_model_list is not None else []
|
||||
if llm_router is not None and data.get("model", "") in router_model_names:
|
||||
|
@ -1046,23 +1019,26 @@ async def chat_completion(request: Request, model: Optional[str] = None, user_ap
|
|||
print(f"{key}: {value}")
|
||||
if user_debug:
|
||||
traceback.print_exc()
|
||||
error_traceback = traceback.format_exc()
|
||||
error_msg = f"{str(e)}\n\n{error_traceback}"
|
||||
try:
|
||||
status = e.status_code # type: ignore
|
||||
except:
|
||||
status = 500
|
||||
raise HTTPException(
|
||||
status_code=status,
|
||||
detail=error_msg
|
||||
)
|
||||
|
||||
if isinstance(e, HTTPException):
|
||||
raise e
|
||||
else:
|
||||
error_traceback = traceback.format_exc()
|
||||
error_msg = f"{str(e)}\n\n{error_traceback}"
|
||||
try:
|
||||
status = e.status_code # type: ignore
|
||||
except:
|
||||
status = 500
|
||||
raise HTTPException(
|
||||
status_code=status,
|
||||
detail=error_msg
|
||||
)
|
||||
|
||||
@router.post("/v1/embeddings", dependencies=[Depends(user_api_key_auth)], response_class=ORJSONResponse)
|
||||
@router.post("/embeddings", dependencies=[Depends(user_api_key_auth)], response_class=ORJSONResponse)
|
||||
async def embeddings(request: Request, user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), background_tasks: BackgroundTasks = BackgroundTasks()):
|
||||
global call_hooks
|
||||
try:
|
||||
global call_hooks
|
||||
|
||||
# Use orjson to parse JSON data, orjson speeds up requests significantly
|
||||
body = await request.body()
|
||||
data = orjson.loads(body)
|
||||
|
@ -1105,7 +1081,7 @@ async def embeddings(request: Request, user_api_key_dict: UserAPIKeyAuth = Depen
|
|||
break
|
||||
|
||||
### CALL HOOKS ### - modify incoming data before calling the model
|
||||
data = call_hooks.pre_call(data=data, call_type="embeddings")
|
||||
data = await call_hooks.pre_call(user_api_key_dict=user_api_key_dict, data=data, call_type="embeddings")
|
||||
|
||||
## ROUTE TO CORRECT ENDPOINT ##
|
||||
if llm_router is not None and data["model"] in router_model_names: # model in router model list
|
||||
|
@ -1117,19 +1093,18 @@ async def embeddings(request: Request, user_api_key_dict: UserAPIKeyAuth = Depen
|
|||
background_tasks.add_task(log_input_output, request, response) # background task for logging to OTEL
|
||||
|
||||
### CALL HOOKS ### - modify outgoing response
|
||||
data = call_hooks.post_call_success(response=response, call_type="embeddings")
|
||||
data = call_hooks.post_call_success(user_api_key_dict=user_api_key_dict, response=response, call_type="embeddings")
|
||||
|
||||
return response
|
||||
except Exception as e:
|
||||
await call_hooks.post_call_failure(user_api_key_dict=user_api_key_dict, original_exception=e)
|
||||
traceback.print_exc()
|
||||
raise e
|
||||
except Exception as e:
|
||||
pass
|
||||
|
||||
#### KEY MANAGEMENT ####
|
||||
|
||||
@router.post("/key/generate", tags=["key management"], dependencies=[Depends(user_api_key_auth)], response_model=GenerateKeyResponse)
|
||||
async def generate_key_fn(request: Request, data: GenerateKeyRequest):
|
||||
async def generate_key_fn(request: Request, data: GenerateKeyRequest, Authorization: Optional[str] = Header(None)):
|
||||
"""
|
||||
Generate an API key based on the provided data.
|
||||
|
||||
|
@ -1141,26 +1116,17 @@ async def generate_key_fn(request: Request, data: GenerateKeyRequest):
|
|||
- aliases: Optional[dict] - Any alias mappings, on top of anything in the config.yaml model list. - https://docs.litellm.ai/docs/proxy/virtual_keys#managing-auth---upgradedowngrade-models
|
||||
- config: Optional[dict] - any key-specific configs, overrides config in config.yaml
|
||||
- spend: Optional[int] - Amount spent by key. Default is 0. Will be updated by proxy whenever key is used. https://docs.litellm.ai/docs/proxy/virtual_keys#managing-auth---tracking-spend
|
||||
- max_parallel_requests: Optional[int] - Rate limit a user based on the number of parallel requests. Raises 429 error, if user's parallel requests > x.
|
||||
|
||||
Returns:
|
||||
- key: The generated api key
|
||||
- expires: Datetime object for when key expires.
|
||||
- key: (str) The generated api key
|
||||
- expires: (datetime) Datetime object for when key expires.
|
||||
- user_id: (str) Unique user id - used for tracking spend across multiple keys for same user id.
|
||||
"""
|
||||
# data = await request.json()
|
||||
duration_str = data.duration # Default to 1 hour if duration is not provided
|
||||
models = data.models # Default to an empty list (meaning allow token to call all models)
|
||||
aliases = data.aliases # Default to an empty dict (no alias mappings, on top of anything in the config.yaml model_list)
|
||||
config = data.config
|
||||
spend = data.spend
|
||||
user_id = data.user_id
|
||||
if isinstance(models, list):
|
||||
response = await generate_key_helper_fn(duration_str=duration_str, models=models, aliases=aliases, config=config, spend=spend, user_id=user_id)
|
||||
return GenerateKeyResponse(key=response["token"], expires=response["expires"], user_id=response["user_id"])
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail={"error": "models param must be a list"},
|
||||
)
|
||||
data_json = data.model_dump()
|
||||
response = await generate_key_helper_fn(**data_json)
|
||||
return GenerateKeyResponse(key=response["token"], expires=response["expires"], user_id=response["user_id"])
|
||||
|
||||
@router.post("/key/delete", tags=["key management"], dependencies=[Depends(user_api_key_auth)])
|
||||
async def delete_key_fn(request: Request, data: DeleteKeyRequest):
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue