fix unit tests

This commit is contained in:
Ashwin Bharambe 2025-08-12 22:16:48 -07:00
parent a4b8eed576
commit 76e4c52f74
2 changed files with 20 additions and 10 deletions

View file

@ -487,7 +487,8 @@ class OpenAIResponsesImpl:
for tool_call in chunk_choice.delta.tool_calls: for tool_call in chunk_choice.delta.tool_calls:
response_tool_call = chat_response_tool_calls.get(tool_call.index, None) response_tool_call = chat_response_tool_calls.get(tool_call.index, None)
# Create new tool call entry if this is the first chunk for this index # Create new tool call entry if this is the first chunk for this index
if response_tool_call is None: is_new_tool_call = response_tool_call is None
if is_new_tool_call:
tool_call_dict: dict[str, Any] = tool_call.model_dump() tool_call_dict: dict[str, Any] = tool_call.model_dump()
tool_call_dict.pop("type", None) tool_call_dict.pop("type", None)
response_tool_call = OpenAIChatCompletionToolCall(**tool_call_dict) response_tool_call = OpenAIChatCompletionToolCall(**tool_call_dict)
@ -524,10 +525,11 @@ class OpenAIResponsesImpl:
sequence_number=sequence_number, sequence_number=sequence_number,
) )
# Accumulate arguments for final response # Accumulate arguments for final response (only for subsequent chunks)
response_tool_call.function.arguments = ( if not is_new_tool_call:
response_tool_call.function.arguments or "" response_tool_call.function.arguments = (
) + tool_call.function.arguments response_tool_call.function.arguments or ""
) + tool_call.function.arguments
# Emit function_call_arguments.done events for completed tool calls # Emit function_call_arguments.done events for completed tool calls
for tool_call_index in sorted(chat_response_tool_calls.keys()): for tool_call_index in sorted(chat_response_tool_calls.keys()):

View file

@ -272,7 +272,9 @@ async def test_create_openai_response_with_tool_call_type_none(openai_responses_
# Check that we got the content from our mocked tool execution result # Check that we got the content from our mocked tool execution result
chunks = [chunk async for chunk in result] chunks = [chunk async for chunk in result]
assert len(chunks) == 2 # Should have response.created and response.completed # Should have: response.created, output_item.added, function_call_arguments.delta,
# function_call_arguments.done, output_item.done, response.completed
assert len(chunks) == 6
# Verify inference API was called correctly (after iterating over result) # Verify inference API was called correctly (after iterating over result)
first_call = mock_inference_api.openai_chat_completion.call_args_list[0] first_call = mock_inference_api.openai_chat_completion.call_args_list[0]
@ -284,11 +286,17 @@ async def test_create_openai_response_with_tool_call_type_none(openai_responses_
assert chunks[0].type == "response.created" assert chunks[0].type == "response.created"
assert len(chunks[0].response.output) == 0 assert len(chunks[0].response.output) == 0
# Check streaming events
assert chunks[1].type == "response.output_item.added"
assert chunks[2].type == "response.function_call_arguments.delta"
assert chunks[3].type == "response.function_call_arguments.done"
assert chunks[4].type == "response.output_item.done"
# Check response.completed event (should have the tool call) # Check response.completed event (should have the tool call)
assert chunks[1].type == "response.completed" assert chunks[5].type == "response.completed"
assert len(chunks[1].response.output) == 1 assert len(chunks[5].response.output) == 1
assert chunks[1].response.output[0].type == "function_call" assert chunks[5].response.output[0].type == "function_call"
assert chunks[1].response.output[0].name == "get_weather" assert chunks[5].response.output[0].name == "get_weather"
async def test_create_openai_response_with_multiple_messages(openai_responses_impl, mock_inference_api): async def test_create_openai_response_with_multiple_messages(openai_responses_impl, mock_inference_api):