From 71d4b7aaf46becdf294c1a90c85be788e7ced337 Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Mon, 26 Feb 2024 15:45:25 -0800 Subject: [PATCH 01/11] fix(proxy_server.py): enforce team based spend limits --- litellm/proxy/_types.py | 11 +++++ litellm/proxy/proxy_server.py | 49 ++++++++++++++++++- litellm/proxy/utils.py | 90 ++++++++++++++++++++++++++++++++++- 3 files changed, 146 insertions(+), 4 deletions(-) diff --git a/litellm/proxy/_types.py b/litellm/proxy/_types.py index 778a012b6..908cb58cf 100644 --- a/litellm/proxy/_types.py +++ b/litellm/proxy/_types.py @@ -458,6 +458,17 @@ class LiteLLM_VerificationToken(LiteLLMBase): protected_namespaces = () +class LiteLLM_VerificationTokenView(LiteLLM_VerificationToken): + """ + Combined view of litellm verification token + litellm team table (select values) + """ + + team_spend: Optional[float] = None + team_tpm_limit: Optional[int] = None + team_rpm_limit: Optional[int] = None + team_max_budget: Optional[float] = None + + class UserAPIKeyAuth( LiteLLM_VerificationToken ): # the expected response object for user api key auth diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index 56c49ae32..e1f604096 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -350,13 +350,14 @@ async def user_api_key_auth( original_api_key = api_key # (Patch: For DynamoDB Backwards Compatibility) if api_key.startswith("sk-"): api_key = hash_token(token=api_key) - valid_token = user_api_key_cache.get_cache(key=api_key) + # valid_token = user_api_key_cache.get_cache(key=api_key) + valid_token = None if valid_token is None: ## check db verbose_proxy_logger.debug(f"api key: {api_key}") if prisma_client is not None: valid_token = await prisma_client.get_data( - token=api_key, + token=api_key, table_name="combined_view" ) elif custom_db_client is not None: @@ -381,6 +382,8 @@ async def user_api_key_auth( # 4. If token is expired # 5. If token spend is under Budget for the token # 6. If token spend per model is under budget per model + # 7. If token spend is under team budget + # 8. If team spend is under team budget request_data = await _read_request_body( request=request @@ -610,6 +613,44 @@ async def user_api_key_auth( f"ExceededModelBudget: Current spend for model: {current_model_spend}; Max Budget for Model: {current_model_budget}" ) + # Check 6. Token spend is under Team budget + if ( + valid_token.spend is not None + and valid_token.team_max_budget is not None + ): + asyncio.create_task( + proxy_logging_obj.budget_alerts( + user_max_budget=valid_token.team_max_budget, + user_current_spend=valid_token.spend, + type="token_budget", + user_info=valid_token, + ) + ) + + if valid_token.spend > valid_token.team_max_budget: + raise Exception( + f"ExceededTokenBudget: Current spend for token: {valid_token.spend}; Max Budget for Team: {valid_token.team_max_budget}" + ) + + # Check 7. Team spend is under Team budget + if ( + valid_token.team_spend is not None + and valid_token.team_max_budget is not None + ): + asyncio.create_task( + proxy_logging_obj.budget_alerts( + user_max_budget=valid_token.team_max_budget, + user_current_spend=valid_token.team_spend, + type="token_budget", + user_info=valid_token, + ) + ) + + if valid_token.team_spend > valid_token.team_max_budget: + raise Exception( + f"ExceededTokenBudget: Current Team Spend: {valid_token.team_spend}; Max Budget for Team: {valid_token.team_max_budget}" + ) + # Token passed all checks api_key = valid_token.token @@ -2256,6 +2297,10 @@ async def startup_event(): duration=None, models=[], aliases={}, config={}, spend=0, token=master_key ) + ### CHECK IF VIEW EXISTS ### + create_view_response = await prisma_client.check_view_exists() + print(f"create_view_response: {create_view_response}") # noqa + ### START BUDGET SCHEDULER ### if prisma_client is not None: scheduler = AsyncIOScheduler() diff --git a/litellm/proxy/utils.py b/litellm/proxy/utils.py index 8defd918c..6b945ce72 100644 --- a/litellm/proxy/utils.py +++ b/litellm/proxy/utils.py @@ -5,6 +5,7 @@ from litellm.proxy._types import ( UserAPIKeyAuth, DynamoDBArgs, LiteLLM_VerificationToken, + LiteLLM_VerificationTokenView, LiteLLM_SpendLogs, ) from litellm.caching import DualCache @@ -479,6 +480,49 @@ class PrismaClient: db_data[k] = json.dumps(v) return db_data + @backoff.on_exception( + backoff.expo, + Exception, # base exception to catch for the backoff + max_tries=3, # maximum number of retries + max_time=10, # maximum total time to retry for + on_backoff=on_backoff, # specifying the function to call on backoff + ) + async def check_view_exists(self): + """ + Checks if the LiteLLM_VerificationTokenView exists in the user's db. + + This is used for getting the token + team data in user_api_key_auth + + If the view doesn't exist, one will be created. + """ + try: + # Try to select one row from the view + await self.db.execute_raw( + """SELECT 1 FROM "LiteLLM_VerificationTokenView" LIMIT 1""" + ) + return "LiteLLM_VerificationTokenView Exists!" + except Exception as e: + # If an error occurs, the view does not exist, so create it + value = await self.health_check() + if '"litellm_verificationtokenview" does not exist' in str(e): + await self.db.execute_raw( + """ + CREATE VIEW "LiteLLM_VerificationTokenView" AS + SELECT + v.*, + t.spend AS team_spend, + t.max_budget AS team_max_budget, + t.tpm_limit AS team_tpm_limit, + t.rpm_limit AS team_rpm_limit + FROM "LiteLLM_VerificationToken" v + LEFT JOIN "LiteLLM_TeamTable" t ON v.team_id = t.team_id; + """ + ) + else: + raise e + + return "LiteLLM_VerificationTokenView Created!" + @backoff.on_exception( backoff.expo, Exception, # base exception to catch for the backoff @@ -535,7 +579,15 @@ class PrismaClient: team_id_list: Optional[list] = None, key_val: Optional[dict] = None, table_name: Optional[ - Literal["user", "key", "config", "spend", "team", "user_notification"] + Literal[ + "user", + "key", + "config", + "spend", + "team", + "user_notification", + "combined_view", + ] ] = None, query_type: Literal["find_unique", "find_all"] = "find_unique", expires: Optional[datetime] = None, @@ -543,7 +595,9 @@ class PrismaClient: ): try: response: Any = None - if token is not None or (table_name is not None and table_name == "key"): + if (token is not None and table_name is None) or ( + table_name is not None and table_name == "key" + ): # check if plain text or hash if token is not None: if isinstance(token, str): @@ -723,6 +777,38 @@ class PrismaClient: elif query_type == "find_all": response = await self.db.litellm_usernotifications.find_many() # type: ignore return response + elif table_name == "combined_view": + # check if plain text or hash + if token is not None: + if isinstance(token, str): + hashed_token = token + if token.startswith("sk-"): + hashed_token = self.hash_token(token=token) + verbose_proxy_logger.debug( + f"PrismaClient: find_unique for token: {hashed_token}" + ) + if query_type == "find_unique": + if token is None: + raise HTTPException( + status_code=400, + detail={"error": f"No token passed in. Token={token}"}, + ) + + sql_query = f""" + SELECT * + FROM "LiteLLM_VerificationTokenView" + WHERE token = '{token}' + """ + + response = await self.db.query_first(query=sql_query) + if response is not None: + response = LiteLLM_VerificationTokenView(**response) + # for prisma we need to cast the expires time to str + if response.expires is not None and isinstance( + response.expires, datetime + ): + response.expires = response.expires.isoformat() + return response except Exception as e: print_verbose(f"LiteLLM Prisma Client Exception: {e}") import traceback From f84ac35000a7d6f370f17ecaa4c3d2b38cc28831 Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Mon, 26 Feb 2024 16:20:41 -0800 Subject: [PATCH 02/11] feat(parallel_request_limiter.py): enforce team based tpm / rpm limits --- litellm/proxy/_types.py | 2 +- .../proxy/hooks/parallel_request_limiter.py | 63 +++++++++++++++++++ 2 files changed, 64 insertions(+), 1 deletion(-) diff --git a/litellm/proxy/_types.py b/litellm/proxy/_types.py index 908cb58cf..086971366 100644 --- a/litellm/proxy/_types.py +++ b/litellm/proxy/_types.py @@ -470,7 +470,7 @@ class LiteLLM_VerificationTokenView(LiteLLM_VerificationToken): class UserAPIKeyAuth( - LiteLLM_VerificationToken + LiteLLM_VerificationTokenView ): # the expected response object for user api key auth """ Return the row in the db diff --git a/litellm/proxy/hooks/parallel_request_limiter.py b/litellm/proxy/hooks/parallel_request_limiter.py index fb61fe3da..e0c85cee0 100644 --- a/litellm/proxy/hooks/parallel_request_limiter.py +++ b/litellm/proxy/hooks/parallel_request_limiter.py @@ -154,6 +154,32 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger): tpm_limit=user_tpm_limit, rpm_limit=user_rpm_limit, ) + + # TEAM RATE LIMITS + ## get team tpm/rpm limits + team_id = user_api_key_dict.team_id + team_tpm_limit = user_api_key_dict.team_tpm_limit or sys.maxsize + team_rpm_limit = user_api_key_dict.team_rpm_limit or sys.maxsize + + if team_tpm_limit is None: + team_tpm_limit = sys.maxsize + if team_rpm_limit is None: + team_rpm_limit = sys.maxsize + + # now do the same tpm/rpm checks + request_count_api_key = f"{team_id}::{precise_minute}::request_count" + + # print(f"Checking if {request_count_api_key} is allowed to make request for minute {precise_minute}") + await self.check_key_in_limits( + user_api_key_dict=user_api_key_dict, + cache=cache, + data=data, + call_type=call_type, + max_parallel_requests=sys.maxsize, # TODO: Support max parallel requests for a user + request_count_api_key=request_count_api_key, + tpm_limit=team_tpm_limit, + rpm_limit=team_rpm_limit, + ) return async def async_log_success_event(self, kwargs, response_obj, start_time, end_time): @@ -163,6 +189,9 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger): user_api_key_user_id = kwargs["litellm_params"]["metadata"].get( "user_api_key_user_id", None ) + user_api_key_team_id = kwargs["litellm_params"]["metadata"].get( + "user_api_key_team_id", None + ) if user_api_key is None: return @@ -243,6 +272,40 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger): request_count_api_key, new_val, ttl=60 ) # store in cache for 1 min. + # ------------ + # Update usage - Team + # ------------ + if user_api_key_team_id is None: + return + + total_tokens = 0 + + if isinstance(response_obj, ModelResponse): + total_tokens = response_obj.usage.total_tokens + + request_count_api_key = ( + f"{user_api_key_team_id}::{precise_minute}::request_count" + ) + + current = self.user_api_key_cache.get_cache(key=request_count_api_key) or { + "current_requests": 1, + "current_tpm": total_tokens, + "current_rpm": 1, + } + + new_val = { + "current_requests": max(current["current_requests"] - 1, 0), + "current_tpm": current["current_tpm"] + total_tokens, + "current_rpm": current["current_rpm"] + 1, + } + + self.print_verbose( + f"updated_value in success call: {new_val}, precise_minute: {precise_minute}" + ) + self.user_api_key_cache.set_cache( + request_count_api_key, new_val, ttl=60 + ) # store in cache for 1 min. + except Exception as e: self.print_verbose(e) # noqa From f86ab190675d5de4d57451476ed02aa0ccaf70ec Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Mon, 26 Feb 2024 18:06:13 -0800 Subject: [PATCH 03/11] fix(parallel_request_limiter.py): fix team rate limit enforcement --- .../proxy/hooks/parallel_request_limiter.py | 100 +++++++++--------- litellm/proxy/proxy_server.py | 8 +- .../tests/test_parallel_request_limiter.py | 53 ++++++++++ 3 files changed, 105 insertions(+), 56 deletions(-) diff --git a/litellm/proxy/hooks/parallel_request_limiter.py b/litellm/proxy/hooks/parallel_request_limiter.py index e0c85cee0..a4fb70c57 100644 --- a/litellm/proxy/hooks/parallel_request_limiter.py +++ b/litellm/proxy/hooks/parallel_request_limiter.py @@ -38,7 +38,6 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger): current = cache.get_cache( key=request_count_api_key ) # {"current_requests": 1, "current_tpm": 1, "current_rpm": 10} - # print(f"current: {current}") if current is None: new_val = { "current_requests": 1, @@ -73,8 +72,8 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger): self.print_verbose(f"Inside Max Parallel Request Pre-Call Hook") api_key = user_api_key_dict.api_key max_parallel_requests = user_api_key_dict.max_parallel_requests or sys.maxsize - tpm_limit = user_api_key_dict.tpm_limit or sys.maxsize - rpm_limit = user_api_key_dict.rpm_limit or sys.maxsize + tpm_limit = getattr(user_api_key_dict, "tpm_limit", sys.maxsize) + rpm_limit = getattr(user_api_key_dict, "rpm_limit", sys.maxsize) if api_key is None: return @@ -131,35 +130,34 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger): _user_id_rate_limits = user_api_key_dict.user_id_rate_limits # get user tpm/rpm limits - if _user_id_rate_limits is None or _user_id_rate_limits == {}: - return - user_tpm_limit = _user_id_rate_limits.get("tpm_limit") - user_rpm_limit = _user_id_rate_limits.get("rpm_limit") - if user_tpm_limit is None: - user_tpm_limit = sys.maxsize - if user_rpm_limit is None: - user_rpm_limit = sys.maxsize + if _user_id_rate_limits is not None and isinstance(_user_id_rate_limits, dict): + user_tpm_limit = _user_id_rate_limits.get("tpm_limit", None) + user_rpm_limit = _user_id_rate_limits.get("rpm_limit", None) + if user_tpm_limit is None: + user_tpm_limit = sys.maxsize + if user_rpm_limit is None: + user_rpm_limit = sys.maxsize - # now do the same tpm/rpm checks - request_count_api_key = f"{user_id}::{precise_minute}::request_count" + # now do the same tpm/rpm checks + request_count_api_key = f"{user_id}::{precise_minute}::request_count" - # print(f"Checking if {request_count_api_key} is allowed to make request for minute {precise_minute}") - await self.check_key_in_limits( - user_api_key_dict=user_api_key_dict, - cache=cache, - data=data, - call_type=call_type, - max_parallel_requests=sys.maxsize, # TODO: Support max parallel requests for a user - request_count_api_key=request_count_api_key, - tpm_limit=user_tpm_limit, - rpm_limit=user_rpm_limit, - ) + # print(f"Checking if {request_count_api_key} is allowed to make request for minute {precise_minute}") + await self.check_key_in_limits( + user_api_key_dict=user_api_key_dict, + cache=cache, + data=data, + call_type=call_type, + max_parallel_requests=sys.maxsize, # TODO: Support max parallel requests for a user + request_count_api_key=request_count_api_key, + tpm_limit=user_tpm_limit, + rpm_limit=user_rpm_limit, + ) # TEAM RATE LIMITS ## get team tpm/rpm limits team_id = user_api_key_dict.team_id - team_tpm_limit = user_api_key_dict.team_tpm_limit or sys.maxsize - team_rpm_limit = user_api_key_dict.team_rpm_limit or sys.maxsize + team_tpm_limit = getattr(user_api_key_dict, "team_tpm_limit", sys.maxsize) + team_rpm_limit = getattr(user_api_key_dict, "team_rpm_limit", sys.maxsize) if team_tpm_limit is None: team_tpm_limit = sys.maxsize @@ -241,36 +239,36 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger): # ------------ # Update usage - User # ------------ - if user_api_key_user_id is None: - return + if user_api_key_user_id is not None: + total_tokens = 0 - total_tokens = 0 + if isinstance(response_obj, ModelResponse): + total_tokens = response_obj.usage.total_tokens - if isinstance(response_obj, ModelResponse): - total_tokens = response_obj.usage.total_tokens + request_count_api_key = ( + f"{user_api_key_user_id}::{precise_minute}::request_count" + ) - request_count_api_key = ( - f"{user_api_key_user_id}::{precise_minute}::request_count" - ) + current = self.user_api_key_cache.get_cache( + key=request_count_api_key + ) or { + "current_requests": 1, + "current_tpm": total_tokens, + "current_rpm": 1, + } - current = self.user_api_key_cache.get_cache(key=request_count_api_key) or { - "current_requests": 1, - "current_tpm": total_tokens, - "current_rpm": 1, - } + new_val = { + "current_requests": max(current["current_requests"] - 1, 0), + "current_tpm": current["current_tpm"] + total_tokens, + "current_rpm": current["current_rpm"] + 1, + } - new_val = { - "current_requests": max(current["current_requests"] - 1, 0), - "current_tpm": current["current_tpm"] + total_tokens, - "current_rpm": current["current_rpm"] + 1, - } - - self.print_verbose( - f"updated_value in success call: {new_val}, precise_minute: {precise_minute}" - ) - self.user_api_key_cache.set_cache( - request_count_api_key, new_val, ttl=60 - ) # store in cache for 1 min. + self.print_verbose( + f"updated_value in success call: {new_val}, precise_minute: {precise_minute}" + ) + self.user_api_key_cache.set_cache( + request_count_api_key, new_val, ttl=60 + ) # store in cache for 1 min. # ------------ # Update usage - Team diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index e1f604096..ba756d05c 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -350,8 +350,7 @@ async def user_api_key_auth( original_api_key = api_key # (Patch: For DynamoDB Backwards Compatibility) if api_key.startswith("sk-"): api_key = hash_token(token=api_key) - # valid_token = user_api_key_cache.get_cache(key=api_key) - valid_token = None + valid_token = user_api_key_cache.get_cache(key=api_key) if valid_token is None: ## check db verbose_proxy_logger.debug(f"api key: {api_key}") @@ -384,7 +383,6 @@ async def user_api_key_auth( # 6. If token spend per model is under budget per model # 7. If token spend is under team budget # 8. If team spend is under team budget - request_data = await _read_request_body( request=request ) # request data, used across all checks. Making this easily available @@ -627,7 +625,7 @@ async def user_api_key_auth( ) ) - if valid_token.spend > valid_token.team_max_budget: + if valid_token.spend >= valid_token.team_max_budget: raise Exception( f"ExceededTokenBudget: Current spend for token: {valid_token.spend}; Max Budget for Team: {valid_token.team_max_budget}" ) @@ -646,7 +644,7 @@ async def user_api_key_auth( ) ) - if valid_token.team_spend > valid_token.team_max_budget: + if valid_token.team_spend >= valid_token.team_max_budget: raise Exception( f"ExceededTokenBudget: Current Team Spend: {valid_token.team_spend}; Max Budget for Team: {valid_token.team_max_budget}" ) diff --git a/litellm/tests/test_parallel_request_limiter.py b/litellm/tests/test_parallel_request_limiter.py index e402b617b..bd5185a23 100644 --- a/litellm/tests/test_parallel_request_limiter.py +++ b/litellm/tests/test_parallel_request_limiter.py @@ -99,6 +99,59 @@ async def test_pre_call_hook_rpm_limits(): assert e.status_code == 429 +@pytest.mark.asyncio +async def test_pre_call_hook_team_rpm_limits(): + """ + Test if error raised on hitting team rpm limits + """ + litellm.set_verbose = True + _api_key = "sk-12345" + _team_id = "unique-team-id" + user_api_key_dict = UserAPIKeyAuth( + api_key=_api_key, + max_parallel_requests=1, + tpm_limit=9, + rpm_limit=10, + team_rpm_limit=1, + team_id=_team_id, + ) + local_cache = DualCache() + parallel_request_handler = MaxParallelRequestsHandler() + + await parallel_request_handler.async_pre_call_hook( + user_api_key_dict=user_api_key_dict, cache=local_cache, data={}, call_type="" + ) + + kwargs = { + "litellm_params": { + "metadata": {"user_api_key": _api_key, "user_api_key_team_id": _team_id} + } + } + + await parallel_request_handler.async_log_success_event( + kwargs=kwargs, + response_obj="", + start_time="", + end_time="", + ) + + print(f"local_cache: {local_cache}") + + ## Expected cache val: {"current_requests": 0, "current_tpm": 0, "current_rpm": 1} + + try: + await parallel_request_handler.async_pre_call_hook( + user_api_key_dict=user_api_key_dict, + cache=local_cache, + data={}, + call_type="", + ) + + pytest.fail(f"Expected call to fail") + except Exception as e: + assert e.status_code == 429 + + @pytest.mark.asyncio async def test_pre_call_hook_tpm_limits(): """ From d0c205fcd19d783cb109913815c87f8734bd1838 Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Mon, 26 Feb 2024 18:48:55 -0800 Subject: [PATCH 04/11] fix(proxy_server.py): don't cache request on key generate - misses the team related data --- litellm/proxy/proxy_server.py | 8 -------- litellm/tests/test_key_generate_prisma.py | 2 +- 2 files changed, 1 insertion(+), 9 deletions(-) diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index ba756d05c..e3cd257fd 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -1902,14 +1902,6 @@ async def generate_key_helper_fn( saved_token["expires"], datetime ): saved_token["expires"] = saved_token["expires"].isoformat() - if key_data["token"] is not None and isinstance(key_data["token"], str): - hashed_token = hash_token(key_data["token"]) - saved_token["token"] = hashed_token - user_api_key_cache.set_cache( - key=hashed_token, - value=LiteLLM_VerificationToken(**saved_token), # type: ignore - ttl=600, - ) if prisma_client is not None: ## CREATE USER (If necessary) verbose_proxy_logger.debug(f"prisma_client: Creating User={user_data}") diff --git a/litellm/tests/test_key_generate_prisma.py b/litellm/tests/test_key_generate_prisma.py index 0d78b6d3c..d8d63e75f 100644 --- a/litellm/tests/test_key_generate_prisma.py +++ b/litellm/tests/test_key_generate_prisma.py @@ -1241,7 +1241,7 @@ async def test_call_with_key_never_over_budget(prisma_client): # use generated key to auth in result = await user_api_key_auth(request=request, api_key=bearer_token) - print("result from user auth with new key", result) + print("result from user auth with new key: {result}") # update spend using track_cost callback, make 2nd request, it should fail from litellm.proxy.proxy_server import ( From 7942539f9bcca64c249d7c3b253fea80464c29a8 Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Mon, 26 Feb 2024 19:10:05 -0800 Subject: [PATCH 05/11] test: testing fixes --- litellm/tests/test_key_generate_prisma.py | 17 ++--------------- 1 file changed, 2 insertions(+), 15 deletions(-) diff --git a/litellm/tests/test_key_generate_prisma.py b/litellm/tests/test_key_generate_prisma.py index d8d63e75f..3660d6371 100644 --- a/litellm/tests/test_key_generate_prisma.py +++ b/litellm/tests/test_key_generate_prisma.py @@ -124,25 +124,12 @@ def test_generate_and_call_with_valid_key(prisma_client): bearer_token = "Bearer " + generated_key assert generated_key not in user_api_key_cache.in_memory_cache.cache_dict - assert ( - hash_token(generated_key) - in user_api_key_cache.in_memory_cache.cache_dict - ) - cached_value = user_api_key_cache.in_memory_cache.cache_dict[ - hash_token(generated_key) - ] - - print("cached value=", cached_value) - print("cached token", cached_value.token) - - value_from_prisma = valid_token = await prisma_client.get_data( + value_from_prisma = await prisma_client.get_data( token=generated_key, ) print("token from prisma", value_from_prisma) - assert value_from_prisma.token == cached_value.token - request = Request(scope={"type": "http"}) request._url = URL(url="/chat/completions") @@ -1312,7 +1299,7 @@ async def test_call_with_key_over_budget_stream(prisma_client): generated_key = key.key user_id = key.user_id bearer_token = "Bearer " + generated_key - + print(f"generated_key: {generated_key}") request = Request(scope={"type": "http"}) request._url = URL(url="/chat/completions") From 4d9584ecd5f87c5d069a0d7ea1bebcd7dbc4c0f1 Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Mon, 26 Feb 2024 19:57:13 -0800 Subject: [PATCH 06/11] test(proxy_server.py): check if token has team params --- litellm/proxy/proxy_server.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index e3cd257fd..7f8839634 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -614,6 +614,7 @@ async def user_api_key_auth( # Check 6. Token spend is under Team budget if ( valid_token.spend is not None + and hasattr(valid_token, "team_max_budget") and valid_token.team_max_budget is not None ): asyncio.create_task( @@ -632,7 +633,9 @@ async def user_api_key_auth( # Check 7. Team spend is under Team budget if ( - valid_token.team_spend is not None + hasattr(valid_token, "team_spend") + and valid_token.team_spend is not None + and hasattr(valid_token, "team_max_budget") and valid_token.team_max_budget is not None ): asyncio.create_task( From b3574f2b379d05140b2783b73a0579f96437231c Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Mon, 26 Feb 2024 20:09:06 -0800 Subject: [PATCH 07/11] fix(parallel_request_limiter.py): handle none scenario --- litellm/proxy/hooks/parallel_request_limiter.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/litellm/proxy/hooks/parallel_request_limiter.py b/litellm/proxy/hooks/parallel_request_limiter.py index a4fb70c57..4221b064e 100644 --- a/litellm/proxy/hooks/parallel_request_limiter.py +++ b/litellm/proxy/hooks/parallel_request_limiter.py @@ -73,7 +73,11 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger): api_key = user_api_key_dict.api_key max_parallel_requests = user_api_key_dict.max_parallel_requests or sys.maxsize tpm_limit = getattr(user_api_key_dict, "tpm_limit", sys.maxsize) + if tpm_limit is None: + tpm_limit = sys.maxsize rpm_limit = getattr(user_api_key_dict, "rpm_limit", sys.maxsize) + if rpm_limit is None: + rpm_limit = sys.maxsize if api_key is None: return @@ -157,7 +161,11 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger): ## get team tpm/rpm limits team_id = user_api_key_dict.team_id team_tpm_limit = getattr(user_api_key_dict, "team_tpm_limit", sys.maxsize) + if team_tpm_limit is None: + team_tpm_limit = sys.maxsize team_rpm_limit = getattr(user_api_key_dict, "team_rpm_limit", sys.maxsize) + if team_rpm_limit is None: + team_rpm_limit = sys.maxsize if team_tpm_limit is None: team_tpm_limit = sys.maxsize From a428501e686b39aa962fa3ba6d4d8d970ea494b1 Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Mon, 26 Feb 2024 21:45:20 -0800 Subject: [PATCH 08/11] test(test_streaming.py): add more logging --- litellm/tests/test_streaming.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/litellm/tests/test_streaming.py b/litellm/tests/test_streaming.py index 66e8be4cb..960bba4ee 100644 --- a/litellm/tests/test_streaming.py +++ b/litellm/tests/test_streaming.py @@ -406,7 +406,7 @@ def test_completion_palm_stream(): def test_completion_gemini_stream(): try: - litellm.set_verbose = False + litellm.set_verbose = True print("Streaming gemini response") messages = [ {"role": "system", "content": "You are a helpful assistant."}, From 1447621128294797c106de068837d62aaf5e3b43 Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Mon, 26 Feb 2024 22:04:24 -0800 Subject: [PATCH 09/11] fix(utils.py): fix redis cache test --- litellm/tests/test_custom_logger.py | 11 ++++++++++- litellm/utils.py | 6 +++++- 2 files changed, 15 insertions(+), 2 deletions(-) diff --git a/litellm/tests/test_custom_logger.py b/litellm/tests/test_custom_logger.py index e1c87f1a3..a7b0c937f 100644 --- a/litellm/tests/test_custom_logger.py +++ b/litellm/tests/test_custom_logger.py @@ -483,9 +483,12 @@ def test_redis_cache_completion_stream(): max_tokens=40, temperature=0.2, stream=True, + caching=True, ) response_1_content = "" + response_1_id = None for chunk in response1: + response_1_id = chunk.id print(chunk) response_1_content += chunk.choices[0].delta.content or "" print(response_1_content) @@ -497,16 +500,22 @@ def test_redis_cache_completion_stream(): max_tokens=40, temperature=0.2, stream=True, + caching=True, ) response_2_content = "" + response_2_id = None for chunk in response2: + response_2_id = chunk.id print(chunk) response_2_content += chunk.choices[0].delta.content or "" print("\nresponse 1", response_1_content) print("\nresponse 2", response_2_content) assert ( - response_1_content == response_2_content + response_1_id == response_2_id ), f"Response 1 != Response 2. Same params, Response 1{response_1_content} != Response 2{response_2_content}" + # assert ( + # response_1_content == response_2_content + # ), f"Response 1 != Response 2. Same params, Response 1{response_1_content} != Response 2{response_2_content}" litellm.success_callback = [] litellm._async_success_callback = [] litellm.cache = None diff --git a/litellm/utils.py b/litellm/utils.py index 0e718d31a..09c23a725 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -1169,7 +1169,7 @@ class Logging: verbose_logger.debug(f"success callbacks: {litellm.success_callback}") ## BUILD COMPLETE STREAMED RESPONSE complete_streaming_response = None - if self.stream: + if self.stream and isinstance(result, ModelResponse): if ( result.choices[0].finish_reason is not None ): # if it's the last chunk @@ -8654,6 +8654,8 @@ class CustomStreamWrapper: completion_obj["content"] = response_obj["text"] print_verbose(f"completion obj content: {completion_obj['content']}") + if hasattr(chunk, "id"): + model_response.id = chunk.id if response_obj["is_finished"]: model_response.choices[0].finish_reason = response_obj[ "finish_reason" @@ -8676,6 +8678,8 @@ class CustomStreamWrapper: model_response.system_fingerprint = getattr( response_obj["original_chunk"], "system_fingerprint", None ) + if hasattr(response_obj["original_chunk"], "id"): + model_response.id = response_obj["original_chunk"].id if response_obj["logprobs"] is not None: model_response.choices[0].logprobs = response_obj["logprobs"] From 1815ee16d4da5b36af7c90937f47e169f099aebf Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Mon, 26 Feb 2024 22:11:47 -0800 Subject: [PATCH 10/11] fix(proxy_server.py): check if prisma client is initialized before checking if view exists --- litellm/proxy/proxy_server.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index 7f8839634..7fc9cd473 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -2291,8 +2291,9 @@ async def startup_event(): ) ### CHECK IF VIEW EXISTS ### - create_view_response = await prisma_client.check_view_exists() - print(f"create_view_response: {create_view_response}") # noqa + if prisma_client is not None: + create_view_response = await prisma_client.check_view_exists() + print(f"create_view_response: {create_view_response}") # noqa ### START BUDGET SCHEDULER ### if prisma_client is not None: From 4c134b1ec7bb5fac86d1bd4e0480b8c39b9fe52d Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Mon, 26 Feb 2024 22:44:05 -0800 Subject: [PATCH 11/11] fix(proxy/utils.py): fix try-except for creating a view --- litellm/proxy/utils.py | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/litellm/proxy/utils.py b/litellm/proxy/utils.py index 6b945ce72..f814d6098 100644 --- a/litellm/proxy/utils.py +++ b/litellm/proxy/utils.py @@ -504,9 +504,8 @@ class PrismaClient: except Exception as e: # If an error occurs, the view does not exist, so create it value = await self.health_check() - if '"litellm_verificationtokenview" does not exist' in str(e): - await self.db.execute_raw( - """ + await self.db.execute_raw( + """ CREATE VIEW "LiteLLM_VerificationTokenView" AS SELECT v.*, @@ -517,9 +516,7 @@ class PrismaClient: FROM "LiteLLM_VerificationToken" v LEFT JOIN "LiteLLM_TeamTable" t ON v.team_id = t.team_id; """ - ) - else: - raise e + ) return "LiteLLM_VerificationTokenView Created!"