mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-15 06:00:48 +00:00
chore: standardize unsupported database error
Signed-off-by: Nathan Weinberg <nweinber@redhat.com>
This commit is contained in:
parent
19123ca957
commit
2510759cdf
2 changed files with 22 additions and 11 deletions
|
@ -9,6 +9,8 @@
|
|||
# 2. All classes should have a custom error message with the goal of informing the Llama Stack user specifically
|
||||
# 3. All classes should propogate the inherited __init__ function otherwise via 'super().__init__(message)'
|
||||
|
||||
from llama_stack.providers.utils.sqlstore.sqlstore import SqlStoreType
|
||||
|
||||
|
||||
class ResourceNotFoundError(ValueError):
|
||||
"""generic exception for a missing Llama Stack resource"""
|
||||
|
@ -28,6 +30,16 @@ class UnsupportedModelError(ValueError):
|
|||
super().__init__(message)
|
||||
|
||||
|
||||
class UnsupportedSqlStoreError(ValueError):
|
||||
"""raised when SQL store is not present in the list of supported SQL stores"""
|
||||
|
||||
def __init__(self, sqlstore_type: str):
|
||||
message = (
|
||||
f"'{sqlstore_type}' SQL store is not supported. Supported SQL stores are: {', '.join(list(SqlStoreType))}"
|
||||
)
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class ModelNotFoundError(ResourceNotFoundError):
|
||||
"""raised when Llama Stack cannot find a referenced model"""
|
||||
|
||||
|
|
|
@ -7,6 +7,7 @@
|
|||
from collections.abc import Mapping
|
||||
from typing import Any, Literal
|
||||
|
||||
from llama_stack.apis.common.errors import UnsupportedSqlStoreError
|
||||
from llama_stack.core.access_control.access_control import default_policy, is_action_allowed
|
||||
from llama_stack.core.access_control.conditions import ProtectedResource
|
||||
from llama_stack.core.access_control.datatypes import AccessRule, Action, Scope
|
||||
|
@ -59,7 +60,7 @@ class AuthorizedSqlStore:
|
|||
|
||||
:param sql_store: Base SqlStore implementation to wrap
|
||||
"""
|
||||
self.sql_store = sql_store
|
||||
self.sql_store: SqlStore = sql_store
|
||||
self._detect_database_type()
|
||||
self._validate_sql_optimized_policy()
|
||||
|
||||
|
@ -69,8 +70,8 @@ class AuthorizedSqlStore:
|
|||
raise ValueError("SqlStore must have a config attribute to be used with AuthorizedSqlStore")
|
||||
|
||||
self.database_type = self.sql_store.config.type
|
||||
if self.database_type not in [SqlStoreType.postgres, SqlStoreType.sqlite]:
|
||||
raise ValueError(f"Unsupported database type: {self.database_type}")
|
||||
if self.database_type not in list(SqlStoreType):
|
||||
raise UnsupportedSqlStoreError(self.database_type)
|
||||
|
||||
def _validate_sql_optimized_policy(self) -> None:
|
||||
"""Validate that SQL_OPTIMIZED_POLICY matches the actual default_policy().
|
||||
|
@ -201,10 +202,10 @@ class AuthorizedSqlStore:
|
|||
"""
|
||||
if self.database_type == SqlStoreType.postgres:
|
||||
return f"{column}->'{path}'"
|
||||
elif self.database_type == SqlStoreType.sqlite:
|
||||
return f"JSON_EXTRACT({column}, '$.{path}')"
|
||||
# this case is when self.database_type == SqlStoreType.sqlite
|
||||
# validity detection already occurs in _detect_database_type
|
||||
else:
|
||||
raise ValueError(f"Unsupported database type: {self.database_type}")
|
||||
return f"JSON_EXTRACT({column}, '$.{path}')"
|
||||
|
||||
def _json_extract_text(self, column: str, path: str) -> str:
|
||||
"""Extract JSON value as text.
|
||||
|
@ -218,10 +219,10 @@ class AuthorizedSqlStore:
|
|||
"""
|
||||
if self.database_type == SqlStoreType.postgres:
|
||||
return f"{column}->>'{path}'"
|
||||
elif self.database_type == SqlStoreType.sqlite:
|
||||
return f"JSON_EXTRACT({column}, '$.{path}')"
|
||||
# this case is when self.database_type == SqlStoreType.sqlite
|
||||
# validity detection already occurs in _detect_database_type
|
||||
else:
|
||||
raise ValueError(f"Unsupported database type: {self.database_type}")
|
||||
return f"JSON_EXTRACT({column}, '$.{path}')"
|
||||
|
||||
def _get_public_access_conditions(self) -> list[str]:
|
||||
"""Get the SQL conditions for public access."""
|
||||
|
@ -232,8 +233,6 @@ class AuthorizedSqlStore:
|
|||
conditions.append("access_attributes::text = 'null'")
|
||||
elif self.database_type == SqlStoreType.sqlite:
|
||||
conditions.append("access_attributes = 'null'")
|
||||
else:
|
||||
raise ValueError(f"Unsupported database type: {self.database_type}")
|
||||
return conditions
|
||||
|
||||
def _build_default_policy_where_clause(self, current_user: User | None) -> str:
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue