test(verification): add streaming tool calling test (#1933)

# What does this PR do?


## Test Plan

---
[//]: # (BEGIN SAPLING FOOTER)
Stack created with [Sapling](https://sapling-scm.com). Best reviewed
with
[ReviewStack](https://reviewstack.dev/meta-llama/llama-stack/pull/1933).
* #1934
* __->__ #1933
This commit is contained in:
ehhuang 2025-04-10 16:58:06 -07:00 committed by GitHub
parent 49955a06b1
commit a4cc4b7e31
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -4,6 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import json
import re
from typing import Any
@ -225,6 +226,60 @@ def test_chat_non_streaming_tool_calling(request, openai_client, model, provider
# TODO: add detailed type validation
@pytest.mark.parametrize(
"case",
chat_completion_test_cases["test_tool_calling"]["test_params"]["case"],
ids=case_id_generator,
)
def test_chat_streaming_tool_calling(request, openai_client, model, provider, verification_config, case):
test_name_base = get_base_test_name(request)
if should_skip_test(verification_config, provider, model, test_name_base):
pytest.skip(f"Skipping {test_name_base} for model {model} on provider {provider} based on config.")
stream = openai_client.chat.completions.create(
model=model,
messages=case["input"]["messages"],
tools=case["input"]["tools"],
stream=True,
)
# Accumulate partial tool_calls here
tool_calls_buffer = {}
current_id = None
# Process streaming chunks
for chunk in stream:
choice = chunk.choices[0]
delta = choice.delta
if delta.tool_calls is None:
continue
for tool_call_delta in delta.tool_calls:
if tool_call_delta.id:
current_id = tool_call_delta.id
call_id = current_id
func_delta = tool_call_delta.function
if call_id not in tool_calls_buffer:
tool_calls_buffer[call_id] = {
"id": call_id,
"type": tool_call_delta.type,
"name": func_delta.name,
"arguments": "",
}
if func_delta.arguments:
tool_calls_buffer[call_id]["arguments"] += func_delta.arguments
assert len(tool_calls_buffer) == 1
for call in tool_calls_buffer.values():
assert len(call["id"]) > 0
assert call["name"] == "get_weather"
args_dict = json.loads(call["arguments"])
assert "san francisco" in args_dict["location"].lower()
# --- Helper functions (structured output validation) ---