diff --git a/README.md b/README.md
index 6be3ece..49a471f 100644
--- a/README.md
+++ b/README.md
@@ -10,16 +10,31 @@ A lightweight authorization proxy for Model Context Protocol (MCP) servers that

-## What it Does
-
-Open MCP Auth Proxy sits between MCP clients and your MCP server to:
+## ๐ก๏ธ What it Does?
- Intercept incoming requests
- Validate authorization tokens
- Offload authentication and authorization to OAuth-compliant Identity Providers
- Support the MCP authorization protocol
-## Quick Start
+
+## ๐ Features
+
+- **Dynamic Authorization** based on MCP Authorization Specification.
+- **JWT Validation** (signature, audience, and scopes).
+- **Identity Provider Integration** (OAuth/OIDC via Asgardeo, Auth0, Keycloak).
+- **Protocol Version Negotiation** via `MCP-Protocol-Version` header.
+- **Comprehensive Authentication Feedback** via RFC-compliant challenges.
+- **Flexible Transport Modes**: SSE and stdio.
+
+## ๐ MCP Specification Verions
+
+| Version | Behavior |
+| :-------------------- | :----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
+| 2025-03-26 | Only signature check of Bearer JWT on both `/sse` and `/message`
No scope or audience enforcement |
+| Latest(draft) | Read `MCP-Protocol-Version` from client header
SSE handshake returns `WWW-Authenticate: Bearer resource_metadata="โฆ"`
`/message` enforces:
`aud` claim == `ResourceIdentifier`
`scope` claim contains `requiredScope`
Scope based access control
Rich `WWW-Authenticate` on 401s
Serves `/โ.well-known/oauth-protected-resource` JSON |
+
+## ๐ ๏ธ Quick Start
### Prerequisites
@@ -67,7 +82,7 @@ Open MCP Auth Proxy sits between MCP clients and your MCP server to:
3. Connect using an MCP client like [MCP Inspector](https://github.com/shashimalcse/inspector)(This is a temporary fork with fixes for authentication [issues](https://github.com/modelcontextprotocol/typescript-sdk/issues/257) in the original implementation)
-## Connect an Identity Provider
+## ๐ Integrate an Identity Provider
### Asgardeo
@@ -81,13 +96,17 @@ To enable authorization through your Asgardeo organization:
3. Update `config.yaml` with the following parameters.
```yaml
-base_url: "http://localhost:8000" # URL of your MCP server
-listen_port: 8080 # Address where the proxy will listen
+base_url: "http://localhost:8000" # URL of your MCP server
+listen_port: 8080 # Address where the proxy will listen
-asgardeo:
- org_name: "" # Your Asgardeo org name
- client_id: "" # Client ID of the M2M app
- client_secret: "" # Client secret of the M2M app
+resource_identifier: "http://localhost:8080" # Proxy server URL
+scopes_supported: # Scopes required to defined for the MCP server
+- "read:tools"
+- "read:resources"
+audience: "" # Access token audience
+authorization_servers: # Authorization server issuer identifier(s)
+- "https://api.asgardeo.io/t/acme"
+jwks_uri: "https://api.asgardeo.io/t/acme/oauth2/jwks" # JWKS URL
```
4. Start the proxy with Asgardeo integration:
@@ -101,7 +120,7 @@ asgardeo:
- [Auth0](docs/integrations/Auth0.md)
- [Keycloak](docs/integrations/keycloak.md)
-# Advanced Configuration
+# โ๏ธ Advanced Configuration
### Transport Modes
@@ -167,7 +186,7 @@ The proxy will:
- Handle all authorization requirements
- Forward messages between clients and the server
-### Complete Configuration Reference
+### ๐ Complete Configuration Reference
```yaml
# Common configuration
@@ -210,13 +229,17 @@ demo:
client_secret: "qFHfiBp5gNGAO9zV4YPnDofBzzfInatfUbHyPZvM0jka"
# Asgardeo configuration (used with --asgardeo flag)
-asgardeo:
- org_name: ""
- client_id: ""
- client_secret: ""
+resource_identifier: "http://localhost:8080"
+scopes_supported:
+- "read:tools"
+- "read:resources"
+audience: ""
+authorization_servers:
+- "https://api.asgardeo.io/t/acme"
+jwks_uri: "https://api.asgardeo.io/t/acme/oauth2/jwks"
```
-### Build from source
+### ๐ฅ๏ธ Build from source
```bash
git clone https://github.com/wso2/open-mcp-auth-proxy
diff --git a/cmd/proxy/main.go b/cmd/proxy/main.go
index 6424f18..0208ead 100644
--- a/cmd/proxy/main.go
+++ b/cmd/proxy/main.go
@@ -11,7 +11,6 @@ import (
"github.com/wso2/open-mcp-auth-proxy/internal/authz"
"github.com/wso2/open-mcp-auth-proxy/internal/config"
- "github.com/wso2/open-mcp-auth-proxy/internal/constants"
"github.com/wso2/open-mcp-auth-proxy/internal/logging"
"github.com/wso2/open-mcp-auth-proxy/internal/proxy"
"github.com/wso2/open-mcp-auth-proxy/internal/subprocess"
@@ -68,23 +67,7 @@ func main() {
}
// 3. Create the chosen provider
- var provider authz.Provider
- if *demoMode {
- cfg.Mode = "demo"
- cfg.AuthServerBaseURL = constants.ASGARDEO_BASE_URL + cfg.Demo.OrgName + "/oauth2"
- cfg.JWKSURL = constants.ASGARDEO_BASE_URL + cfg.Demo.OrgName + "/oauth2/jwks"
- provider = authz.NewAsgardeoProvider(cfg)
- } else if *asgardeoMode {
- cfg.Mode = "asgardeo"
- cfg.AuthServerBaseURL = constants.ASGARDEO_BASE_URL + cfg.Asgardeo.OrgName + "/oauth2"
- cfg.JWKSURL = constants.ASGARDEO_BASE_URL + cfg.Asgardeo.OrgName + "/oauth2/jwks"
- provider = authz.NewAsgardeoProvider(cfg)
- } else {
- cfg.Mode = "default"
- cfg.JWKSURL = cfg.Default.JWKSURL
- cfg.AuthServerBaseURL = cfg.Default.BaseURL
- provider = authz.NewDefaultProvider(cfg)
- }
+ var provider authz.Provider = MakeProvider(cfg, *demoMode, *asgardeoMode)
// 4. (Optional) Fetch JWKS if you want local JWT validation
if err := util.FetchJWKS(cfg.JWKSURL); err != nil {
@@ -92,12 +75,15 @@ func main() {
os.Exit(1)
}
- // 5. Build the main router
- mux := proxy.NewRouter(cfg, provider)
+ // 5. (Optional) Build the access controler
+ accessController := &authz.ScopeValidator{}
+
+ // 6. Build the main router
+ mux := proxy.NewRouter(cfg, provider, accessController)
listen_address := fmt.Sprintf(":%d", cfg.ListenPort)
- // 6. Start the server
+ // 7. Start the server
srv := &http.Server{
Addr: listen_address,
Handler: mux,
@@ -111,18 +97,18 @@ func main() {
}
}()
- // 7. Wait for shutdown signal
+ // 8. Wait for shutdown signal
stop := make(chan os.Signal, 1)
signal.Notify(stop, os.Interrupt, syscall.SIGTERM)
<-stop
logger.Info("Shutting down...")
- // 8. First terminate subprocess if running
+ // 9. First terminate subprocess if running
if procManager != nil && procManager.IsRunning() {
procManager.Shutdown()
}
- // 9. Then shutdown the server
+ // 10. Then shutdown the server
logger.Info("Shutting down HTTP server...")
shutdownCtx, cancel := proxy.NewShutdownContext(5 * time.Second)
defer cancel()
diff --git a/cmd/proxy/provider.go b/cmd/proxy/provider.go
new file mode 100644
index 0000000..be4ee21
--- /dev/null
+++ b/cmd/proxy/provider.go
@@ -0,0 +1,45 @@
+package main
+
+import (
+ "github.com/wso2/open-mcp-auth-proxy/internal/authz"
+ "github.com/wso2/open-mcp-auth-proxy/internal/config"
+ "github.com/wso2/open-mcp-auth-proxy/internal/constants"
+)
+
+func MakeProvider(cfg *config.Config, demoMode, asgardeoMode bool) authz.Provider {
+ var mode, orgName string
+ switch {
+ case demoMode:
+ mode = "demo"
+ orgName = cfg.Demo.OrgName
+ case asgardeoMode:
+ mode = "asgardeo"
+ orgName = cfg.Asgardeo.OrgName
+ default:
+ mode = "default"
+ }
+ cfg.Mode = mode
+
+ switch mode {
+ case "demo", "asgardeo":
+ if len(cfg.AuthorizationServers) == 0 && cfg.JwksURI == "" {
+ base := constants.ASGARDEO_BASE_URL + orgName + "/oauth2"
+ cfg.AuthServerBaseURL = base
+ cfg.JWKSURL = base + "/jwks"
+ } else {
+ cfg.AuthServerBaseURL = cfg.AuthorizationServers[0]
+ cfg.JWKSURL = cfg.JwksURI
+ }
+ return authz.NewAsgardeoProvider(cfg)
+
+ default:
+ if cfg.Default.BaseURL != "" && cfg.Default.JWKSURL != "" {
+ cfg.AuthServerBaseURL = cfg.Default.BaseURL
+ cfg.JWKSURL = cfg.Default.JWKSURL
+ } else if len(cfg.AuthorizationServers) > 0 {
+ cfg.AuthServerBaseURL = cfg.AuthorizationServers[0]
+ cfg.JWKSURL = cfg.JwksURI
+ }
+ return authz.NewDefaultProvider(cfg)
+ }
+}
\ No newline at end of file
diff --git a/config.yaml b/config.yaml
index 5621195..d97c8ca 100644
--- a/config.yaml
+++ b/config.yaml
@@ -45,3 +45,18 @@ demo:
org_name: "openmcpauthdemo"
client_id: "N0U9e_NNGr9mP_0fPnPfPI0a6twa"
client_secret: "qFHfiBp5gNGAO9zV4YPnDofBzzfInatfUbHyPZvM0jka"
+
+# Protected resource metadata
+resource_identifier: http://localhost:3000
+audience: mcp_proxy
+scopes_supported:
+ - "tools":"read:tools"
+ - "resources":"read:resources"
+ - "prompts":"read:prompts"
+authorization_servers:
+ - https://api.asgardeo.io/t/acme/
+jwks_uri: https://api.asgardeo.io/t/acme/oauth2/jwks
+bearer_methods_supported:
+ - header
+ - body
+ - query
diff --git a/internal/authz/access_control.go b/internal/authz/access_control.go
new file mode 100644
index 0000000..1f7ce7b
--- /dev/null
+++ b/internal/authz/access_control.go
@@ -0,0 +1,24 @@
+package authz
+
+import (
+ "net/http"
+
+ "github.com/golang-jwt/jwt/v4"
+ "github.com/wso2/open-mcp-auth-proxy/internal/config"
+)
+
+type Decision int
+
+const (
+ DecisionAllow Decision = iota
+ DecisionDeny
+)
+
+type AccessControlResult struct {
+ Decision Decision
+ Message string
+}
+
+type AccessControl interface {
+ ValidateAccess(r *http.Request, claims *jwt.MapClaims, config *config.Config) AccessControlResult
+}
diff --git a/internal/authz/asgardeo.go b/internal/authz/asgardeo.go
index 9b8fdc5..aa699b6 100644
--- a/internal/authz/asgardeo.go
+++ b/internal/authz/asgardeo.go
@@ -350,3 +350,23 @@ func randomString(n int) string {
}
return string(b)
}
+
+func (p *asgardeoProvider) ProtectedResourceMetadataHandler() http.HandlerFunc {
+ return func(w http.ResponseWriter, r *http.Request) {
+ w.Header().Set("Content-Type", "application/json")
+ meta := map[string]interface{}{
+ "resource": p.cfg.ResourceIdentifier,
+ "scopes_supported": p.cfg.ScopesSupported,
+ "authorization_servers": p.cfg.AuthorizationServers,
+ }
+ if p.cfg.JwksURI != "" {
+ meta["jwks_uri"] = p.cfg.JwksURI
+ }
+ if len(p.cfg.BearerMethodsSupported) > 0 {
+ meta["bearer_methods_supported"] = p.cfg.BearerMethodsSupported
+ }
+ if err := json.NewEncoder(w).Encode(meta); err != nil {
+ http.Error(w, "failed to encode metadata", http.StatusInternalServerError)
+ }
+ }
+}
diff --git a/internal/authz/default.go b/internal/authz/default.go
index 929f586..dc8900d 100644
--- a/internal/authz/default.go
+++ b/internal/authz/default.go
@@ -94,3 +94,27 @@ func (p *defaultProvider) WellKnownHandler() http.HandlerFunc {
func (p *defaultProvider) RegisterHandler() http.HandlerFunc {
return nil
}
+
+func (p *defaultProvider) ProtectedResourceMetadataHandler() http.HandlerFunc {
+ return func(w http.ResponseWriter, r *http.Request) {
+ w.Header().Set("Content-Type", "application/json")
+ meta := map[string]interface{}{
+ "audience": p.cfg.Audience,
+ "resource": p.cfg.ResourceIdentifier,
+ "scopes_supported": p.cfg.ScopesSupported,
+ "authorization_servers": p.cfg.AuthorizationServers,
+ }
+
+ if p.cfg.JwksURI != "" {
+ meta["jwks_uri"] = p.cfg.JwksURI
+ }
+
+ if len(p.cfg.BearerMethodsSupported) > 0 {
+ meta["bearer_methods_supported"] = p.cfg.BearerMethodsSupported
+ }
+
+ if err := json.NewEncoder(w).Encode(meta); err != nil {
+ http.Error(w, "failed to encode metadata", http.StatusInternalServerError)
+ }
+ }
+}
diff --git a/internal/authz/provider.go b/internal/authz/provider.go
index 1629cf4..42a8343 100644
--- a/internal/authz/provider.go
+++ b/internal/authz/provider.go
@@ -7,4 +7,5 @@ import "net/http"
type Provider interface {
WellKnownHandler() http.HandlerFunc
RegisterHandler() http.HandlerFunc
+ ProtectedResourceMetadataHandler() http.HandlerFunc
}
diff --git a/internal/authz/scope_validator.go b/internal/authz/scope_validator.go
new file mode 100644
index 0000000..004fd80
--- /dev/null
+++ b/internal/authz/scope_validator.go
@@ -0,0 +1,71 @@
+package authz
+
+import (
+ "fmt"
+ "net/http"
+ "strings"
+
+ "github.com/golang-jwt/jwt/v4"
+ "github.com/wso2/open-mcp-auth-proxy/internal/config"
+ "github.com/wso2/open-mcp-auth-proxy/internal/util"
+)
+
+type ScopeValidator struct{}
+
+// Evaluate and checks the token claims against one or more required scopes.
+func (d *ScopeValidator) ValidateAccess(
+ r *http.Request,
+ claims *jwt.MapClaims,
+ config *config.Config,
+) AccessControlResult {
+ env, err := util.ParseRPCRequest(r)
+ if err != nil {
+ return AccessControlResult{DecisionDeny, "bad JSON-RPC request"}
+ }
+ requiredScopes := util.GetRequiredScopes(config, env.Method)
+ if len(requiredScopes) == 0 {
+ return AccessControlResult{DecisionAllow, ""}
+ }
+
+ required := make(map[string]struct{}, len(requiredScopes))
+ for _, s := range requiredScopes {
+ s = strings.TrimSpace(s)
+ if s != "" {
+ required[s] = struct{}{}
+ }
+ }
+
+ var tokenScopes []string
+ if claims, ok := (*claims)["scope"]; ok {
+ switch v := claims.(type) {
+ case string:
+ tokenScopes = strings.Fields(v)
+ case []interface{}:
+ for _, x := range v {
+ if s, ok := x.(string); ok && s != "" {
+ tokenScopes = append(tokenScopes, s)
+ }
+ }
+ }
+ }
+
+ tokenScopeSet := make(map[string]struct{}, len(tokenScopes))
+ for _, s := range tokenScopes {
+ tokenScopeSet[s] = struct{}{}
+ }
+
+ var missing []string
+ for s := range required {
+ if _, ok := tokenScopeSet[s]; !ok {
+ missing = append(missing, s)
+ }
+ }
+
+ if len(missing) == 0 {
+ return AccessControlResult{DecisionAllow, ""}
+ }
+ return AccessControlResult{
+ DecisionDeny,
+ fmt.Sprintf("missing required scope(s): %s", strings.Join(missing, ", ")),
+ }
+}
diff --git a/internal/config/config.go b/internal/config/config.go
index fc6743c..8c47d8e 100644
--- a/internal/config/config.go
+++ b/internal/config/config.go
@@ -17,15 +17,15 @@ const (
// Common path configuration for all transport modes
type PathsConfig struct {
- SSE string `yaml:"sse"`
- Messages string `yaml:"messages"`
+ SSE string `yaml:"sse"`
+ Messages string `yaml:"messages"`
}
// StdioConfig contains stdio-specific configuration
type StdioConfig struct {
Enabled bool `yaml:"enabled"`
- UserCommand string `yaml:"user_command"` // The command provided by the user
- WorkDir string `yaml:"work_dir"` // Working directory (optional)
+ UserCommand string `yaml:"user_command"` // The command provided by the user
+ WorkDir string `yaml:"work_dir"` // Working directory (optional)
Args []string `yaml:"args,omitempty"` // Additional arguments
Env []string `yaml:"env,omitempty"` // Environment variables
}
@@ -83,23 +83,32 @@ type DefaultConfig struct {
}
type Config struct {
- AuthServerBaseURL string
- ListenPort int `yaml:"listen_port"`
- BaseURL string `yaml:"base_url"`
- Port int `yaml:"port"`
- JWKSURL string
- TimeoutSeconds int `yaml:"timeout_seconds"`
- PathMapping map[string]string `yaml:"path_mapping"`
- Mode string `yaml:"mode"`
- CORSConfig CORSConfig `yaml:"cors"`
- TransportMode TransportMode `yaml:"transport_mode"`
- Paths PathsConfig `yaml:"paths"`
- Stdio StdioConfig `yaml:"stdio"`
+ AuthServerBaseURL string
+ ListenPort int `yaml:"listen_port"`
+ BaseURL string `yaml:"base_url"`
+ Port int `yaml:"port"`
+ JWKSURL string
+ TimeoutSeconds int `yaml:"timeout_seconds"`
+ PathMapping map[string]string `yaml:"path_mapping"`
+ Mode string `yaml:"mode"`
+ CORSConfig CORSConfig `yaml:"cors"`
+ TransportMode TransportMode `yaml:"transport_mode"`
+ Paths PathsConfig `yaml:"paths"`
+ Stdio StdioConfig `yaml:"stdio"`
+ RequiredScopes map[string]string `yaml:"required_scopes"`
// Nested config for Asgardeo
Demo DemoConfig `yaml:"demo"`
Asgardeo AsgardeoConfig `yaml:"asgardeo"`
Default DefaultConfig `yaml:"default"`
+
+ // Protected resource metadata
+ Audience string `yaml:"audience"`
+ ResourceIdentifier string `yaml:"resource_identifier"`
+ ScopesSupported any `yaml:"scopes_supported"`
+ AuthorizationServers []string `yaml:"authorization_servers"`
+ JwksURI string `yaml:"jwks_uri,omitempty"`
+ BearerMethodsSupported []string `yaml:"bearer_methods_supported,omitempty"`
}
// Validate checks if the config is valid based on transport mode
@@ -165,12 +174,12 @@ func LoadConfig(path string) (*Config, error) {
if err := decoder.Decode(&cfg); err != nil {
return nil, err
}
-
+
// Set default values
if cfg.TimeoutSeconds == 0 {
cfg.TimeoutSeconds = 15 // default
}
-
+
// Set default transport mode if not specified
if cfg.TransportMode == "" {
cfg.TransportMode = SSETransport // Default to SSE
@@ -180,11 +189,11 @@ func LoadConfig(path string) (*Config, error) {
if cfg.Port == 0 {
cfg.Port = 8000 // default
}
-
+
// Validate the configuration
if err := cfg.Validate(); err != nil {
return nil, err
}
-
+
return &cfg, nil
}
diff --git a/internal/constants/constants.go b/internal/constants/constants.go
index 1e5808e..e7b1bec 100644
--- a/internal/constants/constants.go
+++ b/internal/constants/constants.go
@@ -1,7 +1,14 @@
package constants
+import "time"
+
// Package constant provides constants for the MCP Auth Proxy
const (
ASGARDEO_BASE_URL = "https://api.asgardeo.io/t/"
)
+
+// MCP specification version cutover date
+var SpecCutoverDate = time.Date(2025, 3, 26, 0, 0, 0, 0, time.UTC)
+
+const TimeLayout = "2006-01-02"
diff --git a/internal/proxy/proxy.go b/internal/proxy/proxy.go
index 33a9ea3..b880f99 100644
--- a/internal/proxy/proxy.go
+++ b/internal/proxy/proxy.go
@@ -2,6 +2,7 @@ package proxy
import (
"context"
+ "fmt"
"net/http"
"net/http/httputil"
"net/url"
@@ -17,7 +18,7 @@ import (
// NewRouter builds an http.ServeMux that routes
// * /authorize, /token, /register, /.well-known to the provider or proxy
// * MCP paths to the MCP server, etc.
-func NewRouter(cfg *config.Config, provider authz.Provider) http.Handler {
+func NewRouter(cfg *config.Config, provider authz.Provider, accessController authz.AccessControl) http.Handler {
mux := http.NewServeMux()
modifiers := map[string]RequestModifier{
@@ -63,6 +64,20 @@ func NewRouter(cfg *config.Config, provider authz.Provider) http.Handler {
}
}
+ mux.HandleFunc("/.well-known/oauth-protected-resource", func(w http.ResponseWriter, r *http.Request) {
+ origin := r.Header.Get("Origin")
+ allowed := getAllowedOrigin(origin, cfg)
+ if r.Method == http.MethodOptions {
+ addCORSHeaders(w, cfg, allowed, r.Header.Get("Access-Control-Request-Headers"))
+ w.WriteHeader(http.StatusNoContent)
+ return
+ }
+
+ addCORSHeaders(w, cfg, allowed, "")
+ provider.ProtectedResourceMetadataHandler()(w, r)
+ })
+ registeredPaths["/.well-known/oauth-protected-resource"] = true
+
// Remove duplicates from defaultPaths
uniquePaths := make(map[string]bool)
cleanPaths := []string{}
@@ -76,7 +91,7 @@ func NewRouter(cfg *config.Config, provider authz.Provider) http.Handler {
for _, path := range defaultPaths {
if !registeredPaths[path] {
- mux.HandleFunc(path, buildProxyHandler(cfg, modifiers))
+ mux.HandleFunc(path, buildProxyHandler(cfg, modifiers, accessController))
registeredPaths[path] = true
}
}
@@ -84,14 +99,14 @@ func NewRouter(cfg *config.Config, provider authz.Provider) http.Handler {
// MCP paths
mcpPaths := cfg.GetMCPPaths()
for _, path := range mcpPaths {
- mux.HandleFunc(path, buildProxyHandler(cfg, modifiers))
+ mux.HandleFunc(path, buildProxyHandler(cfg, modifiers, accessController))
registeredPaths[path] = true
}
// Register paths from PathMapping that haven't been registered yet
for path := range cfg.PathMapping {
if !registeredPaths[path] {
- mux.HandleFunc(path, buildProxyHandler(cfg, modifiers))
+ mux.HandleFunc(path, buildProxyHandler(cfg, modifiers, accessController))
registeredPaths[path] = true
}
}
@@ -99,14 +114,14 @@ func NewRouter(cfg *config.Config, provider authz.Provider) http.Handler {
return mux
}
-func buildProxyHandler(cfg *config.Config, modifiers map[string]RequestModifier) http.HandlerFunc {
+func buildProxyHandler(cfg *config.Config, modifiers map[string]RequestModifier, accessController authz.AccessControl) http.HandlerFunc {
// Parse the base URLs up front
authBase, err := url.Parse(cfg.AuthServerBaseURL)
if err != nil {
logger.Error("Invalid auth server URL: %v", err)
panic(err) // Fatal error that prevents startup
}
-
+
mcpBase, err := url.Parse(cfg.BaseURL)
if err != nil {
logger.Error("Invalid MCP server URL: %v", err)
@@ -141,6 +156,11 @@ func buildProxyHandler(cfg *config.Config, modifiers map[string]RequestModifier)
// Add CORS headers to all responses
addCORSHeaders(w, cfg, allowedOrigin, "")
+ // Check if the request is for the latest spec
+ specVersion := util.GetVersionWithDefault(r.Header.Get("MCP-Protocol-Version"))
+ ver, err := util.ParseVersionDate(specVersion)
+ isLatestSpec := util.IsLatestSpec(ver, err)
+
// Decide whether the request should go to the auth server or MCP
var targetURL *url.URL
isSSE := false
@@ -148,13 +168,19 @@ func buildProxyHandler(cfg *config.Config, modifiers map[string]RequestModifier)
if isAuthPath(r.URL.Path) {
targetURL = authBase
} else if isMCPPath(r.URL.Path, cfg) {
- // Validate JWT for MCP paths if required
- // Placeholder for JWT validation logic
- if err := util.ValidateJWT(r.Header.Get("Authorization")); err != nil {
- logger.Warn("Unauthorized request to %s: %v", r.URL.Path, err)
- http.Error(w, "Unauthorized", http.StatusUnauthorized)
- return
+ if ssePaths[r.URL.Path] {
+ if err := authorizeSSE(w, r, isLatestSpec, cfg.ResourceIdentifier); err != nil {
+ http.Error(w, err.Error(), http.StatusUnauthorized)
+ return
+ }
+ isSSE = true
+ } else {
+ if err := authorizeMCP(w, r, isLatestSpec, cfg, accessController); err != nil {
+ http.Error(w, err.Error(), http.StatusForbidden)
+ return
+ }
}
+
targetURL = mcpBase
if ssePaths[r.URL.Path] {
isSSE = true
@@ -191,13 +217,13 @@ func buildProxyHandler(cfg *config.Config, modifiers map[string]RequestModifier)
req.Host = targetURL.Host
cleanHeaders := http.Header{}
-
+
// Set proper origin header to match the target
if isSSE {
// For SSE, ensure origin matches the target
req.Header.Set("Origin", targetURL.Scheme+"://"+targetURL.Host)
}
-
+
for k, v := range r.Header {
// Skip hop-by-hop headers
if skipHeader(k) {
@@ -214,7 +240,17 @@ func buildProxyHandler(cfg *config.Config, modifiers map[string]RequestModifier)
},
ModifyResponse: func(resp *http.Response) error {
logger.Debug("Response from %s%s: %d", resp.Request.URL.Host, resp.Request.URL.Path, resp.StatusCode)
- resp.Header.Del("Access-Control-Allow-Origin") // Avoid upstream conflicts
+ if resp.StatusCode == http.StatusUnauthorized {
+ resp.Header.Set(
+ "WWW-Authenticate",
+ fmt.Sprintf(
+ `Bearer resource_metadata="%s"`,
+ cfg.ResourceIdentifier+"/.well-known/oauth-protected-resource",
+ ))
+ resp.Header.Set("Access-Control-Expose-Headers", "WWW-Authenticate")
+ }
+
+ resp.Header.Del("Access-Control-Allow-Origin")
return nil
},
ErrorHandler: func(rw http.ResponseWriter, req *http.Request, err error) {
@@ -231,12 +267,12 @@ func buildProxyHandler(cfg *config.Config, modifiers map[string]RequestModifier)
proxyHost: r.Host,
targetHost: targetURL.Host,
}
-
+
// Set SSE-specific headers
w.Header().Set("X-Accel-Buffering", "no")
w.Header().Set("Cache-Control", "no-cache")
w.Header().Set("Connection", "keep-alive")
-
+ w.Header().Set("Content-Type", "text/event-stream")
// Keep SSE connections open
HandleSSE(w, r, rp)
} else {
@@ -248,6 +284,76 @@ func buildProxyHandler(cfg *config.Config, modifiers map[string]RequestModifier)
}
}
+// Check if the request is for SSE handshake and authorize it
+func authorizeSSE(w http.ResponseWriter, r *http.Request, isLatestSpec bool, resourceID string) error {
+ authHeader := r.Header.Get("Authorization")
+ if authHeader == "" || !strings.HasPrefix(authHeader, "Bearer ") {
+ if isLatestSpec {
+ realm := resourceID + "/.well-known/oauth-protected-resource"
+ w.Header().Set("WWW-Authenticate", fmt.Sprintf(`Bearer resource_metadata="%s"`, realm))
+ w.Header().Set("Access-Control-Expose-Headers", "WWW-Authenticate")
+ }
+ return fmt.Errorf("missing or invalid Authorization header")
+ }
+
+ return nil
+}
+
+// Handles both v1 (just signature) and v2 (aud + scope) flows
+func authorizeMCP(w http.ResponseWriter, r *http.Request, isLatestSpec bool, cfg *config.Config, accessController authz.AccessControl) error {
+ authzHeader := r.Header.Get("Authorization")
+ accessToken, _ := util.ExtractAccessToken(authzHeader)
+ if !strings.HasPrefix(authzHeader, "Bearer ") {
+ if isLatestSpec {
+ realm := cfg.ResourceIdentifier + "/.well-known/oauth-protected-resource"
+ w.Header().Set("WWW-Authenticate", fmt.Sprintf(
+ `Bearer resource_metadata=%q`, realm,
+ ))
+ w.Header().Set("Access-Control-Expose-Headers", "WWW-Authenticate")
+ }
+ http.Error(w, "Unauthorized", http.StatusUnauthorized)
+ return fmt.Errorf("missing or invalid Authorization header")
+ }
+
+ err := util.ValidateJWT(isLatestSpec, accessToken, cfg.Audience)
+ if err != nil {
+ if isLatestSpec {
+ realm := cfg.ResourceIdentifier + "/.well-known/oauth-protected-resource"
+ w.Header().Set("WWW-Authenticate", fmt.Sprintf(err.Error(),
+ `Bearer realm=%q`,
+ realm,
+ ))
+ w.Header().Set("Access-Control-Expose-Headers", "WWW-Authenticate")
+ http.Error(w, "Forbidden", http.StatusForbidden)
+ } else {
+ http.Error(w, "Unauthorized", http.StatusUnauthorized)
+ }
+ return err
+ }
+
+ if isLatestSpec {
+ _, err := util.ParseRPCRequest(r)
+ if err != nil {
+ http.Error(w, "Bad request", http.StatusBadRequest)
+ return err
+ }
+
+ claimsMap, err := util.ParseJWT(accessToken)
+ if err != nil {
+ http.Error(w, "Invalid token claims", http.StatusUnauthorized)
+ return fmt.Errorf("invalid token claims")
+ }
+
+ pr := accessController.ValidateAccess(r, &claimsMap, cfg)
+ if pr.Decision == authz.DecisionDeny {
+ http.Error(w, "Forbidden: "+pr.Message, http.StatusForbidden)
+ return fmt.Errorf("forbidden โ %s", pr.Message)
+ }
+ }
+
+ return nil
+}
+
func getAllowedOrigin(origin string, cfg *config.Config) string {
if origin == "" {
return cfg.CORSConfig.AllowedOrigins[0] // Default to first allowed origin
@@ -265,6 +371,7 @@ func getAllowedOrigin(origin string, cfg *config.Config) string {
func addCORSHeaders(w http.ResponseWriter, cfg *config.Config, allowedOrigin, requestHeaders string) {
w.Header().Set("Access-Control-Allow-Origin", allowedOrigin)
w.Header().Set("Access-Control-Allow-Methods", strings.Join(cfg.CORSConfig.AllowedMethods, ", "))
+ w.Header().Set("Access-Control-Expose-Headers", "WWW-Authenticate, MCP-Protocol-Version")
if requestHeaders != "" {
w.Header().Set("Access-Control-Allow-Headers", requestHeaders)
} else {
@@ -272,6 +379,7 @@ func addCORSHeaders(w http.ResponseWriter, cfg *config.Config, allowedOrigin, re
}
if cfg.CORSConfig.AllowCredentials {
w.Header().Set("Access-Control-Allow-Credentials", "true")
+ w.Header().Set("MCP-Protocol-Version", ", ")
}
w.Header().Set("Vary", "Origin")
w.Header().Set("X-Accel-Buffering", "no")
@@ -283,6 +391,7 @@ func isAuthPath(path string) bool {
"/token": true,
"/register": true,
"/.well-known/oauth-authorization-server": true,
+ "/.well-known/oauth-protected-resource": true,
}
if strings.HasPrefix(path, "/u/") {
return true
diff --git a/internal/util/jwks.go b/internal/util/jwks.go
index f80d82e..b1afb6f 100644
--- a/internal/util/jwks.go
+++ b/internal/util/jwks.go
@@ -4,21 +4,27 @@ import (
"crypto/rsa"
"encoding/json"
"errors"
+ "fmt"
"math/big"
"net/http"
"strings"
"github.com/golang-jwt/jwt/v4"
+ "github.com/wso2/open-mcp-auth-proxy/internal/config"
"github.com/wso2/open-mcp-auth-proxy/internal/logging"
)
+type TokenClaims struct {
+ Scopes []string
+}
+
type JWKS struct {
Keys []json.RawMessage `json:"keys"`
}
var publicKeys map[string]*rsa.PublicKey
-// FetchJWKS downloads JWKS and stores in a package-level map
+// FetchJWKS downloads JWKS and stores in a packageโlevel map
func FetchJWKS(jwksURL string) error {
resp, err := http.Get(jwksURL)
if err != nil {
@@ -31,23 +37,23 @@ func FetchJWKS(jwksURL string) error {
return err
}
- publicKeys = make(map[string]*rsa.PublicKey)
+ publicKeys = make(map[string]*rsa.PublicKey, len(jwks.Keys))
for _, keyData := range jwks.Keys {
- var parsedKey struct {
+ var parsed struct {
Kid string `json:"kid"`
N string `json:"n"`
E string `json:"e"`
Kty string `json:"kty"`
}
- if err := json.Unmarshal(keyData, &parsedKey); err != nil {
+ if err := json.Unmarshal(keyData, &parsed); err != nil {
continue
}
- if parsedKey.Kty != "RSA" {
+ if parsed.Kty != "RSA" {
continue
}
- pubKey, err := parseRSAPublicKey(parsedKey.N, parsedKey.E)
+ pubKey, err := parseRSAPublicKey(parsed.N, parsed.E)
if err == nil {
- publicKeys[parsedKey.Kid] = pubKey
+ publicKeys[parsed.Kid] = pubKey
}
}
logger.Info("Loaded %d public keys.", len(publicKeys))
@@ -73,25 +79,130 @@ func parseRSAPublicKey(nStr, eStr string) (*rsa.PublicKey, error) {
return &rsa.PublicKey{N: n, E: e}, nil
}
-// ValidateJWT checks the Authorization: Bearer token using stored JWKS
-func ValidateJWT(authHeader string) error {
- if authHeader == "" || !strings.HasPrefix(authHeader, "Bearer ") {
- return errors.New("missing or invalid Authorization header")
- }
- tokenStr := strings.TrimPrefix(authHeader, "Bearer ")
- token, err := jwt.Parse(tokenStr, func(token *jwt.Token) (interface{}, error) {
- kid, _ := token.Header["kid"].(string)
- pubKey, ok := publicKeys[kid]
- if !ok {
- return nil, errors.New("unknown or missing kid in token header")
+// ValidateJWT checks the Bearer token according to the Mcp-Protocol-Version.
+func ValidateJWT(
+ isLatestSpec bool,
+ accessToken string,
+ audience string,
+) error {
+ logger.Warn("isLatestSpec: %s", isLatestSpec)
+ // Parse & verify the signature
+ token, err := jwt.Parse(accessToken, func(token *jwt.Token) (interface{}, error) {
+ if _, ok := token.Method.(*jwt.SigningMethodRSA); !ok {
+ return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
}
- return pubKey, nil
+ kid, ok := token.Header["kid"].(string)
+ if !ok {
+ return nil, errors.New("kid header not found")
+ }
+ key, ok := publicKeys[kid]
+ if !ok {
+ return nil, fmt.Errorf("key not found for kid: %s", kid)
+ }
+ return key, nil
})
if err != nil {
- return errors.New("invalid token: " + err.Error())
+ logger.Warn("Error detected, returning early")
+ return fmt.Errorf("invalid token: %w", err)
}
if !token.Valid {
- return errors.New("invalid token: token not valid")
+ logger.Warn("Token invalid, returning early")
+ return errors.New("token not valid")
}
+
+ claimsMap, ok := token.Claims.(jwt.MapClaims)
+ if !ok {
+ return errors.New("unexpected claim type")
+ }
+
+ if !isLatestSpec {
+ return nil
+ }
+
+ audRaw, exists := claimsMap["aud"]
+ if !exists {
+ return errors.New("aud claim missing")
+ }
+ switch v := audRaw.(type) {
+ case string:
+ if v != audience {
+ return fmt.Errorf("aud %q does not match %q", v, audience)
+ }
+ case []interface{}:
+ var found bool
+ for _, a := range v {
+ if s, ok := a.(string); ok && s == audience {
+ found = true
+ break
+ }
+ }
+ if !found {
+ return fmt.Errorf("audience %v does not include %q", v, audience)
+ }
+ default:
+ return errors.New("aud claim has unexpected type")
+ }
+
return nil
}
+
+// Parses the JWT token and returns the claims
+func ParseJWT(tokenStr string) (jwt.MapClaims, error) {
+ if tokenStr == "" {
+ return nil, fmt.Errorf("empty JWT")
+ }
+
+ var claims jwt.MapClaims
+ _, _, err := jwt.NewParser().ParseUnverified(tokenStr, &claims)
+ if err != nil {
+ return nil, fmt.Errorf("failed to parse JWT: %w", err)
+ }
+ return claims, nil
+}
+
+// Process the required scopes
+func GetRequiredScopes(cfg *config.Config, method string) []string {
+ switch raw := cfg.ScopesSupported.(type) {
+ case map[string]string:
+ if scope, ok := raw[method]; ok {
+ return []string{scope}
+ }
+ parts := strings.SplitN(method, "/", 2)
+ if len(parts) > 0 {
+ if scope, ok := raw[parts[0]]; ok {
+ return []string{scope}
+ }
+ }
+ return nil
+ case []interface{}:
+ out := make([]string, 0, len(raw))
+ for _, v := range raw {
+ if s, ok := v.(string); ok && s != "" {
+ out = append(out, s)
+ }
+ }
+ return out
+
+ case []string:
+ return raw
+ }
+
+ return nil
+}
+
+// Extracts the Bearer token from the Authorization header
+func ExtractAccessToken(authHeader string) (string, error) {
+ if authHeader == "" {
+ return "", errors.New("empty authorization header")
+ }
+ if !strings.HasPrefix(authHeader, "Bearer ") {
+ return "", fmt.Errorf("invalid authorization header format: %s", authHeader)
+ }
+
+ tokenStr := strings.TrimSpace(strings.TrimPrefix(authHeader, "Bearer "))
+ if tokenStr == "" {
+ return "", errors.New("empty bearer token")
+ }
+
+ return tokenStr, nil
+}
diff --git a/internal/util/rpc.go b/internal/util/rpc.go
new file mode 100644
index 0000000..5338437
--- /dev/null
+++ b/internal/util/rpc.go
@@ -0,0 +1,38 @@
+package util
+
+import (
+ "bytes"
+ "encoding/json"
+ "io"
+ "net/http"
+
+ logger "github.com/wso2/open-mcp-auth-proxy/internal/logging"
+)
+
+type RPCEnvelope struct {
+ Method string `json:"method"`
+ Params any `json:"params"`
+ ID any `json:"id"`
+}
+
+// This function parses a JSON-RPC request from an HTTP request body
+func ParseRPCRequest(r *http.Request) (*RPCEnvelope, error) {
+ bodyBytes, err := io.ReadAll(r.Body)
+ if err != nil {
+ return nil, err
+ }
+ r.Body = io.NopCloser(bytes.NewBuffer(bodyBytes))
+
+ if len(bodyBytes) == 0 {
+ return nil, nil
+ }
+
+ var env RPCEnvelope
+ dec := json.NewDecoder(bytes.NewReader(bodyBytes))
+ if err := dec.Decode(&env); err != nil && err != io.EOF {
+ logger.Warn("Error parsing JSON-RPC envelope: %v", err)
+ return nil, err
+ }
+
+ return &env, nil
+}
diff --git a/internal/util/version.go b/internal/util/version.go
new file mode 100644
index 0000000..230ef1d
--- /dev/null
+++ b/internal/util/version.go
@@ -0,0 +1,26 @@
+package util
+
+import (
+ "time"
+
+ "github.com/wso2/open-mcp-auth-proxy/internal/constants"
+)
+
+// This function checks if the given version date is after the spec cutover date
+func IsLatestSpec(versionDate time.Time, err error) bool {
+ return err == nil && versionDate.After(constants.SpecCutoverDate)
+}
+
+// This function parses a version string into a time.Time
+func ParseVersionDate(version string) (time.Time, error) {
+ return time.Parse("2006-01-02", version)
+}
+
+// This function returns the version string, using the cutover date if empty
+func GetVersionWithDefault(version string) string {
+ if version == "" {
+ defaultTime, _ := time.Parse(constants.TimeLayout, "2025-05-15")
+ return defaultTime.Format(constants.TimeLayout)
+ }
+ return version
+}