fix(main.py): safely fail stream_chunk_builder calls

This commit is contained in:
Krrish Dholakia 2024-08-10 10:22:26 -07:00
parent 6ff21433da
commit 3fd02a1587
3 changed files with 259 additions and 231 deletions

View file

@ -5005,8 +5005,12 @@ def stream_chunk_builder_text_completion(chunks: list, messages: Optional[List]
def stream_chunk_builder( def stream_chunk_builder(
chunks: list, messages: Optional[list] = None, start_time=None, end_time=None chunks: list, messages: Optional[list] = None, start_time=None, end_time=None
) -> Union[ModelResponse, TextCompletionResponse]: ) -> Optional[Union[ModelResponse, TextCompletionResponse]]:
try:
model_response = litellm.ModelResponse() model_response = litellm.ModelResponse()
### BASE-CASE ###
if len(chunks) == 0:
return None
### SORT CHUNKS BASED ON CREATED ORDER ## ### SORT CHUNKS BASED ON CREATED ORDER ##
print_verbose("Goes into checking if chunk has hiddden created at param") print_verbose("Goes into checking if chunk has hiddden created at param")
if chunks[0]._hidden_params.get("created_at", None): if chunks[0]._hidden_params.get("created_at", None):
@ -5029,7 +5033,9 @@ def stream_chunk_builder(
if isinstance( if isinstance(
chunks[0]["choices"][0], litellm.utils.TextChoices chunks[0]["choices"][0], litellm.utils.TextChoices
): # route to the text completion logic ): # route to the text completion logic
return stream_chunk_builder_text_completion(chunks=chunks, messages=messages) return stream_chunk_builder_text_completion(
chunks=chunks, messages=messages
)
role = chunks[0]["choices"][0]["delta"]["role"] role = chunks[0]["choices"][0]["delta"]["role"]
finish_reason = chunks[-1]["choices"][0]["finish_reason"] finish_reason = chunks[-1]["choices"][0]["finish_reason"]
@ -5232,4 +5238,16 @@ def stream_chunk_builder(
model_response_object=model_response, model_response_object=model_response,
start_time=start_time, start_time=start_time,
end_time=end_time, end_time=end_time,
) # type: ignore
except Exception as e:
verbose_logger.error(
"litellm.main.py::stream_chunk_builder() - Exception occurred - {}\n{}".format(
str(e), traceback.format_exc()
)
)
raise litellm.APIError(
status_code=500,
message="Error building chunks for logging/streaming usage calculation",
llm_provider="",
model="",
) )

View file

@ -16,9 +16,8 @@ import pytest
from openai import OpenAI from openai import OpenAI
import litellm import litellm
from litellm import completion, stream_chunk_builder
import litellm.tests.stream_chunk_testdata import litellm.tests.stream_chunk_testdata
from litellm import completion, stream_chunk_builder
dotenv.load_dotenv() dotenv.load_dotenv()
@ -219,3 +218,11 @@ def test_stream_chunk_builder_litellm_mixed_calls():
"id": "toolu_01H3AjkLpRtGQrof13CBnWfK", "id": "toolu_01H3AjkLpRtGQrof13CBnWfK",
"type": "function", "type": "function",
} }
def test_stream_chunk_builder_litellm_empty_chunks():
with pytest.raises(litellm.APIError):
response = stream_chunk_builder(chunks=None)
response = stream_chunk_builder(chunks=[])
assert response is None

View file

@ -10307,7 +10307,8 @@ class CustomStreamWrapper:
chunks=self.chunks, messages=self.messages chunks=self.chunks, messages=self.messages
) )
response = self.model_response_creator() response = self.model_response_creator()
response.usage = complete_streaming_response.usage # type: ignore if complete_streaming_response is not None:
response.usage = complete_streaming_response.usage
response._hidden_params["usage"] = complete_streaming_response.usage # type: ignore response._hidden_params["usage"] = complete_streaming_response.usage # type: ignore
## LOGGING ## LOGGING
threading.Thread( threading.Thread(
@ -10504,7 +10505,8 @@ class CustomStreamWrapper:
chunks=self.chunks, messages=self.messages chunks=self.chunks, messages=self.messages
) )
response = self.model_response_creator() response = self.model_response_creator()
response.usage = complete_streaming_response.usage if complete_streaming_response is not None:
setattr(response, "usage", complete_streaming_response.usage)
## LOGGING ## LOGGING
threading.Thread( threading.Thread(
target=self.logging_obj.success_handler, target=self.logging_obj.success_handler,
@ -10544,6 +10546,7 @@ class CustomStreamWrapper:
chunks=self.chunks, messages=self.messages chunks=self.chunks, messages=self.messages
) )
response = self.model_response_creator() response = self.model_response_creator()
if complete_streaming_response is not None:
response.usage = complete_streaming_response.usage response.usage = complete_streaming_response.usage
## LOGGING ## LOGGING
threading.Thread( threading.Thread(