From 1fe035c6ddbc43bca89faf8a75c3460bc101636d Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Thu, 23 May 2024 13:08:06 -0700 Subject: [PATCH] feat - add open ai moderations check --- .../enterprise_hooks/openai_moderation.py | 66 +++++++++++++++++++ litellm/__init__.py | 1 + litellm/proxy/proxy_server.py | 12 ++++ 3 files changed, 79 insertions(+) create mode 100644 enterprise/enterprise_hooks/openai_moderation.py diff --git a/enterprise/enterprise_hooks/openai_moderation.py b/enterprise/enterprise_hooks/openai_moderation.py new file mode 100644 index 000000000..78a13a5db --- /dev/null +++ b/enterprise/enterprise_hooks/openai_moderation.py @@ -0,0 +1,66 @@ +# +-------------------------------------------------------------+ +# +# Use OpenAI /moderations for your LLM calls +# +# +-------------------------------------------------------------+ +# Thank you users! We ❤️ you! - Krrish & Ishaan + +import sys, os + +sys.path.insert( + 0, os.path.abspath("../..") +) # Adds the parent directory to the system path +from typing import Optional, Literal, Union +import litellm, traceback, sys, uuid +from litellm.caching import DualCache +from litellm.proxy._types import UserAPIKeyAuth +from litellm.integrations.custom_logger import CustomLogger +from fastapi import HTTPException +from litellm._logging import verbose_proxy_logger +from litellm.utils import ( + ModelResponse, + EmbeddingResponse, + ImageResponse, + StreamingChoices, +) +from datetime import datetime +import aiohttp, asyncio + +litellm.set_verbose = True + + +class _ENTERPRISE_OpenAI_Moderation(CustomLogger): + def __init__(self): + self.model_name = ( + litellm.openai_moderations_model_name or "text-moderation-latest" + ) # pass the model_name you initialized on litellm.Router() + pass + + #### CALL HOOKS - proxy only #### + + async def async_moderation_hook( ### 👈 KEY CHANGE ### + self, + data: dict, + user_api_key_dict: UserAPIKeyAuth, + call_type: Literal["completion", "embeddings", "image_generation"], + ): + _messages = data.get("messages", None) + if "messages" in data and isinstance(data["messages"], list): + text = "" + for m in data["messages"]: # assume messages is a list + if "content" in m and isinstance(m["content"], str): + text += m["content"] + + from litellm.proxy.proxy_server import llm_router + + if llm_router is None: + return + + moderation_response = await llm_router.amoderation( + model=self.model_name, input=text + ) + if moderation_response.results[0].flagged == True: + raise HTTPException( + status_code=403, detail={"error": "Violated content safety policy"} + ) + pass diff --git a/litellm/__init__.py b/litellm/__init__.py index 92610afd9..2af65f790 100644 --- a/litellm/__init__.py +++ b/litellm/__init__.py @@ -97,6 +97,7 @@ ssl_verify: bool = True disable_streaming_logging: bool = False ### GUARDRAILS ### llamaguard_model_name: Optional[str] = None +openai_moderations_model_name: Optional[str] = None presidio_ad_hoc_recognizers: Optional[str] = None google_moderation_confidence_threshold: Optional[float] = None llamaguard_unsafe_content_categories: Optional[str] = None diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index b52c9b249..4045c7d91 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -2313,6 +2313,18 @@ class ProxyConfig: llama_guard_object = _ENTERPRISE_LlamaGuard() imported_list.append(llama_guard_object) + elif ( + isinstance(callback, str) + and callback == "openai_moderations" + ): + from enterprise.enterprise_hooks.openai_moderation import ( + _ENTERPRISE_OpenAI_Moderation, + ) + + openai_moderations_object = ( + _ENTERPRISE_OpenAI_Moderation() + ) + imported_list.append(openai_moderations_object) elif ( isinstance(callback, str) and callback == "google_text_moderation"