mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-06-27 10:46:41 +00:00
feat: support pagination in inference/responses stores (#2397)
Some checks failed
Integration Auth Tests / test-matrix (oauth2_token) (push) Failing after 5s
Integration Tests / test-matrix (http, 3.10, agents) (push) Failing after 7s
Integration Tests / test-matrix (http, 3.10, datasets) (push) Failing after 8s
Integration Tests / test-matrix (http, 3.10, providers) (push) Failing after 6s
Integration Tests / test-matrix (http, 3.11, datasets) (push) Failing after 6s
Integration Tests / test-matrix (http, 3.11, agents) (push) Failing after 7s
Integration Tests / test-matrix (http, 3.10, scoring) (push) Failing after 9s
Integration Tests / test-matrix (http, 3.11, scoring) (push) Failing after 5s
Integration Tests / test-matrix (http, 3.10, post_training) (push) Failing after 9s
Integration Tests / test-matrix (http, 3.11, inference) (push) Failing after 8s
Integration Tests / test-matrix (http, 3.12, inspect) (push) Failing after 5s
Integration Tests / test-matrix (http, 3.12, inference) (push) Failing after 5s
Integration Tests / test-matrix (http, 3.11, tool_runtime) (push) Failing after 6s
Integration Tests / test-matrix (http, 3.12, scoring) (push) Failing after 6s
Integration Tests / test-matrix (http, 3.11, providers) (push) Failing after 12s
Integration Tests / test-matrix (http, 3.12, tool_runtime) (push) Failing after 8s
Integration Tests / test-matrix (library, 3.10, inference) (push) Failing after 6s
Integration Tests / test-matrix (http, 3.12, datasets) (push) Failing after 12s
Integration Tests / test-matrix (http, 3.12, providers) (push) Failing after 11s
Integration Tests / test-matrix (library, 3.10, agents) (push) Failing after 9s
Integration Tests / test-matrix (library, 3.10, datasets) (push) Failing after 8s
Integration Tests / test-matrix (library, 3.10, inspect) (push) Failing after 6s
Integration Tests / test-matrix (library, 3.10, providers) (push) Failing after 7s
Integration Tests / test-matrix (http, 3.10, tool_runtime) (push) Failing after 25s
Integration Tests / test-matrix (http, 3.11, inspect) (push) Failing after 23s
Integration Tests / test-matrix (http, 3.10, inference) (push) Failing after 27s
Integration Tests / test-matrix (http, 3.10, inspect) (push) Failing after 29s
Integration Tests / test-matrix (http, 3.12, post_training) (push) Failing after 20s
Integration Tests / test-matrix (http, 3.11, vector_io) (push) Failing after 22s
Integration Tests / test-matrix (http, 3.11, post_training) (push) Failing after 25s
Integration Tests / test-matrix (library, 3.10, scoring) (push) Failing after 8s
Integration Tests / test-matrix (http, 3.12, agents) (push) Failing after 23s
Integration Tests / test-matrix (library, 3.11, datasets) (push) Failing after 5s
Integration Tests / test-matrix (library, 3.10, vector_io) (push) Failing after 7s
Integration Tests / test-matrix (http, 3.10, vector_io) (push) Failing after 27s
Integration Tests / test-matrix (http, 3.12, vector_io) (push) Failing after 19s
Integration Tests / test-matrix (library, 3.10, post_training) (push) Failing after 10s
Integration Tests / test-matrix (library, 3.10, tool_runtime) (push) Failing after 11s
Integration Tests / test-matrix (library, 3.11, inspect) (push) Failing after 8s
Integration Tests / test-matrix (library, 3.11, inference) (push) Failing after 9s
Integration Tests / test-matrix (library, 3.11, providers) (push) Failing after 8s
Integration Tests / test-matrix (library, 3.11, tool_runtime) (push) Failing after 8s
Integration Tests / test-matrix (library, 3.11, scoring) (push) Failing after 10s
Integration Tests / test-matrix (library, 3.11, agents) (push) Failing after 11s
Integration Tests / test-matrix (library, 3.11, post_training) (push) Failing after 10s
Integration Tests / test-matrix (library, 3.12, agents) (push) Failing after 7s
Integration Tests / test-matrix (library, 3.11, vector_io) (push) Failing after 10s
Integration Tests / test-matrix (library, 3.12, inference) (push) Failing after 7s
Integration Tests / test-matrix (library, 3.12, datasets) (push) Failing after 10s
Integration Tests / test-matrix (library, 3.12, inspect) (push) Failing after 10s
Integration Tests / test-matrix (library, 3.12, post_training) (push) Failing after 8s
Integration Tests / test-matrix (library, 3.12, scoring) (push) Failing after 8s
Integration Tests / test-matrix (library, 3.12, providers) (push) Failing after 10s
Integration Tests / test-matrix (library, 3.12, vector_io) (push) Failing after 44s
Integration Tests / test-matrix (library, 3.12, tool_runtime) (push) Failing after 46s
Test External Providers / test-external-providers (venv) (push) Failing after 41s
Unit Tests / unit-tests (3.10) (push) Failing after 52s
Unit Tests / unit-tests (3.12) (push) Failing after 18s
Unit Tests / unit-tests (3.11) (push) Failing after 20s
Unit Tests / unit-tests (3.13) (push) Failing after 16s
Pre-commit / pre-commit (push) Successful in 2m0s
Some checks failed
Integration Auth Tests / test-matrix (oauth2_token) (push) Failing after 5s
Integration Tests / test-matrix (http, 3.10, agents) (push) Failing after 7s
Integration Tests / test-matrix (http, 3.10, datasets) (push) Failing after 8s
Integration Tests / test-matrix (http, 3.10, providers) (push) Failing after 6s
Integration Tests / test-matrix (http, 3.11, datasets) (push) Failing after 6s
Integration Tests / test-matrix (http, 3.11, agents) (push) Failing after 7s
Integration Tests / test-matrix (http, 3.10, scoring) (push) Failing after 9s
Integration Tests / test-matrix (http, 3.11, scoring) (push) Failing after 5s
Integration Tests / test-matrix (http, 3.10, post_training) (push) Failing after 9s
Integration Tests / test-matrix (http, 3.11, inference) (push) Failing after 8s
Integration Tests / test-matrix (http, 3.12, inspect) (push) Failing after 5s
Integration Tests / test-matrix (http, 3.12, inference) (push) Failing after 5s
Integration Tests / test-matrix (http, 3.11, tool_runtime) (push) Failing after 6s
Integration Tests / test-matrix (http, 3.12, scoring) (push) Failing after 6s
Integration Tests / test-matrix (http, 3.11, providers) (push) Failing after 12s
Integration Tests / test-matrix (http, 3.12, tool_runtime) (push) Failing after 8s
Integration Tests / test-matrix (library, 3.10, inference) (push) Failing after 6s
Integration Tests / test-matrix (http, 3.12, datasets) (push) Failing after 12s
Integration Tests / test-matrix (http, 3.12, providers) (push) Failing after 11s
Integration Tests / test-matrix (library, 3.10, agents) (push) Failing after 9s
Integration Tests / test-matrix (library, 3.10, datasets) (push) Failing after 8s
Integration Tests / test-matrix (library, 3.10, inspect) (push) Failing after 6s
Integration Tests / test-matrix (library, 3.10, providers) (push) Failing after 7s
Integration Tests / test-matrix (http, 3.10, tool_runtime) (push) Failing after 25s
Integration Tests / test-matrix (http, 3.11, inspect) (push) Failing after 23s
Integration Tests / test-matrix (http, 3.10, inference) (push) Failing after 27s
Integration Tests / test-matrix (http, 3.10, inspect) (push) Failing after 29s
Integration Tests / test-matrix (http, 3.12, post_training) (push) Failing after 20s
Integration Tests / test-matrix (http, 3.11, vector_io) (push) Failing after 22s
Integration Tests / test-matrix (http, 3.11, post_training) (push) Failing after 25s
Integration Tests / test-matrix (library, 3.10, scoring) (push) Failing after 8s
Integration Tests / test-matrix (http, 3.12, agents) (push) Failing after 23s
Integration Tests / test-matrix (library, 3.11, datasets) (push) Failing after 5s
Integration Tests / test-matrix (library, 3.10, vector_io) (push) Failing after 7s
Integration Tests / test-matrix (http, 3.10, vector_io) (push) Failing after 27s
Integration Tests / test-matrix (http, 3.12, vector_io) (push) Failing after 19s
Integration Tests / test-matrix (library, 3.10, post_training) (push) Failing after 10s
Integration Tests / test-matrix (library, 3.10, tool_runtime) (push) Failing after 11s
Integration Tests / test-matrix (library, 3.11, inspect) (push) Failing after 8s
Integration Tests / test-matrix (library, 3.11, inference) (push) Failing after 9s
Integration Tests / test-matrix (library, 3.11, providers) (push) Failing after 8s
Integration Tests / test-matrix (library, 3.11, tool_runtime) (push) Failing after 8s
Integration Tests / test-matrix (library, 3.11, scoring) (push) Failing after 10s
Integration Tests / test-matrix (library, 3.11, agents) (push) Failing after 11s
Integration Tests / test-matrix (library, 3.11, post_training) (push) Failing after 10s
Integration Tests / test-matrix (library, 3.12, agents) (push) Failing after 7s
Integration Tests / test-matrix (library, 3.11, vector_io) (push) Failing after 10s
Integration Tests / test-matrix (library, 3.12, inference) (push) Failing after 7s
Integration Tests / test-matrix (library, 3.12, datasets) (push) Failing after 10s
Integration Tests / test-matrix (library, 3.12, inspect) (push) Failing after 10s
Integration Tests / test-matrix (library, 3.12, post_training) (push) Failing after 8s
Integration Tests / test-matrix (library, 3.12, scoring) (push) Failing after 8s
Integration Tests / test-matrix (library, 3.12, providers) (push) Failing after 10s
Integration Tests / test-matrix (library, 3.12, vector_io) (push) Failing after 44s
Integration Tests / test-matrix (library, 3.12, tool_runtime) (push) Failing after 46s
Test External Providers / test-external-providers (venv) (push) Failing after 41s
Unit Tests / unit-tests (3.10) (push) Failing after 52s
Unit Tests / unit-tests (3.12) (push) Failing after 18s
Unit Tests / unit-tests (3.11) (push) Failing after 20s
Unit Tests / unit-tests (3.13) (push) Failing after 16s
Pre-commit / pre-commit (push) Successful in 2m0s
# What does this PR do? ## Test Plan added unit tests
This commit is contained in:
parent
6f1a935365
commit
15f630e5da
10 changed files with 1130 additions and 117 deletions
|
@ -114,18 +114,18 @@ class LocalfsFilesImpl(Files):
|
|||
if not self.sql_store:
|
||||
raise RuntimeError("Files provider not initialized")
|
||||
|
||||
# TODO: Implement 'after' pagination properly
|
||||
if after:
|
||||
raise NotImplementedError("After pagination not yet implemented")
|
||||
if not order:
|
||||
order = Order.desc
|
||||
|
||||
where = None
|
||||
where_conditions = {}
|
||||
if purpose:
|
||||
where = {"purpose": purpose.value}
|
||||
where_conditions["purpose"] = purpose.value
|
||||
|
||||
rows = await self.sql_store.fetch_all(
|
||||
"openai_files",
|
||||
where=where,
|
||||
order_by=[("created_at", order.value if order else Order.desc.value)],
|
||||
paginated_result = await self.sql_store.fetch_all(
|
||||
table="openai_files",
|
||||
where=where_conditions if where_conditions else None,
|
||||
order_by=[("created_at", order.value)],
|
||||
cursor=("id", after) if after else None,
|
||||
limit=limit,
|
||||
)
|
||||
|
||||
|
@ -138,12 +138,12 @@ class LocalfsFilesImpl(Files):
|
|||
created_at=row["created_at"],
|
||||
expires_at=row["expires_at"],
|
||||
)
|
||||
for row in rows
|
||||
for row in paginated_result.data
|
||||
]
|
||||
|
||||
return ListOpenAIFileResponse(
|
||||
data=files,
|
||||
has_more=False, # TODO: Implement proper pagination
|
||||
has_more=paginated_result.has_more,
|
||||
first_id=files[0].id if files else "",
|
||||
last_id=files[-1].id if files else "",
|
||||
)
|
||||
|
|
|
@ -76,16 +76,18 @@ class InferenceStore:
|
|||
if not self.sql_store:
|
||||
raise ValueError("Inference store is not initialized")
|
||||
|
||||
# TODO: support after
|
||||
if after:
|
||||
raise NotImplementedError("After is not supported for SQLite")
|
||||
if not order:
|
||||
order = Order.desc
|
||||
|
||||
rows = await self.sql_store.fetch_all(
|
||||
"chat_completions",
|
||||
where={"model": model} if model else None,
|
||||
where_conditions = {}
|
||||
if model:
|
||||
where_conditions["model"] = model
|
||||
|
||||
paginated_result = await self.sql_store.fetch_all(
|
||||
table="chat_completions",
|
||||
where=where_conditions if where_conditions else None,
|
||||
order_by=[("created", order.value)],
|
||||
cursor=("id", after) if after else None,
|
||||
limit=limit,
|
||||
)
|
||||
|
||||
|
@ -97,12 +99,11 @@ class InferenceStore:
|
|||
choices=row["choices"],
|
||||
input_messages=row["input_messages"],
|
||||
)
|
||||
for row in rows
|
||||
for row in paginated_result.data
|
||||
]
|
||||
return ListOpenAIChatCompletionResponse(
|
||||
data=data,
|
||||
# TODO: implement has_more
|
||||
has_more=False,
|
||||
has_more=paginated_result.has_more,
|
||||
first_id=data[0].id if data else "",
|
||||
last_id=data[-1].id if data else "",
|
||||
)
|
||||
|
|
|
@ -70,24 +70,25 @@ class ResponsesStore:
|
|||
:param model: The model to filter by.
|
||||
:param order: The order to sort the responses by.
|
||||
"""
|
||||
# TODO: support after
|
||||
if after:
|
||||
raise NotImplementedError("After is not supported for SQLite")
|
||||
if not order:
|
||||
order = Order.desc
|
||||
|
||||
rows = await self.sql_store.fetch_all(
|
||||
"openai_responses",
|
||||
where={"model": model} if model else None,
|
||||
where_conditions = {}
|
||||
if model:
|
||||
where_conditions["model"] = model
|
||||
|
||||
paginated_result = await self.sql_store.fetch_all(
|
||||
table="openai_responses",
|
||||
where=where_conditions if where_conditions else None,
|
||||
order_by=[("created_at", order.value)],
|
||||
cursor=("id", after) if after else None,
|
||||
limit=limit,
|
||||
)
|
||||
|
||||
data = [OpenAIResponseObjectWithInput(**row["response_object"]) for row in rows]
|
||||
data = [OpenAIResponseObjectWithInput(**row["response_object"]) for row in paginated_result.data]
|
||||
return ListOpenAIResponseObject(
|
||||
data=data,
|
||||
# TODO: implement has_more
|
||||
has_more=False,
|
||||
has_more=paginated_result.has_more,
|
||||
first_id=data[0].id if data else "",
|
||||
last_id=data[-1].id if data else "",
|
||||
)
|
||||
|
@ -117,19 +118,38 @@ class ResponsesStore:
|
|||
:param limit: A limit on the number of objects to be returned.
|
||||
:param order: The order to return the input items in.
|
||||
"""
|
||||
# TODO: support after/before pagination
|
||||
if after or before:
|
||||
raise NotImplementedError("After/before pagination is not supported yet")
|
||||
if include:
|
||||
raise NotImplementedError("Include is not supported yet")
|
||||
if before and after:
|
||||
raise ValueError("Cannot specify both 'before' and 'after' parameters")
|
||||
|
||||
response_with_input = await self.get_response_object(response_id)
|
||||
input_items = response_with_input.input
|
||||
items = response_with_input.input
|
||||
|
||||
if order == Order.desc:
|
||||
input_items = list(reversed(input_items))
|
||||
items = list(reversed(items))
|
||||
|
||||
if limit is not None and len(input_items) > limit:
|
||||
input_items = input_items[:limit]
|
||||
start_index = 0
|
||||
end_index = len(items)
|
||||
|
||||
return ListOpenAIResponseInputItem(data=input_items)
|
||||
if after or before:
|
||||
for i, item in enumerate(items):
|
||||
item_id = getattr(item, "id", None)
|
||||
if after and item_id == after:
|
||||
start_index = i + 1
|
||||
if before and item_id == before:
|
||||
end_index = i
|
||||
break
|
||||
|
||||
if after and start_index == 0:
|
||||
raise ValueError(f"Input item with id '{after}' not found for response '{response_id}'")
|
||||
if before and end_index == len(items):
|
||||
raise ValueError(f"Input item with id '{before}' not found for response '{response_id}'")
|
||||
|
||||
items = items[start_index:end_index]
|
||||
|
||||
# Apply limit
|
||||
if limit is not None:
|
||||
items = items[:limit]
|
||||
|
||||
return ListOpenAIResponseInputItem(data=items)
|
||||
|
|
|
@ -10,6 +10,8 @@ from typing import Any, Literal, Protocol
|
|||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from llama_stack.apis.common.responses import PaginatedResponse
|
||||
|
||||
|
||||
class ColumnType(Enum):
|
||||
INTEGER = "INTEGER"
|
||||
|
@ -51,9 +53,21 @@ class SqlStore(Protocol):
|
|||
where: Mapping[str, Any] | None = None,
|
||||
limit: int | None = None,
|
||||
order_by: list[tuple[str, Literal["asc", "desc"]]] | None = None,
|
||||
) -> list[dict[str, Any]]:
|
||||
cursor: tuple[str, str] | None = None,
|
||||
) -> PaginatedResponse:
|
||||
"""
|
||||
Fetch all rows from a table.
|
||||
Fetch all rows from a table with optional cursor-based pagination.
|
||||
|
||||
:param table: The table name
|
||||
:param where: WHERE conditions
|
||||
:param limit: Maximum number of records to return
|
||||
:param order_by: List of (column, order) tuples for sorting
|
||||
:param cursor: Tuple of (key_column, cursor_id) for pagination (None for first page)
|
||||
Requires order_by with exactly one column when used
|
||||
:return: PaginatedResult with data and has_more flag
|
||||
|
||||
Note: Cursor pagination only supports single-column ordering for simplicity.
|
||||
Multi-column ordering is allowed without cursor but will raise an error with cursor.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
|
|
@ -21,6 +21,8 @@ from sqlalchemy import (
|
|||
)
|
||||
from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine
|
||||
|
||||
from llama_stack.apis.common.responses import PaginatedResponse
|
||||
|
||||
from .api import ColumnDefinition, ColumnType, SqlStore
|
||||
from .sqlstore import SqlAlchemySqlStoreConfig
|
||||
|
||||
|
@ -94,14 +96,56 @@ class SqlAlchemySqlStoreImpl(SqlStore):
|
|||
where: Mapping[str, Any] | None = None,
|
||||
limit: int | None = None,
|
||||
order_by: list[tuple[str, Literal["asc", "desc"]]] | None = None,
|
||||
) -> list[dict[str, Any]]:
|
||||
cursor: tuple[str, str] | None = None,
|
||||
) -> PaginatedResponse:
|
||||
async with self.async_session() as session:
|
||||
query = select(self.metadata.tables[table])
|
||||
table_obj = self.metadata.tables[table]
|
||||
query = select(table_obj)
|
||||
|
||||
if where:
|
||||
for key, value in where.items():
|
||||
query = query.where(self.metadata.tables[table].c[key] == value)
|
||||
if limit:
|
||||
query = query.limit(limit)
|
||||
query = query.where(table_obj.c[key] == value)
|
||||
|
||||
# Handle cursor-based pagination
|
||||
if cursor:
|
||||
# Validate cursor tuple format
|
||||
if not isinstance(cursor, tuple) or len(cursor) != 2:
|
||||
raise ValueError(f"Cursor must be a tuple of (key_column, cursor_id), got: {cursor}")
|
||||
|
||||
# Require order_by for cursor pagination
|
||||
if not order_by:
|
||||
raise ValueError("order_by is required when using cursor pagination")
|
||||
|
||||
# Only support single-column ordering for cursor pagination
|
||||
if len(order_by) != 1:
|
||||
raise ValueError(
|
||||
f"Cursor pagination only supports single-column ordering, got {len(order_by)} columns"
|
||||
)
|
||||
|
||||
cursor_key_column, cursor_id = cursor
|
||||
order_column, order_direction = order_by[0]
|
||||
|
||||
# Verify cursor_key_column exists
|
||||
if cursor_key_column not in table_obj.c:
|
||||
raise ValueError(f"Cursor key column '{cursor_key_column}' not found in table '{table}'")
|
||||
|
||||
# Get cursor value for the order column
|
||||
cursor_query = select(table_obj.c[order_column]).where(table_obj.c[cursor_key_column] == cursor_id)
|
||||
cursor_result = await session.execute(cursor_query)
|
||||
cursor_row = cursor_result.fetchone()
|
||||
|
||||
if not cursor_row:
|
||||
raise ValueError(f"Record with {cursor_key_column}='{cursor_id}' not found in table '{table}'")
|
||||
|
||||
cursor_value = cursor_row[0]
|
||||
|
||||
# Apply cursor condition based on sort direction
|
||||
if order_direction == "desc":
|
||||
query = query.where(table_obj.c[order_column] < cursor_value)
|
||||
else:
|
||||
query = query.where(table_obj.c[order_column] > cursor_value)
|
||||
|
||||
# Apply ordering
|
||||
if order_by:
|
||||
if not isinstance(order_by, list):
|
||||
raise ValueError(
|
||||
|
@ -113,16 +157,36 @@ class SqlAlchemySqlStoreImpl(SqlStore):
|
|||
f"order_by must be a list of tuples (column, order={['asc', 'desc']}), got {order_by}"
|
||||
)
|
||||
name, order_type = order
|
||||
if name not in table_obj.c:
|
||||
raise ValueError(f"Column '{name}' not found in table '{table}'")
|
||||
if order_type == "asc":
|
||||
query = query.order_by(self.metadata.tables[table].c[name].asc())
|
||||
query = query.order_by(table_obj.c[name].asc())
|
||||
elif order_type == "desc":
|
||||
query = query.order_by(self.metadata.tables[table].c[name].desc())
|
||||
query = query.order_by(table_obj.c[name].desc())
|
||||
else:
|
||||
raise ValueError(f"Invalid order '{order_type}' for column '{name}'")
|
||||
|
||||
# Fetch limit + 1 to determine has_more
|
||||
fetch_limit = limit
|
||||
if limit:
|
||||
fetch_limit = limit + 1
|
||||
|
||||
if fetch_limit:
|
||||
query = query.limit(fetch_limit)
|
||||
|
||||
result = await session.execute(query)
|
||||
if result.rowcount == 0:
|
||||
return []
|
||||
return [dict(row._mapping) for row in result]
|
||||
rows = []
|
||||
else:
|
||||
rows = [dict(row._mapping) for row in result]
|
||||
|
||||
# Always return pagination result
|
||||
has_more = False
|
||||
if limit and len(rows) > limit:
|
||||
has_more = True
|
||||
rows = rows[:limit]
|
||||
|
||||
return PaginatedResponse(data=rows, has_more=has_more)
|
||||
|
||||
async def fetch_one(
|
||||
self,
|
||||
|
@ -130,10 +194,10 @@ class SqlAlchemySqlStoreImpl(SqlStore):
|
|||
where: Mapping[str, Any] | None = None,
|
||||
order_by: list[tuple[str, Literal["asc", "desc"]]] | None = None,
|
||||
) -> dict[str, Any] | None:
|
||||
rows = await self.fetch_all(table, where, limit=1, order_by=order_by)
|
||||
if not rows:
|
||||
result = await self.fetch_all(table, where, limit=1, order_by=order_by)
|
||||
if not result.data:
|
||||
return None
|
||||
return rows[0]
|
||||
return result.data[0]
|
||||
|
||||
async def update(
|
||||
self,
|
||||
|
|
|
@ -328,7 +328,24 @@ class TestOpenAIFilesAPI:
|
|||
assert uploaded_file.filename == "uploaded_file" # Default filename
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_after_pagination_not_implemented(self, files_provider):
|
||||
"""Test that 'after' pagination raises NotImplementedError."""
|
||||
with pytest.raises(NotImplementedError, match="After pagination not yet implemented"):
|
||||
await files_provider.openai_list_files(after="file-some-id")
|
||||
async def test_after_pagination_works(self, files_provider, sample_text_file):
|
||||
"""Test that 'after' pagination works correctly."""
|
||||
# Upload multiple files to test pagination
|
||||
uploaded_files = []
|
||||
for _ in range(5):
|
||||
file = await files_provider.openai_upload_file(file=sample_text_file, purpose=OpenAIFilePurpose.ASSISTANTS)
|
||||
uploaded_files.append(file)
|
||||
|
||||
# Get first page without 'after' parameter
|
||||
first_page = await files_provider.openai_list_files(limit=2, order=Order.desc)
|
||||
assert len(first_page.data) == 2
|
||||
assert first_page.has_more is True
|
||||
|
||||
# Get second page using 'after' parameter
|
||||
second_page = await files_provider.openai_list_files(after=first_page.data[-1].id, limit=2, order=Order.desc)
|
||||
assert len(second_page.data) <= 2
|
||||
|
||||
# Verify no overlap between pages
|
||||
first_page_ids = {f.id for f in first_page.data}
|
||||
second_page_ids = {f.id for f in second_page.data}
|
||||
assert first_page_ids.isdisjoint(second_page_ids)
|
||||
|
|
203
tests/unit/utils/inference/test_inference_store.py
Normal file
203
tests/unit/utils/inference/test_inference_store.py
Normal file
|
@ -0,0 +1,203 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import time
|
||||
from tempfile import TemporaryDirectory
|
||||
|
||||
import pytest
|
||||
|
||||
from llama_stack.apis.inference import (
|
||||
OpenAIAssistantMessageParam,
|
||||
OpenAIChatCompletion,
|
||||
OpenAIChoice,
|
||||
OpenAIUserMessageParam,
|
||||
Order,
|
||||
)
|
||||
from llama_stack.providers.utils.inference.inference_store import InferenceStore
|
||||
from llama_stack.providers.utils.sqlstore.sqlstore import SqliteSqlStoreConfig
|
||||
|
||||
|
||||
def create_test_chat_completion(
|
||||
completion_id: str, created_timestamp: int, model: str = "test-model"
|
||||
) -> OpenAIChatCompletion:
|
||||
"""Helper to create a test chat completion."""
|
||||
return OpenAIChatCompletion(
|
||||
id=completion_id,
|
||||
created=created_timestamp,
|
||||
model=model,
|
||||
object="chat.completion",
|
||||
choices=[
|
||||
OpenAIChoice(
|
||||
index=0,
|
||||
message=OpenAIAssistantMessageParam(
|
||||
role="assistant",
|
||||
content=f"Response for {completion_id}",
|
||||
),
|
||||
finish_reason="stop",
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_inference_store_pagination_basic():
|
||||
"""Test basic pagination functionality."""
|
||||
with TemporaryDirectory() as tmp_dir:
|
||||
db_path = tmp_dir + "/test.db"
|
||||
store = InferenceStore(SqliteSqlStoreConfig(db_path=db_path))
|
||||
await store.initialize()
|
||||
|
||||
# Create test data with different timestamps
|
||||
base_time = int(time.time())
|
||||
test_data = [
|
||||
("zebra-task", base_time + 1),
|
||||
("apple-job", base_time + 2),
|
||||
("moon-work", base_time + 3),
|
||||
("banana-run", base_time + 4),
|
||||
("car-exec", base_time + 5),
|
||||
]
|
||||
|
||||
# Store test chat completions
|
||||
for completion_id, timestamp in test_data:
|
||||
completion = create_test_chat_completion(completion_id, timestamp)
|
||||
input_messages = [OpenAIUserMessageParam(role="user", content=f"Test message for {completion_id}")]
|
||||
await store.store_chat_completion(completion, input_messages)
|
||||
|
||||
# Test 1: First page with limit=2, descending order (default)
|
||||
result = await store.list_chat_completions(limit=2, order=Order.desc)
|
||||
assert len(result.data) == 2
|
||||
assert result.data[0].id == "car-exec" # Most recent first
|
||||
assert result.data[1].id == "banana-run"
|
||||
assert result.has_more is True
|
||||
assert result.last_id == "banana-run"
|
||||
|
||||
# Test 2: Second page using 'after' parameter
|
||||
result2 = await store.list_chat_completions(after="banana-run", limit=2, order=Order.desc)
|
||||
assert len(result2.data) == 2
|
||||
assert result2.data[0].id == "moon-work"
|
||||
assert result2.data[1].id == "apple-job"
|
||||
assert result2.has_more is True
|
||||
|
||||
# Test 3: Final page
|
||||
result3 = await store.list_chat_completions(after="apple-job", limit=2, order=Order.desc)
|
||||
assert len(result3.data) == 1
|
||||
assert result3.data[0].id == "zebra-task"
|
||||
assert result3.has_more is False
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_inference_store_pagination_ascending():
|
||||
"""Test pagination with ascending order."""
|
||||
with TemporaryDirectory() as tmp_dir:
|
||||
db_path = tmp_dir + "/test.db"
|
||||
store = InferenceStore(SqliteSqlStoreConfig(db_path=db_path))
|
||||
await store.initialize()
|
||||
|
||||
# Create test data
|
||||
base_time = int(time.time())
|
||||
test_data = [
|
||||
("delta-item", base_time + 1),
|
||||
("charlie-task", base_time + 2),
|
||||
("alpha-work", base_time + 3),
|
||||
]
|
||||
|
||||
# Store test chat completions
|
||||
for completion_id, timestamp in test_data:
|
||||
completion = create_test_chat_completion(completion_id, timestamp)
|
||||
input_messages = [OpenAIUserMessageParam(role="user", content=f"Test message for {completion_id}")]
|
||||
await store.store_chat_completion(completion, input_messages)
|
||||
|
||||
# Test ascending order pagination
|
||||
result = await store.list_chat_completions(limit=1, order=Order.asc)
|
||||
assert len(result.data) == 1
|
||||
assert result.data[0].id == "delta-item" # Oldest first
|
||||
assert result.has_more is True
|
||||
|
||||
# Second page with ascending order
|
||||
result2 = await store.list_chat_completions(after="delta-item", limit=1, order=Order.asc)
|
||||
assert len(result2.data) == 1
|
||||
assert result2.data[0].id == "charlie-task"
|
||||
assert result2.has_more is True
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_inference_store_pagination_with_model_filter():
|
||||
"""Test pagination combined with model filtering."""
|
||||
with TemporaryDirectory() as tmp_dir:
|
||||
db_path = tmp_dir + "/test.db"
|
||||
store = InferenceStore(SqliteSqlStoreConfig(db_path=db_path))
|
||||
await store.initialize()
|
||||
|
||||
# Create test data with different models
|
||||
base_time = int(time.time())
|
||||
test_data = [
|
||||
("xyz-task", base_time + 1, "model-a"),
|
||||
("def-work", base_time + 2, "model-b"),
|
||||
("pqr-job", base_time + 3, "model-a"),
|
||||
("abc-run", base_time + 4, "model-b"),
|
||||
]
|
||||
|
||||
# Store test chat completions
|
||||
for completion_id, timestamp, model in test_data:
|
||||
completion = create_test_chat_completion(completion_id, timestamp, model)
|
||||
input_messages = [OpenAIUserMessageParam(role="user", content=f"Test message for {completion_id}")]
|
||||
await store.store_chat_completion(completion, input_messages)
|
||||
|
||||
# Test pagination with model filter
|
||||
result = await store.list_chat_completions(limit=1, model="model-a", order=Order.desc)
|
||||
assert len(result.data) == 1
|
||||
assert result.data[0].id == "pqr-job" # Most recent model-a
|
||||
assert result.data[0].model == "model-a"
|
||||
assert result.has_more is True
|
||||
|
||||
# Second page with model filter
|
||||
result2 = await store.list_chat_completions(after="pqr-job", limit=1, model="model-a", order=Order.desc)
|
||||
assert len(result2.data) == 1
|
||||
assert result2.data[0].id == "xyz-task"
|
||||
assert result2.data[0].model == "model-a"
|
||||
assert result2.has_more is False
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_inference_store_pagination_invalid_after():
|
||||
"""Test error handling for invalid 'after' parameter."""
|
||||
with TemporaryDirectory() as tmp_dir:
|
||||
db_path = tmp_dir + "/test.db"
|
||||
store = InferenceStore(SqliteSqlStoreConfig(db_path=db_path))
|
||||
await store.initialize()
|
||||
|
||||
# Try to paginate with non-existent ID
|
||||
with pytest.raises(ValueError, match="Record with id='non-existent' not found in table 'chat_completions'"):
|
||||
await store.list_chat_completions(after="non-existent", limit=2)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_inference_store_pagination_no_limit():
|
||||
"""Test pagination behavior when no limit is specified."""
|
||||
with TemporaryDirectory() as tmp_dir:
|
||||
db_path = tmp_dir + "/test.db"
|
||||
store = InferenceStore(SqliteSqlStoreConfig(db_path=db_path))
|
||||
await store.initialize()
|
||||
|
||||
# Create test data
|
||||
base_time = int(time.time())
|
||||
test_data = [
|
||||
("omega-first", base_time + 1),
|
||||
("beta-second", base_time + 2),
|
||||
]
|
||||
|
||||
# Store test chat completions
|
||||
for completion_id, timestamp in test_data:
|
||||
completion = create_test_chat_completion(completion_id, timestamp)
|
||||
input_messages = [OpenAIUserMessageParam(role="user", content=f"Test message for {completion_id}")]
|
||||
await store.store_chat_completion(completion, input_messages)
|
||||
|
||||
# Test without limit
|
||||
result = await store.list_chat_completions(order=Order.desc)
|
||||
assert len(result.data) == 2
|
||||
assert result.data[0].id == "beta-second" # Most recent first
|
||||
assert result.data[1].id == "omega-first"
|
||||
assert result.has_more is False
|
366
tests/unit/utils/responses/test_responses_store.py
Normal file
366
tests/unit/utils/responses/test_responses_store.py
Normal file
|
@ -0,0 +1,366 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import time
|
||||
from tempfile import TemporaryDirectory
|
||||
|
||||
import pytest
|
||||
|
||||
from llama_stack.apis.agents import Order
|
||||
from llama_stack.apis.agents.openai_responses import (
|
||||
OpenAIResponseInput,
|
||||
OpenAIResponseObject,
|
||||
)
|
||||
from llama_stack.providers.utils.responses.responses_store import ResponsesStore
|
||||
from llama_stack.providers.utils.sqlstore.sqlstore import SqliteSqlStoreConfig
|
||||
|
||||
|
||||
def create_test_response_object(
|
||||
response_id: str, created_timestamp: int, model: str = "test-model"
|
||||
) -> OpenAIResponseObject:
|
||||
"""Helper to create a test response object."""
|
||||
return OpenAIResponseObject(
|
||||
id=response_id,
|
||||
created_at=created_timestamp,
|
||||
model=model,
|
||||
object="response",
|
||||
output=[], # Required field
|
||||
status="completed", # Required field
|
||||
)
|
||||
|
||||
|
||||
def create_test_response_input(content: str, input_id: str) -> OpenAIResponseInput:
|
||||
"""Helper to create a test response input."""
|
||||
from llama_stack.apis.agents.openai_responses import OpenAIResponseMessage
|
||||
|
||||
return OpenAIResponseMessage(
|
||||
id=input_id,
|
||||
content=content,
|
||||
role="user",
|
||||
type="message",
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_responses_store_pagination_basic():
|
||||
"""Test basic pagination functionality for responses store."""
|
||||
with TemporaryDirectory() as tmp_dir:
|
||||
db_path = tmp_dir + "/test.db"
|
||||
store = ResponsesStore(SqliteSqlStoreConfig(db_path=db_path))
|
||||
await store.initialize()
|
||||
|
||||
# Create test data with different timestamps
|
||||
base_time = int(time.time())
|
||||
test_data = [
|
||||
("zebra-resp", base_time + 1),
|
||||
("apple-resp", base_time + 2),
|
||||
("moon-resp", base_time + 3),
|
||||
("banana-resp", base_time + 4),
|
||||
("car-resp", base_time + 5),
|
||||
]
|
||||
|
||||
# Store test responses
|
||||
for response_id, timestamp in test_data:
|
||||
response = create_test_response_object(response_id, timestamp)
|
||||
input_list = [create_test_response_input(f"Input for {response_id}", f"input-{response_id}")]
|
||||
await store.store_response_object(response, input_list)
|
||||
|
||||
# Test 1: First page with limit=2, descending order (default)
|
||||
result = await store.list_responses(limit=2, order=Order.desc)
|
||||
assert len(result.data) == 2
|
||||
assert result.data[0].id == "car-resp" # Most recent first
|
||||
assert result.data[1].id == "banana-resp"
|
||||
assert result.has_more is True
|
||||
assert result.last_id == "banana-resp"
|
||||
|
||||
# Test 2: Second page using 'after' parameter
|
||||
result2 = await store.list_responses(after="banana-resp", limit=2, order=Order.desc)
|
||||
assert len(result2.data) == 2
|
||||
assert result2.data[0].id == "moon-resp"
|
||||
assert result2.data[1].id == "apple-resp"
|
||||
assert result2.has_more is True
|
||||
|
||||
# Test 3: Final page
|
||||
result3 = await store.list_responses(after="apple-resp", limit=2, order=Order.desc)
|
||||
assert len(result3.data) == 1
|
||||
assert result3.data[0].id == "zebra-resp"
|
||||
assert result3.has_more is False
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_responses_store_pagination_ascending():
|
||||
"""Test pagination with ascending order."""
|
||||
with TemporaryDirectory() as tmp_dir:
|
||||
db_path = tmp_dir + "/test.db"
|
||||
store = ResponsesStore(SqliteSqlStoreConfig(db_path=db_path))
|
||||
await store.initialize()
|
||||
|
||||
# Create test data
|
||||
base_time = int(time.time())
|
||||
test_data = [
|
||||
("delta-resp", base_time + 1),
|
||||
("charlie-resp", base_time + 2),
|
||||
("alpha-resp", base_time + 3),
|
||||
]
|
||||
|
||||
# Store test responses
|
||||
for response_id, timestamp in test_data:
|
||||
response = create_test_response_object(response_id, timestamp)
|
||||
input_list = [create_test_response_input(f"Input for {response_id}", f"input-{response_id}")]
|
||||
await store.store_response_object(response, input_list)
|
||||
|
||||
# Test ascending order pagination
|
||||
result = await store.list_responses(limit=1, order=Order.asc)
|
||||
assert len(result.data) == 1
|
||||
assert result.data[0].id == "delta-resp" # Oldest first
|
||||
assert result.has_more is True
|
||||
|
||||
# Second page with ascending order
|
||||
result2 = await store.list_responses(after="delta-resp", limit=1, order=Order.asc)
|
||||
assert len(result2.data) == 1
|
||||
assert result2.data[0].id == "charlie-resp"
|
||||
assert result2.has_more is True
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_responses_store_pagination_with_model_filter():
|
||||
"""Test pagination combined with model filtering."""
|
||||
with TemporaryDirectory() as tmp_dir:
|
||||
db_path = tmp_dir + "/test.db"
|
||||
store = ResponsesStore(SqliteSqlStoreConfig(db_path=db_path))
|
||||
await store.initialize()
|
||||
|
||||
# Create test data with different models
|
||||
base_time = int(time.time())
|
||||
test_data = [
|
||||
("xyz-resp", base_time + 1, "model-a"),
|
||||
("def-resp", base_time + 2, "model-b"),
|
||||
("pqr-resp", base_time + 3, "model-a"),
|
||||
("abc-resp", base_time + 4, "model-b"),
|
||||
]
|
||||
|
||||
# Store test responses
|
||||
for response_id, timestamp, model in test_data:
|
||||
response = create_test_response_object(response_id, timestamp, model)
|
||||
input_list = [create_test_response_input(f"Input for {response_id}", f"input-{response_id}")]
|
||||
await store.store_response_object(response, input_list)
|
||||
|
||||
# Test pagination with model filter
|
||||
result = await store.list_responses(limit=1, model="model-a", order=Order.desc)
|
||||
assert len(result.data) == 1
|
||||
assert result.data[0].id == "pqr-resp" # Most recent model-a
|
||||
assert result.data[0].model == "model-a"
|
||||
assert result.has_more is True
|
||||
|
||||
# Second page with model filter
|
||||
result2 = await store.list_responses(after="pqr-resp", limit=1, model="model-a", order=Order.desc)
|
||||
assert len(result2.data) == 1
|
||||
assert result2.data[0].id == "xyz-resp"
|
||||
assert result2.data[0].model == "model-a"
|
||||
assert result2.has_more is False
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_responses_store_pagination_invalid_after():
|
||||
"""Test error handling for invalid 'after' parameter."""
|
||||
with TemporaryDirectory() as tmp_dir:
|
||||
db_path = tmp_dir + "/test.db"
|
||||
store = ResponsesStore(SqliteSqlStoreConfig(db_path=db_path))
|
||||
await store.initialize()
|
||||
|
||||
# Try to paginate with non-existent ID
|
||||
with pytest.raises(ValueError, match="Record with id.*'non-existent' not found in table 'openai_responses'"):
|
||||
await store.list_responses(after="non-existent", limit=2)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_responses_store_pagination_no_limit():
|
||||
"""Test pagination behavior when no limit is specified."""
|
||||
with TemporaryDirectory() as tmp_dir:
|
||||
db_path = tmp_dir + "/test.db"
|
||||
store = ResponsesStore(SqliteSqlStoreConfig(db_path=db_path))
|
||||
await store.initialize()
|
||||
|
||||
# Create test data
|
||||
base_time = int(time.time())
|
||||
test_data = [
|
||||
("omega-resp", base_time + 1),
|
||||
("beta-resp", base_time + 2),
|
||||
]
|
||||
|
||||
# Store test responses
|
||||
for response_id, timestamp in test_data:
|
||||
response = create_test_response_object(response_id, timestamp)
|
||||
input_list = [create_test_response_input(f"Input for {response_id}", f"input-{response_id}")]
|
||||
await store.store_response_object(response, input_list)
|
||||
|
||||
# Test without limit (should use default of 50)
|
||||
result = await store.list_responses(order=Order.desc)
|
||||
assert len(result.data) == 2
|
||||
assert result.data[0].id == "beta-resp" # Most recent first
|
||||
assert result.data[1].id == "omega-resp"
|
||||
assert result.has_more is False
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_responses_store_get_response_object():
|
||||
"""Test retrieving a single response object."""
|
||||
with TemporaryDirectory() as tmp_dir:
|
||||
db_path = tmp_dir + "/test.db"
|
||||
store = ResponsesStore(SqliteSqlStoreConfig(db_path=db_path))
|
||||
await store.initialize()
|
||||
|
||||
# Store a test response
|
||||
response = create_test_response_object("test-resp", int(time.time()))
|
||||
input_list = [create_test_response_input("Test input content", "input-test-resp")]
|
||||
await store.store_response_object(response, input_list)
|
||||
|
||||
# Retrieve the response
|
||||
retrieved = await store.get_response_object("test-resp")
|
||||
assert retrieved.id == "test-resp"
|
||||
assert retrieved.model == "test-model"
|
||||
assert len(retrieved.input) == 1
|
||||
assert retrieved.input[0].content == "Test input content"
|
||||
|
||||
# Test error for non-existent response
|
||||
with pytest.raises(ValueError, match="Response with id non-existent not found"):
|
||||
await store.get_response_object("non-existent")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_responses_store_input_items_pagination():
|
||||
"""Test pagination functionality for input items."""
|
||||
with TemporaryDirectory() as tmp_dir:
|
||||
db_path = tmp_dir + "/test.db"
|
||||
store = ResponsesStore(SqliteSqlStoreConfig(db_path=db_path))
|
||||
await store.initialize()
|
||||
|
||||
# Store a test response with many inputs with explicit IDs
|
||||
response = create_test_response_object("test-resp", int(time.time()))
|
||||
input_list = [
|
||||
create_test_response_input("First input", "input-1"),
|
||||
create_test_response_input("Second input", "input-2"),
|
||||
create_test_response_input("Third input", "input-3"),
|
||||
create_test_response_input("Fourth input", "input-4"),
|
||||
create_test_response_input("Fifth input", "input-5"),
|
||||
]
|
||||
await store.store_response_object(response, input_list)
|
||||
|
||||
# Verify all items are stored correctly with explicit IDs
|
||||
all_items = await store.list_response_input_items("test-resp", order=Order.desc)
|
||||
assert len(all_items.data) == 5
|
||||
|
||||
# In desc order: [Fifth, Fourth, Third, Second, First]
|
||||
assert all_items.data[0].content == "Fifth input"
|
||||
assert all_items.data[0].id == "input-5"
|
||||
assert all_items.data[1].content == "Fourth input"
|
||||
assert all_items.data[1].id == "input-4"
|
||||
assert all_items.data[2].content == "Third input"
|
||||
assert all_items.data[2].id == "input-3"
|
||||
assert all_items.data[3].content == "Second input"
|
||||
assert all_items.data[3].id == "input-2"
|
||||
assert all_items.data[4].content == "First input"
|
||||
assert all_items.data[4].id == "input-1"
|
||||
|
||||
# Test basic pagination with after parameter using actual IDs
|
||||
result = await store.list_response_input_items("test-resp", limit=2, order=Order.desc)
|
||||
assert len(result.data) == 2
|
||||
assert result.data[0].content == "Fifth input" # Most recent first (reversed order)
|
||||
assert result.data[1].content == "Fourth input"
|
||||
|
||||
# Test pagination using after with actual ID
|
||||
result2 = await store.list_response_input_items("test-resp", after="input-5", limit=2, order=Order.desc)
|
||||
assert len(result2.data) == 2
|
||||
assert result2.data[0].content == "Fourth input" # Next item after Fifth
|
||||
assert result2.data[1].content == "Third input"
|
||||
|
||||
# Test final page
|
||||
result3 = await store.list_response_input_items("test-resp", after="input-3", limit=2, order=Order.desc)
|
||||
assert len(result3.data) == 2
|
||||
assert result3.data[0].content == "Second input"
|
||||
assert result3.data[1].content == "First input"
|
||||
|
||||
# Test ascending order pagination
|
||||
result_asc = await store.list_response_input_items("test-resp", limit=2, order=Order.asc)
|
||||
assert len(result_asc.data) == 2
|
||||
assert result_asc.data[0].content == "First input" # Oldest first
|
||||
assert result_asc.data[1].content == "Second input"
|
||||
|
||||
# Test pagination with ascending order
|
||||
result_asc2 = await store.list_response_input_items("test-resp", after="input-1", limit=2, order=Order.asc)
|
||||
assert len(result_asc2.data) == 2
|
||||
assert result_asc2.data[0].content == "Second input"
|
||||
assert result_asc2.data[1].content == "Third input"
|
||||
|
||||
# Test error for non-existent after ID
|
||||
with pytest.raises(ValueError, match="Input item with id 'non-existent' not found for response 'test-resp'"):
|
||||
await store.list_response_input_items("test-resp", after="non-existent")
|
||||
|
||||
# Test error for unsupported features
|
||||
with pytest.raises(NotImplementedError, match="Include is not supported yet"):
|
||||
await store.list_response_input_items("test-resp", include=["some-field"])
|
||||
|
||||
# Test error for mutually exclusive parameters
|
||||
with pytest.raises(ValueError, match="Cannot specify both 'before' and 'after' parameters"):
|
||||
await store.list_response_input_items("test-resp", before="some-id", after="other-id")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_responses_store_input_items_before_pagination():
|
||||
"""Test before pagination functionality for input items."""
|
||||
with TemporaryDirectory() as tmp_dir:
|
||||
db_path = tmp_dir + "/test.db"
|
||||
store = ResponsesStore(SqliteSqlStoreConfig(db_path=db_path))
|
||||
await store.initialize()
|
||||
|
||||
# Store a test response with many inputs with explicit IDs
|
||||
response = create_test_response_object("test-resp-before", int(time.time()))
|
||||
input_list = [
|
||||
create_test_response_input("First input", "before-1"),
|
||||
create_test_response_input("Second input", "before-2"),
|
||||
create_test_response_input("Third input", "before-3"),
|
||||
create_test_response_input("Fourth input", "before-4"),
|
||||
create_test_response_input("Fifth input", "before-5"),
|
||||
]
|
||||
await store.store_response_object(response, input_list)
|
||||
|
||||
# Test before pagination with descending order
|
||||
# In desc order: [Fifth, Fourth, Third, Second, First]
|
||||
# before="before-3" should return [Fifth, Fourth]
|
||||
result = await store.list_response_input_items("test-resp-before", before="before-3", order=Order.desc)
|
||||
assert len(result.data) == 2
|
||||
assert result.data[0].content == "Fifth input"
|
||||
assert result.data[1].content == "Fourth input"
|
||||
|
||||
# Test before pagination with limit
|
||||
result2 = await store.list_response_input_items(
|
||||
"test-resp-before", before="before-2", limit=3, order=Order.desc
|
||||
)
|
||||
assert len(result2.data) == 3
|
||||
assert result2.data[0].content == "Fifth input"
|
||||
assert result2.data[1].content == "Fourth input"
|
||||
assert result2.data[2].content == "Third input"
|
||||
|
||||
# Test before pagination with ascending order
|
||||
# In asc order: [First, Second, Third, Fourth, Fifth]
|
||||
# before="before-4" should return [First, Second, Third]
|
||||
result3 = await store.list_response_input_items("test-resp-before", before="before-4", order=Order.asc)
|
||||
assert len(result3.data) == 3
|
||||
assert result3.data[0].content == "First input"
|
||||
assert result3.data[1].content == "Second input"
|
||||
assert result3.data[2].content == "Third input"
|
||||
|
||||
# Test before with limit in ascending order
|
||||
result4 = await store.list_response_input_items("test-resp-before", before="before-5", limit=2, order=Order.asc)
|
||||
assert len(result4.data) == 2
|
||||
assert result4.data[0].content == "First input"
|
||||
assert result4.data[1].content == "Second input"
|
||||
|
||||
# Test error for non-existent before ID
|
||||
with pytest.raises(
|
||||
ValueError, match="Input item with id 'non-existent' not found for response 'test-resp-before'"
|
||||
):
|
||||
await store.list_response_input_items("test-resp-before", before="non-existent")
|
390
tests/unit/utils/sqlstore/test_sqlstore.py
Normal file
390
tests/unit/utils/sqlstore/test_sqlstore.py
Normal file
|
@ -0,0 +1,390 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import time
|
||||
from tempfile import TemporaryDirectory
|
||||
|
||||
import pytest
|
||||
|
||||
from llama_stack.providers.utils.sqlstore.api import ColumnType
|
||||
from llama_stack.providers.utils.sqlstore.sqlalchemy_sqlstore import SqlAlchemySqlStoreImpl
|
||||
from llama_stack.providers.utils.sqlstore.sqlstore import SqliteSqlStoreConfig
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_sqlite_sqlstore():
|
||||
with TemporaryDirectory() as tmp_dir:
|
||||
db_name = "test.db"
|
||||
sqlstore = SqlAlchemySqlStoreImpl(
|
||||
SqliteSqlStoreConfig(
|
||||
db_path=tmp_dir + "/" + db_name,
|
||||
)
|
||||
)
|
||||
await sqlstore.create_table(
|
||||
table="test",
|
||||
schema={
|
||||
"id": ColumnType.INTEGER,
|
||||
"name": ColumnType.STRING,
|
||||
},
|
||||
)
|
||||
await sqlstore.insert("test", {"id": 1, "name": "test"})
|
||||
await sqlstore.insert("test", {"id": 12, "name": "test12"})
|
||||
result = await sqlstore.fetch_all("test")
|
||||
assert result.data == [{"id": 1, "name": "test"}, {"id": 12, "name": "test12"}]
|
||||
assert result.has_more is False
|
||||
|
||||
row = await sqlstore.fetch_one("test", {"id": 1})
|
||||
assert row == {"id": 1, "name": "test"}
|
||||
|
||||
row = await sqlstore.fetch_one("test", {"name": "test12"})
|
||||
assert row == {"id": 12, "name": "test12"}
|
||||
|
||||
# order by
|
||||
result = await sqlstore.fetch_all("test", order_by=[("id", "asc")])
|
||||
assert result.data == [{"id": 1, "name": "test"}, {"id": 12, "name": "test12"}]
|
||||
|
||||
result = await sqlstore.fetch_all("test", order_by=[("id", "desc")])
|
||||
assert result.data == [{"id": 12, "name": "test12"}, {"id": 1, "name": "test"}]
|
||||
|
||||
# limit
|
||||
result = await sqlstore.fetch_all("test", limit=1)
|
||||
assert result.data == [{"id": 1, "name": "test"}]
|
||||
assert result.has_more is True
|
||||
|
||||
# update
|
||||
await sqlstore.update("test", {"name": "test123"}, {"id": 1})
|
||||
row = await sqlstore.fetch_one("test", {"id": 1})
|
||||
assert row == {"id": 1, "name": "test123"}
|
||||
|
||||
# delete
|
||||
await sqlstore.delete("test", {"id": 1})
|
||||
result = await sqlstore.fetch_all("test")
|
||||
assert result.data == [{"id": 12, "name": "test12"}]
|
||||
assert result.has_more is False
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_sqlstore_pagination_basic():
|
||||
"""Test basic pagination functionality at the SQL store level."""
|
||||
with TemporaryDirectory() as tmp_dir:
|
||||
db_path = tmp_dir + "/test.db"
|
||||
store = SqlAlchemySqlStoreImpl(SqliteSqlStoreConfig(db_path=db_path))
|
||||
|
||||
# Create test table
|
||||
await store.create_table(
|
||||
"test_records",
|
||||
{
|
||||
"id": ColumnType.STRING,
|
||||
"created_at": ColumnType.INTEGER,
|
||||
"name": ColumnType.STRING,
|
||||
},
|
||||
)
|
||||
|
||||
# Insert test data
|
||||
base_time = int(time.time())
|
||||
test_data = [
|
||||
{"id": "zebra", "created_at": base_time + 1, "name": "First"},
|
||||
{"id": "apple", "created_at": base_time + 2, "name": "Second"},
|
||||
{"id": "moon", "created_at": base_time + 3, "name": "Third"},
|
||||
{"id": "banana", "created_at": base_time + 4, "name": "Fourth"},
|
||||
{"id": "car", "created_at": base_time + 5, "name": "Fifth"},
|
||||
]
|
||||
|
||||
for record in test_data:
|
||||
await store.insert("test_records", record)
|
||||
|
||||
# Test 1: First page (no cursor)
|
||||
result = await store.fetch_all(
|
||||
table="test_records",
|
||||
order_by=[("created_at", "desc")],
|
||||
limit=2,
|
||||
)
|
||||
assert len(result.data) == 2
|
||||
assert result.data[0]["id"] == "car" # Most recent first
|
||||
assert result.data[1]["id"] == "banana"
|
||||
assert result.has_more is True
|
||||
|
||||
# Test 2: Second page using cursor
|
||||
result2 = await store.fetch_all(
|
||||
table="test_records",
|
||||
order_by=[("created_at", "desc")],
|
||||
cursor=("id", "banana"),
|
||||
limit=2,
|
||||
)
|
||||
assert len(result2.data) == 2
|
||||
assert result2.data[0]["id"] == "moon"
|
||||
assert result2.data[1]["id"] == "apple"
|
||||
assert result2.has_more is True
|
||||
|
||||
# Test 3: Final page
|
||||
result3 = await store.fetch_all(
|
||||
table="test_records",
|
||||
order_by=[("created_at", "desc")],
|
||||
cursor=("id", "apple"),
|
||||
limit=2,
|
||||
)
|
||||
assert len(result3.data) == 1
|
||||
assert result3.data[0]["id"] == "zebra"
|
||||
assert result3.has_more is False
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_sqlstore_pagination_with_filter():
|
||||
"""Test pagination with WHERE conditions."""
|
||||
with TemporaryDirectory() as tmp_dir:
|
||||
db_path = tmp_dir + "/test.db"
|
||||
store = SqlAlchemySqlStoreImpl(SqliteSqlStoreConfig(db_path=db_path))
|
||||
|
||||
# Create test table
|
||||
await store.create_table(
|
||||
"test_records",
|
||||
{
|
||||
"id": ColumnType.STRING,
|
||||
"created_at": ColumnType.INTEGER,
|
||||
"category": ColumnType.STRING,
|
||||
},
|
||||
)
|
||||
|
||||
# Insert test data with categories
|
||||
base_time = int(time.time())
|
||||
test_data = [
|
||||
{"id": "xyz", "created_at": base_time + 1, "category": "A"},
|
||||
{"id": "def", "created_at": base_time + 2, "category": "B"},
|
||||
{"id": "pqr", "created_at": base_time + 3, "category": "A"},
|
||||
{"id": "abc", "created_at": base_time + 4, "category": "B"},
|
||||
]
|
||||
|
||||
for record in test_data:
|
||||
await store.insert("test_records", record)
|
||||
|
||||
# Test pagination with filter
|
||||
result = await store.fetch_all(
|
||||
table="test_records",
|
||||
where={"category": "A"},
|
||||
order_by=[("created_at", "desc")],
|
||||
limit=1,
|
||||
)
|
||||
assert len(result.data) == 1
|
||||
assert result.data[0]["id"] == "pqr" # Most recent category A
|
||||
assert result.has_more is True
|
||||
|
||||
# Second page with filter
|
||||
result2 = await store.fetch_all(
|
||||
table="test_records",
|
||||
where={"category": "A"},
|
||||
order_by=[("created_at", "desc")],
|
||||
cursor=("id", "pqr"),
|
||||
limit=1,
|
||||
)
|
||||
assert len(result2.data) == 1
|
||||
assert result2.data[0]["id"] == "xyz"
|
||||
assert result2.has_more is False
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_sqlstore_pagination_ascending_order():
|
||||
"""Test pagination with ascending order."""
|
||||
with TemporaryDirectory() as tmp_dir:
|
||||
db_path = tmp_dir + "/test.db"
|
||||
store = SqlAlchemySqlStoreImpl(SqliteSqlStoreConfig(db_path=db_path))
|
||||
|
||||
# Create test table
|
||||
await store.create_table(
|
||||
"test_records",
|
||||
{
|
||||
"id": ColumnType.STRING,
|
||||
"created_at": ColumnType.INTEGER,
|
||||
},
|
||||
)
|
||||
|
||||
# Insert test data
|
||||
base_time = int(time.time())
|
||||
test_data = [
|
||||
{"id": "gamma", "created_at": base_time + 1},
|
||||
{"id": "alpha", "created_at": base_time + 2},
|
||||
{"id": "beta", "created_at": base_time + 3},
|
||||
]
|
||||
|
||||
for record in test_data:
|
||||
await store.insert("test_records", record)
|
||||
|
||||
# Test ascending order
|
||||
result = await store.fetch_all(
|
||||
table="test_records",
|
||||
order_by=[("created_at", "asc")],
|
||||
limit=1,
|
||||
)
|
||||
assert len(result.data) == 1
|
||||
assert result.data[0]["id"] == "gamma" # Oldest first
|
||||
assert result.has_more is True
|
||||
|
||||
# Second page with ascending order
|
||||
result2 = await store.fetch_all(
|
||||
table="test_records",
|
||||
order_by=[("created_at", "asc")],
|
||||
cursor=("id", "gamma"),
|
||||
limit=1,
|
||||
)
|
||||
assert len(result2.data) == 1
|
||||
assert result2.data[0]["id"] == "alpha"
|
||||
assert result2.has_more is True
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_sqlstore_pagination_multi_column_ordering_error():
|
||||
"""Test that multi-column ordering raises an error when using cursor pagination."""
|
||||
with TemporaryDirectory() as tmp_dir:
|
||||
db_path = tmp_dir + "/test.db"
|
||||
store = SqlAlchemySqlStoreImpl(SqliteSqlStoreConfig(db_path=db_path))
|
||||
|
||||
# Create test table
|
||||
await store.create_table(
|
||||
"test_records",
|
||||
{
|
||||
"id": ColumnType.STRING,
|
||||
"priority": ColumnType.INTEGER,
|
||||
"created_at": ColumnType.INTEGER,
|
||||
},
|
||||
)
|
||||
|
||||
await store.insert("test_records", {"id": "task1", "priority": 1, "created_at": 12345})
|
||||
|
||||
# Test that multi-column ordering with cursor raises error
|
||||
with pytest.raises(ValueError, match="Cursor pagination only supports single-column ordering, got 2 columns"):
|
||||
await store.fetch_all(
|
||||
table="test_records",
|
||||
order_by=[("priority", "asc"), ("created_at", "desc")],
|
||||
cursor=("id", "task1"),
|
||||
limit=2,
|
||||
)
|
||||
|
||||
# Test that multi-column ordering without cursor works fine
|
||||
result = await store.fetch_all(
|
||||
table="test_records",
|
||||
order_by=[("priority", "asc"), ("created_at", "desc")],
|
||||
limit=2,
|
||||
)
|
||||
assert len(result.data) == 1
|
||||
assert result.data[0]["id"] == "task1"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_sqlstore_pagination_cursor_requires_order_by():
|
||||
"""Test that cursor pagination requires order_by parameter."""
|
||||
with TemporaryDirectory() as tmp_dir:
|
||||
db_path = tmp_dir + "/test.db"
|
||||
store = SqlAlchemySqlStoreImpl(SqliteSqlStoreConfig(db_path=db_path))
|
||||
|
||||
await store.create_table("test_records", {"id": ColumnType.STRING})
|
||||
await store.insert("test_records", {"id": "task1"})
|
||||
|
||||
# Test that cursor without order_by raises error
|
||||
with pytest.raises(ValueError, match="order_by is required when using cursor pagination"):
|
||||
await store.fetch_all(
|
||||
table="test_records",
|
||||
cursor=("id", "task1"),
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_sqlstore_pagination_error_handling():
|
||||
"""Test error handling for invalid columns and cursor IDs."""
|
||||
with TemporaryDirectory() as tmp_dir:
|
||||
db_path = tmp_dir + "/test.db"
|
||||
store = SqlAlchemySqlStoreImpl(SqliteSqlStoreConfig(db_path=db_path))
|
||||
|
||||
# Create test table
|
||||
await store.create_table(
|
||||
"test_records",
|
||||
{
|
||||
"id": ColumnType.STRING,
|
||||
"name": ColumnType.STRING,
|
||||
},
|
||||
)
|
||||
|
||||
await store.insert("test_records", {"id": "test1", "name": "Test"})
|
||||
|
||||
# Test invalid cursor tuple format
|
||||
with pytest.raises(ValueError, match="Cursor must be a tuple of"):
|
||||
await store.fetch_all(
|
||||
table="test_records",
|
||||
order_by=[("name", "asc")],
|
||||
cursor="invalid", # Should be tuple
|
||||
)
|
||||
|
||||
# Test invalid cursor_key_column
|
||||
with pytest.raises(ValueError, match="Cursor key column 'nonexistent' not found in table"):
|
||||
await store.fetch_all(
|
||||
table="test_records",
|
||||
order_by=[("name", "asc")],
|
||||
cursor=("nonexistent", "test1"),
|
||||
)
|
||||
|
||||
# Test invalid order_by column
|
||||
with pytest.raises(ValueError, match="Column 'invalid_col' not found in table"):
|
||||
await store.fetch_all(
|
||||
table="test_records",
|
||||
order_by=[("invalid_col", "asc")],
|
||||
)
|
||||
|
||||
# Test nonexistent cursor_id
|
||||
with pytest.raises(ValueError, match="Record with id='nonexistent' not found in table"):
|
||||
await store.fetch_all(
|
||||
table="test_records",
|
||||
order_by=[("name", "asc")],
|
||||
cursor=("id", "nonexistent"),
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_sqlstore_pagination_custom_key_column():
|
||||
"""Test pagination with custom primary key column (not 'id')."""
|
||||
with TemporaryDirectory() as tmp_dir:
|
||||
db_path = tmp_dir + "/test.db"
|
||||
store = SqlAlchemySqlStoreImpl(SqliteSqlStoreConfig(db_path=db_path))
|
||||
|
||||
# Create test table with custom primary key
|
||||
await store.create_table(
|
||||
"custom_table",
|
||||
{
|
||||
"uuid": ColumnType.STRING,
|
||||
"timestamp": ColumnType.INTEGER,
|
||||
"data": ColumnType.STRING,
|
||||
},
|
||||
)
|
||||
|
||||
# Insert test data
|
||||
base_time = int(time.time())
|
||||
test_data = [
|
||||
{"uuid": "uuid-alpha", "timestamp": base_time + 1, "data": "First"},
|
||||
{"uuid": "uuid-beta", "timestamp": base_time + 2, "data": "Second"},
|
||||
{"uuid": "uuid-gamma", "timestamp": base_time + 3, "data": "Third"},
|
||||
]
|
||||
|
||||
for record in test_data:
|
||||
await store.insert("custom_table", record)
|
||||
|
||||
# Test pagination with custom key column
|
||||
result = await store.fetch_all(
|
||||
table="custom_table",
|
||||
order_by=[("timestamp", "desc")],
|
||||
limit=2,
|
||||
)
|
||||
assert len(result.data) == 2
|
||||
assert result.data[0]["uuid"] == "uuid-gamma" # Most recent
|
||||
assert result.data[1]["uuid"] == "uuid-beta"
|
||||
assert result.has_more is True
|
||||
|
||||
# Second page using custom key column
|
||||
result2 = await store.fetch_all(
|
||||
table="custom_table",
|
||||
order_by=[("timestamp", "desc")],
|
||||
cursor=("uuid", "uuid-beta"), # Use uuid as key column
|
||||
limit=2,
|
||||
)
|
||||
assert len(result2.data) == 1
|
||||
assert result2.data[0]["uuid"] == "uuid-alpha"
|
||||
assert result2.has_more is False
|
|
@ -1,62 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from tempfile import TemporaryDirectory
|
||||
|
||||
import pytest
|
||||
|
||||
from llama_stack.providers.utils.sqlstore.api import ColumnType
|
||||
from llama_stack.providers.utils.sqlstore.sqlalchemy_sqlstore import SqlAlchemySqlStoreImpl
|
||||
from llama_stack.providers.utils.sqlstore.sqlstore import SqliteSqlStoreConfig
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_sqlite_sqlstore():
|
||||
with TemporaryDirectory() as tmp_dir:
|
||||
db_name = "test.db"
|
||||
sqlstore = SqlAlchemySqlStoreImpl(
|
||||
SqliteSqlStoreConfig(
|
||||
db_path=tmp_dir + "/" + db_name,
|
||||
)
|
||||
)
|
||||
await sqlstore.create_table(
|
||||
table="test",
|
||||
schema={
|
||||
"id": ColumnType.INTEGER,
|
||||
"name": ColumnType.STRING,
|
||||
},
|
||||
)
|
||||
await sqlstore.insert("test", {"id": 1, "name": "test"})
|
||||
await sqlstore.insert("test", {"id": 12, "name": "test12"})
|
||||
rows = await sqlstore.fetch_all("test")
|
||||
assert rows == [{"id": 1, "name": "test"}, {"id": 12, "name": "test12"}]
|
||||
|
||||
row = await sqlstore.fetch_one("test", {"id": 1})
|
||||
assert row == {"id": 1, "name": "test"}
|
||||
|
||||
row = await sqlstore.fetch_one("test", {"name": "test12"})
|
||||
assert row == {"id": 12, "name": "test12"}
|
||||
|
||||
# order by
|
||||
rows = await sqlstore.fetch_all("test", order_by=[("id", "asc")])
|
||||
assert rows == [{"id": 1, "name": "test"}, {"id": 12, "name": "test12"}]
|
||||
|
||||
rows = await sqlstore.fetch_all("test", order_by=[("id", "desc")])
|
||||
assert rows == [{"id": 12, "name": "test12"}, {"id": 1, "name": "test"}]
|
||||
|
||||
# limit
|
||||
rows = await sqlstore.fetch_all("test", limit=1)
|
||||
assert rows == [{"id": 1, "name": "test"}]
|
||||
|
||||
# update
|
||||
await sqlstore.update("test", {"name": "test123"}, {"id": 1})
|
||||
row = await sqlstore.fetch_one("test", {"id": 1})
|
||||
assert row == {"id": 1, "name": "test123"}
|
||||
|
||||
# delete
|
||||
await sqlstore.delete("test", {"id": 1})
|
||||
rows = await sqlstore.fetch_all("test")
|
||||
assert rows == [{"id": 12, "name": "test12"}]
|
Loading…
Add table
Add a link
Reference in a new issue