mirror of
https://github.com/wso2/open-mcp-auth-proxy.git
synced 2025-06-27 17:13:31 +00:00
fix standard auth
This commit is contained in:
parent
3d085008a8
commit
960261fc80
4 changed files with 137 additions and 53 deletions
|
@ -19,6 +19,13 @@ type AsgardeoConfig struct {
|
|||
OrgName string `yaml:"org_name"`
|
||||
}
|
||||
|
||||
type CORSConfig struct {
|
||||
AllowedOrigins []string `yaml:"allowed_origins"`
|
||||
AllowedMethods []string `yaml:"allowed_methods"`
|
||||
AllowedHeaders []string `yaml:"allowed_headers"`
|
||||
AllowCredentials bool `yaml:"allow_credentials"`
|
||||
}
|
||||
|
||||
type Config struct {
|
||||
AuthServerBaseURL string `yaml:"auth_server_base_url"`
|
||||
MCPServerBaseURL string `yaml:"mcp_server_base_url"`
|
||||
|
@ -27,6 +34,8 @@ type Config struct {
|
|||
TimeoutSeconds int `yaml:"timeout_seconds"`
|
||||
MCPPaths []string `yaml:"mcp_paths"`
|
||||
PathMapping map[string]string `yaml:"path_mapping"`
|
||||
Mode string `yaml:"mode"`
|
||||
CORSConfig CORSConfig `yaml:"cors"`
|
||||
|
||||
// Nested config for Asgardeo
|
||||
Demo DemoConfig `yaml:"demo"`
|
||||
|
|
|
@ -20,27 +20,40 @@ import (
|
|||
func NewRouter(cfg *config.Config, provider authz.Provider) http.Handler {
|
||||
mux := http.NewServeMux()
|
||||
|
||||
// 1. Custom well-known
|
||||
mux.HandleFunc("/.well-known/oauth-authorization-server", provider.WellKnownHandler())
|
||||
registeredPaths := make(map[string]bool)
|
||||
|
||||
// 2. Registration
|
||||
mux.HandleFunc("/register", provider.RegisterHandler())
|
||||
var defaultPaths []string
|
||||
if cfg.Mode == "demo" || cfg.Mode == "asgardeo" {
|
||||
// 1. Custom well-known
|
||||
mux.HandleFunc("/.well-known/oauth-authorization-server", provider.WellKnownHandler())
|
||||
registeredPaths["/.well-known/oauth-authorization-server"] = true
|
||||
|
||||
// 2. Registration
|
||||
mux.HandleFunc("/register", provider.RegisterHandler())
|
||||
registeredPaths["/register"] = true
|
||||
|
||||
defaultPaths = []string{"/authorize", "/token"}
|
||||
} else {
|
||||
defaultPaths = []string{"/authorize", "/token", "/register", "/.well-known/oauth-authorization-server"}
|
||||
}
|
||||
|
||||
// 3. Default "auth" paths, proxied
|
||||
defaultPaths := []string{"/authorize", "/token"}
|
||||
for _, path := range defaultPaths {
|
||||
mux.HandleFunc(path, buildProxyHandler(cfg))
|
||||
registeredPaths[path] = true
|
||||
}
|
||||
|
||||
// 4. MCP paths
|
||||
for _, path := range cfg.MCPPaths {
|
||||
mux.HandleFunc(path, buildProxyHandler(cfg))
|
||||
registeredPaths[path] = true
|
||||
}
|
||||
|
||||
// 5. If you want to map additional paths from config.PathMapping
|
||||
// to the same proxy logic:
|
||||
// 5. Register paths from PathMapping that haven't been registered yet
|
||||
for path := range cfg.PathMapping {
|
||||
mux.HandleFunc(path, buildProxyHandler(cfg))
|
||||
if !registeredPaths[path] {
|
||||
mux.HandleFunc(path, buildProxyHandler(cfg))
|
||||
registeredPaths[path] = true
|
||||
}
|
||||
}
|
||||
|
||||
return mux
|
||||
|
@ -57,13 +70,6 @@ func buildProxyHandler(cfg *config.Config) http.HandlerFunc {
|
|||
log.Fatalf("Invalid MCP server URL: %v", err)
|
||||
}
|
||||
|
||||
// We'll define sets for known auth paths, SSE paths, etc.
|
||||
authPaths := map[string]bool{
|
||||
"/authorize": true,
|
||||
"/token": true,
|
||||
"/.well-known/oauth-authorization-server": true,
|
||||
}
|
||||
|
||||
// Detect SSE paths from config
|
||||
ssePaths := make(map[string]bool)
|
||||
for _, p := range cfg.MCPPaths {
|
||||
|
@ -73,23 +79,38 @@ func buildProxyHandler(cfg *config.Config) http.HandlerFunc {
|
|||
}
|
||||
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
origin := r.Header.Get("Origin")
|
||||
allowedOrigin := getAllowedOrigin(origin, cfg)
|
||||
// Handle OPTIONS
|
||||
if r.Method == http.MethodOptions {
|
||||
addCORSHeaders(w)
|
||||
if allowedOrigin == "" {
|
||||
log.Printf("[proxy] Preflight request from disallowed origin: %s", origin)
|
||||
http.Error(w, "CORS origin not allowed", http.StatusForbidden)
|
||||
return
|
||||
}
|
||||
addCORSHeaders(w, cfg, allowedOrigin, r.Header.Get("Access-Control-Request-Headers"))
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
return
|
||||
}
|
||||
|
||||
addCORSHeaders(w)
|
||||
if allowedOrigin == "" {
|
||||
log.Printf("[proxy] Request from disallowed origin: %s for %s", origin, r.URL.Path)
|
||||
http.Error(w, "CORS origin not allowed", http.StatusForbidden)
|
||||
return
|
||||
}
|
||||
|
||||
// Add CORS headers to all responses
|
||||
addCORSHeaders(w, cfg, allowedOrigin, "")
|
||||
|
||||
// Decide whether the request should go to the auth server or MCP
|
||||
var targetURL *url.URL
|
||||
isSSE := false
|
||||
|
||||
if authPaths[r.URL.Path] {
|
||||
if isAuthPath(r.URL.Path) {
|
||||
targetURL = authBase
|
||||
} else if isMCPPath(r.URL.Path, cfg) {
|
||||
// Validate JWT if you want
|
||||
// Validate JWT for MCP paths if required
|
||||
// Placeholder for JWT validation logic
|
||||
if err := util.ValidateJWT(r.Header.Get("Authorization")); err != nil {
|
||||
log.Printf("[proxy] Unauthorized request to %s: %v", r.URL.Path, err)
|
||||
http.Error(w, "Unauthorized", http.StatusUnauthorized)
|
||||
|
@ -100,7 +121,6 @@ func buildProxyHandler(cfg *config.Config) http.HandlerFunc {
|
|||
isSSE = true
|
||||
}
|
||||
} else {
|
||||
// If it's not recognized as an auth path or an MCP path
|
||||
http.Error(w, "Forbidden", http.StatusForbidden)
|
||||
return
|
||||
}
|
||||
|
@ -120,23 +140,42 @@ func buildProxyHandler(cfg *config.Config) http.HandlerFunc {
|
|||
req.URL.RawQuery = r.URL.RawQuery
|
||||
req.Host = targetURL.Host
|
||||
|
||||
for header, values := range r.Header {
|
||||
cleanHeaders := http.Header{}
|
||||
|
||||
for k, v := range r.Header {
|
||||
// Skip hop-by-hop headers
|
||||
if strings.EqualFold(header, "Connection") ||
|
||||
strings.EqualFold(header, "Keep-Alive") ||
|
||||
strings.EqualFold(header, "Transfer-Encoding") ||
|
||||
strings.EqualFold(header, "Upgrade") ||
|
||||
strings.EqualFold(header, "Proxy-Authorization") ||
|
||||
strings.EqualFold(header, "Proxy-Connection") {
|
||||
if skipHeader(k) {
|
||||
continue
|
||||
}
|
||||
|
||||
for _, value := range values {
|
||||
req.Header.Set(header, value)
|
||||
}
|
||||
// Set only the first value to avoid duplicates
|
||||
cleanHeaders.Set(k, v[0])
|
||||
}
|
||||
|
||||
// Override or remove sensitive headers if needed
|
||||
if strings.Contains(req.URL.Path, "/token") {
|
||||
cleanHeaders.Set("Accept", "application/json")
|
||||
cleanHeaders.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
cleanHeaders.Set("User-Agent", "GoProxy/1.0")
|
||||
cleanHeaders.Del("Origin")
|
||||
cleanHeaders.Del("Referer")
|
||||
}
|
||||
|
||||
req.Header = cleanHeaders
|
||||
|
||||
// DEBUG: log headers sent to Asgardeo
|
||||
log.Println("[proxy] Outgoing request headers:")
|
||||
for k, v := range req.Header {
|
||||
log.Printf(" %s: %s", k, strings.Join(v, ", "))
|
||||
}
|
||||
|
||||
log.Printf("[proxy] %s -> %s%s", r.URL.Path, req.URL.Host, req.URL.Path)
|
||||
},
|
||||
ModifyResponse: func(resp *http.Response) error {
|
||||
log.Printf("[proxy] Response from %s%s: %d", resp.Request.URL.Host, resp.Request.URL.Path, resp.StatusCode)
|
||||
resp.Header.Del("Access-Control-Allow-Origin") // Avoid upstream conflicts
|
||||
return nil
|
||||
},
|
||||
ErrorHandler: func(rw http.ResponseWriter, req *http.Request, err error) {
|
||||
log.Printf("[proxy] Error proxying: %v", err)
|
||||
http.Error(rw, "Bad Gateway", http.StatusBadGateway)
|
||||
|
@ -156,12 +195,43 @@ func buildProxyHandler(cfg *config.Config) http.HandlerFunc {
|
|||
}
|
||||
}
|
||||
|
||||
func addCORSHeaders(w http.ResponseWriter) {
|
||||
w.Header().Set("Access-Control-Allow-Origin", "*")
|
||||
w.Header().Set("Access-Control-Allow-Headers", "Authorization, Content-Type")
|
||||
w.Header().Set("Access-Control-Allow-Methods", "GET, POST, OPTIONS")
|
||||
func getAllowedOrigin(origin string, cfg *config.Config) string {
|
||||
if origin == "" {
|
||||
return cfg.CORSConfig.AllowedOrigins[0] // Default to first allowed origin
|
||||
}
|
||||
for _, allowed := range cfg.CORSConfig.AllowedOrigins {
|
||||
if allowed == origin {
|
||||
return allowed
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// addCORSHeaders adds configurable CORS headers
|
||||
func addCORSHeaders(w http.ResponseWriter, cfg *config.Config, allowedOrigin, requestHeaders string) {
|
||||
w.Header().Set("Access-Control-Allow-Origin", allowedOrigin)
|
||||
w.Header().Set("Access-Control-Allow-Methods", strings.Join(cfg.CORSConfig.AllowedMethods, ", "))
|
||||
if requestHeaders != "" {
|
||||
w.Header().Set("Access-Control-Allow-Headers", requestHeaders)
|
||||
} else {
|
||||
w.Header().Set("Access-Control-Allow-Headers", strings.Join(cfg.CORSConfig.AllowedHeaders, ", "))
|
||||
}
|
||||
if cfg.CORSConfig.AllowCredentials {
|
||||
w.Header().Set("Access-Control-Allow-Credentials", "true")
|
||||
}
|
||||
w.Header().Set("Vary", "Origin")
|
||||
}
|
||||
|
||||
func isAuthPath(path string) bool {
|
||||
authPaths := map[string]bool{
|
||||
"/authorize": true,
|
||||
"/token": true,
|
||||
"/.well-known/oauth-authorization-server": true,
|
||||
}
|
||||
return authPaths[path]
|
||||
}
|
||||
|
||||
// isMCPPath checks if the path is an MCP path
|
||||
func isMCPPath(path string, cfg *config.Config) bool {
|
||||
for _, p := range cfg.MCPPaths {
|
||||
if strings.HasPrefix(path, p) {
|
||||
|
@ -171,22 +241,10 @@ func isMCPPath(path string, cfg *config.Config) bool {
|
|||
return false
|
||||
}
|
||||
|
||||
func copyHeaders(src http.Header, dst http.Header) {
|
||||
// Exclude hop-by-hop
|
||||
hopByHop := map[string]bool{
|
||||
"Connection": true,
|
||||
"Keep-Alive": true,
|
||||
"Transfer-Encoding": true,
|
||||
"Upgrade": true,
|
||||
"Proxy-Authorization": true,
|
||||
"Proxy-Connection": true,
|
||||
}
|
||||
for k, vv := range src {
|
||||
if hopByHop[strings.ToLower(k)] {
|
||||
continue
|
||||
}
|
||||
for _, v := range vv {
|
||||
dst.Add(k, v)
|
||||
}
|
||||
func skipHeader(h string) bool {
|
||||
switch strings.ToLower(h) {
|
||||
case "connection", "keep-alive", "transfer-encoding", "upgrade", "proxy-authorization", "proxy-connection", "te", "trailer":
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue