From 5844c2da6807d2540b76655866325fb017fce7c6 Mon Sep 17 00:00:00 2001 From: ehhuang Date: Fri, 23 May 2025 13:16:48 -0700 Subject: [PATCH] feat: add list responses API (#2233) # What does this PR do? This is not part of the official OpenAI API, but we'll use this for the logs UI. In order to support more filtering options, I'm adopting the newly introduced sql store in in place of the kv store. ## Test Plan Added integration/unit tests. --- docs/_static/llama-stack-spec.html | 170 ++++++++++++++++++ docs/_static/llama-stack-spec.yaml | 121 +++++++++++++ llama_stack/apis/agents/agents.py | 21 ++- llama_stack/apis/agents/openai_responses.py | 14 ++ llama_stack/apis/common/responses.py | 6 + llama_stack/apis/inference/inference.py | 6 +- .../inline/agents/meta_reference/agents.py | 18 +- .../inline/agents/meta_reference/config.py | 8 +- .../agents/meta_reference/openai_responses.py | 46 +++-- .../utils/responses/responses_store.py | 98 ++++++++++ llama_stack/templates/bedrock/run.yaml | 3 + llama_stack/templates/cerebras/run.yaml | 3 + llama_stack/templates/ci-tests/run.yaml | 3 + .../templates/dell/run-with-safety.yaml | 3 + llama_stack/templates/dell/run.yaml | 3 + .../templates/fireworks/run-with-safety.yaml | 3 + llama_stack/templates/fireworks/run.yaml | 3 + llama_stack/templates/groq/run.yaml | 3 + .../hf-endpoint/run-with-safety.yaml | 3 + llama_stack/templates/hf-endpoint/run.yaml | 3 + .../hf-serverless/run-with-safety.yaml | 3 + llama_stack/templates/hf-serverless/run.yaml | 3 + llama_stack/templates/llama_api/run.yaml | 3 + .../meta-reference-gpu/run-with-safety.yaml | 3 + .../templates/meta-reference-gpu/run.yaml | 3 + .../templates/nvidia/run-with-safety.yaml | 3 + llama_stack/templates/nvidia/run.yaml | 3 + .../templates/ollama/run-with-safety.yaml | 3 + llama_stack/templates/ollama/run.yaml | 3 + llama_stack/templates/open-benchmark/run.yaml | 3 + .../passthrough/run-with-safety.yaml | 3 + llama_stack/templates/passthrough/run.yaml | 3 + .../remote-vllm/run-with-safety.yaml | 3 + llama_stack/templates/remote-vllm/run.yaml | 3 + llama_stack/templates/sambanova/run.yaml | 3 + llama_stack/templates/starter/run.yaml | 3 + .../templates/tgi/run-with-safety.yaml | 3 + llama_stack/templates/tgi/run.yaml | 3 + .../templates/together/run-with-safety.yaml | 3 + llama_stack/templates/together/run.yaml | 3 + llama_stack/templates/verification/run.yaml | 3 + llama_stack/templates/vllm-gpu/run.yaml | 3 + llama_stack/templates/watsonx/run.yaml | 3 + .../agents/test_openai_responses.py | 97 ++++++++++ .../agent/test_meta_reference_agent.py | 4 + .../meta_reference/test_openai_responses.py | 70 +++----- .../openai-api-verification-run.yaml | 3 + 47 files changed, 704 insertions(+), 77 deletions(-) create mode 100644 llama_stack/providers/utils/responses/responses_store.py create mode 100644 tests/integration/agents/test_openai_responses.py diff --git a/docs/_static/llama-stack-spec.html b/docs/_static/llama-stack-spec.html index b3e982029..d21ff81fd 100644 --- a/docs/_static/llama-stack-spec.html +++ b/docs/_static/llama-stack-spec.html @@ -518,6 +518,74 @@ } }, "/v1/openai/v1/responses": { + "get": { + "responses": { + "200": { + "description": "A ListOpenAIResponseObject.", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/ListOpenAIResponseObject" + } + } + } + }, + "400": { + "$ref": "#/components/responses/BadRequest400" + }, + "429": { + "$ref": "#/components/responses/TooManyRequests429" + }, + "500": { + "$ref": "#/components/responses/InternalServerError500" + }, + "default": { + "$ref": "#/components/responses/DefaultError" + } + }, + "tags": [ + "Agents" + ], + "description": "List all OpenAI responses.", + "parameters": [ + { + "name": "after", + "in": "query", + "description": "The ID of the last response to return.", + "required": false, + "schema": { + "type": "string" + } + }, + { + "name": "limit", + "in": "query", + "description": "The number of responses to return.", + "required": false, + "schema": { + "type": "integer" + } + }, + { + "name": "model", + "in": "query", + "description": "The model to filter responses by.", + "required": false, + "schema": { + "type": "string" + } + }, + { + "name": "order", + "in": "query", + "description": "The order to sort responses by when sorted by created_at ('asc' or 'desc').", + "required": false, + "schema": { + "$ref": "#/components/schemas/Order" + } + } + ] + }, "post": { "responses": { "200": { @@ -10179,6 +10247,108 @@ ], "title": "ListModelsResponse" }, + "ListOpenAIResponseObject": { + "type": "object", + "properties": { + "data": { + "type": "array", + "items": { + "$ref": "#/components/schemas/OpenAIResponseObjectWithInput" + } + }, + "has_more": { + "type": "boolean" + }, + "first_id": { + "type": "string" + }, + "last_id": { + "type": "string" + }, + "object": { + "type": "string", + "const": "list", + "default": "list" + } + }, + "additionalProperties": false, + "required": [ + "data", + "has_more", + "first_id", + "last_id", + "object" + ], + "title": "ListOpenAIResponseObject" + }, + "OpenAIResponseObjectWithInput": { + "type": "object", + "properties": { + "created_at": { + "type": "integer" + }, + "error": { + "$ref": "#/components/schemas/OpenAIResponseError" + }, + "id": { + "type": "string" + }, + "model": { + "type": "string" + }, + "object": { + "type": "string", + "const": "response", + "default": "response" + }, + "output": { + "type": "array", + "items": { + "$ref": "#/components/schemas/OpenAIResponseOutput" + } + }, + "parallel_tool_calls": { + "type": "boolean", + "default": false + }, + "previous_response_id": { + "type": "string" + }, + "status": { + "type": "string" + }, + "temperature": { + "type": "number" + }, + "top_p": { + "type": "number" + }, + "truncation": { + "type": "string" + }, + "user": { + "type": "string" + }, + "input": { + "type": "array", + "items": { + "$ref": "#/components/schemas/OpenAIResponseInput" + } + } + }, + "additionalProperties": false, + "required": [ + "created_at", + "id", + "model", + "object", + "output", + "parallel_tool_calls", + "status", + "input" + ], + "title": "OpenAIResponseObjectWithInput" + }, "ListProvidersResponse": { "type": "object", "properties": { diff --git a/docs/_static/llama-stack-spec.yaml b/docs/_static/llama-stack-spec.yaml index 18cd2b046..8a936fcee 100644 --- a/docs/_static/llama-stack-spec.yaml +++ b/docs/_static/llama-stack-spec.yaml @@ -349,6 +349,53 @@ paths: $ref: '#/components/schemas/CreateAgentTurnRequest' required: true /v1/openai/v1/responses: + get: + responses: + '200': + description: A ListOpenAIResponseObject. + content: + application/json: + schema: + $ref: '#/components/schemas/ListOpenAIResponseObject' + '400': + $ref: '#/components/responses/BadRequest400' + '429': + $ref: >- + #/components/responses/TooManyRequests429 + '500': + $ref: >- + #/components/responses/InternalServerError500 + default: + $ref: '#/components/responses/DefaultError' + tags: + - Agents + description: List all OpenAI responses. + parameters: + - name: after + in: query + description: The ID of the last response to return. + required: false + schema: + type: string + - name: limit + in: query + description: The number of responses to return. + required: false + schema: + type: integer + - name: model + in: query + description: The model to filter responses by. + required: false + schema: + type: string + - name: order + in: query + description: >- + The order to sort responses by when sorted by created_at ('asc' or 'desc'). + required: false + schema: + $ref: '#/components/schemas/Order' post: responses: '200': @@ -7106,6 +7153,80 @@ components: required: - data title: ListModelsResponse + ListOpenAIResponseObject: + type: object + properties: + data: + type: array + items: + $ref: '#/components/schemas/OpenAIResponseObjectWithInput' + has_more: + type: boolean + first_id: + type: string + last_id: + type: string + object: + type: string + const: list + default: list + additionalProperties: false + required: + - data + - has_more + - first_id + - last_id + - object + title: ListOpenAIResponseObject + OpenAIResponseObjectWithInput: + type: object + properties: + created_at: + type: integer + error: + $ref: '#/components/schemas/OpenAIResponseError' + id: + type: string + model: + type: string + object: + type: string + const: response + default: response + output: + type: array + items: + $ref: '#/components/schemas/OpenAIResponseOutput' + parallel_tool_calls: + type: boolean + default: false + previous_response_id: + type: string + status: + type: string + temperature: + type: number + top_p: + type: number + truncation: + type: string + user: + type: string + input: + type: array + items: + $ref: '#/components/schemas/OpenAIResponseInput' + additionalProperties: false + required: + - created_at + - id + - model + - object + - output + - parallel_tool_calls + - status + - input + title: OpenAIResponseObjectWithInput ListProvidersResponse: type: object properties: diff --git a/llama_stack/apis/agents/agents.py b/llama_stack/apis/agents/agents.py index 5e857c895..bb185b8a3 100644 --- a/llama_stack/apis/agents/agents.py +++ b/llama_stack/apis/agents/agents.py @@ -13,7 +13,7 @@ from typing import Annotated, Any, Literal, Protocol, runtime_checkable from pydantic import BaseModel, ConfigDict, Field from llama_stack.apis.common.content_types import URL, ContentDelta, InterleavedContent -from llama_stack.apis.common.responses import PaginatedResponse +from llama_stack.apis.common.responses import Order, PaginatedResponse from llama_stack.apis.inference import ( CompletionMessage, ResponseFormat, @@ -31,6 +31,7 @@ from llama_stack.apis.tools import ToolDef from llama_stack.schema_utils import json_schema_type, register_schema, webmethod from .openai_responses import ( + ListOpenAIResponseObject, OpenAIResponseInput, OpenAIResponseInputTool, OpenAIResponseObject, @@ -611,3 +612,21 @@ class Agents(Protocol): :returns: An OpenAIResponseObject. """ ... + + @webmethod(route="/openai/v1/responses", method="GET") + async def list_openai_responses( + self, + after: str | None = None, + limit: int | None = 50, + model: str | None = None, + order: Order | None = Order.desc, + ) -> ListOpenAIResponseObject: + """List all OpenAI responses. + + :param after: The ID of the last response to return. + :param limit: The number of responses to return. + :param model: The model to filter responses by. + :param order: The order to sort responses by when sorted by created_at ('asc' or 'desc'). + :returns: A ListOpenAIResponseObject. + """ + ... diff --git a/llama_stack/apis/agents/openai_responses.py b/llama_stack/apis/agents/openai_responses.py index bb463bd57..5d8f2b80b 100644 --- a/llama_stack/apis/agents/openai_responses.py +++ b/llama_stack/apis/agents/openai_responses.py @@ -219,3 +219,17 @@ register_schema(OpenAIResponseInputTool, name="OpenAIResponseInputTool") class OpenAIResponseInputItemList(BaseModel): data: list[OpenAIResponseInput] object: Literal["list"] = "list" + + +@json_schema_type +class OpenAIResponseObjectWithInput(OpenAIResponseObject): + input: list[OpenAIResponseInput] + + +@json_schema_type +class ListOpenAIResponseObject(BaseModel): + data: list[OpenAIResponseObjectWithInput] + has_more: bool + first_id: str + last_id: str + object: Literal["list"] = "list" diff --git a/llama_stack/apis/common/responses.py b/llama_stack/apis/common/responses.py index b3bb5cb6b..5cb41e23d 100644 --- a/llama_stack/apis/common/responses.py +++ b/llama_stack/apis/common/responses.py @@ -4,6 +4,7 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. +from enum import Enum from typing import Any from pydantic import BaseModel @@ -11,6 +12,11 @@ from pydantic import BaseModel from llama_stack.schema_utils import json_schema_type +class Order(Enum): + asc = "asc" + desc = "desc" + + @json_schema_type class PaginatedResponse(BaseModel): """A generic paginated response that follows a simple format. diff --git a/llama_stack/apis/inference/inference.py b/llama_stack/apis/inference/inference.py index 7f8b20952..e79dc6d94 100644 --- a/llama_stack/apis/inference/inference.py +++ b/llama_stack/apis/inference/inference.py @@ -19,6 +19,7 @@ from pydantic import BaseModel, Field, field_validator from typing_extensions import TypedDict from llama_stack.apis.common.content_types import ContentDelta, InterleavedContent, InterleavedContentItem +from llama_stack.apis.common.responses import Order from llama_stack.apis.models import Model from llama_stack.apis.telemetry.telemetry import MetricResponseMixin from llama_stack.models.llama.datatypes import ( @@ -833,11 +834,6 @@ class ListOpenAIChatCompletionResponse(BaseModel): object: Literal["list"] = "list" -class Order(Enum): - asc = "asc" - desc = "desc" - - @runtime_checkable @trace_protocol class InferenceProvider(Protocol): diff --git a/llama_stack/providers/inline/agents/meta_reference/agents.py b/llama_stack/providers/inline/agents/meta_reference/agents.py index e98799ae6..787135488 100644 --- a/llama_stack/providers/inline/agents/meta_reference/agents.py +++ b/llama_stack/providers/inline/agents/meta_reference/agents.py @@ -20,9 +20,11 @@ from llama_stack.apis.agents import ( AgentTurnCreateRequest, AgentTurnResumeRequest, Document, + ListOpenAIResponseObject, OpenAIResponseInput, OpenAIResponseInputTool, OpenAIResponseObject, + Order, Session, Turn, ) @@ -39,6 +41,7 @@ from llama_stack.apis.tools import ToolGroups, ToolRuntime from llama_stack.apis.vector_io import VectorIO from llama_stack.providers.utils.kvstore import InmemoryKVStoreImpl, kvstore_impl from llama_stack.providers.utils.pagination import paginate_records +from llama_stack.providers.utils.responses.responses_store import ResponsesStore from .agent_instance import ChatAgent from .config import MetaReferenceAgentsImplConfig @@ -66,15 +69,17 @@ class MetaReferenceAgentsImpl(Agents): self.tool_groups_api = tool_groups_api self.in_memory_store = InmemoryKVStoreImpl() - self.openai_responses_impl = None + self.openai_responses_impl: OpenAIResponsesImpl | None = None async def initialize(self) -> None: self.persistence_store = await kvstore_impl(self.config.persistence_store) + self.responses_store = ResponsesStore(self.config.responses_store) + await self.responses_store.initialize() self.openai_responses_impl = OpenAIResponsesImpl( - self.persistence_store, inference_api=self.inference_api, tool_groups_api=self.tool_groups_api, tool_runtime_api=self.tool_runtime_api, + responses_store=self.responses_store, ) async def create_agent( @@ -323,3 +328,12 @@ class MetaReferenceAgentsImpl(Agents): return await self.openai_responses_impl.create_openai_response( input, model, instructions, previous_response_id, store, stream, temperature, tools ) + + async def list_openai_responses( + self, + after: str | None = None, + limit: int | None = 50, + model: str | None = None, + order: Order | None = Order.desc, + ) -> ListOpenAIResponseObject: + return await self.openai_responses_impl.list_openai_responses(after, limit, model, order) diff --git a/llama_stack/providers/inline/agents/meta_reference/config.py b/llama_stack/providers/inline/agents/meta_reference/config.py index c860e6df1..1c392f29c 100644 --- a/llama_stack/providers/inline/agents/meta_reference/config.py +++ b/llama_stack/providers/inline/agents/meta_reference/config.py @@ -10,10 +10,12 @@ from pydantic import BaseModel from llama_stack.providers.utils.kvstore import KVStoreConfig from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig +from llama_stack.providers.utils.sqlstore.sqlstore import SqliteSqlStoreConfig, SqlStoreConfig class MetaReferenceAgentsImplConfig(BaseModel): persistence_store: KVStoreConfig + responses_store: SqlStoreConfig @classmethod def sample_run_config(cls, __distro_dir__: str) -> dict[str, Any]: @@ -21,5 +23,9 @@ class MetaReferenceAgentsImplConfig(BaseModel): "persistence_store": SqliteKVStoreConfig.sample_run_config( __distro_dir__=__distro_dir__, db_name="agents_store.db", - ) + ), + "responses_store": SqliteSqlStoreConfig.sample_run_config( + __distro_dir__=__distro_dir__, + db_name="responses_store.db", + ), } diff --git a/llama_stack/providers/inline/agents/meta_reference/openai_responses.py b/llama_stack/providers/inline/agents/meta_reference/openai_responses.py index 92345a12f..939282005 100644 --- a/llama_stack/providers/inline/agents/meta_reference/openai_responses.py +++ b/llama_stack/providers/inline/agents/meta_reference/openai_responses.py @@ -12,7 +12,9 @@ from typing import Any, cast from openai.types.chat import ChatCompletionToolParam from pydantic import BaseModel +from llama_stack.apis.agents import Order from llama_stack.apis.agents.openai_responses import ( + ListOpenAIResponseObject, OpenAIResponseInput, OpenAIResponseInputFunctionToolCallOutput, OpenAIResponseInputItemList, @@ -53,7 +55,7 @@ from llama_stack.apis.tools.tools import ToolGroups, ToolInvocationResult, ToolR from llama_stack.log import get_logger from llama_stack.models.llama.datatypes import ToolDefinition, ToolParamDefinition from llama_stack.providers.utils.inference.openai_compat import convert_tooldef_to_openai_tool -from llama_stack.providers.utils.kvstore import KVStore +from llama_stack.providers.utils.responses.responses_store import ResponsesStore logger = get_logger(name=__name__, category="openai_responses") @@ -169,34 +171,27 @@ class OpenAIResponsePreviousResponseWithInputItems(BaseModel): class OpenAIResponsesImpl: def __init__( self, - persistence_store: KVStore, inference_api: Inference, tool_groups_api: ToolGroups, tool_runtime_api: ToolRuntime, + responses_store: ResponsesStore, ): - self.persistence_store = persistence_store self.inference_api = inference_api self.tool_groups_api = tool_groups_api self.tool_runtime_api = tool_runtime_api - - async def _get_previous_response_with_input(self, id: str) -> OpenAIResponsePreviousResponseWithInputItems: - key = f"{OPENAI_RESPONSES_PREFIX}{id}" - response_json = await self.persistence_store.get(key=key) - if response_json is None: - raise ValueError(f"OpenAI response with id '{id}' not found") - return OpenAIResponsePreviousResponseWithInputItems.model_validate_json(response_json) + self.responses_store = responses_store async def _prepend_previous_response( self, input: str | list[OpenAIResponseInput], previous_response_id: str | None = None ): if previous_response_id: - previous_response_with_input = await self._get_previous_response_with_input(previous_response_id) + previous_response_with_input = await self.responses_store.get_response_object(previous_response_id) # previous response input items - new_input_items = previous_response_with_input.input_items.data + new_input_items = previous_response_with_input.input # previous response output items - new_input_items.extend(previous_response_with_input.response.output) + new_input_items.extend(previous_response_with_input.output) # new input items from the current request if isinstance(input, str): @@ -216,8 +211,17 @@ class OpenAIResponsesImpl: self, response_id: str, ) -> OpenAIResponseObject: - response_with_input = await self._get_previous_response_with_input(response_id) - return response_with_input.response + response_with_input = await self.responses_store.get_response_object(response_id) + return OpenAIResponseObject(**{k: v for k, v in response_with_input.model_dump().items() if k != "input"}) + + async def list_openai_responses( + self, + after: str | None = None, + limit: int | None = 50, + model: str | None = None, + order: Order | None = Order.desc, + ) -> ListOpenAIResponseObject: + return await self.responses_store.list_responses(after, limit, model, order) async def create_openai_response( self, @@ -360,15 +364,9 @@ class OpenAIResponsesImpl: else: input_items_data.append(input_item) - input_items = OpenAIResponseInputItemList(data=input_items_data) - prev_response = OpenAIResponsePreviousResponseWithInputItems( - input_items=input_items, - response=response, - ) - key = f"{OPENAI_RESPONSES_PREFIX}{response.id}" - await self.persistence_store.set( - key=key, - value=prev_response.model_dump_json(), + await self.responses_store.store_response_object( + response_object=response, + input=input_items_data, ) if stream: diff --git a/llama_stack/providers/utils/responses/responses_store.py b/llama_stack/providers/utils/responses/responses_store.py new file mode 100644 index 000000000..19da6785a --- /dev/null +++ b/llama_stack/providers/utils/responses/responses_store.py @@ -0,0 +1,98 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. +from llama_stack.apis.agents import ( + Order, +) +from llama_stack.apis.agents.openai_responses import ( + ListOpenAIResponseObject, + OpenAIResponseInput, + OpenAIResponseObject, + OpenAIResponseObjectWithInput, +) +from llama_stack.distribution.utils.config_dirs import RUNTIME_BASE_DIR + +from ..sqlstore.api import ColumnDefinition, ColumnType +from ..sqlstore.sqlstore import SqliteSqlStoreConfig, SqlStoreConfig, sqlstore_impl + + +class ResponsesStore: + def __init__(self, sql_store_config: SqlStoreConfig): + if not sql_store_config: + sql_store_config = SqliteSqlStoreConfig( + db_path=(RUNTIME_BASE_DIR / "sqlstore.db").as_posix(), + ) + self.sql_store = sqlstore_impl(sql_store_config) + + async def initialize(self): + """Create the necessary tables if they don't exist.""" + await self.sql_store.create_table( + "openai_responses", + { + "id": ColumnDefinition(type=ColumnType.STRING, primary_key=True), + "created_at": ColumnType.INTEGER, + "response_object": ColumnType.JSON, + "model": ColumnType.STRING, + }, + ) + + async def store_response_object( + self, response_object: OpenAIResponseObject, input: list[OpenAIResponseInput] + ) -> None: + data = response_object.model_dump() + data["input"] = [input_item.model_dump() for input_item in input] + + await self.sql_store.insert( + "openai_responses", + { + "id": data["id"], + "created_at": data["created_at"], + "model": data["model"], + "response_object": data, + }, + ) + + async def list_responses( + self, + after: str | None = None, + limit: int | None = 50, + model: str | None = None, + order: Order | None = Order.desc, + ) -> ListOpenAIResponseObject: + """ + List responses from the database. + + :param after: The ID of the last response to return. + :param limit: The maximum number of responses to return. + :param model: The model to filter by. + :param order: The order to sort the responses by. + """ + # TODO: support after + if after: + raise NotImplementedError("After is not supported for SQLite") + if not order: + order = Order.desc + + rows = await self.sql_store.fetch_all( + "openai_responses", + where={"model": model} if model else None, + order_by=[("created_at", order.value)], + limit=limit, + ) + + data = [OpenAIResponseObjectWithInput(**row["response_object"]) for row in rows] + return ListOpenAIResponseObject( + data=data, + # TODO: implement has_more + has_more=False, + first_id=data[0].id if data else "", + last_id=data[-1].id if data else "", + ) + + async def get_response_object(self, response_id: str) -> OpenAIResponseObjectWithInput: + row = await self.sql_store.fetch_one("openai_responses", where={"id": response_id}) + if not row: + raise ValueError(f"Response with id {response_id} not found") from None + return OpenAIResponseObjectWithInput(**row["response_object"]) diff --git a/llama_stack/templates/bedrock/run.yaml b/llama_stack/templates/bedrock/run.yaml index c39b08ff9..a58068a60 100644 --- a/llama_stack/templates/bedrock/run.yaml +++ b/llama_stack/templates/bedrock/run.yaml @@ -35,6 +35,9 @@ providers: type: sqlite namespace: null db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/bedrock}/agents_store.db + responses_store: + type: sqlite + db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/bedrock}/responses_store.db telemetry: - provider_id: meta-reference provider_type: inline::meta-reference diff --git a/llama_stack/templates/cerebras/run.yaml b/llama_stack/templates/cerebras/run.yaml index 025033f59..c080536b7 100644 --- a/llama_stack/templates/cerebras/run.yaml +++ b/llama_stack/templates/cerebras/run.yaml @@ -41,6 +41,9 @@ providers: type: sqlite namespace: null db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/cerebras}/agents_store.db + responses_store: + type: sqlite + db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/cerebras}/responses_store.db eval: - provider_id: meta-reference provider_type: inline::meta-reference diff --git a/llama_stack/templates/ci-tests/run.yaml b/llama_stack/templates/ci-tests/run.yaml index 342388b78..368187d3a 100644 --- a/llama_stack/templates/ci-tests/run.yaml +++ b/llama_stack/templates/ci-tests/run.yaml @@ -38,6 +38,9 @@ providers: type: sqlite namespace: null db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/ci-tests}/agents_store.db + responses_store: + type: sqlite + db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/ci-tests}/responses_store.db telemetry: - provider_id: meta-reference provider_type: inline::meta-reference diff --git a/llama_stack/templates/dell/run-with-safety.yaml b/llama_stack/templates/dell/run-with-safety.yaml index 77843858c..5c6072245 100644 --- a/llama_stack/templates/dell/run-with-safety.yaml +++ b/llama_stack/templates/dell/run-with-safety.yaml @@ -41,6 +41,9 @@ providers: type: sqlite namespace: null db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/dell}/agents_store.db + responses_store: + type: sqlite + db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/dell}/responses_store.db telemetry: - provider_id: meta-reference provider_type: inline::meta-reference diff --git a/llama_stack/templates/dell/run.yaml b/llama_stack/templates/dell/run.yaml index fd0d4a1f6..ffaa0bf2f 100644 --- a/llama_stack/templates/dell/run.yaml +++ b/llama_stack/templates/dell/run.yaml @@ -37,6 +37,9 @@ providers: type: sqlite namespace: null db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/dell}/agents_store.db + responses_store: + type: sqlite + db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/dell}/responses_store.db telemetry: - provider_id: meta-reference provider_type: inline::meta-reference diff --git a/llama_stack/templates/fireworks/run-with-safety.yaml b/llama_stack/templates/fireworks/run-with-safety.yaml index 1f66983f4..41500f6f6 100644 --- a/llama_stack/templates/fireworks/run-with-safety.yaml +++ b/llama_stack/templates/fireworks/run-with-safety.yaml @@ -46,6 +46,9 @@ providers: type: sqlite namespace: null db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/fireworks}/agents_store.db + responses_store: + type: sqlite + db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/fireworks}/responses_store.db telemetry: - provider_id: meta-reference provider_type: inline::meta-reference diff --git a/llama_stack/templates/fireworks/run.yaml b/llama_stack/templates/fireworks/run.yaml index 1fbf4be6e..b1fa03306 100644 --- a/llama_stack/templates/fireworks/run.yaml +++ b/llama_stack/templates/fireworks/run.yaml @@ -41,6 +41,9 @@ providers: type: sqlite namespace: null db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/fireworks}/agents_store.db + responses_store: + type: sqlite + db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/fireworks}/responses_store.db telemetry: - provider_id: meta-reference provider_type: inline::meta-reference diff --git a/llama_stack/templates/groq/run.yaml b/llama_stack/templates/groq/run.yaml index 7d257d379..db7ebffee 100644 --- a/llama_stack/templates/groq/run.yaml +++ b/llama_stack/templates/groq/run.yaml @@ -41,6 +41,9 @@ providers: type: sqlite namespace: null db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/groq}/agents_store.db + responses_store: + type: sqlite + db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/groq}/responses_store.db telemetry: - provider_id: meta-reference provider_type: inline::meta-reference diff --git a/llama_stack/templates/hf-endpoint/run-with-safety.yaml b/llama_stack/templates/hf-endpoint/run-with-safety.yaml index b3938bf93..15cf2a47f 100644 --- a/llama_stack/templates/hf-endpoint/run-with-safety.yaml +++ b/llama_stack/templates/hf-endpoint/run-with-safety.yaml @@ -46,6 +46,9 @@ providers: type: sqlite namespace: null db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/hf-endpoint}/agents_store.db + responses_store: + type: sqlite + db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/hf-endpoint}/responses_store.db telemetry: - provider_id: meta-reference provider_type: inline::meta-reference diff --git a/llama_stack/templates/hf-endpoint/run.yaml b/llama_stack/templates/hf-endpoint/run.yaml index 1e60dd25c..428edf9a2 100644 --- a/llama_stack/templates/hf-endpoint/run.yaml +++ b/llama_stack/templates/hf-endpoint/run.yaml @@ -41,6 +41,9 @@ providers: type: sqlite namespace: null db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/hf-endpoint}/agents_store.db + responses_store: + type: sqlite + db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/hf-endpoint}/responses_store.db telemetry: - provider_id: meta-reference provider_type: inline::meta-reference diff --git a/llama_stack/templates/hf-serverless/run-with-safety.yaml b/llama_stack/templates/hf-serverless/run-with-safety.yaml index 640506632..ab461c6c3 100644 --- a/llama_stack/templates/hf-serverless/run-with-safety.yaml +++ b/llama_stack/templates/hf-serverless/run-with-safety.yaml @@ -46,6 +46,9 @@ providers: type: sqlite namespace: null db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/hf-serverless}/agents_store.db + responses_store: + type: sqlite + db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/hf-serverless}/responses_store.db telemetry: - provider_id: meta-reference provider_type: inline::meta-reference diff --git a/llama_stack/templates/hf-serverless/run.yaml b/llama_stack/templates/hf-serverless/run.yaml index a8b46a0aa..d238506fb 100644 --- a/llama_stack/templates/hf-serverless/run.yaml +++ b/llama_stack/templates/hf-serverless/run.yaml @@ -41,6 +41,9 @@ providers: type: sqlite namespace: null db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/hf-serverless}/agents_store.db + responses_store: + type: sqlite + db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/hf-serverless}/responses_store.db telemetry: - provider_id: meta-reference provider_type: inline::meta-reference diff --git a/llama_stack/templates/llama_api/run.yaml b/llama_stack/templates/llama_api/run.yaml index 1d5739fe2..a7f2b0769 100644 --- a/llama_stack/templates/llama_api/run.yaml +++ b/llama_stack/templates/llama_api/run.yaml @@ -50,6 +50,9 @@ providers: type: sqlite namespace: null db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/llama_api}/agents_store.db + responses_store: + type: sqlite + db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/llama_api}/responses_store.db telemetry: - provider_id: meta-reference provider_type: inline::meta-reference diff --git a/llama_stack/templates/meta-reference-gpu/run-with-safety.yaml b/llama_stack/templates/meta-reference-gpu/run-with-safety.yaml index bbf7ad767..2b751a514 100644 --- a/llama_stack/templates/meta-reference-gpu/run-with-safety.yaml +++ b/llama_stack/templates/meta-reference-gpu/run-with-safety.yaml @@ -56,6 +56,9 @@ providers: type: sqlite namespace: null db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/meta-reference-gpu}/agents_store.db + responses_store: + type: sqlite + db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/meta-reference-gpu}/responses_store.db telemetry: - provider_id: meta-reference provider_type: inline::meta-reference diff --git a/llama_stack/templates/meta-reference-gpu/run.yaml b/llama_stack/templates/meta-reference-gpu/run.yaml index 9ce69c209..a24c5fec5 100644 --- a/llama_stack/templates/meta-reference-gpu/run.yaml +++ b/llama_stack/templates/meta-reference-gpu/run.yaml @@ -46,6 +46,9 @@ providers: type: sqlite namespace: null db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/meta-reference-gpu}/agents_store.db + responses_store: + type: sqlite + db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/meta-reference-gpu}/responses_store.db telemetry: - provider_id: meta-reference provider_type: inline::meta-reference diff --git a/llama_stack/templates/nvidia/run-with-safety.yaml b/llama_stack/templates/nvidia/run-with-safety.yaml index 32359b805..c431e12f2 100644 --- a/llama_stack/templates/nvidia/run-with-safety.yaml +++ b/llama_stack/templates/nvidia/run-with-safety.yaml @@ -46,6 +46,9 @@ providers: type: sqlite namespace: null db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/nvidia}/agents_store.db + responses_store: + type: sqlite + db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/nvidia}/responses_store.db telemetry: - provider_id: meta-reference provider_type: inline::meta-reference diff --git a/llama_stack/templates/nvidia/run.yaml b/llama_stack/templates/nvidia/run.yaml index d4e935727..5b244081d 100644 --- a/llama_stack/templates/nvidia/run.yaml +++ b/llama_stack/templates/nvidia/run.yaml @@ -41,6 +41,9 @@ providers: type: sqlite namespace: null db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/nvidia}/agents_store.db + responses_store: + type: sqlite + db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/nvidia}/responses_store.db telemetry: - provider_id: meta-reference provider_type: inline::meta-reference diff --git a/llama_stack/templates/ollama/run-with-safety.yaml b/llama_stack/templates/ollama/run-with-safety.yaml index a19ac73c6..d63c5e366 100644 --- a/llama_stack/templates/ollama/run-with-safety.yaml +++ b/llama_stack/templates/ollama/run-with-safety.yaml @@ -40,6 +40,9 @@ providers: type: sqlite namespace: null db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/ollama}/agents_store.db + responses_store: + type: sqlite + db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/ollama}/responses_store.db telemetry: - provider_id: meta-reference provider_type: inline::meta-reference diff --git a/llama_stack/templates/ollama/run.yaml b/llama_stack/templates/ollama/run.yaml index 551af3a99..d208cd7f0 100644 --- a/llama_stack/templates/ollama/run.yaml +++ b/llama_stack/templates/ollama/run.yaml @@ -38,6 +38,9 @@ providers: type: sqlite namespace: null db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/ollama}/agents_store.db + responses_store: + type: sqlite + db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/ollama}/responses_store.db telemetry: - provider_id: meta-reference provider_type: inline::meta-reference diff --git a/llama_stack/templates/open-benchmark/run.yaml b/llama_stack/templates/open-benchmark/run.yaml index 7b43ce6e7..0e5edf728 100644 --- a/llama_stack/templates/open-benchmark/run.yaml +++ b/llama_stack/templates/open-benchmark/run.yaml @@ -64,6 +64,9 @@ providers: type: sqlite namespace: null db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/open-benchmark}/agents_store.db + responses_store: + type: sqlite + db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/open-benchmark}/responses_store.db telemetry: - provider_id: meta-reference provider_type: inline::meta-reference diff --git a/llama_stack/templates/passthrough/run-with-safety.yaml b/llama_stack/templates/passthrough/run-with-safety.yaml index cddda39fa..bbf5d9a52 100644 --- a/llama_stack/templates/passthrough/run-with-safety.yaml +++ b/llama_stack/templates/passthrough/run-with-safety.yaml @@ -46,6 +46,9 @@ providers: type: sqlite namespace: null db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/passthrough}/agents_store.db + responses_store: + type: sqlite + db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/passthrough}/responses_store.db telemetry: - provider_id: meta-reference provider_type: inline::meta-reference diff --git a/llama_stack/templates/passthrough/run.yaml b/llama_stack/templates/passthrough/run.yaml index 1fc3914a6..146906d9b 100644 --- a/llama_stack/templates/passthrough/run.yaml +++ b/llama_stack/templates/passthrough/run.yaml @@ -41,6 +41,9 @@ providers: type: sqlite namespace: null db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/passthrough}/agents_store.db + responses_store: + type: sqlite + db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/passthrough}/responses_store.db telemetry: - provider_id: meta-reference provider_type: inline::meta-reference diff --git a/llama_stack/templates/remote-vllm/run-with-safety.yaml b/llama_stack/templates/remote-vllm/run-with-safety.yaml index 89f3aa082..e83162a4f 100644 --- a/llama_stack/templates/remote-vllm/run-with-safety.yaml +++ b/llama_stack/templates/remote-vllm/run-with-safety.yaml @@ -50,6 +50,9 @@ providers: type: sqlite namespace: null db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/remote-vllm}/agents_store.db + responses_store: + type: sqlite + db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/remote-vllm}/responses_store.db eval: - provider_id: meta-reference provider_type: inline::meta-reference diff --git a/llama_stack/templates/remote-vllm/run.yaml b/llama_stack/templates/remote-vllm/run.yaml index 4d4395fd7..4cdf88c6b 100644 --- a/llama_stack/templates/remote-vllm/run.yaml +++ b/llama_stack/templates/remote-vllm/run.yaml @@ -43,6 +43,9 @@ providers: type: sqlite namespace: null db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/remote-vllm}/agents_store.db + responses_store: + type: sqlite + db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/remote-vllm}/responses_store.db eval: - provider_id: meta-reference provider_type: inline::meta-reference diff --git a/llama_stack/templates/sambanova/run.yaml b/llama_stack/templates/sambanova/run.yaml index 907bc013e..8c2a933ab 100644 --- a/llama_stack/templates/sambanova/run.yaml +++ b/llama_stack/templates/sambanova/run.yaml @@ -51,6 +51,9 @@ providers: type: sqlite namespace: null db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/sambanova}/agents_store.db + responses_store: + type: sqlite + db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/sambanova}/responses_store.db telemetry: - provider_id: meta-reference provider_type: inline::meta-reference diff --git a/llama_stack/templates/starter/run.yaml b/llama_stack/templates/starter/run.yaml index 3327e576c..04425ed35 100644 --- a/llama_stack/templates/starter/run.yaml +++ b/llama_stack/templates/starter/run.yaml @@ -72,6 +72,9 @@ providers: type: sqlite namespace: null db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/starter}/agents_store.db + responses_store: + type: sqlite + db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/starter}/responses_store.db telemetry: - provider_id: meta-reference provider_type: inline::meta-reference diff --git a/llama_stack/templates/tgi/run-with-safety.yaml b/llama_stack/templates/tgi/run-with-safety.yaml index bd197b93f..c797b93aa 100644 --- a/llama_stack/templates/tgi/run-with-safety.yaml +++ b/llama_stack/templates/tgi/run-with-safety.yaml @@ -41,6 +41,9 @@ providers: type: sqlite namespace: null db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/tgi}/agents_store.db + responses_store: + type: sqlite + db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/tgi}/responses_store.db telemetry: - provider_id: meta-reference provider_type: inline::meta-reference diff --git a/llama_stack/templates/tgi/run.yaml b/llama_stack/templates/tgi/run.yaml index 230fe9a5a..7e91d20bd 100644 --- a/llama_stack/templates/tgi/run.yaml +++ b/llama_stack/templates/tgi/run.yaml @@ -40,6 +40,9 @@ providers: type: sqlite namespace: null db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/tgi}/agents_store.db + responses_store: + type: sqlite + db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/tgi}/responses_store.db telemetry: - provider_id: meta-reference provider_type: inline::meta-reference diff --git a/llama_stack/templates/together/run-with-safety.yaml b/llama_stack/templates/together/run-with-safety.yaml index 1c05e5e42..190a0400b 100644 --- a/llama_stack/templates/together/run-with-safety.yaml +++ b/llama_stack/templates/together/run-with-safety.yaml @@ -46,6 +46,9 @@ providers: type: sqlite namespace: null db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/together}/agents_store.db + responses_store: + type: sqlite + db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/together}/responses_store.db telemetry: - provider_id: meta-reference provider_type: inline::meta-reference diff --git a/llama_stack/templates/together/run.yaml b/llama_stack/templates/together/run.yaml index aebf4e1a2..ce9542130 100644 --- a/llama_stack/templates/together/run.yaml +++ b/llama_stack/templates/together/run.yaml @@ -41,6 +41,9 @@ providers: type: sqlite namespace: null db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/together}/agents_store.db + responses_store: + type: sqlite + db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/together}/responses_store.db telemetry: - provider_id: meta-reference provider_type: inline::meta-reference diff --git a/llama_stack/templates/verification/run.yaml b/llama_stack/templates/verification/run.yaml index de8b0d850..58b3c576c 100644 --- a/llama_stack/templates/verification/run.yaml +++ b/llama_stack/templates/verification/run.yaml @@ -74,6 +74,9 @@ providers: type: sqlite namespace: null db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/verification}/agents_store.db + responses_store: + type: sqlite + db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/verification}/responses_store.db telemetry: - provider_id: meta-reference provider_type: inline::meta-reference diff --git a/llama_stack/templates/vllm-gpu/run.yaml b/llama_stack/templates/vllm-gpu/run.yaml index a0257f704..6937e2bac 100644 --- a/llama_stack/templates/vllm-gpu/run.yaml +++ b/llama_stack/templates/vllm-gpu/run.yaml @@ -45,6 +45,9 @@ providers: type: sqlite namespace: null db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/vllm-gpu}/agents_store.db + responses_store: + type: sqlite + db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/vllm-gpu}/responses_store.db telemetry: - provider_id: meta-reference provider_type: inline::meta-reference diff --git a/llama_stack/templates/watsonx/run.yaml b/llama_stack/templates/watsonx/run.yaml index 86ec01953..e7222fd57 100644 --- a/llama_stack/templates/watsonx/run.yaml +++ b/llama_stack/templates/watsonx/run.yaml @@ -42,6 +42,9 @@ providers: type: sqlite namespace: null db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/watsonx}/agents_store.db + responses_store: + type: sqlite + db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/watsonx}/responses_store.db telemetry: - provider_id: meta-reference provider_type: inline::meta-reference diff --git a/tests/integration/agents/test_openai_responses.py b/tests/integration/agents/test_openai_responses.py new file mode 100644 index 000000000..8af1c1870 --- /dev/null +++ b/tests/integration/agents/test_openai_responses.py @@ -0,0 +1,97 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from urllib.parse import urljoin + +import pytest +import requests +from openai import OpenAI + +from llama_stack.distribution.library_client import LlamaStackAsLibraryClient + + +@pytest.fixture +def openai_client(client_with_models): + base_url = f"{client_with_models.base_url}/v1/openai/v1" + return OpenAI(base_url=base_url, api_key="bar") + + +@pytest.mark.parametrize( + "stream", + [ + True, + False, + ], +) +@pytest.mark.parametrize( + "tools", + [ + [], + [ + { + "type": "function", + "name": "get_weather", + "description": "Get the weather in a given city", + "parameters": { + "type": "object", + "properties": { + "city": {"type": "string", "description": "The city to get the weather for"}, + }, + }, + } + ], + ], +) +def test_responses_store(openai_client, client_with_models, text_model_id, stream, tools): + if isinstance(client_with_models, LlamaStackAsLibraryClient): + pytest.skip("OpenAI responses are not supported when testing with library client yet.") + + client = openai_client + message = "What's the weather in Tokyo?" + ( + " YOU MUST USE THE get_weather function to get the weather." if tools else "" + ) + response = client.responses.create( + model=text_model_id, + input=[ + { + "role": "user", + "content": message, + } + ], + stream=stream, + tools=tools, + ) + if stream: + # accumulate the streamed content + content = "" + response_id = None + for chunk in response: + if response_id is None: + response_id = chunk.response.id + if not tools: + if chunk.type == "response.completed": + response_id = chunk.response.id + content = chunk.response.output[0].content[0].text + else: + response_id = response.id + if not tools: + content = response.output[0].content[0].text + + # list responses is not available in the SDK + url = urljoin(str(client.base_url), "responses") + response = requests.get(url, headers={"Authorization": f"Bearer {client.api_key}"}) + assert response.status_code == 200 + data = response.json()["data"] + assert response_id in [r["id"] for r in data] + + # test retrieve response + retrieved_response = client.responses.retrieve(response_id) + assert retrieved_response.id == response_id + assert retrieved_response.model == text_model_id + if tools: + assert retrieved_response.output[0].type == "function_call" + else: + assert retrieved_response.output[0].content[0].text == content diff --git a/tests/unit/providers/agent/test_meta_reference_agent.py b/tests/unit/providers/agent/test_meta_reference_agent.py index bef24e123..9549f6df6 100644 --- a/tests/unit/providers/agent/test_meta_reference_agent.py +++ b/tests/unit/providers/agent/test_meta_reference_agent.py @@ -43,6 +43,10 @@ def config(tmp_path): "type": "sqlite", "db_path": str(tmp_path / "test.db"), }, + responses_store={ + "type": "sqlite", + "db_path": str(tmp_path / "test.db"), + }, ) diff --git a/tests/unit/providers/agents/meta_reference/test_openai_responses.py b/tests/unit/providers/agents/meta_reference/test_openai_responses.py index 0a8d59306..bf36d7b64 100644 --- a/tests/unit/providers/agents/meta_reference/test_openai_responses.py +++ b/tests/unit/providers/agents/meta_reference/test_openai_responses.py @@ -4,7 +4,7 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from unittest.mock import AsyncMock, patch +from unittest.mock import AsyncMock import pytest from openai.types.chat.chat_completion_chunk import ( @@ -16,12 +16,11 @@ from openai.types.chat.chat_completion_chunk import ( ) from llama_stack.apis.agents.openai_responses import ( - OpenAIResponseInputItemList, OpenAIResponseInputMessageContentText, OpenAIResponseInputToolFunction, OpenAIResponseInputToolWebSearch, OpenAIResponseMessage, - OpenAIResponseObject, + OpenAIResponseObjectWithInput, OpenAIResponseOutputMessageContentOutputText, OpenAIResponseOutputMessageWebSearchToolCall, ) @@ -33,19 +32,12 @@ from llama_stack.apis.inference.inference import ( ) from llama_stack.apis.tools.tools import Tool, ToolGroups, ToolInvocationResult, ToolParameter, ToolRuntime from llama_stack.providers.inline.agents.meta_reference.openai_responses import ( - OpenAIResponsePreviousResponseWithInputItems, OpenAIResponsesImpl, ) -from llama_stack.providers.utils.kvstore import KVStore +from llama_stack.providers.utils.responses.responses_store import ResponsesStore from tests.unit.providers.agents.meta_reference.fixtures import load_chat_completion_fixture -@pytest.fixture -def mock_kvstore(): - kvstore = AsyncMock(spec=KVStore) - return kvstore - - @pytest.fixture def mock_inference_api(): inference_api = AsyncMock() @@ -65,12 +57,18 @@ def mock_tool_runtime_api(): @pytest.fixture -def openai_responses_impl(mock_kvstore, mock_inference_api, mock_tool_groups_api, mock_tool_runtime_api): +def mock_responses_store(): + responses_store = AsyncMock(spec=ResponsesStore) + return responses_store + + +@pytest.fixture +def openai_responses_impl(mock_inference_api, mock_tool_groups_api, mock_tool_runtime_api, mock_responses_store): return OpenAIResponsesImpl( - persistence_store=mock_kvstore, inference_api=mock_inference_api, tool_groups_api=mock_tool_groups_api, tool_runtime_api=mock_tool_runtime_api, + responses_store=mock_responses_store, ) @@ -100,7 +98,7 @@ async def test_create_openai_response_with_string_input(openai_responses_impl, m stream=False, temperature=0.1, ) - openai_responses_impl.persistence_store.set.assert_called_once() + openai_responses_impl.responses_store.store_response_object.assert_called_once() assert result.model == model assert len(result.output) == 1 assert isinstance(result.output[0], OpenAIResponseMessage) @@ -167,7 +165,7 @@ async def test_create_openai_response_with_string_input_with_tools(openai_respon kwargs={"query": "What is the capital of Ireland?"}, ) - openai_responses_impl.persistence_store.set.assert_called_once() + openai_responses_impl.responses_store.store_response_object.assert_called_once() # Check that we got the content from our mocked tool execution result assert len(result.output) >= 1 @@ -292,8 +290,7 @@ async def test_prepend_previous_response_none(openai_responses_impl): @pytest.mark.asyncio -@patch.object(OpenAIResponsesImpl, "_get_previous_response_with_input") -async def test_prepend_previous_response_basic(get_previous_response_with_input, openai_responses_impl): +async def test_prepend_previous_response_basic(openai_responses_impl, mock_responses_store): """Test prepending a basic previous response to a new response.""" input_item_message = OpenAIResponseMessage( @@ -301,25 +298,21 @@ async def test_prepend_previous_response_basic(get_previous_response_with_input, content=[OpenAIResponseInputMessageContentText(text="fake_previous_input")], role="user", ) - input_items = OpenAIResponseInputItemList(data=[input_item_message]) response_output_message = OpenAIResponseMessage( id="123", content=[OpenAIResponseOutputMessageContentOutputText(text="fake_response")], status="completed", role="assistant", ) - response = OpenAIResponseObject( + previous_response = OpenAIResponseObjectWithInput( created_at=1, id="resp_123", model="fake_model", output=[response_output_message], status="completed", + input=[input_item_message], ) - previous_response = OpenAIResponsePreviousResponseWithInputItems( - input_items=input_items, - response=response, - ) - get_previous_response_with_input.return_value = previous_response + mock_responses_store.get_response_object.return_value = previous_response input = await openai_responses_impl._prepend_previous_response("fake_input", "resp_123") @@ -336,16 +329,13 @@ async def test_prepend_previous_response_basic(get_previous_response_with_input, @pytest.mark.asyncio -@patch.object(OpenAIResponsesImpl, "_get_previous_response_with_input") -async def test_prepend_previous_response_web_search(get_previous_response_with_input, openai_responses_impl): +async def test_prepend_previous_response_web_search(openai_responses_impl, mock_responses_store): """Test prepending a web search previous response to a new response.""" - input_item_message = OpenAIResponseMessage( id="123", content=[OpenAIResponseInputMessageContentText(text="fake_previous_input")], role="user", ) - input_items = OpenAIResponseInputItemList(data=[input_item_message]) output_web_search = OpenAIResponseOutputMessageWebSearchToolCall( id="ws_123", status="completed", @@ -356,18 +346,15 @@ async def test_prepend_previous_response_web_search(get_previous_response_with_i status="completed", role="assistant", ) - response = OpenAIResponseObject( + response = OpenAIResponseObjectWithInput( created_at=1, id="resp_123", model="fake_model", output=[output_web_search, output_message], status="completed", + input=[input_item_message], ) - previous_response = OpenAIResponsePreviousResponseWithInputItems( - input_items=input_items, - response=response, - ) - get_previous_response_with_input.return_value = previous_response + mock_responses_store.get_response_object.return_value = response input_messages = [OpenAIResponseMessage(content="fake_input", role="user")] input = await openai_responses_impl._prepend_previous_response(input_messages, "resp_123") @@ -464,9 +451,8 @@ async def test_create_openai_response_with_instructions_and_multiple_messages( @pytest.mark.asyncio -@patch.object(OpenAIResponsesImpl, "_get_previous_response_with_input") async def test_create_openai_response_with_instructions_and_previous_response( - get_previous_response_with_input, openai_responses_impl, mock_inference_api + openai_responses_impl, mock_responses_store, mock_inference_api ): """Test prepending both instructions and previous response.""" @@ -475,25 +461,21 @@ async def test_create_openai_response_with_instructions_and_previous_response( content="Name some towns in Ireland", role="user", ) - input_items = OpenAIResponseInputItemList(data=[input_item_message]) response_output_message = OpenAIResponseMessage( id="123", content="Galway, Longford, Sligo", status="completed", role="assistant", ) - response = OpenAIResponseObject( + response = OpenAIResponseObjectWithInput( created_at=1, id="resp_123", model="fake_model", output=[response_output_message], status="completed", + input=[input_item_message], ) - previous_response = OpenAIResponsePreviousResponseWithInputItems( - input_items=input_items, - response=response, - ) - get_previous_response_with_input.return_value = previous_response + mock_responses_store.get_response_object.return_value = response model = "meta-llama/Llama-3.1-8B-Instruct" instructions = "You are a geography expert. Provide concise answers." @@ -511,7 +493,7 @@ async def test_create_openai_response_with_instructions_and_previous_response( sent_messages = call_args.kwargs["messages"] # Check that instructions were prepended as a system message - assert len(sent_messages) == 4 + assert len(sent_messages) == 4, sent_messages assert sent_messages[0].role == "system" assert sent_messages[0].content == instructions diff --git a/tests/verifications/openai-api-verification-run.yaml b/tests/verifications/openai-api-verification-run.yaml index 4c322af28..d6d8cd07d 100644 --- a/tests/verifications/openai-api-verification-run.yaml +++ b/tests/verifications/openai-api-verification-run.yaml @@ -63,6 +63,9 @@ providers: type: sqlite namespace: null db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/openai}/agents_store.db + responses_store: + type: sqlite + db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/openai}/responses_store.db tool_runtime: - provider_id: brave-search provider_type: remote::brave-search