mirror of
https://github.com/wso2/open-mcp-auth-proxy.git
synced 2025-06-28 01:23:30 +00:00
Refactor scope validation
This commit is contained in:
parent
ed525dc7b5
commit
7d64cc4093
7 changed files with 115 additions and 102 deletions
|
@ -50,9 +50,9 @@ demo:
|
||||||
resource_identifier: http://localhost:3000
|
resource_identifier: http://localhost:3000
|
||||||
audience: mcp_proxy
|
audience: mcp_proxy
|
||||||
scopes_supported:
|
scopes_supported:
|
||||||
"read:tools"
|
- "tools":"read:tools"
|
||||||
"read:resources"
|
- "resources":"read:resources"
|
||||||
"read:prompts"
|
- "prompts":"read:prompts"
|
||||||
authorization_servers:
|
authorization_servers:
|
||||||
- https://api.asgardeo.io/t/acme/
|
- https://api.asgardeo.io/t/acme/
|
||||||
jwks_uri: https://api.asgardeo.io/t/acme/oauth2/jwks
|
jwks_uri: https://api.asgardeo.io/t/acme/oauth2/jwks
|
||||||
|
|
|
@ -4,6 +4,8 @@ import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
|
logger "github.com/wso2/open-mcp-auth-proxy/internal/logging"
|
||||||
)
|
)
|
||||||
|
|
||||||
type TokenClaims struct {
|
type TokenClaims struct {
|
||||||
|
@ -16,30 +18,42 @@ type DefaultPolicyEngine struct{}
|
||||||
func (d *DefaultPolicyEngine) Evaluate(
|
func (d *DefaultPolicyEngine) Evaluate(
|
||||||
_ *http.Request,
|
_ *http.Request,
|
||||||
claims *TokenClaims,
|
claims *TokenClaims,
|
||||||
requiredScope string,
|
requiredScopes any,
|
||||||
) PolicyResult {
|
) PolicyResult {
|
||||||
if strings.TrimSpace(requiredScope) == "" {
|
|
||||||
|
logger.Info("Required scopes: %v", requiredScopes)
|
||||||
|
|
||||||
|
var scopeStr string
|
||||||
|
switch v := requiredScopes.(type) {
|
||||||
|
case string:
|
||||||
|
scopeStr = v
|
||||||
|
case []string:
|
||||||
|
scopeStr = strings.Join(v, " ")
|
||||||
|
}
|
||||||
|
|
||||||
|
if strings.TrimSpace(scopeStr) == "" {
|
||||||
return PolicyResult{DecisionAllow, ""}
|
return PolicyResult{DecisionAllow, ""}
|
||||||
}
|
}
|
||||||
|
|
||||||
raw := strings.FieldsFunc(requiredScope, func(r rune) bool {
|
scopes := strings.FieldsFunc(scopeStr, func(r rune) bool {
|
||||||
return r == ' ' || r == ','
|
return r == ' ' || r == ','
|
||||||
})
|
})
|
||||||
want := make(map[string]struct{}, len(raw))
|
required := make(map[string]struct{}, len(scopes))
|
||||||
for _, s := range raw {
|
for _, s := range scopes {
|
||||||
if s = strings.TrimSpace(s); s != "" {
|
if s = strings.TrimSpace(s); s != "" {
|
||||||
want[s] = struct{}{}
|
required[s] = struct{}{}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, have := range claims.Scopes {
|
logger.Info("Token scopes: %v", claims.Scopes)
|
||||||
if _, ok := want[have]; ok {
|
for _, tokenScope := range claims.Scopes {
|
||||||
|
if _, ok := required[tokenScope]; ok {
|
||||||
return PolicyResult{DecisionAllow, ""}
|
return PolicyResult{DecisionAllow, ""}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
var list []string
|
var list []string
|
||||||
for s := range want {
|
for s := range required {
|
||||||
list = append(list, s)
|
list = append(list, s)
|
||||||
}
|
}
|
||||||
return PolicyResult{
|
return PolicyResult{
|
||||||
|
|
|
@ -15,5 +15,5 @@ type PolicyResult struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
type PolicyEngine interface {
|
type PolicyEngine interface {
|
||||||
Evaluate(r *http.Request, claims *TokenClaims, requiredScope string) PolicyResult
|
Evaluate(r *http.Request, claims *TokenClaims, requiredScopes any) PolicyResult
|
||||||
}
|
}
|
||||||
|
|
|
@ -105,7 +105,7 @@ type Config struct {
|
||||||
// Protected resource metadata
|
// Protected resource metadata
|
||||||
Audience string `yaml:"audience"`
|
Audience string `yaml:"audience"`
|
||||||
ResourceIdentifier string `yaml:"resource_identifier"`
|
ResourceIdentifier string `yaml:"resource_identifier"`
|
||||||
ScopesSupported map[string]string `yaml:"scopes_supported"`
|
ScopesSupported any `yaml:"scopes_supported"`
|
||||||
AuthorizationServers []string `yaml:"authorization_servers"`
|
AuthorizationServers []string `yaml:"authorization_servers"`
|
||||||
JwksURI string `yaml:"jwks_uri,omitempty"`
|
JwksURI string `yaml:"jwks_uri,omitempty"`
|
||||||
BearerMethodsSupported []string `yaml:"bearer_methods_supported,omitempty"`
|
BearerMethodsSupported []string `yaml:"bearer_methods_supported,omitempty"`
|
||||||
|
|
|
@ -175,20 +175,10 @@ func buildProxyHandler(cfg *config.Config, modifiers map[string]RequestModifier,
|
||||||
}
|
}
|
||||||
isSSE = true
|
isSSE = true
|
||||||
} else {
|
} else {
|
||||||
claims, err := authorizeMCP(w, r, isLatestSpec, cfg)
|
if err := authorizeMCP(w, r, isLatestSpec, cfg, policyEngine); err != nil {
|
||||||
if err != nil {
|
|
||||||
http.Error(w, err.Error(), http.StatusForbidden)
|
http.Error(w, err.Error(), http.StatusForbidden)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if isLatestSpec {
|
|
||||||
scope := cfg.ScopesSupported[r.URL.Path]
|
|
||||||
pr := policyEngine.Evaluate(r, claims, scope)
|
|
||||||
if pr.Decision == authz.DecisionDeny {
|
|
||||||
http.Error(w, "Forbidden: "+pr.Message, http.StatusForbidden)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
targetURL = mcpBase
|
targetURL = mcpBase
|
||||||
|
@ -310,15 +300,7 @@ func authorizeSSE(w http.ResponseWriter, r *http.Request, isLatestSpec bool, res
|
||||||
}
|
}
|
||||||
|
|
||||||
// Handles both v1 (just signature) and v2 (aud + scope) flows
|
// Handles both v1 (just signature) and v2 (aud + scope) flows
|
||||||
func authorizeMCP(w http.ResponseWriter, r *http.Request, isLatestSpec bool, cfg *config.Config) (*authz.TokenClaims, error) {
|
func authorizeMCP(w http.ResponseWriter, r *http.Request, isLatestSpec bool, cfg *config.Config, policyEngine authz.PolicyEngine) error {
|
||||||
// Parse JSON-RPC request if present
|
|
||||||
if env, err := util.ParseRPCRequest(r); err != nil {
|
|
||||||
http.Error(w, "Bad request", http.StatusBadRequest)
|
|
||||||
return nil, err
|
|
||||||
} else if env != nil {
|
|
||||||
logger.Info("JSON-RPC method = %q", env.Method)
|
|
||||||
}
|
|
||||||
|
|
||||||
authzHeader := r.Header.Get("Authorization")
|
authzHeader := r.Header.Get("Authorization")
|
||||||
if !strings.HasPrefix(authzHeader, "Bearer ") {
|
if !strings.HasPrefix(authzHeader, "Bearer ") {
|
||||||
if isLatestSpec {
|
if isLatestSpec {
|
||||||
|
@ -329,35 +311,43 @@ func authorizeMCP(w http.ResponseWriter, r *http.Request, isLatestSpec bool, cfg
|
||||||
w.Header().Set("Access-Control-Expose-Headers", "WWW-Authenticate")
|
w.Header().Set("Access-Control-Expose-Headers", "WWW-Authenticate")
|
||||||
}
|
}
|
||||||
http.Error(w, "Unauthorized", http.StatusUnauthorized)
|
http.Error(w, "Unauthorized", http.StatusUnauthorized)
|
||||||
return nil, fmt.Errorf("missing or invalid Authorization header")
|
return fmt.Errorf("missing or invalid Authorization header")
|
||||||
}
|
}
|
||||||
|
|
||||||
requiredScope := ""
|
claims, err := util.ValidateJWT(isLatestSpec, authzHeader, cfg.Audience)
|
||||||
if isLatestSpec {
|
|
||||||
requiredScope = cfg.ScopesSupported[r.URL.Path]
|
|
||||||
}
|
|
||||||
claims, err := util.ValidateJWT(
|
|
||||||
isLatestSpec,
|
|
||||||
authzHeader,
|
|
||||||
cfg.Audience,
|
|
||||||
requiredScope,
|
|
||||||
)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if isLatestSpec {
|
if isLatestSpec {
|
||||||
realm := cfg.ResourceIdentifier + "/.well-known/oauth-protected-resource"
|
realm := cfg.ResourceIdentifier + "/.well-known/oauth-protected-resource"
|
||||||
w.Header().Set("WWW-Authenticate", fmt.Sprintf(
|
w.Header().Set("WWW-Authenticate", fmt.Sprintf(err.Error(),
|
||||||
`Bearer realm=%q, error="insufficient_scope", scope=%q`,
|
`Bearer realm=%q`,
|
||||||
realm, requiredScope,
|
realm,
|
||||||
))
|
))
|
||||||
w.Header().Set("Access-Control-Expose-Headers", "WWW-Authenticate")
|
w.Header().Set("Access-Control-Expose-Headers", "WWW-Authenticate")
|
||||||
http.Error(w, "Forbidden", http.StatusForbidden)
|
http.Error(w, "Forbidden", http.StatusForbidden)
|
||||||
} else {
|
} else {
|
||||||
http.Error(w, "Unauthorized", http.StatusUnauthorized)
|
http.Error(w, "Unauthorized", http.StatusUnauthorized)
|
||||||
}
|
}
|
||||||
return nil, err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
return claims, nil
|
if isLatestSpec {
|
||||||
|
env, err := util.ParseRPCRequest(r)
|
||||||
|
if err != nil {
|
||||||
|
http.Error(w, "Bad request", http.StatusBadRequest)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
requiredScopes := util.GetRequiredScopes(cfg, env.Method)
|
||||||
|
if len(requiredScopes) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
pr := policyEngine.Evaluate(r, claims, requiredScopes)
|
||||||
|
if pr.Decision == authz.DecisionDeny {
|
||||||
|
http.Error(w, "Forbidden: "+pr.Message, http.StatusForbidden)
|
||||||
|
return fmt.Errorf("forbidden — %s", pr.Message)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func getAllowedOrigin(origin string, cfg *config.Config) string {
|
func getAllowedOrigin(origin string, cfg *config.Config) string {
|
||||||
|
|
|
@ -11,7 +11,8 @@ import (
|
||||||
|
|
||||||
"github.com/golang-jwt/jwt/v4"
|
"github.com/golang-jwt/jwt/v4"
|
||||||
"github.com/wso2/open-mcp-auth-proxy/internal/authz"
|
"github.com/wso2/open-mcp-auth-proxy/internal/authz"
|
||||||
"github.com/wso2/open-mcp-auth-proxy/internal/logging"
|
"github.com/wso2/open-mcp-auth-proxy/internal/config"
|
||||||
|
logger "github.com/wso2/open-mcp-auth-proxy/internal/logging"
|
||||||
)
|
)
|
||||||
|
|
||||||
type TokenClaims struct {
|
type TokenClaims struct {
|
||||||
|
@ -80,19 +81,20 @@ func parseRSAPublicKey(nStr, eStr string) (*rsa.PublicKey, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// ValidateJWT checks the Bearer token according to the Mcp-Protocol-Version.
|
// ValidateJWT checks the Bearer token according to the Mcp-Protocol-Version.
|
||||||
// - versionHeader: the raw value of the "Mcp-Protocol-Version" header
|
// - isLatestSpec: whether to use the latest spec validation
|
||||||
// - authHeader: the full "Authorization" header
|
// - authHeader: the full "Authorization" header
|
||||||
// - audience: the resource identifier to check "aud" against
|
// - audience: the resource identifier to check "aud" against
|
||||||
// - requiredScope: the single scope required (empty ⇒ skip scope check)
|
// - requiredScopes: the scopes required (empty ⇒ skip scope check)
|
||||||
func ValidateJWT(
|
func ValidateJWT(
|
||||||
isLatestSpec bool, authHeader, audience, requiredScope string,
|
isLatestSpec bool,
|
||||||
|
authHeader, audience string,
|
||||||
) (*authz.TokenClaims, error) {
|
) (*authz.TokenClaims, error) {
|
||||||
tokenStr := strings.TrimPrefix(authHeader, "Bearer ")
|
tokenStr := strings.TrimPrefix(authHeader, "Bearer ")
|
||||||
if tokenStr == "" {
|
if tokenStr == "" {
|
||||||
return nil, errors.New("empty bearer token")
|
return nil, errors.New("empty bearer token")
|
||||||
}
|
}
|
||||||
|
|
||||||
// 2) parse & verify signature
|
// --- parse & verify signature ---
|
||||||
token, err := jwt.Parse(tokenStr, func(token *jwt.Token) (interface{}, error) {
|
token, err := jwt.Parse(tokenStr, func(token *jwt.Token) (interface{}, error) {
|
||||||
if _, ok := token.Method.(*jwt.SigningMethodRSA); !ok {
|
if _, ok := token.Method.(*jwt.SigningMethodRSA); !ok {
|
||||||
return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
|
return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
|
||||||
|
@ -107,7 +109,6 @@ func ValidateJWT(
|
||||||
}
|
}
|
||||||
return key, nil
|
return key, nil
|
||||||
})
|
})
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("invalid token: %w", err)
|
return nil, fmt.Errorf("invalid token: %w", err)
|
||||||
}
|
}
|
||||||
|
@ -115,18 +116,19 @@ func ValidateJWT(
|
||||||
return nil, errors.New("token not valid")
|
return nil, errors.New("token not valid")
|
||||||
}
|
}
|
||||||
|
|
||||||
// always extract claims
|
// --- extract raw claims ---
|
||||||
claimsMap, ok := token.Claims.(jwt.MapClaims)
|
claimsMap, ok := token.Claims.(jwt.MapClaims)
|
||||||
if !ok {
|
if !ok {
|
||||||
return nil, errors.New("unexpected claim type")
|
return nil, errors.New("unexpected claim type")
|
||||||
}
|
}
|
||||||
|
|
||||||
// if older than cutover, skip audience+scope
|
// --- v1: skip audience check entirely ---
|
||||||
if !isLatestSpec {
|
if !isLatestSpec {
|
||||||
|
// we still want to return an empty set of scopes for policy to see
|
||||||
return &authz.TokenClaims{Scopes: nil}, nil
|
return &authz.TokenClaims{Scopes: nil}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// --- new spec flow: enforce audience ---
|
// --- v2: enforce audience ---
|
||||||
audRaw, exists := claimsMap["aud"]
|
audRaw, exists := claimsMap["aud"]
|
||||||
if !exists {
|
if !exists {
|
||||||
return nil, errors.New("aud claim missing")
|
return nil, errors.New("aud claim missing")
|
||||||
|
@ -151,25 +153,33 @@ func ValidateJWT(
|
||||||
return nil, errors.New("aud claim has unexpected type")
|
return nil, errors.New("aud claim has unexpected type")
|
||||||
}
|
}
|
||||||
|
|
||||||
// if no scope required, we're done
|
// --- collect all scopes from the token, if any ---
|
||||||
if requiredScope == "" {
|
rawScope := claimsMap["scope"]
|
||||||
return &authz.TokenClaims{Scopes: nil}, nil
|
scopeList := []string{}
|
||||||
|
if s, ok := rawScope.(string); ok {
|
||||||
|
scopeList = strings.Fields(s)
|
||||||
}
|
}
|
||||||
|
|
||||||
// enforce scope
|
return &authz.TokenClaims{Scopes: scopeList}, nil
|
||||||
rawScope, exists := claimsMap["scope"]
|
}
|
||||||
if !exists {
|
|
||||||
return nil, errors.New("scope claim missing")
|
// Process the required scopes
|
||||||
}
|
func GetRequiredScopes(cfg *config.Config, method string) []string {
|
||||||
scopeStr, ok := rawScope.(string)
|
if scopes, ok := cfg.ScopesSupported.(map[string]string); ok && len(scopes) > 0 {
|
||||||
if !ok {
|
if scope, ok := scopes[method]; ok {
|
||||||
return nil, errors.New("scope claim not a string")
|
return []string{scope}
|
||||||
}
|
}
|
||||||
scopes := strings.Fields(scopeStr)
|
if parts := strings.SplitN(method, "/", 2); len(parts) > 0 {
|
||||||
for _, s := range scopes {
|
if scope, ok := scopes[parts[0]]; ok {
|
||||||
if s == requiredScope {
|
return []string{scope}
|
||||||
return &authz.TokenClaims{Scopes: scopes}, nil
|
}
|
||||||
}
|
}
|
||||||
}
|
return nil
|
||||||
return nil, fmt.Errorf("insufficient scope: %q not in %v", requiredScope, scopes)
|
}
|
||||||
|
|
||||||
|
if scopes, ok := cfg.ScopesSupported.([]string); ok && len(scopes) > 0 {
|
||||||
|
return scopes
|
||||||
|
}
|
||||||
|
|
||||||
|
return []string{}
|
||||||
}
|
}
|
||||||
|
|
|
@ -34,6 +34,5 @@ func ParseRPCRequest(r *http.Request) (*RPCEnvelope, error) {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
logger.Info("JSON-RPC method = %q", env.Method)
|
|
||||||
return &env, nil
|
return &env, nil
|
||||||
}
|
}
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue