Add litellm call id passing to Aim guardrails on pre and post-hooks calls (#10021)
All checks were successful
Read Version from pyproject.toml / read-version (push) Successful in 16s
Helm unit test / unit-test (push) Successful in 19s

* Add litellm_call_id passing to aim guardrails on pre and post-hooks

* Add test that ensures that pre_call_hook receives litellm call id when common_request_processing called
This commit is contained in:
Michael Leshchinsky 2025-04-16 17:41:28 +03:00 committed by GitHub
parent ca593e003a
commit e19d05980c
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 96 additions and 7 deletions

View file

@ -105,8 +105,12 @@ class AimGuardrail(CustomGuardrail):
self, data: dict, hook: str, key_alias: Optional[str]
) -> None:
user_email = data.get("metadata", {}).get("headers", {}).get("x-aim-user-email")
call_id = data.get("litellm_call_id")
headers = self._build_aim_headers(
hook=hook, key_alias=key_alias, user_email=user_email
hook=hook,
key_alias=key_alias,
user_email=user_email,
litellm_call_id=call_id,
)
response = await self.async_handler.post(
f"{self.api_base}/detect/openai",
@ -131,10 +135,14 @@ class AimGuardrail(CustomGuardrail):
user_email = (
request_data.get("metadata", {}).get("headers", {}).get("x-aim-user-email")
)
call_id = request_data.get("litellm_call_id")
response = await self.async_handler.post(
f"{self.api_base}/detect/output",
headers=self._build_aim_headers(
hook=hook, key_alias=key_alias, user_email=user_email
hook=hook,
key_alias=key_alias,
user_email=user_email,
litellm_call_id=call_id,
),
json={"output": output, "messages": request_data.get("messages", [])},
)
@ -152,7 +160,12 @@ class AimGuardrail(CustomGuardrail):
return None
def _build_aim_headers(
self, *, hook: str, key_alias: Optional[str], user_email: Optional[str]
self,
*,
hook: str,
key_alias: Optional[str],
user_email: Optional[str],
litellm_call_id: Optional[str],
):
"""
A helper function to build the http headers that are required by AIM guardrails.
@ -165,6 +178,8 @@ class AimGuardrail(CustomGuardrail):
# Used by Aim to track LiteLLM version and provide backward compatibility.
"x-aim-litellm-version": litellm_version,
}
# Used by Aim to track together single call input and output
| ({"x-aim-litellm-call-id": litellm_call_id} if litellm_call_id else {})
# Used by Aim to track guardrails violations by user.
| ({"x-aim-user-email": user_email} if user_email else {})
| (
@ -204,12 +219,14 @@ class AimGuardrail(CustomGuardrail):
user_email = (
request_data.get("metadata", {}).get("headers", {}).get("x-aim-user-email")
)
call_id = request_data.get("litellm_call_id")
async with connect(
f"{self.ws_api_base}/detect/output/ws",
additional_headers=self._build_aim_headers(
hook="output",
key_alias=user_api_key_dict.key_alias,
user_email=user_email,
litellm_call_id=call_id,
),
) as websocket:
sender = asyncio.create_task(