mirror of
https://github.com/wso2/open-mcp-auth-proxy.git
synced 2025-06-28 01:23:30 +00:00
Compare commits
18 commits
Author | SHA1 | Date | |
---|---|---|---|
|
1edfed91b5 | ||
|
53e0fa65a1 | ||
|
fdb81007d4 | ||
|
316370be1c | ||
|
fc0d939e16 | ||
|
2a0075b22e | ||
|
edd3ce483e | ||
|
c7fc15399b | ||
|
9e1316b420 | ||
|
56cdc96cb6 | ||
|
9f856c4279 | ||
|
be697b5868 | ||
|
8bc2e6e76b | ||
|
561b8fb637 | ||
|
68015ae8fc | ||
|
ad5185ad72 | ||
|
0bbc20ca5a | ||
|
4a5cf4e1cc |
13 changed files with 394 additions and 193 deletions
2
.github/scripts/release.sh
vendored
2
.github/scripts/release.sh
vendored
|
@ -51,7 +51,7 @@ else
|
||||||
fi
|
fi
|
||||||
|
|
||||||
# Extract current version.
|
# Extract current version.
|
||||||
CURRENT_VERSION=$(git describe --tags --abbrev=0 2>/dev/null || echo "0.0.0")
|
CURRENT_VERSION=$(git tag --sort=-v:refname | head -n 1 | sed 's/^v//' || echo "0.0.0")
|
||||||
IFS='.' read -r MAJOR MINOR PATCH <<< "${CURRENT_VERSION}"
|
IFS='.' read -r MAJOR MINOR PATCH <<< "${CURRENT_VERSION}"
|
||||||
|
|
||||||
# Determine which part to increment
|
# Determine which part to increment
|
||||||
|
|
4
.gitignore
vendored
4
.gitignore
vendored
|
@ -36,3 +36,7 @@ coverage.html
|
||||||
|
|
||||||
# IDE files
|
# IDE files
|
||||||
.vscode
|
.vscode
|
||||||
|
|
||||||
|
# node modules
|
||||||
|
node_modules
|
||||||
|
openmcpauthproxy
|
||||||
|
|
10
Makefile
10
Makefile
|
@ -24,9 +24,9 @@ TEST_OPTS := -v -race
|
||||||
.PHONY: all clean test fmt lint vet coverage help
|
.PHONY: all clean test fmt lint vet coverage help
|
||||||
|
|
||||||
# Default target
|
# Default target
|
||||||
all: lint test build-linux build-linux-arm build-darwin
|
all: lint test build-linux build-linux-arm build-darwin build-windows
|
||||||
|
|
||||||
build: clean test build-linux build-linux-arm build-darwin
|
build: clean test build-linux build-linux-arm build-darwin build-windows
|
||||||
|
|
||||||
build-linux:
|
build-linux:
|
||||||
mkdir -p $(BUILD_DIR)/linux
|
mkdir -p $(BUILD_DIR)/linux
|
||||||
|
@ -46,6 +46,12 @@ build-darwin:
|
||||||
-o $(BUILD_DIR)/darwin/openmcpauthproxy $(PROJECT_ROOT)/cmd/proxy
|
-o $(BUILD_DIR)/darwin/openmcpauthproxy $(PROJECT_ROOT)/cmd/proxy
|
||||||
cp config.yaml $(BUILD_DIR)/darwin
|
cp config.yaml $(BUILD_DIR)/darwin
|
||||||
|
|
||||||
|
build-windows:
|
||||||
|
mkdir -p $(BUILD_DIR)/windows
|
||||||
|
GOOS=windows GOARCH=amd64 CGO_ENABLED=0 go build -x -ldflags "-X main.version=$(BUILD_VERSION)" \
|
||||||
|
-o $(BUILD_DIR)/windows/openmcpauthproxy.exe ./cmd/proxy
|
||||||
|
cp config.yaml $(BUILD_DIR)/windows
|
||||||
|
|
||||||
# Clean build artifacts
|
# Clean build artifacts
|
||||||
clean:
|
clean:
|
||||||
@echo "Cleaning build artifacts..."
|
@echo "Cleaning build artifacts..."
|
||||||
|
|
126
README.md
126
README.md
|
@ -47,22 +47,30 @@ Open MCP Auth Proxy sits between MCP clients and your MCP server to:
|
||||||
|
|
||||||
### Basic Usage
|
### Basic Usage
|
||||||
|
|
||||||
1. The repository comes with a default `config.yaml` file that contains the basic configuration:
|
1. Download the latest release from [Github releases](https://github.com/wso2/open-mcp-auth-proxy/releases/latest).
|
||||||
|
|
||||||
```yaml
|
|
||||||
listen_port: 8080
|
|
||||||
base_url: "http://localhost:8000" # Your MCP server URL
|
|
||||||
paths:
|
|
||||||
sse: "/sse"
|
|
||||||
messages: "/messages/"
|
|
||||||
```
|
|
||||||
|
|
||||||
2. Start the proxy in demo mode (uses pre-configured authentication with Asgardeo sandbox):
|
2. Start the proxy in demo mode (uses pre-configured authentication with Asgardeo sandbox):
|
||||||
|
|
||||||
|
#### Linux/macOS:
|
||||||
```bash
|
```bash
|
||||||
./openmcpauthproxy --demo
|
./openmcpauthproxy --demo
|
||||||
```
|
```
|
||||||
|
|
||||||
|
#### Windows:
|
||||||
|
```powershell
|
||||||
|
.\openmcpauthproxy.exe --demo
|
||||||
|
```
|
||||||
|
|
||||||
|
> The repository comes with a default `config.yaml` file that contains the basic configuration:
|
||||||
|
>
|
||||||
|
> ```yaml
|
||||||
|
> listen_port: 8080
|
||||||
|
> base_url: "http://localhost:8000" # Your MCP server URL
|
||||||
|
> paths:
|
||||||
|
> sse: "/sse"
|
||||||
|
> messages: "/messages/"
|
||||||
|
> ```
|
||||||
|
|
||||||
3. Connect using an MCP client like [MCP Inspector](https://github.com/shashimalcse/inspector)(This is a temporary fork with fixes for authentication [issues](https://github.com/modelcontextprotocol/typescript-sdk/issues/257) in the original implementation)
|
3. Connect using an MCP client like [MCP Inspector](https://github.com/shashimalcse/inspector)(This is a temporary fork with fixes for authentication [issues](https://github.com/modelcontextprotocol/typescript-sdk/issues/257) in the original implementation)
|
||||||
|
|
||||||
## Connect an Identity Provider
|
## Connect an Identity Provider
|
||||||
|
@ -213,12 +221,102 @@ asgardeo:
|
||||||
client_id: "<client_id>"
|
client_id: "<client_id>"
|
||||||
client_secret: "<client_secret>"
|
client_secret: "<client_secret>"
|
||||||
```
|
```
|
||||||
|
## Build from Source
|
||||||
|
|
||||||
### Build from source
|
### Prerequisites
|
||||||
|
|
||||||
|
* Go 1.20 or higher
|
||||||
|
* Git
|
||||||
|
* Make (optional, for simplified builds)
|
||||||
|
|
||||||
|
### Clone and Build
|
||||||
|
|
||||||
|
1. **Clone the repository:**
|
||||||
|
```bash
|
||||||
|
git clone https://github.com/wso2/open-mcp-auth-proxy
|
||||||
|
cd open-mcp-auth-proxy
|
||||||
|
```
|
||||||
|
|
||||||
|
2. **Install dependencies:**
|
||||||
|
```bash
|
||||||
|
go get -v -t -d ./...
|
||||||
|
```
|
||||||
|
|
||||||
|
3. **Build the application:**
|
||||||
|
|
||||||
|
**Option A: Using Make**
|
||||||
|
```bash
|
||||||
|
# Build for all platforms
|
||||||
|
make all
|
||||||
|
|
||||||
|
# Or build for specific platforms
|
||||||
|
make build-linux # For Linux (x86_64)
|
||||||
|
make build-linux-arm # For ARM-based Linux
|
||||||
|
make build-darwin # For macOS
|
||||||
|
make build-windows # For Windows
|
||||||
|
```
|
||||||
|
|
||||||
|
**Option B: Manual build (works on all platforms)**
|
||||||
|
```bash
|
||||||
|
# Build for your current platform
|
||||||
|
go build -o openmcpauthproxy ./cmd/proxy
|
||||||
|
|
||||||
|
# Cross-compile for other platforms
|
||||||
|
GOOS=linux GOARCH=amd64 go build -o openmcpauthproxy-linux ./cmd/proxy
|
||||||
|
GOOS=windows GOARCH=amd64 go build -o openmcpauthproxy.exe ./cmd/proxy
|
||||||
|
GOOS=darwin GOARCH=amd64 go build -o openmcpauthproxy-macos ./cmd/proxy
|
||||||
|
```
|
||||||
|
|
||||||
|
### Run the Built Application
|
||||||
|
|
||||||
|
After building, you'll find the executables in the `build` directory (when using Make) or in your project root (when building manually).
|
||||||
|
|
||||||
|
**Linux/macOS:**
|
||||||
|
```bash
|
||||||
|
# If built with Make
|
||||||
|
./build/linux/openmcpauthproxy --demo
|
||||||
|
|
||||||
|
# If built manually
|
||||||
|
./openmcpauthproxy --demo
|
||||||
|
```
|
||||||
|
|
||||||
|
**Windows:**
|
||||||
|
```powershell
|
||||||
|
# If built with Make
|
||||||
|
.\build\windows\openmcpauthproxy.exe --demo
|
||||||
|
|
||||||
|
# If built manually
|
||||||
|
.\openmcpauthproxy.exe --demo
|
||||||
|
```
|
||||||
|
|
||||||
|
### Available Command Line Options
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
git clone https://github.com/wso2/open-mcp-auth-proxy
|
# Start in demo mode (using Asgardeo sandbox)
|
||||||
cd open-mcp-auth-proxy
|
./openmcpauthproxy --demo
|
||||||
go get github.com/golang-jwt/jwt/v4 gopkg.in/yaml.v2
|
|
||||||
go build -o openmcpauthproxy ./cmd/proxy
|
# Start with your own Asgardeo organization
|
||||||
|
./openmcpauthproxy --asgardeo
|
||||||
|
|
||||||
|
# Use stdio transport mode instead of SSE
|
||||||
|
./openmcpauthproxy --demo --stdio
|
||||||
|
|
||||||
|
# Enable debug logging
|
||||||
|
./openmcpauthproxy --demo --debug
|
||||||
|
|
||||||
|
# Show all available options
|
||||||
|
./openmcpauthproxy --help
|
||||||
|
```
|
||||||
|
|
||||||
|
### Additional Make Targets
|
||||||
|
|
||||||
|
If you're using Make, these additional targets are available:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
make test # Run tests
|
||||||
|
make coverage # Run tests with coverage report
|
||||||
|
make fmt # Format code with gofmt
|
||||||
|
make vet # Run go vet
|
||||||
|
make clean # Clean build artifacts
|
||||||
|
make help # Show all available targets
|
||||||
```
|
```
|
||||||
|
|
|
@ -2,14 +2,15 @@
|
||||||
|
|
||||||
# Common configuration for all transport modes
|
# Common configuration for all transport modes
|
||||||
listen_port: 8080
|
listen_port: 8080
|
||||||
base_url: "http://localhost:8000" # Base URL for the MCP server
|
base_url: "http://localhost:3001" # Base URL for the MCP server
|
||||||
port: 8000 # Port for the MCP server
|
port: 3001 # Port for the MCP server
|
||||||
timeout_seconds: 10
|
timeout_seconds: 10
|
||||||
|
|
||||||
# Path configuration
|
# Path configuration
|
||||||
paths:
|
paths:
|
||||||
sse: "/sse" # SSE endpoint path
|
sse: "/sse" # SSE endpoint path
|
||||||
messages: "/messages/" # Messages endpoint path
|
messages: "/messages/" # Messages endpoint path
|
||||||
|
streamable_http: "/mcp" # MCP endpoint path
|
||||||
|
|
||||||
# Transport mode configuration
|
# Transport mode configuration
|
||||||
transport_mode: "sse" # Options: "sse" or "stdio"
|
transport_mode: "sse" # Options: "sse" or "stdio"
|
||||||
|
@ -28,7 +29,7 @@ path_mapping:
|
||||||
# CORS configuration
|
# CORS configuration
|
||||||
cors:
|
cors:
|
||||||
allowed_origins:
|
allowed_origins:
|
||||||
- "http://localhost:5173"
|
- "http://127.0.0.1:6274"
|
||||||
allowed_methods:
|
allowed_methods:
|
||||||
- "GET"
|
- "GET"
|
||||||
- "POST"
|
- "POST"
|
||||||
|
|
|
@ -13,14 +13,14 @@ import (
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/wso2/open-mcp-auth-proxy/internal/config"
|
"github.com/wso2/open-mcp-auth-proxy/internal/config"
|
||||||
"github.com/wso2/open-mcp-auth-proxy/internal/logging"
|
logger "github.com/wso2/open-mcp-auth-proxy/internal/logging"
|
||||||
)
|
)
|
||||||
|
|
||||||
type asgardeoProvider struct {
|
type asgardeoProvider struct {
|
||||||
cfg *config.Config
|
cfg *config.Config
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewAsgardeoProvider initializes a Provider for Asgardeo (demo mode).
|
// NewAsgardeoProvider initializes a Provider for Asgardeo.
|
||||||
func NewAsgardeoProvider(cfg *config.Config) Provider {
|
func NewAsgardeoProvider(cfg *config.Config) Provider {
|
||||||
return &asgardeoProvider{cfg: cfg}
|
return &asgardeoProvider{cfg: cfg}
|
||||||
}
|
}
|
||||||
|
@ -113,6 +113,7 @@ func (p *asgardeoProvider) RegisterHandler() http.HandlerFunc {
|
||||||
|
|
||||||
if err := p.createAsgardeoApplication(regReq); err != nil {
|
if err := p.createAsgardeoApplication(regReq); err != nil {
|
||||||
logger.Warn("Asgardeo application creation failed: %v", err)
|
logger.Warn("Asgardeo application creation failed: %v", err)
|
||||||
|
http.Error(w, "Failed to create application in Asgardeo", http.StatusInternalServerError)
|
||||||
// Optionally http.Error(...) if you want to fail
|
// Optionally http.Error(...) if you want to fail
|
||||||
// or continue to return partial data.
|
// or continue to return partial data.
|
||||||
}
|
}
|
||||||
|
@ -159,13 +160,19 @@ type RegisterResponse struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *asgardeoProvider) createAsgardeoApplication(regReq RegisterRequest) error {
|
func (p *asgardeoProvider) createAsgardeoApplication(regReq RegisterRequest) error {
|
||||||
|
|
||||||
|
orgName := p.cfg.Demo.OrgName
|
||||||
|
if p.cfg.Mode == "asgardeo" {
|
||||||
|
orgName = p.cfg.Asgardeo.OrgName
|
||||||
|
}
|
||||||
|
|
||||||
body := buildAsgardeoPayload(regReq)
|
body := buildAsgardeoPayload(regReq)
|
||||||
reqBytes, err := json.Marshal(body)
|
reqBytes, err := json.Marshal(body)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to marshal Asgardeo request: %w", err)
|
return fmt.Errorf("failed to marshal Asgardeo request: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
asgardeoAppURL := "https://api.asgardeo.io/t/" + p.cfg.Demo.OrgName + "/api/server/v1/applications"
|
asgardeoAppURL := "https://api.asgardeo.io/t/" + orgName + "/api/server/v1/applications"
|
||||||
req, err := http.NewRequest("POST", asgardeoAppURL, bytes.NewBuffer(reqBytes))
|
req, err := http.NewRequest("POST", asgardeoAppURL, bytes.NewBuffer(reqBytes))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to create Asgardeo API request: %w", err)
|
return fmt.Errorf("failed to create Asgardeo API request: %w", err)
|
||||||
|
@ -195,6 +202,14 @@ func (p *asgardeoProvider) createAsgardeoApplication(regReq RegisterRequest) err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *asgardeoProvider) getAsgardeoAdminToken() (string, error) {
|
func (p *asgardeoProvider) getAsgardeoAdminToken() (string, error) {
|
||||||
|
|
||||||
|
clientId := p.cfg.Demo.ClientID
|
||||||
|
clientSecret := p.cfg.Demo.ClientSecret
|
||||||
|
if p.cfg.Mode == "asgardeo" {
|
||||||
|
clientId = p.cfg.Asgardeo.ClientID
|
||||||
|
clientSecret = p.cfg.Asgardeo.ClientSecret
|
||||||
|
}
|
||||||
|
|
||||||
tokenURL := p.cfg.AuthServerBaseURL + "/token"
|
tokenURL := p.cfg.AuthServerBaseURL + "/token"
|
||||||
|
|
||||||
formData := "grant_type=client_credentials&scope=internal_application_mgt_create internal_application_mgt_delete " +
|
formData := "grant_type=client_credentials&scope=internal_application_mgt_create internal_application_mgt_delete " +
|
||||||
|
@ -207,10 +222,10 @@ func (p *asgardeoProvider) getAsgardeoAdminToken() (string, error) {
|
||||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||||
|
|
||||||
// Sensitive data - should not be logged at INFO level
|
// Sensitive data - should not be logged at INFO level
|
||||||
auth := p.cfg.Demo.ClientID + ":" + p.cfg.Demo.ClientSecret
|
auth := clientId + ":" + clientSecret
|
||||||
req.Header.Set("Authorization", "Basic "+base64.StdEncoding.EncodeToString([]byte(auth)))
|
req.Header.Set("Authorization", "Basic "+base64.StdEncoding.EncodeToString([]byte(auth)))
|
||||||
|
|
||||||
logger.Debug("Requesting admin token for Asgardeo with client ID: %s", p.cfg.Demo.ClientID)
|
logger.Debug("Requesting admin token for Asgardeo with client ID: %s", clientId)
|
||||||
|
|
||||||
tr := &http.Transport{
|
tr := &http.Transport{
|
||||||
TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
|
TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
|
||||||
|
@ -255,6 +270,18 @@ func buildAsgardeoPayload(regReq RegisterRequest) map[string]interface{} {
|
||||||
}
|
}
|
||||||
appName += "-" + randomString(5)
|
appName += "-" + randomString(5)
|
||||||
|
|
||||||
|
// Build redirect URIs regex from list of redirect URIs : regexp=(https://app.example.com/callback1|https://app.example.com/callback2)
|
||||||
|
redirectURI := "regexp=(" + strings.Join(regReq.RedirectURIs, "|") + ")"
|
||||||
|
redirectURIs := []string{redirectURI}
|
||||||
|
|
||||||
|
// Filter unsupported grant types
|
||||||
|
var grantTypes []string
|
||||||
|
for _, gt := range regReq.GrantTypes {
|
||||||
|
if gt == "authorization_code" || gt == "refresh_token" {
|
||||||
|
grantTypes = append(grantTypes, gt)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
return map[string]interface{}{
|
return map[string]interface{}{
|
||||||
"name": appName,
|
"name": appName,
|
||||||
"templateId": "custom-application-oidc",
|
"templateId": "custom-application-oidc",
|
||||||
|
@ -262,10 +289,10 @@ func buildAsgardeoPayload(regReq RegisterRequest) map[string]interface{} {
|
||||||
"oidc": map[string]interface{}{
|
"oidc": map[string]interface{}{
|
||||||
"clientId": regReq.ClientID,
|
"clientId": regReq.ClientID,
|
||||||
"clientSecret": regReq.ClientSecret,
|
"clientSecret": regReq.ClientSecret,
|
||||||
"grantTypes": regReq.GrantTypes,
|
"grantTypes": grantTypes,
|
||||||
"callbackURLs": regReq.RedirectURIs,
|
"callbackURLs": redirectURIs,
|
||||||
"allowedOrigins": []string{},
|
"allowedOrigins": []string{},
|
||||||
"publicClient": false,
|
"publicClient": true,
|
||||||
"pkce": map[string]bool{
|
"pkce": map[string]bool{
|
||||||
"mandatory": true,
|
"mandatory": true,
|
||||||
"supportPlainTransformAlgorithm": true,
|
"supportPlainTransformAlgorithm": true,
|
||||||
|
|
|
@ -3,6 +3,8 @@ package config
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"os"
|
"os"
|
||||||
|
"runtime"
|
||||||
|
"strings"
|
||||||
|
|
||||||
"gopkg.in/yaml.v2"
|
"gopkg.in/yaml.v2"
|
||||||
)
|
)
|
||||||
|
@ -17,15 +19,16 @@ const (
|
||||||
|
|
||||||
// Common path configuration for all transport modes
|
// Common path configuration for all transport modes
|
||||||
type PathsConfig struct {
|
type PathsConfig struct {
|
||||||
SSE string `yaml:"sse"`
|
SSE string `yaml:"sse"`
|
||||||
Messages string `yaml:"messages"`
|
Messages string `yaml:"messages"`
|
||||||
|
StreamableHTTP string `yaml:"streamable_http"` // Path for streamable HTTP requests
|
||||||
}
|
}
|
||||||
|
|
||||||
// StdioConfig contains stdio-specific configuration
|
// StdioConfig contains stdio-specific configuration
|
||||||
type StdioConfig struct {
|
type StdioConfig struct {
|
||||||
Enabled bool `yaml:"enabled"`
|
Enabled bool `yaml:"enabled"`
|
||||||
UserCommand string `yaml:"user_command"` // The command provided by the user
|
UserCommand string `yaml:"user_command"` // The command provided by the user
|
||||||
WorkDir string `yaml:"work_dir"` // Working directory (optional)
|
WorkDir string `yaml:"work_dir"` // Working directory (optional)
|
||||||
Args []string `yaml:"args,omitempty"` // Additional arguments
|
Args []string `yaml:"args,omitempty"` // Additional arguments
|
||||||
Env []string `yaml:"env,omitempty"` // Environment variables
|
Env []string `yaml:"env,omitempty"` // Environment variables
|
||||||
}
|
}
|
||||||
|
@ -83,18 +86,18 @@ type DefaultConfig struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
type Config struct {
|
type Config struct {
|
||||||
AuthServerBaseURL string
|
AuthServerBaseURL string
|
||||||
ListenPort int `yaml:"listen_port"`
|
ListenPort int `yaml:"listen_port"`
|
||||||
BaseURL string `yaml:"base_url"`
|
BaseURL string `yaml:"base_url"`
|
||||||
Port int `yaml:"port"`
|
Port int `yaml:"port"`
|
||||||
JWKSURL string
|
JWKSURL string
|
||||||
TimeoutSeconds int `yaml:"timeout_seconds"`
|
TimeoutSeconds int `yaml:"timeout_seconds"`
|
||||||
PathMapping map[string]string `yaml:"path_mapping"`
|
PathMapping map[string]string `yaml:"path_mapping"`
|
||||||
Mode string `yaml:"mode"`
|
Mode string `yaml:"mode"`
|
||||||
CORSConfig CORSConfig `yaml:"cors"`
|
CORSConfig CORSConfig `yaml:"cors"`
|
||||||
TransportMode TransportMode `yaml:"transport_mode"`
|
TransportMode TransportMode `yaml:"transport_mode"`
|
||||||
Paths PathsConfig `yaml:"paths"`
|
Paths PathsConfig `yaml:"paths"`
|
||||||
Stdio StdioConfig `yaml:"stdio"`
|
Stdio StdioConfig `yaml:"stdio"`
|
||||||
|
|
||||||
// Nested config for Asgardeo
|
// Nested config for Asgardeo
|
||||||
Demo DemoConfig `yaml:"demo"`
|
Demo DemoConfig `yaml:"demo"`
|
||||||
|
@ -136,7 +139,7 @@ func (c *Config) Validate() error {
|
||||||
|
|
||||||
// GetMCPPaths returns the list of paths that should be proxied to the MCP server
|
// GetMCPPaths returns the list of paths that should be proxied to the MCP server
|
||||||
func (c *Config) GetMCPPaths() []string {
|
func (c *Config) GetMCPPaths() []string {
|
||||||
return []string{c.Paths.SSE, c.Paths.Messages}
|
return []string{c.Paths.SSE, c.Paths.Messages, c.Paths.StreamableHTTP}
|
||||||
}
|
}
|
||||||
|
|
||||||
// BuildExecCommand constructs the full command string for execution in stdio mode
|
// BuildExecCommand constructs the full command string for execution in stdio mode
|
||||||
|
@ -145,7 +148,15 @@ func (c *Config) BuildExecCommand() string {
|
||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
|
|
||||||
// Construct the full command
|
if runtime.GOOS == "windows" {
|
||||||
|
// For Windows, we need to properly escape the inner command
|
||||||
|
escapedCommand := strings.ReplaceAll(c.Stdio.UserCommand, `"`, `\"`)
|
||||||
|
return fmt.Sprintf(
|
||||||
|
`npx -y supergateway --stdio "%s" --port %d --baseUrl %s --ssePath %s --messagePath %s`,
|
||||||
|
escapedCommand, c.Port, c.BaseURL, c.Paths.SSE, c.Paths.Messages,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
return fmt.Sprintf(
|
return fmt.Sprintf(
|
||||||
`npx -y supergateway --stdio "%s" --port %d --baseUrl %s --ssePath %s --messagePath %s`,
|
`npx -y supergateway --stdio "%s" --port %d --baseUrl %s --ssePath %s --messagePath %s`,
|
||||||
c.Stdio.UserCommand, c.Port, c.BaseURL, c.Paths.SSE, c.Paths.Messages,
|
c.Stdio.UserCommand, c.Port, c.BaseURL, c.Paths.SSE, c.Paths.Messages,
|
||||||
|
@ -165,12 +176,12 @@ func LoadConfig(path string) (*Config, error) {
|
||||||
if err := decoder.Decode(&cfg); err != nil {
|
if err := decoder.Decode(&cfg); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// Set default values
|
// Set default values
|
||||||
if cfg.TimeoutSeconds == 0 {
|
if cfg.TimeoutSeconds == 0 {
|
||||||
cfg.TimeoutSeconds = 15 // default
|
cfg.TimeoutSeconds = 15 // default
|
||||||
}
|
}
|
||||||
|
|
||||||
// Set default transport mode if not specified
|
// Set default transport mode if not specified
|
||||||
if cfg.TransportMode == "" {
|
if cfg.TransportMode == "" {
|
||||||
cfg.TransportMode = SSETransport // Default to SSE
|
cfg.TransportMode = SSETransport // Default to SSE
|
||||||
|
@ -180,11 +191,11 @@ func LoadConfig(path string) (*Config, error) {
|
||||||
if cfg.Port == 0 {
|
if cfg.Port == 0 {
|
||||||
cfg.Port = 8000 // default
|
cfg.Port = 8000 // default
|
||||||
}
|
}
|
||||||
|
|
||||||
// Validate the configuration
|
// Validate the configuration
|
||||||
if err := cfg.Validate(); err != nil {
|
if err := cfg.Validate(); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
return &cfg, nil
|
return &cfg, nil
|
||||||
}
|
}
|
||||||
|
|
|
@ -136,20 +136,15 @@ func TestValidate(t *testing.T) {
|
||||||
func TestGetMCPPaths(t *testing.T) {
|
func TestGetMCPPaths(t *testing.T) {
|
||||||
cfg := Config{
|
cfg := Config{
|
||||||
Paths: PathsConfig{
|
Paths: PathsConfig{
|
||||||
SSE: "/custom-sse",
|
SSE: "/custom-sse",
|
||||||
Messages: "/custom-messages",
|
Messages: "/custom-messages",
|
||||||
|
StreamableHTTP: "/custom-streamable",
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
paths := cfg.GetMCPPaths()
|
paths := cfg.GetMCPPaths()
|
||||||
if len(paths) != 2 {
|
if len(paths) != 3 {
|
||||||
t.Errorf("Expected 2 MCP paths, got %d", len(paths))
|
t.Errorf("Expected 3 MCP paths, got %d", len(paths))
|
||||||
}
|
|
||||||
if paths[0] != "/custom-sse" {
|
|
||||||
t.Errorf("Expected first path=/custom-sse, got %s", paths[0])
|
|
||||||
}
|
|
||||||
if paths[1] != "/custom-messages" {
|
|
||||||
t.Errorf("Expected second path=/custom-messages, got %s", paths[1])
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -10,7 +10,7 @@ import (
|
||||||
|
|
||||||
"github.com/wso2/open-mcp-auth-proxy/internal/authz"
|
"github.com/wso2/open-mcp-auth-proxy/internal/authz"
|
||||||
"github.com/wso2/open-mcp-auth-proxy/internal/config"
|
"github.com/wso2/open-mcp-auth-proxy/internal/config"
|
||||||
"github.com/wso2/open-mcp-auth-proxy/internal/logging"
|
logger "github.com/wso2/open-mcp-auth-proxy/internal/logging"
|
||||||
"github.com/wso2/open-mcp-auth-proxy/internal/util"
|
"github.com/wso2/open-mcp-auth-proxy/internal/util"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -106,7 +106,7 @@ func buildProxyHandler(cfg *config.Config, modifiers map[string]RequestModifier)
|
||||||
logger.Error("Invalid auth server URL: %v", err)
|
logger.Error("Invalid auth server URL: %v", err)
|
||||||
panic(err) // Fatal error that prevents startup
|
panic(err) // Fatal error that prevents startup
|
||||||
}
|
}
|
||||||
|
|
||||||
mcpBase, err := url.Parse(cfg.BaseURL)
|
mcpBase, err := url.Parse(cfg.BaseURL)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Error("Invalid MCP server URL: %v", err)
|
logger.Error("Invalid MCP server URL: %v", err)
|
||||||
|
@ -191,13 +191,13 @@ func buildProxyHandler(cfg *config.Config, modifiers map[string]RequestModifier)
|
||||||
req.Host = targetURL.Host
|
req.Host = targetURL.Host
|
||||||
|
|
||||||
cleanHeaders := http.Header{}
|
cleanHeaders := http.Header{}
|
||||||
|
|
||||||
// Set proper origin header to match the target
|
// Set proper origin header to match the target
|
||||||
if isSSE {
|
if isSSE {
|
||||||
// For SSE, ensure origin matches the target
|
// For SSE, ensure origin matches the target
|
||||||
req.Header.Set("Origin", targetURL.Scheme+"://"+targetURL.Host)
|
req.Header.Set("Origin", targetURL.Scheme+"://"+targetURL.Host)
|
||||||
}
|
}
|
||||||
|
|
||||||
for k, v := range r.Header {
|
for k, v := range r.Header {
|
||||||
// Skip hop-by-hop headers
|
// Skip hop-by-hop headers
|
||||||
if skipHeader(k) {
|
if skipHeader(k) {
|
||||||
|
@ -231,12 +231,12 @@ func buildProxyHandler(cfg *config.Config, modifiers map[string]RequestModifier)
|
||||||
proxyHost: r.Host,
|
proxyHost: r.Host,
|
||||||
targetHost: targetURL.Host,
|
targetHost: targetURL.Host,
|
||||||
}
|
}
|
||||||
|
|
||||||
// Set SSE-specific headers
|
// Set SSE-specific headers
|
||||||
w.Header().Set("X-Accel-Buffering", "no")
|
w.Header().Set("X-Accel-Buffering", "no")
|
||||||
w.Header().Set("Cache-Control", "no-cache")
|
w.Header().Set("Cache-Control", "no-cache")
|
||||||
w.Header().Set("Connection", "keep-alive")
|
w.Header().Set("Connection", "keep-alive")
|
||||||
|
|
||||||
// Keep SSE connections open
|
// Keep SSE connections open
|
||||||
HandleSSE(w, r, rp)
|
HandleSSE(w, r, rp)
|
||||||
} else {
|
} else {
|
||||||
|
|
|
@ -4,13 +4,14 @@ import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"os"
|
"os"
|
||||||
"os/exec"
|
"os/exec"
|
||||||
|
"runtime"
|
||||||
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
"syscall"
|
"syscall"
|
||||||
"time"
|
"time"
|
||||||
"strings"
|
|
||||||
|
|
||||||
"github.com/wso2/open-mcp-auth-proxy/internal/config"
|
"github.com/wso2/open-mcp-auth-proxy/internal/config"
|
||||||
"github.com/wso2/open-mcp-auth-proxy/internal/logging"
|
logger "github.com/wso2/open-mcp-auth-proxy/internal/logging"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Manager handles starting and graceful shutdown of subprocesses
|
// Manager handles starting and graceful shutdown of subprocesses
|
||||||
|
@ -31,34 +32,39 @@ func NewManager() *Manager {
|
||||||
|
|
||||||
// EnsureDependenciesAvailable checks and installs required package executors
|
// EnsureDependenciesAvailable checks and installs required package executors
|
||||||
func EnsureDependenciesAvailable(command string) error {
|
func EnsureDependenciesAvailable(command string) error {
|
||||||
// Always ensure npx is available regardless of the command
|
// Always ensure npx is available regardless of the command
|
||||||
if _, err := exec.LookPath("npx"); err != nil {
|
if _, err := exec.LookPath("npx"); err != nil {
|
||||||
// npx is not available, check if npm is installed
|
// npx is not available, check if npm is installed
|
||||||
if _, err := exec.LookPath("npm"); err != nil {
|
if _, err := exec.LookPath("npm"); err != nil {
|
||||||
return fmt.Errorf("npx not found and npm not available; please install Node.js from https://nodejs.org/")
|
return fmt.Errorf("npx not found and npm not available; please install Node.js from https://nodejs.org/")
|
||||||
}
|
}
|
||||||
|
|
||||||
// Try to install npx using npm
|
// Try to install npx using npm
|
||||||
logger.Info("npx not found, attempting to install...")
|
logger.Info("npx not found, attempting to install...")
|
||||||
cmd := exec.Command("npm", "install", "-g", "npx")
|
var cmd *exec.Cmd
|
||||||
cmd.Stdout = os.Stdout
|
if runtime.GOOS == "windows" {
|
||||||
cmd.Stderr = os.Stderr
|
cmd = exec.Command("npm.cmd", "install", "-g", "npx")
|
||||||
|
} else {
|
||||||
if err := cmd.Run(); err != nil {
|
cmd = exec.Command("npm", "install", "-g", "npx")
|
||||||
return fmt.Errorf("failed to install npx: %w", err)
|
}
|
||||||
}
|
cmd.Stdout = os.Stdout
|
||||||
|
cmd.Stderr = os.Stderr
|
||||||
logger.Info("npx installed successfully")
|
|
||||||
}
|
if err := cmd.Run(); err != nil {
|
||||||
|
return fmt.Errorf("failed to install npx: %w", err)
|
||||||
// Check if uv is needed based on the command
|
}
|
||||||
if strings.Contains(command, "uv ") {
|
|
||||||
if _, err := exec.LookPath("uv"); err != nil {
|
logger.Info("npx installed successfully")
|
||||||
return fmt.Errorf("command requires uv but it's not installed; please install it following instructions at https://github.com/astral-sh/uv")
|
}
|
||||||
}
|
|
||||||
}
|
// Check if uv is needed based on the command
|
||||||
|
if strings.Contains(command, "uv ") {
|
||||||
return nil
|
if _, err := exec.LookPath("uv"); err != nil {
|
||||||
|
return fmt.Errorf("command requires uv but it's not installed; please install it following instructions at https://github.com/astral-sh/uv")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetShutdownDelay sets the maximum time to wait for graceful shutdown
|
// SetShutdownDelay sets the maximum time to wait for graceful shutdown
|
||||||
|
@ -88,8 +94,13 @@ func (m *Manager) Start(cfg *config.Config) error {
|
||||||
|
|
||||||
logger.Info("Starting subprocess with command: %s", execCommand)
|
logger.Info("Starting subprocess with command: %s", execCommand)
|
||||||
|
|
||||||
// Use the shell to execute the command
|
var cmd *exec.Cmd
|
||||||
cmd := exec.Command("sh", "-c", execCommand)
|
if runtime.GOOS == "windows" {
|
||||||
|
// Use PowerShell on Windows for better quote handling
|
||||||
|
cmd = exec.Command("powershell", "-Command", execCommand)
|
||||||
|
} else {
|
||||||
|
cmd = exec.Command("sh", "-c", execCommand)
|
||||||
|
}
|
||||||
|
|
||||||
// Set working directory if specified
|
// Set working directory if specified
|
||||||
if cfg.Stdio.WorkDir != "" {
|
if cfg.Stdio.WorkDir != "" {
|
||||||
|
@ -105,8 +116,8 @@ func (m *Manager) Start(cfg *config.Config) error {
|
||||||
cmd.Stdout = os.Stdout
|
cmd.Stdout = os.Stdout
|
||||||
cmd.Stderr = os.Stderr
|
cmd.Stderr = os.Stderr
|
||||||
|
|
||||||
// Set the process group for proper termination
|
// Set platform-specific process attributes
|
||||||
cmd.SysProcAttr = &syscall.SysProcAttr{Setpgid: true}
|
setProcAttr(cmd)
|
||||||
|
|
||||||
// Start the process
|
// Start the process
|
||||||
if err := cmd.Start(); err != nil {
|
if err := cmd.Start(); err != nil {
|
||||||
|
@ -117,11 +128,13 @@ func (m *Manager) Start(cfg *config.Config) error {
|
||||||
m.cmd = cmd
|
m.cmd = cmd
|
||||||
logger.Info("Subprocess started with PID: %d", m.process.Pid)
|
logger.Info("Subprocess started with PID: %d", m.process.Pid)
|
||||||
|
|
||||||
// Get and store the process group ID
|
// Get and store the process group ID (Unix) or PID (Windows)
|
||||||
pgid, err := syscall.Getpgid(m.process.Pid)
|
pgid, err := getProcessGroup(m.process.Pid)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
m.processGroup = pgid
|
m.processGroup = pgid
|
||||||
logger.Debug("Process group ID: %d", m.processGroup)
|
if runtime.GOOS != "windows" {
|
||||||
|
logger.Debug("Process group ID: %d", m.processGroup)
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
logger.Warn("Failed to get process group ID: %v", err)
|
logger.Warn("Failed to get process group ID: %v", err)
|
||||||
m.processGroup = m.process.Pid
|
m.processGroup = m.process.Pid
|
||||||
|
@ -155,7 +168,7 @@ func (m *Manager) IsRunning() bool {
|
||||||
// Shutdown gracefully terminates the subprocess
|
// Shutdown gracefully terminates the subprocess
|
||||||
func (m *Manager) Shutdown() {
|
func (m *Manager) Shutdown() {
|
||||||
m.mutex.Lock()
|
m.mutex.Lock()
|
||||||
processToTerminate := m.process // Local copy of the process reference
|
processToTerminate := m.process // Local copy of the process reference
|
||||||
processGroupToTerminate := m.processGroup
|
processGroupToTerminate := m.processGroup
|
||||||
m.mutex.Unlock()
|
m.mutex.Unlock()
|
||||||
|
|
||||||
|
@ -169,48 +182,73 @@ func (m *Manager) Shutdown() {
|
||||||
go func() {
|
go func() {
|
||||||
defer close(terminateComplete)
|
defer close(terminateComplete)
|
||||||
|
|
||||||
// Try graceful termination first with SIGTERM
|
// Try graceful termination first
|
||||||
terminatedGracefully := false
|
terminatedGracefully := false
|
||||||
|
|
||||||
// Try to terminate the process group first
|
if runtime.GOOS == "windows" {
|
||||||
if processGroupToTerminate != 0 {
|
// Windows: Try to terminate the process
|
||||||
err := syscall.Kill(-processGroupToTerminate, syscall.SIGTERM)
|
m.mutex.Lock()
|
||||||
if err != nil {
|
if m.process != nil {
|
||||||
logger.Warn("Failed to send SIGTERM to process group: %v", err)
|
err := m.process.Kill()
|
||||||
|
if err != nil {
|
||||||
|
logger.Warn("Failed to terminate process: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
m.mutex.Unlock()
|
||||||
|
|
||||||
// Fallback to terminating just the process
|
// Wait a bit to see if it terminates
|
||||||
|
for i := 0; i < 10; i++ {
|
||||||
|
time.Sleep(200 * time.Millisecond)
|
||||||
|
m.mutex.Lock()
|
||||||
|
if m.process == nil {
|
||||||
|
terminatedGracefully = true
|
||||||
|
m.mutex.Unlock()
|
||||||
|
break
|
||||||
|
}
|
||||||
|
m.mutex.Unlock()
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// Unix: Use SIGTERM followed by SIGKILL if necessary
|
||||||
|
// Try to terminate the process group first
|
||||||
|
if processGroupToTerminate != 0 {
|
||||||
|
err := killProcessGroup(processGroupToTerminate, syscall.SIGTERM)
|
||||||
|
if err != nil {
|
||||||
|
logger.Warn("Failed to send SIGTERM to process group: %v", err)
|
||||||
|
|
||||||
|
// Fallback to terminating just the process
|
||||||
|
m.mutex.Lock()
|
||||||
|
if m.process != nil {
|
||||||
|
err = m.process.Signal(syscall.SIGTERM)
|
||||||
|
if err != nil {
|
||||||
|
logger.Warn("Failed to send SIGTERM to process: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
m.mutex.Unlock()
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// Try to terminate just the process
|
||||||
m.mutex.Lock()
|
m.mutex.Lock()
|
||||||
if m.process != nil {
|
if m.process != nil {
|
||||||
err = m.process.Signal(syscall.SIGTERM)
|
err := m.process.Signal(syscall.SIGTERM)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Warn("Failed to send SIGTERM to process: %v", err)
|
logger.Warn("Failed to send SIGTERM to process: %v", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
m.mutex.Unlock()
|
m.mutex.Unlock()
|
||||||
}
|
}
|
||||||
} else {
|
|
||||||
// Try to terminate just the process
|
// Wait for the process to exit gracefully
|
||||||
m.mutex.Lock()
|
for i := 0; i < 10; i++ {
|
||||||
if m.process != nil {
|
time.Sleep(200 * time.Millisecond)
|
||||||
err := m.process.Signal(syscall.SIGTERM)
|
|
||||||
if err != nil {
|
m.mutex.Lock()
|
||||||
logger.Warn("Failed to send SIGTERM to process: %v", err)
|
if m.process == nil {
|
||||||
|
terminatedGracefully = true
|
||||||
|
m.mutex.Unlock()
|
||||||
|
break
|
||||||
}
|
}
|
||||||
}
|
|
||||||
m.mutex.Unlock()
|
|
||||||
}
|
|
||||||
|
|
||||||
// Wait for the process to exit gracefully
|
|
||||||
for i := 0; i < 10; i++ {
|
|
||||||
time.Sleep(200 * time.Millisecond)
|
|
||||||
|
|
||||||
m.mutex.Lock()
|
|
||||||
if m.process == nil {
|
|
||||||
terminatedGracefully = true
|
|
||||||
m.mutex.Unlock()
|
m.mutex.Unlock()
|
||||||
break
|
|
||||||
}
|
}
|
||||||
m.mutex.Unlock()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if terminatedGracefully {
|
if terminatedGracefully {
|
||||||
|
@ -221,12 +259,33 @@ func (m *Manager) Shutdown() {
|
||||||
// If the process didn't exit gracefully, force kill
|
// If the process didn't exit gracefully, force kill
|
||||||
logger.Warn("Subprocess didn't exit gracefully, forcing termination...")
|
logger.Warn("Subprocess didn't exit gracefully, forcing termination...")
|
||||||
|
|
||||||
// Try to kill the process group first
|
if runtime.GOOS == "windows" {
|
||||||
if processGroupToTerminate != 0 {
|
// On Windows, Kill() is already forceful
|
||||||
if err := syscall.Kill(-processGroupToTerminate, syscall.SIGKILL); err != nil {
|
m.mutex.Lock()
|
||||||
logger.Warn("Failed to send SIGKILL to process group: %v", err)
|
if m.process != nil {
|
||||||
|
if err := m.process.Kill(); err != nil {
|
||||||
|
logger.Error("Failed to kill process: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
m.mutex.Unlock()
|
||||||
|
} else {
|
||||||
|
// Unix: Try SIGKILL
|
||||||
|
// Try to kill the process group first
|
||||||
|
if processGroupToTerminate != 0 {
|
||||||
|
if err := killProcessGroup(processGroupToTerminate, syscall.SIGKILL); err != nil {
|
||||||
|
logger.Warn("Failed to send SIGKILL to process group: %v", err)
|
||||||
|
|
||||||
// Fallback to killing just the process
|
// Fallback to killing just the process
|
||||||
|
m.mutex.Lock()
|
||||||
|
if m.process != nil {
|
||||||
|
if err := m.process.Kill(); err != nil {
|
||||||
|
logger.Error("Failed to kill process: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
m.mutex.Unlock()
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// Try to kill just the process
|
||||||
m.mutex.Lock()
|
m.mutex.Lock()
|
||||||
if m.process != nil {
|
if m.process != nil {
|
||||||
if err := m.process.Kill(); err != nil {
|
if err := m.process.Kill(); err != nil {
|
||||||
|
@ -235,15 +294,6 @@ func (m *Manager) Shutdown() {
|
||||||
}
|
}
|
||||||
m.mutex.Unlock()
|
m.mutex.Unlock()
|
||||||
}
|
}
|
||||||
} else {
|
|
||||||
// Try to kill just the process
|
|
||||||
m.mutex.Lock()
|
|
||||||
if m.process != nil {
|
|
||||||
if err := m.process.Kill(); err != nil {
|
|
||||||
logger.Error("Failed to kill process: %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
m.mutex.Unlock()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Wait a bit more to confirm termination
|
// Wait a bit more to confirm termination
|
||||||
|
|
23
internal/subprocess/manager_unix.go
Normal file
23
internal/subprocess/manager_unix.go
Normal file
|
@ -0,0 +1,23 @@
|
||||||
|
//go:build !windows
|
||||||
|
|
||||||
|
package subprocess
|
||||||
|
|
||||||
|
import (
|
||||||
|
"os/exec"
|
||||||
|
"syscall"
|
||||||
|
)
|
||||||
|
|
||||||
|
// setProcAttr sets Unix-specific process attributes
|
||||||
|
func setProcAttr(cmd *exec.Cmd) {
|
||||||
|
cmd.SysProcAttr = &syscall.SysProcAttr{Setpgid: true}
|
||||||
|
}
|
||||||
|
|
||||||
|
// getProcessGroup gets the process group ID on Unix systems
|
||||||
|
func getProcessGroup(pid int) (int, error) {
|
||||||
|
return syscall.Getpgid(pid)
|
||||||
|
}
|
||||||
|
|
||||||
|
// killProcessGroup kills a process group on Unix systems
|
||||||
|
func killProcessGroup(pgid int, signal syscall.Signal) error {
|
||||||
|
return syscall.Kill(-pgid, signal)
|
||||||
|
}
|
27
internal/subprocess/manager_windows.go
Normal file
27
internal/subprocess/manager_windows.go
Normal file
|
@ -0,0 +1,27 @@
|
||||||
|
//go:build windows
|
||||||
|
|
||||||
|
package subprocess
|
||||||
|
|
||||||
|
import (
|
||||||
|
"os/exec"
|
||||||
|
"syscall"
|
||||||
|
)
|
||||||
|
|
||||||
|
// setProcAttr sets Windows-specific process attributes
|
||||||
|
func setProcAttr(cmd *exec.Cmd) {
|
||||||
|
cmd.SysProcAttr = &syscall.SysProcAttr{
|
||||||
|
CreationFlags: syscall.CREATE_NEW_PROCESS_GROUP,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// getProcessGroup returns the PID itself on Windows (no process groups)
|
||||||
|
func getProcessGroup(pid int) (int, error) {
|
||||||
|
return pid, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// killProcessGroup kills a process on Windows (no process groups)
|
||||||
|
func killProcessGroup(pgid int, signal syscall.Signal) error {
|
||||||
|
// On Windows, we'll use the process handle directly
|
||||||
|
// This function shouldn't be called on Windows, but we provide it for compatibility
|
||||||
|
return nil
|
||||||
|
}
|
|
@ -1,52 +1,11 @@
|
||||||
## Purpose
|
## Purpose
|
||||||
> Describe the problems, issues, or needs driving this feature/fix and include links to related issues in the following format: Resolves issue1, issue2, etc.
|
<!-- Describe the problems, issues, or needs driving this feature/fix and include links to related issues in the following format: Resolves issue1, issue2, etc. -->
|
||||||
|
|
||||||
## Goals
|
## Related Issues
|
||||||
> Describe the solutions that this feature/fix will introduce to resolve the problems described above
|
<!-- List any related issues -->
|
||||||
|
|
||||||
## Approach
|
|
||||||
> Describe how you are implementing the solutions. Include an animated GIF or screenshot if the change affects the UI (email documentation@wso2.com to review all UI text). Include a link to a Markdown file or Google doc if the feature write-up is too long to paste here.
|
|
||||||
|
|
||||||
## User stories
|
|
||||||
> Summary of user stories addressed by this change>
|
|
||||||
|
|
||||||
## Release note
|
|
||||||
> Brief description of the new feature or bug fix as it will appear in the release notes
|
|
||||||
|
|
||||||
## Documentation
|
|
||||||
> Link(s) to product documentation that addresses the changes of this PR. If no doc impact, enter “N/A” plus brief explanation of why there’s no doc impact
|
|
||||||
|
|
||||||
## Training
|
|
||||||
> Link to the PR for changes to the training content in https://github.com/wso2/WSO2-Training, if applicable
|
|
||||||
|
|
||||||
## Certification
|
|
||||||
> Type “Sent” when you have provided new/updated certification questions, plus four answers for each question (correct answer highlighted in bold), based on this change. Certification questions/answers should be sent to certification@wso2.com and NOT pasted in this PR. If there is no impact on certification exams, type “N/A” and explain why.
|
|
||||||
|
|
||||||
## Marketing
|
|
||||||
> Link to drafts of marketing content that will describe and promote this feature, including product page changes, technical articles, blog posts, videos, etc., if applicable
|
|
||||||
|
|
||||||
## Automation tests
|
|
||||||
- Unit tests
|
|
||||||
> Code coverage information
|
|
||||||
- Integration tests
|
|
||||||
> Details about the test cases and coverage
|
|
||||||
|
|
||||||
## Security checks
|
|
||||||
- Followed secure coding standards in http://wso2.com/technical-reports/wso2-secure-engineering-guidelines? yes/no
|
|
||||||
- Ran FindSecurityBugs plugin and verified report? yes/no
|
|
||||||
- Confirmed that this PR doesn't commit any keys, passwords, tokens, usernames, or other secrets? yes/no
|
|
||||||
|
|
||||||
## Samples
|
|
||||||
> Provide high-level details about the samples related to this feature
|
|
||||||
|
|
||||||
## Related PRs
|
## Related PRs
|
||||||
> List any other related PRs
|
<!-- List any other related PRs -->
|
||||||
|
|
||||||
## Migrations (if applicable)
|
## Migrations (if applicable)
|
||||||
> Describe migration steps and platforms on which migration has been tested
|
<!-- Describe migration steps and platforms on which migration has been tested -->
|
||||||
|
|
||||||
## Test environment
|
|
||||||
> List all JDK versions, operating systems, databases, and browser/versions on which this feature/fix was tested
|
|
||||||
|
|
||||||
## Learning
|
|
||||||
> Describe the research phase and any blog posts, patterns, libraries, or add-ons you used to solve the problem.
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue