Virtual key based policies in Aim Guardrails (#9499)

* report key alias to aim

* send litellm version to aim

* Update docs

* blacken

* add docs

* Add info part about virtual keys specific guards

* sort guardrails alphabetically

* fix ruff
This commit is contained in:
Tomer Bin 2025-04-02 07:57:23 +03:00 committed by GitHub
parent ac3399238e
commit 0690f7a3cb
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 78 additions and 29 deletions

View file

@ -23,6 +23,12 @@ In the newly created guard's page, you can find a reference to the prompt policy
You can decide which detections will be enabled, and set the threshold for each detection.
:::info
When using LiteLLM with virtual keys, key-specific policies can be set directly in Aim's guards page by specifying the virtual key alias when creating the guard.
Only the aliases of your virtual keys (and not the actual key secrets) will be sent to Aim.
:::
### 3. Add Aim Guardrail on your LiteLLM config.yaml
Define your guardrails under the `guardrails` section

View file

@ -17,6 +17,14 @@ model_list:
api_key: os.environ/OPENAI_API_KEY
guardrails:
- guardrail_name: general-guard
litellm_params:
guardrail: aim
mode: [pre_call, post_call]
api_key: os.environ/AIM_API_KEY
api_base: os.environ/AIM_API_BASE
default_on: true # Optional
- guardrail_name: "aporia-pre-guard"
litellm_params:
guardrail: aporia # supported values: "aporia", "lakera"
@ -45,6 +53,7 @@ guardrails:
- `pre_call` Run **before** LLM call, on **input**
- `post_call` Run **after** LLM call, on **input & output**
- `during_call` Run **during** LLM call, on **input** Same as `pre_call` but runs in parallel as LLM call. Response not returned until guardrail check completes
- A list of the above values to run multiple modes, e.g. `mode: [pre_call, post_call]`
## 2. Start LiteLLM Gateway
@ -569,4 +578,4 @@ guardrails: Union[
class DynamicGuardrailParams:
extra_body: Dict[str, Any] # Additional parameters for the guardrail
```
```

View file

@ -137,15 +137,17 @@ const sidebars = {
label: "[Beta] Guardrails",
items: [
"proxy/guardrails/quick_start",
"proxy/guardrails/aim_security",
"proxy/guardrails/aporia_api",
"proxy/guardrails/bedrock",
"proxy/guardrails/guardrails_ai",
"proxy/guardrails/lakera_ai",
"proxy/guardrails/pii_masking_v2",
"proxy/guardrails/secret_detection",
"proxy/guardrails/custom_guardrail",
"prompt_injection"
...[
"proxy/guardrails/aim_security",
"proxy/guardrails/aporia_api",
"proxy/guardrails/bedrock",
"proxy/guardrails/guardrails_ai",
"proxy/guardrails/lakera_ai",
"proxy/guardrails/pii_masking_v2",
"proxy/guardrails/secret_detection",
"proxy/guardrails/custom_guardrail",
"proxy/guardrails/prompt_injection",
].sort(),
],
},
{

View file

@ -3,4 +3,4 @@ import importlib_metadata
try:
version = importlib_metadata.version("litellm")
except Exception:
pass
version = "unknown"

View file

@ -14,6 +14,7 @@ from pydantic import BaseModel
from websockets.asyncio.client import ClientConnection, connect
from litellm import DualCache
from litellm._version import version as litellm_version
from litellm._logging import verbose_proxy_logger
from litellm.integrations.custom_guardrail import CustomGuardrail
from litellm.llms.custom_httpx.http_handler import (
@ -75,7 +76,9 @@ class AimGuardrail(CustomGuardrail):
) -> Union[Exception, str, dict, None]:
verbose_proxy_logger.debug("Inside AIM Pre-Call Hook")
await self.call_aim_guardrail(data, hook="pre_call")
await self.call_aim_guardrail(
data, hook="pre_call", key_alias=user_api_key_dict.key_alias
)
return data
async def async_moderation_hook(
@ -93,15 +96,18 @@ class AimGuardrail(CustomGuardrail):
) -> Union[Exception, str, dict, None]:
verbose_proxy_logger.debug("Inside AIM Moderation Hook")
await self.call_aim_guardrail(data, hook="moderation")
await self.call_aim_guardrail(
data, hook="moderation", key_alias=user_api_key_dict.key_alias
)
return data
async def call_aim_guardrail(self, data: dict, hook: str) -> None:
async def call_aim_guardrail(
self, data: dict, hook: str, key_alias: Optional[str]
) -> None:
user_email = data.get("metadata", {}).get("headers", {}).get("x-aim-user-email")
headers = {
"Authorization": f"Bearer {self.api_key}",
"x-aim-litellm-hook": hook,
} | ({"x-aim-user-email": user_email} if user_email else {})
headers = self._build_aim_headers(
hook=hook, key_alias=key_alias, user_email=user_email
)
response = await self.async_handler.post(
f"{self.api_base}/detect/openai",
headers=headers,
@ -120,18 +126,16 @@ class AimGuardrail(CustomGuardrail):
raise HTTPException(status_code=400, detail=res["detection_message"])
async def call_aim_guardrail_on_output(
self, request_data: dict, output: str, hook: str
self, request_data: dict, output: str, hook: str, key_alias: Optional[str]
) -> Optional[str]:
user_email = (
request_data.get("metadata", {}).get("headers", {}).get("x-aim-user-email")
)
headers = {
"Authorization": f"Bearer {self.api_key}",
"x-aim-litellm-hook": hook,
} | ({"x-aim-user-email": user_email} if user_email else {})
response = await self.async_handler.post(
f"{self.api_base}/detect/output",
headers=headers,
headers=self._build_aim_headers(
hook=hook, key_alias=key_alias, user_email=user_email
),
json={"output": output, "messages": request_data.get("messages", [])},
)
response.raise_for_status()
@ -147,6 +151,32 @@ class AimGuardrail(CustomGuardrail):
return res["detection_message"]
return None
def _build_aim_headers(
self, *, hook: str, key_alias: Optional[str], user_email: Optional[str]
):
"""
A helper function to build the http headers that are required by AIM guardrails.
"""
return (
{
"Authorization": f"Bearer {self.api_key}",
# Used by Aim to apply only the guardrails that should be applied in a specific request phase.
"x-aim-litellm-hook": hook,
# Used by Aim to track LiteLLM version and provide backward compatibility.
"x-aim-litellm-version": litellm_version,
}
# Used by Aim to track guardrails violations by user.
| ({"x-aim-user-email": user_email} if user_email else {})
| (
{
# Used by Aim apply only the guardrails that are associated with the key alias.
"x-aim-litellm-key-alias": key_alias,
}
if key_alias
else {}
)
)
async def async_post_call_success_hook(
self,
data: dict,
@ -160,7 +190,7 @@ class AimGuardrail(CustomGuardrail):
):
content = response.choices[0].message.content or ""
detection = await self.call_aim_guardrail_on_output(
data, content, hook="output"
data, content, hook="output", key_alias=user_api_key_dict.key_alias
)
if detection:
raise HTTPException(status_code=400, detail=detection)
@ -174,11 +204,13 @@ class AimGuardrail(CustomGuardrail):
user_email = (
request_data.get("metadata", {}).get("headers", {}).get("x-aim-user-email")
)
headers = {
"Authorization": f"Bearer {self.api_key}",
} | ({"x-aim-user-email": user_email} if user_email else {})
async with connect(
f"{self.ws_api_base}/detect/output/ws", additional_headers=headers
f"{self.ws_api_base}/detect/output/ws",
additional_headers=self._build_aim_headers(
hook="output",
key_alias=user_api_key_dict.key_alias,
user_email=user_email,
),
) as websocket:
sender = asyncio.create_task(
self.forward_the_stream_to_aim(websocket, response)