forked from phoenix/litellm-mirror
feat(proxy_server.py): support global budget and resets
This commit is contained in:
parent
f0ada5c950
commit
159e54d8be
6 changed files with 120 additions and 18 deletions
|
@ -62,6 +62,9 @@ cache: Optional[
|
|||
model_alias_map: Dict[str, str] = {}
|
||||
model_group_alias_map: Dict[str, str] = {}
|
||||
max_budget: float = 0.0 # set the max budget across all providers
|
||||
budget_duration: Optional[
|
||||
str
|
||||
] = None # proxy only - resets budget after fixed duration. You can set duration as seconds ("30s"), minutes ("30m"), hours ("30h"), days ("30d").
|
||||
_openai_completion_params = [
|
||||
"functions",
|
||||
"function_call",
|
||||
|
|
|
@ -306,6 +306,10 @@ class LiteLLM_VerificationToken(LiteLLMBase):
|
|||
user_id: Union[str, None]
|
||||
max_parallel_requests: Union[int, None]
|
||||
metadata: Dict[str, str] = {}
|
||||
tpm_limit: Optional[int] = None
|
||||
rpm_limit: Optional[int] = None
|
||||
budget_duration: Optional[str] = None
|
||||
budget_reset_at: Optional[datetime] = None
|
||||
|
||||
|
||||
class LiteLLM_Config(LiteLLMBase):
|
||||
|
|
|
@ -1151,6 +1151,7 @@ async def generate_key_helper_fn(
|
|||
metadata: Optional[dict] = {},
|
||||
tpm_limit: Optional[int] = None,
|
||||
rpm_limit: Optional[int] = None,
|
||||
query_type: Literal["insert_data", "update_data"] = "insert_data",
|
||||
):
|
||||
global prisma_client, custom_db_client
|
||||
|
||||
|
@ -1193,6 +1194,12 @@ async def generate_key_helper_fn(
|
|||
duration_s = _duration_in_seconds(duration=key_budget_duration)
|
||||
key_reset_at = datetime.utcnow() + timedelta(seconds=duration_s)
|
||||
|
||||
if budget_duration is None: # one-time budget
|
||||
reset_at = None
|
||||
else:
|
||||
duration_s = _duration_in_seconds(duration=budget_duration)
|
||||
reset_at = datetime.utcnow() + timedelta(seconds=duration_s)
|
||||
|
||||
aliases_json = json.dumps(aliases)
|
||||
config_json = json.dumps(config)
|
||||
metadata_json = json.dumps(metadata)
|
||||
|
@ -1213,6 +1220,8 @@ async def generate_key_helper_fn(
|
|||
"max_parallel_requests": max_parallel_requests,
|
||||
"tpm_limit": tpm_limit,
|
||||
"rpm_limit": rpm_limit,
|
||||
"budget_duration": budget_duration,
|
||||
"budget_reset_at": reset_at,
|
||||
}
|
||||
key_data = {
|
||||
"token": token,
|
||||
|
@ -1234,13 +1243,18 @@ async def generate_key_helper_fn(
|
|||
if prisma_client is not None:
|
||||
## CREATE USER (If necessary)
|
||||
verbose_proxy_logger.debug(f"prisma_client: Creating User={user_data}")
|
||||
user_row = await prisma_client.insert_data(
|
||||
data=user_data, table_name="user"
|
||||
)
|
||||
if query_type == "insert_data":
|
||||
user_row = await prisma_client.insert_data(
|
||||
data=user_data, table_name="user"
|
||||
)
|
||||
## use default user model list if no key-specific model list provided
|
||||
if len(user_row.models) > 0 and len(key_data["models"]) == 0: # type: ignore
|
||||
key_data["models"] = user_row.models
|
||||
elif query_type == "update_data":
|
||||
user_row = await prisma_client.update_data(
|
||||
data=user_data, table_name="user"
|
||||
)
|
||||
|
||||
## use default user model list if no key-specific model list provided
|
||||
if len(user_row.models) > 0 and len(key_data["models"]) == 0: # type: ignore
|
||||
key_data["models"] = user_row.models
|
||||
## CREATE KEY
|
||||
verbose_proxy_logger.debug(f"prisma_client: Creating Key={key_data}")
|
||||
await prisma_client.insert_data(data=key_data, table_name="key")
|
||||
|
@ -1548,6 +1562,25 @@ async def startup_event():
|
|||
await generate_key_helper_fn(
|
||||
duration=None, models=[], aliases={}, config={}, spend=0, token=master_key
|
||||
)
|
||||
|
||||
if (
|
||||
prisma_client is not None
|
||||
and litellm.max_budget > 0
|
||||
and litellm.budget_duration is not None
|
||||
):
|
||||
# add proxy budget to db in the user table
|
||||
await generate_key_helper_fn(
|
||||
user_id="litellm-proxy-budget",
|
||||
duration=None,
|
||||
models=[],
|
||||
aliases={},
|
||||
config={},
|
||||
spend=0,
|
||||
max_budget=litellm.max_budget,
|
||||
budget_duration=litellm.budget_duration,
|
||||
query_type="update_data",
|
||||
)
|
||||
|
||||
verbose_proxy_logger.debug(
|
||||
f"custom_db_client client {custom_db_client}. Master_key: {master_key}"
|
||||
)
|
||||
|
|
|
@ -17,6 +17,8 @@ model LiteLLM_UserTable {
|
|||
max_parallel_requests Int?
|
||||
tpm_limit BigInt?
|
||||
rpm_limit BigInt?
|
||||
budget_duration String?
|
||||
budget_reset_at DateTime?
|
||||
}
|
||||
|
||||
// required for token gen
|
||||
|
|
|
@ -425,12 +425,21 @@ class PrismaClient:
|
|||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Authentication Error: invalid user key - token does not exist",
|
||||
)
|
||||
elif user_id is not None:
|
||||
response = await self.db.litellm_usertable.find_unique( # type: ignore
|
||||
where={
|
||||
"user_id": user_id,
|
||||
}
|
||||
)
|
||||
elif user_id is not None or (
|
||||
table_name is not None and table_name == "user"
|
||||
):
|
||||
if query_type == "find_unique":
|
||||
response = await self.db.litellm_usertable.find_unique( # type: ignore
|
||||
where={
|
||||
"user_id": user_id, # type: ignore
|
||||
}
|
||||
)
|
||||
elif query_type == "find_all" and reset_at is not None:
|
||||
response = await self.db.litellm_usertable.find_many(
|
||||
where={ # type:ignore
|
||||
"budget_reset_at": {"lt": reset_at},
|
||||
}
|
||||
)
|
||||
return response
|
||||
elif table_name == "spend":
|
||||
verbose_proxy_logger.debug(
|
||||
|
@ -585,10 +594,16 @@ class PrismaClient:
|
|||
+ "\033[0m"
|
||||
)
|
||||
return {"token": token, "data": db_data}
|
||||
elif user_id is not None:
|
||||
elif (
|
||||
user_id is not None
|
||||
or (table_name is not None and table_name == "user")
|
||||
and query_type == "update"
|
||||
):
|
||||
"""
|
||||
If data['spend'] + data['user'], update the user table with spend info as well
|
||||
"""
|
||||
if user_id is None:
|
||||
user_id = db_data["user_id"]
|
||||
update_user_row = await self.db.litellm_usertable.update(
|
||||
where={"user_id": user_id}, # type: ignore
|
||||
data={**db_data}, # type: ignore
|
||||
|
@ -638,6 +653,30 @@ class PrismaClient:
|
|||
print_verbose(
|
||||
"\033[91m" + f"DB Token Table update succeeded" + "\033[0m"
|
||||
)
|
||||
elif (
|
||||
table_name is not None
|
||||
and table_name == "user"
|
||||
and query_type == "update_many"
|
||||
and data_list is not None
|
||||
and isinstance(data_list, list)
|
||||
):
|
||||
"""
|
||||
Batch write update queries
|
||||
"""
|
||||
batcher = self.db.batch_()
|
||||
for idx, user in enumerate(data_list):
|
||||
try:
|
||||
data_json = self.jsonify_object(data=user.model_dump())
|
||||
except:
|
||||
data_json = self.jsonify_object(data=user.dict())
|
||||
batcher.litellm_usertable.update(
|
||||
where={"user_id": user.user_id}, # type: ignore
|
||||
data={**data_json}, # type: ignore
|
||||
)
|
||||
await batcher.commit()
|
||||
print_verbose(
|
||||
"\033[91m" + f"DB User Table update succeeded" + "\033[0m"
|
||||
)
|
||||
except Exception as e:
|
||||
asyncio.create_task(
|
||||
self.proxy_logging_obj.failure_handler(original_exception=e)
|
||||
|
@ -994,17 +1033,36 @@ async def reset_budget(prisma_client: PrismaClient):
|
|||
Updates db
|
||||
"""
|
||||
if prisma_client is not None:
|
||||
### RESET KEY BUDGET ###
|
||||
now = datetime.utcnow()
|
||||
keys_to_reset = await prisma_client.get_data(
|
||||
table_name="key", query_type="find_all", expires=now, reset_at=now
|
||||
)
|
||||
|
||||
for key in keys_to_reset:
|
||||
key.spend = 0.0
|
||||
duration_s = _duration_in_seconds(duration=key.budget_duration)
|
||||
key.budget_reset_at = key.budget_reset_at + timedelta(seconds=duration_s)
|
||||
if keys_to_reset is not None and len(keys_to_reset) > 0:
|
||||
for key in keys_to_reset:
|
||||
key.spend = 0.0
|
||||
duration_s = _duration_in_seconds(duration=key.budget_duration)
|
||||
key.budget_reset_at = now + timedelta(seconds=duration_s)
|
||||
|
||||
if len(keys_to_reset) > 0:
|
||||
await prisma_client.update_data(
|
||||
query_type="update_many", data_list=keys_to_reset, table_name="key"
|
||||
)
|
||||
|
||||
### RESET USER BUDGET ###
|
||||
now = datetime.utcnow()
|
||||
users_to_reset = await prisma_client.get_data(
|
||||
table_name="user", query_type="find_all", reset_at=now
|
||||
)
|
||||
|
||||
verbose_proxy_logger.debug(f"users_to_reset from get_data: {users_to_reset}")
|
||||
|
||||
if users_to_reset is not None and len(users_to_reset) > 0:
|
||||
for user in users_to_reset:
|
||||
user.spend = 0.0
|
||||
duration_s = _duration_in_seconds(duration=user.budget_duration)
|
||||
user.budget_reset_at = now + timedelta(seconds=duration_s)
|
||||
|
||||
await prisma_client.update_data(
|
||||
query_type="update_many", data_list=users_to_reset, table_name="user"
|
||||
)
|
||||
|
|
|
@ -17,6 +17,8 @@ model LiteLLM_UserTable {
|
|||
max_parallel_requests Int?
|
||||
tpm_limit BigInt?
|
||||
rpm_limit BigInt?
|
||||
budget_duration String?
|
||||
budget_reset_at DateTime?
|
||||
}
|
||||
|
||||
// required for token gen
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue