Responses API: Finish wiring up function tool calls

This finishes the plumbing for function tool call and adds a basic
verification test (that passes for me locally against Llama 4 Scout in
vllm).

Signed-off-by: Ben Browning <bbrownin@redhat.com>
This commit is contained in:
Ben Browning 2025-04-30 17:01:00 -04:00
parent 1990df2c50
commit 924213a689
6 changed files with 148 additions and 6 deletions

View file

@ -6920,16 +6920,55 @@
},
{
"$ref": "#/components/schemas/OpenAIResponseOutputMessageWebSearchToolCall"
},
{
"$ref": "#/components/schemas/OpenAIResponseOutputMessageFunctionToolCall"
}
],
"discriminator": {
"propertyName": "type",
"mapping": {
"message": "#/components/schemas/OpenAIResponseMessage",
"web_search_call": "#/components/schemas/OpenAIResponseOutputMessageWebSearchToolCall"
"web_search_call": "#/components/schemas/OpenAIResponseOutputMessageWebSearchToolCall",
"function_call": "#/components/schemas/OpenAIResponseOutputMessageFunctionToolCall"
}
}
},
"OpenAIResponseOutputMessageFunctionToolCall": {
"type": "object",
"properties": {
"arguments": {
"type": "string"
},
"call_id": {
"type": "string"
},
"name": {
"type": "string"
},
"type": {
"type": "string",
"const": "function_call",
"default": "function_call"
},
"id": {
"type": "string"
},
"status": {
"type": "string"
}
},
"additionalProperties": false,
"required": [
"arguments",
"call_id",
"name",
"type",
"id",
"status"
],
"title": "OpenAIResponseOutputMessageFunctionToolCall"
},
"OpenAIResponseObjectStream": {
"oneOf": [
{

View file

@ -4833,11 +4833,40 @@ components:
oneOf:
- $ref: '#/components/schemas/OpenAIResponseMessage'
- $ref: '#/components/schemas/OpenAIResponseOutputMessageWebSearchToolCall'
- $ref: '#/components/schemas/OpenAIResponseOutputMessageFunctionToolCall'
discriminator:
propertyName: type
mapping:
message: '#/components/schemas/OpenAIResponseMessage'
web_search_call: '#/components/schemas/OpenAIResponseOutputMessageWebSearchToolCall'
function_call: '#/components/schemas/OpenAIResponseOutputMessageFunctionToolCall'
"OpenAIResponseOutputMessageFunctionToolCall":
type: object
properties:
arguments:
type: string
call_id:
type: string
name:
type: string
type:
type: string
const: function_call
default: function_call
id:
type: string
status:
type: string
additionalProperties: false
required:
- arguments
- call_id
- name
- type
- id
- status
title: >-
OpenAIResponseOutputMessageFunctionToolCall
OpenAIResponseObjectStream:
oneOf:
- $ref: '#/components/schemas/OpenAIResponseObjectStreamResponseCreated'

View file

@ -77,8 +77,20 @@ class OpenAIResponseOutputMessageWebSearchToolCall(BaseModel):
type: Literal["web_search_call"] = "web_search_call"
@json_schema_type
class OpenAIResponseOutputMessageFunctionToolCall(BaseModel):
arguments: str
call_id: str
name: str
type: Literal["function_call"] = "function_call"
id: str
status: str
OpenAIResponseOutput = Annotated[
OpenAIResponseMessage | OpenAIResponseOutputMessageWebSearchToolCall,
OpenAIResponseMessage
| OpenAIResponseOutputMessageWebSearchToolCall
| OpenAIResponseOutputMessageFunctionToolCall,
Field(discriminator="type"),
]
register_schema(OpenAIResponseOutput, name="OpenAIResponseOutput")

View file

@ -17,6 +17,7 @@ from llama_stack.apis.agents.openai_responses import (
OpenAIResponseInputMessageContentImage,
OpenAIResponseInputMessageContentText,
OpenAIResponseInputTool,
OpenAIResponseInputToolFunction,
OpenAIResponseMessage,
OpenAIResponseObject,
OpenAIResponseObjectStream,
@ -24,6 +25,7 @@ from llama_stack.apis.agents.openai_responses import (
OpenAIResponseObjectStreamResponseCreated,
OpenAIResponseOutput,
OpenAIResponseOutputMessageContentOutputText,
OpenAIResponseOutputMessageFunctionToolCall,
OpenAIResponseOutputMessageWebSearchToolCall,
OpenAIResponsePreviousResponseWithInputItems,
)
@ -221,10 +223,28 @@ class OpenAIResponsesImpl:
chat_response = OpenAIChatCompletion(**chat_response.model_dump())
output_messages: list[OpenAIResponseOutput] = []
if chat_response.choices[0].message.tool_calls:
output_messages.extend(
await self._execute_tool_and_return_final_output(model, stream, chat_response, messages, temperature)
)
# TODO: should we check more than choices[0] here?
if chat_response.choices[0].message.tool_calls and tools:
# TODO: Should we support a mix of custom and builtin tools?
# in other words, should we check for more than tools[0]?
if isinstance(tools[0], OpenAIResponseInputToolFunction):
choice = chat_response.choices[0]
for tool_call in choice.message.tool_calls:
output_messages.append(
OpenAIResponseOutputMessageFunctionToolCall(
arguments=tool_call.function.arguments or "",
call_id=tool_call.id,
name=tool_call.function.name or "",
id=f"fc_{uuid.uuid4()}",
status="completed",
)
)
else:
output_messages.extend(
await self._execute_tool_and_return_final_output(
model, stream, chat_response, messages, temperature
)
)
else:
output_messages.extend(await _openai_choices_to_output_messages(chat_response.choices))
response = OpenAIResponseObject(

View file

@ -31,6 +31,26 @@ test_response_web_search:
search_context_size: "low"
output: "128"
test_response_custom_tool:
test_name: test_response_custom_tool
test_params:
case:
- case_id: "sf_weather"
input: "What's the weather like in San Francisco?"
tools:
- type: function
name: get_weather
description: Get current temperature for a given location.
parameters:
additionalProperties: false
properties:
location:
description: "City and country e.g. Bogot\xE1, Colombia"
type: string
required:
- location
type: object
test_response_image:
test_name: test_response_image
test_params:

View file

@ -124,6 +124,28 @@ def test_response_non_streaming_web_search(request, openai_client, model, provid
assert case["output"].lower() in response.output_text.lower().strip()
@pytest.mark.parametrize(
"case",
responses_test_cases["test_response_custom_tool"]["test_params"]["case"],
ids=case_id_generator,
)
def test_response_non_streaming_custom_tool(request, openai_client, model, provider, verification_config, case):
test_name_base = get_base_test_name(request)
if should_skip_test(verification_config, provider, model, test_name_base):
pytest.skip(f"Skipping {test_name_base} for model {model} on provider {provider} based on config.")
response = openai_client.responses.create(
model=model,
input=case["input"],
tools=case["tools"],
stream=False,
)
assert len(response.output) == 1
assert response.output[0].type == "function_call"
assert response.output[0].status == "completed"
assert response.output[0].name == "get_weather"
@pytest.mark.parametrize(
"case",
responses_test_cases["test_response_image"]["test_params"]["case"],