Merge branch 'main' into ragtool-migration

This commit is contained in:
Francisco Arceo 2025-09-10 13:58:57 -06:00 committed by GitHub
commit 9b27f81b75
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
12 changed files with 318 additions and 40 deletions

View file

@ -1,5 +1,106 @@
@import url("theme.css"); @import url("theme.css");
/* Horizontal Navigation Bar */
.horizontal-nav {
background-color: #ffffff;
border-bottom: 1px solid #e5e5e5;
padding: 0;
position: fixed;
top: 0;
left: 0;
right: 0;
z-index: 1050;
height: 50px;
box-shadow: 0 2px 4px rgba(0,0,0,0.1);
}
[data-theme="dark"] .horizontal-nav {
background-color: #1a1a1a;
border-bottom: 1px solid #333;
}
.horizontal-nav .nav-container {
max-width: 1200px;
margin: 0 auto;
display: flex;
align-items: center;
justify-content: space-between;
padding: 0 20px;
height: 100%;
}
.horizontal-nav .nav-brand {
font-size: 18px;
font-weight: 600;
color: #333;
text-decoration: none;
}
[data-theme="dark"] .horizontal-nav .nav-brand {
color: #fff;
}
.horizontal-nav .nav-links {
display: flex;
align-items: center;
gap: 30px;
list-style: none;
margin: 0;
padding: 0;
}
.horizontal-nav .nav-links a {
color: #666;
text-decoration: none;
font-size: 14px;
font-weight: 500;
padding: 8px 12px;
border-radius: 6px;
transition: all 0.2s ease;
}
.horizontal-nav .nav-links a:hover,
.horizontal-nav .nav-links a.active {
color: #333;
background-color: #f5f5f5;
}
.horizontal-nav .nav-links a.active {
font-weight: 600;
}
[data-theme="dark"] .horizontal-nav .nav-links a {
color: #ccc;
}
[data-theme="dark"] .horizontal-nav .nav-links a:hover,
[data-theme="dark"] .horizontal-nav .nav-links a.active {
color: #fff;
background-color: #333;
}
.horizontal-nav .nav-links .github-link {
display: flex;
align-items: center;
gap: 6px;
}
.horizontal-nav .nav-links .github-icon {
width: 16px;
height: 16px;
fill: currentColor;
}
/* Adjust main content to account for fixed nav */
.wy-nav-side {
top: 50px;
height: calc(100vh - 50px);
}
.wy-nav-content-wrap {
margin-top: 50px;
}
.wy-nav-content { .wy-nav-content {
max-width: 90%; max-width: 90%;
} }

44
docs/_static/js/horizontal_nav.js vendored Normal file
View file

@ -0,0 +1,44 @@
// Horizontal Navigation Bar for Llama Stack Documentation
document.addEventListener('DOMContentLoaded', function() {
// Create the horizontal navigation HTML
const navHTML = `
<nav class="horizontal-nav">
<div class="nav-container">
<a href="/" class="nav-brand">Llama Stack</a>
<ul class="nav-links">
<li><a href="/">Docs</a></li>
<li><a href="/references/api_reference/">API Reference</a></li>
<li><a href="https://github.com/meta-llama/llama-stack" target="_blank" class="github-link">
<svg class="github-icon" viewBox="0 0 16 16" aria-hidden="true">
<path d="M8 0C3.58 0 0 3.58 0 8c0 3.54 2.29 6.53 5.47 7.59.4.07.55-.17.55-.38 0-.19-.01-.82-.01-1.49-2.01.37-2.53-.49-2.69-.94-.09-.23-.48-.94-.82-1.13-.28-.15-.68-.52-.01-.53.63-.01 1.08.58 1.23.82.72 1.21 1.87.87 2.33.66.07-.52.28-.87.51-1.07-1.78-.2-3.64-.89-3.64-3.95 0-.87.31-1.59.82-2.15-.08-.2-.36-1.02.08-2.12 0 0 .67-.21 2.2.82.64-.18 1.32-.27 2-.27.68 0 1.36.09 2 .27 1.53-1.04 2.2-.82 2.2-.82.44 1.1.16 1.92.08 2.12.51.56.82 1.27.82 2.15 0 3.07-1.87 3.75-3.65 3.95.29.25.54.73.54 1.48 0 1.07-.01 1.93-.01 2.2 0 .21.15.46.55.38A8.013 8.013 0 0016 8c0-4.42-3.58-8-8-8z"/>
</svg>
GitHub
</a></li>
</ul>
</div>
</nav>
`;
// Insert the navigation at the beginning of the body
document.body.insertAdjacentHTML('afterbegin', navHTML);
// Update navigation links based on current page
updateActiveNav();
});
function updateActiveNav() {
const currentPath = window.location.pathname;
const navLinks = document.querySelectorAll('.horizontal-nav .nav-links a');
navLinks.forEach(link => {
// Remove any existing active classes
link.classList.remove('active');
// Add active class based on current path
if (currentPath === '/' && link.getAttribute('href') === '/') {
link.classList.add('active');
} else if (currentPath.includes('/references/api_reference/') && link.getAttribute('href').includes('api_reference')) {
link.classList.add('active');
}
});
}

View file

@ -131,6 +131,7 @@ html_static_path = ["../_static"]
def setup(app): def setup(app):
app.add_css_file("css/my_theme.css") app.add_css_file("css/my_theme.css")
app.add_js_file("js/detect_theme.js") app.add_js_file("js/detect_theme.js")
app.add_js_file("js/horizontal_nav.js")
app.add_js_file("js/keyboard_shortcuts.js") app.add_js_file("js/keyboard_shortcuts.js")
def dockerhub_role(name, rawtext, text, lineno, inliner, options={}, content=[]): def dockerhub_role(name, rawtext, text, lineno, inliner, options={}, content=[]):

View file

@ -58,14 +58,6 @@ class BenchmarkStats:
print(f"\n{'='*60}") print(f"\n{'='*60}")
print(f"BENCHMARK RESULTS") print(f"BENCHMARK RESULTS")
print(f"{'='*60}")
print(f"Total time: {total_time:.2f}s")
print(f"Concurrent users: {self.concurrent_users}")
print(f"Total requests: {self.total_requests}")
print(f"Successful requests: {self.success_count}")
print(f"Failed requests: {len(self.errors)}")
print(f"Success rate: {success_rate:.1f}%")
print(f"Requests per second: {self.success_count / total_time:.2f}")
print(f"\nResponse Time Statistics:") print(f"\nResponse Time Statistics:")
print(f" Mean: {statistics.mean(self.response_times):.3f}s") print(f" Mean: {statistics.mean(self.response_times):.3f}s")
@ -106,6 +98,15 @@ class BenchmarkStats:
print(f" Mean chunks per response: {statistics.mean(self.chunks_received):.1f}") print(f" Mean chunks per response: {statistics.mean(self.chunks_received):.1f}")
print(f" Total chunks received: {sum(self.chunks_received)}") print(f" Total chunks received: {sum(self.chunks_received)}")
print(f"{'='*60}")
print(f"Total time: {total_time:.2f}s")
print(f"Concurrent users: {self.concurrent_users}")
print(f"Total requests: {self.total_requests}")
print(f"Successful requests: {self.success_count}")
print(f"Failed requests: {len(self.errors)}")
print(f"Success rate: {success_rate:.1f}%")
print(f"Requests per second: {self.success_count / total_time:.2f}")
if self.errors: if self.errors:
print(f"\nErrors (showing first 5):") print(f"\nErrors (showing first 5):")
for error in self.errors[:5]: for error in self.errors[:5]:
@ -215,7 +216,7 @@ class LlamaStackBenchmark:
await asyncio.sleep(1) # Report every second await asyncio.sleep(1) # Report every second
if time.time() >= last_report_time + 10: # Report every 10 seconds if time.time() >= last_report_time + 10: # Report every 10 seconds
elapsed = time.time() - stats.start_time elapsed = time.time() - stats.start_time
print(f"Completed: {stats.total_requests} requests in {elapsed:.1f}s") print(f"Completed: {stats.total_requests} requests in {elapsed:.1f}s, RPS: {stats.total_requests / elapsed:.1f}")
last_report_time = time.time() last_report_time = time.time()
except asyncio.CancelledError: except asyncio.CancelledError:
break break

View file

@ -2,6 +2,7 @@ version: '2'
image_name: kubernetes-benchmark-demo image_name: kubernetes-benchmark-demo
apis: apis:
- agents - agents
- files
- inference - inference
- files - files
- safety - safety
@ -20,6 +21,14 @@ providers:
- provider_id: sentence-transformers - provider_id: sentence-transformers
provider_type: inline::sentence-transformers provider_type: inline::sentence-transformers
config: {} config: {}
files:
- provider_id: meta-reference-files
provider_type: inline::localfs
config:
storage_dir: ${env.FILES_STORAGE_DIR:=~/.llama/distributions/starter/files}
metadata_store:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter}/files_metadata.db
vector_io: vector_io:
- provider_id: ${env.ENABLE_CHROMADB:+chromadb} - provider_id: ${env.ENABLE_CHROMADB:+chromadb}
provider_type: remote::chromadb provider_type: remote::chromadb

View file

@ -431,6 +431,12 @@ class ServerConfig(BaseModel):
) )
class InferenceStoreConfig(BaseModel):
sql_store_config: SqlStoreConfig
max_write_queue_size: int = Field(default=10000, description="Max queued writes for inference store")
num_writers: int = Field(default=4, description="Number of concurrent background writers")
class StackRunConfig(BaseModel): class StackRunConfig(BaseModel):
version: int = LLAMA_STACK_RUN_CONFIG_VERSION version: int = LLAMA_STACK_RUN_CONFIG_VERSION
@ -464,11 +470,12 @@ Configuration for the persistence store used by the distribution registry. If no
a default SQLite store will be used.""", a default SQLite store will be used.""",
) )
inference_store: SqlStoreConfig | None = Field( inference_store: InferenceStoreConfig | SqlStoreConfig | None = Field(
default=None, default=None,
description=""" description="""
Configuration for the persistence store used by the inference API. If not specified, Configuration for the persistence store used by the inference API. Can be either a
a default SQLite store will be used.""", InferenceStoreConfig (with queue tuning parameters) or a SqlStoreConfig (deprecated).
If not specified, a default SQLite store will be used.""",
) )
# registry of "resources" in the distribution # registry of "resources" in the distribution

View file

@ -78,7 +78,10 @@ async def get_auto_router_impl(
# TODO: move pass configs to routers instead # TODO: move pass configs to routers instead
if api == Api.inference and run_config.inference_store: if api == Api.inference and run_config.inference_store:
inference_store = InferenceStore(run_config.inference_store, policy) inference_store = InferenceStore(
config=run_config.inference_store,
policy=policy,
)
await inference_store.initialize() await inference_store.initialize()
api_to_dep_impl["store"] = inference_store api_to_dep_impl["store"] = inference_store

View file

@ -63,7 +63,7 @@ from llama_stack.models.llama.llama3.chat_format import ChatFormat
from llama_stack.models.llama.llama3.tokenizer import Tokenizer from llama_stack.models.llama.llama3.tokenizer import Tokenizer
from llama_stack.providers.datatypes import HealthResponse, HealthStatus, RoutingTable from llama_stack.providers.datatypes import HealthResponse, HealthStatus, RoutingTable
from llama_stack.providers.utils.inference.inference_store import InferenceStore from llama_stack.providers.utils.inference.inference_store import InferenceStore
from llama_stack.providers.utils.telemetry.tracing import get_current_span from llama_stack.providers.utils.telemetry.tracing import enqueue_event, get_current_span
logger = get_logger(name=__name__, category="core::routers") logger = get_logger(name=__name__, category="core::routers")
@ -90,6 +90,11 @@ class InferenceRouter(Inference):
async def shutdown(self) -> None: async def shutdown(self) -> None:
logger.debug("InferenceRouter.shutdown") logger.debug("InferenceRouter.shutdown")
if self.store:
try:
await self.store.shutdown()
except Exception as e:
logger.warning(f"Error during InferenceStore shutdown: {e}")
async def register_model( async def register_model(
self, self,
@ -160,7 +165,7 @@ class InferenceRouter(Inference):
metrics = self._construct_metrics(prompt_tokens, completion_tokens, total_tokens, model) metrics = self._construct_metrics(prompt_tokens, completion_tokens, total_tokens, model)
if self.telemetry: if self.telemetry:
for metric in metrics: for metric in metrics:
await self.telemetry.log_event(metric) enqueue_event(metric)
return [MetricInResponse(metric=metric.metric, value=metric.value) for metric in metrics] return [MetricInResponse(metric=metric.metric, value=metric.value) for metric in metrics]
async def _count_tokens( async def _count_tokens(
@ -431,7 +436,7 @@ class InferenceRouter(Inference):
model=model_obj, model=model_obj,
) )
for metric in metrics: for metric in metrics:
await self.telemetry.log_event(metric) enqueue_event(metric)
# these metrics will show up in the client response. # these metrics will show up in the client response.
response.metrics = ( response.metrics = (
@ -537,7 +542,7 @@ class InferenceRouter(Inference):
model=model_obj, model=model_obj,
) )
for metric in metrics: for metric in metrics:
await self.telemetry.log_event(metric) enqueue_event(metric)
# these metrics will show up in the client response. # these metrics will show up in the client response.
response.metrics = ( response.metrics = (
metrics if not hasattr(response, "metrics") or response.metrics is None else response.metrics + metrics metrics if not hasattr(response, "metrics") or response.metrics is None else response.metrics + metrics
@ -664,7 +669,7 @@ class InferenceRouter(Inference):
"completion_tokens", "completion_tokens",
"total_tokens", "total_tokens",
]: # Only log completion and total tokens ]: # Only log completion and total tokens
await self.telemetry.log_event(metric) enqueue_event(metric)
# Return metrics in response # Return metrics in response
async_metrics = [ async_metrics = [
@ -710,7 +715,7 @@ class InferenceRouter(Inference):
) )
for metric in completion_metrics: for metric in completion_metrics:
if metric.metric in ["completion_tokens", "total_tokens"]: # Only log completion and total tokens if metric.metric in ["completion_tokens", "total_tokens"]: # Only log completion and total tokens
await self.telemetry.log_event(metric) enqueue_event(metric)
# Return metrics in response # Return metrics in response
return [MetricInResponse(metric=metric.metric, value=metric.value) for metric in completion_metrics] return [MetricInResponse(metric=metric.metric, value=metric.value) for metric in completion_metrics]
@ -806,7 +811,7 @@ class InferenceRouter(Inference):
model=model, model=model,
) )
for metric in metrics: for metric in metrics:
await self.telemetry.log_event(metric) enqueue_event(metric)
yield chunk yield chunk
finally: finally:

View file

@ -3,6 +3,9 @@
# #
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
import asyncio
from typing import Any
from llama_stack.apis.inference import ( from llama_stack.apis.inference import (
ListOpenAIChatCompletionResponse, ListOpenAIChatCompletionResponse,
OpenAIChatCompletion, OpenAIChatCompletion,
@ -10,24 +13,43 @@ from llama_stack.apis.inference import (
OpenAIMessageParam, OpenAIMessageParam,
Order, Order,
) )
from llama_stack.core.datatypes import AccessRule from llama_stack.core.datatypes import AccessRule, InferenceStoreConfig
from llama_stack.core.utils.config_dirs import RUNTIME_BASE_DIR from llama_stack.log import get_logger
from ..sqlstore.api import ColumnDefinition, ColumnType from ..sqlstore.api import ColumnDefinition, ColumnType
from ..sqlstore.authorized_sqlstore import AuthorizedSqlStore from ..sqlstore.authorized_sqlstore import AuthorizedSqlStore
from ..sqlstore.sqlstore import SqliteSqlStoreConfig, SqlStoreConfig, sqlstore_impl from ..sqlstore.sqlstore import SqlStoreConfig, SqlStoreType, sqlstore_impl
logger = get_logger(name=__name__, category="inference_store")
class InferenceStore: class InferenceStore:
def __init__(self, sql_store_config: SqlStoreConfig, policy: list[AccessRule]): def __init__(
if not sql_store_config: self,
sql_store_config = SqliteSqlStoreConfig( config: InferenceStoreConfig | SqlStoreConfig,
db_path=(RUNTIME_BASE_DIR / "sqlstore.db").as_posix(), policy: list[AccessRule],
):
# Handle backward compatibility
if not isinstance(config, InferenceStoreConfig):
# Legacy: SqlStoreConfig passed directly as config
config = InferenceStoreConfig(
sql_store_config=config,
) )
self.sql_store_config = sql_store_config
self.config = config
self.sql_store_config = config.sql_store_config
self.sql_store = None self.sql_store = None
self.policy = policy self.policy = policy
# Disable write queue for SQLite to avoid concurrency issues
self.enable_write_queue = self.sql_store_config.type != SqlStoreType.sqlite
# Async write queue and worker control
self._queue: asyncio.Queue[tuple[OpenAIChatCompletion, list[OpenAIMessageParam]]] | None = None
self._worker_tasks: list[asyncio.Task[Any]] = []
self._max_write_queue_size: int = config.max_write_queue_size
self._num_writers: int = max(1, config.num_writers)
async def initialize(self): async def initialize(self):
"""Create the necessary tables if they don't exist.""" """Create the necessary tables if they don't exist."""
self.sql_store = AuthorizedSqlStore(sqlstore_impl(self.sql_store_config)) self.sql_store = AuthorizedSqlStore(sqlstore_impl(self.sql_store_config))
@ -42,10 +64,68 @@ class InferenceStore:
}, },
) )
if self.enable_write_queue:
self._queue = asyncio.Queue(maxsize=self._max_write_queue_size)
for _ in range(self._num_writers):
self._worker_tasks.append(asyncio.create_task(self._worker_loop()))
else:
logger.info("Write queue disabled for SQLite to avoid concurrency issues")
async def shutdown(self) -> None:
if not self._worker_tasks:
return
if self._queue is not None:
await self._queue.join()
for t in self._worker_tasks:
if not t.done():
t.cancel()
for t in self._worker_tasks:
try:
await t
except asyncio.CancelledError:
pass
self._worker_tasks.clear()
async def flush(self) -> None:
"""Wait for all queued writes to complete. Useful for testing."""
if self.enable_write_queue and self._queue is not None:
await self._queue.join()
async def store_chat_completion( async def store_chat_completion(
self, chat_completion: OpenAIChatCompletion, input_messages: list[OpenAIMessageParam] self, chat_completion: OpenAIChatCompletion, input_messages: list[OpenAIMessageParam]
) -> None: ) -> None:
if not self.sql_store: if self.enable_write_queue:
if self._queue is None:
raise ValueError("Inference store is not initialized")
try:
self._queue.put_nowait((chat_completion, input_messages))
except asyncio.QueueFull:
logger.warning(
f"Write queue full; adding chat completion id={getattr(chat_completion, 'id', '<unknown>')}"
)
await self._queue.put((chat_completion, input_messages))
else:
await self._write_chat_completion(chat_completion, input_messages)
async def _worker_loop(self) -> None:
assert self._queue is not None
while True:
try:
item = await self._queue.get()
except asyncio.CancelledError:
break
chat_completion, input_messages = item
try:
await self._write_chat_completion(chat_completion, input_messages)
except Exception as e: # noqa: BLE001
logger.error(f"Error writing chat completion: {e}")
finally:
self._queue.task_done()
async def _write_chat_completion(
self, chat_completion: OpenAIChatCompletion, input_messages: list[OpenAIMessageParam]
) -> None:
if self.sql_store is None:
raise ValueError("Inference store is not initialized") raise ValueError("Inference store is not initialized")
data = chat_completion.model_dump() data = chat_completion.model_dump()

View file

@ -18,6 +18,7 @@ from functools import wraps
from typing import Any from typing import Any
from llama_stack.apis.telemetry import ( from llama_stack.apis.telemetry import (
Event,
LogSeverity, LogSeverity,
Span, Span,
SpanEndPayload, SpanEndPayload,
@ -98,7 +99,7 @@ class BackgroundLogger:
def __init__(self, api: Telemetry, capacity: int = 100000): def __init__(self, api: Telemetry, capacity: int = 100000):
self.api = api self.api = api
self.log_queue: queue.Queue[Any] = queue.Queue(maxsize=capacity) self.log_queue: queue.Queue[Any] = queue.Queue(maxsize=capacity)
self.worker_thread = threading.Thread(target=self._process_logs, daemon=True) self.worker_thread = threading.Thread(target=self._worker, daemon=True)
self.worker_thread.start() self.worker_thread.start()
self._last_queue_full_log_time: float = 0.0 self._last_queue_full_log_time: float = 0.0
self._dropped_since_last_notice: int = 0 self._dropped_since_last_notice: int = 0
@ -118,12 +119,16 @@ class BackgroundLogger:
self._last_queue_full_log_time = current_time self._last_queue_full_log_time = current_time
self._dropped_since_last_notice = 0 self._dropped_since_last_notice = 0
def _process_logs(self): def _worker(self):
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
loop.run_until_complete(self._process_logs())
async def _process_logs(self):
while True: while True:
try: try:
event = self.log_queue.get() event = self.log_queue.get()
# figure out how to use a thread's native loop await self.api.log_event(event)
asyncio.run(self.api.log_event(event))
except Exception: except Exception:
import traceback import traceback
@ -136,6 +141,19 @@ class BackgroundLogger:
self.log_queue.join() self.log_queue.join()
def enqueue_event(event: Event) -> None:
"""Enqueue a telemetry event to the background logger if available.
This provides a non-blocking path for routers and other hot paths to
submit telemetry without awaiting the Telemetry API, reducing contention
with the main event loop.
"""
global BACKGROUND_LOGGER
if BACKGROUND_LOGGER is None:
raise RuntimeError("Telemetry API not initialized")
BACKGROUND_LOGGER.log_event(event)
class TraceContext: class TraceContext:
spans: list[Span] = [] spans: list[Span] = []
@ -256,11 +274,7 @@ class TelemetryHandler(logging.Handler):
if record.module in ("asyncio", "selector_events"): if record.module in ("asyncio", "selector_events"):
return return
global CURRENT_TRACE_CONTEXT, BACKGROUND_LOGGER global CURRENT_TRACE_CONTEXT
if BACKGROUND_LOGGER is None:
raise RuntimeError("Telemetry API not initialized")
context = CURRENT_TRACE_CONTEXT.get() context = CURRENT_TRACE_CONTEXT.get()
if context is None: if context is None:
return return
@ -269,7 +283,7 @@ class TelemetryHandler(logging.Handler):
if span is None: if span is None:
return return
BACKGROUND_LOGGER.log_event( enqueue_event(
UnstructuredLogEvent( UnstructuredLogEvent(
trace_id=span.trace_id, trace_id=span.trace_id,
span_id=span.span_id, span_id=span.span_id,

View file

@ -239,8 +239,9 @@ echo "Test pattern: ${TEST_PATTERN:-"(none)"}"
echo "" echo ""
# Prepare inputs for gh workflow run # Prepare inputs for gh workflow run
INPUTS=
if [[ -n "$TEST_SUBDIRS" ]]; then if [[ -n "$TEST_SUBDIRS" ]]; then
INPUTS="-f subdirs='$TEST_SUBDIRS'" INPUTS="$INPUTS -f subdirs='$TEST_SUBDIRS'"
fi fi
if [[ -n "$TEST_SETUP" ]]; then if [[ -n "$TEST_SETUP" ]]; then
INPUTS="$INPUTS -f test-setup='$TEST_SETUP'" INPUTS="$INPUTS -f test-setup='$TEST_SETUP'"

View file

@ -65,6 +65,9 @@ async def test_inference_store_pagination_basic():
input_messages = [OpenAIUserMessageParam(role="user", content=f"Test message for {completion_id}")] input_messages = [OpenAIUserMessageParam(role="user", content=f"Test message for {completion_id}")]
await store.store_chat_completion(completion, input_messages) await store.store_chat_completion(completion, input_messages)
# Wait for all queued writes to complete
await store.flush()
# Test 1: First page with limit=2, descending order (default) # Test 1: First page with limit=2, descending order (default)
result = await store.list_chat_completions(limit=2, order=Order.desc) result = await store.list_chat_completions(limit=2, order=Order.desc)
assert len(result.data) == 2 assert len(result.data) == 2
@ -108,6 +111,9 @@ async def test_inference_store_pagination_ascending():
input_messages = [OpenAIUserMessageParam(role="user", content=f"Test message for {completion_id}")] input_messages = [OpenAIUserMessageParam(role="user", content=f"Test message for {completion_id}")]
await store.store_chat_completion(completion, input_messages) await store.store_chat_completion(completion, input_messages)
# Wait for all queued writes to complete
await store.flush()
# Test ascending order pagination # Test ascending order pagination
result = await store.list_chat_completions(limit=1, order=Order.asc) result = await store.list_chat_completions(limit=1, order=Order.asc)
assert len(result.data) == 1 assert len(result.data) == 1
@ -143,6 +149,9 @@ async def test_inference_store_pagination_with_model_filter():
input_messages = [OpenAIUserMessageParam(role="user", content=f"Test message for {completion_id}")] input_messages = [OpenAIUserMessageParam(role="user", content=f"Test message for {completion_id}")]
await store.store_chat_completion(completion, input_messages) await store.store_chat_completion(completion, input_messages)
# Wait for all queued writes to complete
await store.flush()
# Test pagination with model filter # Test pagination with model filter
result = await store.list_chat_completions(limit=1, model="model-a", order=Order.desc) result = await store.list_chat_completions(limit=1, model="model-a", order=Order.desc)
assert len(result.data) == 1 assert len(result.data) == 1
@ -190,6 +199,9 @@ async def test_inference_store_pagination_no_limit():
input_messages = [OpenAIUserMessageParam(role="user", content=f"Test message for {completion_id}")] input_messages = [OpenAIUserMessageParam(role="user", content=f"Test message for {completion_id}")]
await store.store_chat_completion(completion, input_messages) await store.store_chat_completion(completion, input_messages)
# Wait for all queued writes to complete
await store.flush()
# Test without limit # Test without limit
result = await store.list_chat_completions(order=Order.desc) result = await store.list_chat_completions(order=Order.desc)
assert len(result.data) == 2 assert len(result.data) == 2