From 4919cc4d255d3aa42ced83b8b0bc1eb8eed9fbac Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Wed, 7 Aug 2024 09:24:11 -0700 Subject: [PATCH] fix(anthropic.py): handle scenario where anthropic returns invalid json string for tool call while streaming Fixes https://github.com/BerriAI/litellm/issues/5063 --- litellm/llms/anthropic.py | 47 +++++++++++++++++++++++++++++-- litellm/main.py | 4 ++- litellm/tests/test_completion.py | 48 ++++++++++++++++++++++++++++++++ litellm/tests/test_streaming.py | 8 +++--- litellm/types/llms/anthropic.py | 5 ++++ 5 files changed, 105 insertions(+), 7 deletions(-) diff --git a/litellm/llms/anthropic.py b/litellm/llms/anthropic.py index 929375ef03..78888cf4ad 100644 --- a/litellm/llms/anthropic.py +++ b/litellm/llms/anthropic.py @@ -2,6 +2,7 @@ import copy import json import os import time +import traceback import types from enum import Enum from functools import partial @@ -36,6 +37,7 @@ from litellm.types.llms.anthropic import ( AnthropicResponseUsageBlock, ContentBlockDelta, ContentBlockStart, + ContentBlockStop, ContentJsonBlockDelta, ContentTextBlockDelta, MessageBlockDelta, @@ -920,7 +922,12 @@ class AnthropicChatCompletion(BaseLLM): model=model, messages=messages, custom_llm_provider="anthropic" ) except Exception as e: - raise AnthropicError(status_code=400, message=str(e)) + raise AnthropicError( + status_code=400, + message="{}\n{}\nReceived Messages={}".format( + str(e), traceback.format_exc(), messages + ), + ) ## Load Config config = litellm.AnthropicConfig.get_config() @@ -1079,10 +1086,30 @@ class ModelResponseIterator: def __init__(self, streaming_response, sync_stream: bool): self.streaming_response = streaming_response self.response_iterator = self.streaming_response + self.content_blocks: List[ContentBlockDelta] = [] + + def check_empty_tool_call_args(self) -> bool: + """ + Check if the tool call block so far has been an empty string + """ + args = "" + # if text content block -> skip + if len(self.content_blocks) == 0: + return False + + if self.content_blocks[0]["delta"]["type"] == "text_delta": + return False + + for block in self.content_blocks: + if block["delta"]["type"] == "input_json_delta": + args += block["delta"].get("partial_json", "") # type: ignore + + if len(args) == 0: + return True + return False def chunk_parser(self, chunk: dict) -> GenericStreamingChunk: try: - verbose_logger.debug(f"\n\nRaw chunk:\n{chunk}\n") type_chunk = chunk.get("type", "") or "" text = "" @@ -1098,6 +1125,7 @@ class ModelResponseIterator: chunk = {'type': 'content_block_delta', 'index': 0, 'delta': {'type': 'text_delta', 'text': 'Hello'}} """ content_block = ContentBlockDelta(**chunk) # type: ignore + self.content_blocks.append(content_block) if "text" in content_block["delta"]: text = content_block["delta"]["text"] elif "partial_json" in content_block["delta"]: @@ -1116,6 +1144,7 @@ class ModelResponseIterator: data: {"type":"content_block_start","index":1,"content_block":{"type":"tool_use","id":"toolu_01T1x1fJ34qAmk2tNTrN7Up6","name":"get_weather","input":{}}} """ content_block_start = ContentBlockStart(**chunk) # type: ignore + self.content_blocks = [] # reset content blocks when new block starts if content_block_start["content_block"]["type"] == "text": text = content_block_start["content_block"]["text"] elif content_block_start["content_block"]["type"] == "tool_use": @@ -1128,6 +1157,20 @@ class ModelResponseIterator: }, "index": content_block_start["index"], } + elif type_chunk == "content_block_stop": + content_block_stop = ContentBlockStop(**chunk) # type: ignore + # check if tool call content block + is_empty = self.check_empty_tool_call_args() + if is_empty: + tool_use = { + "id": None, + "type": "function", + "function": { + "name": None, + "arguments": "{}", + }, + "index": content_block_stop["index"], + } elif type_chunk == "message_delta": """ Anthropic diff --git a/litellm/main.py b/litellm/main.py index 1209306c8b..0fb26b9c12 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -5113,7 +5113,9 @@ def stream_chunk_builder( prev_index = curr_index prev_id = curr_id - combined_arguments = "".join(argument_list) + combined_arguments = ( + "".join(argument_list) or "{}" + ) # base case, return empty dict tool_calls_list.append( { "id": id, diff --git a/litellm/tests/test_completion.py b/litellm/tests/test_completion.py index eec163f26a..561764f121 100644 --- a/litellm/tests/test_completion.py +++ b/litellm/tests/test_completion.py @@ -4346,3 +4346,51 @@ def test_moderation(): # test_moderation() + + +@pytest.mark.parametrize("model", ["gpt-3.5-turbo", "claude-3-5-sonnet-20240620"]) +def test_streaming_tool_calls_valid_json_str(model): + messages = [ + {"role": "user", "content": "Hit the snooze button."}, + ] + + tools = [ + { + "type": "function", + "function": { + "name": "snooze", + "parameters": { + "type": "object", + "properties": {}, + "required": [], + }, + }, + } + ] + + stream = litellm.completion(model, messages, tools=tools, stream=True) + chunks = [*stream] + print(chunks) + tool_call_id_arg_map = {} + curr_tool_call_id = None + curr_tool_call_str = "" + for chunk in chunks: + if chunk.choices[0].delta.tool_calls is not None: + if chunk.choices[0].delta.tool_calls[0].id is not None: + # flush prev tool call + if curr_tool_call_id is not None: + tool_call_id_arg_map[curr_tool_call_id] = curr_tool_call_str + curr_tool_call_str = "" + curr_tool_call_id = chunk.choices[0].delta.tool_calls[0].id + tool_call_id_arg_map[curr_tool_call_id] = "" + if chunk.choices[0].delta.tool_calls[0].function.arguments is not None: + curr_tool_call_str += ( + chunk.choices[0].delta.tool_calls[0].function.arguments + ) + # flush prev tool call + if curr_tool_call_id is not None: + tool_call_id_arg_map[curr_tool_call_id] = curr_tool_call_str + + for k, v in tool_call_id_arg_map.items(): + print("k={}, v={}".format(k, v)) + json.loads(v) # valid json str diff --git a/litellm/tests/test_streaming.py b/litellm/tests/test_streaming.py index 9c53d5cfbc..e6f8641249 100644 --- a/litellm/tests/test_streaming.py +++ b/litellm/tests/test_streaming.py @@ -2596,8 +2596,8 @@ def streaming_and_function_calling_format_tests(idx, chunk): @pytest.mark.parametrize( "model", [ - "gpt-3.5-turbo", - "anthropic.claude-3-sonnet-20240229-v1:0", + # "gpt-3.5-turbo", + # "anthropic.claude-3-sonnet-20240229-v1:0", "claude-3-haiku-20240307", ], ) @@ -2627,7 +2627,7 @@ def test_streaming_and_function_calling(model): messages = [{"role": "user", "content": "What is the weather like in Boston?"}] try: - litellm.set_verbose = True + # litellm.set_verbose = True response: litellm.CustomStreamWrapper = completion( model=model, tools=tools, @@ -2639,7 +2639,7 @@ def test_streaming_and_function_calling(model): json_str = "" for idx, chunk in enumerate(response): # continue - print("\n{}\n".format(chunk)) + # print("\n{}\n".format(chunk)) if idx == 0: assert ( chunk.choices[0].delta.tool_calls[0].function.arguments is not None diff --git a/litellm/types/llms/anthropic.py b/litellm/types/llms/anthropic.py index 60784e9134..36bcb6cc73 100644 --- a/litellm/types/llms/anthropic.py +++ b/litellm/types/llms/anthropic.py @@ -141,6 +141,11 @@ class ContentBlockDelta(TypedDict): delta: Union[ContentTextBlockDelta, ContentJsonBlockDelta] +class ContentBlockStop(TypedDict): + type: Literal["content_block_stop"] + index: int + + class ToolUseBlock(TypedDict): """ "content_block":{"type":"tool_use","id":"toolu_01T1x1fJ34qAmk2tNTrN7Up6","name":"get_weather","input":{}}