mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-27 11:43:54 +00:00
fix(anthropic.py): fix anthropic tool calling + streaming
Fixes https://github.com/BerriAI/litellm/issues/4537
This commit is contained in:
parent
8625770010
commit
00497b408d
6 changed files with 19 additions and 8 deletions
|
@ -12,6 +12,7 @@ import requests # type: ignore
|
||||||
|
|
||||||
import litellm
|
import litellm
|
||||||
import litellm.litellm_core_utils
|
import litellm.litellm_core_utils
|
||||||
|
from litellm import verbose_logger
|
||||||
from litellm.litellm_core_utils.core_helpers import map_finish_reason
|
from litellm.litellm_core_utils.core_helpers import map_finish_reason
|
||||||
from litellm.llms.custom_httpx.http_handler import (
|
from litellm.llms.custom_httpx.http_handler import (
|
||||||
AsyncHTTPHandler,
|
AsyncHTTPHandler,
|
||||||
|
@ -730,6 +731,7 @@ class ModelResponseIterator:
|
||||||
|
|
||||||
def chunk_parser(self, chunk: dict) -> GenericStreamingChunk:
|
def chunk_parser(self, chunk: dict) -> GenericStreamingChunk:
|
||||||
try:
|
try:
|
||||||
|
verbose_logger.debug(f"\n\nRaw chunk:\n{chunk}\n")
|
||||||
type_chunk = chunk.get("type", "") or ""
|
type_chunk = chunk.get("type", "") or ""
|
||||||
|
|
||||||
text = ""
|
text = ""
|
||||||
|
@ -770,9 +772,7 @@ class ModelResponseIterator:
|
||||||
"type": "function",
|
"type": "function",
|
||||||
"function": {
|
"function": {
|
||||||
"name": content_block_start["content_block"]["name"],
|
"name": content_block_start["content_block"]["name"],
|
||||||
"arguments": json.dumps(
|
"arguments": "",
|
||||||
content_block_start["content_block"]["input"]
|
|
||||||
),
|
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
elif type_chunk == "message_delta":
|
elif type_chunk == "message_delta":
|
||||||
|
|
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
|
@ -38,6 +38,9 @@ model_list:
|
||||||
- litellm_params:
|
- litellm_params:
|
||||||
model: anthropic.claude-3-sonnet-20240229-v1:0
|
model: anthropic.claude-3-sonnet-20240229-v1:0
|
||||||
model_name: bedrock-anthropic-claude-3
|
model_name: bedrock-anthropic-claude-3
|
||||||
|
- litellm_params:
|
||||||
|
model: claude-3-haiku-20240307
|
||||||
|
model_name: anthropic-claude-3
|
||||||
- litellm_params:
|
- litellm_params:
|
||||||
api_base: https://openai-gpt-4-test-v-1.openai.azure.com/
|
api_base: https://openai-gpt-4-test-v-1.openai.azure.com/
|
||||||
api_key: os.environ/AZURE_API_KEY
|
api_key: os.environ/AZURE_API_KEY
|
||||||
|
|
|
@ -2559,9 +2559,16 @@ def streaming_and_function_calling_format_tests(idx, chunk):
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"model", ["gpt-3.5-turbo", "anthropic.claude-3-sonnet-20240229-v1:0"]
|
"model",
|
||||||
|
[
|
||||||
|
"gpt-3.5-turbo",
|
||||||
|
"anthropic.claude-3-sonnet-20240229-v1:0",
|
||||||
|
"claude-3-haiku-20240307",
|
||||||
|
],
|
||||||
)
|
)
|
||||||
def test_streaming_and_function_calling(model):
|
def test_streaming_and_function_calling(model):
|
||||||
|
import json
|
||||||
|
|
||||||
tools = [
|
tools = [
|
||||||
{
|
{
|
||||||
"type": "function",
|
"type": "function",
|
||||||
|
@ -2594,6 +2601,7 @@ def test_streaming_and_function_calling(model):
|
||||||
tool_choice="required",
|
tool_choice="required",
|
||||||
) # type: ignore
|
) # type: ignore
|
||||||
# Add any assertions here to check the response
|
# Add any assertions here to check the response
|
||||||
|
json_str = ""
|
||||||
for idx, chunk in enumerate(response):
|
for idx, chunk in enumerate(response):
|
||||||
# continue
|
# continue
|
||||||
print("\n{}\n".format(chunk))
|
print("\n{}\n".format(chunk))
|
||||||
|
@ -2604,7 +2612,10 @@ def test_streaming_and_function_calling(model):
|
||||||
assert isinstance(
|
assert isinstance(
|
||||||
chunk.choices[0].delta.tool_calls[0].function.arguments, str
|
chunk.choices[0].delta.tool_calls[0].function.arguments, str
|
||||||
)
|
)
|
||||||
# assert False
|
if chunk.choices[0].delta.tool_calls is not None:
|
||||||
|
json_str += chunk.choices[0].delta.tool_calls[0].function.arguments
|
||||||
|
|
||||||
|
print(json.loads(json_str))
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
pytest.fail(f"Error occurred: {e}")
|
pytest.fail(f"Error occurred: {e}")
|
||||||
raise e
|
raise e
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue