mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-04 04:04:14 +00:00
chore: introduce write queue for response_store
# What does this PR do? ## Test Plan
This commit is contained in:
parent
d3600b92d1
commit
f0da887e79
3 changed files with 123 additions and 6 deletions
|
@ -433,6 +433,12 @@ class InferenceStoreConfig(BaseModel):
|
||||||
num_writers: int = Field(default=4, description="Number of concurrent background writers")
|
num_writers: int = Field(default=4, description="Number of concurrent background writers")
|
||||||
|
|
||||||
|
|
||||||
|
class ResponsesStoreConfig(BaseModel):
|
||||||
|
sql_store_config: SqlStoreConfig
|
||||||
|
max_write_queue_size: int = Field(default=10000, description="Max queued writes for responses store")
|
||||||
|
num_writers: int = Field(default=4, description="Number of concurrent background writers")
|
||||||
|
|
||||||
|
|
||||||
class StackRunConfig(BaseModel):
|
class StackRunConfig(BaseModel):
|
||||||
version: int = LLAMA_STACK_RUN_CONFIG_VERSION
|
version: int = LLAMA_STACK_RUN_CONFIG_VERSION
|
||||||
|
|
||||||
|
|
|
@ -3,6 +3,9 @@
|
||||||
#
|
#
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
import asyncio
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
from llama_stack.apis.agents import (
|
from llama_stack.apis.agents import (
|
||||||
Order,
|
Order,
|
||||||
)
|
)
|
||||||
|
@ -14,25 +17,51 @@ from llama_stack.apis.agents.openai_responses import (
|
||||||
OpenAIResponseObject,
|
OpenAIResponseObject,
|
||||||
OpenAIResponseObjectWithInput,
|
OpenAIResponseObjectWithInput,
|
||||||
)
|
)
|
||||||
from llama_stack.core.datatypes import AccessRule
|
from llama_stack.core.datatypes import AccessRule, ResponsesStoreConfig
|
||||||
from llama_stack.core.utils.config_dirs import RUNTIME_BASE_DIR
|
from llama_stack.core.utils.config_dirs import RUNTIME_BASE_DIR
|
||||||
|
from llama_stack.log import get_logger
|
||||||
|
|
||||||
from ..sqlstore.api import ColumnDefinition, ColumnType
|
from ..sqlstore.api import ColumnDefinition, ColumnType
|
||||||
from ..sqlstore.authorized_sqlstore import AuthorizedSqlStore
|
from ..sqlstore.authorized_sqlstore import AuthorizedSqlStore
|
||||||
from ..sqlstore.sqlstore import SqliteSqlStoreConfig, SqlStoreConfig, sqlstore_impl
|
from ..sqlstore.sqlstore import SqliteSqlStoreConfig, SqlStoreConfig, SqlStoreType, sqlstore_impl
|
||||||
|
|
||||||
|
logger = get_logger(name=__name__, category="responses_store")
|
||||||
|
|
||||||
|
|
||||||
class ResponsesStore:
|
class ResponsesStore:
|
||||||
def __init__(self, sql_store_config: SqlStoreConfig, policy: list[AccessRule]):
|
def __init__(
|
||||||
if not sql_store_config:
|
self,
|
||||||
sql_store_config = SqliteSqlStoreConfig(
|
config: ResponsesStoreConfig | SqlStoreConfig,
|
||||||
|
policy: list[AccessRule],
|
||||||
|
):
|
||||||
|
# Handle backward compatibility
|
||||||
|
if not isinstance(config, ResponsesStoreConfig):
|
||||||
|
# Legacy: SqlStoreConfig passed directly as config
|
||||||
|
config = ResponsesStoreConfig(
|
||||||
|
sql_store_config=config,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.config = config
|
||||||
|
self.sql_store_config = config.sql_store_config
|
||||||
|
if not self.sql_store_config:
|
||||||
|
self.sql_store_config = SqliteSqlStoreConfig(
|
||||||
db_path=(RUNTIME_BASE_DIR / "sqlstore.db").as_posix(),
|
db_path=(RUNTIME_BASE_DIR / "sqlstore.db").as_posix(),
|
||||||
)
|
)
|
||||||
self.sql_store = AuthorizedSqlStore(sqlstore_impl(sql_store_config))
|
self.sql_store = None
|
||||||
self.policy = policy
|
self.policy = policy
|
||||||
|
|
||||||
|
# Disable write queue for SQLite to avoid concurrency issues
|
||||||
|
self.enable_write_queue = self.sql_store_config.type != SqlStoreType.sqlite
|
||||||
|
|
||||||
|
# Async write queue and worker control
|
||||||
|
self._queue: asyncio.Queue[tuple[OpenAIResponseObject, list[OpenAIResponseInput]]] | None = None
|
||||||
|
self._worker_tasks: list[asyncio.Task[Any]] = []
|
||||||
|
self._max_write_queue_size: int = config.max_write_queue_size
|
||||||
|
self._num_writers: int = max(1, config.num_writers)
|
||||||
|
|
||||||
async def initialize(self):
|
async def initialize(self):
|
||||||
"""Create the necessary tables if they don't exist."""
|
"""Create the necessary tables if they don't exist."""
|
||||||
|
self.sql_store = AuthorizedSqlStore(sqlstore_impl(self.sql_store_config))
|
||||||
await self.sql_store.create_table(
|
await self.sql_store.create_table(
|
||||||
"openai_responses",
|
"openai_responses",
|
||||||
{
|
{
|
||||||
|
@ -43,9 +72,70 @@ class ResponsesStore:
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if self.enable_write_queue:
|
||||||
|
self._queue = asyncio.Queue(maxsize=self._max_write_queue_size)
|
||||||
|
for _ in range(self._num_writers):
|
||||||
|
self._worker_tasks.append(asyncio.create_task(self._worker_loop()))
|
||||||
|
else:
|
||||||
|
logger.info("Write queue disabled for SQLite to avoid concurrency issues")
|
||||||
|
|
||||||
|
async def shutdown(self) -> None:
|
||||||
|
if not self._worker_tasks:
|
||||||
|
return
|
||||||
|
if self._queue is not None:
|
||||||
|
await self._queue.join()
|
||||||
|
for t in self._worker_tasks:
|
||||||
|
if not t.done():
|
||||||
|
t.cancel()
|
||||||
|
for t in self._worker_tasks:
|
||||||
|
try:
|
||||||
|
await t
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
pass
|
||||||
|
self._worker_tasks.clear()
|
||||||
|
|
||||||
|
async def flush(self) -> None:
|
||||||
|
"""Wait for all queued writes to complete. Useful for testing."""
|
||||||
|
if self.enable_write_queue and self._queue is not None:
|
||||||
|
await self._queue.join()
|
||||||
|
|
||||||
async def store_response_object(
|
async def store_response_object(
|
||||||
self, response_object: OpenAIResponseObject, input: list[OpenAIResponseInput]
|
self, response_object: OpenAIResponseObject, input: list[OpenAIResponseInput]
|
||||||
) -> None:
|
) -> None:
|
||||||
|
if self.enable_write_queue:
|
||||||
|
if self._queue is None:
|
||||||
|
raise ValueError("Responses store is not initialized")
|
||||||
|
try:
|
||||||
|
self._queue.put_nowait((response_object, input))
|
||||||
|
except asyncio.QueueFull:
|
||||||
|
logger.warning(
|
||||||
|
f"Write queue full; adding response id={getattr(response_object, 'id', '<unknown>')}"
|
||||||
|
)
|
||||||
|
await self._queue.put((response_object, input))
|
||||||
|
else:
|
||||||
|
await self._write_response_object(response_object, input)
|
||||||
|
|
||||||
|
async def _worker_loop(self) -> None:
|
||||||
|
assert self._queue is not None
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
item = await self._queue.get()
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
break
|
||||||
|
response_object, input = item
|
||||||
|
try:
|
||||||
|
await self._write_response_object(response_object, input)
|
||||||
|
except Exception as e: # noqa: BLE001
|
||||||
|
logger.error(f"Error writing response object: {e}")
|
||||||
|
finally:
|
||||||
|
self._queue.task_done()
|
||||||
|
|
||||||
|
async def _write_response_object(
|
||||||
|
self, response_object: OpenAIResponseObject, input: list[OpenAIResponseInput]
|
||||||
|
) -> None:
|
||||||
|
if self.sql_store is None:
|
||||||
|
raise ValueError("Responses store is not initialized")
|
||||||
|
|
||||||
data = response_object.model_dump()
|
data = response_object.model_dump()
|
||||||
data["input"] = [input_item.model_dump() for input_item in input]
|
data["input"] = [input_item.model_dump() for input_item in input]
|
||||||
|
|
||||||
|
|
|
@ -67,6 +67,9 @@ async def test_responses_store_pagination_basic():
|
||||||
input_list = [create_test_response_input(f"Input for {response_id}", f"input-{response_id}")]
|
input_list = [create_test_response_input(f"Input for {response_id}", f"input-{response_id}")]
|
||||||
await store.store_response_object(response, input_list)
|
await store.store_response_object(response, input_list)
|
||||||
|
|
||||||
|
# Wait for all queued writes to complete
|
||||||
|
await store.flush()
|
||||||
|
|
||||||
# Test 1: First page with limit=2, descending order (default)
|
# Test 1: First page with limit=2, descending order (default)
|
||||||
result = await store.list_responses(limit=2, order=Order.desc)
|
result = await store.list_responses(limit=2, order=Order.desc)
|
||||||
assert len(result.data) == 2
|
assert len(result.data) == 2
|
||||||
|
@ -110,6 +113,9 @@ async def test_responses_store_pagination_ascending():
|
||||||
input_list = [create_test_response_input(f"Input for {response_id}", f"input-{response_id}")]
|
input_list = [create_test_response_input(f"Input for {response_id}", f"input-{response_id}")]
|
||||||
await store.store_response_object(response, input_list)
|
await store.store_response_object(response, input_list)
|
||||||
|
|
||||||
|
# Wait for all queued writes to complete
|
||||||
|
await store.flush()
|
||||||
|
|
||||||
# Test ascending order pagination
|
# Test ascending order pagination
|
||||||
result = await store.list_responses(limit=1, order=Order.asc)
|
result = await store.list_responses(limit=1, order=Order.asc)
|
||||||
assert len(result.data) == 1
|
assert len(result.data) == 1
|
||||||
|
@ -145,6 +151,9 @@ async def test_responses_store_pagination_with_model_filter():
|
||||||
input_list = [create_test_response_input(f"Input for {response_id}", f"input-{response_id}")]
|
input_list = [create_test_response_input(f"Input for {response_id}", f"input-{response_id}")]
|
||||||
await store.store_response_object(response, input_list)
|
await store.store_response_object(response, input_list)
|
||||||
|
|
||||||
|
# Wait for all queued writes to complete
|
||||||
|
await store.flush()
|
||||||
|
|
||||||
# Test pagination with model filter
|
# Test pagination with model filter
|
||||||
result = await store.list_responses(limit=1, model="model-a", order=Order.desc)
|
result = await store.list_responses(limit=1, model="model-a", order=Order.desc)
|
||||||
assert len(result.data) == 1
|
assert len(result.data) == 1
|
||||||
|
@ -192,6 +201,9 @@ async def test_responses_store_pagination_no_limit():
|
||||||
input_list = [create_test_response_input(f"Input for {response_id}", f"input-{response_id}")]
|
input_list = [create_test_response_input(f"Input for {response_id}", f"input-{response_id}")]
|
||||||
await store.store_response_object(response, input_list)
|
await store.store_response_object(response, input_list)
|
||||||
|
|
||||||
|
# Wait for all queued writes to complete
|
||||||
|
await store.flush()
|
||||||
|
|
||||||
# Test without limit (should use default of 50)
|
# Test without limit (should use default of 50)
|
||||||
result = await store.list_responses(order=Order.desc)
|
result = await store.list_responses(order=Order.desc)
|
||||||
assert len(result.data) == 2
|
assert len(result.data) == 2
|
||||||
|
@ -212,6 +224,9 @@ async def test_responses_store_get_response_object():
|
||||||
input_list = [create_test_response_input("Test input content", "input-test-resp")]
|
input_list = [create_test_response_input("Test input content", "input-test-resp")]
|
||||||
await store.store_response_object(response, input_list)
|
await store.store_response_object(response, input_list)
|
||||||
|
|
||||||
|
# Wait for all queued writes to complete
|
||||||
|
await store.flush()
|
||||||
|
|
||||||
# Retrieve the response
|
# Retrieve the response
|
||||||
retrieved = await store.get_response_object("test-resp")
|
retrieved = await store.get_response_object("test-resp")
|
||||||
assert retrieved.id == "test-resp"
|
assert retrieved.id == "test-resp"
|
||||||
|
@ -242,6 +257,9 @@ async def test_responses_store_input_items_pagination():
|
||||||
]
|
]
|
||||||
await store.store_response_object(response, input_list)
|
await store.store_response_object(response, input_list)
|
||||||
|
|
||||||
|
# Wait for all queued writes to complete
|
||||||
|
await store.flush()
|
||||||
|
|
||||||
# Verify all items are stored correctly with explicit IDs
|
# Verify all items are stored correctly with explicit IDs
|
||||||
all_items = await store.list_response_input_items("test-resp", order=Order.desc)
|
all_items = await store.list_response_input_items("test-resp", order=Order.desc)
|
||||||
assert len(all_items.data) == 5
|
assert len(all_items.data) == 5
|
||||||
|
@ -319,6 +337,9 @@ async def test_responses_store_input_items_before_pagination():
|
||||||
]
|
]
|
||||||
await store.store_response_object(response, input_list)
|
await store.store_response_object(response, input_list)
|
||||||
|
|
||||||
|
# Wait for all queued writes to complete
|
||||||
|
await store.flush()
|
||||||
|
|
||||||
# Test before pagination with descending order
|
# Test before pagination with descending order
|
||||||
# In desc order: [Fifth, Fourth, Third, Second, First]
|
# In desc order: [Fifth, Fourth, Third, Second, First]
|
||||||
# before="before-3" should return [Fifth, Fourth]
|
# before="before-3" should return [Fifth, Fourth]
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue