mirror of
https://github.com/wso2/open-mcp-auth-proxy.git
synced 2025-07-06 12:30:26 +00:00
Refactor configurations
This commit is contained in:
parent
61d3c7e7e1
commit
5c1cc13ff3
6 changed files with 167 additions and 247 deletions
|
@ -15,7 +15,21 @@ const (
|
|||
StdioTransport TransportMode = "stdio"
|
||||
)
|
||||
|
||||
// AsgardeoConfig groups all Asgardeo-specific fields
|
||||
// Common path configuration for all transport modes
|
||||
type PathsConfig struct {
|
||||
SSE string `yaml:"sse"`
|
||||
Messages string `yaml:"messages"`
|
||||
}
|
||||
|
||||
// StdioConfig contains stdio-specific configuration
|
||||
type StdioConfig struct {
|
||||
Enabled bool `yaml:"enabled"`
|
||||
UserCommand string `yaml:"user_command"` // The command provided by the user
|
||||
WorkDir string `yaml:"work_dir"` // Working directory (optional)
|
||||
Args []string `yaml:"args,omitempty"` // Additional arguments
|
||||
Env []string `yaml:"env,omitempty"` // Environment variables
|
||||
}
|
||||
|
||||
type DemoConfig struct {
|
||||
ClientID string `yaml:"client_id"`
|
||||
ClientSecret string `yaml:"client_secret"`
|
||||
|
@ -70,123 +84,74 @@ type DefaultConfig struct {
|
|||
|
||||
type Config struct {
|
||||
AuthServerBaseURL string
|
||||
MCPServerBaseURL string `yaml:"mcp_server_base_url"`
|
||||
ListenPort int `yaml:"listen_port"`
|
||||
ListenPort int `yaml:"listen_port"`
|
||||
BaseURL string `yaml:"base_url"`
|
||||
Port int `yaml:"port"`
|
||||
JWKSURL string
|
||||
TimeoutSeconds int `yaml:"timeout_seconds"`
|
||||
MCPPaths []string `yaml:"mcp_paths"`
|
||||
PathMapping map[string]string `yaml:"path_mapping"`
|
||||
Mode string `yaml:"mode"`
|
||||
CORSConfig CORSConfig `yaml:"cors"`
|
||||
TransportMode TransportMode `yaml:"transport_mode"`
|
||||
Paths PathsConfig `yaml:"paths"`
|
||||
Stdio StdioConfig `yaml:"stdio"`
|
||||
|
||||
// Nested config for Asgardeo
|
||||
Demo DemoConfig `yaml:"demo"`
|
||||
Asgardeo AsgardeoConfig `yaml:"asgardeo"`
|
||||
Default DefaultConfig `yaml:"default"`
|
||||
Command Command `yaml:"command"` // Command to run
|
||||
}
|
||||
|
||||
// Command struct with explicit configuration for all relevant paths
|
||||
type Command struct {
|
||||
Enabled bool `yaml:"enabled"`
|
||||
UserCommand string `yaml:"user_command"` // Only the part provided by the user
|
||||
BaseUrl string `yaml:"base_url"` // Base URL for the MCP server
|
||||
Port int `yaml:"port"` // Port for the MCP server
|
||||
SsePath string `yaml:"sse_path"` // SSE endpoint path
|
||||
MessagePath string `yaml:"message_path"` // Messages endpoint path
|
||||
WorkDir string `yaml:"work_dir"` // Working directory
|
||||
Args []string `yaml:"args,omitempty"` // Additional arguments
|
||||
Env []string `yaml:"env,omitempty"` // Environment variables
|
||||
}
|
||||
|
||||
// Validate checks if the command config is valid based on transport mode
|
||||
func (c *Command) Validate(transportMode TransportMode) error {
|
||||
if transportMode == StdioTransport {
|
||||
if !c.Enabled {
|
||||
return fmt.Errorf("command must be enabled in stdio transport mode")
|
||||
// Validate checks if the config is valid based on transport mode
|
||||
func (c *Config) Validate() error {
|
||||
// Validate based on transport mode
|
||||
if c.TransportMode == StdioTransport {
|
||||
if !c.Stdio.Enabled {
|
||||
return fmt.Errorf("stdio.enabled must be true in stdio transport mode")
|
||||
}
|
||||
if c.UserCommand == "" {
|
||||
return fmt.Errorf("user_command is required in stdio transport mode")
|
||||
if c.Stdio.UserCommand == "" {
|
||||
return fmt.Errorf("stdio.user_command is required in stdio transport mode")
|
||||
}
|
||||
}
|
||||
|
||||
// Validate paths
|
||||
if c.Paths.SSE == "" {
|
||||
c.Paths.SSE = "/sse" // Default value
|
||||
}
|
||||
if c.Paths.Messages == "" {
|
||||
c.Paths.Messages = "/messages" // Default value
|
||||
}
|
||||
|
||||
// Validate base URL
|
||||
if c.BaseURL == "" {
|
||||
if c.Port > 0 {
|
||||
c.BaseURL = fmt.Sprintf("http://localhost:%d", c.Port)
|
||||
} else {
|
||||
c.BaseURL = "http://localhost:8000" // Default value
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetBaseURL returns the base URL for the MCP server
|
||||
func (c *Command) GetBaseURL() string {
|
||||
if c.BaseUrl != "" {
|
||||
return c.BaseUrl
|
||||
}
|
||||
if c.Port > 0 {
|
||||
return fmt.Sprintf("http://localhost:%d", c.Port)
|
||||
}
|
||||
return "http://localhost:8000" // default
|
||||
// GetMCPPaths returns the list of paths that should be proxied to the MCP server
|
||||
func (c *Config) GetMCPPaths() []string {
|
||||
return []string{c.Paths.SSE, c.Paths.Messages}
|
||||
}
|
||||
|
||||
// GetPaths returns the SSE and message paths
|
||||
func (c *Command) GetPaths() []string {
|
||||
var paths []string
|
||||
|
||||
// Add SSE path
|
||||
ssePath := c.SsePath
|
||||
if ssePath == "" {
|
||||
ssePath = "/sse" // default
|
||||
}
|
||||
paths = append(paths, ssePath)
|
||||
|
||||
// Add message path
|
||||
messagePath := c.MessagePath
|
||||
if messagePath == "" {
|
||||
messagePath = "/messages" // default
|
||||
}
|
||||
paths = append(paths, messagePath)
|
||||
|
||||
return paths
|
||||
}
|
||||
|
||||
// BuildExecCommand constructs the full command string for execution
|
||||
func (c *Command) BuildExecCommand() string {
|
||||
if c.UserCommand == "" {
|
||||
// BuildExecCommand constructs the full command string for execution in stdio mode
|
||||
func (c *Config) BuildExecCommand() string {
|
||||
if c.Stdio.UserCommand == "" {
|
||||
return ""
|
||||
}
|
||||
|
||||
// Apply defaults if not specified
|
||||
port := c.Port
|
||||
if port == 0 {
|
||||
port = 8000
|
||||
}
|
||||
|
||||
baseUrl := c.BaseUrl
|
||||
if baseUrl == "" {
|
||||
baseUrl = fmt.Sprintf("http://localhost:%d", port)
|
||||
}
|
||||
|
||||
ssePath := c.SsePath
|
||||
if ssePath == "" {
|
||||
ssePath = "/sse"
|
||||
}
|
||||
|
||||
messagePath := c.MessagePath
|
||||
if messagePath == "" {
|
||||
messagePath = "/messages"
|
||||
}
|
||||
|
||||
// Construct the full command
|
||||
return fmt.Sprintf(
|
||||
`npx -y supergateway --stdio "%s" --port %d --baseUrl %s --ssePath %s --messagePath %s`,
|
||||
c.UserCommand, port, baseUrl, ssePath, messagePath,
|
||||
c.Stdio.UserCommand, c.Port, c.BaseURL, c.Paths.SSE, c.Paths.Messages,
|
||||
)
|
||||
}
|
||||
|
||||
// GetExec returns the complete command string for execution
|
||||
func (c *Command) GetExec() string {
|
||||
if c.UserCommand == "" {
|
||||
return ""
|
||||
}
|
||||
return c.BuildExecCommand()
|
||||
}
|
||||
|
||||
// LoadConfig reads a YAML config file into Config struct.
|
||||
func LoadConfig(path string) (*Config, error) {
|
||||
f, err := os.Open(path)
|
||||
|
@ -200,6 +165,8 @@ func LoadConfig(path string) (*Config, error) {
|
|||
if err := decoder.Decode(&cfg); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Set default values
|
||||
if cfg.TimeoutSeconds == 0 {
|
||||
cfg.TimeoutSeconds = 15 // default
|
||||
}
|
||||
|
@ -208,26 +175,16 @@ func LoadConfig(path string) (*Config, error) {
|
|||
if cfg.TransportMode == "" {
|
||||
cfg.TransportMode = SSETransport // Default to SSE
|
||||
}
|
||||
|
||||
// Set default port if not specified
|
||||
if cfg.Port == 0 {
|
||||
cfg.Port = 8000 // default
|
||||
}
|
||||
|
||||
// Validate command config based on transport mode
|
||||
if err := cfg.Command.Validate(cfg.TransportMode); err != nil {
|
||||
// Validate the configuration
|
||||
if err := cfg.Validate(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// In stdio mode, use command.base_url for MCPServerBaseURL if it's not explicitly set
|
||||
if cfg.TransportMode == StdioTransport && cfg.MCPServerBaseURL == "" {
|
||||
cfg.MCPServerBaseURL = cfg.Command.GetBaseURL()
|
||||
} else if cfg.TransportMode == SSETransport && cfg.MCPServerBaseURL == "" {
|
||||
return nil, fmt.Errorf("mcp_server_base_url is required in SSE transport mode")
|
||||
}
|
||||
|
||||
// In stdio mode, set the MCPPaths from the command configuration
|
||||
if cfg.TransportMode == StdioTransport && cfg.Command.Enabled {
|
||||
// Override MCPPaths with paths from command configuration
|
||||
cfg.MCPPaths = cfg.Command.GetPaths()
|
||||
} else if cfg.TransportMode == SSETransport && len(cfg.MCPPaths) == 0 {
|
||||
return nil, fmt.Errorf("mcp_paths are required in SSE transport mode")
|
||||
}
|
||||
|
||||
return &cfg, nil
|
||||
}
|
|
@ -82,7 +82,8 @@ func NewRouter(cfg *config.Config, provider authz.Provider) http.Handler {
|
|||
}
|
||||
|
||||
// MCP paths
|
||||
for _, path := range cfg.MCPPaths {
|
||||
mcpPaths := cfg.GetMCPPaths()
|
||||
for _, path := range mcpPaths {
|
||||
mux.HandleFunc(path, buildProxyHandler(cfg, modifiers))
|
||||
registeredPaths[path] = true
|
||||
}
|
||||
|
@ -105,7 +106,8 @@ func buildProxyHandler(cfg *config.Config, modifiers map[string]RequestModifier)
|
|||
logger.Error("Invalid auth server URL: %v", err)
|
||||
panic(err) // Fatal error that prevents startup
|
||||
}
|
||||
mcpBase, err := url.Parse(cfg.MCPServerBaseURL)
|
||||
|
||||
mcpBase, err := url.Parse(cfg.BaseURL)
|
||||
if err != nil {
|
||||
logger.Error("Invalid MCP server URL: %v", err)
|
||||
panic(err) // Fatal error that prevents startup
|
||||
|
@ -113,11 +115,7 @@ func buildProxyHandler(cfg *config.Config, modifiers map[string]RequestModifier)
|
|||
|
||||
// Detect SSE paths from config
|
||||
ssePaths := make(map[string]bool)
|
||||
for _, p := range cfg.MCPPaths {
|
||||
if p == "/sse" {
|
||||
ssePaths[p] = true
|
||||
}
|
||||
}
|
||||
ssePaths[cfg.Paths.SSE] = true
|
||||
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
origin := r.Header.Get("Origin")
|
||||
|
@ -294,7 +292,8 @@ func isAuthPath(path string) bool {
|
|||
|
||||
// isMCPPath checks if the path is an MCP path
|
||||
func isMCPPath(path string, cfg *config.Config) bool {
|
||||
for _, p := range cfg.MCPPaths {
|
||||
mcpPaths := cfg.GetMCPPaths()
|
||||
for _, p := range mcpPaths {
|
||||
if strings.HasPrefix(path, p) {
|
||||
return true
|
||||
}
|
||||
|
|
|
@ -66,8 +66,8 @@ func (m *Manager) SetShutdownDelay(duration time.Duration) {
|
|||
m.shutdownDelay = duration
|
||||
}
|
||||
|
||||
// Start launches a subprocess based on the command configuration
|
||||
func (m *Manager) Start(cmdConfig *config.Command) error {
|
||||
// Start launches a subprocess based on the configuration
|
||||
func (m *Manager) Start(cfg *config.Config) error {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
|
@ -76,12 +76,12 @@ func (m *Manager) Start(cmdConfig *config.Command) error {
|
|||
return os.ErrExist
|
||||
}
|
||||
|
||||
if !cmdConfig.Enabled || cmdConfig.UserCommand == "" {
|
||||
if !cfg.Stdio.Enabled || cfg.Stdio.UserCommand == "" {
|
||||
return nil // Nothing to start
|
||||
}
|
||||
|
||||
// Get the full command string
|
||||
execCommand := cmdConfig.GetExec()
|
||||
execCommand := cfg.BuildExecCommand()
|
||||
if execCommand == "" {
|
||||
return nil // No command to execute
|
||||
}
|
||||
|
@ -92,13 +92,13 @@ func (m *Manager) Start(cmdConfig *config.Command) error {
|
|||
cmd := exec.Command("sh", "-c", execCommand)
|
||||
|
||||
// Set working directory if specified
|
||||
if cmdConfig.WorkDir != "" {
|
||||
cmd.Dir = cmdConfig.WorkDir
|
||||
if cfg.Stdio.WorkDir != "" {
|
||||
cmd.Dir = cfg.Stdio.WorkDir
|
||||
}
|
||||
|
||||
// Set environment variables if specified
|
||||
if len(cmdConfig.Env) > 0 {
|
||||
cmd.Env = append(os.Environ(), cmdConfig.Env...)
|
||||
if len(cfg.Stdio.Env) > 0 {
|
||||
cmd.Env = append(os.Environ(), cfg.Stdio.Env...)
|
||||
}
|
||||
|
||||
// Capture stdout/stderr
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue