mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-02 08:44:44 +00:00
Responses API: Finish wiring up function tool calls
This finishes the plumbing for function tool call and adds a basic verification test (that passes for me locally against Llama 4 Scout in vllm). Signed-off-by: Ben Browning <bbrownin@redhat.com>
This commit is contained in:
parent
1990df2c50
commit
924213a689
6 changed files with 148 additions and 6 deletions
41
docs/_static/llama-stack-spec.html
vendored
41
docs/_static/llama-stack-spec.html
vendored
|
@ -6920,16 +6920,55 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"$ref": "#/components/schemas/OpenAIResponseOutputMessageWebSearchToolCall"
|
"$ref": "#/components/schemas/OpenAIResponseOutputMessageWebSearchToolCall"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"$ref": "#/components/schemas/OpenAIResponseOutputMessageFunctionToolCall"
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"discriminator": {
|
"discriminator": {
|
||||||
"propertyName": "type",
|
"propertyName": "type",
|
||||||
"mapping": {
|
"mapping": {
|
||||||
"message": "#/components/schemas/OpenAIResponseMessage",
|
"message": "#/components/schemas/OpenAIResponseMessage",
|
||||||
"web_search_call": "#/components/schemas/OpenAIResponseOutputMessageWebSearchToolCall"
|
"web_search_call": "#/components/schemas/OpenAIResponseOutputMessageWebSearchToolCall",
|
||||||
|
"function_call": "#/components/schemas/OpenAIResponseOutputMessageFunctionToolCall"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
"OpenAIResponseOutputMessageFunctionToolCall": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"arguments": {
|
||||||
|
"type": "string"
|
||||||
|
},
|
||||||
|
"call_id": {
|
||||||
|
"type": "string"
|
||||||
|
},
|
||||||
|
"name": {
|
||||||
|
"type": "string"
|
||||||
|
},
|
||||||
|
"type": {
|
||||||
|
"type": "string",
|
||||||
|
"const": "function_call",
|
||||||
|
"default": "function_call"
|
||||||
|
},
|
||||||
|
"id": {
|
||||||
|
"type": "string"
|
||||||
|
},
|
||||||
|
"status": {
|
||||||
|
"type": "string"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"additionalProperties": false,
|
||||||
|
"required": [
|
||||||
|
"arguments",
|
||||||
|
"call_id",
|
||||||
|
"name",
|
||||||
|
"type",
|
||||||
|
"id",
|
||||||
|
"status"
|
||||||
|
],
|
||||||
|
"title": "OpenAIResponseOutputMessageFunctionToolCall"
|
||||||
|
},
|
||||||
"OpenAIResponseObjectStream": {
|
"OpenAIResponseObjectStream": {
|
||||||
"oneOf": [
|
"oneOf": [
|
||||||
{
|
{
|
||||||
|
|
29
docs/_static/llama-stack-spec.yaml
vendored
29
docs/_static/llama-stack-spec.yaml
vendored
|
@ -4833,11 +4833,40 @@ components:
|
||||||
oneOf:
|
oneOf:
|
||||||
- $ref: '#/components/schemas/OpenAIResponseMessage'
|
- $ref: '#/components/schemas/OpenAIResponseMessage'
|
||||||
- $ref: '#/components/schemas/OpenAIResponseOutputMessageWebSearchToolCall'
|
- $ref: '#/components/schemas/OpenAIResponseOutputMessageWebSearchToolCall'
|
||||||
|
- $ref: '#/components/schemas/OpenAIResponseOutputMessageFunctionToolCall'
|
||||||
discriminator:
|
discriminator:
|
||||||
propertyName: type
|
propertyName: type
|
||||||
mapping:
|
mapping:
|
||||||
message: '#/components/schemas/OpenAIResponseMessage'
|
message: '#/components/schemas/OpenAIResponseMessage'
|
||||||
web_search_call: '#/components/schemas/OpenAIResponseOutputMessageWebSearchToolCall'
|
web_search_call: '#/components/schemas/OpenAIResponseOutputMessageWebSearchToolCall'
|
||||||
|
function_call: '#/components/schemas/OpenAIResponseOutputMessageFunctionToolCall'
|
||||||
|
"OpenAIResponseOutputMessageFunctionToolCall":
|
||||||
|
type: object
|
||||||
|
properties:
|
||||||
|
arguments:
|
||||||
|
type: string
|
||||||
|
call_id:
|
||||||
|
type: string
|
||||||
|
name:
|
||||||
|
type: string
|
||||||
|
type:
|
||||||
|
type: string
|
||||||
|
const: function_call
|
||||||
|
default: function_call
|
||||||
|
id:
|
||||||
|
type: string
|
||||||
|
status:
|
||||||
|
type: string
|
||||||
|
additionalProperties: false
|
||||||
|
required:
|
||||||
|
- arguments
|
||||||
|
- call_id
|
||||||
|
- name
|
||||||
|
- type
|
||||||
|
- id
|
||||||
|
- status
|
||||||
|
title: >-
|
||||||
|
OpenAIResponseOutputMessageFunctionToolCall
|
||||||
OpenAIResponseObjectStream:
|
OpenAIResponseObjectStream:
|
||||||
oneOf:
|
oneOf:
|
||||||
- $ref: '#/components/schemas/OpenAIResponseObjectStreamResponseCreated'
|
- $ref: '#/components/schemas/OpenAIResponseObjectStreamResponseCreated'
|
||||||
|
|
|
@ -77,8 +77,20 @@ class OpenAIResponseOutputMessageWebSearchToolCall(BaseModel):
|
||||||
type: Literal["web_search_call"] = "web_search_call"
|
type: Literal["web_search_call"] = "web_search_call"
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class OpenAIResponseOutputMessageFunctionToolCall(BaseModel):
|
||||||
|
arguments: str
|
||||||
|
call_id: str
|
||||||
|
name: str
|
||||||
|
type: Literal["function_call"] = "function_call"
|
||||||
|
id: str
|
||||||
|
status: str
|
||||||
|
|
||||||
|
|
||||||
OpenAIResponseOutput = Annotated[
|
OpenAIResponseOutput = Annotated[
|
||||||
OpenAIResponseMessage | OpenAIResponseOutputMessageWebSearchToolCall,
|
OpenAIResponseMessage
|
||||||
|
| OpenAIResponseOutputMessageWebSearchToolCall
|
||||||
|
| OpenAIResponseOutputMessageFunctionToolCall,
|
||||||
Field(discriminator="type"),
|
Field(discriminator="type"),
|
||||||
]
|
]
|
||||||
register_schema(OpenAIResponseOutput, name="OpenAIResponseOutput")
|
register_schema(OpenAIResponseOutput, name="OpenAIResponseOutput")
|
||||||
|
|
|
@ -17,6 +17,7 @@ from llama_stack.apis.agents.openai_responses import (
|
||||||
OpenAIResponseInputMessageContentImage,
|
OpenAIResponseInputMessageContentImage,
|
||||||
OpenAIResponseInputMessageContentText,
|
OpenAIResponseInputMessageContentText,
|
||||||
OpenAIResponseInputTool,
|
OpenAIResponseInputTool,
|
||||||
|
OpenAIResponseInputToolFunction,
|
||||||
OpenAIResponseMessage,
|
OpenAIResponseMessage,
|
||||||
OpenAIResponseObject,
|
OpenAIResponseObject,
|
||||||
OpenAIResponseObjectStream,
|
OpenAIResponseObjectStream,
|
||||||
|
@ -24,6 +25,7 @@ from llama_stack.apis.agents.openai_responses import (
|
||||||
OpenAIResponseObjectStreamResponseCreated,
|
OpenAIResponseObjectStreamResponseCreated,
|
||||||
OpenAIResponseOutput,
|
OpenAIResponseOutput,
|
||||||
OpenAIResponseOutputMessageContentOutputText,
|
OpenAIResponseOutputMessageContentOutputText,
|
||||||
|
OpenAIResponseOutputMessageFunctionToolCall,
|
||||||
OpenAIResponseOutputMessageWebSearchToolCall,
|
OpenAIResponseOutputMessageWebSearchToolCall,
|
||||||
OpenAIResponsePreviousResponseWithInputItems,
|
OpenAIResponsePreviousResponseWithInputItems,
|
||||||
)
|
)
|
||||||
|
@ -221,10 +223,28 @@ class OpenAIResponsesImpl:
|
||||||
chat_response = OpenAIChatCompletion(**chat_response.model_dump())
|
chat_response = OpenAIChatCompletion(**chat_response.model_dump())
|
||||||
|
|
||||||
output_messages: list[OpenAIResponseOutput] = []
|
output_messages: list[OpenAIResponseOutput] = []
|
||||||
if chat_response.choices[0].message.tool_calls:
|
# TODO: should we check more than choices[0] here?
|
||||||
output_messages.extend(
|
if chat_response.choices[0].message.tool_calls and tools:
|
||||||
await self._execute_tool_and_return_final_output(model, stream, chat_response, messages, temperature)
|
# TODO: Should we support a mix of custom and builtin tools?
|
||||||
)
|
# in other words, should we check for more than tools[0]?
|
||||||
|
if isinstance(tools[0], OpenAIResponseInputToolFunction):
|
||||||
|
choice = chat_response.choices[0]
|
||||||
|
for tool_call in choice.message.tool_calls:
|
||||||
|
output_messages.append(
|
||||||
|
OpenAIResponseOutputMessageFunctionToolCall(
|
||||||
|
arguments=tool_call.function.arguments or "",
|
||||||
|
call_id=tool_call.id,
|
||||||
|
name=tool_call.function.name or "",
|
||||||
|
id=f"fc_{uuid.uuid4()}",
|
||||||
|
status="completed",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
output_messages.extend(
|
||||||
|
await self._execute_tool_and_return_final_output(
|
||||||
|
model, stream, chat_response, messages, temperature
|
||||||
|
)
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
output_messages.extend(await _openai_choices_to_output_messages(chat_response.choices))
|
output_messages.extend(await _openai_choices_to_output_messages(chat_response.choices))
|
||||||
response = OpenAIResponseObject(
|
response = OpenAIResponseObject(
|
||||||
|
|
|
@ -31,6 +31,26 @@ test_response_web_search:
|
||||||
search_context_size: "low"
|
search_context_size: "low"
|
||||||
output: "128"
|
output: "128"
|
||||||
|
|
||||||
|
test_response_custom_tool:
|
||||||
|
test_name: test_response_custom_tool
|
||||||
|
test_params:
|
||||||
|
case:
|
||||||
|
- case_id: "sf_weather"
|
||||||
|
input: "What's the weather like in San Francisco?"
|
||||||
|
tools:
|
||||||
|
- type: function
|
||||||
|
name: get_weather
|
||||||
|
description: Get current temperature for a given location.
|
||||||
|
parameters:
|
||||||
|
additionalProperties: false
|
||||||
|
properties:
|
||||||
|
location:
|
||||||
|
description: "City and country e.g. Bogot\xE1, Colombia"
|
||||||
|
type: string
|
||||||
|
required:
|
||||||
|
- location
|
||||||
|
type: object
|
||||||
|
|
||||||
test_response_image:
|
test_response_image:
|
||||||
test_name: test_response_image
|
test_name: test_response_image
|
||||||
test_params:
|
test_params:
|
||||||
|
|
|
@ -124,6 +124,28 @@ def test_response_non_streaming_web_search(request, openai_client, model, provid
|
||||||
assert case["output"].lower() in response.output_text.lower().strip()
|
assert case["output"].lower() in response.output_text.lower().strip()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"case",
|
||||||
|
responses_test_cases["test_response_custom_tool"]["test_params"]["case"],
|
||||||
|
ids=case_id_generator,
|
||||||
|
)
|
||||||
|
def test_response_non_streaming_custom_tool(request, openai_client, model, provider, verification_config, case):
|
||||||
|
test_name_base = get_base_test_name(request)
|
||||||
|
if should_skip_test(verification_config, provider, model, test_name_base):
|
||||||
|
pytest.skip(f"Skipping {test_name_base} for model {model} on provider {provider} based on config.")
|
||||||
|
|
||||||
|
response = openai_client.responses.create(
|
||||||
|
model=model,
|
||||||
|
input=case["input"],
|
||||||
|
tools=case["tools"],
|
||||||
|
stream=False,
|
||||||
|
)
|
||||||
|
assert len(response.output) == 1
|
||||||
|
assert response.output[0].type == "function_call"
|
||||||
|
assert response.output[0].status == "completed"
|
||||||
|
assert response.output[0].name == "get_weather"
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"case",
|
"case",
|
||||||
responses_test_cases["test_response_image"]["test_params"]["case"],
|
responses_test_cases["test_response_image"]["test_params"]["case"],
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue