diff --git a/litellm/proxy/auth/auth_checks.py b/litellm/proxy/auth/auth_checks.py index 80113fff3..52e6100d3 100644 --- a/litellm/proxy/auth/auth_checks.py +++ b/litellm/proxy/auth/auth_checks.py @@ -386,8 +386,6 @@ async def get_user_object( - if valid, return LiteLLM_UserTable object with defined limits - if not, then raise an error """ - if prisma_client is None: - raise Exception("No db connected") if user_id is None: return None @@ -400,6 +398,8 @@ async def get_user_object( elif isinstance(cached_user_obj, LiteLLM_UserTable): return cached_user_obj # else, check db + if prisma_client is None: + raise Exception("No db connected") try: response = await prisma_client.db.litellm_usertable.find_unique( @@ -415,9 +415,10 @@ async def get_user_object( raise Exception _response = LiteLLM_UserTable(**dict(response)) + response_dict = _response.model_dump() # save the user object to cache - await user_api_key_cache.async_set_cache(key=user_id, value=_response) + await user_api_key_cache.async_set_cache(key=user_id, value=response_dict) return _response except Exception: # if user not in db diff --git a/litellm/proxy/hooks/parallel_request_limiter.py b/litellm/proxy/hooks/parallel_request_limiter.py index f5c45f3bd..7eaf515f2 100644 --- a/litellm/proxy/hooks/parallel_request_limiter.py +++ b/litellm/proxy/hooks/parallel_request_limiter.py @@ -322,9 +322,9 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger): # check if REQUEST ALLOWED for user_id user_id = user_api_key_dict.user_id if user_id is not None: - _user_id_rate_limits = await self.internal_usage_cache.async_get_cache( - key=user_id, - litellm_parent_otel_span=user_api_key_dict.parent_otel_span, + _user_id_rate_limits = await self.get_internal_user_object( + user_id=user_id, + user_api_key_dict=user_api_key_dict, ) # get user tpm/rpm limits if _user_id_rate_limits is not None and isinstance( @@ -741,3 +741,39 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger): str(e) ) ) + + async def get_internal_user_object( + self, + user_id: str, + user_api_key_dict: UserAPIKeyAuth, + ) -> Optional[dict]: + """ + Helper to get the 'Internal User Object' + + It uses the `get_user_object` function from `litellm.proxy.auth.auth_checks` + + We need this because the UserApiKeyAuth object does not contain the rpm/tpm limits for a User AND there could be a perf impact by additionally reading the UserTable. + """ + from litellm._logging import verbose_proxy_logger + from litellm.proxy.auth.auth_checks import get_user_object + from litellm.proxy.proxy_server import prisma_client + + try: + _user_id_rate_limits = await get_user_object( + user_id=user_id, + prisma_client=prisma_client, + user_api_key_cache=self.internal_usage_cache.dual_cache, + user_id_upsert=False, + parent_otel_span=user_api_key_dict.parent_otel_span, + proxy_logging_obj=None, + ) + + if _user_id_rate_limits is None: + return None + + return _user_id_rate_limits.model_dump() + except Exception as e: + verbose_proxy_logger.exception( + "Parallel Request Limiter: Error getting user object", str(e) + ) + return None diff --git a/litellm/tests/test_parallel_request_limiter.py b/litellm/tests/test_parallel_request_limiter.py index 93e86404e..dee7fefbf 100644 --- a/litellm/tests/test_parallel_request_limiter.py +++ b/litellm/tests/test_parallel_request_limiter.py @@ -295,7 +295,13 @@ async def test_pre_call_hook_user_tpm_limits(): local_cache = DualCache() # create user with tpm/rpm limits user_id = "test-user" - user_obj = {"tpm_limit": 9, "rpm_limit": 10} + user_obj = { + "tpm_limit": 9, + "rpm_limit": 10, + "user_id": user_id, + "user_email": "user_email", + "max_budget": None, + } local_cache.set_cache(key=user_id, value=user_obj) @@ -331,6 +337,7 @@ async def test_pre_call_hook_user_tpm_limits(): ## Expected cache val: {"current_requests": 0, "current_tpm": 0, "current_rpm": 1} try: + print("cache=local_cache", local_cache.in_memory_cache.cache_dict) await parallel_request_handler.async_pre_call_hook( user_api_key_dict=user_api_key_dict, cache=local_cache,