feat(llama_guard.py): add llama guard support for content moderation + new async_moderation_hook endpoint

This commit is contained in:
Krrish Dholakia 2024-02-16 18:45:25 -08:00
parent 5e7dda4f88
commit 2a4a6995ac
12 changed files with 163 additions and 132 deletions

View file

@ -0,0 +1,71 @@
# +-------------------------------------------------------------+
#
# Llama Guard
# https://huggingface.co/meta-llama/LlamaGuard-7b/tree/main
#
# LLM for Content Moderation
# +-------------------------------------------------------------+
# Thank you users! We ❤️ you! - Krrish & Ishaan
import sys, os
sys.path.insert(
0, os.path.abspath("../..")
) # Adds the parent directory to the system path
from typing import Optional, Literal, Union
import litellm, traceback, sys, uuid
from litellm.caching import DualCache
from litellm.proxy._types import UserAPIKeyAuth
from litellm.integrations.custom_logger import CustomLogger
from fastapi import HTTPException
from litellm._logging import verbose_proxy_logger
from litellm.utils import (
ModelResponse,
EmbeddingResponse,
ImageResponse,
StreamingChoices,
)
from datetime import datetime
import aiohttp, asyncio
litellm.set_verbose = True
class _ENTERPRISE_LlamaGuard(CustomLogger):
# Class variables or attributes
def __init__(self, model_name: Optional[str] = None):
self.model = model_name or litellm.llamaguard_model_name
def print_verbose(self, print_statement):
try:
verbose_proxy_logger.debug(print_statement)
if litellm.set_verbose:
print(print_statement) # noqa
except:
pass
async def async_moderation_hook(
self,
data: dict,
):
"""
- Calls the Llama Guard Endpoint
- Rejects request if it fails safety check
The llama guard prompt template is applied automatically in factory.py
"""
safety_check_messages = data["messages"][
-1
] # get the last response - llama guard has a 4k token limit
response = await litellm.acompletion(
model=self.model,
messages=[safety_check_messages],
hf_model_name="meta-llama/LlamaGuard-7b",
)
if "unsafe" in response.choices[0].message.content:
raise HTTPException(
status_code=400, detail={"error": "Violated content safety policy"}
)
return data

16
enterprise/utils.py Normal file
View file

@ -0,0 +1,16 @@
# Enterprise Proxy Util Endpoints
async def get_spend_by_tags(start_date=None, end_date=None, prisma_client=None):
response = await prisma_client.db.query_raw(
"""
SELECT
jsonb_array_elements_text(request_tags) AS individual_request_tag,
COUNT(*) AS log_count,
SUM(spend) AS total_spend
FROM "LiteLLM_SpendLogs"
GROUP BY individual_request_tag;
"""
)
return response