diff --git a/docs/source/distributions/k8s-benchmark/benchmark.py b/docs/source/distributions/k8s-benchmark/benchmark.py index 3d0d18150..83ba9602a 100644 --- a/docs/source/distributions/k8s-benchmark/benchmark.py +++ b/docs/source/distributions/k8s-benchmark/benchmark.py @@ -58,14 +58,6 @@ class BenchmarkStats: print(f"\n{'='*60}") print(f"BENCHMARK RESULTS") - print(f"{'='*60}") - print(f"Total time: {total_time:.2f}s") - print(f"Concurrent users: {self.concurrent_users}") - print(f"Total requests: {self.total_requests}") - print(f"Successful requests: {self.success_count}") - print(f"Failed requests: {len(self.errors)}") - print(f"Success rate: {success_rate:.1f}%") - print(f"Requests per second: {self.success_count / total_time:.2f}") print(f"\nResponse Time Statistics:") print(f" Mean: {statistics.mean(self.response_times):.3f}s") @@ -106,6 +98,15 @@ class BenchmarkStats: print(f" Mean chunks per response: {statistics.mean(self.chunks_received):.1f}") print(f" Total chunks received: {sum(self.chunks_received)}") + print(f"{'='*60}") + print(f"Total time: {total_time:.2f}s") + print(f"Concurrent users: {self.concurrent_users}") + print(f"Total requests: {self.total_requests}") + print(f"Successful requests: {self.success_count}") + print(f"Failed requests: {len(self.errors)}") + print(f"Success rate: {success_rate:.1f}%") + print(f"Requests per second: {self.success_count / total_time:.2f}") + if self.errors: print(f"\nErrors (showing first 5):") for error in self.errors[:5]: @@ -215,7 +216,7 @@ class LlamaStackBenchmark: await asyncio.sleep(1) # Report every second if time.time() >= last_report_time + 10: # Report every 10 seconds elapsed = time.time() - stats.start_time - print(f"Completed: {stats.total_requests} requests in {elapsed:.1f}s") + print(f"Completed: {stats.total_requests} requests in {elapsed:.1f}s, RPS: {stats.total_requests / elapsed:.1f}") last_report_time = time.time() except asyncio.CancelledError: break diff --git a/docs/source/distributions/k8s-benchmark/stack_run_config.yaml b/docs/source/distributions/k8s-benchmark/stack_run_config.yaml index f8ff7811b..5a9e2ae4f 100644 --- a/docs/source/distributions/k8s-benchmark/stack_run_config.yaml +++ b/docs/source/distributions/k8s-benchmark/stack_run_config.yaml @@ -2,6 +2,7 @@ version: '2' image_name: kubernetes-benchmark-demo apis: - agents +- files - inference - files - safety @@ -20,6 +21,14 @@ providers: - provider_id: sentence-transformers provider_type: inline::sentence-transformers config: {} + files: + - provider_id: meta-reference-files + provider_type: inline::localfs + config: + storage_dir: ${env.FILES_STORAGE_DIR:=~/.llama/distributions/starter/files} + metadata_store: + type: sqlite + db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter}/files_metadata.db vector_io: - provider_id: ${env.ENABLE_CHROMADB:+chromadb} provider_type: remote::chromadb diff --git a/llama_stack/core/datatypes.py b/llama_stack/core/datatypes.py index 0f348b067..faaeefd01 100644 --- a/llama_stack/core/datatypes.py +++ b/llama_stack/core/datatypes.py @@ -431,6 +431,12 @@ class ServerConfig(BaseModel): ) +class InferenceStoreConfig(BaseModel): + sql_store_config: SqlStoreConfig + max_write_queue_size: int = Field(default=10000, description="Max queued writes for inference store") + num_writers: int = Field(default=4, description="Number of concurrent background writers") + + class StackRunConfig(BaseModel): version: int = LLAMA_STACK_RUN_CONFIG_VERSION @@ -464,11 +470,12 @@ Configuration for the persistence store used by the distribution registry. If no a default SQLite store will be used.""", ) - inference_store: SqlStoreConfig | None = Field( + inference_store: InferenceStoreConfig | SqlStoreConfig | None = Field( default=None, description=""" -Configuration for the persistence store used by the inference API. If not specified, -a default SQLite store will be used.""", +Configuration for the persistence store used by the inference API. Can be either a +InferenceStoreConfig (with queue tuning parameters) or a SqlStoreConfig (deprecated). +If not specified, a default SQLite store will be used.""", ) # registry of "resources" in the distribution diff --git a/llama_stack/core/routers/__init__.py b/llama_stack/core/routers/__init__.py index 1faace34a..f129f8ede 100644 --- a/llama_stack/core/routers/__init__.py +++ b/llama_stack/core/routers/__init__.py @@ -78,7 +78,10 @@ async def get_auto_router_impl( # TODO: move pass configs to routers instead if api == Api.inference and run_config.inference_store: - inference_store = InferenceStore(run_config.inference_store, policy) + inference_store = InferenceStore( + config=run_config.inference_store, + policy=policy, + ) await inference_store.initialize() api_to_dep_impl["store"] = inference_store diff --git a/llama_stack/core/routers/inference.py b/llama_stack/core/routers/inference.py index 2ed2d0439..762d7073e 100644 --- a/llama_stack/core/routers/inference.py +++ b/llama_stack/core/routers/inference.py @@ -90,6 +90,11 @@ class InferenceRouter(Inference): async def shutdown(self) -> None: logger.debug("InferenceRouter.shutdown") + if self.store: + try: + await self.store.shutdown() + except Exception as e: + logger.warning(f"Error during InferenceStore shutdown: {e}") async def register_model( self, diff --git a/llama_stack/providers/utils/inference/inference_store.py b/llama_stack/providers/utils/inference/inference_store.py index 43006cfd5..8c69b1683 100644 --- a/llama_stack/providers/utils/inference/inference_store.py +++ b/llama_stack/providers/utils/inference/inference_store.py @@ -3,6 +3,9 @@ # # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. +import asyncio +from typing import Any + from llama_stack.apis.inference import ( ListOpenAIChatCompletionResponse, OpenAIChatCompletion, @@ -10,24 +13,43 @@ from llama_stack.apis.inference import ( OpenAIMessageParam, Order, ) -from llama_stack.core.datatypes import AccessRule -from llama_stack.core.utils.config_dirs import RUNTIME_BASE_DIR +from llama_stack.core.datatypes import AccessRule, InferenceStoreConfig +from llama_stack.log import get_logger from ..sqlstore.api import ColumnDefinition, ColumnType from ..sqlstore.authorized_sqlstore import AuthorizedSqlStore -from ..sqlstore.sqlstore import SqliteSqlStoreConfig, SqlStoreConfig, sqlstore_impl +from ..sqlstore.sqlstore import SqlStoreConfig, SqlStoreType, sqlstore_impl + +logger = get_logger(name=__name__, category="inference_store") class InferenceStore: - def __init__(self, sql_store_config: SqlStoreConfig, policy: list[AccessRule]): - if not sql_store_config: - sql_store_config = SqliteSqlStoreConfig( - db_path=(RUNTIME_BASE_DIR / "sqlstore.db").as_posix(), + def __init__( + self, + config: InferenceStoreConfig | SqlStoreConfig, + policy: list[AccessRule], + ): + # Handle backward compatibility + if not isinstance(config, InferenceStoreConfig): + # Legacy: SqlStoreConfig passed directly as config + config = InferenceStoreConfig( + sql_store_config=config, ) - self.sql_store_config = sql_store_config + + self.config = config + self.sql_store_config = config.sql_store_config self.sql_store = None self.policy = policy + # Disable write queue for SQLite to avoid concurrency issues + self.enable_write_queue = self.sql_store_config.type != SqlStoreType.sqlite + + # Async write queue and worker control + self._queue: asyncio.Queue[tuple[OpenAIChatCompletion, list[OpenAIMessageParam]]] | None = None + self._worker_tasks: list[asyncio.Task[Any]] = [] + self._max_write_queue_size: int = config.max_write_queue_size + self._num_writers: int = max(1, config.num_writers) + async def initialize(self): """Create the necessary tables if they don't exist.""" self.sql_store = AuthorizedSqlStore(sqlstore_impl(self.sql_store_config)) @@ -42,10 +64,68 @@ class InferenceStore: }, ) + if self.enable_write_queue: + self._queue = asyncio.Queue(maxsize=self._max_write_queue_size) + for _ in range(self._num_writers): + self._worker_tasks.append(asyncio.create_task(self._worker_loop())) + else: + logger.info("Write queue disabled for SQLite to avoid concurrency issues") + + async def shutdown(self) -> None: + if not self._worker_tasks: + return + if self._queue is not None: + await self._queue.join() + for t in self._worker_tasks: + if not t.done(): + t.cancel() + for t in self._worker_tasks: + try: + await t + except asyncio.CancelledError: + pass + self._worker_tasks.clear() + + async def flush(self) -> None: + """Wait for all queued writes to complete. Useful for testing.""" + if self.enable_write_queue and self._queue is not None: + await self._queue.join() + async def store_chat_completion( self, chat_completion: OpenAIChatCompletion, input_messages: list[OpenAIMessageParam] ) -> None: - if not self.sql_store: + if self.enable_write_queue: + if self._queue is None: + raise ValueError("Inference store is not initialized") + try: + self._queue.put_nowait((chat_completion, input_messages)) + except asyncio.QueueFull: + logger.warning( + f"Write queue full; adding chat completion id={getattr(chat_completion, 'id', '')}" + ) + await self._queue.put((chat_completion, input_messages)) + else: + await self._write_chat_completion(chat_completion, input_messages) + + async def _worker_loop(self) -> None: + assert self._queue is not None + while True: + try: + item = await self._queue.get() + except asyncio.CancelledError: + break + chat_completion, input_messages = item + try: + await self._write_chat_completion(chat_completion, input_messages) + except Exception as e: # noqa: BLE001 + logger.error(f"Error writing chat completion: {e}") + finally: + self._queue.task_done() + + async def _write_chat_completion( + self, chat_completion: OpenAIChatCompletion, input_messages: list[OpenAIMessageParam] + ) -> None: + if self.sql_store is None: raise ValueError("Inference store is not initialized") data = chat_completion.model_dump() diff --git a/tests/unit/utils/inference/test_inference_store.py b/tests/unit/utils/inference/test_inference_store.py index 730f54a05..f6d63490a 100644 --- a/tests/unit/utils/inference/test_inference_store.py +++ b/tests/unit/utils/inference/test_inference_store.py @@ -65,6 +65,9 @@ async def test_inference_store_pagination_basic(): input_messages = [OpenAIUserMessageParam(role="user", content=f"Test message for {completion_id}")] await store.store_chat_completion(completion, input_messages) + # Wait for all queued writes to complete + await store.flush() + # Test 1: First page with limit=2, descending order (default) result = await store.list_chat_completions(limit=2, order=Order.desc) assert len(result.data) == 2 @@ -108,6 +111,9 @@ async def test_inference_store_pagination_ascending(): input_messages = [OpenAIUserMessageParam(role="user", content=f"Test message for {completion_id}")] await store.store_chat_completion(completion, input_messages) + # Wait for all queued writes to complete + await store.flush() + # Test ascending order pagination result = await store.list_chat_completions(limit=1, order=Order.asc) assert len(result.data) == 1 @@ -143,6 +149,9 @@ async def test_inference_store_pagination_with_model_filter(): input_messages = [OpenAIUserMessageParam(role="user", content=f"Test message for {completion_id}")] await store.store_chat_completion(completion, input_messages) + # Wait for all queued writes to complete + await store.flush() + # Test pagination with model filter result = await store.list_chat_completions(limit=1, model="model-a", order=Order.desc) assert len(result.data) == 1 @@ -190,6 +199,9 @@ async def test_inference_store_pagination_no_limit(): input_messages = [OpenAIUserMessageParam(role="user", content=f"Test message for {completion_id}")] await store.store_chat_completion(completion, input_messages) + # Wait for all queued writes to complete + await store.flush() + # Test without limit result = await store.list_chat_completions(order=Order.desc) assert len(result.data) == 2