From 960261fc809b377f5197a88c6db501f3bb6a06e5 Mon Sep 17 00:00:00 2001 From: Thilina Shashimal Senarath Date: Thu, 3 Apr 2025 02:53:14 +0530 Subject: [PATCH] fix standard auth --- cmd/proxy/main.go | 2 + config.yaml | 17 +++- internal/config/config.go | 9 +++ internal/proxy/proxy.go | 162 ++++++++++++++++++++++++++------------ 4 files changed, 137 insertions(+), 53 deletions(-) diff --git a/cmd/proxy/main.go b/cmd/proxy/main.go index f02d9c3..c22dc96 100644 --- a/cmd/proxy/main.go +++ b/cmd/proxy/main.go @@ -29,11 +29,13 @@ func main() { // 2. Create the chosen provider var provider authz.Provider if *demoMode { + cfg.Mode = "demo" cfg.AuthServerBaseURL = "https://api.asgardeo.io/t/" + cfg.Demo.OrgName + "/oauth2" cfg.JWKSURL = "https://api.asgardeo.io/t/" + cfg.Demo.OrgName + "/oauth2/jwks" provider = authz.NewAsgardeoProvider(cfg) fmt.Println("Using Asgardeo provider (demo).") } else if *asgardeoMode { + cfg.Mode = "asgardeo" cfg.AuthServerBaseURL = "https://api.asgardeo.io/t/" + cfg.Asgardeo.OrgName + "/oauth2" cfg.JWKSURL = "https://api.asgardeo.io/t/" + cfg.Asgardeo.OrgName + "/oauth2/jwks" provider = authz.NewAsgardeoProvider(cfg) diff --git a/config.yaml b/config.yaml index 9385f58..4c6b196 100644 --- a/config.yaml +++ b/config.yaml @@ -1,7 +1,7 @@ # config.yaml auth_server_base_url: "" -mcp_server_base_url: "http://localhost:8000" +mcp_server_base_url: "" listen_address: ":8080" jwks_url: "" timeout_seconds: 10 @@ -11,6 +11,21 @@ mcp_paths: - /sse path_mapping: + /token: /oauth/token + /.well-known/oauth-authorization-server: /.well-known/openid-configuration + +cors: + allowed_origins: + - "" + allowed_methods: + - "GET" + - "POST" + - "PUT" + - "DELETE" + allowed_headers: + - "Authorization" + - "Content-Type" + allow_credentials: true demo: org_name: "openmcpauthdemo" diff --git a/internal/config/config.go b/internal/config/config.go index 6cdc949..f4f8218 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -19,6 +19,13 @@ type AsgardeoConfig struct { OrgName string `yaml:"org_name"` } +type CORSConfig struct { + AllowedOrigins []string `yaml:"allowed_origins"` + AllowedMethods []string `yaml:"allowed_methods"` + AllowedHeaders []string `yaml:"allowed_headers"` + AllowCredentials bool `yaml:"allow_credentials"` +} + type Config struct { AuthServerBaseURL string `yaml:"auth_server_base_url"` MCPServerBaseURL string `yaml:"mcp_server_base_url"` @@ -27,6 +34,8 @@ type Config struct { TimeoutSeconds int `yaml:"timeout_seconds"` MCPPaths []string `yaml:"mcp_paths"` PathMapping map[string]string `yaml:"path_mapping"` + Mode string `yaml:"mode"` + CORSConfig CORSConfig `yaml:"cors"` // Nested config for Asgardeo Demo DemoConfig `yaml:"demo"` diff --git a/internal/proxy/proxy.go b/internal/proxy/proxy.go index 382c8f3..5f74125 100644 --- a/internal/proxy/proxy.go +++ b/internal/proxy/proxy.go @@ -20,27 +20,40 @@ import ( func NewRouter(cfg *config.Config, provider authz.Provider) http.Handler { mux := http.NewServeMux() - // 1. Custom well-known - mux.HandleFunc("/.well-known/oauth-authorization-server", provider.WellKnownHandler()) + registeredPaths := make(map[string]bool) - // 2. Registration - mux.HandleFunc("/register", provider.RegisterHandler()) + var defaultPaths []string + if cfg.Mode == "demo" || cfg.Mode == "asgardeo" { + // 1. Custom well-known + mux.HandleFunc("/.well-known/oauth-authorization-server", provider.WellKnownHandler()) + registeredPaths["/.well-known/oauth-authorization-server"] = true + + // 2. Registration + mux.HandleFunc("/register", provider.RegisterHandler()) + registeredPaths["/register"] = true + + defaultPaths = []string{"/authorize", "/token"} + } else { + defaultPaths = []string{"/authorize", "/token", "/register", "/.well-known/oauth-authorization-server"} + } - // 3. Default "auth" paths, proxied - defaultPaths := []string{"/authorize", "/token"} for _, path := range defaultPaths { mux.HandleFunc(path, buildProxyHandler(cfg)) + registeredPaths[path] = true } // 4. MCP paths for _, path := range cfg.MCPPaths { mux.HandleFunc(path, buildProxyHandler(cfg)) + registeredPaths[path] = true } - // 5. If you want to map additional paths from config.PathMapping - // to the same proxy logic: + // 5. Register paths from PathMapping that haven't been registered yet for path := range cfg.PathMapping { - mux.HandleFunc(path, buildProxyHandler(cfg)) + if !registeredPaths[path] { + mux.HandleFunc(path, buildProxyHandler(cfg)) + registeredPaths[path] = true + } } return mux @@ -57,13 +70,6 @@ func buildProxyHandler(cfg *config.Config) http.HandlerFunc { log.Fatalf("Invalid MCP server URL: %v", err) } - // We'll define sets for known auth paths, SSE paths, etc. - authPaths := map[string]bool{ - "/authorize": true, - "/token": true, - "/.well-known/oauth-authorization-server": true, - } - // Detect SSE paths from config ssePaths := make(map[string]bool) for _, p := range cfg.MCPPaths { @@ -73,23 +79,38 @@ func buildProxyHandler(cfg *config.Config) http.HandlerFunc { } return func(w http.ResponseWriter, r *http.Request) { + origin := r.Header.Get("Origin") + allowedOrigin := getAllowedOrigin(origin, cfg) // Handle OPTIONS if r.Method == http.MethodOptions { - addCORSHeaders(w) + if allowedOrigin == "" { + log.Printf("[proxy] Preflight request from disallowed origin: %s", origin) + http.Error(w, "CORS origin not allowed", http.StatusForbidden) + return + } + addCORSHeaders(w, cfg, allowedOrigin, r.Header.Get("Access-Control-Request-Headers")) w.WriteHeader(http.StatusNoContent) return } - addCORSHeaders(w) + if allowedOrigin == "" { + log.Printf("[proxy] Request from disallowed origin: %s for %s", origin, r.URL.Path) + http.Error(w, "CORS origin not allowed", http.StatusForbidden) + return + } + + // Add CORS headers to all responses + addCORSHeaders(w, cfg, allowedOrigin, "") // Decide whether the request should go to the auth server or MCP var targetURL *url.URL isSSE := false - if authPaths[r.URL.Path] { + if isAuthPath(r.URL.Path) { targetURL = authBase } else if isMCPPath(r.URL.Path, cfg) { - // Validate JWT if you want + // Validate JWT for MCP paths if required + // Placeholder for JWT validation logic if err := util.ValidateJWT(r.Header.Get("Authorization")); err != nil { log.Printf("[proxy] Unauthorized request to %s: %v", r.URL.Path, err) http.Error(w, "Unauthorized", http.StatusUnauthorized) @@ -100,7 +121,6 @@ func buildProxyHandler(cfg *config.Config) http.HandlerFunc { isSSE = true } } else { - // If it's not recognized as an auth path or an MCP path http.Error(w, "Forbidden", http.StatusForbidden) return } @@ -120,23 +140,42 @@ func buildProxyHandler(cfg *config.Config) http.HandlerFunc { req.URL.RawQuery = r.URL.RawQuery req.Host = targetURL.Host - for header, values := range r.Header { + cleanHeaders := http.Header{} + + for k, v := range r.Header { // Skip hop-by-hop headers - if strings.EqualFold(header, "Connection") || - strings.EqualFold(header, "Keep-Alive") || - strings.EqualFold(header, "Transfer-Encoding") || - strings.EqualFold(header, "Upgrade") || - strings.EqualFold(header, "Proxy-Authorization") || - strings.EqualFold(header, "Proxy-Connection") { + if skipHeader(k) { continue } - for _, value := range values { - req.Header.Set(header, value) - } + // Set only the first value to avoid duplicates + cleanHeaders.Set(k, v[0]) } + + // Override or remove sensitive headers if needed + if strings.Contains(req.URL.Path, "/token") { + cleanHeaders.Set("Accept", "application/json") + cleanHeaders.Set("Content-Type", "application/x-www-form-urlencoded") + cleanHeaders.Set("User-Agent", "GoProxy/1.0") + cleanHeaders.Del("Origin") + cleanHeaders.Del("Referer") + } + + req.Header = cleanHeaders + + // DEBUG: log headers sent to Asgardeo + log.Println("[proxy] Outgoing request headers:") + for k, v := range req.Header { + log.Printf(" %s: %s", k, strings.Join(v, ", ")) + } + log.Printf("[proxy] %s -> %s%s", r.URL.Path, req.URL.Host, req.URL.Path) }, + ModifyResponse: func(resp *http.Response) error { + log.Printf("[proxy] Response from %s%s: %d", resp.Request.URL.Host, resp.Request.URL.Path, resp.StatusCode) + resp.Header.Del("Access-Control-Allow-Origin") // Avoid upstream conflicts + return nil + }, ErrorHandler: func(rw http.ResponseWriter, req *http.Request, err error) { log.Printf("[proxy] Error proxying: %v", err) http.Error(rw, "Bad Gateway", http.StatusBadGateway) @@ -156,12 +195,43 @@ func buildProxyHandler(cfg *config.Config) http.HandlerFunc { } } -func addCORSHeaders(w http.ResponseWriter) { - w.Header().Set("Access-Control-Allow-Origin", "*") - w.Header().Set("Access-Control-Allow-Headers", "Authorization, Content-Type") - w.Header().Set("Access-Control-Allow-Methods", "GET, POST, OPTIONS") +func getAllowedOrigin(origin string, cfg *config.Config) string { + if origin == "" { + return cfg.CORSConfig.AllowedOrigins[0] // Default to first allowed origin + } + for _, allowed := range cfg.CORSConfig.AllowedOrigins { + if allowed == origin { + return allowed + } + } + return "" } +// addCORSHeaders adds configurable CORS headers +func addCORSHeaders(w http.ResponseWriter, cfg *config.Config, allowedOrigin, requestHeaders string) { + w.Header().Set("Access-Control-Allow-Origin", allowedOrigin) + w.Header().Set("Access-Control-Allow-Methods", strings.Join(cfg.CORSConfig.AllowedMethods, ", ")) + if requestHeaders != "" { + w.Header().Set("Access-Control-Allow-Headers", requestHeaders) + } else { + w.Header().Set("Access-Control-Allow-Headers", strings.Join(cfg.CORSConfig.AllowedHeaders, ", ")) + } + if cfg.CORSConfig.AllowCredentials { + w.Header().Set("Access-Control-Allow-Credentials", "true") + } + w.Header().Set("Vary", "Origin") +} + +func isAuthPath(path string) bool { + authPaths := map[string]bool{ + "/authorize": true, + "/token": true, + "/.well-known/oauth-authorization-server": true, + } + return authPaths[path] +} + +// isMCPPath checks if the path is an MCP path func isMCPPath(path string, cfg *config.Config) bool { for _, p := range cfg.MCPPaths { if strings.HasPrefix(path, p) { @@ -171,22 +241,10 @@ func isMCPPath(path string, cfg *config.Config) bool { return false } -func copyHeaders(src http.Header, dst http.Header) { - // Exclude hop-by-hop - hopByHop := map[string]bool{ - "Connection": true, - "Keep-Alive": true, - "Transfer-Encoding": true, - "Upgrade": true, - "Proxy-Authorization": true, - "Proxy-Connection": true, - } - for k, vv := range src { - if hopByHop[strings.ToLower(k)] { - continue - } - for _, v := range vv { - dst.Add(k, v) - } +func skipHeader(h string) bool { + switch strings.ToLower(h) { + case "connection", "keep-alive", "transfer-encoding", "upgrade", "proxy-authorization", "proxy-connection", "te", "trailer": + return true } + return false }