mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 18:54:30 +00:00
fix importing _ENTERPRISE_Aporia
This commit is contained in:
parent
590c5eade6
commit
6f12f6dcf8
3 changed files with 216 additions and 4 deletions
|
@ -119,7 +119,9 @@ def initialize_callbacks_on_proxy(
|
||||||
)
|
)
|
||||||
imported_list.append(lakera_moderations_object)
|
imported_list.append(lakera_moderations_object)
|
||||||
elif isinstance(callback, str) and callback == "aporia_prompt_injection":
|
elif isinstance(callback, str) and callback == "aporia_prompt_injection":
|
||||||
from enterprise.enterprise_hooks.aporia_ai import _ENTERPRISE_Aporia
|
from litellm.proxy.guardrails.guardrail_hooks.aporia_ai import (
|
||||||
|
_ENTERPRISE_Aporia,
|
||||||
|
)
|
||||||
|
|
||||||
if premium_user is not True:
|
if premium_user is not True:
|
||||||
raise Exception(
|
raise Exception(
|
||||||
|
|
212
litellm/proxy/guardrails/guardrail_hooks/aporia_ai.py
Normal file
212
litellm/proxy/guardrails/guardrail_hooks/aporia_ai.py
Normal file
|
@ -0,0 +1,212 @@
|
||||||
|
# +-------------------------------------------------------------+
|
||||||
|
#
|
||||||
|
# Use AporiaAI for your LLM calls
|
||||||
|
#
|
||||||
|
# +-------------------------------------------------------------+
|
||||||
|
# Thank you users! We ❤️ you! - Krrish & Ishaan
|
||||||
|
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
|
||||||
|
sys.path.insert(
|
||||||
|
0, os.path.abspath("../..")
|
||||||
|
) # Adds the parent directory to the system path
|
||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
import sys
|
||||||
|
import traceback
|
||||||
|
import uuid
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import Any, List, Literal, Optional, Union
|
||||||
|
|
||||||
|
import aiohttp
|
||||||
|
import httpx
|
||||||
|
from fastapi import HTTPException
|
||||||
|
|
||||||
|
import litellm
|
||||||
|
from litellm._logging import verbose_proxy_logger
|
||||||
|
from litellm.caching import DualCache
|
||||||
|
from litellm.integrations.custom_guardrail import CustomGuardrail
|
||||||
|
from litellm.litellm_core_utils.logging_utils import (
|
||||||
|
convert_litellm_response_object_to_str,
|
||||||
|
)
|
||||||
|
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler
|
||||||
|
from litellm.proxy._types import UserAPIKeyAuth
|
||||||
|
from litellm.proxy.guardrails.guardrail_helpers import should_proceed_based_on_metadata
|
||||||
|
from litellm.types.guardrails import GuardrailEventHooks
|
||||||
|
|
||||||
|
litellm.set_verbose = True
|
||||||
|
|
||||||
|
GUARDRAIL_NAME = "aporia"
|
||||||
|
|
||||||
|
|
||||||
|
class _ENTERPRISE_Aporia(CustomGuardrail):
|
||||||
|
def __init__(
|
||||||
|
self, api_key: Optional[str] = None, api_base: Optional[str] = None, **kwargs
|
||||||
|
):
|
||||||
|
self.async_handler = AsyncHTTPHandler(
|
||||||
|
timeout=httpx.Timeout(timeout=600.0, connect=5.0)
|
||||||
|
)
|
||||||
|
self.aporia_api_key = api_key or os.environ["APORIO_API_KEY"]
|
||||||
|
self.aporia_api_base = api_base or os.environ["APORIO_API_BASE"]
|
||||||
|
self.event_hook: GuardrailEventHooks
|
||||||
|
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
|
||||||
|
#### CALL HOOKS - proxy only ####
|
||||||
|
def transform_messages(self, messages: List[dict]) -> List[dict]:
|
||||||
|
supported_openai_roles = ["system", "user", "assistant"]
|
||||||
|
default_role = "other" # for unsupported roles - e.g. tool
|
||||||
|
new_messages = []
|
||||||
|
for m in messages:
|
||||||
|
if m.get("role", "") in supported_openai_roles:
|
||||||
|
new_messages.append(m)
|
||||||
|
else:
|
||||||
|
new_messages.append(
|
||||||
|
{
|
||||||
|
"role": default_role,
|
||||||
|
**{key: value for key, value in m.items() if key != "role"},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
return new_messages
|
||||||
|
|
||||||
|
async def prepare_aporia_request(
|
||||||
|
self, new_messages: List[dict], response_string: Optional[str] = None
|
||||||
|
) -> dict:
|
||||||
|
data: dict[str, Any] = {}
|
||||||
|
if new_messages is not None:
|
||||||
|
data["messages"] = new_messages
|
||||||
|
if response_string is not None:
|
||||||
|
data["response"] = response_string
|
||||||
|
|
||||||
|
# Set validation target
|
||||||
|
if new_messages and response_string:
|
||||||
|
data["validation_target"] = "both"
|
||||||
|
elif new_messages:
|
||||||
|
data["validation_target"] = "prompt"
|
||||||
|
elif response_string:
|
||||||
|
data["validation_target"] = "response"
|
||||||
|
|
||||||
|
verbose_proxy_logger.debug("Aporia AI request: %s", data)
|
||||||
|
return data
|
||||||
|
|
||||||
|
async def make_aporia_api_request(
|
||||||
|
self, new_messages: List[dict], response_string: Optional[str] = None
|
||||||
|
):
|
||||||
|
data = await self.prepare_aporia_request(
|
||||||
|
new_messages=new_messages, response_string=response_string
|
||||||
|
)
|
||||||
|
|
||||||
|
_json_data = json.dumps(data)
|
||||||
|
|
||||||
|
"""
|
||||||
|
export APORIO_API_KEY=<your key>
|
||||||
|
curl https://gr-prd-trial.aporia.com/some-id \
|
||||||
|
-X POST \
|
||||||
|
-H "X-APORIA-API-KEY: $APORIO_API_KEY" \
|
||||||
|
-H "Content-Type: application/json" \
|
||||||
|
-d '{
|
||||||
|
"messages": [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "This is a test prompt"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
}
|
||||||
|
'
|
||||||
|
"""
|
||||||
|
|
||||||
|
response = await self.async_handler.post(
|
||||||
|
url=self.aporia_api_base + "/validate",
|
||||||
|
data=_json_data,
|
||||||
|
headers={
|
||||||
|
"X-APORIA-API-KEY": self.aporia_api_key,
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
verbose_proxy_logger.debug("Aporia AI response: %s", response.text)
|
||||||
|
if response.status_code == 200:
|
||||||
|
# check if the response was flagged
|
||||||
|
_json_response = response.json()
|
||||||
|
action: str = _json_response.get(
|
||||||
|
"action"
|
||||||
|
) # possible values are modify, passthrough, block, rephrase
|
||||||
|
if action == "block":
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=400,
|
||||||
|
detail={
|
||||||
|
"error": "Violated guardrail policy",
|
||||||
|
"aporia_ai_response": _json_response,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
async def async_post_call_success_hook(
|
||||||
|
self,
|
||||||
|
data: dict,
|
||||||
|
user_api_key_dict: UserAPIKeyAuth,
|
||||||
|
response,
|
||||||
|
):
|
||||||
|
from litellm.proxy.common_utils.callback_utils import (
|
||||||
|
add_guardrail_to_applied_guardrails_header,
|
||||||
|
)
|
||||||
|
from litellm.types.guardrails import GuardrailEventHooks
|
||||||
|
|
||||||
|
"""
|
||||||
|
Use this for the post call moderation with Guardrails
|
||||||
|
"""
|
||||||
|
event_type: GuardrailEventHooks = GuardrailEventHooks.post_call
|
||||||
|
if self.should_run_guardrail(data=data, event_type=event_type) is not True:
|
||||||
|
return
|
||||||
|
|
||||||
|
response_str: Optional[str] = convert_litellm_response_object_to_str(response)
|
||||||
|
if response_str is not None:
|
||||||
|
await self.make_aporia_api_request(
|
||||||
|
response_string=response_str, new_messages=data.get("messages", [])
|
||||||
|
)
|
||||||
|
|
||||||
|
add_guardrail_to_applied_guardrails_header(
|
||||||
|
request_data=data, guardrail_name=self.guardrail_name
|
||||||
|
)
|
||||||
|
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def async_moderation_hook( ### 👈 KEY CHANGE ###
|
||||||
|
self,
|
||||||
|
data: dict,
|
||||||
|
user_api_key_dict: UserAPIKeyAuth,
|
||||||
|
call_type: Literal["completion", "embeddings", "image_generation"],
|
||||||
|
):
|
||||||
|
from litellm.proxy.common_utils.callback_utils import (
|
||||||
|
add_guardrail_to_applied_guardrails_header,
|
||||||
|
)
|
||||||
|
from litellm.types.guardrails import GuardrailEventHooks
|
||||||
|
|
||||||
|
event_type: GuardrailEventHooks = GuardrailEventHooks.during_call
|
||||||
|
if self.should_run_guardrail(data=data, event_type=event_type) is not True:
|
||||||
|
return
|
||||||
|
|
||||||
|
# old implementation - backwards compatibility
|
||||||
|
if (
|
||||||
|
await should_proceed_based_on_metadata(
|
||||||
|
data=data,
|
||||||
|
guardrail_name=GUARDRAIL_NAME,
|
||||||
|
)
|
||||||
|
is False
|
||||||
|
):
|
||||||
|
return
|
||||||
|
|
||||||
|
new_messages: Optional[List[dict]] = None
|
||||||
|
if "messages" in data and isinstance(data["messages"], list):
|
||||||
|
new_messages = self.transform_messages(messages=data["messages"])
|
||||||
|
|
||||||
|
if new_messages is not None:
|
||||||
|
await self.make_aporia_api_request(new_messages=new_messages)
|
||||||
|
add_guardrail_to_applied_guardrails_header(
|
||||||
|
request_data=data, guardrail_name=self.guardrail_name
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
verbose_proxy_logger.warning(
|
||||||
|
"Aporia AI: not running guardrail. No messages in data"
|
||||||
|
)
|
||||||
|
pass
|
|
@ -113,9 +113,7 @@ def init_guardrails_v2(all_guardrails: dict):
|
||||||
|
|
||||||
# Init guardrail CustomLoggerClass
|
# Init guardrail CustomLoggerClass
|
||||||
if litellm_params["guardrail"] == "aporia":
|
if litellm_params["guardrail"] == "aporia":
|
||||||
from litellm.proxy.enterprise.enterprise_hooks.aporia_ai import (
|
from guardrail_hooks.aporia_ai import _ENTERPRISE_Aporia
|
||||||
_ENTERPRISE_Aporia,
|
|
||||||
)
|
|
||||||
|
|
||||||
_aporia_callback = _ENTERPRISE_Aporia(
|
_aporia_callback = _ENTERPRISE_Aporia(
|
||||||
api_base=litellm_params["api_base"],
|
api_base=litellm_params["api_base"],
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue