Add transport mode support for stdio, SSE stability fixes (#13)

Add transport mode support for stdio, SSE stability fixes
This commit is contained in:
Chiran Fernando 2025-04-08 12:46:00 +05:30 committed by GitHub
parent 6ce52261db
commit 32c9378aad
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
12 changed files with 808 additions and 142 deletions

View file

@ -7,13 +7,13 @@ import (
"encoding/json"
"fmt"
"io"
"log"
"math/rand"
"net/http"
"strings"
"time"
"github.com/wso2/open-mcp-auth-proxy/internal/config"
"github.com/wso2/open-mcp-auth-proxy/internal/logging"
)
type asgardeoProvider struct {
@ -31,6 +31,7 @@ func (p *asgardeoProvider) WellKnownHandler() http.HandlerFunc {
w.Header().Set("Access-Control-Allow-Origin", "*")
w.Header().Set("Access-Control-Allow-Headers", "Authorization, Content-Type")
w.Header().Set("Access-Control-Allow-Methods", "GET, OPTIONS")
w.Header().Set("X-Accel-Buffering", "no")
if r.Method == http.MethodOptions {
w.WriteHeader(http.StatusNoContent)
@ -70,8 +71,9 @@ func (p *asgardeoProvider) WellKnownHandler() http.HandlerFunc {
}
w.Header().Set("Content-Type", "application/json")
w.Header().Set("X-Accel-Buffering", "no")
if err := json.NewEncoder(w).Encode(response); err != nil {
log.Printf("[asgardeoProvider] Error encoding well-known: %v", err)
logger.Error("Error encoding well-known: %v", err)
http.Error(w, "Internal server error", http.StatusInternalServerError)
}
}
@ -83,6 +85,7 @@ func (p *asgardeoProvider) RegisterHandler() http.HandlerFunc {
w.Header().Set("Access-Control-Allow-Origin", "*")
w.Header().Set("Access-Control-Allow-Headers", "Authorization, Content-Type")
w.Header().Set("Access-Control-Allow-Methods", "GET, OPTIONS")
w.Header().Set("X-Accel-Buffering", "no")
if r.Method == http.MethodOptions {
w.WriteHeader(http.StatusNoContent)
@ -95,7 +98,7 @@ func (p *asgardeoProvider) RegisterHandler() http.HandlerFunc {
var regReq RegisterRequest
if err := json.NewDecoder(r.Body).Decode(&regReq); err != nil {
log.Printf("ERROR: reading register request: %v", err)
logger.Error("Reading register request: %v", err)
http.Error(w, "Invalid request body", http.StatusBadRequest)
return
}
@ -109,7 +112,7 @@ func (p *asgardeoProvider) RegisterHandler() http.HandlerFunc {
regReq.ClientSecret = randomString(16)
if err := p.createAsgardeoApplication(regReq); err != nil {
log.Printf("WARN: Asgardeo application creation failed: %v", err)
logger.Warn("Asgardeo application creation failed: %v", err)
// Optionally http.Error(...) if you want to fail
// or continue to return partial data.
}
@ -124,9 +127,10 @@ func (p *asgardeoProvider) RegisterHandler() http.HandlerFunc {
}
w.Header().Set("Content-Type", "application/json")
w.Header().Set("X-Accel-Buffering", "no")
w.WriteHeader(http.StatusCreated)
if err := json.NewEncoder(w).Encode(resp); err != nil {
log.Printf("ERROR: encoding /register response: %v", err)
logger.Error("Encoding /register response: %v", err)
http.Error(w, "Internal server error", http.StatusInternalServerError)
}
}
@ -186,7 +190,7 @@ func (p *asgardeoProvider) createAsgardeoApplication(regReq RegisterRequest) err
return fmt.Errorf("Asgardeo creation error (%d): %s", resp.StatusCode, string(respBody))
}
log.Printf("INFO: Created Asgardeo application for clientID=%s", regReq.ClientID)
logger.Info("Created Asgardeo application for clientID=%s", regReq.ClientID)
return nil
}
@ -202,8 +206,11 @@ func (p *asgardeoProvider) getAsgardeoAdminToken() (string, error) {
}
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
// Sensitive data - should not be logged at INFO level
auth := p.cfg.Demo.ClientID + ":" + p.cfg.Demo.ClientSecret
req.Header.Set("Authorization", "Basic "+base64.StdEncoding.EncodeToString([]byte(auth)))
logger.Debug("Requesting admin token for Asgardeo with client ID: %s", p.cfg.Demo.ClientID)
tr := &http.Transport{
TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
@ -234,6 +241,10 @@ func (p *asgardeoProvider) getAsgardeoAdminToken() (string, error) {
return "", fmt.Errorf("failed to parse token JSON: %w", err)
}
// Don't log the actual token at info level, only at debug level
logger.Debug("Received access token: %s", tokenResp.AccessToken)
logger.Info("Successfully obtained admin token from Asgardeo")
return tokenResp.AccessToken, nil
}

View file

@ -5,6 +5,7 @@ import (
"net/http"
"github.com/wso2/open-mcp-auth-proxy/internal/config"
"github.com/wso2/open-mcp-auth-proxy/internal/logging"
)
type defaultProvider struct {
@ -81,6 +82,7 @@ func (p *defaultProvider) WellKnownHandler() http.HandlerFunc {
w.Header().Set("Content-Type", "application/json")
if err := json.NewEncoder(w).Encode(response); err != nil {
logger.Error("Error encoding well-known response: %v", err)
http.Error(w, "Internal server error", http.StatusInternalServerError)
}
return

View file

@ -1,12 +1,35 @@
package config
import (
"fmt"
"os"
"gopkg.in/yaml.v2"
)
// AsgardeoConfig groups all Asgardeo-specific fields
// Transport mode for MCP server
type TransportMode string
const (
SSETransport TransportMode = "sse"
StdioTransport TransportMode = "stdio"
)
// Common path configuration for all transport modes
type PathsConfig struct {
SSE string `yaml:"sse"`
Messages string `yaml:"messages"`
}
// StdioConfig contains stdio-specific configuration
type StdioConfig struct {
Enabled bool `yaml:"enabled"`
UserCommand string `yaml:"user_command"` // The command provided by the user
WorkDir string `yaml:"work_dir"` // Working directory (optional)
Args []string `yaml:"args,omitempty"` // Additional arguments
Env []string `yaml:"env,omitempty"` // Environment variables
}
type DemoConfig struct {
ClientID string `yaml:"client_id"`
ClientSecret string `yaml:"client_secret"`
@ -60,15 +83,18 @@ type DefaultConfig struct {
}
type Config struct {
AuthServerBaseURL string
MCPServerBaseURL string `yaml:"mcp_server_base_url"`
ListenPort int `yaml:"listen_port"`
JWKSURL string
TimeoutSeconds int `yaml:"timeout_seconds"`
MCPPaths []string `yaml:"mcp_paths"`
PathMapping map[string]string `yaml:"path_mapping"`
Mode string `yaml:"mode"`
CORSConfig CORSConfig `yaml:"cors"`
AuthServerBaseURL string
ListenPort int `yaml:"listen_port"`
BaseURL string `yaml:"base_url"`
Port int `yaml:"port"`
JWKSURL string
TimeoutSeconds int `yaml:"timeout_seconds"`
PathMapping map[string]string `yaml:"path_mapping"`
Mode string `yaml:"mode"`
CORSConfig CORSConfig `yaml:"cors"`
TransportMode TransportMode `yaml:"transport_mode"`
Paths PathsConfig `yaml:"paths"`
Stdio StdioConfig `yaml:"stdio"`
// Nested config for Asgardeo
Demo DemoConfig `yaml:"demo"`
@ -76,6 +102,56 @@ type Config struct {
Default DefaultConfig `yaml:"default"`
}
// Validate checks if the config is valid based on transport mode
func (c *Config) Validate() error {
// Validate based on transport mode
if c.TransportMode == StdioTransport {
if !c.Stdio.Enabled {
return fmt.Errorf("stdio.enabled must be true in stdio transport mode")
}
if c.Stdio.UserCommand == "" {
return fmt.Errorf("stdio.user_command is required in stdio transport mode")
}
}
// Validate paths
if c.Paths.SSE == "" {
c.Paths.SSE = "/sse" // Default value
}
if c.Paths.Messages == "" {
c.Paths.Messages = "/messages" // Default value
}
// Validate base URL
if c.BaseURL == "" {
if c.Port > 0 {
c.BaseURL = fmt.Sprintf("http://localhost:%d", c.Port)
} else {
c.BaseURL = "http://localhost:8000" // Default value
}
}
return nil
}
// GetMCPPaths returns the list of paths that should be proxied to the MCP server
func (c *Config) GetMCPPaths() []string {
return []string{c.Paths.SSE, c.Paths.Messages}
}
// BuildExecCommand constructs the full command string for execution in stdio mode
func (c *Config) BuildExecCommand() string {
if c.Stdio.UserCommand == "" {
return ""
}
// Construct the full command
return fmt.Sprintf(
`npx -y supergateway --stdio "%s" --port %d --baseUrl %s --ssePath %s --messagePath %s`,
c.Stdio.UserCommand, c.Port, c.BaseURL, c.Paths.SSE, c.Paths.Messages,
)
}
// LoadConfig reads a YAML config file into Config struct.
func LoadConfig(path string) (*Config, error) {
f, err := os.Open(path)
@ -89,8 +165,26 @@ func LoadConfig(path string) (*Config, error) {
if err := decoder.Decode(&cfg); err != nil {
return nil, err
}
// Set default values
if cfg.TimeoutSeconds == 0 {
cfg.TimeoutSeconds = 15 // default
}
// Set default transport mode if not specified
if cfg.TransportMode == "" {
cfg.TransportMode = SSETransport // Default to SSE
}
// Set default port if not specified
if cfg.Port == 0 {
cfg.Port = 8000 // default
}
// Validate the configuration
if err := cfg.Validate(); err != nil {
return nil, err
}
return &cfg, nil
}

View file

@ -0,0 +1,34 @@
package logger
import (
"log"
)
var isDebug = false
// SetDebug enables or disables debug logging
func SetDebug(debug bool) {
isDebug = debug
}
// Debug logs a debug-level message
func Debug(format string, v ...interface{}) {
if isDebug {
log.Printf("DEBUG: "+format, v...)
}
}
// Info logs an info-level message
func Info(format string, v ...interface{}) {
log.Printf("INFO: "+format, v...)
}
// Warn logs a warning-level message
func Warn(format string, v ...interface{}) {
log.Printf("WARN: "+format, v...)
}
// Error logs an error-level message
func Error(format string, v ...interface{}) {
log.Printf("ERROR: "+format, v...)
}

View file

@ -9,6 +9,7 @@ import (
"strings"
"github.com/wso2/open-mcp-auth-proxy/internal/config"
"github.com/wso2/open-mcp-auth-proxy/internal/logging"
)
// RequestModifier modifies requests before they are proxied
@ -148,6 +149,7 @@ func (m *RegisterModifier) ModifyRequest(req *http.Request) (*http.Request, erro
if strings.Contains(contentType, "application/x-www-form-urlencoded") {
// Parse form data
if err := req.ParseForm(); err != nil {
logger.Error("Failed to parse form data: %v", err)
return nil, err
}
@ -169,12 +171,14 @@ func (m *RegisterModifier) ModifyRequest(req *http.Request) (*http.Request, erro
// Read body
bodyBytes, err := io.ReadAll(req.Body)
if err != nil {
logger.Error("Failed to read request body: %v", err)
return nil, err
}
// Parse JSON
var jsonData map[string]interface{}
if err := json.Unmarshal(bodyBytes, &jsonData); err != nil {
logger.Error("Failed to parse JSON body: %v", err)
return nil, err
}
@ -186,6 +190,7 @@ func (m *RegisterModifier) ModifyRequest(req *http.Request) (*http.Request, erro
// Marshal back to JSON
modifiedBody, err := json.Marshal(jsonData)
if err != nil {
logger.Error("Failed to marshal modified JSON: %v", err)
return nil, err
}

View file

@ -2,7 +2,6 @@ package proxy
import (
"context"
"log"
"net/http"
"net/http/httputil"
"net/url"
@ -11,6 +10,7 @@ import (
"github.com/wso2/open-mcp-auth-proxy/internal/authz"
"github.com/wso2/open-mcp-auth-proxy/internal/config"
"github.com/wso2/open-mcp-auth-proxy/internal/logging"
"github.com/wso2/open-mcp-auth-proxy/internal/util"
)
@ -82,7 +82,8 @@ func NewRouter(cfg *config.Config, provider authz.Provider) http.Handler {
}
// MCP paths
for _, path := range cfg.MCPPaths {
mcpPaths := cfg.GetMCPPaths()
for _, path := range mcpPaths {
mux.HandleFunc(path, buildProxyHandler(cfg, modifiers))
registeredPaths[path] = true
}
@ -100,23 +101,21 @@ func NewRouter(cfg *config.Config, provider authz.Provider) http.Handler {
func buildProxyHandler(cfg *config.Config, modifiers map[string]RequestModifier) http.HandlerFunc {
// Parse the base URLs up front
authBase, err := url.Parse(cfg.AuthServerBaseURL)
if err != nil {
log.Fatalf("Invalid auth server URL: %v", err)
logger.Error("Invalid auth server URL: %v", err)
panic(err) // Fatal error that prevents startup
}
mcpBase, err := url.Parse(cfg.MCPServerBaseURL)
mcpBase, err := url.Parse(cfg.BaseURL)
if err != nil {
log.Fatalf("Invalid MCP server URL: %v", err)
logger.Error("Invalid MCP server URL: %v", err)
panic(err) // Fatal error that prevents startup
}
// Detect SSE paths from config
ssePaths := make(map[string]bool)
for _, p := range cfg.MCPPaths {
if p == "/sse" {
ssePaths[p] = true
}
}
ssePaths[cfg.Paths.SSE] = true
return func(w http.ResponseWriter, r *http.Request) {
origin := r.Header.Get("Origin")
@ -124,7 +123,7 @@ func buildProxyHandler(cfg *config.Config, modifiers map[string]RequestModifier)
// Handle OPTIONS
if r.Method == http.MethodOptions {
if allowedOrigin == "" {
log.Printf("[proxy] Preflight request from disallowed origin: %s", origin)
logger.Warn("Preflight request from disallowed origin: %s", origin)
http.Error(w, "CORS origin not allowed", http.StatusForbidden)
return
}
@ -134,7 +133,7 @@ func buildProxyHandler(cfg *config.Config, modifiers map[string]RequestModifier)
}
if allowedOrigin == "" {
log.Printf("[proxy] Request from disallowed origin: %s for %s", origin, r.URL.Path)
logger.Warn("Request from disallowed origin: %s for %s", origin, r.URL.Path)
http.Error(w, "CORS origin not allowed", http.StatusForbidden)
return
}
@ -152,7 +151,7 @@ func buildProxyHandler(cfg *config.Config, modifiers map[string]RequestModifier)
// Validate JWT for MCP paths if required
// Placeholder for JWT validation logic
if err := util.ValidateJWT(r.Header.Get("Authorization")); err != nil {
log.Printf("[proxy] Unauthorized request to %s: %v", r.URL.Path, err)
logger.Warn("Unauthorized request to %s: %v", r.URL.Path, err)
http.Error(w, "Unauthorized", http.StatusUnauthorized)
return
}
@ -170,7 +169,7 @@ func buildProxyHandler(cfg *config.Config, modifiers map[string]RequestModifier)
var err error
r, err = modifier.ModifyRequest(r)
if err != nil {
log.Printf("[proxy] Error modifying request: %v", err)
logger.Error("Error modifying request: %v", err)
http.Error(w, "Bad Request", http.StatusBadRequest)
return
}
@ -192,7 +191,13 @@ func buildProxyHandler(cfg *config.Config, modifiers map[string]RequestModifier)
req.Host = targetURL.Host
cleanHeaders := http.Header{}
// Set proper origin header to match the target
if isSSE {
// For SSE, ensure origin matches the target
req.Header.Set("Origin", targetURL.Scheme+"://"+targetURL.Host)
}
for k, v := range r.Header {
// Skip hop-by-hop headers
if skipHeader(k) {
@ -205,21 +210,33 @@ func buildProxyHandler(cfg *config.Config, modifiers map[string]RequestModifier)
req.Header = cleanHeaders
log.Printf("[proxy] %s -> %s%s", r.URL.Path, req.URL.Host, req.URL.Path)
logger.Debug("%s -> %s%s", r.URL.Path, req.URL.Host, req.URL.Path)
},
ModifyResponse: func(resp *http.Response) error {
log.Printf("[proxy] Response from %s%s: %d", resp.Request.URL.Host, resp.Request.URL.Path, resp.StatusCode)
logger.Debug("Response from %s%s: %d", resp.Request.URL.Host, resp.Request.URL.Path, resp.StatusCode)
resp.Header.Del("Access-Control-Allow-Origin") // Avoid upstream conflicts
return nil
},
ErrorHandler: func(rw http.ResponseWriter, req *http.Request, err error) {
log.Printf("[proxy] Error proxying: %v", err)
logger.Error("Error proxying: %v", err)
http.Error(rw, "Bad Gateway", http.StatusBadGateway)
},
FlushInterval: -1, // immediate flush for SSE
}
if isSSE {
// Add special response handling for SSE connections to rewrite endpoint URLs
rp.Transport = &sseTransport{
Transport: http.DefaultTransport,
proxyHost: r.Host,
targetHost: targetURL.Host,
}
// Set SSE-specific headers
w.Header().Set("X-Accel-Buffering", "no")
w.Header().Set("Cache-Control", "no-cache")
w.Header().Set("Connection", "keep-alive")
// Keep SSE connections open
HandleSSE(w, r, rp)
} else {
@ -236,6 +253,7 @@ func getAllowedOrigin(origin string, cfg *config.Config) string {
return cfg.CORSConfig.AllowedOrigins[0] // Default to first allowed origin
}
for _, allowed := range cfg.CORSConfig.AllowedOrigins {
logger.Debug("Checking CORS origin: %s against allowed: %s", origin, allowed)
if allowed == origin {
return allowed
}
@ -256,6 +274,7 @@ func addCORSHeaders(w http.ResponseWriter, cfg *config.Config, allowedOrigin, re
w.Header().Set("Access-Control-Allow-Credentials", "true")
}
w.Header().Set("Vary", "Origin")
w.Header().Set("X-Accel-Buffering", "no")
}
func isAuthPath(path string) bool {
@ -273,7 +292,8 @@ func isAuthPath(path string) bool {
// isMCPPath checks if the path is an MCP path
func isMCPPath(path string, cfg *config.Config) bool {
for _, p := range cfg.MCPPaths {
mcpPaths := cfg.GetMCPPaths()
for _, p := range mcpPaths {
if strings.HasPrefix(path, p) {
return true
}

View file

@ -1,11 +1,16 @@
package proxy
import (
"bufio"
"context"
"log"
"fmt"
"io"
"net/http"
"net/http/httputil"
"strings"
"time"
"github.com/wso2/open-mcp-auth-proxy/internal/logging"
)
// HandleSSE sets up a go-routine to wait for context cancellation
@ -16,7 +21,7 @@ func HandleSSE(w http.ResponseWriter, r *http.Request, rp *httputil.ReverseProxy
go func() {
<-ctx.Done()
log.Printf("INFO: SSE connection closed from %s (path: %s)", r.RemoteAddr, r.URL.Path)
logger.Info("SSE connection closed from %s (path: %s)", r.RemoteAddr, r.URL.Path)
close(done)
}()
@ -32,3 +37,73 @@ func HandleSSE(w http.ResponseWriter, r *http.Request, rp *httputil.ReverseProxy
func NewShutdownContext(timeout time.Duration) (context.Context, context.CancelFunc) {
return context.WithTimeout(context.Background(), timeout)
}
// sseTransport is a custom http.RoundTripper that intercepts and modifies SSE responses
type sseTransport struct {
Transport http.RoundTripper
proxyHost string
targetHost string
}
func (t *sseTransport) RoundTrip(req *http.Request) (*http.Response, error) {
// Call the underlying transport
resp, err := t.Transport.RoundTrip(req)
if err != nil {
return nil, err
}
// Check if this is an SSE response
contentType := resp.Header.Get("Content-Type")
if !strings.Contains(contentType, "text/event-stream") {
return resp, nil
}
logger.Info("Intercepting SSE response to modify endpoint events")
// Create a response wrapper that modifies the response body
originalBody := resp.Body
pr, pw := io.Pipe()
go func() {
defer originalBody.Close()
defer pw.Close()
scanner := bufio.NewScanner(originalBody)
for scanner.Scan() {
line := scanner.Text()
// Check if this line contains an endpoint event
if strings.HasPrefix(line, "event: endpoint") {
// Read the data line
if scanner.Scan() {
dataLine := scanner.Text()
if strings.HasPrefix(dataLine, "data: ") {
// Extract the endpoint URL
endpoint := strings.TrimPrefix(dataLine, "data: ")
// Replace the host in the endpoint
logger.Debug("Original endpoint: %s", endpoint)
endpoint = strings.Replace(endpoint, t.targetHost, t.proxyHost, 1)
logger.Debug("Modified endpoint: %s", endpoint)
// Write the modified event lines
fmt.Fprintln(pw, line)
fmt.Fprintln(pw, "data: "+endpoint)
continue
}
}
}
// Write the original line for non-endpoint events
fmt.Fprintln(pw, line)
}
if err := scanner.Err(); err != nil {
logger.Error("Error reading SSE stream: %v", err)
}
}()
// Replace the response body with our modified pipe
resp.Body = pr
return resp, nil
}

View file

@ -0,0 +1,268 @@
package subprocess
import (
"fmt"
"os"
"os/exec"
"sync"
"syscall"
"time"
"strings"
"github.com/wso2/open-mcp-auth-proxy/internal/config"
"github.com/wso2/open-mcp-auth-proxy/internal/logging"
)
// Manager handles starting and graceful shutdown of subprocesses
type Manager struct {
process *os.Process
processGroup int
mutex sync.Mutex
cmd *exec.Cmd
shutdownDelay time.Duration
}
// NewManager creates a new subprocess manager
func NewManager() *Manager {
return &Manager{
shutdownDelay: 5 * time.Second,
}
}
// EnsureDependenciesAvailable checks and installs required package executors
func EnsureDependenciesAvailable(command string) error {
// Always ensure npx is available regardless of the command
if _, err := exec.LookPath("npx"); err != nil {
// npx is not available, check if npm is installed
if _, err := exec.LookPath("npm"); err != nil {
return fmt.Errorf("npx not found and npm not available; please install Node.js from https://nodejs.org/")
}
// Try to install npx using npm
logger.Info("npx not found, attempting to install...")
cmd := exec.Command("npm", "install", "-g", "npx")
cmd.Stdout = os.Stdout
cmd.Stderr = os.Stderr
if err := cmd.Run(); err != nil {
return fmt.Errorf("failed to install npx: %w", err)
}
logger.Info("npx installed successfully")
}
// Check if uv is needed based on the command
if strings.Contains(command, "uv ") {
if _, err := exec.LookPath("uv"); err != nil {
return fmt.Errorf("command requires uv but it's not installed; please install it following instructions at https://github.com/astral-sh/uv")
}
}
return nil
}
// SetShutdownDelay sets the maximum time to wait for graceful shutdown
func (m *Manager) SetShutdownDelay(duration time.Duration) {
m.shutdownDelay = duration
}
// Start launches a subprocess based on the configuration
func (m *Manager) Start(cfg *config.Config) error {
m.mutex.Lock()
defer m.mutex.Unlock()
// If a process is already running, return an error
if m.process != nil {
return os.ErrExist
}
if !cfg.Stdio.Enabled || cfg.Stdio.UserCommand == "" {
return nil // Nothing to start
}
// Get the full command string
execCommand := cfg.BuildExecCommand()
if execCommand == "" {
return nil // No command to execute
}
logger.Info("Starting subprocess with command: %s", execCommand)
// Use the shell to execute the command
cmd := exec.Command("sh", "-c", execCommand)
// Set working directory if specified
if cfg.Stdio.WorkDir != "" {
cmd.Dir = cfg.Stdio.WorkDir
}
// Set environment variables if specified
if len(cfg.Stdio.Env) > 0 {
cmd.Env = append(os.Environ(), cfg.Stdio.Env...)
}
// Capture stdout/stderr
cmd.Stdout = os.Stdout
cmd.Stderr = os.Stderr
// Set the process group for proper termination
cmd.SysProcAttr = &syscall.SysProcAttr{Setpgid: true}
// Start the process
if err := cmd.Start(); err != nil {
return err
}
m.process = cmd.Process
m.cmd = cmd
logger.Info("Subprocess started with PID: %d", m.process.Pid)
// Get and store the process group ID
pgid, err := syscall.Getpgid(m.process.Pid)
if err == nil {
m.processGroup = pgid
logger.Debug("Process group ID: %d", m.processGroup)
} else {
logger.Warn("Failed to get process group ID: %v", err)
m.processGroup = m.process.Pid
}
// Handle process termination in background
go func() {
if err := cmd.Wait(); err != nil {
logger.Error("Subprocess exited with error: %v", err)
} else {
logger.Info("Subprocess exited successfully")
}
// Clear the process reference when it exits
m.mutex.Lock()
m.process = nil
m.cmd = nil
m.mutex.Unlock()
}()
return nil
}
// IsRunning checks if the subprocess is running
func (m *Manager) IsRunning() bool {
m.mutex.Lock()
defer m.mutex.Unlock()
return m.process != nil
}
// Shutdown gracefully terminates the subprocess
func (m *Manager) Shutdown() {
m.mutex.Lock()
processToTerminate := m.process // Local copy of the process reference
processGroupToTerminate := m.processGroup
m.mutex.Unlock()
if processToTerminate == nil {
return // No process to terminate
}
logger.Info("Terminating subprocess...")
terminateComplete := make(chan struct{})
go func() {
defer close(terminateComplete)
// Try graceful termination first with SIGTERM
terminatedGracefully := false
// Try to terminate the process group first
if processGroupToTerminate != 0 {
err := syscall.Kill(-processGroupToTerminate, syscall.SIGTERM)
if err != nil {
logger.Warn("Failed to send SIGTERM to process group: %v", err)
// Fallback to terminating just the process
m.mutex.Lock()
if m.process != nil {
err = m.process.Signal(syscall.SIGTERM)
if err != nil {
logger.Warn("Failed to send SIGTERM to process: %v", err)
}
}
m.mutex.Unlock()
}
} else {
// Try to terminate just the process
m.mutex.Lock()
if m.process != nil {
err := m.process.Signal(syscall.SIGTERM)
if err != nil {
logger.Warn("Failed to send SIGTERM to process: %v", err)
}
}
m.mutex.Unlock()
}
// Wait for the process to exit gracefully
for i := 0; i < 10; i++ {
time.Sleep(200 * time.Millisecond)
m.mutex.Lock()
if m.process == nil {
terminatedGracefully = true
m.mutex.Unlock()
break
}
m.mutex.Unlock()
}
if terminatedGracefully {
logger.Info("Subprocess terminated gracefully")
return
}
// If the process didn't exit gracefully, force kill
logger.Warn("Subprocess didn't exit gracefully, forcing termination...")
// Try to kill the process group first
if processGroupToTerminate != 0 {
if err := syscall.Kill(-processGroupToTerminate, syscall.SIGKILL); err != nil {
logger.Warn("Failed to send SIGKILL to process group: %v", err)
// Fallback to killing just the process
m.mutex.Lock()
if m.process != nil {
if err := m.process.Kill(); err != nil {
logger.Error("Failed to kill process: %v", err)
}
}
m.mutex.Unlock()
}
} else {
// Try to kill just the process
m.mutex.Lock()
if m.process != nil {
if err := m.process.Kill(); err != nil {
logger.Error("Failed to kill process: %v", err)
}
}
m.mutex.Unlock()
}
// Wait a bit more to confirm termination
time.Sleep(500 * time.Millisecond)
m.mutex.Lock()
if m.process == nil {
logger.Info("Subprocess terminated by force")
} else {
logger.Warn("Failed to terminate subprocess")
}
m.mutex.Unlock()
}()
// Wait for termination with timeout
select {
case <-terminateComplete:
// Termination completed
case <-time.After(m.shutdownDelay):
logger.Warn("Subprocess termination timed out")
}
}

View file

@ -4,12 +4,12 @@ import (
"crypto/rsa"
"encoding/json"
"errors"
"log"
"math/big"
"net/http"
"strings"
"github.com/golang-jwt/jwt/v4"
"github.com/wso2/open-mcp-auth-proxy/internal/logging"
)
type JWKS struct {
@ -50,7 +50,7 @@ func FetchJWKS(jwksURL string) error {
publicKeys[parsedKey.Kid] = pubKey
}
}
log.Printf("[JWKS] Loaded %d public keys.", len(publicKeys))
logger.Info("Loaded %d public keys.", len(publicKeys))
return nil
}