mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-04 20:14:13 +00:00
Some checks failed
Integration Auth Tests / test-matrix (oauth2_token) (push) Failing after 2s
SqlStore Integration Tests / test-postgres (3.12) (push) Failing after 1s
Python Package Build Test / build (3.12) (push) Failing after 2s
Integration Tests (Replay) / Integration Tests (, , , client=, ) (push) Failing after 5s
Test External Providers Installed via Module / test-external-providers-from-module (venv) (push) Has been skipped
API Conformance Tests / check-schema-compatibility (push) Successful in 10s
Vector IO Integration Tests / test-matrix (push) Failing after 5s
Python Package Build Test / build (3.13) (push) Failing after 3s
SqlStore Integration Tests / test-postgres (3.13) (push) Failing after 9s
Test External API and Providers / test-external (venv) (push) Failing after 6s
Unit Tests / unit-tests (3.12) (push) Failing after 5s
Unit Tests / unit-tests (3.13) (push) Failing after 6s
UI Tests / ui-tests (22) (push) Successful in 33s
Pre-commit / pre-commit (push) Successful in 1m27s
# What does this PR do? closes #3268 closes #3498 When resuming from previous response ID, currently we attempt to convert from the stored responses input to chat completion messages, which is not always possible, e.g. for tool calls where some data is lost once converted from chat completion message to repsonses input format. This PR stores the chat completion messages that correspond to the _last_ call to chat completion, which is sufficient to be resumed from in the next responses API call, where we load these saved messages and skip conversion entirely. Separate issue to optimize storage: https://github.com/llamastack/llama-stack/issues/3646 ## Test Plan existing CI tests
296 lines
11 KiB
Python
296 lines
11 KiB
Python
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
# All rights reserved.
|
|
#
|
|
# This source code is licensed under the terms described in the LICENSE file in
|
|
# the root directory of this source tree.
|
|
import asyncio
|
|
from typing import Any
|
|
|
|
from llama_stack.apis.agents import (
|
|
Order,
|
|
)
|
|
from llama_stack.apis.agents.openai_responses import (
|
|
ListOpenAIResponseInputItem,
|
|
ListOpenAIResponseObject,
|
|
OpenAIDeleteResponseObject,
|
|
OpenAIResponseInput,
|
|
OpenAIResponseObject,
|
|
OpenAIResponseObjectWithInput,
|
|
)
|
|
from llama_stack.apis.inference import OpenAIMessageParam
|
|
from llama_stack.core.datatypes import AccessRule, ResponsesStoreConfig
|
|
from llama_stack.core.utils.config_dirs import RUNTIME_BASE_DIR
|
|
from llama_stack.log import get_logger
|
|
|
|
from ..sqlstore.api import ColumnDefinition, ColumnType
|
|
from ..sqlstore.authorized_sqlstore import AuthorizedSqlStore
|
|
from ..sqlstore.sqlstore import SqliteSqlStoreConfig, SqlStoreConfig, SqlStoreType, sqlstore_impl
|
|
|
|
logger = get_logger(name=__name__, category="openai_responses")
|
|
|
|
|
|
class _OpenAIResponseObjectWithInputAndMessages(OpenAIResponseObjectWithInput):
|
|
"""Internal class for storing responses with chat completion messages.
|
|
|
|
This extends the public OpenAIResponseObjectWithInput with messages field
|
|
for internal storage. The messages field is not exposed in the public API.
|
|
|
|
The messages field is optional for backward compatibility with responses
|
|
stored before this feature was added.
|
|
"""
|
|
|
|
messages: list[OpenAIMessageParam] | None = None
|
|
|
|
|
|
class ResponsesStore:
|
|
def __init__(
|
|
self,
|
|
config: ResponsesStoreConfig | SqlStoreConfig,
|
|
policy: list[AccessRule],
|
|
):
|
|
# Handle backward compatibility
|
|
if not isinstance(config, ResponsesStoreConfig):
|
|
# Legacy: SqlStoreConfig passed directly as config
|
|
config = ResponsesStoreConfig(
|
|
sql_store_config=config,
|
|
)
|
|
|
|
self.config = config
|
|
self.sql_store_config = config.sql_store_config
|
|
if not self.sql_store_config:
|
|
self.sql_store_config = SqliteSqlStoreConfig(
|
|
db_path=(RUNTIME_BASE_DIR / "sqlstore.db").as_posix(),
|
|
)
|
|
self.sql_store = None
|
|
self.policy = policy
|
|
|
|
# Disable write queue for SQLite to avoid concurrency issues
|
|
self.enable_write_queue = self.sql_store_config.type != SqlStoreType.sqlite
|
|
|
|
# Async write queue and worker control
|
|
self._queue: (
|
|
asyncio.Queue[tuple[OpenAIResponseObject, list[OpenAIResponseInput], list[OpenAIMessageParam]]] | None
|
|
) = None
|
|
self._worker_tasks: list[asyncio.Task[Any]] = []
|
|
self._max_write_queue_size: int = config.max_write_queue_size
|
|
self._num_writers: int = max(1, config.num_writers)
|
|
|
|
async def initialize(self):
|
|
"""Create the necessary tables if they don't exist."""
|
|
self.sql_store = AuthorizedSqlStore(sqlstore_impl(self.sql_store_config), self.policy)
|
|
await self.sql_store.create_table(
|
|
"openai_responses",
|
|
{
|
|
"id": ColumnDefinition(type=ColumnType.STRING, primary_key=True),
|
|
"created_at": ColumnType.INTEGER,
|
|
"response_object": ColumnType.JSON,
|
|
"model": ColumnType.STRING,
|
|
},
|
|
)
|
|
|
|
if self.enable_write_queue:
|
|
self._queue = asyncio.Queue(maxsize=self._max_write_queue_size)
|
|
for _ in range(self._num_writers):
|
|
self._worker_tasks.append(asyncio.create_task(self._worker_loop()))
|
|
else:
|
|
logger.info("Write queue disabled for SQLite to avoid concurrency issues")
|
|
|
|
async def shutdown(self) -> None:
|
|
if not self._worker_tasks:
|
|
return
|
|
if self._queue is not None:
|
|
await self._queue.join()
|
|
for t in self._worker_tasks:
|
|
if not t.done():
|
|
t.cancel()
|
|
for t in self._worker_tasks:
|
|
try:
|
|
await t
|
|
except asyncio.CancelledError:
|
|
pass
|
|
self._worker_tasks.clear()
|
|
|
|
async def flush(self) -> None:
|
|
"""Wait for all queued writes to complete. Useful for testing."""
|
|
if self.enable_write_queue and self._queue is not None:
|
|
await self._queue.join()
|
|
|
|
async def store_response_object(
|
|
self,
|
|
response_object: OpenAIResponseObject,
|
|
input: list[OpenAIResponseInput],
|
|
messages: list[OpenAIMessageParam],
|
|
) -> None:
|
|
if self.enable_write_queue:
|
|
if self._queue is None:
|
|
raise ValueError("Responses store is not initialized")
|
|
try:
|
|
self._queue.put_nowait((response_object, input, messages))
|
|
except asyncio.QueueFull:
|
|
logger.warning(f"Write queue full; adding response id={getattr(response_object, 'id', '<unknown>')}")
|
|
await self._queue.put((response_object, input, messages))
|
|
else:
|
|
await self._write_response_object(response_object, input, messages)
|
|
|
|
async def _worker_loop(self) -> None:
|
|
assert self._queue is not None
|
|
while True:
|
|
try:
|
|
item = await self._queue.get()
|
|
except asyncio.CancelledError:
|
|
break
|
|
response_object, input, messages = item
|
|
try:
|
|
await self._write_response_object(response_object, input, messages)
|
|
except Exception as e: # noqa: BLE001
|
|
logger.error(f"Error writing response object: {e}")
|
|
finally:
|
|
self._queue.task_done()
|
|
|
|
async def _write_response_object(
|
|
self,
|
|
response_object: OpenAIResponseObject,
|
|
input: list[OpenAIResponseInput],
|
|
messages: list[OpenAIMessageParam],
|
|
) -> None:
|
|
if self.sql_store is None:
|
|
raise ValueError("Responses store is not initialized")
|
|
|
|
data = response_object.model_dump()
|
|
data["input"] = [input_item.model_dump() for input_item in input]
|
|
data["messages"] = [msg.model_dump() for msg in messages]
|
|
|
|
await self.sql_store.insert(
|
|
"openai_responses",
|
|
{
|
|
"id": data["id"],
|
|
"created_at": data["created_at"],
|
|
"model": data["model"],
|
|
"response_object": data,
|
|
},
|
|
)
|
|
|
|
async def list_responses(
|
|
self,
|
|
after: str | None = None,
|
|
limit: int | None = 50,
|
|
model: str | None = None,
|
|
order: Order | None = Order.desc,
|
|
) -> ListOpenAIResponseObject:
|
|
"""
|
|
List responses from the database.
|
|
|
|
:param after: The ID of the last response to return.
|
|
:param limit: The maximum number of responses to return.
|
|
:param model: The model to filter by.
|
|
:param order: The order to sort the responses by.
|
|
"""
|
|
if not self.sql_store:
|
|
raise ValueError("Responses store is not initialized")
|
|
|
|
if not order:
|
|
order = Order.desc
|
|
|
|
where_conditions = {}
|
|
if model:
|
|
where_conditions["model"] = model
|
|
|
|
paginated_result = await self.sql_store.fetch_all(
|
|
table="openai_responses",
|
|
where=where_conditions if where_conditions else None,
|
|
order_by=[("created_at", order.value)],
|
|
cursor=("id", after) if after else None,
|
|
limit=limit,
|
|
)
|
|
|
|
data = [OpenAIResponseObjectWithInput(**row["response_object"]) for row in paginated_result.data]
|
|
return ListOpenAIResponseObject(
|
|
data=data,
|
|
has_more=paginated_result.has_more,
|
|
first_id=data[0].id if data else "",
|
|
last_id=data[-1].id if data else "",
|
|
)
|
|
|
|
async def get_response_object(self, response_id: str) -> _OpenAIResponseObjectWithInputAndMessages:
|
|
"""
|
|
Get a response object with automatic access control checking.
|
|
"""
|
|
if not self.sql_store:
|
|
raise ValueError("Responses store is not initialized")
|
|
|
|
row = await self.sql_store.fetch_one(
|
|
"openai_responses",
|
|
where={"id": response_id},
|
|
)
|
|
|
|
if not row:
|
|
# SecureSqlStore will return None if record doesn't exist OR access is denied
|
|
# This provides security by not revealing whether the record exists
|
|
raise ValueError(f"Response with id {response_id} not found") from None
|
|
|
|
return _OpenAIResponseObjectWithInputAndMessages(**row["response_object"])
|
|
|
|
async def delete_response_object(self, response_id: str) -> OpenAIDeleteResponseObject:
|
|
if not self.sql_store:
|
|
raise ValueError("Responses store is not initialized")
|
|
|
|
row = await self.sql_store.fetch_one("openai_responses", where={"id": response_id})
|
|
if not row:
|
|
raise ValueError(f"Response with id {response_id} not found")
|
|
await self.sql_store.delete("openai_responses", where={"id": response_id})
|
|
return OpenAIDeleteResponseObject(id=response_id)
|
|
|
|
async def list_response_input_items(
|
|
self,
|
|
response_id: str,
|
|
after: str | None = None,
|
|
before: str | None = None,
|
|
include: list[str] | None = None,
|
|
limit: int | None = 20,
|
|
order: Order | None = Order.desc,
|
|
) -> ListOpenAIResponseInputItem:
|
|
"""
|
|
List input items for a given response.
|
|
|
|
:param response_id: The ID of the response to retrieve input items for.
|
|
:param after: An item ID to list items after, used for pagination.
|
|
:param before: An item ID to list items before, used for pagination.
|
|
:param include: Additional fields to include in the response.
|
|
:param limit: A limit on the number of objects to be returned.
|
|
:param order: The order to return the input items in.
|
|
"""
|
|
if include:
|
|
raise NotImplementedError("Include is not supported yet")
|
|
if before and after:
|
|
raise ValueError("Cannot specify both 'before' and 'after' parameters")
|
|
|
|
response_with_input_and_messages = await self.get_response_object(response_id)
|
|
items = response_with_input_and_messages.input
|
|
|
|
if order == Order.desc:
|
|
items = list(reversed(items))
|
|
|
|
start_index = 0
|
|
end_index = len(items)
|
|
|
|
if after or before:
|
|
for i, item in enumerate(items):
|
|
item_id = getattr(item, "id", None)
|
|
if after and item_id == after:
|
|
start_index = i + 1
|
|
if before and item_id == before:
|
|
end_index = i
|
|
break
|
|
|
|
if after and start_index == 0:
|
|
raise ValueError(f"Input item with id '{after}' not found for response '{response_id}'")
|
|
if before and end_index == len(items):
|
|
raise ValueError(f"Input item with id '{before}' not found for response '{response_id}'")
|
|
|
|
items = items[start_index:end_index]
|
|
|
|
# Apply limit
|
|
if limit is not None:
|
|
items = items[:limit]
|
|
|
|
return ListOpenAIResponseInputItem(data=items)
|