diff --git a/llama_stack/core/conversations/conversations.py b/llama_stack/core/conversations/conversations.py index c513b4347..afd1d74da 100644 --- a/llama_stack/core/conversations/conversations.py +++ b/llama_stack/core/conversations/conversations.py @@ -21,12 +21,16 @@ from llama_stack.apis.conversations.conversations import ( Conversations, Metadata, ) -from llama_stack.core.datatypes import AccessRule, StackRunConfig +from llama_stack.core.datatypes import AccessRule from llama_stack.core.utils.config_dirs import DISTRIBS_BASE_DIR from llama_stack.log import get_logger from llama_stack.providers.utils.sqlstore.api import ColumnDefinition, ColumnType from llama_stack.providers.utils.sqlstore.authorized_sqlstore import AuthorizedSqlStore -from llama_stack.providers.utils.sqlstore.sqlstore import PostgresSqlStoreConfig, SqliteSqlStoreConfig, sqlstore_impl +from llama_stack.providers.utils.sqlstore.sqlstore import ( + SqliteSqlStoreConfig, + SqlStoreConfig, + sqlstore_impl, +) logger = get_logger(name=__name__, category="openai::conversations") @@ -34,10 +38,14 @@ logger = get_logger(name=__name__, category="openai::conversations") class ConversationServiceConfig(BaseModel): """Configuration for the built-in conversation service. - :param run_config: Stack run configuration containing distribution info + :param conversations_store: SQL store configuration for conversations (defaults to SQLite) + :param policy: Access control rules """ - run_config: StackRunConfig + conversations_store: SqlStoreConfig = SqliteSqlStoreConfig( + db_path=(DISTRIBS_BASE_DIR / "conversations.db").as_posix() + ) + policy: list[AccessRule] = [] async def get_provider_impl(config: ConversationServiceConfig, deps: dict[Any, Any]): @@ -53,23 +61,15 @@ class ConversationServiceImpl(Conversations): def __init__(self, config: ConversationServiceConfig, deps: dict[Any, Any]): self.config = config self.deps = deps - self.policy: list[AccessRule] = [] + self.policy = config.policy - conversations_store_config = config.run_config.conversations_store - if conversations_store_config is None: - sql_store_config: SqliteSqlStoreConfig | PostgresSqlStoreConfig = SqliteSqlStoreConfig( - db_path=(DISTRIBS_BASE_DIR / config.run_config.image_name / "conversations.db").as_posix() - ) - else: - sql_store_config = conversations_store_config - - base_sql_store = sqlstore_impl(sql_store_config) + base_sql_store = sqlstore_impl(config.conversations_store) self.sql_store = AuthorizedSqlStore(base_sql_store, self.policy) async def initialize(self) -> None: """Initialize the store and create tables.""" - if hasattr(self.sql_store.sql_store, "config") and hasattr(self.sql_store.sql_store.config, "db_path"): - os.makedirs(os.path.dirname(self.sql_store.sql_store.config.db_path), exist_ok=True) + if isinstance(self.config.conversations_store, SqliteSqlStoreConfig): + os.makedirs(os.path.dirname(self.config.conversations_store.db_path), exist_ok=True) await self.sql_store.create_table( "openai_conversations", @@ -91,8 +91,7 @@ class ConversationServiceImpl(Conversations): items_json = [] for item in items or []: - item_dict = item.model_dump() if hasattr(item, "model_dump") else item - items_json.append(item_dict) + items_json.append(item.model_dump()) record_data = { "id": conversation_id, @@ -170,7 +169,7 @@ class ConversationServiceImpl(Conversations): item_id = f"item_{random_bytes.hex()}" # Create a copy of the item with the generated ID and completed status - item_dict = item.model_dump() if hasattr(item, "model_dump") else dict(item) + item_dict = item.model_dump() item_dict["id"] = item_id if "status" not in item_dict: item_dict["status"] = "completed" @@ -238,15 +237,9 @@ class ConversationServiceImpl(Conversations): adapter.validate_python(item) if isinstance(item, dict) else item for item in items ] - # Get first and last IDs safely - first_id = None - last_id = None - if items: - first_item = items[0] - last_item = items[-1] - - first_id = first_item.get("id") if isinstance(first_item, dict) else getattr(first_item, "id", None) - last_id = last_item.get("id") if isinstance(last_item, dict) else getattr(last_item, "id", None) + # Get first and last IDs from converted response items + first_id = response_items[0].id if response_items else None + last_id = response_items[-1].id if response_items else None return ConversationItemList( data=response_items, diff --git a/tests/unit/conversations/test_conversations.py b/tests/unit/conversations/test_conversations.py index efb3c6351..74f9ba07c 100644 --- a/tests/unit/conversations/test_conversations.py +++ b/tests/unit/conversations/test_conversations.py @@ -20,7 +20,6 @@ from llama_stack.core.conversations.conversations import ( ConversationServiceConfig, ConversationServiceImpl, ) -from llama_stack.core.datatypes import StackRunConfig from llama_stack.providers.utils.sqlstore.sqlstore import SqliteSqlStoreConfig @@ -29,13 +28,7 @@ async def service(): with tempfile.TemporaryDirectory() as tmpdir: db_path = Path(tmpdir) / "test_conversations.db" - config = ConversationServiceConfig( - run_config=StackRunConfig( - image_name="test", - providers={}, - conversations_store=SqliteSqlStoreConfig(db_path=str(db_path)), - ) - ) + config = ConversationServiceConfig(conversations_store=SqliteSqlStoreConfig(db_path=str(db_path)), policy=[]) service = ConversationServiceImpl(config, {}) await service.initialize() yield service @@ -115,3 +108,25 @@ async def test_openai_type_compatibility(service): openai_item_adapter = TypeAdapter(OpenAIConversationItem) openai_item_adapter.validate_python(item_dict) + + +async def test_policy_configuration(): + from llama_stack.core.access_control.datatypes import Action, Scope + from llama_stack.core.datatypes import AccessRule + + with tempfile.TemporaryDirectory() as tmpdir: + db_path = Path(tmpdir) / "test_conversations_policy.db" + + restrictive_policy = [ + AccessRule(forbid=Scope(principal="test_user", actions=[Action.CREATE, Action.READ], resource="*")) + ] + + config = ConversationServiceConfig( + conversations_store=SqliteSqlStoreConfig(db_path=str(db_path)), policy=restrictive_policy + ) + service = ConversationServiceImpl(config, {}) + await service.initialize() + + assert service.policy == restrictive_policy + assert len(service.policy) == 1 + assert service.policy[0].forbid is not None