diff --git a/litellm/integrations/custom_logger.py b/litellm/integrations/custom_logger.py index 17c144cc0b..6f1ec88d01 100644 --- a/litellm/integrations/custom_logger.py +++ b/litellm/integrations/custom_logger.py @@ -1,7 +1,16 @@ #### What this does #### # On success, logs events to Promptlayer import traceback -from typing import TYPE_CHECKING, Any, List, Literal, Optional, Tuple, Union +from typing import ( + TYPE_CHECKING, + Any, + AsyncGenerator, + List, + Literal, + Optional, + Tuple, + Union, +) from pydantic import BaseModel @@ -14,6 +23,7 @@ from litellm.types.utils import ( EmbeddingResponse, ImageResponse, ModelResponse, + ModelResponseStream, StandardCallbackDynamicParams, StandardLoggingPayload, ) @@ -256,7 +266,7 @@ class CustomLogger: # https://docs.litellm.ai/docs/observability/custom_callbac user_api_key_dict: UserAPIKeyAuth, response: Any, request_data: dict, - ) -> Any: + ) -> AsyncGenerator[ModelResponseStream, None]: async for item in response: yield item diff --git a/litellm/proxy/guardrails/guardrail_hooks/aim.py b/litellm/proxy/guardrails/guardrail_hooks/aim.py index f062189b53..9b0cd5f3cc 100644 --- a/litellm/proxy/guardrails/guardrail_hooks/aim.py +++ b/litellm/proxy/guardrails/guardrail_hooks/aim.py @@ -7,7 +7,7 @@ import asyncio import json import os -from typing import Any, Literal, Optional, Union +from typing import Any, AsyncGenerator, Literal, Optional, Union from fastapi import HTTPException from pydantic import BaseModel @@ -36,8 +36,12 @@ class AimGuardrailMissingSecrets(Exception): class AimGuardrail(CustomGuardrail): - def __init__(self, api_key: Optional[str] = None, api_base: Optional[str] = None, **kwargs): - self.async_handler = get_async_httpx_client(llm_provider=httpxSpecialProvider.GuardrailCallback) + def __init__( + self, api_key: Optional[str] = None, api_base: Optional[str] = None, **kwargs + ): + self.async_handler = get_async_httpx_client( + llm_provider=httpxSpecialProvider.GuardrailCallback + ) self.api_key = api_key or os.environ.get("AIM_API_KEY") if not self.api_key: msg = ( @@ -45,8 +49,12 @@ class AimGuardrail(CustomGuardrail): "pass it as a parameter to the guardrail in the config file" ) raise AimGuardrailMissingSecrets(msg) - self.api_base = api_base or os.environ.get("AIM_API_BASE") or "https://api.aim.security" - self.ws_api_base = self.api_base.replace("http://", "ws://").replace("https://", "wss://") + self.api_base = ( + api_base or os.environ.get("AIM_API_BASE") or "https://api.aim.security" + ) + self.ws_api_base = self.api_base.replace("http://", "ws://").replace( + "https://", "wss://" + ) super().__init__(**kwargs) async def async_pre_call_hook( @@ -111,11 +119,16 @@ class AimGuardrail(CustomGuardrail): if detected: raise HTTPException(status_code=400, detail=res["detection_message"]) - async def call_aim_guardrail_on_output(self, request_data: dict, output: str, hook: str) -> None: - user_email = request_data.get("metadata", {}).get("headers", {}).get("x-aim-user-email") - headers = {"Authorization": f"Bearer {self.api_key}", "x-aim-litellm-hook": hook} | ( - {"x-aim-user-email": user_email} if user_email else {} + async def call_aim_guardrail_on_output( + self, request_data: dict, output: str, hook: str + ) -> None: + user_email = ( + request_data.get("metadata", {}).get("headers", {}).get("x-aim-user-email") ) + headers = { + "Authorization": f"Bearer {self.api_key}", + "x-aim-litellm-hook": hook, + } | ({"x-aim-user-email": user_email} if user_email else {}) response = await self.async_handler.post( f"{self.api_base}/detect/output", headers=headers, @@ -140,9 +153,15 @@ class AimGuardrail(CustomGuardrail): user_api_key_dict: UserAPIKeyAuth, response: Union[Any, ModelResponse, EmbeddingResponse, ImageResponse], ) -> Any: - if isinstance(response, ModelResponse) and response.choices and isinstance(response.choices[0], Choices): + if ( + isinstance(response, ModelResponse) + and response.choices + and isinstance(response.choices[0], Choices) + ): content = response.choices[0].message.content or "" - detection = await self.call_aim_guardrail_on_output(data, content, hook="output") + detection = await self.call_aim_guardrail_on_output( + data, content, hook="output" + ) if detection: raise HTTPException(status_code=400, detail=detection) @@ -151,13 +170,19 @@ class AimGuardrail(CustomGuardrail): user_api_key_dict: UserAPIKeyAuth, response, request_data: dict, - ) -> Any: - user_email = request_data.get("metadata", {}).get("headers", {}).get("x-aim-user-email") + ) -> AsyncGenerator[ModelResponseStream, None]: + user_email = ( + request_data.get("metadata", {}).get("headers", {}).get("x-aim-user-email") + ) headers = { "Authorization": f"Bearer {self.api_key}", } | ({"x-aim-user-email": user_email} if user_email else {}) - async with connect(f"{self.ws_api_base}/detect/output/ws", additional_headers=headers) as websocket: - sender = asyncio.create_task(self.forward_the_stream_to_aim(websocket, response)) + async with connect( + f"{self.ws_api_base}/detect/output/ws", additional_headers=headers + ) as websocket: + sender = asyncio.create_task( + self.forward_the_stream_to_aim(websocket, response) + ) while True: result = json.loads(await websocket.recv()) if verified_chunk := result.get("verified_chunk"): @@ -168,7 +193,9 @@ class AimGuardrail(CustomGuardrail): return if blocking_message := result.get("blocking_message"): raise StreamingCallbackError(blocking_message) - verbose_proxy_logger.error(f"Unknown message received from AIM: {result}") + verbose_proxy_logger.error( + f"Unknown message received from AIM: {result}" + ) return async def forward_the_stream_to_aim( diff --git a/litellm/realtime_api/main.py b/litellm/realtime_api/main.py index ac39a68c60..8330de0f55 100644 --- a/litellm/realtime_api/main.py +++ b/litellm/realtime_api/main.py @@ -151,6 +151,8 @@ async def _realtime_health_check( url = openai_realtime._construct_url( api_base=api_base or "https://api.openai.com/", model=model ) + else: + raise ValueError(f"Unsupported model: {model}") async with websockets.connect( # type: ignore url, extra_headers={