Compare commits

...
Sign in to create a new pull request.

31 commits

Author SHA1 Message Date
8e80e2956c
add uvx
Some checks failed
Go CI / Test (push) Failing after 19s
Go CI / Build (push) Successful in 38s
Build and Push container / build_concierge_backend (push) Successful in 1m24s
2025-06-02 03:55:14 +02:00
c3e2abd2bc
change cmd
Some checks failed
Go CI / Test (push) Failing after 21s
Go CI / Build (push) Successful in 55s
Build and Push container / build_concierge_backend (push) Successful in 1m21s
2025-06-02 03:41:06 +02:00
e0beca18cf
adjust configuration to kvant 2025-05-19 10:25:27 +02:00
Pavindu Lakshan
ad5185ad72
Merge pull request #28 from pavinduLakshan/update_pr_template 2025-05-03 01:08:19 +05:30
Pavindu Lakshan
0bbc20ca5a Remove unnecessary fields from PR template 2025-05-03 01:06:41 +05:30
Pavindu Lakshan
4a5cf4e1cc
Update README.md 2025-04-27 17:23:13 +05:30
Pavindu Lakshan
87a1cbe21a
Update release.yml 2025-04-26 20:02:45 +05:30
Pavindu Lakshan
9ce9509cce
Fix issues in makefile (#26) 2025-04-21 15:29:48 +05:30
Pavindu Lakshan
5261a69f7a
Improve ordering in README (#24) 2025-04-18 21:40:36 +05:30
Pavindu Lakshan
23c282dcfc
Add badges to README (#25) 2025-04-18 21:40:12 +05:30
Pavindu Lakshan
f4be3de30f
Add release workflow (#23)
* Add release workflow
2025-04-18 15:12:32 +05:30
Pavindu Lakshan
6036ab30ec
Update README.md 2025-04-17 14:15:04 +05:30
Pavindu Lakshan
11fc3cdfcd
Merge pull request #16 from shashimalcse/shashimalcse-patch-0003 2025-04-16 12:09:04 +05:30
Pavindu Lakshan
42efe1f48a
Update internal/proxy/modifier_test.go
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
2025-04-16 12:08:31 +05:30
Pavindu Lakshan
5a392ea496
Merge pull request #21 from wso2/keycloak 2025-04-16 10:40:59 +05:30
Pavindu Lakshan
d3d2f33661 Improve formatting 2025-04-15 09:12:51 +05:30
Pavindu Lakshan
ecee345f9c Add guide for Keycloak integration 2025-04-15 08:53:53 +05:30
Omindu Rathnaweera
b32f25e694
Delete .github/workflows/go.yml 2025-04-12 10:56:10 +05:30
Omindu Rathnaweera
aa7f76a548
Update pr-builder.yml 2025-04-12 10:55:31 +05:30
Thilina Shashimal Senarath
9ecaabecd2 Remove vscode files 2025-04-09 09:52:40 +05:30
Omindu Rathnaweera
8f57cfe3e0
Update pr-builder.yml 2025-04-08 21:45:11 +05:30
Omindu Rathnaweera
aa5b32aa8a
Create pr-builder.yml 2025-04-08 21:40:08 +05:30
Thilina Shashimal Senarath
9a3d5346f2
Fix Auth0 configs (#15) 2025-04-08 13:32:07 +05:30
Thilina Shashimal Senarath
b2b2124b76 Add unit tests 2025-04-08 13:26:16 +05:30
Chiran Fernando
32c9378aad
Add transport mode support for stdio, SSE stability fixes (#13)
Add transport mode support for stdio, SSE stability fixes
2025-04-08 12:46:00 +05:30
Pavindu Lakshan
6ce52261db
Add venv activate step 2025-04-04 14:45:58 +05:30
Pavindu Lakshan
86fb278ba5
Merge pull request #12 from wso2/update_sample_in
Add instructions to run the sample MCP server
2025-04-04 14:10:27 +05:30
Pavindu Lakshan
8d7aab073e Add instructions to run the sample MCP server 2025-04-04 14:08:11 +05:30
Thilina Shashimal Senarath
28f830dfbf
remove default config (#11) 2025-04-03 22:02:12 +05:30
Ayesha Dissanayaka
48c7f30ea8
Update README.md 2025-04-03 20:00:02 +05:30
Thilina Shashimal Senarath
97ceeb3a1d
Update readme with reason for the inspector fork (#10)
* add reason for fork
2025-04-03 18:37:28 +05:30
29 changed files with 2016 additions and 214 deletions

124
.github/scripts/release.sh vendored Normal file
View file

@ -0,0 +1,124 @@
#!/bin/bash
# Copyright (c) 2025, WSO2 LLC. (https://www.wso2.com).
#
# This software is the property of WSO2 LLC. and its suppliers, if any.
# Dissemination of any information or reproduction of any material contained
# herein in any form is strictly forbidden, unless permitted by WSO2 expressly.
# You may not alter or remove any copyright or other notice from copies of this content.
#
# Exit the script on any command with non-zero exit status.
set -e
set -o pipefail
UPSTREAM_BRANCH="main"
# Assign command line arguments to variables.
GIT_TOKEN=$1
WORK_DIR=$2
VERSION_TYPE=$3 # possible values: major, minor, patch
# Check if GIT_TOKEN is empty
if [ -z "$GIT_TOKEN" ]; then
echo "❌ Error: GIT_TOKEN is not set."
exit 1
fi
# Check if WORK_DIR is empty
if [ -z "$WORK_DIR" ]; then
echo "❌ Error: WORK_DIR is not set."
exit 1
fi
# Validate VERSION_TYPE
if [[ "$VERSION_TYPE" != "major" && "$VERSION_TYPE" != "minor" && "$VERSION_TYPE" != "patch" ]]; then
echo "❌ Error: VERSION_TYPE must be one of: major, minor, or patch."
exit 1
fi
BUILD_DIRECTORY="$WORK_DIR/build"
RELEASE_DIRECTORY="$BUILD_DIRECTORY/releases"
# Navigate to the working directory.
cd "${WORK_DIR}"
# Create the release directory.
if [ ! -d "$RELEASE_DIRECTORY" ]; then
mkdir -p "$RELEASE_DIRECTORY"
else
rm -rf "$RELEASE_DIRECTORY"/*
fi
# Extract current version.
CURRENT_VERSION=$(git describe --tags --abbrev=0 2>/dev/null || echo "0.0.0")
IFS='.' read -r MAJOR MINOR PATCH <<< "${CURRENT_VERSION}"
# Determine which part to increment
case "$VERSION_TYPE" in
major)
MAJOR=$((MAJOR + 1))
MINOR=0
PATCH=0
;;
minor)
MINOR=$((MINOR + 1))
PATCH=0
;;
patch|*)
PATCH=$((PATCH + 1))
;;
esac
NEW_VERSION="${MAJOR}.${MINOR}.${PATCH}"
echo "Creating release packages for version $NEW_VERSION..."
# List of supported OSes.
oses=("linux" "linux-arm" "darwin")
# Navigate to the release directory.
cd "${RELEASE_DIRECTORY}"
for os in "${oses[@]}"; do
os_dir="../$os"
if [ -d "$os_dir" ]; then
release_artifact_folder="openmcpauthproxy_${os}-v${NEW_VERSION}"
mkdir -p "$release_artifact_folder"
cp -r $os_dir/* "$release_artifact_folder"
# Zip the release package.
zip_file="$release_artifact_folder.zip"
echo "Creating $zip_file..."
zip -r "$zip_file" "$release_artifact_folder"
# Delete the folder after zipping.
rm -rf "$release_artifact_folder"
# Generate checksum file.
sha256sum "$zip_file" | sed "s|target/releases/||" > "$zip_file.sha256"
echo "Checksum generated for the $os package."
echo "Release packages created successfully for $os."
else
echo "Skipping $os release package creation as the build artifacts are not available."
fi
done
echo "Release packages created successfully in $RELEASE_DIRECTORY."
# Navigate back to the project root directory.
cd "${WORK_DIR}"
# Collect all ZIP and .sha256 files in the target/releases directory.
FILES_TO_UPLOAD=$(find build/releases -type f \( -name "*.zip" -o -name "*.sha256" \))
# Create a release with the current version.
TAG_NAME="v${NEW_VERSION}"
export GITHUB_TOKEN="${GIT_TOKEN}"
gh release create "${TAG_NAME}" ${FILES_TO_UPLOAD} --title "${TAG_NAME}" --notes "OpenMCPAuthProxy - ${TAG_NAME}" --target "${UPSTREAM_BRANCH}" || { echo "Failed to create release"; exit 1; }
echo "Release ${TAG_NAME} created successfully."

71
.github/workflows/ci.yaml vendored Normal file
View file

@ -0,0 +1,71 @@
name: Build and Push container
run-name: Build and Push container
on:
workflow_dispatch:
#schedule:
# - cron: "0 10 * * *"
push:
branches:
- 'main'
- 'master'
tags:
- 'v*'
pull_request:
branches:
- 'main'
- 'master'
env:
IMAGE: git.kvant.cloud/${{github.repository}}
jobs:
build_concierge_backend:
runs-on: ubuntu-latest
steps:
- name: Checkout
uses: actions/checkout@v4
with:
fetch-depth: 0
- name: Set current time
uses: https://github.com/gerred/actions/current-time@master
id: current_time
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
- name: Login to git.kvant.cloud registry
uses: docker/login-action@v3
with:
registry: git.kvant.cloud
username: ${{ vars.ORG_PACKAGE_WRITER_USERNAME }}
password: ${{ secrets.ORG_PACKAGE_WRITER_TOKEN }}
- name: Docker meta
id: meta
uses: docker/metadata-action@v5
with:
# list of Docker images to use as base name for tags
images: |
${{env.IMAGE}}
# generate Docker tags based on the following events/attributes
tags: |
type=schedule
type=ref,event=branch
type=ref,event=pr
type=semver,pattern={{version}}
- name: Build and push to gitea registry
uses: docker/build-push-action@v6
with:
push: ${{ github.event_name != 'pull_request' }}
tags: ${{ steps.meta.outputs.tags }}
labels: ${{ steps.meta.outputs.labels }}
context: .
provenance: mode=max
sbom: true
build-args: |
BUILD_DATE=${{ steps.current_time.outputs.time }}
cache-from: |
type=registry,ref=${{ env.IMAGE }}:buildcache
type=registry,ref=${{ env.IMAGE }}:${{ github.ref_name }}
type=registry,ref=${{ env.IMAGE }}:main
cache-to: type=registry,ref=${{ env.IMAGE }}:buildcache,mode=max,image-manifest=true

62
.github/workflows/pr-builder.yml vendored Normal file
View file

@ -0,0 +1,62 @@
name: Go CI
on:
push:
branches: [ main ]
pull_request:
branches: [ main ]
jobs:
test:
name: Test
runs-on: ubuntu-latest
strategy:
matrix:
go-version: ['1.20', '1.21']
steps:
- name: Checkout code
uses: actions/checkout@v3
- name: Set up Go
uses: actions/setup-go@v4
with:
go-version: ${{ matrix.go-version }}
- name: Get dependencies
run: go get -v -t -d ./...
- name: Verify dependencies
run: go mod verify
- name: Run go vet
run: go vet ./...
- name: Run tests
run: go test -v -race -coverprofile=coverage.txt -covermode=atomic ./...
- name: Upload coverage to Codecov
uses: codecov/codecov-action@v3
with:
files: ./coverage.txt
fail_ci_if_error: false
build:
name: Build
runs-on: ubuntu-latest
strategy:
matrix:
go-version: ['1.20', '1.21']
os: [ubuntu-latest, macos-latest]
steps:
- name: Checkout code
uses: actions/checkout@v3
- name: Set up Go
uses: actions/setup-go@v4
with:
go-version: ${{ matrix.go-version }}
- name: Build
run: go build -v ./cmd/proxy

64
.github/workflows/release.yml vendored Normal file
View file

@ -0,0 +1,64 @@
#
# Copyright (c) 2025, WSO2 LLC. (https://www.wso2.com).
#
# This software is the property of WSO2 LLC. and its suppliers, if any.
# Dissemination of any information or reproduction of any material contained
# herein in any form is strictly forbidden, unless permitted by WSO2 expressly.
# You may not alter or remove any copyright or other notice from copies of this content.
#
name: Release
on:
workflow_dispatch:
inputs:
version_type:
type: choice
description: Choose the type of version update
options:
- 'major'
- 'minor'
- 'patch'
required: true
jobs:
update-and-release:
runs-on: ubuntu-latest
env:
GOPROXY: https://proxy.golang.org
if: github.event.pull_request.merged == true || github.event_name == 'workflow_dispatch'
steps:
- uses: actions/checkout@v2
with:
ref: 'main'
fetch-depth: 0
token: ${{ secrets.GITHUB_TOKEN }}
- uses: actions/checkout@v2
- name: Set up Go 1.x
uses: actions/setup-go@v3
with:
go-version: "^1.x"
- name: Cache Go modules
id: cache-go-modules
uses: actions/cache@v3
with:
path: |
~/.cache/go-build
~/go/pkg/mod
key: ${{ runner.os }}-go-modules-${{ hashFiles('**/go.sum') }}
restore-keys: |
${{ runner.os }}-go-modules-
- name: Install dependencies
run: go mod download
- name: Build and test
run: make build
working-directory: .
- name: Update artifact version, package, commit, and create release.
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
run: bash ./.github/scripts/release.sh $GITHUB_TOKEN ${{ github.workspace }} ${{ github.event.inputs.version_type }}

14
.gitignore vendored
View file

@ -18,15 +18,21 @@
*.zip
*.tar.gz
*.rar
.venv
# virtual machine crash logs, see http://www.java.com/en/download/help/error_hotspot.xml
hs_err_pid*
replay_pid*
# Go module cache files
go.sum
# OS generated files
.DS_Store
openmcpauthproxy
# builds
build
# test out files
coverage.out
coverage.html
# IDE files
.vscode

48
Dockerfile Normal file
View file

@ -0,0 +1,48 @@
FROM --platform=${BUILDPLATFORM:-linux/amd64} golang:1.24@sha256:d9db32125db0c3a680cfb7a1afcaefb89c898a075ec148fdc2f0f646cc2ed509 AS build
ARG TARGETPLATFORM
ARG BUILDPLATFORM
ARG TARGETOS
ARG TARGETARCH
WORKDIR /workspace
RUN apt update -qq && apt install -qq -y git bash curl g++
# Download libraries
ADD go.* .
RUN go mod download
# Build
ADD cmd cmd
ADD internal internal
RUN CGO_ENABLED=0 GOOS=${TARGETOS} GOARCH=${TARGETARCH} go build -o webhook -ldflags '-w -extldflags "-static"' -o openmcpauthproxy ./cmd/proxy
#Test
RUN CGO_ENABLED=1 GOOS=${TARGETOS} GOARCH=${TARGETARCH} go test -v -race ./...
# Build production container
FROM --platform=${BUILDPLATFORM:-linux/amd64} ubuntu:24.04
RUN apt-get update \
&& apt-get install --no-install-recommends -y \
python3-pip \
python-is-python3 \
npm \
&& apt-get autoremove \
&& apt-get clean \
&& rm -rf /var/lib/apt/lists/*
RUN pip install uvenv --break-system-packages
WORKDIR /app
COPY --from=build /workspace/openmcpauthproxy /app/
ADD config.yaml /app
ENTRYPOINT ["/app/openmcpauthproxy"]
ARG IMAGE_SOURCE
LABEL org.opencontainers.image.source=$IMAGE_SOURCE

88
Makefile Normal file
View file

@ -0,0 +1,88 @@
# Makefile for open-mcp-auth-proxy
# Variables
PROJECT_ROOT := $(realpath $(dir $(abspath $(lastword $(MAKEFILE_LIST)))))
BINARY_NAME := openmcpauthproxy
GO := go
GOFMT := gofmt
GOVET := go vet
GOTEST := go test
GOLINT := golangci-lint
GOCOV := go tool cover
BUILD_DIR := build
# Source files
SRC := $(shell find . -name "*.go" -not -path "./vendor/*")
PKGS := $(shell go list ./... | grep -v /vendor/)
# Set build options
BUILD_OPTS := -v
# Set test options
TEST_OPTS := -v -race
.PHONY: all clean test fmt lint vet coverage help
# Default target
all: lint test build-linux build-linux-arm build-darwin
build: clean test build-linux build-linux-arm build-darwin
build-linux:
mkdir -p $(BUILD_DIR)/linux
GOOS=linux GOARCH=amd64 CGO_ENABLED=0 go build -x -ldflags "-X main.version=$(BUILD_VERSION)" \
-o $(BUILD_DIR)/linux/openmcpauthproxy $(PROJECT_ROOT)/cmd/proxy
cp config.yaml $(BUILD_DIR)/linux
build-linux-arm:
mkdir -p $(BUILD_DIR)/linux-arm
GOOS=linux GOARCH=arm CGO_ENABLED=0 go build -x -ldflags "-X main.version=$(BUILD_VERSION)" \
-o $(BUILD_DIR)/linux-arm/openmcpauthproxy $(PROJECT_ROOT)/cmd/proxy
cp config.yaml $(BUILD_DIR)/linux-arm
build-darwin:
mkdir -p $(BUILD_DIR)/darwin
GOOS=darwin GOARCH=amd64 CGO_ENABLED=0 go build -x -ldflags "-X main.version=$(BUILD_VERSION)" \
-o $(BUILD_DIR)/darwin/openmcpauthproxy $(PROJECT_ROOT)/cmd/proxy
cp config.yaml $(BUILD_DIR)/darwin
# Clean build artifacts
clean:
@echo "Cleaning build artifacts..."
@rm -rf $(BUILD_DIR)
@rm -f coverage.out
# Run tests
test:
@echo "Running tests..."
$(GOTEST) $(TEST_OPTS) ./...
# Run tests with coverage report
coverage:
@echo "Running tests with coverage..."
@$(GOTEST) -coverprofile=coverage.out ./...
@$(GOCOV) -func=coverage.out
@$(GOCOV) -html=coverage.out -o coverage.html
@echo "Coverage report generated in coverage.html"
# Run gofmt
fmt:
@echo "Running gofmt..."
@$(GOFMT) -w -s $(SRC)
# Run go vet
vet:
@echo "Running go vet..."
@$(GOVET) ./...
# Show help
help:
@echo "Available targets:"
@echo " all : Run lint, test, and build"
@echo " build : Build the application"
@echo " clean : Clean build artifacts"
@echo " test : Run tests"
@echo " coverage : Run tests with coverage report"
@echo " fmt : Run gofmt"
@echo " vet : Run go vet"
@echo " help : Show this help message"

243
README.md
View file

@ -1,82 +1,87 @@
# Open MCP Auth Proxy
The Open MCP Auth Proxy is a lightweight proxy designed to sit in front of MCP servers and enforce authorization in compliance with the [Model Context Protocol authorization](https://spec.modelcontextprotocol.io/specification/2025-03-26/basic/authorization/) requirements. It intercepts incoming requests, validates tokens, and offloads authentication and authorization to an OAuth-compliant Identity Provider.
A lightweight authorization proxy for Model Context Protocol (MCP) servers that enforces authorization according to the [MCP authorization specification](https://spec.modelcontextprotocol.io/specification/2025-03-26/basic/authorization/)
![image](https://github.com/user-attachments/assets/41cf6723-c488-4860-8640-8fec45006f92)
<a href="">[![🚀 Release](https://github.com/wso2/open-mcp-auth-proxy/actions/workflows/release.yml/badge.svg)](https://github.com/wso2/open-mcp-auth-proxy/actions/workflows/release.yml)</a>
<a href="">[![💬 Stackoverflow](https://img.shields.io/badge/Ask%20for%20help%20on-Stackoverflow-orange)](https://stackoverflow.com/questions/tagged/wso2is)</a>
<a href="">[![💬 Discord](https://img.shields.io/badge/Join%20us%20on-Discord-%23e01563.svg)](https://discord.gg/wso2)</a>
<a href="">[![🐦 Twitter](https://img.shields.io/twitter/follow/wso2.svg?style=social&label=Follow)](https://twitter.com/intent/follow?screen_name=wso2)</a>
<a href="">[![📝 License](https://img.shields.io/badge/License-Apache%202.0-blue.svg)](https://github.com/wso2/product-is/blob/master/LICENSE)</a>
## **Setup and Installation**
![Architecture Diagram](https://github.com/user-attachments/assets/41cf6723-c488-4860-8640-8fec45006f92)
### **Prerequisites**
## What it Does
* Go 1.20 or higher
* A running MCP server (SSE transport supported)
* An MCP client that supports MCP authorization
Open MCP Auth Proxy sits between MCP clients and your MCP server to:
### **Installation**
- Intercept incoming requests
- Validate authorization tokens
- Offload authentication and authorization to OAuth-compliant Identity Providers
- Support the MCP authorization protocol
```bash
git clone https://github.com/wso2/open-mcp-auth-proxy
cd open-mcp-auth-proxy
## Quick Start
go get github.com/golang-jwt/jwt/v4
go get gopkg.in/yaml.v2
### Prerequisites
go build -o openmcpauthproxy ./cmd/proxy
```
* Go 1.20 or higher
* A running MCP server
## Using Open MCP Auth Proxy
> If you don't have an MCP server, you can use the included example:
>
> 1. Navigate to the `resources` directory
> 2. Set up a Python environment:
>
> ```bash
> python3 -m venv .venv
> source .venv/bin/activate
> pip3 install -r requirements.txt
> ```
>
> 3. Start the example server:
>
> ```bash
> python3 echo_server.py
> ```
### Quick Start
* An MCP client that supports MCP authorization
Allows you to just enable authentication and authorization for your MCP server with the preconfigured auth provider powered by Asgardeo.
### Basic Usage
If you dont have an MCP server, follow the instructions given here to start your own MCP server for testing purposes.
1. Download [sample MCP server](resources/echo_server.py)
2. Run the server with
```bash
python3 echo_server.py
```
1. Download the latest release from [Github releases](https://github.com/wso2/open-mcp-auth-proxy/releases/latest).
#### Configure the Auth Proxy
Update the following parameters in `config.yaml`.
### demo mode configuration:
```yaml
mcp_server_base_url: "http://localhost:8000" # URL of your MCP server
listen_port: 8080 # Address where the proxy will listen
```
#### Start the Auth Proxy
2. Start the proxy in demo mode (uses pre-configured authentication with Asgardeo sandbox):
```bash
./openmcpauthproxy --demo
```
The `--demo` flag enables a demonstration mode with pre-configured authentication and authorization with a sandbox powered by [Asgardeo](https://asgardeo.io/).
> The repository comes with a default `config.yaml` file that contains the basic configuration:
>
> ```yaml
> listen_port: 8080
> base_url: "http://localhost:8000" # Your MCP server URL
> paths:
> sse: "/sse"
> messages: "/messages/"
> ```
#### Connect Using an MCP Client
3. Connect using an MCP client like [MCP Inspector](https://github.com/shashimalcse/inspector)(This is a temporary fork with fixes for authentication [issues](https://github.com/modelcontextprotocol/typescript-sdk/issues/257) in the original implementation)
You can use the [MCP Inspector](https://github.com/shashimalcse/inspector) to test the connection and try out the complete authorization flow.
## Connect an Identity Provider
### Use with Asgardeo
### Asgardeo
Enable authorization for the MCP server through your own Asgardeo organization
To enable authorization through your Asgardeo organization:
1. [Register]([url](https://asgardeo.io/signup)) and create an organization in Asgardeo
2. Now, you need to authorize the OpenMCPAuthProxy to allow dynamically registering MCP Clients as applications in your organization. To do that,
1. Create an [M2M application](https://wso2.com/asgardeo/docs/guides/applications/register-machine-to-machine-app/)
1. [Authorize this application](https://wso2.com/asgardeo/docs/guides/applications/register-machine-to-machine-app/#authorize-the-api-resources-for-the-app) to invoke “Application Management API” with the `internal_application_mgt_create` scope.
![image](https://github.com/user-attachments/assets/0bd57cac-1904-48cc-b7aa-0530224bc41a)
2. Note the **Client ID** and **Client secret** of this application. This is required by the auth proxy
#### Configure the Auth Proxy
Create a configuration file config.yaml with the following parameters:
1. [Register](https://asgardeo.io/signup) and create an organization in Asgardeo
2. Create an [M2M application](https://wso2.com/asgardeo/docs/guides/applications/register-machine-to-machine-app/)
1. [Authorize this application](https://wso2.com/asgardeo/docs/guides/applications/register-machine-to-machine-app/#authorize-the-api-resources-for-the-app) to invoke "Application Management API" with the `internal_application_mgt_create` scope
![image](https://github.com/user-attachments/assets/0bd57cac-1904-48cc-b7aa-0530224bc41a)
3. Update `config.yaml` with the following parameters.
```yaml
mcp_server_base_url: "http://localhost:8000" # URL of your MCP server
base_url: "http://localhost:8000" # URL of your MCP server
listen_port: 8080 # Address where the proxy will listen
asgardeo:
@ -85,31 +90,137 @@ asgardeo:
client_secret: "<client_secret>" # Client secret of the M2M app
```
#### Start the Auth Proxy
4. Start the proxy with Asgardeo integration:
```bash
./openmcpauthproxy --asgardeo
```
### Use with any standard OAuth Server
### Other OAuth Providers
Enable authorization for the MCP server with a compliant OAuth server
- [Auth0](docs/integrations/Auth0.md)
- [Keycloak](docs/integrations/keycloak.md)
#### Configuration
# Advanced Configuration
Create a configuration file config.yaml with the following parameters:
### Transport Modes
```yaml
mcp_server_base_url: "http://localhost:8000" # URL of your MCP server
listen_port: 8080 # Address where the proxy will listen
```
**TODO**: Update the configs for a standard OAuth Server.
The proxy supports two transport modes:
#### Start the Auth Proxy
- **SSE Mode (Default)**: For Server-Sent Events transport
- **stdio Mode**: For MCP servers that use stdio transport
When using stdio mode, the proxy:
- Starts an MCP server as a subprocess using the command specified in the configuration
- Communicates with the subprocess through standard input/output (stdio)
- **Note**: Any commands specified (like `npx` in the example below) must be installed on your system first
To use stdio mode:
```bash
./openmcpauthproxy
./openmcpauthproxy --demo --stdio
```
#### Integrating with existing OAuth Providers
- [Auth0](docs/Auth0.md) - Enable authorization for the MCP server through your Auth0 organization.
#### Example: Running an MCP Server as a Subprocess
1. Configure stdio mode in your `config.yaml`:
```yaml
listen_port: 8080
base_url: "http://localhost:8000"
stdio:
enabled: true
user_command: "npx -y @modelcontextprotocol/server-github" # Example using a GitHub MCP server
env: # Environment variables (optional)
- "GITHUB_PERSONAL_ACCESS_TOKEN=gitPAT"
# CORS configuration
cors:
allowed_origins:
- "http://localhost:5173" # Origin of your client application
allowed_methods:
- "GET"
- "POST"
- "PUT"
- "DELETE"
allowed_headers:
- "Authorization"
- "Content-Type"
allow_credentials: true
# Demo configuration for Asgardeo
demo:
org_name: "openmcpauthdemo"
client_id: "N0U9e_NNGr9mP_0fPnPfPI0a6twa"
client_secret: "qFHfiBp5gNGAO9zV4YPnDofBzzfInatfUbHyPZvM0jka"
```
2. Run the proxy with stdio mode:
```bash
./openmcpauthproxy --demo
```
The proxy will:
- Start the MCP server as a subprocess using the specified command
- Handle all authorization requirements
- Forward messages between clients and the server
### Complete Configuration Reference
```yaml
# Common configuration
listen_port: 8080
base_url: "http://localhost:8000"
port: 8000
# Path configuration
paths:
sse: "/sse"
messages: "/messages/"
# Transport mode
transport_mode: "sse" # Options: "sse" or "stdio"
# stdio-specific configuration (used only in stdio mode)
stdio:
enabled: true
user_command: "npx -y @modelcontextprotocol/server-github" # Command to start the MCP server (requires npx to be installed)
work_dir: "" # Optional working directory for the subprocess
# CORS configuration
cors:
allowed_origins:
- "http://localhost:5173"
allowed_methods:
- "GET"
- "POST"
- "PUT"
- "DELETE"
allowed_headers:
- "Authorization"
- "Content-Type"
allow_credentials: true
# Demo configuration for Asgardeo
demo:
org_name: "openmcpauthdemo"
client_id: "N0U9e_NNGr9mP_0fPnPfPI0a6twa"
client_secret: "qFHfiBp5gNGAO9zV4YPnDofBzzfInatfUbHyPZvM0jka"
# Asgardeo configuration (used with --asgardeo flag)
asgardeo:
org_name: "<org_name>"
client_id: "<client_id>"
client_secret: "<client_secret>"
```
### Build from source
```bash
git clone https://github.com/wso2/open-mcp-auth-proxy
cd open-mcp-auth-proxy
go get github.com/golang-jwt/jwt/v4 gopkg.in/yaml.v2
go build -o openmcpauthproxy ./cmd/proxy
```

View file

@ -3,31 +3,71 @@ package main
import (
"flag"
"fmt"
"log"
"net/http"
"os"
"os/signal"
"syscall"
"time"
"github.com/wso2/open-mcp-auth-proxy/internal/authz"
"github.com/wso2/open-mcp-auth-proxy/internal/config"
"github.com/wso2/open-mcp-auth-proxy/internal/constants"
logger "github.com/wso2/open-mcp-auth-proxy/internal/logging"
"github.com/wso2/open-mcp-auth-proxy/internal/proxy"
"github.com/wso2/open-mcp-auth-proxy/internal/subprocess"
"github.com/wso2/open-mcp-auth-proxy/internal/util"
)
func main() {
demoMode := flag.Bool("demo", false, "Use Asgardeo-based provider (demo).")
asgardeoMode := flag.Bool("asgardeo", false, "Use Asgardeo-based provider (asgardeo).")
debugMode := flag.Bool("debug", false, "Enable debug logging")
stdioMode := flag.Bool("stdio", false, "Use stdio transport mode instead of SSE")
flag.Parse()
logger.SetDebug(*debugMode)
// 1. Load config
cfg, err := config.LoadConfig("config.yaml")
if err != nil {
log.Fatalf("Error loading config: %v", err)
logger.Error("Error loading config: %v", err)
os.Exit(1)
}
// 2. Create the chosen provider
// Override transport mode if stdio flag is set
if *stdioMode {
cfg.TransportMode = config.StdioTransport
// Ensure stdio is enabled
cfg.Stdio.Enabled = true
// Re-validate config
if err := cfg.Validate(); err != nil {
logger.Error("Configuration error: %v", err)
os.Exit(1)
}
}
logger.Info("Using transport mode: %s", cfg.TransportMode)
logger.Info("Using MCP server base URL: %s", cfg.BaseURL)
logger.Info("Using MCP paths: SSE=%s, Messages=%s", cfg.Paths.SSE, cfg.Paths.Messages)
// 2. Start subprocess if configured and in stdio mode
var procManager *subprocess.Manager
if cfg.TransportMode == config.StdioTransport && cfg.Stdio.Enabled {
// Ensure all required dependencies are available
if err := subprocess.EnsureDependenciesAvailable(cfg.Stdio.UserCommand); err != nil {
logger.Warn("%v", err)
logger.Warn("Subprocess may fail to start due to missing dependencies")
}
procManager = subprocess.NewManager()
if err := procManager.Start(cfg); err != nil {
logger.Warn("Failed to start subprocess: %v", err)
}
} else if cfg.TransportMode == config.SSETransport {
logger.Info("Using SSE transport mode, not starting subprocess")
}
// 3. Create the chosen provider
var provider authz.Provider
if *demoMode {
cfg.Mode = "demo"
@ -46,41 +86,49 @@ func main() {
provider = authz.NewDefaultProvider(cfg)
}
// 3. (Optional) Fetch JWKS if you want local JWT validation
// 4. (Optional) Fetch JWKS if you want local JWT validation
if err := util.FetchJWKS(cfg.JWKSURL); err != nil {
log.Fatalf("Failed to fetch JWKS: %v", err)
logger.Error("Failed to fetch JWKS: %v", err)
os.Exit(1)
}
// 4. Build the main router
// 5. Build the main router
mux := proxy.NewRouter(cfg, provider)
listen_address := fmt.Sprintf(":%d", cfg.ListenPort)
listen_address := fmt.Sprintf("0.0.0.0:%d", cfg.ListenPort)
// 5. Start the server
// 6. Start the server
srv := &http.Server{
Addr: listen_address,
Handler: mux,
}
go func() {
log.Printf("Server listening on %s", listen_address)
logger.Info("Server listening on %s", listen_address)
if err := srv.ListenAndServe(); err != nil && err != http.ErrServerClosed {
log.Fatalf("Server error: %v", err)
logger.Error("Server error: %v", err)
os.Exit(1)
}
}()
// 6. Graceful shutdown on Ctrl+C
// 7. Wait for shutdown signal
stop := make(chan os.Signal, 1)
signal.Notify(stop, os.Interrupt)
signal.Notify(stop, os.Interrupt, syscall.SIGTERM)
<-stop
log.Println("Shutting down...")
logger.Info("Shutting down...")
// 8. First terminate subprocess if running
if procManager != nil && procManager.IsRunning() {
procManager.Shutdown()
}
// 9. Then shutdown the server
logger.Info("Shutting down HTTP server...")
shutdownCtx, cancel := proxy.NewShutdownContext(5 * time.Second)
defer cancel()
if err := srv.Shutdown(shutdownCtx); err != nil {
log.Printf("Shutdown error: %v", err)
logger.Error("HTTP server shutdown error: %v", err)
}
log.Println("Stopped.")
logger.Info("Stopped.")
}

View file

@ -1,22 +1,28 @@
# config.yaml
mcp_server_base_url: ""
# Common configuration for all transport modes
listen_port: 8080
base_url: "http://localhost:8000" # Base URL for the MCP server
port: 8000 # Port for the MCP server
timeout_seconds: 10
mcp_paths:
- /messages/
- /sse
path_mapping:
/token: /token
/register: /register
/authorize: /authorize
/.well-known/oauth-authorization-server: /.well-known/oauth-authorization-server
# Transport mode configuration
transport_mode: "stdio" # Options: "sse" or "stdio"
# stdio-specific configuration (used only when transport_mode is "stdio")
stdio:
enabled: true
user_command: uvx mcp-server-time --local-timezone=Europe/Zurich
#user_command: "npx -y @modelcontextprotocol/server-github"
work_dir: "" # Working directory (optional)
# env: # Environment variables (optional)
# - "NODE_ENV=development"
# CORS settings
cors:
allowed_origins:
- ""
- "http://localhost:6274" # Origin of your frontend/client app
allowed_methods:
- "GET"
- "POST"
@ -25,29 +31,26 @@ cors:
allowed_headers:
- "Authorization"
- "Content-Type"
- "mcp-protocol-version"
allow_credentials: true
demo:
org_name: "openmcpauthdemo"
client_id: "N0U9e_NNGr9mP_0fPnPfPI0a6twa"
client_secret: "qFHfiBp5gNGAO9zV4YPnDofBzzfInatfUbHyPZvM0jka"
asgardeo:
org_name: "<org_name>"
client_id: "<client_id>"
client_secret: "<client_secret>"
# Keycloak endpoint path mappings
path_mapping:
sse: "/sse" # SSE endpoint path
messages: "/messages/" # Messages endpoint path
/token: /realms/master/protocol/openid-connect/token
/register: /realms/master/clients-registrations/openid-connect
# Keycloak configuration block
default:
base_url: "<base_url>"
jwks_url: "<jwks_url>"
base_url: "https://iam.phoenix-systems.ch"
jwks_url: "https://iam.phoenix-systems.ch/realms/kvant/protocol/openid-connect/certs"
path:
/.well-known/oauth-authorization-server:
response:
issuer: "<issuer>"
jwks_uri: "<jwks_uri>"
authorization_endpoint: "<authorization_endpoint>" # Optional
token_endpoint: "<token_endpoint>" # Optional
registration_endpoint: "<registration_endpoint>" # Optional
issuer: "https://iam.phoenix-systems.ch/realms/kvant"
jwks_uri: "https://iam.phoenix-systems.ch/realms/kvant/protocol/openid-connect/certs"
authorization_endpoint: "https://iam.phoenix-systems.ch/realms/kvant/protocol/openid-connect/auth"
response_types_supported:
- "code"
grant_types_supported:
@ -55,17 +58,8 @@ default:
- "refresh_token"
code_challenge_methods_supported:
- "S256"
- "plain"
/authroize:
addQueryParams:
- name: "<name>"
value: "<value>"
- "plain"
/token:
addBodyParams:
- name: "<name>"
value: "<value>"
/register:
addBodyParams:
- name: "<name>"
value: "<value>"
- name: "audience"
value: "mcp_proxy"

View file

@ -4,7 +4,7 @@ This guide will help you configure Open MCP Auth Proxy to use Auth0 as your iden
### Prerequisites
- An Auth0 organization (sign up here if you don't have one)
- An Auth0 organization (sign up [here](https://auth0.com) if you don't have one)
- Open MCP Auth Proxy installed
### Setting Up Auth0
@ -28,9 +28,17 @@ Update your `config.yaml` with Auth0 settings:
```yaml
# Basic proxy configuration
mcp_server_base_url: "http://localhost:8000"
listen_port: 8080
timeout_seconds: 10
base_url: "http://localhost:8000"
port: 8000
# Path configuration
paths:
sse: "/sse"
messages: "/messages/"
# Transport mode
transport_mode: "sse"
# CORS configuration
cors:

View file

@ -0,0 +1,92 @@
## Integrating Open MCP Auth Proxy with Keycloak
This guide walks you through configuring the Open MCP Auth Proxy to authenticate using Keycloak as the identity provider.
---
### Prerequisites
Before you begin, ensure you have the following:
- A running Keycloak instance
- Open MCP Auth Proxy installed and accessible
---
### Step 1: Configure Keycloak for Client Registration
Set up dynamic client registration in your Keycloak realm by following the [Keycloak client registration guide](https://www.keycloak.org/securing-apps/client-registration).
---
### Step 2: Configure Open MCP Auth Proxy
Update the `config.yaml` file in your Open MCP Auth Proxy setup using your Keycloak realm's [OIDC settings](https://www.keycloak.org/securing-apps/oidc-layers). Below is an example configuration:
```yaml
# Proxy server configuration
listen_port: 8081 # Port for the auth proxy
base_url: "http://localhost:8000" # Base URL of the MCP server
port: 8000 # MCP server port
# Define path mappings
paths:
sse: "/sse"
messages: "/messages/"
# Set the transport mode
transport_mode: "sse"
# CORS settings
cors:
allowed_origins:
- "http://localhost:5173" # Origin of your frontend/client app
allowed_methods:
- "GET"
- "POST"
- "PUT"
- "DELETE"
allowed_headers:
- "Authorization"
- "Content-Type"
- "mcp-protocol-version"
allow_credentials: true
# Keycloak endpoint path mappings
path_mapping:
/token: /realms/master/protocol/openid-connect/token
/register: /realms/master/clients-registrations/openid-connect
# Keycloak configuration block
default:
base_url: "http://localhost:8080"
jwks_url: "http://localhost:8080/realms/master/protocol/openid-connect/certs"
path:
/.well-known/oauth-authorization-server:
response:
issuer: "http://localhost:8080/realms/master"
jwks_uri: "http://localhost:8080/realms/master/protocol/openid-connect/certs"
authorization_endpoint: "http://localhost:8080/realms/master/protocol/openid-connect/auth"
response_types_supported:
- "code"
grant_types_supported:
- "authorization_code"
- "refresh_token"
code_challenge_methods_supported:
- "S256"
- "plain"
/token:
addBodyParams:
- name: "audience"
value: "mcp_proxy"
```
### Step 3: Start the Auth Proxy
Launch the proxy with the updated Keycloak configuration:
```bash
./openmcpauthproxy
```
Once running, the proxy will handle authentication requests through your configured Keycloak realm.

2
go.mod
View file

@ -1,6 +1,6 @@
module github.com/wso2/open-mcp-auth-proxy
go 1.22.3
go 1.21
require (
github.com/golang-jwt/jwt/v4 v4.5.2

6
go.sum Normal file
View file

@ -0,0 +1,6 @@
github.com/golang-jwt/jwt/v4 v4.5.2 h1:YtQM7lnr8iZ+j5q71MGKkNw9Mn7AjHM68uc9g5fXeUI=
github.com/golang-jwt/jwt/v4 v4.5.2/go.mod h1:m21LjoU+eqJr34lmDMbreY2eSTRJ1cv77w39/MY0Ch0=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY=
gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ=

View file

@ -7,13 +7,13 @@ import (
"encoding/json"
"fmt"
"io"
"log"
"math/rand"
"net/http"
"strings"
"time"
"github.com/wso2/open-mcp-auth-proxy/internal/config"
"github.com/wso2/open-mcp-auth-proxy/internal/logging"
)
type asgardeoProvider struct {
@ -31,6 +31,7 @@ func (p *asgardeoProvider) WellKnownHandler() http.HandlerFunc {
w.Header().Set("Access-Control-Allow-Origin", "*")
w.Header().Set("Access-Control-Allow-Headers", "Authorization, Content-Type")
w.Header().Set("Access-Control-Allow-Methods", "GET, OPTIONS")
w.Header().Set("X-Accel-Buffering", "no")
if r.Method == http.MethodOptions {
w.WriteHeader(http.StatusNoContent)
@ -70,8 +71,9 @@ func (p *asgardeoProvider) WellKnownHandler() http.HandlerFunc {
}
w.Header().Set("Content-Type", "application/json")
w.Header().Set("X-Accel-Buffering", "no")
if err := json.NewEncoder(w).Encode(response); err != nil {
log.Printf("[asgardeoProvider] Error encoding well-known: %v", err)
logger.Error("Error encoding well-known: %v", err)
http.Error(w, "Internal server error", http.StatusInternalServerError)
}
}
@ -83,6 +85,7 @@ func (p *asgardeoProvider) RegisterHandler() http.HandlerFunc {
w.Header().Set("Access-Control-Allow-Origin", "*")
w.Header().Set("Access-Control-Allow-Headers", "Authorization, Content-Type")
w.Header().Set("Access-Control-Allow-Methods", "GET, OPTIONS")
w.Header().Set("X-Accel-Buffering", "no")
if r.Method == http.MethodOptions {
w.WriteHeader(http.StatusNoContent)
@ -95,7 +98,7 @@ func (p *asgardeoProvider) RegisterHandler() http.HandlerFunc {
var regReq RegisterRequest
if err := json.NewDecoder(r.Body).Decode(&regReq); err != nil {
log.Printf("ERROR: reading register request: %v", err)
logger.Error("Reading register request: %v", err)
http.Error(w, "Invalid request body", http.StatusBadRequest)
return
}
@ -109,7 +112,7 @@ func (p *asgardeoProvider) RegisterHandler() http.HandlerFunc {
regReq.ClientSecret = randomString(16)
if err := p.createAsgardeoApplication(regReq); err != nil {
log.Printf("WARN: Asgardeo application creation failed: %v", err)
logger.Warn("Asgardeo application creation failed: %v", err)
// Optionally http.Error(...) if you want to fail
// or continue to return partial data.
}
@ -124,9 +127,10 @@ func (p *asgardeoProvider) RegisterHandler() http.HandlerFunc {
}
w.Header().Set("Content-Type", "application/json")
w.Header().Set("X-Accel-Buffering", "no")
w.WriteHeader(http.StatusCreated)
if err := json.NewEncoder(w).Encode(resp); err != nil {
log.Printf("ERROR: encoding /register response: %v", err)
logger.Error("Encoding /register response: %v", err)
http.Error(w, "Internal server error", http.StatusInternalServerError)
}
}
@ -186,7 +190,7 @@ func (p *asgardeoProvider) createAsgardeoApplication(regReq RegisterRequest) err
return fmt.Errorf("Asgardeo creation error (%d): %s", resp.StatusCode, string(respBody))
}
log.Printf("INFO: Created Asgardeo application for clientID=%s", regReq.ClientID)
logger.Info("Created Asgardeo application for clientID=%s", regReq.ClientID)
return nil
}
@ -202,8 +206,11 @@ func (p *asgardeoProvider) getAsgardeoAdminToken() (string, error) {
}
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
// Sensitive data - should not be logged at INFO level
auth := p.cfg.Demo.ClientID + ":" + p.cfg.Demo.ClientSecret
req.Header.Set("Authorization", "Basic "+base64.StdEncoding.EncodeToString([]byte(auth)))
logger.Debug("Requesting admin token for Asgardeo with client ID: %s", p.cfg.Demo.ClientID)
tr := &http.Transport{
TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
@ -234,6 +241,10 @@ func (p *asgardeoProvider) getAsgardeoAdminToken() (string, error) {
return "", fmt.Errorf("failed to parse token JSON: %w", err)
}
// Don't log the actual token at info level, only at debug level
logger.Debug("Received access token: %s", tokenResp.AccessToken)
logger.Info("Successfully obtained admin token from Asgardeo")
return tokenResp.AccessToken, nil
}

View file

@ -5,6 +5,7 @@ import (
"net/http"
"github.com/wso2/open-mcp-auth-proxy/internal/config"
"github.com/wso2/open-mcp-auth-proxy/internal/logging"
)
type defaultProvider struct {
@ -81,6 +82,7 @@ func (p *defaultProvider) WellKnownHandler() http.HandlerFunc {
w.Header().Set("Content-Type", "application/json")
if err := json.NewEncoder(w).Encode(response); err != nil {
logger.Error("Error encoding well-known response: %v", err)
http.Error(w, "Internal server error", http.StatusInternalServerError)
}
return

View file

@ -0,0 +1,125 @@
package authz
import (
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
"github.com/wso2/open-mcp-auth-proxy/internal/config"
)
func TestNewDefaultProvider(t *testing.T) {
cfg := &config.Config{}
provider := NewDefaultProvider(cfg)
if provider == nil {
t.Fatal("Expected non-nil provider")
}
// Ensure it implements the Provider interface
var _ Provider = provider
}
func TestDefaultProviderWellKnownHandler(t *testing.T) {
// Create a config with a custom well-known response
cfg := &config.Config{
Default: config.DefaultConfig{
Path: map[string]config.PathConfig{
"/.well-known/oauth-authorization-server": {
Response: &config.ResponseConfig{
Issuer: "https://test-issuer.com",
JwksURI: "https://test-issuer.com/jwks",
ResponseTypesSupported: []string{"code"},
GrantTypesSupported: []string{"authorization_code"},
CodeChallengeMethodsSupported: []string{"S256"},
},
},
},
},
}
provider := NewDefaultProvider(cfg)
handler := provider.WellKnownHandler()
// Create a test request
req := httptest.NewRequest("GET", "/.well-known/oauth-authorization-server", nil)
req.Host = "test-host.com"
req.Header.Set("X-Forwarded-Proto", "https")
// Create a response recorder
w := httptest.NewRecorder()
// Call the handler
handler(w, req)
// Check response status
if w.Code != http.StatusOK {
t.Errorf("Expected status OK, got %v", w.Code)
}
// Verify content type
contentType := w.Header().Get("Content-Type")
if contentType != "application/json" {
t.Errorf("Expected Content-Type: application/json, got %s", contentType)
}
// Decode and check the response body
var response map[string]interface{}
if err := json.NewDecoder(w.Body).Decode(&response); err != nil {
t.Fatalf("Failed to decode response JSON: %v", err)
}
// Check expected values
if response["issuer"] != "https://test-issuer.com" {
t.Errorf("Expected issuer=https://test-issuer.com, got %v", response["issuer"])
}
if response["jwks_uri"] != "https://test-issuer.com/jwks" {
t.Errorf("Expected jwks_uri=https://test-issuer.com/jwks, got %v", response["jwks_uri"])
}
if response["authorization_endpoint"] != "https://test-host.com/authorize" {
t.Errorf("Expected authorization_endpoint=https://test-host.com/authorize, got %v", response["authorization_endpoint"])
}
}
func TestDefaultProviderHandleOPTIONS(t *testing.T) {
provider := NewDefaultProvider(&config.Config{})
handler := provider.WellKnownHandler()
// Create OPTIONS request
req := httptest.NewRequest("OPTIONS", "/.well-known/oauth-authorization-server", nil)
w := httptest.NewRecorder()
// Call the handler
handler(w, req)
// Check response
if w.Code != http.StatusNoContent {
t.Errorf("Expected status NoContent for OPTIONS request, got %v", w.Code)
}
// Check CORS headers
if w.Header().Get("Access-Control-Allow-Origin") != "*" {
t.Errorf("Expected Access-Control-Allow-Origin: *, got %s", w.Header().Get("Access-Control-Allow-Origin"))
}
if w.Header().Get("Access-Control-Allow-Methods") != "GET, OPTIONS" {
t.Errorf("Expected Access-Control-Allow-Methods: GET, OPTIONS, got %s", w.Header().Get("Access-Control-Allow-Methods"))
}
}
func TestDefaultProviderInvalidMethod(t *testing.T) {
provider := NewDefaultProvider(&config.Config{})
handler := provider.WellKnownHandler()
// Create POST request (which should be rejected)
req := httptest.NewRequest("POST", "/.well-known/oauth-authorization-server", nil)
w := httptest.NewRecorder()
// Call the handler
handler(w, req)
// Check response
if w.Code != http.StatusMethodNotAllowed {
t.Errorf("Expected status MethodNotAllowed for POST request, got %v", w.Code)
}
}

View file

@ -1,12 +1,35 @@
package config
import (
"fmt"
"os"
"gopkg.in/yaml.v2"
)
// AsgardeoConfig groups all Asgardeo-specific fields
// Transport mode for MCP server
type TransportMode string
const (
SSETransport TransportMode = "sse"
StdioTransport TransportMode = "stdio"
)
// Common path configuration for all transport modes
type PathsConfig struct {
SSE string `yaml:"sse"`
Messages string `yaml:"messages"`
}
// StdioConfig contains stdio-specific configuration
type StdioConfig struct {
Enabled bool `yaml:"enabled"`
UserCommand string `yaml:"user_command"` // The command provided by the user
WorkDir string `yaml:"work_dir"` // Working directory (optional)
Args []string `yaml:"args,omitempty"` // Additional arguments
Env []string `yaml:"env,omitempty"` // Environment variables
}
type DemoConfig struct {
ClientID string `yaml:"client_id"`
ClientSecret string `yaml:"client_secret"`
@ -60,15 +83,18 @@ type DefaultConfig struct {
}
type Config struct {
AuthServerBaseURL string
MCPServerBaseURL string `yaml:"mcp_server_base_url"`
ListenPort int `yaml:"listen_port"`
JWKSURL string
TimeoutSeconds int `yaml:"timeout_seconds"`
MCPPaths []string `yaml:"mcp_paths"`
PathMapping map[string]string `yaml:"path_mapping"`
Mode string `yaml:"mode"`
CORSConfig CORSConfig `yaml:"cors"`
AuthServerBaseURL string
ListenPort int `yaml:"listen_port"`
BaseURL string `yaml:"base_url"`
Port int `yaml:"port"`
JWKSURL string
TimeoutSeconds int `yaml:"timeout_seconds"`
PathMapping map[string]string `yaml:"path_mapping"`
Mode string `yaml:"mode"`
CORSConfig CORSConfig `yaml:"cors"`
TransportMode TransportMode `yaml:"transport_mode"`
Paths PathsConfig `yaml:"paths"`
Stdio StdioConfig `yaml:"stdio"`
// Nested config for Asgardeo
Demo DemoConfig `yaml:"demo"`
@ -76,6 +102,56 @@ type Config struct {
Default DefaultConfig `yaml:"default"`
}
// Validate checks if the config is valid based on transport mode
func (c *Config) Validate() error {
// Validate based on transport mode
if c.TransportMode == StdioTransport {
if !c.Stdio.Enabled {
return fmt.Errorf("stdio.enabled must be true in stdio transport mode")
}
if c.Stdio.UserCommand == "" {
return fmt.Errorf("stdio.user_command is required in stdio transport mode")
}
}
// Validate paths
if c.Paths.SSE == "" {
c.Paths.SSE = "/sse" // Default value
}
if c.Paths.Messages == "" {
c.Paths.Messages = "/messages" // Default value
}
// Validate base URL
if c.BaseURL == "" {
if c.Port > 0 {
c.BaseURL = fmt.Sprintf("http://localhost:%d", c.Port)
} else {
c.BaseURL = "http://localhost:8000" // Default value
}
}
return nil
}
// GetMCPPaths returns the list of paths that should be proxied to the MCP server
func (c *Config) GetMCPPaths() []string {
return []string{c.Paths.SSE, c.Paths.Messages}
}
// BuildExecCommand constructs the full command string for execution in stdio mode
func (c *Config) BuildExecCommand() string {
if c.Stdio.UserCommand == "" {
return ""
}
// Construct the full command
return fmt.Sprintf(
`npx -y supergateway --stdio "%s" --port %d --baseUrl %s --ssePath %s --messagePath %s`,
c.Stdio.UserCommand, c.Port, c.BaseURL, c.Paths.SSE, c.Paths.Messages,
)
}
// LoadConfig reads a YAML config file into Config struct.
func LoadConfig(path string) (*Config, error) {
f, err := os.Open(path)
@ -89,8 +165,26 @@ func LoadConfig(path string) (*Config, error) {
if err := decoder.Decode(&cfg); err != nil {
return nil, err
}
// Set default values
if cfg.TimeoutSeconds == 0 {
cfg.TimeoutSeconds = 15 // default
}
// Set default transport mode if not specified
if cfg.TransportMode == "" {
cfg.TransportMode = SSETransport // Default to SSE
}
// Set default port if not specified
if cfg.Port == 0 {
cfg.Port = 8000 // default
}
// Validate the configuration
if err := cfg.Validate(); err != nil {
return nil, err
}
return &cfg, nil
}

View file

@ -0,0 +1,196 @@
package config
import (
"os"
"path/filepath"
"testing"
)
func TestLoadConfig(t *testing.T) {
// Create a temporary config file
tempDir := t.TempDir()
configPath := filepath.Join(tempDir, "test_config.yaml")
// Basic valid config
validConfig := `
listen_port: 8080
base_url: "http://localhost:8000"
transport_mode: "sse"
paths:
sse: "/sse"
messages: "/messages"
cors:
allowed_origins:
- "http://localhost:5173"
allowed_methods:
- "GET"
- "POST"
allowed_headers:
- "Authorization"
- "Content-Type"
allow_credentials: true
`
err := os.WriteFile(configPath, []byte(validConfig), 0644)
if err != nil {
t.Fatalf("Failed to create test config file: %v", err)
}
// Test loading the valid config
cfg, err := LoadConfig(configPath)
if err != nil {
t.Fatalf("Failed to load valid config: %v", err)
}
// Verify expected values from the config
if cfg.ListenPort != 8080 {
t.Errorf("Expected ListenPort=8080, got %d", cfg.ListenPort)
}
if cfg.BaseURL != "http://localhost:8000" {
t.Errorf("Expected BaseURL=http://localhost:8000, got %s", cfg.BaseURL)
}
if cfg.TransportMode != SSETransport {
t.Errorf("Expected TransportMode=sse, got %s", cfg.TransportMode)
}
if cfg.Paths.SSE != "/sse" {
t.Errorf("Expected Paths.SSE=/sse, got %s", cfg.Paths.SSE)
}
if cfg.Paths.Messages != "/messages" {
t.Errorf("Expected Paths.Messages=/messages, got %s", cfg.Paths.Messages)
}
// Test default values
if cfg.TimeoutSeconds != 15 {
t.Errorf("Expected default TimeoutSeconds=15, got %d", cfg.TimeoutSeconds)
}
if cfg.Port != 8000 {
t.Errorf("Expected default Port=8000, got %d", cfg.Port)
}
}
func TestValidate(t *testing.T) {
tests := []struct {
name string
config Config
expectError bool
}{
{
name: "Valid SSE config",
config: Config{
TransportMode: SSETransport,
Paths: PathsConfig{
SSE: "/sse",
Messages: "/messages",
},
BaseURL: "http://localhost:8000",
},
expectError: false,
},
{
name: "Valid stdio config",
config: Config{
TransportMode: StdioTransport,
Stdio: StdioConfig{
Enabled: true,
UserCommand: "some-command",
},
},
expectError: false,
},
{
name: "Invalid stdio config - not enabled",
config: Config{
TransportMode: StdioTransport,
Stdio: StdioConfig{
Enabled: false,
UserCommand: "some-command",
},
},
expectError: true,
},
{
name: "Invalid stdio config - no command",
config: Config{
TransportMode: StdioTransport,
Stdio: StdioConfig{
Enabled: true,
UserCommand: "",
},
},
expectError: true,
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
err := tc.config.Validate()
if tc.expectError && err == nil {
t.Errorf("Expected validation error but got none")
}
if !tc.expectError && err != nil {
t.Errorf("Expected no validation error but got: %v", err)
}
})
}
}
func TestGetMCPPaths(t *testing.T) {
cfg := Config{
Paths: PathsConfig{
SSE: "/custom-sse",
Messages: "/custom-messages",
},
}
paths := cfg.GetMCPPaths()
if len(paths) != 2 {
t.Errorf("Expected 2 MCP paths, got %d", len(paths))
}
if paths[0] != "/custom-sse" {
t.Errorf("Expected first path=/custom-sse, got %s", paths[0])
}
if paths[1] != "/custom-messages" {
t.Errorf("Expected second path=/custom-messages, got %s", paths[1])
}
}
func TestBuildExecCommand(t *testing.T) {
tests := []struct {
name string
config Config
expectedResult string
}{
{
name: "Valid command",
config: Config{
Stdio: StdioConfig{
UserCommand: "test-command",
},
Port: 8080,
BaseURL: "http://example.com",
Paths: PathsConfig{
SSE: "/sse-path",
Messages: "/msgs",
},
},
expectedResult: `npx -y supergateway --stdio "test-command" --port 8080 --baseUrl http://example.com --ssePath /sse-path --messagePath /msgs`,
},
{
name: "Empty command",
config: Config{
Stdio: StdioConfig{
UserCommand: "",
},
},
expectedResult: "",
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
result := tc.config.BuildExecCommand()
if result != tc.expectedResult {
t.Errorf("Expected command=%s, got %s", tc.expectedResult, result)
}
})
}
}

View file

@ -0,0 +1,34 @@
package logger
import (
"log"
)
var isDebug = false
// SetDebug enables or disables debug logging
func SetDebug(debug bool) {
isDebug = debug
}
// Debug logs a debug-level message
func Debug(format string, v ...interface{}) {
if isDebug {
log.Printf("DEBUG: "+format, v...)
}
}
// Info logs an info-level message
func Info(format string, v ...interface{}) {
log.Printf("INFO: "+format, v...)
}
// Warn logs a warning-level message
func Warn(format string, v ...interface{}) {
log.Printf("WARN: "+format, v...)
}
// Error logs an error-level message
func Error(format string, v ...interface{}) {
log.Printf("ERROR: "+format, v...)
}

View file

@ -9,6 +9,7 @@ import (
"strings"
"github.com/wso2/open-mcp-auth-proxy/internal/config"
"github.com/wso2/open-mcp-auth-proxy/internal/logging"
)
// RequestModifier modifies requests before they are proxied
@ -148,6 +149,7 @@ func (m *RegisterModifier) ModifyRequest(req *http.Request) (*http.Request, erro
if strings.Contains(contentType, "application/x-www-form-urlencoded") {
// Parse form data
if err := req.ParseForm(); err != nil {
logger.Error("Failed to parse form data: %v", err)
return nil, err
}
@ -169,12 +171,14 @@ func (m *RegisterModifier) ModifyRequest(req *http.Request) (*http.Request, erro
// Read body
bodyBytes, err := io.ReadAll(req.Body)
if err != nil {
logger.Error("Failed to read request body: %v", err)
return nil, err
}
// Parse JSON
var jsonData map[string]interface{}
if err := json.Unmarshal(bodyBytes, &jsonData); err != nil {
logger.Error("Failed to parse JSON body: %v", err)
return nil, err
}
@ -186,6 +190,7 @@ func (m *RegisterModifier) ModifyRequest(req *http.Request) (*http.Request, erro
// Marshal back to JSON
modifiedBody, err := json.Marshal(jsonData)
if err != nil {
logger.Error("Failed to marshal modified JSON: %v", err)
return nil, err
}

View file

@ -0,0 +1,147 @@
package proxy
import (
"net/http"
"net/url"
"strings"
"testing"
"github.com/wso2/open-mcp-auth-proxy/internal/config"
)
func TestAuthorizationModifier(t *testing.T) {
cfg := &config.Config{
Default: config.DefaultConfig{
Path: map[string]config.PathConfig{
"/authorize": {
AddQueryParams: []config.ParamConfig{
{Name: "client_id", Value: "test-client-id"},
{Name: "scope", Value: "openid"},
},
},
},
},
}
modifier := &AuthorizationModifier{Config: cfg}
// Create a test request
req, err := http.NewRequest("GET", "/authorize?response_type=code", nil)
if err != nil {
t.Fatalf("Failed to create request: %v", err)
}
// Modify the request
modifiedReq, err := modifier.ModifyRequest(req)
if err != nil {
t.Fatalf("ModifyRequest failed: %v", err)
}
// Check that the query parameters were added
query := modifiedReq.URL.Query()
if query.Get("client_id") != "test-client-id" {
t.Errorf("Expected client_id=test-client-id, got %s", query.Get("client_id"))
}
if query.Get("scope") != "openid" {
t.Errorf("Expected scope=openid, got %s", query.Get("scope"))
}
if query.Get("response_type") != "code" {
t.Errorf("Expected response_type=code, got %s", query.Get("response_type"))
}
}
func TestTokenModifier(t *testing.T) {
cfg := &config.Config{
Default: config.DefaultConfig{
Path: map[string]config.PathConfig{
"/token": {
AddBodyParams: []config.ParamConfig{
{Name: "audience", Value: "test-audience"},
},
},
},
},
}
modifier := &TokenModifier{Config: cfg}
// Create a test request with form data
form := url.Values{}
req, err := http.NewRequest("POST", "/token", strings.NewReader(form.Encode()))
if err != nil {
t.Fatalf("Failed to create request: %v", err)
}
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
// Modify the request
modifiedReq, err := modifier.ModifyRequest(req)
if err != nil {
t.Fatalf("ModifyRequest failed: %v", err)
}
body := make([]byte, 1024)
n, err := modifiedReq.Body.Read(body)
if err != nil && err.Error() != "EOF" {
t.Fatalf("Failed to read body: %v", err)
}
bodyStr := string(body[:n])
// Parse the form data from the modified request
if err := modifiedReq.ParseForm(); err != nil {
t.Fatalf("Failed to parse form data: %v", err)
}
// Check that the body parameters were added
if !strings.Contains(bodyStr, "audience") {
t.Errorf("Expected body to contain audience, got %s", bodyStr)
}
}
func TestRegisterModifier(t *testing.T) {
cfg := &config.Config{
Default: config.DefaultConfig{
Path: map[string]config.PathConfig{
"/register": {
AddBodyParams: []config.ParamConfig{
{Name: "client_name", Value: "test-client"},
},
},
},
},
}
modifier := &RegisterModifier{Config: cfg}
// Create a test request with JSON data
jsonBody := `{"redirect_uris":["https://example.com/callback"]}`
req, err := http.NewRequest("POST", "/register", strings.NewReader(jsonBody))
if err != nil {
t.Fatalf("Failed to create request: %v", err)
}
req.Header.Set("Content-Type", "application/json")
// Modify the request
modifiedReq, err := modifier.ModifyRequest(req)
if err != nil {
t.Fatalf("ModifyRequest failed: %v", err)
}
// Read the body and check that it still contains the original data
// This test would need to be enhanced with a proper JSON parsing to verify
// the added parameters
body := make([]byte, 1024)
n, err := modifiedReq.Body.Read(body)
if err != nil && err.Error() != "EOF" {
t.Fatalf("Failed to read body: %v", err)
}
bodyStr := string(body[:n])
// Simple check to see if the modified body contains the expected fields
if !strings.Contains(bodyStr, "client_name") {
t.Errorf("Expected body to contain client_name, got %s", bodyStr)
}
if !strings.Contains(bodyStr, "redirect_uris") {
t.Errorf("Expected body to contain redirect_uris, got %s", bodyStr)
}
}

View file

@ -2,7 +2,6 @@ package proxy
import (
"context"
"log"
"net/http"
"net/http/httputil"
"net/url"
@ -11,6 +10,7 @@ import (
"github.com/wso2/open-mcp-auth-proxy/internal/authz"
"github.com/wso2/open-mcp-auth-proxy/internal/config"
"github.com/wso2/open-mcp-auth-proxy/internal/logging"
"github.com/wso2/open-mcp-auth-proxy/internal/util"
)
@ -82,7 +82,8 @@ func NewRouter(cfg *config.Config, provider authz.Provider) http.Handler {
}
// MCP paths
for _, path := range cfg.MCPPaths {
mcpPaths := cfg.GetMCPPaths()
for _, path := range mcpPaths {
mux.HandleFunc(path, buildProxyHandler(cfg, modifiers))
registeredPaths[path] = true
}
@ -100,23 +101,21 @@ func NewRouter(cfg *config.Config, provider authz.Provider) http.Handler {
func buildProxyHandler(cfg *config.Config, modifiers map[string]RequestModifier) http.HandlerFunc {
// Parse the base URLs up front
authBase, err := url.Parse(cfg.AuthServerBaseURL)
if err != nil {
log.Fatalf("Invalid auth server URL: %v", err)
logger.Error("Invalid auth server URL: %v", err)
panic(err) // Fatal error that prevents startup
}
mcpBase, err := url.Parse(cfg.MCPServerBaseURL)
mcpBase, err := url.Parse(cfg.BaseURL)
if err != nil {
log.Fatalf("Invalid MCP server URL: %v", err)
logger.Error("Invalid MCP server URL: %v", err)
panic(err) // Fatal error that prevents startup
}
// Detect SSE paths from config
ssePaths := make(map[string]bool)
for _, p := range cfg.MCPPaths {
if p == "/sse" {
ssePaths[p] = true
}
}
ssePaths[cfg.Paths.SSE] = true
return func(w http.ResponseWriter, r *http.Request) {
origin := r.Header.Get("Origin")
@ -124,7 +123,7 @@ func buildProxyHandler(cfg *config.Config, modifiers map[string]RequestModifier)
// Handle OPTIONS
if r.Method == http.MethodOptions {
if allowedOrigin == "" {
log.Printf("[proxy] Preflight request from disallowed origin: %s", origin)
logger.Warn("Preflight request from disallowed origin: %s", origin)
http.Error(w, "CORS origin not allowed", http.StatusForbidden)
return
}
@ -134,7 +133,7 @@ func buildProxyHandler(cfg *config.Config, modifiers map[string]RequestModifier)
}
if allowedOrigin == "" {
log.Printf("[proxy] Request from disallowed origin: %s for %s", origin, r.URL.Path)
logger.Warn("Request from disallowed origin: %s for %s", origin, r.URL.Path)
http.Error(w, "CORS origin not allowed", http.StatusForbidden)
return
}
@ -152,7 +151,7 @@ func buildProxyHandler(cfg *config.Config, modifiers map[string]RequestModifier)
// Validate JWT for MCP paths if required
// Placeholder for JWT validation logic
if err := util.ValidateJWT(r.Header.Get("Authorization")); err != nil {
log.Printf("[proxy] Unauthorized request to %s: %v", r.URL.Path, err)
logger.Warn("Unauthorized request to %s: %v", r.URL.Path, err)
http.Error(w, "Unauthorized", http.StatusUnauthorized)
return
}
@ -170,7 +169,7 @@ func buildProxyHandler(cfg *config.Config, modifiers map[string]RequestModifier)
var err error
r, err = modifier.ModifyRequest(r)
if err != nil {
log.Printf("[proxy] Error modifying request: %v", err)
logger.Error("Error modifying request: %v", err)
http.Error(w, "Bad Request", http.StatusBadRequest)
return
}
@ -192,7 +191,13 @@ func buildProxyHandler(cfg *config.Config, modifiers map[string]RequestModifier)
req.Host = targetURL.Host
cleanHeaders := http.Header{}
// Set proper origin header to match the target
if isSSE {
// For SSE, ensure origin matches the target
req.Header.Set("Origin", targetURL.Scheme+"://"+targetURL.Host)
}
for k, v := range r.Header {
// Skip hop-by-hop headers
if skipHeader(k) {
@ -205,21 +210,33 @@ func buildProxyHandler(cfg *config.Config, modifiers map[string]RequestModifier)
req.Header = cleanHeaders
log.Printf("[proxy] %s -> %s%s", r.URL.Path, req.URL.Host, req.URL.Path)
logger.Debug("%s -> %s%s", r.URL.Path, req.URL.Host, req.URL.Path)
},
ModifyResponse: func(resp *http.Response) error {
log.Printf("[proxy] Response from %s%s: %d", resp.Request.URL.Host, resp.Request.URL.Path, resp.StatusCode)
logger.Debug("Response from %s%s: %d", resp.Request.URL.Host, resp.Request.URL.Path, resp.StatusCode)
resp.Header.Del("Access-Control-Allow-Origin") // Avoid upstream conflicts
return nil
},
ErrorHandler: func(rw http.ResponseWriter, req *http.Request, err error) {
log.Printf("[proxy] Error proxying: %v", err)
logger.Error("Error proxying: %v", err)
http.Error(rw, "Bad Gateway", http.StatusBadGateway)
},
FlushInterval: -1, // immediate flush for SSE
}
if isSSE {
// Add special response handling for SSE connections to rewrite endpoint URLs
rp.Transport = &sseTransport{
Transport: http.DefaultTransport,
proxyHost: r.Host,
targetHost: targetURL.Host,
}
// Set SSE-specific headers
w.Header().Set("X-Accel-Buffering", "no")
w.Header().Set("Cache-Control", "no-cache")
w.Header().Set("Connection", "keep-alive")
// Keep SSE connections open
HandleSSE(w, r, rp)
} else {
@ -236,6 +253,7 @@ func getAllowedOrigin(origin string, cfg *config.Config) string {
return cfg.CORSConfig.AllowedOrigins[0] // Default to first allowed origin
}
for _, allowed := range cfg.CORSConfig.AllowedOrigins {
logger.Debug("Checking CORS origin: %s against allowed: %s", origin, allowed)
if allowed == origin {
return allowed
}
@ -256,6 +274,7 @@ func addCORSHeaders(w http.ResponseWriter, cfg *config.Config, allowedOrigin, re
w.Header().Set("Access-Control-Allow-Credentials", "true")
}
w.Header().Set("Vary", "Origin")
w.Header().Set("X-Accel-Buffering", "no")
}
func isAuthPath(path string) bool {
@ -273,7 +292,8 @@ func isAuthPath(path string) bool {
// isMCPPath checks if the path is an MCP path
func isMCPPath(path string, cfg *config.Config) bool {
for _, p := range cfg.MCPPaths {
mcpPaths := cfg.GetMCPPaths()
for _, p := range mcpPaths {
if strings.HasPrefix(path, p) {
return true
}

View file

@ -1,11 +1,16 @@
package proxy
import (
"bufio"
"context"
"log"
"fmt"
"io"
"net/http"
"net/http/httputil"
"strings"
"time"
"github.com/wso2/open-mcp-auth-proxy/internal/logging"
)
// HandleSSE sets up a go-routine to wait for context cancellation
@ -16,7 +21,7 @@ func HandleSSE(w http.ResponseWriter, r *http.Request, rp *httputil.ReverseProxy
go func() {
<-ctx.Done()
log.Printf("INFO: SSE connection closed from %s (path: %s)", r.RemoteAddr, r.URL.Path)
logger.Info("SSE connection closed from %s (path: %s)", r.RemoteAddr, r.URL.Path)
close(done)
}()
@ -32,3 +37,73 @@ func HandleSSE(w http.ResponseWriter, r *http.Request, rp *httputil.ReverseProxy
func NewShutdownContext(timeout time.Duration) (context.Context, context.CancelFunc) {
return context.WithTimeout(context.Background(), timeout)
}
// sseTransport is a custom http.RoundTripper that intercepts and modifies SSE responses
type sseTransport struct {
Transport http.RoundTripper
proxyHost string
targetHost string
}
func (t *sseTransport) RoundTrip(req *http.Request) (*http.Response, error) {
// Call the underlying transport
resp, err := t.Transport.RoundTrip(req)
if err != nil {
return nil, err
}
// Check if this is an SSE response
contentType := resp.Header.Get("Content-Type")
if !strings.Contains(contentType, "text/event-stream") {
return resp, nil
}
logger.Info("Intercepting SSE response to modify endpoint events")
// Create a response wrapper that modifies the response body
originalBody := resp.Body
pr, pw := io.Pipe()
go func() {
defer originalBody.Close()
defer pw.Close()
scanner := bufio.NewScanner(originalBody)
for scanner.Scan() {
line := scanner.Text()
// Check if this line contains an endpoint event
if strings.HasPrefix(line, "event: endpoint") {
// Read the data line
if scanner.Scan() {
dataLine := scanner.Text()
if strings.HasPrefix(dataLine, "data: ") {
// Extract the endpoint URL
endpoint := strings.TrimPrefix(dataLine, "data: ")
// Replace the host in the endpoint
logger.Debug("Original endpoint: %s", endpoint)
endpoint = strings.Replace(endpoint, t.targetHost, t.proxyHost, 1)
logger.Debug("Modified endpoint: %s", endpoint)
// Write the modified event lines
fmt.Fprintln(pw, line)
fmt.Fprintln(pw, "data: "+endpoint)
continue
}
}
}
// Write the original line for non-endpoint events
fmt.Fprintln(pw, line)
}
if err := scanner.Err(); err != nil {
logger.Error("Error reading SSE stream: %v", err)
}
}()
// Replace the response body with our modified pipe
resp.Body = pr
return resp, nil
}

View file

@ -0,0 +1,268 @@
package subprocess
import (
"fmt"
"os"
"os/exec"
"sync"
"syscall"
"time"
"strings"
"github.com/wso2/open-mcp-auth-proxy/internal/config"
"github.com/wso2/open-mcp-auth-proxy/internal/logging"
)
// Manager handles starting and graceful shutdown of subprocesses
type Manager struct {
process *os.Process
processGroup int
mutex sync.Mutex
cmd *exec.Cmd
shutdownDelay time.Duration
}
// NewManager creates a new subprocess manager
func NewManager() *Manager {
return &Manager{
shutdownDelay: 5 * time.Second,
}
}
// EnsureDependenciesAvailable checks and installs required package executors
func EnsureDependenciesAvailable(command string) error {
// Always ensure npx is available regardless of the command
if _, err := exec.LookPath("npx"); err != nil {
// npx is not available, check if npm is installed
if _, err := exec.LookPath("npm"); err != nil {
return fmt.Errorf("npx not found and npm not available; please install Node.js from https://nodejs.org/")
}
// Try to install npx using npm
logger.Info("npx not found, attempting to install...")
cmd := exec.Command("npm", "install", "-g", "npx")
cmd.Stdout = os.Stdout
cmd.Stderr = os.Stderr
if err := cmd.Run(); err != nil {
return fmt.Errorf("failed to install npx: %w", err)
}
logger.Info("npx installed successfully")
}
// Check if uv is needed based on the command
if strings.Contains(command, "uv ") {
if _, err := exec.LookPath("uv"); err != nil {
return fmt.Errorf("command requires uv but it's not installed; please install it following instructions at https://github.com/astral-sh/uv")
}
}
return nil
}
// SetShutdownDelay sets the maximum time to wait for graceful shutdown
func (m *Manager) SetShutdownDelay(duration time.Duration) {
m.shutdownDelay = duration
}
// Start launches a subprocess based on the configuration
func (m *Manager) Start(cfg *config.Config) error {
m.mutex.Lock()
defer m.mutex.Unlock()
// If a process is already running, return an error
if m.process != nil {
return os.ErrExist
}
if !cfg.Stdio.Enabled || cfg.Stdio.UserCommand == "" {
return nil // Nothing to start
}
// Get the full command string
execCommand := cfg.BuildExecCommand()
if execCommand == "" {
return nil // No command to execute
}
logger.Info("Starting subprocess with command: %s", execCommand)
// Use the shell to execute the command
cmd := exec.Command("sh", "-c", execCommand)
// Set working directory if specified
if cfg.Stdio.WorkDir != "" {
cmd.Dir = cfg.Stdio.WorkDir
}
// Set environment variables if specified
if len(cfg.Stdio.Env) > 0 {
cmd.Env = append(os.Environ(), cfg.Stdio.Env...)
}
// Capture stdout/stderr
cmd.Stdout = os.Stdout
cmd.Stderr = os.Stderr
// Set the process group for proper termination
cmd.SysProcAttr = &syscall.SysProcAttr{Setpgid: true}
// Start the process
if err := cmd.Start(); err != nil {
return err
}
m.process = cmd.Process
m.cmd = cmd
logger.Info("Subprocess started with PID: %d", m.process.Pid)
// Get and store the process group ID
pgid, err := syscall.Getpgid(m.process.Pid)
if err == nil {
m.processGroup = pgid
logger.Debug("Process group ID: %d", m.processGroup)
} else {
logger.Warn("Failed to get process group ID: %v", err)
m.processGroup = m.process.Pid
}
// Handle process termination in background
go func() {
if err := cmd.Wait(); err != nil {
logger.Error("Subprocess exited with error: %v", err)
} else {
logger.Info("Subprocess exited successfully")
}
// Clear the process reference when it exits
m.mutex.Lock()
m.process = nil
m.cmd = nil
m.mutex.Unlock()
}()
return nil
}
// IsRunning checks if the subprocess is running
func (m *Manager) IsRunning() bool {
m.mutex.Lock()
defer m.mutex.Unlock()
return m.process != nil
}
// Shutdown gracefully terminates the subprocess
func (m *Manager) Shutdown() {
m.mutex.Lock()
processToTerminate := m.process // Local copy of the process reference
processGroupToTerminate := m.processGroup
m.mutex.Unlock()
if processToTerminate == nil {
return // No process to terminate
}
logger.Info("Terminating subprocess...")
terminateComplete := make(chan struct{})
go func() {
defer close(terminateComplete)
// Try graceful termination first with SIGTERM
terminatedGracefully := false
// Try to terminate the process group first
if processGroupToTerminate != 0 {
err := syscall.Kill(-processGroupToTerminate, syscall.SIGTERM)
if err != nil {
logger.Warn("Failed to send SIGTERM to process group: %v", err)
// Fallback to terminating just the process
m.mutex.Lock()
if m.process != nil {
err = m.process.Signal(syscall.SIGTERM)
if err != nil {
logger.Warn("Failed to send SIGTERM to process: %v", err)
}
}
m.mutex.Unlock()
}
} else {
// Try to terminate just the process
m.mutex.Lock()
if m.process != nil {
err := m.process.Signal(syscall.SIGTERM)
if err != nil {
logger.Warn("Failed to send SIGTERM to process: %v", err)
}
}
m.mutex.Unlock()
}
// Wait for the process to exit gracefully
for i := 0; i < 10; i++ {
time.Sleep(200 * time.Millisecond)
m.mutex.Lock()
if m.process == nil {
terminatedGracefully = true
m.mutex.Unlock()
break
}
m.mutex.Unlock()
}
if terminatedGracefully {
logger.Info("Subprocess terminated gracefully")
return
}
// If the process didn't exit gracefully, force kill
logger.Warn("Subprocess didn't exit gracefully, forcing termination...")
// Try to kill the process group first
if processGroupToTerminate != 0 {
if err := syscall.Kill(-processGroupToTerminate, syscall.SIGKILL); err != nil {
logger.Warn("Failed to send SIGKILL to process group: %v", err)
// Fallback to killing just the process
m.mutex.Lock()
if m.process != nil {
if err := m.process.Kill(); err != nil {
logger.Error("Failed to kill process: %v", err)
}
}
m.mutex.Unlock()
}
} else {
// Try to kill just the process
m.mutex.Lock()
if m.process != nil {
if err := m.process.Kill(); err != nil {
logger.Error("Failed to kill process: %v", err)
}
}
m.mutex.Unlock()
}
// Wait a bit more to confirm termination
time.Sleep(500 * time.Millisecond)
m.mutex.Lock()
if m.process == nil {
logger.Info("Subprocess terminated by force")
} else {
logger.Warn("Failed to terminate subprocess")
}
m.mutex.Unlock()
}()
// Wait for termination with timeout
select {
case <-terminateComplete:
// Termination completed
case <-time.After(m.shutdownDelay):
logger.Warn("Subprocess termination timed out")
}
}

View file

@ -4,12 +4,12 @@ import (
"crypto/rsa"
"encoding/json"
"errors"
"log"
"math/big"
"net/http"
"strings"
"github.com/golang-jwt/jwt/v4"
"github.com/wso2/open-mcp-auth-proxy/internal/logging"
)
type JWKS struct {
@ -50,7 +50,7 @@ func FetchJWKS(jwksURL string) error {
publicKeys[parsedKey.Kid] = pubKey
}
}
log.Printf("[JWKS] Loaded %d public keys.", len(publicKeys))
logger.Info("Loaded %d public keys.", len(publicKeys))
return nil
}

143
internal/util/jwks_test.go Normal file
View file

@ -0,0 +1,143 @@
package util
import (
"crypto/rand"
"crypto/rsa"
"encoding/base64"
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
"time"
"github.com/golang-jwt/jwt/v4"
)
func TestValidateJWT(t *testing.T) {
// Initialize the test JWKS data
initTestJWKS(t)
// Test cases
tests := []struct {
name string
authHeader string
expectError bool
}{
{
name: "Valid JWT token",
authHeader: "Bearer " + createValidJWT(t),
expectError: false,
},
{
name: "No auth header",
authHeader: "",
expectError: true,
},
{
name: "Invalid auth header format",
authHeader: "InvalidFormat",
expectError: true,
},
{
name: "Invalid JWT token",
authHeader: "Bearer invalid.jwt.token",
expectError: true,
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
err := ValidateJWT(tc.authHeader)
if tc.expectError && err == nil {
t.Errorf("Expected error but got none")
}
if !tc.expectError && err != nil {
t.Errorf("Expected no error but got: %v", err)
}
})
}
}
func TestFetchJWKS(t *testing.T) {
// Create a mock JWKS server
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Generate a test RSA key
privateKey, err := rsa.GenerateKey(rand.Reader, 2048)
if err != nil {
t.Fatalf("Failed to generate RSA key: %v", err)
}
// Create JWKS response
jwks := map[string]interface{}{
"keys": []map[string]interface{}{
{
"kty": "RSA",
"kid": "test-key-id",
"n": base64.RawURLEncoding.EncodeToString(privateKey.N.Bytes()),
"e": base64.RawURLEncoding.EncodeToString([]byte{1, 0, 1}), // Default exponent 65537
},
},
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(jwks)
}))
defer server.Close()
// Test fetching JWKS
err := FetchJWKS(server.URL)
if err != nil {
t.Fatalf("FetchJWKS failed: %v", err)
}
// Check that keys were stored
if len(publicKeys) == 0 {
t.Errorf("Expected publicKeys to be populated")
}
}
// Helper function to initialize test JWKS data
func initTestJWKS(t *testing.T) {
// Create a test RSA key pair
privateKey, err := rsa.GenerateKey(rand.Reader, 2048)
if err != nil {
t.Fatalf("Failed to generate RSA key: %v", err)
}
// Initialize the publicKeys map
publicKeys = map[string]*rsa.PublicKey{
"test-key-id": &privateKey.PublicKey,
}
}
// Helper function to create a valid JWT token for testing
func createValidJWT(t *testing.T) string {
// Create a test RSA key pair
privateKey, err := rsa.GenerateKey(rand.Reader, 2048)
if err != nil {
t.Fatalf("Failed to generate RSA key: %v", err)
}
// Ensure the test key is in the publicKeys map
if publicKeys == nil {
publicKeys = map[string]*rsa.PublicKey{}
}
publicKeys["test-key-id"] = &privateKey.PublicKey
// Create token
token := jwt.NewWithClaims(jwt.SigningMethodRS256, jwt.MapClaims{
"sub": "1234567890",
"name": "Test User",
"iat": time.Now().Unix(),
"exp": time.Now().Add(time.Hour).Unix(),
})
token.Header["kid"] = "test-key-id"
// Sign the token
tokenString, err := token.SignedString(privateKey)
if err != nil {
t.Fatalf("Failed to sign token: %v", err)
}
return tokenString
}

View file

@ -1,52 +1,11 @@
## Purpose
> Describe the problems, issues, or needs driving this feature/fix and include links to related issues in the following format: Resolves issue1, issue2, etc.
<!-- Describe the problems, issues, or needs driving this feature/fix and include links to related issues in the following format: Resolves issue1, issue2, etc. -->
## Goals
> Describe the solutions that this feature/fix will introduce to resolve the problems described above
## Approach
> Describe how you are implementing the solutions. Include an animated GIF or screenshot if the change affects the UI (email documentation@wso2.com to review all UI text). Include a link to a Markdown file or Google doc if the feature write-up is too long to paste here.
## User stories
> Summary of user stories addressed by this change>
## Release note
> Brief description of the new feature or bug fix as it will appear in the release notes
## Documentation
> Link(s) to product documentation that addresses the changes of this PR. If no doc impact, enter “N/A” plus brief explanation of why theres no doc impact
## Training
> Link to the PR for changes to the training content in https://github.com/wso2/WSO2-Training, if applicable
## Certification
> Type “Sent” when you have provided new/updated certification questions, plus four answers for each question (correct answer highlighted in bold), based on this change. Certification questions/answers should be sent to certification@wso2.com and NOT pasted in this PR. If there is no impact on certification exams, type “N/A” and explain why.
## Marketing
> Link to drafts of marketing content that will describe and promote this feature, including product page changes, technical articles, blog posts, videos, etc., if applicable
## Automation tests
- Unit tests
> Code coverage information
- Integration tests
> Details about the test cases and coverage
## Security checks
- Followed secure coding standards in http://wso2.com/technical-reports/wso2-secure-engineering-guidelines? yes/no
- Ran FindSecurityBugs plugin and verified report? yes/no
- Confirmed that this PR doesn't commit any keys, passwords, tokens, usernames, or other secrets? yes/no
## Samples
> Provide high-level details about the samples related to this feature
## Related Issues
<!-- List any related issues -->
## Related PRs
> List any other related PRs
<!-- List any other related PRs -->
## Migrations (if applicable)
> Describe migration steps and platforms on which migration has been tested
## Test environment
> List all JDK versions, operating systems, databases, and browser/versions on which this feature/fix was tested
## Learning
> Describe the research phase and any blog posts, patterns, libraries, or add-ons you used to solve the problem.
<!-- Describe migration steps and platforms on which migration has been tested -->

View file

@ -0,0 +1 @@
fastmcp==0.4.1