chore: introduce write queue for inference_store

# What does this PR do?
Adds a write worker queue for writes to inference store. This avoids overwhelming request processing with slow inference writes.

## Test Plan

Benchmark:
```
cd /docs/source/distributions/k8s-benchmark
# start mock server
python openai-mock-server.py --port 8000
# start stack server
uv run --with llama-stack python -m llama_stack.core.server.server docs/source/distributions/k8s-benchmark/stack_run_config.yaml
# run benchmark script
uv run python3 benchmark.py --duration 120 --concurrent 50 --base-url=http://localhost:8321/v1/openai/v1 --model=vllm-inference/meta-llama/Llama-3.2-3B-Instruct
```


Before:

============================================================
BENCHMARK RESULTS

Response Time Statistics:
  Mean: 1.111s
  Median: 0.982s
  Min: 0.466s
  Max: 15.190s
  Std Dev: 1.091s

Percentiles:
  P50: 0.982s
  P90: 1.281s
  P95: 1.439s
  P99: 5.476s

Time to First Token (TTFT) Statistics:
  Mean: 0.474s
  Median: 0.347s
  Min: 0.175s
  Max: 15.129s
  Std Dev: 0.819s

TTFT Percentiles:
  P50: 0.347s
  P90: 0.661s
  P95: 0.762s
  P99: 2.788s

Streaming Statistics:
  Mean chunks per response: 67.2
  Total chunks received: 122154
============================================================
Total time: 120.00s
Concurrent users: 50
Total requests: 1919
Successful requests: 1819
Failed requests: 100
Success rate: 94.8%
Requests per second: 15.16

Errors (showing first 5):
  Request error:
  Request error:
  Request error:
  Request error:
  Request error:
Benchmark completed.
Stopping server (PID: 679)...
Server stopped.


After:

============================================================
BENCHMARK RESULTS

Response Time Statistics:
  Mean: 1.085s
  Median: 1.089s
  Min: 0.451s
  Max: 2.002s
  Std Dev: 0.212s

Percentiles:
  P50: 1.089s
  P90: 1.343s
  P95: 1.409s
  P99: 1.617s

Time to First Token (TTFT) Statistics:
  Mean: 0.407s
  Median: 0.361s
  Min: 0.182s
  Max: 1.178s
  Std Dev: 0.175s

TTFT Percentiles:
  P50: 0.361s
  P90: 0.644s
  P95: 0.744s
  P99: 0.932s

Streaming Statistics:
  Mean chunks per response: 66.8
  Total chunks received: 367240
============================================================
Total time: 120.00s
Concurrent users: 50
Total requests: 5495
Successful requests: 5495
Failed requests: 0
Success rate: 100.0%
Requests per second: 45.79
Benchmark completed.
Stopping server (PID: 97169)...
Server stopped.
This commit is contained in:
Eric Huang 2025-09-10 11:43:26 -07:00
parent 935b8e28de
commit e721ca9730
7 changed files with 139 additions and 22 deletions

View file

@ -3,6 +3,9 @@
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import asyncio
from typing import Any
from llama_stack.apis.inference import (
ListOpenAIChatCompletionResponse,
OpenAIChatCompletion,
@ -10,24 +13,43 @@ from llama_stack.apis.inference import (
OpenAIMessageParam,
Order,
)
from llama_stack.core.datatypes import AccessRule
from llama_stack.core.utils.config_dirs import RUNTIME_BASE_DIR
from llama_stack.core.datatypes import AccessRule, InferenceStoreConfig
from llama_stack.log import get_logger
from ..sqlstore.api import ColumnDefinition, ColumnType
from ..sqlstore.authorized_sqlstore import AuthorizedSqlStore
from ..sqlstore.sqlstore import SqliteSqlStoreConfig, SqlStoreConfig, sqlstore_impl
from ..sqlstore.sqlstore import SqlStoreConfig, SqlStoreType, sqlstore_impl
logger = get_logger(name=__name__, category="inference_store")
class InferenceStore:
def __init__(self, sql_store_config: SqlStoreConfig, policy: list[AccessRule]):
if not sql_store_config:
sql_store_config = SqliteSqlStoreConfig(
db_path=(RUNTIME_BASE_DIR / "sqlstore.db").as_posix(),
def __init__(
self,
config: InferenceStoreConfig | SqlStoreConfig,
policy: list[AccessRule],
):
# Handle backward compatibility
if not isinstance(config, InferenceStoreConfig):
# Legacy: SqlStoreConfig passed directly as config
config = InferenceStoreConfig(
sql_store_config=config,
)
self.sql_store_config = sql_store_config
self.config = config
self.sql_store_config = config.sql_store_config
self.sql_store = None
self.policy = policy
# Disable write queue for SQLite to avoid concurrency issues
self.enable_write_queue = self.sql_store_config.type != SqlStoreType.sqlite
# Async write queue and worker control
self._queue: asyncio.Queue[tuple[OpenAIChatCompletion, list[OpenAIMessageParam]]] | None = None
self._worker_tasks: list[asyncio.Task[Any]] = []
self._max_write_queue_size: int = config.max_write_queue_size
self._num_writers: int = max(1, config.num_writers)
async def initialize(self):
"""Create the necessary tables if they don't exist."""
self.sql_store = AuthorizedSqlStore(sqlstore_impl(self.sql_store_config))
@ -42,10 +64,68 @@ class InferenceStore:
},
)
if self.enable_write_queue:
self._queue = asyncio.Queue(maxsize=self._max_write_queue_size)
for _ in range(self._num_writers):
self._worker_tasks.append(asyncio.create_task(self._worker_loop()))
else:
logger.info("Write queue disabled for SQLite to avoid concurrency issues")
async def shutdown(self) -> None:
if not self._worker_tasks:
return
if self._queue is not None:
await self._queue.join()
for t in self._worker_tasks:
if not t.done():
t.cancel()
for t in self._worker_tasks:
try:
await t
except asyncio.CancelledError:
pass
self._worker_tasks.clear()
async def flush(self) -> None:
"""Wait for all queued writes to complete. Useful for testing."""
if self.enable_write_queue and self._queue is not None:
await self._queue.join()
async def store_chat_completion(
self, chat_completion: OpenAIChatCompletion, input_messages: list[OpenAIMessageParam]
) -> None:
if not self.sql_store:
if self.enable_write_queue:
if self._queue is None:
raise ValueError("Inference store is not initialized")
try:
self._queue.put_nowait((chat_completion, input_messages))
except asyncio.QueueFull:
logger.warning(
f"Write queue full; adding chat completion id={getattr(chat_completion, 'id', '<unknown>')}"
)
await self._queue.put((chat_completion, input_messages))
else:
await self._write_chat_completion(chat_completion, input_messages)
async def _worker_loop(self) -> None:
assert self._queue is not None
while True:
try:
item = await self._queue.get()
except asyncio.CancelledError:
break
chat_completion, input_messages = item
try:
await self._write_chat_completion(chat_completion, input_messages)
except Exception as e: # noqa: BLE001
logger.error(f"Error writing chat completion: {e}")
finally:
self._queue.task_done()
async def _write_chat_completion(
self, chat_completion: OpenAIChatCompletion, input_messages: list[OpenAIMessageParam]
) -> None:
if self.sql_store is None:
raise ValueError("Inference store is not initialized")
data = chat_completion.model_dump()