fix: harden storage semantics

Fixes issues in the storage system by guaranteeing immediate durability for responses and ensuring background writers stay alive.

Responses to the OpenAI-compatible API now write directly to Postgres/SQLite inside the request instead of detouring through an async queue that might never drain; this restores the expected read-after-write behavior and removes the "response not found" races reported by users.
The access-control shim was stamping owner_principal/access_attributes as SQL NULL, which Postgres interprets as non-public rows; fixing it to use the empty-string/JSON-null pattern means conversations and responses stored without an authenticated user stay queryable (matching SQLite).

The inference-store queue remains for batching, but its worker tasks now start lazily on the live event loop so server startup doesn't cancel them—writes keep flowing even when the stack is launched via llama stack run.
This commit is contained in:
Ashwin Bharambe 2025-11-10 13:59:53 -08:00
parent 4341c4c2ac
commit 8e18584c67
3 changed files with 58 additions and 83 deletions

View file

@ -67,11 +67,8 @@ class InferenceStore:
)
if self.enable_write_queue:
self._queue = asyncio.Queue(maxsize=self._max_write_queue_size)
for _ in range(self._num_writers):
self._worker_tasks.append(asyncio.create_task(self._worker_loop()))
logger.debug(
f"Inference store write queue enabled with {self._num_writers} writers, max queue size {self._max_write_queue_size}"
f"Inference store write queue configured for {self._num_writers} writers, max queue size {self._max_write_queue_size}"
)
async def shutdown(self) -> None:
@ -94,10 +91,29 @@ class InferenceStore:
if self.enable_write_queue and self._queue is not None:
await self._queue.join()
async def _ensure_workers_started(self) -> None:
"""Ensure the async write queue workers run on the current loop."""
if not self.enable_write_queue:
return
if self._queue is None:
self._queue = asyncio.Queue(maxsize=self._max_write_queue_size)
logger.debug(
f"Inference store write queue created with max size {self._max_write_queue_size} "
f"and {self._num_writers} writers"
)
if not self._worker_tasks:
loop = asyncio.get_running_loop()
for _ in range(self._num_writers):
task = loop.create_task(self._worker_loop())
self._worker_tasks.append(task)
async def store_chat_completion(
self, chat_completion: OpenAIChatCompletion, input_messages: list[OpenAIMessageParam]
) -> None:
if self.enable_write_queue:
await self._ensure_workers_started()
if self._queue is None:
raise ValueError("Inference store is not initialized")
try:

View file

@ -3,8 +3,6 @@
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import asyncio
from typing import Any
from llama_stack.apis.agents import (
Order,
@ -19,12 +17,12 @@ from llama_stack.apis.agents.openai_responses import (
)
from llama_stack.apis.inference import OpenAIMessageParam
from llama_stack.core.datatypes import AccessRule
from llama_stack.core.storage.datatypes import ResponsesStoreReference, SqlStoreReference, StorageBackendType
from llama_stack.core.storage.datatypes import ResponsesStoreReference, SqlStoreReference
from llama_stack.log import get_logger
from ..sqlstore.api import ColumnDefinition, ColumnType
from ..sqlstore.authorized_sqlstore import AuthorizedSqlStore
from ..sqlstore.sqlstore import _SQLSTORE_BACKENDS, sqlstore_impl
from ..sqlstore.sqlstore import sqlstore_impl
logger = get_logger(name=__name__, category="openai_responses")
@ -55,28 +53,12 @@ class ResponsesStore:
self.policy = policy
self.sql_store = None
self.enable_write_queue = True
# Async write queue and worker control
self._queue: (
asyncio.Queue[tuple[OpenAIResponseObject, list[OpenAIResponseInput], list[OpenAIMessageParam]]] | None
) = None
self._worker_tasks: list[asyncio.Task[Any]] = []
self._max_write_queue_size: int = self.reference.max_write_queue_size
self._num_writers: int = max(1, self.reference.num_writers)
async def initialize(self):
"""Create the necessary tables if they don't exist."""
base_store = sqlstore_impl(self.reference)
self.sql_store = AuthorizedSqlStore(base_store, self.policy)
# Disable write queue for SQLite since WAL mode handles concurrency
# Keep it enabled for other backends (like Postgres) for performance
backend_config = _SQLSTORE_BACKENDS.get(self.reference.backend)
if backend_config and backend_config.type == StorageBackendType.SQL_SQLITE:
self.enable_write_queue = False
logger.debug("Write queue disabled for SQLite (WAL mode handles concurrency)")
await self.sql_store.create_table(
"openai_responses",
{
@ -95,33 +77,12 @@ class ResponsesStore:
},
)
if self.enable_write_queue:
self._queue = asyncio.Queue(maxsize=self._max_write_queue_size)
for _ in range(self._num_writers):
self._worker_tasks.append(asyncio.create_task(self._worker_loop()))
logger.debug(
f"Responses store write queue enabled with {self._num_writers} writers, max queue size {self._max_write_queue_size}"
)
async def shutdown(self) -> None:
if not self._worker_tasks:
return
if self._queue is not None:
await self._queue.join()
for t in self._worker_tasks:
if not t.done():
t.cancel()
for t in self._worker_tasks:
try:
await t
except asyncio.CancelledError:
pass
self._worker_tasks.clear()
return
async def flush(self) -> None:
"""Wait for all queued writes to complete. Useful for testing."""
if self.enable_write_queue and self._queue is not None:
await self._queue.join()
"""Maintained for compatibility; no-op now that writes are synchronous."""
return
async def store_response_object(
self,
@ -129,31 +90,8 @@ class ResponsesStore:
input: list[OpenAIResponseInput],
messages: list[OpenAIMessageParam],
) -> None:
if self.enable_write_queue:
if self._queue is None:
raise ValueError("Responses store is not initialized")
try:
self._queue.put_nowait((response_object, input, messages))
except asyncio.QueueFull:
logger.warning(f"Write queue full; adding response id={getattr(response_object, 'id', '<unknown>')}")
await self._queue.put((response_object, input, messages))
else:
await self._write_response_object(response_object, input, messages)
async def _worker_loop(self) -> None:
assert self._queue is not None
while True:
try:
item = await self._queue.get()
except asyncio.CancelledError:
break
response_object, input, messages = item
try:
await self._write_response_object(response_object, input, messages)
except Exception as e: # noqa: BLE001
logger.error(f"Error writing response object: {e}")
finally:
self._queue.task_done()
logger.info(f"💾 Writing response id={response_object.id}")
await self._write_response_object(response_object, input, messages)
async def _write_response_object(
self,
@ -312,6 +250,8 @@ class ResponsesStore:
if not self.sql_store:
raise ValueError("Responses store is not initialized")
logger.info(f"💬 Storing {len(messages)} messages for conversation {conversation_id}")
# Serialize messages to dict format for JSON storage
messages_data = [msg.model_dump() for msg in messages]
@ -321,13 +261,16 @@ class ResponsesStore:
table="conversation_messages",
data={"conversation_id": conversation_id, "messages": messages_data},
)
except Exception:
logger.info(f"✅ Inserted conversation messages for {conversation_id}")
except Exception as e:
logger.info(f"🔄 Insert failed, trying update for {conversation_id}: {e}")
# If insert fails due to ID conflict, update existing record
await self.sql_store.update(
table="conversation_messages",
data={"messages": messages_data},
where={"conversation_id": conversation_id},
)
logger.info(f"✅ Updated conversation messages for {conversation_id}")
logger.debug(f"Stored {len(messages)} messages for conversation {conversation_id}")

View file

@ -45,8 +45,13 @@ def _enhance_item_with_access_control(item: Mapping[str, Any], current_user: Use
enhanced["owner_principal"] = current_user.principal
enhanced["access_attributes"] = current_user.attributes
else:
enhanced["owner_principal"] = None
enhanced["access_attributes"] = None
# IMPORTANT: Use empty string and null value (not None) to match public access filter
# The public access filter in _get_public_access_conditions() expects:
# - owner_principal = '' (empty string)
# - access_attributes = null (JSON null, which serializes to the string 'null')
# Setting them to None (SQL NULL) will cause rows to be filtered out on read.
enhanced["owner_principal"] = ""
enhanced["access_attributes"] = None # Pydantic/JSON will serialize this as JSON null
return enhanced
@ -188,8 +193,9 @@ class AuthorizedSqlStore:
enhanced_data["owner_principal"] = current_user.principal
enhanced_data["access_attributes"] = current_user.attributes
else:
enhanced_data["owner_principal"] = None
enhanced_data["access_attributes"] = None
# IMPORTANT: Use empty string for owner_principal to match public access filter
enhanced_data["owner_principal"] = ""
enhanced_data["access_attributes"] = None # Will serialize as JSON null
await self.sql_store.update(table, enhanced_data, where)
@ -245,14 +251,24 @@ class AuthorizedSqlStore:
raise ValueError(f"Unsupported database type: {self.database_type}")
def _get_public_access_conditions(self) -> list[str]:
"""Get the SQL conditions for public access."""
# Public records are records that have no owner_principal or access_attributes
"""Get the SQL conditions for public access.
Public records are those with:
- owner_principal = '' (empty string)
- access_attributes is either SQL NULL or JSON null
Note: Different databases serialize None differently:
- SQLite: None JSON null (text = 'null')
- Postgres: None SQL NULL (IS NULL)
"""
conditions = ["owner_principal = ''"]
if self.database_type == StorageBackendType.SQL_POSTGRES.value:
# Postgres stores JSON null as 'null'
conditions.append("access_attributes::text = 'null'")
# Accept both SQL NULL and JSON null for Postgres compatibility
# This handles both old rows (SQL NULL) and new rows (JSON null)
conditions.append("(access_attributes IS NULL OR access_attributes::text = 'null')")
elif self.database_type == StorageBackendType.SQL_SQLITE.value:
conditions.append("access_attributes = 'null'")
# SQLite serializes None as JSON null
conditions.append("(access_attributes IS NULL OR access_attributes = 'null')")
else:
raise ValueError(f"Unsupported database type: {self.database_type}")
return conditions