mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-07 20:50:52 +00:00
chore: simplify authorized sqlstore
# What does this PR do? ## Test Plan
This commit is contained in:
parent
d3600b92d1
commit
b4974d411d
7 changed files with 32 additions and 37 deletions
|
@ -44,7 +44,7 @@ class LocalfsFilesImpl(Files):
|
|||
storage_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Initialize SQL store for metadata
|
||||
self.sql_store = AuthorizedSqlStore(sqlstore_impl(self.config.metadata_store))
|
||||
self.sql_store = AuthorizedSqlStore(sqlstore_impl(self.config.metadata_store), self.policy)
|
||||
await self.sql_store.create_table(
|
||||
"openai_files",
|
||||
{
|
||||
|
@ -74,7 +74,7 @@ class LocalfsFilesImpl(Files):
|
|||
if not self.sql_store:
|
||||
raise RuntimeError("Files provider not initialized")
|
||||
|
||||
row = await self.sql_store.fetch_one("openai_files", policy=self.policy, where={"id": file_id})
|
||||
row = await self.sql_store.fetch_one("openai_files", where={"id": file_id})
|
||||
if not row:
|
||||
raise ResourceNotFoundError(file_id, "File", "client.files.list()")
|
||||
|
||||
|
@ -150,7 +150,6 @@ class LocalfsFilesImpl(Files):
|
|||
|
||||
paginated_result = await self.sql_store.fetch_all(
|
||||
table="openai_files",
|
||||
policy=self.policy,
|
||||
where=where_conditions if where_conditions else None,
|
||||
order_by=[("created_at", order.value)],
|
||||
cursor=("id", after) if after else None,
|
||||
|
|
|
@ -137,7 +137,7 @@ class S3FilesImpl(Files):
|
|||
where: dict[str, str | dict] = {"id": file_id}
|
||||
if not return_expired:
|
||||
where["expires_at"] = {">": self._now()}
|
||||
if not (row := await self.sql_store.fetch_one("openai_files", policy=self.policy, where=where)):
|
||||
if not (row := await self.sql_store.fetch_one("openai_files", where=where)):
|
||||
raise ResourceNotFoundError(file_id, "File", "files.list()")
|
||||
return row
|
||||
|
||||
|
@ -164,7 +164,7 @@ class S3FilesImpl(Files):
|
|||
self._client = _create_s3_client(self._config)
|
||||
await _create_bucket_if_not_exists(self._client, self._config)
|
||||
|
||||
self._sql_store = AuthorizedSqlStore(sqlstore_impl(self._config.metadata_store))
|
||||
self._sql_store = AuthorizedSqlStore(sqlstore_impl(self._config.metadata_store), self.policy)
|
||||
await self._sql_store.create_table(
|
||||
"openai_files",
|
||||
{
|
||||
|
@ -268,7 +268,6 @@ class S3FilesImpl(Files):
|
|||
|
||||
paginated_result = await self.sql_store.fetch_all(
|
||||
table="openai_files",
|
||||
policy=self.policy,
|
||||
where=where_conditions,
|
||||
order_by=[("created_at", order.value)],
|
||||
cursor=("id", after) if after else None,
|
||||
|
|
|
@ -54,7 +54,7 @@ class InferenceStore:
|
|||
|
||||
async def initialize(self):
|
||||
"""Create the necessary tables if they don't exist."""
|
||||
self.sql_store = AuthorizedSqlStore(sqlstore_impl(self.sql_store_config))
|
||||
self.sql_store = AuthorizedSqlStore(sqlstore_impl(self.sql_store_config), self.policy)
|
||||
await self.sql_store.create_table(
|
||||
"chat_completions",
|
||||
{
|
||||
|
@ -202,7 +202,6 @@ class InferenceStore:
|
|||
order_by=[("created", order.value)],
|
||||
cursor=("id", after) if after else None,
|
||||
limit=limit,
|
||||
policy=self.policy,
|
||||
)
|
||||
|
||||
data = [
|
||||
|
@ -229,7 +228,6 @@ class InferenceStore:
|
|||
row = await self.sql_store.fetch_one(
|
||||
table="chat_completions",
|
||||
where={"id": completion_id},
|
||||
policy=self.policy,
|
||||
)
|
||||
|
||||
if not row:
|
||||
|
|
|
@ -28,8 +28,7 @@ class ResponsesStore:
|
|||
sql_store_config = SqliteSqlStoreConfig(
|
||||
db_path=(RUNTIME_BASE_DIR / "sqlstore.db").as_posix(),
|
||||
)
|
||||
self.sql_store = AuthorizedSqlStore(sqlstore_impl(sql_store_config))
|
||||
self.policy = policy
|
||||
self.sql_store = AuthorizedSqlStore(sqlstore_impl(sql_store_config), policy)
|
||||
|
||||
async def initialize(self):
|
||||
"""Create the necessary tables if they don't exist."""
|
||||
|
@ -87,7 +86,6 @@ class ResponsesStore:
|
|||
order_by=[("created_at", order.value)],
|
||||
cursor=("id", after) if after else None,
|
||||
limit=limit,
|
||||
policy=self.policy,
|
||||
)
|
||||
|
||||
data = [OpenAIResponseObjectWithInput(**row["response_object"]) for row in paginated_result.data]
|
||||
|
@ -105,7 +103,6 @@ class ResponsesStore:
|
|||
row = await self.sql_store.fetch_one(
|
||||
"openai_responses",
|
||||
where={"id": response_id},
|
||||
policy=self.policy,
|
||||
)
|
||||
|
||||
if not row:
|
||||
|
@ -116,7 +113,7 @@ class ResponsesStore:
|
|||
return OpenAIResponseObjectWithInput(**row["response_object"])
|
||||
|
||||
async def delete_response_object(self, response_id: str) -> OpenAIDeleteResponseObject:
|
||||
row = await self.sql_store.fetch_one("openai_responses", where={"id": response_id}, policy=self.policy)
|
||||
row = await self.sql_store.fetch_one("openai_responses", where={"id": response_id})
|
||||
if not row:
|
||||
raise ValueError(f"Response with id {response_id} not found")
|
||||
await self.sql_store.delete("openai_responses", where={"id": response_id})
|
||||
|
|
|
@ -53,13 +53,15 @@ class AuthorizedSqlStore:
|
|||
access control policies, user attribute capture, and SQL filtering optimization.
|
||||
"""
|
||||
|
||||
def __init__(self, sql_store: SqlStore):
|
||||
def __init__(self, sql_store: SqlStore, policy: list[AccessRule]):
|
||||
"""
|
||||
Initialize the authorization layer.
|
||||
|
||||
:param sql_store: Base SqlStore implementation to wrap
|
||||
:param policy: Access control policy to use for authorization
|
||||
"""
|
||||
self.sql_store = sql_store
|
||||
self.policy = policy
|
||||
self._detect_database_type()
|
||||
self._validate_sql_optimized_policy()
|
||||
|
||||
|
@ -117,14 +119,13 @@ class AuthorizedSqlStore:
|
|||
async def fetch_all(
|
||||
self,
|
||||
table: str,
|
||||
policy: list[AccessRule],
|
||||
where: Mapping[str, Any] | None = None,
|
||||
limit: int | None = None,
|
||||
order_by: list[tuple[str, Literal["asc", "desc"]]] | None = None,
|
||||
cursor: tuple[str, str] | None = None,
|
||||
) -> PaginatedResponse:
|
||||
"""Fetch all rows with automatic access control filtering."""
|
||||
access_where = self._build_access_control_where_clause(policy)
|
||||
access_where = self._build_access_control_where_clause(self.policy)
|
||||
rows = await self.sql_store.fetch_all(
|
||||
table=table,
|
||||
where=where,
|
||||
|
@ -146,7 +147,7 @@ class AuthorizedSqlStore:
|
|||
str(record_id), table, User(principal=stored_owner_principal, attributes=stored_access_attrs)
|
||||
)
|
||||
|
||||
if is_action_allowed(policy, Action.READ, sql_record, current_user):
|
||||
if is_action_allowed(self.policy, Action.READ, sql_record, current_user):
|
||||
filtered_rows.append(row)
|
||||
|
||||
return PaginatedResponse(
|
||||
|
@ -157,14 +158,12 @@ class AuthorizedSqlStore:
|
|||
async def fetch_one(
|
||||
self,
|
||||
table: str,
|
||||
policy: list[AccessRule],
|
||||
where: Mapping[str, Any] | None = None,
|
||||
order_by: list[tuple[str, Literal["asc", "desc"]]] | None = None,
|
||||
) -> dict[str, Any] | None:
|
||||
"""Fetch one row with automatic access control checking."""
|
||||
results = await self.fetch_all(
|
||||
table=table,
|
||||
policy=policy,
|
||||
where=where,
|
||||
limit=1,
|
||||
order_by=order_by,
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue