mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-06 10:42:39 +00:00
fix unit tests
This commit is contained in:
parent
b937a49436
commit
01fac67e33
6 changed files with 37 additions and 40 deletions
|
@ -23,10 +23,7 @@ class RequestProviderDataContext(ContextManager):
|
|||
def __init__(
|
||||
self, provider_data: Optional[Dict[str, Any]] = None, auth_attributes: Optional[Dict[str, List[str]]] = None
|
||||
):
|
||||
# Initialize with either provider_data or create a new dict
|
||||
self.provider_data = provider_data or {}
|
||||
|
||||
# Add auth attributes under a special key if provided
|
||||
if auth_attributes:
|
||||
self.provider_data["__auth_attributes"] = auth_attributes
|
||||
|
||||
|
|
|
@ -274,17 +274,14 @@ class CommonRoutingTableImpl(RoutingTable):
|
|||
if not hasattr(obj, "access_attributes") or not obj.access_attributes:
|
||||
return True
|
||||
|
||||
# Get user attributes from context
|
||||
# Get user attributes from request context
|
||||
user_attributes = get_auth_attributes()
|
||||
|
||||
# If no user attributes, deny access to objects with access control
|
||||
if not user_attributes:
|
||||
return False
|
||||
|
||||
# Convert AccessAttributes to dictionary for checking
|
||||
obj_attributes = obj.access_attributes.model_dump(exclude_none=True)
|
||||
|
||||
# If the model_dump is empty (all fields are None), allow access
|
||||
if not obj_attributes:
|
||||
return True
|
||||
|
||||
|
@ -292,14 +289,12 @@ class CommonRoutingTableImpl(RoutingTable):
|
|||
for attr_key, required_values in obj_attributes.items():
|
||||
user_values = user_attributes.get(attr_key, [])
|
||||
|
||||
# No values for this category in user attributes
|
||||
if not user_values:
|
||||
logger.debug(
|
||||
f"Access denied to {obj.type} '{obj.identifier}': missing required attribute category '{attr_key}'"
|
||||
)
|
||||
return False
|
||||
|
||||
# None of the values in this category match (need at least one match per category)
|
||||
if not any(val in user_values for val in required_values):
|
||||
logger.debug(
|
||||
f"Access denied to {obj.type} '{obj.identifier}': "
|
||||
|
|
|
@ -166,18 +166,16 @@ class AuthenticationMiddleware:
|
|||
# Parse and validate the auth response
|
||||
try:
|
||||
response_data = response.json()
|
||||
|
||||
auth_response = AuthResponse(**response_data)
|
||||
|
||||
# Store attributes in request scope for access control
|
||||
if auth_response.access_attributes:
|
||||
user_attributes = auth_response.access_attributes.model_dump(exclude_none=True)
|
||||
scope["user_attributes"] = user_attributes
|
||||
else:
|
||||
logger.warning("Authentication response did not contain any attributes")
|
||||
scope["user_attributes"] = {}
|
||||
user_attributes = {}
|
||||
|
||||
# Log authentication success with attribute details
|
||||
scope["user_attributes"] = user_attributes
|
||||
logger.debug(f"Authentication successful: {len(user_attributes)} attributes")
|
||||
except Exception:
|
||||
logger.exception("Error parsing authentication response")
|
||||
|
|
|
@ -18,7 +18,7 @@ from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig
|
|||
from llama_stack.providers.utils.kvstore.sqlite import SqliteKVStoreImpl
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@pytest.fixture(scope="function")
|
||||
async def kvstore():
|
||||
temp_dir = tempfile.mkdtemp()
|
||||
db_path = os.path.join(temp_dir, "test_registry_acl.db")
|
||||
|
@ -29,13 +29,14 @@ async def kvstore():
|
|||
shutil.rmtree(temp_dir)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@pytest.fixture(scope="function")
|
||||
async def registry(kvstore):
|
||||
registry = CachedDiskDistributionRegistry(kvstore)
|
||||
await registry.initialize()
|
||||
return registry
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_registry_cache_with_acl(registry):
|
||||
model = ModelWithACL(
|
||||
identifier="model-acl",
|
||||
|
@ -68,7 +69,7 @@ async def test_registry_cache_with_acl(registry):
|
|||
assert updated_cached.access_attributes.projects == ["project-x"]
|
||||
assert updated_cached.access_attributes.teams is None
|
||||
|
||||
new_registry = CachedDiskDistributionRegistry(registry._kvstore)
|
||||
new_registry = CachedDiskDistributionRegistry(registry.kvstore)
|
||||
await new_registry.initialize()
|
||||
|
||||
new_model = await new_registry.get("model", "model-acl")
|
||||
|
@ -79,6 +80,7 @@ async def test_registry_cache_with_acl(registry):
|
|||
assert new_model.access_attributes.teams is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_registry_empty_acl(registry):
|
||||
model = ModelWithACL(
|
||||
identifier="model-empty-acl",
|
||||
|
@ -118,24 +120,30 @@ async def test_registry_empty_acl(registry):
|
|||
assert len(all_models) == 2
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_registry_serialization(registry):
|
||||
attributes = AccessAttributes(
|
||||
roles=["admin", "researcher"],
|
||||
teams=["ai-team", "ml-team"],
|
||||
projects=["project-a", "project-b"],
|
||||
namespaces=["prod", "staging"],
|
||||
)
|
||||
|
||||
model = ModelWithACL(
|
||||
identifier="model-serialize",
|
||||
provider_id="test-provider",
|
||||
provider_resource_id="model-resource",
|
||||
model_type=ModelType.llm,
|
||||
access_attributes=AccessAttributes(
|
||||
roles=["admin", "researcher"],
|
||||
teams=["ai-team", "ml-team"],
|
||||
projects=["project-a", "project-b"],
|
||||
namespaces=["prod", "staging"],
|
||||
),
|
||||
access_attributes=attributes,
|
||||
)
|
||||
|
||||
await registry.register(model)
|
||||
registry.cache.clear()
|
||||
|
||||
loaded_model = await registry.get("model", "model-serialize")
|
||||
new_registry = CachedDiskDistributionRegistry(registry.kvstore)
|
||||
await new_registry.initialize()
|
||||
|
||||
loaded_model = await new_registry.get("model", "model-serialize")
|
||||
assert loaded_model is not None
|
||||
|
||||
assert loaded_model.access_attributes.roles == ["admin", "researcher"]
|
||||
assert loaded_model.access_attributes.teams == ["ai-team", "ml-team"]
|
||||
|
|
|
@ -10,8 +10,8 @@ import tempfile
|
|||
from unittest.mock import MagicMock, Mock, patch
|
||||
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
|
||||
from llama_stack.apis.datatypes import Api
|
||||
from llama_stack.apis.models import ModelType
|
||||
from llama_stack.distribution.datatypes import AccessAttributes, ModelWithACL
|
||||
from llama_stack.distribution.routers.routing_tables import ModelsRoutingTable
|
||||
|
@ -29,18 +29,19 @@ def _return_model(model):
|
|||
return model
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
@pytest.fixture
|
||||
async def test_setup():
|
||||
temp_dir = tempfile.mkdtemp()
|
||||
db_path = os.path.join(temp_dir, "test_access_control.db")
|
||||
kvstore_config = SqliteKVStoreConfig(db_path=db_path)
|
||||
kvstore = SqliteKVStoreImpl(kvstore_config)
|
||||
await kvstore.initialize()
|
||||
registry = CachedDiskDistributionRegistry(kvstore)
|
||||
await registry.initialize()
|
||||
|
||||
mock_inference = Mock()
|
||||
mock_inference.__provider_spec__ = MagicMock()
|
||||
mock_inference.__provider_spec__.api = "inference"
|
||||
mock_inference.__provider_spec__.api = Api.inference
|
||||
mock_inference.register_model = AsyncMock(side_effect=_return_model)
|
||||
routing_table = ModelsRoutingTable(
|
||||
impls_by_provider_id={"test_provider": mock_inference},
|
||||
|
@ -80,14 +81,14 @@ async def test_access_control_with_cache(mock_get_auth_attributes, test_setup):
|
|||
|
||||
mock_get_auth_attributes.return_value = {"roles": ["admin"], "teams": ["management"]}
|
||||
all_models = await routing_table.list_models()
|
||||
assert len(all_models.data) == 3
|
||||
assert len(all_models.data) == 2
|
||||
|
||||
model = await routing_table.get_model("model-public")
|
||||
assert model.identifier == "model-public"
|
||||
model = await routing_table.get_model("model-admin")
|
||||
assert model.identifier == "model-admin"
|
||||
model = await routing_table.get_model("model-data-scientist")
|
||||
assert model.identifier == "model-data-scientist"
|
||||
with pytest.raises(ValueError):
|
||||
await routing_table.get_model("model-data-scientist")
|
||||
|
||||
mock_get_auth_attributes.return_value = {"roles": ["data-scientist"], "teams": ["other-team"]}
|
||||
all_models = await routing_table.list_models()
|
||||
|
@ -157,7 +158,9 @@ async def test_access_control_empty_attributes(mock_get_auth_attributes, test_se
|
|||
access_attributes=AccessAttributes(),
|
||||
)
|
||||
await registry.register(model)
|
||||
mock_get_auth_attributes.return_value = {}
|
||||
mock_get_auth_attributes.return_value = {
|
||||
"roles": [],
|
||||
}
|
||||
result = await routing_table.get_model("model-empty-attrs")
|
||||
assert result.identifier == "model-empty-attrs"
|
||||
all_models = await routing_table.list_models()
|
||||
|
@ -231,6 +234,7 @@ async def test_automatic_access_attributes(mock_get_auth_attributes, test_setup)
|
|||
mock_get_auth_attributes.return_value = {
|
||||
"roles": ["data-scientist", "engineer"],
|
||||
"teams": ["ml-team", "platform-team"],
|
||||
"projects": ["llama-3"],
|
||||
}
|
||||
model = await routing_table.get_model("auto-access-model")
|
||||
assert model.identifier == "auto-access-model"
|
||||
|
|
|
@ -75,15 +75,11 @@ def mock_middleware(mock_auth_endpoint):
|
|||
|
||||
|
||||
async def mock_post_success(*args, **kwargs):
|
||||
mock_response = AsyncMock()
|
||||
mock_response.status_code = 200
|
||||
return mock_response
|
||||
return MockResponse(200, {"message": "Authentication successful"})
|
||||
|
||||
|
||||
async def mock_post_failure(*args, **kwargs):
|
||||
mock_response = AsyncMock()
|
||||
mock_response.status_code = 401
|
||||
return mock_response
|
||||
return MockResponse(401, {"message": "Authentication failed"})
|
||||
|
||||
|
||||
async def mock_post_exception(*args, **kwargs):
|
||||
|
@ -125,8 +121,7 @@ def test_auth_service_error(client, valid_api_key):
|
|||
|
||||
def test_auth_request_payload(client, valid_api_key, mock_auth_endpoint):
|
||||
with patch("httpx.AsyncClient.post") as mock_post:
|
||||
mock_response = AsyncMock()
|
||||
mock_response.status_code = 200
|
||||
mock_response = MockResponse(200, {"message": "Authentication successful"})
|
||||
mock_post.return_value = mock_response
|
||||
|
||||
client.get(
|
||||
|
@ -148,7 +143,7 @@ def test_auth_request_payload(client, valid_api_key, mock_auth_endpoint):
|
|||
payload = kwargs["json"]
|
||||
assert payload["api_key"] == valid_api_key
|
||||
assert payload["request"]["path"] == "/test"
|
||||
assert "authorization" in payload["request"]["headers"]
|
||||
assert "authorization" not in payload["request"]["headers"]
|
||||
assert "param1" in payload["request"]["params"]
|
||||
assert "param2" in payload["request"]["params"]
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue