mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 11:14:04 +00:00
Support post-call guards for stream and non-stream responses
This commit is contained in:
parent
44184c4113
commit
b01cf5577c
8 changed files with 297 additions and 33 deletions
|
@ -4,11 +4,14 @@
|
|||
# https://www.aim.security/
|
||||
#
|
||||
# +-------------------------------------------------------------+
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import os
|
||||
from typing import Literal, Optional, Union
|
||||
from typing import Any, Literal, Optional, Union
|
||||
|
||||
from fastapi import HTTPException
|
||||
from pydantic import BaseModel
|
||||
from websockets.asyncio.client import ClientConnection, connect
|
||||
|
||||
from litellm import DualCache
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
|
@ -18,7 +21,14 @@ from litellm.llms.custom_httpx.http_handler import (
|
|||
httpxSpecialProvider,
|
||||
)
|
||||
from litellm.proxy._types import UserAPIKeyAuth
|
||||
|
||||
from litellm.proxy.proxy_server import StreamingCallbackError
|
||||
from litellm.types.utils import (
|
||||
Choices,
|
||||
EmbeddingResponse,
|
||||
ImageResponse,
|
||||
ModelResponse,
|
||||
ModelResponseStream,
|
||||
)
|
||||
|
||||
class AimGuardrailMissingSecrets(Exception):
|
||||
pass
|
||||
|
@ -38,9 +48,8 @@ class AimGuardrail(CustomGuardrail):
|
|||
"pass it as a parameter to the guardrail in the config file"
|
||||
)
|
||||
raise AimGuardrailMissingSecrets(msg)
|
||||
self.api_base = (
|
||||
api_base or os.environ.get("AIM_API_BASE") or "https://api.aim.security"
|
||||
)
|
||||
self.api_base = api_base or os.environ.get("AIM_API_BASE") or "https://api.aim.security"
|
||||
self.ws_api_base = self.api_base.replace("http://", "ws://").replace("https://", "wss://")
|
||||
super().__init__(**kwargs)
|
||||
|
||||
async def async_pre_call_hook(
|
||||
|
@ -98,8 +107,82 @@ class AimGuardrail(CustomGuardrail):
|
|||
detected = res["detected"]
|
||||
verbose_proxy_logger.info(
|
||||
"Aim: detected: {detected}, enabled policies: {policies}".format(
|
||||
detected=detected, policies=list(res["details"].keys())
|
||||
)
|
||||
detected=detected,
|
||||
policies=list(res["details"].keys()),
|
||||
),
|
||||
)
|
||||
if detected:
|
||||
raise HTTPException(status_code=400, detail=res["detection_message"])
|
||||
|
||||
async def call_aim_guardrail_on_output(self, request_data: dict, output: str, hook: str) -> None:
|
||||
user_email = request_data.get("metadata", {}).get("headers", {}).get("x-aim-user-email")
|
||||
headers = {"Authorization": f"Bearer {self.api_key}", "x-aim-litellm-hook": hook} | (
|
||||
{"x-aim-user-email": user_email} if user_email else {}
|
||||
)
|
||||
response = await self.async_handler.post(
|
||||
f"{self.api_base}/detect",
|
||||
headers=headers,
|
||||
json={"user_prompt": output},
|
||||
)
|
||||
response.raise_for_status()
|
||||
res = response.json()
|
||||
detected = res["detected"]
|
||||
verbose_proxy_logger.info(
|
||||
"Aim: detected: {detected}, enabled policies: {policies}".format(
|
||||
detected=detected,
|
||||
policies=list(res["details"].keys()),
|
||||
),
|
||||
)
|
||||
if detected:
|
||||
return res["detection_message"]
|
||||
return None
|
||||
|
||||
async def async_post_call_success_hook(
|
||||
self,
|
||||
data: dict,
|
||||
user_api_key_dict: UserAPIKeyAuth,
|
||||
response: Union[Any, ModelResponse, EmbeddingResponse, ImageResponse],
|
||||
) -> Any:
|
||||
if isinstance(response, ModelResponse) and response.choices and isinstance(response.choices[0], Choices):
|
||||
content = response.choices[0].message.content or ""
|
||||
detection = await self.call_aim_guardrail_on_output(data, content, hook="output")
|
||||
if detection:
|
||||
raise HTTPException(status_code=400, detail=detection)
|
||||
|
||||
async def async_post_call_streaming_iterator_hook(
|
||||
self,
|
||||
user_api_key_dict: UserAPIKeyAuth,
|
||||
response,
|
||||
request_data: dict,
|
||||
) -> Any:
|
||||
user_email = request_data.get("metadata", {}).get("headers", {}).get("x-aim-user-email")
|
||||
headers = {
|
||||
"Authorization": f"Bearer {self.api_key}",
|
||||
} | ({"x-aim-user-email": user_email} if user_email else {})
|
||||
async with connect(f"{self.ws_api_base}/detect/output/ws", additional_headers=headers) as websocket:
|
||||
sender = asyncio.create_task(self.forward_the_stream_to_aim(websocket, response))
|
||||
while True:
|
||||
result = json.loads(await websocket.recv())
|
||||
if verified_chunk := result.get("verified_chunk"):
|
||||
yield ModelResponseStream.model_validate(verified_chunk)
|
||||
else:
|
||||
sender.cancel()
|
||||
if result.get("done"):
|
||||
return
|
||||
if blocking_message := result.get("blocking_message"):
|
||||
raise StreamingCallbackError(blocking_message)
|
||||
verbose_proxy_logger.error(f"Unknown message received from AIM: {result}")
|
||||
return
|
||||
|
||||
async def forward_the_stream_to_aim(
|
||||
self,
|
||||
websocket: ClientConnection,
|
||||
response_iter,
|
||||
) -> None:
|
||||
async for chunk in response_iter:
|
||||
if isinstance(chunk, BaseModel):
|
||||
chunk = chunk.model_dump_json()
|
||||
if isinstance(chunk, dict):
|
||||
chunk = json.dumps(chunk)
|
||||
await websocket.send(chunk)
|
||||
await websocket.send(json.dumps({"done": True}))
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue