diff --git a/.github/scripts/release.sh b/.github/scripts/release.sh
new file mode 100644
index 0000000..2a1f6a9
--- /dev/null
+++ b/.github/scripts/release.sh
@@ -0,0 +1,124 @@
+#!/bin/bash
+
+# Copyright (c) 2025, WSO2 LLC. (https://www.wso2.com).
+#
+# This software is the property of WSO2 LLC. and its suppliers, if any.
+# Dissemination of any information or reproduction of any material contained
+# herein in any form is strictly forbidden, unless permitted by WSO2 expressly.
+# You may not alter or remove any copyright or other notice from copies of this content.
+#
+
+# Exit the script on any command with non-zero exit status.
+set -e
+set -o pipefail
+
+UPSTREAM_BRANCH="main"
+
+# Assign command line arguments to variables.
+GIT_TOKEN=$1
+WORK_DIR=$2
+VERSION_TYPE=$3 # possible values: major, minor, patch
+
+# Check if GIT_TOKEN is empty
+if [ -z "$GIT_TOKEN" ]; then
+ echo "❌ Error: GIT_TOKEN is not set."
+ exit 1
+fi
+
+# Check if WORK_DIR is empty
+if [ -z "$WORK_DIR" ]; then
+ echo "❌ Error: WORK_DIR is not set."
+ exit 1
+fi
+
+# Validate VERSION_TYPE
+if [[ "$VERSION_TYPE" != "major" && "$VERSION_TYPE" != "minor" && "$VERSION_TYPE" != "patch" ]]; then
+ echo "❌ Error: VERSION_TYPE must be one of: major, minor, or patch."
+ exit 1
+fi
+
+BUILD_DIRECTORY="$WORK_DIR/build"
+RELEASE_DIRECTORY="$BUILD_DIRECTORY/releases"
+
+# Navigate to the working directory.
+cd "${WORK_DIR}"
+
+# Create the release directory.
+if [ ! -d "$RELEASE_DIRECTORY" ]; then
+ mkdir -p "$RELEASE_DIRECTORY"
+else
+ rm -rf "$RELEASE_DIRECTORY"/*
+fi
+
+# Extract current version.
+CURRENT_VERSION=$(git describe --tags --abbrev=0 2>/dev/null || echo "0.0.0")
+IFS='.' read -r MAJOR MINOR PATCH <<< "${CURRENT_VERSION}"
+
+# Determine which part to increment
+case "$VERSION_TYPE" in
+ major)
+ MAJOR=$((MAJOR + 1))
+ MINOR=0
+ PATCH=0
+ ;;
+ minor)
+ MINOR=$((MINOR + 1))
+ PATCH=0
+ ;;
+ patch|*)
+ PATCH=$((PATCH + 1))
+ ;;
+esac
+
+NEW_VERSION="${MAJOR}.${MINOR}.${PATCH}"
+
+echo "Creating release packages for version $NEW_VERSION..."
+
+# List of supported OSes.
+oses=("linux" "linux-arm" "darwin")
+
+# Navigate to the release directory.
+cd "${RELEASE_DIRECTORY}"
+
+for os in "${oses[@]}"; do
+ os_dir="../$os"
+
+ if [ -d "$os_dir" ]; then
+ release_artifact_folder="openmcpauthproxy_${os}-v${NEW_VERSION}"
+ mkdir -p "$release_artifact_folder"
+
+ cp -r $os_dir/* "$release_artifact_folder"
+
+ # Zip the release package.
+ zip_file="$release_artifact_folder.zip"
+ echo "Creating $zip_file..."
+ zip -r "$zip_file" "$release_artifact_folder"
+
+ # Delete the folder after zipping.
+ rm -rf "$release_artifact_folder"
+
+ # Generate checksum file.
+ sha256sum "$zip_file" | sed "s|target/releases/||" > "$zip_file.sha256"
+ echo "Checksum generated for the $os package."
+
+ echo "Release packages created successfully for $os."
+ else
+ echo "Skipping $os release package creation as the build artifacts are not available."
+ fi
+done
+
+echo "Release packages created successfully in $RELEASE_DIRECTORY."
+
+# Navigate back to the project root directory.
+cd "${WORK_DIR}"
+
+# Collect all ZIP and .sha256 files in the target/releases directory.
+FILES_TO_UPLOAD=$(find build/releases -type f \( -name "*.zip" -o -name "*.sha256" \))
+
+# Create a release with the current version.
+TAG_NAME="v${NEW_VERSION}"
+export GITHUB_TOKEN="${GIT_TOKEN}"
+gh release create "${TAG_NAME}" ${FILES_TO_UPLOAD} --title "${TAG_NAME}" --notes "OpenMCPAuthProxy - ${TAG_NAME}" --target "${UPSTREAM_BRANCH}" || { echo "Failed to create release"; exit 1; }
+
+
+echo "Release ${TAG_NAME} created successfully."
diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml
new file mode 100644
index 0000000..775003e
--- /dev/null
+++ b/.github/workflows/ci.yaml
@@ -0,0 +1,71 @@
+name: Build and Push container
+run-name: Build and Push container
+on:
+ workflow_dispatch:
+ #schedule:
+ # - cron: "0 10 * * *"
+ push:
+ branches:
+ - 'main'
+ - 'master'
+ tags:
+ - 'v*'
+ pull_request:
+ branches:
+ - 'main'
+ - 'master'
+env:
+ IMAGE: git.kvant.cloud/${{github.repository}}
+jobs:
+ build_concierge_backend:
+ runs-on: ubuntu-latest
+ steps:
+ - name: Checkout
+ uses: actions/checkout@v4
+ with:
+ fetch-depth: 0
+
+ - name: Set current time
+ uses: https://github.com/gerred/actions/current-time@master
+ id: current_time
+
+ - name: Set up Docker Buildx
+ uses: docker/setup-buildx-action@v3
+
+ - name: Login to git.kvant.cloud registry
+ uses: docker/login-action@v3
+ with:
+ registry: git.kvant.cloud
+ username: ${{ vars.ORG_PACKAGE_WRITER_USERNAME }}
+ password: ${{ secrets.ORG_PACKAGE_WRITER_TOKEN }}
+
+ - name: Docker meta
+ id: meta
+ uses: docker/metadata-action@v5
+ with:
+ # list of Docker images to use as base name for tags
+ images: |
+ ${{env.IMAGE}}
+ # generate Docker tags based on the following events/attributes
+ tags: |
+ type=schedule
+ type=ref,event=branch
+ type=ref,event=pr
+ type=semver,pattern={{version}}
+
+ - name: Build and push to gitea registry
+ uses: docker/build-push-action@v6
+ with:
+ push: ${{ github.event_name != 'pull_request' }}
+ tags: ${{ steps.meta.outputs.tags }}
+ labels: ${{ steps.meta.outputs.labels }}
+ context: .
+ provenance: mode=max
+ sbom: true
+ build-args: |
+ BUILD_DATE=${{ steps.current_time.outputs.time }}
+ cache-from: |
+ type=registry,ref=${{ env.IMAGE }}:buildcache
+ type=registry,ref=${{ env.IMAGE }}:${{ github.ref_name }}
+ type=registry,ref=${{ env.IMAGE }}:main
+ cache-to: type=registry,ref=${{ env.IMAGE }}:buildcache,mode=max,image-manifest=true
diff --git a/.github/workflows/pr-builder.yml b/.github/workflows/pr-builder.yml
new file mode 100644
index 0000000..a055e0d
--- /dev/null
+++ b/.github/workflows/pr-builder.yml
@@ -0,0 +1,62 @@
+name: Go CI
+
+on:
+ push:
+ branches: [ main ]
+ pull_request:
+ branches: [ main ]
+
+jobs:
+ test:
+ name: Test
+ runs-on: ubuntu-latest
+ strategy:
+ matrix:
+ go-version: ['1.20', '1.21']
+
+ steps:
+ - name: Checkout code
+ uses: actions/checkout@v3
+
+ - name: Set up Go
+ uses: actions/setup-go@v4
+ with:
+ go-version: ${{ matrix.go-version }}
+
+ - name: Get dependencies
+ run: go get -v -t -d ./...
+
+ - name: Verify dependencies
+ run: go mod verify
+
+ - name: Run go vet
+ run: go vet ./...
+
+ - name: Run tests
+ run: go test -v -race -coverprofile=coverage.txt -covermode=atomic ./...
+
+ - name: Upload coverage to Codecov
+ uses: codecov/codecov-action@v3
+ with:
+ files: ./coverage.txt
+ fail_ci_if_error: false
+
+ build:
+ name: Build
+ runs-on: ubuntu-latest
+ strategy:
+ matrix:
+ go-version: ['1.20', '1.21']
+ os: [ubuntu-latest, macos-latest]
+
+ steps:
+ - name: Checkout code
+ uses: actions/checkout@v3
+
+ - name: Set up Go
+ uses: actions/setup-go@v4
+ with:
+ go-version: ${{ matrix.go-version }}
+
+ - name: Build
+ run: go build -v ./cmd/proxy
diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml
new file mode 100644
index 0000000..0c51bc7
--- /dev/null
+++ b/.github/workflows/release.yml
@@ -0,0 +1,64 @@
+#
+# Copyright (c) 2025, WSO2 LLC. (https://www.wso2.com).
+#
+# This software is the property of WSO2 LLC. and its suppliers, if any.
+# Dissemination of any information or reproduction of any material contained
+# herein in any form is strictly forbidden, unless permitted by WSO2 expressly.
+# You may not alter or remove any copyright or other notice from copies of this content.
+#
+
+name: Release
+
+on:
+ workflow_dispatch:
+ inputs:
+ version_type:
+ type: choice
+ description: Choose the type of version update
+ options:
+ - 'major'
+ - 'minor'
+ - 'patch'
+ required: true
+
+jobs:
+ update-and-release:
+ runs-on: ubuntu-latest
+ env:
+ GOPROXY: https://proxy.golang.org
+ if: github.event.pull_request.merged == true || github.event_name == 'workflow_dispatch'
+ steps:
+ - uses: actions/checkout@v2
+ with:
+ ref: 'main'
+ fetch-depth: 0
+ token: ${{ secrets.GITHUB_TOKEN }}
+ - uses: actions/checkout@v2
+
+ - name: Set up Go 1.x
+ uses: actions/setup-go@v3
+ with:
+ go-version: "^1.x"
+
+ - name: Cache Go modules
+ id: cache-go-modules
+ uses: actions/cache@v3
+ with:
+ path: |
+ ~/.cache/go-build
+ ~/go/pkg/mod
+ key: ${{ runner.os }}-go-modules-${{ hashFiles('**/go.sum') }}
+ restore-keys: |
+ ${{ runner.os }}-go-modules-
+
+ - name: Install dependencies
+ run: go mod download
+
+ - name: Build and test
+ run: make build
+ working-directory: .
+
+ - name: Update artifact version, package, commit, and create release.
+ env:
+ GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
+ run: bash ./.github/scripts/release.sh $GITHUB_TOKEN ${{ github.workspace }} ${{ github.event.inputs.version_type }}
diff --git a/.gitignore b/.gitignore
index 6c1dd97..d200b58 100644
--- a/.gitignore
+++ b/.gitignore
@@ -18,15 +18,21 @@
*.zip
*.tar.gz
*.rar
+.venv
# virtual machine crash logs, see http://www.java.com/en/download/help/error_hotspot.xml
hs_err_pid*
replay_pid*
-# Go module cache files
-go.sum
-
# OS generated files
.DS_Store
-openmcpauthproxy
+# builds
+build
+
+# test out files
+coverage.out
+coverage.html
+
+# IDE files
+.vscode
diff --git a/Dockerfile b/Dockerfile
new file mode 100644
index 0000000..dc468b1
--- /dev/null
+++ b/Dockerfile
@@ -0,0 +1,48 @@
+FROM --platform=${BUILDPLATFORM:-linux/amd64} golang:1.24@sha256:d9db32125db0c3a680cfb7a1afcaefb89c898a075ec148fdc2f0f646cc2ed509 AS build
+
+ARG TARGETPLATFORM
+ARG BUILDPLATFORM
+ARG TARGETOS
+ARG TARGETARCH
+
+WORKDIR /workspace
+
+RUN apt update -qq && apt install -qq -y git bash curl g++
+
+# Download libraries
+ADD go.* .
+RUN go mod download
+
+# Build
+ADD cmd cmd
+ADD internal internal
+RUN CGO_ENABLED=0 GOOS=${TARGETOS} GOARCH=${TARGETARCH} go build -o webhook -ldflags '-w -extldflags "-static"' -o openmcpauthproxy ./cmd/proxy
+
+#Test
+RUN CGO_ENABLED=1 GOOS=${TARGETOS} GOARCH=${TARGETARCH} go test -v -race ./...
+
+
+# Build production container
+FROM --platform=${BUILDPLATFORM:-linux/amd64} ubuntu:24.04
+
+RUN apt-get update \
+ && apt-get install --no-install-recommends -y \
+ python3-pip \
+ python-is-python3 \
+ npm \
+ && apt-get autoremove \
+ && apt-get clean \
+ && rm -rf /var/lib/apt/lists/*
+
+RUN pip install uvenv --break-system-packages
+
+WORKDIR /app
+COPY --from=build /workspace/openmcpauthproxy /app/
+
+ADD config.yaml /app
+
+
+ENTRYPOINT ["/app/openmcpauthproxy"]
+
+ARG IMAGE_SOURCE
+LABEL org.opencontainers.image.source=$IMAGE_SOURCE
diff --git a/Makefile b/Makefile
new file mode 100644
index 0000000..b0d0926
--- /dev/null
+++ b/Makefile
@@ -0,0 +1,88 @@
+# Makefile for open-mcp-auth-proxy
+
+# Variables
+PROJECT_ROOT := $(realpath $(dir $(abspath $(lastword $(MAKEFILE_LIST)))))
+BINARY_NAME := openmcpauthproxy
+GO := go
+GOFMT := gofmt
+GOVET := go vet
+GOTEST := go test
+GOLINT := golangci-lint
+GOCOV := go tool cover
+BUILD_DIR := build
+
+# Source files
+SRC := $(shell find . -name "*.go" -not -path "./vendor/*")
+PKGS := $(shell go list ./... | grep -v /vendor/)
+
+# Set build options
+BUILD_OPTS := -v
+
+# Set test options
+TEST_OPTS := -v -race
+
+.PHONY: all clean test fmt lint vet coverage help
+
+# Default target
+all: lint test build-linux build-linux-arm build-darwin
+
+build: clean test build-linux build-linux-arm build-darwin
+
+build-linux:
+ mkdir -p $(BUILD_DIR)/linux
+ GOOS=linux GOARCH=amd64 CGO_ENABLED=0 go build -x -ldflags "-X main.version=$(BUILD_VERSION)" \
+ -o $(BUILD_DIR)/linux/openmcpauthproxy $(PROJECT_ROOT)/cmd/proxy
+ cp config.yaml $(BUILD_DIR)/linux
+
+build-linux-arm:
+ mkdir -p $(BUILD_DIR)/linux-arm
+ GOOS=linux GOARCH=arm CGO_ENABLED=0 go build -x -ldflags "-X main.version=$(BUILD_VERSION)" \
+ -o $(BUILD_DIR)/linux-arm/openmcpauthproxy $(PROJECT_ROOT)/cmd/proxy
+ cp config.yaml $(BUILD_DIR)/linux-arm
+
+build-darwin:
+ mkdir -p $(BUILD_DIR)/darwin
+ GOOS=darwin GOARCH=amd64 CGO_ENABLED=0 go build -x -ldflags "-X main.version=$(BUILD_VERSION)" \
+ -o $(BUILD_DIR)/darwin/openmcpauthproxy $(PROJECT_ROOT)/cmd/proxy
+ cp config.yaml $(BUILD_DIR)/darwin
+
+# Clean build artifacts
+clean:
+ @echo "Cleaning build artifacts..."
+ @rm -rf $(BUILD_DIR)
+ @rm -f coverage.out
+
+# Run tests
+test:
+ @echo "Running tests..."
+ $(GOTEST) $(TEST_OPTS) ./...
+
+# Run tests with coverage report
+coverage:
+ @echo "Running tests with coverage..."
+ @$(GOTEST) -coverprofile=coverage.out ./...
+ @$(GOCOV) -func=coverage.out
+ @$(GOCOV) -html=coverage.out -o coverage.html
+ @echo "Coverage report generated in coverage.html"
+
+# Run gofmt
+fmt:
+ @echo "Running gofmt..."
+ @$(GOFMT) -w -s $(SRC)
+
+# Run go vet
+vet:
+ @echo "Running go vet..."
+ @$(GOVET) ./...
+
+# Show help
+help:
+ @echo "Available targets:"
+ @echo " all : Run lint, test, and build"
+ @echo " build : Build the application"
+ @echo " clean : Clean build artifacts"
+ @echo " test : Run tests"
+ @echo " coverage : Run tests with coverage report"
+ @echo " fmt : Run gofmt"
+ @echo " vet : Run go vet"
+ @echo " help : Show this help message"
diff --git a/README.md b/README.md
index b587878..6be3ece 100644
--- a/README.md
+++ b/README.md
@@ -1,81 +1,88 @@
# Open MCP Auth Proxy
-The Open MCP Auth Proxy is a lightweight proxy designed to sit in front of MCP servers and enforce authorization in compliance with the [Model Context Protocol authorization](https://spec.modelcontextprotocol.io/specification/2025-03-26/basic/authorization/) requirements. It intercepts incoming requests, validates tokens, and offloads authentication and authorization to an OAuth-compliant Identity Provider.
+A lightweight authorization proxy for Model Context Protocol (MCP) servers that enforces authorization according to the [MCP authorization specification](https://spec.modelcontextprotocol.io/specification/2025-03-26/basic/authorization/)
-
+[](https://github.com/wso2/open-mcp-auth-proxy/actions/workflows/release.yml)
+[](https://stackoverflow.com/questions/tagged/wso2is)
+[](https://discord.gg/wso2)
+[](https://twitter.com/intent/follow?screen_name=wso2)
+[](https://github.com/wso2/product-is/blob/master/LICENSE)
-## **Setup and Installation**
+
-### **Prerequisites**
+## What it Does
-* Go 1.20 or higher
-* A running MCP server (SSE transport supported)
-* An MCP client that supports MCP authorization
+Open MCP Auth Proxy sits between MCP clients and your MCP server to:
-### **Installation**
+- Intercept incoming requests
+- Validate authorization tokens
+- Offload authentication and authorization to OAuth-compliant Identity Providers
+- Support the MCP authorization protocol
-```bash
-git clone https://github.com/wso2/open-mcp-auth-proxy
-cd open-mcp-auth-proxy
+## Quick Start
-go get github.com/golang-jwt/jwt/v4
-go get gopkg.in/yaml.v2
+### Prerequisites
-go build -o openmcpauthproxy ./cmd/proxy
-```
+* Go 1.20 or higher
+* A running MCP server
-## Using Open MCP Auth Proxy
+> If you don't have an MCP server, you can use the included example:
+>
+> 1. Navigate to the `resources` directory
+> 2. Set up a Python environment:
+>
+> ```bash
+> python3 -m venv .venv
+> source .venv/bin/activate
+> pip3 install -r requirements.txt
+> ```
+>
+> 3. Start the example server:
+>
+> ```bash
+> python3 echo_server.py
+> ```
-### Quick Start
+* An MCP client that supports MCP authorization
-Allows you to just enable authentication and authorization for your MCP server with the preconfigured auth provider powered by Asgardeo.
+### Basic Usage
-If you don’t have an MCP server, follow the instructions given here to start your own MCP server for testing purposes.
-1. Download [sample MCP server](resources/echo_server.py)
-2. Run the server with
-```bash
-python3 echo_server.py
-```
+1. Download the latest release from [Github releases](https://github.com/wso2/open-mcp-auth-proxy/releases/latest).
-#### Configure the Auth Proxy
-
-Update the following parameters in `config.yaml`.
-
-```yaml
-mcp_server_base_url: "http://localhost:8000" # URL of your MCP server
-listen_address: ":8080" # Address where the proxy will listen
-```
-
-#### Start the Auth Proxy
+2. Start the proxy in demo mode (uses pre-configured authentication with Asgardeo sandbox):
```bash
./openmcpauthproxy --demo
```
-The `--demo` flag enables a demonstration mode with pre-configured authentication and authorization with a sandbox powered by [Asgardeo](https://asgardeo.io/).
+> The repository comes with a default `config.yaml` file that contains the basic configuration:
+>
+> ```yaml
+> listen_port: 8080
+> base_url: "http://localhost:8000" # Your MCP server URL
+> paths:
+> sse: "/sse"
+> messages: "/messages/"
+> ```
-#### Connect Using an MCP Client
+3. Connect using an MCP client like [MCP Inspector](https://github.com/shashimalcse/inspector)(This is a temporary fork with fixes for authentication [issues](https://github.com/modelcontextprotocol/typescript-sdk/issues/257) in the original implementation)
-You can use the [MCP Inspector](https://github.com/modelcontextprotocol/inspector) to test the connection and try out the complete authorization flow.
+## Connect an Identity Provider
-### Use with Asgardeo
+### Asgardeo
-Enable authorization for the MCP server through your own Asgardeo organization
+To enable authorization through your Asgardeo organization:
-1. [Register]([url](https://asgardeo.io/signup)) and create an organization in Asgardeo
-2. Now, you need to authorize the OpenMCPAuthProxy to allow dynamically registering MCP Clients as applications in your organization. To do that,
- 1. Create an [M2M application](https://wso2.com/asgardeo/docs/guides/applications/register-machine-to-machine-app/)
- 1. [Authorize this application](https://wso2.com/asgardeo/docs/guides/applications/register-machine-to-machine-app/#authorize-the-api-resources-for-the-app) to invoke “Application Management API” with the `internal_application_mgt_create` scope.
- 
- 2. Note the **Client ID** and **Client secret** of this application. This is required by the auth proxy
-
-#### Configure the Auth Proxy
-
-Create a configuration file config.yaml with the following parameters:
+1. [Register](https://asgardeo.io/signup) and create an organization in Asgardeo
+2. Create an [M2M application](https://wso2.com/asgardeo/docs/guides/applications/register-machine-to-machine-app/)
+ 1. [Authorize this application](https://wso2.com/asgardeo/docs/guides/applications/register-machine-to-machine-app/#authorize-the-api-resources-for-the-app) to invoke "Application Management API" with the `internal_application_mgt_create` scope
+ 
+
+3. Update `config.yaml` with the following parameters.
```yaml
-mcp_server_base_url: "http://localhost:8000" # URL of your MCP server
-listen_address: ":8080" # Address where the proxy will listen
+base_url: "http://localhost:8000" # URL of your MCP server
+listen_port: 8080 # Address where the proxy will listen
asgardeo:
org_name: "" # Your Asgardeo org name
@@ -83,31 +90,137 @@ asgardeo:
client_secret: "" # Client secret of the M2M app
```
-#### Start the Auth Proxy
+4. Start the proxy with Asgardeo integration:
```bash
./openmcpauthproxy --asgardeo
```
-### Use with any standard OAuth Server
+### Other OAuth Providers
-Enable authorization for the MCP server with a compliant OAuth server
+- [Auth0](docs/integrations/Auth0.md)
+- [Keycloak](docs/integrations/keycloak.md)
-#### Configuration
+# Advanced Configuration
-Create a configuration file config.yaml with the following parameters:
+### Transport Modes
-```yaml
-mcp_server_base_url: "http://localhost:8000" # URL of your MCP server
-listen_address: ":8080" # Address where the proxy will listen
-```
-**TODO**: Update the configs for a standard OAuth Server.
+The proxy supports two transport modes:
-#### Start the Auth Proxy
+- **SSE Mode (Default)**: For Server-Sent Events transport
+- **stdio Mode**: For MCP servers that use stdio transport
+
+When using stdio mode, the proxy:
+- Starts an MCP server as a subprocess using the command specified in the configuration
+- Communicates with the subprocess through standard input/output (stdio)
+- **Note**: Any commands specified (like `npx` in the example below) must be installed on your system first
+
+To use stdio mode:
```bash
-./openmcpauthproxy
+./openmcpauthproxy --demo --stdio
```
-#### Integrating with existing OAuth Providers
- - [Auth0](URL) - Enable authorization for the MCP server through your Auth0 organization. **TODO**: Add instructions under docs and link
+#### Example: Running an MCP Server as a Subprocess
+
+1. Configure stdio mode in your `config.yaml`:
+
+```yaml
+listen_port: 8080
+base_url: "http://localhost:8000"
+
+stdio:
+ enabled: true
+ user_command: "npx -y @modelcontextprotocol/server-github" # Example using a GitHub MCP server
+ env: # Environment variables (optional)
+ - "GITHUB_PERSONAL_ACCESS_TOKEN=gitPAT"
+
+# CORS configuration
+cors:
+ allowed_origins:
+ - "http://localhost:5173" # Origin of your client application
+ allowed_methods:
+ - "GET"
+ - "POST"
+ - "PUT"
+ - "DELETE"
+ allowed_headers:
+ - "Authorization"
+ - "Content-Type"
+ allow_credentials: true
+
+# Demo configuration for Asgardeo
+demo:
+ org_name: "openmcpauthdemo"
+ client_id: "N0U9e_NNGr9mP_0fPnPfPI0a6twa"
+ client_secret: "qFHfiBp5gNGAO9zV4YPnDofBzzfInatfUbHyPZvM0jka"
+```
+
+2. Run the proxy with stdio mode:
+
+```bash
+./openmcpauthproxy --demo
+```
+
+The proxy will:
+- Start the MCP server as a subprocess using the specified command
+- Handle all authorization requirements
+- Forward messages between clients and the server
+
+### Complete Configuration Reference
+
+```yaml
+# Common configuration
+listen_port: 8080
+base_url: "http://localhost:8000"
+port: 8000
+
+# Path configuration
+paths:
+ sse: "/sse"
+ messages: "/messages/"
+
+# Transport mode
+transport_mode: "sse" # Options: "sse" or "stdio"
+
+# stdio-specific configuration (used only in stdio mode)
+stdio:
+ enabled: true
+ user_command: "npx -y @modelcontextprotocol/server-github" # Command to start the MCP server (requires npx to be installed)
+ work_dir: "" # Optional working directory for the subprocess
+
+# CORS configuration
+cors:
+ allowed_origins:
+ - "http://localhost:5173"
+ allowed_methods:
+ - "GET"
+ - "POST"
+ - "PUT"
+ - "DELETE"
+ allowed_headers:
+ - "Authorization"
+ - "Content-Type"
+ allow_credentials: true
+
+# Demo configuration for Asgardeo
+demo:
+ org_name: "openmcpauthdemo"
+ client_id: "N0U9e_NNGr9mP_0fPnPfPI0a6twa"
+ client_secret: "qFHfiBp5gNGAO9zV4YPnDofBzzfInatfUbHyPZvM0jka"
+
+# Asgardeo configuration (used with --asgardeo flag)
+asgardeo:
+ org_name: ""
+ client_id: ""
+ client_secret: ""
+```
+
+### Build from source
+
+```bash
+git clone https://github.com/wso2/open-mcp-auth-proxy
+cd open-mcp-auth-proxy
+go get github.com/golang-jwt/jwt/v4 gopkg.in/yaml.v2
+go build -o openmcpauthproxy ./cmd/proxy
+```
diff --git a/cmd/proxy/main.go b/cmd/proxy/main.go
index 9a4b472..c43dd7d 100644
--- a/cmd/proxy/main.go
+++ b/cmd/proxy/main.go
@@ -3,71 +3,132 @@ package main
import (
"flag"
"fmt"
- "log"
"net/http"
"os"
"os/signal"
+ "syscall"
"time"
"github.com/wso2/open-mcp-auth-proxy/internal/authz"
"github.com/wso2/open-mcp-auth-proxy/internal/config"
+ "github.com/wso2/open-mcp-auth-proxy/internal/constants"
+ logger "github.com/wso2/open-mcp-auth-proxy/internal/logging"
"github.com/wso2/open-mcp-auth-proxy/internal/proxy"
+ "github.com/wso2/open-mcp-auth-proxy/internal/subprocess"
"github.com/wso2/open-mcp-auth-proxy/internal/util"
)
func main() {
demoMode := flag.Bool("demo", false, "Use Asgardeo-based provider (demo).")
+ asgardeoMode := flag.Bool("asgardeo", false, "Use Asgardeo-based provider (asgardeo).")
+ debugMode := flag.Bool("debug", false, "Enable debug logging")
+ stdioMode := flag.Bool("stdio", false, "Use stdio transport mode instead of SSE")
flag.Parse()
+ logger.SetDebug(*debugMode)
+
// 1. Load config
cfg, err := config.LoadConfig("config.yaml")
if err != nil {
- log.Fatalf("Error loading config: %v", err)
+ logger.Error("Error loading config: %v", err)
+ os.Exit(1)
}
- // 2. Create the chosen provider
+ // Override transport mode if stdio flag is set
+ if *stdioMode {
+ cfg.TransportMode = config.StdioTransport
+ // Ensure stdio is enabled
+ cfg.Stdio.Enabled = true
+ // Re-validate config
+ if err := cfg.Validate(); err != nil {
+ logger.Error("Configuration error: %v", err)
+ os.Exit(1)
+ }
+ }
+
+ logger.Info("Using transport mode: %s", cfg.TransportMode)
+ logger.Info("Using MCP server base URL: %s", cfg.BaseURL)
+ logger.Info("Using MCP paths: SSE=%s, Messages=%s", cfg.Paths.SSE, cfg.Paths.Messages)
+
+ // 2. Start subprocess if configured and in stdio mode
+ var procManager *subprocess.Manager
+ if cfg.TransportMode == config.StdioTransport && cfg.Stdio.Enabled {
+ // Ensure all required dependencies are available
+ if err := subprocess.EnsureDependenciesAvailable(cfg.Stdio.UserCommand); err != nil {
+ logger.Warn("%v", err)
+ logger.Warn("Subprocess may fail to start due to missing dependencies")
+ }
+
+ procManager = subprocess.NewManager()
+ if err := procManager.Start(cfg); err != nil {
+ logger.Warn("Failed to start subprocess: %v", err)
+ }
+ } else if cfg.TransportMode == config.SSETransport {
+ logger.Info("Using SSE transport mode, not starting subprocess")
+ }
+
+ // 3. Create the chosen provider
var provider authz.Provider
if *demoMode {
- cfg.AuthServerBaseURL = "https://api.asgardeo.io/t/" + cfg.Demo.OrgName + "/oauth2"
- cfg.JWKSURL = "https://api.asgardeo.io/t/" + cfg.Demo.OrgName + "/oauth2/jwks"
+ cfg.Mode = "demo"
+ cfg.AuthServerBaseURL = constants.ASGARDEO_BASE_URL + cfg.Demo.OrgName + "/oauth2"
+ cfg.JWKSURL = constants.ASGARDEO_BASE_URL + cfg.Demo.OrgName + "/oauth2/jwks"
+ provider = authz.NewAsgardeoProvider(cfg)
+ } else if *asgardeoMode {
+ cfg.Mode = "asgardeo"
+ cfg.AuthServerBaseURL = constants.ASGARDEO_BASE_URL + cfg.Asgardeo.OrgName + "/oauth2"
+ cfg.JWKSURL = constants.ASGARDEO_BASE_URL + cfg.Asgardeo.OrgName + "/oauth2/jwks"
provider = authz.NewAsgardeoProvider(cfg)
- fmt.Println("Using Asgardeo provider (demo).")
} else {
- log.Fatalf("Not supported yet.")
+ cfg.Mode = "default"
+ cfg.JWKSURL = cfg.Default.JWKSURL
+ cfg.AuthServerBaseURL = cfg.Default.BaseURL
+ provider = authz.NewDefaultProvider(cfg)
}
- // 3. (Optional) Fetch JWKS if you want local JWT validation
+ // 4. (Optional) Fetch JWKS if you want local JWT validation
if err := util.FetchJWKS(cfg.JWKSURL); err != nil {
- log.Fatalf("Failed to fetch JWKS: %v", err)
+ logger.Error("Failed to fetch JWKS: %v", err)
+ os.Exit(1)
}
- // 4. Build the main router
+ // 5. Build the main router
mux := proxy.NewRouter(cfg, provider)
- // 5. Start the server
+ listen_address := fmt.Sprintf("0.0.0.0:%d", cfg.ListenPort)
+
+ // 6. Start the server
srv := &http.Server{
- Addr: cfg.ListenAddress,
+ Addr: listen_address,
Handler: mux,
}
go func() {
- log.Printf("Server listening on %s", cfg.ListenAddress)
+ logger.Info("Server listening on %s", listen_address)
if err := srv.ListenAndServe(); err != nil && err != http.ErrServerClosed {
- log.Fatalf("Server error: %v", err)
+ logger.Error("Server error: %v", err)
+ os.Exit(1)
}
}()
- // 6. Graceful shutdown on Ctrl+C
+ // 7. Wait for shutdown signal
stop := make(chan os.Signal, 1)
- signal.Notify(stop, os.Interrupt)
+ signal.Notify(stop, os.Interrupt, syscall.SIGTERM)
<-stop
- log.Println("Shutting down...")
+ logger.Info("Shutting down...")
+ // 8. First terminate subprocess if running
+ if procManager != nil && procManager.IsRunning() {
+ procManager.Shutdown()
+ }
+
+ // 9. Then shutdown the server
+ logger.Info("Shutting down HTTP server...")
shutdownCtx, cancel := proxy.NewShutdownContext(5 * time.Second)
defer cancel()
if err := srv.Shutdown(shutdownCtx); err != nil {
- log.Printf("Shutdown error: %v", err)
+ logger.Error("HTTP server shutdown error: %v", err)
}
- log.Println("Stopped.")
+ logger.Info("Stopped.")
}
diff --git a/config.yaml b/config.yaml
index 9725f14..ef70fbb 100644
--- a/config.yaml
+++ b/config.yaml
@@ -1,18 +1,65 @@
# config.yaml
-auth_server_base_url: ""
-mcp_server_base_url: "http://localhost:8000"
-listen_address: ":8080"
-jwks_url: ""
+# Common configuration for all transport modes
+listen_port: 8080
+base_url: "http://localhost:8000" # Base URL for the MCP server
+port: 8000 # Port for the MCP server
timeout_seconds: 10
-mcp_paths:
- - /messages/
- - /sse
+# Transport mode configuration
+transport_mode: "stdio" # Options: "sse" or "stdio"
+
+# stdio-specific configuration (used only when transport_mode is "stdio")
+stdio:
+ enabled: true
+ user_command: uvx mcp-server-time --local-timezone=Europe/Zurich
+ #user_command: "npx -y @modelcontextprotocol/server-github"
+ work_dir: "" # Working directory (optional)
+ # env: # Environment variables (optional)
+ # - "NODE_ENV=development"
+
+# CORS settings
+cors:
+ allowed_origins:
+ - "http://localhost:6274" # Origin of your frontend/client app
+ allowed_methods:
+ - "GET"
+ - "POST"
+ - "PUT"
+ - "DELETE"
+ allowed_headers:
+ - "Authorization"
+ - "Content-Type"
+ - "mcp-protocol-version"
+ allow_credentials: true
+
+# Keycloak endpoint path mappings
path_mapping:
+ sse: "/sse" # SSE endpoint path
+ messages: "/messages/" # Messages endpoint path
+ /token: /realms/master/protocol/openid-connect/token
+ /register: /realms/master/clients-registrations/openid-connect
-demo:
- org_name: "openmcpauthdemo"
- client_id: "N0U9e_NNGr9mP_0fPnPfPI0a6twa"
- client_secret: "qFHfiBp5gNGAO9zV4YPnDofBzzfInatfUbHyPZvM0jka"
+# Keycloak configuration block
+default:
+ base_url: "https://iam.phoenix-systems.ch"
+ jwks_url: "https://iam.phoenix-systems.ch/realms/kvant/protocol/openid-connect/certs"
+ path:
+ /.well-known/oauth-authorization-server:
+ response:
+ issuer: "https://iam.phoenix-systems.ch/realms/kvant"
+ jwks_uri: "https://iam.phoenix-systems.ch/realms/kvant/protocol/openid-connect/certs"
+ authorization_endpoint: "https://iam.phoenix-systems.ch/realms/kvant/protocol/openid-connect/auth"
+ response_types_supported:
+ - "code"
+ grant_types_supported:
+ - "authorization_code"
+ - "refresh_token"
+ code_challenge_methods_supported:
+ - "S256"
+ - "plain"
+ /token:
+ addBodyParams:
+ - name: "audience"
+ value: "mcp_proxy"
\ No newline at end of file
diff --git a/docs/integrations/Auth0.md b/docs/integrations/Auth0.md
new file mode 100644
index 0000000..9195659
--- /dev/null
+++ b/docs/integrations/Auth0.md
@@ -0,0 +1,93 @@
+## Integrating with Auth0
+
+This guide will help you configure Open MCP Auth Proxy to use Auth0 as your identity provider.
+
+### Prerequisites
+
+- An Auth0 organization (sign up [here](https://auth0.com) if you don't have one)
+- Open MCP Auth Proxy installed
+
+### Setting Up Auth0
+1. [Enable Dynamic Client Registration](https://auth0.com/docs/get-started/applications/dynamic-client-registration)
+ - Go to your Auth0 dashboard
+ - Navigate to Settings > Advanced
+ - Enable "OIDC Dynamic Application Registration"
+2. In order to setup connections in dynamically created clients [promote Connections to Domain Level](https://auth0.com/docs/authenticate/identity-providers/promote-connections-to-domain-level)
+3. Create an API in Auth0:
+ - Go to your Auth0 dashboard
+ - Navigate to Applications > APIs
+ - Click on "Create API"
+ - Set a Name (e.g., "MCP API")
+ - Set an Identifier (e.g., "mcp_proxy")
+ - Keep the default signing algorithm (RS256)
+ - Click "Create"
+
+### Configuring the Open MCP Auth Proxy
+
+Update your `config.yaml` with Auth0 settings:
+
+```yaml
+# Basic proxy configuration
+listen_port: 8080
+base_url: "http://localhost:8000"
+port: 8000
+
+# Path configuration
+paths:
+ sse: "/sse"
+ messages: "/messages/"
+
+# Transport mode
+transport_mode: "sse"
+
+# CORS configuration
+cors:
+ allowed_origins:
+ - "http://localhost:5173" # Your client application origin
+ allowed_methods:
+ - "GET"
+ - "POST"
+ - "PUT"
+ - "DELETE"
+ allowed_headers:
+ - "Authorization"
+ - "Content-Type"
+ allow_credentials: true
+
+# Path mappings for Auth0 endpoints
+path_mapping:
+ /token: /oauth/token
+ /register: /oidc/register
+
+# Auth0 configuration
+default:
+ base_url: "https://YOUR_AUTH0_DOMAIN" # e.g., https://dev-123456.us.auth0.com
+ jwks_url: "https://YOUR_AUTH0_DOMAIN/.well-known/jwks.json"
+ path:
+ /.well-known/oauth-authorization-server:
+ response:
+ issuer: "https://YOUR_AUTH0_DOMAIN/"
+ jwks_uri: "https://YOUR_AUTH0_DOMAIN/.well-known/jwks.json"
+ authorization_endpoint: "https://YOUR_AUTH0_DOMAIN/authorize?audience=mcp_proxy" # Only if you created an API with this identifier
+ response_types_supported:
+ - "code"
+ grant_types_supported:
+ - "authorization_code"
+ - "refresh_token"
+ code_challenge_methods_supported:
+ - "S256"
+ - "plain"
+ /token:
+ addBodyParams:
+ - name: "audience"
+ value: "mcp_proxy" # Only if you created an API with this identifier
+```
+
+Replace YOUR_AUTH0_DOMAIN with your Auth0 domain (e.g., dev-abc123.us.auth0.com).
+
+## Starting the Proxy with Auth0 Integration
+Start the proxy in default mode (which will use Auth0 based on your configuration):
+
+```bash
+./openmcpauthproxy
+```
diff --git a/docs/integrations/keycloak.md b/docs/integrations/keycloak.md
new file mode 100644
index 0000000..5e338cc
--- /dev/null
+++ b/docs/integrations/keycloak.md
@@ -0,0 +1,92 @@
+## Integrating Open MCP Auth Proxy with Keycloak
+
+This guide walks you through configuring the Open MCP Auth Proxy to authenticate using Keycloak as the identity provider.
+
+---
+
+### Prerequisites
+
+Before you begin, ensure you have the following:
+
+- A running Keycloak instance
+- Open MCP Auth Proxy installed and accessible
+
+---
+
+### Step 1: Configure Keycloak for Client Registration
+
+Set up dynamic client registration in your Keycloak realm by following the [Keycloak client registration guide](https://www.keycloak.org/securing-apps/client-registration).
+
+---
+
+### Step 2: Configure Open MCP Auth Proxy
+
+Update the `config.yaml` file in your Open MCP Auth Proxy setup using your Keycloak realm's [OIDC settings](https://www.keycloak.org/securing-apps/oidc-layers). Below is an example configuration:
+
+```yaml
+# Proxy server configuration
+listen_port: 8081 # Port for the auth proxy
+base_url: "http://localhost:8000" # Base URL of the MCP server
+port: 8000 # MCP server port
+
+# Define path mappings
+paths:
+ sse: "/sse"
+ messages: "/messages/"
+
+# Set the transport mode
+transport_mode: "sse"
+
+# CORS settings
+cors:
+ allowed_origins:
+ - "http://localhost:5173" # Origin of your frontend/client app
+ allowed_methods:
+ - "GET"
+ - "POST"
+ - "PUT"
+ - "DELETE"
+ allowed_headers:
+ - "Authorization"
+ - "Content-Type"
+ - "mcp-protocol-version"
+ allow_credentials: true
+
+# Keycloak endpoint path mappings
+path_mapping:
+ /token: /realms/master/protocol/openid-connect/token
+ /register: /realms/master/clients-registrations/openid-connect
+
+# Keycloak configuration block
+default:
+ base_url: "http://localhost:8080"
+ jwks_url: "http://localhost:8080/realms/master/protocol/openid-connect/certs"
+ path:
+ /.well-known/oauth-authorization-server:
+ response:
+ issuer: "http://localhost:8080/realms/master"
+ jwks_uri: "http://localhost:8080/realms/master/protocol/openid-connect/certs"
+ authorization_endpoint: "http://localhost:8080/realms/master/protocol/openid-connect/auth"
+ response_types_supported:
+ - "code"
+ grant_types_supported:
+ - "authorization_code"
+ - "refresh_token"
+ code_challenge_methods_supported:
+ - "S256"
+ - "plain"
+ /token:
+ addBodyParams:
+ - name: "audience"
+ value: "mcp_proxy"
+```
+
+### Step 3: Start the Auth Proxy
+
+Launch the proxy with the updated Keycloak configuration:
+
+```bash
+./openmcpauthproxy
+```
+
+Once running, the proxy will handle authentication requests through your configured Keycloak realm.
diff --git a/go.mod b/go.mod
index 2d26216..0bceb4f 100644
--- a/go.mod
+++ b/go.mod
@@ -1,6 +1,6 @@
module github.com/wso2/open-mcp-auth-proxy
-go 1.22.3
+go 1.21
require (
github.com/golang-jwt/jwt/v4 v4.5.2
diff --git a/go.sum b/go.sum
new file mode 100644
index 0000000..9d27ad1
--- /dev/null
+++ b/go.sum
@@ -0,0 +1,6 @@
+github.com/golang-jwt/jwt/v4 v4.5.2 h1:YtQM7lnr8iZ+j5q71MGKkNw9Mn7AjHM68uc9g5fXeUI=
+github.com/golang-jwt/jwt/v4 v4.5.2/go.mod h1:m21LjoU+eqJr34lmDMbreY2eSTRJ1cv77w39/MY0Ch0=
+gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
+gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
+gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY=
+gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ=
diff --git a/internal/authz/asgardeo.go b/internal/authz/asgardeo.go
index 7408f79..a3c812c 100644
--- a/internal/authz/asgardeo.go
+++ b/internal/authz/asgardeo.go
@@ -7,13 +7,13 @@ import (
"encoding/json"
"fmt"
"io"
- "log"
"math/rand"
"net/http"
"strings"
"time"
"github.com/wso2/open-mcp-auth-proxy/internal/config"
+ "github.com/wso2/open-mcp-auth-proxy/internal/logging"
)
type asgardeoProvider struct {
@@ -31,6 +31,7 @@ func (p *asgardeoProvider) WellKnownHandler() http.HandlerFunc {
w.Header().Set("Access-Control-Allow-Origin", "*")
w.Header().Set("Access-Control-Allow-Headers", "Authorization, Content-Type")
w.Header().Set("Access-Control-Allow-Methods", "GET, OPTIONS")
+ w.Header().Set("X-Accel-Buffering", "no")
if r.Method == http.MethodOptions {
w.WriteHeader(http.StatusNoContent)
@@ -70,8 +71,9 @@ func (p *asgardeoProvider) WellKnownHandler() http.HandlerFunc {
}
w.Header().Set("Content-Type", "application/json")
+ w.Header().Set("X-Accel-Buffering", "no")
if err := json.NewEncoder(w).Encode(response); err != nil {
- log.Printf("[asgardeoProvider] Error encoding well-known: %v", err)
+ logger.Error("Error encoding well-known: %v", err)
http.Error(w, "Internal server error", http.StatusInternalServerError)
}
}
@@ -83,6 +85,7 @@ func (p *asgardeoProvider) RegisterHandler() http.HandlerFunc {
w.Header().Set("Access-Control-Allow-Origin", "*")
w.Header().Set("Access-Control-Allow-Headers", "Authorization, Content-Type")
w.Header().Set("Access-Control-Allow-Methods", "GET, OPTIONS")
+ w.Header().Set("X-Accel-Buffering", "no")
if r.Method == http.MethodOptions {
w.WriteHeader(http.StatusNoContent)
@@ -95,7 +98,7 @@ func (p *asgardeoProvider) RegisterHandler() http.HandlerFunc {
var regReq RegisterRequest
if err := json.NewDecoder(r.Body).Decode(®Req); err != nil {
- log.Printf("ERROR: reading register request: %v", err)
+ logger.Error("Reading register request: %v", err)
http.Error(w, "Invalid request body", http.StatusBadRequest)
return
}
@@ -109,7 +112,7 @@ func (p *asgardeoProvider) RegisterHandler() http.HandlerFunc {
regReq.ClientSecret = randomString(16)
if err := p.createAsgardeoApplication(regReq); err != nil {
- log.Printf("WARN: Asgardeo application creation failed: %v", err)
+ logger.Warn("Asgardeo application creation failed: %v", err)
// Optionally http.Error(...) if you want to fail
// or continue to return partial data.
}
@@ -124,9 +127,10 @@ func (p *asgardeoProvider) RegisterHandler() http.HandlerFunc {
}
w.Header().Set("Content-Type", "application/json")
+ w.Header().Set("X-Accel-Buffering", "no")
w.WriteHeader(http.StatusCreated)
if err := json.NewEncoder(w).Encode(resp); err != nil {
- log.Printf("ERROR: encoding /register response: %v", err)
+ logger.Error("Encoding /register response: %v", err)
http.Error(w, "Internal server error", http.StatusInternalServerError)
}
}
@@ -186,7 +190,7 @@ func (p *asgardeoProvider) createAsgardeoApplication(regReq RegisterRequest) err
return fmt.Errorf("Asgardeo creation error (%d): %s", resp.StatusCode, string(respBody))
}
- log.Printf("INFO: Created Asgardeo application for clientID=%s", regReq.ClientID)
+ logger.Info("Created Asgardeo application for clientID=%s", regReq.ClientID)
return nil
}
@@ -202,8 +206,11 @@ func (p *asgardeoProvider) getAsgardeoAdminToken() (string, error) {
}
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
+ // Sensitive data - should not be logged at INFO level
auth := p.cfg.Demo.ClientID + ":" + p.cfg.Demo.ClientSecret
req.Header.Set("Authorization", "Basic "+base64.StdEncoding.EncodeToString([]byte(auth)))
+
+ logger.Debug("Requesting admin token for Asgardeo with client ID: %s", p.cfg.Demo.ClientID)
tr := &http.Transport{
TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
@@ -234,6 +241,10 @@ func (p *asgardeoProvider) getAsgardeoAdminToken() (string, error) {
return "", fmt.Errorf("failed to parse token JSON: %w", err)
}
+ // Don't log the actual token at info level, only at debug level
+ logger.Debug("Received access token: %s", tokenResp.AccessToken)
+ logger.Info("Successfully obtained admin token from Asgardeo")
+
return tokenResp.AccessToken, nil
}
diff --git a/internal/authz/default.go b/internal/authz/default.go
new file mode 100644
index 0000000..929f586
--- /dev/null
+++ b/internal/authz/default.go
@@ -0,0 +1,96 @@
+package authz
+
+import (
+ "encoding/json"
+ "net/http"
+
+ "github.com/wso2/open-mcp-auth-proxy/internal/config"
+ "github.com/wso2/open-mcp-auth-proxy/internal/logging"
+)
+
+type defaultProvider struct {
+ cfg *config.Config
+}
+
+// NewDefaultProvider initializes a Provider for Asgardeo (demo mode).
+func NewDefaultProvider(cfg *config.Config) Provider {
+ return &defaultProvider{cfg: cfg}
+}
+
+func (p *defaultProvider) WellKnownHandler() http.HandlerFunc {
+ return func(w http.ResponseWriter, r *http.Request) {
+ w.Header().Set("Access-Control-Allow-Origin", "*")
+ w.Header().Set("Access-Control-Allow-Headers", "Authorization, Content-Type")
+ w.Header().Set("Access-Control-Allow-Methods", "GET, OPTIONS")
+
+ if r.Method == http.MethodOptions {
+ w.WriteHeader(http.StatusNoContent)
+ return
+ }
+
+ if r.Method != http.MethodGet {
+ http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
+ return
+ }
+
+ // Check if we have a custom response configuration
+ if p.cfg.Default.Path != nil {
+ pathConfig, exists := p.cfg.Default.Path["/.well-known/oauth-authorization-server"]
+ if exists && pathConfig.Response != nil {
+ // Use configured response values
+ responseConfig := pathConfig.Response
+
+ // Get current host for proxy endpoints
+ scheme := "http"
+ if r.TLS != nil {
+ scheme = "https"
+ }
+ if forwardedProto := r.Header.Get("X-Forwarded-Proto"); forwardedProto != "" {
+ scheme = forwardedProto
+ }
+ host := r.Host
+ if forwardedHost := r.Header.Get("X-Forwarded-Host"); forwardedHost != "" {
+ host = forwardedHost
+ }
+ baseURL := scheme + "://" + host
+
+ authorizationEndpoint := responseConfig.AuthorizationEndpoint
+ if authorizationEndpoint == "" {
+ authorizationEndpoint = baseURL + "/authorize"
+ }
+ tokenEndpoint := responseConfig.TokenEndpoint
+ if tokenEndpoint == "" {
+ tokenEndpoint = baseURL + "/token"
+ }
+ registraionEndpoint := responseConfig.RegistrationEndpoint
+ if registraionEndpoint == "" {
+ registraionEndpoint = baseURL + "/register"
+ }
+
+ // Build response from config
+ response := map[string]interface{}{
+ "issuer": responseConfig.Issuer,
+ "authorization_endpoint": authorizationEndpoint,
+ "token_endpoint": tokenEndpoint,
+ "jwks_uri": responseConfig.JwksURI,
+ "response_types_supported": responseConfig.ResponseTypesSupported,
+ "grant_types_supported": responseConfig.GrantTypesSupported,
+ "token_endpoint_auth_methods_supported": []string{"client_secret_basic"},
+ "registration_endpoint": registraionEndpoint,
+ "code_challenge_methods_supported": responseConfig.CodeChallengeMethodsSupported,
+ }
+
+ w.Header().Set("Content-Type", "application/json")
+ if err := json.NewEncoder(w).Encode(response); err != nil {
+ logger.Error("Error encoding well-known response: %v", err)
+ http.Error(w, "Internal server error", http.StatusInternalServerError)
+ }
+ return
+ }
+ }
+ }
+}
+
+func (p *defaultProvider) RegisterHandler() http.HandlerFunc {
+ return nil
+}
diff --git a/internal/authz/default_test.go b/internal/authz/default_test.go
new file mode 100644
index 0000000..f40030f
--- /dev/null
+++ b/internal/authz/default_test.go
@@ -0,0 +1,125 @@
+package authz
+
+import (
+ "encoding/json"
+ "net/http"
+ "net/http/httptest"
+ "testing"
+
+ "github.com/wso2/open-mcp-auth-proxy/internal/config"
+)
+
+func TestNewDefaultProvider(t *testing.T) {
+ cfg := &config.Config{}
+ provider := NewDefaultProvider(cfg)
+
+ if provider == nil {
+ t.Fatal("Expected non-nil provider")
+ }
+
+ // Ensure it implements the Provider interface
+ var _ Provider = provider
+}
+
+func TestDefaultProviderWellKnownHandler(t *testing.T) {
+ // Create a config with a custom well-known response
+ cfg := &config.Config{
+ Default: config.DefaultConfig{
+ Path: map[string]config.PathConfig{
+ "/.well-known/oauth-authorization-server": {
+ Response: &config.ResponseConfig{
+ Issuer: "https://test-issuer.com",
+ JwksURI: "https://test-issuer.com/jwks",
+ ResponseTypesSupported: []string{"code"},
+ GrantTypesSupported: []string{"authorization_code"},
+ CodeChallengeMethodsSupported: []string{"S256"},
+ },
+ },
+ },
+ },
+ }
+
+ provider := NewDefaultProvider(cfg)
+ handler := provider.WellKnownHandler()
+
+ // Create a test request
+ req := httptest.NewRequest("GET", "/.well-known/oauth-authorization-server", nil)
+ req.Host = "test-host.com"
+ req.Header.Set("X-Forwarded-Proto", "https")
+
+ // Create a response recorder
+ w := httptest.NewRecorder()
+
+ // Call the handler
+ handler(w, req)
+
+ // Check response status
+ if w.Code != http.StatusOK {
+ t.Errorf("Expected status OK, got %v", w.Code)
+ }
+
+ // Verify content type
+ contentType := w.Header().Get("Content-Type")
+ if contentType != "application/json" {
+ t.Errorf("Expected Content-Type: application/json, got %s", contentType)
+ }
+
+ // Decode and check the response body
+ var response map[string]interface{}
+ if err := json.NewDecoder(w.Body).Decode(&response); err != nil {
+ t.Fatalf("Failed to decode response JSON: %v", err)
+ }
+
+ // Check expected values
+ if response["issuer"] != "https://test-issuer.com" {
+ t.Errorf("Expected issuer=https://test-issuer.com, got %v", response["issuer"])
+ }
+ if response["jwks_uri"] != "https://test-issuer.com/jwks" {
+ t.Errorf("Expected jwks_uri=https://test-issuer.com/jwks, got %v", response["jwks_uri"])
+ }
+ if response["authorization_endpoint"] != "https://test-host.com/authorize" {
+ t.Errorf("Expected authorization_endpoint=https://test-host.com/authorize, got %v", response["authorization_endpoint"])
+ }
+}
+
+func TestDefaultProviderHandleOPTIONS(t *testing.T) {
+ provider := NewDefaultProvider(&config.Config{})
+ handler := provider.WellKnownHandler()
+
+ // Create OPTIONS request
+ req := httptest.NewRequest("OPTIONS", "/.well-known/oauth-authorization-server", nil)
+ w := httptest.NewRecorder()
+
+ // Call the handler
+ handler(w, req)
+
+ // Check response
+ if w.Code != http.StatusNoContent {
+ t.Errorf("Expected status NoContent for OPTIONS request, got %v", w.Code)
+ }
+
+ // Check CORS headers
+ if w.Header().Get("Access-Control-Allow-Origin") != "*" {
+ t.Errorf("Expected Access-Control-Allow-Origin: *, got %s", w.Header().Get("Access-Control-Allow-Origin"))
+ }
+ if w.Header().Get("Access-Control-Allow-Methods") != "GET, OPTIONS" {
+ t.Errorf("Expected Access-Control-Allow-Methods: GET, OPTIONS, got %s", w.Header().Get("Access-Control-Allow-Methods"))
+ }
+}
+
+func TestDefaultProviderInvalidMethod(t *testing.T) {
+ provider := NewDefaultProvider(&config.Config{})
+ handler := provider.WellKnownHandler()
+
+ // Create POST request (which should be rejected)
+ req := httptest.NewRequest("POST", "/.well-known/oauth-authorization-server", nil)
+ w := httptest.NewRecorder()
+
+ // Call the handler
+ handler(w, req)
+
+ // Check response
+ if w.Code != http.StatusMethodNotAllowed {
+ t.Errorf("Expected status MethodNotAllowed for POST request, got %v", w.Code)
+ }
+}
diff --git a/internal/config/config.go b/internal/config/config.go
index 3a3b231..fc6743c 100644
--- a/internal/config/config.go
+++ b/internal/config/config.go
@@ -1,29 +1,155 @@
package config
import (
+ "fmt"
"os"
"gopkg.in/yaml.v2"
)
-// AsgardeoConfig groups all Asgardeo-specific fields
+// Transport mode for MCP server
+type TransportMode string
+
+const (
+ SSETransport TransportMode = "sse"
+ StdioTransport TransportMode = "stdio"
+)
+
+// Common path configuration for all transport modes
+type PathsConfig struct {
+ SSE string `yaml:"sse"`
+ Messages string `yaml:"messages"`
+}
+
+// StdioConfig contains stdio-specific configuration
+type StdioConfig struct {
+ Enabled bool `yaml:"enabled"`
+ UserCommand string `yaml:"user_command"` // The command provided by the user
+ WorkDir string `yaml:"work_dir"` // Working directory (optional)
+ Args []string `yaml:"args,omitempty"` // Additional arguments
+ Env []string `yaml:"env,omitempty"` // Environment variables
+}
+
type DemoConfig struct {
ClientID string `yaml:"client_id"`
ClientSecret string `yaml:"client_secret"`
OrgName string `yaml:"org_name"`
}
+type AsgardeoConfig struct {
+ ClientID string `yaml:"client_id"`
+ ClientSecret string `yaml:"client_secret"`
+ OrgName string `yaml:"org_name"`
+}
+
+type CORSConfig struct {
+ AllowedOrigins []string `yaml:"allowed_origins"`
+ AllowedMethods []string `yaml:"allowed_methods"`
+ AllowedHeaders []string `yaml:"allowed_headers"`
+ AllowCredentials bool `yaml:"allow_credentials"`
+}
+
+type ParamConfig struct {
+ Name string `yaml:"name"`
+ Value string `yaml:"value"`
+}
+
+type ResponseConfig struct {
+ Issuer string `yaml:"issuer,omitempty"`
+ JwksURI string `yaml:"jwks_uri,omitempty"`
+ AuthorizationEndpoint string `yaml:"authorization_endpoint,omitempty"`
+ TokenEndpoint string `yaml:"token_endpoint,omitempty"`
+ RegistrationEndpoint string `yaml:"registration_endpoint,omitempty"`
+ ResponseTypesSupported []string `yaml:"response_types_supported,omitempty"`
+ GrantTypesSupported []string `yaml:"grant_types_supported,omitempty"`
+ CodeChallengeMethodsSupported []string `yaml:"code_challenge_methods_supported,omitempty"`
+}
+
+type PathConfig struct {
+ // For well-known endpoint
+ Response *ResponseConfig `yaml:"response,omitempty"`
+
+ // For authorization endpoint
+ AddQueryParams []ParamConfig `yaml:"addQueryParams,omitempty"`
+
+ // For token and register endpoints
+ AddBodyParams []ParamConfig `yaml:"addBodyParams,omitempty"`
+}
+
+type DefaultConfig struct {
+ BaseURL string `yaml:"base_url,omitempty"`
+ Path map[string]PathConfig `yaml:"path,omitempty"`
+ JWKSURL string `yaml:"jwks_url,omitempty"`
+}
+
type Config struct {
- AuthServerBaseURL string `yaml:"auth_server_base_url"`
- MCPServerBaseURL string `yaml:"mcp_server_base_url"`
- ListenAddress string `yaml:"listen_address"`
- JWKSURL string `yaml:"jwks_url"`
- TimeoutSeconds int `yaml:"timeout_seconds"`
- MCPPaths []string `yaml:"mcp_paths"`
- PathMapping map[string]string `yaml:"path_mapping"`
+ AuthServerBaseURL string
+ ListenPort int `yaml:"listen_port"`
+ BaseURL string `yaml:"base_url"`
+ Port int `yaml:"port"`
+ JWKSURL string
+ TimeoutSeconds int `yaml:"timeout_seconds"`
+ PathMapping map[string]string `yaml:"path_mapping"`
+ Mode string `yaml:"mode"`
+ CORSConfig CORSConfig `yaml:"cors"`
+ TransportMode TransportMode `yaml:"transport_mode"`
+ Paths PathsConfig `yaml:"paths"`
+ Stdio StdioConfig `yaml:"stdio"`
// Nested config for Asgardeo
- Demo DemoConfig `yaml:"demo"`
+ Demo DemoConfig `yaml:"demo"`
+ Asgardeo AsgardeoConfig `yaml:"asgardeo"`
+ Default DefaultConfig `yaml:"default"`
+}
+
+// Validate checks if the config is valid based on transport mode
+func (c *Config) Validate() error {
+ // Validate based on transport mode
+ if c.TransportMode == StdioTransport {
+ if !c.Stdio.Enabled {
+ return fmt.Errorf("stdio.enabled must be true in stdio transport mode")
+ }
+ if c.Stdio.UserCommand == "" {
+ return fmt.Errorf("stdio.user_command is required in stdio transport mode")
+ }
+ }
+
+ // Validate paths
+ if c.Paths.SSE == "" {
+ c.Paths.SSE = "/sse" // Default value
+ }
+ if c.Paths.Messages == "" {
+ c.Paths.Messages = "/messages" // Default value
+ }
+
+ // Validate base URL
+ if c.BaseURL == "" {
+ if c.Port > 0 {
+ c.BaseURL = fmt.Sprintf("http://localhost:%d", c.Port)
+ } else {
+ c.BaseURL = "http://localhost:8000" // Default value
+ }
+ }
+
+ return nil
+}
+
+// GetMCPPaths returns the list of paths that should be proxied to the MCP server
+func (c *Config) GetMCPPaths() []string {
+ return []string{c.Paths.SSE, c.Paths.Messages}
+}
+
+// BuildExecCommand constructs the full command string for execution in stdio mode
+func (c *Config) BuildExecCommand() string {
+ if c.Stdio.UserCommand == "" {
+ return ""
+ }
+
+ // Construct the full command
+ return fmt.Sprintf(
+ `npx -y supergateway --stdio "%s" --port %d --baseUrl %s --ssePath %s --messagePath %s`,
+ c.Stdio.UserCommand, c.Port, c.BaseURL, c.Paths.SSE, c.Paths.Messages,
+ )
}
// LoadConfig reads a YAML config file into Config struct.
@@ -39,8 +165,26 @@ func LoadConfig(path string) (*Config, error) {
if err := decoder.Decode(&cfg); err != nil {
return nil, err
}
+
+ // Set default values
if cfg.TimeoutSeconds == 0 {
cfg.TimeoutSeconds = 15 // default
}
+
+ // Set default transport mode if not specified
+ if cfg.TransportMode == "" {
+ cfg.TransportMode = SSETransport // Default to SSE
+ }
+
+ // Set default port if not specified
+ if cfg.Port == 0 {
+ cfg.Port = 8000 // default
+ }
+
+ // Validate the configuration
+ if err := cfg.Validate(); err != nil {
+ return nil, err
+ }
+
return &cfg, nil
}
diff --git a/internal/config/config_test.go b/internal/config/config_test.go
new file mode 100644
index 0000000..20c0893
--- /dev/null
+++ b/internal/config/config_test.go
@@ -0,0 +1,196 @@
+package config
+
+import (
+ "os"
+ "path/filepath"
+ "testing"
+)
+
+func TestLoadConfig(t *testing.T) {
+ // Create a temporary config file
+ tempDir := t.TempDir()
+ configPath := filepath.Join(tempDir, "test_config.yaml")
+
+ // Basic valid config
+ validConfig := `
+listen_port: 8080
+base_url: "http://localhost:8000"
+transport_mode: "sse"
+paths:
+ sse: "/sse"
+ messages: "/messages"
+cors:
+ allowed_origins:
+ - "http://localhost:5173"
+ allowed_methods:
+ - "GET"
+ - "POST"
+ allowed_headers:
+ - "Authorization"
+ - "Content-Type"
+ allow_credentials: true
+`
+ err := os.WriteFile(configPath, []byte(validConfig), 0644)
+ if err != nil {
+ t.Fatalf("Failed to create test config file: %v", err)
+ }
+
+ // Test loading the valid config
+ cfg, err := LoadConfig(configPath)
+ if err != nil {
+ t.Fatalf("Failed to load valid config: %v", err)
+ }
+
+ // Verify expected values from the config
+ if cfg.ListenPort != 8080 {
+ t.Errorf("Expected ListenPort=8080, got %d", cfg.ListenPort)
+ }
+ if cfg.BaseURL != "http://localhost:8000" {
+ t.Errorf("Expected BaseURL=http://localhost:8000, got %s", cfg.BaseURL)
+ }
+ if cfg.TransportMode != SSETransport {
+ t.Errorf("Expected TransportMode=sse, got %s", cfg.TransportMode)
+ }
+ if cfg.Paths.SSE != "/sse" {
+ t.Errorf("Expected Paths.SSE=/sse, got %s", cfg.Paths.SSE)
+ }
+ if cfg.Paths.Messages != "/messages" {
+ t.Errorf("Expected Paths.Messages=/messages, got %s", cfg.Paths.Messages)
+ }
+
+ // Test default values
+ if cfg.TimeoutSeconds != 15 {
+ t.Errorf("Expected default TimeoutSeconds=15, got %d", cfg.TimeoutSeconds)
+ }
+ if cfg.Port != 8000 {
+ t.Errorf("Expected default Port=8000, got %d", cfg.Port)
+ }
+}
+
+func TestValidate(t *testing.T) {
+ tests := []struct {
+ name string
+ config Config
+ expectError bool
+ }{
+ {
+ name: "Valid SSE config",
+ config: Config{
+ TransportMode: SSETransport,
+ Paths: PathsConfig{
+ SSE: "/sse",
+ Messages: "/messages",
+ },
+ BaseURL: "http://localhost:8000",
+ },
+ expectError: false,
+ },
+ {
+ name: "Valid stdio config",
+ config: Config{
+ TransportMode: StdioTransport,
+ Stdio: StdioConfig{
+ Enabled: true,
+ UserCommand: "some-command",
+ },
+ },
+ expectError: false,
+ },
+ {
+ name: "Invalid stdio config - not enabled",
+ config: Config{
+ TransportMode: StdioTransport,
+ Stdio: StdioConfig{
+ Enabled: false,
+ UserCommand: "some-command",
+ },
+ },
+ expectError: true,
+ },
+ {
+ name: "Invalid stdio config - no command",
+ config: Config{
+ TransportMode: StdioTransport,
+ Stdio: StdioConfig{
+ Enabled: true,
+ UserCommand: "",
+ },
+ },
+ expectError: true,
+ },
+ }
+
+ for _, tc := range tests {
+ t.Run(tc.name, func(t *testing.T) {
+ err := tc.config.Validate()
+ if tc.expectError && err == nil {
+ t.Errorf("Expected validation error but got none")
+ }
+ if !tc.expectError && err != nil {
+ t.Errorf("Expected no validation error but got: %v", err)
+ }
+ })
+ }
+}
+
+func TestGetMCPPaths(t *testing.T) {
+ cfg := Config{
+ Paths: PathsConfig{
+ SSE: "/custom-sse",
+ Messages: "/custom-messages",
+ },
+ }
+
+ paths := cfg.GetMCPPaths()
+ if len(paths) != 2 {
+ t.Errorf("Expected 2 MCP paths, got %d", len(paths))
+ }
+ if paths[0] != "/custom-sse" {
+ t.Errorf("Expected first path=/custom-sse, got %s", paths[0])
+ }
+ if paths[1] != "/custom-messages" {
+ t.Errorf("Expected second path=/custom-messages, got %s", paths[1])
+ }
+}
+
+func TestBuildExecCommand(t *testing.T) {
+ tests := []struct {
+ name string
+ config Config
+ expectedResult string
+ }{
+ {
+ name: "Valid command",
+ config: Config{
+ Stdio: StdioConfig{
+ UserCommand: "test-command",
+ },
+ Port: 8080,
+ BaseURL: "http://example.com",
+ Paths: PathsConfig{
+ SSE: "/sse-path",
+ Messages: "/msgs",
+ },
+ },
+ expectedResult: `npx -y supergateway --stdio "test-command" --port 8080 --baseUrl http://example.com --ssePath /sse-path --messagePath /msgs`,
+ },
+ {
+ name: "Empty command",
+ config: Config{
+ Stdio: StdioConfig{
+ UserCommand: "",
+ },
+ },
+ expectedResult: "",
+ },
+ }
+
+ for _, tc := range tests {
+ t.Run(tc.name, func(t *testing.T) {
+ result := tc.config.BuildExecCommand()
+ if result != tc.expectedResult {
+ t.Errorf("Expected command=%s, got %s", tc.expectedResult, result)
+ }
+ })
+ }
+}
diff --git a/internal/constants/constants.go b/internal/constants/constants.go
new file mode 100644
index 0000000..1e5808e
--- /dev/null
+++ b/internal/constants/constants.go
@@ -0,0 +1,7 @@
+package constants
+
+// Package constant provides constants for the MCP Auth Proxy
+
+const (
+ ASGARDEO_BASE_URL = "https://api.asgardeo.io/t/"
+)
diff --git a/internal/logging/logger.go b/internal/logging/logger.go
new file mode 100644
index 0000000..57bec27
--- /dev/null
+++ b/internal/logging/logger.go
@@ -0,0 +1,34 @@
+package logger
+
+import (
+ "log"
+)
+
+var isDebug = false
+
+// SetDebug enables or disables debug logging
+func SetDebug(debug bool) {
+ isDebug = debug
+}
+
+// Debug logs a debug-level message
+func Debug(format string, v ...interface{}) {
+ if isDebug {
+ log.Printf("DEBUG: "+format, v...)
+ }
+}
+
+// Info logs an info-level message
+func Info(format string, v ...interface{}) {
+ log.Printf("INFO: "+format, v...)
+}
+
+// Warn logs a warning-level message
+func Warn(format string, v ...interface{}) {
+ log.Printf("WARN: "+format, v...)
+}
+
+// Error logs an error-level message
+func Error(format string, v ...interface{}) {
+ log.Printf("ERROR: "+format, v...)
+}
diff --git a/internal/proxy/modifier.go b/internal/proxy/modifier.go
new file mode 100644
index 0000000..6662b2c
--- /dev/null
+++ b/internal/proxy/modifier.go
@@ -0,0 +1,204 @@
+package proxy
+
+import (
+ "bytes"
+ "encoding/json"
+ "fmt"
+ "io"
+ "net/http"
+ "strings"
+
+ "github.com/wso2/open-mcp-auth-proxy/internal/config"
+ "github.com/wso2/open-mcp-auth-proxy/internal/logging"
+)
+
+// RequestModifier modifies requests before they are proxied
+type RequestModifier interface {
+ ModifyRequest(req *http.Request) (*http.Request, error)
+}
+
+// AuthorizationModifier adds parameters to authorization requests
+type AuthorizationModifier struct {
+ Config *config.Config
+}
+
+// TokenModifier adds parameters to token requests
+type TokenModifier struct {
+ Config *config.Config
+}
+
+type RegisterModifier struct {
+ Config *config.Config
+}
+
+// ModifyRequest adds configured parameters to authorization requests
+func (m *AuthorizationModifier) ModifyRequest(req *http.Request) (*http.Request, error) {
+ // Check if we have parameters to add
+ if m.Config.Default.Path == nil {
+ return req, nil
+ }
+
+ pathConfig, exists := m.Config.Default.Path["/authorize"]
+ if !exists || len(pathConfig.AddQueryParams) == 0 {
+ return req, nil
+ }
+ // Get current query parameters
+ query := req.URL.Query()
+
+ // Add parameters from config
+ for _, param := range pathConfig.AddQueryParams {
+ query.Set(param.Name, param.Value)
+ }
+
+ // Update the request URL
+ req.URL.RawQuery = query.Encode()
+
+ return req, nil
+}
+
+// ModifyRequest adds configured parameters to token requests
+func (m *TokenModifier) ModifyRequest(req *http.Request) (*http.Request, error) {
+ // Only modify POST requests
+ if req.Method != http.MethodPost {
+ return req, nil
+ }
+
+ // Check if we have parameters to add
+ if m.Config.Default.Path == nil {
+ return req, nil
+ }
+
+ pathConfig, exists := m.Config.Default.Path["/token"]
+ if !exists || len(pathConfig.AddBodyParams) == 0 {
+ return req, nil
+ }
+
+ contentType := req.Header.Get("Content-Type")
+
+ if strings.Contains(contentType, "application/x-www-form-urlencoded") {
+ // Parse form data
+ if err := req.ParseForm(); err != nil {
+ return nil, err
+ }
+
+ // Clone form data
+ formData := req.PostForm
+
+ // Add configured parameters
+ for _, param := range pathConfig.AddBodyParams {
+ formData.Set(param.Name, param.Value)
+ }
+
+ // Create new request body with modified form
+ formEncoded := formData.Encode()
+ req.Body = io.NopCloser(strings.NewReader(formEncoded))
+ req.ContentLength = int64(len(formEncoded))
+ req.Header.Set("Content-Length", fmt.Sprintf("%d", len(formEncoded)))
+
+ } else if strings.Contains(contentType, "application/json") {
+ // Read body
+ bodyBytes, err := io.ReadAll(req.Body)
+ if err != nil {
+ return nil, err
+ }
+
+ // Parse JSON
+ var jsonData map[string]interface{}
+ if err := json.Unmarshal(bodyBytes, &jsonData); err != nil {
+ return nil, err
+ }
+
+ // Add parameters
+ for _, param := range pathConfig.AddBodyParams {
+ jsonData[param.Name] = param.Value
+ }
+
+ // Marshal back to JSON
+ modifiedBody, err := json.Marshal(jsonData)
+ if err != nil {
+ return nil, err
+ }
+
+ // Update request
+ req.Body = io.NopCloser(bytes.NewReader(modifiedBody))
+ req.ContentLength = int64(len(modifiedBody))
+ req.Header.Set("Content-Length", fmt.Sprintf("%d", len(modifiedBody)))
+ }
+
+ return req, nil
+}
+
+func (m *RegisterModifier) ModifyRequest(req *http.Request) (*http.Request, error) {
+ // Only modify POST requests
+ if req.Method != http.MethodPost {
+ return req, nil
+ }
+
+ // Check if we have parameters to add
+ if m.Config.Default.Path == nil {
+ return req, nil
+ }
+
+ pathConfig, exists := m.Config.Default.Path["/register"]
+ if !exists || len(pathConfig.AddBodyParams) == 0 {
+ return req, nil
+ }
+
+ contentType := req.Header.Get("Content-Type")
+
+ if strings.Contains(contentType, "application/x-www-form-urlencoded") {
+ // Parse form data
+ if err := req.ParseForm(); err != nil {
+ logger.Error("Failed to parse form data: %v", err)
+ return nil, err
+ }
+
+ // Clone form data
+ formData := req.PostForm
+
+ // Add configured parameters
+ for _, param := range pathConfig.AddBodyParams {
+ formData.Set(param.Name, param.Value)
+ }
+
+ // Create new request body with modified form
+ formEncoded := formData.Encode()
+ req.Body = io.NopCloser(strings.NewReader(formEncoded))
+ req.ContentLength = int64(len(formEncoded))
+ req.Header.Set("Content-Length", fmt.Sprintf("%d", len(formEncoded)))
+
+ } else if strings.Contains(contentType, "application/json") {
+ // Read body
+ bodyBytes, err := io.ReadAll(req.Body)
+ if err != nil {
+ logger.Error("Failed to read request body: %v", err)
+ return nil, err
+ }
+
+ // Parse JSON
+ var jsonData map[string]interface{}
+ if err := json.Unmarshal(bodyBytes, &jsonData); err != nil {
+ logger.Error("Failed to parse JSON body: %v", err)
+ return nil, err
+ }
+
+ // Add parameters
+ for _, param := range pathConfig.AddBodyParams {
+ jsonData[param.Name] = param.Value
+ }
+
+ // Marshal back to JSON
+ modifiedBody, err := json.Marshal(jsonData)
+ if err != nil {
+ logger.Error("Failed to marshal modified JSON: %v", err)
+ return nil, err
+ }
+
+ // Update request
+ req.Body = io.NopCloser(bytes.NewReader(modifiedBody))
+ req.ContentLength = int64(len(modifiedBody))
+ req.Header.Set("Content-Length", fmt.Sprintf("%d", len(modifiedBody)))
+ }
+
+ return req, nil
+}
diff --git a/internal/proxy/modifier_test.go b/internal/proxy/modifier_test.go
new file mode 100644
index 0000000..3d2fd44
--- /dev/null
+++ b/internal/proxy/modifier_test.go
@@ -0,0 +1,147 @@
+package proxy
+
+import (
+ "net/http"
+ "net/url"
+ "strings"
+ "testing"
+
+ "github.com/wso2/open-mcp-auth-proxy/internal/config"
+)
+
+func TestAuthorizationModifier(t *testing.T) {
+ cfg := &config.Config{
+ Default: config.DefaultConfig{
+ Path: map[string]config.PathConfig{
+ "/authorize": {
+ AddQueryParams: []config.ParamConfig{
+ {Name: "client_id", Value: "test-client-id"},
+ {Name: "scope", Value: "openid"},
+ },
+ },
+ },
+ },
+ }
+
+ modifier := &AuthorizationModifier{Config: cfg}
+
+ // Create a test request
+ req, err := http.NewRequest("GET", "/authorize?response_type=code", nil)
+ if err != nil {
+ t.Fatalf("Failed to create request: %v", err)
+ }
+
+ // Modify the request
+ modifiedReq, err := modifier.ModifyRequest(req)
+ if err != nil {
+ t.Fatalf("ModifyRequest failed: %v", err)
+ }
+
+ // Check that the query parameters were added
+ query := modifiedReq.URL.Query()
+ if query.Get("client_id") != "test-client-id" {
+ t.Errorf("Expected client_id=test-client-id, got %s", query.Get("client_id"))
+ }
+ if query.Get("scope") != "openid" {
+ t.Errorf("Expected scope=openid, got %s", query.Get("scope"))
+ }
+ if query.Get("response_type") != "code" {
+ t.Errorf("Expected response_type=code, got %s", query.Get("response_type"))
+ }
+}
+
+func TestTokenModifier(t *testing.T) {
+ cfg := &config.Config{
+ Default: config.DefaultConfig{
+ Path: map[string]config.PathConfig{
+ "/token": {
+ AddBodyParams: []config.ParamConfig{
+ {Name: "audience", Value: "test-audience"},
+ },
+ },
+ },
+ },
+ }
+
+ modifier := &TokenModifier{Config: cfg}
+
+ // Create a test request with form data
+ form := url.Values{}
+
+ req, err := http.NewRequest("POST", "/token", strings.NewReader(form.Encode()))
+ if err != nil {
+ t.Fatalf("Failed to create request: %v", err)
+ }
+ req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
+
+ // Modify the request
+ modifiedReq, err := modifier.ModifyRequest(req)
+ if err != nil {
+ t.Fatalf("ModifyRequest failed: %v", err)
+ }
+
+ body := make([]byte, 1024)
+ n, err := modifiedReq.Body.Read(body)
+ if err != nil && err.Error() != "EOF" {
+ t.Fatalf("Failed to read body: %v", err)
+ }
+ bodyStr := string(body[:n])
+
+ // Parse the form data from the modified request
+ if err := modifiedReq.ParseForm(); err != nil {
+ t.Fatalf("Failed to parse form data: %v", err)
+ }
+
+ // Check that the body parameters were added
+ if !strings.Contains(bodyStr, "audience") {
+ t.Errorf("Expected body to contain audience, got %s", bodyStr)
+ }
+}
+
+func TestRegisterModifier(t *testing.T) {
+ cfg := &config.Config{
+ Default: config.DefaultConfig{
+ Path: map[string]config.PathConfig{
+ "/register": {
+ AddBodyParams: []config.ParamConfig{
+ {Name: "client_name", Value: "test-client"},
+ },
+ },
+ },
+ },
+ }
+
+ modifier := &RegisterModifier{Config: cfg}
+
+ // Create a test request with JSON data
+ jsonBody := `{"redirect_uris":["https://example.com/callback"]}`
+ req, err := http.NewRequest("POST", "/register", strings.NewReader(jsonBody))
+ if err != nil {
+ t.Fatalf("Failed to create request: %v", err)
+ }
+ req.Header.Set("Content-Type", "application/json")
+
+ // Modify the request
+ modifiedReq, err := modifier.ModifyRequest(req)
+ if err != nil {
+ t.Fatalf("ModifyRequest failed: %v", err)
+ }
+
+ // Read the body and check that it still contains the original data
+ // This test would need to be enhanced with a proper JSON parsing to verify
+ // the added parameters
+ body := make([]byte, 1024)
+ n, err := modifiedReq.Body.Read(body)
+ if err != nil && err.Error() != "EOF" {
+ t.Fatalf("Failed to read body: %v", err)
+ }
+ bodyStr := string(body[:n])
+
+ // Simple check to see if the modified body contains the expected fields
+ if !strings.Contains(bodyStr, "client_name") {
+ t.Errorf("Expected body to contain client_name, got %s", bodyStr)
+ }
+ if !strings.Contains(bodyStr, "redirect_uris") {
+ t.Errorf("Expected body to contain redirect_uris, got %s", bodyStr)
+ }
+}
diff --git a/internal/proxy/proxy.go b/internal/proxy/proxy.go
index 382c8f3..33a9ea3 100644
--- a/internal/proxy/proxy.go
+++ b/internal/proxy/proxy.go
@@ -2,7 +2,6 @@ package proxy
import (
"context"
- "log"
"net/http"
"net/http/httputil"
"net/url"
@@ -11,6 +10,7 @@ import (
"github.com/wso2/open-mcp-auth-proxy/internal/authz"
"github.com/wso2/open-mcp-auth-proxy/internal/config"
+ "github.com/wso2/open-mcp-auth-proxy/internal/logging"
"github.com/wso2/open-mcp-auth-proxy/internal/util"
)
@@ -20,78 +20,138 @@ import (
func NewRouter(cfg *config.Config, provider authz.Provider) http.Handler {
mux := http.NewServeMux()
- // 1. Custom well-known
- mux.HandleFunc("/.well-known/oauth-authorization-server", provider.WellKnownHandler())
+ modifiers := map[string]RequestModifier{
+ "/authorize": &AuthorizationModifier{Config: cfg},
+ "/token": &TokenModifier{Config: cfg},
+ "/register": &RegisterModifier{Config: cfg},
+ }
- // 2. Registration
- mux.HandleFunc("/register", provider.RegisterHandler())
+ registeredPaths := make(map[string]bool)
- // 3. Default "auth" paths, proxied
- defaultPaths := []string{"/authorize", "/token"}
+ var defaultPaths []string
+
+ // Handle based on mode configuration
+ if cfg.Mode == "demo" || cfg.Mode == "asgardeo" {
+ // Demo/Asgardeo mode: Custom handlers for well-known and register
+ mux.HandleFunc("/.well-known/oauth-authorization-server", provider.WellKnownHandler())
+ registeredPaths["/.well-known/oauth-authorization-server"] = true
+
+ mux.HandleFunc("/register", provider.RegisterHandler())
+ registeredPaths["/register"] = true
+
+ // Authorize and token will be proxied with parameter modification
+ defaultPaths = []string{"/authorize", "/token"}
+ } else {
+ // Default provider mode
+ if cfg.Default.Path != nil {
+ // Check if we have custom response for well-known
+ wellKnownConfig, exists := cfg.Default.Path["/.well-known/oauth-authorization-server"]
+ if exists && wellKnownConfig.Response != nil {
+ // If there's a custom response defined, use our handler
+ mux.HandleFunc("/.well-known/oauth-authorization-server", provider.WellKnownHandler())
+ registeredPaths["/.well-known/oauth-authorization-server"] = true
+ } else {
+ // No custom response, add well-known to proxy paths
+ defaultPaths = append(defaultPaths, "/.well-known/oauth-authorization-server")
+ }
+
+ defaultPaths = append(defaultPaths, "/authorize")
+ defaultPaths = append(defaultPaths, "/token")
+ defaultPaths = append(defaultPaths, "/register")
+ } else {
+ defaultPaths = []string{"/authorize", "/token", "/register", "/.well-known/oauth-authorization-server"}
+ }
+ }
+
+ // Remove duplicates from defaultPaths
+ uniquePaths := make(map[string]bool)
+ cleanPaths := []string{}
for _, path := range defaultPaths {
- mux.HandleFunc(path, buildProxyHandler(cfg))
+ if !uniquePaths[path] {
+ uniquePaths[path] = true
+ cleanPaths = append(cleanPaths, path)
+ }
+ }
+ defaultPaths = cleanPaths
+
+ for _, path := range defaultPaths {
+ if !registeredPaths[path] {
+ mux.HandleFunc(path, buildProxyHandler(cfg, modifiers))
+ registeredPaths[path] = true
+ }
}
- // 4. MCP paths
- for _, path := range cfg.MCPPaths {
- mux.HandleFunc(path, buildProxyHandler(cfg))
+ // MCP paths
+ mcpPaths := cfg.GetMCPPaths()
+ for _, path := range mcpPaths {
+ mux.HandleFunc(path, buildProxyHandler(cfg, modifiers))
+ registeredPaths[path] = true
}
- // 5. If you want to map additional paths from config.PathMapping
- // to the same proxy logic:
+ // Register paths from PathMapping that haven't been registered yet
for path := range cfg.PathMapping {
- mux.HandleFunc(path, buildProxyHandler(cfg))
+ if !registeredPaths[path] {
+ mux.HandleFunc(path, buildProxyHandler(cfg, modifiers))
+ registeredPaths[path] = true
+ }
}
return mux
}
-func buildProxyHandler(cfg *config.Config) http.HandlerFunc {
+func buildProxyHandler(cfg *config.Config, modifiers map[string]RequestModifier) http.HandlerFunc {
// Parse the base URLs up front
authBase, err := url.Parse(cfg.AuthServerBaseURL)
if err != nil {
- log.Fatalf("Invalid auth server URL: %v", err)
+ logger.Error("Invalid auth server URL: %v", err)
+ panic(err) // Fatal error that prevents startup
}
- mcpBase, err := url.Parse(cfg.MCPServerBaseURL)
+
+ mcpBase, err := url.Parse(cfg.BaseURL)
if err != nil {
- log.Fatalf("Invalid MCP server URL: %v", err)
- }
-
- // We'll define sets for known auth paths, SSE paths, etc.
- authPaths := map[string]bool{
- "/authorize": true,
- "/token": true,
- "/.well-known/oauth-authorization-server": true,
+ logger.Error("Invalid MCP server URL: %v", err)
+ panic(err) // Fatal error that prevents startup
}
// Detect SSE paths from config
ssePaths := make(map[string]bool)
- for _, p := range cfg.MCPPaths {
- if p == "/sse" {
- ssePaths[p] = true
- }
- }
+ ssePaths[cfg.Paths.SSE] = true
return func(w http.ResponseWriter, r *http.Request) {
+ origin := r.Header.Get("Origin")
+ allowedOrigin := getAllowedOrigin(origin, cfg)
// Handle OPTIONS
if r.Method == http.MethodOptions {
- addCORSHeaders(w)
+ if allowedOrigin == "" {
+ logger.Warn("Preflight request from disallowed origin: %s", origin)
+ http.Error(w, "CORS origin not allowed", http.StatusForbidden)
+ return
+ }
+ addCORSHeaders(w, cfg, allowedOrigin, r.Header.Get("Access-Control-Request-Headers"))
w.WriteHeader(http.StatusNoContent)
return
}
- addCORSHeaders(w)
+ if allowedOrigin == "" {
+ logger.Warn("Request from disallowed origin: %s for %s", origin, r.URL.Path)
+ http.Error(w, "CORS origin not allowed", http.StatusForbidden)
+ return
+ }
+
+ // Add CORS headers to all responses
+ addCORSHeaders(w, cfg, allowedOrigin, "")
// Decide whether the request should go to the auth server or MCP
var targetURL *url.URL
isSSE := false
- if authPaths[r.URL.Path] {
+ if isAuthPath(r.URL.Path) {
targetURL = authBase
} else if isMCPPath(r.URL.Path, cfg) {
- // Validate JWT if you want
+ // Validate JWT for MCP paths if required
+ // Placeholder for JWT validation logic
if err := util.ValidateJWT(r.Header.Get("Authorization")); err != nil {
- log.Printf("[proxy] Unauthorized request to %s: %v", r.URL.Path, err)
+ logger.Warn("Unauthorized request to %s: %v", r.URL.Path, err)
http.Error(w, "Unauthorized", http.StatusUnauthorized)
return
}
@@ -100,11 +160,21 @@ func buildProxyHandler(cfg *config.Config) http.HandlerFunc {
isSSE = true
}
} else {
- // If it's not recognized as an auth path or an MCP path
http.Error(w, "Forbidden", http.StatusForbidden)
return
}
+ // Apply request modifiers to add parameters
+ if modifier, exists := modifiers[r.URL.Path]; exists {
+ var err error
+ r, err = modifier.ModifyRequest(r)
+ if err != nil {
+ logger.Error("Error modifying request: %v", err)
+ http.Error(w, "Bad Request", http.StatusBadRequest)
+ return
+ }
+ }
+
// Build the reverse proxy
rp := &httputil.ReverseProxy{
Director: func(req *http.Request) {
@@ -120,31 +190,53 @@ func buildProxyHandler(cfg *config.Config) http.HandlerFunc {
req.URL.RawQuery = r.URL.RawQuery
req.Host = targetURL.Host
- for header, values := range r.Header {
+ cleanHeaders := http.Header{}
+
+ // Set proper origin header to match the target
+ if isSSE {
+ // For SSE, ensure origin matches the target
+ req.Header.Set("Origin", targetURL.Scheme+"://"+targetURL.Host)
+ }
+
+ for k, v := range r.Header {
// Skip hop-by-hop headers
- if strings.EqualFold(header, "Connection") ||
- strings.EqualFold(header, "Keep-Alive") ||
- strings.EqualFold(header, "Transfer-Encoding") ||
- strings.EqualFold(header, "Upgrade") ||
- strings.EqualFold(header, "Proxy-Authorization") ||
- strings.EqualFold(header, "Proxy-Connection") {
+ if skipHeader(k) {
continue
}
- for _, value := range values {
- req.Header.Set(header, value)
- }
+ // Set only the first value to avoid duplicates
+ cleanHeaders.Set(k, v[0])
}
- log.Printf("[proxy] %s -> %s%s", r.URL.Path, req.URL.Host, req.URL.Path)
+
+ req.Header = cleanHeaders
+
+ logger.Debug("%s -> %s%s", r.URL.Path, req.URL.Host, req.URL.Path)
+ },
+ ModifyResponse: func(resp *http.Response) error {
+ logger.Debug("Response from %s%s: %d", resp.Request.URL.Host, resp.Request.URL.Path, resp.StatusCode)
+ resp.Header.Del("Access-Control-Allow-Origin") // Avoid upstream conflicts
+ return nil
},
ErrorHandler: func(rw http.ResponseWriter, req *http.Request, err error) {
- log.Printf("[proxy] Error proxying: %v", err)
+ logger.Error("Error proxying: %v", err)
http.Error(rw, "Bad Gateway", http.StatusBadGateway)
},
FlushInterval: -1, // immediate flush for SSE
}
if isSSE {
+ // Add special response handling for SSE connections to rewrite endpoint URLs
+ rp.Transport = &sseTransport{
+ Transport: http.DefaultTransport,
+ proxyHost: r.Host,
+ targetHost: targetURL.Host,
+ }
+
+ // Set SSE-specific headers
+ w.Header().Set("X-Accel-Buffering", "no")
+ w.Header().Set("Cache-Control", "no-cache")
+ w.Header().Set("Connection", "keep-alive")
+
// Keep SSE connections open
HandleSSE(w, r, rp)
} else {
@@ -156,14 +248,52 @@ func buildProxyHandler(cfg *config.Config) http.HandlerFunc {
}
}
-func addCORSHeaders(w http.ResponseWriter) {
- w.Header().Set("Access-Control-Allow-Origin", "*")
- w.Header().Set("Access-Control-Allow-Headers", "Authorization, Content-Type")
- w.Header().Set("Access-Control-Allow-Methods", "GET, POST, OPTIONS")
+func getAllowedOrigin(origin string, cfg *config.Config) string {
+ if origin == "" {
+ return cfg.CORSConfig.AllowedOrigins[0] // Default to first allowed origin
+ }
+ for _, allowed := range cfg.CORSConfig.AllowedOrigins {
+ logger.Debug("Checking CORS origin: %s against allowed: %s", origin, allowed)
+ if allowed == origin {
+ return allowed
+ }
+ }
+ return ""
}
+// addCORSHeaders adds configurable CORS headers
+func addCORSHeaders(w http.ResponseWriter, cfg *config.Config, allowedOrigin, requestHeaders string) {
+ w.Header().Set("Access-Control-Allow-Origin", allowedOrigin)
+ w.Header().Set("Access-Control-Allow-Methods", strings.Join(cfg.CORSConfig.AllowedMethods, ", "))
+ if requestHeaders != "" {
+ w.Header().Set("Access-Control-Allow-Headers", requestHeaders)
+ } else {
+ w.Header().Set("Access-Control-Allow-Headers", strings.Join(cfg.CORSConfig.AllowedHeaders, ", "))
+ }
+ if cfg.CORSConfig.AllowCredentials {
+ w.Header().Set("Access-Control-Allow-Credentials", "true")
+ }
+ w.Header().Set("Vary", "Origin")
+ w.Header().Set("X-Accel-Buffering", "no")
+}
+
+func isAuthPath(path string) bool {
+ authPaths := map[string]bool{
+ "/authorize": true,
+ "/token": true,
+ "/register": true,
+ "/.well-known/oauth-authorization-server": true,
+ }
+ if strings.HasPrefix(path, "/u/") {
+ return true
+ }
+ return authPaths[path]
+}
+
+// isMCPPath checks if the path is an MCP path
func isMCPPath(path string, cfg *config.Config) bool {
- for _, p := range cfg.MCPPaths {
+ mcpPaths := cfg.GetMCPPaths()
+ for _, p := range mcpPaths {
if strings.HasPrefix(path, p) {
return true
}
@@ -171,22 +301,10 @@ func isMCPPath(path string, cfg *config.Config) bool {
return false
}
-func copyHeaders(src http.Header, dst http.Header) {
- // Exclude hop-by-hop
- hopByHop := map[string]bool{
- "Connection": true,
- "Keep-Alive": true,
- "Transfer-Encoding": true,
- "Upgrade": true,
- "Proxy-Authorization": true,
- "Proxy-Connection": true,
- }
- for k, vv := range src {
- if hopByHop[strings.ToLower(k)] {
- continue
- }
- for _, v := range vv {
- dst.Add(k, v)
- }
+func skipHeader(h string) bool {
+ switch strings.ToLower(h) {
+ case "connection", "keep-alive", "transfer-encoding", "upgrade", "proxy-authorization", "proxy-connection", "te", "trailer":
+ return true
}
+ return false
}
diff --git a/internal/proxy/sse.go b/internal/proxy/sse.go
index 44d6558..ce72e04 100644
--- a/internal/proxy/sse.go
+++ b/internal/proxy/sse.go
@@ -1,11 +1,16 @@
package proxy
import (
+ "bufio"
"context"
- "log"
+ "fmt"
+ "io"
"net/http"
"net/http/httputil"
+ "strings"
"time"
+
+ "github.com/wso2/open-mcp-auth-proxy/internal/logging"
)
// HandleSSE sets up a go-routine to wait for context cancellation
@@ -16,7 +21,7 @@ func HandleSSE(w http.ResponseWriter, r *http.Request, rp *httputil.ReverseProxy
go func() {
<-ctx.Done()
- log.Printf("INFO: SSE connection closed from %s (path: %s)", r.RemoteAddr, r.URL.Path)
+ logger.Info("SSE connection closed from %s (path: %s)", r.RemoteAddr, r.URL.Path)
close(done)
}()
@@ -32,3 +37,73 @@ func HandleSSE(w http.ResponseWriter, r *http.Request, rp *httputil.ReverseProxy
func NewShutdownContext(timeout time.Duration) (context.Context, context.CancelFunc) {
return context.WithTimeout(context.Background(), timeout)
}
+
+// sseTransport is a custom http.RoundTripper that intercepts and modifies SSE responses
+type sseTransport struct {
+ Transport http.RoundTripper
+ proxyHost string
+ targetHost string
+}
+
+func (t *sseTransport) RoundTrip(req *http.Request) (*http.Response, error) {
+ // Call the underlying transport
+ resp, err := t.Transport.RoundTrip(req)
+ if err != nil {
+ return nil, err
+ }
+
+ // Check if this is an SSE response
+ contentType := resp.Header.Get("Content-Type")
+ if !strings.Contains(contentType, "text/event-stream") {
+ return resp, nil
+ }
+
+ logger.Info("Intercepting SSE response to modify endpoint events")
+
+ // Create a response wrapper that modifies the response body
+ originalBody := resp.Body
+ pr, pw := io.Pipe()
+
+ go func() {
+ defer originalBody.Close()
+ defer pw.Close()
+
+ scanner := bufio.NewScanner(originalBody)
+ for scanner.Scan() {
+ line := scanner.Text()
+
+ // Check if this line contains an endpoint event
+ if strings.HasPrefix(line, "event: endpoint") {
+ // Read the data line
+ if scanner.Scan() {
+ dataLine := scanner.Text()
+ if strings.HasPrefix(dataLine, "data: ") {
+ // Extract the endpoint URL
+ endpoint := strings.TrimPrefix(dataLine, "data: ")
+
+ // Replace the host in the endpoint
+ logger.Debug("Original endpoint: %s", endpoint)
+ endpoint = strings.Replace(endpoint, t.targetHost, t.proxyHost, 1)
+ logger.Debug("Modified endpoint: %s", endpoint)
+
+ // Write the modified event lines
+ fmt.Fprintln(pw, line)
+ fmt.Fprintln(pw, "data: "+endpoint)
+ continue
+ }
+ }
+ }
+
+ // Write the original line for non-endpoint events
+ fmt.Fprintln(pw, line)
+ }
+
+ if err := scanner.Err(); err != nil {
+ logger.Error("Error reading SSE stream: %v", err)
+ }
+ }()
+
+ // Replace the response body with our modified pipe
+ resp.Body = pr
+ return resp, nil
+}
diff --git a/internal/subprocess/manager.go b/internal/subprocess/manager.go
new file mode 100644
index 0000000..fa64337
--- /dev/null
+++ b/internal/subprocess/manager.go
@@ -0,0 +1,268 @@
+package subprocess
+
+import (
+ "fmt"
+ "os"
+ "os/exec"
+ "sync"
+ "syscall"
+ "time"
+ "strings"
+
+ "github.com/wso2/open-mcp-auth-proxy/internal/config"
+ "github.com/wso2/open-mcp-auth-proxy/internal/logging"
+)
+
+// Manager handles starting and graceful shutdown of subprocesses
+type Manager struct {
+ process *os.Process
+ processGroup int
+ mutex sync.Mutex
+ cmd *exec.Cmd
+ shutdownDelay time.Duration
+}
+
+// NewManager creates a new subprocess manager
+func NewManager() *Manager {
+ return &Manager{
+ shutdownDelay: 5 * time.Second,
+ }
+}
+
+// EnsureDependenciesAvailable checks and installs required package executors
+func EnsureDependenciesAvailable(command string) error {
+ // Always ensure npx is available regardless of the command
+ if _, err := exec.LookPath("npx"); err != nil {
+ // npx is not available, check if npm is installed
+ if _, err := exec.LookPath("npm"); err != nil {
+ return fmt.Errorf("npx not found and npm not available; please install Node.js from https://nodejs.org/")
+ }
+
+ // Try to install npx using npm
+ logger.Info("npx not found, attempting to install...")
+ cmd := exec.Command("npm", "install", "-g", "npx")
+ cmd.Stdout = os.Stdout
+ cmd.Stderr = os.Stderr
+
+ if err := cmd.Run(); err != nil {
+ return fmt.Errorf("failed to install npx: %w", err)
+ }
+
+ logger.Info("npx installed successfully")
+ }
+
+ // Check if uv is needed based on the command
+ if strings.Contains(command, "uv ") {
+ if _, err := exec.LookPath("uv"); err != nil {
+ return fmt.Errorf("command requires uv but it's not installed; please install it following instructions at https://github.com/astral-sh/uv")
+ }
+ }
+
+ return nil
+}
+
+// SetShutdownDelay sets the maximum time to wait for graceful shutdown
+func (m *Manager) SetShutdownDelay(duration time.Duration) {
+ m.shutdownDelay = duration
+}
+
+// Start launches a subprocess based on the configuration
+func (m *Manager) Start(cfg *config.Config) error {
+ m.mutex.Lock()
+ defer m.mutex.Unlock()
+
+ // If a process is already running, return an error
+ if m.process != nil {
+ return os.ErrExist
+ }
+
+ if !cfg.Stdio.Enabled || cfg.Stdio.UserCommand == "" {
+ return nil // Nothing to start
+ }
+
+ // Get the full command string
+ execCommand := cfg.BuildExecCommand()
+ if execCommand == "" {
+ return nil // No command to execute
+ }
+
+ logger.Info("Starting subprocess with command: %s", execCommand)
+
+ // Use the shell to execute the command
+ cmd := exec.Command("sh", "-c", execCommand)
+
+ // Set working directory if specified
+ if cfg.Stdio.WorkDir != "" {
+ cmd.Dir = cfg.Stdio.WorkDir
+ }
+
+ // Set environment variables if specified
+ if len(cfg.Stdio.Env) > 0 {
+ cmd.Env = append(os.Environ(), cfg.Stdio.Env...)
+ }
+
+ // Capture stdout/stderr
+ cmd.Stdout = os.Stdout
+ cmd.Stderr = os.Stderr
+
+ // Set the process group for proper termination
+ cmd.SysProcAttr = &syscall.SysProcAttr{Setpgid: true}
+
+ // Start the process
+ if err := cmd.Start(); err != nil {
+ return err
+ }
+
+ m.process = cmd.Process
+ m.cmd = cmd
+ logger.Info("Subprocess started with PID: %d", m.process.Pid)
+
+ // Get and store the process group ID
+ pgid, err := syscall.Getpgid(m.process.Pid)
+ if err == nil {
+ m.processGroup = pgid
+ logger.Debug("Process group ID: %d", m.processGroup)
+ } else {
+ logger.Warn("Failed to get process group ID: %v", err)
+ m.processGroup = m.process.Pid
+ }
+
+ // Handle process termination in background
+ go func() {
+ if err := cmd.Wait(); err != nil {
+ logger.Error("Subprocess exited with error: %v", err)
+ } else {
+ logger.Info("Subprocess exited successfully")
+ }
+
+ // Clear the process reference when it exits
+ m.mutex.Lock()
+ m.process = nil
+ m.cmd = nil
+ m.mutex.Unlock()
+ }()
+
+ return nil
+}
+
+// IsRunning checks if the subprocess is running
+func (m *Manager) IsRunning() bool {
+ m.mutex.Lock()
+ defer m.mutex.Unlock()
+ return m.process != nil
+}
+
+// Shutdown gracefully terminates the subprocess
+func (m *Manager) Shutdown() {
+ m.mutex.Lock()
+ processToTerminate := m.process // Local copy of the process reference
+ processGroupToTerminate := m.processGroup
+ m.mutex.Unlock()
+
+ if processToTerminate == nil {
+ return // No process to terminate
+ }
+
+ logger.Info("Terminating subprocess...")
+ terminateComplete := make(chan struct{})
+
+ go func() {
+ defer close(terminateComplete)
+
+ // Try graceful termination first with SIGTERM
+ terminatedGracefully := false
+
+ // Try to terminate the process group first
+ if processGroupToTerminate != 0 {
+ err := syscall.Kill(-processGroupToTerminate, syscall.SIGTERM)
+ if err != nil {
+ logger.Warn("Failed to send SIGTERM to process group: %v", err)
+
+ // Fallback to terminating just the process
+ m.mutex.Lock()
+ if m.process != nil {
+ err = m.process.Signal(syscall.SIGTERM)
+ if err != nil {
+ logger.Warn("Failed to send SIGTERM to process: %v", err)
+ }
+ }
+ m.mutex.Unlock()
+ }
+ } else {
+ // Try to terminate just the process
+ m.mutex.Lock()
+ if m.process != nil {
+ err := m.process.Signal(syscall.SIGTERM)
+ if err != nil {
+ logger.Warn("Failed to send SIGTERM to process: %v", err)
+ }
+ }
+ m.mutex.Unlock()
+ }
+
+ // Wait for the process to exit gracefully
+ for i := 0; i < 10; i++ {
+ time.Sleep(200 * time.Millisecond)
+
+ m.mutex.Lock()
+ if m.process == nil {
+ terminatedGracefully = true
+ m.mutex.Unlock()
+ break
+ }
+ m.mutex.Unlock()
+ }
+
+ if terminatedGracefully {
+ logger.Info("Subprocess terminated gracefully")
+ return
+ }
+
+ // If the process didn't exit gracefully, force kill
+ logger.Warn("Subprocess didn't exit gracefully, forcing termination...")
+
+ // Try to kill the process group first
+ if processGroupToTerminate != 0 {
+ if err := syscall.Kill(-processGroupToTerminate, syscall.SIGKILL); err != nil {
+ logger.Warn("Failed to send SIGKILL to process group: %v", err)
+
+ // Fallback to killing just the process
+ m.mutex.Lock()
+ if m.process != nil {
+ if err := m.process.Kill(); err != nil {
+ logger.Error("Failed to kill process: %v", err)
+ }
+ }
+ m.mutex.Unlock()
+ }
+ } else {
+ // Try to kill just the process
+ m.mutex.Lock()
+ if m.process != nil {
+ if err := m.process.Kill(); err != nil {
+ logger.Error("Failed to kill process: %v", err)
+ }
+ }
+ m.mutex.Unlock()
+ }
+
+ // Wait a bit more to confirm termination
+ time.Sleep(500 * time.Millisecond)
+
+ m.mutex.Lock()
+ if m.process == nil {
+ logger.Info("Subprocess terminated by force")
+ } else {
+ logger.Warn("Failed to terminate subprocess")
+ }
+ m.mutex.Unlock()
+ }()
+
+ // Wait for termination with timeout
+ select {
+ case <-terminateComplete:
+ // Termination completed
+ case <-time.After(m.shutdownDelay):
+ logger.Warn("Subprocess termination timed out")
+ }
+}
diff --git a/internal/util/jwks.go b/internal/util/jwks.go
index 4832bf8..f80d82e 100644
--- a/internal/util/jwks.go
+++ b/internal/util/jwks.go
@@ -4,12 +4,12 @@ import (
"crypto/rsa"
"encoding/json"
"errors"
- "log"
"math/big"
"net/http"
"strings"
"github.com/golang-jwt/jwt/v4"
+ "github.com/wso2/open-mcp-auth-proxy/internal/logging"
)
type JWKS struct {
@@ -50,7 +50,7 @@ func FetchJWKS(jwksURL string) error {
publicKeys[parsedKey.Kid] = pubKey
}
}
- log.Printf("[JWKS] Loaded %d public keys.", len(publicKeys))
+ logger.Info("Loaded %d public keys.", len(publicKeys))
return nil
}
diff --git a/internal/util/jwks_test.go b/internal/util/jwks_test.go
new file mode 100644
index 0000000..3b00c68
--- /dev/null
+++ b/internal/util/jwks_test.go
@@ -0,0 +1,143 @@
+package util
+
+import (
+ "crypto/rand"
+ "crypto/rsa"
+ "encoding/base64"
+ "encoding/json"
+ "net/http"
+ "net/http/httptest"
+ "testing"
+ "time"
+
+ "github.com/golang-jwt/jwt/v4"
+)
+
+func TestValidateJWT(t *testing.T) {
+ // Initialize the test JWKS data
+ initTestJWKS(t)
+
+ // Test cases
+ tests := []struct {
+ name string
+ authHeader string
+ expectError bool
+ }{
+ {
+ name: "Valid JWT token",
+ authHeader: "Bearer " + createValidJWT(t),
+ expectError: false,
+ },
+ {
+ name: "No auth header",
+ authHeader: "",
+ expectError: true,
+ },
+ {
+ name: "Invalid auth header format",
+ authHeader: "InvalidFormat",
+ expectError: true,
+ },
+ {
+ name: "Invalid JWT token",
+ authHeader: "Bearer invalid.jwt.token",
+ expectError: true,
+ },
+ }
+
+ for _, tc := range tests {
+ t.Run(tc.name, func(t *testing.T) {
+ err := ValidateJWT(tc.authHeader)
+ if tc.expectError && err == nil {
+ t.Errorf("Expected error but got none")
+ }
+ if !tc.expectError && err != nil {
+ t.Errorf("Expected no error but got: %v", err)
+ }
+ })
+ }
+}
+
+func TestFetchJWKS(t *testing.T) {
+ // Create a mock JWKS server
+ server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ // Generate a test RSA key
+ privateKey, err := rsa.GenerateKey(rand.Reader, 2048)
+ if err != nil {
+ t.Fatalf("Failed to generate RSA key: %v", err)
+ }
+
+ // Create JWKS response
+ jwks := map[string]interface{}{
+ "keys": []map[string]interface{}{
+ {
+ "kty": "RSA",
+ "kid": "test-key-id",
+ "n": base64.RawURLEncoding.EncodeToString(privateKey.N.Bytes()),
+ "e": base64.RawURLEncoding.EncodeToString([]byte{1, 0, 1}), // Default exponent 65537
+ },
+ },
+ }
+
+ w.Header().Set("Content-Type", "application/json")
+ json.NewEncoder(w).Encode(jwks)
+ }))
+ defer server.Close()
+
+ // Test fetching JWKS
+ err := FetchJWKS(server.URL)
+ if err != nil {
+ t.Fatalf("FetchJWKS failed: %v", err)
+ }
+
+ // Check that keys were stored
+ if len(publicKeys) == 0 {
+ t.Errorf("Expected publicKeys to be populated")
+ }
+}
+
+// Helper function to initialize test JWKS data
+func initTestJWKS(t *testing.T) {
+ // Create a test RSA key pair
+ privateKey, err := rsa.GenerateKey(rand.Reader, 2048)
+ if err != nil {
+ t.Fatalf("Failed to generate RSA key: %v", err)
+ }
+
+ // Initialize the publicKeys map
+ publicKeys = map[string]*rsa.PublicKey{
+ "test-key-id": &privateKey.PublicKey,
+ }
+}
+
+// Helper function to create a valid JWT token for testing
+func createValidJWT(t *testing.T) string {
+ // Create a test RSA key pair
+ privateKey, err := rsa.GenerateKey(rand.Reader, 2048)
+ if err != nil {
+ t.Fatalf("Failed to generate RSA key: %v", err)
+ }
+
+ // Ensure the test key is in the publicKeys map
+ if publicKeys == nil {
+ publicKeys = map[string]*rsa.PublicKey{}
+ }
+ publicKeys["test-key-id"] = &privateKey.PublicKey
+
+ // Create token
+ token := jwt.NewWithClaims(jwt.SigningMethodRS256, jwt.MapClaims{
+ "sub": "1234567890",
+ "name": "Test User",
+ "iat": time.Now().Unix(),
+ "exp": time.Now().Add(time.Hour).Unix(),
+ })
+ token.Header["kid"] = "test-key-id"
+
+ // Sign the token
+ tokenString, err := token.SignedString(privateKey)
+ if err != nil {
+ t.Fatalf("Failed to sign token: %v", err)
+ }
+
+ return tokenString
+}
diff --git a/pull_request_template.md b/pull_request_template.md
index 9b32185..c401a06 100644
--- a/pull_request_template.md
+++ b/pull_request_template.md
@@ -1,52 +1,11 @@
## Purpose
-> Describe the problems, issues, or needs driving this feature/fix and include links to related issues in the following format: Resolves issue1, issue2, etc.
+
-## Goals
-> Describe the solutions that this feature/fix will introduce to resolve the problems described above
-
-## Approach
-> Describe how you are implementing the solutions. Include an animated GIF or screenshot if the change affects the UI (email documentation@wso2.com to review all UI text). Include a link to a Markdown file or Google doc if the feature write-up is too long to paste here.
-
-## User stories
-> Summary of user stories addressed by this change>
-
-## Release note
-> Brief description of the new feature or bug fix as it will appear in the release notes
-
-## Documentation
-> Link(s) to product documentation that addresses the changes of this PR. If no doc impact, enter “N/A” plus brief explanation of why there’s no doc impact
-
-## Training
-> Link to the PR for changes to the training content in https://github.com/wso2/WSO2-Training, if applicable
-
-## Certification
-> Type “Sent” when you have provided new/updated certification questions, plus four answers for each question (correct answer highlighted in bold), based on this change. Certification questions/answers should be sent to certification@wso2.com and NOT pasted in this PR. If there is no impact on certification exams, type “N/A” and explain why.
-
-## Marketing
-> Link to drafts of marketing content that will describe and promote this feature, including product page changes, technical articles, blog posts, videos, etc., if applicable
-
-## Automation tests
- - Unit tests
- > Code coverage information
- - Integration tests
- > Details about the test cases and coverage
-
-## Security checks
- - Followed secure coding standards in http://wso2.com/technical-reports/wso2-secure-engineering-guidelines? yes/no
- - Ran FindSecurityBugs plugin and verified report? yes/no
- - Confirmed that this PR doesn't commit any keys, passwords, tokens, usernames, or other secrets? yes/no
-
-## Samples
-> Provide high-level details about the samples related to this feature
+## Related Issues
+
## Related PRs
-> List any other related PRs
+
## Migrations (if applicable)
-> Describe migration steps and platforms on which migration has been tested
-
-## Test environment
-> List all JDK versions, operating systems, databases, and browser/versions on which this feature/fix was tested
-
-## Learning
-> Describe the research phase and any blog posts, patterns, libraries, or add-ons you used to solve the problem.
\ No newline at end of file
+
diff --git a/resources/requirements.txt b/resources/requirements.txt
new file mode 100644
index 0000000..102b728
--- /dev/null
+++ b/resources/requirements.txt
@@ -0,0 +1 @@
+fastmcp==0.4.1
\ No newline at end of file