add custom guardrail reference

This commit is contained in:
Ishaan Jaff 2024-08-23 08:32:07 -07:00
parent e62d0c7922
commit af92cff44d
4 changed files with 342 additions and 39 deletions

View file

@ -0,0 +1,115 @@
import os
import sys
sys.path.insert(
0, os.path.abspath("../..")
) # Adds the parent directory to the system path
import asyncio
import json
import sys
import traceback
import uuid
from datetime import datetime
from typing import Any, Dict, List, Literal, Optional, Union
from fastapi import HTTPException
import litellm
from litellm._logging import verbose_proxy_logger
from litellm.caching import DualCache
from litellm.integrations.custom_guardrail import CustomGuardrail
from litellm.proxy._types import UserAPIKeyAuth
from litellm.proxy.guardrails.guardrail_helpers import should_proceed_based_on_metadata
from litellm.types.guardrails import GuardrailEventHooks
class myCustomGuardrail(CustomGuardrail):
def __init__(
self,
**kwargs,
):
# store kwargs as optional_params
self.optional_params = kwargs
super().__init__(**kwargs)
async def async_pre_call_hook(
self,
user_api_key_dict: UserAPIKeyAuth,
cache: DualCache,
data: dict,
call_type: Literal[
"completion",
"text_completion",
"embeddings",
"image_generation",
"moderation",
"audio_transcription",
"pass_through_endpoint",
],
) -> Optional[Union[Exception, str, dict]]:
# In this guardrail, if a user inputs `litellm` we will mask it.
_messages = data.get("messages")
if _messages:
for message in _messages:
_content = message.get("content")
if isinstance(_content, str):
if "litellm" in _content.lower():
_content = _content.replace("litellm", "********")
message["content"] = _content
verbose_proxy_logger.debug(
"async_pre_call_hook: Message after masking %s", _messages
)
return data
async def async_moderation_hook(
self,
data: dict,
user_api_key_dict: UserAPIKeyAuth,
call_type: Literal["completion", "embeddings", "image_generation"],
):
"""
Runs in parallel to LLM API call
Runs on only Input
"""
# this works the same as async_pre_call_hook, but just runs in parallel as the LLM API Call
# In this guardrail, if a user inputs `litellm` we will mask it.
_messages = data.get("messages")
if _messages:
for message in _messages:
_content = message.get("content")
if isinstance(_content, str):
if "litellm" in _content.lower():
_content = _content.replace("litellm", "********")
message["content"] = _content
verbose_proxy_logger.debug(
"async_pre_call_hook: Message after masking %s", _messages
)
pass
async def async_post_call_success_hook(
self,
data: dict,
user_api_key_dict: UserAPIKeyAuth,
response,
):
"""
Runs on response from LLM API call
If a response contains the word "coffee" -> we will raise an exception
"""
verbose_proxy_logger.debug("async_pre_call_hook response: %s", response)
if isinstance(response, litellm.ModelResponse):
for choice in response.choices:
if isinstance(choice, litellm.Choices):
verbose_proxy_logger.debug("async_pre_call_hook choice: %s", choice)
if (
choice.message.content
and isinstance(choice.message.content, str)
and "coffee" in choice.message.content
):
raise ValueError("Guardrail failed Coffee Detected")

View file

@ -0,0 +1,115 @@
import os
import sys
sys.path.insert(
0, os.path.abspath("../..")
) # Adds the parent directory to the system path
import asyncio
import json
import sys
import traceback
import uuid
from datetime import datetime
from typing import Any, Dict, List, Literal, Optional, Union
from fastapi import HTTPException
import litellm
from litellm._logging import verbose_proxy_logger
from litellm.caching import DualCache
from litellm.integrations.custom_guardrail import CustomGuardrail
from litellm.proxy._types import UserAPIKeyAuth
from litellm.proxy.guardrails.guardrail_helpers import should_proceed_based_on_metadata
from litellm.types.guardrails import GuardrailEventHooks
class myCustomGuardrail(CustomGuardrail):
def __init__(
self,
**kwargs,
):
# store kwargs as optional_params
self.optional_params = kwargs
super().__init__(**kwargs)
async def async_pre_call_hook(
self,
user_api_key_dict: UserAPIKeyAuth,
cache: DualCache,
data: dict,
call_type: Literal[
"completion",
"text_completion",
"embeddings",
"image_generation",
"moderation",
"audio_transcription",
"pass_through_endpoint",
],
) -> Optional[Union[Exception, str, dict]]:
# In this guardrail, if a user inputs `litellm` we will mask it.
_messages = data.get("messages")
if _messages:
for message in _messages:
_content = message.get("content")
if isinstance(_content, str):
if "litellm" in _content.lower():
_content = _content.replace("litellm", "********")
message["content"] = _content
verbose_proxy_logger.debug(
"async_pre_call_hook: Message after masking %s", _messages
)
return data
async def async_moderation_hook(
self,
data: dict,
user_api_key_dict: UserAPIKeyAuth,
call_type: Literal["completion", "embeddings", "image_generation"],
):
"""
Runs in parallel to LLM API call
Runs on only Input
"""
# this works the same as async_pre_call_hook, but just runs in parallel as the LLM API Call
# In this guardrail, if a user inputs `litellm` we will mask it.
_messages = data.get("messages")
if _messages:
for message in _messages:
_content = message.get("content")
if isinstance(_content, str):
if "litellm" in _content.lower():
_content = _content.replace("litellm", "********")
message["content"] = _content
verbose_proxy_logger.debug(
"async_pre_call_hook: Message after masking %s", _messages
)
pass
async def async_post_call_success_hook(
self,
data: dict,
user_api_key_dict: UserAPIKeyAuth,
response,
):
"""
Runs on response from LLM API call
If a response contains the word "coffee" -> we will raise an exception
"""
verbose_proxy_logger.debug("async_pre_call_hook response: %s", response)
if isinstance(response, litellm.ModelResponse):
for choice in response.choices:
if isinstance(choice, litellm.Choices):
verbose_proxy_logger.debug("async_pre_call_hook choice: %s", choice)
if (
choice.message.content
and isinstance(choice.message.content, str)
and "coffee" in choice.message.content
):
raise ValueError("Guardrail failed Coffee Detected")

View file

@ -1,17 +1,19 @@
model_list:
- model_name: fake-openai-endpoint
- model_name: gpt-4
litellm_params:
model: azure/chatgpt-v-2
api_base: https://openai-gpt-4-test-v-1.openai.azure.com/
api_version: "2023-05-15"
tenant_id: os.environ/AZURE_TENANT_ID
client_id: os.environ/AZURE_CLIENT_ID
client_secret: os.environ/AZURE_CLIENT_SECRET
model: openai/gpt-4o
api_key: os.environ/OPENAI_API_KEY
guardrails:
- guardrail_name: "bedrock-pre-guard"
- guardrail_name: "custom-pre-guard"
litellm_params:
guardrail: bedrock # supported values: "aporia", "bedrock", "lakera"
guardrail: custom_guardrail.myCustomGuardrail
mode: "pre_call"
- guardrail_name: "custom-during-guard"
litellm_params:
guardrail: custom_guardrail.myCustomGuardrail
mode: "during_call"
- guardrail_name: "custom-post-guard"
litellm_params:
guardrail: custom_guardrail.myCustomGuardrail
mode: "post_call"
guardrailIdentifier: ff6ujrregl1q
guardrailVersion: "DRAFT"

View file

@ -30,6 +30,7 @@ from litellm._logging import verbose_proxy_logger
from litellm._service_logger import ServiceLogging, ServiceTypes
from litellm.caching import DualCache, RedisCache
from litellm.exceptions import RejectedRequestError
from litellm.integrations.custom_guardrail import CustomGuardrail
from litellm.integrations.custom_logger import CustomLogger
from litellm.integrations.slack_alerting import SlackAlerting
from litellm.litellm_core_utils.core_helpers import (
@ -344,6 +345,23 @@ class ProxyLogging:
ttl=alerting_threshold,
)
async def process_pre_call_hook_response(self, response, data, call_type):
if isinstance(response, Exception):
raise response
if isinstance(response, dict):
return response
if isinstance(response, str):
if call_type in ["completion", "text_completion"]:
raise RejectedRequestError(
message=response,
model=data.get("model", ""),
llm_provider="",
request_data=data,
)
else:
raise HTTPException(status_code=400, detail={"error": response})
return data
# The actual implementation of the function
async def pre_call_hook(
self,
@ -382,7 +400,33 @@ class ProxyLogging:
)
else:
_callback = callback # type: ignore
if (
_callback is not None
and isinstance(_callback, CustomGuardrail)
and "pre_call_hook" in vars(_callback.__class__)
):
from litellm.types.guardrails import GuardrailEventHooks
if (
_callback.should_run_guardrail(
data=data, event_type=GuardrailEventHooks.pre_call
)
is not True
):
continue
response = await _callback.async_pre_call_hook(
user_api_key_dict=user_api_key_dict,
cache=self.call_details["user_api_key_cache"],
data=data,
call_type=call_type,
)
if response is not None:
data = await self.process_pre_call_hook_response(
response=response, data=data, call_type=call_type
)
elif (
_callback is not None
and isinstance(_callback, CustomLogger)
and "async_pre_call_hook" in vars(_callback.__class__)
@ -394,25 +438,9 @@ class ProxyLogging:
call_type=call_type,
)
if response is not None:
if isinstance(response, Exception):
raise response
elif isinstance(response, dict):
data = response
elif isinstance(response, str):
if (
call_type == "completion"
or call_type == "text_completion"
):
raise RejectedRequestError(
message=response,
model=data.get("model", ""),
llm_provider="",
request_data=data,
)
else:
raise HTTPException(
status_code=400, detail={"error": response}
)
data = await self.process_pre_call_hook_response(
response=response, data=data, call_type=call_type
)
return data
except Exception as e:
@ -431,11 +459,30 @@ class ProxyLogging:
],
):
"""
Runs the CustomLogger's async_moderation_hook()
Runs the CustomGuardrail's async_moderation_hook()
"""
for callback in litellm.callbacks:
try:
if isinstance(callback, CustomLogger):
if isinstance(callback, CustomGuardrail):
################################################################
# Check if guardrail should be run for GuardrailEventHooks.during_call hook
################################################################
# V1 implementation - backwards compatibility
if callback.event_hook is None:
if callback.moderation_check == "pre_call":
return
else:
# Main - V2 Guardrails implementation
from litellm.types.guardrails import GuardrailEventHooks
if (
callback.should_run_guardrail(
data=data, event_type=GuardrailEventHooks.during_call
)
is not True
):
continue
await callback.async_moderation_hook(
data=data,
user_api_key_dict=user_api_key_dict,
@ -737,12 +784,36 @@ class ProxyLogging:
)
else:
_callback = callback # type: ignore
if _callback is not None and isinstance(_callback, CustomLogger):
await _callback.async_post_call_success_hook(
user_api_key_dict=user_api_key_dict,
data=data,
response=response,
)
if _callback is not None:
############## Handle Guardrails ########################################
#############################################################################
if isinstance(callback, CustomGuardrail):
# Main - V2 Guardrails implementation
from litellm.types.guardrails import GuardrailEventHooks
if (
callback.should_run_guardrail(
data=data, event_type=GuardrailEventHooks.post_call
)
is not True
):
continue
await callback.async_post_call_success_hook(
user_api_key_dict=user_api_key_dict,
data=data,
response=response,
)
############ Handle CustomLogger ###############################
#################################################################
elif isinstance(_callback, CustomLogger):
await _callback.async_post_call_success_hook(
user_api_key_dict=user_api_key_dict,
data=data,
response=response,
)
except Exception as e:
raise e
return response