forked from phoenix/litellm-mirror
fix(anthropic.py): fix anthropic tool calling + streaming
Fixes https://github.com/BerriAI/litellm/issues/4537
This commit is contained in:
parent
86632f6da0
commit
f2dabc65be
6 changed files with 19 additions and 8 deletions
|
@ -12,6 +12,7 @@ import requests # type: ignore
|
|||
|
||||
import litellm
|
||||
import litellm.litellm_core_utils
|
||||
from litellm import verbose_logger
|
||||
from litellm.litellm_core_utils.core_helpers import map_finish_reason
|
||||
from litellm.llms.custom_httpx.http_handler import (
|
||||
AsyncHTTPHandler,
|
||||
|
@ -730,6 +731,7 @@ class ModelResponseIterator:
|
|||
|
||||
def chunk_parser(self, chunk: dict) -> GenericStreamingChunk:
|
||||
try:
|
||||
verbose_logger.debug(f"\n\nRaw chunk:\n{chunk}\n")
|
||||
type_chunk = chunk.get("type", "") or ""
|
||||
|
||||
text = ""
|
||||
|
@ -770,9 +772,7 @@ class ModelResponseIterator:
|
|||
"type": "function",
|
||||
"function": {
|
||||
"name": content_block_start["content_block"]["name"],
|
||||
"arguments": json.dumps(
|
||||
content_block_start["content_block"]["input"]
|
||||
),
|
||||
"arguments": "",
|
||||
},
|
||||
}
|
||||
elif type_chunk == "message_delta":
|
||||
|
|
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
|
@ -38,6 +38,9 @@ model_list:
|
|||
- litellm_params:
|
||||
model: anthropic.claude-3-sonnet-20240229-v1:0
|
||||
model_name: bedrock-anthropic-claude-3
|
||||
- litellm_params:
|
||||
model: claude-3-haiku-20240307
|
||||
model_name: anthropic-claude-3
|
||||
- litellm_params:
|
||||
api_base: https://openai-gpt-4-test-v-1.openai.azure.com/
|
||||
api_key: os.environ/AZURE_API_KEY
|
||||
|
|
|
@ -2559,9 +2559,16 @@ def streaming_and_function_calling_format_tests(idx, chunk):
|
|||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"model", ["gpt-3.5-turbo", "anthropic.claude-3-sonnet-20240229-v1:0"]
|
||||
"model",
|
||||
[
|
||||
"gpt-3.5-turbo",
|
||||
"anthropic.claude-3-sonnet-20240229-v1:0",
|
||||
"claude-3-haiku-20240307",
|
||||
],
|
||||
)
|
||||
def test_streaming_and_function_calling(model):
|
||||
import json
|
||||
|
||||
tools = [
|
||||
{
|
||||
"type": "function",
|
||||
|
@ -2594,6 +2601,7 @@ def test_streaming_and_function_calling(model):
|
|||
tool_choice="required",
|
||||
) # type: ignore
|
||||
# Add any assertions here to check the response
|
||||
json_str = ""
|
||||
for idx, chunk in enumerate(response):
|
||||
# continue
|
||||
print("\n{}\n".format(chunk))
|
||||
|
@ -2604,7 +2612,10 @@ def test_streaming_and_function_calling(model):
|
|||
assert isinstance(
|
||||
chunk.choices[0].delta.tool_calls[0].function.arguments, str
|
||||
)
|
||||
# assert False
|
||||
if chunk.choices[0].delta.tool_calls is not None:
|
||||
json_str += chunk.choices[0].delta.tool_calls[0].function.arguments
|
||||
|
||||
print(json.loads(json_str))
|
||||
except Exception as e:
|
||||
pytest.fail(f"Error occurred: {e}")
|
||||
raise e
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue