feat: Add OpenAI Conversations API

Signed-off-by: Francisco Javier Arceo <farceo@redhat.com>
This commit is contained in:
Francisco Javier Arceo 2025-09-12 21:13:51 -04:00
parent 0e13512dd7
commit a74a7cc873
18 changed files with 3280 additions and 1088 deletions

View file

@ -0,0 +1,31 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .conversations import (
Conversation,
ConversationCreateRequest,
ConversationDeletedResource,
ConversationItem,
ConversationItemCreateRequest,
ConversationItemDeletedResource,
ConversationItemList,
Conversations,
ConversationUpdateRequest,
Metadata,
)
__all__ = [
"Conversation",
"ConversationCreateRequest",
"ConversationDeletedResource",
"ConversationItem",
"ConversationItemCreateRequest",
"ConversationItemDeletedResource",
"ConversationItemList",
"Conversations",
"ConversationUpdateRequest",
"Metadata",
]

View file

@ -0,0 +1,260 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Annotated, Literal, Protocol, runtime_checkable
from openai import NOT_GIVEN
from openai._types import NotGiven
from openai.types.responses.response_includable import ResponseIncludable
from pydantic import BaseModel, Field
from llama_stack.apis.agents.openai_responses import (
OpenAIResponseMessage,
OpenAIResponseOutputMessageFileSearchToolCall,
OpenAIResponseOutputMessageFunctionToolCall,
OpenAIResponseOutputMessageMCPCall,
OpenAIResponseOutputMessageMCPListTools,
OpenAIResponseOutputMessageWebSearchToolCall,
)
from llama_stack.apis.version import LLAMA_STACK_API_V1
from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol
from llama_stack.schema_utils import json_schema_type, register_schema, webmethod
Metadata = dict[str, str]
@json_schema_type
class Conversation(BaseModel):
"""OpenAI-compatible conversation object."""
id: str = Field(..., description="The unique ID of the conversation.")
object: Literal["conversation"] = Field(
default="conversation", description="The object type, which is always conversation."
)
created_at: int = Field(
..., description="The time at which the conversation was created, measured in seconds since the Unix epoch."
)
metadata: Metadata | None = Field(
default=None,
description="Set of 16 key-value pairs that can be attached to an object. This can be useful for storing additional information about the object in a structured format, and querying for objects via API or the dashboard.",
)
items: list[dict] | None = Field(
default=None,
description="Initial items to include in the conversation context. You may add up to 20 items at a time.",
)
@json_schema_type
class ConversationMessage(BaseModel):
"""OpenAI-compatible message item for conversations."""
id: str = Field(..., description="unique identifier for this message")
content: list[dict] = Field(..., description="message content")
role: str = Field(..., description="message role")
status: str = Field(..., description="message status")
type: Literal["message"] = "message"
object: Literal["message"] = "message"
ConversationItem = Annotated[
OpenAIResponseMessage
| OpenAIResponseOutputMessageFunctionToolCall
| OpenAIResponseOutputMessageFileSearchToolCall
| OpenAIResponseOutputMessageWebSearchToolCall
| OpenAIResponseOutputMessageMCPCall
| OpenAIResponseOutputMessageMCPListTools,
Field(discriminator="type"),
]
register_schema(ConversationItem, name="ConversationItem")
# Using OpenAI types directly caused issues but some notes for reference:
# Note that ConversationItem is a Annotated Union of the types below:
# from openai.types.responses import *
# from openai.types.responses.response_item import *
# from openai.types.conversations import ConversationItem
# f = [
# ResponseFunctionToolCallItem,
# ResponseFunctionToolCallOutputItem,
# ResponseFileSearchToolCall,
# ResponseFunctionWebSearch,
# ImageGenerationCall,
# ResponseComputerToolCall,
# ResponseComputerToolCallOutputItem,
# ResponseReasoningItem,
# ResponseCodeInterpreterToolCall,
# LocalShellCall,
# LocalShellCallOutput,
# McpListTools,
# McpApprovalRequest,
# McpApprovalResponse,
# McpCall,
# ResponseCustomToolCall,
# ResponseCustomToolCallOutput
# ]
@json_schema_type
class ConversationCreateRequest(BaseModel):
"""Request body for creating a conversation."""
items: list[ConversationItem] | None = Field(
default=[],
description="Initial items to include in the conversation context. You may add up to 20 items at a time.",
max_length=20,
)
metadata: Metadata | None = Field(
default={},
description="Set of 16 key-value pairs that can be attached to an object. Useful for storing additional information",
max_length=16,
)
@json_schema_type
class ConversationUpdateRequest(BaseModel):
"""Request body for updating a conversation."""
metadata: Metadata = Field(
...,
description="Set of 16 key-value pairs that can be attached to an object. This can be useful for storing additional information about the object in a structured format, and querying for objects via API or the dashboard. Keys are strings with a maximum length of 64 characters. Values are strings with a maximum length of 512 characters.",
)
@json_schema_type
class ConversationDeletedResource(BaseModel):
"""Response for deleted conversation."""
id: str = Field(..., description="The deleted conversation identifier")
object: str = Field(default="conversation.deleted", description="Object type")
deleted: bool = Field(default=True, description="Whether the object was deleted")
@json_schema_type
class ConversationItemCreateRequest(BaseModel):
"""Request body for creating conversation items."""
items: list[ConversationItem] = Field(
...,
description="Items to include in the conversation context. You may add up to 20 items at a time.",
max_length=20,
)
@json_schema_type
class ConversationItemList(BaseModel):
"""List of conversation items with pagination."""
object: str = Field(default="list", description="Object type")
data: list[ConversationItem] = Field(..., description="List of conversation items")
first_id: str | None = Field(default=None, description="The ID of the first item in the list")
last_id: str | None = Field(default=None, description="The ID of the last item in the list")
has_more: bool = Field(default=False, description="Whether there are more items available")
@json_schema_type
class ConversationItemDeletedResource(BaseModel):
"""Response for deleted conversation item."""
id: str = Field(..., description="The deleted item identifier")
object: str = Field(default="conversation.item.deleted", description="Object type")
deleted: bool = Field(default=True, description="Whether the object was deleted")
@runtime_checkable
@trace_protocol
class Conversations(Protocol):
"""Protocol for conversation management operations."""
@webmethod(route="/conversations", method="POST", level=LLAMA_STACK_API_V1)
async def create_conversation(
self, items: list[ConversationItem] | None = None, metadata: Metadata | None = None
) -> Conversation:
"""Create a conversation.
:param items: Initial items to include in the conversation context.
:param metadata: Set of key-value pairs that can be attached to an object.
:returns: The created conversation object.
"""
...
@webmethod(route="/conversations/{conversation_id}", method="GET", level=LLAMA_STACK_API_V1)
async def get_conversation(self, conversation_id: str) -> Conversation:
"""Get a conversation with the given ID.
:param conversation_id: The conversation identifier.
:returns: The conversation object.
"""
...
@webmethod(route="/conversations/{conversation_id}", method="POST", level=LLAMA_STACK_API_V1)
async def update_conversation(self, conversation_id: str, metadata: Metadata) -> Conversation:
"""Update a conversation's metadata with the given ID.
:param conversation_id: The conversation identifier.
:param metadata: Set of key-value pairs that can be attached to an object.
:returns: The updated conversation object.
"""
...
@webmethod(route="/conversations/{conversation_id}", method="DELETE", level=LLAMA_STACK_API_V1)
async def openai_delete_conversation(self, conversation_id: str) -> ConversationDeletedResource:
"""Delete a conversation with the given ID.
:param conversation_id: The conversation identifier.
:returns: The deleted conversation resource.
"""
...
@webmethod(route="/conversations/{conversation_id}/items", method="POST", level=LLAMA_STACK_API_V1)
async def create(self, conversation_id: str, items: list[ConversationItem]) -> ConversationItemList:
"""Create items in the conversation.
:param conversation_id: The conversation identifier.
:param items: Items to include in the conversation context.
:returns: List of created items.
"""
...
@webmethod(route="/conversations/{conversation_id}/items/{item_id}", method="GET", level=LLAMA_STACK_API_V1)
async def retrieve(self, conversation_id: str, item_id: str) -> ConversationItem:
"""Retrieve a conversation item.
:param conversation_id: The conversation identifier.
:param item_id: The item identifier.
:returns: The conversation item.
"""
...
@webmethod(route="/conversations/{conversation_id}/items", method="GET", level=LLAMA_STACK_API_V1)
async def list(
self,
conversation_id: str,
after: str | NotGiven = NOT_GIVEN,
include: list[ResponseIncludable] | NotGiven = NOT_GIVEN,
limit: int | NotGiven = NOT_GIVEN,
order: Literal["asc", "desc"] | NotGiven = NOT_GIVEN,
) -> ConversationItemList:
"""List items in the conversation.
:param conversation_id: The conversation identifier.
:param after: An item ID to list items after, used in pagination.
:param include: Specify additional output data to include in the response.
:param limit: A limit on the number of objects to be returned (1-100, default 20).
:param order: The order to return items in (asc or desc, default desc).
:returns: List of conversation items.
"""
...
@webmethod(route="/conversations/{conversation_id}/items/{item_id}", method="DELETE", level=LLAMA_STACK_API_V1)
async def openai_delete_conversation_item(
self, conversation_id: str, item_id: str
) -> ConversationItemDeletedResource:
"""Delete a conversation item.
:param conversation_id: The conversation identifier.
:param item_id: The item identifier.
:returns: The deleted item resource.
"""
...

View file

@ -129,6 +129,7 @@ class Api(Enum, metaclass=DynamicApiMeta):
tool_groups = "tool_groups"
files = "files"
prompts = "prompts"
conversations = "conversations"
# built-in API
inspect = "inspect"

View file

@ -0,0 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

View file

@ -0,0 +1,290 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import os
import secrets
import time
from typing import Any
from openai import NOT_GIVEN
from pydantic import BaseModel, TypeAdapter
from llama_stack.apis.conversations.conversations import (
Conversation,
ConversationDeletedResource,
ConversationItem,
ConversationItemDeletedResource,
ConversationItemList,
Conversations,
Metadata,
)
from llama_stack.core.datatypes import AccessRule, StackRunConfig
from llama_stack.core.utils.config_dirs import DISTRIBS_BASE_DIR
from llama_stack.log import get_logger
from llama_stack.providers.utils.sqlstore.api import ColumnDefinition, ColumnType
from llama_stack.providers.utils.sqlstore.authorized_sqlstore import AuthorizedSqlStore
from llama_stack.providers.utils.sqlstore.sqlstore import PostgresSqlStoreConfig, SqliteSqlStoreConfig, sqlstore_impl
logger = get_logger(name=__name__, category="openai::conversations")
class ConversationServiceConfig(BaseModel):
"""Configuration for the built-in conversation service.
:param run_config: Stack run configuration containing distribution info
"""
run_config: StackRunConfig
async def get_provider_impl(config: ConversationServiceConfig, deps: dict[Any, Any]):
"""Get the conversation service implementation."""
impl = ConversationServiceImpl(config, deps)
await impl.initialize()
return impl
class ConversationServiceImpl(Conversations):
"""Built-in conversation service implementation using AuthorizedSqlStore."""
def __init__(self, config: ConversationServiceConfig, deps: dict[Any, Any]):
self.config = config
self.deps = deps
self.policy: list[AccessRule] = []
conversations_store_config = config.run_config.conversations_store
if conversations_store_config is None:
sql_store_config: SqliteSqlStoreConfig | PostgresSqlStoreConfig = SqliteSqlStoreConfig(
db_path=(DISTRIBS_BASE_DIR / config.run_config.image_name / "conversations.db").as_posix()
)
else:
sql_store_config = conversations_store_config
base_sql_store = sqlstore_impl(sql_store_config)
self.sql_store = AuthorizedSqlStore(base_sql_store, self.policy)
async def initialize(self) -> None:
"""Initialize the store and create tables."""
if hasattr(self.sql_store.sql_store, "config") and hasattr(self.sql_store.sql_store.config, "db_path"):
os.makedirs(os.path.dirname(self.sql_store.sql_store.config.db_path), exist_ok=True)
await self.sql_store.create_table(
"openai_conversations",
{
"id": ColumnDefinition(type=ColumnType.STRING, primary_key=True),
"created_at": ColumnType.INTEGER,
"items": ColumnType.JSON,
"metadata": ColumnType.JSON,
},
)
async def create_conversation(
self, items: list[ConversationItem] | None = None, metadata: Metadata | None = None
) -> Conversation:
"""Create a conversation."""
random_bytes = secrets.token_bytes(24)
conversation_id = f"conv_{random_bytes.hex()}"
created_at = int(time.time())
items_json = []
for item in items or []:
item_dict = item.model_dump() if hasattr(item, "model_dump") else item
items_json.append(item_dict)
record_data = {
"id": conversation_id,
"created_at": created_at,
"items": items_json,
"metadata": metadata,
}
await self.sql_store.insert(
table="openai_conversations",
data=record_data,
)
conversation = Conversation(
id=conversation_id,
created_at=created_at,
metadata=metadata,
object="conversation",
)
logger.info(f"Created conversation {conversation_id}")
return conversation
async def get_conversation(self, conversation_id: str) -> Conversation:
"""Get a conversation with the given ID."""
record = await self.sql_store.fetch_one(table="openai_conversations", where={"id": conversation_id})
if record is None:
raise ValueError(f"Conversation {conversation_id} not found")
return Conversation(
id=record["id"], created_at=record["created_at"], metadata=record.get("metadata"), object="conversation"
)
async def update_conversation(self, conversation_id: str, metadata: Metadata) -> Conversation:
"""Update a conversation's metadata with the given ID"""
await self.sql_store.update(
table="openai_conversations", data={"metadata": metadata}, where={"id": conversation_id}
)
return await self.get_conversation(conversation_id)
async def openai_delete_conversation(self, conversation_id: str) -> ConversationDeletedResource:
"""Delete a conversation with the given ID."""
await self.sql_store.delete(table="openai_conversations", where={"id": conversation_id})
logger.info(f"Deleted conversation {conversation_id}")
return ConversationDeletedResource(id=conversation_id)
def _validate_conversation_id(self, conversation_id: str) -> None:
"""Validate conversation ID format."""
if not conversation_id.startswith("conv_"):
raise ValueError(
f"Invalid 'conversation_id': '{conversation_id}'. Expected an ID that begins with 'conv_'."
)
async def _get_validated_conversation(self, conversation_id: str) -> Conversation:
"""Validate conversation ID and return the conversation if it exists."""
self._validate_conversation_id(conversation_id)
return await self.get_conversation(conversation_id)
async def create(self, conversation_id: str, items: list[ConversationItem]) -> ConversationItemList:
"""Create items in the conversation."""
await self._get_validated_conversation(conversation_id)
created_items = []
for item in items:
# Generate item ID based on item type
random_bytes = secrets.token_bytes(24)
item_type = getattr(item, "type", None)
if item_type == "message":
item_id = f"msg_{random_bytes.hex()}"
else:
item_id = f"item_{random_bytes.hex()}"
# Create a copy of the item with the generated ID and completed status
item_dict = item.model_dump() if hasattr(item, "model_dump") else dict(item)
item_dict["id"] = item_id
if "status" not in item_dict:
item_dict["status"] = "completed"
created_items.append(item_dict)
# Get existing items from database
record = await self.sql_store.fetch_one(table="openai_conversations", where={"id": conversation_id})
existing_items = record.get("items", []) if record else []
updated_items = existing_items + created_items
await self.sql_store.update(
table="openai_conversations", data={"items": updated_items}, where={"id": conversation_id}
)
logger.info(f"Created {len(created_items)} items in conversation {conversation_id}")
# Convert created items (dicts) to proper ConversationItem types
adapter: TypeAdapter[ConversationItem] = TypeAdapter(ConversationItem)
response_items: list[ConversationItem] = [adapter.validate_python(item_dict) for item_dict in created_items]
return ConversationItemList(
data=response_items,
first_id=created_items[0]["id"] if created_items else None,
last_id=created_items[-1]["id"] if created_items else None,
has_more=False,
)
async def retrieve(self, conversation_id: str, item_id: str) -> ConversationItem:
"""Retrieve a conversation item."""
if not conversation_id:
raise ValueError(f"Expected a non-empty value for `conversation_id` but received {conversation_id!r}")
if not item_id:
raise ValueError(f"Expected a non-empty value for `item_id` but received {item_id!r}")
record = await self.sql_store.fetch_one(table="openai_conversations", where={"id": conversation_id})
items = record.get("items", []) if record else []
for item in items:
if isinstance(item, dict) and item.get("id") == item_id:
adapter: TypeAdapter[ConversationItem] = TypeAdapter(ConversationItem)
return adapter.validate_python(item)
raise ValueError(f"Item {item_id} not found in conversation {conversation_id}")
async def list(self, conversation_id: str, after=NOT_GIVEN, include=NOT_GIVEN, limit=NOT_GIVEN, order=NOT_GIVEN):
"""List items in the conversation."""
record = await self.sql_store.fetch_one(table="openai_conversations", where={"id": conversation_id})
items = record.get("items", []) if record else []
if order != NOT_GIVEN and order == "asc":
items = items
else:
items = list(reversed(items))
actual_limit = 20
if limit != NOT_GIVEN and isinstance(limit, int):
actual_limit = limit
items = items[:actual_limit]
# Items from database are stored as dicts, convert them to ConversationItem
adapter: TypeAdapter[ConversationItem] = TypeAdapter(ConversationItem)
response_items: list[ConversationItem] = [
adapter.validate_python(item) if isinstance(item, dict) else item for item in items
]
# Get first and last IDs safely
first_id = None
last_id = None
if items:
first_item = items[0]
last_item = items[-1]
first_id = first_item.get("id") if isinstance(first_item, dict) else getattr(first_item, "id", None)
last_id = last_item.get("id") if isinstance(last_item, dict) else getattr(last_item, "id", None)
return ConversationItemList(
data=response_items,
first_id=first_id,
last_id=last_id,
has_more=False,
)
async def openai_delete_conversation_item(
self, conversation_id: str, item_id: str
) -> ConversationItemDeletedResource:
"""Delete a conversation item."""
if not conversation_id:
raise ValueError(f"Expected a non-empty value for `conversation_id` but received {conversation_id!r}")
if not item_id:
raise ValueError(f"Expected a non-empty value for `item_id` but received {item_id!r}")
_ = await self._get_validated_conversation(conversation_id) # executes validation
record = await self.sql_store.fetch_one(table="openai_conversations", where={"id": conversation_id})
items = record.get("items", []) if record else []
updated_items = []
item_found = False
for item in items:
current_item_id = item.get("id") if isinstance(item, dict) else getattr(item, "id", None)
if current_item_id != item_id:
updated_items.append(item)
else:
item_found = True
if not item_found:
raise ValueError(f"Item {item_id} not found in conversation {conversation_id}")
await self.sql_store.update(
table="openai_conversations", data={"items": updated_items}, where={"id": conversation_id}
)
logger.info(f"Deleted item {item_id} from conversation {conversation_id}")
return ConversationItemDeletedResource(id=item_id)

View file

@ -480,6 +480,13 @@ InferenceStoreConfig (with queue tuning parameters) or a SqlStoreConfig (depreca
If not specified, a default SQLite store will be used.""",
)
conversations_store: SqlStoreConfig | None = Field(
default=None,
description="""
Configuration for the persistence store used by the conversations API.
If not specified, a default SQLite store will be used.""",
)
# registry of "resources" in the distribution
models: list[ModelInput] = Field(default_factory=list)
shields: list[ShieldInput] = Field(default_factory=list)

View file

@ -25,7 +25,7 @@ from llama_stack.providers.datatypes import (
logger = get_logger(name=__name__, category="core")
INTERNAL_APIS = {Api.inspect, Api.providers, Api.prompts}
INTERNAL_APIS = {Api.inspect, Api.providers, Api.prompts, Api.conversations}
def stack_apis() -> list[Api]:

View file

@ -10,6 +10,7 @@ from typing import Any
from llama_stack.apis.agents import Agents
from llama_stack.apis.batches import Batches
from llama_stack.apis.benchmarks import Benchmarks
from llama_stack.apis.conversations import Conversations
from llama_stack.apis.datasetio import DatasetIO
from llama_stack.apis.datasets import Datasets
from llama_stack.apis.datatypes import ExternalApiSpec
@ -96,6 +97,7 @@ def api_protocol_map(external_apis: dict[Api, ExternalApiSpec] | None = None) ->
Api.tool_runtime: ToolRuntime,
Api.files: Files,
Api.prompts: Prompts,
Api.conversations: Conversations,
}
if external_apis:

View file

@ -451,6 +451,7 @@ def create_app(
apis_to_serve.add("inspect")
apis_to_serve.add("providers")
apis_to_serve.add("prompts")
apis_to_serve.add("conversations")
for api_str in apis_to_serve:
api = Api(api_str)

View file

@ -15,6 +15,7 @@ import yaml
from llama_stack.apis.agents import Agents
from llama_stack.apis.benchmarks import Benchmarks
from llama_stack.apis.conversations import Conversations
from llama_stack.apis.datasetio import DatasetIO
from llama_stack.apis.datasets import Datasets
from llama_stack.apis.eval import Eval
@ -34,6 +35,7 @@ from llama_stack.apis.telemetry import Telemetry
from llama_stack.apis.tools import RAGToolRuntime, ToolGroups, ToolRuntime
from llama_stack.apis.vector_dbs import VectorDBs
from llama_stack.apis.vector_io import VectorIO
from llama_stack.core.conversations.conversations import ConversationServiceConfig, ConversationServiceImpl
from llama_stack.core.datatypes import Provider, StackRunConfig
from llama_stack.core.distribution import get_provider_registry
from llama_stack.core.inspect import DistributionInspectConfig, DistributionInspectImpl
@ -73,6 +75,7 @@ class LlamaStack(
RAGToolRuntime,
Files,
Prompts,
Conversations,
):
pass
@ -312,6 +315,12 @@ def add_internal_implementations(impls: dict[Api, Any], run_config: StackRunConf
)
impls[Api.prompts] = prompts_impl
conversations_impl = ConversationServiceImpl(
ConversationServiceConfig(run_config=run_config),
deps=impls,
)
impls[Api.conversations] = conversations_impl
class Stack:
def __init__(self, run_config: StackRunConfig, provider_registry: ProviderRegistry | None = None):
@ -342,6 +351,8 @@ class Stack:
if Api.prompts in impls:
await impls[Api.prompts].initialize()
if Api.conversations in impls:
await impls[Api.conversations].initialize()
await register_resources(self.run_config, impls)

View file

@ -484,12 +484,19 @@ class JsonSchemaGenerator:
}
return ret
elif origin_type is Literal:
if len(typing.get_args(typ)) != 1:
raise ValueError(f"Literal type {typ} has {len(typing.get_args(typ))} arguments")
(literal_value,) = typing.get_args(typ) # unpack value of literal type
schema = self.type_to_schema(type(literal_value))
schema["const"] = literal_value
return schema
literal_args = typing.get_args(typ)
if len(literal_args) == 1:
(literal_value,) = literal_args
schema = self.type_to_schema(type(literal_value))
schema["const"] = literal_value
return schema
elif len(literal_args) > 1:
first_value = literal_args[0]
schema = self.type_to_schema(type(first_value))
schema["enum"] = list(literal_args)
return schema
else:
return {"enum": []}
elif origin_type is type:
(concrete_type,) = typing.get_args(typ) # unpack single tuple element
return {"const": self.type_to_schema(concrete_type, force_expand=True)}