fix(anthropic.py): handle scenario where anthropic returns invalid json string for tool call while streaming

Fixes https://github.com/BerriAI/litellm/issues/5063
This commit is contained in:
Krrish Dholakia 2024-08-07 09:24:11 -07:00
parent 1008f24b16
commit 4919cc4d25
5 changed files with 105 additions and 7 deletions

View file

@ -2,6 +2,7 @@ import copy
import json
import os
import time
import traceback
import types
from enum import Enum
from functools import partial
@ -36,6 +37,7 @@ from litellm.types.llms.anthropic import (
AnthropicResponseUsageBlock,
ContentBlockDelta,
ContentBlockStart,
ContentBlockStop,
ContentJsonBlockDelta,
ContentTextBlockDelta,
MessageBlockDelta,
@ -920,7 +922,12 @@ class AnthropicChatCompletion(BaseLLM):
model=model, messages=messages, custom_llm_provider="anthropic"
)
except Exception as e:
raise AnthropicError(status_code=400, message=str(e))
raise AnthropicError(
status_code=400,
message="{}\n{}\nReceived Messages={}".format(
str(e), traceback.format_exc(), messages
),
)
## Load Config
config = litellm.AnthropicConfig.get_config()
@ -1079,10 +1086,30 @@ class ModelResponseIterator:
def __init__(self, streaming_response, sync_stream: bool):
self.streaming_response = streaming_response
self.response_iterator = self.streaming_response
self.content_blocks: List[ContentBlockDelta] = []
def check_empty_tool_call_args(self) -> bool:
"""
Check if the tool call block so far has been an empty string
"""
args = ""
# if text content block -> skip
if len(self.content_blocks) == 0:
return False
if self.content_blocks[0]["delta"]["type"] == "text_delta":
return False
for block in self.content_blocks:
if block["delta"]["type"] == "input_json_delta":
args += block["delta"].get("partial_json", "") # type: ignore
if len(args) == 0:
return True
return False
def chunk_parser(self, chunk: dict) -> GenericStreamingChunk:
try:
verbose_logger.debug(f"\n\nRaw chunk:\n{chunk}\n")
type_chunk = chunk.get("type", "") or ""
text = ""
@ -1098,6 +1125,7 @@ class ModelResponseIterator:
chunk = {'type': 'content_block_delta', 'index': 0, 'delta': {'type': 'text_delta', 'text': 'Hello'}}
"""
content_block = ContentBlockDelta(**chunk) # type: ignore
self.content_blocks.append(content_block)
if "text" in content_block["delta"]:
text = content_block["delta"]["text"]
elif "partial_json" in content_block["delta"]:
@ -1116,6 +1144,7 @@ class ModelResponseIterator:
data: {"type":"content_block_start","index":1,"content_block":{"type":"tool_use","id":"toolu_01T1x1fJ34qAmk2tNTrN7Up6","name":"get_weather","input":{}}}
"""
content_block_start = ContentBlockStart(**chunk) # type: ignore
self.content_blocks = [] # reset content blocks when new block starts
if content_block_start["content_block"]["type"] == "text":
text = content_block_start["content_block"]["text"]
elif content_block_start["content_block"]["type"] == "tool_use":
@ -1128,6 +1157,20 @@ class ModelResponseIterator:
},
"index": content_block_start["index"],
}
elif type_chunk == "content_block_stop":
content_block_stop = ContentBlockStop(**chunk) # type: ignore
# check if tool call content block
is_empty = self.check_empty_tool_call_args()
if is_empty:
tool_use = {
"id": None,
"type": "function",
"function": {
"name": None,
"arguments": "{}",
},
"index": content_block_stop["index"],
}
elif type_chunk == "message_delta":
"""
Anthropic

View file

@ -5113,7 +5113,9 @@ def stream_chunk_builder(
prev_index = curr_index
prev_id = curr_id
combined_arguments = "".join(argument_list)
combined_arguments = (
"".join(argument_list) or "{}"
) # base case, return empty dict
tool_calls_list.append(
{
"id": id,

View file

@ -4346,3 +4346,51 @@ def test_moderation():
# test_moderation()
@pytest.mark.parametrize("model", ["gpt-3.5-turbo", "claude-3-5-sonnet-20240620"])
def test_streaming_tool_calls_valid_json_str(model):
messages = [
{"role": "user", "content": "Hit the snooze button."},
]
tools = [
{
"type": "function",
"function": {
"name": "snooze",
"parameters": {
"type": "object",
"properties": {},
"required": [],
},
},
}
]
stream = litellm.completion(model, messages, tools=tools, stream=True)
chunks = [*stream]
print(chunks)
tool_call_id_arg_map = {}
curr_tool_call_id = None
curr_tool_call_str = ""
for chunk in chunks:
if chunk.choices[0].delta.tool_calls is not None:
if chunk.choices[0].delta.tool_calls[0].id is not None:
# flush prev tool call
if curr_tool_call_id is not None:
tool_call_id_arg_map[curr_tool_call_id] = curr_tool_call_str
curr_tool_call_str = ""
curr_tool_call_id = chunk.choices[0].delta.tool_calls[0].id
tool_call_id_arg_map[curr_tool_call_id] = ""
if chunk.choices[0].delta.tool_calls[0].function.arguments is not None:
curr_tool_call_str += (
chunk.choices[0].delta.tool_calls[0].function.arguments
)
# flush prev tool call
if curr_tool_call_id is not None:
tool_call_id_arg_map[curr_tool_call_id] = curr_tool_call_str
for k, v in tool_call_id_arg_map.items():
print("k={}, v={}".format(k, v))
json.loads(v) # valid json str

View file

@ -2596,8 +2596,8 @@ def streaming_and_function_calling_format_tests(idx, chunk):
@pytest.mark.parametrize(
"model",
[
"gpt-3.5-turbo",
"anthropic.claude-3-sonnet-20240229-v1:0",
# "gpt-3.5-turbo",
# "anthropic.claude-3-sonnet-20240229-v1:0",
"claude-3-haiku-20240307",
],
)
@ -2627,7 +2627,7 @@ def test_streaming_and_function_calling(model):
messages = [{"role": "user", "content": "What is the weather like in Boston?"}]
try:
litellm.set_verbose = True
# litellm.set_verbose = True
response: litellm.CustomStreamWrapper = completion(
model=model,
tools=tools,
@ -2639,7 +2639,7 @@ def test_streaming_and_function_calling(model):
json_str = ""
for idx, chunk in enumerate(response):
# continue
print("\n{}\n".format(chunk))
# print("\n{}\n".format(chunk))
if idx == 0:
assert (
chunk.choices[0].delta.tool_calls[0].function.arguments is not None

View file

@ -141,6 +141,11 @@ class ContentBlockDelta(TypedDict):
delta: Union[ContentTextBlockDelta, ContentJsonBlockDelta]
class ContentBlockStop(TypedDict):
type: Literal["content_block_stop"]
index: int
class ToolUseBlock(TypedDict):
"""
"content_block":{"type":"tool_use","id":"toolu_01T1x1fJ34qAmk2tNTrN7Up6","name":"get_weather","input":{}}