diff --git a/cmd/proxy/main.go b/cmd/proxy/main.go index 699929f..cde3cf3 100644 --- a/cmd/proxy/main.go +++ b/cmd/proxy/main.go @@ -11,6 +11,7 @@ import ( "github.com/wso2/open-mcp-auth-proxy/internal/authz" "github.com/wso2/open-mcp-auth-proxy/internal/config" + "github.com/wso2/open-mcp-auth-proxy/internal/constants" "github.com/wso2/open-mcp-auth-proxy/internal/proxy" "github.com/wso2/open-mcp-auth-proxy/internal/util" ) @@ -30,15 +31,19 @@ func main() { var provider authz.Provider if *demoMode { cfg.Mode = "demo" - cfg.AuthServerBaseURL = "https://api.asgardeo.io/t/" + cfg.Demo.OrgName + "/oauth2" - cfg.JWKSURL = "https://api.asgardeo.io/t/" + cfg.Demo.OrgName + "/oauth2/jwks" + cfg.AuthServerBaseURL = constants.ASGARDEO_BASE_URL + cfg.Demo.OrgName + "/oauth2" + cfg.JWKSURL = constants.ASGARDEO_BASE_URL + cfg.Demo.OrgName + "/oauth2/jwks" provider = authz.NewAsgardeoProvider(cfg) - fmt.Println("Using Asgardeo provider (demo).") } else if *asgardeoMode { cfg.Mode = "asgardeo" - cfg.AuthServerBaseURL = "https://api.asgardeo.io/t/" + cfg.Asgardeo.OrgName + "/oauth2" - cfg.JWKSURL = "https://api.asgardeo.io/t/" + cfg.Asgardeo.OrgName + "/oauth2/jwks" + cfg.AuthServerBaseURL = constants.ASGARDEO_BASE_URL + cfg.Asgardeo.OrgName + "/oauth2" + cfg.JWKSURL = constants.ASGARDEO_BASE_URL + cfg.Asgardeo.OrgName + "/oauth2/jwks" provider = authz.NewAsgardeoProvider(cfg) + } else { + cfg.Mode = "default" + cfg.JWKSURL = cfg.Default.JWKSURL + cfg.AuthServerBaseURL = cfg.Default.BaseURL + provider = authz.NewDefaultProvider(cfg) } // 3. (Optional) Fetch JWKS if you want local JWT validation diff --git a/config.yaml b/config.yaml index ab84ad2..0b0ade4 100644 --- a/config.yaml +++ b/config.yaml @@ -1,9 +1,7 @@ # config.yaml -auth_server_base_url: "" mcp_server_base_url: "" listen_port: 8080 -jwks_url: "" timeout_seconds: 10 mcp_paths: @@ -11,8 +9,10 @@ mcp_paths: - /sse path_mapping: - /token: /oauth/token - /.well-known/oauth-authorization-server: /.well-known/openid-configuration + /token: /token + /register: /register + /authorize: /authorize + /.well-known/oauth-authorization-server: /.well-known/oauth-authorization-server cors: allowed_origins: @@ -36,3 +36,36 @@ asgardeo: org_name: "" client_id: "" client_secret: "" + +default: + base_url: "" + jwks_url: "" + path: + /.well-known/oauth-authorization-server: + response: + issuer: "" + jwks_uri: "" + authorization_endpoint: "" # Optional + token_endpoint: "" # Optional + registration_endpoint: "" # Optional + response_types_supported: + - "code" + grant_types_supported: + - "authorization_code" + - "refresh_token" + code_challenge_methods_supported: + - "S256" + - "plain" + /authroize: + addQueryParams: + - name: "" + value: "" + /token: + addBodyParams: + - name: "" + value: "" + /register: + addBodyParams: + - name: "" + value: "" + diff --git a/internal/authz/default.go b/internal/authz/default.go new file mode 100644 index 0000000..9230d39 --- /dev/null +++ b/internal/authz/default.go @@ -0,0 +1,94 @@ +package authz + +import ( + "encoding/json" + "net/http" + + "github.com/wso2/open-mcp-auth-proxy/internal/config" +) + +type defaultProvider struct { + cfg *config.Config +} + +// NewDefaultProvider initializes a Provider for Asgardeo (demo mode). +func NewDefaultProvider(cfg *config.Config) Provider { + return &defaultProvider{cfg: cfg} +} + +func (p *defaultProvider) WellKnownHandler() http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Access-Control-Allow-Origin", "*") + w.Header().Set("Access-Control-Allow-Headers", "Authorization, Content-Type") + w.Header().Set("Access-Control-Allow-Methods", "GET, OPTIONS") + + if r.Method == http.MethodOptions { + w.WriteHeader(http.StatusNoContent) + return + } + + if r.Method != http.MethodGet { + http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) + return + } + + // Check if we have a custom response configuration + if p.cfg.Default.Path != nil { + pathConfig, exists := p.cfg.Default.Path["/.well-known/oauth-authorization-server"] + if exists && pathConfig.Response != nil { + // Use configured response values + responseConfig := pathConfig.Response + + // Get current host for proxy endpoints + scheme := "http" + if r.TLS != nil { + scheme = "https" + } + if forwardedProto := r.Header.Get("X-Forwarded-Proto"); forwardedProto != "" { + scheme = forwardedProto + } + host := r.Host + if forwardedHost := r.Header.Get("X-Forwarded-Host"); forwardedHost != "" { + host = forwardedHost + } + baseURL := scheme + "://" + host + + authorizationEndpoint := responseConfig.AuthorizationEndpoint + if authorizationEndpoint == "" { + authorizationEndpoint = baseURL + "/authorize" + } + tokenEndpoint := responseConfig.TokenEndpoint + if tokenEndpoint == "" { + tokenEndpoint = baseURL + "/token" + } + registraionEndpoint := responseConfig.RegistrationEndpoint + if registraionEndpoint == "" { + registraionEndpoint = baseURL + "/register" + } + + // Build response from config + response := map[string]interface{}{ + "issuer": responseConfig.Issuer, + "authorization_endpoint": authorizationEndpoint, + "token_endpoint": tokenEndpoint, + "jwks_uri": responseConfig.JwksURI, + "response_types_supported": responseConfig.ResponseTypesSupported, + "grant_types_supported": responseConfig.GrantTypesSupported, + "token_endpoint_auth_methods_supported": []string{"client_secret_basic"}, + "registration_endpoint": registraionEndpoint, + "code_challenge_methods_supported": responseConfig.CodeChallengeMethodsSupported, + } + + w.Header().Set("Content-Type", "application/json") + if err := json.NewEncoder(w).Encode(response); err != nil { + http.Error(w, "Internal server error", http.StatusInternalServerError) + } + return + } + } + } +} + +func (p *defaultProvider) RegisterHandler() http.HandlerFunc { + return nil +} diff --git a/internal/config/config.go b/internal/config/config.go index 2f63e7c..01c3a6f 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -26,11 +26,44 @@ type CORSConfig struct { AllowCredentials bool `yaml:"allow_credentials"` } +type ParamConfig struct { + Name string `yaml:"name"` + Value string `yaml:"value"` +} + +type ResponseConfig struct { + Issuer string `yaml:"issuer,omitempty"` + JwksURI string `yaml:"jwks_uri,omitempty"` + AuthorizationEndpoint string `yaml:"authorization_endpoint,omitempty"` + TokenEndpoint string `yaml:"token_endpoint,omitempty"` + RegistrationEndpoint string `yaml:"registration_endpoint,omitempty"` + ResponseTypesSupported []string `yaml:"response_types_supported,omitempty"` + GrantTypesSupported []string `yaml:"grant_types_supported,omitempty"` + CodeChallengeMethodsSupported []string `yaml:"code_challenge_methods_supported,omitempty"` +} + +type PathConfig struct { + // For well-known endpoint + Response *ResponseConfig `yaml:"response,omitempty"` + + // For authorization endpoint + AddQueryParams []ParamConfig `yaml:"addQueryParams,omitempty"` + + // For token and register endpoints + AddBodyParams []ParamConfig `yaml:"addBodyParams,omitempty"` +} + +type DefaultConfig struct { + BaseURL string `yaml:"base_url,omitempty"` + Path map[string]PathConfig `yaml:"path,omitempty"` + JWKSURL string `yaml:"jwks_url,omitempty"` +} + type Config struct { - AuthServerBaseURL string `yaml:"auth_server_base_url"` - MCPServerBaseURL string `yaml:"mcp_server_base_url"` - ListenPort int `yaml:"listen_port"` - JWKSURL string `yaml:"jwks_url"` + AuthServerBaseURL string + MCPServerBaseURL string `yaml:"mcp_server_base_url"` + ListenPort int `yaml:"listen_port"` + JWKSURL string TimeoutSeconds int `yaml:"timeout_seconds"` MCPPaths []string `yaml:"mcp_paths"` PathMapping map[string]string `yaml:"path_mapping"` @@ -40,6 +73,7 @@ type Config struct { // Nested config for Asgardeo Demo DemoConfig `yaml:"demo"` Asgardeo AsgardeoConfig `yaml:"asgardeo"` + Default DefaultConfig `yaml:"default"` } // LoadConfig reads a YAML config file into Config struct. diff --git a/internal/constants/constants.go b/internal/constants/constants.go new file mode 100644 index 0000000..1e5808e --- /dev/null +++ b/internal/constants/constants.go @@ -0,0 +1,7 @@ +package constants + +// Package constant provides constants for the MCP Auth Proxy + +const ( + ASGARDEO_BASE_URL = "https://api.asgardeo.io/t/" +) diff --git a/internal/proxy/modifier.go b/internal/proxy/modifier.go new file mode 100644 index 0000000..8e2268b --- /dev/null +++ b/internal/proxy/modifier.go @@ -0,0 +1,199 @@ +package proxy + +import ( + "bytes" + "encoding/json" + "fmt" + "io" + "net/http" + "strings" + + "github.com/wso2/open-mcp-auth-proxy/internal/config" +) + +// RequestModifier modifies requests before they are proxied +type RequestModifier interface { + ModifyRequest(req *http.Request) (*http.Request, error) +} + +// AuthorizationModifier adds parameters to authorization requests +type AuthorizationModifier struct { + Config *config.Config +} + +// TokenModifier adds parameters to token requests +type TokenModifier struct { + Config *config.Config +} + +type RegisterModifier struct { + Config *config.Config +} + +// ModifyRequest adds configured parameters to authorization requests +func (m *AuthorizationModifier) ModifyRequest(req *http.Request) (*http.Request, error) { + // Check if we have parameters to add + if m.Config.Default.Path == nil { + return req, nil + } + + pathConfig, exists := m.Config.Default.Path["/authorize"] + if !exists || len(pathConfig.AddQueryParams) == 0 { + return req, nil + } + // Get current query parameters + query := req.URL.Query() + + // Add parameters from config + for _, param := range pathConfig.AddQueryParams { + query.Set(param.Name, param.Value) + } + + // Update the request URL + req.URL.RawQuery = query.Encode() + + return req, nil +} + +// ModifyRequest adds configured parameters to token requests +func (m *TokenModifier) ModifyRequest(req *http.Request) (*http.Request, error) { + // Only modify POST requests + if req.Method != http.MethodPost { + return req, nil + } + + // Check if we have parameters to add + if m.Config.Default.Path == nil { + return req, nil + } + + pathConfig, exists := m.Config.Default.Path["/token"] + if !exists || len(pathConfig.AddBodyParams) == 0 { + return req, nil + } + + contentType := req.Header.Get("Content-Type") + + if strings.Contains(contentType, "application/x-www-form-urlencoded") { + // Parse form data + if err := req.ParseForm(); err != nil { + return nil, err + } + + // Clone form data + formData := req.PostForm + + // Add configured parameters + for _, param := range pathConfig.AddBodyParams { + formData.Set(param.Name, param.Value) + } + + // Create new request body with modified form + formEncoded := formData.Encode() + req.Body = io.NopCloser(strings.NewReader(formEncoded)) + req.ContentLength = int64(len(formEncoded)) + req.Header.Set("Content-Length", fmt.Sprintf("%d", len(formEncoded))) + + } else if strings.Contains(contentType, "application/json") { + // Read body + bodyBytes, err := io.ReadAll(req.Body) + if err != nil { + return nil, err + } + + // Parse JSON + var jsonData map[string]interface{} + if err := json.Unmarshal(bodyBytes, &jsonData); err != nil { + return nil, err + } + + // Add parameters + for _, param := range pathConfig.AddBodyParams { + jsonData[param.Name] = param.Value + } + + // Marshal back to JSON + modifiedBody, err := json.Marshal(jsonData) + if err != nil { + return nil, err + } + + // Update request + req.Body = io.NopCloser(bytes.NewReader(modifiedBody)) + req.ContentLength = int64(len(modifiedBody)) + req.Header.Set("Content-Length", fmt.Sprintf("%d", len(modifiedBody))) + } + + return req, nil +} + +func (m *RegisterModifier) ModifyRequest(req *http.Request) (*http.Request, error) { + // Only modify POST requests + if req.Method != http.MethodPost { + return req, nil + } + + // Check if we have parameters to add + if m.Config.Default.Path == nil { + return req, nil + } + + pathConfig, exists := m.Config.Default.Path["/register"] + if !exists || len(pathConfig.AddBodyParams) == 0 { + return req, nil + } + + contentType := req.Header.Get("Content-Type") + + if strings.Contains(contentType, "application/x-www-form-urlencoded") { + // Parse form data + if err := req.ParseForm(); err != nil { + return nil, err + } + + // Clone form data + formData := req.PostForm + + // Add configured parameters + for _, param := range pathConfig.AddBodyParams { + formData.Set(param.Name, param.Value) + } + + // Create new request body with modified form + formEncoded := formData.Encode() + req.Body = io.NopCloser(strings.NewReader(formEncoded)) + req.ContentLength = int64(len(formEncoded)) + req.Header.Set("Content-Length", fmt.Sprintf("%d", len(formEncoded))) + + } else if strings.Contains(contentType, "application/json") { + // Read body + bodyBytes, err := io.ReadAll(req.Body) + if err != nil { + return nil, err + } + + // Parse JSON + var jsonData map[string]interface{} + if err := json.Unmarshal(bodyBytes, &jsonData); err != nil { + return nil, err + } + + // Add parameters + for _, param := range pathConfig.AddBodyParams { + jsonData[param.Name] = param.Value + } + + // Marshal back to JSON + modifiedBody, err := json.Marshal(jsonData) + if err != nil { + return nil, err + } + + // Update request + req.Body = io.NopCloser(bytes.NewReader(modifiedBody)) + req.ContentLength = int64(len(modifiedBody)) + req.Header.Set("Content-Length", fmt.Sprintf("%d", len(modifiedBody))) + } + + return req, nil +} diff --git a/internal/proxy/proxy.go b/internal/proxy/proxy.go index 5f74125..c999be4 100644 --- a/internal/proxy/proxy.go +++ b/internal/proxy/proxy.go @@ -20,38 +20,77 @@ import ( func NewRouter(cfg *config.Config, provider authz.Provider) http.Handler { mux := http.NewServeMux() + modifiers := map[string]RequestModifier{ + "/authorize": &AuthorizationModifier{Config: cfg}, + "/token": &TokenModifier{Config: cfg}, + "/register": &RegisterModifier{Config: cfg}, + } + registeredPaths := make(map[string]bool) var defaultPaths []string + + // Handle based on mode configuration if cfg.Mode == "demo" || cfg.Mode == "asgardeo" { - // 1. Custom well-known + // Demo/Asgardeo mode: Custom handlers for well-known and register mux.HandleFunc("/.well-known/oauth-authorization-server", provider.WellKnownHandler()) registeredPaths["/.well-known/oauth-authorization-server"] = true - // 2. Registration mux.HandleFunc("/register", provider.RegisterHandler()) registeredPaths["/register"] = true + // Authorize and token will be proxied with parameter modification defaultPaths = []string{"/authorize", "/token"} } else { - defaultPaths = []string{"/authorize", "/token", "/register", "/.well-known/oauth-authorization-server"} + // Default provider mode + if cfg.Default.Path != nil { + // Check if we have custom response for well-known + wellKnownConfig, exists := cfg.Default.Path["/.well-known/oauth-authorization-server"] + if exists && wellKnownConfig.Response != nil { + // If there's a custom response defined, use our handler + mux.HandleFunc("/.well-known/oauth-authorization-server", provider.WellKnownHandler()) + registeredPaths["/.well-known/oauth-authorization-server"] = true + } else { + // No custom response, add well-known to proxy paths + defaultPaths = append(defaultPaths, "/.well-known/oauth-authorization-server") + } + + defaultPaths = append(defaultPaths, "/authorize") + defaultPaths = append(defaultPaths, "/token") + defaultPaths = append(defaultPaths, "/register") + } else { + defaultPaths = []string{"/authorize", "/token", "/register", "/.well-known/oauth-authorization-server"} + } } + // Remove duplicates from defaultPaths + uniquePaths := make(map[string]bool) + cleanPaths := []string{} + for _, path := range defaultPaths { + if !uniquePaths[path] { + uniquePaths[path] = true + cleanPaths = append(cleanPaths, path) + } + } + defaultPaths = cleanPaths + for _, path := range defaultPaths { - mux.HandleFunc(path, buildProxyHandler(cfg)) - registeredPaths[path] = true + if !registeredPaths[path] { + mux.HandleFunc(path, buildProxyHandler(cfg, modifiers)) + registeredPaths[path] = true + } } - // 4. MCP paths + // MCP paths for _, path := range cfg.MCPPaths { - mux.HandleFunc(path, buildProxyHandler(cfg)) + mux.HandleFunc(path, buildProxyHandler(cfg, modifiers)) registeredPaths[path] = true } - // 5. Register paths from PathMapping that haven't been registered yet + // Register paths from PathMapping that haven't been registered yet for path := range cfg.PathMapping { if !registeredPaths[path] { - mux.HandleFunc(path, buildProxyHandler(cfg)) + mux.HandleFunc(path, buildProxyHandler(cfg, modifiers)) registeredPaths[path] = true } } @@ -59,8 +98,9 @@ func NewRouter(cfg *config.Config, provider authz.Provider) http.Handler { return mux } -func buildProxyHandler(cfg *config.Config) http.HandlerFunc { +func buildProxyHandler(cfg *config.Config, modifiers map[string]RequestModifier) http.HandlerFunc { // Parse the base URLs up front + authBase, err := url.Parse(cfg.AuthServerBaseURL) if err != nil { log.Fatalf("Invalid auth server URL: %v", err) @@ -125,6 +165,17 @@ func buildProxyHandler(cfg *config.Config) http.HandlerFunc { return } + // Apply request modifiers to add parameters + if modifier, exists := modifiers[r.URL.Path]; exists { + var err error + r, err = modifier.ModifyRequest(r) + if err != nil { + log.Printf("[proxy] Error modifying request: %v", err) + http.Error(w, "Bad Request", http.StatusBadRequest) + return + } + } + // Build the reverse proxy rp := &httputil.ReverseProxy{ Director: func(req *http.Request) { @@ -152,23 +203,8 @@ func buildProxyHandler(cfg *config.Config) http.HandlerFunc { cleanHeaders.Set(k, v[0]) } - // Override or remove sensitive headers if needed - if strings.Contains(req.URL.Path, "/token") { - cleanHeaders.Set("Accept", "application/json") - cleanHeaders.Set("Content-Type", "application/x-www-form-urlencoded") - cleanHeaders.Set("User-Agent", "GoProxy/1.0") - cleanHeaders.Del("Origin") - cleanHeaders.Del("Referer") - } - req.Header = cleanHeaders - // DEBUG: log headers sent to Asgardeo - log.Println("[proxy] Outgoing request headers:") - for k, v := range req.Header { - log.Printf(" %s: %s", k, strings.Join(v, ", ")) - } - log.Printf("[proxy] %s -> %s%s", r.URL.Path, req.URL.Host, req.URL.Path) }, ModifyResponse: func(resp *http.Response) error { @@ -226,8 +262,12 @@ func isAuthPath(path string) bool { authPaths := map[string]bool{ "/authorize": true, "/token": true, + "/register": true, "/.well-known/oauth-authorization-server": true, } + if strings.HasPrefix(path, "/u/") { + return true + } return authPaths[path] }