mirror of
https://github.com/wso2/open-mcp-auth-proxy.git
synced 2025-06-29 09:54:19 +00:00
Update scope validation implementation
This commit is contained in:
parent
5c22f36ddc
commit
64caaa0f7c
7 changed files with 202 additions and 138 deletions
|
@ -10,7 +10,6 @@ import (
|
|||
"strings"
|
||||
|
||||
"github.com/golang-jwt/jwt/v4"
|
||||
"github.com/wso2/open-mcp-auth-proxy/internal/authz"
|
||||
"github.com/wso2/open-mcp-auth-proxy/internal/config"
|
||||
"github.com/wso2/open-mcp-auth-proxy/internal/logging"
|
||||
)
|
||||
|
@ -83,15 +82,12 @@ func parseRSAPublicKey(nStr, eStr string) (*rsa.PublicKey, error) {
|
|||
// ValidateJWT checks the Bearer token according to the Mcp-Protocol-Version.
|
||||
func ValidateJWT(
|
||||
isLatestSpec bool,
|
||||
authHeader, audience string,
|
||||
) (*authz.TokenClaims, error) {
|
||||
tokenStr := strings.TrimPrefix(authHeader, "Bearer ")
|
||||
if tokenStr == "" {
|
||||
return nil, errors.New("empty bearer token")
|
||||
}
|
||||
|
||||
accessToken string,
|
||||
audience string,
|
||||
) error {
|
||||
logger.Warn("isLatestSpec: %s", isLatestSpec)
|
||||
// Parse & verify the signature
|
||||
token, err := jwt.Parse(tokenStr, func(token *jwt.Token) (interface{}, error) {
|
||||
token, err := jwt.Parse(accessToken, func(token *jwt.Token) (interface{}, error) {
|
||||
if _, ok := token.Method.(*jwt.SigningMethodRSA); !ok {
|
||||
return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
|
||||
}
|
||||
|
@ -106,29 +102,31 @@ func ValidateJWT(
|
|||
return key, nil
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid token: %w", err)
|
||||
logger.Warn("Error detected, returning early")
|
||||
return fmt.Errorf("invalid token: %w", err)
|
||||
}
|
||||
if !token.Valid {
|
||||
return nil, errors.New("token not valid")
|
||||
logger.Warn("Token invalid, returning early")
|
||||
return errors.New("token not valid")
|
||||
}
|
||||
|
||||
claimsMap, ok := token.Claims.(jwt.MapClaims)
|
||||
if !ok {
|
||||
return nil, errors.New("unexpected claim type")
|
||||
return errors.New("unexpected claim type")
|
||||
}
|
||||
|
||||
if !isLatestSpec {
|
||||
return &authz.TokenClaims{Scopes: nil}, nil
|
||||
return nil
|
||||
}
|
||||
|
||||
audRaw, exists := claimsMap["aud"]
|
||||
if !exists {
|
||||
return nil, errors.New("aud claim missing")
|
||||
return errors.New("aud claim missing")
|
||||
}
|
||||
switch v := audRaw.(type) {
|
||||
case string:
|
||||
if v != audience {
|
||||
return nil, fmt.Errorf("aud %q does not match %q", v, audience)
|
||||
return fmt.Errorf("aud %q does not match %q", v, audience)
|
||||
}
|
||||
case []interface{}:
|
||||
var found bool
|
||||
|
@ -139,38 +137,72 @@ func ValidateJWT(
|
|||
}
|
||||
}
|
||||
if !found {
|
||||
return nil, fmt.Errorf("audience %v does not include %q", v, audience)
|
||||
return fmt.Errorf("audience %v does not include %q", v, audience)
|
||||
}
|
||||
default:
|
||||
return nil, errors.New("aud claim has unexpected type")
|
||||
return errors.New("aud claim has unexpected type")
|
||||
}
|
||||
|
||||
rawScope := claimsMap["scope"]
|
||||
scopeList := []string{}
|
||||
if s, ok := rawScope.(string); ok {
|
||||
scopeList = strings.Fields(s)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
return &authz.TokenClaims{Scopes: scopeList}, nil
|
||||
// Parses the JWT token and returns the claims
|
||||
func ParseJWT(tokenStr string) (jwt.MapClaims, error) {
|
||||
if tokenStr == "" {
|
||||
return nil, fmt.Errorf("empty JWT")
|
||||
}
|
||||
|
||||
var claims jwt.MapClaims
|
||||
_, _, err := jwt.NewParser().ParseUnverified(tokenStr, &claims)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to parse JWT: %w", err)
|
||||
}
|
||||
return claims, nil
|
||||
}
|
||||
|
||||
// Process the required scopes
|
||||
func GetRequiredScopes(cfg *config.Config, method string) []string {
|
||||
if scopes, ok := cfg.ScopesSupported.(map[string]string); ok && len(scopes) > 0 {
|
||||
if scope, ok := scopes[method]; ok {
|
||||
return []string{scope}
|
||||
}
|
||||
if parts := strings.SplitN(method, "/", 2); len(parts) > 0 {
|
||||
if scope, ok := scopes[parts[0]]; ok {
|
||||
return []string{scope}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
switch raw := cfg.ScopesSupported.(type) {
|
||||
case map[string]string:
|
||||
if scope, ok := raw[method]; ok {
|
||||
return []string{scope}
|
||||
}
|
||||
parts := strings.SplitN(method, "/", 2)
|
||||
if len(parts) > 0 {
|
||||
if scope, ok := raw[parts[0]]; ok {
|
||||
return []string{scope}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
case []interface{}:
|
||||
out := make([]string, 0, len(raw))
|
||||
for _, v := range raw {
|
||||
if s, ok := v.(string); ok && s != "" {
|
||||
out = append(out, s)
|
||||
}
|
||||
}
|
||||
return out
|
||||
|
||||
if scopes, ok := cfg.ScopesSupported.([]string); ok && len(scopes) > 0 {
|
||||
return scopes
|
||||
}
|
||||
case []string:
|
||||
return raw
|
||||
}
|
||||
|
||||
return []string{}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Extracts the Bearer token from the Authorization header
|
||||
func ExtractAccessToken(authHeader string) (string, error) {
|
||||
if authHeader == "" {
|
||||
return "", errors.New("empty authorization header")
|
||||
}
|
||||
if !strings.HasPrefix(authHeader, "Bearer ") {
|
||||
return "", fmt.Errorf("invalid authorization header format: %s", authHeader)
|
||||
}
|
||||
|
||||
tokenStr := strings.TrimSpace(strings.TrimPrefix(authHeader, "Bearer "))
|
||||
if tokenStr == "" {
|
||||
return "", errors.New("empty bearer token")
|
||||
}
|
||||
|
||||
return tokenStr, nil
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue