From c77af015f8d1164685b4d83d8a3c5d5c021ad9ff Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Tue, 26 Nov 2024 10:28:40 +0530 Subject: [PATCH] fix(team_endpoints.py): /team/member_add fix adding several new members to team --- litellm/proxy/auth/auth_checks.py | 153 ++++++++++++------ .../management_endpoints/team_endpoints.py | 1 + 2 files changed, 104 insertions(+), 50 deletions(-) diff --git a/litellm/proxy/auth/auth_checks.py b/litellm/proxy/auth/auth_checks.py index 7d29032c6..5d789436a 100644 --- a/litellm/proxy/auth/auth_checks.py +++ b/litellm/proxy/auth/auth_checks.py @@ -523,6 +523,10 @@ async def _cache_management_object( proxy_logging_obj: Optional[ProxyLogging], ): await user_api_key_cache.async_set_cache(key=key, value=value) + if proxy_logging_obj is not None: + await proxy_logging_obj.internal_usage_cache.dual_cache.async_set_cache( + key=key, value=value + ) async def _cache_team_object( @@ -586,26 +590,63 @@ async def _get_team_db_check(team_id: str, prisma_client: PrismaClient): ) -async def get_team_object( - team_id: str, - prisma_client: Optional[PrismaClient], - user_api_key_cache: DualCache, - parent_otel_span: Optional[Span] = None, - proxy_logging_obj: Optional[ProxyLogging] = None, - check_cache_only: Optional[bool] = None, -) -> LiteLLM_TeamTableCachedObj: - """ - - Check if team id in proxy Team Table - - if valid, return LiteLLM_TeamTable object with defined limits - - if not, then raise an error - """ - if prisma_client is None: - raise Exception( - "No DB Connected. See - https://docs.litellm.ai/docs/proxy/virtual_keys" - ) +async def _get_team_object_from_db(team_id: str, prisma_client: PrismaClient): + return await prisma_client.db.litellm_teamtable.find_unique( + where={"team_id": team_id} + ) - # check if in cache - key = "team_id:{}".format(team_id) + +async def _get_team_object_from_user_api_key_cache( + team_id: str, + prisma_client: PrismaClient, + user_api_key_cache: DualCache, + last_db_access_time: LimitedSizeOrderedDict, + db_cache_expiry: int, + proxy_logging_obj: Optional[ProxyLogging], + key: str, +) -> LiteLLM_TeamTableCachedObj: + db_access_time_key = key + should_check_db = _should_check_db( + key=db_access_time_key, + last_db_access_time=last_db_access_time, + db_cache_expiry=db_cache_expiry, + ) + if should_check_db: + response = await _get_team_db_check( + team_id=team_id, prisma_client=prisma_client + ) + else: + response = None + + if response is None: + raise Exception + + _response = LiteLLM_TeamTableCachedObj(**response.dict()) + # save the team object to cache + await _cache_team_object( + team_id=team_id, + team_table=_response, + user_api_key_cache=user_api_key_cache, + proxy_logging_obj=proxy_logging_obj, + ) + + # save to db access time + # save to db access time + _update_last_db_access_time( + key=db_access_time_key, + value=_response, + last_db_access_time=last_db_access_time, + ) + + return _response + + +async def _get_team_object_from_cache( + key: str, + proxy_logging_obj: Optional[ProxyLogging], + user_api_key_cache: DualCache, + parent_otel_span: Optional[Span], +) -> Optional[LiteLLM_TeamTableCachedObj]: cached_team_obj: Optional[LiteLLM_TeamTableCachedObj] = None ## CHECK REDIS CACHE ## @@ -613,6 +654,7 @@ async def get_team_object( proxy_logging_obj is not None and proxy_logging_obj.internal_usage_cache.dual_cache ): + cached_team_obj = ( await proxy_logging_obj.internal_usage_cache.dual_cache.async_get_cache( key=key, parent_otel_span=parent_otel_span @@ -628,47 +670,58 @@ async def get_team_object( elif isinstance(cached_team_obj, LiteLLM_TeamTableCachedObj): return cached_team_obj - if check_cache_only: + return None + + +async def get_team_object( + team_id: str, + prisma_client: Optional[PrismaClient], + user_api_key_cache: DualCache, + parent_otel_span: Optional[Span] = None, + proxy_logging_obj: Optional[ProxyLogging] = None, + check_cache_only: Optional[bool] = None, + check_db_only: Optional[bool] = None, +) -> LiteLLM_TeamTableCachedObj: + """ + - Check if team id in proxy Team Table + - if valid, return LiteLLM_TeamTable object with defined limits + - if not, then raise an error + """ + if prisma_client is None: raise Exception( - f"Team doesn't exist in cache + check_cache_only=True. Team={team_id}." + "No DB Connected. See - https://docs.litellm.ai/docs/proxy/virtual_keys" ) + # check if in cache + key = "team_id:{}".format(team_id) + + if not check_db_only: + cached_team_obj = await _get_team_object_from_cache( + key=key, + proxy_logging_obj=proxy_logging_obj, + user_api_key_cache=user_api_key_cache, + parent_otel_span=parent_otel_span, + ) + + if cached_team_obj is not None: + return cached_team_obj + + if check_cache_only: + raise Exception( + f"Team doesn't exist in cache + check_cache_only=True. Team={team_id}." + ) + # else, check db try: - db_access_time_key = "team_id:{}".format(team_id) - should_check_db = _should_check_db( - key=db_access_time_key, - last_db_access_time=last_db_access_time, - db_cache_expiry=db_cache_expiry, - ) - if should_check_db: - response = await _get_team_db_check( - team_id=team_id, prisma_client=prisma_client - ) - else: - response = None - - if response is None: - raise Exception - - _response = LiteLLM_TeamTableCachedObj(**response.dict()) - # save the team object to cache - await _cache_team_object( + return await _get_team_object_from_user_api_key_cache( team_id=team_id, - team_table=_response, + prisma_client=prisma_client, user_api_key_cache=user_api_key_cache, proxy_logging_obj=proxy_logging_obj, - ) - - # save to db access time - # save to db access time - _update_last_db_access_time( - key=db_access_time_key, - value=_response, last_db_access_time=last_db_access_time, + db_cache_expiry=db_cache_expiry, + key=key, ) - - return _response except Exception: raise Exception( f"Team doesn't exist in db. Team={team_id}. Create team via `/team/new` call." diff --git a/litellm/proxy/management_endpoints/team_endpoints.py b/litellm/proxy/management_endpoints/team_endpoints.py index dc1ec444d..9f749cee1 100644 --- a/litellm/proxy/management_endpoints/team_endpoints.py +++ b/litellm/proxy/management_endpoints/team_endpoints.py @@ -547,6 +547,7 @@ async def team_member_add( parent_otel_span=None, proxy_logging_obj=proxy_logging_obj, check_cache_only=False, + check_db_only=True, ) if existing_team_row is None: raise HTTPException(