mirror of
https://github.com/wso2/open-mcp-auth-proxy.git
synced 2025-06-27 17:13:31 +00:00
Add transport mode support for stdio, SSE stability fixes (#13)
Add transport mode support for stdio, SSE stability fixes
This commit is contained in:
parent
6ce52261db
commit
32c9378aad
12 changed files with 808 additions and 142 deletions
|
@ -2,7 +2,6 @@ package proxy
|
|||
|
||||
import (
|
||||
"context"
|
||||
"log"
|
||||
"net/http"
|
||||
"net/http/httputil"
|
||||
"net/url"
|
||||
|
@ -11,6 +10,7 @@ import (
|
|||
|
||||
"github.com/wso2/open-mcp-auth-proxy/internal/authz"
|
||||
"github.com/wso2/open-mcp-auth-proxy/internal/config"
|
||||
"github.com/wso2/open-mcp-auth-proxy/internal/logging"
|
||||
"github.com/wso2/open-mcp-auth-proxy/internal/util"
|
||||
)
|
||||
|
||||
|
@ -82,7 +82,8 @@ func NewRouter(cfg *config.Config, provider authz.Provider) http.Handler {
|
|||
}
|
||||
|
||||
// MCP paths
|
||||
for _, path := range cfg.MCPPaths {
|
||||
mcpPaths := cfg.GetMCPPaths()
|
||||
for _, path := range mcpPaths {
|
||||
mux.HandleFunc(path, buildProxyHandler(cfg, modifiers))
|
||||
registeredPaths[path] = true
|
||||
}
|
||||
|
@ -100,23 +101,21 @@ func NewRouter(cfg *config.Config, provider authz.Provider) http.Handler {
|
|||
|
||||
func buildProxyHandler(cfg *config.Config, modifiers map[string]RequestModifier) http.HandlerFunc {
|
||||
// Parse the base URLs up front
|
||||
|
||||
authBase, err := url.Parse(cfg.AuthServerBaseURL)
|
||||
if err != nil {
|
||||
log.Fatalf("Invalid auth server URL: %v", err)
|
||||
logger.Error("Invalid auth server URL: %v", err)
|
||||
panic(err) // Fatal error that prevents startup
|
||||
}
|
||||
mcpBase, err := url.Parse(cfg.MCPServerBaseURL)
|
||||
|
||||
mcpBase, err := url.Parse(cfg.BaseURL)
|
||||
if err != nil {
|
||||
log.Fatalf("Invalid MCP server URL: %v", err)
|
||||
logger.Error("Invalid MCP server URL: %v", err)
|
||||
panic(err) // Fatal error that prevents startup
|
||||
}
|
||||
|
||||
// Detect SSE paths from config
|
||||
ssePaths := make(map[string]bool)
|
||||
for _, p := range cfg.MCPPaths {
|
||||
if p == "/sse" {
|
||||
ssePaths[p] = true
|
||||
}
|
||||
}
|
||||
ssePaths[cfg.Paths.SSE] = true
|
||||
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
origin := r.Header.Get("Origin")
|
||||
|
@ -124,7 +123,7 @@ func buildProxyHandler(cfg *config.Config, modifiers map[string]RequestModifier)
|
|||
// Handle OPTIONS
|
||||
if r.Method == http.MethodOptions {
|
||||
if allowedOrigin == "" {
|
||||
log.Printf("[proxy] Preflight request from disallowed origin: %s", origin)
|
||||
logger.Warn("Preflight request from disallowed origin: %s", origin)
|
||||
http.Error(w, "CORS origin not allowed", http.StatusForbidden)
|
||||
return
|
||||
}
|
||||
|
@ -134,7 +133,7 @@ func buildProxyHandler(cfg *config.Config, modifiers map[string]RequestModifier)
|
|||
}
|
||||
|
||||
if allowedOrigin == "" {
|
||||
log.Printf("[proxy] Request from disallowed origin: %s for %s", origin, r.URL.Path)
|
||||
logger.Warn("Request from disallowed origin: %s for %s", origin, r.URL.Path)
|
||||
http.Error(w, "CORS origin not allowed", http.StatusForbidden)
|
||||
return
|
||||
}
|
||||
|
@ -152,7 +151,7 @@ func buildProxyHandler(cfg *config.Config, modifiers map[string]RequestModifier)
|
|||
// Validate JWT for MCP paths if required
|
||||
// Placeholder for JWT validation logic
|
||||
if err := util.ValidateJWT(r.Header.Get("Authorization")); err != nil {
|
||||
log.Printf("[proxy] Unauthorized request to %s: %v", r.URL.Path, err)
|
||||
logger.Warn("Unauthorized request to %s: %v", r.URL.Path, err)
|
||||
http.Error(w, "Unauthorized", http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
|
@ -170,7 +169,7 @@ func buildProxyHandler(cfg *config.Config, modifiers map[string]RequestModifier)
|
|||
var err error
|
||||
r, err = modifier.ModifyRequest(r)
|
||||
if err != nil {
|
||||
log.Printf("[proxy] Error modifying request: %v", err)
|
||||
logger.Error("Error modifying request: %v", err)
|
||||
http.Error(w, "Bad Request", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
@ -192,7 +191,13 @@ func buildProxyHandler(cfg *config.Config, modifiers map[string]RequestModifier)
|
|||
req.Host = targetURL.Host
|
||||
|
||||
cleanHeaders := http.Header{}
|
||||
|
||||
|
||||
// Set proper origin header to match the target
|
||||
if isSSE {
|
||||
// For SSE, ensure origin matches the target
|
||||
req.Header.Set("Origin", targetURL.Scheme+"://"+targetURL.Host)
|
||||
}
|
||||
|
||||
for k, v := range r.Header {
|
||||
// Skip hop-by-hop headers
|
||||
if skipHeader(k) {
|
||||
|
@ -205,21 +210,33 @@ func buildProxyHandler(cfg *config.Config, modifiers map[string]RequestModifier)
|
|||
|
||||
req.Header = cleanHeaders
|
||||
|
||||
log.Printf("[proxy] %s -> %s%s", r.URL.Path, req.URL.Host, req.URL.Path)
|
||||
logger.Debug("%s -> %s%s", r.URL.Path, req.URL.Host, req.URL.Path)
|
||||
},
|
||||
ModifyResponse: func(resp *http.Response) error {
|
||||
log.Printf("[proxy] Response from %s%s: %d", resp.Request.URL.Host, resp.Request.URL.Path, resp.StatusCode)
|
||||
logger.Debug("Response from %s%s: %d", resp.Request.URL.Host, resp.Request.URL.Path, resp.StatusCode)
|
||||
resp.Header.Del("Access-Control-Allow-Origin") // Avoid upstream conflicts
|
||||
return nil
|
||||
},
|
||||
ErrorHandler: func(rw http.ResponseWriter, req *http.Request, err error) {
|
||||
log.Printf("[proxy] Error proxying: %v", err)
|
||||
logger.Error("Error proxying: %v", err)
|
||||
http.Error(rw, "Bad Gateway", http.StatusBadGateway)
|
||||
},
|
||||
FlushInterval: -1, // immediate flush for SSE
|
||||
}
|
||||
|
||||
if isSSE {
|
||||
// Add special response handling for SSE connections to rewrite endpoint URLs
|
||||
rp.Transport = &sseTransport{
|
||||
Transport: http.DefaultTransport,
|
||||
proxyHost: r.Host,
|
||||
targetHost: targetURL.Host,
|
||||
}
|
||||
|
||||
// Set SSE-specific headers
|
||||
w.Header().Set("X-Accel-Buffering", "no")
|
||||
w.Header().Set("Cache-Control", "no-cache")
|
||||
w.Header().Set("Connection", "keep-alive")
|
||||
|
||||
// Keep SSE connections open
|
||||
HandleSSE(w, r, rp)
|
||||
} else {
|
||||
|
@ -236,6 +253,7 @@ func getAllowedOrigin(origin string, cfg *config.Config) string {
|
|||
return cfg.CORSConfig.AllowedOrigins[0] // Default to first allowed origin
|
||||
}
|
||||
for _, allowed := range cfg.CORSConfig.AllowedOrigins {
|
||||
logger.Debug("Checking CORS origin: %s against allowed: %s", origin, allowed)
|
||||
if allowed == origin {
|
||||
return allowed
|
||||
}
|
||||
|
@ -256,6 +274,7 @@ func addCORSHeaders(w http.ResponseWriter, cfg *config.Config, allowedOrigin, re
|
|||
w.Header().Set("Access-Control-Allow-Credentials", "true")
|
||||
}
|
||||
w.Header().Set("Vary", "Origin")
|
||||
w.Header().Set("X-Accel-Buffering", "no")
|
||||
}
|
||||
|
||||
func isAuthPath(path string) bool {
|
||||
|
@ -273,7 +292,8 @@ func isAuthPath(path string) bool {
|
|||
|
||||
// isMCPPath checks if the path is an MCP path
|
||||
func isMCPPath(path string, cfg *config.Config) bool {
|
||||
for _, p := range cfg.MCPPaths {
|
||||
mcpPaths := cfg.GetMCPPaths()
|
||||
for _, p := range mcpPaths {
|
||||
if strings.HasPrefix(path, p) {
|
||||
return true
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue