From 312a5557f084a36c4f6b5e9acc453cfa86f70c45 Mon Sep 17 00:00:00 2001 From: NipuniBhagya Date: Wed, 14 May 2025 21:47:15 +0530 Subject: [PATCH] Fix audience validation issues --- cmd/proxy/main.go | 2 +- internal/authz/default.go | 1 + internal/authz/default_policy_engine.go | 39 +++++++++++-- internal/config/config.go | 1 + internal/proxy/proxy.go | 78 +++++++++++++++---------- internal/util/jwks.go | 54 +++++------------ internal/util/rpc.go | 39 +++++++++++++ internal/util/version.go | 25 ++++++++ 8 files changed, 163 insertions(+), 76 deletions(-) create mode 100644 internal/util/rpc.go create mode 100644 internal/util/version.go diff --git a/cmd/proxy/main.go b/cmd/proxy/main.go index f24c21d..562e7aa 100644 --- a/cmd/proxy/main.go +++ b/cmd/proxy/main.go @@ -93,7 +93,7 @@ func main() { } // 5. (Optional) Build the policy engine - engine := &authz.DefaulPolicyEngine{} + engine := &authz.DefaultPolicyEngine{} // 6. Build the main router mux := proxy.NewRouter(cfg, provider, engine) diff --git a/internal/authz/default.go b/internal/authz/default.go index f4d640d..dc8900d 100644 --- a/internal/authz/default.go +++ b/internal/authz/default.go @@ -99,6 +99,7 @@ func (p *defaultProvider) ProtectedResourceMetadataHandler() http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "application/json") meta := map[string]interface{}{ + "audience": p.cfg.Audience, "resource": p.cfg.ResourceIdentifier, "scopes_supported": p.cfg.ScopesSupported, "authorization_servers": p.cfg.AuthorizationServers, diff --git a/internal/authz/default_policy_engine.go b/internal/authz/default_policy_engine.go index efc23d2..d9efe12 100644 --- a/internal/authz/default_policy_engine.go +++ b/internal/authz/default_policy_engine.go @@ -1,20 +1,49 @@ package authz import ( + "fmt" "net/http" + "strings" ) type TokenClaims struct { Scopes []string } -type DefaulPolicyEngine struct{} +type DefaultPolicyEngine struct{} -func (d *DefaulPolicyEngine) Evaluate(r *http.Request, claims *TokenClaims, requiredScope string) PolicyResult { - for _, scope := range claims.Scopes { - if scope == requiredScope { +// Evaluate and checks the token claims against one or more required scopes. +func (d *DefaultPolicyEngine) Evaluate( + _ *http.Request, + claims *TokenClaims, + requiredScope string, +) PolicyResult { + if strings.TrimSpace(requiredScope) == "" { + return PolicyResult{DecisionAllow, ""} + } + + raw := strings.FieldsFunc(requiredScope, func(r rune) bool { + return r == ' ' || r == ',' + }) + want := make(map[string]struct{}, len(raw)) + for _, s := range raw { + if s = strings.TrimSpace(s); s != "" { + want[s] = struct{}{} + } + } + + for _, have := range claims.Scopes { + if _, ok := want[have]; ok { return PolicyResult{DecisionAllow, ""} } } - return PolicyResult{DecisionDeny, "missing scope '" + requiredScope + "'"} + + var list []string + for s := range want { + list = append(list, s) + } + return PolicyResult{ + DecisionDeny, + fmt.Sprintf("missing required scope(s): %s", strings.Join(list, ", ")), + } } diff --git a/internal/config/config.go b/internal/config/config.go index 47778d0..aba479e 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -103,6 +103,7 @@ type Config struct { Default DefaultConfig `yaml:"default"` // Protected resource metadata + Audience string `yaml:"audience"` ResourceIdentifier string `yaml:"resource_identifier"` ScopesSupported map[string]string `yaml:"scopes_supported"` AuthorizationServers []string `yaml:"authorization_servers"` diff --git a/internal/proxy/proxy.go b/internal/proxy/proxy.go index 220f13e..9eb8729 100644 --- a/internal/proxy/proxy.go +++ b/internal/proxy/proxy.go @@ -11,7 +11,6 @@ import ( "github.com/wso2/open-mcp-auth-proxy/internal/authz" "github.com/wso2/open-mcp-auth-proxy/internal/config" - "github.com/wso2/open-mcp-auth-proxy/internal/constants" logger "github.com/wso2/open-mcp-auth-proxy/internal/logging" "github.com/wso2/open-mcp-auth-proxy/internal/util" ) @@ -157,9 +156,10 @@ func buildProxyHandler(cfg *config.Config, modifiers map[string]RequestModifier, // Add CORS headers to all responses addCORSHeaders(w, cfg, allowedOrigin, "") - versionRaw := r.Header.Get("MCP-Protocol-Version") - ver, err := time.Parse(constants.TimeLayout, versionRaw) - isLatestSpec := err == nil && !ver.Before(constants.SpecCutoverDate) + // Check if the request is for the latest spec + specVersion := util.GetVersionWithDefault(r.Header.Get("MCP-Protocol-Version")) + ver, err := util.ParseVersionDate(specVersion) + isLatestSpec := util.IsLatestSpec(ver, err) // Decide whether the request should go to the auth server or MCP var targetURL *url.URL @@ -311,35 +311,53 @@ func authorizeSSE(w http.ResponseWriter, r *http.Request, isLatestSpec bool, res // Handles both v1 (just signature) and v2 (aud + scope) flows func authorizeMCP(w http.ResponseWriter, r *http.Request, isLatestSpec bool, cfg *config.Config) (*authz.TokenClaims, error) { - logger.Info("authorizeMCP") - h := r.Header.Get("Authorization") - audience := cfg.ResourceIdentifier - if isLatestSpec { - required := cfg.ScopesSupported[r.URL.Path] - claims, err := util.ValidateJWT(r.Header.Get("MCP-Protocol-Version"), h, audience, required) - logger.Info("claims: %v", claims) - logger.Info("err: %v", err) - if err != nil { - w.Header().Set( - "WWW-Authenticate", - fmt.Sprintf( - `Bearer realm="%s", error="insufficient_scope", scope="%s"`, - cfg.ResourceIdentifier+"/.well-known/oauth-protected-resource", - required, - ), - ) - w.Header().Set("Access-Control-Expose-Headers", "WWW-Authenticate") - return nil, fmt.Errorf("forbidden — insufficient scope") - } - return claims, nil - } - - // v1: only check signature, then continue - if err := util.ValidateJWTLegacy(h); err != nil { + // Parse JSON-RPC request if present + if env, err := util.ParseRPCRequest(r); err != nil { + http.Error(w, "Bad request", http.StatusBadRequest) return nil, err + } else if env != nil { + logger.Info("JSON-RPC method = %q", env.Method) } - return &authz.TokenClaims{}, nil + authzHeader := r.Header.Get("Authorization") + if !strings.HasPrefix(authzHeader, "Bearer ") { + if isLatestSpec { + realm := cfg.ResourceIdentifier + "/.well-known/oauth-protected-resource" + w.Header().Set("WWW-Authenticate", fmt.Sprintf( + `Bearer resource_metadata=%q`, realm, + )) + w.Header().Set("Access-Control-Expose-Headers", "WWW-Authenticate") + } + http.Error(w, "Unauthorized", http.StatusUnauthorized) + return nil, fmt.Errorf("missing or invalid Authorization header") + } + + requiredScope := "" + if isLatestSpec { + requiredScope = cfg.ScopesSupported[r.URL.Path] + } + claims, err := util.ValidateJWT( + isLatestSpec, + authzHeader, + cfg.Audience, + requiredScope, + ) + if err != nil { + if isLatestSpec { + realm := cfg.ResourceIdentifier + "/.well-known/oauth-protected-resource" + w.Header().Set("WWW-Authenticate", fmt.Sprintf( + `Bearer realm=%q, error="insufficient_scope", scope=%q`, + realm, requiredScope, + )) + w.Header().Set("Access-Control-Expose-Headers", "WWW-Authenticate") + http.Error(w, "Forbidden", http.StatusForbidden) + } else { + http.Error(w, "Unauthorized", http.StatusUnauthorized) + } + return nil, err + } + + return claims, nil } func getAllowedOrigin(origin string, cfg *config.Config) string { diff --git a/internal/util/jwks.go b/internal/util/jwks.go index a050427..40d72bf 100644 --- a/internal/util/jwks.go +++ b/internal/util/jwks.go @@ -8,12 +8,10 @@ import ( "math/big" "net/http" "strings" - "time" "github.com/golang-jwt/jwt/v4" "github.com/wso2/open-mcp-auth-proxy/internal/authz" - "github.com/wso2/open-mcp-auth-proxy/internal/constants" - logger "github.com/wso2/open-mcp-auth-proxy/internal/logging" + "github.com/wso2/open-mcp-auth-proxy/internal/logging" ) type TokenClaims struct { @@ -87,7 +85,7 @@ func parseRSAPublicKey(nStr, eStr string) (*rsa.PublicKey, error) { // - audience: the resource identifier to check "aud" against // - requiredScope: the single scope required (empty ⇒ skip scope check) func ValidateJWT( - versionHeader, authHeader, audience, requiredScope string, + isLatestSpec bool, authHeader, audience, requiredScope string, ) (*authz.TokenClaims, error) { tokenStr := strings.TrimPrefix(authHeader, "Bearer ") if tokenStr == "" { @@ -96,17 +94,20 @@ func ValidateJWT( // 2) parse & verify signature token, err := jwt.Parse(tokenStr, func(token *jwt.Token) (interface{}, error) { - kid, _ := token.Header["kid"].(string) - pk, ok := publicKeys[kid] - if !ok { - return nil, fmt.Errorf("unknown kid %q", kid) + if _, ok := token.Method.(*jwt.SigningMethodRSA); !ok { + return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"]) } - return pk, nil + kid, ok := token.Header["kid"].(string) + if !ok { + return nil, errors.New("kid header not found") + } + key, ok := publicKeys[kid] + if !ok { + return nil, fmt.Errorf("key not found for kid: %s", kid) + } + return key, nil }) - logger.Info("token: %v", token) - logger.Info("err: %v", err) - if err != nil { return nil, fmt.Errorf("invalid token: %w", err) } @@ -120,15 +121,8 @@ func ValidateJWT( return nil, errors.New("unexpected claim type") } - // parse version date - verDate, err := time.Parse("2006-01-02", versionHeader) - if err != nil { - // if unparsable or missing, assume _old_ spec - verDate = time.Time{} // zero time ⇒ before cutover - } - // if older than cutover, skip audience+scope - if verDate.Before(constants.SpecCutoverDate) { + if !isLatestSpec { return &authz.TokenClaims{Scopes: nil}, nil } @@ -179,23 +173,3 @@ func ValidateJWT( } return nil, fmt.Errorf("insufficient scope: %q not in %v", requiredScope, scopes) } - -// Performs basic JWT validation -func ValidateJWTLegacy(authHeader string) error { - tokenString := strings.TrimPrefix(authHeader, "Bearer ") - _, err := jwt.Parse(tokenString, func(token *jwt.Token) (interface{}, error) { - if _, ok := token.Method.(*jwt.SigningMethodRSA); !ok { - return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"]) - } - kid, ok := token.Header["kid"].(string) - if !ok { - return nil, errors.New("kid header not found") - } - key, ok := publicKeys[kid] - if !ok { - return nil, fmt.Errorf("key not found for kid: %s", kid) - } - return key, nil - }) - return err -} diff --git a/internal/util/rpc.go b/internal/util/rpc.go new file mode 100644 index 0000000..896e9b2 --- /dev/null +++ b/internal/util/rpc.go @@ -0,0 +1,39 @@ +package util + +import ( + "bytes" + "encoding/json" + "io" + "net/http" + + logger "github.com/wso2/open-mcp-auth-proxy/internal/logging" +) + +type RPCEnvelope struct { + Method string `json:"method"` + Params any `json:"params"` + ID any `json:"id"` +} + +// This function parses a JSON-RPC request from an HTTP request body +func ParseRPCRequest(r *http.Request) (*RPCEnvelope, error) { + bodyBytes, err := io.ReadAll(r.Body) + if err != nil { + return nil, err + } + r.Body = io.NopCloser(bytes.NewBuffer(bodyBytes)) + + if len(bodyBytes) == 0 { + return nil, nil + } + + var env RPCEnvelope + dec := json.NewDecoder(bytes.NewReader(bodyBytes)) + if err := dec.Decode(&env); err != nil && err != io.EOF { + logger.Warn("Error parsing JSON-RPC envelope: %v", err) + return nil, err + } + + logger.Info("JSON-RPC method = %q", env.Method) + return &env, nil +} diff --git a/internal/util/version.go b/internal/util/version.go new file mode 100644 index 0000000..acf381a --- /dev/null +++ b/internal/util/version.go @@ -0,0 +1,25 @@ +package util + +import ( + "time" + + "github.com/wso2/open-mcp-auth-proxy/internal/constants" +) + +// This function checks if the given version date is after the spec cutover date +func IsLatestSpec(versionDate time.Time, err error) bool { + return err == nil && !versionDate.Before(constants.SpecCutoverDate) +} + +// This function parses a version string into a time.Time +func ParseVersionDate(version string) (time.Time, error) { + return time.Parse("2006-01-02", version) +} + +// This function returns the version string, using the cutover date if empty +func GetVersionWithDefault(version string) string { + if version == "" { + return constants.SpecCutoverDate.Format("2006-01-02") + } + return version +}