mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 18:54:30 +00:00
fix(anthropic.py): handle scenario where anthropic returns invalid json string for tool call while streaming
Fixes https://github.com/BerriAI/litellm/issues/5063
This commit is contained in:
parent
1008f24b16
commit
4919cc4d25
5 changed files with 105 additions and 7 deletions
|
@ -2,6 +2,7 @@ import copy
|
|||
import json
|
||||
import os
|
||||
import time
|
||||
import traceback
|
||||
import types
|
||||
from enum import Enum
|
||||
from functools import partial
|
||||
|
@ -36,6 +37,7 @@ from litellm.types.llms.anthropic import (
|
|||
AnthropicResponseUsageBlock,
|
||||
ContentBlockDelta,
|
||||
ContentBlockStart,
|
||||
ContentBlockStop,
|
||||
ContentJsonBlockDelta,
|
||||
ContentTextBlockDelta,
|
||||
MessageBlockDelta,
|
||||
|
@ -920,7 +922,12 @@ class AnthropicChatCompletion(BaseLLM):
|
|||
model=model, messages=messages, custom_llm_provider="anthropic"
|
||||
)
|
||||
except Exception as e:
|
||||
raise AnthropicError(status_code=400, message=str(e))
|
||||
raise AnthropicError(
|
||||
status_code=400,
|
||||
message="{}\n{}\nReceived Messages={}".format(
|
||||
str(e), traceback.format_exc(), messages
|
||||
),
|
||||
)
|
||||
|
||||
## Load Config
|
||||
config = litellm.AnthropicConfig.get_config()
|
||||
|
@ -1079,10 +1086,30 @@ class ModelResponseIterator:
|
|||
def __init__(self, streaming_response, sync_stream: bool):
|
||||
self.streaming_response = streaming_response
|
||||
self.response_iterator = self.streaming_response
|
||||
self.content_blocks: List[ContentBlockDelta] = []
|
||||
|
||||
def check_empty_tool_call_args(self) -> bool:
|
||||
"""
|
||||
Check if the tool call block so far has been an empty string
|
||||
"""
|
||||
args = ""
|
||||
# if text content block -> skip
|
||||
if len(self.content_blocks) == 0:
|
||||
return False
|
||||
|
||||
if self.content_blocks[0]["delta"]["type"] == "text_delta":
|
||||
return False
|
||||
|
||||
for block in self.content_blocks:
|
||||
if block["delta"]["type"] == "input_json_delta":
|
||||
args += block["delta"].get("partial_json", "") # type: ignore
|
||||
|
||||
if len(args) == 0:
|
||||
return True
|
||||
return False
|
||||
|
||||
def chunk_parser(self, chunk: dict) -> GenericStreamingChunk:
|
||||
try:
|
||||
verbose_logger.debug(f"\n\nRaw chunk:\n{chunk}\n")
|
||||
type_chunk = chunk.get("type", "") or ""
|
||||
|
||||
text = ""
|
||||
|
@ -1098,6 +1125,7 @@ class ModelResponseIterator:
|
|||
chunk = {'type': 'content_block_delta', 'index': 0, 'delta': {'type': 'text_delta', 'text': 'Hello'}}
|
||||
"""
|
||||
content_block = ContentBlockDelta(**chunk) # type: ignore
|
||||
self.content_blocks.append(content_block)
|
||||
if "text" in content_block["delta"]:
|
||||
text = content_block["delta"]["text"]
|
||||
elif "partial_json" in content_block["delta"]:
|
||||
|
@ -1116,6 +1144,7 @@ class ModelResponseIterator:
|
|||
data: {"type":"content_block_start","index":1,"content_block":{"type":"tool_use","id":"toolu_01T1x1fJ34qAmk2tNTrN7Up6","name":"get_weather","input":{}}}
|
||||
"""
|
||||
content_block_start = ContentBlockStart(**chunk) # type: ignore
|
||||
self.content_blocks = [] # reset content blocks when new block starts
|
||||
if content_block_start["content_block"]["type"] == "text":
|
||||
text = content_block_start["content_block"]["text"]
|
||||
elif content_block_start["content_block"]["type"] == "tool_use":
|
||||
|
@ -1128,6 +1157,20 @@ class ModelResponseIterator:
|
|||
},
|
||||
"index": content_block_start["index"],
|
||||
}
|
||||
elif type_chunk == "content_block_stop":
|
||||
content_block_stop = ContentBlockStop(**chunk) # type: ignore
|
||||
# check if tool call content block
|
||||
is_empty = self.check_empty_tool_call_args()
|
||||
if is_empty:
|
||||
tool_use = {
|
||||
"id": None,
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": None,
|
||||
"arguments": "{}",
|
||||
},
|
||||
"index": content_block_stop["index"],
|
||||
}
|
||||
elif type_chunk == "message_delta":
|
||||
"""
|
||||
Anthropic
|
||||
|
|
|
@ -5113,7 +5113,9 @@ def stream_chunk_builder(
|
|||
prev_index = curr_index
|
||||
prev_id = curr_id
|
||||
|
||||
combined_arguments = "".join(argument_list)
|
||||
combined_arguments = (
|
||||
"".join(argument_list) or "{}"
|
||||
) # base case, return empty dict
|
||||
tool_calls_list.append(
|
||||
{
|
||||
"id": id,
|
||||
|
|
|
@ -4346,3 +4346,51 @@ def test_moderation():
|
|||
|
||||
|
||||
# test_moderation()
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model", ["gpt-3.5-turbo", "claude-3-5-sonnet-20240620"])
|
||||
def test_streaming_tool_calls_valid_json_str(model):
|
||||
messages = [
|
||||
{"role": "user", "content": "Hit the snooze button."},
|
||||
]
|
||||
|
||||
tools = [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "snooze",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {},
|
||||
"required": [],
|
||||
},
|
||||
},
|
||||
}
|
||||
]
|
||||
|
||||
stream = litellm.completion(model, messages, tools=tools, stream=True)
|
||||
chunks = [*stream]
|
||||
print(chunks)
|
||||
tool_call_id_arg_map = {}
|
||||
curr_tool_call_id = None
|
||||
curr_tool_call_str = ""
|
||||
for chunk in chunks:
|
||||
if chunk.choices[0].delta.tool_calls is not None:
|
||||
if chunk.choices[0].delta.tool_calls[0].id is not None:
|
||||
# flush prev tool call
|
||||
if curr_tool_call_id is not None:
|
||||
tool_call_id_arg_map[curr_tool_call_id] = curr_tool_call_str
|
||||
curr_tool_call_str = ""
|
||||
curr_tool_call_id = chunk.choices[0].delta.tool_calls[0].id
|
||||
tool_call_id_arg_map[curr_tool_call_id] = ""
|
||||
if chunk.choices[0].delta.tool_calls[0].function.arguments is not None:
|
||||
curr_tool_call_str += (
|
||||
chunk.choices[0].delta.tool_calls[0].function.arguments
|
||||
)
|
||||
# flush prev tool call
|
||||
if curr_tool_call_id is not None:
|
||||
tool_call_id_arg_map[curr_tool_call_id] = curr_tool_call_str
|
||||
|
||||
for k, v in tool_call_id_arg_map.items():
|
||||
print("k={}, v={}".format(k, v))
|
||||
json.loads(v) # valid json str
|
||||
|
|
|
@ -2596,8 +2596,8 @@ def streaming_and_function_calling_format_tests(idx, chunk):
|
|||
@pytest.mark.parametrize(
|
||||
"model",
|
||||
[
|
||||
"gpt-3.5-turbo",
|
||||
"anthropic.claude-3-sonnet-20240229-v1:0",
|
||||
# "gpt-3.5-turbo",
|
||||
# "anthropic.claude-3-sonnet-20240229-v1:0",
|
||||
"claude-3-haiku-20240307",
|
||||
],
|
||||
)
|
||||
|
@ -2627,7 +2627,7 @@ def test_streaming_and_function_calling(model):
|
|||
|
||||
messages = [{"role": "user", "content": "What is the weather like in Boston?"}]
|
||||
try:
|
||||
litellm.set_verbose = True
|
||||
# litellm.set_verbose = True
|
||||
response: litellm.CustomStreamWrapper = completion(
|
||||
model=model,
|
||||
tools=tools,
|
||||
|
@ -2639,7 +2639,7 @@ def test_streaming_and_function_calling(model):
|
|||
json_str = ""
|
||||
for idx, chunk in enumerate(response):
|
||||
# continue
|
||||
print("\n{}\n".format(chunk))
|
||||
# print("\n{}\n".format(chunk))
|
||||
if idx == 0:
|
||||
assert (
|
||||
chunk.choices[0].delta.tool_calls[0].function.arguments is not None
|
||||
|
|
|
@ -141,6 +141,11 @@ class ContentBlockDelta(TypedDict):
|
|||
delta: Union[ContentTextBlockDelta, ContentJsonBlockDelta]
|
||||
|
||||
|
||||
class ContentBlockStop(TypedDict):
|
||||
type: Literal["content_block_stop"]
|
||||
index: int
|
||||
|
||||
|
||||
class ToolUseBlock(TypedDict):
|
||||
"""
|
||||
"content_block":{"type":"tool_use","id":"toolu_01T1x1fJ34qAmk2tNTrN7Up6","name":"get_weather","input":{}}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue