feat: adding scope-based authorization

This commit is contained in:
Lance Galletti 2025-07-17 11:46:07 -04:00
parent 51b179e1c5
commit fe093918c2
6 changed files with 1082 additions and 9 deletions

View file

@ -10,11 +10,36 @@ import httpx
from llama_stack.distribution.datatypes import AuthenticationConfig
from llama_stack.distribution.server.auth_providers import create_auth_provider
from llama_stack.distribution.server.oauth2_scopes import get_required_scopes_for_api
from llama_stack.log import get_logger
logger = get_logger(name=__name__, category="auth")
def extract_api_from_path(path: str) -> tuple[str, str]:
"""Extract API name and method from request path for scope validation"""
# Remove leading/trailing slashes and split
path = path.strip("/")
parts = path.split("/")
# Handle common API path patterns
if len(parts) >= 2 and parts[0] == "v1":
api_name = parts[1]
# Handle nested paths like /v1/models/{id} or /v1/inference/chat-completion
if api_name in ["inference", "models", "agents", "tools", "vector_dbs", "safety", "eval", "scoring"]:
return api_name, "POST" # Default to POST for scope checking
elif api_name == "openai":
# Handle OpenAI compatibility endpoints like /v1/openai/v1/chat/completions
if len(parts) >= 4:
return "inference", "POST" # OpenAI endpoints are typically inference
# Fallback - try to extract from first path component
if parts:
return parts[0], "POST"
return "unknown", "GET"
class AuthenticationMiddleware:
"""Middleware that authenticates requests using configured authentication provider.
@ -109,6 +134,34 @@ class AuthenticationMiddleware:
logger.exception("Error during authentication")
return await self._send_auth_error(send, "Authentication service error")
# Validate OAuth2 scopes for the requested API endpoint
path = scope.get("path", "")
method = scope.get("method", "GET")
api_name, _ = extract_api_from_path(path)
# Get required scopes for this API endpoint
required_scopes = get_required_scopes_for_api(api_name, method)
# Check if user has any of the required scopes
user_scopes = set()
if validation_result.attributes and "scopes" in validation_result.attributes:
user_scopes = set(validation_result.attributes["scopes"])
# Verify user has at least one required scope
if not user_scopes.intersection(required_scopes):
logger.warning(
f"Access denied for {validation_result.principal} to {api_name} API. "
f"Required scopes: {required_scopes}, User scopes: {user_scopes}"
)
return await self._send_auth_error(
send, f"Insufficient OAuth2 scopes for {api_name} API. Required: {', '.join(required_scopes)}"
)
logger.debug(
f"OAuth2 scope validation passed for {validation_result.principal} "
f"on {api_name} API with scopes: {user_scopes.intersection(required_scopes)}"
)
# Store the client ID in the request scope so that downstream middleware (like QuotaMiddleware)
# can identify the requester and enforce per-client rate limits.
scope["authenticated_client_id"] = token
@ -117,9 +170,8 @@ class AuthenticationMiddleware:
scope["principal"] = validation_result.principal
if validation_result.attributes:
scope["user_attributes"] = validation_result.attributes
logger.debug(
f"Authentication successful: {validation_result.principal} with {len(validation_result.attributes)} attributes"
)
attr_count = len(validation_result.attributes) if validation_result.attributes else 0
logger.debug(f"Authentication successful: {validation_result.principal} with {attr_count} attributes")
return await self.app(scope, receive, send)

View file

@ -21,6 +21,10 @@ from llama_stack.distribution.datatypes import (
OAuth2TokenAuthConfig,
User,
)
from llama_stack.distribution.server.oauth2_scopes import (
scope_grants_admin_access,
validate_scopes,
)
from llama_stack.log import get_logger
logger = get_logger(name=__name__, category="auth")
@ -128,13 +132,34 @@ class OAuth2TokenAuthProvider(AuthProvider):
except Exception as exc:
raise ValueError("Invalid JWT token") from exc
# There are other standard claims, the most relevant of which is `scope`.
# We should incorporate these into the access attributes.
principal = claims["sub"]
access_attributes = get_attributes_from_claims(claims, self.config.claims_mapping)
# Extract and validate OAuth2 scopes - deny by default if no valid scopes
token_scopes = set(claims.get("scope", "").split()) if claims.get("scope") else set()
# Validate scopes against standard Llama Stack scopes
valid_scopes = validate_scopes(token_scopes)
logger.info(f"User {claims['sub']} authenticated with scopes: {valid_scopes}")
# Convert scopes to user attributes for the existing ABAC system
attributes = {"scopes": list(valid_scopes)}
# Add admin role if user has admin scope
if scope_grants_admin_access(valid_scopes):
attributes["roles"] = ["admin"]
# Maintain backward compatibility with existing claims mapping
legacy_attributes = get_attributes_from_claims(claims, self.config.claims_mapping)
if legacy_attributes:
for key, values in legacy_attributes.items():
if key not in attributes:
attributes[key] = values
else:
# Merge lists, avoiding duplicates
attributes[key] = list(set(attributes[key] + values))
return User(
principal=principal,
attributes=access_attributes,
principal=claims["sub"],
attributes=attributes,
)
async def introspect_token(self, token: str, scope: dict | None = None) -> User:

View file

@ -0,0 +1,126 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
"""
OAuth2 scope definitions and validation for Llama Stack APIs.
This module defines the standard OAuth2 scopes that are built-in to Llama Stack
and provides utilities for scope validation. These scopes are not configurable
and provide a standardized way to control access to different API endpoints.
"""
from typing import Set
# Standard OAuth2 scopes for Llama Stack APIs (built-in, not configurable)
STANDARD_OAUTH2_SCOPES = {
# Inference API
"llama:inference": "Access to inference APIs (chat completion, embeddings)",
# Models API
"llama:models:read": "Read access to models (list, get model details)",
"llama:models:write": "Write access to models (register, unregister)",
# Agents API
"llama:agents:read": "Read access to agents (list sessions, get agent details)",
"llama:agents:write": "Write access to agents (create sessions, send messages)",
# Tools API
"llama:tools": "Access to tool runtime and execution",
# Vector DB API
"llama:vector_dbs:read": "Read access to vector databases",
"llama:vector_dbs:write": "Write access to vector databases",
# Safety API
"llama:safety": "Access to safety shields and content filtering",
# Eval API
"llama:eval": "Access to evaluation and benchmarking",
# Administrative access
"llama:admin": "Full administrative access to all APIs",
}
def get_required_scopes_for_api(api_name: str, method: str = "GET") -> Set[str]:
"""Get required OAuth2 scopes for accessing a specific API endpoint.
Args:
api_name: The name of the API (e.g., 'models', 'inference', 'agents')
method: The HTTP method (GET, POST, PUT, DELETE)
Returns:
Set of scope strings that would grant access to this endpoint.
Always includes 'llama:admin' as it grants access to everything.
"""
# Admin scope grants access to everything
required_scopes = {"llama:admin"}
# Map API names to required scopes
if api_name in ["inference", "chat", "completion", "embeddings"]:
required_scopes.add("llama:inference")
elif api_name == "models":
if method in ["POST", "PUT", "DELETE"]:
required_scopes.add("llama:models:write")
else:
required_scopes.add("llama:models:read")
elif api_name == "agents":
if method in ["POST", "PUT", "DELETE"]:
required_scopes.add("llama:agents:write")
else:
required_scopes.add("llama:agents:read")
elif api_name in ["tools", "tool_runtime"]:
required_scopes.add("llama:tools")
elif api_name == "vector_dbs":
if method in ["POST", "PUT", "DELETE"]:
required_scopes.add("llama:vector_dbs:write")
else:
required_scopes.add("llama:vector_dbs:read")
elif api_name == "safety":
required_scopes.add("llama:safety")
elif api_name in ["eval", "benchmarks", "scoring"]:
required_scopes.add("llama:eval")
return required_scopes
def validate_scopes(token_scopes: Set[str]) -> Set[str]:
"""Validate OAuth2 scopes against standard Llama Stack scopes.
Args:
token_scopes: Set of scopes from the OAuth2 token
Returns:
Set of valid scopes that are recognized by Llama Stack
Raises:
ValueError: If no valid scopes are found
"""
valid_scopes = token_scopes.intersection(STANDARD_OAUTH2_SCOPES.keys())
if not valid_scopes:
raise ValueError("Token lacks required OAuth2 scopes for Llama Stack access")
return valid_scopes
def scope_grants_admin_access(scopes: Set[str]) -> bool:
"""Check if the provided scopes include administrative access.
Args:
scopes: Set of OAuth2 scopes
Returns:
True if scopes include administrative access, False otherwise
"""
return "llama:admin" in scopes
def get_all_scope_descriptions() -> dict[str, str]:
"""Get all standard OAuth2 scopes with their descriptions.
Returns:
Dictionary mapping scope names to their descriptions
"""
return STANDARD_OAUTH2_SCOPES.copy()