Merge pull request #2208 from BerriAI/litellm_enforce_team_limits

Litellm enforce team limits
This commit is contained in:
Krish Dholakia 2024-02-26 23:10:01 -08:00 committed by GitHub
commit 365e7ed5b9
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
8 changed files with 302 additions and 47 deletions

View file

@ -458,8 +458,19 @@ class LiteLLM_VerificationToken(LiteLLMBase):
protected_namespaces = () protected_namespaces = ()
class LiteLLM_VerificationTokenView(LiteLLM_VerificationToken):
"""
Combined view of litellm verification token + litellm team table (select values)
"""
team_spend: Optional[float] = None
team_tpm_limit: Optional[int] = None
team_rpm_limit: Optional[int] = None
team_max_budget: Optional[float] = None
class UserAPIKeyAuth( class UserAPIKeyAuth(
LiteLLM_VerificationToken LiteLLM_VerificationTokenView
): # the expected response object for user api key auth ): # the expected response object for user api key auth
""" """
Return the row in the db Return the row in the db

View file

@ -38,7 +38,6 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger):
current = cache.get_cache( current = cache.get_cache(
key=request_count_api_key key=request_count_api_key
) # {"current_requests": 1, "current_tpm": 1, "current_rpm": 10} ) # {"current_requests": 1, "current_tpm": 1, "current_rpm": 10}
# print(f"current: {current}")
if current is None: if current is None:
new_val = { new_val = {
"current_requests": 1, "current_requests": 1,
@ -73,8 +72,12 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger):
self.print_verbose(f"Inside Max Parallel Request Pre-Call Hook") self.print_verbose(f"Inside Max Parallel Request Pre-Call Hook")
api_key = user_api_key_dict.api_key api_key = user_api_key_dict.api_key
max_parallel_requests = user_api_key_dict.max_parallel_requests or sys.maxsize max_parallel_requests = user_api_key_dict.max_parallel_requests or sys.maxsize
tpm_limit = user_api_key_dict.tpm_limit or sys.maxsize tpm_limit = getattr(user_api_key_dict, "tpm_limit", sys.maxsize)
rpm_limit = user_api_key_dict.rpm_limit or sys.maxsize if tpm_limit is None:
tpm_limit = sys.maxsize
rpm_limit = getattr(user_api_key_dict, "rpm_limit", sys.maxsize)
if rpm_limit is None:
rpm_limit = sys.maxsize
if api_key is None: if api_key is None:
return return
@ -131,10 +134,9 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger):
_user_id_rate_limits = user_api_key_dict.user_id_rate_limits _user_id_rate_limits = user_api_key_dict.user_id_rate_limits
# get user tpm/rpm limits # get user tpm/rpm limits
if _user_id_rate_limits is None or _user_id_rate_limits == {}: if _user_id_rate_limits is not None and isinstance(_user_id_rate_limits, dict):
return user_tpm_limit = _user_id_rate_limits.get("tpm_limit", None)
user_tpm_limit = _user_id_rate_limits.get("tpm_limit") user_rpm_limit = _user_id_rate_limits.get("rpm_limit", None)
user_rpm_limit = _user_id_rate_limits.get("rpm_limit")
if user_tpm_limit is None: if user_tpm_limit is None:
user_tpm_limit = sys.maxsize user_tpm_limit = sys.maxsize
if user_rpm_limit is None: if user_rpm_limit is None:
@ -154,6 +156,36 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger):
tpm_limit=user_tpm_limit, tpm_limit=user_tpm_limit,
rpm_limit=user_rpm_limit, rpm_limit=user_rpm_limit,
) )
# TEAM RATE LIMITS
## get team tpm/rpm limits
team_id = user_api_key_dict.team_id
team_tpm_limit = getattr(user_api_key_dict, "team_tpm_limit", sys.maxsize)
if team_tpm_limit is None:
team_tpm_limit = sys.maxsize
team_rpm_limit = getattr(user_api_key_dict, "team_rpm_limit", sys.maxsize)
if team_rpm_limit is None:
team_rpm_limit = sys.maxsize
if team_tpm_limit is None:
team_tpm_limit = sys.maxsize
if team_rpm_limit is None:
team_rpm_limit = sys.maxsize
# now do the same tpm/rpm checks
request_count_api_key = f"{team_id}::{precise_minute}::request_count"
# print(f"Checking if {request_count_api_key} is allowed to make request for minute {precise_minute}")
await self.check_key_in_limits(
user_api_key_dict=user_api_key_dict,
cache=cache,
data=data,
call_type=call_type,
max_parallel_requests=sys.maxsize, # TODO: Support max parallel requests for a user
request_count_api_key=request_count_api_key,
tpm_limit=team_tpm_limit,
rpm_limit=team_rpm_limit,
)
return return
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time): async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
@ -163,6 +195,9 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger):
user_api_key_user_id = kwargs["litellm_params"]["metadata"].get( user_api_key_user_id = kwargs["litellm_params"]["metadata"].get(
"user_api_key_user_id", None "user_api_key_user_id", None
) )
user_api_key_team_id = kwargs["litellm_params"]["metadata"].get(
"user_api_key_team_id", None
)
if user_api_key is None: if user_api_key is None:
return return
@ -212,7 +247,41 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger):
# ------------ # ------------
# Update usage - User # Update usage - User
# ------------ # ------------
if user_api_key_user_id is None: if user_api_key_user_id is not None:
total_tokens = 0
if isinstance(response_obj, ModelResponse):
total_tokens = response_obj.usage.total_tokens
request_count_api_key = (
f"{user_api_key_user_id}::{precise_minute}::request_count"
)
current = self.user_api_key_cache.get_cache(
key=request_count_api_key
) or {
"current_requests": 1,
"current_tpm": total_tokens,
"current_rpm": 1,
}
new_val = {
"current_requests": max(current["current_requests"] - 1, 0),
"current_tpm": current["current_tpm"] + total_tokens,
"current_rpm": current["current_rpm"] + 1,
}
self.print_verbose(
f"updated_value in success call: {new_val}, precise_minute: {precise_minute}"
)
self.user_api_key_cache.set_cache(
request_count_api_key, new_val, ttl=60
) # store in cache for 1 min.
# ------------
# Update usage - Team
# ------------
if user_api_key_team_id is None:
return return
total_tokens = 0 total_tokens = 0
@ -221,7 +290,7 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger):
total_tokens = response_obj.usage.total_tokens total_tokens = response_obj.usage.total_tokens
request_count_api_key = ( request_count_api_key = (
f"{user_api_key_user_id}::{precise_minute}::request_count" f"{user_api_key_team_id}::{precise_minute}::request_count"
) )
current = self.user_api_key_cache.get_cache(key=request_count_api_key) or { current = self.user_api_key_cache.get_cache(key=request_count_api_key) or {

View file

@ -356,7 +356,7 @@ async def user_api_key_auth(
verbose_proxy_logger.debug(f"api key: {api_key}") verbose_proxy_logger.debug(f"api key: {api_key}")
if prisma_client is not None: if prisma_client is not None:
valid_token = await prisma_client.get_data( valid_token = await prisma_client.get_data(
token=api_key, token=api_key, table_name="combined_view"
) )
elif custom_db_client is not None: elif custom_db_client is not None:
@ -381,7 +381,8 @@ async def user_api_key_auth(
# 4. If token is expired # 4. If token is expired
# 5. If token spend is under Budget for the token # 5. If token spend is under Budget for the token
# 6. If token spend per model is under budget per model # 6. If token spend per model is under budget per model
# 7. If token spend is under team budget
# 8. If team spend is under team budget
request_data = await _read_request_body( request_data = await _read_request_body(
request=request request=request
) # request data, used across all checks. Making this easily available ) # request data, used across all checks. Making this easily available
@ -610,6 +611,47 @@ async def user_api_key_auth(
f"ExceededModelBudget: Current spend for model: {current_model_spend}; Max Budget for Model: {current_model_budget}" f"ExceededModelBudget: Current spend for model: {current_model_spend}; Max Budget for Model: {current_model_budget}"
) )
# Check 6. Token spend is under Team budget
if (
valid_token.spend is not None
and hasattr(valid_token, "team_max_budget")
and valid_token.team_max_budget is not None
):
asyncio.create_task(
proxy_logging_obj.budget_alerts(
user_max_budget=valid_token.team_max_budget,
user_current_spend=valid_token.spend,
type="token_budget",
user_info=valid_token,
)
)
if valid_token.spend >= valid_token.team_max_budget:
raise Exception(
f"ExceededTokenBudget: Current spend for token: {valid_token.spend}; Max Budget for Team: {valid_token.team_max_budget}"
)
# Check 7. Team spend is under Team budget
if (
hasattr(valid_token, "team_spend")
and valid_token.team_spend is not None
and hasattr(valid_token, "team_max_budget")
and valid_token.team_max_budget is not None
):
asyncio.create_task(
proxy_logging_obj.budget_alerts(
user_max_budget=valid_token.team_max_budget,
user_current_spend=valid_token.team_spend,
type="token_budget",
user_info=valid_token,
)
)
if valid_token.team_spend >= valid_token.team_max_budget:
raise Exception(
f"ExceededTokenBudget: Current Team Spend: {valid_token.team_spend}; Max Budget for Team: {valid_token.team_max_budget}"
)
# Token passed all checks # Token passed all checks
api_key = valid_token.token api_key = valid_token.token
@ -1870,14 +1912,6 @@ async def generate_key_helper_fn(
saved_token["expires"], datetime saved_token["expires"], datetime
): ):
saved_token["expires"] = saved_token["expires"].isoformat() saved_token["expires"] = saved_token["expires"].isoformat()
if key_data["token"] is not None and isinstance(key_data["token"], str):
hashed_token = hash_token(key_data["token"])
saved_token["token"] = hashed_token
user_api_key_cache.set_cache(
key=hashed_token,
value=LiteLLM_VerificationToken(**saved_token), # type: ignore
ttl=600,
)
if prisma_client is not None: if prisma_client is not None:
## CREATE USER (If necessary) ## CREATE USER (If necessary)
verbose_proxy_logger.debug(f"prisma_client: Creating User={user_data}") verbose_proxy_logger.debug(f"prisma_client: Creating User={user_data}")
@ -2263,6 +2297,11 @@ async def startup_event():
duration=None, models=[], aliases={}, config={}, spend=0, token=master_key duration=None, models=[], aliases={}, config={}, spend=0, token=master_key
) )
### CHECK IF VIEW EXISTS ###
if prisma_client is not None:
create_view_response = await prisma_client.check_view_exists()
print(f"create_view_response: {create_view_response}") # noqa
### START BUDGET SCHEDULER ### ### START BUDGET SCHEDULER ###
if prisma_client is not None: if prisma_client is not None:
scheduler = AsyncIOScheduler() scheduler = AsyncIOScheduler()

View file

@ -5,6 +5,7 @@ from litellm.proxy._types import (
UserAPIKeyAuth, UserAPIKeyAuth,
DynamoDBArgs, DynamoDBArgs,
LiteLLM_VerificationToken, LiteLLM_VerificationToken,
LiteLLM_VerificationTokenView,
LiteLLM_SpendLogs, LiteLLM_SpendLogs,
) )
from litellm.caching import DualCache from litellm.caching import DualCache
@ -479,6 +480,46 @@ class PrismaClient:
db_data[k] = json.dumps(v) db_data[k] = json.dumps(v)
return db_data return db_data
@backoff.on_exception(
backoff.expo,
Exception, # base exception to catch for the backoff
max_tries=3, # maximum number of retries
max_time=10, # maximum total time to retry for
on_backoff=on_backoff, # specifying the function to call on backoff
)
async def check_view_exists(self):
"""
Checks if the LiteLLM_VerificationTokenView exists in the user's db.
This is used for getting the token + team data in user_api_key_auth
If the view doesn't exist, one will be created.
"""
try:
# Try to select one row from the view
await self.db.execute_raw(
"""SELECT 1 FROM "LiteLLM_VerificationTokenView" LIMIT 1"""
)
return "LiteLLM_VerificationTokenView Exists!"
except Exception as e:
# If an error occurs, the view does not exist, so create it
value = await self.health_check()
await self.db.execute_raw(
"""
CREATE VIEW "LiteLLM_VerificationTokenView" AS
SELECT
v.*,
t.spend AS team_spend,
t.max_budget AS team_max_budget,
t.tpm_limit AS team_tpm_limit,
t.rpm_limit AS team_rpm_limit
FROM "LiteLLM_VerificationToken" v
LEFT JOIN "LiteLLM_TeamTable" t ON v.team_id = t.team_id;
"""
)
return "LiteLLM_VerificationTokenView Created!"
@backoff.on_exception( @backoff.on_exception(
backoff.expo, backoff.expo,
Exception, # base exception to catch for the backoff Exception, # base exception to catch for the backoff
@ -535,7 +576,15 @@ class PrismaClient:
team_id_list: Optional[list] = None, team_id_list: Optional[list] = None,
key_val: Optional[dict] = None, key_val: Optional[dict] = None,
table_name: Optional[ table_name: Optional[
Literal["user", "key", "config", "spend", "team", "user_notification"] Literal[
"user",
"key",
"config",
"spend",
"team",
"user_notification",
"combined_view",
]
] = None, ] = None,
query_type: Literal["find_unique", "find_all"] = "find_unique", query_type: Literal["find_unique", "find_all"] = "find_unique",
expires: Optional[datetime] = None, expires: Optional[datetime] = None,
@ -543,7 +592,9 @@ class PrismaClient:
): ):
try: try:
response: Any = None response: Any = None
if token is not None or (table_name is not None and table_name == "key"): if (token is not None and table_name is None) or (
table_name is not None and table_name == "key"
):
# check if plain text or hash # check if plain text or hash
if token is not None: if token is not None:
if isinstance(token, str): if isinstance(token, str):
@ -723,6 +774,38 @@ class PrismaClient:
elif query_type == "find_all": elif query_type == "find_all":
response = await self.db.litellm_usernotifications.find_many() # type: ignore response = await self.db.litellm_usernotifications.find_many() # type: ignore
return response return response
elif table_name == "combined_view":
# check if plain text or hash
if token is not None:
if isinstance(token, str):
hashed_token = token
if token.startswith("sk-"):
hashed_token = self.hash_token(token=token)
verbose_proxy_logger.debug(
f"PrismaClient: find_unique for token: {hashed_token}"
)
if query_type == "find_unique":
if token is None:
raise HTTPException(
status_code=400,
detail={"error": f"No token passed in. Token={token}"},
)
sql_query = f"""
SELECT *
FROM "LiteLLM_VerificationTokenView"
WHERE token = '{token}'
"""
response = await self.db.query_first(query=sql_query)
if response is not None:
response = LiteLLM_VerificationTokenView(**response)
# for prisma we need to cast the expires time to str
if response.expires is not None and isinstance(
response.expires, datetime
):
response.expires = response.expires.isoformat()
return response
except Exception as e: except Exception as e:
print_verbose(f"LiteLLM Prisma Client Exception: {e}") print_verbose(f"LiteLLM Prisma Client Exception: {e}")
import traceback import traceback

View file

@ -483,9 +483,12 @@ def test_redis_cache_completion_stream():
max_tokens=40, max_tokens=40,
temperature=0.2, temperature=0.2,
stream=True, stream=True,
caching=True,
) )
response_1_content = "" response_1_content = ""
response_1_id = None
for chunk in response1: for chunk in response1:
response_1_id = chunk.id
print(chunk) print(chunk)
response_1_content += chunk.choices[0].delta.content or "" response_1_content += chunk.choices[0].delta.content or ""
print(response_1_content) print(response_1_content)
@ -497,16 +500,22 @@ def test_redis_cache_completion_stream():
max_tokens=40, max_tokens=40,
temperature=0.2, temperature=0.2,
stream=True, stream=True,
caching=True,
) )
response_2_content = "" response_2_content = ""
response_2_id = None
for chunk in response2: for chunk in response2:
response_2_id = chunk.id
print(chunk) print(chunk)
response_2_content += chunk.choices[0].delta.content or "" response_2_content += chunk.choices[0].delta.content or ""
print("\nresponse 1", response_1_content) print("\nresponse 1", response_1_content)
print("\nresponse 2", response_2_content) print("\nresponse 2", response_2_content)
assert ( assert (
response_1_content == response_2_content response_1_id == response_2_id
), f"Response 1 != Response 2. Same params, Response 1{response_1_content} != Response 2{response_2_content}" ), f"Response 1 != Response 2. Same params, Response 1{response_1_content} != Response 2{response_2_content}"
# assert (
# response_1_content == response_2_content
# ), f"Response 1 != Response 2. Same params, Response 1{response_1_content} != Response 2{response_2_content}"
litellm.success_callback = [] litellm.success_callback = []
litellm._async_success_callback = [] litellm._async_success_callback = []
litellm.cache = None litellm.cache = None

View file

@ -124,25 +124,12 @@ def test_generate_and_call_with_valid_key(prisma_client):
bearer_token = "Bearer " + generated_key bearer_token = "Bearer " + generated_key
assert generated_key not in user_api_key_cache.in_memory_cache.cache_dict assert generated_key not in user_api_key_cache.in_memory_cache.cache_dict
assert (
hash_token(generated_key)
in user_api_key_cache.in_memory_cache.cache_dict
)
cached_value = user_api_key_cache.in_memory_cache.cache_dict[ value_from_prisma = await prisma_client.get_data(
hash_token(generated_key)
]
print("cached value=", cached_value)
print("cached token", cached_value.token)
value_from_prisma = valid_token = await prisma_client.get_data(
token=generated_key, token=generated_key,
) )
print("token from prisma", value_from_prisma) print("token from prisma", value_from_prisma)
assert value_from_prisma.token == cached_value.token
request = Request(scope={"type": "http"}) request = Request(scope={"type": "http"})
request._url = URL(url="/chat/completions") request._url = URL(url="/chat/completions")
@ -1241,7 +1228,7 @@ async def test_call_with_key_never_over_budget(prisma_client):
# use generated key to auth in # use generated key to auth in
result = await user_api_key_auth(request=request, api_key=bearer_token) result = await user_api_key_auth(request=request, api_key=bearer_token)
print("result from user auth with new key", result) print("result from user auth with new key: {result}")
# update spend using track_cost callback, make 2nd request, it should fail # update spend using track_cost callback, make 2nd request, it should fail
from litellm.proxy.proxy_server import ( from litellm.proxy.proxy_server import (
@ -1312,7 +1299,7 @@ async def test_call_with_key_over_budget_stream(prisma_client):
generated_key = key.key generated_key = key.key
user_id = key.user_id user_id = key.user_id
bearer_token = "Bearer " + generated_key bearer_token = "Bearer " + generated_key
print(f"generated_key: {generated_key}")
request = Request(scope={"type": "http"}) request = Request(scope={"type": "http"})
request._url = URL(url="/chat/completions") request._url = URL(url="/chat/completions")

View file

@ -99,6 +99,59 @@ async def test_pre_call_hook_rpm_limits():
assert e.status_code == 429 assert e.status_code == 429
@pytest.mark.asyncio
async def test_pre_call_hook_team_rpm_limits():
"""
Test if error raised on hitting team rpm limits
"""
litellm.set_verbose = True
_api_key = "sk-12345"
_team_id = "unique-team-id"
user_api_key_dict = UserAPIKeyAuth(
api_key=_api_key,
max_parallel_requests=1,
tpm_limit=9,
rpm_limit=10,
team_rpm_limit=1,
team_id=_team_id,
)
local_cache = DualCache()
parallel_request_handler = MaxParallelRequestsHandler()
await parallel_request_handler.async_pre_call_hook(
user_api_key_dict=user_api_key_dict, cache=local_cache, data={}, call_type=""
)
kwargs = {
"litellm_params": {
"metadata": {"user_api_key": _api_key, "user_api_key_team_id": _team_id}
}
}
await parallel_request_handler.async_log_success_event(
kwargs=kwargs,
response_obj="",
start_time="",
end_time="",
)
print(f"local_cache: {local_cache}")
## Expected cache val: {"current_requests": 0, "current_tpm": 0, "current_rpm": 1}
try:
await parallel_request_handler.async_pre_call_hook(
user_api_key_dict=user_api_key_dict,
cache=local_cache,
data={},
call_type="",
)
pytest.fail(f"Expected call to fail")
except Exception as e:
assert e.status_code == 429
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_pre_call_hook_tpm_limits(): async def test_pre_call_hook_tpm_limits():
""" """

View file

@ -1183,7 +1183,7 @@ class Logging:
verbose_logger.debug(f"success callbacks: {litellm.success_callback}") verbose_logger.debug(f"success callbacks: {litellm.success_callback}")
## BUILD COMPLETE STREAMED RESPONSE ## BUILD COMPLETE STREAMED RESPONSE
complete_streaming_response = None complete_streaming_response = None
if self.stream: if self.stream and isinstance(result, ModelResponse):
if ( if (
result.choices[0].finish_reason is not None result.choices[0].finish_reason is not None
): # if it's the last chunk ): # if it's the last chunk
@ -8682,6 +8682,8 @@ class CustomStreamWrapper:
completion_obj["content"] = response_obj["text"] completion_obj["content"] = response_obj["text"]
print_verbose(f"completion obj content: {completion_obj['content']}") print_verbose(f"completion obj content: {completion_obj['content']}")
if hasattr(chunk, "id"):
model_response.id = chunk.id
if response_obj["is_finished"]: if response_obj["is_finished"]:
model_response.choices[0].finish_reason = response_obj[ model_response.choices[0].finish_reason = response_obj[
"finish_reason" "finish_reason"
@ -8704,6 +8706,8 @@ class CustomStreamWrapper:
model_response.system_fingerprint = getattr( model_response.system_fingerprint = getattr(
response_obj["original_chunk"], "system_fingerprint", None response_obj["original_chunk"], "system_fingerprint", None
) )
if hasattr(response_obj["original_chunk"], "id"):
model_response.id = response_obj["original_chunk"].id
if response_obj["logprobs"] is not None: if response_obj["logprobs"] is not None:
model_response.choices[0].logprobs = response_obj["logprobs"] model_response.choices[0].logprobs = response_obj["logprobs"]