diff --git a/CHANGELOG.md b/CHANGELOG.md index d3718e5bc..0a2cfedcd 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,43 @@ # Changelog +# v0.3.0 (Upcoming) +Published on: TBD + +## 🚨 **BREAKING CHANGES** + +### OAuth2 Scope-Based Authentication +* **BREAKING:** JWT tokens now REQUIRE OAuth2 scopes for API access +* **BREAKING:** Tokens without valid `scope` claim will be rejected (401 Unauthorized) +* **BREAKING:** Legacy attribute-based access control replaced with OAuth2 scopes + +#### New Standard OAuth2 Scopes: +- `llama:inference` - Access to inference APIs (`/v1/inference/*`, OpenAI compatibility) +- `llama:models:read` - Read access to models (`GET /v1/models/*`) +- `llama:models:write` - Write access to models (`POST/PUT/DELETE /v1/models/*`) +- `llama:agents:read` - Read access to agents (`GET /v1/agents/*`) +- `llama:agents:write` - Write access to agents (`POST/PUT/DELETE /v1/agents/*`) +- `llama:tools` - Access to tool runtime (`/v1/tools/*`) +- `llama:vector_dbs:read` - Read access to vector databases +- `llama:vector_dbs:write` - Write access to vector databases +- `llama:safety` - Access to safety shields (`/v1/safety/*`) +- `llama:eval` - Access to evaluation APIs (`/v1/eval/*`, `/v1/benchmarks/*`) +- `llama:admin` - Full administrative access to all APIs + +#### Migration Required: +1. **Update OAuth2 Provider:** Configure your OAuth2/OIDC provider to include Llama Stack scopes in JWT tokens +2. **Update Client Applications:** Request appropriate scopes when obtaining tokens +3. **Test Token Format:** Ensure JWT tokens include `"scope": "llama:inference llama:models:read"` claim + +#### Security Benefits: +- **Principle of Least Privilege:** Granular access control per API +- **Deny by Default:** No access without explicit scope grants +- **OAuth2.0 Compliance:** Follows industry standard specifications +- **Enhanced Audit Trail:** Clear permission tracking + +See [OAuth2 Scope Migration Guide](docs/source/concepts/oauth2_scopes.md) for detailed migration instructions. + +--- + # v0.2.12 Published on: 2025-06-20T22:52:12Z diff --git a/llama_stack/distribution/server/auth.py b/llama_stack/distribution/server/auth.py index fadbf7b49..fb04885f3 100644 --- a/llama_stack/distribution/server/auth.py +++ b/llama_stack/distribution/server/auth.py @@ -10,11 +10,36 @@ import httpx from llama_stack.distribution.datatypes import AuthenticationConfig from llama_stack.distribution.server.auth_providers import create_auth_provider +from llama_stack.distribution.server.oauth2_scopes import get_required_scopes_for_api from llama_stack.log import get_logger logger = get_logger(name=__name__, category="auth") +def extract_api_from_path(path: str) -> tuple[str, str]: + """Extract API name and method from request path for scope validation""" + # Remove leading/trailing slashes and split + path = path.strip("/") + parts = path.split("/") + + # Handle common API path patterns + if len(parts) >= 2 and parts[0] == "v1": + api_name = parts[1] + # Handle nested paths like /v1/models/{id} or /v1/inference/chat-completion + if api_name in ["inference", "models", "agents", "tools", "vector_dbs", "safety", "eval", "scoring"]: + return api_name, "POST" # Default to POST for scope checking + elif api_name == "openai": + # Handle OpenAI compatibility endpoints like /v1/openai/v1/chat/completions + if len(parts) >= 4: + return "inference", "POST" # OpenAI endpoints are typically inference + + # Fallback - try to extract from first path component + if parts: + return parts[0], "POST" + + return "unknown", "GET" + + class AuthenticationMiddleware: """Middleware that authenticates requests using configured authentication provider. @@ -109,6 +134,34 @@ class AuthenticationMiddleware: logger.exception("Error during authentication") return await self._send_auth_error(send, "Authentication service error") + # Validate OAuth2 scopes for the requested API endpoint + path = scope.get("path", "") + method = scope.get("method", "GET") + api_name, _ = extract_api_from_path(path) + + # Get required scopes for this API endpoint + required_scopes = get_required_scopes_for_api(api_name, method) + + # Check if user has any of the required scopes + user_scopes = set() + if validation_result.attributes and "scopes" in validation_result.attributes: + user_scopes = set(validation_result.attributes["scopes"]) + + # Verify user has at least one required scope + if not user_scopes.intersection(required_scopes): + logger.warning( + f"Access denied for {validation_result.principal} to {api_name} API. " + f"Required scopes: {required_scopes}, User scopes: {user_scopes}" + ) + return await self._send_auth_error( + send, f"Insufficient OAuth2 scopes for {api_name} API. Required: {', '.join(required_scopes)}" + ) + + logger.debug( + f"OAuth2 scope validation passed for {validation_result.principal} " + f"on {api_name} API with scopes: {user_scopes.intersection(required_scopes)}" + ) + # Store the client ID in the request scope so that downstream middleware (like QuotaMiddleware) # can identify the requester and enforce per-client rate limits. scope["authenticated_client_id"] = token @@ -117,9 +170,8 @@ class AuthenticationMiddleware: scope["principal"] = validation_result.principal if validation_result.attributes: scope["user_attributes"] = validation_result.attributes - logger.debug( - f"Authentication successful: {validation_result.principal} with {len(validation_result.attributes)} attributes" - ) + attr_count = len(validation_result.attributes) if validation_result.attributes else 0 + logger.debug(f"Authentication successful: {validation_result.principal} with {attr_count} attributes") return await self.app(scope, receive, send) diff --git a/llama_stack/distribution/server/auth_providers.py b/llama_stack/distribution/server/auth_providers.py index 9b0e182f5..5c8095580 100644 --- a/llama_stack/distribution/server/auth_providers.py +++ b/llama_stack/distribution/server/auth_providers.py @@ -21,6 +21,10 @@ from llama_stack.distribution.datatypes import ( OAuth2TokenAuthConfig, User, ) +from llama_stack.distribution.server.oauth2_scopes import ( + scope_grants_admin_access, + validate_scopes, +) from llama_stack.log import get_logger logger = get_logger(name=__name__, category="auth") @@ -128,13 +132,34 @@ class OAuth2TokenAuthProvider(AuthProvider): except Exception as exc: raise ValueError("Invalid JWT token") from exc - # There are other standard claims, the most relevant of which is `scope`. - # We should incorporate these into the access attributes. - principal = claims["sub"] - access_attributes = get_attributes_from_claims(claims, self.config.claims_mapping) + # Extract and validate OAuth2 scopes - deny by default if no valid scopes + token_scopes = set(claims.get("scope", "").split()) if claims.get("scope") else set() + + # Validate scopes against standard Llama Stack scopes + valid_scopes = validate_scopes(token_scopes) + + logger.info(f"User {claims['sub']} authenticated with scopes: {valid_scopes}") + + # Convert scopes to user attributes for the existing ABAC system + attributes = {"scopes": list(valid_scopes)} + + # Add admin role if user has admin scope + if scope_grants_admin_access(valid_scopes): + attributes["roles"] = ["admin"] + + # Maintain backward compatibility with existing claims mapping + legacy_attributes = get_attributes_from_claims(claims, self.config.claims_mapping) + if legacy_attributes: + for key, values in legacy_attributes.items(): + if key not in attributes: + attributes[key] = values + else: + # Merge lists, avoiding duplicates + attributes[key] = list(set(attributes[key] + values)) + return User( - principal=principal, - attributes=access_attributes, + principal=claims["sub"], + attributes=attributes, ) async def introspect_token(self, token: str, scope: dict | None = None) -> User: diff --git a/llama_stack/distribution/server/oauth2_scopes.py b/llama_stack/distribution/server/oauth2_scopes.py new file mode 100644 index 000000000..66ff39efb --- /dev/null +++ b/llama_stack/distribution/server/oauth2_scopes.py @@ -0,0 +1,126 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +""" +OAuth2 scope definitions and validation for Llama Stack APIs. + +This module defines the standard OAuth2 scopes that are built-in to Llama Stack +and provides utilities for scope validation. These scopes are not configurable +and provide a standardized way to control access to different API endpoints. +""" + +from typing import Set + +# Standard OAuth2 scopes for Llama Stack APIs (built-in, not configurable) +STANDARD_OAUTH2_SCOPES = { + # Inference API + "llama:inference": "Access to inference APIs (chat completion, embeddings)", + + # Models API + "llama:models:read": "Read access to models (list, get model details)", + "llama:models:write": "Write access to models (register, unregister)", + + # Agents API + "llama:agents:read": "Read access to agents (list sessions, get agent details)", + "llama:agents:write": "Write access to agents (create sessions, send messages)", + + # Tools API + "llama:tools": "Access to tool runtime and execution", + + # Vector DB API + "llama:vector_dbs:read": "Read access to vector databases", + "llama:vector_dbs:write": "Write access to vector databases", + + # Safety API + "llama:safety": "Access to safety shields and content filtering", + + # Eval API + "llama:eval": "Access to evaluation and benchmarking", + + # Administrative access + "llama:admin": "Full administrative access to all APIs", +} + + +def get_required_scopes_for_api(api_name: str, method: str = "GET") -> Set[str]: + """Get required OAuth2 scopes for accessing a specific API endpoint. + + Args: + api_name: The name of the API (e.g., 'models', 'inference', 'agents') + method: The HTTP method (GET, POST, PUT, DELETE) + + Returns: + Set of scope strings that would grant access to this endpoint. + Always includes 'llama:admin' as it grants access to everything. + """ + # Admin scope grants access to everything + required_scopes = {"llama:admin"} + + # Map API names to required scopes + if api_name in ["inference", "chat", "completion", "embeddings"]: + required_scopes.add("llama:inference") + elif api_name == "models": + if method in ["POST", "PUT", "DELETE"]: + required_scopes.add("llama:models:write") + else: + required_scopes.add("llama:models:read") + elif api_name == "agents": + if method in ["POST", "PUT", "DELETE"]: + required_scopes.add("llama:agents:write") + else: + required_scopes.add("llama:agents:read") + elif api_name in ["tools", "tool_runtime"]: + required_scopes.add("llama:tools") + elif api_name == "vector_dbs": + if method in ["POST", "PUT", "DELETE"]: + required_scopes.add("llama:vector_dbs:write") + else: + required_scopes.add("llama:vector_dbs:read") + elif api_name == "safety": + required_scopes.add("llama:safety") + elif api_name in ["eval", "benchmarks", "scoring"]: + required_scopes.add("llama:eval") + + return required_scopes + + +def validate_scopes(token_scopes: Set[str]) -> Set[str]: + """Validate OAuth2 scopes against standard Llama Stack scopes. + + Args: + token_scopes: Set of scopes from the OAuth2 token + + Returns: + Set of valid scopes that are recognized by Llama Stack + + Raises: + ValueError: If no valid scopes are found + """ + valid_scopes = token_scopes.intersection(STANDARD_OAUTH2_SCOPES.keys()) + if not valid_scopes: + raise ValueError("Token lacks required OAuth2 scopes for Llama Stack access") + return valid_scopes + + +def scope_grants_admin_access(scopes: Set[str]) -> bool: + """Check if the provided scopes include administrative access. + + Args: + scopes: Set of OAuth2 scopes + + Returns: + True if scopes include administrative access, False otherwise + """ + return "llama:admin" in scopes + + +def get_all_scope_descriptions() -> dict[str, str]: + """Get all standard OAuth2 scopes with their descriptions. + + Returns: + Dictionary mapping scope names to their descriptions + """ + return STANDARD_OAUTH2_SCOPES.copy() \ No newline at end of file diff --git a/tests/unit/server/test_oauth2_integration.py b/tests/unit/server/test_oauth2_integration.py new file mode 100644 index 000000000..d77378e81 --- /dev/null +++ b/tests/unit/server/test_oauth2_integration.py @@ -0,0 +1,404 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +""" +Integration tests for OAuth2 scope-based authentication. + +These tests verify the end-to-end flow of OAuth2 scope validation +from JWT token parsing to API access decisions. +""" + +import pytest +from unittest.mock import AsyncMock, MagicMock, patch + +from llama_stack.distribution.server.oauth2_scopes import get_required_scopes_for_api +from llama_stack.distribution.server.auth import extract_api_from_path +from llama_stack.distribution.datatypes import User + + +def get_user_scopes(user: User) -> set[str]: + """Safely extract scopes from user attributes""" + if user.attributes and "scopes" in user.attributes: + return set(user.attributes["scopes"]) + return set() + + +class TestOAuth2IntegrationFlow: + """Test the complete OAuth2 scope validation flow""" + + def test_inference_api_access_flow(self): + """Test complete flow for inference API access""" + # 1. Extract API from path + api_name, method = extract_api_from_path("/v1/inference/chat-completion") + assert api_name == "inference" + + # 2. Get required scopes for this API + required_scopes = get_required_scopes_for_api(api_name, method) + assert "llama:inference" in required_scopes + assert "llama:admin" in required_scopes + + # 3. Test user with correct scope + user_with_inference = User( + principal="user1", + attributes={"scopes": ["llama:inference"]} + ) + user_scopes = set(user_with_inference.attributes["scopes"]) if user_with_inference.attributes else set() + assert user_scopes.intersection(required_scopes) # Should have access + + # 4. Test user without correct scope + user_without_inference = User( + principal="user2", + attributes={"scopes": ["llama:models:read"]} + ) + user_scopes = set(user_without_inference.attributes["scopes"]) + assert not user_scopes.intersection(required_scopes) # Should NOT have access + + def test_models_api_read_write_flow(self): + """Test complete flow for models API with read/write distinction""" + # Test read operation + api_name, method = extract_api_from_path("/v1/models") + read_required = get_required_scopes_for_api(api_name, "GET") + + # Test write operation + write_required = get_required_scopes_for_api(api_name, "POST") + + # User with only read access + read_user = User( + principal="read_user", + attributes={"scopes": ["llama:models:read"]} + ) + read_scopes = set(read_user.attributes["scopes"]) + + # Should have read access + assert read_scopes.intersection(read_required) + + # Should NOT have write access + assert not read_scopes.intersection(write_required) + + # User with write access + write_user = User( + principal="write_user", + attributes={"scopes": ["llama:models:write"]} + ) + write_scopes = set(write_user.attributes["scopes"]) + + # Should have write access + assert write_scopes.intersection(write_required) + + # Should NOT have read access (write doesn't imply read) + assert not write_scopes.intersection(read_required) + + def test_admin_scope_universal_access(self): + """Test that admin scope grants access to all APIs""" + admin_user = User( + principal="admin", + attributes={"scopes": ["llama:admin"]} + ) + admin_scopes = set(admin_user.attributes["scopes"]) + + # Test various API endpoints + test_cases = [ + ("/v1/inference/chat-completion", "POST"), + ("/v1/models", "GET"), + ("/v1/models/my-model", "DELETE"), + ("/v1/agents/session", "POST"), + ("/v1/tools/execute", "POST"), + ("/v1/vector_dbs/query", "POST"), + ("/v1/safety/shield", "POST"), + ("/v1/eval/benchmark", "POST"), + ] + + for path, method in test_cases: + api_name, _ = extract_api_from_path(path) + required_scopes = get_required_scopes_for_api(api_name, method) + + # Admin should always have access + assert admin_scopes.intersection(required_scopes), f"Admin denied access to {path}" + + def test_multiple_scopes_user(self): + """Test user with multiple scopes""" + multi_scope_user = User( + principal="power_user", + attributes={ + "scopes": [ + "llama:inference", + "llama:models:read", + "llama:agents:write", + "llama:tools" + ] + } + ) + user_scopes = set(multi_scope_user.attributes["scopes"]) + + # Test various access scenarios + access_tests = [ + ("/v1/inference/chat-completion", "POST", True), # Has inference + ("/v1/models", "GET", True), # Has models:read + ("/v1/models", "POST", False), # Doesn't have models:write + ("/v1/agents/session", "POST", True), # Has agents:write + ("/v1/agents", "GET", False), # Doesn't have agents:read + ("/v1/tools/execute", "POST", True), # Has tools + ("/v1/vector_dbs/query", "GET", False), # Doesn't have vector_dbs:read + ("/v1/safety/shield", "POST", False), # Doesn't have safety + ] + + for path, method, should_have_access in access_tests: + api_name, _ = extract_api_from_path(path) + required_scopes = get_required_scopes_for_api(api_name, method) + has_access = bool(user_scopes.intersection(required_scopes)) + + assert has_access == should_have_access, ( + f"Access mismatch for {path} {method}: " + f"expected {should_have_access}, got {has_access}" + ) + + def test_openai_compatibility_scope_flow(self): + """Test OAuth2 scope validation for OpenAI compatibility endpoints""" + # OpenAI endpoints should map to inference API + openai_paths = [ + "/v1/openai/v1/chat/completions", + "/v1/openai/v1/completions", + "/v1/openai/v1/embeddings", + ] + + inference_user = User( + principal="openai_user", + attributes={"scopes": ["llama:inference"]} + ) + user_scopes = set(inference_user.attributes["scopes"]) + + for path in openai_paths: + api_name, method = extract_api_from_path(path) + assert api_name == "inference" + + required_scopes = get_required_scopes_for_api(api_name, method) + assert user_scopes.intersection(required_scopes), f"No access to {path}" + + def test_scope_validation_error_scenarios(self): + """Test error scenarios in scope validation""" + # User with no scopes + no_scope_user = User(principal="no_scope", attributes={}) + + api_name, method = extract_api_from_path("/v1/inference/chat-completion") + required_scopes = get_required_scopes_for_api(api_name, method) + + # Should not have access + user_scopes = set() + assert not user_scopes.intersection(required_scopes) + + # User with empty scopes list + empty_scope_user = User( + principal="empty_scope", + attributes={"scopes": []} + ) + user_scopes = set(empty_scope_user.attributes["scopes"]) + assert not user_scopes.intersection(required_scopes) + + # User with invalid scopes + invalid_scope_user = User( + principal="invalid_scope", + attributes={"scopes": ["invalid:scope", "another:invalid"]} + ) + user_scopes = set(invalid_scope_user.attributes["scopes"]) + assert not user_scopes.intersection(required_scopes) + + +class TestScopeBasedAccessMatrix: + """Test access matrix for different user types and API combinations""" + + def test_data_scientist_access_pattern(self): + """Test typical data scientist access pattern""" + data_scientist = User( + principal="data_scientist", + attributes={ + "scopes": [ + "llama:inference", + "llama:models:read", + "llama:eval", + "llama:safety" + ] + } + ) + user_scopes = set(data_scientist.attributes["scopes"]) + + # Should have access to + allowed_apis = [ + ("inference", "POST"), # Run inference + ("models", "GET"), # List/inspect models + ("eval", "POST"), # Run evaluations + ("safety", "POST"), # Use safety shields + ] + + for api, method in allowed_apis: + required = get_required_scopes_for_api(api, method) + assert user_scopes.intersection(required), f"Data scientist denied {api} {method}" + + # Should NOT have access to + denied_apis = [ + ("models", "POST"), # Cannot register models + ("agents", "POST"), # Cannot create agents + ("tools", "POST"), # Cannot use tools + ("vector_dbs", "GET"), # Cannot access vector DBs + ] + + for api, method in denied_apis: + required = get_required_scopes_for_api(api, method) + assert not user_scopes.intersection(required), f"Data scientist allowed {api} {method}" + + def test_ml_engineer_access_pattern(self): + """Test typical ML engineer access pattern""" + ml_engineer = User( + principal="ml_engineer", + attributes={ + "scopes": [ + "llama:inference", + "llama:models:read", + "llama:models:write", + "llama:agents:read", + "llama:agents:write", + "llama:tools", + "llama:eval" + ] + } + ) + user_scopes = set(ml_engineer.attributes["scopes"]) + + # Should have broad access except admin-only operations + allowed_apis = [ + ("inference", "POST"), + ("models", "GET"), + ("models", "POST"), + ("agents", "GET"), + ("agents", "POST"), + ("tools", "POST"), + ("eval", "POST"), + ] + + for api, method in allowed_apis: + required = get_required_scopes_for_api(api, method) + assert user_scopes.intersection(required), f"ML engineer denied {api} {method}" + + def test_application_developer_access_pattern(self): + """Test typical application developer access pattern""" + app_developer = User( + principal="app_developer", + attributes={ + "scopes": [ + "llama:inference", + "llama:agents:read", + "llama:agents:write", + "llama:tools", + "llama:safety" + ] + } + ) + user_scopes = set(app_developer.attributes["scopes"]) + + # Should focus on application-building APIs + allowed_apis = [ + ("inference", "POST"), # Use models for apps + ("agents", "GET"), # Inspect agents + ("agents", "POST"), # Create agent sessions + ("tools", "POST"), # Execute tools + ("safety", "POST"), # Apply safety + ] + + for api, method in allowed_apis: + required = get_required_scopes_for_api(api, method) + assert user_scopes.intersection(required), f"App developer denied {api} {method}" + + # Should NOT have model or eval management access + denied_apis = [ + ("models", "POST"), # Cannot manage models + ("eval", "POST"), # Cannot run evaluations + ("vector_dbs", "POST"), # Cannot manage vector DBs + ] + + for api, method in denied_apis: + required = get_required_scopes_for_api(api, method) + assert not user_scopes.intersection(required), f"App developer allowed {api} {method}" + + +class TestScopeHierarchyAndSeparation: + """Test that scopes are properly separated and don't grant unintended access""" + + def test_read_write_separation(self): + """Test that read scopes don't grant write access and vice versa""" + read_only_apis = ["models", "agents", "vector_dbs"] + + for api in read_only_apis: + # User with only read scope + read_user = User( + principal=f"{api}_reader", + attributes={"scopes": [f"llama:{api}:read"]} + ) + read_scopes = set(read_user.attributes["scopes"]) + + # User with only write scope + write_user = User( + principal=f"{api}_writer", + attributes={"scopes": [f"llama:{api}:write"]} + ) + write_scopes = set(write_user.attributes["scopes"]) + + # Read user should only have read access + read_required = get_required_scopes_for_api(api, "GET") + write_required = get_required_scopes_for_api(api, "POST") + + assert read_scopes.intersection(read_required), f"Read user denied read access to {api}" + assert not read_scopes.intersection(write_required), f"Read user granted write access to {api}" + + # Write user should only have write access + assert write_scopes.intersection(write_required), f"Write user denied write access to {api}" + assert not write_scopes.intersection(read_required), f"Write user granted read access to {api}" + + def test_api_isolation(self): + """Test that API scopes don't grant access to other APIs""" + api_scopes = [ + "llama:inference", + "llama:models:read", + "llama:agents:write", + "llama:tools", + "llama:vector_dbs:read", + "llama:safety", + "llama:eval" + ] + + for scope in api_scopes: + user = User( + principal=f"single_scope_user", + attributes={"scopes": [scope]} + ) + user_scopes = set(user.attributes["scopes"]) + + # Test that this scope only grants access to its intended API + test_apis = [ + ("inference", "POST"), + ("models", "GET"), + ("models", "POST"), + ("agents", "GET"), + ("agents", "POST"), + ("tools", "POST"), + ("vector_dbs", "GET"), + ("vector_dbs", "POST"), + ("safety", "POST"), + ("eval", "POST") + ] + + access_count = 0 + for api, method in test_apis: + required = get_required_scopes_for_api(api, method) + if user_scopes.intersection(required): + access_count += 1 + + # Should only have access to 1-2 endpoints (the intended API) + # Allow 2 for APIs that have both read and write variants + assert access_count <= 2, f"Scope {scope} grants too broad access ({access_count} APIs)" + assert access_count >= 1, f"Scope {scope} grants no access" + + +if __name__ == "__main__": + pytest.main([__file__]) \ No newline at end of file diff --git a/tests/unit/server/test_oauth2_scopes.py b/tests/unit/server/test_oauth2_scopes.py new file mode 100644 index 000000000..e9b35dcbe --- /dev/null +++ b/tests/unit/server/test_oauth2_scopes.py @@ -0,0 +1,428 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import pytest +from unittest.mock import AsyncMock, MagicMock, patch + +from llama_stack.distribution.server.oauth2_scopes import ( + STANDARD_OAUTH2_SCOPES, + get_required_scopes_for_api, + validate_scopes, + scope_grants_admin_access, + get_all_scope_descriptions, +) +from llama_stack.distribution.server.auth import extract_api_from_path, AuthenticationMiddleware +from llama_stack.distribution.datatypes import AuthenticationConfig, OAuth2TokenAuthConfig, AuthProviderType + + +class TestOAuth2Scopes: + """Test OAuth2 scope definitions and validation""" + + def test_standard_scopes_exist(self): + """Test that all expected standard scopes are defined""" + expected_scopes = { + "llama:inference", + "llama:models:read", + "llama:models:write", + "llama:agents:read", + "llama:agents:write", + "llama:tools", + "llama:vector_dbs:read", + "llama:vector_dbs:write", + "llama:safety", + "llama:eval", + "llama:admin", + } + + assert set(STANDARD_OAUTH2_SCOPES.keys()) == expected_scopes + + # Verify all scopes have descriptions + for scope, description in STANDARD_OAUTH2_SCOPES.items(): + assert isinstance(description, str) + assert len(description) > 10 # Reasonable description length + + def test_get_all_scope_descriptions(self): + """Test getting all scope descriptions""" + descriptions = get_all_scope_descriptions() + + assert descriptions == STANDARD_OAUTH2_SCOPES + assert len(descriptions) == len(STANDARD_OAUTH2_SCOPES) + + # Verify it's a copy, not the original + descriptions["test"] = "should not affect original" + assert "test" not in STANDARD_OAUTH2_SCOPES + + def test_validate_scopes_valid(self): + """Test scope validation with valid scopes""" + # Single valid scope + token_scopes = {"llama:inference"} + result = validate_scopes(token_scopes) + assert result == {"llama:inference"} + + # Multiple valid scopes + token_scopes = {"llama:inference", "llama:models:read", "llama:admin"} + result = validate_scopes(token_scopes) + assert result == {"llama:inference", "llama:models:read", "llama:admin"} + + # Mix of valid and invalid scopes + token_scopes = {"llama:inference", "invalid:scope", "llama:admin"} + result = validate_scopes(token_scopes) + assert result == {"llama:inference", "llama:admin"} + + def test_validate_scopes_invalid(self): + """Test scope validation with invalid scopes""" + # No valid scopes + token_scopes = {"invalid:scope", "another:invalid"} + with pytest.raises(ValueError, match="Token lacks required OAuth2 scopes"): + validate_scopes(token_scopes) + + # Empty scopes + token_scopes = set() + with pytest.raises(ValueError, match="Token lacks required OAuth2 scopes"): + validate_scopes(token_scopes) + + def test_scope_grants_admin_access(self): + """Test admin scope detection""" + # Admin scope present + assert scope_grants_admin_access({"llama:admin"}) + assert scope_grants_admin_access({"llama:admin", "llama:inference"}) + + # Admin scope not present + assert not scope_grants_admin_access({"llama:inference"}) + assert not scope_grants_admin_access({"llama:models:read", "llama:agents:write"}) + assert not scope_grants_admin_access(set()) + + +class TestScopeRequirements: + """Test API endpoint scope requirements""" + + def test_inference_api_scopes(self): + """Test inference API scope requirements""" + apis = ["inference", "chat", "completion", "embeddings"] + for api in apis: + scopes = get_required_scopes_for_api(api, "POST") + assert "llama:inference" in scopes + assert "llama:admin" in scopes + + def test_models_api_scopes(self): + """Test models API scope requirements""" + # Read operations + read_scopes = get_required_scopes_for_api("models", "GET") + assert "llama:models:read" in read_scopes + assert "llama:admin" in read_scopes + assert "llama:models:write" not in read_scopes + + # Write operations + for method in ["POST", "PUT", "DELETE"]: + write_scopes = get_required_scopes_for_api("models", method) + assert "llama:models:write" in write_scopes + assert "llama:admin" in write_scopes + assert "llama:models:read" not in write_scopes + + def test_agents_api_scopes(self): + """Test agents API scope requirements""" + # Read operations + read_scopes = get_required_scopes_for_api("agents", "GET") + assert "llama:agents:read" in read_scopes + assert "llama:admin" in read_scopes + + # Write operations + for method in ["POST", "PUT", "DELETE"]: + write_scopes = get_required_scopes_for_api("agents", method) + assert "llama:agents:write" in write_scopes + assert "llama:admin" in write_scopes + + def test_tools_api_scopes(self): + """Test tools API scope requirements""" + for api in ["tools", "tool_runtime"]: + scopes = get_required_scopes_for_api(api, "POST") + assert "llama:tools" in scopes + assert "llama:admin" in scopes + + def test_vector_dbs_api_scopes(self): + """Test vector databases API scope requirements""" + # Read operations + read_scopes = get_required_scopes_for_api("vector_dbs", "GET") + assert "llama:vector_dbs:read" in read_scopes + assert "llama:admin" in read_scopes + + # Write operations + for method in ["POST", "PUT", "DELETE"]: + write_scopes = get_required_scopes_for_api("vector_dbs", method) + assert "llama:vector_dbs:write" in write_scopes + assert "llama:admin" in write_scopes + + def test_safety_api_scopes(self): + """Test safety API scope requirements""" + scopes = get_required_scopes_for_api("safety", "POST") + assert "llama:safety" in scopes + assert "llama:admin" in scopes + + def test_eval_api_scopes(self): + """Test evaluation API scope requirements""" + for api in ["eval", "benchmarks", "scoring"]: + scopes = get_required_scopes_for_api(api, "POST") + assert "llama:eval" in scopes + assert "llama:admin" in scopes + + def test_unknown_api_scopes(self): + """Test unknown API only requires admin scope""" + scopes = get_required_scopes_for_api("unknown_api", "POST") + assert scopes == {"llama:admin"} + + def test_admin_always_included(self): + """Test that admin scope is always included in required scopes""" + test_apis = ["inference", "models", "agents", "tools", "safety", "eval", "unknown"] + test_methods = ["GET", "POST", "PUT", "DELETE"] + + for api in test_apis: + for method in test_methods: + scopes = get_required_scopes_for_api(api, method) + assert "llama:admin" in scopes + + +class TestAPIPathExtraction: + """Test API path extraction for scope validation""" + + def test_v1_api_paths(self): + """Test extraction from v1 API paths""" + test_cases = [ + ("/v1/inference/chat-completion", ("inference", "POST")), + ("/v1/models", ("models", "POST")), + ("/v1/models/my-model", ("models", "POST")), + ("/v1/agents/session", ("agents", "POST")), + ("/v1/tools/execute", ("tools", "POST")), + ("/v1/vector_dbs/query", ("vector_dbs", "POST")), + ("/v1/safety/shield", ("safety", "POST")), + ("/v1/eval/benchmark", ("eval", "POST")), + ] + + for path, expected in test_cases: + result = extract_api_from_path(path) + assert result == expected + + def test_openai_compatibility_paths(self): + """Test extraction from OpenAI compatibility paths""" + openai_paths = [ + "/v1/openai/v1/chat/completions", + "/v1/openai/v1/completions", + "/v1/openai/v1/embeddings", + ] + + for path in openai_paths: + api, method = extract_api_from_path(path) + assert api == "inference" + assert method == "POST" + + def test_edge_case_paths(self): + """Test edge cases in path extraction""" + test_cases = [ + ("/", ("unknown", "GET")), + ("/health", ("health", "POST")), + ("/v1/", ("unknown", "GET")), + ("", ("unknown", "GET")), + ("/some/nested/path", ("some", "POST")), + ] + + for path, expected in test_cases: + result = extract_api_from_path(path) + assert result == expected + + +class TestScopeBasedAuth: + """Test scope-based authentication integration""" + + @pytest.fixture + def mock_auth_config(self): + """Create mock OAuth2 authentication config""" + return AuthenticationConfig( + provider_config=OAuth2TokenAuthConfig( + type=AuthProviderType.OAUTH2_TOKEN, + issuer="https://test-issuer.com", + audience="llama-stack", + ) + ) + + @pytest.fixture + def mock_app(self): + """Create mock FastAPI app""" + app = MagicMock() + return app + + def test_scope_validation_in_middleware(self, mock_auth_config, mock_app): + """Test that middleware validates scopes correctly""" + middleware = AuthenticationMiddleware(mock_app, mock_auth_config) + + # This is a simplified test - in practice you'd need to mock the full ASGI flow + assert middleware.auth_provider is not None + + @pytest.mark.asyncio + async def test_token_validation_with_scopes(self): + """Test JWT token validation with OAuth2 scopes""" + # Mock JWT claims with scopes + mock_claims = { + "sub": "test-user", + "scope": "llama:inference llama:models:read", + "iss": "test-issuer", + "aud": "llama-stack", + } + + with patch("llama_stack.distribution.server.auth_providers.jwt.decode") as mock_jwt_decode: + mock_jwt_decode.return_value = mock_claims + + # Mock the auth provider + from llama_stack.distribution.server.auth_providers import OAuth2TokenAuthProvider + from llama_stack.distribution.datatypes import OAuth2TokenAuthConfig, AuthProviderType + + config = OAuth2TokenAuthConfig( + type=AuthProviderType.OAUTH2_TOKEN, + issuer="test-issuer", + audience="llama-stack", + ) + + provider = OAuth2TokenAuthProvider(config) + + # Mock the key retrieval + with patch.object(provider, "_get_public_key") as mock_get_key: + mock_get_key.return_value = "mock-key" + + # Test token validation + user = await provider.validate_token("mock-token", {}) + + assert user.principal == "test-user" + assert user.attributes is not None + assert "scopes" in user.attributes + assert set(user.attributes["scopes"]) == {"llama:inference", "llama:models:read"} + + def test_scope_intersection_logic(self): + """Test scope intersection for access control""" + # User has inference and read scopes + user_scopes = {"llama:inference", "llama:models:read"} + + # Test various API requirements + inference_required = {"llama:admin", "llama:inference"} + models_read_required = {"llama:admin", "llama:models:read"} + models_write_required = {"llama:admin", "llama:models:write"} + + # Should have access to inference + assert user_scopes.intersection(inference_required) + + # Should have access to model reading + assert user_scopes.intersection(models_read_required) + + # Should NOT have access to model writing + assert not user_scopes.intersection(models_write_required) + + def test_admin_scope_access(self): + """Test that admin scope grants access to everything""" + admin_scopes = {"llama:admin"} + + # Test various API requirements + test_requirements = [ + {"llama:admin", "llama:inference"}, + {"llama:admin", "llama:models:write"}, + {"llama:admin", "llama:agents:write"}, + {"llama:admin", "llama:tools"}, + {"llama:admin", "llama:vector_dbs:write"}, + {"llama:admin", "llama:safety"}, + {"llama:admin", "llama:eval"}, + ] + + for required_scopes in test_requirements: + assert admin_scopes.intersection(required_scopes) + + +class TestScopeValidationErrors: + """Test error cases in scope validation""" + + def test_missing_scope_claim(self): + """Test handling of missing scope claim in JWT""" + # Empty scope claim + token_scopes = set() + with pytest.raises(ValueError, match="Token lacks required OAuth2 scopes"): + validate_scopes(token_scopes) + + def test_malformed_scope_string(self): + """Test handling of malformed scope strings""" + # Scopes with extra whitespace should be handled gracefully + scope_string = " llama:inference llama:models:read " + scopes = set(scope_string.split()) + + # Filter out empty strings that might result from split() + scopes = {s.strip() for s in scopes if s.strip()} + + result = validate_scopes(scopes) + assert result == {"llama:inference", "llama:models:read"} + + def test_case_sensitive_scopes(self): + """Test that scopes are case-sensitive""" + # Wrong case should not match + token_scopes = {"LLAMA:INFERENCE", "llama:Models:Read"} + + with pytest.raises(ValueError, match="Token lacks required OAuth2 scopes"): + validate_scopes(token_scopes) + + def test_partial_scope_matches(self): + """Test that partial scope matches don't work""" + # Partial matches should not be accepted + token_scopes = {"llama:model", "llama", "inference"} + + with pytest.raises(ValueError, match="Token lacks required OAuth2 scopes"): + validate_scopes(token_scopes) + + +class TestScopeDocumentation: + """Test scope documentation and descriptions""" + + def test_scope_naming_convention(self): + """Test that scope names follow consistent naming convention""" + for scope in STANDARD_OAUTH2_SCOPES.keys(): + # All scopes should start with 'llama:' + assert scope.startswith("llama:") + + # Should not contain spaces + assert " " not in scope + + # Should use colons as separators, not dots or slashes + parts = scope.split(":") + assert len(parts) >= 2 + assert all(part.replace("_", "").isalnum() for part in parts) + + def test_scope_descriptions_quality(self): + """Test that scope descriptions are meaningful""" + for scope, description in STANDARD_OAUTH2_SCOPES.items(): + # Should be non-empty strings + assert isinstance(description, str) + assert len(description.strip()) > 5 + + # Should not just be the scope name + assert scope not in description + + # Should contain descriptive words + descriptive_words = ["access", "read", "write", "manage", "create", "delete", "execute"] + assert any(word in description.lower() for word in descriptive_words) + + def test_read_write_scope_pairs(self): + """Test that read/write scope pairs are consistent""" + read_write_apis = ["models", "agents", "vector_dbs"] + + for api in read_write_apis: + read_scope = f"llama:{api}:read" + write_scope = f"llama:{api}:write" + + assert read_scope in STANDARD_OAUTH2_SCOPES + assert write_scope in STANDARD_OAUTH2_SCOPES + + # Read description should mention "read" + assert "read" in STANDARD_OAUTH2_SCOPES[read_scope].lower() + + # Write description should mention write operations + write_desc = STANDARD_OAUTH2_SCOPES[write_scope].lower() + assert any(word in write_desc for word in ["write", "manage", "register", "create"]) + + +if __name__ == "__main__": + pytest.main([__file__]) \ No newline at end of file