diff --git a/litellm/proxy/common_utils/callback_utils.py b/litellm/proxy/common_utils/callback_utils.py index 9cf562734c..26aa28d62a 100644 --- a/litellm/proxy/common_utils/callback_utils.py +++ b/litellm/proxy/common_utils/callback_utils.py @@ -119,7 +119,9 @@ def initialize_callbacks_on_proxy( ) imported_list.append(lakera_moderations_object) elif isinstance(callback, str) and callback == "aporia_prompt_injection": - from enterprise.enterprise_hooks.aporia_ai import _ENTERPRISE_Aporia + from litellm.proxy.guardrails.guardrail_hooks.aporia_ai import ( + _ENTERPRISE_Aporia, + ) if premium_user is not True: raise Exception( diff --git a/litellm/proxy/guardrails/guardrail_hooks/aporia_ai.py b/litellm/proxy/guardrails/guardrail_hooks/aporia_ai.py new file mode 100644 index 0000000000..29566d94db --- /dev/null +++ b/litellm/proxy/guardrails/guardrail_hooks/aporia_ai.py @@ -0,0 +1,212 @@ +# +-------------------------------------------------------------+ +# +# Use AporiaAI for your LLM calls +# +# +-------------------------------------------------------------+ +# Thank you users! We ❤️ you! - Krrish & Ishaan + +import os +import sys + +sys.path.insert( + 0, os.path.abspath("../..") +) # Adds the parent directory to the system path +import asyncio +import json +import sys +import traceback +import uuid +from datetime import datetime +from typing import Any, List, Literal, Optional, Union + +import aiohttp +import httpx +from fastapi import HTTPException + +import litellm +from litellm._logging import verbose_proxy_logger +from litellm.caching import DualCache +from litellm.integrations.custom_guardrail import CustomGuardrail +from litellm.litellm_core_utils.logging_utils import ( + convert_litellm_response_object_to_str, +) +from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler +from litellm.proxy._types import UserAPIKeyAuth +from litellm.proxy.guardrails.guardrail_helpers import should_proceed_based_on_metadata +from litellm.types.guardrails import GuardrailEventHooks + +litellm.set_verbose = True + +GUARDRAIL_NAME = "aporia" + + +class _ENTERPRISE_Aporia(CustomGuardrail): + def __init__( + self, api_key: Optional[str] = None, api_base: Optional[str] = None, **kwargs + ): + self.async_handler = AsyncHTTPHandler( + timeout=httpx.Timeout(timeout=600.0, connect=5.0) + ) + self.aporia_api_key = api_key or os.environ["APORIO_API_KEY"] + self.aporia_api_base = api_base or os.environ["APORIO_API_BASE"] + self.event_hook: GuardrailEventHooks + + super().__init__(**kwargs) + + #### CALL HOOKS - proxy only #### + def transform_messages(self, messages: List[dict]) -> List[dict]: + supported_openai_roles = ["system", "user", "assistant"] + default_role = "other" # for unsupported roles - e.g. tool + new_messages = [] + for m in messages: + if m.get("role", "") in supported_openai_roles: + new_messages.append(m) + else: + new_messages.append( + { + "role": default_role, + **{key: value for key, value in m.items() if key != "role"}, + } + ) + + return new_messages + + async def prepare_aporia_request( + self, new_messages: List[dict], response_string: Optional[str] = None + ) -> dict: + data: dict[str, Any] = {} + if new_messages is not None: + data["messages"] = new_messages + if response_string is not None: + data["response"] = response_string + + # Set validation target + if new_messages and response_string: + data["validation_target"] = "both" + elif new_messages: + data["validation_target"] = "prompt" + elif response_string: + data["validation_target"] = "response" + + verbose_proxy_logger.debug("Aporia AI request: %s", data) + return data + + async def make_aporia_api_request( + self, new_messages: List[dict], response_string: Optional[str] = None + ): + data = await self.prepare_aporia_request( + new_messages=new_messages, response_string=response_string + ) + + _json_data = json.dumps(data) + + """ + export APORIO_API_KEY= + curl https://gr-prd-trial.aporia.com/some-id \ + -X POST \ + -H "X-APORIA-API-KEY: $APORIO_API_KEY" \ + -H "Content-Type: application/json" \ + -d '{ + "messages": [ + { + "role": "user", + "content": "This is a test prompt" + } + ], + } +' + """ + + response = await self.async_handler.post( + url=self.aporia_api_base + "/validate", + data=_json_data, + headers={ + "X-APORIA-API-KEY": self.aporia_api_key, + "Content-Type": "application/json", + }, + ) + verbose_proxy_logger.debug("Aporia AI response: %s", response.text) + if response.status_code == 200: + # check if the response was flagged + _json_response = response.json() + action: str = _json_response.get( + "action" + ) # possible values are modify, passthrough, block, rephrase + if action == "block": + raise HTTPException( + status_code=400, + detail={ + "error": "Violated guardrail policy", + "aporia_ai_response": _json_response, + }, + ) + + async def async_post_call_success_hook( + self, + data: dict, + user_api_key_dict: UserAPIKeyAuth, + response, + ): + from litellm.proxy.common_utils.callback_utils import ( + add_guardrail_to_applied_guardrails_header, + ) + from litellm.types.guardrails import GuardrailEventHooks + + """ + Use this for the post call moderation with Guardrails + """ + event_type: GuardrailEventHooks = GuardrailEventHooks.post_call + if self.should_run_guardrail(data=data, event_type=event_type) is not True: + return + + response_str: Optional[str] = convert_litellm_response_object_to_str(response) + if response_str is not None: + await self.make_aporia_api_request( + response_string=response_str, new_messages=data.get("messages", []) + ) + + add_guardrail_to_applied_guardrails_header( + request_data=data, guardrail_name=self.guardrail_name + ) + + pass + + async def async_moderation_hook( ### 👈 KEY CHANGE ### + self, + data: dict, + user_api_key_dict: UserAPIKeyAuth, + call_type: Literal["completion", "embeddings", "image_generation"], + ): + from litellm.proxy.common_utils.callback_utils import ( + add_guardrail_to_applied_guardrails_header, + ) + from litellm.types.guardrails import GuardrailEventHooks + + event_type: GuardrailEventHooks = GuardrailEventHooks.during_call + if self.should_run_guardrail(data=data, event_type=event_type) is not True: + return + + # old implementation - backwards compatibility + if ( + await should_proceed_based_on_metadata( + data=data, + guardrail_name=GUARDRAIL_NAME, + ) + is False + ): + return + + new_messages: Optional[List[dict]] = None + if "messages" in data and isinstance(data["messages"], list): + new_messages = self.transform_messages(messages=data["messages"]) + + if new_messages is not None: + await self.make_aporia_api_request(new_messages=new_messages) + add_guardrail_to_applied_guardrails_header( + request_data=data, guardrail_name=self.guardrail_name + ) + else: + verbose_proxy_logger.warning( + "Aporia AI: not running guardrail. No messages in data" + ) + pass diff --git a/litellm/proxy/guardrails/init_guardrails.py b/litellm/proxy/guardrails/init_guardrails.py index 49b7d67123..f5ed9fee84 100644 --- a/litellm/proxy/guardrails/init_guardrails.py +++ b/litellm/proxy/guardrails/init_guardrails.py @@ -113,9 +113,7 @@ def init_guardrails_v2(all_guardrails: dict): # Init guardrail CustomLoggerClass if litellm_params["guardrail"] == "aporia": - from litellm.proxy.enterprise.enterprise_hooks.aporia_ai import ( - _ENTERPRISE_Aporia, - ) + from guardrail_hooks.aporia_ai import _ENTERPRISE_Aporia _aporia_callback = _ENTERPRISE_Aporia( api_base=litellm_params["api_base"],