feat(proxy_server.py): allow admin to return rejected response as string to user

Closes https://github.com/BerriAI/litellm/issues/3671
This commit is contained in:
Krrish Dholakia 2024-05-20 10:30:23 -07:00
parent dbaeb8ff53
commit 372323c38a
4 changed files with 175 additions and 10 deletions

View file

@ -4,7 +4,7 @@ import dotenv, os
from litellm.proxy._types import UserAPIKeyAuth from litellm.proxy._types import UserAPIKeyAuth
from litellm.caching import DualCache from litellm.caching import DualCache
from litellm.utils import ModelResponse
from typing import Literal, Union, Optional from typing import Literal, Union, Optional
import traceback import traceback
@ -64,8 +64,17 @@ class CustomLogger: # https://docs.litellm.ai/docs/observability/custom_callbac
user_api_key_dict: UserAPIKeyAuth, user_api_key_dict: UserAPIKeyAuth,
cache: DualCache, cache: DualCache,
data: dict, data: dict,
call_type: Literal["completion", "embeddings", "image_generation"], call_type: Literal[
): "completion",
"text_completion",
"embeddings",
"image_generation",
"moderation",
"audio_transcription",
],
) -> Optional[
Union[Exception, str, dict]
]: # raise exception if invalid, return a str for the user to receive - if rejected, or return a modified dictionary for passing into litellm
pass pass
async def async_post_call_failure_hook( async def async_post_call_failure_hook(

View file

@ -3761,11 +3761,24 @@ async def chat_completion(
data["litellm_logging_obj"] = logging_obj data["litellm_logging_obj"] = logging_obj
### CALL HOOKS ### - modify incoming data before calling the model ### CALL HOOKS ### - modify/reject incoming data before calling the model
data = await proxy_logging_obj.pre_call_hook( data = await proxy_logging_obj.pre_call_hook( # type: ignore
user_api_key_dict=user_api_key_dict, data=data, call_type="completion" user_api_key_dict=user_api_key_dict, data=data, call_type="completion"
) )
if isinstance(data, litellm.ModelResponse):
return data
elif isinstance(data, litellm.CustomStreamWrapper):
selected_data_generator = select_data_generator(
response=data,
user_api_key_dict=user_api_key_dict,
request_data={},
)
return StreamingResponse(
selected_data_generator,
media_type="text/event-stream",
)
tasks = [] tasks = []
tasks.append( tasks.append(
proxy_logging_obj.during_call_hook( proxy_logging_obj.during_call_hook(
@ -3998,10 +4011,24 @@ async def completion(
data["model"] = litellm.model_alias_map[data["model"]] data["model"] = litellm.model_alias_map[data["model"]]
### CALL HOOKS ### - modify incoming data before calling the model ### CALL HOOKS ### - modify incoming data before calling the model
data = await proxy_logging_obj.pre_call_hook( data = await proxy_logging_obj.pre_call_hook( # type: ignore
user_api_key_dict=user_api_key_dict, data=data, call_type="completion" user_api_key_dict=user_api_key_dict, data=data, call_type="text_completion"
) )
if isinstance(data, litellm.TextCompletionResponse):
return data
elif isinstance(data, litellm.TextCompletionStreamWrapper):
selected_data_generator = select_data_generator(
response=data,
user_api_key_dict=user_api_key_dict,
request_data={},
)
return StreamingResponse(
selected_data_generator,
media_type="text/event-stream",
)
### ROUTE THE REQUESTs ### ### ROUTE THE REQUESTs ###
router_model_names = llm_router.model_names if llm_router is not None else [] router_model_names = llm_router.model_names if llm_router is not None else []
# skip router if user passed their key # skip router if user passed their key

View file

@ -19,7 +19,16 @@ from litellm.proxy.hooks.parallel_request_limiter import (
_PROXY_MaxParallelRequestsHandler, _PROXY_MaxParallelRequestsHandler,
) )
from litellm._service_logger import ServiceLogging, ServiceTypes from litellm._service_logger import ServiceLogging, ServiceTypes
from litellm import ModelResponse, EmbeddingResponse, ImageResponse from litellm import (
ModelResponse,
EmbeddingResponse,
ImageResponse,
TranscriptionResponse,
TextCompletionResponse,
CustomStreamWrapper,
TextCompletionStreamWrapper,
)
from litellm.utils import ModelResponseIterator
from litellm.proxy.hooks.max_budget_limiter import _PROXY_MaxBudgetLimiter from litellm.proxy.hooks.max_budget_limiter import _PROXY_MaxBudgetLimiter
from litellm.proxy.hooks.tpm_rpm_limiter import _PROXY_MaxTPMRPMLimiter from litellm.proxy.hooks.tpm_rpm_limiter import _PROXY_MaxTPMRPMLimiter
from litellm.proxy.hooks.cache_control_check import _PROXY_CacheControlCheck from litellm.proxy.hooks.cache_control_check import _PROXY_CacheControlCheck
@ -32,6 +41,7 @@ from email.mime.text import MIMEText
from email.mime.multipart import MIMEMultipart from email.mime.multipart import MIMEMultipart
from datetime import datetime, timedelta from datetime import datetime, timedelta
from litellm.integrations.slack_alerting import SlackAlerting from litellm.integrations.slack_alerting import SlackAlerting
from typing_extensions import overload
def print_verbose(print_statement): def print_verbose(print_statement):
@ -176,18 +186,60 @@ class ProxyLogging:
) )
litellm.utils.set_callbacks(callback_list=callback_list) litellm.utils.set_callbacks(callback_list=callback_list)
# fmt: off
@overload
async def pre_call_hook(
self,
user_api_key_dict: UserAPIKeyAuth,
data: dict,
call_type: Literal["completion"]
) -> Union[dict, ModelResponse, CustomStreamWrapper]:
...
@overload
async def pre_call_hook(
self,
user_api_key_dict: UserAPIKeyAuth,
data: dict,
call_type: Literal["text_completion"]
) -> Union[dict, TextCompletionResponse, TextCompletionStreamWrapper]:
...
@overload
async def pre_call_hook(
self,
user_api_key_dict: UserAPIKeyAuth,
data: dict,
call_type: Literal["embeddings",
"image_generation",
"moderation",
"audio_transcription",]
) -> dict:
...
# fmt: on
# The actual implementation of the function
async def pre_call_hook( async def pre_call_hook(
self, self,
user_api_key_dict: UserAPIKeyAuth, user_api_key_dict: UserAPIKeyAuth,
data: dict, data: dict,
call_type: Literal[ call_type: Literal[
"completion", "completion",
"text_completion",
"embeddings", "embeddings",
"image_generation", "image_generation",
"moderation", "moderation",
"audio_transcription", "audio_transcription",
], ],
): ) -> Union[
dict,
ModelResponse,
TextCompletionResponse,
CustomStreamWrapper,
TextCompletionStreamWrapper,
]:
""" """
Allows users to modify/reject the incoming request to the proxy, without having to deal with parsing Request body. Allows users to modify/reject the incoming request to the proxy, without having to deal with parsing Request body.
@ -214,7 +266,58 @@ class ProxyLogging:
call_type=call_type, call_type=call_type,
) )
if response is not None: if response is not None:
data = response if isinstance(response, Exception):
raise response
elif isinstance(response, dict):
data = response
elif isinstance(response, str):
if call_type == "completion":
_chat_response = ModelResponse()
_chat_response.choices[0].message.content = response
if (
data.get("stream", None) is not None
and data["stream"] == True
):
_iterator = ModelResponseIterator(
model_response=_chat_response
)
return CustomStreamWrapper(
completion_stream=_iterator,
model=data.get("model", ""),
custom_llm_provider="cached_response",
logging_obj=data.get(
"litellm_logging_obj", None
),
)
return _response
elif call_type == "text_completion":
if (
data.get("stream", None) is not None
and data["stream"] == True
):
_chat_response = ModelResponse()
_chat_response.choices[0].message.content = response
if (
data.get("stream", None) is not None
and data["stream"] == True
):
_iterator = ModelResponseIterator(
model_response=_chat_response
)
return TextCompletionStreamWrapper(
completion_stream=_iterator,
model=data.get("model", ""),
)
else:
_response = TextCompletionResponse()
_response.choices[0].text = response
return _response
else:
raise HTTPException(
status_code=400, detail={"error": response}
)
print_verbose(f"final data being sent to {call_type} call: {data}") print_verbose(f"final data being sent to {call_type} call: {data}")
return data return data

View file

@ -12187,3 +12187,29 @@ def _add_key_name_and_team_to_alert(request_info: str, metadata: dict) -> str:
return request_info return request_info
except: except:
return request_info return request_info
class ModelResponseIterator:
def __init__(self, model_response):
self.model_response = model_response
self.is_done = False
# Sync iterator
def __iter__(self):
return self
def __next__(self):
if self.is_done:
raise StopIteration
self.is_done = True
return self.model_response
# Async iterator
def __aiter__(self):
return self
async def __anext__(self):
if self.is_done:
raise StopAsyncIteration
self.is_done = True
return self.model_response