mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 03:04:13 +00:00
Merge pull request #4810 from BerriAI/litellm_team_modify_guardrails
feat(auth_checks.py): Allow admin to disable team from turning on/off guardrails
This commit is contained in:
commit
c4db6aa15e
6 changed files with 95 additions and 1 deletions
|
@ -57,6 +57,7 @@ def common_checks(
|
||||||
4. If end_user (either via JWT or 'user' passed to /chat/completions, /embeddings endpoint) is in budget
|
4. If end_user (either via JWT or 'user' passed to /chat/completions, /embeddings endpoint) is in budget
|
||||||
5. [OPTIONAL] If 'enforce_end_user' enabled - did developer pass in 'user' param for openai endpoints
|
5. [OPTIONAL] If 'enforce_end_user' enabled - did developer pass in 'user' param for openai endpoints
|
||||||
6. [OPTIONAL] If 'litellm.max_budget' is set (>0), is proxy under budget
|
6. [OPTIONAL] If 'litellm.max_budget' is set (>0), is proxy under budget
|
||||||
|
7. [OPTIONAL] If guardrails modified - is request allowed to change this
|
||||||
"""
|
"""
|
||||||
_model = request_body.get("model", None)
|
_model = request_body.get("model", None)
|
||||||
if team_object is not None and team_object.blocked is True:
|
if team_object is not None and team_object.blocked is True:
|
||||||
|
@ -158,6 +159,22 @@ def common_checks(
|
||||||
raise litellm.BudgetExceededError(
|
raise litellm.BudgetExceededError(
|
||||||
current_cost=global_proxy_spend, max_budget=litellm.max_budget
|
current_cost=global_proxy_spend, max_budget=litellm.max_budget
|
||||||
)
|
)
|
||||||
|
|
||||||
|
_request_metadata: dict = request_body.get("metadata", {}) or {}
|
||||||
|
if _request_metadata.get("guardrails"):
|
||||||
|
# check if team allowed to modify guardrails
|
||||||
|
from litellm.proxy.guardrails.guardrail_helpers import can_modify_guardrails
|
||||||
|
|
||||||
|
can_modify: bool = can_modify_guardrails(team_object)
|
||||||
|
if can_modify is False:
|
||||||
|
from fastapi import HTTPException
|
||||||
|
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=403,
|
||||||
|
detail={
|
||||||
|
"error": "Your team does not have permission to modify guardrails."
|
||||||
|
},
|
||||||
|
)
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -924,6 +924,7 @@ async def user_api_key_auth(
|
||||||
rpm_limit=valid_token.team_rpm_limit,
|
rpm_limit=valid_token.team_rpm_limit,
|
||||||
blocked=valid_token.team_blocked,
|
blocked=valid_token.team_blocked,
|
||||||
models=valid_token.team_models,
|
models=valid_token.team_models,
|
||||||
|
metadata=valid_token.team_metadata,
|
||||||
)
|
)
|
||||||
|
|
||||||
user_api_key_cache.set_cache(
|
user_api_key_cache.set_cache(
|
||||||
|
|
|
@ -1,9 +1,26 @@
|
||||||
|
from typing import Dict
|
||||||
|
|
||||||
import litellm
|
import litellm
|
||||||
from litellm._logging import verbose_proxy_logger
|
from litellm._logging import verbose_proxy_logger
|
||||||
from litellm.proxy.proxy_server import UserAPIKeyAuth
|
from litellm.proxy.proxy_server import LiteLLM_TeamTable, UserAPIKeyAuth
|
||||||
from litellm.types.guardrails import *
|
from litellm.types.guardrails import *
|
||||||
|
|
||||||
|
|
||||||
|
def can_modify_guardrails(team_obj: Optional[LiteLLM_TeamTable]) -> bool:
|
||||||
|
if team_obj is None:
|
||||||
|
return True
|
||||||
|
|
||||||
|
team_metadata = team_obj.metadata or {}
|
||||||
|
|
||||||
|
if team_metadata.get("guardrails", None) is not None and isinstance(
|
||||||
|
team_metadata.get("guardrails"), Dict
|
||||||
|
):
|
||||||
|
if team_metadata.get("guardrails", {}).get("modify_guardrails", None) is False:
|
||||||
|
return False
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
async def should_proceed_based_on_metadata(data: dict, guardrail_name: str) -> bool:
|
async def should_proceed_based_on_metadata(data: dict, guardrail_name: str) -> bool:
|
||||||
"""
|
"""
|
||||||
checks if this guardrail should be applied to this call
|
checks if this guardrail should be applied to this call
|
||||||
|
|
|
@ -363,6 +363,7 @@ async def update_team(
|
||||||
# set the budget_reset_at in DB
|
# set the budget_reset_at in DB
|
||||||
updated_kv["budget_reset_at"] = reset_at
|
updated_kv["budget_reset_at"] = reset_at
|
||||||
|
|
||||||
|
updated_kv = prisma_client.jsonify_object(data=updated_kv)
|
||||||
team_row: Optional[
|
team_row: Optional[
|
||||||
LiteLLM_TeamTable
|
LiteLLM_TeamTable
|
||||||
] = await prisma_client.db.litellm_teamtable.update(
|
] = await prisma_client.db.litellm_teamtable.update(
|
||||||
|
|
|
@ -1322,6 +1322,7 @@ class PrismaClient:
|
||||||
t.metadata AS team_metadata,
|
t.metadata AS team_metadata,
|
||||||
t.blocked AS team_blocked,
|
t.blocked AS team_blocked,
|
||||||
t.team_alias AS team_alias,
|
t.team_alias AS team_alias,
|
||||||
|
t.metadata AS team_metadata,
|
||||||
tm.spend AS team_member_spend,
|
tm.spend AS team_member_spend,
|
||||||
m.aliases as team_model_aliases
|
m.aliases as team_model_aliases
|
||||||
FROM "LiteLLM_VerificationToken" AS v
|
FROM "LiteLLM_VerificationToken" AS v
|
||||||
|
|
|
@ -173,6 +173,63 @@ def test_chat_completion(mock_acompletion, client_no_auth):
|
||||||
pytest.fail(f"LiteLLM Proxy test failed. Exception - {str(e)}")
|
pytest.fail(f"LiteLLM Proxy test failed. Exception - {str(e)}")
|
||||||
|
|
||||||
|
|
||||||
|
@mock_patch_acompletion()
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_team_disable_guardrails(mock_acompletion, client_no_auth):
|
||||||
|
"""
|
||||||
|
If team not allowed to turn on/off guardrails
|
||||||
|
|
||||||
|
Raise 403 forbidden error, if request is made by team on `/key/generate` or `/chat/completions`.
|
||||||
|
"""
|
||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
import time
|
||||||
|
|
||||||
|
from fastapi import HTTPException, Request
|
||||||
|
from starlette.datastructures import URL
|
||||||
|
|
||||||
|
from litellm.proxy._types import LiteLLM_TeamTable, ProxyException, UserAPIKeyAuth
|
||||||
|
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
|
||||||
|
from litellm.proxy.proxy_server import hash_token, user_api_key_cache
|
||||||
|
|
||||||
|
_team_id = "1234"
|
||||||
|
user_key = "sk-12345678"
|
||||||
|
|
||||||
|
valid_token = UserAPIKeyAuth(
|
||||||
|
team_id=_team_id,
|
||||||
|
team_blocked=True,
|
||||||
|
token=hash_token(user_key),
|
||||||
|
last_refreshed_at=time.time(),
|
||||||
|
)
|
||||||
|
await asyncio.sleep(1)
|
||||||
|
team_obj = LiteLLM_TeamTable(
|
||||||
|
team_id=_team_id,
|
||||||
|
blocked=False,
|
||||||
|
last_refreshed_at=time.time(),
|
||||||
|
metadata={"guardrails": {"modify_guardrails": False}},
|
||||||
|
)
|
||||||
|
user_api_key_cache.set_cache(key=hash_token(user_key), value=valid_token)
|
||||||
|
user_api_key_cache.set_cache(key="team_id:{}".format(_team_id), value=team_obj)
|
||||||
|
|
||||||
|
setattr(litellm.proxy.proxy_server, "user_api_key_cache", user_api_key_cache)
|
||||||
|
setattr(litellm.proxy.proxy_server, "master_key", "sk-1234")
|
||||||
|
setattr(litellm.proxy.proxy_server, "prisma_client", "hello-world")
|
||||||
|
|
||||||
|
request = Request(scope={"type": "http"})
|
||||||
|
request._url = URL(url="/chat/completions")
|
||||||
|
|
||||||
|
body = {"metadata": {"guardrails": {"hide_secrets": False}}}
|
||||||
|
json_bytes = json.dumps(body).encode("utf-8")
|
||||||
|
|
||||||
|
request._body = json_bytes
|
||||||
|
|
||||||
|
try:
|
||||||
|
await user_api_key_auth(request=request, api_key="Bearer " + user_key)
|
||||||
|
pytest.fail("Expected to raise 403 forbidden error.")
|
||||||
|
except ProxyException as e:
|
||||||
|
assert e.code == 403
|
||||||
|
|
||||||
|
|
||||||
from litellm.tests.test_custom_callback_input import CompletionCustomHandler
|
from litellm.tests.test_custom_callback_input import CompletionCustomHandler
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue