fix(user_api_key_auth.py): update team values in token cache if refreshed more recently

This commit is contained in:
Krrish Dholakia 2024-07-19 17:35:59 -07:00
parent 35e640076b
commit 99aa311083
4 changed files with 31 additions and 9 deletions

View file

@ -886,6 +886,7 @@ class LiteLLM_TeamTable(TeamBase):
budget_duration: Optional[str] = None budget_duration: Optional[str] = None
budget_reset_at: Optional[datetime] = None budget_reset_at: Optional[datetime] = None
model_id: Optional[int] = None model_id: Optional[int] = None
last_refreshed_at: Optional[float] = None
model_config = ConfigDict(protected_namespaces=()) model_config = ConfigDict(protected_namespaces=())
@ -1238,6 +1239,9 @@ class LiteLLM_VerificationTokenView(LiteLLM_VerificationToken):
end_user_rpm_limit: Optional[int] = None end_user_rpm_limit: Optional[int] = None
end_user_max_budget: Optional[float] = None end_user_max_budget: Optional[float] = None
# Time stamps
last_refreshed_at: Optional[float] = None # last time joint view was pulled from db
class UserAPIKeyAuth( class UserAPIKeyAuth(
LiteLLM_VerificationTokenView LiteLLM_VerificationTokenView

View file

@ -467,6 +467,11 @@ async def user_api_key_auth(
key="team_id:{}".format(valid_token.team_id) key="team_id:{}".format(valid_token.team_id)
) )
if (
team_obj.last_refreshed_at is not None
and valid_token.last_refreshed_at is not None
and team_obj.last_refreshed_at > valid_token.last_refreshed_at
):
team_obj_dict = team_obj.__dict__ team_obj_dict = team_obj.__dict__
for k, v in team_obj_dict.items(): for k, v in team_obj_dict.items():
@ -541,10 +546,13 @@ async def user_api_key_auth(
parent_otel_span=parent_otel_span, parent_otel_span=parent_otel_span,
proxy_logging_obj=proxy_logging_obj, proxy_logging_obj=proxy_logging_obj,
) )
if _valid_token is not None: if _valid_token is not None:
## update cached token
valid_token = UserAPIKeyAuth( valid_token = UserAPIKeyAuth(
**_valid_token.model_dump(exclude_none=True) **_valid_token.model_dump(exclude_none=True)
) )
verbose_proxy_logger.debug("Token from db: %s", valid_token) verbose_proxy_logger.debug("Token from db: %s", valid_token)
elif valid_token is not None and isinstance(valid_token, UserAPIKeyAuth): elif valid_token is not None and isinstance(valid_token, UserAPIKeyAuth):
verbose_proxy_logger.debug("API Key Cache Hit!") verbose_proxy_logger.debug("API Key Cache Hit!")

View file

@ -1331,7 +1331,9 @@ class PrismaClient:
response["team_models"] = [] response["team_models"] = []
if response["team_blocked"] is None: if response["team_blocked"] is None:
response["team_blocked"] = False response["team_blocked"] = False
response = LiteLLM_VerificationTokenView(**response) response = LiteLLM_VerificationTokenView(
**response, last_refreshed_at=time.time()
)
# for prisma we need to cast the expires time to str # for prisma we need to cast the expires time to str
if response.expires is not None and isinstance( if response.expires is not None and isinstance(
response.expires, datetime response.expires, datetime

View file

@ -55,6 +55,9 @@ async def test_check_blocked_team():
assert team is not blocked assert team is not blocked
""" """
import asyncio
import time
from fastapi import Request from fastapi import Request
from starlette.datastructures import URL from starlette.datastructures import URL
@ -63,12 +66,17 @@ async def test_check_blocked_team():
from litellm.proxy.proxy_server import hash_token, user_api_key_cache from litellm.proxy.proxy_server import hash_token, user_api_key_cache
_team_id = "1234" _team_id = "1234"
team_obj = LiteLLM_TeamTable(team_id=_team_id, blocked=False)
user_key = "sk-12345678" user_key = "sk-12345678"
valid_token = UserAPIKeyAuth( valid_token = UserAPIKeyAuth(
team_id=_team_id, team_blocked=True, token=hash_token(user_key) team_id=_team_id,
team_blocked=True,
token=hash_token(user_key),
last_refreshed_at=time.time(),
)
await asyncio.sleep(1)
team_obj = LiteLLM_TeamTable(
team_id=_team_id, blocked=False, last_refreshed_at=time.time()
) )
user_api_key_cache.set_cache(key=hash_token(user_key), value=valid_token) user_api_key_cache.set_cache(key=hash_token(user_key), value=valid_token)
user_api_key_cache.set_cache(key="team_id:{}".format(_team_id), value=team_obj) user_api_key_cache.set_cache(key="team_id:{}".format(_team_id), value=team_obj)