mirror of
https://github.com/wso2/open-mcp-auth-proxy.git
synced 2025-06-28 09:24:19 +00:00
Refactor proxy builder
This commit is contained in:
parent
85e5fe1c1d
commit
331cc281c6
5 changed files with 200 additions and 35 deletions
|
@ -177,7 +177,7 @@ func buildProxyHandler(cfg *config.Config, modifiers map[string]RequestModifier,
|
|||
} else {
|
||||
claims, err := authorizeMCP(w, r, isLatestSpec, cfg)
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusUnauthorized)
|
||||
http.Error(w, err.Error(), http.StatusForbidden)
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -227,13 +227,13 @@ func buildProxyHandler(cfg *config.Config, modifiers map[string]RequestModifier,
|
|||
req.Host = targetURL.Host
|
||||
|
||||
cleanHeaders := http.Header{}
|
||||
|
||||
|
||||
// Set proper origin header to match the target
|
||||
if isSSE {
|
||||
// For SSE, ensure origin matches the target
|
||||
req.Header.Set("Origin", targetURL.Scheme+"://"+targetURL.Host)
|
||||
}
|
||||
|
||||
|
||||
for k, v := range r.Header {
|
||||
// Skip hop-by-hop headers
|
||||
if skipHeader(k) {
|
||||
|
@ -277,7 +277,7 @@ func buildProxyHandler(cfg *config.Config, modifiers map[string]RequestModifier,
|
|||
proxyHost: r.Host,
|
||||
targetHost: targetURL.Host,
|
||||
}
|
||||
|
||||
|
||||
// Set SSE-specific headers
|
||||
w.Header().Set("X-Accel-Buffering", "no")
|
||||
w.Header().Set("Cache-Control", "no-cache")
|
||||
|
@ -296,15 +296,14 @@ func buildProxyHandler(cfg *config.Config, modifiers map[string]RequestModifier,
|
|||
|
||||
// Check if the request is for SSE handshake and authorize it
|
||||
func authorizeSSE(w http.ResponseWriter, r *http.Request, isLatestSpec bool, resourceID string) error {
|
||||
h := r.Header.Get("Authorization")
|
||||
if !strings.HasPrefix(h, "Bearer ") {
|
||||
authHeader := r.Header.Get("Authorization")
|
||||
if authHeader == "" || !strings.HasPrefix(authHeader, "Bearer ") {
|
||||
if isLatestSpec {
|
||||
realm := resourceID + "/.well-known/oauth-protected-resource"
|
||||
w.Header().Set("WWW-Authenticate", fmt.Sprintf(`Bearer resource_metadata="%s"`, realm))
|
||||
w.Header().Set("Access-Control-Expose-Headers", "WWW-Authenticate")
|
||||
}
|
||||
|
||||
return fmt.Errorf("missing bearer token")
|
||||
return fmt.Errorf("missing or invalid Authorization header")
|
||||
}
|
||||
|
||||
return nil
|
||||
|
@ -312,23 +311,31 @@ func authorizeSSE(w http.ResponseWriter, r *http.Request, isLatestSpec bool, res
|
|||
|
||||
// Handles both v1 (just signature) and v2 (aud + scope) flows
|
||||
func authorizeMCP(w http.ResponseWriter, r *http.Request, isLatestSpec bool, cfg *config.Config) (*authz.TokenClaims, error) {
|
||||
logger.Info("authorizeMCP")
|
||||
h := r.Header.Get("Authorization")
|
||||
audience := cfg.ResourceIdentifier
|
||||
if isLatestSpec {
|
||||
scope := cfg.ScopesSupported[r.URL.Path]
|
||||
claims, err := util.ValidateJWT(r.Header.Get("MCP-Protocol-Version"), h, audience, scope)
|
||||
required := cfg.ScopesSupported[r.URL.Path]
|
||||
claims, err := util.ValidateJWT(r.Header.Get("MCP-Protocol-Version"), h, audience, required)
|
||||
logger.Info("claims: %v", claims)
|
||||
logger.Info("err: %v", err)
|
||||
if err != nil {
|
||||
realm := audience + "/.well-known/oauth-protected-resource"
|
||||
w.Header().Set("WWW-Authenticate",
|
||||
fmt.Sprintf(`Bearer realm="%s", error="insufficient_scope", scope="%s"`, realm, scope))
|
||||
w.Header().Set(
|
||||
"WWW-Authenticate",
|
||||
fmt.Sprintf(
|
||||
`Bearer realm="%s", error="insufficient_scope", scope="%s"`,
|
||||
cfg.ResourceIdentifier+"/.well-known/oauth-protected-resource",
|
||||
required,
|
||||
),
|
||||
)
|
||||
w.Header().Set("Access-Control-Expose-Headers", "WWW-Authenticate")
|
||||
return nil, err
|
||||
return nil, fmt.Errorf("forbidden — insufficient scope")
|
||||
}
|
||||
return claims, nil
|
||||
}
|
||||
|
||||
// v1: only check signature, then continue
|
||||
if err := util.ValidateJWTOld(h); err != nil {
|
||||
if err := util.ValidateJWTLegacy(h); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
|
@ -352,7 +359,7 @@ func getAllowedOrigin(origin string, cfg *config.Config) string {
|
|||
func addCORSHeaders(w http.ResponseWriter, cfg *config.Config, allowedOrigin, requestHeaders string) {
|
||||
w.Header().Set("Access-Control-Allow-Origin", allowedOrigin)
|
||||
w.Header().Set("Access-Control-Allow-Methods", strings.Join(cfg.CORSConfig.AllowedMethods, ", "))
|
||||
w.Header().Set("Access-Control-Expose-Headers", "WWW-Authenticate")
|
||||
w.Header().Set("Access-Control-Expose-Headers", "WWW-Authenticate, MCP-Protocol-Version")
|
||||
if requestHeaders != "" {
|
||||
w.Header().Set("Access-Control-Allow-Headers", requestHeaders)
|
||||
} else {
|
||||
|
@ -360,6 +367,7 @@ func addCORSHeaders(w http.ResponseWriter, cfg *config.Config, allowedOrigin, re
|
|||
}
|
||||
if cfg.CORSConfig.AllowCredentials {
|
||||
w.Header().Set("Access-Control-Allow-Credentials", "true")
|
||||
w.Header().Set("MCP-Protocol-Version", ", ")
|
||||
}
|
||||
w.Header().Set("Vary", "Origin")
|
||||
w.Header().Set("X-Accel-Buffering", "no")
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue