diff --git a/litellm/proxy/common_request_processing.py b/litellm/proxy/common_request_processing.py index eeb4e18e97..60050fbeb2 100644 --- a/litellm/proxy/common_request_processing.py +++ b/litellm/proxy/common_request_processing.py @@ -152,6 +152,9 @@ class ProxyBaseLLMRequestProcessing: ): self.data["model"] = litellm.model_alias_map[self.data["model"]] + self.data["litellm_call_id"] = request.headers.get( + "x-litellm-call-id", str(uuid.uuid4()) + ) ### CALL HOOKS ### - modify/reject incoming data before calling the model self.data = await proxy_logging_obj.pre_call_hook( # type: ignore user_api_key_dict=user_api_key_dict, data=self.data, call_type="completion" @@ -159,9 +162,6 @@ class ProxyBaseLLMRequestProcessing: ## LOGGING OBJECT ## - initialize logging object for logging success/failure events for call ## IMPORTANT Note: - initialize this before running pre-call checks. Ensures we log rejected requests to langfuse. - self.data["litellm_call_id"] = request.headers.get( - "x-litellm-call-id", str(uuid.uuid4()) - ) logging_obj, self.data = litellm.utils.function_setup( original_function=route_type, rules_obj=litellm.utils.Rules(), @@ -384,7 +384,7 @@ class ProxyBaseLLMRequestProcessing: @staticmethod def _get_pre_call_type( - route_type: Literal["acompletion", "aresponses"] + route_type: Literal["acompletion", "aresponses"], ) -> Literal["completion", "responses"]: if route_type == "acompletion": return "completion" diff --git a/litellm/proxy/guardrails/guardrail_hooks/aim.py b/litellm/proxy/guardrails/guardrail_hooks/aim.py index 4f8e36ae78..86a9cf778a 100644 --- a/litellm/proxy/guardrails/guardrail_hooks/aim.py +++ b/litellm/proxy/guardrails/guardrail_hooks/aim.py @@ -105,8 +105,12 @@ class AimGuardrail(CustomGuardrail): self, data: dict, hook: str, key_alias: Optional[str] ) -> None: user_email = data.get("metadata", {}).get("headers", {}).get("x-aim-user-email") + call_id = data.get("litellm_call_id") headers = self._build_aim_headers( - hook=hook, key_alias=key_alias, user_email=user_email + hook=hook, + key_alias=key_alias, + user_email=user_email, + litellm_call_id=call_id, ) response = await self.async_handler.post( f"{self.api_base}/detect/openai", @@ -131,10 +135,14 @@ class AimGuardrail(CustomGuardrail): user_email = ( request_data.get("metadata", {}).get("headers", {}).get("x-aim-user-email") ) + call_id = request_data.get("litellm_call_id") response = await self.async_handler.post( f"{self.api_base}/detect/output", headers=self._build_aim_headers( - hook=hook, key_alias=key_alias, user_email=user_email + hook=hook, + key_alias=key_alias, + user_email=user_email, + litellm_call_id=call_id, ), json={"output": output, "messages": request_data.get("messages", [])}, ) @@ -152,7 +160,12 @@ class AimGuardrail(CustomGuardrail): return None def _build_aim_headers( - self, *, hook: str, key_alias: Optional[str], user_email: Optional[str] + self, + *, + hook: str, + key_alias: Optional[str], + user_email: Optional[str], + litellm_call_id: Optional[str], ): """ A helper function to build the http headers that are required by AIM guardrails. @@ -165,6 +178,8 @@ class AimGuardrail(CustomGuardrail): # Used by Aim to track LiteLLM version and provide backward compatibility. "x-aim-litellm-version": litellm_version, } + # Used by Aim to track together single call input and output + | ({"x-aim-litellm-call-id": litellm_call_id} if litellm_call_id else {}) # Used by Aim to track guardrails violations by user. | ({"x-aim-user-email": user_email} if user_email else {}) | ( @@ -204,12 +219,14 @@ class AimGuardrail(CustomGuardrail): user_email = ( request_data.get("metadata", {}).get("headers", {}).get("x-aim-user-email") ) + call_id = request_data.get("litellm_call_id") async with connect( f"{self.ws_api_base}/detect/output/ws", additional_headers=self._build_aim_headers( hook="output", key_alias=user_api_key_dict.key_alias, user_email=user_email, + litellm_call_id=call_id, ), ) as websocket: sender = asyncio.create_task( diff --git a/tests/litellm/proxy/test_common_request_processing.py b/tests/litellm/proxy/test_common_request_processing.py new file mode 100644 index 0000000000..8e795f8b3b --- /dev/null +++ b/tests/litellm/proxy/test_common_request_processing.py @@ -0,0 +1,72 @@ +import copy +import uuid +import pytest +import litellm +from unittest.mock import AsyncMock, MagicMock +from fastapi import Request + +from litellm.integrations.opentelemetry import UserAPIKeyAuth +from litellm.proxy.common_request_processing import ( + ProxyBaseLLMRequestProcessing, + ProxyConfig, +) +from litellm.proxy.utils import ProxyLogging + + +class TestProxyBaseLLMRequestProcessing: + + @pytest.mark.asyncio + async def test_common_processing_pre_call_logic_pre_call_hook_receives_litellm_call_id( + self, monkeypatch + ): + + processing_obj = ProxyBaseLLMRequestProcessing(data={}) + mock_request = MagicMock(spec=Request) + mock_request.headers = {} + + async def mock_add_litellm_data_to_request(*args, **kwargs): + return {} + + async def mock_common_processing_pre_call_logic( + user_api_key_dict, data, call_type + ): + data_copy = copy.deepcopy(data) + return data_copy + + mock_proxy_logging_obj = MagicMock(spec=ProxyLogging) + mock_proxy_logging_obj.pre_call_hook = AsyncMock( + side_effect=mock_common_processing_pre_call_logic + ) + monkeypatch.setattr( + litellm.proxy.common_request_processing, + "add_litellm_data_to_request", + mock_add_litellm_data_to_request, + ) + mock_general_settings = {} + mock_user_api_key_dict = MagicMock(spec=UserAPIKeyAuth) + mock_proxy_config = MagicMock(spec=ProxyConfig) + route_type = "acompletion" + + # Call the actual method. + returned_data, logging_obj = ( + await processing_obj.common_processing_pre_call_logic( + request=mock_request, + general_settings=mock_general_settings, + user_api_key_dict=mock_user_api_key_dict, + proxy_logging_obj=mock_proxy_logging_obj, + proxy_config=mock_proxy_config, + route_type=route_type, + ) + ) + + mock_proxy_logging_obj.pre_call_hook.assert_called_once() + + _, call_kwargs = mock_proxy_logging_obj.pre_call_hook.call_args + data_passed = call_kwargs.get("data", {}) + + assert "litellm_call_id" in data_passed + try: + uuid.UUID(data_passed["litellm_call_id"]) + except ValueError: + pytest.fail("litellm_call_id is not a valid UUID") + assert data_passed["litellm_call_id"] == returned_data["litellm_call_id"]