mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-04 04:04:14 +00:00
adding batch inserts to sqlstore and authorized sqlstorealong with type adjustments and conversations
Signed-off-by: Francisco Javier Arceo <farceo@redhat.com>
This commit is contained in:
parent
d2ce672d4b
commit
fe695ca475
5 changed files with 54 additions and 17 deletions
|
@ -112,6 +112,7 @@ class ConversationServiceImpl(Conversations):
|
||||||
)
|
)
|
||||||
|
|
||||||
if items:
|
if items:
|
||||||
|
item_records = []
|
||||||
for item in items:
|
for item in items:
|
||||||
item_dict = item.model_dump()
|
item_dict = item.model_dump()
|
||||||
item_id = self._get_or_generate_item_id(item, item_dict)
|
item_id = self._get_or_generate_item_id(item, item_dict)
|
||||||
|
@ -123,7 +124,9 @@ class ConversationServiceImpl(Conversations):
|
||||||
"item_data": item_dict,
|
"item_data": item_dict,
|
||||||
}
|
}
|
||||||
|
|
||||||
await self.sql_store.insert(table="conversation_items", data=item_record)
|
item_records.append(item_record)
|
||||||
|
|
||||||
|
await self.sql_store.insert(table="conversation_items", data=item_records)
|
||||||
|
|
||||||
conversation = Conversation(
|
conversation = Conversation(
|
||||||
id=conversation_id,
|
id=conversation_id,
|
||||||
|
|
|
@ -4,7 +4,7 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from collections.abc import Mapping
|
from collections.abc import Mapping, Sequence
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Any, Literal, Protocol
|
from typing import Any, Literal, Protocol
|
||||||
|
|
||||||
|
@ -41,9 +41,9 @@ class SqlStore(Protocol):
|
||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
async def insert(self, table: str, data: Mapping[str, Any]) -> None:
|
async def insert(self, table: str, data: Mapping[str, Any] | Sequence[Mapping[str, Any]]) -> None:
|
||||||
"""
|
"""
|
||||||
Insert a row into a table.
|
Insert a row or batch of rows into a table.
|
||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
|
@ -4,7 +4,7 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from collections.abc import Mapping
|
from collections.abc import Mapping, Sequence
|
||||||
from typing import Any, Literal
|
from typing import Any, Literal
|
||||||
|
|
||||||
from llama_stack.core.access_control.access_control import default_policy, is_action_allowed
|
from llama_stack.core.access_control.access_control import default_policy, is_action_allowed
|
||||||
|
@ -38,6 +38,18 @@ SQL_OPTIMIZED_POLICY = [
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def _enhance_item_with_access_control(item: Mapping[str, Any], current_user: User | None) -> Mapping[str, Any]:
|
||||||
|
"""Add access control attributes to a data item."""
|
||||||
|
enhanced = dict(item)
|
||||||
|
if current_user:
|
||||||
|
enhanced["owner_principal"] = current_user.principal
|
||||||
|
enhanced["access_attributes"] = current_user.attributes
|
||||||
|
else:
|
||||||
|
enhanced["owner_principal"] = None
|
||||||
|
enhanced["access_attributes"] = None
|
||||||
|
return enhanced
|
||||||
|
|
||||||
|
|
||||||
class SqlRecord(ProtectedResource):
|
class SqlRecord(ProtectedResource):
|
||||||
def __init__(self, record_id: str, table_name: str, owner: User):
|
def __init__(self, record_id: str, table_name: str, owner: User):
|
||||||
self.type = f"sql_record::{table_name}"
|
self.type = f"sql_record::{table_name}"
|
||||||
|
@ -102,18 +114,14 @@ class AuthorizedSqlStore:
|
||||||
await self.sql_store.add_column_if_not_exists(table, "access_attributes", ColumnType.JSON)
|
await self.sql_store.add_column_if_not_exists(table, "access_attributes", ColumnType.JSON)
|
||||||
await self.sql_store.add_column_if_not_exists(table, "owner_principal", ColumnType.STRING)
|
await self.sql_store.add_column_if_not_exists(table, "owner_principal", ColumnType.STRING)
|
||||||
|
|
||||||
async def insert(self, table: str, data: Mapping[str, Any]) -> None:
|
async def insert(self, table: str, data: Mapping[str, Any] | Sequence[Mapping[str, Any]]) -> None:
|
||||||
"""Insert a row with automatic access control attribute capture."""
|
"""Insert a row or batch of rows with automatic access control attribute capture."""
|
||||||
enhanced_data = dict(data)
|
|
||||||
|
|
||||||
current_user = get_authenticated_user()
|
current_user = get_authenticated_user()
|
||||||
if current_user:
|
enhanced_data: Mapping[str, Any] | Sequence[Mapping[str, Any]]
|
||||||
enhanced_data["owner_principal"] = current_user.principal
|
if isinstance(data, Mapping):
|
||||||
enhanced_data["access_attributes"] = current_user.attributes
|
enhanced_data = _enhance_item_with_access_control(data, current_user)
|
||||||
else:
|
else:
|
||||||
enhanced_data["owner_principal"] = None
|
enhanced_data = [_enhance_item_with_access_control(item, current_user) for item in data]
|
||||||
enhanced_data["access_attributes"] = None
|
|
||||||
|
|
||||||
await self.sql_store.insert(table, enhanced_data)
|
await self.sql_store.insert(table, enhanced_data)
|
||||||
|
|
||||||
async def fetch_all(
|
async def fetch_all(
|
||||||
|
|
|
@ -3,7 +3,7 @@
|
||||||
#
|
#
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
from collections.abc import Mapping
|
from collections.abc import Mapping, Sequence
|
||||||
from typing import Any, Literal
|
from typing import Any, Literal
|
||||||
|
|
||||||
from sqlalchemy import (
|
from sqlalchemy import (
|
||||||
|
@ -116,7 +116,7 @@ class SqlAlchemySqlStoreImpl(SqlStore):
|
||||||
async with engine.begin() as conn:
|
async with engine.begin() as conn:
|
||||||
await conn.run_sync(self.metadata.create_all, tables=[sqlalchemy_table], checkfirst=True)
|
await conn.run_sync(self.metadata.create_all, tables=[sqlalchemy_table], checkfirst=True)
|
||||||
|
|
||||||
async def insert(self, table: str, data: Mapping[str, Any]) -> None:
|
async def insert(self, table: str, data: Mapping[str, Any] | Sequence[Mapping[str, Any]]) -> None:
|
||||||
async with self.async_session() as session:
|
async with self.async_session() as session:
|
||||||
await session.execute(self.metadata.tables[table].insert(), data)
|
await session.execute(self.metadata.tables[table].insert(), data)
|
||||||
await session.commit()
|
await session.commit()
|
||||||
|
|
|
@ -368,6 +368,32 @@ async def test_where_operator_gt_and_update_delete():
|
||||||
assert {r["id"] for r in rows_after} == {1, 3}
|
assert {r["id"] for r in rows_after} == {1, 3}
|
||||||
|
|
||||||
|
|
||||||
|
async def test_batch_insert():
|
||||||
|
with TemporaryDirectory() as tmp_dir:
|
||||||
|
db_path = tmp_dir + "/test.db"
|
||||||
|
store = SqlAlchemySqlStoreImpl(SqliteSqlStoreConfig(db_path=db_path))
|
||||||
|
|
||||||
|
await store.create_table(
|
||||||
|
"batch_test",
|
||||||
|
{
|
||||||
|
"id": ColumnType.INTEGER,
|
||||||
|
"name": ColumnType.STRING,
|
||||||
|
"value": ColumnType.INTEGER,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
batch_data = [
|
||||||
|
{"id": 1, "name": "first", "value": 10},
|
||||||
|
{"id": 2, "name": "second", "value": 20},
|
||||||
|
{"id": 3, "name": "third", "value": 30},
|
||||||
|
]
|
||||||
|
|
||||||
|
await store.insert("batch_test", batch_data)
|
||||||
|
|
||||||
|
result = await store.fetch_all("batch_test", order_by=[("id", "asc")])
|
||||||
|
assert result.data == batch_data
|
||||||
|
|
||||||
|
|
||||||
async def test_where_operator_edge_cases():
|
async def test_where_operator_edge_cases():
|
||||||
with TemporaryDirectory() as tmp_dir:
|
with TemporaryDirectory() as tmp_dir:
|
||||||
db_path = tmp_dir + "/test.db"
|
db_path = tmp_dir + "/test.db"
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue