mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 02:34:29 +00:00
Add litellm call id passing to Aim guardrails on pre and post-hooks calls (#10021)
* Add litellm_call_id passing to aim guardrails on pre and post-hooks * Add test that ensures that pre_call_hook receives litellm call id when common_request_processing called
This commit is contained in:
parent
ca593e003a
commit
e19d05980c
3 changed files with 96 additions and 7 deletions
|
@ -152,6 +152,9 @@ class ProxyBaseLLMRequestProcessing:
|
||||||
):
|
):
|
||||||
self.data["model"] = litellm.model_alias_map[self.data["model"]]
|
self.data["model"] = litellm.model_alias_map[self.data["model"]]
|
||||||
|
|
||||||
|
self.data["litellm_call_id"] = request.headers.get(
|
||||||
|
"x-litellm-call-id", str(uuid.uuid4())
|
||||||
|
)
|
||||||
### CALL HOOKS ### - modify/reject incoming data before calling the model
|
### CALL HOOKS ### - modify/reject incoming data before calling the model
|
||||||
self.data = await proxy_logging_obj.pre_call_hook( # type: ignore
|
self.data = await proxy_logging_obj.pre_call_hook( # type: ignore
|
||||||
user_api_key_dict=user_api_key_dict, data=self.data, call_type="completion"
|
user_api_key_dict=user_api_key_dict, data=self.data, call_type="completion"
|
||||||
|
@ -159,9 +162,6 @@ class ProxyBaseLLMRequestProcessing:
|
||||||
|
|
||||||
## LOGGING OBJECT ## - initialize logging object for logging success/failure events for call
|
## LOGGING OBJECT ## - initialize logging object for logging success/failure events for call
|
||||||
## IMPORTANT Note: - initialize this before running pre-call checks. Ensures we log rejected requests to langfuse.
|
## IMPORTANT Note: - initialize this before running pre-call checks. Ensures we log rejected requests to langfuse.
|
||||||
self.data["litellm_call_id"] = request.headers.get(
|
|
||||||
"x-litellm-call-id", str(uuid.uuid4())
|
|
||||||
)
|
|
||||||
logging_obj, self.data = litellm.utils.function_setup(
|
logging_obj, self.data = litellm.utils.function_setup(
|
||||||
original_function=route_type,
|
original_function=route_type,
|
||||||
rules_obj=litellm.utils.Rules(),
|
rules_obj=litellm.utils.Rules(),
|
||||||
|
@ -384,7 +384,7 @@ class ProxyBaseLLMRequestProcessing:
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _get_pre_call_type(
|
def _get_pre_call_type(
|
||||||
route_type: Literal["acompletion", "aresponses"]
|
route_type: Literal["acompletion", "aresponses"],
|
||||||
) -> Literal["completion", "responses"]:
|
) -> Literal["completion", "responses"]:
|
||||||
if route_type == "acompletion":
|
if route_type == "acompletion":
|
||||||
return "completion"
|
return "completion"
|
||||||
|
|
|
@ -105,8 +105,12 @@ class AimGuardrail(CustomGuardrail):
|
||||||
self, data: dict, hook: str, key_alias: Optional[str]
|
self, data: dict, hook: str, key_alias: Optional[str]
|
||||||
) -> None:
|
) -> None:
|
||||||
user_email = data.get("metadata", {}).get("headers", {}).get("x-aim-user-email")
|
user_email = data.get("metadata", {}).get("headers", {}).get("x-aim-user-email")
|
||||||
|
call_id = data.get("litellm_call_id")
|
||||||
headers = self._build_aim_headers(
|
headers = self._build_aim_headers(
|
||||||
hook=hook, key_alias=key_alias, user_email=user_email
|
hook=hook,
|
||||||
|
key_alias=key_alias,
|
||||||
|
user_email=user_email,
|
||||||
|
litellm_call_id=call_id,
|
||||||
)
|
)
|
||||||
response = await self.async_handler.post(
|
response = await self.async_handler.post(
|
||||||
f"{self.api_base}/detect/openai",
|
f"{self.api_base}/detect/openai",
|
||||||
|
@ -131,10 +135,14 @@ class AimGuardrail(CustomGuardrail):
|
||||||
user_email = (
|
user_email = (
|
||||||
request_data.get("metadata", {}).get("headers", {}).get("x-aim-user-email")
|
request_data.get("metadata", {}).get("headers", {}).get("x-aim-user-email")
|
||||||
)
|
)
|
||||||
|
call_id = request_data.get("litellm_call_id")
|
||||||
response = await self.async_handler.post(
|
response = await self.async_handler.post(
|
||||||
f"{self.api_base}/detect/output",
|
f"{self.api_base}/detect/output",
|
||||||
headers=self._build_aim_headers(
|
headers=self._build_aim_headers(
|
||||||
hook=hook, key_alias=key_alias, user_email=user_email
|
hook=hook,
|
||||||
|
key_alias=key_alias,
|
||||||
|
user_email=user_email,
|
||||||
|
litellm_call_id=call_id,
|
||||||
),
|
),
|
||||||
json={"output": output, "messages": request_data.get("messages", [])},
|
json={"output": output, "messages": request_data.get("messages", [])},
|
||||||
)
|
)
|
||||||
|
@ -152,7 +160,12 @@ class AimGuardrail(CustomGuardrail):
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def _build_aim_headers(
|
def _build_aim_headers(
|
||||||
self, *, hook: str, key_alias: Optional[str], user_email: Optional[str]
|
self,
|
||||||
|
*,
|
||||||
|
hook: str,
|
||||||
|
key_alias: Optional[str],
|
||||||
|
user_email: Optional[str],
|
||||||
|
litellm_call_id: Optional[str],
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
A helper function to build the http headers that are required by AIM guardrails.
|
A helper function to build the http headers that are required by AIM guardrails.
|
||||||
|
@ -165,6 +178,8 @@ class AimGuardrail(CustomGuardrail):
|
||||||
# Used by Aim to track LiteLLM version and provide backward compatibility.
|
# Used by Aim to track LiteLLM version and provide backward compatibility.
|
||||||
"x-aim-litellm-version": litellm_version,
|
"x-aim-litellm-version": litellm_version,
|
||||||
}
|
}
|
||||||
|
# Used by Aim to track together single call input and output
|
||||||
|
| ({"x-aim-litellm-call-id": litellm_call_id} if litellm_call_id else {})
|
||||||
# Used by Aim to track guardrails violations by user.
|
# Used by Aim to track guardrails violations by user.
|
||||||
| ({"x-aim-user-email": user_email} if user_email else {})
|
| ({"x-aim-user-email": user_email} if user_email else {})
|
||||||
| (
|
| (
|
||||||
|
@ -204,12 +219,14 @@ class AimGuardrail(CustomGuardrail):
|
||||||
user_email = (
|
user_email = (
|
||||||
request_data.get("metadata", {}).get("headers", {}).get("x-aim-user-email")
|
request_data.get("metadata", {}).get("headers", {}).get("x-aim-user-email")
|
||||||
)
|
)
|
||||||
|
call_id = request_data.get("litellm_call_id")
|
||||||
async with connect(
|
async with connect(
|
||||||
f"{self.ws_api_base}/detect/output/ws",
|
f"{self.ws_api_base}/detect/output/ws",
|
||||||
additional_headers=self._build_aim_headers(
|
additional_headers=self._build_aim_headers(
|
||||||
hook="output",
|
hook="output",
|
||||||
key_alias=user_api_key_dict.key_alias,
|
key_alias=user_api_key_dict.key_alias,
|
||||||
user_email=user_email,
|
user_email=user_email,
|
||||||
|
litellm_call_id=call_id,
|
||||||
),
|
),
|
||||||
) as websocket:
|
) as websocket:
|
||||||
sender = asyncio.create_task(
|
sender = asyncio.create_task(
|
||||||
|
|
72
tests/litellm/proxy/test_common_request_processing.py
Normal file
72
tests/litellm/proxy/test_common_request_processing.py
Normal file
|
@ -0,0 +1,72 @@
|
||||||
|
import copy
|
||||||
|
import uuid
|
||||||
|
import pytest
|
||||||
|
import litellm
|
||||||
|
from unittest.mock import AsyncMock, MagicMock
|
||||||
|
from fastapi import Request
|
||||||
|
|
||||||
|
from litellm.integrations.opentelemetry import UserAPIKeyAuth
|
||||||
|
from litellm.proxy.common_request_processing import (
|
||||||
|
ProxyBaseLLMRequestProcessing,
|
||||||
|
ProxyConfig,
|
||||||
|
)
|
||||||
|
from litellm.proxy.utils import ProxyLogging
|
||||||
|
|
||||||
|
|
||||||
|
class TestProxyBaseLLMRequestProcessing:
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_common_processing_pre_call_logic_pre_call_hook_receives_litellm_call_id(
|
||||||
|
self, monkeypatch
|
||||||
|
):
|
||||||
|
|
||||||
|
processing_obj = ProxyBaseLLMRequestProcessing(data={})
|
||||||
|
mock_request = MagicMock(spec=Request)
|
||||||
|
mock_request.headers = {}
|
||||||
|
|
||||||
|
async def mock_add_litellm_data_to_request(*args, **kwargs):
|
||||||
|
return {}
|
||||||
|
|
||||||
|
async def mock_common_processing_pre_call_logic(
|
||||||
|
user_api_key_dict, data, call_type
|
||||||
|
):
|
||||||
|
data_copy = copy.deepcopy(data)
|
||||||
|
return data_copy
|
||||||
|
|
||||||
|
mock_proxy_logging_obj = MagicMock(spec=ProxyLogging)
|
||||||
|
mock_proxy_logging_obj.pre_call_hook = AsyncMock(
|
||||||
|
side_effect=mock_common_processing_pre_call_logic
|
||||||
|
)
|
||||||
|
monkeypatch.setattr(
|
||||||
|
litellm.proxy.common_request_processing,
|
||||||
|
"add_litellm_data_to_request",
|
||||||
|
mock_add_litellm_data_to_request,
|
||||||
|
)
|
||||||
|
mock_general_settings = {}
|
||||||
|
mock_user_api_key_dict = MagicMock(spec=UserAPIKeyAuth)
|
||||||
|
mock_proxy_config = MagicMock(spec=ProxyConfig)
|
||||||
|
route_type = "acompletion"
|
||||||
|
|
||||||
|
# Call the actual method.
|
||||||
|
returned_data, logging_obj = (
|
||||||
|
await processing_obj.common_processing_pre_call_logic(
|
||||||
|
request=mock_request,
|
||||||
|
general_settings=mock_general_settings,
|
||||||
|
user_api_key_dict=mock_user_api_key_dict,
|
||||||
|
proxy_logging_obj=mock_proxy_logging_obj,
|
||||||
|
proxy_config=mock_proxy_config,
|
||||||
|
route_type=route_type,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
mock_proxy_logging_obj.pre_call_hook.assert_called_once()
|
||||||
|
|
||||||
|
_, call_kwargs = mock_proxy_logging_obj.pre_call_hook.call_args
|
||||||
|
data_passed = call_kwargs.get("data", {})
|
||||||
|
|
||||||
|
assert "litellm_call_id" in data_passed
|
||||||
|
try:
|
||||||
|
uuid.UUID(data_passed["litellm_call_id"])
|
||||||
|
except ValueError:
|
||||||
|
pytest.fail("litellm_call_id is not a valid UUID")
|
||||||
|
assert data_passed["litellm_call_id"] == returned_data["litellm_call_id"]
|
Loading…
Add table
Add a link
Reference in a new issue