diff --git a/docs/my-website/docs/proxy/guardrails/aim_security.md b/docs/my-website/docs/proxy/guardrails/aim_security.md index d588afa424..3de933c0b7 100644 --- a/docs/my-website/docs/proxy/guardrails/aim_security.md +++ b/docs/my-website/docs/proxy/guardrails/aim_security.md @@ -37,7 +37,7 @@ guardrails: - guardrail_name: aim-protected-app litellm_params: guardrail: aim - mode: pre_call + mode: pre_call # 'during_call' is also available api_key: os.environ/AIM_API_KEY api_base: os.environ/AIM_API_BASE # Optional, use only when using a self-hosted Aim Outpost ``` diff --git a/docs/my-website/docs/proxy/timeout.md b/docs/my-website/docs/proxy/timeout.md index 2bf93298fe..85428ae53e 100644 --- a/docs/my-website/docs/proxy/timeout.md +++ b/docs/my-website/docs/proxy/timeout.md @@ -166,7 +166,7 @@ response = client.chat.completions.create( {"role": "user", "content": "what color is red"} ], logit_bias={12481: 100}, - timeout=1 + extra_body={"timeout": 1} # 👈 KEY CHANGE ) print(response) diff --git a/litellm/proxy/guardrails/guardrail_hooks/aim.py b/litellm/proxy/guardrails/guardrail_hooks/aim.py index 5f3ec9e880..91d19e277c 100644 --- a/litellm/proxy/guardrails/guardrail_hooks/aim.py +++ b/litellm/proxy/guardrails/guardrail_hooks/aim.py @@ -6,7 +6,7 @@ # +-------------------------------------------------------------+ import os -from typing import Literal, Optional +from typing import Literal, Optional, Union from fastapi import HTTPException @@ -25,12 +25,8 @@ class AimGuardrailMissingSecrets(Exception): class AimGuardrail(CustomGuardrail): - def __init__( - self, api_key: Optional[str] = None, api_base: Optional[str] = None, **kwargs - ): - self.async_handler = get_async_httpx_client( - llm_provider=httpxSpecialProvider.GuardrailCallback - ) + def __init__(self, api_key: Optional[str] = None, api_base: Optional[str] = None, **kwargs): + self.async_handler = get_async_httpx_client(llm_provider=httpxSpecialProvider.GuardrailCallback) self.api_key = api_key or os.environ.get("AIM_API_KEY") if not self.api_key: msg = ( @@ -38,9 +34,7 @@ class AimGuardrail(CustomGuardrail): "pass it as a parameter to the guardrail in the config file" ) raise AimGuardrailMissingSecrets(msg) - self.api_base = ( - api_base or os.environ.get("AIM_API_BASE") or "https://api.aim.security" - ) + self.api_base = api_base or os.environ.get("AIM_API_BASE") or "https://api.aim.security" super().__init__(**kwargs) async def async_pre_call_hook( @@ -58,11 +52,32 @@ class AimGuardrail(CustomGuardrail): "pass_through_endpoint", "rerank", ], - ) -> Exception | str | dict | None: + ) -> Union[Exception, str, dict, None]: verbose_proxy_logger.debug("Inside AIM Pre-Call Hook") + await self.call_aim_guardrail(data, hook="pre_call") + return data + + async def async_moderation_hook( + self, + data: dict, + user_api_key_dict: UserAPIKeyAuth, + call_type: Literal[ + "completion", + "embeddings", + "image_generation", + "moderation", + "audio_transcription", + ], + ) -> Union[Exception, str, dict, None]: + verbose_proxy_logger.debug("Inside AIM Moderation Hook") + + await self.call_aim_guardrail(data, hook="moderation") + return data + + async def call_aim_guardrail(self, data: dict, hook: str) -> None: user_email = data.get("metadata", {}).get("headers", {}).get("x-aim-user-email") - headers = {"Authorization": f"Bearer {self.api_key}"} | ( + headers = {"Authorization": f"Bearer {self.api_key}", "x-aim-litellm-hook": hook} | ( {"x-aim-user-email": user_email} if user_email else {} ) response = await self.async_handler.post( @@ -80,4 +95,3 @@ class AimGuardrail(CustomGuardrail): ) if detected: raise HTTPException(status_code=400, detail=res["detection_message"]) - return data diff --git a/litellm/utils.py b/litellm/utils.py index 7e66ad2b22..d181791515 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -5884,6 +5884,10 @@ def validate_chat_completion_user_messages(messages: List[AllMessageValues]): if item.get("type") not in ValidUserMessageContentTypes: raise Exception("invalid content type") except Exception as e: + if isinstance(e, KeyError): + raise Exception( + f"Invalid message={m} at index {idx}. Please ensure all messages are valid OpenAI chat completion messages." + ) if "invalid content type" in str(e): raise Exception( f"Invalid user message={m} at index {idx}. Please ensure all user messages are valid OpenAI chat completion messages." diff --git a/tests/litellm_utils_tests/test_utils.py b/tests/litellm_utils_tests/test_utils.py index 4f5e1e2737..4a2f63b51d 100644 --- a/tests/litellm_utils_tests/test_utils.py +++ b/tests/litellm_utils_tests/test_utils.py @@ -1850,3 +1850,15 @@ def test_dict_to_response_format_helper(): "ref_template": "/$defs/{model}", } _dict_to_response_format_helper(**args) + + +def test_validate_user_messages_invalid_content_type(): + from litellm.utils import validate_chat_completion_user_messages + + messages = [{"content": [{"type": "invalid_type", "text": "Hello"}]}] + + with pytest.raises(Exception) as e: + validate_chat_completion_user_messages(messages) + + assert "Invalid message" in str(e) + print(e) diff --git a/tests/local_testing/test_aim_guardrails.py b/tests/local_testing/test_aim_guardrails.py index b68140d37b..d43156fb19 100644 --- a/tests/local_testing/test_aim_guardrails.py +++ b/tests/local_testing/test_aim_guardrails.py @@ -55,15 +55,15 @@ def test_aim_guard_config_no_api_key(): @pytest.mark.asyncio -async def test_callback(): +@pytest.mark.parametrize("mode", ["pre_call", "during_call"]) +async def test_callback(mode: str): init_guardrails_v2( all_guardrails=[ { "guardrail_name": "gibberish-guard", "litellm_params": { "guardrail": "aim", - "guard_name": "gibberish_guard", - "mode": "pre_call", + "mode": mode, "api_key": "hs-aim-key", }, } @@ -89,6 +89,11 @@ async def test_callback(): request=Request(method="POST", url="http://aim"), ), ): - await aim_guardrail.async_pre_call_hook( - data=data, cache=DualCache(), user_api_key_dict=UserAPIKeyAuth(), call_type="completion" - ) + if mode == "pre_call": + await aim_guardrail.async_pre_call_hook( + data=data, cache=DualCache(), user_api_key_dict=UserAPIKeyAuth(), call_type="completion" + ) + else: + await aim_guardrail.async_moderation_hook( + data=data, user_api_key_dict=UserAPIKeyAuth(), call_type="completion" + )