diff --git a/cmd/proxy/main.go b/cmd/proxy/main.go index 6424f18..f24c21d 100644 --- a/cmd/proxy/main.go +++ b/cmd/proxy/main.go @@ -92,12 +92,15 @@ func main() { os.Exit(1) } - // 5. Build the main router - mux := proxy.NewRouter(cfg, provider) + // 5. (Optional) Build the policy engine + engine := &authz.DefaulPolicyEngine{} + + // 6. Build the main router + mux := proxy.NewRouter(cfg, provider, engine) listen_address := fmt.Sprintf(":%d", cfg.ListenPort) - // 6. Start the server + // 7. Start the server srv := &http.Server{ Addr: listen_address, Handler: mux, @@ -111,18 +114,18 @@ func main() { } }() - // 7. Wait for shutdown signal + // 8. Wait for shutdown signal stop := make(chan os.Signal, 1) signal.Notify(stop, os.Interrupt, syscall.SIGTERM) <-stop logger.Info("Shutting down...") - // 8. First terminate subprocess if running + // 9. First terminate subprocess if running if procManager != nil && procManager.IsRunning() { procManager.Shutdown() } - // 9. Then shutdown the server + // 10. Then shutdown the server logger.Info("Shutting down HTTP server...") shutdownCtx, cancel := proxy.NewShutdownContext(5 * time.Second) defer cancel() diff --git a/internal/authz/default.go b/internal/authz/default.go index 929f586..f4d640d 100644 --- a/internal/authz/default.go +++ b/internal/authz/default.go @@ -94,3 +94,26 @@ func (p *defaultProvider) WellKnownHandler() http.HandlerFunc { func (p *defaultProvider) RegisterHandler() http.HandlerFunc { return nil } + +func (p *defaultProvider) ProtectedResourceMetadataHandler() http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + meta := map[string]interface{}{ + "resource": p.cfg.ResourceIdentifier, + "scopes_supported": p.cfg.ScopesSupported, + "authorization_servers": p.cfg.AuthorizationServers, + } + + if p.cfg.JwksURI != "" { + meta["jwks_uri"] = p.cfg.JwksURI + } + + if len(p.cfg.BearerMethodsSupported) > 0 { + meta["bearer_methods_supported"] = p.cfg.BearerMethodsSupported + } + + if err := json.NewEncoder(w).Encode(meta); err != nil { + http.Error(w, "failed to encode metadata", http.StatusInternalServerError) + } + } +} diff --git a/internal/authz/policy.go b/internal/authz/policy.go new file mode 100644 index 0000000..793e7bc --- /dev/null +++ b/internal/authz/policy.go @@ -0,0 +1,19 @@ +package authz + +import "net/http" + +type Decision int + +const ( + DecisionAllow Decision = iota + DecisionDeny +) + +type PolicyResult struct { + Decision Decision + Message string +} + +type PolicyEngine interface { + Evaluate(r *http.Request, claims *TokenClaims, requiredScope string) PolicyResult +} diff --git a/internal/authz/provider.go b/internal/authz/provider.go index 1629cf4..42a8343 100644 --- a/internal/authz/provider.go +++ b/internal/authz/provider.go @@ -7,4 +7,5 @@ import "net/http" type Provider interface { WellKnownHandler() http.HandlerFunc RegisterHandler() http.HandlerFunc + ProtectedResourceMetadataHandler() http.HandlerFunc } diff --git a/internal/config/config.go b/internal/config/config.go index fc6743c..47778d0 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -17,15 +17,15 @@ const ( // Common path configuration for all transport modes type PathsConfig struct { - SSE string `yaml:"sse"` - Messages string `yaml:"messages"` + SSE string `yaml:"sse"` + Messages string `yaml:"messages"` } // StdioConfig contains stdio-specific configuration type StdioConfig struct { Enabled bool `yaml:"enabled"` - UserCommand string `yaml:"user_command"` // The command provided by the user - WorkDir string `yaml:"work_dir"` // Working directory (optional) + UserCommand string `yaml:"user_command"` // The command provided by the user + WorkDir string `yaml:"work_dir"` // Working directory (optional) Args []string `yaml:"args,omitempty"` // Additional arguments Env []string `yaml:"env,omitempty"` // Environment variables } @@ -83,23 +83,31 @@ type DefaultConfig struct { } type Config struct { - AuthServerBaseURL string - ListenPort int `yaml:"listen_port"` - BaseURL string `yaml:"base_url"` - Port int `yaml:"port"` - JWKSURL string - TimeoutSeconds int `yaml:"timeout_seconds"` - PathMapping map[string]string `yaml:"path_mapping"` - Mode string `yaml:"mode"` - CORSConfig CORSConfig `yaml:"cors"` - TransportMode TransportMode `yaml:"transport_mode"` - Paths PathsConfig `yaml:"paths"` - Stdio StdioConfig `yaml:"stdio"` + AuthServerBaseURL string + ListenPort int `yaml:"listen_port"` + BaseURL string `yaml:"base_url"` + Port int `yaml:"port"` + JWKSURL string + TimeoutSeconds int `yaml:"timeout_seconds"` + PathMapping map[string]string `yaml:"path_mapping"` + Mode string `yaml:"mode"` + CORSConfig CORSConfig `yaml:"cors"` + TransportMode TransportMode `yaml:"transport_mode"` + Paths PathsConfig `yaml:"paths"` + Stdio StdioConfig `yaml:"stdio"` + RequiredScopes map[string]string `yaml:"required_scopes"` // Nested config for Asgardeo Demo DemoConfig `yaml:"demo"` Asgardeo AsgardeoConfig `yaml:"asgardeo"` Default DefaultConfig `yaml:"default"` + + // Protected resource metadata + ResourceIdentifier string `yaml:"resource_identifier"` + ScopesSupported map[string]string `yaml:"scopes_supported"` + AuthorizationServers []string `yaml:"authorization_servers"` + JwksURI string `yaml:"jwks_uri,omitempty"` + BearerMethodsSupported []string `yaml:"bearer_methods_supported,omitempty"` } // Validate checks if the config is valid based on transport mode @@ -165,12 +173,12 @@ func LoadConfig(path string) (*Config, error) { if err := decoder.Decode(&cfg); err != nil { return nil, err } - + // Set default values if cfg.TimeoutSeconds == 0 { cfg.TimeoutSeconds = 15 // default } - + // Set default transport mode if not specified if cfg.TransportMode == "" { cfg.TransportMode = SSETransport // Default to SSE @@ -180,11 +188,11 @@ func LoadConfig(path string) (*Config, error) { if cfg.Port == 0 { cfg.Port = 8000 // default } - + // Validate the configuration if err := cfg.Validate(); err != nil { return nil, err } - + return &cfg, nil } diff --git a/internal/constants/constants.go b/internal/constants/constants.go index 1e5808e..e7b1bec 100644 --- a/internal/constants/constants.go +++ b/internal/constants/constants.go @@ -1,7 +1,14 @@ package constants +import "time" + // Package constant provides constants for the MCP Auth Proxy const ( ASGARDEO_BASE_URL = "https://api.asgardeo.io/t/" ) + +// MCP specification version cutover date +var SpecCutoverDate = time.Date(2025, 3, 26, 0, 0, 0, 0, time.UTC) + +const TimeLayout = "2006-01-02" diff --git a/internal/proxy/proxy.go b/internal/proxy/proxy.go index 33a9ea3..682867e 100644 --- a/internal/proxy/proxy.go +++ b/internal/proxy/proxy.go @@ -2,6 +2,7 @@ package proxy import ( "context" + "fmt" "net/http" "net/http/httputil" "net/url" @@ -10,14 +11,15 @@ import ( "github.com/wso2/open-mcp-auth-proxy/internal/authz" "github.com/wso2/open-mcp-auth-proxy/internal/config" - "github.com/wso2/open-mcp-auth-proxy/internal/logging" + "github.com/wso2/open-mcp-auth-proxy/internal/constants" + logger "github.com/wso2/open-mcp-auth-proxy/internal/logging" "github.com/wso2/open-mcp-auth-proxy/internal/util" ) // NewRouter builds an http.ServeMux that routes // * /authorize, /token, /register, /.well-known to the provider or proxy // * MCP paths to the MCP server, etc. -func NewRouter(cfg *config.Config, provider authz.Provider) http.Handler { +func NewRouter(cfg *config.Config, provider authz.Provider, policyEngine authz.PolicyEngine) http.Handler { mux := http.NewServeMux() modifiers := map[string]RequestModifier{ @@ -55,6 +57,20 @@ func NewRouter(cfg *config.Config, provider authz.Provider) http.Handler { defaultPaths = append(defaultPaths, "/.well-known/oauth-authorization-server") } + mux.HandleFunc("/.well-known/oauth-protected-resource", func(w http.ResponseWriter, r *http.Request) { + origin := r.Header.Get("Origin") + allowed := getAllowedOrigin(origin, cfg) + if r.Method == http.MethodOptions { + addCORSHeaders(w, cfg, allowed, r.Header.Get("Access-Control-Request-Headers")) + w.WriteHeader(http.StatusNoContent) + return + } + + addCORSHeaders(w, cfg, allowed, "") + provider.ProtectedResourceMetadataHandler()(w, r) + }) + registeredPaths["/.well-known/oauth-protected-resource"] = true + defaultPaths = append(defaultPaths, "/authorize") defaultPaths = append(defaultPaths, "/token") defaultPaths = append(defaultPaths, "/register") @@ -76,7 +92,7 @@ func NewRouter(cfg *config.Config, provider authz.Provider) http.Handler { for _, path := range defaultPaths { if !registeredPaths[path] { - mux.HandleFunc(path, buildProxyHandler(cfg, modifiers)) + mux.HandleFunc(path, buildProxyHandler(cfg, modifiers, policyEngine)) registeredPaths[path] = true } } @@ -84,14 +100,14 @@ func NewRouter(cfg *config.Config, provider authz.Provider) http.Handler { // MCP paths mcpPaths := cfg.GetMCPPaths() for _, path := range mcpPaths { - mux.HandleFunc(path, buildProxyHandler(cfg, modifiers)) + mux.HandleFunc(path, buildProxyHandler(cfg, modifiers, policyEngine)) registeredPaths[path] = true } // Register paths from PathMapping that haven't been registered yet for path := range cfg.PathMapping { if !registeredPaths[path] { - mux.HandleFunc(path, buildProxyHandler(cfg, modifiers)) + mux.HandleFunc(path, buildProxyHandler(cfg, modifiers, policyEngine)) registeredPaths[path] = true } } @@ -99,14 +115,14 @@ func NewRouter(cfg *config.Config, provider authz.Provider) http.Handler { return mux } -func buildProxyHandler(cfg *config.Config, modifiers map[string]RequestModifier) http.HandlerFunc { +func buildProxyHandler(cfg *config.Config, modifiers map[string]RequestModifier, policyEngine authz.PolicyEngine) http.HandlerFunc { // Parse the base URLs up front authBase, err := url.Parse(cfg.AuthServerBaseURL) if err != nil { logger.Error("Invalid auth server URL: %v", err) panic(err) // Fatal error that prevents startup } - + mcpBase, err := url.Parse(cfg.BaseURL) if err != nil { logger.Error("Invalid MCP server URL: %v", err) @@ -141,6 +157,10 @@ func buildProxyHandler(cfg *config.Config, modifiers map[string]RequestModifier) // Add CORS headers to all responses addCORSHeaders(w, cfg, allowedOrigin, "") + versionRaw := r.Header.Get("MCP-Protocol-Version") + ver, err := time.Parse(constants.TimeLayout, versionRaw) + isLatestSpec := err == nil && !ver.Before(constants.SpecCutoverDate) + // Decide whether the request should go to the auth server or MCP var targetURL *url.URL isSSE := false @@ -148,13 +168,29 @@ func buildProxyHandler(cfg *config.Config, modifiers map[string]RequestModifier) if isAuthPath(r.URL.Path) { targetURL = authBase } else if isMCPPath(r.URL.Path, cfg) { - // Validate JWT for MCP paths if required - // Placeholder for JWT validation logic - if err := util.ValidateJWT(r.Header.Get("Authorization")); err != nil { - logger.Warn("Unauthorized request to %s: %v", r.URL.Path, err) - http.Error(w, "Unauthorized", http.StatusUnauthorized) - return + if ssePaths[r.URL.Path] { + if err := authorizeSSE(w, r, isLatestSpec, cfg.ResourceIdentifier); err != nil { + http.Error(w, err.Error(), http.StatusUnauthorized) + return + } + isSSE = true + } else { + claims, err := authorizeMCP(w, r, isLatestSpec, cfg) + if err != nil { + http.Error(w, err.Error(), http.StatusUnauthorized) + return + } + + if isLatestSpec { + scope := cfg.ScopesSupported[r.URL.Path] + pr := policyEngine.Evaluate(r, claims, scope) + if pr.Decision == authz.DecisionDeny { + http.Error(w, "Forbidden: "+pr.Message, http.StatusForbidden) + return + } + } } + targetURL = mcpBase if ssePaths[r.URL.Path] { isSSE = true @@ -214,7 +250,17 @@ func buildProxyHandler(cfg *config.Config, modifiers map[string]RequestModifier) }, ModifyResponse: func(resp *http.Response) error { logger.Debug("Response from %s%s: %d", resp.Request.URL.Host, resp.Request.URL.Path, resp.StatusCode) - resp.Header.Del("Access-Control-Allow-Origin") // Avoid upstream conflicts + if resp.StatusCode == http.StatusUnauthorized { + resp.Header.Set( + "WWW-Authenticate", + fmt.Sprintf( + `Bearer resource_metadata="%s"`, + cfg.ResourceIdentifier+"/.well-known/oauth-protected-resource", + )) + resp.Header.Set("Access-Control-Expose-Headers", "WWW-Authenticate") + } + + resp.Header.Del("Access-Control-Allow-Origin") return nil }, ErrorHandler: func(rw http.ResponseWriter, req *http.Request, err error) { @@ -236,7 +282,7 @@ func buildProxyHandler(cfg *config.Config, modifiers map[string]RequestModifier) w.Header().Set("X-Accel-Buffering", "no") w.Header().Set("Cache-Control", "no-cache") w.Header().Set("Connection", "keep-alive") - + w.Header().Set("Content-Type", "text/event-stream") // Keep SSE connections open HandleSSE(w, r, rp) } else { @@ -248,6 +294,47 @@ func buildProxyHandler(cfg *config.Config, modifiers map[string]RequestModifier) } } +// Check if the request is for SSE handshake and authorize it +func authorizeSSE(w http.ResponseWriter, r *http.Request, isLatestSpec bool, resourceID string) error { + h := r.Header.Get("Authorization") + if !strings.HasPrefix(h, "Bearer ") { + if isLatestSpec { + realm := resourceID + "/.well-known/oauth-protected-resource" + w.Header().Set("WWW-Authenticate", fmt.Sprintf(`Bearer resource_metadata="%s"`, realm)) + w.Header().Set("Access-Control-Expose-Headers", "WWW-Authenticate") + } + + return fmt.Errorf("missing bearer token") + } + + return nil +} + +// Handles both v1 (just signature) and v2 (aud + scope) flows +func authorizeMCP(w http.ResponseWriter, r *http.Request, isLatestSpec bool, cfg *config.Config) (*authz.TokenClaims, error) { + h := r.Header.Get("Authorization") + audience := cfg.ResourceIdentifier + if isLatestSpec { + scope := cfg.ScopesSupported[r.URL.Path] + claims, err := util.ValidateJWT(r.Header.Get("MCP-Protocol-Version"), h, audience, scope) + if err != nil { + realm := audience + "/.well-known/oauth-protected-resource" + w.Header().Set("WWW-Authenticate", + fmt.Sprintf(`Bearer realm="%s", error="insufficient_scope", scope="%s"`, realm, scope)) + w.Header().Set("Access-Control-Expose-Headers", "WWW-Authenticate") + return nil, err + } + return claims, nil + } + + // v1: only check signature, then continue + if err := util.ValidateJWTOld(h); err != nil { + return nil, err + } + + return &authz.TokenClaims{}, nil +} + func getAllowedOrigin(origin string, cfg *config.Config) string { if origin == "" { return cfg.CORSConfig.AllowedOrigins[0] // Default to first allowed origin @@ -265,6 +352,7 @@ func getAllowedOrigin(origin string, cfg *config.Config) string { func addCORSHeaders(w http.ResponseWriter, cfg *config.Config, allowedOrigin, requestHeaders string) { w.Header().Set("Access-Control-Allow-Origin", allowedOrigin) w.Header().Set("Access-Control-Allow-Methods", strings.Join(cfg.CORSConfig.AllowedMethods, ", ")) + w.Header().Set("Access-Control-Expose-Headers", "WWW-Authenticate") if requestHeaders != "" { w.Header().Set("Access-Control-Allow-Headers", requestHeaders) } else { @@ -283,6 +371,7 @@ func isAuthPath(path string) bool { "/token": true, "/register": true, "/.well-known/oauth-authorization-server": true, + "/.well-known/oauth-protected-resource": true, } if strings.HasPrefix(path, "/u/") { return true