mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-01 16:24:44 +00:00
Fixes for multi-turn tool calls in Responses API
Testing with Codex locally, I found another issue in how we were plumbing through tool calls in multi-turn scenarios and the way tool call inputs and outputs from previous turns were passed back into future turns. This led me to realize we were missing the function tool call output type in the Responses API, so this adds that and plumbs handling of it through the responses API to chat completion conversion code. Signed-off-by: Ben Browning <bbrownin@redhat.com>
This commit is contained in:
parent
65c56d0ee8
commit
4df8caab41
4 changed files with 187 additions and 69 deletions
106
docs/_static/llama-stack-spec.html
vendored
106
docs/_static/llama-stack-spec.html
vendored
|
@ -6471,11 +6471,47 @@
|
|||
{
|
||||
"$ref": "#/components/schemas/OpenAIResponseOutputMessageWebSearchToolCall"
|
||||
},
|
||||
{
|
||||
"$ref": "#/components/schemas/OpenAIResponseOutputMessageFunctionToolCall"
|
||||
},
|
||||
{
|
||||
"$ref": "#/components/schemas/OpenAIResponseInputFunctionToolCallOutput"
|
||||
},
|
||||
{
|
||||
"$ref": "#/components/schemas/OpenAIResponseMessage"
|
||||
}
|
||||
]
|
||||
},
|
||||
"OpenAIResponseInputFunctionToolCallOutput": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"call_id": {
|
||||
"type": "string"
|
||||
},
|
||||
"output": {
|
||||
"type": "string"
|
||||
},
|
||||
"type": {
|
||||
"type": "string",
|
||||
"const": "function_call_output",
|
||||
"default": "function_call_output"
|
||||
},
|
||||
"id": {
|
||||
"type": "string"
|
||||
},
|
||||
"status": {
|
||||
"type": "string"
|
||||
}
|
||||
},
|
||||
"additionalProperties": false,
|
||||
"required": [
|
||||
"call_id",
|
||||
"output",
|
||||
"type"
|
||||
],
|
||||
"title": "OpenAIResponseInputFunctionToolCallOutput",
|
||||
"description": "This represents the output of a function call that gets passed back to the model."
|
||||
},
|
||||
"OpenAIResponseInputMessageContent": {
|
||||
"oneOf": [
|
||||
{
|
||||
|
@ -6764,6 +6800,41 @@
|
|||
],
|
||||
"title": "OpenAIResponseOutputMessageContentOutputText"
|
||||
},
|
||||
"OpenAIResponseOutputMessageFunctionToolCall": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"arguments": {
|
||||
"type": "string"
|
||||
},
|
||||
"call_id": {
|
||||
"type": "string"
|
||||
},
|
||||
"name": {
|
||||
"type": "string"
|
||||
},
|
||||
"type": {
|
||||
"type": "string",
|
||||
"const": "function_call",
|
||||
"default": "function_call"
|
||||
},
|
||||
"id": {
|
||||
"type": "string"
|
||||
},
|
||||
"status": {
|
||||
"type": "string"
|
||||
}
|
||||
},
|
||||
"additionalProperties": false,
|
||||
"required": [
|
||||
"arguments",
|
||||
"call_id",
|
||||
"name",
|
||||
"type",
|
||||
"id",
|
||||
"status"
|
||||
],
|
||||
"title": "OpenAIResponseOutputMessageFunctionToolCall"
|
||||
},
|
||||
"OpenAIResponseOutputMessageWebSearchToolCall": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
|
@ -6934,41 +7005,6 @@
|
|||
}
|
||||
}
|
||||
},
|
||||
"OpenAIResponseOutputMessageFunctionToolCall": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"arguments": {
|
||||
"type": "string"
|
||||
},
|
||||
"call_id": {
|
||||
"type": "string"
|
||||
},
|
||||
"name": {
|
||||
"type": "string"
|
||||
},
|
||||
"type": {
|
||||
"type": "string",
|
||||
"const": "function_call",
|
||||
"default": "function_call"
|
||||
},
|
||||
"id": {
|
||||
"type": "string"
|
||||
},
|
||||
"status": {
|
||||
"type": "string"
|
||||
}
|
||||
},
|
||||
"additionalProperties": false,
|
||||
"required": [
|
||||
"arguments",
|
||||
"call_id",
|
||||
"name",
|
||||
"type",
|
||||
"id",
|
||||
"status"
|
||||
],
|
||||
"title": "OpenAIResponseOutputMessageFunctionToolCall"
|
||||
},
|
||||
"OpenAIResponseObjectStream": {
|
||||
"oneOf": [
|
||||
{
|
||||
|
|
81
docs/_static/llama-stack-spec.yaml
vendored
81
docs/_static/llama-stack-spec.yaml
vendored
|
@ -4537,7 +4537,34 @@ components:
|
|||
OpenAIResponseInput:
|
||||
oneOf:
|
||||
- $ref: '#/components/schemas/OpenAIResponseOutputMessageWebSearchToolCall'
|
||||
- $ref: '#/components/schemas/OpenAIResponseOutputMessageFunctionToolCall'
|
||||
- $ref: '#/components/schemas/OpenAIResponseInputFunctionToolCallOutput'
|
||||
- $ref: '#/components/schemas/OpenAIResponseMessage'
|
||||
"OpenAIResponseInputFunctionToolCallOutput":
|
||||
type: object
|
||||
properties:
|
||||
call_id:
|
||||
type: string
|
||||
output:
|
||||
type: string
|
||||
type:
|
||||
type: string
|
||||
const: function_call_output
|
||||
default: function_call_output
|
||||
id:
|
||||
type: string
|
||||
status:
|
||||
type: string
|
||||
additionalProperties: false
|
||||
required:
|
||||
- call_id
|
||||
- output
|
||||
- type
|
||||
title: >-
|
||||
OpenAIResponseInputFunctionToolCallOutput
|
||||
description: >-
|
||||
This represents the output of a function call that gets passed back to the
|
||||
model.
|
||||
OpenAIResponseInputMessageContent:
|
||||
oneOf:
|
||||
- $ref: '#/components/schemas/OpenAIResponseInputMessageContentText'
|
||||
|
@ -4721,6 +4748,33 @@ components:
|
|||
- type
|
||||
title: >-
|
||||
OpenAIResponseOutputMessageContentOutputText
|
||||
"OpenAIResponseOutputMessageFunctionToolCall":
|
||||
type: object
|
||||
properties:
|
||||
arguments:
|
||||
type: string
|
||||
call_id:
|
||||
type: string
|
||||
name:
|
||||
type: string
|
||||
type:
|
||||
type: string
|
||||
const: function_call
|
||||
default: function_call
|
||||
id:
|
||||
type: string
|
||||
status:
|
||||
type: string
|
||||
additionalProperties: false
|
||||
required:
|
||||
- arguments
|
||||
- call_id
|
||||
- name
|
||||
- type
|
||||
- id
|
||||
- status
|
||||
title: >-
|
||||
OpenAIResponseOutputMessageFunctionToolCall
|
||||
"OpenAIResponseOutputMessageWebSearchToolCall":
|
||||
type: object
|
||||
properties:
|
||||
|
@ -4840,33 +4894,6 @@ components:
|
|||
message: '#/components/schemas/OpenAIResponseMessage'
|
||||
web_search_call: '#/components/schemas/OpenAIResponseOutputMessageWebSearchToolCall'
|
||||
function_call: '#/components/schemas/OpenAIResponseOutputMessageFunctionToolCall'
|
||||
"OpenAIResponseOutputMessageFunctionToolCall":
|
||||
type: object
|
||||
properties:
|
||||
arguments:
|
||||
type: string
|
||||
call_id:
|
||||
type: string
|
||||
name:
|
||||
type: string
|
||||
type:
|
||||
type: string
|
||||
const: function_call
|
||||
default: function_call
|
||||
id:
|
||||
type: string
|
||||
status:
|
||||
type: string
|
||||
additionalProperties: false
|
||||
required:
|
||||
- arguments
|
||||
- call_id
|
||||
- name
|
||||
- type
|
||||
- id
|
||||
- status
|
||||
title: >-
|
||||
OpenAIResponseOutputMessageFunctionToolCall
|
||||
OpenAIResponseObjectStream:
|
||||
oneOf:
|
||||
- $ref: '#/components/schemas/OpenAIResponseObjectStreamResponseCreated'
|
||||
|
|
|
@ -130,9 +130,24 @@ OpenAIResponseObjectStream = Annotated[
|
|||
register_schema(OpenAIResponseObjectStream, name="OpenAIResponseObjectStream")
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class OpenAIResponseInputFunctionToolCallOutput(BaseModel):
|
||||
"""
|
||||
This represents the output of a function call that gets passed back to the model.
|
||||
"""
|
||||
|
||||
call_id: str
|
||||
output: str
|
||||
type: Literal["function_call_output"] = "function_call_output"
|
||||
id: str | None = None
|
||||
status: str | None = None
|
||||
|
||||
|
||||
OpenAIResponseInput = Annotated[
|
||||
# Responses API allows output messages to be passed in as input
|
||||
OpenAIResponseOutputMessageWebSearchToolCall
|
||||
| OpenAIResponseOutputMessageFunctionToolCall
|
||||
| OpenAIResponseInputFunctionToolCallOutput
|
||||
|
|
||||
# Fallback to the generic message type as a last resort
|
||||
OpenAIResponseMessage,
|
||||
|
|
|
@ -14,6 +14,7 @@ from pydantic import BaseModel
|
|||
|
||||
from llama_stack.apis.agents.openai_responses import (
|
||||
OpenAIResponseInput,
|
||||
OpenAIResponseInputFunctionToolCallOutput,
|
||||
OpenAIResponseInputItemList,
|
||||
OpenAIResponseInputMessageContent,
|
||||
OpenAIResponseInputMessageContentImage,
|
||||
|
@ -38,6 +39,7 @@ from llama_stack.apis.inference.inference import (
|
|||
OpenAIChatCompletionContentPartImageParam,
|
||||
OpenAIChatCompletionContentPartParam,
|
||||
OpenAIChatCompletionContentPartTextParam,
|
||||
OpenAIChatCompletionToolCall,
|
||||
OpenAIChatCompletionToolCallFunction,
|
||||
OpenAIChoice,
|
||||
OpenAIDeveloperMessageParam,
|
||||
|
@ -97,13 +99,31 @@ async def _convert_response_input_to_chat_messages(
|
|||
messages: list[OpenAIMessageParam] = []
|
||||
if isinstance(input, list):
|
||||
for input_message in input:
|
||||
content = await _convert_response_content_to_chat_content(input_message.content)
|
||||
message_type = await _get_message_type_by_role(input_message.role)
|
||||
if message_type is None:
|
||||
raise ValueError(
|
||||
f"Llama Stack OpenAI Responses does not yet support message role '{input_message.role}' in this context"
|
||||
if isinstance(input_message, OpenAIResponseInputFunctionToolCallOutput):
|
||||
messages.append(
|
||||
OpenAIToolMessageParam(
|
||||
content=input_message.output,
|
||||
tool_call_id=input_message.call_id,
|
||||
)
|
||||
)
|
||||
messages.append(message_type(content=content))
|
||||
elif isinstance(input_message, OpenAIResponseOutputMessageFunctionToolCall):
|
||||
tool_call = OpenAIChatCompletionToolCall(
|
||||
index=0,
|
||||
id=input_message.call_id,
|
||||
function=OpenAIChatCompletionToolCallFunction(
|
||||
name=input_message.name,
|
||||
arguments=input_message.arguments,
|
||||
),
|
||||
)
|
||||
messages.append(OpenAIAssistantMessageParam(tool_calls=[tool_call]))
|
||||
else:
|
||||
content = await _convert_response_content_to_chat_content(input_message.content)
|
||||
message_type = await _get_message_type_by_role(input_message.role)
|
||||
if message_type is None:
|
||||
raise ValueError(
|
||||
f"Llama Stack OpenAI Responses does not yet support message role '{input_message.role}' in this context"
|
||||
)
|
||||
messages.append(message_type(content=content))
|
||||
else:
|
||||
messages.append(OpenAIUserMessageParam(content=input))
|
||||
return messages
|
||||
|
@ -222,6 +242,7 @@ class OpenAIResponsesImpl:
|
|||
# TODO: refactor this into a separate method that handles streaming
|
||||
chat_response_id = ""
|
||||
chat_response_content = []
|
||||
chat_response_tool_calls: dict[int, OpenAIChatCompletionToolCall] = {}
|
||||
# TODO: these chunk_ fields are hacky and only take the last chunk into account
|
||||
chunk_created = 0
|
||||
chunk_model = ""
|
||||
|
@ -235,7 +256,26 @@ class OpenAIResponsesImpl:
|
|||
chat_response_content.append(chunk_choice.delta.content or "")
|
||||
if chunk_choice.finish_reason:
|
||||
chunk_finish_reason = chunk_choice.finish_reason
|
||||
assistant_message = OpenAIAssistantMessageParam(content="".join(chat_response_content))
|
||||
|
||||
if chunk_choice.delta.tool_calls:
|
||||
for tool_call in chunk_choice.delta.tool_calls:
|
||||
if tool_call.index not in chat_response_tool_calls:
|
||||
chat_response_tool_calls[tool_call.index] = OpenAIChatCompletionToolCall(
|
||||
**tool_call.model_dump()
|
||||
)
|
||||
chat_response_tool_calls[tool_call.index].function.arguments = (
|
||||
chat_response_tool_calls[tool_call.index].function.arguments
|
||||
+ tool_call.function.arguments
|
||||
)
|
||||
|
||||
if chat_response_tool_calls:
|
||||
tool_calls = [chat_response_tool_calls[i] for i in sorted(chat_response_tool_calls.keys())]
|
||||
else:
|
||||
tool_calls = None
|
||||
assistant_message = OpenAIAssistantMessageParam(
|
||||
content="".join(chat_response_content),
|
||||
tool_calls=tool_calls,
|
||||
)
|
||||
chat_response = OpenAIChatCompletion(
|
||||
id=chat_response_id,
|
||||
choices=[
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue