From 331cc281c6a70f4f7c07a5c6517181c3fc087ab8 Mon Sep 17 00:00:00 2001 From: NipuniBhagya Date: Wed, 14 May 2025 15:39:02 +0530 Subject: [PATCH] Refactor proxy builder --- config.yaml | 13 +++ internal/authz/asgardeo.go | 20 ++++ internal/authz/default_policy_engine.go | 20 ++++ internal/proxy/proxy.go | 40 ++++--- internal/util/jwks.go | 142 ++++++++++++++++++++---- 5 files changed, 200 insertions(+), 35 deletions(-) create mode 100644 internal/authz/default_policy_engine.go diff --git a/config.yaml b/config.yaml index 5621195..7d7520d 100644 --- a/config.yaml +++ b/config.yaml @@ -45,3 +45,16 @@ demo: org_name: "openmcpauthdemo" client_id: "N0U9e_NNGr9mP_0fPnPfPI0a6twa" client_secret: "qFHfiBp5gNGAO9zV4YPnDofBzzfInatfUbHyPZvM0jka" + +# Protected resource metadata +resource_identifier: http://localhost:3000 +scopes_supported: + - get-alerts + - get-forecast +authorization_servers: + - https://idp.example.com +jwks_uri: https://idp.example.com/.well-known/jwks.json +bearer_methods_supported: + - header + - body + - query \ No newline at end of file diff --git a/internal/authz/asgardeo.go b/internal/authz/asgardeo.go index a3c812c..dc433b4 100644 --- a/internal/authz/asgardeo.go +++ b/internal/authz/asgardeo.go @@ -336,3 +336,23 @@ func randomString(n int) string { } return string(b) } + +func (p *asgardeoProvider) ProtectedResourceMetadataHandler() http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + meta := map[string]interface{}{ + "resource": p.cfg.ResourceIdentifier, + "scopes_supported": p.cfg.ScopesSupported, + "authorization_servers": p.cfg.AuthorizationServers, + } + if p.cfg.JwksURI != "" { + meta["jwks_uri"] = p.cfg.JwksURI + } + if len(p.cfg.BearerMethodsSupported) > 0 { + meta["bearer_methods_supported"] = p.cfg.BearerMethodsSupported + } + if err := json.NewEncoder(w).Encode(meta); err != nil { + http.Error(w, "failed to encode metadata", http.StatusInternalServerError) + } + } +} diff --git a/internal/authz/default_policy_engine.go b/internal/authz/default_policy_engine.go new file mode 100644 index 0000000..efc23d2 --- /dev/null +++ b/internal/authz/default_policy_engine.go @@ -0,0 +1,20 @@ +package authz + +import ( + "net/http" +) + +type TokenClaims struct { + Scopes []string +} + +type DefaulPolicyEngine struct{} + +func (d *DefaulPolicyEngine) Evaluate(r *http.Request, claims *TokenClaims, requiredScope string) PolicyResult { + for _, scope := range claims.Scopes { + if scope == requiredScope { + return PolicyResult{DecisionAllow, ""} + } + } + return PolicyResult{DecisionDeny, "missing scope '" + requiredScope + "'"} +} diff --git a/internal/proxy/proxy.go b/internal/proxy/proxy.go index 682867e..220f13e 100644 --- a/internal/proxy/proxy.go +++ b/internal/proxy/proxy.go @@ -177,7 +177,7 @@ func buildProxyHandler(cfg *config.Config, modifiers map[string]RequestModifier, } else { claims, err := authorizeMCP(w, r, isLatestSpec, cfg) if err != nil { - http.Error(w, err.Error(), http.StatusUnauthorized) + http.Error(w, err.Error(), http.StatusForbidden) return } @@ -227,13 +227,13 @@ func buildProxyHandler(cfg *config.Config, modifiers map[string]RequestModifier, req.Host = targetURL.Host cleanHeaders := http.Header{} - + // Set proper origin header to match the target if isSSE { // For SSE, ensure origin matches the target req.Header.Set("Origin", targetURL.Scheme+"://"+targetURL.Host) } - + for k, v := range r.Header { // Skip hop-by-hop headers if skipHeader(k) { @@ -277,7 +277,7 @@ func buildProxyHandler(cfg *config.Config, modifiers map[string]RequestModifier, proxyHost: r.Host, targetHost: targetURL.Host, } - + // Set SSE-specific headers w.Header().Set("X-Accel-Buffering", "no") w.Header().Set("Cache-Control", "no-cache") @@ -296,15 +296,14 @@ func buildProxyHandler(cfg *config.Config, modifiers map[string]RequestModifier, // Check if the request is for SSE handshake and authorize it func authorizeSSE(w http.ResponseWriter, r *http.Request, isLatestSpec bool, resourceID string) error { - h := r.Header.Get("Authorization") - if !strings.HasPrefix(h, "Bearer ") { + authHeader := r.Header.Get("Authorization") + if authHeader == "" || !strings.HasPrefix(authHeader, "Bearer ") { if isLatestSpec { realm := resourceID + "/.well-known/oauth-protected-resource" w.Header().Set("WWW-Authenticate", fmt.Sprintf(`Bearer resource_metadata="%s"`, realm)) w.Header().Set("Access-Control-Expose-Headers", "WWW-Authenticate") } - - return fmt.Errorf("missing bearer token") + return fmt.Errorf("missing or invalid Authorization header") } return nil @@ -312,23 +311,31 @@ func authorizeSSE(w http.ResponseWriter, r *http.Request, isLatestSpec bool, res // Handles both v1 (just signature) and v2 (aud + scope) flows func authorizeMCP(w http.ResponseWriter, r *http.Request, isLatestSpec bool, cfg *config.Config) (*authz.TokenClaims, error) { + logger.Info("authorizeMCP") h := r.Header.Get("Authorization") audience := cfg.ResourceIdentifier if isLatestSpec { - scope := cfg.ScopesSupported[r.URL.Path] - claims, err := util.ValidateJWT(r.Header.Get("MCP-Protocol-Version"), h, audience, scope) + required := cfg.ScopesSupported[r.URL.Path] + claims, err := util.ValidateJWT(r.Header.Get("MCP-Protocol-Version"), h, audience, required) + logger.Info("claims: %v", claims) + logger.Info("err: %v", err) if err != nil { - realm := audience + "/.well-known/oauth-protected-resource" - w.Header().Set("WWW-Authenticate", - fmt.Sprintf(`Bearer realm="%s", error="insufficient_scope", scope="%s"`, realm, scope)) + w.Header().Set( + "WWW-Authenticate", + fmt.Sprintf( + `Bearer realm="%s", error="insufficient_scope", scope="%s"`, + cfg.ResourceIdentifier+"/.well-known/oauth-protected-resource", + required, + ), + ) w.Header().Set("Access-Control-Expose-Headers", "WWW-Authenticate") - return nil, err + return nil, fmt.Errorf("forbidden — insufficient scope") } return claims, nil } // v1: only check signature, then continue - if err := util.ValidateJWTOld(h); err != nil { + if err := util.ValidateJWTLegacy(h); err != nil { return nil, err } @@ -352,7 +359,7 @@ func getAllowedOrigin(origin string, cfg *config.Config) string { func addCORSHeaders(w http.ResponseWriter, cfg *config.Config, allowedOrigin, requestHeaders string) { w.Header().Set("Access-Control-Allow-Origin", allowedOrigin) w.Header().Set("Access-Control-Allow-Methods", strings.Join(cfg.CORSConfig.AllowedMethods, ", ")) - w.Header().Set("Access-Control-Expose-Headers", "WWW-Authenticate") + w.Header().Set("Access-Control-Expose-Headers", "WWW-Authenticate, MCP-Protocol-Version") if requestHeaders != "" { w.Header().Set("Access-Control-Allow-Headers", requestHeaders) } else { @@ -360,6 +367,7 @@ func addCORSHeaders(w http.ResponseWriter, cfg *config.Config, allowedOrigin, re } if cfg.CORSConfig.AllowCredentials { w.Header().Set("Access-Control-Allow-Credentials", "true") + w.Header().Set("MCP-Protocol-Version", ", ") } w.Header().Set("Vary", "Origin") w.Header().Set("X-Accel-Buffering", "no") diff --git a/internal/util/jwks.go b/internal/util/jwks.go index f80d82e..a050427 100644 --- a/internal/util/jwks.go +++ b/internal/util/jwks.go @@ -4,21 +4,29 @@ import ( "crypto/rsa" "encoding/json" "errors" + "fmt" "math/big" "net/http" "strings" + "time" "github.com/golang-jwt/jwt/v4" - "github.com/wso2/open-mcp-auth-proxy/internal/logging" + "github.com/wso2/open-mcp-auth-proxy/internal/authz" + "github.com/wso2/open-mcp-auth-proxy/internal/constants" + logger "github.com/wso2/open-mcp-auth-proxy/internal/logging" ) +type TokenClaims struct { + Scopes []string +} + type JWKS struct { Keys []json.RawMessage `json:"keys"` } var publicKeys map[string]*rsa.PublicKey -// FetchJWKS downloads JWKS and stores in a package-level map +// FetchJWKS downloads JWKS and stores in a package‐level map func FetchJWKS(jwksURL string) error { resp, err := http.Get(jwksURL) if err != nil { @@ -31,23 +39,23 @@ func FetchJWKS(jwksURL string) error { return err } - publicKeys = make(map[string]*rsa.PublicKey) + publicKeys = make(map[string]*rsa.PublicKey, len(jwks.Keys)) for _, keyData := range jwks.Keys { - var parsedKey struct { + var parsed struct { Kid string `json:"kid"` N string `json:"n"` E string `json:"e"` Kty string `json:"kty"` } - if err := json.Unmarshal(keyData, &parsedKey); err != nil { + if err := json.Unmarshal(keyData, &parsed); err != nil { continue } - if parsedKey.Kty != "RSA" { + if parsed.Kty != "RSA" { continue } - pubKey, err := parseRSAPublicKey(parsedKey.N, parsedKey.E) + pk, err := parseRSAPublicKey(parsed.N, parsed.E) if err == nil { - publicKeys[parsedKey.Kid] = pubKey + publicKeys[parsed.Kid] = pk } } logger.Info("Loaded %d public keys.", len(publicKeys)) @@ -73,25 +81,121 @@ func parseRSAPublicKey(nStr, eStr string) (*rsa.PublicKey, error) { return &rsa.PublicKey{N: n, E: e}, nil } -// ValidateJWT checks the Authorization: Bearer token using stored JWKS -func ValidateJWT(authHeader string) error { - if authHeader == "" || !strings.HasPrefix(authHeader, "Bearer ") { - return errors.New("missing or invalid Authorization header") - } +// ValidateJWT checks the Bearer token according to the Mcp-Protocol-Version. +// - versionHeader: the raw value of the "Mcp-Protocol-Version" header +// - authHeader: the full "Authorization" header +// - audience: the resource identifier to check "aud" against +// - requiredScope: the single scope required (empty ⇒ skip scope check) +func ValidateJWT( + versionHeader, authHeader, audience, requiredScope string, +) (*authz.TokenClaims, error) { tokenStr := strings.TrimPrefix(authHeader, "Bearer ") + if tokenStr == "" { + return nil, errors.New("empty bearer token") + } + + // 2) parse & verify signature token, err := jwt.Parse(tokenStr, func(token *jwt.Token) (interface{}, error) { kid, _ := token.Header["kid"].(string) - pubKey, ok := publicKeys[kid] + pk, ok := publicKeys[kid] if !ok { - return nil, errors.New("unknown or missing kid in token header") + return nil, fmt.Errorf("unknown kid %q", kid) } - return pubKey, nil + return pk, nil }) + + logger.Info("token: %v", token) + logger.Info("err: %v", err) + if err != nil { - return errors.New("invalid token: " + err.Error()) + return nil, fmt.Errorf("invalid token: %w", err) } if !token.Valid { - return errors.New("invalid token: token not valid") + return nil, errors.New("token not valid") } - return nil + + // always extract claims + claimsMap, ok := token.Claims.(jwt.MapClaims) + if !ok { + return nil, errors.New("unexpected claim type") + } + + // parse version date + verDate, err := time.Parse("2006-01-02", versionHeader) + if err != nil { + // if unparsable or missing, assume _old_ spec + verDate = time.Time{} // zero time ⇒ before cutover + } + + // if older than cutover, skip audience+scope + if verDate.Before(constants.SpecCutoverDate) { + return &authz.TokenClaims{Scopes: nil}, nil + } + + // --- new spec flow: enforce audience --- + audRaw, exists := claimsMap["aud"] + if !exists { + return nil, errors.New("aud claim missing") + } + switch v := audRaw.(type) { + case string: + if v != audience { + return nil, fmt.Errorf("aud %q does not match %q", v, audience) + } + case []interface{}: + var found bool + for _, a := range v { + if s, ok := a.(string); ok && s == audience { + found = true + break + } + } + if !found { + return nil, fmt.Errorf("audience %v does not include %q", v, audience) + } + default: + return nil, errors.New("aud claim has unexpected type") + } + + // if no scope required, we're done + if requiredScope == "" { + return &authz.TokenClaims{Scopes: nil}, nil + } + + // enforce scope + rawScope, exists := claimsMap["scope"] + if !exists { + return nil, errors.New("scope claim missing") + } + scopeStr, ok := rawScope.(string) + if !ok { + return nil, errors.New("scope claim not a string") + } + scopes := strings.Fields(scopeStr) + for _, s := range scopes { + if s == requiredScope { + return &authz.TokenClaims{Scopes: scopes}, nil + } + } + return nil, fmt.Errorf("insufficient scope: %q not in %v", requiredScope, scopes) +} + +// Performs basic JWT validation +func ValidateJWTLegacy(authHeader string) error { + tokenString := strings.TrimPrefix(authHeader, "Bearer ") + _, err := jwt.Parse(tokenString, func(token *jwt.Token) (interface{}, error) { + if _, ok := token.Method.(*jwt.SigningMethodRSA); !ok { + return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"]) + } + kid, ok := token.Header["kid"].(string) + if !ok { + return nil, errors.New("kid header not found") + } + key, ok := publicKeys[kid] + if !ok { + return nil, fmt.Errorf("key not found for kid: %s", kid) + } + return key, nil + }) + return err }