mirror of
https://github.com/wso2/open-mcp-auth-proxy.git
synced 2025-06-27 09:05:41 +00:00
parent
fc0d939e16
commit
316370be1c
4 changed files with 37 additions and 41 deletions
|
@ -2,14 +2,15 @@
|
||||||
|
|
||||||
# Common configuration for all transport modes
|
# Common configuration for all transport modes
|
||||||
listen_port: 8080
|
listen_port: 8080
|
||||||
base_url: "http://localhost:8000" # Base URL for the MCP server
|
base_url: "http://localhost:3001" # Base URL for the MCP server
|
||||||
port: 8000 # Port for the MCP server
|
port: 3001 # Port for the MCP server
|
||||||
timeout_seconds: 10
|
timeout_seconds: 10
|
||||||
|
|
||||||
# Path configuration
|
# Path configuration
|
||||||
paths:
|
paths:
|
||||||
sse: "/sse" # SSE endpoint path
|
sse: "/sse" # SSE endpoint path
|
||||||
messages: "/messages/" # Messages endpoint path
|
messages: "/messages/" # Messages endpoint path
|
||||||
|
streamable_http: "/mcp" # MCP endpoint path
|
||||||
|
|
||||||
# Transport mode configuration
|
# Transport mode configuration
|
||||||
transport_mode: "sse" # Options: "sse" or "stdio"
|
transport_mode: "sse" # Options: "sse" or "stdio"
|
||||||
|
@ -28,7 +29,7 @@ path_mapping:
|
||||||
# CORS configuration
|
# CORS configuration
|
||||||
cors:
|
cors:
|
||||||
allowed_origins:
|
allowed_origins:
|
||||||
- "http://localhost:5173"
|
- "http://127.0.0.1:6274"
|
||||||
allowed_methods:
|
allowed_methods:
|
||||||
- "GET"
|
- "GET"
|
||||||
- "POST"
|
- "POST"
|
||||||
|
|
|
@ -19,15 +19,16 @@ const (
|
||||||
|
|
||||||
// Common path configuration for all transport modes
|
// Common path configuration for all transport modes
|
||||||
type PathsConfig struct {
|
type PathsConfig struct {
|
||||||
SSE string `yaml:"sse"`
|
SSE string `yaml:"sse"`
|
||||||
Messages string `yaml:"messages"`
|
Messages string `yaml:"messages"`
|
||||||
|
StreamableHTTP string `yaml:"streamable_http"` // Path for streamable HTTP requests
|
||||||
}
|
}
|
||||||
|
|
||||||
// StdioConfig contains stdio-specific configuration
|
// StdioConfig contains stdio-specific configuration
|
||||||
type StdioConfig struct {
|
type StdioConfig struct {
|
||||||
Enabled bool `yaml:"enabled"`
|
Enabled bool `yaml:"enabled"`
|
||||||
UserCommand string `yaml:"user_command"` // The command provided by the user
|
UserCommand string `yaml:"user_command"` // The command provided by the user
|
||||||
WorkDir string `yaml:"work_dir"` // Working directory (optional)
|
WorkDir string `yaml:"work_dir"` // Working directory (optional)
|
||||||
Args []string `yaml:"args,omitempty"` // Additional arguments
|
Args []string `yaml:"args,omitempty"` // Additional arguments
|
||||||
Env []string `yaml:"env,omitempty"` // Environment variables
|
Env []string `yaml:"env,omitempty"` // Environment variables
|
||||||
}
|
}
|
||||||
|
@ -85,18 +86,18 @@ type DefaultConfig struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
type Config struct {
|
type Config struct {
|
||||||
AuthServerBaseURL string
|
AuthServerBaseURL string
|
||||||
ListenPort int `yaml:"listen_port"`
|
ListenPort int `yaml:"listen_port"`
|
||||||
BaseURL string `yaml:"base_url"`
|
BaseURL string `yaml:"base_url"`
|
||||||
Port int `yaml:"port"`
|
Port int `yaml:"port"`
|
||||||
JWKSURL string
|
JWKSURL string
|
||||||
TimeoutSeconds int `yaml:"timeout_seconds"`
|
TimeoutSeconds int `yaml:"timeout_seconds"`
|
||||||
PathMapping map[string]string `yaml:"path_mapping"`
|
PathMapping map[string]string `yaml:"path_mapping"`
|
||||||
Mode string `yaml:"mode"`
|
Mode string `yaml:"mode"`
|
||||||
CORSConfig CORSConfig `yaml:"cors"`
|
CORSConfig CORSConfig `yaml:"cors"`
|
||||||
TransportMode TransportMode `yaml:"transport_mode"`
|
TransportMode TransportMode `yaml:"transport_mode"`
|
||||||
Paths PathsConfig `yaml:"paths"`
|
Paths PathsConfig `yaml:"paths"`
|
||||||
Stdio StdioConfig `yaml:"stdio"`
|
Stdio StdioConfig `yaml:"stdio"`
|
||||||
|
|
||||||
// Nested config for Asgardeo
|
// Nested config for Asgardeo
|
||||||
Demo DemoConfig `yaml:"demo"`
|
Demo DemoConfig `yaml:"demo"`
|
||||||
|
@ -138,7 +139,7 @@ func (c *Config) Validate() error {
|
||||||
|
|
||||||
// GetMCPPaths returns the list of paths that should be proxied to the MCP server
|
// GetMCPPaths returns the list of paths that should be proxied to the MCP server
|
||||||
func (c *Config) GetMCPPaths() []string {
|
func (c *Config) GetMCPPaths() []string {
|
||||||
return []string{c.Paths.SSE, c.Paths.Messages}
|
return []string{c.Paths.SSE, c.Paths.Messages, c.Paths.StreamableHTTP}
|
||||||
}
|
}
|
||||||
|
|
||||||
// BuildExecCommand constructs the full command string for execution in stdio mode
|
// BuildExecCommand constructs the full command string for execution in stdio mode
|
||||||
|
@ -147,7 +148,6 @@ func (c *Config) BuildExecCommand() string {
|
||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
if runtime.GOOS == "windows" {
|
if runtime.GOOS == "windows" {
|
||||||
// For Windows, we need to properly escape the inner command
|
// For Windows, we need to properly escape the inner command
|
||||||
escapedCommand := strings.ReplaceAll(c.Stdio.UserCommand, `"`, `\"`)
|
escapedCommand := strings.ReplaceAll(c.Stdio.UserCommand, `"`, `\"`)
|
||||||
|
@ -176,12 +176,12 @@ func LoadConfig(path string) (*Config, error) {
|
||||||
if err := decoder.Decode(&cfg); err != nil {
|
if err := decoder.Decode(&cfg); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// Set default values
|
// Set default values
|
||||||
if cfg.TimeoutSeconds == 0 {
|
if cfg.TimeoutSeconds == 0 {
|
||||||
cfg.TimeoutSeconds = 15 // default
|
cfg.TimeoutSeconds = 15 // default
|
||||||
}
|
}
|
||||||
|
|
||||||
// Set default transport mode if not specified
|
// Set default transport mode if not specified
|
||||||
if cfg.TransportMode == "" {
|
if cfg.TransportMode == "" {
|
||||||
cfg.TransportMode = SSETransport // Default to SSE
|
cfg.TransportMode = SSETransport // Default to SSE
|
||||||
|
@ -191,11 +191,11 @@ func LoadConfig(path string) (*Config, error) {
|
||||||
if cfg.Port == 0 {
|
if cfg.Port == 0 {
|
||||||
cfg.Port = 8000 // default
|
cfg.Port = 8000 // default
|
||||||
}
|
}
|
||||||
|
|
||||||
// Validate the configuration
|
// Validate the configuration
|
||||||
if err := cfg.Validate(); err != nil {
|
if err := cfg.Validate(); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
return &cfg, nil
|
return &cfg, nil
|
||||||
}
|
}
|
||||||
|
|
|
@ -136,20 +136,15 @@ func TestValidate(t *testing.T) {
|
||||||
func TestGetMCPPaths(t *testing.T) {
|
func TestGetMCPPaths(t *testing.T) {
|
||||||
cfg := Config{
|
cfg := Config{
|
||||||
Paths: PathsConfig{
|
Paths: PathsConfig{
|
||||||
SSE: "/custom-sse",
|
SSE: "/custom-sse",
|
||||||
Messages: "/custom-messages",
|
Messages: "/custom-messages",
|
||||||
|
StreamableHTTP: "/custom-streamable",
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
paths := cfg.GetMCPPaths()
|
paths := cfg.GetMCPPaths()
|
||||||
if len(paths) != 2 {
|
if len(paths) != 3 {
|
||||||
t.Errorf("Expected 2 MCP paths, got %d", len(paths))
|
t.Errorf("Expected 3 MCP paths, got %d", len(paths))
|
||||||
}
|
|
||||||
if paths[0] != "/custom-sse" {
|
|
||||||
t.Errorf("Expected first path=/custom-sse, got %s", paths[0])
|
|
||||||
}
|
|
||||||
if paths[1] != "/custom-messages" {
|
|
||||||
t.Errorf("Expected second path=/custom-messages, got %s", paths[1])
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -10,7 +10,7 @@ import (
|
||||||
|
|
||||||
"github.com/wso2/open-mcp-auth-proxy/internal/authz"
|
"github.com/wso2/open-mcp-auth-proxy/internal/authz"
|
||||||
"github.com/wso2/open-mcp-auth-proxy/internal/config"
|
"github.com/wso2/open-mcp-auth-proxy/internal/config"
|
||||||
"github.com/wso2/open-mcp-auth-proxy/internal/logging"
|
logger "github.com/wso2/open-mcp-auth-proxy/internal/logging"
|
||||||
"github.com/wso2/open-mcp-auth-proxy/internal/util"
|
"github.com/wso2/open-mcp-auth-proxy/internal/util"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -106,7 +106,7 @@ func buildProxyHandler(cfg *config.Config, modifiers map[string]RequestModifier)
|
||||||
logger.Error("Invalid auth server URL: %v", err)
|
logger.Error("Invalid auth server URL: %v", err)
|
||||||
panic(err) // Fatal error that prevents startup
|
panic(err) // Fatal error that prevents startup
|
||||||
}
|
}
|
||||||
|
|
||||||
mcpBase, err := url.Parse(cfg.BaseURL)
|
mcpBase, err := url.Parse(cfg.BaseURL)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Error("Invalid MCP server URL: %v", err)
|
logger.Error("Invalid MCP server URL: %v", err)
|
||||||
|
@ -191,13 +191,13 @@ func buildProxyHandler(cfg *config.Config, modifiers map[string]RequestModifier)
|
||||||
req.Host = targetURL.Host
|
req.Host = targetURL.Host
|
||||||
|
|
||||||
cleanHeaders := http.Header{}
|
cleanHeaders := http.Header{}
|
||||||
|
|
||||||
// Set proper origin header to match the target
|
// Set proper origin header to match the target
|
||||||
if isSSE {
|
if isSSE {
|
||||||
// For SSE, ensure origin matches the target
|
// For SSE, ensure origin matches the target
|
||||||
req.Header.Set("Origin", targetURL.Scheme+"://"+targetURL.Host)
|
req.Header.Set("Origin", targetURL.Scheme+"://"+targetURL.Host)
|
||||||
}
|
}
|
||||||
|
|
||||||
for k, v := range r.Header {
|
for k, v := range r.Header {
|
||||||
// Skip hop-by-hop headers
|
// Skip hop-by-hop headers
|
||||||
if skipHeader(k) {
|
if skipHeader(k) {
|
||||||
|
@ -231,12 +231,12 @@ func buildProxyHandler(cfg *config.Config, modifiers map[string]RequestModifier)
|
||||||
proxyHost: r.Host,
|
proxyHost: r.Host,
|
||||||
targetHost: targetURL.Host,
|
targetHost: targetURL.Host,
|
||||||
}
|
}
|
||||||
|
|
||||||
// Set SSE-specific headers
|
// Set SSE-specific headers
|
||||||
w.Header().Set("X-Accel-Buffering", "no")
|
w.Header().Set("X-Accel-Buffering", "no")
|
||||||
w.Header().Set("Cache-Control", "no-cache")
|
w.Header().Set("Cache-Control", "no-cache")
|
||||||
w.Header().Set("Connection", "keep-alive")
|
w.Header().Set("Connection", "keep-alive")
|
||||||
|
|
||||||
// Keep SSE connections open
|
// Keep SSE connections open
|
||||||
HandleSSE(w, r, rp)
|
HandleSSE(w, r, rp)
|
||||||
} else {
|
} else {
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue