Add StreambleHTTP support (#35)
Some checks failed
Go CI / Test (push) Failing after 47s
Go CI / Build (push) Successful in 47s

* Add StreambleHTTP support
This commit is contained in:
Thilina Shashimal Senarath 2025-05-27 13:27:02 +05:30 committed by GitHub
parent fc0d939e16
commit 316370be1c
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 37 additions and 41 deletions

View file

@ -2,14 +2,15 @@
# Common configuration for all transport modes # Common configuration for all transport modes
listen_port: 8080 listen_port: 8080
base_url: "http://localhost:8000" # Base URL for the MCP server base_url: "http://localhost:3001" # Base URL for the MCP server
port: 8000 # Port for the MCP server port: 3001 # Port for the MCP server
timeout_seconds: 10 timeout_seconds: 10
# Path configuration # Path configuration
paths: paths:
sse: "/sse" # SSE endpoint path sse: "/sse" # SSE endpoint path
messages: "/messages/" # Messages endpoint path messages: "/messages/" # Messages endpoint path
streamable_http: "/mcp" # MCP endpoint path
# Transport mode configuration # Transport mode configuration
transport_mode: "sse" # Options: "sse" or "stdio" transport_mode: "sse" # Options: "sse" or "stdio"
@ -28,7 +29,7 @@ path_mapping:
# CORS configuration # CORS configuration
cors: cors:
allowed_origins: allowed_origins:
- "http://localhost:5173" - "http://127.0.0.1:6274"
allowed_methods: allowed_methods:
- "GET" - "GET"
- "POST" - "POST"

View file

@ -19,15 +19,16 @@ const (
// Common path configuration for all transport modes // Common path configuration for all transport modes
type PathsConfig struct { type PathsConfig struct {
SSE string `yaml:"sse"` SSE string `yaml:"sse"`
Messages string `yaml:"messages"` Messages string `yaml:"messages"`
StreamableHTTP string `yaml:"streamable_http"` // Path for streamable HTTP requests
} }
// StdioConfig contains stdio-specific configuration // StdioConfig contains stdio-specific configuration
type StdioConfig struct { type StdioConfig struct {
Enabled bool `yaml:"enabled"` Enabled bool `yaml:"enabled"`
UserCommand string `yaml:"user_command"` // The command provided by the user UserCommand string `yaml:"user_command"` // The command provided by the user
WorkDir string `yaml:"work_dir"` // Working directory (optional) WorkDir string `yaml:"work_dir"` // Working directory (optional)
Args []string `yaml:"args,omitempty"` // Additional arguments Args []string `yaml:"args,omitempty"` // Additional arguments
Env []string `yaml:"env,omitempty"` // Environment variables Env []string `yaml:"env,omitempty"` // Environment variables
} }
@ -85,18 +86,18 @@ type DefaultConfig struct {
} }
type Config struct { type Config struct {
AuthServerBaseURL string AuthServerBaseURL string
ListenPort int `yaml:"listen_port"` ListenPort int `yaml:"listen_port"`
BaseURL string `yaml:"base_url"` BaseURL string `yaml:"base_url"`
Port int `yaml:"port"` Port int `yaml:"port"`
JWKSURL string JWKSURL string
TimeoutSeconds int `yaml:"timeout_seconds"` TimeoutSeconds int `yaml:"timeout_seconds"`
PathMapping map[string]string `yaml:"path_mapping"` PathMapping map[string]string `yaml:"path_mapping"`
Mode string `yaml:"mode"` Mode string `yaml:"mode"`
CORSConfig CORSConfig `yaml:"cors"` CORSConfig CORSConfig `yaml:"cors"`
TransportMode TransportMode `yaml:"transport_mode"` TransportMode TransportMode `yaml:"transport_mode"`
Paths PathsConfig `yaml:"paths"` Paths PathsConfig `yaml:"paths"`
Stdio StdioConfig `yaml:"stdio"` Stdio StdioConfig `yaml:"stdio"`
// Nested config for Asgardeo // Nested config for Asgardeo
Demo DemoConfig `yaml:"demo"` Demo DemoConfig `yaml:"demo"`
@ -138,7 +139,7 @@ func (c *Config) Validate() error {
// GetMCPPaths returns the list of paths that should be proxied to the MCP server // GetMCPPaths returns the list of paths that should be proxied to the MCP server
func (c *Config) GetMCPPaths() []string { func (c *Config) GetMCPPaths() []string {
return []string{c.Paths.SSE, c.Paths.Messages} return []string{c.Paths.SSE, c.Paths.Messages, c.Paths.StreamableHTTP}
} }
// BuildExecCommand constructs the full command string for execution in stdio mode // BuildExecCommand constructs the full command string for execution in stdio mode
@ -147,7 +148,6 @@ func (c *Config) BuildExecCommand() string {
return "" return ""
} }
if runtime.GOOS == "windows" { if runtime.GOOS == "windows" {
// For Windows, we need to properly escape the inner command // For Windows, we need to properly escape the inner command
escapedCommand := strings.ReplaceAll(c.Stdio.UserCommand, `"`, `\"`) escapedCommand := strings.ReplaceAll(c.Stdio.UserCommand, `"`, `\"`)
@ -176,12 +176,12 @@ func LoadConfig(path string) (*Config, error) {
if err := decoder.Decode(&cfg); err != nil { if err := decoder.Decode(&cfg); err != nil {
return nil, err return nil, err
} }
// Set default values // Set default values
if cfg.TimeoutSeconds == 0 { if cfg.TimeoutSeconds == 0 {
cfg.TimeoutSeconds = 15 // default cfg.TimeoutSeconds = 15 // default
} }
// Set default transport mode if not specified // Set default transport mode if not specified
if cfg.TransportMode == "" { if cfg.TransportMode == "" {
cfg.TransportMode = SSETransport // Default to SSE cfg.TransportMode = SSETransport // Default to SSE
@ -191,11 +191,11 @@ func LoadConfig(path string) (*Config, error) {
if cfg.Port == 0 { if cfg.Port == 0 {
cfg.Port = 8000 // default cfg.Port = 8000 // default
} }
// Validate the configuration // Validate the configuration
if err := cfg.Validate(); err != nil { if err := cfg.Validate(); err != nil {
return nil, err return nil, err
} }
return &cfg, nil return &cfg, nil
} }

View file

@ -136,20 +136,15 @@ func TestValidate(t *testing.T) {
func TestGetMCPPaths(t *testing.T) { func TestGetMCPPaths(t *testing.T) {
cfg := Config{ cfg := Config{
Paths: PathsConfig{ Paths: PathsConfig{
SSE: "/custom-sse", SSE: "/custom-sse",
Messages: "/custom-messages", Messages: "/custom-messages",
StreamableHTTP: "/custom-streamable",
}, },
} }
paths := cfg.GetMCPPaths() paths := cfg.GetMCPPaths()
if len(paths) != 2 { if len(paths) != 3 {
t.Errorf("Expected 2 MCP paths, got %d", len(paths)) t.Errorf("Expected 3 MCP paths, got %d", len(paths))
}
if paths[0] != "/custom-sse" {
t.Errorf("Expected first path=/custom-sse, got %s", paths[0])
}
if paths[1] != "/custom-messages" {
t.Errorf("Expected second path=/custom-messages, got %s", paths[1])
} }
} }

View file

@ -10,7 +10,7 @@ import (
"github.com/wso2/open-mcp-auth-proxy/internal/authz" "github.com/wso2/open-mcp-auth-proxy/internal/authz"
"github.com/wso2/open-mcp-auth-proxy/internal/config" "github.com/wso2/open-mcp-auth-proxy/internal/config"
"github.com/wso2/open-mcp-auth-proxy/internal/logging" logger "github.com/wso2/open-mcp-auth-proxy/internal/logging"
"github.com/wso2/open-mcp-auth-proxy/internal/util" "github.com/wso2/open-mcp-auth-proxy/internal/util"
) )
@ -106,7 +106,7 @@ func buildProxyHandler(cfg *config.Config, modifiers map[string]RequestModifier)
logger.Error("Invalid auth server URL: %v", err) logger.Error("Invalid auth server URL: %v", err)
panic(err) // Fatal error that prevents startup panic(err) // Fatal error that prevents startup
} }
mcpBase, err := url.Parse(cfg.BaseURL) mcpBase, err := url.Parse(cfg.BaseURL)
if err != nil { if err != nil {
logger.Error("Invalid MCP server URL: %v", err) logger.Error("Invalid MCP server URL: %v", err)
@ -191,13 +191,13 @@ func buildProxyHandler(cfg *config.Config, modifiers map[string]RequestModifier)
req.Host = targetURL.Host req.Host = targetURL.Host
cleanHeaders := http.Header{} cleanHeaders := http.Header{}
// Set proper origin header to match the target // Set proper origin header to match the target
if isSSE { if isSSE {
// For SSE, ensure origin matches the target // For SSE, ensure origin matches the target
req.Header.Set("Origin", targetURL.Scheme+"://"+targetURL.Host) req.Header.Set("Origin", targetURL.Scheme+"://"+targetURL.Host)
} }
for k, v := range r.Header { for k, v := range r.Header {
// Skip hop-by-hop headers // Skip hop-by-hop headers
if skipHeader(k) { if skipHeader(k) {
@ -231,12 +231,12 @@ func buildProxyHandler(cfg *config.Config, modifiers map[string]RequestModifier)
proxyHost: r.Host, proxyHost: r.Host,
targetHost: targetURL.Host, targetHost: targetURL.Host,
} }
// Set SSE-specific headers // Set SSE-specific headers
w.Header().Set("X-Accel-Buffering", "no") w.Header().Set("X-Accel-Buffering", "no")
w.Header().Set("Cache-Control", "no-cache") w.Header().Set("Cache-Control", "no-cache")
w.Header().Set("Connection", "keep-alive") w.Header().Set("Connection", "keep-alive")
// Keep SSE connections open // Keep SSE connections open
HandleSSE(w, r, rp) HandleSSE(w, r, rp)
} else { } else {