diff --git a/llama_stack/providers/utils/sqlstore/sqlalchemy_sqlstore.py b/llama_stack/providers/utils/sqlstore/sqlalchemy_sqlstore.py index f75c35314..46ed8c1d1 100644 --- a/llama_stack/providers/utils/sqlstore/sqlalchemy_sqlstore.py +++ b/llama_stack/providers/utils/sqlstore/sqlalchemy_sqlstore.py @@ -23,6 +23,7 @@ from sqlalchemy import ( ) from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine from sqlalchemy.ext.asyncio.engine import AsyncEngine +from sqlalchemy.sql.elements import ColumnElement from llama_stack.apis.common.responses import PaginatedResponse from llama_stack.log import get_logger @@ -43,6 +44,30 @@ TYPE_MAPPING: dict[ColumnType, Any] = { } +def _build_where_expr(column: ColumnElement, value: Any) -> ColumnElement: + """Return a SQLAlchemy expression for a where condition. + + `value` may be a simple scalar (equality) or a mapping like {">": 123}. + The returned expression is a SQLAlchemy ColumnElement usable in query.where(...). + """ + if isinstance(value, Mapping): + if len(value) != 1: + raise ValueError(f"Operator mapping must have a single operator, got: {value}") + op, operand = next(iter(value.items())) + if op == "==" or op == "=": + return column == operand + if op == ">": + return column > operand + if op == "<": + return column < operand + if op == ">=": + return column >= operand + if op == "<=": + return column <= operand + raise ValueError(f"Unsupported operator '{op}' in where mapping") + return column == value + + class SqlAlchemySqlStoreImpl(SqlStore): def __init__(self, config: SqlAlchemySqlStoreConfig): self.config = config @@ -111,7 +136,7 @@ class SqlAlchemySqlStoreImpl(SqlStore): if where: for key, value in where.items(): - query = query.where(table_obj.c[key] == value) + query = query.where(_build_where_expr(table_obj.c[key], value)) if where_sql: query = query.where(text(where_sql)) @@ -222,7 +247,7 @@ class SqlAlchemySqlStoreImpl(SqlStore): async with self.async_session() as session: stmt = self.metadata.tables[table].update() for key, value in where.items(): - stmt = stmt.where(self.metadata.tables[table].c[key] == value) + stmt = stmt.where(_build_where_expr(self.metadata.tables[table].c[key], value)) await session.execute(stmt, data) await session.commit() @@ -233,7 +258,7 @@ class SqlAlchemySqlStoreImpl(SqlStore): async with self.async_session() as session: stmt = self.metadata.tables[table].delete() for key, value in where.items(): - stmt = stmt.where(self.metadata.tables[table].c[key] == value) + stmt = stmt.where(_build_where_expr(self.metadata.tables[table].c[key], value)) await session.execute(stmt) await session.commit() diff --git a/tests/unit/utils/sqlstore/test_sqlstore.py b/tests/unit/utils/sqlstore/test_sqlstore.py index 778f0b658..ba59ec7ec 100644 --- a/tests/unit/utils/sqlstore/test_sqlstore.py +++ b/tests/unit/utils/sqlstore/test_sqlstore.py @@ -332,6 +332,63 @@ async def test_sqlstore_pagination_error_handling(): ) +async def test_where_operator_gt_and_update_delete(): + with TemporaryDirectory() as tmp_dir: + db_path = tmp_dir + "/test.db" + store = SqlAlchemySqlStoreImpl(SqliteSqlStoreConfig(db_path=db_path)) + + await store.create_table( + "items", + { + "id": ColumnType.INTEGER, + "value": ColumnType.INTEGER, + "name": ColumnType.STRING, + }, + ) + + await store.insert("items", {"id": 1, "value": 10, "name": "one"}) + await store.insert("items", {"id": 2, "value": 20, "name": "two"}) + await store.insert("items", {"id": 3, "value": 30, "name": "three"}) + + result = await store.fetch_all("items", where={"value": {">": 15}}) + assert {r["id"] for r in result.data} == {2, 3} + + row = await store.fetch_one("items", where={"value": {">=": 30}}) + assert row["id"] == 3 + + await store.update("items", {"name": "small"}, {"value": {"<": 25}}) + rows = (await store.fetch_all("items")).data + names = {r["id"]: r["name"] for r in rows} + assert names[1] == "small" + assert names[2] == "small" + assert names[3] == "three" + + await store.delete("items", {"id": {"==": 2}}) + rows_after = (await store.fetch_all("items")).data + assert {r["id"] for r in rows_after} == {1, 3} + + +async def test_where_operator_edge_cases(): + with TemporaryDirectory() as tmp_dir: + db_path = tmp_dir + "/test.db" + store = SqlAlchemySqlStoreImpl(SqliteSqlStoreConfig(db_path=db_path)) + + await store.create_table( + "events", + {"id": ColumnType.STRING, "ts": ColumnType.INTEGER}, + ) + + base = 1024 + await store.insert("events", {"id": "a", "ts": base - 10}) + await store.insert("events", {"id": "b", "ts": base + 10}) + + row = await store.fetch_one("events", where={"id": "a"}) + assert row["id"] == "a" + + with pytest.raises(ValueError, match="Unsupported operator"): + await store.fetch_all("events", where={"ts": {"!=": base}}) + + async def test_sqlstore_pagination_custom_key_column(): """Test pagination with custom primary key column (not 'id').""" with TemporaryDirectory() as tmp_dir: