Add litellm call id passing to Aim guardrails on pre and post-hooks calls (#10021)
All checks were successful
Read Version from pyproject.toml / read-version (push) Successful in 16s
Helm unit test / unit-test (push) Successful in 19s

* Add litellm_call_id passing to aim guardrails on pre and post-hooks

* Add test that ensures that pre_call_hook receives litellm call id when common_request_processing called
This commit is contained in:
Michael Leshchinsky 2025-04-16 17:41:28 +03:00 committed by GitHub
parent ca593e003a
commit e19d05980c
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 96 additions and 7 deletions

View file

@ -152,6 +152,9 @@ class ProxyBaseLLMRequestProcessing:
):
self.data["model"] = litellm.model_alias_map[self.data["model"]]
self.data["litellm_call_id"] = request.headers.get(
"x-litellm-call-id", str(uuid.uuid4())
)
### CALL HOOKS ### - modify/reject incoming data before calling the model
self.data = await proxy_logging_obj.pre_call_hook( # type: ignore
user_api_key_dict=user_api_key_dict, data=self.data, call_type="completion"
@ -159,9 +162,6 @@ class ProxyBaseLLMRequestProcessing:
## LOGGING OBJECT ## - initialize logging object for logging success/failure events for call
## IMPORTANT Note: - initialize this before running pre-call checks. Ensures we log rejected requests to langfuse.
self.data["litellm_call_id"] = request.headers.get(
"x-litellm-call-id", str(uuid.uuid4())
)
logging_obj, self.data = litellm.utils.function_setup(
original_function=route_type,
rules_obj=litellm.utils.Rules(),
@ -384,7 +384,7 @@ class ProxyBaseLLMRequestProcessing:
@staticmethod
def _get_pre_call_type(
route_type: Literal["acompletion", "aresponses"]
route_type: Literal["acompletion", "aresponses"],
) -> Literal["completion", "responses"]:
if route_type == "acompletion":
return "completion"

View file

@ -105,8 +105,12 @@ class AimGuardrail(CustomGuardrail):
self, data: dict, hook: str, key_alias: Optional[str]
) -> None:
user_email = data.get("metadata", {}).get("headers", {}).get("x-aim-user-email")
call_id = data.get("litellm_call_id")
headers = self._build_aim_headers(
hook=hook, key_alias=key_alias, user_email=user_email
hook=hook,
key_alias=key_alias,
user_email=user_email,
litellm_call_id=call_id,
)
response = await self.async_handler.post(
f"{self.api_base}/detect/openai",
@ -131,10 +135,14 @@ class AimGuardrail(CustomGuardrail):
user_email = (
request_data.get("metadata", {}).get("headers", {}).get("x-aim-user-email")
)
call_id = request_data.get("litellm_call_id")
response = await self.async_handler.post(
f"{self.api_base}/detect/output",
headers=self._build_aim_headers(
hook=hook, key_alias=key_alias, user_email=user_email
hook=hook,
key_alias=key_alias,
user_email=user_email,
litellm_call_id=call_id,
),
json={"output": output, "messages": request_data.get("messages", [])},
)
@ -152,7 +160,12 @@ class AimGuardrail(CustomGuardrail):
return None
def _build_aim_headers(
self, *, hook: str, key_alias: Optional[str], user_email: Optional[str]
self,
*,
hook: str,
key_alias: Optional[str],
user_email: Optional[str],
litellm_call_id: Optional[str],
):
"""
A helper function to build the http headers that are required by AIM guardrails.
@ -165,6 +178,8 @@ class AimGuardrail(CustomGuardrail):
# Used by Aim to track LiteLLM version and provide backward compatibility.
"x-aim-litellm-version": litellm_version,
}
# Used by Aim to track together single call input and output
| ({"x-aim-litellm-call-id": litellm_call_id} if litellm_call_id else {})
# Used by Aim to track guardrails violations by user.
| ({"x-aim-user-email": user_email} if user_email else {})
| (
@ -204,12 +219,14 @@ class AimGuardrail(CustomGuardrail):
user_email = (
request_data.get("metadata", {}).get("headers", {}).get("x-aim-user-email")
)
call_id = request_data.get("litellm_call_id")
async with connect(
f"{self.ws_api_base}/detect/output/ws",
additional_headers=self._build_aim_headers(
hook="output",
key_alias=user_api_key_dict.key_alias,
user_email=user_email,
litellm_call_id=call_id,
),
) as websocket:
sender = asyncio.create_task(

View file

@ -0,0 +1,72 @@
import copy
import uuid
import pytest
import litellm
from unittest.mock import AsyncMock, MagicMock
from fastapi import Request
from litellm.integrations.opentelemetry import UserAPIKeyAuth
from litellm.proxy.common_request_processing import (
ProxyBaseLLMRequestProcessing,
ProxyConfig,
)
from litellm.proxy.utils import ProxyLogging
class TestProxyBaseLLMRequestProcessing:
@pytest.mark.asyncio
async def test_common_processing_pre_call_logic_pre_call_hook_receives_litellm_call_id(
self, monkeypatch
):
processing_obj = ProxyBaseLLMRequestProcessing(data={})
mock_request = MagicMock(spec=Request)
mock_request.headers = {}
async def mock_add_litellm_data_to_request(*args, **kwargs):
return {}
async def mock_common_processing_pre_call_logic(
user_api_key_dict, data, call_type
):
data_copy = copy.deepcopy(data)
return data_copy
mock_proxy_logging_obj = MagicMock(spec=ProxyLogging)
mock_proxy_logging_obj.pre_call_hook = AsyncMock(
side_effect=mock_common_processing_pre_call_logic
)
monkeypatch.setattr(
litellm.proxy.common_request_processing,
"add_litellm_data_to_request",
mock_add_litellm_data_to_request,
)
mock_general_settings = {}
mock_user_api_key_dict = MagicMock(spec=UserAPIKeyAuth)
mock_proxy_config = MagicMock(spec=ProxyConfig)
route_type = "acompletion"
# Call the actual method.
returned_data, logging_obj = (
await processing_obj.common_processing_pre_call_logic(
request=mock_request,
general_settings=mock_general_settings,
user_api_key_dict=mock_user_api_key_dict,
proxy_logging_obj=mock_proxy_logging_obj,
proxy_config=mock_proxy_config,
route_type=route_type,
)
)
mock_proxy_logging_obj.pre_call_hook.assert_called_once()
_, call_kwargs = mock_proxy_logging_obj.pre_call_hook.call_args
data_passed = call_kwargs.get("data", {})
assert "litellm_call_id" in data_passed
try:
uuid.UUID(data_passed["litellm_call_id"])
except ValueError:
pytest.fail("litellm_call_id is not a valid UUID")
assert data_passed["litellm_call_id"] == returned_data["litellm_call_id"]