From 01fac67e334ebb1fc6a458a2f5819448067aece2 Mon Sep 17 00:00:00 2001 From: Ashwin Bharambe Date: Wed, 19 Mar 2025 13:49:23 -0700 Subject: [PATCH] fix unit tests --- llama_stack/distribution/request_headers.py | 3 -- .../distribution/routers/routing_tables.py | 7 +---- llama_stack/distribution/server/auth.py | 6 ++-- tests/unit/registry/test_registry_acl.py | 30 ++++++++++++------- tests/unit/server/test_access_control.py | 18 ++++++----- tests/unit/server/test_auth.py | 13 +++----- 6 files changed, 37 insertions(+), 40 deletions(-) diff --git a/llama_stack/distribution/request_headers.py b/llama_stack/distribution/request_headers.py index 3d13621a3..f9cde2cdf 100644 --- a/llama_stack/distribution/request_headers.py +++ b/llama_stack/distribution/request_headers.py @@ -23,10 +23,7 @@ class RequestProviderDataContext(ContextManager): def __init__( self, provider_data: Optional[Dict[str, Any]] = None, auth_attributes: Optional[Dict[str, List[str]]] = None ): - # Initialize with either provider_data or create a new dict self.provider_data = provider_data or {} - - # Add auth attributes under a special key if provided if auth_attributes: self.provider_data["__auth_attributes"] = auth_attributes diff --git a/llama_stack/distribution/routers/routing_tables.py b/llama_stack/distribution/routers/routing_tables.py index d2cc0c3d0..560fa92b9 100644 --- a/llama_stack/distribution/routers/routing_tables.py +++ b/llama_stack/distribution/routers/routing_tables.py @@ -274,17 +274,14 @@ class CommonRoutingTableImpl(RoutingTable): if not hasattr(obj, "access_attributes") or not obj.access_attributes: return True - # Get user attributes from context + # Get user attributes from request context user_attributes = get_auth_attributes() # If no user attributes, deny access to objects with access control if not user_attributes: return False - # Convert AccessAttributes to dictionary for checking obj_attributes = obj.access_attributes.model_dump(exclude_none=True) - - # If the model_dump is empty (all fields are None), allow access if not obj_attributes: return True @@ -292,14 +289,12 @@ class CommonRoutingTableImpl(RoutingTable): for attr_key, required_values in obj_attributes.items(): user_values = user_attributes.get(attr_key, []) - # No values for this category in user attributes if not user_values: logger.debug( f"Access denied to {obj.type} '{obj.identifier}': missing required attribute category '{attr_key}'" ) return False - # None of the values in this category match (need at least one match per category) if not any(val in user_values for val in required_values): logger.debug( f"Access denied to {obj.type} '{obj.identifier}': " diff --git a/llama_stack/distribution/server/auth.py b/llama_stack/distribution/server/auth.py index d8add22ea..12e342c92 100644 --- a/llama_stack/distribution/server/auth.py +++ b/llama_stack/distribution/server/auth.py @@ -166,18 +166,16 @@ class AuthenticationMiddleware: # Parse and validate the auth response try: response_data = response.json() - auth_response = AuthResponse(**response_data) # Store attributes in request scope for access control if auth_response.access_attributes: user_attributes = auth_response.access_attributes.model_dump(exclude_none=True) - scope["user_attributes"] = user_attributes else: logger.warning("Authentication response did not contain any attributes") - scope["user_attributes"] = {} + user_attributes = {} - # Log authentication success with attribute details + scope["user_attributes"] = user_attributes logger.debug(f"Authentication successful: {len(user_attributes)} attributes") except Exception: logger.exception("Error parsing authentication response") diff --git a/tests/unit/registry/test_registry_acl.py b/tests/unit/registry/test_registry_acl.py index bd8035cfc..ee8f28176 100644 --- a/tests/unit/registry/test_registry_acl.py +++ b/tests/unit/registry/test_registry_acl.py @@ -18,7 +18,7 @@ from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig from llama_stack.providers.utils.kvstore.sqlite import SqliteKVStoreImpl -@pytest.fixture +@pytest.fixture(scope="function") async def kvstore(): temp_dir = tempfile.mkdtemp() db_path = os.path.join(temp_dir, "test_registry_acl.db") @@ -29,13 +29,14 @@ async def kvstore(): shutil.rmtree(temp_dir) -@pytest.fixture +@pytest.fixture(scope="function") async def registry(kvstore): registry = CachedDiskDistributionRegistry(kvstore) await registry.initialize() return registry +@pytest.mark.asyncio async def test_registry_cache_with_acl(registry): model = ModelWithACL( identifier="model-acl", @@ -68,7 +69,7 @@ async def test_registry_cache_with_acl(registry): assert updated_cached.access_attributes.projects == ["project-x"] assert updated_cached.access_attributes.teams is None - new_registry = CachedDiskDistributionRegistry(registry._kvstore) + new_registry = CachedDiskDistributionRegistry(registry.kvstore) await new_registry.initialize() new_model = await new_registry.get("model", "model-acl") @@ -79,6 +80,7 @@ async def test_registry_cache_with_acl(registry): assert new_model.access_attributes.teams is None +@pytest.mark.asyncio async def test_registry_empty_acl(registry): model = ModelWithACL( identifier="model-empty-acl", @@ -118,24 +120,30 @@ async def test_registry_empty_acl(registry): assert len(all_models) == 2 +@pytest.mark.asyncio async def test_registry_serialization(registry): + attributes = AccessAttributes( + roles=["admin", "researcher"], + teams=["ai-team", "ml-team"], + projects=["project-a", "project-b"], + namespaces=["prod", "staging"], + ) + model = ModelWithACL( identifier="model-serialize", provider_id="test-provider", provider_resource_id="model-resource", model_type=ModelType.llm, - access_attributes=AccessAttributes( - roles=["admin", "researcher"], - teams=["ai-team", "ml-team"], - projects=["project-a", "project-b"], - namespaces=["prod", "staging"], - ), + access_attributes=attributes, ) await registry.register(model) - registry.cache.clear() - loaded_model = await registry.get("model", "model-serialize") + new_registry = CachedDiskDistributionRegistry(registry.kvstore) + await new_registry.initialize() + + loaded_model = await new_registry.get("model", "model-serialize") + assert loaded_model is not None assert loaded_model.access_attributes.roles == ["admin", "researcher"] assert loaded_model.access_attributes.teams == ["ai-team", "ml-team"] diff --git a/tests/unit/server/test_access_control.py b/tests/unit/server/test_access_control.py index 45f30a2c8..ab0feb1a9 100644 --- a/tests/unit/server/test_access_control.py +++ b/tests/unit/server/test_access_control.py @@ -10,8 +10,8 @@ import tempfile from unittest.mock import MagicMock, Mock, patch import pytest -import pytest_asyncio +from llama_stack.apis.datatypes import Api from llama_stack.apis.models import ModelType from llama_stack.distribution.datatypes import AccessAttributes, ModelWithACL from llama_stack.distribution.routers.routing_tables import ModelsRoutingTable @@ -29,18 +29,19 @@ def _return_model(model): return model -@pytest_asyncio.fixture +@pytest.fixture async def test_setup(): temp_dir = tempfile.mkdtemp() db_path = os.path.join(temp_dir, "test_access_control.db") kvstore_config = SqliteKVStoreConfig(db_path=db_path) kvstore = SqliteKVStoreImpl(kvstore_config) + await kvstore.initialize() registry = CachedDiskDistributionRegistry(kvstore) await registry.initialize() mock_inference = Mock() mock_inference.__provider_spec__ = MagicMock() - mock_inference.__provider_spec__.api = "inference" + mock_inference.__provider_spec__.api = Api.inference mock_inference.register_model = AsyncMock(side_effect=_return_model) routing_table = ModelsRoutingTable( impls_by_provider_id={"test_provider": mock_inference}, @@ -80,14 +81,14 @@ async def test_access_control_with_cache(mock_get_auth_attributes, test_setup): mock_get_auth_attributes.return_value = {"roles": ["admin"], "teams": ["management"]} all_models = await routing_table.list_models() - assert len(all_models.data) == 3 + assert len(all_models.data) == 2 model = await routing_table.get_model("model-public") assert model.identifier == "model-public" model = await routing_table.get_model("model-admin") assert model.identifier == "model-admin" - model = await routing_table.get_model("model-data-scientist") - assert model.identifier == "model-data-scientist" + with pytest.raises(ValueError): + await routing_table.get_model("model-data-scientist") mock_get_auth_attributes.return_value = {"roles": ["data-scientist"], "teams": ["other-team"]} all_models = await routing_table.list_models() @@ -157,7 +158,9 @@ async def test_access_control_empty_attributes(mock_get_auth_attributes, test_se access_attributes=AccessAttributes(), ) await registry.register(model) - mock_get_auth_attributes.return_value = {} + mock_get_auth_attributes.return_value = { + "roles": [], + } result = await routing_table.get_model("model-empty-attrs") assert result.identifier == "model-empty-attrs" all_models = await routing_table.list_models() @@ -231,6 +234,7 @@ async def test_automatic_access_attributes(mock_get_auth_attributes, test_setup) mock_get_auth_attributes.return_value = { "roles": ["data-scientist", "engineer"], "teams": ["ml-team", "platform-team"], + "projects": ["llama-3"], } model = await routing_table.get_model("auto-access-model") assert model.identifier == "auto-access-model" diff --git a/tests/unit/server/test_auth.py b/tests/unit/server/test_auth.py index fc4a8ecae..b078448a2 100644 --- a/tests/unit/server/test_auth.py +++ b/tests/unit/server/test_auth.py @@ -75,15 +75,11 @@ def mock_middleware(mock_auth_endpoint): async def mock_post_success(*args, **kwargs): - mock_response = AsyncMock() - mock_response.status_code = 200 - return mock_response + return MockResponse(200, {"message": "Authentication successful"}) async def mock_post_failure(*args, **kwargs): - mock_response = AsyncMock() - mock_response.status_code = 401 - return mock_response + return MockResponse(401, {"message": "Authentication failed"}) async def mock_post_exception(*args, **kwargs): @@ -125,8 +121,7 @@ def test_auth_service_error(client, valid_api_key): def test_auth_request_payload(client, valid_api_key, mock_auth_endpoint): with patch("httpx.AsyncClient.post") as mock_post: - mock_response = AsyncMock() - mock_response.status_code = 200 + mock_response = MockResponse(200, {"message": "Authentication successful"}) mock_post.return_value = mock_response client.get( @@ -148,7 +143,7 @@ def test_auth_request_payload(client, valid_api_key, mock_auth_endpoint): payload = kwargs["json"] assert payload["api_key"] == valid_api_key assert payload["request"]["path"] == "/test" - assert "authorization" in payload["request"]["headers"] + assert "authorization" not in payload["request"]["headers"] assert "param1" in payload["request"]["params"] assert "param2" in payload["request"]["params"]