diff --git a/config.yaml b/config.yaml index 294b9e8..4d1e2aa 100644 --- a/config.yaml +++ b/config.yaml @@ -50,9 +50,9 @@ demo: resource_identifier: http://localhost:3000 audience: mcp_proxy scopes_supported: - "read:tools" - "read:resources" - "read:prompts" + - "tools":"read:tools" + - "resources":"read:resources" + - "prompts":"read:prompts" authorization_servers: - https://api.asgardeo.io/t/acme/ jwks_uri: https://api.asgardeo.io/t/acme/oauth2/jwks diff --git a/internal/authz/default_policy_engine.go b/internal/authz/default_policy_engine.go index d9efe12..6f002e6 100644 --- a/internal/authz/default_policy_engine.go +++ b/internal/authz/default_policy_engine.go @@ -4,6 +4,8 @@ import ( "fmt" "net/http" "strings" + + logger "github.com/wso2/open-mcp-auth-proxy/internal/logging" ) type TokenClaims struct { @@ -16,30 +18,42 @@ type DefaultPolicyEngine struct{} func (d *DefaultPolicyEngine) Evaluate( _ *http.Request, claims *TokenClaims, - requiredScope string, + requiredScopes any, ) PolicyResult { - if strings.TrimSpace(requiredScope) == "" { + + logger.Info("Required scopes: %v", requiredScopes) + + var scopeStr string + switch v := requiredScopes.(type) { + case string: + scopeStr = v + case []string: + scopeStr = strings.Join(v, " ") + } + + if strings.TrimSpace(scopeStr) == "" { return PolicyResult{DecisionAllow, ""} } - raw := strings.FieldsFunc(requiredScope, func(r rune) bool { + scopes := strings.FieldsFunc(scopeStr, func(r rune) bool { return r == ' ' || r == ',' }) - want := make(map[string]struct{}, len(raw)) - for _, s := range raw { + required := make(map[string]struct{}, len(scopes)) + for _, s := range scopes { if s = strings.TrimSpace(s); s != "" { - want[s] = struct{}{} + required[s] = struct{}{} } } - for _, have := range claims.Scopes { - if _, ok := want[have]; ok { + logger.Info("Token scopes: %v", claims.Scopes) + for _, tokenScope := range claims.Scopes { + if _, ok := required[tokenScope]; ok { return PolicyResult{DecisionAllow, ""} } } var list []string - for s := range want { + for s := range required { list = append(list, s) } return PolicyResult{ diff --git a/internal/authz/policy.go b/internal/authz/policy.go index 793e7bc..5995250 100644 --- a/internal/authz/policy.go +++ b/internal/authz/policy.go @@ -5,15 +5,15 @@ import "net/http" type Decision int const ( - DecisionAllow Decision = iota - DecisionDeny + DecisionAllow Decision = iota + DecisionDeny ) type PolicyResult struct { - Decision Decision - Message string + Decision Decision + Message string } type PolicyEngine interface { - Evaluate(r *http.Request, claims *TokenClaims, requiredScope string) PolicyResult + Evaluate(r *http.Request, claims *TokenClaims, requiredScopes any) PolicyResult } diff --git a/internal/config/config.go b/internal/config/config.go index aba479e..8c47d8e 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -105,7 +105,7 @@ type Config struct { // Protected resource metadata Audience string `yaml:"audience"` ResourceIdentifier string `yaml:"resource_identifier"` - ScopesSupported map[string]string `yaml:"scopes_supported"` + ScopesSupported any `yaml:"scopes_supported"` AuthorizationServers []string `yaml:"authorization_servers"` JwksURI string `yaml:"jwks_uri,omitempty"` BearerMethodsSupported []string `yaml:"bearer_methods_supported,omitempty"` diff --git a/internal/proxy/proxy.go b/internal/proxy/proxy.go index 9eb8729..377f164 100644 --- a/internal/proxy/proxy.go +++ b/internal/proxy/proxy.go @@ -175,20 +175,10 @@ func buildProxyHandler(cfg *config.Config, modifiers map[string]RequestModifier, } isSSE = true } else { - claims, err := authorizeMCP(w, r, isLatestSpec, cfg) - if err != nil { + if err := authorizeMCP(w, r, isLatestSpec, cfg, policyEngine); err != nil { http.Error(w, err.Error(), http.StatusForbidden) return } - - if isLatestSpec { - scope := cfg.ScopesSupported[r.URL.Path] - pr := policyEngine.Evaluate(r, claims, scope) - if pr.Decision == authz.DecisionDeny { - http.Error(w, "Forbidden: "+pr.Message, http.StatusForbidden) - return - } - } } targetURL = mcpBase @@ -310,54 +300,54 @@ func authorizeSSE(w http.ResponseWriter, r *http.Request, isLatestSpec bool, res } // Handles both v1 (just signature) and v2 (aud + scope) flows -func authorizeMCP(w http.ResponseWriter, r *http.Request, isLatestSpec bool, cfg *config.Config) (*authz.TokenClaims, error) { - // Parse JSON-RPC request if present - if env, err := util.ParseRPCRequest(r); err != nil { - http.Error(w, "Bad request", http.StatusBadRequest) - return nil, err - } else if env != nil { - logger.Info("JSON-RPC method = %q", env.Method) +func authorizeMCP(w http.ResponseWriter, r *http.Request, isLatestSpec bool, cfg *config.Config, policyEngine authz.PolicyEngine) error { + authzHeader := r.Header.Get("Authorization") + if !strings.HasPrefix(authzHeader, "Bearer ") { + if isLatestSpec { + realm := cfg.ResourceIdentifier + "/.well-known/oauth-protected-resource" + w.Header().Set("WWW-Authenticate", fmt.Sprintf( + `Bearer resource_metadata=%q`, realm, + )) + w.Header().Set("Access-Control-Expose-Headers", "WWW-Authenticate") + } + http.Error(w, "Unauthorized", http.StatusUnauthorized) + return fmt.Errorf("missing or invalid Authorization header") } - authzHeader := r.Header.Get("Authorization") - if !strings.HasPrefix(authzHeader, "Bearer ") { - if isLatestSpec { - realm := cfg.ResourceIdentifier + "/.well-known/oauth-protected-resource" - w.Header().Set("WWW-Authenticate", fmt.Sprintf( - `Bearer resource_metadata=%q`, realm, - )) - w.Header().Set("Access-Control-Expose-Headers", "WWW-Authenticate") - } - http.Error(w, "Unauthorized", http.StatusUnauthorized) - return nil, fmt.Errorf("missing or invalid Authorization header") - } + claims, err := util.ValidateJWT(isLatestSpec, authzHeader, cfg.Audience) + if err != nil { + if isLatestSpec { + realm := cfg.ResourceIdentifier + "/.well-known/oauth-protected-resource" + w.Header().Set("WWW-Authenticate", fmt.Sprintf(err.Error(), + `Bearer realm=%q`, + realm, + )) + w.Header().Set("Access-Control-Expose-Headers", "WWW-Authenticate") + http.Error(w, "Forbidden", http.StatusForbidden) + } else { + http.Error(w, "Unauthorized", http.StatusUnauthorized) + } + return err + } - requiredScope := "" - if isLatestSpec { - requiredScope = cfg.ScopesSupported[r.URL.Path] - } - claims, err := util.ValidateJWT( - isLatestSpec, - authzHeader, - cfg.Audience, - requiredScope, - ) - if err != nil { - if isLatestSpec { - realm := cfg.ResourceIdentifier + "/.well-known/oauth-protected-resource" - w.Header().Set("WWW-Authenticate", fmt.Sprintf( - `Bearer realm=%q, error="insufficient_scope", scope=%q`, - realm, requiredScope, - )) - w.Header().Set("Access-Control-Expose-Headers", "WWW-Authenticate") - http.Error(w, "Forbidden", http.StatusForbidden) - } else { - http.Error(w, "Unauthorized", http.StatusUnauthorized) - } - return nil, err - } + if isLatestSpec { + env, err := util.ParseRPCRequest(r) + if err != nil { + http.Error(w, "Bad request", http.StatusBadRequest) + return err + } + requiredScopes := util.GetRequiredScopes(cfg, env.Method) + if len(requiredScopes) == 0 { + return nil + } + pr := policyEngine.Evaluate(r, claims, requiredScopes) + if pr.Decision == authz.DecisionDeny { + http.Error(w, "Forbidden: "+pr.Message, http.StatusForbidden) + return fmt.Errorf("forbidden — %s", pr.Message) + } + } - return claims, nil + return nil } func getAllowedOrigin(origin string, cfg *config.Config) string { diff --git a/internal/util/jwks.go b/internal/util/jwks.go index 40d72bf..0692057 100644 --- a/internal/util/jwks.go +++ b/internal/util/jwks.go @@ -11,7 +11,8 @@ import ( "github.com/golang-jwt/jwt/v4" "github.com/wso2/open-mcp-auth-proxy/internal/authz" - "github.com/wso2/open-mcp-auth-proxy/internal/logging" + "github.com/wso2/open-mcp-auth-proxy/internal/config" + logger "github.com/wso2/open-mcp-auth-proxy/internal/logging" ) type TokenClaims struct { @@ -80,19 +81,20 @@ func parseRSAPublicKey(nStr, eStr string) (*rsa.PublicKey, error) { } // ValidateJWT checks the Bearer token according to the Mcp-Protocol-Version. -// - versionHeader: the raw value of the "Mcp-Protocol-Version" header +// - isLatestSpec: whether to use the latest spec validation // - authHeader: the full "Authorization" header // - audience: the resource identifier to check "aud" against -// - requiredScope: the single scope required (empty ⇒ skip scope check) +// - requiredScopes: the scopes required (empty ⇒ skip scope check) func ValidateJWT( - isLatestSpec bool, authHeader, audience, requiredScope string, + isLatestSpec bool, + authHeader, audience string, ) (*authz.TokenClaims, error) { tokenStr := strings.TrimPrefix(authHeader, "Bearer ") if tokenStr == "" { return nil, errors.New("empty bearer token") } - // 2) parse & verify signature + // --- parse & verify signature --- token, err := jwt.Parse(tokenStr, func(token *jwt.Token) (interface{}, error) { if _, ok := token.Method.(*jwt.SigningMethodRSA); !ok { return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"]) @@ -107,7 +109,6 @@ func ValidateJWT( } return key, nil }) - if err != nil { return nil, fmt.Errorf("invalid token: %w", err) } @@ -115,18 +116,19 @@ func ValidateJWT( return nil, errors.New("token not valid") } - // always extract claims + // --- extract raw claims --- claimsMap, ok := token.Claims.(jwt.MapClaims) if !ok { return nil, errors.New("unexpected claim type") } - // if older than cutover, skip audience+scope + // --- v1: skip audience check entirely --- if !isLatestSpec { + // we still want to return an empty set of scopes for policy to see return &authz.TokenClaims{Scopes: nil}, nil } - // --- new spec flow: enforce audience --- + // --- v2: enforce audience --- audRaw, exists := claimsMap["aud"] if !exists { return nil, errors.New("aud claim missing") @@ -151,25 +153,33 @@ func ValidateJWT( return nil, errors.New("aud claim has unexpected type") } - // if no scope required, we're done - if requiredScope == "" { - return &authz.TokenClaims{Scopes: nil}, nil + // --- collect all scopes from the token, if any --- + rawScope := claimsMap["scope"] + scopeList := []string{} + if s, ok := rawScope.(string); ok { + scopeList = strings.Fields(s) } - // enforce scope - rawScope, exists := claimsMap["scope"] - if !exists { - return nil, errors.New("scope claim missing") - } - scopeStr, ok := rawScope.(string) - if !ok { - return nil, errors.New("scope claim not a string") - } - scopes := strings.Fields(scopeStr) - for _, s := range scopes { - if s == requiredScope { - return &authz.TokenClaims{Scopes: scopes}, nil - } - } - return nil, fmt.Errorf("insufficient scope: %q not in %v", requiredScope, scopes) + return &authz.TokenClaims{Scopes: scopeList}, nil +} + +// Process the required scopes +func GetRequiredScopes(cfg *config.Config, method string) []string { + if scopes, ok := cfg.ScopesSupported.(map[string]string); ok && len(scopes) > 0 { + if scope, ok := scopes[method]; ok { + return []string{scope} + } + if parts := strings.SplitN(method, "/", 2); len(parts) > 0 { + if scope, ok := scopes[parts[0]]; ok { + return []string{scope} + } + } + return nil + } + + if scopes, ok := cfg.ScopesSupported.([]string); ok && len(scopes) > 0 { + return scopes + } + + return []string{} } diff --git a/internal/util/rpc.go b/internal/util/rpc.go index 896e9b2..5338437 100644 --- a/internal/util/rpc.go +++ b/internal/util/rpc.go @@ -34,6 +34,5 @@ func ParseRPCRequest(r *http.Request) (*RPCEnvelope, error) { return nil, err } - logger.Info("JSON-RPC method = %q", env.Method) return &env, nil }