This commit is contained in:
Nathan Weinberg 2025-08-14 13:56:41 -04:00 committed by GitHub
commit 0c15a4053f
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 22 additions and 11 deletions

View file

@ -9,6 +9,8 @@
# 2. All classes should have a custom error message with the goal of informing the Llama Stack user specifically # 2. All classes should have a custom error message with the goal of informing the Llama Stack user specifically
# 3. All classes should propogate the inherited __init__ function otherwise via 'super().__init__(message)' # 3. All classes should propogate the inherited __init__ function otherwise via 'super().__init__(message)'
from llama_stack.providers.utils.sqlstore.sqlstore import SqlStoreType
class ResourceNotFoundError(ValueError): class ResourceNotFoundError(ValueError):
"""generic exception for a missing Llama Stack resource""" """generic exception for a missing Llama Stack resource"""
@ -28,6 +30,16 @@ class UnsupportedModelError(ValueError):
super().__init__(message) super().__init__(message)
class UnsupportedSqlStoreError(ValueError):
"""raised when SQL store is not present in the list of supported SQL stores"""
def __init__(self, sqlstore_type: str):
message = (
f"'{sqlstore_type}' SQL store is not supported. Supported SQL stores are: {', '.join(list(SqlStoreType))}"
)
super().__init__(message)
class ModelNotFoundError(ResourceNotFoundError): class ModelNotFoundError(ResourceNotFoundError):
"""raised when Llama Stack cannot find a referenced model""" """raised when Llama Stack cannot find a referenced model"""

View file

@ -7,6 +7,7 @@
from collections.abc import Mapping from collections.abc import Mapping
from typing import Any, Literal from typing import Any, Literal
from llama_stack.apis.common.errors import UnsupportedSqlStoreError
from llama_stack.core.access_control.access_control import default_policy, is_action_allowed from llama_stack.core.access_control.access_control import default_policy, is_action_allowed
from llama_stack.core.access_control.conditions import ProtectedResource from llama_stack.core.access_control.conditions import ProtectedResource
from llama_stack.core.access_control.datatypes import AccessRule, Action, Scope from llama_stack.core.access_control.datatypes import AccessRule, Action, Scope
@ -59,7 +60,7 @@ class AuthorizedSqlStore:
:param sql_store: Base SqlStore implementation to wrap :param sql_store: Base SqlStore implementation to wrap
""" """
self.sql_store = sql_store self.sql_store: SqlStore = sql_store
self._detect_database_type() self._detect_database_type()
self._validate_sql_optimized_policy() self._validate_sql_optimized_policy()
@ -69,8 +70,8 @@ class AuthorizedSqlStore:
raise ValueError("SqlStore must have a config attribute to be used with AuthorizedSqlStore") raise ValueError("SqlStore must have a config attribute to be used with AuthorizedSqlStore")
self.database_type = self.sql_store.config.type self.database_type = self.sql_store.config.type
if self.database_type not in [SqlStoreType.postgres, SqlStoreType.sqlite]: if self.database_type not in list(SqlStoreType):
raise ValueError(f"Unsupported database type: {self.database_type}") raise UnsupportedSqlStoreError(self.database_type)
def _validate_sql_optimized_policy(self) -> None: def _validate_sql_optimized_policy(self) -> None:
"""Validate that SQL_OPTIMIZED_POLICY matches the actual default_policy(). """Validate that SQL_OPTIMIZED_POLICY matches the actual default_policy().
@ -201,10 +202,10 @@ class AuthorizedSqlStore:
""" """
if self.database_type == SqlStoreType.postgres: if self.database_type == SqlStoreType.postgres:
return f"{column}->'{path}'" return f"{column}->'{path}'"
elif self.database_type == SqlStoreType.sqlite: # this case is when self.database_type == SqlStoreType.sqlite
return f"JSON_EXTRACT({column}, '$.{path}')" # validity detection already occurs in _detect_database_type
else: else:
raise ValueError(f"Unsupported database type: {self.database_type}") return f"JSON_EXTRACT({column}, '$.{path}')"
def _json_extract_text(self, column: str, path: str) -> str: def _json_extract_text(self, column: str, path: str) -> str:
"""Extract JSON value as text. """Extract JSON value as text.
@ -218,10 +219,10 @@ class AuthorizedSqlStore:
""" """
if self.database_type == SqlStoreType.postgres: if self.database_type == SqlStoreType.postgres:
return f"{column}->>'{path}'" return f"{column}->>'{path}'"
elif self.database_type == SqlStoreType.sqlite: # this case is when self.database_type == SqlStoreType.sqlite
return f"JSON_EXTRACT({column}, '$.{path}')" # validity detection already occurs in _detect_database_type
else: else:
raise ValueError(f"Unsupported database type: {self.database_type}") return f"JSON_EXTRACT({column}, '$.{path}')"
def _get_public_access_conditions(self) -> list[str]: def _get_public_access_conditions(self) -> list[str]:
"""Get the SQL conditions for public access.""" """Get the SQL conditions for public access."""
@ -232,8 +233,6 @@ class AuthorizedSqlStore:
conditions.append("access_attributes::text = 'null'") conditions.append("access_attributes::text = 'null'")
elif self.database_type == SqlStoreType.sqlite: elif self.database_type == SqlStoreType.sqlite:
conditions.append("access_attributes = 'null'") conditions.append("access_attributes = 'null'")
else:
raise ValueError(f"Unsupported database type: {self.database_type}")
return conditions return conditions
def _build_default_policy_where_clause(self, current_user: User | None) -> str: def _build_default_policy_where_clause(self, current_user: User | None) -> str: