mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-03 19:57:35 +00:00
chore(dev): add inequality support to sqlstore where clause (#3272)
Some checks failed
Integration Auth Tests / test-matrix (oauth2_token) (push) Failing after 1s
Integration Tests (Replay) / Integration Tests (, , , client=, vision=) (push) Failing after 1s
SqlStore Integration Tests / test-postgres (3.12) (push) Failing after 2s
SqlStore Integration Tests / test-postgres (3.13) (push) Failing after 1s
Vector IO Integration Tests / test-matrix (push) Failing after 1s
Pre-commit / pre-commit (push) Failing after 1s
Python Package Build Test / build (3.12) (push) Failing after 0s
Python Package Build Test / build (3.13) (push) Failing after 1s
Test External Providers Installed via Module / test-external-providers-from-module (venv) (push) Has been skipped
Test External API and Providers / test-external (venv) (push) Failing after 1s
UI Tests / ui-tests (22) (push) Failing after 0s
Unit Tests / unit-tests (3.12) (push) Failing after 1s
Unit Tests / unit-tests (3.13) (push) Failing after 1s
Some checks failed
Integration Auth Tests / test-matrix (oauth2_token) (push) Failing after 1s
Integration Tests (Replay) / Integration Tests (, , , client=, vision=) (push) Failing after 1s
SqlStore Integration Tests / test-postgres (3.12) (push) Failing after 2s
SqlStore Integration Tests / test-postgres (3.13) (push) Failing after 1s
Vector IO Integration Tests / test-matrix (push) Failing after 1s
Pre-commit / pre-commit (push) Failing after 1s
Python Package Build Test / build (3.12) (push) Failing after 0s
Python Package Build Test / build (3.13) (push) Failing after 1s
Test External Providers Installed via Module / test-external-providers-from-module (venv) (push) Has been skipped
Test External API and Providers / test-external (venv) (push) Failing after 1s
UI Tests / ui-tests (22) (push) Failing after 0s
Unit Tests / unit-tests (3.12) (push) Failing after 1s
Unit Tests / unit-tests (3.13) (push) Failing after 1s
# What does this PR do? add the ability to use inequalities in the where clause of the sqlstore. this is infrastructure for files expiration. ## Test Plan unit tests
This commit is contained in:
parent
30117dea22
commit
ed418653ec
2 changed files with 85 additions and 3 deletions
|
@ -23,6 +23,7 @@ from sqlalchemy import (
|
|||
)
|
||||
from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine
|
||||
from sqlalchemy.ext.asyncio.engine import AsyncEngine
|
||||
from sqlalchemy.sql.elements import ColumnElement
|
||||
|
||||
from llama_stack.apis.common.responses import PaginatedResponse
|
||||
from llama_stack.log import get_logger
|
||||
|
@ -43,6 +44,30 @@ TYPE_MAPPING: dict[ColumnType, Any] = {
|
|||
}
|
||||
|
||||
|
||||
def _build_where_expr(column: ColumnElement, value: Any) -> ColumnElement:
|
||||
"""Return a SQLAlchemy expression for a where condition.
|
||||
|
||||
`value` may be a simple scalar (equality) or a mapping like {">": 123}.
|
||||
The returned expression is a SQLAlchemy ColumnElement usable in query.where(...).
|
||||
"""
|
||||
if isinstance(value, Mapping):
|
||||
if len(value) != 1:
|
||||
raise ValueError(f"Operator mapping must have a single operator, got: {value}")
|
||||
op, operand = next(iter(value.items()))
|
||||
if op == "==" or op == "=":
|
||||
return column == operand
|
||||
if op == ">":
|
||||
return column > operand
|
||||
if op == "<":
|
||||
return column < operand
|
||||
if op == ">=":
|
||||
return column >= operand
|
||||
if op == "<=":
|
||||
return column <= operand
|
||||
raise ValueError(f"Unsupported operator '{op}' in where mapping")
|
||||
return column == value
|
||||
|
||||
|
||||
class SqlAlchemySqlStoreImpl(SqlStore):
|
||||
def __init__(self, config: SqlAlchemySqlStoreConfig):
|
||||
self.config = config
|
||||
|
@ -111,7 +136,7 @@ class SqlAlchemySqlStoreImpl(SqlStore):
|
|||
|
||||
if where:
|
||||
for key, value in where.items():
|
||||
query = query.where(table_obj.c[key] == value)
|
||||
query = query.where(_build_where_expr(table_obj.c[key], value))
|
||||
|
||||
if where_sql:
|
||||
query = query.where(text(where_sql))
|
||||
|
@ -222,7 +247,7 @@ class SqlAlchemySqlStoreImpl(SqlStore):
|
|||
async with self.async_session() as session:
|
||||
stmt = self.metadata.tables[table].update()
|
||||
for key, value in where.items():
|
||||
stmt = stmt.where(self.metadata.tables[table].c[key] == value)
|
||||
stmt = stmt.where(_build_where_expr(self.metadata.tables[table].c[key], value))
|
||||
await session.execute(stmt, data)
|
||||
await session.commit()
|
||||
|
||||
|
@ -233,7 +258,7 @@ class SqlAlchemySqlStoreImpl(SqlStore):
|
|||
async with self.async_session() as session:
|
||||
stmt = self.metadata.tables[table].delete()
|
||||
for key, value in where.items():
|
||||
stmt = stmt.where(self.metadata.tables[table].c[key] == value)
|
||||
stmt = stmt.where(_build_where_expr(self.metadata.tables[table].c[key], value))
|
||||
await session.execute(stmt)
|
||||
await session.commit()
|
||||
|
||||
|
|
|
@ -332,6 +332,63 @@ async def test_sqlstore_pagination_error_handling():
|
|||
)
|
||||
|
||||
|
||||
async def test_where_operator_gt_and_update_delete():
|
||||
with TemporaryDirectory() as tmp_dir:
|
||||
db_path = tmp_dir + "/test.db"
|
||||
store = SqlAlchemySqlStoreImpl(SqliteSqlStoreConfig(db_path=db_path))
|
||||
|
||||
await store.create_table(
|
||||
"items",
|
||||
{
|
||||
"id": ColumnType.INTEGER,
|
||||
"value": ColumnType.INTEGER,
|
||||
"name": ColumnType.STRING,
|
||||
},
|
||||
)
|
||||
|
||||
await store.insert("items", {"id": 1, "value": 10, "name": "one"})
|
||||
await store.insert("items", {"id": 2, "value": 20, "name": "two"})
|
||||
await store.insert("items", {"id": 3, "value": 30, "name": "three"})
|
||||
|
||||
result = await store.fetch_all("items", where={"value": {">": 15}})
|
||||
assert {r["id"] for r in result.data} == {2, 3}
|
||||
|
||||
row = await store.fetch_one("items", where={"value": {">=": 30}})
|
||||
assert row["id"] == 3
|
||||
|
||||
await store.update("items", {"name": "small"}, {"value": {"<": 25}})
|
||||
rows = (await store.fetch_all("items")).data
|
||||
names = {r["id"]: r["name"] for r in rows}
|
||||
assert names[1] == "small"
|
||||
assert names[2] == "small"
|
||||
assert names[3] == "three"
|
||||
|
||||
await store.delete("items", {"id": {"==": 2}})
|
||||
rows_after = (await store.fetch_all("items")).data
|
||||
assert {r["id"] for r in rows_after} == {1, 3}
|
||||
|
||||
|
||||
async def test_where_operator_edge_cases():
|
||||
with TemporaryDirectory() as tmp_dir:
|
||||
db_path = tmp_dir + "/test.db"
|
||||
store = SqlAlchemySqlStoreImpl(SqliteSqlStoreConfig(db_path=db_path))
|
||||
|
||||
await store.create_table(
|
||||
"events",
|
||||
{"id": ColumnType.STRING, "ts": ColumnType.INTEGER},
|
||||
)
|
||||
|
||||
base = 1024
|
||||
await store.insert("events", {"id": "a", "ts": base - 10})
|
||||
await store.insert("events", {"id": "b", "ts": base + 10})
|
||||
|
||||
row = await store.fetch_one("events", where={"id": "a"})
|
||||
assert row["id"] == "a"
|
||||
|
||||
with pytest.raises(ValueError, match="Unsupported operator"):
|
||||
await store.fetch_all("events", where={"ts": {"!=": base}})
|
||||
|
||||
|
||||
async def test_sqlstore_pagination_custom_key_column():
|
||||
"""Test pagination with custom primary key column (not 'id')."""
|
||||
with TemporaryDirectory() as tmp_dir:
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue