adding batch inserts to sqlstore and authorized sqlstorealong with type adjustments and conversations

Signed-off-by: Francisco Javier Arceo <farceo@redhat.com>
This commit is contained in:
Francisco Javier Arceo 2025-10-03 10:27:52 -04:00
parent d2ce672d4b
commit fe695ca475
5 changed files with 54 additions and 17 deletions

View file

@ -112,6 +112,7 @@ class ConversationServiceImpl(Conversations):
) )
if items: if items:
item_records = []
for item in items: for item in items:
item_dict = item.model_dump() item_dict = item.model_dump()
item_id = self._get_or_generate_item_id(item, item_dict) item_id = self._get_or_generate_item_id(item, item_dict)
@ -123,7 +124,9 @@ class ConversationServiceImpl(Conversations):
"item_data": item_dict, "item_data": item_dict,
} }
await self.sql_store.insert(table="conversation_items", data=item_record) item_records.append(item_record)
await self.sql_store.insert(table="conversation_items", data=item_records)
conversation = Conversation( conversation = Conversation(
id=conversation_id, id=conversation_id,

View file

@ -4,7 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from collections.abc import Mapping from collections.abc import Mapping, Sequence
from enum import Enum from enum import Enum
from typing import Any, Literal, Protocol from typing import Any, Literal, Protocol
@ -41,9 +41,9 @@ class SqlStore(Protocol):
""" """
pass pass
async def insert(self, table: str, data: Mapping[str, Any]) -> None: async def insert(self, table: str, data: Mapping[str, Any] | Sequence[Mapping[str, Any]]) -> None:
""" """
Insert a row into a table. Insert a row or batch of rows into a table.
""" """
pass pass

View file

@ -4,7 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from collections.abc import Mapping from collections.abc import Mapping, Sequence
from typing import Any, Literal from typing import Any, Literal
from llama_stack.core.access_control.access_control import default_policy, is_action_allowed from llama_stack.core.access_control.access_control import default_policy, is_action_allowed
@ -38,6 +38,18 @@ SQL_OPTIMIZED_POLICY = [
] ]
def _enhance_item_with_access_control(item: Mapping[str, Any], current_user: User | None) -> Mapping[str, Any]:
"""Add access control attributes to a data item."""
enhanced = dict(item)
if current_user:
enhanced["owner_principal"] = current_user.principal
enhanced["access_attributes"] = current_user.attributes
else:
enhanced["owner_principal"] = None
enhanced["access_attributes"] = None
return enhanced
class SqlRecord(ProtectedResource): class SqlRecord(ProtectedResource):
def __init__(self, record_id: str, table_name: str, owner: User): def __init__(self, record_id: str, table_name: str, owner: User):
self.type = f"sql_record::{table_name}" self.type = f"sql_record::{table_name}"
@ -102,18 +114,14 @@ class AuthorizedSqlStore:
await self.sql_store.add_column_if_not_exists(table, "access_attributes", ColumnType.JSON) await self.sql_store.add_column_if_not_exists(table, "access_attributes", ColumnType.JSON)
await self.sql_store.add_column_if_not_exists(table, "owner_principal", ColumnType.STRING) await self.sql_store.add_column_if_not_exists(table, "owner_principal", ColumnType.STRING)
async def insert(self, table: str, data: Mapping[str, Any]) -> None: async def insert(self, table: str, data: Mapping[str, Any] | Sequence[Mapping[str, Any]]) -> None:
"""Insert a row with automatic access control attribute capture.""" """Insert a row or batch of rows with automatic access control attribute capture."""
enhanced_data = dict(data)
current_user = get_authenticated_user() current_user = get_authenticated_user()
if current_user: enhanced_data: Mapping[str, Any] | Sequence[Mapping[str, Any]]
enhanced_data["owner_principal"] = current_user.principal if isinstance(data, Mapping):
enhanced_data["access_attributes"] = current_user.attributes enhanced_data = _enhance_item_with_access_control(data, current_user)
else: else:
enhanced_data["owner_principal"] = None enhanced_data = [_enhance_item_with_access_control(item, current_user) for item in data]
enhanced_data["access_attributes"] = None
await self.sql_store.insert(table, enhanced_data) await self.sql_store.insert(table, enhanced_data)
async def fetch_all( async def fetch_all(

View file

@ -3,7 +3,7 @@
# #
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from collections.abc import Mapping from collections.abc import Mapping, Sequence
from typing import Any, Literal from typing import Any, Literal
from sqlalchemy import ( from sqlalchemy import (
@ -116,7 +116,7 @@ class SqlAlchemySqlStoreImpl(SqlStore):
async with engine.begin() as conn: async with engine.begin() as conn:
await conn.run_sync(self.metadata.create_all, tables=[sqlalchemy_table], checkfirst=True) await conn.run_sync(self.metadata.create_all, tables=[sqlalchemy_table], checkfirst=True)
async def insert(self, table: str, data: Mapping[str, Any]) -> None: async def insert(self, table: str, data: Mapping[str, Any] | Sequence[Mapping[str, Any]]) -> None:
async with self.async_session() as session: async with self.async_session() as session:
await session.execute(self.metadata.tables[table].insert(), data) await session.execute(self.metadata.tables[table].insert(), data)
await session.commit() await session.commit()

View file

@ -368,6 +368,32 @@ async def test_where_operator_gt_and_update_delete():
assert {r["id"] for r in rows_after} == {1, 3} assert {r["id"] for r in rows_after} == {1, 3}
async def test_batch_insert():
with TemporaryDirectory() as tmp_dir:
db_path = tmp_dir + "/test.db"
store = SqlAlchemySqlStoreImpl(SqliteSqlStoreConfig(db_path=db_path))
await store.create_table(
"batch_test",
{
"id": ColumnType.INTEGER,
"name": ColumnType.STRING,
"value": ColumnType.INTEGER,
},
)
batch_data = [
{"id": 1, "name": "first", "value": 10},
{"id": 2, "name": "second", "value": 20},
{"id": 3, "name": "third", "value": 30},
]
await store.insert("batch_test", batch_data)
result = await store.fetch_all("batch_test", order_by=[("id", "asc")])
assert result.data == batch_data
async def test_where_operator_edge_cases(): async def test_where_operator_edge_cases():
with TemporaryDirectory() as tmp_dir: with TemporaryDirectory() as tmp_dir:
db_path = tmp_dir + "/test.db" db_path = tmp_dir + "/test.db"