Merge pull request #3740 from BerriAI/litellm_return_rejected_response

feat(proxy_server.py): allow admin to return rejected response as string to user
This commit is contained in:
Krish Dholakia 2024-05-20 17:48:21 -07:00 committed by GitHub
commit c6bb6e325b
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
9 changed files with 338 additions and 28 deletions

View file

@ -124,6 +124,7 @@ from litellm.proxy.auth.auth_checks import (
get_actual_routes,
)
from litellm.llms.custom_httpx.httpx_handler import HTTPHandler
from litellm.exceptions import RejectedRequestError
try:
from litellm._version import version
@ -3763,8 +3764,8 @@ async def chat_completion(
data["litellm_logging_obj"] = logging_obj
### CALL HOOKS ### - modify incoming data before calling the model
data = await proxy_logging_obj.pre_call_hook(
### CALL HOOKS ### - modify/reject incoming data before calling the model
data = await proxy_logging_obj.pre_call_hook( # type: ignore
user_api_key_dict=user_api_key_dict, data=data, call_type="completion"
)
@ -3879,6 +3880,40 @@ async def chat_completion(
)
return response
except RejectedRequestError as e:
_data = e.request_data
_data["litellm_status"] = "fail" # used for alerting
await proxy_logging_obj.post_call_failure_hook(
user_api_key_dict=user_api_key_dict,
original_exception=e,
request_data=_data,
)
_chat_response = litellm.ModelResponse()
_chat_response.choices[0].message.content = e.message # type: ignore
if data.get("stream", None) is not None and data["stream"] == True:
_iterator = litellm.utils.ModelResponseIterator(
model_response=_chat_response, convert_to_delta=True
)
_streaming_response = litellm.CustomStreamWrapper(
completion_stream=_iterator,
model=data.get("model", ""),
custom_llm_provider="cached_response",
logging_obj=data.get("litellm_logging_obj", None),
)
selected_data_generator = select_data_generator(
response=_streaming_response,
user_api_key_dict=user_api_key_dict,
request_data=_data,
)
return StreamingResponse(
selected_data_generator,
media_type="text/event-stream",
)
_usage = litellm.Usage(prompt_tokens=0, completion_tokens=0, total_tokens=0)
_chat_response.usage = _usage # type: ignore
return _chat_response
except Exception as e:
data["litellm_status"] = "fail" # used for alerting
traceback.print_exc()
@ -3993,8 +4028,8 @@ async def completion(
data["model"] = litellm.model_alias_map[data["model"]]
### CALL HOOKS ### - modify incoming data before calling the model
data = await proxy_logging_obj.pre_call_hook(
user_api_key_dict=user_api_key_dict, data=data, call_type="completion"
data = await proxy_logging_obj.pre_call_hook( # type: ignore
user_api_key_dict=user_api_key_dict, data=data, call_type="text_completion"
)
### ROUTE THE REQUESTs ###
@ -4077,6 +4112,46 @@ async def completion(
)
return response
except RejectedRequestError as e:
_data = e.request_data
_data["litellm_status"] = "fail" # used for alerting
await proxy_logging_obj.post_call_failure_hook(
user_api_key_dict=user_api_key_dict,
original_exception=e,
request_data=_data,
)
if _data.get("stream", None) is not None and _data["stream"] == True:
_chat_response = litellm.ModelResponse()
_usage = litellm.Usage(
prompt_tokens=0,
completion_tokens=0,
total_tokens=0,
)
_chat_response.usage = _usage # type: ignore
_chat_response.choices[0].message.content = e.message # type: ignore
_iterator = litellm.utils.ModelResponseIterator(
model_response=_chat_response, convert_to_delta=True
)
_streaming_response = litellm.TextCompletionStreamWrapper(
completion_stream=_iterator,
model=_data.get("model", ""),
)
selected_data_generator = select_data_generator(
response=_streaming_response,
user_api_key_dict=user_api_key_dict,
request_data=data,
)
return StreamingResponse(
selected_data_generator,
media_type="text/event-stream",
headers={},
)
else:
_response = litellm.TextCompletionResponse()
_response.choices[0].text = e.message
return _response
except Exception as e:
data["litellm_status"] = "fail" # used for alerting
await proxy_logging_obj.post_call_failure_hook(