From 316370be1c0fd7204d2ca7af7e9d5dffe7c5eb9b Mon Sep 17 00:00:00 2001 From: Thilina Shashimal Senarath <43197743+shashimalcse@users.noreply.github.com> Date: Tue, 27 May 2025 13:27:02 +0530 Subject: [PATCH] Add StreambleHTTP support (#35) * Add StreambleHTTP support --- config.yaml | 7 +++--- internal/config/config.go | 44 +++++++++++++++++----------------- internal/config/config_test.go | 15 ++++-------- internal/proxy/proxy.go | 12 +++++----- 4 files changed, 37 insertions(+), 41 deletions(-) diff --git a/config.yaml b/config.yaml index 5621195..427fc15 100644 --- a/config.yaml +++ b/config.yaml @@ -2,14 +2,15 @@ # Common configuration for all transport modes listen_port: 8080 -base_url: "http://localhost:8000" # Base URL for the MCP server -port: 8000 # Port for the MCP server +base_url: "http://localhost:3001" # Base URL for the MCP server +port: 3001 # Port for the MCP server timeout_seconds: 10 # Path configuration paths: sse: "/sse" # SSE endpoint path messages: "/messages/" # Messages endpoint path + streamable_http: "/mcp" # MCP endpoint path # Transport mode configuration transport_mode: "sse" # Options: "sse" or "stdio" @@ -28,7 +29,7 @@ path_mapping: # CORS configuration cors: allowed_origins: - - "http://localhost:5173" + - "http://127.0.0.1:6274" allowed_methods: - "GET" - "POST" diff --git a/internal/config/config.go b/internal/config/config.go index c50d9ed..c51688f 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -19,15 +19,16 @@ const ( // Common path configuration for all transport modes type PathsConfig struct { - SSE string `yaml:"sse"` - Messages string `yaml:"messages"` + SSE string `yaml:"sse"` + Messages string `yaml:"messages"` + StreamableHTTP string `yaml:"streamable_http"` // Path for streamable HTTP requests } // StdioConfig contains stdio-specific configuration type StdioConfig struct { Enabled bool `yaml:"enabled"` - UserCommand string `yaml:"user_command"` // The command provided by the user - WorkDir string `yaml:"work_dir"` // Working directory (optional) + UserCommand string `yaml:"user_command"` // The command provided by the user + WorkDir string `yaml:"work_dir"` // Working directory (optional) Args []string `yaml:"args,omitempty"` // Additional arguments Env []string `yaml:"env,omitempty"` // Environment variables } @@ -85,18 +86,18 @@ type DefaultConfig struct { } type Config struct { - AuthServerBaseURL string - ListenPort int `yaml:"listen_port"` - BaseURL string `yaml:"base_url"` - Port int `yaml:"port"` - JWKSURL string - TimeoutSeconds int `yaml:"timeout_seconds"` - PathMapping map[string]string `yaml:"path_mapping"` - Mode string `yaml:"mode"` - CORSConfig CORSConfig `yaml:"cors"` - TransportMode TransportMode `yaml:"transport_mode"` - Paths PathsConfig `yaml:"paths"` - Stdio StdioConfig `yaml:"stdio"` + AuthServerBaseURL string + ListenPort int `yaml:"listen_port"` + BaseURL string `yaml:"base_url"` + Port int `yaml:"port"` + JWKSURL string + TimeoutSeconds int `yaml:"timeout_seconds"` + PathMapping map[string]string `yaml:"path_mapping"` + Mode string `yaml:"mode"` + CORSConfig CORSConfig `yaml:"cors"` + TransportMode TransportMode `yaml:"transport_mode"` + Paths PathsConfig `yaml:"paths"` + Stdio StdioConfig `yaml:"stdio"` // Nested config for Asgardeo Demo DemoConfig `yaml:"demo"` @@ -138,7 +139,7 @@ func (c *Config) Validate() error { // GetMCPPaths returns the list of paths that should be proxied to the MCP server func (c *Config) GetMCPPaths() []string { - return []string{c.Paths.SSE, c.Paths.Messages} + return []string{c.Paths.SSE, c.Paths.Messages, c.Paths.StreamableHTTP} } // BuildExecCommand constructs the full command string for execution in stdio mode @@ -147,7 +148,6 @@ func (c *Config) BuildExecCommand() string { return "" } - if runtime.GOOS == "windows" { // For Windows, we need to properly escape the inner command escapedCommand := strings.ReplaceAll(c.Stdio.UserCommand, `"`, `\"`) @@ -176,12 +176,12 @@ func LoadConfig(path string) (*Config, error) { if err := decoder.Decode(&cfg); err != nil { return nil, err } - + // Set default values if cfg.TimeoutSeconds == 0 { cfg.TimeoutSeconds = 15 // default } - + // Set default transport mode if not specified if cfg.TransportMode == "" { cfg.TransportMode = SSETransport // Default to SSE @@ -191,11 +191,11 @@ func LoadConfig(path string) (*Config, error) { if cfg.Port == 0 { cfg.Port = 8000 // default } - + // Validate the configuration if err := cfg.Validate(); err != nil { return nil, err } - + return &cfg, nil } diff --git a/internal/config/config_test.go b/internal/config/config_test.go index 20c0893..edf4182 100644 --- a/internal/config/config_test.go +++ b/internal/config/config_test.go @@ -136,20 +136,15 @@ func TestValidate(t *testing.T) { func TestGetMCPPaths(t *testing.T) { cfg := Config{ Paths: PathsConfig{ - SSE: "/custom-sse", - Messages: "/custom-messages", + SSE: "/custom-sse", + Messages: "/custom-messages", + StreamableHTTP: "/custom-streamable", }, } paths := cfg.GetMCPPaths() - if len(paths) != 2 { - t.Errorf("Expected 2 MCP paths, got %d", len(paths)) - } - if paths[0] != "/custom-sse" { - t.Errorf("Expected first path=/custom-sse, got %s", paths[0]) - } - if paths[1] != "/custom-messages" { - t.Errorf("Expected second path=/custom-messages, got %s", paths[1]) + if len(paths) != 3 { + t.Errorf("Expected 3 MCP paths, got %d", len(paths)) } } diff --git a/internal/proxy/proxy.go b/internal/proxy/proxy.go index 33a9ea3..f4d0dec 100644 --- a/internal/proxy/proxy.go +++ b/internal/proxy/proxy.go @@ -10,7 +10,7 @@ import ( "github.com/wso2/open-mcp-auth-proxy/internal/authz" "github.com/wso2/open-mcp-auth-proxy/internal/config" - "github.com/wso2/open-mcp-auth-proxy/internal/logging" + logger "github.com/wso2/open-mcp-auth-proxy/internal/logging" "github.com/wso2/open-mcp-auth-proxy/internal/util" ) @@ -106,7 +106,7 @@ func buildProxyHandler(cfg *config.Config, modifiers map[string]RequestModifier) logger.Error("Invalid auth server URL: %v", err) panic(err) // Fatal error that prevents startup } - + mcpBase, err := url.Parse(cfg.BaseURL) if err != nil { logger.Error("Invalid MCP server URL: %v", err) @@ -191,13 +191,13 @@ func buildProxyHandler(cfg *config.Config, modifiers map[string]RequestModifier) req.Host = targetURL.Host cleanHeaders := http.Header{} - + // Set proper origin header to match the target if isSSE { // For SSE, ensure origin matches the target req.Header.Set("Origin", targetURL.Scheme+"://"+targetURL.Host) } - + for k, v := range r.Header { // Skip hop-by-hop headers if skipHeader(k) { @@ -231,12 +231,12 @@ func buildProxyHandler(cfg *config.Config, modifiers map[string]RequestModifier) proxyHost: r.Host, targetHost: targetURL.Host, } - + // Set SSE-specific headers w.Header().Set("X-Accel-Buffering", "no") w.Header().Set("Cache-Control", "no-cache") w.Header().Set("Connection", "keep-alive") - + // Keep SSE connections open HandleSSE(w, r, rp) } else {