feat: Add OpenAI Conversations API (#3429)

# What does this PR do?

Initial implementation for `Conversations` and `ConversationItems` using
`AuthorizedSqlStore` with endpoints to:
- CREATE
- UPDATE
- GET/RETRIEVE/LIST
- DELETE

Set `level=LLAMA_STACK_API_V1`.

NOTE: This does not currently incorporate changes for Responses, that'll
be done in a subsequent PR.

Closes https://github.com/llamastack/llama-stack/issues/3235

## Test Plan
- Unit tests
- Integration tests

Also comparison of [OpenAPI spec for OpenAI
API](https://github.com/openai/openai-openapi/tree/manual_spec)
```bash
oasdiff breaking --fail-on ERR docs/static/llama-stack-spec.yaml https://raw.githubusercontent.com/openai/openai-openapi/refs/heads/manual_spec/openapi.yaml --strip-prefix-base "/v1/openai/v1" \
--match-path '(^/v1/openai/v1/conversations.*|^/conversations.*)'
```

Note I still have some uncertainty about this, I borrowed this info from
@cdoern on https://github.com/llamastack/llama-stack/pull/3514 but need
to spend more time to confirm it's working, at the moment it suggests it
does.

UPDATE on `oasdiff`, I investigated the OpenAI spec further and it looks
like currently the spec does not list Conversations, so that analysis is
useless. Noting for future reference.

---------

Signed-off-by: Francisco Javier Arceo <farceo@redhat.com>
This commit is contained in:
Francisco Arceo 2025-10-03 11:47:18 -04:00 committed by GitHub
parent a09e30bd87
commit a20e8eac8c
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
24 changed files with 5704 additions and 2183 deletions

File diff suppressed because it is too large Load diff

File diff suppressed because it is too large Load diff

File diff suppressed because it is too large Load diff

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,31 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .conversations import (
Conversation,
ConversationCreateRequest,
ConversationDeletedResource,
ConversationItem,
ConversationItemCreateRequest,
ConversationItemDeletedResource,
ConversationItemList,
Conversations,
ConversationUpdateRequest,
Metadata,
)
__all__ = [
"Conversation",
"ConversationCreateRequest",
"ConversationDeletedResource",
"ConversationItem",
"ConversationItemCreateRequest",
"ConversationItemDeletedResource",
"ConversationItemList",
"Conversations",
"ConversationUpdateRequest",
"Metadata",
]

View file

@ -0,0 +1,260 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Annotated, Literal, Protocol, runtime_checkable
from openai import NOT_GIVEN
from openai._types import NotGiven
from openai.types.responses.response_includable import ResponseIncludable
from pydantic import BaseModel, Field
from llama_stack.apis.agents.openai_responses import (
OpenAIResponseMessage,
OpenAIResponseOutputMessageFileSearchToolCall,
OpenAIResponseOutputMessageFunctionToolCall,
OpenAIResponseOutputMessageMCPCall,
OpenAIResponseOutputMessageMCPListTools,
OpenAIResponseOutputMessageWebSearchToolCall,
)
from llama_stack.apis.version import LLAMA_STACK_API_V1
from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol
from llama_stack.schema_utils import json_schema_type, register_schema, webmethod
Metadata = dict[str, str]
@json_schema_type
class Conversation(BaseModel):
"""OpenAI-compatible conversation object."""
id: str = Field(..., description="The unique ID of the conversation.")
object: Literal["conversation"] = Field(
default="conversation", description="The object type, which is always conversation."
)
created_at: int = Field(
..., description="The time at which the conversation was created, measured in seconds since the Unix epoch."
)
metadata: Metadata | None = Field(
default=None,
description="Set of 16 key-value pairs that can be attached to an object. This can be useful for storing additional information about the object in a structured format, and querying for objects via API or the dashboard.",
)
items: list[dict] | None = Field(
default=None,
description="Initial items to include in the conversation context. You may add up to 20 items at a time.",
)
@json_schema_type
class ConversationMessage(BaseModel):
"""OpenAI-compatible message item for conversations."""
id: str = Field(..., description="unique identifier for this message")
content: list[dict] = Field(..., description="message content")
role: str = Field(..., description="message role")
status: str = Field(..., description="message status")
type: Literal["message"] = "message"
object: Literal["message"] = "message"
ConversationItem = Annotated[
OpenAIResponseMessage
| OpenAIResponseOutputMessageFunctionToolCall
| OpenAIResponseOutputMessageFileSearchToolCall
| OpenAIResponseOutputMessageWebSearchToolCall
| OpenAIResponseOutputMessageMCPCall
| OpenAIResponseOutputMessageMCPListTools,
Field(discriminator="type"),
]
register_schema(ConversationItem, name="ConversationItem")
# Using OpenAI types directly caused issues but some notes for reference:
# Note that ConversationItem is a Annotated Union of the types below:
# from openai.types.responses import *
# from openai.types.responses.response_item import *
# from openai.types.conversations import ConversationItem
# f = [
# ResponseFunctionToolCallItem,
# ResponseFunctionToolCallOutputItem,
# ResponseFileSearchToolCall,
# ResponseFunctionWebSearch,
# ImageGenerationCall,
# ResponseComputerToolCall,
# ResponseComputerToolCallOutputItem,
# ResponseReasoningItem,
# ResponseCodeInterpreterToolCall,
# LocalShellCall,
# LocalShellCallOutput,
# McpListTools,
# McpApprovalRequest,
# McpApprovalResponse,
# McpCall,
# ResponseCustomToolCall,
# ResponseCustomToolCallOutput
# ]
@json_schema_type
class ConversationCreateRequest(BaseModel):
"""Request body for creating a conversation."""
items: list[ConversationItem] | None = Field(
default=[],
description="Initial items to include in the conversation context. You may add up to 20 items at a time.",
max_length=20,
)
metadata: Metadata | None = Field(
default={},
description="Set of 16 key-value pairs that can be attached to an object. Useful for storing additional information",
max_length=16,
)
@json_schema_type
class ConversationUpdateRequest(BaseModel):
"""Request body for updating a conversation."""
metadata: Metadata = Field(
...,
description="Set of 16 key-value pairs that can be attached to an object. This can be useful for storing additional information about the object in a structured format, and querying for objects via API or the dashboard. Keys are strings with a maximum length of 64 characters. Values are strings with a maximum length of 512 characters.",
)
@json_schema_type
class ConversationDeletedResource(BaseModel):
"""Response for deleted conversation."""
id: str = Field(..., description="The deleted conversation identifier")
object: str = Field(default="conversation.deleted", description="Object type")
deleted: bool = Field(default=True, description="Whether the object was deleted")
@json_schema_type
class ConversationItemCreateRequest(BaseModel):
"""Request body for creating conversation items."""
items: list[ConversationItem] = Field(
...,
description="Items to include in the conversation context. You may add up to 20 items at a time.",
max_length=20,
)
@json_schema_type
class ConversationItemList(BaseModel):
"""List of conversation items with pagination."""
object: str = Field(default="list", description="Object type")
data: list[ConversationItem] = Field(..., description="List of conversation items")
first_id: str | None = Field(default=None, description="The ID of the first item in the list")
last_id: str | None = Field(default=None, description="The ID of the last item in the list")
has_more: bool = Field(default=False, description="Whether there are more items available")
@json_schema_type
class ConversationItemDeletedResource(BaseModel):
"""Response for deleted conversation item."""
id: str = Field(..., description="The deleted item identifier")
object: str = Field(default="conversation.item.deleted", description="Object type")
deleted: bool = Field(default=True, description="Whether the object was deleted")
@runtime_checkable
@trace_protocol
class Conversations(Protocol):
"""Protocol for conversation management operations."""
@webmethod(route="/conversations", method="POST", level=LLAMA_STACK_API_V1)
async def create_conversation(
self, items: list[ConversationItem] | None = None, metadata: Metadata | None = None
) -> Conversation:
"""Create a conversation.
:param items: Initial items to include in the conversation context.
:param metadata: Set of key-value pairs that can be attached to an object.
:returns: The created conversation object.
"""
...
@webmethod(route="/conversations/{conversation_id}", method="GET", level=LLAMA_STACK_API_V1)
async def get_conversation(self, conversation_id: str) -> Conversation:
"""Get a conversation with the given ID.
:param conversation_id: The conversation identifier.
:returns: The conversation object.
"""
...
@webmethod(route="/conversations/{conversation_id}", method="POST", level=LLAMA_STACK_API_V1)
async def update_conversation(self, conversation_id: str, metadata: Metadata) -> Conversation:
"""Update a conversation's metadata with the given ID.
:param conversation_id: The conversation identifier.
:param metadata: Set of key-value pairs that can be attached to an object.
:returns: The updated conversation object.
"""
...
@webmethod(route="/conversations/{conversation_id}", method="DELETE", level=LLAMA_STACK_API_V1)
async def openai_delete_conversation(self, conversation_id: str) -> ConversationDeletedResource:
"""Delete a conversation with the given ID.
:param conversation_id: The conversation identifier.
:returns: The deleted conversation resource.
"""
...
@webmethod(route="/conversations/{conversation_id}/items", method="POST", level=LLAMA_STACK_API_V1)
async def add_items(self, conversation_id: str, items: list[ConversationItem]) -> ConversationItemList:
"""Create items in the conversation.
:param conversation_id: The conversation identifier.
:param items: Items to include in the conversation context.
:returns: List of created items.
"""
...
@webmethod(route="/conversations/{conversation_id}/items/{item_id}", method="GET", level=LLAMA_STACK_API_V1)
async def retrieve(self, conversation_id: str, item_id: str) -> ConversationItem:
"""Retrieve a conversation item.
:param conversation_id: The conversation identifier.
:param item_id: The item identifier.
:returns: The conversation item.
"""
...
@webmethod(route="/conversations/{conversation_id}/items", method="GET", level=LLAMA_STACK_API_V1)
async def list(
self,
conversation_id: str,
after: str | NotGiven = NOT_GIVEN,
include: list[ResponseIncludable] | NotGiven = NOT_GIVEN,
limit: int | NotGiven = NOT_GIVEN,
order: Literal["asc", "desc"] | NotGiven = NOT_GIVEN,
) -> ConversationItemList:
"""List items in the conversation.
:param conversation_id: The conversation identifier.
:param after: An item ID to list items after, used in pagination.
:param include: Specify additional output data to include in the response.
:param limit: A limit on the number of objects to be returned (1-100, default 20).
:param order: The order to return items in (asc or desc, default desc).
:returns: List of conversation items.
"""
...
@webmethod(route="/conversations/{conversation_id}/items/{item_id}", method="DELETE", level=LLAMA_STACK_API_V1)
async def openai_delete_conversation_item(
self, conversation_id: str, item_id: str
) -> ConversationItemDeletedResource:
"""Delete a conversation item.
:param conversation_id: The conversation identifier.
:param item_id: The item identifier.
:returns: The deleted item resource.
"""
...

View file

@ -129,6 +129,7 @@ class Api(Enum, metaclass=DynamicApiMeta):
tool_groups = "tool_groups" tool_groups = "tool_groups"
files = "files" files = "files"
prompts = "prompts" prompts = "prompts"
conversations = "conversations"
# built-in API # built-in API
inspect = "inspect" inspect = "inspect"

View file

@ -0,0 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

View file

@ -0,0 +1,306 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import os
import secrets
import time
from typing import Any
from openai import NOT_GIVEN
from pydantic import BaseModel, TypeAdapter
from llama_stack.apis.conversations.conversations import (
Conversation,
ConversationDeletedResource,
ConversationItem,
ConversationItemDeletedResource,
ConversationItemList,
Conversations,
Metadata,
)
from llama_stack.core.datatypes import AccessRule
from llama_stack.core.utils.config_dirs import DISTRIBS_BASE_DIR
from llama_stack.log import get_logger
from llama_stack.providers.utils.sqlstore.api import ColumnDefinition, ColumnType
from llama_stack.providers.utils.sqlstore.authorized_sqlstore import AuthorizedSqlStore
from llama_stack.providers.utils.sqlstore.sqlstore import (
SqliteSqlStoreConfig,
SqlStoreConfig,
sqlstore_impl,
)
logger = get_logger(name=__name__, category="openai::conversations")
class ConversationServiceConfig(BaseModel):
"""Configuration for the built-in conversation service.
:param conversations_store: SQL store configuration for conversations (defaults to SQLite)
:param policy: Access control rules
"""
conversations_store: SqlStoreConfig = SqliteSqlStoreConfig(
db_path=(DISTRIBS_BASE_DIR / "conversations.db").as_posix()
)
policy: list[AccessRule] = []
async def get_provider_impl(config: ConversationServiceConfig, deps: dict[Any, Any]):
"""Get the conversation service implementation."""
impl = ConversationServiceImpl(config, deps)
await impl.initialize()
return impl
class ConversationServiceImpl(Conversations):
"""Built-in conversation service implementation using AuthorizedSqlStore."""
def __init__(self, config: ConversationServiceConfig, deps: dict[Any, Any]):
self.config = config
self.deps = deps
self.policy = config.policy
base_sql_store = sqlstore_impl(config.conversations_store)
self.sql_store = AuthorizedSqlStore(base_sql_store, self.policy)
async def initialize(self) -> None:
"""Initialize the store and create tables."""
if isinstance(self.config.conversations_store, SqliteSqlStoreConfig):
os.makedirs(os.path.dirname(self.config.conversations_store.db_path), exist_ok=True)
await self.sql_store.create_table(
"openai_conversations",
{
"id": ColumnDefinition(type=ColumnType.STRING, primary_key=True),
"created_at": ColumnType.INTEGER,
"items": ColumnType.JSON,
"metadata": ColumnType.JSON,
},
)
await self.sql_store.create_table(
"conversation_items",
{
"id": ColumnDefinition(type=ColumnType.STRING, primary_key=True),
"conversation_id": ColumnType.STRING,
"created_at": ColumnType.INTEGER,
"item_data": ColumnType.JSON,
},
)
async def create_conversation(
self, items: list[ConversationItem] | None = None, metadata: Metadata | None = None
) -> Conversation:
"""Create a conversation."""
random_bytes = secrets.token_bytes(24)
conversation_id = f"conv_{random_bytes.hex()}"
created_at = int(time.time())
record_data = {
"id": conversation_id,
"created_at": created_at,
"items": [],
"metadata": metadata,
}
await self.sql_store.insert(
table="openai_conversations",
data=record_data,
)
if items:
item_records = []
for item in items:
item_dict = item.model_dump()
item_id = self._get_or_generate_item_id(item, item_dict)
item_record = {
"id": item_id,
"conversation_id": conversation_id,
"created_at": created_at,
"item_data": item_dict,
}
item_records.append(item_record)
await self.sql_store.insert(table="conversation_items", data=item_records)
conversation = Conversation(
id=conversation_id,
created_at=created_at,
metadata=metadata,
object="conversation",
)
logger.info(f"Created conversation {conversation_id}")
return conversation
async def get_conversation(self, conversation_id: str) -> Conversation:
"""Get a conversation with the given ID."""
record = await self.sql_store.fetch_one(table="openai_conversations", where={"id": conversation_id})
if record is None:
raise ValueError(f"Conversation {conversation_id} not found")
return Conversation(
id=record["id"], created_at=record["created_at"], metadata=record.get("metadata"), object="conversation"
)
async def update_conversation(self, conversation_id: str, metadata: Metadata) -> Conversation:
"""Update a conversation's metadata with the given ID"""
await self.sql_store.update(
table="openai_conversations", data={"metadata": metadata}, where={"id": conversation_id}
)
return await self.get_conversation(conversation_id)
async def openai_delete_conversation(self, conversation_id: str) -> ConversationDeletedResource:
"""Delete a conversation with the given ID."""
await self.sql_store.delete(table="openai_conversations", where={"id": conversation_id})
logger.info(f"Deleted conversation {conversation_id}")
return ConversationDeletedResource(id=conversation_id)
def _validate_conversation_id(self, conversation_id: str) -> None:
"""Validate conversation ID format."""
if not conversation_id.startswith("conv_"):
raise ValueError(
f"Invalid 'conversation_id': '{conversation_id}'. Expected an ID that begins with 'conv_'."
)
def _get_or_generate_item_id(self, item: ConversationItem, item_dict: dict) -> str:
"""Get existing item ID or generate one if missing."""
if item.id is None:
random_bytes = secrets.token_bytes(24)
if item.type == "message":
item_id = f"msg_{random_bytes.hex()}"
else:
item_id = f"item_{random_bytes.hex()}"
item_dict["id"] = item_id
return item_id
return item.id
async def _get_validated_conversation(self, conversation_id: str) -> Conversation:
"""Validate conversation ID and return the conversation if it exists."""
self._validate_conversation_id(conversation_id)
return await self.get_conversation(conversation_id)
async def add_items(self, conversation_id: str, items: list[ConversationItem]) -> ConversationItemList:
"""Create (add) items to a conversation."""
await self._get_validated_conversation(conversation_id)
created_items = []
created_at = int(time.time())
for item in items:
item_dict = item.model_dump()
item_id = self._get_or_generate_item_id(item, item_dict)
item_record = {
"id": item_id,
"conversation_id": conversation_id,
"created_at": created_at,
"item_data": item_dict,
}
# TODO: Add support for upsert in sql_store, this will fail first if ID exists and then update
try:
await self.sql_store.insert(table="conversation_items", data=item_record)
except Exception:
# If insert fails due to ID conflict, update existing record
await self.sql_store.update(
table="conversation_items",
data={"created_at": created_at, "item_data": item_dict},
where={"id": item_id},
)
created_items.append(item_dict)
logger.info(f"Created {len(created_items)} items in conversation {conversation_id}")
# Convert created items (dicts) to proper ConversationItem types
adapter: TypeAdapter[ConversationItem] = TypeAdapter(ConversationItem)
response_items: list[ConversationItem] = [adapter.validate_python(item_dict) for item_dict in created_items]
return ConversationItemList(
data=response_items,
first_id=created_items[0]["id"] if created_items else None,
last_id=created_items[-1]["id"] if created_items else None,
has_more=False,
)
async def retrieve(self, conversation_id: str, item_id: str) -> ConversationItem:
"""Retrieve a conversation item."""
if not conversation_id:
raise ValueError(f"Expected a non-empty value for `conversation_id` but received {conversation_id!r}")
if not item_id:
raise ValueError(f"Expected a non-empty value for `item_id` but received {item_id!r}")
# Get item from conversation_items table
record = await self.sql_store.fetch_one(
table="conversation_items", where={"id": item_id, "conversation_id": conversation_id}
)
if record is None:
raise ValueError(f"Item {item_id} not found in conversation {conversation_id}")
adapter: TypeAdapter[ConversationItem] = TypeAdapter(ConversationItem)
return adapter.validate_python(record["item_data"])
async def list(self, conversation_id: str, after=NOT_GIVEN, include=NOT_GIVEN, limit=NOT_GIVEN, order=NOT_GIVEN):
"""List items in the conversation."""
result = await self.sql_store.fetch_all(table="conversation_items", where={"conversation_id": conversation_id})
records = result.data
if order != NOT_GIVEN and order == "asc":
records.sort(key=lambda x: x["created_at"])
else:
records.sort(key=lambda x: x["created_at"], reverse=True)
actual_limit = 20
if limit != NOT_GIVEN and isinstance(limit, int):
actual_limit = limit
records = records[:actual_limit]
items = [record["item_data"] for record in records]
adapter: TypeAdapter[ConversationItem] = TypeAdapter(ConversationItem)
response_items: list[ConversationItem] = [adapter.validate_python(item) for item in items]
first_id = response_items[0].id if response_items else None
last_id = response_items[-1].id if response_items else None
return ConversationItemList(
data=response_items,
first_id=first_id,
last_id=last_id,
has_more=False,
)
async def openai_delete_conversation_item(
self, conversation_id: str, item_id: str
) -> ConversationItemDeletedResource:
"""Delete a conversation item."""
if not conversation_id:
raise ValueError(f"Expected a non-empty value for `conversation_id` but received {conversation_id!r}")
if not item_id:
raise ValueError(f"Expected a non-empty value for `item_id` but received {item_id!r}")
_ = await self._get_validated_conversation(conversation_id)
record = await self.sql_store.fetch_one(
table="conversation_items", where={"id": item_id, "conversation_id": conversation_id}
)
if record is None:
raise ValueError(f"Item {item_id} not found in conversation {conversation_id}")
await self.sql_store.delete(
table="conversation_items", where={"id": item_id, "conversation_id": conversation_id}
)
logger.info(f"Deleted item {item_id} from conversation {conversation_id}")
return ConversationItemDeletedResource(id=item_id)

View file

@ -475,6 +475,13 @@ InferenceStoreConfig (with queue tuning parameters) or a SqlStoreConfig (depreca
If not specified, a default SQLite store will be used.""", If not specified, a default SQLite store will be used.""",
) )
conversations_store: SqlStoreConfig | None = Field(
default=None,
description="""
Configuration for the persistence store used by the conversations API.
If not specified, a default SQLite store will be used.""",
)
# registry of "resources" in the distribution # registry of "resources" in the distribution
models: list[ModelInput] = Field(default_factory=list) models: list[ModelInput] = Field(default_factory=list)
shields: list[ShieldInput] = Field(default_factory=list) shields: list[ShieldInput] = Field(default_factory=list)

View file

@ -25,7 +25,7 @@ from llama_stack.providers.datatypes import (
logger = get_logger(name=__name__, category="core") logger = get_logger(name=__name__, category="core")
INTERNAL_APIS = {Api.inspect, Api.providers, Api.prompts} INTERNAL_APIS = {Api.inspect, Api.providers, Api.prompts, Api.conversations}
def stack_apis() -> list[Api]: def stack_apis() -> list[Api]:

View file

@ -10,6 +10,7 @@ from typing import Any
from llama_stack.apis.agents import Agents from llama_stack.apis.agents import Agents
from llama_stack.apis.batches import Batches from llama_stack.apis.batches import Batches
from llama_stack.apis.benchmarks import Benchmarks from llama_stack.apis.benchmarks import Benchmarks
from llama_stack.apis.conversations import Conversations
from llama_stack.apis.datasetio import DatasetIO from llama_stack.apis.datasetio import DatasetIO
from llama_stack.apis.datasets import Datasets from llama_stack.apis.datasets import Datasets
from llama_stack.apis.datatypes import ExternalApiSpec from llama_stack.apis.datatypes import ExternalApiSpec
@ -96,6 +97,7 @@ def api_protocol_map(external_apis: dict[Api, ExternalApiSpec] | None = None) ->
Api.tool_runtime: ToolRuntime, Api.tool_runtime: ToolRuntime,
Api.files: Files, Api.files: Files,
Api.prompts: Prompts, Api.prompts: Prompts,
Api.conversations: Conversations,
} }
if external_apis: if external_apis:

View file

@ -451,6 +451,7 @@ def create_app(
apis_to_serve.add("inspect") apis_to_serve.add("inspect")
apis_to_serve.add("providers") apis_to_serve.add("providers")
apis_to_serve.add("prompts") apis_to_serve.add("prompts")
apis_to_serve.add("conversations")
for api_str in apis_to_serve: for api_str in apis_to_serve:
api = Api(api_str) api = Api(api_str)

View file

@ -15,6 +15,7 @@ import yaml
from llama_stack.apis.agents import Agents from llama_stack.apis.agents import Agents
from llama_stack.apis.benchmarks import Benchmarks from llama_stack.apis.benchmarks import Benchmarks
from llama_stack.apis.conversations import Conversations
from llama_stack.apis.datasetio import DatasetIO from llama_stack.apis.datasetio import DatasetIO
from llama_stack.apis.datasets import Datasets from llama_stack.apis.datasets import Datasets
from llama_stack.apis.eval import Eval from llama_stack.apis.eval import Eval
@ -34,6 +35,7 @@ from llama_stack.apis.telemetry import Telemetry
from llama_stack.apis.tools import RAGToolRuntime, ToolGroups, ToolRuntime from llama_stack.apis.tools import RAGToolRuntime, ToolGroups, ToolRuntime
from llama_stack.apis.vector_dbs import VectorDBs from llama_stack.apis.vector_dbs import VectorDBs
from llama_stack.apis.vector_io import VectorIO from llama_stack.apis.vector_io import VectorIO
from llama_stack.core.conversations.conversations import ConversationServiceConfig, ConversationServiceImpl
from llama_stack.core.datatypes import Provider, StackRunConfig from llama_stack.core.datatypes import Provider, StackRunConfig
from llama_stack.core.distribution import get_provider_registry from llama_stack.core.distribution import get_provider_registry
from llama_stack.core.inspect import DistributionInspectConfig, DistributionInspectImpl from llama_stack.core.inspect import DistributionInspectConfig, DistributionInspectImpl
@ -73,6 +75,7 @@ class LlamaStack(
RAGToolRuntime, RAGToolRuntime,
Files, Files,
Prompts, Prompts,
Conversations,
): ):
pass pass
@ -312,6 +315,12 @@ def add_internal_implementations(impls: dict[Api, Any], run_config: StackRunConf
) )
impls[Api.prompts] = prompts_impl impls[Api.prompts] = prompts_impl
conversations_impl = ConversationServiceImpl(
ConversationServiceConfig(run_config=run_config),
deps=impls,
)
impls[Api.conversations] = conversations_impl
class Stack: class Stack:
def __init__(self, run_config: StackRunConfig, provider_registry: ProviderRegistry | None = None): def __init__(self, run_config: StackRunConfig, provider_registry: ProviderRegistry | None = None):
@ -342,6 +351,8 @@ class Stack:
if Api.prompts in impls: if Api.prompts in impls:
await impls[Api.prompts].initialize() await impls[Api.prompts].initialize()
if Api.conversations in impls:
await impls[Api.conversations].initialize()
await register_resources(self.run_config, impls) await register_resources(self.run_config, impls)

View file

@ -4,7 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from collections.abc import Mapping from collections.abc import Mapping, Sequence
from enum import Enum from enum import Enum
from typing import Any, Literal, Protocol from typing import Any, Literal, Protocol
@ -41,9 +41,9 @@ class SqlStore(Protocol):
""" """
pass pass
async def insert(self, table: str, data: Mapping[str, Any]) -> None: async def insert(self, table: str, data: Mapping[str, Any] | Sequence[Mapping[str, Any]]) -> None:
""" """
Insert a row into a table. Insert a row or batch of rows into a table.
""" """
pass pass

View file

@ -4,7 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from collections.abc import Mapping from collections.abc import Mapping, Sequence
from typing import Any, Literal from typing import Any, Literal
from llama_stack.core.access_control.access_control import default_policy, is_action_allowed from llama_stack.core.access_control.access_control import default_policy, is_action_allowed
@ -38,6 +38,18 @@ SQL_OPTIMIZED_POLICY = [
] ]
def _enhance_item_with_access_control(item: Mapping[str, Any], current_user: User | None) -> Mapping[str, Any]:
"""Add access control attributes to a data item."""
enhanced = dict(item)
if current_user:
enhanced["owner_principal"] = current_user.principal
enhanced["access_attributes"] = current_user.attributes
else:
enhanced["owner_principal"] = None
enhanced["access_attributes"] = None
return enhanced
class SqlRecord(ProtectedResource): class SqlRecord(ProtectedResource):
def __init__(self, record_id: str, table_name: str, owner: User): def __init__(self, record_id: str, table_name: str, owner: User):
self.type = f"sql_record::{table_name}" self.type = f"sql_record::{table_name}"
@ -102,18 +114,14 @@ class AuthorizedSqlStore:
await self.sql_store.add_column_if_not_exists(table, "access_attributes", ColumnType.JSON) await self.sql_store.add_column_if_not_exists(table, "access_attributes", ColumnType.JSON)
await self.sql_store.add_column_if_not_exists(table, "owner_principal", ColumnType.STRING) await self.sql_store.add_column_if_not_exists(table, "owner_principal", ColumnType.STRING)
async def insert(self, table: str, data: Mapping[str, Any]) -> None: async def insert(self, table: str, data: Mapping[str, Any] | Sequence[Mapping[str, Any]]) -> None:
"""Insert a row with automatic access control attribute capture.""" """Insert a row or batch of rows with automatic access control attribute capture."""
enhanced_data = dict(data)
current_user = get_authenticated_user() current_user = get_authenticated_user()
if current_user: enhanced_data: Mapping[str, Any] | Sequence[Mapping[str, Any]]
enhanced_data["owner_principal"] = current_user.principal if isinstance(data, Mapping):
enhanced_data["access_attributes"] = current_user.attributes enhanced_data = _enhance_item_with_access_control(data, current_user)
else: else:
enhanced_data["owner_principal"] = None enhanced_data = [_enhance_item_with_access_control(item, current_user) for item in data]
enhanced_data["access_attributes"] = None
await self.sql_store.insert(table, enhanced_data) await self.sql_store.insert(table, enhanced_data)
async def fetch_all( async def fetch_all(

View file

@ -3,7 +3,7 @@
# #
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from collections.abc import Mapping from collections.abc import Mapping, Sequence
from typing import Any, Literal from typing import Any, Literal
from sqlalchemy import ( from sqlalchemy import (
@ -116,7 +116,7 @@ class SqlAlchemySqlStoreImpl(SqlStore):
async with engine.begin() as conn: async with engine.begin() as conn:
await conn.run_sync(self.metadata.create_all, tables=[sqlalchemy_table], checkfirst=True) await conn.run_sync(self.metadata.create_all, tables=[sqlalchemy_table], checkfirst=True)
async def insert(self, table: str, data: Mapping[str, Any]) -> None: async def insert(self, table: str, data: Mapping[str, Any] | Sequence[Mapping[str, Any]]) -> None:
async with self.async_session() as session: async with self.async_session() as session:
await session.execute(self.metadata.tables[table].insert(), data) await session.execute(self.metadata.tables[table].insert(), data)
await session.commit() await session.commit()

View file

@ -484,12 +484,19 @@ class JsonSchemaGenerator:
} }
return ret return ret
elif origin_type is Literal: elif origin_type is Literal:
if len(typing.get_args(typ)) != 1: literal_args = typing.get_args(typ)
raise ValueError(f"Literal type {typ} has {len(typing.get_args(typ))} arguments") if len(literal_args) == 1:
(literal_value,) = typing.get_args(typ) # unpack value of literal type (literal_value,) = literal_args
schema = self.type_to_schema(type(literal_value)) schema = self.type_to_schema(type(literal_value))
schema["const"] = literal_value schema["const"] = literal_value
return schema return schema
elif len(literal_args) > 1:
first_value = literal_args[0]
schema = self.type_to_schema(type(first_value))
schema["enum"] = list(literal_args)
return schema
else:
return {"enum": []}
elif origin_type is type: elif origin_type is type:
(concrete_type,) = typing.get_args(typ) # unpack single tuple element (concrete_type,) = typing.get_args(typ) # unpack single tuple element
return {"const": self.type_to_schema(concrete_type, force_expand=True)} return {"const": self.type_to_schema(concrete_type, force_expand=True)}

View file

@ -32,7 +32,7 @@ dependencies = [
"jinja2>=3.1.6", "jinja2>=3.1.6",
"jsonschema", "jsonschema",
"llama-stack-client>=0.2.23", "llama-stack-client>=0.2.23",
"openai>=1.100.0", # for expires_after support "openai>=1.107", # for expires_after support
"prompt-toolkit", "prompt-toolkit",
"python-dotenv", "python-dotenv",
"python-jose[cryptography]", "python-jose[cryptography]",
@ -49,6 +49,7 @@ dependencies = [
"opentelemetry-exporter-otlp-proto-http>=1.30.0", # server "opentelemetry-exporter-otlp-proto-http>=1.30.0", # server
"aiosqlite>=0.21.0", # server - for metadata store "aiosqlite>=0.21.0", # server - for metadata store
"asyncpg", # for metadata store "asyncpg", # for metadata store
"sqlalchemy[asyncio]>=2.0.41", # server - for conversations
] ]
[project.optional-dependencies] [project.optional-dependencies]

View file

@ -0,0 +1,135 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import pytest
@pytest.mark.integration
class TestOpenAIConversations:
# TODO: Update to compat_client after client-SDK is generated
def test_conversation_create(self, openai_client):
conversation = openai_client.conversations.create(
metadata={"topic": "demo"}, items=[{"type": "message", "role": "user", "content": "Hello!"}]
)
assert conversation.id.startswith("conv_")
assert conversation.object == "conversation"
assert conversation.metadata["topic"] == "demo"
assert isinstance(conversation.created_at, int)
def test_conversation_retrieve(self, openai_client):
conversation = openai_client.conversations.create(metadata={"topic": "demo"})
retrieved = openai_client.conversations.retrieve(conversation.id)
assert retrieved.id == conversation.id
assert retrieved.object == "conversation"
assert retrieved.metadata["topic"] == "demo"
assert retrieved.created_at == conversation.created_at
def test_conversation_update(self, openai_client):
conversation = openai_client.conversations.create(metadata={"topic": "demo"})
updated = openai_client.conversations.update(conversation.id, metadata={"topic": "project-x"})
assert updated.id == conversation.id
assert updated.metadata["topic"] == "project-x"
assert updated.created_at == conversation.created_at
def test_conversation_delete(self, openai_client):
conversation = openai_client.conversations.create(metadata={"topic": "demo"})
deleted = openai_client.conversations.delete(conversation.id)
assert deleted.id == conversation.id
assert deleted.object == "conversation.deleted"
assert deleted.deleted is True
def test_conversation_items_create(self, openai_client):
conversation = openai_client.conversations.create()
items = openai_client.conversations.items.create(
conversation.id,
items=[
{"type": "message", "role": "user", "content": [{"type": "input_text", "text": "Hello!"}]},
{"type": "message", "role": "user", "content": [{"type": "input_text", "text": "How are you?"}]},
],
)
assert items.object == "list"
assert len(items.data) == 2
assert items.data[0].content[0].text == "Hello!"
assert items.data[1].content[0].text == "How are you?"
assert items.first_id == items.data[0].id
assert items.last_id == items.data[1].id
assert items.has_more is False
def test_conversation_items_list(self, openai_client):
conversation = openai_client.conversations.create()
openai_client.conversations.items.create(
conversation.id,
items=[{"type": "message", "role": "user", "content": [{"type": "input_text", "text": "Hello!"}]}],
)
items = openai_client.conversations.items.list(conversation.id, limit=10)
assert items.object == "list"
assert len(items.data) >= 1
assert items.data[0].type == "message"
assert items.data[0].role == "user"
assert hasattr(items, "first_id")
assert hasattr(items, "last_id")
assert hasattr(items, "has_more")
def test_conversation_item_retrieve(self, openai_client):
conversation = openai_client.conversations.create()
created_items = openai_client.conversations.items.create(
conversation.id,
items=[{"type": "message", "role": "user", "content": [{"type": "input_text", "text": "Hello!"}]}],
)
item_id = created_items.data[0].id
item = openai_client.conversations.items.retrieve(item_id, conversation_id=conversation.id)
assert item.id == item_id
assert item.type == "message"
assert item.role == "user"
assert item.content[0].text == "Hello!"
def test_conversation_item_delete(self, openai_client):
conversation = openai_client.conversations.create()
created_items = openai_client.conversations.items.create(
conversation.id,
items=[{"type": "message", "role": "user", "content": [{"type": "input_text", "text": "Hello!"}]}],
)
item_id = created_items.data[0].id
deleted = openai_client.conversations.items.delete(item_id, conversation_id=conversation.id)
assert deleted.id == item_id
assert deleted.object == "conversation.item.deleted"
assert deleted.deleted is True
def test_full_workflow(self, openai_client):
conversation = openai_client.conversations.create(
metadata={"topic": "workflow-test"}, items=[{"type": "message", "role": "user", "content": "Hello!"}]
)
openai_client.conversations.items.create(
conversation.id,
items=[{"type": "message", "role": "user", "content": [{"type": "input_text", "text": "Follow up"}]}],
)
all_items = openai_client.conversations.items.list(conversation.id)
assert len(all_items.data) >= 2
updated = openai_client.conversations.update(conversation.id, metadata={"topic": "workflow-complete"})
assert updated.metadata["topic"] == "workflow-complete"
openai_client.conversations.delete(conversation.id)

View file

@ -0,0 +1,60 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.apis.conversations.conversations import (
Conversation,
ConversationCreateRequest,
ConversationItem,
ConversationItemList,
)
def test_conversation_create_request_defaults():
request = ConversationCreateRequest()
assert request.items == []
assert request.metadata == {}
def test_conversation_model_defaults():
conversation = Conversation(
id="conv_123456789",
created_at=1234567890,
metadata=None,
object="conversation",
)
assert conversation.id == "conv_123456789"
assert conversation.object == "conversation"
assert conversation.metadata is None
def test_openai_client_compatibility():
from openai.types.conversations.message import Message
from pydantic import TypeAdapter
openai_message = Message(
id="msg_123",
content=[{"type": "input_text", "text": "Hello"}],
role="user",
status="in_progress",
type="message",
object="message",
)
adapter = TypeAdapter(ConversationItem)
validated_item = adapter.validate_python(openai_message.model_dump())
assert validated_item.id == "msg_123"
assert validated_item.type == "message"
def test_conversation_item_list():
item_list = ConversationItemList(data=[])
assert item_list.object == "list"
assert item_list.data == []
assert item_list.first_id is None
assert item_list.last_id is None
assert item_list.has_more is False

View file

@ -0,0 +1,132 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import tempfile
from pathlib import Path
import pytest
from openai.types.conversations.conversation import Conversation as OpenAIConversation
from openai.types.conversations.conversation_item import ConversationItem as OpenAIConversationItem
from pydantic import TypeAdapter
from llama_stack.apis.agents.openai_responses import (
OpenAIResponseInputMessageContentText,
OpenAIResponseMessage,
)
from llama_stack.core.conversations.conversations import (
ConversationServiceConfig,
ConversationServiceImpl,
)
from llama_stack.providers.utils.sqlstore.sqlstore import SqliteSqlStoreConfig
@pytest.fixture
async def service():
with tempfile.TemporaryDirectory() as tmpdir:
db_path = Path(tmpdir) / "test_conversations.db"
config = ConversationServiceConfig(conversations_store=SqliteSqlStoreConfig(db_path=str(db_path)), policy=[])
service = ConversationServiceImpl(config, {})
await service.initialize()
yield service
async def test_conversation_lifecycle(service):
conversation = await service.create_conversation(metadata={"test": "data"})
assert conversation.id.startswith("conv_")
assert conversation.metadata == {"test": "data"}
retrieved = await service.get_conversation(conversation.id)
assert retrieved.id == conversation.id
deleted = await service.openai_delete_conversation(conversation.id)
assert deleted.id == conversation.id
async def test_conversation_items(service):
conversation = await service.create_conversation()
items = [
OpenAIResponseMessage(
type="message",
role="user",
content=[OpenAIResponseInputMessageContentText(type="input_text", text="Hello")],
id="msg_test123",
status="completed",
)
]
item_list = await service.add_items(conversation.id, items)
assert len(item_list.data) == 1
assert item_list.data[0].id == "msg_test123"
items = await service.list(conversation.id)
assert len(items.data) == 1
async def test_invalid_conversation_id(service):
with pytest.raises(ValueError, match="Expected an ID that begins with 'conv_'"):
await service._get_validated_conversation("invalid_id")
async def test_empty_parameter_validation(service):
with pytest.raises(ValueError, match="Expected a non-empty value"):
await service.retrieve("", "item_123")
async def test_openai_type_compatibility(service):
conversation = await service.create_conversation(metadata={"test": "value"})
conversation_dict = conversation.model_dump()
openai_conversation = OpenAIConversation.model_validate(conversation_dict)
for attr in ["id", "object", "created_at", "metadata"]:
assert getattr(openai_conversation, attr) == getattr(conversation, attr)
items = [
OpenAIResponseMessage(
type="message",
role="user",
content=[OpenAIResponseInputMessageContentText(type="input_text", text="Hello")],
id="msg_test456",
status="completed",
)
]
item_list = await service.add_items(conversation.id, items)
for attr in ["object", "data", "first_id", "last_id", "has_more"]:
assert hasattr(item_list, attr)
assert item_list.object == "list"
items = await service.list(conversation.id)
item = await service.retrieve(conversation.id, items.data[0].id)
item_dict = item.model_dump()
openai_item_adapter = TypeAdapter(OpenAIConversationItem)
openai_item_adapter.validate_python(item_dict)
async def test_policy_configuration():
from llama_stack.core.access_control.datatypes import Action, Scope
from llama_stack.core.datatypes import AccessRule
with tempfile.TemporaryDirectory() as tmpdir:
db_path = Path(tmpdir) / "test_conversations_policy.db"
restrictive_policy = [
AccessRule(forbid=Scope(principal="test_user", actions=[Action.CREATE, Action.READ], resource="*"))
]
config = ConversationServiceConfig(
conversations_store=SqliteSqlStoreConfig(db_path=str(db_path)), policy=restrictive_policy
)
service = ConversationServiceImpl(config, {})
await service.initialize()
assert service.policy == restrictive_policy
assert len(service.policy) == 1
assert service.policy[0].forbid is not None

View file

@ -368,6 +368,32 @@ async def test_where_operator_gt_and_update_delete():
assert {r["id"] for r in rows_after} == {1, 3} assert {r["id"] for r in rows_after} == {1, 3}
async def test_batch_insert():
with TemporaryDirectory() as tmp_dir:
db_path = tmp_dir + "/test.db"
store = SqlAlchemySqlStoreImpl(SqliteSqlStoreConfig(db_path=db_path))
await store.create_table(
"batch_test",
{
"id": ColumnType.INTEGER,
"name": ColumnType.STRING,
"value": ColumnType.INTEGER,
},
)
batch_data = [
{"id": 1, "name": "first", "value": 10},
{"id": 2, "name": "second", "value": 20},
{"id": 3, "name": "third", "value": 30},
]
await store.insert("batch_test", batch_data)
result = await store.fetch_all("batch_test", order_by=[("id", "asc")])
assert result.data == batch_data
async def test_where_operator_edge_cases(): async def test_where_operator_edge_cases():
with TemporaryDirectory() as tmp_dir: with TemporaryDirectory() as tmp_dir:
db_path = tmp_dir + "/test.db" db_path = tmp_dir + "/test.db"

4
uv.lock generated
View file

@ -1773,6 +1773,7 @@ dependencies = [
{ name = "python-jose", extra = ["cryptography"] }, { name = "python-jose", extra = ["cryptography"] },
{ name = "python-multipart" }, { name = "python-multipart" },
{ name = "rich" }, { name = "rich" },
{ name = "sqlalchemy", extra = ["asyncio"] },
{ name = "starlette" }, { name = "starlette" },
{ name = "termcolor" }, { name = "termcolor" },
{ name = "tiktoken" }, { name = "tiktoken" },
@ -1887,7 +1888,7 @@ requires-dist = [
{ name = "jsonschema" }, { name = "jsonschema" },
{ name = "llama-stack-client", specifier = ">=0.2.23" }, { name = "llama-stack-client", specifier = ">=0.2.23" },
{ name = "llama-stack-client", marker = "extra == 'ui'", specifier = ">=0.2.23" }, { name = "llama-stack-client", marker = "extra == 'ui'", specifier = ">=0.2.23" },
{ name = "openai", specifier = ">=1.100.0" }, { name = "openai", specifier = ">=1.107" },
{ name = "opentelemetry-exporter-otlp-proto-http", specifier = ">=1.30.0" }, { name = "opentelemetry-exporter-otlp-proto-http", specifier = ">=1.30.0" },
{ name = "opentelemetry-sdk", specifier = ">=1.30.0" }, { name = "opentelemetry-sdk", specifier = ">=1.30.0" },
{ name = "pandas", marker = "extra == 'ui'" }, { name = "pandas", marker = "extra == 'ui'" },
@ -1898,6 +1899,7 @@ requires-dist = [
{ name = "python-jose", extras = ["cryptography"] }, { name = "python-jose", extras = ["cryptography"] },
{ name = "python-multipart", specifier = ">=0.0.20" }, { name = "python-multipart", specifier = ">=0.0.20" },
{ name = "rich" }, { name = "rich" },
{ name = "sqlalchemy", extras = ["asyncio"], specifier = ">=2.0.41" },
{ name = "starlette" }, { name = "starlette" },
{ name = "streamlit", marker = "extra == 'ui'" }, { name = "streamlit", marker = "extra == 'ui'" },
{ name = "streamlit-option-menu", marker = "extra == 'ui'" }, { name = "streamlit-option-menu", marker = "extra == 'ui'" },