From 06f0aeb461238aa001e9ad92132e25dcb749e6a1 Mon Sep 17 00:00:00 2001 From: Thilina Shashimal Senarath Date: Wed, 2 Apr 2025 22:37:01 +0530 Subject: [PATCH 1/6] add --asgardeo --- cmd/proxy/main.go | 7 +++++-- config.yaml | 5 +++++ internal/config/config.go | 9 ++++++++- 3 files changed, 18 insertions(+), 3 deletions(-) diff --git a/cmd/proxy/main.go b/cmd/proxy/main.go index 9a4b472..2308eee 100644 --- a/cmd/proxy/main.go +++ b/cmd/proxy/main.go @@ -17,6 +17,7 @@ import ( func main() { demoMode := flag.Bool("demo", false, "Use Asgardeo-based provider (demo).") + asgardeoMode := flag.Bool("asgardeo", false, "Use Asgardeo-based provider (demo).") flag.Parse() // 1. Load config @@ -32,8 +33,10 @@ func main() { cfg.JWKSURL = "https://api.asgardeo.io/t/" + cfg.Demo.OrgName + "/oauth2/jwks" provider = authz.NewAsgardeoProvider(cfg) fmt.Println("Using Asgardeo provider (demo).") - } else { - log.Fatalf("Not supported yet.") + } else if *asgardeoMode { + cfg.AuthServerBaseURL = "https://api.asgardeo.io/t/" + cfg.Asgardeo.OrgName + "/oauth2" + cfg.JWKSURL = "https://api.asgardeo.io/t/" + cfg.Asgardeo.OrgName + "/oauth2/jwks" + provider = authz.NewAsgardeoProvider(cfg) } // 3. (Optional) Fetch JWKS if you want local JWT validation diff --git a/config.yaml b/config.yaml index 9725f14..9385f58 100644 --- a/config.yaml +++ b/config.yaml @@ -16,3 +16,8 @@ demo: org_name: "openmcpauthdemo" client_id: "N0U9e_NNGr9mP_0fPnPfPI0a6twa" client_secret: "qFHfiBp5gNGAO9zV4YPnDofBzzfInatfUbHyPZvM0jka" + +asgardeo: + org_name: "" + client_id: "" + client_secret: "" diff --git a/internal/config/config.go b/internal/config/config.go index 3a3b231..6cdc949 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -13,6 +13,12 @@ type DemoConfig struct { OrgName string `yaml:"org_name"` } +type AsgardeoConfig struct { + ClientID string `yaml:"client_id"` + ClientSecret string `yaml:"client_secret"` + OrgName string `yaml:"org_name"` +} + type Config struct { AuthServerBaseURL string `yaml:"auth_server_base_url"` MCPServerBaseURL string `yaml:"mcp_server_base_url"` @@ -23,7 +29,8 @@ type Config struct { PathMapping map[string]string `yaml:"path_mapping"` // Nested config for Asgardeo - Demo DemoConfig `yaml:"demo"` + Demo DemoConfig `yaml:"demo"` + Asgardeo AsgardeoConfig `yaml:"asgardeo"` } // LoadConfig reads a YAML config file into Config struct. From 3d085008a86865d17f6c18afaef6de4bcfa34ade Mon Sep 17 00:00:00 2001 From: Thilina Shashimal Senarath Date: Wed, 2 Apr 2025 22:50:33 +0530 Subject: [PATCH 2/6] fix minor issue --- README.md | 18 ++++++++++++++++-- cmd/proxy/main.go | 2 +- 2 files changed, 17 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index c13d162..65259a2 100644 --- a/README.md +++ b/README.md @@ -8,7 +8,6 @@ OpenMCPAuthProxy is a security middleware that implements the Model Context Prot ### Prerequisites - Go 1.20 or higher -- A running MCP server (SSE transport supported) ### Installation ```bash @@ -21,11 +20,26 @@ go build -o openmcpauthproxy ./cmd/proxy Create a configuration file `config.yaml` with the following parameters: +### demo mode configuration: + ```yaml mcp_server_base_url: "http://localhost:8000" # URL of your MCP server listen_address: ":8080" # Address where the proxy will listen ``` +### asgardeo configuration: + +```yaml +mcp_server_base_url: "http://localhost:8000" # URL of your MCP server +listen_address: ":8080" # Address where the proxy will listen + +asgardeo: + org_name: "your-org-name" + client_id: "your-client-id" + client_secret: "your-client-secret" + ``` + + ## Usage Example ### 1. Start the MCP Server @@ -70,7 +84,7 @@ python3 echo_server.py ./openmcpauthproxy --demo ``` -The `--demo` flag enables a demonstration mode with pre-configured authentication with [Asgardeo](https://asgardeo.io/). +The `--demo` flag enables a demonstration mode with pre-configured authentication with [Asgardeo](https://asgardeo.io/) You can also use the `--asgardeo` flag to use your own Asgardeo configuration. ### 3. Connect Using an MCP Client diff --git a/cmd/proxy/main.go b/cmd/proxy/main.go index 2308eee..f02d9c3 100644 --- a/cmd/proxy/main.go +++ b/cmd/proxy/main.go @@ -17,7 +17,7 @@ import ( func main() { demoMode := flag.Bool("demo", false, "Use Asgardeo-based provider (demo).") - asgardeoMode := flag.Bool("asgardeo", false, "Use Asgardeo-based provider (demo).") + asgardeoMode := flag.Bool("asgardeo", false, "Use Asgardeo-based provider (asgardeo).") flag.Parse() // 1. Load config From 960261fc809b377f5197a88c6db501f3bb6a06e5 Mon Sep 17 00:00:00 2001 From: Thilina Shashimal Senarath Date: Thu, 3 Apr 2025 02:53:14 +0530 Subject: [PATCH 3/6] fix standard auth --- cmd/proxy/main.go | 2 + config.yaml | 17 +++- internal/config/config.go | 9 +++ internal/proxy/proxy.go | 162 ++++++++++++++++++++++++++------------ 4 files changed, 137 insertions(+), 53 deletions(-) diff --git a/cmd/proxy/main.go b/cmd/proxy/main.go index f02d9c3..c22dc96 100644 --- a/cmd/proxy/main.go +++ b/cmd/proxy/main.go @@ -29,11 +29,13 @@ func main() { // 2. Create the chosen provider var provider authz.Provider if *demoMode { + cfg.Mode = "demo" cfg.AuthServerBaseURL = "https://api.asgardeo.io/t/" + cfg.Demo.OrgName + "/oauth2" cfg.JWKSURL = "https://api.asgardeo.io/t/" + cfg.Demo.OrgName + "/oauth2/jwks" provider = authz.NewAsgardeoProvider(cfg) fmt.Println("Using Asgardeo provider (demo).") } else if *asgardeoMode { + cfg.Mode = "asgardeo" cfg.AuthServerBaseURL = "https://api.asgardeo.io/t/" + cfg.Asgardeo.OrgName + "/oauth2" cfg.JWKSURL = "https://api.asgardeo.io/t/" + cfg.Asgardeo.OrgName + "/oauth2/jwks" provider = authz.NewAsgardeoProvider(cfg) diff --git a/config.yaml b/config.yaml index 9385f58..4c6b196 100644 --- a/config.yaml +++ b/config.yaml @@ -1,7 +1,7 @@ # config.yaml auth_server_base_url: "" -mcp_server_base_url: "http://localhost:8000" +mcp_server_base_url: "" listen_address: ":8080" jwks_url: "" timeout_seconds: 10 @@ -11,6 +11,21 @@ mcp_paths: - /sse path_mapping: + /token: /oauth/token + /.well-known/oauth-authorization-server: /.well-known/openid-configuration + +cors: + allowed_origins: + - "" + allowed_methods: + - "GET" + - "POST" + - "PUT" + - "DELETE" + allowed_headers: + - "Authorization" + - "Content-Type" + allow_credentials: true demo: org_name: "openmcpauthdemo" diff --git a/internal/config/config.go b/internal/config/config.go index 6cdc949..f4f8218 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -19,6 +19,13 @@ type AsgardeoConfig struct { OrgName string `yaml:"org_name"` } +type CORSConfig struct { + AllowedOrigins []string `yaml:"allowed_origins"` + AllowedMethods []string `yaml:"allowed_methods"` + AllowedHeaders []string `yaml:"allowed_headers"` + AllowCredentials bool `yaml:"allow_credentials"` +} + type Config struct { AuthServerBaseURL string `yaml:"auth_server_base_url"` MCPServerBaseURL string `yaml:"mcp_server_base_url"` @@ -27,6 +34,8 @@ type Config struct { TimeoutSeconds int `yaml:"timeout_seconds"` MCPPaths []string `yaml:"mcp_paths"` PathMapping map[string]string `yaml:"path_mapping"` + Mode string `yaml:"mode"` + CORSConfig CORSConfig `yaml:"cors"` // Nested config for Asgardeo Demo DemoConfig `yaml:"demo"` diff --git a/internal/proxy/proxy.go b/internal/proxy/proxy.go index 382c8f3..5f74125 100644 --- a/internal/proxy/proxy.go +++ b/internal/proxy/proxy.go @@ -20,27 +20,40 @@ import ( func NewRouter(cfg *config.Config, provider authz.Provider) http.Handler { mux := http.NewServeMux() - // 1. Custom well-known - mux.HandleFunc("/.well-known/oauth-authorization-server", provider.WellKnownHandler()) + registeredPaths := make(map[string]bool) - // 2. Registration - mux.HandleFunc("/register", provider.RegisterHandler()) + var defaultPaths []string + if cfg.Mode == "demo" || cfg.Mode == "asgardeo" { + // 1. Custom well-known + mux.HandleFunc("/.well-known/oauth-authorization-server", provider.WellKnownHandler()) + registeredPaths["/.well-known/oauth-authorization-server"] = true + + // 2. Registration + mux.HandleFunc("/register", provider.RegisterHandler()) + registeredPaths["/register"] = true + + defaultPaths = []string{"/authorize", "/token"} + } else { + defaultPaths = []string{"/authorize", "/token", "/register", "/.well-known/oauth-authorization-server"} + } - // 3. Default "auth" paths, proxied - defaultPaths := []string{"/authorize", "/token"} for _, path := range defaultPaths { mux.HandleFunc(path, buildProxyHandler(cfg)) + registeredPaths[path] = true } // 4. MCP paths for _, path := range cfg.MCPPaths { mux.HandleFunc(path, buildProxyHandler(cfg)) + registeredPaths[path] = true } - // 5. If you want to map additional paths from config.PathMapping - // to the same proxy logic: + // 5. Register paths from PathMapping that haven't been registered yet for path := range cfg.PathMapping { - mux.HandleFunc(path, buildProxyHandler(cfg)) + if !registeredPaths[path] { + mux.HandleFunc(path, buildProxyHandler(cfg)) + registeredPaths[path] = true + } } return mux @@ -57,13 +70,6 @@ func buildProxyHandler(cfg *config.Config) http.HandlerFunc { log.Fatalf("Invalid MCP server URL: %v", err) } - // We'll define sets for known auth paths, SSE paths, etc. - authPaths := map[string]bool{ - "/authorize": true, - "/token": true, - "/.well-known/oauth-authorization-server": true, - } - // Detect SSE paths from config ssePaths := make(map[string]bool) for _, p := range cfg.MCPPaths { @@ -73,23 +79,38 @@ func buildProxyHandler(cfg *config.Config) http.HandlerFunc { } return func(w http.ResponseWriter, r *http.Request) { + origin := r.Header.Get("Origin") + allowedOrigin := getAllowedOrigin(origin, cfg) // Handle OPTIONS if r.Method == http.MethodOptions { - addCORSHeaders(w) + if allowedOrigin == "" { + log.Printf("[proxy] Preflight request from disallowed origin: %s", origin) + http.Error(w, "CORS origin not allowed", http.StatusForbidden) + return + } + addCORSHeaders(w, cfg, allowedOrigin, r.Header.Get("Access-Control-Request-Headers")) w.WriteHeader(http.StatusNoContent) return } - addCORSHeaders(w) + if allowedOrigin == "" { + log.Printf("[proxy] Request from disallowed origin: %s for %s", origin, r.URL.Path) + http.Error(w, "CORS origin not allowed", http.StatusForbidden) + return + } + + // Add CORS headers to all responses + addCORSHeaders(w, cfg, allowedOrigin, "") // Decide whether the request should go to the auth server or MCP var targetURL *url.URL isSSE := false - if authPaths[r.URL.Path] { + if isAuthPath(r.URL.Path) { targetURL = authBase } else if isMCPPath(r.URL.Path, cfg) { - // Validate JWT if you want + // Validate JWT for MCP paths if required + // Placeholder for JWT validation logic if err := util.ValidateJWT(r.Header.Get("Authorization")); err != nil { log.Printf("[proxy] Unauthorized request to %s: %v", r.URL.Path, err) http.Error(w, "Unauthorized", http.StatusUnauthorized) @@ -100,7 +121,6 @@ func buildProxyHandler(cfg *config.Config) http.HandlerFunc { isSSE = true } } else { - // If it's not recognized as an auth path or an MCP path http.Error(w, "Forbidden", http.StatusForbidden) return } @@ -120,23 +140,42 @@ func buildProxyHandler(cfg *config.Config) http.HandlerFunc { req.URL.RawQuery = r.URL.RawQuery req.Host = targetURL.Host - for header, values := range r.Header { + cleanHeaders := http.Header{} + + for k, v := range r.Header { // Skip hop-by-hop headers - if strings.EqualFold(header, "Connection") || - strings.EqualFold(header, "Keep-Alive") || - strings.EqualFold(header, "Transfer-Encoding") || - strings.EqualFold(header, "Upgrade") || - strings.EqualFold(header, "Proxy-Authorization") || - strings.EqualFold(header, "Proxy-Connection") { + if skipHeader(k) { continue } - for _, value := range values { - req.Header.Set(header, value) - } + // Set only the first value to avoid duplicates + cleanHeaders.Set(k, v[0]) } + + // Override or remove sensitive headers if needed + if strings.Contains(req.URL.Path, "/token") { + cleanHeaders.Set("Accept", "application/json") + cleanHeaders.Set("Content-Type", "application/x-www-form-urlencoded") + cleanHeaders.Set("User-Agent", "GoProxy/1.0") + cleanHeaders.Del("Origin") + cleanHeaders.Del("Referer") + } + + req.Header = cleanHeaders + + // DEBUG: log headers sent to Asgardeo + log.Println("[proxy] Outgoing request headers:") + for k, v := range req.Header { + log.Printf(" %s: %s", k, strings.Join(v, ", ")) + } + log.Printf("[proxy] %s -> %s%s", r.URL.Path, req.URL.Host, req.URL.Path) }, + ModifyResponse: func(resp *http.Response) error { + log.Printf("[proxy] Response from %s%s: %d", resp.Request.URL.Host, resp.Request.URL.Path, resp.StatusCode) + resp.Header.Del("Access-Control-Allow-Origin") // Avoid upstream conflicts + return nil + }, ErrorHandler: func(rw http.ResponseWriter, req *http.Request, err error) { log.Printf("[proxy] Error proxying: %v", err) http.Error(rw, "Bad Gateway", http.StatusBadGateway) @@ -156,12 +195,43 @@ func buildProxyHandler(cfg *config.Config) http.HandlerFunc { } } -func addCORSHeaders(w http.ResponseWriter) { - w.Header().Set("Access-Control-Allow-Origin", "*") - w.Header().Set("Access-Control-Allow-Headers", "Authorization, Content-Type") - w.Header().Set("Access-Control-Allow-Methods", "GET, POST, OPTIONS") +func getAllowedOrigin(origin string, cfg *config.Config) string { + if origin == "" { + return cfg.CORSConfig.AllowedOrigins[0] // Default to first allowed origin + } + for _, allowed := range cfg.CORSConfig.AllowedOrigins { + if allowed == origin { + return allowed + } + } + return "" } +// addCORSHeaders adds configurable CORS headers +func addCORSHeaders(w http.ResponseWriter, cfg *config.Config, allowedOrigin, requestHeaders string) { + w.Header().Set("Access-Control-Allow-Origin", allowedOrigin) + w.Header().Set("Access-Control-Allow-Methods", strings.Join(cfg.CORSConfig.AllowedMethods, ", ")) + if requestHeaders != "" { + w.Header().Set("Access-Control-Allow-Headers", requestHeaders) + } else { + w.Header().Set("Access-Control-Allow-Headers", strings.Join(cfg.CORSConfig.AllowedHeaders, ", ")) + } + if cfg.CORSConfig.AllowCredentials { + w.Header().Set("Access-Control-Allow-Credentials", "true") + } + w.Header().Set("Vary", "Origin") +} + +func isAuthPath(path string) bool { + authPaths := map[string]bool{ + "/authorize": true, + "/token": true, + "/.well-known/oauth-authorization-server": true, + } + return authPaths[path] +} + +// isMCPPath checks if the path is an MCP path func isMCPPath(path string, cfg *config.Config) bool { for _, p := range cfg.MCPPaths { if strings.HasPrefix(path, p) { @@ -171,22 +241,10 @@ func isMCPPath(path string, cfg *config.Config) bool { return false } -func copyHeaders(src http.Header, dst http.Header) { - // Exclude hop-by-hop - hopByHop := map[string]bool{ - "Connection": true, - "Keep-Alive": true, - "Transfer-Encoding": true, - "Upgrade": true, - "Proxy-Authorization": true, - "Proxy-Connection": true, - } - for k, vv := range src { - if hopByHop[strings.ToLower(k)] { - continue - } - for _, v := range vv { - dst.Add(k, v) - } +func skipHeader(h string) bool { + switch strings.ToLower(h) { + case "connection", "keep-alive", "transfer-encoding", "upgrade", "proxy-authorization", "proxy-connection", "te", "trailer": + return true } + return false } From a6d8eecdcc6646bc953f23b05221728239ea1a59 Mon Sep 17 00:00:00 2001 From: Thilina Shashimal Senarath Date: Thu, 3 Apr 2025 09:32:36 +0530 Subject: [PATCH 4/6] fix ListenPort --- cmd/proxy/main.go | 7 +++++-- config.yaml | 2 +- internal/config/config.go | 2 +- 3 files changed, 7 insertions(+), 4 deletions(-) diff --git a/cmd/proxy/main.go b/cmd/proxy/main.go index c22dc96..699929f 100644 --- a/cmd/proxy/main.go +++ b/cmd/proxy/main.go @@ -49,14 +49,17 @@ func main() { // 4. Build the main router mux := proxy.NewRouter(cfg, provider) + listen_address := fmt.Sprintf(":%d", cfg.ListenPort) + // 5. Start the server srv := &http.Server{ - Addr: cfg.ListenAddress, + + Addr: listen_address, Handler: mux, } go func() { - log.Printf("Server listening on %s", cfg.ListenAddress) + log.Printf("Server listening on %s", listen_address) if err := srv.ListenAndServe(); err != nil && err != http.ErrServerClosed { log.Fatalf("Server error: %v", err) } diff --git a/config.yaml b/config.yaml index 4c6b196..ab84ad2 100644 --- a/config.yaml +++ b/config.yaml @@ -2,7 +2,7 @@ auth_server_base_url: "" mcp_server_base_url: "" -listen_address: ":8080" +listen_port: 8080 jwks_url: "" timeout_seconds: 10 diff --git a/internal/config/config.go b/internal/config/config.go index f4f8218..2f63e7c 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -29,7 +29,7 @@ type CORSConfig struct { type Config struct { AuthServerBaseURL string `yaml:"auth_server_base_url"` MCPServerBaseURL string `yaml:"mcp_server_base_url"` - ListenAddress string `yaml:"listen_address"` + ListenPort int `yaml:"listen_port"` JWKSURL string `yaml:"jwks_url"` TimeoutSeconds int `yaml:"timeout_seconds"` MCPPaths []string `yaml:"mcp_paths"` From ec2335252cbec29f280710a23769162aaa63c9c3 Mon Sep 17 00:00:00 2001 From: Thilina Shashimal Senarath Date: Thu, 3 Apr 2025 09:34:40 +0530 Subject: [PATCH 5/6] fix ListenPort readme --- README.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 65259a2..9c5b809 100644 --- a/README.md +++ b/README.md @@ -24,14 +24,14 @@ Create a configuration file `config.yaml` with the following parameters: ```yaml mcp_server_base_url: "http://localhost:8000" # URL of your MCP server -listen_address: ":8080" # Address where the proxy will listen +listen_port: 8080 # Port where the proxy will listen ``` ### asgardeo configuration: ```yaml mcp_server_base_url: "http://localhost:8000" # URL of your MCP server -listen_address: ":8080" # Address where the proxy will listen +listen_port: 8080 # Port where the proxy will listen asgardeo: org_name: "your-org-name" From d58d93d3a1a1fa0a848437a5284e7550883b5980 Mon Sep 17 00:00:00 2001 From: Thilina Shashimal Senarath Date: Thu, 3 Apr 2025 13:51:57 +0530 Subject: [PATCH 6/6] add default mode --- cmd/proxy/main.go | 15 ++- config.yaml | 41 ++++++- internal/authz/default.go | 94 +++++++++++++++ internal/config/config.go | 42 ++++++- internal/constants/constants.go | 7 ++ internal/proxy/modifier.go | 199 ++++++++++++++++++++++++++++++++ internal/proxy/proxy.go | 90 +++++++++++---- 7 files changed, 450 insertions(+), 38 deletions(-) create mode 100644 internal/authz/default.go create mode 100644 internal/constants/constants.go create mode 100644 internal/proxy/modifier.go diff --git a/cmd/proxy/main.go b/cmd/proxy/main.go index 699929f..cde3cf3 100644 --- a/cmd/proxy/main.go +++ b/cmd/proxy/main.go @@ -11,6 +11,7 @@ import ( "github.com/wso2/open-mcp-auth-proxy/internal/authz" "github.com/wso2/open-mcp-auth-proxy/internal/config" + "github.com/wso2/open-mcp-auth-proxy/internal/constants" "github.com/wso2/open-mcp-auth-proxy/internal/proxy" "github.com/wso2/open-mcp-auth-proxy/internal/util" ) @@ -30,15 +31,19 @@ func main() { var provider authz.Provider if *demoMode { cfg.Mode = "demo" - cfg.AuthServerBaseURL = "https://api.asgardeo.io/t/" + cfg.Demo.OrgName + "/oauth2" - cfg.JWKSURL = "https://api.asgardeo.io/t/" + cfg.Demo.OrgName + "/oauth2/jwks" + cfg.AuthServerBaseURL = constants.ASGARDEO_BASE_URL + cfg.Demo.OrgName + "/oauth2" + cfg.JWKSURL = constants.ASGARDEO_BASE_URL + cfg.Demo.OrgName + "/oauth2/jwks" provider = authz.NewAsgardeoProvider(cfg) - fmt.Println("Using Asgardeo provider (demo).") } else if *asgardeoMode { cfg.Mode = "asgardeo" - cfg.AuthServerBaseURL = "https://api.asgardeo.io/t/" + cfg.Asgardeo.OrgName + "/oauth2" - cfg.JWKSURL = "https://api.asgardeo.io/t/" + cfg.Asgardeo.OrgName + "/oauth2/jwks" + cfg.AuthServerBaseURL = constants.ASGARDEO_BASE_URL + cfg.Asgardeo.OrgName + "/oauth2" + cfg.JWKSURL = constants.ASGARDEO_BASE_URL + cfg.Asgardeo.OrgName + "/oauth2/jwks" provider = authz.NewAsgardeoProvider(cfg) + } else { + cfg.Mode = "default" + cfg.JWKSURL = cfg.Default.JWKSURL + cfg.AuthServerBaseURL = cfg.Default.BaseURL + provider = authz.NewDefaultProvider(cfg) } // 3. (Optional) Fetch JWKS if you want local JWT validation diff --git a/config.yaml b/config.yaml index ab84ad2..0b0ade4 100644 --- a/config.yaml +++ b/config.yaml @@ -1,9 +1,7 @@ # config.yaml -auth_server_base_url: "" mcp_server_base_url: "" listen_port: 8080 -jwks_url: "" timeout_seconds: 10 mcp_paths: @@ -11,8 +9,10 @@ mcp_paths: - /sse path_mapping: - /token: /oauth/token - /.well-known/oauth-authorization-server: /.well-known/openid-configuration + /token: /token + /register: /register + /authorize: /authorize + /.well-known/oauth-authorization-server: /.well-known/oauth-authorization-server cors: allowed_origins: @@ -36,3 +36,36 @@ asgardeo: org_name: "" client_id: "" client_secret: "" + +default: + base_url: "" + jwks_url: "" + path: + /.well-known/oauth-authorization-server: + response: + issuer: "" + jwks_uri: "" + authorization_endpoint: "" # Optional + token_endpoint: "" # Optional + registration_endpoint: "" # Optional + response_types_supported: + - "code" + grant_types_supported: + - "authorization_code" + - "refresh_token" + code_challenge_methods_supported: + - "S256" + - "plain" + /authroize: + addQueryParams: + - name: "" + value: "" + /token: + addBodyParams: + - name: "" + value: "" + /register: + addBodyParams: + - name: "" + value: "" + diff --git a/internal/authz/default.go b/internal/authz/default.go new file mode 100644 index 0000000..9230d39 --- /dev/null +++ b/internal/authz/default.go @@ -0,0 +1,94 @@ +package authz + +import ( + "encoding/json" + "net/http" + + "github.com/wso2/open-mcp-auth-proxy/internal/config" +) + +type defaultProvider struct { + cfg *config.Config +} + +// NewDefaultProvider initializes a Provider for Asgardeo (demo mode). +func NewDefaultProvider(cfg *config.Config) Provider { + return &defaultProvider{cfg: cfg} +} + +func (p *defaultProvider) WellKnownHandler() http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Access-Control-Allow-Origin", "*") + w.Header().Set("Access-Control-Allow-Headers", "Authorization, Content-Type") + w.Header().Set("Access-Control-Allow-Methods", "GET, OPTIONS") + + if r.Method == http.MethodOptions { + w.WriteHeader(http.StatusNoContent) + return + } + + if r.Method != http.MethodGet { + http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) + return + } + + // Check if we have a custom response configuration + if p.cfg.Default.Path != nil { + pathConfig, exists := p.cfg.Default.Path["/.well-known/oauth-authorization-server"] + if exists && pathConfig.Response != nil { + // Use configured response values + responseConfig := pathConfig.Response + + // Get current host for proxy endpoints + scheme := "http" + if r.TLS != nil { + scheme = "https" + } + if forwardedProto := r.Header.Get("X-Forwarded-Proto"); forwardedProto != "" { + scheme = forwardedProto + } + host := r.Host + if forwardedHost := r.Header.Get("X-Forwarded-Host"); forwardedHost != "" { + host = forwardedHost + } + baseURL := scheme + "://" + host + + authorizationEndpoint := responseConfig.AuthorizationEndpoint + if authorizationEndpoint == "" { + authorizationEndpoint = baseURL + "/authorize" + } + tokenEndpoint := responseConfig.TokenEndpoint + if tokenEndpoint == "" { + tokenEndpoint = baseURL + "/token" + } + registraionEndpoint := responseConfig.RegistrationEndpoint + if registraionEndpoint == "" { + registraionEndpoint = baseURL + "/register" + } + + // Build response from config + response := map[string]interface{}{ + "issuer": responseConfig.Issuer, + "authorization_endpoint": authorizationEndpoint, + "token_endpoint": tokenEndpoint, + "jwks_uri": responseConfig.JwksURI, + "response_types_supported": responseConfig.ResponseTypesSupported, + "grant_types_supported": responseConfig.GrantTypesSupported, + "token_endpoint_auth_methods_supported": []string{"client_secret_basic"}, + "registration_endpoint": registraionEndpoint, + "code_challenge_methods_supported": responseConfig.CodeChallengeMethodsSupported, + } + + w.Header().Set("Content-Type", "application/json") + if err := json.NewEncoder(w).Encode(response); err != nil { + http.Error(w, "Internal server error", http.StatusInternalServerError) + } + return + } + } + } +} + +func (p *defaultProvider) RegisterHandler() http.HandlerFunc { + return nil +} diff --git a/internal/config/config.go b/internal/config/config.go index 2f63e7c..01c3a6f 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -26,11 +26,44 @@ type CORSConfig struct { AllowCredentials bool `yaml:"allow_credentials"` } +type ParamConfig struct { + Name string `yaml:"name"` + Value string `yaml:"value"` +} + +type ResponseConfig struct { + Issuer string `yaml:"issuer,omitempty"` + JwksURI string `yaml:"jwks_uri,omitempty"` + AuthorizationEndpoint string `yaml:"authorization_endpoint,omitempty"` + TokenEndpoint string `yaml:"token_endpoint,omitempty"` + RegistrationEndpoint string `yaml:"registration_endpoint,omitempty"` + ResponseTypesSupported []string `yaml:"response_types_supported,omitempty"` + GrantTypesSupported []string `yaml:"grant_types_supported,omitempty"` + CodeChallengeMethodsSupported []string `yaml:"code_challenge_methods_supported,omitempty"` +} + +type PathConfig struct { + // For well-known endpoint + Response *ResponseConfig `yaml:"response,omitempty"` + + // For authorization endpoint + AddQueryParams []ParamConfig `yaml:"addQueryParams,omitempty"` + + // For token and register endpoints + AddBodyParams []ParamConfig `yaml:"addBodyParams,omitempty"` +} + +type DefaultConfig struct { + BaseURL string `yaml:"base_url,omitempty"` + Path map[string]PathConfig `yaml:"path,omitempty"` + JWKSURL string `yaml:"jwks_url,omitempty"` +} + type Config struct { - AuthServerBaseURL string `yaml:"auth_server_base_url"` - MCPServerBaseURL string `yaml:"mcp_server_base_url"` - ListenPort int `yaml:"listen_port"` - JWKSURL string `yaml:"jwks_url"` + AuthServerBaseURL string + MCPServerBaseURL string `yaml:"mcp_server_base_url"` + ListenPort int `yaml:"listen_port"` + JWKSURL string TimeoutSeconds int `yaml:"timeout_seconds"` MCPPaths []string `yaml:"mcp_paths"` PathMapping map[string]string `yaml:"path_mapping"` @@ -40,6 +73,7 @@ type Config struct { // Nested config for Asgardeo Demo DemoConfig `yaml:"demo"` Asgardeo AsgardeoConfig `yaml:"asgardeo"` + Default DefaultConfig `yaml:"default"` } // LoadConfig reads a YAML config file into Config struct. diff --git a/internal/constants/constants.go b/internal/constants/constants.go new file mode 100644 index 0000000..1e5808e --- /dev/null +++ b/internal/constants/constants.go @@ -0,0 +1,7 @@ +package constants + +// Package constant provides constants for the MCP Auth Proxy + +const ( + ASGARDEO_BASE_URL = "https://api.asgardeo.io/t/" +) diff --git a/internal/proxy/modifier.go b/internal/proxy/modifier.go new file mode 100644 index 0000000..8e2268b --- /dev/null +++ b/internal/proxy/modifier.go @@ -0,0 +1,199 @@ +package proxy + +import ( + "bytes" + "encoding/json" + "fmt" + "io" + "net/http" + "strings" + + "github.com/wso2/open-mcp-auth-proxy/internal/config" +) + +// RequestModifier modifies requests before they are proxied +type RequestModifier interface { + ModifyRequest(req *http.Request) (*http.Request, error) +} + +// AuthorizationModifier adds parameters to authorization requests +type AuthorizationModifier struct { + Config *config.Config +} + +// TokenModifier adds parameters to token requests +type TokenModifier struct { + Config *config.Config +} + +type RegisterModifier struct { + Config *config.Config +} + +// ModifyRequest adds configured parameters to authorization requests +func (m *AuthorizationModifier) ModifyRequest(req *http.Request) (*http.Request, error) { + // Check if we have parameters to add + if m.Config.Default.Path == nil { + return req, nil + } + + pathConfig, exists := m.Config.Default.Path["/authorize"] + if !exists || len(pathConfig.AddQueryParams) == 0 { + return req, nil + } + // Get current query parameters + query := req.URL.Query() + + // Add parameters from config + for _, param := range pathConfig.AddQueryParams { + query.Set(param.Name, param.Value) + } + + // Update the request URL + req.URL.RawQuery = query.Encode() + + return req, nil +} + +// ModifyRequest adds configured parameters to token requests +func (m *TokenModifier) ModifyRequest(req *http.Request) (*http.Request, error) { + // Only modify POST requests + if req.Method != http.MethodPost { + return req, nil + } + + // Check if we have parameters to add + if m.Config.Default.Path == nil { + return req, nil + } + + pathConfig, exists := m.Config.Default.Path["/token"] + if !exists || len(pathConfig.AddBodyParams) == 0 { + return req, nil + } + + contentType := req.Header.Get("Content-Type") + + if strings.Contains(contentType, "application/x-www-form-urlencoded") { + // Parse form data + if err := req.ParseForm(); err != nil { + return nil, err + } + + // Clone form data + formData := req.PostForm + + // Add configured parameters + for _, param := range pathConfig.AddBodyParams { + formData.Set(param.Name, param.Value) + } + + // Create new request body with modified form + formEncoded := formData.Encode() + req.Body = io.NopCloser(strings.NewReader(formEncoded)) + req.ContentLength = int64(len(formEncoded)) + req.Header.Set("Content-Length", fmt.Sprintf("%d", len(formEncoded))) + + } else if strings.Contains(contentType, "application/json") { + // Read body + bodyBytes, err := io.ReadAll(req.Body) + if err != nil { + return nil, err + } + + // Parse JSON + var jsonData map[string]interface{} + if err := json.Unmarshal(bodyBytes, &jsonData); err != nil { + return nil, err + } + + // Add parameters + for _, param := range pathConfig.AddBodyParams { + jsonData[param.Name] = param.Value + } + + // Marshal back to JSON + modifiedBody, err := json.Marshal(jsonData) + if err != nil { + return nil, err + } + + // Update request + req.Body = io.NopCloser(bytes.NewReader(modifiedBody)) + req.ContentLength = int64(len(modifiedBody)) + req.Header.Set("Content-Length", fmt.Sprintf("%d", len(modifiedBody))) + } + + return req, nil +} + +func (m *RegisterModifier) ModifyRequest(req *http.Request) (*http.Request, error) { + // Only modify POST requests + if req.Method != http.MethodPost { + return req, nil + } + + // Check if we have parameters to add + if m.Config.Default.Path == nil { + return req, nil + } + + pathConfig, exists := m.Config.Default.Path["/register"] + if !exists || len(pathConfig.AddBodyParams) == 0 { + return req, nil + } + + contentType := req.Header.Get("Content-Type") + + if strings.Contains(contentType, "application/x-www-form-urlencoded") { + // Parse form data + if err := req.ParseForm(); err != nil { + return nil, err + } + + // Clone form data + formData := req.PostForm + + // Add configured parameters + for _, param := range pathConfig.AddBodyParams { + formData.Set(param.Name, param.Value) + } + + // Create new request body with modified form + formEncoded := formData.Encode() + req.Body = io.NopCloser(strings.NewReader(formEncoded)) + req.ContentLength = int64(len(formEncoded)) + req.Header.Set("Content-Length", fmt.Sprintf("%d", len(formEncoded))) + + } else if strings.Contains(contentType, "application/json") { + // Read body + bodyBytes, err := io.ReadAll(req.Body) + if err != nil { + return nil, err + } + + // Parse JSON + var jsonData map[string]interface{} + if err := json.Unmarshal(bodyBytes, &jsonData); err != nil { + return nil, err + } + + // Add parameters + for _, param := range pathConfig.AddBodyParams { + jsonData[param.Name] = param.Value + } + + // Marshal back to JSON + modifiedBody, err := json.Marshal(jsonData) + if err != nil { + return nil, err + } + + // Update request + req.Body = io.NopCloser(bytes.NewReader(modifiedBody)) + req.ContentLength = int64(len(modifiedBody)) + req.Header.Set("Content-Length", fmt.Sprintf("%d", len(modifiedBody))) + } + + return req, nil +} diff --git a/internal/proxy/proxy.go b/internal/proxy/proxy.go index 5f74125..c999be4 100644 --- a/internal/proxy/proxy.go +++ b/internal/proxy/proxy.go @@ -20,38 +20,77 @@ import ( func NewRouter(cfg *config.Config, provider authz.Provider) http.Handler { mux := http.NewServeMux() + modifiers := map[string]RequestModifier{ + "/authorize": &AuthorizationModifier{Config: cfg}, + "/token": &TokenModifier{Config: cfg}, + "/register": &RegisterModifier{Config: cfg}, + } + registeredPaths := make(map[string]bool) var defaultPaths []string + + // Handle based on mode configuration if cfg.Mode == "demo" || cfg.Mode == "asgardeo" { - // 1. Custom well-known + // Demo/Asgardeo mode: Custom handlers for well-known and register mux.HandleFunc("/.well-known/oauth-authorization-server", provider.WellKnownHandler()) registeredPaths["/.well-known/oauth-authorization-server"] = true - // 2. Registration mux.HandleFunc("/register", provider.RegisterHandler()) registeredPaths["/register"] = true + // Authorize and token will be proxied with parameter modification defaultPaths = []string{"/authorize", "/token"} } else { - defaultPaths = []string{"/authorize", "/token", "/register", "/.well-known/oauth-authorization-server"} + // Default provider mode + if cfg.Default.Path != nil { + // Check if we have custom response for well-known + wellKnownConfig, exists := cfg.Default.Path["/.well-known/oauth-authorization-server"] + if exists && wellKnownConfig.Response != nil { + // If there's a custom response defined, use our handler + mux.HandleFunc("/.well-known/oauth-authorization-server", provider.WellKnownHandler()) + registeredPaths["/.well-known/oauth-authorization-server"] = true + } else { + // No custom response, add well-known to proxy paths + defaultPaths = append(defaultPaths, "/.well-known/oauth-authorization-server") + } + + defaultPaths = append(defaultPaths, "/authorize") + defaultPaths = append(defaultPaths, "/token") + defaultPaths = append(defaultPaths, "/register") + } else { + defaultPaths = []string{"/authorize", "/token", "/register", "/.well-known/oauth-authorization-server"} + } } + // Remove duplicates from defaultPaths + uniquePaths := make(map[string]bool) + cleanPaths := []string{} + for _, path := range defaultPaths { + if !uniquePaths[path] { + uniquePaths[path] = true + cleanPaths = append(cleanPaths, path) + } + } + defaultPaths = cleanPaths + for _, path := range defaultPaths { - mux.HandleFunc(path, buildProxyHandler(cfg)) - registeredPaths[path] = true + if !registeredPaths[path] { + mux.HandleFunc(path, buildProxyHandler(cfg, modifiers)) + registeredPaths[path] = true + } } - // 4. MCP paths + // MCP paths for _, path := range cfg.MCPPaths { - mux.HandleFunc(path, buildProxyHandler(cfg)) + mux.HandleFunc(path, buildProxyHandler(cfg, modifiers)) registeredPaths[path] = true } - // 5. Register paths from PathMapping that haven't been registered yet + // Register paths from PathMapping that haven't been registered yet for path := range cfg.PathMapping { if !registeredPaths[path] { - mux.HandleFunc(path, buildProxyHandler(cfg)) + mux.HandleFunc(path, buildProxyHandler(cfg, modifiers)) registeredPaths[path] = true } } @@ -59,8 +98,9 @@ func NewRouter(cfg *config.Config, provider authz.Provider) http.Handler { return mux } -func buildProxyHandler(cfg *config.Config) http.HandlerFunc { +func buildProxyHandler(cfg *config.Config, modifiers map[string]RequestModifier) http.HandlerFunc { // Parse the base URLs up front + authBase, err := url.Parse(cfg.AuthServerBaseURL) if err != nil { log.Fatalf("Invalid auth server URL: %v", err) @@ -125,6 +165,17 @@ func buildProxyHandler(cfg *config.Config) http.HandlerFunc { return } + // Apply request modifiers to add parameters + if modifier, exists := modifiers[r.URL.Path]; exists { + var err error + r, err = modifier.ModifyRequest(r) + if err != nil { + log.Printf("[proxy] Error modifying request: %v", err) + http.Error(w, "Bad Request", http.StatusBadRequest) + return + } + } + // Build the reverse proxy rp := &httputil.ReverseProxy{ Director: func(req *http.Request) { @@ -152,23 +203,8 @@ func buildProxyHandler(cfg *config.Config) http.HandlerFunc { cleanHeaders.Set(k, v[0]) } - // Override or remove sensitive headers if needed - if strings.Contains(req.URL.Path, "/token") { - cleanHeaders.Set("Accept", "application/json") - cleanHeaders.Set("Content-Type", "application/x-www-form-urlencoded") - cleanHeaders.Set("User-Agent", "GoProxy/1.0") - cleanHeaders.Del("Origin") - cleanHeaders.Del("Referer") - } - req.Header = cleanHeaders - // DEBUG: log headers sent to Asgardeo - log.Println("[proxy] Outgoing request headers:") - for k, v := range req.Header { - log.Printf(" %s: %s", k, strings.Join(v, ", ")) - } - log.Printf("[proxy] %s -> %s%s", r.URL.Path, req.URL.Host, req.URL.Path) }, ModifyResponse: func(resp *http.Response) error { @@ -226,8 +262,12 @@ func isAuthPath(path string) bool { authPaths := map[string]bool{ "/authorize": true, "/token": true, + "/register": true, "/.well-known/oauth-authorization-server": true, } + if strings.HasPrefix(path, "/u/") { + return true + } return authPaths[path] }