mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-04 12:07:34 +00:00
feat: Add OpenAI Conversations API
Signed-off-by: Francisco Javier Arceo <farceo@redhat.com>
This commit is contained in:
parent
0e13512dd7
commit
a74a7cc873
18 changed files with 3280 additions and 1088 deletions
290
llama_stack/core/conversations/conversations.py
Normal file
290
llama_stack/core/conversations/conversations.py
Normal file
|
@ -0,0 +1,290 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import os
|
||||
import secrets
|
||||
import time
|
||||
from typing import Any
|
||||
|
||||
from openai import NOT_GIVEN
|
||||
from pydantic import BaseModel, TypeAdapter
|
||||
|
||||
from llama_stack.apis.conversations.conversations import (
|
||||
Conversation,
|
||||
ConversationDeletedResource,
|
||||
ConversationItem,
|
||||
ConversationItemDeletedResource,
|
||||
ConversationItemList,
|
||||
Conversations,
|
||||
Metadata,
|
||||
)
|
||||
from llama_stack.core.datatypes import AccessRule, StackRunConfig
|
||||
from llama_stack.core.utils.config_dirs import DISTRIBS_BASE_DIR
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.utils.sqlstore.api import ColumnDefinition, ColumnType
|
||||
from llama_stack.providers.utils.sqlstore.authorized_sqlstore import AuthorizedSqlStore
|
||||
from llama_stack.providers.utils.sqlstore.sqlstore import PostgresSqlStoreConfig, SqliteSqlStoreConfig, sqlstore_impl
|
||||
|
||||
logger = get_logger(name=__name__, category="openai::conversations")
|
||||
|
||||
|
||||
class ConversationServiceConfig(BaseModel):
|
||||
"""Configuration for the built-in conversation service.
|
||||
|
||||
:param run_config: Stack run configuration containing distribution info
|
||||
"""
|
||||
|
||||
run_config: StackRunConfig
|
||||
|
||||
|
||||
async def get_provider_impl(config: ConversationServiceConfig, deps: dict[Any, Any]):
|
||||
"""Get the conversation service implementation."""
|
||||
impl = ConversationServiceImpl(config, deps)
|
||||
await impl.initialize()
|
||||
return impl
|
||||
|
||||
|
||||
class ConversationServiceImpl(Conversations):
|
||||
"""Built-in conversation service implementation using AuthorizedSqlStore."""
|
||||
|
||||
def __init__(self, config: ConversationServiceConfig, deps: dict[Any, Any]):
|
||||
self.config = config
|
||||
self.deps = deps
|
||||
self.policy: list[AccessRule] = []
|
||||
|
||||
conversations_store_config = config.run_config.conversations_store
|
||||
if conversations_store_config is None:
|
||||
sql_store_config: SqliteSqlStoreConfig | PostgresSqlStoreConfig = SqliteSqlStoreConfig(
|
||||
db_path=(DISTRIBS_BASE_DIR / config.run_config.image_name / "conversations.db").as_posix()
|
||||
)
|
||||
else:
|
||||
sql_store_config = conversations_store_config
|
||||
|
||||
base_sql_store = sqlstore_impl(sql_store_config)
|
||||
self.sql_store = AuthorizedSqlStore(base_sql_store, self.policy)
|
||||
|
||||
async def initialize(self) -> None:
|
||||
"""Initialize the store and create tables."""
|
||||
if hasattr(self.sql_store.sql_store, "config") and hasattr(self.sql_store.sql_store.config, "db_path"):
|
||||
os.makedirs(os.path.dirname(self.sql_store.sql_store.config.db_path), exist_ok=True)
|
||||
|
||||
await self.sql_store.create_table(
|
||||
"openai_conversations",
|
||||
{
|
||||
"id": ColumnDefinition(type=ColumnType.STRING, primary_key=True),
|
||||
"created_at": ColumnType.INTEGER,
|
||||
"items": ColumnType.JSON,
|
||||
"metadata": ColumnType.JSON,
|
||||
},
|
||||
)
|
||||
|
||||
async def create_conversation(
|
||||
self, items: list[ConversationItem] | None = None, metadata: Metadata | None = None
|
||||
) -> Conversation:
|
||||
"""Create a conversation."""
|
||||
random_bytes = secrets.token_bytes(24)
|
||||
conversation_id = f"conv_{random_bytes.hex()}"
|
||||
created_at = int(time.time())
|
||||
|
||||
items_json = []
|
||||
for item in items or []:
|
||||
item_dict = item.model_dump() if hasattr(item, "model_dump") else item
|
||||
items_json.append(item_dict)
|
||||
|
||||
record_data = {
|
||||
"id": conversation_id,
|
||||
"created_at": created_at,
|
||||
"items": items_json,
|
||||
"metadata": metadata,
|
||||
}
|
||||
|
||||
await self.sql_store.insert(
|
||||
table="openai_conversations",
|
||||
data=record_data,
|
||||
)
|
||||
|
||||
conversation = Conversation(
|
||||
id=conversation_id,
|
||||
created_at=created_at,
|
||||
metadata=metadata,
|
||||
object="conversation",
|
||||
)
|
||||
|
||||
logger.info(f"Created conversation {conversation_id}")
|
||||
return conversation
|
||||
|
||||
async def get_conversation(self, conversation_id: str) -> Conversation:
|
||||
"""Get a conversation with the given ID."""
|
||||
record = await self.sql_store.fetch_one(table="openai_conversations", where={"id": conversation_id})
|
||||
|
||||
if record is None:
|
||||
raise ValueError(f"Conversation {conversation_id} not found")
|
||||
|
||||
return Conversation(
|
||||
id=record["id"], created_at=record["created_at"], metadata=record.get("metadata"), object="conversation"
|
||||
)
|
||||
|
||||
async def update_conversation(self, conversation_id: str, metadata: Metadata) -> Conversation:
|
||||
"""Update a conversation's metadata with the given ID"""
|
||||
await self.sql_store.update(
|
||||
table="openai_conversations", data={"metadata": metadata}, where={"id": conversation_id}
|
||||
)
|
||||
|
||||
return await self.get_conversation(conversation_id)
|
||||
|
||||
async def openai_delete_conversation(self, conversation_id: str) -> ConversationDeletedResource:
|
||||
"""Delete a conversation with the given ID."""
|
||||
await self.sql_store.delete(table="openai_conversations", where={"id": conversation_id})
|
||||
|
||||
logger.info(f"Deleted conversation {conversation_id}")
|
||||
return ConversationDeletedResource(id=conversation_id)
|
||||
|
||||
def _validate_conversation_id(self, conversation_id: str) -> None:
|
||||
"""Validate conversation ID format."""
|
||||
if not conversation_id.startswith("conv_"):
|
||||
raise ValueError(
|
||||
f"Invalid 'conversation_id': '{conversation_id}'. Expected an ID that begins with 'conv_'."
|
||||
)
|
||||
|
||||
async def _get_validated_conversation(self, conversation_id: str) -> Conversation:
|
||||
"""Validate conversation ID and return the conversation if it exists."""
|
||||
self._validate_conversation_id(conversation_id)
|
||||
return await self.get_conversation(conversation_id)
|
||||
|
||||
async def create(self, conversation_id: str, items: list[ConversationItem]) -> ConversationItemList:
|
||||
"""Create items in the conversation."""
|
||||
await self._get_validated_conversation(conversation_id)
|
||||
|
||||
created_items = []
|
||||
|
||||
for item in items:
|
||||
# Generate item ID based on item type
|
||||
random_bytes = secrets.token_bytes(24)
|
||||
item_type = getattr(item, "type", None)
|
||||
if item_type == "message":
|
||||
item_id = f"msg_{random_bytes.hex()}"
|
||||
else:
|
||||
item_id = f"item_{random_bytes.hex()}"
|
||||
|
||||
# Create a copy of the item with the generated ID and completed status
|
||||
item_dict = item.model_dump() if hasattr(item, "model_dump") else dict(item)
|
||||
item_dict["id"] = item_id
|
||||
if "status" not in item_dict:
|
||||
item_dict["status"] = "completed"
|
||||
|
||||
created_items.append(item_dict)
|
||||
|
||||
# Get existing items from database
|
||||
record = await self.sql_store.fetch_one(table="openai_conversations", where={"id": conversation_id})
|
||||
existing_items = record.get("items", []) if record else []
|
||||
|
||||
updated_items = existing_items + created_items
|
||||
await self.sql_store.update(
|
||||
table="openai_conversations", data={"items": updated_items}, where={"id": conversation_id}
|
||||
)
|
||||
|
||||
logger.info(f"Created {len(created_items)} items in conversation {conversation_id}")
|
||||
|
||||
# Convert created items (dicts) to proper ConversationItem types
|
||||
adapter: TypeAdapter[ConversationItem] = TypeAdapter(ConversationItem)
|
||||
response_items: list[ConversationItem] = [adapter.validate_python(item_dict) for item_dict in created_items]
|
||||
|
||||
return ConversationItemList(
|
||||
data=response_items,
|
||||
first_id=created_items[0]["id"] if created_items else None,
|
||||
last_id=created_items[-1]["id"] if created_items else None,
|
||||
has_more=False,
|
||||
)
|
||||
|
||||
async def retrieve(self, conversation_id: str, item_id: str) -> ConversationItem:
|
||||
"""Retrieve a conversation item."""
|
||||
if not conversation_id:
|
||||
raise ValueError(f"Expected a non-empty value for `conversation_id` but received {conversation_id!r}")
|
||||
if not item_id:
|
||||
raise ValueError(f"Expected a non-empty value for `item_id` but received {item_id!r}")
|
||||
|
||||
record = await self.sql_store.fetch_one(table="openai_conversations", where={"id": conversation_id})
|
||||
items = record.get("items", []) if record else []
|
||||
|
||||
for item in items:
|
||||
if isinstance(item, dict) and item.get("id") == item_id:
|
||||
adapter: TypeAdapter[ConversationItem] = TypeAdapter(ConversationItem)
|
||||
return adapter.validate_python(item)
|
||||
|
||||
raise ValueError(f"Item {item_id} not found in conversation {conversation_id}")
|
||||
|
||||
async def list(self, conversation_id: str, after=NOT_GIVEN, include=NOT_GIVEN, limit=NOT_GIVEN, order=NOT_GIVEN):
|
||||
"""List items in the conversation."""
|
||||
record = await self.sql_store.fetch_one(table="openai_conversations", where={"id": conversation_id})
|
||||
items = record.get("items", []) if record else []
|
||||
|
||||
if order != NOT_GIVEN and order == "asc":
|
||||
items = items
|
||||
else:
|
||||
items = list(reversed(items))
|
||||
|
||||
actual_limit = 20
|
||||
if limit != NOT_GIVEN and isinstance(limit, int):
|
||||
actual_limit = limit
|
||||
|
||||
items = items[:actual_limit]
|
||||
|
||||
# Items from database are stored as dicts, convert them to ConversationItem
|
||||
adapter: TypeAdapter[ConversationItem] = TypeAdapter(ConversationItem)
|
||||
response_items: list[ConversationItem] = [
|
||||
adapter.validate_python(item) if isinstance(item, dict) else item for item in items
|
||||
]
|
||||
|
||||
# Get first and last IDs safely
|
||||
first_id = None
|
||||
last_id = None
|
||||
if items:
|
||||
first_item = items[0]
|
||||
last_item = items[-1]
|
||||
|
||||
first_id = first_item.get("id") if isinstance(first_item, dict) else getattr(first_item, "id", None)
|
||||
last_id = last_item.get("id") if isinstance(last_item, dict) else getattr(last_item, "id", None)
|
||||
|
||||
return ConversationItemList(
|
||||
data=response_items,
|
||||
first_id=first_id,
|
||||
last_id=last_id,
|
||||
has_more=False,
|
||||
)
|
||||
|
||||
async def openai_delete_conversation_item(
|
||||
self, conversation_id: str, item_id: str
|
||||
) -> ConversationItemDeletedResource:
|
||||
"""Delete a conversation item."""
|
||||
if not conversation_id:
|
||||
raise ValueError(f"Expected a non-empty value for `conversation_id` but received {conversation_id!r}")
|
||||
if not item_id:
|
||||
raise ValueError(f"Expected a non-empty value for `item_id` but received {item_id!r}")
|
||||
|
||||
_ = await self._get_validated_conversation(conversation_id) # executes validation
|
||||
|
||||
record = await self.sql_store.fetch_one(table="openai_conversations", where={"id": conversation_id})
|
||||
items = record.get("items", []) if record else []
|
||||
|
||||
updated_items = []
|
||||
item_found = False
|
||||
|
||||
for item in items:
|
||||
current_item_id = item.get("id") if isinstance(item, dict) else getattr(item, "id", None)
|
||||
if current_item_id != item_id:
|
||||
updated_items.append(item)
|
||||
else:
|
||||
item_found = True
|
||||
|
||||
if not item_found:
|
||||
raise ValueError(f"Item {item_id} not found in conversation {conversation_id}")
|
||||
|
||||
await self.sql_store.update(
|
||||
table="openai_conversations", data={"items": updated_items}, where={"id": conversation_id}
|
||||
)
|
||||
|
||||
logger.info(f"Deleted item {item_id} from conversation {conversation_id}")
|
||||
return ConversationItemDeletedResource(id=item_id)
|
Loading…
Add table
Add a link
Reference in a new issue