llama-stack-mirror/tests/unit/utils/inference/test_inference_store.py
Charlie Doern a078f089d9
Some checks failed
Integration Tests (Replay) / generate-matrix (push) Successful in 3s
SqlStore Integration Tests / test-postgres (3.12) (push) Failing after 0s
Integration Auth Tests / test-matrix (oauth2_token) (push) Failing after 1s
SqlStore Integration Tests / test-postgres (3.13) (push) Failing after 0s
Test External Providers Installed via Module / test-external-providers-from-module (venv) (push) Has been skipped
Test Llama Stack Build / generate-matrix (push) Successful in 5s
Python Package Build Test / build (3.12) (push) Failing after 4s
API Conformance Tests / check-schema-compatibility (push) Successful in 12s
Test llama stack list-deps / generate-matrix (push) Successful in 29s
Test Llama Stack Build / build-single-provider (push) Successful in 33s
Test llama stack list-deps / list-deps-from-config (push) Successful in 32s
UI Tests / ui-tests (22) (push) Successful in 39s
Test Llama Stack Build / build (push) Successful in 39s
Test llama stack list-deps / show-single-provider (push) Successful in 46s
Python Package Build Test / build (3.13) (push) Failing after 44s
Test External API and Providers / test-external (venv) (push) Failing after 44s
Vector IO Integration Tests / test-matrix (push) Failing after 56s
Test llama stack list-deps / list-deps (push) Failing after 47s
Unit Tests / unit-tests (3.12) (push) Failing after 1m42s
Unit Tests / unit-tests (3.13) (push) Failing after 1m55s
Test Llama Stack Build / build-ubi9-container-distribution (push) Successful in 2m0s
Test Llama Stack Build / build-custom-container-distribution (push) Successful in 2m2s
Integration Tests (Replay) / Integration Tests (, , , client=, ) (push) Failing after 2m42s
Pre-commit / pre-commit (push) Successful in 5m17s
fix: rename llama_stack_api dir (#4155)
# What does this PR do?

the directory structure was src/llama-stack-api/llama_stack_api

instead it should just be src/llama_stack_api to match the other
packages.

update the structure and pyproject/linting config

---------

Signed-off-by: Charlie Doern <cdoern@redhat.com>
Co-authored-by: Ashwin Bharambe <ashwin.bharambe@gmail.com>
2025-11-13 15:04:36 -08:00

212 lines
7.9 KiB
Python

# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import time
import pytest
from llama_stack.core.storage.datatypes import InferenceStoreReference, SqliteSqlStoreConfig
from llama_stack.providers.utils.inference.inference_store import InferenceStore
from llama_stack.providers.utils.sqlstore.sqlstore import register_sqlstore_backends
from llama_stack_api import (
OpenAIAssistantMessageParam,
OpenAIChatCompletion,
OpenAIChoice,
OpenAIUserMessageParam,
Order,
)
@pytest.fixture(autouse=True)
def setup_backends(tmp_path):
"""Register SQL store backends for testing."""
db_path = str(tmp_path / "test.db")
register_sqlstore_backends({"sql_default": SqliteSqlStoreConfig(db_path=db_path)})
def create_test_chat_completion(
completion_id: str, created_timestamp: int, model: str = "test-model"
) -> OpenAIChatCompletion:
"""Helper to create a test chat completion."""
return OpenAIChatCompletion(
id=completion_id,
created=created_timestamp,
model=model,
object="chat.completion",
choices=[
OpenAIChoice(
index=0,
message=OpenAIAssistantMessageParam(
role="assistant",
content=f"Response for {completion_id}",
),
finish_reason="stop",
)
],
)
async def test_inference_store_pagination_basic():
"""Test basic pagination functionality."""
reference = InferenceStoreReference(backend="sql_default", table_name="chat_completions")
store = InferenceStore(reference, policy=[])
await store.initialize()
# Create test data with different timestamps
base_time = int(time.time())
test_data = [
("zebra-task", base_time + 1),
("apple-job", base_time + 2),
("moon-work", base_time + 3),
("banana-run", base_time + 4),
("car-exec", base_time + 5),
]
# Store test chat completions
for completion_id, timestamp in test_data:
completion = create_test_chat_completion(completion_id, timestamp)
input_messages = [OpenAIUserMessageParam(role="user", content=f"Test message for {completion_id}")]
await store.store_chat_completion(completion, input_messages)
# Wait for all queued writes to complete
await store.flush()
# Test 1: First page with limit=2, descending order (default)
result = await store.list_chat_completions(limit=2, order=Order.desc)
assert len(result.data) == 2
assert result.data[0].id == "car-exec" # Most recent first
assert result.data[1].id == "banana-run"
assert result.has_more is True
assert result.last_id == "banana-run"
# Test 2: Second page using 'after' parameter
result2 = await store.list_chat_completions(after="banana-run", limit=2, order=Order.desc)
assert len(result2.data) == 2
assert result2.data[0].id == "moon-work"
assert result2.data[1].id == "apple-job"
assert result2.has_more is True
# Test 3: Final page
result3 = await store.list_chat_completions(after="apple-job", limit=2, order=Order.desc)
assert len(result3.data) == 1
assert result3.data[0].id == "zebra-task"
assert result3.has_more is False
async def test_inference_store_pagination_ascending():
"""Test pagination with ascending order."""
reference = InferenceStoreReference(backend="sql_default", table_name="chat_completions")
store = InferenceStore(reference, policy=[])
await store.initialize()
# Create test data
base_time = int(time.time())
test_data = [
("delta-item", base_time + 1),
("charlie-task", base_time + 2),
("alpha-work", base_time + 3),
]
# Store test chat completions
for completion_id, timestamp in test_data:
completion = create_test_chat_completion(completion_id, timestamp)
input_messages = [OpenAIUserMessageParam(role="user", content=f"Test message for {completion_id}")]
await store.store_chat_completion(completion, input_messages)
# Wait for all queued writes to complete
await store.flush()
# Test ascending order pagination
result = await store.list_chat_completions(limit=1, order=Order.asc)
assert len(result.data) == 1
assert result.data[0].id == "delta-item" # Oldest first
assert result.has_more is True
# Second page with ascending order
result2 = await store.list_chat_completions(after="delta-item", limit=1, order=Order.asc)
assert len(result2.data) == 1
assert result2.data[0].id == "charlie-task"
assert result2.has_more is True
async def test_inference_store_pagination_with_model_filter():
"""Test pagination combined with model filtering."""
reference = InferenceStoreReference(backend="sql_default", table_name="chat_completions")
store = InferenceStore(reference, policy=[])
await store.initialize()
# Create test data with different models
base_time = int(time.time())
test_data = [
("xyz-task", base_time + 1, "model-a"),
("def-work", base_time + 2, "model-b"),
("pqr-job", base_time + 3, "model-a"),
("abc-run", base_time + 4, "model-b"),
]
# Store test chat completions
for completion_id, timestamp, model in test_data:
completion = create_test_chat_completion(completion_id, timestamp, model)
input_messages = [OpenAIUserMessageParam(role="user", content=f"Test message for {completion_id}")]
await store.store_chat_completion(completion, input_messages)
# Wait for all queued writes to complete
await store.flush()
# Test pagination with model filter
result = await store.list_chat_completions(limit=1, model="model-a", order=Order.desc)
assert len(result.data) == 1
assert result.data[0].id == "pqr-job" # Most recent model-a
assert result.data[0].model == "model-a"
assert result.has_more is True
# Second page with model filter
result2 = await store.list_chat_completions(after="pqr-job", limit=1, model="model-a", order=Order.desc)
assert len(result2.data) == 1
assert result2.data[0].id == "xyz-task"
assert result2.data[0].model == "model-a"
assert result2.has_more is False
async def test_inference_store_pagination_invalid_after():
"""Test error handling for invalid 'after' parameter."""
reference = InferenceStoreReference(backend="sql_default", table_name="chat_completions")
store = InferenceStore(reference, policy=[])
await store.initialize()
# Try to paginate with non-existent ID
with pytest.raises(ValueError, match="Record with id='non-existent' not found in table 'chat_completions'"):
await store.list_chat_completions(after="non-existent", limit=2)
async def test_inference_store_pagination_no_limit():
"""Test pagination behavior when no limit is specified."""
reference = InferenceStoreReference(backend="sql_default", table_name="chat_completions")
store = InferenceStore(reference, policy=[])
await store.initialize()
# Create test data
base_time = int(time.time())
test_data = [
("omega-first", base_time + 1),
("beta-second", base_time + 2),
]
# Store test chat completions
for completion_id, timestamp in test_data:
completion = create_test_chat_completion(completion_id, timestamp)
input_messages = [OpenAIUserMessageParam(role="user", content=f"Test message for {completion_id}")]
await store.store_chat_completion(completion, input_messages)
# Wait for all queued writes to complete
await store.flush()
# Test without limit
result = await store.list_chat_completions(order=Order.desc)
assert len(result.data) == 2
assert result.data[0].id == "beta-second" # Most recent first
assert result.data[1].id == "omega-first"
assert result.has_more is False