mirror of
https://github.com/wso2/open-mcp-auth-proxy.git
synced 2025-06-28 17:34:19 +00:00
Merge 1964829dcd
into 56cdc96cb6
This commit is contained in:
commit
5e1ea3aaf1
15 changed files with 609 additions and 100 deletions
59
README.md
59
README.md
|
@ -10,16 +10,31 @@ A lightweight authorization proxy for Model Context Protocol (MCP) servers that
|
||||||
|
|
||||||

|

|
||||||
|
|
||||||
## What it Does
|
## 🛡️ What it Does?
|
||||||
|
|
||||||
Open MCP Auth Proxy sits between MCP clients and your MCP server to:
|
|
||||||
|
|
||||||
- Intercept incoming requests
|
- Intercept incoming requests
|
||||||
- Validate authorization tokens
|
- Validate authorization tokens
|
||||||
- Offload authentication and authorization to OAuth-compliant Identity Providers
|
- Offload authentication and authorization to OAuth-compliant Identity Providers
|
||||||
- Support the MCP authorization protocol
|
- Support the MCP authorization protocol
|
||||||
|
|
||||||
## Quick Start
|
|
||||||
|
## 🚀 Features
|
||||||
|
|
||||||
|
- **Dynamic Authorization** based on MCP Authorization Specification.
|
||||||
|
- **JWT Validation** (signature, audience, and scopes).
|
||||||
|
- **Identity Provider Integration** (OAuth/OIDC via Asgardeo, Auth0, Keycloak).
|
||||||
|
- **Protocol Version Negotiation** via `MCP-Protocol-Version` header.
|
||||||
|
- **Comprehensive Authentication Feedback** via RFC-compliant challenges.
|
||||||
|
- **Flexible Transport Modes**: SSE and stdio.
|
||||||
|
|
||||||
|
## 📌 MCP Specification Verions
|
||||||
|
|
||||||
|
| Version | Behavior |
|
||||||
|
| :-------------------- | :----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
|
||||||
|
| 2025-03-26 | Only signature check of Bearer JWT on both `/sse` and `/message`<br> No scope or audience enforcement |
|
||||||
|
| Latest(draft) | Read `MCP-Protocol-Version` from client header<br> SSE handshake returns `WWW-Authenticate: Bearer resource_metadata="…"`<br> `/message` enforces:<br>`aud` claim == `ResourceIdentifier`<br>`scope` claim contains `requiredScope`<br>Scope based access control<br>Rich `WWW-Authenticate` on 401s<br>Serves `/.well-known/oauth-protected-resource` JSON |
|
||||||
|
|
||||||
|
## 🛠️ Quick Start
|
||||||
|
|
||||||
### Prerequisites
|
### Prerequisites
|
||||||
|
|
||||||
|
@ -67,7 +82,7 @@ Open MCP Auth Proxy sits between MCP clients and your MCP server to:
|
||||||
|
|
||||||
3. Connect using an MCP client like [MCP Inspector](https://github.com/shashimalcse/inspector)(This is a temporary fork with fixes for authentication [issues](https://github.com/modelcontextprotocol/typescript-sdk/issues/257) in the original implementation)
|
3. Connect using an MCP client like [MCP Inspector](https://github.com/shashimalcse/inspector)(This is a temporary fork with fixes for authentication [issues](https://github.com/modelcontextprotocol/typescript-sdk/issues/257) in the original implementation)
|
||||||
|
|
||||||
## Connect an Identity Provider
|
## 🔒 Integrate an Identity Provider
|
||||||
|
|
||||||
### Asgardeo
|
### Asgardeo
|
||||||
|
|
||||||
|
@ -81,13 +96,17 @@ To enable authorization through your Asgardeo organization:
|
||||||
3. Update `config.yaml` with the following parameters.
|
3. Update `config.yaml` with the following parameters.
|
||||||
|
|
||||||
```yaml
|
```yaml
|
||||||
base_url: "http://localhost:8000" # URL of your MCP server
|
base_url: "http://localhost:8000" # URL of your MCP server
|
||||||
listen_port: 8080 # Address where the proxy will listen
|
listen_port: 8080 # Address where the proxy will listen
|
||||||
|
|
||||||
asgardeo:
|
resource_identifier: "http://localhost:8080" # Proxy server URL
|
||||||
org_name: "<org_name>" # Your Asgardeo org name
|
scopes_supported: # Scopes required to defined for the MCP server
|
||||||
client_id: "<client_id>" # Client ID of the M2M app
|
- "read:tools"
|
||||||
client_secret: "<client_secret>" # Client secret of the M2M app
|
- "read:resources"
|
||||||
|
audience: "<audience_value>" # Access token audience
|
||||||
|
authorization_servers: # Authorization server issuer identifier(s)
|
||||||
|
- "https://api.asgardeo.io/t/acme"
|
||||||
|
jwks_uri: "https://api.asgardeo.io/t/acme/oauth2/jwks" # JWKS URL
|
||||||
```
|
```
|
||||||
|
|
||||||
4. Start the proxy with Asgardeo integration:
|
4. Start the proxy with Asgardeo integration:
|
||||||
|
@ -101,7 +120,7 @@ asgardeo:
|
||||||
- [Auth0](docs/integrations/Auth0.md)
|
- [Auth0](docs/integrations/Auth0.md)
|
||||||
- [Keycloak](docs/integrations/keycloak.md)
|
- [Keycloak](docs/integrations/keycloak.md)
|
||||||
|
|
||||||
# Advanced Configuration
|
# ⚙️ Advanced Configuration
|
||||||
|
|
||||||
### Transport Modes
|
### Transport Modes
|
||||||
|
|
||||||
|
@ -167,7 +186,7 @@ The proxy will:
|
||||||
- Handle all authorization requirements
|
- Handle all authorization requirements
|
||||||
- Forward messages between clients and the server
|
- Forward messages between clients and the server
|
||||||
|
|
||||||
### Complete Configuration Reference
|
### 📝 Complete Configuration Reference
|
||||||
|
|
||||||
```yaml
|
```yaml
|
||||||
# Common configuration
|
# Common configuration
|
||||||
|
@ -210,13 +229,17 @@ demo:
|
||||||
client_secret: "qFHfiBp5gNGAO9zV4YPnDofBzzfInatfUbHyPZvM0jka"
|
client_secret: "qFHfiBp5gNGAO9zV4YPnDofBzzfInatfUbHyPZvM0jka"
|
||||||
|
|
||||||
# Asgardeo configuration (used with --asgardeo flag)
|
# Asgardeo configuration (used with --asgardeo flag)
|
||||||
asgardeo:
|
resource_identifier: "http://localhost:8080"
|
||||||
org_name: "<org_name>"
|
scopes_supported:
|
||||||
client_id: "<client_id>"
|
- "read:tools"
|
||||||
client_secret: "<client_secret>"
|
- "read:resources"
|
||||||
|
audience: "<audience_value>"
|
||||||
|
authorization_servers:
|
||||||
|
- "https://api.asgardeo.io/t/acme"
|
||||||
|
jwks_uri: "https://api.asgardeo.io/t/acme/oauth2/jwks"
|
||||||
```
|
```
|
||||||
|
|
||||||
### Build from source
|
### 🖥️ Build from source
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
git clone https://github.com/wso2/open-mcp-auth-proxy
|
git clone https://github.com/wso2/open-mcp-auth-proxy
|
||||||
|
|
|
@ -11,7 +11,6 @@ import (
|
||||||
|
|
||||||
"github.com/wso2/open-mcp-auth-proxy/internal/authz"
|
"github.com/wso2/open-mcp-auth-proxy/internal/authz"
|
||||||
"github.com/wso2/open-mcp-auth-proxy/internal/config"
|
"github.com/wso2/open-mcp-auth-proxy/internal/config"
|
||||||
"github.com/wso2/open-mcp-auth-proxy/internal/constants"
|
|
||||||
"github.com/wso2/open-mcp-auth-proxy/internal/logging"
|
"github.com/wso2/open-mcp-auth-proxy/internal/logging"
|
||||||
"github.com/wso2/open-mcp-auth-proxy/internal/proxy"
|
"github.com/wso2/open-mcp-auth-proxy/internal/proxy"
|
||||||
"github.com/wso2/open-mcp-auth-proxy/internal/subprocess"
|
"github.com/wso2/open-mcp-auth-proxy/internal/subprocess"
|
||||||
|
@ -68,23 +67,7 @@ func main() {
|
||||||
}
|
}
|
||||||
|
|
||||||
// 3. Create the chosen provider
|
// 3. Create the chosen provider
|
||||||
var provider authz.Provider
|
var provider authz.Provider = MakeProvider(cfg, *demoMode, *asgardeoMode)
|
||||||
if *demoMode {
|
|
||||||
cfg.Mode = "demo"
|
|
||||||
cfg.AuthServerBaseURL = constants.ASGARDEO_BASE_URL + cfg.Demo.OrgName + "/oauth2"
|
|
||||||
cfg.JWKSURL = constants.ASGARDEO_BASE_URL + cfg.Demo.OrgName + "/oauth2/jwks"
|
|
||||||
provider = authz.NewAsgardeoProvider(cfg)
|
|
||||||
} else if *asgardeoMode {
|
|
||||||
cfg.Mode = "asgardeo"
|
|
||||||
cfg.AuthServerBaseURL = constants.ASGARDEO_BASE_URL + cfg.Asgardeo.OrgName + "/oauth2"
|
|
||||||
cfg.JWKSURL = constants.ASGARDEO_BASE_URL + cfg.Asgardeo.OrgName + "/oauth2/jwks"
|
|
||||||
provider = authz.NewAsgardeoProvider(cfg)
|
|
||||||
} else {
|
|
||||||
cfg.Mode = "default"
|
|
||||||
cfg.JWKSURL = cfg.Default.JWKSURL
|
|
||||||
cfg.AuthServerBaseURL = cfg.Default.BaseURL
|
|
||||||
provider = authz.NewDefaultProvider(cfg)
|
|
||||||
}
|
|
||||||
|
|
||||||
// 4. (Optional) Fetch JWKS if you want local JWT validation
|
// 4. (Optional) Fetch JWKS if you want local JWT validation
|
||||||
if err := util.FetchJWKS(cfg.JWKSURL); err != nil {
|
if err := util.FetchJWKS(cfg.JWKSURL); err != nil {
|
||||||
|
@ -92,12 +75,15 @@ func main() {
|
||||||
os.Exit(1)
|
os.Exit(1)
|
||||||
}
|
}
|
||||||
|
|
||||||
// 5. Build the main router
|
// 5. (Optional) Build the access controler
|
||||||
mux := proxy.NewRouter(cfg, provider)
|
accessController := &authz.ScopeValidator{}
|
||||||
|
|
||||||
|
// 6. Build the main router
|
||||||
|
mux := proxy.NewRouter(cfg, provider, accessController)
|
||||||
|
|
||||||
listen_address := fmt.Sprintf(":%d", cfg.ListenPort)
|
listen_address := fmt.Sprintf(":%d", cfg.ListenPort)
|
||||||
|
|
||||||
// 6. Start the server
|
// 7. Start the server
|
||||||
srv := &http.Server{
|
srv := &http.Server{
|
||||||
Addr: listen_address,
|
Addr: listen_address,
|
||||||
Handler: mux,
|
Handler: mux,
|
||||||
|
@ -111,18 +97,18 @@ func main() {
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
// 7. Wait for shutdown signal
|
// 8. Wait for shutdown signal
|
||||||
stop := make(chan os.Signal, 1)
|
stop := make(chan os.Signal, 1)
|
||||||
signal.Notify(stop, os.Interrupt, syscall.SIGTERM)
|
signal.Notify(stop, os.Interrupt, syscall.SIGTERM)
|
||||||
<-stop
|
<-stop
|
||||||
logger.Info("Shutting down...")
|
logger.Info("Shutting down...")
|
||||||
|
|
||||||
// 8. First terminate subprocess if running
|
// 9. First terminate subprocess if running
|
||||||
if procManager != nil && procManager.IsRunning() {
|
if procManager != nil && procManager.IsRunning() {
|
||||||
procManager.Shutdown()
|
procManager.Shutdown()
|
||||||
}
|
}
|
||||||
|
|
||||||
// 9. Then shutdown the server
|
// 10. Then shutdown the server
|
||||||
logger.Info("Shutting down HTTP server...")
|
logger.Info("Shutting down HTTP server...")
|
||||||
shutdownCtx, cancel := proxy.NewShutdownContext(5 * time.Second)
|
shutdownCtx, cancel := proxy.NewShutdownContext(5 * time.Second)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
45
cmd/proxy/provider.go
Normal file
45
cmd/proxy/provider.go
Normal file
|
@ -0,0 +1,45 @@
|
||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/wso2/open-mcp-auth-proxy/internal/authz"
|
||||||
|
"github.com/wso2/open-mcp-auth-proxy/internal/config"
|
||||||
|
"github.com/wso2/open-mcp-auth-proxy/internal/constants"
|
||||||
|
)
|
||||||
|
|
||||||
|
func MakeProvider(cfg *config.Config, demoMode, asgardeoMode bool) authz.Provider {
|
||||||
|
var mode, orgName string
|
||||||
|
switch {
|
||||||
|
case demoMode:
|
||||||
|
mode = "demo"
|
||||||
|
orgName = cfg.Demo.OrgName
|
||||||
|
case asgardeoMode:
|
||||||
|
mode = "asgardeo"
|
||||||
|
orgName = cfg.Asgardeo.OrgName
|
||||||
|
default:
|
||||||
|
mode = "default"
|
||||||
|
}
|
||||||
|
cfg.Mode = mode
|
||||||
|
|
||||||
|
switch mode {
|
||||||
|
case "demo", "asgardeo":
|
||||||
|
if len(cfg.AuthorizationServers) == 0 && cfg.JwksURI == "" {
|
||||||
|
base := constants.ASGARDEO_BASE_URL + orgName + "/oauth2"
|
||||||
|
cfg.AuthServerBaseURL = base
|
||||||
|
cfg.JWKSURL = base + "/jwks"
|
||||||
|
} else {
|
||||||
|
cfg.AuthServerBaseURL = cfg.AuthorizationServers[0]
|
||||||
|
cfg.JWKSURL = cfg.JwksURI
|
||||||
|
}
|
||||||
|
return authz.NewAsgardeoProvider(cfg)
|
||||||
|
|
||||||
|
default:
|
||||||
|
if cfg.Default.BaseURL != "" && cfg.Default.JWKSURL != "" {
|
||||||
|
cfg.AuthServerBaseURL = cfg.Default.BaseURL
|
||||||
|
cfg.JWKSURL = cfg.Default.JWKSURL
|
||||||
|
} else if len(cfg.AuthorizationServers) > 0 {
|
||||||
|
cfg.AuthServerBaseURL = cfg.AuthorizationServers[0]
|
||||||
|
cfg.JWKSURL = cfg.JwksURI
|
||||||
|
}
|
||||||
|
return authz.NewDefaultProvider(cfg)
|
||||||
|
}
|
||||||
|
}
|
15
config.yaml
15
config.yaml
|
@ -45,3 +45,18 @@ demo:
|
||||||
org_name: "openmcpauthdemo"
|
org_name: "openmcpauthdemo"
|
||||||
client_id: "N0U9e_NNGr9mP_0fPnPfPI0a6twa"
|
client_id: "N0U9e_NNGr9mP_0fPnPfPI0a6twa"
|
||||||
client_secret: "qFHfiBp5gNGAO9zV4YPnDofBzzfInatfUbHyPZvM0jka"
|
client_secret: "qFHfiBp5gNGAO9zV4YPnDofBzzfInatfUbHyPZvM0jka"
|
||||||
|
|
||||||
|
# Protected resource metadata
|
||||||
|
resource_identifier: http://localhost:3000
|
||||||
|
audience: mcp_proxy
|
||||||
|
scopes_supported:
|
||||||
|
- "tools":"read:tools"
|
||||||
|
- "resources":"read:resources"
|
||||||
|
- "prompts":"read:prompts"
|
||||||
|
authorization_servers:
|
||||||
|
- https://api.asgardeo.io/t/acme/
|
||||||
|
jwks_uri: https://api.asgardeo.io/t/acme/oauth2/jwks
|
||||||
|
bearer_methods_supported:
|
||||||
|
- header
|
||||||
|
- body
|
||||||
|
- query
|
||||||
|
|
24
internal/authz/access_control.go
Normal file
24
internal/authz/access_control.go
Normal file
|
@ -0,0 +1,24 @@
|
||||||
|
package authz
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
|
||||||
|
"github.com/golang-jwt/jwt/v4"
|
||||||
|
"github.com/wso2/open-mcp-auth-proxy/internal/config"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Decision int
|
||||||
|
|
||||||
|
const (
|
||||||
|
DecisionAllow Decision = iota
|
||||||
|
DecisionDeny
|
||||||
|
)
|
||||||
|
|
||||||
|
type AccessControlResult struct {
|
||||||
|
Decision Decision
|
||||||
|
Message string
|
||||||
|
}
|
||||||
|
|
||||||
|
type AccessControl interface {
|
||||||
|
ValidateAccess(r *http.Request, claims *jwt.MapClaims, config *config.Config) AccessControlResult
|
||||||
|
}
|
|
@ -350,3 +350,23 @@ func randomString(n int) string {
|
||||||
}
|
}
|
||||||
return string(b)
|
return string(b)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (p *asgardeoProvider) ProtectedResourceMetadataHandler() http.HandlerFunc {
|
||||||
|
return func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
meta := map[string]interface{}{
|
||||||
|
"resource": p.cfg.ResourceIdentifier,
|
||||||
|
"scopes_supported": p.cfg.ScopesSupported,
|
||||||
|
"authorization_servers": p.cfg.AuthorizationServers,
|
||||||
|
}
|
||||||
|
if p.cfg.JwksURI != "" {
|
||||||
|
meta["jwks_uri"] = p.cfg.JwksURI
|
||||||
|
}
|
||||||
|
if len(p.cfg.BearerMethodsSupported) > 0 {
|
||||||
|
meta["bearer_methods_supported"] = p.cfg.BearerMethodsSupported
|
||||||
|
}
|
||||||
|
if err := json.NewEncoder(w).Encode(meta); err != nil {
|
||||||
|
http.Error(w, "failed to encode metadata", http.StatusInternalServerError)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
@ -94,3 +94,27 @@ func (p *defaultProvider) WellKnownHandler() http.HandlerFunc {
|
||||||
func (p *defaultProvider) RegisterHandler() http.HandlerFunc {
|
func (p *defaultProvider) RegisterHandler() http.HandlerFunc {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (p *defaultProvider) ProtectedResourceMetadataHandler() http.HandlerFunc {
|
||||||
|
return func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
meta := map[string]interface{}{
|
||||||
|
"audience": p.cfg.Audience,
|
||||||
|
"resource": p.cfg.ResourceIdentifier,
|
||||||
|
"scopes_supported": p.cfg.ScopesSupported,
|
||||||
|
"authorization_servers": p.cfg.AuthorizationServers,
|
||||||
|
}
|
||||||
|
|
||||||
|
if p.cfg.JwksURI != "" {
|
||||||
|
meta["jwks_uri"] = p.cfg.JwksURI
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(p.cfg.BearerMethodsSupported) > 0 {
|
||||||
|
meta["bearer_methods_supported"] = p.cfg.BearerMethodsSupported
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := json.NewEncoder(w).Encode(meta); err != nil {
|
||||||
|
http.Error(w, "failed to encode metadata", http.StatusInternalServerError)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
@ -7,4 +7,5 @@ import "net/http"
|
||||||
type Provider interface {
|
type Provider interface {
|
||||||
WellKnownHandler() http.HandlerFunc
|
WellKnownHandler() http.HandlerFunc
|
||||||
RegisterHandler() http.HandlerFunc
|
RegisterHandler() http.HandlerFunc
|
||||||
|
ProtectedResourceMetadataHandler() http.HandlerFunc
|
||||||
}
|
}
|
||||||
|
|
71
internal/authz/scope_validator.go
Normal file
71
internal/authz/scope_validator.go
Normal file
|
@ -0,0 +1,71 @@
|
||||||
|
package authz
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/golang-jwt/jwt/v4"
|
||||||
|
"github.com/wso2/open-mcp-auth-proxy/internal/config"
|
||||||
|
"github.com/wso2/open-mcp-auth-proxy/internal/util"
|
||||||
|
)
|
||||||
|
|
||||||
|
type ScopeValidator struct{}
|
||||||
|
|
||||||
|
// Evaluate and checks the token claims against one or more required scopes.
|
||||||
|
func (d *ScopeValidator) ValidateAccess(
|
||||||
|
r *http.Request,
|
||||||
|
claims *jwt.MapClaims,
|
||||||
|
config *config.Config,
|
||||||
|
) AccessControlResult {
|
||||||
|
env, err := util.ParseRPCRequest(r)
|
||||||
|
if err != nil {
|
||||||
|
return AccessControlResult{DecisionDeny, "bad JSON-RPC request"}
|
||||||
|
}
|
||||||
|
requiredScopes := util.GetRequiredScopes(config, env.Method)
|
||||||
|
if len(requiredScopes) == 0 {
|
||||||
|
return AccessControlResult{DecisionAllow, ""}
|
||||||
|
}
|
||||||
|
|
||||||
|
required := make(map[string]struct{}, len(requiredScopes))
|
||||||
|
for _, s := range requiredScopes {
|
||||||
|
s = strings.TrimSpace(s)
|
||||||
|
if s != "" {
|
||||||
|
required[s] = struct{}{}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
var tokenScopes []string
|
||||||
|
if claims, ok := (*claims)["scope"]; ok {
|
||||||
|
switch v := claims.(type) {
|
||||||
|
case string:
|
||||||
|
tokenScopes = strings.Fields(v)
|
||||||
|
case []interface{}:
|
||||||
|
for _, x := range v {
|
||||||
|
if s, ok := x.(string); ok && s != "" {
|
||||||
|
tokenScopes = append(tokenScopes, s)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
tokenScopeSet := make(map[string]struct{}, len(tokenScopes))
|
||||||
|
for _, s := range tokenScopes {
|
||||||
|
tokenScopeSet[s] = struct{}{}
|
||||||
|
}
|
||||||
|
|
||||||
|
var missing []string
|
||||||
|
for s := range required {
|
||||||
|
if _, ok := tokenScopeSet[s]; !ok {
|
||||||
|
missing = append(missing, s)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(missing) == 0 {
|
||||||
|
return AccessControlResult{DecisionAllow, ""}
|
||||||
|
}
|
||||||
|
return AccessControlResult{
|
||||||
|
DecisionDeny,
|
||||||
|
fmt.Sprintf("missing required scope(s): %s", strings.Join(missing, ", ")),
|
||||||
|
}
|
||||||
|
}
|
|
@ -17,15 +17,15 @@ const (
|
||||||
|
|
||||||
// Common path configuration for all transport modes
|
// Common path configuration for all transport modes
|
||||||
type PathsConfig struct {
|
type PathsConfig struct {
|
||||||
SSE string `yaml:"sse"`
|
SSE string `yaml:"sse"`
|
||||||
Messages string `yaml:"messages"`
|
Messages string `yaml:"messages"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// StdioConfig contains stdio-specific configuration
|
// StdioConfig contains stdio-specific configuration
|
||||||
type StdioConfig struct {
|
type StdioConfig struct {
|
||||||
Enabled bool `yaml:"enabled"`
|
Enabled bool `yaml:"enabled"`
|
||||||
UserCommand string `yaml:"user_command"` // The command provided by the user
|
UserCommand string `yaml:"user_command"` // The command provided by the user
|
||||||
WorkDir string `yaml:"work_dir"` // Working directory (optional)
|
WorkDir string `yaml:"work_dir"` // Working directory (optional)
|
||||||
Args []string `yaml:"args,omitempty"` // Additional arguments
|
Args []string `yaml:"args,omitempty"` // Additional arguments
|
||||||
Env []string `yaml:"env,omitempty"` // Environment variables
|
Env []string `yaml:"env,omitempty"` // Environment variables
|
||||||
}
|
}
|
||||||
|
@ -83,23 +83,32 @@ type DefaultConfig struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
type Config struct {
|
type Config struct {
|
||||||
AuthServerBaseURL string
|
AuthServerBaseURL string
|
||||||
ListenPort int `yaml:"listen_port"`
|
ListenPort int `yaml:"listen_port"`
|
||||||
BaseURL string `yaml:"base_url"`
|
BaseURL string `yaml:"base_url"`
|
||||||
Port int `yaml:"port"`
|
Port int `yaml:"port"`
|
||||||
JWKSURL string
|
JWKSURL string
|
||||||
TimeoutSeconds int `yaml:"timeout_seconds"`
|
TimeoutSeconds int `yaml:"timeout_seconds"`
|
||||||
PathMapping map[string]string `yaml:"path_mapping"`
|
PathMapping map[string]string `yaml:"path_mapping"`
|
||||||
Mode string `yaml:"mode"`
|
Mode string `yaml:"mode"`
|
||||||
CORSConfig CORSConfig `yaml:"cors"`
|
CORSConfig CORSConfig `yaml:"cors"`
|
||||||
TransportMode TransportMode `yaml:"transport_mode"`
|
TransportMode TransportMode `yaml:"transport_mode"`
|
||||||
Paths PathsConfig `yaml:"paths"`
|
Paths PathsConfig `yaml:"paths"`
|
||||||
Stdio StdioConfig `yaml:"stdio"`
|
Stdio StdioConfig `yaml:"stdio"`
|
||||||
|
RequiredScopes map[string]string `yaml:"required_scopes"`
|
||||||
|
|
||||||
// Nested config for Asgardeo
|
// Nested config for Asgardeo
|
||||||
Demo DemoConfig `yaml:"demo"`
|
Demo DemoConfig `yaml:"demo"`
|
||||||
Asgardeo AsgardeoConfig `yaml:"asgardeo"`
|
Asgardeo AsgardeoConfig `yaml:"asgardeo"`
|
||||||
Default DefaultConfig `yaml:"default"`
|
Default DefaultConfig `yaml:"default"`
|
||||||
|
|
||||||
|
// Protected resource metadata
|
||||||
|
Audience string `yaml:"audience"`
|
||||||
|
ResourceIdentifier string `yaml:"resource_identifier"`
|
||||||
|
ScopesSupported any `yaml:"scopes_supported"`
|
||||||
|
AuthorizationServers []string `yaml:"authorization_servers"`
|
||||||
|
JwksURI string `yaml:"jwks_uri,omitempty"`
|
||||||
|
BearerMethodsSupported []string `yaml:"bearer_methods_supported,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// Validate checks if the config is valid based on transport mode
|
// Validate checks if the config is valid based on transport mode
|
||||||
|
|
|
@ -1,7 +1,14 @@
|
||||||
package constants
|
package constants
|
||||||
|
|
||||||
|
import "time"
|
||||||
|
|
||||||
// Package constant provides constants for the MCP Auth Proxy
|
// Package constant provides constants for the MCP Auth Proxy
|
||||||
|
|
||||||
const (
|
const (
|
||||||
ASGARDEO_BASE_URL = "https://api.asgardeo.io/t/"
|
ASGARDEO_BASE_URL = "https://api.asgardeo.io/t/"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// MCP specification version cutover date
|
||||||
|
var SpecCutoverDate = time.Date(2025, 3, 26, 0, 0, 0, 0, time.UTC)
|
||||||
|
|
||||||
|
const TimeLayout = "2006-01-02"
|
||||||
|
|
|
@ -2,6 +2,7 @@ package proxy
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httputil"
|
"net/http/httputil"
|
||||||
"net/url"
|
"net/url"
|
||||||
|
@ -17,7 +18,7 @@ import (
|
||||||
// NewRouter builds an http.ServeMux that routes
|
// NewRouter builds an http.ServeMux that routes
|
||||||
// * /authorize, /token, /register, /.well-known to the provider or proxy
|
// * /authorize, /token, /register, /.well-known to the provider or proxy
|
||||||
// * MCP paths to the MCP server, etc.
|
// * MCP paths to the MCP server, etc.
|
||||||
func NewRouter(cfg *config.Config, provider authz.Provider) http.Handler {
|
func NewRouter(cfg *config.Config, provider authz.Provider, accessController authz.AccessControl) http.Handler {
|
||||||
mux := http.NewServeMux()
|
mux := http.NewServeMux()
|
||||||
|
|
||||||
modifiers := map[string]RequestModifier{
|
modifiers := map[string]RequestModifier{
|
||||||
|
@ -63,6 +64,20 @@ func NewRouter(cfg *config.Config, provider authz.Provider) http.Handler {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
mux.HandleFunc("/.well-known/oauth-protected-resource", func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
origin := r.Header.Get("Origin")
|
||||||
|
allowed := getAllowedOrigin(origin, cfg)
|
||||||
|
if r.Method == http.MethodOptions {
|
||||||
|
addCORSHeaders(w, cfg, allowed, r.Header.Get("Access-Control-Request-Headers"))
|
||||||
|
w.WriteHeader(http.StatusNoContent)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
addCORSHeaders(w, cfg, allowed, "")
|
||||||
|
provider.ProtectedResourceMetadataHandler()(w, r)
|
||||||
|
})
|
||||||
|
registeredPaths["/.well-known/oauth-protected-resource"] = true
|
||||||
|
|
||||||
// Remove duplicates from defaultPaths
|
// Remove duplicates from defaultPaths
|
||||||
uniquePaths := make(map[string]bool)
|
uniquePaths := make(map[string]bool)
|
||||||
cleanPaths := []string{}
|
cleanPaths := []string{}
|
||||||
|
@ -76,7 +91,7 @@ func NewRouter(cfg *config.Config, provider authz.Provider) http.Handler {
|
||||||
|
|
||||||
for _, path := range defaultPaths {
|
for _, path := range defaultPaths {
|
||||||
if !registeredPaths[path] {
|
if !registeredPaths[path] {
|
||||||
mux.HandleFunc(path, buildProxyHandler(cfg, modifiers))
|
mux.HandleFunc(path, buildProxyHandler(cfg, modifiers, accessController))
|
||||||
registeredPaths[path] = true
|
registeredPaths[path] = true
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -84,14 +99,14 @@ func NewRouter(cfg *config.Config, provider authz.Provider) http.Handler {
|
||||||
// MCP paths
|
// MCP paths
|
||||||
mcpPaths := cfg.GetMCPPaths()
|
mcpPaths := cfg.GetMCPPaths()
|
||||||
for _, path := range mcpPaths {
|
for _, path := range mcpPaths {
|
||||||
mux.HandleFunc(path, buildProxyHandler(cfg, modifiers))
|
mux.HandleFunc(path, buildProxyHandler(cfg, modifiers, accessController))
|
||||||
registeredPaths[path] = true
|
registeredPaths[path] = true
|
||||||
}
|
}
|
||||||
|
|
||||||
// Register paths from PathMapping that haven't been registered yet
|
// Register paths from PathMapping that haven't been registered yet
|
||||||
for path := range cfg.PathMapping {
|
for path := range cfg.PathMapping {
|
||||||
if !registeredPaths[path] {
|
if !registeredPaths[path] {
|
||||||
mux.HandleFunc(path, buildProxyHandler(cfg, modifiers))
|
mux.HandleFunc(path, buildProxyHandler(cfg, modifiers, accessController))
|
||||||
registeredPaths[path] = true
|
registeredPaths[path] = true
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -99,7 +114,7 @@ func NewRouter(cfg *config.Config, provider authz.Provider) http.Handler {
|
||||||
return mux
|
return mux
|
||||||
}
|
}
|
||||||
|
|
||||||
func buildProxyHandler(cfg *config.Config, modifiers map[string]RequestModifier) http.HandlerFunc {
|
func buildProxyHandler(cfg *config.Config, modifiers map[string]RequestModifier, accessController authz.AccessControl) http.HandlerFunc {
|
||||||
// Parse the base URLs up front
|
// Parse the base URLs up front
|
||||||
authBase, err := url.Parse(cfg.AuthServerBaseURL)
|
authBase, err := url.Parse(cfg.AuthServerBaseURL)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -141,6 +156,11 @@ func buildProxyHandler(cfg *config.Config, modifiers map[string]RequestModifier)
|
||||||
// Add CORS headers to all responses
|
// Add CORS headers to all responses
|
||||||
addCORSHeaders(w, cfg, allowedOrigin, "")
|
addCORSHeaders(w, cfg, allowedOrigin, "")
|
||||||
|
|
||||||
|
// Check if the request is for the latest spec
|
||||||
|
specVersion := util.GetVersionWithDefault(r.Header.Get("MCP-Protocol-Version"))
|
||||||
|
ver, err := util.ParseVersionDate(specVersion)
|
||||||
|
isLatestSpec := util.IsLatestSpec(ver, err)
|
||||||
|
|
||||||
// Decide whether the request should go to the auth server or MCP
|
// Decide whether the request should go to the auth server or MCP
|
||||||
var targetURL *url.URL
|
var targetURL *url.URL
|
||||||
isSSE := false
|
isSSE := false
|
||||||
|
@ -148,13 +168,19 @@ func buildProxyHandler(cfg *config.Config, modifiers map[string]RequestModifier)
|
||||||
if isAuthPath(r.URL.Path) {
|
if isAuthPath(r.URL.Path) {
|
||||||
targetURL = authBase
|
targetURL = authBase
|
||||||
} else if isMCPPath(r.URL.Path, cfg) {
|
} else if isMCPPath(r.URL.Path, cfg) {
|
||||||
// Validate JWT for MCP paths if required
|
if ssePaths[r.URL.Path] {
|
||||||
// Placeholder for JWT validation logic
|
if err := authorizeSSE(w, r, isLatestSpec, cfg.ResourceIdentifier); err != nil {
|
||||||
if err := util.ValidateJWT(r.Header.Get("Authorization")); err != nil {
|
http.Error(w, err.Error(), http.StatusUnauthorized)
|
||||||
logger.Warn("Unauthorized request to %s: %v", r.URL.Path, err)
|
return
|
||||||
http.Error(w, "Unauthorized", http.StatusUnauthorized)
|
}
|
||||||
return
|
isSSE = true
|
||||||
|
} else {
|
||||||
|
if err := authorizeMCP(w, r, isLatestSpec, cfg, accessController); err != nil {
|
||||||
|
http.Error(w, err.Error(), http.StatusForbidden)
|
||||||
|
return
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
targetURL = mcpBase
|
targetURL = mcpBase
|
||||||
if ssePaths[r.URL.Path] {
|
if ssePaths[r.URL.Path] {
|
||||||
isSSE = true
|
isSSE = true
|
||||||
|
@ -214,7 +240,17 @@ func buildProxyHandler(cfg *config.Config, modifiers map[string]RequestModifier)
|
||||||
},
|
},
|
||||||
ModifyResponse: func(resp *http.Response) error {
|
ModifyResponse: func(resp *http.Response) error {
|
||||||
logger.Debug("Response from %s%s: %d", resp.Request.URL.Host, resp.Request.URL.Path, resp.StatusCode)
|
logger.Debug("Response from %s%s: %d", resp.Request.URL.Host, resp.Request.URL.Path, resp.StatusCode)
|
||||||
resp.Header.Del("Access-Control-Allow-Origin") // Avoid upstream conflicts
|
if resp.StatusCode == http.StatusUnauthorized {
|
||||||
|
resp.Header.Set(
|
||||||
|
"WWW-Authenticate",
|
||||||
|
fmt.Sprintf(
|
||||||
|
`Bearer resource_metadata="%s"`,
|
||||||
|
cfg.ResourceIdentifier+"/.well-known/oauth-protected-resource",
|
||||||
|
))
|
||||||
|
resp.Header.Set("Access-Control-Expose-Headers", "WWW-Authenticate")
|
||||||
|
}
|
||||||
|
|
||||||
|
resp.Header.Del("Access-Control-Allow-Origin")
|
||||||
return nil
|
return nil
|
||||||
},
|
},
|
||||||
ErrorHandler: func(rw http.ResponseWriter, req *http.Request, err error) {
|
ErrorHandler: func(rw http.ResponseWriter, req *http.Request, err error) {
|
||||||
|
@ -236,7 +272,7 @@ func buildProxyHandler(cfg *config.Config, modifiers map[string]RequestModifier)
|
||||||
w.Header().Set("X-Accel-Buffering", "no")
|
w.Header().Set("X-Accel-Buffering", "no")
|
||||||
w.Header().Set("Cache-Control", "no-cache")
|
w.Header().Set("Cache-Control", "no-cache")
|
||||||
w.Header().Set("Connection", "keep-alive")
|
w.Header().Set("Connection", "keep-alive")
|
||||||
|
w.Header().Set("Content-Type", "text/event-stream")
|
||||||
// Keep SSE connections open
|
// Keep SSE connections open
|
||||||
HandleSSE(w, r, rp)
|
HandleSSE(w, r, rp)
|
||||||
} else {
|
} else {
|
||||||
|
@ -248,6 +284,76 @@ func buildProxyHandler(cfg *config.Config, modifiers map[string]RequestModifier)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Check if the request is for SSE handshake and authorize it
|
||||||
|
func authorizeSSE(w http.ResponseWriter, r *http.Request, isLatestSpec bool, resourceID string) error {
|
||||||
|
authHeader := r.Header.Get("Authorization")
|
||||||
|
if authHeader == "" || !strings.HasPrefix(authHeader, "Bearer ") {
|
||||||
|
if isLatestSpec {
|
||||||
|
realm := resourceID + "/.well-known/oauth-protected-resource"
|
||||||
|
w.Header().Set("WWW-Authenticate", fmt.Sprintf(`Bearer resource_metadata="%s"`, realm))
|
||||||
|
w.Header().Set("Access-Control-Expose-Headers", "WWW-Authenticate")
|
||||||
|
}
|
||||||
|
return fmt.Errorf("missing or invalid Authorization header")
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handles both v1 (just signature) and v2 (aud + scope) flows
|
||||||
|
func authorizeMCP(w http.ResponseWriter, r *http.Request, isLatestSpec bool, cfg *config.Config, accessController authz.AccessControl) error {
|
||||||
|
authzHeader := r.Header.Get("Authorization")
|
||||||
|
accessToken, _ := util.ExtractAccessToken(authzHeader)
|
||||||
|
if !strings.HasPrefix(authzHeader, "Bearer ") {
|
||||||
|
if isLatestSpec {
|
||||||
|
realm := cfg.ResourceIdentifier + "/.well-known/oauth-protected-resource"
|
||||||
|
w.Header().Set("WWW-Authenticate", fmt.Sprintf(
|
||||||
|
`Bearer resource_metadata=%q`, realm,
|
||||||
|
))
|
||||||
|
w.Header().Set("Access-Control-Expose-Headers", "WWW-Authenticate")
|
||||||
|
}
|
||||||
|
http.Error(w, "Unauthorized", http.StatusUnauthorized)
|
||||||
|
return fmt.Errorf("missing or invalid Authorization header")
|
||||||
|
}
|
||||||
|
|
||||||
|
err := util.ValidateJWT(isLatestSpec, accessToken, cfg.Audience)
|
||||||
|
if err != nil {
|
||||||
|
if isLatestSpec {
|
||||||
|
realm := cfg.ResourceIdentifier + "/.well-known/oauth-protected-resource"
|
||||||
|
w.Header().Set("WWW-Authenticate", fmt.Sprintf(err.Error(),
|
||||||
|
`Bearer realm=%q`,
|
||||||
|
realm,
|
||||||
|
))
|
||||||
|
w.Header().Set("Access-Control-Expose-Headers", "WWW-Authenticate")
|
||||||
|
http.Error(w, "Forbidden", http.StatusForbidden)
|
||||||
|
} else {
|
||||||
|
http.Error(w, "Unauthorized", http.StatusUnauthorized)
|
||||||
|
}
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if isLatestSpec {
|
||||||
|
_, err := util.ParseRPCRequest(r)
|
||||||
|
if err != nil {
|
||||||
|
http.Error(w, "Bad request", http.StatusBadRequest)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
claimsMap, err := util.ParseJWT(accessToken)
|
||||||
|
if err != nil {
|
||||||
|
http.Error(w, "Invalid token claims", http.StatusUnauthorized)
|
||||||
|
return fmt.Errorf("invalid token claims")
|
||||||
|
}
|
||||||
|
|
||||||
|
pr := accessController.ValidateAccess(r, &claimsMap, cfg)
|
||||||
|
if pr.Decision == authz.DecisionDeny {
|
||||||
|
http.Error(w, "Forbidden: "+pr.Message, http.StatusForbidden)
|
||||||
|
return fmt.Errorf("forbidden — %s", pr.Message)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func getAllowedOrigin(origin string, cfg *config.Config) string {
|
func getAllowedOrigin(origin string, cfg *config.Config) string {
|
||||||
if origin == "" {
|
if origin == "" {
|
||||||
return cfg.CORSConfig.AllowedOrigins[0] // Default to first allowed origin
|
return cfg.CORSConfig.AllowedOrigins[0] // Default to first allowed origin
|
||||||
|
@ -265,6 +371,7 @@ func getAllowedOrigin(origin string, cfg *config.Config) string {
|
||||||
func addCORSHeaders(w http.ResponseWriter, cfg *config.Config, allowedOrigin, requestHeaders string) {
|
func addCORSHeaders(w http.ResponseWriter, cfg *config.Config, allowedOrigin, requestHeaders string) {
|
||||||
w.Header().Set("Access-Control-Allow-Origin", allowedOrigin)
|
w.Header().Set("Access-Control-Allow-Origin", allowedOrigin)
|
||||||
w.Header().Set("Access-Control-Allow-Methods", strings.Join(cfg.CORSConfig.AllowedMethods, ", "))
|
w.Header().Set("Access-Control-Allow-Methods", strings.Join(cfg.CORSConfig.AllowedMethods, ", "))
|
||||||
|
w.Header().Set("Access-Control-Expose-Headers", "WWW-Authenticate, MCP-Protocol-Version")
|
||||||
if requestHeaders != "" {
|
if requestHeaders != "" {
|
||||||
w.Header().Set("Access-Control-Allow-Headers", requestHeaders)
|
w.Header().Set("Access-Control-Allow-Headers", requestHeaders)
|
||||||
} else {
|
} else {
|
||||||
|
@ -272,6 +379,7 @@ func addCORSHeaders(w http.ResponseWriter, cfg *config.Config, allowedOrigin, re
|
||||||
}
|
}
|
||||||
if cfg.CORSConfig.AllowCredentials {
|
if cfg.CORSConfig.AllowCredentials {
|
||||||
w.Header().Set("Access-Control-Allow-Credentials", "true")
|
w.Header().Set("Access-Control-Allow-Credentials", "true")
|
||||||
|
w.Header().Set("MCP-Protocol-Version", ", ")
|
||||||
}
|
}
|
||||||
w.Header().Set("Vary", "Origin")
|
w.Header().Set("Vary", "Origin")
|
||||||
w.Header().Set("X-Accel-Buffering", "no")
|
w.Header().Set("X-Accel-Buffering", "no")
|
||||||
|
@ -283,6 +391,7 @@ func isAuthPath(path string) bool {
|
||||||
"/token": true,
|
"/token": true,
|
||||||
"/register": true,
|
"/register": true,
|
||||||
"/.well-known/oauth-authorization-server": true,
|
"/.well-known/oauth-authorization-server": true,
|
||||||
|
"/.well-known/oauth-protected-resource": true,
|
||||||
}
|
}
|
||||||
if strings.HasPrefix(path, "/u/") {
|
if strings.HasPrefix(path, "/u/") {
|
||||||
return true
|
return true
|
||||||
|
|
|
@ -4,21 +4,27 @@ import (
|
||||||
"crypto/rsa"
|
"crypto/rsa"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
|
"fmt"
|
||||||
"math/big"
|
"math/big"
|
||||||
"net/http"
|
"net/http"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/golang-jwt/jwt/v4"
|
"github.com/golang-jwt/jwt/v4"
|
||||||
|
"github.com/wso2/open-mcp-auth-proxy/internal/config"
|
||||||
"github.com/wso2/open-mcp-auth-proxy/internal/logging"
|
"github.com/wso2/open-mcp-auth-proxy/internal/logging"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
type TokenClaims struct {
|
||||||
|
Scopes []string
|
||||||
|
}
|
||||||
|
|
||||||
type JWKS struct {
|
type JWKS struct {
|
||||||
Keys []json.RawMessage `json:"keys"`
|
Keys []json.RawMessage `json:"keys"`
|
||||||
}
|
}
|
||||||
|
|
||||||
var publicKeys map[string]*rsa.PublicKey
|
var publicKeys map[string]*rsa.PublicKey
|
||||||
|
|
||||||
// FetchJWKS downloads JWKS and stores in a package-level map
|
// FetchJWKS downloads JWKS and stores in a package‐level map
|
||||||
func FetchJWKS(jwksURL string) error {
|
func FetchJWKS(jwksURL string) error {
|
||||||
resp, err := http.Get(jwksURL)
|
resp, err := http.Get(jwksURL)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -31,23 +37,23 @@ func FetchJWKS(jwksURL string) error {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
publicKeys = make(map[string]*rsa.PublicKey)
|
publicKeys = make(map[string]*rsa.PublicKey, len(jwks.Keys))
|
||||||
for _, keyData := range jwks.Keys {
|
for _, keyData := range jwks.Keys {
|
||||||
var parsedKey struct {
|
var parsed struct {
|
||||||
Kid string `json:"kid"`
|
Kid string `json:"kid"`
|
||||||
N string `json:"n"`
|
N string `json:"n"`
|
||||||
E string `json:"e"`
|
E string `json:"e"`
|
||||||
Kty string `json:"kty"`
|
Kty string `json:"kty"`
|
||||||
}
|
}
|
||||||
if err := json.Unmarshal(keyData, &parsedKey); err != nil {
|
if err := json.Unmarshal(keyData, &parsed); err != nil {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
if parsedKey.Kty != "RSA" {
|
if parsed.Kty != "RSA" {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
pubKey, err := parseRSAPublicKey(parsedKey.N, parsedKey.E)
|
pubKey, err := parseRSAPublicKey(parsed.N, parsed.E)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
publicKeys[parsedKey.Kid] = pubKey
|
publicKeys[parsed.Kid] = pubKey
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
logger.Info("Loaded %d public keys.", len(publicKeys))
|
logger.Info("Loaded %d public keys.", len(publicKeys))
|
||||||
|
@ -73,25 +79,130 @@ func parseRSAPublicKey(nStr, eStr string) (*rsa.PublicKey, error) {
|
||||||
return &rsa.PublicKey{N: n, E: e}, nil
|
return &rsa.PublicKey{N: n, E: e}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// ValidateJWT checks the Authorization: Bearer token using stored JWKS
|
// ValidateJWT checks the Bearer token according to the Mcp-Protocol-Version.
|
||||||
func ValidateJWT(authHeader string) error {
|
func ValidateJWT(
|
||||||
if authHeader == "" || !strings.HasPrefix(authHeader, "Bearer ") {
|
isLatestSpec bool,
|
||||||
return errors.New("missing or invalid Authorization header")
|
accessToken string,
|
||||||
}
|
audience string,
|
||||||
tokenStr := strings.TrimPrefix(authHeader, "Bearer ")
|
) error {
|
||||||
token, err := jwt.Parse(tokenStr, func(token *jwt.Token) (interface{}, error) {
|
logger.Warn("isLatestSpec: %s", isLatestSpec)
|
||||||
kid, _ := token.Header["kid"].(string)
|
// Parse & verify the signature
|
||||||
pubKey, ok := publicKeys[kid]
|
token, err := jwt.Parse(accessToken, func(token *jwt.Token) (interface{}, error) {
|
||||||
if !ok {
|
if _, ok := token.Method.(*jwt.SigningMethodRSA); !ok {
|
||||||
return nil, errors.New("unknown or missing kid in token header")
|
return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
|
||||||
}
|
}
|
||||||
return pubKey, nil
|
kid, ok := token.Header["kid"].(string)
|
||||||
|
if !ok {
|
||||||
|
return nil, errors.New("kid header not found")
|
||||||
|
}
|
||||||
|
key, ok := publicKeys[kid]
|
||||||
|
if !ok {
|
||||||
|
return nil, fmt.Errorf("key not found for kid: %s", kid)
|
||||||
|
}
|
||||||
|
return key, nil
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return errors.New("invalid token: " + err.Error())
|
logger.Warn("Error detected, returning early")
|
||||||
|
return fmt.Errorf("invalid token: %w", err)
|
||||||
}
|
}
|
||||||
if !token.Valid {
|
if !token.Valid {
|
||||||
return errors.New("invalid token: token not valid")
|
logger.Warn("Token invalid, returning early")
|
||||||
|
return errors.New("token not valid")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
claimsMap, ok := token.Claims.(jwt.MapClaims)
|
||||||
|
if !ok {
|
||||||
|
return errors.New("unexpected claim type")
|
||||||
|
}
|
||||||
|
|
||||||
|
if !isLatestSpec {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
audRaw, exists := claimsMap["aud"]
|
||||||
|
if !exists {
|
||||||
|
return errors.New("aud claim missing")
|
||||||
|
}
|
||||||
|
switch v := audRaw.(type) {
|
||||||
|
case string:
|
||||||
|
if v != audience {
|
||||||
|
return fmt.Errorf("aud %q does not match %q", v, audience)
|
||||||
|
}
|
||||||
|
case []interface{}:
|
||||||
|
var found bool
|
||||||
|
for _, a := range v {
|
||||||
|
if s, ok := a.(string); ok && s == audience {
|
||||||
|
found = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !found {
|
||||||
|
return fmt.Errorf("audience %v does not include %q", v, audience)
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
return errors.New("aud claim has unexpected type")
|
||||||
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Parses the JWT token and returns the claims
|
||||||
|
func ParseJWT(tokenStr string) (jwt.MapClaims, error) {
|
||||||
|
if tokenStr == "" {
|
||||||
|
return nil, fmt.Errorf("empty JWT")
|
||||||
|
}
|
||||||
|
|
||||||
|
var claims jwt.MapClaims
|
||||||
|
_, _, err := jwt.NewParser().ParseUnverified(tokenStr, &claims)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to parse JWT: %w", err)
|
||||||
|
}
|
||||||
|
return claims, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Process the required scopes
|
||||||
|
func GetRequiredScopes(cfg *config.Config, method string) []string {
|
||||||
|
switch raw := cfg.ScopesSupported.(type) {
|
||||||
|
case map[string]string:
|
||||||
|
if scope, ok := raw[method]; ok {
|
||||||
|
return []string{scope}
|
||||||
|
}
|
||||||
|
parts := strings.SplitN(method, "/", 2)
|
||||||
|
if len(parts) > 0 {
|
||||||
|
if scope, ok := raw[parts[0]]; ok {
|
||||||
|
return []string{scope}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
case []interface{}:
|
||||||
|
out := make([]string, 0, len(raw))
|
||||||
|
for _, v := range raw {
|
||||||
|
if s, ok := v.(string); ok && s != "" {
|
||||||
|
out = append(out, s)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return out
|
||||||
|
|
||||||
|
case []string:
|
||||||
|
return raw
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Extracts the Bearer token from the Authorization header
|
||||||
|
func ExtractAccessToken(authHeader string) (string, error) {
|
||||||
|
if authHeader == "" {
|
||||||
|
return "", errors.New("empty authorization header")
|
||||||
|
}
|
||||||
|
if !strings.HasPrefix(authHeader, "Bearer ") {
|
||||||
|
return "", fmt.Errorf("invalid authorization header format: %s", authHeader)
|
||||||
|
}
|
||||||
|
|
||||||
|
tokenStr := strings.TrimSpace(strings.TrimPrefix(authHeader, "Bearer "))
|
||||||
|
if tokenStr == "" {
|
||||||
|
return "", errors.New("empty bearer token")
|
||||||
|
}
|
||||||
|
|
||||||
|
return tokenStr, nil
|
||||||
|
}
|
||||||
|
|
38
internal/util/rpc.go
Normal file
38
internal/util/rpc.go
Normal file
|
@ -0,0 +1,38 @@
|
||||||
|
package util
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"encoding/json"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
|
||||||
|
logger "github.com/wso2/open-mcp-auth-proxy/internal/logging"
|
||||||
|
)
|
||||||
|
|
||||||
|
type RPCEnvelope struct {
|
||||||
|
Method string `json:"method"`
|
||||||
|
Params any `json:"params"`
|
||||||
|
ID any `json:"id"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// This function parses a JSON-RPC request from an HTTP request body
|
||||||
|
func ParseRPCRequest(r *http.Request) (*RPCEnvelope, error) {
|
||||||
|
bodyBytes, err := io.ReadAll(r.Body)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
r.Body = io.NopCloser(bytes.NewBuffer(bodyBytes))
|
||||||
|
|
||||||
|
if len(bodyBytes) == 0 {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var env RPCEnvelope
|
||||||
|
dec := json.NewDecoder(bytes.NewReader(bodyBytes))
|
||||||
|
if err := dec.Decode(&env); err != nil && err != io.EOF {
|
||||||
|
logger.Warn("Error parsing JSON-RPC envelope: %v", err)
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return &env, nil
|
||||||
|
}
|
26
internal/util/version.go
Normal file
26
internal/util/version.go
Normal file
|
@ -0,0 +1,26 @@
|
||||||
|
package util
|
||||||
|
|
||||||
|
import (
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/wso2/open-mcp-auth-proxy/internal/constants"
|
||||||
|
)
|
||||||
|
|
||||||
|
// This function checks if the given version date is after the spec cutover date
|
||||||
|
func IsLatestSpec(versionDate time.Time, err error) bool {
|
||||||
|
return err == nil && versionDate.After(constants.SpecCutoverDate)
|
||||||
|
}
|
||||||
|
|
||||||
|
// This function parses a version string into a time.Time
|
||||||
|
func ParseVersionDate(version string) (time.Time, error) {
|
||||||
|
return time.Parse("2006-01-02", version)
|
||||||
|
}
|
||||||
|
|
||||||
|
// This function returns the version string, using the cutover date if empty
|
||||||
|
func GetVersionWithDefault(version string) string {
|
||||||
|
if version == "" {
|
||||||
|
defaultTime, _ := time.Parse(constants.TimeLayout, "2025-05-15")
|
||||||
|
return defaultTime.Format(constants.TimeLayout)
|
||||||
|
}
|
||||||
|
return version
|
||||||
|
}
|
Loading…
Add table
Add a link
Reference in a new issue