From dd151869a3ace1f6e422f137b90889431daeb424 Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Sat, 16 Mar 2024 13:03:52 -0700 Subject: [PATCH] fix(blocked_user_list.py): check if end user blocked in db --- .../enterprise_hooks/blocked_user_list.py | 54 ++++++++++++++++--- litellm/proxy/_types.py | 16 ++++++ litellm/proxy/proxy_server.py | 37 +++++++------ litellm/proxy/schema.prisma | 2 +- 4 files changed, 86 insertions(+), 23 deletions(-) diff --git a/enterprise/enterprise_hooks/blocked_user_list.py b/enterprise/enterprise_hooks/blocked_user_list.py index 686fdf1de..cbc14d2c2 100644 --- a/enterprise/enterprise_hooks/blocked_user_list.py +++ b/enterprise/enterprise_hooks/blocked_user_list.py @@ -9,8 +9,9 @@ from typing import Optional, Literal import litellm +from litellm.proxy.utils import PrismaClient from litellm.caching import DualCache -from litellm.proxy._types import UserAPIKeyAuth +from litellm.proxy._types import UserAPIKeyAuth, LiteLLM_EndUserTable from litellm.integrations.custom_logger import CustomLogger from litellm._logging import verbose_proxy_logger from fastapi import HTTPException @@ -19,13 +20,13 @@ import json, traceback class _ENTERPRISE_BlockedUserList(CustomLogger): # Class variables or attributes - def __init__(self): - blocked_user_list = litellm.blocked_user_list + def __init__(self, prisma_client: Optional[PrismaClient]): + self.prisma_client = prisma_client + blocked_user_list = litellm.blocked_user_list if blocked_user_list is None: - raise Exception( - "`blocked_user_list` can either be a list or filepath. None set." - ) + self.blocked_user_list = None + return if isinstance(blocked_user_list, list): self.blocked_user_list = blocked_user_list @@ -64,17 +65,56 @@ class _ENTERPRISE_BlockedUserList(CustomLogger): """ - check if user id part of call - check if user id part of blocked list + - if blocked list is none or user not in blocked list + - check if end-user in cache + - check if end-user in db """ self.print_verbose(f"Inside Blocked User List Pre-Call Hook") if "user_id" in data or "user" in data: user = data.get("user_id", data.get("user", "")) - if user in self.blocked_user_list: + if ( + self.blocked_user_list is not None + and user in self.blocked_user_list + ): raise HTTPException( status_code=400, detail={ "error": f"User blocked from making LLM API Calls. User={user}" }, ) + + cache_key = f"litellm:end_user_id:{user}" + end_user_cache_obj: LiteLLM_EndUserTable = cache.get_cache( + key=cache_key + ) + if end_user_cache_obj is None and self.prisma_client is not None: + # check db + end_user_obj = ( + await self.prisma_client.db.litellm_endusertable.find_unique( + where={"user_id": user} + ) + ) + if end_user_obj is None: # user not in db - assume not blocked + end_user_obj = LiteLLM_EndUserTable(user_id=user, blocked=False) + cache.set_cache(key=cache_key, value=end_user_obj, ttl=60) + if end_user_obj is not None and end_user_obj.blocked == True: + raise HTTPException( + status_code=400, + detail={ + "error": f"User blocked from making LLM API Calls. User={user}" + }, + ) + elif ( + end_user_cache_obj is not None + and end_user_cache_obj.blocked == True + ): + raise HTTPException( + status_code=400, + detail={ + "error": f"User blocked from making LLM API Calls. User={user}" + }, + ) + except HTTPException as e: raise e except Exception as e: diff --git a/litellm/proxy/_types.py b/litellm/proxy/_types.py index 8a7efa1a1..a8c0c3d27 100644 --- a/litellm/proxy/_types.py +++ b/litellm/proxy/_types.py @@ -602,6 +602,22 @@ class LiteLLM_UserTable(LiteLLMBase): protected_namespaces = () +class LiteLLM_EndUserTable(LiteLLMBase): + user_id: str + blocked: bool + alias: Optional[str] = None + spend: float = 0.0 + + @root_validator(pre=True) + def set_model_info(cls, values): + if values.get("spend") is None: + values.update({"spend": 0.0}) + return values + + class Config: + protected_namespaces = () + + class LiteLLM_SpendLogs(LiteLLMBase): request_id: str api_key: str diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index b1aac7791..91fc8295d 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -1774,7 +1774,9 @@ class ProxyConfig: _ENTERPRISE_BlockedUserList, ) - blocked_user_list = _ENTERPRISE_BlockedUserList() + blocked_user_list = _ENTERPRISE_BlockedUserList( + prisma_client=prisma_client + ) imported_list.append(blocked_user_list) elif ( isinstance(callback, str) @@ -5111,22 +5113,27 @@ async def block_user(data: BlockUsers): }' ``` """ - if prisma_client is not None: - for id in data.user_ids: - await prisma_client.db.litellm_endusertable.upsert( - where={"id": id}, - data={ - "create": {"id": id, "blocked": True}, - "update": {"blocked": True}, - }, + try: + records = [] + if prisma_client is not None: + for id in data.user_ids: + record = await prisma_client.db.litellm_endusertable.upsert( + where={"user_id": id}, + data={ + "create": {"user_id": id, "blocked": True}, + "update": {"blocked": True}, + }, + ) + records.append(record) + else: + raise HTTPException( + status_code=500, + detail={"error": "Postgres DB Not connected"}, ) - else: - raise HTTPException( - status_code=500, - detail={"error": "Postgres DB Not connected"}, - ) - return {"blocked_users": litellm.blocked_user_list} + return {"blocked_users": records} + except Exception as e: + raise HTTPException(status_code=500, detail={"error": str(e)}) @router.post( diff --git a/litellm/proxy/schema.prisma b/litellm/proxy/schema.prisma index c11a387a8..6dd89bd85 100644 --- a/litellm/proxy/schema.prisma +++ b/litellm/proxy/schema.prisma @@ -129,7 +129,7 @@ model LiteLLM_VerificationToken { } model LiteLLM_EndUserTable { - id String @id + user_id String @id alias String? // admin-facing alias spend Float @default(0.0) budget_id String?