From 0dbc522bcbac79a0b8af603706f885876f5309e7 Mon Sep 17 00:00:00 2001 From: Francisco Javier Arceo Date: Thu, 2 Oct 2025 21:25:04 -0400 Subject: [PATCH] updated conversation items DB model to have a row for each item Signed-off-by: Francisco Javier Arceo --- .../core/conversations/conversations.py | 137 ++++++++++-------- .../unit/conversations/test_conversations.py | 2 +- 2 files changed, 80 insertions(+), 59 deletions(-) diff --git a/llama_stack/core/conversations/conversations.py b/llama_stack/core/conversations/conversations.py index 4587f0864..a41a60a6e 100644 --- a/llama_stack/core/conversations/conversations.py +++ b/llama_stack/core/conversations/conversations.py @@ -81,6 +81,16 @@ class ConversationServiceImpl(Conversations): }, ) + await self.sql_store.create_table( + "conversation_items", + { + "id": ColumnDefinition(type=ColumnType.STRING, primary_key=True), + "conversation_id": ColumnType.STRING, + "created_at": ColumnType.INTEGER, + "item_data": ColumnType.JSON, + }, + ) + async def create_conversation( self, items: list[ConversationItem] | None = None, metadata: Metadata | None = None ) -> Conversation: @@ -89,14 +99,10 @@ class ConversationServiceImpl(Conversations): conversation_id = f"conv_{random_bytes.hex()}" created_at = int(time.time()) - items_json = [] - for item in items or []: - items_json.append(item.model_dump()) - record_data = { "id": conversation_id, "created_at": created_at, - "items": items_json, + "items": [], "metadata": metadata, } @@ -105,6 +111,20 @@ class ConversationServiceImpl(Conversations): data=record_data, ) + if items: + for item in items: + item_dict = item.model_dump() + item_id = self._get_or_generate_item_id(item, item_dict) + + item_record = { + "id": item_id, + "conversation_id": conversation_id, + "created_at": created_at, + "item_data": item_dict, + } + + await self.sql_store.insert(table="conversation_items", data=item_record) + conversation = Conversation( id=conversation_id, created_at=created_at, @@ -148,6 +168,18 @@ class ConversationServiceImpl(Conversations): f"Invalid 'conversation_id': '{conversation_id}'. Expected an ID that begins with 'conv_'." ) + def _get_or_generate_item_id(self, item: ConversationItem, item_dict: dict) -> str: + """Get existing item ID or generate one if missing.""" + if item.id is None: + random_bytes = secrets.token_bytes(24) + if item.type == "message": + item_id = f"msg_{random_bytes.hex()}" + else: + item_id = f"item_{random_bytes.hex()}" + item_dict["id"] = item_id + return item_id + return item.id + async def _get_validated_conversation(self, conversation_id: str) -> Conversation: """Validate conversation ID and return the conversation if it exists.""" self._validate_conversation_id(conversation_id) @@ -158,32 +190,32 @@ class ConversationServiceImpl(Conversations): await self._get_validated_conversation(conversation_id) created_items = [] + created_at = int(time.time()) for item in items: - # Generate item ID based on item type - random_bytes = secrets.token_bytes(24) - if item.type == "message": - item_id = f"msg_{random_bytes.hex()}" - else: - item_id = f"item_{random_bytes.hex()}" - - # Create a copy of the item with the generated ID and completed status item_dict = item.model_dump() - item_dict["id"] = item_id - if "status" not in item_dict: - item_dict["status"] = "completed" + item_id = self._get_or_generate_item_id(item, item_dict) + + item_record = { + "id": item_id, + "conversation_id": conversation_id, + "created_at": created_at, + "item_data": item_dict, + } + + # TODO: Add support for upsert in sql_store, this will fail first if ID exists and then update + try: + await self.sql_store.insert(table="conversation_items", data=item_record) + except Exception: + # If insert fails due to ID conflict, update existing record + await self.sql_store.update( + table="conversation_items", + data={"created_at": created_at, "item_data": item_dict}, + where={"id": item_id}, + ) created_items.append(item_dict) - # Get existing items from database - record = await self.sql_store.fetch_one(table="openai_conversations", where={"id": conversation_id}) - existing_items = record.get("items", []) if record else [] - - updated_items = existing_items + created_items - await self.sql_store.update( - table="openai_conversations", data={"items": updated_items}, where={"id": conversation_id} - ) - logger.info(f"Created {len(created_items)} items in conversation {conversation_id}") # Convert created items (dicts) to proper ConversationItem types @@ -204,39 +236,37 @@ class ConversationServiceImpl(Conversations): if not item_id: raise ValueError(f"Expected a non-empty value for `item_id` but received {item_id!r}") - record = await self.sql_store.fetch_one(table="openai_conversations", where={"id": conversation_id}) - items = record.get("items", []) if record else [] + # Get item from conversation_items table + record = await self.sql_store.fetch_one( + table="conversation_items", where={"id": item_id, "conversation_id": conversation_id} + ) - for item in items: - if isinstance(item, dict) and item.get("id") == item_id: - adapter: TypeAdapter[ConversationItem] = TypeAdapter(ConversationItem) - return adapter.validate_python(item) + if record is None: + raise ValueError(f"Item {item_id} not found in conversation {conversation_id}") - raise ValueError(f"Item {item_id} not found in conversation {conversation_id}") + adapter: TypeAdapter[ConversationItem] = TypeAdapter(ConversationItem) + return adapter.validate_python(record["item_data"]) async def list(self, conversation_id: str, after=NOT_GIVEN, include=NOT_GIVEN, limit=NOT_GIVEN, order=NOT_GIVEN): """List items in the conversation.""" - record = await self.sql_store.fetch_one(table="openai_conversations", where={"id": conversation_id}) - items = record.get("items", []) if record else [] + result = await self.sql_store.fetch_all(table="conversation_items", where={"conversation_id": conversation_id}) + records = result.data if order != NOT_GIVEN and order == "asc": - items = items + records.sort(key=lambda x: x["created_at"]) else: - items = list(reversed(items)) + records.sort(key=lambda x: x["created_at"], reverse=True) actual_limit = 20 if limit != NOT_GIVEN and isinstance(limit, int): actual_limit = limit - items = items[:actual_limit] + records = records[:actual_limit] + items = [record["item_data"] for record in records] - # Items from database are stored as dicts, convert them to ConversationItem adapter: TypeAdapter[ConversationItem] = TypeAdapter(ConversationItem) - response_items: list[ConversationItem] = [ - adapter.validate_python(item) if isinstance(item, dict) else item for item in items - ] + response_items: list[ConversationItem] = [adapter.validate_python(item) for item in items] - # Get first and last IDs from converted response items first_id = response_items[0].id if response_items else None last_id = response_items[-1].id if response_items else None @@ -256,26 +286,17 @@ class ConversationServiceImpl(Conversations): if not item_id: raise ValueError(f"Expected a non-empty value for `item_id` but received {item_id!r}") - _ = await self._get_validated_conversation(conversation_id) # executes validation + _ = await self._get_validated_conversation(conversation_id) - record = await self.sql_store.fetch_one(table="openai_conversations", where={"id": conversation_id}) - items = record.get("items", []) if record else [] + record = await self.sql_store.fetch_one( + table="conversation_items", where={"id": item_id, "conversation_id": conversation_id} + ) - updated_items = [] - item_found = False - - for item in items: - current_item_id = item.get("id") if isinstance(item, dict) else getattr(item, "id", None) - if current_item_id != item_id: - updated_items.append(item) - else: - item_found = True - - if not item_found: + if record is None: raise ValueError(f"Item {item_id} not found in conversation {conversation_id}") - await self.sql_store.update( - table="openai_conversations", data={"items": updated_items}, where={"id": conversation_id} + await self.sql_store.delete( + table="conversation_items", where={"id": item_id, "conversation_id": conversation_id} ) logger.info(f"Deleted item {item_id} from conversation {conversation_id}") diff --git a/tests/unit/conversations/test_conversations.py b/tests/unit/conversations/test_conversations.py index 74f9ba07c..9ea47947a 100644 --- a/tests/unit/conversations/test_conversations.py +++ b/tests/unit/conversations/test_conversations.py @@ -62,7 +62,7 @@ async def test_conversation_items(service): item_list = await service.create(conversation.id, items) assert len(item_list.data) == 1 - assert item_list.data[0].id.startswith("msg_") + assert item_list.data[0].id == "msg_test123" items = await service.list(conversation.id) assert len(items.data) == 1