Refactor configurations

This commit is contained in:
Chiran Fernando 2025-04-05 23:47:00 +05:30
parent 61d3c7e7e1
commit 5c1cc13ff3
6 changed files with 167 additions and 247 deletions

View file

@ -37,43 +37,37 @@ func main() {
// Override transport mode if stdio flag is set
if *stdioMode {
cfg.TransportMode = config.StdioTransport
// Validate command config for stdio mode
if err := cfg.Command.Validate(cfg.TransportMode); err != nil {
// Ensure stdio is enabled
cfg.Stdio.Enabled = true
// Re-validate config
if err := cfg.Validate(); err != nil {
logger.Error("Configuration error: %v", err)
os.Exit(1)
}
}
// 2. Ensure MCPPaths are properly configured
if cfg.TransportMode == config.StdioTransport && cfg.Command.Enabled {
// Use command.base_url for MCPServerBaseURL in stdio mode
cfg.MCPServerBaseURL = cfg.Command.GetBaseURL()
// Use command paths for MCPPaths in stdio mode
cfg.MCPPaths = cfg.Command.GetPaths()
logger.Info("Using MCP server baseUrl: %s", cfg.MCPServerBaseURL)
logger.Info("Using MCP paths: %v", cfg.MCPPaths)
}
logger.Info("Using transport mode: %s", cfg.TransportMode)
logger.Info("Using MCP server base URL: %s", cfg.BaseURL)
logger.Info("Using MCP paths: SSE=%s, Messages=%s", cfg.Paths.SSE, cfg.Paths.Messages)
// 3. Start subprocess if configured and in stdio mode
// 2. Start subprocess if configured and in stdio mode
var procManager *subprocess.Manager
if cfg.TransportMode == config.StdioTransport && cfg.Command.Enabled && cfg.Command.UserCommand != "" {
if cfg.TransportMode == config.StdioTransport && cfg.Stdio.Enabled {
// Ensure all required dependencies are available
if err := subprocess.EnsureDependenciesAvailable(cfg.Command.UserCommand); err != nil {
if err := subprocess.EnsureDependenciesAvailable(cfg.Stdio.UserCommand); err != nil {
logger.Warn("%v", err)
logger.Warn("Subprocess may fail to start due to missing dependencies")
}
procManager = subprocess.NewManager()
if err := procManager.Start(&cfg.Command); err != nil {
if err := procManager.Start(cfg); err != nil {
logger.Warn("Failed to start subprocess: %v", err)
}
} else if cfg.TransportMode == config.SSETransport {
logger.Info("Using SSE transport mode, not starting subprocess")
}
// 4. Create the chosen provider
// 3. Create the chosen provider
var provider authz.Provider
if *demoMode {
cfg.Mode = "demo"
@ -92,18 +86,18 @@ func main() {
provider = authz.NewDefaultProvider(cfg)
}
// 5. (Optional) Fetch JWKS if you want local JWT validation
// 4. (Optional) Fetch JWKS if you want local JWT validation
if err := util.FetchJWKS(cfg.JWKSURL); err != nil {
logger.Error("Failed to fetch JWKS: %v", err)
os.Exit(1)
}
// 6. Build the main router
// 5. Build the main router
mux := proxy.NewRouter(cfg, provider)
listen_address := fmt.Sprintf(":%d", cfg.ListenPort)
// 7. Start the server
// 6. Start the server
srv := &http.Server{
Addr: listen_address,
Handler: mux,
@ -117,18 +111,18 @@ func main() {
}
}()
// 8. Wait for shutdown signal
// 7. Wait for shutdown signal
stop := make(chan os.Signal, 1)
signal.Notify(stop, os.Interrupt, syscall.SIGTERM)
<-stop
logger.Info("Shutting down...")
// 9. First terminate subprocess if running
// 8. First terminate subprocess if running
if procManager != nil && procManager.IsRunning() {
procManager.Shutdown()
}
// 10. Then shutdown the server
// 9. Then shutdown the server
logger.Info("Shutting down HTTP server...")
shutdownCtx, cancel := proxy.NewShutdownContext(5 * time.Second)
defer cancel()
@ -137,30 +131,4 @@ func main() {
logger.Error("HTTP server shutdown error: %v", err)
}
logger.Info("Stopped.")
}
// Helper function to ensure a path is in a list
func ensurePathInList(paths *[]string, path string) {
// Check if path exists in the list
for _, p := range *paths {
if p == path {
return // Path already exists
}
}
// Path doesn't exist, add it
*paths = append(*paths, path)
logger.Info("Added path %s to MCPPaths", path)
}
// Helper function to ensure an origin is in a list
func ensureOriginInList(origins *[]string, origin string) {
// Check if origin exists in the list
for _, o := range *origins {
if o == origin {
return // Origin already exists
}
}
// Origin doesn't exist, add it
*origins = append(*origins, origin)
logger.Info("Added %s to allowed CORS origins", origin)
}