mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 10:44:24 +00:00
fix(aim.py): fix linting error
This commit is contained in:
parent
ee6c9576d4
commit
997f2f0b3e
3 changed files with 57 additions and 18 deletions
|
@ -1,7 +1,16 @@
|
||||||
#### What this does ####
|
#### What this does ####
|
||||||
# On success, logs events to Promptlayer
|
# On success, logs events to Promptlayer
|
||||||
import traceback
|
import traceback
|
||||||
from typing import TYPE_CHECKING, Any, List, Literal, Optional, Tuple, Union
|
from typing import (
|
||||||
|
TYPE_CHECKING,
|
||||||
|
Any,
|
||||||
|
AsyncGenerator,
|
||||||
|
List,
|
||||||
|
Literal,
|
||||||
|
Optional,
|
||||||
|
Tuple,
|
||||||
|
Union,
|
||||||
|
)
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
@ -14,6 +23,7 @@ from litellm.types.utils import (
|
||||||
EmbeddingResponse,
|
EmbeddingResponse,
|
||||||
ImageResponse,
|
ImageResponse,
|
||||||
ModelResponse,
|
ModelResponse,
|
||||||
|
ModelResponseStream,
|
||||||
StandardCallbackDynamicParams,
|
StandardCallbackDynamicParams,
|
||||||
StandardLoggingPayload,
|
StandardLoggingPayload,
|
||||||
)
|
)
|
||||||
|
@ -256,7 +266,7 @@ class CustomLogger: # https://docs.litellm.ai/docs/observability/custom_callbac
|
||||||
user_api_key_dict: UserAPIKeyAuth,
|
user_api_key_dict: UserAPIKeyAuth,
|
||||||
response: Any,
|
response: Any,
|
||||||
request_data: dict,
|
request_data: dict,
|
||||||
) -> Any:
|
) -> AsyncGenerator[ModelResponseStream, None]:
|
||||||
async for item in response:
|
async for item in response:
|
||||||
yield item
|
yield item
|
||||||
|
|
||||||
|
|
|
@ -7,7 +7,7 @@
|
||||||
import asyncio
|
import asyncio
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
from typing import Any, Literal, Optional, Union
|
from typing import Any, AsyncGenerator, Literal, Optional, Union
|
||||||
|
|
||||||
from fastapi import HTTPException
|
from fastapi import HTTPException
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
@ -36,8 +36,12 @@ class AimGuardrailMissingSecrets(Exception):
|
||||||
|
|
||||||
|
|
||||||
class AimGuardrail(CustomGuardrail):
|
class AimGuardrail(CustomGuardrail):
|
||||||
def __init__(self, api_key: Optional[str] = None, api_base: Optional[str] = None, **kwargs):
|
def __init__(
|
||||||
self.async_handler = get_async_httpx_client(llm_provider=httpxSpecialProvider.GuardrailCallback)
|
self, api_key: Optional[str] = None, api_base: Optional[str] = None, **kwargs
|
||||||
|
):
|
||||||
|
self.async_handler = get_async_httpx_client(
|
||||||
|
llm_provider=httpxSpecialProvider.GuardrailCallback
|
||||||
|
)
|
||||||
self.api_key = api_key or os.environ.get("AIM_API_KEY")
|
self.api_key = api_key or os.environ.get("AIM_API_KEY")
|
||||||
if not self.api_key:
|
if not self.api_key:
|
||||||
msg = (
|
msg = (
|
||||||
|
@ -45,8 +49,12 @@ class AimGuardrail(CustomGuardrail):
|
||||||
"pass it as a parameter to the guardrail in the config file"
|
"pass it as a parameter to the guardrail in the config file"
|
||||||
)
|
)
|
||||||
raise AimGuardrailMissingSecrets(msg)
|
raise AimGuardrailMissingSecrets(msg)
|
||||||
self.api_base = api_base or os.environ.get("AIM_API_BASE") or "https://api.aim.security"
|
self.api_base = (
|
||||||
self.ws_api_base = self.api_base.replace("http://", "ws://").replace("https://", "wss://")
|
api_base or os.environ.get("AIM_API_BASE") or "https://api.aim.security"
|
||||||
|
)
|
||||||
|
self.ws_api_base = self.api_base.replace("http://", "ws://").replace(
|
||||||
|
"https://", "wss://"
|
||||||
|
)
|
||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
|
|
||||||
async def async_pre_call_hook(
|
async def async_pre_call_hook(
|
||||||
|
@ -111,11 +119,16 @@ class AimGuardrail(CustomGuardrail):
|
||||||
if detected:
|
if detected:
|
||||||
raise HTTPException(status_code=400, detail=res["detection_message"])
|
raise HTTPException(status_code=400, detail=res["detection_message"])
|
||||||
|
|
||||||
async def call_aim_guardrail_on_output(self, request_data: dict, output: str, hook: str) -> None:
|
async def call_aim_guardrail_on_output(
|
||||||
user_email = request_data.get("metadata", {}).get("headers", {}).get("x-aim-user-email")
|
self, request_data: dict, output: str, hook: str
|
||||||
headers = {"Authorization": f"Bearer {self.api_key}", "x-aim-litellm-hook": hook} | (
|
) -> None:
|
||||||
{"x-aim-user-email": user_email} if user_email else {}
|
user_email = (
|
||||||
|
request_data.get("metadata", {}).get("headers", {}).get("x-aim-user-email")
|
||||||
)
|
)
|
||||||
|
headers = {
|
||||||
|
"Authorization": f"Bearer {self.api_key}",
|
||||||
|
"x-aim-litellm-hook": hook,
|
||||||
|
} | ({"x-aim-user-email": user_email} if user_email else {})
|
||||||
response = await self.async_handler.post(
|
response = await self.async_handler.post(
|
||||||
f"{self.api_base}/detect/output",
|
f"{self.api_base}/detect/output",
|
||||||
headers=headers,
|
headers=headers,
|
||||||
|
@ -140,9 +153,15 @@ class AimGuardrail(CustomGuardrail):
|
||||||
user_api_key_dict: UserAPIKeyAuth,
|
user_api_key_dict: UserAPIKeyAuth,
|
||||||
response: Union[Any, ModelResponse, EmbeddingResponse, ImageResponse],
|
response: Union[Any, ModelResponse, EmbeddingResponse, ImageResponse],
|
||||||
) -> Any:
|
) -> Any:
|
||||||
if isinstance(response, ModelResponse) and response.choices and isinstance(response.choices[0], Choices):
|
if (
|
||||||
|
isinstance(response, ModelResponse)
|
||||||
|
and response.choices
|
||||||
|
and isinstance(response.choices[0], Choices)
|
||||||
|
):
|
||||||
content = response.choices[0].message.content or ""
|
content = response.choices[0].message.content or ""
|
||||||
detection = await self.call_aim_guardrail_on_output(data, content, hook="output")
|
detection = await self.call_aim_guardrail_on_output(
|
||||||
|
data, content, hook="output"
|
||||||
|
)
|
||||||
if detection:
|
if detection:
|
||||||
raise HTTPException(status_code=400, detail=detection)
|
raise HTTPException(status_code=400, detail=detection)
|
||||||
|
|
||||||
|
@ -151,13 +170,19 @@ class AimGuardrail(CustomGuardrail):
|
||||||
user_api_key_dict: UserAPIKeyAuth,
|
user_api_key_dict: UserAPIKeyAuth,
|
||||||
response,
|
response,
|
||||||
request_data: dict,
|
request_data: dict,
|
||||||
) -> Any:
|
) -> AsyncGenerator[ModelResponseStream, None]:
|
||||||
user_email = request_data.get("metadata", {}).get("headers", {}).get("x-aim-user-email")
|
user_email = (
|
||||||
|
request_data.get("metadata", {}).get("headers", {}).get("x-aim-user-email")
|
||||||
|
)
|
||||||
headers = {
|
headers = {
|
||||||
"Authorization": f"Bearer {self.api_key}",
|
"Authorization": f"Bearer {self.api_key}",
|
||||||
} | ({"x-aim-user-email": user_email} if user_email else {})
|
} | ({"x-aim-user-email": user_email} if user_email else {})
|
||||||
async with connect(f"{self.ws_api_base}/detect/output/ws", additional_headers=headers) as websocket:
|
async with connect(
|
||||||
sender = asyncio.create_task(self.forward_the_stream_to_aim(websocket, response))
|
f"{self.ws_api_base}/detect/output/ws", additional_headers=headers
|
||||||
|
) as websocket:
|
||||||
|
sender = asyncio.create_task(
|
||||||
|
self.forward_the_stream_to_aim(websocket, response)
|
||||||
|
)
|
||||||
while True:
|
while True:
|
||||||
result = json.loads(await websocket.recv())
|
result = json.loads(await websocket.recv())
|
||||||
if verified_chunk := result.get("verified_chunk"):
|
if verified_chunk := result.get("verified_chunk"):
|
||||||
|
@ -168,7 +193,9 @@ class AimGuardrail(CustomGuardrail):
|
||||||
return
|
return
|
||||||
if blocking_message := result.get("blocking_message"):
|
if blocking_message := result.get("blocking_message"):
|
||||||
raise StreamingCallbackError(blocking_message)
|
raise StreamingCallbackError(blocking_message)
|
||||||
verbose_proxy_logger.error(f"Unknown message received from AIM: {result}")
|
verbose_proxy_logger.error(
|
||||||
|
f"Unknown message received from AIM: {result}"
|
||||||
|
)
|
||||||
return
|
return
|
||||||
|
|
||||||
async def forward_the_stream_to_aim(
|
async def forward_the_stream_to_aim(
|
||||||
|
|
|
@ -151,6 +151,8 @@ async def _realtime_health_check(
|
||||||
url = openai_realtime._construct_url(
|
url = openai_realtime._construct_url(
|
||||||
api_base=api_base or "https://api.openai.com/", model=model
|
api_base=api_base or "https://api.openai.com/", model=model
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported model: {model}")
|
||||||
async with websockets.connect( # type: ignore
|
async with websockets.connect( # type: ignore
|
||||||
url,
|
url,
|
||||||
extra_headers={
|
extra_headers={
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue