forked from phoenix-oss/llama-stack-mirror
test(verification): add streaming tool calling test (#1933)
# What does this PR do? ## Test Plan --- [//]: # (BEGIN SAPLING FOOTER) Stack created with [Sapling](https://sapling-scm.com). Best reviewed with [ReviewStack](https://reviewstack.dev/meta-llama/llama-stack/pull/1933). * #1934 * __->__ #1933
This commit is contained in:
parent
49955a06b1
commit
a4cc4b7e31
1 changed files with 55 additions and 0 deletions
|
@ -4,6 +4,7 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
import json
|
||||||
import re
|
import re
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
|
@ -225,6 +226,60 @@ def test_chat_non_streaming_tool_calling(request, openai_client, model, provider
|
||||||
# TODO: add detailed type validation
|
# TODO: add detailed type validation
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"case",
|
||||||
|
chat_completion_test_cases["test_tool_calling"]["test_params"]["case"],
|
||||||
|
ids=case_id_generator,
|
||||||
|
)
|
||||||
|
def test_chat_streaming_tool_calling(request, openai_client, model, provider, verification_config, case):
|
||||||
|
test_name_base = get_base_test_name(request)
|
||||||
|
if should_skip_test(verification_config, provider, model, test_name_base):
|
||||||
|
pytest.skip(f"Skipping {test_name_base} for model {model} on provider {provider} based on config.")
|
||||||
|
|
||||||
|
stream = openai_client.chat.completions.create(
|
||||||
|
model=model,
|
||||||
|
messages=case["input"]["messages"],
|
||||||
|
tools=case["input"]["tools"],
|
||||||
|
stream=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Accumulate partial tool_calls here
|
||||||
|
tool_calls_buffer = {}
|
||||||
|
current_id = None
|
||||||
|
# Process streaming chunks
|
||||||
|
for chunk in stream:
|
||||||
|
choice = chunk.choices[0]
|
||||||
|
delta = choice.delta
|
||||||
|
|
||||||
|
if delta.tool_calls is None:
|
||||||
|
continue
|
||||||
|
|
||||||
|
for tool_call_delta in delta.tool_calls:
|
||||||
|
if tool_call_delta.id:
|
||||||
|
current_id = tool_call_delta.id
|
||||||
|
call_id = current_id
|
||||||
|
func_delta = tool_call_delta.function
|
||||||
|
|
||||||
|
if call_id not in tool_calls_buffer:
|
||||||
|
tool_calls_buffer[call_id] = {
|
||||||
|
"id": call_id,
|
||||||
|
"type": tool_call_delta.type,
|
||||||
|
"name": func_delta.name,
|
||||||
|
"arguments": "",
|
||||||
|
}
|
||||||
|
|
||||||
|
if func_delta.arguments:
|
||||||
|
tool_calls_buffer[call_id]["arguments"] += func_delta.arguments
|
||||||
|
|
||||||
|
assert len(tool_calls_buffer) == 1
|
||||||
|
for call in tool_calls_buffer.values():
|
||||||
|
assert len(call["id"]) > 0
|
||||||
|
assert call["name"] == "get_weather"
|
||||||
|
|
||||||
|
args_dict = json.loads(call["arguments"])
|
||||||
|
assert "san francisco" in args_dict["location"].lower()
|
||||||
|
|
||||||
|
|
||||||
# --- Helper functions (structured output validation) ---
|
# --- Helper functions (structured output validation) ---
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue