fix(proxy_server.py): support setting tpm/rpm limits per user / per key

This commit is contained in:
Krrish Dholakia 2024-01-18 17:03:18 -08:00
parent 5dac2402ef
commit 1e5efdfa37
5 changed files with 26 additions and 3 deletions

View file

@ -131,6 +131,8 @@ class GenerateKeyRequest(LiteLLMBase):
user_id: Optional[str] = None
max_parallel_requests: Optional[int] = None
metadata: Optional[dict] = {}
tpm_limit: int = sys.maxsize
rpm_limit: int = sys.maxsize
class UpdateKeyRequest(LiteLLMBase):
@ -145,6 +147,8 @@ class UpdateKeyRequest(LiteLLMBase):
user_id: Optional[str] = None
max_parallel_requests: Optional[int] = None
metadata: Optional[dict] = None
tpm_limit: int = sys.maxsize
rpm_limit: int = sys.maxsize
class UserAPIKeyAuth(LiteLLMBase): # the expected response object for user api key auth

View file

@ -418,6 +418,7 @@ def run_server(
break # Exit the loop if the subprocess succeeds
except subprocess.CalledProcessError as e:
print(f"Error: {e}")
time.sleep(random.randrange(start=1, stop=5))
finally:
os.chdir(original_dir)
else:

View file

@ -1040,6 +1040,8 @@ async def generate_key_helper_fn(
user_email: Optional[str] = None,
max_parallel_requests: Optional[int] = None,
metadata: Optional[dict] = {},
tpm_limit: Optional[int] = None,
rpm_limit: Optional[int] = None,
):
global prisma_client, custom_db_client
@ -1080,6 +1082,8 @@ async def generate_key_helper_fn(
config_json = json.dumps(config)
metadata_json = json.dumps(metadata)
user_id = user_id or str(uuid.uuid4())
tpm_limit = tpm_limit or sys.maxsize
rpm_limit = rpm_limit or sys.maxsize
try:
# Create a new verification token (you may want to enhance this logic based on your needs)
user_data = {
@ -1088,6 +1092,9 @@ async def generate_key_helper_fn(
"user_id": user_id,
"spend": spend,
"models": models,
"max_parallel_requests": max_parallel_requests,
"tpm_limit": tpm_limit,
"rpm_limit": rpm_limit,
}
key_data = {
"token": token,
@ -1099,6 +1106,8 @@ async def generate_key_helper_fn(
"user_id": user_id,
"max_parallel_requests": max_parallel_requests,
"metadata": metadata_json,
"tpm_limit": tpm_limit,
"rpm_limit": rpm_limit,
}
if prisma_client is not None:
## CREATE USER (If necessary)
@ -2032,7 +2041,6 @@ async def image_generation(
response_model=GenerateKeyResponse,
)
async def generate_key_fn(
request: Request,
data: GenerateKeyRequest,
Authorization: Optional[str] = Header(None),
):

View file

@ -12,7 +12,10 @@ model LiteLLM_UserTable {
max_budget Float?
spend Float @default(0.0)
user_email String?
models String[] @default([])
models String[]
max_parallel_requests Int?
tpm_limit BigInt?
rpm_limit BigInt?
}
// required for token gen
@ -20,12 +23,14 @@ model LiteLLM_VerificationToken {
token String @unique
spend Float @default(0.0)
expires DateTime?
models String[] @default([])
models String[]
aliases Json @default("{}")
config Json @default("{}")
user_id String?
max_parallel_requests Int?
metadata Json @default("{}")
tpm_limit BigInt?
rpm_limit BigInt?
}
model LiteLLM_Config {

View file

@ -13,6 +13,9 @@ model LiteLLM_UserTable {
spend Float @default(0.0)
user_email String?
models String[]
max_parallel_requests Int?
tpm_limit BigInt?
rpm_limit BigInt?
}
// required for token gen
@ -26,6 +29,8 @@ model LiteLLM_VerificationToken {
user_id String?
max_parallel_requests Int?
metadata Json @default("{}")
tpm_limit BigInt?
rpm_limit BigInt?
}
model LiteLLM_Config {