mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 03:04:13 +00:00
add custom guardrail reference
This commit is contained in:
parent
e62d0c7922
commit
af92cff44d
4 changed files with 342 additions and 39 deletions
115
litellm/proxy/custom_guardrail.py
Normal file
115
litellm/proxy/custom_guardrail.py
Normal file
|
@ -0,0 +1,115 @@
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
|
||||||
|
sys.path.insert(
|
||||||
|
0, os.path.abspath("../..")
|
||||||
|
) # Adds the parent directory to the system path
|
||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
import sys
|
||||||
|
import traceback
|
||||||
|
import uuid
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import Any, Dict, List, Literal, Optional, Union
|
||||||
|
|
||||||
|
from fastapi import HTTPException
|
||||||
|
|
||||||
|
import litellm
|
||||||
|
from litellm._logging import verbose_proxy_logger
|
||||||
|
from litellm.caching import DualCache
|
||||||
|
from litellm.integrations.custom_guardrail import CustomGuardrail
|
||||||
|
from litellm.proxy._types import UserAPIKeyAuth
|
||||||
|
from litellm.proxy.guardrails.guardrail_helpers import should_proceed_based_on_metadata
|
||||||
|
from litellm.types.guardrails import GuardrailEventHooks
|
||||||
|
|
||||||
|
|
||||||
|
class myCustomGuardrail(CustomGuardrail):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
# store kwargs as optional_params
|
||||||
|
self.optional_params = kwargs
|
||||||
|
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
|
||||||
|
async def async_pre_call_hook(
|
||||||
|
self,
|
||||||
|
user_api_key_dict: UserAPIKeyAuth,
|
||||||
|
cache: DualCache,
|
||||||
|
data: dict,
|
||||||
|
call_type: Literal[
|
||||||
|
"completion",
|
||||||
|
"text_completion",
|
||||||
|
"embeddings",
|
||||||
|
"image_generation",
|
||||||
|
"moderation",
|
||||||
|
"audio_transcription",
|
||||||
|
"pass_through_endpoint",
|
||||||
|
],
|
||||||
|
) -> Optional[Union[Exception, str, dict]]:
|
||||||
|
# In this guardrail, if a user inputs `litellm` we will mask it.
|
||||||
|
_messages = data.get("messages")
|
||||||
|
if _messages:
|
||||||
|
for message in _messages:
|
||||||
|
_content = message.get("content")
|
||||||
|
if isinstance(_content, str):
|
||||||
|
if "litellm" in _content.lower():
|
||||||
|
_content = _content.replace("litellm", "********")
|
||||||
|
message["content"] = _content
|
||||||
|
|
||||||
|
verbose_proxy_logger.debug(
|
||||||
|
"async_pre_call_hook: Message after masking %s", _messages
|
||||||
|
)
|
||||||
|
|
||||||
|
return data
|
||||||
|
|
||||||
|
async def async_moderation_hook(
|
||||||
|
self,
|
||||||
|
data: dict,
|
||||||
|
user_api_key_dict: UserAPIKeyAuth,
|
||||||
|
call_type: Literal["completion", "embeddings", "image_generation"],
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Runs in parallel to LLM API call
|
||||||
|
Runs on only Input
|
||||||
|
"""
|
||||||
|
|
||||||
|
# this works the same as async_pre_call_hook, but just runs in parallel as the LLM API Call
|
||||||
|
# In this guardrail, if a user inputs `litellm` we will mask it.
|
||||||
|
_messages = data.get("messages")
|
||||||
|
if _messages:
|
||||||
|
for message in _messages:
|
||||||
|
_content = message.get("content")
|
||||||
|
if isinstance(_content, str):
|
||||||
|
if "litellm" in _content.lower():
|
||||||
|
_content = _content.replace("litellm", "********")
|
||||||
|
message["content"] = _content
|
||||||
|
|
||||||
|
verbose_proxy_logger.debug(
|
||||||
|
"async_pre_call_hook: Message after masking %s", _messages
|
||||||
|
)
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def async_post_call_success_hook(
|
||||||
|
self,
|
||||||
|
data: dict,
|
||||||
|
user_api_key_dict: UserAPIKeyAuth,
|
||||||
|
response,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Runs on response from LLM API call
|
||||||
|
|
||||||
|
If a response contains the word "coffee" -> we will raise an exception
|
||||||
|
"""
|
||||||
|
verbose_proxy_logger.debug("async_pre_call_hook response: %s", response)
|
||||||
|
if isinstance(response, litellm.ModelResponse):
|
||||||
|
for choice in response.choices:
|
||||||
|
if isinstance(choice, litellm.Choices):
|
||||||
|
verbose_proxy_logger.debug("async_pre_call_hook choice: %s", choice)
|
||||||
|
if (
|
||||||
|
choice.message.content
|
||||||
|
and isinstance(choice.message.content, str)
|
||||||
|
and "coffee" in choice.message.content
|
||||||
|
):
|
||||||
|
raise ValueError("Guardrail failed Coffee Detected")
|
115
litellm/proxy/guardrails/guardrail_hooks/custom_guardrail.py
Normal file
115
litellm/proxy/guardrails/guardrail_hooks/custom_guardrail.py
Normal file
|
@ -0,0 +1,115 @@
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
|
||||||
|
sys.path.insert(
|
||||||
|
0, os.path.abspath("../..")
|
||||||
|
) # Adds the parent directory to the system path
|
||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
import sys
|
||||||
|
import traceback
|
||||||
|
import uuid
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import Any, Dict, List, Literal, Optional, Union
|
||||||
|
|
||||||
|
from fastapi import HTTPException
|
||||||
|
|
||||||
|
import litellm
|
||||||
|
from litellm._logging import verbose_proxy_logger
|
||||||
|
from litellm.caching import DualCache
|
||||||
|
from litellm.integrations.custom_guardrail import CustomGuardrail
|
||||||
|
from litellm.proxy._types import UserAPIKeyAuth
|
||||||
|
from litellm.proxy.guardrails.guardrail_helpers import should_proceed_based_on_metadata
|
||||||
|
from litellm.types.guardrails import GuardrailEventHooks
|
||||||
|
|
||||||
|
|
||||||
|
class myCustomGuardrail(CustomGuardrail):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
# store kwargs as optional_params
|
||||||
|
self.optional_params = kwargs
|
||||||
|
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
|
||||||
|
async def async_pre_call_hook(
|
||||||
|
self,
|
||||||
|
user_api_key_dict: UserAPIKeyAuth,
|
||||||
|
cache: DualCache,
|
||||||
|
data: dict,
|
||||||
|
call_type: Literal[
|
||||||
|
"completion",
|
||||||
|
"text_completion",
|
||||||
|
"embeddings",
|
||||||
|
"image_generation",
|
||||||
|
"moderation",
|
||||||
|
"audio_transcription",
|
||||||
|
"pass_through_endpoint",
|
||||||
|
],
|
||||||
|
) -> Optional[Union[Exception, str, dict]]:
|
||||||
|
# In this guardrail, if a user inputs `litellm` we will mask it.
|
||||||
|
_messages = data.get("messages")
|
||||||
|
if _messages:
|
||||||
|
for message in _messages:
|
||||||
|
_content = message.get("content")
|
||||||
|
if isinstance(_content, str):
|
||||||
|
if "litellm" in _content.lower():
|
||||||
|
_content = _content.replace("litellm", "********")
|
||||||
|
message["content"] = _content
|
||||||
|
|
||||||
|
verbose_proxy_logger.debug(
|
||||||
|
"async_pre_call_hook: Message after masking %s", _messages
|
||||||
|
)
|
||||||
|
|
||||||
|
return data
|
||||||
|
|
||||||
|
async def async_moderation_hook(
|
||||||
|
self,
|
||||||
|
data: dict,
|
||||||
|
user_api_key_dict: UserAPIKeyAuth,
|
||||||
|
call_type: Literal["completion", "embeddings", "image_generation"],
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Runs in parallel to LLM API call
|
||||||
|
Runs on only Input
|
||||||
|
"""
|
||||||
|
|
||||||
|
# this works the same as async_pre_call_hook, but just runs in parallel as the LLM API Call
|
||||||
|
# In this guardrail, if a user inputs `litellm` we will mask it.
|
||||||
|
_messages = data.get("messages")
|
||||||
|
if _messages:
|
||||||
|
for message in _messages:
|
||||||
|
_content = message.get("content")
|
||||||
|
if isinstance(_content, str):
|
||||||
|
if "litellm" in _content.lower():
|
||||||
|
_content = _content.replace("litellm", "********")
|
||||||
|
message["content"] = _content
|
||||||
|
|
||||||
|
verbose_proxy_logger.debug(
|
||||||
|
"async_pre_call_hook: Message after masking %s", _messages
|
||||||
|
)
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def async_post_call_success_hook(
|
||||||
|
self,
|
||||||
|
data: dict,
|
||||||
|
user_api_key_dict: UserAPIKeyAuth,
|
||||||
|
response,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Runs on response from LLM API call
|
||||||
|
|
||||||
|
If a response contains the word "coffee" -> we will raise an exception
|
||||||
|
"""
|
||||||
|
verbose_proxy_logger.debug("async_pre_call_hook response: %s", response)
|
||||||
|
if isinstance(response, litellm.ModelResponse):
|
||||||
|
for choice in response.choices:
|
||||||
|
if isinstance(choice, litellm.Choices):
|
||||||
|
verbose_proxy_logger.debug("async_pre_call_hook choice: %s", choice)
|
||||||
|
if (
|
||||||
|
choice.message.content
|
||||||
|
and isinstance(choice.message.content, str)
|
||||||
|
and "coffee" in choice.message.content
|
||||||
|
):
|
||||||
|
raise ValueError("Guardrail failed Coffee Detected")
|
|
@ -1,17 +1,19 @@
|
||||||
model_list:
|
model_list:
|
||||||
- model_name: fake-openai-endpoint
|
- model_name: gpt-4
|
||||||
litellm_params:
|
litellm_params:
|
||||||
model: azure/chatgpt-v-2
|
model: openai/gpt-4o
|
||||||
api_base: https://openai-gpt-4-test-v-1.openai.azure.com/
|
api_key: os.environ/OPENAI_API_KEY
|
||||||
api_version: "2023-05-15"
|
|
||||||
tenant_id: os.environ/AZURE_TENANT_ID
|
|
||||||
client_id: os.environ/AZURE_CLIENT_ID
|
|
||||||
client_secret: os.environ/AZURE_CLIENT_SECRET
|
|
||||||
|
|
||||||
guardrails:
|
guardrails:
|
||||||
- guardrail_name: "bedrock-pre-guard"
|
- guardrail_name: "custom-pre-guard"
|
||||||
litellm_params:
|
litellm_params:
|
||||||
guardrail: bedrock # supported values: "aporia", "bedrock", "lakera"
|
guardrail: custom_guardrail.myCustomGuardrail
|
||||||
|
mode: "pre_call"
|
||||||
|
- guardrail_name: "custom-during-guard"
|
||||||
|
litellm_params:
|
||||||
|
guardrail: custom_guardrail.myCustomGuardrail
|
||||||
|
mode: "during_call"
|
||||||
|
- guardrail_name: "custom-post-guard"
|
||||||
|
litellm_params:
|
||||||
|
guardrail: custom_guardrail.myCustomGuardrail
|
||||||
mode: "post_call"
|
mode: "post_call"
|
||||||
guardrailIdentifier: ff6ujrregl1q
|
|
||||||
guardrailVersion: "DRAFT"
|
|
|
@ -30,6 +30,7 @@ from litellm._logging import verbose_proxy_logger
|
||||||
from litellm._service_logger import ServiceLogging, ServiceTypes
|
from litellm._service_logger import ServiceLogging, ServiceTypes
|
||||||
from litellm.caching import DualCache, RedisCache
|
from litellm.caching import DualCache, RedisCache
|
||||||
from litellm.exceptions import RejectedRequestError
|
from litellm.exceptions import RejectedRequestError
|
||||||
|
from litellm.integrations.custom_guardrail import CustomGuardrail
|
||||||
from litellm.integrations.custom_logger import CustomLogger
|
from litellm.integrations.custom_logger import CustomLogger
|
||||||
from litellm.integrations.slack_alerting import SlackAlerting
|
from litellm.integrations.slack_alerting import SlackAlerting
|
||||||
from litellm.litellm_core_utils.core_helpers import (
|
from litellm.litellm_core_utils.core_helpers import (
|
||||||
|
@ -344,6 +345,23 @@ class ProxyLogging:
|
||||||
ttl=alerting_threshold,
|
ttl=alerting_threshold,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
async def process_pre_call_hook_response(self, response, data, call_type):
|
||||||
|
if isinstance(response, Exception):
|
||||||
|
raise response
|
||||||
|
if isinstance(response, dict):
|
||||||
|
return response
|
||||||
|
if isinstance(response, str):
|
||||||
|
if call_type in ["completion", "text_completion"]:
|
||||||
|
raise RejectedRequestError(
|
||||||
|
message=response,
|
||||||
|
model=data.get("model", ""),
|
||||||
|
llm_provider="",
|
||||||
|
request_data=data,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise HTTPException(status_code=400, detail={"error": response})
|
||||||
|
return data
|
||||||
|
|
||||||
# The actual implementation of the function
|
# The actual implementation of the function
|
||||||
async def pre_call_hook(
|
async def pre_call_hook(
|
||||||
self,
|
self,
|
||||||
|
@ -382,7 +400,33 @@ class ProxyLogging:
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
_callback = callback # type: ignore
|
_callback = callback # type: ignore
|
||||||
|
|
||||||
if (
|
if (
|
||||||
|
_callback is not None
|
||||||
|
and isinstance(_callback, CustomGuardrail)
|
||||||
|
and "pre_call_hook" in vars(_callback.__class__)
|
||||||
|
):
|
||||||
|
from litellm.types.guardrails import GuardrailEventHooks
|
||||||
|
|
||||||
|
if (
|
||||||
|
_callback.should_run_guardrail(
|
||||||
|
data=data, event_type=GuardrailEventHooks.pre_call
|
||||||
|
)
|
||||||
|
is not True
|
||||||
|
):
|
||||||
|
continue
|
||||||
|
response = await _callback.async_pre_call_hook(
|
||||||
|
user_api_key_dict=user_api_key_dict,
|
||||||
|
cache=self.call_details["user_api_key_cache"],
|
||||||
|
data=data,
|
||||||
|
call_type=call_type,
|
||||||
|
)
|
||||||
|
if response is not None:
|
||||||
|
data = await self.process_pre_call_hook_response(
|
||||||
|
response=response, data=data, call_type=call_type
|
||||||
|
)
|
||||||
|
|
||||||
|
elif (
|
||||||
_callback is not None
|
_callback is not None
|
||||||
and isinstance(_callback, CustomLogger)
|
and isinstance(_callback, CustomLogger)
|
||||||
and "async_pre_call_hook" in vars(_callback.__class__)
|
and "async_pre_call_hook" in vars(_callback.__class__)
|
||||||
|
@ -394,24 +438,8 @@ class ProxyLogging:
|
||||||
call_type=call_type,
|
call_type=call_type,
|
||||||
)
|
)
|
||||||
if response is not None:
|
if response is not None:
|
||||||
if isinstance(response, Exception):
|
data = await self.process_pre_call_hook_response(
|
||||||
raise response
|
response=response, data=data, call_type=call_type
|
||||||
elif isinstance(response, dict):
|
|
||||||
data = response
|
|
||||||
elif isinstance(response, str):
|
|
||||||
if (
|
|
||||||
call_type == "completion"
|
|
||||||
or call_type == "text_completion"
|
|
||||||
):
|
|
||||||
raise RejectedRequestError(
|
|
||||||
message=response,
|
|
||||||
model=data.get("model", ""),
|
|
||||||
llm_provider="",
|
|
||||||
request_data=data,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=400, detail={"error": response}
|
|
||||||
)
|
)
|
||||||
|
|
||||||
return data
|
return data
|
||||||
|
@ -431,11 +459,30 @@ class ProxyLogging:
|
||||||
],
|
],
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Runs the CustomLogger's async_moderation_hook()
|
Runs the CustomGuardrail's async_moderation_hook()
|
||||||
"""
|
"""
|
||||||
for callback in litellm.callbacks:
|
for callback in litellm.callbacks:
|
||||||
try:
|
try:
|
||||||
if isinstance(callback, CustomLogger):
|
if isinstance(callback, CustomGuardrail):
|
||||||
|
################################################################
|
||||||
|
# Check if guardrail should be run for GuardrailEventHooks.during_call hook
|
||||||
|
################################################################
|
||||||
|
|
||||||
|
# V1 implementation - backwards compatibility
|
||||||
|
if callback.event_hook is None:
|
||||||
|
if callback.moderation_check == "pre_call":
|
||||||
|
return
|
||||||
|
else:
|
||||||
|
# Main - V2 Guardrails implementation
|
||||||
|
from litellm.types.guardrails import GuardrailEventHooks
|
||||||
|
|
||||||
|
if (
|
||||||
|
callback.should_run_guardrail(
|
||||||
|
data=data, event_type=GuardrailEventHooks.during_call
|
||||||
|
)
|
||||||
|
is not True
|
||||||
|
):
|
||||||
|
continue
|
||||||
await callback.async_moderation_hook(
|
await callback.async_moderation_hook(
|
||||||
data=data,
|
data=data,
|
||||||
user_api_key_dict=user_api_key_dict,
|
user_api_key_dict=user_api_key_dict,
|
||||||
|
@ -737,7 +784,31 @@ class ProxyLogging:
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
_callback = callback # type: ignore
|
_callback = callback # type: ignore
|
||||||
if _callback is not None and isinstance(_callback, CustomLogger):
|
|
||||||
|
if _callback is not None:
|
||||||
|
############## Handle Guardrails ########################################
|
||||||
|
#############################################################################
|
||||||
|
if isinstance(callback, CustomGuardrail):
|
||||||
|
# Main - V2 Guardrails implementation
|
||||||
|
from litellm.types.guardrails import GuardrailEventHooks
|
||||||
|
|
||||||
|
if (
|
||||||
|
callback.should_run_guardrail(
|
||||||
|
data=data, event_type=GuardrailEventHooks.post_call
|
||||||
|
)
|
||||||
|
is not True
|
||||||
|
):
|
||||||
|
continue
|
||||||
|
|
||||||
|
await callback.async_post_call_success_hook(
|
||||||
|
user_api_key_dict=user_api_key_dict,
|
||||||
|
data=data,
|
||||||
|
response=response,
|
||||||
|
)
|
||||||
|
|
||||||
|
############ Handle CustomLogger ###############################
|
||||||
|
#################################################################
|
||||||
|
elif isinstance(_callback, CustomLogger):
|
||||||
await _callback.async_post_call_success_hook(
|
await _callback.async_post_call_success_hook(
|
||||||
user_api_key_dict=user_api_key_dict,
|
user_api_key_dict=user_api_key_dict,
|
||||||
data=data,
|
data=data,
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue