fix(proxy_server.py): enable pre+post-call hooks and max parallel request limits

This commit is contained in:
Krrish Dholakia 2023-12-08 17:10:57 -08:00
parent 977bfaaab9
commit 5fa2b6e5ad
9 changed files with 213 additions and 130 deletions

View file

@ -1,7 +1,13 @@
from typing import Optional, List, Any, Literal
import os, subprocess, hashlib, importlib, asyncio
import litellm, backoff
from litellm.proxy._types import UserAPIKeyAuth
from litellm.caching import DualCache
from litellm.proxy.hooks.parallel_request_limiter import max_parallel_request_allow_request, max_parallel_request_update_count
def print_verbose(print_statement):
if litellm.set_verbose:
print(print_statement) # noqa
### LOGGING ###
class ProxyLogging:
"""
@ -17,7 +23,6 @@ class ProxyLogging:
self._init_litellm_callbacks()
pass
def _init_litellm_callbacks(self):
if len(litellm.callbacks) > 0:
for callback in litellm.callbacks:
@ -69,11 +74,11 @@ class ProxyLogging:
# Function to be called whenever a retry is about to happen
def on_backoff(details):
# The 'tries' key in the details dictionary contains the number of completed tries
print(f"Backing off... this was attempt #{details['tries']}")
print_verbose(f"Backing off... this was attempt #{details['tries']}")
class PrismaClient:
def __init__(self, database_url: str, proxy_logging_obj: ProxyLogging):
print("LiteLLM: DATABASE_URL Set in config, trying to 'pip install prisma'")
print_verbose("LiteLLM: DATABASE_URL Set in config, trying to 'pip install prisma'")
## init logging object
self.proxy_logging_obj = proxy_logging_obj
@ -109,20 +114,22 @@ class PrismaClient:
max_time=10, # maximum total time to retry for
on_backoff=on_backoff, # specifying the function to call on backoff
)
async def get_data(self, token: str, expires: Optional[Any]=None):
async def get_data(self, token: str, expires: Optional[Any]=None):
try:
hashed_token = self.hash_token(token=token)
# check if plain text or hash
if token.startswith("sk-"):
token = self.hash_token(token=token)
if expires:
response = await self.db.litellm_verificationtoken.find_first(
where={
"token": hashed_token,
"token": token,
"expires": {"gte": expires} # Check if the token is not expired
}
)
else:
response = await self.db.litellm_verificationtoken.find_unique(
where={
"token": hashed_token
"token": token
}
)
return response
@ -175,25 +182,23 @@ class PrismaClient:
Update existing data
"""
try:
hashed_token = self.hash_token(token=token)
data["token"] = hashed_token
await self.db.litellm_verificationtoken.update(
print_verbose(f"token: {token}")
# check if plain text or hash
if token.startswith("sk-"):
token = self.hash_token(token=token)
data["token"] = token
response = await self.db.litellm_verificationtoken.update(
where={
"token": hashed_token
"token": token
},
data={**data} # type: ignore
)
print("\033[91m" + f"DB write succeeded" + "\033[0m")
print_verbose("\033[91m" + f"DB write succeeded {response}" + "\033[0m")
return {"token": token, "data": data}
except Exception as e:
asyncio.create_task(self.proxy_logging_obj.failure_handler(original_exception=e))
print()
print()
print()
print("\033[91m" + f"DB write failed: {e}" + "\033[0m")
print()
print()
print()
print_verbose("\033[91m" + f"DB write failed: {e}" + "\033[0m")
raise e
@ -252,7 +257,7 @@ class PrismaClient:
### CUSTOM FILE ###
def get_instance_fn(value: str, config_file_path: Optional[str] = None) -> Any:
try:
print(f"value: {value}")
print_verbose(f"value: {value}")
# Split the path by dots to separate module from instance
parts = value.split(".")
@ -285,8 +290,6 @@ def get_instance_fn(value: str, config_file_path: Optional[str] = None) -> Any:
except Exception as e:
raise e
### CALL HOOKS ###
class CallHooks:
"""
@ -297,20 +300,55 @@ class CallHooks:
2. /embeddings
"""
def __init__(self, *args, **kwargs):
self.call_details = {}
def __init__(self, user_api_key_cache: DualCache):
self.call_details: dict = {}
self.call_details["user_api_key_cache"] = user_api_key_cache
def update_router_config(self, litellm_settings: dict, general_settings: dict, model_list: list):
self.call_details["litellm_settings"] = litellm_settings
self.call_details["general_settings"] = general_settings
self.call_details["model_list"] = model_list
def pre_call(self, data: dict, call_type: Literal["completion", "embeddings"]):
self.call_details["data"] = data
return data
async def pre_call(self, user_api_key_dict: UserAPIKeyAuth, data: dict, call_type: Literal["completion", "embeddings"]):
try:
self.call_details["data"] = data
self.call_details["call_type"] = call_type
def post_call_success(self, response: Optional[Any]=None, call_type: Optional[Literal["completion", "embeddings"]]=None, chunk: Optional[Any]=None):
return response
## check if max parallel requests set
if user_api_key_dict.max_parallel_requests is not None:
## if set, check if request allowed
await max_parallel_request_allow_request(
max_parallel_requests=user_api_key_dict.max_parallel_requests,
api_key=user_api_key_dict.api_key,
user_api_key_cache=self.call_details["user_api_key_cache"])
return data
except Exception as e:
raise e
def post_call_failure(self, *args, **kwargs):
pass
async def post_call_success(self, user_api_key_dict: UserAPIKeyAuth, response: Optional[Any]=None, call_type: Optional[Literal["completion", "embeddings"]]=None, chunk: Optional[Any]=None):
try:
# check if max parallel requests set
if user_api_key_dict.max_parallel_requests is not None:
## decrement call, once complete
await max_parallel_request_update_count(
api_key=user_api_key_dict.api_key,
user_api_key_cache=self.call_details["user_api_key_cache"])
return response
except Exception as e:
raise e
async def post_call_failure(self, original_exception: Exception, user_api_key_dict: UserAPIKeyAuth):
# check if max parallel requests set
if user_api_key_dict.max_parallel_requests is not None:
## decrement call count if call failed
if (hasattr(original_exception, "status_code")
and original_exception.status_code == 429
and "Max parallel request limit reached" in str(original_exception)):
pass # ignore failed calls due to max limit being reached
else:
await max_parallel_request_update_count(
api_key=user_api_key_dict.api_key,
user_api_key_cache=self.call_details["user_api_key_cache"])
return