diff --git a/docs/_static/llama-stack-spec.html b/docs/_static/llama-stack-spec.html index b3e982029..d21ff81fd 100644 --- a/docs/_static/llama-stack-spec.html +++ b/docs/_static/llama-stack-spec.html @@ -518,6 +518,74 @@ } }, "/v1/openai/v1/responses": { + "get": { + "responses": { + "200": { + "description": "A ListOpenAIResponseObject.", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/ListOpenAIResponseObject" + } + } + } + }, + "400": { + "$ref": "#/components/responses/BadRequest400" + }, + "429": { + "$ref": "#/components/responses/TooManyRequests429" + }, + "500": { + "$ref": "#/components/responses/InternalServerError500" + }, + "default": { + "$ref": "#/components/responses/DefaultError" + } + }, + "tags": [ + "Agents" + ], + "description": "List all OpenAI responses.", + "parameters": [ + { + "name": "after", + "in": "query", + "description": "The ID of the last response to return.", + "required": false, + "schema": { + "type": "string" + } + }, + { + "name": "limit", + "in": "query", + "description": "The number of responses to return.", + "required": false, + "schema": { + "type": "integer" + } + }, + { + "name": "model", + "in": "query", + "description": "The model to filter responses by.", + "required": false, + "schema": { + "type": "string" + } + }, + { + "name": "order", + "in": "query", + "description": "The order to sort responses by when sorted by created_at ('asc' or 'desc').", + "required": false, + "schema": { + "$ref": "#/components/schemas/Order" + } + } + ] + }, "post": { "responses": { "200": { @@ -10179,6 +10247,108 @@ ], "title": "ListModelsResponse" }, + "ListOpenAIResponseObject": { + "type": "object", + "properties": { + "data": { + "type": "array", + "items": { + "$ref": "#/components/schemas/OpenAIResponseObjectWithInput" + } + }, + "has_more": { + "type": "boolean" + }, + "first_id": { + "type": "string" + }, + "last_id": { + "type": "string" + }, + "object": { + "type": "string", + "const": "list", + "default": "list" + } + }, + "additionalProperties": false, + "required": [ + "data", + "has_more", + "first_id", + "last_id", + "object" + ], + "title": "ListOpenAIResponseObject" + }, + "OpenAIResponseObjectWithInput": { + "type": "object", + "properties": { + "created_at": { + "type": "integer" + }, + "error": { + "$ref": "#/components/schemas/OpenAIResponseError" + }, + "id": { + "type": "string" + }, + "model": { + "type": "string" + }, + "object": { + "type": "string", + "const": "response", + "default": "response" + }, + "output": { + "type": "array", + "items": { + "$ref": "#/components/schemas/OpenAIResponseOutput" + } + }, + "parallel_tool_calls": { + "type": "boolean", + "default": false + }, + "previous_response_id": { + "type": "string" + }, + "status": { + "type": "string" + }, + "temperature": { + "type": "number" + }, + "top_p": { + "type": "number" + }, + "truncation": { + "type": "string" + }, + "user": { + "type": "string" + }, + "input": { + "type": "array", + "items": { + "$ref": "#/components/schemas/OpenAIResponseInput" + } + } + }, + "additionalProperties": false, + "required": [ + "created_at", + "id", + "model", + "object", + "output", + "parallel_tool_calls", + "status", + "input" + ], + "title": "OpenAIResponseObjectWithInput" + }, "ListProvidersResponse": { "type": "object", "properties": { diff --git a/docs/_static/llama-stack-spec.yaml b/docs/_static/llama-stack-spec.yaml index 18cd2b046..8a936fcee 100644 --- a/docs/_static/llama-stack-spec.yaml +++ b/docs/_static/llama-stack-spec.yaml @@ -349,6 +349,53 @@ paths: $ref: '#/components/schemas/CreateAgentTurnRequest' required: true /v1/openai/v1/responses: + get: + responses: + '200': + description: A ListOpenAIResponseObject. + content: + application/json: + schema: + $ref: '#/components/schemas/ListOpenAIResponseObject' + '400': + $ref: '#/components/responses/BadRequest400' + '429': + $ref: >- + #/components/responses/TooManyRequests429 + '500': + $ref: >- + #/components/responses/InternalServerError500 + default: + $ref: '#/components/responses/DefaultError' + tags: + - Agents + description: List all OpenAI responses. + parameters: + - name: after + in: query + description: The ID of the last response to return. + required: false + schema: + type: string + - name: limit + in: query + description: The number of responses to return. + required: false + schema: + type: integer + - name: model + in: query + description: The model to filter responses by. + required: false + schema: + type: string + - name: order + in: query + description: >- + The order to sort responses by when sorted by created_at ('asc' or 'desc'). + required: false + schema: + $ref: '#/components/schemas/Order' post: responses: '200': @@ -7106,6 +7153,80 @@ components: required: - data title: ListModelsResponse + ListOpenAIResponseObject: + type: object + properties: + data: + type: array + items: + $ref: '#/components/schemas/OpenAIResponseObjectWithInput' + has_more: + type: boolean + first_id: + type: string + last_id: + type: string + object: + type: string + const: list + default: list + additionalProperties: false + required: + - data + - has_more + - first_id + - last_id + - object + title: ListOpenAIResponseObject + OpenAIResponseObjectWithInput: + type: object + properties: + created_at: + type: integer + error: + $ref: '#/components/schemas/OpenAIResponseError' + id: + type: string + model: + type: string + object: + type: string + const: response + default: response + output: + type: array + items: + $ref: '#/components/schemas/OpenAIResponseOutput' + parallel_tool_calls: + type: boolean + default: false + previous_response_id: + type: string + status: + type: string + temperature: + type: number + top_p: + type: number + truncation: + type: string + user: + type: string + input: + type: array + items: + $ref: '#/components/schemas/OpenAIResponseInput' + additionalProperties: false + required: + - created_at + - id + - model + - object + - output + - parallel_tool_calls + - status + - input + title: OpenAIResponseObjectWithInput ListProvidersResponse: type: object properties: diff --git a/llama_stack/apis/agents/agents.py b/llama_stack/apis/agents/agents.py index 5e857c895..bb185b8a3 100644 --- a/llama_stack/apis/agents/agents.py +++ b/llama_stack/apis/agents/agents.py @@ -13,7 +13,7 @@ from typing import Annotated, Any, Literal, Protocol, runtime_checkable from pydantic import BaseModel, ConfigDict, Field from llama_stack.apis.common.content_types import URL, ContentDelta, InterleavedContent -from llama_stack.apis.common.responses import PaginatedResponse +from llama_stack.apis.common.responses import Order, PaginatedResponse from llama_stack.apis.inference import ( CompletionMessage, ResponseFormat, @@ -31,6 +31,7 @@ from llama_stack.apis.tools import ToolDef from llama_stack.schema_utils import json_schema_type, register_schema, webmethod from .openai_responses import ( + ListOpenAIResponseObject, OpenAIResponseInput, OpenAIResponseInputTool, OpenAIResponseObject, @@ -611,3 +612,21 @@ class Agents(Protocol): :returns: An OpenAIResponseObject. """ ... + + @webmethod(route="/openai/v1/responses", method="GET") + async def list_openai_responses( + self, + after: str | None = None, + limit: int | None = 50, + model: str | None = None, + order: Order | None = Order.desc, + ) -> ListOpenAIResponseObject: + """List all OpenAI responses. + + :param after: The ID of the last response to return. + :param limit: The number of responses to return. + :param model: The model to filter responses by. + :param order: The order to sort responses by when sorted by created_at ('asc' or 'desc'). + :returns: A ListOpenAIResponseObject. + """ + ... diff --git a/llama_stack/apis/agents/openai_responses.py b/llama_stack/apis/agents/openai_responses.py index bb463bd57..5d8f2b80b 100644 --- a/llama_stack/apis/agents/openai_responses.py +++ b/llama_stack/apis/agents/openai_responses.py @@ -219,3 +219,17 @@ register_schema(OpenAIResponseInputTool, name="OpenAIResponseInputTool") class OpenAIResponseInputItemList(BaseModel): data: list[OpenAIResponseInput] object: Literal["list"] = "list" + + +@json_schema_type +class OpenAIResponseObjectWithInput(OpenAIResponseObject): + input: list[OpenAIResponseInput] + + +@json_schema_type +class ListOpenAIResponseObject(BaseModel): + data: list[OpenAIResponseObjectWithInput] + has_more: bool + first_id: str + last_id: str + object: Literal["list"] = "list" diff --git a/llama_stack/apis/common/responses.py b/llama_stack/apis/common/responses.py index b3bb5cb6b..5cb41e23d 100644 --- a/llama_stack/apis/common/responses.py +++ b/llama_stack/apis/common/responses.py @@ -4,6 +4,7 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. +from enum import Enum from typing import Any from pydantic import BaseModel @@ -11,6 +12,11 @@ from pydantic import BaseModel from llama_stack.schema_utils import json_schema_type +class Order(Enum): + asc = "asc" + desc = "desc" + + @json_schema_type class PaginatedResponse(BaseModel): """A generic paginated response that follows a simple format. diff --git a/llama_stack/apis/inference/inference.py b/llama_stack/apis/inference/inference.py index 7f8b20952..e79dc6d94 100644 --- a/llama_stack/apis/inference/inference.py +++ b/llama_stack/apis/inference/inference.py @@ -19,6 +19,7 @@ from pydantic import BaseModel, Field, field_validator from typing_extensions import TypedDict from llama_stack.apis.common.content_types import ContentDelta, InterleavedContent, InterleavedContentItem +from llama_stack.apis.common.responses import Order from llama_stack.apis.models import Model from llama_stack.apis.telemetry.telemetry import MetricResponseMixin from llama_stack.models.llama.datatypes import ( @@ -833,11 +834,6 @@ class ListOpenAIChatCompletionResponse(BaseModel): object: Literal["list"] = "list" -class Order(Enum): - asc = "asc" - desc = "desc" - - @runtime_checkable @trace_protocol class InferenceProvider(Protocol): diff --git a/llama_stack/providers/inline/agents/meta_reference/agents.py b/llama_stack/providers/inline/agents/meta_reference/agents.py index e98799ae6..787135488 100644 --- a/llama_stack/providers/inline/agents/meta_reference/agents.py +++ b/llama_stack/providers/inline/agents/meta_reference/agents.py @@ -20,9 +20,11 @@ from llama_stack.apis.agents import ( AgentTurnCreateRequest, AgentTurnResumeRequest, Document, + ListOpenAIResponseObject, OpenAIResponseInput, OpenAIResponseInputTool, OpenAIResponseObject, + Order, Session, Turn, ) @@ -39,6 +41,7 @@ from llama_stack.apis.tools import ToolGroups, ToolRuntime from llama_stack.apis.vector_io import VectorIO from llama_stack.providers.utils.kvstore import InmemoryKVStoreImpl, kvstore_impl from llama_stack.providers.utils.pagination import paginate_records +from llama_stack.providers.utils.responses.responses_store import ResponsesStore from .agent_instance import ChatAgent from .config import MetaReferenceAgentsImplConfig @@ -66,15 +69,17 @@ class MetaReferenceAgentsImpl(Agents): self.tool_groups_api = tool_groups_api self.in_memory_store = InmemoryKVStoreImpl() - self.openai_responses_impl = None + self.openai_responses_impl: OpenAIResponsesImpl | None = None async def initialize(self) -> None: self.persistence_store = await kvstore_impl(self.config.persistence_store) + self.responses_store = ResponsesStore(self.config.responses_store) + await self.responses_store.initialize() self.openai_responses_impl = OpenAIResponsesImpl( - self.persistence_store, inference_api=self.inference_api, tool_groups_api=self.tool_groups_api, tool_runtime_api=self.tool_runtime_api, + responses_store=self.responses_store, ) async def create_agent( @@ -323,3 +328,12 @@ class MetaReferenceAgentsImpl(Agents): return await self.openai_responses_impl.create_openai_response( input, model, instructions, previous_response_id, store, stream, temperature, tools ) + + async def list_openai_responses( + self, + after: str | None = None, + limit: int | None = 50, + model: str | None = None, + order: Order | None = Order.desc, + ) -> ListOpenAIResponseObject: + return await self.openai_responses_impl.list_openai_responses(after, limit, model, order) diff --git a/llama_stack/providers/inline/agents/meta_reference/config.py b/llama_stack/providers/inline/agents/meta_reference/config.py index c860e6df1..1c392f29c 100644 --- a/llama_stack/providers/inline/agents/meta_reference/config.py +++ b/llama_stack/providers/inline/agents/meta_reference/config.py @@ -10,10 +10,12 @@ from pydantic import BaseModel from llama_stack.providers.utils.kvstore import KVStoreConfig from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig +from llama_stack.providers.utils.sqlstore.sqlstore import SqliteSqlStoreConfig, SqlStoreConfig class MetaReferenceAgentsImplConfig(BaseModel): persistence_store: KVStoreConfig + responses_store: SqlStoreConfig @classmethod def sample_run_config(cls, __distro_dir__: str) -> dict[str, Any]: @@ -21,5 +23,9 @@ class MetaReferenceAgentsImplConfig(BaseModel): "persistence_store": SqliteKVStoreConfig.sample_run_config( __distro_dir__=__distro_dir__, db_name="agents_store.db", - ) + ), + "responses_store": SqliteSqlStoreConfig.sample_run_config( + __distro_dir__=__distro_dir__, + db_name="responses_store.db", + ), } diff --git a/llama_stack/providers/inline/agents/meta_reference/openai_responses.py b/llama_stack/providers/inline/agents/meta_reference/openai_responses.py index 92345a12f..939282005 100644 --- a/llama_stack/providers/inline/agents/meta_reference/openai_responses.py +++ b/llama_stack/providers/inline/agents/meta_reference/openai_responses.py @@ -12,7 +12,9 @@ from typing import Any, cast from openai.types.chat import ChatCompletionToolParam from pydantic import BaseModel +from llama_stack.apis.agents import Order from llama_stack.apis.agents.openai_responses import ( + ListOpenAIResponseObject, OpenAIResponseInput, OpenAIResponseInputFunctionToolCallOutput, OpenAIResponseInputItemList, @@ -53,7 +55,7 @@ from llama_stack.apis.tools.tools import ToolGroups, ToolInvocationResult, ToolR from llama_stack.log import get_logger from llama_stack.models.llama.datatypes import ToolDefinition, ToolParamDefinition from llama_stack.providers.utils.inference.openai_compat import convert_tooldef_to_openai_tool -from llama_stack.providers.utils.kvstore import KVStore +from llama_stack.providers.utils.responses.responses_store import ResponsesStore logger = get_logger(name=__name__, category="openai_responses") @@ -169,34 +171,27 @@ class OpenAIResponsePreviousResponseWithInputItems(BaseModel): class OpenAIResponsesImpl: def __init__( self, - persistence_store: KVStore, inference_api: Inference, tool_groups_api: ToolGroups, tool_runtime_api: ToolRuntime, + responses_store: ResponsesStore, ): - self.persistence_store = persistence_store self.inference_api = inference_api self.tool_groups_api = tool_groups_api self.tool_runtime_api = tool_runtime_api - - async def _get_previous_response_with_input(self, id: str) -> OpenAIResponsePreviousResponseWithInputItems: - key = f"{OPENAI_RESPONSES_PREFIX}{id}" - response_json = await self.persistence_store.get(key=key) - if response_json is None: - raise ValueError(f"OpenAI response with id '{id}' not found") - return OpenAIResponsePreviousResponseWithInputItems.model_validate_json(response_json) + self.responses_store = responses_store async def _prepend_previous_response( self, input: str | list[OpenAIResponseInput], previous_response_id: str | None = None ): if previous_response_id: - previous_response_with_input = await self._get_previous_response_with_input(previous_response_id) + previous_response_with_input = await self.responses_store.get_response_object(previous_response_id) # previous response input items - new_input_items = previous_response_with_input.input_items.data + new_input_items = previous_response_with_input.input # previous response output items - new_input_items.extend(previous_response_with_input.response.output) + new_input_items.extend(previous_response_with_input.output) # new input items from the current request if isinstance(input, str): @@ -216,8 +211,17 @@ class OpenAIResponsesImpl: self, response_id: str, ) -> OpenAIResponseObject: - response_with_input = await self._get_previous_response_with_input(response_id) - return response_with_input.response + response_with_input = await self.responses_store.get_response_object(response_id) + return OpenAIResponseObject(**{k: v for k, v in response_with_input.model_dump().items() if k != "input"}) + + async def list_openai_responses( + self, + after: str | None = None, + limit: int | None = 50, + model: str | None = None, + order: Order | None = Order.desc, + ) -> ListOpenAIResponseObject: + return await self.responses_store.list_responses(after, limit, model, order) async def create_openai_response( self, @@ -360,15 +364,9 @@ class OpenAIResponsesImpl: else: input_items_data.append(input_item) - input_items = OpenAIResponseInputItemList(data=input_items_data) - prev_response = OpenAIResponsePreviousResponseWithInputItems( - input_items=input_items, - response=response, - ) - key = f"{OPENAI_RESPONSES_PREFIX}{response.id}" - await self.persistence_store.set( - key=key, - value=prev_response.model_dump_json(), + await self.responses_store.store_response_object( + response_object=response, + input=input_items_data, ) if stream: diff --git a/llama_stack/providers/utils/responses/responses_store.py b/llama_stack/providers/utils/responses/responses_store.py new file mode 100644 index 000000000..19da6785a --- /dev/null +++ b/llama_stack/providers/utils/responses/responses_store.py @@ -0,0 +1,98 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. +from llama_stack.apis.agents import ( + Order, +) +from llama_stack.apis.agents.openai_responses import ( + ListOpenAIResponseObject, + OpenAIResponseInput, + OpenAIResponseObject, + OpenAIResponseObjectWithInput, +) +from llama_stack.distribution.utils.config_dirs import RUNTIME_BASE_DIR + +from ..sqlstore.api import ColumnDefinition, ColumnType +from ..sqlstore.sqlstore import SqliteSqlStoreConfig, SqlStoreConfig, sqlstore_impl + + +class ResponsesStore: + def __init__(self, sql_store_config: SqlStoreConfig): + if not sql_store_config: + sql_store_config = SqliteSqlStoreConfig( + db_path=(RUNTIME_BASE_DIR / "sqlstore.db").as_posix(), + ) + self.sql_store = sqlstore_impl(sql_store_config) + + async def initialize(self): + """Create the necessary tables if they don't exist.""" + await self.sql_store.create_table( + "openai_responses", + { + "id": ColumnDefinition(type=ColumnType.STRING, primary_key=True), + "created_at": ColumnType.INTEGER, + "response_object": ColumnType.JSON, + "model": ColumnType.STRING, + }, + ) + + async def store_response_object( + self, response_object: OpenAIResponseObject, input: list[OpenAIResponseInput] + ) -> None: + data = response_object.model_dump() + data["input"] = [input_item.model_dump() for input_item in input] + + await self.sql_store.insert( + "openai_responses", + { + "id": data["id"], + "created_at": data["created_at"], + "model": data["model"], + "response_object": data, + }, + ) + + async def list_responses( + self, + after: str | None = None, + limit: int | None = 50, + model: str | None = None, + order: Order | None = Order.desc, + ) -> ListOpenAIResponseObject: + """ + List responses from the database. + + :param after: The ID of the last response to return. + :param limit: The maximum number of responses to return. + :param model: The model to filter by. + :param order: The order to sort the responses by. + """ + # TODO: support after + if after: + raise NotImplementedError("After is not supported for SQLite") + if not order: + order = Order.desc + + rows = await self.sql_store.fetch_all( + "openai_responses", + where={"model": model} if model else None, + order_by=[("created_at", order.value)], + limit=limit, + ) + + data = [OpenAIResponseObjectWithInput(**row["response_object"]) for row in rows] + return ListOpenAIResponseObject( + data=data, + # TODO: implement has_more + has_more=False, + first_id=data[0].id if data else "", + last_id=data[-1].id if data else "", + ) + + async def get_response_object(self, response_id: str) -> OpenAIResponseObjectWithInput: + row = await self.sql_store.fetch_one("openai_responses", where={"id": response_id}) + if not row: + raise ValueError(f"Response with id {response_id} not found") from None + return OpenAIResponseObjectWithInput(**row["response_object"]) diff --git a/llama_stack/templates/bedrock/run.yaml b/llama_stack/templates/bedrock/run.yaml index c39b08ff9..a58068a60 100644 --- a/llama_stack/templates/bedrock/run.yaml +++ b/llama_stack/templates/bedrock/run.yaml @@ -35,6 +35,9 @@ providers: type: sqlite namespace: null db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/bedrock}/agents_store.db + responses_store: + type: sqlite + db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/bedrock}/responses_store.db telemetry: - provider_id: meta-reference provider_type: inline::meta-reference diff --git a/llama_stack/templates/cerebras/run.yaml b/llama_stack/templates/cerebras/run.yaml index 025033f59..c080536b7 100644 --- a/llama_stack/templates/cerebras/run.yaml +++ b/llama_stack/templates/cerebras/run.yaml @@ -41,6 +41,9 @@ providers: type: sqlite namespace: null db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/cerebras}/agents_store.db + responses_store: + type: sqlite + db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/cerebras}/responses_store.db eval: - provider_id: meta-reference provider_type: inline::meta-reference diff --git a/llama_stack/templates/ci-tests/run.yaml b/llama_stack/templates/ci-tests/run.yaml index 342388b78..368187d3a 100644 --- a/llama_stack/templates/ci-tests/run.yaml +++ b/llama_stack/templates/ci-tests/run.yaml @@ -38,6 +38,9 @@ providers: type: sqlite namespace: null db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/ci-tests}/agents_store.db + responses_store: + type: sqlite + db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/ci-tests}/responses_store.db telemetry: - provider_id: meta-reference provider_type: inline::meta-reference diff --git a/llama_stack/templates/dell/run-with-safety.yaml b/llama_stack/templates/dell/run-with-safety.yaml index 77843858c..5c6072245 100644 --- a/llama_stack/templates/dell/run-with-safety.yaml +++ b/llama_stack/templates/dell/run-with-safety.yaml @@ -41,6 +41,9 @@ providers: type: sqlite namespace: null db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/dell}/agents_store.db + responses_store: + type: sqlite + db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/dell}/responses_store.db telemetry: - provider_id: meta-reference provider_type: inline::meta-reference diff --git a/llama_stack/templates/dell/run.yaml b/llama_stack/templates/dell/run.yaml index fd0d4a1f6..ffaa0bf2f 100644 --- a/llama_stack/templates/dell/run.yaml +++ b/llama_stack/templates/dell/run.yaml @@ -37,6 +37,9 @@ providers: type: sqlite namespace: null db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/dell}/agents_store.db + responses_store: + type: sqlite + db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/dell}/responses_store.db telemetry: - provider_id: meta-reference provider_type: inline::meta-reference diff --git a/llama_stack/templates/fireworks/run-with-safety.yaml b/llama_stack/templates/fireworks/run-with-safety.yaml index 1f66983f4..41500f6f6 100644 --- a/llama_stack/templates/fireworks/run-with-safety.yaml +++ b/llama_stack/templates/fireworks/run-with-safety.yaml @@ -46,6 +46,9 @@ providers: type: sqlite namespace: null db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/fireworks}/agents_store.db + responses_store: + type: sqlite + db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/fireworks}/responses_store.db telemetry: - provider_id: meta-reference provider_type: inline::meta-reference diff --git a/llama_stack/templates/fireworks/run.yaml b/llama_stack/templates/fireworks/run.yaml index 1fbf4be6e..b1fa03306 100644 --- a/llama_stack/templates/fireworks/run.yaml +++ b/llama_stack/templates/fireworks/run.yaml @@ -41,6 +41,9 @@ providers: type: sqlite namespace: null db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/fireworks}/agents_store.db + responses_store: + type: sqlite + db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/fireworks}/responses_store.db telemetry: - provider_id: meta-reference provider_type: inline::meta-reference diff --git a/llama_stack/templates/groq/run.yaml b/llama_stack/templates/groq/run.yaml index 7d257d379..db7ebffee 100644 --- a/llama_stack/templates/groq/run.yaml +++ b/llama_stack/templates/groq/run.yaml @@ -41,6 +41,9 @@ providers: type: sqlite namespace: null db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/groq}/agents_store.db + responses_store: + type: sqlite + db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/groq}/responses_store.db telemetry: - provider_id: meta-reference provider_type: inline::meta-reference diff --git a/llama_stack/templates/hf-endpoint/run-with-safety.yaml b/llama_stack/templates/hf-endpoint/run-with-safety.yaml index b3938bf93..15cf2a47f 100644 --- a/llama_stack/templates/hf-endpoint/run-with-safety.yaml +++ b/llama_stack/templates/hf-endpoint/run-with-safety.yaml @@ -46,6 +46,9 @@ providers: type: sqlite namespace: null db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/hf-endpoint}/agents_store.db + responses_store: + type: sqlite + db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/hf-endpoint}/responses_store.db telemetry: - provider_id: meta-reference provider_type: inline::meta-reference diff --git a/llama_stack/templates/hf-endpoint/run.yaml b/llama_stack/templates/hf-endpoint/run.yaml index 1e60dd25c..428edf9a2 100644 --- a/llama_stack/templates/hf-endpoint/run.yaml +++ b/llama_stack/templates/hf-endpoint/run.yaml @@ -41,6 +41,9 @@ providers: type: sqlite namespace: null db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/hf-endpoint}/agents_store.db + responses_store: + type: sqlite + db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/hf-endpoint}/responses_store.db telemetry: - provider_id: meta-reference provider_type: inline::meta-reference diff --git a/llama_stack/templates/hf-serverless/run-with-safety.yaml b/llama_stack/templates/hf-serverless/run-with-safety.yaml index 640506632..ab461c6c3 100644 --- a/llama_stack/templates/hf-serverless/run-with-safety.yaml +++ b/llama_stack/templates/hf-serverless/run-with-safety.yaml @@ -46,6 +46,9 @@ providers: type: sqlite namespace: null db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/hf-serverless}/agents_store.db + responses_store: + type: sqlite + db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/hf-serverless}/responses_store.db telemetry: - provider_id: meta-reference provider_type: inline::meta-reference diff --git a/llama_stack/templates/hf-serverless/run.yaml b/llama_stack/templates/hf-serverless/run.yaml index a8b46a0aa..d238506fb 100644 --- a/llama_stack/templates/hf-serverless/run.yaml +++ b/llama_stack/templates/hf-serverless/run.yaml @@ -41,6 +41,9 @@ providers: type: sqlite namespace: null db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/hf-serverless}/agents_store.db + responses_store: + type: sqlite + db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/hf-serverless}/responses_store.db telemetry: - provider_id: meta-reference provider_type: inline::meta-reference diff --git a/llama_stack/templates/llama_api/run.yaml b/llama_stack/templates/llama_api/run.yaml index 1d5739fe2..a7f2b0769 100644 --- a/llama_stack/templates/llama_api/run.yaml +++ b/llama_stack/templates/llama_api/run.yaml @@ -50,6 +50,9 @@ providers: type: sqlite namespace: null db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/llama_api}/agents_store.db + responses_store: + type: sqlite + db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/llama_api}/responses_store.db telemetry: - provider_id: meta-reference provider_type: inline::meta-reference diff --git a/llama_stack/templates/meta-reference-gpu/run-with-safety.yaml b/llama_stack/templates/meta-reference-gpu/run-with-safety.yaml index bbf7ad767..2b751a514 100644 --- a/llama_stack/templates/meta-reference-gpu/run-with-safety.yaml +++ b/llama_stack/templates/meta-reference-gpu/run-with-safety.yaml @@ -56,6 +56,9 @@ providers: type: sqlite namespace: null db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/meta-reference-gpu}/agents_store.db + responses_store: + type: sqlite + db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/meta-reference-gpu}/responses_store.db telemetry: - provider_id: meta-reference provider_type: inline::meta-reference diff --git a/llama_stack/templates/meta-reference-gpu/run.yaml b/llama_stack/templates/meta-reference-gpu/run.yaml index 9ce69c209..a24c5fec5 100644 --- a/llama_stack/templates/meta-reference-gpu/run.yaml +++ b/llama_stack/templates/meta-reference-gpu/run.yaml @@ -46,6 +46,9 @@ providers: type: sqlite namespace: null db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/meta-reference-gpu}/agents_store.db + responses_store: + type: sqlite + db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/meta-reference-gpu}/responses_store.db telemetry: - provider_id: meta-reference provider_type: inline::meta-reference diff --git a/llama_stack/templates/nvidia/run-with-safety.yaml b/llama_stack/templates/nvidia/run-with-safety.yaml index 32359b805..c431e12f2 100644 --- a/llama_stack/templates/nvidia/run-with-safety.yaml +++ b/llama_stack/templates/nvidia/run-with-safety.yaml @@ -46,6 +46,9 @@ providers: type: sqlite namespace: null db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/nvidia}/agents_store.db + responses_store: + type: sqlite + db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/nvidia}/responses_store.db telemetry: - provider_id: meta-reference provider_type: inline::meta-reference diff --git a/llama_stack/templates/nvidia/run.yaml b/llama_stack/templates/nvidia/run.yaml index d4e935727..5b244081d 100644 --- a/llama_stack/templates/nvidia/run.yaml +++ b/llama_stack/templates/nvidia/run.yaml @@ -41,6 +41,9 @@ providers: type: sqlite namespace: null db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/nvidia}/agents_store.db + responses_store: + type: sqlite + db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/nvidia}/responses_store.db telemetry: - provider_id: meta-reference provider_type: inline::meta-reference diff --git a/llama_stack/templates/ollama/run-with-safety.yaml b/llama_stack/templates/ollama/run-with-safety.yaml index a19ac73c6..d63c5e366 100644 --- a/llama_stack/templates/ollama/run-with-safety.yaml +++ b/llama_stack/templates/ollama/run-with-safety.yaml @@ -40,6 +40,9 @@ providers: type: sqlite namespace: null db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/ollama}/agents_store.db + responses_store: + type: sqlite + db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/ollama}/responses_store.db telemetry: - provider_id: meta-reference provider_type: inline::meta-reference diff --git a/llama_stack/templates/ollama/run.yaml b/llama_stack/templates/ollama/run.yaml index 551af3a99..d208cd7f0 100644 --- a/llama_stack/templates/ollama/run.yaml +++ b/llama_stack/templates/ollama/run.yaml @@ -38,6 +38,9 @@ providers: type: sqlite namespace: null db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/ollama}/agents_store.db + responses_store: + type: sqlite + db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/ollama}/responses_store.db telemetry: - provider_id: meta-reference provider_type: inline::meta-reference diff --git a/llama_stack/templates/open-benchmark/run.yaml b/llama_stack/templates/open-benchmark/run.yaml index 7b43ce6e7..0e5edf728 100644 --- a/llama_stack/templates/open-benchmark/run.yaml +++ b/llama_stack/templates/open-benchmark/run.yaml @@ -64,6 +64,9 @@ providers: type: sqlite namespace: null db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/open-benchmark}/agents_store.db + responses_store: + type: sqlite + db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/open-benchmark}/responses_store.db telemetry: - provider_id: meta-reference provider_type: inline::meta-reference diff --git a/llama_stack/templates/passthrough/run-with-safety.yaml b/llama_stack/templates/passthrough/run-with-safety.yaml index cddda39fa..bbf5d9a52 100644 --- a/llama_stack/templates/passthrough/run-with-safety.yaml +++ b/llama_stack/templates/passthrough/run-with-safety.yaml @@ -46,6 +46,9 @@ providers: type: sqlite namespace: null db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/passthrough}/agents_store.db + responses_store: + type: sqlite + db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/passthrough}/responses_store.db telemetry: - provider_id: meta-reference provider_type: inline::meta-reference diff --git a/llama_stack/templates/passthrough/run.yaml b/llama_stack/templates/passthrough/run.yaml index 1fc3914a6..146906d9b 100644 --- a/llama_stack/templates/passthrough/run.yaml +++ b/llama_stack/templates/passthrough/run.yaml @@ -41,6 +41,9 @@ providers: type: sqlite namespace: null db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/passthrough}/agents_store.db + responses_store: + type: sqlite + db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/passthrough}/responses_store.db telemetry: - provider_id: meta-reference provider_type: inline::meta-reference diff --git a/llama_stack/templates/remote-vllm/run-with-safety.yaml b/llama_stack/templates/remote-vllm/run-with-safety.yaml index 89f3aa082..e83162a4f 100644 --- a/llama_stack/templates/remote-vllm/run-with-safety.yaml +++ b/llama_stack/templates/remote-vllm/run-with-safety.yaml @@ -50,6 +50,9 @@ providers: type: sqlite namespace: null db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/remote-vllm}/agents_store.db + responses_store: + type: sqlite + db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/remote-vllm}/responses_store.db eval: - provider_id: meta-reference provider_type: inline::meta-reference diff --git a/llama_stack/templates/remote-vllm/run.yaml b/llama_stack/templates/remote-vllm/run.yaml index 4d4395fd7..4cdf88c6b 100644 --- a/llama_stack/templates/remote-vllm/run.yaml +++ b/llama_stack/templates/remote-vllm/run.yaml @@ -43,6 +43,9 @@ providers: type: sqlite namespace: null db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/remote-vllm}/agents_store.db + responses_store: + type: sqlite + db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/remote-vllm}/responses_store.db eval: - provider_id: meta-reference provider_type: inline::meta-reference diff --git a/llama_stack/templates/sambanova/run.yaml b/llama_stack/templates/sambanova/run.yaml index 907bc013e..8c2a933ab 100644 --- a/llama_stack/templates/sambanova/run.yaml +++ b/llama_stack/templates/sambanova/run.yaml @@ -51,6 +51,9 @@ providers: type: sqlite namespace: null db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/sambanova}/agents_store.db + responses_store: + type: sqlite + db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/sambanova}/responses_store.db telemetry: - provider_id: meta-reference provider_type: inline::meta-reference diff --git a/llama_stack/templates/starter/run.yaml b/llama_stack/templates/starter/run.yaml index 3327e576c..04425ed35 100644 --- a/llama_stack/templates/starter/run.yaml +++ b/llama_stack/templates/starter/run.yaml @@ -72,6 +72,9 @@ providers: type: sqlite namespace: null db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/starter}/agents_store.db + responses_store: + type: sqlite + db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/starter}/responses_store.db telemetry: - provider_id: meta-reference provider_type: inline::meta-reference diff --git a/llama_stack/templates/tgi/run-with-safety.yaml b/llama_stack/templates/tgi/run-with-safety.yaml index bd197b93f..c797b93aa 100644 --- a/llama_stack/templates/tgi/run-with-safety.yaml +++ b/llama_stack/templates/tgi/run-with-safety.yaml @@ -41,6 +41,9 @@ providers: type: sqlite namespace: null db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/tgi}/agents_store.db + responses_store: + type: sqlite + db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/tgi}/responses_store.db telemetry: - provider_id: meta-reference provider_type: inline::meta-reference diff --git a/llama_stack/templates/tgi/run.yaml b/llama_stack/templates/tgi/run.yaml index 230fe9a5a..7e91d20bd 100644 --- a/llama_stack/templates/tgi/run.yaml +++ b/llama_stack/templates/tgi/run.yaml @@ -40,6 +40,9 @@ providers: type: sqlite namespace: null db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/tgi}/agents_store.db + responses_store: + type: sqlite + db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/tgi}/responses_store.db telemetry: - provider_id: meta-reference provider_type: inline::meta-reference diff --git a/llama_stack/templates/together/run-with-safety.yaml b/llama_stack/templates/together/run-with-safety.yaml index 1c05e5e42..190a0400b 100644 --- a/llama_stack/templates/together/run-with-safety.yaml +++ b/llama_stack/templates/together/run-with-safety.yaml @@ -46,6 +46,9 @@ providers: type: sqlite namespace: null db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/together}/agents_store.db + responses_store: + type: sqlite + db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/together}/responses_store.db telemetry: - provider_id: meta-reference provider_type: inline::meta-reference diff --git a/llama_stack/templates/together/run.yaml b/llama_stack/templates/together/run.yaml index aebf4e1a2..ce9542130 100644 --- a/llama_stack/templates/together/run.yaml +++ b/llama_stack/templates/together/run.yaml @@ -41,6 +41,9 @@ providers: type: sqlite namespace: null db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/together}/agents_store.db + responses_store: + type: sqlite + db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/together}/responses_store.db telemetry: - provider_id: meta-reference provider_type: inline::meta-reference diff --git a/llama_stack/templates/verification/run.yaml b/llama_stack/templates/verification/run.yaml index de8b0d850..58b3c576c 100644 --- a/llama_stack/templates/verification/run.yaml +++ b/llama_stack/templates/verification/run.yaml @@ -74,6 +74,9 @@ providers: type: sqlite namespace: null db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/verification}/agents_store.db + responses_store: + type: sqlite + db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/verification}/responses_store.db telemetry: - provider_id: meta-reference provider_type: inline::meta-reference diff --git a/llama_stack/templates/vllm-gpu/run.yaml b/llama_stack/templates/vllm-gpu/run.yaml index a0257f704..6937e2bac 100644 --- a/llama_stack/templates/vllm-gpu/run.yaml +++ b/llama_stack/templates/vllm-gpu/run.yaml @@ -45,6 +45,9 @@ providers: type: sqlite namespace: null db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/vllm-gpu}/agents_store.db + responses_store: + type: sqlite + db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/vllm-gpu}/responses_store.db telemetry: - provider_id: meta-reference provider_type: inline::meta-reference diff --git a/llama_stack/templates/watsonx/run.yaml b/llama_stack/templates/watsonx/run.yaml index 86ec01953..e7222fd57 100644 --- a/llama_stack/templates/watsonx/run.yaml +++ b/llama_stack/templates/watsonx/run.yaml @@ -42,6 +42,9 @@ providers: type: sqlite namespace: null db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/watsonx}/agents_store.db + responses_store: + type: sqlite + db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/watsonx}/responses_store.db telemetry: - provider_id: meta-reference provider_type: inline::meta-reference diff --git a/tests/integration/agents/test_openai_responses.py b/tests/integration/agents/test_openai_responses.py new file mode 100644 index 000000000..8af1c1870 --- /dev/null +++ b/tests/integration/agents/test_openai_responses.py @@ -0,0 +1,97 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from urllib.parse import urljoin + +import pytest +import requests +from openai import OpenAI + +from llama_stack.distribution.library_client import LlamaStackAsLibraryClient + + +@pytest.fixture +def openai_client(client_with_models): + base_url = f"{client_with_models.base_url}/v1/openai/v1" + return OpenAI(base_url=base_url, api_key="bar") + + +@pytest.mark.parametrize( + "stream", + [ + True, + False, + ], +) +@pytest.mark.parametrize( + "tools", + [ + [], + [ + { + "type": "function", + "name": "get_weather", + "description": "Get the weather in a given city", + "parameters": { + "type": "object", + "properties": { + "city": {"type": "string", "description": "The city to get the weather for"}, + }, + }, + } + ], + ], +) +def test_responses_store(openai_client, client_with_models, text_model_id, stream, tools): + if isinstance(client_with_models, LlamaStackAsLibraryClient): + pytest.skip("OpenAI responses are not supported when testing with library client yet.") + + client = openai_client + message = "What's the weather in Tokyo?" + ( + " YOU MUST USE THE get_weather function to get the weather." if tools else "" + ) + response = client.responses.create( + model=text_model_id, + input=[ + { + "role": "user", + "content": message, + } + ], + stream=stream, + tools=tools, + ) + if stream: + # accumulate the streamed content + content = "" + response_id = None + for chunk in response: + if response_id is None: + response_id = chunk.response.id + if not tools: + if chunk.type == "response.completed": + response_id = chunk.response.id + content = chunk.response.output[0].content[0].text + else: + response_id = response.id + if not tools: + content = response.output[0].content[0].text + + # list responses is not available in the SDK + url = urljoin(str(client.base_url), "responses") + response = requests.get(url, headers={"Authorization": f"Bearer {client.api_key}"}) + assert response.status_code == 200 + data = response.json()["data"] + assert response_id in [r["id"] for r in data] + + # test retrieve response + retrieved_response = client.responses.retrieve(response_id) + assert retrieved_response.id == response_id + assert retrieved_response.model == text_model_id + if tools: + assert retrieved_response.output[0].type == "function_call" + else: + assert retrieved_response.output[0].content[0].text == content diff --git a/tests/unit/providers/agent/test_meta_reference_agent.py b/tests/unit/providers/agent/test_meta_reference_agent.py index bef24e123..9549f6df6 100644 --- a/tests/unit/providers/agent/test_meta_reference_agent.py +++ b/tests/unit/providers/agent/test_meta_reference_agent.py @@ -43,6 +43,10 @@ def config(tmp_path): "type": "sqlite", "db_path": str(tmp_path / "test.db"), }, + responses_store={ + "type": "sqlite", + "db_path": str(tmp_path / "test.db"), + }, ) diff --git a/tests/unit/providers/agents/meta_reference/test_openai_responses.py b/tests/unit/providers/agents/meta_reference/test_openai_responses.py index 0a8d59306..bf36d7b64 100644 --- a/tests/unit/providers/agents/meta_reference/test_openai_responses.py +++ b/tests/unit/providers/agents/meta_reference/test_openai_responses.py @@ -4,7 +4,7 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from unittest.mock import AsyncMock, patch +from unittest.mock import AsyncMock import pytest from openai.types.chat.chat_completion_chunk import ( @@ -16,12 +16,11 @@ from openai.types.chat.chat_completion_chunk import ( ) from llama_stack.apis.agents.openai_responses import ( - OpenAIResponseInputItemList, OpenAIResponseInputMessageContentText, OpenAIResponseInputToolFunction, OpenAIResponseInputToolWebSearch, OpenAIResponseMessage, - OpenAIResponseObject, + OpenAIResponseObjectWithInput, OpenAIResponseOutputMessageContentOutputText, OpenAIResponseOutputMessageWebSearchToolCall, ) @@ -33,19 +32,12 @@ from llama_stack.apis.inference.inference import ( ) from llama_stack.apis.tools.tools import Tool, ToolGroups, ToolInvocationResult, ToolParameter, ToolRuntime from llama_stack.providers.inline.agents.meta_reference.openai_responses import ( - OpenAIResponsePreviousResponseWithInputItems, OpenAIResponsesImpl, ) -from llama_stack.providers.utils.kvstore import KVStore +from llama_stack.providers.utils.responses.responses_store import ResponsesStore from tests.unit.providers.agents.meta_reference.fixtures import load_chat_completion_fixture -@pytest.fixture -def mock_kvstore(): - kvstore = AsyncMock(spec=KVStore) - return kvstore - - @pytest.fixture def mock_inference_api(): inference_api = AsyncMock() @@ -65,12 +57,18 @@ def mock_tool_runtime_api(): @pytest.fixture -def openai_responses_impl(mock_kvstore, mock_inference_api, mock_tool_groups_api, mock_tool_runtime_api): +def mock_responses_store(): + responses_store = AsyncMock(spec=ResponsesStore) + return responses_store + + +@pytest.fixture +def openai_responses_impl(mock_inference_api, mock_tool_groups_api, mock_tool_runtime_api, mock_responses_store): return OpenAIResponsesImpl( - persistence_store=mock_kvstore, inference_api=mock_inference_api, tool_groups_api=mock_tool_groups_api, tool_runtime_api=mock_tool_runtime_api, + responses_store=mock_responses_store, ) @@ -100,7 +98,7 @@ async def test_create_openai_response_with_string_input(openai_responses_impl, m stream=False, temperature=0.1, ) - openai_responses_impl.persistence_store.set.assert_called_once() + openai_responses_impl.responses_store.store_response_object.assert_called_once() assert result.model == model assert len(result.output) == 1 assert isinstance(result.output[0], OpenAIResponseMessage) @@ -167,7 +165,7 @@ async def test_create_openai_response_with_string_input_with_tools(openai_respon kwargs={"query": "What is the capital of Ireland?"}, ) - openai_responses_impl.persistence_store.set.assert_called_once() + openai_responses_impl.responses_store.store_response_object.assert_called_once() # Check that we got the content from our mocked tool execution result assert len(result.output) >= 1 @@ -292,8 +290,7 @@ async def test_prepend_previous_response_none(openai_responses_impl): @pytest.mark.asyncio -@patch.object(OpenAIResponsesImpl, "_get_previous_response_with_input") -async def test_prepend_previous_response_basic(get_previous_response_with_input, openai_responses_impl): +async def test_prepend_previous_response_basic(openai_responses_impl, mock_responses_store): """Test prepending a basic previous response to a new response.""" input_item_message = OpenAIResponseMessage( @@ -301,25 +298,21 @@ async def test_prepend_previous_response_basic(get_previous_response_with_input, content=[OpenAIResponseInputMessageContentText(text="fake_previous_input")], role="user", ) - input_items = OpenAIResponseInputItemList(data=[input_item_message]) response_output_message = OpenAIResponseMessage( id="123", content=[OpenAIResponseOutputMessageContentOutputText(text="fake_response")], status="completed", role="assistant", ) - response = OpenAIResponseObject( + previous_response = OpenAIResponseObjectWithInput( created_at=1, id="resp_123", model="fake_model", output=[response_output_message], status="completed", + input=[input_item_message], ) - previous_response = OpenAIResponsePreviousResponseWithInputItems( - input_items=input_items, - response=response, - ) - get_previous_response_with_input.return_value = previous_response + mock_responses_store.get_response_object.return_value = previous_response input = await openai_responses_impl._prepend_previous_response("fake_input", "resp_123") @@ -336,16 +329,13 @@ async def test_prepend_previous_response_basic(get_previous_response_with_input, @pytest.mark.asyncio -@patch.object(OpenAIResponsesImpl, "_get_previous_response_with_input") -async def test_prepend_previous_response_web_search(get_previous_response_with_input, openai_responses_impl): +async def test_prepend_previous_response_web_search(openai_responses_impl, mock_responses_store): """Test prepending a web search previous response to a new response.""" - input_item_message = OpenAIResponseMessage( id="123", content=[OpenAIResponseInputMessageContentText(text="fake_previous_input")], role="user", ) - input_items = OpenAIResponseInputItemList(data=[input_item_message]) output_web_search = OpenAIResponseOutputMessageWebSearchToolCall( id="ws_123", status="completed", @@ -356,18 +346,15 @@ async def test_prepend_previous_response_web_search(get_previous_response_with_i status="completed", role="assistant", ) - response = OpenAIResponseObject( + response = OpenAIResponseObjectWithInput( created_at=1, id="resp_123", model="fake_model", output=[output_web_search, output_message], status="completed", + input=[input_item_message], ) - previous_response = OpenAIResponsePreviousResponseWithInputItems( - input_items=input_items, - response=response, - ) - get_previous_response_with_input.return_value = previous_response + mock_responses_store.get_response_object.return_value = response input_messages = [OpenAIResponseMessage(content="fake_input", role="user")] input = await openai_responses_impl._prepend_previous_response(input_messages, "resp_123") @@ -464,9 +451,8 @@ async def test_create_openai_response_with_instructions_and_multiple_messages( @pytest.mark.asyncio -@patch.object(OpenAIResponsesImpl, "_get_previous_response_with_input") async def test_create_openai_response_with_instructions_and_previous_response( - get_previous_response_with_input, openai_responses_impl, mock_inference_api + openai_responses_impl, mock_responses_store, mock_inference_api ): """Test prepending both instructions and previous response.""" @@ -475,25 +461,21 @@ async def test_create_openai_response_with_instructions_and_previous_response( content="Name some towns in Ireland", role="user", ) - input_items = OpenAIResponseInputItemList(data=[input_item_message]) response_output_message = OpenAIResponseMessage( id="123", content="Galway, Longford, Sligo", status="completed", role="assistant", ) - response = OpenAIResponseObject( + response = OpenAIResponseObjectWithInput( created_at=1, id="resp_123", model="fake_model", output=[response_output_message], status="completed", + input=[input_item_message], ) - previous_response = OpenAIResponsePreviousResponseWithInputItems( - input_items=input_items, - response=response, - ) - get_previous_response_with_input.return_value = previous_response + mock_responses_store.get_response_object.return_value = response model = "meta-llama/Llama-3.1-8B-Instruct" instructions = "You are a geography expert. Provide concise answers." @@ -511,7 +493,7 @@ async def test_create_openai_response_with_instructions_and_previous_response( sent_messages = call_args.kwargs["messages"] # Check that instructions were prepended as a system message - assert len(sent_messages) == 4 + assert len(sent_messages) == 4, sent_messages assert sent_messages[0].role == "system" assert sent_messages[0].content == instructions diff --git a/tests/verifications/openai-api-verification-run.yaml b/tests/verifications/openai-api-verification-run.yaml index 4c322af28..d6d8cd07d 100644 --- a/tests/verifications/openai-api-verification-run.yaml +++ b/tests/verifications/openai-api-verification-run.yaml @@ -63,6 +63,9 @@ providers: type: sqlite namespace: null db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/openai}/agents_store.db + responses_store: + type: sqlite + db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/openai}/responses_store.db tool_runtime: - provider_id: brave-search provider_type: remote::brave-search