litellm-mirror/litellm/proxy/guardrails/guardrail_hooks/aim.py
2025-03-12 18:44:51 -07:00

105 lines
3.5 KiB
Python

# +-------------------------------------------------------------+
#
# Use Aim Security Guardrails for your LLM calls
# https://www.aim.security/
#
# +-------------------------------------------------------------+
import os
from typing import Literal, Optional, Union
from fastapi import HTTPException
from litellm import DualCache
from litellm._logging import verbose_proxy_logger
from litellm.integrations.custom_guardrail import CustomGuardrail
from litellm.llms.custom_httpx.http_handler import (
get_async_httpx_client,
httpxSpecialProvider,
)
from litellm.proxy._types import UserAPIKeyAuth
class AimGuardrailMissingSecrets(Exception):
pass
class AimGuardrail(CustomGuardrail):
def __init__(
self, api_key: Optional[str] = None, api_base: Optional[str] = None, **kwargs
):
self.async_handler = get_async_httpx_client(
llm_provider=httpxSpecialProvider.GuardrailCallback
)
self.api_key = api_key or os.environ.get("AIM_API_KEY")
if not self.api_key:
msg = (
"Couldn't get Aim api key, either set the `AIM_API_KEY` in the environment or "
"pass it as a parameter to the guardrail in the config file"
)
raise AimGuardrailMissingSecrets(msg)
self.api_base = (
api_base or os.environ.get("AIM_API_BASE") or "https://api.aim.security"
)
super().__init__(**kwargs)
async def async_pre_call_hook(
self,
user_api_key_dict: UserAPIKeyAuth,
cache: DualCache,
data: dict,
call_type: Literal[
"completion",
"text_completion",
"embeddings",
"image_generation",
"moderation",
"audio_transcription",
"pass_through_endpoint",
"rerank",
],
) -> Union[Exception, str, dict, None]:
verbose_proxy_logger.debug("Inside AIM Pre-Call Hook")
await self.call_aim_guardrail(data, hook="pre_call")
return data
async def async_moderation_hook(
self,
data: dict,
user_api_key_dict: UserAPIKeyAuth,
call_type: Literal[
"completion",
"embeddings",
"image_generation",
"moderation",
"audio_transcription",
"responses",
],
) -> Union[Exception, str, dict, None]:
verbose_proxy_logger.debug("Inside AIM Moderation Hook")
await self.call_aim_guardrail(data, hook="moderation")
return data
async def call_aim_guardrail(self, data: dict, hook: str) -> None:
user_email = data.get("metadata", {}).get("headers", {}).get("x-aim-user-email")
headers = {
"Authorization": f"Bearer {self.api_key}",
"x-aim-litellm-hook": hook,
} | ({"x-aim-user-email": user_email} if user_email else {})
response = await self.async_handler.post(
f"{self.api_base}/detect/openai",
headers=headers,
json={"messages": data.get("messages", [])},
)
response.raise_for_status()
res = response.json()
detected = res["detected"]
verbose_proxy_logger.info(
"Aim: detected: {detected}, enabled policies: {policies}".format(
detected=detected, policies=list(res["details"].keys())
)
)
if detected:
raise HTTPException(status_code=400, detail=res["detection_message"])