diff --git a/docs/my-website/docs/proxy/enterprise.md b/docs/my-website/docs/proxy/enterprise.md index 26308e56ca..23a645bc6e 100644 --- a/docs/my-website/docs/proxy/enterprise.md +++ b/docs/my-website/docs/proxy/enterprise.md @@ -14,6 +14,7 @@ Features here are behind a commercial license in our `/enterprise` folder. [**Se Features: - [ ] Content Moderation with LlamaGuard - [ ] Content Moderation with Google Text Moderations +- [ ] Content Moderation with LLM Guard - [ ] Tracking Spend for Custom Tags ## Content Moderation with LlamaGuard @@ -48,6 +49,33 @@ callbacks: ["llamaguard_moderations"] llamaguard_unsafe_content_categories: /path/to/llamaguard_prompt.txt ``` +## Content Moderation with LLM Guard + +Set the LLM Guard API Base in your environment + +```env +LLM_GUARD_API_BASE = "http://0.0.0.0:8000" +``` + +Add `llmguard_moderations` as a callback + +```yaml +litellm_settings: + callbacks: ["llmguard_moderations"] +``` + +Now you can easily test it + +- Make a regular /chat/completion call + +- Check your proxy logs for any statement with `LLM Guard:` + +Expected results: + +``` +LLM Guard: Received response - {"sanitized_prompt": "hello world", "is_valid": true, "scanners": { "Regex": 0.0 }} +``` + ## Content Moderation with Google Text Moderation Requires your GOOGLE_APPLICATION_CREDENTIALS to be set in your .env (same as VertexAI). @@ -102,6 +130,8 @@ Here are the category specific values: | "finance" | finance_threshold: 0.1 | | "legal" | legal_threshold: 0.1 | + + ## Tracking Spend for Custom Tags Requirements: diff --git a/enterprise/enterprise_hooks/llm_guard.py b/enterprise/enterprise_hooks/llm_guard.py new file mode 100644 index 0000000000..c000f60119 --- /dev/null +++ b/enterprise/enterprise_hooks/llm_guard.py @@ -0,0 +1,124 @@ +# +------------------------+ +# +# LLM Guard +# https://llm-guard.com/ +# +# +------------------------+ +# Thank you users! We ❤️ you! - Krrish & Ishaan +## This provides an LLM Guard Integration for content moderation on the proxy + +from typing import Optional, Literal, Union +import litellm, traceback, sys, uuid, os +from litellm.caching import DualCache +from litellm.proxy._types import UserAPIKeyAuth +from litellm.integrations.custom_logger import CustomLogger +from fastapi import HTTPException +from litellm._logging import verbose_proxy_logger +from litellm.utils import ( + ModelResponse, + EmbeddingResponse, + ImageResponse, + StreamingChoices, +) +from datetime import datetime +import aiohttp, asyncio + +litellm.set_verbose = True + + +class _ENTERPRISE_LLMGuard(CustomLogger): + # Class variables or attributes + def __init__( + self, mock_testing: bool = False, mock_redacted_text: Optional[dict] = None + ): + self.mock_redacted_text = mock_redacted_text + if mock_testing == True: # for testing purposes only + return + self.llm_guard_api_base = litellm.get_secret("LLM_GUARD_API_BASE", None) + if self.llm_guard_api_base is None: + raise Exception("Missing `LLM_GUARD_API_BASE` from environment") + elif not self.llm_guard_api_base.endswith("/"): + self.llm_guard_api_base += "/" + + def print_verbose(self, print_statement): + try: + verbose_proxy_logger.debug(print_statement) + if litellm.set_verbose: + print(print_statement) # noqa + except: + pass + + async def moderation_check(self, text: str): + """ + [TODO] make this more performant for high-throughput scenario + """ + try: + async with aiohttp.ClientSession() as session: + if self.mock_redacted_text is not None: + redacted_text = self.mock_redacted_text + else: + # Make the first request to /analyze + analyze_url = f"{self.llm_guard_api_base}analyze/prompt" + verbose_proxy_logger.debug(f"Making request to: {analyze_url}") + analyze_payload = {"prompt": text} + redacted_text = None + async with session.post( + analyze_url, json=analyze_payload + ) as response: + redacted_text = await response.json() + verbose_proxy_logger.info( + f"LLM Guard: Received response - {redacted_text}" + ) + if redacted_text is not None: + if ( + redacted_text.get("is_valid", None) is not None + and redacted_text["is_valid"] == "True" + ): + raise HTTPException( + status_code=400, + detail={"error": "Violated content safety policy"}, + ) + else: + pass + else: + raise HTTPException( + status_code=500, + detail={ + "error": f"Invalid content moderation response: {redacted_text}" + }, + ) + except Exception as e: + traceback.print_exc() + raise e + + async def async_moderation_hook( + self, + data: dict, + ): + """ + - Calls the LLM Guard Endpoint + - Rejects request if it fails safety check + - Use the sanitized prompt returned + - LLM Guard can handle things like PII Masking, etc. + """ + if "messages" in data: + safety_check_messages = data["messages"][ + -1 + ] # get the last response - llama guard has a 4k token limit + if ( + isinstance(safety_check_messages, dict) + and "content" in safety_check_messages + and isinstance(safety_check_messages["content"], str) + ): + await self.moderation_check(safety_check_messages["content"]) + + return data + + +# llm_guard = _ENTERPRISE_LLMGuard() + +# asyncio.run( +# llm_guard.async_moderation_hook( +# data={"messages": [{"role": "user", "content": "Hey how's it going?"}]} +# ) +# ) diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index 647599a88d..4c837471c6 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -1469,6 +1469,16 @@ class ProxyConfig: _ENTERPRISE_GoogleTextModeration() ) imported_list.append(google_text_moderation_obj) + elif ( + isinstance(callback, str) + and callback == "llmguard_moderations" + ): + from litellm.proxy.enterprise.enterprise_hooks.llm_guard import ( + _ENTERPRISE_LLMGuard, + ) + + llm_guard_moderation_obj = _ENTERPRISE_LLMGuard() + imported_list.append(llm_guard_moderation_obj) else: imported_list.append( get_instance_fn( diff --git a/litellm/tests/test_llm_guard.py b/litellm/tests/test_llm_guard.py new file mode 100644 index 0000000000..0f9fad9a4f --- /dev/null +++ b/litellm/tests/test_llm_guard.py @@ -0,0 +1,95 @@ +# What is this? +## This tests the llm guard integration + +# What is this? +## Unit test for presidio pii masking +import sys, os, asyncio, time, random +from datetime import datetime +import traceback +from dotenv import load_dotenv + +load_dotenv() +import os + +sys.path.insert( + 0, os.path.abspath("../..") +) # Adds the parent directory to the system path +import pytest +import litellm +from litellm.proxy.enterprise.enterprise_hooks.llm_guard import _ENTERPRISE_LLMGuard +from litellm import Router, mock_completion +from litellm.proxy.utils import ProxyLogging +from litellm.proxy._types import UserAPIKeyAuth +from litellm.caching import DualCache + +### UNIT TESTS FOR LLM GUARD ### + + +# Test if PII masking works with input A +@pytest.mark.asyncio +async def test_llm_guard_valid_response(): + """ + Tests to see llm guard raises an error for a flagged response + """ + input_a_anonymizer_results = { + "sanitized_prompt": "hello world", + "is_valid": True, + "scanners": {"Regex": 0.0}, + } + llm_guard = _ENTERPRISE_LLMGuard( + mock_testing=True, mock_redacted_text=input_a_anonymizer_results + ) + + _api_key = "sk-12345" + user_api_key_dict = UserAPIKeyAuth(api_key=_api_key) + local_cache = DualCache() + + try: + await llm_guard.async_moderation_hook( + data={ + "messages": [ + { + "role": "user", + "content": "hello world, my name is Jane Doe. My number is: 23r323r23r2wwkl", + } + ] + }, + ) + except Exception as e: + pytest.fail(f"An exception occurred - {str(e)}") + + +# Test if PII masking works with input B (also test if the response != A's response) +@pytest.mark.asyncio +async def test_llm_guard_error_raising(): + """ + Tests to see llm guard raises an error for a flagged response + """ + + input_b_anonymizer_results = { + "sanitized_prompt": "hello world", + "is_valid": False, + "scanners": {"Regex": 0.0}, + } + llm_guard = _ENTERPRISE_LLMGuard( + mock_testing=True, mock_redacted_text=input_b_anonymizer_results + ) + + _api_key = "sk-12345" + user_api_key_dict = UserAPIKeyAuth(api_key=_api_key) + local_cache = DualCache() + + try: + await llm_guard.async_moderation_hook( + data={ + "messages": [ + { + "role": "user", + "content": "hello world, my name is Jane Doe. My number is: 23r323r23r2wwkl", + } + ] + }, + ) + pytest.fail(f"Should have failed - {str(e)}") + except Exception as e: + pass