mirror of
https://github.com/wso2/open-mcp-auth-proxy.git
synced 2025-06-28 01:23:30 +00:00
Refactor proxy builder
This commit is contained in:
parent
85e5fe1c1d
commit
331cc281c6
5 changed files with 200 additions and 35 deletions
13
config.yaml
13
config.yaml
|
@ -45,3 +45,16 @@ demo:
|
||||||
org_name: "openmcpauthdemo"
|
org_name: "openmcpauthdemo"
|
||||||
client_id: "N0U9e_NNGr9mP_0fPnPfPI0a6twa"
|
client_id: "N0U9e_NNGr9mP_0fPnPfPI0a6twa"
|
||||||
client_secret: "qFHfiBp5gNGAO9zV4YPnDofBzzfInatfUbHyPZvM0jka"
|
client_secret: "qFHfiBp5gNGAO9zV4YPnDofBzzfInatfUbHyPZvM0jka"
|
||||||
|
|
||||||
|
# Protected resource metadata
|
||||||
|
resource_identifier: http://localhost:3000
|
||||||
|
scopes_supported:
|
||||||
|
- get-alerts
|
||||||
|
- get-forecast
|
||||||
|
authorization_servers:
|
||||||
|
- https://idp.example.com
|
||||||
|
jwks_uri: https://idp.example.com/.well-known/jwks.json
|
||||||
|
bearer_methods_supported:
|
||||||
|
- header
|
||||||
|
- body
|
||||||
|
- query
|
|
@ -336,3 +336,23 @@ func randomString(n int) string {
|
||||||
}
|
}
|
||||||
return string(b)
|
return string(b)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (p *asgardeoProvider) ProtectedResourceMetadataHandler() http.HandlerFunc {
|
||||||
|
return func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
meta := map[string]interface{}{
|
||||||
|
"resource": p.cfg.ResourceIdentifier,
|
||||||
|
"scopes_supported": p.cfg.ScopesSupported,
|
||||||
|
"authorization_servers": p.cfg.AuthorizationServers,
|
||||||
|
}
|
||||||
|
if p.cfg.JwksURI != "" {
|
||||||
|
meta["jwks_uri"] = p.cfg.JwksURI
|
||||||
|
}
|
||||||
|
if len(p.cfg.BearerMethodsSupported) > 0 {
|
||||||
|
meta["bearer_methods_supported"] = p.cfg.BearerMethodsSupported
|
||||||
|
}
|
||||||
|
if err := json.NewEncoder(w).Encode(meta); err != nil {
|
||||||
|
http.Error(w, "failed to encode metadata", http.StatusInternalServerError)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
20
internal/authz/default_policy_engine.go
Normal file
20
internal/authz/default_policy_engine.go
Normal file
|
@ -0,0 +1,20 @@
|
||||||
|
package authz
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
)
|
||||||
|
|
||||||
|
type TokenClaims struct {
|
||||||
|
Scopes []string
|
||||||
|
}
|
||||||
|
|
||||||
|
type DefaulPolicyEngine struct{}
|
||||||
|
|
||||||
|
func (d *DefaulPolicyEngine) Evaluate(r *http.Request, claims *TokenClaims, requiredScope string) PolicyResult {
|
||||||
|
for _, scope := range claims.Scopes {
|
||||||
|
if scope == requiredScope {
|
||||||
|
return PolicyResult{DecisionAllow, ""}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return PolicyResult{DecisionDeny, "missing scope '" + requiredScope + "'"}
|
||||||
|
}
|
|
@ -177,7 +177,7 @@ func buildProxyHandler(cfg *config.Config, modifiers map[string]RequestModifier,
|
||||||
} else {
|
} else {
|
||||||
claims, err := authorizeMCP(w, r, isLatestSpec, cfg)
|
claims, err := authorizeMCP(w, r, isLatestSpec, cfg)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
http.Error(w, err.Error(), http.StatusUnauthorized)
|
http.Error(w, err.Error(), http.StatusForbidden)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -227,13 +227,13 @@ func buildProxyHandler(cfg *config.Config, modifiers map[string]RequestModifier,
|
||||||
req.Host = targetURL.Host
|
req.Host = targetURL.Host
|
||||||
|
|
||||||
cleanHeaders := http.Header{}
|
cleanHeaders := http.Header{}
|
||||||
|
|
||||||
// Set proper origin header to match the target
|
// Set proper origin header to match the target
|
||||||
if isSSE {
|
if isSSE {
|
||||||
// For SSE, ensure origin matches the target
|
// For SSE, ensure origin matches the target
|
||||||
req.Header.Set("Origin", targetURL.Scheme+"://"+targetURL.Host)
|
req.Header.Set("Origin", targetURL.Scheme+"://"+targetURL.Host)
|
||||||
}
|
}
|
||||||
|
|
||||||
for k, v := range r.Header {
|
for k, v := range r.Header {
|
||||||
// Skip hop-by-hop headers
|
// Skip hop-by-hop headers
|
||||||
if skipHeader(k) {
|
if skipHeader(k) {
|
||||||
|
@ -277,7 +277,7 @@ func buildProxyHandler(cfg *config.Config, modifiers map[string]RequestModifier,
|
||||||
proxyHost: r.Host,
|
proxyHost: r.Host,
|
||||||
targetHost: targetURL.Host,
|
targetHost: targetURL.Host,
|
||||||
}
|
}
|
||||||
|
|
||||||
// Set SSE-specific headers
|
// Set SSE-specific headers
|
||||||
w.Header().Set("X-Accel-Buffering", "no")
|
w.Header().Set("X-Accel-Buffering", "no")
|
||||||
w.Header().Set("Cache-Control", "no-cache")
|
w.Header().Set("Cache-Control", "no-cache")
|
||||||
|
@ -296,15 +296,14 @@ func buildProxyHandler(cfg *config.Config, modifiers map[string]RequestModifier,
|
||||||
|
|
||||||
// Check if the request is for SSE handshake and authorize it
|
// Check if the request is for SSE handshake and authorize it
|
||||||
func authorizeSSE(w http.ResponseWriter, r *http.Request, isLatestSpec bool, resourceID string) error {
|
func authorizeSSE(w http.ResponseWriter, r *http.Request, isLatestSpec bool, resourceID string) error {
|
||||||
h := r.Header.Get("Authorization")
|
authHeader := r.Header.Get("Authorization")
|
||||||
if !strings.HasPrefix(h, "Bearer ") {
|
if authHeader == "" || !strings.HasPrefix(authHeader, "Bearer ") {
|
||||||
if isLatestSpec {
|
if isLatestSpec {
|
||||||
realm := resourceID + "/.well-known/oauth-protected-resource"
|
realm := resourceID + "/.well-known/oauth-protected-resource"
|
||||||
w.Header().Set("WWW-Authenticate", fmt.Sprintf(`Bearer resource_metadata="%s"`, realm))
|
w.Header().Set("WWW-Authenticate", fmt.Sprintf(`Bearer resource_metadata="%s"`, realm))
|
||||||
w.Header().Set("Access-Control-Expose-Headers", "WWW-Authenticate")
|
w.Header().Set("Access-Control-Expose-Headers", "WWW-Authenticate")
|
||||||
}
|
}
|
||||||
|
return fmt.Errorf("missing or invalid Authorization header")
|
||||||
return fmt.Errorf("missing bearer token")
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
|
@ -312,23 +311,31 @@ func authorizeSSE(w http.ResponseWriter, r *http.Request, isLatestSpec bool, res
|
||||||
|
|
||||||
// Handles both v1 (just signature) and v2 (aud + scope) flows
|
// Handles both v1 (just signature) and v2 (aud + scope) flows
|
||||||
func authorizeMCP(w http.ResponseWriter, r *http.Request, isLatestSpec bool, cfg *config.Config) (*authz.TokenClaims, error) {
|
func authorizeMCP(w http.ResponseWriter, r *http.Request, isLatestSpec bool, cfg *config.Config) (*authz.TokenClaims, error) {
|
||||||
|
logger.Info("authorizeMCP")
|
||||||
h := r.Header.Get("Authorization")
|
h := r.Header.Get("Authorization")
|
||||||
audience := cfg.ResourceIdentifier
|
audience := cfg.ResourceIdentifier
|
||||||
if isLatestSpec {
|
if isLatestSpec {
|
||||||
scope := cfg.ScopesSupported[r.URL.Path]
|
required := cfg.ScopesSupported[r.URL.Path]
|
||||||
claims, err := util.ValidateJWT(r.Header.Get("MCP-Protocol-Version"), h, audience, scope)
|
claims, err := util.ValidateJWT(r.Header.Get("MCP-Protocol-Version"), h, audience, required)
|
||||||
|
logger.Info("claims: %v", claims)
|
||||||
|
logger.Info("err: %v", err)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
realm := audience + "/.well-known/oauth-protected-resource"
|
w.Header().Set(
|
||||||
w.Header().Set("WWW-Authenticate",
|
"WWW-Authenticate",
|
||||||
fmt.Sprintf(`Bearer realm="%s", error="insufficient_scope", scope="%s"`, realm, scope))
|
fmt.Sprintf(
|
||||||
|
`Bearer realm="%s", error="insufficient_scope", scope="%s"`,
|
||||||
|
cfg.ResourceIdentifier+"/.well-known/oauth-protected-resource",
|
||||||
|
required,
|
||||||
|
),
|
||||||
|
)
|
||||||
w.Header().Set("Access-Control-Expose-Headers", "WWW-Authenticate")
|
w.Header().Set("Access-Control-Expose-Headers", "WWW-Authenticate")
|
||||||
return nil, err
|
return nil, fmt.Errorf("forbidden — insufficient scope")
|
||||||
}
|
}
|
||||||
return claims, nil
|
return claims, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// v1: only check signature, then continue
|
// v1: only check signature, then continue
|
||||||
if err := util.ValidateJWTOld(h); err != nil {
|
if err := util.ValidateJWTLegacy(h); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -352,7 +359,7 @@ func getAllowedOrigin(origin string, cfg *config.Config) string {
|
||||||
func addCORSHeaders(w http.ResponseWriter, cfg *config.Config, allowedOrigin, requestHeaders string) {
|
func addCORSHeaders(w http.ResponseWriter, cfg *config.Config, allowedOrigin, requestHeaders string) {
|
||||||
w.Header().Set("Access-Control-Allow-Origin", allowedOrigin)
|
w.Header().Set("Access-Control-Allow-Origin", allowedOrigin)
|
||||||
w.Header().Set("Access-Control-Allow-Methods", strings.Join(cfg.CORSConfig.AllowedMethods, ", "))
|
w.Header().Set("Access-Control-Allow-Methods", strings.Join(cfg.CORSConfig.AllowedMethods, ", "))
|
||||||
w.Header().Set("Access-Control-Expose-Headers", "WWW-Authenticate")
|
w.Header().Set("Access-Control-Expose-Headers", "WWW-Authenticate, MCP-Protocol-Version")
|
||||||
if requestHeaders != "" {
|
if requestHeaders != "" {
|
||||||
w.Header().Set("Access-Control-Allow-Headers", requestHeaders)
|
w.Header().Set("Access-Control-Allow-Headers", requestHeaders)
|
||||||
} else {
|
} else {
|
||||||
|
@ -360,6 +367,7 @@ func addCORSHeaders(w http.ResponseWriter, cfg *config.Config, allowedOrigin, re
|
||||||
}
|
}
|
||||||
if cfg.CORSConfig.AllowCredentials {
|
if cfg.CORSConfig.AllowCredentials {
|
||||||
w.Header().Set("Access-Control-Allow-Credentials", "true")
|
w.Header().Set("Access-Control-Allow-Credentials", "true")
|
||||||
|
w.Header().Set("MCP-Protocol-Version", ", ")
|
||||||
}
|
}
|
||||||
w.Header().Set("Vary", "Origin")
|
w.Header().Set("Vary", "Origin")
|
||||||
w.Header().Set("X-Accel-Buffering", "no")
|
w.Header().Set("X-Accel-Buffering", "no")
|
||||||
|
|
|
@ -4,21 +4,29 @@ import (
|
||||||
"crypto/rsa"
|
"crypto/rsa"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
|
"fmt"
|
||||||
"math/big"
|
"math/big"
|
||||||
"net/http"
|
"net/http"
|
||||||
"strings"
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/golang-jwt/jwt/v4"
|
"github.com/golang-jwt/jwt/v4"
|
||||||
"github.com/wso2/open-mcp-auth-proxy/internal/logging"
|
"github.com/wso2/open-mcp-auth-proxy/internal/authz"
|
||||||
|
"github.com/wso2/open-mcp-auth-proxy/internal/constants"
|
||||||
|
logger "github.com/wso2/open-mcp-auth-proxy/internal/logging"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
type TokenClaims struct {
|
||||||
|
Scopes []string
|
||||||
|
}
|
||||||
|
|
||||||
type JWKS struct {
|
type JWKS struct {
|
||||||
Keys []json.RawMessage `json:"keys"`
|
Keys []json.RawMessage `json:"keys"`
|
||||||
}
|
}
|
||||||
|
|
||||||
var publicKeys map[string]*rsa.PublicKey
|
var publicKeys map[string]*rsa.PublicKey
|
||||||
|
|
||||||
// FetchJWKS downloads JWKS and stores in a package-level map
|
// FetchJWKS downloads JWKS and stores in a package‐level map
|
||||||
func FetchJWKS(jwksURL string) error {
|
func FetchJWKS(jwksURL string) error {
|
||||||
resp, err := http.Get(jwksURL)
|
resp, err := http.Get(jwksURL)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -31,23 +39,23 @@ func FetchJWKS(jwksURL string) error {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
publicKeys = make(map[string]*rsa.PublicKey)
|
publicKeys = make(map[string]*rsa.PublicKey, len(jwks.Keys))
|
||||||
for _, keyData := range jwks.Keys {
|
for _, keyData := range jwks.Keys {
|
||||||
var parsedKey struct {
|
var parsed struct {
|
||||||
Kid string `json:"kid"`
|
Kid string `json:"kid"`
|
||||||
N string `json:"n"`
|
N string `json:"n"`
|
||||||
E string `json:"e"`
|
E string `json:"e"`
|
||||||
Kty string `json:"kty"`
|
Kty string `json:"kty"`
|
||||||
}
|
}
|
||||||
if err := json.Unmarshal(keyData, &parsedKey); err != nil {
|
if err := json.Unmarshal(keyData, &parsed); err != nil {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
if parsedKey.Kty != "RSA" {
|
if parsed.Kty != "RSA" {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
pubKey, err := parseRSAPublicKey(parsedKey.N, parsedKey.E)
|
pk, err := parseRSAPublicKey(parsed.N, parsed.E)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
publicKeys[parsedKey.Kid] = pubKey
|
publicKeys[parsed.Kid] = pk
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
logger.Info("Loaded %d public keys.", len(publicKeys))
|
logger.Info("Loaded %d public keys.", len(publicKeys))
|
||||||
|
@ -73,25 +81,121 @@ func parseRSAPublicKey(nStr, eStr string) (*rsa.PublicKey, error) {
|
||||||
return &rsa.PublicKey{N: n, E: e}, nil
|
return &rsa.PublicKey{N: n, E: e}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// ValidateJWT checks the Authorization: Bearer token using stored JWKS
|
// ValidateJWT checks the Bearer token according to the Mcp-Protocol-Version.
|
||||||
func ValidateJWT(authHeader string) error {
|
// - versionHeader: the raw value of the "Mcp-Protocol-Version" header
|
||||||
if authHeader == "" || !strings.HasPrefix(authHeader, "Bearer ") {
|
// - authHeader: the full "Authorization" header
|
||||||
return errors.New("missing or invalid Authorization header")
|
// - audience: the resource identifier to check "aud" against
|
||||||
}
|
// - requiredScope: the single scope required (empty ⇒ skip scope check)
|
||||||
|
func ValidateJWT(
|
||||||
|
versionHeader, authHeader, audience, requiredScope string,
|
||||||
|
) (*authz.TokenClaims, error) {
|
||||||
tokenStr := strings.TrimPrefix(authHeader, "Bearer ")
|
tokenStr := strings.TrimPrefix(authHeader, "Bearer ")
|
||||||
|
if tokenStr == "" {
|
||||||
|
return nil, errors.New("empty bearer token")
|
||||||
|
}
|
||||||
|
|
||||||
|
// 2) parse & verify signature
|
||||||
token, err := jwt.Parse(tokenStr, func(token *jwt.Token) (interface{}, error) {
|
token, err := jwt.Parse(tokenStr, func(token *jwt.Token) (interface{}, error) {
|
||||||
kid, _ := token.Header["kid"].(string)
|
kid, _ := token.Header["kid"].(string)
|
||||||
pubKey, ok := publicKeys[kid]
|
pk, ok := publicKeys[kid]
|
||||||
if !ok {
|
if !ok {
|
||||||
return nil, errors.New("unknown or missing kid in token header")
|
return nil, fmt.Errorf("unknown kid %q", kid)
|
||||||
}
|
}
|
||||||
return pubKey, nil
|
return pk, nil
|
||||||
})
|
})
|
||||||
|
|
||||||
|
logger.Info("token: %v", token)
|
||||||
|
logger.Info("err: %v", err)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return errors.New("invalid token: " + err.Error())
|
return nil, fmt.Errorf("invalid token: %w", err)
|
||||||
}
|
}
|
||||||
if !token.Valid {
|
if !token.Valid {
|
||||||
return errors.New("invalid token: token not valid")
|
return nil, errors.New("token not valid")
|
||||||
}
|
}
|
||||||
return nil
|
|
||||||
|
// always extract claims
|
||||||
|
claimsMap, ok := token.Claims.(jwt.MapClaims)
|
||||||
|
if !ok {
|
||||||
|
return nil, errors.New("unexpected claim type")
|
||||||
|
}
|
||||||
|
|
||||||
|
// parse version date
|
||||||
|
verDate, err := time.Parse("2006-01-02", versionHeader)
|
||||||
|
if err != nil {
|
||||||
|
// if unparsable or missing, assume _old_ spec
|
||||||
|
verDate = time.Time{} // zero time ⇒ before cutover
|
||||||
|
}
|
||||||
|
|
||||||
|
// if older than cutover, skip audience+scope
|
||||||
|
if verDate.Before(constants.SpecCutoverDate) {
|
||||||
|
return &authz.TokenClaims{Scopes: nil}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- new spec flow: enforce audience ---
|
||||||
|
audRaw, exists := claimsMap["aud"]
|
||||||
|
if !exists {
|
||||||
|
return nil, errors.New("aud claim missing")
|
||||||
|
}
|
||||||
|
switch v := audRaw.(type) {
|
||||||
|
case string:
|
||||||
|
if v != audience {
|
||||||
|
return nil, fmt.Errorf("aud %q does not match %q", v, audience)
|
||||||
|
}
|
||||||
|
case []interface{}:
|
||||||
|
var found bool
|
||||||
|
for _, a := range v {
|
||||||
|
if s, ok := a.(string); ok && s == audience {
|
||||||
|
found = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !found {
|
||||||
|
return nil, fmt.Errorf("audience %v does not include %q", v, audience)
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
return nil, errors.New("aud claim has unexpected type")
|
||||||
|
}
|
||||||
|
|
||||||
|
// if no scope required, we're done
|
||||||
|
if requiredScope == "" {
|
||||||
|
return &authz.TokenClaims{Scopes: nil}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// enforce scope
|
||||||
|
rawScope, exists := claimsMap["scope"]
|
||||||
|
if !exists {
|
||||||
|
return nil, errors.New("scope claim missing")
|
||||||
|
}
|
||||||
|
scopeStr, ok := rawScope.(string)
|
||||||
|
if !ok {
|
||||||
|
return nil, errors.New("scope claim not a string")
|
||||||
|
}
|
||||||
|
scopes := strings.Fields(scopeStr)
|
||||||
|
for _, s := range scopes {
|
||||||
|
if s == requiredScope {
|
||||||
|
return &authz.TokenClaims{Scopes: scopes}, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil, fmt.Errorf("insufficient scope: %q not in %v", requiredScope, scopes)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Performs basic JWT validation
|
||||||
|
func ValidateJWTLegacy(authHeader string) error {
|
||||||
|
tokenString := strings.TrimPrefix(authHeader, "Bearer ")
|
||||||
|
_, err := jwt.Parse(tokenString, func(token *jwt.Token) (interface{}, error) {
|
||||||
|
if _, ok := token.Method.(*jwt.SigningMethodRSA); !ok {
|
||||||
|
return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
|
||||||
|
}
|
||||||
|
kid, ok := token.Header["kid"].(string)
|
||||||
|
if !ok {
|
||||||
|
return nil, errors.New("kid header not found")
|
||||||
|
}
|
||||||
|
key, ok := publicKeys[kid]
|
||||||
|
if !ok {
|
||||||
|
return nil, fmt.Errorf("key not found for kid: %s", kid)
|
||||||
|
}
|
||||||
|
return key, nil
|
||||||
|
})
|
||||||
|
return err
|
||||||
}
|
}
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue