mirror of
https://github.com/wso2/open-mcp-auth-proxy.git
synced 2025-07-01 10:48:38 +00:00
Fix audience validation issues
This commit is contained in:
parent
331cc281c6
commit
312a5557f0
8 changed files with 163 additions and 76 deletions
|
@ -8,12 +8,10 @@ import (
|
|||
"math/big"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/golang-jwt/jwt/v4"
|
||||
"github.com/wso2/open-mcp-auth-proxy/internal/authz"
|
||||
"github.com/wso2/open-mcp-auth-proxy/internal/constants"
|
||||
logger "github.com/wso2/open-mcp-auth-proxy/internal/logging"
|
||||
"github.com/wso2/open-mcp-auth-proxy/internal/logging"
|
||||
)
|
||||
|
||||
type TokenClaims struct {
|
||||
|
@ -87,7 +85,7 @@ func parseRSAPublicKey(nStr, eStr string) (*rsa.PublicKey, error) {
|
|||
// - audience: the resource identifier to check "aud" against
|
||||
// - requiredScope: the single scope required (empty ⇒ skip scope check)
|
||||
func ValidateJWT(
|
||||
versionHeader, authHeader, audience, requiredScope string,
|
||||
isLatestSpec bool, authHeader, audience, requiredScope string,
|
||||
) (*authz.TokenClaims, error) {
|
||||
tokenStr := strings.TrimPrefix(authHeader, "Bearer ")
|
||||
if tokenStr == "" {
|
||||
|
@ -96,17 +94,20 @@ func ValidateJWT(
|
|||
|
||||
// 2) parse & verify signature
|
||||
token, err := jwt.Parse(tokenStr, func(token *jwt.Token) (interface{}, error) {
|
||||
kid, _ := token.Header["kid"].(string)
|
||||
pk, ok := publicKeys[kid]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("unknown kid %q", kid)
|
||||
if _, ok := token.Method.(*jwt.SigningMethodRSA); !ok {
|
||||
return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
|
||||
}
|
||||
return pk, nil
|
||||
kid, ok := token.Header["kid"].(string)
|
||||
if !ok {
|
||||
return nil, errors.New("kid header not found")
|
||||
}
|
||||
key, ok := publicKeys[kid]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("key not found for kid: %s", kid)
|
||||
}
|
||||
return key, nil
|
||||
})
|
||||
|
||||
logger.Info("token: %v", token)
|
||||
logger.Info("err: %v", err)
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid token: %w", err)
|
||||
}
|
||||
|
@ -120,15 +121,8 @@ func ValidateJWT(
|
|||
return nil, errors.New("unexpected claim type")
|
||||
}
|
||||
|
||||
// parse version date
|
||||
verDate, err := time.Parse("2006-01-02", versionHeader)
|
||||
if err != nil {
|
||||
// if unparsable or missing, assume _old_ spec
|
||||
verDate = time.Time{} // zero time ⇒ before cutover
|
||||
}
|
||||
|
||||
// if older than cutover, skip audience+scope
|
||||
if verDate.Before(constants.SpecCutoverDate) {
|
||||
if !isLatestSpec {
|
||||
return &authz.TokenClaims{Scopes: nil}, nil
|
||||
}
|
||||
|
||||
|
@ -179,23 +173,3 @@ func ValidateJWT(
|
|||
}
|
||||
return nil, fmt.Errorf("insufficient scope: %q not in %v", requiredScope, scopes)
|
||||
}
|
||||
|
||||
// Performs basic JWT validation
|
||||
func ValidateJWTLegacy(authHeader string) error {
|
||||
tokenString := strings.TrimPrefix(authHeader, "Bearer ")
|
||||
_, err := jwt.Parse(tokenString, func(token *jwt.Token) (interface{}, error) {
|
||||
if _, ok := token.Method.(*jwt.SigningMethodRSA); !ok {
|
||||
return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
|
||||
}
|
||||
kid, ok := token.Header["kid"].(string)
|
||||
if !ok {
|
||||
return nil, errors.New("kid header not found")
|
||||
}
|
||||
key, ok := publicKeys[kid]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("key not found for kid: %s", kid)
|
||||
}
|
||||
return key, nil
|
||||
})
|
||||
return err
|
||||
}
|
||||
|
|
39
internal/util/rpc.go
Normal file
39
internal/util/rpc.go
Normal file
|
@ -0,0 +1,39 @@
|
|||
package util
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
|
||||
logger "github.com/wso2/open-mcp-auth-proxy/internal/logging"
|
||||
)
|
||||
|
||||
type RPCEnvelope struct {
|
||||
Method string `json:"method"`
|
||||
Params any `json:"params"`
|
||||
ID any `json:"id"`
|
||||
}
|
||||
|
||||
// This function parses a JSON-RPC request from an HTTP request body
|
||||
func ParseRPCRequest(r *http.Request) (*RPCEnvelope, error) {
|
||||
bodyBytes, err := io.ReadAll(r.Body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
r.Body = io.NopCloser(bytes.NewBuffer(bodyBytes))
|
||||
|
||||
if len(bodyBytes) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
var env RPCEnvelope
|
||||
dec := json.NewDecoder(bytes.NewReader(bodyBytes))
|
||||
if err := dec.Decode(&env); err != nil && err != io.EOF {
|
||||
logger.Warn("Error parsing JSON-RPC envelope: %v", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
logger.Info("JSON-RPC method = %q", env.Method)
|
||||
return &env, nil
|
||||
}
|
25
internal/util/version.go
Normal file
25
internal/util/version.go
Normal file
|
@ -0,0 +1,25 @@
|
|||
package util
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/wso2/open-mcp-auth-proxy/internal/constants"
|
||||
)
|
||||
|
||||
// This function checks if the given version date is after the spec cutover date
|
||||
func IsLatestSpec(versionDate time.Time, err error) bool {
|
||||
return err == nil && !versionDate.Before(constants.SpecCutoverDate)
|
||||
}
|
||||
|
||||
// This function parses a version string into a time.Time
|
||||
func ParseVersionDate(version string) (time.Time, error) {
|
||||
return time.Parse("2006-01-02", version)
|
||||
}
|
||||
|
||||
// This function returns the version string, using the cutover date if empty
|
||||
func GetVersionWithDefault(version string) string {
|
||||
if version == "" {
|
||||
return constants.SpecCutoverDate.Format("2006-01-02")
|
||||
}
|
||||
return version
|
||||
}
|
Loading…
Add table
Add a link
Reference in a new issue