mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-03 19:57:35 +00:00
feat: Add OpenAI Conversations API (#3429)
# What does this PR do? Initial implementation for `Conversations` and `ConversationItems` using `AuthorizedSqlStore` with endpoints to: - CREATE - UPDATE - GET/RETRIEVE/LIST - DELETE Set `level=LLAMA_STACK_API_V1`. NOTE: This does not currently incorporate changes for Responses, that'll be done in a subsequent PR. Closes https://github.com/llamastack/llama-stack/issues/3235 ## Test Plan - Unit tests - Integration tests Also comparison of [OpenAPI spec for OpenAI API](https://github.com/openai/openai-openapi/tree/manual_spec) ```bash oasdiff breaking --fail-on ERR docs/static/llama-stack-spec.yaml https://raw.githubusercontent.com/openai/openai-openapi/refs/heads/manual_spec/openapi.yaml --strip-prefix-base "/v1/openai/v1" \ --match-path '(^/v1/openai/v1/conversations.*|^/conversations.*)' ``` Note I still have some uncertainty about this, I borrowed this info from @cdoern on https://github.com/llamastack/llama-stack/pull/3514 but need to spend more time to confirm it's working, at the moment it suggests it does. UPDATE on `oasdiff`, I investigated the OpenAI spec further and it looks like currently the spec does not list Conversations, so that analysis is useless. Noting for future reference. --------- Signed-off-by: Francisco Javier Arceo <farceo@redhat.com>
This commit is contained in:
parent
a09e30bd87
commit
a20e8eac8c
24 changed files with 5704 additions and 2183 deletions
1902
docs/static/llama-stack-spec.html
vendored
1902
docs/static/llama-stack-spec.html
vendored
File diff suppressed because it is too large
Load diff
1519
docs/static/llama-stack-spec.yaml
vendored
1519
docs/static/llama-stack-spec.yaml
vendored
File diff suppressed because it is too large
Load diff
1902
docs/static/stainless-llama-stack-spec.html
vendored
1902
docs/static/stainless-llama-stack-spec.html
vendored
File diff suppressed because it is too large
Load diff
1519
docs/static/stainless-llama-stack-spec.yaml
vendored
1519
docs/static/stainless-llama-stack-spec.yaml
vendored
File diff suppressed because it is too large
Load diff
31
llama_stack/apis/conversations/__init__.py
Normal file
31
llama_stack/apis/conversations/__init__.py
Normal file
|
@ -0,0 +1,31 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from .conversations import (
|
||||
Conversation,
|
||||
ConversationCreateRequest,
|
||||
ConversationDeletedResource,
|
||||
ConversationItem,
|
||||
ConversationItemCreateRequest,
|
||||
ConversationItemDeletedResource,
|
||||
ConversationItemList,
|
||||
Conversations,
|
||||
ConversationUpdateRequest,
|
||||
Metadata,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"Conversation",
|
||||
"ConversationCreateRequest",
|
||||
"ConversationDeletedResource",
|
||||
"ConversationItem",
|
||||
"ConversationItemCreateRequest",
|
||||
"ConversationItemDeletedResource",
|
||||
"ConversationItemList",
|
||||
"Conversations",
|
||||
"ConversationUpdateRequest",
|
||||
"Metadata",
|
||||
]
|
260
llama_stack/apis/conversations/conversations.py
Normal file
260
llama_stack/apis/conversations/conversations.py
Normal file
|
@ -0,0 +1,260 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Annotated, Literal, Protocol, runtime_checkable
|
||||
|
||||
from openai import NOT_GIVEN
|
||||
from openai._types import NotGiven
|
||||
from openai.types.responses.response_includable import ResponseIncludable
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from llama_stack.apis.agents.openai_responses import (
|
||||
OpenAIResponseMessage,
|
||||
OpenAIResponseOutputMessageFileSearchToolCall,
|
||||
OpenAIResponseOutputMessageFunctionToolCall,
|
||||
OpenAIResponseOutputMessageMCPCall,
|
||||
OpenAIResponseOutputMessageMCPListTools,
|
||||
OpenAIResponseOutputMessageWebSearchToolCall,
|
||||
)
|
||||
from llama_stack.apis.version import LLAMA_STACK_API_V1
|
||||
from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol
|
||||
from llama_stack.schema_utils import json_schema_type, register_schema, webmethod
|
||||
|
||||
Metadata = dict[str, str]
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class Conversation(BaseModel):
|
||||
"""OpenAI-compatible conversation object."""
|
||||
|
||||
id: str = Field(..., description="The unique ID of the conversation.")
|
||||
object: Literal["conversation"] = Field(
|
||||
default="conversation", description="The object type, which is always conversation."
|
||||
)
|
||||
created_at: int = Field(
|
||||
..., description="The time at which the conversation was created, measured in seconds since the Unix epoch."
|
||||
)
|
||||
metadata: Metadata | None = Field(
|
||||
default=None,
|
||||
description="Set of 16 key-value pairs that can be attached to an object. This can be useful for storing additional information about the object in a structured format, and querying for objects via API or the dashboard.",
|
||||
)
|
||||
items: list[dict] | None = Field(
|
||||
default=None,
|
||||
description="Initial items to include in the conversation context. You may add up to 20 items at a time.",
|
||||
)
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class ConversationMessage(BaseModel):
|
||||
"""OpenAI-compatible message item for conversations."""
|
||||
|
||||
id: str = Field(..., description="unique identifier for this message")
|
||||
content: list[dict] = Field(..., description="message content")
|
||||
role: str = Field(..., description="message role")
|
||||
status: str = Field(..., description="message status")
|
||||
type: Literal["message"] = "message"
|
||||
object: Literal["message"] = "message"
|
||||
|
||||
|
||||
ConversationItem = Annotated[
|
||||
OpenAIResponseMessage
|
||||
| OpenAIResponseOutputMessageFunctionToolCall
|
||||
| OpenAIResponseOutputMessageFileSearchToolCall
|
||||
| OpenAIResponseOutputMessageWebSearchToolCall
|
||||
| OpenAIResponseOutputMessageMCPCall
|
||||
| OpenAIResponseOutputMessageMCPListTools,
|
||||
Field(discriminator="type"),
|
||||
]
|
||||
register_schema(ConversationItem, name="ConversationItem")
|
||||
|
||||
# Using OpenAI types directly caused issues but some notes for reference:
|
||||
# Note that ConversationItem is a Annotated Union of the types below:
|
||||
# from openai.types.responses import *
|
||||
# from openai.types.responses.response_item import *
|
||||
# from openai.types.conversations import ConversationItem
|
||||
# f = [
|
||||
# ResponseFunctionToolCallItem,
|
||||
# ResponseFunctionToolCallOutputItem,
|
||||
# ResponseFileSearchToolCall,
|
||||
# ResponseFunctionWebSearch,
|
||||
# ImageGenerationCall,
|
||||
# ResponseComputerToolCall,
|
||||
# ResponseComputerToolCallOutputItem,
|
||||
# ResponseReasoningItem,
|
||||
# ResponseCodeInterpreterToolCall,
|
||||
# LocalShellCall,
|
||||
# LocalShellCallOutput,
|
||||
# McpListTools,
|
||||
# McpApprovalRequest,
|
||||
# McpApprovalResponse,
|
||||
# McpCall,
|
||||
# ResponseCustomToolCall,
|
||||
# ResponseCustomToolCallOutput
|
||||
# ]
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class ConversationCreateRequest(BaseModel):
|
||||
"""Request body for creating a conversation."""
|
||||
|
||||
items: list[ConversationItem] | None = Field(
|
||||
default=[],
|
||||
description="Initial items to include in the conversation context. You may add up to 20 items at a time.",
|
||||
max_length=20,
|
||||
)
|
||||
metadata: Metadata | None = Field(
|
||||
default={},
|
||||
description="Set of 16 key-value pairs that can be attached to an object. Useful for storing additional information",
|
||||
max_length=16,
|
||||
)
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class ConversationUpdateRequest(BaseModel):
|
||||
"""Request body for updating a conversation."""
|
||||
|
||||
metadata: Metadata = Field(
|
||||
...,
|
||||
description="Set of 16 key-value pairs that can be attached to an object. This can be useful for storing additional information about the object in a structured format, and querying for objects via API or the dashboard. Keys are strings with a maximum length of 64 characters. Values are strings with a maximum length of 512 characters.",
|
||||
)
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class ConversationDeletedResource(BaseModel):
|
||||
"""Response for deleted conversation."""
|
||||
|
||||
id: str = Field(..., description="The deleted conversation identifier")
|
||||
object: str = Field(default="conversation.deleted", description="Object type")
|
||||
deleted: bool = Field(default=True, description="Whether the object was deleted")
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class ConversationItemCreateRequest(BaseModel):
|
||||
"""Request body for creating conversation items."""
|
||||
|
||||
items: list[ConversationItem] = Field(
|
||||
...,
|
||||
description="Items to include in the conversation context. You may add up to 20 items at a time.",
|
||||
max_length=20,
|
||||
)
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class ConversationItemList(BaseModel):
|
||||
"""List of conversation items with pagination."""
|
||||
|
||||
object: str = Field(default="list", description="Object type")
|
||||
data: list[ConversationItem] = Field(..., description="List of conversation items")
|
||||
first_id: str | None = Field(default=None, description="The ID of the first item in the list")
|
||||
last_id: str | None = Field(default=None, description="The ID of the last item in the list")
|
||||
has_more: bool = Field(default=False, description="Whether there are more items available")
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class ConversationItemDeletedResource(BaseModel):
|
||||
"""Response for deleted conversation item."""
|
||||
|
||||
id: str = Field(..., description="The deleted item identifier")
|
||||
object: str = Field(default="conversation.item.deleted", description="Object type")
|
||||
deleted: bool = Field(default=True, description="Whether the object was deleted")
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
@trace_protocol
|
||||
class Conversations(Protocol):
|
||||
"""Protocol for conversation management operations."""
|
||||
|
||||
@webmethod(route="/conversations", method="POST", level=LLAMA_STACK_API_V1)
|
||||
async def create_conversation(
|
||||
self, items: list[ConversationItem] | None = None, metadata: Metadata | None = None
|
||||
) -> Conversation:
|
||||
"""Create a conversation.
|
||||
|
||||
:param items: Initial items to include in the conversation context.
|
||||
:param metadata: Set of key-value pairs that can be attached to an object.
|
||||
:returns: The created conversation object.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/conversations/{conversation_id}", method="GET", level=LLAMA_STACK_API_V1)
|
||||
async def get_conversation(self, conversation_id: str) -> Conversation:
|
||||
"""Get a conversation with the given ID.
|
||||
|
||||
:param conversation_id: The conversation identifier.
|
||||
:returns: The conversation object.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/conversations/{conversation_id}", method="POST", level=LLAMA_STACK_API_V1)
|
||||
async def update_conversation(self, conversation_id: str, metadata: Metadata) -> Conversation:
|
||||
"""Update a conversation's metadata with the given ID.
|
||||
|
||||
:param conversation_id: The conversation identifier.
|
||||
:param metadata: Set of key-value pairs that can be attached to an object.
|
||||
:returns: The updated conversation object.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/conversations/{conversation_id}", method="DELETE", level=LLAMA_STACK_API_V1)
|
||||
async def openai_delete_conversation(self, conversation_id: str) -> ConversationDeletedResource:
|
||||
"""Delete a conversation with the given ID.
|
||||
|
||||
:param conversation_id: The conversation identifier.
|
||||
:returns: The deleted conversation resource.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/conversations/{conversation_id}/items", method="POST", level=LLAMA_STACK_API_V1)
|
||||
async def add_items(self, conversation_id: str, items: list[ConversationItem]) -> ConversationItemList:
|
||||
"""Create items in the conversation.
|
||||
|
||||
:param conversation_id: The conversation identifier.
|
||||
:param items: Items to include in the conversation context.
|
||||
:returns: List of created items.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/conversations/{conversation_id}/items/{item_id}", method="GET", level=LLAMA_STACK_API_V1)
|
||||
async def retrieve(self, conversation_id: str, item_id: str) -> ConversationItem:
|
||||
"""Retrieve a conversation item.
|
||||
|
||||
:param conversation_id: The conversation identifier.
|
||||
:param item_id: The item identifier.
|
||||
:returns: The conversation item.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/conversations/{conversation_id}/items", method="GET", level=LLAMA_STACK_API_V1)
|
||||
async def list(
|
||||
self,
|
||||
conversation_id: str,
|
||||
after: str | NotGiven = NOT_GIVEN,
|
||||
include: list[ResponseIncludable] | NotGiven = NOT_GIVEN,
|
||||
limit: int | NotGiven = NOT_GIVEN,
|
||||
order: Literal["asc", "desc"] | NotGiven = NOT_GIVEN,
|
||||
) -> ConversationItemList:
|
||||
"""List items in the conversation.
|
||||
|
||||
:param conversation_id: The conversation identifier.
|
||||
:param after: An item ID to list items after, used in pagination.
|
||||
:param include: Specify additional output data to include in the response.
|
||||
:param limit: A limit on the number of objects to be returned (1-100, default 20).
|
||||
:param order: The order to return items in (asc or desc, default desc).
|
||||
:returns: List of conversation items.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/conversations/{conversation_id}/items/{item_id}", method="DELETE", level=LLAMA_STACK_API_V1)
|
||||
async def openai_delete_conversation_item(
|
||||
self, conversation_id: str, item_id: str
|
||||
) -> ConversationItemDeletedResource:
|
||||
"""Delete a conversation item.
|
||||
|
||||
:param conversation_id: The conversation identifier.
|
||||
:param item_id: The item identifier.
|
||||
:returns: The deleted item resource.
|
||||
"""
|
||||
...
|
|
@ -129,6 +129,7 @@ class Api(Enum, metaclass=DynamicApiMeta):
|
|||
tool_groups = "tool_groups"
|
||||
files = "files"
|
||||
prompts = "prompts"
|
||||
conversations = "conversations"
|
||||
|
||||
# built-in API
|
||||
inspect = "inspect"
|
||||
|
|
5
llama_stack/core/conversations/__init__.py
Normal file
5
llama_stack/core/conversations/__init__.py
Normal file
|
@ -0,0 +1,5 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
306
llama_stack/core/conversations/conversations.py
Normal file
306
llama_stack/core/conversations/conversations.py
Normal file
|
@ -0,0 +1,306 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import os
|
||||
import secrets
|
||||
import time
|
||||
from typing import Any
|
||||
|
||||
from openai import NOT_GIVEN
|
||||
from pydantic import BaseModel, TypeAdapter
|
||||
|
||||
from llama_stack.apis.conversations.conversations import (
|
||||
Conversation,
|
||||
ConversationDeletedResource,
|
||||
ConversationItem,
|
||||
ConversationItemDeletedResource,
|
||||
ConversationItemList,
|
||||
Conversations,
|
||||
Metadata,
|
||||
)
|
||||
from llama_stack.core.datatypes import AccessRule
|
||||
from llama_stack.core.utils.config_dirs import DISTRIBS_BASE_DIR
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.utils.sqlstore.api import ColumnDefinition, ColumnType
|
||||
from llama_stack.providers.utils.sqlstore.authorized_sqlstore import AuthorizedSqlStore
|
||||
from llama_stack.providers.utils.sqlstore.sqlstore import (
|
||||
SqliteSqlStoreConfig,
|
||||
SqlStoreConfig,
|
||||
sqlstore_impl,
|
||||
)
|
||||
|
||||
logger = get_logger(name=__name__, category="openai::conversations")
|
||||
|
||||
|
||||
class ConversationServiceConfig(BaseModel):
|
||||
"""Configuration for the built-in conversation service.
|
||||
|
||||
:param conversations_store: SQL store configuration for conversations (defaults to SQLite)
|
||||
:param policy: Access control rules
|
||||
"""
|
||||
|
||||
conversations_store: SqlStoreConfig = SqliteSqlStoreConfig(
|
||||
db_path=(DISTRIBS_BASE_DIR / "conversations.db").as_posix()
|
||||
)
|
||||
policy: list[AccessRule] = []
|
||||
|
||||
|
||||
async def get_provider_impl(config: ConversationServiceConfig, deps: dict[Any, Any]):
|
||||
"""Get the conversation service implementation."""
|
||||
impl = ConversationServiceImpl(config, deps)
|
||||
await impl.initialize()
|
||||
return impl
|
||||
|
||||
|
||||
class ConversationServiceImpl(Conversations):
|
||||
"""Built-in conversation service implementation using AuthorizedSqlStore."""
|
||||
|
||||
def __init__(self, config: ConversationServiceConfig, deps: dict[Any, Any]):
|
||||
self.config = config
|
||||
self.deps = deps
|
||||
self.policy = config.policy
|
||||
|
||||
base_sql_store = sqlstore_impl(config.conversations_store)
|
||||
self.sql_store = AuthorizedSqlStore(base_sql_store, self.policy)
|
||||
|
||||
async def initialize(self) -> None:
|
||||
"""Initialize the store and create tables."""
|
||||
if isinstance(self.config.conversations_store, SqliteSqlStoreConfig):
|
||||
os.makedirs(os.path.dirname(self.config.conversations_store.db_path), exist_ok=True)
|
||||
|
||||
await self.sql_store.create_table(
|
||||
"openai_conversations",
|
||||
{
|
||||
"id": ColumnDefinition(type=ColumnType.STRING, primary_key=True),
|
||||
"created_at": ColumnType.INTEGER,
|
||||
"items": ColumnType.JSON,
|
||||
"metadata": ColumnType.JSON,
|
||||
},
|
||||
)
|
||||
|
||||
await self.sql_store.create_table(
|
||||
"conversation_items",
|
||||
{
|
||||
"id": ColumnDefinition(type=ColumnType.STRING, primary_key=True),
|
||||
"conversation_id": ColumnType.STRING,
|
||||
"created_at": ColumnType.INTEGER,
|
||||
"item_data": ColumnType.JSON,
|
||||
},
|
||||
)
|
||||
|
||||
async def create_conversation(
|
||||
self, items: list[ConversationItem] | None = None, metadata: Metadata | None = None
|
||||
) -> Conversation:
|
||||
"""Create a conversation."""
|
||||
random_bytes = secrets.token_bytes(24)
|
||||
conversation_id = f"conv_{random_bytes.hex()}"
|
||||
created_at = int(time.time())
|
||||
|
||||
record_data = {
|
||||
"id": conversation_id,
|
||||
"created_at": created_at,
|
||||
"items": [],
|
||||
"metadata": metadata,
|
||||
}
|
||||
|
||||
await self.sql_store.insert(
|
||||
table="openai_conversations",
|
||||
data=record_data,
|
||||
)
|
||||
|
||||
if items:
|
||||
item_records = []
|
||||
for item in items:
|
||||
item_dict = item.model_dump()
|
||||
item_id = self._get_or_generate_item_id(item, item_dict)
|
||||
|
||||
item_record = {
|
||||
"id": item_id,
|
||||
"conversation_id": conversation_id,
|
||||
"created_at": created_at,
|
||||
"item_data": item_dict,
|
||||
}
|
||||
|
||||
item_records.append(item_record)
|
||||
|
||||
await self.sql_store.insert(table="conversation_items", data=item_records)
|
||||
|
||||
conversation = Conversation(
|
||||
id=conversation_id,
|
||||
created_at=created_at,
|
||||
metadata=metadata,
|
||||
object="conversation",
|
||||
)
|
||||
|
||||
logger.info(f"Created conversation {conversation_id}")
|
||||
return conversation
|
||||
|
||||
async def get_conversation(self, conversation_id: str) -> Conversation:
|
||||
"""Get a conversation with the given ID."""
|
||||
record = await self.sql_store.fetch_one(table="openai_conversations", where={"id": conversation_id})
|
||||
|
||||
if record is None:
|
||||
raise ValueError(f"Conversation {conversation_id} not found")
|
||||
|
||||
return Conversation(
|
||||
id=record["id"], created_at=record["created_at"], metadata=record.get("metadata"), object="conversation"
|
||||
)
|
||||
|
||||
async def update_conversation(self, conversation_id: str, metadata: Metadata) -> Conversation:
|
||||
"""Update a conversation's metadata with the given ID"""
|
||||
await self.sql_store.update(
|
||||
table="openai_conversations", data={"metadata": metadata}, where={"id": conversation_id}
|
||||
)
|
||||
|
||||
return await self.get_conversation(conversation_id)
|
||||
|
||||
async def openai_delete_conversation(self, conversation_id: str) -> ConversationDeletedResource:
|
||||
"""Delete a conversation with the given ID."""
|
||||
await self.sql_store.delete(table="openai_conversations", where={"id": conversation_id})
|
||||
|
||||
logger.info(f"Deleted conversation {conversation_id}")
|
||||
return ConversationDeletedResource(id=conversation_id)
|
||||
|
||||
def _validate_conversation_id(self, conversation_id: str) -> None:
|
||||
"""Validate conversation ID format."""
|
||||
if not conversation_id.startswith("conv_"):
|
||||
raise ValueError(
|
||||
f"Invalid 'conversation_id': '{conversation_id}'. Expected an ID that begins with 'conv_'."
|
||||
)
|
||||
|
||||
def _get_or_generate_item_id(self, item: ConversationItem, item_dict: dict) -> str:
|
||||
"""Get existing item ID or generate one if missing."""
|
||||
if item.id is None:
|
||||
random_bytes = secrets.token_bytes(24)
|
||||
if item.type == "message":
|
||||
item_id = f"msg_{random_bytes.hex()}"
|
||||
else:
|
||||
item_id = f"item_{random_bytes.hex()}"
|
||||
item_dict["id"] = item_id
|
||||
return item_id
|
||||
return item.id
|
||||
|
||||
async def _get_validated_conversation(self, conversation_id: str) -> Conversation:
|
||||
"""Validate conversation ID and return the conversation if it exists."""
|
||||
self._validate_conversation_id(conversation_id)
|
||||
return await self.get_conversation(conversation_id)
|
||||
|
||||
async def add_items(self, conversation_id: str, items: list[ConversationItem]) -> ConversationItemList:
|
||||
"""Create (add) items to a conversation."""
|
||||
await self._get_validated_conversation(conversation_id)
|
||||
|
||||
created_items = []
|
||||
created_at = int(time.time())
|
||||
|
||||
for item in items:
|
||||
item_dict = item.model_dump()
|
||||
item_id = self._get_or_generate_item_id(item, item_dict)
|
||||
|
||||
item_record = {
|
||||
"id": item_id,
|
||||
"conversation_id": conversation_id,
|
||||
"created_at": created_at,
|
||||
"item_data": item_dict,
|
||||
}
|
||||
|
||||
# TODO: Add support for upsert in sql_store, this will fail first if ID exists and then update
|
||||
try:
|
||||
await self.sql_store.insert(table="conversation_items", data=item_record)
|
||||
except Exception:
|
||||
# If insert fails due to ID conflict, update existing record
|
||||
await self.sql_store.update(
|
||||
table="conversation_items",
|
||||
data={"created_at": created_at, "item_data": item_dict},
|
||||
where={"id": item_id},
|
||||
)
|
||||
|
||||
created_items.append(item_dict)
|
||||
|
||||
logger.info(f"Created {len(created_items)} items in conversation {conversation_id}")
|
||||
|
||||
# Convert created items (dicts) to proper ConversationItem types
|
||||
adapter: TypeAdapter[ConversationItem] = TypeAdapter(ConversationItem)
|
||||
response_items: list[ConversationItem] = [adapter.validate_python(item_dict) for item_dict in created_items]
|
||||
|
||||
return ConversationItemList(
|
||||
data=response_items,
|
||||
first_id=created_items[0]["id"] if created_items else None,
|
||||
last_id=created_items[-1]["id"] if created_items else None,
|
||||
has_more=False,
|
||||
)
|
||||
|
||||
async def retrieve(self, conversation_id: str, item_id: str) -> ConversationItem:
|
||||
"""Retrieve a conversation item."""
|
||||
if not conversation_id:
|
||||
raise ValueError(f"Expected a non-empty value for `conversation_id` but received {conversation_id!r}")
|
||||
if not item_id:
|
||||
raise ValueError(f"Expected a non-empty value for `item_id` but received {item_id!r}")
|
||||
|
||||
# Get item from conversation_items table
|
||||
record = await self.sql_store.fetch_one(
|
||||
table="conversation_items", where={"id": item_id, "conversation_id": conversation_id}
|
||||
)
|
||||
|
||||
if record is None:
|
||||
raise ValueError(f"Item {item_id} not found in conversation {conversation_id}")
|
||||
|
||||
adapter: TypeAdapter[ConversationItem] = TypeAdapter(ConversationItem)
|
||||
return adapter.validate_python(record["item_data"])
|
||||
|
||||
async def list(self, conversation_id: str, after=NOT_GIVEN, include=NOT_GIVEN, limit=NOT_GIVEN, order=NOT_GIVEN):
|
||||
"""List items in the conversation."""
|
||||
result = await self.sql_store.fetch_all(table="conversation_items", where={"conversation_id": conversation_id})
|
||||
records = result.data
|
||||
|
||||
if order != NOT_GIVEN and order == "asc":
|
||||
records.sort(key=lambda x: x["created_at"])
|
||||
else:
|
||||
records.sort(key=lambda x: x["created_at"], reverse=True)
|
||||
|
||||
actual_limit = 20
|
||||
if limit != NOT_GIVEN and isinstance(limit, int):
|
||||
actual_limit = limit
|
||||
|
||||
records = records[:actual_limit]
|
||||
items = [record["item_data"] for record in records]
|
||||
|
||||
adapter: TypeAdapter[ConversationItem] = TypeAdapter(ConversationItem)
|
||||
response_items: list[ConversationItem] = [adapter.validate_python(item) for item in items]
|
||||
|
||||
first_id = response_items[0].id if response_items else None
|
||||
last_id = response_items[-1].id if response_items else None
|
||||
|
||||
return ConversationItemList(
|
||||
data=response_items,
|
||||
first_id=first_id,
|
||||
last_id=last_id,
|
||||
has_more=False,
|
||||
)
|
||||
|
||||
async def openai_delete_conversation_item(
|
||||
self, conversation_id: str, item_id: str
|
||||
) -> ConversationItemDeletedResource:
|
||||
"""Delete a conversation item."""
|
||||
if not conversation_id:
|
||||
raise ValueError(f"Expected a non-empty value for `conversation_id` but received {conversation_id!r}")
|
||||
if not item_id:
|
||||
raise ValueError(f"Expected a non-empty value for `item_id` but received {item_id!r}")
|
||||
|
||||
_ = await self._get_validated_conversation(conversation_id)
|
||||
|
||||
record = await self.sql_store.fetch_one(
|
||||
table="conversation_items", where={"id": item_id, "conversation_id": conversation_id}
|
||||
)
|
||||
|
||||
if record is None:
|
||||
raise ValueError(f"Item {item_id} not found in conversation {conversation_id}")
|
||||
|
||||
await self.sql_store.delete(
|
||||
table="conversation_items", where={"id": item_id, "conversation_id": conversation_id}
|
||||
)
|
||||
|
||||
logger.info(f"Deleted item {item_id} from conversation {conversation_id}")
|
||||
return ConversationItemDeletedResource(id=item_id)
|
|
@ -475,6 +475,13 @@ InferenceStoreConfig (with queue tuning parameters) or a SqlStoreConfig (depreca
|
|||
If not specified, a default SQLite store will be used.""",
|
||||
)
|
||||
|
||||
conversations_store: SqlStoreConfig | None = Field(
|
||||
default=None,
|
||||
description="""
|
||||
Configuration for the persistence store used by the conversations API.
|
||||
If not specified, a default SQLite store will be used.""",
|
||||
)
|
||||
|
||||
# registry of "resources" in the distribution
|
||||
models: list[ModelInput] = Field(default_factory=list)
|
||||
shields: list[ShieldInput] = Field(default_factory=list)
|
||||
|
|
|
@ -25,7 +25,7 @@ from llama_stack.providers.datatypes import (
|
|||
logger = get_logger(name=__name__, category="core")
|
||||
|
||||
|
||||
INTERNAL_APIS = {Api.inspect, Api.providers, Api.prompts}
|
||||
INTERNAL_APIS = {Api.inspect, Api.providers, Api.prompts, Api.conversations}
|
||||
|
||||
|
||||
def stack_apis() -> list[Api]:
|
||||
|
|
|
@ -10,6 +10,7 @@ from typing import Any
|
|||
from llama_stack.apis.agents import Agents
|
||||
from llama_stack.apis.batches import Batches
|
||||
from llama_stack.apis.benchmarks import Benchmarks
|
||||
from llama_stack.apis.conversations import Conversations
|
||||
from llama_stack.apis.datasetio import DatasetIO
|
||||
from llama_stack.apis.datasets import Datasets
|
||||
from llama_stack.apis.datatypes import ExternalApiSpec
|
||||
|
@ -96,6 +97,7 @@ def api_protocol_map(external_apis: dict[Api, ExternalApiSpec] | None = None) ->
|
|||
Api.tool_runtime: ToolRuntime,
|
||||
Api.files: Files,
|
||||
Api.prompts: Prompts,
|
||||
Api.conversations: Conversations,
|
||||
}
|
||||
|
||||
if external_apis:
|
||||
|
|
|
@ -451,6 +451,7 @@ def create_app(
|
|||
apis_to_serve.add("inspect")
|
||||
apis_to_serve.add("providers")
|
||||
apis_to_serve.add("prompts")
|
||||
apis_to_serve.add("conversations")
|
||||
for api_str in apis_to_serve:
|
||||
api = Api(api_str)
|
||||
|
||||
|
|
|
@ -15,6 +15,7 @@ import yaml
|
|||
|
||||
from llama_stack.apis.agents import Agents
|
||||
from llama_stack.apis.benchmarks import Benchmarks
|
||||
from llama_stack.apis.conversations import Conversations
|
||||
from llama_stack.apis.datasetio import DatasetIO
|
||||
from llama_stack.apis.datasets import Datasets
|
||||
from llama_stack.apis.eval import Eval
|
||||
|
@ -34,6 +35,7 @@ from llama_stack.apis.telemetry import Telemetry
|
|||
from llama_stack.apis.tools import RAGToolRuntime, ToolGroups, ToolRuntime
|
||||
from llama_stack.apis.vector_dbs import VectorDBs
|
||||
from llama_stack.apis.vector_io import VectorIO
|
||||
from llama_stack.core.conversations.conversations import ConversationServiceConfig, ConversationServiceImpl
|
||||
from llama_stack.core.datatypes import Provider, StackRunConfig
|
||||
from llama_stack.core.distribution import get_provider_registry
|
||||
from llama_stack.core.inspect import DistributionInspectConfig, DistributionInspectImpl
|
||||
|
@ -73,6 +75,7 @@ class LlamaStack(
|
|||
RAGToolRuntime,
|
||||
Files,
|
||||
Prompts,
|
||||
Conversations,
|
||||
):
|
||||
pass
|
||||
|
||||
|
@ -312,6 +315,12 @@ def add_internal_implementations(impls: dict[Api, Any], run_config: StackRunConf
|
|||
)
|
||||
impls[Api.prompts] = prompts_impl
|
||||
|
||||
conversations_impl = ConversationServiceImpl(
|
||||
ConversationServiceConfig(run_config=run_config),
|
||||
deps=impls,
|
||||
)
|
||||
impls[Api.conversations] = conversations_impl
|
||||
|
||||
|
||||
class Stack:
|
||||
def __init__(self, run_config: StackRunConfig, provider_registry: ProviderRegistry | None = None):
|
||||
|
@ -342,6 +351,8 @@ class Stack:
|
|||
|
||||
if Api.prompts in impls:
|
||||
await impls[Api.prompts].initialize()
|
||||
if Api.conversations in impls:
|
||||
await impls[Api.conversations].initialize()
|
||||
|
||||
await register_resources(self.run_config, impls)
|
||||
|
||||
|
|
|
@ -4,7 +4,7 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from collections.abc import Mapping
|
||||
from collections.abc import Mapping, Sequence
|
||||
from enum import Enum
|
||||
from typing import Any, Literal, Protocol
|
||||
|
||||
|
@ -41,9 +41,9 @@ class SqlStore(Protocol):
|
|||
"""
|
||||
pass
|
||||
|
||||
async def insert(self, table: str, data: Mapping[str, Any]) -> None:
|
||||
async def insert(self, table: str, data: Mapping[str, Any] | Sequence[Mapping[str, Any]]) -> None:
|
||||
"""
|
||||
Insert a row into a table.
|
||||
Insert a row or batch of rows into a table.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
|
|
@ -4,7 +4,7 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from collections.abc import Mapping
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import Any, Literal
|
||||
|
||||
from llama_stack.core.access_control.access_control import default_policy, is_action_allowed
|
||||
|
@ -38,6 +38,18 @@ SQL_OPTIMIZED_POLICY = [
|
|||
]
|
||||
|
||||
|
||||
def _enhance_item_with_access_control(item: Mapping[str, Any], current_user: User | None) -> Mapping[str, Any]:
|
||||
"""Add access control attributes to a data item."""
|
||||
enhanced = dict(item)
|
||||
if current_user:
|
||||
enhanced["owner_principal"] = current_user.principal
|
||||
enhanced["access_attributes"] = current_user.attributes
|
||||
else:
|
||||
enhanced["owner_principal"] = None
|
||||
enhanced["access_attributes"] = None
|
||||
return enhanced
|
||||
|
||||
|
||||
class SqlRecord(ProtectedResource):
|
||||
def __init__(self, record_id: str, table_name: str, owner: User):
|
||||
self.type = f"sql_record::{table_name}"
|
||||
|
@ -102,18 +114,14 @@ class AuthorizedSqlStore:
|
|||
await self.sql_store.add_column_if_not_exists(table, "access_attributes", ColumnType.JSON)
|
||||
await self.sql_store.add_column_if_not_exists(table, "owner_principal", ColumnType.STRING)
|
||||
|
||||
async def insert(self, table: str, data: Mapping[str, Any]) -> None:
|
||||
"""Insert a row with automatic access control attribute capture."""
|
||||
enhanced_data = dict(data)
|
||||
|
||||
async def insert(self, table: str, data: Mapping[str, Any] | Sequence[Mapping[str, Any]]) -> None:
|
||||
"""Insert a row or batch of rows with automatic access control attribute capture."""
|
||||
current_user = get_authenticated_user()
|
||||
if current_user:
|
||||
enhanced_data["owner_principal"] = current_user.principal
|
||||
enhanced_data["access_attributes"] = current_user.attributes
|
||||
enhanced_data: Mapping[str, Any] | Sequence[Mapping[str, Any]]
|
||||
if isinstance(data, Mapping):
|
||||
enhanced_data = _enhance_item_with_access_control(data, current_user)
|
||||
else:
|
||||
enhanced_data["owner_principal"] = None
|
||||
enhanced_data["access_attributes"] = None
|
||||
|
||||
enhanced_data = [_enhance_item_with_access_control(item, current_user) for item in data]
|
||||
await self.sql_store.insert(table, enhanced_data)
|
||||
|
||||
async def fetch_all(
|
||||
|
|
|
@ -3,7 +3,7 @@
|
|||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
from collections.abc import Mapping
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import Any, Literal
|
||||
|
||||
from sqlalchemy import (
|
||||
|
@ -116,7 +116,7 @@ class SqlAlchemySqlStoreImpl(SqlStore):
|
|||
async with engine.begin() as conn:
|
||||
await conn.run_sync(self.metadata.create_all, tables=[sqlalchemy_table], checkfirst=True)
|
||||
|
||||
async def insert(self, table: str, data: Mapping[str, Any]) -> None:
|
||||
async def insert(self, table: str, data: Mapping[str, Any] | Sequence[Mapping[str, Any]]) -> None:
|
||||
async with self.async_session() as session:
|
||||
await session.execute(self.metadata.tables[table].insert(), data)
|
||||
await session.commit()
|
||||
|
|
|
@ -484,12 +484,19 @@ class JsonSchemaGenerator:
|
|||
}
|
||||
return ret
|
||||
elif origin_type is Literal:
|
||||
if len(typing.get_args(typ)) != 1:
|
||||
raise ValueError(f"Literal type {typ} has {len(typing.get_args(typ))} arguments")
|
||||
(literal_value,) = typing.get_args(typ) # unpack value of literal type
|
||||
literal_args = typing.get_args(typ)
|
||||
if len(literal_args) == 1:
|
||||
(literal_value,) = literal_args
|
||||
schema = self.type_to_schema(type(literal_value))
|
||||
schema["const"] = literal_value
|
||||
return schema
|
||||
elif len(literal_args) > 1:
|
||||
first_value = literal_args[0]
|
||||
schema = self.type_to_schema(type(first_value))
|
||||
schema["enum"] = list(literal_args)
|
||||
return schema
|
||||
else:
|
||||
return {"enum": []}
|
||||
elif origin_type is type:
|
||||
(concrete_type,) = typing.get_args(typ) # unpack single tuple element
|
||||
return {"const": self.type_to_schema(concrete_type, force_expand=True)}
|
||||
|
|
|
@ -32,7 +32,7 @@ dependencies = [
|
|||
"jinja2>=3.1.6",
|
||||
"jsonschema",
|
||||
"llama-stack-client>=0.2.23",
|
||||
"openai>=1.100.0", # for expires_after support
|
||||
"openai>=1.107", # for expires_after support
|
||||
"prompt-toolkit",
|
||||
"python-dotenv",
|
||||
"python-jose[cryptography]",
|
||||
|
@ -49,6 +49,7 @@ dependencies = [
|
|||
"opentelemetry-exporter-otlp-proto-http>=1.30.0", # server
|
||||
"aiosqlite>=0.21.0", # server - for metadata store
|
||||
"asyncpg", # for metadata store
|
||||
"sqlalchemy[asyncio]>=2.0.41", # server - for conversations
|
||||
]
|
||||
|
||||
[project.optional-dependencies]
|
||||
|
|
135
tests/integration/conversations/test_openai_conversations.py
Normal file
135
tests/integration/conversations/test_openai_conversations.py
Normal file
|
@ -0,0 +1,135 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
class TestOpenAIConversations:
|
||||
# TODO: Update to compat_client after client-SDK is generated
|
||||
def test_conversation_create(self, openai_client):
|
||||
conversation = openai_client.conversations.create(
|
||||
metadata={"topic": "demo"}, items=[{"type": "message", "role": "user", "content": "Hello!"}]
|
||||
)
|
||||
|
||||
assert conversation.id.startswith("conv_")
|
||||
assert conversation.object == "conversation"
|
||||
assert conversation.metadata["topic"] == "demo"
|
||||
assert isinstance(conversation.created_at, int)
|
||||
|
||||
def test_conversation_retrieve(self, openai_client):
|
||||
conversation = openai_client.conversations.create(metadata={"topic": "demo"})
|
||||
|
||||
retrieved = openai_client.conversations.retrieve(conversation.id)
|
||||
|
||||
assert retrieved.id == conversation.id
|
||||
assert retrieved.object == "conversation"
|
||||
assert retrieved.metadata["topic"] == "demo"
|
||||
assert retrieved.created_at == conversation.created_at
|
||||
|
||||
def test_conversation_update(self, openai_client):
|
||||
conversation = openai_client.conversations.create(metadata={"topic": "demo"})
|
||||
|
||||
updated = openai_client.conversations.update(conversation.id, metadata={"topic": "project-x"})
|
||||
|
||||
assert updated.id == conversation.id
|
||||
assert updated.metadata["topic"] == "project-x"
|
||||
assert updated.created_at == conversation.created_at
|
||||
|
||||
def test_conversation_delete(self, openai_client):
|
||||
conversation = openai_client.conversations.create(metadata={"topic": "demo"})
|
||||
|
||||
deleted = openai_client.conversations.delete(conversation.id)
|
||||
|
||||
assert deleted.id == conversation.id
|
||||
assert deleted.object == "conversation.deleted"
|
||||
assert deleted.deleted is True
|
||||
|
||||
def test_conversation_items_create(self, openai_client):
|
||||
conversation = openai_client.conversations.create()
|
||||
|
||||
items = openai_client.conversations.items.create(
|
||||
conversation.id,
|
||||
items=[
|
||||
{"type": "message", "role": "user", "content": [{"type": "input_text", "text": "Hello!"}]},
|
||||
{"type": "message", "role": "user", "content": [{"type": "input_text", "text": "How are you?"}]},
|
||||
],
|
||||
)
|
||||
|
||||
assert items.object == "list"
|
||||
assert len(items.data) == 2
|
||||
assert items.data[0].content[0].text == "Hello!"
|
||||
assert items.data[1].content[0].text == "How are you?"
|
||||
assert items.first_id == items.data[0].id
|
||||
assert items.last_id == items.data[1].id
|
||||
assert items.has_more is False
|
||||
|
||||
def test_conversation_items_list(self, openai_client):
|
||||
conversation = openai_client.conversations.create()
|
||||
|
||||
openai_client.conversations.items.create(
|
||||
conversation.id,
|
||||
items=[{"type": "message", "role": "user", "content": [{"type": "input_text", "text": "Hello!"}]}],
|
||||
)
|
||||
|
||||
items = openai_client.conversations.items.list(conversation.id, limit=10)
|
||||
|
||||
assert items.object == "list"
|
||||
assert len(items.data) >= 1
|
||||
assert items.data[0].type == "message"
|
||||
assert items.data[0].role == "user"
|
||||
assert hasattr(items, "first_id")
|
||||
assert hasattr(items, "last_id")
|
||||
assert hasattr(items, "has_more")
|
||||
|
||||
def test_conversation_item_retrieve(self, openai_client):
|
||||
conversation = openai_client.conversations.create()
|
||||
|
||||
created_items = openai_client.conversations.items.create(
|
||||
conversation.id,
|
||||
items=[{"type": "message", "role": "user", "content": [{"type": "input_text", "text": "Hello!"}]}],
|
||||
)
|
||||
|
||||
item_id = created_items.data[0].id
|
||||
item = openai_client.conversations.items.retrieve(item_id, conversation_id=conversation.id)
|
||||
|
||||
assert item.id == item_id
|
||||
assert item.type == "message"
|
||||
assert item.role == "user"
|
||||
assert item.content[0].text == "Hello!"
|
||||
|
||||
def test_conversation_item_delete(self, openai_client):
|
||||
conversation = openai_client.conversations.create()
|
||||
|
||||
created_items = openai_client.conversations.items.create(
|
||||
conversation.id,
|
||||
items=[{"type": "message", "role": "user", "content": [{"type": "input_text", "text": "Hello!"}]}],
|
||||
)
|
||||
|
||||
item_id = created_items.data[0].id
|
||||
deleted = openai_client.conversations.items.delete(item_id, conversation_id=conversation.id)
|
||||
|
||||
assert deleted.id == item_id
|
||||
assert deleted.object == "conversation.item.deleted"
|
||||
assert deleted.deleted is True
|
||||
|
||||
def test_full_workflow(self, openai_client):
|
||||
conversation = openai_client.conversations.create(
|
||||
metadata={"topic": "workflow-test"}, items=[{"type": "message", "role": "user", "content": "Hello!"}]
|
||||
)
|
||||
|
||||
openai_client.conversations.items.create(
|
||||
conversation.id,
|
||||
items=[{"type": "message", "role": "user", "content": [{"type": "input_text", "text": "Follow up"}]}],
|
||||
)
|
||||
|
||||
all_items = openai_client.conversations.items.list(conversation.id)
|
||||
assert len(all_items.data) >= 2
|
||||
|
||||
updated = openai_client.conversations.update(conversation.id, metadata={"topic": "workflow-complete"})
|
||||
assert updated.metadata["topic"] == "workflow-complete"
|
||||
|
||||
openai_client.conversations.delete(conversation.id)
|
60
tests/unit/conversations/test_api_models.py
Normal file
60
tests/unit/conversations/test_api_models.py
Normal file
|
@ -0,0 +1,60 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
|
||||
from llama_stack.apis.conversations.conversations import (
|
||||
Conversation,
|
||||
ConversationCreateRequest,
|
||||
ConversationItem,
|
||||
ConversationItemList,
|
||||
)
|
||||
|
||||
|
||||
def test_conversation_create_request_defaults():
|
||||
request = ConversationCreateRequest()
|
||||
assert request.items == []
|
||||
assert request.metadata == {}
|
||||
|
||||
|
||||
def test_conversation_model_defaults():
|
||||
conversation = Conversation(
|
||||
id="conv_123456789",
|
||||
created_at=1234567890,
|
||||
metadata=None,
|
||||
object="conversation",
|
||||
)
|
||||
assert conversation.id == "conv_123456789"
|
||||
assert conversation.object == "conversation"
|
||||
assert conversation.metadata is None
|
||||
|
||||
|
||||
def test_openai_client_compatibility():
|
||||
from openai.types.conversations.message import Message
|
||||
from pydantic import TypeAdapter
|
||||
|
||||
openai_message = Message(
|
||||
id="msg_123",
|
||||
content=[{"type": "input_text", "text": "Hello"}],
|
||||
role="user",
|
||||
status="in_progress",
|
||||
type="message",
|
||||
object="message",
|
||||
)
|
||||
|
||||
adapter = TypeAdapter(ConversationItem)
|
||||
validated_item = adapter.validate_python(openai_message.model_dump())
|
||||
|
||||
assert validated_item.id == "msg_123"
|
||||
assert validated_item.type == "message"
|
||||
|
||||
|
||||
def test_conversation_item_list():
|
||||
item_list = ConversationItemList(data=[])
|
||||
assert item_list.object == "list"
|
||||
assert item_list.data == []
|
||||
assert item_list.first_id is None
|
||||
assert item_list.last_id is None
|
||||
assert item_list.has_more is False
|
132
tests/unit/conversations/test_conversations.py
Normal file
132
tests/unit/conversations/test_conversations.py
Normal file
|
@ -0,0 +1,132 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
from openai.types.conversations.conversation import Conversation as OpenAIConversation
|
||||
from openai.types.conversations.conversation_item import ConversationItem as OpenAIConversationItem
|
||||
from pydantic import TypeAdapter
|
||||
|
||||
from llama_stack.apis.agents.openai_responses import (
|
||||
OpenAIResponseInputMessageContentText,
|
||||
OpenAIResponseMessage,
|
||||
)
|
||||
from llama_stack.core.conversations.conversations import (
|
||||
ConversationServiceConfig,
|
||||
ConversationServiceImpl,
|
||||
)
|
||||
from llama_stack.providers.utils.sqlstore.sqlstore import SqliteSqlStoreConfig
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def service():
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
db_path = Path(tmpdir) / "test_conversations.db"
|
||||
|
||||
config = ConversationServiceConfig(conversations_store=SqliteSqlStoreConfig(db_path=str(db_path)), policy=[])
|
||||
service = ConversationServiceImpl(config, {})
|
||||
await service.initialize()
|
||||
yield service
|
||||
|
||||
|
||||
async def test_conversation_lifecycle(service):
|
||||
conversation = await service.create_conversation(metadata={"test": "data"})
|
||||
|
||||
assert conversation.id.startswith("conv_")
|
||||
assert conversation.metadata == {"test": "data"}
|
||||
|
||||
retrieved = await service.get_conversation(conversation.id)
|
||||
assert retrieved.id == conversation.id
|
||||
|
||||
deleted = await service.openai_delete_conversation(conversation.id)
|
||||
assert deleted.id == conversation.id
|
||||
|
||||
|
||||
async def test_conversation_items(service):
|
||||
conversation = await service.create_conversation()
|
||||
|
||||
items = [
|
||||
OpenAIResponseMessage(
|
||||
type="message",
|
||||
role="user",
|
||||
content=[OpenAIResponseInputMessageContentText(type="input_text", text="Hello")],
|
||||
id="msg_test123",
|
||||
status="completed",
|
||||
)
|
||||
]
|
||||
item_list = await service.add_items(conversation.id, items)
|
||||
|
||||
assert len(item_list.data) == 1
|
||||
assert item_list.data[0].id == "msg_test123"
|
||||
|
||||
items = await service.list(conversation.id)
|
||||
assert len(items.data) == 1
|
||||
|
||||
|
||||
async def test_invalid_conversation_id(service):
|
||||
with pytest.raises(ValueError, match="Expected an ID that begins with 'conv_'"):
|
||||
await service._get_validated_conversation("invalid_id")
|
||||
|
||||
|
||||
async def test_empty_parameter_validation(service):
|
||||
with pytest.raises(ValueError, match="Expected a non-empty value"):
|
||||
await service.retrieve("", "item_123")
|
||||
|
||||
|
||||
async def test_openai_type_compatibility(service):
|
||||
conversation = await service.create_conversation(metadata={"test": "value"})
|
||||
|
||||
conversation_dict = conversation.model_dump()
|
||||
openai_conversation = OpenAIConversation.model_validate(conversation_dict)
|
||||
|
||||
for attr in ["id", "object", "created_at", "metadata"]:
|
||||
assert getattr(openai_conversation, attr) == getattr(conversation, attr)
|
||||
|
||||
items = [
|
||||
OpenAIResponseMessage(
|
||||
type="message",
|
||||
role="user",
|
||||
content=[OpenAIResponseInputMessageContentText(type="input_text", text="Hello")],
|
||||
id="msg_test456",
|
||||
status="completed",
|
||||
)
|
||||
]
|
||||
item_list = await service.add_items(conversation.id, items)
|
||||
|
||||
for attr in ["object", "data", "first_id", "last_id", "has_more"]:
|
||||
assert hasattr(item_list, attr)
|
||||
assert item_list.object == "list"
|
||||
|
||||
items = await service.list(conversation.id)
|
||||
item = await service.retrieve(conversation.id, items.data[0].id)
|
||||
item_dict = item.model_dump()
|
||||
|
||||
openai_item_adapter = TypeAdapter(OpenAIConversationItem)
|
||||
openai_item_adapter.validate_python(item_dict)
|
||||
|
||||
|
||||
async def test_policy_configuration():
|
||||
from llama_stack.core.access_control.datatypes import Action, Scope
|
||||
from llama_stack.core.datatypes import AccessRule
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
db_path = Path(tmpdir) / "test_conversations_policy.db"
|
||||
|
||||
restrictive_policy = [
|
||||
AccessRule(forbid=Scope(principal="test_user", actions=[Action.CREATE, Action.READ], resource="*"))
|
||||
]
|
||||
|
||||
config = ConversationServiceConfig(
|
||||
conversations_store=SqliteSqlStoreConfig(db_path=str(db_path)), policy=restrictive_policy
|
||||
)
|
||||
service = ConversationServiceImpl(config, {})
|
||||
await service.initialize()
|
||||
|
||||
assert service.policy == restrictive_policy
|
||||
assert len(service.policy) == 1
|
||||
assert service.policy[0].forbid is not None
|
|
@ -368,6 +368,32 @@ async def test_where_operator_gt_and_update_delete():
|
|||
assert {r["id"] for r in rows_after} == {1, 3}
|
||||
|
||||
|
||||
async def test_batch_insert():
|
||||
with TemporaryDirectory() as tmp_dir:
|
||||
db_path = tmp_dir + "/test.db"
|
||||
store = SqlAlchemySqlStoreImpl(SqliteSqlStoreConfig(db_path=db_path))
|
||||
|
||||
await store.create_table(
|
||||
"batch_test",
|
||||
{
|
||||
"id": ColumnType.INTEGER,
|
||||
"name": ColumnType.STRING,
|
||||
"value": ColumnType.INTEGER,
|
||||
},
|
||||
)
|
||||
|
||||
batch_data = [
|
||||
{"id": 1, "name": "first", "value": 10},
|
||||
{"id": 2, "name": "second", "value": 20},
|
||||
{"id": 3, "name": "third", "value": 30},
|
||||
]
|
||||
|
||||
await store.insert("batch_test", batch_data)
|
||||
|
||||
result = await store.fetch_all("batch_test", order_by=[("id", "asc")])
|
||||
assert result.data == batch_data
|
||||
|
||||
|
||||
async def test_where_operator_edge_cases():
|
||||
with TemporaryDirectory() as tmp_dir:
|
||||
db_path = tmp_dir + "/test.db"
|
||||
|
|
4
uv.lock
generated
4
uv.lock
generated
|
@ -1773,6 +1773,7 @@ dependencies = [
|
|||
{ name = "python-jose", extra = ["cryptography"] },
|
||||
{ name = "python-multipart" },
|
||||
{ name = "rich" },
|
||||
{ name = "sqlalchemy", extra = ["asyncio"] },
|
||||
{ name = "starlette" },
|
||||
{ name = "termcolor" },
|
||||
{ name = "tiktoken" },
|
||||
|
@ -1887,7 +1888,7 @@ requires-dist = [
|
|||
{ name = "jsonschema" },
|
||||
{ name = "llama-stack-client", specifier = ">=0.2.23" },
|
||||
{ name = "llama-stack-client", marker = "extra == 'ui'", specifier = ">=0.2.23" },
|
||||
{ name = "openai", specifier = ">=1.100.0" },
|
||||
{ name = "openai", specifier = ">=1.107" },
|
||||
{ name = "opentelemetry-exporter-otlp-proto-http", specifier = ">=1.30.0" },
|
||||
{ name = "opentelemetry-sdk", specifier = ">=1.30.0" },
|
||||
{ name = "pandas", marker = "extra == 'ui'" },
|
||||
|
@ -1898,6 +1899,7 @@ requires-dist = [
|
|||
{ name = "python-jose", extras = ["cryptography"] },
|
||||
{ name = "python-multipart", specifier = ">=0.0.20" },
|
||||
{ name = "rich" },
|
||||
{ name = "sqlalchemy", extras = ["asyncio"], specifier = ">=2.0.41" },
|
||||
{ name = "starlette" },
|
||||
{ name = "streamlit", marker = "extra == 'ui'" },
|
||||
{ name = "streamlit-option-menu", marker = "extra == 'ui'" },
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue