fix(proxy_server.py): fixes for making rejected responses work with streaming

This commit is contained in:
Krrish Dholakia 2024-05-20 12:32:19 -07:00
parent f11f207ae6
commit b41f30ca60
4 changed files with 34 additions and 22 deletions

View file

@ -3894,7 +3894,7 @@ async def chat_completion(
if data.get("stream", None) is not None and data["stream"] == True:
_iterator = litellm.utils.ModelResponseIterator(
model_response=_chat_response
model_response=_chat_response, convert_to_delta=True
)
_streaming_response = litellm.CustomStreamWrapper(
completion_stream=_iterator,
@ -3903,7 +3903,7 @@ async def chat_completion(
logging_obj=data.get("litellm_logging_obj", None),
)
selected_data_generator = select_data_generator(
response=e.message,
response=_streaming_response,
user_api_key_dict=user_api_key_dict,
request_data=_data,
)
@ -4037,20 +4037,6 @@ async def completion(
user_api_key_dict=user_api_key_dict, data=data, call_type="text_completion"
)
if isinstance(data, litellm.TextCompletionResponse):
return data
elif isinstance(data, litellm.TextCompletionStreamWrapper):
selected_data_generator = select_data_generator(
response=data,
user_api_key_dict=user_api_key_dict,
request_data={},
)
return StreamingResponse(
selected_data_generator,
media_type="text/event-stream",
)
### ROUTE THE REQUESTs ###
router_model_names = llm_router.model_names if llm_router is not None else []
# skip router if user passed their key
@ -4152,12 +4138,24 @@ async def completion(
_chat_response.usage = _usage # type: ignore
_chat_response.choices[0].message.content = e.message # type: ignore
_iterator = litellm.utils.ModelResponseIterator(
model_response=_chat_response
model_response=_chat_response, convert_to_delta=True
)
return litellm.TextCompletionStreamWrapper(
_streaming_response = litellm.TextCompletionStreamWrapper(
completion_stream=_iterator,
model=_data.get("model", ""),
)
selected_data_generator = select_data_generator(
response=_streaming_response,
user_api_key_dict=user_api_key_dict,
request_data=data,
)
return StreamingResponse(
selected_data_generator,
media_type="text/event-stream",
headers={},
)
else:
_response = litellm.TextCompletionResponse()
_response.choices[0].text = e.message