mirror of
https://github.com/wso2/open-mcp-auth-proxy.git
synced 2025-06-27 09:05:41 +00:00
Handle spawning subprocess within windows
This commit is contained in:
parent
ad5185ad72
commit
68015ae8fc
5 changed files with 167 additions and 52 deletions
4
.gitignore
vendored
4
.gitignore
vendored
|
@ -36,3 +36,7 @@ coverage.html
|
|||
|
||||
# IDE files
|
||||
.vscode
|
||||
|
||||
# node modules
|
||||
node_modules
|
||||
openmcpauthproxy
|
||||
|
|
|
@ -3,6 +3,8 @@ package config
|
|||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"runtime"
|
||||
"strings"
|
||||
|
||||
"gopkg.in/yaml.v2"
|
||||
)
|
||||
|
@ -145,7 +147,16 @@ func (c *Config) BuildExecCommand() string {
|
|||
return ""
|
||||
}
|
||||
|
||||
// Construct the full command
|
||||
|
||||
if runtime.GOOS == "windows" {
|
||||
// For Windows, we need to properly escape the inner command
|
||||
escapedCommand := strings.ReplaceAll(c.Stdio.UserCommand, `"`, `\"`)
|
||||
return fmt.Sprintf(
|
||||
`npx -y supergateway --stdio "%s" --port %d --baseUrl %s --ssePath %s --messagePath %s`,
|
||||
escapedCommand, c.Port, c.BaseURL, c.Paths.SSE, c.Paths.Messages,
|
||||
)
|
||||
}
|
||||
|
||||
return fmt.Sprintf(
|
||||
`npx -y supergateway --stdio "%s" --port %d --baseUrl %s --ssePath %s --messagePath %s`,
|
||||
c.Stdio.UserCommand, c.Port, c.BaseURL, c.Paths.SSE, c.Paths.Messages,
|
||||
|
|
|
@ -8,6 +8,7 @@ import (
|
|||
"syscall"
|
||||
"time"
|
||||
"strings"
|
||||
"runtime"
|
||||
|
||||
"github.com/wso2/open-mcp-auth-proxy/internal/config"
|
||||
"github.com/wso2/open-mcp-auth-proxy/internal/logging"
|
||||
|
@ -40,7 +41,12 @@ func EnsureDependenciesAvailable(command string) error {
|
|||
|
||||
// Try to install npx using npm
|
||||
logger.Info("npx not found, attempting to install...")
|
||||
cmd := exec.Command("npm", "install", "-g", "npx")
|
||||
var cmd *exec.Cmd
|
||||
if runtime.GOOS == "windows" {
|
||||
cmd = exec.Command("npm.cmd", "install", "-g", "npx")
|
||||
} else {
|
||||
cmd = exec.Command("npm", "install", "-g", "npx")
|
||||
}
|
||||
cmd.Stdout = os.Stdout
|
||||
cmd.Stderr = os.Stderr
|
||||
|
||||
|
@ -88,8 +94,13 @@ func (m *Manager) Start(cfg *config.Config) error {
|
|||
|
||||
logger.Info("Starting subprocess with command: %s", execCommand)
|
||||
|
||||
// Use the shell to execute the command
|
||||
cmd := exec.Command("sh", "-c", execCommand)
|
||||
var cmd *exec.Cmd
|
||||
if runtime.GOOS == "windows" {
|
||||
// Use PowerShell on Windows for better quote handling
|
||||
cmd = exec.Command("powershell", "-Command", execCommand)
|
||||
} else {
|
||||
cmd = exec.Command("sh", "-c", execCommand)
|
||||
}
|
||||
|
||||
// Set working directory if specified
|
||||
if cfg.Stdio.WorkDir != "" {
|
||||
|
@ -105,8 +116,8 @@ func (m *Manager) Start(cfg *config.Config) error {
|
|||
cmd.Stdout = os.Stdout
|
||||
cmd.Stderr = os.Stderr
|
||||
|
||||
// Set the process group for proper termination
|
||||
cmd.SysProcAttr = &syscall.SysProcAttr{Setpgid: true}
|
||||
// Set platform-specific process attributes
|
||||
setProcAttr(cmd)
|
||||
|
||||
// Start the process
|
||||
if err := cmd.Start(); err != nil {
|
||||
|
@ -117,11 +128,13 @@ func (m *Manager) Start(cfg *config.Config) error {
|
|||
m.cmd = cmd
|
||||
logger.Info("Subprocess started with PID: %d", m.process.Pid)
|
||||
|
||||
// Get and store the process group ID
|
||||
pgid, err := syscall.Getpgid(m.process.Pid)
|
||||
// Get and store the process group ID (Unix) or PID (Windows)
|
||||
pgid, err := getProcessGroup(m.process.Pid)
|
||||
if err == nil {
|
||||
m.processGroup = pgid
|
||||
logger.Debug("Process group ID: %d", m.processGroup)
|
||||
if runtime.GOOS != "windows" {
|
||||
logger.Debug("Process group ID: %d", m.processGroup)
|
||||
}
|
||||
} else {
|
||||
logger.Warn("Failed to get process group ID: %v", err)
|
||||
m.processGroup = m.process.Pid
|
||||
|
@ -169,48 +182,73 @@ func (m *Manager) Shutdown() {
|
|||
go func() {
|
||||
defer close(terminateComplete)
|
||||
|
||||
// Try graceful termination first with SIGTERM
|
||||
// Try graceful termination first
|
||||
terminatedGracefully := false
|
||||
|
||||
// Try to terminate the process group first
|
||||
if processGroupToTerminate != 0 {
|
||||
err := syscall.Kill(-processGroupToTerminate, syscall.SIGTERM)
|
||||
if err != nil {
|
||||
logger.Warn("Failed to send SIGTERM to process group: %v", err)
|
||||
if runtime.GOOS == "windows" {
|
||||
// Windows: Try to terminate the process
|
||||
m.mutex.Lock()
|
||||
if m.process != nil {
|
||||
err := m.process.Kill()
|
||||
if err != nil {
|
||||
logger.Warn("Failed to terminate process: %v", err)
|
||||
}
|
||||
}
|
||||
m.mutex.Unlock()
|
||||
|
||||
// Fallback to terminating just the process
|
||||
// Wait a bit to see if it terminates
|
||||
for i := 0; i < 10; i++ {
|
||||
time.Sleep(200 * time.Millisecond)
|
||||
m.mutex.Lock()
|
||||
if m.process == nil {
|
||||
terminatedGracefully = true
|
||||
m.mutex.Unlock()
|
||||
break
|
||||
}
|
||||
m.mutex.Unlock()
|
||||
}
|
||||
} else {
|
||||
// Unix: Use SIGTERM followed by SIGKILL if necessary
|
||||
// Try to terminate the process group first
|
||||
if processGroupToTerminate != 0 {
|
||||
err := killProcessGroup(processGroupToTerminate, syscall.SIGTERM)
|
||||
if err != nil {
|
||||
logger.Warn("Failed to send SIGTERM to process group: %v", err)
|
||||
|
||||
// Fallback to terminating just the process
|
||||
m.mutex.Lock()
|
||||
if m.process != nil {
|
||||
err = m.process.Signal(syscall.SIGTERM)
|
||||
if err != nil {
|
||||
logger.Warn("Failed to send SIGTERM to process: %v", err)
|
||||
}
|
||||
}
|
||||
m.mutex.Unlock()
|
||||
}
|
||||
} else {
|
||||
// Try to terminate just the process
|
||||
m.mutex.Lock()
|
||||
if m.process != nil {
|
||||
err = m.process.Signal(syscall.SIGTERM)
|
||||
err := m.process.Signal(syscall.SIGTERM)
|
||||
if err != nil {
|
||||
logger.Warn("Failed to send SIGTERM to process: %v", err)
|
||||
}
|
||||
}
|
||||
m.mutex.Unlock()
|
||||
}
|
||||
} else {
|
||||
// Try to terminate just the process
|
||||
m.mutex.Lock()
|
||||
if m.process != nil {
|
||||
err := m.process.Signal(syscall.SIGTERM)
|
||||
if err != nil {
|
||||
logger.Warn("Failed to send SIGTERM to process: %v", err)
|
||||
|
||||
// Wait for the process to exit gracefully
|
||||
for i := 0; i < 10; i++ {
|
||||
time.Sleep(200 * time.Millisecond)
|
||||
|
||||
m.mutex.Lock()
|
||||
if m.process == nil {
|
||||
terminatedGracefully = true
|
||||
m.mutex.Unlock()
|
||||
break
|
||||
}
|
||||
}
|
||||
m.mutex.Unlock()
|
||||
}
|
||||
|
||||
// Wait for the process to exit gracefully
|
||||
for i := 0; i < 10; i++ {
|
||||
time.Sleep(200 * time.Millisecond)
|
||||
|
||||
m.mutex.Lock()
|
||||
if m.process == nil {
|
||||
terminatedGracefully = true
|
||||
m.mutex.Unlock()
|
||||
break
|
||||
}
|
||||
m.mutex.Unlock()
|
||||
}
|
||||
|
||||
if terminatedGracefully {
|
||||
|
@ -221,12 +259,33 @@ func (m *Manager) Shutdown() {
|
|||
// If the process didn't exit gracefully, force kill
|
||||
logger.Warn("Subprocess didn't exit gracefully, forcing termination...")
|
||||
|
||||
// Try to kill the process group first
|
||||
if processGroupToTerminate != 0 {
|
||||
if err := syscall.Kill(-processGroupToTerminate, syscall.SIGKILL); err != nil {
|
||||
logger.Warn("Failed to send SIGKILL to process group: %v", err)
|
||||
if runtime.GOOS == "windows" {
|
||||
// On Windows, Kill() is already forceful
|
||||
m.mutex.Lock()
|
||||
if m.process != nil {
|
||||
if err := m.process.Kill(); err != nil {
|
||||
logger.Error("Failed to kill process: %v", err)
|
||||
}
|
||||
}
|
||||
m.mutex.Unlock()
|
||||
} else {
|
||||
// Unix: Try SIGKILL
|
||||
// Try to kill the process group first
|
||||
if processGroupToTerminate != 0 {
|
||||
if err := killProcessGroup(processGroupToTerminate, syscall.SIGKILL); err != nil {
|
||||
logger.Warn("Failed to send SIGKILL to process group: %v", err)
|
||||
|
||||
// Fallback to killing just the process
|
||||
// Fallback to killing just the process
|
||||
m.mutex.Lock()
|
||||
if m.process != nil {
|
||||
if err := m.process.Kill(); err != nil {
|
||||
logger.Error("Failed to kill process: %v", err)
|
||||
}
|
||||
}
|
||||
m.mutex.Unlock()
|
||||
}
|
||||
} else {
|
||||
// Try to kill just the process
|
||||
m.mutex.Lock()
|
||||
if m.process != nil {
|
||||
if err := m.process.Kill(); err != nil {
|
||||
|
@ -235,15 +294,6 @@ func (m *Manager) Shutdown() {
|
|||
}
|
||||
m.mutex.Unlock()
|
||||
}
|
||||
} else {
|
||||
// Try to kill just the process
|
||||
m.mutex.Lock()
|
||||
if m.process != nil {
|
||||
if err := m.process.Kill(); err != nil {
|
||||
logger.Error("Failed to kill process: %v", err)
|
||||
}
|
||||
}
|
||||
m.mutex.Unlock()
|
||||
}
|
||||
|
||||
// Wait a bit more to confirm termination
|
||||
|
|
23
internal/subprocess/manager_unix.go
Normal file
23
internal/subprocess/manager_unix.go
Normal file
|
@ -0,0 +1,23 @@
|
|||
//go:build !windows
|
||||
|
||||
package subprocess
|
||||
|
||||
import (
|
||||
"os/exec"
|
||||
"syscall"
|
||||
)
|
||||
|
||||
// setProcAttr sets Unix-specific process attributes
|
||||
func setProcAttr(cmd *exec.Cmd) {
|
||||
cmd.SysProcAttr = &syscall.SysProcAttr{Setpgid: true}
|
||||
}
|
||||
|
||||
// getProcessGroup gets the process group ID on Unix systems
|
||||
func getProcessGroup(pid int) (int, error) {
|
||||
return syscall.Getpgid(pid)
|
||||
}
|
||||
|
||||
// killProcessGroup kills a process group on Unix systems
|
||||
func killProcessGroup(pgid int, signal syscall.Signal) error {
|
||||
return syscall.Kill(-pgid, signal)
|
||||
}
|
27
internal/subprocess/manager_windows.go
Normal file
27
internal/subprocess/manager_windows.go
Normal file
|
@ -0,0 +1,27 @@
|
|||
//go:build windows
|
||||
|
||||
package subprocess
|
||||
|
||||
import (
|
||||
"os/exec"
|
||||
"syscall"
|
||||
)
|
||||
|
||||
// setProcAttr sets Windows-specific process attributes
|
||||
func setProcAttr(cmd *exec.Cmd) {
|
||||
cmd.SysProcAttr = &syscall.SysProcAttr{
|
||||
CreationFlags: syscall.CREATE_NEW_PROCESS_GROUP,
|
||||
}
|
||||
}
|
||||
|
||||
// getProcessGroup returns the PID itself on Windows (no process groups)
|
||||
func getProcessGroup(pid int) (int, error) {
|
||||
return pid, nil
|
||||
}
|
||||
|
||||
// killProcessGroup kills a process on Windows (no process groups)
|
||||
func killProcessGroup(pgid int, signal syscall.Signal) error {
|
||||
// On Windows, we'll use the process handle directly
|
||||
// This function shouldn't be called on Windows, but we provide it for compatibility
|
||||
return nil
|
||||
}
|
Loading…
Add table
Add a link
Reference in a new issue