From 15f630e5dab99ef0f47ec90e61f79211c073524a Mon Sep 17 00:00:00 2001 From: ehhuang Date: Mon, 16 Jun 2025 22:43:35 -0700 Subject: [PATCH] feat: support pagination in inference/responses stores (#2397) # What does this PR do? ## Test Plan added unit tests --- .../providers/inline/files/localfs/files.py | 22 +- .../utils/inference/inference_store.py | 19 +- .../utils/responses/responses_store.py | 54 ++- llama_stack/providers/utils/sqlstore/api.py | 18 +- .../utils/sqlstore/sqlalchemy_sqlstore.py | 88 +++- tests/unit/files/test_files.py | 25 +- .../utils/inference/test_inference_store.py | 203 +++++++++ .../utils/responses/test_responses_store.py | 366 ++++++++++++++++ tests/unit/utils/sqlstore/test_sqlstore.py | 390 ++++++++++++++++++ tests/unit/utils/test_sqlstore.py | 62 --- 10 files changed, 1130 insertions(+), 117 deletions(-) create mode 100644 tests/unit/utils/inference/test_inference_store.py create mode 100644 tests/unit/utils/responses/test_responses_store.py create mode 100644 tests/unit/utils/sqlstore/test_sqlstore.py delete mode 100644 tests/unit/utils/test_sqlstore.py diff --git a/llama_stack/providers/inline/files/localfs/files.py b/llama_stack/providers/inline/files/localfs/files.py index f2891c528..851ce2a6a 100644 --- a/llama_stack/providers/inline/files/localfs/files.py +++ b/llama_stack/providers/inline/files/localfs/files.py @@ -114,18 +114,18 @@ class LocalfsFilesImpl(Files): if not self.sql_store: raise RuntimeError("Files provider not initialized") - # TODO: Implement 'after' pagination properly - if after: - raise NotImplementedError("After pagination not yet implemented") + if not order: + order = Order.desc - where = None + where_conditions = {} if purpose: - where = {"purpose": purpose.value} + where_conditions["purpose"] = purpose.value - rows = await self.sql_store.fetch_all( - "openai_files", - where=where, - order_by=[("created_at", order.value if order else Order.desc.value)], + paginated_result = await self.sql_store.fetch_all( + table="openai_files", + where=where_conditions if where_conditions else None, + order_by=[("created_at", order.value)], + cursor=("id", after) if after else None, limit=limit, ) @@ -138,12 +138,12 @@ class LocalfsFilesImpl(Files): created_at=row["created_at"], expires_at=row["expires_at"], ) - for row in rows + for row in paginated_result.data ] return ListOpenAIFileResponse( data=files, - has_more=False, # TODO: Implement proper pagination + has_more=paginated_result.has_more, first_id=files[0].id if files else "", last_id=files[-1].id if files else "", ) diff --git a/llama_stack/providers/utils/inference/inference_store.py b/llama_stack/providers/utils/inference/inference_store.py index 7b6bc2e3d..ab43e48b1 100644 --- a/llama_stack/providers/utils/inference/inference_store.py +++ b/llama_stack/providers/utils/inference/inference_store.py @@ -76,16 +76,18 @@ class InferenceStore: if not self.sql_store: raise ValueError("Inference store is not initialized") - # TODO: support after - if after: - raise NotImplementedError("After is not supported for SQLite") if not order: order = Order.desc - rows = await self.sql_store.fetch_all( - "chat_completions", - where={"model": model} if model else None, + where_conditions = {} + if model: + where_conditions["model"] = model + + paginated_result = await self.sql_store.fetch_all( + table="chat_completions", + where=where_conditions if where_conditions else None, order_by=[("created", order.value)], + cursor=("id", after) if after else None, limit=limit, ) @@ -97,12 +99,11 @@ class InferenceStore: choices=row["choices"], input_messages=row["input_messages"], ) - for row in rows + for row in paginated_result.data ] return ListOpenAIChatCompletionResponse( data=data, - # TODO: implement has_more - has_more=False, + has_more=paginated_result.has_more, first_id=data[0].id if data else "", last_id=data[-1].id if data else "", ) diff --git a/llama_stack/providers/utils/responses/responses_store.py b/llama_stack/providers/utils/responses/responses_store.py index 15354e3e2..151b020f7 100644 --- a/llama_stack/providers/utils/responses/responses_store.py +++ b/llama_stack/providers/utils/responses/responses_store.py @@ -70,24 +70,25 @@ class ResponsesStore: :param model: The model to filter by. :param order: The order to sort the responses by. """ - # TODO: support after - if after: - raise NotImplementedError("After is not supported for SQLite") if not order: order = Order.desc - rows = await self.sql_store.fetch_all( - "openai_responses", - where={"model": model} if model else None, + where_conditions = {} + if model: + where_conditions["model"] = model + + paginated_result = await self.sql_store.fetch_all( + table="openai_responses", + where=where_conditions if where_conditions else None, order_by=[("created_at", order.value)], + cursor=("id", after) if after else None, limit=limit, ) - data = [OpenAIResponseObjectWithInput(**row["response_object"]) for row in rows] + data = [OpenAIResponseObjectWithInput(**row["response_object"]) for row in paginated_result.data] return ListOpenAIResponseObject( data=data, - # TODO: implement has_more - has_more=False, + has_more=paginated_result.has_more, first_id=data[0].id if data else "", last_id=data[-1].id if data else "", ) @@ -117,19 +118,38 @@ class ResponsesStore: :param limit: A limit on the number of objects to be returned. :param order: The order to return the input items in. """ - # TODO: support after/before pagination - if after or before: - raise NotImplementedError("After/before pagination is not supported yet") if include: raise NotImplementedError("Include is not supported yet") + if before and after: + raise ValueError("Cannot specify both 'before' and 'after' parameters") response_with_input = await self.get_response_object(response_id) - input_items = response_with_input.input + items = response_with_input.input if order == Order.desc: - input_items = list(reversed(input_items)) + items = list(reversed(items)) - if limit is not None and len(input_items) > limit: - input_items = input_items[:limit] + start_index = 0 + end_index = len(items) - return ListOpenAIResponseInputItem(data=input_items) + if after or before: + for i, item in enumerate(items): + item_id = getattr(item, "id", None) + if after and item_id == after: + start_index = i + 1 + if before and item_id == before: + end_index = i + break + + if after and start_index == 0: + raise ValueError(f"Input item with id '{after}' not found for response '{response_id}'") + if before and end_index == len(items): + raise ValueError(f"Input item with id '{before}' not found for response '{response_id}'") + + items = items[start_index:end_index] + + # Apply limit + if limit is not None: + items = items[:limit] + + return ListOpenAIResponseInputItem(data=items) diff --git a/llama_stack/providers/utils/sqlstore/api.py b/llama_stack/providers/utils/sqlstore/api.py index ace40e4c4..248dbb38e 100644 --- a/llama_stack/providers/utils/sqlstore/api.py +++ b/llama_stack/providers/utils/sqlstore/api.py @@ -10,6 +10,8 @@ from typing import Any, Literal, Protocol from pydantic import BaseModel +from llama_stack.apis.common.responses import PaginatedResponse + class ColumnType(Enum): INTEGER = "INTEGER" @@ -51,9 +53,21 @@ class SqlStore(Protocol): where: Mapping[str, Any] | None = None, limit: int | None = None, order_by: list[tuple[str, Literal["asc", "desc"]]] | None = None, - ) -> list[dict[str, Any]]: + cursor: tuple[str, str] | None = None, + ) -> PaginatedResponse: """ - Fetch all rows from a table. + Fetch all rows from a table with optional cursor-based pagination. + + :param table: The table name + :param where: WHERE conditions + :param limit: Maximum number of records to return + :param order_by: List of (column, order) tuples for sorting + :param cursor: Tuple of (key_column, cursor_id) for pagination (None for first page) + Requires order_by with exactly one column when used + :return: PaginatedResult with data and has_more flag + + Note: Cursor pagination only supports single-column ordering for simplicity. + Multi-column ordering is allowed without cursor but will raise an error with cursor. """ pass diff --git a/llama_stack/providers/utils/sqlstore/sqlalchemy_sqlstore.py b/llama_stack/providers/utils/sqlstore/sqlalchemy_sqlstore.py index 825220679..db8180e74 100644 --- a/llama_stack/providers/utils/sqlstore/sqlalchemy_sqlstore.py +++ b/llama_stack/providers/utils/sqlstore/sqlalchemy_sqlstore.py @@ -21,6 +21,8 @@ from sqlalchemy import ( ) from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine +from llama_stack.apis.common.responses import PaginatedResponse + from .api import ColumnDefinition, ColumnType, SqlStore from .sqlstore import SqlAlchemySqlStoreConfig @@ -94,14 +96,56 @@ class SqlAlchemySqlStoreImpl(SqlStore): where: Mapping[str, Any] | None = None, limit: int | None = None, order_by: list[tuple[str, Literal["asc", "desc"]]] | None = None, - ) -> list[dict[str, Any]]: + cursor: tuple[str, str] | None = None, + ) -> PaginatedResponse: async with self.async_session() as session: - query = select(self.metadata.tables[table]) + table_obj = self.metadata.tables[table] + query = select(table_obj) + if where: for key, value in where.items(): - query = query.where(self.metadata.tables[table].c[key] == value) - if limit: - query = query.limit(limit) + query = query.where(table_obj.c[key] == value) + + # Handle cursor-based pagination + if cursor: + # Validate cursor tuple format + if not isinstance(cursor, tuple) or len(cursor) != 2: + raise ValueError(f"Cursor must be a tuple of (key_column, cursor_id), got: {cursor}") + + # Require order_by for cursor pagination + if not order_by: + raise ValueError("order_by is required when using cursor pagination") + + # Only support single-column ordering for cursor pagination + if len(order_by) != 1: + raise ValueError( + f"Cursor pagination only supports single-column ordering, got {len(order_by)} columns" + ) + + cursor_key_column, cursor_id = cursor + order_column, order_direction = order_by[0] + + # Verify cursor_key_column exists + if cursor_key_column not in table_obj.c: + raise ValueError(f"Cursor key column '{cursor_key_column}' not found in table '{table}'") + + # Get cursor value for the order column + cursor_query = select(table_obj.c[order_column]).where(table_obj.c[cursor_key_column] == cursor_id) + cursor_result = await session.execute(cursor_query) + cursor_row = cursor_result.fetchone() + + if not cursor_row: + raise ValueError(f"Record with {cursor_key_column}='{cursor_id}' not found in table '{table}'") + + cursor_value = cursor_row[0] + + # Apply cursor condition based on sort direction + if order_direction == "desc": + query = query.where(table_obj.c[order_column] < cursor_value) + else: + query = query.where(table_obj.c[order_column] > cursor_value) + + # Apply ordering if order_by: if not isinstance(order_by, list): raise ValueError( @@ -113,16 +157,36 @@ class SqlAlchemySqlStoreImpl(SqlStore): f"order_by must be a list of tuples (column, order={['asc', 'desc']}), got {order_by}" ) name, order_type = order + if name not in table_obj.c: + raise ValueError(f"Column '{name}' not found in table '{table}'") if order_type == "asc": - query = query.order_by(self.metadata.tables[table].c[name].asc()) + query = query.order_by(table_obj.c[name].asc()) elif order_type == "desc": - query = query.order_by(self.metadata.tables[table].c[name].desc()) + query = query.order_by(table_obj.c[name].desc()) else: raise ValueError(f"Invalid order '{order_type}' for column '{name}'") + + # Fetch limit + 1 to determine has_more + fetch_limit = limit + if limit: + fetch_limit = limit + 1 + + if fetch_limit: + query = query.limit(fetch_limit) + result = await session.execute(query) if result.rowcount == 0: - return [] - return [dict(row._mapping) for row in result] + rows = [] + else: + rows = [dict(row._mapping) for row in result] + + # Always return pagination result + has_more = False + if limit and len(rows) > limit: + has_more = True + rows = rows[:limit] + + return PaginatedResponse(data=rows, has_more=has_more) async def fetch_one( self, @@ -130,10 +194,10 @@ class SqlAlchemySqlStoreImpl(SqlStore): where: Mapping[str, Any] | None = None, order_by: list[tuple[str, Literal["asc", "desc"]]] | None = None, ) -> dict[str, Any] | None: - rows = await self.fetch_all(table, where, limit=1, order_by=order_by) - if not rows: + result = await self.fetch_all(table, where, limit=1, order_by=order_by) + if not result.data: return None - return rows[0] + return result.data[0] async def update( self, diff --git a/tests/unit/files/test_files.py b/tests/unit/files/test_files.py index 32006cbff..ef1dc9743 100644 --- a/tests/unit/files/test_files.py +++ b/tests/unit/files/test_files.py @@ -328,7 +328,24 @@ class TestOpenAIFilesAPI: assert uploaded_file.filename == "uploaded_file" # Default filename @pytest.mark.asyncio - async def test_after_pagination_not_implemented(self, files_provider): - """Test that 'after' pagination raises NotImplementedError.""" - with pytest.raises(NotImplementedError, match="After pagination not yet implemented"): - await files_provider.openai_list_files(after="file-some-id") + async def test_after_pagination_works(self, files_provider, sample_text_file): + """Test that 'after' pagination works correctly.""" + # Upload multiple files to test pagination + uploaded_files = [] + for _ in range(5): + file = await files_provider.openai_upload_file(file=sample_text_file, purpose=OpenAIFilePurpose.ASSISTANTS) + uploaded_files.append(file) + + # Get first page without 'after' parameter + first_page = await files_provider.openai_list_files(limit=2, order=Order.desc) + assert len(first_page.data) == 2 + assert first_page.has_more is True + + # Get second page using 'after' parameter + second_page = await files_provider.openai_list_files(after=first_page.data[-1].id, limit=2, order=Order.desc) + assert len(second_page.data) <= 2 + + # Verify no overlap between pages + first_page_ids = {f.id for f in first_page.data} + second_page_ids = {f.id for f in second_page.data} + assert first_page_ids.isdisjoint(second_page_ids) diff --git a/tests/unit/utils/inference/test_inference_store.py b/tests/unit/utils/inference/test_inference_store.py new file mode 100644 index 000000000..b30748617 --- /dev/null +++ b/tests/unit/utils/inference/test_inference_store.py @@ -0,0 +1,203 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import time +from tempfile import TemporaryDirectory + +import pytest + +from llama_stack.apis.inference import ( + OpenAIAssistantMessageParam, + OpenAIChatCompletion, + OpenAIChoice, + OpenAIUserMessageParam, + Order, +) +from llama_stack.providers.utils.inference.inference_store import InferenceStore +from llama_stack.providers.utils.sqlstore.sqlstore import SqliteSqlStoreConfig + + +def create_test_chat_completion( + completion_id: str, created_timestamp: int, model: str = "test-model" +) -> OpenAIChatCompletion: + """Helper to create a test chat completion.""" + return OpenAIChatCompletion( + id=completion_id, + created=created_timestamp, + model=model, + object="chat.completion", + choices=[ + OpenAIChoice( + index=0, + message=OpenAIAssistantMessageParam( + role="assistant", + content=f"Response for {completion_id}", + ), + finish_reason="stop", + ) + ], + ) + + +@pytest.mark.asyncio +async def test_inference_store_pagination_basic(): + """Test basic pagination functionality.""" + with TemporaryDirectory() as tmp_dir: + db_path = tmp_dir + "/test.db" + store = InferenceStore(SqliteSqlStoreConfig(db_path=db_path)) + await store.initialize() + + # Create test data with different timestamps + base_time = int(time.time()) + test_data = [ + ("zebra-task", base_time + 1), + ("apple-job", base_time + 2), + ("moon-work", base_time + 3), + ("banana-run", base_time + 4), + ("car-exec", base_time + 5), + ] + + # Store test chat completions + for completion_id, timestamp in test_data: + completion = create_test_chat_completion(completion_id, timestamp) + input_messages = [OpenAIUserMessageParam(role="user", content=f"Test message for {completion_id}")] + await store.store_chat_completion(completion, input_messages) + + # Test 1: First page with limit=2, descending order (default) + result = await store.list_chat_completions(limit=2, order=Order.desc) + assert len(result.data) == 2 + assert result.data[0].id == "car-exec" # Most recent first + assert result.data[1].id == "banana-run" + assert result.has_more is True + assert result.last_id == "banana-run" + + # Test 2: Second page using 'after' parameter + result2 = await store.list_chat_completions(after="banana-run", limit=2, order=Order.desc) + assert len(result2.data) == 2 + assert result2.data[0].id == "moon-work" + assert result2.data[1].id == "apple-job" + assert result2.has_more is True + + # Test 3: Final page + result3 = await store.list_chat_completions(after="apple-job", limit=2, order=Order.desc) + assert len(result3.data) == 1 + assert result3.data[0].id == "zebra-task" + assert result3.has_more is False + + +@pytest.mark.asyncio +async def test_inference_store_pagination_ascending(): + """Test pagination with ascending order.""" + with TemporaryDirectory() as tmp_dir: + db_path = tmp_dir + "/test.db" + store = InferenceStore(SqliteSqlStoreConfig(db_path=db_path)) + await store.initialize() + + # Create test data + base_time = int(time.time()) + test_data = [ + ("delta-item", base_time + 1), + ("charlie-task", base_time + 2), + ("alpha-work", base_time + 3), + ] + + # Store test chat completions + for completion_id, timestamp in test_data: + completion = create_test_chat_completion(completion_id, timestamp) + input_messages = [OpenAIUserMessageParam(role="user", content=f"Test message for {completion_id}")] + await store.store_chat_completion(completion, input_messages) + + # Test ascending order pagination + result = await store.list_chat_completions(limit=1, order=Order.asc) + assert len(result.data) == 1 + assert result.data[0].id == "delta-item" # Oldest first + assert result.has_more is True + + # Second page with ascending order + result2 = await store.list_chat_completions(after="delta-item", limit=1, order=Order.asc) + assert len(result2.data) == 1 + assert result2.data[0].id == "charlie-task" + assert result2.has_more is True + + +@pytest.mark.asyncio +async def test_inference_store_pagination_with_model_filter(): + """Test pagination combined with model filtering.""" + with TemporaryDirectory() as tmp_dir: + db_path = tmp_dir + "/test.db" + store = InferenceStore(SqliteSqlStoreConfig(db_path=db_path)) + await store.initialize() + + # Create test data with different models + base_time = int(time.time()) + test_data = [ + ("xyz-task", base_time + 1, "model-a"), + ("def-work", base_time + 2, "model-b"), + ("pqr-job", base_time + 3, "model-a"), + ("abc-run", base_time + 4, "model-b"), + ] + + # Store test chat completions + for completion_id, timestamp, model in test_data: + completion = create_test_chat_completion(completion_id, timestamp, model) + input_messages = [OpenAIUserMessageParam(role="user", content=f"Test message for {completion_id}")] + await store.store_chat_completion(completion, input_messages) + + # Test pagination with model filter + result = await store.list_chat_completions(limit=1, model="model-a", order=Order.desc) + assert len(result.data) == 1 + assert result.data[0].id == "pqr-job" # Most recent model-a + assert result.data[0].model == "model-a" + assert result.has_more is True + + # Second page with model filter + result2 = await store.list_chat_completions(after="pqr-job", limit=1, model="model-a", order=Order.desc) + assert len(result2.data) == 1 + assert result2.data[0].id == "xyz-task" + assert result2.data[0].model == "model-a" + assert result2.has_more is False + + +@pytest.mark.asyncio +async def test_inference_store_pagination_invalid_after(): + """Test error handling for invalid 'after' parameter.""" + with TemporaryDirectory() as tmp_dir: + db_path = tmp_dir + "/test.db" + store = InferenceStore(SqliteSqlStoreConfig(db_path=db_path)) + await store.initialize() + + # Try to paginate with non-existent ID + with pytest.raises(ValueError, match="Record with id='non-existent' not found in table 'chat_completions'"): + await store.list_chat_completions(after="non-existent", limit=2) + + +@pytest.mark.asyncio +async def test_inference_store_pagination_no_limit(): + """Test pagination behavior when no limit is specified.""" + with TemporaryDirectory() as tmp_dir: + db_path = tmp_dir + "/test.db" + store = InferenceStore(SqliteSqlStoreConfig(db_path=db_path)) + await store.initialize() + + # Create test data + base_time = int(time.time()) + test_data = [ + ("omega-first", base_time + 1), + ("beta-second", base_time + 2), + ] + + # Store test chat completions + for completion_id, timestamp in test_data: + completion = create_test_chat_completion(completion_id, timestamp) + input_messages = [OpenAIUserMessageParam(role="user", content=f"Test message for {completion_id}")] + await store.store_chat_completion(completion, input_messages) + + # Test without limit + result = await store.list_chat_completions(order=Order.desc) + assert len(result.data) == 2 + assert result.data[0].id == "beta-second" # Most recent first + assert result.data[1].id == "omega-first" + assert result.has_more is False diff --git a/tests/unit/utils/responses/test_responses_store.py b/tests/unit/utils/responses/test_responses_store.py new file mode 100644 index 000000000..51fcb1ec2 --- /dev/null +++ b/tests/unit/utils/responses/test_responses_store.py @@ -0,0 +1,366 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import time +from tempfile import TemporaryDirectory + +import pytest + +from llama_stack.apis.agents import Order +from llama_stack.apis.agents.openai_responses import ( + OpenAIResponseInput, + OpenAIResponseObject, +) +from llama_stack.providers.utils.responses.responses_store import ResponsesStore +from llama_stack.providers.utils.sqlstore.sqlstore import SqliteSqlStoreConfig + + +def create_test_response_object( + response_id: str, created_timestamp: int, model: str = "test-model" +) -> OpenAIResponseObject: + """Helper to create a test response object.""" + return OpenAIResponseObject( + id=response_id, + created_at=created_timestamp, + model=model, + object="response", + output=[], # Required field + status="completed", # Required field + ) + + +def create_test_response_input(content: str, input_id: str) -> OpenAIResponseInput: + """Helper to create a test response input.""" + from llama_stack.apis.agents.openai_responses import OpenAIResponseMessage + + return OpenAIResponseMessage( + id=input_id, + content=content, + role="user", + type="message", + ) + + +@pytest.mark.asyncio +async def test_responses_store_pagination_basic(): + """Test basic pagination functionality for responses store.""" + with TemporaryDirectory() as tmp_dir: + db_path = tmp_dir + "/test.db" + store = ResponsesStore(SqliteSqlStoreConfig(db_path=db_path)) + await store.initialize() + + # Create test data with different timestamps + base_time = int(time.time()) + test_data = [ + ("zebra-resp", base_time + 1), + ("apple-resp", base_time + 2), + ("moon-resp", base_time + 3), + ("banana-resp", base_time + 4), + ("car-resp", base_time + 5), + ] + + # Store test responses + for response_id, timestamp in test_data: + response = create_test_response_object(response_id, timestamp) + input_list = [create_test_response_input(f"Input for {response_id}", f"input-{response_id}")] + await store.store_response_object(response, input_list) + + # Test 1: First page with limit=2, descending order (default) + result = await store.list_responses(limit=2, order=Order.desc) + assert len(result.data) == 2 + assert result.data[0].id == "car-resp" # Most recent first + assert result.data[1].id == "banana-resp" + assert result.has_more is True + assert result.last_id == "banana-resp" + + # Test 2: Second page using 'after' parameter + result2 = await store.list_responses(after="banana-resp", limit=2, order=Order.desc) + assert len(result2.data) == 2 + assert result2.data[0].id == "moon-resp" + assert result2.data[1].id == "apple-resp" + assert result2.has_more is True + + # Test 3: Final page + result3 = await store.list_responses(after="apple-resp", limit=2, order=Order.desc) + assert len(result3.data) == 1 + assert result3.data[0].id == "zebra-resp" + assert result3.has_more is False + + +@pytest.mark.asyncio +async def test_responses_store_pagination_ascending(): + """Test pagination with ascending order.""" + with TemporaryDirectory() as tmp_dir: + db_path = tmp_dir + "/test.db" + store = ResponsesStore(SqliteSqlStoreConfig(db_path=db_path)) + await store.initialize() + + # Create test data + base_time = int(time.time()) + test_data = [ + ("delta-resp", base_time + 1), + ("charlie-resp", base_time + 2), + ("alpha-resp", base_time + 3), + ] + + # Store test responses + for response_id, timestamp in test_data: + response = create_test_response_object(response_id, timestamp) + input_list = [create_test_response_input(f"Input for {response_id}", f"input-{response_id}")] + await store.store_response_object(response, input_list) + + # Test ascending order pagination + result = await store.list_responses(limit=1, order=Order.asc) + assert len(result.data) == 1 + assert result.data[0].id == "delta-resp" # Oldest first + assert result.has_more is True + + # Second page with ascending order + result2 = await store.list_responses(after="delta-resp", limit=1, order=Order.asc) + assert len(result2.data) == 1 + assert result2.data[0].id == "charlie-resp" + assert result2.has_more is True + + +@pytest.mark.asyncio +async def test_responses_store_pagination_with_model_filter(): + """Test pagination combined with model filtering.""" + with TemporaryDirectory() as tmp_dir: + db_path = tmp_dir + "/test.db" + store = ResponsesStore(SqliteSqlStoreConfig(db_path=db_path)) + await store.initialize() + + # Create test data with different models + base_time = int(time.time()) + test_data = [ + ("xyz-resp", base_time + 1, "model-a"), + ("def-resp", base_time + 2, "model-b"), + ("pqr-resp", base_time + 3, "model-a"), + ("abc-resp", base_time + 4, "model-b"), + ] + + # Store test responses + for response_id, timestamp, model in test_data: + response = create_test_response_object(response_id, timestamp, model) + input_list = [create_test_response_input(f"Input for {response_id}", f"input-{response_id}")] + await store.store_response_object(response, input_list) + + # Test pagination with model filter + result = await store.list_responses(limit=1, model="model-a", order=Order.desc) + assert len(result.data) == 1 + assert result.data[0].id == "pqr-resp" # Most recent model-a + assert result.data[0].model == "model-a" + assert result.has_more is True + + # Second page with model filter + result2 = await store.list_responses(after="pqr-resp", limit=1, model="model-a", order=Order.desc) + assert len(result2.data) == 1 + assert result2.data[0].id == "xyz-resp" + assert result2.data[0].model == "model-a" + assert result2.has_more is False + + +@pytest.mark.asyncio +async def test_responses_store_pagination_invalid_after(): + """Test error handling for invalid 'after' parameter.""" + with TemporaryDirectory() as tmp_dir: + db_path = tmp_dir + "/test.db" + store = ResponsesStore(SqliteSqlStoreConfig(db_path=db_path)) + await store.initialize() + + # Try to paginate with non-existent ID + with pytest.raises(ValueError, match="Record with id.*'non-existent' not found in table 'openai_responses'"): + await store.list_responses(after="non-existent", limit=2) + + +@pytest.mark.asyncio +async def test_responses_store_pagination_no_limit(): + """Test pagination behavior when no limit is specified.""" + with TemporaryDirectory() as tmp_dir: + db_path = tmp_dir + "/test.db" + store = ResponsesStore(SqliteSqlStoreConfig(db_path=db_path)) + await store.initialize() + + # Create test data + base_time = int(time.time()) + test_data = [ + ("omega-resp", base_time + 1), + ("beta-resp", base_time + 2), + ] + + # Store test responses + for response_id, timestamp in test_data: + response = create_test_response_object(response_id, timestamp) + input_list = [create_test_response_input(f"Input for {response_id}", f"input-{response_id}")] + await store.store_response_object(response, input_list) + + # Test without limit (should use default of 50) + result = await store.list_responses(order=Order.desc) + assert len(result.data) == 2 + assert result.data[0].id == "beta-resp" # Most recent first + assert result.data[1].id == "omega-resp" + assert result.has_more is False + + +@pytest.mark.asyncio +async def test_responses_store_get_response_object(): + """Test retrieving a single response object.""" + with TemporaryDirectory() as tmp_dir: + db_path = tmp_dir + "/test.db" + store = ResponsesStore(SqliteSqlStoreConfig(db_path=db_path)) + await store.initialize() + + # Store a test response + response = create_test_response_object("test-resp", int(time.time())) + input_list = [create_test_response_input("Test input content", "input-test-resp")] + await store.store_response_object(response, input_list) + + # Retrieve the response + retrieved = await store.get_response_object("test-resp") + assert retrieved.id == "test-resp" + assert retrieved.model == "test-model" + assert len(retrieved.input) == 1 + assert retrieved.input[0].content == "Test input content" + + # Test error for non-existent response + with pytest.raises(ValueError, match="Response with id non-existent not found"): + await store.get_response_object("non-existent") + + +@pytest.mark.asyncio +async def test_responses_store_input_items_pagination(): + """Test pagination functionality for input items.""" + with TemporaryDirectory() as tmp_dir: + db_path = tmp_dir + "/test.db" + store = ResponsesStore(SqliteSqlStoreConfig(db_path=db_path)) + await store.initialize() + + # Store a test response with many inputs with explicit IDs + response = create_test_response_object("test-resp", int(time.time())) + input_list = [ + create_test_response_input("First input", "input-1"), + create_test_response_input("Second input", "input-2"), + create_test_response_input("Third input", "input-3"), + create_test_response_input("Fourth input", "input-4"), + create_test_response_input("Fifth input", "input-5"), + ] + await store.store_response_object(response, input_list) + + # Verify all items are stored correctly with explicit IDs + all_items = await store.list_response_input_items("test-resp", order=Order.desc) + assert len(all_items.data) == 5 + + # In desc order: [Fifth, Fourth, Third, Second, First] + assert all_items.data[0].content == "Fifth input" + assert all_items.data[0].id == "input-5" + assert all_items.data[1].content == "Fourth input" + assert all_items.data[1].id == "input-4" + assert all_items.data[2].content == "Third input" + assert all_items.data[2].id == "input-3" + assert all_items.data[3].content == "Second input" + assert all_items.data[3].id == "input-2" + assert all_items.data[4].content == "First input" + assert all_items.data[4].id == "input-1" + + # Test basic pagination with after parameter using actual IDs + result = await store.list_response_input_items("test-resp", limit=2, order=Order.desc) + assert len(result.data) == 2 + assert result.data[0].content == "Fifth input" # Most recent first (reversed order) + assert result.data[1].content == "Fourth input" + + # Test pagination using after with actual ID + result2 = await store.list_response_input_items("test-resp", after="input-5", limit=2, order=Order.desc) + assert len(result2.data) == 2 + assert result2.data[0].content == "Fourth input" # Next item after Fifth + assert result2.data[1].content == "Third input" + + # Test final page + result3 = await store.list_response_input_items("test-resp", after="input-3", limit=2, order=Order.desc) + assert len(result3.data) == 2 + assert result3.data[0].content == "Second input" + assert result3.data[1].content == "First input" + + # Test ascending order pagination + result_asc = await store.list_response_input_items("test-resp", limit=2, order=Order.asc) + assert len(result_asc.data) == 2 + assert result_asc.data[0].content == "First input" # Oldest first + assert result_asc.data[1].content == "Second input" + + # Test pagination with ascending order + result_asc2 = await store.list_response_input_items("test-resp", after="input-1", limit=2, order=Order.asc) + assert len(result_asc2.data) == 2 + assert result_asc2.data[0].content == "Second input" + assert result_asc2.data[1].content == "Third input" + + # Test error for non-existent after ID + with pytest.raises(ValueError, match="Input item with id 'non-existent' not found for response 'test-resp'"): + await store.list_response_input_items("test-resp", after="non-existent") + + # Test error for unsupported features + with pytest.raises(NotImplementedError, match="Include is not supported yet"): + await store.list_response_input_items("test-resp", include=["some-field"]) + + # Test error for mutually exclusive parameters + with pytest.raises(ValueError, match="Cannot specify both 'before' and 'after' parameters"): + await store.list_response_input_items("test-resp", before="some-id", after="other-id") + + +@pytest.mark.asyncio +async def test_responses_store_input_items_before_pagination(): + """Test before pagination functionality for input items.""" + with TemporaryDirectory() as tmp_dir: + db_path = tmp_dir + "/test.db" + store = ResponsesStore(SqliteSqlStoreConfig(db_path=db_path)) + await store.initialize() + + # Store a test response with many inputs with explicit IDs + response = create_test_response_object("test-resp-before", int(time.time())) + input_list = [ + create_test_response_input("First input", "before-1"), + create_test_response_input("Second input", "before-2"), + create_test_response_input("Third input", "before-3"), + create_test_response_input("Fourth input", "before-4"), + create_test_response_input("Fifth input", "before-5"), + ] + await store.store_response_object(response, input_list) + + # Test before pagination with descending order + # In desc order: [Fifth, Fourth, Third, Second, First] + # before="before-3" should return [Fifth, Fourth] + result = await store.list_response_input_items("test-resp-before", before="before-3", order=Order.desc) + assert len(result.data) == 2 + assert result.data[0].content == "Fifth input" + assert result.data[1].content == "Fourth input" + + # Test before pagination with limit + result2 = await store.list_response_input_items( + "test-resp-before", before="before-2", limit=3, order=Order.desc + ) + assert len(result2.data) == 3 + assert result2.data[0].content == "Fifth input" + assert result2.data[1].content == "Fourth input" + assert result2.data[2].content == "Third input" + + # Test before pagination with ascending order + # In asc order: [First, Second, Third, Fourth, Fifth] + # before="before-4" should return [First, Second, Third] + result3 = await store.list_response_input_items("test-resp-before", before="before-4", order=Order.asc) + assert len(result3.data) == 3 + assert result3.data[0].content == "First input" + assert result3.data[1].content == "Second input" + assert result3.data[2].content == "Third input" + + # Test before with limit in ascending order + result4 = await store.list_response_input_items("test-resp-before", before="before-5", limit=2, order=Order.asc) + assert len(result4.data) == 2 + assert result4.data[0].content == "First input" + assert result4.data[1].content == "Second input" + + # Test error for non-existent before ID + with pytest.raises( + ValueError, match="Input item with id 'non-existent' not found for response 'test-resp-before'" + ): + await store.list_response_input_items("test-resp-before", before="non-existent") diff --git a/tests/unit/utils/sqlstore/test_sqlstore.py b/tests/unit/utils/sqlstore/test_sqlstore.py new file mode 100644 index 000000000..c4230a396 --- /dev/null +++ b/tests/unit/utils/sqlstore/test_sqlstore.py @@ -0,0 +1,390 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import time +from tempfile import TemporaryDirectory + +import pytest + +from llama_stack.providers.utils.sqlstore.api import ColumnType +from llama_stack.providers.utils.sqlstore.sqlalchemy_sqlstore import SqlAlchemySqlStoreImpl +from llama_stack.providers.utils.sqlstore.sqlstore import SqliteSqlStoreConfig + + +@pytest.mark.asyncio +async def test_sqlite_sqlstore(): + with TemporaryDirectory() as tmp_dir: + db_name = "test.db" + sqlstore = SqlAlchemySqlStoreImpl( + SqliteSqlStoreConfig( + db_path=tmp_dir + "/" + db_name, + ) + ) + await sqlstore.create_table( + table="test", + schema={ + "id": ColumnType.INTEGER, + "name": ColumnType.STRING, + }, + ) + await sqlstore.insert("test", {"id": 1, "name": "test"}) + await sqlstore.insert("test", {"id": 12, "name": "test12"}) + result = await sqlstore.fetch_all("test") + assert result.data == [{"id": 1, "name": "test"}, {"id": 12, "name": "test12"}] + assert result.has_more is False + + row = await sqlstore.fetch_one("test", {"id": 1}) + assert row == {"id": 1, "name": "test"} + + row = await sqlstore.fetch_one("test", {"name": "test12"}) + assert row == {"id": 12, "name": "test12"} + + # order by + result = await sqlstore.fetch_all("test", order_by=[("id", "asc")]) + assert result.data == [{"id": 1, "name": "test"}, {"id": 12, "name": "test12"}] + + result = await sqlstore.fetch_all("test", order_by=[("id", "desc")]) + assert result.data == [{"id": 12, "name": "test12"}, {"id": 1, "name": "test"}] + + # limit + result = await sqlstore.fetch_all("test", limit=1) + assert result.data == [{"id": 1, "name": "test"}] + assert result.has_more is True + + # update + await sqlstore.update("test", {"name": "test123"}, {"id": 1}) + row = await sqlstore.fetch_one("test", {"id": 1}) + assert row == {"id": 1, "name": "test123"} + + # delete + await sqlstore.delete("test", {"id": 1}) + result = await sqlstore.fetch_all("test") + assert result.data == [{"id": 12, "name": "test12"}] + assert result.has_more is False + + +@pytest.mark.asyncio +async def test_sqlstore_pagination_basic(): + """Test basic pagination functionality at the SQL store level.""" + with TemporaryDirectory() as tmp_dir: + db_path = tmp_dir + "/test.db" + store = SqlAlchemySqlStoreImpl(SqliteSqlStoreConfig(db_path=db_path)) + + # Create test table + await store.create_table( + "test_records", + { + "id": ColumnType.STRING, + "created_at": ColumnType.INTEGER, + "name": ColumnType.STRING, + }, + ) + + # Insert test data + base_time = int(time.time()) + test_data = [ + {"id": "zebra", "created_at": base_time + 1, "name": "First"}, + {"id": "apple", "created_at": base_time + 2, "name": "Second"}, + {"id": "moon", "created_at": base_time + 3, "name": "Third"}, + {"id": "banana", "created_at": base_time + 4, "name": "Fourth"}, + {"id": "car", "created_at": base_time + 5, "name": "Fifth"}, + ] + + for record in test_data: + await store.insert("test_records", record) + + # Test 1: First page (no cursor) + result = await store.fetch_all( + table="test_records", + order_by=[("created_at", "desc")], + limit=2, + ) + assert len(result.data) == 2 + assert result.data[0]["id"] == "car" # Most recent first + assert result.data[1]["id"] == "banana" + assert result.has_more is True + + # Test 2: Second page using cursor + result2 = await store.fetch_all( + table="test_records", + order_by=[("created_at", "desc")], + cursor=("id", "banana"), + limit=2, + ) + assert len(result2.data) == 2 + assert result2.data[0]["id"] == "moon" + assert result2.data[1]["id"] == "apple" + assert result2.has_more is True + + # Test 3: Final page + result3 = await store.fetch_all( + table="test_records", + order_by=[("created_at", "desc")], + cursor=("id", "apple"), + limit=2, + ) + assert len(result3.data) == 1 + assert result3.data[0]["id"] == "zebra" + assert result3.has_more is False + + +@pytest.mark.asyncio +async def test_sqlstore_pagination_with_filter(): + """Test pagination with WHERE conditions.""" + with TemporaryDirectory() as tmp_dir: + db_path = tmp_dir + "/test.db" + store = SqlAlchemySqlStoreImpl(SqliteSqlStoreConfig(db_path=db_path)) + + # Create test table + await store.create_table( + "test_records", + { + "id": ColumnType.STRING, + "created_at": ColumnType.INTEGER, + "category": ColumnType.STRING, + }, + ) + + # Insert test data with categories + base_time = int(time.time()) + test_data = [ + {"id": "xyz", "created_at": base_time + 1, "category": "A"}, + {"id": "def", "created_at": base_time + 2, "category": "B"}, + {"id": "pqr", "created_at": base_time + 3, "category": "A"}, + {"id": "abc", "created_at": base_time + 4, "category": "B"}, + ] + + for record in test_data: + await store.insert("test_records", record) + + # Test pagination with filter + result = await store.fetch_all( + table="test_records", + where={"category": "A"}, + order_by=[("created_at", "desc")], + limit=1, + ) + assert len(result.data) == 1 + assert result.data[0]["id"] == "pqr" # Most recent category A + assert result.has_more is True + + # Second page with filter + result2 = await store.fetch_all( + table="test_records", + where={"category": "A"}, + order_by=[("created_at", "desc")], + cursor=("id", "pqr"), + limit=1, + ) + assert len(result2.data) == 1 + assert result2.data[0]["id"] == "xyz" + assert result2.has_more is False + + +@pytest.mark.asyncio +async def test_sqlstore_pagination_ascending_order(): + """Test pagination with ascending order.""" + with TemporaryDirectory() as tmp_dir: + db_path = tmp_dir + "/test.db" + store = SqlAlchemySqlStoreImpl(SqliteSqlStoreConfig(db_path=db_path)) + + # Create test table + await store.create_table( + "test_records", + { + "id": ColumnType.STRING, + "created_at": ColumnType.INTEGER, + }, + ) + + # Insert test data + base_time = int(time.time()) + test_data = [ + {"id": "gamma", "created_at": base_time + 1}, + {"id": "alpha", "created_at": base_time + 2}, + {"id": "beta", "created_at": base_time + 3}, + ] + + for record in test_data: + await store.insert("test_records", record) + + # Test ascending order + result = await store.fetch_all( + table="test_records", + order_by=[("created_at", "asc")], + limit=1, + ) + assert len(result.data) == 1 + assert result.data[0]["id"] == "gamma" # Oldest first + assert result.has_more is True + + # Second page with ascending order + result2 = await store.fetch_all( + table="test_records", + order_by=[("created_at", "asc")], + cursor=("id", "gamma"), + limit=1, + ) + assert len(result2.data) == 1 + assert result2.data[0]["id"] == "alpha" + assert result2.has_more is True + + +@pytest.mark.asyncio +async def test_sqlstore_pagination_multi_column_ordering_error(): + """Test that multi-column ordering raises an error when using cursor pagination.""" + with TemporaryDirectory() as tmp_dir: + db_path = tmp_dir + "/test.db" + store = SqlAlchemySqlStoreImpl(SqliteSqlStoreConfig(db_path=db_path)) + + # Create test table + await store.create_table( + "test_records", + { + "id": ColumnType.STRING, + "priority": ColumnType.INTEGER, + "created_at": ColumnType.INTEGER, + }, + ) + + await store.insert("test_records", {"id": "task1", "priority": 1, "created_at": 12345}) + + # Test that multi-column ordering with cursor raises error + with pytest.raises(ValueError, match="Cursor pagination only supports single-column ordering, got 2 columns"): + await store.fetch_all( + table="test_records", + order_by=[("priority", "asc"), ("created_at", "desc")], + cursor=("id", "task1"), + limit=2, + ) + + # Test that multi-column ordering without cursor works fine + result = await store.fetch_all( + table="test_records", + order_by=[("priority", "asc"), ("created_at", "desc")], + limit=2, + ) + assert len(result.data) == 1 + assert result.data[0]["id"] == "task1" + + +@pytest.mark.asyncio +async def test_sqlstore_pagination_cursor_requires_order_by(): + """Test that cursor pagination requires order_by parameter.""" + with TemporaryDirectory() as tmp_dir: + db_path = tmp_dir + "/test.db" + store = SqlAlchemySqlStoreImpl(SqliteSqlStoreConfig(db_path=db_path)) + + await store.create_table("test_records", {"id": ColumnType.STRING}) + await store.insert("test_records", {"id": "task1"}) + + # Test that cursor without order_by raises error + with pytest.raises(ValueError, match="order_by is required when using cursor pagination"): + await store.fetch_all( + table="test_records", + cursor=("id", "task1"), + ) + + +@pytest.mark.asyncio +async def test_sqlstore_pagination_error_handling(): + """Test error handling for invalid columns and cursor IDs.""" + with TemporaryDirectory() as tmp_dir: + db_path = tmp_dir + "/test.db" + store = SqlAlchemySqlStoreImpl(SqliteSqlStoreConfig(db_path=db_path)) + + # Create test table + await store.create_table( + "test_records", + { + "id": ColumnType.STRING, + "name": ColumnType.STRING, + }, + ) + + await store.insert("test_records", {"id": "test1", "name": "Test"}) + + # Test invalid cursor tuple format + with pytest.raises(ValueError, match="Cursor must be a tuple of"): + await store.fetch_all( + table="test_records", + order_by=[("name", "asc")], + cursor="invalid", # Should be tuple + ) + + # Test invalid cursor_key_column + with pytest.raises(ValueError, match="Cursor key column 'nonexistent' not found in table"): + await store.fetch_all( + table="test_records", + order_by=[("name", "asc")], + cursor=("nonexistent", "test1"), + ) + + # Test invalid order_by column + with pytest.raises(ValueError, match="Column 'invalid_col' not found in table"): + await store.fetch_all( + table="test_records", + order_by=[("invalid_col", "asc")], + ) + + # Test nonexistent cursor_id + with pytest.raises(ValueError, match="Record with id='nonexistent' not found in table"): + await store.fetch_all( + table="test_records", + order_by=[("name", "asc")], + cursor=("id", "nonexistent"), + ) + + +@pytest.mark.asyncio +async def test_sqlstore_pagination_custom_key_column(): + """Test pagination with custom primary key column (not 'id').""" + with TemporaryDirectory() as tmp_dir: + db_path = tmp_dir + "/test.db" + store = SqlAlchemySqlStoreImpl(SqliteSqlStoreConfig(db_path=db_path)) + + # Create test table with custom primary key + await store.create_table( + "custom_table", + { + "uuid": ColumnType.STRING, + "timestamp": ColumnType.INTEGER, + "data": ColumnType.STRING, + }, + ) + + # Insert test data + base_time = int(time.time()) + test_data = [ + {"uuid": "uuid-alpha", "timestamp": base_time + 1, "data": "First"}, + {"uuid": "uuid-beta", "timestamp": base_time + 2, "data": "Second"}, + {"uuid": "uuid-gamma", "timestamp": base_time + 3, "data": "Third"}, + ] + + for record in test_data: + await store.insert("custom_table", record) + + # Test pagination with custom key column + result = await store.fetch_all( + table="custom_table", + order_by=[("timestamp", "desc")], + limit=2, + ) + assert len(result.data) == 2 + assert result.data[0]["uuid"] == "uuid-gamma" # Most recent + assert result.data[1]["uuid"] == "uuid-beta" + assert result.has_more is True + + # Second page using custom key column + result2 = await store.fetch_all( + table="custom_table", + order_by=[("timestamp", "desc")], + cursor=("uuid", "uuid-beta"), # Use uuid as key column + limit=2, + ) + assert len(result2.data) == 1 + assert result2.data[0]["uuid"] == "uuid-alpha" + assert result2.has_more is False diff --git a/tests/unit/utils/test_sqlstore.py b/tests/unit/utils/test_sqlstore.py deleted file mode 100644 index 6231e9082..000000000 --- a/tests/unit/utils/test_sqlstore.py +++ /dev/null @@ -1,62 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. - -from tempfile import TemporaryDirectory - -import pytest - -from llama_stack.providers.utils.sqlstore.api import ColumnType -from llama_stack.providers.utils.sqlstore.sqlalchemy_sqlstore import SqlAlchemySqlStoreImpl -from llama_stack.providers.utils.sqlstore.sqlstore import SqliteSqlStoreConfig - - -@pytest.mark.asyncio -async def test_sqlite_sqlstore(): - with TemporaryDirectory() as tmp_dir: - db_name = "test.db" - sqlstore = SqlAlchemySqlStoreImpl( - SqliteSqlStoreConfig( - db_path=tmp_dir + "/" + db_name, - ) - ) - await sqlstore.create_table( - table="test", - schema={ - "id": ColumnType.INTEGER, - "name": ColumnType.STRING, - }, - ) - await sqlstore.insert("test", {"id": 1, "name": "test"}) - await sqlstore.insert("test", {"id": 12, "name": "test12"}) - rows = await sqlstore.fetch_all("test") - assert rows == [{"id": 1, "name": "test"}, {"id": 12, "name": "test12"}] - - row = await sqlstore.fetch_one("test", {"id": 1}) - assert row == {"id": 1, "name": "test"} - - row = await sqlstore.fetch_one("test", {"name": "test12"}) - assert row == {"id": 12, "name": "test12"} - - # order by - rows = await sqlstore.fetch_all("test", order_by=[("id", "asc")]) - assert rows == [{"id": 1, "name": "test"}, {"id": 12, "name": "test12"}] - - rows = await sqlstore.fetch_all("test", order_by=[("id", "desc")]) - assert rows == [{"id": 12, "name": "test12"}, {"id": 1, "name": "test"}] - - # limit - rows = await sqlstore.fetch_all("test", limit=1) - assert rows == [{"id": 1, "name": "test"}] - - # update - await sqlstore.update("test", {"name": "test123"}, {"id": 1}) - row = await sqlstore.fetch_one("test", {"id": 1}) - assert row == {"id": 1, "name": "test123"} - - # delete - await sqlstore.delete("test", {"id": 1}) - rows = await sqlstore.fetch_all("test") - assert rows == [{"id": 12, "name": "test12"}]