Merge pull request #3785 from BerriAI/litellm_end_user_rate_limits

[Feat] LiteLLM Proxy: Enforce End-User TPM, RPM Limits
This commit is contained in:
Ishaan Jaff 2024-05-22 17:12:58 -07:00 committed by GitHub
commit 31fc6d79af
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 292 additions and 21 deletions

View file

@ -374,6 +374,62 @@ curl --location 'http://0.0.0.0:4000/key/generate' \
} }
``` ```
</TabItem>
<TabItem value="per-end-user" label="For End User">
Use this to set rate limits for `user` passed to `/chat/completions`, without needing to create a key for every user
#### Step 1. Create Budget
Set a `tpm_limit` on the budget (You can also pass `rpm_limit` if needed)
```shell
curl --location 'http://0.0.0.0:4000/budget/new' \
--header 'Authorization: Bearer sk-1234' \
--header 'Content-Type: application/json' \
--data '{
"budget_id" : "free-tier",
"tpm_limit": 5
}'
```
#### Step 2. Create `End-User` with Budget
We use `budget_id="free-tier"` from Step 1 when creating this new end user
```shell
curl --location 'http://0.0.0.0:4000/end_user/new' \
--header 'Authorization: Bearer sk-1234' \
--header 'Content-Type: application/json' \
--data '{
"user_id" : "palantir",
"budget_id": "free-tier"
}'
```
#### Step 3. Pass end user id in `/chat/completions` requests
Pass the `user_id` from Step 2 as `user="palantir"`
```shell
curl --location 'http://localhost:4000/chat/completions' \
--header 'Authorization: Bearer sk-1234' \
--header 'Content-Type: application/json' \
--data '{
"model": "llama3",
"user": "palantir",
"messages": [
{
"role": "user",
"content": "gm"
}
]
}'
```
</TabItem> </TabItem>
</Tabs> </Tabs>

View file

@ -938,6 +938,11 @@ class LiteLLM_VerificationTokenView(LiteLLM_VerificationToken):
soft_budget: Optional[float] = None soft_budget: Optional[float] = None
team_model_aliases: Optional[Dict] = None team_model_aliases: Optional[Dict] = None
# End User Params
end_user_id: Optional[str] = None
end_user_tpm_limit: Optional[int] = None
end_user_rpm_limit: Optional[int] = None
class UserAPIKeyAuth( class UserAPIKeyAuth(
LiteLLM_VerificationTokenView LiteLLM_VerificationTokenView

View file

@ -219,7 +219,8 @@ async def get_end_user_object(
# else, check db # else, check db
try: try:
response = await prisma_client.db.litellm_endusertable.find_unique( response = await prisma_client.db.litellm_endusertable.find_unique(
where={"user_id": end_user_id} where={"user_id": end_user_id},
include={"litellm_budget_table": True},
) )
if response is None: if response is None:

View file

@ -64,7 +64,8 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger):
cache.set_cache(request_count_api_key, new_val) cache.set_cache(request_count_api_key, new_val)
else: else:
raise HTTPException( raise HTTPException(
status_code=429, detail="Max parallel request limit reached." status_code=429,
detail=f"LiteLLM Rate Limit Handler: Crossed TPM, RPM Limit. current rpm: {current['current_rpm']}, rpm limit: {rpm_limit}, current tpm: {current['current_tpm']}, tpm limit: {tpm_limit}",
) )
async def async_pre_call_hook( async def async_pre_call_hook(
@ -223,6 +224,38 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger):
rpm_limit=team_rpm_limit, rpm_limit=team_rpm_limit,
) )
# End-User Rate Limits
# Only enforce if user passed `user` to /chat, /completions, /embeddings
if user_api_key_dict.end_user_id:
end_user_tpm_limit = getattr(
user_api_key_dict, "end_user_tpm_limit", sys.maxsize
)
end_user_rpm_limit = getattr(
user_api_key_dict, "end_user_rpm_limit", sys.maxsize
)
if end_user_tpm_limit is None:
end_user_tpm_limit = sys.maxsize
if end_user_rpm_limit is None:
end_user_rpm_limit = sys.maxsize
# now do the same tpm/rpm checks
request_count_api_key = (
f"{user_api_key_dict.end_user_id}::{precise_minute}::request_count"
)
# print(f"Checking if {request_count_api_key} is allowed to make request for minute {precise_minute}")
await self.check_key_in_limits(
user_api_key_dict=user_api_key_dict,
cache=cache,
data=data,
call_type=call_type,
max_parallel_requests=sys.maxsize, # TODO: Support max parallel requests for an End-User
request_count_api_key=request_count_api_key,
tpm_limit=end_user_tpm_limit,
rpm_limit=end_user_rpm_limit,
)
return return
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time): async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
@ -238,6 +271,7 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger):
user_api_key_team_id = kwargs["litellm_params"]["metadata"].get( user_api_key_team_id = kwargs["litellm_params"]["metadata"].get(
"user_api_key_team_id", None "user_api_key_team_id", None
) )
user_api_key_end_user_id = kwargs.get("user")
if self.user_api_key_cache is None: if self.user_api_key_cache is None:
return return
@ -362,6 +396,40 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger):
request_count_api_key, new_val, ttl=60 request_count_api_key, new_val, ttl=60
) # store in cache for 1 min. ) # store in cache for 1 min.
# ------------
# Update usage - End User
# ------------
if user_api_key_end_user_id is not None:
total_tokens = 0
if isinstance(response_obj, ModelResponse):
total_tokens = response_obj.usage.total_tokens
request_count_api_key = (
f"{user_api_key_end_user_id}::{precise_minute}::request_count"
)
current = self.user_api_key_cache.get_cache(
key=request_count_api_key
) or {
"current_requests": 1,
"current_tpm": total_tokens,
"current_rpm": 1,
}
new_val = {
"current_requests": max(current["current_requests"] - 1, 0),
"current_tpm": current["current_tpm"] + total_tokens,
"current_rpm": current["current_rpm"] + 1,
}
self.print_verbose(
f"updated_value in success call: {new_val}, precise_minute: {precise_minute}"
)
self.user_api_key_cache.set_cache(
request_count_api_key, new_val, ttl=60
) # store in cache for 1 min.
except Exception as e: except Exception as e:
self.print_verbose(e) # noqa self.print_verbose(e) # noqa

View file

@ -655,19 +655,6 @@ async def user_api_key_auth(
detail="'allow_user_auth' not set or set to False", detail="'allow_user_auth' not set or set to False",
) )
### CHECK IF ADMIN ###
# note: never string compare api keys, this is vulenerable to a time attack. Use secrets.compare_digest instead
### CHECK IF ADMIN ###
# note: never string compare api keys, this is vulenerable to a time attack. Use secrets.compare_digest instead
## Check CACHE
valid_token = user_api_key_cache.get_cache(key=hash_token(api_key))
if (
valid_token is not None
and isinstance(valid_token, UserAPIKeyAuth)
and valid_token.user_role == "proxy_admin"
):
return valid_token
## Check END-USER OBJECT ## Check END-USER OBJECT
request_data = await _read_request_body(request=request) request_data = await _read_request_body(request=request)
_end_user_object = None _end_user_object = None
@ -683,12 +670,44 @@ async def user_api_key_auth(
end_user_params["allowed_model_region"] = ( end_user_params["allowed_model_region"] = (
_end_user_object.allowed_model_region _end_user_object.allowed_model_region
) )
if _end_user_object.litellm_budget_table is not None:
budget_info = _end_user_object.litellm_budget_table
end_user_params["end_user_id"] = _end_user_object.user_id
if budget_info.tpm_limit is not None:
end_user_params["end_user_tpm_limit"] = (
budget_info.tpm_limit
)
if budget_info.rpm_limit is not None:
end_user_params["end_user_rpm_limit"] = (
budget_info.rpm_limit
)
except Exception as e: except Exception as e:
verbose_proxy_logger.debug( verbose_proxy_logger.debug(
"Unable to find user in db. Error - {}".format(str(e)) "Unable to find user in db. Error - {}".format(str(e))
) )
pass pass
### CHECK IF ADMIN ###
# note: never string compare api keys, this is vulenerable to a time attack. Use secrets.compare_digest instead
### CHECK IF ADMIN ###
# note: never string compare api keys, this is vulenerable to a time attack. Use secrets.compare_digest instead
## Check CACHE
valid_token = user_api_key_cache.get_cache(key=hash_token(api_key))
if (
valid_token is not None
and isinstance(valid_token, UserAPIKeyAuth)
and valid_token.user_role == "proxy_admin"
):
# update end-user params on valid token
valid_token.end_user_id = end_user_params.get("end_user_id")
valid_token.end_user_tpm_limit = end_user_params.get("end_user_tpm_limit")
valid_token.end_user_rpm_limit = end_user_params.get("end_user_rpm_limit")
valid_token.allowed_model_region = end_user_params.get(
"allowed_model_region"
)
return valid_token
try: try:
is_master_key_valid = secrets.compare_digest(api_key, master_key) # type: ignore is_master_key_valid = secrets.compare_digest(api_key, master_key) # type: ignore
except Exception as e: except Exception as e:
@ -761,9 +780,18 @@ async def user_api_key_auth(
key=original_api_key, table_name="key" key=original_api_key, table_name="key"
) )
verbose_proxy_logger.debug("Token from db: %s", valid_token) verbose_proxy_logger.debug("Token from db: %s", valid_token)
elif valid_token is not None: elif valid_token is not None and isinstance(valid_token, UserAPIKeyAuth):
verbose_proxy_logger.debug("API Key Cache Hit!") verbose_proxy_logger.debug("API Key Cache Hit!")
# update end-user params on valid token
# These can change per request - it's important to update them here
valid_token.end_user_id = end_user_params.get("end_user_id")
valid_token.end_user_tpm_limit = end_user_params.get("end_user_tpm_limit")
valid_token.end_user_rpm_limit = end_user_params.get("end_user_rpm_limit")
valid_token.allowed_model_region = end_user_params.get(
"allowed_model_region"
)
user_id_information = None user_id_information = None
if valid_token: if valid_token:
# Got Valid Token from Cache, DB # Got Valid Token from Cache, DB
@ -1148,10 +1176,7 @@ async def user_api_key_auth(
valid_token_dict.pop("token", None) valid_token_dict.pop("token", None)
if _end_user_object is not None: if _end_user_object is not None:
valid_token_dict["allowed_model_region"] = ( valid_token_dict.update(end_user_params)
_end_user_object.allowed_model_region
)
""" """
asyncio create task to update the user api key cache with the user db table as well asyncio create task to update the user api key cache with the user db table as well

View file

@ -99,7 +99,12 @@ async def generate_key(
async def new_end_user( async def new_end_user(
session, i, user_id=str(uuid.uuid4()), model_region=None, default_model=None session,
i,
user_id=str(uuid.uuid4()),
model_region=None,
default_model=None,
budget_id=None,
): ):
url = "http://0.0.0.0:4000/end_user/new" url = "http://0.0.0.0:4000/end_user/new"
headers = {"Authorization": "Bearer sk-1234", "Content-Type": "application/json"} headers = {"Authorization": "Bearer sk-1234", "Content-Type": "application/json"}
@ -109,6 +114,10 @@ async def new_end_user(
"default_model": default_model, "default_model": default_model,
} }
if budget_id is not None:
data["budget_id"] = budget_id
print("end user data: {}".format(data))
async with session.post(url, headers=headers, json=data) as response: async with session.post(url, headers=headers, json=data) as response:
status = response.status status = response.status
response_text = await response.text() response_text = await response.text()
@ -123,6 +132,23 @@ async def new_end_user(
return await response.json() return await response.json()
async def new_budget(session, i, budget_id=None):
url = "http://0.0.0.0:4000/budget/new"
headers = {"Authorization": "Bearer sk-1234", "Content-Type": "application/json"}
data = {
"budget_id": budget_id,
"tpm_limit": 2,
}
async with session.post(url, headers=headers, json=data) as response:
status = response.status
response_text = await response.text()
print(f"Response {i} (Status code: {status}):")
print(response_text)
print()
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_end_user_new(): async def test_end_user_new():
""" """
@ -170,3 +196,93 @@ async def test_end_user_specific_region():
) )
assert result.headers.get("x-litellm-model-region") == "eu" assert result.headers.get("x-litellm-model-region") == "eu"
@pytest.mark.asyncio
async def test_enduser_tpm_limits_non_master_key():
"""
1. budget_id = Create Budget with tpm_limit = 10
2. create end_user with budget_id
3. Make /chat/completions calls
4. Sleep 1 second
4. Make /chat/completions call -> expect this to fail because rate limit hit
"""
async with aiohttp.ClientSession() as session:
# create a budget with budget_id = "free-tier"
budget_id = f"free-tier-{uuid.uuid4()}"
await new_budget(session, 0, budget_id=budget_id)
await asyncio.sleep(2)
end_user_id = str(uuid.uuid4())
await new_end_user(
session=session, i=0, user_id=end_user_id, budget_id=budget_id
)
## MAKE CALL ##
key_gen = await generate_key(session=session, i=0, models=[])
key = key_gen["key"]
# chat completion 1
client = AsyncOpenAI(api_key=key, base_url="http://0.0.0.0:4000")
# chat completion 2
passed = 0
for _ in range(10):
try:
result = await client.chat.completions.create(
model="fake-openai-endpoint",
messages=[{"role": "user", "content": "Hey!"}],
user=end_user_id,
)
passed += 1
except:
pass
print("Passed requests=", passed)
assert (
passed < 5
), f"Sent 10 requests and end-user has tpm_limit of 2. Number requests passed: {passed}. Expected less than 5 to pass"
@pytest.mark.asyncio
async def test_enduser_tpm_limits_with_master_key():
"""
1. budget_id = Create Budget with tpm_limit = 10
2. create end_user with budget_id
3. Make /chat/completions calls
4. Sleep 1 second
4. Make /chat/completions call -> expect this to fail because rate limit hit
"""
async with aiohttp.ClientSession() as session:
# create a budget with budget_id = "free-tier"
budget_id = f"free-tier-{uuid.uuid4()}"
await new_budget(session, 0, budget_id=budget_id)
end_user_id = str(uuid.uuid4())
await new_end_user(
session=session, i=0, user_id=end_user_id, budget_id=budget_id
)
# chat completion 1
client = AsyncOpenAI(api_key="sk-1234", base_url="http://0.0.0.0:4000")
# chat completion 2
passed = 0
for _ in range(10):
try:
result = await client.chat.completions.create(
model="fake-openai-endpoint",
messages=[{"role": "user", "content": "Hey!"}],
user=end_user_id,
)
passed += 1
except:
pass
print("Passed requests=", passed)
assert (
passed < 5
), f"Sent 10 requests and end-user has tpm_limit of 2. Number requests passed: {passed}. Expected less than 5 to pass"