open-mcp-auth-proxy-upstream/internal/util/jwks.go
2025-05-15 01:20:29 +05:30

185 lines
4.4 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package util
import (
"crypto/rsa"
"encoding/json"
"errors"
"fmt"
"math/big"
"net/http"
"strings"
"github.com/golang-jwt/jwt/v4"
"github.com/wso2/open-mcp-auth-proxy/internal/authz"
"github.com/wso2/open-mcp-auth-proxy/internal/config"
logger "github.com/wso2/open-mcp-auth-proxy/internal/logging"
)
type TokenClaims struct {
Scopes []string
}
type JWKS struct {
Keys []json.RawMessage `json:"keys"`
}
var publicKeys map[string]*rsa.PublicKey
// FetchJWKS downloads JWKS and stores in a packagelevel map
func FetchJWKS(jwksURL string) error {
resp, err := http.Get(jwksURL)
if err != nil {
return err
}
defer resp.Body.Close()
var jwks JWKS
if err := json.NewDecoder(resp.Body).Decode(&jwks); err != nil {
return err
}
publicKeys = make(map[string]*rsa.PublicKey, len(jwks.Keys))
for _, keyData := range jwks.Keys {
var parsed struct {
Kid string `json:"kid"`
N string `json:"n"`
E string `json:"e"`
Kty string `json:"kty"`
}
if err := json.Unmarshal(keyData, &parsed); err != nil {
continue
}
if parsed.Kty != "RSA" {
continue
}
pk, err := parseRSAPublicKey(parsed.N, parsed.E)
if err == nil {
publicKeys[parsed.Kid] = pk
}
}
logger.Info("Loaded %d public keys.", len(publicKeys))
return nil
}
func parseRSAPublicKey(nStr, eStr string) (*rsa.PublicKey, error) {
nBytes, err := jwt.DecodeSegment(nStr)
if err != nil {
return nil, err
}
eBytes, err := jwt.DecodeSegment(eStr)
if err != nil {
return nil, err
}
n := new(big.Int).SetBytes(nBytes)
e := 0
for _, b := range eBytes {
e = e<<8 + int(b)
}
return &rsa.PublicKey{N: n, E: e}, nil
}
// ValidateJWT checks the Bearer token according to the Mcp-Protocol-Version.
// - isLatestSpec: whether to use the latest spec validation
// - authHeader: the full "Authorization" header
// - audience: the resource identifier to check "aud" against
// - requiredScopes: the scopes required (empty ⇒ skip scope check)
func ValidateJWT(
isLatestSpec bool,
authHeader, audience string,
) (*authz.TokenClaims, error) {
tokenStr := strings.TrimPrefix(authHeader, "Bearer ")
if tokenStr == "" {
return nil, errors.New("empty bearer token")
}
// --- parse & verify signature ---
token, err := jwt.Parse(tokenStr, func(token *jwt.Token) (interface{}, error) {
if _, ok := token.Method.(*jwt.SigningMethodRSA); !ok {
return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
}
kid, ok := token.Header["kid"].(string)
if !ok {
return nil, errors.New("kid header not found")
}
key, ok := publicKeys[kid]
if !ok {
return nil, fmt.Errorf("key not found for kid: %s", kid)
}
return key, nil
})
if err != nil {
return nil, fmt.Errorf("invalid token: %w", err)
}
if !token.Valid {
return nil, errors.New("token not valid")
}
// --- extract raw claims ---
claimsMap, ok := token.Claims.(jwt.MapClaims)
if !ok {
return nil, errors.New("unexpected claim type")
}
// --- v1: skip audience check entirely ---
if !isLatestSpec {
// we still want to return an empty set of scopes for policy to see
return &authz.TokenClaims{Scopes: nil}, nil
}
// --- v2: enforce audience ---
audRaw, exists := claimsMap["aud"]
if !exists {
return nil, errors.New("aud claim missing")
}
switch v := audRaw.(type) {
case string:
if v != audience {
return nil, fmt.Errorf("aud %q does not match %q", v, audience)
}
case []interface{}:
var found bool
for _, a := range v {
if s, ok := a.(string); ok && s == audience {
found = true
break
}
}
if !found {
return nil, fmt.Errorf("audience %v does not include %q", v, audience)
}
default:
return nil, errors.New("aud claim has unexpected type")
}
// --- collect all scopes from the token, if any ---
rawScope := claimsMap["scope"]
scopeList := []string{}
if s, ok := rawScope.(string); ok {
scopeList = strings.Fields(s)
}
return &authz.TokenClaims{Scopes: scopeList}, nil
}
// Process the required scopes
func GetRequiredScopes(cfg *config.Config, method string) []string {
if scopes, ok := cfg.ScopesSupported.(map[string]string); ok && len(scopes) > 0 {
if scope, ok := scopes[method]; ok {
return []string{scope}
}
if parts := strings.SplitN(method, "/", 2); len(parts) > 0 {
if scope, ok := scopes[parts[0]]; ok {
return []string{scope}
}
}
return nil
}
if scopes, ok := cfg.ScopesSupported.([]string); ok && len(scopes) > 0 {
return scopes
}
return []string{}
}