mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 19:24:27 +00:00
fix use guardrail for pre call hook
This commit is contained in:
parent
6e3f27cf69
commit
a8e192a868
4 changed files with 30 additions and 54 deletions
|
@ -18,16 +18,16 @@ class CustomGuardrail(CustomLogger):
|
||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
|
|
||||||
def should_run_guardrail(self, data, event_type: GuardrailEventHooks) -> bool:
|
def should_run_guardrail(self, data, event_type: GuardrailEventHooks) -> bool:
|
||||||
|
metadata = data.get("metadata") or {}
|
||||||
|
requested_guardrails = metadata.get("guardrails") or []
|
||||||
verbose_logger.debug(
|
verbose_logger.debug(
|
||||||
"inside should_run_guardrail for guardrail=%s event_type= %s guardrail_supported_event_hooks= %s",
|
"inside should_run_guardrail for guardrail=%s event_type= %s guardrail_supported_event_hooks= %s requested_guardrails= %s",
|
||||||
self.guardrail_name,
|
self.guardrail_name,
|
||||||
event_type,
|
event_type,
|
||||||
self.event_hook,
|
self.event_hook,
|
||||||
|
requested_guardrails,
|
||||||
)
|
)
|
||||||
|
|
||||||
metadata = data.get("metadata") or {}
|
|
||||||
requested_guardrails = metadata.get("guardrails") or []
|
|
||||||
|
|
||||||
if self.guardrail_name not in requested_guardrails:
|
if self.guardrail_name not in requested_guardrails:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
|
@ -1,19 +1,5 @@
|
||||||
import os
|
|
||||||
import sys
|
|
||||||
|
|
||||||
sys.path.insert(
|
|
||||||
0, os.path.abspath("../..")
|
|
||||||
) # Adds the parent directory to the system path
|
|
||||||
import asyncio
|
|
||||||
import json
|
|
||||||
import sys
|
|
||||||
import traceback
|
|
||||||
import uuid
|
|
||||||
from datetime import datetime
|
|
||||||
from typing import Any, Dict, List, Literal, Optional, Union
|
from typing import Any, Dict, List, Literal, Optional, Union
|
||||||
|
|
||||||
from fastapi import HTTPException
|
|
||||||
|
|
||||||
import litellm
|
import litellm
|
||||||
from litellm._logging import verbose_proxy_logger
|
from litellm._logging import verbose_proxy_logger
|
||||||
from litellm.caching import DualCache
|
from litellm.caching import DualCache
|
||||||
|
@ -48,7 +34,13 @@ class myCustomGuardrail(CustomGuardrail):
|
||||||
"pass_through_endpoint",
|
"pass_through_endpoint",
|
||||||
],
|
],
|
||||||
) -> Optional[Union[Exception, str, dict]]:
|
) -> Optional[Union[Exception, str, dict]]:
|
||||||
# In this guardrail, if a user inputs `litellm` we will mask it.
|
"""
|
||||||
|
Runs before the LLM API call
|
||||||
|
Runs on only Input
|
||||||
|
Use this if you want to MODIFY the input
|
||||||
|
"""
|
||||||
|
|
||||||
|
# In this guardrail, if a user inputs `litellm` we will mask it and then send it to the LLM
|
||||||
_messages = data.get("messages")
|
_messages = data.get("messages")
|
||||||
if _messages:
|
if _messages:
|
||||||
for message in _messages:
|
for message in _messages:
|
||||||
|
@ -73,6 +65,8 @@ class myCustomGuardrail(CustomGuardrail):
|
||||||
"""
|
"""
|
||||||
Runs in parallel to LLM API call
|
Runs in parallel to LLM API call
|
||||||
Runs on only Input
|
Runs on only Input
|
||||||
|
|
||||||
|
This can NOT modify the input, only used to reject or accept a call before going to LLM API
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# this works the same as async_pre_call_hook, but just runs in parallel as the LLM API Call
|
# this works the same as async_pre_call_hook, but just runs in parallel as the LLM API Call
|
||||||
|
@ -83,13 +77,7 @@ class myCustomGuardrail(CustomGuardrail):
|
||||||
_content = message.get("content")
|
_content = message.get("content")
|
||||||
if isinstance(_content, str):
|
if isinstance(_content, str):
|
||||||
if "litellm" in _content.lower():
|
if "litellm" in _content.lower():
|
||||||
_content = _content.replace("litellm", "********")
|
raise ValueError("Guardrail failed words - `litellm` detected")
|
||||||
message["content"] = _content
|
|
||||||
|
|
||||||
verbose_proxy_logger.debug(
|
|
||||||
"async_pre_call_hook: Message after masking %s", _messages
|
|
||||||
)
|
|
||||||
pass
|
|
||||||
|
|
||||||
async def async_post_call_success_hook(
|
async def async_post_call_success_hook(
|
||||||
self,
|
self,
|
||||||
|
@ -100,6 +88,8 @@ class myCustomGuardrail(CustomGuardrail):
|
||||||
"""
|
"""
|
||||||
Runs on response from LLM API call
|
Runs on response from LLM API call
|
||||||
|
|
||||||
|
It can be used to reject a response
|
||||||
|
|
||||||
If a response contains the word "coffee" -> we will raise an exception
|
If a response contains the word "coffee" -> we will raise an exception
|
||||||
"""
|
"""
|
||||||
verbose_proxy_logger.debug("async_pre_call_hook response: %s", response)
|
verbose_proxy_logger.debug("async_pre_call_hook response: %s", response)
|
||||||
|
|
|
@ -1,19 +1,5 @@
|
||||||
import os
|
|
||||||
import sys
|
|
||||||
|
|
||||||
sys.path.insert(
|
|
||||||
0, os.path.abspath("../..")
|
|
||||||
) # Adds the parent directory to the system path
|
|
||||||
import asyncio
|
|
||||||
import json
|
|
||||||
import sys
|
|
||||||
import traceback
|
|
||||||
import uuid
|
|
||||||
from datetime import datetime
|
|
||||||
from typing import Any, Dict, List, Literal, Optional, Union
|
from typing import Any, Dict, List, Literal, Optional, Union
|
||||||
|
|
||||||
from fastapi import HTTPException
|
|
||||||
|
|
||||||
import litellm
|
import litellm
|
||||||
from litellm._logging import verbose_proxy_logger
|
from litellm._logging import verbose_proxy_logger
|
||||||
from litellm.caching import DualCache
|
from litellm.caching import DualCache
|
||||||
|
@ -48,7 +34,13 @@ class myCustomGuardrail(CustomGuardrail):
|
||||||
"pass_through_endpoint",
|
"pass_through_endpoint",
|
||||||
],
|
],
|
||||||
) -> Optional[Union[Exception, str, dict]]:
|
) -> Optional[Union[Exception, str, dict]]:
|
||||||
# In this guardrail, if a user inputs `litellm` we will mask it.
|
"""
|
||||||
|
Runs before the LLM API call
|
||||||
|
Runs on only Input
|
||||||
|
Use this if you want to MODIFY the input
|
||||||
|
"""
|
||||||
|
|
||||||
|
# In this guardrail, if a user inputs `litellm` we will mask it and then send it to the LLM
|
||||||
_messages = data.get("messages")
|
_messages = data.get("messages")
|
||||||
if _messages:
|
if _messages:
|
||||||
for message in _messages:
|
for message in _messages:
|
||||||
|
@ -73,6 +65,8 @@ class myCustomGuardrail(CustomGuardrail):
|
||||||
"""
|
"""
|
||||||
Runs in parallel to LLM API call
|
Runs in parallel to LLM API call
|
||||||
Runs on only Input
|
Runs on only Input
|
||||||
|
|
||||||
|
This can NOT modify the input, only used to reject or accept a call before going to LLM API
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# this works the same as async_pre_call_hook, but just runs in parallel as the LLM API Call
|
# this works the same as async_pre_call_hook, but just runs in parallel as the LLM API Call
|
||||||
|
@ -83,13 +77,7 @@ class myCustomGuardrail(CustomGuardrail):
|
||||||
_content = message.get("content")
|
_content = message.get("content")
|
||||||
if isinstance(_content, str):
|
if isinstance(_content, str):
|
||||||
if "litellm" in _content.lower():
|
if "litellm" in _content.lower():
|
||||||
_content = _content.replace("litellm", "********")
|
raise ValueError("Guardrail failed words - `litellm` detected")
|
||||||
message["content"] = _content
|
|
||||||
|
|
||||||
verbose_proxy_logger.debug(
|
|
||||||
"async_pre_call_hook: Message after masking %s", _messages
|
|
||||||
)
|
|
||||||
pass
|
|
||||||
|
|
||||||
async def async_post_call_success_hook(
|
async def async_post_call_success_hook(
|
||||||
self,
|
self,
|
||||||
|
@ -100,6 +88,8 @@ class myCustomGuardrail(CustomGuardrail):
|
||||||
"""
|
"""
|
||||||
Runs on response from LLM API call
|
Runs on response from LLM API call
|
||||||
|
|
||||||
|
It can be used to reject a response
|
||||||
|
|
||||||
If a response contains the word "coffee" -> we will raise an exception
|
If a response contains the word "coffee" -> we will raise an exception
|
||||||
"""
|
"""
|
||||||
verbose_proxy_logger.debug("async_pre_call_hook response: %s", response)
|
verbose_proxy_logger.debug("async_pre_call_hook response: %s", response)
|
||||||
|
|
|
@ -393,7 +393,7 @@ class ProxyLogging:
|
||||||
|
|
||||||
try:
|
try:
|
||||||
for callback in litellm.callbacks:
|
for callback in litellm.callbacks:
|
||||||
_callback: Optional[CustomLogger] = None
|
_callback = None
|
||||||
if isinstance(callback, str):
|
if isinstance(callback, str):
|
||||||
_callback = litellm.litellm_core_utils.litellm_logging.get_custom_logger_compatible_class(
|
_callback = litellm.litellm_core_utils.litellm_logging.get_custom_logger_compatible_class(
|
||||||
callback
|
callback
|
||||||
|
@ -401,11 +401,7 @@ class ProxyLogging:
|
||||||
else:
|
else:
|
||||||
_callback = callback # type: ignore
|
_callback = callback # type: ignore
|
||||||
|
|
||||||
if (
|
if _callback is not None and isinstance(_callback, CustomGuardrail):
|
||||||
_callback is not None
|
|
||||||
and isinstance(_callback, CustomGuardrail)
|
|
||||||
and "pre_call_hook" in vars(_callback.__class__)
|
|
||||||
):
|
|
||||||
from litellm.types.guardrails import GuardrailEventHooks
|
from litellm.types.guardrails import GuardrailEventHooks
|
||||||
|
|
||||||
if (
|
if (
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue