diff --git a/.github/scripts/release.sh b/.github/scripts/release.sh index 52b024d..2a1f6a9 100644 --- a/.github/scripts/release.sh +++ b/.github/scripts/release.sh @@ -51,7 +51,7 @@ else fi # Extract current version. -CURRENT_VERSION=$(git tag --sort=-v:refname | head -n 1 | sed 's/^v//' || echo "0.0.0") +CURRENT_VERSION=$(git describe --tags --abbrev=0 2>/dev/null || echo "0.0.0") IFS='.' read -r MAJOR MINOR PATCH <<< "${CURRENT_VERSION}" # Determine which part to increment diff --git a/.gitignore b/.gitignore index f2bcda1..d200b58 100644 --- a/.gitignore +++ b/.gitignore @@ -36,7 +36,3 @@ coverage.html # IDE files .vscode - -# node modules -node_modules -openmcpauthproxy diff --git a/Makefile b/Makefile index 3c0c590..b0d0926 100644 --- a/Makefile +++ b/Makefile @@ -24,9 +24,9 @@ TEST_OPTS := -v -race .PHONY: all clean test fmt lint vet coverage help # Default target -all: lint test build-linux build-linux-arm build-darwin build-windows +all: lint test build-linux build-linux-arm build-darwin -build: clean test build-linux build-linux-arm build-darwin build-windows +build: clean test build-linux build-linux-arm build-darwin build-linux: mkdir -p $(BUILD_DIR)/linux @@ -46,12 +46,6 @@ build-darwin: -o $(BUILD_DIR)/darwin/openmcpauthproxy $(PROJECT_ROOT)/cmd/proxy cp config.yaml $(BUILD_DIR)/darwin -build-windows: - mkdir -p $(BUILD_DIR)/windows - GOOS=windows GOARCH=amd64 CGO_ENABLED=0 go build -x -ldflags "-X main.version=$(BUILD_VERSION)" \ - -o $(BUILD_DIR)/windows/openmcpauthproxy.exe ./cmd/proxy - cp config.yaml $(BUILD_DIR)/windows - # Clean build artifacts clean: @echo "Cleaning build artifacts..." diff --git a/README.md b/README.md index 4694164..71b4b60 100644 --- a/README.md +++ b/README.md @@ -47,30 +47,22 @@ Open MCP Auth Proxy sits between MCP clients and your MCP server to: ### Basic Usage -1. Download the latest release from [Github releases](https://github.com/wso2/open-mcp-auth-proxy/releases/latest). +1. The repository comes with a default `config.yaml` file that contains the basic configuration: + +```yaml +listen_port: 8080 +base_url: "http://localhost:8000" # Your MCP server URL +paths: + sse: "/sse" + messages: "/messages/" +``` 2. Start the proxy in demo mode (uses pre-configured authentication with Asgardeo sandbox): -#### Linux/macOS: ```bash ./openmcpauthproxy --demo ``` -#### Windows: -```powershell -.\openmcpauthproxy.exe --demo -``` - -> The repository comes with a default `config.yaml` file that contains the basic configuration: -> -> ```yaml -> listen_port: 8080 -> base_url: "http://localhost:8000" # Your MCP server URL -> paths: -> sse: "/sse" -> messages: "/messages/" -> ``` - 3. Connect using an MCP client like [MCP Inspector](https://github.com/shashimalcse/inspector)(This is a temporary fork with fixes for authentication [issues](https://github.com/modelcontextprotocol/typescript-sdk/issues/257) in the original implementation) ## Connect an Identity Provider @@ -221,102 +213,12 @@ asgardeo: client_id: "" client_secret: "" ``` -## Build from Source -### Prerequisites - -* Go 1.20 or higher -* Git -* Make (optional, for simplified builds) - -### Clone and Build - -1. **Clone the repository:** - ```bash - git clone https://github.com/wso2/open-mcp-auth-proxy - cd open-mcp-auth-proxy - ``` - -2. **Install dependencies:** - ```bash - go get -v -t -d ./... - ``` - -3. **Build the application:** - - **Option A: Using Make** - ```bash - # Build for all platforms - make all - - # Or build for specific platforms - make build-linux # For Linux (x86_64) - make build-linux-arm # For ARM-based Linux - make build-darwin # For macOS - make build-windows # For Windows - ``` - - **Option B: Manual build (works on all platforms)** - ```bash - # Build for your current platform - go build -o openmcpauthproxy ./cmd/proxy - - # Cross-compile for other platforms - GOOS=linux GOARCH=amd64 go build -o openmcpauthproxy-linux ./cmd/proxy - GOOS=windows GOARCH=amd64 go build -o openmcpauthproxy.exe ./cmd/proxy - GOOS=darwin GOARCH=amd64 go build -o openmcpauthproxy-macos ./cmd/proxy - ``` - -### Run the Built Application - -After building, you'll find the executables in the `build` directory (when using Make) or in your project root (when building manually). - -**Linux/macOS:** -```bash -# If built with Make -./build/linux/openmcpauthproxy --demo - -# If built manually -./openmcpauthproxy --demo -``` - -**Windows:** -```powershell -# If built with Make -.\build\windows\openmcpauthproxy.exe --demo - -# If built manually -.\openmcpauthproxy.exe --demo -``` - -### Available Command Line Options +### Build from source ```bash -# Start in demo mode (using Asgardeo sandbox) -./openmcpauthproxy --demo - -# Start with your own Asgardeo organization -./openmcpauthproxy --asgardeo - -# Use stdio transport mode instead of SSE -./openmcpauthproxy --demo --stdio - -# Enable debug logging -./openmcpauthproxy --demo --debug - -# Show all available options -./openmcpauthproxy --help -``` - -### Additional Make Targets - -If you're using Make, these additional targets are available: - -```bash -make test # Run tests -make coverage # Run tests with coverage report -make fmt # Format code with gofmt -make vet # Run go vet -make clean # Clean build artifacts -make help # Show all available targets +git clone https://github.com/wso2/open-mcp-auth-proxy +cd open-mcp-auth-proxy +go get github.com/golang-jwt/jwt/v4 gopkg.in/yaml.v2 +go build -o openmcpauthproxy ./cmd/proxy ``` diff --git a/config.yaml b/config.yaml index 427fc15..5621195 100644 --- a/config.yaml +++ b/config.yaml @@ -2,15 +2,14 @@ # Common configuration for all transport modes listen_port: 8080 -base_url: "http://localhost:3001" # Base URL for the MCP server -port: 3001 # Port for the MCP server +base_url: "http://localhost:8000" # Base URL for the MCP server +port: 8000 # Port for the MCP server timeout_seconds: 10 # Path configuration paths: sse: "/sse" # SSE endpoint path messages: "/messages/" # Messages endpoint path - streamable_http: "/mcp" # MCP endpoint path # Transport mode configuration transport_mode: "sse" # Options: "sse" or "stdio" @@ -29,7 +28,7 @@ path_mapping: # CORS configuration cors: allowed_origins: - - "http://127.0.0.1:6274" + - "http://localhost:5173" allowed_methods: - "GET" - "POST" diff --git a/internal/authz/asgardeo.go b/internal/authz/asgardeo.go index eb26b83..a3c812c 100644 --- a/internal/authz/asgardeo.go +++ b/internal/authz/asgardeo.go @@ -13,14 +13,14 @@ import ( "time" "github.com/wso2/open-mcp-auth-proxy/internal/config" - logger "github.com/wso2/open-mcp-auth-proxy/internal/logging" + "github.com/wso2/open-mcp-auth-proxy/internal/logging" ) type asgardeoProvider struct { cfg *config.Config } -// NewAsgardeoProvider initializes a Provider for Asgardeo. +// NewAsgardeoProvider initializes a Provider for Asgardeo (demo mode). func NewAsgardeoProvider(cfg *config.Config) Provider { return &asgardeoProvider{cfg: cfg} } @@ -113,7 +113,6 @@ func (p *asgardeoProvider) RegisterHandler() http.HandlerFunc { if err := p.createAsgardeoApplication(regReq); err != nil { logger.Warn("Asgardeo application creation failed: %v", err) - http.Error(w, "Failed to create application in Asgardeo", http.StatusInternalServerError) // Optionally http.Error(...) if you want to fail // or continue to return partial data. } @@ -160,19 +159,13 @@ type RegisterResponse struct { } func (p *asgardeoProvider) createAsgardeoApplication(regReq RegisterRequest) error { - - orgName := p.cfg.Demo.OrgName - if p.cfg.Mode == "asgardeo" { - orgName = p.cfg.Asgardeo.OrgName - } - body := buildAsgardeoPayload(regReq) reqBytes, err := json.Marshal(body) if err != nil { return fmt.Errorf("failed to marshal Asgardeo request: %w", err) } - asgardeoAppURL := "https://api.asgardeo.io/t/" + orgName + "/api/server/v1/applications" + asgardeoAppURL := "https://api.asgardeo.io/t/" + p.cfg.Demo.OrgName + "/api/server/v1/applications" req, err := http.NewRequest("POST", asgardeoAppURL, bytes.NewBuffer(reqBytes)) if err != nil { return fmt.Errorf("failed to create Asgardeo API request: %w", err) @@ -202,14 +195,6 @@ func (p *asgardeoProvider) createAsgardeoApplication(regReq RegisterRequest) err } func (p *asgardeoProvider) getAsgardeoAdminToken() (string, error) { - - clientId := p.cfg.Demo.ClientID - clientSecret := p.cfg.Demo.ClientSecret - if p.cfg.Mode == "asgardeo" { - clientId = p.cfg.Asgardeo.ClientID - clientSecret = p.cfg.Asgardeo.ClientSecret - } - tokenURL := p.cfg.AuthServerBaseURL + "/token" formData := "grant_type=client_credentials&scope=internal_application_mgt_create internal_application_mgt_delete " + @@ -222,10 +207,10 @@ func (p *asgardeoProvider) getAsgardeoAdminToken() (string, error) { req.Header.Set("Content-Type", "application/x-www-form-urlencoded") // Sensitive data - should not be logged at INFO level - auth := clientId + ":" + clientSecret + auth := p.cfg.Demo.ClientID + ":" + p.cfg.Demo.ClientSecret req.Header.Set("Authorization", "Basic "+base64.StdEncoding.EncodeToString([]byte(auth))) - - logger.Debug("Requesting admin token for Asgardeo with client ID: %s", clientId) + + logger.Debug("Requesting admin token for Asgardeo with client ID: %s", p.cfg.Demo.ClientID) tr := &http.Transport{ TLSClientConfig: &tls.Config{InsecureSkipVerify: true}, @@ -270,18 +255,6 @@ func buildAsgardeoPayload(regReq RegisterRequest) map[string]interface{} { } appName += "-" + randomString(5) - // Build redirect URIs regex from list of redirect URIs : regexp=(https://app.example.com/callback1|https://app.example.com/callback2) - redirectURI := "regexp=(" + strings.Join(regReq.RedirectURIs, "|") + ")" - redirectURIs := []string{redirectURI} - - // Filter unsupported grant types - var grantTypes []string - for _, gt := range regReq.GrantTypes { - if gt == "authorization_code" || gt == "refresh_token" { - grantTypes = append(grantTypes, gt) - } - } - return map[string]interface{}{ "name": appName, "templateId": "custom-application-oidc", @@ -289,10 +262,10 @@ func buildAsgardeoPayload(regReq RegisterRequest) map[string]interface{} { "oidc": map[string]interface{}{ "clientId": regReq.ClientID, "clientSecret": regReq.ClientSecret, - "grantTypes": grantTypes, - "callbackURLs": redirectURIs, + "grantTypes": regReq.GrantTypes, + "callbackURLs": regReq.RedirectURIs, "allowedOrigins": []string{}, - "publicClient": true, + "publicClient": false, "pkce": map[string]bool{ "mandatory": true, "supportPlainTransformAlgorithm": true, diff --git a/internal/config/config.go b/internal/config/config.go index c51688f..fc6743c 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -3,8 +3,6 @@ package config import ( "fmt" "os" - "runtime" - "strings" "gopkg.in/yaml.v2" ) @@ -19,16 +17,15 @@ const ( // Common path configuration for all transport modes type PathsConfig struct { - SSE string `yaml:"sse"` - Messages string `yaml:"messages"` - StreamableHTTP string `yaml:"streamable_http"` // Path for streamable HTTP requests + SSE string `yaml:"sse"` + Messages string `yaml:"messages"` } // StdioConfig contains stdio-specific configuration type StdioConfig struct { Enabled bool `yaml:"enabled"` - UserCommand string `yaml:"user_command"` // The command provided by the user - WorkDir string `yaml:"work_dir"` // Working directory (optional) + UserCommand string `yaml:"user_command"` // The command provided by the user + WorkDir string `yaml:"work_dir"` // Working directory (optional) Args []string `yaml:"args,omitempty"` // Additional arguments Env []string `yaml:"env,omitempty"` // Environment variables } @@ -86,18 +83,18 @@ type DefaultConfig struct { } type Config struct { - AuthServerBaseURL string - ListenPort int `yaml:"listen_port"` - BaseURL string `yaml:"base_url"` - Port int `yaml:"port"` - JWKSURL string - TimeoutSeconds int `yaml:"timeout_seconds"` - PathMapping map[string]string `yaml:"path_mapping"` - Mode string `yaml:"mode"` - CORSConfig CORSConfig `yaml:"cors"` - TransportMode TransportMode `yaml:"transport_mode"` - Paths PathsConfig `yaml:"paths"` - Stdio StdioConfig `yaml:"stdio"` + AuthServerBaseURL string + ListenPort int `yaml:"listen_port"` + BaseURL string `yaml:"base_url"` + Port int `yaml:"port"` + JWKSURL string + TimeoutSeconds int `yaml:"timeout_seconds"` + PathMapping map[string]string `yaml:"path_mapping"` + Mode string `yaml:"mode"` + CORSConfig CORSConfig `yaml:"cors"` + TransportMode TransportMode `yaml:"transport_mode"` + Paths PathsConfig `yaml:"paths"` + Stdio StdioConfig `yaml:"stdio"` // Nested config for Asgardeo Demo DemoConfig `yaml:"demo"` @@ -139,7 +136,7 @@ func (c *Config) Validate() error { // GetMCPPaths returns the list of paths that should be proxied to the MCP server func (c *Config) GetMCPPaths() []string { - return []string{c.Paths.SSE, c.Paths.Messages, c.Paths.StreamableHTTP} + return []string{c.Paths.SSE, c.Paths.Messages} } // BuildExecCommand constructs the full command string for execution in stdio mode @@ -148,15 +145,7 @@ func (c *Config) BuildExecCommand() string { return "" } - if runtime.GOOS == "windows" { - // For Windows, we need to properly escape the inner command - escapedCommand := strings.ReplaceAll(c.Stdio.UserCommand, `"`, `\"`) - return fmt.Sprintf( - `npx -y supergateway --stdio "%s" --port %d --baseUrl %s --ssePath %s --messagePath %s`, - escapedCommand, c.Port, c.BaseURL, c.Paths.SSE, c.Paths.Messages, - ) - } - + // Construct the full command return fmt.Sprintf( `npx -y supergateway --stdio "%s" --port %d --baseUrl %s --ssePath %s --messagePath %s`, c.Stdio.UserCommand, c.Port, c.BaseURL, c.Paths.SSE, c.Paths.Messages, @@ -176,12 +165,12 @@ func LoadConfig(path string) (*Config, error) { if err := decoder.Decode(&cfg); err != nil { return nil, err } - + // Set default values if cfg.TimeoutSeconds == 0 { cfg.TimeoutSeconds = 15 // default } - + // Set default transport mode if not specified if cfg.TransportMode == "" { cfg.TransportMode = SSETransport // Default to SSE @@ -191,11 +180,11 @@ func LoadConfig(path string) (*Config, error) { if cfg.Port == 0 { cfg.Port = 8000 // default } - + // Validate the configuration if err := cfg.Validate(); err != nil { return nil, err } - + return &cfg, nil } diff --git a/internal/config/config_test.go b/internal/config/config_test.go index edf4182..20c0893 100644 --- a/internal/config/config_test.go +++ b/internal/config/config_test.go @@ -136,15 +136,20 @@ func TestValidate(t *testing.T) { func TestGetMCPPaths(t *testing.T) { cfg := Config{ Paths: PathsConfig{ - SSE: "/custom-sse", - Messages: "/custom-messages", - StreamableHTTP: "/custom-streamable", + SSE: "/custom-sse", + Messages: "/custom-messages", }, } paths := cfg.GetMCPPaths() - if len(paths) != 3 { - t.Errorf("Expected 3 MCP paths, got %d", len(paths)) + if len(paths) != 2 { + t.Errorf("Expected 2 MCP paths, got %d", len(paths)) + } + if paths[0] != "/custom-sse" { + t.Errorf("Expected first path=/custom-sse, got %s", paths[0]) + } + if paths[1] != "/custom-messages" { + t.Errorf("Expected second path=/custom-messages, got %s", paths[1]) } } diff --git a/internal/proxy/proxy.go b/internal/proxy/proxy.go index f4d0dec..33a9ea3 100644 --- a/internal/proxy/proxy.go +++ b/internal/proxy/proxy.go @@ -10,7 +10,7 @@ import ( "github.com/wso2/open-mcp-auth-proxy/internal/authz" "github.com/wso2/open-mcp-auth-proxy/internal/config" - logger "github.com/wso2/open-mcp-auth-proxy/internal/logging" + "github.com/wso2/open-mcp-auth-proxy/internal/logging" "github.com/wso2/open-mcp-auth-proxy/internal/util" ) @@ -106,7 +106,7 @@ func buildProxyHandler(cfg *config.Config, modifiers map[string]RequestModifier) logger.Error("Invalid auth server URL: %v", err) panic(err) // Fatal error that prevents startup } - + mcpBase, err := url.Parse(cfg.BaseURL) if err != nil { logger.Error("Invalid MCP server URL: %v", err) @@ -191,13 +191,13 @@ func buildProxyHandler(cfg *config.Config, modifiers map[string]RequestModifier) req.Host = targetURL.Host cleanHeaders := http.Header{} - + // Set proper origin header to match the target if isSSE { // For SSE, ensure origin matches the target req.Header.Set("Origin", targetURL.Scheme+"://"+targetURL.Host) } - + for k, v := range r.Header { // Skip hop-by-hop headers if skipHeader(k) { @@ -231,12 +231,12 @@ func buildProxyHandler(cfg *config.Config, modifiers map[string]RequestModifier) proxyHost: r.Host, targetHost: targetURL.Host, } - + // Set SSE-specific headers w.Header().Set("X-Accel-Buffering", "no") w.Header().Set("Cache-Control", "no-cache") w.Header().Set("Connection", "keep-alive") - + // Keep SSE connections open HandleSSE(w, r, rp) } else { diff --git a/internal/subprocess/manager.go b/internal/subprocess/manager.go index 902a517..fa64337 100644 --- a/internal/subprocess/manager.go +++ b/internal/subprocess/manager.go @@ -4,14 +4,13 @@ import ( "fmt" "os" "os/exec" - "runtime" - "strings" "sync" "syscall" "time" + "strings" "github.com/wso2/open-mcp-auth-proxy/internal/config" - logger "github.com/wso2/open-mcp-auth-proxy/internal/logging" + "github.com/wso2/open-mcp-auth-proxy/internal/logging" ) // Manager handles starting and graceful shutdown of subprocesses @@ -32,39 +31,34 @@ func NewManager() *Manager { // EnsureDependenciesAvailable checks and installs required package executors func EnsureDependenciesAvailable(command string) error { - // Always ensure npx is available regardless of the command - if _, err := exec.LookPath("npx"); err != nil { - // npx is not available, check if npm is installed - if _, err := exec.LookPath("npm"); err != nil { - return fmt.Errorf("npx not found and npm not available; please install Node.js from https://nodejs.org/") - } - - // Try to install npx using npm - logger.Info("npx not found, attempting to install...") - var cmd *exec.Cmd - if runtime.GOOS == "windows" { - cmd = exec.Command("npm.cmd", "install", "-g", "npx") - } else { - cmd = exec.Command("npm", "install", "-g", "npx") - } - cmd.Stdout = os.Stdout - cmd.Stderr = os.Stderr - - if err := cmd.Run(); err != nil { - return fmt.Errorf("failed to install npx: %w", err) - } - - logger.Info("npx installed successfully") - } - - // Check if uv is needed based on the command - if strings.Contains(command, "uv ") { - if _, err := exec.LookPath("uv"); err != nil { - return fmt.Errorf("command requires uv but it's not installed; please install it following instructions at https://github.com/astral-sh/uv") - } - } - - return nil + // Always ensure npx is available regardless of the command + if _, err := exec.LookPath("npx"); err != nil { + // npx is not available, check if npm is installed + if _, err := exec.LookPath("npm"); err != nil { + return fmt.Errorf("npx not found and npm not available; please install Node.js from https://nodejs.org/") + } + + // Try to install npx using npm + logger.Info("npx not found, attempting to install...") + cmd := exec.Command("npm", "install", "-g", "npx") + cmd.Stdout = os.Stdout + cmd.Stderr = os.Stderr + + if err := cmd.Run(); err != nil { + return fmt.Errorf("failed to install npx: %w", err) + } + + logger.Info("npx installed successfully") + } + + // Check if uv is needed based on the command + if strings.Contains(command, "uv ") { + if _, err := exec.LookPath("uv"); err != nil { + return fmt.Errorf("command requires uv but it's not installed; please install it following instructions at https://github.com/astral-sh/uv") + } + } + + return nil } // SetShutdownDelay sets the maximum time to wait for graceful shutdown @@ -94,13 +88,8 @@ func (m *Manager) Start(cfg *config.Config) error { logger.Info("Starting subprocess with command: %s", execCommand) - var cmd *exec.Cmd - if runtime.GOOS == "windows" { - // Use PowerShell on Windows for better quote handling - cmd = exec.Command("powershell", "-Command", execCommand) - } else { - cmd = exec.Command("sh", "-c", execCommand) - } + // Use the shell to execute the command + cmd := exec.Command("sh", "-c", execCommand) // Set working directory if specified if cfg.Stdio.WorkDir != "" { @@ -116,8 +105,8 @@ func (m *Manager) Start(cfg *config.Config) error { cmd.Stdout = os.Stdout cmd.Stderr = os.Stderr - // Set platform-specific process attributes - setProcAttr(cmd) + // Set the process group for proper termination + cmd.SysProcAttr = &syscall.SysProcAttr{Setpgid: true} // Start the process if err := cmd.Start(); err != nil { @@ -128,13 +117,11 @@ func (m *Manager) Start(cfg *config.Config) error { m.cmd = cmd logger.Info("Subprocess started with PID: %d", m.process.Pid) - // Get and store the process group ID (Unix) or PID (Windows) - pgid, err := getProcessGroup(m.process.Pid) + // Get and store the process group ID + pgid, err := syscall.Getpgid(m.process.Pid) if err == nil { m.processGroup = pgid - if runtime.GOOS != "windows" { - logger.Debug("Process group ID: %d", m.processGroup) - } + logger.Debug("Process group ID: %d", m.processGroup) } else { logger.Warn("Failed to get process group ID: %v", err) m.processGroup = m.process.Pid @@ -168,7 +155,7 @@ func (m *Manager) IsRunning() bool { // Shutdown gracefully terminates the subprocess func (m *Manager) Shutdown() { m.mutex.Lock() - processToTerminate := m.process // Local copy of the process reference + processToTerminate := m.process // Local copy of the process reference processGroupToTerminate := m.processGroup m.mutex.Unlock() @@ -182,73 +169,48 @@ func (m *Manager) Shutdown() { go func() { defer close(terminateComplete) - // Try graceful termination first + // Try graceful termination first with SIGTERM terminatedGracefully := false - if runtime.GOOS == "windows" { - // Windows: Try to terminate the process - m.mutex.Lock() - if m.process != nil { - err := m.process.Kill() - if err != nil { - logger.Warn("Failed to terminate process: %v", err) - } - } - m.mutex.Unlock() + // Try to terminate the process group first + if processGroupToTerminate != 0 { + err := syscall.Kill(-processGroupToTerminate, syscall.SIGTERM) + if err != nil { + logger.Warn("Failed to send SIGTERM to process group: %v", err) - // Wait a bit to see if it terminates - for i := 0; i < 10; i++ { - time.Sleep(200 * time.Millisecond) - m.mutex.Lock() - if m.process == nil { - terminatedGracefully = true - m.mutex.Unlock() - break - } - m.mutex.Unlock() - } - } else { - // Unix: Use SIGTERM followed by SIGKILL if necessary - // Try to terminate the process group first - if processGroupToTerminate != 0 { - err := killProcessGroup(processGroupToTerminate, syscall.SIGTERM) - if err != nil { - logger.Warn("Failed to send SIGTERM to process group: %v", err) - - // Fallback to terminating just the process - m.mutex.Lock() - if m.process != nil { - err = m.process.Signal(syscall.SIGTERM) - if err != nil { - logger.Warn("Failed to send SIGTERM to process: %v", err) - } - } - m.mutex.Unlock() - } - } else { - // Try to terminate just the process + // Fallback to terminating just the process m.mutex.Lock() if m.process != nil { - err := m.process.Signal(syscall.SIGTERM) + err = m.process.Signal(syscall.SIGTERM) if err != nil { logger.Warn("Failed to send SIGTERM to process: %v", err) } } m.mutex.Unlock() } - - // Wait for the process to exit gracefully - for i := 0; i < 10; i++ { - time.Sleep(200 * time.Millisecond) - - m.mutex.Lock() - if m.process == nil { - terminatedGracefully = true - m.mutex.Unlock() - break + } else { + // Try to terminate just the process + m.mutex.Lock() + if m.process != nil { + err := m.process.Signal(syscall.SIGTERM) + if err != nil { + logger.Warn("Failed to send SIGTERM to process: %v", err) } - m.mutex.Unlock() } + m.mutex.Unlock() + } + + // Wait for the process to exit gracefully + for i := 0; i < 10; i++ { + time.Sleep(200 * time.Millisecond) + + m.mutex.Lock() + if m.process == nil { + terminatedGracefully = true + m.mutex.Unlock() + break + } + m.mutex.Unlock() } if terminatedGracefully { @@ -259,33 +221,12 @@ func (m *Manager) Shutdown() { // If the process didn't exit gracefully, force kill logger.Warn("Subprocess didn't exit gracefully, forcing termination...") - if runtime.GOOS == "windows" { - // On Windows, Kill() is already forceful - m.mutex.Lock() - if m.process != nil { - if err := m.process.Kill(); err != nil { - logger.Error("Failed to kill process: %v", err) - } - } - m.mutex.Unlock() - } else { - // Unix: Try SIGKILL - // Try to kill the process group first - if processGroupToTerminate != 0 { - if err := killProcessGroup(processGroupToTerminate, syscall.SIGKILL); err != nil { - logger.Warn("Failed to send SIGKILL to process group: %v", err) + // Try to kill the process group first + if processGroupToTerminate != 0 { + if err := syscall.Kill(-processGroupToTerminate, syscall.SIGKILL); err != nil { + logger.Warn("Failed to send SIGKILL to process group: %v", err) - // Fallback to killing just the process - m.mutex.Lock() - if m.process != nil { - if err := m.process.Kill(); err != nil { - logger.Error("Failed to kill process: %v", err) - } - } - m.mutex.Unlock() - } - } else { - // Try to kill just the process + // Fallback to killing just the process m.mutex.Lock() if m.process != nil { if err := m.process.Kill(); err != nil { @@ -294,6 +235,15 @@ func (m *Manager) Shutdown() { } m.mutex.Unlock() } + } else { + // Try to kill just the process + m.mutex.Lock() + if m.process != nil { + if err := m.process.Kill(); err != nil { + logger.Error("Failed to kill process: %v", err) + } + } + m.mutex.Unlock() } // Wait a bit more to confirm termination diff --git a/internal/subprocess/manager_unix.go b/internal/subprocess/manager_unix.go deleted file mode 100644 index 03ae1a8..0000000 --- a/internal/subprocess/manager_unix.go +++ /dev/null @@ -1,23 +0,0 @@ -//go:build !windows - -package subprocess - -import ( - "os/exec" - "syscall" -) - -// setProcAttr sets Unix-specific process attributes -func setProcAttr(cmd *exec.Cmd) { - cmd.SysProcAttr = &syscall.SysProcAttr{Setpgid: true} -} - -// getProcessGroup gets the process group ID on Unix systems -func getProcessGroup(pid int) (int, error) { - return syscall.Getpgid(pid) -} - -// killProcessGroup kills a process group on Unix systems -func killProcessGroup(pgid int, signal syscall.Signal) error { - return syscall.Kill(-pgid, signal) -} diff --git a/internal/subprocess/manager_windows.go b/internal/subprocess/manager_windows.go deleted file mode 100644 index a039897..0000000 --- a/internal/subprocess/manager_windows.go +++ /dev/null @@ -1,27 +0,0 @@ -//go:build windows - -package subprocess - -import ( - "os/exec" - "syscall" -) - -// setProcAttr sets Windows-specific process attributes -func setProcAttr(cmd *exec.Cmd) { - cmd.SysProcAttr = &syscall.SysProcAttr{ - CreationFlags: syscall.CREATE_NEW_PROCESS_GROUP, - } -} - -// getProcessGroup returns the PID itself on Windows (no process groups) -func getProcessGroup(pid int) (int, error) { - return pid, nil -} - -// killProcessGroup kills a process on Windows (no process groups) -func killProcessGroup(pgid int, signal syscall.Signal) error { - // On Windows, we'll use the process handle directly - // This function shouldn't be called on Windows, but we provide it for compatibility - return nil -} diff --git a/pull_request_template.md b/pull_request_template.md index c401a06..9b32185 100644 --- a/pull_request_template.md +++ b/pull_request_template.md @@ -1,11 +1,52 @@ ## Purpose - +> Describe the problems, issues, or needs driving this feature/fix and include links to related issues in the following format: Resolves issue1, issue2, etc. -## Related Issues - +## Goals +> Describe the solutions that this feature/fix will introduce to resolve the problems described above + +## Approach +> Describe how you are implementing the solutions. Include an animated GIF or screenshot if the change affects the UI (email documentation@wso2.com to review all UI text). Include a link to a Markdown file or Google doc if the feature write-up is too long to paste here. + +## User stories +> Summary of user stories addressed by this change> + +## Release note +> Brief description of the new feature or bug fix as it will appear in the release notes + +## Documentation +> Link(s) to product documentation that addresses the changes of this PR. If no doc impact, enter “N/A” plus brief explanation of why there’s no doc impact + +## Training +> Link to the PR for changes to the training content in https://github.com/wso2/WSO2-Training, if applicable + +## Certification +> Type “Sent” when you have provided new/updated certification questions, plus four answers for each question (correct answer highlighted in bold), based on this change. Certification questions/answers should be sent to certification@wso2.com and NOT pasted in this PR. If there is no impact on certification exams, type “N/A” and explain why. + +## Marketing +> Link to drafts of marketing content that will describe and promote this feature, including product page changes, technical articles, blog posts, videos, etc., if applicable + +## Automation tests + - Unit tests + > Code coverage information + - Integration tests + > Details about the test cases and coverage + +## Security checks + - Followed secure coding standards in http://wso2.com/technical-reports/wso2-secure-engineering-guidelines? yes/no + - Ran FindSecurityBugs plugin and verified report? yes/no + - Confirmed that this PR doesn't commit any keys, passwords, tokens, usernames, or other secrets? yes/no + +## Samples +> Provide high-level details about the samples related to this feature ## Related PRs - +> List any other related PRs ## Migrations (if applicable) - +> Describe migration steps and platforms on which migration has been tested + +## Test environment +> List all JDK versions, operating systems, databases, and browser/versions on which this feature/fix was tested + +## Learning +> Describe the research phase and any blog posts, patterns, libraries, or add-ons you used to solve the problem. \ No newline at end of file