mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 11:14:04 +00:00
Virtual key based policies in Aim Guardrails (#9499)
* report key alias to aim * send litellm version to aim * Update docs * blacken * add docs * Add info part about virtual keys specific guards * sort guardrails alphabetically * fix ruff
This commit is contained in:
parent
ac3399238e
commit
0690f7a3cb
6 changed files with 78 additions and 29 deletions
|
@ -23,6 +23,12 @@ In the newly created guard's page, you can find a reference to the prompt policy
|
||||||
|
|
||||||
You can decide which detections will be enabled, and set the threshold for each detection.
|
You can decide which detections will be enabled, and set the threshold for each detection.
|
||||||
|
|
||||||
|
:::info
|
||||||
|
When using LiteLLM with virtual keys, key-specific policies can be set directly in Aim's guards page by specifying the virtual key alias when creating the guard.
|
||||||
|
|
||||||
|
Only the aliases of your virtual keys (and not the actual key secrets) will be sent to Aim.
|
||||||
|
:::
|
||||||
|
|
||||||
### 3. Add Aim Guardrail on your LiteLLM config.yaml
|
### 3. Add Aim Guardrail on your LiteLLM config.yaml
|
||||||
|
|
||||||
Define your guardrails under the `guardrails` section
|
Define your guardrails under the `guardrails` section
|
||||||
|
|
|
@ -17,6 +17,14 @@ model_list:
|
||||||
api_key: os.environ/OPENAI_API_KEY
|
api_key: os.environ/OPENAI_API_KEY
|
||||||
|
|
||||||
guardrails:
|
guardrails:
|
||||||
|
- guardrail_name: general-guard
|
||||||
|
litellm_params:
|
||||||
|
guardrail: aim
|
||||||
|
mode: [pre_call, post_call]
|
||||||
|
api_key: os.environ/AIM_API_KEY
|
||||||
|
api_base: os.environ/AIM_API_BASE
|
||||||
|
default_on: true # Optional
|
||||||
|
|
||||||
- guardrail_name: "aporia-pre-guard"
|
- guardrail_name: "aporia-pre-guard"
|
||||||
litellm_params:
|
litellm_params:
|
||||||
guardrail: aporia # supported values: "aporia", "lakera"
|
guardrail: aporia # supported values: "aporia", "lakera"
|
||||||
|
@ -45,6 +53,7 @@ guardrails:
|
||||||
- `pre_call` Run **before** LLM call, on **input**
|
- `pre_call` Run **before** LLM call, on **input**
|
||||||
- `post_call` Run **after** LLM call, on **input & output**
|
- `post_call` Run **after** LLM call, on **input & output**
|
||||||
- `during_call` Run **during** LLM call, on **input** Same as `pre_call` but runs in parallel as LLM call. Response not returned until guardrail check completes
|
- `during_call` Run **during** LLM call, on **input** Same as `pre_call` but runs in parallel as LLM call. Response not returned until guardrail check completes
|
||||||
|
- A list of the above values to run multiple modes, e.g. `mode: [pre_call, post_call]`
|
||||||
|
|
||||||
|
|
||||||
## 2. Start LiteLLM Gateway
|
## 2. Start LiteLLM Gateway
|
||||||
|
|
|
@ -137,15 +137,17 @@ const sidebars = {
|
||||||
label: "[Beta] Guardrails",
|
label: "[Beta] Guardrails",
|
||||||
items: [
|
items: [
|
||||||
"proxy/guardrails/quick_start",
|
"proxy/guardrails/quick_start",
|
||||||
"proxy/guardrails/aim_security",
|
...[
|
||||||
"proxy/guardrails/aporia_api",
|
"proxy/guardrails/aim_security",
|
||||||
"proxy/guardrails/bedrock",
|
"proxy/guardrails/aporia_api",
|
||||||
"proxy/guardrails/guardrails_ai",
|
"proxy/guardrails/bedrock",
|
||||||
"proxy/guardrails/lakera_ai",
|
"proxy/guardrails/guardrails_ai",
|
||||||
"proxy/guardrails/pii_masking_v2",
|
"proxy/guardrails/lakera_ai",
|
||||||
"proxy/guardrails/secret_detection",
|
"proxy/guardrails/pii_masking_v2",
|
||||||
"proxy/guardrails/custom_guardrail",
|
"proxy/guardrails/secret_detection",
|
||||||
"prompt_injection"
|
"proxy/guardrails/custom_guardrail",
|
||||||
|
"proxy/guardrails/prompt_injection",
|
||||||
|
].sort(),
|
||||||
],
|
],
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
|
|
@ -3,4 +3,4 @@ import importlib_metadata
|
||||||
try:
|
try:
|
||||||
version = importlib_metadata.version("litellm")
|
version = importlib_metadata.version("litellm")
|
||||||
except Exception:
|
except Exception:
|
||||||
pass
|
version = "unknown"
|
||||||
|
|
|
@ -14,6 +14,7 @@ from pydantic import BaseModel
|
||||||
from websockets.asyncio.client import ClientConnection, connect
|
from websockets.asyncio.client import ClientConnection, connect
|
||||||
|
|
||||||
from litellm import DualCache
|
from litellm import DualCache
|
||||||
|
from litellm._version import version as litellm_version
|
||||||
from litellm._logging import verbose_proxy_logger
|
from litellm._logging import verbose_proxy_logger
|
||||||
from litellm.integrations.custom_guardrail import CustomGuardrail
|
from litellm.integrations.custom_guardrail import CustomGuardrail
|
||||||
from litellm.llms.custom_httpx.http_handler import (
|
from litellm.llms.custom_httpx.http_handler import (
|
||||||
|
@ -75,7 +76,9 @@ class AimGuardrail(CustomGuardrail):
|
||||||
) -> Union[Exception, str, dict, None]:
|
) -> Union[Exception, str, dict, None]:
|
||||||
verbose_proxy_logger.debug("Inside AIM Pre-Call Hook")
|
verbose_proxy_logger.debug("Inside AIM Pre-Call Hook")
|
||||||
|
|
||||||
await self.call_aim_guardrail(data, hook="pre_call")
|
await self.call_aim_guardrail(
|
||||||
|
data, hook="pre_call", key_alias=user_api_key_dict.key_alias
|
||||||
|
)
|
||||||
return data
|
return data
|
||||||
|
|
||||||
async def async_moderation_hook(
|
async def async_moderation_hook(
|
||||||
|
@ -93,15 +96,18 @@ class AimGuardrail(CustomGuardrail):
|
||||||
) -> Union[Exception, str, dict, None]:
|
) -> Union[Exception, str, dict, None]:
|
||||||
verbose_proxy_logger.debug("Inside AIM Moderation Hook")
|
verbose_proxy_logger.debug("Inside AIM Moderation Hook")
|
||||||
|
|
||||||
await self.call_aim_guardrail(data, hook="moderation")
|
await self.call_aim_guardrail(
|
||||||
|
data, hook="moderation", key_alias=user_api_key_dict.key_alias
|
||||||
|
)
|
||||||
return data
|
return data
|
||||||
|
|
||||||
async def call_aim_guardrail(self, data: dict, hook: str) -> None:
|
async def call_aim_guardrail(
|
||||||
|
self, data: dict, hook: str, key_alias: Optional[str]
|
||||||
|
) -> None:
|
||||||
user_email = data.get("metadata", {}).get("headers", {}).get("x-aim-user-email")
|
user_email = data.get("metadata", {}).get("headers", {}).get("x-aim-user-email")
|
||||||
headers = {
|
headers = self._build_aim_headers(
|
||||||
"Authorization": f"Bearer {self.api_key}",
|
hook=hook, key_alias=key_alias, user_email=user_email
|
||||||
"x-aim-litellm-hook": hook,
|
)
|
||||||
} | ({"x-aim-user-email": user_email} if user_email else {})
|
|
||||||
response = await self.async_handler.post(
|
response = await self.async_handler.post(
|
||||||
f"{self.api_base}/detect/openai",
|
f"{self.api_base}/detect/openai",
|
||||||
headers=headers,
|
headers=headers,
|
||||||
|
@ -120,18 +126,16 @@ class AimGuardrail(CustomGuardrail):
|
||||||
raise HTTPException(status_code=400, detail=res["detection_message"])
|
raise HTTPException(status_code=400, detail=res["detection_message"])
|
||||||
|
|
||||||
async def call_aim_guardrail_on_output(
|
async def call_aim_guardrail_on_output(
|
||||||
self, request_data: dict, output: str, hook: str
|
self, request_data: dict, output: str, hook: str, key_alias: Optional[str]
|
||||||
) -> Optional[str]:
|
) -> Optional[str]:
|
||||||
user_email = (
|
user_email = (
|
||||||
request_data.get("metadata", {}).get("headers", {}).get("x-aim-user-email")
|
request_data.get("metadata", {}).get("headers", {}).get("x-aim-user-email")
|
||||||
)
|
)
|
||||||
headers = {
|
|
||||||
"Authorization": f"Bearer {self.api_key}",
|
|
||||||
"x-aim-litellm-hook": hook,
|
|
||||||
} | ({"x-aim-user-email": user_email} if user_email else {})
|
|
||||||
response = await self.async_handler.post(
|
response = await self.async_handler.post(
|
||||||
f"{self.api_base}/detect/output",
|
f"{self.api_base}/detect/output",
|
||||||
headers=headers,
|
headers=self._build_aim_headers(
|
||||||
|
hook=hook, key_alias=key_alias, user_email=user_email
|
||||||
|
),
|
||||||
json={"output": output, "messages": request_data.get("messages", [])},
|
json={"output": output, "messages": request_data.get("messages", [])},
|
||||||
)
|
)
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
|
@ -147,6 +151,32 @@ class AimGuardrail(CustomGuardrail):
|
||||||
return res["detection_message"]
|
return res["detection_message"]
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
def _build_aim_headers(
|
||||||
|
self, *, hook: str, key_alias: Optional[str], user_email: Optional[str]
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
A helper function to build the http headers that are required by AIM guardrails.
|
||||||
|
"""
|
||||||
|
return (
|
||||||
|
{
|
||||||
|
"Authorization": f"Bearer {self.api_key}",
|
||||||
|
# Used by Aim to apply only the guardrails that should be applied in a specific request phase.
|
||||||
|
"x-aim-litellm-hook": hook,
|
||||||
|
# Used by Aim to track LiteLLM version and provide backward compatibility.
|
||||||
|
"x-aim-litellm-version": litellm_version,
|
||||||
|
}
|
||||||
|
# Used by Aim to track guardrails violations by user.
|
||||||
|
| ({"x-aim-user-email": user_email} if user_email else {})
|
||||||
|
| (
|
||||||
|
{
|
||||||
|
# Used by Aim apply only the guardrails that are associated with the key alias.
|
||||||
|
"x-aim-litellm-key-alias": key_alias,
|
||||||
|
}
|
||||||
|
if key_alias
|
||||||
|
else {}
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
async def async_post_call_success_hook(
|
async def async_post_call_success_hook(
|
||||||
self,
|
self,
|
||||||
data: dict,
|
data: dict,
|
||||||
|
@ -160,7 +190,7 @@ class AimGuardrail(CustomGuardrail):
|
||||||
):
|
):
|
||||||
content = response.choices[0].message.content or ""
|
content = response.choices[0].message.content or ""
|
||||||
detection = await self.call_aim_guardrail_on_output(
|
detection = await self.call_aim_guardrail_on_output(
|
||||||
data, content, hook="output"
|
data, content, hook="output", key_alias=user_api_key_dict.key_alias
|
||||||
)
|
)
|
||||||
if detection:
|
if detection:
|
||||||
raise HTTPException(status_code=400, detail=detection)
|
raise HTTPException(status_code=400, detail=detection)
|
||||||
|
@ -174,11 +204,13 @@ class AimGuardrail(CustomGuardrail):
|
||||||
user_email = (
|
user_email = (
|
||||||
request_data.get("metadata", {}).get("headers", {}).get("x-aim-user-email")
|
request_data.get("metadata", {}).get("headers", {}).get("x-aim-user-email")
|
||||||
)
|
)
|
||||||
headers = {
|
|
||||||
"Authorization": f"Bearer {self.api_key}",
|
|
||||||
} | ({"x-aim-user-email": user_email} if user_email else {})
|
|
||||||
async with connect(
|
async with connect(
|
||||||
f"{self.ws_api_base}/detect/output/ws", additional_headers=headers
|
f"{self.ws_api_base}/detect/output/ws",
|
||||||
|
additional_headers=self._build_aim_headers(
|
||||||
|
hook="output",
|
||||||
|
key_alias=user_api_key_dict.key_alias,
|
||||||
|
user_email=user_email,
|
||||||
|
),
|
||||||
) as websocket:
|
) as websocket:
|
||||||
sender = asyncio.create_task(
|
sender = asyncio.create_task(
|
||||||
self.forward_the_stream_to_aim(websocket, response)
|
self.forward_the_stream_to_aim(websocket, response)
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue