mirror of
https://github.com/wso2/open-mcp-auth-proxy.git
synced 2025-06-28 01:23:30 +00:00
Update MCP proxy to adhere to the latest draft of MCP specification
This commit is contained in:
parent
9c2d37e2df
commit
85e5fe1c1d
7 changed files with 191 additions and 41 deletions
|
@ -92,12 +92,15 @@ func main() {
|
||||||
os.Exit(1)
|
os.Exit(1)
|
||||||
}
|
}
|
||||||
|
|
||||||
// 5. Build the main router
|
// 5. (Optional) Build the policy engine
|
||||||
mux := proxy.NewRouter(cfg, provider)
|
engine := &authz.DefaulPolicyEngine{}
|
||||||
|
|
||||||
|
// 6. Build the main router
|
||||||
|
mux := proxy.NewRouter(cfg, provider, engine)
|
||||||
|
|
||||||
listen_address := fmt.Sprintf(":%d", cfg.ListenPort)
|
listen_address := fmt.Sprintf(":%d", cfg.ListenPort)
|
||||||
|
|
||||||
// 6. Start the server
|
// 7. Start the server
|
||||||
srv := &http.Server{
|
srv := &http.Server{
|
||||||
Addr: listen_address,
|
Addr: listen_address,
|
||||||
Handler: mux,
|
Handler: mux,
|
||||||
|
@ -111,18 +114,18 @@ func main() {
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
// 7. Wait for shutdown signal
|
// 8. Wait for shutdown signal
|
||||||
stop := make(chan os.Signal, 1)
|
stop := make(chan os.Signal, 1)
|
||||||
signal.Notify(stop, os.Interrupt, syscall.SIGTERM)
|
signal.Notify(stop, os.Interrupt, syscall.SIGTERM)
|
||||||
<-stop
|
<-stop
|
||||||
logger.Info("Shutting down...")
|
logger.Info("Shutting down...")
|
||||||
|
|
||||||
// 8. First terminate subprocess if running
|
// 9. First terminate subprocess if running
|
||||||
if procManager != nil && procManager.IsRunning() {
|
if procManager != nil && procManager.IsRunning() {
|
||||||
procManager.Shutdown()
|
procManager.Shutdown()
|
||||||
}
|
}
|
||||||
|
|
||||||
// 9. Then shutdown the server
|
// 10. Then shutdown the server
|
||||||
logger.Info("Shutting down HTTP server...")
|
logger.Info("Shutting down HTTP server...")
|
||||||
shutdownCtx, cancel := proxy.NewShutdownContext(5 * time.Second)
|
shutdownCtx, cancel := proxy.NewShutdownContext(5 * time.Second)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
|
@ -94,3 +94,26 @@ func (p *defaultProvider) WellKnownHandler() http.HandlerFunc {
|
||||||
func (p *defaultProvider) RegisterHandler() http.HandlerFunc {
|
func (p *defaultProvider) RegisterHandler() http.HandlerFunc {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (p *defaultProvider) ProtectedResourceMetadataHandler() http.HandlerFunc {
|
||||||
|
return func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
meta := map[string]interface{}{
|
||||||
|
"resource": p.cfg.ResourceIdentifier,
|
||||||
|
"scopes_supported": p.cfg.ScopesSupported,
|
||||||
|
"authorization_servers": p.cfg.AuthorizationServers,
|
||||||
|
}
|
||||||
|
|
||||||
|
if p.cfg.JwksURI != "" {
|
||||||
|
meta["jwks_uri"] = p.cfg.JwksURI
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(p.cfg.BearerMethodsSupported) > 0 {
|
||||||
|
meta["bearer_methods_supported"] = p.cfg.BearerMethodsSupported
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := json.NewEncoder(w).Encode(meta); err != nil {
|
||||||
|
http.Error(w, "failed to encode metadata", http.StatusInternalServerError)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
19
internal/authz/policy.go
Normal file
19
internal/authz/policy.go
Normal file
|
@ -0,0 +1,19 @@
|
||||||
|
package authz
|
||||||
|
|
||||||
|
import "net/http"
|
||||||
|
|
||||||
|
type Decision int
|
||||||
|
|
||||||
|
const (
|
||||||
|
DecisionAllow Decision = iota
|
||||||
|
DecisionDeny
|
||||||
|
)
|
||||||
|
|
||||||
|
type PolicyResult struct {
|
||||||
|
Decision Decision
|
||||||
|
Message string
|
||||||
|
}
|
||||||
|
|
||||||
|
type PolicyEngine interface {
|
||||||
|
Evaluate(r *http.Request, claims *TokenClaims, requiredScope string) PolicyResult
|
||||||
|
}
|
|
@ -7,4 +7,5 @@ import "net/http"
|
||||||
type Provider interface {
|
type Provider interface {
|
||||||
WellKnownHandler() http.HandlerFunc
|
WellKnownHandler() http.HandlerFunc
|
||||||
RegisterHandler() http.HandlerFunc
|
RegisterHandler() http.HandlerFunc
|
||||||
|
ProtectedResourceMetadataHandler() http.HandlerFunc
|
||||||
}
|
}
|
||||||
|
|
|
@ -17,15 +17,15 @@ const (
|
||||||
|
|
||||||
// Common path configuration for all transport modes
|
// Common path configuration for all transport modes
|
||||||
type PathsConfig struct {
|
type PathsConfig struct {
|
||||||
SSE string `yaml:"sse"`
|
SSE string `yaml:"sse"`
|
||||||
Messages string `yaml:"messages"`
|
Messages string `yaml:"messages"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// StdioConfig contains stdio-specific configuration
|
// StdioConfig contains stdio-specific configuration
|
||||||
type StdioConfig struct {
|
type StdioConfig struct {
|
||||||
Enabled bool `yaml:"enabled"`
|
Enabled bool `yaml:"enabled"`
|
||||||
UserCommand string `yaml:"user_command"` // The command provided by the user
|
UserCommand string `yaml:"user_command"` // The command provided by the user
|
||||||
WorkDir string `yaml:"work_dir"` // Working directory (optional)
|
WorkDir string `yaml:"work_dir"` // Working directory (optional)
|
||||||
Args []string `yaml:"args,omitempty"` // Additional arguments
|
Args []string `yaml:"args,omitempty"` // Additional arguments
|
||||||
Env []string `yaml:"env,omitempty"` // Environment variables
|
Env []string `yaml:"env,omitempty"` // Environment variables
|
||||||
}
|
}
|
||||||
|
@ -83,23 +83,31 @@ type DefaultConfig struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
type Config struct {
|
type Config struct {
|
||||||
AuthServerBaseURL string
|
AuthServerBaseURL string
|
||||||
ListenPort int `yaml:"listen_port"`
|
ListenPort int `yaml:"listen_port"`
|
||||||
BaseURL string `yaml:"base_url"`
|
BaseURL string `yaml:"base_url"`
|
||||||
Port int `yaml:"port"`
|
Port int `yaml:"port"`
|
||||||
JWKSURL string
|
JWKSURL string
|
||||||
TimeoutSeconds int `yaml:"timeout_seconds"`
|
TimeoutSeconds int `yaml:"timeout_seconds"`
|
||||||
PathMapping map[string]string `yaml:"path_mapping"`
|
PathMapping map[string]string `yaml:"path_mapping"`
|
||||||
Mode string `yaml:"mode"`
|
Mode string `yaml:"mode"`
|
||||||
CORSConfig CORSConfig `yaml:"cors"`
|
CORSConfig CORSConfig `yaml:"cors"`
|
||||||
TransportMode TransportMode `yaml:"transport_mode"`
|
TransportMode TransportMode `yaml:"transport_mode"`
|
||||||
Paths PathsConfig `yaml:"paths"`
|
Paths PathsConfig `yaml:"paths"`
|
||||||
Stdio StdioConfig `yaml:"stdio"`
|
Stdio StdioConfig `yaml:"stdio"`
|
||||||
|
RequiredScopes map[string]string `yaml:"required_scopes"`
|
||||||
|
|
||||||
// Nested config for Asgardeo
|
// Nested config for Asgardeo
|
||||||
Demo DemoConfig `yaml:"demo"`
|
Demo DemoConfig `yaml:"demo"`
|
||||||
Asgardeo AsgardeoConfig `yaml:"asgardeo"`
|
Asgardeo AsgardeoConfig `yaml:"asgardeo"`
|
||||||
Default DefaultConfig `yaml:"default"`
|
Default DefaultConfig `yaml:"default"`
|
||||||
|
|
||||||
|
// Protected resource metadata
|
||||||
|
ResourceIdentifier string `yaml:"resource_identifier"`
|
||||||
|
ScopesSupported map[string]string `yaml:"scopes_supported"`
|
||||||
|
AuthorizationServers []string `yaml:"authorization_servers"`
|
||||||
|
JwksURI string `yaml:"jwks_uri,omitempty"`
|
||||||
|
BearerMethodsSupported []string `yaml:"bearer_methods_supported,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// Validate checks if the config is valid based on transport mode
|
// Validate checks if the config is valid based on transport mode
|
||||||
|
@ -165,12 +173,12 @@ func LoadConfig(path string) (*Config, error) {
|
||||||
if err := decoder.Decode(&cfg); err != nil {
|
if err := decoder.Decode(&cfg); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// Set default values
|
// Set default values
|
||||||
if cfg.TimeoutSeconds == 0 {
|
if cfg.TimeoutSeconds == 0 {
|
||||||
cfg.TimeoutSeconds = 15 // default
|
cfg.TimeoutSeconds = 15 // default
|
||||||
}
|
}
|
||||||
|
|
||||||
// Set default transport mode if not specified
|
// Set default transport mode if not specified
|
||||||
if cfg.TransportMode == "" {
|
if cfg.TransportMode == "" {
|
||||||
cfg.TransportMode = SSETransport // Default to SSE
|
cfg.TransportMode = SSETransport // Default to SSE
|
||||||
|
@ -180,11 +188,11 @@ func LoadConfig(path string) (*Config, error) {
|
||||||
if cfg.Port == 0 {
|
if cfg.Port == 0 {
|
||||||
cfg.Port = 8000 // default
|
cfg.Port = 8000 // default
|
||||||
}
|
}
|
||||||
|
|
||||||
// Validate the configuration
|
// Validate the configuration
|
||||||
if err := cfg.Validate(); err != nil {
|
if err := cfg.Validate(); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
return &cfg, nil
|
return &cfg, nil
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,7 +1,14 @@
|
||||||
package constants
|
package constants
|
||||||
|
|
||||||
|
import "time"
|
||||||
|
|
||||||
// Package constant provides constants for the MCP Auth Proxy
|
// Package constant provides constants for the MCP Auth Proxy
|
||||||
|
|
||||||
const (
|
const (
|
||||||
ASGARDEO_BASE_URL = "https://api.asgardeo.io/t/"
|
ASGARDEO_BASE_URL = "https://api.asgardeo.io/t/"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// MCP specification version cutover date
|
||||||
|
var SpecCutoverDate = time.Date(2025, 3, 26, 0, 0, 0, 0, time.UTC)
|
||||||
|
|
||||||
|
const TimeLayout = "2006-01-02"
|
||||||
|
|
|
@ -2,6 +2,7 @@ package proxy
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httputil"
|
"net/http/httputil"
|
||||||
"net/url"
|
"net/url"
|
||||||
|
@ -10,14 +11,15 @@ import (
|
||||||
|
|
||||||
"github.com/wso2/open-mcp-auth-proxy/internal/authz"
|
"github.com/wso2/open-mcp-auth-proxy/internal/authz"
|
||||||
"github.com/wso2/open-mcp-auth-proxy/internal/config"
|
"github.com/wso2/open-mcp-auth-proxy/internal/config"
|
||||||
"github.com/wso2/open-mcp-auth-proxy/internal/logging"
|
"github.com/wso2/open-mcp-auth-proxy/internal/constants"
|
||||||
|
logger "github.com/wso2/open-mcp-auth-proxy/internal/logging"
|
||||||
"github.com/wso2/open-mcp-auth-proxy/internal/util"
|
"github.com/wso2/open-mcp-auth-proxy/internal/util"
|
||||||
)
|
)
|
||||||
|
|
||||||
// NewRouter builds an http.ServeMux that routes
|
// NewRouter builds an http.ServeMux that routes
|
||||||
// * /authorize, /token, /register, /.well-known to the provider or proxy
|
// * /authorize, /token, /register, /.well-known to the provider or proxy
|
||||||
// * MCP paths to the MCP server, etc.
|
// * MCP paths to the MCP server, etc.
|
||||||
func NewRouter(cfg *config.Config, provider authz.Provider) http.Handler {
|
func NewRouter(cfg *config.Config, provider authz.Provider, policyEngine authz.PolicyEngine) http.Handler {
|
||||||
mux := http.NewServeMux()
|
mux := http.NewServeMux()
|
||||||
|
|
||||||
modifiers := map[string]RequestModifier{
|
modifiers := map[string]RequestModifier{
|
||||||
|
@ -55,6 +57,20 @@ func NewRouter(cfg *config.Config, provider authz.Provider) http.Handler {
|
||||||
defaultPaths = append(defaultPaths, "/.well-known/oauth-authorization-server")
|
defaultPaths = append(defaultPaths, "/.well-known/oauth-authorization-server")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
mux.HandleFunc("/.well-known/oauth-protected-resource", func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
origin := r.Header.Get("Origin")
|
||||||
|
allowed := getAllowedOrigin(origin, cfg)
|
||||||
|
if r.Method == http.MethodOptions {
|
||||||
|
addCORSHeaders(w, cfg, allowed, r.Header.Get("Access-Control-Request-Headers"))
|
||||||
|
w.WriteHeader(http.StatusNoContent)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
addCORSHeaders(w, cfg, allowed, "")
|
||||||
|
provider.ProtectedResourceMetadataHandler()(w, r)
|
||||||
|
})
|
||||||
|
registeredPaths["/.well-known/oauth-protected-resource"] = true
|
||||||
|
|
||||||
defaultPaths = append(defaultPaths, "/authorize")
|
defaultPaths = append(defaultPaths, "/authorize")
|
||||||
defaultPaths = append(defaultPaths, "/token")
|
defaultPaths = append(defaultPaths, "/token")
|
||||||
defaultPaths = append(defaultPaths, "/register")
|
defaultPaths = append(defaultPaths, "/register")
|
||||||
|
@ -76,7 +92,7 @@ func NewRouter(cfg *config.Config, provider authz.Provider) http.Handler {
|
||||||
|
|
||||||
for _, path := range defaultPaths {
|
for _, path := range defaultPaths {
|
||||||
if !registeredPaths[path] {
|
if !registeredPaths[path] {
|
||||||
mux.HandleFunc(path, buildProxyHandler(cfg, modifiers))
|
mux.HandleFunc(path, buildProxyHandler(cfg, modifiers, policyEngine))
|
||||||
registeredPaths[path] = true
|
registeredPaths[path] = true
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -84,14 +100,14 @@ func NewRouter(cfg *config.Config, provider authz.Provider) http.Handler {
|
||||||
// MCP paths
|
// MCP paths
|
||||||
mcpPaths := cfg.GetMCPPaths()
|
mcpPaths := cfg.GetMCPPaths()
|
||||||
for _, path := range mcpPaths {
|
for _, path := range mcpPaths {
|
||||||
mux.HandleFunc(path, buildProxyHandler(cfg, modifiers))
|
mux.HandleFunc(path, buildProxyHandler(cfg, modifiers, policyEngine))
|
||||||
registeredPaths[path] = true
|
registeredPaths[path] = true
|
||||||
}
|
}
|
||||||
|
|
||||||
// Register paths from PathMapping that haven't been registered yet
|
// Register paths from PathMapping that haven't been registered yet
|
||||||
for path := range cfg.PathMapping {
|
for path := range cfg.PathMapping {
|
||||||
if !registeredPaths[path] {
|
if !registeredPaths[path] {
|
||||||
mux.HandleFunc(path, buildProxyHandler(cfg, modifiers))
|
mux.HandleFunc(path, buildProxyHandler(cfg, modifiers, policyEngine))
|
||||||
registeredPaths[path] = true
|
registeredPaths[path] = true
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -99,14 +115,14 @@ func NewRouter(cfg *config.Config, provider authz.Provider) http.Handler {
|
||||||
return mux
|
return mux
|
||||||
}
|
}
|
||||||
|
|
||||||
func buildProxyHandler(cfg *config.Config, modifiers map[string]RequestModifier) http.HandlerFunc {
|
func buildProxyHandler(cfg *config.Config, modifiers map[string]RequestModifier, policyEngine authz.PolicyEngine) http.HandlerFunc {
|
||||||
// Parse the base URLs up front
|
// Parse the base URLs up front
|
||||||
authBase, err := url.Parse(cfg.AuthServerBaseURL)
|
authBase, err := url.Parse(cfg.AuthServerBaseURL)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Error("Invalid auth server URL: %v", err)
|
logger.Error("Invalid auth server URL: %v", err)
|
||||||
panic(err) // Fatal error that prevents startup
|
panic(err) // Fatal error that prevents startup
|
||||||
}
|
}
|
||||||
|
|
||||||
mcpBase, err := url.Parse(cfg.BaseURL)
|
mcpBase, err := url.Parse(cfg.BaseURL)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Error("Invalid MCP server URL: %v", err)
|
logger.Error("Invalid MCP server URL: %v", err)
|
||||||
|
@ -141,6 +157,10 @@ func buildProxyHandler(cfg *config.Config, modifiers map[string]RequestModifier)
|
||||||
// Add CORS headers to all responses
|
// Add CORS headers to all responses
|
||||||
addCORSHeaders(w, cfg, allowedOrigin, "")
|
addCORSHeaders(w, cfg, allowedOrigin, "")
|
||||||
|
|
||||||
|
versionRaw := r.Header.Get("MCP-Protocol-Version")
|
||||||
|
ver, err := time.Parse(constants.TimeLayout, versionRaw)
|
||||||
|
isLatestSpec := err == nil && !ver.Before(constants.SpecCutoverDate)
|
||||||
|
|
||||||
// Decide whether the request should go to the auth server or MCP
|
// Decide whether the request should go to the auth server or MCP
|
||||||
var targetURL *url.URL
|
var targetURL *url.URL
|
||||||
isSSE := false
|
isSSE := false
|
||||||
|
@ -148,13 +168,29 @@ func buildProxyHandler(cfg *config.Config, modifiers map[string]RequestModifier)
|
||||||
if isAuthPath(r.URL.Path) {
|
if isAuthPath(r.URL.Path) {
|
||||||
targetURL = authBase
|
targetURL = authBase
|
||||||
} else if isMCPPath(r.URL.Path, cfg) {
|
} else if isMCPPath(r.URL.Path, cfg) {
|
||||||
// Validate JWT for MCP paths if required
|
if ssePaths[r.URL.Path] {
|
||||||
// Placeholder for JWT validation logic
|
if err := authorizeSSE(w, r, isLatestSpec, cfg.ResourceIdentifier); err != nil {
|
||||||
if err := util.ValidateJWT(r.Header.Get("Authorization")); err != nil {
|
http.Error(w, err.Error(), http.StatusUnauthorized)
|
||||||
logger.Warn("Unauthorized request to %s: %v", r.URL.Path, err)
|
return
|
||||||
http.Error(w, "Unauthorized", http.StatusUnauthorized)
|
}
|
||||||
return
|
isSSE = true
|
||||||
|
} else {
|
||||||
|
claims, err := authorizeMCP(w, r, isLatestSpec, cfg)
|
||||||
|
if err != nil {
|
||||||
|
http.Error(w, err.Error(), http.StatusUnauthorized)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if isLatestSpec {
|
||||||
|
scope := cfg.ScopesSupported[r.URL.Path]
|
||||||
|
pr := policyEngine.Evaluate(r, claims, scope)
|
||||||
|
if pr.Decision == authz.DecisionDeny {
|
||||||
|
http.Error(w, "Forbidden: "+pr.Message, http.StatusForbidden)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
targetURL = mcpBase
|
targetURL = mcpBase
|
||||||
if ssePaths[r.URL.Path] {
|
if ssePaths[r.URL.Path] {
|
||||||
isSSE = true
|
isSSE = true
|
||||||
|
@ -214,7 +250,17 @@ func buildProxyHandler(cfg *config.Config, modifiers map[string]RequestModifier)
|
||||||
},
|
},
|
||||||
ModifyResponse: func(resp *http.Response) error {
|
ModifyResponse: func(resp *http.Response) error {
|
||||||
logger.Debug("Response from %s%s: %d", resp.Request.URL.Host, resp.Request.URL.Path, resp.StatusCode)
|
logger.Debug("Response from %s%s: %d", resp.Request.URL.Host, resp.Request.URL.Path, resp.StatusCode)
|
||||||
resp.Header.Del("Access-Control-Allow-Origin") // Avoid upstream conflicts
|
if resp.StatusCode == http.StatusUnauthorized {
|
||||||
|
resp.Header.Set(
|
||||||
|
"WWW-Authenticate",
|
||||||
|
fmt.Sprintf(
|
||||||
|
`Bearer resource_metadata="%s"`,
|
||||||
|
cfg.ResourceIdentifier+"/.well-known/oauth-protected-resource",
|
||||||
|
))
|
||||||
|
resp.Header.Set("Access-Control-Expose-Headers", "WWW-Authenticate")
|
||||||
|
}
|
||||||
|
|
||||||
|
resp.Header.Del("Access-Control-Allow-Origin")
|
||||||
return nil
|
return nil
|
||||||
},
|
},
|
||||||
ErrorHandler: func(rw http.ResponseWriter, req *http.Request, err error) {
|
ErrorHandler: func(rw http.ResponseWriter, req *http.Request, err error) {
|
||||||
|
@ -236,7 +282,7 @@ func buildProxyHandler(cfg *config.Config, modifiers map[string]RequestModifier)
|
||||||
w.Header().Set("X-Accel-Buffering", "no")
|
w.Header().Set("X-Accel-Buffering", "no")
|
||||||
w.Header().Set("Cache-Control", "no-cache")
|
w.Header().Set("Cache-Control", "no-cache")
|
||||||
w.Header().Set("Connection", "keep-alive")
|
w.Header().Set("Connection", "keep-alive")
|
||||||
|
w.Header().Set("Content-Type", "text/event-stream")
|
||||||
// Keep SSE connections open
|
// Keep SSE connections open
|
||||||
HandleSSE(w, r, rp)
|
HandleSSE(w, r, rp)
|
||||||
} else {
|
} else {
|
||||||
|
@ -248,6 +294,47 @@ func buildProxyHandler(cfg *config.Config, modifiers map[string]RequestModifier)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Check if the request is for SSE handshake and authorize it
|
||||||
|
func authorizeSSE(w http.ResponseWriter, r *http.Request, isLatestSpec bool, resourceID string) error {
|
||||||
|
h := r.Header.Get("Authorization")
|
||||||
|
if !strings.HasPrefix(h, "Bearer ") {
|
||||||
|
if isLatestSpec {
|
||||||
|
realm := resourceID + "/.well-known/oauth-protected-resource"
|
||||||
|
w.Header().Set("WWW-Authenticate", fmt.Sprintf(`Bearer resource_metadata="%s"`, realm))
|
||||||
|
w.Header().Set("Access-Control-Expose-Headers", "WWW-Authenticate")
|
||||||
|
}
|
||||||
|
|
||||||
|
return fmt.Errorf("missing bearer token")
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handles both v1 (just signature) and v2 (aud + scope) flows
|
||||||
|
func authorizeMCP(w http.ResponseWriter, r *http.Request, isLatestSpec bool, cfg *config.Config) (*authz.TokenClaims, error) {
|
||||||
|
h := r.Header.Get("Authorization")
|
||||||
|
audience := cfg.ResourceIdentifier
|
||||||
|
if isLatestSpec {
|
||||||
|
scope := cfg.ScopesSupported[r.URL.Path]
|
||||||
|
claims, err := util.ValidateJWT(r.Header.Get("MCP-Protocol-Version"), h, audience, scope)
|
||||||
|
if err != nil {
|
||||||
|
realm := audience + "/.well-known/oauth-protected-resource"
|
||||||
|
w.Header().Set("WWW-Authenticate",
|
||||||
|
fmt.Sprintf(`Bearer realm="%s", error="insufficient_scope", scope="%s"`, realm, scope))
|
||||||
|
w.Header().Set("Access-Control-Expose-Headers", "WWW-Authenticate")
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return claims, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// v1: only check signature, then continue
|
||||||
|
if err := util.ValidateJWTOld(h); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return &authz.TokenClaims{}, nil
|
||||||
|
}
|
||||||
|
|
||||||
func getAllowedOrigin(origin string, cfg *config.Config) string {
|
func getAllowedOrigin(origin string, cfg *config.Config) string {
|
||||||
if origin == "" {
|
if origin == "" {
|
||||||
return cfg.CORSConfig.AllowedOrigins[0] // Default to first allowed origin
|
return cfg.CORSConfig.AllowedOrigins[0] // Default to first allowed origin
|
||||||
|
@ -265,6 +352,7 @@ func getAllowedOrigin(origin string, cfg *config.Config) string {
|
||||||
func addCORSHeaders(w http.ResponseWriter, cfg *config.Config, allowedOrigin, requestHeaders string) {
|
func addCORSHeaders(w http.ResponseWriter, cfg *config.Config, allowedOrigin, requestHeaders string) {
|
||||||
w.Header().Set("Access-Control-Allow-Origin", allowedOrigin)
|
w.Header().Set("Access-Control-Allow-Origin", allowedOrigin)
|
||||||
w.Header().Set("Access-Control-Allow-Methods", strings.Join(cfg.CORSConfig.AllowedMethods, ", "))
|
w.Header().Set("Access-Control-Allow-Methods", strings.Join(cfg.CORSConfig.AllowedMethods, ", "))
|
||||||
|
w.Header().Set("Access-Control-Expose-Headers", "WWW-Authenticate")
|
||||||
if requestHeaders != "" {
|
if requestHeaders != "" {
|
||||||
w.Header().Set("Access-Control-Allow-Headers", requestHeaders)
|
w.Header().Set("Access-Control-Allow-Headers", requestHeaders)
|
||||||
} else {
|
} else {
|
||||||
|
@ -283,6 +371,7 @@ func isAuthPath(path string) bool {
|
||||||
"/token": true,
|
"/token": true,
|
||||||
"/register": true,
|
"/register": true,
|
||||||
"/.well-known/oauth-authorization-server": true,
|
"/.well-known/oauth-authorization-server": true,
|
||||||
|
"/.well-known/oauth-protected-resource": true,
|
||||||
}
|
}
|
||||||
if strings.HasPrefix(path, "/u/") {
|
if strings.HasPrefix(path, "/u/") {
|
||||||
return true
|
return true
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue