fix(anthropic.py): handle scenario where anthropic returns invalid json string for tool call while streaming

Fixes https://github.com/BerriAI/litellm/issues/5063
This commit is contained in:
Krrish Dholakia 2024-08-07 09:24:11 -07:00
parent 1008f24b16
commit 4919cc4d25
5 changed files with 105 additions and 7 deletions

View file

@ -2,6 +2,7 @@ import copy
import json import json
import os import os
import time import time
import traceback
import types import types
from enum import Enum from enum import Enum
from functools import partial from functools import partial
@ -36,6 +37,7 @@ from litellm.types.llms.anthropic import (
AnthropicResponseUsageBlock, AnthropicResponseUsageBlock,
ContentBlockDelta, ContentBlockDelta,
ContentBlockStart, ContentBlockStart,
ContentBlockStop,
ContentJsonBlockDelta, ContentJsonBlockDelta,
ContentTextBlockDelta, ContentTextBlockDelta,
MessageBlockDelta, MessageBlockDelta,
@ -920,7 +922,12 @@ class AnthropicChatCompletion(BaseLLM):
model=model, messages=messages, custom_llm_provider="anthropic" model=model, messages=messages, custom_llm_provider="anthropic"
) )
except Exception as e: except Exception as e:
raise AnthropicError(status_code=400, message=str(e)) raise AnthropicError(
status_code=400,
message="{}\n{}\nReceived Messages={}".format(
str(e), traceback.format_exc(), messages
),
)
## Load Config ## Load Config
config = litellm.AnthropicConfig.get_config() config = litellm.AnthropicConfig.get_config()
@ -1079,10 +1086,30 @@ class ModelResponseIterator:
def __init__(self, streaming_response, sync_stream: bool): def __init__(self, streaming_response, sync_stream: bool):
self.streaming_response = streaming_response self.streaming_response = streaming_response
self.response_iterator = self.streaming_response self.response_iterator = self.streaming_response
self.content_blocks: List[ContentBlockDelta] = []
def check_empty_tool_call_args(self) -> bool:
"""
Check if the tool call block so far has been an empty string
"""
args = ""
# if text content block -> skip
if len(self.content_blocks) == 0:
return False
if self.content_blocks[0]["delta"]["type"] == "text_delta":
return False
for block in self.content_blocks:
if block["delta"]["type"] == "input_json_delta":
args += block["delta"].get("partial_json", "") # type: ignore
if len(args) == 0:
return True
return False
def chunk_parser(self, chunk: dict) -> GenericStreamingChunk: def chunk_parser(self, chunk: dict) -> GenericStreamingChunk:
try: try:
verbose_logger.debug(f"\n\nRaw chunk:\n{chunk}\n")
type_chunk = chunk.get("type", "") or "" type_chunk = chunk.get("type", "") or ""
text = "" text = ""
@ -1098,6 +1125,7 @@ class ModelResponseIterator:
chunk = {'type': 'content_block_delta', 'index': 0, 'delta': {'type': 'text_delta', 'text': 'Hello'}} chunk = {'type': 'content_block_delta', 'index': 0, 'delta': {'type': 'text_delta', 'text': 'Hello'}}
""" """
content_block = ContentBlockDelta(**chunk) # type: ignore content_block = ContentBlockDelta(**chunk) # type: ignore
self.content_blocks.append(content_block)
if "text" in content_block["delta"]: if "text" in content_block["delta"]:
text = content_block["delta"]["text"] text = content_block["delta"]["text"]
elif "partial_json" in content_block["delta"]: elif "partial_json" in content_block["delta"]:
@ -1116,6 +1144,7 @@ class ModelResponseIterator:
data: {"type":"content_block_start","index":1,"content_block":{"type":"tool_use","id":"toolu_01T1x1fJ34qAmk2tNTrN7Up6","name":"get_weather","input":{}}} data: {"type":"content_block_start","index":1,"content_block":{"type":"tool_use","id":"toolu_01T1x1fJ34qAmk2tNTrN7Up6","name":"get_weather","input":{}}}
""" """
content_block_start = ContentBlockStart(**chunk) # type: ignore content_block_start = ContentBlockStart(**chunk) # type: ignore
self.content_blocks = [] # reset content blocks when new block starts
if content_block_start["content_block"]["type"] == "text": if content_block_start["content_block"]["type"] == "text":
text = content_block_start["content_block"]["text"] text = content_block_start["content_block"]["text"]
elif content_block_start["content_block"]["type"] == "tool_use": elif content_block_start["content_block"]["type"] == "tool_use":
@ -1128,6 +1157,20 @@ class ModelResponseIterator:
}, },
"index": content_block_start["index"], "index": content_block_start["index"],
} }
elif type_chunk == "content_block_stop":
content_block_stop = ContentBlockStop(**chunk) # type: ignore
# check if tool call content block
is_empty = self.check_empty_tool_call_args()
if is_empty:
tool_use = {
"id": None,
"type": "function",
"function": {
"name": None,
"arguments": "{}",
},
"index": content_block_stop["index"],
}
elif type_chunk == "message_delta": elif type_chunk == "message_delta":
""" """
Anthropic Anthropic

View file

@ -5113,7 +5113,9 @@ def stream_chunk_builder(
prev_index = curr_index prev_index = curr_index
prev_id = curr_id prev_id = curr_id
combined_arguments = "".join(argument_list) combined_arguments = (
"".join(argument_list) or "{}"
) # base case, return empty dict
tool_calls_list.append( tool_calls_list.append(
{ {
"id": id, "id": id,

View file

@ -4346,3 +4346,51 @@ def test_moderation():
# test_moderation() # test_moderation()
@pytest.mark.parametrize("model", ["gpt-3.5-turbo", "claude-3-5-sonnet-20240620"])
def test_streaming_tool_calls_valid_json_str(model):
messages = [
{"role": "user", "content": "Hit the snooze button."},
]
tools = [
{
"type": "function",
"function": {
"name": "snooze",
"parameters": {
"type": "object",
"properties": {},
"required": [],
},
},
}
]
stream = litellm.completion(model, messages, tools=tools, stream=True)
chunks = [*stream]
print(chunks)
tool_call_id_arg_map = {}
curr_tool_call_id = None
curr_tool_call_str = ""
for chunk in chunks:
if chunk.choices[0].delta.tool_calls is not None:
if chunk.choices[0].delta.tool_calls[0].id is not None:
# flush prev tool call
if curr_tool_call_id is not None:
tool_call_id_arg_map[curr_tool_call_id] = curr_tool_call_str
curr_tool_call_str = ""
curr_tool_call_id = chunk.choices[0].delta.tool_calls[0].id
tool_call_id_arg_map[curr_tool_call_id] = ""
if chunk.choices[0].delta.tool_calls[0].function.arguments is not None:
curr_tool_call_str += (
chunk.choices[0].delta.tool_calls[0].function.arguments
)
# flush prev tool call
if curr_tool_call_id is not None:
tool_call_id_arg_map[curr_tool_call_id] = curr_tool_call_str
for k, v in tool_call_id_arg_map.items():
print("k={}, v={}".format(k, v))
json.loads(v) # valid json str

View file

@ -2596,8 +2596,8 @@ def streaming_and_function_calling_format_tests(idx, chunk):
@pytest.mark.parametrize( @pytest.mark.parametrize(
"model", "model",
[ [
"gpt-3.5-turbo", # "gpt-3.5-turbo",
"anthropic.claude-3-sonnet-20240229-v1:0", # "anthropic.claude-3-sonnet-20240229-v1:0",
"claude-3-haiku-20240307", "claude-3-haiku-20240307",
], ],
) )
@ -2627,7 +2627,7 @@ def test_streaming_and_function_calling(model):
messages = [{"role": "user", "content": "What is the weather like in Boston?"}] messages = [{"role": "user", "content": "What is the weather like in Boston?"}]
try: try:
litellm.set_verbose = True # litellm.set_verbose = True
response: litellm.CustomStreamWrapper = completion( response: litellm.CustomStreamWrapper = completion(
model=model, model=model,
tools=tools, tools=tools,
@ -2639,7 +2639,7 @@ def test_streaming_and_function_calling(model):
json_str = "" json_str = ""
for idx, chunk in enumerate(response): for idx, chunk in enumerate(response):
# continue # continue
print("\n{}\n".format(chunk)) # print("\n{}\n".format(chunk))
if idx == 0: if idx == 0:
assert ( assert (
chunk.choices[0].delta.tool_calls[0].function.arguments is not None chunk.choices[0].delta.tool_calls[0].function.arguments is not None

View file

@ -141,6 +141,11 @@ class ContentBlockDelta(TypedDict):
delta: Union[ContentTextBlockDelta, ContentJsonBlockDelta] delta: Union[ContentTextBlockDelta, ContentJsonBlockDelta]
class ContentBlockStop(TypedDict):
type: Literal["content_block_stop"]
index: int
class ToolUseBlock(TypedDict): class ToolUseBlock(TypedDict):
""" """
"content_block":{"type":"tool_use","id":"toolu_01T1x1fJ34qAmk2tNTrN7Up6","name":"get_weather","input":{}} "content_block":{"type":"tool_use","id":"toolu_01T1x1fJ34qAmk2tNTrN7Up6","name":"get_weather","input":{}}