diff --git a/README.md b/README.md index a0cff62..6fdd29e 100644 --- a/README.md +++ b/README.md @@ -31,7 +31,7 @@ go build -o openmcpauthproxy ./cmd/proxy The Open MCP Auth Proxy supports two transport modes: 1. **SSE Mode (Default)**: For MCP servers that use Server-Sent Events transport -2. **stdio Mode**: For MCP servers that use stdio transport, which requires starting a MCP Server as a subprocess +2. **stdio Mode**: For MCP servers that use stdio transport, which requires starting a subprocess You can specify the transport mode in the `config.yaml` file: @@ -45,24 +45,38 @@ Or use the `--stdio` flag to override the configuration: ./openmcpauthproxy --stdio ``` -**Configuration Requirements by Transport Mode:** +### Configuration -**SSE Mode:** -- `mcp_server_base_url` is required (points to an external MCP server) -- The `command` section is optional and will be ignored -- No subprocess will be started -- The proxy expects an external MCP server to be running at the specified URL +The configuration uses a unified structure with common settings and transport-specific options: -**stdio Mode:** -- The `command` section in `config.yaml` is mandatory -- `mcp_server_base_url` is optional (if not specified, it will use `command.base_url`) -- The proxy will start a subprocess as specified in the command configuration -- The subprocess will be terminated when the proxy shuts down +```yaml +# Common configuration +listen_port: 8080 +base_url: "http://localhost:8000" # Base URL for the MCP server +port: 8000 # Port for the MCP server + +# Path configuration +paths: + sse: "/sse" # SSE endpoint path + messages: "/messages" # Messages endpoint path + +# Transport mode configuration +transport_mode: "sse" # Options: "sse" or "stdio" + +# stdio-specific configuration (used only when transport_mode is "stdio") +stdio: + enabled: true + user_command: "npx -y @modelcontextprotocol/server-github" + work_dir: "" # Working directory (optional) +``` + +**Notes:** +- In SSE mode, the proxy connects to an external MCP server at the specified `base_url` +- In stdio mode, the proxy starts a subprocess using the `stdio.user_command` configuration +- Common settings like `base_url`, `port`, and `paths` are used for both transport modes ### Quick Start -Allows you to just enable authentication and authorization for your MCP server with the preconfigured auth provider powered by Asgardeo. - If you don't have an MCP server, follow the instructions given here to start your own MCP server for testing purposes. 1. Navigate to `resources` directory. @@ -91,36 +105,28 @@ python3 echo_server.py #### Configure the Auth Proxy -Update the following parameters in `config.yaml`. - -### Configuration examples: - -**SSE mode (using external MCP server):** -```yaml -transport_mode: "sse" # Transport mode: "sse" or "stdio" -mcp_server_base_url: "http://localhost:8000" # URL of your MCP server (required in SSE mode) -listen_port: 8080 # Address where the proxy will listen -``` - -**stdio mode (using subprocess):** -```yaml -transport_mode: "stdio" # Transport mode: "sse" or "stdio" - -command: - enabled: true # Must be true in stdio mode - user_command: "npx -y @modelcontextprotocol/server-github" # Required in stdio mode - base_url: "http://localhost:8000" # Used as MCP server base URL if not specified above - port: 8000 - sse_path: "/sse" # SSE endpoint path - message_path: "/messages" # Messages endpoint path -``` +Update the necessary parameters in `config.yaml` as shown in the examples above. #### Start the Auth Proxy +For the demo mode with pre-configured authentication: + ```bash ./openmcpauthproxy --demo ``` +For standard mode: + +```bash +./openmcpauthproxy +``` + +For stdio mode: + +```bash +./openmcpauthproxy --stdio +``` + The `--demo` flag enables a demonstration mode with pre-configured authentication and authorization with a sandbox powered by [Asgardeo](https://asgardeo.io/). #### Connect Using an MCP Client @@ -143,9 +149,17 @@ Enable authorization for the MCP server through your own Asgardeo organization Create a configuration file config.yaml with the following parameters: ```yaml -mcp_server_base_url: "http://localhost:8000" # URL of your MCP server -listen_port: 8080 # Address where the proxy will listen -transport_mode: "sse" # Transport mode: "sse" or "stdio" +# Common configuration +listen_port: 8080 +base_url: "http://localhost:8000" # Base URL for the MCP server + +# Path configuration +paths: + sse: "/sse" + messages: "/messages" + +# Transport mode +transport_mode: "sse" # or "stdio" asgardeo: org_name: "" # Your Asgardeo org name @@ -159,26 +173,6 @@ asgardeo: ./openmcpauthproxy --asgardeo ``` -### Use with any standard OAuth Server - -Enable authorization for the MCP server with a compliant OAuth server - -#### Configuration - -Create a configuration file config.yaml with the following parameters: - -```yaml -mcp_server_base_url: "http://localhost:8000" # URL of your MCP server -listen_port: 8080 # Address where the proxy will listen -transport_mode: "sse" # Transport mode: "sse" or "stdio" -``` -**TODO**: Update the configs for a standard OAuth Server. - -#### Start the Auth Proxy - -```bash -./openmcpauthproxy -``` #### Integrating with existing OAuth Providers - [Auth0](docs/Auth0.md) - Enable authorization for the MCP server through your Auth0 organization. \ No newline at end of file diff --git a/cmd/proxy/main.go b/cmd/proxy/main.go index 6886d9b..27bab5d 100644 --- a/cmd/proxy/main.go +++ b/cmd/proxy/main.go @@ -37,43 +37,37 @@ func main() { // Override transport mode if stdio flag is set if *stdioMode { cfg.TransportMode = config.StdioTransport - // Validate command config for stdio mode - if err := cfg.Command.Validate(cfg.TransportMode); err != nil { + // Ensure stdio is enabled + cfg.Stdio.Enabled = true + // Re-validate config + if err := cfg.Validate(); err != nil { logger.Error("Configuration error: %v", err) os.Exit(1) } } - // 2. Ensure MCPPaths are properly configured - if cfg.TransportMode == config.StdioTransport && cfg.Command.Enabled { - // Use command.base_url for MCPServerBaseURL in stdio mode - cfg.MCPServerBaseURL = cfg.Command.GetBaseURL() - - // Use command paths for MCPPaths in stdio mode - cfg.MCPPaths = cfg.Command.GetPaths() - - logger.Info("Using MCP server baseUrl: %s", cfg.MCPServerBaseURL) - logger.Info("Using MCP paths: %v", cfg.MCPPaths) - } + logger.Info("Using transport mode: %s", cfg.TransportMode) + logger.Info("Using MCP server base URL: %s", cfg.BaseURL) + logger.Info("Using MCP paths: SSE=%s, Messages=%s", cfg.Paths.SSE, cfg.Paths.Messages) - // 3. Start subprocess if configured and in stdio mode + // 2. Start subprocess if configured and in stdio mode var procManager *subprocess.Manager - if cfg.TransportMode == config.StdioTransport && cfg.Command.Enabled && cfg.Command.UserCommand != "" { + if cfg.TransportMode == config.StdioTransport && cfg.Stdio.Enabled { // Ensure all required dependencies are available - if err := subprocess.EnsureDependenciesAvailable(cfg.Command.UserCommand); err != nil { + if err := subprocess.EnsureDependenciesAvailable(cfg.Stdio.UserCommand); err != nil { logger.Warn("%v", err) logger.Warn("Subprocess may fail to start due to missing dependencies") } procManager = subprocess.NewManager() - if err := procManager.Start(&cfg.Command); err != nil { + if err := procManager.Start(cfg); err != nil { logger.Warn("Failed to start subprocess: %v", err) } } else if cfg.TransportMode == config.SSETransport { logger.Info("Using SSE transport mode, not starting subprocess") } - // 4. Create the chosen provider + // 3. Create the chosen provider var provider authz.Provider if *demoMode { cfg.Mode = "demo" @@ -92,18 +86,18 @@ func main() { provider = authz.NewDefaultProvider(cfg) } - // 5. (Optional) Fetch JWKS if you want local JWT validation + // 4. (Optional) Fetch JWKS if you want local JWT validation if err := util.FetchJWKS(cfg.JWKSURL); err != nil { logger.Error("Failed to fetch JWKS: %v", err) os.Exit(1) } - // 6. Build the main router + // 5. Build the main router mux := proxy.NewRouter(cfg, provider) listen_address := fmt.Sprintf(":%d", cfg.ListenPort) - // 7. Start the server + // 6. Start the server srv := &http.Server{ Addr: listen_address, Handler: mux, @@ -117,18 +111,18 @@ func main() { } }() - // 8. Wait for shutdown signal + // 7. Wait for shutdown signal stop := make(chan os.Signal, 1) signal.Notify(stop, os.Interrupt, syscall.SIGTERM) <-stop logger.Info("Shutting down...") - // 9. First terminate subprocess if running + // 8. First terminate subprocess if running if procManager != nil && procManager.IsRunning() { procManager.Shutdown() } - // 10. Then shutdown the server + // 9. Then shutdown the server logger.Info("Shutting down HTTP server...") shutdownCtx, cancel := proxy.NewShutdownContext(5 * time.Second) defer cancel() @@ -137,30 +131,4 @@ func main() { logger.Error("HTTP server shutdown error: %v", err) } logger.Info("Stopped.") -} - -// Helper function to ensure a path is in a list -func ensurePathInList(paths *[]string, path string) { - // Check if path exists in the list - for _, p := range *paths { - if p == path { - return // Path already exists - } - } - // Path doesn't exist, add it - *paths = append(*paths, path) - logger.Info("Added path %s to MCPPaths", path) -} - -// Helper function to ensure an origin is in a list -func ensureOriginInList(origins *[]string, origin string) { - // Check if origin exists in the list - for _, o := range *origins { - if o == origin { - return // Origin already exists - } - } - // Origin doesn't exist, add it - *origins = append(*origins, origin) - logger.Info("Added %s to allowed CORS origins", origin) } \ No newline at end of file diff --git a/config.yaml b/config.yaml index e7e0537..29bf812 100644 --- a/config.yaml +++ b/config.yaml @@ -1,30 +1,31 @@ # config.yaml -transport_mode: "stdio" # Options: "sse" or "stdio" - -# For SSE mode, mcp_server_base_url and mcp_paths are required -# For stdio mode, both are optional and will be derived from command configuration if not specified -mcp_server_base_url: "http://localhost:8000" +# Common configuration for all transport modes listen_port: 8080 +base_url: "http://localhost:8000" # Base URL for the MCP server +port: 8000 # Port for the MCP server timeout_seconds: 10 -mcp_paths: # Required in SSE mode, ignored in stdio mode (derived from command) - - /messages/ - - /sse -# Subprocess configuration -command: +# Path configuration +paths: + sse: "/sse" # SSE endpoint path + messages: "/messages" # Messages endpoint path + +# Transport mode configuration +transport_mode: "sse" # Options: "sse" or "stdio" + +# stdio-specific configuration (used only when transport_mode is "stdio") +stdio: enabled: true - user_command: "npx -y @modelcontextprotocol/server-github" # User only needs to provide this part - base_url: "http://localhost:8000" # Will be used for CORS and in the full command - port: 8000 # Port for the MCP server - sse_path: "/sse" # SSE endpoint path - message_path: "/messages" # Messages endpoint path + user_command: "npx -y @modelcontextprotocol/server-github" work_dir: "" # Working directory (optional) - # env: # Environment variables (optional) + # env: # Environment variables (optional) # - "NODE_ENV=development" +# Path mapping (optional) path_mapping: +# CORS configuration cors: allowed_origins: - "http://localhost:5173" @@ -38,6 +39,7 @@ cors: - "Content-Type" allow_credentials: true +# Demo configuration for Asgardeo demo: org_name: "openmcpauthdemo" client_id: "N0U9e_NNGr9mP_0fPnPfPI0a6twa" diff --git a/internal/config/config.go b/internal/config/config.go index c9e4118..3aba71b 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -15,7 +15,21 @@ const ( StdioTransport TransportMode = "stdio" ) -// AsgardeoConfig groups all Asgardeo-specific fields +// Common path configuration for all transport modes +type PathsConfig struct { + SSE string `yaml:"sse"` + Messages string `yaml:"messages"` +} + +// StdioConfig contains stdio-specific configuration +type StdioConfig struct { + Enabled bool `yaml:"enabled"` + UserCommand string `yaml:"user_command"` // The command provided by the user + WorkDir string `yaml:"work_dir"` // Working directory (optional) + Args []string `yaml:"args,omitempty"` // Additional arguments + Env []string `yaml:"env,omitempty"` // Environment variables +} + type DemoConfig struct { ClientID string `yaml:"client_id"` ClientSecret string `yaml:"client_secret"` @@ -70,123 +84,74 @@ type DefaultConfig struct { type Config struct { AuthServerBaseURL string - MCPServerBaseURL string `yaml:"mcp_server_base_url"` - ListenPort int `yaml:"listen_port"` + ListenPort int `yaml:"listen_port"` + BaseURL string `yaml:"base_url"` + Port int `yaml:"port"` JWKSURL string TimeoutSeconds int `yaml:"timeout_seconds"` - MCPPaths []string `yaml:"mcp_paths"` PathMapping map[string]string `yaml:"path_mapping"` Mode string `yaml:"mode"` CORSConfig CORSConfig `yaml:"cors"` TransportMode TransportMode `yaml:"transport_mode"` + Paths PathsConfig `yaml:"paths"` + Stdio StdioConfig `yaml:"stdio"` // Nested config for Asgardeo Demo DemoConfig `yaml:"demo"` Asgardeo AsgardeoConfig `yaml:"asgardeo"` Default DefaultConfig `yaml:"default"` - Command Command `yaml:"command"` // Command to run } -// Command struct with explicit configuration for all relevant paths -type Command struct { - Enabled bool `yaml:"enabled"` - UserCommand string `yaml:"user_command"` // Only the part provided by the user - BaseUrl string `yaml:"base_url"` // Base URL for the MCP server - Port int `yaml:"port"` // Port for the MCP server - SsePath string `yaml:"sse_path"` // SSE endpoint path - MessagePath string `yaml:"message_path"` // Messages endpoint path - WorkDir string `yaml:"work_dir"` // Working directory - Args []string `yaml:"args,omitempty"` // Additional arguments - Env []string `yaml:"env,omitempty"` // Environment variables -} - -// Validate checks if the command config is valid based on transport mode -func (c *Command) Validate(transportMode TransportMode) error { - if transportMode == StdioTransport { - if !c.Enabled { - return fmt.Errorf("command must be enabled in stdio transport mode") +// Validate checks if the config is valid based on transport mode +func (c *Config) Validate() error { + // Validate based on transport mode + if c.TransportMode == StdioTransport { + if !c.Stdio.Enabled { + return fmt.Errorf("stdio.enabled must be true in stdio transport mode") } - if c.UserCommand == "" { - return fmt.Errorf("user_command is required in stdio transport mode") + if c.Stdio.UserCommand == "" { + return fmt.Errorf("stdio.user_command is required in stdio transport mode") } } + + // Validate paths + if c.Paths.SSE == "" { + c.Paths.SSE = "/sse" // Default value + } + if c.Paths.Messages == "" { + c.Paths.Messages = "/messages" // Default value + } + + // Validate base URL + if c.BaseURL == "" { + if c.Port > 0 { + c.BaseURL = fmt.Sprintf("http://localhost:%d", c.Port) + } else { + c.BaseURL = "http://localhost:8000" // Default value + } + } + return nil } -// GetBaseURL returns the base URL for the MCP server -func (c *Command) GetBaseURL() string { - if c.BaseUrl != "" { - return c.BaseUrl - } - if c.Port > 0 { - return fmt.Sprintf("http://localhost:%d", c.Port) - } - return "http://localhost:8000" // default +// GetMCPPaths returns the list of paths that should be proxied to the MCP server +func (c *Config) GetMCPPaths() []string { + return []string{c.Paths.SSE, c.Paths.Messages} } -// GetPaths returns the SSE and message paths -func (c *Command) GetPaths() []string { - var paths []string - - // Add SSE path - ssePath := c.SsePath - if ssePath == "" { - ssePath = "/sse" // default - } - paths = append(paths, ssePath) - - // Add message path - messagePath := c.MessagePath - if messagePath == "" { - messagePath = "/messages" // default - } - paths = append(paths, messagePath) - - return paths -} - -// BuildExecCommand constructs the full command string for execution -func (c *Command) BuildExecCommand() string { - if c.UserCommand == "" { +// BuildExecCommand constructs the full command string for execution in stdio mode +func (c *Config) BuildExecCommand() string { + if c.Stdio.UserCommand == "" { return "" } - // Apply defaults if not specified - port := c.Port - if port == 0 { - port = 8000 - } - - baseUrl := c.BaseUrl - if baseUrl == "" { - baseUrl = fmt.Sprintf("http://localhost:%d", port) - } - - ssePath := c.SsePath - if ssePath == "" { - ssePath = "/sse" - } - - messagePath := c.MessagePath - if messagePath == "" { - messagePath = "/messages" - } - // Construct the full command return fmt.Sprintf( `npx -y supergateway --stdio "%s" --port %d --baseUrl %s --ssePath %s --messagePath %s`, - c.UserCommand, port, baseUrl, ssePath, messagePath, + c.Stdio.UserCommand, c.Port, c.BaseURL, c.Paths.SSE, c.Paths.Messages, ) } -// GetExec returns the complete command string for execution -func (c *Command) GetExec() string { - if c.UserCommand == "" { - return "" - } - return c.BuildExecCommand() -} - // LoadConfig reads a YAML config file into Config struct. func LoadConfig(path string) (*Config, error) { f, err := os.Open(path) @@ -200,6 +165,8 @@ func LoadConfig(path string) (*Config, error) { if err := decoder.Decode(&cfg); err != nil { return nil, err } + + // Set default values if cfg.TimeoutSeconds == 0 { cfg.TimeoutSeconds = 15 // default } @@ -208,26 +175,16 @@ func LoadConfig(path string) (*Config, error) { if cfg.TransportMode == "" { cfg.TransportMode = SSETransport // Default to SSE } + + // Set default port if not specified + if cfg.Port == 0 { + cfg.Port = 8000 // default + } - // Validate command config based on transport mode - if err := cfg.Command.Validate(cfg.TransportMode); err != nil { + // Validate the configuration + if err := cfg.Validate(); err != nil { return nil, err } - // In stdio mode, use command.base_url for MCPServerBaseURL if it's not explicitly set - if cfg.TransportMode == StdioTransport && cfg.MCPServerBaseURL == "" { - cfg.MCPServerBaseURL = cfg.Command.GetBaseURL() - } else if cfg.TransportMode == SSETransport && cfg.MCPServerBaseURL == "" { - return nil, fmt.Errorf("mcp_server_base_url is required in SSE transport mode") - } - - // In stdio mode, set the MCPPaths from the command configuration - if cfg.TransportMode == StdioTransport && cfg.Command.Enabled { - // Override MCPPaths with paths from command configuration - cfg.MCPPaths = cfg.Command.GetPaths() - } else if cfg.TransportMode == SSETransport && len(cfg.MCPPaths) == 0 { - return nil, fmt.Errorf("mcp_paths are required in SSE transport mode") - } - return &cfg, nil } \ No newline at end of file diff --git a/internal/proxy/proxy.go b/internal/proxy/proxy.go index 898804a..089f470 100644 --- a/internal/proxy/proxy.go +++ b/internal/proxy/proxy.go @@ -82,7 +82,8 @@ func NewRouter(cfg *config.Config, provider authz.Provider) http.Handler { } // MCP paths - for _, path := range cfg.MCPPaths { + mcpPaths := cfg.GetMCPPaths() + for _, path := range mcpPaths { mux.HandleFunc(path, buildProxyHandler(cfg, modifiers)) registeredPaths[path] = true } @@ -105,7 +106,8 @@ func buildProxyHandler(cfg *config.Config, modifiers map[string]RequestModifier) logger.Error("Invalid auth server URL: %v", err) panic(err) // Fatal error that prevents startup } - mcpBase, err := url.Parse(cfg.MCPServerBaseURL) + + mcpBase, err := url.Parse(cfg.BaseURL) if err != nil { logger.Error("Invalid MCP server URL: %v", err) panic(err) // Fatal error that prevents startup @@ -113,11 +115,7 @@ func buildProxyHandler(cfg *config.Config, modifiers map[string]RequestModifier) // Detect SSE paths from config ssePaths := make(map[string]bool) - for _, p := range cfg.MCPPaths { - if p == "/sse" { - ssePaths[p] = true - } - } + ssePaths[cfg.Paths.SSE] = true return func(w http.ResponseWriter, r *http.Request) { origin := r.Header.Get("Origin") @@ -294,7 +292,8 @@ func isAuthPath(path string) bool { // isMCPPath checks if the path is an MCP path func isMCPPath(path string, cfg *config.Config) bool { - for _, p := range cfg.MCPPaths { + mcpPaths := cfg.GetMCPPaths() + for _, p := range mcpPaths { if strings.HasPrefix(path, p) { return true } diff --git a/internal/subprocess/manager.go b/internal/subprocess/manager.go index e667886..6d140b9 100644 --- a/internal/subprocess/manager.go +++ b/internal/subprocess/manager.go @@ -66,8 +66,8 @@ func (m *Manager) SetShutdownDelay(duration time.Duration) { m.shutdownDelay = duration } -// Start launches a subprocess based on the command configuration -func (m *Manager) Start(cmdConfig *config.Command) error { +// Start launches a subprocess based on the configuration +func (m *Manager) Start(cfg *config.Config) error { m.mutex.Lock() defer m.mutex.Unlock() @@ -76,12 +76,12 @@ func (m *Manager) Start(cmdConfig *config.Command) error { return os.ErrExist } - if !cmdConfig.Enabled || cmdConfig.UserCommand == "" { + if !cfg.Stdio.Enabled || cfg.Stdio.UserCommand == "" { return nil // Nothing to start } // Get the full command string - execCommand := cmdConfig.GetExec() + execCommand := cfg.BuildExecCommand() if execCommand == "" { return nil // No command to execute } @@ -92,13 +92,13 @@ func (m *Manager) Start(cmdConfig *config.Command) error { cmd := exec.Command("sh", "-c", execCommand) // Set working directory if specified - if cmdConfig.WorkDir != "" { - cmd.Dir = cmdConfig.WorkDir + if cfg.Stdio.WorkDir != "" { + cmd.Dir = cfg.Stdio.WorkDir } // Set environment variables if specified - if len(cmdConfig.Env) > 0 { - cmd.Env = append(os.Environ(), cmdConfig.Env...) + if len(cfg.Stdio.Env) > 0 { + cmd.Env = append(os.Environ(), cfg.Stdio.Env...) } // Capture stdout/stderr