Merge pull request #3740 from BerriAI/litellm_return_rejected_response

feat(proxy_server.py): allow admin to return rejected response as string to user
This commit is contained in:
Krish Dholakia 2024-05-20 17:48:21 -07:00 committed by GitHub
commit c6bb6e325b
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
9 changed files with 338 additions and 28 deletions

View file

@ -25,26 +25,45 @@ class MyCustomHandler(CustomLogger): # https://docs.litellm.ai/docs/observabilit
def __init__(self):
pass
#### ASYNC ####
async def async_log_stream_event(self, kwargs, response_obj, start_time, end_time):
pass
async def async_log_pre_api_call(self, model, messages, kwargs):
pass
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
pass
async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time):
pass
#### CALL HOOKS - proxy only ####
async def async_pre_call_hook(self, user_api_key_dict: UserAPIKeyAuth, cache: DualCache, data: dict, call_type: Literal["completion", "embeddings"]):
async def async_pre_call_hook(self, user_api_key_dict: UserAPIKeyAuth, cache: DualCache, data: dict, call_type: Literal[
"completion",
"text_completion",
"embeddings",
"image_generation",
"moderation",
"audio_transcription",
]) -> Optional[dict, str, Exception]:
data["model"] = "my-new-model"
return data
async def async_post_call_failure_hook(
self, original_exception: Exception, user_api_key_dict: UserAPIKeyAuth
):
pass
async def async_post_call_success_hook(
self,
user_api_key_dict: UserAPIKeyAuth,
response,
):
pass
async def async_moderation_hook( # call made in parallel to llm api call
self,
data: dict,
user_api_key_dict: UserAPIKeyAuth,
call_type: Literal["completion", "embeddings", "image_generation"],
):
pass
async def async_post_call_streaming_hook(
self,
user_api_key_dict: UserAPIKeyAuth,
response: str,
):
pass
proxy_handler_instance = MyCustomHandler()
```
@ -191,3 +210,99 @@ general_settings:
**Result**
<Image img={require('../../img/end_user_enforcement.png')}/>
## Advanced - Return rejected message as response
For chat completions and text completion calls, you can return a rejected message as a user response.
Do this by returning a string. LiteLLM takes care of returning the response in the correct format depending on the endpoint and if it's streaming/non-streaming.
For non-chat/text completion endpoints, this response is returned as a 400 status code exception.
### 1. Create Custom Handler
```python
from litellm.integrations.custom_logger import CustomLogger
import litellm
from litellm.utils import get_formatted_prompt
# This file includes the custom callbacks for LiteLLM Proxy
# Once defined, these can be passed in proxy_config.yaml
class MyCustomHandler(CustomLogger):
def __init__(self):
pass
#### CALL HOOKS - proxy only ####
async def async_pre_call_hook(self, user_api_key_dict: UserAPIKeyAuth, cache: DualCache, data: dict, call_type: Literal[
"completion",
"text_completion",
"embeddings",
"image_generation",
"moderation",
"audio_transcription",
]) -> Optional[dict, str, Exception]:
formatted_prompt = get_formatted_prompt(data=data, call_type=call_type)
if "Hello world" in formatted_prompt:
return "This is an invalid response"
return data
proxy_handler_instance = MyCustomHandler()
```
### 2. Update config.yaml
```yaml
model_list:
- model_name: gpt-3.5-turbo
litellm_params:
model: gpt-3.5-turbo
litellm_settings:
callbacks: custom_callbacks.proxy_handler_instance # sets litellm.callbacks = [proxy_handler_instance]
```
### 3. Test it!
```shell
$ litellm /path/to/config.yaml
```
```shell
curl --location 'http://0.0.0.0:4000/chat/completions' \
--data ' {
"model": "gpt-3.5-turbo",
"messages": [
{
"role": "user",
"content": "Hello world"
}
],
}'
```
**Expected Response**
```
{
"id": "chatcmpl-d00bbede-2d90-4618-bf7b-11a1c23cf360",
"choices": [
{
"finish_reason": "stop",
"index": 0,
"message": {
"content": "This is an invalid response.", # 👈 REJECTED RESPONSE
"role": "assistant"
}
}
],
"created": 1716234198,
"model": null,
"object": "chat.completion",
"system_fingerprint": null,
"usage": {}
}
```

View file

@ -177,6 +177,32 @@ class ContextWindowExceededError(BadRequestError): # type: ignore
) # Call the base class constructor with the parameters it needs
# sub class of bad request error - meant to help us catch guardrails-related errors on proxy.
class RejectedRequestError(BadRequestError): # type: ignore
def __init__(
self,
message,
model,
llm_provider,
request_data: dict,
litellm_debug_info: Optional[str] = None,
):
self.status_code = 400
self.message = message
self.model = model
self.llm_provider = llm_provider
self.litellm_debug_info = litellm_debug_info
self.request_data = request_data
request = httpx.Request(method="POST", url="https://api.openai.com/v1")
response = httpx.Response(status_code=500, request=request)
super().__init__(
message=self.message,
model=self.model, # type: ignore
llm_provider=self.llm_provider, # type: ignore
response=response,
) # Call the base class constructor with the parameters it needs
class ContentPolicyViolationError(BadRequestError): # type: ignore
# Error code: 400 - {'error': {'code': 'content_policy_violation', 'message': 'Your request was rejected as a result of our safety system. Image descriptions generated from your prompt may contain text that is not allowed by our safety system. If you believe this was done in error, your request may succeed if retried, or by adjusting your prompt.', 'param': None, 'type': 'invalid_request_error'}}
def __init__(

View file

@ -4,7 +4,6 @@ import dotenv, os
from litellm.proxy._types import UserAPIKeyAuth
from litellm.caching import DualCache
from typing import Literal, Union, Optional
import traceback
@ -64,8 +63,17 @@ class CustomLogger: # https://docs.litellm.ai/docs/observability/custom_callbac
user_api_key_dict: UserAPIKeyAuth,
cache: DualCache,
data: dict,
call_type: Literal["completion", "embeddings", "image_generation"],
):
call_type: Literal[
"completion",
"text_completion",
"embeddings",
"image_generation",
"moderation",
"audio_transcription",
],
) -> Optional[
Union[Exception, str, dict]
]: # raise exception if invalid, return a str for the user to receive - if rejected, or return a modified dictionary for passing into litellm
pass
async def async_post_call_failure_hook(

View file

@ -18,3 +18,11 @@ model_list:
router_settings:
enable_pre_call_checks: true
litellm_settings:
callbacks: ["detect_prompt_injection"]
prompt_injection_params:
heuristics_check: true
similarity_check: true
reject_as_response: true

View file

@ -251,6 +251,10 @@ class LiteLLMPromptInjectionParams(LiteLLMBase):
llm_api_name: Optional[str] = None
llm_api_system_prompt: Optional[str] = None
llm_api_fail_call_string: Optional[str] = None
reject_as_response: Optional[bool] = Field(
default=False,
description="Return rejected request error message as a string to the user. Default behaviour is to raise an exception.",
)
@root_validator(pre=True)
def check_llm_api_params(cls, values):

View file

@ -146,6 +146,7 @@ class _OPTIONAL_PromptInjectionDetection(CustomLogger):
try:
assert call_type in [
"completion",
"text_completion",
"embeddings",
"image_generation",
"moderation",
@ -192,6 +193,15 @@ class _OPTIONAL_PromptInjectionDetection(CustomLogger):
return data
except HTTPException as e:
if (
e.status_code == 400
and isinstance(e.detail, dict)
and "error" in e.detail
and self.prompt_injection_params is not None
and self.prompt_injection_params.reject_as_response
):
return e.detail["error"]
raise e
except Exception as e:
traceback.print_exc()

View file

@ -124,6 +124,7 @@ from litellm.proxy.auth.auth_checks import (
get_actual_routes,
)
from litellm.llms.custom_httpx.httpx_handler import HTTPHandler
from litellm.exceptions import RejectedRequestError
try:
from litellm._version import version
@ -3763,8 +3764,8 @@ async def chat_completion(
data["litellm_logging_obj"] = logging_obj
### CALL HOOKS ### - modify incoming data before calling the model
data = await proxy_logging_obj.pre_call_hook(
### CALL HOOKS ### - modify/reject incoming data before calling the model
data = await proxy_logging_obj.pre_call_hook( # type: ignore
user_api_key_dict=user_api_key_dict, data=data, call_type="completion"
)
@ -3879,6 +3880,40 @@ async def chat_completion(
)
return response
except RejectedRequestError as e:
_data = e.request_data
_data["litellm_status"] = "fail" # used for alerting
await proxy_logging_obj.post_call_failure_hook(
user_api_key_dict=user_api_key_dict,
original_exception=e,
request_data=_data,
)
_chat_response = litellm.ModelResponse()
_chat_response.choices[0].message.content = e.message # type: ignore
if data.get("stream", None) is not None and data["stream"] == True:
_iterator = litellm.utils.ModelResponseIterator(
model_response=_chat_response, convert_to_delta=True
)
_streaming_response = litellm.CustomStreamWrapper(
completion_stream=_iterator,
model=data.get("model", ""),
custom_llm_provider="cached_response",
logging_obj=data.get("litellm_logging_obj", None),
)
selected_data_generator = select_data_generator(
response=_streaming_response,
user_api_key_dict=user_api_key_dict,
request_data=_data,
)
return StreamingResponse(
selected_data_generator,
media_type="text/event-stream",
)
_usage = litellm.Usage(prompt_tokens=0, completion_tokens=0, total_tokens=0)
_chat_response.usage = _usage # type: ignore
return _chat_response
except Exception as e:
data["litellm_status"] = "fail" # used for alerting
traceback.print_exc()
@ -3993,8 +4028,8 @@ async def completion(
data["model"] = litellm.model_alias_map[data["model"]]
### CALL HOOKS ### - modify incoming data before calling the model
data = await proxy_logging_obj.pre_call_hook(
user_api_key_dict=user_api_key_dict, data=data, call_type="completion"
data = await proxy_logging_obj.pre_call_hook( # type: ignore
user_api_key_dict=user_api_key_dict, data=data, call_type="text_completion"
)
### ROUTE THE REQUESTs ###
@ -4077,6 +4112,46 @@ async def completion(
)
return response
except RejectedRequestError as e:
_data = e.request_data
_data["litellm_status"] = "fail" # used for alerting
await proxy_logging_obj.post_call_failure_hook(
user_api_key_dict=user_api_key_dict,
original_exception=e,
request_data=_data,
)
if _data.get("stream", None) is not None and _data["stream"] == True:
_chat_response = litellm.ModelResponse()
_usage = litellm.Usage(
prompt_tokens=0,
completion_tokens=0,
total_tokens=0,
)
_chat_response.usage = _usage # type: ignore
_chat_response.choices[0].message.content = e.message # type: ignore
_iterator = litellm.utils.ModelResponseIterator(
model_response=_chat_response, convert_to_delta=True
)
_streaming_response = litellm.TextCompletionStreamWrapper(
completion_stream=_iterator,
model=_data.get("model", ""),
)
selected_data_generator = select_data_generator(
response=_streaming_response,
user_api_key_dict=user_api_key_dict,
request_data=data,
)
return StreamingResponse(
selected_data_generator,
media_type="text/event-stream",
headers={},
)
else:
_response = litellm.TextCompletionResponse()
_response.choices[0].text = e.message
return _response
except Exception as e:
data["litellm_status"] = "fail" # used for alerting
await proxy_logging_obj.post_call_failure_hook(

View file

@ -18,8 +18,18 @@ from litellm.llms.custom_httpx.httpx_handler import HTTPHandler
from litellm.proxy.hooks.parallel_request_limiter import (
_PROXY_MaxParallelRequestsHandler,
)
from litellm.exceptions import RejectedRequestError
from litellm._service_logger import ServiceLogging, ServiceTypes
from litellm import ModelResponse, EmbeddingResponse, ImageResponse
from litellm import (
ModelResponse,
EmbeddingResponse,
ImageResponse,
TranscriptionResponse,
TextCompletionResponse,
CustomStreamWrapper,
TextCompletionStreamWrapper,
)
from litellm.utils import ModelResponseIterator
from litellm.proxy.hooks.max_budget_limiter import _PROXY_MaxBudgetLimiter
from litellm.proxy.hooks.tpm_rpm_limiter import _PROXY_MaxTPMRPMLimiter
from litellm.proxy.hooks.cache_control_check import _PROXY_CacheControlCheck
@ -32,6 +42,7 @@ from email.mime.text import MIMEText
from email.mime.multipart import MIMEMultipart
from datetime import datetime, timedelta
from litellm.integrations.slack_alerting import SlackAlerting
from typing_extensions import overload
def print_verbose(print_statement):
@ -182,18 +193,20 @@ class ProxyLogging:
)
litellm.utils.set_callbacks(callback_list=callback_list)
# The actual implementation of the function
async def pre_call_hook(
self,
user_api_key_dict: UserAPIKeyAuth,
data: dict,
call_type: Literal[
"completion",
"text_completion",
"embeddings",
"image_generation",
"moderation",
"audio_transcription",
],
):
) -> dict:
"""
Allows users to modify/reject the incoming request to the proxy, without having to deal with parsing Request body.
@ -220,8 +233,25 @@ class ProxyLogging:
call_type=call_type,
)
if response is not None:
data = response
if isinstance(response, Exception):
raise response
elif isinstance(response, dict):
data = response
elif isinstance(response, str):
if (
call_type == "completion"
or call_type == "text_completion"
):
raise RejectedRequestError(
message=response,
model=data.get("model", ""),
llm_provider="",
request_data=data,
)
else:
raise HTTPException(
status_code=400, detail={"error": response}
)
print_verbose(f"final data being sent to {call_type} call: {data}")
return data
except Exception as e:

View file

@ -6492,6 +6492,7 @@ def get_formatted_prompt(
"image_generation",
"audio_transcription",
"moderation",
"text_completion",
],
) -> str:
"""
@ -6504,6 +6505,8 @@ def get_formatted_prompt(
for m in data["messages"]:
if "content" in m and isinstance(m["content"], str):
prompt += m["content"]
elif call_type == "text_completion":
prompt = data["prompt"]
elif call_type == "embedding" or call_type == "moderation":
if isinstance(data["input"], str):
prompt = data["input"]
@ -12246,3 +12249,34 @@ def _add_key_name_and_team_to_alert(request_info: str, metadata: dict) -> str:
return request_info
except:
return request_info
class ModelResponseIterator:
def __init__(self, model_response: ModelResponse, convert_to_delta: bool = False):
if convert_to_delta == True:
self.model_response = ModelResponse(stream=True)
_delta = self.model_response.choices[0].delta # type: ignore
_delta.content = model_response.choices[0].message.content # type: ignore
else:
self.model_response = model_response
self.is_done = False
# Sync iterator
def __iter__(self):
return self
def __next__(self):
if self.is_done:
raise StopIteration
self.is_done = True
return self.model_response
# Async iterator
def __aiter__(self):
return self
async def __anext__(self):
if self.is_done:
raise StopAsyncIteration
self.is_done = True
return self.model_response