Add transport mode support for stdio, SSE stability fixes (#13)

Add transport mode support for stdio, SSE stability fixes
This commit is contained in:
Chiran Fernando 2025-04-08 12:46:00 +05:30 committed by GitHub
parent 6ce52261db
commit 32c9378aad
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
12 changed files with 808 additions and 142 deletions

245
README.md
View file

@ -1,101 +1,81 @@
# Open MCP Auth Proxy # Open MCP Auth Proxy
The Open MCP Auth Proxy is a lightweight proxy designed to sit in front of MCP servers and enforce authorization in compliance with the [Model Context Protocol authorization](https://spec.modelcontextprotocol.io/specification/2025-03-26/basic/authorization/) requirements. It intercepts incoming requests, validates tokens, and offloads authentication and authorization to an OAuth-compliant Identity Provider. A lightweight authorization proxy for Model Context Protocol (MCP) servers that enforces authorization according to the [MCP authorization specification](https://spec.modelcontextprotocol.io/specification/2025-03-26/basic/authorization/).
![image](https://github.com/user-attachments/assets/41cf6723-c488-4860-8640-8fec45006f92) ![Architecture Diagram](https://github.com/user-attachments/assets/41cf6723-c488-4860-8640-8fec45006f92)
## **Setup and Installation** ## What it Does
### **Prerequisites** Open MCP Auth Proxy sits between MCP clients and your MCP server to:
- Intercept incoming requests
- Validate authorization tokens
- Offload authentication and authorization to OAuth-compliant Identity Providers
- Support the MCP authorization protocol
## Quick Start
### Prerequisites
* Go 1.20 or higher * Go 1.20 or higher
* A running MCP server (SSE transport supported) * A running MCP server
* An MCP client that supports MCP authorization * An MCP client that supports MCP authorization
### **Installation** ### Installation
```bash ```bash
git clone https://github.com/wso2/open-mcp-auth-proxy git clone https://github.com/wso2/open-mcp-auth-proxy
cd open-mcp-auth-proxy cd open-mcp-auth-proxy
go get github.com/golang-jwt/jwt/v4 gopkg.in/yaml.v2
go get github.com/golang-jwt/jwt/v4
go get gopkg.in/yaml.v2
go build -o openmcpauthproxy ./cmd/proxy go build -o openmcpauthproxy ./cmd/proxy
``` ```
## Using Open MCP Auth Proxy ### Basic Usage
### Quick Start 1. The repository comes with a default `config.yaml` file that contains the basic configuration:
Allows you to just enable authentication and authorization for your MCP server with the preconfigured auth provider powered by Asgardeo.
If you dont have an MCP server, follow the instructions given here to start your own MCP server for testing purposes.
1. Navigate to `resources` directory.
2. Initialize a virtual environment.
```bash
python3 -m venv .venv
```
3. Activate virtual environment.
```bash
source .venv/bin/activate
```
4. Install dependencies.
```
pip3 install -r requirements.txt
```
5. Start the server.
```bash
python3 echo_server.py
```
#### Configure the Auth Proxy
Update the following parameters in `config.yaml`.
### demo mode configuration:
```yaml ```yaml
mcp_server_base_url: "http://localhost:8000" # URL of your MCP server listen_port: 8080
listen_port: 8080 # Address where the proxy will listen base_url: "http://localhost:8000" # Your MCP server URL
paths:
sse: "/sse"
messages: "/messages/"
``` ```
#### Start the Auth Proxy 2. Start the proxy in demo mode (uses pre-configured authentication with Asgardeo sandbox):
```bash ```bash
./openmcpauthproxy --demo ./openmcpauthproxy --demo
``` ```
The `--demo` flag enables a demonstration mode with pre-configured authentication and authorization with a sandbox powered by [Asgardeo](https://asgardeo.io/). 3. Connect using an MCP client like [MCP Inspector](https://github.com/shashimalcse/inspector)(This is a temporary fork with fixes for authentication [issues](https://github.com/modelcontextprotocol/typescript-sdk/issues/257) in the original implementation)
#### Connect Using an MCP Client ## Identity Provider Integration
You can use this fork of the [MCP Inspector](https://github.com/shashimalcse/inspector) to test the connection and try out the complete authorization flow. (This is a temporary fork with fixes for authentication [issues](https://github.com/modelcontextprotocol/typescript-sdk/issues/257) in the original implementation) ### Demo Mode
### Use with Asgardeo For quick testing, use the `--demo` flag which includes pre-configured authentication and authorization with an Asgardeo sandbox.
Enable authorization for the MCP server through your own Asgardeo organization ```bash
./openmcpauthproxy --demo
```
1. [Register]([url](https://asgardeo.io/signup)) and create an organization in Asgardeo ### Asgardeo Integration
2. Now, you need to authorize the OpenMCPAuthProxy to allow dynamically registering MCP Clients as applications in your organization. To do that,
1. Create an [M2M application](https://wso2.com/asgardeo/docs/guides/applications/register-machine-to-machine-app/) To enable authorization through your own Asgardeo organization:
1. [Authorize this application](https://wso2.com/asgardeo/docs/guides/applications/register-machine-to-machine-app/#authorize-the-api-resources-for-the-app) to invoke “Application Management API” with the `internal_application_mgt_create` scope.
1. [Register](https://asgardeo.io/signup) and create an organization in Asgardeo
2. Create an [M2M application](https://wso2.com/asgardeo/docs/guides/applications/register-machine-to-machine-app/)
1. [Authorize this application](https://wso2.com/asgardeo/docs/guides/applications/register-machine-to-machine-app/#authorize-the-api-resources-for-the-app) to invoke "Application Management API" with the `internal_application_mgt_create` scope
![image](https://github.com/user-attachments/assets/0bd57cac-1904-48cc-b7aa-0530224bc41a) ![image](https://github.com/user-attachments/assets/0bd57cac-1904-48cc-b7aa-0530224bc41a)
2. Note the **Client ID** and **Client secret** of this application. This is required by the auth proxy 2. Update the existing `config.yaml` with your Asgardeo details:
#### Configure the Auth Proxy #### Configure the Auth Proxy
Create a configuration file config.yaml with the following parameters: Create a configuration file config.yaml with the following parameters:
```yaml ```yaml
mcp_server_base_url: "http://localhost:8000" # URL of your MCP server base_url: "http://localhost:8000" # URL of your MCP server
listen_port: 8080 # Address where the proxy will listen listen_port: 8080 # Address where the proxy will listen
asgardeo: asgardeo:
@ -104,31 +84,146 @@ asgardeo:
client_secret: "<client_secret>" # Client secret of the M2M app client_secret: "<client_secret>" # Client secret of the M2M app
``` ```
#### Start the Auth Proxy 3. Start the proxy with Asgardeo integration:
```bash ```bash
./openmcpauthproxy --asgardeo ./openmcpauthproxy --asgardeo
``` ```
### Use with any standard OAuth Server ### Other OAuth Providers
Enable authorization for the MCP server with a compliant OAuth server - [Auth0 Integration Guide](docs/Auth0.md)
#### Configuration ## Testing with an Example MCP Server
Create a configuration file config.yaml with the following parameters: If you don't have an MCP server, you can use the included example:
```yaml 1. Navigate to the `resources` directory
mcp_server_base_url: "http://localhost:8000" # URL of your MCP server 2. Set up a Python environment:
listen_port: 8080 # Address where the proxy will listen
```
**TODO**: Update the configs for a standard OAuth Server.
#### Start the Auth Proxy
```bash ```bash
./openmcpauthproxy python3 -m venv .venv
source .venv/bin/activate
pip3 install -r requirements.txt
``` ```
#### Integrating with existing OAuth Providers
- [Auth0](docs/Auth0.md) - Enable authorization for the MCP server through your Auth0 organization. 3. Start the example server:
```bash
python3 echo_server.py
```
# Advanced Configuration
### Transport Modes
The proxy supports two transport modes:
- **SSE Mode (Default)**: For Server-Sent Events transport
- **stdio Mode**: For MCP servers that use stdio transport
When using stdio mode, the proxy:
- Starts an MCP server as a subprocess using the command specified in the configuration
- Communicates with the subprocess through standard input/output (stdio)
- **Note**: Any commands specified (like `npx` in the example below) must be installed on your system first
To use stdio mode:
```bash
./openmcpauthproxy --demo --stdio
```
#### Example: Running an MCP Server as a Subprocess
1. Configure stdio mode in your `config.yaml`:
```yaml
listen_port: 8080
base_url: "http://localhost:8000"
stdio:
enabled: true
user_command: "npx -y @modelcontextprotocol/server-github" # Example using a GitHub MCP server
env: # Environment variables (optional)
- "GITHUB_PERSONAL_ACCESS_TOKEN=gitPAT"
# CORS configuration
cors:
allowed_origins:
- "http://localhost:5173" # Origin of your client application
allowed_methods:
- "GET"
- "POST"
- "PUT"
- "DELETE"
allowed_headers:
- "Authorization"
- "Content-Type"
allow_credentials: true
# Demo configuration for Asgardeo
demo:
org_name: "openmcpauthdemo"
client_id: "N0U9e_NNGr9mP_0fPnPfPI0a6twa"
client_secret: "qFHfiBp5gNGAO9zV4YPnDofBzzfInatfUbHyPZvM0jka"
```
2. Run the proxy with stdio mode:
```bash
./openmcpauthproxy --demo
```
The proxy will:
- Start the MCP server as a subprocess using the specified command
- Handle all authorization requirements
- Forward messages between clients and the server
### Complete Configuration Reference
```yaml
# Common configuration
listen_port: 8080
base_url: "http://localhost:8000"
port: 8000
# Path configuration
paths:
sse: "/sse"
messages: "/messages/"
# Transport mode
transport_mode: "sse" # Options: "sse" or "stdio"
# stdio-specific configuration (used only in stdio mode)
stdio:
enabled: true
user_command: "npx -y @modelcontextprotocol/server-github" # Command to start the MCP server (requires npx to be installed)
work_dir: "" # Optional working directory for the subprocess
# CORS configuration
cors:
allowed_origins:
- "http://localhost:5173"
allowed_methods:
- "GET"
- "POST"
- "PUT"
- "DELETE"
allowed_headers:
- "Authorization"
- "Content-Type"
allow_credentials: true
# Demo configuration for Asgardeo
demo:
org_name: "openmcpauthdemo"
client_id: "N0U9e_NNGr9mP_0fPnPfPI0a6twa"
client_secret: "qFHfiBp5gNGAO9zV4YPnDofBzzfInatfUbHyPZvM0jka"
# Asgardeo configuration (used with --asgardeo flag)
asgardeo:
org_name: "<org_name>"
client_id: "<client_id>"
client_secret: "<client_secret>"
```

View file

@ -3,31 +3,71 @@ package main
import ( import (
"flag" "flag"
"fmt" "fmt"
"log"
"net/http" "net/http"
"os" "os"
"os/signal" "os/signal"
"syscall"
"time" "time"
"github.com/wso2/open-mcp-auth-proxy/internal/authz" "github.com/wso2/open-mcp-auth-proxy/internal/authz"
"github.com/wso2/open-mcp-auth-proxy/internal/config" "github.com/wso2/open-mcp-auth-proxy/internal/config"
"github.com/wso2/open-mcp-auth-proxy/internal/constants" "github.com/wso2/open-mcp-auth-proxy/internal/constants"
"github.com/wso2/open-mcp-auth-proxy/internal/logging"
"github.com/wso2/open-mcp-auth-proxy/internal/proxy" "github.com/wso2/open-mcp-auth-proxy/internal/proxy"
"github.com/wso2/open-mcp-auth-proxy/internal/subprocess"
"github.com/wso2/open-mcp-auth-proxy/internal/util" "github.com/wso2/open-mcp-auth-proxy/internal/util"
) )
func main() { func main() {
demoMode := flag.Bool("demo", false, "Use Asgardeo-based provider (demo).") demoMode := flag.Bool("demo", false, "Use Asgardeo-based provider (demo).")
asgardeoMode := flag.Bool("asgardeo", false, "Use Asgardeo-based provider (asgardeo).") asgardeoMode := flag.Bool("asgardeo", false, "Use Asgardeo-based provider (asgardeo).")
debugMode := flag.Bool("debug", false, "Enable debug logging")
stdioMode := flag.Bool("stdio", false, "Use stdio transport mode instead of SSE")
flag.Parse() flag.Parse()
logger.SetDebug(*debugMode)
// 1. Load config // 1. Load config
cfg, err := config.LoadConfig("config.yaml") cfg, err := config.LoadConfig("config.yaml")
if err != nil { if err != nil {
log.Fatalf("Error loading config: %v", err) logger.Error("Error loading config: %v", err)
os.Exit(1)
} }
// 2. Create the chosen provider // Override transport mode if stdio flag is set
if *stdioMode {
cfg.TransportMode = config.StdioTransport
// Ensure stdio is enabled
cfg.Stdio.Enabled = true
// Re-validate config
if err := cfg.Validate(); err != nil {
logger.Error("Configuration error: %v", err)
os.Exit(1)
}
}
logger.Info("Using transport mode: %s", cfg.TransportMode)
logger.Info("Using MCP server base URL: %s", cfg.BaseURL)
logger.Info("Using MCP paths: SSE=%s, Messages=%s", cfg.Paths.SSE, cfg.Paths.Messages)
// 2. Start subprocess if configured and in stdio mode
var procManager *subprocess.Manager
if cfg.TransportMode == config.StdioTransport && cfg.Stdio.Enabled {
// Ensure all required dependencies are available
if err := subprocess.EnsureDependenciesAvailable(cfg.Stdio.UserCommand); err != nil {
logger.Warn("%v", err)
logger.Warn("Subprocess may fail to start due to missing dependencies")
}
procManager = subprocess.NewManager()
if err := procManager.Start(cfg); err != nil {
logger.Warn("Failed to start subprocess: %v", err)
}
} else if cfg.TransportMode == config.SSETransport {
logger.Info("Using SSE transport mode, not starting subprocess")
}
// 3. Create the chosen provider
var provider authz.Provider var provider authz.Provider
if *demoMode { if *demoMode {
cfg.Mode = "demo" cfg.Mode = "demo"
@ -46,41 +86,49 @@ func main() {
provider = authz.NewDefaultProvider(cfg) provider = authz.NewDefaultProvider(cfg)
} }
// 3. (Optional) Fetch JWKS if you want local JWT validation // 4. (Optional) Fetch JWKS if you want local JWT validation
if err := util.FetchJWKS(cfg.JWKSURL); err != nil { if err := util.FetchJWKS(cfg.JWKSURL); err != nil {
log.Fatalf("Failed to fetch JWKS: %v", err) logger.Error("Failed to fetch JWKS: %v", err)
os.Exit(1)
} }
// 4. Build the main router // 5. Build the main router
mux := proxy.NewRouter(cfg, provider) mux := proxy.NewRouter(cfg, provider)
listen_address := fmt.Sprintf(":%d", cfg.ListenPort) listen_address := fmt.Sprintf(":%d", cfg.ListenPort)
// 5. Start the server // 6. Start the server
srv := &http.Server{ srv := &http.Server{
Addr: listen_address, Addr: listen_address,
Handler: mux, Handler: mux,
} }
go func() { go func() {
log.Printf("Server listening on %s", listen_address) logger.Info("Server listening on %s", listen_address)
if err := srv.ListenAndServe(); err != nil && err != http.ErrServerClosed { if err := srv.ListenAndServe(); err != nil && err != http.ErrServerClosed {
log.Fatalf("Server error: %v", err) logger.Error("Server error: %v", err)
os.Exit(1)
} }
}() }()
// 6. Graceful shutdown on Ctrl+C // 7. Wait for shutdown signal
stop := make(chan os.Signal, 1) stop := make(chan os.Signal, 1)
signal.Notify(stop, os.Interrupt) signal.Notify(stop, os.Interrupt, syscall.SIGTERM)
<-stop <-stop
log.Println("Shutting down...") logger.Info("Shutting down...")
// 8. First terminate subprocess if running
if procManager != nil && procManager.IsRunning() {
procManager.Shutdown()
}
// 9. Then shutdown the server
logger.Info("Shutting down HTTP server...")
shutdownCtx, cancel := proxy.NewShutdownContext(5 * time.Second) shutdownCtx, cancel := proxy.NewShutdownContext(5 * time.Second)
defer cancel() defer cancel()
if err := srv.Shutdown(shutdownCtx); err != nil { if err := srv.Shutdown(shutdownCtx); err != nil {
log.Printf("Shutdown error: %v", err) logger.Error("HTTP server shutdown error: %v", err)
} }
log.Println("Stopped.") logger.Info("Stopped.")
} }

View file

@ -1,15 +1,31 @@
# config.yaml # config.yaml
mcp_server_base_url: "" # Common configuration for all transport modes
listen_port: 8080 listen_port: 8080
base_url: "http://localhost:8000" # Base URL for the MCP server
port: 8000 # Port for the MCP server
timeout_seconds: 10 timeout_seconds: 10
mcp_paths: # Path configuration
- /messages/ paths:
- /sse sse: "/sse" # SSE endpoint path
messages: "/messages/" # Messages endpoint path
# Transport mode configuration
transport_mode: "sse" # Options: "sse" or "stdio"
# stdio-specific configuration (used only when transport_mode is "stdio")
stdio:
enabled: true
user_command: "npx -y @modelcontextprotocol/server-github"
work_dir: "" # Working directory (optional)
# env: # Environment variables (optional)
# - "NODE_ENV=development"
# Path mapping (optional)
path_mapping: path_mapping:
# CORS configuration
cors: cors:
allowed_origins: allowed_origins:
- "http://localhost:5173" - "http://localhost:5173"
@ -23,10 +39,8 @@ cors:
- "Content-Type" - "Content-Type"
allow_credentials: true allow_credentials: true
# Demo configuration for Asgardeo
demo: demo:
org_name: "openmcpauthdemo" org_name: "openmcpauthdemo"
client_id: "N0U9e_NNGr9mP_0fPnPfPI0a6twa" client_id: "N0U9e_NNGr9mP_0fPnPfPI0a6twa"
client_secret: "qFHfiBp5gNGAO9zV4YPnDofBzzfInatfUbHyPZvM0jka" client_secret: "qFHfiBp5gNGAO9zV4YPnDofBzzfInatfUbHyPZvM0jka"

View file

@ -7,13 +7,13 @@ import (
"encoding/json" "encoding/json"
"fmt" "fmt"
"io" "io"
"log"
"math/rand" "math/rand"
"net/http" "net/http"
"strings" "strings"
"time" "time"
"github.com/wso2/open-mcp-auth-proxy/internal/config" "github.com/wso2/open-mcp-auth-proxy/internal/config"
"github.com/wso2/open-mcp-auth-proxy/internal/logging"
) )
type asgardeoProvider struct { type asgardeoProvider struct {
@ -31,6 +31,7 @@ func (p *asgardeoProvider) WellKnownHandler() http.HandlerFunc {
w.Header().Set("Access-Control-Allow-Origin", "*") w.Header().Set("Access-Control-Allow-Origin", "*")
w.Header().Set("Access-Control-Allow-Headers", "Authorization, Content-Type") w.Header().Set("Access-Control-Allow-Headers", "Authorization, Content-Type")
w.Header().Set("Access-Control-Allow-Methods", "GET, OPTIONS") w.Header().Set("Access-Control-Allow-Methods", "GET, OPTIONS")
w.Header().Set("X-Accel-Buffering", "no")
if r.Method == http.MethodOptions { if r.Method == http.MethodOptions {
w.WriteHeader(http.StatusNoContent) w.WriteHeader(http.StatusNoContent)
@ -70,8 +71,9 @@ func (p *asgardeoProvider) WellKnownHandler() http.HandlerFunc {
} }
w.Header().Set("Content-Type", "application/json") w.Header().Set("Content-Type", "application/json")
w.Header().Set("X-Accel-Buffering", "no")
if err := json.NewEncoder(w).Encode(response); err != nil { if err := json.NewEncoder(w).Encode(response); err != nil {
log.Printf("[asgardeoProvider] Error encoding well-known: %v", err) logger.Error("Error encoding well-known: %v", err)
http.Error(w, "Internal server error", http.StatusInternalServerError) http.Error(w, "Internal server error", http.StatusInternalServerError)
} }
} }
@ -83,6 +85,7 @@ func (p *asgardeoProvider) RegisterHandler() http.HandlerFunc {
w.Header().Set("Access-Control-Allow-Origin", "*") w.Header().Set("Access-Control-Allow-Origin", "*")
w.Header().Set("Access-Control-Allow-Headers", "Authorization, Content-Type") w.Header().Set("Access-Control-Allow-Headers", "Authorization, Content-Type")
w.Header().Set("Access-Control-Allow-Methods", "GET, OPTIONS") w.Header().Set("Access-Control-Allow-Methods", "GET, OPTIONS")
w.Header().Set("X-Accel-Buffering", "no")
if r.Method == http.MethodOptions { if r.Method == http.MethodOptions {
w.WriteHeader(http.StatusNoContent) w.WriteHeader(http.StatusNoContent)
@ -95,7 +98,7 @@ func (p *asgardeoProvider) RegisterHandler() http.HandlerFunc {
var regReq RegisterRequest var regReq RegisterRequest
if err := json.NewDecoder(r.Body).Decode(&regReq); err != nil { if err := json.NewDecoder(r.Body).Decode(&regReq); err != nil {
log.Printf("ERROR: reading register request: %v", err) logger.Error("Reading register request: %v", err)
http.Error(w, "Invalid request body", http.StatusBadRequest) http.Error(w, "Invalid request body", http.StatusBadRequest)
return return
} }
@ -109,7 +112,7 @@ func (p *asgardeoProvider) RegisterHandler() http.HandlerFunc {
regReq.ClientSecret = randomString(16) regReq.ClientSecret = randomString(16)
if err := p.createAsgardeoApplication(regReq); err != nil { if err := p.createAsgardeoApplication(regReq); err != nil {
log.Printf("WARN: Asgardeo application creation failed: %v", err) logger.Warn("Asgardeo application creation failed: %v", err)
// Optionally http.Error(...) if you want to fail // Optionally http.Error(...) if you want to fail
// or continue to return partial data. // or continue to return partial data.
} }
@ -124,9 +127,10 @@ func (p *asgardeoProvider) RegisterHandler() http.HandlerFunc {
} }
w.Header().Set("Content-Type", "application/json") w.Header().Set("Content-Type", "application/json")
w.Header().Set("X-Accel-Buffering", "no")
w.WriteHeader(http.StatusCreated) w.WriteHeader(http.StatusCreated)
if err := json.NewEncoder(w).Encode(resp); err != nil { if err := json.NewEncoder(w).Encode(resp); err != nil {
log.Printf("ERROR: encoding /register response: %v", err) logger.Error("Encoding /register response: %v", err)
http.Error(w, "Internal server error", http.StatusInternalServerError) http.Error(w, "Internal server error", http.StatusInternalServerError)
} }
} }
@ -186,7 +190,7 @@ func (p *asgardeoProvider) createAsgardeoApplication(regReq RegisterRequest) err
return fmt.Errorf("Asgardeo creation error (%d): %s", resp.StatusCode, string(respBody)) return fmt.Errorf("Asgardeo creation error (%d): %s", resp.StatusCode, string(respBody))
} }
log.Printf("INFO: Created Asgardeo application for clientID=%s", regReq.ClientID) logger.Info("Created Asgardeo application for clientID=%s", regReq.ClientID)
return nil return nil
} }
@ -202,9 +206,12 @@ func (p *asgardeoProvider) getAsgardeoAdminToken() (string, error) {
} }
req.Header.Set("Content-Type", "application/x-www-form-urlencoded") req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
// Sensitive data - should not be logged at INFO level
auth := p.cfg.Demo.ClientID + ":" + p.cfg.Demo.ClientSecret auth := p.cfg.Demo.ClientID + ":" + p.cfg.Demo.ClientSecret
req.Header.Set("Authorization", "Basic "+base64.StdEncoding.EncodeToString([]byte(auth))) req.Header.Set("Authorization", "Basic "+base64.StdEncoding.EncodeToString([]byte(auth)))
logger.Debug("Requesting admin token for Asgardeo with client ID: %s", p.cfg.Demo.ClientID)
tr := &http.Transport{ tr := &http.Transport{
TLSClientConfig: &tls.Config{InsecureSkipVerify: true}, TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
} }
@ -234,6 +241,10 @@ func (p *asgardeoProvider) getAsgardeoAdminToken() (string, error) {
return "", fmt.Errorf("failed to parse token JSON: %w", err) return "", fmt.Errorf("failed to parse token JSON: %w", err)
} }
// Don't log the actual token at info level, only at debug level
logger.Debug("Received access token: %s", tokenResp.AccessToken)
logger.Info("Successfully obtained admin token from Asgardeo")
return tokenResp.AccessToken, nil return tokenResp.AccessToken, nil
} }

View file

@ -5,6 +5,7 @@ import (
"net/http" "net/http"
"github.com/wso2/open-mcp-auth-proxy/internal/config" "github.com/wso2/open-mcp-auth-proxy/internal/config"
"github.com/wso2/open-mcp-auth-proxy/internal/logging"
) )
type defaultProvider struct { type defaultProvider struct {
@ -81,6 +82,7 @@ func (p *defaultProvider) WellKnownHandler() http.HandlerFunc {
w.Header().Set("Content-Type", "application/json") w.Header().Set("Content-Type", "application/json")
if err := json.NewEncoder(w).Encode(response); err != nil { if err := json.NewEncoder(w).Encode(response); err != nil {
logger.Error("Error encoding well-known response: %v", err)
http.Error(w, "Internal server error", http.StatusInternalServerError) http.Error(w, "Internal server error", http.StatusInternalServerError)
} }
return return

View file

@ -1,12 +1,35 @@
package config package config
import ( import (
"fmt"
"os" "os"
"gopkg.in/yaml.v2" "gopkg.in/yaml.v2"
) )
// AsgardeoConfig groups all Asgardeo-specific fields // Transport mode for MCP server
type TransportMode string
const (
SSETransport TransportMode = "sse"
StdioTransport TransportMode = "stdio"
)
// Common path configuration for all transport modes
type PathsConfig struct {
SSE string `yaml:"sse"`
Messages string `yaml:"messages"`
}
// StdioConfig contains stdio-specific configuration
type StdioConfig struct {
Enabled bool `yaml:"enabled"`
UserCommand string `yaml:"user_command"` // The command provided by the user
WorkDir string `yaml:"work_dir"` // Working directory (optional)
Args []string `yaml:"args,omitempty"` // Additional arguments
Env []string `yaml:"env,omitempty"` // Environment variables
}
type DemoConfig struct { type DemoConfig struct {
ClientID string `yaml:"client_id"` ClientID string `yaml:"client_id"`
ClientSecret string `yaml:"client_secret"` ClientSecret string `yaml:"client_secret"`
@ -61,14 +84,17 @@ type DefaultConfig struct {
type Config struct { type Config struct {
AuthServerBaseURL string AuthServerBaseURL string
MCPServerBaseURL string `yaml:"mcp_server_base_url"`
ListenPort int `yaml:"listen_port"` ListenPort int `yaml:"listen_port"`
BaseURL string `yaml:"base_url"`
Port int `yaml:"port"`
JWKSURL string JWKSURL string
TimeoutSeconds int `yaml:"timeout_seconds"` TimeoutSeconds int `yaml:"timeout_seconds"`
MCPPaths []string `yaml:"mcp_paths"`
PathMapping map[string]string `yaml:"path_mapping"` PathMapping map[string]string `yaml:"path_mapping"`
Mode string `yaml:"mode"` Mode string `yaml:"mode"`
CORSConfig CORSConfig `yaml:"cors"` CORSConfig CORSConfig `yaml:"cors"`
TransportMode TransportMode `yaml:"transport_mode"`
Paths PathsConfig `yaml:"paths"`
Stdio StdioConfig `yaml:"stdio"`
// Nested config for Asgardeo // Nested config for Asgardeo
Demo DemoConfig `yaml:"demo"` Demo DemoConfig `yaml:"demo"`
@ -76,6 +102,56 @@ type Config struct {
Default DefaultConfig `yaml:"default"` Default DefaultConfig `yaml:"default"`
} }
// Validate checks if the config is valid based on transport mode
func (c *Config) Validate() error {
// Validate based on transport mode
if c.TransportMode == StdioTransport {
if !c.Stdio.Enabled {
return fmt.Errorf("stdio.enabled must be true in stdio transport mode")
}
if c.Stdio.UserCommand == "" {
return fmt.Errorf("stdio.user_command is required in stdio transport mode")
}
}
// Validate paths
if c.Paths.SSE == "" {
c.Paths.SSE = "/sse" // Default value
}
if c.Paths.Messages == "" {
c.Paths.Messages = "/messages" // Default value
}
// Validate base URL
if c.BaseURL == "" {
if c.Port > 0 {
c.BaseURL = fmt.Sprintf("http://localhost:%d", c.Port)
} else {
c.BaseURL = "http://localhost:8000" // Default value
}
}
return nil
}
// GetMCPPaths returns the list of paths that should be proxied to the MCP server
func (c *Config) GetMCPPaths() []string {
return []string{c.Paths.SSE, c.Paths.Messages}
}
// BuildExecCommand constructs the full command string for execution in stdio mode
func (c *Config) BuildExecCommand() string {
if c.Stdio.UserCommand == "" {
return ""
}
// Construct the full command
return fmt.Sprintf(
`npx -y supergateway --stdio "%s" --port %d --baseUrl %s --ssePath %s --messagePath %s`,
c.Stdio.UserCommand, c.Port, c.BaseURL, c.Paths.SSE, c.Paths.Messages,
)
}
// LoadConfig reads a YAML config file into Config struct. // LoadConfig reads a YAML config file into Config struct.
func LoadConfig(path string) (*Config, error) { func LoadConfig(path string) (*Config, error) {
f, err := os.Open(path) f, err := os.Open(path)
@ -89,8 +165,26 @@ func LoadConfig(path string) (*Config, error) {
if err := decoder.Decode(&cfg); err != nil { if err := decoder.Decode(&cfg); err != nil {
return nil, err return nil, err
} }
// Set default values
if cfg.TimeoutSeconds == 0 { if cfg.TimeoutSeconds == 0 {
cfg.TimeoutSeconds = 15 // default cfg.TimeoutSeconds = 15 // default
} }
// Set default transport mode if not specified
if cfg.TransportMode == "" {
cfg.TransportMode = SSETransport // Default to SSE
}
// Set default port if not specified
if cfg.Port == 0 {
cfg.Port = 8000 // default
}
// Validate the configuration
if err := cfg.Validate(); err != nil {
return nil, err
}
return &cfg, nil return &cfg, nil
} }

View file

@ -0,0 +1,34 @@
package logger
import (
"log"
)
var isDebug = false
// SetDebug enables or disables debug logging
func SetDebug(debug bool) {
isDebug = debug
}
// Debug logs a debug-level message
func Debug(format string, v ...interface{}) {
if isDebug {
log.Printf("DEBUG: "+format, v...)
}
}
// Info logs an info-level message
func Info(format string, v ...interface{}) {
log.Printf("INFO: "+format, v...)
}
// Warn logs a warning-level message
func Warn(format string, v ...interface{}) {
log.Printf("WARN: "+format, v...)
}
// Error logs an error-level message
func Error(format string, v ...interface{}) {
log.Printf("ERROR: "+format, v...)
}

View file

@ -9,6 +9,7 @@ import (
"strings" "strings"
"github.com/wso2/open-mcp-auth-proxy/internal/config" "github.com/wso2/open-mcp-auth-proxy/internal/config"
"github.com/wso2/open-mcp-auth-proxy/internal/logging"
) )
// RequestModifier modifies requests before they are proxied // RequestModifier modifies requests before they are proxied
@ -148,6 +149,7 @@ func (m *RegisterModifier) ModifyRequest(req *http.Request) (*http.Request, erro
if strings.Contains(contentType, "application/x-www-form-urlencoded") { if strings.Contains(contentType, "application/x-www-form-urlencoded") {
// Parse form data // Parse form data
if err := req.ParseForm(); err != nil { if err := req.ParseForm(); err != nil {
logger.Error("Failed to parse form data: %v", err)
return nil, err return nil, err
} }
@ -169,12 +171,14 @@ func (m *RegisterModifier) ModifyRequest(req *http.Request) (*http.Request, erro
// Read body // Read body
bodyBytes, err := io.ReadAll(req.Body) bodyBytes, err := io.ReadAll(req.Body)
if err != nil { if err != nil {
logger.Error("Failed to read request body: %v", err)
return nil, err return nil, err
} }
// Parse JSON // Parse JSON
var jsonData map[string]interface{} var jsonData map[string]interface{}
if err := json.Unmarshal(bodyBytes, &jsonData); err != nil { if err := json.Unmarshal(bodyBytes, &jsonData); err != nil {
logger.Error("Failed to parse JSON body: %v", err)
return nil, err return nil, err
} }
@ -186,6 +190,7 @@ func (m *RegisterModifier) ModifyRequest(req *http.Request) (*http.Request, erro
// Marshal back to JSON // Marshal back to JSON
modifiedBody, err := json.Marshal(jsonData) modifiedBody, err := json.Marshal(jsonData)
if err != nil { if err != nil {
logger.Error("Failed to marshal modified JSON: %v", err)
return nil, err return nil, err
} }

View file

@ -2,7 +2,6 @@ package proxy
import ( import (
"context" "context"
"log"
"net/http" "net/http"
"net/http/httputil" "net/http/httputil"
"net/url" "net/url"
@ -11,6 +10,7 @@ import (
"github.com/wso2/open-mcp-auth-proxy/internal/authz" "github.com/wso2/open-mcp-auth-proxy/internal/authz"
"github.com/wso2/open-mcp-auth-proxy/internal/config" "github.com/wso2/open-mcp-auth-proxy/internal/config"
"github.com/wso2/open-mcp-auth-proxy/internal/logging"
"github.com/wso2/open-mcp-auth-proxy/internal/util" "github.com/wso2/open-mcp-auth-proxy/internal/util"
) )
@ -82,7 +82,8 @@ func NewRouter(cfg *config.Config, provider authz.Provider) http.Handler {
} }
// MCP paths // MCP paths
for _, path := range cfg.MCPPaths { mcpPaths := cfg.GetMCPPaths()
for _, path := range mcpPaths {
mux.HandleFunc(path, buildProxyHandler(cfg, modifiers)) mux.HandleFunc(path, buildProxyHandler(cfg, modifiers))
registeredPaths[path] = true registeredPaths[path] = true
} }
@ -100,23 +101,21 @@ func NewRouter(cfg *config.Config, provider authz.Provider) http.Handler {
func buildProxyHandler(cfg *config.Config, modifiers map[string]RequestModifier) http.HandlerFunc { func buildProxyHandler(cfg *config.Config, modifiers map[string]RequestModifier) http.HandlerFunc {
// Parse the base URLs up front // Parse the base URLs up front
authBase, err := url.Parse(cfg.AuthServerBaseURL) authBase, err := url.Parse(cfg.AuthServerBaseURL)
if err != nil { if err != nil {
log.Fatalf("Invalid auth server URL: %v", err) logger.Error("Invalid auth server URL: %v", err)
panic(err) // Fatal error that prevents startup
} }
mcpBase, err := url.Parse(cfg.MCPServerBaseURL)
mcpBase, err := url.Parse(cfg.BaseURL)
if err != nil { if err != nil {
log.Fatalf("Invalid MCP server URL: %v", err) logger.Error("Invalid MCP server URL: %v", err)
panic(err) // Fatal error that prevents startup
} }
// Detect SSE paths from config // Detect SSE paths from config
ssePaths := make(map[string]bool) ssePaths := make(map[string]bool)
for _, p := range cfg.MCPPaths { ssePaths[cfg.Paths.SSE] = true
if p == "/sse" {
ssePaths[p] = true
}
}
return func(w http.ResponseWriter, r *http.Request) { return func(w http.ResponseWriter, r *http.Request) {
origin := r.Header.Get("Origin") origin := r.Header.Get("Origin")
@ -124,7 +123,7 @@ func buildProxyHandler(cfg *config.Config, modifiers map[string]RequestModifier)
// Handle OPTIONS // Handle OPTIONS
if r.Method == http.MethodOptions { if r.Method == http.MethodOptions {
if allowedOrigin == "" { if allowedOrigin == "" {
log.Printf("[proxy] Preflight request from disallowed origin: %s", origin) logger.Warn("Preflight request from disallowed origin: %s", origin)
http.Error(w, "CORS origin not allowed", http.StatusForbidden) http.Error(w, "CORS origin not allowed", http.StatusForbidden)
return return
} }
@ -134,7 +133,7 @@ func buildProxyHandler(cfg *config.Config, modifiers map[string]RequestModifier)
} }
if allowedOrigin == "" { if allowedOrigin == "" {
log.Printf("[proxy] Request from disallowed origin: %s for %s", origin, r.URL.Path) logger.Warn("Request from disallowed origin: %s for %s", origin, r.URL.Path)
http.Error(w, "CORS origin not allowed", http.StatusForbidden) http.Error(w, "CORS origin not allowed", http.StatusForbidden)
return return
} }
@ -152,7 +151,7 @@ func buildProxyHandler(cfg *config.Config, modifiers map[string]RequestModifier)
// Validate JWT for MCP paths if required // Validate JWT for MCP paths if required
// Placeholder for JWT validation logic // Placeholder for JWT validation logic
if err := util.ValidateJWT(r.Header.Get("Authorization")); err != nil { if err := util.ValidateJWT(r.Header.Get("Authorization")); err != nil {
log.Printf("[proxy] Unauthorized request to %s: %v", r.URL.Path, err) logger.Warn("Unauthorized request to %s: %v", r.URL.Path, err)
http.Error(w, "Unauthorized", http.StatusUnauthorized) http.Error(w, "Unauthorized", http.StatusUnauthorized)
return return
} }
@ -170,7 +169,7 @@ func buildProxyHandler(cfg *config.Config, modifiers map[string]RequestModifier)
var err error var err error
r, err = modifier.ModifyRequest(r) r, err = modifier.ModifyRequest(r)
if err != nil { if err != nil {
log.Printf("[proxy] Error modifying request: %v", err) logger.Error("Error modifying request: %v", err)
http.Error(w, "Bad Request", http.StatusBadRequest) http.Error(w, "Bad Request", http.StatusBadRequest)
return return
} }
@ -193,6 +192,12 @@ func buildProxyHandler(cfg *config.Config, modifiers map[string]RequestModifier)
cleanHeaders := http.Header{} cleanHeaders := http.Header{}
// Set proper origin header to match the target
if isSSE {
// For SSE, ensure origin matches the target
req.Header.Set("Origin", targetURL.Scheme+"://"+targetURL.Host)
}
for k, v := range r.Header { for k, v := range r.Header {
// Skip hop-by-hop headers // Skip hop-by-hop headers
if skipHeader(k) { if skipHeader(k) {
@ -205,21 +210,33 @@ func buildProxyHandler(cfg *config.Config, modifiers map[string]RequestModifier)
req.Header = cleanHeaders req.Header = cleanHeaders
log.Printf("[proxy] %s -> %s%s", r.URL.Path, req.URL.Host, req.URL.Path) logger.Debug("%s -> %s%s", r.URL.Path, req.URL.Host, req.URL.Path)
}, },
ModifyResponse: func(resp *http.Response) error { ModifyResponse: func(resp *http.Response) error {
log.Printf("[proxy] Response from %s%s: %d", resp.Request.URL.Host, resp.Request.URL.Path, resp.StatusCode) logger.Debug("Response from %s%s: %d", resp.Request.URL.Host, resp.Request.URL.Path, resp.StatusCode)
resp.Header.Del("Access-Control-Allow-Origin") // Avoid upstream conflicts resp.Header.Del("Access-Control-Allow-Origin") // Avoid upstream conflicts
return nil return nil
}, },
ErrorHandler: func(rw http.ResponseWriter, req *http.Request, err error) { ErrorHandler: func(rw http.ResponseWriter, req *http.Request, err error) {
log.Printf("[proxy] Error proxying: %v", err) logger.Error("Error proxying: %v", err)
http.Error(rw, "Bad Gateway", http.StatusBadGateway) http.Error(rw, "Bad Gateway", http.StatusBadGateway)
}, },
FlushInterval: -1, // immediate flush for SSE FlushInterval: -1, // immediate flush for SSE
} }
if isSSE { if isSSE {
// Add special response handling for SSE connections to rewrite endpoint URLs
rp.Transport = &sseTransport{
Transport: http.DefaultTransport,
proxyHost: r.Host,
targetHost: targetURL.Host,
}
// Set SSE-specific headers
w.Header().Set("X-Accel-Buffering", "no")
w.Header().Set("Cache-Control", "no-cache")
w.Header().Set("Connection", "keep-alive")
// Keep SSE connections open // Keep SSE connections open
HandleSSE(w, r, rp) HandleSSE(w, r, rp)
} else { } else {
@ -236,6 +253,7 @@ func getAllowedOrigin(origin string, cfg *config.Config) string {
return cfg.CORSConfig.AllowedOrigins[0] // Default to first allowed origin return cfg.CORSConfig.AllowedOrigins[0] // Default to first allowed origin
} }
for _, allowed := range cfg.CORSConfig.AllowedOrigins { for _, allowed := range cfg.CORSConfig.AllowedOrigins {
logger.Debug("Checking CORS origin: %s against allowed: %s", origin, allowed)
if allowed == origin { if allowed == origin {
return allowed return allowed
} }
@ -256,6 +274,7 @@ func addCORSHeaders(w http.ResponseWriter, cfg *config.Config, allowedOrigin, re
w.Header().Set("Access-Control-Allow-Credentials", "true") w.Header().Set("Access-Control-Allow-Credentials", "true")
} }
w.Header().Set("Vary", "Origin") w.Header().Set("Vary", "Origin")
w.Header().Set("X-Accel-Buffering", "no")
} }
func isAuthPath(path string) bool { func isAuthPath(path string) bool {
@ -273,7 +292,8 @@ func isAuthPath(path string) bool {
// isMCPPath checks if the path is an MCP path // isMCPPath checks if the path is an MCP path
func isMCPPath(path string, cfg *config.Config) bool { func isMCPPath(path string, cfg *config.Config) bool {
for _, p := range cfg.MCPPaths { mcpPaths := cfg.GetMCPPaths()
for _, p := range mcpPaths {
if strings.HasPrefix(path, p) { if strings.HasPrefix(path, p) {
return true return true
} }

View file

@ -1,11 +1,16 @@
package proxy package proxy
import ( import (
"bufio"
"context" "context"
"log" "fmt"
"io"
"net/http" "net/http"
"net/http/httputil" "net/http/httputil"
"strings"
"time" "time"
"github.com/wso2/open-mcp-auth-proxy/internal/logging"
) )
// HandleSSE sets up a go-routine to wait for context cancellation // HandleSSE sets up a go-routine to wait for context cancellation
@ -16,7 +21,7 @@ func HandleSSE(w http.ResponseWriter, r *http.Request, rp *httputil.ReverseProxy
go func() { go func() {
<-ctx.Done() <-ctx.Done()
log.Printf("INFO: SSE connection closed from %s (path: %s)", r.RemoteAddr, r.URL.Path) logger.Info("SSE connection closed from %s (path: %s)", r.RemoteAddr, r.URL.Path)
close(done) close(done)
}() }()
@ -32,3 +37,73 @@ func HandleSSE(w http.ResponseWriter, r *http.Request, rp *httputil.ReverseProxy
func NewShutdownContext(timeout time.Duration) (context.Context, context.CancelFunc) { func NewShutdownContext(timeout time.Duration) (context.Context, context.CancelFunc) {
return context.WithTimeout(context.Background(), timeout) return context.WithTimeout(context.Background(), timeout)
} }
// sseTransport is a custom http.RoundTripper that intercepts and modifies SSE responses
type sseTransport struct {
Transport http.RoundTripper
proxyHost string
targetHost string
}
func (t *sseTransport) RoundTrip(req *http.Request) (*http.Response, error) {
// Call the underlying transport
resp, err := t.Transport.RoundTrip(req)
if err != nil {
return nil, err
}
// Check if this is an SSE response
contentType := resp.Header.Get("Content-Type")
if !strings.Contains(contentType, "text/event-stream") {
return resp, nil
}
logger.Info("Intercepting SSE response to modify endpoint events")
// Create a response wrapper that modifies the response body
originalBody := resp.Body
pr, pw := io.Pipe()
go func() {
defer originalBody.Close()
defer pw.Close()
scanner := bufio.NewScanner(originalBody)
for scanner.Scan() {
line := scanner.Text()
// Check if this line contains an endpoint event
if strings.HasPrefix(line, "event: endpoint") {
// Read the data line
if scanner.Scan() {
dataLine := scanner.Text()
if strings.HasPrefix(dataLine, "data: ") {
// Extract the endpoint URL
endpoint := strings.TrimPrefix(dataLine, "data: ")
// Replace the host in the endpoint
logger.Debug("Original endpoint: %s", endpoint)
endpoint = strings.Replace(endpoint, t.targetHost, t.proxyHost, 1)
logger.Debug("Modified endpoint: %s", endpoint)
// Write the modified event lines
fmt.Fprintln(pw, line)
fmt.Fprintln(pw, "data: "+endpoint)
continue
}
}
}
// Write the original line for non-endpoint events
fmt.Fprintln(pw, line)
}
if err := scanner.Err(); err != nil {
logger.Error("Error reading SSE stream: %v", err)
}
}()
// Replace the response body with our modified pipe
resp.Body = pr
return resp, nil
}

View file

@ -0,0 +1,268 @@
package subprocess
import (
"fmt"
"os"
"os/exec"
"sync"
"syscall"
"time"
"strings"
"github.com/wso2/open-mcp-auth-proxy/internal/config"
"github.com/wso2/open-mcp-auth-proxy/internal/logging"
)
// Manager handles starting and graceful shutdown of subprocesses
type Manager struct {
process *os.Process
processGroup int
mutex sync.Mutex
cmd *exec.Cmd
shutdownDelay time.Duration
}
// NewManager creates a new subprocess manager
func NewManager() *Manager {
return &Manager{
shutdownDelay: 5 * time.Second,
}
}
// EnsureDependenciesAvailable checks and installs required package executors
func EnsureDependenciesAvailable(command string) error {
// Always ensure npx is available regardless of the command
if _, err := exec.LookPath("npx"); err != nil {
// npx is not available, check if npm is installed
if _, err := exec.LookPath("npm"); err != nil {
return fmt.Errorf("npx not found and npm not available; please install Node.js from https://nodejs.org/")
}
// Try to install npx using npm
logger.Info("npx not found, attempting to install...")
cmd := exec.Command("npm", "install", "-g", "npx")
cmd.Stdout = os.Stdout
cmd.Stderr = os.Stderr
if err := cmd.Run(); err != nil {
return fmt.Errorf("failed to install npx: %w", err)
}
logger.Info("npx installed successfully")
}
// Check if uv is needed based on the command
if strings.Contains(command, "uv ") {
if _, err := exec.LookPath("uv"); err != nil {
return fmt.Errorf("command requires uv but it's not installed; please install it following instructions at https://github.com/astral-sh/uv")
}
}
return nil
}
// SetShutdownDelay sets the maximum time to wait for graceful shutdown
func (m *Manager) SetShutdownDelay(duration time.Duration) {
m.shutdownDelay = duration
}
// Start launches a subprocess based on the configuration
func (m *Manager) Start(cfg *config.Config) error {
m.mutex.Lock()
defer m.mutex.Unlock()
// If a process is already running, return an error
if m.process != nil {
return os.ErrExist
}
if !cfg.Stdio.Enabled || cfg.Stdio.UserCommand == "" {
return nil // Nothing to start
}
// Get the full command string
execCommand := cfg.BuildExecCommand()
if execCommand == "" {
return nil // No command to execute
}
logger.Info("Starting subprocess with command: %s", execCommand)
// Use the shell to execute the command
cmd := exec.Command("sh", "-c", execCommand)
// Set working directory if specified
if cfg.Stdio.WorkDir != "" {
cmd.Dir = cfg.Stdio.WorkDir
}
// Set environment variables if specified
if len(cfg.Stdio.Env) > 0 {
cmd.Env = append(os.Environ(), cfg.Stdio.Env...)
}
// Capture stdout/stderr
cmd.Stdout = os.Stdout
cmd.Stderr = os.Stderr
// Set the process group for proper termination
cmd.SysProcAttr = &syscall.SysProcAttr{Setpgid: true}
// Start the process
if err := cmd.Start(); err != nil {
return err
}
m.process = cmd.Process
m.cmd = cmd
logger.Info("Subprocess started with PID: %d", m.process.Pid)
// Get and store the process group ID
pgid, err := syscall.Getpgid(m.process.Pid)
if err == nil {
m.processGroup = pgid
logger.Debug("Process group ID: %d", m.processGroup)
} else {
logger.Warn("Failed to get process group ID: %v", err)
m.processGroup = m.process.Pid
}
// Handle process termination in background
go func() {
if err := cmd.Wait(); err != nil {
logger.Error("Subprocess exited with error: %v", err)
} else {
logger.Info("Subprocess exited successfully")
}
// Clear the process reference when it exits
m.mutex.Lock()
m.process = nil
m.cmd = nil
m.mutex.Unlock()
}()
return nil
}
// IsRunning checks if the subprocess is running
func (m *Manager) IsRunning() bool {
m.mutex.Lock()
defer m.mutex.Unlock()
return m.process != nil
}
// Shutdown gracefully terminates the subprocess
func (m *Manager) Shutdown() {
m.mutex.Lock()
processToTerminate := m.process // Local copy of the process reference
processGroupToTerminate := m.processGroup
m.mutex.Unlock()
if processToTerminate == nil {
return // No process to terminate
}
logger.Info("Terminating subprocess...")
terminateComplete := make(chan struct{})
go func() {
defer close(terminateComplete)
// Try graceful termination first with SIGTERM
terminatedGracefully := false
// Try to terminate the process group first
if processGroupToTerminate != 0 {
err := syscall.Kill(-processGroupToTerminate, syscall.SIGTERM)
if err != nil {
logger.Warn("Failed to send SIGTERM to process group: %v", err)
// Fallback to terminating just the process
m.mutex.Lock()
if m.process != nil {
err = m.process.Signal(syscall.SIGTERM)
if err != nil {
logger.Warn("Failed to send SIGTERM to process: %v", err)
}
}
m.mutex.Unlock()
}
} else {
// Try to terminate just the process
m.mutex.Lock()
if m.process != nil {
err := m.process.Signal(syscall.SIGTERM)
if err != nil {
logger.Warn("Failed to send SIGTERM to process: %v", err)
}
}
m.mutex.Unlock()
}
// Wait for the process to exit gracefully
for i := 0; i < 10; i++ {
time.Sleep(200 * time.Millisecond)
m.mutex.Lock()
if m.process == nil {
terminatedGracefully = true
m.mutex.Unlock()
break
}
m.mutex.Unlock()
}
if terminatedGracefully {
logger.Info("Subprocess terminated gracefully")
return
}
// If the process didn't exit gracefully, force kill
logger.Warn("Subprocess didn't exit gracefully, forcing termination...")
// Try to kill the process group first
if processGroupToTerminate != 0 {
if err := syscall.Kill(-processGroupToTerminate, syscall.SIGKILL); err != nil {
logger.Warn("Failed to send SIGKILL to process group: %v", err)
// Fallback to killing just the process
m.mutex.Lock()
if m.process != nil {
if err := m.process.Kill(); err != nil {
logger.Error("Failed to kill process: %v", err)
}
}
m.mutex.Unlock()
}
} else {
// Try to kill just the process
m.mutex.Lock()
if m.process != nil {
if err := m.process.Kill(); err != nil {
logger.Error("Failed to kill process: %v", err)
}
}
m.mutex.Unlock()
}
// Wait a bit more to confirm termination
time.Sleep(500 * time.Millisecond)
m.mutex.Lock()
if m.process == nil {
logger.Info("Subprocess terminated by force")
} else {
logger.Warn("Failed to terminate subprocess")
}
m.mutex.Unlock()
}()
// Wait for termination with timeout
select {
case <-terminateComplete:
// Termination completed
case <-time.After(m.shutdownDelay):
logger.Warn("Subprocess termination timed out")
}
}

View file

@ -4,12 +4,12 @@ import (
"crypto/rsa" "crypto/rsa"
"encoding/json" "encoding/json"
"errors" "errors"
"log"
"math/big" "math/big"
"net/http" "net/http"
"strings" "strings"
"github.com/golang-jwt/jwt/v4" "github.com/golang-jwt/jwt/v4"
"github.com/wso2/open-mcp-auth-proxy/internal/logging"
) )
type JWKS struct { type JWKS struct {
@ -50,7 +50,7 @@ func FetchJWKS(jwksURL string) error {
publicKeys[parsedKey.Kid] = pubKey publicKeys[parsedKey.Kid] = pubKey
} }
} }
log.Printf("[JWKS] Loaded %d public keys.", len(publicKeys)) logger.Info("Loaded %d public keys.", len(publicKeys))
return nil return nil
} }