mirror of
https://github.com/wso2/open-mcp-auth-proxy.git
synced 2025-12-14 12:12:30 +00:00
Compare commits
36 commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
017a3f3a13 | ||
|
|
8db310ec06 | ||
|
|
56d969b785 | ||
|
|
eb72be0aab | ||
|
|
c38e27b097 | ||
|
|
8589035d64 | ||
|
|
b30aa6273c | ||
|
|
2cad797aee | ||
|
|
a2d775a902 | ||
|
|
31fce13ca4 | ||
|
|
5b13c7167f | ||
|
|
2c175869a9 | ||
|
|
885b32ee80 | ||
|
|
1ab421dc81 | ||
|
|
7120886e9b | ||
|
|
21805b4f0b | ||
|
|
5601c7836c | ||
|
|
c5cfecf243 | ||
|
|
ba19c7363a | ||
|
|
c65f73a6ce | ||
|
|
d71ee4052c | ||
|
|
3e2c49db5c | ||
|
|
1edfed91b5 | ||
|
|
53e0fa65a1 | ||
|
|
fdb81007d4 | ||
|
|
316370be1c | ||
|
|
fc0d939e16 | ||
|
|
2a0075b22e | ||
|
|
edd3ce483e | ||
|
|
c7fc15399b | ||
|
|
9e1316b420 | ||
|
|
9f856c4279 | ||
|
|
be697b5868 | ||
|
|
8bc2e6e76b | ||
|
|
561b8fb637 | ||
|
|
68015ae8fc |
26 changed files with 1030 additions and 379 deletions
2
.github/scripts/release.sh
vendored
2
.github/scripts/release.sh
vendored
|
|
@ -51,7 +51,7 @@ else
|
|||
fi
|
||||
|
||||
# Extract current version.
|
||||
CURRENT_VERSION=$(git describe --tags --abbrev=0 2>/dev/null || echo "0.0.0")
|
||||
CURRENT_VERSION=$(git tag --sort=-v:refname | head -n 1 | sed 's/^v//' || echo "0.0.0")
|
||||
IFS='.' read -r MAJOR MINOR PATCH <<< "${CURRENT_VERSION}"
|
||||
|
||||
# Determine which part to increment
|
||||
|
|
|
|||
4
.gitignore
vendored
4
.gitignore
vendored
|
|
@ -36,3 +36,7 @@ coverage.html
|
|||
|
||||
# IDE files
|
||||
.vscode
|
||||
|
||||
# node modules
|
||||
node_modules
|
||||
openmcpauthproxy
|
||||
|
|
|
|||
62
CONTRIBUTING.md
Normal file
62
CONTRIBUTING.md
Normal file
|
|
@ -0,0 +1,62 @@
|
|||
# Contributing
|
||||
|
||||
## Build from Source
|
||||
|
||||
> Prerequisites
|
||||
>
|
||||
> * Go 1.20 or higher
|
||||
> * Git
|
||||
> * Make (optional, for simplified builds)
|
||||
|
||||
1. **Clone the repository:**
|
||||
```bash
|
||||
git clone https://github.com/wso2/open-mcp-auth-proxy
|
||||
cd open-mcp-auth-proxy
|
||||
```
|
||||
|
||||
2. **Install dependencies:**
|
||||
```bash
|
||||
go get -v -t -d ./...
|
||||
```
|
||||
|
||||
3. **Build the application:**
|
||||
|
||||
**Option A: Using Make**
|
||||
|
||||
```bash
|
||||
# Build for all platforms
|
||||
make all
|
||||
|
||||
# Or build for specific platforms
|
||||
make build-linux # For Linux (x86_64)
|
||||
make build-linux-arm # For ARM-based Linux
|
||||
make build-darwin # For macOS
|
||||
make build-windows # For Windows
|
||||
```
|
||||
|
||||
**Option B: Manual build (works on all platforms)**
|
||||
|
||||
```bash
|
||||
# Build for your current platform
|
||||
go build -o openmcpauthproxy ./cmd/proxy
|
||||
|
||||
# Cross-compile for other platforms
|
||||
GOOS=linux GOARCH=amd64 go build -o openmcpauthproxy-linux ./cmd/proxy
|
||||
GOOS=windows GOARCH=amd64 go build -o openmcpauthproxy.exe ./cmd/proxy
|
||||
GOOS=darwin GOARCH=amd64 go build -o openmcpauthproxy-macos ./cmd/proxy
|
||||
```
|
||||
|
||||
After building, you'll find the executables in the `build` directory (when using Make) or in your project root (when building manually).
|
||||
|
||||
### Additional Make Targets
|
||||
|
||||
If you're using Make, these additional targets are available:
|
||||
|
||||
```bash
|
||||
make test # Run tests
|
||||
make coverage # Run tests with coverage report
|
||||
make fmt # Format code with gofmt
|
||||
make vet # Run go vet
|
||||
make clean # Clean build artifacts
|
||||
make help # Show all available targets
|
||||
```
|
||||
10
Makefile
10
Makefile
|
|
@ -24,9 +24,9 @@ TEST_OPTS := -v -race
|
|||
.PHONY: all clean test fmt lint vet coverage help
|
||||
|
||||
# Default target
|
||||
all: lint test build-linux build-linux-arm build-darwin
|
||||
all: lint test build-linux build-linux-arm build-darwin build-windows
|
||||
|
||||
build: clean test build-linux build-linux-arm build-darwin
|
||||
build: clean test build-linux build-linux-arm build-darwin build-windows
|
||||
|
||||
build-linux:
|
||||
mkdir -p $(BUILD_DIR)/linux
|
||||
|
|
@ -46,6 +46,12 @@ build-darwin:
|
|||
-o $(BUILD_DIR)/darwin/openmcpauthproxy $(PROJECT_ROOT)/cmd/proxy
|
||||
cp config.yaml $(BUILD_DIR)/darwin
|
||||
|
||||
build-windows:
|
||||
mkdir -p $(BUILD_DIR)/windows
|
||||
GOOS=windows GOARCH=amd64 CGO_ENABLED=0 go build -x -ldflags "-X main.version=$(BUILD_VERSION)" \
|
||||
-o $(BUILD_DIR)/windows/openmcpauthproxy.exe ./cmd/proxy
|
||||
cp config.yaml $(BUILD_DIR)/windows
|
||||
|
||||
# Clean build artifacts
|
||||
clean:
|
||||
@echo "Cleaning build artifacts..."
|
||||
|
|
|
|||
200
README.md
200
README.md
|
|
@ -10,64 +10,40 @@ A lightweight authorization proxy for Model Context Protocol (MCP) servers that
|
|||
|
||||

|
||||
|
||||
## What it Does
|
||||
## 🚀 Features
|
||||
|
||||
Open MCP Auth Proxy sits between MCP clients and your MCP server to:
|
||||
- **Dynamic Authorization**: based on MCP Authorization Specification.
|
||||
- **JWT Validation**: Validates the token’s signature, checks the `audience` claim, and enforces scope requirements.
|
||||
- **Identity Provider Integration**: Supports integrating any OAuth/OIDC provider such as Asgardeo, Auth0, Keycloak, etc.
|
||||
- **Protocol Version Negotiation**: via `MCP-Protocol-Version` header.
|
||||
- **Flexible Transport Modes**: Supports STDIO, SSE and streamable HTTP transport options.
|
||||
|
||||
- Intercept incoming requests
|
||||
- Validate authorization tokens
|
||||
- Offload authentication and authorization to OAuth-compliant Identity Providers
|
||||
- Support the MCP authorization protocol
|
||||
## 🛠️ Quick Start
|
||||
|
||||
## Quick Start
|
||||
|
||||
### Prerequisites
|
||||
|
||||
* Go 1.20 or higher
|
||||
* A running MCP server
|
||||
|
||||
> If you don't have an MCP server, you can use the included example:
|
||||
> **Prerequisites**
|
||||
>
|
||||
> 1. Navigate to the `resources` directory
|
||||
> 2. Set up a Python environment:
|
||||
>
|
||||
> ```bash
|
||||
> python3 -m venv .venv
|
||||
> source .venv/bin/activate
|
||||
> pip3 install -r requirements.txt
|
||||
> ```
|
||||
>
|
||||
> 3. Start the example server:
|
||||
>
|
||||
> ```bash
|
||||
> python3 echo_server.py
|
||||
> ```
|
||||
|
||||
* An MCP client that supports MCP authorization
|
||||
|
||||
### Basic Usage
|
||||
> * A running MCP server (Use the [example MCP server](resources/README.md) if you don't have an MCP server already)
|
||||
> * An MCP client that supports MCP authorization specification
|
||||
|
||||
1. Download the latest release from [Github releases](https://github.com/wso2/open-mcp-auth-proxy/releases/latest).
|
||||
|
||||
2. Start the proxy in demo mode (uses pre-configured authentication with Asgardeo sandbox):
|
||||
|
||||
- Linux/macOS:
|
||||
|
||||
```bash
|
||||
./openmcpauthproxy --demo
|
||||
```
|
||||
|
||||
> The repository comes with a default `config.yaml` file that contains the basic configuration:
|
||||
>
|
||||
> ```yaml
|
||||
> listen_port: 8080
|
||||
> base_url: "http://localhost:8000" # Your MCP server URL
|
||||
> paths:
|
||||
> sse: "/sse"
|
||||
> messages: "/messages/"
|
||||
> ```
|
||||
- Windows:
|
||||
|
||||
3. Connect using an MCP client like [MCP Inspector](https://github.com/shashimalcse/inspector)(This is a temporary fork with fixes for authentication [issues](https://github.com/modelcontextprotocol/typescript-sdk/issues/257) in the original implementation)
|
||||
```powershell
|
||||
.\openmcpauthproxy.exe --demo
|
||||
```
|
||||
|
||||
## Connect an Identity Provider
|
||||
3. Connect using an MCP client like [MCP Inspector](https://github.com/modelcontextprotocol/inspector).
|
||||
|
||||
## 🔒 Integrate an Identity Provider
|
||||
|
||||
### Asgardeo
|
||||
|
||||
|
|
@ -75,22 +51,26 @@ To enable authorization through your Asgardeo organization:
|
|||
|
||||
1. [Register](https://asgardeo.io/signup) and create an organization in Asgardeo
|
||||
2. Create an [M2M application](https://wso2.com/asgardeo/docs/guides/applications/register-machine-to-machine-app/)
|
||||
1. [Authorize this application](https://wso2.com/asgardeo/docs/guides/applications/register-machine-to-machine-app/#authorize-the-api-resources-for-the-app) to invoke "Application Management API" with the `internal_application_mgt_create` scope
|
||||
3. [Authorize this application](https://wso2.com/asgardeo/docs/guides/applications/register-machine-to-machine-app/#authorize-the-api-resources-for-the-app) to invoke "Application Management API" with the `internal_application_mgt_create` scope
|
||||

|
||||
|
||||
3. Update `config.yaml` with the following parameters.
|
||||
4. Update `config.yaml` with the following parameters.
|
||||
|
||||
```yaml
|
||||
base_url: "http://localhost:8000" # URL of your MCP server
|
||||
listen_port: 8080 # Address where the proxy will listen
|
||||
|
||||
asgardeo:
|
||||
org_name: "<org_name>" # Your Asgardeo org name
|
||||
client_id: "<client_id>" # Client ID of the M2M app
|
||||
client_secret: "<client_secret>" # Client secret of the M2M app
|
||||
resource_identifier: "http://localhost:8080" # Proxy server URL
|
||||
scopes_supported: # Scopes required to defined for the MCP server
|
||||
- "read:tools"
|
||||
- "read:resources"
|
||||
audience: "<audience_value>" # Access token audience
|
||||
authorization_servers: # Authorization server issuer identifier(s)
|
||||
- "https://api.asgardeo.io/t/acme"
|
||||
jwks_uri: "https://api.asgardeo.io/t/acme/oauth2/jwks" # JWKS URL
|
||||
```
|
||||
|
||||
4. Start the proxy with Asgardeo integration:
|
||||
5. Start the proxy with Asgardeo integration:
|
||||
|
||||
```bash
|
||||
./openmcpauthproxy --asgardeo
|
||||
|
|
@ -101,59 +81,24 @@ asgardeo:
|
|||
- [Auth0](docs/integrations/Auth0.md)
|
||||
- [Keycloak](docs/integrations/keycloak.md)
|
||||
|
||||
# Advanced Configuration
|
||||
## Transport Modes
|
||||
|
||||
### Transport Modes
|
||||
|
||||
The proxy supports two transport modes:
|
||||
|
||||
- **SSE Mode (Default)**: For Server-Sent Events transport
|
||||
- **stdio Mode**: For MCP servers that use stdio transport
|
||||
### **STDIO Mode**
|
||||
|
||||
When using stdio mode, the proxy:
|
||||
- Starts an MCP server as a subprocess using the command specified in the configuration
|
||||
- Communicates with the subprocess through standard input/output (stdio)
|
||||
- **Note**: Any commands specified (like `npx` in the example below) must be installed on your system first
|
||||
|
||||
To use stdio mode:
|
||||
|
||||
```bash
|
||||
./openmcpauthproxy --demo --stdio
|
||||
```
|
||||
|
||||
#### Example: Running an MCP Server as a Subprocess
|
||||
> **Note**: Any commands specified (like `npx` in the example below) must be installed on your system first
|
||||
|
||||
1. Configure stdio mode in your `config.yaml`:
|
||||
|
||||
```yaml
|
||||
listen_port: 8080
|
||||
base_url: "http://localhost:8000"
|
||||
|
||||
stdio:
|
||||
enabled: true
|
||||
user_command: "npx -y @modelcontextprotocol/server-github" # Example using a GitHub MCP server
|
||||
env: # Environment variables (optional)
|
||||
- "GITHUB_PERSONAL_ACCESS_TOKEN=gitPAT"
|
||||
|
||||
# CORS configuration
|
||||
cors:
|
||||
allowed_origins:
|
||||
- "http://localhost:5173" # Origin of your client application
|
||||
allowed_methods:
|
||||
- "GET"
|
||||
- "POST"
|
||||
- "PUT"
|
||||
- "DELETE"
|
||||
allowed_headers:
|
||||
- "Authorization"
|
||||
- "Content-Type"
|
||||
allow_credentials: true
|
||||
|
||||
# Demo configuration for Asgardeo
|
||||
demo:
|
||||
org_name: "openmcpauthdemo"
|
||||
client_id: "N0U9e_NNGr9mP_0fPnPfPI0a6twa"
|
||||
client_secret: "qFHfiBp5gNGAO9zV4YPnDofBzzfInatfUbHyPZvM0jka"
|
||||
```
|
||||
|
||||
2. Run the proxy with stdio mode:
|
||||
|
|
@ -162,65 +107,28 @@ demo:
|
|||
./openmcpauthproxy --demo
|
||||
```
|
||||
|
||||
The proxy will:
|
||||
- Start the MCP server as a subprocess using the specified command
|
||||
- Handle all authorization requirements
|
||||
- Forward messages between clients and the server
|
||||
- **SSE Mode (Default)**: For Server-Sent Events transport
|
||||
- **Streamable HTTP Mode**: For Streamable HTTP transport
|
||||
|
||||
### Complete Configuration Reference
|
||||
|
||||
```yaml
|
||||
# Common configuration
|
||||
listen_port: 8080
|
||||
base_url: "http://localhost:8000"
|
||||
port: 8000
|
||||
|
||||
# Path configuration
|
||||
paths:
|
||||
sse: "/sse"
|
||||
messages: "/messages/"
|
||||
|
||||
# Transport mode
|
||||
transport_mode: "sse" # Options: "sse" or "stdio"
|
||||
|
||||
# stdio-specific configuration (used only in stdio mode)
|
||||
stdio:
|
||||
enabled: true
|
||||
user_command: "npx -y @modelcontextprotocol/server-github" # Command to start the MCP server (requires npx to be installed)
|
||||
work_dir: "" # Optional working directory for the subprocess
|
||||
|
||||
# CORS configuration
|
||||
cors:
|
||||
allowed_origins:
|
||||
- "http://localhost:5173"
|
||||
allowed_methods:
|
||||
- "GET"
|
||||
- "POST"
|
||||
- "PUT"
|
||||
- "DELETE"
|
||||
allowed_headers:
|
||||
- "Authorization"
|
||||
- "Content-Type"
|
||||
allow_credentials: true
|
||||
|
||||
# Demo configuration for Asgardeo
|
||||
demo:
|
||||
org_name: "openmcpauthdemo"
|
||||
client_id: "N0U9e_NNGr9mP_0fPnPfPI0a6twa"
|
||||
client_secret: "qFHfiBp5gNGAO9zV4YPnDofBzzfInatfUbHyPZvM0jka"
|
||||
|
||||
# Asgardeo configuration (used with --asgardeo flag)
|
||||
asgardeo:
|
||||
org_name: "<org_name>"
|
||||
client_id: "<client_id>"
|
||||
client_secret: "<client_secret>"
|
||||
```
|
||||
|
||||
### Build from source
|
||||
## Available Command Line Options
|
||||
|
||||
```bash
|
||||
git clone https://github.com/wso2/open-mcp-auth-proxy
|
||||
cd open-mcp-auth-proxy
|
||||
go get github.com/golang-jwt/jwt/v4 gopkg.in/yaml.v2
|
||||
go build -o openmcpauthproxy ./cmd/proxy
|
||||
# Start in demo mode (using Asgardeo sandbox)
|
||||
./openmcpauthproxy --demo
|
||||
|
||||
# Start with your own Asgardeo organization
|
||||
./openmcpauthproxy --asgardeo
|
||||
|
||||
# Use stdio transport mode instead of SSE
|
||||
./openmcpauthproxy --demo --stdio
|
||||
|
||||
# Enable debug logging
|
||||
./openmcpauthproxy --demo --debug
|
||||
|
||||
# Show all available options
|
||||
./openmcpauthproxy --help
|
||||
```
|
||||
|
||||
## Contributing
|
||||
|
||||
We appreciate your contributions, whether it is improving documentation, adding new features, or fixing bugs. To get started, please refer to our [contributing guide](CONTRIBUTING.md).
|
||||
|
|
|
|||
|
|
@ -11,7 +11,6 @@ import (
|
|||
|
||||
"github.com/wso2/open-mcp-auth-proxy/internal/authz"
|
||||
"github.com/wso2/open-mcp-auth-proxy/internal/config"
|
||||
"github.com/wso2/open-mcp-auth-proxy/internal/constants"
|
||||
"github.com/wso2/open-mcp-auth-proxy/internal/logging"
|
||||
"github.com/wso2/open-mcp-auth-proxy/internal/proxy"
|
||||
"github.com/wso2/open-mcp-auth-proxy/internal/subprocess"
|
||||
|
|
@ -68,23 +67,7 @@ func main() {
|
|||
}
|
||||
|
||||
// 3. Create the chosen provider
|
||||
var provider authz.Provider
|
||||
if *demoMode {
|
||||
cfg.Mode = "demo"
|
||||
cfg.AuthServerBaseURL = constants.ASGARDEO_BASE_URL + cfg.Demo.OrgName + "/oauth2"
|
||||
cfg.JWKSURL = constants.ASGARDEO_BASE_URL + cfg.Demo.OrgName + "/oauth2/jwks"
|
||||
provider = authz.NewAsgardeoProvider(cfg)
|
||||
} else if *asgardeoMode {
|
||||
cfg.Mode = "asgardeo"
|
||||
cfg.AuthServerBaseURL = constants.ASGARDEO_BASE_URL + cfg.Asgardeo.OrgName + "/oauth2"
|
||||
cfg.JWKSURL = constants.ASGARDEO_BASE_URL + cfg.Asgardeo.OrgName + "/oauth2/jwks"
|
||||
provider = authz.NewAsgardeoProvider(cfg)
|
||||
} else {
|
||||
cfg.Mode = "default"
|
||||
cfg.JWKSURL = cfg.Default.JWKSURL
|
||||
cfg.AuthServerBaseURL = cfg.Default.BaseURL
|
||||
provider = authz.NewDefaultProvider(cfg)
|
||||
}
|
||||
var provider authz.Provider = MakeProvider(cfg, *demoMode, *asgardeoMode)
|
||||
|
||||
// 4. (Optional) Fetch JWKS if you want local JWT validation
|
||||
if err := util.FetchJWKS(cfg.JWKSURL); err != nil {
|
||||
|
|
@ -92,12 +75,15 @@ func main() {
|
|||
os.Exit(1)
|
||||
}
|
||||
|
||||
// 5. Build the main router
|
||||
mux := proxy.NewRouter(cfg, provider)
|
||||
// 5. (Optional) Build the access controler
|
||||
accessController := &authz.ScopeValidator{}
|
||||
|
||||
// 6. Build the main router
|
||||
mux := proxy.NewRouter(cfg, provider, accessController)
|
||||
|
||||
listen_address := fmt.Sprintf(":%d", cfg.ListenPort)
|
||||
|
||||
// 6. Start the server
|
||||
// 7. Start the server
|
||||
srv := &http.Server{
|
||||
Addr: listen_address,
|
||||
Handler: mux,
|
||||
|
|
@ -111,18 +97,18 @@ func main() {
|
|||
}
|
||||
}()
|
||||
|
||||
// 7. Wait for shutdown signal
|
||||
// 8. Wait for shutdown signal
|
||||
stop := make(chan os.Signal, 1)
|
||||
signal.Notify(stop, os.Interrupt, syscall.SIGTERM)
|
||||
<-stop
|
||||
logger.Info("Shutting down...")
|
||||
|
||||
// 8. First terminate subprocess if running
|
||||
// 9. First terminate subprocess if running
|
||||
if procManager != nil && procManager.IsRunning() {
|
||||
procManager.Shutdown()
|
||||
}
|
||||
|
||||
// 9. Then shutdown the server
|
||||
// 10. Then shutdown the server
|
||||
logger.Info("Shutting down HTTP server...")
|
||||
shutdownCtx, cancel := proxy.NewShutdownContext(5 * time.Second)
|
||||
defer cancel()
|
||||
|
|
|
|||
45
cmd/proxy/provider.go
Normal file
45
cmd/proxy/provider.go
Normal file
|
|
@ -0,0 +1,45 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"github.com/wso2/open-mcp-auth-proxy/internal/authz"
|
||||
"github.com/wso2/open-mcp-auth-proxy/internal/config"
|
||||
"github.com/wso2/open-mcp-auth-proxy/internal/constants"
|
||||
)
|
||||
|
||||
func MakeProvider(cfg *config.Config, demoMode, asgardeoMode bool) authz.Provider {
|
||||
var mode, orgName string
|
||||
switch {
|
||||
case demoMode:
|
||||
mode = "demo"
|
||||
orgName = cfg.Demo.OrgName
|
||||
case asgardeoMode:
|
||||
mode = "asgardeo"
|
||||
orgName = cfg.Asgardeo.OrgName
|
||||
default:
|
||||
mode = "default"
|
||||
}
|
||||
cfg.Mode = mode
|
||||
|
||||
switch mode {
|
||||
case "demo", "asgardeo":
|
||||
if len(cfg.ProtectedResourceMetadata.AuthorizationServers) == 0 && cfg.ProtectedResourceMetadata.JwksURI == "" {
|
||||
base := constants.ASGARDEO_BASE_URL + orgName + "/oauth2"
|
||||
cfg.AuthServerBaseURL = base
|
||||
cfg.JWKSURL = base + "/jwks"
|
||||
} else {
|
||||
cfg.AuthServerBaseURL = cfg.ProtectedResourceMetadata.AuthorizationServers[0]
|
||||
cfg.JWKSURL = cfg.ProtectedResourceMetadata.JwksURI
|
||||
}
|
||||
return authz.NewAsgardeoProvider(cfg)
|
||||
|
||||
default:
|
||||
if cfg.Default.BaseURL != "" && cfg.Default.JWKSURL != "" {
|
||||
cfg.AuthServerBaseURL = cfg.Default.BaseURL
|
||||
cfg.JWKSURL = cfg.Default.JWKSURL
|
||||
} else if len(cfg.ProtectedResourceMetadata.AuthorizationServers) > 0 {
|
||||
cfg.AuthServerBaseURL = cfg.ProtectedResourceMetadata.AuthorizationServers[0]
|
||||
cfg.JWKSURL = cfg.ProtectedResourceMetadata.JwksURI
|
||||
}
|
||||
return authz.NewDefaultProvider(cfg)
|
||||
}
|
||||
}
|
||||
22
config.yaml
22
config.yaml
|
|
@ -1,6 +1,7 @@
|
|||
# config.yaml
|
||||
|
||||
# Common configuration for all transport modes
|
||||
proxy_base_url: http://localhost:8080
|
||||
listen_port: 8080
|
||||
base_url: "http://localhost:8000" # Base URL for the MCP server
|
||||
port: 8000 # Port for the MCP server
|
||||
|
|
@ -10,13 +11,14 @@ timeout_seconds: 10
|
|||
paths:
|
||||
sse: "/sse" # SSE endpoint path
|
||||
messages: "/messages/" # Messages endpoint path
|
||||
streamable_http: "/mcp" # MCP endpoint path
|
||||
|
||||
# Transport mode configuration
|
||||
transport_mode: "sse" # Options: "sse" or "stdio"
|
||||
|
||||
# stdio-specific configuration (used only when transport_mode is "stdio")
|
||||
stdio:
|
||||
enabled: true
|
||||
enabled: false
|
||||
user_command: "npx -y @modelcontextprotocol/server-github"
|
||||
work_dir: "" # Working directory (optional)
|
||||
# env: # Environment variables (optional)
|
||||
|
|
@ -28,7 +30,8 @@ path_mapping:
|
|||
# CORS configuration
|
||||
cors:
|
||||
allowed_origins:
|
||||
- "http://localhost:5173"
|
||||
- "http://127.0.0.1:6274"
|
||||
- "http://localhost:6274"
|
||||
allowed_methods:
|
||||
- "GET"
|
||||
- "POST"
|
||||
|
|
@ -45,3 +48,18 @@ demo:
|
|||
org_name: "openmcpauthdemo"
|
||||
client_id: "N0U9e_NNGr9mP_0fPnPfPI0a6twa"
|
||||
client_secret: "qFHfiBp5gNGAO9zV4YPnDofBzzfInatfUbHyPZvM0jka"
|
||||
|
||||
protected_resource_metadata:
|
||||
resource_identifier: http://localhost:8080/sse
|
||||
audience: 2xGW_poFYoObUE_vUQxvGdPSUPwa
|
||||
scopes_supported:
|
||||
- initialize: "mcp_init"
|
||||
- tools/call:
|
||||
- echo_tool: "mcp_echo_tool"
|
||||
authorization_servers:
|
||||
- https://api.asgardeo.io/t/openmcpauthdemo/oauth2/token
|
||||
jwks_uri: https://api.asgardeo.io/t/openmcpauthdemo/oauth2/jwks
|
||||
bearer_methods_supported:
|
||||
- header
|
||||
- body
|
||||
- query
|
||||
|
|
|
|||
24
internal/authz/access_control.go
Normal file
24
internal/authz/access_control.go
Normal file
|
|
@ -0,0 +1,24 @@
|
|||
package authz
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"github.com/golang-jwt/jwt/v4"
|
||||
"github.com/wso2/open-mcp-auth-proxy/internal/config"
|
||||
)
|
||||
|
||||
type Decision int
|
||||
|
||||
const (
|
||||
DecisionAllow Decision = iota
|
||||
DecisionDeny
|
||||
)
|
||||
|
||||
type AccessControlResult struct {
|
||||
Decision Decision
|
||||
Message string
|
||||
}
|
||||
|
||||
type AccessControl interface {
|
||||
ValidateAccess(r *http.Request, claims *jwt.MapClaims, config *config.Config) AccessControlResult
|
||||
}
|
||||
|
|
@ -42,31 +42,17 @@ func (p *asgardeoProvider) WellKnownHandler() http.HandlerFunc {
|
|||
return
|
||||
}
|
||||
|
||||
scheme := "http"
|
||||
if r.TLS != nil {
|
||||
scheme = "https"
|
||||
}
|
||||
if forwardedProto := r.Header.Get("X-Forwarded-Proto"); forwardedProto != "" {
|
||||
scheme = forwardedProto
|
||||
}
|
||||
host := r.Host
|
||||
if forwardedHost := r.Header.Get("X-Forwarded-Host"); forwardedHost != "" {
|
||||
host = forwardedHost
|
||||
}
|
||||
|
||||
baseURL := scheme + "://" + host
|
||||
|
||||
issuer := strings.TrimSuffix(p.cfg.AuthServerBaseURL, "/") + "/token"
|
||||
|
||||
response := map[string]interface{}{
|
||||
"issuer": issuer,
|
||||
"authorization_endpoint": baseURL + "/authorize",
|
||||
"token_endpoint": baseURL + "/token",
|
||||
"authorization_endpoint": p.cfg.BaseURL + "/authorize",
|
||||
"token_endpoint": p.cfg.BaseURL + "/token",
|
||||
"jwks_uri": p.cfg.JWKSURL,
|
||||
"response_types_supported": []string{"code"},
|
||||
"grant_types_supported": []string{"authorization_code", "refresh_token"},
|
||||
"token_endpoint_auth_methods_supported": []string{"client_secret_basic"},
|
||||
"registration_endpoint": baseURL + "/register",
|
||||
"registration_endpoint": p.cfg.BaseURL + "/register",
|
||||
"code_challenge_methods_supported": []string{"plain", "S256"},
|
||||
}
|
||||
|
||||
|
|
@ -113,6 +99,7 @@ func (p *asgardeoProvider) RegisterHandler() http.HandlerFunc {
|
|||
|
||||
if err := p.createAsgardeoApplication(regReq); err != nil {
|
||||
logger.Warn("Asgardeo application creation failed: %v", err)
|
||||
http.Error(w, "Failed to create application in Asgardeo", http.StatusInternalServerError)
|
||||
// Optionally http.Error(...) if you want to fail
|
||||
// or continue to return partial data.
|
||||
}
|
||||
|
|
@ -193,7 +180,7 @@ func (p *asgardeoProvider) createAsgardeoApplication(regReq RegisterRequest) err
|
|||
|
||||
if resp.StatusCode >= 400 {
|
||||
respBody, _ := io.ReadAll(resp.Body)
|
||||
return fmt.Errorf("Asgardeo creation error (%d): %s", resp.StatusCode, string(respBody))
|
||||
return fmt.Errorf("asgardeo creation error (%d): %s", resp.StatusCode, string(respBody))
|
||||
}
|
||||
|
||||
logger.Info("Created Asgardeo application for clientID=%s", regReq.ClientID)
|
||||
|
|
@ -269,6 +256,18 @@ func buildAsgardeoPayload(regReq RegisterRequest) map[string]interface{} {
|
|||
}
|
||||
appName += "-" + randomString(5)
|
||||
|
||||
// Build redirect URIs regex from list of redirect URIs : regexp=(https://app.example.com/callback1|https://app.example.com/callback2)
|
||||
redirectURI := "regexp=(" + strings.Join(regReq.RedirectURIs, "|") + ")"
|
||||
redirectURIs := []string{redirectURI}
|
||||
|
||||
// Filter unsupported grant types
|
||||
var grantTypes []string
|
||||
for _, gt := range regReq.GrantTypes {
|
||||
if gt == "authorization_code" || gt == "refresh_token" {
|
||||
grantTypes = append(grantTypes, gt)
|
||||
}
|
||||
}
|
||||
|
||||
return map[string]interface{}{
|
||||
"name": appName,
|
||||
"templateId": "custom-application-oidc",
|
||||
|
|
@ -276,10 +275,10 @@ func buildAsgardeoPayload(regReq RegisterRequest) map[string]interface{} {
|
|||
"oidc": map[string]interface{}{
|
||||
"clientId": regReq.ClientID,
|
||||
"clientSecret": regReq.ClientSecret,
|
||||
"grantTypes": regReq.GrantTypes,
|
||||
"callbackURLs": regReq.RedirectURIs,
|
||||
"grantTypes": grantTypes,
|
||||
"callbackURLs": redirectURIs,
|
||||
"allowedOrigins": []string{},
|
||||
"publicClient": false,
|
||||
"publicClient": true,
|
||||
"pkce": map[string]bool{
|
||||
"mandatory": true,
|
||||
"supportPlainTransformAlgorithm": true,
|
||||
|
|
@ -350,3 +349,48 @@ func randomString(n int) string {
|
|||
}
|
||||
return string(b)
|
||||
}
|
||||
|
||||
func (p *asgardeoProvider) ProtectedResourceMetadataHandler() http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
|
||||
// Extract only the values into a []string
|
||||
var supportedScopes []string
|
||||
var extractStrings func(interface{})
|
||||
extractStrings = func(val interface{}) {
|
||||
switch v := val.(type) {
|
||||
case string:
|
||||
supportedScopes = append(supportedScopes, v)
|
||||
case []any:
|
||||
for _, item := range v {
|
||||
extractStrings(item)
|
||||
}
|
||||
case map[string]any:
|
||||
for _, item := range v {
|
||||
extractStrings(item)
|
||||
}
|
||||
}
|
||||
}
|
||||
for _, m := range p.cfg.ProtectedResourceMetadata.ScopesSupported {
|
||||
for _, v := range m {
|
||||
extractStrings(v)
|
||||
}
|
||||
}
|
||||
|
||||
meta := map[string]interface{}{
|
||||
"resource": p.cfg.ProtectedResourceMetadata.ResourceIdentifier,
|
||||
"scopes_supported": supportedScopes,
|
||||
"authorization_servers": p.cfg.ProtectedResourceMetadata.AuthorizationServers,
|
||||
}
|
||||
|
||||
if p.cfg.ProtectedResourceMetadata.JwksURI != "" {
|
||||
meta["jwks_uri"] = p.cfg.ProtectedResourceMetadata.JwksURI
|
||||
}
|
||||
if len(p.cfg.ProtectedResourceMetadata.BearerMethodsSupported) > 0 {
|
||||
meta["bearer_methods_supported"] = p.cfg.ProtectedResourceMetadata.BearerMethodsSupported
|
||||
}
|
||||
if err := json.NewEncoder(w).Encode(meta); err != nil {
|
||||
http.Error(w, "failed to encode metadata", http.StatusInternalServerError)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -5,7 +5,7 @@ import (
|
|||
"net/http"
|
||||
|
||||
"github.com/wso2/open-mcp-auth-proxy/internal/config"
|
||||
"github.com/wso2/open-mcp-auth-proxy/internal/logging"
|
||||
logger "github.com/wso2/open-mcp-auth-proxy/internal/logging"
|
||||
)
|
||||
|
||||
type defaultProvider struct {
|
||||
|
|
@ -40,31 +40,17 @@ func (p *defaultProvider) WellKnownHandler() http.HandlerFunc {
|
|||
// Use configured response values
|
||||
responseConfig := pathConfig.Response
|
||||
|
||||
// Get current host for proxy endpoints
|
||||
scheme := "http"
|
||||
if r.TLS != nil {
|
||||
scheme = "https"
|
||||
}
|
||||
if forwardedProto := r.Header.Get("X-Forwarded-Proto"); forwardedProto != "" {
|
||||
scheme = forwardedProto
|
||||
}
|
||||
host := r.Host
|
||||
if forwardedHost := r.Header.Get("X-Forwarded-Host"); forwardedHost != "" {
|
||||
host = forwardedHost
|
||||
}
|
||||
baseURL := scheme + "://" + host
|
||||
|
||||
authorizationEndpoint := responseConfig.AuthorizationEndpoint
|
||||
if authorizationEndpoint == "" {
|
||||
authorizationEndpoint = baseURL + "/authorize"
|
||||
authorizationEndpoint = p.cfg.BaseURL + "/authorize"
|
||||
}
|
||||
tokenEndpoint := responseConfig.TokenEndpoint
|
||||
if tokenEndpoint == "" {
|
||||
tokenEndpoint = baseURL + "/token"
|
||||
tokenEndpoint = p.cfg.BaseURL + "/token"
|
||||
}
|
||||
registraionEndpoint := responseConfig.RegistrationEndpoint
|
||||
if registraionEndpoint == "" {
|
||||
registraionEndpoint = baseURL + "/register"
|
||||
registrationEndpoint := responseConfig.RegistrationEndpoint
|
||||
if registrationEndpoint == "" {
|
||||
registrationEndpoint = p.cfg.BaseURL + "/register"
|
||||
}
|
||||
|
||||
// Build response from config
|
||||
|
|
@ -76,7 +62,7 @@ func (p *defaultProvider) WellKnownHandler() http.HandlerFunc {
|
|||
"response_types_supported": responseConfig.ResponseTypesSupported,
|
||||
"grant_types_supported": responseConfig.GrantTypesSupported,
|
||||
"token_endpoint_auth_methods_supported": []string{"client_secret_basic"},
|
||||
"registration_endpoint": registraionEndpoint,
|
||||
"registration_endpoint": registrationEndpoint,
|
||||
"code_challenge_methods_supported": responseConfig.CodeChallengeMethodsSupported,
|
||||
}
|
||||
|
||||
|
|
@ -94,3 +80,26 @@ func (p *defaultProvider) WellKnownHandler() http.HandlerFunc {
|
|||
func (p *defaultProvider) RegisterHandler() http.HandlerFunc {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *defaultProvider) ProtectedResourceMetadataHandler() http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
meta := map[string]interface{}{
|
||||
"audience": p.cfg.ProtectedResourceMetadata.Audience,
|
||||
"scopes_supported": p.cfg.ProtectedResourceMetadata.ScopesSupported,
|
||||
"authorization_servers": p.cfg.ProtectedResourceMetadata.AuthorizationServers,
|
||||
}
|
||||
|
||||
if p.cfg.ProtectedResourceMetadata.JwksURI != "" {
|
||||
meta["jwks_uri"] = p.cfg.ProtectedResourceMetadata.JwksURI
|
||||
}
|
||||
|
||||
if len(p.cfg.ProtectedResourceMetadata.BearerMethodsSupported) > 0 {
|
||||
meta["bearer_methods_supported"] = p.cfg.ProtectedResourceMetadata.BearerMethodsSupported
|
||||
}
|
||||
|
||||
if err := json.NewEncoder(w).Encode(meta); err != nil {
|
||||
http.Error(w, "failed to encode metadata", http.StatusInternalServerError)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -7,4 +7,5 @@ import "net/http"
|
|||
type Provider interface {
|
||||
WellKnownHandler() http.HandlerFunc
|
||||
RegisterHandler() http.HandlerFunc
|
||||
ProtectedResourceMetadataHandler() http.HandlerFunc
|
||||
}
|
||||
|
|
|
|||
72
internal/authz/scope_validator.go
Normal file
72
internal/authz/scope_validator.go
Normal file
|
|
@ -0,0 +1,72 @@
|
|||
package authz
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/golang-jwt/jwt/v4"
|
||||
"github.com/wso2/open-mcp-auth-proxy/internal/config"
|
||||
"github.com/wso2/open-mcp-auth-proxy/internal/util"
|
||||
)
|
||||
|
||||
type ScopeValidator struct{}
|
||||
|
||||
// Evaluate and checks the token claims against one or more required scopes.
|
||||
func (d *ScopeValidator) ValidateAccess(
|
||||
r *http.Request,
|
||||
claims *jwt.MapClaims,
|
||||
config *config.Config,
|
||||
) AccessControlResult {
|
||||
env, err := util.ParseRPCRequest(r)
|
||||
if err != nil {
|
||||
return AccessControlResult{DecisionDeny, "bad JSON-RPC request"}
|
||||
}
|
||||
requiredScopes := util.GetRequiredScopes(config, env)
|
||||
|
||||
if len(requiredScopes) == 0 {
|
||||
return AccessControlResult{DecisionAllow, ""}
|
||||
}
|
||||
|
||||
required := make(map[string]struct{}, len(requiredScopes))
|
||||
for _, s := range requiredScopes {
|
||||
s = strings.TrimSpace(s)
|
||||
if s != "" {
|
||||
required[s] = struct{}{}
|
||||
}
|
||||
}
|
||||
|
||||
var tokenScopes []string
|
||||
if claims, ok := (*claims)["scope"]; ok {
|
||||
switch v := claims.(type) {
|
||||
case string:
|
||||
tokenScopes = strings.Fields(v)
|
||||
case []interface{}:
|
||||
for _, x := range v {
|
||||
if s, ok := x.(string); ok && s != "" {
|
||||
tokenScopes = append(tokenScopes, s)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
tokenScopeSet := make(map[string]struct{}, len(tokenScopes))
|
||||
for _, s := range tokenScopes {
|
||||
tokenScopeSet[s] = struct{}{}
|
||||
}
|
||||
|
||||
var missing []string
|
||||
for s := range required {
|
||||
if _, ok := tokenScopeSet[s]; !ok {
|
||||
missing = append(missing, s)
|
||||
}
|
||||
}
|
||||
|
||||
if len(missing) == 0 {
|
||||
return AccessControlResult{DecisionAllow, ""}
|
||||
}
|
||||
return AccessControlResult{
|
||||
DecisionDeny,
|
||||
fmt.Sprintf("missing required scope(s): %s", strings.Join(missing, ", ")),
|
||||
}
|
||||
}
|
||||
|
|
@ -3,6 +3,8 @@ package config
|
|||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"runtime"
|
||||
"strings"
|
||||
|
||||
"gopkg.in/yaml.v2"
|
||||
)
|
||||
|
|
@ -13,12 +15,14 @@ type TransportMode string
|
|||
const (
|
||||
SSETransport TransportMode = "sse"
|
||||
StdioTransport TransportMode = "stdio"
|
||||
StreamableHTTPTransport TransportMode = "streamable_http"
|
||||
)
|
||||
|
||||
// Common path configuration for all transport modes
|
||||
type PathsConfig struct {
|
||||
SSE string `yaml:"sse"`
|
||||
Messages string `yaml:"messages"`
|
||||
StreamableHTTP string `yaml:"streamable_http"` // Path for streamable HTTP requests
|
||||
}
|
||||
|
||||
// StdioConfig contains stdio-specific configuration
|
||||
|
|
@ -65,6 +69,15 @@ type ResponseConfig struct {
|
|||
CodeChallengeMethodsSupported []string `yaml:"code_challenge_methods_supported,omitempty"`
|
||||
}
|
||||
|
||||
type ProtectedResourceMetadata struct {
|
||||
ResourceIdentifier string `yaml:"resource_identifier"`
|
||||
Audience string `yaml:"audience"`
|
||||
ScopesSupported []map[string]interface{} `yaml:"scopes_supported"`
|
||||
AuthorizationServers []string `yaml:"authorization_servers"`
|
||||
JwksURI string `yaml:"jwks_uri,omitempty"`
|
||||
BearerMethodsSupported []string `yaml:"bearer_methods_supported,omitempty"`
|
||||
}
|
||||
|
||||
type PathConfig struct {
|
||||
// For well-known endpoint
|
||||
Response *ResponseConfig `yaml:"response,omitempty"`
|
||||
|
|
@ -83,6 +96,7 @@ type DefaultConfig struct {
|
|||
}
|
||||
|
||||
type Config struct {
|
||||
ProxyBaseURL string `yaml:"proxy_base_url"`
|
||||
AuthServerBaseURL string
|
||||
ListenPort int `yaml:"listen_port"`
|
||||
BaseURL string `yaml:"base_url"`
|
||||
|
|
@ -100,6 +114,9 @@ type Config struct {
|
|||
Demo DemoConfig `yaml:"demo"`
|
||||
Asgardeo AsgardeoConfig `yaml:"asgardeo"`
|
||||
Default DefaultConfig `yaml:"default"`
|
||||
|
||||
// Protected resource metadata
|
||||
ProtectedResourceMetadata ProtectedResourceMetadata `yaml:"protected_resource_metadata"`
|
||||
}
|
||||
|
||||
// Validate checks if the config is valid based on transport mode
|
||||
|
|
@ -136,7 +153,7 @@ func (c *Config) Validate() error {
|
|||
|
||||
// GetMCPPaths returns the list of paths that should be proxied to the MCP server
|
||||
func (c *Config) GetMCPPaths() []string {
|
||||
return []string{c.Paths.SSE, c.Paths.Messages}
|
||||
return []string{c.Paths.SSE, c.Paths.Messages, c.Paths.StreamableHTTP}
|
||||
}
|
||||
|
||||
// BuildExecCommand constructs the full command string for execution in stdio mode
|
||||
|
|
@ -145,7 +162,15 @@ func (c *Config) BuildExecCommand() string {
|
|||
return ""
|
||||
}
|
||||
|
||||
// Construct the full command
|
||||
if runtime.GOOS == "windows" {
|
||||
// For Windows, we need to properly escape the inner command
|
||||
escapedCommand := strings.ReplaceAll(c.Stdio.UserCommand, `"`, `\"`)
|
||||
return fmt.Sprintf(
|
||||
`npx -y supergateway --stdio "%s" --port %d --baseUrl %s --ssePath %s --messagePath %s`,
|
||||
escapedCommand, c.Port, c.BaseURL, c.Paths.SSE, c.Paths.Messages,
|
||||
)
|
||||
}
|
||||
|
||||
return fmt.Sprintf(
|
||||
`npx -y supergateway --stdio "%s" --port %d --baseUrl %s --ssePath %s --messagePath %s`,
|
||||
c.Stdio.UserCommand, c.Port, c.BaseURL, c.Paths.SSE, c.Paths.Messages,
|
||||
|
|
|
|||
|
|
@ -138,18 +138,13 @@ func TestGetMCPPaths(t *testing.T) {
|
|||
Paths: PathsConfig{
|
||||
SSE: "/custom-sse",
|
||||
Messages: "/custom-messages",
|
||||
StreamableHTTP: "/custom-streamable",
|
||||
},
|
||||
}
|
||||
|
||||
paths := cfg.GetMCPPaths()
|
||||
if len(paths) != 2 {
|
||||
t.Errorf("Expected 2 MCP paths, got %d", len(paths))
|
||||
}
|
||||
if paths[0] != "/custom-sse" {
|
||||
t.Errorf("Expected first path=/custom-sse, got %s", paths[0])
|
||||
}
|
||||
if paths[1] != "/custom-messages" {
|
||||
t.Errorf("Expected second path=/custom-messages, got %s", paths[1])
|
||||
if len(paths) != 3 {
|
||||
t.Errorf("Expected 3 MCP paths, got %d", len(paths))
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -1,7 +1,14 @@
|
|||
package constants
|
||||
|
||||
import "time"
|
||||
|
||||
// Package constant provides constants for the MCP Auth Proxy
|
||||
|
||||
const (
|
||||
ASGARDEO_BASE_URL = "https://api.asgardeo.io/t/"
|
||||
)
|
||||
|
||||
// MCP specification version cutover date
|
||||
var SpecCutoverDate = time.Date(2025, 3, 26, 0, 0, 0, 0, time.UTC)
|
||||
|
||||
const TimeLayout = "2006-01-02"
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@ package proxy
|
|||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httputil"
|
||||
"net/url"
|
||||
|
|
@ -10,14 +11,14 @@ import (
|
|||
|
||||
"github.com/wso2/open-mcp-auth-proxy/internal/authz"
|
||||
"github.com/wso2/open-mcp-auth-proxy/internal/config"
|
||||
"github.com/wso2/open-mcp-auth-proxy/internal/logging"
|
||||
logger "github.com/wso2/open-mcp-auth-proxy/internal/logging"
|
||||
"github.com/wso2/open-mcp-auth-proxy/internal/util"
|
||||
)
|
||||
|
||||
// NewRouter builds an http.ServeMux that routes
|
||||
// * /authorize, /token, /register, /.well-known to the provider or proxy
|
||||
// * MCP paths to the MCP server, etc.
|
||||
func NewRouter(cfg *config.Config, provider authz.Provider) http.Handler {
|
||||
func NewRouter(cfg *config.Config, provider authz.Provider, accessController authz.AccessControl) http.Handler {
|
||||
mux := http.NewServeMux()
|
||||
|
||||
modifiers := map[string]RequestModifier{
|
||||
|
|
@ -63,6 +64,20 @@ func NewRouter(cfg *config.Config, provider authz.Provider) http.Handler {
|
|||
}
|
||||
}
|
||||
|
||||
mux.HandleFunc(getProtectedResourceMetadataEndpointPath(cfg), func(w http.ResponseWriter, r *http.Request) {
|
||||
origin := r.Header.Get("Origin")
|
||||
allowed := getAllowedOrigin(origin, cfg)
|
||||
if r.Method == http.MethodOptions {
|
||||
addCORSHeaders(w, cfg, allowed, r.Header.Get("Access-Control-Request-Headers"))
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
return
|
||||
}
|
||||
|
||||
addCORSHeaders(w, cfg, allowed, "")
|
||||
provider.ProtectedResourceMetadataHandler()(w, r)
|
||||
})
|
||||
registeredPaths[getProtectedResourceMetadataEndpointPath(cfg)] = true
|
||||
|
||||
// Remove duplicates from defaultPaths
|
||||
uniquePaths := make(map[string]bool)
|
||||
cleanPaths := []string{}
|
||||
|
|
@ -76,7 +91,7 @@ func NewRouter(cfg *config.Config, provider authz.Provider) http.Handler {
|
|||
|
||||
for _, path := range defaultPaths {
|
||||
if !registeredPaths[path] {
|
||||
mux.HandleFunc(path, buildProxyHandler(cfg, modifiers))
|
||||
mux.HandleFunc(path, buildProxyHandler(cfg, modifiers, accessController))
|
||||
registeredPaths[path] = true
|
||||
}
|
||||
}
|
||||
|
|
@ -84,14 +99,14 @@ func NewRouter(cfg *config.Config, provider authz.Provider) http.Handler {
|
|||
// MCP paths
|
||||
mcpPaths := cfg.GetMCPPaths()
|
||||
for _, path := range mcpPaths {
|
||||
mux.HandleFunc(path, buildProxyHandler(cfg, modifiers))
|
||||
mux.HandleFunc(path, buildProxyHandler(cfg, modifiers, accessController))
|
||||
registeredPaths[path] = true
|
||||
}
|
||||
|
||||
// Register paths from PathMapping that haven't been registered yet
|
||||
for path := range cfg.PathMapping {
|
||||
if !registeredPaths[path] {
|
||||
mux.HandleFunc(path, buildProxyHandler(cfg, modifiers))
|
||||
mux.HandleFunc(path, buildProxyHandler(cfg, modifiers, accessController))
|
||||
registeredPaths[path] = true
|
||||
}
|
||||
}
|
||||
|
|
@ -99,7 +114,7 @@ func NewRouter(cfg *config.Config, provider authz.Provider) http.Handler {
|
|||
return mux
|
||||
}
|
||||
|
||||
func buildProxyHandler(cfg *config.Config, modifiers map[string]RequestModifier) http.HandlerFunc {
|
||||
func buildProxyHandler(cfg *config.Config, modifiers map[string]RequestModifier, accessController authz.AccessControl) http.HandlerFunc {
|
||||
// Parse the base URLs up front
|
||||
authBase, err := url.Parse(cfg.AuthServerBaseURL)
|
||||
if err != nil {
|
||||
|
|
@ -141,20 +156,31 @@ func buildProxyHandler(cfg *config.Config, modifiers map[string]RequestModifier)
|
|||
// Add CORS headers to all responses
|
||||
addCORSHeaders(w, cfg, allowedOrigin, "")
|
||||
|
||||
// Check if the request is for the latest spec
|
||||
specVersion := util.GetVersionWithDefault(r.Header.Get("MCP-Protocol-Version"))
|
||||
ver, err := util.ParseVersionDate(specVersion)
|
||||
isLatestSpec := util.IsLatestSpec(ver, err)
|
||||
|
||||
// Decide whether the request should go to the auth server or MCP
|
||||
var targetURL *url.URL
|
||||
isSSE := false
|
||||
|
||||
if isAuthPath(r.URL.Path) {
|
||||
if isAuthPath(r.URL.Path, cfg) {
|
||||
targetURL = authBase
|
||||
} else if isMCPPath(r.URL.Path, cfg) {
|
||||
// Validate JWT for MCP paths if required
|
||||
// Placeholder for JWT validation logic
|
||||
if err := util.ValidateJWT(r.Header.Get("Authorization")); err != nil {
|
||||
logger.Warn("Unauthorized request to %s: %v", r.URL.Path, err)
|
||||
http.Error(w, "Unauthorized", http.StatusUnauthorized)
|
||||
if ssePaths[r.URL.Path] {
|
||||
if err := authorizeSSE(w, r, isLatestSpec, cfg); err != nil {
|
||||
http.Error(w, err.Error(), http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
isSSE = true
|
||||
} else {
|
||||
if err := authorizeMCP(w, r, isLatestSpec, cfg, accessController); err != nil {
|
||||
http.Error(w, err.Error(), http.StatusForbidden)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
targetURL = mcpBase
|
||||
if ssePaths[r.URL.Path] {
|
||||
isSSE = true
|
||||
|
|
@ -214,7 +240,17 @@ func buildProxyHandler(cfg *config.Config, modifiers map[string]RequestModifier)
|
|||
},
|
||||
ModifyResponse: func(resp *http.Response) error {
|
||||
logger.Debug("Response from %s%s: %d", resp.Request.URL.Host, resp.Request.URL.Path, resp.StatusCode)
|
||||
resp.Header.Del("Access-Control-Allow-Origin") // Avoid upstream conflicts
|
||||
if resp.StatusCode == http.StatusUnauthorized {
|
||||
resp.Header.Set(
|
||||
"WWW-Authenticate",
|
||||
fmt.Sprintf(
|
||||
`Bearer resource_metadata="%s"`,
|
||||
cfg.ProxyBaseURL+getProtectedResourceMetadataEndpointPath(cfg),
|
||||
))
|
||||
resp.Header.Set("Access-Control-Expose-Headers", "WWW-Authenticate")
|
||||
}
|
||||
|
||||
resp.Header.Del("Access-Control-Allow-Origin")
|
||||
return nil
|
||||
},
|
||||
ErrorHandler: func(rw http.ResponseWriter, req *http.Request, err error) {
|
||||
|
|
@ -236,7 +272,7 @@ func buildProxyHandler(cfg *config.Config, modifiers map[string]RequestModifier)
|
|||
w.Header().Set("X-Accel-Buffering", "no")
|
||||
w.Header().Set("Cache-Control", "no-cache")
|
||||
w.Header().Set("Connection", "keep-alive")
|
||||
|
||||
w.Header().Set("Content-Type", "text/event-stream")
|
||||
// Keep SSE connections open
|
||||
HandleSSE(w, r, rp)
|
||||
} else {
|
||||
|
|
@ -248,6 +284,76 @@ func buildProxyHandler(cfg *config.Config, modifiers map[string]RequestModifier)
|
|||
}
|
||||
}
|
||||
|
||||
// Check if the request is for SSE handshake and authorize it
|
||||
func authorizeSSE(w http.ResponseWriter, r *http.Request, isLatestSpec bool, cfg *config.Config) error {
|
||||
authHeader := r.Header.Get("Authorization")
|
||||
if authHeader == "" || !strings.HasPrefix(authHeader, "Bearer ") {
|
||||
if isLatestSpec {
|
||||
realm := cfg.BaseURL + getProtectedResourceMetadataEndpointPath(cfg)
|
||||
w.Header().Set("WWW-Authenticate", fmt.Sprintf(`Bearer resource_metadata="%s"`, realm))
|
||||
w.Header().Set("Access-Control-Expose-Headers", "WWW-Authenticate")
|
||||
}
|
||||
return fmt.Errorf("missing or invalid Authorization header")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Handles both v1 (just signature) and v2 (aud + scope) flows
|
||||
func authorizeMCP(w http.ResponseWriter, r *http.Request, isLatestSpec bool, cfg *config.Config, accessController authz.AccessControl) error {
|
||||
authzHeader := r.Header.Get("Authorization")
|
||||
accessToken, _ := util.ExtractAccessToken(authzHeader)
|
||||
if !strings.HasPrefix(authzHeader, "Bearer ") {
|
||||
if isLatestSpec {
|
||||
realm := cfg.ProxyBaseURL + getProtectedResourceMetadataEndpointPath(cfg)
|
||||
w.Header().Set("WWW-Authenticate", fmt.Sprintf(
|
||||
`Bearer resource_metadata=%q`, realm,
|
||||
))
|
||||
w.Header().Set("Access-Control-Expose-Headers", "WWW-Authenticate")
|
||||
}
|
||||
http.Error(w, "Unauthorized", http.StatusUnauthorized)
|
||||
return fmt.Errorf("missing or invalid Authorization header")
|
||||
}
|
||||
|
||||
err := util.ValidateJWT(isLatestSpec, accessToken, cfg.ProtectedResourceMetadata.Audience)
|
||||
if err != nil {
|
||||
if isLatestSpec {
|
||||
realm := cfg.ProxyBaseURL + getProtectedResourceMetadataEndpointPath(cfg)
|
||||
w.Header().Set("WWW-Authenticate", fmt.Sprintf(err.Error(),
|
||||
`Bearer realm=%q`,
|
||||
realm,
|
||||
))
|
||||
w.Header().Set("Access-Control-Expose-Headers", "WWW-Authenticate")
|
||||
http.Error(w, "Forbidden", http.StatusForbidden)
|
||||
} else {
|
||||
http.Error(w, "Unauthorized", http.StatusUnauthorized)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
if isLatestSpec {
|
||||
_, err := util.ParseRPCRequest(r)
|
||||
if err != nil {
|
||||
http.Error(w, "Bad request", http.StatusBadRequest)
|
||||
return err
|
||||
}
|
||||
|
||||
claimsMap, err := util.ParseJWT(accessToken)
|
||||
if err != nil {
|
||||
http.Error(w, "Invalid token claims", http.StatusUnauthorized)
|
||||
return fmt.Errorf("invalid token claims")
|
||||
}
|
||||
|
||||
pr := accessController.ValidateAccess(r, &claimsMap, cfg)
|
||||
if pr.Decision == authz.DecisionDeny {
|
||||
http.Error(w, "Forbidden: "+pr.Message, http.StatusForbidden)
|
||||
return fmt.Errorf("forbidden — %s", pr.Message)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func getAllowedOrigin(origin string, cfg *config.Config) string {
|
||||
if origin == "" {
|
||||
return cfg.CORSConfig.AllowedOrigins[0] // Default to first allowed origin
|
||||
|
|
@ -265,6 +371,7 @@ func getAllowedOrigin(origin string, cfg *config.Config) string {
|
|||
func addCORSHeaders(w http.ResponseWriter, cfg *config.Config, allowedOrigin, requestHeaders string) {
|
||||
w.Header().Set("Access-Control-Allow-Origin", allowedOrigin)
|
||||
w.Header().Set("Access-Control-Allow-Methods", strings.Join(cfg.CORSConfig.AllowedMethods, ", "))
|
||||
w.Header().Set("Access-Control-Expose-Headers", "WWW-Authenticate, MCP-Protocol-Version")
|
||||
if requestHeaders != "" {
|
||||
w.Header().Set("Access-Control-Allow-Headers", requestHeaders)
|
||||
} else {
|
||||
|
|
@ -272,17 +379,19 @@ func addCORSHeaders(w http.ResponseWriter, cfg *config.Config, allowedOrigin, re
|
|||
}
|
||||
if cfg.CORSConfig.AllowCredentials {
|
||||
w.Header().Set("Access-Control-Allow-Credentials", "true")
|
||||
w.Header().Set("MCP-Protocol-Version", ", ")
|
||||
}
|
||||
w.Header().Set("Vary", "Origin")
|
||||
w.Header().Set("X-Accel-Buffering", "no")
|
||||
}
|
||||
|
||||
func isAuthPath(path string) bool {
|
||||
func isAuthPath(path string, cfg *config.Config) bool {
|
||||
authPaths := map[string]bool{
|
||||
"/authorize": true,
|
||||
"/token": true,
|
||||
"/register": true,
|
||||
"/.well-known/oauth-authorization-server": true,
|
||||
getProtectedResourceMetadataEndpointPath(cfg): true,
|
||||
}
|
||||
if strings.HasPrefix(path, "/u/") {
|
||||
return true
|
||||
|
|
@ -308,3 +417,17 @@ func skipHeader(h string) bool {
|
|||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func getProtectedResourceMetadataEndpointPath(cfg *config.Config) string {
|
||||
|
||||
protectedResourceMetadataPath := "/.well-known/oauth-protected-resource"
|
||||
|
||||
switch cfg.TransportMode {
|
||||
case config.SSETransport:
|
||||
protectedResourceMetadataPath += cfg.Paths.SSE
|
||||
case config.StreamableHTTPTransport:
|
||||
protectedResourceMetadataPath += cfg.Paths.StreamableHTTP
|
||||
}
|
||||
|
||||
return protectedResourceMetadataPath
|
||||
}
|
||||
|
|
|
|||
|
|
@ -4,13 +4,14 @@ import (
|
|||
"fmt"
|
||||
"os"
|
||||
"os/exec"
|
||||
"runtime"
|
||||
"strings"
|
||||
"sync"
|
||||
"syscall"
|
||||
"time"
|
||||
"strings"
|
||||
|
||||
"github.com/wso2/open-mcp-auth-proxy/internal/config"
|
||||
"github.com/wso2/open-mcp-auth-proxy/internal/logging"
|
||||
logger "github.com/wso2/open-mcp-auth-proxy/internal/logging"
|
||||
)
|
||||
|
||||
// Manager handles starting and graceful shutdown of subprocesses
|
||||
|
|
@ -40,7 +41,12 @@ func EnsureDependenciesAvailable(command string) error {
|
|||
|
||||
// Try to install npx using npm
|
||||
logger.Info("npx not found, attempting to install...")
|
||||
cmd := exec.Command("npm", "install", "-g", "npx")
|
||||
var cmd *exec.Cmd
|
||||
if runtime.GOOS == "windows" {
|
||||
cmd = exec.Command("npm.cmd", "install", "-g", "npx")
|
||||
} else {
|
||||
cmd = exec.Command("npm", "install", "-g", "npx")
|
||||
}
|
||||
cmd.Stdout = os.Stdout
|
||||
cmd.Stderr = os.Stderr
|
||||
|
||||
|
|
@ -88,8 +94,13 @@ func (m *Manager) Start(cfg *config.Config) error {
|
|||
|
||||
logger.Info("Starting subprocess with command: %s", execCommand)
|
||||
|
||||
// Use the shell to execute the command
|
||||
cmd := exec.Command("sh", "-c", execCommand)
|
||||
var cmd *exec.Cmd
|
||||
if runtime.GOOS == "windows" {
|
||||
// Use PowerShell on Windows for better quote handling
|
||||
cmd = exec.Command("powershell", "-Command", execCommand)
|
||||
} else {
|
||||
cmd = exec.Command("sh", "-c", execCommand)
|
||||
}
|
||||
|
||||
// Set working directory if specified
|
||||
if cfg.Stdio.WorkDir != "" {
|
||||
|
|
@ -105,8 +116,8 @@ func (m *Manager) Start(cfg *config.Config) error {
|
|||
cmd.Stdout = os.Stdout
|
||||
cmd.Stderr = os.Stderr
|
||||
|
||||
// Set the process group for proper termination
|
||||
cmd.SysProcAttr = &syscall.SysProcAttr{Setpgid: true}
|
||||
// Set platform-specific process attributes
|
||||
setProcAttr(cmd)
|
||||
|
||||
// Start the process
|
||||
if err := cmd.Start(); err != nil {
|
||||
|
|
@ -117,11 +128,13 @@ func (m *Manager) Start(cfg *config.Config) error {
|
|||
m.cmd = cmd
|
||||
logger.Info("Subprocess started with PID: %d", m.process.Pid)
|
||||
|
||||
// Get and store the process group ID
|
||||
pgid, err := syscall.Getpgid(m.process.Pid)
|
||||
// Get and store the process group ID (Unix) or PID (Windows)
|
||||
pgid, err := getProcessGroup(m.process.Pid)
|
||||
if err == nil {
|
||||
m.processGroup = pgid
|
||||
if runtime.GOOS != "windows" {
|
||||
logger.Debug("Process group ID: %d", m.processGroup)
|
||||
}
|
||||
} else {
|
||||
logger.Warn("Failed to get process group ID: %v", err)
|
||||
m.processGroup = m.process.Pid
|
||||
|
|
@ -169,12 +182,36 @@ func (m *Manager) Shutdown() {
|
|||
go func() {
|
||||
defer close(terminateComplete)
|
||||
|
||||
// Try graceful termination first with SIGTERM
|
||||
// Try graceful termination first
|
||||
terminatedGracefully := false
|
||||
|
||||
if runtime.GOOS == "windows" {
|
||||
// Windows: Try to terminate the process
|
||||
m.mutex.Lock()
|
||||
if m.process != nil {
|
||||
err := m.process.Kill()
|
||||
if err != nil {
|
||||
logger.Warn("Failed to terminate process: %v", err)
|
||||
}
|
||||
}
|
||||
m.mutex.Unlock()
|
||||
|
||||
// Wait a bit to see if it terminates
|
||||
for i := 0; i < 10; i++ {
|
||||
time.Sleep(200 * time.Millisecond)
|
||||
m.mutex.Lock()
|
||||
if m.process == nil {
|
||||
terminatedGracefully = true
|
||||
m.mutex.Unlock()
|
||||
break
|
||||
}
|
||||
m.mutex.Unlock()
|
||||
}
|
||||
} else {
|
||||
// Unix: Use SIGTERM followed by SIGKILL if necessary
|
||||
// Try to terminate the process group first
|
||||
if processGroupToTerminate != 0 {
|
||||
err := syscall.Kill(-processGroupToTerminate, syscall.SIGTERM)
|
||||
err := killProcessGroup(processGroupToTerminate, syscall.SIGTERM)
|
||||
if err != nil {
|
||||
logger.Warn("Failed to send SIGTERM to process group: %v", err)
|
||||
|
||||
|
|
@ -212,6 +249,7 @@ func (m *Manager) Shutdown() {
|
|||
}
|
||||
m.mutex.Unlock()
|
||||
}
|
||||
}
|
||||
|
||||
if terminatedGracefully {
|
||||
logger.Info("Subprocess terminated gracefully")
|
||||
|
|
@ -221,9 +259,20 @@ func (m *Manager) Shutdown() {
|
|||
// If the process didn't exit gracefully, force kill
|
||||
logger.Warn("Subprocess didn't exit gracefully, forcing termination...")
|
||||
|
||||
if runtime.GOOS == "windows" {
|
||||
// On Windows, Kill() is already forceful
|
||||
m.mutex.Lock()
|
||||
if m.process != nil {
|
||||
if err := m.process.Kill(); err != nil {
|
||||
logger.Error("Failed to kill process: %v", err)
|
||||
}
|
||||
}
|
||||
m.mutex.Unlock()
|
||||
} else {
|
||||
// Unix: Try SIGKILL
|
||||
// Try to kill the process group first
|
||||
if processGroupToTerminate != 0 {
|
||||
if err := syscall.Kill(-processGroupToTerminate, syscall.SIGKILL); err != nil {
|
||||
if err := killProcessGroup(processGroupToTerminate, syscall.SIGKILL); err != nil {
|
||||
logger.Warn("Failed to send SIGKILL to process group: %v", err)
|
||||
|
||||
// Fallback to killing just the process
|
||||
|
|
@ -245,6 +294,7 @@ func (m *Manager) Shutdown() {
|
|||
}
|
||||
m.mutex.Unlock()
|
||||
}
|
||||
}
|
||||
|
||||
// Wait a bit more to confirm termination
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
|
|
|
|||
23
internal/subprocess/manager_unix.go
Normal file
23
internal/subprocess/manager_unix.go
Normal file
|
|
@ -0,0 +1,23 @@
|
|||
//go:build !windows
|
||||
|
||||
package subprocess
|
||||
|
||||
import (
|
||||
"os/exec"
|
||||
"syscall"
|
||||
)
|
||||
|
||||
// setProcAttr sets Unix-specific process attributes
|
||||
func setProcAttr(cmd *exec.Cmd) {
|
||||
cmd.SysProcAttr = &syscall.SysProcAttr{Setpgid: true}
|
||||
}
|
||||
|
||||
// getProcessGroup gets the process group ID on Unix systems
|
||||
func getProcessGroup(pid int) (int, error) {
|
||||
return syscall.Getpgid(pid)
|
||||
}
|
||||
|
||||
// killProcessGroup kills a process group on Unix systems
|
||||
func killProcessGroup(pgid int, signal syscall.Signal) error {
|
||||
return syscall.Kill(-pgid, signal)
|
||||
}
|
||||
27
internal/subprocess/manager_windows.go
Normal file
27
internal/subprocess/manager_windows.go
Normal file
|
|
@ -0,0 +1,27 @@
|
|||
//go:build windows
|
||||
|
||||
package subprocess
|
||||
|
||||
import (
|
||||
"os/exec"
|
||||
"syscall"
|
||||
)
|
||||
|
||||
// setProcAttr sets Windows-specific process attributes
|
||||
func setProcAttr(cmd *exec.Cmd) {
|
||||
cmd.SysProcAttr = &syscall.SysProcAttr{
|
||||
CreationFlags: syscall.CREATE_NEW_PROCESS_GROUP,
|
||||
}
|
||||
}
|
||||
|
||||
// getProcessGroup returns the PID itself on Windows (no process groups)
|
||||
func getProcessGroup(pid int) (int, error) {
|
||||
return pid, nil
|
||||
}
|
||||
|
||||
// killProcessGroup kills a process on Windows (no process groups)
|
||||
func killProcessGroup(pgid int, signal syscall.Signal) error {
|
||||
// On Windows, we'll use the process handle directly
|
||||
// This function shouldn't be called on Windows, but we provide it for compatibility
|
||||
return nil
|
||||
}
|
||||
|
|
@ -4,21 +4,27 @@ import (
|
|||
"crypto/rsa"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"math/big"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/golang-jwt/jwt/v4"
|
||||
"github.com/wso2/open-mcp-auth-proxy/internal/logging"
|
||||
"github.com/wso2/open-mcp-auth-proxy/internal/config"
|
||||
logger "github.com/wso2/open-mcp-auth-proxy/internal/logging"
|
||||
)
|
||||
|
||||
type TokenClaims struct {
|
||||
Scopes []string
|
||||
}
|
||||
|
||||
type JWKS struct {
|
||||
Keys []json.RawMessage `json:"keys"`
|
||||
}
|
||||
|
||||
var publicKeys map[string]*rsa.PublicKey
|
||||
|
||||
// FetchJWKS downloads JWKS and stores in a package-level map
|
||||
// FetchJWKS downloads JWKS and stores in a package‐level map
|
||||
func FetchJWKS(jwksURL string) error {
|
||||
resp, err := http.Get(jwksURL)
|
||||
if err != nil {
|
||||
|
|
@ -31,23 +37,23 @@ func FetchJWKS(jwksURL string) error {
|
|||
return err
|
||||
}
|
||||
|
||||
publicKeys = make(map[string]*rsa.PublicKey)
|
||||
publicKeys = make(map[string]*rsa.PublicKey, len(jwks.Keys))
|
||||
for _, keyData := range jwks.Keys {
|
||||
var parsedKey struct {
|
||||
var parsed struct {
|
||||
Kid string `json:"kid"`
|
||||
N string `json:"n"`
|
||||
E string `json:"e"`
|
||||
Kty string `json:"kty"`
|
||||
}
|
||||
if err := json.Unmarshal(keyData, &parsedKey); err != nil {
|
||||
if err := json.Unmarshal(keyData, &parsed); err != nil {
|
||||
continue
|
||||
}
|
||||
if parsedKey.Kty != "RSA" {
|
||||
if parsed.Kty != "RSA" {
|
||||
continue
|
||||
}
|
||||
pubKey, err := parseRSAPublicKey(parsedKey.N, parsedKey.E)
|
||||
pubKey, err := parseRSAPublicKey(parsed.N, parsed.E)
|
||||
if err == nil {
|
||||
publicKeys[parsedKey.Kid] = pubKey
|
||||
publicKeys[parsed.Kid] = pubKey
|
||||
}
|
||||
}
|
||||
logger.Info("Loaded %d public keys.", len(publicKeys))
|
||||
|
|
@ -73,25 +79,150 @@ func parseRSAPublicKey(nStr, eStr string) (*rsa.PublicKey, error) {
|
|||
return &rsa.PublicKey{N: n, E: e}, nil
|
||||
}
|
||||
|
||||
// ValidateJWT checks the Authorization: Bearer token using stored JWKS
|
||||
func ValidateJWT(authHeader string) error {
|
||||
if authHeader == "" || !strings.HasPrefix(authHeader, "Bearer ") {
|
||||
return errors.New("missing or invalid Authorization header")
|
||||
// ValidateJWT checks the Bearer token according to the Mcp-Protocol-Version.
|
||||
func ValidateJWT(
|
||||
isLatestSpec bool,
|
||||
accessToken string,
|
||||
audience string,
|
||||
) error {
|
||||
logger.Warn("isLatestSpec: %s", isLatestSpec)
|
||||
// Parse & verify the signature
|
||||
token, err := jwt.Parse(accessToken, func(token *jwt.Token) (interface{}, error) {
|
||||
if _, ok := token.Method.(*jwt.SigningMethodRSA); !ok {
|
||||
return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
|
||||
}
|
||||
tokenStr := strings.TrimPrefix(authHeader, "Bearer ")
|
||||
token, err := jwt.Parse(tokenStr, func(token *jwt.Token) (interface{}, error) {
|
||||
kid, _ := token.Header["kid"].(string)
|
||||
pubKey, ok := publicKeys[kid]
|
||||
kid, ok := token.Header["kid"].(string)
|
||||
if !ok {
|
||||
return nil, errors.New("unknown or missing kid in token header")
|
||||
return nil, errors.New("kid header not found")
|
||||
}
|
||||
return pubKey, nil
|
||||
key, ok := publicKeys[kid]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("key not found for kid: %s", kid)
|
||||
}
|
||||
return key, nil
|
||||
})
|
||||
if err != nil {
|
||||
return errors.New("invalid token: " + err.Error())
|
||||
logger.Warn("Error detected, returning early")
|
||||
return fmt.Errorf("invalid token: %w", err)
|
||||
}
|
||||
if !token.Valid {
|
||||
return errors.New("invalid token: token not valid")
|
||||
logger.Warn("Token invalid, returning early")
|
||||
return errors.New("token not valid")
|
||||
}
|
||||
|
||||
claimsMap, ok := token.Claims.(jwt.MapClaims)
|
||||
if !ok {
|
||||
return errors.New("unexpected claim type")
|
||||
}
|
||||
|
||||
if !isLatestSpec {
|
||||
return nil
|
||||
}
|
||||
|
||||
audRaw, exists := claimsMap["aud"]
|
||||
if !exists {
|
||||
return errors.New("aud claim missing")
|
||||
}
|
||||
switch v := audRaw.(type) {
|
||||
case string:
|
||||
if v != audience {
|
||||
return fmt.Errorf("aud %q does not match %q", v, audience)
|
||||
}
|
||||
case []interface{}:
|
||||
var found bool
|
||||
for _, a := range v {
|
||||
if s, ok := a.(string); ok && s == audience {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
return fmt.Errorf("audience %v does not include %q", v, audience)
|
||||
}
|
||||
default:
|
||||
return errors.New("aud claim has unexpected type")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Parses the JWT token and returns the claims
|
||||
func ParseJWT(tokenStr string) (jwt.MapClaims, error) {
|
||||
if tokenStr == "" {
|
||||
return nil, fmt.Errorf("empty JWT")
|
||||
}
|
||||
|
||||
var claims jwt.MapClaims
|
||||
_, _, err := jwt.NewParser().ParseUnverified(tokenStr, &claims)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to parse JWT: %w", err)
|
||||
}
|
||||
return claims, nil
|
||||
}
|
||||
|
||||
// Process the required scopes
|
||||
func GetRequiredScopes(cfg *config.Config, requestBody *RPCEnvelope) []string {
|
||||
|
||||
var scopeObj interface{}
|
||||
found := false
|
||||
for _, m := range cfg.ProtectedResourceMetadata.ScopesSupported {
|
||||
if val, ok := m[requestBody.Method]; ok {
|
||||
scopeObj = val
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
return nil
|
||||
}
|
||||
|
||||
switch v := scopeObj.(type) {
|
||||
case string:
|
||||
return []string{v}
|
||||
case []any:
|
||||
if requestBody.Params != nil {
|
||||
if paramsMap, ok := requestBody.Params.(map[string]any); ok {
|
||||
name, ok := paramsMap["name"].(string)
|
||||
if ok {
|
||||
for _, item := range v {
|
||||
if scopeMap, ok := item.(map[interface{}]interface{}); ok {
|
||||
if scopeVal, exists := scopeMap[name]; exists {
|
||||
if scopeStr, ok := scopeVal.(string); ok {
|
||||
return []string{scopeStr}
|
||||
}
|
||||
if scopeArr, ok := scopeVal.([]any); ok {
|
||||
var scopes []string
|
||||
for _, s := range scopeArr {
|
||||
if str, ok := s.(string); ok {
|
||||
scopes = append(scopes, str)
|
||||
}
|
||||
}
|
||||
return scopes
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Extracts the Bearer token from the Authorization header
|
||||
func ExtractAccessToken(authHeader string) (string, error) {
|
||||
if authHeader == "" {
|
||||
return "", errors.New("empty authorization header")
|
||||
}
|
||||
if !strings.HasPrefix(authHeader, "Bearer ") {
|
||||
return "", fmt.Errorf("invalid authorization header format: %s", authHeader)
|
||||
}
|
||||
|
||||
tokenStr := strings.TrimSpace(strings.TrimPrefix(authHeader, "Bearer "))
|
||||
if tokenStr == "" {
|
||||
return "", errors.New("empty bearer token")
|
||||
}
|
||||
|
||||
return tokenStr, nil
|
||||
}
|
||||
|
|
|
|||
|
|
@ -7,6 +7,7 @@ import (
|
|||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
|
|
@ -47,7 +48,14 @@ func TestValidateJWT(t *testing.T) {
|
|||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
err := ValidateJWT(tc.authHeader)
|
||||
var accessToken string
|
||||
parts := strings.Split(tc.authHeader, "Bearer ")
|
||||
if len(parts) == 2 {
|
||||
accessToken = parts[1]
|
||||
} else {
|
||||
accessToken = ""
|
||||
}
|
||||
err := ValidateJWT(true, accessToken, "test-audience")
|
||||
if tc.expectError && err == nil {
|
||||
t.Errorf("Expected error but got none")
|
||||
}
|
||||
|
|
@ -128,6 +136,7 @@ func createValidJWT(t *testing.T) string {
|
|||
token := jwt.NewWithClaims(jwt.SigningMethodRS256, jwt.MapClaims{
|
||||
"sub": "1234567890",
|
||||
"name": "Test User",
|
||||
"aud": "test-audience",
|
||||
"iat": time.Now().Unix(),
|
||||
"exp": time.Now().Add(time.Hour).Unix(),
|
||||
})
|
||||
|
|
|
|||
38
internal/util/rpc.go
Normal file
38
internal/util/rpc.go
Normal file
|
|
@ -0,0 +1,38 @@
|
|||
package util
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
|
||||
logger "github.com/wso2/open-mcp-auth-proxy/internal/logging"
|
||||
)
|
||||
|
||||
type RPCEnvelope struct {
|
||||
Method string `json:"method"`
|
||||
Params any `json:"params"`
|
||||
ID any `json:"id"`
|
||||
}
|
||||
|
||||
// This function parses a JSON-RPC request from an HTTP request body
|
||||
func ParseRPCRequest(r *http.Request) (*RPCEnvelope, error) {
|
||||
bodyBytes, err := io.ReadAll(r.Body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
r.Body = io.NopCloser(bytes.NewBuffer(bodyBytes))
|
||||
|
||||
if len(bodyBytes) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
var env RPCEnvelope
|
||||
dec := json.NewDecoder(bytes.NewReader(bodyBytes))
|
||||
if err := dec.Decode(&env); err != nil && err != io.EOF {
|
||||
logger.Warn("Error parsing JSON-RPC envelope: %v", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &env, nil
|
||||
}
|
||||
26
internal/util/version.go
Normal file
26
internal/util/version.go
Normal file
|
|
@ -0,0 +1,26 @@
|
|||
package util
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/wso2/open-mcp-auth-proxy/internal/constants"
|
||||
)
|
||||
|
||||
// This function checks if the given version date is after the spec cutover date
|
||||
func IsLatestSpec(versionDate time.Time, err error) bool {
|
||||
return err == nil && versionDate.After(constants.SpecCutoverDate)
|
||||
}
|
||||
|
||||
// This function parses a version string into a time.Time
|
||||
func ParseVersionDate(version string) (time.Time, error) {
|
||||
return time.Parse("2006-01-02", version)
|
||||
}
|
||||
|
||||
// This function returns the version string, using the cutover date if empty
|
||||
func GetVersionWithDefault(version string) string {
|
||||
if version == "" {
|
||||
defaultTime, _ := time.Parse(constants.TimeLayout, "2025-05-15")
|
||||
return defaultTime.Format(constants.TimeLayout)
|
||||
}
|
||||
return version
|
||||
}
|
||||
19
resources/README.md
Normal file
19
resources/README.md
Normal file
|
|
@ -0,0 +1,19 @@
|
|||
# Example MCP server
|
||||
|
||||
Use this example MCP server, if you don't already have an MCP server to test the open-mcp-auth-proxy.
|
||||
|
||||
## Setting Up
|
||||
|
||||
1. Set up a Python virtual environment.
|
||||
|
||||
```bash
|
||||
python3 -m venv .venv
|
||||
source .venv/bin/activate
|
||||
pip3 install -r requirements.txt
|
||||
```
|
||||
|
||||
2. Start the example server.
|
||||
|
||||
```bash
|
||||
python3 echo_server.py
|
||||
```
|
||||
|
|
@ -2,7 +2,6 @@ from mcp.server.fastmcp import FastMCP
|
|||
|
||||
mcp = FastMCP("Echo")
|
||||
|
||||
|
||||
@mcp.resource("echo://{message}")
|
||||
def echo_resource(message: str) -> str:
|
||||
"""Echo a message as a resource"""
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue