mirror of
https://github.com/wso2/open-mcp-auth-proxy.git
synced 2025-06-27 17:13:31 +00:00
Update scope validation implementation
This commit is contained in:
parent
5c22f36ddc
commit
64caaa0f7c
7 changed files with 202 additions and 138 deletions
57
README.md
57
README.md
|
@ -32,9 +32,7 @@ A lightweight authorization proxy for Model Context Protocol (MCP) servers that
|
|||
| Version | Behavior |
|
||||
| :-------------------- | :----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
|
||||
| 2025-03-26 | Only signature check of Bearer JWT on both `/sse` and `/message`<br> No scope or audience enforcement |
|
||||
| Latest(draft) | Read `MCP-Protocol-Version` from client header<br> SSE handshake returns `WWW-Authenticate: Bearer resource_metadata="…"`<br> `/message` enforces:<br> 1. `aud` claim == `ResourceIdentifier`<br> 2. `scope` claim contains per-path `requiredScope`<br> 3. PolicyEngine decision<br> Rich `WWW-Authenticate` on 401s<br> Serves `/.well-known/oauth-protected-resource` JSON |
|
||||
|
||||
> ⚠️ **Note:** MCP v2 support is available **only in SSE mode**. The stdio mode supports only v1.
|
||||
| Latest(draft) | Read `MCP-Protocol-Version` from client header<br> SSE handshake returns `WWW-Authenticate: Bearer resource_metadata="…"`<br> `/message` enforces:<br>`aud` claim == `ResourceIdentifier`<br>`scope` claim contains `requiredScope`<br>Scope based access control<br>Rich `WWW-Authenticate` on 401s<br>Serves `/.well-known/oauth-protected-resource` JSON |
|
||||
|
||||
## 🛠️ Quick Start
|
||||
|
||||
|
@ -98,26 +96,17 @@ To enable authorization through your Asgardeo organization:
|
|||
3. Update `config.yaml` with the following parameters.
|
||||
|
||||
```yaml
|
||||
base_url: "http://localhost:8000" # URL of your MCP server
|
||||
listen_port: 8080 # Address where the proxy will listen
|
||||
base_url: "http://localhost:8000" # URL of your MCP server
|
||||
listen_port: 8080 # Address where the proxy will listen
|
||||
|
||||
asgardeo:
|
||||
org_name: "<org_name>" # Your Asgardeo org name
|
||||
client_id: "<client_id>" # Client ID of the M2M app
|
||||
client_secret: "<client_secret>" # Client secret of the M2M app
|
||||
|
||||
resource_identifier: "http://localhost:8080"
|
||||
scopes_supported:
|
||||
- "read:tools"
|
||||
- "read:resources"
|
||||
audience: "<audience_value>"
|
||||
authorization_servers:
|
||||
- "https://api.asgardeo.io/t/acme"
|
||||
jwks_uri: "https://api.asgardeo.io/t/acme/oauth2/jwks"
|
||||
bearer_methods_supported:
|
||||
- header
|
||||
- body
|
||||
- query
|
||||
resource_identifier: "http://localhost:8080" # Proxy server URL
|
||||
scopes_supported: # Scopes required to access the MCP server
|
||||
- "read:tools"
|
||||
- "read:resources"
|
||||
audience: "<audience_value>" # Access token audience
|
||||
authorization_servers: # Authorization server URL
|
||||
- "https://api.asgardeo.io/t/acme"
|
||||
jwks_uri: "https://api.asgardeo.io/t/acme/oauth2/jwks" # JWKS URL of the Authorization server
|
||||
```
|
||||
|
||||
4. Start the proxy with Asgardeo integration:
|
||||
|
@ -240,22 +229,14 @@ demo:
|
|||
client_secret: "qFHfiBp5gNGAO9zV4YPnDofBzzfInatfUbHyPZvM0jka"
|
||||
|
||||
# Asgardeo configuration (used with --asgardeo flag)
|
||||
asgardeo:
|
||||
org_name: "<org_name>"
|
||||
client_id: "<client_id>"
|
||||
client_secret: "<client_secret>"
|
||||
resource_identifier: "http://localhost:8080"
|
||||
scopes_supported:
|
||||
- "read:tools"
|
||||
- "read:resources"
|
||||
audience: "<audience_value>"
|
||||
authorization_servers:
|
||||
- "https://api.asgardeo.io/t/acme"
|
||||
jwks_uri: "https://api.asgardeo.io/t/acme/oauth2/jwks"
|
||||
bearer_methods_supported:
|
||||
- header
|
||||
- body
|
||||
- query
|
||||
resource_identifier: "http://localhost:8080"
|
||||
scopes_supported:
|
||||
- "read:tools"
|
||||
- "read:resources"
|
||||
audience: "<audience_value>"
|
||||
authorization_servers:
|
||||
- "https://api.asgardeo.io/t/acme"
|
||||
jwks_uri: "https://api.asgardeo.io/t/acme/oauth2/jwks"
|
||||
```
|
||||
|
||||
### 🖥️ Build from source
|
||||
|
|
|
@ -11,7 +11,6 @@ import (
|
|||
|
||||
"github.com/wso2/open-mcp-auth-proxy/internal/authz"
|
||||
"github.com/wso2/open-mcp-auth-proxy/internal/config"
|
||||
"github.com/wso2/open-mcp-auth-proxy/internal/constants"
|
||||
"github.com/wso2/open-mcp-auth-proxy/internal/logging"
|
||||
"github.com/wso2/open-mcp-auth-proxy/internal/proxy"
|
||||
"github.com/wso2/open-mcp-auth-proxy/internal/subprocess"
|
||||
|
@ -68,23 +67,7 @@ func main() {
|
|||
}
|
||||
|
||||
// 3. Create the chosen provider
|
||||
var provider authz.Provider
|
||||
if *demoMode {
|
||||
cfg.Mode = "demo"
|
||||
cfg.AuthServerBaseURL = constants.ASGARDEO_BASE_URL + cfg.Demo.OrgName + "/oauth2"
|
||||
cfg.JWKSURL = constants.ASGARDEO_BASE_URL + cfg.Demo.OrgName + "/oauth2/jwks"
|
||||
provider = authz.NewAsgardeoProvider(cfg)
|
||||
} else if *asgardeoMode {
|
||||
cfg.Mode = "asgardeo"
|
||||
cfg.AuthServerBaseURL = constants.ASGARDEO_BASE_URL + cfg.Asgardeo.OrgName + "/oauth2"
|
||||
cfg.JWKSURL = constants.ASGARDEO_BASE_URL + cfg.Asgardeo.OrgName + "/oauth2/jwks"
|
||||
provider = authz.NewAsgardeoProvider(cfg)
|
||||
} else {
|
||||
cfg.Mode = "default"
|
||||
cfg.JWKSURL = cfg.Default.JWKSURL
|
||||
cfg.AuthServerBaseURL = cfg.Default.BaseURL
|
||||
provider = authz.NewDefaultProvider(cfg)
|
||||
}
|
||||
var provider authz.Provider = MakeProvider(cfg, *demoMode, *asgardeoMode)
|
||||
|
||||
// 4. (Optional) Fetch JWKS if you want local JWT validation
|
||||
if err := util.FetchJWKS(cfg.JWKSURL); err != nil {
|
||||
|
|
45
cmd/proxy/provider.go
Normal file
45
cmd/proxy/provider.go
Normal file
|
@ -0,0 +1,45 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"github.com/wso2/open-mcp-auth-proxy/internal/authz"
|
||||
"github.com/wso2/open-mcp-auth-proxy/internal/config"
|
||||
"github.com/wso2/open-mcp-auth-proxy/internal/constants"
|
||||
)
|
||||
|
||||
func MakeProvider(cfg *config.Config, demoMode, asgardeoMode bool) authz.Provider {
|
||||
var mode, orgName string
|
||||
switch {
|
||||
case demoMode:
|
||||
mode = "demo"
|
||||
orgName = cfg.Demo.OrgName
|
||||
case asgardeoMode:
|
||||
mode = "asgardeo"
|
||||
orgName = cfg.Asgardeo.OrgName
|
||||
default:
|
||||
mode = "default"
|
||||
}
|
||||
cfg.Mode = mode
|
||||
|
||||
switch mode {
|
||||
case "demo", "asgardeo":
|
||||
if len(cfg.AuthorizationServers) == 0 && cfg.JwksURI == "" {
|
||||
base := constants.ASGARDEO_BASE_URL + orgName + "/oauth2"
|
||||
cfg.AuthServerBaseURL = base
|
||||
cfg.JWKSURL = base + "/jwks"
|
||||
} else {
|
||||
cfg.AuthServerBaseURL = cfg.AuthorizationServers[0]
|
||||
cfg.JWKSURL = cfg.JwksURI
|
||||
}
|
||||
return authz.NewAsgardeoProvider(cfg)
|
||||
|
||||
default:
|
||||
if cfg.Default.BaseURL != "" && cfg.Default.JWKSURL != "" {
|
||||
cfg.AuthServerBaseURL = cfg.Default.BaseURL
|
||||
cfg.JWKSURL = cfg.Default.JWKSURL
|
||||
} else if len(cfg.AuthorizationServers) > 0 {
|
||||
cfg.AuthServerBaseURL = cfg.AuthorizationServers[0]
|
||||
cfg.JWKSURL = cfg.JwksURI
|
||||
}
|
||||
return authz.NewDefaultProvider(cfg)
|
||||
}
|
||||
}
|
|
@ -1,6 +1,11 @@
|
|||
package authz
|
||||
|
||||
import "net/http"
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"github.com/golang-jwt/jwt/v4"
|
||||
"github.com/wso2/open-mcp-auth-proxy/internal/config"
|
||||
)
|
||||
|
||||
type Decision int
|
||||
|
||||
|
@ -15,5 +20,5 @@ type AccessControlResult struct {
|
|||
}
|
||||
|
||||
type AccessControl interface {
|
||||
ValidateAccess(r *http.Request, claims *TokenClaims, requiredScopes any) AccessControlResult
|
||||
ValidateAccess(r *http.Request, claims *jwt.MapClaims, config *config.Config) AccessControlResult
|
||||
}
|
||||
|
|
|
@ -4,54 +4,68 @@ import (
|
|||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type TokenClaims struct {
|
||||
Scopes []string
|
||||
}
|
||||
"github.com/golang-jwt/jwt/v4"
|
||||
"github.com/wso2/open-mcp-auth-proxy/internal/config"
|
||||
"github.com/wso2/open-mcp-auth-proxy/internal/util"
|
||||
)
|
||||
|
||||
type ScopeValidator struct{}
|
||||
|
||||
// Evaluate and checks the token claims against one or more required scopes.
|
||||
func (d *ScopeValidator) ValidateAccess(
|
||||
_ *http.Request,
|
||||
claims *TokenClaims,
|
||||
requiredScopes any,
|
||||
r *http.Request,
|
||||
claims *jwt.MapClaims,
|
||||
config *config.Config,
|
||||
) AccessControlResult {
|
||||
var scopeStr string
|
||||
switch v := requiredScopes.(type) {
|
||||
case string:
|
||||
scopeStr = v
|
||||
case []string:
|
||||
scopeStr = strings.Join(v, " ")
|
||||
env, err := util.ParseRPCRequest(r)
|
||||
if err != nil {
|
||||
return AccessControlResult{DecisionDeny, "bad JSON-RPC request"}
|
||||
}
|
||||
requiredScopes := util.GetRequiredScopes(config, env.Method)
|
||||
if len(requiredScopes) == 0 {
|
||||
return AccessControlResult{DecisionAllow, ""}
|
||||
}
|
||||
|
||||
required := make(map[string]struct{}, len(requiredScopes))
|
||||
for _, s := range requiredScopes {
|
||||
s = strings.TrimSpace(s)
|
||||
if s != "" {
|
||||
required[s] = struct{}{}
|
||||
}
|
||||
}
|
||||
|
||||
var tokenScopes []string
|
||||
if claims, ok := (*claims)["scope"]; ok {
|
||||
switch v := claims.(type) {
|
||||
case string:
|
||||
tokenScopes = strings.Fields(v)
|
||||
case []interface{}:
|
||||
for _, x := range v {
|
||||
if s, ok := x.(string); ok && s != "" {
|
||||
tokenScopes = append(tokenScopes, s)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
tokenScopeSet := make(map[string]struct{}, len(tokenScopes))
|
||||
for _, s := range tokenScopes {
|
||||
tokenScopeSet[s] = struct{}{}
|
||||
}
|
||||
|
||||
if strings.TrimSpace(scopeStr) == "" {
|
||||
return AccessControlResult{DecisionAllow, ""}
|
||||
}
|
||||
|
||||
scopes := strings.FieldsFunc(scopeStr, func(r rune) bool {
|
||||
return r == ' ' || r == ','
|
||||
})
|
||||
required := make(map[string]struct{}, len(scopes))
|
||||
for _, s := range scopes {
|
||||
if s = strings.TrimSpace(s); s != "" {
|
||||
required[s] = struct{}{}
|
||||
}
|
||||
}
|
||||
|
||||
for _, tokenScope := range claims.Scopes {
|
||||
if _, ok := required[tokenScope]; ok {
|
||||
return AccessControlResult{DecisionAllow, ""}
|
||||
}
|
||||
}
|
||||
|
||||
var list []string
|
||||
var missing []string
|
||||
for s := range required {
|
||||
list = append(list, s)
|
||||
if _, ok := tokenScopeSet[s]; !ok {
|
||||
missing = append(missing, s)
|
||||
}
|
||||
}
|
||||
|
||||
if len(missing) == 0 {
|
||||
return AccessControlResult{DecisionAllow, ""}
|
||||
}
|
||||
return AccessControlResult{
|
||||
DecisionDeny,
|
||||
fmt.Sprintf("missing required scope(s): %s", strings.Join(list, ", ")),
|
||||
fmt.Sprintf("missing required scope(s): %s", strings.Join(missing, ", ")),
|
||||
}
|
||||
}
|
||||
|
|
|
@ -302,6 +302,7 @@ func authorizeSSE(w http.ResponseWriter, r *http.Request, isLatestSpec bool, res
|
|||
// Handles both v1 (just signature) and v2 (aud + scope) flows
|
||||
func authorizeMCP(w http.ResponseWriter, r *http.Request, isLatestSpec bool, cfg *config.Config, accessController authz.AccessControl) error {
|
||||
authzHeader := r.Header.Get("Authorization")
|
||||
accessToken, _ := util.ExtractAccessToken(authzHeader)
|
||||
if !strings.HasPrefix(authzHeader, "Bearer ") {
|
||||
if isLatestSpec {
|
||||
realm := cfg.ResourceIdentifier + "/.well-known/oauth-protected-resource"
|
||||
|
@ -314,7 +315,7 @@ func authorizeMCP(w http.ResponseWriter, r *http.Request, isLatestSpec bool, cfg
|
|||
return fmt.Errorf("missing or invalid Authorization header")
|
||||
}
|
||||
|
||||
claims, err := util.ValidateJWT(isLatestSpec, authzHeader, cfg.Audience)
|
||||
err := util.ValidateJWT(isLatestSpec, accessToken, cfg.Audience)
|
||||
if err != nil {
|
||||
if isLatestSpec {
|
||||
realm := cfg.ResourceIdentifier + "/.well-known/oauth-protected-resource"
|
||||
|
@ -331,16 +332,19 @@ func authorizeMCP(w http.ResponseWriter, r *http.Request, isLatestSpec bool, cfg
|
|||
}
|
||||
|
||||
if isLatestSpec {
|
||||
env, err := util.ParseRPCRequest(r)
|
||||
_, err := util.ParseRPCRequest(r)
|
||||
if err != nil {
|
||||
http.Error(w, "Bad request", http.StatusBadRequest)
|
||||
return err
|
||||
}
|
||||
requiredScopes := util.GetRequiredScopes(cfg, env.Method)
|
||||
if len(requiredScopes) == 0 {
|
||||
return nil
|
||||
|
||||
claimsMap, err := util.ParseJWT(accessToken)
|
||||
if err != nil {
|
||||
http.Error(w, "Invalid token claims", http.StatusUnauthorized)
|
||||
return fmt.Errorf("invalid token claims")
|
||||
}
|
||||
pr := accessController.ValidateAccess(r, claims, requiredScopes)
|
||||
|
||||
pr := accessController.ValidateAccess(r, &claimsMap, cfg)
|
||||
if pr.Decision == authz.DecisionDeny {
|
||||
http.Error(w, "Forbidden: "+pr.Message, http.StatusForbidden)
|
||||
return fmt.Errorf("forbidden — %s", pr.Message)
|
||||
|
|
|
@ -10,7 +10,6 @@ import (
|
|||
"strings"
|
||||
|
||||
"github.com/golang-jwt/jwt/v4"
|
||||
"github.com/wso2/open-mcp-auth-proxy/internal/authz"
|
||||
"github.com/wso2/open-mcp-auth-proxy/internal/config"
|
||||
"github.com/wso2/open-mcp-auth-proxy/internal/logging"
|
||||
)
|
||||
|
@ -83,15 +82,12 @@ func parseRSAPublicKey(nStr, eStr string) (*rsa.PublicKey, error) {
|
|||
// ValidateJWT checks the Bearer token according to the Mcp-Protocol-Version.
|
||||
func ValidateJWT(
|
||||
isLatestSpec bool,
|
||||
authHeader, audience string,
|
||||
) (*authz.TokenClaims, error) {
|
||||
tokenStr := strings.TrimPrefix(authHeader, "Bearer ")
|
||||
if tokenStr == "" {
|
||||
return nil, errors.New("empty bearer token")
|
||||
}
|
||||
|
||||
accessToken string,
|
||||
audience string,
|
||||
) error {
|
||||
logger.Warn("isLatestSpec: %s", isLatestSpec)
|
||||
// Parse & verify the signature
|
||||
token, err := jwt.Parse(tokenStr, func(token *jwt.Token) (interface{}, error) {
|
||||
token, err := jwt.Parse(accessToken, func(token *jwt.Token) (interface{}, error) {
|
||||
if _, ok := token.Method.(*jwt.SigningMethodRSA); !ok {
|
||||
return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
|
||||
}
|
||||
|
@ -106,29 +102,31 @@ func ValidateJWT(
|
|||
return key, nil
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid token: %w", err)
|
||||
logger.Warn("Error detected, returning early")
|
||||
return fmt.Errorf("invalid token: %w", err)
|
||||
}
|
||||
if !token.Valid {
|
||||
return nil, errors.New("token not valid")
|
||||
logger.Warn("Token invalid, returning early")
|
||||
return errors.New("token not valid")
|
||||
}
|
||||
|
||||
claimsMap, ok := token.Claims.(jwt.MapClaims)
|
||||
if !ok {
|
||||
return nil, errors.New("unexpected claim type")
|
||||
return errors.New("unexpected claim type")
|
||||
}
|
||||
|
||||
if !isLatestSpec {
|
||||
return &authz.TokenClaims{Scopes: nil}, nil
|
||||
return nil
|
||||
}
|
||||
|
||||
audRaw, exists := claimsMap["aud"]
|
||||
if !exists {
|
||||
return nil, errors.New("aud claim missing")
|
||||
return errors.New("aud claim missing")
|
||||
}
|
||||
switch v := audRaw.(type) {
|
||||
case string:
|
||||
if v != audience {
|
||||
return nil, fmt.Errorf("aud %q does not match %q", v, audience)
|
||||
return fmt.Errorf("aud %q does not match %q", v, audience)
|
||||
}
|
||||
case []interface{}:
|
||||
var found bool
|
||||
|
@ -139,38 +137,72 @@ func ValidateJWT(
|
|||
}
|
||||
}
|
||||
if !found {
|
||||
return nil, fmt.Errorf("audience %v does not include %q", v, audience)
|
||||
return fmt.Errorf("audience %v does not include %q", v, audience)
|
||||
}
|
||||
default:
|
||||
return nil, errors.New("aud claim has unexpected type")
|
||||
return errors.New("aud claim has unexpected type")
|
||||
}
|
||||
|
||||
rawScope := claimsMap["scope"]
|
||||
scopeList := []string{}
|
||||
if s, ok := rawScope.(string); ok {
|
||||
scopeList = strings.Fields(s)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
return &authz.TokenClaims{Scopes: scopeList}, nil
|
||||
// Parses the JWT token and returns the claims
|
||||
func ParseJWT(tokenStr string) (jwt.MapClaims, error) {
|
||||
if tokenStr == "" {
|
||||
return nil, fmt.Errorf("empty JWT")
|
||||
}
|
||||
|
||||
var claims jwt.MapClaims
|
||||
_, _, err := jwt.NewParser().ParseUnverified(tokenStr, &claims)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to parse JWT: %w", err)
|
||||
}
|
||||
return claims, nil
|
||||
}
|
||||
|
||||
// Process the required scopes
|
||||
func GetRequiredScopes(cfg *config.Config, method string) []string {
|
||||
if scopes, ok := cfg.ScopesSupported.(map[string]string); ok && len(scopes) > 0 {
|
||||
if scope, ok := scopes[method]; ok {
|
||||
return []string{scope}
|
||||
}
|
||||
if parts := strings.SplitN(method, "/", 2); len(parts) > 0 {
|
||||
if scope, ok := scopes[parts[0]]; ok {
|
||||
return []string{scope}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
switch raw := cfg.ScopesSupported.(type) {
|
||||
case map[string]string:
|
||||
if scope, ok := raw[method]; ok {
|
||||
return []string{scope}
|
||||
}
|
||||
parts := strings.SplitN(method, "/", 2)
|
||||
if len(parts) > 0 {
|
||||
if scope, ok := raw[parts[0]]; ok {
|
||||
return []string{scope}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
case []interface{}:
|
||||
out := make([]string, 0, len(raw))
|
||||
for _, v := range raw {
|
||||
if s, ok := v.(string); ok && s != "" {
|
||||
out = append(out, s)
|
||||
}
|
||||
}
|
||||
return out
|
||||
|
||||
if scopes, ok := cfg.ScopesSupported.([]string); ok && len(scopes) > 0 {
|
||||
return scopes
|
||||
}
|
||||
case []string:
|
||||
return raw
|
||||
}
|
||||
|
||||
return []string{}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Extracts the Bearer token from the Authorization header
|
||||
func ExtractAccessToken(authHeader string) (string, error) {
|
||||
if authHeader == "" {
|
||||
return "", errors.New("empty authorization header")
|
||||
}
|
||||
if !strings.HasPrefix(authHeader, "Bearer ") {
|
||||
return "", fmt.Errorf("invalid authorization header format: %s", authHeader)
|
||||
}
|
||||
|
||||
tokenStr := strings.TrimSpace(strings.TrimPrefix(authHeader, "Bearer "))
|
||||
if tokenStr == "" {
|
||||
return "", errors.New("empty bearer token")
|
||||
}
|
||||
|
||||
return tokenStr, nil
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue