fix(aim.py): fix linting error

This commit is contained in:
Krrish Dholakia 2025-03-13 15:32:42 -07:00
parent ee6c9576d4
commit 997f2f0b3e
3 changed files with 57 additions and 18 deletions

View file

@ -7,7 +7,7 @@
import asyncio
import json
import os
from typing import Any, Literal, Optional, Union
from typing import Any, AsyncGenerator, Literal, Optional, Union
from fastapi import HTTPException
from pydantic import BaseModel
@ -36,8 +36,12 @@ class AimGuardrailMissingSecrets(Exception):
class AimGuardrail(CustomGuardrail):
def __init__(self, api_key: Optional[str] = None, api_base: Optional[str] = None, **kwargs):
self.async_handler = get_async_httpx_client(llm_provider=httpxSpecialProvider.GuardrailCallback)
def __init__(
self, api_key: Optional[str] = None, api_base: Optional[str] = None, **kwargs
):
self.async_handler = get_async_httpx_client(
llm_provider=httpxSpecialProvider.GuardrailCallback
)
self.api_key = api_key or os.environ.get("AIM_API_KEY")
if not self.api_key:
msg = (
@ -45,8 +49,12 @@ class AimGuardrail(CustomGuardrail):
"pass it as a parameter to the guardrail in the config file"
)
raise AimGuardrailMissingSecrets(msg)
self.api_base = api_base or os.environ.get("AIM_API_BASE") or "https://api.aim.security"
self.ws_api_base = self.api_base.replace("http://", "ws://").replace("https://", "wss://")
self.api_base = (
api_base or os.environ.get("AIM_API_BASE") or "https://api.aim.security"
)
self.ws_api_base = self.api_base.replace("http://", "ws://").replace(
"https://", "wss://"
)
super().__init__(**kwargs)
async def async_pre_call_hook(
@ -111,11 +119,16 @@ class AimGuardrail(CustomGuardrail):
if detected:
raise HTTPException(status_code=400, detail=res["detection_message"])
async def call_aim_guardrail_on_output(self, request_data: dict, output: str, hook: str) -> None:
user_email = request_data.get("metadata", {}).get("headers", {}).get("x-aim-user-email")
headers = {"Authorization": f"Bearer {self.api_key}", "x-aim-litellm-hook": hook} | (
{"x-aim-user-email": user_email} if user_email else {}
async def call_aim_guardrail_on_output(
self, request_data: dict, output: str, hook: str
) -> None:
user_email = (
request_data.get("metadata", {}).get("headers", {}).get("x-aim-user-email")
)
headers = {
"Authorization": f"Bearer {self.api_key}",
"x-aim-litellm-hook": hook,
} | ({"x-aim-user-email": user_email} if user_email else {})
response = await self.async_handler.post(
f"{self.api_base}/detect/output",
headers=headers,
@ -140,9 +153,15 @@ class AimGuardrail(CustomGuardrail):
user_api_key_dict: UserAPIKeyAuth,
response: Union[Any, ModelResponse, EmbeddingResponse, ImageResponse],
) -> Any:
if isinstance(response, ModelResponse) and response.choices and isinstance(response.choices[0], Choices):
if (
isinstance(response, ModelResponse)
and response.choices
and isinstance(response.choices[0], Choices)
):
content = response.choices[0].message.content or ""
detection = await self.call_aim_guardrail_on_output(data, content, hook="output")
detection = await self.call_aim_guardrail_on_output(
data, content, hook="output"
)
if detection:
raise HTTPException(status_code=400, detail=detection)
@ -151,13 +170,19 @@ class AimGuardrail(CustomGuardrail):
user_api_key_dict: UserAPIKeyAuth,
response,
request_data: dict,
) -> Any:
user_email = request_data.get("metadata", {}).get("headers", {}).get("x-aim-user-email")
) -> AsyncGenerator[ModelResponseStream, None]:
user_email = (
request_data.get("metadata", {}).get("headers", {}).get("x-aim-user-email")
)
headers = {
"Authorization": f"Bearer {self.api_key}",
} | ({"x-aim-user-email": user_email} if user_email else {})
async with connect(f"{self.ws_api_base}/detect/output/ws", additional_headers=headers) as websocket:
sender = asyncio.create_task(self.forward_the_stream_to_aim(websocket, response))
async with connect(
f"{self.ws_api_base}/detect/output/ws", additional_headers=headers
) as websocket:
sender = asyncio.create_task(
self.forward_the_stream_to_aim(websocket, response)
)
while True:
result = json.loads(await websocket.recv())
if verified_chunk := result.get("verified_chunk"):
@ -168,7 +193,9 @@ class AimGuardrail(CustomGuardrail):
return
if blocking_message := result.get("blocking_message"):
raise StreamingCallbackError(blocking_message)
verbose_proxy_logger.error(f"Unknown message received from AIM: {result}")
verbose_proxy_logger.error(
f"Unknown message received from AIM: {result}"
)
return
async def forward_the_stream_to_aim(