mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-22 22:39:41 +00:00
404 lines
No EOL
16 KiB
Python
404 lines
No EOL
16 KiB
Python
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
# All rights reserved.
|
|
#
|
|
# This source code is licensed under the terms described in the LICENSE file in
|
|
# the root directory of this source tree.
|
|
|
|
"""
|
|
Integration tests for OAuth2 scope-based authentication.
|
|
|
|
These tests verify the end-to-end flow of OAuth2 scope validation
|
|
from JWT token parsing to API access decisions.
|
|
"""
|
|
|
|
import pytest
|
|
from unittest.mock import AsyncMock, MagicMock, patch
|
|
|
|
from llama_stack.distribution.server.oauth2_scopes import get_required_scopes_for_api
|
|
from llama_stack.distribution.server.auth import extract_api_from_path
|
|
from llama_stack.distribution.datatypes import User
|
|
|
|
|
|
def get_user_scopes(user: User) -> set[str]:
|
|
"""Safely extract scopes from user attributes"""
|
|
if user.attributes and "scopes" in user.attributes:
|
|
return set(user.attributes["scopes"])
|
|
return set()
|
|
|
|
|
|
class TestOAuth2IntegrationFlow:
|
|
"""Test the complete OAuth2 scope validation flow"""
|
|
|
|
def test_inference_api_access_flow(self):
|
|
"""Test complete flow for inference API access"""
|
|
# 1. Extract API from path
|
|
api_name, method = extract_api_from_path("/v1/inference/chat-completion")
|
|
assert api_name == "inference"
|
|
|
|
# 2. Get required scopes for this API
|
|
required_scopes = get_required_scopes_for_api(api_name, method)
|
|
assert "llama:inference" in required_scopes
|
|
assert "llama:admin" in required_scopes
|
|
|
|
# 3. Test user with correct scope
|
|
user_with_inference = User(
|
|
principal="user1",
|
|
attributes={"scopes": ["llama:inference"]}
|
|
)
|
|
user_scopes = set(user_with_inference.attributes["scopes"]) if user_with_inference.attributes else set()
|
|
assert user_scopes.intersection(required_scopes) # Should have access
|
|
|
|
# 4. Test user without correct scope
|
|
user_without_inference = User(
|
|
principal="user2",
|
|
attributes={"scopes": ["llama:models:read"]}
|
|
)
|
|
user_scopes = set(user_without_inference.attributes["scopes"])
|
|
assert not user_scopes.intersection(required_scopes) # Should NOT have access
|
|
|
|
def test_models_api_read_write_flow(self):
|
|
"""Test complete flow for models API with read/write distinction"""
|
|
# Test read operation
|
|
api_name, method = extract_api_from_path("/v1/models")
|
|
read_required = get_required_scopes_for_api(api_name, "GET")
|
|
|
|
# Test write operation
|
|
write_required = get_required_scopes_for_api(api_name, "POST")
|
|
|
|
# User with only read access
|
|
read_user = User(
|
|
principal="read_user",
|
|
attributes={"scopes": ["llama:models:read"]}
|
|
)
|
|
read_scopes = set(read_user.attributes["scopes"])
|
|
|
|
# Should have read access
|
|
assert read_scopes.intersection(read_required)
|
|
|
|
# Should NOT have write access
|
|
assert not read_scopes.intersection(write_required)
|
|
|
|
# User with write access
|
|
write_user = User(
|
|
principal="write_user",
|
|
attributes={"scopes": ["llama:models:write"]}
|
|
)
|
|
write_scopes = set(write_user.attributes["scopes"])
|
|
|
|
# Should have write access
|
|
assert write_scopes.intersection(write_required)
|
|
|
|
# Should NOT have read access (write doesn't imply read)
|
|
assert not write_scopes.intersection(read_required)
|
|
|
|
def test_admin_scope_universal_access(self):
|
|
"""Test that admin scope grants access to all APIs"""
|
|
admin_user = User(
|
|
principal="admin",
|
|
attributes={"scopes": ["llama:admin"]}
|
|
)
|
|
admin_scopes = set(admin_user.attributes["scopes"])
|
|
|
|
# Test various API endpoints
|
|
test_cases = [
|
|
("/v1/inference/chat-completion", "POST"),
|
|
("/v1/models", "GET"),
|
|
("/v1/models/my-model", "DELETE"),
|
|
("/v1/agents/session", "POST"),
|
|
("/v1/tools/execute", "POST"),
|
|
("/v1/vector_dbs/query", "POST"),
|
|
("/v1/safety/shield", "POST"),
|
|
("/v1/eval/benchmark", "POST"),
|
|
]
|
|
|
|
for path, method in test_cases:
|
|
api_name, _ = extract_api_from_path(path)
|
|
required_scopes = get_required_scopes_for_api(api_name, method)
|
|
|
|
# Admin should always have access
|
|
assert admin_scopes.intersection(required_scopes), f"Admin denied access to {path}"
|
|
|
|
def test_multiple_scopes_user(self):
|
|
"""Test user with multiple scopes"""
|
|
multi_scope_user = User(
|
|
principal="power_user",
|
|
attributes={
|
|
"scopes": [
|
|
"llama:inference",
|
|
"llama:models:read",
|
|
"llama:agents:write",
|
|
"llama:tools"
|
|
]
|
|
}
|
|
)
|
|
user_scopes = set(multi_scope_user.attributes["scopes"])
|
|
|
|
# Test various access scenarios
|
|
access_tests = [
|
|
("/v1/inference/chat-completion", "POST", True), # Has inference
|
|
("/v1/models", "GET", True), # Has models:read
|
|
("/v1/models", "POST", False), # Doesn't have models:write
|
|
("/v1/agents/session", "POST", True), # Has agents:write
|
|
("/v1/agents", "GET", False), # Doesn't have agents:read
|
|
("/v1/tools/execute", "POST", True), # Has tools
|
|
("/v1/vector_dbs/query", "GET", False), # Doesn't have vector_dbs:read
|
|
("/v1/safety/shield", "POST", False), # Doesn't have safety
|
|
]
|
|
|
|
for path, method, should_have_access in access_tests:
|
|
api_name, _ = extract_api_from_path(path)
|
|
required_scopes = get_required_scopes_for_api(api_name, method)
|
|
has_access = bool(user_scopes.intersection(required_scopes))
|
|
|
|
assert has_access == should_have_access, (
|
|
f"Access mismatch for {path} {method}: "
|
|
f"expected {should_have_access}, got {has_access}"
|
|
)
|
|
|
|
def test_openai_compatibility_scope_flow(self):
|
|
"""Test OAuth2 scope validation for OpenAI compatibility endpoints"""
|
|
# OpenAI endpoints should map to inference API
|
|
openai_paths = [
|
|
"/v1/openai/v1/chat/completions",
|
|
"/v1/openai/v1/completions",
|
|
"/v1/openai/v1/embeddings",
|
|
]
|
|
|
|
inference_user = User(
|
|
principal="openai_user",
|
|
attributes={"scopes": ["llama:inference"]}
|
|
)
|
|
user_scopes = set(inference_user.attributes["scopes"])
|
|
|
|
for path in openai_paths:
|
|
api_name, method = extract_api_from_path(path)
|
|
assert api_name == "inference"
|
|
|
|
required_scopes = get_required_scopes_for_api(api_name, method)
|
|
assert user_scopes.intersection(required_scopes), f"No access to {path}"
|
|
|
|
def test_scope_validation_error_scenarios(self):
|
|
"""Test error scenarios in scope validation"""
|
|
# User with no scopes
|
|
no_scope_user = User(principal="no_scope", attributes={})
|
|
|
|
api_name, method = extract_api_from_path("/v1/inference/chat-completion")
|
|
required_scopes = get_required_scopes_for_api(api_name, method)
|
|
|
|
# Should not have access
|
|
user_scopes = set()
|
|
assert not user_scopes.intersection(required_scopes)
|
|
|
|
# User with empty scopes list
|
|
empty_scope_user = User(
|
|
principal="empty_scope",
|
|
attributes={"scopes": []}
|
|
)
|
|
user_scopes = set(empty_scope_user.attributes["scopes"])
|
|
assert not user_scopes.intersection(required_scopes)
|
|
|
|
# User with invalid scopes
|
|
invalid_scope_user = User(
|
|
principal="invalid_scope",
|
|
attributes={"scopes": ["invalid:scope", "another:invalid"]}
|
|
)
|
|
user_scopes = set(invalid_scope_user.attributes["scopes"])
|
|
assert not user_scopes.intersection(required_scopes)
|
|
|
|
|
|
class TestScopeBasedAccessMatrix:
|
|
"""Test access matrix for different user types and API combinations"""
|
|
|
|
def test_data_scientist_access_pattern(self):
|
|
"""Test typical data scientist access pattern"""
|
|
data_scientist = User(
|
|
principal="data_scientist",
|
|
attributes={
|
|
"scopes": [
|
|
"llama:inference",
|
|
"llama:models:read",
|
|
"llama:eval",
|
|
"llama:safety"
|
|
]
|
|
}
|
|
)
|
|
user_scopes = set(data_scientist.attributes["scopes"])
|
|
|
|
# Should have access to
|
|
allowed_apis = [
|
|
("inference", "POST"), # Run inference
|
|
("models", "GET"), # List/inspect models
|
|
("eval", "POST"), # Run evaluations
|
|
("safety", "POST"), # Use safety shields
|
|
]
|
|
|
|
for api, method in allowed_apis:
|
|
required = get_required_scopes_for_api(api, method)
|
|
assert user_scopes.intersection(required), f"Data scientist denied {api} {method}"
|
|
|
|
# Should NOT have access to
|
|
denied_apis = [
|
|
("models", "POST"), # Cannot register models
|
|
("agents", "POST"), # Cannot create agents
|
|
("tools", "POST"), # Cannot use tools
|
|
("vector_dbs", "GET"), # Cannot access vector DBs
|
|
]
|
|
|
|
for api, method in denied_apis:
|
|
required = get_required_scopes_for_api(api, method)
|
|
assert not user_scopes.intersection(required), f"Data scientist allowed {api} {method}"
|
|
|
|
def test_ml_engineer_access_pattern(self):
|
|
"""Test typical ML engineer access pattern"""
|
|
ml_engineer = User(
|
|
principal="ml_engineer",
|
|
attributes={
|
|
"scopes": [
|
|
"llama:inference",
|
|
"llama:models:read",
|
|
"llama:models:write",
|
|
"llama:agents:read",
|
|
"llama:agents:write",
|
|
"llama:tools",
|
|
"llama:eval"
|
|
]
|
|
}
|
|
)
|
|
user_scopes = set(ml_engineer.attributes["scopes"])
|
|
|
|
# Should have broad access except admin-only operations
|
|
allowed_apis = [
|
|
("inference", "POST"),
|
|
("models", "GET"),
|
|
("models", "POST"),
|
|
("agents", "GET"),
|
|
("agents", "POST"),
|
|
("tools", "POST"),
|
|
("eval", "POST"),
|
|
]
|
|
|
|
for api, method in allowed_apis:
|
|
required = get_required_scopes_for_api(api, method)
|
|
assert user_scopes.intersection(required), f"ML engineer denied {api} {method}"
|
|
|
|
def test_application_developer_access_pattern(self):
|
|
"""Test typical application developer access pattern"""
|
|
app_developer = User(
|
|
principal="app_developer",
|
|
attributes={
|
|
"scopes": [
|
|
"llama:inference",
|
|
"llama:agents:read",
|
|
"llama:agents:write",
|
|
"llama:tools",
|
|
"llama:safety"
|
|
]
|
|
}
|
|
)
|
|
user_scopes = set(app_developer.attributes["scopes"])
|
|
|
|
# Should focus on application-building APIs
|
|
allowed_apis = [
|
|
("inference", "POST"), # Use models for apps
|
|
("agents", "GET"), # Inspect agents
|
|
("agents", "POST"), # Create agent sessions
|
|
("tools", "POST"), # Execute tools
|
|
("safety", "POST"), # Apply safety
|
|
]
|
|
|
|
for api, method in allowed_apis:
|
|
required = get_required_scopes_for_api(api, method)
|
|
assert user_scopes.intersection(required), f"App developer denied {api} {method}"
|
|
|
|
# Should NOT have model or eval management access
|
|
denied_apis = [
|
|
("models", "POST"), # Cannot manage models
|
|
("eval", "POST"), # Cannot run evaluations
|
|
("vector_dbs", "POST"), # Cannot manage vector DBs
|
|
]
|
|
|
|
for api, method in denied_apis:
|
|
required = get_required_scopes_for_api(api, method)
|
|
assert not user_scopes.intersection(required), f"App developer allowed {api} {method}"
|
|
|
|
|
|
class TestScopeHierarchyAndSeparation:
|
|
"""Test that scopes are properly separated and don't grant unintended access"""
|
|
|
|
def test_read_write_separation(self):
|
|
"""Test that read scopes don't grant write access and vice versa"""
|
|
read_only_apis = ["models", "agents", "vector_dbs"]
|
|
|
|
for api in read_only_apis:
|
|
# User with only read scope
|
|
read_user = User(
|
|
principal=f"{api}_reader",
|
|
attributes={"scopes": [f"llama:{api}:read"]}
|
|
)
|
|
read_scopes = set(read_user.attributes["scopes"])
|
|
|
|
# User with only write scope
|
|
write_user = User(
|
|
principal=f"{api}_writer",
|
|
attributes={"scopes": [f"llama:{api}:write"]}
|
|
)
|
|
write_scopes = set(write_user.attributes["scopes"])
|
|
|
|
# Read user should only have read access
|
|
read_required = get_required_scopes_for_api(api, "GET")
|
|
write_required = get_required_scopes_for_api(api, "POST")
|
|
|
|
assert read_scopes.intersection(read_required), f"Read user denied read access to {api}"
|
|
assert not read_scopes.intersection(write_required), f"Read user granted write access to {api}"
|
|
|
|
# Write user should only have write access
|
|
assert write_scopes.intersection(write_required), f"Write user denied write access to {api}"
|
|
assert not write_scopes.intersection(read_required), f"Write user granted read access to {api}"
|
|
|
|
def test_api_isolation(self):
|
|
"""Test that API scopes don't grant access to other APIs"""
|
|
api_scopes = [
|
|
"llama:inference",
|
|
"llama:models:read",
|
|
"llama:agents:write",
|
|
"llama:tools",
|
|
"llama:vector_dbs:read",
|
|
"llama:safety",
|
|
"llama:eval"
|
|
]
|
|
|
|
for scope in api_scopes:
|
|
user = User(
|
|
principal=f"single_scope_user",
|
|
attributes={"scopes": [scope]}
|
|
)
|
|
user_scopes = set(user.attributes["scopes"])
|
|
|
|
# Test that this scope only grants access to its intended API
|
|
test_apis = [
|
|
("inference", "POST"),
|
|
("models", "GET"),
|
|
("models", "POST"),
|
|
("agents", "GET"),
|
|
("agents", "POST"),
|
|
("tools", "POST"),
|
|
("vector_dbs", "GET"),
|
|
("vector_dbs", "POST"),
|
|
("safety", "POST"),
|
|
("eval", "POST")
|
|
]
|
|
|
|
access_count = 0
|
|
for api, method in test_apis:
|
|
required = get_required_scopes_for_api(api, method)
|
|
if user_scopes.intersection(required):
|
|
access_count += 1
|
|
|
|
# Should only have access to 1-2 endpoints (the intended API)
|
|
# Allow 2 for APIs that have both read and write variants
|
|
assert access_count <= 2, f"Scope {scope} grants too broad access ({access_count} APIs)"
|
|
assert access_count >= 1, f"Scope {scope} grants no access"
|
|
|
|
|
|
if __name__ == "__main__":
|
|
pytest.main([__file__]) |