fix(user_api_key_auth.py): update team values in token cache if refreshed more recently

This commit is contained in:
Krrish Dholakia 2024-07-19 17:35:59 -07:00
parent 35e640076b
commit 99aa311083
4 changed files with 31 additions and 9 deletions

View file

@ -886,6 +886,7 @@ class LiteLLM_TeamTable(TeamBase):
budget_duration: Optional[str] = None
budget_reset_at: Optional[datetime] = None
model_id: Optional[int] = None
last_refreshed_at: Optional[float] = None
model_config = ConfigDict(protected_namespaces=())
@ -1238,6 +1239,9 @@ class LiteLLM_VerificationTokenView(LiteLLM_VerificationToken):
end_user_rpm_limit: Optional[int] = None
end_user_max_budget: Optional[float] = None
# Time stamps
last_refreshed_at: Optional[float] = None # last time joint view was pulled from db
class UserAPIKeyAuth(
LiteLLM_VerificationTokenView

View file

@ -467,12 +467,17 @@ async def user_api_key_auth(
key="team_id:{}".format(valid_token.team_id)
)
team_obj_dict = team_obj.__dict__
if (
team_obj.last_refreshed_at is not None
and valid_token.last_refreshed_at is not None
and team_obj.last_refreshed_at > valid_token.last_refreshed_at
):
team_obj_dict = team_obj.__dict__
for k, v in team_obj_dict.items():
field_name = f"team_{k}"
if field_name in valid_token.__fields__:
setattr(valid_token, field_name, v)
for k, v in team_obj_dict.items():
field_name = f"team_{k}"
if field_name in valid_token.__fields__:
setattr(valid_token, field_name, v)
try:
is_master_key_valid = secrets.compare_digest(api_key, master_key) # type: ignore
@ -541,10 +546,13 @@ async def user_api_key_auth(
parent_otel_span=parent_otel_span,
proxy_logging_obj=proxy_logging_obj,
)
if _valid_token is not None:
## update cached token
valid_token = UserAPIKeyAuth(
**_valid_token.model_dump(exclude_none=True)
)
verbose_proxy_logger.debug("Token from db: %s", valid_token)
elif valid_token is not None and isinstance(valid_token, UserAPIKeyAuth):
verbose_proxy_logger.debug("API Key Cache Hit!")

View file

@ -1331,7 +1331,9 @@ class PrismaClient:
response["team_models"] = []
if response["team_blocked"] is None:
response["team_blocked"] = False
response = LiteLLM_VerificationTokenView(**response)
response = LiteLLM_VerificationTokenView(
**response, last_refreshed_at=time.time()
)
# for prisma we need to cast the expires time to str
if response.expires is not None and isinstance(
response.expires, datetime

View file

@ -55,6 +55,9 @@ async def test_check_blocked_team():
assert team is not blocked
"""
import asyncio
import time
from fastapi import Request
from starlette.datastructures import URL
@ -63,12 +66,17 @@ async def test_check_blocked_team():
from litellm.proxy.proxy_server import hash_token, user_api_key_cache
_team_id = "1234"
team_obj = LiteLLM_TeamTable(team_id=_team_id, blocked=False)
user_key = "sk-12345678"
valid_token = UserAPIKeyAuth(
team_id=_team_id, team_blocked=True, token=hash_token(user_key)
team_id=_team_id,
team_blocked=True,
token=hash_token(user_key),
last_refreshed_at=time.time(),
)
await asyncio.sleep(1)
team_obj = LiteLLM_TeamTable(
team_id=_team_id, blocked=False, last_refreshed_at=time.time()
)
user_api_key_cache.set_cache(key=hash_token(user_key), value=valid_token)
user_api_key_cache.set_cache(key="team_id:{}".format(_team_id), value=team_obj)