From f78385ef23a2c0fa0b0fe1c95b40efcc4d30cd67 Mon Sep 17 00:00:00 2001 From: Chiran Fernando Date: Fri, 4 Apr 2025 17:29:02 +0530 Subject: [PATCH] Start a MCP server as a subprocess --- cmd/proxy/main.go | 93 +++++++++++-- config.yaml | 88 +++++++------ internal/config/config.go | 58 +++++++- internal/proxy/proxy.go | 14 ++ internal/subprocess/manager.go | 234 +++++++++++++++++++++++++++++++++ 5 files changed, 439 insertions(+), 48 deletions(-) create mode 100644 internal/subprocess/manager.go diff --git a/cmd/proxy/main.go b/cmd/proxy/main.go index f26585e..3f5c6c9 100644 --- a/cmd/proxy/main.go +++ b/cmd/proxy/main.go @@ -7,12 +7,14 @@ import ( "net/http" "os" "os/signal" + "syscall" "time" "github.com/wso2/open-mcp-auth-proxy/internal/authz" "github.com/wso2/open-mcp-auth-proxy/internal/config" "github.com/wso2/open-mcp-auth-proxy/internal/constants" "github.com/wso2/open-mcp-auth-proxy/internal/proxy" + "github.com/wso2/open-mcp-auth-proxy/internal/subprocess" "github.com/wso2/open-mcp-auth-proxy/internal/util" ) @@ -22,12 +24,53 @@ func main() { flag.Parse() // 1. Load config - cfg, err := config.LoadConfig("/etc/open-mcp-auth-proxy/config.yaml") + cfg, err := config.LoadConfig("config.yaml") if err != nil { log.Fatalf("Error loading config: %v", err) } - // 2. Create the chosen provider + // 2. Ensure MCPPaths includes the configured paths from the command + if cfg.Command.Enabled { + // Add SSE path to MCPPaths if not already present + ssePath := cfg.Command.SsePath + if ssePath == "" { + ssePath = "/sse" // default + } + + messagePath := cfg.Command.MessagePath + if messagePath == "" { + messagePath = "/messages" // default + } + + // Make sure paths are in MCPPaths + ensurePathInList(&cfg.MCPPaths, ssePath) + ensurePathInList(&cfg.MCPPaths, messagePath) + + // Configure baseUrl + baseUrl := cfg.Command.BaseUrl + if baseUrl == "" { + if cfg.Command.Port > 0 { + baseUrl = fmt.Sprintf("http://localhost:%d", cfg.Command.Port) + } else { + baseUrl = "http://localhost:8000" // default + } + } + + // Add the baseUrl to allowed origins if not already present + // ensureOriginInList(&cfg.CORSConfig.AllowedOrigins, "http://localhost:8080") + log.Printf("Using MCP server baseUrl: %s", baseUrl) + } + + // 3. Start subprocess if configured + var procManager *subprocess.Manager + if cfg.Command.Enabled && cfg.Command.UserCommand != "" { + procManager = subprocess.NewManager() + if err := procManager.Start(&cfg.Command); err != nil { + log.Printf("Warning: Failed to start subprocess: %v", err) + } + } + + // 4. Create the chosen provider var provider authz.Provider if *demoMode { cfg.Mode = "demo" @@ -46,19 +89,18 @@ func main() { provider = authz.NewDefaultProvider(cfg) } - // 3. (Optional) Fetch JWKS if you want local JWT validation + // 5. (Optional) Fetch JWKS if you want local JWT validation if err := util.FetchJWKS(cfg.JWKSURL); err != nil { log.Fatalf("Failed to fetch JWKS: %v", err) } - // 4. Build the main router + // 6. Build the main router mux := proxy.NewRouter(cfg, provider) listen_address := fmt.Sprintf(":%d", cfg.ListenPort) - // 5. Start the server + // 7. Start the server srv := &http.Server{ - Addr: listen_address, Handler: mux, } @@ -70,17 +112,50 @@ func main() { } }() - // 6. Graceful shutdown on Ctrl+C + // 8. Wait for shutdown signal stop := make(chan os.Signal, 1) - signal.Notify(stop, os.Interrupt) + signal.Notify(stop, os.Interrupt, syscall.SIGTERM) <-stop log.Println("Shutting down...") + // 9. First terminate subprocess if running + if procManager != nil && procManager.IsRunning() { + procManager.Shutdown() + } + + // 10. Then shutdown the server + log.Println("Shutting down HTTP server...") shutdownCtx, cancel := proxy.NewShutdownContext(5 * time.Second) defer cancel() if err := srv.Shutdown(shutdownCtx); err != nil { - log.Printf("Shutdown error: %v", err) + log.Printf("HTTP server shutdown error: %v", err) } log.Println("Stopped.") } + +// Helper function to ensure a path is in a list +func ensurePathInList(paths *[]string, path string) { + // Check if path exists in the list + for _, p := range *paths { + if p == path { + return // Path already exists + } + } + // Path doesn't exist, add it + *paths = append(*paths, path) + log.Printf("Added path %s to MCPPaths", path) +} + +// Helper function to ensure an origin is in a list +func ensureOriginInList(origins *[]string, origin string) { + // Check if origin exists in the list + for _, o := range *origins { + if o == origin { + return // Origin already exists + } + } + // Origin doesn't exist, add it + *origins = append(*origins, origin) + log.Printf("Added %s to allowed CORS origins", origin) +} \ No newline at end of file diff --git a/config.yaml b/config.yaml index 0b0ade4..76884b9 100644 --- a/config.yaml +++ b/config.yaml @@ -1,6 +1,6 @@ # config.yaml -mcp_server_base_url: "" +mcp_server_base_url: "http://localhost:8000" listen_port: 8080 timeout_seconds: 10 @@ -8,15 +8,28 @@ mcp_paths: - /messages/ - /sse +# Subprocess configuration +command: + enabled: true + user_command: "npx -y @modelcontextprotocol/server-github" # User only needs to provide this part + base_url: "http://localhost:8000" # Will be used for CORS and in the full command + port: 8000 # Port for the MCP server + sse_path: "/sse" # SSE endpoint path + message_path: "/messages" # Messages endpoint path + work_dir: "" # Working directory (optional) + # env: # Environment variables (optional) + # - "NODE_ENV=development" + path_mapping: - /token: /token - /register: /register - /authorize: /authorize - /.well-known/oauth-authorization-server: /.well-known/oauth-authorization-server + # /token: /oauth/token + # /register: /oidc/register + # /authorize: /authorize + # /u/login: /u/login + # /.well-known/oauth-authorization-server: /.well-known/openid-configuration cors: allowed_origins: - - "" + - "http://localhost:5173" allowed_methods: - "GET" - "POST" @@ -36,36 +49,35 @@ asgardeo: org_name: "" client_id: "" client_secret: "" - -default: - base_url: "" - jwks_url: "" - path: - /.well-known/oauth-authorization-server: - response: - issuer: "" - jwks_uri: "" - authorization_endpoint: "" # Optional - token_endpoint: "" # Optional - registration_endpoint: "" # Optional - response_types_supported: - - "code" - grant_types_supported: - - "authorization_code" - - "refresh_token" - code_challenge_methods_supported: - - "S256" - - "plain" - /authroize: - addQueryParams: - - name: "" - value: "" - /token: - addBodyParams: - - name: "" - value: "" - /register: - addBodyParams: - - name: "" - value: "" +# default: +# base_url: "https://dev-mw4ipgsq1454jrwm.us.auth0.com" +# jwks_url: "https://dev-mw4ipgsq1454jrwm.us.auth0.com/.well-known/jwks.json" +# path: +# /.well-known/oauth-authorization-server: +# response: +# issuer: "https://dev-mw4ipgsq1454jrwm.us.auth0.com/" +# jwks_uri: "https://dev-mw4ipgsq1454jrwm.us.auth0.com/.well-known/jwks.json" +# authorization_endpoint: "https://dev-mw4ipgsq1454jrwm.us.auth0.com/authorize?audience=mcp_proxy" +# # token_endpoint: "https://dev-mw4ipgsq1454jrwm.us.auth0.com/oauth/token" +# # registration_endpoint: "https://dev-mw4ipgsq1454jrwm.us.auth0.com/oidc/register" +# response_types_supported: +# - "code" +# grant_types_supported: +# - "authorization_code" +# - "refresh_token" +# code_challenge_methods_supported: +# - "S256" +# - "plain" +# /authroize: +# addQueryParams: +# - name: "audience" +# value: "mcp_proxy" +# /token: +# addBodyParams: +# - name: "audience" +# value: "mcp_proxy" +# /register: +# addBodyParams: +# - name: "audience" +# value: "mcp_proxy" diff --git a/internal/config/config.go b/internal/config/config.go index 01c3a6f..5e46d6d 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -2,7 +2,7 @@ package config import ( "os" - + "fmt" "gopkg.in/yaml.v2" ) @@ -74,6 +74,62 @@ type Config struct { Demo DemoConfig `yaml:"demo"` Asgardeo AsgardeoConfig `yaml:"asgardeo"` Default DefaultConfig `yaml:"default"` + Command Command `yaml:"command"` // Command to run +} + +// Command struct with explicit configuration for all relevant paths +type Command struct { + Enabled bool `yaml:"enabled"` + UserCommand string `yaml:"user_command"` // Only the part provided by the user + BaseUrl string `yaml:"base_url"` // Base URL for the MCP server + Port int `yaml:"port"` // Port for the MCP server + SsePath string `yaml:"sse_path"` // SSE endpoint path + MessagePath string `yaml:"message_path"` // Messages endpoint path + WorkDir string `yaml:"work_dir"` // Working directory + Args []string `yaml:"args,omitempty"` // Additional arguments + Env []string `yaml:"env,omitempty"` // Environment variables +} + +// BuildExecCommand constructs the full command string for execution +func (c *Command) BuildExecCommand() string { + if c.UserCommand == "" { + return "" + } + + // Apply defaults if not specified + port := c.Port + if port == 0 { + port = 8000 + } + + baseUrl := c.BaseUrl + if baseUrl == "" { + baseUrl = fmt.Sprintf("http://localhost:%d", port) + } + + ssePath := c.SsePath + if ssePath == "" { + ssePath = "/sse" + } + + messagePath := c.MessagePath + if messagePath == "" { + messagePath = "/messages" + } + + // Construct the full command + return fmt.Sprintf( + `npx -y supergateway --stdio "%s" --port %d --baseUrl %s --ssePath %s --messagePath %s`, + c.UserCommand, port, baseUrl, ssePath, messagePath, + ) +} + +// GetExec returns the complete command string for execution +func (c *Command) GetExec() string { + if c.UserCommand == "" { + return "" + } + return c.BuildExecCommand() } // LoadConfig reads a YAML config file into Config struct. diff --git a/internal/proxy/proxy.go b/internal/proxy/proxy.go index c999be4..0213ea4 100644 --- a/internal/proxy/proxy.go +++ b/internal/proxy/proxy.go @@ -191,8 +191,20 @@ func buildProxyHandler(cfg *config.Config, modifiers map[string]RequestModifier) req.URL.RawQuery = r.URL.RawQuery req.Host = targetURL.Host + // for key, values := range r.Header { + // log.Printf("Header: %s, Values: %v", key, values) + // } + cleanHeaders := http.Header{} + // Preserve the original Origin header if present + // if origin := r.Header.Get("Origin"); origin != "" { + // cleanHeaders.Set("Origin", origin) + // } else { + // log.Printf("[proxy] No Origin header found, setting to target URL: http://localhost:8080") + // cleanHeaders.Set("Origin", "http://localhost:8080") + // } + for k, v := range r.Header { // Skip hop-by-hop headers if skipHeader(k) { @@ -236,6 +248,7 @@ func getAllowedOrigin(origin string, cfg *config.Config) string { return cfg.CORSConfig.AllowedOrigins[0] // Default to first allowed origin } for _, allowed := range cfg.CORSConfig.AllowedOrigins { + log.Printf("[proxy] Checking CORS origin: %s against allowed: %s", origin, allowed) if allowed == origin { return allowed } @@ -256,6 +269,7 @@ func addCORSHeaders(w http.ResponseWriter, cfg *config.Config, allowedOrigin, re w.Header().Set("Access-Control-Allow-Credentials", "true") } w.Header().Set("Vary", "Origin") + w.Header().Set("X-Accel-Buffering", "no") } func isAuthPath(path string) bool { diff --git a/internal/subprocess/manager.go b/internal/subprocess/manager.go new file mode 100644 index 0000000..faf22b4 --- /dev/null +++ b/internal/subprocess/manager.go @@ -0,0 +1,234 @@ +package subprocess + +import ( + "log" + "os" + "os/exec" + "sync" + "syscall" + "time" + + "github.com/wso2/open-mcp-auth-proxy/internal/config" +) + +// Manager handles starting and graceful shutdown of subprocesses +type Manager struct { + process *os.Process + processGroup int + mutex sync.Mutex + cmd *exec.Cmd + shutdownDelay time.Duration +} + +// NewManager creates a new subprocess manager +func NewManager() *Manager { + return &Manager{ + shutdownDelay: 5 * time.Second, + } +} + +// SetShutdownDelay sets the maximum time to wait for graceful shutdown +func (m *Manager) SetShutdownDelay(duration time.Duration) { + m.shutdownDelay = duration +} + +// Start launches a subprocess based on the command configuration +func (m *Manager) Start(cmdConfig *config.Command) error { + m.mutex.Lock() + defer m.mutex.Unlock() + + // If a process is already running, return an error + if m.process != nil { + return os.ErrExist + } + + if !cmdConfig.Enabled || cmdConfig.UserCommand == "" { + return nil // Nothing to start + } + + // Get the full command string + execCommand := cmdConfig.GetExec() + if execCommand == "" { + return nil // No command to execute + } + + log.Printf("Starting subprocess with command: %s", execCommand) + + // Use the shell to execute the command + cmd := exec.Command("sh", "-c", execCommand) + + // Set working directory if specified + if cmdConfig.WorkDir != "" { + cmd.Dir = cmdConfig.WorkDir + } + + // Set environment variables if specified + if len(cmdConfig.Env) > 0 { + cmd.Env = append(os.Environ(), cmdConfig.Env...) + } + + // Capture stdout/stderr + cmd.Stdout = os.Stdout + cmd.Stderr = os.Stderr + + // Set the process group for proper termination + cmd.SysProcAttr = &syscall.SysProcAttr{Setpgid: true} + + // Start the process + if err := cmd.Start(); err != nil { + return err + } + + m.process = cmd.Process + m.cmd = cmd + log.Printf("Subprocess started with PID: %d", m.process.Pid) + + // Get and store the process group ID + pgid, err := syscall.Getpgid(m.process.Pid) + if err == nil { + m.processGroup = pgid + log.Printf("Process group ID: %d", m.processGroup) + } else { + log.Printf("Warning: Failed to get process group ID: %v", err) + m.processGroup = m.process.Pid + } + + // Handle process termination in background + go func() { + if err := cmd.Wait(); err != nil { + log.Printf("Subprocess exited with error: %v", err) + } else { + log.Printf("Subprocess exited successfully") + } + + // Clear the process reference when it exits + m.mutex.Lock() + m.process = nil + m.cmd = nil + m.mutex.Unlock() + }() + + return nil +} + +// IsRunning checks if the subprocess is running +func (m *Manager) IsRunning() bool { + m.mutex.Lock() + defer m.mutex.Unlock() + return m.process != nil +} + +// Shutdown gracefully terminates the subprocess +func (m *Manager) Shutdown() { + m.mutex.Lock() + processToTerminate := m.process // Local copy of the process reference + processGroupToTerminate := m.processGroup + m.mutex.Unlock() + + if processToTerminate == nil { + return // No process to terminate + } + + log.Println("Terminating subprocess...") + terminateComplete := make(chan struct{}) + + go func() { + defer close(terminateComplete) + + // Try graceful termination first with SIGTERM + terminatedGracefully := false + + // Try to terminate the process group first + if processGroupToTerminate != 0 { + err := syscall.Kill(-processGroupToTerminate, syscall.SIGTERM) + if err != nil { + log.Printf("Failed to send SIGTERM to process group: %v", err) + + // Fallback to terminating just the process + m.mutex.Lock() + if m.process != nil { + err = m.process.Signal(syscall.SIGTERM) + if err != nil { + log.Printf("Failed to send SIGTERM to process: %v", err) + } + } + m.mutex.Unlock() + } + } else { + // Try to terminate just the process + m.mutex.Lock() + if m.process != nil { + err := m.process.Signal(syscall.SIGTERM) + if err != nil { + log.Printf("Failed to send SIGTERM to process: %v", err) + } + } + m.mutex.Unlock() + } + + // Wait for the process to exit gracefully + for i := 0; i < 10; i++ { + time.Sleep(200 * time.Millisecond) + + m.mutex.Lock() + if m.process == nil { + terminatedGracefully = true + m.mutex.Unlock() + break + } + m.mutex.Unlock() + } + + if terminatedGracefully { + log.Println("Subprocess terminated gracefully") + return + } + + // If the process didn't exit gracefully, force kill + log.Println("Subprocess didn't exit gracefully, forcing termination...") + + // Try to kill the process group first + if processGroupToTerminate != 0 { + if err := syscall.Kill(-processGroupToTerminate, syscall.SIGKILL); err != nil { + log.Printf("Failed to send SIGKILL to process group: %v", err) + + // Fallback to killing just the process + m.mutex.Lock() + if m.process != nil { + if err := m.process.Kill(); err != nil { + log.Printf("Failed to kill process: %v", err) + } + } + m.mutex.Unlock() + } + } else { + // Try to kill just the process + m.mutex.Lock() + if m.process != nil { + if err := m.process.Kill(); err != nil { + log.Printf("Failed to kill process: %v", err) + } + } + m.mutex.Unlock() + } + + // Wait a bit more to confirm termination + time.Sleep(500 * time.Millisecond) + + m.mutex.Lock() + if m.process == nil { + log.Println("Subprocess terminated by force") + } else { + log.Println("Warning: Failed to terminate subprocess") + } + m.mutex.Unlock() + }() + + // Wait for termination with timeout + select { + case <-terminateComplete: + // Termination completed + case <-time.After(m.shutdownDelay): + log.Println("Warning: Subprocess termination timed out") + } +} \ No newline at end of file