mirror of
https://github.com/wso2/open-mcp-auth-proxy.git
synced 2025-06-27 17:13:31 +00:00
Add transport mode support for stdio, SSE stability fixes (#13)
Add transport mode support for stdio, SSE stability fixes
This commit is contained in:
parent
6ce52261db
commit
32c9378aad
12 changed files with 808 additions and 142 deletions
247
README.md
247
README.md
|
@ -1,101 +1,81 @@
|
||||||
# Open MCP Auth Proxy
|
# Open MCP Auth Proxy
|
||||||
|
|
||||||
The Open MCP Auth Proxy is a lightweight proxy designed to sit in front of MCP servers and enforce authorization in compliance with the [Model Context Protocol authorization](https://spec.modelcontextprotocol.io/specification/2025-03-26/basic/authorization/) requirements. It intercepts incoming requests, validates tokens, and offloads authentication and authorization to an OAuth-compliant Identity Provider.
|
A lightweight authorization proxy for Model Context Protocol (MCP) servers that enforces authorization according to the [MCP authorization specification](https://spec.modelcontextprotocol.io/specification/2025-03-26/basic/authorization/).
|
||||||
|
|
||||||

|

|
||||||
|
|
||||||
## **Setup and Installation**
|
## What it Does
|
||||||
|
|
||||||
### **Prerequisites**
|
Open MCP Auth Proxy sits between MCP clients and your MCP server to:
|
||||||
|
|
||||||
|
- Intercept incoming requests
|
||||||
|
- Validate authorization tokens
|
||||||
|
- Offload authentication and authorization to OAuth-compliant Identity Providers
|
||||||
|
- Support the MCP authorization protocol
|
||||||
|
|
||||||
|
## Quick Start
|
||||||
|
|
||||||
|
### Prerequisites
|
||||||
|
|
||||||
* Go 1.20 or higher
|
* Go 1.20 or higher
|
||||||
* A running MCP server (SSE transport supported)
|
* A running MCP server
|
||||||
* An MCP client that supports MCP authorization
|
* An MCP client that supports MCP authorization
|
||||||
|
|
||||||
### **Installation**
|
### Installation
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
git clone https://github.com/wso2/open-mcp-auth-proxy
|
git clone https://github.com/wso2/open-mcp-auth-proxy
|
||||||
cd open-mcp-auth-proxy
|
cd open-mcp-auth-proxy
|
||||||
|
go get github.com/golang-jwt/jwt/v4 gopkg.in/yaml.v2
|
||||||
go get github.com/golang-jwt/jwt/v4
|
|
||||||
go get gopkg.in/yaml.v2
|
|
||||||
|
|
||||||
go build -o openmcpauthproxy ./cmd/proxy
|
go build -o openmcpauthproxy ./cmd/proxy
|
||||||
```
|
```
|
||||||
|
|
||||||
## Using Open MCP Auth Proxy
|
### Basic Usage
|
||||||
|
|
||||||
### Quick Start
|
1. The repository comes with a default `config.yaml` file that contains the basic configuration:
|
||||||
|
|
||||||
Allows you to just enable authentication and authorization for your MCP server with the preconfigured auth provider powered by Asgardeo.
|
|
||||||
|
|
||||||
If you don’t have an MCP server, follow the instructions given here to start your own MCP server for testing purposes.
|
|
||||||
|
|
||||||
1. Navigate to `resources` directory.
|
|
||||||
2. Initialize a virtual environment.
|
|
||||||
|
|
||||||
```bash
|
|
||||||
python3 -m venv .venv
|
|
||||||
```
|
|
||||||
3. Activate virtual environment.
|
|
||||||
|
|
||||||
```bash
|
|
||||||
source .venv/bin/activate
|
|
||||||
```
|
|
||||||
|
|
||||||
4. Install dependencies.
|
|
||||||
|
|
||||||
```
|
|
||||||
pip3 install -r requirements.txt
|
|
||||||
```
|
|
||||||
|
|
||||||
5. Start the server.
|
|
||||||
|
|
||||||
```bash
|
|
||||||
python3 echo_server.py
|
|
||||||
```
|
|
||||||
|
|
||||||
#### Configure the Auth Proxy
|
|
||||||
|
|
||||||
Update the following parameters in `config.yaml`.
|
|
||||||
|
|
||||||
### demo mode configuration:
|
|
||||||
|
|
||||||
```yaml
|
```yaml
|
||||||
mcp_server_base_url: "http://localhost:8000" # URL of your MCP server
|
listen_port: 8080
|
||||||
listen_port: 8080 # Address where the proxy will listen
|
base_url: "http://localhost:8000" # Your MCP server URL
|
||||||
|
paths:
|
||||||
|
sse: "/sse"
|
||||||
|
messages: "/messages/"
|
||||||
```
|
```
|
||||||
|
|
||||||
#### Start the Auth Proxy
|
2. Start the proxy in demo mode (uses pre-configured authentication with Asgardeo sandbox):
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
./openmcpauthproxy --demo
|
./openmcpauthproxy --demo
|
||||||
```
|
```
|
||||||
|
|
||||||
The `--demo` flag enables a demonstration mode with pre-configured authentication and authorization with a sandbox powered by [Asgardeo](https://asgardeo.io/).
|
3. Connect using an MCP client like [MCP Inspector](https://github.com/shashimalcse/inspector)(This is a temporary fork with fixes for authentication [issues](https://github.com/modelcontextprotocol/typescript-sdk/issues/257) in the original implementation)
|
||||||
|
|
||||||
#### Connect Using an MCP Client
|
## Identity Provider Integration
|
||||||
|
|
||||||
You can use this fork of the [MCP Inspector](https://github.com/shashimalcse/inspector) to test the connection and try out the complete authorization flow. (This is a temporary fork with fixes for authentication [issues](https://github.com/modelcontextprotocol/typescript-sdk/issues/257) in the original implementation)
|
### Demo Mode
|
||||||
|
|
||||||
### Use with Asgardeo
|
For quick testing, use the `--demo` flag which includes pre-configured authentication and authorization with an Asgardeo sandbox.
|
||||||
|
|
||||||
Enable authorization for the MCP server through your own Asgardeo organization
|
```bash
|
||||||
|
./openmcpauthproxy --demo
|
||||||
|
```
|
||||||
|
|
||||||
1. [Register]([url](https://asgardeo.io/signup)) and create an organization in Asgardeo
|
### Asgardeo Integration
|
||||||
2. Now, you need to authorize the OpenMCPAuthProxy to allow dynamically registering MCP Clients as applications in your organization. To do that,
|
|
||||||
1. Create an [M2M application](https://wso2.com/asgardeo/docs/guides/applications/register-machine-to-machine-app/)
|
To enable authorization through your own Asgardeo organization:
|
||||||
1. [Authorize this application](https://wso2.com/asgardeo/docs/guides/applications/register-machine-to-machine-app/#authorize-the-api-resources-for-the-app) to invoke “Application Management API” with the `internal_application_mgt_create` scope.
|
|
||||||

|
1. [Register](https://asgardeo.io/signup) and create an organization in Asgardeo
|
||||||
2. Note the **Client ID** and **Client secret** of this application. This is required by the auth proxy
|
2. Create an [M2M application](https://wso2.com/asgardeo/docs/guides/applications/register-machine-to-machine-app/)
|
||||||
|
1. [Authorize this application](https://wso2.com/asgardeo/docs/guides/applications/register-machine-to-machine-app/#authorize-the-api-resources-for-the-app) to invoke "Application Management API" with the `internal_application_mgt_create` scope
|
||||||
|

|
||||||
|
2. Update the existing `config.yaml` with your Asgardeo details:
|
||||||
|
|
||||||
#### Configure the Auth Proxy
|
#### Configure the Auth Proxy
|
||||||
|
|
||||||
Create a configuration file config.yaml with the following parameters:
|
Create a configuration file config.yaml with the following parameters:
|
||||||
|
|
||||||
```yaml
|
```yaml
|
||||||
mcp_server_base_url: "http://localhost:8000" # URL of your MCP server
|
base_url: "http://localhost:8000" # URL of your MCP server
|
||||||
listen_port: 8080 # Address where the proxy will listen
|
listen_port: 8080 # Address where the proxy will listen
|
||||||
|
|
||||||
asgardeo:
|
asgardeo:
|
||||||
|
@ -104,31 +84,146 @@ asgardeo:
|
||||||
client_secret: "<client_secret>" # Client secret of the M2M app
|
client_secret: "<client_secret>" # Client secret of the M2M app
|
||||||
```
|
```
|
||||||
|
|
||||||
#### Start the Auth Proxy
|
3. Start the proxy with Asgardeo integration:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
./openmcpauthproxy --asgardeo
|
./openmcpauthproxy --asgardeo
|
||||||
```
|
```
|
||||||
|
|
||||||
### Use with any standard OAuth Server
|
### Other OAuth Providers
|
||||||
|
|
||||||
Enable authorization for the MCP server with a compliant OAuth server
|
- [Auth0 Integration Guide](docs/Auth0.md)
|
||||||
|
|
||||||
#### Configuration
|
## Testing with an Example MCP Server
|
||||||
|
|
||||||
Create a configuration file config.yaml with the following parameters:
|
If you don't have an MCP server, you can use the included example:
|
||||||
|
|
||||||
```yaml
|
1. Navigate to the `resources` directory
|
||||||
mcp_server_base_url: "http://localhost:8000" # URL of your MCP server
|
2. Set up a Python environment:
|
||||||
listen_port: 8080 # Address where the proxy will listen
|
|
||||||
```
|
|
||||||
**TODO**: Update the configs for a standard OAuth Server.
|
|
||||||
|
|
||||||
#### Start the Auth Proxy
|
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
./openmcpauthproxy
|
python3 -m venv .venv
|
||||||
|
source .venv/bin/activate
|
||||||
|
pip3 install -r requirements.txt
|
||||||
```
|
```
|
||||||
#### Integrating with existing OAuth Providers
|
|
||||||
|
|
||||||
- [Auth0](docs/Auth0.md) - Enable authorization for the MCP server through your Auth0 organization.
|
3. Start the example server:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python3 echo_server.py
|
||||||
|
```
|
||||||
|
|
||||||
|
# Advanced Configuration
|
||||||
|
|
||||||
|
### Transport Modes
|
||||||
|
|
||||||
|
The proxy supports two transport modes:
|
||||||
|
|
||||||
|
- **SSE Mode (Default)**: For Server-Sent Events transport
|
||||||
|
- **stdio Mode**: For MCP servers that use stdio transport
|
||||||
|
|
||||||
|
When using stdio mode, the proxy:
|
||||||
|
- Starts an MCP server as a subprocess using the command specified in the configuration
|
||||||
|
- Communicates with the subprocess through standard input/output (stdio)
|
||||||
|
- **Note**: Any commands specified (like `npx` in the example below) must be installed on your system first
|
||||||
|
|
||||||
|
To use stdio mode:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
./openmcpauthproxy --demo --stdio
|
||||||
|
```
|
||||||
|
|
||||||
|
#### Example: Running an MCP Server as a Subprocess
|
||||||
|
|
||||||
|
1. Configure stdio mode in your `config.yaml`:
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
listen_port: 8080
|
||||||
|
base_url: "http://localhost:8000"
|
||||||
|
|
||||||
|
stdio:
|
||||||
|
enabled: true
|
||||||
|
user_command: "npx -y @modelcontextprotocol/server-github" # Example using a GitHub MCP server
|
||||||
|
env: # Environment variables (optional)
|
||||||
|
- "GITHUB_PERSONAL_ACCESS_TOKEN=gitPAT"
|
||||||
|
|
||||||
|
# CORS configuration
|
||||||
|
cors:
|
||||||
|
allowed_origins:
|
||||||
|
- "http://localhost:5173" # Origin of your client application
|
||||||
|
allowed_methods:
|
||||||
|
- "GET"
|
||||||
|
- "POST"
|
||||||
|
- "PUT"
|
||||||
|
- "DELETE"
|
||||||
|
allowed_headers:
|
||||||
|
- "Authorization"
|
||||||
|
- "Content-Type"
|
||||||
|
allow_credentials: true
|
||||||
|
|
||||||
|
# Demo configuration for Asgardeo
|
||||||
|
demo:
|
||||||
|
org_name: "openmcpauthdemo"
|
||||||
|
client_id: "N0U9e_NNGr9mP_0fPnPfPI0a6twa"
|
||||||
|
client_secret: "qFHfiBp5gNGAO9zV4YPnDofBzzfInatfUbHyPZvM0jka"
|
||||||
|
```
|
||||||
|
|
||||||
|
2. Run the proxy with stdio mode:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
./openmcpauthproxy --demo
|
||||||
|
```
|
||||||
|
|
||||||
|
The proxy will:
|
||||||
|
- Start the MCP server as a subprocess using the specified command
|
||||||
|
- Handle all authorization requirements
|
||||||
|
- Forward messages between clients and the server
|
||||||
|
|
||||||
|
### Complete Configuration Reference
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
# Common configuration
|
||||||
|
listen_port: 8080
|
||||||
|
base_url: "http://localhost:8000"
|
||||||
|
port: 8000
|
||||||
|
|
||||||
|
# Path configuration
|
||||||
|
paths:
|
||||||
|
sse: "/sse"
|
||||||
|
messages: "/messages/"
|
||||||
|
|
||||||
|
# Transport mode
|
||||||
|
transport_mode: "sse" # Options: "sse" or "stdio"
|
||||||
|
|
||||||
|
# stdio-specific configuration (used only in stdio mode)
|
||||||
|
stdio:
|
||||||
|
enabled: true
|
||||||
|
user_command: "npx -y @modelcontextprotocol/server-github" # Command to start the MCP server (requires npx to be installed)
|
||||||
|
work_dir: "" # Optional working directory for the subprocess
|
||||||
|
|
||||||
|
# CORS configuration
|
||||||
|
cors:
|
||||||
|
allowed_origins:
|
||||||
|
- "http://localhost:5173"
|
||||||
|
allowed_methods:
|
||||||
|
- "GET"
|
||||||
|
- "POST"
|
||||||
|
- "PUT"
|
||||||
|
- "DELETE"
|
||||||
|
allowed_headers:
|
||||||
|
- "Authorization"
|
||||||
|
- "Content-Type"
|
||||||
|
allow_credentials: true
|
||||||
|
|
||||||
|
# Demo configuration for Asgardeo
|
||||||
|
demo:
|
||||||
|
org_name: "openmcpauthdemo"
|
||||||
|
client_id: "N0U9e_NNGr9mP_0fPnPfPI0a6twa"
|
||||||
|
client_secret: "qFHfiBp5gNGAO9zV4YPnDofBzzfInatfUbHyPZvM0jka"
|
||||||
|
|
||||||
|
# Asgardeo configuration (used with --asgardeo flag)
|
||||||
|
asgardeo:
|
||||||
|
org_name: "<org_name>"
|
||||||
|
client_id: "<client_id>"
|
||||||
|
client_secret: "<client_secret>"
|
||||||
|
```
|
|
@ -3,31 +3,71 @@ package main
|
||||||
import (
|
import (
|
||||||
"flag"
|
"flag"
|
||||||
"fmt"
|
"fmt"
|
||||||
"log"
|
|
||||||
"net/http"
|
"net/http"
|
||||||
"os"
|
"os"
|
||||||
"os/signal"
|
"os/signal"
|
||||||
|
"syscall"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/wso2/open-mcp-auth-proxy/internal/authz"
|
"github.com/wso2/open-mcp-auth-proxy/internal/authz"
|
||||||
"github.com/wso2/open-mcp-auth-proxy/internal/config"
|
"github.com/wso2/open-mcp-auth-proxy/internal/config"
|
||||||
"github.com/wso2/open-mcp-auth-proxy/internal/constants"
|
"github.com/wso2/open-mcp-auth-proxy/internal/constants"
|
||||||
|
"github.com/wso2/open-mcp-auth-proxy/internal/logging"
|
||||||
"github.com/wso2/open-mcp-auth-proxy/internal/proxy"
|
"github.com/wso2/open-mcp-auth-proxy/internal/proxy"
|
||||||
|
"github.com/wso2/open-mcp-auth-proxy/internal/subprocess"
|
||||||
"github.com/wso2/open-mcp-auth-proxy/internal/util"
|
"github.com/wso2/open-mcp-auth-proxy/internal/util"
|
||||||
)
|
)
|
||||||
|
|
||||||
func main() {
|
func main() {
|
||||||
demoMode := flag.Bool("demo", false, "Use Asgardeo-based provider (demo).")
|
demoMode := flag.Bool("demo", false, "Use Asgardeo-based provider (demo).")
|
||||||
asgardeoMode := flag.Bool("asgardeo", false, "Use Asgardeo-based provider (asgardeo).")
|
asgardeoMode := flag.Bool("asgardeo", false, "Use Asgardeo-based provider (asgardeo).")
|
||||||
|
debugMode := flag.Bool("debug", false, "Enable debug logging")
|
||||||
|
stdioMode := flag.Bool("stdio", false, "Use stdio transport mode instead of SSE")
|
||||||
flag.Parse()
|
flag.Parse()
|
||||||
|
|
||||||
|
logger.SetDebug(*debugMode)
|
||||||
|
|
||||||
// 1. Load config
|
// 1. Load config
|
||||||
cfg, err := config.LoadConfig("config.yaml")
|
cfg, err := config.LoadConfig("config.yaml")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatalf("Error loading config: %v", err)
|
logger.Error("Error loading config: %v", err)
|
||||||
|
os.Exit(1)
|
||||||
}
|
}
|
||||||
|
|
||||||
// 2. Create the chosen provider
|
// Override transport mode if stdio flag is set
|
||||||
|
if *stdioMode {
|
||||||
|
cfg.TransportMode = config.StdioTransport
|
||||||
|
// Ensure stdio is enabled
|
||||||
|
cfg.Stdio.Enabled = true
|
||||||
|
// Re-validate config
|
||||||
|
if err := cfg.Validate(); err != nil {
|
||||||
|
logger.Error("Configuration error: %v", err)
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.Info("Using transport mode: %s", cfg.TransportMode)
|
||||||
|
logger.Info("Using MCP server base URL: %s", cfg.BaseURL)
|
||||||
|
logger.Info("Using MCP paths: SSE=%s, Messages=%s", cfg.Paths.SSE, cfg.Paths.Messages)
|
||||||
|
|
||||||
|
// 2. Start subprocess if configured and in stdio mode
|
||||||
|
var procManager *subprocess.Manager
|
||||||
|
if cfg.TransportMode == config.StdioTransport && cfg.Stdio.Enabled {
|
||||||
|
// Ensure all required dependencies are available
|
||||||
|
if err := subprocess.EnsureDependenciesAvailable(cfg.Stdio.UserCommand); err != nil {
|
||||||
|
logger.Warn("%v", err)
|
||||||
|
logger.Warn("Subprocess may fail to start due to missing dependencies")
|
||||||
|
}
|
||||||
|
|
||||||
|
procManager = subprocess.NewManager()
|
||||||
|
if err := procManager.Start(cfg); err != nil {
|
||||||
|
logger.Warn("Failed to start subprocess: %v", err)
|
||||||
|
}
|
||||||
|
} else if cfg.TransportMode == config.SSETransport {
|
||||||
|
logger.Info("Using SSE transport mode, not starting subprocess")
|
||||||
|
}
|
||||||
|
|
||||||
|
// 3. Create the chosen provider
|
||||||
var provider authz.Provider
|
var provider authz.Provider
|
||||||
if *demoMode {
|
if *demoMode {
|
||||||
cfg.Mode = "demo"
|
cfg.Mode = "demo"
|
||||||
|
@ -46,41 +86,49 @@ func main() {
|
||||||
provider = authz.NewDefaultProvider(cfg)
|
provider = authz.NewDefaultProvider(cfg)
|
||||||
}
|
}
|
||||||
|
|
||||||
// 3. (Optional) Fetch JWKS if you want local JWT validation
|
// 4. (Optional) Fetch JWKS if you want local JWT validation
|
||||||
if err := util.FetchJWKS(cfg.JWKSURL); err != nil {
|
if err := util.FetchJWKS(cfg.JWKSURL); err != nil {
|
||||||
log.Fatalf("Failed to fetch JWKS: %v", err)
|
logger.Error("Failed to fetch JWKS: %v", err)
|
||||||
|
os.Exit(1)
|
||||||
}
|
}
|
||||||
|
|
||||||
// 4. Build the main router
|
// 5. Build the main router
|
||||||
mux := proxy.NewRouter(cfg, provider)
|
mux := proxy.NewRouter(cfg, provider)
|
||||||
|
|
||||||
listen_address := fmt.Sprintf(":%d", cfg.ListenPort)
|
listen_address := fmt.Sprintf(":%d", cfg.ListenPort)
|
||||||
|
|
||||||
// 5. Start the server
|
// 6. Start the server
|
||||||
srv := &http.Server{
|
srv := &http.Server{
|
||||||
|
|
||||||
Addr: listen_address,
|
Addr: listen_address,
|
||||||
Handler: mux,
|
Handler: mux,
|
||||||
}
|
}
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
log.Printf("Server listening on %s", listen_address)
|
logger.Info("Server listening on %s", listen_address)
|
||||||
if err := srv.ListenAndServe(); err != nil && err != http.ErrServerClosed {
|
if err := srv.ListenAndServe(); err != nil && err != http.ErrServerClosed {
|
||||||
log.Fatalf("Server error: %v", err)
|
logger.Error("Server error: %v", err)
|
||||||
|
os.Exit(1)
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
// 6. Graceful shutdown on Ctrl+C
|
// 7. Wait for shutdown signal
|
||||||
stop := make(chan os.Signal, 1)
|
stop := make(chan os.Signal, 1)
|
||||||
signal.Notify(stop, os.Interrupt)
|
signal.Notify(stop, os.Interrupt, syscall.SIGTERM)
|
||||||
<-stop
|
<-stop
|
||||||
log.Println("Shutting down...")
|
logger.Info("Shutting down...")
|
||||||
|
|
||||||
|
// 8. First terminate subprocess if running
|
||||||
|
if procManager != nil && procManager.IsRunning() {
|
||||||
|
procManager.Shutdown()
|
||||||
|
}
|
||||||
|
|
||||||
|
// 9. Then shutdown the server
|
||||||
|
logger.Info("Shutting down HTTP server...")
|
||||||
shutdownCtx, cancel := proxy.NewShutdownContext(5 * time.Second)
|
shutdownCtx, cancel := proxy.NewShutdownContext(5 * time.Second)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
if err := srv.Shutdown(shutdownCtx); err != nil {
|
if err := srv.Shutdown(shutdownCtx); err != nil {
|
||||||
log.Printf("Shutdown error: %v", err)
|
logger.Error("HTTP server shutdown error: %v", err)
|
||||||
}
|
}
|
||||||
log.Println("Stopped.")
|
logger.Info("Stopped.")
|
||||||
}
|
}
|
||||||
|
|
28
config.yaml
28
config.yaml
|
@ -1,15 +1,31 @@
|
||||||
# config.yaml
|
# config.yaml
|
||||||
|
|
||||||
mcp_server_base_url: ""
|
# Common configuration for all transport modes
|
||||||
listen_port: 8080
|
listen_port: 8080
|
||||||
|
base_url: "http://localhost:8000" # Base URL for the MCP server
|
||||||
|
port: 8000 # Port for the MCP server
|
||||||
timeout_seconds: 10
|
timeout_seconds: 10
|
||||||
|
|
||||||
mcp_paths:
|
# Path configuration
|
||||||
- /messages/
|
paths:
|
||||||
- /sse
|
sse: "/sse" # SSE endpoint path
|
||||||
|
messages: "/messages/" # Messages endpoint path
|
||||||
|
|
||||||
|
# Transport mode configuration
|
||||||
|
transport_mode: "sse" # Options: "sse" or "stdio"
|
||||||
|
|
||||||
|
# stdio-specific configuration (used only when transport_mode is "stdio")
|
||||||
|
stdio:
|
||||||
|
enabled: true
|
||||||
|
user_command: "npx -y @modelcontextprotocol/server-github"
|
||||||
|
work_dir: "" # Working directory (optional)
|
||||||
|
# env: # Environment variables (optional)
|
||||||
|
# - "NODE_ENV=development"
|
||||||
|
|
||||||
|
# Path mapping (optional)
|
||||||
path_mapping:
|
path_mapping:
|
||||||
|
|
||||||
|
# CORS configuration
|
||||||
cors:
|
cors:
|
||||||
allowed_origins:
|
allowed_origins:
|
||||||
- "http://localhost:5173"
|
- "http://localhost:5173"
|
||||||
|
@ -23,10 +39,8 @@ cors:
|
||||||
- "Content-Type"
|
- "Content-Type"
|
||||||
allow_credentials: true
|
allow_credentials: true
|
||||||
|
|
||||||
|
# Demo configuration for Asgardeo
|
||||||
demo:
|
demo:
|
||||||
org_name: "openmcpauthdemo"
|
org_name: "openmcpauthdemo"
|
||||||
client_id: "N0U9e_NNGr9mP_0fPnPfPI0a6twa"
|
client_id: "N0U9e_NNGr9mP_0fPnPfPI0a6twa"
|
||||||
client_secret: "qFHfiBp5gNGAO9zV4YPnDofBzzfInatfUbHyPZvM0jka"
|
client_secret: "qFHfiBp5gNGAO9zV4YPnDofBzzfInatfUbHyPZvM0jka"
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -7,13 +7,13 @@ import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"log"
|
|
||||||
"math/rand"
|
"math/rand"
|
||||||
"net/http"
|
"net/http"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/wso2/open-mcp-auth-proxy/internal/config"
|
"github.com/wso2/open-mcp-auth-proxy/internal/config"
|
||||||
|
"github.com/wso2/open-mcp-auth-proxy/internal/logging"
|
||||||
)
|
)
|
||||||
|
|
||||||
type asgardeoProvider struct {
|
type asgardeoProvider struct {
|
||||||
|
@ -31,6 +31,7 @@ func (p *asgardeoProvider) WellKnownHandler() http.HandlerFunc {
|
||||||
w.Header().Set("Access-Control-Allow-Origin", "*")
|
w.Header().Set("Access-Control-Allow-Origin", "*")
|
||||||
w.Header().Set("Access-Control-Allow-Headers", "Authorization, Content-Type")
|
w.Header().Set("Access-Control-Allow-Headers", "Authorization, Content-Type")
|
||||||
w.Header().Set("Access-Control-Allow-Methods", "GET, OPTIONS")
|
w.Header().Set("Access-Control-Allow-Methods", "GET, OPTIONS")
|
||||||
|
w.Header().Set("X-Accel-Buffering", "no")
|
||||||
|
|
||||||
if r.Method == http.MethodOptions {
|
if r.Method == http.MethodOptions {
|
||||||
w.WriteHeader(http.StatusNoContent)
|
w.WriteHeader(http.StatusNoContent)
|
||||||
|
@ -70,8 +71,9 @@ func (p *asgardeoProvider) WellKnownHandler() http.HandlerFunc {
|
||||||
}
|
}
|
||||||
|
|
||||||
w.Header().Set("Content-Type", "application/json")
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
w.Header().Set("X-Accel-Buffering", "no")
|
||||||
if err := json.NewEncoder(w).Encode(response); err != nil {
|
if err := json.NewEncoder(w).Encode(response); err != nil {
|
||||||
log.Printf("[asgardeoProvider] Error encoding well-known: %v", err)
|
logger.Error("Error encoding well-known: %v", err)
|
||||||
http.Error(w, "Internal server error", http.StatusInternalServerError)
|
http.Error(w, "Internal server error", http.StatusInternalServerError)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -83,6 +85,7 @@ func (p *asgardeoProvider) RegisterHandler() http.HandlerFunc {
|
||||||
w.Header().Set("Access-Control-Allow-Origin", "*")
|
w.Header().Set("Access-Control-Allow-Origin", "*")
|
||||||
w.Header().Set("Access-Control-Allow-Headers", "Authorization, Content-Type")
|
w.Header().Set("Access-Control-Allow-Headers", "Authorization, Content-Type")
|
||||||
w.Header().Set("Access-Control-Allow-Methods", "GET, OPTIONS")
|
w.Header().Set("Access-Control-Allow-Methods", "GET, OPTIONS")
|
||||||
|
w.Header().Set("X-Accel-Buffering", "no")
|
||||||
|
|
||||||
if r.Method == http.MethodOptions {
|
if r.Method == http.MethodOptions {
|
||||||
w.WriteHeader(http.StatusNoContent)
|
w.WriteHeader(http.StatusNoContent)
|
||||||
|
@ -95,7 +98,7 @@ func (p *asgardeoProvider) RegisterHandler() http.HandlerFunc {
|
||||||
|
|
||||||
var regReq RegisterRequest
|
var regReq RegisterRequest
|
||||||
if err := json.NewDecoder(r.Body).Decode(®Req); err != nil {
|
if err := json.NewDecoder(r.Body).Decode(®Req); err != nil {
|
||||||
log.Printf("ERROR: reading register request: %v", err)
|
logger.Error("Reading register request: %v", err)
|
||||||
http.Error(w, "Invalid request body", http.StatusBadRequest)
|
http.Error(w, "Invalid request body", http.StatusBadRequest)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@ -109,7 +112,7 @@ func (p *asgardeoProvider) RegisterHandler() http.HandlerFunc {
|
||||||
regReq.ClientSecret = randomString(16)
|
regReq.ClientSecret = randomString(16)
|
||||||
|
|
||||||
if err := p.createAsgardeoApplication(regReq); err != nil {
|
if err := p.createAsgardeoApplication(regReq); err != nil {
|
||||||
log.Printf("WARN: Asgardeo application creation failed: %v", err)
|
logger.Warn("Asgardeo application creation failed: %v", err)
|
||||||
// Optionally http.Error(...) if you want to fail
|
// Optionally http.Error(...) if you want to fail
|
||||||
// or continue to return partial data.
|
// or continue to return partial data.
|
||||||
}
|
}
|
||||||
|
@ -124,9 +127,10 @@ func (p *asgardeoProvider) RegisterHandler() http.HandlerFunc {
|
||||||
}
|
}
|
||||||
|
|
||||||
w.Header().Set("Content-Type", "application/json")
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
w.Header().Set("X-Accel-Buffering", "no")
|
||||||
w.WriteHeader(http.StatusCreated)
|
w.WriteHeader(http.StatusCreated)
|
||||||
if err := json.NewEncoder(w).Encode(resp); err != nil {
|
if err := json.NewEncoder(w).Encode(resp); err != nil {
|
||||||
log.Printf("ERROR: encoding /register response: %v", err)
|
logger.Error("Encoding /register response: %v", err)
|
||||||
http.Error(w, "Internal server error", http.StatusInternalServerError)
|
http.Error(w, "Internal server error", http.StatusInternalServerError)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -186,7 +190,7 @@ func (p *asgardeoProvider) createAsgardeoApplication(regReq RegisterRequest) err
|
||||||
return fmt.Errorf("Asgardeo creation error (%d): %s", resp.StatusCode, string(respBody))
|
return fmt.Errorf("Asgardeo creation error (%d): %s", resp.StatusCode, string(respBody))
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Printf("INFO: Created Asgardeo application for clientID=%s", regReq.ClientID)
|
logger.Info("Created Asgardeo application for clientID=%s", regReq.ClientID)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -202,9 +206,12 @@ func (p *asgardeoProvider) getAsgardeoAdminToken() (string, error) {
|
||||||
}
|
}
|
||||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||||
|
|
||||||
|
// Sensitive data - should not be logged at INFO level
|
||||||
auth := p.cfg.Demo.ClientID + ":" + p.cfg.Demo.ClientSecret
|
auth := p.cfg.Demo.ClientID + ":" + p.cfg.Demo.ClientSecret
|
||||||
req.Header.Set("Authorization", "Basic "+base64.StdEncoding.EncodeToString([]byte(auth)))
|
req.Header.Set("Authorization", "Basic "+base64.StdEncoding.EncodeToString([]byte(auth)))
|
||||||
|
|
||||||
|
logger.Debug("Requesting admin token for Asgardeo with client ID: %s", p.cfg.Demo.ClientID)
|
||||||
|
|
||||||
tr := &http.Transport{
|
tr := &http.Transport{
|
||||||
TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
|
TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
|
||||||
}
|
}
|
||||||
|
@ -234,6 +241,10 @@ func (p *asgardeoProvider) getAsgardeoAdminToken() (string, error) {
|
||||||
return "", fmt.Errorf("failed to parse token JSON: %w", err)
|
return "", fmt.Errorf("failed to parse token JSON: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Don't log the actual token at info level, only at debug level
|
||||||
|
logger.Debug("Received access token: %s", tokenResp.AccessToken)
|
||||||
|
logger.Info("Successfully obtained admin token from Asgardeo")
|
||||||
|
|
||||||
return tokenResp.AccessToken, nil
|
return tokenResp.AccessToken, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -5,6 +5,7 @@ import (
|
||||||
"net/http"
|
"net/http"
|
||||||
|
|
||||||
"github.com/wso2/open-mcp-auth-proxy/internal/config"
|
"github.com/wso2/open-mcp-auth-proxy/internal/config"
|
||||||
|
"github.com/wso2/open-mcp-auth-proxy/internal/logging"
|
||||||
)
|
)
|
||||||
|
|
||||||
type defaultProvider struct {
|
type defaultProvider struct {
|
||||||
|
@ -81,6 +82,7 @@ func (p *defaultProvider) WellKnownHandler() http.HandlerFunc {
|
||||||
|
|
||||||
w.Header().Set("Content-Type", "application/json")
|
w.Header().Set("Content-Type", "application/json")
|
||||||
if err := json.NewEncoder(w).Encode(response); err != nil {
|
if err := json.NewEncoder(w).Encode(response); err != nil {
|
||||||
|
logger.Error("Error encoding well-known response: %v", err)
|
||||||
http.Error(w, "Internal server error", http.StatusInternalServerError)
|
http.Error(w, "Internal server error", http.StatusInternalServerError)
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
|
|
|
@ -1,12 +1,35 @@
|
||||||
package config
|
package config
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"fmt"
|
||||||
"os"
|
"os"
|
||||||
|
|
||||||
"gopkg.in/yaml.v2"
|
"gopkg.in/yaml.v2"
|
||||||
)
|
)
|
||||||
|
|
||||||
// AsgardeoConfig groups all Asgardeo-specific fields
|
// Transport mode for MCP server
|
||||||
|
type TransportMode string
|
||||||
|
|
||||||
|
const (
|
||||||
|
SSETransport TransportMode = "sse"
|
||||||
|
StdioTransport TransportMode = "stdio"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Common path configuration for all transport modes
|
||||||
|
type PathsConfig struct {
|
||||||
|
SSE string `yaml:"sse"`
|
||||||
|
Messages string `yaml:"messages"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// StdioConfig contains stdio-specific configuration
|
||||||
|
type StdioConfig struct {
|
||||||
|
Enabled bool `yaml:"enabled"`
|
||||||
|
UserCommand string `yaml:"user_command"` // The command provided by the user
|
||||||
|
WorkDir string `yaml:"work_dir"` // Working directory (optional)
|
||||||
|
Args []string `yaml:"args,omitempty"` // Additional arguments
|
||||||
|
Env []string `yaml:"env,omitempty"` // Environment variables
|
||||||
|
}
|
||||||
|
|
||||||
type DemoConfig struct {
|
type DemoConfig struct {
|
||||||
ClientID string `yaml:"client_id"`
|
ClientID string `yaml:"client_id"`
|
||||||
ClientSecret string `yaml:"client_secret"`
|
ClientSecret string `yaml:"client_secret"`
|
||||||
|
@ -60,15 +83,18 @@ type DefaultConfig struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
type Config struct {
|
type Config struct {
|
||||||
AuthServerBaseURL string
|
AuthServerBaseURL string
|
||||||
MCPServerBaseURL string `yaml:"mcp_server_base_url"`
|
ListenPort int `yaml:"listen_port"`
|
||||||
ListenPort int `yaml:"listen_port"`
|
BaseURL string `yaml:"base_url"`
|
||||||
JWKSURL string
|
Port int `yaml:"port"`
|
||||||
TimeoutSeconds int `yaml:"timeout_seconds"`
|
JWKSURL string
|
||||||
MCPPaths []string `yaml:"mcp_paths"`
|
TimeoutSeconds int `yaml:"timeout_seconds"`
|
||||||
PathMapping map[string]string `yaml:"path_mapping"`
|
PathMapping map[string]string `yaml:"path_mapping"`
|
||||||
Mode string `yaml:"mode"`
|
Mode string `yaml:"mode"`
|
||||||
CORSConfig CORSConfig `yaml:"cors"`
|
CORSConfig CORSConfig `yaml:"cors"`
|
||||||
|
TransportMode TransportMode `yaml:"transport_mode"`
|
||||||
|
Paths PathsConfig `yaml:"paths"`
|
||||||
|
Stdio StdioConfig `yaml:"stdio"`
|
||||||
|
|
||||||
// Nested config for Asgardeo
|
// Nested config for Asgardeo
|
||||||
Demo DemoConfig `yaml:"demo"`
|
Demo DemoConfig `yaml:"demo"`
|
||||||
|
@ -76,6 +102,56 @@ type Config struct {
|
||||||
Default DefaultConfig `yaml:"default"`
|
Default DefaultConfig `yaml:"default"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Validate checks if the config is valid based on transport mode
|
||||||
|
func (c *Config) Validate() error {
|
||||||
|
// Validate based on transport mode
|
||||||
|
if c.TransportMode == StdioTransport {
|
||||||
|
if !c.Stdio.Enabled {
|
||||||
|
return fmt.Errorf("stdio.enabled must be true in stdio transport mode")
|
||||||
|
}
|
||||||
|
if c.Stdio.UserCommand == "" {
|
||||||
|
return fmt.Errorf("stdio.user_command is required in stdio transport mode")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validate paths
|
||||||
|
if c.Paths.SSE == "" {
|
||||||
|
c.Paths.SSE = "/sse" // Default value
|
||||||
|
}
|
||||||
|
if c.Paths.Messages == "" {
|
||||||
|
c.Paths.Messages = "/messages" // Default value
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validate base URL
|
||||||
|
if c.BaseURL == "" {
|
||||||
|
if c.Port > 0 {
|
||||||
|
c.BaseURL = fmt.Sprintf("http://localhost:%d", c.Port)
|
||||||
|
} else {
|
||||||
|
c.BaseURL = "http://localhost:8000" // Default value
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetMCPPaths returns the list of paths that should be proxied to the MCP server
|
||||||
|
func (c *Config) GetMCPPaths() []string {
|
||||||
|
return []string{c.Paths.SSE, c.Paths.Messages}
|
||||||
|
}
|
||||||
|
|
||||||
|
// BuildExecCommand constructs the full command string for execution in stdio mode
|
||||||
|
func (c *Config) BuildExecCommand() string {
|
||||||
|
if c.Stdio.UserCommand == "" {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
// Construct the full command
|
||||||
|
return fmt.Sprintf(
|
||||||
|
`npx -y supergateway --stdio "%s" --port %d --baseUrl %s --ssePath %s --messagePath %s`,
|
||||||
|
c.Stdio.UserCommand, c.Port, c.BaseURL, c.Paths.SSE, c.Paths.Messages,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
// LoadConfig reads a YAML config file into Config struct.
|
// LoadConfig reads a YAML config file into Config struct.
|
||||||
func LoadConfig(path string) (*Config, error) {
|
func LoadConfig(path string) (*Config, error) {
|
||||||
f, err := os.Open(path)
|
f, err := os.Open(path)
|
||||||
|
@ -89,8 +165,26 @@ func LoadConfig(path string) (*Config, error) {
|
||||||
if err := decoder.Decode(&cfg); err != nil {
|
if err := decoder.Decode(&cfg); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Set default values
|
||||||
if cfg.TimeoutSeconds == 0 {
|
if cfg.TimeoutSeconds == 0 {
|
||||||
cfg.TimeoutSeconds = 15 // default
|
cfg.TimeoutSeconds = 15 // default
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Set default transport mode if not specified
|
||||||
|
if cfg.TransportMode == "" {
|
||||||
|
cfg.TransportMode = SSETransport // Default to SSE
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set default port if not specified
|
||||||
|
if cfg.Port == 0 {
|
||||||
|
cfg.Port = 8000 // default
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validate the configuration
|
||||||
|
if err := cfg.Validate(); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
return &cfg, nil
|
return &cfg, nil
|
||||||
}
|
}
|
||||||
|
|
34
internal/logging/logger.go
Normal file
34
internal/logging/logger.go
Normal file
|
@ -0,0 +1,34 @@
|
||||||
|
package logger
|
||||||
|
|
||||||
|
import (
|
||||||
|
"log"
|
||||||
|
)
|
||||||
|
|
||||||
|
var isDebug = false
|
||||||
|
|
||||||
|
// SetDebug enables or disables debug logging
|
||||||
|
func SetDebug(debug bool) {
|
||||||
|
isDebug = debug
|
||||||
|
}
|
||||||
|
|
||||||
|
// Debug logs a debug-level message
|
||||||
|
func Debug(format string, v ...interface{}) {
|
||||||
|
if isDebug {
|
||||||
|
log.Printf("DEBUG: "+format, v...)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Info logs an info-level message
|
||||||
|
func Info(format string, v ...interface{}) {
|
||||||
|
log.Printf("INFO: "+format, v...)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Warn logs a warning-level message
|
||||||
|
func Warn(format string, v ...interface{}) {
|
||||||
|
log.Printf("WARN: "+format, v...)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Error logs an error-level message
|
||||||
|
func Error(format string, v ...interface{}) {
|
||||||
|
log.Printf("ERROR: "+format, v...)
|
||||||
|
}
|
|
@ -9,6 +9,7 @@ import (
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/wso2/open-mcp-auth-proxy/internal/config"
|
"github.com/wso2/open-mcp-auth-proxy/internal/config"
|
||||||
|
"github.com/wso2/open-mcp-auth-proxy/internal/logging"
|
||||||
)
|
)
|
||||||
|
|
||||||
// RequestModifier modifies requests before they are proxied
|
// RequestModifier modifies requests before they are proxied
|
||||||
|
@ -148,6 +149,7 @@ func (m *RegisterModifier) ModifyRequest(req *http.Request) (*http.Request, erro
|
||||||
if strings.Contains(contentType, "application/x-www-form-urlencoded") {
|
if strings.Contains(contentType, "application/x-www-form-urlencoded") {
|
||||||
// Parse form data
|
// Parse form data
|
||||||
if err := req.ParseForm(); err != nil {
|
if err := req.ParseForm(); err != nil {
|
||||||
|
logger.Error("Failed to parse form data: %v", err)
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -169,12 +171,14 @@ func (m *RegisterModifier) ModifyRequest(req *http.Request) (*http.Request, erro
|
||||||
// Read body
|
// Read body
|
||||||
bodyBytes, err := io.ReadAll(req.Body)
|
bodyBytes, err := io.ReadAll(req.Body)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
logger.Error("Failed to read request body: %v", err)
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// Parse JSON
|
// Parse JSON
|
||||||
var jsonData map[string]interface{}
|
var jsonData map[string]interface{}
|
||||||
if err := json.Unmarshal(bodyBytes, &jsonData); err != nil {
|
if err := json.Unmarshal(bodyBytes, &jsonData); err != nil {
|
||||||
|
logger.Error("Failed to parse JSON body: %v", err)
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -186,6 +190,7 @@ func (m *RegisterModifier) ModifyRequest(req *http.Request) (*http.Request, erro
|
||||||
// Marshal back to JSON
|
// Marshal back to JSON
|
||||||
modifiedBody, err := json.Marshal(jsonData)
|
modifiedBody, err := json.Marshal(jsonData)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
logger.Error("Failed to marshal modified JSON: %v", err)
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -2,7 +2,6 @@ package proxy
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"log"
|
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httputil"
|
"net/http/httputil"
|
||||||
"net/url"
|
"net/url"
|
||||||
|
@ -11,6 +10,7 @@ import (
|
||||||
|
|
||||||
"github.com/wso2/open-mcp-auth-proxy/internal/authz"
|
"github.com/wso2/open-mcp-auth-proxy/internal/authz"
|
||||||
"github.com/wso2/open-mcp-auth-proxy/internal/config"
|
"github.com/wso2/open-mcp-auth-proxy/internal/config"
|
||||||
|
"github.com/wso2/open-mcp-auth-proxy/internal/logging"
|
||||||
"github.com/wso2/open-mcp-auth-proxy/internal/util"
|
"github.com/wso2/open-mcp-auth-proxy/internal/util"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -82,7 +82,8 @@ func NewRouter(cfg *config.Config, provider authz.Provider) http.Handler {
|
||||||
}
|
}
|
||||||
|
|
||||||
// MCP paths
|
// MCP paths
|
||||||
for _, path := range cfg.MCPPaths {
|
mcpPaths := cfg.GetMCPPaths()
|
||||||
|
for _, path := range mcpPaths {
|
||||||
mux.HandleFunc(path, buildProxyHandler(cfg, modifiers))
|
mux.HandleFunc(path, buildProxyHandler(cfg, modifiers))
|
||||||
registeredPaths[path] = true
|
registeredPaths[path] = true
|
||||||
}
|
}
|
||||||
|
@ -100,23 +101,21 @@ func NewRouter(cfg *config.Config, provider authz.Provider) http.Handler {
|
||||||
|
|
||||||
func buildProxyHandler(cfg *config.Config, modifiers map[string]RequestModifier) http.HandlerFunc {
|
func buildProxyHandler(cfg *config.Config, modifiers map[string]RequestModifier) http.HandlerFunc {
|
||||||
// Parse the base URLs up front
|
// Parse the base URLs up front
|
||||||
|
|
||||||
authBase, err := url.Parse(cfg.AuthServerBaseURL)
|
authBase, err := url.Parse(cfg.AuthServerBaseURL)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatalf("Invalid auth server URL: %v", err)
|
logger.Error("Invalid auth server URL: %v", err)
|
||||||
|
panic(err) // Fatal error that prevents startup
|
||||||
}
|
}
|
||||||
mcpBase, err := url.Parse(cfg.MCPServerBaseURL)
|
|
||||||
|
mcpBase, err := url.Parse(cfg.BaseURL)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatalf("Invalid MCP server URL: %v", err)
|
logger.Error("Invalid MCP server URL: %v", err)
|
||||||
|
panic(err) // Fatal error that prevents startup
|
||||||
}
|
}
|
||||||
|
|
||||||
// Detect SSE paths from config
|
// Detect SSE paths from config
|
||||||
ssePaths := make(map[string]bool)
|
ssePaths := make(map[string]bool)
|
||||||
for _, p := range cfg.MCPPaths {
|
ssePaths[cfg.Paths.SSE] = true
|
||||||
if p == "/sse" {
|
|
||||||
ssePaths[p] = true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return func(w http.ResponseWriter, r *http.Request) {
|
return func(w http.ResponseWriter, r *http.Request) {
|
||||||
origin := r.Header.Get("Origin")
|
origin := r.Header.Get("Origin")
|
||||||
|
@ -124,7 +123,7 @@ func buildProxyHandler(cfg *config.Config, modifiers map[string]RequestModifier)
|
||||||
// Handle OPTIONS
|
// Handle OPTIONS
|
||||||
if r.Method == http.MethodOptions {
|
if r.Method == http.MethodOptions {
|
||||||
if allowedOrigin == "" {
|
if allowedOrigin == "" {
|
||||||
log.Printf("[proxy] Preflight request from disallowed origin: %s", origin)
|
logger.Warn("Preflight request from disallowed origin: %s", origin)
|
||||||
http.Error(w, "CORS origin not allowed", http.StatusForbidden)
|
http.Error(w, "CORS origin not allowed", http.StatusForbidden)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@ -134,7 +133,7 @@ func buildProxyHandler(cfg *config.Config, modifiers map[string]RequestModifier)
|
||||||
}
|
}
|
||||||
|
|
||||||
if allowedOrigin == "" {
|
if allowedOrigin == "" {
|
||||||
log.Printf("[proxy] Request from disallowed origin: %s for %s", origin, r.URL.Path)
|
logger.Warn("Request from disallowed origin: %s for %s", origin, r.URL.Path)
|
||||||
http.Error(w, "CORS origin not allowed", http.StatusForbidden)
|
http.Error(w, "CORS origin not allowed", http.StatusForbidden)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@ -152,7 +151,7 @@ func buildProxyHandler(cfg *config.Config, modifiers map[string]RequestModifier)
|
||||||
// Validate JWT for MCP paths if required
|
// Validate JWT for MCP paths if required
|
||||||
// Placeholder for JWT validation logic
|
// Placeholder for JWT validation logic
|
||||||
if err := util.ValidateJWT(r.Header.Get("Authorization")); err != nil {
|
if err := util.ValidateJWT(r.Header.Get("Authorization")); err != nil {
|
||||||
log.Printf("[proxy] Unauthorized request to %s: %v", r.URL.Path, err)
|
logger.Warn("Unauthorized request to %s: %v", r.URL.Path, err)
|
||||||
http.Error(w, "Unauthorized", http.StatusUnauthorized)
|
http.Error(w, "Unauthorized", http.StatusUnauthorized)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@ -170,7 +169,7 @@ func buildProxyHandler(cfg *config.Config, modifiers map[string]RequestModifier)
|
||||||
var err error
|
var err error
|
||||||
r, err = modifier.ModifyRequest(r)
|
r, err = modifier.ModifyRequest(r)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Printf("[proxy] Error modifying request: %v", err)
|
logger.Error("Error modifying request: %v", err)
|
||||||
http.Error(w, "Bad Request", http.StatusBadRequest)
|
http.Error(w, "Bad Request", http.StatusBadRequest)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@ -193,6 +192,12 @@ func buildProxyHandler(cfg *config.Config, modifiers map[string]RequestModifier)
|
||||||
|
|
||||||
cleanHeaders := http.Header{}
|
cleanHeaders := http.Header{}
|
||||||
|
|
||||||
|
// Set proper origin header to match the target
|
||||||
|
if isSSE {
|
||||||
|
// For SSE, ensure origin matches the target
|
||||||
|
req.Header.Set("Origin", targetURL.Scheme+"://"+targetURL.Host)
|
||||||
|
}
|
||||||
|
|
||||||
for k, v := range r.Header {
|
for k, v := range r.Header {
|
||||||
// Skip hop-by-hop headers
|
// Skip hop-by-hop headers
|
||||||
if skipHeader(k) {
|
if skipHeader(k) {
|
||||||
|
@ -205,21 +210,33 @@ func buildProxyHandler(cfg *config.Config, modifiers map[string]RequestModifier)
|
||||||
|
|
||||||
req.Header = cleanHeaders
|
req.Header = cleanHeaders
|
||||||
|
|
||||||
log.Printf("[proxy] %s -> %s%s", r.URL.Path, req.URL.Host, req.URL.Path)
|
logger.Debug("%s -> %s%s", r.URL.Path, req.URL.Host, req.URL.Path)
|
||||||
},
|
},
|
||||||
ModifyResponse: func(resp *http.Response) error {
|
ModifyResponse: func(resp *http.Response) error {
|
||||||
log.Printf("[proxy] Response from %s%s: %d", resp.Request.URL.Host, resp.Request.URL.Path, resp.StatusCode)
|
logger.Debug("Response from %s%s: %d", resp.Request.URL.Host, resp.Request.URL.Path, resp.StatusCode)
|
||||||
resp.Header.Del("Access-Control-Allow-Origin") // Avoid upstream conflicts
|
resp.Header.Del("Access-Control-Allow-Origin") // Avoid upstream conflicts
|
||||||
return nil
|
return nil
|
||||||
},
|
},
|
||||||
ErrorHandler: func(rw http.ResponseWriter, req *http.Request, err error) {
|
ErrorHandler: func(rw http.ResponseWriter, req *http.Request, err error) {
|
||||||
log.Printf("[proxy] Error proxying: %v", err)
|
logger.Error("Error proxying: %v", err)
|
||||||
http.Error(rw, "Bad Gateway", http.StatusBadGateway)
|
http.Error(rw, "Bad Gateway", http.StatusBadGateway)
|
||||||
},
|
},
|
||||||
FlushInterval: -1, // immediate flush for SSE
|
FlushInterval: -1, // immediate flush for SSE
|
||||||
}
|
}
|
||||||
|
|
||||||
if isSSE {
|
if isSSE {
|
||||||
|
// Add special response handling for SSE connections to rewrite endpoint URLs
|
||||||
|
rp.Transport = &sseTransport{
|
||||||
|
Transport: http.DefaultTransport,
|
||||||
|
proxyHost: r.Host,
|
||||||
|
targetHost: targetURL.Host,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set SSE-specific headers
|
||||||
|
w.Header().Set("X-Accel-Buffering", "no")
|
||||||
|
w.Header().Set("Cache-Control", "no-cache")
|
||||||
|
w.Header().Set("Connection", "keep-alive")
|
||||||
|
|
||||||
// Keep SSE connections open
|
// Keep SSE connections open
|
||||||
HandleSSE(w, r, rp)
|
HandleSSE(w, r, rp)
|
||||||
} else {
|
} else {
|
||||||
|
@ -236,6 +253,7 @@ func getAllowedOrigin(origin string, cfg *config.Config) string {
|
||||||
return cfg.CORSConfig.AllowedOrigins[0] // Default to first allowed origin
|
return cfg.CORSConfig.AllowedOrigins[0] // Default to first allowed origin
|
||||||
}
|
}
|
||||||
for _, allowed := range cfg.CORSConfig.AllowedOrigins {
|
for _, allowed := range cfg.CORSConfig.AllowedOrigins {
|
||||||
|
logger.Debug("Checking CORS origin: %s against allowed: %s", origin, allowed)
|
||||||
if allowed == origin {
|
if allowed == origin {
|
||||||
return allowed
|
return allowed
|
||||||
}
|
}
|
||||||
|
@ -256,6 +274,7 @@ func addCORSHeaders(w http.ResponseWriter, cfg *config.Config, allowedOrigin, re
|
||||||
w.Header().Set("Access-Control-Allow-Credentials", "true")
|
w.Header().Set("Access-Control-Allow-Credentials", "true")
|
||||||
}
|
}
|
||||||
w.Header().Set("Vary", "Origin")
|
w.Header().Set("Vary", "Origin")
|
||||||
|
w.Header().Set("X-Accel-Buffering", "no")
|
||||||
}
|
}
|
||||||
|
|
||||||
func isAuthPath(path string) bool {
|
func isAuthPath(path string) bool {
|
||||||
|
@ -273,7 +292,8 @@ func isAuthPath(path string) bool {
|
||||||
|
|
||||||
// isMCPPath checks if the path is an MCP path
|
// isMCPPath checks if the path is an MCP path
|
||||||
func isMCPPath(path string, cfg *config.Config) bool {
|
func isMCPPath(path string, cfg *config.Config) bool {
|
||||||
for _, p := range cfg.MCPPaths {
|
mcpPaths := cfg.GetMCPPaths()
|
||||||
|
for _, p := range mcpPaths {
|
||||||
if strings.HasPrefix(path, p) {
|
if strings.HasPrefix(path, p) {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,11 +1,16 @@
|
||||||
package proxy
|
package proxy
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"bufio"
|
||||||
"context"
|
"context"
|
||||||
"log"
|
"fmt"
|
||||||
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httputil"
|
"net/http/httputil"
|
||||||
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/wso2/open-mcp-auth-proxy/internal/logging"
|
||||||
)
|
)
|
||||||
|
|
||||||
// HandleSSE sets up a go-routine to wait for context cancellation
|
// HandleSSE sets up a go-routine to wait for context cancellation
|
||||||
|
@ -16,7 +21,7 @@ func HandleSSE(w http.ResponseWriter, r *http.Request, rp *httputil.ReverseProxy
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
<-ctx.Done()
|
<-ctx.Done()
|
||||||
log.Printf("INFO: SSE connection closed from %s (path: %s)", r.RemoteAddr, r.URL.Path)
|
logger.Info("SSE connection closed from %s (path: %s)", r.RemoteAddr, r.URL.Path)
|
||||||
close(done)
|
close(done)
|
||||||
}()
|
}()
|
||||||
|
|
||||||
|
@ -32,3 +37,73 @@ func HandleSSE(w http.ResponseWriter, r *http.Request, rp *httputil.ReverseProxy
|
||||||
func NewShutdownContext(timeout time.Duration) (context.Context, context.CancelFunc) {
|
func NewShutdownContext(timeout time.Duration) (context.Context, context.CancelFunc) {
|
||||||
return context.WithTimeout(context.Background(), timeout)
|
return context.WithTimeout(context.Background(), timeout)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// sseTransport is a custom http.RoundTripper that intercepts and modifies SSE responses
|
||||||
|
type sseTransport struct {
|
||||||
|
Transport http.RoundTripper
|
||||||
|
proxyHost string
|
||||||
|
targetHost string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *sseTransport) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||||
|
// Call the underlying transport
|
||||||
|
resp, err := t.Transport.RoundTrip(req)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if this is an SSE response
|
||||||
|
contentType := resp.Header.Get("Content-Type")
|
||||||
|
if !strings.Contains(contentType, "text/event-stream") {
|
||||||
|
return resp, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.Info("Intercepting SSE response to modify endpoint events")
|
||||||
|
|
||||||
|
// Create a response wrapper that modifies the response body
|
||||||
|
originalBody := resp.Body
|
||||||
|
pr, pw := io.Pipe()
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
defer originalBody.Close()
|
||||||
|
defer pw.Close()
|
||||||
|
|
||||||
|
scanner := bufio.NewScanner(originalBody)
|
||||||
|
for scanner.Scan() {
|
||||||
|
line := scanner.Text()
|
||||||
|
|
||||||
|
// Check if this line contains an endpoint event
|
||||||
|
if strings.HasPrefix(line, "event: endpoint") {
|
||||||
|
// Read the data line
|
||||||
|
if scanner.Scan() {
|
||||||
|
dataLine := scanner.Text()
|
||||||
|
if strings.HasPrefix(dataLine, "data: ") {
|
||||||
|
// Extract the endpoint URL
|
||||||
|
endpoint := strings.TrimPrefix(dataLine, "data: ")
|
||||||
|
|
||||||
|
// Replace the host in the endpoint
|
||||||
|
logger.Debug("Original endpoint: %s", endpoint)
|
||||||
|
endpoint = strings.Replace(endpoint, t.targetHost, t.proxyHost, 1)
|
||||||
|
logger.Debug("Modified endpoint: %s", endpoint)
|
||||||
|
|
||||||
|
// Write the modified event lines
|
||||||
|
fmt.Fprintln(pw, line)
|
||||||
|
fmt.Fprintln(pw, "data: "+endpoint)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Write the original line for non-endpoint events
|
||||||
|
fmt.Fprintln(pw, line)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := scanner.Err(); err != nil {
|
||||||
|
logger.Error("Error reading SSE stream: %v", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
// Replace the response body with our modified pipe
|
||||||
|
resp.Body = pr
|
||||||
|
return resp, nil
|
||||||
|
}
|
||||||
|
|
268
internal/subprocess/manager.go
Normal file
268
internal/subprocess/manager.go
Normal file
|
@ -0,0 +1,268 @@
|
||||||
|
package subprocess
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
"os/exec"
|
||||||
|
"sync"
|
||||||
|
"syscall"
|
||||||
|
"time"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/wso2/open-mcp-auth-proxy/internal/config"
|
||||||
|
"github.com/wso2/open-mcp-auth-proxy/internal/logging"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Manager handles starting and graceful shutdown of subprocesses
|
||||||
|
type Manager struct {
|
||||||
|
process *os.Process
|
||||||
|
processGroup int
|
||||||
|
mutex sync.Mutex
|
||||||
|
cmd *exec.Cmd
|
||||||
|
shutdownDelay time.Duration
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewManager creates a new subprocess manager
|
||||||
|
func NewManager() *Manager {
|
||||||
|
return &Manager{
|
||||||
|
shutdownDelay: 5 * time.Second,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// EnsureDependenciesAvailable checks and installs required package executors
|
||||||
|
func EnsureDependenciesAvailable(command string) error {
|
||||||
|
// Always ensure npx is available regardless of the command
|
||||||
|
if _, err := exec.LookPath("npx"); err != nil {
|
||||||
|
// npx is not available, check if npm is installed
|
||||||
|
if _, err := exec.LookPath("npm"); err != nil {
|
||||||
|
return fmt.Errorf("npx not found and npm not available; please install Node.js from https://nodejs.org/")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Try to install npx using npm
|
||||||
|
logger.Info("npx not found, attempting to install...")
|
||||||
|
cmd := exec.Command("npm", "install", "-g", "npx")
|
||||||
|
cmd.Stdout = os.Stdout
|
||||||
|
cmd.Stderr = os.Stderr
|
||||||
|
|
||||||
|
if err := cmd.Run(); err != nil {
|
||||||
|
return fmt.Errorf("failed to install npx: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.Info("npx installed successfully")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if uv is needed based on the command
|
||||||
|
if strings.Contains(command, "uv ") {
|
||||||
|
if _, err := exec.LookPath("uv"); err != nil {
|
||||||
|
return fmt.Errorf("command requires uv but it's not installed; please install it following instructions at https://github.com/astral-sh/uv")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetShutdownDelay sets the maximum time to wait for graceful shutdown
|
||||||
|
func (m *Manager) SetShutdownDelay(duration time.Duration) {
|
||||||
|
m.shutdownDelay = duration
|
||||||
|
}
|
||||||
|
|
||||||
|
// Start launches a subprocess based on the configuration
|
||||||
|
func (m *Manager) Start(cfg *config.Config) error {
|
||||||
|
m.mutex.Lock()
|
||||||
|
defer m.mutex.Unlock()
|
||||||
|
|
||||||
|
// If a process is already running, return an error
|
||||||
|
if m.process != nil {
|
||||||
|
return os.ErrExist
|
||||||
|
}
|
||||||
|
|
||||||
|
if !cfg.Stdio.Enabled || cfg.Stdio.UserCommand == "" {
|
||||||
|
return nil // Nothing to start
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get the full command string
|
||||||
|
execCommand := cfg.BuildExecCommand()
|
||||||
|
if execCommand == "" {
|
||||||
|
return nil // No command to execute
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.Info("Starting subprocess with command: %s", execCommand)
|
||||||
|
|
||||||
|
// Use the shell to execute the command
|
||||||
|
cmd := exec.Command("sh", "-c", execCommand)
|
||||||
|
|
||||||
|
// Set working directory if specified
|
||||||
|
if cfg.Stdio.WorkDir != "" {
|
||||||
|
cmd.Dir = cfg.Stdio.WorkDir
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set environment variables if specified
|
||||||
|
if len(cfg.Stdio.Env) > 0 {
|
||||||
|
cmd.Env = append(os.Environ(), cfg.Stdio.Env...)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Capture stdout/stderr
|
||||||
|
cmd.Stdout = os.Stdout
|
||||||
|
cmd.Stderr = os.Stderr
|
||||||
|
|
||||||
|
// Set the process group for proper termination
|
||||||
|
cmd.SysProcAttr = &syscall.SysProcAttr{Setpgid: true}
|
||||||
|
|
||||||
|
// Start the process
|
||||||
|
if err := cmd.Start(); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
m.process = cmd.Process
|
||||||
|
m.cmd = cmd
|
||||||
|
logger.Info("Subprocess started with PID: %d", m.process.Pid)
|
||||||
|
|
||||||
|
// Get and store the process group ID
|
||||||
|
pgid, err := syscall.Getpgid(m.process.Pid)
|
||||||
|
if err == nil {
|
||||||
|
m.processGroup = pgid
|
||||||
|
logger.Debug("Process group ID: %d", m.processGroup)
|
||||||
|
} else {
|
||||||
|
logger.Warn("Failed to get process group ID: %v", err)
|
||||||
|
m.processGroup = m.process.Pid
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handle process termination in background
|
||||||
|
go func() {
|
||||||
|
if err := cmd.Wait(); err != nil {
|
||||||
|
logger.Error("Subprocess exited with error: %v", err)
|
||||||
|
} else {
|
||||||
|
logger.Info("Subprocess exited successfully")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Clear the process reference when it exits
|
||||||
|
m.mutex.Lock()
|
||||||
|
m.process = nil
|
||||||
|
m.cmd = nil
|
||||||
|
m.mutex.Unlock()
|
||||||
|
}()
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsRunning checks if the subprocess is running
|
||||||
|
func (m *Manager) IsRunning() bool {
|
||||||
|
m.mutex.Lock()
|
||||||
|
defer m.mutex.Unlock()
|
||||||
|
return m.process != nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Shutdown gracefully terminates the subprocess
|
||||||
|
func (m *Manager) Shutdown() {
|
||||||
|
m.mutex.Lock()
|
||||||
|
processToTerminate := m.process // Local copy of the process reference
|
||||||
|
processGroupToTerminate := m.processGroup
|
||||||
|
m.mutex.Unlock()
|
||||||
|
|
||||||
|
if processToTerminate == nil {
|
||||||
|
return // No process to terminate
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.Info("Terminating subprocess...")
|
||||||
|
terminateComplete := make(chan struct{})
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
defer close(terminateComplete)
|
||||||
|
|
||||||
|
// Try graceful termination first with SIGTERM
|
||||||
|
terminatedGracefully := false
|
||||||
|
|
||||||
|
// Try to terminate the process group first
|
||||||
|
if processGroupToTerminate != 0 {
|
||||||
|
err := syscall.Kill(-processGroupToTerminate, syscall.SIGTERM)
|
||||||
|
if err != nil {
|
||||||
|
logger.Warn("Failed to send SIGTERM to process group: %v", err)
|
||||||
|
|
||||||
|
// Fallback to terminating just the process
|
||||||
|
m.mutex.Lock()
|
||||||
|
if m.process != nil {
|
||||||
|
err = m.process.Signal(syscall.SIGTERM)
|
||||||
|
if err != nil {
|
||||||
|
logger.Warn("Failed to send SIGTERM to process: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
m.mutex.Unlock()
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// Try to terminate just the process
|
||||||
|
m.mutex.Lock()
|
||||||
|
if m.process != nil {
|
||||||
|
err := m.process.Signal(syscall.SIGTERM)
|
||||||
|
if err != nil {
|
||||||
|
logger.Warn("Failed to send SIGTERM to process: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
m.mutex.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Wait for the process to exit gracefully
|
||||||
|
for i := 0; i < 10; i++ {
|
||||||
|
time.Sleep(200 * time.Millisecond)
|
||||||
|
|
||||||
|
m.mutex.Lock()
|
||||||
|
if m.process == nil {
|
||||||
|
terminatedGracefully = true
|
||||||
|
m.mutex.Unlock()
|
||||||
|
break
|
||||||
|
}
|
||||||
|
m.mutex.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
if terminatedGracefully {
|
||||||
|
logger.Info("Subprocess terminated gracefully")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// If the process didn't exit gracefully, force kill
|
||||||
|
logger.Warn("Subprocess didn't exit gracefully, forcing termination...")
|
||||||
|
|
||||||
|
// Try to kill the process group first
|
||||||
|
if processGroupToTerminate != 0 {
|
||||||
|
if err := syscall.Kill(-processGroupToTerminate, syscall.SIGKILL); err != nil {
|
||||||
|
logger.Warn("Failed to send SIGKILL to process group: %v", err)
|
||||||
|
|
||||||
|
// Fallback to killing just the process
|
||||||
|
m.mutex.Lock()
|
||||||
|
if m.process != nil {
|
||||||
|
if err := m.process.Kill(); err != nil {
|
||||||
|
logger.Error("Failed to kill process: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
m.mutex.Unlock()
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// Try to kill just the process
|
||||||
|
m.mutex.Lock()
|
||||||
|
if m.process != nil {
|
||||||
|
if err := m.process.Kill(); err != nil {
|
||||||
|
logger.Error("Failed to kill process: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
m.mutex.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Wait a bit more to confirm termination
|
||||||
|
time.Sleep(500 * time.Millisecond)
|
||||||
|
|
||||||
|
m.mutex.Lock()
|
||||||
|
if m.process == nil {
|
||||||
|
logger.Info("Subprocess terminated by force")
|
||||||
|
} else {
|
||||||
|
logger.Warn("Failed to terminate subprocess")
|
||||||
|
}
|
||||||
|
m.mutex.Unlock()
|
||||||
|
}()
|
||||||
|
|
||||||
|
// Wait for termination with timeout
|
||||||
|
select {
|
||||||
|
case <-terminateComplete:
|
||||||
|
// Termination completed
|
||||||
|
case <-time.After(m.shutdownDelay):
|
||||||
|
logger.Warn("Subprocess termination timed out")
|
||||||
|
}
|
||||||
|
}
|
|
@ -4,12 +4,12 @@ import (
|
||||||
"crypto/rsa"
|
"crypto/rsa"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
"log"
|
|
||||||
"math/big"
|
"math/big"
|
||||||
"net/http"
|
"net/http"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/golang-jwt/jwt/v4"
|
"github.com/golang-jwt/jwt/v4"
|
||||||
|
"github.com/wso2/open-mcp-auth-proxy/internal/logging"
|
||||||
)
|
)
|
||||||
|
|
||||||
type JWKS struct {
|
type JWKS struct {
|
||||||
|
@ -50,7 +50,7 @@ func FetchJWKS(jwksURL string) error {
|
||||||
publicKeys[parsedKey.Kid] = pubKey
|
publicKeys[parsedKey.Kid] = pubKey
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
log.Printf("[JWKS] Loaded %d public keys.", len(publicKeys))
|
logger.Info("Loaded %d public keys.", len(publicKeys))
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue