diff --git a/README.md b/README.md index 7e0b35e..f15aca8 100644 --- a/README.md +++ b/README.md @@ -32,9 +32,7 @@ A lightweight authorization proxy for Model Context Protocol (MCP) servers that | Version | Behavior | | :-------------------- | :----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | | 2025-03-26 | Only signature check of Bearer JWT on both `/sse` and `/message`
No scope or audience enforcement | -| Latest(draft) | Read `MCP-Protocol-Version` from client header
SSE handshake returns `WWW-Authenticate: Bearer resource_metadata="…"`
`/message` enforces:
1. `aud` claim == `ResourceIdentifier`
2. `scope` claim contains per-path `requiredScope`
3. PolicyEngine decision
Rich `WWW-Authenticate` on 401s
Serves `/​.well-known/oauth-protected-resource` JSON | - -> ⚠️ **Note:** MCP v2 support is available **only in SSE mode**. The stdio mode supports only v1. +| Latest(draft) | Read `MCP-Protocol-Version` from client header
SSE handshake returns `WWW-Authenticate: Bearer resource_metadata="…"`
`/message` enforces:
`aud` claim == `ResourceIdentifier`
`scope` claim contains `requiredScope`
Scope based access control
Rich `WWW-Authenticate` on 401s
Serves `/​.well-known/oauth-protected-resource` JSON | ## 🛠️ Quick Start @@ -104,26 +102,17 @@ To enable authorization through your Asgardeo organization: 3. Update `config.yaml` with the following parameters. ```yaml -base_url: "http://localhost:8000" # URL of your MCP server -listen_port: 8080 # Address where the proxy will listen +base_url: "http://localhost:8000" # URL of your MCP server +listen_port: 8080 # Address where the proxy will listen -asgardeo: - org_name: "" # Your Asgardeo org name - client_id: "" # Client ID of the M2M app - client_secret: "" # Client secret of the M2M app - - resource_identifier: "http://localhost:8080" - scopes_supported: - - "read:tools" - - "read:resources" - audience: "" - authorization_servers: - - "https://api.asgardeo.io/t/acme" - jwks_uri: "https://api.asgardeo.io/t/acme/oauth2/jwks" - bearer_methods_supported: - - header - - body - - query +resource_identifier: "http://localhost:8080" # Proxy server URL +scopes_supported: # Scopes required to access the MCP server +- "read:tools" +- "read:resources" +audience: "" # Access token audience +authorization_servers: # Authorization server URL +- "https://api.asgardeo.io/t/acme" +jwks_uri: "https://api.asgardeo.io/t/acme/oauth2/jwks" # JWKS URL of the Authorization server ``` 4. Start the proxy with Asgardeo integration: @@ -246,22 +235,14 @@ demo: client_secret: "qFHfiBp5gNGAO9zV4YPnDofBzzfInatfUbHyPZvM0jka" # Asgardeo configuration (used with --asgardeo flag) -asgardeo: - org_name: "" - client_id: "" - client_secret: "" - resource_identifier: "http://localhost:8080" - scopes_supported: - - "read:tools" - - "read:resources" - audience: "" - authorization_servers: - - "https://api.asgardeo.io/t/acme" - jwks_uri: "https://api.asgardeo.io/t/acme/oauth2/jwks" - bearer_methods_supported: - - header - - body - - query +resource_identifier: "http://localhost:8080" +scopes_supported: +- "read:tools" +- "read:resources" +audience: "" +authorization_servers: +- "https://api.asgardeo.io/t/acme" +jwks_uri: "https://api.asgardeo.io/t/acme/oauth2/jwks" ``` ## Build from Source diff --git a/cmd/proxy/main.go b/cmd/proxy/main.go index 2583b75..0208ead 100644 --- a/cmd/proxy/main.go +++ b/cmd/proxy/main.go @@ -11,7 +11,6 @@ import ( "github.com/wso2/open-mcp-auth-proxy/internal/authz" "github.com/wso2/open-mcp-auth-proxy/internal/config" - "github.com/wso2/open-mcp-auth-proxy/internal/constants" "github.com/wso2/open-mcp-auth-proxy/internal/logging" "github.com/wso2/open-mcp-auth-proxy/internal/proxy" "github.com/wso2/open-mcp-auth-proxy/internal/subprocess" @@ -68,23 +67,7 @@ func main() { } // 3. Create the chosen provider - var provider authz.Provider - if *demoMode { - cfg.Mode = "demo" - cfg.AuthServerBaseURL = constants.ASGARDEO_BASE_URL + cfg.Demo.OrgName + "/oauth2" - cfg.JWKSURL = constants.ASGARDEO_BASE_URL + cfg.Demo.OrgName + "/oauth2/jwks" - provider = authz.NewAsgardeoProvider(cfg) - } else if *asgardeoMode { - cfg.Mode = "asgardeo" - cfg.AuthServerBaseURL = constants.ASGARDEO_BASE_URL + cfg.Asgardeo.OrgName + "/oauth2" - cfg.JWKSURL = constants.ASGARDEO_BASE_URL + cfg.Asgardeo.OrgName + "/oauth2/jwks" - provider = authz.NewAsgardeoProvider(cfg) - } else { - cfg.Mode = "default" - cfg.JWKSURL = cfg.Default.JWKSURL - cfg.AuthServerBaseURL = cfg.Default.BaseURL - provider = authz.NewDefaultProvider(cfg) - } + var provider authz.Provider = MakeProvider(cfg, *demoMode, *asgardeoMode) // 4. (Optional) Fetch JWKS if you want local JWT validation if err := util.FetchJWKS(cfg.JWKSURL); err != nil { diff --git a/cmd/proxy/provider.go b/cmd/proxy/provider.go new file mode 100644 index 0000000..be4ee21 --- /dev/null +++ b/cmd/proxy/provider.go @@ -0,0 +1,45 @@ +package main + +import ( + "github.com/wso2/open-mcp-auth-proxy/internal/authz" + "github.com/wso2/open-mcp-auth-proxy/internal/config" + "github.com/wso2/open-mcp-auth-proxy/internal/constants" +) + +func MakeProvider(cfg *config.Config, demoMode, asgardeoMode bool) authz.Provider { + var mode, orgName string + switch { + case demoMode: + mode = "demo" + orgName = cfg.Demo.OrgName + case asgardeoMode: + mode = "asgardeo" + orgName = cfg.Asgardeo.OrgName + default: + mode = "default" + } + cfg.Mode = mode + + switch mode { + case "demo", "asgardeo": + if len(cfg.AuthorizationServers) == 0 && cfg.JwksURI == "" { + base := constants.ASGARDEO_BASE_URL + orgName + "/oauth2" + cfg.AuthServerBaseURL = base + cfg.JWKSURL = base + "/jwks" + } else { + cfg.AuthServerBaseURL = cfg.AuthorizationServers[0] + cfg.JWKSURL = cfg.JwksURI + } + return authz.NewAsgardeoProvider(cfg) + + default: + if cfg.Default.BaseURL != "" && cfg.Default.JWKSURL != "" { + cfg.AuthServerBaseURL = cfg.Default.BaseURL + cfg.JWKSURL = cfg.Default.JWKSURL + } else if len(cfg.AuthorizationServers) > 0 { + cfg.AuthServerBaseURL = cfg.AuthorizationServers[0] + cfg.JWKSURL = cfg.JwksURI + } + return authz.NewDefaultProvider(cfg) + } +} \ No newline at end of file diff --git a/internal/authz/access_control.go b/internal/authz/access_control.go index 2e321c3..1f7ce7b 100644 --- a/internal/authz/access_control.go +++ b/internal/authz/access_control.go @@ -1,6 +1,11 @@ package authz -import "net/http" +import ( + "net/http" + + "github.com/golang-jwt/jwt/v4" + "github.com/wso2/open-mcp-auth-proxy/internal/config" +) type Decision int @@ -15,5 +20,5 @@ type AccessControlResult struct { } type AccessControl interface { - ValidateAccess(r *http.Request, claims *TokenClaims, requiredScopes any) AccessControlResult + ValidateAccess(r *http.Request, claims *jwt.MapClaims, config *config.Config) AccessControlResult } diff --git a/internal/authz/scope_validator.go b/internal/authz/scope_validator.go index 03ef3bf..004fd80 100644 --- a/internal/authz/scope_validator.go +++ b/internal/authz/scope_validator.go @@ -4,54 +4,68 @@ import ( "fmt" "net/http" "strings" -) -type TokenClaims struct { - Scopes []string -} + "github.com/golang-jwt/jwt/v4" + "github.com/wso2/open-mcp-auth-proxy/internal/config" + "github.com/wso2/open-mcp-auth-proxy/internal/util" +) type ScopeValidator struct{} // Evaluate and checks the token claims against one or more required scopes. func (d *ScopeValidator) ValidateAccess( - _ *http.Request, - claims *TokenClaims, - requiredScopes any, + r *http.Request, + claims *jwt.MapClaims, + config *config.Config, ) AccessControlResult { - var scopeStr string - switch v := requiredScopes.(type) { - case string: - scopeStr = v - case []string: - scopeStr = strings.Join(v, " ") + env, err := util.ParseRPCRequest(r) + if err != nil { + return AccessControlResult{DecisionDeny, "bad JSON-RPC request"} + } + requiredScopes := util.GetRequiredScopes(config, env.Method) + if len(requiredScopes) == 0 { + return AccessControlResult{DecisionAllow, ""} + } + + required := make(map[string]struct{}, len(requiredScopes)) + for _, s := range requiredScopes { + s = strings.TrimSpace(s) + if s != "" { + required[s] = struct{}{} + } + } + + var tokenScopes []string + if claims, ok := (*claims)["scope"]; ok { + switch v := claims.(type) { + case string: + tokenScopes = strings.Fields(v) + case []interface{}: + for _, x := range v { + if s, ok := x.(string); ok && s != "" { + tokenScopes = append(tokenScopes, s) + } + } + } + } + + tokenScopeSet := make(map[string]struct{}, len(tokenScopes)) + for _, s := range tokenScopes { + tokenScopeSet[s] = struct{}{} } - if strings.TrimSpace(scopeStr) == "" { - return AccessControlResult{DecisionAllow, ""} - } - - scopes := strings.FieldsFunc(scopeStr, func(r rune) bool { - return r == ' ' || r == ',' - }) - required := make(map[string]struct{}, len(scopes)) - for _, s := range scopes { - if s = strings.TrimSpace(s); s != "" { - required[s] = struct{}{} - } - } - - for _, tokenScope := range claims.Scopes { - if _, ok := required[tokenScope]; ok { - return AccessControlResult{DecisionAllow, ""} - } - } - - var list []string + var missing []string for s := range required { - list = append(list, s) + if _, ok := tokenScopeSet[s]; !ok { + missing = append(missing, s) + } + } + + if len(missing) == 0 { + return AccessControlResult{DecisionAllow, ""} } return AccessControlResult{ DecisionDeny, - fmt.Sprintf("missing required scope(s): %s", strings.Join(list, ", ")), + fmt.Sprintf("missing required scope(s): %s", strings.Join(missing, ", ")), } } diff --git a/internal/proxy/proxy.go b/internal/proxy/proxy.go index 83aeb6e..b880f99 100644 --- a/internal/proxy/proxy.go +++ b/internal/proxy/proxy.go @@ -302,6 +302,7 @@ func authorizeSSE(w http.ResponseWriter, r *http.Request, isLatestSpec bool, res // Handles both v1 (just signature) and v2 (aud + scope) flows func authorizeMCP(w http.ResponseWriter, r *http.Request, isLatestSpec bool, cfg *config.Config, accessController authz.AccessControl) error { authzHeader := r.Header.Get("Authorization") + accessToken, _ := util.ExtractAccessToken(authzHeader) if !strings.HasPrefix(authzHeader, "Bearer ") { if isLatestSpec { realm := cfg.ResourceIdentifier + "/.well-known/oauth-protected-resource" @@ -314,7 +315,7 @@ func authorizeMCP(w http.ResponseWriter, r *http.Request, isLatestSpec bool, cfg return fmt.Errorf("missing or invalid Authorization header") } - claims, err := util.ValidateJWT(isLatestSpec, authzHeader, cfg.Audience) + err := util.ValidateJWT(isLatestSpec, accessToken, cfg.Audience) if err != nil { if isLatestSpec { realm := cfg.ResourceIdentifier + "/.well-known/oauth-protected-resource" @@ -331,16 +332,19 @@ func authorizeMCP(w http.ResponseWriter, r *http.Request, isLatestSpec bool, cfg } if isLatestSpec { - env, err := util.ParseRPCRequest(r) + _, err := util.ParseRPCRequest(r) if err != nil { http.Error(w, "Bad request", http.StatusBadRequest) return err } - requiredScopes := util.GetRequiredScopes(cfg, env.Method) - if len(requiredScopes) == 0 { - return nil + + claimsMap, err := util.ParseJWT(accessToken) + if err != nil { + http.Error(w, "Invalid token claims", http.StatusUnauthorized) + return fmt.Errorf("invalid token claims") } - pr := accessController.ValidateAccess(r, claims, requiredScopes) + + pr := accessController.ValidateAccess(r, &claimsMap, cfg) if pr.Decision == authz.DecisionDeny { http.Error(w, "Forbidden: "+pr.Message, http.StatusForbidden) return fmt.Errorf("forbidden — %s", pr.Message) diff --git a/internal/util/jwks.go b/internal/util/jwks.go index 54ca735..b1afb6f 100644 --- a/internal/util/jwks.go +++ b/internal/util/jwks.go @@ -10,7 +10,6 @@ import ( "strings" "github.com/golang-jwt/jwt/v4" - "github.com/wso2/open-mcp-auth-proxy/internal/authz" "github.com/wso2/open-mcp-auth-proxy/internal/config" "github.com/wso2/open-mcp-auth-proxy/internal/logging" ) @@ -83,15 +82,12 @@ func parseRSAPublicKey(nStr, eStr string) (*rsa.PublicKey, error) { // ValidateJWT checks the Bearer token according to the Mcp-Protocol-Version. func ValidateJWT( isLatestSpec bool, - authHeader, audience string, -) (*authz.TokenClaims, error) { - tokenStr := strings.TrimPrefix(authHeader, "Bearer ") - if tokenStr == "" { - return nil, errors.New("empty bearer token") - } - + accessToken string, + audience string, +) error { + logger.Warn("isLatestSpec: %s", isLatestSpec) // Parse & verify the signature - token, err := jwt.Parse(tokenStr, func(token *jwt.Token) (interface{}, error) { + token, err := jwt.Parse(accessToken, func(token *jwt.Token) (interface{}, error) { if _, ok := token.Method.(*jwt.SigningMethodRSA); !ok { return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"]) } @@ -106,29 +102,31 @@ func ValidateJWT( return key, nil }) if err != nil { - return nil, fmt.Errorf("invalid token: %w", err) + logger.Warn("Error detected, returning early") + return fmt.Errorf("invalid token: %w", err) } if !token.Valid { - return nil, errors.New("token not valid") + logger.Warn("Token invalid, returning early") + return errors.New("token not valid") } claimsMap, ok := token.Claims.(jwt.MapClaims) if !ok { - return nil, errors.New("unexpected claim type") + return errors.New("unexpected claim type") } if !isLatestSpec { - return &authz.TokenClaims{Scopes: nil}, nil + return nil } audRaw, exists := claimsMap["aud"] if !exists { - return nil, errors.New("aud claim missing") + return errors.New("aud claim missing") } switch v := audRaw.(type) { case string: if v != audience { - return nil, fmt.Errorf("aud %q does not match %q", v, audience) + return fmt.Errorf("aud %q does not match %q", v, audience) } case []interface{}: var found bool @@ -139,38 +137,72 @@ func ValidateJWT( } } if !found { - return nil, fmt.Errorf("audience %v does not include %q", v, audience) + return fmt.Errorf("audience %v does not include %q", v, audience) } default: - return nil, errors.New("aud claim has unexpected type") + return errors.New("aud claim has unexpected type") } - rawScope := claimsMap["scope"] - scopeList := []string{} - if s, ok := rawScope.(string); ok { - scopeList = strings.Fields(s) - } + return nil +} - return &authz.TokenClaims{Scopes: scopeList}, nil +// Parses the JWT token and returns the claims +func ParseJWT(tokenStr string) (jwt.MapClaims, error) { + if tokenStr == "" { + return nil, fmt.Errorf("empty JWT") + } + + var claims jwt.MapClaims + _, _, err := jwt.NewParser().ParseUnverified(tokenStr, &claims) + if err != nil { + return nil, fmt.Errorf("failed to parse JWT: %w", err) + } + return claims, nil } // Process the required scopes func GetRequiredScopes(cfg *config.Config, method string) []string { - if scopes, ok := cfg.ScopesSupported.(map[string]string); ok && len(scopes) > 0 { - if scope, ok := scopes[method]; ok { - return []string{scope} - } - if parts := strings.SplitN(method, "/", 2); len(parts) > 0 { - if scope, ok := scopes[parts[0]]; ok { - return []string{scope} - } - } - return nil - } + switch raw := cfg.ScopesSupported.(type) { + case map[string]string: + if scope, ok := raw[method]; ok { + return []string{scope} + } + parts := strings.SplitN(method, "/", 2) + if len(parts) > 0 { + if scope, ok := raw[parts[0]]; ok { + return []string{scope} + } + } + return nil + case []interface{}: + out := make([]string, 0, len(raw)) + for _, v := range raw { + if s, ok := v.(string); ok && s != "" { + out = append(out, s) + } + } + return out - if scopes, ok := cfg.ScopesSupported.([]string); ok && len(scopes) > 0 { - return scopes - } + case []string: + return raw + } - return []string{} + return nil +} + +// Extracts the Bearer token from the Authorization header +func ExtractAccessToken(authHeader string) (string, error) { + if authHeader == "" { + return "", errors.New("empty authorization header") + } + if !strings.HasPrefix(authHeader, "Bearer ") { + return "", fmt.Errorf("invalid authorization header format: %s", authHeader) + } + + tokenStr := strings.TrimSpace(strings.TrimPrefix(authHeader, "Bearer ")) + if tokenStr == "" { + return "", errors.New("empty bearer token") + } + + return tokenStr, nil }