mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 18:54:30 +00:00
fix(proxy_server.py): enable pre+post-call hooks and max parallel request limits
This commit is contained in:
parent
977bfaaab9
commit
5fa2b6e5ad
9 changed files with 213 additions and 130 deletions
|
@ -1,7 +1,13 @@
|
|||
from typing import Optional, List, Any, Literal
|
||||
import os, subprocess, hashlib, importlib, asyncio
|
||||
import litellm, backoff
|
||||
from litellm.proxy._types import UserAPIKeyAuth
|
||||
from litellm.caching import DualCache
|
||||
from litellm.proxy.hooks.parallel_request_limiter import max_parallel_request_allow_request, max_parallel_request_update_count
|
||||
|
||||
def print_verbose(print_statement):
|
||||
if litellm.set_verbose:
|
||||
print(print_statement) # noqa
|
||||
### LOGGING ###
|
||||
class ProxyLogging:
|
||||
"""
|
||||
|
@ -17,7 +23,6 @@ class ProxyLogging:
|
|||
self._init_litellm_callbacks()
|
||||
pass
|
||||
|
||||
|
||||
def _init_litellm_callbacks(self):
|
||||
if len(litellm.callbacks) > 0:
|
||||
for callback in litellm.callbacks:
|
||||
|
@ -69,11 +74,11 @@ class ProxyLogging:
|
|||
# Function to be called whenever a retry is about to happen
|
||||
def on_backoff(details):
|
||||
# The 'tries' key in the details dictionary contains the number of completed tries
|
||||
print(f"Backing off... this was attempt #{details['tries']}")
|
||||
print_verbose(f"Backing off... this was attempt #{details['tries']}")
|
||||
|
||||
class PrismaClient:
|
||||
def __init__(self, database_url: str, proxy_logging_obj: ProxyLogging):
|
||||
print("LiteLLM: DATABASE_URL Set in config, trying to 'pip install prisma'")
|
||||
print_verbose("LiteLLM: DATABASE_URL Set in config, trying to 'pip install prisma'")
|
||||
## init logging object
|
||||
self.proxy_logging_obj = proxy_logging_obj
|
||||
|
||||
|
@ -109,20 +114,22 @@ class PrismaClient:
|
|||
max_time=10, # maximum total time to retry for
|
||||
on_backoff=on_backoff, # specifying the function to call on backoff
|
||||
)
|
||||
async def get_data(self, token: str, expires: Optional[Any]=None):
|
||||
async def get_data(self, token: str, expires: Optional[Any]=None):
|
||||
try:
|
||||
hashed_token = self.hash_token(token=token)
|
||||
# check if plain text or hash
|
||||
if token.startswith("sk-"):
|
||||
token = self.hash_token(token=token)
|
||||
if expires:
|
||||
response = await self.db.litellm_verificationtoken.find_first(
|
||||
where={
|
||||
"token": hashed_token,
|
||||
"token": token,
|
||||
"expires": {"gte": expires} # Check if the token is not expired
|
||||
}
|
||||
)
|
||||
else:
|
||||
response = await self.db.litellm_verificationtoken.find_unique(
|
||||
where={
|
||||
"token": hashed_token
|
||||
"token": token
|
||||
}
|
||||
)
|
||||
return response
|
||||
|
@ -175,25 +182,23 @@ class PrismaClient:
|
|||
Update existing data
|
||||
"""
|
||||
try:
|
||||
hashed_token = self.hash_token(token=token)
|
||||
data["token"] = hashed_token
|
||||
await self.db.litellm_verificationtoken.update(
|
||||
print_verbose(f"token: {token}")
|
||||
# check if plain text or hash
|
||||
if token.startswith("sk-"):
|
||||
token = self.hash_token(token=token)
|
||||
|
||||
data["token"] = token
|
||||
response = await self.db.litellm_verificationtoken.update(
|
||||
where={
|
||||
"token": hashed_token
|
||||
"token": token
|
||||
},
|
||||
data={**data} # type: ignore
|
||||
)
|
||||
print("\033[91m" + f"DB write succeeded" + "\033[0m")
|
||||
print_verbose("\033[91m" + f"DB write succeeded {response}" + "\033[0m")
|
||||
return {"token": token, "data": data}
|
||||
except Exception as e:
|
||||
asyncio.create_task(self.proxy_logging_obj.failure_handler(original_exception=e))
|
||||
print()
|
||||
print()
|
||||
print()
|
||||
print("\033[91m" + f"DB write failed: {e}" + "\033[0m")
|
||||
print()
|
||||
print()
|
||||
print()
|
||||
print_verbose("\033[91m" + f"DB write failed: {e}" + "\033[0m")
|
||||
raise e
|
||||
|
||||
|
||||
|
@ -252,7 +257,7 @@ class PrismaClient:
|
|||
### CUSTOM FILE ###
|
||||
def get_instance_fn(value: str, config_file_path: Optional[str] = None) -> Any:
|
||||
try:
|
||||
print(f"value: {value}")
|
||||
print_verbose(f"value: {value}")
|
||||
# Split the path by dots to separate module from instance
|
||||
parts = value.split(".")
|
||||
|
||||
|
@ -285,8 +290,6 @@ def get_instance_fn(value: str, config_file_path: Optional[str] = None) -> Any:
|
|||
except Exception as e:
|
||||
raise e
|
||||
|
||||
|
||||
|
||||
### CALL HOOKS ###
|
||||
class CallHooks:
|
||||
"""
|
||||
|
@ -297,20 +300,55 @@ class CallHooks:
|
|||
2. /embeddings
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
self.call_details = {}
|
||||
def __init__(self, user_api_key_cache: DualCache):
|
||||
self.call_details: dict = {}
|
||||
self.call_details["user_api_key_cache"] = user_api_key_cache
|
||||
|
||||
def update_router_config(self, litellm_settings: dict, general_settings: dict, model_list: list):
|
||||
self.call_details["litellm_settings"] = litellm_settings
|
||||
self.call_details["general_settings"] = general_settings
|
||||
self.call_details["model_list"] = model_list
|
||||
|
||||
def pre_call(self, data: dict, call_type: Literal["completion", "embeddings"]):
|
||||
self.call_details["data"] = data
|
||||
return data
|
||||
async def pre_call(self, user_api_key_dict: UserAPIKeyAuth, data: dict, call_type: Literal["completion", "embeddings"]):
|
||||
try:
|
||||
self.call_details["data"] = data
|
||||
self.call_details["call_type"] = call_type
|
||||
|
||||
def post_call_success(self, response: Optional[Any]=None, call_type: Optional[Literal["completion", "embeddings"]]=None, chunk: Optional[Any]=None):
|
||||
return response
|
||||
## check if max parallel requests set
|
||||
if user_api_key_dict.max_parallel_requests is not None:
|
||||
## if set, check if request allowed
|
||||
await max_parallel_request_allow_request(
|
||||
max_parallel_requests=user_api_key_dict.max_parallel_requests,
|
||||
api_key=user_api_key_dict.api_key,
|
||||
user_api_key_cache=self.call_details["user_api_key_cache"])
|
||||
|
||||
return data
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
def post_call_failure(self, *args, **kwargs):
|
||||
pass
|
||||
async def post_call_success(self, user_api_key_dict: UserAPIKeyAuth, response: Optional[Any]=None, call_type: Optional[Literal["completion", "embeddings"]]=None, chunk: Optional[Any]=None):
|
||||
try:
|
||||
# check if max parallel requests set
|
||||
if user_api_key_dict.max_parallel_requests is not None:
|
||||
## decrement call, once complete
|
||||
await max_parallel_request_update_count(
|
||||
api_key=user_api_key_dict.api_key,
|
||||
user_api_key_cache=self.call_details["user_api_key_cache"])
|
||||
|
||||
return response
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
async def post_call_failure(self, original_exception: Exception, user_api_key_dict: UserAPIKeyAuth):
|
||||
# check if max parallel requests set
|
||||
if user_api_key_dict.max_parallel_requests is not None:
|
||||
## decrement call count if call failed
|
||||
if (hasattr(original_exception, "status_code")
|
||||
and original_exception.status_code == 429
|
||||
and "Max parallel request limit reached" in str(original_exception)):
|
||||
pass # ignore failed calls due to max limit being reached
|
||||
else:
|
||||
await max_parallel_request_update_count(
|
||||
api_key=user_api_key_dict.api_key,
|
||||
user_api_key_cache=self.call_details["user_api_key_cache"])
|
||||
return
|
Loading…
Add table
Add a link
Reference in a new issue