mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 18:54:30 +00:00
feat(proxy_server.py): allow admin to return rejected response as string to user
Closes https://github.com/BerriAI/litellm/issues/3671
This commit is contained in:
parent
dbaeb8ff53
commit
372323c38a
4 changed files with 175 additions and 10 deletions
|
@ -4,7 +4,7 @@ import dotenv, os
|
||||||
|
|
||||||
from litellm.proxy._types import UserAPIKeyAuth
|
from litellm.proxy._types import UserAPIKeyAuth
|
||||||
from litellm.caching import DualCache
|
from litellm.caching import DualCache
|
||||||
|
from litellm.utils import ModelResponse
|
||||||
from typing import Literal, Union, Optional
|
from typing import Literal, Union, Optional
|
||||||
import traceback
|
import traceback
|
||||||
|
|
||||||
|
@ -64,8 +64,17 @@ class CustomLogger: # https://docs.litellm.ai/docs/observability/custom_callbac
|
||||||
user_api_key_dict: UserAPIKeyAuth,
|
user_api_key_dict: UserAPIKeyAuth,
|
||||||
cache: DualCache,
|
cache: DualCache,
|
||||||
data: dict,
|
data: dict,
|
||||||
call_type: Literal["completion", "embeddings", "image_generation"],
|
call_type: Literal[
|
||||||
):
|
"completion",
|
||||||
|
"text_completion",
|
||||||
|
"embeddings",
|
||||||
|
"image_generation",
|
||||||
|
"moderation",
|
||||||
|
"audio_transcription",
|
||||||
|
],
|
||||||
|
) -> Optional[
|
||||||
|
Union[Exception, str, dict]
|
||||||
|
]: # raise exception if invalid, return a str for the user to receive - if rejected, or return a modified dictionary for passing into litellm
|
||||||
pass
|
pass
|
||||||
|
|
||||||
async def async_post_call_failure_hook(
|
async def async_post_call_failure_hook(
|
||||||
|
|
|
@ -3761,11 +3761,24 @@ async def chat_completion(
|
||||||
|
|
||||||
data["litellm_logging_obj"] = logging_obj
|
data["litellm_logging_obj"] = logging_obj
|
||||||
|
|
||||||
### CALL HOOKS ### - modify incoming data before calling the model
|
### CALL HOOKS ### - modify/reject incoming data before calling the model
|
||||||
data = await proxy_logging_obj.pre_call_hook(
|
data = await proxy_logging_obj.pre_call_hook( # type: ignore
|
||||||
user_api_key_dict=user_api_key_dict, data=data, call_type="completion"
|
user_api_key_dict=user_api_key_dict, data=data, call_type="completion"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if isinstance(data, litellm.ModelResponse):
|
||||||
|
return data
|
||||||
|
elif isinstance(data, litellm.CustomStreamWrapper):
|
||||||
|
selected_data_generator = select_data_generator(
|
||||||
|
response=data,
|
||||||
|
user_api_key_dict=user_api_key_dict,
|
||||||
|
request_data={},
|
||||||
|
)
|
||||||
|
|
||||||
|
return StreamingResponse(
|
||||||
|
selected_data_generator,
|
||||||
|
media_type="text/event-stream",
|
||||||
|
)
|
||||||
tasks = []
|
tasks = []
|
||||||
tasks.append(
|
tasks.append(
|
||||||
proxy_logging_obj.during_call_hook(
|
proxy_logging_obj.during_call_hook(
|
||||||
|
@ -3998,10 +4011,24 @@ async def completion(
|
||||||
data["model"] = litellm.model_alias_map[data["model"]]
|
data["model"] = litellm.model_alias_map[data["model"]]
|
||||||
|
|
||||||
### CALL HOOKS ### - modify incoming data before calling the model
|
### CALL HOOKS ### - modify incoming data before calling the model
|
||||||
data = await proxy_logging_obj.pre_call_hook(
|
data = await proxy_logging_obj.pre_call_hook( # type: ignore
|
||||||
user_api_key_dict=user_api_key_dict, data=data, call_type="completion"
|
user_api_key_dict=user_api_key_dict, data=data, call_type="text_completion"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if isinstance(data, litellm.TextCompletionResponse):
|
||||||
|
return data
|
||||||
|
elif isinstance(data, litellm.TextCompletionStreamWrapper):
|
||||||
|
selected_data_generator = select_data_generator(
|
||||||
|
response=data,
|
||||||
|
user_api_key_dict=user_api_key_dict,
|
||||||
|
request_data={},
|
||||||
|
)
|
||||||
|
|
||||||
|
return StreamingResponse(
|
||||||
|
selected_data_generator,
|
||||||
|
media_type="text/event-stream",
|
||||||
|
)
|
||||||
|
|
||||||
### ROUTE THE REQUESTs ###
|
### ROUTE THE REQUESTs ###
|
||||||
router_model_names = llm_router.model_names if llm_router is not None else []
|
router_model_names = llm_router.model_names if llm_router is not None else []
|
||||||
# skip router if user passed their key
|
# skip router if user passed their key
|
||||||
|
|
|
@ -19,7 +19,16 @@ from litellm.proxy.hooks.parallel_request_limiter import (
|
||||||
_PROXY_MaxParallelRequestsHandler,
|
_PROXY_MaxParallelRequestsHandler,
|
||||||
)
|
)
|
||||||
from litellm._service_logger import ServiceLogging, ServiceTypes
|
from litellm._service_logger import ServiceLogging, ServiceTypes
|
||||||
from litellm import ModelResponse, EmbeddingResponse, ImageResponse
|
from litellm import (
|
||||||
|
ModelResponse,
|
||||||
|
EmbeddingResponse,
|
||||||
|
ImageResponse,
|
||||||
|
TranscriptionResponse,
|
||||||
|
TextCompletionResponse,
|
||||||
|
CustomStreamWrapper,
|
||||||
|
TextCompletionStreamWrapper,
|
||||||
|
)
|
||||||
|
from litellm.utils import ModelResponseIterator
|
||||||
from litellm.proxy.hooks.max_budget_limiter import _PROXY_MaxBudgetLimiter
|
from litellm.proxy.hooks.max_budget_limiter import _PROXY_MaxBudgetLimiter
|
||||||
from litellm.proxy.hooks.tpm_rpm_limiter import _PROXY_MaxTPMRPMLimiter
|
from litellm.proxy.hooks.tpm_rpm_limiter import _PROXY_MaxTPMRPMLimiter
|
||||||
from litellm.proxy.hooks.cache_control_check import _PROXY_CacheControlCheck
|
from litellm.proxy.hooks.cache_control_check import _PROXY_CacheControlCheck
|
||||||
|
@ -32,6 +41,7 @@ from email.mime.text import MIMEText
|
||||||
from email.mime.multipart import MIMEMultipart
|
from email.mime.multipart import MIMEMultipart
|
||||||
from datetime import datetime, timedelta
|
from datetime import datetime, timedelta
|
||||||
from litellm.integrations.slack_alerting import SlackAlerting
|
from litellm.integrations.slack_alerting import SlackAlerting
|
||||||
|
from typing_extensions import overload
|
||||||
|
|
||||||
|
|
||||||
def print_verbose(print_statement):
|
def print_verbose(print_statement):
|
||||||
|
@ -176,18 +186,60 @@ class ProxyLogging:
|
||||||
)
|
)
|
||||||
litellm.utils.set_callbacks(callback_list=callback_list)
|
litellm.utils.set_callbacks(callback_list=callback_list)
|
||||||
|
|
||||||
|
# fmt: off
|
||||||
|
|
||||||
|
@overload
|
||||||
|
async def pre_call_hook(
|
||||||
|
self,
|
||||||
|
user_api_key_dict: UserAPIKeyAuth,
|
||||||
|
data: dict,
|
||||||
|
call_type: Literal["completion"]
|
||||||
|
) -> Union[dict, ModelResponse, CustomStreamWrapper]:
|
||||||
|
...
|
||||||
|
|
||||||
|
@overload
|
||||||
|
async def pre_call_hook(
|
||||||
|
self,
|
||||||
|
user_api_key_dict: UserAPIKeyAuth,
|
||||||
|
data: dict,
|
||||||
|
call_type: Literal["text_completion"]
|
||||||
|
) -> Union[dict, TextCompletionResponse, TextCompletionStreamWrapper]:
|
||||||
|
...
|
||||||
|
|
||||||
|
@overload
|
||||||
|
async def pre_call_hook(
|
||||||
|
self,
|
||||||
|
user_api_key_dict: UserAPIKeyAuth,
|
||||||
|
data: dict,
|
||||||
|
call_type: Literal["embeddings",
|
||||||
|
"image_generation",
|
||||||
|
"moderation",
|
||||||
|
"audio_transcription",]
|
||||||
|
) -> dict:
|
||||||
|
...
|
||||||
|
|
||||||
|
# fmt: on
|
||||||
|
|
||||||
|
# The actual implementation of the function
|
||||||
async def pre_call_hook(
|
async def pre_call_hook(
|
||||||
self,
|
self,
|
||||||
user_api_key_dict: UserAPIKeyAuth,
|
user_api_key_dict: UserAPIKeyAuth,
|
||||||
data: dict,
|
data: dict,
|
||||||
call_type: Literal[
|
call_type: Literal[
|
||||||
"completion",
|
"completion",
|
||||||
|
"text_completion",
|
||||||
"embeddings",
|
"embeddings",
|
||||||
"image_generation",
|
"image_generation",
|
||||||
"moderation",
|
"moderation",
|
||||||
"audio_transcription",
|
"audio_transcription",
|
||||||
],
|
],
|
||||||
):
|
) -> Union[
|
||||||
|
dict,
|
||||||
|
ModelResponse,
|
||||||
|
TextCompletionResponse,
|
||||||
|
CustomStreamWrapper,
|
||||||
|
TextCompletionStreamWrapper,
|
||||||
|
]:
|
||||||
"""
|
"""
|
||||||
Allows users to modify/reject the incoming request to the proxy, without having to deal with parsing Request body.
|
Allows users to modify/reject the incoming request to the proxy, without having to deal with parsing Request body.
|
||||||
|
|
||||||
|
@ -214,7 +266,58 @@ class ProxyLogging:
|
||||||
call_type=call_type,
|
call_type=call_type,
|
||||||
)
|
)
|
||||||
if response is not None:
|
if response is not None:
|
||||||
data = response
|
if isinstance(response, Exception):
|
||||||
|
raise response
|
||||||
|
elif isinstance(response, dict):
|
||||||
|
data = response
|
||||||
|
elif isinstance(response, str):
|
||||||
|
if call_type == "completion":
|
||||||
|
_chat_response = ModelResponse()
|
||||||
|
_chat_response.choices[0].message.content = response
|
||||||
|
|
||||||
|
if (
|
||||||
|
data.get("stream", None) is not None
|
||||||
|
and data["stream"] == True
|
||||||
|
):
|
||||||
|
_iterator = ModelResponseIterator(
|
||||||
|
model_response=_chat_response
|
||||||
|
)
|
||||||
|
return CustomStreamWrapper(
|
||||||
|
completion_stream=_iterator,
|
||||||
|
model=data.get("model", ""),
|
||||||
|
custom_llm_provider="cached_response",
|
||||||
|
logging_obj=data.get(
|
||||||
|
"litellm_logging_obj", None
|
||||||
|
),
|
||||||
|
)
|
||||||
|
return _response
|
||||||
|
elif call_type == "text_completion":
|
||||||
|
if (
|
||||||
|
data.get("stream", None) is not None
|
||||||
|
and data["stream"] == True
|
||||||
|
):
|
||||||
|
_chat_response = ModelResponse()
|
||||||
|
_chat_response.choices[0].message.content = response
|
||||||
|
|
||||||
|
if (
|
||||||
|
data.get("stream", None) is not None
|
||||||
|
and data["stream"] == True
|
||||||
|
):
|
||||||
|
_iterator = ModelResponseIterator(
|
||||||
|
model_response=_chat_response
|
||||||
|
)
|
||||||
|
return TextCompletionStreamWrapper(
|
||||||
|
completion_stream=_iterator,
|
||||||
|
model=data.get("model", ""),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
_response = TextCompletionResponse()
|
||||||
|
_response.choices[0].text = response
|
||||||
|
return _response
|
||||||
|
else:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=400, detail={"error": response}
|
||||||
|
)
|
||||||
|
|
||||||
print_verbose(f"final data being sent to {call_type} call: {data}")
|
print_verbose(f"final data being sent to {call_type} call: {data}")
|
||||||
return data
|
return data
|
||||||
|
|
|
@ -12187,3 +12187,29 @@ def _add_key_name_and_team_to_alert(request_info: str, metadata: dict) -> str:
|
||||||
return request_info
|
return request_info
|
||||||
except:
|
except:
|
||||||
return request_info
|
return request_info
|
||||||
|
|
||||||
|
|
||||||
|
class ModelResponseIterator:
|
||||||
|
def __init__(self, model_response):
|
||||||
|
self.model_response = model_response
|
||||||
|
self.is_done = False
|
||||||
|
|
||||||
|
# Sync iterator
|
||||||
|
def __iter__(self):
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __next__(self):
|
||||||
|
if self.is_done:
|
||||||
|
raise StopIteration
|
||||||
|
self.is_done = True
|
||||||
|
return self.model_response
|
||||||
|
|
||||||
|
# Async iterator
|
||||||
|
def __aiter__(self):
|
||||||
|
return self
|
||||||
|
|
||||||
|
async def __anext__(self):
|
||||||
|
if self.is_done:
|
||||||
|
raise StopAsyncIteration
|
||||||
|
self.is_done = True
|
||||||
|
return self.model_response
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue