mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-03 19:57:35 +00:00
feat: Add OpenAI Conversations API (#3429)
# What does this PR do? Initial implementation for `Conversations` and `ConversationItems` using `AuthorizedSqlStore` with endpoints to: - CREATE - UPDATE - GET/RETRIEVE/LIST - DELETE Set `level=LLAMA_STACK_API_V1`. NOTE: This does not currently incorporate changes for Responses, that'll be done in a subsequent PR. Closes https://github.com/llamastack/llama-stack/issues/3235 ## Test Plan - Unit tests - Integration tests Also comparison of [OpenAPI spec for OpenAI API](https://github.com/openai/openai-openapi/tree/manual_spec) ```bash oasdiff breaking --fail-on ERR docs/static/llama-stack-spec.yaml https://raw.githubusercontent.com/openai/openai-openapi/refs/heads/manual_spec/openapi.yaml --strip-prefix-base "/v1/openai/v1" \ --match-path '(^/v1/openai/v1/conversations.*|^/conversations.*)' ``` Note I still have some uncertainty about this, I borrowed this info from @cdoern on https://github.com/llamastack/llama-stack/pull/3514 but need to spend more time to confirm it's working, at the moment it suggests it does. UPDATE on `oasdiff`, I investigated the OpenAI spec further and it looks like currently the spec does not list Conversations, so that analysis is useless. Noting for future reference. --------- Signed-off-by: Francisco Javier Arceo <farceo@redhat.com>
This commit is contained in:
parent
a09e30bd87
commit
a20e8eac8c
24 changed files with 5704 additions and 2183 deletions
1902
docs/static/llama-stack-spec.html
vendored
1902
docs/static/llama-stack-spec.html
vendored
File diff suppressed because it is too large
Load diff
1519
docs/static/llama-stack-spec.yaml
vendored
1519
docs/static/llama-stack-spec.yaml
vendored
File diff suppressed because it is too large
Load diff
1902
docs/static/stainless-llama-stack-spec.html
vendored
1902
docs/static/stainless-llama-stack-spec.html
vendored
File diff suppressed because it is too large
Load diff
1519
docs/static/stainless-llama-stack-spec.yaml
vendored
1519
docs/static/stainless-llama-stack-spec.yaml
vendored
File diff suppressed because it is too large
Load diff
31
llama_stack/apis/conversations/__init__.py
Normal file
31
llama_stack/apis/conversations/__init__.py
Normal file
|
@ -0,0 +1,31 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from .conversations import (
|
||||||
|
Conversation,
|
||||||
|
ConversationCreateRequest,
|
||||||
|
ConversationDeletedResource,
|
||||||
|
ConversationItem,
|
||||||
|
ConversationItemCreateRequest,
|
||||||
|
ConversationItemDeletedResource,
|
||||||
|
ConversationItemList,
|
||||||
|
Conversations,
|
||||||
|
ConversationUpdateRequest,
|
||||||
|
Metadata,
|
||||||
|
)
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"Conversation",
|
||||||
|
"ConversationCreateRequest",
|
||||||
|
"ConversationDeletedResource",
|
||||||
|
"ConversationItem",
|
||||||
|
"ConversationItemCreateRequest",
|
||||||
|
"ConversationItemDeletedResource",
|
||||||
|
"ConversationItemList",
|
||||||
|
"Conversations",
|
||||||
|
"ConversationUpdateRequest",
|
||||||
|
"Metadata",
|
||||||
|
]
|
260
llama_stack/apis/conversations/conversations.py
Normal file
260
llama_stack/apis/conversations/conversations.py
Normal file
|
@ -0,0 +1,260 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from typing import Annotated, Literal, Protocol, runtime_checkable
|
||||||
|
|
||||||
|
from openai import NOT_GIVEN
|
||||||
|
from openai._types import NotGiven
|
||||||
|
from openai.types.responses.response_includable import ResponseIncludable
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
from llama_stack.apis.agents.openai_responses import (
|
||||||
|
OpenAIResponseMessage,
|
||||||
|
OpenAIResponseOutputMessageFileSearchToolCall,
|
||||||
|
OpenAIResponseOutputMessageFunctionToolCall,
|
||||||
|
OpenAIResponseOutputMessageMCPCall,
|
||||||
|
OpenAIResponseOutputMessageMCPListTools,
|
||||||
|
OpenAIResponseOutputMessageWebSearchToolCall,
|
||||||
|
)
|
||||||
|
from llama_stack.apis.version import LLAMA_STACK_API_V1
|
||||||
|
from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol
|
||||||
|
from llama_stack.schema_utils import json_schema_type, register_schema, webmethod
|
||||||
|
|
||||||
|
Metadata = dict[str, str]
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class Conversation(BaseModel):
|
||||||
|
"""OpenAI-compatible conversation object."""
|
||||||
|
|
||||||
|
id: str = Field(..., description="The unique ID of the conversation.")
|
||||||
|
object: Literal["conversation"] = Field(
|
||||||
|
default="conversation", description="The object type, which is always conversation."
|
||||||
|
)
|
||||||
|
created_at: int = Field(
|
||||||
|
..., description="The time at which the conversation was created, measured in seconds since the Unix epoch."
|
||||||
|
)
|
||||||
|
metadata: Metadata | None = Field(
|
||||||
|
default=None,
|
||||||
|
description="Set of 16 key-value pairs that can be attached to an object. This can be useful for storing additional information about the object in a structured format, and querying for objects via API or the dashboard.",
|
||||||
|
)
|
||||||
|
items: list[dict] | None = Field(
|
||||||
|
default=None,
|
||||||
|
description="Initial items to include in the conversation context. You may add up to 20 items at a time.",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class ConversationMessage(BaseModel):
|
||||||
|
"""OpenAI-compatible message item for conversations."""
|
||||||
|
|
||||||
|
id: str = Field(..., description="unique identifier for this message")
|
||||||
|
content: list[dict] = Field(..., description="message content")
|
||||||
|
role: str = Field(..., description="message role")
|
||||||
|
status: str = Field(..., description="message status")
|
||||||
|
type: Literal["message"] = "message"
|
||||||
|
object: Literal["message"] = "message"
|
||||||
|
|
||||||
|
|
||||||
|
ConversationItem = Annotated[
|
||||||
|
OpenAIResponseMessage
|
||||||
|
| OpenAIResponseOutputMessageFunctionToolCall
|
||||||
|
| OpenAIResponseOutputMessageFileSearchToolCall
|
||||||
|
| OpenAIResponseOutputMessageWebSearchToolCall
|
||||||
|
| OpenAIResponseOutputMessageMCPCall
|
||||||
|
| OpenAIResponseOutputMessageMCPListTools,
|
||||||
|
Field(discriminator="type"),
|
||||||
|
]
|
||||||
|
register_schema(ConversationItem, name="ConversationItem")
|
||||||
|
|
||||||
|
# Using OpenAI types directly caused issues but some notes for reference:
|
||||||
|
# Note that ConversationItem is a Annotated Union of the types below:
|
||||||
|
# from openai.types.responses import *
|
||||||
|
# from openai.types.responses.response_item import *
|
||||||
|
# from openai.types.conversations import ConversationItem
|
||||||
|
# f = [
|
||||||
|
# ResponseFunctionToolCallItem,
|
||||||
|
# ResponseFunctionToolCallOutputItem,
|
||||||
|
# ResponseFileSearchToolCall,
|
||||||
|
# ResponseFunctionWebSearch,
|
||||||
|
# ImageGenerationCall,
|
||||||
|
# ResponseComputerToolCall,
|
||||||
|
# ResponseComputerToolCallOutputItem,
|
||||||
|
# ResponseReasoningItem,
|
||||||
|
# ResponseCodeInterpreterToolCall,
|
||||||
|
# LocalShellCall,
|
||||||
|
# LocalShellCallOutput,
|
||||||
|
# McpListTools,
|
||||||
|
# McpApprovalRequest,
|
||||||
|
# McpApprovalResponse,
|
||||||
|
# McpCall,
|
||||||
|
# ResponseCustomToolCall,
|
||||||
|
# ResponseCustomToolCallOutput
|
||||||
|
# ]
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class ConversationCreateRequest(BaseModel):
|
||||||
|
"""Request body for creating a conversation."""
|
||||||
|
|
||||||
|
items: list[ConversationItem] | None = Field(
|
||||||
|
default=[],
|
||||||
|
description="Initial items to include in the conversation context. You may add up to 20 items at a time.",
|
||||||
|
max_length=20,
|
||||||
|
)
|
||||||
|
metadata: Metadata | None = Field(
|
||||||
|
default={},
|
||||||
|
description="Set of 16 key-value pairs that can be attached to an object. Useful for storing additional information",
|
||||||
|
max_length=16,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class ConversationUpdateRequest(BaseModel):
|
||||||
|
"""Request body for updating a conversation."""
|
||||||
|
|
||||||
|
metadata: Metadata = Field(
|
||||||
|
...,
|
||||||
|
description="Set of 16 key-value pairs that can be attached to an object. This can be useful for storing additional information about the object in a structured format, and querying for objects via API or the dashboard. Keys are strings with a maximum length of 64 characters. Values are strings with a maximum length of 512 characters.",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class ConversationDeletedResource(BaseModel):
|
||||||
|
"""Response for deleted conversation."""
|
||||||
|
|
||||||
|
id: str = Field(..., description="The deleted conversation identifier")
|
||||||
|
object: str = Field(default="conversation.deleted", description="Object type")
|
||||||
|
deleted: bool = Field(default=True, description="Whether the object was deleted")
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class ConversationItemCreateRequest(BaseModel):
|
||||||
|
"""Request body for creating conversation items."""
|
||||||
|
|
||||||
|
items: list[ConversationItem] = Field(
|
||||||
|
...,
|
||||||
|
description="Items to include in the conversation context. You may add up to 20 items at a time.",
|
||||||
|
max_length=20,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class ConversationItemList(BaseModel):
|
||||||
|
"""List of conversation items with pagination."""
|
||||||
|
|
||||||
|
object: str = Field(default="list", description="Object type")
|
||||||
|
data: list[ConversationItem] = Field(..., description="List of conversation items")
|
||||||
|
first_id: str | None = Field(default=None, description="The ID of the first item in the list")
|
||||||
|
last_id: str | None = Field(default=None, description="The ID of the last item in the list")
|
||||||
|
has_more: bool = Field(default=False, description="Whether there are more items available")
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class ConversationItemDeletedResource(BaseModel):
|
||||||
|
"""Response for deleted conversation item."""
|
||||||
|
|
||||||
|
id: str = Field(..., description="The deleted item identifier")
|
||||||
|
object: str = Field(default="conversation.item.deleted", description="Object type")
|
||||||
|
deleted: bool = Field(default=True, description="Whether the object was deleted")
|
||||||
|
|
||||||
|
|
||||||
|
@runtime_checkable
|
||||||
|
@trace_protocol
|
||||||
|
class Conversations(Protocol):
|
||||||
|
"""Protocol for conversation management operations."""
|
||||||
|
|
||||||
|
@webmethod(route="/conversations", method="POST", level=LLAMA_STACK_API_V1)
|
||||||
|
async def create_conversation(
|
||||||
|
self, items: list[ConversationItem] | None = None, metadata: Metadata | None = None
|
||||||
|
) -> Conversation:
|
||||||
|
"""Create a conversation.
|
||||||
|
|
||||||
|
:param items: Initial items to include in the conversation context.
|
||||||
|
:param metadata: Set of key-value pairs that can be attached to an object.
|
||||||
|
:returns: The created conversation object.
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
|
@webmethod(route="/conversations/{conversation_id}", method="GET", level=LLAMA_STACK_API_V1)
|
||||||
|
async def get_conversation(self, conversation_id: str) -> Conversation:
|
||||||
|
"""Get a conversation with the given ID.
|
||||||
|
|
||||||
|
:param conversation_id: The conversation identifier.
|
||||||
|
:returns: The conversation object.
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
|
@webmethod(route="/conversations/{conversation_id}", method="POST", level=LLAMA_STACK_API_V1)
|
||||||
|
async def update_conversation(self, conversation_id: str, metadata: Metadata) -> Conversation:
|
||||||
|
"""Update a conversation's metadata with the given ID.
|
||||||
|
|
||||||
|
:param conversation_id: The conversation identifier.
|
||||||
|
:param metadata: Set of key-value pairs that can be attached to an object.
|
||||||
|
:returns: The updated conversation object.
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
|
@webmethod(route="/conversations/{conversation_id}", method="DELETE", level=LLAMA_STACK_API_V1)
|
||||||
|
async def openai_delete_conversation(self, conversation_id: str) -> ConversationDeletedResource:
|
||||||
|
"""Delete a conversation with the given ID.
|
||||||
|
|
||||||
|
:param conversation_id: The conversation identifier.
|
||||||
|
:returns: The deleted conversation resource.
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
|
@webmethod(route="/conversations/{conversation_id}/items", method="POST", level=LLAMA_STACK_API_V1)
|
||||||
|
async def add_items(self, conversation_id: str, items: list[ConversationItem]) -> ConversationItemList:
|
||||||
|
"""Create items in the conversation.
|
||||||
|
|
||||||
|
:param conversation_id: The conversation identifier.
|
||||||
|
:param items: Items to include in the conversation context.
|
||||||
|
:returns: List of created items.
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
|
@webmethod(route="/conversations/{conversation_id}/items/{item_id}", method="GET", level=LLAMA_STACK_API_V1)
|
||||||
|
async def retrieve(self, conversation_id: str, item_id: str) -> ConversationItem:
|
||||||
|
"""Retrieve a conversation item.
|
||||||
|
|
||||||
|
:param conversation_id: The conversation identifier.
|
||||||
|
:param item_id: The item identifier.
|
||||||
|
:returns: The conversation item.
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
|
@webmethod(route="/conversations/{conversation_id}/items", method="GET", level=LLAMA_STACK_API_V1)
|
||||||
|
async def list(
|
||||||
|
self,
|
||||||
|
conversation_id: str,
|
||||||
|
after: str | NotGiven = NOT_GIVEN,
|
||||||
|
include: list[ResponseIncludable] | NotGiven = NOT_GIVEN,
|
||||||
|
limit: int | NotGiven = NOT_GIVEN,
|
||||||
|
order: Literal["asc", "desc"] | NotGiven = NOT_GIVEN,
|
||||||
|
) -> ConversationItemList:
|
||||||
|
"""List items in the conversation.
|
||||||
|
|
||||||
|
:param conversation_id: The conversation identifier.
|
||||||
|
:param after: An item ID to list items after, used in pagination.
|
||||||
|
:param include: Specify additional output data to include in the response.
|
||||||
|
:param limit: A limit on the number of objects to be returned (1-100, default 20).
|
||||||
|
:param order: The order to return items in (asc or desc, default desc).
|
||||||
|
:returns: List of conversation items.
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
|
@webmethod(route="/conversations/{conversation_id}/items/{item_id}", method="DELETE", level=LLAMA_STACK_API_V1)
|
||||||
|
async def openai_delete_conversation_item(
|
||||||
|
self, conversation_id: str, item_id: str
|
||||||
|
) -> ConversationItemDeletedResource:
|
||||||
|
"""Delete a conversation item.
|
||||||
|
|
||||||
|
:param conversation_id: The conversation identifier.
|
||||||
|
:param item_id: The item identifier.
|
||||||
|
:returns: The deleted item resource.
|
||||||
|
"""
|
||||||
|
...
|
|
@ -129,6 +129,7 @@ class Api(Enum, metaclass=DynamicApiMeta):
|
||||||
tool_groups = "tool_groups"
|
tool_groups = "tool_groups"
|
||||||
files = "files"
|
files = "files"
|
||||||
prompts = "prompts"
|
prompts = "prompts"
|
||||||
|
conversations = "conversations"
|
||||||
|
|
||||||
# built-in API
|
# built-in API
|
||||||
inspect = "inspect"
|
inspect = "inspect"
|
||||||
|
|
5
llama_stack/core/conversations/__init__.py
Normal file
5
llama_stack/core/conversations/__init__.py
Normal file
|
@ -0,0 +1,5 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
306
llama_stack/core/conversations/conversations.py
Normal file
306
llama_stack/core/conversations/conversations.py
Normal file
|
@ -0,0 +1,306 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
import os
|
||||||
|
import secrets
|
||||||
|
import time
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from openai import NOT_GIVEN
|
||||||
|
from pydantic import BaseModel, TypeAdapter
|
||||||
|
|
||||||
|
from llama_stack.apis.conversations.conversations import (
|
||||||
|
Conversation,
|
||||||
|
ConversationDeletedResource,
|
||||||
|
ConversationItem,
|
||||||
|
ConversationItemDeletedResource,
|
||||||
|
ConversationItemList,
|
||||||
|
Conversations,
|
||||||
|
Metadata,
|
||||||
|
)
|
||||||
|
from llama_stack.core.datatypes import AccessRule
|
||||||
|
from llama_stack.core.utils.config_dirs import DISTRIBS_BASE_DIR
|
||||||
|
from llama_stack.log import get_logger
|
||||||
|
from llama_stack.providers.utils.sqlstore.api import ColumnDefinition, ColumnType
|
||||||
|
from llama_stack.providers.utils.sqlstore.authorized_sqlstore import AuthorizedSqlStore
|
||||||
|
from llama_stack.providers.utils.sqlstore.sqlstore import (
|
||||||
|
SqliteSqlStoreConfig,
|
||||||
|
SqlStoreConfig,
|
||||||
|
sqlstore_impl,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger = get_logger(name=__name__, category="openai::conversations")
|
||||||
|
|
||||||
|
|
||||||
|
class ConversationServiceConfig(BaseModel):
|
||||||
|
"""Configuration for the built-in conversation service.
|
||||||
|
|
||||||
|
:param conversations_store: SQL store configuration for conversations (defaults to SQLite)
|
||||||
|
:param policy: Access control rules
|
||||||
|
"""
|
||||||
|
|
||||||
|
conversations_store: SqlStoreConfig = SqliteSqlStoreConfig(
|
||||||
|
db_path=(DISTRIBS_BASE_DIR / "conversations.db").as_posix()
|
||||||
|
)
|
||||||
|
policy: list[AccessRule] = []
|
||||||
|
|
||||||
|
|
||||||
|
async def get_provider_impl(config: ConversationServiceConfig, deps: dict[Any, Any]):
|
||||||
|
"""Get the conversation service implementation."""
|
||||||
|
impl = ConversationServiceImpl(config, deps)
|
||||||
|
await impl.initialize()
|
||||||
|
return impl
|
||||||
|
|
||||||
|
|
||||||
|
class ConversationServiceImpl(Conversations):
|
||||||
|
"""Built-in conversation service implementation using AuthorizedSqlStore."""
|
||||||
|
|
||||||
|
def __init__(self, config: ConversationServiceConfig, deps: dict[Any, Any]):
|
||||||
|
self.config = config
|
||||||
|
self.deps = deps
|
||||||
|
self.policy = config.policy
|
||||||
|
|
||||||
|
base_sql_store = sqlstore_impl(config.conversations_store)
|
||||||
|
self.sql_store = AuthorizedSqlStore(base_sql_store, self.policy)
|
||||||
|
|
||||||
|
async def initialize(self) -> None:
|
||||||
|
"""Initialize the store and create tables."""
|
||||||
|
if isinstance(self.config.conversations_store, SqliteSqlStoreConfig):
|
||||||
|
os.makedirs(os.path.dirname(self.config.conversations_store.db_path), exist_ok=True)
|
||||||
|
|
||||||
|
await self.sql_store.create_table(
|
||||||
|
"openai_conversations",
|
||||||
|
{
|
||||||
|
"id": ColumnDefinition(type=ColumnType.STRING, primary_key=True),
|
||||||
|
"created_at": ColumnType.INTEGER,
|
||||||
|
"items": ColumnType.JSON,
|
||||||
|
"metadata": ColumnType.JSON,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
await self.sql_store.create_table(
|
||||||
|
"conversation_items",
|
||||||
|
{
|
||||||
|
"id": ColumnDefinition(type=ColumnType.STRING, primary_key=True),
|
||||||
|
"conversation_id": ColumnType.STRING,
|
||||||
|
"created_at": ColumnType.INTEGER,
|
||||||
|
"item_data": ColumnType.JSON,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
async def create_conversation(
|
||||||
|
self, items: list[ConversationItem] | None = None, metadata: Metadata | None = None
|
||||||
|
) -> Conversation:
|
||||||
|
"""Create a conversation."""
|
||||||
|
random_bytes = secrets.token_bytes(24)
|
||||||
|
conversation_id = f"conv_{random_bytes.hex()}"
|
||||||
|
created_at = int(time.time())
|
||||||
|
|
||||||
|
record_data = {
|
||||||
|
"id": conversation_id,
|
||||||
|
"created_at": created_at,
|
||||||
|
"items": [],
|
||||||
|
"metadata": metadata,
|
||||||
|
}
|
||||||
|
|
||||||
|
await self.sql_store.insert(
|
||||||
|
table="openai_conversations",
|
||||||
|
data=record_data,
|
||||||
|
)
|
||||||
|
|
||||||
|
if items:
|
||||||
|
item_records = []
|
||||||
|
for item in items:
|
||||||
|
item_dict = item.model_dump()
|
||||||
|
item_id = self._get_or_generate_item_id(item, item_dict)
|
||||||
|
|
||||||
|
item_record = {
|
||||||
|
"id": item_id,
|
||||||
|
"conversation_id": conversation_id,
|
||||||
|
"created_at": created_at,
|
||||||
|
"item_data": item_dict,
|
||||||
|
}
|
||||||
|
|
||||||
|
item_records.append(item_record)
|
||||||
|
|
||||||
|
await self.sql_store.insert(table="conversation_items", data=item_records)
|
||||||
|
|
||||||
|
conversation = Conversation(
|
||||||
|
id=conversation_id,
|
||||||
|
created_at=created_at,
|
||||||
|
metadata=metadata,
|
||||||
|
object="conversation",
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(f"Created conversation {conversation_id}")
|
||||||
|
return conversation
|
||||||
|
|
||||||
|
async def get_conversation(self, conversation_id: str) -> Conversation:
|
||||||
|
"""Get a conversation with the given ID."""
|
||||||
|
record = await self.sql_store.fetch_one(table="openai_conversations", where={"id": conversation_id})
|
||||||
|
|
||||||
|
if record is None:
|
||||||
|
raise ValueError(f"Conversation {conversation_id} not found")
|
||||||
|
|
||||||
|
return Conversation(
|
||||||
|
id=record["id"], created_at=record["created_at"], metadata=record.get("metadata"), object="conversation"
|
||||||
|
)
|
||||||
|
|
||||||
|
async def update_conversation(self, conversation_id: str, metadata: Metadata) -> Conversation:
|
||||||
|
"""Update a conversation's metadata with the given ID"""
|
||||||
|
await self.sql_store.update(
|
||||||
|
table="openai_conversations", data={"metadata": metadata}, where={"id": conversation_id}
|
||||||
|
)
|
||||||
|
|
||||||
|
return await self.get_conversation(conversation_id)
|
||||||
|
|
||||||
|
async def openai_delete_conversation(self, conversation_id: str) -> ConversationDeletedResource:
|
||||||
|
"""Delete a conversation with the given ID."""
|
||||||
|
await self.sql_store.delete(table="openai_conversations", where={"id": conversation_id})
|
||||||
|
|
||||||
|
logger.info(f"Deleted conversation {conversation_id}")
|
||||||
|
return ConversationDeletedResource(id=conversation_id)
|
||||||
|
|
||||||
|
def _validate_conversation_id(self, conversation_id: str) -> None:
|
||||||
|
"""Validate conversation ID format."""
|
||||||
|
if not conversation_id.startswith("conv_"):
|
||||||
|
raise ValueError(
|
||||||
|
f"Invalid 'conversation_id': '{conversation_id}'. Expected an ID that begins with 'conv_'."
|
||||||
|
)
|
||||||
|
|
||||||
|
def _get_or_generate_item_id(self, item: ConversationItem, item_dict: dict) -> str:
|
||||||
|
"""Get existing item ID or generate one if missing."""
|
||||||
|
if item.id is None:
|
||||||
|
random_bytes = secrets.token_bytes(24)
|
||||||
|
if item.type == "message":
|
||||||
|
item_id = f"msg_{random_bytes.hex()}"
|
||||||
|
else:
|
||||||
|
item_id = f"item_{random_bytes.hex()}"
|
||||||
|
item_dict["id"] = item_id
|
||||||
|
return item_id
|
||||||
|
return item.id
|
||||||
|
|
||||||
|
async def _get_validated_conversation(self, conversation_id: str) -> Conversation:
|
||||||
|
"""Validate conversation ID and return the conversation if it exists."""
|
||||||
|
self._validate_conversation_id(conversation_id)
|
||||||
|
return await self.get_conversation(conversation_id)
|
||||||
|
|
||||||
|
async def add_items(self, conversation_id: str, items: list[ConversationItem]) -> ConversationItemList:
|
||||||
|
"""Create (add) items to a conversation."""
|
||||||
|
await self._get_validated_conversation(conversation_id)
|
||||||
|
|
||||||
|
created_items = []
|
||||||
|
created_at = int(time.time())
|
||||||
|
|
||||||
|
for item in items:
|
||||||
|
item_dict = item.model_dump()
|
||||||
|
item_id = self._get_or_generate_item_id(item, item_dict)
|
||||||
|
|
||||||
|
item_record = {
|
||||||
|
"id": item_id,
|
||||||
|
"conversation_id": conversation_id,
|
||||||
|
"created_at": created_at,
|
||||||
|
"item_data": item_dict,
|
||||||
|
}
|
||||||
|
|
||||||
|
# TODO: Add support for upsert in sql_store, this will fail first if ID exists and then update
|
||||||
|
try:
|
||||||
|
await self.sql_store.insert(table="conversation_items", data=item_record)
|
||||||
|
except Exception:
|
||||||
|
# If insert fails due to ID conflict, update existing record
|
||||||
|
await self.sql_store.update(
|
||||||
|
table="conversation_items",
|
||||||
|
data={"created_at": created_at, "item_data": item_dict},
|
||||||
|
where={"id": item_id},
|
||||||
|
)
|
||||||
|
|
||||||
|
created_items.append(item_dict)
|
||||||
|
|
||||||
|
logger.info(f"Created {len(created_items)} items in conversation {conversation_id}")
|
||||||
|
|
||||||
|
# Convert created items (dicts) to proper ConversationItem types
|
||||||
|
adapter: TypeAdapter[ConversationItem] = TypeAdapter(ConversationItem)
|
||||||
|
response_items: list[ConversationItem] = [adapter.validate_python(item_dict) for item_dict in created_items]
|
||||||
|
|
||||||
|
return ConversationItemList(
|
||||||
|
data=response_items,
|
||||||
|
first_id=created_items[0]["id"] if created_items else None,
|
||||||
|
last_id=created_items[-1]["id"] if created_items else None,
|
||||||
|
has_more=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def retrieve(self, conversation_id: str, item_id: str) -> ConversationItem:
|
||||||
|
"""Retrieve a conversation item."""
|
||||||
|
if not conversation_id:
|
||||||
|
raise ValueError(f"Expected a non-empty value for `conversation_id` but received {conversation_id!r}")
|
||||||
|
if not item_id:
|
||||||
|
raise ValueError(f"Expected a non-empty value for `item_id` but received {item_id!r}")
|
||||||
|
|
||||||
|
# Get item from conversation_items table
|
||||||
|
record = await self.sql_store.fetch_one(
|
||||||
|
table="conversation_items", where={"id": item_id, "conversation_id": conversation_id}
|
||||||
|
)
|
||||||
|
|
||||||
|
if record is None:
|
||||||
|
raise ValueError(f"Item {item_id} not found in conversation {conversation_id}")
|
||||||
|
|
||||||
|
adapter: TypeAdapter[ConversationItem] = TypeAdapter(ConversationItem)
|
||||||
|
return adapter.validate_python(record["item_data"])
|
||||||
|
|
||||||
|
async def list(self, conversation_id: str, after=NOT_GIVEN, include=NOT_GIVEN, limit=NOT_GIVEN, order=NOT_GIVEN):
|
||||||
|
"""List items in the conversation."""
|
||||||
|
result = await self.sql_store.fetch_all(table="conversation_items", where={"conversation_id": conversation_id})
|
||||||
|
records = result.data
|
||||||
|
|
||||||
|
if order != NOT_GIVEN and order == "asc":
|
||||||
|
records.sort(key=lambda x: x["created_at"])
|
||||||
|
else:
|
||||||
|
records.sort(key=lambda x: x["created_at"], reverse=True)
|
||||||
|
|
||||||
|
actual_limit = 20
|
||||||
|
if limit != NOT_GIVEN and isinstance(limit, int):
|
||||||
|
actual_limit = limit
|
||||||
|
|
||||||
|
records = records[:actual_limit]
|
||||||
|
items = [record["item_data"] for record in records]
|
||||||
|
|
||||||
|
adapter: TypeAdapter[ConversationItem] = TypeAdapter(ConversationItem)
|
||||||
|
response_items: list[ConversationItem] = [adapter.validate_python(item) for item in items]
|
||||||
|
|
||||||
|
first_id = response_items[0].id if response_items else None
|
||||||
|
last_id = response_items[-1].id if response_items else None
|
||||||
|
|
||||||
|
return ConversationItemList(
|
||||||
|
data=response_items,
|
||||||
|
first_id=first_id,
|
||||||
|
last_id=last_id,
|
||||||
|
has_more=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def openai_delete_conversation_item(
|
||||||
|
self, conversation_id: str, item_id: str
|
||||||
|
) -> ConversationItemDeletedResource:
|
||||||
|
"""Delete a conversation item."""
|
||||||
|
if not conversation_id:
|
||||||
|
raise ValueError(f"Expected a non-empty value for `conversation_id` but received {conversation_id!r}")
|
||||||
|
if not item_id:
|
||||||
|
raise ValueError(f"Expected a non-empty value for `item_id` but received {item_id!r}")
|
||||||
|
|
||||||
|
_ = await self._get_validated_conversation(conversation_id)
|
||||||
|
|
||||||
|
record = await self.sql_store.fetch_one(
|
||||||
|
table="conversation_items", where={"id": item_id, "conversation_id": conversation_id}
|
||||||
|
)
|
||||||
|
|
||||||
|
if record is None:
|
||||||
|
raise ValueError(f"Item {item_id} not found in conversation {conversation_id}")
|
||||||
|
|
||||||
|
await self.sql_store.delete(
|
||||||
|
table="conversation_items", where={"id": item_id, "conversation_id": conversation_id}
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(f"Deleted item {item_id} from conversation {conversation_id}")
|
||||||
|
return ConversationItemDeletedResource(id=item_id)
|
|
@ -475,6 +475,13 @@ InferenceStoreConfig (with queue tuning parameters) or a SqlStoreConfig (depreca
|
||||||
If not specified, a default SQLite store will be used.""",
|
If not specified, a default SQLite store will be used.""",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
conversations_store: SqlStoreConfig | None = Field(
|
||||||
|
default=None,
|
||||||
|
description="""
|
||||||
|
Configuration for the persistence store used by the conversations API.
|
||||||
|
If not specified, a default SQLite store will be used.""",
|
||||||
|
)
|
||||||
|
|
||||||
# registry of "resources" in the distribution
|
# registry of "resources" in the distribution
|
||||||
models: list[ModelInput] = Field(default_factory=list)
|
models: list[ModelInput] = Field(default_factory=list)
|
||||||
shields: list[ShieldInput] = Field(default_factory=list)
|
shields: list[ShieldInput] = Field(default_factory=list)
|
||||||
|
|
|
@ -25,7 +25,7 @@ from llama_stack.providers.datatypes import (
|
||||||
logger = get_logger(name=__name__, category="core")
|
logger = get_logger(name=__name__, category="core")
|
||||||
|
|
||||||
|
|
||||||
INTERNAL_APIS = {Api.inspect, Api.providers, Api.prompts}
|
INTERNAL_APIS = {Api.inspect, Api.providers, Api.prompts, Api.conversations}
|
||||||
|
|
||||||
|
|
||||||
def stack_apis() -> list[Api]:
|
def stack_apis() -> list[Api]:
|
||||||
|
|
|
@ -10,6 +10,7 @@ from typing import Any
|
||||||
from llama_stack.apis.agents import Agents
|
from llama_stack.apis.agents import Agents
|
||||||
from llama_stack.apis.batches import Batches
|
from llama_stack.apis.batches import Batches
|
||||||
from llama_stack.apis.benchmarks import Benchmarks
|
from llama_stack.apis.benchmarks import Benchmarks
|
||||||
|
from llama_stack.apis.conversations import Conversations
|
||||||
from llama_stack.apis.datasetio import DatasetIO
|
from llama_stack.apis.datasetio import DatasetIO
|
||||||
from llama_stack.apis.datasets import Datasets
|
from llama_stack.apis.datasets import Datasets
|
||||||
from llama_stack.apis.datatypes import ExternalApiSpec
|
from llama_stack.apis.datatypes import ExternalApiSpec
|
||||||
|
@ -96,6 +97,7 @@ def api_protocol_map(external_apis: dict[Api, ExternalApiSpec] | None = None) ->
|
||||||
Api.tool_runtime: ToolRuntime,
|
Api.tool_runtime: ToolRuntime,
|
||||||
Api.files: Files,
|
Api.files: Files,
|
||||||
Api.prompts: Prompts,
|
Api.prompts: Prompts,
|
||||||
|
Api.conversations: Conversations,
|
||||||
}
|
}
|
||||||
|
|
||||||
if external_apis:
|
if external_apis:
|
||||||
|
|
|
@ -451,6 +451,7 @@ def create_app(
|
||||||
apis_to_serve.add("inspect")
|
apis_to_serve.add("inspect")
|
||||||
apis_to_serve.add("providers")
|
apis_to_serve.add("providers")
|
||||||
apis_to_serve.add("prompts")
|
apis_to_serve.add("prompts")
|
||||||
|
apis_to_serve.add("conversations")
|
||||||
for api_str in apis_to_serve:
|
for api_str in apis_to_serve:
|
||||||
api = Api(api_str)
|
api = Api(api_str)
|
||||||
|
|
||||||
|
|
|
@ -15,6 +15,7 @@ import yaml
|
||||||
|
|
||||||
from llama_stack.apis.agents import Agents
|
from llama_stack.apis.agents import Agents
|
||||||
from llama_stack.apis.benchmarks import Benchmarks
|
from llama_stack.apis.benchmarks import Benchmarks
|
||||||
|
from llama_stack.apis.conversations import Conversations
|
||||||
from llama_stack.apis.datasetio import DatasetIO
|
from llama_stack.apis.datasetio import DatasetIO
|
||||||
from llama_stack.apis.datasets import Datasets
|
from llama_stack.apis.datasets import Datasets
|
||||||
from llama_stack.apis.eval import Eval
|
from llama_stack.apis.eval import Eval
|
||||||
|
@ -34,6 +35,7 @@ from llama_stack.apis.telemetry import Telemetry
|
||||||
from llama_stack.apis.tools import RAGToolRuntime, ToolGroups, ToolRuntime
|
from llama_stack.apis.tools import RAGToolRuntime, ToolGroups, ToolRuntime
|
||||||
from llama_stack.apis.vector_dbs import VectorDBs
|
from llama_stack.apis.vector_dbs import VectorDBs
|
||||||
from llama_stack.apis.vector_io import VectorIO
|
from llama_stack.apis.vector_io import VectorIO
|
||||||
|
from llama_stack.core.conversations.conversations import ConversationServiceConfig, ConversationServiceImpl
|
||||||
from llama_stack.core.datatypes import Provider, StackRunConfig
|
from llama_stack.core.datatypes import Provider, StackRunConfig
|
||||||
from llama_stack.core.distribution import get_provider_registry
|
from llama_stack.core.distribution import get_provider_registry
|
||||||
from llama_stack.core.inspect import DistributionInspectConfig, DistributionInspectImpl
|
from llama_stack.core.inspect import DistributionInspectConfig, DistributionInspectImpl
|
||||||
|
@ -73,6 +75,7 @@ class LlamaStack(
|
||||||
RAGToolRuntime,
|
RAGToolRuntime,
|
||||||
Files,
|
Files,
|
||||||
Prompts,
|
Prompts,
|
||||||
|
Conversations,
|
||||||
):
|
):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@ -312,6 +315,12 @@ def add_internal_implementations(impls: dict[Api, Any], run_config: StackRunConf
|
||||||
)
|
)
|
||||||
impls[Api.prompts] = prompts_impl
|
impls[Api.prompts] = prompts_impl
|
||||||
|
|
||||||
|
conversations_impl = ConversationServiceImpl(
|
||||||
|
ConversationServiceConfig(run_config=run_config),
|
||||||
|
deps=impls,
|
||||||
|
)
|
||||||
|
impls[Api.conversations] = conversations_impl
|
||||||
|
|
||||||
|
|
||||||
class Stack:
|
class Stack:
|
||||||
def __init__(self, run_config: StackRunConfig, provider_registry: ProviderRegistry | None = None):
|
def __init__(self, run_config: StackRunConfig, provider_registry: ProviderRegistry | None = None):
|
||||||
|
@ -342,6 +351,8 @@ class Stack:
|
||||||
|
|
||||||
if Api.prompts in impls:
|
if Api.prompts in impls:
|
||||||
await impls[Api.prompts].initialize()
|
await impls[Api.prompts].initialize()
|
||||||
|
if Api.conversations in impls:
|
||||||
|
await impls[Api.conversations].initialize()
|
||||||
|
|
||||||
await register_resources(self.run_config, impls)
|
await register_resources(self.run_config, impls)
|
||||||
|
|
||||||
|
|
|
@ -4,7 +4,7 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from collections.abc import Mapping
|
from collections.abc import Mapping, Sequence
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Any, Literal, Protocol
|
from typing import Any, Literal, Protocol
|
||||||
|
|
||||||
|
@ -41,9 +41,9 @@ class SqlStore(Protocol):
|
||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
async def insert(self, table: str, data: Mapping[str, Any]) -> None:
|
async def insert(self, table: str, data: Mapping[str, Any] | Sequence[Mapping[str, Any]]) -> None:
|
||||||
"""
|
"""
|
||||||
Insert a row into a table.
|
Insert a row or batch of rows into a table.
|
||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
|
@ -4,7 +4,7 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from collections.abc import Mapping
|
from collections.abc import Mapping, Sequence
|
||||||
from typing import Any, Literal
|
from typing import Any, Literal
|
||||||
|
|
||||||
from llama_stack.core.access_control.access_control import default_policy, is_action_allowed
|
from llama_stack.core.access_control.access_control import default_policy, is_action_allowed
|
||||||
|
@ -38,6 +38,18 @@ SQL_OPTIMIZED_POLICY = [
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def _enhance_item_with_access_control(item: Mapping[str, Any], current_user: User | None) -> Mapping[str, Any]:
|
||||||
|
"""Add access control attributes to a data item."""
|
||||||
|
enhanced = dict(item)
|
||||||
|
if current_user:
|
||||||
|
enhanced["owner_principal"] = current_user.principal
|
||||||
|
enhanced["access_attributes"] = current_user.attributes
|
||||||
|
else:
|
||||||
|
enhanced["owner_principal"] = None
|
||||||
|
enhanced["access_attributes"] = None
|
||||||
|
return enhanced
|
||||||
|
|
||||||
|
|
||||||
class SqlRecord(ProtectedResource):
|
class SqlRecord(ProtectedResource):
|
||||||
def __init__(self, record_id: str, table_name: str, owner: User):
|
def __init__(self, record_id: str, table_name: str, owner: User):
|
||||||
self.type = f"sql_record::{table_name}"
|
self.type = f"sql_record::{table_name}"
|
||||||
|
@ -102,18 +114,14 @@ class AuthorizedSqlStore:
|
||||||
await self.sql_store.add_column_if_not_exists(table, "access_attributes", ColumnType.JSON)
|
await self.sql_store.add_column_if_not_exists(table, "access_attributes", ColumnType.JSON)
|
||||||
await self.sql_store.add_column_if_not_exists(table, "owner_principal", ColumnType.STRING)
|
await self.sql_store.add_column_if_not_exists(table, "owner_principal", ColumnType.STRING)
|
||||||
|
|
||||||
async def insert(self, table: str, data: Mapping[str, Any]) -> None:
|
async def insert(self, table: str, data: Mapping[str, Any] | Sequence[Mapping[str, Any]]) -> None:
|
||||||
"""Insert a row with automatic access control attribute capture."""
|
"""Insert a row or batch of rows with automatic access control attribute capture."""
|
||||||
enhanced_data = dict(data)
|
|
||||||
|
|
||||||
current_user = get_authenticated_user()
|
current_user = get_authenticated_user()
|
||||||
if current_user:
|
enhanced_data: Mapping[str, Any] | Sequence[Mapping[str, Any]]
|
||||||
enhanced_data["owner_principal"] = current_user.principal
|
if isinstance(data, Mapping):
|
||||||
enhanced_data["access_attributes"] = current_user.attributes
|
enhanced_data = _enhance_item_with_access_control(data, current_user)
|
||||||
else:
|
else:
|
||||||
enhanced_data["owner_principal"] = None
|
enhanced_data = [_enhance_item_with_access_control(item, current_user) for item in data]
|
||||||
enhanced_data["access_attributes"] = None
|
|
||||||
|
|
||||||
await self.sql_store.insert(table, enhanced_data)
|
await self.sql_store.insert(table, enhanced_data)
|
||||||
|
|
||||||
async def fetch_all(
|
async def fetch_all(
|
||||||
|
|
|
@ -3,7 +3,7 @@
|
||||||
#
|
#
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
from collections.abc import Mapping
|
from collections.abc import Mapping, Sequence
|
||||||
from typing import Any, Literal
|
from typing import Any, Literal
|
||||||
|
|
||||||
from sqlalchemy import (
|
from sqlalchemy import (
|
||||||
|
@ -116,7 +116,7 @@ class SqlAlchemySqlStoreImpl(SqlStore):
|
||||||
async with engine.begin() as conn:
|
async with engine.begin() as conn:
|
||||||
await conn.run_sync(self.metadata.create_all, tables=[sqlalchemy_table], checkfirst=True)
|
await conn.run_sync(self.metadata.create_all, tables=[sqlalchemy_table], checkfirst=True)
|
||||||
|
|
||||||
async def insert(self, table: str, data: Mapping[str, Any]) -> None:
|
async def insert(self, table: str, data: Mapping[str, Any] | Sequence[Mapping[str, Any]]) -> None:
|
||||||
async with self.async_session() as session:
|
async with self.async_session() as session:
|
||||||
await session.execute(self.metadata.tables[table].insert(), data)
|
await session.execute(self.metadata.tables[table].insert(), data)
|
||||||
await session.commit()
|
await session.commit()
|
||||||
|
|
|
@ -484,12 +484,19 @@ class JsonSchemaGenerator:
|
||||||
}
|
}
|
||||||
return ret
|
return ret
|
||||||
elif origin_type is Literal:
|
elif origin_type is Literal:
|
||||||
if len(typing.get_args(typ)) != 1:
|
literal_args = typing.get_args(typ)
|
||||||
raise ValueError(f"Literal type {typ} has {len(typing.get_args(typ))} arguments")
|
if len(literal_args) == 1:
|
||||||
(literal_value,) = typing.get_args(typ) # unpack value of literal type
|
(literal_value,) = literal_args
|
||||||
schema = self.type_to_schema(type(literal_value))
|
schema = self.type_to_schema(type(literal_value))
|
||||||
schema["const"] = literal_value
|
schema["const"] = literal_value
|
||||||
return schema
|
return schema
|
||||||
|
elif len(literal_args) > 1:
|
||||||
|
first_value = literal_args[0]
|
||||||
|
schema = self.type_to_schema(type(first_value))
|
||||||
|
schema["enum"] = list(literal_args)
|
||||||
|
return schema
|
||||||
|
else:
|
||||||
|
return {"enum": []}
|
||||||
elif origin_type is type:
|
elif origin_type is type:
|
||||||
(concrete_type,) = typing.get_args(typ) # unpack single tuple element
|
(concrete_type,) = typing.get_args(typ) # unpack single tuple element
|
||||||
return {"const": self.type_to_schema(concrete_type, force_expand=True)}
|
return {"const": self.type_to_schema(concrete_type, force_expand=True)}
|
||||||
|
|
|
@ -32,7 +32,7 @@ dependencies = [
|
||||||
"jinja2>=3.1.6",
|
"jinja2>=3.1.6",
|
||||||
"jsonschema",
|
"jsonschema",
|
||||||
"llama-stack-client>=0.2.23",
|
"llama-stack-client>=0.2.23",
|
||||||
"openai>=1.100.0", # for expires_after support
|
"openai>=1.107", # for expires_after support
|
||||||
"prompt-toolkit",
|
"prompt-toolkit",
|
||||||
"python-dotenv",
|
"python-dotenv",
|
||||||
"python-jose[cryptography]",
|
"python-jose[cryptography]",
|
||||||
|
@ -49,6 +49,7 @@ dependencies = [
|
||||||
"opentelemetry-exporter-otlp-proto-http>=1.30.0", # server
|
"opentelemetry-exporter-otlp-proto-http>=1.30.0", # server
|
||||||
"aiosqlite>=0.21.0", # server - for metadata store
|
"aiosqlite>=0.21.0", # server - for metadata store
|
||||||
"asyncpg", # for metadata store
|
"asyncpg", # for metadata store
|
||||||
|
"sqlalchemy[asyncio]>=2.0.41", # server - for conversations
|
||||||
]
|
]
|
||||||
|
|
||||||
[project.optional-dependencies]
|
[project.optional-dependencies]
|
||||||
|
|
135
tests/integration/conversations/test_openai_conversations.py
Normal file
135
tests/integration/conversations/test_openai_conversations.py
Normal file
|
@ -0,0 +1,135 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.integration
|
||||||
|
class TestOpenAIConversations:
|
||||||
|
# TODO: Update to compat_client after client-SDK is generated
|
||||||
|
def test_conversation_create(self, openai_client):
|
||||||
|
conversation = openai_client.conversations.create(
|
||||||
|
metadata={"topic": "demo"}, items=[{"type": "message", "role": "user", "content": "Hello!"}]
|
||||||
|
)
|
||||||
|
|
||||||
|
assert conversation.id.startswith("conv_")
|
||||||
|
assert conversation.object == "conversation"
|
||||||
|
assert conversation.metadata["topic"] == "demo"
|
||||||
|
assert isinstance(conversation.created_at, int)
|
||||||
|
|
||||||
|
def test_conversation_retrieve(self, openai_client):
|
||||||
|
conversation = openai_client.conversations.create(metadata={"topic": "demo"})
|
||||||
|
|
||||||
|
retrieved = openai_client.conversations.retrieve(conversation.id)
|
||||||
|
|
||||||
|
assert retrieved.id == conversation.id
|
||||||
|
assert retrieved.object == "conversation"
|
||||||
|
assert retrieved.metadata["topic"] == "demo"
|
||||||
|
assert retrieved.created_at == conversation.created_at
|
||||||
|
|
||||||
|
def test_conversation_update(self, openai_client):
|
||||||
|
conversation = openai_client.conversations.create(metadata={"topic": "demo"})
|
||||||
|
|
||||||
|
updated = openai_client.conversations.update(conversation.id, metadata={"topic": "project-x"})
|
||||||
|
|
||||||
|
assert updated.id == conversation.id
|
||||||
|
assert updated.metadata["topic"] == "project-x"
|
||||||
|
assert updated.created_at == conversation.created_at
|
||||||
|
|
||||||
|
def test_conversation_delete(self, openai_client):
|
||||||
|
conversation = openai_client.conversations.create(metadata={"topic": "demo"})
|
||||||
|
|
||||||
|
deleted = openai_client.conversations.delete(conversation.id)
|
||||||
|
|
||||||
|
assert deleted.id == conversation.id
|
||||||
|
assert deleted.object == "conversation.deleted"
|
||||||
|
assert deleted.deleted is True
|
||||||
|
|
||||||
|
def test_conversation_items_create(self, openai_client):
|
||||||
|
conversation = openai_client.conversations.create()
|
||||||
|
|
||||||
|
items = openai_client.conversations.items.create(
|
||||||
|
conversation.id,
|
||||||
|
items=[
|
||||||
|
{"type": "message", "role": "user", "content": [{"type": "input_text", "text": "Hello!"}]},
|
||||||
|
{"type": "message", "role": "user", "content": [{"type": "input_text", "text": "How are you?"}]},
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
assert items.object == "list"
|
||||||
|
assert len(items.data) == 2
|
||||||
|
assert items.data[0].content[0].text == "Hello!"
|
||||||
|
assert items.data[1].content[0].text == "How are you?"
|
||||||
|
assert items.first_id == items.data[0].id
|
||||||
|
assert items.last_id == items.data[1].id
|
||||||
|
assert items.has_more is False
|
||||||
|
|
||||||
|
def test_conversation_items_list(self, openai_client):
|
||||||
|
conversation = openai_client.conversations.create()
|
||||||
|
|
||||||
|
openai_client.conversations.items.create(
|
||||||
|
conversation.id,
|
||||||
|
items=[{"type": "message", "role": "user", "content": [{"type": "input_text", "text": "Hello!"}]}],
|
||||||
|
)
|
||||||
|
|
||||||
|
items = openai_client.conversations.items.list(conversation.id, limit=10)
|
||||||
|
|
||||||
|
assert items.object == "list"
|
||||||
|
assert len(items.data) >= 1
|
||||||
|
assert items.data[0].type == "message"
|
||||||
|
assert items.data[0].role == "user"
|
||||||
|
assert hasattr(items, "first_id")
|
||||||
|
assert hasattr(items, "last_id")
|
||||||
|
assert hasattr(items, "has_more")
|
||||||
|
|
||||||
|
def test_conversation_item_retrieve(self, openai_client):
|
||||||
|
conversation = openai_client.conversations.create()
|
||||||
|
|
||||||
|
created_items = openai_client.conversations.items.create(
|
||||||
|
conversation.id,
|
||||||
|
items=[{"type": "message", "role": "user", "content": [{"type": "input_text", "text": "Hello!"}]}],
|
||||||
|
)
|
||||||
|
|
||||||
|
item_id = created_items.data[0].id
|
||||||
|
item = openai_client.conversations.items.retrieve(item_id, conversation_id=conversation.id)
|
||||||
|
|
||||||
|
assert item.id == item_id
|
||||||
|
assert item.type == "message"
|
||||||
|
assert item.role == "user"
|
||||||
|
assert item.content[0].text == "Hello!"
|
||||||
|
|
||||||
|
def test_conversation_item_delete(self, openai_client):
|
||||||
|
conversation = openai_client.conversations.create()
|
||||||
|
|
||||||
|
created_items = openai_client.conversations.items.create(
|
||||||
|
conversation.id,
|
||||||
|
items=[{"type": "message", "role": "user", "content": [{"type": "input_text", "text": "Hello!"}]}],
|
||||||
|
)
|
||||||
|
|
||||||
|
item_id = created_items.data[0].id
|
||||||
|
deleted = openai_client.conversations.items.delete(item_id, conversation_id=conversation.id)
|
||||||
|
|
||||||
|
assert deleted.id == item_id
|
||||||
|
assert deleted.object == "conversation.item.deleted"
|
||||||
|
assert deleted.deleted is True
|
||||||
|
|
||||||
|
def test_full_workflow(self, openai_client):
|
||||||
|
conversation = openai_client.conversations.create(
|
||||||
|
metadata={"topic": "workflow-test"}, items=[{"type": "message", "role": "user", "content": "Hello!"}]
|
||||||
|
)
|
||||||
|
|
||||||
|
openai_client.conversations.items.create(
|
||||||
|
conversation.id,
|
||||||
|
items=[{"type": "message", "role": "user", "content": [{"type": "input_text", "text": "Follow up"}]}],
|
||||||
|
)
|
||||||
|
|
||||||
|
all_items = openai_client.conversations.items.list(conversation.id)
|
||||||
|
assert len(all_items.data) >= 2
|
||||||
|
|
||||||
|
updated = openai_client.conversations.update(conversation.id, metadata={"topic": "workflow-complete"})
|
||||||
|
assert updated.metadata["topic"] == "workflow-complete"
|
||||||
|
|
||||||
|
openai_client.conversations.delete(conversation.id)
|
60
tests/unit/conversations/test_api_models.py
Normal file
60
tests/unit/conversations/test_api_models.py
Normal file
|
@ -0,0 +1,60 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
|
||||||
|
from llama_stack.apis.conversations.conversations import (
|
||||||
|
Conversation,
|
||||||
|
ConversationCreateRequest,
|
||||||
|
ConversationItem,
|
||||||
|
ConversationItemList,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_conversation_create_request_defaults():
|
||||||
|
request = ConversationCreateRequest()
|
||||||
|
assert request.items == []
|
||||||
|
assert request.metadata == {}
|
||||||
|
|
||||||
|
|
||||||
|
def test_conversation_model_defaults():
|
||||||
|
conversation = Conversation(
|
||||||
|
id="conv_123456789",
|
||||||
|
created_at=1234567890,
|
||||||
|
metadata=None,
|
||||||
|
object="conversation",
|
||||||
|
)
|
||||||
|
assert conversation.id == "conv_123456789"
|
||||||
|
assert conversation.object == "conversation"
|
||||||
|
assert conversation.metadata is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_openai_client_compatibility():
|
||||||
|
from openai.types.conversations.message import Message
|
||||||
|
from pydantic import TypeAdapter
|
||||||
|
|
||||||
|
openai_message = Message(
|
||||||
|
id="msg_123",
|
||||||
|
content=[{"type": "input_text", "text": "Hello"}],
|
||||||
|
role="user",
|
||||||
|
status="in_progress",
|
||||||
|
type="message",
|
||||||
|
object="message",
|
||||||
|
)
|
||||||
|
|
||||||
|
adapter = TypeAdapter(ConversationItem)
|
||||||
|
validated_item = adapter.validate_python(openai_message.model_dump())
|
||||||
|
|
||||||
|
assert validated_item.id == "msg_123"
|
||||||
|
assert validated_item.type == "message"
|
||||||
|
|
||||||
|
|
||||||
|
def test_conversation_item_list():
|
||||||
|
item_list = ConversationItemList(data=[])
|
||||||
|
assert item_list.object == "list"
|
||||||
|
assert item_list.data == []
|
||||||
|
assert item_list.first_id is None
|
||||||
|
assert item_list.last_id is None
|
||||||
|
assert item_list.has_more is False
|
132
tests/unit/conversations/test_conversations.py
Normal file
132
tests/unit/conversations/test_conversations.py
Normal file
|
@ -0,0 +1,132 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
import tempfile
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from openai.types.conversations.conversation import Conversation as OpenAIConversation
|
||||||
|
from openai.types.conversations.conversation_item import ConversationItem as OpenAIConversationItem
|
||||||
|
from pydantic import TypeAdapter
|
||||||
|
|
||||||
|
from llama_stack.apis.agents.openai_responses import (
|
||||||
|
OpenAIResponseInputMessageContentText,
|
||||||
|
OpenAIResponseMessage,
|
||||||
|
)
|
||||||
|
from llama_stack.core.conversations.conversations import (
|
||||||
|
ConversationServiceConfig,
|
||||||
|
ConversationServiceImpl,
|
||||||
|
)
|
||||||
|
from llama_stack.providers.utils.sqlstore.sqlstore import SqliteSqlStoreConfig
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
async def service():
|
||||||
|
with tempfile.TemporaryDirectory() as tmpdir:
|
||||||
|
db_path = Path(tmpdir) / "test_conversations.db"
|
||||||
|
|
||||||
|
config = ConversationServiceConfig(conversations_store=SqliteSqlStoreConfig(db_path=str(db_path)), policy=[])
|
||||||
|
service = ConversationServiceImpl(config, {})
|
||||||
|
await service.initialize()
|
||||||
|
yield service
|
||||||
|
|
||||||
|
|
||||||
|
async def test_conversation_lifecycle(service):
|
||||||
|
conversation = await service.create_conversation(metadata={"test": "data"})
|
||||||
|
|
||||||
|
assert conversation.id.startswith("conv_")
|
||||||
|
assert conversation.metadata == {"test": "data"}
|
||||||
|
|
||||||
|
retrieved = await service.get_conversation(conversation.id)
|
||||||
|
assert retrieved.id == conversation.id
|
||||||
|
|
||||||
|
deleted = await service.openai_delete_conversation(conversation.id)
|
||||||
|
assert deleted.id == conversation.id
|
||||||
|
|
||||||
|
|
||||||
|
async def test_conversation_items(service):
|
||||||
|
conversation = await service.create_conversation()
|
||||||
|
|
||||||
|
items = [
|
||||||
|
OpenAIResponseMessage(
|
||||||
|
type="message",
|
||||||
|
role="user",
|
||||||
|
content=[OpenAIResponseInputMessageContentText(type="input_text", text="Hello")],
|
||||||
|
id="msg_test123",
|
||||||
|
status="completed",
|
||||||
|
)
|
||||||
|
]
|
||||||
|
item_list = await service.add_items(conversation.id, items)
|
||||||
|
|
||||||
|
assert len(item_list.data) == 1
|
||||||
|
assert item_list.data[0].id == "msg_test123"
|
||||||
|
|
||||||
|
items = await service.list(conversation.id)
|
||||||
|
assert len(items.data) == 1
|
||||||
|
|
||||||
|
|
||||||
|
async def test_invalid_conversation_id(service):
|
||||||
|
with pytest.raises(ValueError, match="Expected an ID that begins with 'conv_'"):
|
||||||
|
await service._get_validated_conversation("invalid_id")
|
||||||
|
|
||||||
|
|
||||||
|
async def test_empty_parameter_validation(service):
|
||||||
|
with pytest.raises(ValueError, match="Expected a non-empty value"):
|
||||||
|
await service.retrieve("", "item_123")
|
||||||
|
|
||||||
|
|
||||||
|
async def test_openai_type_compatibility(service):
|
||||||
|
conversation = await service.create_conversation(metadata={"test": "value"})
|
||||||
|
|
||||||
|
conversation_dict = conversation.model_dump()
|
||||||
|
openai_conversation = OpenAIConversation.model_validate(conversation_dict)
|
||||||
|
|
||||||
|
for attr in ["id", "object", "created_at", "metadata"]:
|
||||||
|
assert getattr(openai_conversation, attr) == getattr(conversation, attr)
|
||||||
|
|
||||||
|
items = [
|
||||||
|
OpenAIResponseMessage(
|
||||||
|
type="message",
|
||||||
|
role="user",
|
||||||
|
content=[OpenAIResponseInputMessageContentText(type="input_text", text="Hello")],
|
||||||
|
id="msg_test456",
|
||||||
|
status="completed",
|
||||||
|
)
|
||||||
|
]
|
||||||
|
item_list = await service.add_items(conversation.id, items)
|
||||||
|
|
||||||
|
for attr in ["object", "data", "first_id", "last_id", "has_more"]:
|
||||||
|
assert hasattr(item_list, attr)
|
||||||
|
assert item_list.object == "list"
|
||||||
|
|
||||||
|
items = await service.list(conversation.id)
|
||||||
|
item = await service.retrieve(conversation.id, items.data[0].id)
|
||||||
|
item_dict = item.model_dump()
|
||||||
|
|
||||||
|
openai_item_adapter = TypeAdapter(OpenAIConversationItem)
|
||||||
|
openai_item_adapter.validate_python(item_dict)
|
||||||
|
|
||||||
|
|
||||||
|
async def test_policy_configuration():
|
||||||
|
from llama_stack.core.access_control.datatypes import Action, Scope
|
||||||
|
from llama_stack.core.datatypes import AccessRule
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory() as tmpdir:
|
||||||
|
db_path = Path(tmpdir) / "test_conversations_policy.db"
|
||||||
|
|
||||||
|
restrictive_policy = [
|
||||||
|
AccessRule(forbid=Scope(principal="test_user", actions=[Action.CREATE, Action.READ], resource="*"))
|
||||||
|
]
|
||||||
|
|
||||||
|
config = ConversationServiceConfig(
|
||||||
|
conversations_store=SqliteSqlStoreConfig(db_path=str(db_path)), policy=restrictive_policy
|
||||||
|
)
|
||||||
|
service = ConversationServiceImpl(config, {})
|
||||||
|
await service.initialize()
|
||||||
|
|
||||||
|
assert service.policy == restrictive_policy
|
||||||
|
assert len(service.policy) == 1
|
||||||
|
assert service.policy[0].forbid is not None
|
|
@ -368,6 +368,32 @@ async def test_where_operator_gt_and_update_delete():
|
||||||
assert {r["id"] for r in rows_after} == {1, 3}
|
assert {r["id"] for r in rows_after} == {1, 3}
|
||||||
|
|
||||||
|
|
||||||
|
async def test_batch_insert():
|
||||||
|
with TemporaryDirectory() as tmp_dir:
|
||||||
|
db_path = tmp_dir + "/test.db"
|
||||||
|
store = SqlAlchemySqlStoreImpl(SqliteSqlStoreConfig(db_path=db_path))
|
||||||
|
|
||||||
|
await store.create_table(
|
||||||
|
"batch_test",
|
||||||
|
{
|
||||||
|
"id": ColumnType.INTEGER,
|
||||||
|
"name": ColumnType.STRING,
|
||||||
|
"value": ColumnType.INTEGER,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
batch_data = [
|
||||||
|
{"id": 1, "name": "first", "value": 10},
|
||||||
|
{"id": 2, "name": "second", "value": 20},
|
||||||
|
{"id": 3, "name": "third", "value": 30},
|
||||||
|
]
|
||||||
|
|
||||||
|
await store.insert("batch_test", batch_data)
|
||||||
|
|
||||||
|
result = await store.fetch_all("batch_test", order_by=[("id", "asc")])
|
||||||
|
assert result.data == batch_data
|
||||||
|
|
||||||
|
|
||||||
async def test_where_operator_edge_cases():
|
async def test_where_operator_edge_cases():
|
||||||
with TemporaryDirectory() as tmp_dir:
|
with TemporaryDirectory() as tmp_dir:
|
||||||
db_path = tmp_dir + "/test.db"
|
db_path = tmp_dir + "/test.db"
|
||||||
|
|
4
uv.lock
generated
4
uv.lock
generated
|
@ -1773,6 +1773,7 @@ dependencies = [
|
||||||
{ name = "python-jose", extra = ["cryptography"] },
|
{ name = "python-jose", extra = ["cryptography"] },
|
||||||
{ name = "python-multipart" },
|
{ name = "python-multipart" },
|
||||||
{ name = "rich" },
|
{ name = "rich" },
|
||||||
|
{ name = "sqlalchemy", extra = ["asyncio"] },
|
||||||
{ name = "starlette" },
|
{ name = "starlette" },
|
||||||
{ name = "termcolor" },
|
{ name = "termcolor" },
|
||||||
{ name = "tiktoken" },
|
{ name = "tiktoken" },
|
||||||
|
@ -1887,7 +1888,7 @@ requires-dist = [
|
||||||
{ name = "jsonschema" },
|
{ name = "jsonschema" },
|
||||||
{ name = "llama-stack-client", specifier = ">=0.2.23" },
|
{ name = "llama-stack-client", specifier = ">=0.2.23" },
|
||||||
{ name = "llama-stack-client", marker = "extra == 'ui'", specifier = ">=0.2.23" },
|
{ name = "llama-stack-client", marker = "extra == 'ui'", specifier = ">=0.2.23" },
|
||||||
{ name = "openai", specifier = ">=1.100.0" },
|
{ name = "openai", specifier = ">=1.107" },
|
||||||
{ name = "opentelemetry-exporter-otlp-proto-http", specifier = ">=1.30.0" },
|
{ name = "opentelemetry-exporter-otlp-proto-http", specifier = ">=1.30.0" },
|
||||||
{ name = "opentelemetry-sdk", specifier = ">=1.30.0" },
|
{ name = "opentelemetry-sdk", specifier = ">=1.30.0" },
|
||||||
{ name = "pandas", marker = "extra == 'ui'" },
|
{ name = "pandas", marker = "extra == 'ui'" },
|
||||||
|
@ -1898,6 +1899,7 @@ requires-dist = [
|
||||||
{ name = "python-jose", extras = ["cryptography"] },
|
{ name = "python-jose", extras = ["cryptography"] },
|
||||||
{ name = "python-multipart", specifier = ">=0.0.20" },
|
{ name = "python-multipart", specifier = ">=0.0.20" },
|
||||||
{ name = "rich" },
|
{ name = "rich" },
|
||||||
|
{ name = "sqlalchemy", extras = ["asyncio"], specifier = ">=2.0.41" },
|
||||||
{ name = "starlette" },
|
{ name = "starlette" },
|
||||||
{ name = "streamlit", marker = "extra == 'ui'" },
|
{ name = "streamlit", marker = "extra == 'ui'" },
|
||||||
{ name = "streamlit-option-menu", marker = "extra == 'ui'" },
|
{ name = "streamlit-option-menu", marker = "extra == 'ui'" },
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue