mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-04 04:04:14 +00:00
chore: introduce write queue for inference_store (#3383)
# What does this PR do? Adds a write worker queue for writes to inference store. This avoids overwhelming request processing with slow inference writes. ## Test Plan Benchmark: ``` cd /docs/source/distributions/k8s-benchmark # start mock server python openai-mock-server.py --port 8000 # start stack server LLAMA_STACK_LOGGING="all=WARNING" uv run --with llama-stack python -m llama_stack.core.server.server docs/source/distributions/k8s-benchmark/stack_run_config.yaml # run benchmark script uv run python3 benchmark.py --duration 120 --concurrent 50 --base-url=http://localhost:8321/v1/openai/v1 --model=vllm-inference/meta-llama/Llama-3.2-3B-Instruct ``` ## RPS from 21 -> 57
This commit is contained in:
parent
e6edc1f934
commit
e980436a2e
7 changed files with 139 additions and 22 deletions
|
@ -58,14 +58,6 @@ class BenchmarkStats:
|
||||||
|
|
||||||
print(f"\n{'='*60}")
|
print(f"\n{'='*60}")
|
||||||
print(f"BENCHMARK RESULTS")
|
print(f"BENCHMARK RESULTS")
|
||||||
print(f"{'='*60}")
|
|
||||||
print(f"Total time: {total_time:.2f}s")
|
|
||||||
print(f"Concurrent users: {self.concurrent_users}")
|
|
||||||
print(f"Total requests: {self.total_requests}")
|
|
||||||
print(f"Successful requests: {self.success_count}")
|
|
||||||
print(f"Failed requests: {len(self.errors)}")
|
|
||||||
print(f"Success rate: {success_rate:.1f}%")
|
|
||||||
print(f"Requests per second: {self.success_count / total_time:.2f}")
|
|
||||||
|
|
||||||
print(f"\nResponse Time Statistics:")
|
print(f"\nResponse Time Statistics:")
|
||||||
print(f" Mean: {statistics.mean(self.response_times):.3f}s")
|
print(f" Mean: {statistics.mean(self.response_times):.3f}s")
|
||||||
|
@ -106,6 +98,15 @@ class BenchmarkStats:
|
||||||
print(f" Mean chunks per response: {statistics.mean(self.chunks_received):.1f}")
|
print(f" Mean chunks per response: {statistics.mean(self.chunks_received):.1f}")
|
||||||
print(f" Total chunks received: {sum(self.chunks_received)}")
|
print(f" Total chunks received: {sum(self.chunks_received)}")
|
||||||
|
|
||||||
|
print(f"{'='*60}")
|
||||||
|
print(f"Total time: {total_time:.2f}s")
|
||||||
|
print(f"Concurrent users: {self.concurrent_users}")
|
||||||
|
print(f"Total requests: {self.total_requests}")
|
||||||
|
print(f"Successful requests: {self.success_count}")
|
||||||
|
print(f"Failed requests: {len(self.errors)}")
|
||||||
|
print(f"Success rate: {success_rate:.1f}%")
|
||||||
|
print(f"Requests per second: {self.success_count / total_time:.2f}")
|
||||||
|
|
||||||
if self.errors:
|
if self.errors:
|
||||||
print(f"\nErrors (showing first 5):")
|
print(f"\nErrors (showing first 5):")
|
||||||
for error in self.errors[:5]:
|
for error in self.errors[:5]:
|
||||||
|
@ -215,7 +216,7 @@ class LlamaStackBenchmark:
|
||||||
await asyncio.sleep(1) # Report every second
|
await asyncio.sleep(1) # Report every second
|
||||||
if time.time() >= last_report_time + 10: # Report every 10 seconds
|
if time.time() >= last_report_time + 10: # Report every 10 seconds
|
||||||
elapsed = time.time() - stats.start_time
|
elapsed = time.time() - stats.start_time
|
||||||
print(f"Completed: {stats.total_requests} requests in {elapsed:.1f}s")
|
print(f"Completed: {stats.total_requests} requests in {elapsed:.1f}s, RPS: {stats.total_requests / elapsed:.1f}")
|
||||||
last_report_time = time.time()
|
last_report_time = time.time()
|
||||||
except asyncio.CancelledError:
|
except asyncio.CancelledError:
|
||||||
break
|
break
|
||||||
|
|
|
@ -2,6 +2,7 @@ version: '2'
|
||||||
image_name: kubernetes-benchmark-demo
|
image_name: kubernetes-benchmark-demo
|
||||||
apis:
|
apis:
|
||||||
- agents
|
- agents
|
||||||
|
- files
|
||||||
- inference
|
- inference
|
||||||
- files
|
- files
|
||||||
- safety
|
- safety
|
||||||
|
@ -20,6 +21,14 @@ providers:
|
||||||
- provider_id: sentence-transformers
|
- provider_id: sentence-transformers
|
||||||
provider_type: inline::sentence-transformers
|
provider_type: inline::sentence-transformers
|
||||||
config: {}
|
config: {}
|
||||||
|
files:
|
||||||
|
- provider_id: meta-reference-files
|
||||||
|
provider_type: inline::localfs
|
||||||
|
config:
|
||||||
|
storage_dir: ${env.FILES_STORAGE_DIR:=~/.llama/distributions/starter/files}
|
||||||
|
metadata_store:
|
||||||
|
type: sqlite
|
||||||
|
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter}/files_metadata.db
|
||||||
vector_io:
|
vector_io:
|
||||||
- provider_id: ${env.ENABLE_CHROMADB:+chromadb}
|
- provider_id: ${env.ENABLE_CHROMADB:+chromadb}
|
||||||
provider_type: remote::chromadb
|
provider_type: remote::chromadb
|
||||||
|
|
|
@ -431,6 +431,12 @@ class ServerConfig(BaseModel):
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class InferenceStoreConfig(BaseModel):
|
||||||
|
sql_store_config: SqlStoreConfig
|
||||||
|
max_write_queue_size: int = Field(default=10000, description="Max queued writes for inference store")
|
||||||
|
num_writers: int = Field(default=4, description="Number of concurrent background writers")
|
||||||
|
|
||||||
|
|
||||||
class StackRunConfig(BaseModel):
|
class StackRunConfig(BaseModel):
|
||||||
version: int = LLAMA_STACK_RUN_CONFIG_VERSION
|
version: int = LLAMA_STACK_RUN_CONFIG_VERSION
|
||||||
|
|
||||||
|
@ -464,11 +470,12 @@ Configuration for the persistence store used by the distribution registry. If no
|
||||||
a default SQLite store will be used.""",
|
a default SQLite store will be used.""",
|
||||||
)
|
)
|
||||||
|
|
||||||
inference_store: SqlStoreConfig | None = Field(
|
inference_store: InferenceStoreConfig | SqlStoreConfig | None = Field(
|
||||||
default=None,
|
default=None,
|
||||||
description="""
|
description="""
|
||||||
Configuration for the persistence store used by the inference API. If not specified,
|
Configuration for the persistence store used by the inference API. Can be either a
|
||||||
a default SQLite store will be used.""",
|
InferenceStoreConfig (with queue tuning parameters) or a SqlStoreConfig (deprecated).
|
||||||
|
If not specified, a default SQLite store will be used.""",
|
||||||
)
|
)
|
||||||
|
|
||||||
# registry of "resources" in the distribution
|
# registry of "resources" in the distribution
|
||||||
|
|
|
@ -78,7 +78,10 @@ async def get_auto_router_impl(
|
||||||
|
|
||||||
# TODO: move pass configs to routers instead
|
# TODO: move pass configs to routers instead
|
||||||
if api == Api.inference and run_config.inference_store:
|
if api == Api.inference and run_config.inference_store:
|
||||||
inference_store = InferenceStore(run_config.inference_store, policy)
|
inference_store = InferenceStore(
|
||||||
|
config=run_config.inference_store,
|
||||||
|
policy=policy,
|
||||||
|
)
|
||||||
await inference_store.initialize()
|
await inference_store.initialize()
|
||||||
api_to_dep_impl["store"] = inference_store
|
api_to_dep_impl["store"] = inference_store
|
||||||
|
|
||||||
|
|
|
@ -90,6 +90,11 @@ class InferenceRouter(Inference):
|
||||||
|
|
||||||
async def shutdown(self) -> None:
|
async def shutdown(self) -> None:
|
||||||
logger.debug("InferenceRouter.shutdown")
|
logger.debug("InferenceRouter.shutdown")
|
||||||
|
if self.store:
|
||||||
|
try:
|
||||||
|
await self.store.shutdown()
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Error during InferenceStore shutdown: {e}")
|
||||||
|
|
||||||
async def register_model(
|
async def register_model(
|
||||||
self,
|
self,
|
||||||
|
|
|
@ -3,6 +3,9 @@
|
||||||
#
|
#
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
import asyncio
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
from llama_stack.apis.inference import (
|
from llama_stack.apis.inference import (
|
||||||
ListOpenAIChatCompletionResponse,
|
ListOpenAIChatCompletionResponse,
|
||||||
OpenAIChatCompletion,
|
OpenAIChatCompletion,
|
||||||
|
@ -10,24 +13,43 @@ from llama_stack.apis.inference import (
|
||||||
OpenAIMessageParam,
|
OpenAIMessageParam,
|
||||||
Order,
|
Order,
|
||||||
)
|
)
|
||||||
from llama_stack.core.datatypes import AccessRule
|
from llama_stack.core.datatypes import AccessRule, InferenceStoreConfig
|
||||||
from llama_stack.core.utils.config_dirs import RUNTIME_BASE_DIR
|
from llama_stack.log import get_logger
|
||||||
|
|
||||||
from ..sqlstore.api import ColumnDefinition, ColumnType
|
from ..sqlstore.api import ColumnDefinition, ColumnType
|
||||||
from ..sqlstore.authorized_sqlstore import AuthorizedSqlStore
|
from ..sqlstore.authorized_sqlstore import AuthorizedSqlStore
|
||||||
from ..sqlstore.sqlstore import SqliteSqlStoreConfig, SqlStoreConfig, sqlstore_impl
|
from ..sqlstore.sqlstore import SqlStoreConfig, SqlStoreType, sqlstore_impl
|
||||||
|
|
||||||
|
logger = get_logger(name=__name__, category="inference_store")
|
||||||
|
|
||||||
|
|
||||||
class InferenceStore:
|
class InferenceStore:
|
||||||
def __init__(self, sql_store_config: SqlStoreConfig, policy: list[AccessRule]):
|
def __init__(
|
||||||
if not sql_store_config:
|
self,
|
||||||
sql_store_config = SqliteSqlStoreConfig(
|
config: InferenceStoreConfig | SqlStoreConfig,
|
||||||
db_path=(RUNTIME_BASE_DIR / "sqlstore.db").as_posix(),
|
policy: list[AccessRule],
|
||||||
|
):
|
||||||
|
# Handle backward compatibility
|
||||||
|
if not isinstance(config, InferenceStoreConfig):
|
||||||
|
# Legacy: SqlStoreConfig passed directly as config
|
||||||
|
config = InferenceStoreConfig(
|
||||||
|
sql_store_config=config,
|
||||||
)
|
)
|
||||||
self.sql_store_config = sql_store_config
|
|
||||||
|
self.config = config
|
||||||
|
self.sql_store_config = config.sql_store_config
|
||||||
self.sql_store = None
|
self.sql_store = None
|
||||||
self.policy = policy
|
self.policy = policy
|
||||||
|
|
||||||
|
# Disable write queue for SQLite to avoid concurrency issues
|
||||||
|
self.enable_write_queue = self.sql_store_config.type != SqlStoreType.sqlite
|
||||||
|
|
||||||
|
# Async write queue and worker control
|
||||||
|
self._queue: asyncio.Queue[tuple[OpenAIChatCompletion, list[OpenAIMessageParam]]] | None = None
|
||||||
|
self._worker_tasks: list[asyncio.Task[Any]] = []
|
||||||
|
self._max_write_queue_size: int = config.max_write_queue_size
|
||||||
|
self._num_writers: int = max(1, config.num_writers)
|
||||||
|
|
||||||
async def initialize(self):
|
async def initialize(self):
|
||||||
"""Create the necessary tables if they don't exist."""
|
"""Create the necessary tables if they don't exist."""
|
||||||
self.sql_store = AuthorizedSqlStore(sqlstore_impl(self.sql_store_config))
|
self.sql_store = AuthorizedSqlStore(sqlstore_impl(self.sql_store_config))
|
||||||
|
@ -42,10 +64,68 @@ class InferenceStore:
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if self.enable_write_queue:
|
||||||
|
self._queue = asyncio.Queue(maxsize=self._max_write_queue_size)
|
||||||
|
for _ in range(self._num_writers):
|
||||||
|
self._worker_tasks.append(asyncio.create_task(self._worker_loop()))
|
||||||
|
else:
|
||||||
|
logger.info("Write queue disabled for SQLite to avoid concurrency issues")
|
||||||
|
|
||||||
|
async def shutdown(self) -> None:
|
||||||
|
if not self._worker_tasks:
|
||||||
|
return
|
||||||
|
if self._queue is not None:
|
||||||
|
await self._queue.join()
|
||||||
|
for t in self._worker_tasks:
|
||||||
|
if not t.done():
|
||||||
|
t.cancel()
|
||||||
|
for t in self._worker_tasks:
|
||||||
|
try:
|
||||||
|
await t
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
pass
|
||||||
|
self._worker_tasks.clear()
|
||||||
|
|
||||||
|
async def flush(self) -> None:
|
||||||
|
"""Wait for all queued writes to complete. Useful for testing."""
|
||||||
|
if self.enable_write_queue and self._queue is not None:
|
||||||
|
await self._queue.join()
|
||||||
|
|
||||||
async def store_chat_completion(
|
async def store_chat_completion(
|
||||||
self, chat_completion: OpenAIChatCompletion, input_messages: list[OpenAIMessageParam]
|
self, chat_completion: OpenAIChatCompletion, input_messages: list[OpenAIMessageParam]
|
||||||
) -> None:
|
) -> None:
|
||||||
if not self.sql_store:
|
if self.enable_write_queue:
|
||||||
|
if self._queue is None:
|
||||||
|
raise ValueError("Inference store is not initialized")
|
||||||
|
try:
|
||||||
|
self._queue.put_nowait((chat_completion, input_messages))
|
||||||
|
except asyncio.QueueFull:
|
||||||
|
logger.warning(
|
||||||
|
f"Write queue full; adding chat completion id={getattr(chat_completion, 'id', '<unknown>')}"
|
||||||
|
)
|
||||||
|
await self._queue.put((chat_completion, input_messages))
|
||||||
|
else:
|
||||||
|
await self._write_chat_completion(chat_completion, input_messages)
|
||||||
|
|
||||||
|
async def _worker_loop(self) -> None:
|
||||||
|
assert self._queue is not None
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
item = await self._queue.get()
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
break
|
||||||
|
chat_completion, input_messages = item
|
||||||
|
try:
|
||||||
|
await self._write_chat_completion(chat_completion, input_messages)
|
||||||
|
except Exception as e: # noqa: BLE001
|
||||||
|
logger.error(f"Error writing chat completion: {e}")
|
||||||
|
finally:
|
||||||
|
self._queue.task_done()
|
||||||
|
|
||||||
|
async def _write_chat_completion(
|
||||||
|
self, chat_completion: OpenAIChatCompletion, input_messages: list[OpenAIMessageParam]
|
||||||
|
) -> None:
|
||||||
|
if self.sql_store is None:
|
||||||
raise ValueError("Inference store is not initialized")
|
raise ValueError("Inference store is not initialized")
|
||||||
|
|
||||||
data = chat_completion.model_dump()
|
data = chat_completion.model_dump()
|
||||||
|
|
|
@ -65,6 +65,9 @@ async def test_inference_store_pagination_basic():
|
||||||
input_messages = [OpenAIUserMessageParam(role="user", content=f"Test message for {completion_id}")]
|
input_messages = [OpenAIUserMessageParam(role="user", content=f"Test message for {completion_id}")]
|
||||||
await store.store_chat_completion(completion, input_messages)
|
await store.store_chat_completion(completion, input_messages)
|
||||||
|
|
||||||
|
# Wait for all queued writes to complete
|
||||||
|
await store.flush()
|
||||||
|
|
||||||
# Test 1: First page with limit=2, descending order (default)
|
# Test 1: First page with limit=2, descending order (default)
|
||||||
result = await store.list_chat_completions(limit=2, order=Order.desc)
|
result = await store.list_chat_completions(limit=2, order=Order.desc)
|
||||||
assert len(result.data) == 2
|
assert len(result.data) == 2
|
||||||
|
@ -108,6 +111,9 @@ async def test_inference_store_pagination_ascending():
|
||||||
input_messages = [OpenAIUserMessageParam(role="user", content=f"Test message for {completion_id}")]
|
input_messages = [OpenAIUserMessageParam(role="user", content=f"Test message for {completion_id}")]
|
||||||
await store.store_chat_completion(completion, input_messages)
|
await store.store_chat_completion(completion, input_messages)
|
||||||
|
|
||||||
|
# Wait for all queued writes to complete
|
||||||
|
await store.flush()
|
||||||
|
|
||||||
# Test ascending order pagination
|
# Test ascending order pagination
|
||||||
result = await store.list_chat_completions(limit=1, order=Order.asc)
|
result = await store.list_chat_completions(limit=1, order=Order.asc)
|
||||||
assert len(result.data) == 1
|
assert len(result.data) == 1
|
||||||
|
@ -143,6 +149,9 @@ async def test_inference_store_pagination_with_model_filter():
|
||||||
input_messages = [OpenAIUserMessageParam(role="user", content=f"Test message for {completion_id}")]
|
input_messages = [OpenAIUserMessageParam(role="user", content=f"Test message for {completion_id}")]
|
||||||
await store.store_chat_completion(completion, input_messages)
|
await store.store_chat_completion(completion, input_messages)
|
||||||
|
|
||||||
|
# Wait for all queued writes to complete
|
||||||
|
await store.flush()
|
||||||
|
|
||||||
# Test pagination with model filter
|
# Test pagination with model filter
|
||||||
result = await store.list_chat_completions(limit=1, model="model-a", order=Order.desc)
|
result = await store.list_chat_completions(limit=1, model="model-a", order=Order.desc)
|
||||||
assert len(result.data) == 1
|
assert len(result.data) == 1
|
||||||
|
@ -190,6 +199,9 @@ async def test_inference_store_pagination_no_limit():
|
||||||
input_messages = [OpenAIUserMessageParam(role="user", content=f"Test message for {completion_id}")]
|
input_messages = [OpenAIUserMessageParam(role="user", content=f"Test message for {completion_id}")]
|
||||||
await store.store_chat_completion(completion, input_messages)
|
await store.store_chat_completion(completion, input_messages)
|
||||||
|
|
||||||
|
# Wait for all queued writes to complete
|
||||||
|
await store.flush()
|
||||||
|
|
||||||
# Test without limit
|
# Test without limit
|
||||||
result = await store.list_chat_completions(order=Order.desc)
|
result = await store.list_chat_completions(order=Order.desc)
|
||||||
assert len(result.data) == 2
|
assert len(result.data) == 2
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue