forked from phoenix/litellm-mirror
fix(user_api_key_auth.py): update team values in token cache if refreshed more recently
This commit is contained in:
parent
35e640076b
commit
99aa311083
4 changed files with 31 additions and 9 deletions
|
@ -886,6 +886,7 @@ class LiteLLM_TeamTable(TeamBase):
|
||||||
budget_duration: Optional[str] = None
|
budget_duration: Optional[str] = None
|
||||||
budget_reset_at: Optional[datetime] = None
|
budget_reset_at: Optional[datetime] = None
|
||||||
model_id: Optional[int] = None
|
model_id: Optional[int] = None
|
||||||
|
last_refreshed_at: Optional[float] = None
|
||||||
|
|
||||||
model_config = ConfigDict(protected_namespaces=())
|
model_config = ConfigDict(protected_namespaces=())
|
||||||
|
|
||||||
|
@ -1238,6 +1239,9 @@ class LiteLLM_VerificationTokenView(LiteLLM_VerificationToken):
|
||||||
end_user_rpm_limit: Optional[int] = None
|
end_user_rpm_limit: Optional[int] = None
|
||||||
end_user_max_budget: Optional[float] = None
|
end_user_max_budget: Optional[float] = None
|
||||||
|
|
||||||
|
# Time stamps
|
||||||
|
last_refreshed_at: Optional[float] = None # last time joint view was pulled from db
|
||||||
|
|
||||||
|
|
||||||
class UserAPIKeyAuth(
|
class UserAPIKeyAuth(
|
||||||
LiteLLM_VerificationTokenView
|
LiteLLM_VerificationTokenView
|
||||||
|
|
|
@ -467,12 +467,17 @@ async def user_api_key_auth(
|
||||||
key="team_id:{}".format(valid_token.team_id)
|
key="team_id:{}".format(valid_token.team_id)
|
||||||
)
|
)
|
||||||
|
|
||||||
team_obj_dict = team_obj.__dict__
|
if (
|
||||||
|
team_obj.last_refreshed_at is not None
|
||||||
|
and valid_token.last_refreshed_at is not None
|
||||||
|
and team_obj.last_refreshed_at > valid_token.last_refreshed_at
|
||||||
|
):
|
||||||
|
team_obj_dict = team_obj.__dict__
|
||||||
|
|
||||||
for k, v in team_obj_dict.items():
|
for k, v in team_obj_dict.items():
|
||||||
field_name = f"team_{k}"
|
field_name = f"team_{k}"
|
||||||
if field_name in valid_token.__fields__:
|
if field_name in valid_token.__fields__:
|
||||||
setattr(valid_token, field_name, v)
|
setattr(valid_token, field_name, v)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
is_master_key_valid = secrets.compare_digest(api_key, master_key) # type: ignore
|
is_master_key_valid = secrets.compare_digest(api_key, master_key) # type: ignore
|
||||||
|
@ -541,10 +546,13 @@ async def user_api_key_auth(
|
||||||
parent_otel_span=parent_otel_span,
|
parent_otel_span=parent_otel_span,
|
||||||
proxy_logging_obj=proxy_logging_obj,
|
proxy_logging_obj=proxy_logging_obj,
|
||||||
)
|
)
|
||||||
|
|
||||||
if _valid_token is not None:
|
if _valid_token is not None:
|
||||||
|
## update cached token
|
||||||
valid_token = UserAPIKeyAuth(
|
valid_token = UserAPIKeyAuth(
|
||||||
**_valid_token.model_dump(exclude_none=True)
|
**_valid_token.model_dump(exclude_none=True)
|
||||||
)
|
)
|
||||||
|
|
||||||
verbose_proxy_logger.debug("Token from db: %s", valid_token)
|
verbose_proxy_logger.debug("Token from db: %s", valid_token)
|
||||||
elif valid_token is not None and isinstance(valid_token, UserAPIKeyAuth):
|
elif valid_token is not None and isinstance(valid_token, UserAPIKeyAuth):
|
||||||
verbose_proxy_logger.debug("API Key Cache Hit!")
|
verbose_proxy_logger.debug("API Key Cache Hit!")
|
||||||
|
|
|
@ -1331,7 +1331,9 @@ class PrismaClient:
|
||||||
response["team_models"] = []
|
response["team_models"] = []
|
||||||
if response["team_blocked"] is None:
|
if response["team_blocked"] is None:
|
||||||
response["team_blocked"] = False
|
response["team_blocked"] = False
|
||||||
response = LiteLLM_VerificationTokenView(**response)
|
response = LiteLLM_VerificationTokenView(
|
||||||
|
**response, last_refreshed_at=time.time()
|
||||||
|
)
|
||||||
# for prisma we need to cast the expires time to str
|
# for prisma we need to cast the expires time to str
|
||||||
if response.expires is not None and isinstance(
|
if response.expires is not None and isinstance(
|
||||||
response.expires, datetime
|
response.expires, datetime
|
||||||
|
|
|
@ -55,6 +55,9 @@ async def test_check_blocked_team():
|
||||||
|
|
||||||
assert team is not blocked
|
assert team is not blocked
|
||||||
"""
|
"""
|
||||||
|
import asyncio
|
||||||
|
import time
|
||||||
|
|
||||||
from fastapi import Request
|
from fastapi import Request
|
||||||
from starlette.datastructures import URL
|
from starlette.datastructures import URL
|
||||||
|
|
||||||
|
@ -63,12 +66,17 @@ async def test_check_blocked_team():
|
||||||
from litellm.proxy.proxy_server import hash_token, user_api_key_cache
|
from litellm.proxy.proxy_server import hash_token, user_api_key_cache
|
||||||
|
|
||||||
_team_id = "1234"
|
_team_id = "1234"
|
||||||
team_obj = LiteLLM_TeamTable(team_id=_team_id, blocked=False)
|
|
||||||
|
|
||||||
user_key = "sk-12345678"
|
user_key = "sk-12345678"
|
||||||
|
|
||||||
valid_token = UserAPIKeyAuth(
|
valid_token = UserAPIKeyAuth(
|
||||||
team_id=_team_id, team_blocked=True, token=hash_token(user_key)
|
team_id=_team_id,
|
||||||
|
team_blocked=True,
|
||||||
|
token=hash_token(user_key),
|
||||||
|
last_refreshed_at=time.time(),
|
||||||
|
)
|
||||||
|
await asyncio.sleep(1)
|
||||||
|
team_obj = LiteLLM_TeamTable(
|
||||||
|
team_id=_team_id, blocked=False, last_refreshed_at=time.time()
|
||||||
)
|
)
|
||||||
user_api_key_cache.set_cache(key=hash_token(user_key), value=valid_token)
|
user_api_key_cache.set_cache(key=hash_token(user_key), value=valid_token)
|
||||||
user_api_key_cache.set_cache(key="team_id:{}".format(_team_id), value=team_obj)
|
user_api_key_cache.set_cache(key="team_id:{}".format(_team_id), value=team_obj)
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue