feat(auth_checks.py): Allow admin to disable team from turning on/off guardrails.

This commit is contained in:
Krrish Dholakia 2024-07-20 18:39:05 -07:00
parent 2ff30fdace
commit a351b7cc3e
7 changed files with 96 additions and 1 deletions

View file

@ -173,6 +173,63 @@ def test_chat_completion(mock_acompletion, client_no_auth):
pytest.fail(f"LiteLLM Proxy test failed. Exception - {str(e)}")
@mock_patch_acompletion()
@pytest.mark.asyncio
async def test_team_disable_guardrails(mock_acompletion, client_no_auth):
"""
If team not allowed to turn on/off guardrails
Raise 403 forbidden error, if request is made by team on `/key/generate` or `/chat/completions`.
"""
import asyncio
import json
import time
from fastapi import HTTPException, Request
from starlette.datastructures import URL
from litellm.proxy._types import LiteLLM_TeamTable, ProxyException, UserAPIKeyAuth
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
from litellm.proxy.proxy_server import hash_token, user_api_key_cache
_team_id = "1234"
user_key = "sk-12345678"
valid_token = UserAPIKeyAuth(
team_id=_team_id,
team_blocked=True,
token=hash_token(user_key),
last_refreshed_at=time.time(),
)
await asyncio.sleep(1)
team_obj = LiteLLM_TeamTable(
team_id=_team_id,
blocked=False,
last_refreshed_at=time.time(),
metadata={"guardrails": {"modify_guardrails": False}},
)
user_api_key_cache.set_cache(key=hash_token(user_key), value=valid_token)
user_api_key_cache.set_cache(key="team_id:{}".format(_team_id), value=team_obj)
setattr(litellm.proxy.proxy_server, "user_api_key_cache", user_api_key_cache)
setattr(litellm.proxy.proxy_server, "master_key", "sk-1234")
setattr(litellm.proxy.proxy_server, "prisma_client", "hello-world")
request = Request(scope={"type": "http"})
request._url = URL(url="/chat/completions")
body = {"metadata": {"guardrails": {"hide_secrets": False}}}
json_bytes = json.dumps(body).encode("utf-8")
request._body = json_bytes
try:
await user_api_key_auth(request=request, api_key="Bearer " + user_key)
pytest.fail("Expected to raise 403 forbidden error.")
except ProxyException as e:
assert e.code == 403
from litellm.tests.test_custom_callback_input import CompletionCustomHandler