Refactor scope validation

This commit is contained in:
NipuniBhagya 2025-05-15 01:20:29 +05:30
parent ed525dc7b5
commit 7d64cc4093
7 changed files with 115 additions and 102 deletions

View file

@ -50,9 +50,9 @@ demo:
resource_identifier: http://localhost:3000 resource_identifier: http://localhost:3000
audience: mcp_proxy audience: mcp_proxy
scopes_supported: scopes_supported:
"read:tools" - "tools":"read:tools"
"read:resources" - "resources":"read:resources"
"read:prompts" - "prompts":"read:prompts"
authorization_servers: authorization_servers:
- https://api.asgardeo.io/t/acme/ - https://api.asgardeo.io/t/acme/
jwks_uri: https://api.asgardeo.io/t/acme/oauth2/jwks jwks_uri: https://api.asgardeo.io/t/acme/oauth2/jwks

View file

@ -4,6 +4,8 @@ import (
"fmt" "fmt"
"net/http" "net/http"
"strings" "strings"
logger "github.com/wso2/open-mcp-auth-proxy/internal/logging"
) )
type TokenClaims struct { type TokenClaims struct {
@ -16,30 +18,42 @@ type DefaultPolicyEngine struct{}
func (d *DefaultPolicyEngine) Evaluate( func (d *DefaultPolicyEngine) Evaluate(
_ *http.Request, _ *http.Request,
claims *TokenClaims, claims *TokenClaims,
requiredScope string, requiredScopes any,
) PolicyResult { ) PolicyResult {
if strings.TrimSpace(requiredScope) == "" {
logger.Info("Required scopes: %v", requiredScopes)
var scopeStr string
switch v := requiredScopes.(type) {
case string:
scopeStr = v
case []string:
scopeStr = strings.Join(v, " ")
}
if strings.TrimSpace(scopeStr) == "" {
return PolicyResult{DecisionAllow, ""} return PolicyResult{DecisionAllow, ""}
} }
raw := strings.FieldsFunc(requiredScope, func(r rune) bool { scopes := strings.FieldsFunc(scopeStr, func(r rune) bool {
return r == ' ' || r == ',' return r == ' ' || r == ','
}) })
want := make(map[string]struct{}, len(raw)) required := make(map[string]struct{}, len(scopes))
for _, s := range raw { for _, s := range scopes {
if s = strings.TrimSpace(s); s != "" { if s = strings.TrimSpace(s); s != "" {
want[s] = struct{}{} required[s] = struct{}{}
} }
} }
for _, have := range claims.Scopes { logger.Info("Token scopes: %v", claims.Scopes)
if _, ok := want[have]; ok { for _, tokenScope := range claims.Scopes {
if _, ok := required[tokenScope]; ok {
return PolicyResult{DecisionAllow, ""} return PolicyResult{DecisionAllow, ""}
} }
} }
var list []string var list []string
for s := range want { for s := range required {
list = append(list, s) list = append(list, s)
} }
return PolicyResult{ return PolicyResult{

View file

@ -5,15 +5,15 @@ import "net/http"
type Decision int type Decision int
const ( const (
DecisionAllow Decision = iota DecisionAllow Decision = iota
DecisionDeny DecisionDeny
) )
type PolicyResult struct { type PolicyResult struct {
Decision Decision Decision Decision
Message string Message string
} }
type PolicyEngine interface { type PolicyEngine interface {
Evaluate(r *http.Request, claims *TokenClaims, requiredScope string) PolicyResult Evaluate(r *http.Request, claims *TokenClaims, requiredScopes any) PolicyResult
} }

View file

@ -105,7 +105,7 @@ type Config struct {
// Protected resource metadata // Protected resource metadata
Audience string `yaml:"audience"` Audience string `yaml:"audience"`
ResourceIdentifier string `yaml:"resource_identifier"` ResourceIdentifier string `yaml:"resource_identifier"`
ScopesSupported map[string]string `yaml:"scopes_supported"` ScopesSupported any `yaml:"scopes_supported"`
AuthorizationServers []string `yaml:"authorization_servers"` AuthorizationServers []string `yaml:"authorization_servers"`
JwksURI string `yaml:"jwks_uri,omitempty"` JwksURI string `yaml:"jwks_uri,omitempty"`
BearerMethodsSupported []string `yaml:"bearer_methods_supported,omitempty"` BearerMethodsSupported []string `yaml:"bearer_methods_supported,omitempty"`

View file

@ -175,20 +175,10 @@ func buildProxyHandler(cfg *config.Config, modifiers map[string]RequestModifier,
} }
isSSE = true isSSE = true
} else { } else {
claims, err := authorizeMCP(w, r, isLatestSpec, cfg) if err := authorizeMCP(w, r, isLatestSpec, cfg, policyEngine); err != nil {
if err != nil {
http.Error(w, err.Error(), http.StatusForbidden) http.Error(w, err.Error(), http.StatusForbidden)
return return
} }
if isLatestSpec {
scope := cfg.ScopesSupported[r.URL.Path]
pr := policyEngine.Evaluate(r, claims, scope)
if pr.Decision == authz.DecisionDeny {
http.Error(w, "Forbidden: "+pr.Message, http.StatusForbidden)
return
}
}
} }
targetURL = mcpBase targetURL = mcpBase
@ -310,54 +300,54 @@ func authorizeSSE(w http.ResponseWriter, r *http.Request, isLatestSpec bool, res
} }
// Handles both v1 (just signature) and v2 (aud + scope) flows // Handles both v1 (just signature) and v2 (aud + scope) flows
func authorizeMCP(w http.ResponseWriter, r *http.Request, isLatestSpec bool, cfg *config.Config) (*authz.TokenClaims, error) { func authorizeMCP(w http.ResponseWriter, r *http.Request, isLatestSpec bool, cfg *config.Config, policyEngine authz.PolicyEngine) error {
// Parse JSON-RPC request if present authzHeader := r.Header.Get("Authorization")
if env, err := util.ParseRPCRequest(r); err != nil { if !strings.HasPrefix(authzHeader, "Bearer ") {
http.Error(w, "Bad request", http.StatusBadRequest) if isLatestSpec {
return nil, err realm := cfg.ResourceIdentifier + "/.well-known/oauth-protected-resource"
} else if env != nil { w.Header().Set("WWW-Authenticate", fmt.Sprintf(
logger.Info("JSON-RPC method = %q", env.Method) `Bearer resource_metadata=%q`, realm,
))
w.Header().Set("Access-Control-Expose-Headers", "WWW-Authenticate")
}
http.Error(w, "Unauthorized", http.StatusUnauthorized)
return fmt.Errorf("missing or invalid Authorization header")
} }
authzHeader := r.Header.Get("Authorization") claims, err := util.ValidateJWT(isLatestSpec, authzHeader, cfg.Audience)
if !strings.HasPrefix(authzHeader, "Bearer ") { if err != nil {
if isLatestSpec { if isLatestSpec {
realm := cfg.ResourceIdentifier + "/.well-known/oauth-protected-resource" realm := cfg.ResourceIdentifier + "/.well-known/oauth-protected-resource"
w.Header().Set("WWW-Authenticate", fmt.Sprintf( w.Header().Set("WWW-Authenticate", fmt.Sprintf(err.Error(),
`Bearer resource_metadata=%q`, realm, `Bearer realm=%q`,
)) realm,
w.Header().Set("Access-Control-Expose-Headers", "WWW-Authenticate") ))
} w.Header().Set("Access-Control-Expose-Headers", "WWW-Authenticate")
http.Error(w, "Unauthorized", http.StatusUnauthorized) http.Error(w, "Forbidden", http.StatusForbidden)
return nil, fmt.Errorf("missing or invalid Authorization header") } else {
} http.Error(w, "Unauthorized", http.StatusUnauthorized)
}
return err
}
requiredScope := "" if isLatestSpec {
if isLatestSpec { env, err := util.ParseRPCRequest(r)
requiredScope = cfg.ScopesSupported[r.URL.Path] if err != nil {
} http.Error(w, "Bad request", http.StatusBadRequest)
claims, err := util.ValidateJWT( return err
isLatestSpec, }
authzHeader, requiredScopes := util.GetRequiredScopes(cfg, env.Method)
cfg.Audience, if len(requiredScopes) == 0 {
requiredScope, return nil
) }
if err != nil { pr := policyEngine.Evaluate(r, claims, requiredScopes)
if isLatestSpec { if pr.Decision == authz.DecisionDeny {
realm := cfg.ResourceIdentifier + "/.well-known/oauth-protected-resource" http.Error(w, "Forbidden: "+pr.Message, http.StatusForbidden)
w.Header().Set("WWW-Authenticate", fmt.Sprintf( return fmt.Errorf("forbidden — %s", pr.Message)
`Bearer realm=%q, error="insufficient_scope", scope=%q`, }
realm, requiredScope, }
))
w.Header().Set("Access-Control-Expose-Headers", "WWW-Authenticate")
http.Error(w, "Forbidden", http.StatusForbidden)
} else {
http.Error(w, "Unauthorized", http.StatusUnauthorized)
}
return nil, err
}
return claims, nil return nil
} }
func getAllowedOrigin(origin string, cfg *config.Config) string { func getAllowedOrigin(origin string, cfg *config.Config) string {

View file

@ -11,7 +11,8 @@ import (
"github.com/golang-jwt/jwt/v4" "github.com/golang-jwt/jwt/v4"
"github.com/wso2/open-mcp-auth-proxy/internal/authz" "github.com/wso2/open-mcp-auth-proxy/internal/authz"
"github.com/wso2/open-mcp-auth-proxy/internal/logging" "github.com/wso2/open-mcp-auth-proxy/internal/config"
logger "github.com/wso2/open-mcp-auth-proxy/internal/logging"
) )
type TokenClaims struct { type TokenClaims struct {
@ -80,19 +81,20 @@ func parseRSAPublicKey(nStr, eStr string) (*rsa.PublicKey, error) {
} }
// ValidateJWT checks the Bearer token according to the Mcp-Protocol-Version. // ValidateJWT checks the Bearer token according to the Mcp-Protocol-Version.
// - versionHeader: the raw value of the "Mcp-Protocol-Version" header // - isLatestSpec: whether to use the latest spec validation
// - authHeader: the full "Authorization" header // - authHeader: the full "Authorization" header
// - audience: the resource identifier to check "aud" against // - audience: the resource identifier to check "aud" against
// - requiredScope: the single scope required (empty ⇒ skip scope check) // - requiredScopes: the scopes required (empty ⇒ skip scope check)
func ValidateJWT( func ValidateJWT(
isLatestSpec bool, authHeader, audience, requiredScope string, isLatestSpec bool,
authHeader, audience string,
) (*authz.TokenClaims, error) { ) (*authz.TokenClaims, error) {
tokenStr := strings.TrimPrefix(authHeader, "Bearer ") tokenStr := strings.TrimPrefix(authHeader, "Bearer ")
if tokenStr == "" { if tokenStr == "" {
return nil, errors.New("empty bearer token") return nil, errors.New("empty bearer token")
} }
// 2) parse & verify signature // --- parse & verify signature ---
token, err := jwt.Parse(tokenStr, func(token *jwt.Token) (interface{}, error) { token, err := jwt.Parse(tokenStr, func(token *jwt.Token) (interface{}, error) {
if _, ok := token.Method.(*jwt.SigningMethodRSA); !ok { if _, ok := token.Method.(*jwt.SigningMethodRSA); !ok {
return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"]) return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
@ -107,7 +109,6 @@ func ValidateJWT(
} }
return key, nil return key, nil
}) })
if err != nil { if err != nil {
return nil, fmt.Errorf("invalid token: %w", err) return nil, fmt.Errorf("invalid token: %w", err)
} }
@ -115,18 +116,19 @@ func ValidateJWT(
return nil, errors.New("token not valid") return nil, errors.New("token not valid")
} }
// always extract claims // --- extract raw claims ---
claimsMap, ok := token.Claims.(jwt.MapClaims) claimsMap, ok := token.Claims.(jwt.MapClaims)
if !ok { if !ok {
return nil, errors.New("unexpected claim type") return nil, errors.New("unexpected claim type")
} }
// if older than cutover, skip audience+scope // --- v1: skip audience check entirely ---
if !isLatestSpec { if !isLatestSpec {
// we still want to return an empty set of scopes for policy to see
return &authz.TokenClaims{Scopes: nil}, nil return &authz.TokenClaims{Scopes: nil}, nil
} }
// --- new spec flow: enforce audience --- // --- v2: enforce audience ---
audRaw, exists := claimsMap["aud"] audRaw, exists := claimsMap["aud"]
if !exists { if !exists {
return nil, errors.New("aud claim missing") return nil, errors.New("aud claim missing")
@ -151,25 +153,33 @@ func ValidateJWT(
return nil, errors.New("aud claim has unexpected type") return nil, errors.New("aud claim has unexpected type")
} }
// if no scope required, we're done // --- collect all scopes from the token, if any ---
if requiredScope == "" { rawScope := claimsMap["scope"]
return &authz.TokenClaims{Scopes: nil}, nil scopeList := []string{}
if s, ok := rawScope.(string); ok {
scopeList = strings.Fields(s)
} }
// enforce scope return &authz.TokenClaims{Scopes: scopeList}, nil
rawScope, exists := claimsMap["scope"] }
if !exists {
return nil, errors.New("scope claim missing") // Process the required scopes
} func GetRequiredScopes(cfg *config.Config, method string) []string {
scopeStr, ok := rawScope.(string) if scopes, ok := cfg.ScopesSupported.(map[string]string); ok && len(scopes) > 0 {
if !ok { if scope, ok := scopes[method]; ok {
return nil, errors.New("scope claim not a string") return []string{scope}
} }
scopes := strings.Fields(scopeStr) if parts := strings.SplitN(method, "/", 2); len(parts) > 0 {
for _, s := range scopes { if scope, ok := scopes[parts[0]]; ok {
if s == requiredScope { return []string{scope}
return &authz.TokenClaims{Scopes: scopes}, nil }
} }
} return nil
return nil, fmt.Errorf("insufficient scope: %q not in %v", requiredScope, scopes) }
if scopes, ok := cfg.ScopesSupported.([]string); ok && len(scopes) > 0 {
return scopes
}
return []string{}
} }

View file

@ -34,6 +34,5 @@ func ParseRPCRequest(r *http.Request) (*RPCEnvelope, error) {
return nil, err return nil, err
} }
logger.Info("JSON-RPC method = %q", env.Method)
return &env, nil return &env, nil
} }