Update scope validation implementation

This commit is contained in:
NipuniBhagya 2025-05-21 10:00:01 +05:30
parent 5c22f36ddc
commit 64caaa0f7c
7 changed files with 202 additions and 138 deletions

View file

@ -1,6 +1,11 @@
package authz
import "net/http"
import (
"net/http"
"github.com/golang-jwt/jwt/v4"
"github.com/wso2/open-mcp-auth-proxy/internal/config"
)
type Decision int
@ -15,5 +20,5 @@ type AccessControlResult struct {
}
type AccessControl interface {
ValidateAccess(r *http.Request, claims *TokenClaims, requiredScopes any) AccessControlResult
ValidateAccess(r *http.Request, claims *jwt.MapClaims, config *config.Config) AccessControlResult
}

View file

@ -4,54 +4,68 @@ import (
"fmt"
"net/http"
"strings"
)
type TokenClaims struct {
Scopes []string
}
"github.com/golang-jwt/jwt/v4"
"github.com/wso2/open-mcp-auth-proxy/internal/config"
"github.com/wso2/open-mcp-auth-proxy/internal/util"
)
type ScopeValidator struct{}
// Evaluate and checks the token claims against one or more required scopes.
func (d *ScopeValidator) ValidateAccess(
_ *http.Request,
claims *TokenClaims,
requiredScopes any,
r *http.Request,
claims *jwt.MapClaims,
config *config.Config,
) AccessControlResult {
var scopeStr string
switch v := requiredScopes.(type) {
case string:
scopeStr = v
case []string:
scopeStr = strings.Join(v, " ")
env, err := util.ParseRPCRequest(r)
if err != nil {
return AccessControlResult{DecisionDeny, "bad JSON-RPC request"}
}
requiredScopes := util.GetRequiredScopes(config, env.Method)
if len(requiredScopes) == 0 {
return AccessControlResult{DecisionAllow, ""}
}
required := make(map[string]struct{}, len(requiredScopes))
for _, s := range requiredScopes {
s = strings.TrimSpace(s)
if s != "" {
required[s] = struct{}{}
}
}
var tokenScopes []string
if claims, ok := (*claims)["scope"]; ok {
switch v := claims.(type) {
case string:
tokenScopes = strings.Fields(v)
case []interface{}:
for _, x := range v {
if s, ok := x.(string); ok && s != "" {
tokenScopes = append(tokenScopes, s)
}
}
}
}
tokenScopeSet := make(map[string]struct{}, len(tokenScopes))
for _, s := range tokenScopes {
tokenScopeSet[s] = struct{}{}
}
if strings.TrimSpace(scopeStr) == "" {
return AccessControlResult{DecisionAllow, ""}
}
scopes := strings.FieldsFunc(scopeStr, func(r rune) bool {
return r == ' ' || r == ','
})
required := make(map[string]struct{}, len(scopes))
for _, s := range scopes {
if s = strings.TrimSpace(s); s != "" {
required[s] = struct{}{}
}
}
for _, tokenScope := range claims.Scopes {
if _, ok := required[tokenScope]; ok {
return AccessControlResult{DecisionAllow, ""}
}
}
var list []string
var missing []string
for s := range required {
list = append(list, s)
if _, ok := tokenScopeSet[s]; !ok {
missing = append(missing, s)
}
}
if len(missing) == 0 {
return AccessControlResult{DecisionAllow, ""}
}
return AccessControlResult{
DecisionDeny,
fmt.Sprintf("missing required scope(s): %s", strings.Join(list, ", ")),
fmt.Sprintf("missing required scope(s): %s", strings.Join(missing, ", ")),
}
}

View file

@ -302,6 +302,7 @@ func authorizeSSE(w http.ResponseWriter, r *http.Request, isLatestSpec bool, res
// Handles both v1 (just signature) and v2 (aud + scope) flows
func authorizeMCP(w http.ResponseWriter, r *http.Request, isLatestSpec bool, cfg *config.Config, accessController authz.AccessControl) error {
authzHeader := r.Header.Get("Authorization")
accessToken, _ := util.ExtractAccessToken(authzHeader)
if !strings.HasPrefix(authzHeader, "Bearer ") {
if isLatestSpec {
realm := cfg.ResourceIdentifier + "/.well-known/oauth-protected-resource"
@ -314,7 +315,7 @@ func authorizeMCP(w http.ResponseWriter, r *http.Request, isLatestSpec bool, cfg
return fmt.Errorf("missing or invalid Authorization header")
}
claims, err := util.ValidateJWT(isLatestSpec, authzHeader, cfg.Audience)
err := util.ValidateJWT(isLatestSpec, accessToken, cfg.Audience)
if err != nil {
if isLatestSpec {
realm := cfg.ResourceIdentifier + "/.well-known/oauth-protected-resource"
@ -331,16 +332,19 @@ func authorizeMCP(w http.ResponseWriter, r *http.Request, isLatestSpec bool, cfg
}
if isLatestSpec {
env, err := util.ParseRPCRequest(r)
_, err := util.ParseRPCRequest(r)
if err != nil {
http.Error(w, "Bad request", http.StatusBadRequest)
return err
}
requiredScopes := util.GetRequiredScopes(cfg, env.Method)
if len(requiredScopes) == 0 {
return nil
claimsMap, err := util.ParseJWT(accessToken)
if err != nil {
http.Error(w, "Invalid token claims", http.StatusUnauthorized)
return fmt.Errorf("invalid token claims")
}
pr := accessController.ValidateAccess(r, claims, requiredScopes)
pr := accessController.ValidateAccess(r, &claimsMap, cfg)
if pr.Decision == authz.DecisionDeny {
http.Error(w, "Forbidden: "+pr.Message, http.StatusForbidden)
return fmt.Errorf("forbidden — %s", pr.Message)

View file

@ -10,7 +10,6 @@ import (
"strings"
"github.com/golang-jwt/jwt/v4"
"github.com/wso2/open-mcp-auth-proxy/internal/authz"
"github.com/wso2/open-mcp-auth-proxy/internal/config"
"github.com/wso2/open-mcp-auth-proxy/internal/logging"
)
@ -83,15 +82,12 @@ func parseRSAPublicKey(nStr, eStr string) (*rsa.PublicKey, error) {
// ValidateJWT checks the Bearer token according to the Mcp-Protocol-Version.
func ValidateJWT(
isLatestSpec bool,
authHeader, audience string,
) (*authz.TokenClaims, error) {
tokenStr := strings.TrimPrefix(authHeader, "Bearer ")
if tokenStr == "" {
return nil, errors.New("empty bearer token")
}
accessToken string,
audience string,
) error {
logger.Warn("isLatestSpec: %s", isLatestSpec)
// Parse & verify the signature
token, err := jwt.Parse(tokenStr, func(token *jwt.Token) (interface{}, error) {
token, err := jwt.Parse(accessToken, func(token *jwt.Token) (interface{}, error) {
if _, ok := token.Method.(*jwt.SigningMethodRSA); !ok {
return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
}
@ -106,29 +102,31 @@ func ValidateJWT(
return key, nil
})
if err != nil {
return nil, fmt.Errorf("invalid token: %w", err)
logger.Warn("Error detected, returning early")
return fmt.Errorf("invalid token: %w", err)
}
if !token.Valid {
return nil, errors.New("token not valid")
logger.Warn("Token invalid, returning early")
return errors.New("token not valid")
}
claimsMap, ok := token.Claims.(jwt.MapClaims)
if !ok {
return nil, errors.New("unexpected claim type")
return errors.New("unexpected claim type")
}
if !isLatestSpec {
return &authz.TokenClaims{Scopes: nil}, nil
return nil
}
audRaw, exists := claimsMap["aud"]
if !exists {
return nil, errors.New("aud claim missing")
return errors.New("aud claim missing")
}
switch v := audRaw.(type) {
case string:
if v != audience {
return nil, fmt.Errorf("aud %q does not match %q", v, audience)
return fmt.Errorf("aud %q does not match %q", v, audience)
}
case []interface{}:
var found bool
@ -139,38 +137,72 @@ func ValidateJWT(
}
}
if !found {
return nil, fmt.Errorf("audience %v does not include %q", v, audience)
return fmt.Errorf("audience %v does not include %q", v, audience)
}
default:
return nil, errors.New("aud claim has unexpected type")
return errors.New("aud claim has unexpected type")
}
rawScope := claimsMap["scope"]
scopeList := []string{}
if s, ok := rawScope.(string); ok {
scopeList = strings.Fields(s)
}
return nil
}
return &authz.TokenClaims{Scopes: scopeList}, nil
// Parses the JWT token and returns the claims
func ParseJWT(tokenStr string) (jwt.MapClaims, error) {
if tokenStr == "" {
return nil, fmt.Errorf("empty JWT")
}
var claims jwt.MapClaims
_, _, err := jwt.NewParser().ParseUnverified(tokenStr, &claims)
if err != nil {
return nil, fmt.Errorf("failed to parse JWT: %w", err)
}
return claims, nil
}
// Process the required scopes
func GetRequiredScopes(cfg *config.Config, method string) []string {
if scopes, ok := cfg.ScopesSupported.(map[string]string); ok && len(scopes) > 0 {
if scope, ok := scopes[method]; ok {
return []string{scope}
}
if parts := strings.SplitN(method, "/", 2); len(parts) > 0 {
if scope, ok := scopes[parts[0]]; ok {
return []string{scope}
}
}
return nil
}
switch raw := cfg.ScopesSupported.(type) {
case map[string]string:
if scope, ok := raw[method]; ok {
return []string{scope}
}
parts := strings.SplitN(method, "/", 2)
if len(parts) > 0 {
if scope, ok := raw[parts[0]]; ok {
return []string{scope}
}
}
return nil
case []interface{}:
out := make([]string, 0, len(raw))
for _, v := range raw {
if s, ok := v.(string); ok && s != "" {
out = append(out, s)
}
}
return out
if scopes, ok := cfg.ScopesSupported.([]string); ok && len(scopes) > 0 {
return scopes
}
case []string:
return raw
}
return []string{}
return nil
}
// Extracts the Bearer token from the Authorization header
func ExtractAccessToken(authHeader string) (string, error) {
if authHeader == "" {
return "", errors.New("empty authorization header")
}
if !strings.HasPrefix(authHeader, "Bearer ") {
return "", fmt.Errorf("invalid authorization header format: %s", authHeader)
}
tokenStr := strings.TrimSpace(strings.TrimPrefix(authHeader, "Bearer "))
if tokenStr == "" {
return "", errors.New("empty bearer token")
}
return tokenStr, nil
}