feat: Add OpenAI Conversations API

Signed-off-by: Francisco Javier Arceo <farceo@redhat.com>
This commit is contained in:
Francisco Javier Arceo 2025-09-12 21:13:51 -04:00
parent 0e13512dd7
commit a74a7cc873
18 changed files with 3280 additions and 1088 deletions

File diff suppressed because it is too large Load diff

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,31 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .conversations import (
Conversation,
ConversationCreateRequest,
ConversationDeletedResource,
ConversationItem,
ConversationItemCreateRequest,
ConversationItemDeletedResource,
ConversationItemList,
Conversations,
ConversationUpdateRequest,
Metadata,
)
__all__ = [
"Conversation",
"ConversationCreateRequest",
"ConversationDeletedResource",
"ConversationItem",
"ConversationItemCreateRequest",
"ConversationItemDeletedResource",
"ConversationItemList",
"Conversations",
"ConversationUpdateRequest",
"Metadata",
]

View file

@ -0,0 +1,260 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Annotated, Literal, Protocol, runtime_checkable
from openai import NOT_GIVEN
from openai._types import NotGiven
from openai.types.responses.response_includable import ResponseIncludable
from pydantic import BaseModel, Field
from llama_stack.apis.agents.openai_responses import (
OpenAIResponseMessage,
OpenAIResponseOutputMessageFileSearchToolCall,
OpenAIResponseOutputMessageFunctionToolCall,
OpenAIResponseOutputMessageMCPCall,
OpenAIResponseOutputMessageMCPListTools,
OpenAIResponseOutputMessageWebSearchToolCall,
)
from llama_stack.apis.version import LLAMA_STACK_API_V1
from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol
from llama_stack.schema_utils import json_schema_type, register_schema, webmethod
Metadata = dict[str, str]
@json_schema_type
class Conversation(BaseModel):
"""OpenAI-compatible conversation object."""
id: str = Field(..., description="The unique ID of the conversation.")
object: Literal["conversation"] = Field(
default="conversation", description="The object type, which is always conversation."
)
created_at: int = Field(
..., description="The time at which the conversation was created, measured in seconds since the Unix epoch."
)
metadata: Metadata | None = Field(
default=None,
description="Set of 16 key-value pairs that can be attached to an object. This can be useful for storing additional information about the object in a structured format, and querying for objects via API or the dashboard.",
)
items: list[dict] | None = Field(
default=None,
description="Initial items to include in the conversation context. You may add up to 20 items at a time.",
)
@json_schema_type
class ConversationMessage(BaseModel):
"""OpenAI-compatible message item for conversations."""
id: str = Field(..., description="unique identifier for this message")
content: list[dict] = Field(..., description="message content")
role: str = Field(..., description="message role")
status: str = Field(..., description="message status")
type: Literal["message"] = "message"
object: Literal["message"] = "message"
ConversationItem = Annotated[
OpenAIResponseMessage
| OpenAIResponseOutputMessageFunctionToolCall
| OpenAIResponseOutputMessageFileSearchToolCall
| OpenAIResponseOutputMessageWebSearchToolCall
| OpenAIResponseOutputMessageMCPCall
| OpenAIResponseOutputMessageMCPListTools,
Field(discriminator="type"),
]
register_schema(ConversationItem, name="ConversationItem")
# Using OpenAI types directly caused issues but some notes for reference:
# Note that ConversationItem is a Annotated Union of the types below:
# from openai.types.responses import *
# from openai.types.responses.response_item import *
# from openai.types.conversations import ConversationItem
# f = [
# ResponseFunctionToolCallItem,
# ResponseFunctionToolCallOutputItem,
# ResponseFileSearchToolCall,
# ResponseFunctionWebSearch,
# ImageGenerationCall,
# ResponseComputerToolCall,
# ResponseComputerToolCallOutputItem,
# ResponseReasoningItem,
# ResponseCodeInterpreterToolCall,
# LocalShellCall,
# LocalShellCallOutput,
# McpListTools,
# McpApprovalRequest,
# McpApprovalResponse,
# McpCall,
# ResponseCustomToolCall,
# ResponseCustomToolCallOutput
# ]
@json_schema_type
class ConversationCreateRequest(BaseModel):
"""Request body for creating a conversation."""
items: list[ConversationItem] | None = Field(
default=[],
description="Initial items to include in the conversation context. You may add up to 20 items at a time.",
max_length=20,
)
metadata: Metadata | None = Field(
default={},
description="Set of 16 key-value pairs that can be attached to an object. Useful for storing additional information",
max_length=16,
)
@json_schema_type
class ConversationUpdateRequest(BaseModel):
"""Request body for updating a conversation."""
metadata: Metadata = Field(
...,
description="Set of 16 key-value pairs that can be attached to an object. This can be useful for storing additional information about the object in a structured format, and querying for objects via API or the dashboard. Keys are strings with a maximum length of 64 characters. Values are strings with a maximum length of 512 characters.",
)
@json_schema_type
class ConversationDeletedResource(BaseModel):
"""Response for deleted conversation."""
id: str = Field(..., description="The deleted conversation identifier")
object: str = Field(default="conversation.deleted", description="Object type")
deleted: bool = Field(default=True, description="Whether the object was deleted")
@json_schema_type
class ConversationItemCreateRequest(BaseModel):
"""Request body for creating conversation items."""
items: list[ConversationItem] = Field(
...,
description="Items to include in the conversation context. You may add up to 20 items at a time.",
max_length=20,
)
@json_schema_type
class ConversationItemList(BaseModel):
"""List of conversation items with pagination."""
object: str = Field(default="list", description="Object type")
data: list[ConversationItem] = Field(..., description="List of conversation items")
first_id: str | None = Field(default=None, description="The ID of the first item in the list")
last_id: str | None = Field(default=None, description="The ID of the last item in the list")
has_more: bool = Field(default=False, description="Whether there are more items available")
@json_schema_type
class ConversationItemDeletedResource(BaseModel):
"""Response for deleted conversation item."""
id: str = Field(..., description="The deleted item identifier")
object: str = Field(default="conversation.item.deleted", description="Object type")
deleted: bool = Field(default=True, description="Whether the object was deleted")
@runtime_checkable
@trace_protocol
class Conversations(Protocol):
"""Protocol for conversation management operations."""
@webmethod(route="/conversations", method="POST", level=LLAMA_STACK_API_V1)
async def create_conversation(
self, items: list[ConversationItem] | None = None, metadata: Metadata | None = None
) -> Conversation:
"""Create a conversation.
:param items: Initial items to include in the conversation context.
:param metadata: Set of key-value pairs that can be attached to an object.
:returns: The created conversation object.
"""
...
@webmethod(route="/conversations/{conversation_id}", method="GET", level=LLAMA_STACK_API_V1)
async def get_conversation(self, conversation_id: str) -> Conversation:
"""Get a conversation with the given ID.
:param conversation_id: The conversation identifier.
:returns: The conversation object.
"""
...
@webmethod(route="/conversations/{conversation_id}", method="POST", level=LLAMA_STACK_API_V1)
async def update_conversation(self, conversation_id: str, metadata: Metadata) -> Conversation:
"""Update a conversation's metadata with the given ID.
:param conversation_id: The conversation identifier.
:param metadata: Set of key-value pairs that can be attached to an object.
:returns: The updated conversation object.
"""
...
@webmethod(route="/conversations/{conversation_id}", method="DELETE", level=LLAMA_STACK_API_V1)
async def openai_delete_conversation(self, conversation_id: str) -> ConversationDeletedResource:
"""Delete a conversation with the given ID.
:param conversation_id: The conversation identifier.
:returns: The deleted conversation resource.
"""
...
@webmethod(route="/conversations/{conversation_id}/items", method="POST", level=LLAMA_STACK_API_V1)
async def create(self, conversation_id: str, items: list[ConversationItem]) -> ConversationItemList:
"""Create items in the conversation.
:param conversation_id: The conversation identifier.
:param items: Items to include in the conversation context.
:returns: List of created items.
"""
...
@webmethod(route="/conversations/{conversation_id}/items/{item_id}", method="GET", level=LLAMA_STACK_API_V1)
async def retrieve(self, conversation_id: str, item_id: str) -> ConversationItem:
"""Retrieve a conversation item.
:param conversation_id: The conversation identifier.
:param item_id: The item identifier.
:returns: The conversation item.
"""
...
@webmethod(route="/conversations/{conversation_id}/items", method="GET", level=LLAMA_STACK_API_V1)
async def list(
self,
conversation_id: str,
after: str | NotGiven = NOT_GIVEN,
include: list[ResponseIncludable] | NotGiven = NOT_GIVEN,
limit: int | NotGiven = NOT_GIVEN,
order: Literal["asc", "desc"] | NotGiven = NOT_GIVEN,
) -> ConversationItemList:
"""List items in the conversation.
:param conversation_id: The conversation identifier.
:param after: An item ID to list items after, used in pagination.
:param include: Specify additional output data to include in the response.
:param limit: A limit on the number of objects to be returned (1-100, default 20).
:param order: The order to return items in (asc or desc, default desc).
:returns: List of conversation items.
"""
...
@webmethod(route="/conversations/{conversation_id}/items/{item_id}", method="DELETE", level=LLAMA_STACK_API_V1)
async def openai_delete_conversation_item(
self, conversation_id: str, item_id: str
) -> ConversationItemDeletedResource:
"""Delete a conversation item.
:param conversation_id: The conversation identifier.
:param item_id: The item identifier.
:returns: The deleted item resource.
"""
...

View file

@ -129,6 +129,7 @@ class Api(Enum, metaclass=DynamicApiMeta):
tool_groups = "tool_groups"
files = "files"
prompts = "prompts"
conversations = "conversations"
# built-in API
inspect = "inspect"

View file

@ -0,0 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

View file

@ -0,0 +1,290 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import os
import secrets
import time
from typing import Any
from openai import NOT_GIVEN
from pydantic import BaseModel, TypeAdapter
from llama_stack.apis.conversations.conversations import (
Conversation,
ConversationDeletedResource,
ConversationItem,
ConversationItemDeletedResource,
ConversationItemList,
Conversations,
Metadata,
)
from llama_stack.core.datatypes import AccessRule, StackRunConfig
from llama_stack.core.utils.config_dirs import DISTRIBS_BASE_DIR
from llama_stack.log import get_logger
from llama_stack.providers.utils.sqlstore.api import ColumnDefinition, ColumnType
from llama_stack.providers.utils.sqlstore.authorized_sqlstore import AuthorizedSqlStore
from llama_stack.providers.utils.sqlstore.sqlstore import PostgresSqlStoreConfig, SqliteSqlStoreConfig, sqlstore_impl
logger = get_logger(name=__name__, category="openai::conversations")
class ConversationServiceConfig(BaseModel):
"""Configuration for the built-in conversation service.
:param run_config: Stack run configuration containing distribution info
"""
run_config: StackRunConfig
async def get_provider_impl(config: ConversationServiceConfig, deps: dict[Any, Any]):
"""Get the conversation service implementation."""
impl = ConversationServiceImpl(config, deps)
await impl.initialize()
return impl
class ConversationServiceImpl(Conversations):
"""Built-in conversation service implementation using AuthorizedSqlStore."""
def __init__(self, config: ConversationServiceConfig, deps: dict[Any, Any]):
self.config = config
self.deps = deps
self.policy: list[AccessRule] = []
conversations_store_config = config.run_config.conversations_store
if conversations_store_config is None:
sql_store_config: SqliteSqlStoreConfig | PostgresSqlStoreConfig = SqliteSqlStoreConfig(
db_path=(DISTRIBS_BASE_DIR / config.run_config.image_name / "conversations.db").as_posix()
)
else:
sql_store_config = conversations_store_config
base_sql_store = sqlstore_impl(sql_store_config)
self.sql_store = AuthorizedSqlStore(base_sql_store, self.policy)
async def initialize(self) -> None:
"""Initialize the store and create tables."""
if hasattr(self.sql_store.sql_store, "config") and hasattr(self.sql_store.sql_store.config, "db_path"):
os.makedirs(os.path.dirname(self.sql_store.sql_store.config.db_path), exist_ok=True)
await self.sql_store.create_table(
"openai_conversations",
{
"id": ColumnDefinition(type=ColumnType.STRING, primary_key=True),
"created_at": ColumnType.INTEGER,
"items": ColumnType.JSON,
"metadata": ColumnType.JSON,
},
)
async def create_conversation(
self, items: list[ConversationItem] | None = None, metadata: Metadata | None = None
) -> Conversation:
"""Create a conversation."""
random_bytes = secrets.token_bytes(24)
conversation_id = f"conv_{random_bytes.hex()}"
created_at = int(time.time())
items_json = []
for item in items or []:
item_dict = item.model_dump() if hasattr(item, "model_dump") else item
items_json.append(item_dict)
record_data = {
"id": conversation_id,
"created_at": created_at,
"items": items_json,
"metadata": metadata,
}
await self.sql_store.insert(
table="openai_conversations",
data=record_data,
)
conversation = Conversation(
id=conversation_id,
created_at=created_at,
metadata=metadata,
object="conversation",
)
logger.info(f"Created conversation {conversation_id}")
return conversation
async def get_conversation(self, conversation_id: str) -> Conversation:
"""Get a conversation with the given ID."""
record = await self.sql_store.fetch_one(table="openai_conversations", where={"id": conversation_id})
if record is None:
raise ValueError(f"Conversation {conversation_id} not found")
return Conversation(
id=record["id"], created_at=record["created_at"], metadata=record.get("metadata"), object="conversation"
)
async def update_conversation(self, conversation_id: str, metadata: Metadata) -> Conversation:
"""Update a conversation's metadata with the given ID"""
await self.sql_store.update(
table="openai_conversations", data={"metadata": metadata}, where={"id": conversation_id}
)
return await self.get_conversation(conversation_id)
async def openai_delete_conversation(self, conversation_id: str) -> ConversationDeletedResource:
"""Delete a conversation with the given ID."""
await self.sql_store.delete(table="openai_conversations", where={"id": conversation_id})
logger.info(f"Deleted conversation {conversation_id}")
return ConversationDeletedResource(id=conversation_id)
def _validate_conversation_id(self, conversation_id: str) -> None:
"""Validate conversation ID format."""
if not conversation_id.startswith("conv_"):
raise ValueError(
f"Invalid 'conversation_id': '{conversation_id}'. Expected an ID that begins with 'conv_'."
)
async def _get_validated_conversation(self, conversation_id: str) -> Conversation:
"""Validate conversation ID and return the conversation if it exists."""
self._validate_conversation_id(conversation_id)
return await self.get_conversation(conversation_id)
async def create(self, conversation_id: str, items: list[ConversationItem]) -> ConversationItemList:
"""Create items in the conversation."""
await self._get_validated_conversation(conversation_id)
created_items = []
for item in items:
# Generate item ID based on item type
random_bytes = secrets.token_bytes(24)
item_type = getattr(item, "type", None)
if item_type == "message":
item_id = f"msg_{random_bytes.hex()}"
else:
item_id = f"item_{random_bytes.hex()}"
# Create a copy of the item with the generated ID and completed status
item_dict = item.model_dump() if hasattr(item, "model_dump") else dict(item)
item_dict["id"] = item_id
if "status" not in item_dict:
item_dict["status"] = "completed"
created_items.append(item_dict)
# Get existing items from database
record = await self.sql_store.fetch_one(table="openai_conversations", where={"id": conversation_id})
existing_items = record.get("items", []) if record else []
updated_items = existing_items + created_items
await self.sql_store.update(
table="openai_conversations", data={"items": updated_items}, where={"id": conversation_id}
)
logger.info(f"Created {len(created_items)} items in conversation {conversation_id}")
# Convert created items (dicts) to proper ConversationItem types
adapter: TypeAdapter[ConversationItem] = TypeAdapter(ConversationItem)
response_items: list[ConversationItem] = [adapter.validate_python(item_dict) for item_dict in created_items]
return ConversationItemList(
data=response_items,
first_id=created_items[0]["id"] if created_items else None,
last_id=created_items[-1]["id"] if created_items else None,
has_more=False,
)
async def retrieve(self, conversation_id: str, item_id: str) -> ConversationItem:
"""Retrieve a conversation item."""
if not conversation_id:
raise ValueError(f"Expected a non-empty value for `conversation_id` but received {conversation_id!r}")
if not item_id:
raise ValueError(f"Expected a non-empty value for `item_id` but received {item_id!r}")
record = await self.sql_store.fetch_one(table="openai_conversations", where={"id": conversation_id})
items = record.get("items", []) if record else []
for item in items:
if isinstance(item, dict) and item.get("id") == item_id:
adapter: TypeAdapter[ConversationItem] = TypeAdapter(ConversationItem)
return adapter.validate_python(item)
raise ValueError(f"Item {item_id} not found in conversation {conversation_id}")
async def list(self, conversation_id: str, after=NOT_GIVEN, include=NOT_GIVEN, limit=NOT_GIVEN, order=NOT_GIVEN):
"""List items in the conversation."""
record = await self.sql_store.fetch_one(table="openai_conversations", where={"id": conversation_id})
items = record.get("items", []) if record else []
if order != NOT_GIVEN and order == "asc":
items = items
else:
items = list(reversed(items))
actual_limit = 20
if limit != NOT_GIVEN and isinstance(limit, int):
actual_limit = limit
items = items[:actual_limit]
# Items from database are stored as dicts, convert them to ConversationItem
adapter: TypeAdapter[ConversationItem] = TypeAdapter(ConversationItem)
response_items: list[ConversationItem] = [
adapter.validate_python(item) if isinstance(item, dict) else item for item in items
]
# Get first and last IDs safely
first_id = None
last_id = None
if items:
first_item = items[0]
last_item = items[-1]
first_id = first_item.get("id") if isinstance(first_item, dict) else getattr(first_item, "id", None)
last_id = last_item.get("id") if isinstance(last_item, dict) else getattr(last_item, "id", None)
return ConversationItemList(
data=response_items,
first_id=first_id,
last_id=last_id,
has_more=False,
)
async def openai_delete_conversation_item(
self, conversation_id: str, item_id: str
) -> ConversationItemDeletedResource:
"""Delete a conversation item."""
if not conversation_id:
raise ValueError(f"Expected a non-empty value for `conversation_id` but received {conversation_id!r}")
if not item_id:
raise ValueError(f"Expected a non-empty value for `item_id` but received {item_id!r}")
_ = await self._get_validated_conversation(conversation_id) # executes validation
record = await self.sql_store.fetch_one(table="openai_conversations", where={"id": conversation_id})
items = record.get("items", []) if record else []
updated_items = []
item_found = False
for item in items:
current_item_id = item.get("id") if isinstance(item, dict) else getattr(item, "id", None)
if current_item_id != item_id:
updated_items.append(item)
else:
item_found = True
if not item_found:
raise ValueError(f"Item {item_id} not found in conversation {conversation_id}")
await self.sql_store.update(
table="openai_conversations", data={"items": updated_items}, where={"id": conversation_id}
)
logger.info(f"Deleted item {item_id} from conversation {conversation_id}")
return ConversationItemDeletedResource(id=item_id)

View file

@ -480,6 +480,13 @@ InferenceStoreConfig (with queue tuning parameters) or a SqlStoreConfig (depreca
If not specified, a default SQLite store will be used.""",
)
conversations_store: SqlStoreConfig | None = Field(
default=None,
description="""
Configuration for the persistence store used by the conversations API.
If not specified, a default SQLite store will be used.""",
)
# registry of "resources" in the distribution
models: list[ModelInput] = Field(default_factory=list)
shields: list[ShieldInput] = Field(default_factory=list)

View file

@ -25,7 +25,7 @@ from llama_stack.providers.datatypes import (
logger = get_logger(name=__name__, category="core")
INTERNAL_APIS = {Api.inspect, Api.providers, Api.prompts}
INTERNAL_APIS = {Api.inspect, Api.providers, Api.prompts, Api.conversations}
def stack_apis() -> list[Api]:

View file

@ -10,6 +10,7 @@ from typing import Any
from llama_stack.apis.agents import Agents
from llama_stack.apis.batches import Batches
from llama_stack.apis.benchmarks import Benchmarks
from llama_stack.apis.conversations import Conversations
from llama_stack.apis.datasetio import DatasetIO
from llama_stack.apis.datasets import Datasets
from llama_stack.apis.datatypes import ExternalApiSpec
@ -96,6 +97,7 @@ def api_protocol_map(external_apis: dict[Api, ExternalApiSpec] | None = None) ->
Api.tool_runtime: ToolRuntime,
Api.files: Files,
Api.prompts: Prompts,
Api.conversations: Conversations,
}
if external_apis:

View file

@ -451,6 +451,7 @@ def create_app(
apis_to_serve.add("inspect")
apis_to_serve.add("providers")
apis_to_serve.add("prompts")
apis_to_serve.add("conversations")
for api_str in apis_to_serve:
api = Api(api_str)

View file

@ -15,6 +15,7 @@ import yaml
from llama_stack.apis.agents import Agents
from llama_stack.apis.benchmarks import Benchmarks
from llama_stack.apis.conversations import Conversations
from llama_stack.apis.datasetio import DatasetIO
from llama_stack.apis.datasets import Datasets
from llama_stack.apis.eval import Eval
@ -34,6 +35,7 @@ from llama_stack.apis.telemetry import Telemetry
from llama_stack.apis.tools import RAGToolRuntime, ToolGroups, ToolRuntime
from llama_stack.apis.vector_dbs import VectorDBs
from llama_stack.apis.vector_io import VectorIO
from llama_stack.core.conversations.conversations import ConversationServiceConfig, ConversationServiceImpl
from llama_stack.core.datatypes import Provider, StackRunConfig
from llama_stack.core.distribution import get_provider_registry
from llama_stack.core.inspect import DistributionInspectConfig, DistributionInspectImpl
@ -73,6 +75,7 @@ class LlamaStack(
RAGToolRuntime,
Files,
Prompts,
Conversations,
):
pass
@ -312,6 +315,12 @@ def add_internal_implementations(impls: dict[Api, Any], run_config: StackRunConf
)
impls[Api.prompts] = prompts_impl
conversations_impl = ConversationServiceImpl(
ConversationServiceConfig(run_config=run_config),
deps=impls,
)
impls[Api.conversations] = conversations_impl
class Stack:
def __init__(self, run_config: StackRunConfig, provider_registry: ProviderRegistry | None = None):
@ -342,6 +351,8 @@ class Stack:
if Api.prompts in impls:
await impls[Api.prompts].initialize()
if Api.conversations in impls:
await impls[Api.conversations].initialize()
await register_resources(self.run_config, impls)

View file

@ -484,12 +484,19 @@ class JsonSchemaGenerator:
}
return ret
elif origin_type is Literal:
if len(typing.get_args(typ)) != 1:
raise ValueError(f"Literal type {typ} has {len(typing.get_args(typ))} arguments")
(literal_value,) = typing.get_args(typ) # unpack value of literal type
schema = self.type_to_schema(type(literal_value))
schema["const"] = literal_value
return schema
literal_args = typing.get_args(typ)
if len(literal_args) == 1:
(literal_value,) = literal_args
schema = self.type_to_schema(type(literal_value))
schema["const"] = literal_value
return schema
elif len(literal_args) > 1:
first_value = literal_args[0]
schema = self.type_to_schema(type(first_value))
schema["enum"] = list(literal_args)
return schema
else:
return {"enum": []}
elif origin_type is type:
(concrete_type,) = typing.get_args(typ) # unpack single tuple element
return {"const": self.type_to_schema(concrete_type, force_expand=True)}

View file

@ -32,7 +32,7 @@ dependencies = [
"jinja2>=3.1.6",
"jsonschema",
"llama-stack-client>=0.2.23",
"openai>=1.100.0", # for expires_after support
"openai>=1.107", # for expires_after support
"prompt-toolkit",
"python-dotenv",
"python-jose[cryptography]",
@ -49,6 +49,7 @@ dependencies = [
"opentelemetry-exporter-otlp-proto-http>=1.30.0", # server
"aiosqlite>=0.21.0", # server - for metadata store
"asyncpg", # for metadata store
"sqlalchemy[asyncio]>=2.0.41", # server - for conversations
]
[project.optional-dependencies]

View file

@ -0,0 +1,134 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import pytest
@pytest.mark.integration
class TestOpenAIConversations:
def test_conversation_create(self, openai_client):
conversation = openai_client.conversations.create(
metadata={"topic": "demo"}, items=[{"type": "message", "role": "user", "content": "Hello!"}]
)
assert conversation.id.startswith("conv_")
assert conversation.object == "conversation"
assert conversation.metadata["topic"] == "demo"
assert isinstance(conversation.created_at, int)
def test_conversation_retrieve(self, openai_client):
conversation = openai_client.conversations.create(metadata={"topic": "demo"})
retrieved = openai_client.conversations.retrieve(conversation.id)
assert retrieved.id == conversation.id
assert retrieved.object == "conversation"
assert retrieved.metadata["topic"] == "demo"
assert retrieved.created_at == conversation.created_at
def test_conversation_update(self, openai_client):
conversation = openai_client.conversations.create(metadata={"topic": "demo"})
updated = openai_client.conversations.update(conversation.id, metadata={"topic": "project-x"})
assert updated.id == conversation.id
assert updated.metadata["topic"] == "project-x"
assert updated.created_at == conversation.created_at
def test_conversation_delete(self, openai_client):
conversation = openai_client.conversations.create(metadata={"topic": "demo"})
deleted = openai_client.conversations.delete(conversation.id)
assert deleted.id == conversation.id
assert deleted.object == "conversation.deleted"
assert deleted.deleted is True
def test_conversation_items_create(self, openai_client):
conversation = openai_client.conversations.create()
items = openai_client.conversations.items.create(
conversation.id,
items=[
{"type": "message", "role": "user", "content": [{"type": "input_text", "text": "Hello!"}]},
{"type": "message", "role": "user", "content": [{"type": "input_text", "text": "How are you?"}]},
],
)
assert items.object == "list"
assert len(items.data) == 2
assert items.data[0].content[0].text == "Hello!"
assert items.data[1].content[0].text == "How are you?"
assert items.first_id == items.data[0].id
assert items.last_id == items.data[1].id
assert items.has_more is False
def test_conversation_items_list(self, openai_client):
conversation = openai_client.conversations.create()
openai_client.conversations.items.create(
conversation.id,
items=[{"type": "message", "role": "user", "content": [{"type": "input_text", "text": "Hello!"}]}],
)
items = openai_client.conversations.items.list(conversation.id, limit=10)
assert items.object == "list"
assert len(items.data) >= 1
assert items.data[0].type == "message"
assert items.data[0].role == "user"
assert hasattr(items, "first_id")
assert hasattr(items, "last_id")
assert hasattr(items, "has_more")
def test_conversation_item_retrieve(self, openai_client):
conversation = openai_client.conversations.create()
created_items = openai_client.conversations.items.create(
conversation.id,
items=[{"type": "message", "role": "user", "content": [{"type": "input_text", "text": "Hello!"}]}],
)
item_id = created_items.data[0].id
item = openai_client.conversations.items.retrieve(item_id, conversation_id=conversation.id)
assert item.id == item_id
assert item.type == "message"
assert item.role == "user"
assert item.content[0].text == "Hello!"
def test_conversation_item_delete(self, openai_client):
conversation = openai_client.conversations.create()
created_items = openai_client.conversations.items.create(
conversation.id,
items=[{"type": "message", "role": "user", "content": [{"type": "input_text", "text": "Hello!"}]}],
)
item_id = created_items.data[0].id
deleted = openai_client.conversations.items.delete(item_id, conversation_id=conversation.id)
assert deleted.id == item_id
assert deleted.object == "conversation.item.deleted"
assert deleted.deleted is True
def test_full_workflow(self, openai_client):
conversation = openai_client.conversations.create(
metadata={"topic": "workflow-test"}, items=[{"type": "message", "role": "user", "content": "Hello!"}]
)
openai_client.conversations.items.create(
conversation.id,
items=[{"type": "message", "role": "user", "content": [{"type": "input_text", "text": "Follow up"}]}],
)
all_items = openai_client.conversations.items.list(conversation.id)
assert len(all_items.data) >= 2
updated = openai_client.conversations.update(conversation.id, metadata={"topic": "workflow-complete"})
assert updated.metadata["topic"] == "workflow-complete"
openai_client.conversations.delete(conversation.id)

View file

@ -0,0 +1,60 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.apis.conversations.conversations import (
Conversation,
ConversationCreateRequest,
ConversationItem,
ConversationItemList,
)
def test_conversation_create_request_defaults():
request = ConversationCreateRequest()
assert request.items == []
assert request.metadata == {}
def test_conversation_model_defaults():
conversation = Conversation(
id="conv_123456789",
created_at=1234567890,
metadata=None,
object="conversation",
)
assert conversation.id == "conv_123456789"
assert conversation.object == "conversation"
assert conversation.metadata is None
def test_openai_client_compatibility():
from openai.types.conversations.message import Message
from pydantic import TypeAdapter
openai_message = Message(
id="msg_123",
content=[{"type": "input_text", "text": "Hello"}],
role="user",
status="in_progress",
type="message",
object="message",
)
adapter = TypeAdapter(ConversationItem)
validated_item = adapter.validate_python(openai_message.model_dump())
assert validated_item.id == "msg_123"
assert validated_item.type == "message"
def test_conversation_item_list():
item_list = ConversationItemList(data=[])
assert item_list.object == "list"
assert item_list.data == []
assert item_list.first_id is None
assert item_list.last_id is None
assert item_list.has_more is False

View file

@ -0,0 +1,117 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import tempfile
from pathlib import Path
import pytest
from openai.types.conversations.conversation import Conversation as OpenAIConversation
from openai.types.conversations.conversation_item import ConversationItem as OpenAIConversationItem
from pydantic import TypeAdapter
from llama_stack.apis.agents.openai_responses import (
OpenAIResponseInputMessageContentText,
OpenAIResponseMessage,
)
from llama_stack.core.conversations.conversations import (
ConversationServiceConfig,
ConversationServiceImpl,
)
from llama_stack.core.datatypes import StackRunConfig
from llama_stack.providers.utils.sqlstore.sqlstore import SqliteSqlStoreConfig
@pytest.fixture
async def service():
with tempfile.TemporaryDirectory() as tmpdir:
db_path = Path(tmpdir) / "test_conversations.db"
config = ConversationServiceConfig(
run_config=StackRunConfig(
image_name="test",
providers={},
conversations_store=SqliteSqlStoreConfig(db_path=str(db_path)),
)
)
service = ConversationServiceImpl(config, {})
await service.initialize()
yield service
async def test_conversation_lifecycle(service):
conversation = await service.create_conversation(metadata={"test": "data"})
assert conversation.id.startswith("conv_")
assert conversation.metadata == {"test": "data"}
retrieved = await service.get_conversation(conversation.id)
assert retrieved.id == conversation.id
deleted = await service.openai_delete_conversation(conversation.id)
assert deleted.id == conversation.id
async def test_conversation_items(service):
conversation = await service.create_conversation()
items = [
OpenAIResponseMessage(
type="message",
role="user",
content=[OpenAIResponseInputMessageContentText(type="input_text", text="Hello")],
id="msg_test123",
status="completed",
)
]
item_list = await service.create(conversation.id, items)
assert len(item_list.data) == 1
assert item_list.data[0].id.startswith("msg_")
items = await service.list(conversation.id)
assert len(items.data) == 1
async def test_invalid_conversation_id(service):
with pytest.raises(ValueError, match="Expected an ID that begins with 'conv_'"):
await service._get_validated_conversation("invalid_id")
async def test_empty_parameter_validation(service):
with pytest.raises(ValueError, match="Expected a non-empty value"):
await service.retrieve("", "item_123")
async def test_openai_type_compatibility(service):
conversation = await service.create_conversation(metadata={"test": "value"})
conversation_dict = conversation.model_dump()
openai_conversation = OpenAIConversation.model_validate(conversation_dict)
for attr in ["id", "object", "created_at", "metadata"]:
assert getattr(openai_conversation, attr) == getattr(conversation, attr)
items = [
OpenAIResponseMessage(
type="message",
role="user",
content=[OpenAIResponseInputMessageContentText(type="input_text", text="Hello")],
id="msg_test456",
status="completed",
)
]
item_list = await service.create(conversation.id, items)
for attr in ["object", "data", "first_id", "last_id", "has_more"]:
assert hasattr(item_list, attr)
assert item_list.object == "list"
items = await service.list(conversation.id)
item = await service.retrieve(conversation.id, items.data[0].id)
item_dict = item.model_dump()
openai_item_adapter = TypeAdapter(OpenAIConversationItem)
openai_item_adapter.validate_python(item_dict)

4
uv.lock generated
View file

@ -1773,6 +1773,7 @@ dependencies = [
{ name = "python-jose", extra = ["cryptography"] },
{ name = "python-multipart" },
{ name = "rich" },
{ name = "sqlalchemy", extra = ["asyncio"] },
{ name = "starlette" },
{ name = "termcolor" },
{ name = "tiktoken" },
@ -1887,7 +1888,7 @@ requires-dist = [
{ name = "jsonschema" },
{ name = "llama-stack-client", specifier = ">=0.2.23" },
{ name = "llama-stack-client", marker = "extra == 'ui'", specifier = ">=0.2.23" },
{ name = "openai", specifier = ">=1.100.0" },
{ name = "openai", specifier = ">=1.107" },
{ name = "opentelemetry-exporter-otlp-proto-http", specifier = ">=1.30.0" },
{ name = "opentelemetry-sdk", specifier = ">=1.30.0" },
{ name = "pandas", marker = "extra == 'ui'" },
@ -1898,6 +1899,7 @@ requires-dist = [
{ name = "python-jose", extras = ["cryptography"] },
{ name = "python-multipart", specifier = ">=0.0.20" },
{ name = "rich" },
{ name = "sqlalchemy", extras = ["asyncio"], specifier = ">=2.0.41" },
{ name = "starlette" },
{ name = "streamlit", marker = "extra == 'ui'" },
{ name = "streamlit-option-menu", marker = "extra == 'ui'" },