feat - add open ai moderations check

This commit is contained in:
Ishaan Jaff 2024-05-23 13:08:06 -07:00
parent 84f3690453
commit 1fe035c6dd
3 changed files with 79 additions and 0 deletions

View file

@ -0,0 +1,66 @@
# +-------------------------------------------------------------+
#
# Use OpenAI /moderations for your LLM calls
#
# +-------------------------------------------------------------+
# Thank you users! We ❤️ you! - Krrish & Ishaan
import sys, os
sys.path.insert(
0, os.path.abspath("../..")
) # Adds the parent directory to the system path
from typing import Optional, Literal, Union
import litellm, traceback, sys, uuid
from litellm.caching import DualCache
from litellm.proxy._types import UserAPIKeyAuth
from litellm.integrations.custom_logger import CustomLogger
from fastapi import HTTPException
from litellm._logging import verbose_proxy_logger
from litellm.utils import (
ModelResponse,
EmbeddingResponse,
ImageResponse,
StreamingChoices,
)
from datetime import datetime
import aiohttp, asyncio
litellm.set_verbose = True
class _ENTERPRISE_OpenAI_Moderation(CustomLogger):
def __init__(self):
self.model_name = (
litellm.openai_moderations_model_name or "text-moderation-latest"
) # pass the model_name you initialized on litellm.Router()
pass
#### CALL HOOKS - proxy only ####
async def async_moderation_hook( ### 👈 KEY CHANGE ###
self,
data: dict,
user_api_key_dict: UserAPIKeyAuth,
call_type: Literal["completion", "embeddings", "image_generation"],
):
_messages = data.get("messages", None)
if "messages" in data and isinstance(data["messages"], list):
text = ""
for m in data["messages"]: # assume messages is a list
if "content" in m and isinstance(m["content"], str):
text += m["content"]
from litellm.proxy.proxy_server import llm_router
if llm_router is None:
return
moderation_response = await llm_router.amoderation(
model=self.model_name, input=text
)
if moderation_response.results[0].flagged == True:
raise HTTPException(
status_code=403, detail={"error": "Violated content safety policy"}
)
pass

View file

@ -97,6 +97,7 @@ ssl_verify: bool = True
disable_streaming_logging: bool = False
### GUARDRAILS ###
llamaguard_model_name: Optional[str] = None
openai_moderations_model_name: Optional[str] = None
presidio_ad_hoc_recognizers: Optional[str] = None
google_moderation_confidence_threshold: Optional[float] = None
llamaguard_unsafe_content_categories: Optional[str] = None

View file

@ -2313,6 +2313,18 @@ class ProxyConfig:
llama_guard_object = _ENTERPRISE_LlamaGuard()
imported_list.append(llama_guard_object)
elif (
isinstance(callback, str)
and callback == "openai_moderations"
):
from enterprise.enterprise_hooks.openai_moderation import (
_ENTERPRISE_OpenAI_Moderation,
)
openai_moderations_object = (
_ENTERPRISE_OpenAI_Moderation()
)
imported_list.append(openai_moderations_object)
elif (
isinstance(callback, str)
and callback == "google_text_moderation"