fix(anthropic.py): fix anthropic tool calling + streaming

Fixes https://github.com/BerriAI/litellm/issues/4537
This commit is contained in:
Krrish Dholakia 2024-07-04 16:30:24 -07:00
parent 8625770010
commit 00497b408d
6 changed files with 19 additions and 8 deletions

View file

@ -12,6 +12,7 @@ import requests # type: ignore
import litellm import litellm
import litellm.litellm_core_utils import litellm.litellm_core_utils
from litellm import verbose_logger
from litellm.litellm_core_utils.core_helpers import map_finish_reason from litellm.litellm_core_utils.core_helpers import map_finish_reason
from litellm.llms.custom_httpx.http_handler import ( from litellm.llms.custom_httpx.http_handler import (
AsyncHTTPHandler, AsyncHTTPHandler,
@ -730,6 +731,7 @@ class ModelResponseIterator:
def chunk_parser(self, chunk: dict) -> GenericStreamingChunk: def chunk_parser(self, chunk: dict) -> GenericStreamingChunk:
try: try:
verbose_logger.debug(f"\n\nRaw chunk:\n{chunk}\n")
type_chunk = chunk.get("type", "") or "" type_chunk = chunk.get("type", "") or ""
text = "" text = ""
@ -770,9 +772,7 @@ class ModelResponseIterator:
"type": "function", "type": "function",
"function": { "function": {
"name": content_block_start["content_block"]["name"], "name": content_block_start["content_block"]["name"],
"arguments": json.dumps( "arguments": "",
content_block_start["content_block"]["input"]
),
}, },
} }
elif type_chunk == "message_delta": elif type_chunk == "message_delta":

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View file

@ -38,6 +38,9 @@ model_list:
- litellm_params: - litellm_params:
model: anthropic.claude-3-sonnet-20240229-v1:0 model: anthropic.claude-3-sonnet-20240229-v1:0
model_name: bedrock-anthropic-claude-3 model_name: bedrock-anthropic-claude-3
- litellm_params:
model: claude-3-haiku-20240307
model_name: anthropic-claude-3
- litellm_params: - litellm_params:
api_base: https://openai-gpt-4-test-v-1.openai.azure.com/ api_base: https://openai-gpt-4-test-v-1.openai.azure.com/
api_key: os.environ/AZURE_API_KEY api_key: os.environ/AZURE_API_KEY

View file

@ -2559,9 +2559,16 @@ def streaming_and_function_calling_format_tests(idx, chunk):
@pytest.mark.parametrize( @pytest.mark.parametrize(
"model", ["gpt-3.5-turbo", "anthropic.claude-3-sonnet-20240229-v1:0"] "model",
[
"gpt-3.5-turbo",
"anthropic.claude-3-sonnet-20240229-v1:0",
"claude-3-haiku-20240307",
],
) )
def test_streaming_and_function_calling(model): def test_streaming_and_function_calling(model):
import json
tools = [ tools = [
{ {
"type": "function", "type": "function",
@ -2594,6 +2601,7 @@ def test_streaming_and_function_calling(model):
tool_choice="required", tool_choice="required",
) # type: ignore ) # type: ignore
# Add any assertions here to check the response # Add any assertions here to check the response
json_str = ""
for idx, chunk in enumerate(response): for idx, chunk in enumerate(response):
# continue # continue
print("\n{}\n".format(chunk)) print("\n{}\n".format(chunk))
@ -2604,7 +2612,10 @@ def test_streaming_and_function_calling(model):
assert isinstance( assert isinstance(
chunk.choices[0].delta.tool_calls[0].function.arguments, str chunk.choices[0].delta.tool_calls[0].function.arguments, str
) )
# assert False if chunk.choices[0].delta.tool_calls is not None:
json_str += chunk.choices[0].delta.tool_calls[0].function.arguments
print(json.loads(json_str))
except Exception as e: except Exception as e:
pytest.fail(f"Error occurred: {e}") pytest.fail(f"Error occurred: {e}")
raise e raise e