Add transport mode support for stdio, SSE stability fixes (#13)

Add transport mode support for stdio, SSE stability fixes
This commit is contained in:
Chiran Fernando 2025-04-08 12:46:00 +05:30 committed by GitHub
parent 6ce52261db
commit 32c9378aad
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
12 changed files with 808 additions and 142 deletions

View file

@ -9,6 +9,7 @@ import (
"strings"
"github.com/wso2/open-mcp-auth-proxy/internal/config"
"github.com/wso2/open-mcp-auth-proxy/internal/logging"
)
// RequestModifier modifies requests before they are proxied
@ -148,6 +149,7 @@ func (m *RegisterModifier) ModifyRequest(req *http.Request) (*http.Request, erro
if strings.Contains(contentType, "application/x-www-form-urlencoded") {
// Parse form data
if err := req.ParseForm(); err != nil {
logger.Error("Failed to parse form data: %v", err)
return nil, err
}
@ -169,12 +171,14 @@ func (m *RegisterModifier) ModifyRequest(req *http.Request) (*http.Request, erro
// Read body
bodyBytes, err := io.ReadAll(req.Body)
if err != nil {
logger.Error("Failed to read request body: %v", err)
return nil, err
}
// Parse JSON
var jsonData map[string]interface{}
if err := json.Unmarshal(bodyBytes, &jsonData); err != nil {
logger.Error("Failed to parse JSON body: %v", err)
return nil, err
}
@ -186,6 +190,7 @@ func (m *RegisterModifier) ModifyRequest(req *http.Request) (*http.Request, erro
// Marshal back to JSON
modifiedBody, err := json.Marshal(jsonData)
if err != nil {
logger.Error("Failed to marshal modified JSON: %v", err)
return nil, err
}

View file

@ -2,7 +2,6 @@ package proxy
import (
"context"
"log"
"net/http"
"net/http/httputil"
"net/url"
@ -11,6 +10,7 @@ import (
"github.com/wso2/open-mcp-auth-proxy/internal/authz"
"github.com/wso2/open-mcp-auth-proxy/internal/config"
"github.com/wso2/open-mcp-auth-proxy/internal/logging"
"github.com/wso2/open-mcp-auth-proxy/internal/util"
)
@ -82,7 +82,8 @@ func NewRouter(cfg *config.Config, provider authz.Provider) http.Handler {
}
// MCP paths
for _, path := range cfg.MCPPaths {
mcpPaths := cfg.GetMCPPaths()
for _, path := range mcpPaths {
mux.HandleFunc(path, buildProxyHandler(cfg, modifiers))
registeredPaths[path] = true
}
@ -100,23 +101,21 @@ func NewRouter(cfg *config.Config, provider authz.Provider) http.Handler {
func buildProxyHandler(cfg *config.Config, modifiers map[string]RequestModifier) http.HandlerFunc {
// Parse the base URLs up front
authBase, err := url.Parse(cfg.AuthServerBaseURL)
if err != nil {
log.Fatalf("Invalid auth server URL: %v", err)
logger.Error("Invalid auth server URL: %v", err)
panic(err) // Fatal error that prevents startup
}
mcpBase, err := url.Parse(cfg.MCPServerBaseURL)
mcpBase, err := url.Parse(cfg.BaseURL)
if err != nil {
log.Fatalf("Invalid MCP server URL: %v", err)
logger.Error("Invalid MCP server URL: %v", err)
panic(err) // Fatal error that prevents startup
}
// Detect SSE paths from config
ssePaths := make(map[string]bool)
for _, p := range cfg.MCPPaths {
if p == "/sse" {
ssePaths[p] = true
}
}
ssePaths[cfg.Paths.SSE] = true
return func(w http.ResponseWriter, r *http.Request) {
origin := r.Header.Get("Origin")
@ -124,7 +123,7 @@ func buildProxyHandler(cfg *config.Config, modifiers map[string]RequestModifier)
// Handle OPTIONS
if r.Method == http.MethodOptions {
if allowedOrigin == "" {
log.Printf("[proxy] Preflight request from disallowed origin: %s", origin)
logger.Warn("Preflight request from disallowed origin: %s", origin)
http.Error(w, "CORS origin not allowed", http.StatusForbidden)
return
}
@ -134,7 +133,7 @@ func buildProxyHandler(cfg *config.Config, modifiers map[string]RequestModifier)
}
if allowedOrigin == "" {
log.Printf("[proxy] Request from disallowed origin: %s for %s", origin, r.URL.Path)
logger.Warn("Request from disallowed origin: %s for %s", origin, r.URL.Path)
http.Error(w, "CORS origin not allowed", http.StatusForbidden)
return
}
@ -152,7 +151,7 @@ func buildProxyHandler(cfg *config.Config, modifiers map[string]RequestModifier)
// Validate JWT for MCP paths if required
// Placeholder for JWT validation logic
if err := util.ValidateJWT(r.Header.Get("Authorization")); err != nil {
log.Printf("[proxy] Unauthorized request to %s: %v", r.URL.Path, err)
logger.Warn("Unauthorized request to %s: %v", r.URL.Path, err)
http.Error(w, "Unauthorized", http.StatusUnauthorized)
return
}
@ -170,7 +169,7 @@ func buildProxyHandler(cfg *config.Config, modifiers map[string]RequestModifier)
var err error
r, err = modifier.ModifyRequest(r)
if err != nil {
log.Printf("[proxy] Error modifying request: %v", err)
logger.Error("Error modifying request: %v", err)
http.Error(w, "Bad Request", http.StatusBadRequest)
return
}
@ -192,7 +191,13 @@ func buildProxyHandler(cfg *config.Config, modifiers map[string]RequestModifier)
req.Host = targetURL.Host
cleanHeaders := http.Header{}
// Set proper origin header to match the target
if isSSE {
// For SSE, ensure origin matches the target
req.Header.Set("Origin", targetURL.Scheme+"://"+targetURL.Host)
}
for k, v := range r.Header {
// Skip hop-by-hop headers
if skipHeader(k) {
@ -205,21 +210,33 @@ func buildProxyHandler(cfg *config.Config, modifiers map[string]RequestModifier)
req.Header = cleanHeaders
log.Printf("[proxy] %s -> %s%s", r.URL.Path, req.URL.Host, req.URL.Path)
logger.Debug("%s -> %s%s", r.URL.Path, req.URL.Host, req.URL.Path)
},
ModifyResponse: func(resp *http.Response) error {
log.Printf("[proxy] Response from %s%s: %d", resp.Request.URL.Host, resp.Request.URL.Path, resp.StatusCode)
logger.Debug("Response from %s%s: %d", resp.Request.URL.Host, resp.Request.URL.Path, resp.StatusCode)
resp.Header.Del("Access-Control-Allow-Origin") // Avoid upstream conflicts
return nil
},
ErrorHandler: func(rw http.ResponseWriter, req *http.Request, err error) {
log.Printf("[proxy] Error proxying: %v", err)
logger.Error("Error proxying: %v", err)
http.Error(rw, "Bad Gateway", http.StatusBadGateway)
},
FlushInterval: -1, // immediate flush for SSE
}
if isSSE {
// Add special response handling for SSE connections to rewrite endpoint URLs
rp.Transport = &sseTransport{
Transport: http.DefaultTransport,
proxyHost: r.Host,
targetHost: targetURL.Host,
}
// Set SSE-specific headers
w.Header().Set("X-Accel-Buffering", "no")
w.Header().Set("Cache-Control", "no-cache")
w.Header().Set("Connection", "keep-alive")
// Keep SSE connections open
HandleSSE(w, r, rp)
} else {
@ -236,6 +253,7 @@ func getAllowedOrigin(origin string, cfg *config.Config) string {
return cfg.CORSConfig.AllowedOrigins[0] // Default to first allowed origin
}
for _, allowed := range cfg.CORSConfig.AllowedOrigins {
logger.Debug("Checking CORS origin: %s against allowed: %s", origin, allowed)
if allowed == origin {
return allowed
}
@ -256,6 +274,7 @@ func addCORSHeaders(w http.ResponseWriter, cfg *config.Config, allowedOrigin, re
w.Header().Set("Access-Control-Allow-Credentials", "true")
}
w.Header().Set("Vary", "Origin")
w.Header().Set("X-Accel-Buffering", "no")
}
func isAuthPath(path string) bool {
@ -273,7 +292,8 @@ func isAuthPath(path string) bool {
// isMCPPath checks if the path is an MCP path
func isMCPPath(path string, cfg *config.Config) bool {
for _, p := range cfg.MCPPaths {
mcpPaths := cfg.GetMCPPaths()
for _, p := range mcpPaths {
if strings.HasPrefix(path, p) {
return true
}

View file

@ -1,11 +1,16 @@
package proxy
import (
"bufio"
"context"
"log"
"fmt"
"io"
"net/http"
"net/http/httputil"
"strings"
"time"
"github.com/wso2/open-mcp-auth-proxy/internal/logging"
)
// HandleSSE sets up a go-routine to wait for context cancellation
@ -16,7 +21,7 @@ func HandleSSE(w http.ResponseWriter, r *http.Request, rp *httputil.ReverseProxy
go func() {
<-ctx.Done()
log.Printf("INFO: SSE connection closed from %s (path: %s)", r.RemoteAddr, r.URL.Path)
logger.Info("SSE connection closed from %s (path: %s)", r.RemoteAddr, r.URL.Path)
close(done)
}()
@ -32,3 +37,73 @@ func HandleSSE(w http.ResponseWriter, r *http.Request, rp *httputil.ReverseProxy
func NewShutdownContext(timeout time.Duration) (context.Context, context.CancelFunc) {
return context.WithTimeout(context.Background(), timeout)
}
// sseTransport is a custom http.RoundTripper that intercepts and modifies SSE responses
type sseTransport struct {
Transport http.RoundTripper
proxyHost string
targetHost string
}
func (t *sseTransport) RoundTrip(req *http.Request) (*http.Response, error) {
// Call the underlying transport
resp, err := t.Transport.RoundTrip(req)
if err != nil {
return nil, err
}
// Check if this is an SSE response
contentType := resp.Header.Get("Content-Type")
if !strings.Contains(contentType, "text/event-stream") {
return resp, nil
}
logger.Info("Intercepting SSE response to modify endpoint events")
// Create a response wrapper that modifies the response body
originalBody := resp.Body
pr, pw := io.Pipe()
go func() {
defer originalBody.Close()
defer pw.Close()
scanner := bufio.NewScanner(originalBody)
for scanner.Scan() {
line := scanner.Text()
// Check if this line contains an endpoint event
if strings.HasPrefix(line, "event: endpoint") {
// Read the data line
if scanner.Scan() {
dataLine := scanner.Text()
if strings.HasPrefix(dataLine, "data: ") {
// Extract the endpoint URL
endpoint := strings.TrimPrefix(dataLine, "data: ")
// Replace the host in the endpoint
logger.Debug("Original endpoint: %s", endpoint)
endpoint = strings.Replace(endpoint, t.targetHost, t.proxyHost, 1)
logger.Debug("Modified endpoint: %s", endpoint)
// Write the modified event lines
fmt.Fprintln(pw, line)
fmt.Fprintln(pw, "data: "+endpoint)
continue
}
}
}
// Write the original line for non-endpoint events
fmt.Fprintln(pw, line)
}
if err := scanner.Err(); err != nil {
logger.Error("Error reading SSE stream: %v", err)
}
}()
// Replace the response body with our modified pipe
resp.Body = pr
return resp, nil
}