diff --git a/README.md b/README.md index d00f16e..f197891 100644 --- a/README.md +++ b/README.md @@ -1,101 +1,81 @@ # Open MCP Auth Proxy -The Open MCP Auth Proxy is a lightweight proxy designed to sit in front of MCP servers and enforce authorization in compliance with the [Model Context Protocol authorization](https://spec.modelcontextprotocol.io/specification/2025-03-26/basic/authorization/) requirements. It intercepts incoming requests, validates tokens, and offloads authentication and authorization to an OAuth-compliant Identity Provider. +A lightweight authorization proxy for Model Context Protocol (MCP) servers that enforces authorization according to the [MCP authorization specification](https://spec.modelcontextprotocol.io/specification/2025-03-26/basic/authorization/). -![image](https://github.com/user-attachments/assets/41cf6723-c488-4860-8640-8fec45006f92) +![Architecture Diagram](https://github.com/user-attachments/assets/41cf6723-c488-4860-8640-8fec45006f92) -## **Setup and Installation** +## What it Does -### **Prerequisites** +Open MCP Auth Proxy sits between MCP clients and your MCP server to: -* Go 1.20 or higher -* A running MCP server (SSE transport supported) -* An MCP client that supports MCP authorization +- Intercept incoming requests +- Validate authorization tokens +- Offload authentication and authorization to OAuth-compliant Identity Providers +- Support the MCP authorization protocol -### **Installation** +## Quick Start + +### Prerequisites + +* Go 1.20 or higher +* A running MCP server +* An MCP client that supports MCP authorization + +### Installation ```bash -git clone https://github.com/wso2/open-mcp-auth-proxy -cd open-mcp-auth-proxy - -go get github.com/golang-jwt/jwt/v4 -go get gopkg.in/yaml.v2 - +git clone https://github.com/wso2/open-mcp-auth-proxy +cd open-mcp-auth-proxy +go get github.com/golang-jwt/jwt/v4 gopkg.in/yaml.v2 go build -o openmcpauthproxy ./cmd/proxy ``` -## Using Open MCP Auth Proxy +### Basic Usage -### Quick Start - -Allows you to just enable authentication and authorization for your MCP server with the preconfigured auth provider powered by Asgardeo. - -If you don’t have an MCP server, follow the instructions given here to start your own MCP server for testing purposes. - -1. Navigate to `resources` directory. -2. Initialize a virtual environment. - -```bash -python3 -m venv .venv -``` -3. Activate virtual environment. - -```bash -source .venv/bin/activate -``` - -4. Install dependencies. - -``` -pip3 install -r requirements.txt -``` - -5. Start the server. - -```bash -python3 echo_server.py -``` - -#### Configure the Auth Proxy - -Update the following parameters in `config.yaml`. - -### demo mode configuration: +1. The repository comes with a default `config.yaml` file that contains the basic configuration: ```yaml -mcp_server_base_url: "http://localhost:8000" # URL of your MCP server -listen_port: 8080 # Address where the proxy will listen +listen_port: 8080 +base_url: "http://localhost:8000" # Your MCP server URL +paths: + sse: "/sse" + messages: "/messages/" ``` -#### Start the Auth Proxy +2. Start the proxy in demo mode (uses pre-configured authentication with Asgardeo sandbox): ```bash ./openmcpauthproxy --demo ``` -The `--demo` flag enables a demonstration mode with pre-configured authentication and authorization with a sandbox powered by [Asgardeo](https://asgardeo.io/). +3. Connect using an MCP client like [MCP Inspector](https://github.com/shashimalcse/inspector)(This is a temporary fork with fixes for authentication [issues](https://github.com/modelcontextprotocol/typescript-sdk/issues/257) in the original implementation) -#### Connect Using an MCP Client +## Identity Provider Integration -You can use this fork of the [MCP Inspector](https://github.com/shashimalcse/inspector) to test the connection and try out the complete authorization flow. (This is a temporary fork with fixes for authentication [issues](https://github.com/modelcontextprotocol/typescript-sdk/issues/257) in the original implementation) +### Demo Mode -### Use with Asgardeo +For quick testing, use the `--demo` flag which includes pre-configured authentication and authorization with an Asgardeo sandbox. -Enable authorization for the MCP server through your own Asgardeo organization +```bash +./openmcpauthproxy --demo +``` -1. [Register]([url](https://asgardeo.io/signup)) and create an organization in Asgardeo -2. Now, you need to authorize the OpenMCPAuthProxy to allow dynamically registering MCP Clients as applications in your organization. To do that, - 1. Create an [M2M application](https://wso2.com/asgardeo/docs/guides/applications/register-machine-to-machine-app/) - 1. [Authorize this application](https://wso2.com/asgardeo/docs/guides/applications/register-machine-to-machine-app/#authorize-the-api-resources-for-the-app) to invoke “Application Management API” with the `internal_application_mgt_create` scope. - ![image](https://github.com/user-attachments/assets/0bd57cac-1904-48cc-b7aa-0530224bc41a) - 2. Note the **Client ID** and **Client secret** of this application. This is required by the auth proxy +### Asgardeo Integration + +To enable authorization through your own Asgardeo organization: + +1. [Register](https://asgardeo.io/signup) and create an organization in Asgardeo +2. Create an [M2M application](https://wso2.com/asgardeo/docs/guides/applications/register-machine-to-machine-app/) + 1. [Authorize this application](https://wso2.com/asgardeo/docs/guides/applications/register-machine-to-machine-app/#authorize-the-api-resources-for-the-app) to invoke "Application Management API" with the `internal_application_mgt_create` scope + ![image](https://github.com/user-attachments/assets/0bd57cac-1904-48cc-b7aa-0530224bc41a) + 2. Update the existing `config.yaml` with your Asgardeo details: #### Configure the Auth Proxy Create a configuration file config.yaml with the following parameters: ```yaml -mcp_server_base_url: "http://localhost:8000" # URL of your MCP server +base_url: "http://localhost:8000" # URL of your MCP server listen_port: 8080 # Address where the proxy will listen asgardeo: @@ -104,31 +84,146 @@ asgardeo: client_secret: "" # Client secret of the M2M app ``` -#### Start the Auth Proxy +3. Start the proxy with Asgardeo integration: ```bash ./openmcpauthproxy --asgardeo ``` -### Use with any standard OAuth Server +### Other OAuth Providers -Enable authorization for the MCP server with a compliant OAuth server +- [Auth0 Integration Guide](docs/Auth0.md) -#### Configuration +## Testing with an Example MCP Server -Create a configuration file config.yaml with the following parameters: +If you don't have an MCP server, you can use the included example: -```yaml -mcp_server_base_url: "http://localhost:8000" # URL of your MCP server -listen_port: 8080 # Address where the proxy will listen -``` -**TODO**: Update the configs for a standard OAuth Server. - -#### Start the Auth Proxy +1. Navigate to the `resources` directory +2. Set up a Python environment: ```bash -./openmcpauthproxy +python3 -m venv .venv +source .venv/bin/activate +pip3 install -r requirements.txt ``` -#### Integrating with existing OAuth Providers - - [Auth0](docs/Auth0.md) - Enable authorization for the MCP server through your Auth0 organization. +3. Start the example server: + +```bash +python3 echo_server.py +``` + +# Advanced Configuration + +### Transport Modes + +The proxy supports two transport modes: + +- **SSE Mode (Default)**: For Server-Sent Events transport +- **stdio Mode**: For MCP servers that use stdio transport + +When using stdio mode, the proxy: +- Starts an MCP server as a subprocess using the command specified in the configuration +- Communicates with the subprocess through standard input/output (stdio) +- **Note**: Any commands specified (like `npx` in the example below) must be installed on your system first + +To use stdio mode: + +```bash +./openmcpauthproxy --demo --stdio +``` + +#### Example: Running an MCP Server as a Subprocess + +1. Configure stdio mode in your `config.yaml`: + +```yaml +listen_port: 8080 +base_url: "http://localhost:8000" + +stdio: + enabled: true + user_command: "npx -y @modelcontextprotocol/server-github" # Example using a GitHub MCP server + env: # Environment variables (optional) + - "GITHUB_PERSONAL_ACCESS_TOKEN=gitPAT" + +# CORS configuration +cors: + allowed_origins: + - "http://localhost:5173" # Origin of your client application + allowed_methods: + - "GET" + - "POST" + - "PUT" + - "DELETE" + allowed_headers: + - "Authorization" + - "Content-Type" + allow_credentials: true + +# Demo configuration for Asgardeo +demo: + org_name: "openmcpauthdemo" + client_id: "N0U9e_NNGr9mP_0fPnPfPI0a6twa" + client_secret: "qFHfiBp5gNGAO9zV4YPnDofBzzfInatfUbHyPZvM0jka" +``` + +2. Run the proxy with stdio mode: + +```bash +./openmcpauthproxy --demo +``` + +The proxy will: +- Start the MCP server as a subprocess using the specified command +- Handle all authorization requirements +- Forward messages between clients and the server + +### Complete Configuration Reference + +```yaml +# Common configuration +listen_port: 8080 +base_url: "http://localhost:8000" +port: 8000 + +# Path configuration +paths: + sse: "/sse" + messages: "/messages/" + +# Transport mode +transport_mode: "sse" # Options: "sse" or "stdio" + +# stdio-specific configuration (used only in stdio mode) +stdio: + enabled: true + user_command: "npx -y @modelcontextprotocol/server-github" # Command to start the MCP server (requires npx to be installed) + work_dir: "" # Optional working directory for the subprocess + +# CORS configuration +cors: + allowed_origins: + - "http://localhost:5173" + allowed_methods: + - "GET" + - "POST" + - "PUT" + - "DELETE" + allowed_headers: + - "Authorization" + - "Content-Type" + allow_credentials: true + +# Demo configuration for Asgardeo +demo: + org_name: "openmcpauthdemo" + client_id: "N0U9e_NNGr9mP_0fPnPfPI0a6twa" + client_secret: "qFHfiBp5gNGAO9zV4YPnDofBzzfInatfUbHyPZvM0jka" + +# Asgardeo configuration (used with --asgardeo flag) +asgardeo: + org_name: "" + client_id: "" + client_secret: "" +``` \ No newline at end of file diff --git a/cmd/proxy/main.go b/cmd/proxy/main.go index cde3cf3..6424f18 100644 --- a/cmd/proxy/main.go +++ b/cmd/proxy/main.go @@ -3,31 +3,71 @@ package main import ( "flag" "fmt" - "log" "net/http" "os" "os/signal" + "syscall" "time" "github.com/wso2/open-mcp-auth-proxy/internal/authz" "github.com/wso2/open-mcp-auth-proxy/internal/config" "github.com/wso2/open-mcp-auth-proxy/internal/constants" + "github.com/wso2/open-mcp-auth-proxy/internal/logging" "github.com/wso2/open-mcp-auth-proxy/internal/proxy" + "github.com/wso2/open-mcp-auth-proxy/internal/subprocess" "github.com/wso2/open-mcp-auth-proxy/internal/util" ) func main() { demoMode := flag.Bool("demo", false, "Use Asgardeo-based provider (demo).") asgardeoMode := flag.Bool("asgardeo", false, "Use Asgardeo-based provider (asgardeo).") + debugMode := flag.Bool("debug", false, "Enable debug logging") + stdioMode := flag.Bool("stdio", false, "Use stdio transport mode instead of SSE") flag.Parse() + logger.SetDebug(*debugMode) + // 1. Load config cfg, err := config.LoadConfig("config.yaml") if err != nil { - log.Fatalf("Error loading config: %v", err) + logger.Error("Error loading config: %v", err) + os.Exit(1) } - // 2. Create the chosen provider + // Override transport mode if stdio flag is set + if *stdioMode { + cfg.TransportMode = config.StdioTransport + // Ensure stdio is enabled + cfg.Stdio.Enabled = true + // Re-validate config + if err := cfg.Validate(); err != nil { + logger.Error("Configuration error: %v", err) + os.Exit(1) + } + } + + logger.Info("Using transport mode: %s", cfg.TransportMode) + logger.Info("Using MCP server base URL: %s", cfg.BaseURL) + logger.Info("Using MCP paths: SSE=%s, Messages=%s", cfg.Paths.SSE, cfg.Paths.Messages) + + // 2. Start subprocess if configured and in stdio mode + var procManager *subprocess.Manager + if cfg.TransportMode == config.StdioTransport && cfg.Stdio.Enabled { + // Ensure all required dependencies are available + if err := subprocess.EnsureDependenciesAvailable(cfg.Stdio.UserCommand); err != nil { + logger.Warn("%v", err) + logger.Warn("Subprocess may fail to start due to missing dependencies") + } + + procManager = subprocess.NewManager() + if err := procManager.Start(cfg); err != nil { + logger.Warn("Failed to start subprocess: %v", err) + } + } else if cfg.TransportMode == config.SSETransport { + logger.Info("Using SSE transport mode, not starting subprocess") + } + + // 3. Create the chosen provider var provider authz.Provider if *demoMode { cfg.Mode = "demo" @@ -46,41 +86,49 @@ func main() { provider = authz.NewDefaultProvider(cfg) } - // 3. (Optional) Fetch JWKS if you want local JWT validation + // 4. (Optional) Fetch JWKS if you want local JWT validation if err := util.FetchJWKS(cfg.JWKSURL); err != nil { - log.Fatalf("Failed to fetch JWKS: %v", err) + logger.Error("Failed to fetch JWKS: %v", err) + os.Exit(1) } - // 4. Build the main router + // 5. Build the main router mux := proxy.NewRouter(cfg, provider) listen_address := fmt.Sprintf(":%d", cfg.ListenPort) - // 5. Start the server + // 6. Start the server srv := &http.Server{ - Addr: listen_address, Handler: mux, } go func() { - log.Printf("Server listening on %s", listen_address) + logger.Info("Server listening on %s", listen_address) if err := srv.ListenAndServe(); err != nil && err != http.ErrServerClosed { - log.Fatalf("Server error: %v", err) + logger.Error("Server error: %v", err) + os.Exit(1) } }() - // 6. Graceful shutdown on Ctrl+C + // 7. Wait for shutdown signal stop := make(chan os.Signal, 1) - signal.Notify(stop, os.Interrupt) + signal.Notify(stop, os.Interrupt, syscall.SIGTERM) <-stop - log.Println("Shutting down...") + logger.Info("Shutting down...") + // 8. First terminate subprocess if running + if procManager != nil && procManager.IsRunning() { + procManager.Shutdown() + } + + // 9. Then shutdown the server + logger.Info("Shutting down HTTP server...") shutdownCtx, cancel := proxy.NewShutdownContext(5 * time.Second) defer cancel() if err := srv.Shutdown(shutdownCtx); err != nil { - log.Printf("Shutdown error: %v", err) + logger.Error("HTTP server shutdown error: %v", err) } - log.Println("Stopped.") + logger.Info("Stopped.") } diff --git a/config.yaml b/config.yaml index b949380..971b93c 100644 --- a/config.yaml +++ b/config.yaml @@ -1,15 +1,31 @@ # config.yaml -mcp_server_base_url: "" +# Common configuration for all transport modes listen_port: 8080 +base_url: "http://localhost:8000" # Base URL for the MCP server +port: 8000 # Port for the MCP server timeout_seconds: 10 -mcp_paths: - - /messages/ - - /sse +# Path configuration +paths: + sse: "/sse" # SSE endpoint path + messages: "/messages/" # Messages endpoint path +# Transport mode configuration +transport_mode: "sse" # Options: "sse" or "stdio" + +# stdio-specific configuration (used only when transport_mode is "stdio") +stdio: + enabled: true + user_command: "npx -y @modelcontextprotocol/server-github" + work_dir: "" # Working directory (optional) + # env: # Environment variables (optional) + # - "NODE_ENV=development" + +# Path mapping (optional) path_mapping: +# CORS configuration cors: allowed_origins: - "http://localhost:5173" @@ -23,10 +39,8 @@ cors: - "Content-Type" allow_credentials: true +# Demo configuration for Asgardeo demo: org_name: "openmcpauthdemo" client_id: "N0U9e_NNGr9mP_0fPnPfPI0a6twa" client_secret: "qFHfiBp5gNGAO9zV4YPnDofBzzfInatfUbHyPZvM0jka" - - - diff --git a/internal/authz/asgardeo.go b/internal/authz/asgardeo.go index 7408f79..a3c812c 100644 --- a/internal/authz/asgardeo.go +++ b/internal/authz/asgardeo.go @@ -7,13 +7,13 @@ import ( "encoding/json" "fmt" "io" - "log" "math/rand" "net/http" "strings" "time" "github.com/wso2/open-mcp-auth-proxy/internal/config" + "github.com/wso2/open-mcp-auth-proxy/internal/logging" ) type asgardeoProvider struct { @@ -31,6 +31,7 @@ func (p *asgardeoProvider) WellKnownHandler() http.HandlerFunc { w.Header().Set("Access-Control-Allow-Origin", "*") w.Header().Set("Access-Control-Allow-Headers", "Authorization, Content-Type") w.Header().Set("Access-Control-Allow-Methods", "GET, OPTIONS") + w.Header().Set("X-Accel-Buffering", "no") if r.Method == http.MethodOptions { w.WriteHeader(http.StatusNoContent) @@ -70,8 +71,9 @@ func (p *asgardeoProvider) WellKnownHandler() http.HandlerFunc { } w.Header().Set("Content-Type", "application/json") + w.Header().Set("X-Accel-Buffering", "no") if err := json.NewEncoder(w).Encode(response); err != nil { - log.Printf("[asgardeoProvider] Error encoding well-known: %v", err) + logger.Error("Error encoding well-known: %v", err) http.Error(w, "Internal server error", http.StatusInternalServerError) } } @@ -83,6 +85,7 @@ func (p *asgardeoProvider) RegisterHandler() http.HandlerFunc { w.Header().Set("Access-Control-Allow-Origin", "*") w.Header().Set("Access-Control-Allow-Headers", "Authorization, Content-Type") w.Header().Set("Access-Control-Allow-Methods", "GET, OPTIONS") + w.Header().Set("X-Accel-Buffering", "no") if r.Method == http.MethodOptions { w.WriteHeader(http.StatusNoContent) @@ -95,7 +98,7 @@ func (p *asgardeoProvider) RegisterHandler() http.HandlerFunc { var regReq RegisterRequest if err := json.NewDecoder(r.Body).Decode(®Req); err != nil { - log.Printf("ERROR: reading register request: %v", err) + logger.Error("Reading register request: %v", err) http.Error(w, "Invalid request body", http.StatusBadRequest) return } @@ -109,7 +112,7 @@ func (p *asgardeoProvider) RegisterHandler() http.HandlerFunc { regReq.ClientSecret = randomString(16) if err := p.createAsgardeoApplication(regReq); err != nil { - log.Printf("WARN: Asgardeo application creation failed: %v", err) + logger.Warn("Asgardeo application creation failed: %v", err) // Optionally http.Error(...) if you want to fail // or continue to return partial data. } @@ -124,9 +127,10 @@ func (p *asgardeoProvider) RegisterHandler() http.HandlerFunc { } w.Header().Set("Content-Type", "application/json") + w.Header().Set("X-Accel-Buffering", "no") w.WriteHeader(http.StatusCreated) if err := json.NewEncoder(w).Encode(resp); err != nil { - log.Printf("ERROR: encoding /register response: %v", err) + logger.Error("Encoding /register response: %v", err) http.Error(w, "Internal server error", http.StatusInternalServerError) } } @@ -186,7 +190,7 @@ func (p *asgardeoProvider) createAsgardeoApplication(regReq RegisterRequest) err return fmt.Errorf("Asgardeo creation error (%d): %s", resp.StatusCode, string(respBody)) } - log.Printf("INFO: Created Asgardeo application for clientID=%s", regReq.ClientID) + logger.Info("Created Asgardeo application for clientID=%s", regReq.ClientID) return nil } @@ -202,8 +206,11 @@ func (p *asgardeoProvider) getAsgardeoAdminToken() (string, error) { } req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + // Sensitive data - should not be logged at INFO level auth := p.cfg.Demo.ClientID + ":" + p.cfg.Demo.ClientSecret req.Header.Set("Authorization", "Basic "+base64.StdEncoding.EncodeToString([]byte(auth))) + + logger.Debug("Requesting admin token for Asgardeo with client ID: %s", p.cfg.Demo.ClientID) tr := &http.Transport{ TLSClientConfig: &tls.Config{InsecureSkipVerify: true}, @@ -234,6 +241,10 @@ func (p *asgardeoProvider) getAsgardeoAdminToken() (string, error) { return "", fmt.Errorf("failed to parse token JSON: %w", err) } + // Don't log the actual token at info level, only at debug level + logger.Debug("Received access token: %s", tokenResp.AccessToken) + logger.Info("Successfully obtained admin token from Asgardeo") + return tokenResp.AccessToken, nil } diff --git a/internal/authz/default.go b/internal/authz/default.go index 9230d39..929f586 100644 --- a/internal/authz/default.go +++ b/internal/authz/default.go @@ -5,6 +5,7 @@ import ( "net/http" "github.com/wso2/open-mcp-auth-proxy/internal/config" + "github.com/wso2/open-mcp-auth-proxy/internal/logging" ) type defaultProvider struct { @@ -81,6 +82,7 @@ func (p *defaultProvider) WellKnownHandler() http.HandlerFunc { w.Header().Set("Content-Type", "application/json") if err := json.NewEncoder(w).Encode(response); err != nil { + logger.Error("Error encoding well-known response: %v", err) http.Error(w, "Internal server error", http.StatusInternalServerError) } return diff --git a/internal/config/config.go b/internal/config/config.go index 01c3a6f..fc6743c 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -1,12 +1,35 @@ package config import ( + "fmt" "os" "gopkg.in/yaml.v2" ) -// AsgardeoConfig groups all Asgardeo-specific fields +// Transport mode for MCP server +type TransportMode string + +const ( + SSETransport TransportMode = "sse" + StdioTransport TransportMode = "stdio" +) + +// Common path configuration for all transport modes +type PathsConfig struct { + SSE string `yaml:"sse"` + Messages string `yaml:"messages"` +} + +// StdioConfig contains stdio-specific configuration +type StdioConfig struct { + Enabled bool `yaml:"enabled"` + UserCommand string `yaml:"user_command"` // The command provided by the user + WorkDir string `yaml:"work_dir"` // Working directory (optional) + Args []string `yaml:"args,omitempty"` // Additional arguments + Env []string `yaml:"env,omitempty"` // Environment variables +} + type DemoConfig struct { ClientID string `yaml:"client_id"` ClientSecret string `yaml:"client_secret"` @@ -60,15 +83,18 @@ type DefaultConfig struct { } type Config struct { - AuthServerBaseURL string - MCPServerBaseURL string `yaml:"mcp_server_base_url"` - ListenPort int `yaml:"listen_port"` - JWKSURL string - TimeoutSeconds int `yaml:"timeout_seconds"` - MCPPaths []string `yaml:"mcp_paths"` - PathMapping map[string]string `yaml:"path_mapping"` - Mode string `yaml:"mode"` - CORSConfig CORSConfig `yaml:"cors"` + AuthServerBaseURL string + ListenPort int `yaml:"listen_port"` + BaseURL string `yaml:"base_url"` + Port int `yaml:"port"` + JWKSURL string + TimeoutSeconds int `yaml:"timeout_seconds"` + PathMapping map[string]string `yaml:"path_mapping"` + Mode string `yaml:"mode"` + CORSConfig CORSConfig `yaml:"cors"` + TransportMode TransportMode `yaml:"transport_mode"` + Paths PathsConfig `yaml:"paths"` + Stdio StdioConfig `yaml:"stdio"` // Nested config for Asgardeo Demo DemoConfig `yaml:"demo"` @@ -76,6 +102,56 @@ type Config struct { Default DefaultConfig `yaml:"default"` } +// Validate checks if the config is valid based on transport mode +func (c *Config) Validate() error { + // Validate based on transport mode + if c.TransportMode == StdioTransport { + if !c.Stdio.Enabled { + return fmt.Errorf("stdio.enabled must be true in stdio transport mode") + } + if c.Stdio.UserCommand == "" { + return fmt.Errorf("stdio.user_command is required in stdio transport mode") + } + } + + // Validate paths + if c.Paths.SSE == "" { + c.Paths.SSE = "/sse" // Default value + } + if c.Paths.Messages == "" { + c.Paths.Messages = "/messages" // Default value + } + + // Validate base URL + if c.BaseURL == "" { + if c.Port > 0 { + c.BaseURL = fmt.Sprintf("http://localhost:%d", c.Port) + } else { + c.BaseURL = "http://localhost:8000" // Default value + } + } + + return nil +} + +// GetMCPPaths returns the list of paths that should be proxied to the MCP server +func (c *Config) GetMCPPaths() []string { + return []string{c.Paths.SSE, c.Paths.Messages} +} + +// BuildExecCommand constructs the full command string for execution in stdio mode +func (c *Config) BuildExecCommand() string { + if c.Stdio.UserCommand == "" { + return "" + } + + // Construct the full command + return fmt.Sprintf( + `npx -y supergateway --stdio "%s" --port %d --baseUrl %s --ssePath %s --messagePath %s`, + c.Stdio.UserCommand, c.Port, c.BaseURL, c.Paths.SSE, c.Paths.Messages, + ) +} + // LoadConfig reads a YAML config file into Config struct. func LoadConfig(path string) (*Config, error) { f, err := os.Open(path) @@ -89,8 +165,26 @@ func LoadConfig(path string) (*Config, error) { if err := decoder.Decode(&cfg); err != nil { return nil, err } + + // Set default values if cfg.TimeoutSeconds == 0 { cfg.TimeoutSeconds = 15 // default } + + // Set default transport mode if not specified + if cfg.TransportMode == "" { + cfg.TransportMode = SSETransport // Default to SSE + } + + // Set default port if not specified + if cfg.Port == 0 { + cfg.Port = 8000 // default + } + + // Validate the configuration + if err := cfg.Validate(); err != nil { + return nil, err + } + return &cfg, nil } diff --git a/internal/logging/logger.go b/internal/logging/logger.go new file mode 100644 index 0000000..57bec27 --- /dev/null +++ b/internal/logging/logger.go @@ -0,0 +1,34 @@ +package logger + +import ( + "log" +) + +var isDebug = false + +// SetDebug enables or disables debug logging +func SetDebug(debug bool) { + isDebug = debug +} + +// Debug logs a debug-level message +func Debug(format string, v ...interface{}) { + if isDebug { + log.Printf("DEBUG: "+format, v...) + } +} + +// Info logs an info-level message +func Info(format string, v ...interface{}) { + log.Printf("INFO: "+format, v...) +} + +// Warn logs a warning-level message +func Warn(format string, v ...interface{}) { + log.Printf("WARN: "+format, v...) +} + +// Error logs an error-level message +func Error(format string, v ...interface{}) { + log.Printf("ERROR: "+format, v...) +} diff --git a/internal/proxy/modifier.go b/internal/proxy/modifier.go index 8e2268b..6662b2c 100644 --- a/internal/proxy/modifier.go +++ b/internal/proxy/modifier.go @@ -9,6 +9,7 @@ import ( "strings" "github.com/wso2/open-mcp-auth-proxy/internal/config" + "github.com/wso2/open-mcp-auth-proxy/internal/logging" ) // RequestModifier modifies requests before they are proxied @@ -148,6 +149,7 @@ func (m *RegisterModifier) ModifyRequest(req *http.Request) (*http.Request, erro if strings.Contains(contentType, "application/x-www-form-urlencoded") { // Parse form data if err := req.ParseForm(); err != nil { + logger.Error("Failed to parse form data: %v", err) return nil, err } @@ -169,12 +171,14 @@ func (m *RegisterModifier) ModifyRequest(req *http.Request) (*http.Request, erro // Read body bodyBytes, err := io.ReadAll(req.Body) if err != nil { + logger.Error("Failed to read request body: %v", err) return nil, err } // Parse JSON var jsonData map[string]interface{} if err := json.Unmarshal(bodyBytes, &jsonData); err != nil { + logger.Error("Failed to parse JSON body: %v", err) return nil, err } @@ -186,6 +190,7 @@ func (m *RegisterModifier) ModifyRequest(req *http.Request) (*http.Request, erro // Marshal back to JSON modifiedBody, err := json.Marshal(jsonData) if err != nil { + logger.Error("Failed to marshal modified JSON: %v", err) return nil, err } diff --git a/internal/proxy/proxy.go b/internal/proxy/proxy.go index c999be4..33a9ea3 100644 --- a/internal/proxy/proxy.go +++ b/internal/proxy/proxy.go @@ -2,7 +2,6 @@ package proxy import ( "context" - "log" "net/http" "net/http/httputil" "net/url" @@ -11,6 +10,7 @@ import ( "github.com/wso2/open-mcp-auth-proxy/internal/authz" "github.com/wso2/open-mcp-auth-proxy/internal/config" + "github.com/wso2/open-mcp-auth-proxy/internal/logging" "github.com/wso2/open-mcp-auth-proxy/internal/util" ) @@ -82,7 +82,8 @@ func NewRouter(cfg *config.Config, provider authz.Provider) http.Handler { } // MCP paths - for _, path := range cfg.MCPPaths { + mcpPaths := cfg.GetMCPPaths() + for _, path := range mcpPaths { mux.HandleFunc(path, buildProxyHandler(cfg, modifiers)) registeredPaths[path] = true } @@ -100,23 +101,21 @@ func NewRouter(cfg *config.Config, provider authz.Provider) http.Handler { func buildProxyHandler(cfg *config.Config, modifiers map[string]RequestModifier) http.HandlerFunc { // Parse the base URLs up front - authBase, err := url.Parse(cfg.AuthServerBaseURL) if err != nil { - log.Fatalf("Invalid auth server URL: %v", err) + logger.Error("Invalid auth server URL: %v", err) + panic(err) // Fatal error that prevents startup } - mcpBase, err := url.Parse(cfg.MCPServerBaseURL) + + mcpBase, err := url.Parse(cfg.BaseURL) if err != nil { - log.Fatalf("Invalid MCP server URL: %v", err) + logger.Error("Invalid MCP server URL: %v", err) + panic(err) // Fatal error that prevents startup } // Detect SSE paths from config ssePaths := make(map[string]bool) - for _, p := range cfg.MCPPaths { - if p == "/sse" { - ssePaths[p] = true - } - } + ssePaths[cfg.Paths.SSE] = true return func(w http.ResponseWriter, r *http.Request) { origin := r.Header.Get("Origin") @@ -124,7 +123,7 @@ func buildProxyHandler(cfg *config.Config, modifiers map[string]RequestModifier) // Handle OPTIONS if r.Method == http.MethodOptions { if allowedOrigin == "" { - log.Printf("[proxy] Preflight request from disallowed origin: %s", origin) + logger.Warn("Preflight request from disallowed origin: %s", origin) http.Error(w, "CORS origin not allowed", http.StatusForbidden) return } @@ -134,7 +133,7 @@ func buildProxyHandler(cfg *config.Config, modifiers map[string]RequestModifier) } if allowedOrigin == "" { - log.Printf("[proxy] Request from disallowed origin: %s for %s", origin, r.URL.Path) + logger.Warn("Request from disallowed origin: %s for %s", origin, r.URL.Path) http.Error(w, "CORS origin not allowed", http.StatusForbidden) return } @@ -152,7 +151,7 @@ func buildProxyHandler(cfg *config.Config, modifiers map[string]RequestModifier) // Validate JWT for MCP paths if required // Placeholder for JWT validation logic if err := util.ValidateJWT(r.Header.Get("Authorization")); err != nil { - log.Printf("[proxy] Unauthorized request to %s: %v", r.URL.Path, err) + logger.Warn("Unauthorized request to %s: %v", r.URL.Path, err) http.Error(w, "Unauthorized", http.StatusUnauthorized) return } @@ -170,7 +169,7 @@ func buildProxyHandler(cfg *config.Config, modifiers map[string]RequestModifier) var err error r, err = modifier.ModifyRequest(r) if err != nil { - log.Printf("[proxy] Error modifying request: %v", err) + logger.Error("Error modifying request: %v", err) http.Error(w, "Bad Request", http.StatusBadRequest) return } @@ -192,7 +191,13 @@ func buildProxyHandler(cfg *config.Config, modifiers map[string]RequestModifier) req.Host = targetURL.Host cleanHeaders := http.Header{} - + + // Set proper origin header to match the target + if isSSE { + // For SSE, ensure origin matches the target + req.Header.Set("Origin", targetURL.Scheme+"://"+targetURL.Host) + } + for k, v := range r.Header { // Skip hop-by-hop headers if skipHeader(k) { @@ -205,21 +210,33 @@ func buildProxyHandler(cfg *config.Config, modifiers map[string]RequestModifier) req.Header = cleanHeaders - log.Printf("[proxy] %s -> %s%s", r.URL.Path, req.URL.Host, req.URL.Path) + logger.Debug("%s -> %s%s", r.URL.Path, req.URL.Host, req.URL.Path) }, ModifyResponse: func(resp *http.Response) error { - log.Printf("[proxy] Response from %s%s: %d", resp.Request.URL.Host, resp.Request.URL.Path, resp.StatusCode) + logger.Debug("Response from %s%s: %d", resp.Request.URL.Host, resp.Request.URL.Path, resp.StatusCode) resp.Header.Del("Access-Control-Allow-Origin") // Avoid upstream conflicts return nil }, ErrorHandler: func(rw http.ResponseWriter, req *http.Request, err error) { - log.Printf("[proxy] Error proxying: %v", err) + logger.Error("Error proxying: %v", err) http.Error(rw, "Bad Gateway", http.StatusBadGateway) }, FlushInterval: -1, // immediate flush for SSE } if isSSE { + // Add special response handling for SSE connections to rewrite endpoint URLs + rp.Transport = &sseTransport{ + Transport: http.DefaultTransport, + proxyHost: r.Host, + targetHost: targetURL.Host, + } + + // Set SSE-specific headers + w.Header().Set("X-Accel-Buffering", "no") + w.Header().Set("Cache-Control", "no-cache") + w.Header().Set("Connection", "keep-alive") + // Keep SSE connections open HandleSSE(w, r, rp) } else { @@ -236,6 +253,7 @@ func getAllowedOrigin(origin string, cfg *config.Config) string { return cfg.CORSConfig.AllowedOrigins[0] // Default to first allowed origin } for _, allowed := range cfg.CORSConfig.AllowedOrigins { + logger.Debug("Checking CORS origin: %s against allowed: %s", origin, allowed) if allowed == origin { return allowed } @@ -256,6 +274,7 @@ func addCORSHeaders(w http.ResponseWriter, cfg *config.Config, allowedOrigin, re w.Header().Set("Access-Control-Allow-Credentials", "true") } w.Header().Set("Vary", "Origin") + w.Header().Set("X-Accel-Buffering", "no") } func isAuthPath(path string) bool { @@ -273,7 +292,8 @@ func isAuthPath(path string) bool { // isMCPPath checks if the path is an MCP path func isMCPPath(path string, cfg *config.Config) bool { - for _, p := range cfg.MCPPaths { + mcpPaths := cfg.GetMCPPaths() + for _, p := range mcpPaths { if strings.HasPrefix(path, p) { return true } diff --git a/internal/proxy/sse.go b/internal/proxy/sse.go index 44d6558..ce72e04 100644 --- a/internal/proxy/sse.go +++ b/internal/proxy/sse.go @@ -1,11 +1,16 @@ package proxy import ( + "bufio" "context" - "log" + "fmt" + "io" "net/http" "net/http/httputil" + "strings" "time" + + "github.com/wso2/open-mcp-auth-proxy/internal/logging" ) // HandleSSE sets up a go-routine to wait for context cancellation @@ -16,7 +21,7 @@ func HandleSSE(w http.ResponseWriter, r *http.Request, rp *httputil.ReverseProxy go func() { <-ctx.Done() - log.Printf("INFO: SSE connection closed from %s (path: %s)", r.RemoteAddr, r.URL.Path) + logger.Info("SSE connection closed from %s (path: %s)", r.RemoteAddr, r.URL.Path) close(done) }() @@ -32,3 +37,73 @@ func HandleSSE(w http.ResponseWriter, r *http.Request, rp *httputil.ReverseProxy func NewShutdownContext(timeout time.Duration) (context.Context, context.CancelFunc) { return context.WithTimeout(context.Background(), timeout) } + +// sseTransport is a custom http.RoundTripper that intercepts and modifies SSE responses +type sseTransport struct { + Transport http.RoundTripper + proxyHost string + targetHost string +} + +func (t *sseTransport) RoundTrip(req *http.Request) (*http.Response, error) { + // Call the underlying transport + resp, err := t.Transport.RoundTrip(req) + if err != nil { + return nil, err + } + + // Check if this is an SSE response + contentType := resp.Header.Get("Content-Type") + if !strings.Contains(contentType, "text/event-stream") { + return resp, nil + } + + logger.Info("Intercepting SSE response to modify endpoint events") + + // Create a response wrapper that modifies the response body + originalBody := resp.Body + pr, pw := io.Pipe() + + go func() { + defer originalBody.Close() + defer pw.Close() + + scanner := bufio.NewScanner(originalBody) + for scanner.Scan() { + line := scanner.Text() + + // Check if this line contains an endpoint event + if strings.HasPrefix(line, "event: endpoint") { + // Read the data line + if scanner.Scan() { + dataLine := scanner.Text() + if strings.HasPrefix(dataLine, "data: ") { + // Extract the endpoint URL + endpoint := strings.TrimPrefix(dataLine, "data: ") + + // Replace the host in the endpoint + logger.Debug("Original endpoint: %s", endpoint) + endpoint = strings.Replace(endpoint, t.targetHost, t.proxyHost, 1) + logger.Debug("Modified endpoint: %s", endpoint) + + // Write the modified event lines + fmt.Fprintln(pw, line) + fmt.Fprintln(pw, "data: "+endpoint) + continue + } + } + } + + // Write the original line for non-endpoint events + fmt.Fprintln(pw, line) + } + + if err := scanner.Err(); err != nil { + logger.Error("Error reading SSE stream: %v", err) + } + }() + + // Replace the response body with our modified pipe + resp.Body = pr + return resp, nil +} diff --git a/internal/subprocess/manager.go b/internal/subprocess/manager.go new file mode 100644 index 0000000..fa64337 --- /dev/null +++ b/internal/subprocess/manager.go @@ -0,0 +1,268 @@ +package subprocess + +import ( + "fmt" + "os" + "os/exec" + "sync" + "syscall" + "time" + "strings" + + "github.com/wso2/open-mcp-auth-proxy/internal/config" + "github.com/wso2/open-mcp-auth-proxy/internal/logging" +) + +// Manager handles starting and graceful shutdown of subprocesses +type Manager struct { + process *os.Process + processGroup int + mutex sync.Mutex + cmd *exec.Cmd + shutdownDelay time.Duration +} + +// NewManager creates a new subprocess manager +func NewManager() *Manager { + return &Manager{ + shutdownDelay: 5 * time.Second, + } +} + +// EnsureDependenciesAvailable checks and installs required package executors +func EnsureDependenciesAvailable(command string) error { + // Always ensure npx is available regardless of the command + if _, err := exec.LookPath("npx"); err != nil { + // npx is not available, check if npm is installed + if _, err := exec.LookPath("npm"); err != nil { + return fmt.Errorf("npx not found and npm not available; please install Node.js from https://nodejs.org/") + } + + // Try to install npx using npm + logger.Info("npx not found, attempting to install...") + cmd := exec.Command("npm", "install", "-g", "npx") + cmd.Stdout = os.Stdout + cmd.Stderr = os.Stderr + + if err := cmd.Run(); err != nil { + return fmt.Errorf("failed to install npx: %w", err) + } + + logger.Info("npx installed successfully") + } + + // Check if uv is needed based on the command + if strings.Contains(command, "uv ") { + if _, err := exec.LookPath("uv"); err != nil { + return fmt.Errorf("command requires uv but it's not installed; please install it following instructions at https://github.com/astral-sh/uv") + } + } + + return nil +} + +// SetShutdownDelay sets the maximum time to wait for graceful shutdown +func (m *Manager) SetShutdownDelay(duration time.Duration) { + m.shutdownDelay = duration +} + +// Start launches a subprocess based on the configuration +func (m *Manager) Start(cfg *config.Config) error { + m.mutex.Lock() + defer m.mutex.Unlock() + + // If a process is already running, return an error + if m.process != nil { + return os.ErrExist + } + + if !cfg.Stdio.Enabled || cfg.Stdio.UserCommand == "" { + return nil // Nothing to start + } + + // Get the full command string + execCommand := cfg.BuildExecCommand() + if execCommand == "" { + return nil // No command to execute + } + + logger.Info("Starting subprocess with command: %s", execCommand) + + // Use the shell to execute the command + cmd := exec.Command("sh", "-c", execCommand) + + // Set working directory if specified + if cfg.Stdio.WorkDir != "" { + cmd.Dir = cfg.Stdio.WorkDir + } + + // Set environment variables if specified + if len(cfg.Stdio.Env) > 0 { + cmd.Env = append(os.Environ(), cfg.Stdio.Env...) + } + + // Capture stdout/stderr + cmd.Stdout = os.Stdout + cmd.Stderr = os.Stderr + + // Set the process group for proper termination + cmd.SysProcAttr = &syscall.SysProcAttr{Setpgid: true} + + // Start the process + if err := cmd.Start(); err != nil { + return err + } + + m.process = cmd.Process + m.cmd = cmd + logger.Info("Subprocess started with PID: %d", m.process.Pid) + + // Get and store the process group ID + pgid, err := syscall.Getpgid(m.process.Pid) + if err == nil { + m.processGroup = pgid + logger.Debug("Process group ID: %d", m.processGroup) + } else { + logger.Warn("Failed to get process group ID: %v", err) + m.processGroup = m.process.Pid + } + + // Handle process termination in background + go func() { + if err := cmd.Wait(); err != nil { + logger.Error("Subprocess exited with error: %v", err) + } else { + logger.Info("Subprocess exited successfully") + } + + // Clear the process reference when it exits + m.mutex.Lock() + m.process = nil + m.cmd = nil + m.mutex.Unlock() + }() + + return nil +} + +// IsRunning checks if the subprocess is running +func (m *Manager) IsRunning() bool { + m.mutex.Lock() + defer m.mutex.Unlock() + return m.process != nil +} + +// Shutdown gracefully terminates the subprocess +func (m *Manager) Shutdown() { + m.mutex.Lock() + processToTerminate := m.process // Local copy of the process reference + processGroupToTerminate := m.processGroup + m.mutex.Unlock() + + if processToTerminate == nil { + return // No process to terminate + } + + logger.Info("Terminating subprocess...") + terminateComplete := make(chan struct{}) + + go func() { + defer close(terminateComplete) + + // Try graceful termination first with SIGTERM + terminatedGracefully := false + + // Try to terminate the process group first + if processGroupToTerminate != 0 { + err := syscall.Kill(-processGroupToTerminate, syscall.SIGTERM) + if err != nil { + logger.Warn("Failed to send SIGTERM to process group: %v", err) + + // Fallback to terminating just the process + m.mutex.Lock() + if m.process != nil { + err = m.process.Signal(syscall.SIGTERM) + if err != nil { + logger.Warn("Failed to send SIGTERM to process: %v", err) + } + } + m.mutex.Unlock() + } + } else { + // Try to terminate just the process + m.mutex.Lock() + if m.process != nil { + err := m.process.Signal(syscall.SIGTERM) + if err != nil { + logger.Warn("Failed to send SIGTERM to process: %v", err) + } + } + m.mutex.Unlock() + } + + // Wait for the process to exit gracefully + for i := 0; i < 10; i++ { + time.Sleep(200 * time.Millisecond) + + m.mutex.Lock() + if m.process == nil { + terminatedGracefully = true + m.mutex.Unlock() + break + } + m.mutex.Unlock() + } + + if terminatedGracefully { + logger.Info("Subprocess terminated gracefully") + return + } + + // If the process didn't exit gracefully, force kill + logger.Warn("Subprocess didn't exit gracefully, forcing termination...") + + // Try to kill the process group first + if processGroupToTerminate != 0 { + if err := syscall.Kill(-processGroupToTerminate, syscall.SIGKILL); err != nil { + logger.Warn("Failed to send SIGKILL to process group: %v", err) + + // Fallback to killing just the process + m.mutex.Lock() + if m.process != nil { + if err := m.process.Kill(); err != nil { + logger.Error("Failed to kill process: %v", err) + } + } + m.mutex.Unlock() + } + } else { + // Try to kill just the process + m.mutex.Lock() + if m.process != nil { + if err := m.process.Kill(); err != nil { + logger.Error("Failed to kill process: %v", err) + } + } + m.mutex.Unlock() + } + + // Wait a bit more to confirm termination + time.Sleep(500 * time.Millisecond) + + m.mutex.Lock() + if m.process == nil { + logger.Info("Subprocess terminated by force") + } else { + logger.Warn("Failed to terminate subprocess") + } + m.mutex.Unlock() + }() + + // Wait for termination with timeout + select { + case <-terminateComplete: + // Termination completed + case <-time.After(m.shutdownDelay): + logger.Warn("Subprocess termination timed out") + } +} diff --git a/internal/util/jwks.go b/internal/util/jwks.go index 4832bf8..f80d82e 100644 --- a/internal/util/jwks.go +++ b/internal/util/jwks.go @@ -4,12 +4,12 @@ import ( "crypto/rsa" "encoding/json" "errors" - "log" "math/big" "net/http" "strings" "github.com/golang-jwt/jwt/v4" + "github.com/wso2/open-mcp-auth-proxy/internal/logging" ) type JWKS struct { @@ -50,7 +50,7 @@ func FetchJWKS(jwksURL string) error { publicKeys[parsedKey.Kid] = pubKey } } - log.Printf("[JWKS] Loaded %d public keys.", len(publicKeys)) + logger.Info("Loaded %d public keys.", len(publicKeys)) return nil }