From eac7e70dcaf26a7484f8374c970cb665c6992efd Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Wed, 22 May 2024 13:30:08 -0700 Subject: [PATCH 1/8] feat - include litellm_budget table when getting end_user --- litellm/proxy/auth/auth_checks.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/litellm/proxy/auth/auth_checks.py b/litellm/proxy/auth/auth_checks.py index 08da25556..fce6d4254 100644 --- a/litellm/proxy/auth/auth_checks.py +++ b/litellm/proxy/auth/auth_checks.py @@ -219,7 +219,8 @@ async def get_end_user_object( # else, check db try: response = await prisma_client.db.litellm_endusertable.find_unique( - where={"user_id": end_user_id} + where={"user_id": end_user_id}, + include={"litellm_budget_table": True}, ) if response is None: From 106910cecf3e839c78bd6310468c3aa29001469b Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Wed, 22 May 2024 14:01:57 -0700 Subject: [PATCH 2/8] feat - add end user rate limiting --- litellm/proxy/_types.py | 5 ++ .../proxy/hooks/parallel_request_limiter.py | 70 ++++++++++++++++++- 2 files changed, 74 insertions(+), 1 deletion(-) diff --git a/litellm/proxy/_types.py b/litellm/proxy/_types.py index 8bfa56004..ce38f4f9e 100644 --- a/litellm/proxy/_types.py +++ b/litellm/proxy/_types.py @@ -938,6 +938,11 @@ class LiteLLM_VerificationTokenView(LiteLLM_VerificationToken): soft_budget: Optional[float] = None team_model_aliases: Optional[Dict] = None + # End User Params + end_user_id: Optional[str] = None + end_user_tpm_limit: Optional[int] = None + end_user_rpm_limit: Optional[int] = None + class UserAPIKeyAuth( LiteLLM_VerificationTokenView diff --git a/litellm/proxy/hooks/parallel_request_limiter.py b/litellm/proxy/hooks/parallel_request_limiter.py index 26238b6c0..0558cdf05 100644 --- a/litellm/proxy/hooks/parallel_request_limiter.py +++ b/litellm/proxy/hooks/parallel_request_limiter.py @@ -64,7 +64,8 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger): cache.set_cache(request_count_api_key, new_val) else: raise HTTPException( - status_code=429, detail="Max parallel request limit reached." + status_code=429, + detail=f"LiteLLM Rate Limit Handler: Crossed TPM, RPM Limit. current rpm: {current['current_rpm']}, rpm limit: {rpm_limit}, current tpm: {current['current_tpm']}, tpm limit: {tpm_limit}", ) async def async_pre_call_hook( @@ -223,6 +224,38 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger): rpm_limit=team_rpm_limit, ) + # End-User Rate Limits + # Only enforce if user passed `user` to /chat, /completions, /embeddings + if user_api_key_dict.end_user_id: + end_user_tpm_limit = getattr( + user_api_key_dict, "end_user_tpm_limit", sys.maxsize + ) + end_user_rpm_limit = getattr( + user_api_key_dict, "end_user_rpm_limit", sys.maxsize + ) + + if end_user_tpm_limit is None: + end_user_tpm_limit = sys.maxsize + if end_user_rpm_limit is None: + end_user_rpm_limit = sys.maxsize + + # now do the same tpm/rpm checks + request_count_api_key = ( + f"{user_api_key_dict.end_user_id}::{precise_minute}::request_count" + ) + + # print(f"Checking if {request_count_api_key} is allowed to make request for minute {precise_minute}") + await self.check_key_in_limits( + user_api_key_dict=user_api_key_dict, + cache=cache, + data=data, + call_type=call_type, + max_parallel_requests=sys.maxsize, # TODO: Support max parallel requests for an End-User + request_count_api_key=request_count_api_key, + tpm_limit=end_user_tpm_limit, + rpm_limit=end_user_rpm_limit, + ) + return async def async_log_success_event(self, kwargs, response_obj, start_time, end_time): @@ -238,6 +271,7 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger): user_api_key_team_id = kwargs["litellm_params"]["metadata"].get( "user_api_key_team_id", None ) + user_api_key_end_user_id = kwargs.get("user") if self.user_api_key_cache is None: return @@ -362,6 +396,40 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger): request_count_api_key, new_val, ttl=60 ) # store in cache for 1 min. + # ------------ + # Update usage - End User + # ------------ + if user_api_key_end_user_id is not None: + total_tokens = 0 + + if isinstance(response_obj, ModelResponse): + total_tokens = response_obj.usage.total_tokens + + request_count_api_key = ( + f"{user_api_key_end_user_id}::{precise_minute}::request_count" + ) + + current = self.user_api_key_cache.get_cache( + key=request_count_api_key + ) or { + "current_requests": 1, + "current_tpm": total_tokens, + "current_rpm": 1, + } + + new_val = { + "current_requests": max(current["current_requests"] - 1, 0), + "current_tpm": current["current_tpm"] + total_tokens, + "current_rpm": current["current_rpm"] + 1, + } + + self.print_verbose( + f"updated_value in success call: {new_val}, precise_minute: {precise_minute}" + ) + self.user_api_key_cache.set_cache( + request_count_api_key, new_val, ttl=60 + ) # store in cache for 1 min. + except Exception as e: self.print_verbose(e) # noqa From bef10f4b01acbb4c9eb61f05b218afceacd27d06 Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Wed, 22 May 2024 15:42:41 -0700 Subject: [PATCH 3/8] test - end user tpm / rpm limiting --- tests/test_end_users.py | 86 ++++++++++++++++++++++++++++++++++++++++- 1 file changed, 85 insertions(+), 1 deletion(-) diff --git a/tests/test_end_users.py b/tests/test_end_users.py index 83bffdc12..9c8b59753 100644 --- a/tests/test_end_users.py +++ b/tests/test_end_users.py @@ -99,7 +99,12 @@ async def generate_key( async def new_end_user( - session, i, user_id=str(uuid.uuid4()), model_region=None, default_model=None + session, + i, + user_id=str(uuid.uuid4()), + model_region=None, + default_model=None, + budget_id=None, ): url = "http://0.0.0.0:4000/end_user/new" headers = {"Authorization": "Bearer sk-1234", "Content-Type": "application/json"} @@ -109,6 +114,10 @@ async def new_end_user( "default_model": default_model, } + if budget_id is not None: + data["budget_id"] = budget_id + print("end user data: {}".format(data)) + async with session.post(url, headers=headers, json=data) as response: status = response.status response_text = await response.text() @@ -123,6 +132,23 @@ async def new_end_user( return await response.json() +async def new_budget(session, i, budget_id=None): + url = "http://0.0.0.0:4000/budget/new" + headers = {"Authorization": "Bearer sk-1234", "Content-Type": "application/json"} + data = { + "budget_id": budget_id, + "tpm_limit": 2, + } + + async with session.post(url, headers=headers, json=data) as response: + status = response.status + response_text = await response.text() + + print(f"Response {i} (Status code: {status}):") + print(response_text) + print() + + @pytest.mark.asyncio async def test_end_user_new(): """ @@ -170,3 +196,61 @@ async def test_end_user_specific_region(): ) assert result.headers.get("x-litellm-model-region") == "eu" + + +@pytest.mark.asyncio +async def test_end_tpm_limits(): + """ + 1. budget_id = Create Budget with tpm_limit = 10 + 2. create end_user with budget_id + 3. Make /chat/completions calls + 4. Sleep 1 second + 4. Make /chat/completions call -> expect this to fail because rate limit hit + """ + async with aiohttp.ClientSession() as session: + # create a budget with budget_id = "free-tier" + budget_id = f"free-tier-{uuid.uuid4()}" + await new_budget(session, 0, budget_id=budget_id) + await asyncio.sleep(2) + + end_user_id = str(uuid.uuid4()) + + await new_end_user( + session=session, i=0, user_id=end_user_id, budget_id=budget_id + ) + + ## MAKE CALL ## + key_gen = await generate_key(session=session, i=0, models=[]) + + key = key_gen["key"] + + # chat completion 1 + client = AsyncOpenAI(api_key=key, base_url="http://0.0.0.0:4000") + + result = await client.chat.completions.create( + model="fake-openai-endpoint", + messages=[{"role": "user", "content": "Hey!"}], + user=end_user_id, + ) + + print("\nchat completion result 1=", result) + + await asyncio.sleep(1) + + # chat completion 2 + try: + result = await client.chat.completions.create( + model="fake-openai-endpoint", + messages=[{"role": "user", "content": "Hey!"}], + user=end_user_id, + ) + pytest.fail( + "User crossed their limit - this should have failed. instead got result = {}".format( + result + ) + ) + except Exception as e: + print("got exception 2 =", e) + assert "Crossed TPM, RPM Limit" in str( + e + ), f"Expected 'Crossed TPM, RPM Limit' but got {str(e)}" From e6b406d739e5d7fdf8554a17f418b31e8f789c42 Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Wed, 22 May 2024 15:45:30 -0700 Subject: [PATCH 4/8] feat - enforce end user tpm / rpm limits --- litellm/proxy/proxy_server.py | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index f157f420c..80b7c84d5 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -683,6 +683,17 @@ async def user_api_key_auth( end_user_params["allowed_model_region"] = ( _end_user_object.allowed_model_region ) + if _end_user_object.litellm_budget_table is not None: + budget_info = _end_user_object.litellm_budget_table + end_user_params["end_user_id"] = _end_user_object.user_id + if budget_info.tpm_limit is not None: + end_user_params["end_user_tpm_limit"] = ( + budget_info.tpm_limit + ) + if budget_info.rpm_limit is not None: + end_user_params["end_user_rpm_limit"] = ( + budget_info.rpm_limit + ) except Exception as e: verbose_proxy_logger.debug( "Unable to find user in db. Error - {}".format(str(e)) @@ -1148,10 +1159,7 @@ async def user_api_key_auth( valid_token_dict.pop("token", None) if _end_user_object is not None: - valid_token_dict["allowed_model_region"] = ( - _end_user_object.allowed_model_region - ) - + valid_token_dict.update(end_user_params) """ asyncio create task to update the user api key cache with the user db table as well From 42078ac285783bf559355ed4da7bd7143e250073 Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Wed, 22 May 2024 16:15:09 -0700 Subject: [PATCH 5/8] fix - run tpm / rpm checks on proxy admin keys too --- litellm/proxy/proxy_server.py | 45 ++++++++++++++++++++++++----------- 1 file changed, 31 insertions(+), 14 deletions(-) diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index 80b7c84d5..fbcff31c2 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -655,19 +655,6 @@ async def user_api_key_auth( detail="'allow_user_auth' not set or set to False", ) - ### CHECK IF ADMIN ### - # note: never string compare api keys, this is vulenerable to a time attack. Use secrets.compare_digest instead - ### CHECK IF ADMIN ### - # note: never string compare api keys, this is vulenerable to a time attack. Use secrets.compare_digest instead - ## Check CACHE - valid_token = user_api_key_cache.get_cache(key=hash_token(api_key)) - if ( - valid_token is not None - and isinstance(valid_token, UserAPIKeyAuth) - and valid_token.user_role == "proxy_admin" - ): - return valid_token - ## Check END-USER OBJECT request_data = await _read_request_body(request=request) _end_user_object = None @@ -700,6 +687,27 @@ async def user_api_key_auth( ) pass + ### CHECK IF ADMIN ### + # note: never string compare api keys, this is vulenerable to a time attack. Use secrets.compare_digest instead + ### CHECK IF ADMIN ### + # note: never string compare api keys, this is vulenerable to a time attack. Use secrets.compare_digest instead + ## Check CACHE + valid_token = user_api_key_cache.get_cache(key=hash_token(api_key)) + if ( + valid_token is not None + and isinstance(valid_token, UserAPIKeyAuth) + and valid_token.user_role == "proxy_admin" + ): + # update end-user params on valid token + valid_token.end_user_id = end_user_params.get("end_user_id") + valid_token.end_user_tpm_limit = end_user_params.get("end_user_tpm_limit") + valid_token.end_user_rpm_limit = end_user_params.get("end_user_rpm_limit") + valid_token.allowed_model_region = end_user_params.get( + "allowed_model_region" + ) + + return valid_token + try: is_master_key_valid = secrets.compare_digest(api_key, master_key) # type: ignore except Exception as e: @@ -772,9 +780,18 @@ async def user_api_key_auth( key=original_api_key, table_name="key" ) verbose_proxy_logger.debug("Token from db: %s", valid_token) - elif valid_token is not None: + elif valid_token is not None and isinstance(valid_token, UserAPIKeyAuth): verbose_proxy_logger.debug("API Key Cache Hit!") + # update end-user params on valid token + # These can change per request - it's important to update them here + valid_token.end_user_id = end_user_params.get("end_user_id") + valid_token.end_user_tpm_limit = end_user_params.get("end_user_tpm_limit") + valid_token.end_user_rpm_limit = end_user_params.get("end_user_rpm_limit") + valid_token.allowed_model_region = end_user_params.get( + "allowed_model_region" + ) + user_id_information = None if valid_token: # Got Valid Token from Cache, DB From a848a676af57d8453a7519e9517d670eea9a8887 Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Wed, 22 May 2024 16:20:25 -0700 Subject: [PATCH 6/8] docs - end user rate limiting --- docs/my-website/docs/proxy/users.md | 56 +++++++++++++++++++++++++++++ 1 file changed, 56 insertions(+) diff --git a/docs/my-website/docs/proxy/users.md b/docs/my-website/docs/proxy/users.md index 6d9c43c5f..866c0c4bb 100644 --- a/docs/my-website/docs/proxy/users.md +++ b/docs/my-website/docs/proxy/users.md @@ -374,6 +374,62 @@ curl --location 'http://0.0.0.0:4000/key/generate' \ } ``` + + + +Use this to set rate limits for `user` passed to `/chat/completions`, without needing to create a key for every user + +#### Step 1. Create Budget + +Set a `tpm_limit` on the budget (You can also pass `rpm_limit` if needed) + +```shell +curl --location 'http://0.0.0.0:4000/budget/new' \ +--header 'Authorization: Bearer sk-1234' \ +--header 'Content-Type: application/json' \ +--data '{ + "budget_id" : "free-tier", + "tpm_limit": 5 +}' +``` + + +#### Step 2. Create `End-User` with Budget + +We use `budget_id="free-tier"` from Step 1 when creating this new end user + +```shell +curl --location 'http://0.0.0.0:4000/end_user/new' \ +--header 'Authorization: Bearer sk-1234' \ +--header 'Content-Type: application/json' \ +--data '{ + "user_id" : "palantir", + "budget_id": "free-tier" +}' +``` + + +#### Step 3. Pass end user id in `/chat/completions` requests + +Pass the `user_id` from Step 2 as `user="palantir"` + +```shell +curl --location 'http://localhost:4000/chat/completions' \ + --header 'Authorization: Bearer sk-1234' \ + --header 'Content-Type: application/json' \ + --data '{ + "model": "llama3", + "user": "palantir", + "messages": [ + { + "role": "user", + "content": "gm" + } + ] +}' +``` + + From 4175d00a24ff090c02f0093e0527a50dc168a63f Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Wed, 22 May 2024 16:23:15 -0700 Subject: [PATCH 7/8] fix - test end user rate limits with master key --- tests/test_end_users.py | 52 ++++++++++++++++++++++++++++++++++++++++- 1 file changed, 51 insertions(+), 1 deletion(-) diff --git a/tests/test_end_users.py b/tests/test_end_users.py index 9c8b59753..9f405ccc5 100644 --- a/tests/test_end_users.py +++ b/tests/test_end_users.py @@ -199,7 +199,7 @@ async def test_end_user_specific_region(): @pytest.mark.asyncio -async def test_end_tpm_limits(): +async def test_enduser_tpm_limits_non_master_key(): """ 1. budget_id = Create Budget with tpm_limit = 10 2. create end_user with budget_id @@ -235,6 +235,56 @@ async def test_end_tpm_limits(): print("\nchat completion result 1=", result) + # chat completion 2 + try: + result = await client.chat.completions.create( + model="fake-openai-endpoint", + messages=[{"role": "user", "content": "Hey!"}], + user=end_user_id, + ) + pytest.fail( + "User crossed their limit - this should have failed. instead got result = {}".format( + result + ) + ) + except Exception as e: + print("got exception 2 =", e) + assert "Crossed TPM, RPM Limit" in str( + e + ), f"Expected 'Crossed TPM, RPM Limit' but got {str(e)}" + + +@pytest.mark.asyncio +async def test_enduser_tpm_limits_with_master_key(): + """ + 1. budget_id = Create Budget with tpm_limit = 10 + 2. create end_user with budget_id + 3. Make /chat/completions calls + 4. Sleep 1 second + 4. Make /chat/completions call -> expect this to fail because rate limit hit + """ + async with aiohttp.ClientSession() as session: + # create a budget with budget_id = "free-tier" + budget_id = f"free-tier-{uuid.uuid4()}" + await new_budget(session, 0, budget_id=budget_id) + + end_user_id = str(uuid.uuid4()) + + await new_end_user( + session=session, i=0, user_id=end_user_id, budget_id=budget_id + ) + + # chat completion 1 + client = AsyncOpenAI(api_key="sk-1234", base_url="http://0.0.0.0:4000") + + result = await client.chat.completions.create( + model="fake-openai-endpoint", + messages=[{"role": "user", "content": "Hey!"}], + user=end_user_id, + ) + + print("\nchat completion result 1=", result) + await asyncio.sleep(1) # chat completion 2 From a4cf453ad136f546d70bc53517e35fd93efbac20 Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Wed, 22 May 2024 16:46:19 -0700 Subject: [PATCH 8/8] fix - end user rate limiting tests --- tests/test_end_users.py | 78 ++++++++++++++++------------------------- 1 file changed, 30 insertions(+), 48 deletions(-) diff --git a/tests/test_end_users.py b/tests/test_end_users.py index 9f405ccc5..4ee894987 100644 --- a/tests/test_end_users.py +++ b/tests/test_end_users.py @@ -227,31 +227,23 @@ async def test_enduser_tpm_limits_non_master_key(): # chat completion 1 client = AsyncOpenAI(api_key=key, base_url="http://0.0.0.0:4000") - result = await client.chat.completions.create( - model="fake-openai-endpoint", - messages=[{"role": "user", "content": "Hey!"}], - user=end_user_id, - ) - - print("\nchat completion result 1=", result) - # chat completion 2 - try: - result = await client.chat.completions.create( - model="fake-openai-endpoint", - messages=[{"role": "user", "content": "Hey!"}], - user=end_user_id, - ) - pytest.fail( - "User crossed their limit - this should have failed. instead got result = {}".format( - result + passed = 0 + for _ in range(10): + try: + result = await client.chat.completions.create( + model="fake-openai-endpoint", + messages=[{"role": "user", "content": "Hey!"}], + user=end_user_id, ) - ) - except Exception as e: - print("got exception 2 =", e) - assert "Crossed TPM, RPM Limit" in str( - e - ), f"Expected 'Crossed TPM, RPM Limit' but got {str(e)}" + passed += 1 + except: + pass + print("Passed requests=", passed) + + assert ( + passed < 5 + ), f"Sent 10 requests and end-user has tpm_limit of 2. Number requests passed: {passed}. Expected less than 5 to pass" @pytest.mark.asyncio @@ -277,30 +269,20 @@ async def test_enduser_tpm_limits_with_master_key(): # chat completion 1 client = AsyncOpenAI(api_key="sk-1234", base_url="http://0.0.0.0:4000") - result = await client.chat.completions.create( - model="fake-openai-endpoint", - messages=[{"role": "user", "content": "Hey!"}], - user=end_user_id, - ) - - print("\nchat completion result 1=", result) - - await asyncio.sleep(1) - # chat completion 2 - try: - result = await client.chat.completions.create( - model="fake-openai-endpoint", - messages=[{"role": "user", "content": "Hey!"}], - user=end_user_id, - ) - pytest.fail( - "User crossed their limit - this should have failed. instead got result = {}".format( - result + passed = 0 + for _ in range(10): + try: + result = await client.chat.completions.create( + model="fake-openai-endpoint", + messages=[{"role": "user", "content": "Hey!"}], + user=end_user_id, ) - ) - except Exception as e: - print("got exception 2 =", e) - assert "Crossed TPM, RPM Limit" in str( - e - ), f"Expected 'Crossed TPM, RPM Limit' but got {str(e)}" + passed += 1 + except: + pass + print("Passed requests=", passed) + + assert ( + passed < 5 + ), f"Sent 10 requests and end-user has tpm_limit of 2. Number requests passed: {passed}. Expected less than 5 to pass"