mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-03 19:57:35 +00:00
feat: Add OpenAI Conversations API
Signed-off-by: Francisco Javier Arceo <farceo@redhat.com>
This commit is contained in:
parent
0e13512dd7
commit
a74a7cc873
18 changed files with 3280 additions and 1088 deletions
1902
docs/static/llama-stack-spec.html
vendored
1902
docs/static/llama-stack-spec.html
vendored
File diff suppressed because it is too large
Load diff
1519
docs/static/llama-stack-spec.yaml
vendored
1519
docs/static/llama-stack-spec.yaml
vendored
File diff suppressed because it is too large
Load diff
31
llama_stack/apis/conversations/__init__.py
Normal file
31
llama_stack/apis/conversations/__init__.py
Normal file
|
@ -0,0 +1,31 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from .conversations import (
|
||||||
|
Conversation,
|
||||||
|
ConversationCreateRequest,
|
||||||
|
ConversationDeletedResource,
|
||||||
|
ConversationItem,
|
||||||
|
ConversationItemCreateRequest,
|
||||||
|
ConversationItemDeletedResource,
|
||||||
|
ConversationItemList,
|
||||||
|
Conversations,
|
||||||
|
ConversationUpdateRequest,
|
||||||
|
Metadata,
|
||||||
|
)
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"Conversation",
|
||||||
|
"ConversationCreateRequest",
|
||||||
|
"ConversationDeletedResource",
|
||||||
|
"ConversationItem",
|
||||||
|
"ConversationItemCreateRequest",
|
||||||
|
"ConversationItemDeletedResource",
|
||||||
|
"ConversationItemList",
|
||||||
|
"Conversations",
|
||||||
|
"ConversationUpdateRequest",
|
||||||
|
"Metadata",
|
||||||
|
]
|
260
llama_stack/apis/conversations/conversations.py
Normal file
260
llama_stack/apis/conversations/conversations.py
Normal file
|
@ -0,0 +1,260 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from typing import Annotated, Literal, Protocol, runtime_checkable
|
||||||
|
|
||||||
|
from openai import NOT_GIVEN
|
||||||
|
from openai._types import NotGiven
|
||||||
|
from openai.types.responses.response_includable import ResponseIncludable
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
from llama_stack.apis.agents.openai_responses import (
|
||||||
|
OpenAIResponseMessage,
|
||||||
|
OpenAIResponseOutputMessageFileSearchToolCall,
|
||||||
|
OpenAIResponseOutputMessageFunctionToolCall,
|
||||||
|
OpenAIResponseOutputMessageMCPCall,
|
||||||
|
OpenAIResponseOutputMessageMCPListTools,
|
||||||
|
OpenAIResponseOutputMessageWebSearchToolCall,
|
||||||
|
)
|
||||||
|
from llama_stack.apis.version import LLAMA_STACK_API_V1
|
||||||
|
from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol
|
||||||
|
from llama_stack.schema_utils import json_schema_type, register_schema, webmethod
|
||||||
|
|
||||||
|
Metadata = dict[str, str]
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class Conversation(BaseModel):
|
||||||
|
"""OpenAI-compatible conversation object."""
|
||||||
|
|
||||||
|
id: str = Field(..., description="The unique ID of the conversation.")
|
||||||
|
object: Literal["conversation"] = Field(
|
||||||
|
default="conversation", description="The object type, which is always conversation."
|
||||||
|
)
|
||||||
|
created_at: int = Field(
|
||||||
|
..., description="The time at which the conversation was created, measured in seconds since the Unix epoch."
|
||||||
|
)
|
||||||
|
metadata: Metadata | None = Field(
|
||||||
|
default=None,
|
||||||
|
description="Set of 16 key-value pairs that can be attached to an object. This can be useful for storing additional information about the object in a structured format, and querying for objects via API or the dashboard.",
|
||||||
|
)
|
||||||
|
items: list[dict] | None = Field(
|
||||||
|
default=None,
|
||||||
|
description="Initial items to include in the conversation context. You may add up to 20 items at a time.",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class ConversationMessage(BaseModel):
|
||||||
|
"""OpenAI-compatible message item for conversations."""
|
||||||
|
|
||||||
|
id: str = Field(..., description="unique identifier for this message")
|
||||||
|
content: list[dict] = Field(..., description="message content")
|
||||||
|
role: str = Field(..., description="message role")
|
||||||
|
status: str = Field(..., description="message status")
|
||||||
|
type: Literal["message"] = "message"
|
||||||
|
object: Literal["message"] = "message"
|
||||||
|
|
||||||
|
|
||||||
|
ConversationItem = Annotated[
|
||||||
|
OpenAIResponseMessage
|
||||||
|
| OpenAIResponseOutputMessageFunctionToolCall
|
||||||
|
| OpenAIResponseOutputMessageFileSearchToolCall
|
||||||
|
| OpenAIResponseOutputMessageWebSearchToolCall
|
||||||
|
| OpenAIResponseOutputMessageMCPCall
|
||||||
|
| OpenAIResponseOutputMessageMCPListTools,
|
||||||
|
Field(discriminator="type"),
|
||||||
|
]
|
||||||
|
register_schema(ConversationItem, name="ConversationItem")
|
||||||
|
|
||||||
|
# Using OpenAI types directly caused issues but some notes for reference:
|
||||||
|
# Note that ConversationItem is a Annotated Union of the types below:
|
||||||
|
# from openai.types.responses import *
|
||||||
|
# from openai.types.responses.response_item import *
|
||||||
|
# from openai.types.conversations import ConversationItem
|
||||||
|
# f = [
|
||||||
|
# ResponseFunctionToolCallItem,
|
||||||
|
# ResponseFunctionToolCallOutputItem,
|
||||||
|
# ResponseFileSearchToolCall,
|
||||||
|
# ResponseFunctionWebSearch,
|
||||||
|
# ImageGenerationCall,
|
||||||
|
# ResponseComputerToolCall,
|
||||||
|
# ResponseComputerToolCallOutputItem,
|
||||||
|
# ResponseReasoningItem,
|
||||||
|
# ResponseCodeInterpreterToolCall,
|
||||||
|
# LocalShellCall,
|
||||||
|
# LocalShellCallOutput,
|
||||||
|
# McpListTools,
|
||||||
|
# McpApprovalRequest,
|
||||||
|
# McpApprovalResponse,
|
||||||
|
# McpCall,
|
||||||
|
# ResponseCustomToolCall,
|
||||||
|
# ResponseCustomToolCallOutput
|
||||||
|
# ]
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class ConversationCreateRequest(BaseModel):
|
||||||
|
"""Request body for creating a conversation."""
|
||||||
|
|
||||||
|
items: list[ConversationItem] | None = Field(
|
||||||
|
default=[],
|
||||||
|
description="Initial items to include in the conversation context. You may add up to 20 items at a time.",
|
||||||
|
max_length=20,
|
||||||
|
)
|
||||||
|
metadata: Metadata | None = Field(
|
||||||
|
default={},
|
||||||
|
description="Set of 16 key-value pairs that can be attached to an object. Useful for storing additional information",
|
||||||
|
max_length=16,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class ConversationUpdateRequest(BaseModel):
|
||||||
|
"""Request body for updating a conversation."""
|
||||||
|
|
||||||
|
metadata: Metadata = Field(
|
||||||
|
...,
|
||||||
|
description="Set of 16 key-value pairs that can be attached to an object. This can be useful for storing additional information about the object in a structured format, and querying for objects via API or the dashboard. Keys are strings with a maximum length of 64 characters. Values are strings with a maximum length of 512 characters.",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class ConversationDeletedResource(BaseModel):
|
||||||
|
"""Response for deleted conversation."""
|
||||||
|
|
||||||
|
id: str = Field(..., description="The deleted conversation identifier")
|
||||||
|
object: str = Field(default="conversation.deleted", description="Object type")
|
||||||
|
deleted: bool = Field(default=True, description="Whether the object was deleted")
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class ConversationItemCreateRequest(BaseModel):
|
||||||
|
"""Request body for creating conversation items."""
|
||||||
|
|
||||||
|
items: list[ConversationItem] = Field(
|
||||||
|
...,
|
||||||
|
description="Items to include in the conversation context. You may add up to 20 items at a time.",
|
||||||
|
max_length=20,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class ConversationItemList(BaseModel):
|
||||||
|
"""List of conversation items with pagination."""
|
||||||
|
|
||||||
|
object: str = Field(default="list", description="Object type")
|
||||||
|
data: list[ConversationItem] = Field(..., description="List of conversation items")
|
||||||
|
first_id: str | None = Field(default=None, description="The ID of the first item in the list")
|
||||||
|
last_id: str | None = Field(default=None, description="The ID of the last item in the list")
|
||||||
|
has_more: bool = Field(default=False, description="Whether there are more items available")
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class ConversationItemDeletedResource(BaseModel):
|
||||||
|
"""Response for deleted conversation item."""
|
||||||
|
|
||||||
|
id: str = Field(..., description="The deleted item identifier")
|
||||||
|
object: str = Field(default="conversation.item.deleted", description="Object type")
|
||||||
|
deleted: bool = Field(default=True, description="Whether the object was deleted")
|
||||||
|
|
||||||
|
|
||||||
|
@runtime_checkable
|
||||||
|
@trace_protocol
|
||||||
|
class Conversations(Protocol):
|
||||||
|
"""Protocol for conversation management operations."""
|
||||||
|
|
||||||
|
@webmethod(route="/conversations", method="POST", level=LLAMA_STACK_API_V1)
|
||||||
|
async def create_conversation(
|
||||||
|
self, items: list[ConversationItem] | None = None, metadata: Metadata | None = None
|
||||||
|
) -> Conversation:
|
||||||
|
"""Create a conversation.
|
||||||
|
|
||||||
|
:param items: Initial items to include in the conversation context.
|
||||||
|
:param metadata: Set of key-value pairs that can be attached to an object.
|
||||||
|
:returns: The created conversation object.
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
|
@webmethod(route="/conversations/{conversation_id}", method="GET", level=LLAMA_STACK_API_V1)
|
||||||
|
async def get_conversation(self, conversation_id: str) -> Conversation:
|
||||||
|
"""Get a conversation with the given ID.
|
||||||
|
|
||||||
|
:param conversation_id: The conversation identifier.
|
||||||
|
:returns: The conversation object.
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
|
@webmethod(route="/conversations/{conversation_id}", method="POST", level=LLAMA_STACK_API_V1)
|
||||||
|
async def update_conversation(self, conversation_id: str, metadata: Metadata) -> Conversation:
|
||||||
|
"""Update a conversation's metadata with the given ID.
|
||||||
|
|
||||||
|
:param conversation_id: The conversation identifier.
|
||||||
|
:param metadata: Set of key-value pairs that can be attached to an object.
|
||||||
|
:returns: The updated conversation object.
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
|
@webmethod(route="/conversations/{conversation_id}", method="DELETE", level=LLAMA_STACK_API_V1)
|
||||||
|
async def openai_delete_conversation(self, conversation_id: str) -> ConversationDeletedResource:
|
||||||
|
"""Delete a conversation with the given ID.
|
||||||
|
|
||||||
|
:param conversation_id: The conversation identifier.
|
||||||
|
:returns: The deleted conversation resource.
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
|
@webmethod(route="/conversations/{conversation_id}/items", method="POST", level=LLAMA_STACK_API_V1)
|
||||||
|
async def create(self, conversation_id: str, items: list[ConversationItem]) -> ConversationItemList:
|
||||||
|
"""Create items in the conversation.
|
||||||
|
|
||||||
|
:param conversation_id: The conversation identifier.
|
||||||
|
:param items: Items to include in the conversation context.
|
||||||
|
:returns: List of created items.
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
|
@webmethod(route="/conversations/{conversation_id}/items/{item_id}", method="GET", level=LLAMA_STACK_API_V1)
|
||||||
|
async def retrieve(self, conversation_id: str, item_id: str) -> ConversationItem:
|
||||||
|
"""Retrieve a conversation item.
|
||||||
|
|
||||||
|
:param conversation_id: The conversation identifier.
|
||||||
|
:param item_id: The item identifier.
|
||||||
|
:returns: The conversation item.
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
|
@webmethod(route="/conversations/{conversation_id}/items", method="GET", level=LLAMA_STACK_API_V1)
|
||||||
|
async def list(
|
||||||
|
self,
|
||||||
|
conversation_id: str,
|
||||||
|
after: str | NotGiven = NOT_GIVEN,
|
||||||
|
include: list[ResponseIncludable] | NotGiven = NOT_GIVEN,
|
||||||
|
limit: int | NotGiven = NOT_GIVEN,
|
||||||
|
order: Literal["asc", "desc"] | NotGiven = NOT_GIVEN,
|
||||||
|
) -> ConversationItemList:
|
||||||
|
"""List items in the conversation.
|
||||||
|
|
||||||
|
:param conversation_id: The conversation identifier.
|
||||||
|
:param after: An item ID to list items after, used in pagination.
|
||||||
|
:param include: Specify additional output data to include in the response.
|
||||||
|
:param limit: A limit on the number of objects to be returned (1-100, default 20).
|
||||||
|
:param order: The order to return items in (asc or desc, default desc).
|
||||||
|
:returns: List of conversation items.
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
|
@webmethod(route="/conversations/{conversation_id}/items/{item_id}", method="DELETE", level=LLAMA_STACK_API_V1)
|
||||||
|
async def openai_delete_conversation_item(
|
||||||
|
self, conversation_id: str, item_id: str
|
||||||
|
) -> ConversationItemDeletedResource:
|
||||||
|
"""Delete a conversation item.
|
||||||
|
|
||||||
|
:param conversation_id: The conversation identifier.
|
||||||
|
:param item_id: The item identifier.
|
||||||
|
:returns: The deleted item resource.
|
||||||
|
"""
|
||||||
|
...
|
|
@ -129,6 +129,7 @@ class Api(Enum, metaclass=DynamicApiMeta):
|
||||||
tool_groups = "tool_groups"
|
tool_groups = "tool_groups"
|
||||||
files = "files"
|
files = "files"
|
||||||
prompts = "prompts"
|
prompts = "prompts"
|
||||||
|
conversations = "conversations"
|
||||||
|
|
||||||
# built-in API
|
# built-in API
|
||||||
inspect = "inspect"
|
inspect = "inspect"
|
||||||
|
|
5
llama_stack/core/conversations/__init__.py
Normal file
5
llama_stack/core/conversations/__init__.py
Normal file
|
@ -0,0 +1,5 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
290
llama_stack/core/conversations/conversations.py
Normal file
290
llama_stack/core/conversations/conversations.py
Normal file
|
@ -0,0 +1,290 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
import os
|
||||||
|
import secrets
|
||||||
|
import time
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from openai import NOT_GIVEN
|
||||||
|
from pydantic import BaseModel, TypeAdapter
|
||||||
|
|
||||||
|
from llama_stack.apis.conversations.conversations import (
|
||||||
|
Conversation,
|
||||||
|
ConversationDeletedResource,
|
||||||
|
ConversationItem,
|
||||||
|
ConversationItemDeletedResource,
|
||||||
|
ConversationItemList,
|
||||||
|
Conversations,
|
||||||
|
Metadata,
|
||||||
|
)
|
||||||
|
from llama_stack.core.datatypes import AccessRule, StackRunConfig
|
||||||
|
from llama_stack.core.utils.config_dirs import DISTRIBS_BASE_DIR
|
||||||
|
from llama_stack.log import get_logger
|
||||||
|
from llama_stack.providers.utils.sqlstore.api import ColumnDefinition, ColumnType
|
||||||
|
from llama_stack.providers.utils.sqlstore.authorized_sqlstore import AuthorizedSqlStore
|
||||||
|
from llama_stack.providers.utils.sqlstore.sqlstore import PostgresSqlStoreConfig, SqliteSqlStoreConfig, sqlstore_impl
|
||||||
|
|
||||||
|
logger = get_logger(name=__name__, category="openai::conversations")
|
||||||
|
|
||||||
|
|
||||||
|
class ConversationServiceConfig(BaseModel):
|
||||||
|
"""Configuration for the built-in conversation service.
|
||||||
|
|
||||||
|
:param run_config: Stack run configuration containing distribution info
|
||||||
|
"""
|
||||||
|
|
||||||
|
run_config: StackRunConfig
|
||||||
|
|
||||||
|
|
||||||
|
async def get_provider_impl(config: ConversationServiceConfig, deps: dict[Any, Any]):
|
||||||
|
"""Get the conversation service implementation."""
|
||||||
|
impl = ConversationServiceImpl(config, deps)
|
||||||
|
await impl.initialize()
|
||||||
|
return impl
|
||||||
|
|
||||||
|
|
||||||
|
class ConversationServiceImpl(Conversations):
|
||||||
|
"""Built-in conversation service implementation using AuthorizedSqlStore."""
|
||||||
|
|
||||||
|
def __init__(self, config: ConversationServiceConfig, deps: dict[Any, Any]):
|
||||||
|
self.config = config
|
||||||
|
self.deps = deps
|
||||||
|
self.policy: list[AccessRule] = []
|
||||||
|
|
||||||
|
conversations_store_config = config.run_config.conversations_store
|
||||||
|
if conversations_store_config is None:
|
||||||
|
sql_store_config: SqliteSqlStoreConfig | PostgresSqlStoreConfig = SqliteSqlStoreConfig(
|
||||||
|
db_path=(DISTRIBS_BASE_DIR / config.run_config.image_name / "conversations.db").as_posix()
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
sql_store_config = conversations_store_config
|
||||||
|
|
||||||
|
base_sql_store = sqlstore_impl(sql_store_config)
|
||||||
|
self.sql_store = AuthorizedSqlStore(base_sql_store, self.policy)
|
||||||
|
|
||||||
|
async def initialize(self) -> None:
|
||||||
|
"""Initialize the store and create tables."""
|
||||||
|
if hasattr(self.sql_store.sql_store, "config") and hasattr(self.sql_store.sql_store.config, "db_path"):
|
||||||
|
os.makedirs(os.path.dirname(self.sql_store.sql_store.config.db_path), exist_ok=True)
|
||||||
|
|
||||||
|
await self.sql_store.create_table(
|
||||||
|
"openai_conversations",
|
||||||
|
{
|
||||||
|
"id": ColumnDefinition(type=ColumnType.STRING, primary_key=True),
|
||||||
|
"created_at": ColumnType.INTEGER,
|
||||||
|
"items": ColumnType.JSON,
|
||||||
|
"metadata": ColumnType.JSON,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
async def create_conversation(
|
||||||
|
self, items: list[ConversationItem] | None = None, metadata: Metadata | None = None
|
||||||
|
) -> Conversation:
|
||||||
|
"""Create a conversation."""
|
||||||
|
random_bytes = secrets.token_bytes(24)
|
||||||
|
conversation_id = f"conv_{random_bytes.hex()}"
|
||||||
|
created_at = int(time.time())
|
||||||
|
|
||||||
|
items_json = []
|
||||||
|
for item in items or []:
|
||||||
|
item_dict = item.model_dump() if hasattr(item, "model_dump") else item
|
||||||
|
items_json.append(item_dict)
|
||||||
|
|
||||||
|
record_data = {
|
||||||
|
"id": conversation_id,
|
||||||
|
"created_at": created_at,
|
||||||
|
"items": items_json,
|
||||||
|
"metadata": metadata,
|
||||||
|
}
|
||||||
|
|
||||||
|
await self.sql_store.insert(
|
||||||
|
table="openai_conversations",
|
||||||
|
data=record_data,
|
||||||
|
)
|
||||||
|
|
||||||
|
conversation = Conversation(
|
||||||
|
id=conversation_id,
|
||||||
|
created_at=created_at,
|
||||||
|
metadata=metadata,
|
||||||
|
object="conversation",
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(f"Created conversation {conversation_id}")
|
||||||
|
return conversation
|
||||||
|
|
||||||
|
async def get_conversation(self, conversation_id: str) -> Conversation:
|
||||||
|
"""Get a conversation with the given ID."""
|
||||||
|
record = await self.sql_store.fetch_one(table="openai_conversations", where={"id": conversation_id})
|
||||||
|
|
||||||
|
if record is None:
|
||||||
|
raise ValueError(f"Conversation {conversation_id} not found")
|
||||||
|
|
||||||
|
return Conversation(
|
||||||
|
id=record["id"], created_at=record["created_at"], metadata=record.get("metadata"), object="conversation"
|
||||||
|
)
|
||||||
|
|
||||||
|
async def update_conversation(self, conversation_id: str, metadata: Metadata) -> Conversation:
|
||||||
|
"""Update a conversation's metadata with the given ID"""
|
||||||
|
await self.sql_store.update(
|
||||||
|
table="openai_conversations", data={"metadata": metadata}, where={"id": conversation_id}
|
||||||
|
)
|
||||||
|
|
||||||
|
return await self.get_conversation(conversation_id)
|
||||||
|
|
||||||
|
async def openai_delete_conversation(self, conversation_id: str) -> ConversationDeletedResource:
|
||||||
|
"""Delete a conversation with the given ID."""
|
||||||
|
await self.sql_store.delete(table="openai_conversations", where={"id": conversation_id})
|
||||||
|
|
||||||
|
logger.info(f"Deleted conversation {conversation_id}")
|
||||||
|
return ConversationDeletedResource(id=conversation_id)
|
||||||
|
|
||||||
|
def _validate_conversation_id(self, conversation_id: str) -> None:
|
||||||
|
"""Validate conversation ID format."""
|
||||||
|
if not conversation_id.startswith("conv_"):
|
||||||
|
raise ValueError(
|
||||||
|
f"Invalid 'conversation_id': '{conversation_id}'. Expected an ID that begins with 'conv_'."
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _get_validated_conversation(self, conversation_id: str) -> Conversation:
|
||||||
|
"""Validate conversation ID and return the conversation if it exists."""
|
||||||
|
self._validate_conversation_id(conversation_id)
|
||||||
|
return await self.get_conversation(conversation_id)
|
||||||
|
|
||||||
|
async def create(self, conversation_id: str, items: list[ConversationItem]) -> ConversationItemList:
|
||||||
|
"""Create items in the conversation."""
|
||||||
|
await self._get_validated_conversation(conversation_id)
|
||||||
|
|
||||||
|
created_items = []
|
||||||
|
|
||||||
|
for item in items:
|
||||||
|
# Generate item ID based on item type
|
||||||
|
random_bytes = secrets.token_bytes(24)
|
||||||
|
item_type = getattr(item, "type", None)
|
||||||
|
if item_type == "message":
|
||||||
|
item_id = f"msg_{random_bytes.hex()}"
|
||||||
|
else:
|
||||||
|
item_id = f"item_{random_bytes.hex()}"
|
||||||
|
|
||||||
|
# Create a copy of the item with the generated ID and completed status
|
||||||
|
item_dict = item.model_dump() if hasattr(item, "model_dump") else dict(item)
|
||||||
|
item_dict["id"] = item_id
|
||||||
|
if "status" not in item_dict:
|
||||||
|
item_dict["status"] = "completed"
|
||||||
|
|
||||||
|
created_items.append(item_dict)
|
||||||
|
|
||||||
|
# Get existing items from database
|
||||||
|
record = await self.sql_store.fetch_one(table="openai_conversations", where={"id": conversation_id})
|
||||||
|
existing_items = record.get("items", []) if record else []
|
||||||
|
|
||||||
|
updated_items = existing_items + created_items
|
||||||
|
await self.sql_store.update(
|
||||||
|
table="openai_conversations", data={"items": updated_items}, where={"id": conversation_id}
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(f"Created {len(created_items)} items in conversation {conversation_id}")
|
||||||
|
|
||||||
|
# Convert created items (dicts) to proper ConversationItem types
|
||||||
|
adapter: TypeAdapter[ConversationItem] = TypeAdapter(ConversationItem)
|
||||||
|
response_items: list[ConversationItem] = [adapter.validate_python(item_dict) for item_dict in created_items]
|
||||||
|
|
||||||
|
return ConversationItemList(
|
||||||
|
data=response_items,
|
||||||
|
first_id=created_items[0]["id"] if created_items else None,
|
||||||
|
last_id=created_items[-1]["id"] if created_items else None,
|
||||||
|
has_more=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def retrieve(self, conversation_id: str, item_id: str) -> ConversationItem:
|
||||||
|
"""Retrieve a conversation item."""
|
||||||
|
if not conversation_id:
|
||||||
|
raise ValueError(f"Expected a non-empty value for `conversation_id` but received {conversation_id!r}")
|
||||||
|
if not item_id:
|
||||||
|
raise ValueError(f"Expected a non-empty value for `item_id` but received {item_id!r}")
|
||||||
|
|
||||||
|
record = await self.sql_store.fetch_one(table="openai_conversations", where={"id": conversation_id})
|
||||||
|
items = record.get("items", []) if record else []
|
||||||
|
|
||||||
|
for item in items:
|
||||||
|
if isinstance(item, dict) and item.get("id") == item_id:
|
||||||
|
adapter: TypeAdapter[ConversationItem] = TypeAdapter(ConversationItem)
|
||||||
|
return adapter.validate_python(item)
|
||||||
|
|
||||||
|
raise ValueError(f"Item {item_id} not found in conversation {conversation_id}")
|
||||||
|
|
||||||
|
async def list(self, conversation_id: str, after=NOT_GIVEN, include=NOT_GIVEN, limit=NOT_GIVEN, order=NOT_GIVEN):
|
||||||
|
"""List items in the conversation."""
|
||||||
|
record = await self.sql_store.fetch_one(table="openai_conversations", where={"id": conversation_id})
|
||||||
|
items = record.get("items", []) if record else []
|
||||||
|
|
||||||
|
if order != NOT_GIVEN and order == "asc":
|
||||||
|
items = items
|
||||||
|
else:
|
||||||
|
items = list(reversed(items))
|
||||||
|
|
||||||
|
actual_limit = 20
|
||||||
|
if limit != NOT_GIVEN and isinstance(limit, int):
|
||||||
|
actual_limit = limit
|
||||||
|
|
||||||
|
items = items[:actual_limit]
|
||||||
|
|
||||||
|
# Items from database are stored as dicts, convert them to ConversationItem
|
||||||
|
adapter: TypeAdapter[ConversationItem] = TypeAdapter(ConversationItem)
|
||||||
|
response_items: list[ConversationItem] = [
|
||||||
|
adapter.validate_python(item) if isinstance(item, dict) else item for item in items
|
||||||
|
]
|
||||||
|
|
||||||
|
# Get first and last IDs safely
|
||||||
|
first_id = None
|
||||||
|
last_id = None
|
||||||
|
if items:
|
||||||
|
first_item = items[0]
|
||||||
|
last_item = items[-1]
|
||||||
|
|
||||||
|
first_id = first_item.get("id") if isinstance(first_item, dict) else getattr(first_item, "id", None)
|
||||||
|
last_id = last_item.get("id") if isinstance(last_item, dict) else getattr(last_item, "id", None)
|
||||||
|
|
||||||
|
return ConversationItemList(
|
||||||
|
data=response_items,
|
||||||
|
first_id=first_id,
|
||||||
|
last_id=last_id,
|
||||||
|
has_more=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def openai_delete_conversation_item(
|
||||||
|
self, conversation_id: str, item_id: str
|
||||||
|
) -> ConversationItemDeletedResource:
|
||||||
|
"""Delete a conversation item."""
|
||||||
|
if not conversation_id:
|
||||||
|
raise ValueError(f"Expected a non-empty value for `conversation_id` but received {conversation_id!r}")
|
||||||
|
if not item_id:
|
||||||
|
raise ValueError(f"Expected a non-empty value for `item_id` but received {item_id!r}")
|
||||||
|
|
||||||
|
_ = await self._get_validated_conversation(conversation_id) # executes validation
|
||||||
|
|
||||||
|
record = await self.sql_store.fetch_one(table="openai_conversations", where={"id": conversation_id})
|
||||||
|
items = record.get("items", []) if record else []
|
||||||
|
|
||||||
|
updated_items = []
|
||||||
|
item_found = False
|
||||||
|
|
||||||
|
for item in items:
|
||||||
|
current_item_id = item.get("id") if isinstance(item, dict) else getattr(item, "id", None)
|
||||||
|
if current_item_id != item_id:
|
||||||
|
updated_items.append(item)
|
||||||
|
else:
|
||||||
|
item_found = True
|
||||||
|
|
||||||
|
if not item_found:
|
||||||
|
raise ValueError(f"Item {item_id} not found in conversation {conversation_id}")
|
||||||
|
|
||||||
|
await self.sql_store.update(
|
||||||
|
table="openai_conversations", data={"items": updated_items}, where={"id": conversation_id}
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(f"Deleted item {item_id} from conversation {conversation_id}")
|
||||||
|
return ConversationItemDeletedResource(id=item_id)
|
|
@ -480,6 +480,13 @@ InferenceStoreConfig (with queue tuning parameters) or a SqlStoreConfig (depreca
|
||||||
If not specified, a default SQLite store will be used.""",
|
If not specified, a default SQLite store will be used.""",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
conversations_store: SqlStoreConfig | None = Field(
|
||||||
|
default=None,
|
||||||
|
description="""
|
||||||
|
Configuration for the persistence store used by the conversations API.
|
||||||
|
If not specified, a default SQLite store will be used.""",
|
||||||
|
)
|
||||||
|
|
||||||
# registry of "resources" in the distribution
|
# registry of "resources" in the distribution
|
||||||
models: list[ModelInput] = Field(default_factory=list)
|
models: list[ModelInput] = Field(default_factory=list)
|
||||||
shields: list[ShieldInput] = Field(default_factory=list)
|
shields: list[ShieldInput] = Field(default_factory=list)
|
||||||
|
|
|
@ -25,7 +25,7 @@ from llama_stack.providers.datatypes import (
|
||||||
logger = get_logger(name=__name__, category="core")
|
logger = get_logger(name=__name__, category="core")
|
||||||
|
|
||||||
|
|
||||||
INTERNAL_APIS = {Api.inspect, Api.providers, Api.prompts}
|
INTERNAL_APIS = {Api.inspect, Api.providers, Api.prompts, Api.conversations}
|
||||||
|
|
||||||
|
|
||||||
def stack_apis() -> list[Api]:
|
def stack_apis() -> list[Api]:
|
||||||
|
|
|
@ -10,6 +10,7 @@ from typing import Any
|
||||||
from llama_stack.apis.agents import Agents
|
from llama_stack.apis.agents import Agents
|
||||||
from llama_stack.apis.batches import Batches
|
from llama_stack.apis.batches import Batches
|
||||||
from llama_stack.apis.benchmarks import Benchmarks
|
from llama_stack.apis.benchmarks import Benchmarks
|
||||||
|
from llama_stack.apis.conversations import Conversations
|
||||||
from llama_stack.apis.datasetio import DatasetIO
|
from llama_stack.apis.datasetio import DatasetIO
|
||||||
from llama_stack.apis.datasets import Datasets
|
from llama_stack.apis.datasets import Datasets
|
||||||
from llama_stack.apis.datatypes import ExternalApiSpec
|
from llama_stack.apis.datatypes import ExternalApiSpec
|
||||||
|
@ -96,6 +97,7 @@ def api_protocol_map(external_apis: dict[Api, ExternalApiSpec] | None = None) ->
|
||||||
Api.tool_runtime: ToolRuntime,
|
Api.tool_runtime: ToolRuntime,
|
||||||
Api.files: Files,
|
Api.files: Files,
|
||||||
Api.prompts: Prompts,
|
Api.prompts: Prompts,
|
||||||
|
Api.conversations: Conversations,
|
||||||
}
|
}
|
||||||
|
|
||||||
if external_apis:
|
if external_apis:
|
||||||
|
|
|
@ -451,6 +451,7 @@ def create_app(
|
||||||
apis_to_serve.add("inspect")
|
apis_to_serve.add("inspect")
|
||||||
apis_to_serve.add("providers")
|
apis_to_serve.add("providers")
|
||||||
apis_to_serve.add("prompts")
|
apis_to_serve.add("prompts")
|
||||||
|
apis_to_serve.add("conversations")
|
||||||
for api_str in apis_to_serve:
|
for api_str in apis_to_serve:
|
||||||
api = Api(api_str)
|
api = Api(api_str)
|
||||||
|
|
||||||
|
|
|
@ -15,6 +15,7 @@ import yaml
|
||||||
|
|
||||||
from llama_stack.apis.agents import Agents
|
from llama_stack.apis.agents import Agents
|
||||||
from llama_stack.apis.benchmarks import Benchmarks
|
from llama_stack.apis.benchmarks import Benchmarks
|
||||||
|
from llama_stack.apis.conversations import Conversations
|
||||||
from llama_stack.apis.datasetio import DatasetIO
|
from llama_stack.apis.datasetio import DatasetIO
|
||||||
from llama_stack.apis.datasets import Datasets
|
from llama_stack.apis.datasets import Datasets
|
||||||
from llama_stack.apis.eval import Eval
|
from llama_stack.apis.eval import Eval
|
||||||
|
@ -34,6 +35,7 @@ from llama_stack.apis.telemetry import Telemetry
|
||||||
from llama_stack.apis.tools import RAGToolRuntime, ToolGroups, ToolRuntime
|
from llama_stack.apis.tools import RAGToolRuntime, ToolGroups, ToolRuntime
|
||||||
from llama_stack.apis.vector_dbs import VectorDBs
|
from llama_stack.apis.vector_dbs import VectorDBs
|
||||||
from llama_stack.apis.vector_io import VectorIO
|
from llama_stack.apis.vector_io import VectorIO
|
||||||
|
from llama_stack.core.conversations.conversations import ConversationServiceConfig, ConversationServiceImpl
|
||||||
from llama_stack.core.datatypes import Provider, StackRunConfig
|
from llama_stack.core.datatypes import Provider, StackRunConfig
|
||||||
from llama_stack.core.distribution import get_provider_registry
|
from llama_stack.core.distribution import get_provider_registry
|
||||||
from llama_stack.core.inspect import DistributionInspectConfig, DistributionInspectImpl
|
from llama_stack.core.inspect import DistributionInspectConfig, DistributionInspectImpl
|
||||||
|
@ -73,6 +75,7 @@ class LlamaStack(
|
||||||
RAGToolRuntime,
|
RAGToolRuntime,
|
||||||
Files,
|
Files,
|
||||||
Prompts,
|
Prompts,
|
||||||
|
Conversations,
|
||||||
):
|
):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@ -312,6 +315,12 @@ def add_internal_implementations(impls: dict[Api, Any], run_config: StackRunConf
|
||||||
)
|
)
|
||||||
impls[Api.prompts] = prompts_impl
|
impls[Api.prompts] = prompts_impl
|
||||||
|
|
||||||
|
conversations_impl = ConversationServiceImpl(
|
||||||
|
ConversationServiceConfig(run_config=run_config),
|
||||||
|
deps=impls,
|
||||||
|
)
|
||||||
|
impls[Api.conversations] = conversations_impl
|
||||||
|
|
||||||
|
|
||||||
class Stack:
|
class Stack:
|
||||||
def __init__(self, run_config: StackRunConfig, provider_registry: ProviderRegistry | None = None):
|
def __init__(self, run_config: StackRunConfig, provider_registry: ProviderRegistry | None = None):
|
||||||
|
@ -342,6 +351,8 @@ class Stack:
|
||||||
|
|
||||||
if Api.prompts in impls:
|
if Api.prompts in impls:
|
||||||
await impls[Api.prompts].initialize()
|
await impls[Api.prompts].initialize()
|
||||||
|
if Api.conversations in impls:
|
||||||
|
await impls[Api.conversations].initialize()
|
||||||
|
|
||||||
await register_resources(self.run_config, impls)
|
await register_resources(self.run_config, impls)
|
||||||
|
|
||||||
|
|
|
@ -484,12 +484,19 @@ class JsonSchemaGenerator:
|
||||||
}
|
}
|
||||||
return ret
|
return ret
|
||||||
elif origin_type is Literal:
|
elif origin_type is Literal:
|
||||||
if len(typing.get_args(typ)) != 1:
|
literal_args = typing.get_args(typ)
|
||||||
raise ValueError(f"Literal type {typ} has {len(typing.get_args(typ))} arguments")
|
if len(literal_args) == 1:
|
||||||
(literal_value,) = typing.get_args(typ) # unpack value of literal type
|
(literal_value,) = literal_args
|
||||||
schema = self.type_to_schema(type(literal_value))
|
schema = self.type_to_schema(type(literal_value))
|
||||||
schema["const"] = literal_value
|
schema["const"] = literal_value
|
||||||
return schema
|
return schema
|
||||||
|
elif len(literal_args) > 1:
|
||||||
|
first_value = literal_args[0]
|
||||||
|
schema = self.type_to_schema(type(first_value))
|
||||||
|
schema["enum"] = list(literal_args)
|
||||||
|
return schema
|
||||||
|
else:
|
||||||
|
return {"enum": []}
|
||||||
elif origin_type is type:
|
elif origin_type is type:
|
||||||
(concrete_type,) = typing.get_args(typ) # unpack single tuple element
|
(concrete_type,) = typing.get_args(typ) # unpack single tuple element
|
||||||
return {"const": self.type_to_schema(concrete_type, force_expand=True)}
|
return {"const": self.type_to_schema(concrete_type, force_expand=True)}
|
||||||
|
|
|
@ -32,7 +32,7 @@ dependencies = [
|
||||||
"jinja2>=3.1.6",
|
"jinja2>=3.1.6",
|
||||||
"jsonschema",
|
"jsonschema",
|
||||||
"llama-stack-client>=0.2.23",
|
"llama-stack-client>=0.2.23",
|
||||||
"openai>=1.100.0", # for expires_after support
|
"openai>=1.107", # for expires_after support
|
||||||
"prompt-toolkit",
|
"prompt-toolkit",
|
||||||
"python-dotenv",
|
"python-dotenv",
|
||||||
"python-jose[cryptography]",
|
"python-jose[cryptography]",
|
||||||
|
@ -49,6 +49,7 @@ dependencies = [
|
||||||
"opentelemetry-exporter-otlp-proto-http>=1.30.0", # server
|
"opentelemetry-exporter-otlp-proto-http>=1.30.0", # server
|
||||||
"aiosqlite>=0.21.0", # server - for metadata store
|
"aiosqlite>=0.21.0", # server - for metadata store
|
||||||
"asyncpg", # for metadata store
|
"asyncpg", # for metadata store
|
||||||
|
"sqlalchemy[asyncio]>=2.0.41", # server - for conversations
|
||||||
]
|
]
|
||||||
|
|
||||||
[project.optional-dependencies]
|
[project.optional-dependencies]
|
||||||
|
|
134
tests/integration/conversations/test_openai_conversations.py
Normal file
134
tests/integration/conversations/test_openai_conversations.py
Normal file
|
@ -0,0 +1,134 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.integration
|
||||||
|
class TestOpenAIConversations:
|
||||||
|
def test_conversation_create(self, openai_client):
|
||||||
|
conversation = openai_client.conversations.create(
|
||||||
|
metadata={"topic": "demo"}, items=[{"type": "message", "role": "user", "content": "Hello!"}]
|
||||||
|
)
|
||||||
|
|
||||||
|
assert conversation.id.startswith("conv_")
|
||||||
|
assert conversation.object == "conversation"
|
||||||
|
assert conversation.metadata["topic"] == "demo"
|
||||||
|
assert isinstance(conversation.created_at, int)
|
||||||
|
|
||||||
|
def test_conversation_retrieve(self, openai_client):
|
||||||
|
conversation = openai_client.conversations.create(metadata={"topic": "demo"})
|
||||||
|
|
||||||
|
retrieved = openai_client.conversations.retrieve(conversation.id)
|
||||||
|
|
||||||
|
assert retrieved.id == conversation.id
|
||||||
|
assert retrieved.object == "conversation"
|
||||||
|
assert retrieved.metadata["topic"] == "demo"
|
||||||
|
assert retrieved.created_at == conversation.created_at
|
||||||
|
|
||||||
|
def test_conversation_update(self, openai_client):
|
||||||
|
conversation = openai_client.conversations.create(metadata={"topic": "demo"})
|
||||||
|
|
||||||
|
updated = openai_client.conversations.update(conversation.id, metadata={"topic": "project-x"})
|
||||||
|
|
||||||
|
assert updated.id == conversation.id
|
||||||
|
assert updated.metadata["topic"] == "project-x"
|
||||||
|
assert updated.created_at == conversation.created_at
|
||||||
|
|
||||||
|
def test_conversation_delete(self, openai_client):
|
||||||
|
conversation = openai_client.conversations.create(metadata={"topic": "demo"})
|
||||||
|
|
||||||
|
deleted = openai_client.conversations.delete(conversation.id)
|
||||||
|
|
||||||
|
assert deleted.id == conversation.id
|
||||||
|
assert deleted.object == "conversation.deleted"
|
||||||
|
assert deleted.deleted is True
|
||||||
|
|
||||||
|
def test_conversation_items_create(self, openai_client):
|
||||||
|
conversation = openai_client.conversations.create()
|
||||||
|
|
||||||
|
items = openai_client.conversations.items.create(
|
||||||
|
conversation.id,
|
||||||
|
items=[
|
||||||
|
{"type": "message", "role": "user", "content": [{"type": "input_text", "text": "Hello!"}]},
|
||||||
|
{"type": "message", "role": "user", "content": [{"type": "input_text", "text": "How are you?"}]},
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
assert items.object == "list"
|
||||||
|
assert len(items.data) == 2
|
||||||
|
assert items.data[0].content[0].text == "Hello!"
|
||||||
|
assert items.data[1].content[0].text == "How are you?"
|
||||||
|
assert items.first_id == items.data[0].id
|
||||||
|
assert items.last_id == items.data[1].id
|
||||||
|
assert items.has_more is False
|
||||||
|
|
||||||
|
def test_conversation_items_list(self, openai_client):
|
||||||
|
conversation = openai_client.conversations.create()
|
||||||
|
|
||||||
|
openai_client.conversations.items.create(
|
||||||
|
conversation.id,
|
||||||
|
items=[{"type": "message", "role": "user", "content": [{"type": "input_text", "text": "Hello!"}]}],
|
||||||
|
)
|
||||||
|
|
||||||
|
items = openai_client.conversations.items.list(conversation.id, limit=10)
|
||||||
|
|
||||||
|
assert items.object == "list"
|
||||||
|
assert len(items.data) >= 1
|
||||||
|
assert items.data[0].type == "message"
|
||||||
|
assert items.data[0].role == "user"
|
||||||
|
assert hasattr(items, "first_id")
|
||||||
|
assert hasattr(items, "last_id")
|
||||||
|
assert hasattr(items, "has_more")
|
||||||
|
|
||||||
|
def test_conversation_item_retrieve(self, openai_client):
|
||||||
|
conversation = openai_client.conversations.create()
|
||||||
|
|
||||||
|
created_items = openai_client.conversations.items.create(
|
||||||
|
conversation.id,
|
||||||
|
items=[{"type": "message", "role": "user", "content": [{"type": "input_text", "text": "Hello!"}]}],
|
||||||
|
)
|
||||||
|
|
||||||
|
item_id = created_items.data[0].id
|
||||||
|
item = openai_client.conversations.items.retrieve(item_id, conversation_id=conversation.id)
|
||||||
|
|
||||||
|
assert item.id == item_id
|
||||||
|
assert item.type == "message"
|
||||||
|
assert item.role == "user"
|
||||||
|
assert item.content[0].text == "Hello!"
|
||||||
|
|
||||||
|
def test_conversation_item_delete(self, openai_client):
|
||||||
|
conversation = openai_client.conversations.create()
|
||||||
|
|
||||||
|
created_items = openai_client.conversations.items.create(
|
||||||
|
conversation.id,
|
||||||
|
items=[{"type": "message", "role": "user", "content": [{"type": "input_text", "text": "Hello!"}]}],
|
||||||
|
)
|
||||||
|
|
||||||
|
item_id = created_items.data[0].id
|
||||||
|
deleted = openai_client.conversations.items.delete(item_id, conversation_id=conversation.id)
|
||||||
|
|
||||||
|
assert deleted.id == item_id
|
||||||
|
assert deleted.object == "conversation.item.deleted"
|
||||||
|
assert deleted.deleted is True
|
||||||
|
|
||||||
|
def test_full_workflow(self, openai_client):
|
||||||
|
conversation = openai_client.conversations.create(
|
||||||
|
metadata={"topic": "workflow-test"}, items=[{"type": "message", "role": "user", "content": "Hello!"}]
|
||||||
|
)
|
||||||
|
|
||||||
|
openai_client.conversations.items.create(
|
||||||
|
conversation.id,
|
||||||
|
items=[{"type": "message", "role": "user", "content": [{"type": "input_text", "text": "Follow up"}]}],
|
||||||
|
)
|
||||||
|
|
||||||
|
all_items = openai_client.conversations.items.list(conversation.id)
|
||||||
|
assert len(all_items.data) >= 2
|
||||||
|
|
||||||
|
updated = openai_client.conversations.update(conversation.id, metadata={"topic": "workflow-complete"})
|
||||||
|
assert updated.metadata["topic"] == "workflow-complete"
|
||||||
|
|
||||||
|
openai_client.conversations.delete(conversation.id)
|
60
tests/unit/conversations/test_api_models.py
Normal file
60
tests/unit/conversations/test_api_models.py
Normal file
|
@ -0,0 +1,60 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
|
||||||
|
from llama_stack.apis.conversations.conversations import (
|
||||||
|
Conversation,
|
||||||
|
ConversationCreateRequest,
|
||||||
|
ConversationItem,
|
||||||
|
ConversationItemList,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_conversation_create_request_defaults():
|
||||||
|
request = ConversationCreateRequest()
|
||||||
|
assert request.items == []
|
||||||
|
assert request.metadata == {}
|
||||||
|
|
||||||
|
|
||||||
|
def test_conversation_model_defaults():
|
||||||
|
conversation = Conversation(
|
||||||
|
id="conv_123456789",
|
||||||
|
created_at=1234567890,
|
||||||
|
metadata=None,
|
||||||
|
object="conversation",
|
||||||
|
)
|
||||||
|
assert conversation.id == "conv_123456789"
|
||||||
|
assert conversation.object == "conversation"
|
||||||
|
assert conversation.metadata is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_openai_client_compatibility():
|
||||||
|
from openai.types.conversations.message import Message
|
||||||
|
from pydantic import TypeAdapter
|
||||||
|
|
||||||
|
openai_message = Message(
|
||||||
|
id="msg_123",
|
||||||
|
content=[{"type": "input_text", "text": "Hello"}],
|
||||||
|
role="user",
|
||||||
|
status="in_progress",
|
||||||
|
type="message",
|
||||||
|
object="message",
|
||||||
|
)
|
||||||
|
|
||||||
|
adapter = TypeAdapter(ConversationItem)
|
||||||
|
validated_item = adapter.validate_python(openai_message.model_dump())
|
||||||
|
|
||||||
|
assert validated_item.id == "msg_123"
|
||||||
|
assert validated_item.type == "message"
|
||||||
|
|
||||||
|
|
||||||
|
def test_conversation_item_list():
|
||||||
|
item_list = ConversationItemList(data=[])
|
||||||
|
assert item_list.object == "list"
|
||||||
|
assert item_list.data == []
|
||||||
|
assert item_list.first_id is None
|
||||||
|
assert item_list.last_id is None
|
||||||
|
assert item_list.has_more is False
|
117
tests/unit/conversations/test_conversations.py
Normal file
117
tests/unit/conversations/test_conversations.py
Normal file
|
@ -0,0 +1,117 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
import tempfile
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from openai.types.conversations.conversation import Conversation as OpenAIConversation
|
||||||
|
from openai.types.conversations.conversation_item import ConversationItem as OpenAIConversationItem
|
||||||
|
from pydantic import TypeAdapter
|
||||||
|
|
||||||
|
from llama_stack.apis.agents.openai_responses import (
|
||||||
|
OpenAIResponseInputMessageContentText,
|
||||||
|
OpenAIResponseMessage,
|
||||||
|
)
|
||||||
|
from llama_stack.core.conversations.conversations import (
|
||||||
|
ConversationServiceConfig,
|
||||||
|
ConversationServiceImpl,
|
||||||
|
)
|
||||||
|
from llama_stack.core.datatypes import StackRunConfig
|
||||||
|
from llama_stack.providers.utils.sqlstore.sqlstore import SqliteSqlStoreConfig
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
async def service():
|
||||||
|
with tempfile.TemporaryDirectory() as tmpdir:
|
||||||
|
db_path = Path(tmpdir) / "test_conversations.db"
|
||||||
|
|
||||||
|
config = ConversationServiceConfig(
|
||||||
|
run_config=StackRunConfig(
|
||||||
|
image_name="test",
|
||||||
|
providers={},
|
||||||
|
conversations_store=SqliteSqlStoreConfig(db_path=str(db_path)),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
service = ConversationServiceImpl(config, {})
|
||||||
|
await service.initialize()
|
||||||
|
yield service
|
||||||
|
|
||||||
|
|
||||||
|
async def test_conversation_lifecycle(service):
|
||||||
|
conversation = await service.create_conversation(metadata={"test": "data"})
|
||||||
|
|
||||||
|
assert conversation.id.startswith("conv_")
|
||||||
|
assert conversation.metadata == {"test": "data"}
|
||||||
|
|
||||||
|
retrieved = await service.get_conversation(conversation.id)
|
||||||
|
assert retrieved.id == conversation.id
|
||||||
|
|
||||||
|
deleted = await service.openai_delete_conversation(conversation.id)
|
||||||
|
assert deleted.id == conversation.id
|
||||||
|
|
||||||
|
|
||||||
|
async def test_conversation_items(service):
|
||||||
|
conversation = await service.create_conversation()
|
||||||
|
|
||||||
|
items = [
|
||||||
|
OpenAIResponseMessage(
|
||||||
|
type="message",
|
||||||
|
role="user",
|
||||||
|
content=[OpenAIResponseInputMessageContentText(type="input_text", text="Hello")],
|
||||||
|
id="msg_test123",
|
||||||
|
status="completed",
|
||||||
|
)
|
||||||
|
]
|
||||||
|
item_list = await service.create(conversation.id, items)
|
||||||
|
|
||||||
|
assert len(item_list.data) == 1
|
||||||
|
assert item_list.data[0].id.startswith("msg_")
|
||||||
|
|
||||||
|
items = await service.list(conversation.id)
|
||||||
|
assert len(items.data) == 1
|
||||||
|
|
||||||
|
|
||||||
|
async def test_invalid_conversation_id(service):
|
||||||
|
with pytest.raises(ValueError, match="Expected an ID that begins with 'conv_'"):
|
||||||
|
await service._get_validated_conversation("invalid_id")
|
||||||
|
|
||||||
|
|
||||||
|
async def test_empty_parameter_validation(service):
|
||||||
|
with pytest.raises(ValueError, match="Expected a non-empty value"):
|
||||||
|
await service.retrieve("", "item_123")
|
||||||
|
|
||||||
|
|
||||||
|
async def test_openai_type_compatibility(service):
|
||||||
|
conversation = await service.create_conversation(metadata={"test": "value"})
|
||||||
|
|
||||||
|
conversation_dict = conversation.model_dump()
|
||||||
|
openai_conversation = OpenAIConversation.model_validate(conversation_dict)
|
||||||
|
|
||||||
|
for attr in ["id", "object", "created_at", "metadata"]:
|
||||||
|
assert getattr(openai_conversation, attr) == getattr(conversation, attr)
|
||||||
|
|
||||||
|
items = [
|
||||||
|
OpenAIResponseMessage(
|
||||||
|
type="message",
|
||||||
|
role="user",
|
||||||
|
content=[OpenAIResponseInputMessageContentText(type="input_text", text="Hello")],
|
||||||
|
id="msg_test456",
|
||||||
|
status="completed",
|
||||||
|
)
|
||||||
|
]
|
||||||
|
item_list = await service.create(conversation.id, items)
|
||||||
|
|
||||||
|
for attr in ["object", "data", "first_id", "last_id", "has_more"]:
|
||||||
|
assert hasattr(item_list, attr)
|
||||||
|
assert item_list.object == "list"
|
||||||
|
|
||||||
|
items = await service.list(conversation.id)
|
||||||
|
item = await service.retrieve(conversation.id, items.data[0].id)
|
||||||
|
item_dict = item.model_dump()
|
||||||
|
|
||||||
|
openai_item_adapter = TypeAdapter(OpenAIConversationItem)
|
||||||
|
openai_item_adapter.validate_python(item_dict)
|
4
uv.lock
generated
4
uv.lock
generated
|
@ -1773,6 +1773,7 @@ dependencies = [
|
||||||
{ name = "python-jose", extra = ["cryptography"] },
|
{ name = "python-jose", extra = ["cryptography"] },
|
||||||
{ name = "python-multipart" },
|
{ name = "python-multipart" },
|
||||||
{ name = "rich" },
|
{ name = "rich" },
|
||||||
|
{ name = "sqlalchemy", extra = ["asyncio"] },
|
||||||
{ name = "starlette" },
|
{ name = "starlette" },
|
||||||
{ name = "termcolor" },
|
{ name = "termcolor" },
|
||||||
{ name = "tiktoken" },
|
{ name = "tiktoken" },
|
||||||
|
@ -1887,7 +1888,7 @@ requires-dist = [
|
||||||
{ name = "jsonschema" },
|
{ name = "jsonschema" },
|
||||||
{ name = "llama-stack-client", specifier = ">=0.2.23" },
|
{ name = "llama-stack-client", specifier = ">=0.2.23" },
|
||||||
{ name = "llama-stack-client", marker = "extra == 'ui'", specifier = ">=0.2.23" },
|
{ name = "llama-stack-client", marker = "extra == 'ui'", specifier = ">=0.2.23" },
|
||||||
{ name = "openai", specifier = ">=1.100.0" },
|
{ name = "openai", specifier = ">=1.107" },
|
||||||
{ name = "opentelemetry-exporter-otlp-proto-http", specifier = ">=1.30.0" },
|
{ name = "opentelemetry-exporter-otlp-proto-http", specifier = ">=1.30.0" },
|
||||||
{ name = "opentelemetry-sdk", specifier = ">=1.30.0" },
|
{ name = "opentelemetry-sdk", specifier = ">=1.30.0" },
|
||||||
{ name = "pandas", marker = "extra == 'ui'" },
|
{ name = "pandas", marker = "extra == 'ui'" },
|
||||||
|
@ -1898,6 +1899,7 @@ requires-dist = [
|
||||||
{ name = "python-jose", extras = ["cryptography"] },
|
{ name = "python-jose", extras = ["cryptography"] },
|
||||||
{ name = "python-multipart", specifier = ">=0.0.20" },
|
{ name = "python-multipart", specifier = ">=0.0.20" },
|
||||||
{ name = "rich" },
|
{ name = "rich" },
|
||||||
|
{ name = "sqlalchemy", extras = ["asyncio"], specifier = ">=2.0.41" },
|
||||||
{ name = "starlette" },
|
{ name = "starlette" },
|
||||||
{ name = "streamlit", marker = "extra == 'ui'" },
|
{ name = "streamlit", marker = "extra == 'ui'" },
|
||||||
{ name = "streamlit-option-menu", marker = "extra == 'ui'" },
|
{ name = "streamlit-option-menu", marker = "extra == 'ui'" },
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue