mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 18:54:30 +00:00
Add litellm call id passing to Aim guardrails on pre and post-hooks calls (#10021)
* Add litellm_call_id passing to aim guardrails on pre and post-hooks * Add test that ensures that pre_call_hook receives litellm call id when common_request_processing called
This commit is contained in:
parent
ca593e003a
commit
e19d05980c
3 changed files with 96 additions and 7 deletions
|
@ -105,8 +105,12 @@ class AimGuardrail(CustomGuardrail):
|
|||
self, data: dict, hook: str, key_alias: Optional[str]
|
||||
) -> None:
|
||||
user_email = data.get("metadata", {}).get("headers", {}).get("x-aim-user-email")
|
||||
call_id = data.get("litellm_call_id")
|
||||
headers = self._build_aim_headers(
|
||||
hook=hook, key_alias=key_alias, user_email=user_email
|
||||
hook=hook,
|
||||
key_alias=key_alias,
|
||||
user_email=user_email,
|
||||
litellm_call_id=call_id,
|
||||
)
|
||||
response = await self.async_handler.post(
|
||||
f"{self.api_base}/detect/openai",
|
||||
|
@ -131,10 +135,14 @@ class AimGuardrail(CustomGuardrail):
|
|||
user_email = (
|
||||
request_data.get("metadata", {}).get("headers", {}).get("x-aim-user-email")
|
||||
)
|
||||
call_id = request_data.get("litellm_call_id")
|
||||
response = await self.async_handler.post(
|
||||
f"{self.api_base}/detect/output",
|
||||
headers=self._build_aim_headers(
|
||||
hook=hook, key_alias=key_alias, user_email=user_email
|
||||
hook=hook,
|
||||
key_alias=key_alias,
|
||||
user_email=user_email,
|
||||
litellm_call_id=call_id,
|
||||
),
|
||||
json={"output": output, "messages": request_data.get("messages", [])},
|
||||
)
|
||||
|
@ -152,7 +160,12 @@ class AimGuardrail(CustomGuardrail):
|
|||
return None
|
||||
|
||||
def _build_aim_headers(
|
||||
self, *, hook: str, key_alias: Optional[str], user_email: Optional[str]
|
||||
self,
|
||||
*,
|
||||
hook: str,
|
||||
key_alias: Optional[str],
|
||||
user_email: Optional[str],
|
||||
litellm_call_id: Optional[str],
|
||||
):
|
||||
"""
|
||||
A helper function to build the http headers that are required by AIM guardrails.
|
||||
|
@ -165,6 +178,8 @@ class AimGuardrail(CustomGuardrail):
|
|||
# Used by Aim to track LiteLLM version and provide backward compatibility.
|
||||
"x-aim-litellm-version": litellm_version,
|
||||
}
|
||||
# Used by Aim to track together single call input and output
|
||||
| ({"x-aim-litellm-call-id": litellm_call_id} if litellm_call_id else {})
|
||||
# Used by Aim to track guardrails violations by user.
|
||||
| ({"x-aim-user-email": user_email} if user_email else {})
|
||||
| (
|
||||
|
@ -204,12 +219,14 @@ class AimGuardrail(CustomGuardrail):
|
|||
user_email = (
|
||||
request_data.get("metadata", {}).get("headers", {}).get("x-aim-user-email")
|
||||
)
|
||||
call_id = request_data.get("litellm_call_id")
|
||||
async with connect(
|
||||
f"{self.ws_api_base}/detect/output/ws",
|
||||
additional_headers=self._build_aim_headers(
|
||||
hook="output",
|
||||
key_alias=user_api_key_dict.key_alias,
|
||||
user_email=user_email,
|
||||
litellm_call_id=call_id,
|
||||
),
|
||||
) as websocket:
|
||||
sender = asyncio.create_task(
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue