fix(aim.py): fix linting error

This commit is contained in:
Krrish Dholakia 2025-03-13 15:32:42 -07:00
parent ee6c9576d4
commit 997f2f0b3e
3 changed files with 57 additions and 18 deletions

View file

@ -1,7 +1,16 @@
#### What this does #### #### What this does ####
# On success, logs events to Promptlayer # On success, logs events to Promptlayer
import traceback import traceback
from typing import TYPE_CHECKING, Any, List, Literal, Optional, Tuple, Union from typing import (
TYPE_CHECKING,
Any,
AsyncGenerator,
List,
Literal,
Optional,
Tuple,
Union,
)
from pydantic import BaseModel from pydantic import BaseModel
@ -14,6 +23,7 @@ from litellm.types.utils import (
EmbeddingResponse, EmbeddingResponse,
ImageResponse, ImageResponse,
ModelResponse, ModelResponse,
ModelResponseStream,
StandardCallbackDynamicParams, StandardCallbackDynamicParams,
StandardLoggingPayload, StandardLoggingPayload,
) )
@ -256,7 +266,7 @@ class CustomLogger: # https://docs.litellm.ai/docs/observability/custom_callbac
user_api_key_dict: UserAPIKeyAuth, user_api_key_dict: UserAPIKeyAuth,
response: Any, response: Any,
request_data: dict, request_data: dict,
) -> Any: ) -> AsyncGenerator[ModelResponseStream, None]:
async for item in response: async for item in response:
yield item yield item

View file

@ -7,7 +7,7 @@
import asyncio import asyncio
import json import json
import os import os
from typing import Any, Literal, Optional, Union from typing import Any, AsyncGenerator, Literal, Optional, Union
from fastapi import HTTPException from fastapi import HTTPException
from pydantic import BaseModel from pydantic import BaseModel
@ -36,8 +36,12 @@ class AimGuardrailMissingSecrets(Exception):
class AimGuardrail(CustomGuardrail): class AimGuardrail(CustomGuardrail):
def __init__(self, api_key: Optional[str] = None, api_base: Optional[str] = None, **kwargs): def __init__(
self.async_handler = get_async_httpx_client(llm_provider=httpxSpecialProvider.GuardrailCallback) self, api_key: Optional[str] = None, api_base: Optional[str] = None, **kwargs
):
self.async_handler = get_async_httpx_client(
llm_provider=httpxSpecialProvider.GuardrailCallback
)
self.api_key = api_key or os.environ.get("AIM_API_KEY") self.api_key = api_key or os.environ.get("AIM_API_KEY")
if not self.api_key: if not self.api_key:
msg = ( msg = (
@ -45,8 +49,12 @@ class AimGuardrail(CustomGuardrail):
"pass it as a parameter to the guardrail in the config file" "pass it as a parameter to the guardrail in the config file"
) )
raise AimGuardrailMissingSecrets(msg) raise AimGuardrailMissingSecrets(msg)
self.api_base = api_base or os.environ.get("AIM_API_BASE") or "https://api.aim.security" self.api_base = (
self.ws_api_base = self.api_base.replace("http://", "ws://").replace("https://", "wss://") api_base or os.environ.get("AIM_API_BASE") or "https://api.aim.security"
)
self.ws_api_base = self.api_base.replace("http://", "ws://").replace(
"https://", "wss://"
)
super().__init__(**kwargs) super().__init__(**kwargs)
async def async_pre_call_hook( async def async_pre_call_hook(
@ -111,11 +119,16 @@ class AimGuardrail(CustomGuardrail):
if detected: if detected:
raise HTTPException(status_code=400, detail=res["detection_message"]) raise HTTPException(status_code=400, detail=res["detection_message"])
async def call_aim_guardrail_on_output(self, request_data: dict, output: str, hook: str) -> None: async def call_aim_guardrail_on_output(
user_email = request_data.get("metadata", {}).get("headers", {}).get("x-aim-user-email") self, request_data: dict, output: str, hook: str
headers = {"Authorization": f"Bearer {self.api_key}", "x-aim-litellm-hook": hook} | ( ) -> None:
{"x-aim-user-email": user_email} if user_email else {} user_email = (
request_data.get("metadata", {}).get("headers", {}).get("x-aim-user-email")
) )
headers = {
"Authorization": f"Bearer {self.api_key}",
"x-aim-litellm-hook": hook,
} | ({"x-aim-user-email": user_email} if user_email else {})
response = await self.async_handler.post( response = await self.async_handler.post(
f"{self.api_base}/detect/output", f"{self.api_base}/detect/output",
headers=headers, headers=headers,
@ -140,9 +153,15 @@ class AimGuardrail(CustomGuardrail):
user_api_key_dict: UserAPIKeyAuth, user_api_key_dict: UserAPIKeyAuth,
response: Union[Any, ModelResponse, EmbeddingResponse, ImageResponse], response: Union[Any, ModelResponse, EmbeddingResponse, ImageResponse],
) -> Any: ) -> Any:
if isinstance(response, ModelResponse) and response.choices and isinstance(response.choices[0], Choices): if (
isinstance(response, ModelResponse)
and response.choices
and isinstance(response.choices[0], Choices)
):
content = response.choices[0].message.content or "" content = response.choices[0].message.content or ""
detection = await self.call_aim_guardrail_on_output(data, content, hook="output") detection = await self.call_aim_guardrail_on_output(
data, content, hook="output"
)
if detection: if detection:
raise HTTPException(status_code=400, detail=detection) raise HTTPException(status_code=400, detail=detection)
@ -151,13 +170,19 @@ class AimGuardrail(CustomGuardrail):
user_api_key_dict: UserAPIKeyAuth, user_api_key_dict: UserAPIKeyAuth,
response, response,
request_data: dict, request_data: dict,
) -> Any: ) -> AsyncGenerator[ModelResponseStream, None]:
user_email = request_data.get("metadata", {}).get("headers", {}).get("x-aim-user-email") user_email = (
request_data.get("metadata", {}).get("headers", {}).get("x-aim-user-email")
)
headers = { headers = {
"Authorization": f"Bearer {self.api_key}", "Authorization": f"Bearer {self.api_key}",
} | ({"x-aim-user-email": user_email} if user_email else {}) } | ({"x-aim-user-email": user_email} if user_email else {})
async with connect(f"{self.ws_api_base}/detect/output/ws", additional_headers=headers) as websocket: async with connect(
sender = asyncio.create_task(self.forward_the_stream_to_aim(websocket, response)) f"{self.ws_api_base}/detect/output/ws", additional_headers=headers
) as websocket:
sender = asyncio.create_task(
self.forward_the_stream_to_aim(websocket, response)
)
while True: while True:
result = json.loads(await websocket.recv()) result = json.loads(await websocket.recv())
if verified_chunk := result.get("verified_chunk"): if verified_chunk := result.get("verified_chunk"):
@ -168,7 +193,9 @@ class AimGuardrail(CustomGuardrail):
return return
if blocking_message := result.get("blocking_message"): if blocking_message := result.get("blocking_message"):
raise StreamingCallbackError(blocking_message) raise StreamingCallbackError(blocking_message)
verbose_proxy_logger.error(f"Unknown message received from AIM: {result}") verbose_proxy_logger.error(
f"Unknown message received from AIM: {result}"
)
return return
async def forward_the_stream_to_aim( async def forward_the_stream_to_aim(

View file

@ -151,6 +151,8 @@ async def _realtime_health_check(
url = openai_realtime._construct_url( url = openai_realtime._construct_url(
api_base=api_base or "https://api.openai.com/", model=model api_base=api_base or "https://api.openai.com/", model=model
) )
else:
raise ValueError(f"Unsupported model: {model}")
async with websockets.connect( # type: ignore async with websockets.connect( # type: ignore
url, url,
extra_headers={ extra_headers={