mirror of
https://github.com/wso2/open-mcp-auth-proxy.git
synced 2025-06-28 01:23:30 +00:00
Start a MCP server as a subprocess
This commit is contained in:
parent
72c44afc14
commit
f78385ef23
5 changed files with 439 additions and 48 deletions
|
@ -7,12 +7,14 @@ import (
|
||||||
"net/http"
|
"net/http"
|
||||||
"os"
|
"os"
|
||||||
"os/signal"
|
"os/signal"
|
||||||
|
"syscall"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/wso2/open-mcp-auth-proxy/internal/authz"
|
"github.com/wso2/open-mcp-auth-proxy/internal/authz"
|
||||||
"github.com/wso2/open-mcp-auth-proxy/internal/config"
|
"github.com/wso2/open-mcp-auth-proxy/internal/config"
|
||||||
"github.com/wso2/open-mcp-auth-proxy/internal/constants"
|
"github.com/wso2/open-mcp-auth-proxy/internal/constants"
|
||||||
"github.com/wso2/open-mcp-auth-proxy/internal/proxy"
|
"github.com/wso2/open-mcp-auth-proxy/internal/proxy"
|
||||||
|
"github.com/wso2/open-mcp-auth-proxy/internal/subprocess"
|
||||||
"github.com/wso2/open-mcp-auth-proxy/internal/util"
|
"github.com/wso2/open-mcp-auth-proxy/internal/util"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -22,12 +24,53 @@ func main() {
|
||||||
flag.Parse()
|
flag.Parse()
|
||||||
|
|
||||||
// 1. Load config
|
// 1. Load config
|
||||||
cfg, err := config.LoadConfig("/etc/open-mcp-auth-proxy/config.yaml")
|
cfg, err := config.LoadConfig("config.yaml")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatalf("Error loading config: %v", err)
|
log.Fatalf("Error loading config: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// 2. Create the chosen provider
|
// 2. Ensure MCPPaths includes the configured paths from the command
|
||||||
|
if cfg.Command.Enabled {
|
||||||
|
// Add SSE path to MCPPaths if not already present
|
||||||
|
ssePath := cfg.Command.SsePath
|
||||||
|
if ssePath == "" {
|
||||||
|
ssePath = "/sse" // default
|
||||||
|
}
|
||||||
|
|
||||||
|
messagePath := cfg.Command.MessagePath
|
||||||
|
if messagePath == "" {
|
||||||
|
messagePath = "/messages" // default
|
||||||
|
}
|
||||||
|
|
||||||
|
// Make sure paths are in MCPPaths
|
||||||
|
ensurePathInList(&cfg.MCPPaths, ssePath)
|
||||||
|
ensurePathInList(&cfg.MCPPaths, messagePath)
|
||||||
|
|
||||||
|
// Configure baseUrl
|
||||||
|
baseUrl := cfg.Command.BaseUrl
|
||||||
|
if baseUrl == "" {
|
||||||
|
if cfg.Command.Port > 0 {
|
||||||
|
baseUrl = fmt.Sprintf("http://localhost:%d", cfg.Command.Port)
|
||||||
|
} else {
|
||||||
|
baseUrl = "http://localhost:8000" // default
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add the baseUrl to allowed origins if not already present
|
||||||
|
// ensureOriginInList(&cfg.CORSConfig.AllowedOrigins, "http://localhost:8080")
|
||||||
|
log.Printf("Using MCP server baseUrl: %s", baseUrl)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 3. Start subprocess if configured
|
||||||
|
var procManager *subprocess.Manager
|
||||||
|
if cfg.Command.Enabled && cfg.Command.UserCommand != "" {
|
||||||
|
procManager = subprocess.NewManager()
|
||||||
|
if err := procManager.Start(&cfg.Command); err != nil {
|
||||||
|
log.Printf("Warning: Failed to start subprocess: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 4. Create the chosen provider
|
||||||
var provider authz.Provider
|
var provider authz.Provider
|
||||||
if *demoMode {
|
if *demoMode {
|
||||||
cfg.Mode = "demo"
|
cfg.Mode = "demo"
|
||||||
|
@ -46,19 +89,18 @@ func main() {
|
||||||
provider = authz.NewDefaultProvider(cfg)
|
provider = authz.NewDefaultProvider(cfg)
|
||||||
}
|
}
|
||||||
|
|
||||||
// 3. (Optional) Fetch JWKS if you want local JWT validation
|
// 5. (Optional) Fetch JWKS if you want local JWT validation
|
||||||
if err := util.FetchJWKS(cfg.JWKSURL); err != nil {
|
if err := util.FetchJWKS(cfg.JWKSURL); err != nil {
|
||||||
log.Fatalf("Failed to fetch JWKS: %v", err)
|
log.Fatalf("Failed to fetch JWKS: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// 4. Build the main router
|
// 6. Build the main router
|
||||||
mux := proxy.NewRouter(cfg, provider)
|
mux := proxy.NewRouter(cfg, provider)
|
||||||
|
|
||||||
listen_address := fmt.Sprintf(":%d", cfg.ListenPort)
|
listen_address := fmt.Sprintf(":%d", cfg.ListenPort)
|
||||||
|
|
||||||
// 5. Start the server
|
// 7. Start the server
|
||||||
srv := &http.Server{
|
srv := &http.Server{
|
||||||
|
|
||||||
Addr: listen_address,
|
Addr: listen_address,
|
||||||
Handler: mux,
|
Handler: mux,
|
||||||
}
|
}
|
||||||
|
@ -70,17 +112,50 @@ func main() {
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
// 6. Graceful shutdown on Ctrl+C
|
// 8. Wait for shutdown signal
|
||||||
stop := make(chan os.Signal, 1)
|
stop := make(chan os.Signal, 1)
|
||||||
signal.Notify(stop, os.Interrupt)
|
signal.Notify(stop, os.Interrupt, syscall.SIGTERM)
|
||||||
<-stop
|
<-stop
|
||||||
log.Println("Shutting down...")
|
log.Println("Shutting down...")
|
||||||
|
|
||||||
|
// 9. First terminate subprocess if running
|
||||||
|
if procManager != nil && procManager.IsRunning() {
|
||||||
|
procManager.Shutdown()
|
||||||
|
}
|
||||||
|
|
||||||
|
// 10. Then shutdown the server
|
||||||
|
log.Println("Shutting down HTTP server...")
|
||||||
shutdownCtx, cancel := proxy.NewShutdownContext(5 * time.Second)
|
shutdownCtx, cancel := proxy.NewShutdownContext(5 * time.Second)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
if err := srv.Shutdown(shutdownCtx); err != nil {
|
if err := srv.Shutdown(shutdownCtx); err != nil {
|
||||||
log.Printf("Shutdown error: %v", err)
|
log.Printf("HTTP server shutdown error: %v", err)
|
||||||
}
|
}
|
||||||
log.Println("Stopped.")
|
log.Println("Stopped.")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Helper function to ensure a path is in a list
|
||||||
|
func ensurePathInList(paths *[]string, path string) {
|
||||||
|
// Check if path exists in the list
|
||||||
|
for _, p := range *paths {
|
||||||
|
if p == path {
|
||||||
|
return // Path already exists
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Path doesn't exist, add it
|
||||||
|
*paths = append(*paths, path)
|
||||||
|
log.Printf("Added path %s to MCPPaths", path)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Helper function to ensure an origin is in a list
|
||||||
|
func ensureOriginInList(origins *[]string, origin string) {
|
||||||
|
// Check if origin exists in the list
|
||||||
|
for _, o := range *origins {
|
||||||
|
if o == origin {
|
||||||
|
return // Origin already exists
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Origin doesn't exist, add it
|
||||||
|
*origins = append(*origins, origin)
|
||||||
|
log.Printf("Added %s to allowed CORS origins", origin)
|
||||||
|
}
|
88
config.yaml
88
config.yaml
|
@ -1,6 +1,6 @@
|
||||||
# config.yaml
|
# config.yaml
|
||||||
|
|
||||||
mcp_server_base_url: ""
|
mcp_server_base_url: "http://localhost:8000"
|
||||||
listen_port: 8080
|
listen_port: 8080
|
||||||
timeout_seconds: 10
|
timeout_seconds: 10
|
||||||
|
|
||||||
|
@ -8,15 +8,28 @@ mcp_paths:
|
||||||
- /messages/
|
- /messages/
|
||||||
- /sse
|
- /sse
|
||||||
|
|
||||||
|
# Subprocess configuration
|
||||||
|
command:
|
||||||
|
enabled: true
|
||||||
|
user_command: "npx -y @modelcontextprotocol/server-github" # User only needs to provide this part
|
||||||
|
base_url: "http://localhost:8000" # Will be used for CORS and in the full command
|
||||||
|
port: 8000 # Port for the MCP server
|
||||||
|
sse_path: "/sse" # SSE endpoint path
|
||||||
|
message_path: "/messages" # Messages endpoint path
|
||||||
|
work_dir: "" # Working directory (optional)
|
||||||
|
# env: # Environment variables (optional)
|
||||||
|
# - "NODE_ENV=development"
|
||||||
|
|
||||||
path_mapping:
|
path_mapping:
|
||||||
/token: /token
|
# /token: /oauth/token
|
||||||
/register: /register
|
# /register: /oidc/register
|
||||||
/authorize: /authorize
|
# /authorize: /authorize
|
||||||
/.well-known/oauth-authorization-server: /.well-known/oauth-authorization-server
|
# /u/login: /u/login
|
||||||
|
# /.well-known/oauth-authorization-server: /.well-known/openid-configuration
|
||||||
|
|
||||||
cors:
|
cors:
|
||||||
allowed_origins:
|
allowed_origins:
|
||||||
- ""
|
- "http://localhost:5173"
|
||||||
allowed_methods:
|
allowed_methods:
|
||||||
- "GET"
|
- "GET"
|
||||||
- "POST"
|
- "POST"
|
||||||
|
@ -36,36 +49,35 @@ asgardeo:
|
||||||
org_name: "<org_name>"
|
org_name: "<org_name>"
|
||||||
client_id: "<client_id>"
|
client_id: "<client_id>"
|
||||||
client_secret: "<client_secret>"
|
client_secret: "<client_secret>"
|
||||||
|
# default:
|
||||||
default:
|
# base_url: "https://dev-mw4ipgsq1454jrwm.us.auth0.com"
|
||||||
base_url: "<base_url>"
|
# jwks_url: "https://dev-mw4ipgsq1454jrwm.us.auth0.com/.well-known/jwks.json"
|
||||||
jwks_url: "<jwks_url>"
|
# path:
|
||||||
path:
|
# /.well-known/oauth-authorization-server:
|
||||||
/.well-known/oauth-authorization-server:
|
# response:
|
||||||
response:
|
# issuer: "https://dev-mw4ipgsq1454jrwm.us.auth0.com/"
|
||||||
issuer: "<issuer>"
|
# jwks_uri: "https://dev-mw4ipgsq1454jrwm.us.auth0.com/.well-known/jwks.json"
|
||||||
jwks_uri: "<jwks_uri>"
|
# authorization_endpoint: "https://dev-mw4ipgsq1454jrwm.us.auth0.com/authorize?audience=mcp_proxy"
|
||||||
authorization_endpoint: "<authorization_endpoint>" # Optional
|
# # token_endpoint: "https://dev-mw4ipgsq1454jrwm.us.auth0.com/oauth/token"
|
||||||
token_endpoint: "<token_endpoint>" # Optional
|
# # registration_endpoint: "https://dev-mw4ipgsq1454jrwm.us.auth0.com/oidc/register"
|
||||||
registration_endpoint: "<registration_endpoint>" # Optional
|
# response_types_supported:
|
||||||
response_types_supported:
|
# - "code"
|
||||||
- "code"
|
# grant_types_supported:
|
||||||
grant_types_supported:
|
# - "authorization_code"
|
||||||
- "authorization_code"
|
# - "refresh_token"
|
||||||
- "refresh_token"
|
# code_challenge_methods_supported:
|
||||||
code_challenge_methods_supported:
|
# - "S256"
|
||||||
- "S256"
|
# - "plain"
|
||||||
- "plain"
|
# /authroize:
|
||||||
/authroize:
|
# addQueryParams:
|
||||||
addQueryParams:
|
# - name: "audience"
|
||||||
- name: "<name>"
|
# value: "mcp_proxy"
|
||||||
value: "<value>"
|
# /token:
|
||||||
/token:
|
# addBodyParams:
|
||||||
addBodyParams:
|
# - name: "audience"
|
||||||
- name: "<name>"
|
# value: "mcp_proxy"
|
||||||
value: "<value>"
|
# /register:
|
||||||
/register:
|
# addBodyParams:
|
||||||
addBodyParams:
|
# - name: "audience"
|
||||||
- name: "<name>"
|
# value: "mcp_proxy"
|
||||||
value: "<value>"
|
|
||||||
|
|
||||||
|
|
|
@ -2,7 +2,7 @@ package config
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"os"
|
"os"
|
||||||
|
"fmt"
|
||||||
"gopkg.in/yaml.v2"
|
"gopkg.in/yaml.v2"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -74,6 +74,62 @@ type Config struct {
|
||||||
Demo DemoConfig `yaml:"demo"`
|
Demo DemoConfig `yaml:"demo"`
|
||||||
Asgardeo AsgardeoConfig `yaml:"asgardeo"`
|
Asgardeo AsgardeoConfig `yaml:"asgardeo"`
|
||||||
Default DefaultConfig `yaml:"default"`
|
Default DefaultConfig `yaml:"default"`
|
||||||
|
Command Command `yaml:"command"` // Command to run
|
||||||
|
}
|
||||||
|
|
||||||
|
// Command struct with explicit configuration for all relevant paths
|
||||||
|
type Command struct {
|
||||||
|
Enabled bool `yaml:"enabled"`
|
||||||
|
UserCommand string `yaml:"user_command"` // Only the part provided by the user
|
||||||
|
BaseUrl string `yaml:"base_url"` // Base URL for the MCP server
|
||||||
|
Port int `yaml:"port"` // Port for the MCP server
|
||||||
|
SsePath string `yaml:"sse_path"` // SSE endpoint path
|
||||||
|
MessagePath string `yaml:"message_path"` // Messages endpoint path
|
||||||
|
WorkDir string `yaml:"work_dir"` // Working directory
|
||||||
|
Args []string `yaml:"args,omitempty"` // Additional arguments
|
||||||
|
Env []string `yaml:"env,omitempty"` // Environment variables
|
||||||
|
}
|
||||||
|
|
||||||
|
// BuildExecCommand constructs the full command string for execution
|
||||||
|
func (c *Command) BuildExecCommand() string {
|
||||||
|
if c.UserCommand == "" {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
// Apply defaults if not specified
|
||||||
|
port := c.Port
|
||||||
|
if port == 0 {
|
||||||
|
port = 8000
|
||||||
|
}
|
||||||
|
|
||||||
|
baseUrl := c.BaseUrl
|
||||||
|
if baseUrl == "" {
|
||||||
|
baseUrl = fmt.Sprintf("http://localhost:%d", port)
|
||||||
|
}
|
||||||
|
|
||||||
|
ssePath := c.SsePath
|
||||||
|
if ssePath == "" {
|
||||||
|
ssePath = "/sse"
|
||||||
|
}
|
||||||
|
|
||||||
|
messagePath := c.MessagePath
|
||||||
|
if messagePath == "" {
|
||||||
|
messagePath = "/messages"
|
||||||
|
}
|
||||||
|
|
||||||
|
// Construct the full command
|
||||||
|
return fmt.Sprintf(
|
||||||
|
`npx -y supergateway --stdio "%s" --port %d --baseUrl %s --ssePath %s --messagePath %s`,
|
||||||
|
c.UserCommand, port, baseUrl, ssePath, messagePath,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetExec returns the complete command string for execution
|
||||||
|
func (c *Command) GetExec() string {
|
||||||
|
if c.UserCommand == "" {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
return c.BuildExecCommand()
|
||||||
}
|
}
|
||||||
|
|
||||||
// LoadConfig reads a YAML config file into Config struct.
|
// LoadConfig reads a YAML config file into Config struct.
|
||||||
|
|
|
@ -191,8 +191,20 @@ func buildProxyHandler(cfg *config.Config, modifiers map[string]RequestModifier)
|
||||||
req.URL.RawQuery = r.URL.RawQuery
|
req.URL.RawQuery = r.URL.RawQuery
|
||||||
req.Host = targetURL.Host
|
req.Host = targetURL.Host
|
||||||
|
|
||||||
|
// for key, values := range r.Header {
|
||||||
|
// log.Printf("Header: %s, Values: %v", key, values)
|
||||||
|
// }
|
||||||
|
|
||||||
cleanHeaders := http.Header{}
|
cleanHeaders := http.Header{}
|
||||||
|
|
||||||
|
// Preserve the original Origin header if present
|
||||||
|
// if origin := r.Header.Get("Origin"); origin != "" {
|
||||||
|
// cleanHeaders.Set("Origin", origin)
|
||||||
|
// } else {
|
||||||
|
// log.Printf("[proxy] No Origin header found, setting to target URL: http://localhost:8080")
|
||||||
|
// cleanHeaders.Set("Origin", "http://localhost:8080")
|
||||||
|
// }
|
||||||
|
|
||||||
for k, v := range r.Header {
|
for k, v := range r.Header {
|
||||||
// Skip hop-by-hop headers
|
// Skip hop-by-hop headers
|
||||||
if skipHeader(k) {
|
if skipHeader(k) {
|
||||||
|
@ -236,6 +248,7 @@ func getAllowedOrigin(origin string, cfg *config.Config) string {
|
||||||
return cfg.CORSConfig.AllowedOrigins[0] // Default to first allowed origin
|
return cfg.CORSConfig.AllowedOrigins[0] // Default to first allowed origin
|
||||||
}
|
}
|
||||||
for _, allowed := range cfg.CORSConfig.AllowedOrigins {
|
for _, allowed := range cfg.CORSConfig.AllowedOrigins {
|
||||||
|
log.Printf("[proxy] Checking CORS origin: %s against allowed: %s", origin, allowed)
|
||||||
if allowed == origin {
|
if allowed == origin {
|
||||||
return allowed
|
return allowed
|
||||||
}
|
}
|
||||||
|
@ -256,6 +269,7 @@ func addCORSHeaders(w http.ResponseWriter, cfg *config.Config, allowedOrigin, re
|
||||||
w.Header().Set("Access-Control-Allow-Credentials", "true")
|
w.Header().Set("Access-Control-Allow-Credentials", "true")
|
||||||
}
|
}
|
||||||
w.Header().Set("Vary", "Origin")
|
w.Header().Set("Vary", "Origin")
|
||||||
|
w.Header().Set("X-Accel-Buffering", "no")
|
||||||
}
|
}
|
||||||
|
|
||||||
func isAuthPath(path string) bool {
|
func isAuthPath(path string) bool {
|
||||||
|
|
234
internal/subprocess/manager.go
Normal file
234
internal/subprocess/manager.go
Normal file
|
@ -0,0 +1,234 @@
|
||||||
|
package subprocess
|
||||||
|
|
||||||
|
import (
|
||||||
|
"log"
|
||||||
|
"os"
|
||||||
|
"os/exec"
|
||||||
|
"sync"
|
||||||
|
"syscall"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/wso2/open-mcp-auth-proxy/internal/config"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Manager handles starting and graceful shutdown of subprocesses
|
||||||
|
type Manager struct {
|
||||||
|
process *os.Process
|
||||||
|
processGroup int
|
||||||
|
mutex sync.Mutex
|
||||||
|
cmd *exec.Cmd
|
||||||
|
shutdownDelay time.Duration
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewManager creates a new subprocess manager
|
||||||
|
func NewManager() *Manager {
|
||||||
|
return &Manager{
|
||||||
|
shutdownDelay: 5 * time.Second,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetShutdownDelay sets the maximum time to wait for graceful shutdown
|
||||||
|
func (m *Manager) SetShutdownDelay(duration time.Duration) {
|
||||||
|
m.shutdownDelay = duration
|
||||||
|
}
|
||||||
|
|
||||||
|
// Start launches a subprocess based on the command configuration
|
||||||
|
func (m *Manager) Start(cmdConfig *config.Command) error {
|
||||||
|
m.mutex.Lock()
|
||||||
|
defer m.mutex.Unlock()
|
||||||
|
|
||||||
|
// If a process is already running, return an error
|
||||||
|
if m.process != nil {
|
||||||
|
return os.ErrExist
|
||||||
|
}
|
||||||
|
|
||||||
|
if !cmdConfig.Enabled || cmdConfig.UserCommand == "" {
|
||||||
|
return nil // Nothing to start
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get the full command string
|
||||||
|
execCommand := cmdConfig.GetExec()
|
||||||
|
if execCommand == "" {
|
||||||
|
return nil // No command to execute
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Printf("Starting subprocess with command: %s", execCommand)
|
||||||
|
|
||||||
|
// Use the shell to execute the command
|
||||||
|
cmd := exec.Command("sh", "-c", execCommand)
|
||||||
|
|
||||||
|
// Set working directory if specified
|
||||||
|
if cmdConfig.WorkDir != "" {
|
||||||
|
cmd.Dir = cmdConfig.WorkDir
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set environment variables if specified
|
||||||
|
if len(cmdConfig.Env) > 0 {
|
||||||
|
cmd.Env = append(os.Environ(), cmdConfig.Env...)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Capture stdout/stderr
|
||||||
|
cmd.Stdout = os.Stdout
|
||||||
|
cmd.Stderr = os.Stderr
|
||||||
|
|
||||||
|
// Set the process group for proper termination
|
||||||
|
cmd.SysProcAttr = &syscall.SysProcAttr{Setpgid: true}
|
||||||
|
|
||||||
|
// Start the process
|
||||||
|
if err := cmd.Start(); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
m.process = cmd.Process
|
||||||
|
m.cmd = cmd
|
||||||
|
log.Printf("Subprocess started with PID: %d", m.process.Pid)
|
||||||
|
|
||||||
|
// Get and store the process group ID
|
||||||
|
pgid, err := syscall.Getpgid(m.process.Pid)
|
||||||
|
if err == nil {
|
||||||
|
m.processGroup = pgid
|
||||||
|
log.Printf("Process group ID: %d", m.processGroup)
|
||||||
|
} else {
|
||||||
|
log.Printf("Warning: Failed to get process group ID: %v", err)
|
||||||
|
m.processGroup = m.process.Pid
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handle process termination in background
|
||||||
|
go func() {
|
||||||
|
if err := cmd.Wait(); err != nil {
|
||||||
|
log.Printf("Subprocess exited with error: %v", err)
|
||||||
|
} else {
|
||||||
|
log.Printf("Subprocess exited successfully")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Clear the process reference when it exits
|
||||||
|
m.mutex.Lock()
|
||||||
|
m.process = nil
|
||||||
|
m.cmd = nil
|
||||||
|
m.mutex.Unlock()
|
||||||
|
}()
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsRunning checks if the subprocess is running
|
||||||
|
func (m *Manager) IsRunning() bool {
|
||||||
|
m.mutex.Lock()
|
||||||
|
defer m.mutex.Unlock()
|
||||||
|
return m.process != nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Shutdown gracefully terminates the subprocess
|
||||||
|
func (m *Manager) Shutdown() {
|
||||||
|
m.mutex.Lock()
|
||||||
|
processToTerminate := m.process // Local copy of the process reference
|
||||||
|
processGroupToTerminate := m.processGroup
|
||||||
|
m.mutex.Unlock()
|
||||||
|
|
||||||
|
if processToTerminate == nil {
|
||||||
|
return // No process to terminate
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Println("Terminating subprocess...")
|
||||||
|
terminateComplete := make(chan struct{})
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
defer close(terminateComplete)
|
||||||
|
|
||||||
|
// Try graceful termination first with SIGTERM
|
||||||
|
terminatedGracefully := false
|
||||||
|
|
||||||
|
// Try to terminate the process group first
|
||||||
|
if processGroupToTerminate != 0 {
|
||||||
|
err := syscall.Kill(-processGroupToTerminate, syscall.SIGTERM)
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("Failed to send SIGTERM to process group: %v", err)
|
||||||
|
|
||||||
|
// Fallback to terminating just the process
|
||||||
|
m.mutex.Lock()
|
||||||
|
if m.process != nil {
|
||||||
|
err = m.process.Signal(syscall.SIGTERM)
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("Failed to send SIGTERM to process: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
m.mutex.Unlock()
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// Try to terminate just the process
|
||||||
|
m.mutex.Lock()
|
||||||
|
if m.process != nil {
|
||||||
|
err := m.process.Signal(syscall.SIGTERM)
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("Failed to send SIGTERM to process: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
m.mutex.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Wait for the process to exit gracefully
|
||||||
|
for i := 0; i < 10; i++ {
|
||||||
|
time.Sleep(200 * time.Millisecond)
|
||||||
|
|
||||||
|
m.mutex.Lock()
|
||||||
|
if m.process == nil {
|
||||||
|
terminatedGracefully = true
|
||||||
|
m.mutex.Unlock()
|
||||||
|
break
|
||||||
|
}
|
||||||
|
m.mutex.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
if terminatedGracefully {
|
||||||
|
log.Println("Subprocess terminated gracefully")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// If the process didn't exit gracefully, force kill
|
||||||
|
log.Println("Subprocess didn't exit gracefully, forcing termination...")
|
||||||
|
|
||||||
|
// Try to kill the process group first
|
||||||
|
if processGroupToTerminate != 0 {
|
||||||
|
if err := syscall.Kill(-processGroupToTerminate, syscall.SIGKILL); err != nil {
|
||||||
|
log.Printf("Failed to send SIGKILL to process group: %v", err)
|
||||||
|
|
||||||
|
// Fallback to killing just the process
|
||||||
|
m.mutex.Lock()
|
||||||
|
if m.process != nil {
|
||||||
|
if err := m.process.Kill(); err != nil {
|
||||||
|
log.Printf("Failed to kill process: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
m.mutex.Unlock()
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// Try to kill just the process
|
||||||
|
m.mutex.Lock()
|
||||||
|
if m.process != nil {
|
||||||
|
if err := m.process.Kill(); err != nil {
|
||||||
|
log.Printf("Failed to kill process: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
m.mutex.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Wait a bit more to confirm termination
|
||||||
|
time.Sleep(500 * time.Millisecond)
|
||||||
|
|
||||||
|
m.mutex.Lock()
|
||||||
|
if m.process == nil {
|
||||||
|
log.Println("Subprocess terminated by force")
|
||||||
|
} else {
|
||||||
|
log.Println("Warning: Failed to terminate subprocess")
|
||||||
|
}
|
||||||
|
m.mutex.Unlock()
|
||||||
|
}()
|
||||||
|
|
||||||
|
// Wait for termination with timeout
|
||||||
|
select {
|
||||||
|
case <-terminateComplete:
|
||||||
|
// Termination completed
|
||||||
|
case <-time.After(m.shutdownDelay):
|
||||||
|
log.Println("Warning: Subprocess termination timed out")
|
||||||
|
}
|
||||||
|
}
|
Loading…
Add table
Add a link
Reference in a new issue