mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-24 18:24:20 +00:00
Add litellm call id passing to Aim guardrails on pre and post-hooks calls (#10021)
* Add litellm_call_id passing to aim guardrails on pre and post-hooks * Add test that ensures that pre_call_hook receives litellm call id when common_request_processing called
This commit is contained in:
parent
ca593e003a
commit
e19d05980c
3 changed files with 96 additions and 7 deletions
|
@ -152,6 +152,9 @@ class ProxyBaseLLMRequestProcessing:
|
|||
):
|
||||
self.data["model"] = litellm.model_alias_map[self.data["model"]]
|
||||
|
||||
self.data["litellm_call_id"] = request.headers.get(
|
||||
"x-litellm-call-id", str(uuid.uuid4())
|
||||
)
|
||||
### CALL HOOKS ### - modify/reject incoming data before calling the model
|
||||
self.data = await proxy_logging_obj.pre_call_hook( # type: ignore
|
||||
user_api_key_dict=user_api_key_dict, data=self.data, call_type="completion"
|
||||
|
@ -159,9 +162,6 @@ class ProxyBaseLLMRequestProcessing:
|
|||
|
||||
## LOGGING OBJECT ## - initialize logging object for logging success/failure events for call
|
||||
## IMPORTANT Note: - initialize this before running pre-call checks. Ensures we log rejected requests to langfuse.
|
||||
self.data["litellm_call_id"] = request.headers.get(
|
||||
"x-litellm-call-id", str(uuid.uuid4())
|
||||
)
|
||||
logging_obj, self.data = litellm.utils.function_setup(
|
||||
original_function=route_type,
|
||||
rules_obj=litellm.utils.Rules(),
|
||||
|
@ -384,7 +384,7 @@ class ProxyBaseLLMRequestProcessing:
|
|||
|
||||
@staticmethod
|
||||
def _get_pre_call_type(
|
||||
route_type: Literal["acompletion", "aresponses"]
|
||||
route_type: Literal["acompletion", "aresponses"],
|
||||
) -> Literal["completion", "responses"]:
|
||||
if route_type == "acompletion":
|
||||
return "completion"
|
||||
|
|
|
@ -105,8 +105,12 @@ class AimGuardrail(CustomGuardrail):
|
|||
self, data: dict, hook: str, key_alias: Optional[str]
|
||||
) -> None:
|
||||
user_email = data.get("metadata", {}).get("headers", {}).get("x-aim-user-email")
|
||||
call_id = data.get("litellm_call_id")
|
||||
headers = self._build_aim_headers(
|
||||
hook=hook, key_alias=key_alias, user_email=user_email
|
||||
hook=hook,
|
||||
key_alias=key_alias,
|
||||
user_email=user_email,
|
||||
litellm_call_id=call_id,
|
||||
)
|
||||
response = await self.async_handler.post(
|
||||
f"{self.api_base}/detect/openai",
|
||||
|
@ -131,10 +135,14 @@ class AimGuardrail(CustomGuardrail):
|
|||
user_email = (
|
||||
request_data.get("metadata", {}).get("headers", {}).get("x-aim-user-email")
|
||||
)
|
||||
call_id = request_data.get("litellm_call_id")
|
||||
response = await self.async_handler.post(
|
||||
f"{self.api_base}/detect/output",
|
||||
headers=self._build_aim_headers(
|
||||
hook=hook, key_alias=key_alias, user_email=user_email
|
||||
hook=hook,
|
||||
key_alias=key_alias,
|
||||
user_email=user_email,
|
||||
litellm_call_id=call_id,
|
||||
),
|
||||
json={"output": output, "messages": request_data.get("messages", [])},
|
||||
)
|
||||
|
@ -152,7 +160,12 @@ class AimGuardrail(CustomGuardrail):
|
|||
return None
|
||||
|
||||
def _build_aim_headers(
|
||||
self, *, hook: str, key_alias: Optional[str], user_email: Optional[str]
|
||||
self,
|
||||
*,
|
||||
hook: str,
|
||||
key_alias: Optional[str],
|
||||
user_email: Optional[str],
|
||||
litellm_call_id: Optional[str],
|
||||
):
|
||||
"""
|
||||
A helper function to build the http headers that are required by AIM guardrails.
|
||||
|
@ -165,6 +178,8 @@ class AimGuardrail(CustomGuardrail):
|
|||
# Used by Aim to track LiteLLM version and provide backward compatibility.
|
||||
"x-aim-litellm-version": litellm_version,
|
||||
}
|
||||
# Used by Aim to track together single call input and output
|
||||
| ({"x-aim-litellm-call-id": litellm_call_id} if litellm_call_id else {})
|
||||
# Used by Aim to track guardrails violations by user.
|
||||
| ({"x-aim-user-email": user_email} if user_email else {})
|
||||
| (
|
||||
|
@ -204,12 +219,14 @@ class AimGuardrail(CustomGuardrail):
|
|||
user_email = (
|
||||
request_data.get("metadata", {}).get("headers", {}).get("x-aim-user-email")
|
||||
)
|
||||
call_id = request_data.get("litellm_call_id")
|
||||
async with connect(
|
||||
f"{self.ws_api_base}/detect/output/ws",
|
||||
additional_headers=self._build_aim_headers(
|
||||
hook="output",
|
||||
key_alias=user_api_key_dict.key_alias,
|
||||
user_email=user_email,
|
||||
litellm_call_id=call_id,
|
||||
),
|
||||
) as websocket:
|
||||
sender = asyncio.create_task(
|
||||
|
|
72
tests/litellm/proxy/test_common_request_processing.py
Normal file
72
tests/litellm/proxy/test_common_request_processing.py
Normal file
|
@ -0,0 +1,72 @@
|
|||
import copy
|
||||
import uuid
|
||||
import pytest
|
||||
import litellm
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
from fastapi import Request
|
||||
|
||||
from litellm.integrations.opentelemetry import UserAPIKeyAuth
|
||||
from litellm.proxy.common_request_processing import (
|
||||
ProxyBaseLLMRequestProcessing,
|
||||
ProxyConfig,
|
||||
)
|
||||
from litellm.proxy.utils import ProxyLogging
|
||||
|
||||
|
||||
class TestProxyBaseLLMRequestProcessing:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_common_processing_pre_call_logic_pre_call_hook_receives_litellm_call_id(
|
||||
self, monkeypatch
|
||||
):
|
||||
|
||||
processing_obj = ProxyBaseLLMRequestProcessing(data={})
|
||||
mock_request = MagicMock(spec=Request)
|
||||
mock_request.headers = {}
|
||||
|
||||
async def mock_add_litellm_data_to_request(*args, **kwargs):
|
||||
return {}
|
||||
|
||||
async def mock_common_processing_pre_call_logic(
|
||||
user_api_key_dict, data, call_type
|
||||
):
|
||||
data_copy = copy.deepcopy(data)
|
||||
return data_copy
|
||||
|
||||
mock_proxy_logging_obj = MagicMock(spec=ProxyLogging)
|
||||
mock_proxy_logging_obj.pre_call_hook = AsyncMock(
|
||||
side_effect=mock_common_processing_pre_call_logic
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
litellm.proxy.common_request_processing,
|
||||
"add_litellm_data_to_request",
|
||||
mock_add_litellm_data_to_request,
|
||||
)
|
||||
mock_general_settings = {}
|
||||
mock_user_api_key_dict = MagicMock(spec=UserAPIKeyAuth)
|
||||
mock_proxy_config = MagicMock(spec=ProxyConfig)
|
||||
route_type = "acompletion"
|
||||
|
||||
# Call the actual method.
|
||||
returned_data, logging_obj = (
|
||||
await processing_obj.common_processing_pre_call_logic(
|
||||
request=mock_request,
|
||||
general_settings=mock_general_settings,
|
||||
user_api_key_dict=mock_user_api_key_dict,
|
||||
proxy_logging_obj=mock_proxy_logging_obj,
|
||||
proxy_config=mock_proxy_config,
|
||||
route_type=route_type,
|
||||
)
|
||||
)
|
||||
|
||||
mock_proxy_logging_obj.pre_call_hook.assert_called_once()
|
||||
|
||||
_, call_kwargs = mock_proxy_logging_obj.pre_call_hook.call_args
|
||||
data_passed = call_kwargs.get("data", {})
|
||||
|
||||
assert "litellm_call_id" in data_passed
|
||||
try:
|
||||
uuid.UUID(data_passed["litellm_call_id"])
|
||||
except ValueError:
|
||||
pytest.fail("litellm_call_id is not a valid UUID")
|
||||
assert data_passed["litellm_call_id"] == returned_data["litellm_call_id"]
|
Loading…
Add table
Add a link
Reference in a new issue