mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-04 04:04:14 +00:00
chore(dev): add inequality support to sqlstore where clause (#3272)
Some checks failed
Integration Auth Tests / test-matrix (oauth2_token) (push) Failing after 1s
Integration Tests (Replay) / Integration Tests (, , , client=, vision=) (push) Failing after 1s
SqlStore Integration Tests / test-postgres (3.12) (push) Failing after 2s
SqlStore Integration Tests / test-postgres (3.13) (push) Failing after 1s
Vector IO Integration Tests / test-matrix (push) Failing after 1s
Pre-commit / pre-commit (push) Failing after 1s
Python Package Build Test / build (3.12) (push) Failing after 0s
Python Package Build Test / build (3.13) (push) Failing after 1s
Test External Providers Installed via Module / test-external-providers-from-module (venv) (push) Has been skipped
Test External API and Providers / test-external (venv) (push) Failing after 1s
UI Tests / ui-tests (22) (push) Failing after 0s
Unit Tests / unit-tests (3.12) (push) Failing after 1s
Unit Tests / unit-tests (3.13) (push) Failing after 1s
Some checks failed
Integration Auth Tests / test-matrix (oauth2_token) (push) Failing after 1s
Integration Tests (Replay) / Integration Tests (, , , client=, vision=) (push) Failing after 1s
SqlStore Integration Tests / test-postgres (3.12) (push) Failing after 2s
SqlStore Integration Tests / test-postgres (3.13) (push) Failing after 1s
Vector IO Integration Tests / test-matrix (push) Failing after 1s
Pre-commit / pre-commit (push) Failing after 1s
Python Package Build Test / build (3.12) (push) Failing after 0s
Python Package Build Test / build (3.13) (push) Failing after 1s
Test External Providers Installed via Module / test-external-providers-from-module (venv) (push) Has been skipped
Test External API and Providers / test-external (venv) (push) Failing after 1s
UI Tests / ui-tests (22) (push) Failing after 0s
Unit Tests / unit-tests (3.12) (push) Failing after 1s
Unit Tests / unit-tests (3.13) (push) Failing after 1s
# What does this PR do? add the ability to use inequalities in the where clause of the sqlstore. this is infrastructure for files expiration. ## Test Plan unit tests
This commit is contained in:
parent
30117dea22
commit
ed418653ec
2 changed files with 85 additions and 3 deletions
|
@ -23,6 +23,7 @@ from sqlalchemy import (
|
||||||
)
|
)
|
||||||
from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine
|
from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine
|
||||||
from sqlalchemy.ext.asyncio.engine import AsyncEngine
|
from sqlalchemy.ext.asyncio.engine import AsyncEngine
|
||||||
|
from sqlalchemy.sql.elements import ColumnElement
|
||||||
|
|
||||||
from llama_stack.apis.common.responses import PaginatedResponse
|
from llama_stack.apis.common.responses import PaginatedResponse
|
||||||
from llama_stack.log import get_logger
|
from llama_stack.log import get_logger
|
||||||
|
@ -43,6 +44,30 @@ TYPE_MAPPING: dict[ColumnType, Any] = {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _build_where_expr(column: ColumnElement, value: Any) -> ColumnElement:
|
||||||
|
"""Return a SQLAlchemy expression for a where condition.
|
||||||
|
|
||||||
|
`value` may be a simple scalar (equality) or a mapping like {">": 123}.
|
||||||
|
The returned expression is a SQLAlchemy ColumnElement usable in query.where(...).
|
||||||
|
"""
|
||||||
|
if isinstance(value, Mapping):
|
||||||
|
if len(value) != 1:
|
||||||
|
raise ValueError(f"Operator mapping must have a single operator, got: {value}")
|
||||||
|
op, operand = next(iter(value.items()))
|
||||||
|
if op == "==" or op == "=":
|
||||||
|
return column == operand
|
||||||
|
if op == ">":
|
||||||
|
return column > operand
|
||||||
|
if op == "<":
|
||||||
|
return column < operand
|
||||||
|
if op == ">=":
|
||||||
|
return column >= operand
|
||||||
|
if op == "<=":
|
||||||
|
return column <= operand
|
||||||
|
raise ValueError(f"Unsupported operator '{op}' in where mapping")
|
||||||
|
return column == value
|
||||||
|
|
||||||
|
|
||||||
class SqlAlchemySqlStoreImpl(SqlStore):
|
class SqlAlchemySqlStoreImpl(SqlStore):
|
||||||
def __init__(self, config: SqlAlchemySqlStoreConfig):
|
def __init__(self, config: SqlAlchemySqlStoreConfig):
|
||||||
self.config = config
|
self.config = config
|
||||||
|
@ -111,7 +136,7 @@ class SqlAlchemySqlStoreImpl(SqlStore):
|
||||||
|
|
||||||
if where:
|
if where:
|
||||||
for key, value in where.items():
|
for key, value in where.items():
|
||||||
query = query.where(table_obj.c[key] == value)
|
query = query.where(_build_where_expr(table_obj.c[key], value))
|
||||||
|
|
||||||
if where_sql:
|
if where_sql:
|
||||||
query = query.where(text(where_sql))
|
query = query.where(text(where_sql))
|
||||||
|
@ -222,7 +247,7 @@ class SqlAlchemySqlStoreImpl(SqlStore):
|
||||||
async with self.async_session() as session:
|
async with self.async_session() as session:
|
||||||
stmt = self.metadata.tables[table].update()
|
stmt = self.metadata.tables[table].update()
|
||||||
for key, value in where.items():
|
for key, value in where.items():
|
||||||
stmt = stmt.where(self.metadata.tables[table].c[key] == value)
|
stmt = stmt.where(_build_where_expr(self.metadata.tables[table].c[key], value))
|
||||||
await session.execute(stmt, data)
|
await session.execute(stmt, data)
|
||||||
await session.commit()
|
await session.commit()
|
||||||
|
|
||||||
|
@ -233,7 +258,7 @@ class SqlAlchemySqlStoreImpl(SqlStore):
|
||||||
async with self.async_session() as session:
|
async with self.async_session() as session:
|
||||||
stmt = self.metadata.tables[table].delete()
|
stmt = self.metadata.tables[table].delete()
|
||||||
for key, value in where.items():
|
for key, value in where.items():
|
||||||
stmt = stmt.where(self.metadata.tables[table].c[key] == value)
|
stmt = stmt.where(_build_where_expr(self.metadata.tables[table].c[key], value))
|
||||||
await session.execute(stmt)
|
await session.execute(stmt)
|
||||||
await session.commit()
|
await session.commit()
|
||||||
|
|
||||||
|
|
|
@ -332,6 +332,63 @@ async def test_sqlstore_pagination_error_handling():
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def test_where_operator_gt_and_update_delete():
|
||||||
|
with TemporaryDirectory() as tmp_dir:
|
||||||
|
db_path = tmp_dir + "/test.db"
|
||||||
|
store = SqlAlchemySqlStoreImpl(SqliteSqlStoreConfig(db_path=db_path))
|
||||||
|
|
||||||
|
await store.create_table(
|
||||||
|
"items",
|
||||||
|
{
|
||||||
|
"id": ColumnType.INTEGER,
|
||||||
|
"value": ColumnType.INTEGER,
|
||||||
|
"name": ColumnType.STRING,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
await store.insert("items", {"id": 1, "value": 10, "name": "one"})
|
||||||
|
await store.insert("items", {"id": 2, "value": 20, "name": "two"})
|
||||||
|
await store.insert("items", {"id": 3, "value": 30, "name": "three"})
|
||||||
|
|
||||||
|
result = await store.fetch_all("items", where={"value": {">": 15}})
|
||||||
|
assert {r["id"] for r in result.data} == {2, 3}
|
||||||
|
|
||||||
|
row = await store.fetch_one("items", where={"value": {">=": 30}})
|
||||||
|
assert row["id"] == 3
|
||||||
|
|
||||||
|
await store.update("items", {"name": "small"}, {"value": {"<": 25}})
|
||||||
|
rows = (await store.fetch_all("items")).data
|
||||||
|
names = {r["id"]: r["name"] for r in rows}
|
||||||
|
assert names[1] == "small"
|
||||||
|
assert names[2] == "small"
|
||||||
|
assert names[3] == "three"
|
||||||
|
|
||||||
|
await store.delete("items", {"id": {"==": 2}})
|
||||||
|
rows_after = (await store.fetch_all("items")).data
|
||||||
|
assert {r["id"] for r in rows_after} == {1, 3}
|
||||||
|
|
||||||
|
|
||||||
|
async def test_where_operator_edge_cases():
|
||||||
|
with TemporaryDirectory() as tmp_dir:
|
||||||
|
db_path = tmp_dir + "/test.db"
|
||||||
|
store = SqlAlchemySqlStoreImpl(SqliteSqlStoreConfig(db_path=db_path))
|
||||||
|
|
||||||
|
await store.create_table(
|
||||||
|
"events",
|
||||||
|
{"id": ColumnType.STRING, "ts": ColumnType.INTEGER},
|
||||||
|
)
|
||||||
|
|
||||||
|
base = 1024
|
||||||
|
await store.insert("events", {"id": "a", "ts": base - 10})
|
||||||
|
await store.insert("events", {"id": "b", "ts": base + 10})
|
||||||
|
|
||||||
|
row = await store.fetch_one("events", where={"id": "a"})
|
||||||
|
assert row["id"] == "a"
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match="Unsupported operator"):
|
||||||
|
await store.fetch_all("events", where={"ts": {"!=": base}})
|
||||||
|
|
||||||
|
|
||||||
async def test_sqlstore_pagination_custom_key_column():
|
async def test_sqlstore_pagination_custom_key_column():
|
||||||
"""Test pagination with custom primary key column (not 'id')."""
|
"""Test pagination with custom primary key column (not 'id')."""
|
||||||
with TemporaryDirectory() as tmp_dir:
|
with TemporaryDirectory() as tmp_dir:
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue