fix unit tests

This commit is contained in:
Ashwin Bharambe 2025-03-19 13:49:23 -07:00
parent b937a49436
commit 01fac67e33
6 changed files with 37 additions and 40 deletions

View file

@ -18,7 +18,7 @@ from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig
from llama_stack.providers.utils.kvstore.sqlite import SqliteKVStoreImpl
@pytest.fixture
@pytest.fixture(scope="function")
async def kvstore():
temp_dir = tempfile.mkdtemp()
db_path = os.path.join(temp_dir, "test_registry_acl.db")
@ -29,13 +29,14 @@ async def kvstore():
shutil.rmtree(temp_dir)
@pytest.fixture
@pytest.fixture(scope="function")
async def registry(kvstore):
registry = CachedDiskDistributionRegistry(kvstore)
await registry.initialize()
return registry
@pytest.mark.asyncio
async def test_registry_cache_with_acl(registry):
model = ModelWithACL(
identifier="model-acl",
@ -68,7 +69,7 @@ async def test_registry_cache_with_acl(registry):
assert updated_cached.access_attributes.projects == ["project-x"]
assert updated_cached.access_attributes.teams is None
new_registry = CachedDiskDistributionRegistry(registry._kvstore)
new_registry = CachedDiskDistributionRegistry(registry.kvstore)
await new_registry.initialize()
new_model = await new_registry.get("model", "model-acl")
@ -79,6 +80,7 @@ async def test_registry_cache_with_acl(registry):
assert new_model.access_attributes.teams is None
@pytest.mark.asyncio
async def test_registry_empty_acl(registry):
model = ModelWithACL(
identifier="model-empty-acl",
@ -118,24 +120,30 @@ async def test_registry_empty_acl(registry):
assert len(all_models) == 2
@pytest.mark.asyncio
async def test_registry_serialization(registry):
attributes = AccessAttributes(
roles=["admin", "researcher"],
teams=["ai-team", "ml-team"],
projects=["project-a", "project-b"],
namespaces=["prod", "staging"],
)
model = ModelWithACL(
identifier="model-serialize",
provider_id="test-provider",
provider_resource_id="model-resource",
model_type=ModelType.llm,
access_attributes=AccessAttributes(
roles=["admin", "researcher"],
teams=["ai-team", "ml-team"],
projects=["project-a", "project-b"],
namespaces=["prod", "staging"],
),
access_attributes=attributes,
)
await registry.register(model)
registry.cache.clear()
loaded_model = await registry.get("model", "model-serialize")
new_registry = CachedDiskDistributionRegistry(registry.kvstore)
await new_registry.initialize()
loaded_model = await new_registry.get("model", "model-serialize")
assert loaded_model is not None
assert loaded_model.access_attributes.roles == ["admin", "researcher"]
assert loaded_model.access_attributes.teams == ["ai-team", "ml-team"]

View file

@ -10,8 +10,8 @@ import tempfile
from unittest.mock import MagicMock, Mock, patch
import pytest
import pytest_asyncio
from llama_stack.apis.datatypes import Api
from llama_stack.apis.models import ModelType
from llama_stack.distribution.datatypes import AccessAttributes, ModelWithACL
from llama_stack.distribution.routers.routing_tables import ModelsRoutingTable
@ -29,18 +29,19 @@ def _return_model(model):
return model
@pytest_asyncio.fixture
@pytest.fixture
async def test_setup():
temp_dir = tempfile.mkdtemp()
db_path = os.path.join(temp_dir, "test_access_control.db")
kvstore_config = SqliteKVStoreConfig(db_path=db_path)
kvstore = SqliteKVStoreImpl(kvstore_config)
await kvstore.initialize()
registry = CachedDiskDistributionRegistry(kvstore)
await registry.initialize()
mock_inference = Mock()
mock_inference.__provider_spec__ = MagicMock()
mock_inference.__provider_spec__.api = "inference"
mock_inference.__provider_spec__.api = Api.inference
mock_inference.register_model = AsyncMock(side_effect=_return_model)
routing_table = ModelsRoutingTable(
impls_by_provider_id={"test_provider": mock_inference},
@ -80,14 +81,14 @@ async def test_access_control_with_cache(mock_get_auth_attributes, test_setup):
mock_get_auth_attributes.return_value = {"roles": ["admin"], "teams": ["management"]}
all_models = await routing_table.list_models()
assert len(all_models.data) == 3
assert len(all_models.data) == 2
model = await routing_table.get_model("model-public")
assert model.identifier == "model-public"
model = await routing_table.get_model("model-admin")
assert model.identifier == "model-admin"
model = await routing_table.get_model("model-data-scientist")
assert model.identifier == "model-data-scientist"
with pytest.raises(ValueError):
await routing_table.get_model("model-data-scientist")
mock_get_auth_attributes.return_value = {"roles": ["data-scientist"], "teams": ["other-team"]}
all_models = await routing_table.list_models()
@ -157,7 +158,9 @@ async def test_access_control_empty_attributes(mock_get_auth_attributes, test_se
access_attributes=AccessAttributes(),
)
await registry.register(model)
mock_get_auth_attributes.return_value = {}
mock_get_auth_attributes.return_value = {
"roles": [],
}
result = await routing_table.get_model("model-empty-attrs")
assert result.identifier == "model-empty-attrs"
all_models = await routing_table.list_models()
@ -231,6 +234,7 @@ async def test_automatic_access_attributes(mock_get_auth_attributes, test_setup)
mock_get_auth_attributes.return_value = {
"roles": ["data-scientist", "engineer"],
"teams": ["ml-team", "platform-team"],
"projects": ["llama-3"],
}
model = await routing_table.get_model("auto-access-model")
assert model.identifier == "auto-access-model"

View file

@ -75,15 +75,11 @@ def mock_middleware(mock_auth_endpoint):
async def mock_post_success(*args, **kwargs):
mock_response = AsyncMock()
mock_response.status_code = 200
return mock_response
return MockResponse(200, {"message": "Authentication successful"})
async def mock_post_failure(*args, **kwargs):
mock_response = AsyncMock()
mock_response.status_code = 401
return mock_response
return MockResponse(401, {"message": "Authentication failed"})
async def mock_post_exception(*args, **kwargs):
@ -125,8 +121,7 @@ def test_auth_service_error(client, valid_api_key):
def test_auth_request_payload(client, valid_api_key, mock_auth_endpoint):
with patch("httpx.AsyncClient.post") as mock_post:
mock_response = AsyncMock()
mock_response.status_code = 200
mock_response = MockResponse(200, {"message": "Authentication successful"})
mock_post.return_value = mock_response
client.get(
@ -148,7 +143,7 @@ def test_auth_request_payload(client, valid_api_key, mock_auth_endpoint):
payload = kwargs["json"]
assert payload["api_key"] == valid_api_key
assert payload["request"]["path"] == "/test"
assert "authorization" in payload["request"]["headers"]
assert "authorization" not in payload["request"]["headers"]
assert "param1" in payload["request"]["params"]
assert "param2" in payload["request"]["params"]