diff --git a/litellm/proxy/auth/auth_checks.py b/litellm/proxy/auth/auth_checks.py index 1650eb8aa..91d4b1938 100644 --- a/litellm/proxy/auth/auth_checks.py +++ b/litellm/proxy/auth/auth_checks.py @@ -57,6 +57,7 @@ def common_checks( 4. If end_user (either via JWT or 'user' passed to /chat/completions, /embeddings endpoint) is in budget 5. [OPTIONAL] If 'enforce_end_user' enabled - did developer pass in 'user' param for openai endpoints 6. [OPTIONAL] If 'litellm.max_budget' is set (>0), is proxy under budget + 7. [OPTIONAL] If guardrails modified - is request allowed to change this """ _model = request_body.get("model", None) if team_object is not None and team_object.blocked is True: @@ -158,6 +159,22 @@ def common_checks( raise litellm.BudgetExceededError( current_cost=global_proxy_spend, max_budget=litellm.max_budget ) + + _request_metadata: dict = request_body.get("metadata", {}) or {} + if _request_metadata.get("guardrails"): + # check if team allowed to modify guardrails + from litellm.proxy.guardrails.guardrail_helpers import can_modify_guardrails + + can_modify: bool = can_modify_guardrails(team_object) + if can_modify is False: + from fastapi import HTTPException + + raise HTTPException( + status_code=403, + detail={ + "error": "Your team does not have permission to modify guardrails." + }, + ) return True diff --git a/litellm/proxy/auth/user_api_key_auth.py b/litellm/proxy/auth/user_api_key_auth.py index b4c88148e..d91baf5ca 100644 --- a/litellm/proxy/auth/user_api_key_auth.py +++ b/litellm/proxy/auth/user_api_key_auth.py @@ -924,6 +924,7 @@ async def user_api_key_auth( rpm_limit=valid_token.team_rpm_limit, blocked=valid_token.team_blocked, models=valid_token.team_models, + metadata=valid_token.team_metadata, ) user_api_key_cache.set_cache( diff --git a/litellm/proxy/guardrails/guardrail_helpers.py b/litellm/proxy/guardrails/guardrail_helpers.py index d6a081b4d..e0a5f1eb3 100644 --- a/litellm/proxy/guardrails/guardrail_helpers.py +++ b/litellm/proxy/guardrails/guardrail_helpers.py @@ -1,9 +1,26 @@ +from typing import Dict + import litellm from litellm._logging import verbose_proxy_logger -from litellm.proxy.proxy_server import UserAPIKeyAuth +from litellm.proxy.proxy_server import LiteLLM_TeamTable, UserAPIKeyAuth from litellm.types.guardrails import * +def can_modify_guardrails(team_obj: Optional[LiteLLM_TeamTable]) -> bool: + if team_obj is None: + return True + + team_metadata = team_obj.metadata or {} + + if team_metadata.get("guardrails", None) is not None and isinstance( + team_metadata.get("guardrails"), Dict + ): + if team_metadata.get("guardrails", {}).get("modify_guardrails", None) is False: + return False + + return True + + async def should_proceed_based_on_metadata(data: dict, guardrail_name: str) -> bool: """ checks if this guardrail should be applied to this call diff --git a/litellm/proxy/management_endpoints/team_endpoints.py b/litellm/proxy/management_endpoints/team_endpoints.py index bb98a02ec..9ba76a203 100644 --- a/litellm/proxy/management_endpoints/team_endpoints.py +++ b/litellm/proxy/management_endpoints/team_endpoints.py @@ -363,6 +363,7 @@ async def update_team( # set the budget_reset_at in DB updated_kv["budget_reset_at"] = reset_at + updated_kv = prisma_client.jsonify_object(data=updated_kv) team_row: Optional[ LiteLLM_TeamTable ] = await prisma_client.db.litellm_teamtable.update( diff --git a/litellm/proxy/utils.py b/litellm/proxy/utils.py index cebd79aa2..a982c6cd7 100644 --- a/litellm/proxy/utils.py +++ b/litellm/proxy/utils.py @@ -1322,6 +1322,7 @@ class PrismaClient: t.metadata AS team_metadata, t.blocked AS team_blocked, t.team_alias AS team_alias, + t.metadata AS team_metadata, tm.spend AS team_member_spend, m.aliases as team_model_aliases FROM "LiteLLM_VerificationToken" AS v diff --git a/litellm/tests/test_proxy_server.py b/litellm/tests/test_proxy_server.py index ed7451c27..f3cb69a08 100644 --- a/litellm/tests/test_proxy_server.py +++ b/litellm/tests/test_proxy_server.py @@ -173,6 +173,63 @@ def test_chat_completion(mock_acompletion, client_no_auth): pytest.fail(f"LiteLLM Proxy test failed. Exception - {str(e)}") +@mock_patch_acompletion() +@pytest.mark.asyncio +async def test_team_disable_guardrails(mock_acompletion, client_no_auth): + """ + If team not allowed to turn on/off guardrails + + Raise 403 forbidden error, if request is made by team on `/key/generate` or `/chat/completions`. + """ + import asyncio + import json + import time + + from fastapi import HTTPException, Request + from starlette.datastructures import URL + + from litellm.proxy._types import LiteLLM_TeamTable, ProxyException, UserAPIKeyAuth + from litellm.proxy.auth.user_api_key_auth import user_api_key_auth + from litellm.proxy.proxy_server import hash_token, user_api_key_cache + + _team_id = "1234" + user_key = "sk-12345678" + + valid_token = UserAPIKeyAuth( + team_id=_team_id, + team_blocked=True, + token=hash_token(user_key), + last_refreshed_at=time.time(), + ) + await asyncio.sleep(1) + team_obj = LiteLLM_TeamTable( + team_id=_team_id, + blocked=False, + last_refreshed_at=time.time(), + metadata={"guardrails": {"modify_guardrails": False}}, + ) + user_api_key_cache.set_cache(key=hash_token(user_key), value=valid_token) + user_api_key_cache.set_cache(key="team_id:{}".format(_team_id), value=team_obj) + + setattr(litellm.proxy.proxy_server, "user_api_key_cache", user_api_key_cache) + setattr(litellm.proxy.proxy_server, "master_key", "sk-1234") + setattr(litellm.proxy.proxy_server, "prisma_client", "hello-world") + + request = Request(scope={"type": "http"}) + request._url = URL(url="/chat/completions") + + body = {"metadata": {"guardrails": {"hide_secrets": False}}} + json_bytes = json.dumps(body).encode("utf-8") + + request._body = json_bytes + + try: + await user_api_key_auth(request=request, api_key="Bearer " + user_key) + pytest.fail("Expected to raise 403 forbidden error.") + except ProxyException as e: + assert e.code == 403 + + from litellm.tests.test_custom_callback_input import CompletionCustomHandler