mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-22 22:42:25 +00:00
428 lines
No EOL
16 KiB
Python
428 lines
No EOL
16 KiB
Python
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
# All rights reserved.
|
|
#
|
|
# This source code is licensed under the terms described in the LICENSE file in
|
|
# the root directory of this source tree.
|
|
|
|
import pytest
|
|
from unittest.mock import AsyncMock, MagicMock, patch
|
|
|
|
from llama_stack.distribution.server.oauth2_scopes import (
|
|
STANDARD_OAUTH2_SCOPES,
|
|
get_required_scopes_for_api,
|
|
validate_scopes,
|
|
scope_grants_admin_access,
|
|
get_all_scope_descriptions,
|
|
)
|
|
from llama_stack.distribution.server.auth import extract_api_from_path, AuthenticationMiddleware
|
|
from llama_stack.distribution.datatypes import AuthenticationConfig, OAuth2TokenAuthConfig, AuthProviderType
|
|
|
|
|
|
class TestOAuth2Scopes:
|
|
"""Test OAuth2 scope definitions and validation"""
|
|
|
|
def test_standard_scopes_exist(self):
|
|
"""Test that all expected standard scopes are defined"""
|
|
expected_scopes = {
|
|
"llama:inference",
|
|
"llama:models:read",
|
|
"llama:models:write",
|
|
"llama:agents:read",
|
|
"llama:agents:write",
|
|
"llama:tools",
|
|
"llama:vector_dbs:read",
|
|
"llama:vector_dbs:write",
|
|
"llama:safety",
|
|
"llama:eval",
|
|
"llama:admin",
|
|
}
|
|
|
|
assert set(STANDARD_OAUTH2_SCOPES.keys()) == expected_scopes
|
|
|
|
# Verify all scopes have descriptions
|
|
for scope, description in STANDARD_OAUTH2_SCOPES.items():
|
|
assert isinstance(description, str)
|
|
assert len(description) > 10 # Reasonable description length
|
|
|
|
def test_get_all_scope_descriptions(self):
|
|
"""Test getting all scope descriptions"""
|
|
descriptions = get_all_scope_descriptions()
|
|
|
|
assert descriptions == STANDARD_OAUTH2_SCOPES
|
|
assert len(descriptions) == len(STANDARD_OAUTH2_SCOPES)
|
|
|
|
# Verify it's a copy, not the original
|
|
descriptions["test"] = "should not affect original"
|
|
assert "test" not in STANDARD_OAUTH2_SCOPES
|
|
|
|
def test_validate_scopes_valid(self):
|
|
"""Test scope validation with valid scopes"""
|
|
# Single valid scope
|
|
token_scopes = {"llama:inference"}
|
|
result = validate_scopes(token_scopes)
|
|
assert result == {"llama:inference"}
|
|
|
|
# Multiple valid scopes
|
|
token_scopes = {"llama:inference", "llama:models:read", "llama:admin"}
|
|
result = validate_scopes(token_scopes)
|
|
assert result == {"llama:inference", "llama:models:read", "llama:admin"}
|
|
|
|
# Mix of valid and invalid scopes
|
|
token_scopes = {"llama:inference", "invalid:scope", "llama:admin"}
|
|
result = validate_scopes(token_scopes)
|
|
assert result == {"llama:inference", "llama:admin"}
|
|
|
|
def test_validate_scopes_invalid(self):
|
|
"""Test scope validation with invalid scopes"""
|
|
# No valid scopes
|
|
token_scopes = {"invalid:scope", "another:invalid"}
|
|
with pytest.raises(ValueError, match="Token lacks required OAuth2 scopes"):
|
|
validate_scopes(token_scopes)
|
|
|
|
# Empty scopes
|
|
token_scopes = set()
|
|
with pytest.raises(ValueError, match="Token lacks required OAuth2 scopes"):
|
|
validate_scopes(token_scopes)
|
|
|
|
def test_scope_grants_admin_access(self):
|
|
"""Test admin scope detection"""
|
|
# Admin scope present
|
|
assert scope_grants_admin_access({"llama:admin"})
|
|
assert scope_grants_admin_access({"llama:admin", "llama:inference"})
|
|
|
|
# Admin scope not present
|
|
assert not scope_grants_admin_access({"llama:inference"})
|
|
assert not scope_grants_admin_access({"llama:models:read", "llama:agents:write"})
|
|
assert not scope_grants_admin_access(set())
|
|
|
|
|
|
class TestScopeRequirements:
|
|
"""Test API endpoint scope requirements"""
|
|
|
|
def test_inference_api_scopes(self):
|
|
"""Test inference API scope requirements"""
|
|
apis = ["inference", "chat", "completion", "embeddings"]
|
|
for api in apis:
|
|
scopes = get_required_scopes_for_api(api, "POST")
|
|
assert "llama:inference" in scopes
|
|
assert "llama:admin" in scopes
|
|
|
|
def test_models_api_scopes(self):
|
|
"""Test models API scope requirements"""
|
|
# Read operations
|
|
read_scopes = get_required_scopes_for_api("models", "GET")
|
|
assert "llama:models:read" in read_scopes
|
|
assert "llama:admin" in read_scopes
|
|
assert "llama:models:write" not in read_scopes
|
|
|
|
# Write operations
|
|
for method in ["POST", "PUT", "DELETE"]:
|
|
write_scopes = get_required_scopes_for_api("models", method)
|
|
assert "llama:models:write" in write_scopes
|
|
assert "llama:admin" in write_scopes
|
|
assert "llama:models:read" not in write_scopes
|
|
|
|
def test_agents_api_scopes(self):
|
|
"""Test agents API scope requirements"""
|
|
# Read operations
|
|
read_scopes = get_required_scopes_for_api("agents", "GET")
|
|
assert "llama:agents:read" in read_scopes
|
|
assert "llama:admin" in read_scopes
|
|
|
|
# Write operations
|
|
for method in ["POST", "PUT", "DELETE"]:
|
|
write_scopes = get_required_scopes_for_api("agents", method)
|
|
assert "llama:agents:write" in write_scopes
|
|
assert "llama:admin" in write_scopes
|
|
|
|
def test_tools_api_scopes(self):
|
|
"""Test tools API scope requirements"""
|
|
for api in ["tools", "tool_runtime"]:
|
|
scopes = get_required_scopes_for_api(api, "POST")
|
|
assert "llama:tools" in scopes
|
|
assert "llama:admin" in scopes
|
|
|
|
def test_vector_dbs_api_scopes(self):
|
|
"""Test vector databases API scope requirements"""
|
|
# Read operations
|
|
read_scopes = get_required_scopes_for_api("vector_dbs", "GET")
|
|
assert "llama:vector_dbs:read" in read_scopes
|
|
assert "llama:admin" in read_scopes
|
|
|
|
# Write operations
|
|
for method in ["POST", "PUT", "DELETE"]:
|
|
write_scopes = get_required_scopes_for_api("vector_dbs", method)
|
|
assert "llama:vector_dbs:write" in write_scopes
|
|
assert "llama:admin" in write_scopes
|
|
|
|
def test_safety_api_scopes(self):
|
|
"""Test safety API scope requirements"""
|
|
scopes = get_required_scopes_for_api("safety", "POST")
|
|
assert "llama:safety" in scopes
|
|
assert "llama:admin" in scopes
|
|
|
|
def test_eval_api_scopes(self):
|
|
"""Test evaluation API scope requirements"""
|
|
for api in ["eval", "benchmarks", "scoring"]:
|
|
scopes = get_required_scopes_for_api(api, "POST")
|
|
assert "llama:eval" in scopes
|
|
assert "llama:admin" in scopes
|
|
|
|
def test_unknown_api_scopes(self):
|
|
"""Test unknown API only requires admin scope"""
|
|
scopes = get_required_scopes_for_api("unknown_api", "POST")
|
|
assert scopes == {"llama:admin"}
|
|
|
|
def test_admin_always_included(self):
|
|
"""Test that admin scope is always included in required scopes"""
|
|
test_apis = ["inference", "models", "agents", "tools", "safety", "eval", "unknown"]
|
|
test_methods = ["GET", "POST", "PUT", "DELETE"]
|
|
|
|
for api in test_apis:
|
|
for method in test_methods:
|
|
scopes = get_required_scopes_for_api(api, method)
|
|
assert "llama:admin" in scopes
|
|
|
|
|
|
class TestAPIPathExtraction:
|
|
"""Test API path extraction for scope validation"""
|
|
|
|
def test_v1_api_paths(self):
|
|
"""Test extraction from v1 API paths"""
|
|
test_cases = [
|
|
("/v1/inference/chat-completion", ("inference", "POST")),
|
|
("/v1/models", ("models", "POST")),
|
|
("/v1/models/my-model", ("models", "POST")),
|
|
("/v1/agents/session", ("agents", "POST")),
|
|
("/v1/tools/execute", ("tools", "POST")),
|
|
("/v1/vector_dbs/query", ("vector_dbs", "POST")),
|
|
("/v1/safety/shield", ("safety", "POST")),
|
|
("/v1/eval/benchmark", ("eval", "POST")),
|
|
]
|
|
|
|
for path, expected in test_cases:
|
|
result = extract_api_from_path(path)
|
|
assert result == expected
|
|
|
|
def test_openai_compatibility_paths(self):
|
|
"""Test extraction from OpenAI compatibility paths"""
|
|
openai_paths = [
|
|
"/v1/openai/v1/chat/completions",
|
|
"/v1/openai/v1/completions",
|
|
"/v1/openai/v1/embeddings",
|
|
]
|
|
|
|
for path in openai_paths:
|
|
api, method = extract_api_from_path(path)
|
|
assert api == "inference"
|
|
assert method == "POST"
|
|
|
|
def test_edge_case_paths(self):
|
|
"""Test edge cases in path extraction"""
|
|
test_cases = [
|
|
("/", ("unknown", "GET")),
|
|
("/health", ("health", "POST")),
|
|
("/v1/", ("unknown", "GET")),
|
|
("", ("unknown", "GET")),
|
|
("/some/nested/path", ("some", "POST")),
|
|
]
|
|
|
|
for path, expected in test_cases:
|
|
result = extract_api_from_path(path)
|
|
assert result == expected
|
|
|
|
|
|
class TestScopeBasedAuth:
|
|
"""Test scope-based authentication integration"""
|
|
|
|
@pytest.fixture
|
|
def mock_auth_config(self):
|
|
"""Create mock OAuth2 authentication config"""
|
|
return AuthenticationConfig(
|
|
provider_config=OAuth2TokenAuthConfig(
|
|
type=AuthProviderType.OAUTH2_TOKEN,
|
|
issuer="https://test-issuer.com",
|
|
audience="llama-stack",
|
|
)
|
|
)
|
|
|
|
@pytest.fixture
|
|
def mock_app(self):
|
|
"""Create mock FastAPI app"""
|
|
app = MagicMock()
|
|
return app
|
|
|
|
def test_scope_validation_in_middleware(self, mock_auth_config, mock_app):
|
|
"""Test that middleware validates scopes correctly"""
|
|
middleware = AuthenticationMiddleware(mock_app, mock_auth_config)
|
|
|
|
# This is a simplified test - in practice you'd need to mock the full ASGI flow
|
|
assert middleware.auth_provider is not None
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_token_validation_with_scopes(self):
|
|
"""Test JWT token validation with OAuth2 scopes"""
|
|
# Mock JWT claims with scopes
|
|
mock_claims = {
|
|
"sub": "test-user",
|
|
"scope": "llama:inference llama:models:read",
|
|
"iss": "test-issuer",
|
|
"aud": "llama-stack",
|
|
}
|
|
|
|
with patch("llama_stack.distribution.server.auth_providers.jwt.decode") as mock_jwt_decode:
|
|
mock_jwt_decode.return_value = mock_claims
|
|
|
|
# Mock the auth provider
|
|
from llama_stack.distribution.server.auth_providers import OAuth2TokenAuthProvider
|
|
from llama_stack.distribution.datatypes import OAuth2TokenAuthConfig, AuthProviderType
|
|
|
|
config = OAuth2TokenAuthConfig(
|
|
type=AuthProviderType.OAUTH2_TOKEN,
|
|
issuer="test-issuer",
|
|
audience="llama-stack",
|
|
)
|
|
|
|
provider = OAuth2TokenAuthProvider(config)
|
|
|
|
# Mock the key retrieval
|
|
with patch.object(provider, "_get_public_key") as mock_get_key:
|
|
mock_get_key.return_value = "mock-key"
|
|
|
|
# Test token validation
|
|
user = await provider.validate_token("mock-token", {})
|
|
|
|
assert user.principal == "test-user"
|
|
assert user.attributes is not None
|
|
assert "scopes" in user.attributes
|
|
assert set(user.attributes["scopes"]) == {"llama:inference", "llama:models:read"}
|
|
|
|
def test_scope_intersection_logic(self):
|
|
"""Test scope intersection for access control"""
|
|
# User has inference and read scopes
|
|
user_scopes = {"llama:inference", "llama:models:read"}
|
|
|
|
# Test various API requirements
|
|
inference_required = {"llama:admin", "llama:inference"}
|
|
models_read_required = {"llama:admin", "llama:models:read"}
|
|
models_write_required = {"llama:admin", "llama:models:write"}
|
|
|
|
# Should have access to inference
|
|
assert user_scopes.intersection(inference_required)
|
|
|
|
# Should have access to model reading
|
|
assert user_scopes.intersection(models_read_required)
|
|
|
|
# Should NOT have access to model writing
|
|
assert not user_scopes.intersection(models_write_required)
|
|
|
|
def test_admin_scope_access(self):
|
|
"""Test that admin scope grants access to everything"""
|
|
admin_scopes = {"llama:admin"}
|
|
|
|
# Test various API requirements
|
|
test_requirements = [
|
|
{"llama:admin", "llama:inference"},
|
|
{"llama:admin", "llama:models:write"},
|
|
{"llama:admin", "llama:agents:write"},
|
|
{"llama:admin", "llama:tools"},
|
|
{"llama:admin", "llama:vector_dbs:write"},
|
|
{"llama:admin", "llama:safety"},
|
|
{"llama:admin", "llama:eval"},
|
|
]
|
|
|
|
for required_scopes in test_requirements:
|
|
assert admin_scopes.intersection(required_scopes)
|
|
|
|
|
|
class TestScopeValidationErrors:
|
|
"""Test error cases in scope validation"""
|
|
|
|
def test_missing_scope_claim(self):
|
|
"""Test handling of missing scope claim in JWT"""
|
|
# Empty scope claim
|
|
token_scopes = set()
|
|
with pytest.raises(ValueError, match="Token lacks required OAuth2 scopes"):
|
|
validate_scopes(token_scopes)
|
|
|
|
def test_malformed_scope_string(self):
|
|
"""Test handling of malformed scope strings"""
|
|
# Scopes with extra whitespace should be handled gracefully
|
|
scope_string = " llama:inference llama:models:read "
|
|
scopes = set(scope_string.split())
|
|
|
|
# Filter out empty strings that might result from split()
|
|
scopes = {s.strip() for s in scopes if s.strip()}
|
|
|
|
result = validate_scopes(scopes)
|
|
assert result == {"llama:inference", "llama:models:read"}
|
|
|
|
def test_case_sensitive_scopes(self):
|
|
"""Test that scopes are case-sensitive"""
|
|
# Wrong case should not match
|
|
token_scopes = {"LLAMA:INFERENCE", "llama:Models:Read"}
|
|
|
|
with pytest.raises(ValueError, match="Token lacks required OAuth2 scopes"):
|
|
validate_scopes(token_scopes)
|
|
|
|
def test_partial_scope_matches(self):
|
|
"""Test that partial scope matches don't work"""
|
|
# Partial matches should not be accepted
|
|
token_scopes = {"llama:model", "llama", "inference"}
|
|
|
|
with pytest.raises(ValueError, match="Token lacks required OAuth2 scopes"):
|
|
validate_scopes(token_scopes)
|
|
|
|
|
|
class TestScopeDocumentation:
|
|
"""Test scope documentation and descriptions"""
|
|
|
|
def test_scope_naming_convention(self):
|
|
"""Test that scope names follow consistent naming convention"""
|
|
for scope in STANDARD_OAUTH2_SCOPES.keys():
|
|
# All scopes should start with 'llama:'
|
|
assert scope.startswith("llama:")
|
|
|
|
# Should not contain spaces
|
|
assert " " not in scope
|
|
|
|
# Should use colons as separators, not dots or slashes
|
|
parts = scope.split(":")
|
|
assert len(parts) >= 2
|
|
assert all(part.replace("_", "").isalnum() for part in parts)
|
|
|
|
def test_scope_descriptions_quality(self):
|
|
"""Test that scope descriptions are meaningful"""
|
|
for scope, description in STANDARD_OAUTH2_SCOPES.items():
|
|
# Should be non-empty strings
|
|
assert isinstance(description, str)
|
|
assert len(description.strip()) > 5
|
|
|
|
# Should not just be the scope name
|
|
assert scope not in description
|
|
|
|
# Should contain descriptive words
|
|
descriptive_words = ["access", "read", "write", "manage", "create", "delete", "execute"]
|
|
assert any(word in description.lower() for word in descriptive_words)
|
|
|
|
def test_read_write_scope_pairs(self):
|
|
"""Test that read/write scope pairs are consistent"""
|
|
read_write_apis = ["models", "agents", "vector_dbs"]
|
|
|
|
for api in read_write_apis:
|
|
read_scope = f"llama:{api}:read"
|
|
write_scope = f"llama:{api}:write"
|
|
|
|
assert read_scope in STANDARD_OAUTH2_SCOPES
|
|
assert write_scope in STANDARD_OAUTH2_SCOPES
|
|
|
|
# Read description should mention "read"
|
|
assert "read" in STANDARD_OAUTH2_SCOPES[read_scope].lower()
|
|
|
|
# Write description should mention write operations
|
|
write_desc = STANDARD_OAUTH2_SCOPES[write_scope].lower()
|
|
assert any(word in write_desc for word in ["write", "manage", "register", "create"])
|
|
|
|
|
|
if __name__ == "__main__":
|
|
pytest.main([__file__]) |