forked from phoenix/litellm-mirror
fix(blocked_user_list.py): check if end user blocked in db
This commit is contained in:
parent
ef0002f31c
commit
dd151869a3
4 changed files with 86 additions and 23 deletions
|
@ -9,8 +9,9 @@
|
|||
|
||||
from typing import Optional, Literal
|
||||
import litellm
|
||||
from litellm.proxy.utils import PrismaClient
|
||||
from litellm.caching import DualCache
|
||||
from litellm.proxy._types import UserAPIKeyAuth
|
||||
from litellm.proxy._types import UserAPIKeyAuth, LiteLLM_EndUserTable
|
||||
from litellm.integrations.custom_logger import CustomLogger
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
from fastapi import HTTPException
|
||||
|
@ -19,13 +20,13 @@ import json, traceback
|
|||
|
||||
class _ENTERPRISE_BlockedUserList(CustomLogger):
|
||||
# Class variables or attributes
|
||||
def __init__(self):
|
||||
blocked_user_list = litellm.blocked_user_list
|
||||
def __init__(self, prisma_client: Optional[PrismaClient]):
|
||||
self.prisma_client = prisma_client
|
||||
|
||||
blocked_user_list = litellm.blocked_user_list
|
||||
if blocked_user_list is None:
|
||||
raise Exception(
|
||||
"`blocked_user_list` can either be a list or filepath. None set."
|
||||
)
|
||||
self.blocked_user_list = None
|
||||
return
|
||||
|
||||
if isinstance(blocked_user_list, list):
|
||||
self.blocked_user_list = blocked_user_list
|
||||
|
@ -64,17 +65,56 @@ class _ENTERPRISE_BlockedUserList(CustomLogger):
|
|||
"""
|
||||
- check if user id part of call
|
||||
- check if user id part of blocked list
|
||||
- if blocked list is none or user not in blocked list
|
||||
- check if end-user in cache
|
||||
- check if end-user in db
|
||||
"""
|
||||
self.print_verbose(f"Inside Blocked User List Pre-Call Hook")
|
||||
if "user_id" in data or "user" in data:
|
||||
user = data.get("user_id", data.get("user", ""))
|
||||
if user in self.blocked_user_list:
|
||||
if (
|
||||
self.blocked_user_list is not None
|
||||
and user in self.blocked_user_list
|
||||
):
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail={
|
||||
"error": f"User blocked from making LLM API Calls. User={user}"
|
||||
},
|
||||
)
|
||||
|
||||
cache_key = f"litellm:end_user_id:{user}"
|
||||
end_user_cache_obj: LiteLLM_EndUserTable = cache.get_cache(
|
||||
key=cache_key
|
||||
)
|
||||
if end_user_cache_obj is None and self.prisma_client is not None:
|
||||
# check db
|
||||
end_user_obj = (
|
||||
await self.prisma_client.db.litellm_endusertable.find_unique(
|
||||
where={"user_id": user}
|
||||
)
|
||||
)
|
||||
if end_user_obj is None: # user not in db - assume not blocked
|
||||
end_user_obj = LiteLLM_EndUserTable(user_id=user, blocked=False)
|
||||
cache.set_cache(key=cache_key, value=end_user_obj, ttl=60)
|
||||
if end_user_obj is not None and end_user_obj.blocked == True:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail={
|
||||
"error": f"User blocked from making LLM API Calls. User={user}"
|
||||
},
|
||||
)
|
||||
elif (
|
||||
end_user_cache_obj is not None
|
||||
and end_user_cache_obj.blocked == True
|
||||
):
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail={
|
||||
"error": f"User blocked from making LLM API Calls. User={user}"
|
||||
},
|
||||
)
|
||||
|
||||
except HTTPException as e:
|
||||
raise e
|
||||
except Exception as e:
|
||||
|
|
|
@ -602,6 +602,22 @@ class LiteLLM_UserTable(LiteLLMBase):
|
|||
protected_namespaces = ()
|
||||
|
||||
|
||||
class LiteLLM_EndUserTable(LiteLLMBase):
|
||||
user_id: str
|
||||
blocked: bool
|
||||
alias: Optional[str] = None
|
||||
spend: float = 0.0
|
||||
|
||||
@root_validator(pre=True)
|
||||
def set_model_info(cls, values):
|
||||
if values.get("spend") is None:
|
||||
values.update({"spend": 0.0})
|
||||
return values
|
||||
|
||||
class Config:
|
||||
protected_namespaces = ()
|
||||
|
||||
|
||||
class LiteLLM_SpendLogs(LiteLLMBase):
|
||||
request_id: str
|
||||
api_key: str
|
||||
|
|
|
@ -1774,7 +1774,9 @@ class ProxyConfig:
|
|||
_ENTERPRISE_BlockedUserList,
|
||||
)
|
||||
|
||||
blocked_user_list = _ENTERPRISE_BlockedUserList()
|
||||
blocked_user_list = _ENTERPRISE_BlockedUserList(
|
||||
prisma_client=prisma_client
|
||||
)
|
||||
imported_list.append(blocked_user_list)
|
||||
elif (
|
||||
isinstance(callback, str)
|
||||
|
@ -5111,22 +5113,27 @@ async def block_user(data: BlockUsers):
|
|||
}'
|
||||
```
|
||||
"""
|
||||
if prisma_client is not None:
|
||||
for id in data.user_ids:
|
||||
await prisma_client.db.litellm_endusertable.upsert(
|
||||
where={"id": id},
|
||||
data={
|
||||
"create": {"id": id, "blocked": True},
|
||||
"update": {"blocked": True},
|
||||
},
|
||||
try:
|
||||
records = []
|
||||
if prisma_client is not None:
|
||||
for id in data.user_ids:
|
||||
record = await prisma_client.db.litellm_endusertable.upsert(
|
||||
where={"user_id": id},
|
||||
data={
|
||||
"create": {"user_id": id, "blocked": True},
|
||||
"update": {"blocked": True},
|
||||
},
|
||||
)
|
||||
records.append(record)
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail={"error": "Postgres DB Not connected"},
|
||||
)
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail={"error": "Postgres DB Not connected"},
|
||||
)
|
||||
|
||||
return {"blocked_users": litellm.blocked_user_list}
|
||||
return {"blocked_users": records}
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail={"error": str(e)})
|
||||
|
||||
|
||||
@router.post(
|
||||
|
|
|
@ -129,7 +129,7 @@ model LiteLLM_VerificationToken {
|
|||
}
|
||||
|
||||
model LiteLLM_EndUserTable {
|
||||
id String @id
|
||||
user_id String @id
|
||||
alias String? // admin-facing alias
|
||||
spend Float @default(0.0)
|
||||
budget_id String?
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue