diff --git a/cmd/proxy/provider.go b/cmd/proxy/provider.go index be4ee21..90ef369 100644 --- a/cmd/proxy/provider.go +++ b/cmd/proxy/provider.go @@ -7,39 +7,39 @@ import ( ) func MakeProvider(cfg *config.Config, demoMode, asgardeoMode bool) authz.Provider { - var mode, orgName string - switch { - case demoMode: - mode = "demo" - orgName = cfg.Demo.OrgName - case asgardeoMode: - mode = "asgardeo" - orgName = cfg.Asgardeo.OrgName - default: - mode = "default" - } - cfg.Mode = mode + var mode, orgName string + switch { + case demoMode: + mode = "demo" + orgName = cfg.Demo.OrgName + case asgardeoMode: + mode = "asgardeo" + orgName = cfg.Asgardeo.OrgName + default: + mode = "default" + } + cfg.Mode = mode - switch mode { - case "demo", "asgardeo": - if len(cfg.AuthorizationServers) == 0 && cfg.JwksURI == "" { - base := constants.ASGARDEO_BASE_URL + orgName + "/oauth2" - cfg.AuthServerBaseURL = base - cfg.JWKSURL = base + "/jwks" - } else { - cfg.AuthServerBaseURL = cfg.AuthorizationServers[0] - cfg.JWKSURL = cfg.JwksURI - } - return authz.NewAsgardeoProvider(cfg) + switch mode { + case "demo", "asgardeo": + if len(cfg.ProtectedResourceMetadata.AuthorizationServers) == 0 && cfg.ProtectedResourceMetadata.JwksURI == "" { + base := constants.ASGARDEO_BASE_URL + orgName + "/oauth2" + cfg.AuthServerBaseURL = base + cfg.JWKSURL = base + "/jwks" + } else { + cfg.AuthServerBaseURL = cfg.ProtectedResourceMetadata.AuthorizationServers[0] + cfg.JWKSURL = cfg.ProtectedResourceMetadata.JwksURI + } + return authz.NewAsgardeoProvider(cfg) - default: - if cfg.Default.BaseURL != "" && cfg.Default.JWKSURL != "" { - cfg.AuthServerBaseURL = cfg.Default.BaseURL - cfg.JWKSURL = cfg.Default.JWKSURL - } else if len(cfg.AuthorizationServers) > 0 { - cfg.AuthServerBaseURL = cfg.AuthorizationServers[0] - cfg.JWKSURL = cfg.JwksURI - } - return authz.NewDefaultProvider(cfg) - } -} \ No newline at end of file + default: + if cfg.Default.BaseURL != "" && cfg.Default.JWKSURL != "" { + cfg.AuthServerBaseURL = cfg.Default.BaseURL + cfg.JWKSURL = cfg.Default.JWKSURL + } else if len(cfg.ProtectedResourceMetadata.AuthorizationServers) > 0 { + cfg.AuthServerBaseURL = cfg.ProtectedResourceMetadata.AuthorizationServers[0] + cfg.JWKSURL = cfg.ProtectedResourceMetadata.JwksURI + } + return authz.NewDefaultProvider(cfg) + } +} diff --git a/config.yaml b/config.yaml index 88c6117..47eb8eb 100644 --- a/config.yaml +++ b/config.yaml @@ -1,9 +1,10 @@ # config.yaml # Common configuration for all transport modes +proxy_base_url: http://localhost:8080 listen_port: 8080 -base_url: "http://localhost:3001" # Base URL for the MCP server -port: 3001 # Port for the MCP server +base_url: "http://localhost:8000" # Base URL for the MCP server +port: 8000 # Port for the MCP server timeout_seconds: 10 # Path configuration @@ -17,7 +18,7 @@ transport_mode: "sse" # Options: "sse" or "stdio" # stdio-specific configuration (used only when transport_mode is "stdio") stdio: - enabled: true + enabled: false user_command: "npx -y @modelcontextprotocol/server-github" work_dir: "" # Working directory (optional) # env: # Environment variables (optional) @@ -30,6 +31,7 @@ path_mapping: cors: allowed_origins: - "http://127.0.0.1:6274" + - "http://localhost:6274" allowed_methods: - "GET" - "POST" @@ -47,17 +49,17 @@ demo: client_id: "N0U9e_NNGr9mP_0fPnPfPI0a6twa" client_secret: "qFHfiBp5gNGAO9zV4YPnDofBzzfInatfUbHyPZvM0jka" -# Protected resource metadata -resource_identifier: http://localhost:3000 -audience: mcp_proxy -scopes_supported: - - "tools":"read:tools" - - "resources":"read:resources" - - "prompts":"read:prompts" -authorization_servers: - - https://api.asgardeo.io/t/acme/ -jwks_uri: https://api.asgardeo.io/t/acme/oauth2/jwks -bearer_methods_supported: - - header - - body - - query +protected_resource_metadata: + resource_identifier: http://localhost:8080/sse + audience: 2xGW_poFYoObUE_vUQxvGdPSUPwa + scopes_supported: + - initialize: "mcp_init" + - tools/call: + - echo_tool: "mcp_echo_tool" + authorization_servers: + - https://api.asgardeo.io/t/openmcpauthdemo/oauth2/token + jwks_uri: https://api.asgardeo.io/t/openmcpauthdemo/oauth2/jwks + bearer_methods_supported: + - header + - body + - query diff --git a/internal/authz/asgardeo.go b/internal/authz/asgardeo.go index aa458d5..598d1ca 100644 --- a/internal/authz/asgardeo.go +++ b/internal/authz/asgardeo.go @@ -194,7 +194,7 @@ func (p *asgardeoProvider) createAsgardeoApplication(regReq RegisterRequest) err if resp.StatusCode >= 400 { respBody, _ := io.ReadAll(resp.Body) - return fmt.Errorf("Asgardeo creation error (%d): %s", resp.StatusCode, string(respBody)) + return fmt.Errorf("asgardeo creation error (%d): %s", resp.StatusCode, string(respBody)) } logger.Info("Created Asgardeo application for clientID=%s", regReq.ClientID) @@ -367,16 +367,41 @@ func randomString(n int) string { func (p *asgardeoProvider) ProtectedResourceMetadataHandler() http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "application/json") + + // Extract only the values into a []string + var supportedScopes []string + var extractStrings func(interface{}) + extractStrings = func(val interface{}) { + switch v := val.(type) { + case string: + supportedScopes = append(supportedScopes, v) + case []any: + for _, item := range v { + extractStrings(item) + } + case map[string]any: + for _, item := range v { + extractStrings(item) + } + } + } + for _, m := range p.cfg.ProtectedResourceMetadata.ScopesSupported { + for _, v := range m { + extractStrings(v) + } + } + meta := map[string]interface{}{ - "resource": p.cfg.ResourceIdentifier, - "scopes_supported": p.cfg.ScopesSupported, - "authorization_servers": p.cfg.AuthorizationServers, + "resource": p.cfg.ProtectedResourceMetadata.ResourceIdentifier, + "scopes_supported": supportedScopes, + "authorization_servers": p.cfg.ProtectedResourceMetadata.AuthorizationServers, } - if p.cfg.JwksURI != "" { - meta["jwks_uri"] = p.cfg.JwksURI + + if p.cfg.ProtectedResourceMetadata.JwksURI != "" { + meta["jwks_uri"] = p.cfg.ProtectedResourceMetadata.JwksURI } - if len(p.cfg.BearerMethodsSupported) > 0 { - meta["bearer_methods_supported"] = p.cfg.BearerMethodsSupported + if len(p.cfg.ProtectedResourceMetadata.BearerMethodsSupported) > 0 { + meta["bearer_methods_supported"] = p.cfg.ProtectedResourceMetadata.BearerMethodsSupported } if err := json.NewEncoder(w).Encode(meta); err != nil { http.Error(w, "failed to encode metadata", http.StatusInternalServerError) diff --git a/internal/authz/default.go b/internal/authz/default.go index dc8900d..8b58fa0 100644 --- a/internal/authz/default.go +++ b/internal/authz/default.go @@ -5,7 +5,7 @@ import ( "net/http" "github.com/wso2/open-mcp-auth-proxy/internal/config" - "github.com/wso2/open-mcp-auth-proxy/internal/logging" + logger "github.com/wso2/open-mcp-auth-proxy/internal/logging" ) type defaultProvider struct { @@ -99,18 +99,17 @@ func (p *defaultProvider) ProtectedResourceMetadataHandler() http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "application/json") meta := map[string]interface{}{ - "audience": p.cfg.Audience, - "resource": p.cfg.ResourceIdentifier, - "scopes_supported": p.cfg.ScopesSupported, - "authorization_servers": p.cfg.AuthorizationServers, + "audience": p.cfg.ProtectedResourceMetadata.Audience, + "scopes_supported": p.cfg.ProtectedResourceMetadata.ScopesSupported, + "authorization_servers": p.cfg.ProtectedResourceMetadata.AuthorizationServers, } - if p.cfg.JwksURI != "" { - meta["jwks_uri"] = p.cfg.JwksURI + if p.cfg.ProtectedResourceMetadata.JwksURI != "" { + meta["jwks_uri"] = p.cfg.ProtectedResourceMetadata.JwksURI } - if len(p.cfg.BearerMethodsSupported) > 0 { - meta["bearer_methods_supported"] = p.cfg.BearerMethodsSupported + if len(p.cfg.ProtectedResourceMetadata.BearerMethodsSupported) > 0 { + meta["bearer_methods_supported"] = p.cfg.ProtectedResourceMetadata.BearerMethodsSupported } if err := json.NewEncoder(w).Encode(meta); err != nil { diff --git a/internal/authz/scope_validator.go b/internal/authz/scope_validator.go index 004fd80..bf18a07 100644 --- a/internal/authz/scope_validator.go +++ b/internal/authz/scope_validator.go @@ -18,36 +18,37 @@ func (d *ScopeValidator) ValidateAccess( claims *jwt.MapClaims, config *config.Config, ) AccessControlResult { - env, err := util.ParseRPCRequest(r) - if err != nil { - return AccessControlResult{DecisionDeny, "bad JSON-RPC request"} - } - requiredScopes := util.GetRequiredScopes(config, env.Method) - if len(requiredScopes) == 0 { - return AccessControlResult{DecisionAllow, ""} - } + env, err := util.ParseRPCRequest(r) + if err != nil { + return AccessControlResult{DecisionDeny, "bad JSON-RPC request"} + } + requiredScopes := util.GetRequiredScopes(config, env) - required := make(map[string]struct{}, len(requiredScopes)) - for _, s := range requiredScopes { - s = strings.TrimSpace(s) - if s != "" { - required[s] = struct{}{} - } - } + if len(requiredScopes) == 0 { + return AccessControlResult{DecisionAllow, ""} + } - var tokenScopes []string - if claims, ok := (*claims)["scope"]; ok { - switch v := claims.(type) { - case string: - tokenScopes = strings.Fields(v) - case []interface{}: - for _, x := range v { - if s, ok := x.(string); ok && s != "" { - tokenScopes = append(tokenScopes, s) - } - } - } - } + required := make(map[string]struct{}, len(requiredScopes)) + for _, s := range requiredScopes { + s = strings.TrimSpace(s) + if s != "" { + required[s] = struct{}{} + } + } + + var tokenScopes []string + if claims, ok := (*claims)["scope"]; ok { + switch v := claims.(type) { + case string: + tokenScopes = strings.Fields(v) + case []interface{}: + for _, x := range v { + if s, ok := x.(string); ok && s != "" { + tokenScopes = append(tokenScopes, s) + } + } + } + } tokenScopeSet := make(map[string]struct{}, len(tokenScopes)) for _, s := range tokenScopes { diff --git a/internal/config/config.go b/internal/config/config.go index dea7a79..2a7958a 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -13,8 +13,9 @@ import ( type TransportMode string const ( - SSETransport TransportMode = "sse" - StdioTransport TransportMode = "stdio" + SSETransport TransportMode = "sse" + StdioTransport TransportMode = "stdio" + StreamableHTTPTransport TransportMode = "streamable_http" ) // Common path configuration for all transport modes @@ -68,6 +69,15 @@ type ResponseConfig struct { CodeChallengeMethodsSupported []string `yaml:"code_challenge_methods_supported,omitempty"` } +type ProtectedResourceMetadata struct { + ResourceIdentifier string `yaml:"resource_identifier"` + Audience string `yaml:"audience"` + ScopesSupported []map[string]interface{} `yaml:"scopes_supported"` + AuthorizationServers []string `yaml:"authorization_servers"` + JwksURI string `yaml:"jwks_uri,omitempty"` + BearerMethodsSupported []string `yaml:"bearer_methods_supported,omitempty"` +} + type PathConfig struct { // For well-known endpoint Response *ResponseConfig `yaml:"response,omitempty"` @@ -86,6 +96,7 @@ type DefaultConfig struct { } type Config struct { + ProxyBaseURL string `yaml:"proxy_base_url"` AuthServerBaseURL string ListenPort int `yaml:"listen_port"` BaseURL string `yaml:"base_url"` @@ -98,7 +109,6 @@ type Config struct { TransportMode TransportMode `yaml:"transport_mode"` Paths PathsConfig `yaml:"paths"` Stdio StdioConfig `yaml:"stdio"` - RequiredScopes map[string]string `yaml:"required_scopes"` // Nested config for Asgardeo Demo DemoConfig `yaml:"demo"` @@ -106,12 +116,7 @@ type Config struct { Default DefaultConfig `yaml:"default"` // Protected resource metadata - Audience string `yaml:"audience"` - ResourceIdentifier string `yaml:"resource_identifier"` - ScopesSupported any `yaml:"scopes_supported"` - AuthorizationServers []string `yaml:"authorization_servers"` - JwksURI string `yaml:"jwks_uri,omitempty"` - BearerMethodsSupported []string `yaml:"bearer_methods_supported,omitempty"` + ProtectedResourceMetadata ProtectedResourceMetadata `yaml:"protected_resource_metadata"` } // Validate checks if the config is valid based on transport mode diff --git a/internal/proxy/proxy.go b/internal/proxy/proxy.go index b880f99..fa72d58 100644 --- a/internal/proxy/proxy.go +++ b/internal/proxy/proxy.go @@ -11,7 +11,7 @@ import ( "github.com/wso2/open-mcp-auth-proxy/internal/authz" "github.com/wso2/open-mcp-auth-proxy/internal/config" - "github.com/wso2/open-mcp-auth-proxy/internal/logging" + logger "github.com/wso2/open-mcp-auth-proxy/internal/logging" "github.com/wso2/open-mcp-auth-proxy/internal/util" ) @@ -64,7 +64,7 @@ func NewRouter(cfg *config.Config, provider authz.Provider, accessController aut } } - mux.HandleFunc("/.well-known/oauth-protected-resource", func(w http.ResponseWriter, r *http.Request) { + mux.HandleFunc(getProtectedResourceMetadataEndpointPath(cfg), func(w http.ResponseWriter, r *http.Request) { origin := r.Header.Get("Origin") allowed := getAllowedOrigin(origin, cfg) if r.Method == http.MethodOptions { @@ -76,7 +76,7 @@ func NewRouter(cfg *config.Config, provider authz.Provider, accessController aut addCORSHeaders(w, cfg, allowed, "") provider.ProtectedResourceMetadataHandler()(w, r) }) - registeredPaths["/.well-known/oauth-protected-resource"] = true + registeredPaths[getProtectedResourceMetadataEndpointPath(cfg)] = true // Remove duplicates from defaultPaths uniquePaths := make(map[string]bool) @@ -165,11 +165,11 @@ func buildProxyHandler(cfg *config.Config, modifiers map[string]RequestModifier, var targetURL *url.URL isSSE := false - if isAuthPath(r.URL.Path) { + if isAuthPath(r.URL.Path, cfg) { targetURL = authBase } else if isMCPPath(r.URL.Path, cfg) { if ssePaths[r.URL.Path] { - if err := authorizeSSE(w, r, isLatestSpec, cfg.ResourceIdentifier); err != nil { + if err := authorizeSSE(w, r, isLatestSpec, cfg); err != nil { http.Error(w, err.Error(), http.StatusUnauthorized) return } @@ -245,7 +245,7 @@ func buildProxyHandler(cfg *config.Config, modifiers map[string]RequestModifier, "WWW-Authenticate", fmt.Sprintf( `Bearer resource_metadata="%s"`, - cfg.ResourceIdentifier+"/.well-known/oauth-protected-resource", + cfg.ProxyBaseURL+getProtectedResourceMetadataEndpointPath(cfg), )) resp.Header.Set("Access-Control-Expose-Headers", "WWW-Authenticate") } @@ -285,11 +285,11 @@ func buildProxyHandler(cfg *config.Config, modifiers map[string]RequestModifier, } // Check if the request is for SSE handshake and authorize it -func authorizeSSE(w http.ResponseWriter, r *http.Request, isLatestSpec bool, resourceID string) error { +func authorizeSSE(w http.ResponseWriter, r *http.Request, isLatestSpec bool, cfg *config.Config) error { authHeader := r.Header.Get("Authorization") if authHeader == "" || !strings.HasPrefix(authHeader, "Bearer ") { if isLatestSpec { - realm := resourceID + "/.well-known/oauth-protected-resource" + realm := cfg.BaseURL + getProtectedResourceMetadataEndpointPath(cfg) w.Header().Set("WWW-Authenticate", fmt.Sprintf(`Bearer resource_metadata="%s"`, realm)) w.Header().Set("Access-Control-Expose-Headers", "WWW-Authenticate") } @@ -305,7 +305,7 @@ func authorizeMCP(w http.ResponseWriter, r *http.Request, isLatestSpec bool, cfg accessToken, _ := util.ExtractAccessToken(authzHeader) if !strings.HasPrefix(authzHeader, "Bearer ") { if isLatestSpec { - realm := cfg.ResourceIdentifier + "/.well-known/oauth-protected-resource" + realm := cfg.ProxyBaseURL + getProtectedResourceMetadataEndpointPath(cfg) w.Header().Set("WWW-Authenticate", fmt.Sprintf( `Bearer resource_metadata=%q`, realm, )) @@ -315,10 +315,10 @@ func authorizeMCP(w http.ResponseWriter, r *http.Request, isLatestSpec bool, cfg return fmt.Errorf("missing or invalid Authorization header") } - err := util.ValidateJWT(isLatestSpec, accessToken, cfg.Audience) + err := util.ValidateJWT(isLatestSpec, accessToken, cfg.ProtectedResourceMetadata.Audience) if err != nil { if isLatestSpec { - realm := cfg.ResourceIdentifier + "/.well-known/oauth-protected-resource" + realm := cfg.ProxyBaseURL + getProtectedResourceMetadataEndpointPath(cfg) w.Header().Set("WWW-Authenticate", fmt.Sprintf(err.Error(), `Bearer realm=%q`, realm, @@ -343,7 +343,7 @@ func authorizeMCP(w http.ResponseWriter, r *http.Request, isLatestSpec bool, cfg http.Error(w, "Invalid token claims", http.StatusUnauthorized) return fmt.Errorf("invalid token claims") } - + pr := accessController.ValidateAccess(r, &claimsMap, cfg) if pr.Decision == authz.DecisionDeny { http.Error(w, "Forbidden: "+pr.Message, http.StatusForbidden) @@ -385,13 +385,13 @@ func addCORSHeaders(w http.ResponseWriter, cfg *config.Config, allowedOrigin, re w.Header().Set("X-Accel-Buffering", "no") } -func isAuthPath(path string) bool { +func isAuthPath(path string, cfg *config.Config) bool { authPaths := map[string]bool{ "/authorize": true, "/token": true, "/register": true, - "/.well-known/oauth-authorization-server": true, - "/.well-known/oauth-protected-resource": true, + "/.well-known/oauth-authorization-server": true, + getProtectedResourceMetadataEndpointPath(cfg): true, } if strings.HasPrefix(path, "/u/") { return true @@ -417,3 +417,17 @@ func skipHeader(h string) bool { } return false } + +func getProtectedResourceMetadataEndpointPath(cfg *config.Config) string { + + protectedResourceMetadataPath := "/.well-known/oauth-protected-resource" + + switch cfg.TransportMode { + case config.SSETransport: + protectedResourceMetadataPath += cfg.Paths.SSE + case config.StreamableHTTPTransport: + protectedResourceMetadataPath += cfg.Paths.StreamableHTTP + } + + return protectedResourceMetadataPath +} diff --git a/internal/util/jwks.go b/internal/util/jwks.go index b1afb6f..1a00d6e 100644 --- a/internal/util/jwks.go +++ b/internal/util/jwks.go @@ -11,7 +11,7 @@ import ( "github.com/golang-jwt/jwt/v4" "github.com/wso2/open-mcp-auth-proxy/internal/config" - "github.com/wso2/open-mcp-auth-proxy/internal/logging" + logger "github.com/wso2/open-mcp-auth-proxy/internal/logging" ) type TokenClaims struct { @@ -82,7 +82,7 @@ func parseRSAPublicKey(nStr, eStr string) (*rsa.PublicKey, error) { // ValidateJWT checks the Bearer token according to the Mcp-Protocol-Version. func ValidateJWT( isLatestSpec bool, - accessToken string, + accessToken string, audience string, ) error { logger.Warn("isLatestSpec: %s", isLatestSpec) @@ -148,46 +148,66 @@ func ValidateJWT( // Parses the JWT token and returns the claims func ParseJWT(tokenStr string) (jwt.MapClaims, error) { - if tokenStr == "" { - return nil, fmt.Errorf("empty JWT") - } + if tokenStr == "" { + return nil, fmt.Errorf("empty JWT") + } - var claims jwt.MapClaims - _, _, err := jwt.NewParser().ParseUnverified(tokenStr, &claims) - if err != nil { - return nil, fmt.Errorf("failed to parse JWT: %w", err) - } - return claims, nil + var claims jwt.MapClaims + _, _, err := jwt.NewParser().ParseUnverified(tokenStr, &claims) + if err != nil { + return nil, fmt.Errorf("failed to parse JWT: %w", err) + } + return claims, nil } // Process the required scopes -func GetRequiredScopes(cfg *config.Config, method string) []string { - switch raw := cfg.ScopesSupported.(type) { - case map[string]string: - if scope, ok := raw[method]; ok { - return []string{scope} - } - parts := strings.SplitN(method, "/", 2) - if len(parts) > 0 { - if scope, ok := raw[parts[0]]; ok { - return []string{scope} - } - } - return nil - case []interface{}: - out := make([]string, 0, len(raw)) - for _, v := range raw { - if s, ok := v.(string); ok && s != "" { - out = append(out, s) - } - } - return out +func GetRequiredScopes(cfg *config.Config, requestBody *RPCEnvelope) []string { - case []string: - return raw - } + var scopeObj interface{} + found := false + for _, m := range cfg.ProtectedResourceMetadata.ScopesSupported { + if val, ok := m[requestBody.Method]; ok { + scopeObj = val + found = true + break + } + } + if !found { + return nil + } - return nil + switch v := scopeObj.(type) { + case string: + return []string{v} + case []any: + if requestBody.Params != nil { + if paramsMap, ok := requestBody.Params.(map[string]any); ok { + name, ok := paramsMap["name"].(string) + if ok { + for _, item := range v { + if scopeMap, ok := item.(map[interface{}]interface{}); ok { + if scopeVal, exists := scopeMap[name]; exists { + if scopeStr, ok := scopeVal.(string); ok { + return []string{scopeStr} + } + if scopeArr, ok := scopeVal.([]any); ok { + var scopes []string + for _, s := range scopeArr { + if str, ok := s.(string); ok { + scopes = append(scopes, str) + } + } + return scopes + } + } + } + } + } + } + } + } + + return nil } // Extracts the Bearer token from the Authorization header