mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 10:44:24 +00:00
feat(proxy_server.py): refactor returning rejected message, to work with error logging
log the rejected request as a failed call to langfuse/slack alerting
This commit is contained in:
parent
372323c38a
commit
f11f207ae6
7 changed files with 118 additions and 100 deletions
|
@ -177,6 +177,32 @@ class ContextWindowExceededError(BadRequestError): # type: ignore
|
|||
) # Call the base class constructor with the parameters it needs
|
||||
|
||||
|
||||
# sub class of bad request error - meant to help us catch guardrails-related errors on proxy.
|
||||
class RejectedRequestError(BadRequestError): # type: ignore
|
||||
def __init__(
|
||||
self,
|
||||
message,
|
||||
model,
|
||||
llm_provider,
|
||||
request_data: dict,
|
||||
litellm_debug_info: Optional[str] = None,
|
||||
):
|
||||
self.status_code = 400
|
||||
self.message = message
|
||||
self.model = model
|
||||
self.llm_provider = llm_provider
|
||||
self.litellm_debug_info = litellm_debug_info
|
||||
self.request_data = request_data
|
||||
request = httpx.Request(method="POST", url="https://api.openai.com/v1")
|
||||
response = httpx.Response(status_code=500, request=request)
|
||||
super().__init__(
|
||||
message=self.message,
|
||||
model=self.model, # type: ignore
|
||||
llm_provider=self.llm_provider, # type: ignore
|
||||
response=response,
|
||||
) # Call the base class constructor with the parameters it needs
|
||||
|
||||
|
||||
class ContentPolicyViolationError(BadRequestError): # type: ignore
|
||||
# Error code: 400 - {'error': {'code': 'content_policy_violation', 'message': 'Your request was rejected as a result of our safety system. Image descriptions generated from your prompt may contain text that is not allowed by our safety system. If you believe this was done in error, your request may succeed if retried, or by adjusting your prompt.', 'param': None, 'type': 'invalid_request_error'}}
|
||||
def __init__(
|
||||
|
|
|
@ -4,7 +4,6 @@ import dotenv, os
|
|||
|
||||
from litellm.proxy._types import UserAPIKeyAuth
|
||||
from litellm.caching import DualCache
|
||||
from litellm.utils import ModelResponse
|
||||
from typing import Literal, Union, Optional
|
||||
import traceback
|
||||
|
||||
|
|
|
@ -18,3 +18,7 @@ model_list:
|
|||
|
||||
router_settings:
|
||||
enable_pre_call_checks: true
|
||||
|
||||
litellm_settings:
|
||||
callbacks: ["detect_prompt_injection"]
|
||||
|
||||
|
|
|
@ -251,6 +251,10 @@ class LiteLLMPromptInjectionParams(LiteLLMBase):
|
|||
llm_api_name: Optional[str] = None
|
||||
llm_api_system_prompt: Optional[str] = None
|
||||
llm_api_fail_call_string: Optional[str] = None
|
||||
reject_as_response: Optional[bool] = Field(
|
||||
default=False,
|
||||
description="Return rejected request error message as a string to the user. Default behaviour is to raise an exception.",
|
||||
)
|
||||
|
||||
@root_validator(pre=True)
|
||||
def check_llm_api_params(cls, values):
|
||||
|
|
|
@ -146,6 +146,7 @@ class _OPTIONAL_PromptInjectionDetection(CustomLogger):
|
|||
try:
|
||||
assert call_type in [
|
||||
"completion",
|
||||
"text_completion",
|
||||
"embeddings",
|
||||
"image_generation",
|
||||
"moderation",
|
||||
|
@ -192,6 +193,13 @@ class _OPTIONAL_PromptInjectionDetection(CustomLogger):
|
|||
return data
|
||||
|
||||
except HTTPException as e:
|
||||
if (
|
||||
e.status_code == 400
|
||||
and isinstance(e.detail, dict)
|
||||
and "error" in e.detail
|
||||
):
|
||||
if self.prompt_injection_params.reject_as_response:
|
||||
return e.detail["error"]
|
||||
raise e
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
|
|
|
@ -124,6 +124,7 @@ from litellm.proxy.auth.auth_checks import (
|
|||
get_actual_routes,
|
||||
)
|
||||
from litellm.llms.custom_httpx.httpx_handler import HTTPHandler
|
||||
from litellm.exceptions import RejectedRequestError
|
||||
|
||||
try:
|
||||
from litellm._version import version
|
||||
|
@ -3766,19 +3767,6 @@ async def chat_completion(
|
|||
user_api_key_dict=user_api_key_dict, data=data, call_type="completion"
|
||||
)
|
||||
|
||||
if isinstance(data, litellm.ModelResponse):
|
||||
return data
|
||||
elif isinstance(data, litellm.CustomStreamWrapper):
|
||||
selected_data_generator = select_data_generator(
|
||||
response=data,
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
request_data={},
|
||||
)
|
||||
|
||||
return StreamingResponse(
|
||||
selected_data_generator,
|
||||
media_type="text/event-stream",
|
||||
)
|
||||
tasks = []
|
||||
tasks.append(
|
||||
proxy_logging_obj.during_call_hook(
|
||||
|
@ -3893,6 +3881,40 @@ async def chat_completion(
|
|||
)
|
||||
|
||||
return response
|
||||
except RejectedRequestError as e:
|
||||
_data = e.request_data
|
||||
_data["litellm_status"] = "fail" # used for alerting
|
||||
await proxy_logging_obj.post_call_failure_hook(
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
original_exception=e,
|
||||
request_data=_data,
|
||||
)
|
||||
_chat_response = litellm.ModelResponse()
|
||||
_chat_response.choices[0].message.content = e.message # type: ignore
|
||||
|
||||
if data.get("stream", None) is not None and data["stream"] == True:
|
||||
_iterator = litellm.utils.ModelResponseIterator(
|
||||
model_response=_chat_response
|
||||
)
|
||||
_streaming_response = litellm.CustomStreamWrapper(
|
||||
completion_stream=_iterator,
|
||||
model=data.get("model", ""),
|
||||
custom_llm_provider="cached_response",
|
||||
logging_obj=data.get("litellm_logging_obj", None),
|
||||
)
|
||||
selected_data_generator = select_data_generator(
|
||||
response=e.message,
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
request_data=_data,
|
||||
)
|
||||
|
||||
return StreamingResponse(
|
||||
selected_data_generator,
|
||||
media_type="text/event-stream",
|
||||
)
|
||||
_usage = litellm.Usage(prompt_tokens=0, completion_tokens=0, total_tokens=0)
|
||||
_chat_response.usage = _usage # type: ignore
|
||||
return _chat_response
|
||||
except Exception as e:
|
||||
data["litellm_status"] = "fail" # used for alerting
|
||||
traceback.print_exc()
|
||||
|
@ -4112,6 +4134,34 @@ async def completion(
|
|||
)
|
||||
|
||||
return response
|
||||
except RejectedRequestError as e:
|
||||
_data = e.request_data
|
||||
_data["litellm_status"] = "fail" # used for alerting
|
||||
await proxy_logging_obj.post_call_failure_hook(
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
original_exception=e,
|
||||
request_data=_data,
|
||||
)
|
||||
if _data.get("stream", None) is not None and _data["stream"] == True:
|
||||
_chat_response = litellm.ModelResponse()
|
||||
_usage = litellm.Usage(
|
||||
prompt_tokens=0,
|
||||
completion_tokens=0,
|
||||
total_tokens=0,
|
||||
)
|
||||
_chat_response.usage = _usage # type: ignore
|
||||
_chat_response.choices[0].message.content = e.message # type: ignore
|
||||
_iterator = litellm.utils.ModelResponseIterator(
|
||||
model_response=_chat_response
|
||||
)
|
||||
return litellm.TextCompletionStreamWrapper(
|
||||
completion_stream=_iterator,
|
||||
model=_data.get("model", ""),
|
||||
)
|
||||
else:
|
||||
_response = litellm.TextCompletionResponse()
|
||||
_response.choices[0].text = e.message
|
||||
return _response
|
||||
except Exception as e:
|
||||
data["litellm_status"] = "fail" # used for alerting
|
||||
await proxy_logging_obj.post_call_failure_hook(
|
||||
|
|
|
@ -18,6 +18,7 @@ from litellm.llms.custom_httpx.httpx_handler import HTTPHandler
|
|||
from litellm.proxy.hooks.parallel_request_limiter import (
|
||||
_PROXY_MaxParallelRequestsHandler,
|
||||
)
|
||||
from litellm.exceptions import RejectedRequestError
|
||||
from litellm._service_logger import ServiceLogging, ServiceTypes
|
||||
from litellm import (
|
||||
ModelResponse,
|
||||
|
@ -186,40 +187,6 @@ class ProxyLogging:
|
|||
)
|
||||
litellm.utils.set_callbacks(callback_list=callback_list)
|
||||
|
||||
# fmt: off
|
||||
|
||||
@overload
|
||||
async def pre_call_hook(
|
||||
self,
|
||||
user_api_key_dict: UserAPIKeyAuth,
|
||||
data: dict,
|
||||
call_type: Literal["completion"]
|
||||
) -> Union[dict, ModelResponse, CustomStreamWrapper]:
|
||||
...
|
||||
|
||||
@overload
|
||||
async def pre_call_hook(
|
||||
self,
|
||||
user_api_key_dict: UserAPIKeyAuth,
|
||||
data: dict,
|
||||
call_type: Literal["text_completion"]
|
||||
) -> Union[dict, TextCompletionResponse, TextCompletionStreamWrapper]:
|
||||
...
|
||||
|
||||
@overload
|
||||
async def pre_call_hook(
|
||||
self,
|
||||
user_api_key_dict: UserAPIKeyAuth,
|
||||
data: dict,
|
||||
call_type: Literal["embeddings",
|
||||
"image_generation",
|
||||
"moderation",
|
||||
"audio_transcription",]
|
||||
) -> dict:
|
||||
...
|
||||
|
||||
# fmt: on
|
||||
|
||||
# The actual implementation of the function
|
||||
async def pre_call_hook(
|
||||
self,
|
||||
|
@ -233,13 +200,7 @@ class ProxyLogging:
|
|||
"moderation",
|
||||
"audio_transcription",
|
||||
],
|
||||
) -> Union[
|
||||
dict,
|
||||
ModelResponse,
|
||||
TextCompletionResponse,
|
||||
CustomStreamWrapper,
|
||||
TextCompletionStreamWrapper,
|
||||
]:
|
||||
) -> dict:
|
||||
"""
|
||||
Allows users to modify/reject the incoming request to the proxy, without having to deal with parsing Request body.
|
||||
|
||||
|
@ -271,54 +232,20 @@ class ProxyLogging:
|
|||
elif isinstance(response, dict):
|
||||
data = response
|
||||
elif isinstance(response, str):
|
||||
if call_type == "completion":
|
||||
_chat_response = ModelResponse()
|
||||
_chat_response.choices[0].message.content = response
|
||||
|
||||
if (
|
||||
data.get("stream", None) is not None
|
||||
and data["stream"] == True
|
||||
call_type == "completion"
|
||||
or call_type == "text_completion"
|
||||
):
|
||||
_iterator = ModelResponseIterator(
|
||||
model_response=_chat_response
|
||||
)
|
||||
return CustomStreamWrapper(
|
||||
completion_stream=_iterator,
|
||||
raise RejectedRequestError(
|
||||
message=response,
|
||||
model=data.get("model", ""),
|
||||
custom_llm_provider="cached_response",
|
||||
logging_obj=data.get(
|
||||
"litellm_logging_obj", None
|
||||
),
|
||||
llm_provider="",
|
||||
request_data=data,
|
||||
)
|
||||
return _response
|
||||
elif call_type == "text_completion":
|
||||
if (
|
||||
data.get("stream", None) is not None
|
||||
and data["stream"] == True
|
||||
):
|
||||
_chat_response = ModelResponse()
|
||||
_chat_response.choices[0].message.content = response
|
||||
|
||||
if (
|
||||
data.get("stream", None) is not None
|
||||
and data["stream"] == True
|
||||
):
|
||||
_iterator = ModelResponseIterator(
|
||||
model_response=_chat_response
|
||||
)
|
||||
return TextCompletionStreamWrapper(
|
||||
completion_stream=_iterator,
|
||||
model=data.get("model", ""),
|
||||
)
|
||||
else:
|
||||
_response = TextCompletionResponse()
|
||||
_response.choices[0].text = response
|
||||
return _response
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=400, detail={"error": response}
|
||||
)
|
||||
|
||||
print_verbose(f"final data being sent to {call_type} call: {data}")
|
||||
return data
|
||||
except Exception as e:
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue