mirror of
https://github.com/wso2/open-mcp-auth-proxy.git
synced 2025-06-28 01:23:30 +00:00
Fix audience validation issues
This commit is contained in:
parent
331cc281c6
commit
312a5557f0
8 changed files with 163 additions and 76 deletions
|
@ -93,7 +93,7 @@ func main() {
|
||||||
}
|
}
|
||||||
|
|
||||||
// 5. (Optional) Build the policy engine
|
// 5. (Optional) Build the policy engine
|
||||||
engine := &authz.DefaulPolicyEngine{}
|
engine := &authz.DefaultPolicyEngine{}
|
||||||
|
|
||||||
// 6. Build the main router
|
// 6. Build the main router
|
||||||
mux := proxy.NewRouter(cfg, provider, engine)
|
mux := proxy.NewRouter(cfg, provider, engine)
|
||||||
|
|
|
@ -99,6 +99,7 @@ func (p *defaultProvider) ProtectedResourceMetadataHandler() http.HandlerFunc {
|
||||||
return func(w http.ResponseWriter, r *http.Request) {
|
return func(w http.ResponseWriter, r *http.Request) {
|
||||||
w.Header().Set("Content-Type", "application/json")
|
w.Header().Set("Content-Type", "application/json")
|
||||||
meta := map[string]interface{}{
|
meta := map[string]interface{}{
|
||||||
|
"audience": p.cfg.Audience,
|
||||||
"resource": p.cfg.ResourceIdentifier,
|
"resource": p.cfg.ResourceIdentifier,
|
||||||
"scopes_supported": p.cfg.ScopesSupported,
|
"scopes_supported": p.cfg.ScopesSupported,
|
||||||
"authorization_servers": p.cfg.AuthorizationServers,
|
"authorization_servers": p.cfg.AuthorizationServers,
|
||||||
|
|
|
@ -1,20 +1,49 @@
|
||||||
package authz
|
package authz
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"strings"
|
||||||
)
|
)
|
||||||
|
|
||||||
type TokenClaims struct {
|
type TokenClaims struct {
|
||||||
Scopes []string
|
Scopes []string
|
||||||
}
|
}
|
||||||
|
|
||||||
type DefaulPolicyEngine struct{}
|
type DefaultPolicyEngine struct{}
|
||||||
|
|
||||||
func (d *DefaulPolicyEngine) Evaluate(r *http.Request, claims *TokenClaims, requiredScope string) PolicyResult {
|
// Evaluate and checks the token claims against one or more required scopes.
|
||||||
for _, scope := range claims.Scopes {
|
func (d *DefaultPolicyEngine) Evaluate(
|
||||||
if scope == requiredScope {
|
_ *http.Request,
|
||||||
|
claims *TokenClaims,
|
||||||
|
requiredScope string,
|
||||||
|
) PolicyResult {
|
||||||
|
if strings.TrimSpace(requiredScope) == "" {
|
||||||
|
return PolicyResult{DecisionAllow, ""}
|
||||||
|
}
|
||||||
|
|
||||||
|
raw := strings.FieldsFunc(requiredScope, func(r rune) bool {
|
||||||
|
return r == ' ' || r == ','
|
||||||
|
})
|
||||||
|
want := make(map[string]struct{}, len(raw))
|
||||||
|
for _, s := range raw {
|
||||||
|
if s = strings.TrimSpace(s); s != "" {
|
||||||
|
want[s] = struct{}{}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, have := range claims.Scopes {
|
||||||
|
if _, ok := want[have]; ok {
|
||||||
return PolicyResult{DecisionAllow, ""}
|
return PolicyResult{DecisionAllow, ""}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return PolicyResult{DecisionDeny, "missing scope '" + requiredScope + "'"}
|
|
||||||
|
var list []string
|
||||||
|
for s := range want {
|
||||||
|
list = append(list, s)
|
||||||
|
}
|
||||||
|
return PolicyResult{
|
||||||
|
DecisionDeny,
|
||||||
|
fmt.Sprintf("missing required scope(s): %s", strings.Join(list, ", ")),
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -103,6 +103,7 @@ type Config struct {
|
||||||
Default DefaultConfig `yaml:"default"`
|
Default DefaultConfig `yaml:"default"`
|
||||||
|
|
||||||
// Protected resource metadata
|
// Protected resource metadata
|
||||||
|
Audience string `yaml:"audience"`
|
||||||
ResourceIdentifier string `yaml:"resource_identifier"`
|
ResourceIdentifier string `yaml:"resource_identifier"`
|
||||||
ScopesSupported map[string]string `yaml:"scopes_supported"`
|
ScopesSupported map[string]string `yaml:"scopes_supported"`
|
||||||
AuthorizationServers []string `yaml:"authorization_servers"`
|
AuthorizationServers []string `yaml:"authorization_servers"`
|
||||||
|
|
|
@ -11,7 +11,6 @@ import (
|
||||||
|
|
||||||
"github.com/wso2/open-mcp-auth-proxy/internal/authz"
|
"github.com/wso2/open-mcp-auth-proxy/internal/authz"
|
||||||
"github.com/wso2/open-mcp-auth-proxy/internal/config"
|
"github.com/wso2/open-mcp-auth-proxy/internal/config"
|
||||||
"github.com/wso2/open-mcp-auth-proxy/internal/constants"
|
|
||||||
logger "github.com/wso2/open-mcp-auth-proxy/internal/logging"
|
logger "github.com/wso2/open-mcp-auth-proxy/internal/logging"
|
||||||
"github.com/wso2/open-mcp-auth-proxy/internal/util"
|
"github.com/wso2/open-mcp-auth-proxy/internal/util"
|
||||||
)
|
)
|
||||||
|
@ -157,9 +156,10 @@ func buildProxyHandler(cfg *config.Config, modifiers map[string]RequestModifier,
|
||||||
// Add CORS headers to all responses
|
// Add CORS headers to all responses
|
||||||
addCORSHeaders(w, cfg, allowedOrigin, "")
|
addCORSHeaders(w, cfg, allowedOrigin, "")
|
||||||
|
|
||||||
versionRaw := r.Header.Get("MCP-Protocol-Version")
|
// Check if the request is for the latest spec
|
||||||
ver, err := time.Parse(constants.TimeLayout, versionRaw)
|
specVersion := util.GetVersionWithDefault(r.Header.Get("MCP-Protocol-Version"))
|
||||||
isLatestSpec := err == nil && !ver.Before(constants.SpecCutoverDate)
|
ver, err := util.ParseVersionDate(specVersion)
|
||||||
|
isLatestSpec := util.IsLatestSpec(ver, err)
|
||||||
|
|
||||||
// Decide whether the request should go to the auth server or MCP
|
// Decide whether the request should go to the auth server or MCP
|
||||||
var targetURL *url.URL
|
var targetURL *url.URL
|
||||||
|
@ -311,35 +311,53 @@ func authorizeSSE(w http.ResponseWriter, r *http.Request, isLatestSpec bool, res
|
||||||
|
|
||||||
// Handles both v1 (just signature) and v2 (aud + scope) flows
|
// Handles both v1 (just signature) and v2 (aud + scope) flows
|
||||||
func authorizeMCP(w http.ResponseWriter, r *http.Request, isLatestSpec bool, cfg *config.Config) (*authz.TokenClaims, error) {
|
func authorizeMCP(w http.ResponseWriter, r *http.Request, isLatestSpec bool, cfg *config.Config) (*authz.TokenClaims, error) {
|
||||||
logger.Info("authorizeMCP")
|
// Parse JSON-RPC request if present
|
||||||
h := r.Header.Get("Authorization")
|
if env, err := util.ParseRPCRequest(r); err != nil {
|
||||||
audience := cfg.ResourceIdentifier
|
http.Error(w, "Bad request", http.StatusBadRequest)
|
||||||
if isLatestSpec {
|
|
||||||
required := cfg.ScopesSupported[r.URL.Path]
|
|
||||||
claims, err := util.ValidateJWT(r.Header.Get("MCP-Protocol-Version"), h, audience, required)
|
|
||||||
logger.Info("claims: %v", claims)
|
|
||||||
logger.Info("err: %v", err)
|
|
||||||
if err != nil {
|
|
||||||
w.Header().Set(
|
|
||||||
"WWW-Authenticate",
|
|
||||||
fmt.Sprintf(
|
|
||||||
`Bearer realm="%s", error="insufficient_scope", scope="%s"`,
|
|
||||||
cfg.ResourceIdentifier+"/.well-known/oauth-protected-resource",
|
|
||||||
required,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
w.Header().Set("Access-Control-Expose-Headers", "WWW-Authenticate")
|
|
||||||
return nil, fmt.Errorf("forbidden — insufficient scope")
|
|
||||||
}
|
|
||||||
return claims, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// v1: only check signature, then continue
|
|
||||||
if err := util.ValidateJWTLegacy(h); err != nil {
|
|
||||||
return nil, err
|
return nil, err
|
||||||
|
} else if env != nil {
|
||||||
|
logger.Info("JSON-RPC method = %q", env.Method)
|
||||||
}
|
}
|
||||||
|
|
||||||
return &authz.TokenClaims{}, nil
|
authzHeader := r.Header.Get("Authorization")
|
||||||
|
if !strings.HasPrefix(authzHeader, "Bearer ") {
|
||||||
|
if isLatestSpec {
|
||||||
|
realm := cfg.ResourceIdentifier + "/.well-known/oauth-protected-resource"
|
||||||
|
w.Header().Set("WWW-Authenticate", fmt.Sprintf(
|
||||||
|
`Bearer resource_metadata=%q`, realm,
|
||||||
|
))
|
||||||
|
w.Header().Set("Access-Control-Expose-Headers", "WWW-Authenticate")
|
||||||
|
}
|
||||||
|
http.Error(w, "Unauthorized", http.StatusUnauthorized)
|
||||||
|
return nil, fmt.Errorf("missing or invalid Authorization header")
|
||||||
|
}
|
||||||
|
|
||||||
|
requiredScope := ""
|
||||||
|
if isLatestSpec {
|
||||||
|
requiredScope = cfg.ScopesSupported[r.URL.Path]
|
||||||
|
}
|
||||||
|
claims, err := util.ValidateJWT(
|
||||||
|
isLatestSpec,
|
||||||
|
authzHeader,
|
||||||
|
cfg.Audience,
|
||||||
|
requiredScope,
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
if isLatestSpec {
|
||||||
|
realm := cfg.ResourceIdentifier + "/.well-known/oauth-protected-resource"
|
||||||
|
w.Header().Set("WWW-Authenticate", fmt.Sprintf(
|
||||||
|
`Bearer realm=%q, error="insufficient_scope", scope=%q`,
|
||||||
|
realm, requiredScope,
|
||||||
|
))
|
||||||
|
w.Header().Set("Access-Control-Expose-Headers", "WWW-Authenticate")
|
||||||
|
http.Error(w, "Forbidden", http.StatusForbidden)
|
||||||
|
} else {
|
||||||
|
http.Error(w, "Unauthorized", http.StatusUnauthorized)
|
||||||
|
}
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return claims, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func getAllowedOrigin(origin string, cfg *config.Config) string {
|
func getAllowedOrigin(origin string, cfg *config.Config) string {
|
||||||
|
|
|
@ -8,12 +8,10 @@ import (
|
||||||
"math/big"
|
"math/big"
|
||||||
"net/http"
|
"net/http"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/golang-jwt/jwt/v4"
|
"github.com/golang-jwt/jwt/v4"
|
||||||
"github.com/wso2/open-mcp-auth-proxy/internal/authz"
|
"github.com/wso2/open-mcp-auth-proxy/internal/authz"
|
||||||
"github.com/wso2/open-mcp-auth-proxy/internal/constants"
|
"github.com/wso2/open-mcp-auth-proxy/internal/logging"
|
||||||
logger "github.com/wso2/open-mcp-auth-proxy/internal/logging"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type TokenClaims struct {
|
type TokenClaims struct {
|
||||||
|
@ -87,7 +85,7 @@ func parseRSAPublicKey(nStr, eStr string) (*rsa.PublicKey, error) {
|
||||||
// - audience: the resource identifier to check "aud" against
|
// - audience: the resource identifier to check "aud" against
|
||||||
// - requiredScope: the single scope required (empty ⇒ skip scope check)
|
// - requiredScope: the single scope required (empty ⇒ skip scope check)
|
||||||
func ValidateJWT(
|
func ValidateJWT(
|
||||||
versionHeader, authHeader, audience, requiredScope string,
|
isLatestSpec bool, authHeader, audience, requiredScope string,
|
||||||
) (*authz.TokenClaims, error) {
|
) (*authz.TokenClaims, error) {
|
||||||
tokenStr := strings.TrimPrefix(authHeader, "Bearer ")
|
tokenStr := strings.TrimPrefix(authHeader, "Bearer ")
|
||||||
if tokenStr == "" {
|
if tokenStr == "" {
|
||||||
|
@ -96,17 +94,20 @@ func ValidateJWT(
|
||||||
|
|
||||||
// 2) parse & verify signature
|
// 2) parse & verify signature
|
||||||
token, err := jwt.Parse(tokenStr, func(token *jwt.Token) (interface{}, error) {
|
token, err := jwt.Parse(tokenStr, func(token *jwt.Token) (interface{}, error) {
|
||||||
kid, _ := token.Header["kid"].(string)
|
if _, ok := token.Method.(*jwt.SigningMethodRSA); !ok {
|
||||||
pk, ok := publicKeys[kid]
|
return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
|
||||||
if !ok {
|
|
||||||
return nil, fmt.Errorf("unknown kid %q", kid)
|
|
||||||
}
|
}
|
||||||
return pk, nil
|
kid, ok := token.Header["kid"].(string)
|
||||||
|
if !ok {
|
||||||
|
return nil, errors.New("kid header not found")
|
||||||
|
}
|
||||||
|
key, ok := publicKeys[kid]
|
||||||
|
if !ok {
|
||||||
|
return nil, fmt.Errorf("key not found for kid: %s", kid)
|
||||||
|
}
|
||||||
|
return key, nil
|
||||||
})
|
})
|
||||||
|
|
||||||
logger.Info("token: %v", token)
|
|
||||||
logger.Info("err: %v", err)
|
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("invalid token: %w", err)
|
return nil, fmt.Errorf("invalid token: %w", err)
|
||||||
}
|
}
|
||||||
|
@ -120,15 +121,8 @@ func ValidateJWT(
|
||||||
return nil, errors.New("unexpected claim type")
|
return nil, errors.New("unexpected claim type")
|
||||||
}
|
}
|
||||||
|
|
||||||
// parse version date
|
|
||||||
verDate, err := time.Parse("2006-01-02", versionHeader)
|
|
||||||
if err != nil {
|
|
||||||
// if unparsable or missing, assume _old_ spec
|
|
||||||
verDate = time.Time{} // zero time ⇒ before cutover
|
|
||||||
}
|
|
||||||
|
|
||||||
// if older than cutover, skip audience+scope
|
// if older than cutover, skip audience+scope
|
||||||
if verDate.Before(constants.SpecCutoverDate) {
|
if !isLatestSpec {
|
||||||
return &authz.TokenClaims{Scopes: nil}, nil
|
return &authz.TokenClaims{Scopes: nil}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -179,23 +173,3 @@ func ValidateJWT(
|
||||||
}
|
}
|
||||||
return nil, fmt.Errorf("insufficient scope: %q not in %v", requiredScope, scopes)
|
return nil, fmt.Errorf("insufficient scope: %q not in %v", requiredScope, scopes)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Performs basic JWT validation
|
|
||||||
func ValidateJWTLegacy(authHeader string) error {
|
|
||||||
tokenString := strings.TrimPrefix(authHeader, "Bearer ")
|
|
||||||
_, err := jwt.Parse(tokenString, func(token *jwt.Token) (interface{}, error) {
|
|
||||||
if _, ok := token.Method.(*jwt.SigningMethodRSA); !ok {
|
|
||||||
return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
|
|
||||||
}
|
|
||||||
kid, ok := token.Header["kid"].(string)
|
|
||||||
if !ok {
|
|
||||||
return nil, errors.New("kid header not found")
|
|
||||||
}
|
|
||||||
key, ok := publicKeys[kid]
|
|
||||||
if !ok {
|
|
||||||
return nil, fmt.Errorf("key not found for kid: %s", kid)
|
|
||||||
}
|
|
||||||
return key, nil
|
|
||||||
})
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
39
internal/util/rpc.go
Normal file
39
internal/util/rpc.go
Normal file
|
@ -0,0 +1,39 @@
|
||||||
|
package util
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"encoding/json"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
|
||||||
|
logger "github.com/wso2/open-mcp-auth-proxy/internal/logging"
|
||||||
|
)
|
||||||
|
|
||||||
|
type RPCEnvelope struct {
|
||||||
|
Method string `json:"method"`
|
||||||
|
Params any `json:"params"`
|
||||||
|
ID any `json:"id"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// This function parses a JSON-RPC request from an HTTP request body
|
||||||
|
func ParseRPCRequest(r *http.Request) (*RPCEnvelope, error) {
|
||||||
|
bodyBytes, err := io.ReadAll(r.Body)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
r.Body = io.NopCloser(bytes.NewBuffer(bodyBytes))
|
||||||
|
|
||||||
|
if len(bodyBytes) == 0 {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var env RPCEnvelope
|
||||||
|
dec := json.NewDecoder(bytes.NewReader(bodyBytes))
|
||||||
|
if err := dec.Decode(&env); err != nil && err != io.EOF {
|
||||||
|
logger.Warn("Error parsing JSON-RPC envelope: %v", err)
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.Info("JSON-RPC method = %q", env.Method)
|
||||||
|
return &env, nil
|
||||||
|
}
|
25
internal/util/version.go
Normal file
25
internal/util/version.go
Normal file
|
@ -0,0 +1,25 @@
|
||||||
|
package util
|
||||||
|
|
||||||
|
import (
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/wso2/open-mcp-auth-proxy/internal/constants"
|
||||||
|
)
|
||||||
|
|
||||||
|
// This function checks if the given version date is after the spec cutover date
|
||||||
|
func IsLatestSpec(versionDate time.Time, err error) bool {
|
||||||
|
return err == nil && !versionDate.Before(constants.SpecCutoverDate)
|
||||||
|
}
|
||||||
|
|
||||||
|
// This function parses a version string into a time.Time
|
||||||
|
func ParseVersionDate(version string) (time.Time, error) {
|
||||||
|
return time.Parse("2006-01-02", version)
|
||||||
|
}
|
||||||
|
|
||||||
|
// This function returns the version string, using the cutover date if empty
|
||||||
|
func GetVersionWithDefault(version string) string {
|
||||||
|
if version == "" {
|
||||||
|
return constants.SpecCutoverDate.Format("2006-01-02")
|
||||||
|
}
|
||||||
|
return version
|
||||||
|
}
|
Loading…
Add table
Add a link
Reference in a new issue