diff --git a/README.md b/README.md index b587878..258f848 100644 --- a/README.md +++ b/README.md @@ -41,6 +41,8 @@ python3 echo_server.py Update the following parameters in `config.yaml`. +### demo mode configuration: + ```yaml mcp_server_base_url: "http://localhost:8000" # URL of your MCP server listen_address: ":8080" # Address where the proxy will listen diff --git a/cmd/proxy/main.go b/cmd/proxy/main.go index 9a4b472..cde3cf3 100644 --- a/cmd/proxy/main.go +++ b/cmd/proxy/main.go @@ -11,12 +11,14 @@ import ( "github.com/wso2/open-mcp-auth-proxy/internal/authz" "github.com/wso2/open-mcp-auth-proxy/internal/config" + "github.com/wso2/open-mcp-auth-proxy/internal/constants" "github.com/wso2/open-mcp-auth-proxy/internal/proxy" "github.com/wso2/open-mcp-auth-proxy/internal/util" ) func main() { demoMode := flag.Bool("demo", false, "Use Asgardeo-based provider (demo).") + asgardeoMode := flag.Bool("asgardeo", false, "Use Asgardeo-based provider (asgardeo).") flag.Parse() // 1. Load config @@ -28,12 +30,20 @@ func main() { // 2. Create the chosen provider var provider authz.Provider if *demoMode { - cfg.AuthServerBaseURL = "https://api.asgardeo.io/t/" + cfg.Demo.OrgName + "/oauth2" - cfg.JWKSURL = "https://api.asgardeo.io/t/" + cfg.Demo.OrgName + "/oauth2/jwks" + cfg.Mode = "demo" + cfg.AuthServerBaseURL = constants.ASGARDEO_BASE_URL + cfg.Demo.OrgName + "/oauth2" + cfg.JWKSURL = constants.ASGARDEO_BASE_URL + cfg.Demo.OrgName + "/oauth2/jwks" + provider = authz.NewAsgardeoProvider(cfg) + } else if *asgardeoMode { + cfg.Mode = "asgardeo" + cfg.AuthServerBaseURL = constants.ASGARDEO_BASE_URL + cfg.Asgardeo.OrgName + "/oauth2" + cfg.JWKSURL = constants.ASGARDEO_BASE_URL + cfg.Asgardeo.OrgName + "/oauth2/jwks" provider = authz.NewAsgardeoProvider(cfg) - fmt.Println("Using Asgardeo provider (demo).") } else { - log.Fatalf("Not supported yet.") + cfg.Mode = "default" + cfg.JWKSURL = cfg.Default.JWKSURL + cfg.AuthServerBaseURL = cfg.Default.BaseURL + provider = authz.NewDefaultProvider(cfg) } // 3. (Optional) Fetch JWKS if you want local JWT validation @@ -44,14 +54,17 @@ func main() { // 4. Build the main router mux := proxy.NewRouter(cfg, provider) + listen_address := fmt.Sprintf(":%d", cfg.ListenPort) + // 5. Start the server srv := &http.Server{ - Addr: cfg.ListenAddress, + + Addr: listen_address, Handler: mux, } go func() { - log.Printf("Server listening on %s", cfg.ListenAddress) + log.Printf("Server listening on %s", listen_address) if err := srv.ListenAndServe(); err != nil && err != http.ErrServerClosed { log.Fatalf("Server error: %v", err) } diff --git a/config.yaml b/config.yaml index 9725f14..0b0ade4 100644 --- a/config.yaml +++ b/config.yaml @@ -1,9 +1,7 @@ # config.yaml -auth_server_base_url: "" -mcp_server_base_url: "http://localhost:8000" -listen_address: ":8080" -jwks_url: "" +mcp_server_base_url: "" +listen_port: 8080 timeout_seconds: 10 mcp_paths: @@ -11,8 +9,63 @@ mcp_paths: - /sse path_mapping: + /token: /token + /register: /register + /authorize: /authorize + /.well-known/oauth-authorization-server: /.well-known/oauth-authorization-server + +cors: + allowed_origins: + - "" + allowed_methods: + - "GET" + - "POST" + - "PUT" + - "DELETE" + allowed_headers: + - "Authorization" + - "Content-Type" + allow_credentials: true demo: org_name: "openmcpauthdemo" client_id: "N0U9e_NNGr9mP_0fPnPfPI0a6twa" client_secret: "qFHfiBp5gNGAO9zV4YPnDofBzzfInatfUbHyPZvM0jka" + +asgardeo: + org_name: "" + client_id: "" + client_secret: "" + +default: + base_url: "" + jwks_url: "" + path: + /.well-known/oauth-authorization-server: + response: + issuer: "" + jwks_uri: "" + authorization_endpoint: "" # Optional + token_endpoint: "" # Optional + registration_endpoint: "" # Optional + response_types_supported: + - "code" + grant_types_supported: + - "authorization_code" + - "refresh_token" + code_challenge_methods_supported: + - "S256" + - "plain" + /authroize: + addQueryParams: + - name: "" + value: "" + /token: + addBodyParams: + - name: "" + value: "" + /register: + addBodyParams: + - name: "" + value: "" + diff --git a/internal/authz/default.go b/internal/authz/default.go new file mode 100644 index 0000000..9230d39 --- /dev/null +++ b/internal/authz/default.go @@ -0,0 +1,94 @@ +package authz + +import ( + "encoding/json" + "net/http" + + "github.com/wso2/open-mcp-auth-proxy/internal/config" +) + +type defaultProvider struct { + cfg *config.Config +} + +// NewDefaultProvider initializes a Provider for Asgardeo (demo mode). +func NewDefaultProvider(cfg *config.Config) Provider { + return &defaultProvider{cfg: cfg} +} + +func (p *defaultProvider) WellKnownHandler() http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Access-Control-Allow-Origin", "*") + w.Header().Set("Access-Control-Allow-Headers", "Authorization, Content-Type") + w.Header().Set("Access-Control-Allow-Methods", "GET, OPTIONS") + + if r.Method == http.MethodOptions { + w.WriteHeader(http.StatusNoContent) + return + } + + if r.Method != http.MethodGet { + http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) + return + } + + // Check if we have a custom response configuration + if p.cfg.Default.Path != nil { + pathConfig, exists := p.cfg.Default.Path["/.well-known/oauth-authorization-server"] + if exists && pathConfig.Response != nil { + // Use configured response values + responseConfig := pathConfig.Response + + // Get current host for proxy endpoints + scheme := "http" + if r.TLS != nil { + scheme = "https" + } + if forwardedProto := r.Header.Get("X-Forwarded-Proto"); forwardedProto != "" { + scheme = forwardedProto + } + host := r.Host + if forwardedHost := r.Header.Get("X-Forwarded-Host"); forwardedHost != "" { + host = forwardedHost + } + baseURL := scheme + "://" + host + + authorizationEndpoint := responseConfig.AuthorizationEndpoint + if authorizationEndpoint == "" { + authorizationEndpoint = baseURL + "/authorize" + } + tokenEndpoint := responseConfig.TokenEndpoint + if tokenEndpoint == "" { + tokenEndpoint = baseURL + "/token" + } + registraionEndpoint := responseConfig.RegistrationEndpoint + if registraionEndpoint == "" { + registraionEndpoint = baseURL + "/register" + } + + // Build response from config + response := map[string]interface{}{ + "issuer": responseConfig.Issuer, + "authorization_endpoint": authorizationEndpoint, + "token_endpoint": tokenEndpoint, + "jwks_uri": responseConfig.JwksURI, + "response_types_supported": responseConfig.ResponseTypesSupported, + "grant_types_supported": responseConfig.GrantTypesSupported, + "token_endpoint_auth_methods_supported": []string{"client_secret_basic"}, + "registration_endpoint": registraionEndpoint, + "code_challenge_methods_supported": responseConfig.CodeChallengeMethodsSupported, + } + + w.Header().Set("Content-Type", "application/json") + if err := json.NewEncoder(w).Encode(response); err != nil { + http.Error(w, "Internal server error", http.StatusInternalServerError) + } + return + } + } + } +} + +func (p *defaultProvider) RegisterHandler() http.HandlerFunc { + return nil +} diff --git a/internal/config/config.go b/internal/config/config.go index 3a3b231..01c3a6f 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -13,17 +13,67 @@ type DemoConfig struct { OrgName string `yaml:"org_name"` } +type AsgardeoConfig struct { + ClientID string `yaml:"client_id"` + ClientSecret string `yaml:"client_secret"` + OrgName string `yaml:"org_name"` +} + +type CORSConfig struct { + AllowedOrigins []string `yaml:"allowed_origins"` + AllowedMethods []string `yaml:"allowed_methods"` + AllowedHeaders []string `yaml:"allowed_headers"` + AllowCredentials bool `yaml:"allow_credentials"` +} + +type ParamConfig struct { + Name string `yaml:"name"` + Value string `yaml:"value"` +} + +type ResponseConfig struct { + Issuer string `yaml:"issuer,omitempty"` + JwksURI string `yaml:"jwks_uri,omitempty"` + AuthorizationEndpoint string `yaml:"authorization_endpoint,omitempty"` + TokenEndpoint string `yaml:"token_endpoint,omitempty"` + RegistrationEndpoint string `yaml:"registration_endpoint,omitempty"` + ResponseTypesSupported []string `yaml:"response_types_supported,omitempty"` + GrantTypesSupported []string `yaml:"grant_types_supported,omitempty"` + CodeChallengeMethodsSupported []string `yaml:"code_challenge_methods_supported,omitempty"` +} + +type PathConfig struct { + // For well-known endpoint + Response *ResponseConfig `yaml:"response,omitempty"` + + // For authorization endpoint + AddQueryParams []ParamConfig `yaml:"addQueryParams,omitempty"` + + // For token and register endpoints + AddBodyParams []ParamConfig `yaml:"addBodyParams,omitempty"` +} + +type DefaultConfig struct { + BaseURL string `yaml:"base_url,omitempty"` + Path map[string]PathConfig `yaml:"path,omitempty"` + JWKSURL string `yaml:"jwks_url,omitempty"` +} + type Config struct { - AuthServerBaseURL string `yaml:"auth_server_base_url"` - MCPServerBaseURL string `yaml:"mcp_server_base_url"` - ListenAddress string `yaml:"listen_address"` - JWKSURL string `yaml:"jwks_url"` + AuthServerBaseURL string + MCPServerBaseURL string `yaml:"mcp_server_base_url"` + ListenPort int `yaml:"listen_port"` + JWKSURL string TimeoutSeconds int `yaml:"timeout_seconds"` MCPPaths []string `yaml:"mcp_paths"` PathMapping map[string]string `yaml:"path_mapping"` + Mode string `yaml:"mode"` + CORSConfig CORSConfig `yaml:"cors"` // Nested config for Asgardeo - Demo DemoConfig `yaml:"demo"` + Demo DemoConfig `yaml:"demo"` + Asgardeo AsgardeoConfig `yaml:"asgardeo"` + Default DefaultConfig `yaml:"default"` } // LoadConfig reads a YAML config file into Config struct. diff --git a/internal/constants/constants.go b/internal/constants/constants.go new file mode 100644 index 0000000..1e5808e --- /dev/null +++ b/internal/constants/constants.go @@ -0,0 +1,7 @@ +package constants + +// Package constant provides constants for the MCP Auth Proxy + +const ( + ASGARDEO_BASE_URL = "https://api.asgardeo.io/t/" +) diff --git a/internal/proxy/modifier.go b/internal/proxy/modifier.go new file mode 100644 index 0000000..8e2268b --- /dev/null +++ b/internal/proxy/modifier.go @@ -0,0 +1,199 @@ +package proxy + +import ( + "bytes" + "encoding/json" + "fmt" + "io" + "net/http" + "strings" + + "github.com/wso2/open-mcp-auth-proxy/internal/config" +) + +// RequestModifier modifies requests before they are proxied +type RequestModifier interface { + ModifyRequest(req *http.Request) (*http.Request, error) +} + +// AuthorizationModifier adds parameters to authorization requests +type AuthorizationModifier struct { + Config *config.Config +} + +// TokenModifier adds parameters to token requests +type TokenModifier struct { + Config *config.Config +} + +type RegisterModifier struct { + Config *config.Config +} + +// ModifyRequest adds configured parameters to authorization requests +func (m *AuthorizationModifier) ModifyRequest(req *http.Request) (*http.Request, error) { + // Check if we have parameters to add + if m.Config.Default.Path == nil { + return req, nil + } + + pathConfig, exists := m.Config.Default.Path["/authorize"] + if !exists || len(pathConfig.AddQueryParams) == 0 { + return req, nil + } + // Get current query parameters + query := req.URL.Query() + + // Add parameters from config + for _, param := range pathConfig.AddQueryParams { + query.Set(param.Name, param.Value) + } + + // Update the request URL + req.URL.RawQuery = query.Encode() + + return req, nil +} + +// ModifyRequest adds configured parameters to token requests +func (m *TokenModifier) ModifyRequest(req *http.Request) (*http.Request, error) { + // Only modify POST requests + if req.Method != http.MethodPost { + return req, nil + } + + // Check if we have parameters to add + if m.Config.Default.Path == nil { + return req, nil + } + + pathConfig, exists := m.Config.Default.Path["/token"] + if !exists || len(pathConfig.AddBodyParams) == 0 { + return req, nil + } + + contentType := req.Header.Get("Content-Type") + + if strings.Contains(contentType, "application/x-www-form-urlencoded") { + // Parse form data + if err := req.ParseForm(); err != nil { + return nil, err + } + + // Clone form data + formData := req.PostForm + + // Add configured parameters + for _, param := range pathConfig.AddBodyParams { + formData.Set(param.Name, param.Value) + } + + // Create new request body with modified form + formEncoded := formData.Encode() + req.Body = io.NopCloser(strings.NewReader(formEncoded)) + req.ContentLength = int64(len(formEncoded)) + req.Header.Set("Content-Length", fmt.Sprintf("%d", len(formEncoded))) + + } else if strings.Contains(contentType, "application/json") { + // Read body + bodyBytes, err := io.ReadAll(req.Body) + if err != nil { + return nil, err + } + + // Parse JSON + var jsonData map[string]interface{} + if err := json.Unmarshal(bodyBytes, &jsonData); err != nil { + return nil, err + } + + // Add parameters + for _, param := range pathConfig.AddBodyParams { + jsonData[param.Name] = param.Value + } + + // Marshal back to JSON + modifiedBody, err := json.Marshal(jsonData) + if err != nil { + return nil, err + } + + // Update request + req.Body = io.NopCloser(bytes.NewReader(modifiedBody)) + req.ContentLength = int64(len(modifiedBody)) + req.Header.Set("Content-Length", fmt.Sprintf("%d", len(modifiedBody))) + } + + return req, nil +} + +func (m *RegisterModifier) ModifyRequest(req *http.Request) (*http.Request, error) { + // Only modify POST requests + if req.Method != http.MethodPost { + return req, nil + } + + // Check if we have parameters to add + if m.Config.Default.Path == nil { + return req, nil + } + + pathConfig, exists := m.Config.Default.Path["/register"] + if !exists || len(pathConfig.AddBodyParams) == 0 { + return req, nil + } + + contentType := req.Header.Get("Content-Type") + + if strings.Contains(contentType, "application/x-www-form-urlencoded") { + // Parse form data + if err := req.ParseForm(); err != nil { + return nil, err + } + + // Clone form data + formData := req.PostForm + + // Add configured parameters + for _, param := range pathConfig.AddBodyParams { + formData.Set(param.Name, param.Value) + } + + // Create new request body with modified form + formEncoded := formData.Encode() + req.Body = io.NopCloser(strings.NewReader(formEncoded)) + req.ContentLength = int64(len(formEncoded)) + req.Header.Set("Content-Length", fmt.Sprintf("%d", len(formEncoded))) + + } else if strings.Contains(contentType, "application/json") { + // Read body + bodyBytes, err := io.ReadAll(req.Body) + if err != nil { + return nil, err + } + + // Parse JSON + var jsonData map[string]interface{} + if err := json.Unmarshal(bodyBytes, &jsonData); err != nil { + return nil, err + } + + // Add parameters + for _, param := range pathConfig.AddBodyParams { + jsonData[param.Name] = param.Value + } + + // Marshal back to JSON + modifiedBody, err := json.Marshal(jsonData) + if err != nil { + return nil, err + } + + // Update request + req.Body = io.NopCloser(bytes.NewReader(modifiedBody)) + req.ContentLength = int64(len(modifiedBody)) + req.Header.Set("Content-Length", fmt.Sprintf("%d", len(modifiedBody))) + } + + return req, nil +} diff --git a/internal/proxy/proxy.go b/internal/proxy/proxy.go index 382c8f3..c999be4 100644 --- a/internal/proxy/proxy.go +++ b/internal/proxy/proxy.go @@ -20,34 +20,87 @@ import ( func NewRouter(cfg *config.Config, provider authz.Provider) http.Handler { mux := http.NewServeMux() - // 1. Custom well-known - mux.HandleFunc("/.well-known/oauth-authorization-server", provider.WellKnownHandler()) + modifiers := map[string]RequestModifier{ + "/authorize": &AuthorizationModifier{Config: cfg}, + "/token": &TokenModifier{Config: cfg}, + "/register": &RegisterModifier{Config: cfg}, + } - // 2. Registration - mux.HandleFunc("/register", provider.RegisterHandler()) + registeredPaths := make(map[string]bool) - // 3. Default "auth" paths, proxied - defaultPaths := []string{"/authorize", "/token"} + var defaultPaths []string + + // Handle based on mode configuration + if cfg.Mode == "demo" || cfg.Mode == "asgardeo" { + // Demo/Asgardeo mode: Custom handlers for well-known and register + mux.HandleFunc("/.well-known/oauth-authorization-server", provider.WellKnownHandler()) + registeredPaths["/.well-known/oauth-authorization-server"] = true + + mux.HandleFunc("/register", provider.RegisterHandler()) + registeredPaths["/register"] = true + + // Authorize and token will be proxied with parameter modification + defaultPaths = []string{"/authorize", "/token"} + } else { + // Default provider mode + if cfg.Default.Path != nil { + // Check if we have custom response for well-known + wellKnownConfig, exists := cfg.Default.Path["/.well-known/oauth-authorization-server"] + if exists && wellKnownConfig.Response != nil { + // If there's a custom response defined, use our handler + mux.HandleFunc("/.well-known/oauth-authorization-server", provider.WellKnownHandler()) + registeredPaths["/.well-known/oauth-authorization-server"] = true + } else { + // No custom response, add well-known to proxy paths + defaultPaths = append(defaultPaths, "/.well-known/oauth-authorization-server") + } + + defaultPaths = append(defaultPaths, "/authorize") + defaultPaths = append(defaultPaths, "/token") + defaultPaths = append(defaultPaths, "/register") + } else { + defaultPaths = []string{"/authorize", "/token", "/register", "/.well-known/oauth-authorization-server"} + } + } + + // Remove duplicates from defaultPaths + uniquePaths := make(map[string]bool) + cleanPaths := []string{} for _, path := range defaultPaths { - mux.HandleFunc(path, buildProxyHandler(cfg)) + if !uniquePaths[path] { + uniquePaths[path] = true + cleanPaths = append(cleanPaths, path) + } + } + defaultPaths = cleanPaths + + for _, path := range defaultPaths { + if !registeredPaths[path] { + mux.HandleFunc(path, buildProxyHandler(cfg, modifiers)) + registeredPaths[path] = true + } } - // 4. MCP paths + // MCP paths for _, path := range cfg.MCPPaths { - mux.HandleFunc(path, buildProxyHandler(cfg)) + mux.HandleFunc(path, buildProxyHandler(cfg, modifiers)) + registeredPaths[path] = true } - // 5. If you want to map additional paths from config.PathMapping - // to the same proxy logic: + // Register paths from PathMapping that haven't been registered yet for path := range cfg.PathMapping { - mux.HandleFunc(path, buildProxyHandler(cfg)) + if !registeredPaths[path] { + mux.HandleFunc(path, buildProxyHandler(cfg, modifiers)) + registeredPaths[path] = true + } } return mux } -func buildProxyHandler(cfg *config.Config) http.HandlerFunc { +func buildProxyHandler(cfg *config.Config, modifiers map[string]RequestModifier) http.HandlerFunc { // Parse the base URLs up front + authBase, err := url.Parse(cfg.AuthServerBaseURL) if err != nil { log.Fatalf("Invalid auth server URL: %v", err) @@ -57,13 +110,6 @@ func buildProxyHandler(cfg *config.Config) http.HandlerFunc { log.Fatalf("Invalid MCP server URL: %v", err) } - // We'll define sets for known auth paths, SSE paths, etc. - authPaths := map[string]bool{ - "/authorize": true, - "/token": true, - "/.well-known/oauth-authorization-server": true, - } - // Detect SSE paths from config ssePaths := make(map[string]bool) for _, p := range cfg.MCPPaths { @@ -73,23 +119,38 @@ func buildProxyHandler(cfg *config.Config) http.HandlerFunc { } return func(w http.ResponseWriter, r *http.Request) { + origin := r.Header.Get("Origin") + allowedOrigin := getAllowedOrigin(origin, cfg) // Handle OPTIONS if r.Method == http.MethodOptions { - addCORSHeaders(w) + if allowedOrigin == "" { + log.Printf("[proxy] Preflight request from disallowed origin: %s", origin) + http.Error(w, "CORS origin not allowed", http.StatusForbidden) + return + } + addCORSHeaders(w, cfg, allowedOrigin, r.Header.Get("Access-Control-Request-Headers")) w.WriteHeader(http.StatusNoContent) return } - addCORSHeaders(w) + if allowedOrigin == "" { + log.Printf("[proxy] Request from disallowed origin: %s for %s", origin, r.URL.Path) + http.Error(w, "CORS origin not allowed", http.StatusForbidden) + return + } + + // Add CORS headers to all responses + addCORSHeaders(w, cfg, allowedOrigin, "") // Decide whether the request should go to the auth server or MCP var targetURL *url.URL isSSE := false - if authPaths[r.URL.Path] { + if isAuthPath(r.URL.Path) { targetURL = authBase } else if isMCPPath(r.URL.Path, cfg) { - // Validate JWT if you want + // Validate JWT for MCP paths if required + // Placeholder for JWT validation logic if err := util.ValidateJWT(r.Header.Get("Authorization")); err != nil { log.Printf("[proxy] Unauthorized request to %s: %v", r.URL.Path, err) http.Error(w, "Unauthorized", http.StatusUnauthorized) @@ -100,11 +161,21 @@ func buildProxyHandler(cfg *config.Config) http.HandlerFunc { isSSE = true } } else { - // If it's not recognized as an auth path or an MCP path http.Error(w, "Forbidden", http.StatusForbidden) return } + // Apply request modifiers to add parameters + if modifier, exists := modifiers[r.URL.Path]; exists { + var err error + r, err = modifier.ModifyRequest(r) + if err != nil { + log.Printf("[proxy] Error modifying request: %v", err) + http.Error(w, "Bad Request", http.StatusBadRequest) + return + } + } + // Build the reverse proxy rp := &httputil.ReverseProxy{ Director: func(req *http.Request) { @@ -120,23 +191,27 @@ func buildProxyHandler(cfg *config.Config) http.HandlerFunc { req.URL.RawQuery = r.URL.RawQuery req.Host = targetURL.Host - for header, values := range r.Header { + cleanHeaders := http.Header{} + + for k, v := range r.Header { // Skip hop-by-hop headers - if strings.EqualFold(header, "Connection") || - strings.EqualFold(header, "Keep-Alive") || - strings.EqualFold(header, "Transfer-Encoding") || - strings.EqualFold(header, "Upgrade") || - strings.EqualFold(header, "Proxy-Authorization") || - strings.EqualFold(header, "Proxy-Connection") { + if skipHeader(k) { continue } - for _, value := range values { - req.Header.Set(header, value) - } + // Set only the first value to avoid duplicates + cleanHeaders.Set(k, v[0]) } + + req.Header = cleanHeaders + log.Printf("[proxy] %s -> %s%s", r.URL.Path, req.URL.Host, req.URL.Path) }, + ModifyResponse: func(resp *http.Response) error { + log.Printf("[proxy] Response from %s%s: %d", resp.Request.URL.Host, resp.Request.URL.Path, resp.StatusCode) + resp.Header.Del("Access-Control-Allow-Origin") // Avoid upstream conflicts + return nil + }, ErrorHandler: func(rw http.ResponseWriter, req *http.Request, err error) { log.Printf("[proxy] Error proxying: %v", err) http.Error(rw, "Bad Gateway", http.StatusBadGateway) @@ -156,12 +231,47 @@ func buildProxyHandler(cfg *config.Config) http.HandlerFunc { } } -func addCORSHeaders(w http.ResponseWriter) { - w.Header().Set("Access-Control-Allow-Origin", "*") - w.Header().Set("Access-Control-Allow-Headers", "Authorization, Content-Type") - w.Header().Set("Access-Control-Allow-Methods", "GET, POST, OPTIONS") +func getAllowedOrigin(origin string, cfg *config.Config) string { + if origin == "" { + return cfg.CORSConfig.AllowedOrigins[0] // Default to first allowed origin + } + for _, allowed := range cfg.CORSConfig.AllowedOrigins { + if allowed == origin { + return allowed + } + } + return "" } +// addCORSHeaders adds configurable CORS headers +func addCORSHeaders(w http.ResponseWriter, cfg *config.Config, allowedOrigin, requestHeaders string) { + w.Header().Set("Access-Control-Allow-Origin", allowedOrigin) + w.Header().Set("Access-Control-Allow-Methods", strings.Join(cfg.CORSConfig.AllowedMethods, ", ")) + if requestHeaders != "" { + w.Header().Set("Access-Control-Allow-Headers", requestHeaders) + } else { + w.Header().Set("Access-Control-Allow-Headers", strings.Join(cfg.CORSConfig.AllowedHeaders, ", ")) + } + if cfg.CORSConfig.AllowCredentials { + w.Header().Set("Access-Control-Allow-Credentials", "true") + } + w.Header().Set("Vary", "Origin") +} + +func isAuthPath(path string) bool { + authPaths := map[string]bool{ + "/authorize": true, + "/token": true, + "/register": true, + "/.well-known/oauth-authorization-server": true, + } + if strings.HasPrefix(path, "/u/") { + return true + } + return authPaths[path] +} + +// isMCPPath checks if the path is an MCP path func isMCPPath(path string, cfg *config.Config) bool { for _, p := range cfg.MCPPaths { if strings.HasPrefix(path, p) { @@ -171,22 +281,10 @@ func isMCPPath(path string, cfg *config.Config) bool { return false } -func copyHeaders(src http.Header, dst http.Header) { - // Exclude hop-by-hop - hopByHop := map[string]bool{ - "Connection": true, - "Keep-Alive": true, - "Transfer-Encoding": true, - "Upgrade": true, - "Proxy-Authorization": true, - "Proxy-Connection": true, - } - for k, vv := range src { - if hopByHop[strings.ToLower(k)] { - continue - } - for _, v := range vv { - dst.Add(k, v) - } +func skipHeader(h string) bool { + switch strings.ToLower(h) { + case "connection", "keep-alive", "transfer-encoding", "upgrade", "proxy-authorization", "proxy-connection", "te", "trailer": + return true } + return false }