Merge branch 'main' into opengauss-add

This commit is contained in:
windy 2025-08-08 20:58:48 +08:00 committed by GitHub
commit 39e49ab97a
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
807 changed files with 79555 additions and 26772 deletions

View file

@ -1,9 +1,17 @@
# Llama Stack Unit Tests
## Unit Tests
Unit tests verify individual components and functions in isolation. They are fast, reliable, and don't require external services.
### Prerequisites
1. **Python Environment**: Ensure you have Python 3.12+ installed
2. **uv Package Manager**: Install `uv` if not already installed
You can run the unit tests by running:
```bash
source .venv/bin/activate
./scripts/unit-tests.sh [PYTEST_ARGS]
```
@ -19,3 +27,21 @@ If you'd like to run for a non-default version of Python (currently 3.12), pass
source .venv/bin/activate
PYTHON_VERSION=3.13 ./scripts/unit-tests.sh
```
### Test Configuration
- **Test Discovery**: Tests are automatically discovered in the `tests/unit/` directory
- **Async Support**: Tests use `--asyncio-mode=auto` for automatic async test handling
- **Coverage**: Tests generate coverage reports in `htmlcov/` directory
- **Python Version**: Defaults to Python 3.12, but can be overridden with `PYTHON_VERSION` environment variable
### Coverage Reports
After running tests, you can view coverage reports:
```bash
# Open HTML coverage report in browser
open htmlcov/index.html # macOS
xdg-open htmlcov/index.html # Linux
start htmlcov/index.html # Windows
```

View file

@ -9,12 +9,43 @@ from datetime import datetime
import pytest
import yaml
from llama_stack.distribution.configure import (
from llama_stack.core.configure import (
LLAMA_STACK_RUN_CONFIG_VERSION,
parse_and_maybe_upgrade_config,
)
@pytest.fixture
def config_with_image_name_int():
return yaml.safe_load(
f"""
version: {LLAMA_STACK_RUN_CONFIG_VERSION}
image_name: 1234
apis_to_serve: []
built_at: {datetime.now().isoformat()}
providers:
inference:
- provider_id: provider1
provider_type: inline::meta-reference
config: {{}}
safety:
- provider_id: provider1
provider_type: inline::meta-reference
config:
llama_guard_shield:
model: Llama-Guard-3-1B
excluded_categories: []
disable_input_check: false
disable_output_check: false
enable_prompt_guard: false
memory:
- provider_id: provider1
provider_type: inline::meta-reference
config: {{}}
"""
)
@pytest.fixture
def up_to_date_config():
return yaml.safe_load(
@ -125,3 +156,8 @@ def test_parse_and_maybe_upgrade_config_old_format(old_config):
def test_parse_and_maybe_upgrade_config_invalid(invalid_config):
with pytest.raises(KeyError):
parse_and_maybe_upgrade_config(invalid_config)
def test_parse_and_maybe_upgrade_config_image_name_int(config_with_image_name_int):
result = parse_and_maybe_upgrade_config(config_with_image_name_int)
assert isinstance(result.image_name, str)

View file

@ -4,6 +4,17 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import pytest_socket
# We need to import the fixtures here so that pytest can find them
# but ruff doesn't think they are used and removes the import. "noqa: F401" prevents them from being removed
from .fixtures import cached_disk_dist_registry, disk_dist_registry, sqlite_kvstore # noqa: F401
def pytest_runtest_setup(item):
"""Setup for each test - check if network access should be allowed."""
if "allow_network" in item.keywords:
pytest_socket.enable_socket()
else:
# Allowing Unix sockets is necessary for some tests that use local servers and mocks
pytest_socket.disable_socket(allow_unix_socket=True)

View file

@ -16,14 +16,15 @@ from llama_stack.apis.datatypes import Api
from llama_stack.apis.models import Model, ModelType
from llama_stack.apis.shields.shields import Shield
from llama_stack.apis.tools import ListToolDefsResponse, ToolDef, ToolGroup, ToolParameter
from llama_stack.apis.vector_dbs.vector_dbs import VectorDB
from llama_stack.distribution.routing_tables.benchmarks import BenchmarksRoutingTable
from llama_stack.distribution.routing_tables.datasets import DatasetsRoutingTable
from llama_stack.distribution.routing_tables.models import ModelsRoutingTable
from llama_stack.distribution.routing_tables.scoring_functions import ScoringFunctionsRoutingTable
from llama_stack.distribution.routing_tables.shields import ShieldsRoutingTable
from llama_stack.distribution.routing_tables.toolgroups import ToolGroupsRoutingTable
from llama_stack.distribution.routing_tables.vector_dbs import VectorDBsRoutingTable
from llama_stack.apis.vector_dbs import VectorDB
from llama_stack.core.datatypes import RegistryEntrySource
from llama_stack.core.routing_tables.benchmarks import BenchmarksRoutingTable
from llama_stack.core.routing_tables.datasets import DatasetsRoutingTable
from llama_stack.core.routing_tables.models import ModelsRoutingTable
from llama_stack.core.routing_tables.scoring_functions import ScoringFunctionsRoutingTable
from llama_stack.core.routing_tables.shields import ShieldsRoutingTable
from llama_stack.core.routing_tables.toolgroups import ToolGroupsRoutingTable
from llama_stack.core.routing_tables.vector_dbs import VectorDBsRoutingTable
class Impl:
@ -47,6 +48,30 @@ class InferenceImpl(Impl):
async def unregister_model(self, model_id: str):
return model_id
async def should_refresh_models(self):
return False
async def list_models(self):
return [
Model(
identifier="provider-model-1",
provider_resource_id="provider-model-1",
provider_id="test_provider",
metadata={},
model_type=ModelType.llm,
),
Model(
identifier="provider-model-2",
provider_resource_id="provider-model-2",
provider_id="test_provider",
metadata={"embedding_dimension": 512},
model_type=ModelType.embedding,
),
]
async def shutdown(self):
pass
class SafetyImpl(Impl):
def __init__(self):
@ -55,16 +80,8 @@ class SafetyImpl(Impl):
async def register_shield(self, shield: Shield):
return shield
class VectorDBImpl(Impl):
def __init__(self):
super().__init__(Api.vector_io)
async def register_vector_db(self, vector_db: VectorDB):
return vector_db
async def unregister_vector_db(self, vector_db_id: str):
return vector_db_id
async def unregister_shield(self, shield_id: str):
return shield_id
class DatasetsImpl(Impl):
@ -119,7 +136,17 @@ class ToolGroupsImpl(Impl):
)
@pytest.mark.asyncio
class VectorDBImpl(Impl):
def __init__(self):
super().__init__(Api.vector_io)
async def register_vector_db(self, vector_db: VectorDB):
return vector_db
async def unregister_vector_db(self, vector_db_id: str):
return vector_db_id
async def test_models_routing_table(cached_disk_dist_registry):
table = ModelsRoutingTable({"test_provider": InferenceImpl()}, cached_disk_dist_registry, {})
await table.initialize()
@ -131,27 +158,27 @@ async def test_models_routing_table(cached_disk_dist_registry):
models = await table.list_models()
assert len(models.data) == 2
model_ids = {m.identifier for m in models.data}
assert "test-model" in model_ids
assert "test-model-2" in model_ids
assert "test_provider/test-model" in model_ids
assert "test_provider/test-model-2" in model_ids
# Test openai list models
openai_models = await table.openai_list_models()
assert len(openai_models.data) == 2
openai_model_ids = {m.id for m in openai_models.data}
assert "test-model" in openai_model_ids
assert "test-model-2" in openai_model_ids
assert "test_provider/test-model" in openai_model_ids
assert "test_provider/test-model-2" in openai_model_ids
# Test get_object_by_identifier
model = await table.get_object_by_identifier("model", "test-model")
model = await table.get_object_by_identifier("model", "test_provider/test-model")
assert model is not None
assert model.identifier == "test-model"
assert model.identifier == "test_provider/test-model"
# Test get_object_by_identifier on non-existent object
non_existent = await table.get_object_by_identifier("model", "non-existent-model")
assert non_existent is None
await table.unregister_model(model_id="test-model")
await table.unregister_model(model_id="test-model-2")
await table.unregister_model(model_id="test_provider/test-model")
await table.unregister_model(model_id="test_provider/test-model-2")
models = await table.list_models()
assert len(models.data) == 0
@ -161,7 +188,6 @@ async def test_models_routing_table(cached_disk_dist_registry):
assert len(openai_models.data) == 0
@pytest.mark.asyncio
async def test_shields_routing_table(cached_disk_dist_registry):
table = ShieldsRoutingTable({"test_provider": SafetyImpl()}, cached_disk_dist_registry, {})
await table.initialize()
@ -170,14 +196,43 @@ async def test_shields_routing_table(cached_disk_dist_registry):
await table.register_shield(shield_id="test-shield", provider_id="test_provider")
await table.register_shield(shield_id="test-shield-2", provider_id="test_provider")
shields = await table.list_shields()
assert len(shields.data) == 2
shield_ids = {s.identifier for s in shields.data}
assert "test-shield" in shield_ids
assert "test-shield-2" in shield_ids
# Test get specific shield
test_shield = await table.get_shield(identifier="test-shield")
assert test_shield is not None
assert test_shield.identifier == "test-shield"
assert test_shield.provider_id == "test_provider"
assert test_shield.provider_resource_id == "test-shield"
assert test_shield.params == {}
# Test get non-existent shield - should raise ValueError with specific message
with pytest.raises(ValueError, match="Shield 'non-existent' not found"):
await table.get_shield(identifier="non-existent")
# Test unregistering shields
await table.unregister_shield(identifier="test-shield")
shields = await table.list_shields()
assert len(shields.data) == 1
shield_ids = {s.identifier for s in shields.data}
assert "test-shield" not in shield_ids
assert "test-shield-2" in shield_ids
# Unregister the remaining shield
await table.unregister_shield(identifier="test-shield-2")
shields = await table.list_shields()
assert len(shields.data) == 0
# Test unregistering non-existent shield - should raise ValueError with specific message
with pytest.raises(ValueError, match="Shield 'non-existent' not found"):
await table.unregister_shield(identifier="non-existent")
@pytest.mark.asyncio
async def test_vectordbs_routing_table(cached_disk_dist_registry):
table = VectorDBsRoutingTable({"test_provider": VectorDBImpl()}, cached_disk_dist_registry, {})
await table.initialize()
@ -192,8 +247,8 @@ async def test_vectordbs_routing_table(cached_disk_dist_registry):
)
# Register multiple vector databases and verify listing
await table.register_vector_db(vector_db_id="test-vectordb", embedding_model="test-model")
await table.register_vector_db(vector_db_id="test-vectordb-2", embedding_model="test-model")
await table.register_vector_db(vector_db_id="test-vectordb", embedding_model="test_provider/test-model")
await table.register_vector_db(vector_db_id="test-vectordb-2", embedding_model="test_provider/test-model")
vector_dbs = await table.list_vector_dbs()
assert len(vector_dbs.data) == 2
@ -233,7 +288,6 @@ async def test_datasets_routing_table(cached_disk_dist_registry):
assert len(datasets.data) == 0
@pytest.mark.asyncio
async def test_scoring_functions_routing_table(cached_disk_dist_registry):
table = ScoringFunctionsRoutingTable({"test_provider": ScoringFunctionsImpl()}, cached_disk_dist_registry, {})
await table.initialize()
@ -259,7 +313,6 @@ async def test_scoring_functions_routing_table(cached_disk_dist_registry):
assert "test-scoring-fn-2" in scoring_fn_ids
@pytest.mark.asyncio
async def test_benchmarks_routing_table(cached_disk_dist_registry):
table = BenchmarksRoutingTable({"test_provider": BenchmarksImpl()}, cached_disk_dist_registry, {})
await table.initialize()
@ -277,7 +330,6 @@ async def test_benchmarks_routing_table(cached_disk_dist_registry):
assert "test-benchmark" in benchmark_ids
@pytest.mark.asyncio
async def test_tool_groups_routing_table(cached_disk_dist_registry):
table = ToolGroupsRoutingTable({"test_provider": ToolGroupsImpl()}, cached_disk_dist_registry, {})
await table.initialize()
@ -296,3 +348,260 @@ async def test_tool_groups_routing_table(cached_disk_dist_registry):
await table.unregister_toolgroup(toolgroup_id="test-toolgroup")
tool_groups = await table.list_tool_groups()
assert len(tool_groups.data) == 0
async def test_models_alias_registration_and_lookup(cached_disk_dist_registry):
"""Test alias registration (model_id != provider_model_id) and lookup behavior."""
table = ModelsRoutingTable({"test_provider": InferenceImpl()}, cached_disk_dist_registry, {})
await table.initialize()
# Register model with alias (model_id different from provider_model_id)
await table.register_model(
model_id="my-alias", provider_model_id="actual-provider-model", provider_id="test_provider"
)
# Verify the model was registered with alias as identifier (not namespaced)
models = await table.list_models()
assert len(models.data) == 1
model = models.data[0]
assert model.identifier == "my-alias" # Uses alias as identifier
assert model.provider_resource_id == "actual-provider-model"
# Test lookup by alias works
retrieved_model = await table.get_model("my-alias")
assert retrieved_model.identifier == "my-alias"
assert retrieved_model.provider_resource_id == "actual-provider-model"
async def test_models_multi_provider_disambiguation(cached_disk_dist_registry):
"""Test registration and lookup with multiple providers having same provider_model_id."""
table = ModelsRoutingTable(
{"provider1": InferenceImpl(), "provider2": InferenceImpl()}, cached_disk_dist_registry, {}
)
await table.initialize()
# Register same provider_model_id on both providers (no aliases)
await table.register_model(model_id="common-model", provider_id="provider1")
await table.register_model(model_id="common-model", provider_id="provider2")
# Verify both models get namespaced identifiers
models = await table.list_models()
assert len(models.data) == 2
identifiers = {m.identifier for m in models.data}
assert identifiers == {"provider1/common-model", "provider2/common-model"}
# Test lookup by full namespaced identifier works
model1 = await table.get_model("provider1/common-model")
assert model1.provider_id == "provider1"
assert model1.provider_resource_id == "common-model"
model2 = await table.get_model("provider2/common-model")
assert model2.provider_id == "provider2"
assert model2.provider_resource_id == "common-model"
# Test lookup by unscoped provider_model_id fails with multiple providers error
try:
await table.get_model("common-model")
raise AssertionError("Should have raised ValueError for multiple providers")
except ValueError as e:
assert "Multiple providers found" in str(e)
assert "provider1" in str(e) and "provider2" in str(e)
async def test_models_fallback_lookup_behavior(cached_disk_dist_registry):
"""Test two-stage lookup: direct identifier hit vs fallback to provider_resource_id."""
table = ModelsRoutingTable({"test_provider": InferenceImpl()}, cached_disk_dist_registry, {})
await table.initialize()
# Register model without alias (gets namespaced identifier)
await table.register_model(model_id="test-model", provider_id="test_provider")
# Verify namespaced identifier was created
models = await table.list_models()
assert len(models.data) == 1
model = models.data[0]
assert model.identifier == "test_provider/test-model"
assert model.provider_resource_id == "test-model"
# Test lookup by full namespaced identifier (direct hit via get_object_by_identifier)
retrieved_model = await table.get_model("test_provider/test-model")
assert retrieved_model.identifier == "test_provider/test-model"
# Test lookup by unscoped provider_model_id (fallback via iteration)
retrieved_model = await table.get_model("test-model")
assert retrieved_model.identifier == "test_provider/test-model"
assert retrieved_model.provider_resource_id == "test-model"
# Test lookup of non-existent model fails
try:
await table.get_model("non-existent")
raise AssertionError("Should have raised ValueError for non-existent model")
except ValueError as e:
assert "not found" in str(e)
async def test_models_source_tracking_default(cached_disk_dist_registry):
"""Test that models registered via register_model get default source."""
table = ModelsRoutingTable({"test_provider": InferenceImpl()}, cached_disk_dist_registry, {})
await table.initialize()
# Register model via register_model (should get default source)
await table.register_model(model_id="user-model", provider_id="test_provider")
models = await table.list_models()
assert len(models.data) == 1
model = models.data[0]
assert model.source == RegistryEntrySource.via_register_api
assert model.identifier == "test_provider/user-model"
# Cleanup
await table.shutdown()
async def test_models_source_tracking_provider(cached_disk_dist_registry):
"""Test that models registered via update_registered_models get provider source."""
table = ModelsRoutingTable({"test_provider": InferenceImpl()}, cached_disk_dist_registry, {})
await table.initialize()
# Simulate provider refresh by calling update_registered_models
provider_models = [
Model(
identifier="provider-model-1",
provider_resource_id="provider-model-1",
provider_id="test_provider",
metadata={},
model_type=ModelType.llm,
),
Model(
identifier="provider-model-2",
provider_resource_id="provider-model-2",
provider_id="test_provider",
metadata={"embedding_dimension": 512},
model_type=ModelType.embedding,
),
]
await table.update_registered_models("test_provider", provider_models)
models = await table.list_models()
assert len(models.data) == 2
# All models should have provider source
for model in models.data:
assert model.source == RegistryEntrySource.listed_from_provider
assert model.provider_id == "test_provider"
# Cleanup
await table.shutdown()
async def test_models_source_interaction_preserves_default(cached_disk_dist_registry):
"""Test that provider refresh preserves user-registered models with default source."""
table = ModelsRoutingTable({"test_provider": InferenceImpl()}, cached_disk_dist_registry, {})
await table.initialize()
# First register a user model with same provider_resource_id as provider will later provide
await table.register_model(
model_id="my-custom-alias", provider_model_id="provider-model-1", provider_id="test_provider"
)
# Verify user model is registered with default source
models = await table.list_models()
assert len(models.data) == 1
user_model = models.data[0]
assert user_model.source == RegistryEntrySource.via_register_api
assert user_model.identifier == "my-custom-alias"
assert user_model.provider_resource_id == "provider-model-1"
# Now simulate provider refresh
provider_models = [
Model(
identifier="provider-model-1",
provider_resource_id="provider-model-1",
provider_id="test_provider",
metadata={},
model_type=ModelType.llm,
),
Model(
identifier="different-model",
provider_resource_id="different-model",
provider_id="test_provider",
metadata={},
model_type=ModelType.llm,
),
]
await table.update_registered_models("test_provider", provider_models)
# Verify user model with alias is preserved, but provider added new model
models = await table.list_models()
assert len(models.data) == 2
# Find the user model and provider model
user_model = next((m for m in models.data if m.identifier == "my-custom-alias"), None)
provider_model = next((m for m in models.data if m.identifier == "test_provider/different-model"), None)
assert user_model is not None
assert user_model.source == RegistryEntrySource.via_register_api
assert user_model.provider_resource_id == "provider-model-1"
assert provider_model is not None
assert provider_model.source == RegistryEntrySource.listed_from_provider
assert provider_model.provider_resource_id == "different-model"
# Cleanup
await table.shutdown()
async def test_models_source_interaction_cleanup_provider_models(cached_disk_dist_registry):
"""Test that provider refresh removes old provider models but keeps default ones."""
table = ModelsRoutingTable({"test_provider": InferenceImpl()}, cached_disk_dist_registry, {})
await table.initialize()
# Register a user model
await table.register_model(model_id="user-model", provider_id="test_provider")
# Add some provider models
provider_models_v1 = [
Model(
identifier="provider-model-old",
provider_resource_id="provider-model-old",
provider_id="test_provider",
metadata={},
model_type=ModelType.llm,
),
]
await table.update_registered_models("test_provider", provider_models_v1)
# Verify we have both user and provider models
models = await table.list_models()
assert len(models.data) == 2
# Now update with new provider models (should remove old provider models)
provider_models_v2 = [
Model(
identifier="provider-model-new",
provider_resource_id="provider-model-new",
provider_id="test_provider",
metadata={},
model_type=ModelType.llm,
),
]
await table.update_registered_models("test_provider", provider_models_v2)
# Should have user model + new provider model, old provider model gone
models = await table.list_models()
assert len(models.data) == 2
identifiers = {m.identifier for m in models.data}
assert "test_provider/user-model" in identifiers # User model preserved
assert "test_provider/provider-model-new" in identifiers # New provider model (uses provider's identifier)
assert "test_provider/provider-model-old" not in identifiers # Old provider model removed
# Verify sources are correct
user_model = next((m for m in models.data if m.identifier == "test_provider/user-model"), None)
provider_model = next((m for m in models.data if m.identifier == "test_provider/provider-model-new"), None)
assert user_model.source == RegistryEntrySource.via_register_api
assert provider_model.source == RegistryEntrySource.listed_from_provider
# Cleanup
await table.shutdown()

View file

@ -0,0 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

View file

@ -0,0 +1,274 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
# Unit tests for the routing tables vector_dbs
import time
from unittest.mock import AsyncMock
import pytest
from llama_stack.apis.datatypes import Api
from llama_stack.apis.models import ModelType
from llama_stack.apis.vector_dbs.vector_dbs import VectorDB
from llama_stack.apis.vector_io.vector_io import (
VectorStoreContent,
VectorStoreDeleteResponse,
VectorStoreFileContentsResponse,
VectorStoreFileCounts,
VectorStoreFileDeleteResponse,
VectorStoreFileObject,
VectorStoreObject,
VectorStoreSearchResponsePage,
)
from llama_stack.core.access_control.datatypes import AccessRule, Scope
from llama_stack.core.datatypes import User
from llama_stack.core.request_headers import request_provider_data_context
from llama_stack.core.routing_tables.vector_dbs import VectorDBsRoutingTable
from tests.unit.distribution.routers.test_routing_tables import Impl, InferenceImpl, ModelsRoutingTable
class VectorDBImpl(Impl):
def __init__(self):
super().__init__(Api.vector_io)
async def register_vector_db(self, vector_db: VectorDB):
return vector_db
async def unregister_vector_db(self, vector_db_id: str):
return vector_db_id
async def openai_retrieve_vector_store(self, vector_store_id):
return VectorStoreObject(
id=vector_store_id,
name="Test Store",
created_at=int(time.time()),
file_counts=VectorStoreFileCounts(completed=0, cancelled=0, failed=0, in_progress=0, total=0),
)
async def openai_update_vector_store(self, vector_store_id, **kwargs):
return VectorStoreObject(
id=vector_store_id,
name="Updated Store",
created_at=int(time.time()),
file_counts=VectorStoreFileCounts(completed=0, cancelled=0, failed=0, in_progress=0, total=0),
)
async def openai_delete_vector_store(self, vector_store_id):
return VectorStoreDeleteResponse(id=vector_store_id, object="vector_store.deleted", deleted=True)
async def openai_search_vector_store(self, vector_store_id, query, **kwargs):
return VectorStoreSearchResponsePage(
object="vector_store.search_results.page", search_query="query", data=[], has_more=False, next_page=None
)
async def openai_attach_file_to_vector_store(self, vector_store_id, file_id, **kwargs):
return VectorStoreFileObject(
id=file_id,
status="completed",
chunking_strategy={"type": "auto"},
created_at=int(time.time()),
vector_store_id=vector_store_id,
)
async def openai_list_files_in_vector_store(self, vector_store_id, **kwargs):
return [
VectorStoreFileObject(
id="1",
status="completed",
chunking_strategy={"type": "auto"},
created_at=int(time.time()),
vector_store_id=vector_store_id,
)
]
async def openai_retrieve_vector_store_file(self, vector_store_id, file_id):
return VectorStoreFileObject(
id=file_id,
status="completed",
chunking_strategy={"type": "auto"},
created_at=int(time.time()),
vector_store_id=vector_store_id,
)
async def openai_retrieve_vector_store_file_contents(self, vector_store_id, file_id):
return VectorStoreFileContentsResponse(
file_id=file_id,
filename="Sample File name",
attributes={"key": "value"},
content=[VectorStoreContent(type="text", text="Sample content")],
)
async def openai_update_vector_store_file(self, vector_store_id, file_id, **kwargs):
return VectorStoreFileObject(
id=file_id,
status="completed",
chunking_strategy={"type": "auto"},
created_at=int(time.time()),
vector_store_id=vector_store_id,
)
async def openai_delete_vector_store_file(self, vector_store_id, file_id):
return VectorStoreFileDeleteResponse(id=file_id, deleted=True)
async def test_vectordbs_routing_table(cached_disk_dist_registry):
table = VectorDBsRoutingTable({"test_provider": VectorDBImpl()}, cached_disk_dist_registry, {})
await table.initialize()
m_table = ModelsRoutingTable({"test_provider": InferenceImpl()}, cached_disk_dist_registry, {})
await m_table.initialize()
await m_table.register_model(
model_id="test-model",
provider_id="test_provider",
metadata={"embedding_dimension": 128},
model_type=ModelType.embedding,
)
# Register multiple vector databases and verify listing
await table.register_vector_db(vector_db_id="test-vectordb", embedding_model="test-model")
await table.register_vector_db(vector_db_id="test-vectordb-2", embedding_model="test-model")
vector_dbs = await table.list_vector_dbs()
assert len(vector_dbs.data) == 2
vector_db_ids = {v.identifier for v in vector_dbs.data}
assert "test-vectordb" in vector_db_ids
assert "test-vectordb-2" in vector_db_ids
await table.unregister_vector_db(vector_db_id="test-vectordb")
await table.unregister_vector_db(vector_db_id="test-vectordb-2")
vector_dbs = await table.list_vector_dbs()
assert len(vector_dbs.data) == 0
async def test_openai_vector_stores_routing_table_roles(cached_disk_dist_registry):
impl = VectorDBImpl()
impl.openai_retrieve_vector_store = AsyncMock(return_value="OK")
table = VectorDBsRoutingTable({"test_provider": impl}, cached_disk_dist_registry, policy=[])
m_table = ModelsRoutingTable({"test_provider": InferenceImpl()}, cached_disk_dist_registry, policy=[])
authorized_table = "vs1"
authorized_team = "team1"
unauthorized_team = "team2"
await m_table.initialize()
await m_table.register_model(
model_id="test-model",
provider_id="test_provider",
metadata={"embedding_dimension": 128},
model_type=ModelType.embedding,
)
authorized_user = User(principal="alice", attributes={"roles": [authorized_team]})
with request_provider_data_context({}, authorized_user):
_ = await table.register_vector_db(vector_db_id="vs1", embedding_model="test-model")
# Authorized reader
with request_provider_data_context({}, authorized_user):
res = await table.openai_retrieve_vector_store(authorized_table)
assert res == "OK"
# Authorized updater
impl.openai_update_vector_store_file = AsyncMock(return_value="UPDATED")
with request_provider_data_context({}, authorized_user):
res = await table.openai_update_vector_store_file(authorized_table, file_id="file1", attributes={"foo": "bar"})
assert res == "UPDATED"
# Unauthorized reader
unauthorized_user = User(principal="eve", attributes={"roles": [unauthorized_team]})
with request_provider_data_context({}, unauthorized_user):
with pytest.raises(ValueError):
await table.openai_retrieve_vector_store(authorized_table)
# Unauthorized updater
with request_provider_data_context({}, unauthorized_user):
with pytest.raises(ValueError):
await table.openai_update_vector_store_file(authorized_table, file_id="file1", attributes={"foo": "bar"})
# Authorized deleter
impl.openai_delete_vector_store_file = AsyncMock(return_value="DELETED")
with request_provider_data_context({}, authorized_user):
res = await table.openai_delete_vector_store_file(authorized_table, file_id="file1")
assert res == "DELETED"
# Unauthorized deleter
with request_provider_data_context({}, unauthorized_user):
with pytest.raises(ValueError):
await table.openai_delete_vector_store_file(authorized_table, file_id="file1")
async def test_openai_vector_stores_routing_table_actions(cached_disk_dist_registry):
impl = VectorDBImpl()
policy = [
AccessRule(permit=Scope(actions=["create", "read", "update", "delete"]), when="user with admin in roles"),
AccessRule(permit=Scope(actions=["read"]), when="user with reader in roles"),
]
table = VectorDBsRoutingTable({"test_provider": impl}, cached_disk_dist_registry, policy=policy)
m_table = ModelsRoutingTable({"test_provider": InferenceImpl()}, cached_disk_dist_registry, policy=[])
vector_db_id = "vs1"
file_id = "file-1"
admin_user = User(principal="admin", attributes={"roles": ["admin"]})
read_only_user = User(principal="reader", attributes={"roles": ["reader"]})
no_access_user = User(principal="outsider", attributes={"roles": ["no_access"]})
await m_table.initialize()
await m_table.register_model(
model_id="test-model",
provider_id="test_provider",
metadata={"embedding_dimension": 128},
model_type=ModelType.embedding,
)
with request_provider_data_context({}, admin_user):
await table.register_vector_db(vector_db_id=vector_db_id, embedding_model="test-model")
read_methods = [
(table.openai_retrieve_vector_store, (vector_db_id,), {}),
(table.openai_search_vector_store, (vector_db_id, "query"), {}),
(table.openai_list_files_in_vector_store, (vector_db_id,), {}),
(table.openai_retrieve_vector_store_file, (vector_db_id, file_id), {}),
(table.openai_retrieve_vector_store_file_contents, (vector_db_id, file_id), {}),
]
update_methods = [
(table.openai_update_vector_store, (vector_db_id,), {"name": "Updated DB"}),
(table.openai_attach_file_to_vector_store, (vector_db_id, file_id), {}),
(table.openai_update_vector_store_file, (vector_db_id, file_id), {"attributes": {"key": "value"}}),
]
delete_methods = [
(table.openai_delete_vector_store_file, (vector_db_id, file_id), {}),
(table.openai_delete_vector_store, (vector_db_id,), {}),
]
for user in [admin_user, read_only_user]:
with request_provider_data_context({}, user):
for method, args, kwargs in read_methods:
result = await method(*args, **kwargs)
assert result is not None, f"Read operation failed with user {user.principal}"
with request_provider_data_context({}, no_access_user):
for method, args, kwargs in read_methods:
with pytest.raises(ValueError):
await method(*args, **kwargs)
with request_provider_data_context({}, admin_user):
for method, args, kwargs in update_methods:
result = await method(*args, **kwargs)
assert result is not None, "Update operation failed with admin user"
with request_provider_data_context({}, admin_user):
for method, args, kwargs in delete_methods:
result = await method(*args, **kwargs)
assert result is not None, "Delete operation failed with admin user"
for user in [read_only_user, no_access_user]:
with request_provider_data_context({}, user):
for method, args, kwargs in delete_methods:
with pytest.raises(ValueError):
await method(*args, **kwargs)

View file

@ -9,15 +9,15 @@ from pathlib import Path
from llama_stack.cli.stack._build import (
_run_stack_build_command_from_build_config,
)
from llama_stack.distribution.datatypes import BuildConfig, DistributionSpec
from llama_stack.distribution.utils.image_types import LlamaStackImageType
from llama_stack.core.datatypes import BuildConfig, DistributionSpec
from llama_stack.core.utils.image_types import LlamaStackImageType
def test_container_build_passes_path(monkeypatch, tmp_path):
called_with = {}
def spy_build_image(cfg, build_file_path, image_name, template_or_config, run_config=None):
called_with["path"] = template_or_config
def spy_build_image(build_config, image_name, distro_or_config, run_config=None):
called_with["path"] = distro_or_config
called_with["run_config"] = run_config
return 0

View file

@ -10,10 +10,9 @@ from contextvars import ContextVar
import pytest
from llama_stack.distribution.utils.context import preserve_contexts_async_generator
from llama_stack.core.utils.context import preserve_contexts_async_generator
@pytest.mark.asyncio
async def test_preserve_contexts_with_exception():
# Create context variable
context_var = ContextVar("exception_var", default="initial")
@ -41,7 +40,6 @@ async def test_preserve_contexts_with_exception():
context_var.reset(token)
@pytest.mark.asyncio
async def test_preserve_contexts_empty_generator():
# Create context variable
context_var = ContextVar("empty_var", default="initial")
@ -66,7 +64,6 @@ async def test_preserve_contexts_empty_generator():
context_var.reset(token)
@pytest.mark.asyncio
async def test_preserve_contexts_across_event_loops():
"""
Test that context variables are preserved across event loop boundaries with nested generators.

View file

@ -11,8 +11,8 @@ import pytest
import yaml
from pydantic import BaseModel, Field, ValidationError
from llama_stack.distribution.datatypes import Api, Provider, StackRunConfig
from llama_stack.distribution.distribution import get_provider_registry
from llama_stack.core.datatypes import Api, Provider, StackRunConfig
from llama_stack.core.distribution import get_provider_registry
from llama_stack.providers.datatypes import ProviderSpec
@ -106,6 +106,40 @@ def api_directories(tmp_path):
return remote_inference_dir, inline_inference_dir
def make_import_module_side_effect(
builtin_provider_spec=None,
external_module=None,
raise_for_external=False,
missing_get_provider_spec=False,
):
from types import SimpleNamespace
def import_module_side_effect(name):
if name == "llama_stack.providers.registry.inference":
mock_builtin = SimpleNamespace(
available_providers=lambda: [
builtin_provider_spec
or ProviderSpec(
api=Api.inference,
provider_type="test_provider",
config_class="test_provider.config.TestProviderConfig",
module="test_provider",
)
]
)
return mock_builtin
elif name == "external_test.provider":
if raise_for_external:
raise ModuleNotFoundError(name)
if missing_get_provider_spec:
return SimpleNamespace()
return external_module
else:
raise ModuleNotFoundError(name)
return import_module_side_effect
class TestProviderRegistry:
"""Test suite for provider registry functionality."""
@ -221,3 +255,122 @@ pip_packages:
with pytest.raises(KeyError) as exc_info:
get_provider_registry(base_config)
assert "config_class" in str(exc_info.value)
def test_external_provider_from_module_success(self, mock_providers):
"""Test loading an external provider from a module (success path)."""
from types import SimpleNamespace
from llama_stack.core.datatypes import Provider, StackRunConfig
from llama_stack.providers.datatypes import Api, ProviderSpec
# Simulate a provider module with get_provider_spec
fake_spec = ProviderSpec(
api=Api.inference,
provider_type="external_test",
config_class="external_test.config.ExternalTestConfig",
module="external_test",
)
fake_module = SimpleNamespace(get_provider_spec=lambda: fake_spec)
import_module_side_effect = make_import_module_side_effect(external_module=fake_module)
with patch("importlib.import_module", side_effect=import_module_side_effect) as mock_import:
config = StackRunConfig(
image_name="test_image",
providers={
"inference": [
Provider(
provider_id="external_test",
provider_type="external_test",
config={},
module="external_test",
)
]
},
)
registry = get_provider_registry(config)
assert Api.inference in registry
assert "external_test" in registry[Api.inference]
provider = registry[Api.inference]["external_test"]
assert provider.module == "external_test"
assert provider.config_class == "external_test.config.ExternalTestConfig"
mock_import.assert_any_call("llama_stack.providers.registry.inference")
mock_import.assert_any_call("external_test.provider")
def test_external_provider_from_module_not_found(self, mock_providers):
"""Test handling ModuleNotFoundError for missing provider module."""
from llama_stack.core.datatypes import Provider, StackRunConfig
import_module_side_effect = make_import_module_side_effect(raise_for_external=True)
with patch("importlib.import_module", side_effect=import_module_side_effect):
config = StackRunConfig(
image_name="test_image",
providers={
"inference": [
Provider(
provider_id="external_test",
provider_type="external_test",
config={},
module="external_test",
)
]
},
)
with pytest.raises(ValueError) as exc_info:
get_provider_registry(config)
assert "get_provider_spec not found" in str(exc_info.value)
def test_external_provider_from_module_missing_get_provider_spec(self, mock_providers):
"""Test handling missing get_provider_spec in provider module (should raise ValueError)."""
from llama_stack.core.datatypes import Provider, StackRunConfig
import_module_side_effect = make_import_module_side_effect(missing_get_provider_spec=True)
with patch("importlib.import_module", side_effect=import_module_side_effect):
config = StackRunConfig(
image_name="test_image",
providers={
"inference": [
Provider(
provider_id="external_test",
provider_type="external_test",
config={},
module="external_test",
)
]
},
)
with pytest.raises(AttributeError):
get_provider_registry(config)
def test_external_provider_from_module_building(self, mock_providers):
"""Test loading an external provider from a module during build (building=True, partial spec)."""
from llama_stack.core.datatypes import BuildConfig, BuildProvider, DistributionSpec
from llama_stack.providers.datatypes import Api
# No importlib patch needed, should not import module when type of `config` is BuildConfig or DistributionSpec
build_config = BuildConfig(
version=2,
image_type="container",
image_name="test_image",
distribution_spec=DistributionSpec(
description="test",
providers={
"inference": [
BuildProvider(
provider_type="external_test",
module="external_test",
)
]
},
),
)
registry = get_provider_registry(build_config)
assert Api.inference in registry
assert "external_test" in registry[Api.inference]
provider = registry[Api.inference]["external_test"]
assert provider.module == "external_test"
assert provider.is_external is True
# config_class is empty string in partial spec
assert provider.config_class == ""

View file

@ -0,0 +1,291 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import sqlite3
import tempfile
from pathlib import Path
from unittest.mock import patch
import pytest
from openai import AsyncOpenAI
# Import the real Pydantic response types instead of using Mocks
from llama_stack.apis.inference import (
OpenAIAssistantMessageParam,
OpenAIChatCompletion,
OpenAIChoice,
OpenAIEmbeddingData,
OpenAIEmbeddingsResponse,
OpenAIEmbeddingUsage,
)
from llama_stack.testing.inference_recorder import (
InferenceMode,
ResponseStorage,
inference_recording,
normalize_request,
)
@pytest.fixture
def temp_storage_dir():
"""Create a temporary directory for test recordings."""
with tempfile.TemporaryDirectory() as temp_dir:
yield Path(temp_dir)
@pytest.fixture
def real_openai_chat_response():
"""Real OpenAI chat completion response using proper Pydantic objects."""
return OpenAIChatCompletion(
id="chatcmpl-test123",
choices=[
OpenAIChoice(
index=0,
message=OpenAIAssistantMessageParam(
role="assistant", content="Hello! I'm doing well, thank you for asking."
),
finish_reason="stop",
)
],
created=1234567890,
model="llama3.2:3b",
)
@pytest.fixture
def real_embeddings_response():
"""Real OpenAI embeddings response using proper Pydantic objects."""
return OpenAIEmbeddingsResponse(
object="list",
data=[
OpenAIEmbeddingData(object="embedding", embedding=[0.1, 0.2, 0.3], index=0),
OpenAIEmbeddingData(object="embedding", embedding=[0.4, 0.5, 0.6], index=1),
],
model="nomic-embed-text",
usage=OpenAIEmbeddingUsage(prompt_tokens=6, total_tokens=6),
)
class TestInferenceRecording:
"""Test the inference recording system."""
def test_request_normalization(self):
"""Test that request normalization produces consistent hashes."""
# Test basic normalization
hash1 = normalize_request(
"POST",
"http://localhost:11434/v1/chat/completions",
{},
{"model": "llama3.2:3b", "messages": [{"role": "user", "content": "Hello world"}], "temperature": 0.7},
)
# Same request should produce same hash
hash2 = normalize_request(
"POST",
"http://localhost:11434/v1/chat/completions",
{},
{"model": "llama3.2:3b", "messages": [{"role": "user", "content": "Hello world"}], "temperature": 0.7},
)
assert hash1 == hash2
# Different content should produce different hash
hash3 = normalize_request(
"POST",
"http://localhost:11434/v1/chat/completions",
{},
{
"model": "llama3.2:3b",
"messages": [{"role": "user", "content": "Different message"}],
"temperature": 0.7,
},
)
assert hash1 != hash3
def test_request_normalization_edge_cases(self):
"""Test request normalization is precise about request content."""
# Test that different whitespace produces different hashes (no normalization)
hash1 = normalize_request(
"POST",
"http://test/v1/chat/completions",
{},
{"messages": [{"role": "user", "content": "Hello world\n\n"}]},
)
hash2 = normalize_request(
"POST", "http://test/v1/chat/completions", {}, {"messages": [{"role": "user", "content": "Hello world"}]}
)
assert hash1 != hash2 # Different whitespace should produce different hashes
# Test that different float precision produces different hashes (no rounding)
hash3 = normalize_request("POST", "http://test/v1/chat/completions", {}, {"temperature": 0.7000001})
hash4 = normalize_request("POST", "http://test/v1/chat/completions", {}, {"temperature": 0.7})
assert hash3 != hash4 # Different precision should produce different hashes
def test_response_storage(self, temp_storage_dir):
"""Test the ResponseStorage class."""
temp_storage_dir = temp_storage_dir / "test_response_storage"
storage = ResponseStorage(temp_storage_dir)
# Test directory creation
assert storage.test_dir.exists()
assert storage.responses_dir.exists()
assert storage.db_path.exists()
# Test storing and retrieving a recording
request_hash = "test_hash_123"
request_data = {
"method": "POST",
"url": "http://localhost:11434/v1/chat/completions",
"endpoint": "/v1/chat/completions",
"model": "llama3.2:3b",
}
response_data = {"body": {"content": "test response"}, "is_streaming": False}
storage.store_recording(request_hash, request_data, response_data)
# Verify SQLite record
with sqlite3.connect(storage.db_path) as conn:
result = conn.execute("SELECT * FROM recordings WHERE request_hash = ?", (request_hash,)).fetchone()
assert result is not None
assert result[0] == request_hash # request_hash
assert result[2] == "/v1/chat/completions" # endpoint
assert result[3] == "llama3.2:3b" # model
# Verify file storage and retrieval
retrieved = storage.find_recording(request_hash)
assert retrieved is not None
assert retrieved["request"]["model"] == "llama3.2:3b"
assert retrieved["response"]["body"]["content"] == "test response"
async def test_recording_mode(self, temp_storage_dir, real_openai_chat_response):
"""Test that recording mode captures and stores responses."""
async def mock_create(*args, **kwargs):
return real_openai_chat_response
temp_storage_dir = temp_storage_dir / "test_recording_mode"
with patch("openai.resources.chat.completions.AsyncCompletions.create", side_effect=mock_create):
with inference_recording(mode=InferenceMode.RECORD, storage_dir=str(temp_storage_dir)):
client = AsyncOpenAI(base_url="http://localhost:11434/v1", api_key="test")
response = await client.chat.completions.create(
model="llama3.2:3b",
messages=[{"role": "user", "content": "Hello, how are you?"}],
temperature=0.7,
max_tokens=50,
)
# Verify the response was returned correctly
assert response.choices[0].message.content == "Hello! I'm doing well, thank you for asking."
# Verify recording was stored
storage = ResponseStorage(temp_storage_dir)
with sqlite3.connect(storage.db_path) as conn:
recordings = conn.execute("SELECT COUNT(*) FROM recordings").fetchone()[0]
assert recordings == 1
async def test_replay_mode(self, temp_storage_dir, real_openai_chat_response):
"""Test that replay mode returns stored responses without making real calls."""
async def mock_create(*args, **kwargs):
return real_openai_chat_response
temp_storage_dir = temp_storage_dir / "test_replay_mode"
# First, record a response
with patch("openai.resources.chat.completions.AsyncCompletions.create", side_effect=mock_create):
with inference_recording(mode=InferenceMode.RECORD, storage_dir=str(temp_storage_dir)):
client = AsyncOpenAI(base_url="http://localhost:11434/v1", api_key="test")
response = await client.chat.completions.create(
model="llama3.2:3b",
messages=[{"role": "user", "content": "Hello, how are you?"}],
temperature=0.7,
max_tokens=50,
)
# Now test replay mode - should not call the original method
with patch("openai.resources.chat.completions.AsyncCompletions.create") as mock_create_patch:
with inference_recording(mode=InferenceMode.REPLAY, storage_dir=str(temp_storage_dir)):
client = AsyncOpenAI(base_url="http://localhost:11434/v1", api_key="test")
response = await client.chat.completions.create(
model="llama3.2:3b",
messages=[{"role": "user", "content": "Hello, how are you?"}],
temperature=0.7,
max_tokens=50,
)
# Verify we got the recorded response
assert response.choices[0].message.content == "Hello! I'm doing well, thank you for asking."
# Verify the original method was NOT called
mock_create_patch.assert_not_called()
async def test_replay_missing_recording(self, temp_storage_dir):
"""Test that replay mode fails when no recording is found."""
temp_storage_dir = temp_storage_dir / "test_replay_missing_recording"
with patch("openai.resources.chat.completions.AsyncCompletions.create"):
with inference_recording(mode=InferenceMode.REPLAY, storage_dir=str(temp_storage_dir)):
client = AsyncOpenAI(base_url="http://localhost:11434/v1", api_key="test")
with pytest.raises(RuntimeError, match="No recorded response found"):
await client.chat.completions.create(
model="llama3.2:3b", messages=[{"role": "user", "content": "This was never recorded"}]
)
async def test_embeddings_recording(self, temp_storage_dir, real_embeddings_response):
"""Test recording and replay of embeddings calls."""
async def mock_create(*args, **kwargs):
return real_embeddings_response
temp_storage_dir = temp_storage_dir / "test_embeddings_recording"
# Record
with patch("openai.resources.embeddings.AsyncEmbeddings.create", side_effect=mock_create):
with inference_recording(mode=InferenceMode.RECORD, storage_dir=str(temp_storage_dir)):
client = AsyncOpenAI(base_url="http://localhost:11434/v1", api_key="test")
response = await client.embeddings.create(
model="nomic-embed-text", input=["Hello world", "Test embedding"]
)
assert len(response.data) == 2
# Replay
with patch("openai.resources.embeddings.AsyncEmbeddings.create") as mock_create_patch:
with inference_recording(mode=InferenceMode.REPLAY, storage_dir=str(temp_storage_dir)):
client = AsyncOpenAI(base_url="http://localhost:11434/v1", api_key="test")
response = await client.embeddings.create(
model="nomic-embed-text", input=["Hello world", "Test embedding"]
)
# Verify we got the recorded response
assert len(response.data) == 2
assert response.data[0].embedding == [0.1, 0.2, 0.3]
# Verify original method was not called
mock_create_patch.assert_not_called()
async def test_live_mode(self, real_openai_chat_response):
"""Test that live mode passes through to original methods."""
async def mock_create(*args, **kwargs):
return real_openai_chat_response
with patch("openai.resources.chat.completions.AsyncCompletions.create", side_effect=mock_create):
with inference_recording(mode=InferenceMode.LIVE):
client = AsyncOpenAI(base_url="http://localhost:11434/v1", api_key="test")
response = await client.chat.completions.create(
model="llama3.2:3b", messages=[{"role": "user", "content": "Hello"}]
)
# Verify the response was returned
assert response.choices[0].message.content == "Hello! I'm doing well, thank you for asking."

View file

@ -0,0 +1,90 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
"""
Unit tests for LlamaStackAsLibraryClient initialization error handling.
These tests ensure that users get proper error messages when they forget to call
initialize() on the library client, preventing AttributeError regressions.
"""
import pytest
from llama_stack.core.library_client import (
AsyncLlamaStackAsLibraryClient,
LlamaStackAsLibraryClient,
)
class TestLlamaStackAsLibraryClientInitialization:
"""Test proper error handling for uninitialized library clients."""
@pytest.mark.parametrize(
"api_call",
[
lambda client: client.models.list(),
lambda client: client.chat.completions.create(model="test", messages=[{"role": "user", "content": "test"}]),
lambda client: next(
client.chat.completions.create(
model="test", messages=[{"role": "user", "content": "test"}], stream=True
)
),
],
ids=["models.list", "chat.completions.create", "chat.completions.create_stream"],
)
def test_sync_client_proper_error_without_initialization(self, api_call):
"""Test that sync client raises ValueError with helpful message when not initialized."""
client = LlamaStackAsLibraryClient("nvidia")
with pytest.raises(ValueError) as exc_info:
api_call(client)
error_msg = str(exc_info.value)
assert "Client not initialized" in error_msg
assert "Please call initialize() first" in error_msg
@pytest.mark.parametrize(
"api_call",
[
lambda client: client.models.list(),
lambda client: client.chat.completions.create(model="test", messages=[{"role": "user", "content": "test"}]),
],
ids=["models.list", "chat.completions.create"],
)
async def test_async_client_proper_error_without_initialization(self, api_call):
"""Test that async client raises ValueError with helpful message when not initialized."""
client = AsyncLlamaStackAsLibraryClient("nvidia")
with pytest.raises(ValueError) as exc_info:
await api_call(client)
error_msg = str(exc_info.value)
assert "Client not initialized" in error_msg
assert "Please call initialize() first" in error_msg
async def test_async_client_streaming_error_without_initialization(self):
"""Test that async client streaming raises ValueError with helpful message when not initialized."""
client = AsyncLlamaStackAsLibraryClient("nvidia")
with pytest.raises(ValueError) as exc_info:
stream = await client.chat.completions.create(
model="test", messages=[{"role": "user", "content": "test"}], stream=True
)
await anext(stream)
error_msg = str(exc_info.value)
assert "Client not initialized" in error_msg
assert "Please call initialize() first" in error_msg
def test_route_impls_initialized_to_none(self):
"""Test that route_impls is initialized to None to prevent AttributeError."""
# Test sync client
sync_client = LlamaStackAsLibraryClient("nvidia")
assert sync_client.async_client.route_impls is None
# Test async client directly
async_client = AsyncLlamaStackAsLibraryClient("nvidia")
assert async_client.route_impls is None

View file

@ -6,10 +6,10 @@
import pytest
import pytest_asyncio
from llama_stack.apis.common.responses import Order
from llama_stack.apis.files import OpenAIFilePurpose
from llama_stack.core.access_control.access_control import default_policy
from llama_stack.providers.inline.files.localfs import (
LocalfsFilesImpl,
LocalfsFilesImplConfig,
@ -29,7 +29,7 @@ class MockUploadFile:
return self.content
@pytest_asyncio.fixture
@pytest.fixture
async def files_provider(tmp_path):
"""Create a files provider with temporary storage for testing."""
storage_dir = tmp_path / "files"
@ -39,7 +39,7 @@ async def files_provider(tmp_path):
storage_dir=storage_dir.as_posix(), metadata_store=SqliteSqlStoreConfig(db_path=db_path.as_posix())
)
provider = LocalfsFilesImpl(config)
provider = LocalfsFilesImpl(config, default_policy())
await provider.initialize()
yield provider
@ -68,7 +68,6 @@ def large_file():
class TestOpenAIFilesAPI:
"""Test suite for OpenAI Files API endpoints."""
@pytest.mark.asyncio
async def test_upload_file_success(self, files_provider, sample_text_file):
"""Test successful file upload."""
# Upload file
@ -82,7 +81,6 @@ class TestOpenAIFilesAPI:
assert result.created_at > 0
assert result.expires_at > result.created_at
@pytest.mark.asyncio
async def test_upload_different_purposes(self, files_provider, sample_text_file):
"""Test uploading files with different purposes."""
purposes = list(OpenAIFilePurpose)
@ -93,7 +91,6 @@ class TestOpenAIFilesAPI:
uploaded_files.append(result)
assert result.purpose == purpose
@pytest.mark.asyncio
async def test_upload_different_file_types(self, files_provider, sample_text_file, sample_json_file, large_file):
"""Test uploading different types and sizes of files."""
files_to_test = [
@ -107,7 +104,6 @@ class TestOpenAIFilesAPI:
assert result.filename == expected_filename
assert result.bytes == len(file_obj.content)
@pytest.mark.asyncio
async def test_list_files_empty(self, files_provider):
"""Test listing files when no files exist."""
result = await files_provider.openai_list_files()
@ -117,7 +113,6 @@ class TestOpenAIFilesAPI:
assert result.first_id == ""
assert result.last_id == ""
@pytest.mark.asyncio
async def test_list_files_with_content(self, files_provider, sample_text_file, sample_json_file):
"""Test listing files when files exist."""
# Upload multiple files
@ -132,7 +127,6 @@ class TestOpenAIFilesAPI:
assert file1.id in file_ids
assert file2.id in file_ids
@pytest.mark.asyncio
async def test_list_files_with_purpose_filter(self, files_provider, sample_text_file):
"""Test listing files with purpose filtering."""
# Upload file with specific purpose
@ -146,7 +140,6 @@ class TestOpenAIFilesAPI:
assert result.data[0].id == uploaded_file.id
assert result.data[0].purpose == OpenAIFilePurpose.ASSISTANTS
@pytest.mark.asyncio
async def test_list_files_with_limit(self, files_provider, sample_text_file):
"""Test listing files with limit parameter."""
# Upload multiple files
@ -157,7 +150,6 @@ class TestOpenAIFilesAPI:
result = await files_provider.openai_list_files(limit=3)
assert len(result.data) == 3
@pytest.mark.asyncio
async def test_list_files_with_order(self, files_provider, sample_text_file):
"""Test listing files with different order."""
# Upload multiple files
@ -178,7 +170,6 @@ class TestOpenAIFilesAPI:
# Oldest should be first
assert result_asc.data[0].created_at <= result_asc.data[1].created_at <= result_asc.data[2].created_at
@pytest.mark.asyncio
async def test_retrieve_file_success(self, files_provider, sample_text_file):
"""Test successful file retrieval."""
# Upload file
@ -197,13 +188,11 @@ class TestOpenAIFilesAPI:
assert retrieved_file.created_at == uploaded_file.created_at
assert retrieved_file.expires_at == uploaded_file.expires_at
@pytest.mark.asyncio
async def test_retrieve_file_not_found(self, files_provider):
"""Test retrieving a non-existent file."""
with pytest.raises(ValueError, match="File with id file-nonexistent not found"):
await files_provider.openai_retrieve_file("file-nonexistent")
@pytest.mark.asyncio
async def test_retrieve_file_content_success(self, files_provider, sample_text_file):
"""Test successful file content retrieval."""
# Upload file
@ -217,13 +206,11 @@ class TestOpenAIFilesAPI:
# Verify content
assert content.body == sample_text_file.content
@pytest.mark.asyncio
async def test_retrieve_file_content_not_found(self, files_provider):
"""Test retrieving content of a non-existent file."""
with pytest.raises(ValueError, match="File with id file-nonexistent not found"):
await files_provider.openai_retrieve_file_content("file-nonexistent")
@pytest.mark.asyncio
async def test_delete_file_success(self, files_provider, sample_text_file):
"""Test successful file deletion."""
# Upload file
@ -245,13 +232,11 @@ class TestOpenAIFilesAPI:
with pytest.raises(ValueError, match=f"File with id {uploaded_file.id} not found"):
await files_provider.openai_retrieve_file(uploaded_file.id)
@pytest.mark.asyncio
async def test_delete_file_not_found(self, files_provider):
"""Test deleting a non-existent file."""
with pytest.raises(ValueError, match="File with id file-nonexistent not found"):
await files_provider.openai_delete_file("file-nonexistent")
@pytest.mark.asyncio
async def test_file_persistence_across_operations(self, files_provider, sample_text_file):
"""Test that files persist correctly across multiple operations."""
# Upload file
@ -279,7 +264,6 @@ class TestOpenAIFilesAPI:
files_list = await files_provider.openai_list_files()
assert len(files_list.data) == 0
@pytest.mark.asyncio
async def test_multiple_files_operations(self, files_provider, sample_text_file, sample_json_file):
"""Test operations with multiple files."""
# Upload multiple files
@ -302,7 +286,6 @@ class TestOpenAIFilesAPI:
content = await files_provider.openai_retrieve_file_content(file2.id)
assert content.body == sample_json_file.content
@pytest.mark.asyncio
async def test_file_id_uniqueness(self, files_provider, sample_text_file):
"""Test that each uploaded file gets a unique ID."""
file_ids = set()
@ -316,7 +299,6 @@ class TestOpenAIFilesAPI:
file_ids.add(uploaded_file.id)
assert uploaded_file.id.startswith("file-")
@pytest.mark.asyncio
async def test_file_no_filename_handling(self, files_provider):
"""Test handling files with no filename."""
file_without_name = MockUploadFile(b"content", None) # No filename
@ -327,7 +309,6 @@ class TestOpenAIFilesAPI:
assert uploaded_file.filename == "uploaded_file" # Default filename
@pytest.mark.asyncio
async def test_after_pagination_works(self, files_provider, sample_text_file):
"""Test that 'after' pagination works correctly."""
# Upload multiple files to test pagination

View file

@ -4,14 +4,14 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import pytest_asyncio
import pytest
from llama_stack.distribution.store.registry import CachedDiskDistributionRegistry, DiskDistributionRegistry
from llama_stack.core.store.registry import CachedDiskDistributionRegistry, DiskDistributionRegistry
from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig
from llama_stack.providers.utils.kvstore.sqlite import SqliteKVStoreImpl
@pytest_asyncio.fixture(scope="function")
@pytest.fixture(scope="function")
async def sqlite_kvstore(tmp_path):
db_path = tmp_path / "test_kv.db"
kvstore_config = SqliteKVStoreConfig(db_path=db_path.as_posix())
@ -20,14 +20,14 @@ async def sqlite_kvstore(tmp_path):
yield kvstore
@pytest_asyncio.fixture(scope="function")
@pytest.fixture(scope="function")
async def disk_dist_registry(sqlite_kvstore):
registry = DiskDistributionRegistry(sqlite_kvstore)
await registry.initialize()
yield registry
@pytest_asyncio.fixture(scope="function")
@pytest.fixture(scope="function")
async def cached_disk_dist_registry(sqlite_kvstore):
registry = CachedDiskDistributionRegistry(sqlite_kvstore)
await registry.initialize()

View file

@ -4,14 +4,13 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import asyncio
import unittest
from llama_stack.apis.inference import (
ChatCompletionRequest,
CompletionMessage,
StopReason,
SystemMessage,
SystemMessageBehavior,
ToolCall,
ToolConfig,
UserMessage,
@ -25,264 +24,266 @@ from llama_stack.models.llama.datatypes import (
from llama_stack.providers.utils.inference.prompt_adapter import (
chat_completion_request_to_messages,
chat_completion_request_to_prompt,
interleaved_content_as_str,
)
MODEL = "Llama3.1-8B-Instruct"
MODEL3_2 = "Llama3.2-3B-Instruct"
class PrepareMessagesTests(unittest.IsolatedAsyncioTestCase):
async def asyncSetUp(self):
asyncio.get_running_loop().set_debug(False)
async def test_system_default():
content = "Hello !"
request = ChatCompletionRequest(
model=MODEL,
messages=[
UserMessage(content=content),
],
)
messages = chat_completion_request_to_messages(request, MODEL)
assert len(messages) == 2
assert messages[-1].content == content
assert "Cutting Knowledge Date: December 2023" in interleaved_content_as_str(messages[0].content)
async def test_system_default(self):
content = "Hello !"
request = ChatCompletionRequest(
model=MODEL,
messages=[
UserMessage(content=content),
],
)
messages = chat_completion_request_to_messages(request, MODEL)
self.assertEqual(len(messages), 2)
self.assertEqual(messages[-1].content, content)
self.assertTrue("Cutting Knowledge Date: December 2023" in messages[0].content)
async def test_system_builtin_only(self):
content = "Hello !"
request = ChatCompletionRequest(
model=MODEL,
messages=[
UserMessage(content=content),
],
tools=[
ToolDefinition(tool_name=BuiltinTool.code_interpreter),
ToolDefinition(tool_name=BuiltinTool.brave_search),
],
)
messages = chat_completion_request_to_messages(request, MODEL)
self.assertEqual(len(messages), 2)
self.assertEqual(messages[-1].content, content)
self.assertTrue("Cutting Knowledge Date: December 2023" in messages[0].content)
self.assertTrue("Tools: brave_search" in messages[0].content)
async def test_system_builtin_only():
content = "Hello !"
request = ChatCompletionRequest(
model=MODEL,
messages=[
UserMessage(content=content),
],
tools=[
ToolDefinition(tool_name=BuiltinTool.code_interpreter),
ToolDefinition(tool_name=BuiltinTool.brave_search),
],
)
messages = chat_completion_request_to_messages(request, MODEL)
assert len(messages) == 2
assert messages[-1].content == content
assert "Cutting Knowledge Date: December 2023" in interleaved_content_as_str(messages[0].content)
assert "Tools: brave_search" in interleaved_content_as_str(messages[0].content)
async def test_system_custom_only(self):
content = "Hello !"
request = ChatCompletionRequest(
model=MODEL,
messages=[
UserMessage(content=content),
],
tools=[
ToolDefinition(
tool_name="custom1",
description="custom1 tool",
parameters={
"param1": ToolParamDefinition(
param_type="str",
description="param1 description",
required=True,
),
},
)
],
tool_config=ToolConfig(tool_prompt_format=ToolPromptFormat.json),
)
messages = chat_completion_request_to_messages(request, MODEL)
self.assertEqual(len(messages), 3)
self.assertTrue("Environment: ipython" in messages[0].content)
self.assertTrue("Return function calls in JSON format" in messages[1].content)
self.assertEqual(messages[-1].content, content)
async def test_system_custom_only():
content = "Hello !"
request = ChatCompletionRequest(
model=MODEL,
messages=[
UserMessage(content=content),
],
tools=[
ToolDefinition(
tool_name="custom1",
description="custom1 tool",
parameters={
"param1": ToolParamDefinition(
param_type="str",
description="param1 description",
required=True,
),
},
)
],
tool_config=ToolConfig(tool_prompt_format=ToolPromptFormat.json),
)
messages = chat_completion_request_to_messages(request, MODEL)
assert len(messages) == 3
assert "Environment: ipython" in interleaved_content_as_str(messages[0].content)
async def test_system_custom_and_builtin(self):
content = "Hello !"
request = ChatCompletionRequest(
model=MODEL,
messages=[
UserMessage(content=content),
],
tools=[
ToolDefinition(tool_name=BuiltinTool.code_interpreter),
ToolDefinition(tool_name=BuiltinTool.brave_search),
ToolDefinition(
tool_name="custom1",
description="custom1 tool",
parameters={
"param1": ToolParamDefinition(
param_type="str",
description="param1 description",
required=True,
),
},
),
],
)
messages = chat_completion_request_to_messages(request, MODEL)
self.assertEqual(len(messages), 3)
assert "Return function calls in JSON format" in interleaved_content_as_str(messages[1].content)
assert messages[-1].content == content
self.assertTrue("Environment: ipython" in messages[0].content)
self.assertTrue("Tools: brave_search" in messages[0].content)
self.assertTrue("Return function calls in JSON format" in messages[1].content)
self.assertEqual(messages[-1].content, content)
async def test_completion_message_encoding(self):
request = ChatCompletionRequest(
model=MODEL3_2,
messages=[
UserMessage(content="hello"),
CompletionMessage(
content="",
stop_reason=StopReason.end_of_turn,
tool_calls=[
ToolCall(
tool_name="custom1",
arguments={"param1": "value1"},
call_id="123",
)
],
),
],
tools=[
ToolDefinition(
tool_name="custom1",
description="custom1 tool",
parameters={
"param1": ToolParamDefinition(
param_type="str",
description="param1 description",
required=True,
),
},
),
],
tool_config=ToolConfig(tool_prompt_format=ToolPromptFormat.python_list),
)
prompt = await chat_completion_request_to_prompt(request, request.model)
self.assertIn('[custom1(param1="value1")]', prompt)
request.model = MODEL
request.tool_config.tool_prompt_format = ToolPromptFormat.json
prompt = await chat_completion_request_to_prompt(request, request.model)
self.assertIn(
'{"type": "function", "name": "custom1", "parameters": {"param1": "value1"}}',
prompt,
)
async def test_user_provided_system_message(self):
content = "Hello !"
system_prompt = "You are a pirate"
request = ChatCompletionRequest(
model=MODEL,
messages=[
SystemMessage(content=system_prompt),
UserMessage(content=content),
],
tools=[
ToolDefinition(tool_name=BuiltinTool.code_interpreter),
],
)
messages = chat_completion_request_to_messages(request, MODEL)
self.assertEqual(len(messages), 2, messages)
self.assertTrue(messages[0].content.endswith(system_prompt))
self.assertEqual(messages[-1].content, content)
async def test_repalce_system_message_behavior_builtin_tools(self):
content = "Hello !"
system_prompt = "You are a pirate"
request = ChatCompletionRequest(
model=MODEL,
messages=[
SystemMessage(content=system_prompt),
UserMessage(content=content),
],
tools=[
ToolDefinition(tool_name=BuiltinTool.code_interpreter),
],
tool_config=ToolConfig(
tool_choice="auto",
tool_prompt_format="python_list",
system_message_behavior="replace",
async def test_system_custom_and_builtin():
content = "Hello !"
request = ChatCompletionRequest(
model=MODEL,
messages=[
UserMessage(content=content),
],
tools=[
ToolDefinition(tool_name=BuiltinTool.code_interpreter),
ToolDefinition(tool_name=BuiltinTool.brave_search),
ToolDefinition(
tool_name="custom1",
description="custom1 tool",
parameters={
"param1": ToolParamDefinition(
param_type="str",
description="param1 description",
required=True,
),
},
),
)
messages = chat_completion_request_to_messages(request, MODEL3_2)
self.assertEqual(len(messages), 2, messages)
self.assertTrue(messages[0].content.endswith(system_prompt))
self.assertIn("Environment: ipython", messages[0].content)
self.assertEqual(messages[-1].content, content)
],
)
messages = chat_completion_request_to_messages(request, MODEL)
assert len(messages) == 3
async def test_repalce_system_message_behavior_custom_tools(self):
content = "Hello !"
system_prompt = "You are a pirate"
request = ChatCompletionRequest(
model=MODEL,
messages=[
SystemMessage(content=system_prompt),
UserMessage(content=content),
],
tools=[
ToolDefinition(tool_name=BuiltinTool.code_interpreter),
ToolDefinition(
tool_name="custom1",
description="custom1 tool",
parameters={
"param1": ToolParamDefinition(
param_type="str",
description="param1 description",
required=True,
),
},
),
],
tool_config=ToolConfig(
tool_choice="auto",
tool_prompt_format="python_list",
system_message_behavior="replace",
assert "Environment: ipython" in interleaved_content_as_str(messages[0].content)
assert "Tools: brave_search" in interleaved_content_as_str(messages[0].content)
assert "Return function calls in JSON format" in interleaved_content_as_str(messages[1].content)
assert messages[-1].content == content
async def test_completion_message_encoding():
request = ChatCompletionRequest(
model=MODEL3_2,
messages=[
UserMessage(content="hello"),
CompletionMessage(
content="",
stop_reason=StopReason.end_of_turn,
tool_calls=[
ToolCall(
tool_name="custom1",
arguments={"param1": "value1"},
call_id="123",
)
],
),
)
messages = chat_completion_request_to_messages(request, MODEL3_2)
self.assertEqual(len(messages), 2, messages)
self.assertTrue(messages[0].content.endswith(system_prompt))
self.assertIn("Environment: ipython", messages[0].content)
self.assertEqual(messages[-1].content, content)
async def test_replace_system_message_behavior_custom_tools_with_template(self):
content = "Hello !"
system_prompt = "You are a pirate {{ function_description }}"
request = ChatCompletionRequest(
model=MODEL,
messages=[
SystemMessage(content=system_prompt),
UserMessage(content=content),
],
tools=[
ToolDefinition(tool_name=BuiltinTool.code_interpreter),
ToolDefinition(
tool_name="custom1",
description="custom1 tool",
parameters={
"param1": ToolParamDefinition(
param_type="str",
description="param1 description",
required=True,
),
},
),
],
tool_config=ToolConfig(
tool_choice="auto",
tool_prompt_format="python_list",
system_message_behavior="replace",
],
tools=[
ToolDefinition(
tool_name="custom1",
description="custom1 tool",
parameters={
"param1": ToolParamDefinition(
param_type="str",
description="param1 description",
required=True,
),
},
),
)
messages = chat_completion_request_to_messages(request, MODEL3_2)
],
tool_config=ToolConfig(tool_prompt_format=ToolPromptFormat.python_list),
)
prompt = await chat_completion_request_to_prompt(request, request.model)
assert '[custom1(param1="value1")]' in prompt
self.assertEqual(len(messages), 2, messages)
self.assertIn("Environment: ipython", messages[0].content)
self.assertIn("You are a pirate", messages[0].content)
# function description is present in the system prompt
self.assertIn('"name": "custom1"', messages[0].content)
self.assertEqual(messages[-1].content, content)
request.model = MODEL
request.tool_config = ToolConfig(tool_prompt_format=ToolPromptFormat.json)
prompt = await chat_completion_request_to_prompt(request, request.model)
assert '{"type": "function", "name": "custom1", "parameters": {"param1": "value1"}}' in prompt
async def test_user_provided_system_message():
content = "Hello !"
system_prompt = "You are a pirate"
request = ChatCompletionRequest(
model=MODEL,
messages=[
SystemMessage(content=system_prompt),
UserMessage(content=content),
],
tools=[
ToolDefinition(tool_name=BuiltinTool.code_interpreter),
],
)
messages = chat_completion_request_to_messages(request, MODEL)
assert len(messages) == 2
assert interleaved_content_as_str(messages[0].content).endswith(system_prompt)
assert messages[-1].content == content
async def test_replace_system_message_behavior_builtin_tools():
content = "Hello !"
system_prompt = "You are a pirate"
request = ChatCompletionRequest(
model=MODEL,
messages=[
SystemMessage(content=system_prompt),
UserMessage(content=content),
],
tools=[
ToolDefinition(tool_name=BuiltinTool.code_interpreter),
],
tool_config=ToolConfig(
tool_choice="auto",
tool_prompt_format=ToolPromptFormat.python_list,
system_message_behavior=SystemMessageBehavior.replace,
),
)
messages = chat_completion_request_to_messages(request, MODEL3_2)
assert len(messages) == 2
assert interleaved_content_as_str(messages[0].content).endswith(system_prompt)
assert "Environment: ipython" in interleaved_content_as_str(messages[0].content)
assert messages[-1].content == content
async def test_replace_system_message_behavior_custom_tools():
content = "Hello !"
system_prompt = "You are a pirate"
request = ChatCompletionRequest(
model=MODEL,
messages=[
SystemMessage(content=system_prompt),
UserMessage(content=content),
],
tools=[
ToolDefinition(tool_name=BuiltinTool.code_interpreter),
ToolDefinition(
tool_name="custom1",
description="custom1 tool",
parameters={
"param1": ToolParamDefinition(
param_type="str",
description="param1 description",
required=True,
),
},
),
],
tool_config=ToolConfig(
tool_choice="auto",
tool_prompt_format=ToolPromptFormat.python_list,
system_message_behavior=SystemMessageBehavior.replace,
),
)
messages = chat_completion_request_to_messages(request, MODEL3_2)
assert len(messages) == 2
assert interleaved_content_as_str(messages[0].content).endswith(system_prompt)
assert "Environment: ipython" in interleaved_content_as_str(messages[0].content)
assert messages[-1].content == content
async def test_replace_system_message_behavior_custom_tools_with_template():
content = "Hello !"
system_prompt = "You are a pirate {{ function_description }}"
request = ChatCompletionRequest(
model=MODEL,
messages=[
SystemMessage(content=system_prompt),
UserMessage(content=content),
],
tools=[
ToolDefinition(tool_name=BuiltinTool.code_interpreter),
ToolDefinition(
tool_name="custom1",
description="custom1 tool",
parameters={
"param1": ToolParamDefinition(
param_type="str",
description="param1 description",
required=True,
),
},
),
],
tool_config=ToolConfig(
tool_choice="auto",
tool_prompt_format=ToolPromptFormat.python_list,
system_message_behavior=SystemMessageBehavior.replace,
),
)
messages = chat_completion_request_to_messages(request, MODEL3_2)
assert len(messages) == 2
assert "Environment: ipython" in interleaved_content_as_str(messages[0].content)
assert "You are a pirate" in interleaved_content_as_str(messages[0].content)
# function description is present in the system prompt
assert '"name": "custom1"' in interleaved_content_as_str(messages[0].content)
assert messages[-1].content == content

View file

@ -12,7 +12,6 @@
# the top-level of this source tree.
import textwrap
import unittest
from datetime import datetime
from llama_stack.models.llama.llama3.prompt_templates import (
@ -24,59 +23,61 @@ from llama_stack.models.llama.llama3.prompt_templates import (
)
class PromptTemplateTests(unittest.TestCase):
def check_generator_output(self, generator):
for example in generator.data_examples():
pt = generator.gen(example)
text = pt.render()
# print(text) # debugging
if not example:
continue
for tool in example:
assert tool.tool_name in text
def test_system_default(self):
generator = SystemDefaultGenerator()
today = datetime.now().strftime("%d %B %Y")
expected_text = f"Cutting Knowledge Date: December 2023\nToday Date: {today}"
assert expected_text.strip("\n") == generator.gen(generator.data_examples()[0]).render()
def test_system_builtin_only(self):
generator = BuiltinToolGenerator()
expected_text = textwrap.dedent(
"""
Environment: ipython
Tools: brave_search, wolfram_alpha
"""
)
assert expected_text.strip("\n") == generator.gen(generator.data_examples()[0]).render()
def test_system_custom_only(self):
self.maxDiff = None
generator = JsonCustomToolGenerator()
self.check_generator_output(generator)
def test_system_custom_function_tag(self):
self.maxDiff = None
generator = FunctionTagCustomToolGenerator()
self.check_generator_output(generator)
def test_llama_3_2_system_zero_shot(self):
generator = PythonListCustomToolGenerator()
self.check_generator_output(generator)
def test_llama_3_2_provided_system_prompt(self):
generator = PythonListCustomToolGenerator()
user_system_prompt = textwrap.dedent(
"""
Overriding message.
{{ function_description }}
"""
)
example = generator.data_examples()[0]
pt = generator.gen(example, user_system_prompt)
def check_generator_output(generator):
for example in generator.data_examples():
pt = generator.gen(example)
text = pt.render()
assert "Overriding message." in text
assert '"name": "get_weather"' in text
if not example:
continue
for tool in example:
assert tool.tool_name in text
def test_system_default():
generator = SystemDefaultGenerator()
today = datetime.now().strftime("%d %B %Y")
expected_text = f"Cutting Knowledge Date: December 2023\nToday Date: {today}"
assert expected_text.strip("\n") == generator.gen(generator.data_examples()[0]).render()
def test_system_builtin_only():
generator = BuiltinToolGenerator()
expected_text = textwrap.dedent(
"""
Environment: ipython
Tools: brave_search, wolfram_alpha
"""
)
assert expected_text.strip("\n") == generator.gen(generator.data_examples()[0]).render()
def test_system_custom_only():
generator = JsonCustomToolGenerator()
check_generator_output(generator)
def test_system_custom_function_tag():
generator = FunctionTagCustomToolGenerator()
check_generator_output(generator)
def test_llama_3_2_system_zero_shot():
generator = PythonListCustomToolGenerator()
check_generator_output(generator)
def test_llama_3_2_provided_system_prompt():
generator = PythonListCustomToolGenerator()
user_system_prompt = textwrap.dedent(
"""
Overriding message.
{{ function_description }}
"""
)
example = generator.data_examples()[0]
pt = generator.gen(example, user_system_prompt)
text = pt.render()
assert "Overriding message." in text
assert '"name": "get_weather"' in text

View file

@ -0,0 +1,176 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import warnings
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from llama_stack.apis.agents import Document
from llama_stack.apis.common.content_types import URL, TextContentItem
from llama_stack.providers.inline.agents.meta_reference.agent_instance import get_raw_document_text
async def test_get_raw_document_text_supports_text_mime_types():
"""Test that the function accepts text/* mime types."""
document = Document(content="Sample text content", mime_type="text/plain")
result = await get_raw_document_text(document)
assert result == "Sample text content"
async def test_get_raw_document_text_supports_yaml_mime_type():
"""Test that the function accepts application/yaml mime type."""
yaml_content = """
name: test
version: 1.0
items:
- item1
- item2
"""
document = Document(content=yaml_content, mime_type="application/yaml")
result = await get_raw_document_text(document)
assert result == yaml_content
async def test_get_raw_document_text_supports_deprecated_text_yaml_with_warning():
"""Test that the function accepts text/yaml but emits a deprecation warning."""
yaml_content = """
name: test
version: 1.0
items:
- item1
- item2
"""
document = Document(content=yaml_content, mime_type="text/yaml")
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
result = await get_raw_document_text(document)
# Check that result is correct
assert result == yaml_content
# Check that exactly one warning was issued
assert len(w) == 1
assert issubclass(w[0].category, DeprecationWarning)
assert "text/yaml" in str(w[0].message)
assert "application/yaml" in str(w[0].message)
assert "deprecated" in str(w[0].message).lower()
async def test_get_raw_document_text_deprecated_text_yaml_with_url():
"""Test that text/yaml works with URL content and emits warning."""
yaml_content = "name: test\nversion: 1.0"
with patch("llama_stack.providers.inline.agents.meta_reference.agent_instance.load_data_from_url") as mock_load:
mock_load.return_value = yaml_content
document = Document(content=URL(uri="https://example.com/config.yaml"), mime_type="text/yaml")
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
result = await get_raw_document_text(document)
# Check that result is correct
assert result == yaml_content
mock_load.assert_called_once_with("https://example.com/config.yaml")
# Check that deprecation warning was issued
assert len(w) == 1
assert issubclass(w[0].category, DeprecationWarning)
assert "text/yaml" in str(w[0].message)
async def test_get_raw_document_text_deprecated_text_yaml_with_text_content_item():
"""Test that text/yaml works with TextContentItem and emits warning."""
yaml_content = "key: value\nlist:\n - item1\n - item2"
document = Document(content=TextContentItem(text=yaml_content), mime_type="text/yaml")
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
result = await get_raw_document_text(document)
# Check that result is correct
assert result == yaml_content
# Check that deprecation warning was issued
assert len(w) == 1
assert issubclass(w[0].category, DeprecationWarning)
assert "text/yaml" in str(w[0].message)
async def test_get_raw_document_text_rejects_unsupported_mime_types():
"""Test that the function rejects unsupported mime types."""
document = Document(
content="Some content",
mime_type="application/json", # Not supported
)
with pytest.raises(ValueError, match="Unexpected document mime type: application/json"):
await get_raw_document_text(document)
async def test_get_raw_document_text_with_url_content():
"""Test that the function handles URL content correctly."""
mock_response = AsyncMock()
mock_response.text = "Content from URL"
with patch("llama_stack.providers.inline.agents.meta_reference.agent_instance.load_data_from_url") as mock_load:
mock_load.return_value = "Content from URL"
document = Document(content=URL(uri="https://example.com/test.txt"), mime_type="text/plain")
result = await get_raw_document_text(document)
assert result == "Content from URL"
mock_load.assert_called_once_with("https://example.com/test.txt")
async def test_get_raw_document_text_with_yaml_url():
"""Test that the function handles YAML URLs correctly."""
yaml_content = "name: test\nversion: 1.0"
with patch("llama_stack.providers.inline.agents.meta_reference.agent_instance.load_data_from_url") as mock_load:
mock_load.return_value = yaml_content
document = Document(content=URL(uri="https://example.com/config.yaml"), mime_type="application/yaml")
result = await get_raw_document_text(document)
assert result == yaml_content
mock_load.assert_called_once_with("https://example.com/config.yaml")
async def test_get_raw_document_text_with_text_content_item():
"""Test that the function handles TextContentItem correctly."""
document = Document(content=TextContentItem(text="Text content item"), mime_type="text/plain")
result = await get_raw_document_text(document)
assert result == "Text content item"
async def test_get_raw_document_text_with_yaml_text_content_item():
"""Test that the function handles YAML TextContentItem correctly."""
yaml_content = "key: value\nlist:\n - item1\n - item2"
document = Document(content=TextContentItem(text=yaml_content), mime_type="application/yaml")
result = await get_raw_document_text(document)
assert result == yaml_content
async def test_get_raw_document_text_rejects_unexpected_content_type():
"""Test that the function rejects unexpected document content types."""
# Create a mock document that bypasses Pydantic validation
mock_document = MagicMock(spec=Document)
mock_document.mime_type = "text/plain"
mock_document.content = 123 # Unexpected content type (not str, URL, or TextContentItem)
with pytest.raises(ValueError, match="Unexpected document content type: <class 'int'>"):
await get_raw_document_text(mock_document)

View file

@ -8,7 +8,6 @@ from datetime import datetime
from unittest.mock import AsyncMock
import pytest
import pytest_asyncio
from llama_stack.apis.agents import (
Agent,
@ -50,7 +49,7 @@ def config(tmp_path):
)
@pytest_asyncio.fixture
@pytest.fixture
async def agents_impl(config, mock_apis):
impl = MetaReferenceAgentsImpl(
config,
@ -117,7 +116,6 @@ def sample_agent_config():
)
@pytest.mark.asyncio
async def test_create_agent(agents_impl, sample_agent_config):
response = await agents_impl.create_agent(sample_agent_config)
@ -132,7 +130,6 @@ async def test_create_agent(agents_impl, sample_agent_config):
assert isinstance(agent_info.created_at, datetime)
@pytest.mark.asyncio
async def test_get_agent(agents_impl, sample_agent_config):
create_response = await agents_impl.create_agent(sample_agent_config)
agent_id = create_response.agent_id
@ -146,7 +143,6 @@ async def test_get_agent(agents_impl, sample_agent_config):
assert isinstance(agent.created_at, datetime)
@pytest.mark.asyncio
async def test_list_agents(agents_impl, sample_agent_config):
agent1_response = await agents_impl.create_agent(sample_agent_config)
agent2_response = await agents_impl.create_agent(sample_agent_config)
@ -160,7 +156,6 @@ async def test_list_agents(agents_impl, sample_agent_config):
assert agent2_response.agent_id in agent_ids
@pytest.mark.asyncio
@pytest.mark.parametrize("enable_session_persistence", [True, False])
async def test_create_agent_session_persistence(agents_impl, sample_agent_config, enable_session_persistence):
# Create an agent with specified persistence setting
@ -188,7 +183,6 @@ async def test_create_agent_session_persistence(agents_impl, sample_agent_config
await agents_impl.get_agents_session(agent_id, session_response.session_id)
@pytest.mark.asyncio
@pytest.mark.parametrize("enable_session_persistence", [True, False])
async def test_list_agent_sessions_persistence(agents_impl, sample_agent_config, enable_session_persistence):
# Create an agent with specified persistence setting
@ -221,7 +215,6 @@ async def test_list_agent_sessions_persistence(agents_impl, sample_agent_config,
assert session2.session_id in {s["session_id"] for s in sessions.data}
@pytest.mark.asyncio
async def test_delete_agent(agents_impl, sample_agent_config):
# Create an agent
response = await agents_impl.create_agent(sample_agent_config)

View file

@ -40,7 +40,7 @@ from llama_stack.apis.inference import (
OpenAIUserMessageParam,
)
from llama_stack.apis.tools.tools import Tool, ToolGroups, ToolInvocationResult, ToolParameter, ToolRuntime
from llama_stack.distribution.access_control.access_control import default_policy
from llama_stack.core.access_control.access_control import default_policy
from llama_stack.providers.inline.agents.meta_reference.openai_responses import (
OpenAIResponsesImpl,
)
@ -122,7 +122,6 @@ async def fake_stream(fixture: str = "simple_chat_completion.yaml"):
)
@pytest.mark.asyncio
async def test_create_openai_response_with_string_input(openai_responses_impl, mock_inference_api):
"""Test creating an OpenAI response with a simple string input."""
# Setup
@ -155,7 +154,6 @@ async def test_create_openai_response_with_string_input(openai_responses_impl, m
assert result.output[0].content[0].text == "Dublin"
@pytest.mark.asyncio
async def test_create_openai_response_with_string_input_with_tools(openai_responses_impl, mock_inference_api):
"""Test creating an OpenAI response with a simple string input and tools."""
# Setup
@ -224,7 +222,6 @@ async def test_create_openai_response_with_string_input_with_tools(openai_respon
assert result.output[1].content[0].annotations == []
@pytest.mark.asyncio
async def test_create_openai_response_with_tool_call_type_none(openai_responses_impl, mock_inference_api):
"""Test creating an OpenAI response with a tool call response that has a type of None."""
# Setup
@ -294,7 +291,6 @@ async def test_create_openai_response_with_tool_call_type_none(openai_responses_
assert chunks[1].response.output[0].name == "get_weather"
@pytest.mark.asyncio
async def test_create_openai_response_with_multiple_messages(openai_responses_impl, mock_inference_api):
"""Test creating an OpenAI response with multiple messages."""
# Setup
@ -340,7 +336,6 @@ async def test_create_openai_response_with_multiple_messages(openai_responses_im
assert isinstance(inference_messages[i], OpenAIDeveloperMessageParam)
@pytest.mark.asyncio
async def test_prepend_previous_response_none(openai_responses_impl):
"""Test prepending no previous response to a new response."""
@ -348,7 +343,6 @@ async def test_prepend_previous_response_none(openai_responses_impl):
assert input == "fake_input"
@pytest.mark.asyncio
async def test_prepend_previous_response_basic(openai_responses_impl, mock_responses_store):
"""Test prepending a basic previous response to a new response."""
@ -388,7 +382,6 @@ async def test_prepend_previous_response_basic(openai_responses_impl, mock_respo
assert input[2].content == "fake_input"
@pytest.mark.asyncio
async def test_prepend_previous_response_web_search(openai_responses_impl, mock_responses_store):
"""Test prepending a web search previous response to a new response."""
input_item_message = OpenAIResponseMessage(
@ -434,7 +427,6 @@ async def test_prepend_previous_response_web_search(openai_responses_impl, mock_
assert input[3].content == "fake_input"
@pytest.mark.asyncio
async def test_create_openai_response_with_instructions(openai_responses_impl, mock_inference_api):
# Setup
input_text = "What is the capital of Ireland?"
@ -463,7 +455,6 @@ async def test_create_openai_response_with_instructions(openai_responses_impl, m
assert sent_messages[1].content == input_text
@pytest.mark.asyncio
async def test_create_openai_response_with_instructions_and_multiple_messages(
openai_responses_impl, mock_inference_api
):
@ -508,7 +499,6 @@ async def test_create_openai_response_with_instructions_and_multiple_messages(
assert sent_messages[3].content == "Which is the largest?"
@pytest.mark.asyncio
async def test_create_openai_response_with_instructions_and_previous_response(
openai_responses_impl, mock_responses_store, mock_inference_api
):
@ -565,7 +555,6 @@ async def test_create_openai_response_with_instructions_and_previous_response(
assert sent_messages[3].content == "Which is the largest?"
@pytest.mark.asyncio
async def test_list_openai_response_input_items_delegation(openai_responses_impl, mock_responses_store):
"""Test that list_openai_response_input_items properly delegates to responses_store with correct parameters."""
# Setup
@ -601,7 +590,6 @@ async def test_list_openai_response_input_items_delegation(openai_responses_impl
assert result.data[0].id == "msg_123"
@pytest.mark.asyncio
async def test_responses_store_list_input_items_logic():
"""Test ResponsesStore list_response_input_items logic - mocks get_response_object to test actual ordering/limiting."""
@ -680,7 +668,6 @@ async def test_responses_store_list_input_items_logic():
assert len(result.data) == 0 # Should return no items
@pytest.mark.asyncio
async def test_store_response_uses_rehydrated_input_with_previous_response(
openai_responses_impl, mock_responses_store, mock_inference_api
):
@ -747,7 +734,6 @@ async def test_store_response_uses_rehydrated_input_with_previous_response(
assert result.status == "completed"
@pytest.mark.asyncio
@pytest.mark.parametrize(
"text_format, response_format",
[
@ -787,7 +773,6 @@ async def test_create_openai_response_with_text_format(
assert first_call.kwargs["response_format"] == response_format
@pytest.mark.asyncio
async def test_create_openai_response_with_invalid_text_format(openai_responses_impl, mock_inference_api):
"""Test creating an OpenAI response with an invalid text format."""
# Setup

View file

@ -9,21 +9,19 @@ from datetime import datetime
from unittest.mock import patch
import pytest
import pytest_asyncio
from llama_stack.apis.agents import Turn
from llama_stack.apis.inference import CompletionMessage, StopReason
from llama_stack.distribution.datatypes import User
from llama_stack.core.datatypes import User
from llama_stack.providers.inline.agents.meta_reference.persistence import AgentPersistence, AgentSessionInfo
@pytest_asyncio.fixture
@pytest.fixture
async def test_setup(sqlite_kvstore):
agent_persistence = AgentPersistence(agent_id="test_agent", kvstore=sqlite_kvstore, policy={})
yield agent_persistence
@pytest.mark.asyncio
@patch("llama_stack.providers.inline.agents.meta_reference.persistence.get_authenticated_user")
async def test_session_creation_with_access_attributes(mock_get_authenticated_user, test_setup):
agent_persistence = test_setup
@ -44,7 +42,6 @@ async def test_session_creation_with_access_attributes(mock_get_authenticated_us
assert session_info.owner.attributes["teams"] == ["ai-team"]
@pytest.mark.asyncio
@patch("llama_stack.providers.inline.agents.meta_reference.persistence.get_authenticated_user")
async def test_session_access_control(mock_get_authenticated_user, test_setup):
agent_persistence = test_setup
@ -79,7 +76,6 @@ async def test_session_access_control(mock_get_authenticated_user, test_setup):
assert retrieved_session is None
@pytest.mark.asyncio
@patch("llama_stack.providers.inline.agents.meta_reference.persistence.get_authenticated_user")
async def test_turn_access_control(mock_get_authenticated_user, test_setup):
agent_persistence = test_setup
@ -133,7 +129,6 @@ async def test_turn_access_control(mock_get_authenticated_user, test_setup):
await agent_persistence.get_session_turns(session_id)
@pytest.mark.asyncio
@patch("llama_stack.providers.inline.agents.meta_reference.persistence.get_authenticated_user")
async def test_tool_call_and_infer_iters_access_control(mock_get_authenticated_user, test_setup):
agent_persistence = test_setup

View file

@ -0,0 +1,90 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import json
from unittest.mock import MagicMock
from llama_stack.core.request_headers import request_provider_data_context
from llama_stack.providers.remote.inference.groq.config import GroqConfig
from llama_stack.providers.remote.inference.groq.groq import GroqInferenceAdapter
from llama_stack.providers.remote.inference.llama_openai_compat.config import LlamaCompatConfig
from llama_stack.providers.remote.inference.llama_openai_compat.llama import LlamaCompatInferenceAdapter
from llama_stack.providers.remote.inference.openai.config import OpenAIConfig
from llama_stack.providers.remote.inference.openai.openai import OpenAIInferenceAdapter
from llama_stack.providers.remote.inference.together.config import TogetherImplConfig
from llama_stack.providers.remote.inference.together.together import TogetherInferenceAdapter
def test_groq_provider_openai_client_caching():
"""Ensure the Groq provider does not cache api keys across client requests"""
config = GroqConfig()
inference_adapter = GroqInferenceAdapter(config)
inference_adapter.__provider_spec__ = MagicMock()
inference_adapter.__provider_spec__.provider_data_validator = (
"llama_stack.providers.remote.inference.groq.config.GroqProviderDataValidator"
)
for api_key in ["test1", "test2"]:
with request_provider_data_context(
{"x-llamastack-provider-data": json.dumps({inference_adapter.provider_data_api_key_field: api_key})}
):
openai_client = inference_adapter._get_openai_client()
assert openai_client.api_key == api_key
def test_openai_provider_openai_client_caching():
"""Ensure the OpenAI provider does not cache api keys across client requests"""
config = OpenAIConfig()
inference_adapter = OpenAIInferenceAdapter(config)
inference_adapter.__provider_spec__ = MagicMock()
inference_adapter.__provider_spec__.provider_data_validator = (
"llama_stack.providers.remote.inference.openai.config.OpenAIProviderDataValidator"
)
for api_key in ["test1", "test2"]:
with request_provider_data_context(
{"x-llamastack-provider-data": json.dumps({inference_adapter.provider_data_api_key_field: api_key})}
):
openai_client = inference_adapter.client
assert openai_client.api_key == api_key
def test_together_provider_openai_client_caching():
"""Ensure the Together provider does not cache api keys across client requests"""
config = TogetherImplConfig()
inference_adapter = TogetherInferenceAdapter(config)
inference_adapter.__provider_spec__ = MagicMock()
inference_adapter.__provider_spec__.provider_data_validator = (
"llama_stack.providers.remote.inference.together.TogetherProviderDataValidator"
)
for api_key in ["test1", "test2"]:
with request_provider_data_context({"x-llamastack-provider-data": json.dumps({"together_api_key": api_key})}):
together_client = inference_adapter._get_client()
assert together_client.client.api_key == api_key
openai_client = inference_adapter._get_openai_client()
assert openai_client.api_key == api_key
def test_llama_compat_provider_openai_client_caching():
"""Ensure the LlamaCompat provider does not cache api keys across client requests"""
config = LlamaCompatConfig()
inference_adapter = LlamaCompatInferenceAdapter(config)
inference_adapter.__provider_spec__ = MagicMock()
inference_adapter.__provider_spec__.provider_data_validator = (
"llama_stack.providers.remote.inference.llama_openai_compat.config.LlamaProviderDataValidator"
)
for api_key in ["test1", "test2"]:
with request_provider_data_context({"x-llamastack-provider-data": json.dumps({"llama_api_key": api_key})}):
assert inference_adapter.client.api_key == api_key

View file

@ -0,0 +1,112 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import json
from unittest.mock import MagicMock
import pytest
from pydantic import BaseModel, Field
from llama_stack.core.request_headers import request_provider_data_context
from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOpenAIMixin
# Test fixtures and helper classes
class TestConfig(BaseModel):
api_key: str | None = Field(default=None)
class TestProviderDataValidator(BaseModel):
test_api_key: str | None = Field(default=None)
class TestLiteLLMAdapter(LiteLLMOpenAIMixin):
def __init__(self, config: TestConfig):
super().__init__(
model_entries=[],
litellm_provider_name="test",
api_key_from_config=config.api_key,
provider_data_api_key_field="test_api_key",
openai_compat_api_base=None,
)
@pytest.fixture
def adapter_with_config_key():
"""Fixture to create adapter with API key in config"""
config = TestConfig(api_key="config-api-key")
adapter = TestLiteLLMAdapter(config)
adapter.__provider_spec__ = MagicMock()
adapter.__provider_spec__.provider_data_validator = (
"tests.unit.providers.inference.test_litellm_openai_mixin.TestProviderDataValidator"
)
return adapter
@pytest.fixture
def adapter_without_config_key():
"""Fixture to create adapter without API key in config"""
config = TestConfig(api_key=None)
adapter = TestLiteLLMAdapter(config)
adapter.__provider_spec__ = MagicMock()
adapter.__provider_spec__.provider_data_validator = (
"tests.unit.providers.inference.test_litellm_openai_mixin.TestProviderDataValidator"
)
return adapter
def test_api_key_from_config_when_no_provider_data(adapter_with_config_key):
"""Test that adapter uses config API key when no provider data is available"""
api_key = adapter_with_config_key.get_api_key()
assert api_key == "config-api-key"
def test_provider_data_takes_priority_over_config(adapter_with_config_key):
"""Test that provider data API key overrides config API key"""
with request_provider_data_context(
{"x-llamastack-provider-data": json.dumps({"test_api_key": "provider-data-key"})}
):
api_key = adapter_with_config_key.get_api_key()
assert api_key == "provider-data-key"
def test_fallback_to_config_when_provider_data_missing_key(adapter_with_config_key):
"""Test fallback to config when provider data doesn't have the required key"""
with request_provider_data_context({"x-llamastack-provider-data": json.dumps({"wrong_key": "some-value"})}):
api_key = adapter_with_config_key.get_api_key()
assert api_key == "config-api-key"
def test_error_when_no_api_key_available(adapter_without_config_key):
"""Test that ValueError is raised when neither config nor provider data have API key"""
with pytest.raises(ValueError, match="API key is not set"):
adapter_without_config_key.get_api_key()
def test_error_when_provider_data_has_wrong_key(adapter_without_config_key):
"""Test that ValueError is raised when provider data exists but doesn't have required key"""
with request_provider_data_context({"x-llamastack-provider-data": json.dumps({"wrong_key": "some-value"})}):
with pytest.raises(ValueError, match="API key is not set"):
adapter_without_config_key.get_api_key()
def test_provider_data_works_when_config_is_none(adapter_without_config_key):
"""Test that provider data works even when config has no API key"""
with request_provider_data_context(
{"x-llamastack-provider-data": json.dumps({"test_api_key": "provider-only-key"})}
):
api_key = adapter_without_config_key.get_api_key()
assert api_key == "provider-only-key"
def test_error_message_includes_correct_field_names(adapter_without_config_key):
"""Test that error message includes correct field name and header information"""
try:
adapter_without_config_key.get_api_key()
raise AssertionError("Should have raised ValueError")
except ValueError as e:
assert "test_api_key" in str(e) # Should mention the correct field name
assert "x-llamastack-provider-data" in str(e) # Should mention header name

View file

@ -0,0 +1,125 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import os
from unittest.mock import AsyncMock, MagicMock, patch
from llama_stack.core.stack import replace_env_vars
from llama_stack.providers.remote.inference.openai.config import OpenAIConfig
from llama_stack.providers.remote.inference.openai.openai import OpenAIInferenceAdapter
class TestOpenAIBaseURLConfig:
"""Test that OPENAI_BASE_URL environment variable properly configures the OpenAI adapter."""
def test_default_base_url_without_env_var(self):
"""Test that the adapter uses the default OpenAI base URL when no environment variable is set."""
config = OpenAIConfig(api_key="test-key")
adapter = OpenAIInferenceAdapter(config)
assert adapter.get_base_url() == "https://api.openai.com/v1"
def test_custom_base_url_from_config(self):
"""Test that the adapter uses a custom base URL when provided in config."""
custom_url = "https://custom.openai.com/v1"
config = OpenAIConfig(api_key="test-key", base_url=custom_url)
adapter = OpenAIInferenceAdapter(config)
assert adapter.get_base_url() == custom_url
@patch.dict(os.environ, {"OPENAI_BASE_URL": "https://env.openai.com/v1"})
def test_base_url_from_environment_variable(self):
"""Test that the adapter uses base URL from OPENAI_BASE_URL environment variable."""
# Use sample_run_config which has proper environment variable syntax
config_data = OpenAIConfig.sample_run_config(api_key="test-key")
processed_config = replace_env_vars(config_data)
config = OpenAIConfig.model_validate(processed_config)
adapter = OpenAIInferenceAdapter(config)
assert adapter.get_base_url() == "https://env.openai.com/v1"
@patch.dict(os.environ, {"OPENAI_BASE_URL": "https://env.openai.com/v1"})
def test_config_overrides_environment_variable(self):
"""Test that explicit config value overrides environment variable."""
custom_url = "https://config.openai.com/v1"
config = OpenAIConfig(api_key="test-key", base_url=custom_url)
adapter = OpenAIInferenceAdapter(config)
# Config should take precedence over environment variable
assert adapter.get_base_url() == custom_url
@patch("llama_stack.providers.utils.inference.openai_mixin.AsyncOpenAI")
def test_client_uses_configured_base_url(self, mock_openai_class):
"""Test that the OpenAI client is initialized with the configured base URL."""
custom_url = "https://test.openai.com/v1"
config = OpenAIConfig(api_key="test-key", base_url=custom_url)
adapter = OpenAIInferenceAdapter(config)
# Mock the get_api_key method since it's delegated to LiteLLMOpenAIMixin
adapter.get_api_key = MagicMock(return_value="test-key")
# Access the client property to trigger AsyncOpenAI initialization
_ = adapter.client
# Verify AsyncOpenAI was called with the correct base_url
mock_openai_class.assert_called_once_with(
api_key="test-key",
base_url=custom_url,
)
@patch("llama_stack.providers.utils.inference.openai_mixin.AsyncOpenAI")
async def test_check_model_availability_uses_configured_url(self, mock_openai_class):
"""Test that check_model_availability uses the configured base URL."""
custom_url = "https://test.openai.com/v1"
config = OpenAIConfig(api_key="test-key", base_url=custom_url)
adapter = OpenAIInferenceAdapter(config)
# Mock the get_api_key method
adapter.get_api_key = MagicMock(return_value="test-key")
# Mock the AsyncOpenAI client and its models.retrieve method
mock_client = MagicMock()
mock_client.models.retrieve = AsyncMock(return_value=MagicMock())
mock_openai_class.return_value = mock_client
# Call check_model_availability and verify it returns True
assert await adapter.check_model_availability("gpt-4")
# Verify the client was created with the custom URL
mock_openai_class.assert_called_with(
api_key="test-key",
base_url=custom_url,
)
# Verify the method was called and returned True
mock_client.models.retrieve.assert_called_once_with("gpt-4")
@patch.dict(os.environ, {"OPENAI_BASE_URL": "https://proxy.openai.com/v1"})
@patch("llama_stack.providers.utils.inference.openai_mixin.AsyncOpenAI")
async def test_environment_variable_affects_model_availability_check(self, mock_openai_class):
"""Test that setting OPENAI_BASE_URL environment variable affects where model availability is checked."""
# Use sample_run_config which has proper environment variable syntax
config_data = OpenAIConfig.sample_run_config(api_key="test-key")
processed_config = replace_env_vars(config_data)
config = OpenAIConfig.model_validate(processed_config)
adapter = OpenAIInferenceAdapter(config)
# Mock the get_api_key method
adapter.get_api_key = MagicMock(return_value="test-key")
# Mock the AsyncOpenAI client
mock_client = MagicMock()
mock_client.models.retrieve = AsyncMock(return_value=MagicMock())
mock_openai_class.return_value = mock_client
# Call check_model_availability and verify it returns True
assert await adapter.check_model_availability("gpt-4")
# Verify the client was created with the environment variable URL
mock_openai_class.assert_called_with(
api_key="test-key",
base_url="https://proxy.openai.com/v1",
)

View file

@ -14,7 +14,6 @@ from typing import Any
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
import pytest_asyncio
from openai.types.chat.chat_completion_chunk import (
ChatCompletionChunk as OpenAIChatCompletionChunk,
)
@ -103,7 +102,7 @@ def mock_openai_models_list():
yield mock_list
@pytest_asyncio.fixture(scope="module")
@pytest.fixture(scope="module")
async def vllm_inference_adapter():
config = VLLMInferenceAdapterConfig(url="http://mocked.localhost:12345")
inference_adapter = VLLMInferenceAdapter(config)
@ -112,7 +111,6 @@ async def vllm_inference_adapter():
return inference_adapter
@pytest.mark.asyncio
async def test_register_model_checks_vllm(mock_openai_models_list, vllm_inference_adapter):
async def mock_openai_models():
yield OpenAIModel(id="foo", created=1, object="model", owned_by="test")
@ -125,7 +123,6 @@ async def test_register_model_checks_vllm(mock_openai_models_list, vllm_inferenc
mock_openai_models_list.assert_called()
@pytest.mark.asyncio
async def test_old_vllm_tool_choice(vllm_inference_adapter):
"""
Test that we set tool_choice to none when no tools are in use
@ -149,7 +146,6 @@ async def test_old_vllm_tool_choice(vllm_inference_adapter):
assert request.tool_config.tool_choice == ToolChoice.none
@pytest.mark.asyncio
async def test_tool_call_response(vllm_inference_adapter):
"""Verify that tool call arguments from a CompletionMessage are correctly converted
into the expected JSON format."""
@ -192,7 +188,6 @@ async def test_tool_call_response(vllm_inference_adapter):
]
@pytest.mark.asyncio
async def test_tool_call_delta_empty_tool_call_buf():
"""
Test that we don't generate extra chunks when processing a
@ -222,7 +217,6 @@ async def test_tool_call_delta_empty_tool_call_buf():
assert chunks[1].event.stop_reason == StopReason.end_of_turn
@pytest.mark.asyncio
async def test_tool_call_delta_streaming_arguments_dict():
async def mock_stream():
mock_chunk_1 = OpenAIChatCompletionChunk(
@ -297,7 +291,6 @@ async def test_tool_call_delta_streaming_arguments_dict():
assert chunks[2].event.event_type.value == "complete"
@pytest.mark.asyncio
async def test_multiple_tool_calls():
async def mock_stream():
mock_chunk_1 = OpenAIChatCompletionChunk(
@ -376,7 +369,6 @@ async def test_multiple_tool_calls():
assert chunks[3].event.event_type.value == "complete"
@pytest.mark.asyncio
async def test_process_vllm_chat_completion_stream_response_no_choices():
"""
Test that we don't error out when vLLM returns no choices for a
@ -401,6 +393,7 @@ async def test_process_vllm_chat_completion_stream_response_no_choices():
assert chunks[0].event.event_type.value == "start"
@pytest.mark.allow_network
def test_chat_completion_doesnt_block_event_loop(caplog):
loop = asyncio.new_event_loop()
loop.set_debug(True)
@ -453,7 +446,6 @@ def test_chat_completion_doesnt_block_event_loop(caplog):
assert not asyncio_warnings
@pytest.mark.asyncio
async def test_get_params_empty_tools(vllm_inference_adapter):
request = ChatCompletionRequest(
tools=[],
@ -464,7 +456,6 @@ async def test_get_params_empty_tools(vllm_inference_adapter):
assert "tools" not in params
@pytest.mark.asyncio
async def test_process_vllm_chat_completion_stream_response_tool_call_args_last_chunk():
"""
Tests the edge case where the model returns the arguments for the tool call in the same chunk that
@ -543,7 +534,6 @@ async def test_process_vllm_chat_completion_stream_response_tool_call_args_last_
assert chunks[-2].event.delta.tool_call.arguments == mock_tool_arguments
@pytest.mark.asyncio
async def test_process_vllm_chat_completion_stream_response_no_finish_reason():
"""
Tests the edge case where the model requests a tool call and stays idle without explicitly providing the
@ -596,7 +586,6 @@ async def test_process_vllm_chat_completion_stream_response_no_finish_reason():
assert chunks[-2].event.delta.tool_call.arguments == mock_tool_arguments
@pytest.mark.asyncio
async def test_process_vllm_chat_completion_stream_response_tool_without_args():
"""
Tests the edge case where no arguments are provided for the tool call.
@ -645,7 +634,6 @@ async def test_process_vllm_chat_completion_stream_response_tool_without_args():
assert chunks[-2].event.delta.tool_call.arguments == {}
@pytest.mark.asyncio
async def test_health_status_success(vllm_inference_adapter):
"""
Test the health method of VLLM InferenceAdapter when the connection is successful.
@ -679,7 +667,6 @@ async def test_health_status_success(vllm_inference_adapter):
mock_models.list.assert_called_once()
@pytest.mark.asyncio
async def test_health_status_failure(vllm_inference_adapter):
"""
Test the health method of VLLM InferenceAdapter when the connection fails.

View file

@ -5,103 +5,110 @@
# the root directory of this source tree.
import os
import unittest
from unittest.mock import patch
import pytest
from llama_stack.apis.datasets import Dataset, DatasetPurpose, URIDataSource
from llama_stack.apis.resource import ResourceType
from llama_stack.providers.remote.datasetio.nvidia.config import NvidiaDatasetIOConfig
from llama_stack.providers.remote.datasetio.nvidia.datasetio import NvidiaDatasetIOAdapter
class TestNvidiaDatastore(unittest.TestCase):
def setUp(self):
os.environ["NVIDIA_DATASETS_URL"] = "http://nemo.test/datasets"
@pytest.fixture
def nvidia_adapter():
"""Fixture to set up NvidiaDatasetIOAdapter with mocked requests."""
os.environ["NVIDIA_DATASETS_URL"] = "http://nemo.test/datasets"
config = NvidiaDatasetIOConfig(
datasets_url=os.environ["NVIDIA_DATASETS_URL"], dataset_namespace="default", project_id="default"
)
self.adapter = NvidiaDatasetIOAdapter(config)
self.make_request_patcher = patch(
"llama_stack.providers.remote.datasetio.nvidia.datasetio.NvidiaDatasetIOAdapter._make_request"
)
self.mock_make_request = self.make_request_patcher.start()
config = NvidiaDatasetIOConfig(
datasets_url=os.environ["NVIDIA_DATASETS_URL"], dataset_namespace="default", project_id="default"
)
adapter = NvidiaDatasetIOAdapter(config)
def tearDown(self):
self.make_request_patcher.stop()
with patch(
"llama_stack.providers.remote.datasetio.nvidia.datasetio.NvidiaDatasetIOAdapter._make_request"
) as mock_make_request:
yield adapter, mock_make_request
@pytest.fixture(autouse=True)
def inject_fixtures(self, run_async):
self.run_async = run_async
def _assert_request(self, mock_call, expected_method, expected_path, expected_json=None):
"""Helper method to verify request details in mock calls."""
call_args = mock_call.call_args
def _assert_request(mock_call, expected_method, expected_path, expected_json=None):
"""Helper function to verify request details in mock calls."""
call_args = mock_call.call_args
assert call_args[0][0] == expected_method
assert call_args[0][1] == expected_path
assert call_args[0][0] == expected_method
assert call_args[0][1] == expected_path
if expected_json:
for key, value in expected_json.items():
assert call_args[1]["json"][key] == value
if expected_json:
for key, value in expected_json.items():
assert call_args[1]["json"][key] == value
def test_register_dataset(self):
self.mock_make_request.return_value = {
"id": "dataset-123456",
def test_register_dataset(nvidia_adapter, run_async):
adapter, mock_make_request = nvidia_adapter
mock_make_request.return_value = {
"id": "dataset-123456",
"name": "test-dataset",
"namespace": "default",
}
dataset_def = Dataset(
identifier="test-dataset",
type=ResourceType.dataset,
provider_resource_id="",
provider_id="",
purpose=DatasetPurpose.post_training_messages,
source=URIDataSource(uri="https://example.com/data.jsonl"),
metadata={"provider_id": "nvidia", "format": "jsonl", "description": "Test dataset description"},
)
run_async(adapter.register_dataset(dataset_def))
mock_make_request.assert_called_once()
_assert_request(
mock_make_request,
"POST",
"/v1/datasets",
expected_json={
"name": "test-dataset",
"namespace": "default",
}
"files_url": "https://example.com/data.jsonl",
"project": "default",
"format": "jsonl",
"description": "Test dataset description",
},
)
dataset_def = Dataset(
identifier="test-dataset",
type="dataset",
provider_resource_id="",
provider_id="",
purpose=DatasetPurpose.post_training_messages,
source=URIDataSource(uri="https://example.com/data.jsonl"),
metadata={"provider_id": "nvidia", "format": "jsonl", "description": "Test dataset description"},
)
self.run_async(self.adapter.register_dataset(dataset_def))
def test_unregister_dataset(nvidia_adapter, run_async):
adapter, mock_make_request = nvidia_adapter
mock_make_request.return_value = {
"message": "Resource deleted successfully.",
"id": "dataset-81RSQp7FKX3rdBtKvF9Skn",
"deleted_at": None,
}
dataset_id = "test-dataset"
self.mock_make_request.assert_called_once()
self._assert_request(
self.mock_make_request,
"POST",
"/v1/datasets",
expected_json={
"name": "test-dataset",
"namespace": "default",
"files_url": "https://example.com/data.jsonl",
"project": "default",
"format": "jsonl",
"description": "Test dataset description",
},
)
run_async(adapter.unregister_dataset(dataset_id))
def test_unregister_dataset(self):
self.mock_make_request.return_value = {
"message": "Resource deleted successfully.",
"id": "dataset-81RSQp7FKX3rdBtKvF9Skn",
"deleted_at": None,
}
dataset_id = "test-dataset"
mock_make_request.assert_called_once()
_assert_request(mock_make_request, "DELETE", "/v1/datasets/default/test-dataset")
self.run_async(self.adapter.unregister_dataset(dataset_id))
self.mock_make_request.assert_called_once()
self._assert_request(self.mock_make_request, "DELETE", "/v1/datasets/default/test-dataset")
def test_register_dataset_with_custom_namespace_project(run_async):
"""Test with custom namespace and project configuration."""
os.environ["NVIDIA_DATASETS_URL"] = "http://nemo.test/datasets"
def test_register_dataset_with_custom_namespace_project(self):
custom_config = NvidiaDatasetIOConfig(
datasets_url=os.environ["NVIDIA_DATASETS_URL"],
dataset_namespace="custom-namespace",
project_id="custom-project",
)
custom_adapter = NvidiaDatasetIOAdapter(custom_config)
custom_config = NvidiaDatasetIOConfig(
datasets_url=os.environ["NVIDIA_DATASETS_URL"],
dataset_namespace="custom-namespace",
project_id="custom-project",
)
custom_adapter = NvidiaDatasetIOAdapter(custom_config)
self.mock_make_request.return_value = {
with patch(
"llama_stack.providers.remote.datasetio.nvidia.datasetio.NvidiaDatasetIOAdapter._make_request"
) as mock_make_request:
mock_make_request.return_value = {
"id": "dataset-123456",
"name": "test-dataset",
"namespace": "custom-namespace",
@ -109,7 +116,7 @@ class TestNvidiaDatastore(unittest.TestCase):
dataset_def = Dataset(
identifier="test-dataset",
type="dataset",
type=ResourceType.dataset,
provider_resource_id="",
provider_id="",
purpose=DatasetPurpose.post_training_messages,
@ -117,11 +124,11 @@ class TestNvidiaDatastore(unittest.TestCase):
metadata={"format": "jsonl"},
)
self.run_async(custom_adapter.register_dataset(dataset_def))
run_async(custom_adapter.register_dataset(dataset_def))
self.mock_make_request.assert_called_once()
self._assert_request(
self.mock_make_request,
mock_make_request.assert_called_once()
_assert_request(
mock_make_request,
"POST",
"/v1/datasets",
expected_json={
@ -132,7 +139,3 @@ class TestNvidiaDatastore(unittest.TestCase):
"format": "jsonl",
},
)
if __name__ == "__main__":
unittest.main()

View file

@ -5,7 +5,6 @@
# the root directory of this source tree.
import os
import unittest
import warnings
from unittest.mock import patch
@ -20,21 +19,20 @@ from llama_stack.apis.post_training.post_training import (
OptimizerType,
TrainingConfig,
)
from llama_stack.distribution.library_client import convert_pydantic_to_json_value
from llama_stack.core.library_client import convert_pydantic_to_json_value
from llama_stack.providers.remote.post_training.nvidia.post_training import (
NvidiaPostTrainingAdapter,
NvidiaPostTrainingConfig,
)
class TestNvidiaParameters(unittest.TestCase):
def setUp(self):
os.environ["NVIDIA_BASE_URL"] = "http://nemo.test"
class TestNvidiaParameters:
@pytest.fixture(autouse=True)
def setup_and_teardown(self):
"""Setup and teardown for each test method."""
os.environ["NVIDIA_CUSTOMIZER_URL"] = "http://nemo.test"
config = NvidiaPostTrainingConfig(
base_url=os.environ["NVIDIA_BASE_URL"], customizer_url=os.environ["NVIDIA_CUSTOMIZER_URL"], api_key=None
)
config = NvidiaPostTrainingConfig(customizer_url=os.environ["NVIDIA_CUSTOMIZER_URL"], api_key=None)
self.adapter = NvidiaPostTrainingAdapter(config)
self.make_request_patcher = patch(
@ -48,7 +46,8 @@ class TestNvidiaParameters(unittest.TestCase):
"updated_at": "2025-03-04T13:07:47.543605",
}
def tearDown(self):
yield
self.make_request_patcher.stop()
def _assert_request_params(self, expected_json):
@ -166,8 +165,8 @@ class TestNvidiaParameters(unittest.TestCase):
self.run_async(
self.adapter.supervised_fine_tune(
job_uuid=required_job_uuid, # Required parameter
model=required_model, # Required parameter
job_uuid=required_job_uuid,
model=required_model,
checkpoint_dir="",
algorithm_config=algorithm_config,
training_config=convert_pydantic_to_json_value(training_config),
@ -198,7 +197,6 @@ class TestNvidiaParameters(unittest.TestCase):
data_config = DataConfig(
dataset_id="test-dataset",
batch_size=8,
# Unsupported parameters
shuffle=True,
data_format=DatasetFormat.instruct,
validation_dataset_id="val-dataset",
@ -207,20 +205,16 @@ class TestNvidiaParameters(unittest.TestCase):
optimizer_config = OptimizerConfig(
lr=0.0001,
weight_decay=0.01,
# Unsupported parameters
optimizer_type=OptimizerType.adam,
num_warmup_steps=100,
)
efficiency_config = EfficiencyConfig(
enable_activation_checkpointing=True # Unsupported parameter
)
efficiency_config = EfficiencyConfig(enable_activation_checkpointing=True)
training_config = TrainingConfig(
n_epochs=1,
data_config=data_config,
optimizer_config=optimizer_config,
# Unsupported parameters
efficiency_config=efficiency_config,
max_steps_per_epoch=1000,
gradient_accumulation_steps=4,
@ -228,7 +222,6 @@ class TestNvidiaParameters(unittest.TestCase):
dtype="bf16",
)
# Capture warnings
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
@ -236,7 +229,7 @@ class TestNvidiaParameters(unittest.TestCase):
self.adapter.supervised_fine_tune(
job_uuid="test-job",
model="meta-llama/Llama-3.1-8B-Instruct",
checkpoint_dir="test-dir", # Unsupported parameter
checkpoint_dir="test-dir",
algorithm_config=LoraFinetuningConfig(
type="LoRA",
apply_lora_to_mlp=True,
@ -246,8 +239,8 @@ class TestNvidiaParameters(unittest.TestCase):
lora_attn_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
),
training_config=convert_pydantic_to_json_value(training_config),
logger_config={"test": "value"}, # Unsupported parameter
hyperparam_search_config={"test": "value"}, # Unsupported parameter
logger_config={"test": "value"},
hyperparam_search_config={"test": "value"},
)
)
@ -265,7 +258,6 @@ class TestNvidiaParameters(unittest.TestCase):
"gradient_accumulation_steps",
"max_validation_steps",
"dtype",
# required unsupported parameters
"rank",
"apply_lora_to_output",
"lora_attn_modules",
@ -273,7 +265,3 @@ class TestNvidiaParameters(unittest.TestCase):
]
for field in fields:
assert any(field in text for text in warning_texts)
if __name__ == "__main__":
unittest.main()

View file

@ -5,321 +5,353 @@
# the root directory of this source tree.
import os
import unittest
from typing import Any
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from llama_stack.apis.inference import CompletionMessage, UserMessage
from llama_stack.apis.resource import ResourceType
from llama_stack.apis.safety import RunShieldResponse, ViolationLevel
from llama_stack.apis.shields import Shield
from llama_stack.models.llama.datatypes import StopReason
from llama_stack.providers.remote.safety.nvidia.config import NVIDIASafetyConfig
from llama_stack.providers.remote.safety.nvidia.nvidia import NVIDIASafetyAdapter
class TestNVIDIASafetyAdapter(unittest.TestCase):
def setUp(self):
os.environ["NVIDIA_GUARDRAILS_URL"] = "http://nemo.test"
class TestNVIDIASafetyAdapter(NVIDIASafetyAdapter):
"""Test implementation that provides the required shield_store."""
# Initialize the adapter
self.config = NVIDIASafetyConfig(
guardrails_service_url=os.environ["NVIDIA_GUARDRAILS_URL"],
)
self.adapter = NVIDIASafetyAdapter(config=self.config)
self.shield_store = AsyncMock()
self.adapter.shield_store = self.shield_store
def __init__(self, config: NVIDIASafetyConfig, shield_store):
super().__init__(config)
self.shield_store = shield_store
# Mock the HTTP request methods
self.guardrails_post_patcher = patch(
"llama_stack.providers.remote.safety.nvidia.nvidia.NeMoGuardrails._guardrails_post"
)
self.mock_guardrails_post = self.guardrails_post_patcher.start()
self.mock_guardrails_post.return_value = {"status": "allowed"}
def tearDown(self):
"""Clean up after each test."""
self.guardrails_post_patcher.stop()
@pytest.fixture
def nvidia_adapter():
"""Set up the NVIDIASafetyAdapter for testing."""
os.environ["NVIDIA_GUARDRAILS_URL"] = "http://nemo.test"
@pytest.fixture(autouse=True)
def inject_fixtures(self, run_async):
self.run_async = run_async
# Initialize the adapter
config = NVIDIASafetyConfig(
guardrails_service_url=os.environ["NVIDIA_GUARDRAILS_URL"],
)
def _assert_request(
self,
mock_call: MagicMock,
expected_url: str,
expected_headers: dict[str, str] | None = None,
expected_json: dict[str, Any] | None = None,
) -> None:
"""
Helper method to verify request details in mock API calls.
# Create a mock shield store that implements the ShieldStore protocol
shield_store = AsyncMock()
shield_store.get_shield = AsyncMock()
Args:
mock_call: The MagicMock object that was called
expected_url: The expected URL to which the request was made
expected_headers: Optional dictionary of expected request headers
expected_json: Optional dictionary of expected JSON payload
"""
call_args = mock_call.call_args
adapter = TestNVIDIASafetyAdapter(config=config, shield_store=shield_store)
# Check URL
assert call_args[0][0] == expected_url
return adapter
# Check headers if provided
if expected_headers:
for key, value in expected_headers.items():
assert call_args[1]["headers"][key] == value
# Check JSON if provided
if expected_json:
for key, value in expected_json.items():
if isinstance(value, dict):
for nested_key, nested_value in value.items():
assert call_args[1]["json"][key][nested_key] == nested_value
else:
assert call_args[1]["json"][key] == value
@pytest.fixture
def mock_guardrails_post():
"""Mock the HTTP request methods."""
with patch("llama_stack.providers.remote.safety.nvidia.nvidia.NeMoGuardrails._guardrails_post") as mock_post:
mock_post.return_value = {"status": "allowed"}
yield mock_post
def test_register_shield_with_valid_id(self):
shield = Shield(
provider_id="nvidia",
type="shield",
identifier="test-shield",
provider_resource_id="test-model",
)
# Register the shield
self.run_async(self.adapter.register_shield(shield))
def _assert_request(
mock_call: MagicMock,
expected_url: str,
expected_headers: dict[str, str] | None = None,
expected_json: dict[str, Any] | None = None,
) -> None:
"""
Helper method to verify request details in mock API calls.
def test_register_shield_without_id(self):
shield = Shield(
provider_id="nvidia",
type="shield",
identifier="test-shield",
provider_resource_id="",
)
Args:
mock_call: The MagicMock object that was called
expected_url: The expected URL to which the request was made
expected_headers: Optional dictionary of expected request headers
expected_json: Optional dictionary of expected JSON payload
"""
call_args = mock_call.call_args
# Register the shield should raise a ValueError
with self.assertRaises(ValueError):
self.run_async(self.adapter.register_shield(shield))
# Check URL
assert call_args[0][0] == expected_url
def test_run_shield_allowed(self):
# Set up the shield
shield_id = "test-shield"
shield = Shield(
provider_id="nvidia",
type="shield",
identifier=shield_id,
provider_resource_id="test-model",
)
self.shield_store.get_shield.return_value = shield
# Check headers if provided
if expected_headers:
for key, value in expected_headers.items():
assert call_args[1]["headers"][key] == value
# Mock Guardrails API response
self.mock_guardrails_post.return_value = {"status": "allowed"}
# Check JSON if provided
if expected_json:
for key, value in expected_json.items():
if isinstance(value, dict):
for nested_key, nested_value in value.items():
assert call_args[1]["json"][key][nested_key] == nested_value
else:
assert call_args[1]["json"][key] == value
# Run the shield
messages = [
UserMessage(role="user", content="Hello, how are you?"),
CompletionMessage(
role="assistant",
content="I'm doing well, thank you for asking!",
stop_reason="end_of_message",
tool_calls=[],
),
]
result = self.run_async(self.adapter.run_shield(shield_id, messages))
# Verify the shield store was called
self.shield_store.get_shield.assert_called_once_with(shield_id)
async def test_register_shield_with_valid_id(nvidia_adapter):
adapter = nvidia_adapter
# Verify the Guardrails API was called correctly
self.mock_guardrails_post.assert_called_once_with(
path="/v1/guardrail/checks",
data={
"model": shield_id,
"messages": [
{"role": "user", "content": "Hello, how are you?"},
{"role": "assistant", "content": "I'm doing well, thank you for asking!"},
],
"temperature": 1.0,
"top_p": 1,
"frequency_penalty": 0,
"presence_penalty": 0,
"max_tokens": 160,
"stream": False,
"guardrails": {
"config_id": "self-check",
},
shield = Shield(
provider_id="nvidia",
type=ResourceType.shield,
identifier="test-shield",
provider_resource_id="test-model",
)
# Register the shield
await adapter.register_shield(shield)
async def test_register_shield_without_id(nvidia_adapter):
adapter = nvidia_adapter
shield = Shield(
provider_id="nvidia",
type=ResourceType.shield,
identifier="test-shield",
provider_resource_id="",
)
# Register the shield should raise a ValueError
with pytest.raises(ValueError):
await adapter.register_shield(shield)
async def test_run_shield_allowed(nvidia_adapter, mock_guardrails_post):
adapter = nvidia_adapter
# Set up the shield
shield_id = "test-shield"
shield = Shield(
provider_id="nvidia",
type=ResourceType.shield,
identifier=shield_id,
provider_resource_id="test-model",
)
adapter.shield_store.get_shield.return_value = shield
# Mock Guardrails API response
mock_guardrails_post.return_value = {"status": "allowed"}
# Run the shield
messages = [
UserMessage(role="user", content="Hello, how are you?"),
CompletionMessage(
role="assistant",
content="I'm doing well, thank you for asking!",
stop_reason=StopReason.end_of_message,
tool_calls=[],
),
]
result = await adapter.run_shield(shield_id, messages)
# Verify the shield store was called
adapter.shield_store.get_shield.assert_called_once_with(shield_id)
# Verify the Guardrails API was called correctly
mock_guardrails_post.assert_called_once_with(
path="/v1/guardrail/checks",
data={
"model": shield_id,
"messages": [
{"role": "user", "content": "Hello, how are you?"},
{"role": "assistant", "content": "I'm doing well, thank you for asking!"},
],
"temperature": 1.0,
"top_p": 1,
"frequency_penalty": 0,
"presence_penalty": 0,
"max_tokens": 160,
"stream": False,
"guardrails": {
"config_id": "self-check",
},
)
},
)
# Verify the result
assert isinstance(result, RunShieldResponse)
assert result.violation is None
# Verify the result
assert isinstance(result, RunShieldResponse)
assert result.violation is None
def test_run_shield_blocked(self):
# Set up the shield
shield_id = "test-shield"
shield = Shield(
provider_id="nvidia",
type="shield",
identifier=shield_id,
provider_resource_id="test-model",
)
self.shield_store.get_shield.return_value = shield
# Mock Guardrails API response
self.mock_guardrails_post.return_value = {"status": "blocked", "rails_status": {"reason": "harmful_content"}}
async def test_run_shield_blocked(nvidia_adapter, mock_guardrails_post):
adapter = nvidia_adapter
# Run the shield
messages = [
UserMessage(role="user", content="Hello, how are you?"),
CompletionMessage(
role="assistant",
content="I'm doing well, thank you for asking!",
stop_reason="end_of_message",
tool_calls=[],
),
]
result = self.run_async(self.adapter.run_shield(shield_id, messages))
# Set up the shield
shield_id = "test-shield"
shield = Shield(
provider_id="nvidia",
type=ResourceType.shield,
identifier=shield_id,
provider_resource_id="test-model",
)
adapter.shield_store.get_shield.return_value = shield
# Verify the shield store was called
self.shield_store.get_shield.assert_called_once_with(shield_id)
# Mock Guardrails API response
mock_guardrails_post.return_value = {"status": "blocked", "rails_status": {"reason": "harmful_content"}}
# Verify the Guardrails API was called correctly
self.mock_guardrails_post.assert_called_once_with(
path="/v1/guardrail/checks",
data={
"model": shield_id,
"messages": [
{"role": "user", "content": "Hello, how are you?"},
{"role": "assistant", "content": "I'm doing well, thank you for asking!"},
],
"temperature": 1.0,
"top_p": 1,
"frequency_penalty": 0,
"presence_penalty": 0,
"max_tokens": 160,
"stream": False,
"guardrails": {
"config_id": "self-check",
},
# Run the shield
messages = [
UserMessage(role="user", content="Hello, how are you?"),
CompletionMessage(
role="assistant",
content="I'm doing well, thank you for asking!",
stop_reason=StopReason.end_of_message,
tool_calls=[],
),
]
result = await adapter.run_shield(shield_id, messages)
# Verify the shield store was called
adapter.shield_store.get_shield.assert_called_once_with(shield_id)
# Verify the Guardrails API was called correctly
mock_guardrails_post.assert_called_once_with(
path="/v1/guardrail/checks",
data={
"model": shield_id,
"messages": [
{"role": "user", "content": "Hello, how are you?"},
{"role": "assistant", "content": "I'm doing well, thank you for asking!"},
],
"temperature": 1.0,
"top_p": 1,
"frequency_penalty": 0,
"presence_penalty": 0,
"max_tokens": 160,
"stream": False,
"guardrails": {
"config_id": "self-check",
},
)
},
)
# Verify the result
assert result.violation is not None
assert isinstance(result, RunShieldResponse)
assert result.violation.user_message == "Sorry I cannot do this."
assert result.violation.violation_level == ViolationLevel.ERROR
assert result.violation.metadata == {"reason": "harmful_content"}
# Verify the result
assert result.violation is not None
assert isinstance(result, RunShieldResponse)
assert result.violation.user_message == "Sorry I cannot do this."
assert result.violation.violation_level == ViolationLevel.ERROR
assert result.violation.metadata == {"reason": "harmful_content"}
def test_run_shield_not_found(self):
# Set up shield store to return None
shield_id = "non-existent-shield"
self.shield_store.get_shield.return_value = None
messages = [
UserMessage(role="user", content="Hello, how are you?"),
]
async def test_run_shield_not_found(nvidia_adapter, mock_guardrails_post):
adapter = nvidia_adapter
with self.assertRaises(ValueError):
self.run_async(self.adapter.run_shield(shield_id, messages))
# Set up shield store to return None
shield_id = "non-existent-shield"
adapter.shield_store.get_shield.return_value = None
# Verify the shield store was called
self.shield_store.get_shield.assert_called_once_with(shield_id)
messages = [
UserMessage(role="user", content="Hello, how are you?"),
]
# Verify the Guardrails API was not called
self.mock_guardrails_post.assert_not_called()
with pytest.raises(ValueError):
await adapter.run_shield(shield_id, messages)
def test_run_shield_http_error(self):
shield_id = "test-shield"
shield = Shield(
provider_id="nvidia",
type="shield",
identifier=shield_id,
provider_resource_id="test-model",
)
self.shield_store.get_shield.return_value = shield
# Verify the shield store was called
adapter.shield_store.get_shield.assert_called_once_with(shield_id)
# Mock Guardrails API to raise an exception
error_msg = "API Error: 500 Internal Server Error"
self.mock_guardrails_post.side_effect = Exception(error_msg)
# Verify the Guardrails API was not called
mock_guardrails_post.assert_not_called()
# Running the shield should raise an exception
messages = [
UserMessage(role="user", content="Hello, how are you?"),
CompletionMessage(
role="assistant",
content="I'm doing well, thank you for asking!",
stop_reason="end_of_message",
tool_calls=[],
),
]
with self.assertRaises(Exception) as context:
self.run_async(self.adapter.run_shield(shield_id, messages))
# Verify the shield store was called
self.shield_store.get_shield.assert_called_once_with(shield_id)
async def test_run_shield_http_error(nvidia_adapter, mock_guardrails_post):
adapter = nvidia_adapter
# Verify the Guardrails API was called correctly
self.mock_guardrails_post.assert_called_once_with(
path="/v1/guardrail/checks",
data={
"model": shield_id,
"messages": [
{"role": "user", "content": "Hello, how are you?"},
{"role": "assistant", "content": "I'm doing well, thank you for asking!"},
],
"temperature": 1.0,
"top_p": 1,
"frequency_penalty": 0,
"presence_penalty": 0,
"max_tokens": 160,
"stream": False,
"guardrails": {
"config_id": "self-check",
},
shield_id = "test-shield"
shield = Shield(
provider_id="nvidia",
type=ResourceType.shield,
identifier=shield_id,
provider_resource_id="test-model",
)
adapter.shield_store.get_shield.return_value = shield
# Mock Guardrails API to raise an exception
error_msg = "API Error: 500 Internal Server Error"
mock_guardrails_post.side_effect = Exception(error_msg)
# Running the shield should raise an exception
messages = [
UserMessage(role="user", content="Hello, how are you?"),
CompletionMessage(
role="assistant",
content="I'm doing well, thank you for asking!",
stop_reason=StopReason.end_of_message,
tool_calls=[],
),
]
with pytest.raises(Exception) as exc_info:
await adapter.run_shield(shield_id, messages)
# Verify the shield store was called
adapter.shield_store.get_shield.assert_called_once_with(shield_id)
# Verify the Guardrails API was called correctly
mock_guardrails_post.assert_called_once_with(
path="/v1/guardrail/checks",
data={
"model": shield_id,
"messages": [
{"role": "user", "content": "Hello, how are you?"},
{"role": "assistant", "content": "I'm doing well, thank you for asking!"},
],
"temperature": 1.0,
"top_p": 1,
"frequency_penalty": 0,
"presence_penalty": 0,
"max_tokens": 160,
"stream": False,
"guardrails": {
"config_id": "self-check",
},
)
# Verify the exception message
assert error_msg in str(context.exception)
},
)
# Verify the exception message
assert error_msg in str(exc_info.value)
def test_init_nemo_guardrails(self):
from llama_stack.providers.remote.safety.nvidia.nvidia import NeMoGuardrails
test_config_id = "test-custom-config-id"
config = NVIDIASafetyConfig(
guardrails_service_url=os.environ["NVIDIA_GUARDRAILS_URL"],
config_id=test_config_id,
)
# Initialize with default parameters
test_model = "test-model"
guardrails = NeMoGuardrails(config, test_model)
def test_init_nemo_guardrails():
from llama_stack.providers.remote.safety.nvidia.nvidia import NeMoGuardrails
# Verify the attributes are set correctly
assert guardrails.config_id == test_config_id
assert guardrails.model == test_model
assert guardrails.threshold == 0.9 # Default value
assert guardrails.temperature == 1.0 # Default value
assert guardrails.guardrails_service_url == os.environ["NVIDIA_GUARDRAILS_URL"]
os.environ["NVIDIA_GUARDRAILS_URL"] = "http://nemo.test"
# Initialize with custom parameters
guardrails = NeMoGuardrails(config, test_model, threshold=0.8, temperature=0.7)
test_config_id = "test-custom-config-id"
config = NVIDIASafetyConfig(
guardrails_service_url=os.environ["NVIDIA_GUARDRAILS_URL"],
config_id=test_config_id,
)
# Initialize with default parameters
test_model = "test-model"
guardrails = NeMoGuardrails(config, test_model)
# Verify the attributes are set correctly
assert guardrails.config_id == test_config_id
assert guardrails.model == test_model
assert guardrails.threshold == 0.8
assert guardrails.temperature == 0.7
assert guardrails.guardrails_service_url == os.environ["NVIDIA_GUARDRAILS_URL"]
# Verify the attributes are set correctly
assert guardrails.config_id == test_config_id
assert guardrails.model == test_model
assert guardrails.threshold == 0.9 # Default value
assert guardrails.temperature == 1.0 # Default value
assert guardrails.guardrails_service_url == os.environ["NVIDIA_GUARDRAILS_URL"]
def test_init_nemo_guardrails_invalid_temperature(self):
from llama_stack.providers.remote.safety.nvidia.nvidia import NeMoGuardrails
# Initialize with custom parameters
guardrails = NeMoGuardrails(config, test_model, threshold=0.8, temperature=0.7)
config = NVIDIASafetyConfig(
guardrails_service_url=os.environ["NVIDIA_GUARDRAILS_URL"],
config_id="test-custom-config-id",
)
with self.assertRaises(ValueError):
NeMoGuardrails(config, "test-model", temperature=0)
# Verify the attributes are set correctly
assert guardrails.config_id == test_config_id
assert guardrails.model == test_model
assert guardrails.threshold == 0.8
assert guardrails.temperature == 0.7
assert guardrails.guardrails_service_url == os.environ["NVIDIA_GUARDRAILS_URL"]
def test_init_nemo_guardrails_invalid_temperature():
from llama_stack.providers.remote.safety.nvidia.nvidia import NeMoGuardrails
os.environ["NVIDIA_GUARDRAILS_URL"] = "http://nemo.test"
config = NVIDIASafetyConfig(
guardrails_service_url=os.environ["NVIDIA_GUARDRAILS_URL"],
config_id="test-custom-config-id",
)
with pytest.raises(ValueError):
NeMoGuardrails(config, "test-model", temperature=0)

View file

@ -5,13 +5,11 @@
# the root directory of this source tree.
import os
import unittest
import warnings
from unittest.mock import patch
import pytest
from llama_stack.apis.models import Model, ModelType
from llama_stack.apis.post_training.post_training import (
DataConfig,
DatasetFormat,
@ -21,8 +19,7 @@ from llama_stack.apis.post_training.post_training import (
QATFinetuningConfig,
TrainingConfig,
)
from llama_stack.distribution.library_client import convert_pydantic_to_json_value
from llama_stack.providers.remote.inference.nvidia.nvidia import NVIDIAConfig, NVIDIAInferenceAdapter
from llama_stack.core.library_client import convert_pydantic_to_json_value
from llama_stack.providers.remote.post_training.nvidia.post_training import (
ListNvidiaPostTrainingJobs,
NvidiaPostTrainingAdapter,
@ -32,331 +29,297 @@ from llama_stack.providers.remote.post_training.nvidia.post_training import (
)
class TestNvidiaPostTraining(unittest.TestCase):
def setUp(self):
os.environ["NVIDIA_BASE_URL"] = "http://nemo.test" # needed for llm inference
os.environ["NVIDIA_CUSTOMIZER_URL"] = "http://nemo.test" # needed for nemo customizer
@pytest.fixture
def nvidia_post_training_adapter():
"""Fixture to create and configure the NVIDIA post training adapter."""
os.environ["NVIDIA_CUSTOMIZER_URL"] = "http://nemo.test" # needed for nemo customizer
config = NvidiaPostTrainingConfig(
base_url=os.environ["NVIDIA_BASE_URL"], customizer_url=os.environ["NVIDIA_CUSTOMIZER_URL"], api_key=None
config = NvidiaPostTrainingConfig(customizer_url=os.environ["NVIDIA_CUSTOMIZER_URL"], api_key=None)
adapter = NvidiaPostTrainingAdapter(config)
with patch.object(adapter, "_make_request") as mock_make_request:
yield adapter, mock_make_request
def _assert_request(mock_call, expected_method, expected_path, expected_params=None, expected_json=None):
"""Helper method to verify request details in mock calls."""
call_args = mock_call.call_args
if expected_method and expected_path:
if isinstance(call_args[0], tuple) and len(call_args[0]) == 2:
assert call_args[0] == (expected_method, expected_path)
else:
assert call_args[1]["method"] == expected_method
assert call_args[1]["path"] == expected_path
if expected_params:
assert call_args[1]["params"] == expected_params
if expected_json:
for key, value in expected_json.items():
assert call_args[1]["json"][key] == value
async def test_supervised_fine_tune(nvidia_post_training_adapter):
"""Test the supervised fine-tuning API call."""
adapter, mock_make_request = nvidia_post_training_adapter
mock_make_request.return_value = {
"id": "cust-JGTaMbJMdqjJU8WbQdN9Q2",
"created_at": "2024-12-09T04:06:28.542884",
"updated_at": "2024-12-09T04:06:28.542884",
"config": {
"schema_version": "1.0",
"id": "af783f5b-d985-4e5b-bbb7-f9eec39cc0b1",
"created_at": "2024-12-09T04:06:28.542657",
"updated_at": "2024-12-09T04:06:28.569837",
"custom_fields": {},
"name": "meta-llama/Llama-3.1-8B-Instruct",
"base_model": "meta-llama/Llama-3.1-8B-Instruct",
"model_path": "llama-3_1-8b-instruct",
"training_types": [],
"finetuning_types": ["lora"],
"precision": "bf16",
"num_gpus": 4,
"num_nodes": 1,
"micro_batch_size": 1,
"tensor_parallel_size": 1,
"max_seq_length": 4096,
},
"dataset": {
"schema_version": "1.0",
"id": "dataset-XU4pvGzr5tvawnbVxeJMTb",
"created_at": "2024-12-09T04:06:28.542657",
"updated_at": "2024-12-09T04:06:28.542660",
"custom_fields": {},
"name": "sample-basic-test",
"version_id": "main",
"version_tags": [],
},
"hyperparameters": {
"finetuning_type": "lora",
"training_type": "sft",
"batch_size": 16,
"epochs": 2,
"learning_rate": 0.0001,
"lora": {"alpha": 16},
},
"output_model": "default/job-1234",
"status": "created",
"project": "default",
"custom_fields": {},
"ownership": {"created_by": "me", "access_policies": {}},
}
algorithm_config = LoraFinetuningConfig(
type="LoRA",
apply_lora_to_mlp=True,
apply_lora_to_output=True,
alpha=16,
rank=16,
lora_attn_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
)
data_config = DataConfig(
dataset_id="sample-basic-test", batch_size=16, shuffle=False, data_format=DatasetFormat.instruct
)
optimizer_config = OptimizerConfig(
optimizer_type=OptimizerType.adam,
lr=0.0001,
weight_decay=0.01,
num_warmup_steps=100,
)
training_config = TrainingConfig(
n_epochs=2,
data_config=data_config,
optimizer_config=optimizer_config,
)
with warnings.catch_warnings(record=True):
warnings.simplefilter("always")
training_job = await adapter.supervised_fine_tune(
job_uuid="1234",
model="meta/llama-3.2-1b-instruct@v1.0.0+L40",
checkpoint_dir="",
algorithm_config=algorithm_config,
training_config=convert_pydantic_to_json_value(training_config),
logger_config={},
hyperparam_search_config={},
)
self.adapter = NvidiaPostTrainingAdapter(config)
self.make_request_patcher = patch(
"llama_stack.providers.remote.post_training.nvidia.post_training.NvidiaPostTrainingAdapter._make_request"
)
self.mock_make_request = self.make_request_patcher.start()
# Mock the inference client
inference_config = NVIDIAConfig(base_url=os.environ["NVIDIA_BASE_URL"], api_key=None)
self.inference_adapter = NVIDIAInferenceAdapter(inference_config)
# check the output is a PostTrainingJob
assert isinstance(training_job, NvidiaPostTrainingJob)
assert training_job.job_uuid == "cust-JGTaMbJMdqjJU8WbQdN9Q2"
self.mock_client = unittest.mock.MagicMock()
self.mock_client.chat.completions.create = unittest.mock.AsyncMock()
self.inference_mock_make_request = self.mock_client.chat.completions.create
self.inference_make_request_patcher = patch(
"llama_stack.providers.remote.inference.nvidia.nvidia.NVIDIAInferenceAdapter._get_client",
return_value=self.mock_client,
)
self.inference_make_request_patcher.start()
def tearDown(self):
self.make_request_patcher.stop()
self.inference_make_request_patcher.stop()
@pytest.fixture(autouse=True)
def inject_fixtures(self, run_async):
self.run_async = run_async
def _assert_request(self, mock_call, expected_method, expected_path, expected_params=None, expected_json=None):
"""Helper method to verify request details in mock calls."""
call_args = mock_call.call_args
if expected_method and expected_path:
if isinstance(call_args[0], tuple) and len(call_args[0]) == 2:
assert call_args[0] == (expected_method, expected_path)
else:
assert call_args[1]["method"] == expected_method
assert call_args[1]["path"] == expected_path
if expected_params:
assert call_args[1]["params"] == expected_params
if expected_json:
for key, value in expected_json.items():
assert call_args[1]["json"][key] == value
def test_supervised_fine_tune(self):
"""Test the supervised fine-tuning API call."""
self.mock_make_request.return_value = {
"id": "cust-JGTaMbJMdqjJU8WbQdN9Q2",
"created_at": "2024-12-09T04:06:28.542884",
"updated_at": "2024-12-09T04:06:28.542884",
"config": {
"schema_version": "1.0",
"id": "af783f5b-d985-4e5b-bbb7-f9eec39cc0b1",
"created_at": "2024-12-09T04:06:28.542657",
"updated_at": "2024-12-09T04:06:28.569837",
"custom_fields": {},
"name": "meta-llama/Llama-3.1-8B-Instruct",
"base_model": "meta-llama/Llama-3.1-8B-Instruct",
"model_path": "llama-3_1-8b-instruct",
"training_types": [],
"finetuning_types": ["lora"],
"precision": "bf16",
"num_gpus": 4,
"num_nodes": 1,
"micro_batch_size": 1,
"tensor_parallel_size": 1,
"max_seq_length": 4096,
},
"dataset": {
"schema_version": "1.0",
"id": "dataset-XU4pvGzr5tvawnbVxeJMTb",
"created_at": "2024-12-09T04:06:28.542657",
"updated_at": "2024-12-09T04:06:28.542660",
"custom_fields": {},
"name": "sample-basic-test",
"version_id": "main",
"version_tags": [],
},
mock_make_request.assert_called_once()
_assert_request(
mock_make_request,
"POST",
"/v1/customization/jobs",
expected_json={
"config": "meta/llama-3.2-1b-instruct@v1.0.0+L40",
"dataset": {"name": "sample-basic-test", "namespace": "default"},
"hyperparameters": {
"finetuning_type": "lora",
"training_type": "sft",
"batch_size": 16,
"finetuning_type": "lora",
"epochs": 2,
"batch_size": 16,
"learning_rate": 0.0001,
"weight_decay": 0.01,
"lora": {"alpha": 16},
},
"output_model": "default/job-1234",
"status": "created",
"project": "default",
"custom_fields": {},
"ownership": {"created_by": "me", "access_policies": {}},
},
)
async def test_supervised_fine_tune_with_qat(nvidia_post_training_adapter):
"""Test that QAT configuration raises NotImplementedError."""
adapter, mock_make_request = nvidia_post_training_adapter
algorithm_config = QATFinetuningConfig(type="QAT", quantizer_name="quantizer_name", group_size=1)
data_config = DataConfig(
dataset_id="sample-basic-test", batch_size=16, shuffle=False, data_format=DatasetFormat.instruct
)
optimizer_config = OptimizerConfig(
optimizer_type=OptimizerType.adam,
lr=0.0001,
weight_decay=0.01,
num_warmup_steps=100,
)
training_config = TrainingConfig(
n_epochs=2,
data_config=data_config,
optimizer_config=optimizer_config,
)
# This will raise NotImplementedError since QAT is not supported
with pytest.raises(NotImplementedError):
await adapter.supervised_fine_tune(
job_uuid="1234",
model="meta/llama-3.2-1b-instruct@v1.0.0+L40",
checkpoint_dir="",
algorithm_config=algorithm_config,
training_config=convert_pydantic_to_json_value(training_config),
logger_config={},
hyperparam_search_config={},
)
async def test_get_training_job_status(nvidia_post_training_adapter):
"""Test getting training job status with different statuses."""
adapter, mock_make_request = nvidia_post_training_adapter
customizer_status_to_job_status = [
("running", "in_progress"),
("completed", "completed"),
("failed", "failed"),
("cancelled", "cancelled"),
("pending", "scheduled"),
("unknown", "scheduled"),
]
for customizer_status, expected_status in customizer_status_to_job_status:
mock_make_request.return_value = {
"created_at": "2024-12-09T04:06:28.580220",
"updated_at": "2024-12-09T04:21:19.852832",
"status": customizer_status,
"steps_completed": 1210,
"epochs_completed": 2,
"percentage_done": 100.0,
"best_epoch": 2,
"train_loss": 1.718016266822815,
"val_loss": 1.8661999702453613,
}
algorithm_config = LoraFinetuningConfig(
type="LoRA",
apply_lora_to_mlp=True,
apply_lora_to_output=True,
alpha=16,
rank=16,
lora_attn_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
)
data_config = DataConfig(
dataset_id="sample-basic-test", batch_size=16, shuffle=False, data_format=DatasetFormat.instruct
)
optimizer_config = OptimizerConfig(
optimizer_type=OptimizerType.adam,
lr=0.0001,
weight_decay=0.01,
num_warmup_steps=100,
)
training_config = TrainingConfig(
n_epochs=2,
data_config=data_config,
optimizer_config=optimizer_config,
)
with warnings.catch_warnings(record=True):
warnings.simplefilter("always")
training_job = self.run_async(
self.adapter.supervised_fine_tune(
job_uuid="1234",
model="meta/llama-3.2-1b-instruct@v1.0.0+L40",
checkpoint_dir="",
algorithm_config=algorithm_config,
training_config=convert_pydantic_to_json_value(training_config),
logger_config={},
hyperparam_search_config={},
)
)
# check the output is a PostTrainingJob
assert isinstance(training_job, NvidiaPostTrainingJob)
assert training_job.job_uuid == "cust-JGTaMbJMdqjJU8WbQdN9Q2"
self.mock_make_request.assert_called_once()
self._assert_request(
self.mock_make_request,
"POST",
"/v1/customization/jobs",
expected_json={
"config": "meta/llama-3.2-1b-instruct@v1.0.0+L40",
"dataset": {"name": "sample-basic-test", "namespace": "default"},
"hyperparameters": {
"training_type": "sft",
"finetuning_type": "lora",
"epochs": 2,
"batch_size": 16,
"learning_rate": 0.0001,
"weight_decay": 0.01,
"lora": {"alpha": 16},
},
},
)
def test_supervised_fine_tune_with_qat(self):
algorithm_config = QATFinetuningConfig(type="QAT", quantizer_name="quantizer_name", group_size=1)
data_config = DataConfig(
dataset_id="sample-basic-test", batch_size=16, shuffle=False, data_format=DatasetFormat.instruct
)
optimizer_config = OptimizerConfig(
optimizer_type=OptimizerType.adam,
lr=0.0001,
weight_decay=0.01,
num_warmup_steps=100,
)
training_config = TrainingConfig(
n_epochs=2,
data_config=data_config,
optimizer_config=optimizer_config,
)
# This will raise NotImplementedError since QAT is not supported
with self.assertRaises(NotImplementedError):
self.run_async(
self.adapter.supervised_fine_tune(
job_uuid="1234",
model="meta/llama-3.2-1b-instruct@v1.0.0+L40",
checkpoint_dir="",
algorithm_config=algorithm_config,
training_config=convert_pydantic_to_json_value(training_config),
logger_config={},
hyperparam_search_config={},
)
)
def test_get_training_job_status(self):
customizer_status_to_job_status = [
("running", "in_progress"),
("completed", "completed"),
("failed", "failed"),
("cancelled", "cancelled"),
("pending", "scheduled"),
("unknown", "scheduled"),
]
for customizer_status, expected_status in customizer_status_to_job_status:
with self.subTest(customizer_status=customizer_status, expected_status=expected_status):
self.mock_make_request.return_value = {
"created_at": "2024-12-09T04:06:28.580220",
"updated_at": "2024-12-09T04:21:19.852832",
"status": customizer_status,
"steps_completed": 1210,
"epochs_completed": 2,
"percentage_done": 100.0,
"best_epoch": 2,
"train_loss": 1.718016266822815,
"val_loss": 1.8661999702453613,
}
job_id = "cust-JGTaMbJMdqjJU8WbQdN9Q2"
status = self.run_async(self.adapter.get_training_job_status(job_uuid=job_id))
assert isinstance(status, NvidiaPostTrainingJobStatusResponse)
assert status.status.value == expected_status
assert status.steps_completed == 1210
assert status.epochs_completed == 2
assert status.percentage_done == 100.0
assert status.best_epoch == 2
assert status.train_loss == 1.718016266822815
assert status.val_loss == 1.8661999702453613
self._assert_request(
self.mock_make_request,
"GET",
f"/v1/customization/jobs/{job_id}/status",
expected_params={"job_id": job_id},
)
def test_get_training_jobs(self):
job_id = "cust-JGTaMbJMdqjJU8WbQdN9Q2"
self.mock_make_request.return_value = {
"data": [
{
"id": job_id,
"created_at": "2024-12-09T04:06:28.542884",
"updated_at": "2024-12-09T04:21:19.852832",
"config": {
"name": "meta-llama/Llama-3.1-8B-Instruct",
"base_model": "meta-llama/Llama-3.1-8B-Instruct",
},
"dataset": {"name": "default/sample-basic-test"},
"hyperparameters": {
"finetuning_type": "lora",
"training_type": "sft",
"batch_size": 16,
"epochs": 2,
"learning_rate": 0.0001,
"lora": {"adapter_dim": 16, "adapter_dropout": 0.1},
},
"output_model": "default/job-1234",
"status": "completed",
"project": "default",
}
]
}
jobs = self.run_async(self.adapter.get_training_jobs())
status = await adapter.get_training_job_status(job_uuid=job_id)
assert isinstance(jobs, ListNvidiaPostTrainingJobs)
assert len(jobs.data) == 1
job = jobs.data[0]
assert job.job_uuid == job_id
assert job.status.value == "completed"
assert isinstance(status, NvidiaPostTrainingJobStatusResponse)
assert status.status.value == expected_status
# Note: The response object inherits extra fields via ConfigDict(extra="allow")
# So these attributes should be accessible using getattr with defaults
assert getattr(status, "steps_completed", None) == 1210
assert getattr(status, "epochs_completed", None) == 2
assert getattr(status, "percentage_done", None) == 100.0
assert getattr(status, "best_epoch", None) == 2
assert getattr(status, "train_loss", None) == 1.718016266822815
assert getattr(status, "val_loss", None) == 1.8661999702453613
self.mock_make_request.assert_called_once()
self._assert_request(
self.mock_make_request,
_assert_request(
mock_make_request,
"GET",
"/v1/customization/jobs",
expected_params={"page": 1, "page_size": 10, "sort": "created_at"},
)
def test_cancel_training_job(self):
self.mock_make_request.return_value = {} # Empty response for successful cancellation
job_id = "cust-JGTaMbJMdqjJU8WbQdN9Q2"
result = self.run_async(self.adapter.cancel_training_job(job_uuid=job_id))
assert result is None
self.mock_make_request.assert_called_once()
self._assert_request(
self.mock_make_request,
"POST",
f"/v1/customization/jobs/{job_id}/cancel",
f"/v1/customization/jobs/{job_id}/status",
expected_params={"job_id": job_id},
)
def test_inference_register_model(self):
model_id = "default/job-1234"
model_type = ModelType.llm
model = Model(
identifier=model_id,
provider_id="nvidia",
provider_model_id=model_id,
provider_resource_id=model_id,
model_type=model_type,
)
result = self.run_async(self.inference_adapter.register_model(model))
assert result == model
assert len(self.inference_adapter.alias_to_provider_id_map) > 1
assert self.inference_adapter.get_provider_model_id(model.provider_model_id) == model_id
with patch.object(self.inference_adapter, "chat_completion") as mock_chat_completion:
self.run_async(
self.inference_adapter.chat_completion(
model_id=model_id,
messages=[{"role": "user", "content": "Hello, model"}],
)
)
mock_chat_completion.assert_called()
mock_make_request.reset_mock()
if __name__ == "__main__":
unittest.main()
async def test_get_training_jobs(nvidia_post_training_adapter):
"""Test getting list of training jobs."""
adapter, mock_make_request = nvidia_post_training_adapter
job_id = "cust-JGTaMbJMdqjJU8WbQdN9Q2"
mock_make_request.return_value = {
"data": [
{
"id": job_id,
"created_at": "2024-12-09T04:06:28.542884",
"updated_at": "2024-12-09T04:21:19.852832",
"config": {
"name": "meta-llama/Llama-3.1-8B-Instruct",
"base_model": "meta-llama/Llama-3.1-8B-Instruct",
},
"dataset": {"name": "default/sample-basic-test"},
"hyperparameters": {
"finetuning_type": "lora",
"training_type": "sft",
"batch_size": 16,
"epochs": 2,
"learning_rate": 0.0001,
"lora": {"adapter_dim": 16, "adapter_dropout": 0.1},
},
"output_model": "default/job-1234",
"status": "completed",
"project": "default",
}
]
}
jobs = await adapter.get_training_jobs()
assert isinstance(jobs, ListNvidiaPostTrainingJobs)
assert len(jobs.data) == 1
job = jobs.data[0]
assert job.job_uuid == job_id
assert job.status.value == "completed"
mock_make_request.assert_called_once()
_assert_request(
mock_make_request,
"GET",
"/v1/customization/jobs",
expected_params={"page": 1, "page_size": 10, "sort": "created_at"},
)
async def test_cancel_training_job(nvidia_post_training_adapter):
"""Test canceling a training job."""
adapter, mock_make_request = nvidia_post_training_adapter
mock_make_request.return_value = {} # Empty response for successful cancellation
job_id = "cust-JGTaMbJMdqjJU8WbQdN9Q2"
result = await adapter.cancel_training_job(job_uuid=job_id)
assert result is None
mock_make_request.assert_called_once()
_assert_request(
mock_make_request,
"POST",
f"/v1/customization/jobs/{job_id}/cancel",
expected_params={"job_id": job_id},
)

View file

@ -7,8 +7,8 @@
import pytest
from pydantic import BaseModel
from llama_stack.distribution.distribution import get_provider_registry, providable_apis
from llama_stack.distribution.utils.dynamic import instantiate_class_type
from llama_stack.core.distribution import get_provider_registry, providable_apis
from llama_stack.core.utils.dynamic import instantiate_class_type
class TestProviderConfigurations:

View file

@ -5,13 +5,18 @@
# the root directory of this source tree.
import pytest
from pydantic import ValidationError
from llama_stack.apis.common.content_types import TextContentItem
from llama_stack.apis.inference import (
CompletionMessage,
OpenAIAssistantMessageParam,
OpenAIChatCompletionContentPartImageParam,
OpenAIChatCompletionContentPartTextParam,
OpenAIDeveloperMessageParam,
OpenAIImageURL,
OpenAISystemMessageParam,
OpenAIToolMessageParam,
OpenAIUserMessageParam,
SystemMessage,
UserMessage,
@ -23,7 +28,6 @@ from llama_stack.providers.utils.inference.openai_compat import (
)
@pytest.mark.asyncio
async def test_convert_message_to_openai_dict():
message = UserMessage(content=[TextContentItem(text="Hello, world!")], role="user")
assert await convert_message_to_openai_dict(message) == {
@ -33,7 +37,6 @@ async def test_convert_message_to_openai_dict():
# Test convert_message_to_openai_dict with a tool call
@pytest.mark.asyncio
async def test_convert_message_to_openai_dict_with_tool_call():
message = CompletionMessage(
content="",
@ -54,7 +57,6 @@ async def test_convert_message_to_openai_dict_with_tool_call():
}
@pytest.mark.asyncio
async def test_convert_message_to_openai_dict_with_builtin_tool_call():
message = CompletionMessage(
content="",
@ -80,7 +82,6 @@ async def test_convert_message_to_openai_dict_with_builtin_tool_call():
}
@pytest.mark.asyncio
async def test_openai_messages_to_messages_with_content_str():
openai_messages = [
OpenAISystemMessageParam(content="system message"),
@ -98,7 +99,6 @@ async def test_openai_messages_to_messages_with_content_str():
assert llama_messages[2].content == "assistant message"
@pytest.mark.asyncio
async def test_openai_messages_to_messages_with_content_list():
openai_messages = [
OpenAISystemMessageParam(content=[OpenAIChatCompletionContentPartTextParam(text="system message")]),
@ -114,3 +114,71 @@ async def test_openai_messages_to_messages_with_content_list():
assert llama_messages[0].content[0].text == "system message"
assert llama_messages[1].content[0].text == "user message"
assert llama_messages[2].content[0].text == "assistant message"
@pytest.mark.parametrize(
"message_class,kwargs",
[
(OpenAISystemMessageParam, {}),
(OpenAIAssistantMessageParam, {}),
(OpenAIDeveloperMessageParam, {}),
(OpenAIUserMessageParam, {}),
(OpenAIToolMessageParam, {"tool_call_id": "call_123"}),
],
)
def test_message_accepts_text_string(message_class, kwargs):
"""Test that messages accept string text content."""
msg = message_class(content="Test message", **kwargs)
assert msg.content == "Test message"
@pytest.mark.parametrize(
"message_class,kwargs",
[
(OpenAISystemMessageParam, {}),
(OpenAIAssistantMessageParam, {}),
(OpenAIDeveloperMessageParam, {}),
(OpenAIUserMessageParam, {}),
(OpenAIToolMessageParam, {"tool_call_id": "call_123"}),
],
)
def test_message_accepts_text_list(message_class, kwargs):
"""Test that messages accept list of text content parts."""
content_list = [OpenAIChatCompletionContentPartTextParam(text="Test message")]
msg = message_class(content=content_list, **kwargs)
assert len(msg.content) == 1
assert msg.content[0].text == "Test message"
@pytest.mark.parametrize(
"message_class,kwargs",
[
(OpenAISystemMessageParam, {}),
(OpenAIAssistantMessageParam, {}),
(OpenAIDeveloperMessageParam, {}),
(OpenAIToolMessageParam, {"tool_call_id": "call_123"}),
],
)
def test_message_rejects_images(message_class, kwargs):
"""Test that system, assistant, developer, and tool messages reject image content."""
with pytest.raises(ValidationError):
message_class(
content=[
OpenAIChatCompletionContentPartImageParam(image_url=OpenAIImageURL(url="http://example.com/image.jpg"))
],
**kwargs,
)
def test_user_message_accepts_images():
"""Test that user messages accept image content (unlike other message types)."""
# List with images should work
msg = OpenAIUserMessageParam(
content=[
OpenAIChatCompletionContentPartTextParam(text="Describe this image:"),
OpenAIChatCompletionContentPartImageParam(image_url=OpenAIImageURL(url="http://example.com/image.jpg")),
]
)
assert len(msg.content) == 2
assert msg.content[0].text == "Describe this image:"
assert msg.content[1].image_url.url == "http://example.com/image.jpg"

View file

@ -13,7 +13,6 @@ from llama_stack.apis.tools import RAGDocument
from llama_stack.providers.utils.memory.vector_store import content_from_data_and_mime_type, content_from_doc
@pytest.mark.asyncio
async def test_content_from_doc_with_url():
"""Test extracting content from RAGDocument with URL content."""
mock_url = URL(uri="https://example.com")
@ -33,7 +32,6 @@ async def test_content_from_doc_with_url():
mock_instance.get.assert_called_once_with(mock_url.uri)
@pytest.mark.asyncio
async def test_content_from_doc_with_pdf_url():
"""Test extracting content from RAGDocument with URL pointing to a PDF."""
mock_url = URL(uri="https://example.com/document.pdf")
@ -58,7 +56,6 @@ async def test_content_from_doc_with_pdf_url():
mock_parse_pdf.assert_called_once_with(b"PDF binary data")
@pytest.mark.asyncio
async def test_content_from_doc_with_data_url():
"""Test extracting content from RAGDocument with data URL content."""
data_url = "data:text/plain;base64,SGVsbG8gV29ybGQ=" # "Hello World" base64 encoded
@ -74,7 +71,6 @@ async def test_content_from_doc_with_data_url():
mock_content_from_data.assert_called_once_with(data_url)
@pytest.mark.asyncio
async def test_content_from_doc_with_string():
"""Test extracting content from RAGDocument with string content."""
content_string = "This is plain text content"
@ -85,7 +81,6 @@ async def test_content_from_doc_with_string():
assert result == content_string
@pytest.mark.asyncio
async def test_content_from_doc_with_string_url():
"""Test extracting content from RAGDocument with string URL content."""
url_string = "https://example.com"
@ -105,7 +100,6 @@ async def test_content_from_doc_with_string_url():
mock_instance.get.assert_called_once_with(url_string)
@pytest.mark.asyncio
async def test_content_from_doc_with_string_pdf_url():
"""Test extracting content from RAGDocument with string URL pointing to a PDF."""
url_string = "https://example.com/document.pdf"
@ -130,7 +124,6 @@ async def test_content_from_doc_with_string_pdf_url():
mock_parse_pdf.assert_called_once_with(b"PDF binary data")
@pytest.mark.asyncio
async def test_content_from_doc_with_interleaved_content():
"""Test extracting content from RAGDocument with InterleavedContent (the new case added in the commit)."""
interleaved_content = [TextContentItem(text="First item"), TextContentItem(text="Second item")]

View file

@ -87,18 +87,46 @@ def helper(known_provider_model: ProviderModelEntry, known_provider_model2: Prov
return ModelRegistryHelper([known_provider_model, known_provider_model2])
@pytest.mark.asyncio
class MockModelRegistryHelperWithDynamicModels(ModelRegistryHelper):
"""Test helper that simulates a provider with dynamically available models."""
def __init__(self, model_entries: list[ProviderModelEntry], available_models: list[str]):
super().__init__(model_entries)
self._available_models = available_models
async def check_model_availability(self, model: str) -> bool:
return model in self._available_models
@pytest.fixture
def dynamic_model() -> Model:
"""A model that's not in static config but available dynamically."""
return Model(
provider_id="provider",
identifier="dynamic-model",
provider_resource_id="dynamic-provider-id",
)
@pytest.fixture
def helper_with_dynamic_models(
known_provider_model: ProviderModelEntry, known_provider_model2: ProviderModelEntry, dynamic_model: Model
) -> MockModelRegistryHelperWithDynamicModels:
"""Helper that includes dynamically available models."""
return MockModelRegistryHelperWithDynamicModels(
[known_provider_model, known_provider_model2], [dynamic_model.provider_resource_id]
)
async def test_lookup_unknown_model(helper: ModelRegistryHelper, unknown_model: Model) -> None:
assert helper.get_provider_model_id(unknown_model.model_id) is None
@pytest.mark.asyncio
async def test_register_unknown_provider_model(helper: ModelRegistryHelper, unknown_model: Model) -> None:
with pytest.raises(ValueError):
await helper.register_model(unknown_model)
@pytest.mark.asyncio
async def test_register_model(helper: ModelRegistryHelper, known_model: Model) -> None:
model = Model(
provider_id=known_model.provider_id,
@ -110,7 +138,6 @@ async def test_register_model(helper: ModelRegistryHelper, known_model: Model) -
assert helper.get_provider_model_id(model.model_id) == model.provider_resource_id
@pytest.mark.asyncio
async def test_register_model_from_alias(helper: ModelRegistryHelper, known_model: Model) -> None:
model = Model(
provider_id=known_model.provider_id,
@ -122,13 +149,11 @@ async def test_register_model_from_alias(helper: ModelRegistryHelper, known_mode
assert helper.get_provider_model_id(model.model_id) == known_model.provider_resource_id
@pytest.mark.asyncio
async def test_register_model_existing(helper: ModelRegistryHelper, known_model: Model) -> None:
await helper.register_model(known_model)
assert helper.get_provider_model_id(known_model.model_id) == known_model.provider_resource_id
@pytest.mark.asyncio
async def test_register_model_existing_different(
helper: ModelRegistryHelper, known_model: Model, known_model2: Model
) -> None:
@ -137,27 +162,86 @@ async def test_register_model_existing_different(
await helper.register_model(known_model)
@pytest.mark.asyncio
async def test_unregister_model(helper: ModelRegistryHelper, known_model: Model) -> None:
await helper.register_model(known_model) # duplicate entry
assert helper.get_provider_model_id(known_model.model_id) == known_model.provider_model_id
await helper.unregister_model(known_model.model_id)
assert helper.get_provider_model_id(known_model.model_id) is None
# TODO: unregister_model functionality was removed/disabled by https://github.com/meta-llama/llama-stack/pull/2916
# async def test_unregister_model(helper: ModelRegistryHelper, known_model: Model) -> None:
# await helper.register_model(known_model) # duplicate entry
# assert helper.get_provider_model_id(known_model.model_id) == known_model.provider_model_id
# await helper.unregister_model(known_model.model_id)
# assert helper.get_provider_model_id(known_model.model_id) is None
@pytest.mark.asyncio
async def test_unregister_unknown_model(helper: ModelRegistryHelper, unknown_model: Model) -> None:
with pytest.raises(ValueError):
await helper.unregister_model(unknown_model.model_id)
# TODO: unregister_model functionality was removed/disabled by https://github.com/meta-llama/llama-stack/pull/2916
# async def test_unregister_unknown_model(helper: ModelRegistryHelper, unknown_model: Model) -> None:
# with pytest.raises(ValueError):
# await helper.unregister_model(unknown_model.model_id)
@pytest.mark.asyncio
async def test_register_model_during_init(helper: ModelRegistryHelper, known_model: Model) -> None:
assert helper.get_provider_model_id(known_model.provider_resource_id) == known_model.provider_model_id
@pytest.mark.asyncio
async def test_unregister_model_during_init(helper: ModelRegistryHelper, known_model: Model) -> None:
assert helper.get_provider_model_id(known_model.provider_resource_id) == known_model.provider_model_id
await helper.unregister_model(known_model.provider_resource_id)
assert helper.get_provider_model_id(known_model.provider_resource_id) is None
# TODO: unregister_model functionality was removed/disabled by https://github.com/meta-llama/llama-stack/pull/2916
# async def test_unregister_model_during_init(helper: ModelRegistryHelper, known_model: Model) -> None:
# assert helper.get_provider_model_id(known_model.provider_resource_id) == known_model.provider_model_id
# await helper.unregister_model(known_model.provider_resource_id)
# assert helper.get_provider_model_id(known_model.provider_resource_id) is None
async def test_register_model_from_check_model_availability(
helper_with_dynamic_models: MockModelRegistryHelperWithDynamicModels, dynamic_model: Model
) -> None:
"""Test that models returned by check_model_availability can be registered."""
# Verify the model is not in static config
assert helper_with_dynamic_models.get_provider_model_id(dynamic_model.provider_resource_id) is None
# But it should be available via check_model_availability
is_available = await helper_with_dynamic_models.check_model_availability(dynamic_model.provider_resource_id)
assert is_available
# Registration should succeed
registered_model = await helper_with_dynamic_models.register_model(dynamic_model)
assert registered_model == dynamic_model
# Model should now be registered and accessible
assert (
helper_with_dynamic_models.get_provider_model_id(dynamic_model.model_id) == dynamic_model.provider_resource_id
)
async def test_register_model_not_in_static_or_dynamic(
helper_with_dynamic_models: MockModelRegistryHelperWithDynamicModels, unknown_model: Model
) -> None:
"""Test that models not in static config or dynamic models are rejected."""
# Verify the model is not in static config
assert helper_with_dynamic_models.get_provider_model_id(unknown_model.provider_resource_id) is None
# And not available via check_model_availability
is_available = await helper_with_dynamic_models.check_model_availability(unknown_model.provider_resource_id)
assert not is_available
# Registration should fail with comprehensive error message
with pytest.raises(Exception) as exc_info: # UnsupportedModelError
await helper_with_dynamic_models.register_model(unknown_model)
# Error should include static models and "..." for dynamic models
error_str = str(exc_info.value)
assert "..." in error_str # "..." should be in error message
async def test_register_alias_for_dynamic_model(
helper_with_dynamic_models: MockModelRegistryHelperWithDynamicModels, dynamic_model: Model
) -> None:
"""Test that we can register an alias that maps to a dynamically available model."""
# Create a model with a different identifier but same provider_resource_id
alias_model = Model(
provider_id=dynamic_model.provider_id,
identifier="dynamic-model-alias",
provider_resource_id=dynamic_model.provider_resource_id,
)
# Registration should succeed since the provider_resource_id is available dynamically
registered_model = await helper_with_dynamic_models.register_model(alias_model)
assert registered_model == alias_model
# Both the original provider_resource_id and the new alias should work
assert helper_with_dynamic_models.get_provider_model_id(alias_model.model_id) == dynamic_model.provider_resource_id

View file

@ -11,7 +11,6 @@ import pytest
from llama_stack.providers.utils.scheduler import JobStatus, Scheduler
@pytest.mark.asyncio
async def test_scheduler_unknown_backend():
with pytest.raises(ValueError):
Scheduler(backend="unknown")
@ -26,7 +25,6 @@ async def wait_for_job_completed(sched: Scheduler, job_id: str) -> None:
raise TimeoutError(f"Job {job_id} did not complete in time.")
@pytest.mark.asyncio
async def test_scheduler_naive():
sched = Scheduler()
@ -87,7 +85,6 @@ async def test_scheduler_naive():
assert job.logs[0][0] < job.logs[1][0]
@pytest.mark.asyncio
async def test_scheduler_naive_handler_raises():
sched = Scheduler()

View file

@ -8,20 +8,32 @@ import random
import numpy as np
import pytest
from chromadb import PersistentClient
from pymilvus import MilvusClient, connections
from llama_stack.apis.vector_dbs import VectorDB
from llama_stack.apis.vector_io import Chunk, ChunkMetadata
from llama_stack.providers.inline.vector_io.chroma.config import ChromaVectorIOConfig
from llama_stack.providers.inline.vector_io.faiss.config import FaissVectorIOConfig
from llama_stack.providers.inline.vector_io.faiss.faiss import FaissIndex, FaissVectorIOAdapter
from llama_stack.providers.inline.vector_io.milvus.config import MilvusVectorIOConfig, SqliteKVStoreConfig
from llama_stack.providers.inline.vector_io.qdrant import QdrantVectorIOConfig
from llama_stack.providers.inline.vector_io.sqlite_vec import SQLiteVectorIOConfig
from llama_stack.providers.inline.vector_io.sqlite_vec.sqlite_vec import SQLiteVecIndex, SQLiteVecVectorIOAdapter
from llama_stack.providers.remote.vector_io.chroma.chroma import ChromaIndex, ChromaVectorIOAdapter, maybe_await
from llama_stack.providers.remote.vector_io.milvus.milvus import MilvusIndex, MilvusVectorIOAdapter
from llama_stack.providers.remote.vector_io.qdrant.qdrant import QdrantVectorIOAdapter
EMBEDDING_DIMENSION = 384
COLLECTION_PREFIX = "test_collection"
MILVUS_ALIAS = "test_milvus"
@pytest.fixture(params=["milvus", "sqlite_vec", "faiss", "chroma"])
def vector_provider(request):
return request.param
@pytest.fixture
def vector_db_id() -> str:
return f"test-vector-db-{random.randint(1, 100)}"
@ -90,11 +102,6 @@ def sample_embeddings_with_metadata(sample_chunks_with_metadata):
return np.array([np.random.rand(EMBEDDING_DIMENSION).astype(np.float32) for _ in sample_chunks_with_metadata])
@pytest.fixture(params=["milvus", "sqlite_vec"])
def vector_provider(request):
return request.param
@pytest.fixture(scope="session")
def mock_inference_api(embedding_dimension):
class MockInferenceAPI:
@ -116,7 +123,7 @@ async def unique_kvstore_config(tmp_path_factory):
@pytest.fixture(scope="session")
def sqlite_vec_db_path(tmp_path_factory):
db_path = str(tmp_path_factory.getbasetemp() / "test.db")
db_path = str(tmp_path_factory.getbasetemp() / "test_sqlite_vec.db")
return db_path
@ -198,13 +205,145 @@ async def milvus_vec_adapter(milvus_vec_db_path, mock_inference_api):
await adapter.shutdown()
@pytest.fixture
def faiss_vec_db_path(tmp_path_factory):
db_path = str(tmp_path_factory.getbasetemp() / "test_faiss.db")
return db_path
@pytest.fixture
async def faiss_vec_index(embedding_dimension):
index = FaissIndex(embedding_dimension)
yield index
await index.delete()
@pytest.fixture
async def faiss_vec_adapter(unique_kvstore_config, mock_inference_api, embedding_dimension):
config = FaissVectorIOConfig(
kvstore=unique_kvstore_config,
)
adapter = FaissVectorIOAdapter(
config=config,
inference_api=mock_inference_api,
files_api=None,
)
await adapter.initialize()
await adapter.register_vector_db(
VectorDB(
identifier=f"faiss_test_collection_{np.random.randint(1e6)}",
provider_id="test_provider",
embedding_model="test_model",
embedding_dimension=embedding_dimension,
)
)
yield adapter
await adapter.shutdown()
@pytest.fixture
def chroma_vec_db_path(tmp_path_factory):
persist_dir = tmp_path_factory.mktemp(f"chroma_{np.random.randint(1e6)}")
return str(persist_dir)
@pytest.fixture
async def chroma_vec_index(chroma_vec_db_path, embedding_dimension):
client = PersistentClient(path=chroma_vec_db_path)
name = f"{COLLECTION_PREFIX}_{np.random.randint(1e6)}"
collection = await maybe_await(client.get_or_create_collection(name))
index = ChromaIndex(client=client, collection=collection)
await index.initialize()
yield index
await index.delete()
@pytest.fixture
async def chroma_vec_adapter(chroma_vec_db_path, mock_inference_api, embedding_dimension):
config = ChromaVectorIOConfig(
db_path=chroma_vec_db_path,
kvstore=SqliteKVStoreConfig(),
)
adapter = ChromaVectorIOAdapter(
config=config,
inference_api=mock_inference_api,
files_api=None,
)
await adapter.initialize()
await adapter.register_vector_db(
VectorDB(
identifier=f"chroma_test_collection_{random.randint(1, 1_000_000)}",
provider_id="test_provider",
embedding_model="test_model",
embedding_dimension=embedding_dimension,
)
)
yield adapter
await adapter.shutdown()
@pytest.fixture
def qdrant_vec_db_path(tmp_path_factory):
import uuid
db_path = str(tmp_path_factory.getbasetemp() / f"test_qdrant_{uuid.uuid4()}.db")
return db_path
@pytest.fixture
async def qdrant_vec_adapter(qdrant_vec_db_path, mock_inference_api, embedding_dimension):
import uuid
config = QdrantVectorIOConfig(
db_path=qdrant_vec_db_path,
kvstore=SqliteKVStoreConfig(),
)
adapter = QdrantVectorIOAdapter(
config=config,
inference_api=mock_inference_api,
files_api=None,
)
collection_id = f"qdrant_test_collection_{uuid.uuid4()}"
await adapter.initialize()
await adapter.register_vector_db(
VectorDB(
identifier=collection_id,
provider_id="test_provider",
embedding_model="test_model",
embedding_dimension=embedding_dimension,
)
)
adapter.test_collection_id = collection_id
yield adapter
await adapter.shutdown()
@pytest.fixture
async def qdrant_vec_index(qdrant_vec_db_path, embedding_dimension):
import uuid
from qdrant_client import AsyncQdrantClient
from llama_stack.providers.remote.vector_io.qdrant.qdrant import QdrantIndex
client = AsyncQdrantClient(path=qdrant_vec_db_path)
collection_name = f"qdrant_test_collection_{uuid.uuid4()}"
index = QdrantIndex(client, collection_name)
yield index
await index.delete()
@pytest.fixture
def vector_io_adapter(vector_provider, request):
"""Returns the appropriate vector IO adapter based on the provider parameter."""
if vector_provider == "milvus":
return request.getfixturevalue("milvus_vec_adapter")
else:
return request.getfixturevalue("sqlite_vec_adapter")
vector_provider_dict = {
"milvus": "milvus_vec_adapter",
"faiss": "faiss_vec_adapter",
"sqlite_vec": "sqlite_vec_adapter",
"chroma": "chroma_vec_adapter",
"qdrant": "qdrant_vec_adapter",
}
return request.getfixturevalue(vector_provider_dict[vector_provider])
@pytest.fixture

View file

@ -0,0 +1,326 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from unittest.mock import MagicMock, patch
import numpy as np
import pytest
from llama_stack.apis.vector_io import QueryChunksResponse
# Mock the entire pymilvus module
pymilvus_mock = MagicMock()
pymilvus_mock.DataType = MagicMock()
pymilvus_mock.MilvusClient = MagicMock
pymilvus_mock.RRFRanker = MagicMock
pymilvus_mock.WeightedRanker = MagicMock
pymilvus_mock.AnnSearchRequest = MagicMock
# Apply the mock before importing MilvusIndex
with patch.dict("sys.modules", {"pymilvus": pymilvus_mock}):
from llama_stack.providers.remote.vector_io.milvus.milvus import MilvusIndex
# This test is a unit test for the MilvusVectorIOAdapter class. This should only contain
# tests which are specific to this class. More general (API-level) tests should be placed in
# tests/integration/vector_io/
#
# How to run this test:
#
# pytest tests/unit/providers/vector_io/test_milvus.py \
# -v -s --tb=short --disable-warnings --asyncio-mode=auto
MILVUS_PROVIDER = "milvus"
@pytest.fixture
async def mock_milvus_client() -> MagicMock:
"""Create a mock Milvus client with common method behaviors."""
client = MagicMock()
# Mock collection operations
client.has_collection.return_value = False # Initially no collection
client.create_collection.return_value = None
client.drop_collection.return_value = None
# Mock insert operation
client.insert.return_value = {"insert_count": 10}
# Mock search operation - return mock results (data should be dict, not JSON string)
client.search.return_value = [
[
{
"id": 0,
"distance": 0.1,
"entity": {"chunk_content": {"content": "mock chunk 1", "metadata": {"document_id": "doc1"}}},
},
{
"id": 1,
"distance": 0.2,
"entity": {"chunk_content": {"content": "mock chunk 2", "metadata": {"document_id": "doc2"}}},
},
]
]
# Mock query operation for keyword search (data should be dict, not JSON string)
client.query.return_value = [
{
"chunk_id": "chunk1",
"chunk_content": {"content": "mock chunk 1", "metadata": {"document_id": "doc1"}},
"score": 0.9,
},
{
"chunk_id": "chunk2",
"chunk_content": {"content": "mock chunk 2", "metadata": {"document_id": "doc2"}},
"score": 0.8,
},
{
"chunk_id": "chunk3",
"chunk_content": {"content": "mock chunk 3", "metadata": {"document_id": "doc3"}},
"score": 0.7,
},
]
return client
@pytest.fixture
async def milvus_index(mock_milvus_client):
"""Create a MilvusIndex with mocked client."""
index = MilvusIndex(client=mock_milvus_client, collection_name="test_collection")
yield index
# No real cleanup needed since we're using mocks
async def test_add_chunks(milvus_index, sample_chunks, sample_embeddings, mock_milvus_client):
# Setup: collection doesn't exist initially, then exists after creation
mock_milvus_client.has_collection.side_effect = [False, True]
await milvus_index.add_chunks(sample_chunks, sample_embeddings)
# Verify collection was created and data was inserted
mock_milvus_client.create_collection.assert_called_once()
mock_milvus_client.insert.assert_called_once()
# Verify the insert call had the right number of chunks
insert_call = mock_milvus_client.insert.call_args
assert len(insert_call[1]["data"]) == len(sample_chunks)
async def test_query_chunks_vector(
milvus_index, sample_chunks, sample_embeddings, embedding_dimension, mock_milvus_client
):
# Setup: Add chunks first
mock_milvus_client.has_collection.return_value = True
await milvus_index.add_chunks(sample_chunks, sample_embeddings)
# Test vector search
query_embedding = np.random.rand(embedding_dimension).astype(np.float32)
response = await milvus_index.query_vector(query_embedding, k=2, score_threshold=0.0)
assert isinstance(response, QueryChunksResponse)
assert len(response.chunks) == 2
mock_milvus_client.search.assert_called_once()
async def test_query_chunks_keyword_search(milvus_index, sample_chunks, sample_embeddings, mock_milvus_client):
mock_milvus_client.has_collection.return_value = True
await milvus_index.add_chunks(sample_chunks, sample_embeddings)
# Test keyword search
query_string = "Sentence 5"
response = await milvus_index.query_keyword(query_string=query_string, k=2, score_threshold=0.0)
assert isinstance(response, QueryChunksResponse)
assert len(response.chunks) == 2
async def test_bm25_fallback_to_simple_search(milvus_index, sample_chunks, sample_embeddings, mock_milvus_client):
"""Test that when BM25 search fails, the system falls back to simple text search."""
mock_milvus_client.has_collection.return_value = True
await milvus_index.add_chunks(sample_chunks, sample_embeddings)
# Force BM25 search to fail
mock_milvus_client.search.side_effect = Exception("BM25 search not available")
# Mock simple text search results
mock_milvus_client.query.return_value = [
{
"chunk_id": "chunk1",
"chunk_content": {"content": "Python programming language", "metadata": {"document_id": "doc1"}},
},
{
"chunk_id": "chunk2",
"chunk_content": {"content": "Machine learning algorithms", "metadata": {"document_id": "doc2"}},
},
]
# Test keyword search that should fall back to simple text search
query_string = "Python"
response = await milvus_index.query_keyword(query_string=query_string, k=3, score_threshold=0.0)
# Verify response structure
assert isinstance(response, QueryChunksResponse)
assert len(response.chunks) > 0, "Fallback search should return results"
# Verify that simple text search was used (query method called instead of search)
mock_milvus_client.query.assert_called_once()
mock_milvus_client.search.assert_called_once() # Called once but failed
# Verify the query uses parameterized filter with filter_params
query_call_args = mock_milvus_client.query.call_args
assert "filter" in query_call_args[1], "Query should include filter for text search"
assert "filter_params" in query_call_args[1], "Query should use parameterized filter"
assert query_call_args[1]["filter_params"]["content"] == "Python", "Filter params should contain the search term"
# Verify all returned chunks have score 1.0 (simple binary scoring)
assert all(score == 1.0 for score in response.scores), "Simple text search should use binary scoring"
async def test_delete_collection(milvus_index, mock_milvus_client):
# Test collection deletion
mock_milvus_client.has_collection.return_value = True
await milvus_index.delete()
mock_milvus_client.drop_collection.assert_called_once_with(collection_name=milvus_index.collection_name)
async def test_query_hybrid_search_rrf(
milvus_index, sample_chunks, sample_embeddings, embedding_dimension, mock_milvus_client
):
"""Test hybrid search with RRF reranker."""
mock_milvus_client.has_collection.return_value = True
await milvus_index.add_chunks(sample_chunks, sample_embeddings)
# Mock hybrid search results
mock_milvus_client.hybrid_search.return_value = [
[
{
"id": 0,
"distance": 0.1,
"entity": {"chunk_content": {"content": "mock chunk 1", "metadata": {"document_id": "doc1"}}},
},
{
"id": 1,
"distance": 0.2,
"entity": {"chunk_content": {"content": "mock chunk 2", "metadata": {"document_id": "doc2"}}},
},
]
]
# Test hybrid search with RRF reranker
query_embedding = np.random.rand(embedding_dimension).astype(np.float32)
query_string = "test query"
response = await milvus_index.query_hybrid(
embedding=query_embedding,
query_string=query_string,
k=2,
score_threshold=0.0,
reranker_type="rrf",
reranker_params={"impact_factor": 60.0},
)
assert isinstance(response, QueryChunksResponse)
assert len(response.chunks) == 2
assert len(response.scores) == 2
# Verify hybrid search was called with correct parameters
mock_milvus_client.hybrid_search.assert_called_once()
call_args = mock_milvus_client.hybrid_search.call_args
# Check that the request contains both vector and BM25 search requests
reqs = call_args[1]["reqs"]
assert len(reqs) == 2
assert reqs[0].anns_field == "vector"
assert reqs[1].anns_field == "sparse"
ranker = call_args[1]["ranker"]
assert ranker is not None
async def test_query_hybrid_search_weighted(
milvus_index, sample_chunks, sample_embeddings, embedding_dimension, mock_milvus_client
):
"""Test hybrid search with weighted reranker."""
mock_milvus_client.has_collection.return_value = True
await milvus_index.add_chunks(sample_chunks, sample_embeddings)
# Mock hybrid search results
mock_milvus_client.hybrid_search.return_value = [
[
{
"id": 0,
"distance": 0.1,
"entity": {"chunk_content": {"content": "mock chunk 1", "metadata": {"document_id": "doc1"}}},
},
{
"id": 1,
"distance": 0.2,
"entity": {"chunk_content": {"content": "mock chunk 2", "metadata": {"document_id": "doc2"}}},
},
]
]
# Test hybrid search with weighted reranker
query_embedding = np.random.rand(embedding_dimension).astype(np.float32)
query_string = "test query"
response = await milvus_index.query_hybrid(
embedding=query_embedding,
query_string=query_string,
k=2,
score_threshold=0.0,
reranker_type="weighted",
reranker_params={"alpha": 0.7},
)
assert isinstance(response, QueryChunksResponse)
assert len(response.chunks) == 2
assert len(response.scores) == 2
# Verify hybrid search was called with correct parameters
mock_milvus_client.hybrid_search.assert_called_once()
call_args = mock_milvus_client.hybrid_search.call_args
ranker = call_args[1]["ranker"]
assert ranker is not None
async def test_query_hybrid_search_default_rrf(
milvus_index, sample_chunks, sample_embeddings, embedding_dimension, mock_milvus_client
):
"""Test hybrid search with default RRF reranker (no reranker_type specified)."""
mock_milvus_client.has_collection.return_value = True
await milvus_index.add_chunks(sample_chunks, sample_embeddings)
# Mock hybrid search results
mock_milvus_client.hybrid_search.return_value = [
[
{
"id": 0,
"distance": 0.1,
"entity": {"chunk_content": {"content": "mock chunk 1", "metadata": {"document_id": "doc1"}}},
},
]
]
# Test hybrid search with default reranker (should be RRF)
query_embedding = np.random.rand(embedding_dimension).astype(np.float32)
query_string = "test query"
response = await milvus_index.query_hybrid(
embedding=query_embedding,
query_string=query_string,
k=1,
score_threshold=0.0,
reranker_type="unknown_type", # Should default to RRF
reranker_params=None, # Should use default impact_factor
)
assert isinstance(response, QueryChunksResponse)
assert len(response.chunks) == 1
# Verify hybrid search was called with RRF reranker
mock_milvus_client.hybrid_search.assert_called_once()
call_args = mock_milvus_client.hybrid_search.call_args
ranker = call_args[1]["ranker"]
assert ranker is not None

View file

@ -9,7 +9,6 @@ from unittest.mock import AsyncMock, MagicMock, patch
import numpy as np
import pytest
import pytest_asyncio
from llama_stack.apis.files import Files
from llama_stack.apis.inference import EmbeddingsResponse, Inference
@ -91,13 +90,13 @@ def faiss_config():
return config
@pytest_asyncio.fixture
@pytest.fixture
async def faiss_index(embedding_dimension):
index = await FaissIndex.create(dimension=embedding_dimension)
yield index
@pytest_asyncio.fixture
@pytest.fixture
async def faiss_adapter(faiss_config, mock_inference_api, mock_files_api) -> FaissVectorIOAdapter:
# Create the adapter
adapter = FaissVectorIOAdapter(config=faiss_config, inference_api=mock_inference_api, files_api=mock_files_api)
@ -113,7 +112,6 @@ async def faiss_adapter(faiss_config, mock_inference_api, mock_files_api) -> Fai
yield adapter
@pytest.mark.asyncio
async def test_faiss_query_vector_returns_infinity_when_query_and_embedding_are_identical(
faiss_index, sample_chunks, sample_embeddings, embedding_dimension
):
@ -136,7 +134,6 @@ async def test_faiss_query_vector_returns_infinity_when_query_and_embedding_are_
assert response.chunks[1] == sample_chunks[1]
@pytest.mark.asyncio
async def test_health_success():
"""Test that the health check returns OK status when faiss is working correctly."""
# Create a fresh instance of FaissVectorIOAdapter for testing
@ -160,7 +157,6 @@ async def test_health_success():
mock_index_flat.assert_called_once_with(128) # VECTOR_DIMENSION is 128
@pytest.mark.asyncio
async def test_health_failure():
"""Test that the health check returns ERROR status when faiss encounters an error."""
# Create a fresh instance of FaissVectorIOAdapter for testing

View file

@ -10,7 +10,6 @@ from typing import Any
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
import pytest_asyncio
from llama_stack.apis.inference import EmbeddingsResponse, Inference
from llama_stack.apis.vector_io import (
@ -24,6 +23,7 @@ from llama_stack.providers.inline.vector_io.qdrant.config import (
from llama_stack.providers.remote.vector_io.qdrant.qdrant import (
QdrantVectorIOAdapter,
)
from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig
# This test is a unit test for the QdrantVectorIOAdapter class. This should only contain
# tests which are specific to this class. More general (API-level) tests should be placed in
@ -37,7 +37,8 @@ from llama_stack.providers.remote.vector_io.qdrant.qdrant import (
@pytest.fixture
def qdrant_config(tmp_path) -> InlineQdrantVectorIOConfig:
return InlineQdrantVectorIOConfig(path=os.path.join(tmp_path, "qdrant.db"))
kvstore_config = SqliteKVStoreConfig(db_name=os.path.join(tmp_path, "test_kvstore.db"))
return InlineQdrantVectorIOConfig(path=os.path.join(tmp_path, "qdrant.db"), kvstore=kvstore_config)
@pytest.fixture(scope="session")
@ -51,6 +52,9 @@ def mock_vector_db(vector_db_id) -> MagicMock:
mock_vector_db.embedding_model = "embedding_model"
mock_vector_db.identifier = vector_db_id
mock_vector_db.embedding_dimension = 384
mock_vector_db.model_dump_json.return_value = (
'{"identifier": "' + vector_db_id + '", "embedding_model": "embedding_model", "embedding_dimension": 384}'
)
return mock_vector_db
@ -68,9 +72,9 @@ def mock_api_service(sample_embeddings):
return mock_api_service
@pytest_asyncio.fixture
@pytest.fixture
async def qdrant_adapter(qdrant_config, mock_vector_db_store, mock_api_service, loop) -> QdrantVectorIOAdapter:
adapter = QdrantVectorIOAdapter(config=qdrant_config, inference_api=mock_api_service)
adapter = QdrantVectorIOAdapter(config=qdrant_config, inference_api=mock_api_service, files_api=None)
adapter.vector_db_store = mock_vector_db_store
await adapter.initialize()
yield adapter
@ -80,7 +84,6 @@ async def qdrant_adapter(qdrant_config, mock_vector_db_store, mock_api_service,
__QUERY = "Sample query"
@pytest.mark.asyncio
@pytest.mark.parametrize("max_query_chunks, expected_chunks", [(2, 2), (100, 60)])
async def test_qdrant_adapter_returns_expected_chunks(
qdrant_adapter: QdrantVectorIOAdapter,
@ -111,7 +114,6 @@ def _prepare_for_json(value: Any) -> str:
@patch("llama_stack.providers.utils.telemetry.trace_protocol._prepare_for_json", new=_prepare_for_json)
@pytest.mark.asyncio
async def test_qdrant_register_and_unregister_vector_db(
qdrant_adapter: QdrantVectorIOAdapter,
mock_vector_db,

View file

@ -8,7 +8,6 @@ import asyncio
import numpy as np
import pytest
import pytest_asyncio
from llama_stack.apis.vector_io import Chunk, QueryChunksResponse
from llama_stack.providers.inline.vector_io.sqlite_vec.sqlite_vec import (
@ -34,23 +33,21 @@ def loop():
return asyncio.new_event_loop()
@pytest_asyncio.fixture
@pytest.fixture
async def sqlite_vec_index(embedding_dimension, tmp_path_factory):
temp_dir = tmp_path_factory.getbasetemp()
db_path = str(temp_dir / "test_sqlite.db")
index = await SQLiteVecIndex.create(dimension=embedding_dimension, db_path=db_path, bank_id="test_bank")
index = await SQLiteVecIndex.create(dimension=embedding_dimension, db_path=db_path, bank_id="test_bank.123")
yield index
await index.delete()
@pytest.mark.asyncio
async def test_query_chunk_metadata(sqlite_vec_index, sample_chunks_with_metadata, sample_embeddings_with_metadata):
await sqlite_vec_index.add_chunks(sample_chunks_with_metadata, sample_embeddings_with_metadata)
response = await sqlite_vec_index.query_vector(sample_embeddings_with_metadata[-1], k=2, score_threshold=0.0)
assert response.chunks[0].chunk_metadata == sample_chunks_with_metadata[-1].chunk_metadata
@pytest.mark.asyncio
async def test_query_chunks_full_text_search(sqlite_vec_index, sample_chunks, sample_embeddings):
await sqlite_vec_index.add_chunks(sample_chunks, sample_embeddings)
query_string = "Sentence 5"
@ -68,7 +65,6 @@ async def test_query_chunks_full_text_search(sqlite_vec_index, sample_chunks, sa
assert len(response_no_results.chunks) == 0, f"Expected 0 results, but got {len(response_no_results.chunks)}"
@pytest.mark.asyncio
async def test_query_chunks_hybrid(sqlite_vec_index, sample_chunks, sample_embeddings):
await sqlite_vec_index.add_chunks(sample_chunks, sample_embeddings)
@ -90,7 +86,6 @@ async def test_query_chunks_hybrid(sqlite_vec_index, sample_chunks, sample_embed
assert all(response.scores[i] >= response.scores[i + 1] for i in range(len(response.scores) - 1))
@pytest.mark.asyncio
async def test_query_chunks_full_text_search_k_greater_than_results(sqlite_vec_index, sample_chunks, sample_embeddings):
# Re-initialize with a clean index
await sqlite_vec_index.add_chunks(sample_chunks, sample_embeddings)
@ -103,7 +98,6 @@ async def test_query_chunks_full_text_search_k_greater_than_results(sqlite_vec_i
assert any("Sentence 1 from document 0" in chunk.content for chunk in response.chunks), "Expected chunk not found"
@pytest.mark.asyncio
async def test_chunk_id_conflict(sqlite_vec_index, sample_chunks, embedding_dimension):
"""Test that chunk IDs do not conflict across batches when inserting chunks."""
# Reduce batch size to force multiple batches for same document
@ -116,7 +110,7 @@ async def test_chunk_id_conflict(sqlite_vec_index, sample_chunks, embedding_dime
cur = connection.cursor()
# Retrieve all chunk IDs to check for duplicates
cur.execute(f"SELECT id FROM {sqlite_vec_index.metadata_table}")
cur.execute(f"SELECT id FROM [{sqlite_vec_index.metadata_table}]")
chunk_ids = [row[0] for row in cur.fetchall()]
cur.close()
connection.close()
@ -134,7 +128,6 @@ async def sqlite_vec_adapter(sqlite_connection):
await adapter.shutdown()
@pytest.mark.asyncio
async def test_query_chunks_hybrid_no_keyword_matches(sqlite_vec_index, sample_chunks, sample_embeddings):
"""Test hybrid search when keyword search returns no matches - should still return vector results."""
await sqlite_vec_index.add_chunks(sample_chunks, sample_embeddings)
@ -163,7 +156,6 @@ async def test_query_chunks_hybrid_no_keyword_matches(sqlite_vec_index, sample_c
assert all(response.scores[i] >= response.scores[i + 1] for i in range(len(response.scores) - 1))
@pytest.mark.asyncio
async def test_query_chunks_hybrid_score_threshold(sqlite_vec_index, sample_chunks, sample_embeddings):
"""Test hybrid search with a high score threshold."""
await sqlite_vec_index.add_chunks(sample_chunks, sample_embeddings)
@ -185,7 +177,6 @@ async def test_query_chunks_hybrid_score_threshold(sqlite_vec_index, sample_chun
assert len(response.chunks) == 0
@pytest.mark.asyncio
async def test_query_chunks_hybrid_different_embedding(
sqlite_vec_index, sample_chunks, sample_embeddings, embedding_dimension
):
@ -211,7 +202,6 @@ async def test_query_chunks_hybrid_different_embedding(
assert all(response.scores[i] >= response.scores[i + 1] for i in range(len(response.scores) - 1))
@pytest.mark.asyncio
async def test_query_chunks_hybrid_rrf_ranking(sqlite_vec_index, sample_chunks, sample_embeddings):
"""Test that RRF properly combines rankings when documents appear in both search methods."""
await sqlite_vec_index.add_chunks(sample_chunks, sample_embeddings)
@ -236,7 +226,6 @@ async def test_query_chunks_hybrid_rrf_ranking(sqlite_vec_index, sample_chunks,
assert all(response.scores[i] >= response.scores[i + 1] for i in range(len(response.scores) - 1))
@pytest.mark.asyncio
async def test_query_chunks_hybrid_score_selection(sqlite_vec_index, sample_chunks, sample_embeddings):
await sqlite_vec_index.add_chunks(sample_chunks, sample_embeddings)
@ -284,7 +273,6 @@ async def test_query_chunks_hybrid_score_selection(sqlite_vec_index, sample_chun
assert response.scores[0] == pytest.approx(2.0 / 61.0, rel=1e-6) # Should behave like RRF
@pytest.mark.asyncio
async def test_query_chunks_hybrid_mixed_results(sqlite_vec_index, sample_chunks, sample_embeddings):
"""Test hybrid search with documents that appear in only one search method."""
await sqlite_vec_index.add_chunks(sample_chunks, sample_embeddings)
@ -313,7 +301,6 @@ async def test_query_chunks_hybrid_mixed_results(sqlite_vec_index, sample_chunks
assert "document-2" in doc_ids # From keyword search
@pytest.mark.asyncio
async def test_query_chunks_hybrid_weighted_reranker_parametrization(
sqlite_vec_index, sample_chunks, sample_embeddings
):
@ -369,7 +356,6 @@ async def test_query_chunks_hybrid_weighted_reranker_parametrization(
)
@pytest.mark.asyncio
async def test_query_chunks_hybrid_rrf_impact_factor(sqlite_vec_index, sample_chunks, sample_embeddings):
"""Test RRFReRanker with different impact factors."""
await sqlite_vec_index.add_chunks(sample_chunks, sample_embeddings)
@ -401,7 +387,6 @@ async def test_query_chunks_hybrid_rrf_impact_factor(sqlite_vec_index, sample_ch
assert response.scores[0] == pytest.approx(2.0 / 101.0, rel=1e-6)
@pytest.mark.asyncio
async def test_query_chunks_hybrid_edge_cases(sqlite_vec_index, sample_chunks, sample_embeddings):
await sqlite_vec_index.add_chunks(sample_chunks, sample_embeddings)
@ -445,7 +430,6 @@ async def test_query_chunks_hybrid_edge_cases(sqlite_vec_index, sample_chunks, s
assert len(response.chunks) <= 100
@pytest.mark.asyncio
async def test_query_chunks_hybrid_tie_breaking(
sqlite_vec_index, sample_embeddings, embedding_dimension, tmp_path_factory
):

View file

@ -25,12 +25,10 @@ from llama_stack.providers.remote.vector_io.milvus.milvus import VECTOR_DBS_PREF
# -v -s --tb=short --disable-warnings --asyncio-mode=auto
@pytest.mark.asyncio
async def test_initialize_index(vector_index):
await vector_index.initialize()
@pytest.mark.asyncio
async def test_add_chunks_query_vector(vector_index, sample_chunks, sample_embeddings):
vector_index.delete()
vector_index.initialize()
@ -40,7 +38,6 @@ async def test_add_chunks_query_vector(vector_index, sample_chunks, sample_embed
vector_index.delete()
@pytest.mark.asyncio
async def test_chunk_id_conflict(vector_index, sample_chunks, embedding_dimension):
embeddings = np.random.rand(len(sample_chunks), embedding_dimension).astype(np.float32)
await vector_index.add_chunks(sample_chunks, embeddings)
@ -54,7 +51,6 @@ async def test_chunk_id_conflict(vector_index, sample_chunks, embedding_dimensio
assert len(contents) == len(set(contents))
@pytest.mark.asyncio
async def test_initialize_adapter_with_existing_kvstore(vector_io_adapter):
key = f"{VECTOR_DBS_PREFIX}db1"
dummy = VectorDB(
@ -65,7 +61,6 @@ async def test_initialize_adapter_with_existing_kvstore(vector_io_adapter):
await vector_io_adapter.initialize()
@pytest.mark.asyncio
async def test_persistence_across_adapter_restarts(vector_io_adapter):
await vector_io_adapter.initialize()
dummy = VectorDB(
@ -79,7 +74,6 @@ async def test_persistence_across_adapter_restarts(vector_io_adapter):
await vector_io_adapter.shutdown()
@pytest.mark.asyncio
async def test_register_and_unregister_vector_db(vector_io_adapter):
unique_id = f"foo_db_{np.random.randint(1e6)}"
dummy = VectorDB(
@ -92,17 +86,19 @@ async def test_register_and_unregister_vector_db(vector_io_adapter):
assert dummy.identifier not in vector_io_adapter.cache
@pytest.mark.asyncio
async def test_query_unregistered_raises(vector_io_adapter):
async def test_query_unregistered_raises(vector_io_adapter, vector_provider):
fake_emb = np.zeros(8, dtype=np.float32)
with pytest.raises(ValueError):
await vector_io_adapter.query_chunks("no_such_db", fake_emb)
if vector_provider == "chroma":
with pytest.raises(AttributeError):
await vector_io_adapter.query_chunks("no_such_db", fake_emb)
else:
with pytest.raises(ValueError):
await vector_io_adapter.query_chunks("no_such_db", fake_emb)
@pytest.mark.asyncio
async def test_insert_chunks_calls_underlying_index(vector_io_adapter):
fake_index = AsyncMock()
vector_io_adapter._get_and_cache_vector_db_index = AsyncMock(return_value=fake_index)
vector_io_adapter.cache["db1"] = fake_index
chunks = ["chunk1", "chunk2"]
await vector_io_adapter.insert_chunks("db1", chunks)
@ -110,7 +106,6 @@ async def test_insert_chunks_calls_underlying_index(vector_io_adapter):
fake_index.insert_chunks.assert_awaited_once_with(chunks)
@pytest.mark.asyncio
async def test_insert_chunks_missing_db_raises(vector_io_adapter):
vector_io_adapter._get_and_cache_vector_db_index = AsyncMock(return_value=None)
@ -118,11 +113,10 @@ async def test_insert_chunks_missing_db_raises(vector_io_adapter):
await vector_io_adapter.insert_chunks("db_not_exist", [])
@pytest.mark.asyncio
async def test_query_chunks_calls_underlying_index_and_returns(vector_io_adapter):
expected = QueryChunksResponse(chunks=[Chunk(content="c1")], scores=[0.1])
fake_index = AsyncMock(query_chunks=AsyncMock(return_value=expected))
vector_io_adapter._get_and_cache_vector_db_index = AsyncMock(return_value=fake_index)
vector_io_adapter.cache["db1"] = fake_index
response = await vector_io_adapter.query_chunks("db1", "my_query", {"param": 1})
@ -130,7 +124,6 @@ async def test_query_chunks_calls_underlying_index_and_returns(vector_io_adapter
assert response is expected
@pytest.mark.asyncio
async def test_query_chunks_missing_db_raises(vector_io_adapter):
vector_io_adapter._get_and_cache_vector_db_index = AsyncMock(return_value=None)
@ -138,7 +131,6 @@ async def test_query_chunks_missing_db_raises(vector_io_adapter):
await vector_io_adapter.query_chunks("db_missing", "q", None)
@pytest.mark.asyncio
async def test_save_openai_vector_store(vector_io_adapter):
store_id = "vs_1234"
openai_vector_store = {
@ -155,7 +147,6 @@ async def test_save_openai_vector_store(vector_io_adapter):
assert vector_io_adapter.openai_vector_stores[openai_vector_store["id"]] == openai_vector_store
@pytest.mark.asyncio
async def test_update_openai_vector_store(vector_io_adapter):
store_id = "vs_1234"
openai_vector_store = {
@ -172,7 +163,6 @@ async def test_update_openai_vector_store(vector_io_adapter):
assert vector_io_adapter.openai_vector_stores[openai_vector_store["id"]] == openai_vector_store
@pytest.mark.asyncio
async def test_delete_openai_vector_store(vector_io_adapter):
store_id = "vs_1234"
openai_vector_store = {
@ -188,7 +178,6 @@ async def test_delete_openai_vector_store(vector_io_adapter):
assert openai_vector_store["id"] not in vector_io_adapter.openai_vector_stores
@pytest.mark.asyncio
async def test_load_openai_vector_stores(vector_io_adapter):
store_id = "vs_1234"
openai_vector_store = {
@ -204,7 +193,6 @@ async def test_load_openai_vector_stores(vector_io_adapter):
assert loaded_stores[store_id] == openai_vector_store
@pytest.mark.asyncio
async def test_save_openai_vector_store_file(vector_io_adapter, tmp_path_factory):
store_id = "vs_1234"
file_id = "file_1234"
@ -226,7 +214,6 @@ async def test_save_openai_vector_store_file(vector_io_adapter, tmp_path_factory
await vector_io_adapter._save_openai_vector_store_file(store_id, file_id, file_info, file_contents)
@pytest.mark.asyncio
async def test_update_openai_vector_store_file(vector_io_adapter, tmp_path_factory):
store_id = "vs_1234"
file_id = "file_1234"
@ -260,7 +247,6 @@ async def test_update_openai_vector_store_file(vector_io_adapter, tmp_path_facto
assert loaded_contents != file_info
@pytest.mark.asyncio
async def test_load_openai_vector_store_file_contents(vector_io_adapter, tmp_path_factory):
store_id = "vs_1234"
file_id = "file_1234"
@ -284,7 +270,6 @@ async def test_load_openai_vector_store_file_contents(vector_io_adapter, tmp_pat
assert loaded_contents == file_contents
@pytest.mark.asyncio
async def test_delete_openai_vector_store_file_from_storage(vector_io_adapter, tmp_path_factory):
store_id = "vs_1234"
file_id = "file_1234"
@ -305,5 +290,7 @@ async def test_delete_openai_vector_store_file_from_storage(vector_io_adapter, t
await vector_io_adapter._save_openai_vector_store_file(store_id, file_id, file_info, file_contents)
await vector_io_adapter._delete_openai_vector_store_file_from_storage(store_id, file_id)
loaded_file_info = await vector_io_adapter._load_openai_vector_store_file(store_id, file_id)
assert loaded_file_info == {}
loaded_contents = await vector_io_adapter._load_openai_vector_store_file_contents(store_id, file_id)
assert loaded_contents == []

View file

@ -5,7 +5,7 @@
# the root directory of this source tree.
from llama_stack.apis.vector_io import Chunk, ChunkMetadata
from llama_stack.providers.utils.vector_io.chunk_utils import generate_chunk_id
from llama_stack.providers.utils.vector_io.vector_utils import generate_chunk_id
# This test is a unit test for the chunk_utils.py helpers. This should only contain
# tests which are specific to this file. More general (API-level) tests should be placed in

View file

@ -8,6 +8,7 @@ from unittest.mock import AsyncMock, MagicMock
import pytest
from llama_stack.apis.tools.rag_tool import RAGQueryConfig
from llama_stack.apis.vector_io import (
Chunk,
ChunkMetadata,
@ -17,13 +18,11 @@ from llama_stack.providers.inline.tool_runtime.rag.memory import MemoryToolRunti
class TestRagQuery:
@pytest.mark.asyncio
async def test_query_raises_on_empty_vector_db_ids(self):
rag_tool = MemoryToolRuntimeImpl(config=MagicMock(), vector_io_api=MagicMock(), inference_api=MagicMock())
with pytest.raises(ValueError):
await rag_tool.query(content=MagicMock(), vector_db_ids=[])
@pytest.mark.asyncio
async def test_query_chunk_metadata_handling(self):
rag_tool = MemoryToolRuntimeImpl(config=MagicMock(), vector_io_api=MagicMock(), inference_api=MagicMock())
content = "test query content"
@ -60,3 +59,21 @@ class TestRagQuery:
)
assert expected_metadata_string in result.content[1].text
assert result.content is not None
async def test_query_raises_incorrect_mode(self):
with pytest.raises(ValueError):
RAGQueryConfig(mode="invalid_mode")
async def test_query_accepts_valid_modes(self):
default_config = RAGQueryConfig() # Test default (vector)
assert default_config.mode == "vector"
vector_config = RAGQueryConfig(mode="vector") # Test vector
assert vector_config.mode == "vector"
keyword_config = RAGQueryConfig(mode="keyword") # Test keyword
assert keyword_config.mode == "keyword"
hybrid_config = RAGQueryConfig(mode="hybrid") # Test hybrid
assert hybrid_config.mode == "hybrid"
# Test that invalid mode raises an error
with pytest.raises(ValueError):
RAGQueryConfig(mode="wrong_mode")

View file

@ -112,7 +112,6 @@ class TestValidateEmbedding:
class TestVectorStore:
@pytest.mark.asyncio
async def test_returns_content_from_pdf_data_uri(self):
data_uri = data_url_from_file(DUMMY_PDF_PATH)
doc = RAGDocument(
@ -124,7 +123,7 @@ class TestVectorStore:
content = await content_from_doc(doc)
assert content in DUMMY_PDF_TEXT_CHOICES
@pytest.mark.asyncio
@pytest.mark.allow_network
async def test_downloads_pdf_and_returns_content(self):
# Using GitHub to host the PDF file
url = "https://raw.githubusercontent.com/meta-llama/llama-stack/da035d69cfca915318eaf485770a467ca3c2a238/llama_stack/providers/tests/memory/fixtures/dummy.pdf"
@ -137,7 +136,7 @@ class TestVectorStore:
content = await content_from_doc(doc)
assert content in DUMMY_PDF_TEXT_CHOICES
@pytest.mark.asyncio
@pytest.mark.allow_network
async def test_downloads_pdf_and_returns_content_with_url_object(self):
# Using GitHub to host the PDF file
url = "https://raw.githubusercontent.com/meta-llama/llama-stack/da035d69cfca915318eaf485770a467ca3c2a238/llama_stack/providers/tests/memory/fixtures/dummy.pdf"
@ -204,7 +203,6 @@ class TestVectorStore:
class TestVectorDBWithIndex:
@pytest.mark.asyncio
async def test_insert_chunks_without_embeddings(self):
mock_vector_db = MagicMock()
mock_vector_db.embedding_model = "test-model without embeddings"
@ -230,7 +228,6 @@ class TestVectorDBWithIndex:
assert args[0] == chunks
assert np.array_equal(args[1], np.array([[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]], dtype=np.float32))
@pytest.mark.asyncio
async def test_insert_chunks_with_valid_embeddings(self):
mock_vector_db = MagicMock()
mock_vector_db.embedding_model = "test-model with embeddings"
@ -255,7 +252,6 @@ class TestVectorDBWithIndex:
assert args[0] == chunks
assert np.array_equal(args[1], np.array([[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]], dtype=np.float32))
@pytest.mark.asyncio
async def test_insert_chunks_with_invalid_embeddings(self):
mock_vector_db = MagicMock()
mock_vector_db.embedding_dimension = 3
@ -295,7 +291,6 @@ class TestVectorDBWithIndex:
mock_inference_api.embeddings.assert_not_called()
mock_index.add_chunks.assert_not_called()
@pytest.mark.asyncio
async def test_insert_chunks_with_partially_precomputed_embeddings(self):
mock_vector_db = MagicMock()
mock_vector_db.embedding_model = "test-model with partial embeddings"

View file

@ -9,7 +9,7 @@ import pytest
from llama_stack.apis.inference import Model
from llama_stack.apis.vector_dbs import VectorDB
from llama_stack.distribution.store.registry import (
from llama_stack.core.store.registry import (
KEY_FORMAT,
CachedDiskDistributionRegistry,
DiskDistributionRegistry,
@ -38,14 +38,12 @@ def sample_model():
)
@pytest.mark.asyncio
async def test_registry_initialization(disk_dist_registry):
# Test empty registry
result = await disk_dist_registry.get("nonexistent", "nonexistent")
assert result is None
@pytest.mark.asyncio
async def test_basic_registration(disk_dist_registry, sample_vector_db, sample_model):
print(f"Registering {sample_vector_db}")
await disk_dist_registry.register(sample_vector_db)
@ -64,7 +62,6 @@ async def test_basic_registration(disk_dist_registry, sample_vector_db, sample_m
assert result_model.provider_id == sample_model.provider_id
@pytest.mark.asyncio
async def test_cached_registry_initialization(sqlite_kvstore, sample_vector_db, sample_model):
# First populate the disk registry
disk_registry = DiskDistributionRegistry(sqlite_kvstore)
@ -85,7 +82,6 @@ async def test_cached_registry_initialization(sqlite_kvstore, sample_vector_db,
assert result_vector_db.provider_id == sample_vector_db.provider_id
@pytest.mark.asyncio
async def test_cached_registry_updates(cached_disk_dist_registry):
new_vector_db = VectorDB(
identifier="test_vector_db_2",
@ -112,7 +108,6 @@ async def test_cached_registry_updates(cached_disk_dist_registry):
assert result_vector_db.provider_id == new_vector_db.provider_id
@pytest.mark.asyncio
async def test_duplicate_provider_registration(cached_disk_dist_registry):
original_vector_db = VectorDB(
identifier="test_vector_db_2",
@ -137,7 +132,6 @@ async def test_duplicate_provider_registration(cached_disk_dist_registry):
assert result.embedding_model == original_vector_db.embedding_model # Original values preserved
@pytest.mark.asyncio
async def test_get_all_objects(cached_disk_dist_registry):
# Create multiple test banks
# Create multiple test banks
@ -170,7 +164,6 @@ async def test_get_all_objects(cached_disk_dist_registry):
assert stored_vector_db.embedding_dimension == original_vector_db.embedding_dimension
@pytest.mark.asyncio
async def test_parse_registry_values_error_handling(sqlite_kvstore):
valid_db = VectorDB(
identifier="valid_vector_db",
@ -209,7 +202,6 @@ async def test_parse_registry_values_error_handling(sqlite_kvstore):
assert invalid_obj is None
@pytest.mark.asyncio
async def test_cached_registry_error_handling(sqlite_kvstore):
valid_db = VectorDB(
identifier="valid_cached_db",

View file

@ -5,14 +5,11 @@
# the root directory of this source tree.
import pytest
from llama_stack.apis.models import ModelType
from llama_stack.distribution.datatypes import ModelWithOwner, User
from llama_stack.distribution.store.registry import CachedDiskDistributionRegistry
from llama_stack.core.datatypes import ModelWithOwner, User
from llama_stack.core.store.registry import CachedDiskDistributionRegistry
@pytest.mark.asyncio
async def test_registry_cache_with_acl(cached_disk_dist_registry):
model = ModelWithOwner(
identifier="model-acl",
@ -48,7 +45,6 @@ async def test_registry_cache_with_acl(cached_disk_dist_registry):
assert new_model.owner.attributes["teams"] == ["ai-team"]
@pytest.mark.asyncio
async def test_registry_empty_acl(cached_disk_dist_registry):
model = ModelWithOwner(
identifier="model-empty-acl",
@ -85,7 +81,6 @@ async def test_registry_empty_acl(cached_disk_dist_registry):
assert len(all_models) == 2
@pytest.mark.asyncio
async def test_registry_serialization(cached_disk_dist_registry):
attributes = {
"roles": ["admin", "researcher"],

View file

@ -7,15 +7,14 @@
from unittest.mock import MagicMock, Mock, patch
import pytest
import pytest_asyncio
import yaml
from pydantic import TypeAdapter, ValidationError
from llama_stack.apis.datatypes import Api
from llama_stack.apis.models import ModelType
from llama_stack.distribution.access_control.access_control import AccessDeniedError, is_action_allowed
from llama_stack.distribution.datatypes import AccessRule, ModelWithOwner, User
from llama_stack.distribution.routing_tables.models import ModelsRoutingTable
from llama_stack.core.access_control.access_control import AccessDeniedError, is_action_allowed
from llama_stack.core.datatypes import AccessRule, ModelWithOwner, User
from llama_stack.core.routing_tables.models import ModelsRoutingTable
class AsyncMock(MagicMock):
@ -27,7 +26,7 @@ def _return_model(model):
return model
@pytest_asyncio.fixture
@pytest.fixture
async def test_setup(cached_disk_dist_registry):
mock_inference = Mock()
mock_inference.__provider_spec__ = MagicMock()
@ -41,8 +40,7 @@ async def test_setup(cached_disk_dist_registry):
yield cached_disk_dist_registry, routing_table
@pytest.mark.asyncio
@patch("llama_stack.distribution.routing_tables.common.get_authenticated_user")
@patch("llama_stack.core.routing_tables.common.get_authenticated_user")
async def test_access_control_with_cache(mock_get_authenticated_user, test_setup):
registry, routing_table = test_setup
model_public = ModelWithOwner(
@ -106,8 +104,7 @@ async def test_access_control_with_cache(mock_get_authenticated_user, test_setup
await routing_table.get_model("model-admin")
@pytest.mark.asyncio
@patch("llama_stack.distribution.routing_tables.common.get_authenticated_user")
@patch("llama_stack.core.routing_tables.common.get_authenticated_user")
async def test_access_control_and_updates(mock_get_authenticated_user, test_setup):
registry, routing_table = test_setup
model_public = ModelWithOwner(
@ -145,8 +142,7 @@ async def test_access_control_and_updates(mock_get_authenticated_user, test_setu
assert model.identifier == "model-updates"
@pytest.mark.asyncio
@patch("llama_stack.distribution.routing_tables.common.get_authenticated_user")
@patch("llama_stack.core.routing_tables.common.get_authenticated_user")
async def test_access_control_empty_attributes(mock_get_authenticated_user, test_setup):
registry, routing_table = test_setup
model = ModelWithOwner(
@ -170,8 +166,7 @@ async def test_access_control_empty_attributes(mock_get_authenticated_user, test
assert "model-empty-attrs" in model_ids
@pytest.mark.asyncio
@patch("llama_stack.distribution.routing_tables.common.get_authenticated_user")
@patch("llama_stack.core.routing_tables.common.get_authenticated_user")
async def test_no_user_attributes(mock_get_authenticated_user, test_setup):
registry, routing_table = test_setup
model_public = ModelWithOwner(
@ -201,8 +196,7 @@ async def test_no_user_attributes(mock_get_authenticated_user, test_setup):
assert all_models.data[0].identifier == "model-public-2"
@pytest.mark.asyncio
@patch("llama_stack.distribution.routing_tables.common.get_authenticated_user")
@patch("llama_stack.core.routing_tables.common.get_authenticated_user")
async def test_automatic_access_attributes(mock_get_authenticated_user, test_setup):
"""Test that newly created resources inherit access attributes from their creator."""
registry, routing_table = test_setup
@ -246,7 +240,7 @@ async def test_automatic_access_attributes(mock_get_authenticated_user, test_set
assert model.identifier == "auto-access-model"
@pytest_asyncio.fixture
@pytest.fixture
async def test_setup_with_access_policy(cached_disk_dist_registry):
mock_inference = Mock()
mock_inference.__provider_spec__ = MagicMock()
@ -281,8 +275,7 @@ async def test_setup_with_access_policy(cached_disk_dist_registry):
yield routing_table
@pytest.mark.asyncio
@patch("llama_stack.distribution.routing_tables.common.get_authenticated_user")
@patch("llama_stack.core.routing_tables.common.get_authenticated_user")
async def test_access_policy(mock_get_authenticated_user, test_setup_with_access_policy):
routing_table = test_setup_with_access_policy
mock_get_authenticated_user.return_value = User(
@ -292,9 +285,15 @@ async def test_access_policy(mock_get_authenticated_user, test_setup_with_access
"projects": ["foo", "bar"],
},
)
await routing_table.register_model("model-1", provider_id="test_provider")
await routing_table.register_model("model-2", provider_id="test_provider")
await routing_table.register_model("model-3", provider_id="test_provider")
await routing_table.register_model(
"model-1", provider_model_id="test_provider/model-1", provider_id="test_provider"
)
await routing_table.register_model(
"model-2", provider_model_id="test_provider/model-2", provider_id="test_provider"
)
await routing_table.register_model(
"model-3", provider_model_id="test_provider/model-3", provider_id="test_provider"
)
model = await routing_table.get_model("model-1")
assert model.identifier == "model-1"
model = await routing_table.get_model("model-2")
@ -562,6 +561,6 @@ def test_invalid_condition():
],
)
def test_condition_reprs(condition):
from llama_stack.distribution.access_control.conditions import parse_condition
from llama_stack.core.access_control.conditions import parse_condition
assert condition == str(parse_condition(condition))

View file

@ -11,7 +11,7 @@ import pytest
from fastapi import FastAPI
from fastapi.testclient import TestClient
from llama_stack.distribution.datatypes import (
from llama_stack.core.datatypes import (
AuthenticationConfig,
AuthProviderType,
CustomAuthConfig,
@ -19,8 +19,9 @@ from llama_stack.distribution.datatypes import (
OAuth2JWKSConfig,
OAuth2TokenAuthConfig,
)
from llama_stack.distribution.server.auth import AuthenticationMiddleware
from llama_stack.distribution.server.auth_providers import (
from llama_stack.core.request_headers import User
from llama_stack.core.server.auth import AuthenticationMiddleware, _has_required_scope
from llama_stack.core.server.auth_providers import (
get_attributes_from_claims,
)
@ -73,7 +74,7 @@ def http_app(mock_auth_endpoint):
),
access_policy=[],
)
app.add_middleware(AuthenticationMiddleware, auth_config=auth_config)
app.add_middleware(AuthenticationMiddleware, auth_config=auth_config, impls={})
@app.get("/test")
def test_endpoint():
@ -111,7 +112,50 @@ def mock_http_middleware(mock_auth_endpoint):
),
access_policy=[],
)
return AuthenticationMiddleware(mock_app, auth_config), mock_app
return AuthenticationMiddleware(mock_app, auth_config, {}), mock_app
@pytest.fixture
def mock_impls():
"""Mock implementations for scope testing"""
return {}
@pytest.fixture
def scope_middleware_with_mocks(mock_auth_endpoint):
"""Create AuthenticationMiddleware with mocked route implementations"""
mock_app = AsyncMock()
auth_config = AuthenticationConfig(
provider_config=CustomAuthConfig(
type=AuthProviderType.CUSTOM,
endpoint=mock_auth_endpoint,
),
access_policy=[],
)
middleware = AuthenticationMiddleware(mock_app, auth_config, {})
# Mock the route_impls to simulate finding routes with required scopes
from llama_stack.schema_utils import WebMethod
scoped_webmethod = WebMethod(route="/test/scoped", method="POST", required_scope="test.read")
public_webmethod = WebMethod(route="/test/public", method="GET")
# Mock the route finding logic
def mock_find_matching_route(method, path, route_impls):
if method == "POST" and path == "/test/scoped":
return None, {}, "/test/scoped", scoped_webmethod
elif method == "GET" and path == "/test/public":
return None, {}, "/test/public", public_webmethod
else:
raise ValueError("No matching route")
import llama_stack.core.server.auth
llama_stack.core.server.auth.find_matching_route = mock_find_matching_route
llama_stack.core.server.auth.initialize_route_impls = lambda impls: {}
return middleware, mock_app
async def mock_post_success(*args, **kwargs):
@ -138,6 +182,36 @@ async def mock_post_exception(*args, **kwargs):
raise Exception("Connection error")
async def mock_post_success_with_scope(*args, **kwargs):
"""Mock auth response for user with test.read scope"""
return MockResponse(
200,
{
"message": "Authentication successful",
"principal": "test-user",
"attributes": {
"scopes": ["test.read", "other.scope"],
"roles": ["user"],
},
},
)
async def mock_post_success_no_scope(*args, **kwargs):
"""Mock auth response for user without required scope"""
return MockResponse(
200,
{
"message": "Authentication successful",
"principal": "test-user",
"attributes": {
"scopes": ["other.scope"],
"roles": ["user"],
},
},
)
# HTTP Endpoint Tests
def test_missing_auth_header(http_client):
response = http_client.get("/test")
@ -202,7 +276,6 @@ def test_http_auth_request_payload(http_client, valid_api_key, mock_auth_endpoin
assert "param2" in payload["request"]["params"]
@pytest.mark.asyncio
async def test_http_middleware_with_access_attributes(mock_http_middleware, mock_scope):
"""Test HTTP middleware behavior with access attributes"""
middleware, mock_app = mock_http_middleware
@ -253,7 +326,7 @@ def oauth2_app():
),
access_policy=[],
)
app.add_middleware(AuthenticationMiddleware, auth_config=auth_config)
app.add_middleware(AuthenticationMiddleware, auth_config=auth_config, impls={})
@app.get("/test")
def test_endpoint():
@ -352,7 +425,7 @@ def oauth2_app_with_jwks_token():
),
access_policy=[],
)
app.add_middleware(AuthenticationMiddleware, auth_config=auth_config)
app.add_middleware(AuthenticationMiddleware, auth_config=auth_config, impls={})
@app.get("/test")
def test_endpoint():
@ -443,7 +516,7 @@ def introspection_app(mock_introspection_endpoint):
),
access_policy=[],
)
app.add_middleware(AuthenticationMiddleware, auth_config=auth_config)
app.add_middleware(AuthenticationMiddleware, auth_config=auth_config, impls={})
@app.get("/test")
def test_endpoint():
@ -473,7 +546,7 @@ def introspection_app_with_custom_mapping(mock_introspection_endpoint):
),
access_policy=[],
)
app.add_middleware(AuthenticationMiddleware, auth_config=auth_config)
app.add_middleware(AuthenticationMiddleware, auth_config=auth_config, impls={})
@app.get("/test")
def test_endpoint():
@ -582,3 +655,122 @@ def test_valid_introspection_with_custom_mapping_authentication(
)
assert response.status_code == 200
assert response.json() == {"message": "Authentication successful"}
# Scope-based authorization tests
@patch("httpx.AsyncClient.post", new=mock_post_success_with_scope)
async def test_scope_authorization_success(scope_middleware_with_mocks, valid_api_key):
"""Test that user with required scope can access protected endpoint"""
middleware, mock_app = scope_middleware_with_mocks
mock_receive = AsyncMock()
mock_send = AsyncMock()
scope = {
"type": "http",
"path": "/test/scoped",
"method": "POST",
"headers": [(b"authorization", f"Bearer {valid_api_key}".encode())],
}
await middleware(scope, mock_receive, mock_send)
# Should call the downstream app (no 403 error sent)
mock_app.assert_called_once_with(scope, mock_receive, mock_send)
mock_send.assert_not_called()
@patch("httpx.AsyncClient.post", new=mock_post_success_no_scope)
async def test_scope_authorization_denied(scope_middleware_with_mocks, valid_api_key):
"""Test that user without required scope gets 403 access denied"""
middleware, mock_app = scope_middleware_with_mocks
mock_receive = AsyncMock()
mock_send = AsyncMock()
scope = {
"type": "http",
"path": "/test/scoped",
"method": "POST",
"headers": [(b"authorization", f"Bearer {valid_api_key}".encode())],
}
await middleware(scope, mock_receive, mock_send)
# Should send 403 error, not call downstream app
mock_app.assert_not_called()
assert mock_send.call_count == 2 # start + body
# Check the response
start_call = mock_send.call_args_list[0][0][0]
assert start_call["status"] == 403
body_call = mock_send.call_args_list[1][0][0]
body_text = body_call["body"].decode()
assert "Access denied" in body_text
assert "test.read" in body_text
@patch("httpx.AsyncClient.post", new=mock_post_success_no_scope)
async def test_public_endpoint_no_scope_required(scope_middleware_with_mocks, valid_api_key):
"""Test that public endpoints work without specific scopes"""
middleware, mock_app = scope_middleware_with_mocks
mock_receive = AsyncMock()
mock_send = AsyncMock()
scope = {
"type": "http",
"path": "/test/public",
"method": "GET",
"headers": [(b"authorization", f"Bearer {valid_api_key}".encode())],
}
await middleware(scope, mock_receive, mock_send)
# Should call the downstream app (no error)
mock_app.assert_called_once_with(scope, mock_receive, mock_send)
mock_send.assert_not_called()
async def test_scope_authorization_no_auth_disabled(scope_middleware_with_mocks):
"""Test that when auth is disabled (no user), scope checks are bypassed"""
middleware, mock_app = scope_middleware_with_mocks
mock_receive = AsyncMock()
mock_send = AsyncMock()
scope = {
"type": "http",
"path": "/test/scoped",
"method": "POST",
"headers": [], # No authorization header
}
await middleware(scope, mock_receive, mock_send)
# Should send 401 auth error, not call downstream app
mock_app.assert_not_called()
assert mock_send.call_count == 2 # start + body
# Check the response
start_call = mock_send.call_args_list[0][0][0]
assert start_call["status"] == 401
body_call = mock_send.call_args_list[1][0][0]
body_text = body_call["body"].decode()
assert "Authentication required" in body_text
def test_has_required_scope_function():
"""Test the _has_required_scope function directly"""
# Test user with required scope
user_with_scope = User(principal="test-user", attributes={"scopes": ["test.read", "other.scope"]})
assert _has_required_scope("test.read", user_with_scope)
# Test user without required scope
user_without_scope = User(principal="test-user", attributes={"scopes": ["other.scope"]})
assert not _has_required_scope("test.read", user_without_scope)
# Test user with no scopes attribute
user_no_scopes = User(principal="test-user", attributes={})
assert not _has_required_scope("test.read", user_no_scopes)
# Test no user (auth disabled)
assert _has_required_scope("test.read", None)

View file

@ -11,8 +11,8 @@ import pytest
from fastapi import FastAPI
from fastapi.testclient import TestClient
from llama_stack.distribution.datatypes import AuthenticationConfig, AuthProviderType, GitHubTokenAuthConfig
from llama_stack.distribution.server.auth import AuthenticationMiddleware
from llama_stack.core.datatypes import AuthenticationConfig, AuthProviderType, GitHubTokenAuthConfig
from llama_stack.core.server.auth import AuthenticationMiddleware
class MockResponse:
@ -49,7 +49,7 @@ def github_token_app():
)
# Add auth middleware
app.add_middleware(AuthenticationMiddleware, auth_config=auth_config)
app.add_middleware(AuthenticationMiddleware, auth_config=auth_config, impls={})
@app.get("/test")
def test_endpoint():
@ -78,7 +78,7 @@ def test_authenticated_endpoint_with_invalid_bearer_format(github_token_client):
assert "Invalid Authorization header format" in response.json()["error"]["message"]
@patch("llama_stack.distribution.server.auth_providers.httpx.AsyncClient")
@patch("llama_stack.core.server.auth_providers.httpx.AsyncClient")
def test_authenticated_endpoint_with_valid_github_token(mock_client_class, github_token_client):
"""Test accessing protected endpoint with valid GitHub token"""
# Mock the GitHub API responses
@ -118,7 +118,7 @@ def test_authenticated_endpoint_with_valid_github_token(mock_client_class, githu
assert calls[0][1]["headers"]["Authorization"] == "Bearer github_token_123"
@patch("llama_stack.distribution.server.auth_providers.httpx.AsyncClient")
@patch("llama_stack.core.server.auth_providers.httpx.AsyncClient")
def test_authenticated_endpoint_with_invalid_github_token(mock_client_class, github_token_client):
"""Test accessing protected endpoint with invalid GitHub token"""
# Mock the GitHub API to return 401 Unauthorized
@ -135,7 +135,7 @@ def test_authenticated_endpoint_with_invalid_github_token(mock_client_class, git
)
@patch("llama_stack.distribution.server.auth_providers.httpx.AsyncClient")
@patch("llama_stack.core.server.auth_providers.httpx.AsyncClient")
def test_github_enterprise_support(mock_client_class):
"""Test GitHub Enterprise support with custom API base URL"""
app = FastAPI()
@ -149,7 +149,7 @@ def test_github_enterprise_support(mock_client_class):
access_policy=[],
)
app.add_middleware(AuthenticationMiddleware, auth_config=auth_config)
app.add_middleware(AuthenticationMiddleware, auth_config=auth_config, impls={})
@app.get("/test")
def test_endpoint():

View file

@ -9,8 +9,8 @@ from fastapi import FastAPI, Request
from fastapi.testclient import TestClient
from starlette.middleware.base import BaseHTTPMiddleware
from llama_stack.distribution.datatypes import QuotaConfig, QuotaPeriod
from llama_stack.distribution.server.quota import QuotaMiddleware
from llama_stack.core.datatypes import QuotaConfig, QuotaPeriod
from llama_stack.core.server.quota import QuotaMiddleware
from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig

View file

@ -5,73 +5,86 @@
# the root directory of this source tree.
import os
import unittest
from llama_stack.distribution.stack import replace_env_vars
import pytest
from llama_stack.core.stack import replace_env_vars
class TestReplaceEnvVars(unittest.TestCase):
def setUp(self):
# Clear any existing environment variables we'll use in tests
for var in ["TEST_VAR", "EMPTY_VAR", "ZERO_VAR"]:
if var in os.environ:
del os.environ[var]
@pytest.fixture
def setup_env_vars():
# Clear any existing environment variables we'll use in tests
for var in ["TEST_VAR", "EMPTY_VAR", "ZERO_VAR"]:
if var in os.environ:
del os.environ[var]
# Set up test environment variables
os.environ["TEST_VAR"] = "test_value"
os.environ["EMPTY_VAR"] = ""
os.environ["ZERO_VAR"] = "0"
# Set up test environment variables
os.environ["TEST_VAR"] = "test_value"
os.environ["EMPTY_VAR"] = ""
os.environ["ZERO_VAR"] = "0"
def test_simple_replacement(self):
self.assertEqual(replace_env_vars("${env.TEST_VAR}"), "test_value")
yield
def test_default_value_when_not_set(self):
self.assertEqual(replace_env_vars("${env.NOT_SET:=default}"), "default")
def test_default_value_when_set(self):
self.assertEqual(replace_env_vars("${env.TEST_VAR:=default}"), "test_value")
def test_default_value_when_empty(self):
self.assertEqual(replace_env_vars("${env.EMPTY_VAR:=default}"), "default")
def test_none_value_when_empty(self):
self.assertEqual(replace_env_vars("${env.EMPTY_VAR:=}"), None)
def test_value_when_set(self):
self.assertEqual(replace_env_vars("${env.TEST_VAR:=}"), "test_value")
def test_empty_var_no_default(self):
self.assertEqual(replace_env_vars("${env.EMPTY_VAR_NO_DEFAULT:+}"), None)
def test_conditional_value_when_set(self):
self.assertEqual(replace_env_vars("${env.TEST_VAR:+conditional}"), "conditional")
def test_conditional_value_when_not_set(self):
self.assertEqual(replace_env_vars("${env.NOT_SET:+conditional}"), None)
def test_conditional_value_when_empty(self):
self.assertEqual(replace_env_vars("${env.EMPTY_VAR:+conditional}"), None)
def test_conditional_value_with_zero(self):
self.assertEqual(replace_env_vars("${env.ZERO_VAR:+conditional}"), "conditional")
def test_mixed_syntax(self):
self.assertEqual(
replace_env_vars("${env.TEST_VAR:=default} and ${env.NOT_SET:+conditional}"), "test_value and "
)
self.assertEqual(
replace_env_vars("${env.NOT_SET:=default} and ${env.TEST_VAR:+conditional}"), "default and conditional"
)
def test_nested_structures(self):
data = {
"key1": "${env.TEST_VAR:=default}",
"key2": ["${env.NOT_SET:=default}", "${env.TEST_VAR:+conditional}"],
"key3": {"nested": "${env.NOT_SET:+conditional}"},
}
expected = {"key1": "test_value", "key2": ["default", "conditional"], "key3": {"nested": None}}
self.assertEqual(replace_env_vars(data), expected)
# Cleanup after test
for var in ["TEST_VAR", "EMPTY_VAR", "ZERO_VAR"]:
if var in os.environ:
del os.environ[var]
if __name__ == "__main__":
unittest.main()
def test_simple_replacement(setup_env_vars):
assert replace_env_vars("${env.TEST_VAR}") == "test_value"
def test_default_value_when_not_set(setup_env_vars):
assert replace_env_vars("${env.NOT_SET:=default}") == "default"
def test_default_value_when_set(setup_env_vars):
assert replace_env_vars("${env.TEST_VAR:=default}") == "test_value"
def test_default_value_when_empty(setup_env_vars):
assert replace_env_vars("${env.EMPTY_VAR:=default}") == "default"
def test_none_value_when_empty(setup_env_vars):
assert replace_env_vars("${env.EMPTY_VAR:=}") is None
def test_value_when_set(setup_env_vars):
assert replace_env_vars("${env.TEST_VAR:=}") == "test_value"
def test_empty_var_no_default(setup_env_vars):
assert replace_env_vars("${env.EMPTY_VAR_NO_DEFAULT:+}") is None
def test_conditional_value_when_set(setup_env_vars):
assert replace_env_vars("${env.TEST_VAR:+conditional}") == "conditional"
def test_conditional_value_when_not_set(setup_env_vars):
assert replace_env_vars("${env.NOT_SET:+conditional}") is None
def test_conditional_value_when_empty(setup_env_vars):
assert replace_env_vars("${env.EMPTY_VAR:+conditional}") is None
def test_conditional_value_with_zero(setup_env_vars):
assert replace_env_vars("${env.ZERO_VAR:+conditional}") == "conditional"
def test_mixed_syntax(setup_env_vars):
assert replace_env_vars("${env.TEST_VAR:=default} and ${env.NOT_SET:+conditional}") == "test_value and "
assert replace_env_vars("${env.NOT_SET:=default} and ${env.TEST_VAR:+conditional}") == "default and conditional"
def test_nested_structures(setup_env_vars):
data = {
"key1": "${env.TEST_VAR:=default}",
"key2": ["${env.NOT_SET:=default}", "${env.TEST_VAR:+conditional}"],
"key3": {"nested": "${env.NOT_SET:+conditional}"},
}
expected = {"key1": "test_value", "key2": ["default", "conditional"], "key3": {"nested": None}}
assert replace_env_vars(data) == expected

View file

@ -9,18 +9,17 @@ import sys
from typing import Any, Protocol
from unittest.mock import AsyncMock, MagicMock
import pytest
from pydantic import BaseModel, Field
from llama_stack.apis.inference import Inference
from llama_stack.distribution.datatypes import (
from llama_stack.core.datatypes import (
Api,
Provider,
StackRunConfig,
)
from llama_stack.distribution.resolver import resolve_impls
from llama_stack.distribution.routers.inference import InferenceRouter
from llama_stack.distribution.routing_tables.models import ModelsRoutingTable
from llama_stack.core.resolver import resolve_impls
from llama_stack.core.routers.inference import InferenceRouter
from llama_stack.core.routing_tables.models import ModelsRoutingTable
from llama_stack.providers.datatypes import InlineProviderSpec, ProviderSpec
@ -66,7 +65,6 @@ class SampleImpl:
pass
@pytest.mark.asyncio
async def test_resolve_impls_basic():
# Create a real provider spec
provider_spec = InlineProviderSpec(

View file

@ -10,9 +10,9 @@ from fastapi import HTTPException
from openai import BadRequestError
from pydantic import ValidationError
from llama_stack.distribution.access_control.access_control import AccessDeniedError
from llama_stack.distribution.datatypes import AuthenticationRequiredError
from llama_stack.distribution.server.server import translate_exception
from llama_stack.core.access_control.access_control import AccessDeniedError
from llama_stack.core.datatypes import AuthenticationRequiredError
from llama_stack.core.server.server import translate_exception
class TestTranslateException:
@ -29,7 +29,7 @@ class TestTranslateException:
def test_translate_access_denied_error_with_context(self):
"""Test that AccessDeniedError with context includes detailed information."""
from llama_stack.distribution.datatypes import User
from llama_stack.core.datatypes import User
# Create mock user and resource
user = User("test-user", {"roles": ["user"], "teams": ["dev"]})

View file

@ -7,13 +7,10 @@
import asyncio
from unittest.mock import AsyncMock, MagicMock
import pytest
from llama_stack.apis.common.responses import PaginatedResponse
from llama_stack.distribution.server.server import create_dynamic_typed_route, create_sse_event, sse_generator
from llama_stack.core.server.server import create_dynamic_typed_route, create_sse_event, sse_generator
@pytest.mark.asyncio
async def test_sse_generator_basic():
# An AsyncIterator wrapped in an Awaitable, just like our web methods
async def async_event_gen():
@ -35,7 +32,6 @@ async def test_sse_generator_basic():
assert seen_events[1] == create_sse_event("Test event 2")
@pytest.mark.asyncio
async def test_sse_generator_client_disconnected():
# An AsyncIterator wrapped in an Awaitable, just like our web methods
async def async_event_gen():
@ -58,7 +54,6 @@ async def test_sse_generator_client_disconnected():
assert seen_events[0] == create_sse_event("Test event 1")
@pytest.mark.asyncio
async def test_sse_generator_client_disconnected_before_response_starts():
# Disconnect before the response starts
async def async_event_gen():
@ -75,7 +70,6 @@ async def test_sse_generator_client_disconnected_before_response_starts():
assert len(seen_events) == 0
@pytest.mark.asyncio
async def test_sse_generator_error_before_response_starts():
# Raise an error before the response starts
async def async_event_gen():
@ -93,7 +87,6 @@ async def test_sse_generator_error_before_response_starts():
assert 'data: {"error":' in seen_events[0]
@pytest.mark.asyncio
async def test_paginated_response_url_setting():
"""Test that PaginatedResponse gets url set to route path."""

View file

@ -42,7 +42,6 @@ def create_test_chat_completion(
)
@pytest.mark.asyncio
async def test_inference_store_pagination_basic():
"""Test basic pagination functionality."""
with TemporaryDirectory() as tmp_dir:
@ -88,7 +87,6 @@ async def test_inference_store_pagination_basic():
assert result3.has_more is False
@pytest.mark.asyncio
async def test_inference_store_pagination_ascending():
"""Test pagination with ascending order."""
with TemporaryDirectory() as tmp_dir:
@ -123,7 +121,6 @@ async def test_inference_store_pagination_ascending():
assert result2.has_more is True
@pytest.mark.asyncio
async def test_inference_store_pagination_with_model_filter():
"""Test pagination combined with model filtering."""
with TemporaryDirectory() as tmp_dir:
@ -161,7 +158,6 @@ async def test_inference_store_pagination_with_model_filter():
assert result2.has_more is False
@pytest.mark.asyncio
async def test_inference_store_pagination_invalid_after():
"""Test error handling for invalid 'after' parameter."""
with TemporaryDirectory() as tmp_dir:
@ -174,7 +170,6 @@ async def test_inference_store_pagination_invalid_after():
await store.list_chat_completions(after="non-existent", limit=2)
@pytest.mark.asyncio
async def test_inference_store_pagination_no_limit():
"""Test pagination behavior when no limit is specified."""
with TemporaryDirectory() as tmp_dir:

View file

@ -44,7 +44,6 @@ def create_test_response_input(content: str, input_id: str) -> OpenAIResponseInp
)
@pytest.mark.asyncio
async def test_responses_store_pagination_basic():
"""Test basic pagination functionality for responses store."""
with TemporaryDirectory() as tmp_dir:
@ -90,7 +89,6 @@ async def test_responses_store_pagination_basic():
assert result3.has_more is False
@pytest.mark.asyncio
async def test_responses_store_pagination_ascending():
"""Test pagination with ascending order."""
with TemporaryDirectory() as tmp_dir:
@ -125,7 +123,6 @@ async def test_responses_store_pagination_ascending():
assert result2.has_more is True
@pytest.mark.asyncio
async def test_responses_store_pagination_with_model_filter():
"""Test pagination combined with model filtering."""
with TemporaryDirectory() as tmp_dir:
@ -163,7 +160,6 @@ async def test_responses_store_pagination_with_model_filter():
assert result2.has_more is False
@pytest.mark.asyncio
async def test_responses_store_pagination_invalid_after():
"""Test error handling for invalid 'after' parameter."""
with TemporaryDirectory() as tmp_dir:
@ -176,7 +172,6 @@ async def test_responses_store_pagination_invalid_after():
await store.list_responses(after="non-existent", limit=2)
@pytest.mark.asyncio
async def test_responses_store_pagination_no_limit():
"""Test pagination behavior when no limit is specified."""
with TemporaryDirectory() as tmp_dir:
@ -205,7 +200,6 @@ async def test_responses_store_pagination_no_limit():
assert result.has_more is False
@pytest.mark.asyncio
async def test_responses_store_get_response_object():
"""Test retrieving a single response object."""
with TemporaryDirectory() as tmp_dir:
@ -230,7 +224,6 @@ async def test_responses_store_get_response_object():
await store.get_response_object("non-existent")
@pytest.mark.asyncio
async def test_responses_store_input_items_pagination():
"""Test pagination functionality for input items."""
with TemporaryDirectory() as tmp_dir:
@ -308,7 +301,6 @@ async def test_responses_store_input_items_pagination():
await store.list_response_input_items("test-resp", before="some-id", after="other-id")
@pytest.mark.asyncio
async def test_responses_store_input_items_before_pagination():
"""Test before pagination functionality for input items."""
with TemporaryDirectory() as tmp_dir:

View file

@ -14,7 +14,6 @@ from llama_stack.providers.utils.sqlstore.sqlalchemy_sqlstore import SqlAlchemyS
from llama_stack.providers.utils.sqlstore.sqlstore import SqliteSqlStoreConfig
@pytest.mark.asyncio
async def test_sqlite_sqlstore():
with TemporaryDirectory() as tmp_dir:
db_name = "test.db"
@ -66,7 +65,6 @@ async def test_sqlite_sqlstore():
assert result.has_more is False
@pytest.mark.asyncio
async def test_sqlstore_pagination_basic():
"""Test basic pagination functionality at the SQL store level."""
with TemporaryDirectory() as tmp_dir:
@ -131,7 +129,6 @@ async def test_sqlstore_pagination_basic():
assert result3.has_more is False
@pytest.mark.asyncio
async def test_sqlstore_pagination_with_filter():
"""Test pagination with WHERE conditions."""
with TemporaryDirectory() as tmp_dir:
@ -184,7 +181,6 @@ async def test_sqlstore_pagination_with_filter():
assert result2.has_more is False
@pytest.mark.asyncio
async def test_sqlstore_pagination_ascending_order():
"""Test pagination with ascending order."""
with TemporaryDirectory() as tmp_dir:
@ -233,7 +229,6 @@ async def test_sqlstore_pagination_ascending_order():
assert result2.has_more is True
@pytest.mark.asyncio
async def test_sqlstore_pagination_multi_column_ordering_error():
"""Test that multi-column ordering raises an error when using cursor pagination."""
with TemporaryDirectory() as tmp_dir:
@ -271,7 +266,6 @@ async def test_sqlstore_pagination_multi_column_ordering_error():
assert result.data[0]["id"] == "task1"
@pytest.mark.asyncio
async def test_sqlstore_pagination_cursor_requires_order_by():
"""Test that cursor pagination requires order_by parameter."""
with TemporaryDirectory() as tmp_dir:
@ -289,7 +283,6 @@ async def test_sqlstore_pagination_cursor_requires_order_by():
)
@pytest.mark.asyncio
async def test_sqlstore_pagination_error_handling():
"""Test error handling for invalid columns and cursor IDs."""
with TemporaryDirectory() as tmp_dir:
@ -339,7 +332,6 @@ async def test_sqlstore_pagination_error_handling():
)
@pytest.mark.asyncio
async def test_sqlstore_pagination_custom_key_column():
"""Test pagination with custom primary key column (not 'id')."""
with TemporaryDirectory() as tmp_dir:

View file

@ -7,18 +7,15 @@
from tempfile import TemporaryDirectory
from unittest.mock import patch
import pytest
from llama_stack.distribution.access_control.access_control import default_policy, is_action_allowed
from llama_stack.distribution.access_control.datatypes import Action
from llama_stack.distribution.datatypes import User
from llama_stack.core.access_control.access_control import default_policy, is_action_allowed
from llama_stack.core.access_control.datatypes import Action
from llama_stack.core.datatypes import User
from llama_stack.providers.utils.sqlstore.api import ColumnType
from llama_stack.providers.utils.sqlstore.authorized_sqlstore import AuthorizedSqlStore, SqlRecord
from llama_stack.providers.utils.sqlstore.sqlalchemy_sqlstore import SqlAlchemySqlStoreImpl
from llama_stack.providers.utils.sqlstore.sqlstore import SqliteSqlStoreConfig
@pytest.mark.asyncio
@patch("llama_stack.providers.utils.sqlstore.authorized_sqlstore.get_authenticated_user")
async def test_authorized_fetch_with_where_sql_access_control(mock_get_authenticated_user):
"""Test that fetch_all works correctly with where_sql for access control"""
@ -81,7 +78,6 @@ async def test_authorized_fetch_with_where_sql_access_control(mock_get_authentic
assert row["title"] == "User Document"
@pytest.mark.asyncio
@patch("llama_stack.providers.utils.sqlstore.authorized_sqlstore.get_authenticated_user")
async def test_sql_policy_consistency(mock_get_authenticated_user):
"""Test that SQL WHERE clause logic exactly matches is_action_allowed policy logic"""
@ -168,7 +164,6 @@ async def test_sql_policy_consistency(mock_get_authenticated_user):
)
@pytest.mark.asyncio
@patch("llama_stack.providers.utils.sqlstore.authorized_sqlstore.get_authenticated_user")
async def test_authorized_store_user_attribute_capture(mock_get_authenticated_user):
"""Test that user attributes are properly captured during insert"""