From 2548eb569a76de0078eb77109d78b368e1ca0787 Mon Sep 17 00:00:00 2001 From: Chiran Fernando Date: Sat, 5 Apr 2025 08:57:33 +0530 Subject: [PATCH] Standardize logging and improve sensitive data handling --- cmd/proxy/main.go | 36 ++++++++++++++-------------- internal/authz/asgardeo.go | 21 ++++++++++------ internal/authz/default.go | 4 +++- internal/proxy/modifier.go | 7 +++++- internal/proxy/proxy.go | 24 ++++++++++--------- internal/proxy/sse.go | 13 +++++----- internal/subprocess/manager.go | 44 +++++++++++++++++----------------- internal/util/jwks.go | 6 ++--- 8 files changed, 86 insertions(+), 69 deletions(-) diff --git a/cmd/proxy/main.go b/cmd/proxy/main.go index a61f227..d116fea 100644 --- a/cmd/proxy/main.go +++ b/cmd/proxy/main.go @@ -3,7 +3,6 @@ package main import ( "flag" "fmt" - "log" "net/http" "os" "os/signal" @@ -13,7 +12,7 @@ import ( "github.com/wso2/open-mcp-auth-proxy/internal/authz" "github.com/wso2/open-mcp-auth-proxy/internal/config" "github.com/wso2/open-mcp-auth-proxy/internal/constants" - "github.com/wso2/open-mcp-auth-proxy/internal/logging" + "github.com/wso2/open-mcp-auth-proxy/internal/logging" "github.com/wso2/open-mcp-auth-proxy/internal/proxy" "github.com/wso2/open-mcp-auth-proxy/internal/subprocess" "github.com/wso2/open-mcp-auth-proxy/internal/util" @@ -30,7 +29,8 @@ func main() { // 1. Load config cfg, err := config.LoadConfig("config.yaml") if err != nil { - log.Fatalf("Error loading config: %v", err) + logger.Error("Error loading config: %v", err) + os.Exit(1) } // 2. Ensure MCPPaths includes the configured paths from the command @@ -60,9 +60,7 @@ func main() { } } - // Add the baseUrl to allowed origins if not already present - // ensureOriginInList(&cfg.CORSConfig.AllowedOrigins, "http://localhost:8080") - log.Printf("Using MCP server baseUrl: %s", baseUrl) + logger.Info("Using MCP server baseUrl: %s", baseUrl) } // 3. Start subprocess if configured @@ -70,13 +68,13 @@ func main() { if cfg.Command.Enabled && cfg.Command.UserCommand != "" { // Ensure all required dependencies are available if err := subprocess.EnsureDependenciesAvailable(cfg.Command.UserCommand); err != nil { - log.Printf("Warning: %v", err) - log.Printf("Subprocess may fail to start due to missing dependencies") + logger.Warn("%v", err) + logger.Warn("Subprocess may fail to start due to missing dependencies") } procManager = subprocess.NewManager() if err := procManager.Start(&cfg.Command); err != nil { - log.Printf("Warning: Failed to start subprocess: %v", err) + logger.Warn("Failed to start subprocess: %v", err) } } @@ -101,7 +99,8 @@ func main() { // 5. (Optional) Fetch JWKS if you want local JWT validation if err := util.FetchJWKS(cfg.JWKSURL); err != nil { - log.Fatalf("Failed to fetch JWKS: %v", err) + logger.Error("Failed to fetch JWKS: %v", err) + os.Exit(1) } // 6. Build the main router @@ -116,9 +115,10 @@ func main() { } go func() { - log.Printf("Server listening on %s", listen_address) + logger.Info("Server listening on %s", listen_address) if err := srv.ListenAndServe(); err != nil && err != http.ErrServerClosed { - log.Fatalf("Server error: %v", err) + logger.Error("Server error: %v", err) + os.Exit(1) } }() @@ -126,7 +126,7 @@ func main() { stop := make(chan os.Signal, 1) signal.Notify(stop, os.Interrupt, syscall.SIGTERM) <-stop - log.Println("Shutting down...") + logger.Info("Shutting down...") // 9. First terminate subprocess if running if procManager != nil && procManager.IsRunning() { @@ -134,14 +134,14 @@ func main() { } // 10. Then shutdown the server - log.Println("Shutting down HTTP server...") + logger.Info("Shutting down HTTP server...") shutdownCtx, cancel := proxy.NewShutdownContext(5 * time.Second) defer cancel() if err := srv.Shutdown(shutdownCtx); err != nil { - log.Printf("HTTP server shutdown error: %v", err) + logger.Error("HTTP server shutdown error: %v", err) } - log.Println("Stopped.") + logger.Info("Stopped.") } // Helper function to ensure a path is in a list @@ -154,7 +154,7 @@ func ensurePathInList(paths *[]string, path string) { } // Path doesn't exist, add it *paths = append(*paths, path) - log.Printf("Added path %s to MCPPaths", path) + logger.Info("Added path %s to MCPPaths", path) } // Helper function to ensure an origin is in a list @@ -167,5 +167,5 @@ func ensureOriginInList(origins *[]string, origin string) { } // Origin doesn't exist, add it *origins = append(*origins, origin) - log.Printf("Added %s to allowed CORS origins", origin) + logger.Info("Added %s to allowed CORS origins", origin) } \ No newline at end of file diff --git a/internal/authz/asgardeo.go b/internal/authz/asgardeo.go index 429e2e2..f0a7fac 100644 --- a/internal/authz/asgardeo.go +++ b/internal/authz/asgardeo.go @@ -7,13 +7,13 @@ import ( "encoding/json" "fmt" "io" - "log" "math/rand" "net/http" "strings" "time" "github.com/wso2/open-mcp-auth-proxy/internal/config" + "github.com/wso2/open-mcp-auth-proxy/internal/logging" ) type asgardeoProvider struct { @@ -73,7 +73,7 @@ func (p *asgardeoProvider) WellKnownHandler() http.HandlerFunc { w.Header().Set("Content-Type", "application/json") w.Header().Set("X-Accel-Buffering", "no") if err := json.NewEncoder(w).Encode(response); err != nil { - log.Printf("[asgardeoProvider] Error encoding well-known: %v", err) + logger.Error("Error encoding well-known: %v", err) http.Error(w, "Internal server error", http.StatusInternalServerError) } } @@ -98,7 +98,7 @@ func (p *asgardeoProvider) RegisterHandler() http.HandlerFunc { var regReq RegisterRequest if err := json.NewDecoder(r.Body).Decode(®Req); err != nil { - log.Printf("ERROR: reading register request: %v", err) + logger.Error("Reading register request: %v", err) http.Error(w, "Invalid request body", http.StatusBadRequest) return } @@ -112,7 +112,7 @@ func (p *asgardeoProvider) RegisterHandler() http.HandlerFunc { regReq.ClientSecret = randomString(16) if err := p.createAsgardeoApplication(regReq); err != nil { - log.Printf("WARN: Asgardeo application creation failed: %v", err) + logger.Warn("Asgardeo application creation failed: %v", err) // Optionally http.Error(...) if you want to fail // or continue to return partial data. } @@ -130,7 +130,7 @@ func (p *asgardeoProvider) RegisterHandler() http.HandlerFunc { w.Header().Set("X-Accel-Buffering", "no") w.WriteHeader(http.StatusCreated) if err := json.NewEncoder(w).Encode(resp); err != nil { - log.Printf("ERROR: encoding /register response: %v", err) + logger.Error("Encoding /register response: %v", err) http.Error(w, "Internal server error", http.StatusInternalServerError) } } @@ -190,7 +190,7 @@ func (p *asgardeoProvider) createAsgardeoApplication(regReq RegisterRequest) err return fmt.Errorf("Asgardeo creation error (%d): %s", resp.StatusCode, string(respBody)) } - log.Printf("INFO: Created Asgardeo application for clientID=%s", regReq.ClientID) + logger.Info("Created Asgardeo application for clientID=%s", regReq.ClientID) return nil } @@ -206,8 +206,11 @@ func (p *asgardeoProvider) getAsgardeoAdminToken() (string, error) { } req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + // Sensitive data - should not be logged at INFO level auth := p.cfg.Demo.ClientID + ":" + p.cfg.Demo.ClientSecret req.Header.Set("Authorization", "Basic "+base64.StdEncoding.EncodeToString([]byte(auth))) + + logger.Debug("Requesting admin token for Asgardeo with client ID: %s", p.cfg.Demo.ClientID) tr := &http.Transport{ TLSClientConfig: &tls.Config{InsecureSkipVerify: true}, @@ -238,6 +241,10 @@ func (p *asgardeoProvider) getAsgardeoAdminToken() (string, error) { return "", fmt.Errorf("failed to parse token JSON: %w", err) } + // Don't log the actual token at info level, only at debug level + logger.Debug("Received access token: %s", tokenResp.AccessToken) + logger.Info("Successfully obtained admin token from Asgardeo") + return tokenResp.AccessToken, nil } @@ -328,4 +335,4 @@ func randomString(n int) string { b[i] = letters[rand.Intn(len(letters))] } return string(b) -} +} \ No newline at end of file diff --git a/internal/authz/default.go b/internal/authz/default.go index 9230d39..666b958 100644 --- a/internal/authz/default.go +++ b/internal/authz/default.go @@ -5,6 +5,7 @@ import ( "net/http" "github.com/wso2/open-mcp-auth-proxy/internal/config" + "github.com/wso2/open-mcp-auth-proxy/internal/logging" ) type defaultProvider struct { @@ -81,6 +82,7 @@ func (p *defaultProvider) WellKnownHandler() http.HandlerFunc { w.Header().Set("Content-Type", "application/json") if err := json.NewEncoder(w).Encode(response); err != nil { + logger.Error("Error encoding well-known response: %v", err) http.Error(w, "Internal server error", http.StatusInternalServerError) } return @@ -91,4 +93,4 @@ func (p *defaultProvider) WellKnownHandler() http.HandlerFunc { func (p *defaultProvider) RegisterHandler() http.HandlerFunc { return nil -} +} \ No newline at end of file diff --git a/internal/proxy/modifier.go b/internal/proxy/modifier.go index 8e2268b..fe86e46 100644 --- a/internal/proxy/modifier.go +++ b/internal/proxy/modifier.go @@ -9,6 +9,7 @@ import ( "strings" "github.com/wso2/open-mcp-auth-proxy/internal/config" + "github.com/wso2/open-mcp-auth-proxy/internal/logging" ) // RequestModifier modifies requests before they are proxied @@ -148,6 +149,7 @@ func (m *RegisterModifier) ModifyRequest(req *http.Request) (*http.Request, erro if strings.Contains(contentType, "application/x-www-form-urlencoded") { // Parse form data if err := req.ParseForm(); err != nil { + logger.Error("Failed to parse form data: %v", err) return nil, err } @@ -169,12 +171,14 @@ func (m *RegisterModifier) ModifyRequest(req *http.Request) (*http.Request, erro // Read body bodyBytes, err := io.ReadAll(req.Body) if err != nil { + logger.Error("Failed to read request body: %v", err) return nil, err } // Parse JSON var jsonData map[string]interface{} if err := json.Unmarshal(bodyBytes, &jsonData); err != nil { + logger.Error("Failed to parse JSON body: %v", err) return nil, err } @@ -186,6 +190,7 @@ func (m *RegisterModifier) ModifyRequest(req *http.Request) (*http.Request, erro // Marshal back to JSON modifiedBody, err := json.Marshal(jsonData) if err != nil { + logger.Error("Failed to marshal modified JSON: %v", err) return nil, err } @@ -196,4 +201,4 @@ func (m *RegisterModifier) ModifyRequest(req *http.Request) (*http.Request, erro } return req, nil -} +} \ No newline at end of file diff --git a/internal/proxy/proxy.go b/internal/proxy/proxy.go index e02c8e5..898804a 100644 --- a/internal/proxy/proxy.go +++ b/internal/proxy/proxy.go @@ -2,7 +2,6 @@ package proxy import ( "context" - "log" "net/http" "net/http/httputil" "net/url" @@ -11,6 +10,7 @@ import ( "github.com/wso2/open-mcp-auth-proxy/internal/authz" "github.com/wso2/open-mcp-auth-proxy/internal/config" + "github.com/wso2/open-mcp-auth-proxy/internal/logging" "github.com/wso2/open-mcp-auth-proxy/internal/util" ) @@ -102,11 +102,13 @@ func buildProxyHandler(cfg *config.Config, modifiers map[string]RequestModifier) // Parse the base URLs up front authBase, err := url.Parse(cfg.AuthServerBaseURL) if err != nil { - log.Fatalf("Invalid auth server URL: %v", err) + logger.Error("Invalid auth server URL: %v", err) + panic(err) // Fatal error that prevents startup } mcpBase, err := url.Parse(cfg.MCPServerBaseURL) if err != nil { - log.Fatalf("Invalid MCP server URL: %v", err) + logger.Error("Invalid MCP server URL: %v", err) + panic(err) // Fatal error that prevents startup } // Detect SSE paths from config @@ -123,7 +125,7 @@ func buildProxyHandler(cfg *config.Config, modifiers map[string]RequestModifier) // Handle OPTIONS if r.Method == http.MethodOptions { if allowedOrigin == "" { - log.Printf("[proxy] Preflight request from disallowed origin: %s", origin) + logger.Warn("Preflight request from disallowed origin: %s", origin) http.Error(w, "CORS origin not allowed", http.StatusForbidden) return } @@ -133,7 +135,7 @@ func buildProxyHandler(cfg *config.Config, modifiers map[string]RequestModifier) } if allowedOrigin == "" { - log.Printf("[proxy] Request from disallowed origin: %s for %s", origin, r.URL.Path) + logger.Warn("Request from disallowed origin: %s for %s", origin, r.URL.Path) http.Error(w, "CORS origin not allowed", http.StatusForbidden) return } @@ -151,7 +153,7 @@ func buildProxyHandler(cfg *config.Config, modifiers map[string]RequestModifier) // Validate JWT for MCP paths if required // Placeholder for JWT validation logic if err := util.ValidateJWT(r.Header.Get("Authorization")); err != nil { - log.Printf("[proxy] Unauthorized request to %s: %v", r.URL.Path, err) + logger.Warn("Unauthorized request to %s: %v", r.URL.Path, err) http.Error(w, "Unauthorized", http.StatusUnauthorized) return } @@ -169,7 +171,7 @@ func buildProxyHandler(cfg *config.Config, modifiers map[string]RequestModifier) var err error r, err = modifier.ModifyRequest(r) if err != nil { - log.Printf("[proxy] Error modifying request: %v", err) + logger.Error("Error modifying request: %v", err) http.Error(w, "Bad Request", http.StatusBadRequest) return } @@ -210,15 +212,15 @@ func buildProxyHandler(cfg *config.Config, modifiers map[string]RequestModifier) req.Header = cleanHeaders - log.Printf("[proxy] %s -> %s%s", r.URL.Path, req.URL.Host, req.URL.Path) + logger.Debug("%s -> %s%s", r.URL.Path, req.URL.Host, req.URL.Path) }, ModifyResponse: func(resp *http.Response) error { - log.Printf("[proxy] Response from %s%s: %d", resp.Request.URL.Host, resp.Request.URL.Path, resp.StatusCode) + logger.Debug("Response from %s%s: %d", resp.Request.URL.Host, resp.Request.URL.Path, resp.StatusCode) resp.Header.Del("Access-Control-Allow-Origin") // Avoid upstream conflicts return nil }, ErrorHandler: func(rw http.ResponseWriter, req *http.Request, err error) { - log.Printf("[proxy] Error proxying: %v", err) + logger.Error("Error proxying: %v", err) http.Error(rw, "Bad Gateway", http.StatusBadGateway) }, FlushInterval: -1, // immediate flush for SSE @@ -253,7 +255,7 @@ func getAllowedOrigin(origin string, cfg *config.Config) string { return cfg.CORSConfig.AllowedOrigins[0] // Default to first allowed origin } for _, allowed := range cfg.CORSConfig.AllowedOrigins { - log.Printf("[proxy] Checking CORS origin: %s against allowed: %s", origin, allowed) + logger.Debug("Checking CORS origin: %s against allowed: %s", origin, allowed) if allowed == origin { return allowed } diff --git a/internal/proxy/sse.go b/internal/proxy/sse.go index cd168ce..4b64d23 100644 --- a/internal/proxy/sse.go +++ b/internal/proxy/sse.go @@ -5,11 +5,12 @@ import ( "context" "fmt" "io" - "log" "net/http" "net/http/httputil" "strings" "time" + + "github.com/wso2/open-mcp-auth-proxy/internal/logging" ) // HandleSSE sets up a go-routine to wait for context cancellation @@ -20,7 +21,7 @@ func HandleSSE(w http.ResponseWriter, r *http.Request, rp *httputil.ReverseProxy go func() { <-ctx.Done() - log.Printf("INFO: SSE connection closed from %s (path: %s)", r.RemoteAddr, r.URL.Path) + logger.Info("SSE connection closed from %s (path: %s)", r.RemoteAddr, r.URL.Path) close(done) }() @@ -57,7 +58,7 @@ func (t *sseTransport) RoundTrip(req *http.Request) (*http.Response, error) { return resp, nil } - log.Printf("INFO: Intercepting SSE response to modify endpoint events") + logger.Info("Intercepting SSE response to modify endpoint events") // Create a response wrapper that modifies the response body originalBody := resp.Body @@ -81,9 +82,9 @@ func (t *sseTransport) RoundTrip(req *http.Request) (*http.Response, error) { endpoint := strings.TrimPrefix(dataLine, "data: ") // Replace the host in the endpoint - log.Printf("DEBUG: Original endpoint: %s", endpoint) + logger.Debug("Original endpoint: %s", endpoint) endpoint = strings.Replace(endpoint, t.targetHost, t.proxyHost, 1) - log.Printf("DEBUG: Modified endpoint: %s", endpoint) + logger.Debug("Modified endpoint: %s", endpoint) // Write the modified event lines fmt.Fprintln(pw, line) @@ -98,7 +99,7 @@ func (t *sseTransport) RoundTrip(req *http.Request) (*http.Response, error) { } if err := scanner.Err(); err != nil { - log.Printf("Error reading SSE stream: %v", err) + logger.Error("Error reading SSE stream: %v", err) } }() diff --git a/internal/subprocess/manager.go b/internal/subprocess/manager.go index a83e75b..e667886 100644 --- a/internal/subprocess/manager.go +++ b/internal/subprocess/manager.go @@ -1,16 +1,16 @@ package subprocess import ( - "log" + "fmt" "os" "os/exec" "sync" "syscall" "time" - "fmt" "strings" "github.com/wso2/open-mcp-auth-proxy/internal/config" + "github.com/wso2/open-mcp-auth-proxy/internal/logging" ) // Manager handles starting and graceful shutdown of subprocesses @@ -39,7 +39,7 @@ func EnsureDependenciesAvailable(command string) error { } // Try to install npx using npm - log.Printf("npx not found, attempting to install...") + logger.Info("npx not found, attempting to install...") cmd := exec.Command("npm", "install", "-g", "npx") cmd.Stdout = os.Stdout cmd.Stderr = os.Stderr @@ -48,7 +48,7 @@ func EnsureDependenciesAvailable(command string) error { return fmt.Errorf("failed to install npx: %w", err) } - log.Printf("npx installed successfully") + logger.Info("npx installed successfully") } // Check if uv is needed based on the command @@ -86,7 +86,7 @@ func (m *Manager) Start(cmdConfig *config.Command) error { return nil // No command to execute } - log.Printf("Starting subprocess with command: %s", execCommand) + logger.Info("Starting subprocess with command: %s", execCommand) // Use the shell to execute the command cmd := exec.Command("sh", "-c", execCommand) @@ -115,24 +115,24 @@ func (m *Manager) Start(cmdConfig *config.Command) error { m.process = cmd.Process m.cmd = cmd - log.Printf("Subprocess started with PID: %d", m.process.Pid) + logger.Info("Subprocess started with PID: %d", m.process.Pid) // Get and store the process group ID pgid, err := syscall.Getpgid(m.process.Pid) if err == nil { m.processGroup = pgid - log.Printf("Process group ID: %d", m.processGroup) + logger.Debug("Process group ID: %d", m.processGroup) } else { - log.Printf("Warning: Failed to get process group ID: %v", err) + logger.Warn("Failed to get process group ID: %v", err) m.processGroup = m.process.Pid } // Handle process termination in background go func() { if err := cmd.Wait(); err != nil { - log.Printf("Subprocess exited with error: %v", err) + logger.Error("Subprocess exited with error: %v", err) } else { - log.Printf("Subprocess exited successfully") + logger.Info("Subprocess exited successfully") } // Clear the process reference when it exits @@ -163,7 +163,7 @@ func (m *Manager) Shutdown() { return // No process to terminate } - log.Println("Terminating subprocess...") + logger.Info("Terminating subprocess...") terminateComplete := make(chan struct{}) go func() { @@ -176,14 +176,14 @@ func (m *Manager) Shutdown() { if processGroupToTerminate != 0 { err := syscall.Kill(-processGroupToTerminate, syscall.SIGTERM) if err != nil { - log.Printf("Failed to send SIGTERM to process group: %v", err) + logger.Warn("Failed to send SIGTERM to process group: %v", err) // Fallback to terminating just the process m.mutex.Lock() if m.process != nil { err = m.process.Signal(syscall.SIGTERM) if err != nil { - log.Printf("Failed to send SIGTERM to process: %v", err) + logger.Warn("Failed to send SIGTERM to process: %v", err) } } m.mutex.Unlock() @@ -194,7 +194,7 @@ func (m *Manager) Shutdown() { if m.process != nil { err := m.process.Signal(syscall.SIGTERM) if err != nil { - log.Printf("Failed to send SIGTERM to process: %v", err) + logger.Warn("Failed to send SIGTERM to process: %v", err) } } m.mutex.Unlock() @@ -214,23 +214,23 @@ func (m *Manager) Shutdown() { } if terminatedGracefully { - log.Println("Subprocess terminated gracefully") + logger.Info("Subprocess terminated gracefully") return } // If the process didn't exit gracefully, force kill - log.Println("Subprocess didn't exit gracefully, forcing termination...") + logger.Warn("Subprocess didn't exit gracefully, forcing termination...") // Try to kill the process group first if processGroupToTerminate != 0 { if err := syscall.Kill(-processGroupToTerminate, syscall.SIGKILL); err != nil { - log.Printf("Failed to send SIGKILL to process group: %v", err) + logger.Warn("Failed to send SIGKILL to process group: %v", err) // Fallback to killing just the process m.mutex.Lock() if m.process != nil { if err := m.process.Kill(); err != nil { - log.Printf("Failed to kill process: %v", err) + logger.Error("Failed to kill process: %v", err) } } m.mutex.Unlock() @@ -240,7 +240,7 @@ func (m *Manager) Shutdown() { m.mutex.Lock() if m.process != nil { if err := m.process.Kill(); err != nil { - log.Printf("Failed to kill process: %v", err) + logger.Error("Failed to kill process: %v", err) } } m.mutex.Unlock() @@ -251,9 +251,9 @@ func (m *Manager) Shutdown() { m.mutex.Lock() if m.process == nil { - log.Println("Subprocess terminated by force") + logger.Info("Subprocess terminated by force") } else { - log.Println("Warning: Failed to terminate subprocess") + logger.Warn("Failed to terminate subprocess") } m.mutex.Unlock() }() @@ -263,6 +263,6 @@ func (m *Manager) Shutdown() { case <-terminateComplete: // Termination completed case <-time.After(m.shutdownDelay): - log.Println("Warning: Subprocess termination timed out") + logger.Warn("Subprocess termination timed out") } } \ No newline at end of file diff --git a/internal/util/jwks.go b/internal/util/jwks.go index 4832bf8..0d278a7 100644 --- a/internal/util/jwks.go +++ b/internal/util/jwks.go @@ -4,12 +4,12 @@ import ( "crypto/rsa" "encoding/json" "errors" - "log" "math/big" "net/http" "strings" "github.com/golang-jwt/jwt/v4" + "github.com/wso2/open-mcp-auth-proxy/internal/logging" ) type JWKS struct { @@ -50,7 +50,7 @@ func FetchJWKS(jwksURL string) error { publicKeys[parsedKey.Kid] = pubKey } } - log.Printf("[JWKS] Loaded %d public keys.", len(publicKeys)) + logger.Info("Loaded %d public keys.", len(publicKeys)) return nil } @@ -94,4 +94,4 @@ func ValidateJWT(authHeader string) error { return errors.New("invalid token: token not valid") } return nil -} +} \ No newline at end of file