forked from phoenix/litellm-mirror
fix(blocked_user_list.py): check if end user blocked in db
This commit is contained in:
parent
ef0002f31c
commit
dd151869a3
4 changed files with 86 additions and 23 deletions
|
@ -9,8 +9,9 @@
|
||||||
|
|
||||||
from typing import Optional, Literal
|
from typing import Optional, Literal
|
||||||
import litellm
|
import litellm
|
||||||
|
from litellm.proxy.utils import PrismaClient
|
||||||
from litellm.caching import DualCache
|
from litellm.caching import DualCache
|
||||||
from litellm.proxy._types import UserAPIKeyAuth
|
from litellm.proxy._types import UserAPIKeyAuth, LiteLLM_EndUserTable
|
||||||
from litellm.integrations.custom_logger import CustomLogger
|
from litellm.integrations.custom_logger import CustomLogger
|
||||||
from litellm._logging import verbose_proxy_logger
|
from litellm._logging import verbose_proxy_logger
|
||||||
from fastapi import HTTPException
|
from fastapi import HTTPException
|
||||||
|
@ -19,13 +20,13 @@ import json, traceback
|
||||||
|
|
||||||
class _ENTERPRISE_BlockedUserList(CustomLogger):
|
class _ENTERPRISE_BlockedUserList(CustomLogger):
|
||||||
# Class variables or attributes
|
# Class variables or attributes
|
||||||
def __init__(self):
|
def __init__(self, prisma_client: Optional[PrismaClient]):
|
||||||
blocked_user_list = litellm.blocked_user_list
|
self.prisma_client = prisma_client
|
||||||
|
|
||||||
|
blocked_user_list = litellm.blocked_user_list
|
||||||
if blocked_user_list is None:
|
if blocked_user_list is None:
|
||||||
raise Exception(
|
self.blocked_user_list = None
|
||||||
"`blocked_user_list` can either be a list or filepath. None set."
|
return
|
||||||
)
|
|
||||||
|
|
||||||
if isinstance(blocked_user_list, list):
|
if isinstance(blocked_user_list, list):
|
||||||
self.blocked_user_list = blocked_user_list
|
self.blocked_user_list = blocked_user_list
|
||||||
|
@ -64,17 +65,56 @@ class _ENTERPRISE_BlockedUserList(CustomLogger):
|
||||||
"""
|
"""
|
||||||
- check if user id part of call
|
- check if user id part of call
|
||||||
- check if user id part of blocked list
|
- check if user id part of blocked list
|
||||||
|
- if blocked list is none or user not in blocked list
|
||||||
|
- check if end-user in cache
|
||||||
|
- check if end-user in db
|
||||||
"""
|
"""
|
||||||
self.print_verbose(f"Inside Blocked User List Pre-Call Hook")
|
self.print_verbose(f"Inside Blocked User List Pre-Call Hook")
|
||||||
if "user_id" in data or "user" in data:
|
if "user_id" in data or "user" in data:
|
||||||
user = data.get("user_id", data.get("user", ""))
|
user = data.get("user_id", data.get("user", ""))
|
||||||
if user in self.blocked_user_list:
|
if (
|
||||||
|
self.blocked_user_list is not None
|
||||||
|
and user in self.blocked_user_list
|
||||||
|
):
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=400,
|
status_code=400,
|
||||||
detail={
|
detail={
|
||||||
"error": f"User blocked from making LLM API Calls. User={user}"
|
"error": f"User blocked from making LLM API Calls. User={user}"
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
cache_key = f"litellm:end_user_id:{user}"
|
||||||
|
end_user_cache_obj: LiteLLM_EndUserTable = cache.get_cache(
|
||||||
|
key=cache_key
|
||||||
|
)
|
||||||
|
if end_user_cache_obj is None and self.prisma_client is not None:
|
||||||
|
# check db
|
||||||
|
end_user_obj = (
|
||||||
|
await self.prisma_client.db.litellm_endusertable.find_unique(
|
||||||
|
where={"user_id": user}
|
||||||
|
)
|
||||||
|
)
|
||||||
|
if end_user_obj is None: # user not in db - assume not blocked
|
||||||
|
end_user_obj = LiteLLM_EndUserTable(user_id=user, blocked=False)
|
||||||
|
cache.set_cache(key=cache_key, value=end_user_obj, ttl=60)
|
||||||
|
if end_user_obj is not None and end_user_obj.blocked == True:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=400,
|
||||||
|
detail={
|
||||||
|
"error": f"User blocked from making LLM API Calls. User={user}"
|
||||||
|
},
|
||||||
|
)
|
||||||
|
elif (
|
||||||
|
end_user_cache_obj is not None
|
||||||
|
and end_user_cache_obj.blocked == True
|
||||||
|
):
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=400,
|
||||||
|
detail={
|
||||||
|
"error": f"User blocked from making LLM API Calls. User={user}"
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
except HTTPException as e:
|
except HTTPException as e:
|
||||||
raise e
|
raise e
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|
|
@ -602,6 +602,22 @@ class LiteLLM_UserTable(LiteLLMBase):
|
||||||
protected_namespaces = ()
|
protected_namespaces = ()
|
||||||
|
|
||||||
|
|
||||||
|
class LiteLLM_EndUserTable(LiteLLMBase):
|
||||||
|
user_id: str
|
||||||
|
blocked: bool
|
||||||
|
alias: Optional[str] = None
|
||||||
|
spend: float = 0.0
|
||||||
|
|
||||||
|
@root_validator(pre=True)
|
||||||
|
def set_model_info(cls, values):
|
||||||
|
if values.get("spend") is None:
|
||||||
|
values.update({"spend": 0.0})
|
||||||
|
return values
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
protected_namespaces = ()
|
||||||
|
|
||||||
|
|
||||||
class LiteLLM_SpendLogs(LiteLLMBase):
|
class LiteLLM_SpendLogs(LiteLLMBase):
|
||||||
request_id: str
|
request_id: str
|
||||||
api_key: str
|
api_key: str
|
||||||
|
|
|
@ -1774,7 +1774,9 @@ class ProxyConfig:
|
||||||
_ENTERPRISE_BlockedUserList,
|
_ENTERPRISE_BlockedUserList,
|
||||||
)
|
)
|
||||||
|
|
||||||
blocked_user_list = _ENTERPRISE_BlockedUserList()
|
blocked_user_list = _ENTERPRISE_BlockedUserList(
|
||||||
|
prisma_client=prisma_client
|
||||||
|
)
|
||||||
imported_list.append(blocked_user_list)
|
imported_list.append(blocked_user_list)
|
||||||
elif (
|
elif (
|
||||||
isinstance(callback, str)
|
isinstance(callback, str)
|
||||||
|
@ -5111,22 +5113,27 @@ async def block_user(data: BlockUsers):
|
||||||
}'
|
}'
|
||||||
```
|
```
|
||||||
"""
|
"""
|
||||||
if prisma_client is not None:
|
try:
|
||||||
for id in data.user_ids:
|
records = []
|
||||||
await prisma_client.db.litellm_endusertable.upsert(
|
if prisma_client is not None:
|
||||||
where={"id": id},
|
for id in data.user_ids:
|
||||||
data={
|
record = await prisma_client.db.litellm_endusertable.upsert(
|
||||||
"create": {"id": id, "blocked": True},
|
where={"user_id": id},
|
||||||
"update": {"blocked": True},
|
data={
|
||||||
},
|
"create": {"user_id": id, "blocked": True},
|
||||||
|
"update": {"blocked": True},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
records.append(record)
|
||||||
|
else:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=500,
|
||||||
|
detail={"error": "Postgres DB Not connected"},
|
||||||
)
|
)
|
||||||
else:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=500,
|
|
||||||
detail={"error": "Postgres DB Not connected"},
|
|
||||||
)
|
|
||||||
|
|
||||||
return {"blocked_users": litellm.blocked_user_list}
|
return {"blocked_users": records}
|
||||||
|
except Exception as e:
|
||||||
|
raise HTTPException(status_code=500, detail={"error": str(e)})
|
||||||
|
|
||||||
|
|
||||||
@router.post(
|
@router.post(
|
||||||
|
|
|
@ -129,7 +129,7 @@ model LiteLLM_VerificationToken {
|
||||||
}
|
}
|
||||||
|
|
||||||
model LiteLLM_EndUserTable {
|
model LiteLLM_EndUserTable {
|
||||||
id String @id
|
user_id String @id
|
||||||
alias String? // admin-facing alias
|
alias String? // admin-facing alias
|
||||||
spend Float @default(0.0)
|
spend Float @default(0.0)
|
||||||
budget_id String?
|
budget_id String?
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue