Fix audience validation issues

This commit is contained in:
NipuniBhagya 2025-05-14 21:47:15 +05:30
parent 331cc281c6
commit 312a5557f0
8 changed files with 163 additions and 76 deletions

View file

@ -93,7 +93,7 @@ func main() {
}
// 5. (Optional) Build the policy engine
engine := &authz.DefaulPolicyEngine{}
engine := &authz.DefaultPolicyEngine{}
// 6. Build the main router
mux := proxy.NewRouter(cfg, provider, engine)

View file

@ -99,6 +99,7 @@ func (p *defaultProvider) ProtectedResourceMetadataHandler() http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
meta := map[string]interface{}{
"audience": p.cfg.Audience,
"resource": p.cfg.ResourceIdentifier,
"scopes_supported": p.cfg.ScopesSupported,
"authorization_servers": p.cfg.AuthorizationServers,

View file

@ -1,20 +1,49 @@
package authz
import (
"fmt"
"net/http"
"strings"
)
type TokenClaims struct {
Scopes []string
}
type DefaulPolicyEngine struct{}
type DefaultPolicyEngine struct{}
func (d *DefaulPolicyEngine) Evaluate(r *http.Request, claims *TokenClaims, requiredScope string) PolicyResult {
for _, scope := range claims.Scopes {
if scope == requiredScope {
// Evaluate and checks the token claims against one or more required scopes.
func (d *DefaultPolicyEngine) Evaluate(
_ *http.Request,
claims *TokenClaims,
requiredScope string,
) PolicyResult {
if strings.TrimSpace(requiredScope) == "" {
return PolicyResult{DecisionAllow, ""}
}
raw := strings.FieldsFunc(requiredScope, func(r rune) bool {
return r == ' ' || r == ','
})
want := make(map[string]struct{}, len(raw))
for _, s := range raw {
if s = strings.TrimSpace(s); s != "" {
want[s] = struct{}{}
}
}
for _, have := range claims.Scopes {
if _, ok := want[have]; ok {
return PolicyResult{DecisionAllow, ""}
}
}
return PolicyResult{DecisionDeny, "missing scope '" + requiredScope + "'"}
var list []string
for s := range want {
list = append(list, s)
}
return PolicyResult{
DecisionDeny,
fmt.Sprintf("missing required scope(s): %s", strings.Join(list, ", ")),
}
}

View file

@ -103,6 +103,7 @@ type Config struct {
Default DefaultConfig `yaml:"default"`
// Protected resource metadata
Audience string `yaml:"audience"`
ResourceIdentifier string `yaml:"resource_identifier"`
ScopesSupported map[string]string `yaml:"scopes_supported"`
AuthorizationServers []string `yaml:"authorization_servers"`

View file

@ -11,7 +11,6 @@ import (
"github.com/wso2/open-mcp-auth-proxy/internal/authz"
"github.com/wso2/open-mcp-auth-proxy/internal/config"
"github.com/wso2/open-mcp-auth-proxy/internal/constants"
logger "github.com/wso2/open-mcp-auth-proxy/internal/logging"
"github.com/wso2/open-mcp-auth-proxy/internal/util"
)
@ -157,9 +156,10 @@ func buildProxyHandler(cfg *config.Config, modifiers map[string]RequestModifier,
// Add CORS headers to all responses
addCORSHeaders(w, cfg, allowedOrigin, "")
versionRaw := r.Header.Get("MCP-Protocol-Version")
ver, err := time.Parse(constants.TimeLayout, versionRaw)
isLatestSpec := err == nil && !ver.Before(constants.SpecCutoverDate)
// Check if the request is for the latest spec
specVersion := util.GetVersionWithDefault(r.Header.Get("MCP-Protocol-Version"))
ver, err := util.ParseVersionDate(specVersion)
isLatestSpec := util.IsLatestSpec(ver, err)
// Decide whether the request should go to the auth server or MCP
var targetURL *url.URL
@ -311,35 +311,53 @@ func authorizeSSE(w http.ResponseWriter, r *http.Request, isLatestSpec bool, res
// Handles both v1 (just signature) and v2 (aud + scope) flows
func authorizeMCP(w http.ResponseWriter, r *http.Request, isLatestSpec bool, cfg *config.Config) (*authz.TokenClaims, error) {
logger.Info("authorizeMCP")
h := r.Header.Get("Authorization")
audience := cfg.ResourceIdentifier
if isLatestSpec {
required := cfg.ScopesSupported[r.URL.Path]
claims, err := util.ValidateJWT(r.Header.Get("MCP-Protocol-Version"), h, audience, required)
logger.Info("claims: %v", claims)
logger.Info("err: %v", err)
if err != nil {
w.Header().Set(
"WWW-Authenticate",
fmt.Sprintf(
`Bearer realm="%s", error="insufficient_scope", scope="%s"`,
cfg.ResourceIdentifier+"/.well-known/oauth-protected-resource",
required,
),
)
w.Header().Set("Access-Control-Expose-Headers", "WWW-Authenticate")
return nil, fmt.Errorf("forbidden — insufficient scope")
}
return claims, nil
}
// v1: only check signature, then continue
if err := util.ValidateJWTLegacy(h); err != nil {
// Parse JSON-RPC request if present
if env, err := util.ParseRPCRequest(r); err != nil {
http.Error(w, "Bad request", http.StatusBadRequest)
return nil, err
} else if env != nil {
logger.Info("JSON-RPC method = %q", env.Method)
}
return &authz.TokenClaims{}, nil
authzHeader := r.Header.Get("Authorization")
if !strings.HasPrefix(authzHeader, "Bearer ") {
if isLatestSpec {
realm := cfg.ResourceIdentifier + "/.well-known/oauth-protected-resource"
w.Header().Set("WWW-Authenticate", fmt.Sprintf(
`Bearer resource_metadata=%q`, realm,
))
w.Header().Set("Access-Control-Expose-Headers", "WWW-Authenticate")
}
http.Error(w, "Unauthorized", http.StatusUnauthorized)
return nil, fmt.Errorf("missing or invalid Authorization header")
}
requiredScope := ""
if isLatestSpec {
requiredScope = cfg.ScopesSupported[r.URL.Path]
}
claims, err := util.ValidateJWT(
isLatestSpec,
authzHeader,
cfg.Audience,
requiredScope,
)
if err != nil {
if isLatestSpec {
realm := cfg.ResourceIdentifier + "/.well-known/oauth-protected-resource"
w.Header().Set("WWW-Authenticate", fmt.Sprintf(
`Bearer realm=%q, error="insufficient_scope", scope=%q`,
realm, requiredScope,
))
w.Header().Set("Access-Control-Expose-Headers", "WWW-Authenticate")
http.Error(w, "Forbidden", http.StatusForbidden)
} else {
http.Error(w, "Unauthorized", http.StatusUnauthorized)
}
return nil, err
}
return claims, nil
}
func getAllowedOrigin(origin string, cfg *config.Config) string {

View file

@ -8,12 +8,10 @@ import (
"math/big"
"net/http"
"strings"
"time"
"github.com/golang-jwt/jwt/v4"
"github.com/wso2/open-mcp-auth-proxy/internal/authz"
"github.com/wso2/open-mcp-auth-proxy/internal/constants"
logger "github.com/wso2/open-mcp-auth-proxy/internal/logging"
"github.com/wso2/open-mcp-auth-proxy/internal/logging"
)
type TokenClaims struct {
@ -87,7 +85,7 @@ func parseRSAPublicKey(nStr, eStr string) (*rsa.PublicKey, error) {
// - audience: the resource identifier to check "aud" against
// - requiredScope: the single scope required (empty ⇒ skip scope check)
func ValidateJWT(
versionHeader, authHeader, audience, requiredScope string,
isLatestSpec bool, authHeader, audience, requiredScope string,
) (*authz.TokenClaims, error) {
tokenStr := strings.TrimPrefix(authHeader, "Bearer ")
if tokenStr == "" {
@ -96,17 +94,20 @@ func ValidateJWT(
// 2) parse & verify signature
token, err := jwt.Parse(tokenStr, func(token *jwt.Token) (interface{}, error) {
kid, _ := token.Header["kid"].(string)
pk, ok := publicKeys[kid]
if !ok {
return nil, fmt.Errorf("unknown kid %q", kid)
if _, ok := token.Method.(*jwt.SigningMethodRSA); !ok {
return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
}
return pk, nil
kid, ok := token.Header["kid"].(string)
if !ok {
return nil, errors.New("kid header not found")
}
key, ok := publicKeys[kid]
if !ok {
return nil, fmt.Errorf("key not found for kid: %s", kid)
}
return key, nil
})
logger.Info("token: %v", token)
logger.Info("err: %v", err)
if err != nil {
return nil, fmt.Errorf("invalid token: %w", err)
}
@ -120,15 +121,8 @@ func ValidateJWT(
return nil, errors.New("unexpected claim type")
}
// parse version date
verDate, err := time.Parse("2006-01-02", versionHeader)
if err != nil {
// if unparsable or missing, assume _old_ spec
verDate = time.Time{} // zero time ⇒ before cutover
}
// if older than cutover, skip audience+scope
if verDate.Before(constants.SpecCutoverDate) {
if !isLatestSpec {
return &authz.TokenClaims{Scopes: nil}, nil
}
@ -179,23 +173,3 @@ func ValidateJWT(
}
return nil, fmt.Errorf("insufficient scope: %q not in %v", requiredScope, scopes)
}
// Performs basic JWT validation
func ValidateJWTLegacy(authHeader string) error {
tokenString := strings.TrimPrefix(authHeader, "Bearer ")
_, err := jwt.Parse(tokenString, func(token *jwt.Token) (interface{}, error) {
if _, ok := token.Method.(*jwt.SigningMethodRSA); !ok {
return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
}
kid, ok := token.Header["kid"].(string)
if !ok {
return nil, errors.New("kid header not found")
}
key, ok := publicKeys[kid]
if !ok {
return nil, fmt.Errorf("key not found for kid: %s", kid)
}
return key, nil
})
return err
}

39
internal/util/rpc.go Normal file
View file

@ -0,0 +1,39 @@
package util
import (
"bytes"
"encoding/json"
"io"
"net/http"
logger "github.com/wso2/open-mcp-auth-proxy/internal/logging"
)
type RPCEnvelope struct {
Method string `json:"method"`
Params any `json:"params"`
ID any `json:"id"`
}
// This function parses a JSON-RPC request from an HTTP request body
func ParseRPCRequest(r *http.Request) (*RPCEnvelope, error) {
bodyBytes, err := io.ReadAll(r.Body)
if err != nil {
return nil, err
}
r.Body = io.NopCloser(bytes.NewBuffer(bodyBytes))
if len(bodyBytes) == 0 {
return nil, nil
}
var env RPCEnvelope
dec := json.NewDecoder(bytes.NewReader(bodyBytes))
if err := dec.Decode(&env); err != nil && err != io.EOF {
logger.Warn("Error parsing JSON-RPC envelope: %v", err)
return nil, err
}
logger.Info("JSON-RPC method = %q", env.Method)
return &env, nil
}

25
internal/util/version.go Normal file
View file

@ -0,0 +1,25 @@
package util
import (
"time"
"github.com/wso2/open-mcp-auth-proxy/internal/constants"
)
// This function checks if the given version date is after the spec cutover date
func IsLatestSpec(versionDate time.Time, err error) bool {
return err == nil && !versionDate.Before(constants.SpecCutoverDate)
}
// This function parses a version string into a time.Time
func ParseVersionDate(version string) (time.Time, error) {
return time.Parse("2006-01-02", version)
}
// This function returns the version string, using the cutover date if empty
func GetVersionWithDefault(version string) string {
if version == "" {
return constants.SpecCutoverDate.Format("2006-01-02")
}
return version
}