Merge pull request #5232 from Penagwin/fix_anthropic_tool_streaming_index

Fixes the `tool_use` indexes not being correctly mapped
This commit is contained in:
Krish Dholakia 2024-08-17 14:33:50 -07:00 committed by GitHub
commit be37310e94
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 75 additions and 4 deletions

View file

@ -1122,6 +1122,7 @@ class ModelResponseIterator:
self.streaming_response = streaming_response
self.response_iterator = self.streaming_response
self.content_blocks: List[ContentBlockDelta] = []
self.tool_index = -1
def check_empty_tool_call_args(self) -> bool:
"""
@ -1171,7 +1172,7 @@ class ModelResponseIterator:
"name": None,
"arguments": content_block["delta"]["partial_json"],
},
"index": content_block["index"],
"index": self.tool_index,
}
elif type_chunk == "content_block_start":
"""
@ -1183,6 +1184,7 @@ class ModelResponseIterator:
if content_block_start["content_block"]["type"] == "text":
text = content_block_start["content_block"]["text"]
elif content_block_start["content_block"]["type"] == "tool_use":
self.tool_index += 1
tool_use = {
"id": content_block_start["content_block"]["id"],
"type": "function",
@ -1190,7 +1192,7 @@ class ModelResponseIterator:
"name": content_block_start["content_block"]["name"],
"arguments": "",
},
"index": content_block_start["index"],
"index": self.tool_index,
}
elif type_chunk == "content_block_stop":
content_block_stop = ContentBlockStop(**chunk) # type: ignore
@ -1204,7 +1206,7 @@ class ModelResponseIterator:
"name": None,
"arguments": "{}",
},
"index": content_block_stop["index"],
"index": self.tool_index,
}
elif type_chunk == "message_delta":
"""

View file

@ -10,6 +10,7 @@ from dotenv import load_dotenv
import litellm.types
import litellm.types.utils
from litellm.llms.anthropic import ModelResponseIterator
load_dotenv()
import io
@ -150,6 +151,74 @@ def test_anthropic_completion_e2e(stream):
assert message_stop_received is True
anthropic_chunk_list = [
{"type": "content_block_start", "index": 0, "content_block": {"type": "text", "text": ""}},
{"type": "content_block_delta", "index": 0, "delta": {"type": "text_delta", "text": "To"}},
{"type": "content_block_delta", "index": 0, "delta": {"type": "text_delta", "text": " answer"}},
{"type": "content_block_delta", "index": 0,
"delta": {"type": "text_delta", "text": " your question about the weather"}},
{"type": "content_block_delta", "index": 0, "delta": {"type": "text_delta", "text": " in Boston and Los"}},
{"type": "content_block_delta", "index": 0, "delta": {"type": "text_delta", "text": " Angeles today, I'll"}},
{"type": "content_block_delta", "index": 0, "delta": {"type": "text_delta", "text": " need to"}},
{"type": "content_block_delta", "index": 0, "delta": {"type": "text_delta", "text": " use"}},
{"type": "content_block_delta", "index": 0, "delta": {"type": "text_delta", "text": " the"}},
{"type": "content_block_delta", "index": 0, "delta": {"type": "text_delta", "text": " get_current_weather"}},
{"type": "content_block_delta", "index": 0, "delta": {"type": "text_delta", "text": " function"}},
{"type": "content_block_delta", "index": 0, "delta": {"type": "text_delta", "text": " for"}},
{"type": "content_block_delta", "index": 0, "delta": {"type": "text_delta", "text": " both"}},
{"type": "content_block_delta", "index": 0, "delta": {"type": "text_delta", "text": " cities"}},
{"type": "content_block_delta", "index": 0, "delta": {"type": "text_delta", "text": ". Let"}},
{"type": "content_block_delta", "index": 0, "delta": {"type": "text_delta", "text": " me fetch"}},
{"type": "content_block_delta", "index": 0, "delta": {"type": "text_delta", "text": " that"}},
{"type": "content_block_delta", "index": 0, "delta": {"type": "text_delta", "text": " information"}},
{"type": "content_block_delta", "index": 0, "delta": {"type": "text_delta", "text": " for"}},
{"type": "content_block_delta", "index": 0, "delta": {"type": "text_delta", "text": " you."}},
{"type": "content_block_stop", "index": 0},
{"type": "content_block_start", "index": 1,
"content_block": {"type": "tool_use", "id": "toolu_12345", "name": "get_current_weather", "input": {}}},
{"type": "content_block_delta", "index": 1, "delta": {"type": "input_json_delta", "partial_json": ""}},
{"type": "content_block_delta", "index": 1, "delta": {"type": "input_json_delta", "partial_json": "{\"locat"}},
{"type": "content_block_delta", "index": 1, "delta": {"type": "input_json_delta", "partial_json": "ion\": \"Bos"}},
{"type": "content_block_delta", "index": 1, "delta": {"type": "input_json_delta", "partial_json": "ton, MA\"}"}},
{"type": "content_block_stop", "index": 1},
{"type": "content_block_start", "index": 2,
"content_block": {"type": "tool_use", "id": "toolu_023423423", "name": "get_current_weather", "input": {}}},
{"type": "content_block_delta", "index": 2, "delta": {"type": "input_json_delta", "partial_json": ""}},
{"type": "content_block_delta", "index": 2, "delta": {"type": "input_json_delta", "partial_json": "{\"l"}},
{"type": "content_block_delta", "index": 2, "delta": {"type": "input_json_delta", "partial_json": "oca"}},
{"type": "content_block_delta", "index": 2, "delta": {"type": "input_json_delta", "partial_json": "tio"}},
{"type": "content_block_delta", "index": 2, "delta": {"type": "input_json_delta", "partial_json": "n\": \"Lo"}},
{"type": "content_block_delta", "index": 2, "delta": {"type": "input_json_delta", "partial_json": "s Angel"}},
{"type": "content_block_delta", "index": 2, "delta": {"type": "input_json_delta", "partial_json": "es, CA\"}"}},
{"type": "content_block_stop", "index": 2},
{"type": "message_delta", "delta": {"stop_reason": "tool_use", "stop_sequence": None},
"usage": {"output_tokens": 137}},
{"type": "message_stop"}
]
def test_anthropic_tool_streaming():
"""
OpenAI starts tool_use indexes at 0 for the first tool, regardless of preceding text.
Anthropic gives tool_use indexes starting at the first chunk, meaning they often start at 1
when they should start at 0
"""
litellm.set_verbose = True
response_iter = ModelResponseIterator([], False)
# First index is 0, we'll start earlier because incrementing is easier
correct_tool_index = -1
for chunk in anthropic_chunk_list:
parsed_chunk = response_iter.chunk_parser(chunk)
if tool_use := parsed_chunk.get('tool_use'):
# We only increment when a new block starts
if tool_use.get('id') is not None:
correct_tool_index += 1
assert tool_use['index'] == correct_tool_index
@pytest.mark.asyncio
async def test_anthropic_router_completion_e2e():
litellm.set_verbose = True
@ -275,4 +344,4 @@ def test_anthropic_tool_calling_translation():
print(translated_params["messages"])
assert len(translated_params["messages"]) > 0
assert translated_params["messages"][0]["role"] == "user"
assert translated_params["messages"][0]["role"] == "user"