Misc improvements

This commit is contained in:
Pavindu Lakshan 2025-08-11 16:39:35 +05:30
parent b30aa6273c
commit 8589035d64
8 changed files with 222 additions and 156 deletions

View file

@ -7,39 +7,39 @@ import (
) )
func MakeProvider(cfg *config.Config, demoMode, asgardeoMode bool) authz.Provider { func MakeProvider(cfg *config.Config, demoMode, asgardeoMode bool) authz.Provider {
var mode, orgName string var mode, orgName string
switch { switch {
case demoMode: case demoMode:
mode = "demo" mode = "demo"
orgName = cfg.Demo.OrgName orgName = cfg.Demo.OrgName
case asgardeoMode: case asgardeoMode:
mode = "asgardeo" mode = "asgardeo"
orgName = cfg.Asgardeo.OrgName orgName = cfg.Asgardeo.OrgName
default: default:
mode = "default" mode = "default"
} }
cfg.Mode = mode cfg.Mode = mode
switch mode { switch mode {
case "demo", "asgardeo": case "demo", "asgardeo":
if len(cfg.AuthorizationServers) == 0 && cfg.JwksURI == "" { if len(cfg.ProtectedResourceMetadata.AuthorizationServers) == 0 && cfg.ProtectedResourceMetadata.JwksURI == "" {
base := constants.ASGARDEO_BASE_URL + orgName + "/oauth2" base := constants.ASGARDEO_BASE_URL + orgName + "/oauth2"
cfg.AuthServerBaseURL = base cfg.AuthServerBaseURL = base
cfg.JWKSURL = base + "/jwks" cfg.JWKSURL = base + "/jwks"
} else { } else {
cfg.AuthServerBaseURL = cfg.AuthorizationServers[0] cfg.AuthServerBaseURL = cfg.ProtectedResourceMetadata.AuthorizationServers[0]
cfg.JWKSURL = cfg.JwksURI cfg.JWKSURL = cfg.ProtectedResourceMetadata.JwksURI
} }
return authz.NewAsgardeoProvider(cfg) return authz.NewAsgardeoProvider(cfg)
default: default:
if cfg.Default.BaseURL != "" && cfg.Default.JWKSURL != "" { if cfg.Default.BaseURL != "" && cfg.Default.JWKSURL != "" {
cfg.AuthServerBaseURL = cfg.Default.BaseURL cfg.AuthServerBaseURL = cfg.Default.BaseURL
cfg.JWKSURL = cfg.Default.JWKSURL cfg.JWKSURL = cfg.Default.JWKSURL
} else if len(cfg.AuthorizationServers) > 0 { } else if len(cfg.ProtectedResourceMetadata.AuthorizationServers) > 0 {
cfg.AuthServerBaseURL = cfg.AuthorizationServers[0] cfg.AuthServerBaseURL = cfg.ProtectedResourceMetadata.AuthorizationServers[0]
cfg.JWKSURL = cfg.JwksURI cfg.JWKSURL = cfg.ProtectedResourceMetadata.JwksURI
} }
return authz.NewDefaultProvider(cfg) return authz.NewDefaultProvider(cfg)
} }
} }

View file

@ -1,9 +1,10 @@
# config.yaml # config.yaml
# Common configuration for all transport modes # Common configuration for all transport modes
proxy_base_url: http://localhost:8080
listen_port: 8080 listen_port: 8080
base_url: "http://localhost:3001" # Base URL for the MCP server base_url: "http://localhost:8000" # Base URL for the MCP server
port: 3001 # Port for the MCP server port: 8000 # Port for the MCP server
timeout_seconds: 10 timeout_seconds: 10
# Path configuration # Path configuration
@ -17,7 +18,7 @@ transport_mode: "sse" # Options: "sse" or "stdio"
# stdio-specific configuration (used only when transport_mode is "stdio") # stdio-specific configuration (used only when transport_mode is "stdio")
stdio: stdio:
enabled: true enabled: false
user_command: "npx -y @modelcontextprotocol/server-github" user_command: "npx -y @modelcontextprotocol/server-github"
work_dir: "" # Working directory (optional) work_dir: "" # Working directory (optional)
# env: # Environment variables (optional) # env: # Environment variables (optional)
@ -30,6 +31,7 @@ path_mapping:
cors: cors:
allowed_origins: allowed_origins:
- "http://127.0.0.1:6274" - "http://127.0.0.1:6274"
- "http://localhost:6274"
allowed_methods: allowed_methods:
- "GET" - "GET"
- "POST" - "POST"
@ -47,17 +49,17 @@ demo:
client_id: "N0U9e_NNGr9mP_0fPnPfPI0a6twa" client_id: "N0U9e_NNGr9mP_0fPnPfPI0a6twa"
client_secret: "qFHfiBp5gNGAO9zV4YPnDofBzzfInatfUbHyPZvM0jka" client_secret: "qFHfiBp5gNGAO9zV4YPnDofBzzfInatfUbHyPZvM0jka"
# Protected resource metadata protected_resource_metadata:
resource_identifier: http://localhost:3000 resource_identifier: http://localhost:8080/sse
audience: mcp_proxy audience: 2xGW_poFYoObUE_vUQxvGdPSUPwa
scopes_supported: scopes_supported:
- "tools":"read:tools" - initialize: "mcp_init"
- "resources":"read:resources" - tools/call:
- "prompts":"read:prompts" - echo_tool: "mcp_echo_tool"
authorization_servers: authorization_servers:
- https://api.asgardeo.io/t/acme/ - https://api.asgardeo.io/t/openmcpauthdemo/oauth2/token
jwks_uri: https://api.asgardeo.io/t/acme/oauth2/jwks jwks_uri: https://api.asgardeo.io/t/openmcpauthdemo/oauth2/jwks
bearer_methods_supported: bearer_methods_supported:
- header - header
- body - body
- query - query

View file

@ -194,7 +194,7 @@ func (p *asgardeoProvider) createAsgardeoApplication(regReq RegisterRequest) err
if resp.StatusCode >= 400 { if resp.StatusCode >= 400 {
respBody, _ := io.ReadAll(resp.Body) respBody, _ := io.ReadAll(resp.Body)
return fmt.Errorf("Asgardeo creation error (%d): %s", resp.StatusCode, string(respBody)) return fmt.Errorf("asgardeo creation error (%d): %s", resp.StatusCode, string(respBody))
} }
logger.Info("Created Asgardeo application for clientID=%s", regReq.ClientID) logger.Info("Created Asgardeo application for clientID=%s", regReq.ClientID)
@ -367,16 +367,41 @@ func randomString(n int) string {
func (p *asgardeoProvider) ProtectedResourceMetadataHandler() http.HandlerFunc { func (p *asgardeoProvider) ProtectedResourceMetadataHandler() http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) { return func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json") w.Header().Set("Content-Type", "application/json")
// Extract only the values into a []string
var supportedScopes []string
var extractStrings func(interface{})
extractStrings = func(val interface{}) {
switch v := val.(type) {
case string:
supportedScopes = append(supportedScopes, v)
case []any:
for _, item := range v {
extractStrings(item)
}
case map[string]any:
for _, item := range v {
extractStrings(item)
}
}
}
for _, m := range p.cfg.ProtectedResourceMetadata.ScopesSupported {
for _, v := range m {
extractStrings(v)
}
}
meta := map[string]interface{}{ meta := map[string]interface{}{
"resource": p.cfg.ResourceIdentifier, "resource": p.cfg.ProtectedResourceMetadata.ResourceIdentifier,
"scopes_supported": p.cfg.ScopesSupported, "scopes_supported": supportedScopes,
"authorization_servers": p.cfg.AuthorizationServers, "authorization_servers": p.cfg.ProtectedResourceMetadata.AuthorizationServers,
} }
if p.cfg.JwksURI != "" {
meta["jwks_uri"] = p.cfg.JwksURI if p.cfg.ProtectedResourceMetadata.JwksURI != "" {
meta["jwks_uri"] = p.cfg.ProtectedResourceMetadata.JwksURI
} }
if len(p.cfg.BearerMethodsSupported) > 0 { if len(p.cfg.ProtectedResourceMetadata.BearerMethodsSupported) > 0 {
meta["bearer_methods_supported"] = p.cfg.BearerMethodsSupported meta["bearer_methods_supported"] = p.cfg.ProtectedResourceMetadata.BearerMethodsSupported
} }
if err := json.NewEncoder(w).Encode(meta); err != nil { if err := json.NewEncoder(w).Encode(meta); err != nil {
http.Error(w, "failed to encode metadata", http.StatusInternalServerError) http.Error(w, "failed to encode metadata", http.StatusInternalServerError)

View file

@ -5,7 +5,7 @@ import (
"net/http" "net/http"
"github.com/wso2/open-mcp-auth-proxy/internal/config" "github.com/wso2/open-mcp-auth-proxy/internal/config"
"github.com/wso2/open-mcp-auth-proxy/internal/logging" logger "github.com/wso2/open-mcp-auth-proxy/internal/logging"
) )
type defaultProvider struct { type defaultProvider struct {
@ -99,18 +99,17 @@ func (p *defaultProvider) ProtectedResourceMetadataHandler() http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) { return func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json") w.Header().Set("Content-Type", "application/json")
meta := map[string]interface{}{ meta := map[string]interface{}{
"audience": p.cfg.Audience, "audience": p.cfg.ProtectedResourceMetadata.Audience,
"resource": p.cfg.ResourceIdentifier, "scopes_supported": p.cfg.ProtectedResourceMetadata.ScopesSupported,
"scopes_supported": p.cfg.ScopesSupported, "authorization_servers": p.cfg.ProtectedResourceMetadata.AuthorizationServers,
"authorization_servers": p.cfg.AuthorizationServers,
} }
if p.cfg.JwksURI != "" { if p.cfg.ProtectedResourceMetadata.JwksURI != "" {
meta["jwks_uri"] = p.cfg.JwksURI meta["jwks_uri"] = p.cfg.ProtectedResourceMetadata.JwksURI
} }
if len(p.cfg.BearerMethodsSupported) > 0 { if len(p.cfg.ProtectedResourceMetadata.BearerMethodsSupported) > 0 {
meta["bearer_methods_supported"] = p.cfg.BearerMethodsSupported meta["bearer_methods_supported"] = p.cfg.ProtectedResourceMetadata.BearerMethodsSupported
} }
if err := json.NewEncoder(w).Encode(meta); err != nil { if err := json.NewEncoder(w).Encode(meta); err != nil {

View file

@ -18,36 +18,37 @@ func (d *ScopeValidator) ValidateAccess(
claims *jwt.MapClaims, claims *jwt.MapClaims,
config *config.Config, config *config.Config,
) AccessControlResult { ) AccessControlResult {
env, err := util.ParseRPCRequest(r) env, err := util.ParseRPCRequest(r)
if err != nil { if err != nil {
return AccessControlResult{DecisionDeny, "bad JSON-RPC request"} return AccessControlResult{DecisionDeny, "bad JSON-RPC request"}
} }
requiredScopes := util.GetRequiredScopes(config, env.Method) requiredScopes := util.GetRequiredScopes(config, env)
if len(requiredScopes) == 0 {
return AccessControlResult{DecisionAllow, ""}
}
required := make(map[string]struct{}, len(requiredScopes)) if len(requiredScopes) == 0 {
for _, s := range requiredScopes { return AccessControlResult{DecisionAllow, ""}
s = strings.TrimSpace(s) }
if s != "" {
required[s] = struct{}{}
}
}
var tokenScopes []string required := make(map[string]struct{}, len(requiredScopes))
if claims, ok := (*claims)["scope"]; ok { for _, s := range requiredScopes {
switch v := claims.(type) { s = strings.TrimSpace(s)
case string: if s != "" {
tokenScopes = strings.Fields(v) required[s] = struct{}{}
case []interface{}: }
for _, x := range v { }
if s, ok := x.(string); ok && s != "" {
tokenScopes = append(tokenScopes, s) var tokenScopes []string
} if claims, ok := (*claims)["scope"]; ok {
} switch v := claims.(type) {
} case string:
} tokenScopes = strings.Fields(v)
case []interface{}:
for _, x := range v {
if s, ok := x.(string); ok && s != "" {
tokenScopes = append(tokenScopes, s)
}
}
}
}
tokenScopeSet := make(map[string]struct{}, len(tokenScopes)) tokenScopeSet := make(map[string]struct{}, len(tokenScopes))
for _, s := range tokenScopes { for _, s := range tokenScopes {

View file

@ -13,8 +13,9 @@ import (
type TransportMode string type TransportMode string
const ( const (
SSETransport TransportMode = "sse" SSETransport TransportMode = "sse"
StdioTransport TransportMode = "stdio" StdioTransport TransportMode = "stdio"
StreamableHTTPTransport TransportMode = "streamable_http"
) )
// Common path configuration for all transport modes // Common path configuration for all transport modes
@ -68,6 +69,15 @@ type ResponseConfig struct {
CodeChallengeMethodsSupported []string `yaml:"code_challenge_methods_supported,omitempty"` CodeChallengeMethodsSupported []string `yaml:"code_challenge_methods_supported,omitempty"`
} }
type ProtectedResourceMetadata struct {
ResourceIdentifier string `yaml:"resource_identifier"`
Audience string `yaml:"audience"`
ScopesSupported []map[string]interface{} `yaml:"scopes_supported"`
AuthorizationServers []string `yaml:"authorization_servers"`
JwksURI string `yaml:"jwks_uri,omitempty"`
BearerMethodsSupported []string `yaml:"bearer_methods_supported,omitempty"`
}
type PathConfig struct { type PathConfig struct {
// For well-known endpoint // For well-known endpoint
Response *ResponseConfig `yaml:"response,omitempty"` Response *ResponseConfig `yaml:"response,omitempty"`
@ -86,6 +96,7 @@ type DefaultConfig struct {
} }
type Config struct { type Config struct {
ProxyBaseURL string `yaml:"proxy_base_url"`
AuthServerBaseURL string AuthServerBaseURL string
ListenPort int `yaml:"listen_port"` ListenPort int `yaml:"listen_port"`
BaseURL string `yaml:"base_url"` BaseURL string `yaml:"base_url"`
@ -98,7 +109,6 @@ type Config struct {
TransportMode TransportMode `yaml:"transport_mode"` TransportMode TransportMode `yaml:"transport_mode"`
Paths PathsConfig `yaml:"paths"` Paths PathsConfig `yaml:"paths"`
Stdio StdioConfig `yaml:"stdio"` Stdio StdioConfig `yaml:"stdio"`
RequiredScopes map[string]string `yaml:"required_scopes"`
// Nested config for Asgardeo // Nested config for Asgardeo
Demo DemoConfig `yaml:"demo"` Demo DemoConfig `yaml:"demo"`
@ -106,12 +116,7 @@ type Config struct {
Default DefaultConfig `yaml:"default"` Default DefaultConfig `yaml:"default"`
// Protected resource metadata // Protected resource metadata
Audience string `yaml:"audience"` ProtectedResourceMetadata ProtectedResourceMetadata `yaml:"protected_resource_metadata"`
ResourceIdentifier string `yaml:"resource_identifier"`
ScopesSupported any `yaml:"scopes_supported"`
AuthorizationServers []string `yaml:"authorization_servers"`
JwksURI string `yaml:"jwks_uri,omitempty"`
BearerMethodsSupported []string `yaml:"bearer_methods_supported,omitempty"`
} }
// Validate checks if the config is valid based on transport mode // Validate checks if the config is valid based on transport mode

View file

@ -11,7 +11,7 @@ import (
"github.com/wso2/open-mcp-auth-proxy/internal/authz" "github.com/wso2/open-mcp-auth-proxy/internal/authz"
"github.com/wso2/open-mcp-auth-proxy/internal/config" "github.com/wso2/open-mcp-auth-proxy/internal/config"
"github.com/wso2/open-mcp-auth-proxy/internal/logging" logger "github.com/wso2/open-mcp-auth-proxy/internal/logging"
"github.com/wso2/open-mcp-auth-proxy/internal/util" "github.com/wso2/open-mcp-auth-proxy/internal/util"
) )
@ -64,7 +64,7 @@ func NewRouter(cfg *config.Config, provider authz.Provider, accessController aut
} }
} }
mux.HandleFunc("/.well-known/oauth-protected-resource", func(w http.ResponseWriter, r *http.Request) { mux.HandleFunc(getProtectedResourceMetadataEndpointPath(cfg), func(w http.ResponseWriter, r *http.Request) {
origin := r.Header.Get("Origin") origin := r.Header.Get("Origin")
allowed := getAllowedOrigin(origin, cfg) allowed := getAllowedOrigin(origin, cfg)
if r.Method == http.MethodOptions { if r.Method == http.MethodOptions {
@ -76,7 +76,7 @@ func NewRouter(cfg *config.Config, provider authz.Provider, accessController aut
addCORSHeaders(w, cfg, allowed, "") addCORSHeaders(w, cfg, allowed, "")
provider.ProtectedResourceMetadataHandler()(w, r) provider.ProtectedResourceMetadataHandler()(w, r)
}) })
registeredPaths["/.well-known/oauth-protected-resource"] = true registeredPaths[getProtectedResourceMetadataEndpointPath(cfg)] = true
// Remove duplicates from defaultPaths // Remove duplicates from defaultPaths
uniquePaths := make(map[string]bool) uniquePaths := make(map[string]bool)
@ -165,11 +165,11 @@ func buildProxyHandler(cfg *config.Config, modifiers map[string]RequestModifier,
var targetURL *url.URL var targetURL *url.URL
isSSE := false isSSE := false
if isAuthPath(r.URL.Path) { if isAuthPath(r.URL.Path, cfg) {
targetURL = authBase targetURL = authBase
} else if isMCPPath(r.URL.Path, cfg) { } else if isMCPPath(r.URL.Path, cfg) {
if ssePaths[r.URL.Path] { if ssePaths[r.URL.Path] {
if err := authorizeSSE(w, r, isLatestSpec, cfg.ResourceIdentifier); err != nil { if err := authorizeSSE(w, r, isLatestSpec, cfg); err != nil {
http.Error(w, err.Error(), http.StatusUnauthorized) http.Error(w, err.Error(), http.StatusUnauthorized)
return return
} }
@ -245,7 +245,7 @@ func buildProxyHandler(cfg *config.Config, modifiers map[string]RequestModifier,
"WWW-Authenticate", "WWW-Authenticate",
fmt.Sprintf( fmt.Sprintf(
`Bearer resource_metadata="%s"`, `Bearer resource_metadata="%s"`,
cfg.ResourceIdentifier+"/.well-known/oauth-protected-resource", cfg.ProxyBaseURL+getProtectedResourceMetadataEndpointPath(cfg),
)) ))
resp.Header.Set("Access-Control-Expose-Headers", "WWW-Authenticate") resp.Header.Set("Access-Control-Expose-Headers", "WWW-Authenticate")
} }
@ -285,11 +285,11 @@ func buildProxyHandler(cfg *config.Config, modifiers map[string]RequestModifier,
} }
// Check if the request is for SSE handshake and authorize it // Check if the request is for SSE handshake and authorize it
func authorizeSSE(w http.ResponseWriter, r *http.Request, isLatestSpec bool, resourceID string) error { func authorizeSSE(w http.ResponseWriter, r *http.Request, isLatestSpec bool, cfg *config.Config) error {
authHeader := r.Header.Get("Authorization") authHeader := r.Header.Get("Authorization")
if authHeader == "" || !strings.HasPrefix(authHeader, "Bearer ") { if authHeader == "" || !strings.HasPrefix(authHeader, "Bearer ") {
if isLatestSpec { if isLatestSpec {
realm := resourceID + "/.well-known/oauth-protected-resource" realm := cfg.BaseURL + getProtectedResourceMetadataEndpointPath(cfg)
w.Header().Set("WWW-Authenticate", fmt.Sprintf(`Bearer resource_metadata="%s"`, realm)) w.Header().Set("WWW-Authenticate", fmt.Sprintf(`Bearer resource_metadata="%s"`, realm))
w.Header().Set("Access-Control-Expose-Headers", "WWW-Authenticate") w.Header().Set("Access-Control-Expose-Headers", "WWW-Authenticate")
} }
@ -305,7 +305,7 @@ func authorizeMCP(w http.ResponseWriter, r *http.Request, isLatestSpec bool, cfg
accessToken, _ := util.ExtractAccessToken(authzHeader) accessToken, _ := util.ExtractAccessToken(authzHeader)
if !strings.HasPrefix(authzHeader, "Bearer ") { if !strings.HasPrefix(authzHeader, "Bearer ") {
if isLatestSpec { if isLatestSpec {
realm := cfg.ResourceIdentifier + "/.well-known/oauth-protected-resource" realm := cfg.ProxyBaseURL + getProtectedResourceMetadataEndpointPath(cfg)
w.Header().Set("WWW-Authenticate", fmt.Sprintf( w.Header().Set("WWW-Authenticate", fmt.Sprintf(
`Bearer resource_metadata=%q`, realm, `Bearer resource_metadata=%q`, realm,
)) ))
@ -315,10 +315,10 @@ func authorizeMCP(w http.ResponseWriter, r *http.Request, isLatestSpec bool, cfg
return fmt.Errorf("missing or invalid Authorization header") return fmt.Errorf("missing or invalid Authorization header")
} }
err := util.ValidateJWT(isLatestSpec, accessToken, cfg.Audience) err := util.ValidateJWT(isLatestSpec, accessToken, cfg.ProtectedResourceMetadata.Audience)
if err != nil { if err != nil {
if isLatestSpec { if isLatestSpec {
realm := cfg.ResourceIdentifier + "/.well-known/oauth-protected-resource" realm := cfg.ProxyBaseURL + getProtectedResourceMetadataEndpointPath(cfg)
w.Header().Set("WWW-Authenticate", fmt.Sprintf(err.Error(), w.Header().Set("WWW-Authenticate", fmt.Sprintf(err.Error(),
`Bearer realm=%q`, `Bearer realm=%q`,
realm, realm,
@ -343,7 +343,7 @@ func authorizeMCP(w http.ResponseWriter, r *http.Request, isLatestSpec bool, cfg
http.Error(w, "Invalid token claims", http.StatusUnauthorized) http.Error(w, "Invalid token claims", http.StatusUnauthorized)
return fmt.Errorf("invalid token claims") return fmt.Errorf("invalid token claims")
} }
pr := accessController.ValidateAccess(r, &claimsMap, cfg) pr := accessController.ValidateAccess(r, &claimsMap, cfg)
if pr.Decision == authz.DecisionDeny { if pr.Decision == authz.DecisionDeny {
http.Error(w, "Forbidden: "+pr.Message, http.StatusForbidden) http.Error(w, "Forbidden: "+pr.Message, http.StatusForbidden)
@ -385,13 +385,13 @@ func addCORSHeaders(w http.ResponseWriter, cfg *config.Config, allowedOrigin, re
w.Header().Set("X-Accel-Buffering", "no") w.Header().Set("X-Accel-Buffering", "no")
} }
func isAuthPath(path string) bool { func isAuthPath(path string, cfg *config.Config) bool {
authPaths := map[string]bool{ authPaths := map[string]bool{
"/authorize": true, "/authorize": true,
"/token": true, "/token": true,
"/register": true, "/register": true,
"/.well-known/oauth-authorization-server": true, "/.well-known/oauth-authorization-server": true,
"/.well-known/oauth-protected-resource": true, getProtectedResourceMetadataEndpointPath(cfg): true,
} }
if strings.HasPrefix(path, "/u/") { if strings.HasPrefix(path, "/u/") {
return true return true
@ -417,3 +417,17 @@ func skipHeader(h string) bool {
} }
return false return false
} }
func getProtectedResourceMetadataEndpointPath(cfg *config.Config) string {
protectedResourceMetadataPath := "/.well-known/oauth-protected-resource"
switch cfg.TransportMode {
case config.SSETransport:
protectedResourceMetadataPath += cfg.Paths.SSE
case config.StreamableHTTPTransport:
protectedResourceMetadataPath += cfg.Paths.StreamableHTTP
}
return protectedResourceMetadataPath
}

View file

@ -11,7 +11,7 @@ import (
"github.com/golang-jwt/jwt/v4" "github.com/golang-jwt/jwt/v4"
"github.com/wso2/open-mcp-auth-proxy/internal/config" "github.com/wso2/open-mcp-auth-proxy/internal/config"
"github.com/wso2/open-mcp-auth-proxy/internal/logging" logger "github.com/wso2/open-mcp-auth-proxy/internal/logging"
) )
type TokenClaims struct { type TokenClaims struct {
@ -82,7 +82,7 @@ func parseRSAPublicKey(nStr, eStr string) (*rsa.PublicKey, error) {
// ValidateJWT checks the Bearer token according to the Mcp-Protocol-Version. // ValidateJWT checks the Bearer token according to the Mcp-Protocol-Version.
func ValidateJWT( func ValidateJWT(
isLatestSpec bool, isLatestSpec bool,
accessToken string, accessToken string,
audience string, audience string,
) error { ) error {
logger.Warn("isLatestSpec: %s", isLatestSpec) logger.Warn("isLatestSpec: %s", isLatestSpec)
@ -148,46 +148,66 @@ func ValidateJWT(
// Parses the JWT token and returns the claims // Parses the JWT token and returns the claims
func ParseJWT(tokenStr string) (jwt.MapClaims, error) { func ParseJWT(tokenStr string) (jwt.MapClaims, error) {
if tokenStr == "" { if tokenStr == "" {
return nil, fmt.Errorf("empty JWT") return nil, fmt.Errorf("empty JWT")
} }
var claims jwt.MapClaims var claims jwt.MapClaims
_, _, err := jwt.NewParser().ParseUnverified(tokenStr, &claims) _, _, err := jwt.NewParser().ParseUnverified(tokenStr, &claims)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to parse JWT: %w", err) return nil, fmt.Errorf("failed to parse JWT: %w", err)
} }
return claims, nil return claims, nil
} }
// Process the required scopes // Process the required scopes
func GetRequiredScopes(cfg *config.Config, method string) []string { func GetRequiredScopes(cfg *config.Config, requestBody *RPCEnvelope) []string {
switch raw := cfg.ScopesSupported.(type) {
case map[string]string:
if scope, ok := raw[method]; ok {
return []string{scope}
}
parts := strings.SplitN(method, "/", 2)
if len(parts) > 0 {
if scope, ok := raw[parts[0]]; ok {
return []string{scope}
}
}
return nil
case []interface{}:
out := make([]string, 0, len(raw))
for _, v := range raw {
if s, ok := v.(string); ok && s != "" {
out = append(out, s)
}
}
return out
case []string: var scopeObj interface{}
return raw found := false
} for _, m := range cfg.ProtectedResourceMetadata.ScopesSupported {
if val, ok := m[requestBody.Method]; ok {
scopeObj = val
found = true
break
}
}
if !found {
return nil
}
return nil switch v := scopeObj.(type) {
case string:
return []string{v}
case []any:
if requestBody.Params != nil {
if paramsMap, ok := requestBody.Params.(map[string]any); ok {
name, ok := paramsMap["name"].(string)
if ok {
for _, item := range v {
if scopeMap, ok := item.(map[interface{}]interface{}); ok {
if scopeVal, exists := scopeMap[name]; exists {
if scopeStr, ok := scopeVal.(string); ok {
return []string{scopeStr}
}
if scopeArr, ok := scopeVal.([]any); ok {
var scopes []string
for _, s := range scopeArr {
if str, ok := s.(string); ok {
scopes = append(scopes, str)
}
}
return scopes
}
}
}
}
}
}
}
}
return nil
} }
// Extracts the Bearer token from the Authorization header // Extracts the Bearer token from the Authorization header