mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-15 06:00:48 +00:00
chore: standardize unsupported database error
Signed-off-by: Nathan Weinberg <nweinber@redhat.com>
This commit is contained in:
parent
19123ca957
commit
2510759cdf
2 changed files with 22 additions and 11 deletions
|
@ -9,6 +9,8 @@
|
||||||
# 2. All classes should have a custom error message with the goal of informing the Llama Stack user specifically
|
# 2. All classes should have a custom error message with the goal of informing the Llama Stack user specifically
|
||||||
# 3. All classes should propogate the inherited __init__ function otherwise via 'super().__init__(message)'
|
# 3. All classes should propogate the inherited __init__ function otherwise via 'super().__init__(message)'
|
||||||
|
|
||||||
|
from llama_stack.providers.utils.sqlstore.sqlstore import SqlStoreType
|
||||||
|
|
||||||
|
|
||||||
class ResourceNotFoundError(ValueError):
|
class ResourceNotFoundError(ValueError):
|
||||||
"""generic exception for a missing Llama Stack resource"""
|
"""generic exception for a missing Llama Stack resource"""
|
||||||
|
@ -28,6 +30,16 @@ class UnsupportedModelError(ValueError):
|
||||||
super().__init__(message)
|
super().__init__(message)
|
||||||
|
|
||||||
|
|
||||||
|
class UnsupportedSqlStoreError(ValueError):
|
||||||
|
"""raised when SQL store is not present in the list of supported SQL stores"""
|
||||||
|
|
||||||
|
def __init__(self, sqlstore_type: str):
|
||||||
|
message = (
|
||||||
|
f"'{sqlstore_type}' SQL store is not supported. Supported SQL stores are: {', '.join(list(SqlStoreType))}"
|
||||||
|
)
|
||||||
|
super().__init__(message)
|
||||||
|
|
||||||
|
|
||||||
class ModelNotFoundError(ResourceNotFoundError):
|
class ModelNotFoundError(ResourceNotFoundError):
|
||||||
"""raised when Llama Stack cannot find a referenced model"""
|
"""raised when Llama Stack cannot find a referenced model"""
|
||||||
|
|
||||||
|
|
|
@ -7,6 +7,7 @@
|
||||||
from collections.abc import Mapping
|
from collections.abc import Mapping
|
||||||
from typing import Any, Literal
|
from typing import Any, Literal
|
||||||
|
|
||||||
|
from llama_stack.apis.common.errors import UnsupportedSqlStoreError
|
||||||
from llama_stack.core.access_control.access_control import default_policy, is_action_allowed
|
from llama_stack.core.access_control.access_control import default_policy, is_action_allowed
|
||||||
from llama_stack.core.access_control.conditions import ProtectedResource
|
from llama_stack.core.access_control.conditions import ProtectedResource
|
||||||
from llama_stack.core.access_control.datatypes import AccessRule, Action, Scope
|
from llama_stack.core.access_control.datatypes import AccessRule, Action, Scope
|
||||||
|
@ -59,7 +60,7 @@ class AuthorizedSqlStore:
|
||||||
|
|
||||||
:param sql_store: Base SqlStore implementation to wrap
|
:param sql_store: Base SqlStore implementation to wrap
|
||||||
"""
|
"""
|
||||||
self.sql_store = sql_store
|
self.sql_store: SqlStore = sql_store
|
||||||
self._detect_database_type()
|
self._detect_database_type()
|
||||||
self._validate_sql_optimized_policy()
|
self._validate_sql_optimized_policy()
|
||||||
|
|
||||||
|
@ -69,8 +70,8 @@ class AuthorizedSqlStore:
|
||||||
raise ValueError("SqlStore must have a config attribute to be used with AuthorizedSqlStore")
|
raise ValueError("SqlStore must have a config attribute to be used with AuthorizedSqlStore")
|
||||||
|
|
||||||
self.database_type = self.sql_store.config.type
|
self.database_type = self.sql_store.config.type
|
||||||
if self.database_type not in [SqlStoreType.postgres, SqlStoreType.sqlite]:
|
if self.database_type not in list(SqlStoreType):
|
||||||
raise ValueError(f"Unsupported database type: {self.database_type}")
|
raise UnsupportedSqlStoreError(self.database_type)
|
||||||
|
|
||||||
def _validate_sql_optimized_policy(self) -> None:
|
def _validate_sql_optimized_policy(self) -> None:
|
||||||
"""Validate that SQL_OPTIMIZED_POLICY matches the actual default_policy().
|
"""Validate that SQL_OPTIMIZED_POLICY matches the actual default_policy().
|
||||||
|
@ -201,10 +202,10 @@ class AuthorizedSqlStore:
|
||||||
"""
|
"""
|
||||||
if self.database_type == SqlStoreType.postgres:
|
if self.database_type == SqlStoreType.postgres:
|
||||||
return f"{column}->'{path}'"
|
return f"{column}->'{path}'"
|
||||||
elif self.database_type == SqlStoreType.sqlite:
|
# this case is when self.database_type == SqlStoreType.sqlite
|
||||||
return f"JSON_EXTRACT({column}, '$.{path}')"
|
# validity detection already occurs in _detect_database_type
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unsupported database type: {self.database_type}")
|
return f"JSON_EXTRACT({column}, '$.{path}')"
|
||||||
|
|
||||||
def _json_extract_text(self, column: str, path: str) -> str:
|
def _json_extract_text(self, column: str, path: str) -> str:
|
||||||
"""Extract JSON value as text.
|
"""Extract JSON value as text.
|
||||||
|
@ -218,10 +219,10 @@ class AuthorizedSqlStore:
|
||||||
"""
|
"""
|
||||||
if self.database_type == SqlStoreType.postgres:
|
if self.database_type == SqlStoreType.postgres:
|
||||||
return f"{column}->>'{path}'"
|
return f"{column}->>'{path}'"
|
||||||
elif self.database_type == SqlStoreType.sqlite:
|
# this case is when self.database_type == SqlStoreType.sqlite
|
||||||
return f"JSON_EXTRACT({column}, '$.{path}')"
|
# validity detection already occurs in _detect_database_type
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unsupported database type: {self.database_type}")
|
return f"JSON_EXTRACT({column}, '$.{path}')"
|
||||||
|
|
||||||
def _get_public_access_conditions(self) -> list[str]:
|
def _get_public_access_conditions(self) -> list[str]:
|
||||||
"""Get the SQL conditions for public access."""
|
"""Get the SQL conditions for public access."""
|
||||||
|
@ -232,8 +233,6 @@ class AuthorizedSqlStore:
|
||||||
conditions.append("access_attributes::text = 'null'")
|
conditions.append("access_attributes::text = 'null'")
|
||||||
elif self.database_type == SqlStoreType.sqlite:
|
elif self.database_type == SqlStoreType.sqlite:
|
||||||
conditions.append("access_attributes = 'null'")
|
conditions.append("access_attributes = 'null'")
|
||||||
else:
|
|
||||||
raise ValueError(f"Unsupported database type: {self.database_type}")
|
|
||||||
return conditions
|
return conditions
|
||||||
|
|
||||||
def _build_default_policy_where_clause(self, current_user: User | None) -> str:
|
def _build_default_policy_where_clause(self, current_user: User | None) -> str:
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue