forked from phoenix/litellm-mirror
[Fix proxy perf] Use correct cache key when reading from redis cache (#5928)
* fix parallel request limiter use correct user id * async def get_user_object( fix * use safe get_internal_user_object * fix store internal users in redis correctly
This commit is contained in:
parent
8b6eec1951
commit
58171f35ef
3 changed files with 51 additions and 7 deletions
|
@ -386,8 +386,6 @@ async def get_user_object(
|
||||||
- if valid, return LiteLLM_UserTable object with defined limits
|
- if valid, return LiteLLM_UserTable object with defined limits
|
||||||
- if not, then raise an error
|
- if not, then raise an error
|
||||||
"""
|
"""
|
||||||
if prisma_client is None:
|
|
||||||
raise Exception("No db connected")
|
|
||||||
|
|
||||||
if user_id is None:
|
if user_id is None:
|
||||||
return None
|
return None
|
||||||
|
@ -400,6 +398,8 @@ async def get_user_object(
|
||||||
elif isinstance(cached_user_obj, LiteLLM_UserTable):
|
elif isinstance(cached_user_obj, LiteLLM_UserTable):
|
||||||
return cached_user_obj
|
return cached_user_obj
|
||||||
# else, check db
|
# else, check db
|
||||||
|
if prisma_client is None:
|
||||||
|
raise Exception("No db connected")
|
||||||
try:
|
try:
|
||||||
|
|
||||||
response = await prisma_client.db.litellm_usertable.find_unique(
|
response = await prisma_client.db.litellm_usertable.find_unique(
|
||||||
|
@ -415,9 +415,10 @@ async def get_user_object(
|
||||||
raise Exception
|
raise Exception
|
||||||
|
|
||||||
_response = LiteLLM_UserTable(**dict(response))
|
_response = LiteLLM_UserTable(**dict(response))
|
||||||
|
response_dict = _response.model_dump()
|
||||||
|
|
||||||
# save the user object to cache
|
# save the user object to cache
|
||||||
await user_api_key_cache.async_set_cache(key=user_id, value=_response)
|
await user_api_key_cache.async_set_cache(key=user_id, value=response_dict)
|
||||||
|
|
||||||
return _response
|
return _response
|
||||||
except Exception: # if user not in db
|
except Exception: # if user not in db
|
||||||
|
|
|
@ -322,9 +322,9 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger):
|
||||||
# check if REQUEST ALLOWED for user_id
|
# check if REQUEST ALLOWED for user_id
|
||||||
user_id = user_api_key_dict.user_id
|
user_id = user_api_key_dict.user_id
|
||||||
if user_id is not None:
|
if user_id is not None:
|
||||||
_user_id_rate_limits = await self.internal_usage_cache.async_get_cache(
|
_user_id_rate_limits = await self.get_internal_user_object(
|
||||||
key=user_id,
|
user_id=user_id,
|
||||||
litellm_parent_otel_span=user_api_key_dict.parent_otel_span,
|
user_api_key_dict=user_api_key_dict,
|
||||||
)
|
)
|
||||||
# get user tpm/rpm limits
|
# get user tpm/rpm limits
|
||||||
if _user_id_rate_limits is not None and isinstance(
|
if _user_id_rate_limits is not None and isinstance(
|
||||||
|
@ -741,3 +741,39 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger):
|
||||||
str(e)
|
str(e)
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
async def get_internal_user_object(
|
||||||
|
self,
|
||||||
|
user_id: str,
|
||||||
|
user_api_key_dict: UserAPIKeyAuth,
|
||||||
|
) -> Optional[dict]:
|
||||||
|
"""
|
||||||
|
Helper to get the 'Internal User Object'
|
||||||
|
|
||||||
|
It uses the `get_user_object` function from `litellm.proxy.auth.auth_checks`
|
||||||
|
|
||||||
|
We need this because the UserApiKeyAuth object does not contain the rpm/tpm limits for a User AND there could be a perf impact by additionally reading the UserTable.
|
||||||
|
"""
|
||||||
|
from litellm._logging import verbose_proxy_logger
|
||||||
|
from litellm.proxy.auth.auth_checks import get_user_object
|
||||||
|
from litellm.proxy.proxy_server import prisma_client
|
||||||
|
|
||||||
|
try:
|
||||||
|
_user_id_rate_limits = await get_user_object(
|
||||||
|
user_id=user_id,
|
||||||
|
prisma_client=prisma_client,
|
||||||
|
user_api_key_cache=self.internal_usage_cache.dual_cache,
|
||||||
|
user_id_upsert=False,
|
||||||
|
parent_otel_span=user_api_key_dict.parent_otel_span,
|
||||||
|
proxy_logging_obj=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
if _user_id_rate_limits is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
return _user_id_rate_limits.model_dump()
|
||||||
|
except Exception as e:
|
||||||
|
verbose_proxy_logger.exception(
|
||||||
|
"Parallel Request Limiter: Error getting user object", str(e)
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
|
|
@ -295,7 +295,13 @@ async def test_pre_call_hook_user_tpm_limits():
|
||||||
local_cache = DualCache()
|
local_cache = DualCache()
|
||||||
# create user with tpm/rpm limits
|
# create user with tpm/rpm limits
|
||||||
user_id = "test-user"
|
user_id = "test-user"
|
||||||
user_obj = {"tpm_limit": 9, "rpm_limit": 10}
|
user_obj = {
|
||||||
|
"tpm_limit": 9,
|
||||||
|
"rpm_limit": 10,
|
||||||
|
"user_id": user_id,
|
||||||
|
"user_email": "user_email",
|
||||||
|
"max_budget": None,
|
||||||
|
}
|
||||||
|
|
||||||
local_cache.set_cache(key=user_id, value=user_obj)
|
local_cache.set_cache(key=user_id, value=user_obj)
|
||||||
|
|
||||||
|
@ -331,6 +337,7 @@ async def test_pre_call_hook_user_tpm_limits():
|
||||||
## Expected cache val: {"current_requests": 0, "current_tpm": 0, "current_rpm": 1}
|
## Expected cache val: {"current_requests": 0, "current_tpm": 0, "current_rpm": 1}
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
print("cache=local_cache", local_cache.in_memory_cache.cache_dict)
|
||||||
await parallel_request_handler.async_pre_call_hook(
|
await parallel_request_handler.async_pre_call_hook(
|
||||||
user_api_key_dict=user_api_key_dict,
|
user_api_key_dict=user_api_key_dict,
|
||||||
cache=local_cache,
|
cache=local_cache,
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue