fix: multiple tool calls in remote-vllm chat_completion (#2161)

# What does this PR do?

This fixes an issue in how we used the tool_call_buf from streaming tool
calls in the remote-vllm provider where it would end up concatenating
parameters from multiple different tool call results instead of
aggregating the results from each tool call separately.

It also fixes an issue found while digging into that where we were
accidentally mixing the json string form of tool call parameters with
the string representation of the python form, which mean we'd end up
with single quotes in what should be double-quoted json strings.

Closes #1120

## Test Plan

The following tests are now passing 100% for the remote-vllm provider,
where some of the test_text_inference were failing before this change:

```
VLLM_URL="http://localhost:8000/v1" INFERENCE_MODEL="RedHatAI/Llama-4-Scout-17B-16E-Instruct-FP8-dynamic" LLAMA_STACK_CONFIG=remote-vllm python -m pytest -v tests/integration/inference/test_text_inference.py --text-model "RedHatAI/Llama-4-Scout-17B-16E-Instruct-FP8-dynamic"

VLLM_URL="http://localhost:8000/v1" INFERENCE_MODEL="RedHatAI/Llama-4-Scout-17B-16E-Instruct-FP8-dynamic" LLAMA_STACK_CONFIG=remote-vllm python -m pytest -v tests/integration/inference/test_vision_inference.py --vision-model "RedHatAI/Llama-4-Scout-17B-16E-Instruct-FP8-dynamic"

```

All but one of the agent tests are passing (including the multi-tool
one). See the PR at https://github.com/vllm-project/vllm/pull/17917 and
a gist at
https://gist.github.com/bbrowning/4734240ce96b4264340caa9584e47c9e for
changes needed there, which will have to get made upstream in vLLM.

Agent tests:

```
VLLM_URL="http://localhost:8000/v1" INFERENCE_MODEL="RedHatAI/Llama-4-Scout-17B-16E-Instruct-FP8-dynamic" LLAMA_STACK_CONFIG=remote-vllm python -m pytest -v tests/integration/agents/test_agents.py --text-model "RedHatAI/Llama-4-Scout-17B-16E-Instruct-FP8-dynamic"
````

---------

Signed-off-by: Ben Browning <bbrownin@redhat.com>
This commit is contained in:
Ben Browning 2025-05-15 14:23:29 -04:00 committed by GitHub
parent bb5fca9521
commit 10b1056dea
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 225 additions and 34 deletions

View file

@ -162,7 +162,7 @@ def _process_vllm_chat_completion_end_of_stream(
finish_reason: str | None, finish_reason: str | None,
last_chunk_content: str | None, last_chunk_content: str | None,
current_event_type: ChatCompletionResponseEventType, current_event_type: ChatCompletionResponseEventType,
tool_call_buf: UnparseableToolCall, tool_call_bufs: dict[str, UnparseableToolCall] | None = None,
) -> list[OpenAIChatCompletionChunk]: ) -> list[OpenAIChatCompletionChunk]:
chunks = [] chunks = []
@ -171,9 +171,8 @@ def _process_vllm_chat_completion_end_of_stream(
else: else:
stop_reason = StopReason.end_of_message stop_reason = StopReason.end_of_message
if tool_call_buf.tool_name: tool_call_bufs = tool_call_bufs or {}
# at least one tool call request is received for _index, tool_call_buf in sorted(tool_call_bufs.items()):
args_str = tool_call_buf.arguments or "{}" args_str = tool_call_buf.arguments or "{}"
try: try:
args = json.loads(args_str) args = json.loads(args_str)
@ -225,8 +224,14 @@ def _process_vllm_chat_completion_end_of_stream(
async def _process_vllm_chat_completion_stream_response( async def _process_vllm_chat_completion_stream_response(
stream: AsyncGenerator[OpenAIChatCompletionChunk, None], stream: AsyncGenerator[OpenAIChatCompletionChunk, None],
) -> AsyncGenerator: ) -> AsyncGenerator:
event_type = ChatCompletionResponseEventType.start yield ChatCompletionResponseStreamChunk(
tool_call_buf = UnparseableToolCall() event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.start,
delta=TextDelta(text=""),
)
)
event_type = ChatCompletionResponseEventType.progress
tool_call_bufs: dict[str, UnparseableToolCall] = {}
end_of_stream_processed = False end_of_stream_processed = False
async for chunk in stream: async for chunk in stream:
@ -235,17 +240,22 @@ async def _process_vllm_chat_completion_stream_response(
return return
choice = chunk.choices[0] choice = chunk.choices[0]
if choice.delta.tool_calls: if choice.delta.tool_calls:
tool_call = convert_tool_call(choice.delta.tool_calls[0]) for delta_tool_call in choice.delta.tool_calls:
tool_call_buf.tool_name += str(tool_call.tool_name) tool_call = convert_tool_call(delta_tool_call)
tool_call_buf.call_id += tool_call.call_id if delta_tool_call.index not in tool_call_bufs:
# TODO: remove str() when dict type for 'arguments' is no longer allowed tool_call_bufs[delta_tool_call.index] = UnparseableToolCall()
tool_call_buf.arguments += str(tool_call.arguments) tool_call_buf = tool_call_bufs[delta_tool_call.index]
tool_call_buf.tool_name += str(tool_call.tool_name)
tool_call_buf.call_id += tool_call.call_id
tool_call_buf.arguments += (
tool_call.arguments if isinstance(tool_call.arguments, str) else json.dumps(tool_call.arguments)
)
if choice.finish_reason: if choice.finish_reason:
chunks = _process_vllm_chat_completion_end_of_stream( chunks = _process_vllm_chat_completion_end_of_stream(
finish_reason=choice.finish_reason, finish_reason=choice.finish_reason,
last_chunk_content=choice.delta.content, last_chunk_content=choice.delta.content,
current_event_type=event_type, current_event_type=event_type,
tool_call_buf=tool_call_buf, tool_call_bufs=tool_call_bufs,
) )
for c in chunks: for c in chunks:
yield c yield c
@ -266,7 +276,7 @@ async def _process_vllm_chat_completion_stream_response(
# the stream ended without a chunk containing finish_reason - we have to generate the # the stream ended without a chunk containing finish_reason - we have to generate the
# respective completion chunks manually # respective completion chunks manually
chunks = _process_vllm_chat_completion_end_of_stream( chunks = _process_vllm_chat_completion_end_of_stream(
finish_reason=None, last_chunk_content=None, current_event_type=event_type, tool_call_buf=tool_call_buf finish_reason=None, last_chunk_content=None, current_event_type=event_type, tool_call_bufs=tool_call_bufs
) )
for c in chunks: for c in chunks:
yield c yield c

View file

@ -531,13 +531,19 @@ async def convert_message_to_openai_dict(message: Message, download: bool = Fals
tool_name = tc.tool_name tool_name = tc.tool_name
if isinstance(tool_name, BuiltinTool): if isinstance(tool_name, BuiltinTool):
tool_name = tool_name.value tool_name = tool_name.value
# arguments_json can be None, so attempt it first and fall back to arguments
if hasattr(tc, "arguments_json") and tc.arguments_json:
arguments = tc.arguments_json
else:
arguments = json.dumps(tc.arguments)
result["tool_calls"].append( result["tool_calls"].append(
{ {
"id": tc.call_id, "id": tc.call_id,
"type": "function", "type": "function",
"function": { "function": {
"name": tool_name, "name": tool_name,
"arguments": tc.arguments_json if hasattr(tc, "arguments_json") else json.dumps(tc.arguments), "arguments": arguments,
}, },
} }
) )

View file

@ -266,6 +266,7 @@ def test_builtin_tool_web_search(llama_stack_client, agent_config):
assert found_tool_execution assert found_tool_execution
@pytest.mark.skip(reason="Code interpreter is currently disabled in the Stack")
def test_builtin_tool_code_execution(llama_stack_client, agent_config): def test_builtin_tool_code_execution(llama_stack_client, agent_config):
agent_config = { agent_config = {
**agent_config, **agent_config,
@ -346,7 +347,7 @@ def test_custom_tool(llama_stack_client, agent_config):
messages=[ messages=[
{ {
"role": "user", "role": "user",
"content": "What is the boiling point of polyjuice?", "content": "What is the boiling point of the liquid polyjuice in celsius?",
}, },
], ],
session_id=session_id, session_id=session_id,
@ -420,7 +421,7 @@ def run_agent_with_tool_choice(client, agent_config, tool_choice):
messages=[ messages=[
{ {
"role": "user", "role": "user",
"content": "What is the boiling point of polyjuice?", "content": "What is the boiling point of the liquid polyjuice in celsius?",
}, },
], ],
session_id=session_id, session_id=session_id,
@ -674,8 +675,8 @@ def test_create_turn_response(llama_stack_client, agent_config, client_tools):
def test_multi_tool_calls(llama_stack_client, agent_config): def test_multi_tool_calls(llama_stack_client, agent_config):
if "gpt" not in agent_config["model"]: if "gpt" not in agent_config["model"] and "llama-4" not in agent_config["model"].lower():
pytest.xfail("Only tested on GPT models") pytest.xfail("Only tested on GPT and Llama 4 models")
agent_config = { agent_config = {
**agent_config, **agent_config,
@ -689,23 +690,34 @@ def test_multi_tool_calls(llama_stack_client, agent_config):
messages=[ messages=[
{ {
"role": "user", "role": "user",
"content": "Call get_boiling_point twice to answer: What is the boiling point of polyjuice in both celsius and fahrenheit?", "content": "Call get_boiling_point twice to answer: What is the boiling point of polyjuice in both celsius and fahrenheit?.\nUse the tool responses to answer the question.",
}, },
], ],
session_id=session_id, session_id=session_id,
stream=False, stream=False,
) )
steps = response.steps steps = response.steps
assert len(steps) == 7
assert steps[0].step_type == "shield_call"
assert steps[1].step_type == "inference"
assert steps[2].step_type == "shield_call"
assert steps[3].step_type == "tool_execution"
assert steps[4].step_type == "shield_call"
assert steps[5].step_type == "inference"
assert steps[6].step_type == "shield_call"
tool_execution_step = steps[3] has_input_shield = agent_config.get("input_shields")
has_output_shield = agent_config.get("output_shields")
assert len(steps) == 3 + (2 if has_input_shield else 0) + (2 if has_output_shield else 0)
if has_input_shield:
assert steps[0].step_type == "shield_call"
steps.pop(0)
assert steps[0].step_type == "inference"
if has_output_shield:
assert steps[1].step_type == "shield_call"
steps.pop(1)
assert steps[1].step_type == "tool_execution"
tool_execution_step = steps[1]
if has_input_shield:
assert steps[2].step_type == "shield_call"
steps.pop(2)
assert steps[2].step_type == "inference"
if has_output_shield:
assert steps[3].step_type == "shield_call"
steps.pop(3)
assert len(tool_execution_step.tool_calls) == 2 assert len(tool_execution_step.tool_calls) == 2
assert tool_execution_step.tool_calls[0].tool_name.startswith("get_boiling_point") assert tool_execution_step.tool_calls[0].tool_name.startswith("get_boiling_point")
assert tool_execution_step.tool_calls[1].tool_name.startswith("get_boiling_point") assert tool_execution_step.tool_calls[1].tool_name.startswith("get_boiling_point")

View file

@ -24,6 +24,12 @@ from openai.types.chat.chat_completion_chunk import (
from openai.types.chat.chat_completion_chunk import ( from openai.types.chat.chat_completion_chunk import (
ChoiceDelta as OpenAIChoiceDelta, ChoiceDelta as OpenAIChoiceDelta,
) )
from openai.types.chat.chat_completion_chunk import (
ChoiceDeltaToolCall as OpenAIChoiceDeltaToolCall,
)
from openai.types.chat.chat_completion_chunk import (
ChoiceDeltaToolCallFunction as OpenAIChoiceDeltaToolCallFunction,
)
from openai.types.model import Model as OpenAIModel from openai.types.model import Model as OpenAIModel
from llama_stack.apis.inference import ( from llama_stack.apis.inference import (
@ -206,8 +212,164 @@ async def test_tool_call_delta_empty_tool_call_buf():
yield chunk yield chunk
chunks = [chunk async for chunk in _process_vllm_chat_completion_stream_response(mock_stream())] chunks = [chunk async for chunk in _process_vllm_chat_completion_stream_response(mock_stream())]
assert len(chunks) == 1 assert len(chunks) == 2
assert chunks[0].event.stop_reason == StopReason.end_of_turn assert chunks[0].event.event_type.value == "start"
assert chunks[1].event.event_type.value == "complete"
assert chunks[1].event.stop_reason == StopReason.end_of_turn
@pytest.mark.asyncio
async def test_tool_call_delta_streaming_arguments_dict():
async def mock_stream():
mock_chunk_1 = OpenAIChatCompletionChunk(
id="chunk-1",
created=1,
model="foo",
object="chat.completion.chunk",
choices=[
OpenAIChoice(
delta=OpenAIChoiceDelta(
content="",
tool_calls=[
OpenAIChoiceDeltaToolCall(
id="tc_1",
index=1,
function=OpenAIChoiceDeltaToolCallFunction(
name="power",
arguments="",
),
)
],
),
finish_reason=None,
index=0,
)
],
)
mock_chunk_2 = OpenAIChatCompletionChunk(
id="chunk-2",
created=1,
model="foo",
object="chat.completion.chunk",
choices=[
OpenAIChoice(
delta=OpenAIChoiceDelta(
content="",
tool_calls=[
OpenAIChoiceDeltaToolCall(
id="tc_1",
index=1,
function=OpenAIChoiceDeltaToolCallFunction(
name="power",
arguments='{"number": 28, "power": 3}',
),
)
],
),
finish_reason=None,
index=0,
)
],
)
mock_chunk_3 = OpenAIChatCompletionChunk(
id="chunk-3",
created=1,
model="foo",
object="chat.completion.chunk",
choices=[
OpenAIChoice(delta=OpenAIChoiceDelta(content="", tool_calls=None), finish_reason="tool_calls", index=0)
],
)
for chunk in [mock_chunk_1, mock_chunk_2, mock_chunk_3]:
yield chunk
chunks = [chunk async for chunk in _process_vllm_chat_completion_stream_response(mock_stream())]
assert len(chunks) == 3
assert chunks[0].event.event_type.value == "start"
assert chunks[1].event.event_type.value == "progress"
assert chunks[1].event.delta.type == "tool_call"
assert chunks[1].event.delta.parse_status.value == "succeeded"
assert chunks[1].event.delta.tool_call.arguments_json == '{"number": 28, "power": 3}'
assert chunks[2].event.event_type.value == "complete"
@pytest.mark.asyncio
async def test_multiple_tool_calls():
async def mock_stream():
mock_chunk_1 = OpenAIChatCompletionChunk(
id="chunk-1",
created=1,
model="foo",
object="chat.completion.chunk",
choices=[
OpenAIChoice(
delta=OpenAIChoiceDelta(
content="",
tool_calls=[
OpenAIChoiceDeltaToolCall(
id="",
index=1,
function=OpenAIChoiceDeltaToolCallFunction(
name="power",
arguments='{"number": 28, "power": 3}',
),
),
],
),
finish_reason=None,
index=0,
)
],
)
mock_chunk_2 = OpenAIChatCompletionChunk(
id="chunk-2",
created=1,
model="foo",
object="chat.completion.chunk",
choices=[
OpenAIChoice(
delta=OpenAIChoiceDelta(
content="",
tool_calls=[
OpenAIChoiceDeltaToolCall(
id="",
index=2,
function=OpenAIChoiceDeltaToolCallFunction(
name="multiple",
arguments='{"first_number": 4, "second_number": 7}',
),
),
],
),
finish_reason=None,
index=0,
)
],
)
mock_chunk_3 = OpenAIChatCompletionChunk(
id="chunk-3",
created=1,
model="foo",
object="chat.completion.chunk",
choices=[
OpenAIChoice(delta=OpenAIChoiceDelta(content="", tool_calls=None), finish_reason="tool_calls", index=0)
],
)
for chunk in [mock_chunk_1, mock_chunk_2, mock_chunk_3]:
yield chunk
chunks = [chunk async for chunk in _process_vllm_chat_completion_stream_response(mock_stream())]
assert len(chunks) == 4
assert chunks[0].event.event_type.value == "start"
assert chunks[1].event.event_type.value == "progress"
assert chunks[1].event.delta.type == "tool_call"
assert chunks[1].event.delta.parse_status.value == "succeeded"
assert chunks[1].event.delta.tool_call.arguments_json == '{"number": 28, "power": 3}'
assert chunks[2].event.event_type.value == "progress"
assert chunks[2].event.delta.type == "tool_call"
assert chunks[2].event.delta.parse_status.value == "succeeded"
assert chunks[2].event.delta.tool_call.arguments_json == '{"first_number": 4, "second_number": 7}'
assert chunks[3].event.event_type.value == "complete"
@pytest.mark.asyncio @pytest.mark.asyncio
@ -231,7 +393,8 @@ async def test_process_vllm_chat_completion_stream_response_no_choices():
yield chunk yield chunk
chunks = [chunk async for chunk in _process_vllm_chat_completion_stream_response(mock_stream())] chunks = [chunk async for chunk in _process_vllm_chat_completion_stream_response(mock_stream())]
assert len(chunks) == 0 assert len(chunks) == 1
assert chunks[0].event.event_type.value == "start"
def test_chat_completion_doesnt_block_event_loop(caplog): def test_chat_completion_doesnt_block_event_loop(caplog):
@ -369,7 +532,7 @@ async def test_process_vllm_chat_completion_stream_response_tool_call_args_last_
yield chunk yield chunk
chunks = [chunk async for chunk in _process_vllm_chat_completion_stream_response(mock_stream())] chunks = [chunk async for chunk in _process_vllm_chat_completion_stream_response(mock_stream())]
assert len(chunks) == 2 assert len(chunks) == 3
assert chunks[-1].event.event_type == ChatCompletionResponseEventType.complete assert chunks[-1].event.event_type == ChatCompletionResponseEventType.complete
assert chunks[-2].event.delta.type == "tool_call" assert chunks[-2].event.delta.type == "tool_call"
assert chunks[-2].event.delta.tool_call.tool_name == mock_tool_name assert chunks[-2].event.delta.tool_call.tool_name == mock_tool_name
@ -422,7 +585,7 @@ async def test_process_vllm_chat_completion_stream_response_no_finish_reason():
yield chunk yield chunk
chunks = [chunk async for chunk in _process_vllm_chat_completion_stream_response(mock_stream())] chunks = [chunk async for chunk in _process_vllm_chat_completion_stream_response(mock_stream())]
assert len(chunks) == 2 assert len(chunks) == 3
assert chunks[-1].event.event_type == ChatCompletionResponseEventType.complete assert chunks[-1].event.event_type == ChatCompletionResponseEventType.complete
assert chunks[-2].event.delta.type == "tool_call" assert chunks[-2].event.delta.type == "tool_call"
assert chunks[-2].event.delta.tool_call.tool_name == mock_tool_name assert chunks[-2].event.delta.tool_call.tool_name == mock_tool_name
@ -471,7 +634,7 @@ async def test_process_vllm_chat_completion_stream_response_tool_without_args():
yield chunk yield chunk
chunks = [chunk async for chunk in _process_vllm_chat_completion_stream_response(mock_stream())] chunks = [chunk async for chunk in _process_vllm_chat_completion_stream_response(mock_stream())]
assert len(chunks) == 2 assert len(chunks) == 3
assert chunks[-1].event.event_type == ChatCompletionResponseEventType.complete assert chunks[-1].event.event_type == ChatCompletionResponseEventType.complete
assert chunks[-2].event.delta.type == "tool_call" assert chunks[-2].event.delta.type == "tool_call"
assert chunks[-2].event.delta.tool_call.tool_name == mock_tool_name assert chunks[-2].event.delta.tool_call.tool_name == mock_tool_name