mirror of
https://github.com/wso2/open-mcp-auth-proxy.git
synced 2025-06-28 01:23:30 +00:00
Standardize logging and improve sensitive data handling
This commit is contained in:
parent
d7097e76e4
commit
2548eb569a
8 changed files with 86 additions and 69 deletions
|
@ -3,7 +3,6 @@ package main
|
||||||
import (
|
import (
|
||||||
"flag"
|
"flag"
|
||||||
"fmt"
|
"fmt"
|
||||||
"log"
|
|
||||||
"net/http"
|
"net/http"
|
||||||
"os"
|
"os"
|
||||||
"os/signal"
|
"os/signal"
|
||||||
|
@ -30,7 +29,8 @@ func main() {
|
||||||
// 1. Load config
|
// 1. Load config
|
||||||
cfg, err := config.LoadConfig("config.yaml")
|
cfg, err := config.LoadConfig("config.yaml")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatalf("Error loading config: %v", err)
|
logger.Error("Error loading config: %v", err)
|
||||||
|
os.Exit(1)
|
||||||
}
|
}
|
||||||
|
|
||||||
// 2. Ensure MCPPaths includes the configured paths from the command
|
// 2. Ensure MCPPaths includes the configured paths from the command
|
||||||
|
@ -60,9 +60,7 @@ func main() {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Add the baseUrl to allowed origins if not already present
|
logger.Info("Using MCP server baseUrl: %s", baseUrl)
|
||||||
// ensureOriginInList(&cfg.CORSConfig.AllowedOrigins, "http://localhost:8080")
|
|
||||||
log.Printf("Using MCP server baseUrl: %s", baseUrl)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// 3. Start subprocess if configured
|
// 3. Start subprocess if configured
|
||||||
|
@ -70,13 +68,13 @@ func main() {
|
||||||
if cfg.Command.Enabled && cfg.Command.UserCommand != "" {
|
if cfg.Command.Enabled && cfg.Command.UserCommand != "" {
|
||||||
// Ensure all required dependencies are available
|
// Ensure all required dependencies are available
|
||||||
if err := subprocess.EnsureDependenciesAvailable(cfg.Command.UserCommand); err != nil {
|
if err := subprocess.EnsureDependenciesAvailable(cfg.Command.UserCommand); err != nil {
|
||||||
log.Printf("Warning: %v", err)
|
logger.Warn("%v", err)
|
||||||
log.Printf("Subprocess may fail to start due to missing dependencies")
|
logger.Warn("Subprocess may fail to start due to missing dependencies")
|
||||||
}
|
}
|
||||||
|
|
||||||
procManager = subprocess.NewManager()
|
procManager = subprocess.NewManager()
|
||||||
if err := procManager.Start(&cfg.Command); err != nil {
|
if err := procManager.Start(&cfg.Command); err != nil {
|
||||||
log.Printf("Warning: Failed to start subprocess: %v", err)
|
logger.Warn("Failed to start subprocess: %v", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -101,7 +99,8 @@ func main() {
|
||||||
|
|
||||||
// 5. (Optional) Fetch JWKS if you want local JWT validation
|
// 5. (Optional) Fetch JWKS if you want local JWT validation
|
||||||
if err := util.FetchJWKS(cfg.JWKSURL); err != nil {
|
if err := util.FetchJWKS(cfg.JWKSURL); err != nil {
|
||||||
log.Fatalf("Failed to fetch JWKS: %v", err)
|
logger.Error("Failed to fetch JWKS: %v", err)
|
||||||
|
os.Exit(1)
|
||||||
}
|
}
|
||||||
|
|
||||||
// 6. Build the main router
|
// 6. Build the main router
|
||||||
|
@ -116,9 +115,10 @@ func main() {
|
||||||
}
|
}
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
log.Printf("Server listening on %s", listen_address)
|
logger.Info("Server listening on %s", listen_address)
|
||||||
if err := srv.ListenAndServe(); err != nil && err != http.ErrServerClosed {
|
if err := srv.ListenAndServe(); err != nil && err != http.ErrServerClosed {
|
||||||
log.Fatalf("Server error: %v", err)
|
logger.Error("Server error: %v", err)
|
||||||
|
os.Exit(1)
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
|
@ -126,7 +126,7 @@ func main() {
|
||||||
stop := make(chan os.Signal, 1)
|
stop := make(chan os.Signal, 1)
|
||||||
signal.Notify(stop, os.Interrupt, syscall.SIGTERM)
|
signal.Notify(stop, os.Interrupt, syscall.SIGTERM)
|
||||||
<-stop
|
<-stop
|
||||||
log.Println("Shutting down...")
|
logger.Info("Shutting down...")
|
||||||
|
|
||||||
// 9. First terminate subprocess if running
|
// 9. First terminate subprocess if running
|
||||||
if procManager != nil && procManager.IsRunning() {
|
if procManager != nil && procManager.IsRunning() {
|
||||||
|
@ -134,14 +134,14 @@ func main() {
|
||||||
}
|
}
|
||||||
|
|
||||||
// 10. Then shutdown the server
|
// 10. Then shutdown the server
|
||||||
log.Println("Shutting down HTTP server...")
|
logger.Info("Shutting down HTTP server...")
|
||||||
shutdownCtx, cancel := proxy.NewShutdownContext(5 * time.Second)
|
shutdownCtx, cancel := proxy.NewShutdownContext(5 * time.Second)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
if err := srv.Shutdown(shutdownCtx); err != nil {
|
if err := srv.Shutdown(shutdownCtx); err != nil {
|
||||||
log.Printf("HTTP server shutdown error: %v", err)
|
logger.Error("HTTP server shutdown error: %v", err)
|
||||||
}
|
}
|
||||||
log.Println("Stopped.")
|
logger.Info("Stopped.")
|
||||||
}
|
}
|
||||||
|
|
||||||
// Helper function to ensure a path is in a list
|
// Helper function to ensure a path is in a list
|
||||||
|
@ -154,7 +154,7 @@ func ensurePathInList(paths *[]string, path string) {
|
||||||
}
|
}
|
||||||
// Path doesn't exist, add it
|
// Path doesn't exist, add it
|
||||||
*paths = append(*paths, path)
|
*paths = append(*paths, path)
|
||||||
log.Printf("Added path %s to MCPPaths", path)
|
logger.Info("Added path %s to MCPPaths", path)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Helper function to ensure an origin is in a list
|
// Helper function to ensure an origin is in a list
|
||||||
|
@ -167,5 +167,5 @@ func ensureOriginInList(origins *[]string, origin string) {
|
||||||
}
|
}
|
||||||
// Origin doesn't exist, add it
|
// Origin doesn't exist, add it
|
||||||
*origins = append(*origins, origin)
|
*origins = append(*origins, origin)
|
||||||
log.Printf("Added %s to allowed CORS origins", origin)
|
logger.Info("Added %s to allowed CORS origins", origin)
|
||||||
}
|
}
|
|
@ -7,13 +7,13 @@ import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"log"
|
|
||||||
"math/rand"
|
"math/rand"
|
||||||
"net/http"
|
"net/http"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/wso2/open-mcp-auth-proxy/internal/config"
|
"github.com/wso2/open-mcp-auth-proxy/internal/config"
|
||||||
|
"github.com/wso2/open-mcp-auth-proxy/internal/logging"
|
||||||
)
|
)
|
||||||
|
|
||||||
type asgardeoProvider struct {
|
type asgardeoProvider struct {
|
||||||
|
@ -73,7 +73,7 @@ func (p *asgardeoProvider) WellKnownHandler() http.HandlerFunc {
|
||||||
w.Header().Set("Content-Type", "application/json")
|
w.Header().Set("Content-Type", "application/json")
|
||||||
w.Header().Set("X-Accel-Buffering", "no")
|
w.Header().Set("X-Accel-Buffering", "no")
|
||||||
if err := json.NewEncoder(w).Encode(response); err != nil {
|
if err := json.NewEncoder(w).Encode(response); err != nil {
|
||||||
log.Printf("[asgardeoProvider] Error encoding well-known: %v", err)
|
logger.Error("Error encoding well-known: %v", err)
|
||||||
http.Error(w, "Internal server error", http.StatusInternalServerError)
|
http.Error(w, "Internal server error", http.StatusInternalServerError)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -98,7 +98,7 @@ func (p *asgardeoProvider) RegisterHandler() http.HandlerFunc {
|
||||||
|
|
||||||
var regReq RegisterRequest
|
var regReq RegisterRequest
|
||||||
if err := json.NewDecoder(r.Body).Decode(®Req); err != nil {
|
if err := json.NewDecoder(r.Body).Decode(®Req); err != nil {
|
||||||
log.Printf("ERROR: reading register request: %v", err)
|
logger.Error("Reading register request: %v", err)
|
||||||
http.Error(w, "Invalid request body", http.StatusBadRequest)
|
http.Error(w, "Invalid request body", http.StatusBadRequest)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@ -112,7 +112,7 @@ func (p *asgardeoProvider) RegisterHandler() http.HandlerFunc {
|
||||||
regReq.ClientSecret = randomString(16)
|
regReq.ClientSecret = randomString(16)
|
||||||
|
|
||||||
if err := p.createAsgardeoApplication(regReq); err != nil {
|
if err := p.createAsgardeoApplication(regReq); err != nil {
|
||||||
log.Printf("WARN: Asgardeo application creation failed: %v", err)
|
logger.Warn("Asgardeo application creation failed: %v", err)
|
||||||
// Optionally http.Error(...) if you want to fail
|
// Optionally http.Error(...) if you want to fail
|
||||||
// or continue to return partial data.
|
// or continue to return partial data.
|
||||||
}
|
}
|
||||||
|
@ -130,7 +130,7 @@ func (p *asgardeoProvider) RegisterHandler() http.HandlerFunc {
|
||||||
w.Header().Set("X-Accel-Buffering", "no")
|
w.Header().Set("X-Accel-Buffering", "no")
|
||||||
w.WriteHeader(http.StatusCreated)
|
w.WriteHeader(http.StatusCreated)
|
||||||
if err := json.NewEncoder(w).Encode(resp); err != nil {
|
if err := json.NewEncoder(w).Encode(resp); err != nil {
|
||||||
log.Printf("ERROR: encoding /register response: %v", err)
|
logger.Error("Encoding /register response: %v", err)
|
||||||
http.Error(w, "Internal server error", http.StatusInternalServerError)
|
http.Error(w, "Internal server error", http.StatusInternalServerError)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -190,7 +190,7 @@ func (p *asgardeoProvider) createAsgardeoApplication(regReq RegisterRequest) err
|
||||||
return fmt.Errorf("Asgardeo creation error (%d): %s", resp.StatusCode, string(respBody))
|
return fmt.Errorf("Asgardeo creation error (%d): %s", resp.StatusCode, string(respBody))
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Printf("INFO: Created Asgardeo application for clientID=%s", regReq.ClientID)
|
logger.Info("Created Asgardeo application for clientID=%s", regReq.ClientID)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -206,9 +206,12 @@ func (p *asgardeoProvider) getAsgardeoAdminToken() (string, error) {
|
||||||
}
|
}
|
||||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||||
|
|
||||||
|
// Sensitive data - should not be logged at INFO level
|
||||||
auth := p.cfg.Demo.ClientID + ":" + p.cfg.Demo.ClientSecret
|
auth := p.cfg.Demo.ClientID + ":" + p.cfg.Demo.ClientSecret
|
||||||
req.Header.Set("Authorization", "Basic "+base64.StdEncoding.EncodeToString([]byte(auth)))
|
req.Header.Set("Authorization", "Basic "+base64.StdEncoding.EncodeToString([]byte(auth)))
|
||||||
|
|
||||||
|
logger.Debug("Requesting admin token for Asgardeo with client ID: %s", p.cfg.Demo.ClientID)
|
||||||
|
|
||||||
tr := &http.Transport{
|
tr := &http.Transport{
|
||||||
TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
|
TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
|
||||||
}
|
}
|
||||||
|
@ -238,6 +241,10 @@ func (p *asgardeoProvider) getAsgardeoAdminToken() (string, error) {
|
||||||
return "", fmt.Errorf("failed to parse token JSON: %w", err)
|
return "", fmt.Errorf("failed to parse token JSON: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Don't log the actual token at info level, only at debug level
|
||||||
|
logger.Debug("Received access token: %s", tokenResp.AccessToken)
|
||||||
|
logger.Info("Successfully obtained admin token from Asgardeo")
|
||||||
|
|
||||||
return tokenResp.AccessToken, nil
|
return tokenResp.AccessToken, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -5,6 +5,7 @@ import (
|
||||||
"net/http"
|
"net/http"
|
||||||
|
|
||||||
"github.com/wso2/open-mcp-auth-proxy/internal/config"
|
"github.com/wso2/open-mcp-auth-proxy/internal/config"
|
||||||
|
"github.com/wso2/open-mcp-auth-proxy/internal/logging"
|
||||||
)
|
)
|
||||||
|
|
||||||
type defaultProvider struct {
|
type defaultProvider struct {
|
||||||
|
@ -81,6 +82,7 @@ func (p *defaultProvider) WellKnownHandler() http.HandlerFunc {
|
||||||
|
|
||||||
w.Header().Set("Content-Type", "application/json")
|
w.Header().Set("Content-Type", "application/json")
|
||||||
if err := json.NewEncoder(w).Encode(response); err != nil {
|
if err := json.NewEncoder(w).Encode(response); err != nil {
|
||||||
|
logger.Error("Error encoding well-known response: %v", err)
|
||||||
http.Error(w, "Internal server error", http.StatusInternalServerError)
|
http.Error(w, "Internal server error", http.StatusInternalServerError)
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
|
|
|
@ -9,6 +9,7 @@ import (
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/wso2/open-mcp-auth-proxy/internal/config"
|
"github.com/wso2/open-mcp-auth-proxy/internal/config"
|
||||||
|
"github.com/wso2/open-mcp-auth-proxy/internal/logging"
|
||||||
)
|
)
|
||||||
|
|
||||||
// RequestModifier modifies requests before they are proxied
|
// RequestModifier modifies requests before they are proxied
|
||||||
|
@ -148,6 +149,7 @@ func (m *RegisterModifier) ModifyRequest(req *http.Request) (*http.Request, erro
|
||||||
if strings.Contains(contentType, "application/x-www-form-urlencoded") {
|
if strings.Contains(contentType, "application/x-www-form-urlencoded") {
|
||||||
// Parse form data
|
// Parse form data
|
||||||
if err := req.ParseForm(); err != nil {
|
if err := req.ParseForm(); err != nil {
|
||||||
|
logger.Error("Failed to parse form data: %v", err)
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -169,12 +171,14 @@ func (m *RegisterModifier) ModifyRequest(req *http.Request) (*http.Request, erro
|
||||||
// Read body
|
// Read body
|
||||||
bodyBytes, err := io.ReadAll(req.Body)
|
bodyBytes, err := io.ReadAll(req.Body)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
logger.Error("Failed to read request body: %v", err)
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// Parse JSON
|
// Parse JSON
|
||||||
var jsonData map[string]interface{}
|
var jsonData map[string]interface{}
|
||||||
if err := json.Unmarshal(bodyBytes, &jsonData); err != nil {
|
if err := json.Unmarshal(bodyBytes, &jsonData); err != nil {
|
||||||
|
logger.Error("Failed to parse JSON body: %v", err)
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -186,6 +190,7 @@ func (m *RegisterModifier) ModifyRequest(req *http.Request) (*http.Request, erro
|
||||||
// Marshal back to JSON
|
// Marshal back to JSON
|
||||||
modifiedBody, err := json.Marshal(jsonData)
|
modifiedBody, err := json.Marshal(jsonData)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
logger.Error("Failed to marshal modified JSON: %v", err)
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -2,7 +2,6 @@ package proxy
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"log"
|
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httputil"
|
"net/http/httputil"
|
||||||
"net/url"
|
"net/url"
|
||||||
|
@ -11,6 +10,7 @@ import (
|
||||||
|
|
||||||
"github.com/wso2/open-mcp-auth-proxy/internal/authz"
|
"github.com/wso2/open-mcp-auth-proxy/internal/authz"
|
||||||
"github.com/wso2/open-mcp-auth-proxy/internal/config"
|
"github.com/wso2/open-mcp-auth-proxy/internal/config"
|
||||||
|
"github.com/wso2/open-mcp-auth-proxy/internal/logging"
|
||||||
"github.com/wso2/open-mcp-auth-proxy/internal/util"
|
"github.com/wso2/open-mcp-auth-proxy/internal/util"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -102,11 +102,13 @@ func buildProxyHandler(cfg *config.Config, modifiers map[string]RequestModifier)
|
||||||
// Parse the base URLs up front
|
// Parse the base URLs up front
|
||||||
authBase, err := url.Parse(cfg.AuthServerBaseURL)
|
authBase, err := url.Parse(cfg.AuthServerBaseURL)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatalf("Invalid auth server URL: %v", err)
|
logger.Error("Invalid auth server URL: %v", err)
|
||||||
|
panic(err) // Fatal error that prevents startup
|
||||||
}
|
}
|
||||||
mcpBase, err := url.Parse(cfg.MCPServerBaseURL)
|
mcpBase, err := url.Parse(cfg.MCPServerBaseURL)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatalf("Invalid MCP server URL: %v", err)
|
logger.Error("Invalid MCP server URL: %v", err)
|
||||||
|
panic(err) // Fatal error that prevents startup
|
||||||
}
|
}
|
||||||
|
|
||||||
// Detect SSE paths from config
|
// Detect SSE paths from config
|
||||||
|
@ -123,7 +125,7 @@ func buildProxyHandler(cfg *config.Config, modifiers map[string]RequestModifier)
|
||||||
// Handle OPTIONS
|
// Handle OPTIONS
|
||||||
if r.Method == http.MethodOptions {
|
if r.Method == http.MethodOptions {
|
||||||
if allowedOrigin == "" {
|
if allowedOrigin == "" {
|
||||||
log.Printf("[proxy] Preflight request from disallowed origin: %s", origin)
|
logger.Warn("Preflight request from disallowed origin: %s", origin)
|
||||||
http.Error(w, "CORS origin not allowed", http.StatusForbidden)
|
http.Error(w, "CORS origin not allowed", http.StatusForbidden)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@ -133,7 +135,7 @@ func buildProxyHandler(cfg *config.Config, modifiers map[string]RequestModifier)
|
||||||
}
|
}
|
||||||
|
|
||||||
if allowedOrigin == "" {
|
if allowedOrigin == "" {
|
||||||
log.Printf("[proxy] Request from disallowed origin: %s for %s", origin, r.URL.Path)
|
logger.Warn("Request from disallowed origin: %s for %s", origin, r.URL.Path)
|
||||||
http.Error(w, "CORS origin not allowed", http.StatusForbidden)
|
http.Error(w, "CORS origin not allowed", http.StatusForbidden)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@ -151,7 +153,7 @@ func buildProxyHandler(cfg *config.Config, modifiers map[string]RequestModifier)
|
||||||
// Validate JWT for MCP paths if required
|
// Validate JWT for MCP paths if required
|
||||||
// Placeholder for JWT validation logic
|
// Placeholder for JWT validation logic
|
||||||
if err := util.ValidateJWT(r.Header.Get("Authorization")); err != nil {
|
if err := util.ValidateJWT(r.Header.Get("Authorization")); err != nil {
|
||||||
log.Printf("[proxy] Unauthorized request to %s: %v", r.URL.Path, err)
|
logger.Warn("Unauthorized request to %s: %v", r.URL.Path, err)
|
||||||
http.Error(w, "Unauthorized", http.StatusUnauthorized)
|
http.Error(w, "Unauthorized", http.StatusUnauthorized)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@ -169,7 +171,7 @@ func buildProxyHandler(cfg *config.Config, modifiers map[string]RequestModifier)
|
||||||
var err error
|
var err error
|
||||||
r, err = modifier.ModifyRequest(r)
|
r, err = modifier.ModifyRequest(r)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Printf("[proxy] Error modifying request: %v", err)
|
logger.Error("Error modifying request: %v", err)
|
||||||
http.Error(w, "Bad Request", http.StatusBadRequest)
|
http.Error(w, "Bad Request", http.StatusBadRequest)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@ -210,15 +212,15 @@ func buildProxyHandler(cfg *config.Config, modifiers map[string]RequestModifier)
|
||||||
|
|
||||||
req.Header = cleanHeaders
|
req.Header = cleanHeaders
|
||||||
|
|
||||||
log.Printf("[proxy] %s -> %s%s", r.URL.Path, req.URL.Host, req.URL.Path)
|
logger.Debug("%s -> %s%s", r.URL.Path, req.URL.Host, req.URL.Path)
|
||||||
},
|
},
|
||||||
ModifyResponse: func(resp *http.Response) error {
|
ModifyResponse: func(resp *http.Response) error {
|
||||||
log.Printf("[proxy] Response from %s%s: %d", resp.Request.URL.Host, resp.Request.URL.Path, resp.StatusCode)
|
logger.Debug("Response from %s%s: %d", resp.Request.URL.Host, resp.Request.URL.Path, resp.StatusCode)
|
||||||
resp.Header.Del("Access-Control-Allow-Origin") // Avoid upstream conflicts
|
resp.Header.Del("Access-Control-Allow-Origin") // Avoid upstream conflicts
|
||||||
return nil
|
return nil
|
||||||
},
|
},
|
||||||
ErrorHandler: func(rw http.ResponseWriter, req *http.Request, err error) {
|
ErrorHandler: func(rw http.ResponseWriter, req *http.Request, err error) {
|
||||||
log.Printf("[proxy] Error proxying: %v", err)
|
logger.Error("Error proxying: %v", err)
|
||||||
http.Error(rw, "Bad Gateway", http.StatusBadGateway)
|
http.Error(rw, "Bad Gateway", http.StatusBadGateway)
|
||||||
},
|
},
|
||||||
FlushInterval: -1, // immediate flush for SSE
|
FlushInterval: -1, // immediate flush for SSE
|
||||||
|
@ -253,7 +255,7 @@ func getAllowedOrigin(origin string, cfg *config.Config) string {
|
||||||
return cfg.CORSConfig.AllowedOrigins[0] // Default to first allowed origin
|
return cfg.CORSConfig.AllowedOrigins[0] // Default to first allowed origin
|
||||||
}
|
}
|
||||||
for _, allowed := range cfg.CORSConfig.AllowedOrigins {
|
for _, allowed := range cfg.CORSConfig.AllowedOrigins {
|
||||||
log.Printf("[proxy] Checking CORS origin: %s against allowed: %s", origin, allowed)
|
logger.Debug("Checking CORS origin: %s against allowed: %s", origin, allowed)
|
||||||
if allowed == origin {
|
if allowed == origin {
|
||||||
return allowed
|
return allowed
|
||||||
}
|
}
|
||||||
|
|
|
@ -5,11 +5,12 @@ import (
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"log"
|
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httputil"
|
"net/http/httputil"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/wso2/open-mcp-auth-proxy/internal/logging"
|
||||||
)
|
)
|
||||||
|
|
||||||
// HandleSSE sets up a go-routine to wait for context cancellation
|
// HandleSSE sets up a go-routine to wait for context cancellation
|
||||||
|
@ -20,7 +21,7 @@ func HandleSSE(w http.ResponseWriter, r *http.Request, rp *httputil.ReverseProxy
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
<-ctx.Done()
|
<-ctx.Done()
|
||||||
log.Printf("INFO: SSE connection closed from %s (path: %s)", r.RemoteAddr, r.URL.Path)
|
logger.Info("SSE connection closed from %s (path: %s)", r.RemoteAddr, r.URL.Path)
|
||||||
close(done)
|
close(done)
|
||||||
}()
|
}()
|
||||||
|
|
||||||
|
@ -57,7 +58,7 @@ func (t *sseTransport) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||||
return resp, nil
|
return resp, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Printf("INFO: Intercepting SSE response to modify endpoint events")
|
logger.Info("Intercepting SSE response to modify endpoint events")
|
||||||
|
|
||||||
// Create a response wrapper that modifies the response body
|
// Create a response wrapper that modifies the response body
|
||||||
originalBody := resp.Body
|
originalBody := resp.Body
|
||||||
|
@ -81,9 +82,9 @@ func (t *sseTransport) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||||
endpoint := strings.TrimPrefix(dataLine, "data: ")
|
endpoint := strings.TrimPrefix(dataLine, "data: ")
|
||||||
|
|
||||||
// Replace the host in the endpoint
|
// Replace the host in the endpoint
|
||||||
log.Printf("DEBUG: Original endpoint: %s", endpoint)
|
logger.Debug("Original endpoint: %s", endpoint)
|
||||||
endpoint = strings.Replace(endpoint, t.targetHost, t.proxyHost, 1)
|
endpoint = strings.Replace(endpoint, t.targetHost, t.proxyHost, 1)
|
||||||
log.Printf("DEBUG: Modified endpoint: %s", endpoint)
|
logger.Debug("Modified endpoint: %s", endpoint)
|
||||||
|
|
||||||
// Write the modified event lines
|
// Write the modified event lines
|
||||||
fmt.Fprintln(pw, line)
|
fmt.Fprintln(pw, line)
|
||||||
|
@ -98,7 +99,7 @@ func (t *sseTransport) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := scanner.Err(); err != nil {
|
if err := scanner.Err(); err != nil {
|
||||||
log.Printf("Error reading SSE stream: %v", err)
|
logger.Error("Error reading SSE stream: %v", err)
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
|
|
|
@ -1,16 +1,16 @@
|
||||||
package subprocess
|
package subprocess
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"log"
|
"fmt"
|
||||||
"os"
|
"os"
|
||||||
"os/exec"
|
"os/exec"
|
||||||
"sync"
|
"sync"
|
||||||
"syscall"
|
"syscall"
|
||||||
"time"
|
"time"
|
||||||
"fmt"
|
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/wso2/open-mcp-auth-proxy/internal/config"
|
"github.com/wso2/open-mcp-auth-proxy/internal/config"
|
||||||
|
"github.com/wso2/open-mcp-auth-proxy/internal/logging"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Manager handles starting and graceful shutdown of subprocesses
|
// Manager handles starting and graceful shutdown of subprocesses
|
||||||
|
@ -39,7 +39,7 @@ func EnsureDependenciesAvailable(command string) error {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Try to install npx using npm
|
// Try to install npx using npm
|
||||||
log.Printf("npx not found, attempting to install...")
|
logger.Info("npx not found, attempting to install...")
|
||||||
cmd := exec.Command("npm", "install", "-g", "npx")
|
cmd := exec.Command("npm", "install", "-g", "npx")
|
||||||
cmd.Stdout = os.Stdout
|
cmd.Stdout = os.Stdout
|
||||||
cmd.Stderr = os.Stderr
|
cmd.Stderr = os.Stderr
|
||||||
|
@ -48,7 +48,7 @@ func EnsureDependenciesAvailable(command string) error {
|
||||||
return fmt.Errorf("failed to install npx: %w", err)
|
return fmt.Errorf("failed to install npx: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Printf("npx installed successfully")
|
logger.Info("npx installed successfully")
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check if uv is needed based on the command
|
// Check if uv is needed based on the command
|
||||||
|
@ -86,7 +86,7 @@ func (m *Manager) Start(cmdConfig *config.Command) error {
|
||||||
return nil // No command to execute
|
return nil // No command to execute
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Printf("Starting subprocess with command: %s", execCommand)
|
logger.Info("Starting subprocess with command: %s", execCommand)
|
||||||
|
|
||||||
// Use the shell to execute the command
|
// Use the shell to execute the command
|
||||||
cmd := exec.Command("sh", "-c", execCommand)
|
cmd := exec.Command("sh", "-c", execCommand)
|
||||||
|
@ -115,24 +115,24 @@ func (m *Manager) Start(cmdConfig *config.Command) error {
|
||||||
|
|
||||||
m.process = cmd.Process
|
m.process = cmd.Process
|
||||||
m.cmd = cmd
|
m.cmd = cmd
|
||||||
log.Printf("Subprocess started with PID: %d", m.process.Pid)
|
logger.Info("Subprocess started with PID: %d", m.process.Pid)
|
||||||
|
|
||||||
// Get and store the process group ID
|
// Get and store the process group ID
|
||||||
pgid, err := syscall.Getpgid(m.process.Pid)
|
pgid, err := syscall.Getpgid(m.process.Pid)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
m.processGroup = pgid
|
m.processGroup = pgid
|
||||||
log.Printf("Process group ID: %d", m.processGroup)
|
logger.Debug("Process group ID: %d", m.processGroup)
|
||||||
} else {
|
} else {
|
||||||
log.Printf("Warning: Failed to get process group ID: %v", err)
|
logger.Warn("Failed to get process group ID: %v", err)
|
||||||
m.processGroup = m.process.Pid
|
m.processGroup = m.process.Pid
|
||||||
}
|
}
|
||||||
|
|
||||||
// Handle process termination in background
|
// Handle process termination in background
|
||||||
go func() {
|
go func() {
|
||||||
if err := cmd.Wait(); err != nil {
|
if err := cmd.Wait(); err != nil {
|
||||||
log.Printf("Subprocess exited with error: %v", err)
|
logger.Error("Subprocess exited with error: %v", err)
|
||||||
} else {
|
} else {
|
||||||
log.Printf("Subprocess exited successfully")
|
logger.Info("Subprocess exited successfully")
|
||||||
}
|
}
|
||||||
|
|
||||||
// Clear the process reference when it exits
|
// Clear the process reference when it exits
|
||||||
|
@ -163,7 +163,7 @@ func (m *Manager) Shutdown() {
|
||||||
return // No process to terminate
|
return // No process to terminate
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Println("Terminating subprocess...")
|
logger.Info("Terminating subprocess...")
|
||||||
terminateComplete := make(chan struct{})
|
terminateComplete := make(chan struct{})
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
|
@ -176,14 +176,14 @@ func (m *Manager) Shutdown() {
|
||||||
if processGroupToTerminate != 0 {
|
if processGroupToTerminate != 0 {
|
||||||
err := syscall.Kill(-processGroupToTerminate, syscall.SIGTERM)
|
err := syscall.Kill(-processGroupToTerminate, syscall.SIGTERM)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Printf("Failed to send SIGTERM to process group: %v", err)
|
logger.Warn("Failed to send SIGTERM to process group: %v", err)
|
||||||
|
|
||||||
// Fallback to terminating just the process
|
// Fallback to terminating just the process
|
||||||
m.mutex.Lock()
|
m.mutex.Lock()
|
||||||
if m.process != nil {
|
if m.process != nil {
|
||||||
err = m.process.Signal(syscall.SIGTERM)
|
err = m.process.Signal(syscall.SIGTERM)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Printf("Failed to send SIGTERM to process: %v", err)
|
logger.Warn("Failed to send SIGTERM to process: %v", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
m.mutex.Unlock()
|
m.mutex.Unlock()
|
||||||
|
@ -194,7 +194,7 @@ func (m *Manager) Shutdown() {
|
||||||
if m.process != nil {
|
if m.process != nil {
|
||||||
err := m.process.Signal(syscall.SIGTERM)
|
err := m.process.Signal(syscall.SIGTERM)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Printf("Failed to send SIGTERM to process: %v", err)
|
logger.Warn("Failed to send SIGTERM to process: %v", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
m.mutex.Unlock()
|
m.mutex.Unlock()
|
||||||
|
@ -214,23 +214,23 @@ func (m *Manager) Shutdown() {
|
||||||
}
|
}
|
||||||
|
|
||||||
if terminatedGracefully {
|
if terminatedGracefully {
|
||||||
log.Println("Subprocess terminated gracefully")
|
logger.Info("Subprocess terminated gracefully")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// If the process didn't exit gracefully, force kill
|
// If the process didn't exit gracefully, force kill
|
||||||
log.Println("Subprocess didn't exit gracefully, forcing termination...")
|
logger.Warn("Subprocess didn't exit gracefully, forcing termination...")
|
||||||
|
|
||||||
// Try to kill the process group first
|
// Try to kill the process group first
|
||||||
if processGroupToTerminate != 0 {
|
if processGroupToTerminate != 0 {
|
||||||
if err := syscall.Kill(-processGroupToTerminate, syscall.SIGKILL); err != nil {
|
if err := syscall.Kill(-processGroupToTerminate, syscall.SIGKILL); err != nil {
|
||||||
log.Printf("Failed to send SIGKILL to process group: %v", err)
|
logger.Warn("Failed to send SIGKILL to process group: %v", err)
|
||||||
|
|
||||||
// Fallback to killing just the process
|
// Fallback to killing just the process
|
||||||
m.mutex.Lock()
|
m.mutex.Lock()
|
||||||
if m.process != nil {
|
if m.process != nil {
|
||||||
if err := m.process.Kill(); err != nil {
|
if err := m.process.Kill(); err != nil {
|
||||||
log.Printf("Failed to kill process: %v", err)
|
logger.Error("Failed to kill process: %v", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
m.mutex.Unlock()
|
m.mutex.Unlock()
|
||||||
|
@ -240,7 +240,7 @@ func (m *Manager) Shutdown() {
|
||||||
m.mutex.Lock()
|
m.mutex.Lock()
|
||||||
if m.process != nil {
|
if m.process != nil {
|
||||||
if err := m.process.Kill(); err != nil {
|
if err := m.process.Kill(); err != nil {
|
||||||
log.Printf("Failed to kill process: %v", err)
|
logger.Error("Failed to kill process: %v", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
m.mutex.Unlock()
|
m.mutex.Unlock()
|
||||||
|
@ -251,9 +251,9 @@ func (m *Manager) Shutdown() {
|
||||||
|
|
||||||
m.mutex.Lock()
|
m.mutex.Lock()
|
||||||
if m.process == nil {
|
if m.process == nil {
|
||||||
log.Println("Subprocess terminated by force")
|
logger.Info("Subprocess terminated by force")
|
||||||
} else {
|
} else {
|
||||||
log.Println("Warning: Failed to terminate subprocess")
|
logger.Warn("Failed to terminate subprocess")
|
||||||
}
|
}
|
||||||
m.mutex.Unlock()
|
m.mutex.Unlock()
|
||||||
}()
|
}()
|
||||||
|
@ -263,6 +263,6 @@ func (m *Manager) Shutdown() {
|
||||||
case <-terminateComplete:
|
case <-terminateComplete:
|
||||||
// Termination completed
|
// Termination completed
|
||||||
case <-time.After(m.shutdownDelay):
|
case <-time.After(m.shutdownDelay):
|
||||||
log.Println("Warning: Subprocess termination timed out")
|
logger.Warn("Subprocess termination timed out")
|
||||||
}
|
}
|
||||||
}
|
}
|
|
@ -4,12 +4,12 @@ import (
|
||||||
"crypto/rsa"
|
"crypto/rsa"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
"log"
|
|
||||||
"math/big"
|
"math/big"
|
||||||
"net/http"
|
"net/http"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/golang-jwt/jwt/v4"
|
"github.com/golang-jwt/jwt/v4"
|
||||||
|
"github.com/wso2/open-mcp-auth-proxy/internal/logging"
|
||||||
)
|
)
|
||||||
|
|
||||||
type JWKS struct {
|
type JWKS struct {
|
||||||
|
@ -50,7 +50,7 @@ func FetchJWKS(jwksURL string) error {
|
||||||
publicKeys[parsedKey.Kid] = pubKey
|
publicKeys[parsedKey.Kid] = pubKey
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
log.Printf("[JWKS] Loaded %d public keys.", len(publicKeys))
|
logger.Info("Loaded %d public keys.", len(publicKeys))
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue