mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-06-28 10:54:19 +00:00
feat: Structured output for Responses API
This adds the missing `text` parameter to the Responses API that is how users control structured outputs. All we do with that parameter is map it to the corresponding chat completion response_format. The unit tests exercise the various permutations allowed for this property, while a couple of new verification tests actually use it for real to verify the model outputs are following the format as expected. Signed-off-by: Ben Browning <bbrownin@redhat.com>
This commit is contained in:
parent
4540c9b3e5
commit
badf8594d1
8 changed files with 323 additions and 2 deletions
86
docs/_static/llama-stack-spec.html
vendored
86
docs/_static/llama-stack-spec.html
vendored
|
@ -7241,6 +7241,79 @@
|
||||||
],
|
],
|
||||||
"title": "OpenAIResponseOutputMessageWebSearchToolCall"
|
"title": "OpenAIResponseOutputMessageWebSearchToolCall"
|
||||||
},
|
},
|
||||||
|
"OpenAIResponseText": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"format": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"type": {
|
||||||
|
"oneOf": [
|
||||||
|
{
|
||||||
|
"type": "string",
|
||||||
|
"const": "text"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "string",
|
||||||
|
"const": "json_schema"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "string",
|
||||||
|
"const": "json_object"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"description": "Must be \"text\", \"json_schema\", or \"json_object\" to identify the format type"
|
||||||
|
},
|
||||||
|
"name": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "The name of the response format. Only used for json_schema."
|
||||||
|
},
|
||||||
|
"schema": {
|
||||||
|
"type": "object",
|
||||||
|
"additionalProperties": {
|
||||||
|
"oneOf": [
|
||||||
|
{
|
||||||
|
"type": "null"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "boolean"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "number"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "string"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "array"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "object"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"description": "The JSON schema the response should conform to. In a Python SDK, this is often a `pydantic` model. Only used for json_schema."
|
||||||
|
},
|
||||||
|
"description": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "(Optional) A description of the response format. Only used for json_schema."
|
||||||
|
},
|
||||||
|
"strict": {
|
||||||
|
"type": "boolean",
|
||||||
|
"description": "(Optional) Whether to strictly enforce the JSON schema. If true, the response must match the schema exactly. Only used for json_schema."
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"additionalProperties": false,
|
||||||
|
"required": [
|
||||||
|
"type"
|
||||||
|
],
|
||||||
|
"title": "OpenAIResponseTextFormat",
|
||||||
|
"description": "Configuration for Responses API text format."
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"additionalProperties": false,
|
||||||
|
"title": "OpenAIResponseText"
|
||||||
|
},
|
||||||
"CreateOpenaiResponseRequest": {
|
"CreateOpenaiResponseRequest": {
|
||||||
"type": "object",
|
"type": "object",
|
||||||
"properties": {
|
"properties": {
|
||||||
|
@ -7278,6 +7351,9 @@
|
||||||
"temperature": {
|
"temperature": {
|
||||||
"type": "number"
|
"type": "number"
|
||||||
},
|
},
|
||||||
|
"text": {
|
||||||
|
"$ref": "#/components/schemas/OpenAIResponseText"
|
||||||
|
},
|
||||||
"tools": {
|
"tools": {
|
||||||
"type": "array",
|
"type": "array",
|
||||||
"items": {
|
"items": {
|
||||||
|
@ -7351,6 +7427,9 @@
|
||||||
"temperature": {
|
"temperature": {
|
||||||
"type": "number"
|
"type": "number"
|
||||||
},
|
},
|
||||||
|
"text": {
|
||||||
|
"$ref": "#/components/schemas/OpenAIResponseText"
|
||||||
|
},
|
||||||
"top_p": {
|
"top_p": {
|
||||||
"type": "number"
|
"type": "number"
|
||||||
},
|
},
|
||||||
|
@ -7369,7 +7448,8 @@
|
||||||
"object",
|
"object",
|
||||||
"output",
|
"output",
|
||||||
"parallel_tool_calls",
|
"parallel_tool_calls",
|
||||||
"status"
|
"status",
|
||||||
|
"text"
|
||||||
],
|
],
|
||||||
"title": "OpenAIResponseObject"
|
"title": "OpenAIResponseObject"
|
||||||
},
|
},
|
||||||
|
@ -10406,6 +10486,9 @@
|
||||||
"temperature": {
|
"temperature": {
|
||||||
"type": "number"
|
"type": "number"
|
||||||
},
|
},
|
||||||
|
"text": {
|
||||||
|
"$ref": "#/components/schemas/OpenAIResponseText"
|
||||||
|
},
|
||||||
"top_p": {
|
"top_p": {
|
||||||
"type": "number"
|
"type": "number"
|
||||||
},
|
},
|
||||||
|
@ -10431,6 +10514,7 @@
|
||||||
"output",
|
"output",
|
||||||
"parallel_tool_calls",
|
"parallel_tool_calls",
|
||||||
"status",
|
"status",
|
||||||
|
"text",
|
||||||
"input"
|
"input"
|
||||||
],
|
],
|
||||||
"title": "OpenAIResponseObjectWithInput"
|
"title": "OpenAIResponseObjectWithInput"
|
||||||
|
|
59
docs/_static/llama-stack-spec.yaml
vendored
59
docs/_static/llama-stack-spec.yaml
vendored
|
@ -5118,6 +5118,57 @@ components:
|
||||||
- type
|
- type
|
||||||
title: >-
|
title: >-
|
||||||
OpenAIResponseOutputMessageWebSearchToolCall
|
OpenAIResponseOutputMessageWebSearchToolCall
|
||||||
|
OpenAIResponseText:
|
||||||
|
type: object
|
||||||
|
properties:
|
||||||
|
format:
|
||||||
|
type: object
|
||||||
|
properties:
|
||||||
|
type:
|
||||||
|
oneOf:
|
||||||
|
- type: string
|
||||||
|
const: text
|
||||||
|
- type: string
|
||||||
|
const: json_schema
|
||||||
|
- type: string
|
||||||
|
const: json_object
|
||||||
|
description: >-
|
||||||
|
Must be "text", "json_schema", or "json_object" to identify the format
|
||||||
|
type
|
||||||
|
name:
|
||||||
|
type: string
|
||||||
|
description: >-
|
||||||
|
The name of the response format. Only used for json_schema.
|
||||||
|
schema:
|
||||||
|
type: object
|
||||||
|
additionalProperties:
|
||||||
|
oneOf:
|
||||||
|
- type: 'null'
|
||||||
|
- type: boolean
|
||||||
|
- type: number
|
||||||
|
- type: string
|
||||||
|
- type: array
|
||||||
|
- type: object
|
||||||
|
description: >-
|
||||||
|
The JSON schema the response should conform to. In a Python SDK, this
|
||||||
|
is often a `pydantic` model. Only used for json_schema.
|
||||||
|
description:
|
||||||
|
type: string
|
||||||
|
description: >-
|
||||||
|
(Optional) A description of the response format. Only used for json_schema.
|
||||||
|
strict:
|
||||||
|
type: boolean
|
||||||
|
description: >-
|
||||||
|
(Optional) Whether to strictly enforce the JSON schema. If true, the
|
||||||
|
response must match the schema exactly. Only used for json_schema.
|
||||||
|
additionalProperties: false
|
||||||
|
required:
|
||||||
|
- type
|
||||||
|
title: OpenAIResponseTextFormat
|
||||||
|
description: >-
|
||||||
|
Configuration for Responses API text format.
|
||||||
|
additionalProperties: false
|
||||||
|
title: OpenAIResponseText
|
||||||
CreateOpenaiResponseRequest:
|
CreateOpenaiResponseRequest:
|
||||||
type: object
|
type: object
|
||||||
properties:
|
properties:
|
||||||
|
@ -5145,6 +5196,8 @@ components:
|
||||||
type: boolean
|
type: boolean
|
||||||
temperature:
|
temperature:
|
||||||
type: number
|
type: number
|
||||||
|
text:
|
||||||
|
$ref: '#/components/schemas/OpenAIResponseText'
|
||||||
tools:
|
tools:
|
||||||
type: array
|
type: array
|
||||||
items:
|
items:
|
||||||
|
@ -5196,6 +5249,8 @@ components:
|
||||||
type: string
|
type: string
|
||||||
temperature:
|
temperature:
|
||||||
type: number
|
type: number
|
||||||
|
text:
|
||||||
|
$ref: '#/components/schemas/OpenAIResponseText'
|
||||||
top_p:
|
top_p:
|
||||||
type: number
|
type: number
|
||||||
truncation:
|
truncation:
|
||||||
|
@ -5211,6 +5266,7 @@ components:
|
||||||
- output
|
- output
|
||||||
- parallel_tool_calls
|
- parallel_tool_calls
|
||||||
- status
|
- status
|
||||||
|
- text
|
||||||
title: OpenAIResponseObject
|
title: OpenAIResponseObject
|
||||||
OpenAIResponseOutput:
|
OpenAIResponseOutput:
|
||||||
oneOf:
|
oneOf:
|
||||||
|
@ -7288,6 +7344,8 @@ components:
|
||||||
type: string
|
type: string
|
||||||
temperature:
|
temperature:
|
||||||
type: number
|
type: number
|
||||||
|
text:
|
||||||
|
$ref: '#/components/schemas/OpenAIResponseText'
|
||||||
top_p:
|
top_p:
|
||||||
type: number
|
type: number
|
||||||
truncation:
|
truncation:
|
||||||
|
@ -7307,6 +7365,7 @@ components:
|
||||||
- output
|
- output
|
||||||
- parallel_tool_calls
|
- parallel_tool_calls
|
||||||
- status
|
- status
|
||||||
|
- text
|
||||||
- input
|
- input
|
||||||
title: OpenAIResponseObjectWithInput
|
title: OpenAIResponseObjectWithInput
|
||||||
ListProvidersResponse:
|
ListProvidersResponse:
|
||||||
|
|
|
@ -37,6 +37,7 @@ from .openai_responses import (
|
||||||
OpenAIResponseInputTool,
|
OpenAIResponseInputTool,
|
||||||
OpenAIResponseObject,
|
OpenAIResponseObject,
|
||||||
OpenAIResponseObjectStream,
|
OpenAIResponseObjectStream,
|
||||||
|
OpenAIResponseText,
|
||||||
)
|
)
|
||||||
|
|
||||||
# TODO: use enum.StrEnum when we drop support for python 3.10
|
# TODO: use enum.StrEnum when we drop support for python 3.10
|
||||||
|
@ -603,6 +604,7 @@ class Agents(Protocol):
|
||||||
store: bool | None = True,
|
store: bool | None = True,
|
||||||
stream: bool | None = False,
|
stream: bool | None = False,
|
||||||
temperature: float | None = None,
|
temperature: float | None = None,
|
||||||
|
text: OpenAIResponseText | None = None,
|
||||||
tools: list[OpenAIResponseInputTool] | None = None,
|
tools: list[OpenAIResponseInputTool] | None = None,
|
||||||
max_infer_iters: int | None = 10, # this is an extension to the OpenAI API
|
max_infer_iters: int | None = 10, # this is an extension to the OpenAI API
|
||||||
) -> OpenAIResponseObject | AsyncIterator[OpenAIResponseObjectStream]:
|
) -> OpenAIResponseObject | AsyncIterator[OpenAIResponseObjectStream]:
|
||||||
|
|
|
@ -7,6 +7,7 @@
|
||||||
from typing import Annotated, Any, Literal
|
from typing import Annotated, Any, Literal
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
from typing_extensions import TypedDict
|
||||||
|
|
||||||
from llama_stack.schema_utils import json_schema_type, register_schema
|
from llama_stack.schema_utils import json_schema_type, register_schema
|
||||||
|
|
||||||
|
@ -126,6 +127,32 @@ OpenAIResponseOutput = Annotated[
|
||||||
register_schema(OpenAIResponseOutput, name="OpenAIResponseOutput")
|
register_schema(OpenAIResponseOutput, name="OpenAIResponseOutput")
|
||||||
|
|
||||||
|
|
||||||
|
# This has to be a TypedDict because we need a "schema" field and our strong
|
||||||
|
# typing code in the schema generator doesn't support Pydantic aliases. That also
|
||||||
|
# means we can't use a discriminator field here, because TypedDicts don't support
|
||||||
|
# default values which the strong typing code requires for discriminators.
|
||||||
|
class OpenAIResponseTextFormat(TypedDict, total=False):
|
||||||
|
"""Configuration for Responses API text format.
|
||||||
|
|
||||||
|
:param type: Must be "text", "json_schema", or "json_object" to identify the format type
|
||||||
|
:param name: The name of the response format. Only used for json_schema.
|
||||||
|
:param schema: The JSON schema the response should conform to. In a Python SDK, this is often a `pydantic` model. Only used for json_schema.
|
||||||
|
:param description: (Optional) A description of the response format. Only used for json_schema.
|
||||||
|
:param strict: (Optional) Whether to strictly enforce the JSON schema. If true, the response must match the schema exactly. Only used for json_schema.
|
||||||
|
"""
|
||||||
|
|
||||||
|
type: Literal["text"] | Literal["json_schema"] | Literal["json_object"]
|
||||||
|
name: str | None
|
||||||
|
schema: dict[str, Any] | None
|
||||||
|
description: str | None
|
||||||
|
strict: bool | None
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class OpenAIResponseText(BaseModel):
|
||||||
|
format: OpenAIResponseTextFormat | None = None
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class OpenAIResponseObject(BaseModel):
|
class OpenAIResponseObject(BaseModel):
|
||||||
created_at: int
|
created_at: int
|
||||||
|
@ -138,6 +165,9 @@ class OpenAIResponseObject(BaseModel):
|
||||||
previous_response_id: str | None = None
|
previous_response_id: str | None = None
|
||||||
status: str
|
status: str
|
||||||
temperature: float | None = None
|
temperature: float | None = None
|
||||||
|
# Default to text format to avoid breaking the loading of old responses
|
||||||
|
# before the field was added. New responses will have this set always.
|
||||||
|
text: OpenAIResponseText = OpenAIResponseText(format=OpenAIResponseTextFormat(type="text"))
|
||||||
top_p: float | None = None
|
top_p: float | None = None
|
||||||
truncation: str | None = None
|
truncation: str | None = None
|
||||||
user: str | None = None
|
user: str | None = None
|
||||||
|
|
|
@ -29,6 +29,7 @@ from llama_stack.apis.agents import (
|
||||||
Session,
|
Session,
|
||||||
Turn,
|
Turn,
|
||||||
)
|
)
|
||||||
|
from llama_stack.apis.agents.openai_responses import OpenAIResponseText
|
||||||
from llama_stack.apis.common.responses import PaginatedResponse
|
from llama_stack.apis.common.responses import PaginatedResponse
|
||||||
from llama_stack.apis.inference import (
|
from llama_stack.apis.inference import (
|
||||||
Inference,
|
Inference,
|
||||||
|
@ -324,11 +325,12 @@ class MetaReferenceAgentsImpl(Agents):
|
||||||
store: bool | None = True,
|
store: bool | None = True,
|
||||||
stream: bool | None = False,
|
stream: bool | None = False,
|
||||||
temperature: float | None = None,
|
temperature: float | None = None,
|
||||||
|
text: OpenAIResponseText | None = None,
|
||||||
tools: list[OpenAIResponseInputTool] | None = None,
|
tools: list[OpenAIResponseInputTool] | None = None,
|
||||||
max_infer_iters: int | None = 10,
|
max_infer_iters: int | None = 10,
|
||||||
) -> OpenAIResponseObject:
|
) -> OpenAIResponseObject:
|
||||||
return await self.openai_responses_impl.create_openai_response(
|
return await self.openai_responses_impl.create_openai_response(
|
||||||
input, model, instructions, previous_response_id, store, stream, temperature, tools, max_infer_iters
|
input, model, instructions, previous_response_id, store, stream, temperature, text, tools, max_infer_iters
|
||||||
)
|
)
|
||||||
|
|
||||||
async def list_openai_responses(
|
async def list_openai_responses(
|
||||||
|
|
|
@ -37,6 +37,8 @@ from llama_stack.apis.agents.openai_responses import (
|
||||||
OpenAIResponseOutputMessageFunctionToolCall,
|
OpenAIResponseOutputMessageFunctionToolCall,
|
||||||
OpenAIResponseOutputMessageMCPListTools,
|
OpenAIResponseOutputMessageMCPListTools,
|
||||||
OpenAIResponseOutputMessageWebSearchToolCall,
|
OpenAIResponseOutputMessageWebSearchToolCall,
|
||||||
|
OpenAIResponseText,
|
||||||
|
OpenAIResponseTextFormat,
|
||||||
)
|
)
|
||||||
from llama_stack.apis.inference.inference import (
|
from llama_stack.apis.inference.inference import (
|
||||||
Inference,
|
Inference,
|
||||||
|
@ -50,7 +52,12 @@ from llama_stack.apis.inference.inference import (
|
||||||
OpenAIChoice,
|
OpenAIChoice,
|
||||||
OpenAIDeveloperMessageParam,
|
OpenAIDeveloperMessageParam,
|
||||||
OpenAIImageURL,
|
OpenAIImageURL,
|
||||||
|
OpenAIJSONSchema,
|
||||||
OpenAIMessageParam,
|
OpenAIMessageParam,
|
||||||
|
OpenAIResponseFormatJSONObject,
|
||||||
|
OpenAIResponseFormatJSONSchema,
|
||||||
|
OpenAIResponseFormatParam,
|
||||||
|
OpenAIResponseFormatText,
|
||||||
OpenAISystemMessageParam,
|
OpenAISystemMessageParam,
|
||||||
OpenAIToolMessageParam,
|
OpenAIToolMessageParam,
|
||||||
OpenAIUserMessageParam,
|
OpenAIUserMessageParam,
|
||||||
|
@ -158,6 +165,21 @@ async def _convert_chat_choice_to_response_message(choice: OpenAIChoice) -> Open
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def _convert_response_text_to_chat_response_format(text: OpenAIResponseText) -> OpenAIResponseFormatParam:
|
||||||
|
"""
|
||||||
|
Convert an OpenAI Response text parameter into an OpenAI Chat Completion response format.
|
||||||
|
"""
|
||||||
|
if not text.format or text.format["type"] == "text":
|
||||||
|
return OpenAIResponseFormatText(type="text")
|
||||||
|
if text.format["type"] == "json_object":
|
||||||
|
return OpenAIResponseFormatJSONObject()
|
||||||
|
if text.format["type"] == "json_schema":
|
||||||
|
return OpenAIResponseFormatJSONSchema(
|
||||||
|
json_schema=OpenAIJSONSchema(name=text.format["name"], schema=text.format["schema"])
|
||||||
|
)
|
||||||
|
raise ValueError(f"Unsupported text format: {text.format}")
|
||||||
|
|
||||||
|
|
||||||
async def _get_message_type_by_role(role: str):
|
async def _get_message_type_by_role(role: str):
|
||||||
role_to_type = {
|
role_to_type = {
|
||||||
"user": OpenAIUserMessageParam,
|
"user": OpenAIUserMessageParam,
|
||||||
|
@ -180,6 +202,7 @@ class ChatCompletionContext(BaseModel):
|
||||||
mcp_tool_to_server: dict[str, OpenAIResponseInputToolMCP]
|
mcp_tool_to_server: dict[str, OpenAIResponseInputToolMCP]
|
||||||
stream: bool
|
stream: bool
|
||||||
temperature: float | None
|
temperature: float | None
|
||||||
|
response_format: OpenAIResponseFormatParam
|
||||||
|
|
||||||
|
|
||||||
class OpenAIResponsesImpl:
|
class OpenAIResponsesImpl:
|
||||||
|
@ -343,10 +366,12 @@ class OpenAIResponsesImpl:
|
||||||
store: bool | None = True,
|
store: bool | None = True,
|
||||||
stream: bool | None = False,
|
stream: bool | None = False,
|
||||||
temperature: float | None = None,
|
temperature: float | None = None,
|
||||||
|
text: OpenAIResponseText | None = None,
|
||||||
tools: list[OpenAIResponseInputTool] | None = None,
|
tools: list[OpenAIResponseInputTool] | None = None,
|
||||||
max_infer_iters: int | None = 10,
|
max_infer_iters: int | None = 10,
|
||||||
):
|
):
|
||||||
stream = False if stream is None else stream
|
stream = False if stream is None else stream
|
||||||
|
text = OpenAIResponseText(format=OpenAIResponseTextFormat(type="text")) if text is None else text
|
||||||
|
|
||||||
output_messages: list[OpenAIResponseOutput] = []
|
output_messages: list[OpenAIResponseOutput] = []
|
||||||
|
|
||||||
|
@ -355,6 +380,9 @@ class OpenAIResponsesImpl:
|
||||||
messages = await _convert_response_input_to_chat_messages(input)
|
messages = await _convert_response_input_to_chat_messages(input)
|
||||||
await self._prepend_instructions(messages, instructions)
|
await self._prepend_instructions(messages, instructions)
|
||||||
|
|
||||||
|
# Structured outputs
|
||||||
|
response_format = await _convert_response_text_to_chat_response_format(text)
|
||||||
|
|
||||||
# Tool setup
|
# Tool setup
|
||||||
chat_tools, mcp_tool_to_server, mcp_list_message = (
|
chat_tools, mcp_tool_to_server, mcp_list_message = (
|
||||||
await self._convert_response_tools_to_chat_tools(tools) if tools else (None, {}, None)
|
await self._convert_response_tools_to_chat_tools(tools) if tools else (None, {}, None)
|
||||||
|
@ -369,6 +397,7 @@ class OpenAIResponsesImpl:
|
||||||
mcp_tool_to_server=mcp_tool_to_server,
|
mcp_tool_to_server=mcp_tool_to_server,
|
||||||
stream=stream,
|
stream=stream,
|
||||||
temperature=temperature,
|
temperature=temperature,
|
||||||
|
response_format=response_format,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Fork to streaming vs non-streaming - let each handle ALL inference rounds
|
# Fork to streaming vs non-streaming - let each handle ALL inference rounds
|
||||||
|
@ -379,6 +408,7 @@ class OpenAIResponsesImpl:
|
||||||
input=input,
|
input=input,
|
||||||
model=model,
|
model=model,
|
||||||
store=store,
|
store=store,
|
||||||
|
text=text,
|
||||||
tools=tools,
|
tools=tools,
|
||||||
max_infer_iters=max_infer_iters,
|
max_infer_iters=max_infer_iters,
|
||||||
)
|
)
|
||||||
|
@ -389,6 +419,7 @@ class OpenAIResponsesImpl:
|
||||||
input=input,
|
input=input,
|
||||||
model=model,
|
model=model,
|
||||||
store=store,
|
store=store,
|
||||||
|
text=text,
|
||||||
tools=tools,
|
tools=tools,
|
||||||
max_infer_iters=max_infer_iters,
|
max_infer_iters=max_infer_iters,
|
||||||
)
|
)
|
||||||
|
@ -400,6 +431,7 @@ class OpenAIResponsesImpl:
|
||||||
input: str | list[OpenAIResponseInput],
|
input: str | list[OpenAIResponseInput],
|
||||||
model: str,
|
model: str,
|
||||||
store: bool | None,
|
store: bool | None,
|
||||||
|
text: OpenAIResponseText,
|
||||||
tools: list[OpenAIResponseInputTool] | None,
|
tools: list[OpenAIResponseInputTool] | None,
|
||||||
max_infer_iters: int | None,
|
max_infer_iters: int | None,
|
||||||
) -> OpenAIResponseObject:
|
) -> OpenAIResponseObject:
|
||||||
|
@ -416,6 +448,7 @@ class OpenAIResponsesImpl:
|
||||||
tools=ctx.tools,
|
tools=ctx.tools,
|
||||||
stream=False,
|
stream=False,
|
||||||
temperature=ctx.temperature,
|
temperature=ctx.temperature,
|
||||||
|
response_format=ctx.response_format,
|
||||||
)
|
)
|
||||||
current_response = OpenAIChatCompletion(**inference_result.model_dump())
|
current_response = OpenAIChatCompletion(**inference_result.model_dump())
|
||||||
|
|
||||||
|
@ -470,6 +503,7 @@ class OpenAIResponsesImpl:
|
||||||
object="response",
|
object="response",
|
||||||
status="completed",
|
status="completed",
|
||||||
output=output_messages,
|
output=output_messages,
|
||||||
|
text=text,
|
||||||
)
|
)
|
||||||
logger.debug(f"OpenAI Responses response: {response}")
|
logger.debug(f"OpenAI Responses response: {response}")
|
||||||
|
|
||||||
|
@ -489,6 +523,7 @@ class OpenAIResponsesImpl:
|
||||||
input: str | list[OpenAIResponseInput],
|
input: str | list[OpenAIResponseInput],
|
||||||
model: str,
|
model: str,
|
||||||
store: bool | None,
|
store: bool | None,
|
||||||
|
text: OpenAIResponseText,
|
||||||
tools: list[OpenAIResponseInputTool] | None,
|
tools: list[OpenAIResponseInputTool] | None,
|
||||||
max_infer_iters: int | None,
|
max_infer_iters: int | None,
|
||||||
) -> AsyncIterator[OpenAIResponseObjectStream]:
|
) -> AsyncIterator[OpenAIResponseObjectStream]:
|
||||||
|
@ -503,6 +538,7 @@ class OpenAIResponsesImpl:
|
||||||
object="response",
|
object="response",
|
||||||
status="in_progress",
|
status="in_progress",
|
||||||
output=output_messages.copy(),
|
output=output_messages.copy(),
|
||||||
|
text=text,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Emit response.created immediately
|
# Emit response.created immediately
|
||||||
|
@ -520,6 +556,7 @@ class OpenAIResponsesImpl:
|
||||||
tools=ctx.tools,
|
tools=ctx.tools,
|
||||||
stream=True,
|
stream=True,
|
||||||
temperature=ctx.temperature,
|
temperature=ctx.temperature,
|
||||||
|
response_format=ctx.response_format,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Process streaming chunks and build complete response
|
# Process streaming chunks and build complete response
|
||||||
|
@ -645,6 +682,7 @@ class OpenAIResponsesImpl:
|
||||||
model=model,
|
model=model,
|
||||||
object="response",
|
object="response",
|
||||||
status="completed",
|
status="completed",
|
||||||
|
text=text,
|
||||||
output=output_messages,
|
output=output_messages,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -25,11 +25,17 @@ from llama_stack.apis.agents.openai_responses import (
|
||||||
OpenAIResponseObjectWithInput,
|
OpenAIResponseObjectWithInput,
|
||||||
OpenAIResponseOutputMessageContentOutputText,
|
OpenAIResponseOutputMessageContentOutputText,
|
||||||
OpenAIResponseOutputMessageWebSearchToolCall,
|
OpenAIResponseOutputMessageWebSearchToolCall,
|
||||||
|
OpenAIResponseText,
|
||||||
|
OpenAIResponseTextFormat,
|
||||||
)
|
)
|
||||||
from llama_stack.apis.inference.inference import (
|
from llama_stack.apis.inference.inference import (
|
||||||
OpenAIAssistantMessageParam,
|
OpenAIAssistantMessageParam,
|
||||||
OpenAIChatCompletionContentPartTextParam,
|
OpenAIChatCompletionContentPartTextParam,
|
||||||
OpenAIDeveloperMessageParam,
|
OpenAIDeveloperMessageParam,
|
||||||
|
OpenAIJSONSchema,
|
||||||
|
OpenAIResponseFormatJSONObject,
|
||||||
|
OpenAIResponseFormatJSONSchema,
|
||||||
|
OpenAIResponseFormatText,
|
||||||
OpenAIUserMessageParam,
|
OpenAIUserMessageParam,
|
||||||
)
|
)
|
||||||
from llama_stack.apis.tools.tools import Tool, ToolGroups, ToolInvocationResult, ToolParameter, ToolRuntime
|
from llama_stack.apis.tools.tools import Tool, ToolGroups, ToolInvocationResult, ToolParameter, ToolRuntime
|
||||||
|
@ -96,6 +102,7 @@ async def test_create_openai_response_with_string_input(openai_responses_impl, m
|
||||||
mock_inference_api.openai_chat_completion.assert_called_once_with(
|
mock_inference_api.openai_chat_completion.assert_called_once_with(
|
||||||
model=model,
|
model=model,
|
||||||
messages=[OpenAIUserMessageParam(role="user", content="What is the capital of Ireland?", name=None)],
|
messages=[OpenAIUserMessageParam(role="user", content="What is the capital of Ireland?", name=None)],
|
||||||
|
response_format=OpenAIResponseFormatText(),
|
||||||
tools=None,
|
tools=None,
|
||||||
stream=False,
|
stream=False,
|
||||||
temperature=0.1,
|
temperature=0.1,
|
||||||
|
@ -320,6 +327,7 @@ async def test_prepend_previous_response_basic(openai_responses_impl, mock_respo
|
||||||
model="fake_model",
|
model="fake_model",
|
||||||
output=[response_output_message],
|
output=[response_output_message],
|
||||||
status="completed",
|
status="completed",
|
||||||
|
text=OpenAIResponseText(format=OpenAIResponseTextFormat(type="text")),
|
||||||
input=[input_item_message],
|
input=[input_item_message],
|
||||||
)
|
)
|
||||||
mock_responses_store.get_response_object.return_value = previous_response
|
mock_responses_store.get_response_object.return_value = previous_response
|
||||||
|
@ -362,6 +370,7 @@ async def test_prepend_previous_response_web_search(openai_responses_impl, mock_
|
||||||
model="fake_model",
|
model="fake_model",
|
||||||
output=[output_web_search, output_message],
|
output=[output_web_search, output_message],
|
||||||
status="completed",
|
status="completed",
|
||||||
|
text=OpenAIResponseText(format=OpenAIResponseTextFormat(type="text")),
|
||||||
input=[input_item_message],
|
input=[input_item_message],
|
||||||
)
|
)
|
||||||
mock_responses_store.get_response_object.return_value = response
|
mock_responses_store.get_response_object.return_value = response
|
||||||
|
@ -483,6 +492,7 @@ async def test_create_openai_response_with_instructions_and_previous_response(
|
||||||
model="fake_model",
|
model="fake_model",
|
||||||
output=[response_output_message],
|
output=[response_output_message],
|
||||||
status="completed",
|
status="completed",
|
||||||
|
text=OpenAIResponseText(format=OpenAIResponseTextFormat(type="text")),
|
||||||
input=[input_item_message],
|
input=[input_item_message],
|
||||||
)
|
)
|
||||||
mock_responses_store.get_response_object.return_value = response
|
mock_responses_store.get_response_object.return_value = response
|
||||||
|
@ -576,6 +586,7 @@ async def test_responses_store_list_input_items_logic():
|
||||||
object="response",
|
object="response",
|
||||||
status="completed",
|
status="completed",
|
||||||
output=[],
|
output=[],
|
||||||
|
text=OpenAIResponseText(format=(OpenAIResponseTextFormat(type="text"))),
|
||||||
input=input_items,
|
input=input_items,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -644,6 +655,7 @@ async def test_store_response_uses_rehydrated_input_with_previous_response(
|
||||||
created_at=1234567890,
|
created_at=1234567890,
|
||||||
model="meta-llama/Llama-3.1-8B-Instruct",
|
model="meta-llama/Llama-3.1-8B-Instruct",
|
||||||
status="completed",
|
status="completed",
|
||||||
|
text=OpenAIResponseText(format=OpenAIResponseTextFormat(type="text")),
|
||||||
input=[
|
input=[
|
||||||
OpenAIResponseMessage(
|
OpenAIResponseMessage(
|
||||||
id="msg-prev-user", role="user", content=[OpenAIResponseInputMessageContentText(text="What is 2+2?")]
|
id="msg-prev-user", role="user", content=[OpenAIResponseInputMessageContentText(text="What is 2+2?")]
|
||||||
|
@ -694,3 +706,61 @@ async def test_store_response_uses_rehydrated_input_with_previous_response(
|
||||||
# Verify the response itself is correct
|
# Verify the response itself is correct
|
||||||
assert result.model == model
|
assert result.model == model
|
||||||
assert result.status == "completed"
|
assert result.status == "completed"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"text_format, response_format",
|
||||||
|
[
|
||||||
|
(OpenAIResponseText(format=OpenAIResponseTextFormat(type="text")), OpenAIResponseFormatText()),
|
||||||
|
(
|
||||||
|
OpenAIResponseText(format=OpenAIResponseTextFormat(name="Test", schema={"foo": "bar"}, type="json_schema")),
|
||||||
|
OpenAIResponseFormatJSONSchema(json_schema=OpenAIJSONSchema(name="Test", schema={"foo": "bar"})),
|
||||||
|
),
|
||||||
|
(OpenAIResponseText(format=OpenAIResponseTextFormat(type="json_object")), OpenAIResponseFormatJSONObject()),
|
||||||
|
# ensure text param with no format specified defaults to text
|
||||||
|
(OpenAIResponseText(format=None), OpenAIResponseFormatText()),
|
||||||
|
# ensure text param of None defaults to text
|
||||||
|
(None, OpenAIResponseFormatText()),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
async def test_create_openai_response_with_text_format(
|
||||||
|
openai_responses_impl, mock_inference_api, text_format, response_format
|
||||||
|
):
|
||||||
|
"""Test creating Responses with text formats."""
|
||||||
|
# Setup
|
||||||
|
input_text = "How hot it is in San Francisco today?"
|
||||||
|
model = "meta-llama/Llama-3.1-8B-Instruct"
|
||||||
|
|
||||||
|
# Load the chat completion fixture
|
||||||
|
mock_chat_completion = load_chat_completion_fixture("simple_chat_completion.yaml")
|
||||||
|
mock_inference_api.openai_chat_completion.return_value = mock_chat_completion
|
||||||
|
|
||||||
|
# Execute
|
||||||
|
_result = await openai_responses_impl.create_openai_response(
|
||||||
|
input=input_text,
|
||||||
|
model=model,
|
||||||
|
text=text_format,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify
|
||||||
|
first_call = mock_inference_api.openai_chat_completion.call_args_list[0]
|
||||||
|
assert first_call.kwargs["messages"][0].content == input_text
|
||||||
|
assert first_call.kwargs["response_format"] is not None
|
||||||
|
assert first_call.kwargs["response_format"] == response_format
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_create_openai_response_with_invalid_text_format(openai_responses_impl, mock_inference_api):
|
||||||
|
"""Test creating an OpenAI response with an invalid text format."""
|
||||||
|
# Setup
|
||||||
|
input_text = "How hot it is in San Francisco today?"
|
||||||
|
model = "meta-llama/Llama-3.1-8B-Instruct"
|
||||||
|
|
||||||
|
# Execute
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
_result = await openai_responses_impl.create_openai_response(
|
||||||
|
input=input_text,
|
||||||
|
model=model,
|
||||||
|
text=OpenAIResponseText(format={"type": "invalid"}),
|
||||||
|
)
|
||||||
|
|
|
@ -546,3 +546,39 @@ async def test_response_streaming_multi_turn_tool_execution(
|
||||||
assert expected_output.lower() in final_response.output_text.lower(), (
|
assert expected_output.lower() in final_response.output_text.lower(), (
|
||||||
f"Expected '{expected_output}' to appear in response: {final_response.output_text}"
|
f"Expected '{expected_output}' to appear in response: {final_response.output_text}"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"text_format",
|
||||||
|
# Not testing json_object because most providers don't actually support it.
|
||||||
|
[
|
||||||
|
{"type": "text"},
|
||||||
|
{
|
||||||
|
"type": "json_schema",
|
||||||
|
"name": "capitals",
|
||||||
|
"description": "A schema for the capital of each country",
|
||||||
|
"schema": {"type": "object", "properties": {"capital": {"type": "string"}}},
|
||||||
|
"strict": True,
|
||||||
|
},
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_response_text_format(request, openai_client, model, provider, verification_config, text_format):
|
||||||
|
if isinstance(openai_client, LlamaStackAsLibraryClient):
|
||||||
|
pytest.skip("Responses API text format is not yet supported in library client.")
|
||||||
|
|
||||||
|
test_name_base = get_base_test_name(request)
|
||||||
|
if should_skip_test(verification_config, provider, model, test_name_base):
|
||||||
|
pytest.skip(f"Skipping {test_name_base} for model {model} on provider {provider} based on config.")
|
||||||
|
|
||||||
|
stream = False
|
||||||
|
response = openai_client.responses.create(
|
||||||
|
model=model,
|
||||||
|
input="What is the capital of France?",
|
||||||
|
stream=stream,
|
||||||
|
text={"format": text_format},
|
||||||
|
)
|
||||||
|
# by_alias=True is needed because otherwise Pydantic renames our "schema" field
|
||||||
|
assert response.text.format.model_dump(exclude_none=True, by_alias=True) == text_format
|
||||||
|
assert "paris" in response.output_text.lower()
|
||||||
|
if text_format["type"] == "json_schema":
|
||||||
|
assert "paris" in json.loads(response.output_text)["capital"].lower()
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue