mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-22 22:42:25 +00:00
feat: adding scope-based authorization
This commit is contained in:
parent
51b179e1c5
commit
fe093918c2
6 changed files with 1082 additions and 9 deletions
38
CHANGELOG.md
38
CHANGELOG.md
|
|
@ -1,5 +1,43 @@
|
||||||
# Changelog
|
# Changelog
|
||||||
|
|
||||||
|
# v0.3.0 (Upcoming)
|
||||||
|
Published on: TBD
|
||||||
|
|
||||||
|
## 🚨 **BREAKING CHANGES**
|
||||||
|
|
||||||
|
### OAuth2 Scope-Based Authentication
|
||||||
|
* **BREAKING:** JWT tokens now REQUIRE OAuth2 scopes for API access
|
||||||
|
* **BREAKING:** Tokens without valid `scope` claim will be rejected (401 Unauthorized)
|
||||||
|
* **BREAKING:** Legacy attribute-based access control replaced with OAuth2 scopes
|
||||||
|
|
||||||
|
#### New Standard OAuth2 Scopes:
|
||||||
|
- `llama:inference` - Access to inference APIs (`/v1/inference/*`, OpenAI compatibility)
|
||||||
|
- `llama:models:read` - Read access to models (`GET /v1/models/*`)
|
||||||
|
- `llama:models:write` - Write access to models (`POST/PUT/DELETE /v1/models/*`)
|
||||||
|
- `llama:agents:read` - Read access to agents (`GET /v1/agents/*`)
|
||||||
|
- `llama:agents:write` - Write access to agents (`POST/PUT/DELETE /v1/agents/*`)
|
||||||
|
- `llama:tools` - Access to tool runtime (`/v1/tools/*`)
|
||||||
|
- `llama:vector_dbs:read` - Read access to vector databases
|
||||||
|
- `llama:vector_dbs:write` - Write access to vector databases
|
||||||
|
- `llama:safety` - Access to safety shields (`/v1/safety/*`)
|
||||||
|
- `llama:eval` - Access to evaluation APIs (`/v1/eval/*`, `/v1/benchmarks/*`)
|
||||||
|
- `llama:admin` - Full administrative access to all APIs
|
||||||
|
|
||||||
|
#### Migration Required:
|
||||||
|
1. **Update OAuth2 Provider:** Configure your OAuth2/OIDC provider to include Llama Stack scopes in JWT tokens
|
||||||
|
2. **Update Client Applications:** Request appropriate scopes when obtaining tokens
|
||||||
|
3. **Test Token Format:** Ensure JWT tokens include `"scope": "llama:inference llama:models:read"` claim
|
||||||
|
|
||||||
|
#### Security Benefits:
|
||||||
|
- **Principle of Least Privilege:** Granular access control per API
|
||||||
|
- **Deny by Default:** No access without explicit scope grants
|
||||||
|
- **OAuth2.0 Compliance:** Follows industry standard specifications
|
||||||
|
- **Enhanced Audit Trail:** Clear permission tracking
|
||||||
|
|
||||||
|
See [OAuth2 Scope Migration Guide](docs/source/concepts/oauth2_scopes.md) for detailed migration instructions.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
# v0.2.12
|
# v0.2.12
|
||||||
Published on: 2025-06-20T22:52:12Z
|
Published on: 2025-06-20T22:52:12Z
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -10,11 +10,36 @@ import httpx
|
||||||
|
|
||||||
from llama_stack.distribution.datatypes import AuthenticationConfig
|
from llama_stack.distribution.datatypes import AuthenticationConfig
|
||||||
from llama_stack.distribution.server.auth_providers import create_auth_provider
|
from llama_stack.distribution.server.auth_providers import create_auth_provider
|
||||||
|
from llama_stack.distribution.server.oauth2_scopes import get_required_scopes_for_api
|
||||||
from llama_stack.log import get_logger
|
from llama_stack.log import get_logger
|
||||||
|
|
||||||
logger = get_logger(name=__name__, category="auth")
|
logger = get_logger(name=__name__, category="auth")
|
||||||
|
|
||||||
|
|
||||||
|
def extract_api_from_path(path: str) -> tuple[str, str]:
|
||||||
|
"""Extract API name and method from request path for scope validation"""
|
||||||
|
# Remove leading/trailing slashes and split
|
||||||
|
path = path.strip("/")
|
||||||
|
parts = path.split("/")
|
||||||
|
|
||||||
|
# Handle common API path patterns
|
||||||
|
if len(parts) >= 2 and parts[0] == "v1":
|
||||||
|
api_name = parts[1]
|
||||||
|
# Handle nested paths like /v1/models/{id} or /v1/inference/chat-completion
|
||||||
|
if api_name in ["inference", "models", "agents", "tools", "vector_dbs", "safety", "eval", "scoring"]:
|
||||||
|
return api_name, "POST" # Default to POST for scope checking
|
||||||
|
elif api_name == "openai":
|
||||||
|
# Handle OpenAI compatibility endpoints like /v1/openai/v1/chat/completions
|
||||||
|
if len(parts) >= 4:
|
||||||
|
return "inference", "POST" # OpenAI endpoints are typically inference
|
||||||
|
|
||||||
|
# Fallback - try to extract from first path component
|
||||||
|
if parts:
|
||||||
|
return parts[0], "POST"
|
||||||
|
|
||||||
|
return "unknown", "GET"
|
||||||
|
|
||||||
|
|
||||||
class AuthenticationMiddleware:
|
class AuthenticationMiddleware:
|
||||||
"""Middleware that authenticates requests using configured authentication provider.
|
"""Middleware that authenticates requests using configured authentication provider.
|
||||||
|
|
||||||
|
|
@ -109,6 +134,34 @@ class AuthenticationMiddleware:
|
||||||
logger.exception("Error during authentication")
|
logger.exception("Error during authentication")
|
||||||
return await self._send_auth_error(send, "Authentication service error")
|
return await self._send_auth_error(send, "Authentication service error")
|
||||||
|
|
||||||
|
# Validate OAuth2 scopes for the requested API endpoint
|
||||||
|
path = scope.get("path", "")
|
||||||
|
method = scope.get("method", "GET")
|
||||||
|
api_name, _ = extract_api_from_path(path)
|
||||||
|
|
||||||
|
# Get required scopes for this API endpoint
|
||||||
|
required_scopes = get_required_scopes_for_api(api_name, method)
|
||||||
|
|
||||||
|
# Check if user has any of the required scopes
|
||||||
|
user_scopes = set()
|
||||||
|
if validation_result.attributes and "scopes" in validation_result.attributes:
|
||||||
|
user_scopes = set(validation_result.attributes["scopes"])
|
||||||
|
|
||||||
|
# Verify user has at least one required scope
|
||||||
|
if not user_scopes.intersection(required_scopes):
|
||||||
|
logger.warning(
|
||||||
|
f"Access denied for {validation_result.principal} to {api_name} API. "
|
||||||
|
f"Required scopes: {required_scopes}, User scopes: {user_scopes}"
|
||||||
|
)
|
||||||
|
return await self._send_auth_error(
|
||||||
|
send, f"Insufficient OAuth2 scopes for {api_name} API. Required: {', '.join(required_scopes)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.debug(
|
||||||
|
f"OAuth2 scope validation passed for {validation_result.principal} "
|
||||||
|
f"on {api_name} API with scopes: {user_scopes.intersection(required_scopes)}"
|
||||||
|
)
|
||||||
|
|
||||||
# Store the client ID in the request scope so that downstream middleware (like QuotaMiddleware)
|
# Store the client ID in the request scope so that downstream middleware (like QuotaMiddleware)
|
||||||
# can identify the requester and enforce per-client rate limits.
|
# can identify the requester and enforce per-client rate limits.
|
||||||
scope["authenticated_client_id"] = token
|
scope["authenticated_client_id"] = token
|
||||||
|
|
@ -117,9 +170,8 @@ class AuthenticationMiddleware:
|
||||||
scope["principal"] = validation_result.principal
|
scope["principal"] = validation_result.principal
|
||||||
if validation_result.attributes:
|
if validation_result.attributes:
|
||||||
scope["user_attributes"] = validation_result.attributes
|
scope["user_attributes"] = validation_result.attributes
|
||||||
logger.debug(
|
attr_count = len(validation_result.attributes) if validation_result.attributes else 0
|
||||||
f"Authentication successful: {validation_result.principal} with {len(validation_result.attributes)} attributes"
|
logger.debug(f"Authentication successful: {validation_result.principal} with {attr_count} attributes")
|
||||||
)
|
|
||||||
|
|
||||||
return await self.app(scope, receive, send)
|
return await self.app(scope, receive, send)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -21,6 +21,10 @@ from llama_stack.distribution.datatypes import (
|
||||||
OAuth2TokenAuthConfig,
|
OAuth2TokenAuthConfig,
|
||||||
User,
|
User,
|
||||||
)
|
)
|
||||||
|
from llama_stack.distribution.server.oauth2_scopes import (
|
||||||
|
scope_grants_admin_access,
|
||||||
|
validate_scopes,
|
||||||
|
)
|
||||||
from llama_stack.log import get_logger
|
from llama_stack.log import get_logger
|
||||||
|
|
||||||
logger = get_logger(name=__name__, category="auth")
|
logger = get_logger(name=__name__, category="auth")
|
||||||
|
|
@ -128,13 +132,34 @@ class OAuth2TokenAuthProvider(AuthProvider):
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
raise ValueError("Invalid JWT token") from exc
|
raise ValueError("Invalid JWT token") from exc
|
||||||
|
|
||||||
# There are other standard claims, the most relevant of which is `scope`.
|
# Extract and validate OAuth2 scopes - deny by default if no valid scopes
|
||||||
# We should incorporate these into the access attributes.
|
token_scopes = set(claims.get("scope", "").split()) if claims.get("scope") else set()
|
||||||
principal = claims["sub"]
|
|
||||||
access_attributes = get_attributes_from_claims(claims, self.config.claims_mapping)
|
# Validate scopes against standard Llama Stack scopes
|
||||||
|
valid_scopes = validate_scopes(token_scopes)
|
||||||
|
|
||||||
|
logger.info(f"User {claims['sub']} authenticated with scopes: {valid_scopes}")
|
||||||
|
|
||||||
|
# Convert scopes to user attributes for the existing ABAC system
|
||||||
|
attributes = {"scopes": list(valid_scopes)}
|
||||||
|
|
||||||
|
# Add admin role if user has admin scope
|
||||||
|
if scope_grants_admin_access(valid_scopes):
|
||||||
|
attributes["roles"] = ["admin"]
|
||||||
|
|
||||||
|
# Maintain backward compatibility with existing claims mapping
|
||||||
|
legacy_attributes = get_attributes_from_claims(claims, self.config.claims_mapping)
|
||||||
|
if legacy_attributes:
|
||||||
|
for key, values in legacy_attributes.items():
|
||||||
|
if key not in attributes:
|
||||||
|
attributes[key] = values
|
||||||
|
else:
|
||||||
|
# Merge lists, avoiding duplicates
|
||||||
|
attributes[key] = list(set(attributes[key] + values))
|
||||||
|
|
||||||
return User(
|
return User(
|
||||||
principal=principal,
|
principal=claims["sub"],
|
||||||
attributes=access_attributes,
|
attributes=attributes,
|
||||||
)
|
)
|
||||||
|
|
||||||
async def introspect_token(self, token: str, scope: dict | None = None) -> User:
|
async def introspect_token(self, token: str, scope: dict | None = None) -> User:
|
||||||
|
|
|
||||||
126
llama_stack/distribution/server/oauth2_scopes.py
Normal file
126
llama_stack/distribution/server/oauth2_scopes.py
Normal file
|
|
@ -0,0 +1,126 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
"""
|
||||||
|
OAuth2 scope definitions and validation for Llama Stack APIs.
|
||||||
|
|
||||||
|
This module defines the standard OAuth2 scopes that are built-in to Llama Stack
|
||||||
|
and provides utilities for scope validation. These scopes are not configurable
|
||||||
|
and provide a standardized way to control access to different API endpoints.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import Set
|
||||||
|
|
||||||
|
# Standard OAuth2 scopes for Llama Stack APIs (built-in, not configurable)
|
||||||
|
STANDARD_OAUTH2_SCOPES = {
|
||||||
|
# Inference API
|
||||||
|
"llama:inference": "Access to inference APIs (chat completion, embeddings)",
|
||||||
|
|
||||||
|
# Models API
|
||||||
|
"llama:models:read": "Read access to models (list, get model details)",
|
||||||
|
"llama:models:write": "Write access to models (register, unregister)",
|
||||||
|
|
||||||
|
# Agents API
|
||||||
|
"llama:agents:read": "Read access to agents (list sessions, get agent details)",
|
||||||
|
"llama:agents:write": "Write access to agents (create sessions, send messages)",
|
||||||
|
|
||||||
|
# Tools API
|
||||||
|
"llama:tools": "Access to tool runtime and execution",
|
||||||
|
|
||||||
|
# Vector DB API
|
||||||
|
"llama:vector_dbs:read": "Read access to vector databases",
|
||||||
|
"llama:vector_dbs:write": "Write access to vector databases",
|
||||||
|
|
||||||
|
# Safety API
|
||||||
|
"llama:safety": "Access to safety shields and content filtering",
|
||||||
|
|
||||||
|
# Eval API
|
||||||
|
"llama:eval": "Access to evaluation and benchmarking",
|
||||||
|
|
||||||
|
# Administrative access
|
||||||
|
"llama:admin": "Full administrative access to all APIs",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def get_required_scopes_for_api(api_name: str, method: str = "GET") -> Set[str]:
|
||||||
|
"""Get required OAuth2 scopes for accessing a specific API endpoint.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
api_name: The name of the API (e.g., 'models', 'inference', 'agents')
|
||||||
|
method: The HTTP method (GET, POST, PUT, DELETE)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Set of scope strings that would grant access to this endpoint.
|
||||||
|
Always includes 'llama:admin' as it grants access to everything.
|
||||||
|
"""
|
||||||
|
# Admin scope grants access to everything
|
||||||
|
required_scopes = {"llama:admin"}
|
||||||
|
|
||||||
|
# Map API names to required scopes
|
||||||
|
if api_name in ["inference", "chat", "completion", "embeddings"]:
|
||||||
|
required_scopes.add("llama:inference")
|
||||||
|
elif api_name == "models":
|
||||||
|
if method in ["POST", "PUT", "DELETE"]:
|
||||||
|
required_scopes.add("llama:models:write")
|
||||||
|
else:
|
||||||
|
required_scopes.add("llama:models:read")
|
||||||
|
elif api_name == "agents":
|
||||||
|
if method in ["POST", "PUT", "DELETE"]:
|
||||||
|
required_scopes.add("llama:agents:write")
|
||||||
|
else:
|
||||||
|
required_scopes.add("llama:agents:read")
|
||||||
|
elif api_name in ["tools", "tool_runtime"]:
|
||||||
|
required_scopes.add("llama:tools")
|
||||||
|
elif api_name == "vector_dbs":
|
||||||
|
if method in ["POST", "PUT", "DELETE"]:
|
||||||
|
required_scopes.add("llama:vector_dbs:write")
|
||||||
|
else:
|
||||||
|
required_scopes.add("llama:vector_dbs:read")
|
||||||
|
elif api_name == "safety":
|
||||||
|
required_scopes.add("llama:safety")
|
||||||
|
elif api_name in ["eval", "benchmarks", "scoring"]:
|
||||||
|
required_scopes.add("llama:eval")
|
||||||
|
|
||||||
|
return required_scopes
|
||||||
|
|
||||||
|
|
||||||
|
def validate_scopes(token_scopes: Set[str]) -> Set[str]:
|
||||||
|
"""Validate OAuth2 scopes against standard Llama Stack scopes.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
token_scopes: Set of scopes from the OAuth2 token
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Set of valid scopes that are recognized by Llama Stack
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If no valid scopes are found
|
||||||
|
"""
|
||||||
|
valid_scopes = token_scopes.intersection(STANDARD_OAUTH2_SCOPES.keys())
|
||||||
|
if not valid_scopes:
|
||||||
|
raise ValueError("Token lacks required OAuth2 scopes for Llama Stack access")
|
||||||
|
return valid_scopes
|
||||||
|
|
||||||
|
|
||||||
|
def scope_grants_admin_access(scopes: Set[str]) -> bool:
|
||||||
|
"""Check if the provided scopes include administrative access.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
scopes: Set of OAuth2 scopes
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if scopes include administrative access, False otherwise
|
||||||
|
"""
|
||||||
|
return "llama:admin" in scopes
|
||||||
|
|
||||||
|
|
||||||
|
def get_all_scope_descriptions() -> dict[str, str]:
|
||||||
|
"""Get all standard OAuth2 scopes with their descriptions.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary mapping scope names to their descriptions
|
||||||
|
"""
|
||||||
|
return STANDARD_OAUTH2_SCOPES.copy()
|
||||||
404
tests/unit/server/test_oauth2_integration.py
Normal file
404
tests/unit/server/test_oauth2_integration.py
Normal file
|
|
@ -0,0 +1,404 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
"""
|
||||||
|
Integration tests for OAuth2 scope-based authentication.
|
||||||
|
|
||||||
|
These tests verify the end-to-end flow of OAuth2 scope validation
|
||||||
|
from JWT token parsing to API access decisions.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
|
||||||
|
from llama_stack.distribution.server.oauth2_scopes import get_required_scopes_for_api
|
||||||
|
from llama_stack.distribution.server.auth import extract_api_from_path
|
||||||
|
from llama_stack.distribution.datatypes import User
|
||||||
|
|
||||||
|
|
||||||
|
def get_user_scopes(user: User) -> set[str]:
|
||||||
|
"""Safely extract scopes from user attributes"""
|
||||||
|
if user.attributes and "scopes" in user.attributes:
|
||||||
|
return set(user.attributes["scopes"])
|
||||||
|
return set()
|
||||||
|
|
||||||
|
|
||||||
|
class TestOAuth2IntegrationFlow:
|
||||||
|
"""Test the complete OAuth2 scope validation flow"""
|
||||||
|
|
||||||
|
def test_inference_api_access_flow(self):
|
||||||
|
"""Test complete flow for inference API access"""
|
||||||
|
# 1. Extract API from path
|
||||||
|
api_name, method = extract_api_from_path("/v1/inference/chat-completion")
|
||||||
|
assert api_name == "inference"
|
||||||
|
|
||||||
|
# 2. Get required scopes for this API
|
||||||
|
required_scopes = get_required_scopes_for_api(api_name, method)
|
||||||
|
assert "llama:inference" in required_scopes
|
||||||
|
assert "llama:admin" in required_scopes
|
||||||
|
|
||||||
|
# 3. Test user with correct scope
|
||||||
|
user_with_inference = User(
|
||||||
|
principal="user1",
|
||||||
|
attributes={"scopes": ["llama:inference"]}
|
||||||
|
)
|
||||||
|
user_scopes = set(user_with_inference.attributes["scopes"]) if user_with_inference.attributes else set()
|
||||||
|
assert user_scopes.intersection(required_scopes) # Should have access
|
||||||
|
|
||||||
|
# 4. Test user without correct scope
|
||||||
|
user_without_inference = User(
|
||||||
|
principal="user2",
|
||||||
|
attributes={"scopes": ["llama:models:read"]}
|
||||||
|
)
|
||||||
|
user_scopes = set(user_without_inference.attributes["scopes"])
|
||||||
|
assert not user_scopes.intersection(required_scopes) # Should NOT have access
|
||||||
|
|
||||||
|
def test_models_api_read_write_flow(self):
|
||||||
|
"""Test complete flow for models API with read/write distinction"""
|
||||||
|
# Test read operation
|
||||||
|
api_name, method = extract_api_from_path("/v1/models")
|
||||||
|
read_required = get_required_scopes_for_api(api_name, "GET")
|
||||||
|
|
||||||
|
# Test write operation
|
||||||
|
write_required = get_required_scopes_for_api(api_name, "POST")
|
||||||
|
|
||||||
|
# User with only read access
|
||||||
|
read_user = User(
|
||||||
|
principal="read_user",
|
||||||
|
attributes={"scopes": ["llama:models:read"]}
|
||||||
|
)
|
||||||
|
read_scopes = set(read_user.attributes["scopes"])
|
||||||
|
|
||||||
|
# Should have read access
|
||||||
|
assert read_scopes.intersection(read_required)
|
||||||
|
|
||||||
|
# Should NOT have write access
|
||||||
|
assert not read_scopes.intersection(write_required)
|
||||||
|
|
||||||
|
# User with write access
|
||||||
|
write_user = User(
|
||||||
|
principal="write_user",
|
||||||
|
attributes={"scopes": ["llama:models:write"]}
|
||||||
|
)
|
||||||
|
write_scopes = set(write_user.attributes["scopes"])
|
||||||
|
|
||||||
|
# Should have write access
|
||||||
|
assert write_scopes.intersection(write_required)
|
||||||
|
|
||||||
|
# Should NOT have read access (write doesn't imply read)
|
||||||
|
assert not write_scopes.intersection(read_required)
|
||||||
|
|
||||||
|
def test_admin_scope_universal_access(self):
|
||||||
|
"""Test that admin scope grants access to all APIs"""
|
||||||
|
admin_user = User(
|
||||||
|
principal="admin",
|
||||||
|
attributes={"scopes": ["llama:admin"]}
|
||||||
|
)
|
||||||
|
admin_scopes = set(admin_user.attributes["scopes"])
|
||||||
|
|
||||||
|
# Test various API endpoints
|
||||||
|
test_cases = [
|
||||||
|
("/v1/inference/chat-completion", "POST"),
|
||||||
|
("/v1/models", "GET"),
|
||||||
|
("/v1/models/my-model", "DELETE"),
|
||||||
|
("/v1/agents/session", "POST"),
|
||||||
|
("/v1/tools/execute", "POST"),
|
||||||
|
("/v1/vector_dbs/query", "POST"),
|
||||||
|
("/v1/safety/shield", "POST"),
|
||||||
|
("/v1/eval/benchmark", "POST"),
|
||||||
|
]
|
||||||
|
|
||||||
|
for path, method in test_cases:
|
||||||
|
api_name, _ = extract_api_from_path(path)
|
||||||
|
required_scopes = get_required_scopes_for_api(api_name, method)
|
||||||
|
|
||||||
|
# Admin should always have access
|
||||||
|
assert admin_scopes.intersection(required_scopes), f"Admin denied access to {path}"
|
||||||
|
|
||||||
|
def test_multiple_scopes_user(self):
|
||||||
|
"""Test user with multiple scopes"""
|
||||||
|
multi_scope_user = User(
|
||||||
|
principal="power_user",
|
||||||
|
attributes={
|
||||||
|
"scopes": [
|
||||||
|
"llama:inference",
|
||||||
|
"llama:models:read",
|
||||||
|
"llama:agents:write",
|
||||||
|
"llama:tools"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
)
|
||||||
|
user_scopes = set(multi_scope_user.attributes["scopes"])
|
||||||
|
|
||||||
|
# Test various access scenarios
|
||||||
|
access_tests = [
|
||||||
|
("/v1/inference/chat-completion", "POST", True), # Has inference
|
||||||
|
("/v1/models", "GET", True), # Has models:read
|
||||||
|
("/v1/models", "POST", False), # Doesn't have models:write
|
||||||
|
("/v1/agents/session", "POST", True), # Has agents:write
|
||||||
|
("/v1/agents", "GET", False), # Doesn't have agents:read
|
||||||
|
("/v1/tools/execute", "POST", True), # Has tools
|
||||||
|
("/v1/vector_dbs/query", "GET", False), # Doesn't have vector_dbs:read
|
||||||
|
("/v1/safety/shield", "POST", False), # Doesn't have safety
|
||||||
|
]
|
||||||
|
|
||||||
|
for path, method, should_have_access in access_tests:
|
||||||
|
api_name, _ = extract_api_from_path(path)
|
||||||
|
required_scopes = get_required_scopes_for_api(api_name, method)
|
||||||
|
has_access = bool(user_scopes.intersection(required_scopes))
|
||||||
|
|
||||||
|
assert has_access == should_have_access, (
|
||||||
|
f"Access mismatch for {path} {method}: "
|
||||||
|
f"expected {should_have_access}, got {has_access}"
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_openai_compatibility_scope_flow(self):
|
||||||
|
"""Test OAuth2 scope validation for OpenAI compatibility endpoints"""
|
||||||
|
# OpenAI endpoints should map to inference API
|
||||||
|
openai_paths = [
|
||||||
|
"/v1/openai/v1/chat/completions",
|
||||||
|
"/v1/openai/v1/completions",
|
||||||
|
"/v1/openai/v1/embeddings",
|
||||||
|
]
|
||||||
|
|
||||||
|
inference_user = User(
|
||||||
|
principal="openai_user",
|
||||||
|
attributes={"scopes": ["llama:inference"]}
|
||||||
|
)
|
||||||
|
user_scopes = set(inference_user.attributes["scopes"])
|
||||||
|
|
||||||
|
for path in openai_paths:
|
||||||
|
api_name, method = extract_api_from_path(path)
|
||||||
|
assert api_name == "inference"
|
||||||
|
|
||||||
|
required_scopes = get_required_scopes_for_api(api_name, method)
|
||||||
|
assert user_scopes.intersection(required_scopes), f"No access to {path}"
|
||||||
|
|
||||||
|
def test_scope_validation_error_scenarios(self):
|
||||||
|
"""Test error scenarios in scope validation"""
|
||||||
|
# User with no scopes
|
||||||
|
no_scope_user = User(principal="no_scope", attributes={})
|
||||||
|
|
||||||
|
api_name, method = extract_api_from_path("/v1/inference/chat-completion")
|
||||||
|
required_scopes = get_required_scopes_for_api(api_name, method)
|
||||||
|
|
||||||
|
# Should not have access
|
||||||
|
user_scopes = set()
|
||||||
|
assert not user_scopes.intersection(required_scopes)
|
||||||
|
|
||||||
|
# User with empty scopes list
|
||||||
|
empty_scope_user = User(
|
||||||
|
principal="empty_scope",
|
||||||
|
attributes={"scopes": []}
|
||||||
|
)
|
||||||
|
user_scopes = set(empty_scope_user.attributes["scopes"])
|
||||||
|
assert not user_scopes.intersection(required_scopes)
|
||||||
|
|
||||||
|
# User with invalid scopes
|
||||||
|
invalid_scope_user = User(
|
||||||
|
principal="invalid_scope",
|
||||||
|
attributes={"scopes": ["invalid:scope", "another:invalid"]}
|
||||||
|
)
|
||||||
|
user_scopes = set(invalid_scope_user.attributes["scopes"])
|
||||||
|
assert not user_scopes.intersection(required_scopes)
|
||||||
|
|
||||||
|
|
||||||
|
class TestScopeBasedAccessMatrix:
|
||||||
|
"""Test access matrix for different user types and API combinations"""
|
||||||
|
|
||||||
|
def test_data_scientist_access_pattern(self):
|
||||||
|
"""Test typical data scientist access pattern"""
|
||||||
|
data_scientist = User(
|
||||||
|
principal="data_scientist",
|
||||||
|
attributes={
|
||||||
|
"scopes": [
|
||||||
|
"llama:inference",
|
||||||
|
"llama:models:read",
|
||||||
|
"llama:eval",
|
||||||
|
"llama:safety"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
)
|
||||||
|
user_scopes = set(data_scientist.attributes["scopes"])
|
||||||
|
|
||||||
|
# Should have access to
|
||||||
|
allowed_apis = [
|
||||||
|
("inference", "POST"), # Run inference
|
||||||
|
("models", "GET"), # List/inspect models
|
||||||
|
("eval", "POST"), # Run evaluations
|
||||||
|
("safety", "POST"), # Use safety shields
|
||||||
|
]
|
||||||
|
|
||||||
|
for api, method in allowed_apis:
|
||||||
|
required = get_required_scopes_for_api(api, method)
|
||||||
|
assert user_scopes.intersection(required), f"Data scientist denied {api} {method}"
|
||||||
|
|
||||||
|
# Should NOT have access to
|
||||||
|
denied_apis = [
|
||||||
|
("models", "POST"), # Cannot register models
|
||||||
|
("agents", "POST"), # Cannot create agents
|
||||||
|
("tools", "POST"), # Cannot use tools
|
||||||
|
("vector_dbs", "GET"), # Cannot access vector DBs
|
||||||
|
]
|
||||||
|
|
||||||
|
for api, method in denied_apis:
|
||||||
|
required = get_required_scopes_for_api(api, method)
|
||||||
|
assert not user_scopes.intersection(required), f"Data scientist allowed {api} {method}"
|
||||||
|
|
||||||
|
def test_ml_engineer_access_pattern(self):
|
||||||
|
"""Test typical ML engineer access pattern"""
|
||||||
|
ml_engineer = User(
|
||||||
|
principal="ml_engineer",
|
||||||
|
attributes={
|
||||||
|
"scopes": [
|
||||||
|
"llama:inference",
|
||||||
|
"llama:models:read",
|
||||||
|
"llama:models:write",
|
||||||
|
"llama:agents:read",
|
||||||
|
"llama:agents:write",
|
||||||
|
"llama:tools",
|
||||||
|
"llama:eval"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
)
|
||||||
|
user_scopes = set(ml_engineer.attributes["scopes"])
|
||||||
|
|
||||||
|
# Should have broad access except admin-only operations
|
||||||
|
allowed_apis = [
|
||||||
|
("inference", "POST"),
|
||||||
|
("models", "GET"),
|
||||||
|
("models", "POST"),
|
||||||
|
("agents", "GET"),
|
||||||
|
("agents", "POST"),
|
||||||
|
("tools", "POST"),
|
||||||
|
("eval", "POST"),
|
||||||
|
]
|
||||||
|
|
||||||
|
for api, method in allowed_apis:
|
||||||
|
required = get_required_scopes_for_api(api, method)
|
||||||
|
assert user_scopes.intersection(required), f"ML engineer denied {api} {method}"
|
||||||
|
|
||||||
|
def test_application_developer_access_pattern(self):
|
||||||
|
"""Test typical application developer access pattern"""
|
||||||
|
app_developer = User(
|
||||||
|
principal="app_developer",
|
||||||
|
attributes={
|
||||||
|
"scopes": [
|
||||||
|
"llama:inference",
|
||||||
|
"llama:agents:read",
|
||||||
|
"llama:agents:write",
|
||||||
|
"llama:tools",
|
||||||
|
"llama:safety"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
)
|
||||||
|
user_scopes = set(app_developer.attributes["scopes"])
|
||||||
|
|
||||||
|
# Should focus on application-building APIs
|
||||||
|
allowed_apis = [
|
||||||
|
("inference", "POST"), # Use models for apps
|
||||||
|
("agents", "GET"), # Inspect agents
|
||||||
|
("agents", "POST"), # Create agent sessions
|
||||||
|
("tools", "POST"), # Execute tools
|
||||||
|
("safety", "POST"), # Apply safety
|
||||||
|
]
|
||||||
|
|
||||||
|
for api, method in allowed_apis:
|
||||||
|
required = get_required_scopes_for_api(api, method)
|
||||||
|
assert user_scopes.intersection(required), f"App developer denied {api} {method}"
|
||||||
|
|
||||||
|
# Should NOT have model or eval management access
|
||||||
|
denied_apis = [
|
||||||
|
("models", "POST"), # Cannot manage models
|
||||||
|
("eval", "POST"), # Cannot run evaluations
|
||||||
|
("vector_dbs", "POST"), # Cannot manage vector DBs
|
||||||
|
]
|
||||||
|
|
||||||
|
for api, method in denied_apis:
|
||||||
|
required = get_required_scopes_for_api(api, method)
|
||||||
|
assert not user_scopes.intersection(required), f"App developer allowed {api} {method}"
|
||||||
|
|
||||||
|
|
||||||
|
class TestScopeHierarchyAndSeparation:
|
||||||
|
"""Test that scopes are properly separated and don't grant unintended access"""
|
||||||
|
|
||||||
|
def test_read_write_separation(self):
|
||||||
|
"""Test that read scopes don't grant write access and vice versa"""
|
||||||
|
read_only_apis = ["models", "agents", "vector_dbs"]
|
||||||
|
|
||||||
|
for api in read_only_apis:
|
||||||
|
# User with only read scope
|
||||||
|
read_user = User(
|
||||||
|
principal=f"{api}_reader",
|
||||||
|
attributes={"scopes": [f"llama:{api}:read"]}
|
||||||
|
)
|
||||||
|
read_scopes = set(read_user.attributes["scopes"])
|
||||||
|
|
||||||
|
# User with only write scope
|
||||||
|
write_user = User(
|
||||||
|
principal=f"{api}_writer",
|
||||||
|
attributes={"scopes": [f"llama:{api}:write"]}
|
||||||
|
)
|
||||||
|
write_scopes = set(write_user.attributes["scopes"])
|
||||||
|
|
||||||
|
# Read user should only have read access
|
||||||
|
read_required = get_required_scopes_for_api(api, "GET")
|
||||||
|
write_required = get_required_scopes_for_api(api, "POST")
|
||||||
|
|
||||||
|
assert read_scopes.intersection(read_required), f"Read user denied read access to {api}"
|
||||||
|
assert not read_scopes.intersection(write_required), f"Read user granted write access to {api}"
|
||||||
|
|
||||||
|
# Write user should only have write access
|
||||||
|
assert write_scopes.intersection(write_required), f"Write user denied write access to {api}"
|
||||||
|
assert not write_scopes.intersection(read_required), f"Write user granted read access to {api}"
|
||||||
|
|
||||||
|
def test_api_isolation(self):
|
||||||
|
"""Test that API scopes don't grant access to other APIs"""
|
||||||
|
api_scopes = [
|
||||||
|
"llama:inference",
|
||||||
|
"llama:models:read",
|
||||||
|
"llama:agents:write",
|
||||||
|
"llama:tools",
|
||||||
|
"llama:vector_dbs:read",
|
||||||
|
"llama:safety",
|
||||||
|
"llama:eval"
|
||||||
|
]
|
||||||
|
|
||||||
|
for scope in api_scopes:
|
||||||
|
user = User(
|
||||||
|
principal=f"single_scope_user",
|
||||||
|
attributes={"scopes": [scope]}
|
||||||
|
)
|
||||||
|
user_scopes = set(user.attributes["scopes"])
|
||||||
|
|
||||||
|
# Test that this scope only grants access to its intended API
|
||||||
|
test_apis = [
|
||||||
|
("inference", "POST"),
|
||||||
|
("models", "GET"),
|
||||||
|
("models", "POST"),
|
||||||
|
("agents", "GET"),
|
||||||
|
("agents", "POST"),
|
||||||
|
("tools", "POST"),
|
||||||
|
("vector_dbs", "GET"),
|
||||||
|
("vector_dbs", "POST"),
|
||||||
|
("safety", "POST"),
|
||||||
|
("eval", "POST")
|
||||||
|
]
|
||||||
|
|
||||||
|
access_count = 0
|
||||||
|
for api, method in test_apis:
|
||||||
|
required = get_required_scopes_for_api(api, method)
|
||||||
|
if user_scopes.intersection(required):
|
||||||
|
access_count += 1
|
||||||
|
|
||||||
|
# Should only have access to 1-2 endpoints (the intended API)
|
||||||
|
# Allow 2 for APIs that have both read and write variants
|
||||||
|
assert access_count <= 2, f"Scope {scope} grants too broad access ({access_count} APIs)"
|
||||||
|
assert access_count >= 1, f"Scope {scope} grants no access"
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
pytest.main([__file__])
|
||||||
428
tests/unit/server/test_oauth2_scopes.py
Normal file
428
tests/unit/server/test_oauth2_scopes.py
Normal file
|
|
@ -0,0 +1,428 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
|
||||||
|
from llama_stack.distribution.server.oauth2_scopes import (
|
||||||
|
STANDARD_OAUTH2_SCOPES,
|
||||||
|
get_required_scopes_for_api,
|
||||||
|
validate_scopes,
|
||||||
|
scope_grants_admin_access,
|
||||||
|
get_all_scope_descriptions,
|
||||||
|
)
|
||||||
|
from llama_stack.distribution.server.auth import extract_api_from_path, AuthenticationMiddleware
|
||||||
|
from llama_stack.distribution.datatypes import AuthenticationConfig, OAuth2TokenAuthConfig, AuthProviderType
|
||||||
|
|
||||||
|
|
||||||
|
class TestOAuth2Scopes:
|
||||||
|
"""Test OAuth2 scope definitions and validation"""
|
||||||
|
|
||||||
|
def test_standard_scopes_exist(self):
|
||||||
|
"""Test that all expected standard scopes are defined"""
|
||||||
|
expected_scopes = {
|
||||||
|
"llama:inference",
|
||||||
|
"llama:models:read",
|
||||||
|
"llama:models:write",
|
||||||
|
"llama:agents:read",
|
||||||
|
"llama:agents:write",
|
||||||
|
"llama:tools",
|
||||||
|
"llama:vector_dbs:read",
|
||||||
|
"llama:vector_dbs:write",
|
||||||
|
"llama:safety",
|
||||||
|
"llama:eval",
|
||||||
|
"llama:admin",
|
||||||
|
}
|
||||||
|
|
||||||
|
assert set(STANDARD_OAUTH2_SCOPES.keys()) == expected_scopes
|
||||||
|
|
||||||
|
# Verify all scopes have descriptions
|
||||||
|
for scope, description in STANDARD_OAUTH2_SCOPES.items():
|
||||||
|
assert isinstance(description, str)
|
||||||
|
assert len(description) > 10 # Reasonable description length
|
||||||
|
|
||||||
|
def test_get_all_scope_descriptions(self):
|
||||||
|
"""Test getting all scope descriptions"""
|
||||||
|
descriptions = get_all_scope_descriptions()
|
||||||
|
|
||||||
|
assert descriptions == STANDARD_OAUTH2_SCOPES
|
||||||
|
assert len(descriptions) == len(STANDARD_OAUTH2_SCOPES)
|
||||||
|
|
||||||
|
# Verify it's a copy, not the original
|
||||||
|
descriptions["test"] = "should not affect original"
|
||||||
|
assert "test" not in STANDARD_OAUTH2_SCOPES
|
||||||
|
|
||||||
|
def test_validate_scopes_valid(self):
|
||||||
|
"""Test scope validation with valid scopes"""
|
||||||
|
# Single valid scope
|
||||||
|
token_scopes = {"llama:inference"}
|
||||||
|
result = validate_scopes(token_scopes)
|
||||||
|
assert result == {"llama:inference"}
|
||||||
|
|
||||||
|
# Multiple valid scopes
|
||||||
|
token_scopes = {"llama:inference", "llama:models:read", "llama:admin"}
|
||||||
|
result = validate_scopes(token_scopes)
|
||||||
|
assert result == {"llama:inference", "llama:models:read", "llama:admin"}
|
||||||
|
|
||||||
|
# Mix of valid and invalid scopes
|
||||||
|
token_scopes = {"llama:inference", "invalid:scope", "llama:admin"}
|
||||||
|
result = validate_scopes(token_scopes)
|
||||||
|
assert result == {"llama:inference", "llama:admin"}
|
||||||
|
|
||||||
|
def test_validate_scopes_invalid(self):
|
||||||
|
"""Test scope validation with invalid scopes"""
|
||||||
|
# No valid scopes
|
||||||
|
token_scopes = {"invalid:scope", "another:invalid"}
|
||||||
|
with pytest.raises(ValueError, match="Token lacks required OAuth2 scopes"):
|
||||||
|
validate_scopes(token_scopes)
|
||||||
|
|
||||||
|
# Empty scopes
|
||||||
|
token_scopes = set()
|
||||||
|
with pytest.raises(ValueError, match="Token lacks required OAuth2 scopes"):
|
||||||
|
validate_scopes(token_scopes)
|
||||||
|
|
||||||
|
def test_scope_grants_admin_access(self):
|
||||||
|
"""Test admin scope detection"""
|
||||||
|
# Admin scope present
|
||||||
|
assert scope_grants_admin_access({"llama:admin"})
|
||||||
|
assert scope_grants_admin_access({"llama:admin", "llama:inference"})
|
||||||
|
|
||||||
|
# Admin scope not present
|
||||||
|
assert not scope_grants_admin_access({"llama:inference"})
|
||||||
|
assert not scope_grants_admin_access({"llama:models:read", "llama:agents:write"})
|
||||||
|
assert not scope_grants_admin_access(set())
|
||||||
|
|
||||||
|
|
||||||
|
class TestScopeRequirements:
|
||||||
|
"""Test API endpoint scope requirements"""
|
||||||
|
|
||||||
|
def test_inference_api_scopes(self):
|
||||||
|
"""Test inference API scope requirements"""
|
||||||
|
apis = ["inference", "chat", "completion", "embeddings"]
|
||||||
|
for api in apis:
|
||||||
|
scopes = get_required_scopes_for_api(api, "POST")
|
||||||
|
assert "llama:inference" in scopes
|
||||||
|
assert "llama:admin" in scopes
|
||||||
|
|
||||||
|
def test_models_api_scopes(self):
|
||||||
|
"""Test models API scope requirements"""
|
||||||
|
# Read operations
|
||||||
|
read_scopes = get_required_scopes_for_api("models", "GET")
|
||||||
|
assert "llama:models:read" in read_scopes
|
||||||
|
assert "llama:admin" in read_scopes
|
||||||
|
assert "llama:models:write" not in read_scopes
|
||||||
|
|
||||||
|
# Write operations
|
||||||
|
for method in ["POST", "PUT", "DELETE"]:
|
||||||
|
write_scopes = get_required_scopes_for_api("models", method)
|
||||||
|
assert "llama:models:write" in write_scopes
|
||||||
|
assert "llama:admin" in write_scopes
|
||||||
|
assert "llama:models:read" not in write_scopes
|
||||||
|
|
||||||
|
def test_agents_api_scopes(self):
|
||||||
|
"""Test agents API scope requirements"""
|
||||||
|
# Read operations
|
||||||
|
read_scopes = get_required_scopes_for_api("agents", "GET")
|
||||||
|
assert "llama:agents:read" in read_scopes
|
||||||
|
assert "llama:admin" in read_scopes
|
||||||
|
|
||||||
|
# Write operations
|
||||||
|
for method in ["POST", "PUT", "DELETE"]:
|
||||||
|
write_scopes = get_required_scopes_for_api("agents", method)
|
||||||
|
assert "llama:agents:write" in write_scopes
|
||||||
|
assert "llama:admin" in write_scopes
|
||||||
|
|
||||||
|
def test_tools_api_scopes(self):
|
||||||
|
"""Test tools API scope requirements"""
|
||||||
|
for api in ["tools", "tool_runtime"]:
|
||||||
|
scopes = get_required_scopes_for_api(api, "POST")
|
||||||
|
assert "llama:tools" in scopes
|
||||||
|
assert "llama:admin" in scopes
|
||||||
|
|
||||||
|
def test_vector_dbs_api_scopes(self):
|
||||||
|
"""Test vector databases API scope requirements"""
|
||||||
|
# Read operations
|
||||||
|
read_scopes = get_required_scopes_for_api("vector_dbs", "GET")
|
||||||
|
assert "llama:vector_dbs:read" in read_scopes
|
||||||
|
assert "llama:admin" in read_scopes
|
||||||
|
|
||||||
|
# Write operations
|
||||||
|
for method in ["POST", "PUT", "DELETE"]:
|
||||||
|
write_scopes = get_required_scopes_for_api("vector_dbs", method)
|
||||||
|
assert "llama:vector_dbs:write" in write_scopes
|
||||||
|
assert "llama:admin" in write_scopes
|
||||||
|
|
||||||
|
def test_safety_api_scopes(self):
|
||||||
|
"""Test safety API scope requirements"""
|
||||||
|
scopes = get_required_scopes_for_api("safety", "POST")
|
||||||
|
assert "llama:safety" in scopes
|
||||||
|
assert "llama:admin" in scopes
|
||||||
|
|
||||||
|
def test_eval_api_scopes(self):
|
||||||
|
"""Test evaluation API scope requirements"""
|
||||||
|
for api in ["eval", "benchmarks", "scoring"]:
|
||||||
|
scopes = get_required_scopes_for_api(api, "POST")
|
||||||
|
assert "llama:eval" in scopes
|
||||||
|
assert "llama:admin" in scopes
|
||||||
|
|
||||||
|
def test_unknown_api_scopes(self):
|
||||||
|
"""Test unknown API only requires admin scope"""
|
||||||
|
scopes = get_required_scopes_for_api("unknown_api", "POST")
|
||||||
|
assert scopes == {"llama:admin"}
|
||||||
|
|
||||||
|
def test_admin_always_included(self):
|
||||||
|
"""Test that admin scope is always included in required scopes"""
|
||||||
|
test_apis = ["inference", "models", "agents", "tools", "safety", "eval", "unknown"]
|
||||||
|
test_methods = ["GET", "POST", "PUT", "DELETE"]
|
||||||
|
|
||||||
|
for api in test_apis:
|
||||||
|
for method in test_methods:
|
||||||
|
scopes = get_required_scopes_for_api(api, method)
|
||||||
|
assert "llama:admin" in scopes
|
||||||
|
|
||||||
|
|
||||||
|
class TestAPIPathExtraction:
|
||||||
|
"""Test API path extraction for scope validation"""
|
||||||
|
|
||||||
|
def test_v1_api_paths(self):
|
||||||
|
"""Test extraction from v1 API paths"""
|
||||||
|
test_cases = [
|
||||||
|
("/v1/inference/chat-completion", ("inference", "POST")),
|
||||||
|
("/v1/models", ("models", "POST")),
|
||||||
|
("/v1/models/my-model", ("models", "POST")),
|
||||||
|
("/v1/agents/session", ("agents", "POST")),
|
||||||
|
("/v1/tools/execute", ("tools", "POST")),
|
||||||
|
("/v1/vector_dbs/query", ("vector_dbs", "POST")),
|
||||||
|
("/v1/safety/shield", ("safety", "POST")),
|
||||||
|
("/v1/eval/benchmark", ("eval", "POST")),
|
||||||
|
]
|
||||||
|
|
||||||
|
for path, expected in test_cases:
|
||||||
|
result = extract_api_from_path(path)
|
||||||
|
assert result == expected
|
||||||
|
|
||||||
|
def test_openai_compatibility_paths(self):
|
||||||
|
"""Test extraction from OpenAI compatibility paths"""
|
||||||
|
openai_paths = [
|
||||||
|
"/v1/openai/v1/chat/completions",
|
||||||
|
"/v1/openai/v1/completions",
|
||||||
|
"/v1/openai/v1/embeddings",
|
||||||
|
]
|
||||||
|
|
||||||
|
for path in openai_paths:
|
||||||
|
api, method = extract_api_from_path(path)
|
||||||
|
assert api == "inference"
|
||||||
|
assert method == "POST"
|
||||||
|
|
||||||
|
def test_edge_case_paths(self):
|
||||||
|
"""Test edge cases in path extraction"""
|
||||||
|
test_cases = [
|
||||||
|
("/", ("unknown", "GET")),
|
||||||
|
("/health", ("health", "POST")),
|
||||||
|
("/v1/", ("unknown", "GET")),
|
||||||
|
("", ("unknown", "GET")),
|
||||||
|
("/some/nested/path", ("some", "POST")),
|
||||||
|
]
|
||||||
|
|
||||||
|
for path, expected in test_cases:
|
||||||
|
result = extract_api_from_path(path)
|
||||||
|
assert result == expected
|
||||||
|
|
||||||
|
|
||||||
|
class TestScopeBasedAuth:
|
||||||
|
"""Test scope-based authentication integration"""
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_auth_config(self):
|
||||||
|
"""Create mock OAuth2 authentication config"""
|
||||||
|
return AuthenticationConfig(
|
||||||
|
provider_config=OAuth2TokenAuthConfig(
|
||||||
|
type=AuthProviderType.OAUTH2_TOKEN,
|
||||||
|
issuer="https://test-issuer.com",
|
||||||
|
audience="llama-stack",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_app(self):
|
||||||
|
"""Create mock FastAPI app"""
|
||||||
|
app = MagicMock()
|
||||||
|
return app
|
||||||
|
|
||||||
|
def test_scope_validation_in_middleware(self, mock_auth_config, mock_app):
|
||||||
|
"""Test that middleware validates scopes correctly"""
|
||||||
|
middleware = AuthenticationMiddleware(mock_app, mock_auth_config)
|
||||||
|
|
||||||
|
# This is a simplified test - in practice you'd need to mock the full ASGI flow
|
||||||
|
assert middleware.auth_provider is not None
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_token_validation_with_scopes(self):
|
||||||
|
"""Test JWT token validation with OAuth2 scopes"""
|
||||||
|
# Mock JWT claims with scopes
|
||||||
|
mock_claims = {
|
||||||
|
"sub": "test-user",
|
||||||
|
"scope": "llama:inference llama:models:read",
|
||||||
|
"iss": "test-issuer",
|
||||||
|
"aud": "llama-stack",
|
||||||
|
}
|
||||||
|
|
||||||
|
with patch("llama_stack.distribution.server.auth_providers.jwt.decode") as mock_jwt_decode:
|
||||||
|
mock_jwt_decode.return_value = mock_claims
|
||||||
|
|
||||||
|
# Mock the auth provider
|
||||||
|
from llama_stack.distribution.server.auth_providers import OAuth2TokenAuthProvider
|
||||||
|
from llama_stack.distribution.datatypes import OAuth2TokenAuthConfig, AuthProviderType
|
||||||
|
|
||||||
|
config = OAuth2TokenAuthConfig(
|
||||||
|
type=AuthProviderType.OAUTH2_TOKEN,
|
||||||
|
issuer="test-issuer",
|
||||||
|
audience="llama-stack",
|
||||||
|
)
|
||||||
|
|
||||||
|
provider = OAuth2TokenAuthProvider(config)
|
||||||
|
|
||||||
|
# Mock the key retrieval
|
||||||
|
with patch.object(provider, "_get_public_key") as mock_get_key:
|
||||||
|
mock_get_key.return_value = "mock-key"
|
||||||
|
|
||||||
|
# Test token validation
|
||||||
|
user = await provider.validate_token("mock-token", {})
|
||||||
|
|
||||||
|
assert user.principal == "test-user"
|
||||||
|
assert user.attributes is not None
|
||||||
|
assert "scopes" in user.attributes
|
||||||
|
assert set(user.attributes["scopes"]) == {"llama:inference", "llama:models:read"}
|
||||||
|
|
||||||
|
def test_scope_intersection_logic(self):
|
||||||
|
"""Test scope intersection for access control"""
|
||||||
|
# User has inference and read scopes
|
||||||
|
user_scopes = {"llama:inference", "llama:models:read"}
|
||||||
|
|
||||||
|
# Test various API requirements
|
||||||
|
inference_required = {"llama:admin", "llama:inference"}
|
||||||
|
models_read_required = {"llama:admin", "llama:models:read"}
|
||||||
|
models_write_required = {"llama:admin", "llama:models:write"}
|
||||||
|
|
||||||
|
# Should have access to inference
|
||||||
|
assert user_scopes.intersection(inference_required)
|
||||||
|
|
||||||
|
# Should have access to model reading
|
||||||
|
assert user_scopes.intersection(models_read_required)
|
||||||
|
|
||||||
|
# Should NOT have access to model writing
|
||||||
|
assert not user_scopes.intersection(models_write_required)
|
||||||
|
|
||||||
|
def test_admin_scope_access(self):
|
||||||
|
"""Test that admin scope grants access to everything"""
|
||||||
|
admin_scopes = {"llama:admin"}
|
||||||
|
|
||||||
|
# Test various API requirements
|
||||||
|
test_requirements = [
|
||||||
|
{"llama:admin", "llama:inference"},
|
||||||
|
{"llama:admin", "llama:models:write"},
|
||||||
|
{"llama:admin", "llama:agents:write"},
|
||||||
|
{"llama:admin", "llama:tools"},
|
||||||
|
{"llama:admin", "llama:vector_dbs:write"},
|
||||||
|
{"llama:admin", "llama:safety"},
|
||||||
|
{"llama:admin", "llama:eval"},
|
||||||
|
]
|
||||||
|
|
||||||
|
for required_scopes in test_requirements:
|
||||||
|
assert admin_scopes.intersection(required_scopes)
|
||||||
|
|
||||||
|
|
||||||
|
class TestScopeValidationErrors:
|
||||||
|
"""Test error cases in scope validation"""
|
||||||
|
|
||||||
|
def test_missing_scope_claim(self):
|
||||||
|
"""Test handling of missing scope claim in JWT"""
|
||||||
|
# Empty scope claim
|
||||||
|
token_scopes = set()
|
||||||
|
with pytest.raises(ValueError, match="Token lacks required OAuth2 scopes"):
|
||||||
|
validate_scopes(token_scopes)
|
||||||
|
|
||||||
|
def test_malformed_scope_string(self):
|
||||||
|
"""Test handling of malformed scope strings"""
|
||||||
|
# Scopes with extra whitespace should be handled gracefully
|
||||||
|
scope_string = " llama:inference llama:models:read "
|
||||||
|
scopes = set(scope_string.split())
|
||||||
|
|
||||||
|
# Filter out empty strings that might result from split()
|
||||||
|
scopes = {s.strip() for s in scopes if s.strip()}
|
||||||
|
|
||||||
|
result = validate_scopes(scopes)
|
||||||
|
assert result == {"llama:inference", "llama:models:read"}
|
||||||
|
|
||||||
|
def test_case_sensitive_scopes(self):
|
||||||
|
"""Test that scopes are case-sensitive"""
|
||||||
|
# Wrong case should not match
|
||||||
|
token_scopes = {"LLAMA:INFERENCE", "llama:Models:Read"}
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match="Token lacks required OAuth2 scopes"):
|
||||||
|
validate_scopes(token_scopes)
|
||||||
|
|
||||||
|
def test_partial_scope_matches(self):
|
||||||
|
"""Test that partial scope matches don't work"""
|
||||||
|
# Partial matches should not be accepted
|
||||||
|
token_scopes = {"llama:model", "llama", "inference"}
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match="Token lacks required OAuth2 scopes"):
|
||||||
|
validate_scopes(token_scopes)
|
||||||
|
|
||||||
|
|
||||||
|
class TestScopeDocumentation:
|
||||||
|
"""Test scope documentation and descriptions"""
|
||||||
|
|
||||||
|
def test_scope_naming_convention(self):
|
||||||
|
"""Test that scope names follow consistent naming convention"""
|
||||||
|
for scope in STANDARD_OAUTH2_SCOPES.keys():
|
||||||
|
# All scopes should start with 'llama:'
|
||||||
|
assert scope.startswith("llama:")
|
||||||
|
|
||||||
|
# Should not contain spaces
|
||||||
|
assert " " not in scope
|
||||||
|
|
||||||
|
# Should use colons as separators, not dots or slashes
|
||||||
|
parts = scope.split(":")
|
||||||
|
assert len(parts) >= 2
|
||||||
|
assert all(part.replace("_", "").isalnum() for part in parts)
|
||||||
|
|
||||||
|
def test_scope_descriptions_quality(self):
|
||||||
|
"""Test that scope descriptions are meaningful"""
|
||||||
|
for scope, description in STANDARD_OAUTH2_SCOPES.items():
|
||||||
|
# Should be non-empty strings
|
||||||
|
assert isinstance(description, str)
|
||||||
|
assert len(description.strip()) > 5
|
||||||
|
|
||||||
|
# Should not just be the scope name
|
||||||
|
assert scope not in description
|
||||||
|
|
||||||
|
# Should contain descriptive words
|
||||||
|
descriptive_words = ["access", "read", "write", "manage", "create", "delete", "execute"]
|
||||||
|
assert any(word in description.lower() for word in descriptive_words)
|
||||||
|
|
||||||
|
def test_read_write_scope_pairs(self):
|
||||||
|
"""Test that read/write scope pairs are consistent"""
|
||||||
|
read_write_apis = ["models", "agents", "vector_dbs"]
|
||||||
|
|
||||||
|
for api in read_write_apis:
|
||||||
|
read_scope = f"llama:{api}:read"
|
||||||
|
write_scope = f"llama:{api}:write"
|
||||||
|
|
||||||
|
assert read_scope in STANDARD_OAUTH2_SCOPES
|
||||||
|
assert write_scope in STANDARD_OAUTH2_SCOPES
|
||||||
|
|
||||||
|
# Read description should mention "read"
|
||||||
|
assert "read" in STANDARD_OAUTH2_SCOPES[read_scope].lower()
|
||||||
|
|
||||||
|
# Write description should mention write operations
|
||||||
|
write_desc = STANDARD_OAUTH2_SCOPES[write_scope].lower()
|
||||||
|
assert any(word in write_desc for word in ["write", "manage", "register", "create"])
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
pytest.main([__file__])
|
||||||
Loading…
Add table
Add a link
Reference in a new issue