mirror of
https://github.com/wso2/open-mcp-auth-proxy.git
synced 2025-06-27 17:13:31 +00:00
Add transport mode support for stdio, SSE stability fixes (#13)
Add transport mode support for stdio, SSE stability fixes
This commit is contained in:
parent
6ce52261db
commit
32c9378aad
12 changed files with 808 additions and 142 deletions
|
@ -3,31 +3,71 @@ package main
|
|||
import (
|
||||
"flag"
|
||||
"fmt"
|
||||
"log"
|
||||
"net/http"
|
||||
"os"
|
||||
"os/signal"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"github.com/wso2/open-mcp-auth-proxy/internal/authz"
|
||||
"github.com/wso2/open-mcp-auth-proxy/internal/config"
|
||||
"github.com/wso2/open-mcp-auth-proxy/internal/constants"
|
||||
"github.com/wso2/open-mcp-auth-proxy/internal/logging"
|
||||
"github.com/wso2/open-mcp-auth-proxy/internal/proxy"
|
||||
"github.com/wso2/open-mcp-auth-proxy/internal/subprocess"
|
||||
"github.com/wso2/open-mcp-auth-proxy/internal/util"
|
||||
)
|
||||
|
||||
func main() {
|
||||
demoMode := flag.Bool("demo", false, "Use Asgardeo-based provider (demo).")
|
||||
asgardeoMode := flag.Bool("asgardeo", false, "Use Asgardeo-based provider (asgardeo).")
|
||||
debugMode := flag.Bool("debug", false, "Enable debug logging")
|
||||
stdioMode := flag.Bool("stdio", false, "Use stdio transport mode instead of SSE")
|
||||
flag.Parse()
|
||||
|
||||
logger.SetDebug(*debugMode)
|
||||
|
||||
// 1. Load config
|
||||
cfg, err := config.LoadConfig("config.yaml")
|
||||
if err != nil {
|
||||
log.Fatalf("Error loading config: %v", err)
|
||||
logger.Error("Error loading config: %v", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
// 2. Create the chosen provider
|
||||
// Override transport mode if stdio flag is set
|
||||
if *stdioMode {
|
||||
cfg.TransportMode = config.StdioTransport
|
||||
// Ensure stdio is enabled
|
||||
cfg.Stdio.Enabled = true
|
||||
// Re-validate config
|
||||
if err := cfg.Validate(); err != nil {
|
||||
logger.Error("Configuration error: %v", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
}
|
||||
|
||||
logger.Info("Using transport mode: %s", cfg.TransportMode)
|
||||
logger.Info("Using MCP server base URL: %s", cfg.BaseURL)
|
||||
logger.Info("Using MCP paths: SSE=%s, Messages=%s", cfg.Paths.SSE, cfg.Paths.Messages)
|
||||
|
||||
// 2. Start subprocess if configured and in stdio mode
|
||||
var procManager *subprocess.Manager
|
||||
if cfg.TransportMode == config.StdioTransport && cfg.Stdio.Enabled {
|
||||
// Ensure all required dependencies are available
|
||||
if err := subprocess.EnsureDependenciesAvailable(cfg.Stdio.UserCommand); err != nil {
|
||||
logger.Warn("%v", err)
|
||||
logger.Warn("Subprocess may fail to start due to missing dependencies")
|
||||
}
|
||||
|
||||
procManager = subprocess.NewManager()
|
||||
if err := procManager.Start(cfg); err != nil {
|
||||
logger.Warn("Failed to start subprocess: %v", err)
|
||||
}
|
||||
} else if cfg.TransportMode == config.SSETransport {
|
||||
logger.Info("Using SSE transport mode, not starting subprocess")
|
||||
}
|
||||
|
||||
// 3. Create the chosen provider
|
||||
var provider authz.Provider
|
||||
if *demoMode {
|
||||
cfg.Mode = "demo"
|
||||
|
@ -46,41 +86,49 @@ func main() {
|
|||
provider = authz.NewDefaultProvider(cfg)
|
||||
}
|
||||
|
||||
// 3. (Optional) Fetch JWKS if you want local JWT validation
|
||||
// 4. (Optional) Fetch JWKS if you want local JWT validation
|
||||
if err := util.FetchJWKS(cfg.JWKSURL); err != nil {
|
||||
log.Fatalf("Failed to fetch JWKS: %v", err)
|
||||
logger.Error("Failed to fetch JWKS: %v", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
// 4. Build the main router
|
||||
// 5. Build the main router
|
||||
mux := proxy.NewRouter(cfg, provider)
|
||||
|
||||
listen_address := fmt.Sprintf(":%d", cfg.ListenPort)
|
||||
|
||||
// 5. Start the server
|
||||
// 6. Start the server
|
||||
srv := &http.Server{
|
||||
|
||||
Addr: listen_address,
|
||||
Handler: mux,
|
||||
}
|
||||
|
||||
go func() {
|
||||
log.Printf("Server listening on %s", listen_address)
|
||||
logger.Info("Server listening on %s", listen_address)
|
||||
if err := srv.ListenAndServe(); err != nil && err != http.ErrServerClosed {
|
||||
log.Fatalf("Server error: %v", err)
|
||||
logger.Error("Server error: %v", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
}()
|
||||
|
||||
// 6. Graceful shutdown on Ctrl+C
|
||||
// 7. Wait for shutdown signal
|
||||
stop := make(chan os.Signal, 1)
|
||||
signal.Notify(stop, os.Interrupt)
|
||||
signal.Notify(stop, os.Interrupt, syscall.SIGTERM)
|
||||
<-stop
|
||||
log.Println("Shutting down...")
|
||||
logger.Info("Shutting down...")
|
||||
|
||||
// 8. First terminate subprocess if running
|
||||
if procManager != nil && procManager.IsRunning() {
|
||||
procManager.Shutdown()
|
||||
}
|
||||
|
||||
// 9. Then shutdown the server
|
||||
logger.Info("Shutting down HTTP server...")
|
||||
shutdownCtx, cancel := proxy.NewShutdownContext(5 * time.Second)
|
||||
defer cancel()
|
||||
|
||||
if err := srv.Shutdown(shutdownCtx); err != nil {
|
||||
log.Printf("Shutdown error: %v", err)
|
||||
logger.Error("HTTP server shutdown error: %v", err)
|
||||
}
|
||||
log.Println("Stopped.")
|
||||
logger.Info("Stopped.")
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue