forked from phoenix/litellm-mirror
fix(proxy_server.py): enable pre+post-call hooks and max parallel request limits
This commit is contained in:
parent
977bfaaab9
commit
5fa2b6e5ad
9 changed files with 213 additions and 130 deletions
|
@ -3,7 +3,7 @@ repos:
|
|||
rev: 3.8.4 # The version of flake8 to use
|
||||
hooks:
|
||||
- id: flake8
|
||||
exclude: ^litellm/tests/|^litellm/proxy/|^litellm/integrations/
|
||||
exclude: ^litellm/tests/|^litellm/proxy/proxy_server.py|^litellm/integrations/
|
||||
additional_dependencies: [flake8-print]
|
||||
files: litellm/.*\.py
|
||||
- repo: local
|
||||
|
|
|
@ -76,12 +76,13 @@ class ModelParams(BaseModel):
|
|||
protected_namespaces = ()
|
||||
|
||||
class GenerateKeyRequest(BaseModel):
|
||||
duration: str = "1h"
|
||||
models: list = []
|
||||
aliases: dict = {}
|
||||
config: dict = {}
|
||||
spend: int = 0
|
||||
duration: Optional[str] = "1h"
|
||||
models: Optional[list] = []
|
||||
aliases: Optional[dict] = {}
|
||||
config: Optional[dict] = {}
|
||||
spend: Optional[float] = 0
|
||||
user_id: Optional[str] = None
|
||||
max_parallel_requests: Optional[int] = None
|
||||
|
||||
class GenerateKeyResponse(BaseModel):
|
||||
key: str
|
||||
|
@ -96,8 +97,17 @@ class DeleteKeyRequest(BaseModel):
|
|||
|
||||
|
||||
class UserAPIKeyAuth(BaseModel): # the expected response object for user api key auth
|
||||
"""
|
||||
Return the row in the db
|
||||
"""
|
||||
api_key: Optional[str] = None
|
||||
models: list = []
|
||||
aliases: dict = {}
|
||||
config: dict = {}
|
||||
spend: Optional[float] = 0
|
||||
user_id: Optional[str] = None
|
||||
max_parallel_requests: Optional[int] = None
|
||||
duration: str = "1h"
|
||||
|
||||
class ConfigGeneralSettings(BaseModel):
|
||||
"""
|
||||
|
|
1
litellm/proxy/hooks/__init__.py
Normal file
1
litellm/proxy/hooks/__init__.py
Normal file
|
@ -0,0 +1 @@
|
|||
from . import *
|
33
litellm/proxy/hooks/parallel_request_limiter.py
Normal file
33
litellm/proxy/hooks/parallel_request_limiter.py
Normal file
|
@ -0,0 +1,33 @@
|
|||
from typing import Optional
|
||||
from litellm.caching import DualCache
|
||||
from fastapi import HTTPException
|
||||
|
||||
async def max_parallel_request_allow_request(max_parallel_requests: Optional[int], api_key: Optional[str], user_api_key_cache: DualCache):
|
||||
if api_key is None:
|
||||
return
|
||||
|
||||
if max_parallel_requests is None:
|
||||
return
|
||||
|
||||
# CHECK IF REQUEST ALLOWED
|
||||
request_count_api_key = f"{api_key}_request_count"
|
||||
current = user_api_key_cache.get_cache(key=request_count_api_key)
|
||||
if current is None:
|
||||
user_api_key_cache.set_cache(request_count_api_key, 1)
|
||||
elif int(current) < max_parallel_requests:
|
||||
# Increase count for this token
|
||||
user_api_key_cache.set_cache(request_count_api_key, int(current) + 1)
|
||||
else:
|
||||
raise HTTPException(status_code=429, detail="Max parallel request limit reached.")
|
||||
|
||||
|
||||
async def max_parallel_request_update_count(api_key: Optional[str], user_api_key_cache: DualCache):
|
||||
if api_key is None:
|
||||
return
|
||||
|
||||
request_count_api_key = f"{api_key}_request_count"
|
||||
# Decrease count for this token
|
||||
current = user_api_key_cache.get_cache(key=request_count_api_key) or 1
|
||||
user_api_key_cache.set_cache(request_count_api_key, int(current) - 1)
|
||||
|
||||
return
|
|
@ -102,7 +102,7 @@ from litellm.proxy._types import *
|
|||
from litellm.caching import DualCache
|
||||
from litellm.proxy.health_check import perform_health_check
|
||||
litellm.suppress_debug_info = True
|
||||
from fastapi import FastAPI, Request, HTTPException, status, Depends, BackgroundTasks
|
||||
from fastapi import FastAPI, Request, HTTPException, status, Depends, BackgroundTasks, Header
|
||||
from fastapi.routing import APIRouter
|
||||
from fastapi.security import OAuth2PasswordBearer
|
||||
from fastapi.encoders import jsonable_encoder
|
||||
|
@ -198,7 +198,7 @@ user_custom_auth = None
|
|||
use_background_health_checks = None
|
||||
health_check_interval = None
|
||||
health_check_results = {}
|
||||
call_hooks = CallHooks()
|
||||
call_hooks = CallHooks(user_api_key_cache=user_api_key_cache)
|
||||
proxy_logging_obj: Optional[ProxyLogging] = None
|
||||
### REDIS QUEUE ###
|
||||
async_result = None
|
||||
|
@ -259,10 +259,10 @@ async def user_api_key_auth(request: Request, api_key: str = fastapi.Security(ap
|
|||
if prisma_client:
|
||||
## check for cache hit (In-Memory Cache)
|
||||
valid_token = user_api_key_cache.get_cache(key=api_key)
|
||||
print(f"valid_token from cache: {valid_token}")
|
||||
if valid_token is None:
|
||||
## check db
|
||||
cleaned_api_key = api_key
|
||||
valid_token = await prisma_client.get_data(token=cleaned_api_key, expires=datetime.utcnow())
|
||||
valid_token = await prisma_client.get_data(token=api_key, expires=datetime.utcnow())
|
||||
user_api_key_cache.set_cache(key=api_key, value=valid_token, ttl=60)
|
||||
elif valid_token is not None:
|
||||
print(f"API Key Cache Hit!")
|
||||
|
@ -274,10 +274,10 @@ async def user_api_key_auth(request: Request, api_key: str = fastapi.Security(ap
|
|||
llm_model_list = model_list
|
||||
print("\n new llm router model list", llm_model_list)
|
||||
if len(valid_token.models) == 0: # assume an empty model list means all models are allowed to be called
|
||||
return_dict = {"api_key": valid_token.token}
|
||||
if valid_token.user_id:
|
||||
return_dict["user_id"] = valid_token.user_id
|
||||
return UserAPIKeyAuth(**return_dict)
|
||||
api_key = valid_token.token
|
||||
valid_token_dict = valid_token.model_dump()
|
||||
valid_token_dict.pop("token", None)
|
||||
return UserAPIKeyAuth(api_key=api_key, **valid_token_dict)
|
||||
else:
|
||||
data = await request.json()
|
||||
model = data.get("model", None)
|
||||
|
@ -285,10 +285,10 @@ async def user_api_key_auth(request: Request, api_key: str = fastapi.Security(ap
|
|||
model = litellm.model_alias_map[model]
|
||||
if model and model not in valid_token.models:
|
||||
raise Exception(f"Token not allowed to access model")
|
||||
return_dict = {"api_key": valid_token.token}
|
||||
if valid_token.user_id:
|
||||
return_dict["user_id"] = valid_token.user_id
|
||||
return UserAPIKeyAuth(**return_dict)
|
||||
api_key = valid_token.token
|
||||
valid_token_dict = valid_token.model_dump()
|
||||
valid_token.pop("token", None)
|
||||
return UserAPIKeyAuth(api_key=api_key, **valid_token)
|
||||
else:
|
||||
raise Exception(f"Invalid token")
|
||||
except Exception as e:
|
||||
|
@ -588,7 +588,7 @@ def load_router_config(router: Optional[litellm.Router], config_file_path: str):
|
|||
call_hooks.update_router_config(litellm_settings=litellm_settings, model_list=model_list, general_settings=general_settings)
|
||||
return router, model_list, general_settings
|
||||
|
||||
async def generate_key_helper_fn(duration_str: Optional[str], models: list, aliases: dict, config: dict, spend: float, token: Optional[str]=None, user_id: Optional[str]=None):
|
||||
async def generate_key_helper_fn(duration: Optional[str], models: list, aliases: dict, config: dict, spend: float, token: Optional[str]=None, user_id: Optional[str]=None, max_parallel_requests: Optional[int]=None):
|
||||
global prisma_client
|
||||
|
||||
if prisma_client is None:
|
||||
|
@ -617,11 +617,11 @@ async def generate_key_helper_fn(duration_str: Optional[str], models: list, alia
|
|||
else:
|
||||
raise ValueError("Unsupported duration unit")
|
||||
|
||||
if duration_str is None: # allow tokens that never expire
|
||||
if duration is None: # allow tokens that never expire
|
||||
expires = None
|
||||
else:
|
||||
duration = _duration_in_seconds(duration=duration_str)
|
||||
expires = datetime.utcnow() + timedelta(seconds=duration)
|
||||
duration_s = _duration_in_seconds(duration=duration)
|
||||
expires = datetime.utcnow() + timedelta(seconds=duration_s)
|
||||
|
||||
aliases_json = json.dumps(aliases)
|
||||
config_json = json.dumps(config)
|
||||
|
@ -635,7 +635,8 @@ async def generate_key_helper_fn(duration_str: Optional[str], models: list, alia
|
|||
"aliases": aliases_json,
|
||||
"config": config_json,
|
||||
"spend": spend,
|
||||
"user_id": user_id
|
||||
"user_id": user_id,
|
||||
"max_parallel_requests": max_parallel_requests
|
||||
}
|
||||
new_verification_token = await prisma_client.insert_data(data=verification_token_data)
|
||||
except Exception as e:
|
||||
|
@ -755,14 +756,12 @@ def data_generator(response):
|
|||
except:
|
||||
yield f"data: {json.dumps(chunk)}\n\n"
|
||||
|
||||
async def async_data_generator(response):
|
||||
async def async_data_generator(response, user_api_key_dict):
|
||||
global call_hooks
|
||||
|
||||
print_verbose("inside generator")
|
||||
async for chunk in response:
|
||||
print_verbose(f"returned chunk: {chunk}")
|
||||
### CALL HOOKS ### - modify outgoing response
|
||||
response = call_hooks.post_call_success(chunk=chunk, call_type="completion")
|
||||
try:
|
||||
yield f"data: {json.dumps(chunk.dict())}\n\n"
|
||||
except:
|
||||
|
@ -812,36 +811,6 @@ def get_litellm_model_info(model: dict = {}):
|
|||
# if litellm does not have info on the model it should return {}
|
||||
return {}
|
||||
|
||||
@app.middleware("http")
|
||||
async def rate_limit_per_token(request: Request, call_next):
|
||||
global user_api_key_cache, general_settings
|
||||
max_parallel_requests = general_settings.get("max_parallel_requests", None)
|
||||
api_key = request.headers.get("Authorization")
|
||||
if max_parallel_requests is not None and api_key is not None: # Rate limiting is enabled
|
||||
api_key = _get_bearer_token(api_key=api_key)
|
||||
# CHECK IF REQUEST ALLOWED
|
||||
request_count_api_key = f"{api_key}_request_count"
|
||||
current = user_api_key_cache.get_cache(key=request_count_api_key)
|
||||
if current is None:
|
||||
user_api_key_cache.set_cache(request_count_api_key, 1)
|
||||
elif int(current) < max_parallel_requests:
|
||||
# Increase count for this token
|
||||
user_api_key_cache.set_cache(request_count_api_key, int(current) + 1)
|
||||
else:
|
||||
raise HTTPException(status_code=429, detail="Too many requests.")
|
||||
|
||||
|
||||
response = await call_next(request)
|
||||
|
||||
# Decrease count for this token
|
||||
current = user_api_key_cache.get_cache(key=request_count_api_key)
|
||||
user_api_key_cache.set_cache(request_count_api_key, int(current) - 1)
|
||||
|
||||
return response
|
||||
else: # Rate limiting is not enabled, just pass the request
|
||||
response = await call_next(request)
|
||||
return response
|
||||
|
||||
@router.on_event("startup")
|
||||
async def startup_event():
|
||||
global prisma_client, master_key, use_background_health_checks
|
||||
|
@ -868,7 +837,7 @@ async def startup_event():
|
|||
|
||||
if prisma_client is not None and master_key is not None:
|
||||
# add master key to db
|
||||
await generate_key_helper_fn(duration_str=None, models=[], aliases={}, config={}, spend=0, token=master_key)
|
||||
await generate_key_helper_fn(duration=None, models=[], aliases={}, config={}, spend=0, token=master_key)
|
||||
|
||||
@router.on_event("shutdown")
|
||||
async def shutdown_event():
|
||||
|
@ -1008,7 +977,7 @@ async def chat_completion(request: Request, model: Optional[str] = None, user_ap
|
|||
data["api_base"] = user_api_base
|
||||
|
||||
### CALL HOOKS ### - modify incoming data before calling the model
|
||||
data = call_hooks.pre_call(data=data, call_type="completion")
|
||||
data = await call_hooks.pre_call(user_api_key_dict=user_api_key_dict, data=data, call_type="completion")
|
||||
|
||||
### ROUTE THE REQUEST ###
|
||||
router_model_names = [m["model_name"] for m in llm_model_list] if llm_model_list is not None else []
|
||||
|
@ -1021,15 +990,19 @@ async def chat_completion(request: Request, model: Optional[str] = None, user_ap
|
|||
else: # router is not set
|
||||
response = await litellm.acompletion(**data)
|
||||
|
||||
print(f"final response: {response}")
|
||||
if 'stream' in data and data['stream'] == True: # use generate_responses to stream responses
|
||||
return StreamingResponse(async_data_generator(response), media_type='text/event-stream')
|
||||
return StreamingResponse(async_data_generator(user_api_key_dict=user_api_key_dict, response=response), media_type='text/event-stream')
|
||||
|
||||
### CALL HOOKS ### - modify outgoing response
|
||||
response = call_hooks.post_call_success(response=response, call_type="completion")
|
||||
response = await call_hooks.post_call_success(user_api_key_dict=user_api_key_dict, response=response, call_type="completion")
|
||||
|
||||
background_tasks.add_task(log_input_output, request, response) # background task for logging to OTEL
|
||||
return response
|
||||
except Exception as e:
|
||||
except Exception as e:
|
||||
print(f"Exception received: {str(e)}")
|
||||
raise e
|
||||
await call_hooks.post_call_failure(original_exception=e, user_api_key_dict=user_api_key_dict)
|
||||
print(f"\033[1;31mAn error occurred: {e}\n\n Debug this by setting `--debug`, e.g. `litellm --model gpt-3.5-turbo --debug`")
|
||||
router_model_names = [m["model_name"] for m in llm_model_list] if llm_model_list is not None else []
|
||||
if llm_router is not None and data.get("model", "") in router_model_names:
|
||||
|
@ -1046,23 +1019,26 @@ async def chat_completion(request: Request, model: Optional[str] = None, user_ap
|
|||
print(f"{key}: {value}")
|
||||
if user_debug:
|
||||
traceback.print_exc()
|
||||
error_traceback = traceback.format_exc()
|
||||
error_msg = f"{str(e)}\n\n{error_traceback}"
|
||||
try:
|
||||
status = e.status_code # type: ignore
|
||||
except:
|
||||
status = 500
|
||||
raise HTTPException(
|
||||
status_code=status,
|
||||
detail=error_msg
|
||||
)
|
||||
|
||||
if isinstance(e, HTTPException):
|
||||
raise e
|
||||
else:
|
||||
error_traceback = traceback.format_exc()
|
||||
error_msg = f"{str(e)}\n\n{error_traceback}"
|
||||
try:
|
||||
status = e.status_code # type: ignore
|
||||
except:
|
||||
status = 500
|
||||
raise HTTPException(
|
||||
status_code=status,
|
||||
detail=error_msg
|
||||
)
|
||||
|
||||
@router.post("/v1/embeddings", dependencies=[Depends(user_api_key_auth)], response_class=ORJSONResponse)
|
||||
@router.post("/embeddings", dependencies=[Depends(user_api_key_auth)], response_class=ORJSONResponse)
|
||||
async def embeddings(request: Request, user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), background_tasks: BackgroundTasks = BackgroundTasks()):
|
||||
global call_hooks
|
||||
try:
|
||||
global call_hooks
|
||||
|
||||
# Use orjson to parse JSON data, orjson speeds up requests significantly
|
||||
body = await request.body()
|
||||
data = orjson.loads(body)
|
||||
|
@ -1105,7 +1081,7 @@ async def embeddings(request: Request, user_api_key_dict: UserAPIKeyAuth = Depen
|
|||
break
|
||||
|
||||
### CALL HOOKS ### - modify incoming data before calling the model
|
||||
data = call_hooks.pre_call(data=data, call_type="embeddings")
|
||||
data = await call_hooks.pre_call(user_api_key_dict=user_api_key_dict, data=data, call_type="embeddings")
|
||||
|
||||
## ROUTE TO CORRECT ENDPOINT ##
|
||||
if llm_router is not None and data["model"] in router_model_names: # model in router model list
|
||||
|
@ -1117,19 +1093,18 @@ async def embeddings(request: Request, user_api_key_dict: UserAPIKeyAuth = Depen
|
|||
background_tasks.add_task(log_input_output, request, response) # background task for logging to OTEL
|
||||
|
||||
### CALL HOOKS ### - modify outgoing response
|
||||
data = call_hooks.post_call_success(response=response, call_type="embeddings")
|
||||
data = call_hooks.post_call_success(user_api_key_dict=user_api_key_dict, response=response, call_type="embeddings")
|
||||
|
||||
return response
|
||||
except Exception as e:
|
||||
await call_hooks.post_call_failure(user_api_key_dict=user_api_key_dict, original_exception=e)
|
||||
traceback.print_exc()
|
||||
raise e
|
||||
except Exception as e:
|
||||
pass
|
||||
|
||||
#### KEY MANAGEMENT ####
|
||||
|
||||
@router.post("/key/generate", tags=["key management"], dependencies=[Depends(user_api_key_auth)], response_model=GenerateKeyResponse)
|
||||
async def generate_key_fn(request: Request, data: GenerateKeyRequest):
|
||||
async def generate_key_fn(request: Request, data: GenerateKeyRequest, Authorization: Optional[str] = Header(None)):
|
||||
"""
|
||||
Generate an API key based on the provided data.
|
||||
|
||||
|
@ -1141,26 +1116,17 @@ async def generate_key_fn(request: Request, data: GenerateKeyRequest):
|
|||
- aliases: Optional[dict] - Any alias mappings, on top of anything in the config.yaml model list. - https://docs.litellm.ai/docs/proxy/virtual_keys#managing-auth---upgradedowngrade-models
|
||||
- config: Optional[dict] - any key-specific configs, overrides config in config.yaml
|
||||
- spend: Optional[int] - Amount spent by key. Default is 0. Will be updated by proxy whenever key is used. https://docs.litellm.ai/docs/proxy/virtual_keys#managing-auth---tracking-spend
|
||||
- max_parallel_requests: Optional[int] - Rate limit a user based on the number of parallel requests. Raises 429 error, if user's parallel requests > x.
|
||||
|
||||
Returns:
|
||||
- key: The generated api key
|
||||
- expires: Datetime object for when key expires.
|
||||
- key: (str) The generated api key
|
||||
- expires: (datetime) Datetime object for when key expires.
|
||||
- user_id: (str) Unique user id - used for tracking spend across multiple keys for same user id.
|
||||
"""
|
||||
# data = await request.json()
|
||||
duration_str = data.duration # Default to 1 hour if duration is not provided
|
||||
models = data.models # Default to an empty list (meaning allow token to call all models)
|
||||
aliases = data.aliases # Default to an empty dict (no alias mappings, on top of anything in the config.yaml model_list)
|
||||
config = data.config
|
||||
spend = data.spend
|
||||
user_id = data.user_id
|
||||
if isinstance(models, list):
|
||||
response = await generate_key_helper_fn(duration_str=duration_str, models=models, aliases=aliases, config=config, spend=spend, user_id=user_id)
|
||||
return GenerateKeyResponse(key=response["token"], expires=response["expires"], user_id=response["user_id"])
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail={"error": "models param must be a list"},
|
||||
)
|
||||
data_json = data.model_dump()
|
||||
response = await generate_key_helper_fn(**data_json)
|
||||
return GenerateKeyResponse(key=response["token"], expires=response["expires"], user_id=response["user_id"])
|
||||
|
||||
@router.post("/key/delete", tags=["key management"], dependencies=[Depends(user_api_key_auth)])
|
||||
async def delete_key_fn(request: Request, data: DeleteKeyRequest):
|
||||
|
|
|
@ -16,4 +16,5 @@ model LiteLLM_VerificationToken {
|
|||
aliases Json @default("{}")
|
||||
config Json @default("{}")
|
||||
user_id String?
|
||||
max_parallel_requests Int?
|
||||
}
|
|
@ -1,7 +1,13 @@
|
|||
from typing import Optional, List, Any, Literal
|
||||
import os, subprocess, hashlib, importlib, asyncio
|
||||
import litellm, backoff
|
||||
from litellm.proxy._types import UserAPIKeyAuth
|
||||
from litellm.caching import DualCache
|
||||
from litellm.proxy.hooks.parallel_request_limiter import max_parallel_request_allow_request, max_parallel_request_update_count
|
||||
|
||||
def print_verbose(print_statement):
|
||||
if litellm.set_verbose:
|
||||
print(print_statement) # noqa
|
||||
### LOGGING ###
|
||||
class ProxyLogging:
|
||||
"""
|
||||
|
@ -17,7 +23,6 @@ class ProxyLogging:
|
|||
self._init_litellm_callbacks()
|
||||
pass
|
||||
|
||||
|
||||
def _init_litellm_callbacks(self):
|
||||
if len(litellm.callbacks) > 0:
|
||||
for callback in litellm.callbacks:
|
||||
|
@ -69,11 +74,11 @@ class ProxyLogging:
|
|||
# Function to be called whenever a retry is about to happen
|
||||
def on_backoff(details):
|
||||
# The 'tries' key in the details dictionary contains the number of completed tries
|
||||
print(f"Backing off... this was attempt #{details['tries']}")
|
||||
print_verbose(f"Backing off... this was attempt #{details['tries']}")
|
||||
|
||||
class PrismaClient:
|
||||
def __init__(self, database_url: str, proxy_logging_obj: ProxyLogging):
|
||||
print("LiteLLM: DATABASE_URL Set in config, trying to 'pip install prisma'")
|
||||
print_verbose("LiteLLM: DATABASE_URL Set in config, trying to 'pip install prisma'")
|
||||
## init logging object
|
||||
self.proxy_logging_obj = proxy_logging_obj
|
||||
|
||||
|
@ -109,20 +114,22 @@ class PrismaClient:
|
|||
max_time=10, # maximum total time to retry for
|
||||
on_backoff=on_backoff, # specifying the function to call on backoff
|
||||
)
|
||||
async def get_data(self, token: str, expires: Optional[Any]=None):
|
||||
async def get_data(self, token: str, expires: Optional[Any]=None):
|
||||
try:
|
||||
hashed_token = self.hash_token(token=token)
|
||||
# check if plain text or hash
|
||||
if token.startswith("sk-"):
|
||||
token = self.hash_token(token=token)
|
||||
if expires:
|
||||
response = await self.db.litellm_verificationtoken.find_first(
|
||||
where={
|
||||
"token": hashed_token,
|
||||
"token": token,
|
||||
"expires": {"gte": expires} # Check if the token is not expired
|
||||
}
|
||||
)
|
||||
else:
|
||||
response = await self.db.litellm_verificationtoken.find_unique(
|
||||
where={
|
||||
"token": hashed_token
|
||||
"token": token
|
||||
}
|
||||
)
|
||||
return response
|
||||
|
@ -175,25 +182,23 @@ class PrismaClient:
|
|||
Update existing data
|
||||
"""
|
||||
try:
|
||||
hashed_token = self.hash_token(token=token)
|
||||
data["token"] = hashed_token
|
||||
await self.db.litellm_verificationtoken.update(
|
||||
print_verbose(f"token: {token}")
|
||||
# check if plain text or hash
|
||||
if token.startswith("sk-"):
|
||||
token = self.hash_token(token=token)
|
||||
|
||||
data["token"] = token
|
||||
response = await self.db.litellm_verificationtoken.update(
|
||||
where={
|
||||
"token": hashed_token
|
||||
"token": token
|
||||
},
|
||||
data={**data} # type: ignore
|
||||
)
|
||||
print("\033[91m" + f"DB write succeeded" + "\033[0m")
|
||||
print_verbose("\033[91m" + f"DB write succeeded {response}" + "\033[0m")
|
||||
return {"token": token, "data": data}
|
||||
except Exception as e:
|
||||
asyncio.create_task(self.proxy_logging_obj.failure_handler(original_exception=e))
|
||||
print()
|
||||
print()
|
||||
print()
|
||||
print("\033[91m" + f"DB write failed: {e}" + "\033[0m")
|
||||
print()
|
||||
print()
|
||||
print()
|
||||
print_verbose("\033[91m" + f"DB write failed: {e}" + "\033[0m")
|
||||
raise e
|
||||
|
||||
|
||||
|
@ -252,7 +257,7 @@ class PrismaClient:
|
|||
### CUSTOM FILE ###
|
||||
def get_instance_fn(value: str, config_file_path: Optional[str] = None) -> Any:
|
||||
try:
|
||||
print(f"value: {value}")
|
||||
print_verbose(f"value: {value}")
|
||||
# Split the path by dots to separate module from instance
|
||||
parts = value.split(".")
|
||||
|
||||
|
@ -285,8 +290,6 @@ def get_instance_fn(value: str, config_file_path: Optional[str] = None) -> Any:
|
|||
except Exception as e:
|
||||
raise e
|
||||
|
||||
|
||||
|
||||
### CALL HOOKS ###
|
||||
class CallHooks:
|
||||
"""
|
||||
|
@ -297,20 +300,55 @@ class CallHooks:
|
|||
2. /embeddings
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
self.call_details = {}
|
||||
def __init__(self, user_api_key_cache: DualCache):
|
||||
self.call_details: dict = {}
|
||||
self.call_details["user_api_key_cache"] = user_api_key_cache
|
||||
|
||||
def update_router_config(self, litellm_settings: dict, general_settings: dict, model_list: list):
|
||||
self.call_details["litellm_settings"] = litellm_settings
|
||||
self.call_details["general_settings"] = general_settings
|
||||
self.call_details["model_list"] = model_list
|
||||
|
||||
def pre_call(self, data: dict, call_type: Literal["completion", "embeddings"]):
|
||||
self.call_details["data"] = data
|
||||
return data
|
||||
async def pre_call(self, user_api_key_dict: UserAPIKeyAuth, data: dict, call_type: Literal["completion", "embeddings"]):
|
||||
try:
|
||||
self.call_details["data"] = data
|
||||
self.call_details["call_type"] = call_type
|
||||
|
||||
def post_call_success(self, response: Optional[Any]=None, call_type: Optional[Literal["completion", "embeddings"]]=None, chunk: Optional[Any]=None):
|
||||
return response
|
||||
## check if max parallel requests set
|
||||
if user_api_key_dict.max_parallel_requests is not None:
|
||||
## if set, check if request allowed
|
||||
await max_parallel_request_allow_request(
|
||||
max_parallel_requests=user_api_key_dict.max_parallel_requests,
|
||||
api_key=user_api_key_dict.api_key,
|
||||
user_api_key_cache=self.call_details["user_api_key_cache"])
|
||||
|
||||
return data
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
def post_call_failure(self, *args, **kwargs):
|
||||
pass
|
||||
async def post_call_success(self, user_api_key_dict: UserAPIKeyAuth, response: Optional[Any]=None, call_type: Optional[Literal["completion", "embeddings"]]=None, chunk: Optional[Any]=None):
|
||||
try:
|
||||
# check if max parallel requests set
|
||||
if user_api_key_dict.max_parallel_requests is not None:
|
||||
## decrement call, once complete
|
||||
await max_parallel_request_update_count(
|
||||
api_key=user_api_key_dict.api_key,
|
||||
user_api_key_cache=self.call_details["user_api_key_cache"])
|
||||
|
||||
return response
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
async def post_call_failure(self, original_exception: Exception, user_api_key_dict: UserAPIKeyAuth):
|
||||
# check if max parallel requests set
|
||||
if user_api_key_dict.max_parallel_requests is not None:
|
||||
## decrement call count if call failed
|
||||
if (hasattr(original_exception, "status_code")
|
||||
and original_exception.status_code == 429
|
||||
and "Max parallel request limit reached" in str(original_exception)):
|
||||
pass # ignore failed calls due to max limit being reached
|
||||
else:
|
||||
await max_parallel_request_update_count(
|
||||
api_key=user_api_key_dict.api_key,
|
||||
user_api_key_cache=self.call_details["user_api_key_cache"])
|
||||
return
|
|
@ -3,7 +3,6 @@ general_settings:
|
|||
master_key: os.environ/PROXY_MASTER_KEY
|
||||
litellm_settings:
|
||||
drop_params: true
|
||||
set_verbose: true
|
||||
success_callback: ["langfuse"]
|
||||
|
||||
model_list:
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
import sys, os
|
||||
import sys, os, time
|
||||
import traceback
|
||||
from dotenv import load_dotenv
|
||||
|
||||
|
@ -19,7 +19,7 @@ logging.basicConfig(
|
|||
level=logging.DEBUG, # Set the desired logging level
|
||||
format="%(asctime)s - %(levelname)s - %(message)s",
|
||||
)
|
||||
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
# test /chat/completion request to the proxy
|
||||
from fastapi.testclient import TestClient
|
||||
from fastapi import FastAPI
|
||||
|
@ -62,6 +62,41 @@ def test_add_new_key(client):
|
|||
assert result["key"].startswith("sk-")
|
||||
print(f"Received response: {result}")
|
||||
except Exception as e:
|
||||
pytest.fail("LiteLLM Proxy test failed. Exception", e)
|
||||
pytest.fail(f"LiteLLM Proxy test failed. Exception: {str(e)}")
|
||||
|
||||
# # Run the test - only runs via pytest
|
||||
# # Run the test - only runs via pytest
|
||||
|
||||
|
||||
def test_add_new_key_max_parallel_limit(client):
|
||||
try:
|
||||
# Your test data
|
||||
test_data = {"duration": "20m", "max_parallel_requests": 1}
|
||||
# Your bearer token
|
||||
token = os.getenv("PROXY_MASTER_KEY")
|
||||
|
||||
headers = {
|
||||
"Authorization": f"Bearer {token}"
|
||||
}
|
||||
response = client.post("/key/generate", json=test_data, headers=headers)
|
||||
print(f"response: {response.text}")
|
||||
assert response.status_code == 200
|
||||
result = response.json()
|
||||
def _post_data():
|
||||
json_data = {'model': 'azure-model', "messages": [{"role": "user", "content": f"this is a test request, write a short poem {time.time()}"}]}
|
||||
response = client.post("/chat/completions", json=json_data, headers={"Authorization": f"Bearer {result['key']}"})
|
||||
return response
|
||||
def _run_in_parallel():
|
||||
with ThreadPoolExecutor(max_workers=2) as executor:
|
||||
future1 = executor.submit(_post_data)
|
||||
future2 = executor.submit(_post_data)
|
||||
|
||||
# Obtain the results from the futures
|
||||
response1 = future1.result()
|
||||
response2 = future2.result()
|
||||
if response1.status_code == 429 or response2.status_code == 429:
|
||||
pass
|
||||
else:
|
||||
raise Exception()
|
||||
_run_in_parallel()
|
||||
except Exception as e:
|
||||
pytest.fail(f"LiteLLM Proxy test failed. Exception: {str(e)}")
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue