feat: support pagination in inference/responses stores (#2397)
Some checks failed
Integration Auth Tests / test-matrix (oauth2_token) (push) Failing after 5s
Integration Tests / test-matrix (http, 3.10, agents) (push) Failing after 7s
Integration Tests / test-matrix (http, 3.10, datasets) (push) Failing after 8s
Integration Tests / test-matrix (http, 3.10, providers) (push) Failing after 6s
Integration Tests / test-matrix (http, 3.11, datasets) (push) Failing after 6s
Integration Tests / test-matrix (http, 3.11, agents) (push) Failing after 7s
Integration Tests / test-matrix (http, 3.10, scoring) (push) Failing after 9s
Integration Tests / test-matrix (http, 3.11, scoring) (push) Failing after 5s
Integration Tests / test-matrix (http, 3.10, post_training) (push) Failing after 9s
Integration Tests / test-matrix (http, 3.11, inference) (push) Failing after 8s
Integration Tests / test-matrix (http, 3.12, inspect) (push) Failing after 5s
Integration Tests / test-matrix (http, 3.12, inference) (push) Failing after 5s
Integration Tests / test-matrix (http, 3.11, tool_runtime) (push) Failing after 6s
Integration Tests / test-matrix (http, 3.12, scoring) (push) Failing after 6s
Integration Tests / test-matrix (http, 3.11, providers) (push) Failing after 12s
Integration Tests / test-matrix (http, 3.12, tool_runtime) (push) Failing after 8s
Integration Tests / test-matrix (library, 3.10, inference) (push) Failing after 6s
Integration Tests / test-matrix (http, 3.12, datasets) (push) Failing after 12s
Integration Tests / test-matrix (http, 3.12, providers) (push) Failing after 11s
Integration Tests / test-matrix (library, 3.10, agents) (push) Failing after 9s
Integration Tests / test-matrix (library, 3.10, datasets) (push) Failing after 8s
Integration Tests / test-matrix (library, 3.10, inspect) (push) Failing after 6s
Integration Tests / test-matrix (library, 3.10, providers) (push) Failing after 7s
Integration Tests / test-matrix (http, 3.10, tool_runtime) (push) Failing after 25s
Integration Tests / test-matrix (http, 3.11, inspect) (push) Failing after 23s
Integration Tests / test-matrix (http, 3.10, inference) (push) Failing after 27s
Integration Tests / test-matrix (http, 3.10, inspect) (push) Failing after 29s
Integration Tests / test-matrix (http, 3.12, post_training) (push) Failing after 20s
Integration Tests / test-matrix (http, 3.11, vector_io) (push) Failing after 22s
Integration Tests / test-matrix (http, 3.11, post_training) (push) Failing after 25s
Integration Tests / test-matrix (library, 3.10, scoring) (push) Failing after 8s
Integration Tests / test-matrix (http, 3.12, agents) (push) Failing after 23s
Integration Tests / test-matrix (library, 3.11, datasets) (push) Failing after 5s
Integration Tests / test-matrix (library, 3.10, vector_io) (push) Failing after 7s
Integration Tests / test-matrix (http, 3.10, vector_io) (push) Failing after 27s
Integration Tests / test-matrix (http, 3.12, vector_io) (push) Failing after 19s
Integration Tests / test-matrix (library, 3.10, post_training) (push) Failing after 10s
Integration Tests / test-matrix (library, 3.10, tool_runtime) (push) Failing after 11s
Integration Tests / test-matrix (library, 3.11, inspect) (push) Failing after 8s
Integration Tests / test-matrix (library, 3.11, inference) (push) Failing after 9s
Integration Tests / test-matrix (library, 3.11, providers) (push) Failing after 8s
Integration Tests / test-matrix (library, 3.11, tool_runtime) (push) Failing after 8s
Integration Tests / test-matrix (library, 3.11, scoring) (push) Failing after 10s
Integration Tests / test-matrix (library, 3.11, agents) (push) Failing after 11s
Integration Tests / test-matrix (library, 3.11, post_training) (push) Failing after 10s
Integration Tests / test-matrix (library, 3.12, agents) (push) Failing after 7s
Integration Tests / test-matrix (library, 3.11, vector_io) (push) Failing after 10s
Integration Tests / test-matrix (library, 3.12, inference) (push) Failing after 7s
Integration Tests / test-matrix (library, 3.12, datasets) (push) Failing after 10s
Integration Tests / test-matrix (library, 3.12, inspect) (push) Failing after 10s
Integration Tests / test-matrix (library, 3.12, post_training) (push) Failing after 8s
Integration Tests / test-matrix (library, 3.12, scoring) (push) Failing after 8s
Integration Tests / test-matrix (library, 3.12, providers) (push) Failing after 10s
Integration Tests / test-matrix (library, 3.12, vector_io) (push) Failing after 44s
Integration Tests / test-matrix (library, 3.12, tool_runtime) (push) Failing after 46s
Test External Providers / test-external-providers (venv) (push) Failing after 41s
Unit Tests / unit-tests (3.10) (push) Failing after 52s
Unit Tests / unit-tests (3.12) (push) Failing after 18s
Unit Tests / unit-tests (3.11) (push) Failing after 20s
Unit Tests / unit-tests (3.13) (push) Failing after 16s
Pre-commit / pre-commit (push) Successful in 2m0s

# What does this PR do?


## Test Plan
added unit tests
This commit is contained in:
ehhuang 2025-06-16 22:43:35 -07:00 committed by GitHub
parent 6f1a935365
commit 15f630e5da
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
10 changed files with 1130 additions and 117 deletions

View file

@ -114,18 +114,18 @@ class LocalfsFilesImpl(Files):
if not self.sql_store:
raise RuntimeError("Files provider not initialized")
# TODO: Implement 'after' pagination properly
if after:
raise NotImplementedError("After pagination not yet implemented")
if not order:
order = Order.desc
where = None
where_conditions = {}
if purpose:
where = {"purpose": purpose.value}
where_conditions["purpose"] = purpose.value
rows = await self.sql_store.fetch_all(
"openai_files",
where=where,
order_by=[("created_at", order.value if order else Order.desc.value)],
paginated_result = await self.sql_store.fetch_all(
table="openai_files",
where=where_conditions if where_conditions else None,
order_by=[("created_at", order.value)],
cursor=("id", after) if after else None,
limit=limit,
)
@ -138,12 +138,12 @@ class LocalfsFilesImpl(Files):
created_at=row["created_at"],
expires_at=row["expires_at"],
)
for row in rows
for row in paginated_result.data
]
return ListOpenAIFileResponse(
data=files,
has_more=False, # TODO: Implement proper pagination
has_more=paginated_result.has_more,
first_id=files[0].id if files else "",
last_id=files[-1].id if files else "",
)

View file

@ -76,16 +76,18 @@ class InferenceStore:
if not self.sql_store:
raise ValueError("Inference store is not initialized")
# TODO: support after
if after:
raise NotImplementedError("After is not supported for SQLite")
if not order:
order = Order.desc
rows = await self.sql_store.fetch_all(
"chat_completions",
where={"model": model} if model else None,
where_conditions = {}
if model:
where_conditions["model"] = model
paginated_result = await self.sql_store.fetch_all(
table="chat_completions",
where=where_conditions if where_conditions else None,
order_by=[("created", order.value)],
cursor=("id", after) if after else None,
limit=limit,
)
@ -97,12 +99,11 @@ class InferenceStore:
choices=row["choices"],
input_messages=row["input_messages"],
)
for row in rows
for row in paginated_result.data
]
return ListOpenAIChatCompletionResponse(
data=data,
# TODO: implement has_more
has_more=False,
has_more=paginated_result.has_more,
first_id=data[0].id if data else "",
last_id=data[-1].id if data else "",
)

View file

@ -70,24 +70,25 @@ class ResponsesStore:
:param model: The model to filter by.
:param order: The order to sort the responses by.
"""
# TODO: support after
if after:
raise NotImplementedError("After is not supported for SQLite")
if not order:
order = Order.desc
rows = await self.sql_store.fetch_all(
"openai_responses",
where={"model": model} if model else None,
where_conditions = {}
if model:
where_conditions["model"] = model
paginated_result = await self.sql_store.fetch_all(
table="openai_responses",
where=where_conditions if where_conditions else None,
order_by=[("created_at", order.value)],
cursor=("id", after) if after else None,
limit=limit,
)
data = [OpenAIResponseObjectWithInput(**row["response_object"]) for row in rows]
data = [OpenAIResponseObjectWithInput(**row["response_object"]) for row in paginated_result.data]
return ListOpenAIResponseObject(
data=data,
# TODO: implement has_more
has_more=False,
has_more=paginated_result.has_more,
first_id=data[0].id if data else "",
last_id=data[-1].id if data else "",
)
@ -117,19 +118,38 @@ class ResponsesStore:
:param limit: A limit on the number of objects to be returned.
:param order: The order to return the input items in.
"""
# TODO: support after/before pagination
if after or before:
raise NotImplementedError("After/before pagination is not supported yet")
if include:
raise NotImplementedError("Include is not supported yet")
if before and after:
raise ValueError("Cannot specify both 'before' and 'after' parameters")
response_with_input = await self.get_response_object(response_id)
input_items = response_with_input.input
items = response_with_input.input
if order == Order.desc:
input_items = list(reversed(input_items))
items = list(reversed(items))
if limit is not None and len(input_items) > limit:
input_items = input_items[:limit]
start_index = 0
end_index = len(items)
return ListOpenAIResponseInputItem(data=input_items)
if after or before:
for i, item in enumerate(items):
item_id = getattr(item, "id", None)
if after and item_id == after:
start_index = i + 1
if before and item_id == before:
end_index = i
break
if after and start_index == 0:
raise ValueError(f"Input item with id '{after}' not found for response '{response_id}'")
if before and end_index == len(items):
raise ValueError(f"Input item with id '{before}' not found for response '{response_id}'")
items = items[start_index:end_index]
# Apply limit
if limit is not None:
items = items[:limit]
return ListOpenAIResponseInputItem(data=items)

View file

@ -10,6 +10,8 @@ from typing import Any, Literal, Protocol
from pydantic import BaseModel
from llama_stack.apis.common.responses import PaginatedResponse
class ColumnType(Enum):
INTEGER = "INTEGER"
@ -51,9 +53,21 @@ class SqlStore(Protocol):
where: Mapping[str, Any] | None = None,
limit: int | None = None,
order_by: list[tuple[str, Literal["asc", "desc"]]] | None = None,
) -> list[dict[str, Any]]:
cursor: tuple[str, str] | None = None,
) -> PaginatedResponse:
"""
Fetch all rows from a table.
Fetch all rows from a table with optional cursor-based pagination.
:param table: The table name
:param where: WHERE conditions
:param limit: Maximum number of records to return
:param order_by: List of (column, order) tuples for sorting
:param cursor: Tuple of (key_column, cursor_id) for pagination (None for first page)
Requires order_by with exactly one column when used
:return: PaginatedResult with data and has_more flag
Note: Cursor pagination only supports single-column ordering for simplicity.
Multi-column ordering is allowed without cursor but will raise an error with cursor.
"""
pass

View file

@ -21,6 +21,8 @@ from sqlalchemy import (
)
from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine
from llama_stack.apis.common.responses import PaginatedResponse
from .api import ColumnDefinition, ColumnType, SqlStore
from .sqlstore import SqlAlchemySqlStoreConfig
@ -94,14 +96,56 @@ class SqlAlchemySqlStoreImpl(SqlStore):
where: Mapping[str, Any] | None = None,
limit: int | None = None,
order_by: list[tuple[str, Literal["asc", "desc"]]] | None = None,
) -> list[dict[str, Any]]:
cursor: tuple[str, str] | None = None,
) -> PaginatedResponse:
async with self.async_session() as session:
query = select(self.metadata.tables[table])
table_obj = self.metadata.tables[table]
query = select(table_obj)
if where:
for key, value in where.items():
query = query.where(self.metadata.tables[table].c[key] == value)
if limit:
query = query.limit(limit)
query = query.where(table_obj.c[key] == value)
# Handle cursor-based pagination
if cursor:
# Validate cursor tuple format
if not isinstance(cursor, tuple) or len(cursor) != 2:
raise ValueError(f"Cursor must be a tuple of (key_column, cursor_id), got: {cursor}")
# Require order_by for cursor pagination
if not order_by:
raise ValueError("order_by is required when using cursor pagination")
# Only support single-column ordering for cursor pagination
if len(order_by) != 1:
raise ValueError(
f"Cursor pagination only supports single-column ordering, got {len(order_by)} columns"
)
cursor_key_column, cursor_id = cursor
order_column, order_direction = order_by[0]
# Verify cursor_key_column exists
if cursor_key_column not in table_obj.c:
raise ValueError(f"Cursor key column '{cursor_key_column}' not found in table '{table}'")
# Get cursor value for the order column
cursor_query = select(table_obj.c[order_column]).where(table_obj.c[cursor_key_column] == cursor_id)
cursor_result = await session.execute(cursor_query)
cursor_row = cursor_result.fetchone()
if not cursor_row:
raise ValueError(f"Record with {cursor_key_column}='{cursor_id}' not found in table '{table}'")
cursor_value = cursor_row[0]
# Apply cursor condition based on sort direction
if order_direction == "desc":
query = query.where(table_obj.c[order_column] < cursor_value)
else:
query = query.where(table_obj.c[order_column] > cursor_value)
# Apply ordering
if order_by:
if not isinstance(order_by, list):
raise ValueError(
@ -113,16 +157,36 @@ class SqlAlchemySqlStoreImpl(SqlStore):
f"order_by must be a list of tuples (column, order={['asc', 'desc']}), got {order_by}"
)
name, order_type = order
if name not in table_obj.c:
raise ValueError(f"Column '{name}' not found in table '{table}'")
if order_type == "asc":
query = query.order_by(self.metadata.tables[table].c[name].asc())
query = query.order_by(table_obj.c[name].asc())
elif order_type == "desc":
query = query.order_by(self.metadata.tables[table].c[name].desc())
query = query.order_by(table_obj.c[name].desc())
else:
raise ValueError(f"Invalid order '{order_type}' for column '{name}'")
# Fetch limit + 1 to determine has_more
fetch_limit = limit
if limit:
fetch_limit = limit + 1
if fetch_limit:
query = query.limit(fetch_limit)
result = await session.execute(query)
if result.rowcount == 0:
return []
return [dict(row._mapping) for row in result]
rows = []
else:
rows = [dict(row._mapping) for row in result]
# Always return pagination result
has_more = False
if limit and len(rows) > limit:
has_more = True
rows = rows[:limit]
return PaginatedResponse(data=rows, has_more=has_more)
async def fetch_one(
self,
@ -130,10 +194,10 @@ class SqlAlchemySqlStoreImpl(SqlStore):
where: Mapping[str, Any] | None = None,
order_by: list[tuple[str, Literal["asc", "desc"]]] | None = None,
) -> dict[str, Any] | None:
rows = await self.fetch_all(table, where, limit=1, order_by=order_by)
if not rows:
result = await self.fetch_all(table, where, limit=1, order_by=order_by)
if not result.data:
return None
return rows[0]
return result.data[0]
async def update(
self,