fix(blocked_user_list.py): check if end user blocked in db

This commit is contained in:
Krrish Dholakia 2024-03-16 13:03:52 -07:00
parent ef0002f31c
commit dd151869a3
4 changed files with 86 additions and 23 deletions

View file

@ -9,8 +9,9 @@
from typing import Optional, Literal
import litellm
from litellm.proxy.utils import PrismaClient
from litellm.caching import DualCache
from litellm.proxy._types import UserAPIKeyAuth
from litellm.proxy._types import UserAPIKeyAuth, LiteLLM_EndUserTable
from litellm.integrations.custom_logger import CustomLogger
from litellm._logging import verbose_proxy_logger
from fastapi import HTTPException
@ -19,13 +20,13 @@ import json, traceback
class _ENTERPRISE_BlockedUserList(CustomLogger):
# Class variables or attributes
def __init__(self):
blocked_user_list = litellm.blocked_user_list
def __init__(self, prisma_client: Optional[PrismaClient]):
self.prisma_client = prisma_client
blocked_user_list = litellm.blocked_user_list
if blocked_user_list is None:
raise Exception(
"`blocked_user_list` can either be a list or filepath. None set."
)
self.blocked_user_list = None
return
if isinstance(blocked_user_list, list):
self.blocked_user_list = blocked_user_list
@ -64,17 +65,56 @@ class _ENTERPRISE_BlockedUserList(CustomLogger):
"""
- check if user id part of call
- check if user id part of blocked list
- if blocked list is none or user not in blocked list
- check if end-user in cache
- check if end-user in db
"""
self.print_verbose(f"Inside Blocked User List Pre-Call Hook")
if "user_id" in data or "user" in data:
user = data.get("user_id", data.get("user", ""))
if user in self.blocked_user_list:
if (
self.blocked_user_list is not None
and user in self.blocked_user_list
):
raise HTTPException(
status_code=400,
detail={
"error": f"User blocked from making LLM API Calls. User={user}"
},
)
cache_key = f"litellm:end_user_id:{user}"
end_user_cache_obj: LiteLLM_EndUserTable = cache.get_cache(
key=cache_key
)
if end_user_cache_obj is None and self.prisma_client is not None:
# check db
end_user_obj = (
await self.prisma_client.db.litellm_endusertable.find_unique(
where={"user_id": user}
)
)
if end_user_obj is None: # user not in db - assume not blocked
end_user_obj = LiteLLM_EndUserTable(user_id=user, blocked=False)
cache.set_cache(key=cache_key, value=end_user_obj, ttl=60)
if end_user_obj is not None and end_user_obj.blocked == True:
raise HTTPException(
status_code=400,
detail={
"error": f"User blocked from making LLM API Calls. User={user}"
},
)
elif (
end_user_cache_obj is not None
and end_user_cache_obj.blocked == True
):
raise HTTPException(
status_code=400,
detail={
"error": f"User blocked from making LLM API Calls. User={user}"
},
)
except HTTPException as e:
raise e
except Exception as e:

View file

@ -602,6 +602,22 @@ class LiteLLM_UserTable(LiteLLMBase):
protected_namespaces = ()
class LiteLLM_EndUserTable(LiteLLMBase):
user_id: str
blocked: bool
alias: Optional[str] = None
spend: float = 0.0
@root_validator(pre=True)
def set_model_info(cls, values):
if values.get("spend") is None:
values.update({"spend": 0.0})
return values
class Config:
protected_namespaces = ()
class LiteLLM_SpendLogs(LiteLLMBase):
request_id: str
api_key: str

View file

@ -1774,7 +1774,9 @@ class ProxyConfig:
_ENTERPRISE_BlockedUserList,
)
blocked_user_list = _ENTERPRISE_BlockedUserList()
blocked_user_list = _ENTERPRISE_BlockedUserList(
prisma_client=prisma_client
)
imported_list.append(blocked_user_list)
elif (
isinstance(callback, str)
@ -5111,22 +5113,27 @@ async def block_user(data: BlockUsers):
}'
```
"""
if prisma_client is not None:
for id in data.user_ids:
await prisma_client.db.litellm_endusertable.upsert(
where={"id": id},
data={
"create": {"id": id, "blocked": True},
"update": {"blocked": True},
},
try:
records = []
if prisma_client is not None:
for id in data.user_ids:
record = await prisma_client.db.litellm_endusertable.upsert(
where={"user_id": id},
data={
"create": {"user_id": id, "blocked": True},
"update": {"blocked": True},
},
)
records.append(record)
else:
raise HTTPException(
status_code=500,
detail={"error": "Postgres DB Not connected"},
)
else:
raise HTTPException(
status_code=500,
detail={"error": "Postgres DB Not connected"},
)
return {"blocked_users": litellm.blocked_user_list}
return {"blocked_users": records}
except Exception as e:
raise HTTPException(status_code=500, detail={"error": str(e)})
@router.post(

View file

@ -129,7 +129,7 @@ model LiteLLM_VerificationToken {
}
model LiteLLM_EndUserTable {
id String @id
user_id String @id
alias String? // admin-facing alias
spend Float @default(0.0)
budget_id String?