fix(anthropic.py): fix anthropic tool calling + streaming

Fixes https://github.com/BerriAI/litellm/issues/4537
This commit is contained in:
Krrish Dholakia 2024-07-04 16:30:24 -07:00
parent 8625770010
commit 00497b408d
6 changed files with 19 additions and 8 deletions

View file

@ -12,6 +12,7 @@ import requests # type: ignore
import litellm
import litellm.litellm_core_utils
from litellm import verbose_logger
from litellm.litellm_core_utils.core_helpers import map_finish_reason
from litellm.llms.custom_httpx.http_handler import (
AsyncHTTPHandler,
@ -730,6 +731,7 @@ class ModelResponseIterator:
def chunk_parser(self, chunk: dict) -> GenericStreamingChunk:
try:
verbose_logger.debug(f"\n\nRaw chunk:\n{chunk}\n")
type_chunk = chunk.get("type", "") or ""
text = ""
@ -770,9 +772,7 @@ class ModelResponseIterator:
"type": "function",
"function": {
"name": content_block_start["content_block"]["name"],
"arguments": json.dumps(
content_block_start["content_block"]["input"]
),
"arguments": "",
},
}
elif type_chunk == "message_delta":

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View file

@ -38,6 +38,9 @@ model_list:
- litellm_params:
model: anthropic.claude-3-sonnet-20240229-v1:0
model_name: bedrock-anthropic-claude-3
- litellm_params:
model: claude-3-haiku-20240307
model_name: anthropic-claude-3
- litellm_params:
api_base: https://openai-gpt-4-test-v-1.openai.azure.com/
api_key: os.environ/AZURE_API_KEY

View file

@ -2559,9 +2559,16 @@ def streaming_and_function_calling_format_tests(idx, chunk):
@pytest.mark.parametrize(
"model", ["gpt-3.5-turbo", "anthropic.claude-3-sonnet-20240229-v1:0"]
"model",
[
"gpt-3.5-turbo",
"anthropic.claude-3-sonnet-20240229-v1:0",
"claude-3-haiku-20240307",
],
)
def test_streaming_and_function_calling(model):
import json
tools = [
{
"type": "function",
@ -2594,6 +2601,7 @@ def test_streaming_and_function_calling(model):
tool_choice="required",
) # type: ignore
# Add any assertions here to check the response
json_str = ""
for idx, chunk in enumerate(response):
# continue
print("\n{}\n".format(chunk))
@ -2604,7 +2612,10 @@ def test_streaming_and_function_calling(model):
assert isinstance(
chunk.choices[0].delta.tool_calls[0].function.arguments, str
)
# assert False
if chunk.choices[0].delta.tool_calls is not None:
json_str += chunk.choices[0].delta.tool_calls[0].function.arguments
print(json.loads(json_str))
except Exception as e:
pytest.fail(f"Error occurred: {e}")
raise e