fix(main.py): safely fail stream_chunk_builder calls

This commit is contained in:
Krrish Dholakia 2024-08-10 10:22:26 -07:00
parent dd2ea72cb4
commit 068ee12c30
3 changed files with 259 additions and 231 deletions

View file

@ -5005,231 +5005,249 @@ def stream_chunk_builder_text_completion(chunks: list, messages: Optional[List]
def stream_chunk_builder( def stream_chunk_builder(
chunks: list, messages: Optional[list] = None, start_time=None, end_time=None chunks: list, messages: Optional[list] = None, start_time=None, end_time=None
) -> Union[ModelResponse, TextCompletionResponse]: ) -> Optional[Union[ModelResponse, TextCompletionResponse]]:
model_response = litellm.ModelResponse()
### SORT CHUNKS BASED ON CREATED ORDER ##
print_verbose("Goes into checking if chunk has hiddden created at param")
if chunks[0]._hidden_params.get("created_at", None):
print_verbose("Chunks have a created at hidden param")
# Sort chunks based on created_at in ascending order
chunks = sorted(
chunks, key=lambda x: x._hidden_params.get("created_at", float("inf"))
)
print_verbose("Chunks sorted")
# set hidden params from chunk to model_response
if model_response is not None and hasattr(model_response, "_hidden_params"):
model_response._hidden_params = chunks[0].get("_hidden_params", {})
id = chunks[0]["id"]
object = chunks[0]["object"]
created = chunks[0]["created"]
model = chunks[0]["model"]
system_fingerprint = chunks[0].get("system_fingerprint", None)
if isinstance(
chunks[0]["choices"][0], litellm.utils.TextChoices
): # route to the text completion logic
return stream_chunk_builder_text_completion(chunks=chunks, messages=messages)
role = chunks[0]["choices"][0]["delta"]["role"]
finish_reason = chunks[-1]["choices"][0]["finish_reason"]
# Initialize the response dictionary
response = {
"id": id,
"object": object,
"created": created,
"model": model,
"system_fingerprint": system_fingerprint,
"choices": [
{
"index": 0,
"message": {"role": role, "content": ""},
"finish_reason": finish_reason,
}
],
"usage": {
"prompt_tokens": 0, # Modify as needed
"completion_tokens": 0, # Modify as needed
"total_tokens": 0, # Modify as needed
},
}
# Extract the "content" strings from the nested dictionaries within "choices"
content_list = []
combined_content = ""
combined_arguments = ""
tool_call_chunks = [
chunk
for chunk in chunks
if "tool_calls" in chunk["choices"][0]["delta"]
and chunk["choices"][0]["delta"]["tool_calls"] is not None
]
if len(tool_call_chunks) > 0:
argument_list = []
delta = tool_call_chunks[0]["choices"][0]["delta"]
message = response["choices"][0]["message"]
message["tool_calls"] = []
id = None
name = None
type = None
tool_calls_list = []
prev_index = None
prev_id = None
curr_id = None
curr_index = 0
for chunk in tool_call_chunks:
choices = chunk["choices"]
for choice in choices:
delta = choice.get("delta", {})
tool_calls = delta.get("tool_calls", "")
# Check if a tool call is present
if tool_calls and tool_calls[0].function is not None:
if tool_calls[0].id:
id = tool_calls[0].id
curr_id = id
if prev_id is None:
prev_id = curr_id
if tool_calls[0].index:
curr_index = tool_calls[0].index
if tool_calls[0].function.arguments:
# Now, tool_calls is expected to be a dictionary
arguments = tool_calls[0].function.arguments
argument_list.append(arguments)
if tool_calls[0].function.name:
name = tool_calls[0].function.name
if tool_calls[0].type:
type = tool_calls[0].type
if prev_index is None:
prev_index = curr_index
if curr_index != prev_index: # new tool call
combined_arguments = "".join(argument_list)
tool_calls_list.append(
{
"id": prev_id,
"index": prev_index,
"function": {"arguments": combined_arguments, "name": name},
"type": type,
}
)
argument_list = [] # reset
prev_index = curr_index
prev_id = curr_id
combined_arguments = (
"".join(argument_list) or "{}"
) # base case, return empty dict
tool_calls_list.append(
{
"id": id,
"index": curr_index,
"function": {"arguments": combined_arguments, "name": name},
"type": type,
}
)
response["choices"][0]["message"]["content"] = None
response["choices"][0]["message"]["tool_calls"] = tool_calls_list
function_call_chunks = [
chunk
for chunk in chunks
if "function_call" in chunk["choices"][0]["delta"]
and chunk["choices"][0]["delta"]["function_call"] is not None
]
if len(function_call_chunks) > 0:
argument_list = []
delta = function_call_chunks[0]["choices"][0]["delta"]
function_call = delta.get("function_call", "")
function_call_name = function_call.name
message = response["choices"][0]["message"]
message["function_call"] = {}
message["function_call"]["name"] = function_call_name
for chunk in function_call_chunks:
choices = chunk["choices"]
for choice in choices:
delta = choice.get("delta", {})
function_call = delta.get("function_call", "")
# Check if a function call is present
if function_call:
# Now, function_call is expected to be a dictionary
arguments = function_call.arguments
argument_list.append(arguments)
combined_arguments = "".join(argument_list)
response["choices"][0]["message"]["content"] = None
response["choices"][0]["message"]["function_call"][
"arguments"
] = combined_arguments
content_chunks = [
chunk
for chunk in chunks
if "content" in chunk["choices"][0]["delta"]
and chunk["choices"][0]["delta"]["content"] is not None
]
if len(content_chunks) > 0:
for chunk in chunks:
choices = chunk["choices"]
for choice in choices:
delta = choice.get("delta", {})
content = delta.get("content", "")
if content == None:
continue # openai v1.0.0 sets content = None for chunks
content_list.append(content)
# Combine the "content" strings into a single string || combine the 'function' strings into a single string
combined_content = "".join(content_list)
# Update the "content" field within the response dictionary
response["choices"][0]["message"]["content"] = combined_content
completion_output = ""
if len(combined_content) > 0:
completion_output += combined_content
if len(combined_arguments) > 0:
completion_output += combined_arguments
# # Update usage information if needed
prompt_tokens = 0
completion_tokens = 0
for chunk in chunks:
usage_chunk: Optional[Usage] = None
if "usage" in chunk:
usage_chunk = chunk.usage
elif hasattr(chunk, "_hidden_params") and "usage" in chunk._hidden_params:
usage_chunk = chunk._hidden_params["usage"]
if usage_chunk is not None:
if "prompt_tokens" in usage_chunk:
prompt_tokens = usage_chunk.get("prompt_tokens", 0) or 0
if "completion_tokens" in usage_chunk:
completion_tokens = usage_chunk.get("completion_tokens", 0) or 0
try: try:
response["usage"]["prompt_tokens"] = prompt_tokens or token_counter( model_response = litellm.ModelResponse()
model=model, messages=messages ### BASE-CASE ###
) if len(chunks) == 0:
except ( return None
Exception ### SORT CHUNKS BASED ON CREATED ORDER ##
): # don't allow this failing to block a complete streaming response from being returned print_verbose("Goes into checking if chunk has hiddden created at param")
print_verbose("token_counter failed, assuming prompt tokens is 0") if chunks[0]._hidden_params.get("created_at", None):
response["usage"]["prompt_tokens"] = 0 print_verbose("Chunks have a created at hidden param")
response["usage"]["completion_tokens"] = completion_tokens or token_counter( # Sort chunks based on created_at in ascending order
model=model, chunks = sorted(
text=completion_output, chunks, key=lambda x: x._hidden_params.get("created_at", float("inf"))
count_response_tokens=True, # count_response_tokens is a Flag to tell token counter this is a response, No need to add extra tokens we do for input messages )
) print_verbose("Chunks sorted")
response["usage"]["total_tokens"] = (
response["usage"]["prompt_tokens"] + response["usage"]["completion_tokens"]
)
return convert_to_model_response_object( # set hidden params from chunk to model_response
response_object=response, if model_response is not None and hasattr(model_response, "_hidden_params"):
model_response_object=model_response, model_response._hidden_params = chunks[0].get("_hidden_params", {})
start_time=start_time, id = chunks[0]["id"]
end_time=end_time, object = chunks[0]["object"]
) created = chunks[0]["created"]
model = chunks[0]["model"]
system_fingerprint = chunks[0].get("system_fingerprint", None)
if isinstance(
chunks[0]["choices"][0], litellm.utils.TextChoices
): # route to the text completion logic
return stream_chunk_builder_text_completion(
chunks=chunks, messages=messages
)
role = chunks[0]["choices"][0]["delta"]["role"]
finish_reason = chunks[-1]["choices"][0]["finish_reason"]
# Initialize the response dictionary
response = {
"id": id,
"object": object,
"created": created,
"model": model,
"system_fingerprint": system_fingerprint,
"choices": [
{
"index": 0,
"message": {"role": role, "content": ""},
"finish_reason": finish_reason,
}
],
"usage": {
"prompt_tokens": 0, # Modify as needed
"completion_tokens": 0, # Modify as needed
"total_tokens": 0, # Modify as needed
},
}
# Extract the "content" strings from the nested dictionaries within "choices"
content_list = []
combined_content = ""
combined_arguments = ""
tool_call_chunks = [
chunk
for chunk in chunks
if "tool_calls" in chunk["choices"][0]["delta"]
and chunk["choices"][0]["delta"]["tool_calls"] is not None
]
if len(tool_call_chunks) > 0:
argument_list = []
delta = tool_call_chunks[0]["choices"][0]["delta"]
message = response["choices"][0]["message"]
message["tool_calls"] = []
id = None
name = None
type = None
tool_calls_list = []
prev_index = None
prev_id = None
curr_id = None
curr_index = 0
for chunk in tool_call_chunks:
choices = chunk["choices"]
for choice in choices:
delta = choice.get("delta", {})
tool_calls = delta.get("tool_calls", "")
# Check if a tool call is present
if tool_calls and tool_calls[0].function is not None:
if tool_calls[0].id:
id = tool_calls[0].id
curr_id = id
if prev_id is None:
prev_id = curr_id
if tool_calls[0].index:
curr_index = tool_calls[0].index
if tool_calls[0].function.arguments:
# Now, tool_calls is expected to be a dictionary
arguments = tool_calls[0].function.arguments
argument_list.append(arguments)
if tool_calls[0].function.name:
name = tool_calls[0].function.name
if tool_calls[0].type:
type = tool_calls[0].type
if prev_index is None:
prev_index = curr_index
if curr_index != prev_index: # new tool call
combined_arguments = "".join(argument_list)
tool_calls_list.append(
{
"id": prev_id,
"index": prev_index,
"function": {"arguments": combined_arguments, "name": name},
"type": type,
}
)
argument_list = [] # reset
prev_index = curr_index
prev_id = curr_id
combined_arguments = (
"".join(argument_list) or "{}"
) # base case, return empty dict
tool_calls_list.append(
{
"id": id,
"index": curr_index,
"function": {"arguments": combined_arguments, "name": name},
"type": type,
}
)
response["choices"][0]["message"]["content"] = None
response["choices"][0]["message"]["tool_calls"] = tool_calls_list
function_call_chunks = [
chunk
for chunk in chunks
if "function_call" in chunk["choices"][0]["delta"]
and chunk["choices"][0]["delta"]["function_call"] is not None
]
if len(function_call_chunks) > 0:
argument_list = []
delta = function_call_chunks[0]["choices"][0]["delta"]
function_call = delta.get("function_call", "")
function_call_name = function_call.name
message = response["choices"][0]["message"]
message["function_call"] = {}
message["function_call"]["name"] = function_call_name
for chunk in function_call_chunks:
choices = chunk["choices"]
for choice in choices:
delta = choice.get("delta", {})
function_call = delta.get("function_call", "")
# Check if a function call is present
if function_call:
# Now, function_call is expected to be a dictionary
arguments = function_call.arguments
argument_list.append(arguments)
combined_arguments = "".join(argument_list)
response["choices"][0]["message"]["content"] = None
response["choices"][0]["message"]["function_call"][
"arguments"
] = combined_arguments
content_chunks = [
chunk
for chunk in chunks
if "content" in chunk["choices"][0]["delta"]
and chunk["choices"][0]["delta"]["content"] is not None
]
if len(content_chunks) > 0:
for chunk in chunks:
choices = chunk["choices"]
for choice in choices:
delta = choice.get("delta", {})
content = delta.get("content", "")
if content == None:
continue # openai v1.0.0 sets content = None for chunks
content_list.append(content)
# Combine the "content" strings into a single string || combine the 'function' strings into a single string
combined_content = "".join(content_list)
# Update the "content" field within the response dictionary
response["choices"][0]["message"]["content"] = combined_content
completion_output = ""
if len(combined_content) > 0:
completion_output += combined_content
if len(combined_arguments) > 0:
completion_output += combined_arguments
# # Update usage information if needed
prompt_tokens = 0
completion_tokens = 0
for chunk in chunks:
usage_chunk: Optional[Usage] = None
if "usage" in chunk:
usage_chunk = chunk.usage
elif hasattr(chunk, "_hidden_params") and "usage" in chunk._hidden_params:
usage_chunk = chunk._hidden_params["usage"]
if usage_chunk is not None:
if "prompt_tokens" in usage_chunk:
prompt_tokens = usage_chunk.get("prompt_tokens", 0) or 0
if "completion_tokens" in usage_chunk:
completion_tokens = usage_chunk.get("completion_tokens", 0) or 0
try:
response["usage"]["prompt_tokens"] = prompt_tokens or token_counter(
model=model, messages=messages
)
except (
Exception
): # don't allow this failing to block a complete streaming response from being returned
print_verbose("token_counter failed, assuming prompt tokens is 0")
response["usage"]["prompt_tokens"] = 0
response["usage"]["completion_tokens"] = completion_tokens or token_counter(
model=model,
text=completion_output,
count_response_tokens=True, # count_response_tokens is a Flag to tell token counter this is a response, No need to add extra tokens we do for input messages
)
response["usage"]["total_tokens"] = (
response["usage"]["prompt_tokens"] + response["usage"]["completion_tokens"]
)
return convert_to_model_response_object(
response_object=response,
model_response_object=model_response,
start_time=start_time,
end_time=end_time,
) # type: ignore
except Exception as e:
verbose_logger.error(
"litellm.main.py::stream_chunk_builder() - Exception occurred - {}\n{}".format(
str(e), traceback.format_exc()
)
)
raise litellm.APIError(
status_code=500,
message="Error building chunks for logging/streaming usage calculation",
llm_provider="",
model="",
)

View file

@ -16,9 +16,8 @@ import pytest
from openai import OpenAI from openai import OpenAI
import litellm import litellm
from litellm import completion, stream_chunk_builder
import litellm.tests.stream_chunk_testdata import litellm.tests.stream_chunk_testdata
from litellm import completion, stream_chunk_builder
dotenv.load_dotenv() dotenv.load_dotenv()
@ -219,3 +218,11 @@ def test_stream_chunk_builder_litellm_mixed_calls():
"id": "toolu_01H3AjkLpRtGQrof13CBnWfK", "id": "toolu_01H3AjkLpRtGQrof13CBnWfK",
"type": "function", "type": "function",
} }
def test_stream_chunk_builder_litellm_empty_chunks():
with pytest.raises(litellm.APIError):
response = stream_chunk_builder(chunks=None)
response = stream_chunk_builder(chunks=[])
assert response is None

View file

@ -10307,7 +10307,8 @@ class CustomStreamWrapper:
chunks=self.chunks, messages=self.messages chunks=self.chunks, messages=self.messages
) )
response = self.model_response_creator() response = self.model_response_creator()
response.usage = complete_streaming_response.usage # type: ignore if complete_streaming_response is not None:
response.usage = complete_streaming_response.usage
response._hidden_params["usage"] = complete_streaming_response.usage # type: ignore response._hidden_params["usage"] = complete_streaming_response.usage # type: ignore
## LOGGING ## LOGGING
threading.Thread( threading.Thread(
@ -10504,7 +10505,8 @@ class CustomStreamWrapper:
chunks=self.chunks, messages=self.messages chunks=self.chunks, messages=self.messages
) )
response = self.model_response_creator() response = self.model_response_creator()
response.usage = complete_streaming_response.usage if complete_streaming_response is not None:
setattr(response, "usage", complete_streaming_response.usage)
## LOGGING ## LOGGING
threading.Thread( threading.Thread(
target=self.logging_obj.success_handler, target=self.logging_obj.success_handler,
@ -10544,7 +10546,8 @@ class CustomStreamWrapper:
chunks=self.chunks, messages=self.messages chunks=self.chunks, messages=self.messages
) )
response = self.model_response_creator() response = self.model_response_creator()
response.usage = complete_streaming_response.usage if complete_streaming_response is not None:
response.usage = complete_streaming_response.usage
## LOGGING ## LOGGING
threading.Thread( threading.Thread(
target=self.logging_obj.success_handler, target=self.logging_obj.success_handler,