mirror of
https://github.com/wso2/open-mcp-auth-proxy.git
synced 2025-06-28 17:34:19 +00:00
Refactor scope validation
This commit is contained in:
parent
ed525dc7b5
commit
7d64cc4093
7 changed files with 115 additions and 102 deletions
|
@ -11,7 +11,8 @@ import (
|
|||
|
||||
"github.com/golang-jwt/jwt/v4"
|
||||
"github.com/wso2/open-mcp-auth-proxy/internal/authz"
|
||||
"github.com/wso2/open-mcp-auth-proxy/internal/logging"
|
||||
"github.com/wso2/open-mcp-auth-proxy/internal/config"
|
||||
logger "github.com/wso2/open-mcp-auth-proxy/internal/logging"
|
||||
)
|
||||
|
||||
type TokenClaims struct {
|
||||
|
@ -80,19 +81,20 @@ func parseRSAPublicKey(nStr, eStr string) (*rsa.PublicKey, error) {
|
|||
}
|
||||
|
||||
// ValidateJWT checks the Bearer token according to the Mcp-Protocol-Version.
|
||||
// - versionHeader: the raw value of the "Mcp-Protocol-Version" header
|
||||
// - isLatestSpec: whether to use the latest spec validation
|
||||
// - authHeader: the full "Authorization" header
|
||||
// - audience: the resource identifier to check "aud" against
|
||||
// - requiredScope: the single scope required (empty ⇒ skip scope check)
|
||||
// - requiredScopes: the scopes required (empty ⇒ skip scope check)
|
||||
func ValidateJWT(
|
||||
isLatestSpec bool, authHeader, audience, requiredScope string,
|
||||
isLatestSpec bool,
|
||||
authHeader, audience string,
|
||||
) (*authz.TokenClaims, error) {
|
||||
tokenStr := strings.TrimPrefix(authHeader, "Bearer ")
|
||||
if tokenStr == "" {
|
||||
return nil, errors.New("empty bearer token")
|
||||
}
|
||||
|
||||
// 2) parse & verify signature
|
||||
// --- parse & verify signature ---
|
||||
token, err := jwt.Parse(tokenStr, func(token *jwt.Token) (interface{}, error) {
|
||||
if _, ok := token.Method.(*jwt.SigningMethodRSA); !ok {
|
||||
return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
|
||||
|
@ -107,7 +109,6 @@ func ValidateJWT(
|
|||
}
|
||||
return key, nil
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid token: %w", err)
|
||||
}
|
||||
|
@ -115,18 +116,19 @@ func ValidateJWT(
|
|||
return nil, errors.New("token not valid")
|
||||
}
|
||||
|
||||
// always extract claims
|
||||
// --- extract raw claims ---
|
||||
claimsMap, ok := token.Claims.(jwt.MapClaims)
|
||||
if !ok {
|
||||
return nil, errors.New("unexpected claim type")
|
||||
}
|
||||
|
||||
// if older than cutover, skip audience+scope
|
||||
// --- v1: skip audience check entirely ---
|
||||
if !isLatestSpec {
|
||||
// we still want to return an empty set of scopes for policy to see
|
||||
return &authz.TokenClaims{Scopes: nil}, nil
|
||||
}
|
||||
|
||||
// --- new spec flow: enforce audience ---
|
||||
// --- v2: enforce audience ---
|
||||
audRaw, exists := claimsMap["aud"]
|
||||
if !exists {
|
||||
return nil, errors.New("aud claim missing")
|
||||
|
@ -151,25 +153,33 @@ func ValidateJWT(
|
|||
return nil, errors.New("aud claim has unexpected type")
|
||||
}
|
||||
|
||||
// if no scope required, we're done
|
||||
if requiredScope == "" {
|
||||
return &authz.TokenClaims{Scopes: nil}, nil
|
||||
// --- collect all scopes from the token, if any ---
|
||||
rawScope := claimsMap["scope"]
|
||||
scopeList := []string{}
|
||||
if s, ok := rawScope.(string); ok {
|
||||
scopeList = strings.Fields(s)
|
||||
}
|
||||
|
||||
// enforce scope
|
||||
rawScope, exists := claimsMap["scope"]
|
||||
if !exists {
|
||||
return nil, errors.New("scope claim missing")
|
||||
}
|
||||
scopeStr, ok := rawScope.(string)
|
||||
if !ok {
|
||||
return nil, errors.New("scope claim not a string")
|
||||
}
|
||||
scopes := strings.Fields(scopeStr)
|
||||
for _, s := range scopes {
|
||||
if s == requiredScope {
|
||||
return &authz.TokenClaims{Scopes: scopes}, nil
|
||||
}
|
||||
}
|
||||
return nil, fmt.Errorf("insufficient scope: %q not in %v", requiredScope, scopes)
|
||||
return &authz.TokenClaims{Scopes: scopeList}, nil
|
||||
}
|
||||
|
||||
// Process the required scopes
|
||||
func GetRequiredScopes(cfg *config.Config, method string) []string {
|
||||
if scopes, ok := cfg.ScopesSupported.(map[string]string); ok && len(scopes) > 0 {
|
||||
if scope, ok := scopes[method]; ok {
|
||||
return []string{scope}
|
||||
}
|
||||
if parts := strings.SplitN(method, "/", 2); len(parts) > 0 {
|
||||
if scope, ok := scopes[parts[0]]; ok {
|
||||
return []string{scope}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
if scopes, ok := cfg.ScopesSupported.([]string); ok && len(scopes) > 0 {
|
||||
return scopes
|
||||
}
|
||||
|
||||
return []string{}
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue