chore: introduce write queue for inference_store (#3383)

# What does this PR do?
Adds a write worker queue for writes to inference store. This avoids
overwhelming request processing with slow inference writes.

## Test Plan

Benchmark:
```
cd /docs/source/distributions/k8s-benchmark
# start mock server
python openai-mock-server.py --port 8000
# start stack server
LLAMA_STACK_LOGGING="all=WARNING" uv run --with llama-stack python -m llama_stack.core.server.server docs/source/distributions/k8s-benchmark/stack_run_config.yaml
# run benchmark script
uv run python3 benchmark.py --duration 120 --concurrent 50 --base-url=http://localhost:8321/v1/openai/v1 --model=vllm-inference/meta-llama/Llama-3.2-3B-Instruct
```
## RPS from 21 -> 57
This commit is contained in:
ehhuang 2025-09-10 11:57:42 -07:00 committed by GitHub
parent e6edc1f934
commit e980436a2e
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
7 changed files with 139 additions and 22 deletions

View file

@ -58,14 +58,6 @@ class BenchmarkStats:
print(f"\n{'='*60}")
print(f"BENCHMARK RESULTS")
print(f"{'='*60}")
print(f"Total time: {total_time:.2f}s")
print(f"Concurrent users: {self.concurrent_users}")
print(f"Total requests: {self.total_requests}")
print(f"Successful requests: {self.success_count}")
print(f"Failed requests: {len(self.errors)}")
print(f"Success rate: {success_rate:.1f}%")
print(f"Requests per second: {self.success_count / total_time:.2f}")
print(f"\nResponse Time Statistics:")
print(f" Mean: {statistics.mean(self.response_times):.3f}s")
@ -106,6 +98,15 @@ class BenchmarkStats:
print(f" Mean chunks per response: {statistics.mean(self.chunks_received):.1f}")
print(f" Total chunks received: {sum(self.chunks_received)}")
print(f"{'='*60}")
print(f"Total time: {total_time:.2f}s")
print(f"Concurrent users: {self.concurrent_users}")
print(f"Total requests: {self.total_requests}")
print(f"Successful requests: {self.success_count}")
print(f"Failed requests: {len(self.errors)}")
print(f"Success rate: {success_rate:.1f}%")
print(f"Requests per second: {self.success_count / total_time:.2f}")
if self.errors:
print(f"\nErrors (showing first 5):")
for error in self.errors[:5]:
@ -215,7 +216,7 @@ class LlamaStackBenchmark:
await asyncio.sleep(1) # Report every second
if time.time() >= last_report_time + 10: # Report every 10 seconds
elapsed = time.time() - stats.start_time
print(f"Completed: {stats.total_requests} requests in {elapsed:.1f}s")
print(f"Completed: {stats.total_requests} requests in {elapsed:.1f}s, RPS: {stats.total_requests / elapsed:.1f}")
last_report_time = time.time()
except asyncio.CancelledError:
break

View file

@ -2,6 +2,7 @@ version: '2'
image_name: kubernetes-benchmark-demo
apis:
- agents
- files
- inference
- files
- safety
@ -20,6 +21,14 @@ providers:
- provider_id: sentence-transformers
provider_type: inline::sentence-transformers
config: {}
files:
- provider_id: meta-reference-files
provider_type: inline::localfs
config:
storage_dir: ${env.FILES_STORAGE_DIR:=~/.llama/distributions/starter/files}
metadata_store:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter}/files_metadata.db
vector_io:
- provider_id: ${env.ENABLE_CHROMADB:+chromadb}
provider_type: remote::chromadb

View file

@ -431,6 +431,12 @@ class ServerConfig(BaseModel):
)
class InferenceStoreConfig(BaseModel):
sql_store_config: SqlStoreConfig
max_write_queue_size: int = Field(default=10000, description="Max queued writes for inference store")
num_writers: int = Field(default=4, description="Number of concurrent background writers")
class StackRunConfig(BaseModel):
version: int = LLAMA_STACK_RUN_CONFIG_VERSION
@ -464,11 +470,12 @@ Configuration for the persistence store used by the distribution registry. If no
a default SQLite store will be used.""",
)
inference_store: SqlStoreConfig | None = Field(
inference_store: InferenceStoreConfig | SqlStoreConfig | None = Field(
default=None,
description="""
Configuration for the persistence store used by the inference API. If not specified,
a default SQLite store will be used.""",
Configuration for the persistence store used by the inference API. Can be either a
InferenceStoreConfig (with queue tuning parameters) or a SqlStoreConfig (deprecated).
If not specified, a default SQLite store will be used.""",
)
# registry of "resources" in the distribution

View file

@ -78,7 +78,10 @@ async def get_auto_router_impl(
# TODO: move pass configs to routers instead
if api == Api.inference and run_config.inference_store:
inference_store = InferenceStore(run_config.inference_store, policy)
inference_store = InferenceStore(
config=run_config.inference_store,
policy=policy,
)
await inference_store.initialize()
api_to_dep_impl["store"] = inference_store

View file

@ -90,6 +90,11 @@ class InferenceRouter(Inference):
async def shutdown(self) -> None:
logger.debug("InferenceRouter.shutdown")
if self.store:
try:
await self.store.shutdown()
except Exception as e:
logger.warning(f"Error during InferenceStore shutdown: {e}")
async def register_model(
self,

View file

@ -3,6 +3,9 @@
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import asyncio
from typing import Any
from llama_stack.apis.inference import (
ListOpenAIChatCompletionResponse,
OpenAIChatCompletion,
@ -10,24 +13,43 @@ from llama_stack.apis.inference import (
OpenAIMessageParam,
Order,
)
from llama_stack.core.datatypes import AccessRule
from llama_stack.core.utils.config_dirs import RUNTIME_BASE_DIR
from llama_stack.core.datatypes import AccessRule, InferenceStoreConfig
from llama_stack.log import get_logger
from ..sqlstore.api import ColumnDefinition, ColumnType
from ..sqlstore.authorized_sqlstore import AuthorizedSqlStore
from ..sqlstore.sqlstore import SqliteSqlStoreConfig, SqlStoreConfig, sqlstore_impl
from ..sqlstore.sqlstore import SqlStoreConfig, SqlStoreType, sqlstore_impl
logger = get_logger(name=__name__, category="inference_store")
class InferenceStore:
def __init__(self, sql_store_config: SqlStoreConfig, policy: list[AccessRule]):
if not sql_store_config:
sql_store_config = SqliteSqlStoreConfig(
db_path=(RUNTIME_BASE_DIR / "sqlstore.db").as_posix(),
def __init__(
self,
config: InferenceStoreConfig | SqlStoreConfig,
policy: list[AccessRule],
):
# Handle backward compatibility
if not isinstance(config, InferenceStoreConfig):
# Legacy: SqlStoreConfig passed directly as config
config = InferenceStoreConfig(
sql_store_config=config,
)
self.sql_store_config = sql_store_config
self.config = config
self.sql_store_config = config.sql_store_config
self.sql_store = None
self.policy = policy
# Disable write queue for SQLite to avoid concurrency issues
self.enable_write_queue = self.sql_store_config.type != SqlStoreType.sqlite
# Async write queue and worker control
self._queue: asyncio.Queue[tuple[OpenAIChatCompletion, list[OpenAIMessageParam]]] | None = None
self._worker_tasks: list[asyncio.Task[Any]] = []
self._max_write_queue_size: int = config.max_write_queue_size
self._num_writers: int = max(1, config.num_writers)
async def initialize(self):
"""Create the necessary tables if they don't exist."""
self.sql_store = AuthorizedSqlStore(sqlstore_impl(self.sql_store_config))
@ -42,10 +64,68 @@ class InferenceStore:
},
)
if self.enable_write_queue:
self._queue = asyncio.Queue(maxsize=self._max_write_queue_size)
for _ in range(self._num_writers):
self._worker_tasks.append(asyncio.create_task(self._worker_loop()))
else:
logger.info("Write queue disabled for SQLite to avoid concurrency issues")
async def shutdown(self) -> None:
if not self._worker_tasks:
return
if self._queue is not None:
await self._queue.join()
for t in self._worker_tasks:
if not t.done():
t.cancel()
for t in self._worker_tasks:
try:
await t
except asyncio.CancelledError:
pass
self._worker_tasks.clear()
async def flush(self) -> None:
"""Wait for all queued writes to complete. Useful for testing."""
if self.enable_write_queue and self._queue is not None:
await self._queue.join()
async def store_chat_completion(
self, chat_completion: OpenAIChatCompletion, input_messages: list[OpenAIMessageParam]
) -> None:
if not self.sql_store:
if self.enable_write_queue:
if self._queue is None:
raise ValueError("Inference store is not initialized")
try:
self._queue.put_nowait((chat_completion, input_messages))
except asyncio.QueueFull:
logger.warning(
f"Write queue full; adding chat completion id={getattr(chat_completion, 'id', '<unknown>')}"
)
await self._queue.put((chat_completion, input_messages))
else:
await self._write_chat_completion(chat_completion, input_messages)
async def _worker_loop(self) -> None:
assert self._queue is not None
while True:
try:
item = await self._queue.get()
except asyncio.CancelledError:
break
chat_completion, input_messages = item
try:
await self._write_chat_completion(chat_completion, input_messages)
except Exception as e: # noqa: BLE001
logger.error(f"Error writing chat completion: {e}")
finally:
self._queue.task_done()
async def _write_chat_completion(
self, chat_completion: OpenAIChatCompletion, input_messages: list[OpenAIMessageParam]
) -> None:
if self.sql_store is None:
raise ValueError("Inference store is not initialized")
data = chat_completion.model_dump()

View file

@ -65,6 +65,9 @@ async def test_inference_store_pagination_basic():
input_messages = [OpenAIUserMessageParam(role="user", content=f"Test message for {completion_id}")]
await store.store_chat_completion(completion, input_messages)
# Wait for all queued writes to complete
await store.flush()
# Test 1: First page with limit=2, descending order (default)
result = await store.list_chat_completions(limit=2, order=Order.desc)
assert len(result.data) == 2
@ -108,6 +111,9 @@ async def test_inference_store_pagination_ascending():
input_messages = [OpenAIUserMessageParam(role="user", content=f"Test message for {completion_id}")]
await store.store_chat_completion(completion, input_messages)
# Wait for all queued writes to complete
await store.flush()
# Test ascending order pagination
result = await store.list_chat_completions(limit=1, order=Order.asc)
assert len(result.data) == 1
@ -143,6 +149,9 @@ async def test_inference_store_pagination_with_model_filter():
input_messages = [OpenAIUserMessageParam(role="user", content=f"Test message for {completion_id}")]
await store.store_chat_completion(completion, input_messages)
# Wait for all queued writes to complete
await store.flush()
# Test pagination with model filter
result = await store.list_chat_completions(limit=1, model="model-a", order=Order.desc)
assert len(result.data) == 1
@ -190,6 +199,9 @@ async def test_inference_store_pagination_no_limit():
input_messages = [OpenAIUserMessageParam(role="user", content=f"Test message for {completion_id}")]
await store.store_chat_completion(completion, input_messages)
# Wait for all queued writes to complete
await store.flush()
# Test without limit
result = await store.list_chat_completions(order=Order.desc)
assert len(result.data) == 2