forked from phoenix/litellm-mirror
fix(proxy_server.py): enforce team based spend limits
This commit is contained in:
parent
f7af18d72a
commit
71d4b7aaf4
3 changed files with 146 additions and 4 deletions
|
@ -458,6 +458,17 @@ class LiteLLM_VerificationToken(LiteLLMBase):
|
||||||
protected_namespaces = ()
|
protected_namespaces = ()
|
||||||
|
|
||||||
|
|
||||||
|
class LiteLLM_VerificationTokenView(LiteLLM_VerificationToken):
|
||||||
|
"""
|
||||||
|
Combined view of litellm verification token + litellm team table (select values)
|
||||||
|
"""
|
||||||
|
|
||||||
|
team_spend: Optional[float] = None
|
||||||
|
team_tpm_limit: Optional[int] = None
|
||||||
|
team_rpm_limit: Optional[int] = None
|
||||||
|
team_max_budget: Optional[float] = None
|
||||||
|
|
||||||
|
|
||||||
class UserAPIKeyAuth(
|
class UserAPIKeyAuth(
|
||||||
LiteLLM_VerificationToken
|
LiteLLM_VerificationToken
|
||||||
): # the expected response object for user api key auth
|
): # the expected response object for user api key auth
|
||||||
|
|
|
@ -350,13 +350,14 @@ async def user_api_key_auth(
|
||||||
original_api_key = api_key # (Patch: For DynamoDB Backwards Compatibility)
|
original_api_key = api_key # (Patch: For DynamoDB Backwards Compatibility)
|
||||||
if api_key.startswith("sk-"):
|
if api_key.startswith("sk-"):
|
||||||
api_key = hash_token(token=api_key)
|
api_key = hash_token(token=api_key)
|
||||||
valid_token = user_api_key_cache.get_cache(key=api_key)
|
# valid_token = user_api_key_cache.get_cache(key=api_key)
|
||||||
|
valid_token = None
|
||||||
if valid_token is None:
|
if valid_token is None:
|
||||||
## check db
|
## check db
|
||||||
verbose_proxy_logger.debug(f"api key: {api_key}")
|
verbose_proxy_logger.debug(f"api key: {api_key}")
|
||||||
if prisma_client is not None:
|
if prisma_client is not None:
|
||||||
valid_token = await prisma_client.get_data(
|
valid_token = await prisma_client.get_data(
|
||||||
token=api_key,
|
token=api_key, table_name="combined_view"
|
||||||
)
|
)
|
||||||
|
|
||||||
elif custom_db_client is not None:
|
elif custom_db_client is not None:
|
||||||
|
@ -381,6 +382,8 @@ async def user_api_key_auth(
|
||||||
# 4. If token is expired
|
# 4. If token is expired
|
||||||
# 5. If token spend is under Budget for the token
|
# 5. If token spend is under Budget for the token
|
||||||
# 6. If token spend per model is under budget per model
|
# 6. If token spend per model is under budget per model
|
||||||
|
# 7. If token spend is under team budget
|
||||||
|
# 8. If team spend is under team budget
|
||||||
|
|
||||||
request_data = await _read_request_body(
|
request_data = await _read_request_body(
|
||||||
request=request
|
request=request
|
||||||
|
@ -610,6 +613,44 @@ async def user_api_key_auth(
|
||||||
f"ExceededModelBudget: Current spend for model: {current_model_spend}; Max Budget for Model: {current_model_budget}"
|
f"ExceededModelBudget: Current spend for model: {current_model_spend}; Max Budget for Model: {current_model_budget}"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Check 6. Token spend is under Team budget
|
||||||
|
if (
|
||||||
|
valid_token.spend is not None
|
||||||
|
and valid_token.team_max_budget is not None
|
||||||
|
):
|
||||||
|
asyncio.create_task(
|
||||||
|
proxy_logging_obj.budget_alerts(
|
||||||
|
user_max_budget=valid_token.team_max_budget,
|
||||||
|
user_current_spend=valid_token.spend,
|
||||||
|
type="token_budget",
|
||||||
|
user_info=valid_token,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
if valid_token.spend > valid_token.team_max_budget:
|
||||||
|
raise Exception(
|
||||||
|
f"ExceededTokenBudget: Current spend for token: {valid_token.spend}; Max Budget for Team: {valid_token.team_max_budget}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check 7. Team spend is under Team budget
|
||||||
|
if (
|
||||||
|
valid_token.team_spend is not None
|
||||||
|
and valid_token.team_max_budget is not None
|
||||||
|
):
|
||||||
|
asyncio.create_task(
|
||||||
|
proxy_logging_obj.budget_alerts(
|
||||||
|
user_max_budget=valid_token.team_max_budget,
|
||||||
|
user_current_spend=valid_token.team_spend,
|
||||||
|
type="token_budget",
|
||||||
|
user_info=valid_token,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
if valid_token.team_spend > valid_token.team_max_budget:
|
||||||
|
raise Exception(
|
||||||
|
f"ExceededTokenBudget: Current Team Spend: {valid_token.team_spend}; Max Budget for Team: {valid_token.team_max_budget}"
|
||||||
|
)
|
||||||
|
|
||||||
# Token passed all checks
|
# Token passed all checks
|
||||||
api_key = valid_token.token
|
api_key = valid_token.token
|
||||||
|
|
||||||
|
@ -2256,6 +2297,10 @@ async def startup_event():
|
||||||
duration=None, models=[], aliases={}, config={}, spend=0, token=master_key
|
duration=None, models=[], aliases={}, config={}, spend=0, token=master_key
|
||||||
)
|
)
|
||||||
|
|
||||||
|
### CHECK IF VIEW EXISTS ###
|
||||||
|
create_view_response = await prisma_client.check_view_exists()
|
||||||
|
print(f"create_view_response: {create_view_response}") # noqa
|
||||||
|
|
||||||
### START BUDGET SCHEDULER ###
|
### START BUDGET SCHEDULER ###
|
||||||
if prisma_client is not None:
|
if prisma_client is not None:
|
||||||
scheduler = AsyncIOScheduler()
|
scheduler = AsyncIOScheduler()
|
||||||
|
|
|
@ -5,6 +5,7 @@ from litellm.proxy._types import (
|
||||||
UserAPIKeyAuth,
|
UserAPIKeyAuth,
|
||||||
DynamoDBArgs,
|
DynamoDBArgs,
|
||||||
LiteLLM_VerificationToken,
|
LiteLLM_VerificationToken,
|
||||||
|
LiteLLM_VerificationTokenView,
|
||||||
LiteLLM_SpendLogs,
|
LiteLLM_SpendLogs,
|
||||||
)
|
)
|
||||||
from litellm.caching import DualCache
|
from litellm.caching import DualCache
|
||||||
|
@ -479,6 +480,49 @@ class PrismaClient:
|
||||||
db_data[k] = json.dumps(v)
|
db_data[k] = json.dumps(v)
|
||||||
return db_data
|
return db_data
|
||||||
|
|
||||||
|
@backoff.on_exception(
|
||||||
|
backoff.expo,
|
||||||
|
Exception, # base exception to catch for the backoff
|
||||||
|
max_tries=3, # maximum number of retries
|
||||||
|
max_time=10, # maximum total time to retry for
|
||||||
|
on_backoff=on_backoff, # specifying the function to call on backoff
|
||||||
|
)
|
||||||
|
async def check_view_exists(self):
|
||||||
|
"""
|
||||||
|
Checks if the LiteLLM_VerificationTokenView exists in the user's db.
|
||||||
|
|
||||||
|
This is used for getting the token + team data in user_api_key_auth
|
||||||
|
|
||||||
|
If the view doesn't exist, one will be created.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Try to select one row from the view
|
||||||
|
await self.db.execute_raw(
|
||||||
|
"""SELECT 1 FROM "LiteLLM_VerificationTokenView" LIMIT 1"""
|
||||||
|
)
|
||||||
|
return "LiteLLM_VerificationTokenView Exists!"
|
||||||
|
except Exception as e:
|
||||||
|
# If an error occurs, the view does not exist, so create it
|
||||||
|
value = await self.health_check()
|
||||||
|
if '"litellm_verificationtokenview" does not exist' in str(e):
|
||||||
|
await self.db.execute_raw(
|
||||||
|
"""
|
||||||
|
CREATE VIEW "LiteLLM_VerificationTokenView" AS
|
||||||
|
SELECT
|
||||||
|
v.*,
|
||||||
|
t.spend AS team_spend,
|
||||||
|
t.max_budget AS team_max_budget,
|
||||||
|
t.tpm_limit AS team_tpm_limit,
|
||||||
|
t.rpm_limit AS team_rpm_limit
|
||||||
|
FROM "LiteLLM_VerificationToken" v
|
||||||
|
LEFT JOIN "LiteLLM_TeamTable" t ON v.team_id = t.team_id;
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise e
|
||||||
|
|
||||||
|
return "LiteLLM_VerificationTokenView Created!"
|
||||||
|
|
||||||
@backoff.on_exception(
|
@backoff.on_exception(
|
||||||
backoff.expo,
|
backoff.expo,
|
||||||
Exception, # base exception to catch for the backoff
|
Exception, # base exception to catch for the backoff
|
||||||
|
@ -535,7 +579,15 @@ class PrismaClient:
|
||||||
team_id_list: Optional[list] = None,
|
team_id_list: Optional[list] = None,
|
||||||
key_val: Optional[dict] = None,
|
key_val: Optional[dict] = None,
|
||||||
table_name: Optional[
|
table_name: Optional[
|
||||||
Literal["user", "key", "config", "spend", "team", "user_notification"]
|
Literal[
|
||||||
|
"user",
|
||||||
|
"key",
|
||||||
|
"config",
|
||||||
|
"spend",
|
||||||
|
"team",
|
||||||
|
"user_notification",
|
||||||
|
"combined_view",
|
||||||
|
]
|
||||||
] = None,
|
] = None,
|
||||||
query_type: Literal["find_unique", "find_all"] = "find_unique",
|
query_type: Literal["find_unique", "find_all"] = "find_unique",
|
||||||
expires: Optional[datetime] = None,
|
expires: Optional[datetime] = None,
|
||||||
|
@ -543,7 +595,9 @@ class PrismaClient:
|
||||||
):
|
):
|
||||||
try:
|
try:
|
||||||
response: Any = None
|
response: Any = None
|
||||||
if token is not None or (table_name is not None and table_name == "key"):
|
if (token is not None and table_name is None) or (
|
||||||
|
table_name is not None and table_name == "key"
|
||||||
|
):
|
||||||
# check if plain text or hash
|
# check if plain text or hash
|
||||||
if token is not None:
|
if token is not None:
|
||||||
if isinstance(token, str):
|
if isinstance(token, str):
|
||||||
|
@ -723,6 +777,38 @@ class PrismaClient:
|
||||||
elif query_type == "find_all":
|
elif query_type == "find_all":
|
||||||
response = await self.db.litellm_usernotifications.find_many() # type: ignore
|
response = await self.db.litellm_usernotifications.find_many() # type: ignore
|
||||||
return response
|
return response
|
||||||
|
elif table_name == "combined_view":
|
||||||
|
# check if plain text or hash
|
||||||
|
if token is not None:
|
||||||
|
if isinstance(token, str):
|
||||||
|
hashed_token = token
|
||||||
|
if token.startswith("sk-"):
|
||||||
|
hashed_token = self.hash_token(token=token)
|
||||||
|
verbose_proxy_logger.debug(
|
||||||
|
f"PrismaClient: find_unique for token: {hashed_token}"
|
||||||
|
)
|
||||||
|
if query_type == "find_unique":
|
||||||
|
if token is None:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=400,
|
||||||
|
detail={"error": f"No token passed in. Token={token}"},
|
||||||
|
)
|
||||||
|
|
||||||
|
sql_query = f"""
|
||||||
|
SELECT *
|
||||||
|
FROM "LiteLLM_VerificationTokenView"
|
||||||
|
WHERE token = '{token}'
|
||||||
|
"""
|
||||||
|
|
||||||
|
response = await self.db.query_first(query=sql_query)
|
||||||
|
if response is not None:
|
||||||
|
response = LiteLLM_VerificationTokenView(**response)
|
||||||
|
# for prisma we need to cast the expires time to str
|
||||||
|
if response.expires is not None and isinstance(
|
||||||
|
response.expires, datetime
|
||||||
|
):
|
||||||
|
response.expires = response.expires.isoformat()
|
||||||
|
return response
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print_verbose(f"LiteLLM Prisma Client Exception: {e}")
|
print_verbose(f"LiteLLM Prisma Client Exception: {e}")
|
||||||
import traceback
|
import traceback
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue