adding batch inserts to sqlstore and authorized sqlstorealong with type adjustments and conversations

Signed-off-by: Francisco Javier Arceo <farceo@redhat.com>
This commit is contained in:
Francisco Javier Arceo 2025-10-03 10:27:52 -04:00
parent d2ce672d4b
commit fe695ca475
5 changed files with 54 additions and 17 deletions

View file

@ -4,7 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from collections.abc import Mapping
from collections.abc import Mapping, Sequence
from enum import Enum
from typing import Any, Literal, Protocol
@ -41,9 +41,9 @@ class SqlStore(Protocol):
"""
pass
async def insert(self, table: str, data: Mapping[str, Any]) -> None:
async def insert(self, table: str, data: Mapping[str, Any] | Sequence[Mapping[str, Any]]) -> None:
"""
Insert a row into a table.
Insert a row or batch of rows into a table.
"""
pass

View file

@ -4,7 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from collections.abc import Mapping
from collections.abc import Mapping, Sequence
from typing import Any, Literal
from llama_stack.core.access_control.access_control import default_policy, is_action_allowed
@ -38,6 +38,18 @@ SQL_OPTIMIZED_POLICY = [
]
def _enhance_item_with_access_control(item: Mapping[str, Any], current_user: User | None) -> Mapping[str, Any]:
"""Add access control attributes to a data item."""
enhanced = dict(item)
if current_user:
enhanced["owner_principal"] = current_user.principal
enhanced["access_attributes"] = current_user.attributes
else:
enhanced["owner_principal"] = None
enhanced["access_attributes"] = None
return enhanced
class SqlRecord(ProtectedResource):
def __init__(self, record_id: str, table_name: str, owner: User):
self.type = f"sql_record::{table_name}"
@ -102,18 +114,14 @@ class AuthorizedSqlStore:
await self.sql_store.add_column_if_not_exists(table, "access_attributes", ColumnType.JSON)
await self.sql_store.add_column_if_not_exists(table, "owner_principal", ColumnType.STRING)
async def insert(self, table: str, data: Mapping[str, Any]) -> None:
"""Insert a row with automatic access control attribute capture."""
enhanced_data = dict(data)
async def insert(self, table: str, data: Mapping[str, Any] | Sequence[Mapping[str, Any]]) -> None:
"""Insert a row or batch of rows with automatic access control attribute capture."""
current_user = get_authenticated_user()
if current_user:
enhanced_data["owner_principal"] = current_user.principal
enhanced_data["access_attributes"] = current_user.attributes
enhanced_data: Mapping[str, Any] | Sequence[Mapping[str, Any]]
if isinstance(data, Mapping):
enhanced_data = _enhance_item_with_access_control(data, current_user)
else:
enhanced_data["owner_principal"] = None
enhanced_data["access_attributes"] = None
enhanced_data = [_enhance_item_with_access_control(item, current_user) for item in data]
await self.sql_store.insert(table, enhanced_data)
async def fetch_all(

View file

@ -3,7 +3,7 @@
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from collections.abc import Mapping
from collections.abc import Mapping, Sequence
from typing import Any, Literal
from sqlalchemy import (
@ -116,7 +116,7 @@ class SqlAlchemySqlStoreImpl(SqlStore):
async with engine.begin() as conn:
await conn.run_sync(self.metadata.create_all, tables=[sqlalchemy_table], checkfirst=True)
async def insert(self, table: str, data: Mapping[str, Any]) -> None:
async def insert(self, table: str, data: Mapping[str, Any] | Sequence[Mapping[str, Any]]) -> None:
async with self.async_session() as session:
await session.execute(self.metadata.tables[table].insert(), data)
await session.commit()