mirror of
https://github.com/wso2/open-mcp-auth-proxy.git
synced 2025-06-27 17:13:31 +00:00
improve readme
This commit is contained in:
parent
7b727c03a3
commit
4e957e93a2
11 changed files with 889 additions and 1 deletions
97
internal/util/jwks.go
Normal file
97
internal/util/jwks.go
Normal file
|
@ -0,0 +1,97 @@
|
|||
package util
|
||||
|
||||
import (
|
||||
"crypto/rsa"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"log"
|
||||
"math/big"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/golang-jwt/jwt/v4"
|
||||
)
|
||||
|
||||
type JWKS struct {
|
||||
Keys []json.RawMessage `json:"keys"`
|
||||
}
|
||||
|
||||
var publicKeys map[string]*rsa.PublicKey
|
||||
|
||||
// FetchJWKS downloads JWKS and stores in a package-level map
|
||||
func FetchJWKS(jwksURL string) error {
|
||||
resp, err := http.Get(jwksURL)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
var jwks JWKS
|
||||
if err := json.NewDecoder(resp.Body).Decode(&jwks); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
publicKeys = make(map[string]*rsa.PublicKey)
|
||||
for _, keyData := range jwks.Keys {
|
||||
var parsedKey struct {
|
||||
Kid string `json:"kid"`
|
||||
N string `json:"n"`
|
||||
E string `json:"e"`
|
||||
Kty string `json:"kty"`
|
||||
}
|
||||
if err := json.Unmarshal(keyData, &parsedKey); err != nil {
|
||||
continue
|
||||
}
|
||||
if parsedKey.Kty != "RSA" {
|
||||
continue
|
||||
}
|
||||
pubKey, err := parseRSAPublicKey(parsedKey.N, parsedKey.E)
|
||||
if err == nil {
|
||||
publicKeys[parsedKey.Kid] = pubKey
|
||||
}
|
||||
}
|
||||
log.Printf("[JWKS] Loaded %d public keys.", len(publicKeys))
|
||||
return nil
|
||||
}
|
||||
|
||||
func parseRSAPublicKey(nStr, eStr string) (*rsa.PublicKey, error) {
|
||||
nBytes, err := jwt.DecodeSegment(nStr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
eBytes, err := jwt.DecodeSegment(eStr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
n := new(big.Int).SetBytes(nBytes)
|
||||
e := 0
|
||||
for _, b := range eBytes {
|
||||
e = e<<8 + int(b)
|
||||
}
|
||||
|
||||
return &rsa.PublicKey{N: n, E: e}, nil
|
||||
}
|
||||
|
||||
// ValidateJWT checks the Authorization: Bearer token using stored JWKS
|
||||
func ValidateJWT(authHeader string) error {
|
||||
if authHeader == "" || !strings.HasPrefix(authHeader, "Bearer ") {
|
||||
return errors.New("missing or invalid Authorization header")
|
||||
}
|
||||
tokenStr := strings.TrimPrefix(authHeader, "Bearer ")
|
||||
token, err := jwt.Parse(tokenStr, func(token *jwt.Token) (interface{}, error) {
|
||||
kid, _ := token.Header["kid"].(string)
|
||||
pubKey, ok := publicKeys[kid]
|
||||
if !ok {
|
||||
return nil, errors.New("unknown or missing kid in token header")
|
||||
}
|
||||
return pubKey, nil
|
||||
})
|
||||
if err != nil {
|
||||
return errors.New("invalid token: " + err.Error())
|
||||
}
|
||||
if !token.Valid {
|
||||
return errors.New("invalid token: token not valid")
|
||||
}
|
||||
return nil
|
||||
}
|
Loading…
Add table
Add a link
Reference in a new issue