feat(proxy_server.py): allow admin to return rejected response as string to user

Closes https://github.com/BerriAI/litellm/issues/3671
This commit is contained in:
Krrish Dholakia 2024-05-20 10:30:23 -07:00
parent 098686c193
commit 45fedb83c6
4 changed files with 175 additions and 10 deletions

View file

@ -4,7 +4,7 @@ import dotenv, os
from litellm.proxy._types import UserAPIKeyAuth
from litellm.caching import DualCache
from litellm.utils import ModelResponse
from typing import Literal, Union, Optional
import traceback
@ -64,8 +64,17 @@ class CustomLogger: # https://docs.litellm.ai/docs/observability/custom_callbac
user_api_key_dict: UserAPIKeyAuth,
cache: DualCache,
data: dict,
call_type: Literal["completion", "embeddings", "image_generation"],
):
call_type: Literal[
"completion",
"text_completion",
"embeddings",
"image_generation",
"moderation",
"audio_transcription",
],
) -> Optional[
Union[Exception, str, dict]
]: # raise exception if invalid, return a str for the user to receive - if rejected, or return a modified dictionary for passing into litellm
pass
async def async_post_call_failure_hook(

View file

@ -3761,11 +3761,24 @@ async def chat_completion(
data["litellm_logging_obj"] = logging_obj
### CALL HOOKS ### - modify incoming data before calling the model
data = await proxy_logging_obj.pre_call_hook(
### CALL HOOKS ### - modify/reject incoming data before calling the model
data = await proxy_logging_obj.pre_call_hook( # type: ignore
user_api_key_dict=user_api_key_dict, data=data, call_type="completion"
)
if isinstance(data, litellm.ModelResponse):
return data
elif isinstance(data, litellm.CustomStreamWrapper):
selected_data_generator = select_data_generator(
response=data,
user_api_key_dict=user_api_key_dict,
request_data={},
)
return StreamingResponse(
selected_data_generator,
media_type="text/event-stream",
)
tasks = []
tasks.append(
proxy_logging_obj.during_call_hook(
@ -3998,10 +4011,24 @@ async def completion(
data["model"] = litellm.model_alias_map[data["model"]]
### CALL HOOKS ### - modify incoming data before calling the model
data = await proxy_logging_obj.pre_call_hook(
user_api_key_dict=user_api_key_dict, data=data, call_type="completion"
data = await proxy_logging_obj.pre_call_hook( # type: ignore
user_api_key_dict=user_api_key_dict, data=data, call_type="text_completion"
)
if isinstance(data, litellm.TextCompletionResponse):
return data
elif isinstance(data, litellm.TextCompletionStreamWrapper):
selected_data_generator = select_data_generator(
response=data,
user_api_key_dict=user_api_key_dict,
request_data={},
)
return StreamingResponse(
selected_data_generator,
media_type="text/event-stream",
)
### ROUTE THE REQUESTs ###
router_model_names = llm_router.model_names if llm_router is not None else []
# skip router if user passed their key

View file

@ -19,7 +19,16 @@ from litellm.proxy.hooks.parallel_request_limiter import (
_PROXY_MaxParallelRequestsHandler,
)
from litellm._service_logger import ServiceLogging, ServiceTypes
from litellm import ModelResponse, EmbeddingResponse, ImageResponse
from litellm import (
ModelResponse,
EmbeddingResponse,
ImageResponse,
TranscriptionResponse,
TextCompletionResponse,
CustomStreamWrapper,
TextCompletionStreamWrapper,
)
from litellm.utils import ModelResponseIterator
from litellm.proxy.hooks.max_budget_limiter import _PROXY_MaxBudgetLimiter
from litellm.proxy.hooks.tpm_rpm_limiter import _PROXY_MaxTPMRPMLimiter
from litellm.proxy.hooks.cache_control_check import _PROXY_CacheControlCheck
@ -32,6 +41,7 @@ from email.mime.text import MIMEText
from email.mime.multipart import MIMEMultipart
from datetime import datetime, timedelta
from litellm.integrations.slack_alerting import SlackAlerting
from typing_extensions import overload
def print_verbose(print_statement):
@ -176,18 +186,60 @@ class ProxyLogging:
)
litellm.utils.set_callbacks(callback_list=callback_list)
# fmt: off
@overload
async def pre_call_hook(
self,
user_api_key_dict: UserAPIKeyAuth,
data: dict,
call_type: Literal["completion"]
) -> Union[dict, ModelResponse, CustomStreamWrapper]:
...
@overload
async def pre_call_hook(
self,
user_api_key_dict: UserAPIKeyAuth,
data: dict,
call_type: Literal["text_completion"]
) -> Union[dict, TextCompletionResponse, TextCompletionStreamWrapper]:
...
@overload
async def pre_call_hook(
self,
user_api_key_dict: UserAPIKeyAuth,
data: dict,
call_type: Literal["embeddings",
"image_generation",
"moderation",
"audio_transcription",]
) -> dict:
...
# fmt: on
# The actual implementation of the function
async def pre_call_hook(
self,
user_api_key_dict: UserAPIKeyAuth,
data: dict,
call_type: Literal[
"completion",
"text_completion",
"embeddings",
"image_generation",
"moderation",
"audio_transcription",
],
):
) -> Union[
dict,
ModelResponse,
TextCompletionResponse,
CustomStreamWrapper,
TextCompletionStreamWrapper,
]:
"""
Allows users to modify/reject the incoming request to the proxy, without having to deal with parsing Request body.
@ -214,7 +266,58 @@ class ProxyLogging:
call_type=call_type,
)
if response is not None:
data = response
if isinstance(response, Exception):
raise response
elif isinstance(response, dict):
data = response
elif isinstance(response, str):
if call_type == "completion":
_chat_response = ModelResponse()
_chat_response.choices[0].message.content = response
if (
data.get("stream", None) is not None
and data["stream"] == True
):
_iterator = ModelResponseIterator(
model_response=_chat_response
)
return CustomStreamWrapper(
completion_stream=_iterator,
model=data.get("model", ""),
custom_llm_provider="cached_response",
logging_obj=data.get(
"litellm_logging_obj", None
),
)
return _response
elif call_type == "text_completion":
if (
data.get("stream", None) is not None
and data["stream"] == True
):
_chat_response = ModelResponse()
_chat_response.choices[0].message.content = response
if (
data.get("stream", None) is not None
and data["stream"] == True
):
_iterator = ModelResponseIterator(
model_response=_chat_response
)
return TextCompletionStreamWrapper(
completion_stream=_iterator,
model=data.get("model", ""),
)
else:
_response = TextCompletionResponse()
_response.choices[0].text = response
return _response
else:
raise HTTPException(
status_code=400, detail={"error": response}
)
print_verbose(f"final data being sent to {call_type} call: {data}")
return data

View file

@ -12187,3 +12187,29 @@ def _add_key_name_and_team_to_alert(request_info: str, metadata: dict) -> str:
return request_info
except:
return request_info
class ModelResponseIterator:
def __init__(self, model_response):
self.model_response = model_response
self.is_done = False
# Sync iterator
def __iter__(self):
return self
def __next__(self):
if self.is_done:
raise StopIteration
self.is_done = True
return self.model_response
# Async iterator
def __aiter__(self):
return self
async def __anext__(self):
if self.is_done:
raise StopAsyncIteration
self.is_done = True
return self.model_response