diff --git a/litellm/proxy/_types.py b/litellm/proxy/_types.py index 778a012b6f..0869713660 100644 --- a/litellm/proxy/_types.py +++ b/litellm/proxy/_types.py @@ -458,8 +458,19 @@ class LiteLLM_VerificationToken(LiteLLMBase): protected_namespaces = () +class LiteLLM_VerificationTokenView(LiteLLM_VerificationToken): + """ + Combined view of litellm verification token + litellm team table (select values) + """ + + team_spend: Optional[float] = None + team_tpm_limit: Optional[int] = None + team_rpm_limit: Optional[int] = None + team_max_budget: Optional[float] = None + + class UserAPIKeyAuth( - LiteLLM_VerificationToken + LiteLLM_VerificationTokenView ): # the expected response object for user api key auth """ Return the row in the db diff --git a/litellm/proxy/hooks/parallel_request_limiter.py b/litellm/proxy/hooks/parallel_request_limiter.py index fb61fe3da6..4221b064ee 100644 --- a/litellm/proxy/hooks/parallel_request_limiter.py +++ b/litellm/proxy/hooks/parallel_request_limiter.py @@ -38,7 +38,6 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger): current = cache.get_cache( key=request_count_api_key ) # {"current_requests": 1, "current_tpm": 1, "current_rpm": 10} - # print(f"current: {current}") if current is None: new_val = { "current_requests": 1, @@ -73,8 +72,12 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger): self.print_verbose(f"Inside Max Parallel Request Pre-Call Hook") api_key = user_api_key_dict.api_key max_parallel_requests = user_api_key_dict.max_parallel_requests or sys.maxsize - tpm_limit = user_api_key_dict.tpm_limit or sys.maxsize - rpm_limit = user_api_key_dict.rpm_limit or sys.maxsize + tpm_limit = getattr(user_api_key_dict, "tpm_limit", sys.maxsize) + if tpm_limit is None: + tpm_limit = sys.maxsize + rpm_limit = getattr(user_api_key_dict, "rpm_limit", sys.maxsize) + if rpm_limit is None: + rpm_limit = sys.maxsize if api_key is None: return @@ -131,17 +134,46 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger): _user_id_rate_limits = user_api_key_dict.user_id_rate_limits # get user tpm/rpm limits - if _user_id_rate_limits is None or _user_id_rate_limits == {}: - return - user_tpm_limit = _user_id_rate_limits.get("tpm_limit") - user_rpm_limit = _user_id_rate_limits.get("rpm_limit") - if user_tpm_limit is None: - user_tpm_limit = sys.maxsize - if user_rpm_limit is None: - user_rpm_limit = sys.maxsize + if _user_id_rate_limits is not None and isinstance(_user_id_rate_limits, dict): + user_tpm_limit = _user_id_rate_limits.get("tpm_limit", None) + user_rpm_limit = _user_id_rate_limits.get("rpm_limit", None) + if user_tpm_limit is None: + user_tpm_limit = sys.maxsize + if user_rpm_limit is None: + user_rpm_limit = sys.maxsize + + # now do the same tpm/rpm checks + request_count_api_key = f"{user_id}::{precise_minute}::request_count" + + # print(f"Checking if {request_count_api_key} is allowed to make request for minute {precise_minute}") + await self.check_key_in_limits( + user_api_key_dict=user_api_key_dict, + cache=cache, + data=data, + call_type=call_type, + max_parallel_requests=sys.maxsize, # TODO: Support max parallel requests for a user + request_count_api_key=request_count_api_key, + tpm_limit=user_tpm_limit, + rpm_limit=user_rpm_limit, + ) + + # TEAM RATE LIMITS + ## get team tpm/rpm limits + team_id = user_api_key_dict.team_id + team_tpm_limit = getattr(user_api_key_dict, "team_tpm_limit", sys.maxsize) + if team_tpm_limit is None: + team_tpm_limit = sys.maxsize + team_rpm_limit = getattr(user_api_key_dict, "team_rpm_limit", sys.maxsize) + if team_rpm_limit is None: + team_rpm_limit = sys.maxsize + + if team_tpm_limit is None: + team_tpm_limit = sys.maxsize + if team_rpm_limit is None: + team_rpm_limit = sys.maxsize # now do the same tpm/rpm checks - request_count_api_key = f"{user_id}::{precise_minute}::request_count" + request_count_api_key = f"{team_id}::{precise_minute}::request_count" # print(f"Checking if {request_count_api_key} is allowed to make request for minute {precise_minute}") await self.check_key_in_limits( @@ -151,8 +183,8 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger): call_type=call_type, max_parallel_requests=sys.maxsize, # TODO: Support max parallel requests for a user request_count_api_key=request_count_api_key, - tpm_limit=user_tpm_limit, - rpm_limit=user_rpm_limit, + tpm_limit=team_tpm_limit, + rpm_limit=team_rpm_limit, ) return @@ -163,6 +195,9 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger): user_api_key_user_id = kwargs["litellm_params"]["metadata"].get( "user_api_key_user_id", None ) + user_api_key_team_id = kwargs["litellm_params"]["metadata"].get( + "user_api_key_team_id", None + ) if user_api_key is None: return @@ -212,7 +247,41 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger): # ------------ # Update usage - User # ------------ - if user_api_key_user_id is None: + if user_api_key_user_id is not None: + total_tokens = 0 + + if isinstance(response_obj, ModelResponse): + total_tokens = response_obj.usage.total_tokens + + request_count_api_key = ( + f"{user_api_key_user_id}::{precise_minute}::request_count" + ) + + current = self.user_api_key_cache.get_cache( + key=request_count_api_key + ) or { + "current_requests": 1, + "current_tpm": total_tokens, + "current_rpm": 1, + } + + new_val = { + "current_requests": max(current["current_requests"] - 1, 0), + "current_tpm": current["current_tpm"] + total_tokens, + "current_rpm": current["current_rpm"] + 1, + } + + self.print_verbose( + f"updated_value in success call: {new_val}, precise_minute: {precise_minute}" + ) + self.user_api_key_cache.set_cache( + request_count_api_key, new_val, ttl=60 + ) # store in cache for 1 min. + + # ------------ + # Update usage - Team + # ------------ + if user_api_key_team_id is None: return total_tokens = 0 @@ -221,7 +290,7 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger): total_tokens = response_obj.usage.total_tokens request_count_api_key = ( - f"{user_api_key_user_id}::{precise_minute}::request_count" + f"{user_api_key_team_id}::{precise_minute}::request_count" ) current = self.user_api_key_cache.get_cache(key=request_count_api_key) or { diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index 4a0aafdcce..6b9e3f644c 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -356,7 +356,7 @@ async def user_api_key_auth( verbose_proxy_logger.debug(f"api key: {api_key}") if prisma_client is not None: valid_token = await prisma_client.get_data( - token=api_key, + token=api_key, table_name="combined_view" ) elif custom_db_client is not None: @@ -381,7 +381,8 @@ async def user_api_key_auth( # 4. If token is expired # 5. If token spend is under Budget for the token # 6. If token spend per model is under budget per model - + # 7. If token spend is under team budget + # 8. If team spend is under team budget request_data = await _read_request_body( request=request ) # request data, used across all checks. Making this easily available @@ -610,6 +611,47 @@ async def user_api_key_auth( f"ExceededModelBudget: Current spend for model: {current_model_spend}; Max Budget for Model: {current_model_budget}" ) + # Check 6. Token spend is under Team budget + if ( + valid_token.spend is not None + and hasattr(valid_token, "team_max_budget") + and valid_token.team_max_budget is not None + ): + asyncio.create_task( + proxy_logging_obj.budget_alerts( + user_max_budget=valid_token.team_max_budget, + user_current_spend=valid_token.spend, + type="token_budget", + user_info=valid_token, + ) + ) + + if valid_token.spend >= valid_token.team_max_budget: + raise Exception( + f"ExceededTokenBudget: Current spend for token: {valid_token.spend}; Max Budget for Team: {valid_token.team_max_budget}" + ) + + # Check 7. Team spend is under Team budget + if ( + hasattr(valid_token, "team_spend") + and valid_token.team_spend is not None + and hasattr(valid_token, "team_max_budget") + and valid_token.team_max_budget is not None + ): + asyncio.create_task( + proxy_logging_obj.budget_alerts( + user_max_budget=valid_token.team_max_budget, + user_current_spend=valid_token.team_spend, + type="token_budget", + user_info=valid_token, + ) + ) + + if valid_token.team_spend >= valid_token.team_max_budget: + raise Exception( + f"ExceededTokenBudget: Current Team Spend: {valid_token.team_spend}; Max Budget for Team: {valid_token.team_max_budget}" + ) + # Token passed all checks api_key = valid_token.token @@ -1870,14 +1912,6 @@ async def generate_key_helper_fn( saved_token["expires"], datetime ): saved_token["expires"] = saved_token["expires"].isoformat() - if key_data["token"] is not None and isinstance(key_data["token"], str): - hashed_token = hash_token(key_data["token"]) - saved_token["token"] = hashed_token - user_api_key_cache.set_cache( - key=hashed_token, - value=LiteLLM_VerificationToken(**saved_token), # type: ignore - ttl=600, - ) if prisma_client is not None: ## CREATE USER (If necessary) verbose_proxy_logger.debug(f"prisma_client: Creating User={user_data}") @@ -2263,6 +2297,11 @@ async def startup_event(): duration=None, models=[], aliases={}, config={}, spend=0, token=master_key ) + ### CHECK IF VIEW EXISTS ### + if prisma_client is not None: + create_view_response = await prisma_client.check_view_exists() + print(f"create_view_response: {create_view_response}") # noqa + ### START BUDGET SCHEDULER ### if prisma_client is not None: scheduler = AsyncIOScheduler() diff --git a/litellm/proxy/utils.py b/litellm/proxy/utils.py index 8defd918cc..f814d60988 100644 --- a/litellm/proxy/utils.py +++ b/litellm/proxy/utils.py @@ -5,6 +5,7 @@ from litellm.proxy._types import ( UserAPIKeyAuth, DynamoDBArgs, LiteLLM_VerificationToken, + LiteLLM_VerificationTokenView, LiteLLM_SpendLogs, ) from litellm.caching import DualCache @@ -479,6 +480,46 @@ class PrismaClient: db_data[k] = json.dumps(v) return db_data + @backoff.on_exception( + backoff.expo, + Exception, # base exception to catch for the backoff + max_tries=3, # maximum number of retries + max_time=10, # maximum total time to retry for + on_backoff=on_backoff, # specifying the function to call on backoff + ) + async def check_view_exists(self): + """ + Checks if the LiteLLM_VerificationTokenView exists in the user's db. + + This is used for getting the token + team data in user_api_key_auth + + If the view doesn't exist, one will be created. + """ + try: + # Try to select one row from the view + await self.db.execute_raw( + """SELECT 1 FROM "LiteLLM_VerificationTokenView" LIMIT 1""" + ) + return "LiteLLM_VerificationTokenView Exists!" + except Exception as e: + # If an error occurs, the view does not exist, so create it + value = await self.health_check() + await self.db.execute_raw( + """ + CREATE VIEW "LiteLLM_VerificationTokenView" AS + SELECT + v.*, + t.spend AS team_spend, + t.max_budget AS team_max_budget, + t.tpm_limit AS team_tpm_limit, + t.rpm_limit AS team_rpm_limit + FROM "LiteLLM_VerificationToken" v + LEFT JOIN "LiteLLM_TeamTable" t ON v.team_id = t.team_id; + """ + ) + + return "LiteLLM_VerificationTokenView Created!" + @backoff.on_exception( backoff.expo, Exception, # base exception to catch for the backoff @@ -535,7 +576,15 @@ class PrismaClient: team_id_list: Optional[list] = None, key_val: Optional[dict] = None, table_name: Optional[ - Literal["user", "key", "config", "spend", "team", "user_notification"] + Literal[ + "user", + "key", + "config", + "spend", + "team", + "user_notification", + "combined_view", + ] ] = None, query_type: Literal["find_unique", "find_all"] = "find_unique", expires: Optional[datetime] = None, @@ -543,7 +592,9 @@ class PrismaClient: ): try: response: Any = None - if token is not None or (table_name is not None and table_name == "key"): + if (token is not None and table_name is None) or ( + table_name is not None and table_name == "key" + ): # check if plain text or hash if token is not None: if isinstance(token, str): @@ -723,6 +774,38 @@ class PrismaClient: elif query_type == "find_all": response = await self.db.litellm_usernotifications.find_many() # type: ignore return response + elif table_name == "combined_view": + # check if plain text or hash + if token is not None: + if isinstance(token, str): + hashed_token = token + if token.startswith("sk-"): + hashed_token = self.hash_token(token=token) + verbose_proxy_logger.debug( + f"PrismaClient: find_unique for token: {hashed_token}" + ) + if query_type == "find_unique": + if token is None: + raise HTTPException( + status_code=400, + detail={"error": f"No token passed in. Token={token}"}, + ) + + sql_query = f""" + SELECT * + FROM "LiteLLM_VerificationTokenView" + WHERE token = '{token}' + """ + + response = await self.db.query_first(query=sql_query) + if response is not None: + response = LiteLLM_VerificationTokenView(**response) + # for prisma we need to cast the expires time to str + if response.expires is not None and isinstance( + response.expires, datetime + ): + response.expires = response.expires.isoformat() + return response except Exception as e: print_verbose(f"LiteLLM Prisma Client Exception: {e}") import traceback diff --git a/litellm/tests/test_custom_logger.py b/litellm/tests/test_custom_logger.py index e1c87f1a32..a7b0c937f0 100644 --- a/litellm/tests/test_custom_logger.py +++ b/litellm/tests/test_custom_logger.py @@ -483,9 +483,12 @@ def test_redis_cache_completion_stream(): max_tokens=40, temperature=0.2, stream=True, + caching=True, ) response_1_content = "" + response_1_id = None for chunk in response1: + response_1_id = chunk.id print(chunk) response_1_content += chunk.choices[0].delta.content or "" print(response_1_content) @@ -497,16 +500,22 @@ def test_redis_cache_completion_stream(): max_tokens=40, temperature=0.2, stream=True, + caching=True, ) response_2_content = "" + response_2_id = None for chunk in response2: + response_2_id = chunk.id print(chunk) response_2_content += chunk.choices[0].delta.content or "" print("\nresponse 1", response_1_content) print("\nresponse 2", response_2_content) assert ( - response_1_content == response_2_content + response_1_id == response_2_id ), f"Response 1 != Response 2. Same params, Response 1{response_1_content} != Response 2{response_2_content}" + # assert ( + # response_1_content == response_2_content + # ), f"Response 1 != Response 2. Same params, Response 1{response_1_content} != Response 2{response_2_content}" litellm.success_callback = [] litellm._async_success_callback = [] litellm.cache = None diff --git a/litellm/tests/test_key_generate_prisma.py b/litellm/tests/test_key_generate_prisma.py index 0d78b6d3c5..3660d6371c 100644 --- a/litellm/tests/test_key_generate_prisma.py +++ b/litellm/tests/test_key_generate_prisma.py @@ -124,25 +124,12 @@ def test_generate_and_call_with_valid_key(prisma_client): bearer_token = "Bearer " + generated_key assert generated_key not in user_api_key_cache.in_memory_cache.cache_dict - assert ( - hash_token(generated_key) - in user_api_key_cache.in_memory_cache.cache_dict - ) - cached_value = user_api_key_cache.in_memory_cache.cache_dict[ - hash_token(generated_key) - ] - - print("cached value=", cached_value) - print("cached token", cached_value.token) - - value_from_prisma = valid_token = await prisma_client.get_data( + value_from_prisma = await prisma_client.get_data( token=generated_key, ) print("token from prisma", value_from_prisma) - assert value_from_prisma.token == cached_value.token - request = Request(scope={"type": "http"}) request._url = URL(url="/chat/completions") @@ -1241,7 +1228,7 @@ async def test_call_with_key_never_over_budget(prisma_client): # use generated key to auth in result = await user_api_key_auth(request=request, api_key=bearer_token) - print("result from user auth with new key", result) + print("result from user auth with new key: {result}") # update spend using track_cost callback, make 2nd request, it should fail from litellm.proxy.proxy_server import ( @@ -1312,7 +1299,7 @@ async def test_call_with_key_over_budget_stream(prisma_client): generated_key = key.key user_id = key.user_id bearer_token = "Bearer " + generated_key - + print(f"generated_key: {generated_key}") request = Request(scope={"type": "http"}) request._url = URL(url="/chat/completions") diff --git a/litellm/tests/test_parallel_request_limiter.py b/litellm/tests/test_parallel_request_limiter.py index e402b617b7..bd5185a23a 100644 --- a/litellm/tests/test_parallel_request_limiter.py +++ b/litellm/tests/test_parallel_request_limiter.py @@ -99,6 +99,59 @@ async def test_pre_call_hook_rpm_limits(): assert e.status_code == 429 +@pytest.mark.asyncio +async def test_pre_call_hook_team_rpm_limits(): + """ + Test if error raised on hitting team rpm limits + """ + litellm.set_verbose = True + _api_key = "sk-12345" + _team_id = "unique-team-id" + user_api_key_dict = UserAPIKeyAuth( + api_key=_api_key, + max_parallel_requests=1, + tpm_limit=9, + rpm_limit=10, + team_rpm_limit=1, + team_id=_team_id, + ) + local_cache = DualCache() + parallel_request_handler = MaxParallelRequestsHandler() + + await parallel_request_handler.async_pre_call_hook( + user_api_key_dict=user_api_key_dict, cache=local_cache, data={}, call_type="" + ) + + kwargs = { + "litellm_params": { + "metadata": {"user_api_key": _api_key, "user_api_key_team_id": _team_id} + } + } + + await parallel_request_handler.async_log_success_event( + kwargs=kwargs, + response_obj="", + start_time="", + end_time="", + ) + + print(f"local_cache: {local_cache}") + + ## Expected cache val: {"current_requests": 0, "current_tpm": 0, "current_rpm": 1} + + try: + await parallel_request_handler.async_pre_call_hook( + user_api_key_dict=user_api_key_dict, + cache=local_cache, + data={}, + call_type="", + ) + + pytest.fail(f"Expected call to fail") + except Exception as e: + assert e.status_code == 429 + + @pytest.mark.asyncio async def test_pre_call_hook_tpm_limits(): """ diff --git a/litellm/utils.py b/litellm/utils.py index c59a3c1e5c..4502e14de2 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -1183,7 +1183,7 @@ class Logging: verbose_logger.debug(f"success callbacks: {litellm.success_callback}") ## BUILD COMPLETE STREAMED RESPONSE complete_streaming_response = None - if self.stream: + if self.stream and isinstance(result, ModelResponse): if ( result.choices[0].finish_reason is not None ): # if it's the last chunk @@ -8682,6 +8682,8 @@ class CustomStreamWrapper: completion_obj["content"] = response_obj["text"] print_verbose(f"completion obj content: {completion_obj['content']}") + if hasattr(chunk, "id"): + model_response.id = chunk.id if response_obj["is_finished"]: model_response.choices[0].finish_reason = response_obj[ "finish_reason" @@ -8704,6 +8706,8 @@ class CustomStreamWrapper: model_response.system_fingerprint = getattr( response_obj["original_chunk"], "system_fingerprint", None ) + if hasattr(response_obj["original_chunk"], "id"): + model_response.id = response_obj["original_chunk"].id if response_obj["logprobs"] is not None: model_response.choices[0].logprobs = response_obj["logprobs"]