From fe695ca475537a5b722741ddeff5d51468775e12 Mon Sep 17 00:00:00 2001 From: Francisco Javier Arceo Date: Fri, 3 Oct 2025 10:27:52 -0400 Subject: [PATCH] adding batch inserts to sqlstore and authorized sqlstorealong with type adjustments and conversations Signed-off-by: Francisco Javier Arceo --- .../core/conversations/conversations.py | 5 +++- llama_stack/providers/utils/sqlstore/api.py | 6 ++-- .../utils/sqlstore/authorized_sqlstore.py | 30 ++++++++++++------- .../utils/sqlstore/sqlalchemy_sqlstore.py | 4 +-- tests/unit/utils/sqlstore/test_sqlstore.py | 26 ++++++++++++++++ 5 files changed, 54 insertions(+), 17 deletions(-) diff --git a/llama_stack/core/conversations/conversations.py b/llama_stack/core/conversations/conversations.py index 40173a0be..bef138e69 100644 --- a/llama_stack/core/conversations/conversations.py +++ b/llama_stack/core/conversations/conversations.py @@ -112,6 +112,7 @@ class ConversationServiceImpl(Conversations): ) if items: + item_records = [] for item in items: item_dict = item.model_dump() item_id = self._get_or_generate_item_id(item, item_dict) @@ -123,7 +124,9 @@ class ConversationServiceImpl(Conversations): "item_data": item_dict, } - await self.sql_store.insert(table="conversation_items", data=item_record) + item_records.append(item_record) + + await self.sql_store.insert(table="conversation_items", data=item_records) conversation = Conversation( id=conversation_id, diff --git a/llama_stack/providers/utils/sqlstore/api.py b/llama_stack/providers/utils/sqlstore/api.py index 6bb85ea0c..a61fd1090 100644 --- a/llama_stack/providers/utils/sqlstore/api.py +++ b/llama_stack/providers/utils/sqlstore/api.py @@ -4,7 +4,7 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from collections.abc import Mapping +from collections.abc import Mapping, Sequence from enum import Enum from typing import Any, Literal, Protocol @@ -41,9 +41,9 @@ class SqlStore(Protocol): """ pass - async def insert(self, table: str, data: Mapping[str, Any]) -> None: + async def insert(self, table: str, data: Mapping[str, Any] | Sequence[Mapping[str, Any]]) -> None: """ - Insert a row into a table. + Insert a row or batch of rows into a table. """ pass diff --git a/llama_stack/providers/utils/sqlstore/authorized_sqlstore.py b/llama_stack/providers/utils/sqlstore/authorized_sqlstore.py index ab67f7052..e1da4db6e 100644 --- a/llama_stack/providers/utils/sqlstore/authorized_sqlstore.py +++ b/llama_stack/providers/utils/sqlstore/authorized_sqlstore.py @@ -4,7 +4,7 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from collections.abc import Mapping +from collections.abc import Mapping, Sequence from typing import Any, Literal from llama_stack.core.access_control.access_control import default_policy, is_action_allowed @@ -38,6 +38,18 @@ SQL_OPTIMIZED_POLICY = [ ] +def _enhance_item_with_access_control(item: Mapping[str, Any], current_user: User | None) -> Mapping[str, Any]: + """Add access control attributes to a data item.""" + enhanced = dict(item) + if current_user: + enhanced["owner_principal"] = current_user.principal + enhanced["access_attributes"] = current_user.attributes + else: + enhanced["owner_principal"] = None + enhanced["access_attributes"] = None + return enhanced + + class SqlRecord(ProtectedResource): def __init__(self, record_id: str, table_name: str, owner: User): self.type = f"sql_record::{table_name}" @@ -102,18 +114,14 @@ class AuthorizedSqlStore: await self.sql_store.add_column_if_not_exists(table, "access_attributes", ColumnType.JSON) await self.sql_store.add_column_if_not_exists(table, "owner_principal", ColumnType.STRING) - async def insert(self, table: str, data: Mapping[str, Any]) -> None: - """Insert a row with automatic access control attribute capture.""" - enhanced_data = dict(data) - + async def insert(self, table: str, data: Mapping[str, Any] | Sequence[Mapping[str, Any]]) -> None: + """Insert a row or batch of rows with automatic access control attribute capture.""" current_user = get_authenticated_user() - if current_user: - enhanced_data["owner_principal"] = current_user.principal - enhanced_data["access_attributes"] = current_user.attributes + enhanced_data: Mapping[str, Any] | Sequence[Mapping[str, Any]] + if isinstance(data, Mapping): + enhanced_data = _enhance_item_with_access_control(data, current_user) else: - enhanced_data["owner_principal"] = None - enhanced_data["access_attributes"] = None - + enhanced_data = [_enhance_item_with_access_control(item, current_user) for item in data] await self.sql_store.insert(table, enhanced_data) async def fetch_all( diff --git a/llama_stack/providers/utils/sqlstore/sqlalchemy_sqlstore.py b/llama_stack/providers/utils/sqlstore/sqlalchemy_sqlstore.py index 46ed8c1d1..23cd6444e 100644 --- a/llama_stack/providers/utils/sqlstore/sqlalchemy_sqlstore.py +++ b/llama_stack/providers/utils/sqlstore/sqlalchemy_sqlstore.py @@ -3,7 +3,7 @@ # # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from collections.abc import Mapping +from collections.abc import Mapping, Sequence from typing import Any, Literal from sqlalchemy import ( @@ -116,7 +116,7 @@ class SqlAlchemySqlStoreImpl(SqlStore): async with engine.begin() as conn: await conn.run_sync(self.metadata.create_all, tables=[sqlalchemy_table], checkfirst=True) - async def insert(self, table: str, data: Mapping[str, Any]) -> None: + async def insert(self, table: str, data: Mapping[str, Any] | Sequence[Mapping[str, Any]]) -> None: async with self.async_session() as session: await session.execute(self.metadata.tables[table].insert(), data) await session.commit() diff --git a/tests/unit/utils/sqlstore/test_sqlstore.py b/tests/unit/utils/sqlstore/test_sqlstore.py index ba59ec7ec..00669b698 100644 --- a/tests/unit/utils/sqlstore/test_sqlstore.py +++ b/tests/unit/utils/sqlstore/test_sqlstore.py @@ -368,6 +368,32 @@ async def test_where_operator_gt_and_update_delete(): assert {r["id"] for r in rows_after} == {1, 3} +async def test_batch_insert(): + with TemporaryDirectory() as tmp_dir: + db_path = tmp_dir + "/test.db" + store = SqlAlchemySqlStoreImpl(SqliteSqlStoreConfig(db_path=db_path)) + + await store.create_table( + "batch_test", + { + "id": ColumnType.INTEGER, + "name": ColumnType.STRING, + "value": ColumnType.INTEGER, + }, + ) + + batch_data = [ + {"id": 1, "name": "first", "value": 10}, + {"id": 2, "name": "second", "value": 20}, + {"id": 3, "name": "third", "value": 30}, + ] + + await store.insert("batch_test", batch_data) + + result = await store.fetch_all("batch_test", order_by=[("id", "asc")]) + assert result.data == batch_data + + async def test_where_operator_edge_cases(): with TemporaryDirectory() as tmp_dir: db_path = tmp_dir + "/test.db"