diff --git a/.github/scripts/release.sh b/.github/scripts/release.sh index 2a1f6a9..52b024d 100644 --- a/.github/scripts/release.sh +++ b/.github/scripts/release.sh @@ -51,7 +51,7 @@ else fi # Extract current version. -CURRENT_VERSION=$(git describe --tags --abbrev=0 2>/dev/null || echo "0.0.0") +CURRENT_VERSION=$(git tag --sort=-v:refname | head -n 1 | sed 's/^v//' || echo "0.0.0") IFS='.' read -r MAJOR MINOR PATCH <<< "${CURRENT_VERSION}" # Determine which part to increment diff --git a/.gitignore b/.gitignore index d200b58..f2bcda1 100644 --- a/.gitignore +++ b/.gitignore @@ -36,3 +36,7 @@ coverage.html # IDE files .vscode + +# node modules +node_modules +openmcpauthproxy diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md new file mode 100644 index 0000000..3375ec0 --- /dev/null +++ b/CONTRIBUTING.md @@ -0,0 +1,62 @@ +# Contributing + +## Build from Source + +> Prerequisites +> +> * Go 1.20 or higher +> * Git +> * Make (optional, for simplified builds) + +1. **Clone the repository:** + ```bash + git clone https://github.com/wso2/open-mcp-auth-proxy + cd open-mcp-auth-proxy + ``` + +2. **Install dependencies:** + ```bash + go get -v -t -d ./... + ``` + +3. **Build the application:** + + **Option A: Using Make** + + ```bash + # Build for all platforms + make all + + # Or build for specific platforms + make build-linux # For Linux (x86_64) + make build-linux-arm # For ARM-based Linux + make build-darwin # For macOS + make build-windows # For Windows + ``` + + **Option B: Manual build (works on all platforms)** + + ```bash + # Build for your current platform + go build -o openmcpauthproxy ./cmd/proxy + + # Cross-compile for other platforms + GOOS=linux GOARCH=amd64 go build -o openmcpauthproxy-linux ./cmd/proxy + GOOS=windows GOARCH=amd64 go build -o openmcpauthproxy.exe ./cmd/proxy + GOOS=darwin GOARCH=amd64 go build -o openmcpauthproxy-macos ./cmd/proxy + ``` + +After building, you'll find the executables in the `build` directory (when using Make) or in your project root (when building manually). + +### Additional Make Targets + +If you're using Make, these additional targets are available: + +```bash +make test # Run tests +make coverage # Run tests with coverage report +make fmt # Format code with gofmt +make vet # Run go vet +make clean # Clean build artifacts +make help # Show all available targets +``` \ No newline at end of file diff --git a/Makefile b/Makefile index b0d0926..3c0c590 100644 --- a/Makefile +++ b/Makefile @@ -24,9 +24,9 @@ TEST_OPTS := -v -race .PHONY: all clean test fmt lint vet coverage help # Default target -all: lint test build-linux build-linux-arm build-darwin +all: lint test build-linux build-linux-arm build-darwin build-windows -build: clean test build-linux build-linux-arm build-darwin +build: clean test build-linux build-linux-arm build-darwin build-windows build-linux: mkdir -p $(BUILD_DIR)/linux @@ -46,6 +46,12 @@ build-darwin: -o $(BUILD_DIR)/darwin/openmcpauthproxy $(PROJECT_ROOT)/cmd/proxy cp config.yaml $(BUILD_DIR)/darwin +build-windows: + mkdir -p $(BUILD_DIR)/windows + GOOS=windows GOARCH=amd64 CGO_ENABLED=0 go build -x -ldflags "-X main.version=$(BUILD_VERSION)" \ + -o $(BUILD_DIR)/windows/openmcpauthproxy.exe ./cmd/proxy + cp config.yaml $(BUILD_DIR)/windows + # Clean build artifacts clean: @echo "Cleaning build artifacts..." diff --git a/README.md b/README.md index 6be3ece..bd7abac 100644 --- a/README.md +++ b/README.md @@ -10,64 +10,40 @@ A lightweight authorization proxy for Model Context Protocol (MCP) servers that ![Architecture Diagram](https://github.com/user-attachments/assets/41cf6723-c488-4860-8640-8fec45006f92) -## What it Does +## 🚀 Features -Open MCP Auth Proxy sits between MCP clients and your MCP server to: +- **Dynamic Authorization**: based on MCP Authorization Specification. +- **JWT Validation**: Validates the token’s signature, checks the `audience` claim, and enforces scope requirements. +- **Identity Provider Integration**: Supports integrating any OAuth/OIDC provider such as Asgardeo, Auth0, Keycloak, etc. +- **Protocol Version Negotiation**: via `MCP-Protocol-Version` header. +- **Flexible Transport Modes**: Supports STDIO, SSE and streamable HTTP transport options. -- Intercept incoming requests -- Validate authorization tokens -- Offload authentication and authorization to OAuth-compliant Identity Providers -- Support the MCP authorization protocol +## 🛠️ Quick Start -## Quick Start - -### Prerequisites - -* Go 1.20 or higher -* A running MCP server - -> If you don't have an MCP server, you can use the included example: -> -> 1. Navigate to the `resources` directory -> 2. Set up a Python environment: +> **Prerequisites** > -> ```bash -> python3 -m venv .venv -> source .venv/bin/activate -> pip3 install -r requirements.txt -> ``` -> -> 3. Start the example server: -> -> ```bash -> python3 echo_server.py -> ``` - -* An MCP client that supports MCP authorization - -### Basic Usage +> * A running MCP server (Use the [example MCP server](resources/README.md) if you don't have an MCP server already) +> * An MCP client that supports MCP authorization specification 1. Download the latest release from [Github releases](https://github.com/wso2/open-mcp-auth-proxy/releases/latest). 2. Start the proxy in demo mode (uses pre-configured authentication with Asgardeo sandbox): +- Linux/macOS: + ```bash ./openmcpauthproxy --demo ``` -> The repository comes with a default `config.yaml` file that contains the basic configuration: -> -> ```yaml -> listen_port: 8080 -> base_url: "http://localhost:8000" # Your MCP server URL -> paths: -> sse: "/sse" -> messages: "/messages/" -> ``` +- Windows: -3. Connect using an MCP client like [MCP Inspector](https://github.com/shashimalcse/inspector)(This is a temporary fork with fixes for authentication [issues](https://github.com/modelcontextprotocol/typescript-sdk/issues/257) in the original implementation) +```powershell +.\openmcpauthproxy.exe --demo +``` -## Connect an Identity Provider +3. Connect using an MCP client like [MCP Inspector](https://github.com/modelcontextprotocol/inspector). + +## 🔒 Integrate an Identity Provider ### Asgardeo @@ -75,22 +51,26 @@ To enable authorization through your Asgardeo organization: 1. [Register](https://asgardeo.io/signup) and create an organization in Asgardeo 2. Create an [M2M application](https://wso2.com/asgardeo/docs/guides/applications/register-machine-to-machine-app/) - 1. [Authorize this application](https://wso2.com/asgardeo/docs/guides/applications/register-machine-to-machine-app/#authorize-the-api-resources-for-the-app) to invoke "Application Management API" with the `internal_application_mgt_create` scope +3. [Authorize this application](https://wso2.com/asgardeo/docs/guides/applications/register-machine-to-machine-app/#authorize-the-api-resources-for-the-app) to invoke "Application Management API" with the `internal_application_mgt_create` scope ![image](https://github.com/user-attachments/assets/0bd57cac-1904-48cc-b7aa-0530224bc41a) -3. Update `config.yaml` with the following parameters. +4. Update `config.yaml` with the following parameters. ```yaml -base_url: "http://localhost:8000" # URL of your MCP server -listen_port: 8080 # Address where the proxy will listen +base_url: "http://localhost:8000" # URL of your MCP server +listen_port: 8080 # Address where the proxy will listen -asgardeo: - org_name: "" # Your Asgardeo org name - client_id: "" # Client ID of the M2M app - client_secret: "" # Client secret of the M2M app +resource_identifier: "http://localhost:8080" # Proxy server URL +scopes_supported: # Scopes required to defined for the MCP server +- "read:tools" +- "read:resources" +audience: "" # Access token audience +authorization_servers: # Authorization server issuer identifier(s) +- "https://api.asgardeo.io/t/acme" +jwks_uri: "https://api.asgardeo.io/t/acme/oauth2/jwks" # JWKS URL ``` -4. Start the proxy with Asgardeo integration: +5. Start the proxy with Asgardeo integration: ```bash ./openmcpauthproxy --asgardeo @@ -101,59 +81,24 @@ asgardeo: - [Auth0](docs/integrations/Auth0.md) - [Keycloak](docs/integrations/keycloak.md) -# Advanced Configuration +## Transport Modes -### Transport Modes - -The proxy supports two transport modes: - -- **SSE Mode (Default)**: For Server-Sent Events transport -- **stdio Mode**: For MCP servers that use stdio transport +### **STDIO Mode** When using stdio mode, the proxy: - Starts an MCP server as a subprocess using the command specified in the configuration - Communicates with the subprocess through standard input/output (stdio) -- **Note**: Any commands specified (like `npx` in the example below) must be installed on your system first -To use stdio mode: - -```bash -./openmcpauthproxy --demo --stdio -``` - -#### Example: Running an MCP Server as a Subprocess +> **Note**: Any commands specified (like `npx` in the example below) must be installed on your system first 1. Configure stdio mode in your `config.yaml`: ```yaml -listen_port: 8080 -base_url: "http://localhost:8000" - stdio: enabled: true user_command: "npx -y @modelcontextprotocol/server-github" # Example using a GitHub MCP server env: # Environment variables (optional) - - "GITHUB_PERSONAL_ACCESS_TOKEN=gitPAT" - -# CORS configuration -cors: - allowed_origins: - - "http://localhost:5173" # Origin of your client application - allowed_methods: - - "GET" - - "POST" - - "PUT" - - "DELETE" - allowed_headers: - - "Authorization" - - "Content-Type" - allow_credentials: true - -# Demo configuration for Asgardeo -demo: - org_name: "openmcpauthdemo" - client_id: "N0U9e_NNGr9mP_0fPnPfPI0a6twa" - client_secret: "qFHfiBp5gNGAO9zV4YPnDofBzzfInatfUbHyPZvM0jka" + - "GITHUB_PERSONAL_ACCESS_TOKEN=gitPAT" ``` 2. Run the proxy with stdio mode: @@ -162,65 +107,28 @@ demo: ./openmcpauthproxy --demo ``` -The proxy will: -- Start the MCP server as a subprocess using the specified command -- Handle all authorization requirements -- Forward messages between clients and the server +- **SSE Mode (Default)**: For Server-Sent Events transport +- **Streamable HTTP Mode**: For Streamable HTTP transport -### Complete Configuration Reference - -```yaml -# Common configuration -listen_port: 8080 -base_url: "http://localhost:8000" -port: 8000 - -# Path configuration -paths: - sse: "/sse" - messages: "/messages/" - -# Transport mode -transport_mode: "sse" # Options: "sse" or "stdio" - -# stdio-specific configuration (used only in stdio mode) -stdio: - enabled: true - user_command: "npx -y @modelcontextprotocol/server-github" # Command to start the MCP server (requires npx to be installed) - work_dir: "" # Optional working directory for the subprocess - -# CORS configuration -cors: - allowed_origins: - - "http://localhost:5173" - allowed_methods: - - "GET" - - "POST" - - "PUT" - - "DELETE" - allowed_headers: - - "Authorization" - - "Content-Type" - allow_credentials: true - -# Demo configuration for Asgardeo -demo: - org_name: "openmcpauthdemo" - client_id: "N0U9e_NNGr9mP_0fPnPfPI0a6twa" - client_secret: "qFHfiBp5gNGAO9zV4YPnDofBzzfInatfUbHyPZvM0jka" - -# Asgardeo configuration (used with --asgardeo flag) -asgardeo: - org_name: "" - client_id: "" - client_secret: "" -``` - -### Build from source +## Available Command Line Options ```bash -git clone https://github.com/wso2/open-mcp-auth-proxy -cd open-mcp-auth-proxy -go get github.com/golang-jwt/jwt/v4 gopkg.in/yaml.v2 -go build -o openmcpauthproxy ./cmd/proxy +# Start in demo mode (using Asgardeo sandbox) +./openmcpauthproxy --demo + +# Start with your own Asgardeo organization +./openmcpauthproxy --asgardeo + +# Use stdio transport mode instead of SSE +./openmcpauthproxy --demo --stdio + +# Enable debug logging +./openmcpauthproxy --demo --debug + +# Show all available options +./openmcpauthproxy --help ``` + +## Contributing + +We appreciate your contributions, whether it is improving documentation, adding new features, or fixing bugs. To get started, please refer to our [contributing guide](CONTRIBUTING.md). diff --git a/cmd/proxy/main.go b/cmd/proxy/main.go index 6424f18..0208ead 100644 --- a/cmd/proxy/main.go +++ b/cmd/proxy/main.go @@ -11,7 +11,6 @@ import ( "github.com/wso2/open-mcp-auth-proxy/internal/authz" "github.com/wso2/open-mcp-auth-proxy/internal/config" - "github.com/wso2/open-mcp-auth-proxy/internal/constants" "github.com/wso2/open-mcp-auth-proxy/internal/logging" "github.com/wso2/open-mcp-auth-proxy/internal/proxy" "github.com/wso2/open-mcp-auth-proxy/internal/subprocess" @@ -68,23 +67,7 @@ func main() { } // 3. Create the chosen provider - var provider authz.Provider - if *demoMode { - cfg.Mode = "demo" - cfg.AuthServerBaseURL = constants.ASGARDEO_BASE_URL + cfg.Demo.OrgName + "/oauth2" - cfg.JWKSURL = constants.ASGARDEO_BASE_URL + cfg.Demo.OrgName + "/oauth2/jwks" - provider = authz.NewAsgardeoProvider(cfg) - } else if *asgardeoMode { - cfg.Mode = "asgardeo" - cfg.AuthServerBaseURL = constants.ASGARDEO_BASE_URL + cfg.Asgardeo.OrgName + "/oauth2" - cfg.JWKSURL = constants.ASGARDEO_BASE_URL + cfg.Asgardeo.OrgName + "/oauth2/jwks" - provider = authz.NewAsgardeoProvider(cfg) - } else { - cfg.Mode = "default" - cfg.JWKSURL = cfg.Default.JWKSURL - cfg.AuthServerBaseURL = cfg.Default.BaseURL - provider = authz.NewDefaultProvider(cfg) - } + var provider authz.Provider = MakeProvider(cfg, *demoMode, *asgardeoMode) // 4. (Optional) Fetch JWKS if you want local JWT validation if err := util.FetchJWKS(cfg.JWKSURL); err != nil { @@ -92,12 +75,15 @@ func main() { os.Exit(1) } - // 5. Build the main router - mux := proxy.NewRouter(cfg, provider) + // 5. (Optional) Build the access controler + accessController := &authz.ScopeValidator{} + + // 6. Build the main router + mux := proxy.NewRouter(cfg, provider, accessController) listen_address := fmt.Sprintf(":%d", cfg.ListenPort) - // 6. Start the server + // 7. Start the server srv := &http.Server{ Addr: listen_address, Handler: mux, @@ -111,18 +97,18 @@ func main() { } }() - // 7. Wait for shutdown signal + // 8. Wait for shutdown signal stop := make(chan os.Signal, 1) signal.Notify(stop, os.Interrupt, syscall.SIGTERM) <-stop logger.Info("Shutting down...") - // 8. First terminate subprocess if running + // 9. First terminate subprocess if running if procManager != nil && procManager.IsRunning() { procManager.Shutdown() } - // 9. Then shutdown the server + // 10. Then shutdown the server logger.Info("Shutting down HTTP server...") shutdownCtx, cancel := proxy.NewShutdownContext(5 * time.Second) defer cancel() diff --git a/cmd/proxy/provider.go b/cmd/proxy/provider.go new file mode 100644 index 0000000..90ef369 --- /dev/null +++ b/cmd/proxy/provider.go @@ -0,0 +1,45 @@ +package main + +import ( + "github.com/wso2/open-mcp-auth-proxy/internal/authz" + "github.com/wso2/open-mcp-auth-proxy/internal/config" + "github.com/wso2/open-mcp-auth-proxy/internal/constants" +) + +func MakeProvider(cfg *config.Config, demoMode, asgardeoMode bool) authz.Provider { + var mode, orgName string + switch { + case demoMode: + mode = "demo" + orgName = cfg.Demo.OrgName + case asgardeoMode: + mode = "asgardeo" + orgName = cfg.Asgardeo.OrgName + default: + mode = "default" + } + cfg.Mode = mode + + switch mode { + case "demo", "asgardeo": + if len(cfg.ProtectedResourceMetadata.AuthorizationServers) == 0 && cfg.ProtectedResourceMetadata.JwksURI == "" { + base := constants.ASGARDEO_BASE_URL + orgName + "/oauth2" + cfg.AuthServerBaseURL = base + cfg.JWKSURL = base + "/jwks" + } else { + cfg.AuthServerBaseURL = cfg.ProtectedResourceMetadata.AuthorizationServers[0] + cfg.JWKSURL = cfg.ProtectedResourceMetadata.JwksURI + } + return authz.NewAsgardeoProvider(cfg) + + default: + if cfg.Default.BaseURL != "" && cfg.Default.JWKSURL != "" { + cfg.AuthServerBaseURL = cfg.Default.BaseURL + cfg.JWKSURL = cfg.Default.JWKSURL + } else if len(cfg.ProtectedResourceMetadata.AuthorizationServers) > 0 { + cfg.AuthServerBaseURL = cfg.ProtectedResourceMetadata.AuthorizationServers[0] + cfg.JWKSURL = cfg.ProtectedResourceMetadata.JwksURI + } + return authz.NewDefaultProvider(cfg) + } +} diff --git a/config.yaml b/config.yaml index 5621195..47eb8eb 100644 --- a/config.yaml +++ b/config.yaml @@ -1,6 +1,7 @@ # config.yaml # Common configuration for all transport modes +proxy_base_url: http://localhost:8080 listen_port: 8080 base_url: "http://localhost:8000" # Base URL for the MCP server port: 8000 # Port for the MCP server @@ -10,13 +11,14 @@ timeout_seconds: 10 paths: sse: "/sse" # SSE endpoint path messages: "/messages/" # Messages endpoint path + streamable_http: "/mcp" # MCP endpoint path # Transport mode configuration transport_mode: "sse" # Options: "sse" or "stdio" # stdio-specific configuration (used only when transport_mode is "stdio") stdio: - enabled: true + enabled: false user_command: "npx -y @modelcontextprotocol/server-github" work_dir: "" # Working directory (optional) # env: # Environment variables (optional) @@ -28,7 +30,8 @@ path_mapping: # CORS configuration cors: allowed_origins: - - "http://localhost:5173" + - "http://127.0.0.1:6274" + - "http://localhost:6274" allowed_methods: - "GET" - "POST" @@ -45,3 +48,18 @@ demo: org_name: "openmcpauthdemo" client_id: "N0U9e_NNGr9mP_0fPnPfPI0a6twa" client_secret: "qFHfiBp5gNGAO9zV4YPnDofBzzfInatfUbHyPZvM0jka" + +protected_resource_metadata: + resource_identifier: http://localhost:8080/sse + audience: 2xGW_poFYoObUE_vUQxvGdPSUPwa + scopes_supported: + - initialize: "mcp_init" + - tools/call: + - echo_tool: "mcp_echo_tool" + authorization_servers: + - https://api.asgardeo.io/t/openmcpauthdemo/oauth2/token + jwks_uri: https://api.asgardeo.io/t/openmcpauthdemo/oauth2/jwks + bearer_methods_supported: + - header + - body + - query diff --git a/internal/authz/access_control.go b/internal/authz/access_control.go new file mode 100644 index 0000000..1f7ce7b --- /dev/null +++ b/internal/authz/access_control.go @@ -0,0 +1,24 @@ +package authz + +import ( + "net/http" + + "github.com/golang-jwt/jwt/v4" + "github.com/wso2/open-mcp-auth-proxy/internal/config" +) + +type Decision int + +const ( + DecisionAllow Decision = iota + DecisionDeny +) + +type AccessControlResult struct { + Decision Decision + Message string +} + +type AccessControl interface { + ValidateAccess(r *http.Request, claims *jwt.MapClaims, config *config.Config) AccessControlResult +} diff --git a/internal/authz/asgardeo.go b/internal/authz/asgardeo.go index 9b8fdc5..647e570 100644 --- a/internal/authz/asgardeo.go +++ b/internal/authz/asgardeo.go @@ -42,31 +42,17 @@ func (p *asgardeoProvider) WellKnownHandler() http.HandlerFunc { return } - scheme := "http" - if r.TLS != nil { - scheme = "https" - } - if forwardedProto := r.Header.Get("X-Forwarded-Proto"); forwardedProto != "" { - scheme = forwardedProto - } - host := r.Host - if forwardedHost := r.Header.Get("X-Forwarded-Host"); forwardedHost != "" { - host = forwardedHost - } - - baseURL := scheme + "://" + host - issuer := strings.TrimSuffix(p.cfg.AuthServerBaseURL, "/") + "/token" response := map[string]interface{}{ "issuer": issuer, - "authorization_endpoint": baseURL + "/authorize", - "token_endpoint": baseURL + "/token", + "authorization_endpoint": p.cfg.BaseURL + "/authorize", + "token_endpoint": p.cfg.BaseURL + "/token", "jwks_uri": p.cfg.JWKSURL, "response_types_supported": []string{"code"}, "grant_types_supported": []string{"authorization_code", "refresh_token"}, "token_endpoint_auth_methods_supported": []string{"client_secret_basic"}, - "registration_endpoint": baseURL + "/register", + "registration_endpoint": p.cfg.BaseURL + "/register", "code_challenge_methods_supported": []string{"plain", "S256"}, } @@ -113,6 +99,7 @@ func (p *asgardeoProvider) RegisterHandler() http.HandlerFunc { if err := p.createAsgardeoApplication(regReq); err != nil { logger.Warn("Asgardeo application creation failed: %v", err) + http.Error(w, "Failed to create application in Asgardeo", http.StatusInternalServerError) // Optionally http.Error(...) if you want to fail // or continue to return partial data. } @@ -193,7 +180,7 @@ func (p *asgardeoProvider) createAsgardeoApplication(regReq RegisterRequest) err if resp.StatusCode >= 400 { respBody, _ := io.ReadAll(resp.Body) - return fmt.Errorf("Asgardeo creation error (%d): %s", resp.StatusCode, string(respBody)) + return fmt.Errorf("asgardeo creation error (%d): %s", resp.StatusCode, string(respBody)) } logger.Info("Created Asgardeo application for clientID=%s", regReq.ClientID) @@ -269,6 +256,18 @@ func buildAsgardeoPayload(regReq RegisterRequest) map[string]interface{} { } appName += "-" + randomString(5) + // Build redirect URIs regex from list of redirect URIs : regexp=(https://app.example.com/callback1|https://app.example.com/callback2) + redirectURI := "regexp=(" + strings.Join(regReq.RedirectURIs, "|") + ")" + redirectURIs := []string{redirectURI} + + // Filter unsupported grant types + var grantTypes []string + for _, gt := range regReq.GrantTypes { + if gt == "authorization_code" || gt == "refresh_token" { + grantTypes = append(grantTypes, gt) + } + } + return map[string]interface{}{ "name": appName, "templateId": "custom-application-oidc", @@ -276,10 +275,10 @@ func buildAsgardeoPayload(regReq RegisterRequest) map[string]interface{} { "oidc": map[string]interface{}{ "clientId": regReq.ClientID, "clientSecret": regReq.ClientSecret, - "grantTypes": regReq.GrantTypes, - "callbackURLs": regReq.RedirectURIs, + "grantTypes": grantTypes, + "callbackURLs": redirectURIs, "allowedOrigins": []string{}, - "publicClient": false, + "publicClient": true, "pkce": map[string]bool{ "mandatory": true, "supportPlainTransformAlgorithm": true, @@ -350,3 +349,48 @@ func randomString(n int) string { } return string(b) } + +func (p *asgardeoProvider) ProtectedResourceMetadataHandler() http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + + // Extract only the values into a []string + var supportedScopes []string + var extractStrings func(interface{}) + extractStrings = func(val interface{}) { + switch v := val.(type) { + case string: + supportedScopes = append(supportedScopes, v) + case []any: + for _, item := range v { + extractStrings(item) + } + case map[string]any: + for _, item := range v { + extractStrings(item) + } + } + } + for _, m := range p.cfg.ProtectedResourceMetadata.ScopesSupported { + for _, v := range m { + extractStrings(v) + } + } + + meta := map[string]interface{}{ + "resource": p.cfg.ProtectedResourceMetadata.ResourceIdentifier, + "scopes_supported": supportedScopes, + "authorization_servers": p.cfg.ProtectedResourceMetadata.AuthorizationServers, + } + + if p.cfg.ProtectedResourceMetadata.JwksURI != "" { + meta["jwks_uri"] = p.cfg.ProtectedResourceMetadata.JwksURI + } + if len(p.cfg.ProtectedResourceMetadata.BearerMethodsSupported) > 0 { + meta["bearer_methods_supported"] = p.cfg.ProtectedResourceMetadata.BearerMethodsSupported + } + if err := json.NewEncoder(w).Encode(meta); err != nil { + http.Error(w, "failed to encode metadata", http.StatusInternalServerError) + } + } +} diff --git a/internal/authz/default.go b/internal/authz/default.go index 929f586..4f6647d 100644 --- a/internal/authz/default.go +++ b/internal/authz/default.go @@ -5,7 +5,7 @@ import ( "net/http" "github.com/wso2/open-mcp-auth-proxy/internal/config" - "github.com/wso2/open-mcp-auth-proxy/internal/logging" + logger "github.com/wso2/open-mcp-auth-proxy/internal/logging" ) type defaultProvider struct { @@ -40,31 +40,17 @@ func (p *defaultProvider) WellKnownHandler() http.HandlerFunc { // Use configured response values responseConfig := pathConfig.Response - // Get current host for proxy endpoints - scheme := "http" - if r.TLS != nil { - scheme = "https" - } - if forwardedProto := r.Header.Get("X-Forwarded-Proto"); forwardedProto != "" { - scheme = forwardedProto - } - host := r.Host - if forwardedHost := r.Header.Get("X-Forwarded-Host"); forwardedHost != "" { - host = forwardedHost - } - baseURL := scheme + "://" + host - authorizationEndpoint := responseConfig.AuthorizationEndpoint if authorizationEndpoint == "" { - authorizationEndpoint = baseURL + "/authorize" + authorizationEndpoint = p.cfg.BaseURL + "/authorize" } tokenEndpoint := responseConfig.TokenEndpoint if tokenEndpoint == "" { - tokenEndpoint = baseURL + "/token" + tokenEndpoint = p.cfg.BaseURL + "/token" } - registraionEndpoint := responseConfig.RegistrationEndpoint - if registraionEndpoint == "" { - registraionEndpoint = baseURL + "/register" + registrationEndpoint := responseConfig.RegistrationEndpoint + if registrationEndpoint == "" { + registrationEndpoint = p.cfg.BaseURL + "/register" } // Build response from config @@ -76,7 +62,7 @@ func (p *defaultProvider) WellKnownHandler() http.HandlerFunc { "response_types_supported": responseConfig.ResponseTypesSupported, "grant_types_supported": responseConfig.GrantTypesSupported, "token_endpoint_auth_methods_supported": []string{"client_secret_basic"}, - "registration_endpoint": registraionEndpoint, + "registration_endpoint": registrationEndpoint, "code_challenge_methods_supported": responseConfig.CodeChallengeMethodsSupported, } @@ -94,3 +80,26 @@ func (p *defaultProvider) WellKnownHandler() http.HandlerFunc { func (p *defaultProvider) RegisterHandler() http.HandlerFunc { return nil } + +func (p *defaultProvider) ProtectedResourceMetadataHandler() http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + meta := map[string]interface{}{ + "audience": p.cfg.ProtectedResourceMetadata.Audience, + "scopes_supported": p.cfg.ProtectedResourceMetadata.ScopesSupported, + "authorization_servers": p.cfg.ProtectedResourceMetadata.AuthorizationServers, + } + + if p.cfg.ProtectedResourceMetadata.JwksURI != "" { + meta["jwks_uri"] = p.cfg.ProtectedResourceMetadata.JwksURI + } + + if len(p.cfg.ProtectedResourceMetadata.BearerMethodsSupported) > 0 { + meta["bearer_methods_supported"] = p.cfg.ProtectedResourceMetadata.BearerMethodsSupported + } + + if err := json.NewEncoder(w).Encode(meta); err != nil { + http.Error(w, "failed to encode metadata", http.StatusInternalServerError) + } + } +} diff --git a/internal/authz/provider.go b/internal/authz/provider.go index 1629cf4..42a8343 100644 --- a/internal/authz/provider.go +++ b/internal/authz/provider.go @@ -7,4 +7,5 @@ import "net/http" type Provider interface { WellKnownHandler() http.HandlerFunc RegisterHandler() http.HandlerFunc + ProtectedResourceMetadataHandler() http.HandlerFunc } diff --git a/internal/authz/scope_validator.go b/internal/authz/scope_validator.go new file mode 100644 index 0000000..bf18a07 --- /dev/null +++ b/internal/authz/scope_validator.go @@ -0,0 +1,72 @@ +package authz + +import ( + "fmt" + "net/http" + "strings" + + "github.com/golang-jwt/jwt/v4" + "github.com/wso2/open-mcp-auth-proxy/internal/config" + "github.com/wso2/open-mcp-auth-proxy/internal/util" +) + +type ScopeValidator struct{} + +// Evaluate and checks the token claims against one or more required scopes. +func (d *ScopeValidator) ValidateAccess( + r *http.Request, + claims *jwt.MapClaims, + config *config.Config, +) AccessControlResult { + env, err := util.ParseRPCRequest(r) + if err != nil { + return AccessControlResult{DecisionDeny, "bad JSON-RPC request"} + } + requiredScopes := util.GetRequiredScopes(config, env) + + if len(requiredScopes) == 0 { + return AccessControlResult{DecisionAllow, ""} + } + + required := make(map[string]struct{}, len(requiredScopes)) + for _, s := range requiredScopes { + s = strings.TrimSpace(s) + if s != "" { + required[s] = struct{}{} + } + } + + var tokenScopes []string + if claims, ok := (*claims)["scope"]; ok { + switch v := claims.(type) { + case string: + tokenScopes = strings.Fields(v) + case []interface{}: + for _, x := range v { + if s, ok := x.(string); ok && s != "" { + tokenScopes = append(tokenScopes, s) + } + } + } + } + + tokenScopeSet := make(map[string]struct{}, len(tokenScopes)) + for _, s := range tokenScopes { + tokenScopeSet[s] = struct{}{} + } + + var missing []string + for s := range required { + if _, ok := tokenScopeSet[s]; !ok { + missing = append(missing, s) + } + } + + if len(missing) == 0 { + return AccessControlResult{DecisionAllow, ""} + } + return AccessControlResult{ + DecisionDeny, + fmt.Sprintf("missing required scope(s): %s", strings.Join(missing, ", ")), + } +} diff --git a/internal/config/config.go b/internal/config/config.go index fc6743c..2a7958a 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -3,6 +3,8 @@ package config import ( "fmt" "os" + "runtime" + "strings" "gopkg.in/yaml.v2" ) @@ -11,21 +13,23 @@ import ( type TransportMode string const ( - SSETransport TransportMode = "sse" - StdioTransport TransportMode = "stdio" + SSETransport TransportMode = "sse" + StdioTransport TransportMode = "stdio" + StreamableHTTPTransport TransportMode = "streamable_http" ) // Common path configuration for all transport modes type PathsConfig struct { - SSE string `yaml:"sse"` - Messages string `yaml:"messages"` + SSE string `yaml:"sse"` + Messages string `yaml:"messages"` + StreamableHTTP string `yaml:"streamable_http"` // Path for streamable HTTP requests } // StdioConfig contains stdio-specific configuration type StdioConfig struct { Enabled bool `yaml:"enabled"` - UserCommand string `yaml:"user_command"` // The command provided by the user - WorkDir string `yaml:"work_dir"` // Working directory (optional) + UserCommand string `yaml:"user_command"` // The command provided by the user + WorkDir string `yaml:"work_dir"` // Working directory (optional) Args []string `yaml:"args,omitempty"` // Additional arguments Env []string `yaml:"env,omitempty"` // Environment variables } @@ -65,6 +69,15 @@ type ResponseConfig struct { CodeChallengeMethodsSupported []string `yaml:"code_challenge_methods_supported,omitempty"` } +type ProtectedResourceMetadata struct { + ResourceIdentifier string `yaml:"resource_identifier"` + Audience string `yaml:"audience"` + ScopesSupported []map[string]interface{} `yaml:"scopes_supported"` + AuthorizationServers []string `yaml:"authorization_servers"` + JwksURI string `yaml:"jwks_uri,omitempty"` + BearerMethodsSupported []string `yaml:"bearer_methods_supported,omitempty"` +} + type PathConfig struct { // For well-known endpoint Response *ResponseConfig `yaml:"response,omitempty"` @@ -83,23 +96,27 @@ type DefaultConfig struct { } type Config struct { - AuthServerBaseURL string - ListenPort int `yaml:"listen_port"` - BaseURL string `yaml:"base_url"` - Port int `yaml:"port"` - JWKSURL string - TimeoutSeconds int `yaml:"timeout_seconds"` - PathMapping map[string]string `yaml:"path_mapping"` - Mode string `yaml:"mode"` - CORSConfig CORSConfig `yaml:"cors"` - TransportMode TransportMode `yaml:"transport_mode"` - Paths PathsConfig `yaml:"paths"` - Stdio StdioConfig `yaml:"stdio"` + ProxyBaseURL string `yaml:"proxy_base_url"` + AuthServerBaseURL string + ListenPort int `yaml:"listen_port"` + BaseURL string `yaml:"base_url"` + Port int `yaml:"port"` + JWKSURL string + TimeoutSeconds int `yaml:"timeout_seconds"` + PathMapping map[string]string `yaml:"path_mapping"` + Mode string `yaml:"mode"` + CORSConfig CORSConfig `yaml:"cors"` + TransportMode TransportMode `yaml:"transport_mode"` + Paths PathsConfig `yaml:"paths"` + Stdio StdioConfig `yaml:"stdio"` // Nested config for Asgardeo Demo DemoConfig `yaml:"demo"` Asgardeo AsgardeoConfig `yaml:"asgardeo"` Default DefaultConfig `yaml:"default"` + + // Protected resource metadata + ProtectedResourceMetadata ProtectedResourceMetadata `yaml:"protected_resource_metadata"` } // Validate checks if the config is valid based on transport mode @@ -136,7 +153,7 @@ func (c *Config) Validate() error { // GetMCPPaths returns the list of paths that should be proxied to the MCP server func (c *Config) GetMCPPaths() []string { - return []string{c.Paths.SSE, c.Paths.Messages} + return []string{c.Paths.SSE, c.Paths.Messages, c.Paths.StreamableHTTP} } // BuildExecCommand constructs the full command string for execution in stdio mode @@ -145,7 +162,15 @@ func (c *Config) BuildExecCommand() string { return "" } - // Construct the full command + if runtime.GOOS == "windows" { + // For Windows, we need to properly escape the inner command + escapedCommand := strings.ReplaceAll(c.Stdio.UserCommand, `"`, `\"`) + return fmt.Sprintf( + `npx -y supergateway --stdio "%s" --port %d --baseUrl %s --ssePath %s --messagePath %s`, + escapedCommand, c.Port, c.BaseURL, c.Paths.SSE, c.Paths.Messages, + ) + } + return fmt.Sprintf( `npx -y supergateway --stdio "%s" --port %d --baseUrl %s --ssePath %s --messagePath %s`, c.Stdio.UserCommand, c.Port, c.BaseURL, c.Paths.SSE, c.Paths.Messages, @@ -165,12 +190,12 @@ func LoadConfig(path string) (*Config, error) { if err := decoder.Decode(&cfg); err != nil { return nil, err } - + // Set default values if cfg.TimeoutSeconds == 0 { cfg.TimeoutSeconds = 15 // default } - + // Set default transport mode if not specified if cfg.TransportMode == "" { cfg.TransportMode = SSETransport // Default to SSE @@ -180,11 +205,11 @@ func LoadConfig(path string) (*Config, error) { if cfg.Port == 0 { cfg.Port = 8000 // default } - + // Validate the configuration if err := cfg.Validate(); err != nil { return nil, err } - + return &cfg, nil } diff --git a/internal/config/config_test.go b/internal/config/config_test.go index 20c0893..edf4182 100644 --- a/internal/config/config_test.go +++ b/internal/config/config_test.go @@ -136,20 +136,15 @@ func TestValidate(t *testing.T) { func TestGetMCPPaths(t *testing.T) { cfg := Config{ Paths: PathsConfig{ - SSE: "/custom-sse", - Messages: "/custom-messages", + SSE: "/custom-sse", + Messages: "/custom-messages", + StreamableHTTP: "/custom-streamable", }, } paths := cfg.GetMCPPaths() - if len(paths) != 2 { - t.Errorf("Expected 2 MCP paths, got %d", len(paths)) - } - if paths[0] != "/custom-sse" { - t.Errorf("Expected first path=/custom-sse, got %s", paths[0]) - } - if paths[1] != "/custom-messages" { - t.Errorf("Expected second path=/custom-messages, got %s", paths[1]) + if len(paths) != 3 { + t.Errorf("Expected 3 MCP paths, got %d", len(paths)) } } diff --git a/internal/constants/constants.go b/internal/constants/constants.go index 1e5808e..e7b1bec 100644 --- a/internal/constants/constants.go +++ b/internal/constants/constants.go @@ -1,7 +1,14 @@ package constants +import "time" + // Package constant provides constants for the MCP Auth Proxy const ( ASGARDEO_BASE_URL = "https://api.asgardeo.io/t/" ) + +// MCP specification version cutover date +var SpecCutoverDate = time.Date(2025, 3, 26, 0, 0, 0, 0, time.UTC) + +const TimeLayout = "2006-01-02" diff --git a/internal/proxy/proxy.go b/internal/proxy/proxy.go index 33a9ea3..fa72d58 100644 --- a/internal/proxy/proxy.go +++ b/internal/proxy/proxy.go @@ -2,6 +2,7 @@ package proxy import ( "context" + "fmt" "net/http" "net/http/httputil" "net/url" @@ -10,14 +11,14 @@ import ( "github.com/wso2/open-mcp-auth-proxy/internal/authz" "github.com/wso2/open-mcp-auth-proxy/internal/config" - "github.com/wso2/open-mcp-auth-proxy/internal/logging" + logger "github.com/wso2/open-mcp-auth-proxy/internal/logging" "github.com/wso2/open-mcp-auth-proxy/internal/util" ) // NewRouter builds an http.ServeMux that routes // * /authorize, /token, /register, /.well-known to the provider or proxy // * MCP paths to the MCP server, etc. -func NewRouter(cfg *config.Config, provider authz.Provider) http.Handler { +func NewRouter(cfg *config.Config, provider authz.Provider, accessController authz.AccessControl) http.Handler { mux := http.NewServeMux() modifiers := map[string]RequestModifier{ @@ -63,6 +64,20 @@ func NewRouter(cfg *config.Config, provider authz.Provider) http.Handler { } } + mux.HandleFunc(getProtectedResourceMetadataEndpointPath(cfg), func(w http.ResponseWriter, r *http.Request) { + origin := r.Header.Get("Origin") + allowed := getAllowedOrigin(origin, cfg) + if r.Method == http.MethodOptions { + addCORSHeaders(w, cfg, allowed, r.Header.Get("Access-Control-Request-Headers")) + w.WriteHeader(http.StatusNoContent) + return + } + + addCORSHeaders(w, cfg, allowed, "") + provider.ProtectedResourceMetadataHandler()(w, r) + }) + registeredPaths[getProtectedResourceMetadataEndpointPath(cfg)] = true + // Remove duplicates from defaultPaths uniquePaths := make(map[string]bool) cleanPaths := []string{} @@ -76,7 +91,7 @@ func NewRouter(cfg *config.Config, provider authz.Provider) http.Handler { for _, path := range defaultPaths { if !registeredPaths[path] { - mux.HandleFunc(path, buildProxyHandler(cfg, modifiers)) + mux.HandleFunc(path, buildProxyHandler(cfg, modifiers, accessController)) registeredPaths[path] = true } } @@ -84,14 +99,14 @@ func NewRouter(cfg *config.Config, provider authz.Provider) http.Handler { // MCP paths mcpPaths := cfg.GetMCPPaths() for _, path := range mcpPaths { - mux.HandleFunc(path, buildProxyHandler(cfg, modifiers)) + mux.HandleFunc(path, buildProxyHandler(cfg, modifiers, accessController)) registeredPaths[path] = true } // Register paths from PathMapping that haven't been registered yet for path := range cfg.PathMapping { if !registeredPaths[path] { - mux.HandleFunc(path, buildProxyHandler(cfg, modifiers)) + mux.HandleFunc(path, buildProxyHandler(cfg, modifiers, accessController)) registeredPaths[path] = true } } @@ -99,14 +114,14 @@ func NewRouter(cfg *config.Config, provider authz.Provider) http.Handler { return mux } -func buildProxyHandler(cfg *config.Config, modifiers map[string]RequestModifier) http.HandlerFunc { +func buildProxyHandler(cfg *config.Config, modifiers map[string]RequestModifier, accessController authz.AccessControl) http.HandlerFunc { // Parse the base URLs up front authBase, err := url.Parse(cfg.AuthServerBaseURL) if err != nil { logger.Error("Invalid auth server URL: %v", err) panic(err) // Fatal error that prevents startup } - + mcpBase, err := url.Parse(cfg.BaseURL) if err != nil { logger.Error("Invalid MCP server URL: %v", err) @@ -141,20 +156,31 @@ func buildProxyHandler(cfg *config.Config, modifiers map[string]RequestModifier) // Add CORS headers to all responses addCORSHeaders(w, cfg, allowedOrigin, "") + // Check if the request is for the latest spec + specVersion := util.GetVersionWithDefault(r.Header.Get("MCP-Protocol-Version")) + ver, err := util.ParseVersionDate(specVersion) + isLatestSpec := util.IsLatestSpec(ver, err) + // Decide whether the request should go to the auth server or MCP var targetURL *url.URL isSSE := false - if isAuthPath(r.URL.Path) { + if isAuthPath(r.URL.Path, cfg) { targetURL = authBase } else if isMCPPath(r.URL.Path, cfg) { - // Validate JWT for MCP paths if required - // Placeholder for JWT validation logic - if err := util.ValidateJWT(r.Header.Get("Authorization")); err != nil { - logger.Warn("Unauthorized request to %s: %v", r.URL.Path, err) - http.Error(w, "Unauthorized", http.StatusUnauthorized) - return + if ssePaths[r.URL.Path] { + if err := authorizeSSE(w, r, isLatestSpec, cfg); err != nil { + http.Error(w, err.Error(), http.StatusUnauthorized) + return + } + isSSE = true + } else { + if err := authorizeMCP(w, r, isLatestSpec, cfg, accessController); err != nil { + http.Error(w, err.Error(), http.StatusForbidden) + return + } } + targetURL = mcpBase if ssePaths[r.URL.Path] { isSSE = true @@ -191,13 +217,13 @@ func buildProxyHandler(cfg *config.Config, modifiers map[string]RequestModifier) req.Host = targetURL.Host cleanHeaders := http.Header{} - + // Set proper origin header to match the target if isSSE { // For SSE, ensure origin matches the target req.Header.Set("Origin", targetURL.Scheme+"://"+targetURL.Host) } - + for k, v := range r.Header { // Skip hop-by-hop headers if skipHeader(k) { @@ -214,7 +240,17 @@ func buildProxyHandler(cfg *config.Config, modifiers map[string]RequestModifier) }, ModifyResponse: func(resp *http.Response) error { logger.Debug("Response from %s%s: %d", resp.Request.URL.Host, resp.Request.URL.Path, resp.StatusCode) - resp.Header.Del("Access-Control-Allow-Origin") // Avoid upstream conflicts + if resp.StatusCode == http.StatusUnauthorized { + resp.Header.Set( + "WWW-Authenticate", + fmt.Sprintf( + `Bearer resource_metadata="%s"`, + cfg.ProxyBaseURL+getProtectedResourceMetadataEndpointPath(cfg), + )) + resp.Header.Set("Access-Control-Expose-Headers", "WWW-Authenticate") + } + + resp.Header.Del("Access-Control-Allow-Origin") return nil }, ErrorHandler: func(rw http.ResponseWriter, req *http.Request, err error) { @@ -231,12 +267,12 @@ func buildProxyHandler(cfg *config.Config, modifiers map[string]RequestModifier) proxyHost: r.Host, targetHost: targetURL.Host, } - + // Set SSE-specific headers w.Header().Set("X-Accel-Buffering", "no") w.Header().Set("Cache-Control", "no-cache") w.Header().Set("Connection", "keep-alive") - + w.Header().Set("Content-Type", "text/event-stream") // Keep SSE connections open HandleSSE(w, r, rp) } else { @@ -248,6 +284,76 @@ func buildProxyHandler(cfg *config.Config, modifiers map[string]RequestModifier) } } +// Check if the request is for SSE handshake and authorize it +func authorizeSSE(w http.ResponseWriter, r *http.Request, isLatestSpec bool, cfg *config.Config) error { + authHeader := r.Header.Get("Authorization") + if authHeader == "" || !strings.HasPrefix(authHeader, "Bearer ") { + if isLatestSpec { + realm := cfg.BaseURL + getProtectedResourceMetadataEndpointPath(cfg) + w.Header().Set("WWW-Authenticate", fmt.Sprintf(`Bearer resource_metadata="%s"`, realm)) + w.Header().Set("Access-Control-Expose-Headers", "WWW-Authenticate") + } + return fmt.Errorf("missing or invalid Authorization header") + } + + return nil +} + +// Handles both v1 (just signature) and v2 (aud + scope) flows +func authorizeMCP(w http.ResponseWriter, r *http.Request, isLatestSpec bool, cfg *config.Config, accessController authz.AccessControl) error { + authzHeader := r.Header.Get("Authorization") + accessToken, _ := util.ExtractAccessToken(authzHeader) + if !strings.HasPrefix(authzHeader, "Bearer ") { + if isLatestSpec { + realm := cfg.ProxyBaseURL + getProtectedResourceMetadataEndpointPath(cfg) + w.Header().Set("WWW-Authenticate", fmt.Sprintf( + `Bearer resource_metadata=%q`, realm, + )) + w.Header().Set("Access-Control-Expose-Headers", "WWW-Authenticate") + } + http.Error(w, "Unauthorized", http.StatusUnauthorized) + return fmt.Errorf("missing or invalid Authorization header") + } + + err := util.ValidateJWT(isLatestSpec, accessToken, cfg.ProtectedResourceMetadata.Audience) + if err != nil { + if isLatestSpec { + realm := cfg.ProxyBaseURL + getProtectedResourceMetadataEndpointPath(cfg) + w.Header().Set("WWW-Authenticate", fmt.Sprintf(err.Error(), + `Bearer realm=%q`, + realm, + )) + w.Header().Set("Access-Control-Expose-Headers", "WWW-Authenticate") + http.Error(w, "Forbidden", http.StatusForbidden) + } else { + http.Error(w, "Unauthorized", http.StatusUnauthorized) + } + return err + } + + if isLatestSpec { + _, err := util.ParseRPCRequest(r) + if err != nil { + http.Error(w, "Bad request", http.StatusBadRequest) + return err + } + + claimsMap, err := util.ParseJWT(accessToken) + if err != nil { + http.Error(w, "Invalid token claims", http.StatusUnauthorized) + return fmt.Errorf("invalid token claims") + } + + pr := accessController.ValidateAccess(r, &claimsMap, cfg) + if pr.Decision == authz.DecisionDeny { + http.Error(w, "Forbidden: "+pr.Message, http.StatusForbidden) + return fmt.Errorf("forbidden — %s", pr.Message) + } + } + + return nil +} + func getAllowedOrigin(origin string, cfg *config.Config) string { if origin == "" { return cfg.CORSConfig.AllowedOrigins[0] // Default to first allowed origin @@ -265,6 +371,7 @@ func getAllowedOrigin(origin string, cfg *config.Config) string { func addCORSHeaders(w http.ResponseWriter, cfg *config.Config, allowedOrigin, requestHeaders string) { w.Header().Set("Access-Control-Allow-Origin", allowedOrigin) w.Header().Set("Access-Control-Allow-Methods", strings.Join(cfg.CORSConfig.AllowedMethods, ", ")) + w.Header().Set("Access-Control-Expose-Headers", "WWW-Authenticate, MCP-Protocol-Version") if requestHeaders != "" { w.Header().Set("Access-Control-Allow-Headers", requestHeaders) } else { @@ -272,17 +379,19 @@ func addCORSHeaders(w http.ResponseWriter, cfg *config.Config, allowedOrigin, re } if cfg.CORSConfig.AllowCredentials { w.Header().Set("Access-Control-Allow-Credentials", "true") + w.Header().Set("MCP-Protocol-Version", ", ") } w.Header().Set("Vary", "Origin") w.Header().Set("X-Accel-Buffering", "no") } -func isAuthPath(path string) bool { +func isAuthPath(path string, cfg *config.Config) bool { authPaths := map[string]bool{ "/authorize": true, "/token": true, "/register": true, - "/.well-known/oauth-authorization-server": true, + "/.well-known/oauth-authorization-server": true, + getProtectedResourceMetadataEndpointPath(cfg): true, } if strings.HasPrefix(path, "/u/") { return true @@ -308,3 +417,17 @@ func skipHeader(h string) bool { } return false } + +func getProtectedResourceMetadataEndpointPath(cfg *config.Config) string { + + protectedResourceMetadataPath := "/.well-known/oauth-protected-resource" + + switch cfg.TransportMode { + case config.SSETransport: + protectedResourceMetadataPath += cfg.Paths.SSE + case config.StreamableHTTPTransport: + protectedResourceMetadataPath += cfg.Paths.StreamableHTTP + } + + return protectedResourceMetadataPath +} diff --git a/internal/subprocess/manager.go b/internal/subprocess/manager.go index fa64337..902a517 100644 --- a/internal/subprocess/manager.go +++ b/internal/subprocess/manager.go @@ -4,13 +4,14 @@ import ( "fmt" "os" "os/exec" + "runtime" + "strings" "sync" "syscall" "time" - "strings" "github.com/wso2/open-mcp-auth-proxy/internal/config" - "github.com/wso2/open-mcp-auth-proxy/internal/logging" + logger "github.com/wso2/open-mcp-auth-proxy/internal/logging" ) // Manager handles starting and graceful shutdown of subprocesses @@ -31,34 +32,39 @@ func NewManager() *Manager { // EnsureDependenciesAvailable checks and installs required package executors func EnsureDependenciesAvailable(command string) error { - // Always ensure npx is available regardless of the command - if _, err := exec.LookPath("npx"); err != nil { - // npx is not available, check if npm is installed - if _, err := exec.LookPath("npm"); err != nil { - return fmt.Errorf("npx not found and npm not available; please install Node.js from https://nodejs.org/") - } - - // Try to install npx using npm - logger.Info("npx not found, attempting to install...") - cmd := exec.Command("npm", "install", "-g", "npx") - cmd.Stdout = os.Stdout - cmd.Stderr = os.Stderr - - if err := cmd.Run(); err != nil { - return fmt.Errorf("failed to install npx: %w", err) - } - - logger.Info("npx installed successfully") - } - - // Check if uv is needed based on the command - if strings.Contains(command, "uv ") { - if _, err := exec.LookPath("uv"); err != nil { - return fmt.Errorf("command requires uv but it's not installed; please install it following instructions at https://github.com/astral-sh/uv") - } - } - - return nil + // Always ensure npx is available regardless of the command + if _, err := exec.LookPath("npx"); err != nil { + // npx is not available, check if npm is installed + if _, err := exec.LookPath("npm"); err != nil { + return fmt.Errorf("npx not found and npm not available; please install Node.js from https://nodejs.org/") + } + + // Try to install npx using npm + logger.Info("npx not found, attempting to install...") + var cmd *exec.Cmd + if runtime.GOOS == "windows" { + cmd = exec.Command("npm.cmd", "install", "-g", "npx") + } else { + cmd = exec.Command("npm", "install", "-g", "npx") + } + cmd.Stdout = os.Stdout + cmd.Stderr = os.Stderr + + if err := cmd.Run(); err != nil { + return fmt.Errorf("failed to install npx: %w", err) + } + + logger.Info("npx installed successfully") + } + + // Check if uv is needed based on the command + if strings.Contains(command, "uv ") { + if _, err := exec.LookPath("uv"); err != nil { + return fmt.Errorf("command requires uv but it's not installed; please install it following instructions at https://github.com/astral-sh/uv") + } + } + + return nil } // SetShutdownDelay sets the maximum time to wait for graceful shutdown @@ -88,8 +94,13 @@ func (m *Manager) Start(cfg *config.Config) error { logger.Info("Starting subprocess with command: %s", execCommand) - // Use the shell to execute the command - cmd := exec.Command("sh", "-c", execCommand) + var cmd *exec.Cmd + if runtime.GOOS == "windows" { + // Use PowerShell on Windows for better quote handling + cmd = exec.Command("powershell", "-Command", execCommand) + } else { + cmd = exec.Command("sh", "-c", execCommand) + } // Set working directory if specified if cfg.Stdio.WorkDir != "" { @@ -105,8 +116,8 @@ func (m *Manager) Start(cfg *config.Config) error { cmd.Stdout = os.Stdout cmd.Stderr = os.Stderr - // Set the process group for proper termination - cmd.SysProcAttr = &syscall.SysProcAttr{Setpgid: true} + // Set platform-specific process attributes + setProcAttr(cmd) // Start the process if err := cmd.Start(); err != nil { @@ -117,11 +128,13 @@ func (m *Manager) Start(cfg *config.Config) error { m.cmd = cmd logger.Info("Subprocess started with PID: %d", m.process.Pid) - // Get and store the process group ID - pgid, err := syscall.Getpgid(m.process.Pid) + // Get and store the process group ID (Unix) or PID (Windows) + pgid, err := getProcessGroup(m.process.Pid) if err == nil { m.processGroup = pgid - logger.Debug("Process group ID: %d", m.processGroup) + if runtime.GOOS != "windows" { + logger.Debug("Process group ID: %d", m.processGroup) + } } else { logger.Warn("Failed to get process group ID: %v", err) m.processGroup = m.process.Pid @@ -155,7 +168,7 @@ func (m *Manager) IsRunning() bool { // Shutdown gracefully terminates the subprocess func (m *Manager) Shutdown() { m.mutex.Lock() - processToTerminate := m.process // Local copy of the process reference + processToTerminate := m.process // Local copy of the process reference processGroupToTerminate := m.processGroup m.mutex.Unlock() @@ -169,48 +182,73 @@ func (m *Manager) Shutdown() { go func() { defer close(terminateComplete) - // Try graceful termination first with SIGTERM + // Try graceful termination first terminatedGracefully := false - // Try to terminate the process group first - if processGroupToTerminate != 0 { - err := syscall.Kill(-processGroupToTerminate, syscall.SIGTERM) - if err != nil { - logger.Warn("Failed to send SIGTERM to process group: %v", err) + if runtime.GOOS == "windows" { + // Windows: Try to terminate the process + m.mutex.Lock() + if m.process != nil { + err := m.process.Kill() + if err != nil { + logger.Warn("Failed to terminate process: %v", err) + } + } + m.mutex.Unlock() - // Fallback to terminating just the process + // Wait a bit to see if it terminates + for i := 0; i < 10; i++ { + time.Sleep(200 * time.Millisecond) + m.mutex.Lock() + if m.process == nil { + terminatedGracefully = true + m.mutex.Unlock() + break + } + m.mutex.Unlock() + } + } else { + // Unix: Use SIGTERM followed by SIGKILL if necessary + // Try to terminate the process group first + if processGroupToTerminate != 0 { + err := killProcessGroup(processGroupToTerminate, syscall.SIGTERM) + if err != nil { + logger.Warn("Failed to send SIGTERM to process group: %v", err) + + // Fallback to terminating just the process + m.mutex.Lock() + if m.process != nil { + err = m.process.Signal(syscall.SIGTERM) + if err != nil { + logger.Warn("Failed to send SIGTERM to process: %v", err) + } + } + m.mutex.Unlock() + } + } else { + // Try to terminate just the process m.mutex.Lock() if m.process != nil { - err = m.process.Signal(syscall.SIGTERM) + err := m.process.Signal(syscall.SIGTERM) if err != nil { logger.Warn("Failed to send SIGTERM to process: %v", err) } } m.mutex.Unlock() } - } else { - // Try to terminate just the process - m.mutex.Lock() - if m.process != nil { - err := m.process.Signal(syscall.SIGTERM) - if err != nil { - logger.Warn("Failed to send SIGTERM to process: %v", err) + + // Wait for the process to exit gracefully + for i := 0; i < 10; i++ { + time.Sleep(200 * time.Millisecond) + + m.mutex.Lock() + if m.process == nil { + terminatedGracefully = true + m.mutex.Unlock() + break } - } - m.mutex.Unlock() - } - - // Wait for the process to exit gracefully - for i := 0; i < 10; i++ { - time.Sleep(200 * time.Millisecond) - - m.mutex.Lock() - if m.process == nil { - terminatedGracefully = true m.mutex.Unlock() - break } - m.mutex.Unlock() } if terminatedGracefully { @@ -221,12 +259,33 @@ func (m *Manager) Shutdown() { // If the process didn't exit gracefully, force kill logger.Warn("Subprocess didn't exit gracefully, forcing termination...") - // Try to kill the process group first - if processGroupToTerminate != 0 { - if err := syscall.Kill(-processGroupToTerminate, syscall.SIGKILL); err != nil { - logger.Warn("Failed to send SIGKILL to process group: %v", err) + if runtime.GOOS == "windows" { + // On Windows, Kill() is already forceful + m.mutex.Lock() + if m.process != nil { + if err := m.process.Kill(); err != nil { + logger.Error("Failed to kill process: %v", err) + } + } + m.mutex.Unlock() + } else { + // Unix: Try SIGKILL + // Try to kill the process group first + if processGroupToTerminate != 0 { + if err := killProcessGroup(processGroupToTerminate, syscall.SIGKILL); err != nil { + logger.Warn("Failed to send SIGKILL to process group: %v", err) - // Fallback to killing just the process + // Fallback to killing just the process + m.mutex.Lock() + if m.process != nil { + if err := m.process.Kill(); err != nil { + logger.Error("Failed to kill process: %v", err) + } + } + m.mutex.Unlock() + } + } else { + // Try to kill just the process m.mutex.Lock() if m.process != nil { if err := m.process.Kill(); err != nil { @@ -235,15 +294,6 @@ func (m *Manager) Shutdown() { } m.mutex.Unlock() } - } else { - // Try to kill just the process - m.mutex.Lock() - if m.process != nil { - if err := m.process.Kill(); err != nil { - logger.Error("Failed to kill process: %v", err) - } - } - m.mutex.Unlock() } // Wait a bit more to confirm termination diff --git a/internal/subprocess/manager_unix.go b/internal/subprocess/manager_unix.go new file mode 100644 index 0000000..03ae1a8 --- /dev/null +++ b/internal/subprocess/manager_unix.go @@ -0,0 +1,23 @@ +//go:build !windows + +package subprocess + +import ( + "os/exec" + "syscall" +) + +// setProcAttr sets Unix-specific process attributes +func setProcAttr(cmd *exec.Cmd) { + cmd.SysProcAttr = &syscall.SysProcAttr{Setpgid: true} +} + +// getProcessGroup gets the process group ID on Unix systems +func getProcessGroup(pid int) (int, error) { + return syscall.Getpgid(pid) +} + +// killProcessGroup kills a process group on Unix systems +func killProcessGroup(pgid int, signal syscall.Signal) error { + return syscall.Kill(-pgid, signal) +} diff --git a/internal/subprocess/manager_windows.go b/internal/subprocess/manager_windows.go new file mode 100644 index 0000000..a039897 --- /dev/null +++ b/internal/subprocess/manager_windows.go @@ -0,0 +1,27 @@ +//go:build windows + +package subprocess + +import ( + "os/exec" + "syscall" +) + +// setProcAttr sets Windows-specific process attributes +func setProcAttr(cmd *exec.Cmd) { + cmd.SysProcAttr = &syscall.SysProcAttr{ + CreationFlags: syscall.CREATE_NEW_PROCESS_GROUP, + } +} + +// getProcessGroup returns the PID itself on Windows (no process groups) +func getProcessGroup(pid int) (int, error) { + return pid, nil +} + +// killProcessGroup kills a process on Windows (no process groups) +func killProcessGroup(pgid int, signal syscall.Signal) error { + // On Windows, we'll use the process handle directly + // This function shouldn't be called on Windows, but we provide it for compatibility + return nil +} diff --git a/internal/util/jwks.go b/internal/util/jwks.go index f80d82e..1a00d6e 100644 --- a/internal/util/jwks.go +++ b/internal/util/jwks.go @@ -4,21 +4,27 @@ import ( "crypto/rsa" "encoding/json" "errors" + "fmt" "math/big" "net/http" "strings" "github.com/golang-jwt/jwt/v4" - "github.com/wso2/open-mcp-auth-proxy/internal/logging" + "github.com/wso2/open-mcp-auth-proxy/internal/config" + logger "github.com/wso2/open-mcp-auth-proxy/internal/logging" ) +type TokenClaims struct { + Scopes []string +} + type JWKS struct { Keys []json.RawMessage `json:"keys"` } var publicKeys map[string]*rsa.PublicKey -// FetchJWKS downloads JWKS and stores in a package-level map +// FetchJWKS downloads JWKS and stores in a package‐level map func FetchJWKS(jwksURL string) error { resp, err := http.Get(jwksURL) if err != nil { @@ -31,23 +37,23 @@ func FetchJWKS(jwksURL string) error { return err } - publicKeys = make(map[string]*rsa.PublicKey) + publicKeys = make(map[string]*rsa.PublicKey, len(jwks.Keys)) for _, keyData := range jwks.Keys { - var parsedKey struct { + var parsed struct { Kid string `json:"kid"` N string `json:"n"` E string `json:"e"` Kty string `json:"kty"` } - if err := json.Unmarshal(keyData, &parsedKey); err != nil { + if err := json.Unmarshal(keyData, &parsed); err != nil { continue } - if parsedKey.Kty != "RSA" { + if parsed.Kty != "RSA" { continue } - pubKey, err := parseRSAPublicKey(parsedKey.N, parsedKey.E) + pubKey, err := parseRSAPublicKey(parsed.N, parsed.E) if err == nil { - publicKeys[parsedKey.Kid] = pubKey + publicKeys[parsed.Kid] = pubKey } } logger.Info("Loaded %d public keys.", len(publicKeys)) @@ -73,25 +79,150 @@ func parseRSAPublicKey(nStr, eStr string) (*rsa.PublicKey, error) { return &rsa.PublicKey{N: n, E: e}, nil } -// ValidateJWT checks the Authorization: Bearer token using stored JWKS -func ValidateJWT(authHeader string) error { - if authHeader == "" || !strings.HasPrefix(authHeader, "Bearer ") { - return errors.New("missing or invalid Authorization header") - } - tokenStr := strings.TrimPrefix(authHeader, "Bearer ") - token, err := jwt.Parse(tokenStr, func(token *jwt.Token) (interface{}, error) { - kid, _ := token.Header["kid"].(string) - pubKey, ok := publicKeys[kid] - if !ok { - return nil, errors.New("unknown or missing kid in token header") +// ValidateJWT checks the Bearer token according to the Mcp-Protocol-Version. +func ValidateJWT( + isLatestSpec bool, + accessToken string, + audience string, +) error { + logger.Warn("isLatestSpec: %s", isLatestSpec) + // Parse & verify the signature + token, err := jwt.Parse(accessToken, func(token *jwt.Token) (interface{}, error) { + if _, ok := token.Method.(*jwt.SigningMethodRSA); !ok { + return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"]) } - return pubKey, nil + kid, ok := token.Header["kid"].(string) + if !ok { + return nil, errors.New("kid header not found") + } + key, ok := publicKeys[kid] + if !ok { + return nil, fmt.Errorf("key not found for kid: %s", kid) + } + return key, nil }) if err != nil { - return errors.New("invalid token: " + err.Error()) + logger.Warn("Error detected, returning early") + return fmt.Errorf("invalid token: %w", err) } if !token.Valid { - return errors.New("invalid token: token not valid") + logger.Warn("Token invalid, returning early") + return errors.New("token not valid") } + + claimsMap, ok := token.Claims.(jwt.MapClaims) + if !ok { + return errors.New("unexpected claim type") + } + + if !isLatestSpec { + return nil + } + + audRaw, exists := claimsMap["aud"] + if !exists { + return errors.New("aud claim missing") + } + switch v := audRaw.(type) { + case string: + if v != audience { + return fmt.Errorf("aud %q does not match %q", v, audience) + } + case []interface{}: + var found bool + for _, a := range v { + if s, ok := a.(string); ok && s == audience { + found = true + break + } + } + if !found { + return fmt.Errorf("audience %v does not include %q", v, audience) + } + default: + return errors.New("aud claim has unexpected type") + } + return nil } + +// Parses the JWT token and returns the claims +func ParseJWT(tokenStr string) (jwt.MapClaims, error) { + if tokenStr == "" { + return nil, fmt.Errorf("empty JWT") + } + + var claims jwt.MapClaims + _, _, err := jwt.NewParser().ParseUnverified(tokenStr, &claims) + if err != nil { + return nil, fmt.Errorf("failed to parse JWT: %w", err) + } + return claims, nil +} + +// Process the required scopes +func GetRequiredScopes(cfg *config.Config, requestBody *RPCEnvelope) []string { + + var scopeObj interface{} + found := false + for _, m := range cfg.ProtectedResourceMetadata.ScopesSupported { + if val, ok := m[requestBody.Method]; ok { + scopeObj = val + found = true + break + } + } + if !found { + return nil + } + + switch v := scopeObj.(type) { + case string: + return []string{v} + case []any: + if requestBody.Params != nil { + if paramsMap, ok := requestBody.Params.(map[string]any); ok { + name, ok := paramsMap["name"].(string) + if ok { + for _, item := range v { + if scopeMap, ok := item.(map[interface{}]interface{}); ok { + if scopeVal, exists := scopeMap[name]; exists { + if scopeStr, ok := scopeVal.(string); ok { + return []string{scopeStr} + } + if scopeArr, ok := scopeVal.([]any); ok { + var scopes []string + for _, s := range scopeArr { + if str, ok := s.(string); ok { + scopes = append(scopes, str) + } + } + return scopes + } + } + } + } + } + } + } + } + + return nil +} + +// Extracts the Bearer token from the Authorization header +func ExtractAccessToken(authHeader string) (string, error) { + if authHeader == "" { + return "", errors.New("empty authorization header") + } + if !strings.HasPrefix(authHeader, "Bearer ") { + return "", fmt.Errorf("invalid authorization header format: %s", authHeader) + } + + tokenStr := strings.TrimSpace(strings.TrimPrefix(authHeader, "Bearer ")) + if tokenStr == "" { + return "", errors.New("empty bearer token") + } + + return tokenStr, nil +} diff --git a/internal/util/jwks_test.go b/internal/util/jwks_test.go index 3b00c68..19c506e 100644 --- a/internal/util/jwks_test.go +++ b/internal/util/jwks_test.go @@ -7,6 +7,7 @@ import ( "encoding/json" "net/http" "net/http/httptest" + "strings" "testing" "time" @@ -47,7 +48,14 @@ func TestValidateJWT(t *testing.T) { for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { - err := ValidateJWT(tc.authHeader) + var accessToken string + parts := strings.Split(tc.authHeader, "Bearer ") + if len(parts) == 2 { + accessToken = parts[1] + } else { + accessToken = "" + } + err := ValidateJWT(true, accessToken, "test-audience") if tc.expectError && err == nil { t.Errorf("Expected error but got none") } @@ -128,6 +136,7 @@ func createValidJWT(t *testing.T) string { token := jwt.NewWithClaims(jwt.SigningMethodRS256, jwt.MapClaims{ "sub": "1234567890", "name": "Test User", + "aud": "test-audience", "iat": time.Now().Unix(), "exp": time.Now().Add(time.Hour).Unix(), }) diff --git a/internal/util/rpc.go b/internal/util/rpc.go new file mode 100644 index 0000000..5338437 --- /dev/null +++ b/internal/util/rpc.go @@ -0,0 +1,38 @@ +package util + +import ( + "bytes" + "encoding/json" + "io" + "net/http" + + logger "github.com/wso2/open-mcp-auth-proxy/internal/logging" +) + +type RPCEnvelope struct { + Method string `json:"method"` + Params any `json:"params"` + ID any `json:"id"` +} + +// This function parses a JSON-RPC request from an HTTP request body +func ParseRPCRequest(r *http.Request) (*RPCEnvelope, error) { + bodyBytes, err := io.ReadAll(r.Body) + if err != nil { + return nil, err + } + r.Body = io.NopCloser(bytes.NewBuffer(bodyBytes)) + + if len(bodyBytes) == 0 { + return nil, nil + } + + var env RPCEnvelope + dec := json.NewDecoder(bytes.NewReader(bodyBytes)) + if err := dec.Decode(&env); err != nil && err != io.EOF { + logger.Warn("Error parsing JSON-RPC envelope: %v", err) + return nil, err + } + + return &env, nil +} diff --git a/internal/util/version.go b/internal/util/version.go new file mode 100644 index 0000000..230ef1d --- /dev/null +++ b/internal/util/version.go @@ -0,0 +1,26 @@ +package util + +import ( + "time" + + "github.com/wso2/open-mcp-auth-proxy/internal/constants" +) + +// This function checks if the given version date is after the spec cutover date +func IsLatestSpec(versionDate time.Time, err error) bool { + return err == nil && versionDate.After(constants.SpecCutoverDate) +} + +// This function parses a version string into a time.Time +func ParseVersionDate(version string) (time.Time, error) { + return time.Parse("2006-01-02", version) +} + +// This function returns the version string, using the cutover date if empty +func GetVersionWithDefault(version string) string { + if version == "" { + defaultTime, _ := time.Parse(constants.TimeLayout, "2025-05-15") + return defaultTime.Format(constants.TimeLayout) + } + return version +} diff --git a/resources/README.md b/resources/README.md new file mode 100644 index 0000000..a10285d --- /dev/null +++ b/resources/README.md @@ -0,0 +1,19 @@ +# Example MCP server + +Use this example MCP server, if you don't already have an MCP server to test the open-mcp-auth-proxy. + +## Setting Up + +1. Set up a Python virtual environment. + +```bash +python3 -m venv .venv +source .venv/bin/activate +pip3 install -r requirements.txt +``` + +2. Start the example server. + +```bash +python3 echo_server.py +``` diff --git a/resources/echo_server.py b/resources/echo_server.py index 889bcc7..f9339c5 100644 --- a/resources/echo_server.py +++ b/resources/echo_server.py @@ -2,7 +2,6 @@ from mcp.server.fastmcp import FastMCP mcp = FastMCP("Echo") - @mcp.resource("echo://{message}") def echo_resource(message: str) -> str: """Echo a message as a resource"""