Virtual key based policies in Aim Guardrails (#9499)

* report key alias to aim

* send litellm version to aim

* Update docs

* blacken

* add docs

* Add info part about virtual keys specific guards

* sort guardrails alphabetically

* fix ruff
This commit is contained in:
Tomer Bin 2025-04-02 07:57:23 +03:00 committed by GitHub
parent ac3399238e
commit 0690f7a3cb
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 78 additions and 29 deletions

View file

@ -14,6 +14,7 @@ from pydantic import BaseModel
from websockets.asyncio.client import ClientConnection, connect
from litellm import DualCache
from litellm._version import version as litellm_version
from litellm._logging import verbose_proxy_logger
from litellm.integrations.custom_guardrail import CustomGuardrail
from litellm.llms.custom_httpx.http_handler import (
@ -75,7 +76,9 @@ class AimGuardrail(CustomGuardrail):
) -> Union[Exception, str, dict, None]:
verbose_proxy_logger.debug("Inside AIM Pre-Call Hook")
await self.call_aim_guardrail(data, hook="pre_call")
await self.call_aim_guardrail(
data, hook="pre_call", key_alias=user_api_key_dict.key_alias
)
return data
async def async_moderation_hook(
@ -93,15 +96,18 @@ class AimGuardrail(CustomGuardrail):
) -> Union[Exception, str, dict, None]:
verbose_proxy_logger.debug("Inside AIM Moderation Hook")
await self.call_aim_guardrail(data, hook="moderation")
await self.call_aim_guardrail(
data, hook="moderation", key_alias=user_api_key_dict.key_alias
)
return data
async def call_aim_guardrail(self, data: dict, hook: str) -> None:
async def call_aim_guardrail(
self, data: dict, hook: str, key_alias: Optional[str]
) -> None:
user_email = data.get("metadata", {}).get("headers", {}).get("x-aim-user-email")
headers = {
"Authorization": f"Bearer {self.api_key}",
"x-aim-litellm-hook": hook,
} | ({"x-aim-user-email": user_email} if user_email else {})
headers = self._build_aim_headers(
hook=hook, key_alias=key_alias, user_email=user_email
)
response = await self.async_handler.post(
f"{self.api_base}/detect/openai",
headers=headers,
@ -120,18 +126,16 @@ class AimGuardrail(CustomGuardrail):
raise HTTPException(status_code=400, detail=res["detection_message"])
async def call_aim_guardrail_on_output(
self, request_data: dict, output: str, hook: str
self, request_data: dict, output: str, hook: str, key_alias: Optional[str]
) -> Optional[str]:
user_email = (
request_data.get("metadata", {}).get("headers", {}).get("x-aim-user-email")
)
headers = {
"Authorization": f"Bearer {self.api_key}",
"x-aim-litellm-hook": hook,
} | ({"x-aim-user-email": user_email} if user_email else {})
response = await self.async_handler.post(
f"{self.api_base}/detect/output",
headers=headers,
headers=self._build_aim_headers(
hook=hook, key_alias=key_alias, user_email=user_email
),
json={"output": output, "messages": request_data.get("messages", [])},
)
response.raise_for_status()
@ -147,6 +151,32 @@ class AimGuardrail(CustomGuardrail):
return res["detection_message"]
return None
def _build_aim_headers(
self, *, hook: str, key_alias: Optional[str], user_email: Optional[str]
):
"""
A helper function to build the http headers that are required by AIM guardrails.
"""
return (
{
"Authorization": f"Bearer {self.api_key}",
# Used by Aim to apply only the guardrails that should be applied in a specific request phase.
"x-aim-litellm-hook": hook,
# Used by Aim to track LiteLLM version and provide backward compatibility.
"x-aim-litellm-version": litellm_version,
}
# Used by Aim to track guardrails violations by user.
| ({"x-aim-user-email": user_email} if user_email else {})
| (
{
# Used by Aim apply only the guardrails that are associated with the key alias.
"x-aim-litellm-key-alias": key_alias,
}
if key_alias
else {}
)
)
async def async_post_call_success_hook(
self,
data: dict,
@ -160,7 +190,7 @@ class AimGuardrail(CustomGuardrail):
):
content = response.choices[0].message.content or ""
detection = await self.call_aim_guardrail_on_output(
data, content, hook="output"
data, content, hook="output", key_alias=user_api_key_dict.key_alias
)
if detection:
raise HTTPException(status_code=400, detail=detection)
@ -174,11 +204,13 @@ class AimGuardrail(CustomGuardrail):
user_email = (
request_data.get("metadata", {}).get("headers", {}).get("x-aim-user-email")
)
headers = {
"Authorization": f"Bearer {self.api_key}",
} | ({"x-aim-user-email": user_email} if user_email else {})
async with connect(
f"{self.ws_api_base}/detect/output/ws", additional_headers=headers
f"{self.ws_api_base}/detect/output/ws",
additional_headers=self._build_aim_headers(
hook="output",
key_alias=user_api_key_dict.key_alias,
user_email=user_email,
),
) as websocket:
sender = asyncio.create_task(
self.forward_the_stream_to_aim(websocket, response)