fix(user_api_key_cache): fix check to not raise error if team object is missing

This commit is contained in:
Krrish Dholakia 2024-07-30 18:25:04 -07:00
parent 6c0506a144
commit b77edc59ed
6 changed files with 34 additions and 22 deletions

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View file

@ -396,6 +396,7 @@ async def get_team_object(
user_api_key_cache: DualCache,
parent_otel_span: Optional[Span] = None,
proxy_logging_obj: Optional[ProxyLogging] = None,
check_cache_only: Optional[bool] = None,
) -> LiteLLM_TeamTableCachedObj:
"""
- Check if team id in proxy Team Table
@ -431,6 +432,12 @@ async def get_team_object(
return LiteLLM_TeamTableCachedObj(**cached_team_obj)
elif isinstance(cached_team_obj, LiteLLM_TeamTableCachedObj):
return cached_team_obj
if check_cache_only:
raise Exception(
f"Team doesn't exist in cache + check_cache_only=True. Team={team_id}. Create team via `/team/new` call."
)
# else, check db
try:
response = await prisma_client.db.litellm_teamtable.find_unique(

View file

@ -463,25 +463,29 @@ async def user_api_key_auth(
and valid_token.team_id is not None
):
## UPDATE TEAM VALUES BASED ON CACHED TEAM OBJECT - allows `/team/update` values to work for cached token
team_obj: LiteLLM_TeamTableCachedObj = await get_team_object(
team_id=valid_token.team_id,
prisma_client=prisma_client,
user_api_key_cache=user_api_key_cache,
parent_otel_span=parent_otel_span,
proxy_logging_obj=proxy_logging_obj,
)
try:
team_obj: LiteLLM_TeamTableCachedObj = await get_team_object(
team_id=valid_token.team_id,
prisma_client=prisma_client,
user_api_key_cache=user_api_key_cache,
parent_otel_span=parent_otel_span,
proxy_logging_obj=proxy_logging_obj,
check_cache_only=True,
)
if (
team_obj.last_refreshed_at is not None
and valid_token.last_refreshed_at is not None
and team_obj.last_refreshed_at > valid_token.last_refreshed_at
):
team_obj_dict = team_obj.__dict__
if (
team_obj.last_refreshed_at is not None
and valid_token.last_refreshed_at is not None
and team_obj.last_refreshed_at > valid_token.last_refreshed_at
):
team_obj_dict = team_obj.__dict__
for k, v in team_obj_dict.items():
field_name = f"team_{k}"
if field_name in valid_token.__fields__:
setattr(valid_token, field_name, v)
for k, v in team_obj_dict.items():
field_name = f"team_{k}"
if field_name in valid_token.__fields__:
setattr(valid_token, field_name, v)
except Exception as e:
verbose_logger.warning(e)
try:
is_master_key_valid = secrets.compare_digest(api_key, master_key) # type: ignore

View file

@ -61,7 +61,11 @@ async def test_check_blocked_team():
from fastapi import Request
from starlette.datastructures import URL
from litellm.proxy._types import LiteLLM_TeamTable, UserAPIKeyAuth
from litellm.proxy._types import (
LiteLLM_TeamTable,
LiteLLM_TeamTableCachedObj,
UserAPIKeyAuth,
)
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
from litellm.proxy.proxy_server import hash_token, user_api_key_cache
@ -75,7 +79,7 @@ async def test_check_blocked_team():
last_refreshed_at=time.time(),
)
await asyncio.sleep(1)
team_obj = LiteLLM_TeamTable(
team_obj = LiteLLM_TeamTableCachedObj(
team_id=_team_id, blocked=False, last_refreshed_at=time.time()
)
user_api_key_cache.set_cache(key=hash_token(user_key), value=valid_token)