mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-27 03:34:10 +00:00
fix import loc
This commit is contained in:
parent
b965f1a306
commit
a40ecc3fe4
380 changed files with 1491 additions and 1208 deletions
|
@ -0,0 +1,261 @@
|
|||
# +-------------------------------------------------------------+
|
||||
#
|
||||
# Use Aim Security Guardrails for your LLM calls
|
||||
# https://www.aim.security/
|
||||
#
|
||||
# +-------------------------------------------------------------+
|
||||
import asyncio
|
||||
import json
|
||||
import os
|
||||
from typing import Any, AsyncGenerator, Literal, Optional, Union
|
||||
|
||||
from fastapi import HTTPException
|
||||
from pydantic import BaseModel
|
||||
from websockets.asyncio.client import ClientConnection, connect
|
||||
|
||||
from litellm import DualCache
|
||||
from litellm._version import version as litellm_version
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
from litellm.integrations.custom_guardrail import CustomGuardrail
|
||||
from litellm.llms.custom_httpx.http_handler import (
|
||||
get_async_httpx_client,
|
||||
httpxSpecialProvider,
|
||||
)
|
||||
from litellm_proxy_extras.litellm_proxy._types import UserAPIKeyAuth
|
||||
from litellm_proxy_extras.litellm_proxy.proxy_server import StreamingCallbackError
|
||||
from litellm.types.utils import (
|
||||
Choices,
|
||||
EmbeddingResponse,
|
||||
ImageResponse,
|
||||
ModelResponse,
|
||||
ModelResponseStream,
|
||||
)
|
||||
|
||||
|
||||
class AimGuardrailMissingSecrets(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class AimGuardrail(CustomGuardrail):
|
||||
def __init__(
|
||||
self, api_key: Optional[str] = None, api_base: Optional[str] = None, **kwargs
|
||||
):
|
||||
self.async_handler = get_async_httpx_client(
|
||||
llm_provider=httpxSpecialProvider.GuardrailCallback
|
||||
)
|
||||
self.api_key = api_key or os.environ.get("AIM_API_KEY")
|
||||
if not self.api_key:
|
||||
msg = (
|
||||
"Couldn't get Aim api key, either set the `AIM_API_KEY` in the environment or "
|
||||
"pass it as a parameter to the guardrail in the config file"
|
||||
)
|
||||
raise AimGuardrailMissingSecrets(msg)
|
||||
self.api_base = (
|
||||
api_base or os.environ.get("AIM_API_BASE") or "https://api.aim.security"
|
||||
)
|
||||
self.ws_api_base = self.api_base.replace("http://", "ws://").replace(
|
||||
"https://", "wss://"
|
||||
)
|
||||
super().__init__(**kwargs)
|
||||
|
||||
async def async_pre_call_hook(
|
||||
self,
|
||||
user_api_key_dict: UserAPIKeyAuth,
|
||||
cache: DualCache,
|
||||
data: dict,
|
||||
call_type: Literal[
|
||||
"completion",
|
||||
"text_completion",
|
||||
"embeddings",
|
||||
"image_generation",
|
||||
"moderation",
|
||||
"audio_transcription",
|
||||
"pass_through_endpoint",
|
||||
"rerank",
|
||||
],
|
||||
) -> Union[Exception, str, dict, None]:
|
||||
verbose_proxy_logger.debug("Inside AIM Pre-Call Hook")
|
||||
|
||||
await self.call_aim_guardrail(
|
||||
data, hook="pre_call", key_alias=user_api_key_dict.key_alias
|
||||
)
|
||||
return data
|
||||
|
||||
async def async_moderation_hook(
|
||||
self,
|
||||
data: dict,
|
||||
user_api_key_dict: UserAPIKeyAuth,
|
||||
call_type: Literal[
|
||||
"completion",
|
||||
"embeddings",
|
||||
"image_generation",
|
||||
"moderation",
|
||||
"audio_transcription",
|
||||
"responses",
|
||||
],
|
||||
) -> Union[Exception, str, dict, None]:
|
||||
verbose_proxy_logger.debug("Inside AIM Moderation Hook")
|
||||
|
||||
await self.call_aim_guardrail(
|
||||
data, hook="moderation", key_alias=user_api_key_dict.key_alias
|
||||
)
|
||||
return data
|
||||
|
||||
async def call_aim_guardrail(
|
||||
self, data: dict, hook: str, key_alias: Optional[str]
|
||||
) -> None:
|
||||
user_email = data.get("metadata", {}).get("headers", {}).get("x-aim-user-email")
|
||||
call_id = data.get("litellm_call_id")
|
||||
headers = self._build_aim_headers(
|
||||
hook=hook,
|
||||
key_alias=key_alias,
|
||||
user_email=user_email,
|
||||
litellm_call_id=call_id,
|
||||
)
|
||||
response = await self.async_handler.post(
|
||||
f"{self.api_base}/detect/openai",
|
||||
headers=headers,
|
||||
json={"messages": data.get("messages", [])},
|
||||
)
|
||||
response.raise_for_status()
|
||||
res = response.json()
|
||||
detected = res["detected"]
|
||||
verbose_proxy_logger.info(
|
||||
"Aim: detected: {detected}, enabled policies: {policies}".format(
|
||||
detected=detected,
|
||||
policies=list(res["details"].keys()),
|
||||
),
|
||||
)
|
||||
if detected:
|
||||
raise HTTPException(status_code=400, detail=res["detection_message"])
|
||||
|
||||
async def call_aim_guardrail_on_output(
|
||||
self, request_data: dict, output: str, hook: str, key_alias: Optional[str]
|
||||
) -> Optional[str]:
|
||||
user_email = (
|
||||
request_data.get("metadata", {}).get("headers", {}).get("x-aim-user-email")
|
||||
)
|
||||
call_id = request_data.get("litellm_call_id")
|
||||
response = await self.async_handler.post(
|
||||
f"{self.api_base}/detect/output",
|
||||
headers=self._build_aim_headers(
|
||||
hook=hook,
|
||||
key_alias=key_alias,
|
||||
user_email=user_email,
|
||||
litellm_call_id=call_id,
|
||||
),
|
||||
json={"output": output, "messages": request_data.get("messages", [])},
|
||||
)
|
||||
response.raise_for_status()
|
||||
res = response.json()
|
||||
detected = res["detected"]
|
||||
verbose_proxy_logger.info(
|
||||
"Aim: detected: {detected}, enabled policies: {policies}".format(
|
||||
detected=detected,
|
||||
policies=list(res["details"].keys()),
|
||||
),
|
||||
)
|
||||
if detected:
|
||||
return res["detection_message"]
|
||||
return None
|
||||
|
||||
def _build_aim_headers(
|
||||
self,
|
||||
*,
|
||||
hook: str,
|
||||
key_alias: Optional[str],
|
||||
user_email: Optional[str],
|
||||
litellm_call_id: Optional[str],
|
||||
):
|
||||
"""
|
||||
A helper function to build the http headers that are required by AIM guardrails.
|
||||
"""
|
||||
return (
|
||||
{
|
||||
"Authorization": f"Bearer {self.api_key}",
|
||||
# Used by Aim to apply only the guardrails that should be applied in a specific request phase.
|
||||
"x-aim-litellm-hook": hook,
|
||||
# Used by Aim to track LiteLLM version and provide backward compatibility.
|
||||
"x-aim-litellm-version": litellm_version,
|
||||
}
|
||||
# Used by Aim to track together single call input and output
|
||||
| ({"x-aim-litellm-call-id": litellm_call_id} if litellm_call_id else {})
|
||||
# Used by Aim to track guardrails violations by user.
|
||||
| ({"x-aim-user-email": user_email} if user_email else {})
|
||||
| (
|
||||
{
|
||||
# Used by Aim apply only the guardrails that are associated with the key alias.
|
||||
"x-aim-litellm-key-alias": key_alias,
|
||||
}
|
||||
if key_alias
|
||||
else {}
|
||||
)
|
||||
)
|
||||
|
||||
async def async_post_call_success_hook(
|
||||
self,
|
||||
data: dict,
|
||||
user_api_key_dict: UserAPIKeyAuth,
|
||||
response: Union[Any, ModelResponse, EmbeddingResponse, ImageResponse],
|
||||
) -> Any:
|
||||
if (
|
||||
isinstance(response, ModelResponse)
|
||||
and response.choices
|
||||
and isinstance(response.choices[0], Choices)
|
||||
):
|
||||
content = response.choices[0].message.content or ""
|
||||
detection = await self.call_aim_guardrail_on_output(
|
||||
data, content, hook="output", key_alias=user_api_key_dict.key_alias
|
||||
)
|
||||
if detection:
|
||||
raise HTTPException(status_code=400, detail=detection)
|
||||
|
||||
async def async_post_call_streaming_iterator_hook(
|
||||
self,
|
||||
user_api_key_dict: UserAPIKeyAuth,
|
||||
response,
|
||||
request_data: dict,
|
||||
) -> AsyncGenerator[ModelResponseStream, None]:
|
||||
user_email = (
|
||||
request_data.get("metadata", {}).get("headers", {}).get("x-aim-user-email")
|
||||
)
|
||||
call_id = request_data.get("litellm_call_id")
|
||||
async with connect(
|
||||
f"{self.ws_api_base}/detect/output/ws",
|
||||
additional_headers=self._build_aim_headers(
|
||||
hook="output",
|
||||
key_alias=user_api_key_dict.key_alias,
|
||||
user_email=user_email,
|
||||
litellm_call_id=call_id,
|
||||
),
|
||||
) as websocket:
|
||||
sender = asyncio.create_task(
|
||||
self.forward_the_stream_to_aim(websocket, response)
|
||||
)
|
||||
while True:
|
||||
result = json.loads(await websocket.recv())
|
||||
if verified_chunk := result.get("verified_chunk"):
|
||||
yield ModelResponseStream.model_validate(verified_chunk)
|
||||
else:
|
||||
sender.cancel()
|
||||
if result.get("done"):
|
||||
return
|
||||
if blocking_message := result.get("blocking_message"):
|
||||
raise StreamingCallbackError(blocking_message)
|
||||
verbose_proxy_logger.error(
|
||||
f"Unknown message received from AIM: {result}"
|
||||
)
|
||||
return
|
||||
|
||||
async def forward_the_stream_to_aim(
|
||||
self,
|
||||
websocket: ClientConnection,
|
||||
response_iter,
|
||||
) -> None:
|
||||
async for chunk in response_iter:
|
||||
if isinstance(chunk, BaseModel):
|
||||
chunk = chunk.model_dump_json()
|
||||
if isinstance(chunk, dict):
|
||||
chunk = json.dumps(chunk)
|
||||
await websocket.send(chunk)
|
||||
await websocket.send(json.dumps({"done": True}))
|
|
@ -0,0 +1,228 @@
|
|||
# +-------------------------------------------------------------+
|
||||
#
|
||||
# Use AporiaAI for your LLM calls
|
||||
#
|
||||
# +-------------------------------------------------------------+
|
||||
# Thank you users! We ❤️ you! - Krrish & Ishaan
|
||||
|
||||
import os
|
||||
import sys
|
||||
|
||||
sys.path.insert(
|
||||
0, os.path.abspath("../..")
|
||||
) # Adds the parent directory to the system path
|
||||
import json
|
||||
import sys
|
||||
from typing import Any, List, Literal, Optional
|
||||
|
||||
from fastapi import HTTPException
|
||||
|
||||
import litellm
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
from litellm.integrations.custom_guardrail import (
|
||||
CustomGuardrail,
|
||||
log_guardrail_information,
|
||||
)
|
||||
from litellm.litellm_core_utils.logging_utils import (
|
||||
convert_litellm_response_object_to_str,
|
||||
)
|
||||
from litellm.llms.custom_httpx.http_handler import (
|
||||
get_async_httpx_client,
|
||||
httpxSpecialProvider,
|
||||
)
|
||||
from litellm_proxy_extras.litellm_proxy._types import UserAPIKeyAuth
|
||||
from litellm_proxy_extras.litellm_proxy.guardrails.guardrail_helpers import should_proceed_based_on_metadata
|
||||
from litellm.types.guardrails import GuardrailEventHooks
|
||||
|
||||
litellm.set_verbose = True
|
||||
|
||||
GUARDRAIL_NAME = "aporia"
|
||||
|
||||
|
||||
class AporiaGuardrail(CustomGuardrail):
|
||||
def __init__(
|
||||
self, api_key: Optional[str] = None, api_base: Optional[str] = None, **kwargs
|
||||
):
|
||||
self.async_handler = get_async_httpx_client(
|
||||
llm_provider=httpxSpecialProvider.GuardrailCallback
|
||||
)
|
||||
self.aporia_api_key = api_key or os.environ["APORIO_API_KEY"]
|
||||
self.aporia_api_base = api_base or os.environ["APORIO_API_BASE"]
|
||||
super().__init__(**kwargs)
|
||||
|
||||
#### CALL HOOKS - proxy only ####
|
||||
def transform_messages(self, messages: List[dict]) -> List[dict]:
|
||||
supported_openai_roles = ["system", "user", "assistant"]
|
||||
default_role = "other" # for unsupported roles - e.g. tool
|
||||
new_messages = []
|
||||
for m in messages:
|
||||
if m.get("role", "") in supported_openai_roles:
|
||||
new_messages.append(m)
|
||||
else:
|
||||
new_messages.append(
|
||||
{
|
||||
"role": default_role,
|
||||
**{key: value for key, value in m.items() if key != "role"},
|
||||
}
|
||||
)
|
||||
|
||||
return new_messages
|
||||
|
||||
async def prepare_aporia_request(
|
||||
self, new_messages: List[dict], response_string: Optional[str] = None
|
||||
) -> dict:
|
||||
data: dict[str, Any] = {}
|
||||
if new_messages is not None:
|
||||
data["messages"] = new_messages
|
||||
if response_string is not None:
|
||||
data["response"] = response_string
|
||||
|
||||
# Set validation target
|
||||
if new_messages and response_string:
|
||||
data["validation_target"] = "both"
|
||||
elif new_messages:
|
||||
data["validation_target"] = "prompt"
|
||||
elif response_string:
|
||||
data["validation_target"] = "response"
|
||||
|
||||
verbose_proxy_logger.debug("Aporia AI request: %s", data)
|
||||
return data
|
||||
|
||||
async def make_aporia_api_request(
|
||||
self,
|
||||
request_data: dict,
|
||||
new_messages: List[dict],
|
||||
response_string: Optional[str] = None,
|
||||
):
|
||||
data = await self.prepare_aporia_request(
|
||||
new_messages=new_messages, response_string=response_string
|
||||
)
|
||||
|
||||
data.update(
|
||||
self.get_guardrail_dynamic_request_body_params(request_data=request_data)
|
||||
)
|
||||
|
||||
_json_data = json.dumps(data)
|
||||
|
||||
"""
|
||||
export APORIO_API_KEY=<your key>
|
||||
curl https://gr-prd-trial.aporia.com/some-id \
|
||||
-X POST \
|
||||
-H "X-APORIA-API-KEY: $APORIO_API_KEY" \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "This is a test prompt"
|
||||
}
|
||||
],
|
||||
}
|
||||
'
|
||||
"""
|
||||
|
||||
response = await self.async_handler.post(
|
||||
url=self.aporia_api_base + "/validate",
|
||||
data=_json_data,
|
||||
headers={
|
||||
"X-APORIA-API-KEY": self.aporia_api_key,
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
)
|
||||
verbose_proxy_logger.debug("Aporia AI response: %s", response.text)
|
||||
if response.status_code == 200:
|
||||
# check if the response was flagged
|
||||
_json_response = response.json()
|
||||
action: str = _json_response.get(
|
||||
"action"
|
||||
) # possible values are modify, passthrough, block, rephrase
|
||||
if action == "block":
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail={
|
||||
"error": "Violated guardrail policy",
|
||||
"aporia_ai_response": _json_response,
|
||||
},
|
||||
)
|
||||
|
||||
@log_guardrail_information
|
||||
async def async_post_call_success_hook(
|
||||
self,
|
||||
data: dict,
|
||||
user_api_key_dict: UserAPIKeyAuth,
|
||||
response,
|
||||
):
|
||||
from litellm_proxy_extras.litellm_proxy.common_utils.callback_utils import (
|
||||
add_guardrail_to_applied_guardrails_header,
|
||||
)
|
||||
|
||||
"""
|
||||
Use this for the post call moderation with Guardrails
|
||||
"""
|
||||
event_type: GuardrailEventHooks = GuardrailEventHooks.post_call
|
||||
if self.should_run_guardrail(data=data, event_type=event_type) is not True:
|
||||
return
|
||||
|
||||
response_str: Optional[str] = convert_litellm_response_object_to_str(response)
|
||||
if response_str is not None:
|
||||
await self.make_aporia_api_request(
|
||||
request_data=data,
|
||||
response_string=response_str,
|
||||
new_messages=data.get("messages", []),
|
||||
)
|
||||
|
||||
add_guardrail_to_applied_guardrails_header(
|
||||
request_data=data, guardrail_name=self.guardrail_name
|
||||
)
|
||||
|
||||
pass
|
||||
|
||||
@log_guardrail_information
|
||||
async def async_moderation_hook(
|
||||
self,
|
||||
data: dict,
|
||||
user_api_key_dict: UserAPIKeyAuth,
|
||||
call_type: Literal[
|
||||
"completion",
|
||||
"embeddings",
|
||||
"image_generation",
|
||||
"moderation",
|
||||
"audio_transcription",
|
||||
"responses",
|
||||
],
|
||||
):
|
||||
from litellm_proxy_extras.litellm_proxy.common_utils.callback_utils import (
|
||||
add_guardrail_to_applied_guardrails_header,
|
||||
)
|
||||
|
||||
event_type: GuardrailEventHooks = GuardrailEventHooks.during_call
|
||||
if self.should_run_guardrail(data=data, event_type=event_type) is not True:
|
||||
return
|
||||
|
||||
# old implementation - backwards compatibility
|
||||
if (
|
||||
await should_proceed_based_on_metadata(
|
||||
data=data,
|
||||
guardrail_name=GUARDRAIL_NAME,
|
||||
)
|
||||
is False
|
||||
):
|
||||
return
|
||||
|
||||
new_messages: Optional[List[dict]] = None
|
||||
if "messages" in data and isinstance(data["messages"], list):
|
||||
new_messages = self.transform_messages(messages=data["messages"])
|
||||
|
||||
if new_messages is not None:
|
||||
await self.make_aporia_api_request(
|
||||
request_data=data,
|
||||
new_messages=new_messages,
|
||||
)
|
||||
add_guardrail_to_applied_guardrails_header(
|
||||
request_data=data, guardrail_name=self.guardrail_name
|
||||
)
|
||||
else:
|
||||
verbose_proxy_logger.warning(
|
||||
"Aporia AI: not running guardrail. No messages in data"
|
||||
)
|
||||
pass
|
|
@ -0,0 +1,304 @@
|
|||
# +-------------------------------------------------------------+
|
||||
#
|
||||
# Use Bedrock Guardrails for your LLM calls
|
||||
#
|
||||
# +-------------------------------------------------------------+
|
||||
# Thank you users! We ❤️ you! - Krrish & Ishaan
|
||||
|
||||
import os
|
||||
import sys
|
||||
|
||||
sys.path.insert(
|
||||
0, os.path.abspath("../..")
|
||||
) # Adds the parent directory to the system path
|
||||
import json
|
||||
import sys
|
||||
from typing import Any, List, Literal, Optional, Union
|
||||
|
||||
from fastapi import HTTPException
|
||||
|
||||
import litellm
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
from litellm.integrations.custom_guardrail import (
|
||||
CustomGuardrail,
|
||||
log_guardrail_information,
|
||||
)
|
||||
from litellm.litellm_core_utils.prompt_templates.common_utils import (
|
||||
convert_content_list_to_str,
|
||||
)
|
||||
from litellm.llms.bedrock.base_aws_llm import BaseAWSLLM
|
||||
from litellm.llms.custom_httpx.http_handler import (
|
||||
get_async_httpx_client,
|
||||
httpxSpecialProvider,
|
||||
)
|
||||
from litellm_proxy_extras.litellm_proxy._types import UserAPIKeyAuth
|
||||
from litellm.secret_managers.main import get_secret
|
||||
from litellm.types.guardrails import (
|
||||
BedrockContentItem,
|
||||
BedrockRequest,
|
||||
BedrockTextContent,
|
||||
GuardrailEventHooks,
|
||||
)
|
||||
from litellm.types.llms.openai import AllMessageValues
|
||||
from litellm.types.utils import ModelResponse
|
||||
|
||||
GUARDRAIL_NAME = "bedrock"
|
||||
|
||||
|
||||
class BedrockGuardrail(CustomGuardrail, BaseAWSLLM):
|
||||
def __init__(
|
||||
self,
|
||||
guardrailIdentifier: Optional[str] = None,
|
||||
guardrailVersion: Optional[str] = None,
|
||||
**kwargs,
|
||||
):
|
||||
self.async_handler = get_async_httpx_client(
|
||||
llm_provider=httpxSpecialProvider.GuardrailCallback
|
||||
)
|
||||
self.guardrailIdentifier = guardrailIdentifier
|
||||
self.guardrailVersion = guardrailVersion
|
||||
|
||||
# store kwargs as optional_params
|
||||
self.optional_params = kwargs
|
||||
|
||||
super().__init__(**kwargs)
|
||||
BaseAWSLLM.__init__(self)
|
||||
|
||||
def convert_to_bedrock_format(
|
||||
self,
|
||||
messages: Optional[List[AllMessageValues]] = None,
|
||||
response: Optional[Union[Any, ModelResponse]] = None,
|
||||
) -> BedrockRequest:
|
||||
bedrock_request: BedrockRequest = BedrockRequest(source="INPUT")
|
||||
bedrock_request_content: List[BedrockContentItem] = []
|
||||
|
||||
if messages:
|
||||
for message in messages:
|
||||
bedrock_content_item = BedrockContentItem(
|
||||
text=BedrockTextContent(
|
||||
text=convert_content_list_to_str(message=message)
|
||||
)
|
||||
)
|
||||
bedrock_request_content.append(bedrock_content_item)
|
||||
|
||||
bedrock_request["content"] = bedrock_request_content
|
||||
if response:
|
||||
bedrock_request["source"] = "OUTPUT"
|
||||
if isinstance(response, litellm.ModelResponse):
|
||||
for choice in response.choices:
|
||||
if isinstance(choice, litellm.Choices):
|
||||
if choice.message.content and isinstance(
|
||||
choice.message.content, str
|
||||
):
|
||||
bedrock_content_item = BedrockContentItem(
|
||||
text=BedrockTextContent(text=choice.message.content)
|
||||
)
|
||||
bedrock_request_content.append(bedrock_content_item)
|
||||
bedrock_request["content"] = bedrock_request_content
|
||||
return bedrock_request
|
||||
|
||||
#### CALL HOOKS - proxy only ####
|
||||
def _load_credentials(
|
||||
self,
|
||||
):
|
||||
try:
|
||||
from botocore.credentials import Credentials
|
||||
except ImportError:
|
||||
raise ImportError("Missing boto3 to call bedrock. Run 'pip install boto3'.")
|
||||
## CREDENTIALS ##
|
||||
# pop aws_secret_access_key, aws_access_key_id, aws_session_token, aws_region_name from kwargs, since completion calls fail with them
|
||||
aws_secret_access_key = self.optional_params.pop("aws_secret_access_key", None)
|
||||
aws_access_key_id = self.optional_params.pop("aws_access_key_id", None)
|
||||
aws_session_token = self.optional_params.pop("aws_session_token", None)
|
||||
aws_region_name = self.optional_params.pop("aws_region_name", None)
|
||||
aws_role_name = self.optional_params.pop("aws_role_name", None)
|
||||
aws_session_name = self.optional_params.pop("aws_session_name", None)
|
||||
aws_profile_name = self.optional_params.pop("aws_profile_name", None)
|
||||
self.optional_params.pop(
|
||||
"aws_bedrock_runtime_endpoint", None
|
||||
) # https://bedrock-runtime.{region_name}.amazonaws.com
|
||||
aws_web_identity_token = self.optional_params.pop(
|
||||
"aws_web_identity_token", None
|
||||
)
|
||||
aws_sts_endpoint = self.optional_params.pop("aws_sts_endpoint", None)
|
||||
|
||||
### SET REGION NAME ###
|
||||
if aws_region_name is None:
|
||||
# check env #
|
||||
litellm_aws_region_name = get_secret("AWS_REGION_NAME", None)
|
||||
|
||||
if litellm_aws_region_name is not None and isinstance(
|
||||
litellm_aws_region_name, str
|
||||
):
|
||||
aws_region_name = litellm_aws_region_name
|
||||
|
||||
standard_aws_region_name = get_secret("AWS_REGION", None)
|
||||
if standard_aws_region_name is not None and isinstance(
|
||||
standard_aws_region_name, str
|
||||
):
|
||||
aws_region_name = standard_aws_region_name
|
||||
|
||||
if aws_region_name is None:
|
||||
aws_region_name = "us-west-2"
|
||||
|
||||
credentials: Credentials = self.get_credentials(
|
||||
aws_access_key_id=aws_access_key_id,
|
||||
aws_secret_access_key=aws_secret_access_key,
|
||||
aws_session_token=aws_session_token,
|
||||
aws_region_name=aws_region_name,
|
||||
aws_session_name=aws_session_name,
|
||||
aws_profile_name=aws_profile_name,
|
||||
aws_role_name=aws_role_name,
|
||||
aws_web_identity_token=aws_web_identity_token,
|
||||
aws_sts_endpoint=aws_sts_endpoint,
|
||||
)
|
||||
return credentials, aws_region_name
|
||||
|
||||
def _prepare_request(
|
||||
self,
|
||||
credentials,
|
||||
data: dict,
|
||||
optional_params: dict,
|
||||
aws_region_name: str,
|
||||
extra_headers: Optional[dict] = None,
|
||||
):
|
||||
try:
|
||||
from botocore.auth import SigV4Auth
|
||||
from botocore.awsrequest import AWSRequest
|
||||
except ImportError:
|
||||
raise ImportError("Missing boto3 to call bedrock. Run 'pip install boto3'.")
|
||||
|
||||
sigv4 = SigV4Auth(credentials, "bedrock", aws_region_name)
|
||||
api_base = f"https://bedrock-runtime.{aws_region_name}.amazonaws.com/guardrail/{self.guardrailIdentifier}/version/{self.guardrailVersion}/apply"
|
||||
|
||||
encoded_data = json.dumps(data).encode("utf-8")
|
||||
headers = {"Content-Type": "application/json"}
|
||||
if extra_headers is not None:
|
||||
headers = {"Content-Type": "application/json", **extra_headers}
|
||||
|
||||
request = AWSRequest(
|
||||
method="POST", url=api_base, data=encoded_data, headers=headers
|
||||
)
|
||||
sigv4.add_auth(request)
|
||||
if (
|
||||
extra_headers is not None and "Authorization" in extra_headers
|
||||
): # prevent sigv4 from overwriting the auth header
|
||||
request.headers["Authorization"] = extra_headers["Authorization"]
|
||||
|
||||
prepped_request = request.prepare()
|
||||
|
||||
return prepped_request
|
||||
|
||||
async def make_bedrock_api_request(
|
||||
self, kwargs: dict, response: Optional[Union[Any, litellm.ModelResponse]] = None
|
||||
):
|
||||
credentials, aws_region_name = self._load_credentials()
|
||||
bedrock_request_data: dict = dict(
|
||||
self.convert_to_bedrock_format(
|
||||
messages=kwargs.get("messages"), response=response
|
||||
)
|
||||
)
|
||||
bedrock_request_data.update(
|
||||
self.get_guardrail_dynamic_request_body_params(request_data=kwargs)
|
||||
)
|
||||
prepared_request = self._prepare_request(
|
||||
credentials=credentials,
|
||||
data=bedrock_request_data,
|
||||
optional_params=self.optional_params,
|
||||
aws_region_name=aws_region_name,
|
||||
)
|
||||
verbose_proxy_logger.debug(
|
||||
"Bedrock AI request body: %s, url %s, headers: %s",
|
||||
bedrock_request_data,
|
||||
prepared_request.url,
|
||||
prepared_request.headers,
|
||||
)
|
||||
|
||||
response = await self.async_handler.post(
|
||||
url=prepared_request.url,
|
||||
data=prepared_request.body, # type: ignore
|
||||
headers=prepared_request.headers, # type: ignore
|
||||
)
|
||||
verbose_proxy_logger.debug("Bedrock AI response: %s", response.text)
|
||||
if response.status_code == 200:
|
||||
# check if the response was flagged
|
||||
_json_response = response.json()
|
||||
if _json_response.get("action") == "GUARDRAIL_INTERVENED":
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail={
|
||||
"error": "Violated guardrail policy",
|
||||
"bedrock_guardrail_response": _json_response,
|
||||
},
|
||||
)
|
||||
else:
|
||||
verbose_proxy_logger.error(
|
||||
"Bedrock AI: error in response. Status code: %s, response: %s",
|
||||
response.status_code,
|
||||
response.text,
|
||||
)
|
||||
|
||||
@log_guardrail_information
|
||||
async def async_moderation_hook(
|
||||
self,
|
||||
data: dict,
|
||||
user_api_key_dict: UserAPIKeyAuth,
|
||||
call_type: Literal[
|
||||
"completion",
|
||||
"embeddings",
|
||||
"image_generation",
|
||||
"moderation",
|
||||
"audio_transcription",
|
||||
"responses",
|
||||
],
|
||||
):
|
||||
from litellm_proxy_extras.litellm_proxy.common_utils.callback_utils import (
|
||||
add_guardrail_to_applied_guardrails_header,
|
||||
)
|
||||
|
||||
event_type: GuardrailEventHooks = GuardrailEventHooks.during_call
|
||||
if self.should_run_guardrail(data=data, event_type=event_type) is not True:
|
||||
return
|
||||
|
||||
new_messages: Optional[List[dict]] = data.get("messages")
|
||||
if new_messages is not None:
|
||||
await self.make_bedrock_api_request(kwargs=data)
|
||||
add_guardrail_to_applied_guardrails_header(
|
||||
request_data=data, guardrail_name=self.guardrail_name
|
||||
)
|
||||
else:
|
||||
verbose_proxy_logger.warning(
|
||||
"Bedrock AI: not running guardrail. No messages in data"
|
||||
)
|
||||
pass
|
||||
|
||||
@log_guardrail_information
|
||||
async def async_post_call_success_hook(
|
||||
self,
|
||||
data: dict,
|
||||
user_api_key_dict: UserAPIKeyAuth,
|
||||
response,
|
||||
):
|
||||
from litellm_proxy_extras.litellm_proxy.common_utils.callback_utils import (
|
||||
add_guardrail_to_applied_guardrails_header,
|
||||
)
|
||||
from litellm.types.guardrails import GuardrailEventHooks
|
||||
|
||||
if (
|
||||
self.should_run_guardrail(
|
||||
data=data, event_type=GuardrailEventHooks.post_call
|
||||
)
|
||||
is not True
|
||||
):
|
||||
return
|
||||
|
||||
new_messages: Optional[List[dict]] = data.get("messages")
|
||||
if new_messages is not None:
|
||||
await self.make_bedrock_api_request(kwargs=data, response=response)
|
||||
add_guardrail_to_applied_guardrails_header(
|
||||
request_data=data, guardrail_name=self.guardrail_name
|
||||
)
|
||||
else:
|
||||
verbose_proxy_logger.warning(
|
||||
"Bedrock AI: not running guardrail. No messages in data"
|
||||
)
|
|
@ -0,0 +1,117 @@
|
|||
from typing import Literal, Optional, Union
|
||||
|
||||
import litellm
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
from litellm.caching.caching import DualCache
|
||||
from litellm.integrations.custom_guardrail import (
|
||||
CustomGuardrail,
|
||||
log_guardrail_information,
|
||||
)
|
||||
from litellm_proxy_extras.litellm_proxy._types import UserAPIKeyAuth
|
||||
|
||||
|
||||
class myCustomGuardrail(CustomGuardrail):
|
||||
def __init__(
|
||||
self,
|
||||
**kwargs,
|
||||
):
|
||||
# store kwargs as optional_params
|
||||
self.optional_params = kwargs
|
||||
|
||||
super().__init__(**kwargs)
|
||||
|
||||
@log_guardrail_information
|
||||
async def async_pre_call_hook(
|
||||
self,
|
||||
user_api_key_dict: UserAPIKeyAuth,
|
||||
cache: DualCache,
|
||||
data: dict,
|
||||
call_type: Literal[
|
||||
"completion",
|
||||
"text_completion",
|
||||
"embeddings",
|
||||
"image_generation",
|
||||
"moderation",
|
||||
"audio_transcription",
|
||||
"pass_through_endpoint",
|
||||
"rerank",
|
||||
],
|
||||
) -> Optional[Union[Exception, str, dict]]:
|
||||
"""
|
||||
Runs before the LLM API call
|
||||
Runs on only Input
|
||||
Use this if you want to MODIFY the input
|
||||
"""
|
||||
|
||||
# In this guardrail, if a user inputs `litellm` we will mask it and then send it to the LLM
|
||||
_messages = data.get("messages")
|
||||
if _messages:
|
||||
for message in _messages:
|
||||
_content = message.get("content")
|
||||
if isinstance(_content, str):
|
||||
if "litellm" in _content.lower():
|
||||
_content = _content.replace("litellm", "********")
|
||||
message["content"] = _content
|
||||
|
||||
verbose_proxy_logger.debug(
|
||||
"async_pre_call_hook: Message after masking %s", _messages
|
||||
)
|
||||
|
||||
return data
|
||||
|
||||
@log_guardrail_information
|
||||
async def async_moderation_hook(
|
||||
self,
|
||||
data: dict,
|
||||
user_api_key_dict: UserAPIKeyAuth,
|
||||
call_type: Literal[
|
||||
"completion",
|
||||
"embeddings",
|
||||
"image_generation",
|
||||
"moderation",
|
||||
"audio_transcription",
|
||||
"responses",
|
||||
],
|
||||
):
|
||||
"""
|
||||
Runs in parallel to LLM API call
|
||||
Runs on only Input
|
||||
|
||||
This can NOT modify the input, only used to reject or accept a call before going to LLM API
|
||||
"""
|
||||
|
||||
# this works the same as async_pre_call_hook, but just runs in parallel as the LLM API Call
|
||||
# In this guardrail, if a user inputs `litellm` we will mask it.
|
||||
_messages = data.get("messages")
|
||||
if _messages:
|
||||
for message in _messages:
|
||||
_content = message.get("content")
|
||||
if isinstance(_content, str):
|
||||
if "litellm" in _content.lower():
|
||||
raise ValueError("Guardrail failed words - `litellm` detected")
|
||||
|
||||
@log_guardrail_information
|
||||
async def async_post_call_success_hook(
|
||||
self,
|
||||
data: dict,
|
||||
user_api_key_dict: UserAPIKeyAuth,
|
||||
response,
|
||||
):
|
||||
"""
|
||||
Runs on response from LLM API call
|
||||
|
||||
It can be used to reject a response
|
||||
|
||||
If a response contains the word "coffee" -> we will raise an exception
|
||||
"""
|
||||
verbose_proxy_logger.debug("async_pre_call_hook response: %s", response)
|
||||
if isinstance(response, litellm.ModelResponse):
|
||||
for choice in response.choices:
|
||||
if isinstance(choice, litellm.Choices):
|
||||
verbose_proxy_logger.debug("async_pre_call_hook choice: %s", choice)
|
||||
if (
|
||||
choice.message.content
|
||||
and isinstance(choice.message.content, str)
|
||||
and "coffee" in choice.message.content
|
||||
):
|
||||
raise ValueError("Guardrail failed Coffee Detected")
|
|
@ -0,0 +1,114 @@
|
|||
# +-------------------------------------------------------------+
|
||||
#
|
||||
# Use GuardrailsAI for your LLM calls
|
||||
#
|
||||
# +-------------------------------------------------------------+
|
||||
# Thank you for using Litellm! - Krrish & Ishaan
|
||||
|
||||
import json
|
||||
from typing import Optional, TypedDict
|
||||
|
||||
from fastapi import HTTPException
|
||||
|
||||
import litellm
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
from litellm.integrations.custom_guardrail import (
|
||||
CustomGuardrail,
|
||||
log_guardrail_information,
|
||||
)
|
||||
from litellm.litellm_core_utils.prompt_templates.common_utils import (
|
||||
get_content_from_model_response,
|
||||
)
|
||||
from litellm_proxy_extras.litellm_proxy._types import UserAPIKeyAuth
|
||||
from litellm_proxy_extras.litellm_proxy.common_utils.callback_utils import (
|
||||
add_guardrail_to_applied_guardrails_header,
|
||||
)
|
||||
from litellm.types.guardrails import GuardrailEventHooks
|
||||
|
||||
|
||||
class GuardrailsAIResponse(TypedDict):
|
||||
callId: str
|
||||
rawLlmOutput: str
|
||||
validatedOutput: str
|
||||
validationPassed: bool
|
||||
|
||||
|
||||
class GuardrailsAI(CustomGuardrail):
|
||||
def __init__(
|
||||
self,
|
||||
guard_name: str,
|
||||
api_base: Optional[str] = None,
|
||||
**kwargs,
|
||||
):
|
||||
if guard_name is None:
|
||||
raise Exception(
|
||||
"GuardrailsAIException - Please pass the Guardrails AI guard name via 'litellm_params::guard_name'"
|
||||
)
|
||||
# store kwargs as optional_params
|
||||
self.guardrails_ai_api_base = api_base or "http://0.0.0.0:8000"
|
||||
self.guardrails_ai_guard_name = guard_name
|
||||
self.optional_params = kwargs
|
||||
supported_event_hooks = [GuardrailEventHooks.post_call]
|
||||
super().__init__(supported_event_hooks=supported_event_hooks, **kwargs)
|
||||
|
||||
async def make_guardrails_ai_api_request(self, llm_output: str, request_data: dict):
|
||||
from httpx import URL
|
||||
|
||||
data = {
|
||||
"llmOutput": llm_output,
|
||||
**self.get_guardrail_dynamic_request_body_params(request_data=request_data),
|
||||
}
|
||||
_json_data = json.dumps(data)
|
||||
response = await litellm.module_level_aclient.post(
|
||||
url=str(
|
||||
URL(self.guardrails_ai_api_base).join(
|
||||
f"guards/{self.guardrails_ai_guard_name}/validate"
|
||||
)
|
||||
),
|
||||
data=_json_data,
|
||||
headers={
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
)
|
||||
verbose_proxy_logger.debug("guardrails_ai response: %s", response)
|
||||
_json_response = GuardrailsAIResponse(**response.json()) # type: ignore
|
||||
if _json_response.get("validationPassed") is False:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail={
|
||||
"error": "Violated guardrail policy",
|
||||
"guardrails_ai_response": _json_response,
|
||||
},
|
||||
)
|
||||
return _json_response
|
||||
|
||||
@log_guardrail_information
|
||||
async def async_post_call_success_hook(
|
||||
self,
|
||||
data: dict,
|
||||
user_api_key_dict: UserAPIKeyAuth,
|
||||
response,
|
||||
):
|
||||
"""
|
||||
Runs on response from LLM API call
|
||||
|
||||
It can be used to reject a response
|
||||
"""
|
||||
event_type: GuardrailEventHooks = GuardrailEventHooks.post_call
|
||||
if self.should_run_guardrail(data=data, event_type=event_type) is not True:
|
||||
return
|
||||
|
||||
if not isinstance(response, litellm.ModelResponse):
|
||||
return
|
||||
|
||||
response_str: str = get_content_from_model_response(response)
|
||||
if response_str is not None and len(response_str) > 0:
|
||||
await self.make_guardrails_ai_api_request(
|
||||
llm_output=response_str, request_data=data
|
||||
)
|
||||
|
||||
add_guardrail_to_applied_guardrails_header(
|
||||
request_data=data, guardrail_name=self.guardrail_name
|
||||
)
|
||||
|
||||
return
|
|
@ -0,0 +1,365 @@
|
|||
# +-------------------------------------------------------------+
|
||||
#
|
||||
# Use lakeraAI /moderations for your LLM calls
|
||||
#
|
||||
# +-------------------------------------------------------------+
|
||||
# Thank you users! We ❤️ you! - Krrish & Ishaan
|
||||
|
||||
import os
|
||||
import sys
|
||||
|
||||
sys.path.insert(
|
||||
0, os.path.abspath("../..")
|
||||
) # Adds the parent directory to the system path
|
||||
import json
|
||||
import sys
|
||||
from typing import Dict, List, Literal, Optional, Union
|
||||
|
||||
import httpx
|
||||
from fastapi import HTTPException
|
||||
|
||||
import litellm
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
from litellm.integrations.custom_guardrail import (
|
||||
CustomGuardrail,
|
||||
log_guardrail_information,
|
||||
)
|
||||
from litellm.llms.custom_httpx.http_handler import (
|
||||
get_async_httpx_client,
|
||||
httpxSpecialProvider,
|
||||
)
|
||||
from litellm_proxy_extras.litellm_proxy._types import UserAPIKeyAuth
|
||||
from litellm_proxy_extras.litellm_proxy.guardrails.guardrail_helpers import should_proceed_based_on_metadata
|
||||
from litellm.secret_managers.main import get_secret
|
||||
from litellm.types.guardrails import (
|
||||
GuardrailItem,
|
||||
LakeraCategoryThresholds,
|
||||
Role,
|
||||
default_roles,
|
||||
)
|
||||
|
||||
GUARDRAIL_NAME = "lakera_prompt_injection"
|
||||
|
||||
INPUT_POSITIONING_MAP = {
|
||||
Role.SYSTEM.value: 0,
|
||||
Role.USER.value: 1,
|
||||
Role.ASSISTANT.value: 2,
|
||||
}
|
||||
|
||||
|
||||
class lakeraAI_Moderation(CustomGuardrail):
|
||||
def __init__(
|
||||
self,
|
||||
moderation_check: Literal["pre_call", "in_parallel"] = "in_parallel",
|
||||
category_thresholds: Optional[LakeraCategoryThresholds] = None,
|
||||
api_base: Optional[str] = None,
|
||||
api_key: Optional[str] = None,
|
||||
**kwargs,
|
||||
):
|
||||
self.async_handler = get_async_httpx_client(
|
||||
llm_provider=httpxSpecialProvider.GuardrailCallback
|
||||
)
|
||||
self.lakera_api_key = api_key or os.environ["LAKERA_API_KEY"]
|
||||
self.moderation_check = moderation_check
|
||||
self.category_thresholds = category_thresholds
|
||||
self.api_base = (
|
||||
api_base or get_secret("LAKERA_API_BASE") or "https://api.lakera.ai"
|
||||
)
|
||||
super().__init__(**kwargs)
|
||||
|
||||
#### CALL HOOKS - proxy only ####
|
||||
def _check_response_flagged(self, response: dict) -> None:
|
||||
_results = response.get("results", [])
|
||||
if len(_results) <= 0:
|
||||
return
|
||||
|
||||
flagged = _results[0].get("flagged", False)
|
||||
category_scores: Optional[dict] = _results[0].get("category_scores", None)
|
||||
|
||||
if self.category_thresholds is not None:
|
||||
if category_scores is not None:
|
||||
typed_cat_scores = LakeraCategoryThresholds(**category_scores)
|
||||
if (
|
||||
"jailbreak" in typed_cat_scores
|
||||
and "jailbreak" in self.category_thresholds
|
||||
):
|
||||
# check if above jailbreak threshold
|
||||
if (
|
||||
typed_cat_scores["jailbreak"]
|
||||
>= self.category_thresholds["jailbreak"]
|
||||
):
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail={
|
||||
"error": "Violated jailbreak threshold",
|
||||
"lakera_ai_response": response,
|
||||
},
|
||||
)
|
||||
if (
|
||||
"prompt_injection" in typed_cat_scores
|
||||
and "prompt_injection" in self.category_thresholds
|
||||
):
|
||||
if (
|
||||
typed_cat_scores["prompt_injection"]
|
||||
>= self.category_thresholds["prompt_injection"]
|
||||
):
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail={
|
||||
"error": "Violated prompt_injection threshold",
|
||||
"lakera_ai_response": response,
|
||||
},
|
||||
)
|
||||
elif flagged is True:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail={
|
||||
"error": "Violated content safety policy",
|
||||
"lakera_ai_response": response,
|
||||
},
|
||||
)
|
||||
|
||||
return None
|
||||
|
||||
async def _check( # noqa: PLR0915
|
||||
self,
|
||||
data: dict,
|
||||
user_api_key_dict: UserAPIKeyAuth,
|
||||
call_type: Literal[
|
||||
"completion",
|
||||
"text_completion",
|
||||
"embeddings",
|
||||
"image_generation",
|
||||
"moderation",
|
||||
"audio_transcription",
|
||||
"pass_through_endpoint",
|
||||
"rerank",
|
||||
"responses",
|
||||
],
|
||||
):
|
||||
if (
|
||||
await should_proceed_based_on_metadata(
|
||||
data=data,
|
||||
guardrail_name=GUARDRAIL_NAME,
|
||||
)
|
||||
is False
|
||||
):
|
||||
return
|
||||
text = ""
|
||||
_json_data: str = ""
|
||||
if "messages" in data and isinstance(data["messages"], list):
|
||||
prompt_injection_obj: Optional[
|
||||
GuardrailItem
|
||||
] = litellm.guardrail_name_config_map.get("prompt_injection")
|
||||
if prompt_injection_obj is not None:
|
||||
enabled_roles = prompt_injection_obj.enabled_roles
|
||||
else:
|
||||
enabled_roles = None
|
||||
|
||||
if enabled_roles is None:
|
||||
enabled_roles = default_roles
|
||||
|
||||
stringified_roles: List[str] = []
|
||||
if enabled_roles is not None: # convert to list of str
|
||||
for role in enabled_roles:
|
||||
if isinstance(role, Role):
|
||||
stringified_roles.append(role.value)
|
||||
elif isinstance(role, str):
|
||||
stringified_roles.append(role)
|
||||
lakera_input_dict: Dict = {
|
||||
role: None for role in INPUT_POSITIONING_MAP.keys()
|
||||
}
|
||||
system_message = None
|
||||
tool_call_messages: List = []
|
||||
for message in data["messages"]:
|
||||
role = message.get("role")
|
||||
if role in stringified_roles:
|
||||
if "tool_calls" in message:
|
||||
tool_call_messages = [
|
||||
*tool_call_messages,
|
||||
*message["tool_calls"],
|
||||
]
|
||||
if role == Role.SYSTEM.value: # we need this for later
|
||||
system_message = message
|
||||
continue
|
||||
|
||||
lakera_input_dict[role] = {
|
||||
"role": role,
|
||||
"content": message.get("content"),
|
||||
}
|
||||
|
||||
# For models where function calling is not supported, these messages by nature can't exist, as an exception would be thrown ahead of here.
|
||||
# Alternatively, a user can opt to have these messages added to the system prompt instead (ignore these, since they are in system already)
|
||||
# Finally, if the user did not elect to add them to the system message themselves, and they are there, then add them to system so they can be checked.
|
||||
# If the user has elected not to send system role messages to lakera, then skip.
|
||||
|
||||
if system_message is not None:
|
||||
if not litellm.add_function_to_prompt:
|
||||
content = system_message.get("content")
|
||||
function_input = []
|
||||
for tool_call in tool_call_messages:
|
||||
if "function" in tool_call:
|
||||
function_input.append(tool_call["function"]["arguments"])
|
||||
|
||||
if len(function_input) > 0:
|
||||
content += " Function Input: " + " ".join(function_input)
|
||||
lakera_input_dict[Role.SYSTEM.value] = {
|
||||
"role": Role.SYSTEM.value,
|
||||
"content": content,
|
||||
}
|
||||
|
||||
lakera_input = [
|
||||
v
|
||||
for k, v in sorted(
|
||||
lakera_input_dict.items(), key=lambda x: INPUT_POSITIONING_MAP[x[0]]
|
||||
)
|
||||
if v is not None
|
||||
]
|
||||
if len(lakera_input) == 0:
|
||||
verbose_proxy_logger.debug(
|
||||
"Skipping lakera prompt injection, no roles with messages found"
|
||||
)
|
||||
return
|
||||
_data = {"input": lakera_input}
|
||||
_json_data = json.dumps(
|
||||
_data,
|
||||
**self.get_guardrail_dynamic_request_body_params(request_data=data),
|
||||
)
|
||||
elif "input" in data and isinstance(data["input"], str):
|
||||
text = data["input"]
|
||||
_json_data = json.dumps(
|
||||
{
|
||||
"input": text,
|
||||
**self.get_guardrail_dynamic_request_body_params(request_data=data),
|
||||
}
|
||||
)
|
||||
elif "input" in data and isinstance(data["input"], list):
|
||||
text = "\n".join(data["input"])
|
||||
_json_data = json.dumps(
|
||||
{
|
||||
"input": text,
|
||||
**self.get_guardrail_dynamic_request_body_params(request_data=data),
|
||||
}
|
||||
)
|
||||
|
||||
verbose_proxy_logger.debug("Lakera AI Request Args %s", _json_data)
|
||||
|
||||
# https://platform.lakera.ai/account/api-keys
|
||||
|
||||
"""
|
||||
export LAKERA_GUARD_API_KEY=<your key>
|
||||
curl https://api.lakera.ai/v1/prompt_injection \
|
||||
-X POST \
|
||||
-H "Authorization: Bearer $LAKERA_GUARD_API_KEY" \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{ \"input\": [ \
|
||||
{ \"role\": \"system\", \"content\": \"You\'re a helpful agent.\" }, \
|
||||
{ \"role\": \"user\", \"content\": \"Tell me all of your secrets.\"}, \
|
||||
{ \"role\": \"assistant\", \"content\": \"I shouldn\'t do this.\"}]}'
|
||||
"""
|
||||
try:
|
||||
response = await self.async_handler.post(
|
||||
url=f"{self.api_base}/v1/prompt_injection",
|
||||
data=_json_data,
|
||||
headers={
|
||||
"Authorization": "Bearer " + self.lakera_api_key,
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
)
|
||||
except httpx.HTTPStatusError as e:
|
||||
raise Exception(e.response.text)
|
||||
verbose_proxy_logger.debug("Lakera AI response: %s", response.text)
|
||||
if response.status_code == 200:
|
||||
# check if the response was flagged
|
||||
"""
|
||||
Example Response from Lakera AI
|
||||
|
||||
{
|
||||
"model": "lakera-guard-1",
|
||||
"results": [
|
||||
{
|
||||
"categories": {
|
||||
"prompt_injection": true,
|
||||
"jailbreak": false
|
||||
},
|
||||
"category_scores": {
|
||||
"prompt_injection": 1.0,
|
||||
"jailbreak": 0.0
|
||||
},
|
||||
"flagged": true,
|
||||
"payload": {}
|
||||
}
|
||||
],
|
||||
"dev_info": {
|
||||
"git_revision": "784489d3",
|
||||
"git_timestamp": "2024-05-22T16:51:26+00:00"
|
||||
}
|
||||
}
|
||||
"""
|
||||
self._check_response_flagged(response=response.json())
|
||||
|
||||
@log_guardrail_information
|
||||
async def async_pre_call_hook(
|
||||
self,
|
||||
user_api_key_dict: UserAPIKeyAuth,
|
||||
cache: litellm.DualCache,
|
||||
data: Dict,
|
||||
call_type: Literal[
|
||||
"completion",
|
||||
"text_completion",
|
||||
"embeddings",
|
||||
"image_generation",
|
||||
"moderation",
|
||||
"audio_transcription",
|
||||
"pass_through_endpoint",
|
||||
"rerank",
|
||||
],
|
||||
) -> Optional[Union[Exception, str, Dict]]:
|
||||
from litellm.types.guardrails import GuardrailEventHooks
|
||||
|
||||
if self.event_hook is None:
|
||||
if self.moderation_check == "in_parallel":
|
||||
return None
|
||||
else:
|
||||
# v2 guardrails implementation
|
||||
|
||||
if (
|
||||
self.should_run_guardrail(
|
||||
data=data, event_type=GuardrailEventHooks.pre_call
|
||||
)
|
||||
is not True
|
||||
):
|
||||
return None
|
||||
|
||||
return await self._check(
|
||||
data=data, user_api_key_dict=user_api_key_dict, call_type=call_type
|
||||
)
|
||||
|
||||
@log_guardrail_information
|
||||
async def async_moderation_hook(
|
||||
self,
|
||||
data: dict,
|
||||
user_api_key_dict: UserAPIKeyAuth,
|
||||
call_type: Literal[
|
||||
"completion",
|
||||
"embeddings",
|
||||
"image_generation",
|
||||
"moderation",
|
||||
"audio_transcription",
|
||||
"responses",
|
||||
],
|
||||
):
|
||||
if self.event_hook is None:
|
||||
if self.moderation_check == "pre_call":
|
||||
return
|
||||
else:
|
||||
# V2 Guardrails implementation
|
||||
from litellm.types.guardrails import GuardrailEventHooks
|
||||
|
||||
event_type: GuardrailEventHooks = GuardrailEventHooks.during_call
|
||||
if self.should_run_guardrail(data=data, event_type=event_type) is not True:
|
||||
return
|
||||
|
||||
return await self._check(
|
||||
data=data, user_api_key_dict=user_api_key_dict, call_type=call_type
|
||||
)
|
|
@ -0,0 +1,390 @@
|
|||
# +-----------------------------------------------+
|
||||
# | |
|
||||
# | PII Masking |
|
||||
# | with Microsoft Presidio |
|
||||
# | https://github.com/BerriAI/litellm/issues/ |
|
||||
# +-----------------------------------------------+
|
||||
#
|
||||
# Tell us how we can improve! - Krrish & Ishaan
|
||||
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import uuid
|
||||
from typing import Any, List, Optional, Tuple, Union
|
||||
|
||||
import aiohttp
|
||||
from pydantic import BaseModel
|
||||
|
||||
import litellm # noqa: E401
|
||||
from litellm import get_secret
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
from litellm.caching.caching import DualCache
|
||||
from litellm.integrations.custom_guardrail import (
|
||||
CustomGuardrail,
|
||||
log_guardrail_information,
|
||||
)
|
||||
from litellm_proxy_extras.litellm_proxy._types import UserAPIKeyAuth
|
||||
from litellm.types.guardrails import GuardrailEventHooks
|
||||
from litellm.utils import (
|
||||
EmbeddingResponse,
|
||||
ImageResponse,
|
||||
ModelResponse,
|
||||
StreamingChoices,
|
||||
)
|
||||
|
||||
|
||||
class PresidioPerRequestConfig(BaseModel):
|
||||
"""
|
||||
presdio params that can be controlled per request, api key
|
||||
"""
|
||||
|
||||
language: Optional[str] = None
|
||||
|
||||
|
||||
class _OPTIONAL_PresidioPIIMasking(CustomGuardrail):
|
||||
user_api_key_cache = None
|
||||
ad_hoc_recognizers = None
|
||||
|
||||
# Class variables or attributes
|
||||
def __init__(
|
||||
self,
|
||||
mock_testing: bool = False,
|
||||
mock_redacted_text: Optional[dict] = None,
|
||||
presidio_analyzer_api_base: Optional[str] = None,
|
||||
presidio_anonymizer_api_base: Optional[str] = None,
|
||||
output_parse_pii: Optional[bool] = False,
|
||||
presidio_ad_hoc_recognizers: Optional[str] = None,
|
||||
logging_only: Optional[bool] = None,
|
||||
**kwargs,
|
||||
):
|
||||
if logging_only is True:
|
||||
self.logging_only = True
|
||||
kwargs["event_hook"] = GuardrailEventHooks.logging_only
|
||||
super().__init__(**kwargs)
|
||||
self.pii_tokens: dict = (
|
||||
{}
|
||||
) # mapping of PII token to original text - only used with Presidio `replace` operation
|
||||
self.mock_redacted_text = mock_redacted_text
|
||||
self.output_parse_pii = output_parse_pii or False
|
||||
if mock_testing is True: # for testing purposes only
|
||||
return
|
||||
|
||||
ad_hoc_recognizers = presidio_ad_hoc_recognizers
|
||||
if ad_hoc_recognizers is not None:
|
||||
try:
|
||||
with open(ad_hoc_recognizers, "r") as file:
|
||||
self.ad_hoc_recognizers = json.load(file)
|
||||
except FileNotFoundError:
|
||||
raise Exception(f"File not found. file_path={ad_hoc_recognizers}")
|
||||
except json.JSONDecodeError as e:
|
||||
raise Exception(
|
||||
f"Error decoding JSON file: {str(e)}, file_path={ad_hoc_recognizers}"
|
||||
)
|
||||
except Exception as e:
|
||||
raise Exception(
|
||||
f"An error occurred: {str(e)}, file_path={ad_hoc_recognizers}"
|
||||
)
|
||||
self.validate_environment(
|
||||
presidio_analyzer_api_base=presidio_analyzer_api_base,
|
||||
presidio_anonymizer_api_base=presidio_anonymizer_api_base,
|
||||
)
|
||||
|
||||
def validate_environment(
|
||||
self,
|
||||
presidio_analyzer_api_base: Optional[str] = None,
|
||||
presidio_anonymizer_api_base: Optional[str] = None,
|
||||
):
|
||||
self.presidio_analyzer_api_base: Optional[
|
||||
str
|
||||
] = presidio_analyzer_api_base or get_secret(
|
||||
"PRESIDIO_ANALYZER_API_BASE", None
|
||||
) # type: ignore
|
||||
self.presidio_anonymizer_api_base: Optional[
|
||||
str
|
||||
] = presidio_anonymizer_api_base or litellm.get_secret(
|
||||
"PRESIDIO_ANONYMIZER_API_BASE", None
|
||||
) # type: ignore
|
||||
|
||||
if self.presidio_analyzer_api_base is None:
|
||||
raise Exception("Missing `PRESIDIO_ANALYZER_API_BASE` from environment")
|
||||
if not self.presidio_analyzer_api_base.endswith("/"):
|
||||
self.presidio_analyzer_api_base += "/"
|
||||
if not (
|
||||
self.presidio_analyzer_api_base.startswith("http://")
|
||||
or self.presidio_analyzer_api_base.startswith("https://")
|
||||
):
|
||||
# add http:// if unset, assume communicating over private network - e.g. render
|
||||
self.presidio_analyzer_api_base = (
|
||||
"http://" + self.presidio_analyzer_api_base
|
||||
)
|
||||
|
||||
if self.presidio_anonymizer_api_base is None:
|
||||
raise Exception("Missing `PRESIDIO_ANONYMIZER_API_BASE` from environment")
|
||||
if not self.presidio_anonymizer_api_base.endswith("/"):
|
||||
self.presidio_anonymizer_api_base += "/"
|
||||
if not (
|
||||
self.presidio_anonymizer_api_base.startswith("http://")
|
||||
or self.presidio_anonymizer_api_base.startswith("https://")
|
||||
):
|
||||
# add http:// if unset, assume communicating over private network - e.g. render
|
||||
self.presidio_anonymizer_api_base = (
|
||||
"http://" + self.presidio_anonymizer_api_base
|
||||
)
|
||||
|
||||
async def check_pii(
|
||||
self,
|
||||
text: str,
|
||||
output_parse_pii: bool,
|
||||
presidio_config: Optional[PresidioPerRequestConfig],
|
||||
request_data: dict,
|
||||
) -> str:
|
||||
"""
|
||||
[TODO] make this more performant for high-throughput scenario
|
||||
"""
|
||||
try:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
if self.mock_redacted_text is not None:
|
||||
redacted_text = self.mock_redacted_text
|
||||
else:
|
||||
# Make the first request to /analyze
|
||||
# Construct Request 1
|
||||
analyze_url = f"{self.presidio_analyzer_api_base}analyze"
|
||||
analyze_payload = {"text": text, "language": "en"}
|
||||
if presidio_config and presidio_config.language:
|
||||
analyze_payload["language"] = presidio_config.language
|
||||
if self.ad_hoc_recognizers is not None:
|
||||
analyze_payload["ad_hoc_recognizers"] = self.ad_hoc_recognizers
|
||||
# End of constructing Request 1
|
||||
analyze_payload.update(
|
||||
self.get_guardrail_dynamic_request_body_params(
|
||||
request_data=request_data
|
||||
)
|
||||
)
|
||||
redacted_text = None
|
||||
verbose_proxy_logger.debug(
|
||||
"Making request to: %s with payload: %s",
|
||||
analyze_url,
|
||||
analyze_payload,
|
||||
)
|
||||
async with session.post(
|
||||
analyze_url, json=analyze_payload
|
||||
) as response:
|
||||
analyze_results = await response.json()
|
||||
|
||||
# Make the second request to /anonymize
|
||||
anonymize_url = f"{self.presidio_anonymizer_api_base}anonymize"
|
||||
verbose_proxy_logger.debug("Making request to: %s", anonymize_url)
|
||||
anonymize_payload = {
|
||||
"text": text,
|
||||
"analyzer_results": analyze_results,
|
||||
}
|
||||
|
||||
async with session.post(
|
||||
anonymize_url, json=anonymize_payload
|
||||
) as response:
|
||||
redacted_text = await response.json()
|
||||
|
||||
new_text = text
|
||||
if redacted_text is not None:
|
||||
verbose_proxy_logger.debug("redacted_text: %s", redacted_text)
|
||||
for item in redacted_text["items"]:
|
||||
start = item["start"]
|
||||
end = item["end"]
|
||||
replacement = item["text"] # replacement token
|
||||
if item["operator"] == "replace" and output_parse_pii is True:
|
||||
# check if token in dict
|
||||
# if exists, add a uuid to the replacement token for swapping back to the original text in llm response output parsing
|
||||
if replacement in self.pii_tokens:
|
||||
replacement = replacement + str(uuid.uuid4())
|
||||
|
||||
self.pii_tokens[replacement] = new_text[
|
||||
start:end
|
||||
] # get text it'll replace
|
||||
|
||||
new_text = new_text[:start] + replacement + new_text[end:]
|
||||
return redacted_text["text"]
|
||||
else:
|
||||
raise Exception(f"Invalid anonymizer response: {redacted_text}")
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
@log_guardrail_information
|
||||
async def async_pre_call_hook(
|
||||
self,
|
||||
user_api_key_dict: UserAPIKeyAuth,
|
||||
cache: DualCache,
|
||||
data: dict,
|
||||
call_type: str,
|
||||
):
|
||||
"""
|
||||
- Check if request turned off pii
|
||||
- Check if user allowed to turn off pii (key permissions -> 'allow_pii_controls')
|
||||
|
||||
- Take the request data
|
||||
- Call /analyze -> get the results
|
||||
- Call /anonymize w/ the analyze results -> get the redacted text
|
||||
|
||||
For multiple messages in /chat/completions, we'll need to call them in parallel.
|
||||
"""
|
||||
|
||||
try:
|
||||
content_safety = data.get("content_safety", None)
|
||||
verbose_proxy_logger.debug("content_safety: %s", content_safety)
|
||||
presidio_config = self.get_presidio_settings_from_request_data(data)
|
||||
|
||||
if call_type == "completion": # /chat/completions requests
|
||||
messages = data["messages"]
|
||||
tasks = []
|
||||
|
||||
for m in messages:
|
||||
if isinstance(m["content"], str):
|
||||
tasks.append(
|
||||
self.check_pii(
|
||||
text=m["content"],
|
||||
output_parse_pii=self.output_parse_pii,
|
||||
presidio_config=presidio_config,
|
||||
request_data=data,
|
||||
)
|
||||
)
|
||||
responses = await asyncio.gather(*tasks)
|
||||
for index, r in enumerate(responses):
|
||||
if isinstance(messages[index]["content"], str):
|
||||
messages[index][
|
||||
"content"
|
||||
] = r # replace content with redacted string
|
||||
verbose_proxy_logger.info(
|
||||
f"Presidio PII Masking: Redacted pii message: {data['messages']}"
|
||||
)
|
||||
data["messages"] = messages
|
||||
return data
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
@log_guardrail_information
|
||||
def logging_hook(
|
||||
self, kwargs: dict, result: Any, call_type: str
|
||||
) -> Tuple[dict, Any]:
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
|
||||
def run_in_new_loop():
|
||||
"""Run the coroutine in a new event loop within this thread."""
|
||||
new_loop = asyncio.new_event_loop()
|
||||
try:
|
||||
asyncio.set_event_loop(new_loop)
|
||||
return new_loop.run_until_complete(
|
||||
self.async_logging_hook(
|
||||
kwargs=kwargs, result=result, call_type=call_type
|
||||
)
|
||||
)
|
||||
finally:
|
||||
new_loop.close()
|
||||
asyncio.set_event_loop(None)
|
||||
|
||||
try:
|
||||
# First, try to get the current event loop
|
||||
_ = asyncio.get_running_loop()
|
||||
# If we're already in an event loop, run in a separate thread
|
||||
# to avoid nested event loop issues
|
||||
with ThreadPoolExecutor(max_workers=1) as executor:
|
||||
future = executor.submit(run_in_new_loop)
|
||||
return future.result()
|
||||
|
||||
except RuntimeError:
|
||||
# No running event loop, we can safely run in this thread
|
||||
return run_in_new_loop()
|
||||
|
||||
@log_guardrail_information
|
||||
async def async_logging_hook(
|
||||
self, kwargs: dict, result: Any, call_type: str
|
||||
) -> Tuple[dict, Any]:
|
||||
"""
|
||||
Masks the input before logging to langfuse, datadog, etc.
|
||||
"""
|
||||
if (
|
||||
call_type == "completion" or call_type == "acompletion"
|
||||
): # /chat/completions requests
|
||||
messages: Optional[List] = kwargs.get("messages", None)
|
||||
tasks = []
|
||||
|
||||
if messages is None:
|
||||
return kwargs, result
|
||||
|
||||
presidio_config = self.get_presidio_settings_from_request_data(kwargs)
|
||||
|
||||
for m in messages:
|
||||
text_str = ""
|
||||
if m["content"] is None:
|
||||
continue
|
||||
if isinstance(m["content"], str):
|
||||
text_str = m["content"]
|
||||
tasks.append(
|
||||
self.check_pii(
|
||||
text=text_str,
|
||||
output_parse_pii=False,
|
||||
presidio_config=presidio_config,
|
||||
request_data=kwargs,
|
||||
)
|
||||
) # need to pass separately b/c presidio has context window limits
|
||||
responses = await asyncio.gather(*tasks)
|
||||
for index, r in enumerate(responses):
|
||||
if isinstance(messages[index]["content"], str):
|
||||
messages[index][
|
||||
"content"
|
||||
] = r # replace content with redacted string
|
||||
verbose_proxy_logger.info(
|
||||
f"Presidio PII Masking: Redacted pii message: {messages}"
|
||||
)
|
||||
kwargs["messages"] = messages
|
||||
|
||||
return kwargs, result
|
||||
|
||||
@log_guardrail_information
|
||||
async def async_post_call_success_hook( # type: ignore
|
||||
self,
|
||||
data: dict,
|
||||
user_api_key_dict: UserAPIKeyAuth,
|
||||
response: Union[ModelResponse, EmbeddingResponse, ImageResponse],
|
||||
):
|
||||
"""
|
||||
Output parse the response object to replace the masked tokens with user sent values
|
||||
"""
|
||||
verbose_proxy_logger.debug(
|
||||
f"PII Masking Args: self.output_parse_pii={self.output_parse_pii}; type of response={type(response)}"
|
||||
)
|
||||
|
||||
if self.output_parse_pii is False and litellm.output_parse_pii is False:
|
||||
return response
|
||||
|
||||
if isinstance(response, ModelResponse) and not isinstance(
|
||||
response.choices[0], StreamingChoices
|
||||
): # /chat/completions requests
|
||||
if isinstance(response.choices[0].message.content, str):
|
||||
verbose_proxy_logger.debug(
|
||||
f"self.pii_tokens: {self.pii_tokens}; initial response: {response.choices[0].message.content}"
|
||||
)
|
||||
for key, value in self.pii_tokens.items():
|
||||
response.choices[0].message.content = response.choices[
|
||||
0
|
||||
].message.content.replace(key, value)
|
||||
return response
|
||||
|
||||
def get_presidio_settings_from_request_data(
|
||||
self, data: dict
|
||||
) -> Optional[PresidioPerRequestConfig]:
|
||||
if "metadata" in data:
|
||||
_metadata = data["metadata"]
|
||||
_guardrail_config = _metadata.get("guardrail_config")
|
||||
if _guardrail_config:
|
||||
_presidio_config = PresidioPerRequestConfig(**_guardrail_config)
|
||||
return _presidio_config
|
||||
|
||||
return None
|
||||
|
||||
def print_verbose(self, print_statement):
|
||||
try:
|
||||
verbose_proxy_logger.debug(print_statement)
|
||||
if litellm.set_verbose:
|
||||
print(print_statement) # noqa
|
||||
except Exception:
|
||||
pass
|
Loading…
Add table
Add a link
Reference in a new issue