feat: adding scope-based authorization

This commit is contained in:
Lance Galletti 2025-07-17 11:46:07 -04:00
parent 51b179e1c5
commit fe093918c2
6 changed files with 1082 additions and 9 deletions

View file

@ -10,11 +10,36 @@ import httpx
from llama_stack.distribution.datatypes import AuthenticationConfig
from llama_stack.distribution.server.auth_providers import create_auth_provider
from llama_stack.distribution.server.oauth2_scopes import get_required_scopes_for_api
from llama_stack.log import get_logger
logger = get_logger(name=__name__, category="auth")
def extract_api_from_path(path: str) -> tuple[str, str]:
"""Extract API name and method from request path for scope validation"""
# Remove leading/trailing slashes and split
path = path.strip("/")
parts = path.split("/")
# Handle common API path patterns
if len(parts) >= 2 and parts[0] == "v1":
api_name = parts[1]
# Handle nested paths like /v1/models/{id} or /v1/inference/chat-completion
if api_name in ["inference", "models", "agents", "tools", "vector_dbs", "safety", "eval", "scoring"]:
return api_name, "POST" # Default to POST for scope checking
elif api_name == "openai":
# Handle OpenAI compatibility endpoints like /v1/openai/v1/chat/completions
if len(parts) >= 4:
return "inference", "POST" # OpenAI endpoints are typically inference
# Fallback - try to extract from first path component
if parts:
return parts[0], "POST"
return "unknown", "GET"
class AuthenticationMiddleware:
"""Middleware that authenticates requests using configured authentication provider.
@ -109,6 +134,34 @@ class AuthenticationMiddleware:
logger.exception("Error during authentication")
return await self._send_auth_error(send, "Authentication service error")
# Validate OAuth2 scopes for the requested API endpoint
path = scope.get("path", "")
method = scope.get("method", "GET")
api_name, _ = extract_api_from_path(path)
# Get required scopes for this API endpoint
required_scopes = get_required_scopes_for_api(api_name, method)
# Check if user has any of the required scopes
user_scopes = set()
if validation_result.attributes and "scopes" in validation_result.attributes:
user_scopes = set(validation_result.attributes["scopes"])
# Verify user has at least one required scope
if not user_scopes.intersection(required_scopes):
logger.warning(
f"Access denied for {validation_result.principal} to {api_name} API. "
f"Required scopes: {required_scopes}, User scopes: {user_scopes}"
)
return await self._send_auth_error(
send, f"Insufficient OAuth2 scopes for {api_name} API. Required: {', '.join(required_scopes)}"
)
logger.debug(
f"OAuth2 scope validation passed for {validation_result.principal} "
f"on {api_name} API with scopes: {user_scopes.intersection(required_scopes)}"
)
# Store the client ID in the request scope so that downstream middleware (like QuotaMiddleware)
# can identify the requester and enforce per-client rate limits.
scope["authenticated_client_id"] = token
@ -117,9 +170,8 @@ class AuthenticationMiddleware:
scope["principal"] = validation_result.principal
if validation_result.attributes:
scope["user_attributes"] = validation_result.attributes
logger.debug(
f"Authentication successful: {validation_result.principal} with {len(validation_result.attributes)} attributes"
)
attr_count = len(validation_result.attributes) if validation_result.attributes else 0
logger.debug(f"Authentication successful: {validation_result.principal} with {attr_count} attributes")
return await self.app(scope, receive, send)