Update scope validation implementation

This commit is contained in:
NipuniBhagya 2025-05-21 10:00:01 +05:30
parent 5c22f36ddc
commit 64caaa0f7c
7 changed files with 202 additions and 138 deletions

View file

@ -10,7 +10,6 @@ import (
"strings"
"github.com/golang-jwt/jwt/v4"
"github.com/wso2/open-mcp-auth-proxy/internal/authz"
"github.com/wso2/open-mcp-auth-proxy/internal/config"
"github.com/wso2/open-mcp-auth-proxy/internal/logging"
)
@ -83,15 +82,12 @@ func parseRSAPublicKey(nStr, eStr string) (*rsa.PublicKey, error) {
// ValidateJWT checks the Bearer token according to the Mcp-Protocol-Version.
func ValidateJWT(
isLatestSpec bool,
authHeader, audience string,
) (*authz.TokenClaims, error) {
tokenStr := strings.TrimPrefix(authHeader, "Bearer ")
if tokenStr == "" {
return nil, errors.New("empty bearer token")
}
accessToken string,
audience string,
) error {
logger.Warn("isLatestSpec: %s", isLatestSpec)
// Parse & verify the signature
token, err := jwt.Parse(tokenStr, func(token *jwt.Token) (interface{}, error) {
token, err := jwt.Parse(accessToken, func(token *jwt.Token) (interface{}, error) {
if _, ok := token.Method.(*jwt.SigningMethodRSA); !ok {
return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
}
@ -106,29 +102,31 @@ func ValidateJWT(
return key, nil
})
if err != nil {
return nil, fmt.Errorf("invalid token: %w", err)
logger.Warn("Error detected, returning early")
return fmt.Errorf("invalid token: %w", err)
}
if !token.Valid {
return nil, errors.New("token not valid")
logger.Warn("Token invalid, returning early")
return errors.New("token not valid")
}
claimsMap, ok := token.Claims.(jwt.MapClaims)
if !ok {
return nil, errors.New("unexpected claim type")
return errors.New("unexpected claim type")
}
if !isLatestSpec {
return &authz.TokenClaims{Scopes: nil}, nil
return nil
}
audRaw, exists := claimsMap["aud"]
if !exists {
return nil, errors.New("aud claim missing")
return errors.New("aud claim missing")
}
switch v := audRaw.(type) {
case string:
if v != audience {
return nil, fmt.Errorf("aud %q does not match %q", v, audience)
return fmt.Errorf("aud %q does not match %q", v, audience)
}
case []interface{}:
var found bool
@ -139,38 +137,72 @@ func ValidateJWT(
}
}
if !found {
return nil, fmt.Errorf("audience %v does not include %q", v, audience)
return fmt.Errorf("audience %v does not include %q", v, audience)
}
default:
return nil, errors.New("aud claim has unexpected type")
return errors.New("aud claim has unexpected type")
}
rawScope := claimsMap["scope"]
scopeList := []string{}
if s, ok := rawScope.(string); ok {
scopeList = strings.Fields(s)
}
return nil
}
return &authz.TokenClaims{Scopes: scopeList}, nil
// Parses the JWT token and returns the claims
func ParseJWT(tokenStr string) (jwt.MapClaims, error) {
if tokenStr == "" {
return nil, fmt.Errorf("empty JWT")
}
var claims jwt.MapClaims
_, _, err := jwt.NewParser().ParseUnverified(tokenStr, &claims)
if err != nil {
return nil, fmt.Errorf("failed to parse JWT: %w", err)
}
return claims, nil
}
// Process the required scopes
func GetRequiredScopes(cfg *config.Config, method string) []string {
if scopes, ok := cfg.ScopesSupported.(map[string]string); ok && len(scopes) > 0 {
if scope, ok := scopes[method]; ok {
return []string{scope}
}
if parts := strings.SplitN(method, "/", 2); len(parts) > 0 {
if scope, ok := scopes[parts[0]]; ok {
return []string{scope}
}
}
return nil
}
switch raw := cfg.ScopesSupported.(type) {
case map[string]string:
if scope, ok := raw[method]; ok {
return []string{scope}
}
parts := strings.SplitN(method, "/", 2)
if len(parts) > 0 {
if scope, ok := raw[parts[0]]; ok {
return []string{scope}
}
}
return nil
case []interface{}:
out := make([]string, 0, len(raw))
for _, v := range raw {
if s, ok := v.(string); ok && s != "" {
out = append(out, s)
}
}
return out
if scopes, ok := cfg.ScopesSupported.([]string); ok && len(scopes) > 0 {
return scopes
}
case []string:
return raw
}
return []string{}
return nil
}
// Extracts the Bearer token from the Authorization header
func ExtractAccessToken(authHeader string) (string, error) {
if authHeader == "" {
return "", errors.New("empty authorization header")
}
if !strings.HasPrefix(authHeader, "Bearer ") {
return "", fmt.Errorf("invalid authorization header format: %s", authHeader)
}
tokenStr := strings.TrimSpace(strings.TrimPrefix(authHeader, "Bearer "))
if tokenStr == "" {
return "", errors.New("empty bearer token")
}
return tokenStr, nil
}