From 68015ae8fc9897526cdc01c1c49afc0d7770c925 Mon Sep 17 00:00:00 2001 From: Chiran Fernando Date: Sun, 18 May 2025 14:19:36 +0530 Subject: [PATCH] Handle spawning subprocess within windows --- .gitignore | 4 + internal/config/config.go | 13 ++- internal/subprocess/manager.go | 152 ++++++++++++++++--------- internal/subprocess/manager_unix.go | 23 ++++ internal/subprocess/manager_windows.go | 27 +++++ 5 files changed, 167 insertions(+), 52 deletions(-) create mode 100644 internal/subprocess/manager_unix.go create mode 100644 internal/subprocess/manager_windows.go diff --git a/.gitignore b/.gitignore index d200b58..f2bcda1 100644 --- a/.gitignore +++ b/.gitignore @@ -36,3 +36,7 @@ coverage.html # IDE files .vscode + +# node modules +node_modules +openmcpauthproxy diff --git a/internal/config/config.go b/internal/config/config.go index fc6743c..c50d9ed 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -3,6 +3,8 @@ package config import ( "fmt" "os" + "runtime" + "strings" "gopkg.in/yaml.v2" ) @@ -145,7 +147,16 @@ func (c *Config) BuildExecCommand() string { return "" } - // Construct the full command + + if runtime.GOOS == "windows" { + // For Windows, we need to properly escape the inner command + escapedCommand := strings.ReplaceAll(c.Stdio.UserCommand, `"`, `\"`) + return fmt.Sprintf( + `npx -y supergateway --stdio "%s" --port %d --baseUrl %s --ssePath %s --messagePath %s`, + escapedCommand, c.Port, c.BaseURL, c.Paths.SSE, c.Paths.Messages, + ) + } + return fmt.Sprintf( `npx -y supergateway --stdio "%s" --port %d --baseUrl %s --ssePath %s --messagePath %s`, c.Stdio.UserCommand, c.Port, c.BaseURL, c.Paths.SSE, c.Paths.Messages, diff --git a/internal/subprocess/manager.go b/internal/subprocess/manager.go index fa64337..6230fe0 100644 --- a/internal/subprocess/manager.go +++ b/internal/subprocess/manager.go @@ -8,6 +8,7 @@ import ( "syscall" "time" "strings" + "runtime" "github.com/wso2/open-mcp-auth-proxy/internal/config" "github.com/wso2/open-mcp-auth-proxy/internal/logging" @@ -40,7 +41,12 @@ func EnsureDependenciesAvailable(command string) error { // Try to install npx using npm logger.Info("npx not found, attempting to install...") - cmd := exec.Command("npm", "install", "-g", "npx") + var cmd *exec.Cmd + if runtime.GOOS == "windows" { + cmd = exec.Command("npm.cmd", "install", "-g", "npx") + } else { + cmd = exec.Command("npm", "install", "-g", "npx") + } cmd.Stdout = os.Stdout cmd.Stderr = os.Stderr @@ -88,8 +94,13 @@ func (m *Manager) Start(cfg *config.Config) error { logger.Info("Starting subprocess with command: %s", execCommand) - // Use the shell to execute the command - cmd := exec.Command("sh", "-c", execCommand) + var cmd *exec.Cmd + if runtime.GOOS == "windows" { + // Use PowerShell on Windows for better quote handling + cmd = exec.Command("powershell", "-Command", execCommand) + } else { + cmd = exec.Command("sh", "-c", execCommand) + } // Set working directory if specified if cfg.Stdio.WorkDir != "" { @@ -105,8 +116,8 @@ func (m *Manager) Start(cfg *config.Config) error { cmd.Stdout = os.Stdout cmd.Stderr = os.Stderr - // Set the process group for proper termination - cmd.SysProcAttr = &syscall.SysProcAttr{Setpgid: true} + // Set platform-specific process attributes + setProcAttr(cmd) // Start the process if err := cmd.Start(); err != nil { @@ -117,11 +128,13 @@ func (m *Manager) Start(cfg *config.Config) error { m.cmd = cmd logger.Info("Subprocess started with PID: %d", m.process.Pid) - // Get and store the process group ID - pgid, err := syscall.Getpgid(m.process.Pid) + // Get and store the process group ID (Unix) or PID (Windows) + pgid, err := getProcessGroup(m.process.Pid) if err == nil { m.processGroup = pgid - logger.Debug("Process group ID: %d", m.processGroup) + if runtime.GOOS != "windows" { + logger.Debug("Process group ID: %d", m.processGroup) + } } else { logger.Warn("Failed to get process group ID: %v", err) m.processGroup = m.process.Pid @@ -169,48 +182,73 @@ func (m *Manager) Shutdown() { go func() { defer close(terminateComplete) - // Try graceful termination first with SIGTERM + // Try graceful termination first terminatedGracefully := false - // Try to terminate the process group first - if processGroupToTerminate != 0 { - err := syscall.Kill(-processGroupToTerminate, syscall.SIGTERM) - if err != nil { - logger.Warn("Failed to send SIGTERM to process group: %v", err) + if runtime.GOOS == "windows" { + // Windows: Try to terminate the process + m.mutex.Lock() + if m.process != nil { + err := m.process.Kill() + if err != nil { + logger.Warn("Failed to terminate process: %v", err) + } + } + m.mutex.Unlock() - // Fallback to terminating just the process + // Wait a bit to see if it terminates + for i := 0; i < 10; i++ { + time.Sleep(200 * time.Millisecond) + m.mutex.Lock() + if m.process == nil { + terminatedGracefully = true + m.mutex.Unlock() + break + } + m.mutex.Unlock() + } + } else { + // Unix: Use SIGTERM followed by SIGKILL if necessary + // Try to terminate the process group first + if processGroupToTerminate != 0 { + err := killProcessGroup(processGroupToTerminate, syscall.SIGTERM) + if err != nil { + logger.Warn("Failed to send SIGTERM to process group: %v", err) + + // Fallback to terminating just the process + m.mutex.Lock() + if m.process != nil { + err = m.process.Signal(syscall.SIGTERM) + if err != nil { + logger.Warn("Failed to send SIGTERM to process: %v", err) + } + } + m.mutex.Unlock() + } + } else { + // Try to terminate just the process m.mutex.Lock() if m.process != nil { - err = m.process.Signal(syscall.SIGTERM) + err := m.process.Signal(syscall.SIGTERM) if err != nil { logger.Warn("Failed to send SIGTERM to process: %v", err) } } m.mutex.Unlock() } - } else { - // Try to terminate just the process - m.mutex.Lock() - if m.process != nil { - err := m.process.Signal(syscall.SIGTERM) - if err != nil { - logger.Warn("Failed to send SIGTERM to process: %v", err) + + // Wait for the process to exit gracefully + for i := 0; i < 10; i++ { + time.Sleep(200 * time.Millisecond) + + m.mutex.Lock() + if m.process == nil { + terminatedGracefully = true + m.mutex.Unlock() + break } - } - m.mutex.Unlock() - } - - // Wait for the process to exit gracefully - for i := 0; i < 10; i++ { - time.Sleep(200 * time.Millisecond) - - m.mutex.Lock() - if m.process == nil { - terminatedGracefully = true m.mutex.Unlock() - break } - m.mutex.Unlock() } if terminatedGracefully { @@ -221,12 +259,33 @@ func (m *Manager) Shutdown() { // If the process didn't exit gracefully, force kill logger.Warn("Subprocess didn't exit gracefully, forcing termination...") - // Try to kill the process group first - if processGroupToTerminate != 0 { - if err := syscall.Kill(-processGroupToTerminate, syscall.SIGKILL); err != nil { - logger.Warn("Failed to send SIGKILL to process group: %v", err) + if runtime.GOOS == "windows" { + // On Windows, Kill() is already forceful + m.mutex.Lock() + if m.process != nil { + if err := m.process.Kill(); err != nil { + logger.Error("Failed to kill process: %v", err) + } + } + m.mutex.Unlock() + } else { + // Unix: Try SIGKILL + // Try to kill the process group first + if processGroupToTerminate != 0 { + if err := killProcessGroup(processGroupToTerminate, syscall.SIGKILL); err != nil { + logger.Warn("Failed to send SIGKILL to process group: %v", err) - // Fallback to killing just the process + // Fallback to killing just the process + m.mutex.Lock() + if m.process != nil { + if err := m.process.Kill(); err != nil { + logger.Error("Failed to kill process: %v", err) + } + } + m.mutex.Unlock() + } + } else { + // Try to kill just the process m.mutex.Lock() if m.process != nil { if err := m.process.Kill(); err != nil { @@ -235,15 +294,6 @@ func (m *Manager) Shutdown() { } m.mutex.Unlock() } - } else { - // Try to kill just the process - m.mutex.Lock() - if m.process != nil { - if err := m.process.Kill(); err != nil { - logger.Error("Failed to kill process: %v", err) - } - } - m.mutex.Unlock() } // Wait a bit more to confirm termination @@ -265,4 +315,4 @@ func (m *Manager) Shutdown() { case <-time.After(m.shutdownDelay): logger.Warn("Subprocess termination timed out") } -} +} \ No newline at end of file diff --git a/internal/subprocess/manager_unix.go b/internal/subprocess/manager_unix.go new file mode 100644 index 0000000..2f3dc35 --- /dev/null +++ b/internal/subprocess/manager_unix.go @@ -0,0 +1,23 @@ +//go:build !windows + +package subprocess + +import ( + "os/exec" + "syscall" +) + +// setProcAttr sets Unix-specific process attributes +func setProcAttr(cmd *exec.Cmd) { + cmd.SysProcAttr = &syscall.SysProcAttr{Setpgid: true} +} + +// getProcessGroup gets the process group ID on Unix systems +func getProcessGroup(pid int) (int, error) { + return syscall.Getpgid(pid) +} + +// killProcessGroup kills a process group on Unix systems +func killProcessGroup(pgid int, signal syscall.Signal) error { + return syscall.Kill(-pgid, signal) +} \ No newline at end of file diff --git a/internal/subprocess/manager_windows.go b/internal/subprocess/manager_windows.go new file mode 100644 index 0000000..30cb2c8 --- /dev/null +++ b/internal/subprocess/manager_windows.go @@ -0,0 +1,27 @@ +//go:build windows + +package subprocess + +import ( + "os/exec" + "syscall" +) + +// setProcAttr sets Windows-specific process attributes +func setProcAttr(cmd *exec.Cmd) { + cmd.SysProcAttr = &syscall.SysProcAttr{ + CreationFlags: syscall.CREATE_NEW_PROCESS_GROUP, + } +} + +// getProcessGroup returns the PID itself on Windows (no process groups) +func getProcessGroup(pid int) (int, error) { + return pid, nil +} + +// killProcessGroup kills a process on Windows (no process groups) +func killProcessGroup(pgid int, signal syscall.Signal) error { + // On Windows, we'll use the process handle directly + // This function shouldn't be called on Windows, but we provide it for compatibility + return nil +} \ No newline at end of file