mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 10:44:24 +00:00
fix(utils.py): handle key error in msg validation (#8325)
* fix(utils.py): handle key error in msg validation * Support running Aim Guard during LLM call (#7918) * support running Aim Guard during LLM call * Rename header * adjust docs and fix type annotations * fix(timeout.md): doc fix for openai example on dynamic timeouts --------- Co-authored-by: Tomer Bin <117278227+hxtomer@users.noreply.github.com>
This commit is contained in:
parent
fac1d2ccef
commit
f031926b82
6 changed files with 56 additions and 21 deletions
|
@ -37,7 +37,7 @@ guardrails:
|
||||||
- guardrail_name: aim-protected-app
|
- guardrail_name: aim-protected-app
|
||||||
litellm_params:
|
litellm_params:
|
||||||
guardrail: aim
|
guardrail: aim
|
||||||
mode: pre_call
|
mode: pre_call # 'during_call' is also available
|
||||||
api_key: os.environ/AIM_API_KEY
|
api_key: os.environ/AIM_API_KEY
|
||||||
api_base: os.environ/AIM_API_BASE # Optional, use only when using a self-hosted Aim Outpost
|
api_base: os.environ/AIM_API_BASE # Optional, use only when using a self-hosted Aim Outpost
|
||||||
```
|
```
|
||||||
|
|
|
@ -166,7 +166,7 @@ response = client.chat.completions.create(
|
||||||
{"role": "user", "content": "what color is red"}
|
{"role": "user", "content": "what color is red"}
|
||||||
],
|
],
|
||||||
logit_bias={12481: 100},
|
logit_bias={12481: 100},
|
||||||
timeout=1
|
extra_body={"timeout": 1} # 👈 KEY CHANGE
|
||||||
)
|
)
|
||||||
|
|
||||||
print(response)
|
print(response)
|
||||||
|
|
|
@ -6,7 +6,7 @@
|
||||||
# +-------------------------------------------------------------+
|
# +-------------------------------------------------------------+
|
||||||
|
|
||||||
import os
|
import os
|
||||||
from typing import Literal, Optional
|
from typing import Literal, Optional, Union
|
||||||
|
|
||||||
from fastapi import HTTPException
|
from fastapi import HTTPException
|
||||||
|
|
||||||
|
@ -25,12 +25,8 @@ class AimGuardrailMissingSecrets(Exception):
|
||||||
|
|
||||||
|
|
||||||
class AimGuardrail(CustomGuardrail):
|
class AimGuardrail(CustomGuardrail):
|
||||||
def __init__(
|
def __init__(self, api_key: Optional[str] = None, api_base: Optional[str] = None, **kwargs):
|
||||||
self, api_key: Optional[str] = None, api_base: Optional[str] = None, **kwargs
|
self.async_handler = get_async_httpx_client(llm_provider=httpxSpecialProvider.GuardrailCallback)
|
||||||
):
|
|
||||||
self.async_handler = get_async_httpx_client(
|
|
||||||
llm_provider=httpxSpecialProvider.GuardrailCallback
|
|
||||||
)
|
|
||||||
self.api_key = api_key or os.environ.get("AIM_API_KEY")
|
self.api_key = api_key or os.environ.get("AIM_API_KEY")
|
||||||
if not self.api_key:
|
if not self.api_key:
|
||||||
msg = (
|
msg = (
|
||||||
|
@ -38,9 +34,7 @@ class AimGuardrail(CustomGuardrail):
|
||||||
"pass it as a parameter to the guardrail in the config file"
|
"pass it as a parameter to the guardrail in the config file"
|
||||||
)
|
)
|
||||||
raise AimGuardrailMissingSecrets(msg)
|
raise AimGuardrailMissingSecrets(msg)
|
||||||
self.api_base = (
|
self.api_base = api_base or os.environ.get("AIM_API_BASE") or "https://api.aim.security"
|
||||||
api_base or os.environ.get("AIM_API_BASE") or "https://api.aim.security"
|
|
||||||
)
|
|
||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
|
|
||||||
async def async_pre_call_hook(
|
async def async_pre_call_hook(
|
||||||
|
@ -58,11 +52,32 @@ class AimGuardrail(CustomGuardrail):
|
||||||
"pass_through_endpoint",
|
"pass_through_endpoint",
|
||||||
"rerank",
|
"rerank",
|
||||||
],
|
],
|
||||||
) -> Exception | str | dict | None:
|
) -> Union[Exception, str, dict, None]:
|
||||||
verbose_proxy_logger.debug("Inside AIM Pre-Call Hook")
|
verbose_proxy_logger.debug("Inside AIM Pre-Call Hook")
|
||||||
|
|
||||||
|
await self.call_aim_guardrail(data, hook="pre_call")
|
||||||
|
return data
|
||||||
|
|
||||||
|
async def async_moderation_hook(
|
||||||
|
self,
|
||||||
|
data: dict,
|
||||||
|
user_api_key_dict: UserAPIKeyAuth,
|
||||||
|
call_type: Literal[
|
||||||
|
"completion",
|
||||||
|
"embeddings",
|
||||||
|
"image_generation",
|
||||||
|
"moderation",
|
||||||
|
"audio_transcription",
|
||||||
|
],
|
||||||
|
) -> Union[Exception, str, dict, None]:
|
||||||
|
verbose_proxy_logger.debug("Inside AIM Moderation Hook")
|
||||||
|
|
||||||
|
await self.call_aim_guardrail(data, hook="moderation")
|
||||||
|
return data
|
||||||
|
|
||||||
|
async def call_aim_guardrail(self, data: dict, hook: str) -> None:
|
||||||
user_email = data.get("metadata", {}).get("headers", {}).get("x-aim-user-email")
|
user_email = data.get("metadata", {}).get("headers", {}).get("x-aim-user-email")
|
||||||
headers = {"Authorization": f"Bearer {self.api_key}"} | (
|
headers = {"Authorization": f"Bearer {self.api_key}", "x-aim-litellm-hook": hook} | (
|
||||||
{"x-aim-user-email": user_email} if user_email else {}
|
{"x-aim-user-email": user_email} if user_email else {}
|
||||||
)
|
)
|
||||||
response = await self.async_handler.post(
|
response = await self.async_handler.post(
|
||||||
|
@ -80,4 +95,3 @@ class AimGuardrail(CustomGuardrail):
|
||||||
)
|
)
|
||||||
if detected:
|
if detected:
|
||||||
raise HTTPException(status_code=400, detail=res["detection_message"])
|
raise HTTPException(status_code=400, detail=res["detection_message"])
|
||||||
return data
|
|
||||||
|
|
|
@ -5884,6 +5884,10 @@ def validate_chat_completion_user_messages(messages: List[AllMessageValues]):
|
||||||
if item.get("type") not in ValidUserMessageContentTypes:
|
if item.get("type") not in ValidUserMessageContentTypes:
|
||||||
raise Exception("invalid content type")
|
raise Exception("invalid content type")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
if isinstance(e, KeyError):
|
||||||
|
raise Exception(
|
||||||
|
f"Invalid message={m} at index {idx}. Please ensure all messages are valid OpenAI chat completion messages."
|
||||||
|
)
|
||||||
if "invalid content type" in str(e):
|
if "invalid content type" in str(e):
|
||||||
raise Exception(
|
raise Exception(
|
||||||
f"Invalid user message={m} at index {idx}. Please ensure all user messages are valid OpenAI chat completion messages."
|
f"Invalid user message={m} at index {idx}. Please ensure all user messages are valid OpenAI chat completion messages."
|
||||||
|
|
|
@ -1850,3 +1850,15 @@ def test_dict_to_response_format_helper():
|
||||||
"ref_template": "/$defs/{model}",
|
"ref_template": "/$defs/{model}",
|
||||||
}
|
}
|
||||||
_dict_to_response_format_helper(**args)
|
_dict_to_response_format_helper(**args)
|
||||||
|
|
||||||
|
|
||||||
|
def test_validate_user_messages_invalid_content_type():
|
||||||
|
from litellm.utils import validate_chat_completion_user_messages
|
||||||
|
|
||||||
|
messages = [{"content": [{"type": "invalid_type", "text": "Hello"}]}]
|
||||||
|
|
||||||
|
with pytest.raises(Exception) as e:
|
||||||
|
validate_chat_completion_user_messages(messages)
|
||||||
|
|
||||||
|
assert "Invalid message" in str(e)
|
||||||
|
print(e)
|
||||||
|
|
|
@ -55,15 +55,15 @@ def test_aim_guard_config_no_api_key():
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_callback():
|
@pytest.mark.parametrize("mode", ["pre_call", "during_call"])
|
||||||
|
async def test_callback(mode: str):
|
||||||
init_guardrails_v2(
|
init_guardrails_v2(
|
||||||
all_guardrails=[
|
all_guardrails=[
|
||||||
{
|
{
|
||||||
"guardrail_name": "gibberish-guard",
|
"guardrail_name": "gibberish-guard",
|
||||||
"litellm_params": {
|
"litellm_params": {
|
||||||
"guardrail": "aim",
|
"guardrail": "aim",
|
||||||
"guard_name": "gibberish_guard",
|
"mode": mode,
|
||||||
"mode": "pre_call",
|
|
||||||
"api_key": "hs-aim-key",
|
"api_key": "hs-aim-key",
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
@ -89,6 +89,11 @@ async def test_callback():
|
||||||
request=Request(method="POST", url="http://aim"),
|
request=Request(method="POST", url="http://aim"),
|
||||||
),
|
),
|
||||||
):
|
):
|
||||||
await aim_guardrail.async_pre_call_hook(
|
if mode == "pre_call":
|
||||||
data=data, cache=DualCache(), user_api_key_dict=UserAPIKeyAuth(), call_type="completion"
|
await aim_guardrail.async_pre_call_hook(
|
||||||
)
|
data=data, cache=DualCache(), user_api_key_dict=UserAPIKeyAuth(), call_type="completion"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
await aim_guardrail.async_moderation_hook(
|
||||||
|
data=data, user_api_key_dict=UserAPIKeyAuth(), call_type="completion"
|
||||||
|
)
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue