Handle spawning subprocess within windows

This commit is contained in:
Chiran Fernando 2025-05-18 14:19:36 +05:30
parent ad5185ad72
commit 68015ae8fc
5 changed files with 167 additions and 52 deletions

4
.gitignore vendored
View file

@ -36,3 +36,7 @@ coverage.html
# IDE files # IDE files
.vscode .vscode
# node modules
node_modules
openmcpauthproxy

View file

@ -3,6 +3,8 @@ package config
import ( import (
"fmt" "fmt"
"os" "os"
"runtime"
"strings"
"gopkg.in/yaml.v2" "gopkg.in/yaml.v2"
) )
@ -145,7 +147,16 @@ func (c *Config) BuildExecCommand() string {
return "" return ""
} }
// Construct the full command
if runtime.GOOS == "windows" {
// For Windows, we need to properly escape the inner command
escapedCommand := strings.ReplaceAll(c.Stdio.UserCommand, `"`, `\"`)
return fmt.Sprintf(
`npx -y supergateway --stdio "%s" --port %d --baseUrl %s --ssePath %s --messagePath %s`,
escapedCommand, c.Port, c.BaseURL, c.Paths.SSE, c.Paths.Messages,
)
}
return fmt.Sprintf( return fmt.Sprintf(
`npx -y supergateway --stdio "%s" --port %d --baseUrl %s --ssePath %s --messagePath %s`, `npx -y supergateway --stdio "%s" --port %d --baseUrl %s --ssePath %s --messagePath %s`,
c.Stdio.UserCommand, c.Port, c.BaseURL, c.Paths.SSE, c.Paths.Messages, c.Stdio.UserCommand, c.Port, c.BaseURL, c.Paths.SSE, c.Paths.Messages,

View file

@ -8,6 +8,7 @@ import (
"syscall" "syscall"
"time" "time"
"strings" "strings"
"runtime"
"github.com/wso2/open-mcp-auth-proxy/internal/config" "github.com/wso2/open-mcp-auth-proxy/internal/config"
"github.com/wso2/open-mcp-auth-proxy/internal/logging" "github.com/wso2/open-mcp-auth-proxy/internal/logging"
@ -40,7 +41,12 @@ func EnsureDependenciesAvailable(command string) error {
// Try to install npx using npm // Try to install npx using npm
logger.Info("npx not found, attempting to install...") logger.Info("npx not found, attempting to install...")
cmd := exec.Command("npm", "install", "-g", "npx") var cmd *exec.Cmd
if runtime.GOOS == "windows" {
cmd = exec.Command("npm.cmd", "install", "-g", "npx")
} else {
cmd = exec.Command("npm", "install", "-g", "npx")
}
cmd.Stdout = os.Stdout cmd.Stdout = os.Stdout
cmd.Stderr = os.Stderr cmd.Stderr = os.Stderr
@ -88,8 +94,13 @@ func (m *Manager) Start(cfg *config.Config) error {
logger.Info("Starting subprocess with command: %s", execCommand) logger.Info("Starting subprocess with command: %s", execCommand)
// Use the shell to execute the command var cmd *exec.Cmd
cmd := exec.Command("sh", "-c", execCommand) if runtime.GOOS == "windows" {
// Use PowerShell on Windows for better quote handling
cmd = exec.Command("powershell", "-Command", execCommand)
} else {
cmd = exec.Command("sh", "-c", execCommand)
}
// Set working directory if specified // Set working directory if specified
if cfg.Stdio.WorkDir != "" { if cfg.Stdio.WorkDir != "" {
@ -105,8 +116,8 @@ func (m *Manager) Start(cfg *config.Config) error {
cmd.Stdout = os.Stdout cmd.Stdout = os.Stdout
cmd.Stderr = os.Stderr cmd.Stderr = os.Stderr
// Set the process group for proper termination // Set platform-specific process attributes
cmd.SysProcAttr = &syscall.SysProcAttr{Setpgid: true} setProcAttr(cmd)
// Start the process // Start the process
if err := cmd.Start(); err != nil { if err := cmd.Start(); err != nil {
@ -117,11 +128,13 @@ func (m *Manager) Start(cfg *config.Config) error {
m.cmd = cmd m.cmd = cmd
logger.Info("Subprocess started with PID: %d", m.process.Pid) logger.Info("Subprocess started with PID: %d", m.process.Pid)
// Get and store the process group ID // Get and store the process group ID (Unix) or PID (Windows)
pgid, err := syscall.Getpgid(m.process.Pid) pgid, err := getProcessGroup(m.process.Pid)
if err == nil { if err == nil {
m.processGroup = pgid m.processGroup = pgid
if runtime.GOOS != "windows" {
logger.Debug("Process group ID: %d", m.processGroup) logger.Debug("Process group ID: %d", m.processGroup)
}
} else { } else {
logger.Warn("Failed to get process group ID: %v", err) logger.Warn("Failed to get process group ID: %v", err)
m.processGroup = m.process.Pid m.processGroup = m.process.Pid
@ -169,12 +182,36 @@ func (m *Manager) Shutdown() {
go func() { go func() {
defer close(terminateComplete) defer close(terminateComplete)
// Try graceful termination first with SIGTERM // Try graceful termination first
terminatedGracefully := false terminatedGracefully := false
if runtime.GOOS == "windows" {
// Windows: Try to terminate the process
m.mutex.Lock()
if m.process != nil {
err := m.process.Kill()
if err != nil {
logger.Warn("Failed to terminate process: %v", err)
}
}
m.mutex.Unlock()
// Wait a bit to see if it terminates
for i := 0; i < 10; i++ {
time.Sleep(200 * time.Millisecond)
m.mutex.Lock()
if m.process == nil {
terminatedGracefully = true
m.mutex.Unlock()
break
}
m.mutex.Unlock()
}
} else {
// Unix: Use SIGTERM followed by SIGKILL if necessary
// Try to terminate the process group first // Try to terminate the process group first
if processGroupToTerminate != 0 { if processGroupToTerminate != 0 {
err := syscall.Kill(-processGroupToTerminate, syscall.SIGTERM) err := killProcessGroup(processGroupToTerminate, syscall.SIGTERM)
if err != nil { if err != nil {
logger.Warn("Failed to send SIGTERM to process group: %v", err) logger.Warn("Failed to send SIGTERM to process group: %v", err)
@ -212,6 +249,7 @@ func (m *Manager) Shutdown() {
} }
m.mutex.Unlock() m.mutex.Unlock()
} }
}
if terminatedGracefully { if terminatedGracefully {
logger.Info("Subprocess terminated gracefully") logger.Info("Subprocess terminated gracefully")
@ -221,9 +259,20 @@ func (m *Manager) Shutdown() {
// If the process didn't exit gracefully, force kill // If the process didn't exit gracefully, force kill
logger.Warn("Subprocess didn't exit gracefully, forcing termination...") logger.Warn("Subprocess didn't exit gracefully, forcing termination...")
if runtime.GOOS == "windows" {
// On Windows, Kill() is already forceful
m.mutex.Lock()
if m.process != nil {
if err := m.process.Kill(); err != nil {
logger.Error("Failed to kill process: %v", err)
}
}
m.mutex.Unlock()
} else {
// Unix: Try SIGKILL
// Try to kill the process group first // Try to kill the process group first
if processGroupToTerminate != 0 { if processGroupToTerminate != 0 {
if err := syscall.Kill(-processGroupToTerminate, syscall.SIGKILL); err != nil { if err := killProcessGroup(processGroupToTerminate, syscall.SIGKILL); err != nil {
logger.Warn("Failed to send SIGKILL to process group: %v", err) logger.Warn("Failed to send SIGKILL to process group: %v", err)
// Fallback to killing just the process // Fallback to killing just the process
@ -245,6 +294,7 @@ func (m *Manager) Shutdown() {
} }
m.mutex.Unlock() m.mutex.Unlock()
} }
}
// Wait a bit more to confirm termination // Wait a bit more to confirm termination
time.Sleep(500 * time.Millisecond) time.Sleep(500 * time.Millisecond)

View file

@ -0,0 +1,23 @@
//go:build !windows
package subprocess
import (
"os/exec"
"syscall"
)
// setProcAttr sets Unix-specific process attributes
func setProcAttr(cmd *exec.Cmd) {
cmd.SysProcAttr = &syscall.SysProcAttr{Setpgid: true}
}
// getProcessGroup gets the process group ID on Unix systems
func getProcessGroup(pid int) (int, error) {
return syscall.Getpgid(pid)
}
// killProcessGroup kills a process group on Unix systems
func killProcessGroup(pgid int, signal syscall.Signal) error {
return syscall.Kill(-pgid, signal)
}

View file

@ -0,0 +1,27 @@
//go:build windows
package subprocess
import (
"os/exec"
"syscall"
)
// setProcAttr sets Windows-specific process attributes
func setProcAttr(cmd *exec.Cmd) {
cmd.SysProcAttr = &syscall.SysProcAttr{
CreationFlags: syscall.CREATE_NEW_PROCESS_GROUP,
}
}
// getProcessGroup returns the PID itself on Windows (no process groups)
func getProcessGroup(pid int) (int, error) {
return pid, nil
}
// killProcessGroup kills a process on Windows (no process groups)
func killProcessGroup(pgid int, signal syscall.Signal) error {
// On Windows, we'll use the process handle directly
// This function shouldn't be called on Windows, but we provide it for compatibility
return nil
}