fix(auth_checks.py): fix redis usage for team cached objects

This commit is contained in:
Krrish Dholakia 2024-07-30 17:29:38 -07:00
parent c551e5b47a
commit 142f4fefd0
5 changed files with 37 additions and 21 deletions

View file

@ -3,7 +3,7 @@ model_list:
litellm_params: litellm_params:
model: "*" model: "*"
# litellm_settings: litellm_settings:
# cache: true cache: true
# cache_params: cache_params:
# type: redis type: redis

View file

@ -910,7 +910,6 @@ class LiteLLM_TeamTable(TeamBase):
budget_duration: Optional[str] = None budget_duration: Optional[str] = None
budget_reset_at: Optional[datetime] = None budget_reset_at: Optional[datetime] = None
model_id: Optional[int] = None model_id: Optional[int] = None
last_refreshed_at: Optional[float] = None
model_config = ConfigDict(protected_namespaces=()) model_config = ConfigDict(protected_namespaces=())
@ -936,6 +935,10 @@ class LiteLLM_TeamTable(TeamBase):
return values return values
class LiteLLM_TeamTableCachedObj(LiteLLM_TeamTable):
last_refreshed_at: Optional[float] = None
class TeamRequest(LiteLLMBase): class TeamRequest(LiteLLMBase):
teams: List[str] teams: List[str]

View file

@ -8,6 +8,7 @@ Run checks for:
2. If user is in budget 2. If user is in budget
3. If end_user ('user' passed to /chat/completions, /embeddings endpoint) is in budget 3. If end_user ('user' passed to /chat/completions, /embeddings endpoint) is in budget
""" """
import time
from datetime import datetime from datetime import datetime
from typing import TYPE_CHECKING, Any, Literal, Optional from typing import TYPE_CHECKING, Any, Literal, Optional
@ -19,6 +20,7 @@ from litellm.proxy._types import (
LiteLLM_JWTAuth, LiteLLM_JWTAuth,
LiteLLM_OrganizationTable, LiteLLM_OrganizationTable,
LiteLLM_TeamTable, LiteLLM_TeamTable,
LiteLLM_TeamTableCachedObj,
LiteLLM_UserTable, LiteLLM_UserTable,
LiteLLMRoutes, LiteLLMRoutes,
LitellmUserRoles, LitellmUserRoles,
@ -368,11 +370,15 @@ async def get_user_object(
async def _cache_team_object( async def _cache_team_object(
team_id: str, team_id: str,
team_table: LiteLLM_TeamTable, team_table: LiteLLM_TeamTableCachedObj,
user_api_key_cache: DualCache, user_api_key_cache: DualCache,
proxy_logging_obj: Optional[ProxyLogging], proxy_logging_obj: Optional[ProxyLogging],
): ):
key = "team_id:{}".format(team_id) key = "team_id:{}".format(team_id)
## CACHE REFRESH TIME!
team_table.last_refreshed_at = time.time()
value = team_table.model_dump_json(exclude_unset=True) value = team_table.model_dump_json(exclude_unset=True)
await user_api_key_cache.async_set_cache(key=key, value=value) await user_api_key_cache.async_set_cache(key=key, value=value)
@ -390,7 +396,7 @@ async def get_team_object(
user_api_key_cache: DualCache, user_api_key_cache: DualCache,
parent_otel_span: Optional[Span] = None, parent_otel_span: Optional[Span] = None,
proxy_logging_obj: Optional[ProxyLogging] = None, proxy_logging_obj: Optional[ProxyLogging] = None,
) -> LiteLLM_TeamTable: ) -> LiteLLM_TeamTableCachedObj:
""" """
- Check if team id in proxy Team Table - Check if team id in proxy Team Table
- if valid, return LiteLLM_TeamTable object with defined limits - if valid, return LiteLLM_TeamTable object with defined limits
@ -404,20 +410,26 @@ async def get_team_object(
# check if in cache # check if in cache
key = "team_id:{}".format(team_id) key = "team_id:{}".format(team_id)
cached_team_obj: Optional[LiteLLM_TeamTable] = None cached_team_obj: Optional[LiteLLM_TeamTableCachedObj] = None
## CHECK REDIS CACHE ## ## CHECK REDIS CACHE ##
if proxy_logging_obj is not None: if (
cached_team_obj = await proxy_logging_obj.internal_usage_cache.async_get_cache( proxy_logging_obj is not None
and proxy_logging_obj.internal_usage_cache.redis_cache is not None
):
cached_team_obj = (
await proxy_logging_obj.internal_usage_cache.redis_cache.async_get_cache(
key=key key=key
) )
)
if cached_team_obj is None: if cached_team_obj is None:
cached_team_obj = await user_api_key_cache.async_get_cache(key=key) cached_team_obj = await user_api_key_cache.async_get_cache(key=key)
if cached_team_obj is not None: if cached_team_obj is not None:
if isinstance(cached_team_obj, dict): if isinstance(cached_team_obj, dict):
return LiteLLM_TeamTable(**cached_team_obj) return LiteLLM_TeamTableCachedObj(**cached_team_obj)
elif isinstance(cached_team_obj, LiteLLM_TeamTable): elif isinstance(cached_team_obj, LiteLLM_TeamTableCachedObj):
return cached_team_obj return cached_team_obj
# else, check db # else, check db
try: try:
@ -428,7 +440,7 @@ async def get_team_object(
if response is None: if response is None:
raise Exception raise Exception
_response = LiteLLM_TeamTable(**response.dict()) _response = LiteLLM_TeamTableCachedObj(**response.dict())
# save the team object to cache # save the team object to cache
await _cache_team_object( await _cache_team_object(
team_id=team_id, team_id=team_id,

View file

@ -461,14 +461,14 @@ async def user_api_key_auth(
valid_token is not None valid_token is not None
and isinstance(valid_token, UserAPIKeyAuth) and isinstance(valid_token, UserAPIKeyAuth)
and valid_token.team_id is not None and valid_token.team_id is not None
and user_api_key_cache.get_cache(
key="team_id:{}".format(valid_token.team_id)
)
is not None
): ):
## UPDATE TEAM VALUES BASED ON CACHED TEAM OBJECT - allows `/team/update` values to work for cached token ## UPDATE TEAM VALUES BASED ON CACHED TEAM OBJECT - allows `/team/update` values to work for cached token
team_obj: LiteLLM_TeamTable = user_api_key_cache.get_cache( team_obj: LiteLLM_TeamTableCachedObj = await get_team_object(
key="team_id:{}".format(valid_token.team_id) team_id=valid_token.team_id,
prisma_client=prisma_client,
user_api_key_cache=user_api_key_cache,
parent_otel_span=parent_otel_span,
proxy_logging_obj=proxy_logging_obj,
) )
if ( if (

View file

@ -18,6 +18,7 @@ from litellm.proxy._types import (
LiteLLM_AuditLogs, LiteLLM_AuditLogs,
LiteLLM_ModelTable, LiteLLM_ModelTable,
LiteLLM_TeamTable, LiteLLM_TeamTable,
LiteLLM_TeamTableCachedObj,
LitellmTableNames, LitellmTableNames,
LitellmUserRoles, LitellmUserRoles,
Member, Member,
@ -379,7 +380,7 @@ async def update_team(
await _cache_team_object( await _cache_team_object(
team_id=team_row.team_id, team_id=team_row.team_id,
team_table=team_row, team_table=LiteLLM_TeamTableCachedObj(**team_row.model_dump()),
user_api_key_cache=user_api_key_cache, user_api_key_cache=user_api_key_cache,
proxy_logging_obj=proxy_logging_obj, proxy_logging_obj=proxy_logging_obj,
) )