forked from phoenix/litellm-mirror
fix(user_api_key_auth.py): update team values in token cache if refreshed more recently
This commit is contained in:
parent
35e640076b
commit
99aa311083
4 changed files with 31 additions and 9 deletions
|
@ -886,6 +886,7 @@ class LiteLLM_TeamTable(TeamBase):
|
|||
budget_duration: Optional[str] = None
|
||||
budget_reset_at: Optional[datetime] = None
|
||||
model_id: Optional[int] = None
|
||||
last_refreshed_at: Optional[float] = None
|
||||
|
||||
model_config = ConfigDict(protected_namespaces=())
|
||||
|
||||
|
@ -1238,6 +1239,9 @@ class LiteLLM_VerificationTokenView(LiteLLM_VerificationToken):
|
|||
end_user_rpm_limit: Optional[int] = None
|
||||
end_user_max_budget: Optional[float] = None
|
||||
|
||||
# Time stamps
|
||||
last_refreshed_at: Optional[float] = None # last time joint view was pulled from db
|
||||
|
||||
|
||||
class UserAPIKeyAuth(
|
||||
LiteLLM_VerificationTokenView
|
||||
|
|
|
@ -467,12 +467,17 @@ async def user_api_key_auth(
|
|||
key="team_id:{}".format(valid_token.team_id)
|
||||
)
|
||||
|
||||
team_obj_dict = team_obj.__dict__
|
||||
if (
|
||||
team_obj.last_refreshed_at is not None
|
||||
and valid_token.last_refreshed_at is not None
|
||||
and team_obj.last_refreshed_at > valid_token.last_refreshed_at
|
||||
):
|
||||
team_obj_dict = team_obj.__dict__
|
||||
|
||||
for k, v in team_obj_dict.items():
|
||||
field_name = f"team_{k}"
|
||||
if field_name in valid_token.__fields__:
|
||||
setattr(valid_token, field_name, v)
|
||||
for k, v in team_obj_dict.items():
|
||||
field_name = f"team_{k}"
|
||||
if field_name in valid_token.__fields__:
|
||||
setattr(valid_token, field_name, v)
|
||||
|
||||
try:
|
||||
is_master_key_valid = secrets.compare_digest(api_key, master_key) # type: ignore
|
||||
|
@ -541,10 +546,13 @@ async def user_api_key_auth(
|
|||
parent_otel_span=parent_otel_span,
|
||||
proxy_logging_obj=proxy_logging_obj,
|
||||
)
|
||||
|
||||
if _valid_token is not None:
|
||||
## update cached token
|
||||
valid_token = UserAPIKeyAuth(
|
||||
**_valid_token.model_dump(exclude_none=True)
|
||||
)
|
||||
|
||||
verbose_proxy_logger.debug("Token from db: %s", valid_token)
|
||||
elif valid_token is not None and isinstance(valid_token, UserAPIKeyAuth):
|
||||
verbose_proxy_logger.debug("API Key Cache Hit!")
|
||||
|
|
|
@ -1331,7 +1331,9 @@ class PrismaClient:
|
|||
response["team_models"] = []
|
||||
if response["team_blocked"] is None:
|
||||
response["team_blocked"] = False
|
||||
response = LiteLLM_VerificationTokenView(**response)
|
||||
response = LiteLLM_VerificationTokenView(
|
||||
**response, last_refreshed_at=time.time()
|
||||
)
|
||||
# for prisma we need to cast the expires time to str
|
||||
if response.expires is not None and isinstance(
|
||||
response.expires, datetime
|
||||
|
|
|
@ -55,6 +55,9 @@ async def test_check_blocked_team():
|
|||
|
||||
assert team is not blocked
|
||||
"""
|
||||
import asyncio
|
||||
import time
|
||||
|
||||
from fastapi import Request
|
||||
from starlette.datastructures import URL
|
||||
|
||||
|
@ -63,12 +66,17 @@ async def test_check_blocked_team():
|
|||
from litellm.proxy.proxy_server import hash_token, user_api_key_cache
|
||||
|
||||
_team_id = "1234"
|
||||
team_obj = LiteLLM_TeamTable(team_id=_team_id, blocked=False)
|
||||
|
||||
user_key = "sk-12345678"
|
||||
|
||||
valid_token = UserAPIKeyAuth(
|
||||
team_id=_team_id, team_blocked=True, token=hash_token(user_key)
|
||||
team_id=_team_id,
|
||||
team_blocked=True,
|
||||
token=hash_token(user_key),
|
||||
last_refreshed_at=time.time(),
|
||||
)
|
||||
await asyncio.sleep(1)
|
||||
team_obj = LiteLLM_TeamTable(
|
||||
team_id=_team_id, blocked=False, last_refreshed_at=time.time()
|
||||
)
|
||||
user_api_key_cache.set_cache(key=hash_token(user_key), value=valid_token)
|
||||
user_api_key_cache.set_cache(key="team_id:{}".format(_team_id), value=team_obj)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue