feat(proxy_server.py): refactor returning rejected message, to work with error logging

log the rejected request as a failed call to langfuse/slack alerting
This commit is contained in:
Krrish Dholakia 2024-05-20 11:14:36 -07:00
parent 372323c38a
commit f11f207ae6
7 changed files with 118 additions and 100 deletions

View file

@ -177,6 +177,32 @@ class ContextWindowExceededError(BadRequestError): # type: ignore
) # Call the base class constructor with the parameters it needs
# sub class of bad request error - meant to help us catch guardrails-related errors on proxy.
class RejectedRequestError(BadRequestError): # type: ignore
def __init__(
self,
message,
model,
llm_provider,
request_data: dict,
litellm_debug_info: Optional[str] = None,
):
self.status_code = 400
self.message = message
self.model = model
self.llm_provider = llm_provider
self.litellm_debug_info = litellm_debug_info
self.request_data = request_data
request = httpx.Request(method="POST", url="https://api.openai.com/v1")
response = httpx.Response(status_code=500, request=request)
super().__init__(
message=self.message,
model=self.model, # type: ignore
llm_provider=self.llm_provider, # type: ignore
response=response,
) # Call the base class constructor with the parameters it needs
class ContentPolicyViolationError(BadRequestError): # type: ignore
# Error code: 400 - {'error': {'code': 'content_policy_violation', 'message': 'Your request was rejected as a result of our safety system. Image descriptions generated from your prompt may contain text that is not allowed by our safety system. If you believe this was done in error, your request may succeed if retried, or by adjusting your prompt.', 'param': None, 'type': 'invalid_request_error'}}
def __init__(

View file

@ -4,7 +4,6 @@ import dotenv, os
from litellm.proxy._types import UserAPIKeyAuth
from litellm.caching import DualCache
from litellm.utils import ModelResponse
from typing import Literal, Union, Optional
import traceback

View file

@ -18,3 +18,7 @@ model_list:
router_settings:
enable_pre_call_checks: true
litellm_settings:
callbacks: ["detect_prompt_injection"]

View file

@ -251,6 +251,10 @@ class LiteLLMPromptInjectionParams(LiteLLMBase):
llm_api_name: Optional[str] = None
llm_api_system_prompt: Optional[str] = None
llm_api_fail_call_string: Optional[str] = None
reject_as_response: Optional[bool] = Field(
default=False,
description="Return rejected request error message as a string to the user. Default behaviour is to raise an exception.",
)
@root_validator(pre=True)
def check_llm_api_params(cls, values):

View file

@ -146,6 +146,7 @@ class _OPTIONAL_PromptInjectionDetection(CustomLogger):
try:
assert call_type in [
"completion",
"text_completion",
"embeddings",
"image_generation",
"moderation",
@ -192,6 +193,13 @@ class _OPTIONAL_PromptInjectionDetection(CustomLogger):
return data
except HTTPException as e:
if (
e.status_code == 400
and isinstance(e.detail, dict)
and "error" in e.detail
):
if self.prompt_injection_params.reject_as_response:
return e.detail["error"]
raise e
except Exception as e:
traceback.print_exc()

View file

@ -124,6 +124,7 @@ from litellm.proxy.auth.auth_checks import (
get_actual_routes,
)
from litellm.llms.custom_httpx.httpx_handler import HTTPHandler
from litellm.exceptions import RejectedRequestError
try:
from litellm._version import version
@ -3766,19 +3767,6 @@ async def chat_completion(
user_api_key_dict=user_api_key_dict, data=data, call_type="completion"
)
if isinstance(data, litellm.ModelResponse):
return data
elif isinstance(data, litellm.CustomStreamWrapper):
selected_data_generator = select_data_generator(
response=data,
user_api_key_dict=user_api_key_dict,
request_data={},
)
return StreamingResponse(
selected_data_generator,
media_type="text/event-stream",
)
tasks = []
tasks.append(
proxy_logging_obj.during_call_hook(
@ -3893,6 +3881,40 @@ async def chat_completion(
)
return response
except RejectedRequestError as e:
_data = e.request_data
_data["litellm_status"] = "fail" # used for alerting
await proxy_logging_obj.post_call_failure_hook(
user_api_key_dict=user_api_key_dict,
original_exception=e,
request_data=_data,
)
_chat_response = litellm.ModelResponse()
_chat_response.choices[0].message.content = e.message # type: ignore
if data.get("stream", None) is not None and data["stream"] == True:
_iterator = litellm.utils.ModelResponseIterator(
model_response=_chat_response
)
_streaming_response = litellm.CustomStreamWrapper(
completion_stream=_iterator,
model=data.get("model", ""),
custom_llm_provider="cached_response",
logging_obj=data.get("litellm_logging_obj", None),
)
selected_data_generator = select_data_generator(
response=e.message,
user_api_key_dict=user_api_key_dict,
request_data=_data,
)
return StreamingResponse(
selected_data_generator,
media_type="text/event-stream",
)
_usage = litellm.Usage(prompt_tokens=0, completion_tokens=0, total_tokens=0)
_chat_response.usage = _usage # type: ignore
return _chat_response
except Exception as e:
data["litellm_status"] = "fail" # used for alerting
traceback.print_exc()
@ -4112,6 +4134,34 @@ async def completion(
)
return response
except RejectedRequestError as e:
_data = e.request_data
_data["litellm_status"] = "fail" # used for alerting
await proxy_logging_obj.post_call_failure_hook(
user_api_key_dict=user_api_key_dict,
original_exception=e,
request_data=_data,
)
if _data.get("stream", None) is not None and _data["stream"] == True:
_chat_response = litellm.ModelResponse()
_usage = litellm.Usage(
prompt_tokens=0,
completion_tokens=0,
total_tokens=0,
)
_chat_response.usage = _usage # type: ignore
_chat_response.choices[0].message.content = e.message # type: ignore
_iterator = litellm.utils.ModelResponseIterator(
model_response=_chat_response
)
return litellm.TextCompletionStreamWrapper(
completion_stream=_iterator,
model=_data.get("model", ""),
)
else:
_response = litellm.TextCompletionResponse()
_response.choices[0].text = e.message
return _response
except Exception as e:
data["litellm_status"] = "fail" # used for alerting
await proxy_logging_obj.post_call_failure_hook(

View file

@ -18,6 +18,7 @@ from litellm.llms.custom_httpx.httpx_handler import HTTPHandler
from litellm.proxy.hooks.parallel_request_limiter import (
_PROXY_MaxParallelRequestsHandler,
)
from litellm.exceptions import RejectedRequestError
from litellm._service_logger import ServiceLogging, ServiceTypes
from litellm import (
ModelResponse,
@ -186,40 +187,6 @@ class ProxyLogging:
)
litellm.utils.set_callbacks(callback_list=callback_list)
# fmt: off
@overload
async def pre_call_hook(
self,
user_api_key_dict: UserAPIKeyAuth,
data: dict,
call_type: Literal["completion"]
) -> Union[dict, ModelResponse, CustomStreamWrapper]:
...
@overload
async def pre_call_hook(
self,
user_api_key_dict: UserAPIKeyAuth,
data: dict,
call_type: Literal["text_completion"]
) -> Union[dict, TextCompletionResponse, TextCompletionStreamWrapper]:
...
@overload
async def pre_call_hook(
self,
user_api_key_dict: UserAPIKeyAuth,
data: dict,
call_type: Literal["embeddings",
"image_generation",
"moderation",
"audio_transcription",]
) -> dict:
...
# fmt: on
# The actual implementation of the function
async def pre_call_hook(
self,
@ -233,13 +200,7 @@ class ProxyLogging:
"moderation",
"audio_transcription",
],
) -> Union[
dict,
ModelResponse,
TextCompletionResponse,
CustomStreamWrapper,
TextCompletionStreamWrapper,
]:
) -> dict:
"""
Allows users to modify/reject the incoming request to the proxy, without having to deal with parsing Request body.
@ -271,54 +232,20 @@ class ProxyLogging:
elif isinstance(response, dict):
data = response
elif isinstance(response, str):
if call_type == "completion":
_chat_response = ModelResponse()
_chat_response.choices[0].message.content = response
if (
data.get("stream", None) is not None
and data["stream"] == True
call_type == "completion"
or call_type == "text_completion"
):
_iterator = ModelResponseIterator(
model_response=_chat_response
)
return CustomStreamWrapper(
completion_stream=_iterator,
raise RejectedRequestError(
message=response,
model=data.get("model", ""),
custom_llm_provider="cached_response",
logging_obj=data.get(
"litellm_logging_obj", None
),
llm_provider="",
request_data=data,
)
return _response
elif call_type == "text_completion":
if (
data.get("stream", None) is not None
and data["stream"] == True
):
_chat_response = ModelResponse()
_chat_response.choices[0].message.content = response
if (
data.get("stream", None) is not None
and data["stream"] == True
):
_iterator = ModelResponseIterator(
model_response=_chat_response
)
return TextCompletionStreamWrapper(
completion_stream=_iterator,
model=data.get("model", ""),
)
else:
_response = TextCompletionResponse()
_response.choices[0].text = response
return _response
else:
raise HTTPException(
status_code=400, detail={"error": response}
)
print_verbose(f"final data being sent to {call_type} call: {data}")
return data
except Exception as e: