mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-24 14:40:00 +00:00
feat: adding scope-based authorization
This commit is contained in:
parent
51b179e1c5
commit
fe093918c2
6 changed files with 1082 additions and 9 deletions
|
|
@ -10,11 +10,36 @@ import httpx
|
|||
|
||||
from llama_stack.distribution.datatypes import AuthenticationConfig
|
||||
from llama_stack.distribution.server.auth_providers import create_auth_provider
|
||||
from llama_stack.distribution.server.oauth2_scopes import get_required_scopes_for_api
|
||||
from llama_stack.log import get_logger
|
||||
|
||||
logger = get_logger(name=__name__, category="auth")
|
||||
|
||||
|
||||
def extract_api_from_path(path: str) -> tuple[str, str]:
|
||||
"""Extract API name and method from request path for scope validation"""
|
||||
# Remove leading/trailing slashes and split
|
||||
path = path.strip("/")
|
||||
parts = path.split("/")
|
||||
|
||||
# Handle common API path patterns
|
||||
if len(parts) >= 2 and parts[0] == "v1":
|
||||
api_name = parts[1]
|
||||
# Handle nested paths like /v1/models/{id} or /v1/inference/chat-completion
|
||||
if api_name in ["inference", "models", "agents", "tools", "vector_dbs", "safety", "eval", "scoring"]:
|
||||
return api_name, "POST" # Default to POST for scope checking
|
||||
elif api_name == "openai":
|
||||
# Handle OpenAI compatibility endpoints like /v1/openai/v1/chat/completions
|
||||
if len(parts) >= 4:
|
||||
return "inference", "POST" # OpenAI endpoints are typically inference
|
||||
|
||||
# Fallback - try to extract from first path component
|
||||
if parts:
|
||||
return parts[0], "POST"
|
||||
|
||||
return "unknown", "GET"
|
||||
|
||||
|
||||
class AuthenticationMiddleware:
|
||||
"""Middleware that authenticates requests using configured authentication provider.
|
||||
|
||||
|
|
@ -109,6 +134,34 @@ class AuthenticationMiddleware:
|
|||
logger.exception("Error during authentication")
|
||||
return await self._send_auth_error(send, "Authentication service error")
|
||||
|
||||
# Validate OAuth2 scopes for the requested API endpoint
|
||||
path = scope.get("path", "")
|
||||
method = scope.get("method", "GET")
|
||||
api_name, _ = extract_api_from_path(path)
|
||||
|
||||
# Get required scopes for this API endpoint
|
||||
required_scopes = get_required_scopes_for_api(api_name, method)
|
||||
|
||||
# Check if user has any of the required scopes
|
||||
user_scopes = set()
|
||||
if validation_result.attributes and "scopes" in validation_result.attributes:
|
||||
user_scopes = set(validation_result.attributes["scopes"])
|
||||
|
||||
# Verify user has at least one required scope
|
||||
if not user_scopes.intersection(required_scopes):
|
||||
logger.warning(
|
||||
f"Access denied for {validation_result.principal} to {api_name} API. "
|
||||
f"Required scopes: {required_scopes}, User scopes: {user_scopes}"
|
||||
)
|
||||
return await self._send_auth_error(
|
||||
send, f"Insufficient OAuth2 scopes for {api_name} API. Required: {', '.join(required_scopes)}"
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
f"OAuth2 scope validation passed for {validation_result.principal} "
|
||||
f"on {api_name} API with scopes: {user_scopes.intersection(required_scopes)}"
|
||||
)
|
||||
|
||||
# Store the client ID in the request scope so that downstream middleware (like QuotaMiddleware)
|
||||
# can identify the requester and enforce per-client rate limits.
|
||||
scope["authenticated_client_id"] = token
|
||||
|
|
@ -117,9 +170,8 @@ class AuthenticationMiddleware:
|
|||
scope["principal"] = validation_result.principal
|
||||
if validation_result.attributes:
|
||||
scope["user_attributes"] = validation_result.attributes
|
||||
logger.debug(
|
||||
f"Authentication successful: {validation_result.principal} with {len(validation_result.attributes)} attributes"
|
||||
)
|
||||
attr_count = len(validation_result.attributes) if validation_result.attributes else 0
|
||||
logger.debug(f"Authentication successful: {validation_result.principal} with {attr_count} attributes")
|
||||
|
||||
return await self.app(scope, receive, send)
|
||||
|
||||
|
|
|
|||
|
|
@ -21,6 +21,10 @@ from llama_stack.distribution.datatypes import (
|
|||
OAuth2TokenAuthConfig,
|
||||
User,
|
||||
)
|
||||
from llama_stack.distribution.server.oauth2_scopes import (
|
||||
scope_grants_admin_access,
|
||||
validate_scopes,
|
||||
)
|
||||
from llama_stack.log import get_logger
|
||||
|
||||
logger = get_logger(name=__name__, category="auth")
|
||||
|
|
@ -128,13 +132,34 @@ class OAuth2TokenAuthProvider(AuthProvider):
|
|||
except Exception as exc:
|
||||
raise ValueError("Invalid JWT token") from exc
|
||||
|
||||
# There are other standard claims, the most relevant of which is `scope`.
|
||||
# We should incorporate these into the access attributes.
|
||||
principal = claims["sub"]
|
||||
access_attributes = get_attributes_from_claims(claims, self.config.claims_mapping)
|
||||
# Extract and validate OAuth2 scopes - deny by default if no valid scopes
|
||||
token_scopes = set(claims.get("scope", "").split()) if claims.get("scope") else set()
|
||||
|
||||
# Validate scopes against standard Llama Stack scopes
|
||||
valid_scopes = validate_scopes(token_scopes)
|
||||
|
||||
logger.info(f"User {claims['sub']} authenticated with scopes: {valid_scopes}")
|
||||
|
||||
# Convert scopes to user attributes for the existing ABAC system
|
||||
attributes = {"scopes": list(valid_scopes)}
|
||||
|
||||
# Add admin role if user has admin scope
|
||||
if scope_grants_admin_access(valid_scopes):
|
||||
attributes["roles"] = ["admin"]
|
||||
|
||||
# Maintain backward compatibility with existing claims mapping
|
||||
legacy_attributes = get_attributes_from_claims(claims, self.config.claims_mapping)
|
||||
if legacy_attributes:
|
||||
for key, values in legacy_attributes.items():
|
||||
if key not in attributes:
|
||||
attributes[key] = values
|
||||
else:
|
||||
# Merge lists, avoiding duplicates
|
||||
attributes[key] = list(set(attributes[key] + values))
|
||||
|
||||
return User(
|
||||
principal=principal,
|
||||
attributes=access_attributes,
|
||||
principal=claims["sub"],
|
||||
attributes=attributes,
|
||||
)
|
||||
|
||||
async def introspect_token(self, token: str, scope: dict | None = None) -> User:
|
||||
|
|
|
|||
126
llama_stack/distribution/server/oauth2_scopes.py
Normal file
126
llama_stack/distribution/server/oauth2_scopes.py
Normal file
|
|
@ -0,0 +1,126 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
"""
|
||||
OAuth2 scope definitions and validation for Llama Stack APIs.
|
||||
|
||||
This module defines the standard OAuth2 scopes that are built-in to Llama Stack
|
||||
and provides utilities for scope validation. These scopes are not configurable
|
||||
and provide a standardized way to control access to different API endpoints.
|
||||
"""
|
||||
|
||||
from typing import Set
|
||||
|
||||
# Standard OAuth2 scopes for Llama Stack APIs (built-in, not configurable)
|
||||
STANDARD_OAUTH2_SCOPES = {
|
||||
# Inference API
|
||||
"llama:inference": "Access to inference APIs (chat completion, embeddings)",
|
||||
|
||||
# Models API
|
||||
"llama:models:read": "Read access to models (list, get model details)",
|
||||
"llama:models:write": "Write access to models (register, unregister)",
|
||||
|
||||
# Agents API
|
||||
"llama:agents:read": "Read access to agents (list sessions, get agent details)",
|
||||
"llama:agents:write": "Write access to agents (create sessions, send messages)",
|
||||
|
||||
# Tools API
|
||||
"llama:tools": "Access to tool runtime and execution",
|
||||
|
||||
# Vector DB API
|
||||
"llama:vector_dbs:read": "Read access to vector databases",
|
||||
"llama:vector_dbs:write": "Write access to vector databases",
|
||||
|
||||
# Safety API
|
||||
"llama:safety": "Access to safety shields and content filtering",
|
||||
|
||||
# Eval API
|
||||
"llama:eval": "Access to evaluation and benchmarking",
|
||||
|
||||
# Administrative access
|
||||
"llama:admin": "Full administrative access to all APIs",
|
||||
}
|
||||
|
||||
|
||||
def get_required_scopes_for_api(api_name: str, method: str = "GET") -> Set[str]:
|
||||
"""Get required OAuth2 scopes for accessing a specific API endpoint.
|
||||
|
||||
Args:
|
||||
api_name: The name of the API (e.g., 'models', 'inference', 'agents')
|
||||
method: The HTTP method (GET, POST, PUT, DELETE)
|
||||
|
||||
Returns:
|
||||
Set of scope strings that would grant access to this endpoint.
|
||||
Always includes 'llama:admin' as it grants access to everything.
|
||||
"""
|
||||
# Admin scope grants access to everything
|
||||
required_scopes = {"llama:admin"}
|
||||
|
||||
# Map API names to required scopes
|
||||
if api_name in ["inference", "chat", "completion", "embeddings"]:
|
||||
required_scopes.add("llama:inference")
|
||||
elif api_name == "models":
|
||||
if method in ["POST", "PUT", "DELETE"]:
|
||||
required_scopes.add("llama:models:write")
|
||||
else:
|
||||
required_scopes.add("llama:models:read")
|
||||
elif api_name == "agents":
|
||||
if method in ["POST", "PUT", "DELETE"]:
|
||||
required_scopes.add("llama:agents:write")
|
||||
else:
|
||||
required_scopes.add("llama:agents:read")
|
||||
elif api_name in ["tools", "tool_runtime"]:
|
||||
required_scopes.add("llama:tools")
|
||||
elif api_name == "vector_dbs":
|
||||
if method in ["POST", "PUT", "DELETE"]:
|
||||
required_scopes.add("llama:vector_dbs:write")
|
||||
else:
|
||||
required_scopes.add("llama:vector_dbs:read")
|
||||
elif api_name == "safety":
|
||||
required_scopes.add("llama:safety")
|
||||
elif api_name in ["eval", "benchmarks", "scoring"]:
|
||||
required_scopes.add("llama:eval")
|
||||
|
||||
return required_scopes
|
||||
|
||||
|
||||
def validate_scopes(token_scopes: Set[str]) -> Set[str]:
|
||||
"""Validate OAuth2 scopes against standard Llama Stack scopes.
|
||||
|
||||
Args:
|
||||
token_scopes: Set of scopes from the OAuth2 token
|
||||
|
||||
Returns:
|
||||
Set of valid scopes that are recognized by Llama Stack
|
||||
|
||||
Raises:
|
||||
ValueError: If no valid scopes are found
|
||||
"""
|
||||
valid_scopes = token_scopes.intersection(STANDARD_OAUTH2_SCOPES.keys())
|
||||
if not valid_scopes:
|
||||
raise ValueError("Token lacks required OAuth2 scopes for Llama Stack access")
|
||||
return valid_scopes
|
||||
|
||||
|
||||
def scope_grants_admin_access(scopes: Set[str]) -> bool:
|
||||
"""Check if the provided scopes include administrative access.
|
||||
|
||||
Args:
|
||||
scopes: Set of OAuth2 scopes
|
||||
|
||||
Returns:
|
||||
True if scopes include administrative access, False otherwise
|
||||
"""
|
||||
return "llama:admin" in scopes
|
||||
|
||||
|
||||
def get_all_scope_descriptions() -> dict[str, str]:
|
||||
"""Get all standard OAuth2 scopes with their descriptions.
|
||||
|
||||
Returns:
|
||||
Dictionary mapping scope names to their descriptions
|
||||
"""
|
||||
return STANDARD_OAUTH2_SCOPES.copy()
|
||||
Loading…
Add table
Add a link
Reference in a new issue