diff --git a/tests/verifications/openai_api/test_chat_completion.py b/tests/verifications/openai_api/test_chat_completion.py index dc08ec944..6aee29c3a 100644 --- a/tests/verifications/openai_api/test_chat_completion.py +++ b/tests/verifications/openai_api/test_chat_completion.py @@ -4,6 +4,7 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. +import json import re from typing import Any @@ -225,6 +226,60 @@ def test_chat_non_streaming_tool_calling(request, openai_client, model, provider # TODO: add detailed type validation +@pytest.mark.parametrize( + "case", + chat_completion_test_cases["test_tool_calling"]["test_params"]["case"], + ids=case_id_generator, +) +def test_chat_streaming_tool_calling(request, openai_client, model, provider, verification_config, case): + test_name_base = get_base_test_name(request) + if should_skip_test(verification_config, provider, model, test_name_base): + pytest.skip(f"Skipping {test_name_base} for model {model} on provider {provider} based on config.") + + stream = openai_client.chat.completions.create( + model=model, + messages=case["input"]["messages"], + tools=case["input"]["tools"], + stream=True, + ) + + # Accumulate partial tool_calls here + tool_calls_buffer = {} + current_id = None + # Process streaming chunks + for chunk in stream: + choice = chunk.choices[0] + delta = choice.delta + + if delta.tool_calls is None: + continue + + for tool_call_delta in delta.tool_calls: + if tool_call_delta.id: + current_id = tool_call_delta.id + call_id = current_id + func_delta = tool_call_delta.function + + if call_id not in tool_calls_buffer: + tool_calls_buffer[call_id] = { + "id": call_id, + "type": tool_call_delta.type, + "name": func_delta.name, + "arguments": "", + } + + if func_delta.arguments: + tool_calls_buffer[call_id]["arguments"] += func_delta.arguments + + assert len(tool_calls_buffer) == 1 + for call in tool_calls_buffer.values(): + assert len(call["id"]) > 0 + assert call["name"] == "get_weather" + + args_dict = json.loads(call["arguments"]) + assert "san francisco" in args_dict["location"].lower() + + # --- Helper functions (structured output validation) ---