mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 18:54:30 +00:00
fix(anthropic.py): handle scenario where anthropic returns invalid json string for tool call while streaming
Fixes https://github.com/BerriAI/litellm/issues/5063
This commit is contained in:
parent
1008f24b16
commit
4919cc4d25
5 changed files with 105 additions and 7 deletions
|
@ -2,6 +2,7 @@ import copy
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
import time
|
import time
|
||||||
|
import traceback
|
||||||
import types
|
import types
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from functools import partial
|
from functools import partial
|
||||||
|
@ -36,6 +37,7 @@ from litellm.types.llms.anthropic import (
|
||||||
AnthropicResponseUsageBlock,
|
AnthropicResponseUsageBlock,
|
||||||
ContentBlockDelta,
|
ContentBlockDelta,
|
||||||
ContentBlockStart,
|
ContentBlockStart,
|
||||||
|
ContentBlockStop,
|
||||||
ContentJsonBlockDelta,
|
ContentJsonBlockDelta,
|
||||||
ContentTextBlockDelta,
|
ContentTextBlockDelta,
|
||||||
MessageBlockDelta,
|
MessageBlockDelta,
|
||||||
|
@ -920,7 +922,12 @@ class AnthropicChatCompletion(BaseLLM):
|
||||||
model=model, messages=messages, custom_llm_provider="anthropic"
|
model=model, messages=messages, custom_llm_provider="anthropic"
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise AnthropicError(status_code=400, message=str(e))
|
raise AnthropicError(
|
||||||
|
status_code=400,
|
||||||
|
message="{}\n{}\nReceived Messages={}".format(
|
||||||
|
str(e), traceback.format_exc(), messages
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
## Load Config
|
## Load Config
|
||||||
config = litellm.AnthropicConfig.get_config()
|
config = litellm.AnthropicConfig.get_config()
|
||||||
|
@ -1079,10 +1086,30 @@ class ModelResponseIterator:
|
||||||
def __init__(self, streaming_response, sync_stream: bool):
|
def __init__(self, streaming_response, sync_stream: bool):
|
||||||
self.streaming_response = streaming_response
|
self.streaming_response = streaming_response
|
||||||
self.response_iterator = self.streaming_response
|
self.response_iterator = self.streaming_response
|
||||||
|
self.content_blocks: List[ContentBlockDelta] = []
|
||||||
|
|
||||||
|
def check_empty_tool_call_args(self) -> bool:
|
||||||
|
"""
|
||||||
|
Check if the tool call block so far has been an empty string
|
||||||
|
"""
|
||||||
|
args = ""
|
||||||
|
# if text content block -> skip
|
||||||
|
if len(self.content_blocks) == 0:
|
||||||
|
return False
|
||||||
|
|
||||||
|
if self.content_blocks[0]["delta"]["type"] == "text_delta":
|
||||||
|
return False
|
||||||
|
|
||||||
|
for block in self.content_blocks:
|
||||||
|
if block["delta"]["type"] == "input_json_delta":
|
||||||
|
args += block["delta"].get("partial_json", "") # type: ignore
|
||||||
|
|
||||||
|
if len(args) == 0:
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
def chunk_parser(self, chunk: dict) -> GenericStreamingChunk:
|
def chunk_parser(self, chunk: dict) -> GenericStreamingChunk:
|
||||||
try:
|
try:
|
||||||
verbose_logger.debug(f"\n\nRaw chunk:\n{chunk}\n")
|
|
||||||
type_chunk = chunk.get("type", "") or ""
|
type_chunk = chunk.get("type", "") or ""
|
||||||
|
|
||||||
text = ""
|
text = ""
|
||||||
|
@ -1098,6 +1125,7 @@ class ModelResponseIterator:
|
||||||
chunk = {'type': 'content_block_delta', 'index': 0, 'delta': {'type': 'text_delta', 'text': 'Hello'}}
|
chunk = {'type': 'content_block_delta', 'index': 0, 'delta': {'type': 'text_delta', 'text': 'Hello'}}
|
||||||
"""
|
"""
|
||||||
content_block = ContentBlockDelta(**chunk) # type: ignore
|
content_block = ContentBlockDelta(**chunk) # type: ignore
|
||||||
|
self.content_blocks.append(content_block)
|
||||||
if "text" in content_block["delta"]:
|
if "text" in content_block["delta"]:
|
||||||
text = content_block["delta"]["text"]
|
text = content_block["delta"]["text"]
|
||||||
elif "partial_json" in content_block["delta"]:
|
elif "partial_json" in content_block["delta"]:
|
||||||
|
@ -1116,6 +1144,7 @@ class ModelResponseIterator:
|
||||||
data: {"type":"content_block_start","index":1,"content_block":{"type":"tool_use","id":"toolu_01T1x1fJ34qAmk2tNTrN7Up6","name":"get_weather","input":{}}}
|
data: {"type":"content_block_start","index":1,"content_block":{"type":"tool_use","id":"toolu_01T1x1fJ34qAmk2tNTrN7Up6","name":"get_weather","input":{}}}
|
||||||
"""
|
"""
|
||||||
content_block_start = ContentBlockStart(**chunk) # type: ignore
|
content_block_start = ContentBlockStart(**chunk) # type: ignore
|
||||||
|
self.content_blocks = [] # reset content blocks when new block starts
|
||||||
if content_block_start["content_block"]["type"] == "text":
|
if content_block_start["content_block"]["type"] == "text":
|
||||||
text = content_block_start["content_block"]["text"]
|
text = content_block_start["content_block"]["text"]
|
||||||
elif content_block_start["content_block"]["type"] == "tool_use":
|
elif content_block_start["content_block"]["type"] == "tool_use":
|
||||||
|
@ -1128,6 +1157,20 @@ class ModelResponseIterator:
|
||||||
},
|
},
|
||||||
"index": content_block_start["index"],
|
"index": content_block_start["index"],
|
||||||
}
|
}
|
||||||
|
elif type_chunk == "content_block_stop":
|
||||||
|
content_block_stop = ContentBlockStop(**chunk) # type: ignore
|
||||||
|
# check if tool call content block
|
||||||
|
is_empty = self.check_empty_tool_call_args()
|
||||||
|
if is_empty:
|
||||||
|
tool_use = {
|
||||||
|
"id": None,
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": None,
|
||||||
|
"arguments": "{}",
|
||||||
|
},
|
||||||
|
"index": content_block_stop["index"],
|
||||||
|
}
|
||||||
elif type_chunk == "message_delta":
|
elif type_chunk == "message_delta":
|
||||||
"""
|
"""
|
||||||
Anthropic
|
Anthropic
|
||||||
|
|
|
@ -5113,7 +5113,9 @@ def stream_chunk_builder(
|
||||||
prev_index = curr_index
|
prev_index = curr_index
|
||||||
prev_id = curr_id
|
prev_id = curr_id
|
||||||
|
|
||||||
combined_arguments = "".join(argument_list)
|
combined_arguments = (
|
||||||
|
"".join(argument_list) or "{}"
|
||||||
|
) # base case, return empty dict
|
||||||
tool_calls_list.append(
|
tool_calls_list.append(
|
||||||
{
|
{
|
||||||
"id": id,
|
"id": id,
|
||||||
|
|
|
@ -4346,3 +4346,51 @@ def test_moderation():
|
||||||
|
|
||||||
|
|
||||||
# test_moderation()
|
# test_moderation()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("model", ["gpt-3.5-turbo", "claude-3-5-sonnet-20240620"])
|
||||||
|
def test_streaming_tool_calls_valid_json_str(model):
|
||||||
|
messages = [
|
||||||
|
{"role": "user", "content": "Hit the snooze button."},
|
||||||
|
]
|
||||||
|
|
||||||
|
tools = [
|
||||||
|
{
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": "snooze",
|
||||||
|
"parameters": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {},
|
||||||
|
"required": [],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
stream = litellm.completion(model, messages, tools=tools, stream=True)
|
||||||
|
chunks = [*stream]
|
||||||
|
print(chunks)
|
||||||
|
tool_call_id_arg_map = {}
|
||||||
|
curr_tool_call_id = None
|
||||||
|
curr_tool_call_str = ""
|
||||||
|
for chunk in chunks:
|
||||||
|
if chunk.choices[0].delta.tool_calls is not None:
|
||||||
|
if chunk.choices[0].delta.tool_calls[0].id is not None:
|
||||||
|
# flush prev tool call
|
||||||
|
if curr_tool_call_id is not None:
|
||||||
|
tool_call_id_arg_map[curr_tool_call_id] = curr_tool_call_str
|
||||||
|
curr_tool_call_str = ""
|
||||||
|
curr_tool_call_id = chunk.choices[0].delta.tool_calls[0].id
|
||||||
|
tool_call_id_arg_map[curr_tool_call_id] = ""
|
||||||
|
if chunk.choices[0].delta.tool_calls[0].function.arguments is not None:
|
||||||
|
curr_tool_call_str += (
|
||||||
|
chunk.choices[0].delta.tool_calls[0].function.arguments
|
||||||
|
)
|
||||||
|
# flush prev tool call
|
||||||
|
if curr_tool_call_id is not None:
|
||||||
|
tool_call_id_arg_map[curr_tool_call_id] = curr_tool_call_str
|
||||||
|
|
||||||
|
for k, v in tool_call_id_arg_map.items():
|
||||||
|
print("k={}, v={}".format(k, v))
|
||||||
|
json.loads(v) # valid json str
|
||||||
|
|
|
@ -2596,8 +2596,8 @@ def streaming_and_function_calling_format_tests(idx, chunk):
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"model",
|
"model",
|
||||||
[
|
[
|
||||||
"gpt-3.5-turbo",
|
# "gpt-3.5-turbo",
|
||||||
"anthropic.claude-3-sonnet-20240229-v1:0",
|
# "anthropic.claude-3-sonnet-20240229-v1:0",
|
||||||
"claude-3-haiku-20240307",
|
"claude-3-haiku-20240307",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
@ -2627,7 +2627,7 @@ def test_streaming_and_function_calling(model):
|
||||||
|
|
||||||
messages = [{"role": "user", "content": "What is the weather like in Boston?"}]
|
messages = [{"role": "user", "content": "What is the weather like in Boston?"}]
|
||||||
try:
|
try:
|
||||||
litellm.set_verbose = True
|
# litellm.set_verbose = True
|
||||||
response: litellm.CustomStreamWrapper = completion(
|
response: litellm.CustomStreamWrapper = completion(
|
||||||
model=model,
|
model=model,
|
||||||
tools=tools,
|
tools=tools,
|
||||||
|
@ -2639,7 +2639,7 @@ def test_streaming_and_function_calling(model):
|
||||||
json_str = ""
|
json_str = ""
|
||||||
for idx, chunk in enumerate(response):
|
for idx, chunk in enumerate(response):
|
||||||
# continue
|
# continue
|
||||||
print("\n{}\n".format(chunk))
|
# print("\n{}\n".format(chunk))
|
||||||
if idx == 0:
|
if idx == 0:
|
||||||
assert (
|
assert (
|
||||||
chunk.choices[0].delta.tool_calls[0].function.arguments is not None
|
chunk.choices[0].delta.tool_calls[0].function.arguments is not None
|
||||||
|
|
|
@ -141,6 +141,11 @@ class ContentBlockDelta(TypedDict):
|
||||||
delta: Union[ContentTextBlockDelta, ContentJsonBlockDelta]
|
delta: Union[ContentTextBlockDelta, ContentJsonBlockDelta]
|
||||||
|
|
||||||
|
|
||||||
|
class ContentBlockStop(TypedDict):
|
||||||
|
type: Literal["content_block_stop"]
|
||||||
|
index: int
|
||||||
|
|
||||||
|
|
||||||
class ToolUseBlock(TypedDict):
|
class ToolUseBlock(TypedDict):
|
||||||
"""
|
"""
|
||||||
"content_block":{"type":"tool_use","id":"toolu_01T1x1fJ34qAmk2tNTrN7Up6","name":"get_weather","input":{}}
|
"content_block":{"type":"tool_use","id":"toolu_01T1x1fJ34qAmk2tNTrN7Up6","name":"get_weather","input":{}}
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue