diff --git a/docs/_static/css/my_theme.css b/docs/_static/css/my_theme.css index d078ec057..7dcd97c9b 100644 --- a/docs/_static/css/my_theme.css +++ b/docs/_static/css/my_theme.css @@ -1,5 +1,106 @@ @import url("theme.css"); +/* Horizontal Navigation Bar */ +.horizontal-nav { + background-color: #ffffff; + border-bottom: 1px solid #e5e5e5; + padding: 0; + position: fixed; + top: 0; + left: 0; + right: 0; + z-index: 1050; + height: 50px; + box-shadow: 0 2px 4px rgba(0,0,0,0.1); +} + +[data-theme="dark"] .horizontal-nav { + background-color: #1a1a1a; + border-bottom: 1px solid #333; +} + +.horizontal-nav .nav-container { + max-width: 1200px; + margin: 0 auto; + display: flex; + align-items: center; + justify-content: space-between; + padding: 0 20px; + height: 100%; +} + +.horizontal-nav .nav-brand { + font-size: 18px; + font-weight: 600; + color: #333; + text-decoration: none; +} + +[data-theme="dark"] .horizontal-nav .nav-brand { + color: #fff; +} + +.horizontal-nav .nav-links { + display: flex; + align-items: center; + gap: 30px; + list-style: none; + margin: 0; + padding: 0; +} + +.horizontal-nav .nav-links a { + color: #666; + text-decoration: none; + font-size: 14px; + font-weight: 500; + padding: 8px 12px; + border-radius: 6px; + transition: all 0.2s ease; +} + +.horizontal-nav .nav-links a:hover, +.horizontal-nav .nav-links a.active { + color: #333; + background-color: #f5f5f5; +} + +.horizontal-nav .nav-links a.active { + font-weight: 600; +} + +[data-theme="dark"] .horizontal-nav .nav-links a { + color: #ccc; +} + +[data-theme="dark"] .horizontal-nav .nav-links a:hover, +[data-theme="dark"] .horizontal-nav .nav-links a.active { + color: #fff; + background-color: #333; +} + +.horizontal-nav .nav-links .github-link { + display: flex; + align-items: center; + gap: 6px; +} + +.horizontal-nav .nav-links .github-icon { + width: 16px; + height: 16px; + fill: currentColor; +} + +/* Adjust main content to account for fixed nav */ +.wy-nav-side { + top: 50px; + height: calc(100vh - 50px); +} + +.wy-nav-content-wrap { + margin-top: 50px; +} + .wy-nav-content { max-width: 90%; } diff --git a/docs/_static/js/horizontal_nav.js b/docs/_static/js/horizontal_nav.js new file mode 100644 index 000000000..c2384f9d5 --- /dev/null +++ b/docs/_static/js/horizontal_nav.js @@ -0,0 +1,44 @@ +// Horizontal Navigation Bar for Llama Stack Documentation +document.addEventListener('DOMContentLoaded', function() { + // Create the horizontal navigation HTML + const navHTML = ` + + `; + + // Insert the navigation at the beginning of the body + document.body.insertAdjacentHTML('afterbegin', navHTML); + + // Update navigation links based on current page + updateActiveNav(); +}); + +function updateActiveNav() { + const currentPath = window.location.pathname; + const navLinks = document.querySelectorAll('.horizontal-nav .nav-links a'); + + navLinks.forEach(link => { + // Remove any existing active classes + link.classList.remove('active'); + + // Add active class based on current path + if (currentPath === '/' && link.getAttribute('href') === '/') { + link.classList.add('active'); + } else if (currentPath.includes('/references/api_reference/') && link.getAttribute('href').includes('api_reference')) { + link.classList.add('active'); + } + }); +} diff --git a/docs/source/conf.py b/docs/source/conf.py index 3f84d1310..0cbddef31 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -131,6 +131,7 @@ html_static_path = ["../_static"] def setup(app): app.add_css_file("css/my_theme.css") app.add_js_file("js/detect_theme.js") + app.add_js_file("js/horizontal_nav.js") app.add_js_file("js/keyboard_shortcuts.js") def dockerhub_role(name, rawtext, text, lineno, inliner, options={}, content=[]): diff --git a/docs/source/distributions/k8s-benchmark/benchmark.py b/docs/source/distributions/k8s-benchmark/benchmark.py index 3d0d18150..83ba9602a 100644 --- a/docs/source/distributions/k8s-benchmark/benchmark.py +++ b/docs/source/distributions/k8s-benchmark/benchmark.py @@ -58,14 +58,6 @@ class BenchmarkStats: print(f"\n{'='*60}") print(f"BENCHMARK RESULTS") - print(f"{'='*60}") - print(f"Total time: {total_time:.2f}s") - print(f"Concurrent users: {self.concurrent_users}") - print(f"Total requests: {self.total_requests}") - print(f"Successful requests: {self.success_count}") - print(f"Failed requests: {len(self.errors)}") - print(f"Success rate: {success_rate:.1f}%") - print(f"Requests per second: {self.success_count / total_time:.2f}") print(f"\nResponse Time Statistics:") print(f" Mean: {statistics.mean(self.response_times):.3f}s") @@ -106,6 +98,15 @@ class BenchmarkStats: print(f" Mean chunks per response: {statistics.mean(self.chunks_received):.1f}") print(f" Total chunks received: {sum(self.chunks_received)}") + print(f"{'='*60}") + print(f"Total time: {total_time:.2f}s") + print(f"Concurrent users: {self.concurrent_users}") + print(f"Total requests: {self.total_requests}") + print(f"Successful requests: {self.success_count}") + print(f"Failed requests: {len(self.errors)}") + print(f"Success rate: {success_rate:.1f}%") + print(f"Requests per second: {self.success_count / total_time:.2f}") + if self.errors: print(f"\nErrors (showing first 5):") for error in self.errors[:5]: @@ -215,7 +216,7 @@ class LlamaStackBenchmark: await asyncio.sleep(1) # Report every second if time.time() >= last_report_time + 10: # Report every 10 seconds elapsed = time.time() - stats.start_time - print(f"Completed: {stats.total_requests} requests in {elapsed:.1f}s") + print(f"Completed: {stats.total_requests} requests in {elapsed:.1f}s, RPS: {stats.total_requests / elapsed:.1f}") last_report_time = time.time() except asyncio.CancelledError: break diff --git a/docs/source/distributions/k8s-benchmark/stack_run_config.yaml b/docs/source/distributions/k8s-benchmark/stack_run_config.yaml index f8ff7811b..5a9e2ae4f 100644 --- a/docs/source/distributions/k8s-benchmark/stack_run_config.yaml +++ b/docs/source/distributions/k8s-benchmark/stack_run_config.yaml @@ -2,6 +2,7 @@ version: '2' image_name: kubernetes-benchmark-demo apis: - agents +- files - inference - files - safety @@ -20,6 +21,14 @@ providers: - provider_id: sentence-transformers provider_type: inline::sentence-transformers config: {} + files: + - provider_id: meta-reference-files + provider_type: inline::localfs + config: + storage_dir: ${env.FILES_STORAGE_DIR:=~/.llama/distributions/starter/files} + metadata_store: + type: sqlite + db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter}/files_metadata.db vector_io: - provider_id: ${env.ENABLE_CHROMADB:+chromadb} provider_type: remote::chromadb diff --git a/llama_stack/core/datatypes.py b/llama_stack/core/datatypes.py index 0f348b067..faaeefd01 100644 --- a/llama_stack/core/datatypes.py +++ b/llama_stack/core/datatypes.py @@ -431,6 +431,12 @@ class ServerConfig(BaseModel): ) +class InferenceStoreConfig(BaseModel): + sql_store_config: SqlStoreConfig + max_write_queue_size: int = Field(default=10000, description="Max queued writes for inference store") + num_writers: int = Field(default=4, description="Number of concurrent background writers") + + class StackRunConfig(BaseModel): version: int = LLAMA_STACK_RUN_CONFIG_VERSION @@ -464,11 +470,12 @@ Configuration for the persistence store used by the distribution registry. If no a default SQLite store will be used.""", ) - inference_store: SqlStoreConfig | None = Field( + inference_store: InferenceStoreConfig | SqlStoreConfig | None = Field( default=None, description=""" -Configuration for the persistence store used by the inference API. If not specified, -a default SQLite store will be used.""", +Configuration for the persistence store used by the inference API. Can be either a +InferenceStoreConfig (with queue tuning parameters) or a SqlStoreConfig (deprecated). +If not specified, a default SQLite store will be used.""", ) # registry of "resources" in the distribution diff --git a/llama_stack/core/routers/__init__.py b/llama_stack/core/routers/__init__.py index 1faace34a..f129f8ede 100644 --- a/llama_stack/core/routers/__init__.py +++ b/llama_stack/core/routers/__init__.py @@ -78,7 +78,10 @@ async def get_auto_router_impl( # TODO: move pass configs to routers instead if api == Api.inference and run_config.inference_store: - inference_store = InferenceStore(run_config.inference_store, policy) + inference_store = InferenceStore( + config=run_config.inference_store, + policy=policy, + ) await inference_store.initialize() api_to_dep_impl["store"] = inference_store diff --git a/llama_stack/core/routers/inference.py b/llama_stack/core/routers/inference.py index 045093fe0..762d7073e 100644 --- a/llama_stack/core/routers/inference.py +++ b/llama_stack/core/routers/inference.py @@ -63,7 +63,7 @@ from llama_stack.models.llama.llama3.chat_format import ChatFormat from llama_stack.models.llama.llama3.tokenizer import Tokenizer from llama_stack.providers.datatypes import HealthResponse, HealthStatus, RoutingTable from llama_stack.providers.utils.inference.inference_store import InferenceStore -from llama_stack.providers.utils.telemetry.tracing import get_current_span +from llama_stack.providers.utils.telemetry.tracing import enqueue_event, get_current_span logger = get_logger(name=__name__, category="core::routers") @@ -90,6 +90,11 @@ class InferenceRouter(Inference): async def shutdown(self) -> None: logger.debug("InferenceRouter.shutdown") + if self.store: + try: + await self.store.shutdown() + except Exception as e: + logger.warning(f"Error during InferenceStore shutdown: {e}") async def register_model( self, @@ -160,7 +165,7 @@ class InferenceRouter(Inference): metrics = self._construct_metrics(prompt_tokens, completion_tokens, total_tokens, model) if self.telemetry: for metric in metrics: - await self.telemetry.log_event(metric) + enqueue_event(metric) return [MetricInResponse(metric=metric.metric, value=metric.value) for metric in metrics] async def _count_tokens( @@ -431,7 +436,7 @@ class InferenceRouter(Inference): model=model_obj, ) for metric in metrics: - await self.telemetry.log_event(metric) + enqueue_event(metric) # these metrics will show up in the client response. response.metrics = ( @@ -537,7 +542,7 @@ class InferenceRouter(Inference): model=model_obj, ) for metric in metrics: - await self.telemetry.log_event(metric) + enqueue_event(metric) # these metrics will show up in the client response. response.metrics = ( metrics if not hasattr(response, "metrics") or response.metrics is None else response.metrics + metrics @@ -664,7 +669,7 @@ class InferenceRouter(Inference): "completion_tokens", "total_tokens", ]: # Only log completion and total tokens - await self.telemetry.log_event(metric) + enqueue_event(metric) # Return metrics in response async_metrics = [ @@ -710,7 +715,7 @@ class InferenceRouter(Inference): ) for metric in completion_metrics: if metric.metric in ["completion_tokens", "total_tokens"]: # Only log completion and total tokens - await self.telemetry.log_event(metric) + enqueue_event(metric) # Return metrics in response return [MetricInResponse(metric=metric.metric, value=metric.value) for metric in completion_metrics] @@ -806,7 +811,7 @@ class InferenceRouter(Inference): model=model, ) for metric in metrics: - await self.telemetry.log_event(metric) + enqueue_event(metric) yield chunk finally: diff --git a/llama_stack/providers/utils/inference/inference_store.py b/llama_stack/providers/utils/inference/inference_store.py index 43006cfd5..8c69b1683 100644 --- a/llama_stack/providers/utils/inference/inference_store.py +++ b/llama_stack/providers/utils/inference/inference_store.py @@ -3,6 +3,9 @@ # # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. +import asyncio +from typing import Any + from llama_stack.apis.inference import ( ListOpenAIChatCompletionResponse, OpenAIChatCompletion, @@ -10,24 +13,43 @@ from llama_stack.apis.inference import ( OpenAIMessageParam, Order, ) -from llama_stack.core.datatypes import AccessRule -from llama_stack.core.utils.config_dirs import RUNTIME_BASE_DIR +from llama_stack.core.datatypes import AccessRule, InferenceStoreConfig +from llama_stack.log import get_logger from ..sqlstore.api import ColumnDefinition, ColumnType from ..sqlstore.authorized_sqlstore import AuthorizedSqlStore -from ..sqlstore.sqlstore import SqliteSqlStoreConfig, SqlStoreConfig, sqlstore_impl +from ..sqlstore.sqlstore import SqlStoreConfig, SqlStoreType, sqlstore_impl + +logger = get_logger(name=__name__, category="inference_store") class InferenceStore: - def __init__(self, sql_store_config: SqlStoreConfig, policy: list[AccessRule]): - if not sql_store_config: - sql_store_config = SqliteSqlStoreConfig( - db_path=(RUNTIME_BASE_DIR / "sqlstore.db").as_posix(), + def __init__( + self, + config: InferenceStoreConfig | SqlStoreConfig, + policy: list[AccessRule], + ): + # Handle backward compatibility + if not isinstance(config, InferenceStoreConfig): + # Legacy: SqlStoreConfig passed directly as config + config = InferenceStoreConfig( + sql_store_config=config, ) - self.sql_store_config = sql_store_config + + self.config = config + self.sql_store_config = config.sql_store_config self.sql_store = None self.policy = policy + # Disable write queue for SQLite to avoid concurrency issues + self.enable_write_queue = self.sql_store_config.type != SqlStoreType.sqlite + + # Async write queue and worker control + self._queue: asyncio.Queue[tuple[OpenAIChatCompletion, list[OpenAIMessageParam]]] | None = None + self._worker_tasks: list[asyncio.Task[Any]] = [] + self._max_write_queue_size: int = config.max_write_queue_size + self._num_writers: int = max(1, config.num_writers) + async def initialize(self): """Create the necessary tables if they don't exist.""" self.sql_store = AuthorizedSqlStore(sqlstore_impl(self.sql_store_config)) @@ -42,10 +64,68 @@ class InferenceStore: }, ) + if self.enable_write_queue: + self._queue = asyncio.Queue(maxsize=self._max_write_queue_size) + for _ in range(self._num_writers): + self._worker_tasks.append(asyncio.create_task(self._worker_loop())) + else: + logger.info("Write queue disabled for SQLite to avoid concurrency issues") + + async def shutdown(self) -> None: + if not self._worker_tasks: + return + if self._queue is not None: + await self._queue.join() + for t in self._worker_tasks: + if not t.done(): + t.cancel() + for t in self._worker_tasks: + try: + await t + except asyncio.CancelledError: + pass + self._worker_tasks.clear() + + async def flush(self) -> None: + """Wait for all queued writes to complete. Useful for testing.""" + if self.enable_write_queue and self._queue is not None: + await self._queue.join() + async def store_chat_completion( self, chat_completion: OpenAIChatCompletion, input_messages: list[OpenAIMessageParam] ) -> None: - if not self.sql_store: + if self.enable_write_queue: + if self._queue is None: + raise ValueError("Inference store is not initialized") + try: + self._queue.put_nowait((chat_completion, input_messages)) + except asyncio.QueueFull: + logger.warning( + f"Write queue full; adding chat completion id={getattr(chat_completion, 'id', '')}" + ) + await self._queue.put((chat_completion, input_messages)) + else: + await self._write_chat_completion(chat_completion, input_messages) + + async def _worker_loop(self) -> None: + assert self._queue is not None + while True: + try: + item = await self._queue.get() + except asyncio.CancelledError: + break + chat_completion, input_messages = item + try: + await self._write_chat_completion(chat_completion, input_messages) + except Exception as e: # noqa: BLE001 + logger.error(f"Error writing chat completion: {e}") + finally: + self._queue.task_done() + + async def _write_chat_completion( + self, chat_completion: OpenAIChatCompletion, input_messages: list[OpenAIMessageParam] + ) -> None: + if self.sql_store is None: raise ValueError("Inference store is not initialized") data = chat_completion.model_dump() diff --git a/llama_stack/providers/utils/telemetry/tracing.py b/llama_stack/providers/utils/telemetry/tracing.py index 7694003b5..9969b1055 100644 --- a/llama_stack/providers/utils/telemetry/tracing.py +++ b/llama_stack/providers/utils/telemetry/tracing.py @@ -18,6 +18,7 @@ from functools import wraps from typing import Any from llama_stack.apis.telemetry import ( + Event, LogSeverity, Span, SpanEndPayload, @@ -98,7 +99,7 @@ class BackgroundLogger: def __init__(self, api: Telemetry, capacity: int = 100000): self.api = api self.log_queue: queue.Queue[Any] = queue.Queue(maxsize=capacity) - self.worker_thread = threading.Thread(target=self._process_logs, daemon=True) + self.worker_thread = threading.Thread(target=self._worker, daemon=True) self.worker_thread.start() self._last_queue_full_log_time: float = 0.0 self._dropped_since_last_notice: int = 0 @@ -118,12 +119,16 @@ class BackgroundLogger: self._last_queue_full_log_time = current_time self._dropped_since_last_notice = 0 - def _process_logs(self): + def _worker(self): + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + loop.run_until_complete(self._process_logs()) + + async def _process_logs(self): while True: try: event = self.log_queue.get() - # figure out how to use a thread's native loop - asyncio.run(self.api.log_event(event)) + await self.api.log_event(event) except Exception: import traceback @@ -136,6 +141,19 @@ class BackgroundLogger: self.log_queue.join() +def enqueue_event(event: Event) -> None: + """Enqueue a telemetry event to the background logger if available. + + This provides a non-blocking path for routers and other hot paths to + submit telemetry without awaiting the Telemetry API, reducing contention + with the main event loop. + """ + global BACKGROUND_LOGGER + if BACKGROUND_LOGGER is None: + raise RuntimeError("Telemetry API not initialized") + BACKGROUND_LOGGER.log_event(event) + + class TraceContext: spans: list[Span] = [] @@ -256,11 +274,7 @@ class TelemetryHandler(logging.Handler): if record.module in ("asyncio", "selector_events"): return - global CURRENT_TRACE_CONTEXT, BACKGROUND_LOGGER - - if BACKGROUND_LOGGER is None: - raise RuntimeError("Telemetry API not initialized") - + global CURRENT_TRACE_CONTEXT context = CURRENT_TRACE_CONTEXT.get() if context is None: return @@ -269,7 +283,7 @@ class TelemetryHandler(logging.Handler): if span is None: return - BACKGROUND_LOGGER.log_event( + enqueue_event( UnstructuredLogEvent( trace_id=span.trace_id, span_id=span.span_id, diff --git a/scripts/github/schedule-record-workflow.sh b/scripts/github/schedule-record-workflow.sh index c292e53e6..44b0947b6 100755 --- a/scripts/github/schedule-record-workflow.sh +++ b/scripts/github/schedule-record-workflow.sh @@ -239,8 +239,9 @@ echo "Test pattern: ${TEST_PATTERN:-"(none)"}" echo "" # Prepare inputs for gh workflow run +INPUTS= if [[ -n "$TEST_SUBDIRS" ]]; then - INPUTS="-f subdirs='$TEST_SUBDIRS'" + INPUTS="$INPUTS -f subdirs='$TEST_SUBDIRS'" fi if [[ -n "$TEST_SETUP" ]]; then INPUTS="$INPUTS -f test-setup='$TEST_SETUP'" diff --git a/tests/unit/utils/inference/test_inference_store.py b/tests/unit/utils/inference/test_inference_store.py index 730f54a05..f6d63490a 100644 --- a/tests/unit/utils/inference/test_inference_store.py +++ b/tests/unit/utils/inference/test_inference_store.py @@ -65,6 +65,9 @@ async def test_inference_store_pagination_basic(): input_messages = [OpenAIUserMessageParam(role="user", content=f"Test message for {completion_id}")] await store.store_chat_completion(completion, input_messages) + # Wait for all queued writes to complete + await store.flush() + # Test 1: First page with limit=2, descending order (default) result = await store.list_chat_completions(limit=2, order=Order.desc) assert len(result.data) == 2 @@ -108,6 +111,9 @@ async def test_inference_store_pagination_ascending(): input_messages = [OpenAIUserMessageParam(role="user", content=f"Test message for {completion_id}")] await store.store_chat_completion(completion, input_messages) + # Wait for all queued writes to complete + await store.flush() + # Test ascending order pagination result = await store.list_chat_completions(limit=1, order=Order.asc) assert len(result.data) == 1 @@ -143,6 +149,9 @@ async def test_inference_store_pagination_with_model_filter(): input_messages = [OpenAIUserMessageParam(role="user", content=f"Test message for {completion_id}")] await store.store_chat_completion(completion, input_messages) + # Wait for all queued writes to complete + await store.flush() + # Test pagination with model filter result = await store.list_chat_completions(limit=1, model="model-a", order=Order.desc) assert len(result.data) == 1 @@ -190,6 +199,9 @@ async def test_inference_store_pagination_no_limit(): input_messages = [OpenAIUserMessageParam(role="user", content=f"Test message for {completion_id}")] await store.store_chat_completion(completion, input_messages) + # Wait for all queued writes to complete + await store.flush() + # Test without limit result = await store.list_chat_completions(order=Order.desc) assert len(result.data) == 2