Update MCP proxy to adhere to the latest draft of MCP specification

This commit is contained in:
NipuniBhagya 2025-05-13 23:58:06 +05:30
parent 9c2d37e2df
commit 85e5fe1c1d
7 changed files with 191 additions and 41 deletions

View file

@ -92,12 +92,15 @@ func main() {
os.Exit(1) os.Exit(1)
} }
// 5. Build the main router // 5. (Optional) Build the policy engine
mux := proxy.NewRouter(cfg, provider) engine := &authz.DefaulPolicyEngine{}
// 6. Build the main router
mux := proxy.NewRouter(cfg, provider, engine)
listen_address := fmt.Sprintf(":%d", cfg.ListenPort) listen_address := fmt.Sprintf(":%d", cfg.ListenPort)
// 6. Start the server // 7. Start the server
srv := &http.Server{ srv := &http.Server{
Addr: listen_address, Addr: listen_address,
Handler: mux, Handler: mux,
@ -111,18 +114,18 @@ func main() {
} }
}() }()
// 7. Wait for shutdown signal // 8. Wait for shutdown signal
stop := make(chan os.Signal, 1) stop := make(chan os.Signal, 1)
signal.Notify(stop, os.Interrupt, syscall.SIGTERM) signal.Notify(stop, os.Interrupt, syscall.SIGTERM)
<-stop <-stop
logger.Info("Shutting down...") logger.Info("Shutting down...")
// 8. First terminate subprocess if running // 9. First terminate subprocess if running
if procManager != nil && procManager.IsRunning() { if procManager != nil && procManager.IsRunning() {
procManager.Shutdown() procManager.Shutdown()
} }
// 9. Then shutdown the server // 10. Then shutdown the server
logger.Info("Shutting down HTTP server...") logger.Info("Shutting down HTTP server...")
shutdownCtx, cancel := proxy.NewShutdownContext(5 * time.Second) shutdownCtx, cancel := proxy.NewShutdownContext(5 * time.Second)
defer cancel() defer cancel()

View file

@ -94,3 +94,26 @@ func (p *defaultProvider) WellKnownHandler() http.HandlerFunc {
func (p *defaultProvider) RegisterHandler() http.HandlerFunc { func (p *defaultProvider) RegisterHandler() http.HandlerFunc {
return nil return nil
} }
func (p *defaultProvider) ProtectedResourceMetadataHandler() http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
meta := map[string]interface{}{
"resource": p.cfg.ResourceIdentifier,
"scopes_supported": p.cfg.ScopesSupported,
"authorization_servers": p.cfg.AuthorizationServers,
}
if p.cfg.JwksURI != "" {
meta["jwks_uri"] = p.cfg.JwksURI
}
if len(p.cfg.BearerMethodsSupported) > 0 {
meta["bearer_methods_supported"] = p.cfg.BearerMethodsSupported
}
if err := json.NewEncoder(w).Encode(meta); err != nil {
http.Error(w, "failed to encode metadata", http.StatusInternalServerError)
}
}
}

19
internal/authz/policy.go Normal file
View file

@ -0,0 +1,19 @@
package authz
import "net/http"
type Decision int
const (
DecisionAllow Decision = iota
DecisionDeny
)
type PolicyResult struct {
Decision Decision
Message string
}
type PolicyEngine interface {
Evaluate(r *http.Request, claims *TokenClaims, requiredScope string) PolicyResult
}

View file

@ -7,4 +7,5 @@ import "net/http"
type Provider interface { type Provider interface {
WellKnownHandler() http.HandlerFunc WellKnownHandler() http.HandlerFunc
RegisterHandler() http.HandlerFunc RegisterHandler() http.HandlerFunc
ProtectedResourceMetadataHandler() http.HandlerFunc
} }

View file

@ -17,15 +17,15 @@ const (
// Common path configuration for all transport modes // Common path configuration for all transport modes
type PathsConfig struct { type PathsConfig struct {
SSE string `yaml:"sse"` SSE string `yaml:"sse"`
Messages string `yaml:"messages"` Messages string `yaml:"messages"`
} }
// StdioConfig contains stdio-specific configuration // StdioConfig contains stdio-specific configuration
type StdioConfig struct { type StdioConfig struct {
Enabled bool `yaml:"enabled"` Enabled bool `yaml:"enabled"`
UserCommand string `yaml:"user_command"` // The command provided by the user UserCommand string `yaml:"user_command"` // The command provided by the user
WorkDir string `yaml:"work_dir"` // Working directory (optional) WorkDir string `yaml:"work_dir"` // Working directory (optional)
Args []string `yaml:"args,omitempty"` // Additional arguments Args []string `yaml:"args,omitempty"` // Additional arguments
Env []string `yaml:"env,omitempty"` // Environment variables Env []string `yaml:"env,omitempty"` // Environment variables
} }
@ -83,23 +83,31 @@ type DefaultConfig struct {
} }
type Config struct { type Config struct {
AuthServerBaseURL string AuthServerBaseURL string
ListenPort int `yaml:"listen_port"` ListenPort int `yaml:"listen_port"`
BaseURL string `yaml:"base_url"` BaseURL string `yaml:"base_url"`
Port int `yaml:"port"` Port int `yaml:"port"`
JWKSURL string JWKSURL string
TimeoutSeconds int `yaml:"timeout_seconds"` TimeoutSeconds int `yaml:"timeout_seconds"`
PathMapping map[string]string `yaml:"path_mapping"` PathMapping map[string]string `yaml:"path_mapping"`
Mode string `yaml:"mode"` Mode string `yaml:"mode"`
CORSConfig CORSConfig `yaml:"cors"` CORSConfig CORSConfig `yaml:"cors"`
TransportMode TransportMode `yaml:"transport_mode"` TransportMode TransportMode `yaml:"transport_mode"`
Paths PathsConfig `yaml:"paths"` Paths PathsConfig `yaml:"paths"`
Stdio StdioConfig `yaml:"stdio"` Stdio StdioConfig `yaml:"stdio"`
RequiredScopes map[string]string `yaml:"required_scopes"`
// Nested config for Asgardeo // Nested config for Asgardeo
Demo DemoConfig `yaml:"demo"` Demo DemoConfig `yaml:"demo"`
Asgardeo AsgardeoConfig `yaml:"asgardeo"` Asgardeo AsgardeoConfig `yaml:"asgardeo"`
Default DefaultConfig `yaml:"default"` Default DefaultConfig `yaml:"default"`
// Protected resource metadata
ResourceIdentifier string `yaml:"resource_identifier"`
ScopesSupported map[string]string `yaml:"scopes_supported"`
AuthorizationServers []string `yaml:"authorization_servers"`
JwksURI string `yaml:"jwks_uri,omitempty"`
BearerMethodsSupported []string `yaml:"bearer_methods_supported,omitempty"`
} }
// Validate checks if the config is valid based on transport mode // Validate checks if the config is valid based on transport mode
@ -165,12 +173,12 @@ func LoadConfig(path string) (*Config, error) {
if err := decoder.Decode(&cfg); err != nil { if err := decoder.Decode(&cfg); err != nil {
return nil, err return nil, err
} }
// Set default values // Set default values
if cfg.TimeoutSeconds == 0 { if cfg.TimeoutSeconds == 0 {
cfg.TimeoutSeconds = 15 // default cfg.TimeoutSeconds = 15 // default
} }
// Set default transport mode if not specified // Set default transport mode if not specified
if cfg.TransportMode == "" { if cfg.TransportMode == "" {
cfg.TransportMode = SSETransport // Default to SSE cfg.TransportMode = SSETransport // Default to SSE
@ -180,11 +188,11 @@ func LoadConfig(path string) (*Config, error) {
if cfg.Port == 0 { if cfg.Port == 0 {
cfg.Port = 8000 // default cfg.Port = 8000 // default
} }
// Validate the configuration // Validate the configuration
if err := cfg.Validate(); err != nil { if err := cfg.Validate(); err != nil {
return nil, err return nil, err
} }
return &cfg, nil return &cfg, nil
} }

View file

@ -1,7 +1,14 @@
package constants package constants
import "time"
// Package constant provides constants for the MCP Auth Proxy // Package constant provides constants for the MCP Auth Proxy
const ( const (
ASGARDEO_BASE_URL = "https://api.asgardeo.io/t/" ASGARDEO_BASE_URL = "https://api.asgardeo.io/t/"
) )
// MCP specification version cutover date
var SpecCutoverDate = time.Date(2025, 3, 26, 0, 0, 0, 0, time.UTC)
const TimeLayout = "2006-01-02"

View file

@ -2,6 +2,7 @@ package proxy
import ( import (
"context" "context"
"fmt"
"net/http" "net/http"
"net/http/httputil" "net/http/httputil"
"net/url" "net/url"
@ -10,14 +11,15 @@ import (
"github.com/wso2/open-mcp-auth-proxy/internal/authz" "github.com/wso2/open-mcp-auth-proxy/internal/authz"
"github.com/wso2/open-mcp-auth-proxy/internal/config" "github.com/wso2/open-mcp-auth-proxy/internal/config"
"github.com/wso2/open-mcp-auth-proxy/internal/logging" "github.com/wso2/open-mcp-auth-proxy/internal/constants"
logger "github.com/wso2/open-mcp-auth-proxy/internal/logging"
"github.com/wso2/open-mcp-auth-proxy/internal/util" "github.com/wso2/open-mcp-auth-proxy/internal/util"
) )
// NewRouter builds an http.ServeMux that routes // NewRouter builds an http.ServeMux that routes
// * /authorize, /token, /register, /.well-known to the provider or proxy // * /authorize, /token, /register, /.well-known to the provider or proxy
// * MCP paths to the MCP server, etc. // * MCP paths to the MCP server, etc.
func NewRouter(cfg *config.Config, provider authz.Provider) http.Handler { func NewRouter(cfg *config.Config, provider authz.Provider, policyEngine authz.PolicyEngine) http.Handler {
mux := http.NewServeMux() mux := http.NewServeMux()
modifiers := map[string]RequestModifier{ modifiers := map[string]RequestModifier{
@ -55,6 +57,20 @@ func NewRouter(cfg *config.Config, provider authz.Provider) http.Handler {
defaultPaths = append(defaultPaths, "/.well-known/oauth-authorization-server") defaultPaths = append(defaultPaths, "/.well-known/oauth-authorization-server")
} }
mux.HandleFunc("/.well-known/oauth-protected-resource", func(w http.ResponseWriter, r *http.Request) {
origin := r.Header.Get("Origin")
allowed := getAllowedOrigin(origin, cfg)
if r.Method == http.MethodOptions {
addCORSHeaders(w, cfg, allowed, r.Header.Get("Access-Control-Request-Headers"))
w.WriteHeader(http.StatusNoContent)
return
}
addCORSHeaders(w, cfg, allowed, "")
provider.ProtectedResourceMetadataHandler()(w, r)
})
registeredPaths["/.well-known/oauth-protected-resource"] = true
defaultPaths = append(defaultPaths, "/authorize") defaultPaths = append(defaultPaths, "/authorize")
defaultPaths = append(defaultPaths, "/token") defaultPaths = append(defaultPaths, "/token")
defaultPaths = append(defaultPaths, "/register") defaultPaths = append(defaultPaths, "/register")
@ -76,7 +92,7 @@ func NewRouter(cfg *config.Config, provider authz.Provider) http.Handler {
for _, path := range defaultPaths { for _, path := range defaultPaths {
if !registeredPaths[path] { if !registeredPaths[path] {
mux.HandleFunc(path, buildProxyHandler(cfg, modifiers)) mux.HandleFunc(path, buildProxyHandler(cfg, modifiers, policyEngine))
registeredPaths[path] = true registeredPaths[path] = true
} }
} }
@ -84,14 +100,14 @@ func NewRouter(cfg *config.Config, provider authz.Provider) http.Handler {
// MCP paths // MCP paths
mcpPaths := cfg.GetMCPPaths() mcpPaths := cfg.GetMCPPaths()
for _, path := range mcpPaths { for _, path := range mcpPaths {
mux.HandleFunc(path, buildProxyHandler(cfg, modifiers)) mux.HandleFunc(path, buildProxyHandler(cfg, modifiers, policyEngine))
registeredPaths[path] = true registeredPaths[path] = true
} }
// Register paths from PathMapping that haven't been registered yet // Register paths from PathMapping that haven't been registered yet
for path := range cfg.PathMapping { for path := range cfg.PathMapping {
if !registeredPaths[path] { if !registeredPaths[path] {
mux.HandleFunc(path, buildProxyHandler(cfg, modifiers)) mux.HandleFunc(path, buildProxyHandler(cfg, modifiers, policyEngine))
registeredPaths[path] = true registeredPaths[path] = true
} }
} }
@ -99,14 +115,14 @@ func NewRouter(cfg *config.Config, provider authz.Provider) http.Handler {
return mux return mux
} }
func buildProxyHandler(cfg *config.Config, modifiers map[string]RequestModifier) http.HandlerFunc { func buildProxyHandler(cfg *config.Config, modifiers map[string]RequestModifier, policyEngine authz.PolicyEngine) http.HandlerFunc {
// Parse the base URLs up front // Parse the base URLs up front
authBase, err := url.Parse(cfg.AuthServerBaseURL) authBase, err := url.Parse(cfg.AuthServerBaseURL)
if err != nil { if err != nil {
logger.Error("Invalid auth server URL: %v", err) logger.Error("Invalid auth server URL: %v", err)
panic(err) // Fatal error that prevents startup panic(err) // Fatal error that prevents startup
} }
mcpBase, err := url.Parse(cfg.BaseURL) mcpBase, err := url.Parse(cfg.BaseURL)
if err != nil { if err != nil {
logger.Error("Invalid MCP server URL: %v", err) logger.Error("Invalid MCP server URL: %v", err)
@ -141,6 +157,10 @@ func buildProxyHandler(cfg *config.Config, modifiers map[string]RequestModifier)
// Add CORS headers to all responses // Add CORS headers to all responses
addCORSHeaders(w, cfg, allowedOrigin, "") addCORSHeaders(w, cfg, allowedOrigin, "")
versionRaw := r.Header.Get("MCP-Protocol-Version")
ver, err := time.Parse(constants.TimeLayout, versionRaw)
isLatestSpec := err == nil && !ver.Before(constants.SpecCutoverDate)
// Decide whether the request should go to the auth server or MCP // Decide whether the request should go to the auth server or MCP
var targetURL *url.URL var targetURL *url.URL
isSSE := false isSSE := false
@ -148,13 +168,29 @@ func buildProxyHandler(cfg *config.Config, modifiers map[string]RequestModifier)
if isAuthPath(r.URL.Path) { if isAuthPath(r.URL.Path) {
targetURL = authBase targetURL = authBase
} else if isMCPPath(r.URL.Path, cfg) { } else if isMCPPath(r.URL.Path, cfg) {
// Validate JWT for MCP paths if required if ssePaths[r.URL.Path] {
// Placeholder for JWT validation logic if err := authorizeSSE(w, r, isLatestSpec, cfg.ResourceIdentifier); err != nil {
if err := util.ValidateJWT(r.Header.Get("Authorization")); err != nil { http.Error(w, err.Error(), http.StatusUnauthorized)
logger.Warn("Unauthorized request to %s: %v", r.URL.Path, err) return
http.Error(w, "Unauthorized", http.StatusUnauthorized) }
return isSSE = true
} else {
claims, err := authorizeMCP(w, r, isLatestSpec, cfg)
if err != nil {
http.Error(w, err.Error(), http.StatusUnauthorized)
return
}
if isLatestSpec {
scope := cfg.ScopesSupported[r.URL.Path]
pr := policyEngine.Evaluate(r, claims, scope)
if pr.Decision == authz.DecisionDeny {
http.Error(w, "Forbidden: "+pr.Message, http.StatusForbidden)
return
}
}
} }
targetURL = mcpBase targetURL = mcpBase
if ssePaths[r.URL.Path] { if ssePaths[r.URL.Path] {
isSSE = true isSSE = true
@ -214,7 +250,17 @@ func buildProxyHandler(cfg *config.Config, modifiers map[string]RequestModifier)
}, },
ModifyResponse: func(resp *http.Response) error { ModifyResponse: func(resp *http.Response) error {
logger.Debug("Response from %s%s: %d", resp.Request.URL.Host, resp.Request.URL.Path, resp.StatusCode) logger.Debug("Response from %s%s: %d", resp.Request.URL.Host, resp.Request.URL.Path, resp.StatusCode)
resp.Header.Del("Access-Control-Allow-Origin") // Avoid upstream conflicts if resp.StatusCode == http.StatusUnauthorized {
resp.Header.Set(
"WWW-Authenticate",
fmt.Sprintf(
`Bearer resource_metadata="%s"`,
cfg.ResourceIdentifier+"/.well-known/oauth-protected-resource",
))
resp.Header.Set("Access-Control-Expose-Headers", "WWW-Authenticate")
}
resp.Header.Del("Access-Control-Allow-Origin")
return nil return nil
}, },
ErrorHandler: func(rw http.ResponseWriter, req *http.Request, err error) { ErrorHandler: func(rw http.ResponseWriter, req *http.Request, err error) {
@ -236,7 +282,7 @@ func buildProxyHandler(cfg *config.Config, modifiers map[string]RequestModifier)
w.Header().Set("X-Accel-Buffering", "no") w.Header().Set("X-Accel-Buffering", "no")
w.Header().Set("Cache-Control", "no-cache") w.Header().Set("Cache-Control", "no-cache")
w.Header().Set("Connection", "keep-alive") w.Header().Set("Connection", "keep-alive")
w.Header().Set("Content-Type", "text/event-stream")
// Keep SSE connections open // Keep SSE connections open
HandleSSE(w, r, rp) HandleSSE(w, r, rp)
} else { } else {
@ -248,6 +294,47 @@ func buildProxyHandler(cfg *config.Config, modifiers map[string]RequestModifier)
} }
} }
// Check if the request is for SSE handshake and authorize it
func authorizeSSE(w http.ResponseWriter, r *http.Request, isLatestSpec bool, resourceID string) error {
h := r.Header.Get("Authorization")
if !strings.HasPrefix(h, "Bearer ") {
if isLatestSpec {
realm := resourceID + "/.well-known/oauth-protected-resource"
w.Header().Set("WWW-Authenticate", fmt.Sprintf(`Bearer resource_metadata="%s"`, realm))
w.Header().Set("Access-Control-Expose-Headers", "WWW-Authenticate")
}
return fmt.Errorf("missing bearer token")
}
return nil
}
// Handles both v1 (just signature) and v2 (aud + scope) flows
func authorizeMCP(w http.ResponseWriter, r *http.Request, isLatestSpec bool, cfg *config.Config) (*authz.TokenClaims, error) {
h := r.Header.Get("Authorization")
audience := cfg.ResourceIdentifier
if isLatestSpec {
scope := cfg.ScopesSupported[r.URL.Path]
claims, err := util.ValidateJWT(r.Header.Get("MCP-Protocol-Version"), h, audience, scope)
if err != nil {
realm := audience + "/.well-known/oauth-protected-resource"
w.Header().Set("WWW-Authenticate",
fmt.Sprintf(`Bearer realm="%s", error="insufficient_scope", scope="%s"`, realm, scope))
w.Header().Set("Access-Control-Expose-Headers", "WWW-Authenticate")
return nil, err
}
return claims, nil
}
// v1: only check signature, then continue
if err := util.ValidateJWTOld(h); err != nil {
return nil, err
}
return &authz.TokenClaims{}, nil
}
func getAllowedOrigin(origin string, cfg *config.Config) string { func getAllowedOrigin(origin string, cfg *config.Config) string {
if origin == "" { if origin == "" {
return cfg.CORSConfig.AllowedOrigins[0] // Default to first allowed origin return cfg.CORSConfig.AllowedOrigins[0] // Default to first allowed origin
@ -265,6 +352,7 @@ func getAllowedOrigin(origin string, cfg *config.Config) string {
func addCORSHeaders(w http.ResponseWriter, cfg *config.Config, allowedOrigin, requestHeaders string) { func addCORSHeaders(w http.ResponseWriter, cfg *config.Config, allowedOrigin, requestHeaders string) {
w.Header().Set("Access-Control-Allow-Origin", allowedOrigin) w.Header().Set("Access-Control-Allow-Origin", allowedOrigin)
w.Header().Set("Access-Control-Allow-Methods", strings.Join(cfg.CORSConfig.AllowedMethods, ", ")) w.Header().Set("Access-Control-Allow-Methods", strings.Join(cfg.CORSConfig.AllowedMethods, ", "))
w.Header().Set("Access-Control-Expose-Headers", "WWW-Authenticate")
if requestHeaders != "" { if requestHeaders != "" {
w.Header().Set("Access-Control-Allow-Headers", requestHeaders) w.Header().Set("Access-Control-Allow-Headers", requestHeaders)
} else { } else {
@ -283,6 +371,7 @@ func isAuthPath(path string) bool {
"/token": true, "/token": true,
"/register": true, "/register": true,
"/.well-known/oauth-authorization-server": true, "/.well-known/oauth-authorization-server": true,
"/.well-known/oauth-protected-resource": true,
} }
if strings.HasPrefix(path, "/u/") { if strings.HasPrefix(path, "/u/") {
return true return true