mirror of
https://github.com/wso2/open-mcp-auth-proxy.git
synced 2025-08-17 20:03:08 +00:00
Misc improvements
This commit is contained in:
parent
b30aa6273c
commit
8589035d64
8 changed files with 222 additions and 156 deletions
|
@ -7,39 +7,39 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
func MakeProvider(cfg *config.Config, demoMode, asgardeoMode bool) authz.Provider {
|
func MakeProvider(cfg *config.Config, demoMode, asgardeoMode bool) authz.Provider {
|
||||||
var mode, orgName string
|
var mode, orgName string
|
||||||
switch {
|
switch {
|
||||||
case demoMode:
|
case demoMode:
|
||||||
mode = "demo"
|
mode = "demo"
|
||||||
orgName = cfg.Demo.OrgName
|
orgName = cfg.Demo.OrgName
|
||||||
case asgardeoMode:
|
case asgardeoMode:
|
||||||
mode = "asgardeo"
|
mode = "asgardeo"
|
||||||
orgName = cfg.Asgardeo.OrgName
|
orgName = cfg.Asgardeo.OrgName
|
||||||
default:
|
default:
|
||||||
mode = "default"
|
mode = "default"
|
||||||
}
|
}
|
||||||
cfg.Mode = mode
|
cfg.Mode = mode
|
||||||
|
|
||||||
switch mode {
|
switch mode {
|
||||||
case "demo", "asgardeo":
|
case "demo", "asgardeo":
|
||||||
if len(cfg.AuthorizationServers) == 0 && cfg.JwksURI == "" {
|
if len(cfg.ProtectedResourceMetadata.AuthorizationServers) == 0 && cfg.ProtectedResourceMetadata.JwksURI == "" {
|
||||||
base := constants.ASGARDEO_BASE_URL + orgName + "/oauth2"
|
base := constants.ASGARDEO_BASE_URL + orgName + "/oauth2"
|
||||||
cfg.AuthServerBaseURL = base
|
cfg.AuthServerBaseURL = base
|
||||||
cfg.JWKSURL = base + "/jwks"
|
cfg.JWKSURL = base + "/jwks"
|
||||||
} else {
|
} else {
|
||||||
cfg.AuthServerBaseURL = cfg.AuthorizationServers[0]
|
cfg.AuthServerBaseURL = cfg.ProtectedResourceMetadata.AuthorizationServers[0]
|
||||||
cfg.JWKSURL = cfg.JwksURI
|
cfg.JWKSURL = cfg.ProtectedResourceMetadata.JwksURI
|
||||||
}
|
}
|
||||||
return authz.NewAsgardeoProvider(cfg)
|
return authz.NewAsgardeoProvider(cfg)
|
||||||
|
|
||||||
default:
|
default:
|
||||||
if cfg.Default.BaseURL != "" && cfg.Default.JWKSURL != "" {
|
if cfg.Default.BaseURL != "" && cfg.Default.JWKSURL != "" {
|
||||||
cfg.AuthServerBaseURL = cfg.Default.BaseURL
|
cfg.AuthServerBaseURL = cfg.Default.BaseURL
|
||||||
cfg.JWKSURL = cfg.Default.JWKSURL
|
cfg.JWKSURL = cfg.Default.JWKSURL
|
||||||
} else if len(cfg.AuthorizationServers) > 0 {
|
} else if len(cfg.ProtectedResourceMetadata.AuthorizationServers) > 0 {
|
||||||
cfg.AuthServerBaseURL = cfg.AuthorizationServers[0]
|
cfg.AuthServerBaseURL = cfg.ProtectedResourceMetadata.AuthorizationServers[0]
|
||||||
cfg.JWKSURL = cfg.JwksURI
|
cfg.JWKSURL = cfg.ProtectedResourceMetadata.JwksURI
|
||||||
}
|
}
|
||||||
return authz.NewDefaultProvider(cfg)
|
return authz.NewDefaultProvider(cfg)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
36
config.yaml
36
config.yaml
|
@ -1,9 +1,10 @@
|
||||||
# config.yaml
|
# config.yaml
|
||||||
|
|
||||||
# Common configuration for all transport modes
|
# Common configuration for all transport modes
|
||||||
|
proxy_base_url: http://localhost:8080
|
||||||
listen_port: 8080
|
listen_port: 8080
|
||||||
base_url: "http://localhost:3001" # Base URL for the MCP server
|
base_url: "http://localhost:8000" # Base URL for the MCP server
|
||||||
port: 3001 # Port for the MCP server
|
port: 8000 # Port for the MCP server
|
||||||
timeout_seconds: 10
|
timeout_seconds: 10
|
||||||
|
|
||||||
# Path configuration
|
# Path configuration
|
||||||
|
@ -17,7 +18,7 @@ transport_mode: "sse" # Options: "sse" or "stdio"
|
||||||
|
|
||||||
# stdio-specific configuration (used only when transport_mode is "stdio")
|
# stdio-specific configuration (used only when transport_mode is "stdio")
|
||||||
stdio:
|
stdio:
|
||||||
enabled: true
|
enabled: false
|
||||||
user_command: "npx -y @modelcontextprotocol/server-github"
|
user_command: "npx -y @modelcontextprotocol/server-github"
|
||||||
work_dir: "" # Working directory (optional)
|
work_dir: "" # Working directory (optional)
|
||||||
# env: # Environment variables (optional)
|
# env: # Environment variables (optional)
|
||||||
|
@ -30,6 +31,7 @@ path_mapping:
|
||||||
cors:
|
cors:
|
||||||
allowed_origins:
|
allowed_origins:
|
||||||
- "http://127.0.0.1:6274"
|
- "http://127.0.0.1:6274"
|
||||||
|
- "http://localhost:6274"
|
||||||
allowed_methods:
|
allowed_methods:
|
||||||
- "GET"
|
- "GET"
|
||||||
- "POST"
|
- "POST"
|
||||||
|
@ -47,17 +49,17 @@ demo:
|
||||||
client_id: "N0U9e_NNGr9mP_0fPnPfPI0a6twa"
|
client_id: "N0U9e_NNGr9mP_0fPnPfPI0a6twa"
|
||||||
client_secret: "qFHfiBp5gNGAO9zV4YPnDofBzzfInatfUbHyPZvM0jka"
|
client_secret: "qFHfiBp5gNGAO9zV4YPnDofBzzfInatfUbHyPZvM0jka"
|
||||||
|
|
||||||
# Protected resource metadata
|
protected_resource_metadata:
|
||||||
resource_identifier: http://localhost:3000
|
resource_identifier: http://localhost:8080/sse
|
||||||
audience: mcp_proxy
|
audience: 2xGW_poFYoObUE_vUQxvGdPSUPwa
|
||||||
scopes_supported:
|
scopes_supported:
|
||||||
- "tools":"read:tools"
|
- initialize: "mcp_init"
|
||||||
- "resources":"read:resources"
|
- tools/call:
|
||||||
- "prompts":"read:prompts"
|
- echo_tool: "mcp_echo_tool"
|
||||||
authorization_servers:
|
authorization_servers:
|
||||||
- https://api.asgardeo.io/t/acme/
|
- https://api.asgardeo.io/t/openmcpauthdemo/oauth2/token
|
||||||
jwks_uri: https://api.asgardeo.io/t/acme/oauth2/jwks
|
jwks_uri: https://api.asgardeo.io/t/openmcpauthdemo/oauth2/jwks
|
||||||
bearer_methods_supported:
|
bearer_methods_supported:
|
||||||
- header
|
- header
|
||||||
- body
|
- body
|
||||||
- query
|
- query
|
||||||
|
|
|
@ -194,7 +194,7 @@ func (p *asgardeoProvider) createAsgardeoApplication(regReq RegisterRequest) err
|
||||||
|
|
||||||
if resp.StatusCode >= 400 {
|
if resp.StatusCode >= 400 {
|
||||||
respBody, _ := io.ReadAll(resp.Body)
|
respBody, _ := io.ReadAll(resp.Body)
|
||||||
return fmt.Errorf("Asgardeo creation error (%d): %s", resp.StatusCode, string(respBody))
|
return fmt.Errorf("asgardeo creation error (%d): %s", resp.StatusCode, string(respBody))
|
||||||
}
|
}
|
||||||
|
|
||||||
logger.Info("Created Asgardeo application for clientID=%s", regReq.ClientID)
|
logger.Info("Created Asgardeo application for clientID=%s", regReq.ClientID)
|
||||||
|
@ -367,16 +367,41 @@ func randomString(n int) string {
|
||||||
func (p *asgardeoProvider) ProtectedResourceMetadataHandler() http.HandlerFunc {
|
func (p *asgardeoProvider) ProtectedResourceMetadataHandler() http.HandlerFunc {
|
||||||
return func(w http.ResponseWriter, r *http.Request) {
|
return func(w http.ResponseWriter, r *http.Request) {
|
||||||
w.Header().Set("Content-Type", "application/json")
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
|
||||||
|
// Extract only the values into a []string
|
||||||
|
var supportedScopes []string
|
||||||
|
var extractStrings func(interface{})
|
||||||
|
extractStrings = func(val interface{}) {
|
||||||
|
switch v := val.(type) {
|
||||||
|
case string:
|
||||||
|
supportedScopes = append(supportedScopes, v)
|
||||||
|
case []any:
|
||||||
|
for _, item := range v {
|
||||||
|
extractStrings(item)
|
||||||
|
}
|
||||||
|
case map[string]any:
|
||||||
|
for _, item := range v {
|
||||||
|
extractStrings(item)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for _, m := range p.cfg.ProtectedResourceMetadata.ScopesSupported {
|
||||||
|
for _, v := range m {
|
||||||
|
extractStrings(v)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
meta := map[string]interface{}{
|
meta := map[string]interface{}{
|
||||||
"resource": p.cfg.ResourceIdentifier,
|
"resource": p.cfg.ProtectedResourceMetadata.ResourceIdentifier,
|
||||||
"scopes_supported": p.cfg.ScopesSupported,
|
"scopes_supported": supportedScopes,
|
||||||
"authorization_servers": p.cfg.AuthorizationServers,
|
"authorization_servers": p.cfg.ProtectedResourceMetadata.AuthorizationServers,
|
||||||
}
|
}
|
||||||
if p.cfg.JwksURI != "" {
|
|
||||||
meta["jwks_uri"] = p.cfg.JwksURI
|
if p.cfg.ProtectedResourceMetadata.JwksURI != "" {
|
||||||
|
meta["jwks_uri"] = p.cfg.ProtectedResourceMetadata.JwksURI
|
||||||
}
|
}
|
||||||
if len(p.cfg.BearerMethodsSupported) > 0 {
|
if len(p.cfg.ProtectedResourceMetadata.BearerMethodsSupported) > 0 {
|
||||||
meta["bearer_methods_supported"] = p.cfg.BearerMethodsSupported
|
meta["bearer_methods_supported"] = p.cfg.ProtectedResourceMetadata.BearerMethodsSupported
|
||||||
}
|
}
|
||||||
if err := json.NewEncoder(w).Encode(meta); err != nil {
|
if err := json.NewEncoder(w).Encode(meta); err != nil {
|
||||||
http.Error(w, "failed to encode metadata", http.StatusInternalServerError)
|
http.Error(w, "failed to encode metadata", http.StatusInternalServerError)
|
||||||
|
|
|
@ -5,7 +5,7 @@ import (
|
||||||
"net/http"
|
"net/http"
|
||||||
|
|
||||||
"github.com/wso2/open-mcp-auth-proxy/internal/config"
|
"github.com/wso2/open-mcp-auth-proxy/internal/config"
|
||||||
"github.com/wso2/open-mcp-auth-proxy/internal/logging"
|
logger "github.com/wso2/open-mcp-auth-proxy/internal/logging"
|
||||||
)
|
)
|
||||||
|
|
||||||
type defaultProvider struct {
|
type defaultProvider struct {
|
||||||
|
@ -99,18 +99,17 @@ func (p *defaultProvider) ProtectedResourceMetadataHandler() http.HandlerFunc {
|
||||||
return func(w http.ResponseWriter, r *http.Request) {
|
return func(w http.ResponseWriter, r *http.Request) {
|
||||||
w.Header().Set("Content-Type", "application/json")
|
w.Header().Set("Content-Type", "application/json")
|
||||||
meta := map[string]interface{}{
|
meta := map[string]interface{}{
|
||||||
"audience": p.cfg.Audience,
|
"audience": p.cfg.ProtectedResourceMetadata.Audience,
|
||||||
"resource": p.cfg.ResourceIdentifier,
|
"scopes_supported": p.cfg.ProtectedResourceMetadata.ScopesSupported,
|
||||||
"scopes_supported": p.cfg.ScopesSupported,
|
"authorization_servers": p.cfg.ProtectedResourceMetadata.AuthorizationServers,
|
||||||
"authorization_servers": p.cfg.AuthorizationServers,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if p.cfg.JwksURI != "" {
|
if p.cfg.ProtectedResourceMetadata.JwksURI != "" {
|
||||||
meta["jwks_uri"] = p.cfg.JwksURI
|
meta["jwks_uri"] = p.cfg.ProtectedResourceMetadata.JwksURI
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(p.cfg.BearerMethodsSupported) > 0 {
|
if len(p.cfg.ProtectedResourceMetadata.BearerMethodsSupported) > 0 {
|
||||||
meta["bearer_methods_supported"] = p.cfg.BearerMethodsSupported
|
meta["bearer_methods_supported"] = p.cfg.ProtectedResourceMetadata.BearerMethodsSupported
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := json.NewEncoder(w).Encode(meta); err != nil {
|
if err := json.NewEncoder(w).Encode(meta); err != nil {
|
||||||
|
|
|
@ -18,36 +18,37 @@ func (d *ScopeValidator) ValidateAccess(
|
||||||
claims *jwt.MapClaims,
|
claims *jwt.MapClaims,
|
||||||
config *config.Config,
|
config *config.Config,
|
||||||
) AccessControlResult {
|
) AccessControlResult {
|
||||||
env, err := util.ParseRPCRequest(r)
|
env, err := util.ParseRPCRequest(r)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return AccessControlResult{DecisionDeny, "bad JSON-RPC request"}
|
return AccessControlResult{DecisionDeny, "bad JSON-RPC request"}
|
||||||
}
|
}
|
||||||
requiredScopes := util.GetRequiredScopes(config, env.Method)
|
requiredScopes := util.GetRequiredScopes(config, env)
|
||||||
if len(requiredScopes) == 0 {
|
|
||||||
return AccessControlResult{DecisionAllow, ""}
|
|
||||||
}
|
|
||||||
|
|
||||||
required := make(map[string]struct{}, len(requiredScopes))
|
if len(requiredScopes) == 0 {
|
||||||
for _, s := range requiredScopes {
|
return AccessControlResult{DecisionAllow, ""}
|
||||||
s = strings.TrimSpace(s)
|
}
|
||||||
if s != "" {
|
|
||||||
required[s] = struct{}{}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
var tokenScopes []string
|
required := make(map[string]struct{}, len(requiredScopes))
|
||||||
if claims, ok := (*claims)["scope"]; ok {
|
for _, s := range requiredScopes {
|
||||||
switch v := claims.(type) {
|
s = strings.TrimSpace(s)
|
||||||
case string:
|
if s != "" {
|
||||||
tokenScopes = strings.Fields(v)
|
required[s] = struct{}{}
|
||||||
case []interface{}:
|
}
|
||||||
for _, x := range v {
|
}
|
||||||
if s, ok := x.(string); ok && s != "" {
|
|
||||||
tokenScopes = append(tokenScopes, s)
|
var tokenScopes []string
|
||||||
}
|
if claims, ok := (*claims)["scope"]; ok {
|
||||||
}
|
switch v := claims.(type) {
|
||||||
}
|
case string:
|
||||||
}
|
tokenScopes = strings.Fields(v)
|
||||||
|
case []interface{}:
|
||||||
|
for _, x := range v {
|
||||||
|
if s, ok := x.(string); ok && s != "" {
|
||||||
|
tokenScopes = append(tokenScopes, s)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
tokenScopeSet := make(map[string]struct{}, len(tokenScopes))
|
tokenScopeSet := make(map[string]struct{}, len(tokenScopes))
|
||||||
for _, s := range tokenScopes {
|
for _, s := range tokenScopes {
|
||||||
|
|
|
@ -13,8 +13,9 @@ import (
|
||||||
type TransportMode string
|
type TransportMode string
|
||||||
|
|
||||||
const (
|
const (
|
||||||
SSETransport TransportMode = "sse"
|
SSETransport TransportMode = "sse"
|
||||||
StdioTransport TransportMode = "stdio"
|
StdioTransport TransportMode = "stdio"
|
||||||
|
StreamableHTTPTransport TransportMode = "streamable_http"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Common path configuration for all transport modes
|
// Common path configuration for all transport modes
|
||||||
|
@ -68,6 +69,15 @@ type ResponseConfig struct {
|
||||||
CodeChallengeMethodsSupported []string `yaml:"code_challenge_methods_supported,omitempty"`
|
CodeChallengeMethodsSupported []string `yaml:"code_challenge_methods_supported,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type ProtectedResourceMetadata struct {
|
||||||
|
ResourceIdentifier string `yaml:"resource_identifier"`
|
||||||
|
Audience string `yaml:"audience"`
|
||||||
|
ScopesSupported []map[string]interface{} `yaml:"scopes_supported"`
|
||||||
|
AuthorizationServers []string `yaml:"authorization_servers"`
|
||||||
|
JwksURI string `yaml:"jwks_uri,omitempty"`
|
||||||
|
BearerMethodsSupported []string `yaml:"bearer_methods_supported,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
type PathConfig struct {
|
type PathConfig struct {
|
||||||
// For well-known endpoint
|
// For well-known endpoint
|
||||||
Response *ResponseConfig `yaml:"response,omitempty"`
|
Response *ResponseConfig `yaml:"response,omitempty"`
|
||||||
|
@ -86,6 +96,7 @@ type DefaultConfig struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
type Config struct {
|
type Config struct {
|
||||||
|
ProxyBaseURL string `yaml:"proxy_base_url"`
|
||||||
AuthServerBaseURL string
|
AuthServerBaseURL string
|
||||||
ListenPort int `yaml:"listen_port"`
|
ListenPort int `yaml:"listen_port"`
|
||||||
BaseURL string `yaml:"base_url"`
|
BaseURL string `yaml:"base_url"`
|
||||||
|
@ -98,7 +109,6 @@ type Config struct {
|
||||||
TransportMode TransportMode `yaml:"transport_mode"`
|
TransportMode TransportMode `yaml:"transport_mode"`
|
||||||
Paths PathsConfig `yaml:"paths"`
|
Paths PathsConfig `yaml:"paths"`
|
||||||
Stdio StdioConfig `yaml:"stdio"`
|
Stdio StdioConfig `yaml:"stdio"`
|
||||||
RequiredScopes map[string]string `yaml:"required_scopes"`
|
|
||||||
|
|
||||||
// Nested config for Asgardeo
|
// Nested config for Asgardeo
|
||||||
Demo DemoConfig `yaml:"demo"`
|
Demo DemoConfig `yaml:"demo"`
|
||||||
|
@ -106,12 +116,7 @@ type Config struct {
|
||||||
Default DefaultConfig `yaml:"default"`
|
Default DefaultConfig `yaml:"default"`
|
||||||
|
|
||||||
// Protected resource metadata
|
// Protected resource metadata
|
||||||
Audience string `yaml:"audience"`
|
ProtectedResourceMetadata ProtectedResourceMetadata `yaml:"protected_resource_metadata"`
|
||||||
ResourceIdentifier string `yaml:"resource_identifier"`
|
|
||||||
ScopesSupported any `yaml:"scopes_supported"`
|
|
||||||
AuthorizationServers []string `yaml:"authorization_servers"`
|
|
||||||
JwksURI string `yaml:"jwks_uri,omitempty"`
|
|
||||||
BearerMethodsSupported []string `yaml:"bearer_methods_supported,omitempty"`
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Validate checks if the config is valid based on transport mode
|
// Validate checks if the config is valid based on transport mode
|
||||||
|
|
|
@ -11,7 +11,7 @@ import (
|
||||||
|
|
||||||
"github.com/wso2/open-mcp-auth-proxy/internal/authz"
|
"github.com/wso2/open-mcp-auth-proxy/internal/authz"
|
||||||
"github.com/wso2/open-mcp-auth-proxy/internal/config"
|
"github.com/wso2/open-mcp-auth-proxy/internal/config"
|
||||||
"github.com/wso2/open-mcp-auth-proxy/internal/logging"
|
logger "github.com/wso2/open-mcp-auth-proxy/internal/logging"
|
||||||
"github.com/wso2/open-mcp-auth-proxy/internal/util"
|
"github.com/wso2/open-mcp-auth-proxy/internal/util"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -64,7 +64,7 @@ func NewRouter(cfg *config.Config, provider authz.Provider, accessController aut
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
mux.HandleFunc("/.well-known/oauth-protected-resource", func(w http.ResponseWriter, r *http.Request) {
|
mux.HandleFunc(getProtectedResourceMetadataEndpointPath(cfg), func(w http.ResponseWriter, r *http.Request) {
|
||||||
origin := r.Header.Get("Origin")
|
origin := r.Header.Get("Origin")
|
||||||
allowed := getAllowedOrigin(origin, cfg)
|
allowed := getAllowedOrigin(origin, cfg)
|
||||||
if r.Method == http.MethodOptions {
|
if r.Method == http.MethodOptions {
|
||||||
|
@ -76,7 +76,7 @@ func NewRouter(cfg *config.Config, provider authz.Provider, accessController aut
|
||||||
addCORSHeaders(w, cfg, allowed, "")
|
addCORSHeaders(w, cfg, allowed, "")
|
||||||
provider.ProtectedResourceMetadataHandler()(w, r)
|
provider.ProtectedResourceMetadataHandler()(w, r)
|
||||||
})
|
})
|
||||||
registeredPaths["/.well-known/oauth-protected-resource"] = true
|
registeredPaths[getProtectedResourceMetadataEndpointPath(cfg)] = true
|
||||||
|
|
||||||
// Remove duplicates from defaultPaths
|
// Remove duplicates from defaultPaths
|
||||||
uniquePaths := make(map[string]bool)
|
uniquePaths := make(map[string]bool)
|
||||||
|
@ -165,11 +165,11 @@ func buildProxyHandler(cfg *config.Config, modifiers map[string]RequestModifier,
|
||||||
var targetURL *url.URL
|
var targetURL *url.URL
|
||||||
isSSE := false
|
isSSE := false
|
||||||
|
|
||||||
if isAuthPath(r.URL.Path) {
|
if isAuthPath(r.URL.Path, cfg) {
|
||||||
targetURL = authBase
|
targetURL = authBase
|
||||||
} else if isMCPPath(r.URL.Path, cfg) {
|
} else if isMCPPath(r.URL.Path, cfg) {
|
||||||
if ssePaths[r.URL.Path] {
|
if ssePaths[r.URL.Path] {
|
||||||
if err := authorizeSSE(w, r, isLatestSpec, cfg.ResourceIdentifier); err != nil {
|
if err := authorizeSSE(w, r, isLatestSpec, cfg); err != nil {
|
||||||
http.Error(w, err.Error(), http.StatusUnauthorized)
|
http.Error(w, err.Error(), http.StatusUnauthorized)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@ -245,7 +245,7 @@ func buildProxyHandler(cfg *config.Config, modifiers map[string]RequestModifier,
|
||||||
"WWW-Authenticate",
|
"WWW-Authenticate",
|
||||||
fmt.Sprintf(
|
fmt.Sprintf(
|
||||||
`Bearer resource_metadata="%s"`,
|
`Bearer resource_metadata="%s"`,
|
||||||
cfg.ResourceIdentifier+"/.well-known/oauth-protected-resource",
|
cfg.ProxyBaseURL+getProtectedResourceMetadataEndpointPath(cfg),
|
||||||
))
|
))
|
||||||
resp.Header.Set("Access-Control-Expose-Headers", "WWW-Authenticate")
|
resp.Header.Set("Access-Control-Expose-Headers", "WWW-Authenticate")
|
||||||
}
|
}
|
||||||
|
@ -285,11 +285,11 @@ func buildProxyHandler(cfg *config.Config, modifiers map[string]RequestModifier,
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check if the request is for SSE handshake and authorize it
|
// Check if the request is for SSE handshake and authorize it
|
||||||
func authorizeSSE(w http.ResponseWriter, r *http.Request, isLatestSpec bool, resourceID string) error {
|
func authorizeSSE(w http.ResponseWriter, r *http.Request, isLatestSpec bool, cfg *config.Config) error {
|
||||||
authHeader := r.Header.Get("Authorization")
|
authHeader := r.Header.Get("Authorization")
|
||||||
if authHeader == "" || !strings.HasPrefix(authHeader, "Bearer ") {
|
if authHeader == "" || !strings.HasPrefix(authHeader, "Bearer ") {
|
||||||
if isLatestSpec {
|
if isLatestSpec {
|
||||||
realm := resourceID + "/.well-known/oauth-protected-resource"
|
realm := cfg.BaseURL + getProtectedResourceMetadataEndpointPath(cfg)
|
||||||
w.Header().Set("WWW-Authenticate", fmt.Sprintf(`Bearer resource_metadata="%s"`, realm))
|
w.Header().Set("WWW-Authenticate", fmt.Sprintf(`Bearer resource_metadata="%s"`, realm))
|
||||||
w.Header().Set("Access-Control-Expose-Headers", "WWW-Authenticate")
|
w.Header().Set("Access-Control-Expose-Headers", "WWW-Authenticate")
|
||||||
}
|
}
|
||||||
|
@ -305,7 +305,7 @@ func authorizeMCP(w http.ResponseWriter, r *http.Request, isLatestSpec bool, cfg
|
||||||
accessToken, _ := util.ExtractAccessToken(authzHeader)
|
accessToken, _ := util.ExtractAccessToken(authzHeader)
|
||||||
if !strings.HasPrefix(authzHeader, "Bearer ") {
|
if !strings.HasPrefix(authzHeader, "Bearer ") {
|
||||||
if isLatestSpec {
|
if isLatestSpec {
|
||||||
realm := cfg.ResourceIdentifier + "/.well-known/oauth-protected-resource"
|
realm := cfg.ProxyBaseURL + getProtectedResourceMetadataEndpointPath(cfg)
|
||||||
w.Header().Set("WWW-Authenticate", fmt.Sprintf(
|
w.Header().Set("WWW-Authenticate", fmt.Sprintf(
|
||||||
`Bearer resource_metadata=%q`, realm,
|
`Bearer resource_metadata=%q`, realm,
|
||||||
))
|
))
|
||||||
|
@ -315,10 +315,10 @@ func authorizeMCP(w http.ResponseWriter, r *http.Request, isLatestSpec bool, cfg
|
||||||
return fmt.Errorf("missing or invalid Authorization header")
|
return fmt.Errorf("missing or invalid Authorization header")
|
||||||
}
|
}
|
||||||
|
|
||||||
err := util.ValidateJWT(isLatestSpec, accessToken, cfg.Audience)
|
err := util.ValidateJWT(isLatestSpec, accessToken, cfg.ProtectedResourceMetadata.Audience)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if isLatestSpec {
|
if isLatestSpec {
|
||||||
realm := cfg.ResourceIdentifier + "/.well-known/oauth-protected-resource"
|
realm := cfg.ProxyBaseURL + getProtectedResourceMetadataEndpointPath(cfg)
|
||||||
w.Header().Set("WWW-Authenticate", fmt.Sprintf(err.Error(),
|
w.Header().Set("WWW-Authenticate", fmt.Sprintf(err.Error(),
|
||||||
`Bearer realm=%q`,
|
`Bearer realm=%q`,
|
||||||
realm,
|
realm,
|
||||||
|
@ -343,7 +343,7 @@ func authorizeMCP(w http.ResponseWriter, r *http.Request, isLatestSpec bool, cfg
|
||||||
http.Error(w, "Invalid token claims", http.StatusUnauthorized)
|
http.Error(w, "Invalid token claims", http.StatusUnauthorized)
|
||||||
return fmt.Errorf("invalid token claims")
|
return fmt.Errorf("invalid token claims")
|
||||||
}
|
}
|
||||||
|
|
||||||
pr := accessController.ValidateAccess(r, &claimsMap, cfg)
|
pr := accessController.ValidateAccess(r, &claimsMap, cfg)
|
||||||
if pr.Decision == authz.DecisionDeny {
|
if pr.Decision == authz.DecisionDeny {
|
||||||
http.Error(w, "Forbidden: "+pr.Message, http.StatusForbidden)
|
http.Error(w, "Forbidden: "+pr.Message, http.StatusForbidden)
|
||||||
|
@ -385,13 +385,13 @@ func addCORSHeaders(w http.ResponseWriter, cfg *config.Config, allowedOrigin, re
|
||||||
w.Header().Set("X-Accel-Buffering", "no")
|
w.Header().Set("X-Accel-Buffering", "no")
|
||||||
}
|
}
|
||||||
|
|
||||||
func isAuthPath(path string) bool {
|
func isAuthPath(path string, cfg *config.Config) bool {
|
||||||
authPaths := map[string]bool{
|
authPaths := map[string]bool{
|
||||||
"/authorize": true,
|
"/authorize": true,
|
||||||
"/token": true,
|
"/token": true,
|
||||||
"/register": true,
|
"/register": true,
|
||||||
"/.well-known/oauth-authorization-server": true,
|
"/.well-known/oauth-authorization-server": true,
|
||||||
"/.well-known/oauth-protected-resource": true,
|
getProtectedResourceMetadataEndpointPath(cfg): true,
|
||||||
}
|
}
|
||||||
if strings.HasPrefix(path, "/u/") {
|
if strings.HasPrefix(path, "/u/") {
|
||||||
return true
|
return true
|
||||||
|
@ -417,3 +417,17 @@ func skipHeader(h string) bool {
|
||||||
}
|
}
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func getProtectedResourceMetadataEndpointPath(cfg *config.Config) string {
|
||||||
|
|
||||||
|
protectedResourceMetadataPath := "/.well-known/oauth-protected-resource"
|
||||||
|
|
||||||
|
switch cfg.TransportMode {
|
||||||
|
case config.SSETransport:
|
||||||
|
protectedResourceMetadataPath += cfg.Paths.SSE
|
||||||
|
case config.StreamableHTTPTransport:
|
||||||
|
protectedResourceMetadataPath += cfg.Paths.StreamableHTTP
|
||||||
|
}
|
||||||
|
|
||||||
|
return protectedResourceMetadataPath
|
||||||
|
}
|
||||||
|
|
|
@ -11,7 +11,7 @@ import (
|
||||||
|
|
||||||
"github.com/golang-jwt/jwt/v4"
|
"github.com/golang-jwt/jwt/v4"
|
||||||
"github.com/wso2/open-mcp-auth-proxy/internal/config"
|
"github.com/wso2/open-mcp-auth-proxy/internal/config"
|
||||||
"github.com/wso2/open-mcp-auth-proxy/internal/logging"
|
logger "github.com/wso2/open-mcp-auth-proxy/internal/logging"
|
||||||
)
|
)
|
||||||
|
|
||||||
type TokenClaims struct {
|
type TokenClaims struct {
|
||||||
|
@ -82,7 +82,7 @@ func parseRSAPublicKey(nStr, eStr string) (*rsa.PublicKey, error) {
|
||||||
// ValidateJWT checks the Bearer token according to the Mcp-Protocol-Version.
|
// ValidateJWT checks the Bearer token according to the Mcp-Protocol-Version.
|
||||||
func ValidateJWT(
|
func ValidateJWT(
|
||||||
isLatestSpec bool,
|
isLatestSpec bool,
|
||||||
accessToken string,
|
accessToken string,
|
||||||
audience string,
|
audience string,
|
||||||
) error {
|
) error {
|
||||||
logger.Warn("isLatestSpec: %s", isLatestSpec)
|
logger.Warn("isLatestSpec: %s", isLatestSpec)
|
||||||
|
@ -148,46 +148,66 @@ func ValidateJWT(
|
||||||
|
|
||||||
// Parses the JWT token and returns the claims
|
// Parses the JWT token and returns the claims
|
||||||
func ParseJWT(tokenStr string) (jwt.MapClaims, error) {
|
func ParseJWT(tokenStr string) (jwt.MapClaims, error) {
|
||||||
if tokenStr == "" {
|
if tokenStr == "" {
|
||||||
return nil, fmt.Errorf("empty JWT")
|
return nil, fmt.Errorf("empty JWT")
|
||||||
}
|
}
|
||||||
|
|
||||||
var claims jwt.MapClaims
|
var claims jwt.MapClaims
|
||||||
_, _, err := jwt.NewParser().ParseUnverified(tokenStr, &claims)
|
_, _, err := jwt.NewParser().ParseUnverified(tokenStr, &claims)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to parse JWT: %w", err)
|
return nil, fmt.Errorf("failed to parse JWT: %w", err)
|
||||||
}
|
}
|
||||||
return claims, nil
|
return claims, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Process the required scopes
|
// Process the required scopes
|
||||||
func GetRequiredScopes(cfg *config.Config, method string) []string {
|
func GetRequiredScopes(cfg *config.Config, requestBody *RPCEnvelope) []string {
|
||||||
switch raw := cfg.ScopesSupported.(type) {
|
|
||||||
case map[string]string:
|
|
||||||
if scope, ok := raw[method]; ok {
|
|
||||||
return []string{scope}
|
|
||||||
}
|
|
||||||
parts := strings.SplitN(method, "/", 2)
|
|
||||||
if len(parts) > 0 {
|
|
||||||
if scope, ok := raw[parts[0]]; ok {
|
|
||||||
return []string{scope}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
case []interface{}:
|
|
||||||
out := make([]string, 0, len(raw))
|
|
||||||
for _, v := range raw {
|
|
||||||
if s, ok := v.(string); ok && s != "" {
|
|
||||||
out = append(out, s)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return out
|
|
||||||
|
|
||||||
case []string:
|
var scopeObj interface{}
|
||||||
return raw
|
found := false
|
||||||
}
|
for _, m := range cfg.ProtectedResourceMetadata.ScopesSupported {
|
||||||
|
if val, ok := m[requestBody.Method]; ok {
|
||||||
|
scopeObj = val
|
||||||
|
found = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !found {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
return nil
|
switch v := scopeObj.(type) {
|
||||||
|
case string:
|
||||||
|
return []string{v}
|
||||||
|
case []any:
|
||||||
|
if requestBody.Params != nil {
|
||||||
|
if paramsMap, ok := requestBody.Params.(map[string]any); ok {
|
||||||
|
name, ok := paramsMap["name"].(string)
|
||||||
|
if ok {
|
||||||
|
for _, item := range v {
|
||||||
|
if scopeMap, ok := item.(map[interface{}]interface{}); ok {
|
||||||
|
if scopeVal, exists := scopeMap[name]; exists {
|
||||||
|
if scopeStr, ok := scopeVal.(string); ok {
|
||||||
|
return []string{scopeStr}
|
||||||
|
}
|
||||||
|
if scopeArr, ok := scopeVal.([]any); ok {
|
||||||
|
var scopes []string
|
||||||
|
for _, s := range scopeArr {
|
||||||
|
if str, ok := s.(string); ok {
|
||||||
|
scopes = append(scopes, str)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return scopes
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Extracts the Bearer token from the Authorization header
|
// Extracts the Bearer token from the Authorization header
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue