fix(proxy_server.py): enable pre+post-call hooks and max parallel request limits

This commit is contained in:
Krrish Dholakia 2023-12-08 17:10:57 -08:00
parent 977bfaaab9
commit 5fa2b6e5ad
9 changed files with 213 additions and 130 deletions

View file

@ -3,7 +3,7 @@ repos:
rev: 3.8.4 # The version of flake8 to use rev: 3.8.4 # The version of flake8 to use
hooks: hooks:
- id: flake8 - id: flake8
exclude: ^litellm/tests/|^litellm/proxy/|^litellm/integrations/ exclude: ^litellm/tests/|^litellm/proxy/proxy_server.py|^litellm/integrations/
additional_dependencies: [flake8-print] additional_dependencies: [flake8-print]
files: litellm/.*\.py files: litellm/.*\.py
- repo: local - repo: local

View file

@ -76,12 +76,13 @@ class ModelParams(BaseModel):
protected_namespaces = () protected_namespaces = ()
class GenerateKeyRequest(BaseModel): class GenerateKeyRequest(BaseModel):
duration: str = "1h" duration: Optional[str] = "1h"
models: list = [] models: Optional[list] = []
aliases: dict = {} aliases: Optional[dict] = {}
config: dict = {} config: Optional[dict] = {}
spend: int = 0 spend: Optional[float] = 0
user_id: Optional[str] = None user_id: Optional[str] = None
max_parallel_requests: Optional[int] = None
class GenerateKeyResponse(BaseModel): class GenerateKeyResponse(BaseModel):
key: str key: str
@ -96,8 +97,17 @@ class DeleteKeyRequest(BaseModel):
class UserAPIKeyAuth(BaseModel): # the expected response object for user api key auth class UserAPIKeyAuth(BaseModel): # the expected response object for user api key auth
"""
Return the row in the db
"""
api_key: Optional[str] = None api_key: Optional[str] = None
models: list = []
aliases: dict = {}
config: dict = {}
spend: Optional[float] = 0
user_id: Optional[str] = None user_id: Optional[str] = None
max_parallel_requests: Optional[int] = None
duration: str = "1h"
class ConfigGeneralSettings(BaseModel): class ConfigGeneralSettings(BaseModel):
""" """

View file

@ -0,0 +1 @@
from . import *

View file

@ -0,0 +1,33 @@
from typing import Optional
from litellm.caching import DualCache
from fastapi import HTTPException
async def max_parallel_request_allow_request(max_parallel_requests: Optional[int], api_key: Optional[str], user_api_key_cache: DualCache):
if api_key is None:
return
if max_parallel_requests is None:
return
# CHECK IF REQUEST ALLOWED
request_count_api_key = f"{api_key}_request_count"
current = user_api_key_cache.get_cache(key=request_count_api_key)
if current is None:
user_api_key_cache.set_cache(request_count_api_key, 1)
elif int(current) < max_parallel_requests:
# Increase count for this token
user_api_key_cache.set_cache(request_count_api_key, int(current) + 1)
else:
raise HTTPException(status_code=429, detail="Max parallel request limit reached.")
async def max_parallel_request_update_count(api_key: Optional[str], user_api_key_cache: DualCache):
if api_key is None:
return
request_count_api_key = f"{api_key}_request_count"
# Decrease count for this token
current = user_api_key_cache.get_cache(key=request_count_api_key) or 1
user_api_key_cache.set_cache(request_count_api_key, int(current) - 1)
return

View file

@ -102,7 +102,7 @@ from litellm.proxy._types import *
from litellm.caching import DualCache from litellm.caching import DualCache
from litellm.proxy.health_check import perform_health_check from litellm.proxy.health_check import perform_health_check
litellm.suppress_debug_info = True litellm.suppress_debug_info = True
from fastapi import FastAPI, Request, HTTPException, status, Depends, BackgroundTasks from fastapi import FastAPI, Request, HTTPException, status, Depends, BackgroundTasks, Header
from fastapi.routing import APIRouter from fastapi.routing import APIRouter
from fastapi.security import OAuth2PasswordBearer from fastapi.security import OAuth2PasswordBearer
from fastapi.encoders import jsonable_encoder from fastapi.encoders import jsonable_encoder
@ -198,7 +198,7 @@ user_custom_auth = None
use_background_health_checks = None use_background_health_checks = None
health_check_interval = None health_check_interval = None
health_check_results = {} health_check_results = {}
call_hooks = CallHooks() call_hooks = CallHooks(user_api_key_cache=user_api_key_cache)
proxy_logging_obj: Optional[ProxyLogging] = None proxy_logging_obj: Optional[ProxyLogging] = None
### REDIS QUEUE ### ### REDIS QUEUE ###
async_result = None async_result = None
@ -259,10 +259,10 @@ async def user_api_key_auth(request: Request, api_key: str = fastapi.Security(ap
if prisma_client: if prisma_client:
## check for cache hit (In-Memory Cache) ## check for cache hit (In-Memory Cache)
valid_token = user_api_key_cache.get_cache(key=api_key) valid_token = user_api_key_cache.get_cache(key=api_key)
print(f"valid_token from cache: {valid_token}")
if valid_token is None: if valid_token is None:
## check db ## check db
cleaned_api_key = api_key valid_token = await prisma_client.get_data(token=api_key, expires=datetime.utcnow())
valid_token = await prisma_client.get_data(token=cleaned_api_key, expires=datetime.utcnow())
user_api_key_cache.set_cache(key=api_key, value=valid_token, ttl=60) user_api_key_cache.set_cache(key=api_key, value=valid_token, ttl=60)
elif valid_token is not None: elif valid_token is not None:
print(f"API Key Cache Hit!") print(f"API Key Cache Hit!")
@ -274,10 +274,10 @@ async def user_api_key_auth(request: Request, api_key: str = fastapi.Security(ap
llm_model_list = model_list llm_model_list = model_list
print("\n new llm router model list", llm_model_list) print("\n new llm router model list", llm_model_list)
if len(valid_token.models) == 0: # assume an empty model list means all models are allowed to be called if len(valid_token.models) == 0: # assume an empty model list means all models are allowed to be called
return_dict = {"api_key": valid_token.token} api_key = valid_token.token
if valid_token.user_id: valid_token_dict = valid_token.model_dump()
return_dict["user_id"] = valid_token.user_id valid_token_dict.pop("token", None)
return UserAPIKeyAuth(**return_dict) return UserAPIKeyAuth(api_key=api_key, **valid_token_dict)
else: else:
data = await request.json() data = await request.json()
model = data.get("model", None) model = data.get("model", None)
@ -285,10 +285,10 @@ async def user_api_key_auth(request: Request, api_key: str = fastapi.Security(ap
model = litellm.model_alias_map[model] model = litellm.model_alias_map[model]
if model and model not in valid_token.models: if model and model not in valid_token.models:
raise Exception(f"Token not allowed to access model") raise Exception(f"Token not allowed to access model")
return_dict = {"api_key": valid_token.token} api_key = valid_token.token
if valid_token.user_id: valid_token_dict = valid_token.model_dump()
return_dict["user_id"] = valid_token.user_id valid_token.pop("token", None)
return UserAPIKeyAuth(**return_dict) return UserAPIKeyAuth(api_key=api_key, **valid_token)
else: else:
raise Exception(f"Invalid token") raise Exception(f"Invalid token")
except Exception as e: except Exception as e:
@ -588,7 +588,7 @@ def load_router_config(router: Optional[litellm.Router], config_file_path: str):
call_hooks.update_router_config(litellm_settings=litellm_settings, model_list=model_list, general_settings=general_settings) call_hooks.update_router_config(litellm_settings=litellm_settings, model_list=model_list, general_settings=general_settings)
return router, model_list, general_settings return router, model_list, general_settings
async def generate_key_helper_fn(duration_str: Optional[str], models: list, aliases: dict, config: dict, spend: float, token: Optional[str]=None, user_id: Optional[str]=None): async def generate_key_helper_fn(duration: Optional[str], models: list, aliases: dict, config: dict, spend: float, token: Optional[str]=None, user_id: Optional[str]=None, max_parallel_requests: Optional[int]=None):
global prisma_client global prisma_client
if prisma_client is None: if prisma_client is None:
@ -617,11 +617,11 @@ async def generate_key_helper_fn(duration_str: Optional[str], models: list, alia
else: else:
raise ValueError("Unsupported duration unit") raise ValueError("Unsupported duration unit")
if duration_str is None: # allow tokens that never expire if duration is None: # allow tokens that never expire
expires = None expires = None
else: else:
duration = _duration_in_seconds(duration=duration_str) duration_s = _duration_in_seconds(duration=duration)
expires = datetime.utcnow() + timedelta(seconds=duration) expires = datetime.utcnow() + timedelta(seconds=duration_s)
aliases_json = json.dumps(aliases) aliases_json = json.dumps(aliases)
config_json = json.dumps(config) config_json = json.dumps(config)
@ -635,7 +635,8 @@ async def generate_key_helper_fn(duration_str: Optional[str], models: list, alia
"aliases": aliases_json, "aliases": aliases_json,
"config": config_json, "config": config_json,
"spend": spend, "spend": spend,
"user_id": user_id "user_id": user_id,
"max_parallel_requests": max_parallel_requests
} }
new_verification_token = await prisma_client.insert_data(data=verification_token_data) new_verification_token = await prisma_client.insert_data(data=verification_token_data)
except Exception as e: except Exception as e:
@ -755,14 +756,12 @@ def data_generator(response):
except: except:
yield f"data: {json.dumps(chunk)}\n\n" yield f"data: {json.dumps(chunk)}\n\n"
async def async_data_generator(response): async def async_data_generator(response, user_api_key_dict):
global call_hooks global call_hooks
print_verbose("inside generator") print_verbose("inside generator")
async for chunk in response: async for chunk in response:
print_verbose(f"returned chunk: {chunk}") print_verbose(f"returned chunk: {chunk}")
### CALL HOOKS ### - modify outgoing response
response = call_hooks.post_call_success(chunk=chunk, call_type="completion")
try: try:
yield f"data: {json.dumps(chunk.dict())}\n\n" yield f"data: {json.dumps(chunk.dict())}\n\n"
except: except:
@ -812,36 +811,6 @@ def get_litellm_model_info(model: dict = {}):
# if litellm does not have info on the model it should return {} # if litellm does not have info on the model it should return {}
return {} return {}
@app.middleware("http")
async def rate_limit_per_token(request: Request, call_next):
global user_api_key_cache, general_settings
max_parallel_requests = general_settings.get("max_parallel_requests", None)
api_key = request.headers.get("Authorization")
if max_parallel_requests is not None and api_key is not None: # Rate limiting is enabled
api_key = _get_bearer_token(api_key=api_key)
# CHECK IF REQUEST ALLOWED
request_count_api_key = f"{api_key}_request_count"
current = user_api_key_cache.get_cache(key=request_count_api_key)
if current is None:
user_api_key_cache.set_cache(request_count_api_key, 1)
elif int(current) < max_parallel_requests:
# Increase count for this token
user_api_key_cache.set_cache(request_count_api_key, int(current) + 1)
else:
raise HTTPException(status_code=429, detail="Too many requests.")
response = await call_next(request)
# Decrease count for this token
current = user_api_key_cache.get_cache(key=request_count_api_key)
user_api_key_cache.set_cache(request_count_api_key, int(current) - 1)
return response
else: # Rate limiting is not enabled, just pass the request
response = await call_next(request)
return response
@router.on_event("startup") @router.on_event("startup")
async def startup_event(): async def startup_event():
global prisma_client, master_key, use_background_health_checks global prisma_client, master_key, use_background_health_checks
@ -868,7 +837,7 @@ async def startup_event():
if prisma_client is not None and master_key is not None: if prisma_client is not None and master_key is not None:
# add master key to db # add master key to db
await generate_key_helper_fn(duration_str=None, models=[], aliases={}, config={}, spend=0, token=master_key) await generate_key_helper_fn(duration=None, models=[], aliases={}, config={}, spend=0, token=master_key)
@router.on_event("shutdown") @router.on_event("shutdown")
async def shutdown_event(): async def shutdown_event():
@ -1008,7 +977,7 @@ async def chat_completion(request: Request, model: Optional[str] = None, user_ap
data["api_base"] = user_api_base data["api_base"] = user_api_base
### CALL HOOKS ### - modify incoming data before calling the model ### CALL HOOKS ### - modify incoming data before calling the model
data = call_hooks.pre_call(data=data, call_type="completion") data = await call_hooks.pre_call(user_api_key_dict=user_api_key_dict, data=data, call_type="completion")
### ROUTE THE REQUEST ### ### ROUTE THE REQUEST ###
router_model_names = [m["model_name"] for m in llm_model_list] if llm_model_list is not None else [] router_model_names = [m["model_name"] for m in llm_model_list] if llm_model_list is not None else []
@ -1021,15 +990,19 @@ async def chat_completion(request: Request, model: Optional[str] = None, user_ap
else: # router is not set else: # router is not set
response = await litellm.acompletion(**data) response = await litellm.acompletion(**data)
print(f"final response: {response}")
if 'stream' in data and data['stream'] == True: # use generate_responses to stream responses if 'stream' in data and data['stream'] == True: # use generate_responses to stream responses
return StreamingResponse(async_data_generator(response), media_type='text/event-stream') return StreamingResponse(async_data_generator(user_api_key_dict=user_api_key_dict, response=response), media_type='text/event-stream')
### CALL HOOKS ### - modify outgoing response ### CALL HOOKS ### - modify outgoing response
response = call_hooks.post_call_success(response=response, call_type="completion") response = await call_hooks.post_call_success(user_api_key_dict=user_api_key_dict, response=response, call_type="completion")
background_tasks.add_task(log_input_output, request, response) # background task for logging to OTEL background_tasks.add_task(log_input_output, request, response) # background task for logging to OTEL
return response return response
except Exception as e: except Exception as e:
print(f"Exception received: {str(e)}")
raise e
await call_hooks.post_call_failure(original_exception=e, user_api_key_dict=user_api_key_dict)
print(f"\033[1;31mAn error occurred: {e}\n\n Debug this by setting `--debug`, e.g. `litellm --model gpt-3.5-turbo --debug`") print(f"\033[1;31mAn error occurred: {e}\n\n Debug this by setting `--debug`, e.g. `litellm --model gpt-3.5-turbo --debug`")
router_model_names = [m["model_name"] for m in llm_model_list] if llm_model_list is not None else [] router_model_names = [m["model_name"] for m in llm_model_list] if llm_model_list is not None else []
if llm_router is not None and data.get("model", "") in router_model_names: if llm_router is not None and data.get("model", "") in router_model_names:
@ -1046,23 +1019,26 @@ async def chat_completion(request: Request, model: Optional[str] = None, user_ap
print(f"{key}: {value}") print(f"{key}: {value}")
if user_debug: if user_debug:
traceback.print_exc() traceback.print_exc()
error_traceback = traceback.format_exc()
error_msg = f"{str(e)}\n\n{error_traceback}" if isinstance(e, HTTPException):
try: raise e
status = e.status_code # type: ignore else:
except: error_traceback = traceback.format_exc()
status = 500 error_msg = f"{str(e)}\n\n{error_traceback}"
raise HTTPException( try:
status_code=status, status = e.status_code # type: ignore
detail=error_msg except:
) status = 500
raise HTTPException(
status_code=status,
detail=error_msg
)
@router.post("/v1/embeddings", dependencies=[Depends(user_api_key_auth)], response_class=ORJSONResponse) @router.post("/v1/embeddings", dependencies=[Depends(user_api_key_auth)], response_class=ORJSONResponse)
@router.post("/embeddings", dependencies=[Depends(user_api_key_auth)], response_class=ORJSONResponse) @router.post("/embeddings", dependencies=[Depends(user_api_key_auth)], response_class=ORJSONResponse)
async def embeddings(request: Request, user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), background_tasks: BackgroundTasks = BackgroundTasks()): async def embeddings(request: Request, user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), background_tasks: BackgroundTasks = BackgroundTasks()):
global call_hooks
try: try:
global call_hooks
# Use orjson to parse JSON data, orjson speeds up requests significantly # Use orjson to parse JSON data, orjson speeds up requests significantly
body = await request.body() body = await request.body()
data = orjson.loads(body) data = orjson.loads(body)
@ -1105,7 +1081,7 @@ async def embeddings(request: Request, user_api_key_dict: UserAPIKeyAuth = Depen
break break
### CALL HOOKS ### - modify incoming data before calling the model ### CALL HOOKS ### - modify incoming data before calling the model
data = call_hooks.pre_call(data=data, call_type="embeddings") data = await call_hooks.pre_call(user_api_key_dict=user_api_key_dict, data=data, call_type="embeddings")
## ROUTE TO CORRECT ENDPOINT ## ## ROUTE TO CORRECT ENDPOINT ##
if llm_router is not None and data["model"] in router_model_names: # model in router model list if llm_router is not None and data["model"] in router_model_names: # model in router model list
@ -1117,19 +1093,18 @@ async def embeddings(request: Request, user_api_key_dict: UserAPIKeyAuth = Depen
background_tasks.add_task(log_input_output, request, response) # background task for logging to OTEL background_tasks.add_task(log_input_output, request, response) # background task for logging to OTEL
### CALL HOOKS ### - modify outgoing response ### CALL HOOKS ### - modify outgoing response
data = call_hooks.post_call_success(response=response, call_type="embeddings") data = call_hooks.post_call_success(user_api_key_dict=user_api_key_dict, response=response, call_type="embeddings")
return response return response
except Exception as e: except Exception as e:
await call_hooks.post_call_failure(user_api_key_dict=user_api_key_dict, original_exception=e)
traceback.print_exc() traceback.print_exc()
raise e raise e
except Exception as e:
pass
#### KEY MANAGEMENT #### #### KEY MANAGEMENT ####
@router.post("/key/generate", tags=["key management"], dependencies=[Depends(user_api_key_auth)], response_model=GenerateKeyResponse) @router.post("/key/generate", tags=["key management"], dependencies=[Depends(user_api_key_auth)], response_model=GenerateKeyResponse)
async def generate_key_fn(request: Request, data: GenerateKeyRequest): async def generate_key_fn(request: Request, data: GenerateKeyRequest, Authorization: Optional[str] = Header(None)):
""" """
Generate an API key based on the provided data. Generate an API key based on the provided data.
@ -1141,26 +1116,17 @@ async def generate_key_fn(request: Request, data: GenerateKeyRequest):
- aliases: Optional[dict] - Any alias mappings, on top of anything in the config.yaml model list. - https://docs.litellm.ai/docs/proxy/virtual_keys#managing-auth---upgradedowngrade-models - aliases: Optional[dict] - Any alias mappings, on top of anything in the config.yaml model list. - https://docs.litellm.ai/docs/proxy/virtual_keys#managing-auth---upgradedowngrade-models
- config: Optional[dict] - any key-specific configs, overrides config in config.yaml - config: Optional[dict] - any key-specific configs, overrides config in config.yaml
- spend: Optional[int] - Amount spent by key. Default is 0. Will be updated by proxy whenever key is used. https://docs.litellm.ai/docs/proxy/virtual_keys#managing-auth---tracking-spend - spend: Optional[int] - Amount spent by key. Default is 0. Will be updated by proxy whenever key is used. https://docs.litellm.ai/docs/proxy/virtual_keys#managing-auth---tracking-spend
- max_parallel_requests: Optional[int] - Rate limit a user based on the number of parallel requests. Raises 429 error, if user's parallel requests > x.
Returns: Returns:
- key: The generated api key - key: (str) The generated api key
- expires: Datetime object for when key expires. - expires: (datetime) Datetime object for when key expires.
- user_id: (str) Unique user id - used for tracking spend across multiple keys for same user id.
""" """
# data = await request.json() # data = await request.json()
duration_str = data.duration # Default to 1 hour if duration is not provided data_json = data.model_dump()
models = data.models # Default to an empty list (meaning allow token to call all models) response = await generate_key_helper_fn(**data_json)
aliases = data.aliases # Default to an empty dict (no alias mappings, on top of anything in the config.yaml model_list) return GenerateKeyResponse(key=response["token"], expires=response["expires"], user_id=response["user_id"])
config = data.config
spend = data.spend
user_id = data.user_id
if isinstance(models, list):
response = await generate_key_helper_fn(duration_str=duration_str, models=models, aliases=aliases, config=config, spend=spend, user_id=user_id)
return GenerateKeyResponse(key=response["token"], expires=response["expires"], user_id=response["user_id"])
else:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail={"error": "models param must be a list"},
)
@router.post("/key/delete", tags=["key management"], dependencies=[Depends(user_api_key_auth)]) @router.post("/key/delete", tags=["key management"], dependencies=[Depends(user_api_key_auth)])
async def delete_key_fn(request: Request, data: DeleteKeyRequest): async def delete_key_fn(request: Request, data: DeleteKeyRequest):

View file

@ -16,4 +16,5 @@ model LiteLLM_VerificationToken {
aliases Json @default("{}") aliases Json @default("{}")
config Json @default("{}") config Json @default("{}")
user_id String? user_id String?
max_parallel_requests Int?
} }

View file

@ -1,7 +1,13 @@
from typing import Optional, List, Any, Literal from typing import Optional, List, Any, Literal
import os, subprocess, hashlib, importlib, asyncio import os, subprocess, hashlib, importlib, asyncio
import litellm, backoff import litellm, backoff
from litellm.proxy._types import UserAPIKeyAuth
from litellm.caching import DualCache
from litellm.proxy.hooks.parallel_request_limiter import max_parallel_request_allow_request, max_parallel_request_update_count
def print_verbose(print_statement):
if litellm.set_verbose:
print(print_statement) # noqa
### LOGGING ### ### LOGGING ###
class ProxyLogging: class ProxyLogging:
""" """
@ -17,7 +23,6 @@ class ProxyLogging:
self._init_litellm_callbacks() self._init_litellm_callbacks()
pass pass
def _init_litellm_callbacks(self): def _init_litellm_callbacks(self):
if len(litellm.callbacks) > 0: if len(litellm.callbacks) > 0:
for callback in litellm.callbacks: for callback in litellm.callbacks:
@ -69,11 +74,11 @@ class ProxyLogging:
# Function to be called whenever a retry is about to happen # Function to be called whenever a retry is about to happen
def on_backoff(details): def on_backoff(details):
# The 'tries' key in the details dictionary contains the number of completed tries # The 'tries' key in the details dictionary contains the number of completed tries
print(f"Backing off... this was attempt #{details['tries']}") print_verbose(f"Backing off... this was attempt #{details['tries']}")
class PrismaClient: class PrismaClient:
def __init__(self, database_url: str, proxy_logging_obj: ProxyLogging): def __init__(self, database_url: str, proxy_logging_obj: ProxyLogging):
print("LiteLLM: DATABASE_URL Set in config, trying to 'pip install prisma'") print_verbose("LiteLLM: DATABASE_URL Set in config, trying to 'pip install prisma'")
## init logging object ## init logging object
self.proxy_logging_obj = proxy_logging_obj self.proxy_logging_obj = proxy_logging_obj
@ -109,20 +114,22 @@ class PrismaClient:
max_time=10, # maximum total time to retry for max_time=10, # maximum total time to retry for
on_backoff=on_backoff, # specifying the function to call on backoff on_backoff=on_backoff, # specifying the function to call on backoff
) )
async def get_data(self, token: str, expires: Optional[Any]=None): async def get_data(self, token: str, expires: Optional[Any]=None):
try: try:
hashed_token = self.hash_token(token=token) # check if plain text or hash
if token.startswith("sk-"):
token = self.hash_token(token=token)
if expires: if expires:
response = await self.db.litellm_verificationtoken.find_first( response = await self.db.litellm_verificationtoken.find_first(
where={ where={
"token": hashed_token, "token": token,
"expires": {"gte": expires} # Check if the token is not expired "expires": {"gte": expires} # Check if the token is not expired
} }
) )
else: else:
response = await self.db.litellm_verificationtoken.find_unique( response = await self.db.litellm_verificationtoken.find_unique(
where={ where={
"token": hashed_token "token": token
} }
) )
return response return response
@ -175,25 +182,23 @@ class PrismaClient:
Update existing data Update existing data
""" """
try: try:
hashed_token = self.hash_token(token=token) print_verbose(f"token: {token}")
data["token"] = hashed_token # check if plain text or hash
await self.db.litellm_verificationtoken.update( if token.startswith("sk-"):
token = self.hash_token(token=token)
data["token"] = token
response = await self.db.litellm_verificationtoken.update(
where={ where={
"token": hashed_token "token": token
}, },
data={**data} # type: ignore data={**data} # type: ignore
) )
print("\033[91m" + f"DB write succeeded" + "\033[0m") print_verbose("\033[91m" + f"DB write succeeded {response}" + "\033[0m")
return {"token": token, "data": data} return {"token": token, "data": data}
except Exception as e: except Exception as e:
asyncio.create_task(self.proxy_logging_obj.failure_handler(original_exception=e)) asyncio.create_task(self.proxy_logging_obj.failure_handler(original_exception=e))
print() print_verbose("\033[91m" + f"DB write failed: {e}" + "\033[0m")
print()
print()
print("\033[91m" + f"DB write failed: {e}" + "\033[0m")
print()
print()
print()
raise e raise e
@ -252,7 +257,7 @@ class PrismaClient:
### CUSTOM FILE ### ### CUSTOM FILE ###
def get_instance_fn(value: str, config_file_path: Optional[str] = None) -> Any: def get_instance_fn(value: str, config_file_path: Optional[str] = None) -> Any:
try: try:
print(f"value: {value}") print_verbose(f"value: {value}")
# Split the path by dots to separate module from instance # Split the path by dots to separate module from instance
parts = value.split(".") parts = value.split(".")
@ -285,8 +290,6 @@ def get_instance_fn(value: str, config_file_path: Optional[str] = None) -> Any:
except Exception as e: except Exception as e:
raise e raise e
### CALL HOOKS ### ### CALL HOOKS ###
class CallHooks: class CallHooks:
""" """
@ -297,20 +300,55 @@ class CallHooks:
2. /embeddings 2. /embeddings
""" """
def __init__(self, *args, **kwargs): def __init__(self, user_api_key_cache: DualCache):
self.call_details = {} self.call_details: dict = {}
self.call_details["user_api_key_cache"] = user_api_key_cache
def update_router_config(self, litellm_settings: dict, general_settings: dict, model_list: list): def update_router_config(self, litellm_settings: dict, general_settings: dict, model_list: list):
self.call_details["litellm_settings"] = litellm_settings self.call_details["litellm_settings"] = litellm_settings
self.call_details["general_settings"] = general_settings self.call_details["general_settings"] = general_settings
self.call_details["model_list"] = model_list self.call_details["model_list"] = model_list
def pre_call(self, data: dict, call_type: Literal["completion", "embeddings"]): async def pre_call(self, user_api_key_dict: UserAPIKeyAuth, data: dict, call_type: Literal["completion", "embeddings"]):
self.call_details["data"] = data try:
return data self.call_details["data"] = data
self.call_details["call_type"] = call_type
def post_call_success(self, response: Optional[Any]=None, call_type: Optional[Literal["completion", "embeddings"]]=None, chunk: Optional[Any]=None): ## check if max parallel requests set
return response if user_api_key_dict.max_parallel_requests is not None:
## if set, check if request allowed
await max_parallel_request_allow_request(
max_parallel_requests=user_api_key_dict.max_parallel_requests,
api_key=user_api_key_dict.api_key,
user_api_key_cache=self.call_details["user_api_key_cache"])
return data
except Exception as e:
raise e
def post_call_failure(self, *args, **kwargs): async def post_call_success(self, user_api_key_dict: UserAPIKeyAuth, response: Optional[Any]=None, call_type: Optional[Literal["completion", "embeddings"]]=None, chunk: Optional[Any]=None):
pass try:
# check if max parallel requests set
if user_api_key_dict.max_parallel_requests is not None:
## decrement call, once complete
await max_parallel_request_update_count(
api_key=user_api_key_dict.api_key,
user_api_key_cache=self.call_details["user_api_key_cache"])
return response
except Exception as e:
raise e
async def post_call_failure(self, original_exception: Exception, user_api_key_dict: UserAPIKeyAuth):
# check if max parallel requests set
if user_api_key_dict.max_parallel_requests is not None:
## decrement call count if call failed
if (hasattr(original_exception, "status_code")
and original_exception.status_code == 429
and "Max parallel request limit reached" in str(original_exception)):
pass # ignore failed calls due to max limit being reached
else:
await max_parallel_request_update_count(
api_key=user_api_key_dict.api_key,
user_api_key_cache=self.call_details["user_api_key_cache"])
return

View file

@ -3,7 +3,6 @@ general_settings:
master_key: os.environ/PROXY_MASTER_KEY master_key: os.environ/PROXY_MASTER_KEY
litellm_settings: litellm_settings:
drop_params: true drop_params: true
set_verbose: true
success_callback: ["langfuse"] success_callback: ["langfuse"]
model_list: model_list:

View file

@ -1,4 +1,4 @@
import sys, os import sys, os, time
import traceback import traceback
from dotenv import load_dotenv from dotenv import load_dotenv
@ -19,7 +19,7 @@ logging.basicConfig(
level=logging.DEBUG, # Set the desired logging level level=logging.DEBUG, # Set the desired logging level
format="%(asctime)s - %(levelname)s - %(message)s", format="%(asctime)s - %(levelname)s - %(message)s",
) )
from concurrent.futures import ThreadPoolExecutor
# test /chat/completion request to the proxy # test /chat/completion request to the proxy
from fastapi.testclient import TestClient from fastapi.testclient import TestClient
from fastapi import FastAPI from fastapi import FastAPI
@ -62,6 +62,41 @@ def test_add_new_key(client):
assert result["key"].startswith("sk-") assert result["key"].startswith("sk-")
print(f"Received response: {result}") print(f"Received response: {result}")
except Exception as e: except Exception as e:
pytest.fail("LiteLLM Proxy test failed. Exception", e) pytest.fail(f"LiteLLM Proxy test failed. Exception: {str(e)}")
# # Run the test - only runs via pytest # # Run the test - only runs via pytest
def test_add_new_key_max_parallel_limit(client):
try:
# Your test data
test_data = {"duration": "20m", "max_parallel_requests": 1}
# Your bearer token
token = os.getenv("PROXY_MASTER_KEY")
headers = {
"Authorization": f"Bearer {token}"
}
response = client.post("/key/generate", json=test_data, headers=headers)
print(f"response: {response.text}")
assert response.status_code == 200
result = response.json()
def _post_data():
json_data = {'model': 'azure-model', "messages": [{"role": "user", "content": f"this is a test request, write a short poem {time.time()}"}]}
response = client.post("/chat/completions", json=json_data, headers={"Authorization": f"Bearer {result['key']}"})
return response
def _run_in_parallel():
with ThreadPoolExecutor(max_workers=2) as executor:
future1 = executor.submit(_post_data)
future2 = executor.submit(_post_data)
# Obtain the results from the futures
response1 = future1.result()
response2 = future2.result()
if response1.status_code == 429 or response2.status_code == 429:
pass
else:
raise Exception()
_run_in_parallel()
except Exception as e:
pytest.fail(f"LiteLLM Proxy test failed. Exception: {str(e)}")