llama-stack-mirror/tests/unit/utils/inference/test_inference_store.py
Eric Huang e721ca9730 chore: introduce write queue for inference_store
# What does this PR do?
Adds a write worker queue for writes to inference store. This avoids overwhelming request processing with slow inference writes.

## Test Plan

Benchmark:
```
cd /docs/source/distributions/k8s-benchmark
# start mock server
python openai-mock-server.py --port 8000
# start stack server
uv run --with llama-stack python -m llama_stack.core.server.server docs/source/distributions/k8s-benchmark/stack_run_config.yaml
# run benchmark script
uv run python3 benchmark.py --duration 120 --concurrent 50 --base-url=http://localhost:8321/v1/openai/v1 --model=vllm-inference/meta-llama/Llama-3.2-3B-Instruct
```


Before:

============================================================
BENCHMARK RESULTS

Response Time Statistics:
  Mean: 1.111s
  Median: 0.982s
  Min: 0.466s
  Max: 15.190s
  Std Dev: 1.091s

Percentiles:
  P50: 0.982s
  P90: 1.281s
  P95: 1.439s
  P99: 5.476s

Time to First Token (TTFT) Statistics:
  Mean: 0.474s
  Median: 0.347s
  Min: 0.175s
  Max: 15.129s
  Std Dev: 0.819s

TTFT Percentiles:
  P50: 0.347s
  P90: 0.661s
  P95: 0.762s
  P99: 2.788s

Streaming Statistics:
  Mean chunks per response: 67.2
  Total chunks received: 122154
============================================================
Total time: 120.00s
Concurrent users: 50
Total requests: 1919
Successful requests: 1819
Failed requests: 100
Success rate: 94.8%
Requests per second: 15.16

Errors (showing first 5):
  Request error:
  Request error:
  Request error:
  Request error:
  Request error:
Benchmark completed.
Stopping server (PID: 679)...
Server stopped.


After:

============================================================
BENCHMARK RESULTS

Response Time Statistics:
  Mean: 1.085s
  Median: 1.089s
  Min: 0.451s
  Max: 2.002s
  Std Dev: 0.212s

Percentiles:
  P50: 1.089s
  P90: 1.343s
  P95: 1.409s
  P99: 1.617s

Time to First Token (TTFT) Statistics:
  Mean: 0.407s
  Median: 0.361s
  Min: 0.182s
  Max: 1.178s
  Std Dev: 0.175s

TTFT Percentiles:
  P50: 0.361s
  P90: 0.644s
  P95: 0.744s
  P99: 0.932s

Streaming Statistics:
  Mean chunks per response: 66.8
  Total chunks received: 367240
============================================================
Total time: 120.00s
Concurrent users: 50
Total requests: 5495
Successful requests: 5495
Failed requests: 0
Success rate: 100.0%
Requests per second: 45.79
Benchmark completed.
Stopping server (PID: 97169)...
Server stopped.
2025-09-10 11:50:06 -07:00

210 lines
8.1 KiB
Python

# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import time
from tempfile import TemporaryDirectory
import pytest
from llama_stack.apis.inference import (
OpenAIAssistantMessageParam,
OpenAIChatCompletion,
OpenAIChoice,
OpenAIUserMessageParam,
Order,
)
from llama_stack.providers.utils.inference.inference_store import InferenceStore
from llama_stack.providers.utils.sqlstore.sqlstore import SqliteSqlStoreConfig
def create_test_chat_completion(
completion_id: str, created_timestamp: int, model: str = "test-model"
) -> OpenAIChatCompletion:
"""Helper to create a test chat completion."""
return OpenAIChatCompletion(
id=completion_id,
created=created_timestamp,
model=model,
object="chat.completion",
choices=[
OpenAIChoice(
index=0,
message=OpenAIAssistantMessageParam(
role="assistant",
content=f"Response for {completion_id}",
),
finish_reason="stop",
)
],
)
async def test_inference_store_pagination_basic():
"""Test basic pagination functionality."""
with TemporaryDirectory() as tmp_dir:
db_path = tmp_dir + "/test.db"
store = InferenceStore(SqliteSqlStoreConfig(db_path=db_path), policy=[])
await store.initialize()
# Create test data with different timestamps
base_time = int(time.time())
test_data = [
("zebra-task", base_time + 1),
("apple-job", base_time + 2),
("moon-work", base_time + 3),
("banana-run", base_time + 4),
("car-exec", base_time + 5),
]
# Store test chat completions
for completion_id, timestamp in test_data:
completion = create_test_chat_completion(completion_id, timestamp)
input_messages = [OpenAIUserMessageParam(role="user", content=f"Test message for {completion_id}")]
await store.store_chat_completion(completion, input_messages)
# Wait for all queued writes to complete
await store.flush()
# Test 1: First page with limit=2, descending order (default)
result = await store.list_chat_completions(limit=2, order=Order.desc)
assert len(result.data) == 2
assert result.data[0].id == "car-exec" # Most recent first
assert result.data[1].id == "banana-run"
assert result.has_more is True
assert result.last_id == "banana-run"
# Test 2: Second page using 'after' parameter
result2 = await store.list_chat_completions(after="banana-run", limit=2, order=Order.desc)
assert len(result2.data) == 2
assert result2.data[0].id == "moon-work"
assert result2.data[1].id == "apple-job"
assert result2.has_more is True
# Test 3: Final page
result3 = await store.list_chat_completions(after="apple-job", limit=2, order=Order.desc)
assert len(result3.data) == 1
assert result3.data[0].id == "zebra-task"
assert result3.has_more is False
async def test_inference_store_pagination_ascending():
"""Test pagination with ascending order."""
with TemporaryDirectory() as tmp_dir:
db_path = tmp_dir + "/test.db"
store = InferenceStore(SqliteSqlStoreConfig(db_path=db_path), policy=[])
await store.initialize()
# Create test data
base_time = int(time.time())
test_data = [
("delta-item", base_time + 1),
("charlie-task", base_time + 2),
("alpha-work", base_time + 3),
]
# Store test chat completions
for completion_id, timestamp in test_data:
completion = create_test_chat_completion(completion_id, timestamp)
input_messages = [OpenAIUserMessageParam(role="user", content=f"Test message for {completion_id}")]
await store.store_chat_completion(completion, input_messages)
# Wait for all queued writes to complete
await store.flush()
# Test ascending order pagination
result = await store.list_chat_completions(limit=1, order=Order.asc)
assert len(result.data) == 1
assert result.data[0].id == "delta-item" # Oldest first
assert result.has_more is True
# Second page with ascending order
result2 = await store.list_chat_completions(after="delta-item", limit=1, order=Order.asc)
assert len(result2.data) == 1
assert result2.data[0].id == "charlie-task"
assert result2.has_more is True
async def test_inference_store_pagination_with_model_filter():
"""Test pagination combined with model filtering."""
with TemporaryDirectory() as tmp_dir:
db_path = tmp_dir + "/test.db"
store = InferenceStore(SqliteSqlStoreConfig(db_path=db_path), policy=[])
await store.initialize()
# Create test data with different models
base_time = int(time.time())
test_data = [
("xyz-task", base_time + 1, "model-a"),
("def-work", base_time + 2, "model-b"),
("pqr-job", base_time + 3, "model-a"),
("abc-run", base_time + 4, "model-b"),
]
# Store test chat completions
for completion_id, timestamp, model in test_data:
completion = create_test_chat_completion(completion_id, timestamp, model)
input_messages = [OpenAIUserMessageParam(role="user", content=f"Test message for {completion_id}")]
await store.store_chat_completion(completion, input_messages)
# Wait for all queued writes to complete
await store.flush()
# Test pagination with model filter
result = await store.list_chat_completions(limit=1, model="model-a", order=Order.desc)
assert len(result.data) == 1
assert result.data[0].id == "pqr-job" # Most recent model-a
assert result.data[0].model == "model-a"
assert result.has_more is True
# Second page with model filter
result2 = await store.list_chat_completions(after="pqr-job", limit=1, model="model-a", order=Order.desc)
assert len(result2.data) == 1
assert result2.data[0].id == "xyz-task"
assert result2.data[0].model == "model-a"
assert result2.has_more is False
async def test_inference_store_pagination_invalid_after():
"""Test error handling for invalid 'after' parameter."""
with TemporaryDirectory() as tmp_dir:
db_path = tmp_dir + "/test.db"
store = InferenceStore(SqliteSqlStoreConfig(db_path=db_path), policy=[])
await store.initialize()
# Try to paginate with non-existent ID
with pytest.raises(ValueError, match="Record with id='non-existent' not found in table 'chat_completions'"):
await store.list_chat_completions(after="non-existent", limit=2)
async def test_inference_store_pagination_no_limit():
"""Test pagination behavior when no limit is specified."""
with TemporaryDirectory() as tmp_dir:
db_path = tmp_dir + "/test.db"
store = InferenceStore(SqliteSqlStoreConfig(db_path=db_path), policy=[])
await store.initialize()
# Create test data
base_time = int(time.time())
test_data = [
("omega-first", base_time + 1),
("beta-second", base_time + 2),
]
# Store test chat completions
for completion_id, timestamp in test_data:
completion = create_test_chat_completion(completion_id, timestamp)
input_messages = [OpenAIUserMessageParam(role="user", content=f"Test message for {completion_id}")]
await store.store_chat_completion(completion, input_messages)
# Wait for all queued writes to complete
await store.flush()
# Test without limit
result = await store.list_chat_completions(order=Order.desc)
assert len(result.data) == 2
assert result.data[0].id == "beta-second" # Most recent first
assert result.data[1].id == "omega-first"
assert result.has_more is False