mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-04 04:04:14 +00:00
updated conversation items DB model to have a row for each item
Signed-off-by: Francisco Javier Arceo <farceo@redhat.com>
This commit is contained in:
parent
b38e6df982
commit
0dbc522bcb
2 changed files with 80 additions and 59 deletions
|
@ -81,6 +81,16 @@ class ConversationServiceImpl(Conversations):
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
await self.sql_store.create_table(
|
||||||
|
"conversation_items",
|
||||||
|
{
|
||||||
|
"id": ColumnDefinition(type=ColumnType.STRING, primary_key=True),
|
||||||
|
"conversation_id": ColumnType.STRING,
|
||||||
|
"created_at": ColumnType.INTEGER,
|
||||||
|
"item_data": ColumnType.JSON,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
async def create_conversation(
|
async def create_conversation(
|
||||||
self, items: list[ConversationItem] | None = None, metadata: Metadata | None = None
|
self, items: list[ConversationItem] | None = None, metadata: Metadata | None = None
|
||||||
) -> Conversation:
|
) -> Conversation:
|
||||||
|
@ -89,14 +99,10 @@ class ConversationServiceImpl(Conversations):
|
||||||
conversation_id = f"conv_{random_bytes.hex()}"
|
conversation_id = f"conv_{random_bytes.hex()}"
|
||||||
created_at = int(time.time())
|
created_at = int(time.time())
|
||||||
|
|
||||||
items_json = []
|
|
||||||
for item in items or []:
|
|
||||||
items_json.append(item.model_dump())
|
|
||||||
|
|
||||||
record_data = {
|
record_data = {
|
||||||
"id": conversation_id,
|
"id": conversation_id,
|
||||||
"created_at": created_at,
|
"created_at": created_at,
|
||||||
"items": items_json,
|
"items": [],
|
||||||
"metadata": metadata,
|
"metadata": metadata,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -105,6 +111,20 @@ class ConversationServiceImpl(Conversations):
|
||||||
data=record_data,
|
data=record_data,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if items:
|
||||||
|
for item in items:
|
||||||
|
item_dict = item.model_dump()
|
||||||
|
item_id = self._get_or_generate_item_id(item, item_dict)
|
||||||
|
|
||||||
|
item_record = {
|
||||||
|
"id": item_id,
|
||||||
|
"conversation_id": conversation_id,
|
||||||
|
"created_at": created_at,
|
||||||
|
"item_data": item_dict,
|
||||||
|
}
|
||||||
|
|
||||||
|
await self.sql_store.insert(table="conversation_items", data=item_record)
|
||||||
|
|
||||||
conversation = Conversation(
|
conversation = Conversation(
|
||||||
id=conversation_id,
|
id=conversation_id,
|
||||||
created_at=created_at,
|
created_at=created_at,
|
||||||
|
@ -148,6 +168,18 @@ class ConversationServiceImpl(Conversations):
|
||||||
f"Invalid 'conversation_id': '{conversation_id}'. Expected an ID that begins with 'conv_'."
|
f"Invalid 'conversation_id': '{conversation_id}'. Expected an ID that begins with 'conv_'."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def _get_or_generate_item_id(self, item: ConversationItem, item_dict: dict) -> str:
|
||||||
|
"""Get existing item ID or generate one if missing."""
|
||||||
|
if item.id is None:
|
||||||
|
random_bytes = secrets.token_bytes(24)
|
||||||
|
if item.type == "message":
|
||||||
|
item_id = f"msg_{random_bytes.hex()}"
|
||||||
|
else:
|
||||||
|
item_id = f"item_{random_bytes.hex()}"
|
||||||
|
item_dict["id"] = item_id
|
||||||
|
return item_id
|
||||||
|
return item.id
|
||||||
|
|
||||||
async def _get_validated_conversation(self, conversation_id: str) -> Conversation:
|
async def _get_validated_conversation(self, conversation_id: str) -> Conversation:
|
||||||
"""Validate conversation ID and return the conversation if it exists."""
|
"""Validate conversation ID and return the conversation if it exists."""
|
||||||
self._validate_conversation_id(conversation_id)
|
self._validate_conversation_id(conversation_id)
|
||||||
|
@ -158,32 +190,32 @@ class ConversationServiceImpl(Conversations):
|
||||||
await self._get_validated_conversation(conversation_id)
|
await self._get_validated_conversation(conversation_id)
|
||||||
|
|
||||||
created_items = []
|
created_items = []
|
||||||
|
created_at = int(time.time())
|
||||||
|
|
||||||
for item in items:
|
for item in items:
|
||||||
# Generate item ID based on item type
|
|
||||||
random_bytes = secrets.token_bytes(24)
|
|
||||||
if item.type == "message":
|
|
||||||
item_id = f"msg_{random_bytes.hex()}"
|
|
||||||
else:
|
|
||||||
item_id = f"item_{random_bytes.hex()}"
|
|
||||||
|
|
||||||
# Create a copy of the item with the generated ID and completed status
|
|
||||||
item_dict = item.model_dump()
|
item_dict = item.model_dump()
|
||||||
item_dict["id"] = item_id
|
item_id = self._get_or_generate_item_id(item, item_dict)
|
||||||
if "status" not in item_dict:
|
|
||||||
item_dict["status"] = "completed"
|
item_record = {
|
||||||
|
"id": item_id,
|
||||||
|
"conversation_id": conversation_id,
|
||||||
|
"created_at": created_at,
|
||||||
|
"item_data": item_dict,
|
||||||
|
}
|
||||||
|
|
||||||
|
# TODO: Add support for upsert in sql_store, this will fail first if ID exists and then update
|
||||||
|
try:
|
||||||
|
await self.sql_store.insert(table="conversation_items", data=item_record)
|
||||||
|
except Exception:
|
||||||
|
# If insert fails due to ID conflict, update existing record
|
||||||
|
await self.sql_store.update(
|
||||||
|
table="conversation_items",
|
||||||
|
data={"created_at": created_at, "item_data": item_dict},
|
||||||
|
where={"id": item_id},
|
||||||
|
)
|
||||||
|
|
||||||
created_items.append(item_dict)
|
created_items.append(item_dict)
|
||||||
|
|
||||||
# Get existing items from database
|
|
||||||
record = await self.sql_store.fetch_one(table="openai_conversations", where={"id": conversation_id})
|
|
||||||
existing_items = record.get("items", []) if record else []
|
|
||||||
|
|
||||||
updated_items = existing_items + created_items
|
|
||||||
await self.sql_store.update(
|
|
||||||
table="openai_conversations", data={"items": updated_items}, where={"id": conversation_id}
|
|
||||||
)
|
|
||||||
|
|
||||||
logger.info(f"Created {len(created_items)} items in conversation {conversation_id}")
|
logger.info(f"Created {len(created_items)} items in conversation {conversation_id}")
|
||||||
|
|
||||||
# Convert created items (dicts) to proper ConversationItem types
|
# Convert created items (dicts) to proper ConversationItem types
|
||||||
|
@ -204,39 +236,37 @@ class ConversationServiceImpl(Conversations):
|
||||||
if not item_id:
|
if not item_id:
|
||||||
raise ValueError(f"Expected a non-empty value for `item_id` but received {item_id!r}")
|
raise ValueError(f"Expected a non-empty value for `item_id` but received {item_id!r}")
|
||||||
|
|
||||||
record = await self.sql_store.fetch_one(table="openai_conversations", where={"id": conversation_id})
|
# Get item from conversation_items table
|
||||||
items = record.get("items", []) if record else []
|
record = await self.sql_store.fetch_one(
|
||||||
|
table="conversation_items", where={"id": item_id, "conversation_id": conversation_id}
|
||||||
for item in items:
|
)
|
||||||
if isinstance(item, dict) and item.get("id") == item_id:
|
|
||||||
adapter: TypeAdapter[ConversationItem] = TypeAdapter(ConversationItem)
|
|
||||||
return adapter.validate_python(item)
|
|
||||||
|
|
||||||
|
if record is None:
|
||||||
raise ValueError(f"Item {item_id} not found in conversation {conversation_id}")
|
raise ValueError(f"Item {item_id} not found in conversation {conversation_id}")
|
||||||
|
|
||||||
|
adapter: TypeAdapter[ConversationItem] = TypeAdapter(ConversationItem)
|
||||||
|
return adapter.validate_python(record["item_data"])
|
||||||
|
|
||||||
async def list(self, conversation_id: str, after=NOT_GIVEN, include=NOT_GIVEN, limit=NOT_GIVEN, order=NOT_GIVEN):
|
async def list(self, conversation_id: str, after=NOT_GIVEN, include=NOT_GIVEN, limit=NOT_GIVEN, order=NOT_GIVEN):
|
||||||
"""List items in the conversation."""
|
"""List items in the conversation."""
|
||||||
record = await self.sql_store.fetch_one(table="openai_conversations", where={"id": conversation_id})
|
result = await self.sql_store.fetch_all(table="conversation_items", where={"conversation_id": conversation_id})
|
||||||
items = record.get("items", []) if record else []
|
records = result.data
|
||||||
|
|
||||||
if order != NOT_GIVEN and order == "asc":
|
if order != NOT_GIVEN and order == "asc":
|
||||||
items = items
|
records.sort(key=lambda x: x["created_at"])
|
||||||
else:
|
else:
|
||||||
items = list(reversed(items))
|
records.sort(key=lambda x: x["created_at"], reverse=True)
|
||||||
|
|
||||||
actual_limit = 20
|
actual_limit = 20
|
||||||
if limit != NOT_GIVEN and isinstance(limit, int):
|
if limit != NOT_GIVEN and isinstance(limit, int):
|
||||||
actual_limit = limit
|
actual_limit = limit
|
||||||
|
|
||||||
items = items[:actual_limit]
|
records = records[:actual_limit]
|
||||||
|
items = [record["item_data"] for record in records]
|
||||||
|
|
||||||
# Items from database are stored as dicts, convert them to ConversationItem
|
|
||||||
adapter: TypeAdapter[ConversationItem] = TypeAdapter(ConversationItem)
|
adapter: TypeAdapter[ConversationItem] = TypeAdapter(ConversationItem)
|
||||||
response_items: list[ConversationItem] = [
|
response_items: list[ConversationItem] = [adapter.validate_python(item) for item in items]
|
||||||
adapter.validate_python(item) if isinstance(item, dict) else item for item in items
|
|
||||||
]
|
|
||||||
|
|
||||||
# Get first and last IDs from converted response items
|
|
||||||
first_id = response_items[0].id if response_items else None
|
first_id = response_items[0].id if response_items else None
|
||||||
last_id = response_items[-1].id if response_items else None
|
last_id = response_items[-1].id if response_items else None
|
||||||
|
|
||||||
|
@ -256,26 +286,17 @@ class ConversationServiceImpl(Conversations):
|
||||||
if not item_id:
|
if not item_id:
|
||||||
raise ValueError(f"Expected a non-empty value for `item_id` but received {item_id!r}")
|
raise ValueError(f"Expected a non-empty value for `item_id` but received {item_id!r}")
|
||||||
|
|
||||||
_ = await self._get_validated_conversation(conversation_id) # executes validation
|
_ = await self._get_validated_conversation(conversation_id)
|
||||||
|
|
||||||
record = await self.sql_store.fetch_one(table="openai_conversations", where={"id": conversation_id})
|
record = await self.sql_store.fetch_one(
|
||||||
items = record.get("items", []) if record else []
|
table="conversation_items", where={"id": item_id, "conversation_id": conversation_id}
|
||||||
|
)
|
||||||
|
|
||||||
updated_items = []
|
if record is None:
|
||||||
item_found = False
|
|
||||||
|
|
||||||
for item in items:
|
|
||||||
current_item_id = item.get("id") if isinstance(item, dict) else getattr(item, "id", None)
|
|
||||||
if current_item_id != item_id:
|
|
||||||
updated_items.append(item)
|
|
||||||
else:
|
|
||||||
item_found = True
|
|
||||||
|
|
||||||
if not item_found:
|
|
||||||
raise ValueError(f"Item {item_id} not found in conversation {conversation_id}")
|
raise ValueError(f"Item {item_id} not found in conversation {conversation_id}")
|
||||||
|
|
||||||
await self.sql_store.update(
|
await self.sql_store.delete(
|
||||||
table="openai_conversations", data={"items": updated_items}, where={"id": conversation_id}
|
table="conversation_items", where={"id": item_id, "conversation_id": conversation_id}
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.info(f"Deleted item {item_id} from conversation {conversation_id}")
|
logger.info(f"Deleted item {item_id} from conversation {conversation_id}")
|
||||||
|
|
|
@ -62,7 +62,7 @@ async def test_conversation_items(service):
|
||||||
item_list = await service.create(conversation.id, items)
|
item_list = await service.create(conversation.id, items)
|
||||||
|
|
||||||
assert len(item_list.data) == 1
|
assert len(item_list.data) == 1
|
||||||
assert item_list.data[0].id.startswith("msg_")
|
assert item_list.data[0].id == "msg_test123"
|
||||||
|
|
||||||
items = await service.list(conversation.id)
|
items = await service.list(conversation.id)
|
||||||
assert len(items.data) == 1
|
assert len(items.data) == 1
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue