diff --git a/docs/_static/css/my_theme.css b/docs/_static/css/my_theme.css
index d078ec057..7dcd97c9b 100644
--- a/docs/_static/css/my_theme.css
+++ b/docs/_static/css/my_theme.css
@@ -1,5 +1,106 @@
@import url("theme.css");
+/* Horizontal Navigation Bar */
+.horizontal-nav {
+ background-color: #ffffff;
+ border-bottom: 1px solid #e5e5e5;
+ padding: 0;
+ position: fixed;
+ top: 0;
+ left: 0;
+ right: 0;
+ z-index: 1050;
+ height: 50px;
+ box-shadow: 0 2px 4px rgba(0,0,0,0.1);
+}
+
+[data-theme="dark"] .horizontal-nav {
+ background-color: #1a1a1a;
+ border-bottom: 1px solid #333;
+}
+
+.horizontal-nav .nav-container {
+ max-width: 1200px;
+ margin: 0 auto;
+ display: flex;
+ align-items: center;
+ justify-content: space-between;
+ padding: 0 20px;
+ height: 100%;
+}
+
+.horizontal-nav .nav-brand {
+ font-size: 18px;
+ font-weight: 600;
+ color: #333;
+ text-decoration: none;
+}
+
+[data-theme="dark"] .horizontal-nav .nav-brand {
+ color: #fff;
+}
+
+.horizontal-nav .nav-links {
+ display: flex;
+ align-items: center;
+ gap: 30px;
+ list-style: none;
+ margin: 0;
+ padding: 0;
+}
+
+.horizontal-nav .nav-links a {
+ color: #666;
+ text-decoration: none;
+ font-size: 14px;
+ font-weight: 500;
+ padding: 8px 12px;
+ border-radius: 6px;
+ transition: all 0.2s ease;
+}
+
+.horizontal-nav .nav-links a:hover,
+.horizontal-nav .nav-links a.active {
+ color: #333;
+ background-color: #f5f5f5;
+}
+
+.horizontal-nav .nav-links a.active {
+ font-weight: 600;
+}
+
+[data-theme="dark"] .horizontal-nav .nav-links a {
+ color: #ccc;
+}
+
+[data-theme="dark"] .horizontal-nav .nav-links a:hover,
+[data-theme="dark"] .horizontal-nav .nav-links a.active {
+ color: #fff;
+ background-color: #333;
+}
+
+.horizontal-nav .nav-links .github-link {
+ display: flex;
+ align-items: center;
+ gap: 6px;
+}
+
+.horizontal-nav .nav-links .github-icon {
+ width: 16px;
+ height: 16px;
+ fill: currentColor;
+}
+
+/* Adjust main content to account for fixed nav */
+.wy-nav-side {
+ top: 50px;
+ height: calc(100vh - 50px);
+}
+
+.wy-nav-content-wrap {
+ margin-top: 50px;
+}
+
.wy-nav-content {
max-width: 90%;
}
diff --git a/docs/_static/js/horizontal_nav.js b/docs/_static/js/horizontal_nav.js
new file mode 100644
index 000000000..c2384f9d5
--- /dev/null
+++ b/docs/_static/js/horizontal_nav.js
@@ -0,0 +1,44 @@
+// Horizontal Navigation Bar for Llama Stack Documentation
+document.addEventListener('DOMContentLoaded', function() {
+ // Create the horizontal navigation HTML
+ const navHTML = `
+
+ `;
+
+ // Insert the navigation at the beginning of the body
+ document.body.insertAdjacentHTML('afterbegin', navHTML);
+
+ // Update navigation links based on current page
+ updateActiveNav();
+});
+
+function updateActiveNav() {
+ const currentPath = window.location.pathname;
+ const navLinks = document.querySelectorAll('.horizontal-nav .nav-links a');
+
+ navLinks.forEach(link => {
+ // Remove any existing active classes
+ link.classList.remove('active');
+
+ // Add active class based on current path
+ if (currentPath === '/' && link.getAttribute('href') === '/') {
+ link.classList.add('active');
+ } else if (currentPath.includes('/references/api_reference/') && link.getAttribute('href').includes('api_reference')) {
+ link.classList.add('active');
+ }
+ });
+}
diff --git a/docs/source/conf.py b/docs/source/conf.py
index 3f84d1310..0cbddef31 100644
--- a/docs/source/conf.py
+++ b/docs/source/conf.py
@@ -131,6 +131,7 @@ html_static_path = ["../_static"]
def setup(app):
app.add_css_file("css/my_theme.css")
app.add_js_file("js/detect_theme.js")
+ app.add_js_file("js/horizontal_nav.js")
app.add_js_file("js/keyboard_shortcuts.js")
def dockerhub_role(name, rawtext, text, lineno, inliner, options={}, content=[]):
diff --git a/docs/source/distributions/k8s-benchmark/benchmark.py b/docs/source/distributions/k8s-benchmark/benchmark.py
index 3d0d18150..83ba9602a 100644
--- a/docs/source/distributions/k8s-benchmark/benchmark.py
+++ b/docs/source/distributions/k8s-benchmark/benchmark.py
@@ -58,14 +58,6 @@ class BenchmarkStats:
print(f"\n{'='*60}")
print(f"BENCHMARK RESULTS")
- print(f"{'='*60}")
- print(f"Total time: {total_time:.2f}s")
- print(f"Concurrent users: {self.concurrent_users}")
- print(f"Total requests: {self.total_requests}")
- print(f"Successful requests: {self.success_count}")
- print(f"Failed requests: {len(self.errors)}")
- print(f"Success rate: {success_rate:.1f}%")
- print(f"Requests per second: {self.success_count / total_time:.2f}")
print(f"\nResponse Time Statistics:")
print(f" Mean: {statistics.mean(self.response_times):.3f}s")
@@ -106,6 +98,15 @@ class BenchmarkStats:
print(f" Mean chunks per response: {statistics.mean(self.chunks_received):.1f}")
print(f" Total chunks received: {sum(self.chunks_received)}")
+ print(f"{'='*60}")
+ print(f"Total time: {total_time:.2f}s")
+ print(f"Concurrent users: {self.concurrent_users}")
+ print(f"Total requests: {self.total_requests}")
+ print(f"Successful requests: {self.success_count}")
+ print(f"Failed requests: {len(self.errors)}")
+ print(f"Success rate: {success_rate:.1f}%")
+ print(f"Requests per second: {self.success_count / total_time:.2f}")
+
if self.errors:
print(f"\nErrors (showing first 5):")
for error in self.errors[:5]:
@@ -215,7 +216,7 @@ class LlamaStackBenchmark:
await asyncio.sleep(1) # Report every second
if time.time() >= last_report_time + 10: # Report every 10 seconds
elapsed = time.time() - stats.start_time
- print(f"Completed: {stats.total_requests} requests in {elapsed:.1f}s")
+ print(f"Completed: {stats.total_requests} requests in {elapsed:.1f}s, RPS: {stats.total_requests / elapsed:.1f}")
last_report_time = time.time()
except asyncio.CancelledError:
break
diff --git a/docs/source/distributions/k8s-benchmark/stack_run_config.yaml b/docs/source/distributions/k8s-benchmark/stack_run_config.yaml
index f8ff7811b..5a9e2ae4f 100644
--- a/docs/source/distributions/k8s-benchmark/stack_run_config.yaml
+++ b/docs/source/distributions/k8s-benchmark/stack_run_config.yaml
@@ -2,6 +2,7 @@ version: '2'
image_name: kubernetes-benchmark-demo
apis:
- agents
+- files
- inference
- files
- safety
@@ -20,6 +21,14 @@ providers:
- provider_id: sentence-transformers
provider_type: inline::sentence-transformers
config: {}
+ files:
+ - provider_id: meta-reference-files
+ provider_type: inline::localfs
+ config:
+ storage_dir: ${env.FILES_STORAGE_DIR:=~/.llama/distributions/starter/files}
+ metadata_store:
+ type: sqlite
+ db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter}/files_metadata.db
vector_io:
- provider_id: ${env.ENABLE_CHROMADB:+chromadb}
provider_type: remote::chromadb
diff --git a/llama_stack/core/datatypes.py b/llama_stack/core/datatypes.py
index 0f348b067..faaeefd01 100644
--- a/llama_stack/core/datatypes.py
+++ b/llama_stack/core/datatypes.py
@@ -431,6 +431,12 @@ class ServerConfig(BaseModel):
)
+class InferenceStoreConfig(BaseModel):
+ sql_store_config: SqlStoreConfig
+ max_write_queue_size: int = Field(default=10000, description="Max queued writes for inference store")
+ num_writers: int = Field(default=4, description="Number of concurrent background writers")
+
+
class StackRunConfig(BaseModel):
version: int = LLAMA_STACK_RUN_CONFIG_VERSION
@@ -464,11 +470,12 @@ Configuration for the persistence store used by the distribution registry. If no
a default SQLite store will be used.""",
)
- inference_store: SqlStoreConfig | None = Field(
+ inference_store: InferenceStoreConfig | SqlStoreConfig | None = Field(
default=None,
description="""
-Configuration for the persistence store used by the inference API. If not specified,
-a default SQLite store will be used.""",
+Configuration for the persistence store used by the inference API. Can be either a
+InferenceStoreConfig (with queue tuning parameters) or a SqlStoreConfig (deprecated).
+If not specified, a default SQLite store will be used.""",
)
# registry of "resources" in the distribution
diff --git a/llama_stack/core/routers/__init__.py b/llama_stack/core/routers/__init__.py
index 1faace34a..f129f8ede 100644
--- a/llama_stack/core/routers/__init__.py
+++ b/llama_stack/core/routers/__init__.py
@@ -78,7 +78,10 @@ async def get_auto_router_impl(
# TODO: move pass configs to routers instead
if api == Api.inference and run_config.inference_store:
- inference_store = InferenceStore(run_config.inference_store, policy)
+ inference_store = InferenceStore(
+ config=run_config.inference_store,
+ policy=policy,
+ )
await inference_store.initialize()
api_to_dep_impl["store"] = inference_store
diff --git a/llama_stack/core/routers/inference.py b/llama_stack/core/routers/inference.py
index 045093fe0..762d7073e 100644
--- a/llama_stack/core/routers/inference.py
+++ b/llama_stack/core/routers/inference.py
@@ -63,7 +63,7 @@ from llama_stack.models.llama.llama3.chat_format import ChatFormat
from llama_stack.models.llama.llama3.tokenizer import Tokenizer
from llama_stack.providers.datatypes import HealthResponse, HealthStatus, RoutingTable
from llama_stack.providers.utils.inference.inference_store import InferenceStore
-from llama_stack.providers.utils.telemetry.tracing import get_current_span
+from llama_stack.providers.utils.telemetry.tracing import enqueue_event, get_current_span
logger = get_logger(name=__name__, category="core::routers")
@@ -90,6 +90,11 @@ class InferenceRouter(Inference):
async def shutdown(self) -> None:
logger.debug("InferenceRouter.shutdown")
+ if self.store:
+ try:
+ await self.store.shutdown()
+ except Exception as e:
+ logger.warning(f"Error during InferenceStore shutdown: {e}")
async def register_model(
self,
@@ -160,7 +165,7 @@ class InferenceRouter(Inference):
metrics = self._construct_metrics(prompt_tokens, completion_tokens, total_tokens, model)
if self.telemetry:
for metric in metrics:
- await self.telemetry.log_event(metric)
+ enqueue_event(metric)
return [MetricInResponse(metric=metric.metric, value=metric.value) for metric in metrics]
async def _count_tokens(
@@ -431,7 +436,7 @@ class InferenceRouter(Inference):
model=model_obj,
)
for metric in metrics:
- await self.telemetry.log_event(metric)
+ enqueue_event(metric)
# these metrics will show up in the client response.
response.metrics = (
@@ -537,7 +542,7 @@ class InferenceRouter(Inference):
model=model_obj,
)
for metric in metrics:
- await self.telemetry.log_event(metric)
+ enqueue_event(metric)
# these metrics will show up in the client response.
response.metrics = (
metrics if not hasattr(response, "metrics") or response.metrics is None else response.metrics + metrics
@@ -664,7 +669,7 @@ class InferenceRouter(Inference):
"completion_tokens",
"total_tokens",
]: # Only log completion and total tokens
- await self.telemetry.log_event(metric)
+ enqueue_event(metric)
# Return metrics in response
async_metrics = [
@@ -710,7 +715,7 @@ class InferenceRouter(Inference):
)
for metric in completion_metrics:
if metric.metric in ["completion_tokens", "total_tokens"]: # Only log completion and total tokens
- await self.telemetry.log_event(metric)
+ enqueue_event(metric)
# Return metrics in response
return [MetricInResponse(metric=metric.metric, value=metric.value) for metric in completion_metrics]
@@ -806,7 +811,7 @@ class InferenceRouter(Inference):
model=model,
)
for metric in metrics:
- await self.telemetry.log_event(metric)
+ enqueue_event(metric)
yield chunk
finally:
diff --git a/llama_stack/providers/utils/inference/inference_store.py b/llama_stack/providers/utils/inference/inference_store.py
index 43006cfd5..8c69b1683 100644
--- a/llama_stack/providers/utils/inference/inference_store.py
+++ b/llama_stack/providers/utils/inference/inference_store.py
@@ -3,6 +3,9 @@
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
+import asyncio
+from typing import Any
+
from llama_stack.apis.inference import (
ListOpenAIChatCompletionResponse,
OpenAIChatCompletion,
@@ -10,24 +13,43 @@ from llama_stack.apis.inference import (
OpenAIMessageParam,
Order,
)
-from llama_stack.core.datatypes import AccessRule
-from llama_stack.core.utils.config_dirs import RUNTIME_BASE_DIR
+from llama_stack.core.datatypes import AccessRule, InferenceStoreConfig
+from llama_stack.log import get_logger
from ..sqlstore.api import ColumnDefinition, ColumnType
from ..sqlstore.authorized_sqlstore import AuthorizedSqlStore
-from ..sqlstore.sqlstore import SqliteSqlStoreConfig, SqlStoreConfig, sqlstore_impl
+from ..sqlstore.sqlstore import SqlStoreConfig, SqlStoreType, sqlstore_impl
+
+logger = get_logger(name=__name__, category="inference_store")
class InferenceStore:
- def __init__(self, sql_store_config: SqlStoreConfig, policy: list[AccessRule]):
- if not sql_store_config:
- sql_store_config = SqliteSqlStoreConfig(
- db_path=(RUNTIME_BASE_DIR / "sqlstore.db").as_posix(),
+ def __init__(
+ self,
+ config: InferenceStoreConfig | SqlStoreConfig,
+ policy: list[AccessRule],
+ ):
+ # Handle backward compatibility
+ if not isinstance(config, InferenceStoreConfig):
+ # Legacy: SqlStoreConfig passed directly as config
+ config = InferenceStoreConfig(
+ sql_store_config=config,
)
- self.sql_store_config = sql_store_config
+
+ self.config = config
+ self.sql_store_config = config.sql_store_config
self.sql_store = None
self.policy = policy
+ # Disable write queue for SQLite to avoid concurrency issues
+ self.enable_write_queue = self.sql_store_config.type != SqlStoreType.sqlite
+
+ # Async write queue and worker control
+ self._queue: asyncio.Queue[tuple[OpenAIChatCompletion, list[OpenAIMessageParam]]] | None = None
+ self._worker_tasks: list[asyncio.Task[Any]] = []
+ self._max_write_queue_size: int = config.max_write_queue_size
+ self._num_writers: int = max(1, config.num_writers)
+
async def initialize(self):
"""Create the necessary tables if they don't exist."""
self.sql_store = AuthorizedSqlStore(sqlstore_impl(self.sql_store_config))
@@ -42,10 +64,68 @@ class InferenceStore:
},
)
+ if self.enable_write_queue:
+ self._queue = asyncio.Queue(maxsize=self._max_write_queue_size)
+ for _ in range(self._num_writers):
+ self._worker_tasks.append(asyncio.create_task(self._worker_loop()))
+ else:
+ logger.info("Write queue disabled for SQLite to avoid concurrency issues")
+
+ async def shutdown(self) -> None:
+ if not self._worker_tasks:
+ return
+ if self._queue is not None:
+ await self._queue.join()
+ for t in self._worker_tasks:
+ if not t.done():
+ t.cancel()
+ for t in self._worker_tasks:
+ try:
+ await t
+ except asyncio.CancelledError:
+ pass
+ self._worker_tasks.clear()
+
+ async def flush(self) -> None:
+ """Wait for all queued writes to complete. Useful for testing."""
+ if self.enable_write_queue and self._queue is not None:
+ await self._queue.join()
+
async def store_chat_completion(
self, chat_completion: OpenAIChatCompletion, input_messages: list[OpenAIMessageParam]
) -> None:
- if not self.sql_store:
+ if self.enable_write_queue:
+ if self._queue is None:
+ raise ValueError("Inference store is not initialized")
+ try:
+ self._queue.put_nowait((chat_completion, input_messages))
+ except asyncio.QueueFull:
+ logger.warning(
+ f"Write queue full; adding chat completion id={getattr(chat_completion, 'id', '')}"
+ )
+ await self._queue.put((chat_completion, input_messages))
+ else:
+ await self._write_chat_completion(chat_completion, input_messages)
+
+ async def _worker_loop(self) -> None:
+ assert self._queue is not None
+ while True:
+ try:
+ item = await self._queue.get()
+ except asyncio.CancelledError:
+ break
+ chat_completion, input_messages = item
+ try:
+ await self._write_chat_completion(chat_completion, input_messages)
+ except Exception as e: # noqa: BLE001
+ logger.error(f"Error writing chat completion: {e}")
+ finally:
+ self._queue.task_done()
+
+ async def _write_chat_completion(
+ self, chat_completion: OpenAIChatCompletion, input_messages: list[OpenAIMessageParam]
+ ) -> None:
+ if self.sql_store is None:
raise ValueError("Inference store is not initialized")
data = chat_completion.model_dump()
diff --git a/llama_stack/providers/utils/telemetry/tracing.py b/llama_stack/providers/utils/telemetry/tracing.py
index 7694003b5..9969b1055 100644
--- a/llama_stack/providers/utils/telemetry/tracing.py
+++ b/llama_stack/providers/utils/telemetry/tracing.py
@@ -18,6 +18,7 @@ from functools import wraps
from typing import Any
from llama_stack.apis.telemetry import (
+ Event,
LogSeverity,
Span,
SpanEndPayload,
@@ -98,7 +99,7 @@ class BackgroundLogger:
def __init__(self, api: Telemetry, capacity: int = 100000):
self.api = api
self.log_queue: queue.Queue[Any] = queue.Queue(maxsize=capacity)
- self.worker_thread = threading.Thread(target=self._process_logs, daemon=True)
+ self.worker_thread = threading.Thread(target=self._worker, daemon=True)
self.worker_thread.start()
self._last_queue_full_log_time: float = 0.0
self._dropped_since_last_notice: int = 0
@@ -118,12 +119,16 @@ class BackgroundLogger:
self._last_queue_full_log_time = current_time
self._dropped_since_last_notice = 0
- def _process_logs(self):
+ def _worker(self):
+ loop = asyncio.new_event_loop()
+ asyncio.set_event_loop(loop)
+ loop.run_until_complete(self._process_logs())
+
+ async def _process_logs(self):
while True:
try:
event = self.log_queue.get()
- # figure out how to use a thread's native loop
- asyncio.run(self.api.log_event(event))
+ await self.api.log_event(event)
except Exception:
import traceback
@@ -136,6 +141,19 @@ class BackgroundLogger:
self.log_queue.join()
+def enqueue_event(event: Event) -> None:
+ """Enqueue a telemetry event to the background logger if available.
+
+ This provides a non-blocking path for routers and other hot paths to
+ submit telemetry without awaiting the Telemetry API, reducing contention
+ with the main event loop.
+ """
+ global BACKGROUND_LOGGER
+ if BACKGROUND_LOGGER is None:
+ raise RuntimeError("Telemetry API not initialized")
+ BACKGROUND_LOGGER.log_event(event)
+
+
class TraceContext:
spans: list[Span] = []
@@ -256,11 +274,7 @@ class TelemetryHandler(logging.Handler):
if record.module in ("asyncio", "selector_events"):
return
- global CURRENT_TRACE_CONTEXT, BACKGROUND_LOGGER
-
- if BACKGROUND_LOGGER is None:
- raise RuntimeError("Telemetry API not initialized")
-
+ global CURRENT_TRACE_CONTEXT
context = CURRENT_TRACE_CONTEXT.get()
if context is None:
return
@@ -269,7 +283,7 @@ class TelemetryHandler(logging.Handler):
if span is None:
return
- BACKGROUND_LOGGER.log_event(
+ enqueue_event(
UnstructuredLogEvent(
trace_id=span.trace_id,
span_id=span.span_id,
diff --git a/scripts/github/schedule-record-workflow.sh b/scripts/github/schedule-record-workflow.sh
index c292e53e6..44b0947b6 100755
--- a/scripts/github/schedule-record-workflow.sh
+++ b/scripts/github/schedule-record-workflow.sh
@@ -239,8 +239,9 @@ echo "Test pattern: ${TEST_PATTERN:-"(none)"}"
echo ""
# Prepare inputs for gh workflow run
+INPUTS=
if [[ -n "$TEST_SUBDIRS" ]]; then
- INPUTS="-f subdirs='$TEST_SUBDIRS'"
+ INPUTS="$INPUTS -f subdirs='$TEST_SUBDIRS'"
fi
if [[ -n "$TEST_SETUP" ]]; then
INPUTS="$INPUTS -f test-setup='$TEST_SETUP'"
diff --git a/tests/unit/utils/inference/test_inference_store.py b/tests/unit/utils/inference/test_inference_store.py
index 730f54a05..f6d63490a 100644
--- a/tests/unit/utils/inference/test_inference_store.py
+++ b/tests/unit/utils/inference/test_inference_store.py
@@ -65,6 +65,9 @@ async def test_inference_store_pagination_basic():
input_messages = [OpenAIUserMessageParam(role="user", content=f"Test message for {completion_id}")]
await store.store_chat_completion(completion, input_messages)
+ # Wait for all queued writes to complete
+ await store.flush()
+
# Test 1: First page with limit=2, descending order (default)
result = await store.list_chat_completions(limit=2, order=Order.desc)
assert len(result.data) == 2
@@ -108,6 +111,9 @@ async def test_inference_store_pagination_ascending():
input_messages = [OpenAIUserMessageParam(role="user", content=f"Test message for {completion_id}")]
await store.store_chat_completion(completion, input_messages)
+ # Wait for all queued writes to complete
+ await store.flush()
+
# Test ascending order pagination
result = await store.list_chat_completions(limit=1, order=Order.asc)
assert len(result.data) == 1
@@ -143,6 +149,9 @@ async def test_inference_store_pagination_with_model_filter():
input_messages = [OpenAIUserMessageParam(role="user", content=f"Test message for {completion_id}")]
await store.store_chat_completion(completion, input_messages)
+ # Wait for all queued writes to complete
+ await store.flush()
+
# Test pagination with model filter
result = await store.list_chat_completions(limit=1, model="model-a", order=Order.desc)
assert len(result.data) == 1
@@ -190,6 +199,9 @@ async def test_inference_store_pagination_no_limit():
input_messages = [OpenAIUserMessageParam(role="user", content=f"Test message for {completion_id}")]
await store.store_chat_completion(completion, input_messages)
+ # Wait for all queued writes to complete
+ await store.flush()
+
# Test without limit
result = await store.list_chat_completions(order=Order.desc)
assert len(result.data) == 2