Compare commits
1 commit
main
...
pavinduLak
Author | SHA1 | Date | |
---|---|---|---|
|
48fc6a5054 |
29 changed files with 212 additions and 2014 deletions
124
.github/scripts/release.sh
vendored
124
.github/scripts/release.sh
vendored
|
@ -1,124 +0,0 @@
|
||||||
#!/bin/bash
|
|
||||||
|
|
||||||
# Copyright (c) 2025, WSO2 LLC. (https://www.wso2.com).
|
|
||||||
#
|
|
||||||
# This software is the property of WSO2 LLC. and its suppliers, if any.
|
|
||||||
# Dissemination of any information or reproduction of any material contained
|
|
||||||
# herein in any form is strictly forbidden, unless permitted by WSO2 expressly.
|
|
||||||
# You may not alter or remove any copyright or other notice from copies of this content.
|
|
||||||
#
|
|
||||||
|
|
||||||
# Exit the script on any command with non-zero exit status.
|
|
||||||
set -e
|
|
||||||
set -o pipefail
|
|
||||||
|
|
||||||
UPSTREAM_BRANCH="main"
|
|
||||||
|
|
||||||
# Assign command line arguments to variables.
|
|
||||||
GIT_TOKEN=$1
|
|
||||||
WORK_DIR=$2
|
|
||||||
VERSION_TYPE=$3 # possible values: major, minor, patch
|
|
||||||
|
|
||||||
# Check if GIT_TOKEN is empty
|
|
||||||
if [ -z "$GIT_TOKEN" ]; then
|
|
||||||
echo "❌ Error: GIT_TOKEN is not set."
|
|
||||||
exit 1
|
|
||||||
fi
|
|
||||||
|
|
||||||
# Check if WORK_DIR is empty
|
|
||||||
if [ -z "$WORK_DIR" ]; then
|
|
||||||
echo "❌ Error: WORK_DIR is not set."
|
|
||||||
exit 1
|
|
||||||
fi
|
|
||||||
|
|
||||||
# Validate VERSION_TYPE
|
|
||||||
if [[ "$VERSION_TYPE" != "major" && "$VERSION_TYPE" != "minor" && "$VERSION_TYPE" != "patch" ]]; then
|
|
||||||
echo "❌ Error: VERSION_TYPE must be one of: major, minor, or patch."
|
|
||||||
exit 1
|
|
||||||
fi
|
|
||||||
|
|
||||||
BUILD_DIRECTORY="$WORK_DIR/build"
|
|
||||||
RELEASE_DIRECTORY="$BUILD_DIRECTORY/releases"
|
|
||||||
|
|
||||||
# Navigate to the working directory.
|
|
||||||
cd "${WORK_DIR}"
|
|
||||||
|
|
||||||
# Create the release directory.
|
|
||||||
if [ ! -d "$RELEASE_DIRECTORY" ]; then
|
|
||||||
mkdir -p "$RELEASE_DIRECTORY"
|
|
||||||
else
|
|
||||||
rm -rf "$RELEASE_DIRECTORY"/*
|
|
||||||
fi
|
|
||||||
|
|
||||||
# Extract current version.
|
|
||||||
CURRENT_VERSION=$(git describe --tags --abbrev=0 2>/dev/null || echo "0.0.0")
|
|
||||||
IFS='.' read -r MAJOR MINOR PATCH <<< "${CURRENT_VERSION}"
|
|
||||||
|
|
||||||
# Determine which part to increment
|
|
||||||
case "$VERSION_TYPE" in
|
|
||||||
major)
|
|
||||||
MAJOR=$((MAJOR + 1))
|
|
||||||
MINOR=0
|
|
||||||
PATCH=0
|
|
||||||
;;
|
|
||||||
minor)
|
|
||||||
MINOR=$((MINOR + 1))
|
|
||||||
PATCH=0
|
|
||||||
;;
|
|
||||||
patch|*)
|
|
||||||
PATCH=$((PATCH + 1))
|
|
||||||
;;
|
|
||||||
esac
|
|
||||||
|
|
||||||
NEW_VERSION="${MAJOR}.${MINOR}.${PATCH}"
|
|
||||||
|
|
||||||
echo "Creating release packages for version $NEW_VERSION..."
|
|
||||||
|
|
||||||
# List of supported OSes.
|
|
||||||
oses=("linux" "linux-arm" "darwin")
|
|
||||||
|
|
||||||
# Navigate to the release directory.
|
|
||||||
cd "${RELEASE_DIRECTORY}"
|
|
||||||
|
|
||||||
for os in "${oses[@]}"; do
|
|
||||||
os_dir="../$os"
|
|
||||||
|
|
||||||
if [ -d "$os_dir" ]; then
|
|
||||||
release_artifact_folder="openmcpauthproxy_${os}-v${NEW_VERSION}"
|
|
||||||
mkdir -p "$release_artifact_folder"
|
|
||||||
|
|
||||||
cp -r $os_dir/* "$release_artifact_folder"
|
|
||||||
|
|
||||||
# Zip the release package.
|
|
||||||
zip_file="$release_artifact_folder.zip"
|
|
||||||
echo "Creating $zip_file..."
|
|
||||||
zip -r "$zip_file" "$release_artifact_folder"
|
|
||||||
|
|
||||||
# Delete the folder after zipping.
|
|
||||||
rm -rf "$release_artifact_folder"
|
|
||||||
|
|
||||||
# Generate checksum file.
|
|
||||||
sha256sum "$zip_file" | sed "s|target/releases/||" > "$zip_file.sha256"
|
|
||||||
echo "Checksum generated for the $os package."
|
|
||||||
|
|
||||||
echo "Release packages created successfully for $os."
|
|
||||||
else
|
|
||||||
echo "Skipping $os release package creation as the build artifacts are not available."
|
|
||||||
fi
|
|
||||||
done
|
|
||||||
|
|
||||||
echo "Release packages created successfully in $RELEASE_DIRECTORY."
|
|
||||||
|
|
||||||
# Navigate back to the project root directory.
|
|
||||||
cd "${WORK_DIR}"
|
|
||||||
|
|
||||||
# Collect all ZIP and .sha256 files in the target/releases directory.
|
|
||||||
FILES_TO_UPLOAD=$(find build/releases -type f \( -name "*.zip" -o -name "*.sha256" \))
|
|
||||||
|
|
||||||
# Create a release with the current version.
|
|
||||||
TAG_NAME="v${NEW_VERSION}"
|
|
||||||
export GITHUB_TOKEN="${GIT_TOKEN}"
|
|
||||||
gh release create "${TAG_NAME}" ${FILES_TO_UPLOAD} --title "${TAG_NAME}" --notes "OpenMCPAuthProxy - ${TAG_NAME}" --target "${UPSTREAM_BRANCH}" || { echo "Failed to create release"; exit 1; }
|
|
||||||
|
|
||||||
|
|
||||||
echo "Release ${TAG_NAME} created successfully."
|
|
71
.github/workflows/ci.yaml
vendored
71
.github/workflows/ci.yaml
vendored
|
@ -1,71 +0,0 @@
|
||||||
name: Build and Push container
|
|
||||||
run-name: Build and Push container
|
|
||||||
on:
|
|
||||||
workflow_dispatch:
|
|
||||||
#schedule:
|
|
||||||
# - cron: "0 10 * * *"
|
|
||||||
push:
|
|
||||||
branches:
|
|
||||||
- 'main'
|
|
||||||
- 'master'
|
|
||||||
tags:
|
|
||||||
- 'v*'
|
|
||||||
pull_request:
|
|
||||||
branches:
|
|
||||||
- 'main'
|
|
||||||
- 'master'
|
|
||||||
env:
|
|
||||||
IMAGE: git.kvant.cloud/${{github.repository}}
|
|
||||||
jobs:
|
|
||||||
build_concierge_backend:
|
|
||||||
runs-on: ubuntu-latest
|
|
||||||
steps:
|
|
||||||
- name: Checkout
|
|
||||||
uses: actions/checkout@v4
|
|
||||||
with:
|
|
||||||
fetch-depth: 0
|
|
||||||
|
|
||||||
- name: Set current time
|
|
||||||
uses: https://github.com/gerred/actions/current-time@master
|
|
||||||
id: current_time
|
|
||||||
|
|
||||||
- name: Set up Docker Buildx
|
|
||||||
uses: docker/setup-buildx-action@v3
|
|
||||||
|
|
||||||
- name: Login to git.kvant.cloud registry
|
|
||||||
uses: docker/login-action@v3
|
|
||||||
with:
|
|
||||||
registry: git.kvant.cloud
|
|
||||||
username: ${{ vars.ORG_PACKAGE_WRITER_USERNAME }}
|
|
||||||
password: ${{ secrets.ORG_PACKAGE_WRITER_TOKEN }}
|
|
||||||
|
|
||||||
- name: Docker meta
|
|
||||||
id: meta
|
|
||||||
uses: docker/metadata-action@v5
|
|
||||||
with:
|
|
||||||
# list of Docker images to use as base name for tags
|
|
||||||
images: |
|
|
||||||
${{env.IMAGE}}
|
|
||||||
# generate Docker tags based on the following events/attributes
|
|
||||||
tags: |
|
|
||||||
type=schedule
|
|
||||||
type=ref,event=branch
|
|
||||||
type=ref,event=pr
|
|
||||||
type=semver,pattern={{version}}
|
|
||||||
|
|
||||||
- name: Build and push to gitea registry
|
|
||||||
uses: docker/build-push-action@v6
|
|
||||||
with:
|
|
||||||
push: ${{ github.event_name != 'pull_request' }}
|
|
||||||
tags: ${{ steps.meta.outputs.tags }}
|
|
||||||
labels: ${{ steps.meta.outputs.labels }}
|
|
||||||
context: .
|
|
||||||
provenance: mode=max
|
|
||||||
sbom: true
|
|
||||||
build-args: |
|
|
||||||
BUILD_DATE=${{ steps.current_time.outputs.time }}
|
|
||||||
cache-from: |
|
|
||||||
type=registry,ref=${{ env.IMAGE }}:buildcache
|
|
||||||
type=registry,ref=${{ env.IMAGE }}:${{ github.ref_name }}
|
|
||||||
type=registry,ref=${{ env.IMAGE }}:main
|
|
||||||
cache-to: type=registry,ref=${{ env.IMAGE }}:buildcache,mode=max,image-manifest=true
|
|
62
.github/workflows/pr-builder.yml
vendored
62
.github/workflows/pr-builder.yml
vendored
|
@ -1,62 +0,0 @@
|
||||||
name: Go CI
|
|
||||||
|
|
||||||
on:
|
|
||||||
push:
|
|
||||||
branches: [ main ]
|
|
||||||
pull_request:
|
|
||||||
branches: [ main ]
|
|
||||||
|
|
||||||
jobs:
|
|
||||||
test:
|
|
||||||
name: Test
|
|
||||||
runs-on: ubuntu-latest
|
|
||||||
strategy:
|
|
||||||
matrix:
|
|
||||||
go-version: ['1.20', '1.21']
|
|
||||||
|
|
||||||
steps:
|
|
||||||
- name: Checkout code
|
|
||||||
uses: actions/checkout@v3
|
|
||||||
|
|
||||||
- name: Set up Go
|
|
||||||
uses: actions/setup-go@v4
|
|
||||||
with:
|
|
||||||
go-version: ${{ matrix.go-version }}
|
|
||||||
|
|
||||||
- name: Get dependencies
|
|
||||||
run: go get -v -t -d ./...
|
|
||||||
|
|
||||||
- name: Verify dependencies
|
|
||||||
run: go mod verify
|
|
||||||
|
|
||||||
- name: Run go vet
|
|
||||||
run: go vet ./...
|
|
||||||
|
|
||||||
- name: Run tests
|
|
||||||
run: go test -v -race -coverprofile=coverage.txt -covermode=atomic ./...
|
|
||||||
|
|
||||||
- name: Upload coverage to Codecov
|
|
||||||
uses: codecov/codecov-action@v3
|
|
||||||
with:
|
|
||||||
files: ./coverage.txt
|
|
||||||
fail_ci_if_error: false
|
|
||||||
|
|
||||||
build:
|
|
||||||
name: Build
|
|
||||||
runs-on: ubuntu-latest
|
|
||||||
strategy:
|
|
||||||
matrix:
|
|
||||||
go-version: ['1.20', '1.21']
|
|
||||||
os: [ubuntu-latest, macos-latest]
|
|
||||||
|
|
||||||
steps:
|
|
||||||
- name: Checkout code
|
|
||||||
uses: actions/checkout@v3
|
|
||||||
|
|
||||||
- name: Set up Go
|
|
||||||
uses: actions/setup-go@v4
|
|
||||||
with:
|
|
||||||
go-version: ${{ matrix.go-version }}
|
|
||||||
|
|
||||||
- name: Build
|
|
||||||
run: go build -v ./cmd/proxy
|
|
64
.github/workflows/release.yml
vendored
64
.github/workflows/release.yml
vendored
|
@ -1,64 +0,0 @@
|
||||||
#
|
|
||||||
# Copyright (c) 2025, WSO2 LLC. (https://www.wso2.com).
|
|
||||||
#
|
|
||||||
# This software is the property of WSO2 LLC. and its suppliers, if any.
|
|
||||||
# Dissemination of any information or reproduction of any material contained
|
|
||||||
# herein in any form is strictly forbidden, unless permitted by WSO2 expressly.
|
|
||||||
# You may not alter or remove any copyright or other notice from copies of this content.
|
|
||||||
#
|
|
||||||
|
|
||||||
name: Release
|
|
||||||
|
|
||||||
on:
|
|
||||||
workflow_dispatch:
|
|
||||||
inputs:
|
|
||||||
version_type:
|
|
||||||
type: choice
|
|
||||||
description: Choose the type of version update
|
|
||||||
options:
|
|
||||||
- 'major'
|
|
||||||
- 'minor'
|
|
||||||
- 'patch'
|
|
||||||
required: true
|
|
||||||
|
|
||||||
jobs:
|
|
||||||
update-and-release:
|
|
||||||
runs-on: ubuntu-latest
|
|
||||||
env:
|
|
||||||
GOPROXY: https://proxy.golang.org
|
|
||||||
if: github.event.pull_request.merged == true || github.event_name == 'workflow_dispatch'
|
|
||||||
steps:
|
|
||||||
- uses: actions/checkout@v2
|
|
||||||
with:
|
|
||||||
ref: 'main'
|
|
||||||
fetch-depth: 0
|
|
||||||
token: ${{ secrets.GITHUB_TOKEN }}
|
|
||||||
- uses: actions/checkout@v2
|
|
||||||
|
|
||||||
- name: Set up Go 1.x
|
|
||||||
uses: actions/setup-go@v3
|
|
||||||
with:
|
|
||||||
go-version: "^1.x"
|
|
||||||
|
|
||||||
- name: Cache Go modules
|
|
||||||
id: cache-go-modules
|
|
||||||
uses: actions/cache@v3
|
|
||||||
with:
|
|
||||||
path: |
|
|
||||||
~/.cache/go-build
|
|
||||||
~/go/pkg/mod
|
|
||||||
key: ${{ runner.os }}-go-modules-${{ hashFiles('**/go.sum') }}
|
|
||||||
restore-keys: |
|
|
||||||
${{ runner.os }}-go-modules-
|
|
||||||
|
|
||||||
- name: Install dependencies
|
|
||||||
run: go mod download
|
|
||||||
|
|
||||||
- name: Build and test
|
|
||||||
run: make build
|
|
||||||
working-directory: .
|
|
||||||
|
|
||||||
- name: Update artifact version, package, commit, and create release.
|
|
||||||
env:
|
|
||||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
|
||||||
run: bash ./.github/scripts/release.sh $GITHUB_TOKEN ${{ github.workspace }} ${{ github.event.inputs.version_type }}
|
|
14
.gitignore
vendored
14
.gitignore
vendored
|
@ -18,21 +18,15 @@
|
||||||
*.zip
|
*.zip
|
||||||
*.tar.gz
|
*.tar.gz
|
||||||
*.rar
|
*.rar
|
||||||
.venv
|
|
||||||
|
|
||||||
# virtual machine crash logs, see http://www.java.com/en/download/help/error_hotspot.xml
|
# virtual machine crash logs, see http://www.java.com/en/download/help/error_hotspot.xml
|
||||||
hs_err_pid*
|
hs_err_pid*
|
||||||
replay_pid*
|
replay_pid*
|
||||||
|
|
||||||
|
# Go module cache files
|
||||||
|
go.sum
|
||||||
|
|
||||||
# OS generated files
|
# OS generated files
|
||||||
.DS_Store
|
.DS_Store
|
||||||
|
|
||||||
# builds
|
openmcpauthproxy
|
||||||
build
|
|
||||||
|
|
||||||
# test out files
|
|
||||||
coverage.out
|
|
||||||
coverage.html
|
|
||||||
|
|
||||||
# IDE files
|
|
||||||
.vscode
|
|
||||||
|
|
48
Dockerfile
48
Dockerfile
|
@ -1,48 +0,0 @@
|
||||||
FROM --platform=${BUILDPLATFORM:-linux/amd64} golang:1.24@sha256:d9db32125db0c3a680cfb7a1afcaefb89c898a075ec148fdc2f0f646cc2ed509 AS build
|
|
||||||
|
|
||||||
ARG TARGETPLATFORM
|
|
||||||
ARG BUILDPLATFORM
|
|
||||||
ARG TARGETOS
|
|
||||||
ARG TARGETARCH
|
|
||||||
|
|
||||||
WORKDIR /workspace
|
|
||||||
|
|
||||||
RUN apt update -qq && apt install -qq -y git bash curl g++
|
|
||||||
|
|
||||||
# Download libraries
|
|
||||||
ADD go.* .
|
|
||||||
RUN go mod download
|
|
||||||
|
|
||||||
# Build
|
|
||||||
ADD cmd cmd
|
|
||||||
ADD internal internal
|
|
||||||
RUN CGO_ENABLED=0 GOOS=${TARGETOS} GOARCH=${TARGETARCH} go build -o webhook -ldflags '-w -extldflags "-static"' -o openmcpauthproxy ./cmd/proxy
|
|
||||||
|
|
||||||
#Test
|
|
||||||
RUN CGO_ENABLED=1 GOOS=${TARGETOS} GOARCH=${TARGETARCH} go test -v -race ./...
|
|
||||||
|
|
||||||
|
|
||||||
# Build production container
|
|
||||||
FROM --platform=${BUILDPLATFORM:-linux/amd64} ubuntu:24.04
|
|
||||||
|
|
||||||
RUN apt-get update \
|
|
||||||
&& apt-get install --no-install-recommends -y \
|
|
||||||
python3-pip \
|
|
||||||
python-is-python3 \
|
|
||||||
npm \
|
|
||||||
&& apt-get autoremove \
|
|
||||||
&& apt-get clean \
|
|
||||||
&& rm -rf /var/lib/apt/lists/*
|
|
||||||
|
|
||||||
RUN pip install uvenv --break-system-packages
|
|
||||||
|
|
||||||
WORKDIR /app
|
|
||||||
COPY --from=build /workspace/openmcpauthproxy /app/
|
|
||||||
|
|
||||||
ADD config.yaml /app
|
|
||||||
|
|
||||||
|
|
||||||
ENTRYPOINT ["/app/openmcpauthproxy"]
|
|
||||||
|
|
||||||
ARG IMAGE_SOURCE
|
|
||||||
LABEL org.opencontainers.image.source=$IMAGE_SOURCE
|
|
88
Makefile
88
Makefile
|
@ -1,88 +0,0 @@
|
||||||
# Makefile for open-mcp-auth-proxy
|
|
||||||
|
|
||||||
# Variables
|
|
||||||
PROJECT_ROOT := $(realpath $(dir $(abspath $(lastword $(MAKEFILE_LIST)))))
|
|
||||||
BINARY_NAME := openmcpauthproxy
|
|
||||||
GO := go
|
|
||||||
GOFMT := gofmt
|
|
||||||
GOVET := go vet
|
|
||||||
GOTEST := go test
|
|
||||||
GOLINT := golangci-lint
|
|
||||||
GOCOV := go tool cover
|
|
||||||
BUILD_DIR := build
|
|
||||||
|
|
||||||
# Source files
|
|
||||||
SRC := $(shell find . -name "*.go" -not -path "./vendor/*")
|
|
||||||
PKGS := $(shell go list ./... | grep -v /vendor/)
|
|
||||||
|
|
||||||
# Set build options
|
|
||||||
BUILD_OPTS := -v
|
|
||||||
|
|
||||||
# Set test options
|
|
||||||
TEST_OPTS := -v -race
|
|
||||||
|
|
||||||
.PHONY: all clean test fmt lint vet coverage help
|
|
||||||
|
|
||||||
# Default target
|
|
||||||
all: lint test build-linux build-linux-arm build-darwin
|
|
||||||
|
|
||||||
build: clean test build-linux build-linux-arm build-darwin
|
|
||||||
|
|
||||||
build-linux:
|
|
||||||
mkdir -p $(BUILD_DIR)/linux
|
|
||||||
GOOS=linux GOARCH=amd64 CGO_ENABLED=0 go build -x -ldflags "-X main.version=$(BUILD_VERSION)" \
|
|
||||||
-o $(BUILD_DIR)/linux/openmcpauthproxy $(PROJECT_ROOT)/cmd/proxy
|
|
||||||
cp config.yaml $(BUILD_DIR)/linux
|
|
||||||
|
|
||||||
build-linux-arm:
|
|
||||||
mkdir -p $(BUILD_DIR)/linux-arm
|
|
||||||
GOOS=linux GOARCH=arm CGO_ENABLED=0 go build -x -ldflags "-X main.version=$(BUILD_VERSION)" \
|
|
||||||
-o $(BUILD_DIR)/linux-arm/openmcpauthproxy $(PROJECT_ROOT)/cmd/proxy
|
|
||||||
cp config.yaml $(BUILD_DIR)/linux-arm
|
|
||||||
|
|
||||||
build-darwin:
|
|
||||||
mkdir -p $(BUILD_DIR)/darwin
|
|
||||||
GOOS=darwin GOARCH=amd64 CGO_ENABLED=0 go build -x -ldflags "-X main.version=$(BUILD_VERSION)" \
|
|
||||||
-o $(BUILD_DIR)/darwin/openmcpauthproxy $(PROJECT_ROOT)/cmd/proxy
|
|
||||||
cp config.yaml $(BUILD_DIR)/darwin
|
|
||||||
|
|
||||||
# Clean build artifacts
|
|
||||||
clean:
|
|
||||||
@echo "Cleaning build artifacts..."
|
|
||||||
@rm -rf $(BUILD_DIR)
|
|
||||||
@rm -f coverage.out
|
|
||||||
|
|
||||||
# Run tests
|
|
||||||
test:
|
|
||||||
@echo "Running tests..."
|
|
||||||
$(GOTEST) $(TEST_OPTS) ./...
|
|
||||||
|
|
||||||
# Run tests with coverage report
|
|
||||||
coverage:
|
|
||||||
@echo "Running tests with coverage..."
|
|
||||||
@$(GOTEST) -coverprofile=coverage.out ./...
|
|
||||||
@$(GOCOV) -func=coverage.out
|
|
||||||
@$(GOCOV) -html=coverage.out -o coverage.html
|
|
||||||
@echo "Coverage report generated in coverage.html"
|
|
||||||
|
|
||||||
# Run gofmt
|
|
||||||
fmt:
|
|
||||||
@echo "Running gofmt..."
|
|
||||||
@$(GOFMT) -w -s $(SRC)
|
|
||||||
|
|
||||||
# Run go vet
|
|
||||||
vet:
|
|
||||||
@echo "Running go vet..."
|
|
||||||
@$(GOVET) ./...
|
|
||||||
|
|
||||||
# Show help
|
|
||||||
help:
|
|
||||||
@echo "Available targets:"
|
|
||||||
@echo " all : Run lint, test, and build"
|
|
||||||
@echo " build : Build the application"
|
|
||||||
@echo " clean : Clean build artifacts"
|
|
||||||
@echo " test : Run tests"
|
|
||||||
@echo " coverage : Run tests with coverage report"
|
|
||||||
@echo " fmt : Run gofmt"
|
|
||||||
@echo " vet : Run go vet"
|
|
||||||
@echo " help : Show this help message"
|
|
245
README.md
245
README.md
|
@ -1,87 +1,82 @@
|
||||||
# Open MCP Auth Proxy
|
# Open MCP Auth Proxy
|
||||||
|
|
||||||
A lightweight authorization proxy for Model Context Protocol (MCP) servers that enforces authorization according to the [MCP authorization specification](https://spec.modelcontextprotocol.io/specification/2025-03-26/basic/authorization/)
|
The Open MCP Auth Proxy is a lightweight proxy designed to sit in front of MCP servers and enforce authorization in compliance with the [Model Context Protocol authorization](https://spec.modelcontextprotocol.io/specification/2025-03-26/basic/authorization/) requirements. It intercepts incoming requests, validates tokens, and offloads authentication and authorization to an OAuth-compliant Identity Provider.
|
||||||
|
|
||||||
<a href="">[](https://github.com/wso2/open-mcp-auth-proxy/actions/workflows/release.yml)</a>
|

|
||||||
<a href="">[](https://stackoverflow.com/questions/tagged/wso2is)</a>
|
|
||||||
<a href="">[](https://discord.gg/wso2)</a>
|
|
||||||
<a href="">[](https://twitter.com/intent/follow?screen_name=wso2)</a>
|
|
||||||
<a href="">[](https://github.com/wso2/product-is/blob/master/LICENSE)</a>
|
|
||||||
|
|
||||||

|
## **Setup and Installation**
|
||||||
|
|
||||||
## What it Does
|
### **Prerequisites**
|
||||||
|
|
||||||
Open MCP Auth Proxy sits between MCP clients and your MCP server to:
|
|
||||||
|
|
||||||
- Intercept incoming requests
|
|
||||||
- Validate authorization tokens
|
|
||||||
- Offload authentication and authorization to OAuth-compliant Identity Providers
|
|
||||||
- Support the MCP authorization protocol
|
|
||||||
|
|
||||||
## Quick Start
|
|
||||||
|
|
||||||
### Prerequisites
|
|
||||||
|
|
||||||
* Go 1.20 or higher
|
* Go 1.20 or higher
|
||||||
* A running MCP server
|
* A running MCP server (SSE transport supported)
|
||||||
|
|
||||||
> If you don't have an MCP server, you can use the included example:
|
|
||||||
>
|
|
||||||
> 1. Navigate to the `resources` directory
|
|
||||||
> 2. Set up a Python environment:
|
|
||||||
>
|
|
||||||
> ```bash
|
|
||||||
> python3 -m venv .venv
|
|
||||||
> source .venv/bin/activate
|
|
||||||
> pip3 install -r requirements.txt
|
|
||||||
> ```
|
|
||||||
>
|
|
||||||
> 3. Start the example server:
|
|
||||||
>
|
|
||||||
> ```bash
|
|
||||||
> python3 echo_server.py
|
|
||||||
> ```
|
|
||||||
|
|
||||||
* An MCP client that supports MCP authorization
|
* An MCP client that supports MCP authorization
|
||||||
|
|
||||||
### Basic Usage
|
### **Installation**
|
||||||
|
|
||||||
1. Download the latest release from [Github releases](https://github.com/wso2/open-mcp-auth-proxy/releases/latest).
|
```bash
|
||||||
|
git clone https://github.com/wso2/open-mcp-auth-proxy
|
||||||
|
cd open-mcp-auth-proxy
|
||||||
|
|
||||||
2. Start the proxy in demo mode (uses pre-configured authentication with Asgardeo sandbox):
|
go get github.com/golang-jwt/jwt/v4
|
||||||
|
go get gopkg.in/yaml.v2
|
||||||
|
|
||||||
|
go build -o openmcpauthproxy ./cmd/proxy
|
||||||
|
```
|
||||||
|
|
||||||
|
## Using Open MCP Auth Proxy
|
||||||
|
|
||||||
|
### Quick Start
|
||||||
|
|
||||||
|
Allows you to just enable authentication and authorization for your MCP server with the preconfigured auth provider powered by Asgardeo.
|
||||||
|
|
||||||
|
If you don’t have an MCP server, follow the instructions given here to start your own MCP server for testing purposes.
|
||||||
|
1. Download [sample MCP server](resources/echo_server.py)
|
||||||
|
2. Run the server with
|
||||||
|
```bash
|
||||||
|
python3 echo_server.py
|
||||||
|
```
|
||||||
|
|
||||||
|
#### Configure the Auth Proxy
|
||||||
|
|
||||||
|
Update the following parameters in `config.yaml`.
|
||||||
|
|
||||||
|
### demo mode configuration:
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
mcp_server_base_url: "http://localhost:8000" # URL of your MCP server
|
||||||
|
listen_port: 8080 # Address where the proxy will listen
|
||||||
|
```
|
||||||
|
|
||||||
|
#### Start the Auth Proxy
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
./openmcpauthproxy --demo
|
./openmcpauthproxy --demo
|
||||||
```
|
```
|
||||||
|
|
||||||
> The repository comes with a default `config.yaml` file that contains the basic configuration:
|
The `--demo` flag enables a demonstration mode with pre-configured authentication and authorization with a sandbox powered by [Asgardeo](https://asgardeo.io/).
|
||||||
>
|
|
||||||
> ```yaml
|
|
||||||
> listen_port: 8080
|
|
||||||
> base_url: "http://localhost:8000" # Your MCP server URL
|
|
||||||
> paths:
|
|
||||||
> sse: "/sse"
|
|
||||||
> messages: "/messages/"
|
|
||||||
> ```
|
|
||||||
|
|
||||||
3. Connect using an MCP client like [MCP Inspector](https://github.com/shashimalcse/inspector)(This is a temporary fork with fixes for authentication [issues](https://github.com/modelcontextprotocol/typescript-sdk/issues/257) in the original implementation)
|
#### Connect Using an MCP Client
|
||||||
|
|
||||||
## Connect an Identity Provider
|
You can use this improved fork of [MCP Inspector](https://github.com/shashimalcse/inspector) to test the connection and try out the complete authorization flow.
|
||||||
|
|
||||||
### Asgardeo
|
### Use with Asgardeo
|
||||||
|
|
||||||
To enable authorization through your Asgardeo organization:
|
Enable authorization for the MCP server through your own Asgardeo organization
|
||||||
|
|
||||||
1. [Register](https://asgardeo.io/signup) and create an organization in Asgardeo
|
1. [Register]([url](https://asgardeo.io/signup)) and create an organization in Asgardeo
|
||||||
2. Create an [M2M application](https://wso2.com/asgardeo/docs/guides/applications/register-machine-to-machine-app/)
|
2. Now, you need to authorize the OpenMCPAuthProxy to allow dynamically registering MCP Clients as applications in your organization. To do that,
|
||||||
1. [Authorize this application](https://wso2.com/asgardeo/docs/guides/applications/register-machine-to-machine-app/#authorize-the-api-resources-for-the-app) to invoke "Application Management API" with the `internal_application_mgt_create` scope
|
1. Create an [M2M application](https://wso2.com/asgardeo/docs/guides/applications/register-machine-to-machine-app/)
|
||||||

|
1. [Authorize this application](https://wso2.com/asgardeo/docs/guides/applications/register-machine-to-machine-app/#authorize-the-api-resources-for-the-app) to invoke “Application Management API” with the `internal_application_mgt_create` scope.
|
||||||
|

|
||||||
|
2. Note the **Client ID** and **Client secret** of this application. This is required by the auth proxy
|
||||||
|
|
||||||
3. Update `config.yaml` with the following parameters.
|
#### Configure the Auth Proxy
|
||||||
|
|
||||||
|
Create a configuration file config.yaml with the following parameters:
|
||||||
|
|
||||||
```yaml
|
```yaml
|
||||||
base_url: "http://localhost:8000" # URL of your MCP server
|
mcp_server_base_url: "http://localhost:8000" # URL of your MCP server
|
||||||
listen_port: 8080 # Address where the proxy will listen
|
listen_port: 8080 # Address where the proxy will listen
|
||||||
|
|
||||||
asgardeo:
|
asgardeo:
|
||||||
|
@ -90,137 +85,31 @@ asgardeo:
|
||||||
client_secret: "<client_secret>" # Client secret of the M2M app
|
client_secret: "<client_secret>" # Client secret of the M2M app
|
||||||
```
|
```
|
||||||
|
|
||||||
4. Start the proxy with Asgardeo integration:
|
#### Start the Auth Proxy
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
./openmcpauthproxy --asgardeo
|
./openmcpauthproxy --asgardeo
|
||||||
```
|
```
|
||||||
|
|
||||||
### Other OAuth Providers
|
### Use with any standard OAuth Server
|
||||||
|
|
||||||
- [Auth0](docs/integrations/Auth0.md)
|
Enable authorization for the MCP server with a compliant OAuth server
|
||||||
- [Keycloak](docs/integrations/keycloak.md)
|
|
||||||
|
|
||||||
# Advanced Configuration
|
#### Configuration
|
||||||
|
|
||||||
### Transport Modes
|
Create a configuration file config.yaml with the following parameters:
|
||||||
|
|
||||||
The proxy supports two transport modes:
|
|
||||||
|
|
||||||
- **SSE Mode (Default)**: For Server-Sent Events transport
|
|
||||||
- **stdio Mode**: For MCP servers that use stdio transport
|
|
||||||
|
|
||||||
When using stdio mode, the proxy:
|
|
||||||
- Starts an MCP server as a subprocess using the command specified in the configuration
|
|
||||||
- Communicates with the subprocess through standard input/output (stdio)
|
|
||||||
- **Note**: Any commands specified (like `npx` in the example below) must be installed on your system first
|
|
||||||
|
|
||||||
To use stdio mode:
|
|
||||||
|
|
||||||
```bash
|
|
||||||
./openmcpauthproxy --demo --stdio
|
|
||||||
```
|
|
||||||
|
|
||||||
#### Example: Running an MCP Server as a Subprocess
|
|
||||||
|
|
||||||
1. Configure stdio mode in your `config.yaml`:
|
|
||||||
|
|
||||||
```yaml
|
```yaml
|
||||||
listen_port: 8080
|
mcp_server_base_url: "http://localhost:8000" # URL of your MCP server
|
||||||
base_url: "http://localhost:8000"
|
listen_port: 8080 # Address where the proxy will listen
|
||||||
|
|
||||||
stdio:
|
|
||||||
enabled: true
|
|
||||||
user_command: "npx -y @modelcontextprotocol/server-github" # Example using a GitHub MCP server
|
|
||||||
env: # Environment variables (optional)
|
|
||||||
- "GITHUB_PERSONAL_ACCESS_TOKEN=gitPAT"
|
|
||||||
|
|
||||||
# CORS configuration
|
|
||||||
cors:
|
|
||||||
allowed_origins:
|
|
||||||
- "http://localhost:5173" # Origin of your client application
|
|
||||||
allowed_methods:
|
|
||||||
- "GET"
|
|
||||||
- "POST"
|
|
||||||
- "PUT"
|
|
||||||
- "DELETE"
|
|
||||||
allowed_headers:
|
|
||||||
- "Authorization"
|
|
||||||
- "Content-Type"
|
|
||||||
allow_credentials: true
|
|
||||||
|
|
||||||
# Demo configuration for Asgardeo
|
|
||||||
demo:
|
|
||||||
org_name: "openmcpauthdemo"
|
|
||||||
client_id: "N0U9e_NNGr9mP_0fPnPfPI0a6twa"
|
|
||||||
client_secret: "qFHfiBp5gNGAO9zV4YPnDofBzzfInatfUbHyPZvM0jka"
|
|
||||||
```
|
```
|
||||||
|
**TODO**: Update the configs for a standard OAuth Server.
|
||||||
|
|
||||||
2. Run the proxy with stdio mode:
|
#### Start the Auth Proxy
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
./openmcpauthproxy --demo
|
./openmcpauthproxy
|
||||||
```
|
```
|
||||||
|
#### Integrating with existing OAuth Providers
|
||||||
|
|
||||||
The proxy will:
|
- [Auth0](docs/Auth0.md) - Enable authorization for the MCP server through your Auth0 organization.
|
||||||
- Start the MCP server as a subprocess using the specified command
|
|
||||||
- Handle all authorization requirements
|
|
||||||
- Forward messages between clients and the server
|
|
||||||
|
|
||||||
### Complete Configuration Reference
|
|
||||||
|
|
||||||
```yaml
|
|
||||||
# Common configuration
|
|
||||||
listen_port: 8080
|
|
||||||
base_url: "http://localhost:8000"
|
|
||||||
port: 8000
|
|
||||||
|
|
||||||
# Path configuration
|
|
||||||
paths:
|
|
||||||
sse: "/sse"
|
|
||||||
messages: "/messages/"
|
|
||||||
|
|
||||||
# Transport mode
|
|
||||||
transport_mode: "sse" # Options: "sse" or "stdio"
|
|
||||||
|
|
||||||
# stdio-specific configuration (used only in stdio mode)
|
|
||||||
stdio:
|
|
||||||
enabled: true
|
|
||||||
user_command: "npx -y @modelcontextprotocol/server-github" # Command to start the MCP server (requires npx to be installed)
|
|
||||||
work_dir: "" # Optional working directory for the subprocess
|
|
||||||
|
|
||||||
# CORS configuration
|
|
||||||
cors:
|
|
||||||
allowed_origins:
|
|
||||||
- "http://localhost:5173"
|
|
||||||
allowed_methods:
|
|
||||||
- "GET"
|
|
||||||
- "POST"
|
|
||||||
- "PUT"
|
|
||||||
- "DELETE"
|
|
||||||
allowed_headers:
|
|
||||||
- "Authorization"
|
|
||||||
- "Content-Type"
|
|
||||||
allow_credentials: true
|
|
||||||
|
|
||||||
# Demo configuration for Asgardeo
|
|
||||||
demo:
|
|
||||||
org_name: "openmcpauthdemo"
|
|
||||||
client_id: "N0U9e_NNGr9mP_0fPnPfPI0a6twa"
|
|
||||||
client_secret: "qFHfiBp5gNGAO9zV4YPnDofBzzfInatfUbHyPZvM0jka"
|
|
||||||
|
|
||||||
# Asgardeo configuration (used with --asgardeo flag)
|
|
||||||
asgardeo:
|
|
||||||
org_name: "<org_name>"
|
|
||||||
client_id: "<client_id>"
|
|
||||||
client_secret: "<client_secret>"
|
|
||||||
```
|
|
||||||
|
|
||||||
### Build from source
|
|
||||||
|
|
||||||
```bash
|
|
||||||
git clone https://github.com/wso2/open-mcp-auth-proxy
|
|
||||||
cd open-mcp-auth-proxy
|
|
||||||
go get github.com/golang-jwt/jwt/v4 gopkg.in/yaml.v2
|
|
||||||
go build -o openmcpauthproxy ./cmd/proxy
|
|
||||||
```
|
|
||||||
|
|
|
@ -3,71 +3,31 @@ package main
|
||||||
import (
|
import (
|
||||||
"flag"
|
"flag"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"log"
|
||||||
"net/http"
|
"net/http"
|
||||||
"os"
|
"os"
|
||||||
"os/signal"
|
"os/signal"
|
||||||
"syscall"
|
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/wso2/open-mcp-auth-proxy/internal/authz"
|
"github.com/wso2/open-mcp-auth-proxy/internal/authz"
|
||||||
"github.com/wso2/open-mcp-auth-proxy/internal/config"
|
"github.com/wso2/open-mcp-auth-proxy/internal/config"
|
||||||
"github.com/wso2/open-mcp-auth-proxy/internal/constants"
|
"github.com/wso2/open-mcp-auth-proxy/internal/constants"
|
||||||
logger "github.com/wso2/open-mcp-auth-proxy/internal/logging"
|
|
||||||
"github.com/wso2/open-mcp-auth-proxy/internal/proxy"
|
"github.com/wso2/open-mcp-auth-proxy/internal/proxy"
|
||||||
"github.com/wso2/open-mcp-auth-proxy/internal/subprocess"
|
|
||||||
"github.com/wso2/open-mcp-auth-proxy/internal/util"
|
"github.com/wso2/open-mcp-auth-proxy/internal/util"
|
||||||
)
|
)
|
||||||
|
|
||||||
func main() {
|
func main() {
|
||||||
demoMode := flag.Bool("demo", false, "Use Asgardeo-based provider (demo).")
|
demoMode := flag.Bool("demo", false, "Use Asgardeo-based provider (demo).")
|
||||||
asgardeoMode := flag.Bool("asgardeo", false, "Use Asgardeo-based provider (asgardeo).")
|
asgardeoMode := flag.Bool("asgardeo", false, "Use Asgardeo-based provider (asgardeo).")
|
||||||
debugMode := flag.Bool("debug", false, "Enable debug logging")
|
|
||||||
stdioMode := flag.Bool("stdio", false, "Use stdio transport mode instead of SSE")
|
|
||||||
flag.Parse()
|
flag.Parse()
|
||||||
|
|
||||||
logger.SetDebug(*debugMode)
|
|
||||||
|
|
||||||
// 1. Load config
|
// 1. Load config
|
||||||
cfg, err := config.LoadConfig("config.yaml")
|
cfg, err := config.LoadConfig("config.yaml")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Error("Error loading config: %v", err)
|
log.Fatalf("Error loading config: %v", err)
|
||||||
os.Exit(1)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Override transport mode if stdio flag is set
|
// 2. Create the chosen provider
|
||||||
if *stdioMode {
|
|
||||||
cfg.TransportMode = config.StdioTransport
|
|
||||||
// Ensure stdio is enabled
|
|
||||||
cfg.Stdio.Enabled = true
|
|
||||||
// Re-validate config
|
|
||||||
if err := cfg.Validate(); err != nil {
|
|
||||||
logger.Error("Configuration error: %v", err)
|
|
||||||
os.Exit(1)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
logger.Info("Using transport mode: %s", cfg.TransportMode)
|
|
||||||
logger.Info("Using MCP server base URL: %s", cfg.BaseURL)
|
|
||||||
logger.Info("Using MCP paths: SSE=%s, Messages=%s", cfg.Paths.SSE, cfg.Paths.Messages)
|
|
||||||
|
|
||||||
// 2. Start subprocess if configured and in stdio mode
|
|
||||||
var procManager *subprocess.Manager
|
|
||||||
if cfg.TransportMode == config.StdioTransport && cfg.Stdio.Enabled {
|
|
||||||
// Ensure all required dependencies are available
|
|
||||||
if err := subprocess.EnsureDependenciesAvailable(cfg.Stdio.UserCommand); err != nil {
|
|
||||||
logger.Warn("%v", err)
|
|
||||||
logger.Warn("Subprocess may fail to start due to missing dependencies")
|
|
||||||
}
|
|
||||||
|
|
||||||
procManager = subprocess.NewManager()
|
|
||||||
if err := procManager.Start(cfg); err != nil {
|
|
||||||
logger.Warn("Failed to start subprocess: %v", err)
|
|
||||||
}
|
|
||||||
} else if cfg.TransportMode == config.SSETransport {
|
|
||||||
logger.Info("Using SSE transport mode, not starting subprocess")
|
|
||||||
}
|
|
||||||
|
|
||||||
// 3. Create the chosen provider
|
|
||||||
var provider authz.Provider
|
var provider authz.Provider
|
||||||
if *demoMode {
|
if *demoMode {
|
||||||
cfg.Mode = "demo"
|
cfg.Mode = "demo"
|
||||||
|
@ -86,49 +46,41 @@ func main() {
|
||||||
provider = authz.NewDefaultProvider(cfg)
|
provider = authz.NewDefaultProvider(cfg)
|
||||||
}
|
}
|
||||||
|
|
||||||
// 4. (Optional) Fetch JWKS if you want local JWT validation
|
// 3. (Optional) Fetch JWKS if you want local JWT validation
|
||||||
if err := util.FetchJWKS(cfg.JWKSURL); err != nil {
|
if err := util.FetchJWKS(cfg.JWKSURL); err != nil {
|
||||||
logger.Error("Failed to fetch JWKS: %v", err)
|
log.Fatalf("Failed to fetch JWKS: %v", err)
|
||||||
os.Exit(1)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// 5. Build the main router
|
// 4. Build the main router
|
||||||
mux := proxy.NewRouter(cfg, provider)
|
mux := proxy.NewRouter(cfg, provider)
|
||||||
|
|
||||||
listen_address := fmt.Sprintf("0.0.0.0:%d", cfg.ListenPort)
|
listen_address := fmt.Sprintf(":%d", cfg.ListenPort)
|
||||||
|
|
||||||
// 6. Start the server
|
// 5. Start the server
|
||||||
srv := &http.Server{
|
srv := &http.Server{
|
||||||
|
|
||||||
Addr: listen_address,
|
Addr: listen_address,
|
||||||
Handler: mux,
|
Handler: mux,
|
||||||
}
|
}
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
logger.Info("Server listening on %s", listen_address)
|
log.Printf("Server listening on %s", listen_address)
|
||||||
if err := srv.ListenAndServe(); err != nil && err != http.ErrServerClosed {
|
if err := srv.ListenAndServe(); err != nil && err != http.ErrServerClosed {
|
||||||
logger.Error("Server error: %v", err)
|
log.Fatalf("Server error: %v", err)
|
||||||
os.Exit(1)
|
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
// 7. Wait for shutdown signal
|
// 6. Graceful shutdown on Ctrl+C
|
||||||
stop := make(chan os.Signal, 1)
|
stop := make(chan os.Signal, 1)
|
||||||
signal.Notify(stop, os.Interrupt, syscall.SIGTERM)
|
signal.Notify(stop, os.Interrupt)
|
||||||
<-stop
|
<-stop
|
||||||
logger.Info("Shutting down...")
|
log.Println("Shutting down...")
|
||||||
|
|
||||||
// 8. First terminate subprocess if running
|
|
||||||
if procManager != nil && procManager.IsRunning() {
|
|
||||||
procManager.Shutdown()
|
|
||||||
}
|
|
||||||
|
|
||||||
// 9. Then shutdown the server
|
|
||||||
logger.Info("Shutting down HTTP server...")
|
|
||||||
shutdownCtx, cancel := proxy.NewShutdownContext(5 * time.Second)
|
shutdownCtx, cancel := proxy.NewShutdownContext(5 * time.Second)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
if err := srv.Shutdown(shutdownCtx); err != nil {
|
if err := srv.Shutdown(shutdownCtx); err != nil {
|
||||||
logger.Error("HTTP server shutdown error: %v", err)
|
log.Printf("Shutdown error: %v", err)
|
||||||
}
|
}
|
||||||
logger.Info("Stopped.")
|
log.Println("Stopped.")
|
||||||
}
|
}
|
||||||
|
|
68
config.yaml
68
config.yaml
|
@ -1,28 +1,22 @@
|
||||||
# config.yaml
|
# config.yaml
|
||||||
|
|
||||||
# Common configuration for all transport modes
|
mcp_server_base_url: ""
|
||||||
listen_port: 8080
|
listen_port: 8080
|
||||||
base_url: "http://localhost:8000" # Base URL for the MCP server
|
|
||||||
port: 8000 # Port for the MCP server
|
|
||||||
timeout_seconds: 10
|
timeout_seconds: 10
|
||||||
|
|
||||||
|
mcp_paths:
|
||||||
|
- /messages/
|
||||||
|
- /sse
|
||||||
|
|
||||||
# Transport mode configuration
|
path_mapping:
|
||||||
transport_mode: "stdio" # Options: "sse" or "stdio"
|
/token: /token
|
||||||
|
/register: /register
|
||||||
|
/authorize: /authorize
|
||||||
|
/.well-known/oauth-authorization-server: /.well-known/oauth-authorization-server
|
||||||
|
|
||||||
# stdio-specific configuration (used only when transport_mode is "stdio")
|
|
||||||
stdio:
|
|
||||||
enabled: true
|
|
||||||
user_command: uvx mcp-server-time --local-timezone=Europe/Zurich
|
|
||||||
#user_command: "npx -y @modelcontextprotocol/server-github"
|
|
||||||
work_dir: "" # Working directory (optional)
|
|
||||||
# env: # Environment variables (optional)
|
|
||||||
# - "NODE_ENV=development"
|
|
||||||
|
|
||||||
# CORS settings
|
|
||||||
cors:
|
cors:
|
||||||
allowed_origins:
|
allowed_origins:
|
||||||
- "http://localhost:6274" # Origin of your frontend/client app
|
- ""
|
||||||
allowed_methods:
|
allowed_methods:
|
||||||
- "GET"
|
- "GET"
|
||||||
- "POST"
|
- "POST"
|
||||||
|
@ -31,26 +25,29 @@ cors:
|
||||||
allowed_headers:
|
allowed_headers:
|
||||||
- "Authorization"
|
- "Authorization"
|
||||||
- "Content-Type"
|
- "Content-Type"
|
||||||
- "mcp-protocol-version"
|
|
||||||
allow_credentials: true
|
allow_credentials: true
|
||||||
|
|
||||||
# Keycloak endpoint path mappings
|
demo:
|
||||||
path_mapping:
|
org_name: "openmcpauthdemo"
|
||||||
sse: "/sse" # SSE endpoint path
|
client_id: "N0U9e_NNGr9mP_0fPnPfPI0a6twa"
|
||||||
messages: "/messages/" # Messages endpoint path
|
client_secret: "qFHfiBp5gNGAO9zV4YPnDofBzzfInatfUbHyPZvM0jka"
|
||||||
/token: /realms/master/protocol/openid-connect/token
|
|
||||||
/register: /realms/master/clients-registrations/openid-connect
|
asgardeo:
|
||||||
|
org_name: "<org_name>"
|
||||||
|
client_id: "<client_id>"
|
||||||
|
client_secret: "<client_secret>"
|
||||||
|
|
||||||
# Keycloak configuration block
|
|
||||||
default:
|
default:
|
||||||
base_url: "https://iam.phoenix-systems.ch"
|
base_url: "<base_url>"
|
||||||
jwks_url: "https://iam.phoenix-systems.ch/realms/kvant/protocol/openid-connect/certs"
|
jwks_url: "<jwks_url>"
|
||||||
path:
|
path:
|
||||||
/.well-known/oauth-authorization-server:
|
/.well-known/oauth-authorization-server:
|
||||||
response:
|
response:
|
||||||
issuer: "https://iam.phoenix-systems.ch/realms/kvant"
|
issuer: "<issuer>"
|
||||||
jwks_uri: "https://iam.phoenix-systems.ch/realms/kvant/protocol/openid-connect/certs"
|
jwks_uri: "<jwks_uri>"
|
||||||
authorization_endpoint: "https://iam.phoenix-systems.ch/realms/kvant/protocol/openid-connect/auth"
|
authorization_endpoint: "<authorization_endpoint>" # Optional
|
||||||
|
token_endpoint: "<token_endpoint>" # Optional
|
||||||
|
registration_endpoint: "<registration_endpoint>" # Optional
|
||||||
response_types_supported:
|
response_types_supported:
|
||||||
- "code"
|
- "code"
|
||||||
grant_types_supported:
|
grant_types_supported:
|
||||||
|
@ -59,7 +56,16 @@ default:
|
||||||
code_challenge_methods_supported:
|
code_challenge_methods_supported:
|
||||||
- "S256"
|
- "S256"
|
||||||
- "plain"
|
- "plain"
|
||||||
|
/authroize:
|
||||||
|
addQueryParams:
|
||||||
|
- name: "<name>"
|
||||||
|
value: "<value>"
|
||||||
/token:
|
/token:
|
||||||
addBodyParams:
|
addBodyParams:
|
||||||
- name: "audience"
|
- name: "<name>"
|
||||||
value: "mcp_proxy"
|
value: "<value>"
|
||||||
|
/register:
|
||||||
|
addBodyParams:
|
||||||
|
- name: "<name>"
|
||||||
|
value: "<value>"
|
||||||
|
|
||||||
|
|
|
@ -4,7 +4,7 @@ This guide will help you configure Open MCP Auth Proxy to use Auth0 as your iden
|
||||||
|
|
||||||
### Prerequisites
|
### Prerequisites
|
||||||
|
|
||||||
- An Auth0 organization (sign up [here](https://auth0.com) if you don't have one)
|
- An Auth0 organization (sign up here if you don't have one)
|
||||||
- Open MCP Auth Proxy installed
|
- Open MCP Auth Proxy installed
|
||||||
|
|
||||||
### Setting Up Auth0
|
### Setting Up Auth0
|
||||||
|
@ -28,17 +28,9 @@ Update your `config.yaml` with Auth0 settings:
|
||||||
|
|
||||||
```yaml
|
```yaml
|
||||||
# Basic proxy configuration
|
# Basic proxy configuration
|
||||||
|
mcp_server_base_url: "http://localhost:8000"
|
||||||
listen_port: 8080
|
listen_port: 8080
|
||||||
base_url: "http://localhost:8000"
|
timeout_seconds: 10
|
||||||
port: 8000
|
|
||||||
|
|
||||||
# Path configuration
|
|
||||||
paths:
|
|
||||||
sse: "/sse"
|
|
||||||
messages: "/messages/"
|
|
||||||
|
|
||||||
# Transport mode
|
|
||||||
transport_mode: "sse"
|
|
||||||
|
|
||||||
# CORS configuration
|
# CORS configuration
|
||||||
cors:
|
cors:
|
|
@ -1,92 +0,0 @@
|
||||||
## Integrating Open MCP Auth Proxy with Keycloak
|
|
||||||
|
|
||||||
This guide walks you through configuring the Open MCP Auth Proxy to authenticate using Keycloak as the identity provider.
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
### Prerequisites
|
|
||||||
|
|
||||||
Before you begin, ensure you have the following:
|
|
||||||
|
|
||||||
- A running Keycloak instance
|
|
||||||
- Open MCP Auth Proxy installed and accessible
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
### Step 1: Configure Keycloak for Client Registration
|
|
||||||
|
|
||||||
Set up dynamic client registration in your Keycloak realm by following the [Keycloak client registration guide](https://www.keycloak.org/securing-apps/client-registration).
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
### Step 2: Configure Open MCP Auth Proxy
|
|
||||||
|
|
||||||
Update the `config.yaml` file in your Open MCP Auth Proxy setup using your Keycloak realm's [OIDC settings](https://www.keycloak.org/securing-apps/oidc-layers). Below is an example configuration:
|
|
||||||
|
|
||||||
```yaml
|
|
||||||
# Proxy server configuration
|
|
||||||
listen_port: 8081 # Port for the auth proxy
|
|
||||||
base_url: "http://localhost:8000" # Base URL of the MCP server
|
|
||||||
port: 8000 # MCP server port
|
|
||||||
|
|
||||||
# Define path mappings
|
|
||||||
paths:
|
|
||||||
sse: "/sse"
|
|
||||||
messages: "/messages/"
|
|
||||||
|
|
||||||
# Set the transport mode
|
|
||||||
transport_mode: "sse"
|
|
||||||
|
|
||||||
# CORS settings
|
|
||||||
cors:
|
|
||||||
allowed_origins:
|
|
||||||
- "http://localhost:5173" # Origin of your frontend/client app
|
|
||||||
allowed_methods:
|
|
||||||
- "GET"
|
|
||||||
- "POST"
|
|
||||||
- "PUT"
|
|
||||||
- "DELETE"
|
|
||||||
allowed_headers:
|
|
||||||
- "Authorization"
|
|
||||||
- "Content-Type"
|
|
||||||
- "mcp-protocol-version"
|
|
||||||
allow_credentials: true
|
|
||||||
|
|
||||||
# Keycloak endpoint path mappings
|
|
||||||
path_mapping:
|
|
||||||
/token: /realms/master/protocol/openid-connect/token
|
|
||||||
/register: /realms/master/clients-registrations/openid-connect
|
|
||||||
|
|
||||||
# Keycloak configuration block
|
|
||||||
default:
|
|
||||||
base_url: "http://localhost:8080"
|
|
||||||
jwks_url: "http://localhost:8080/realms/master/protocol/openid-connect/certs"
|
|
||||||
path:
|
|
||||||
/.well-known/oauth-authorization-server:
|
|
||||||
response:
|
|
||||||
issuer: "http://localhost:8080/realms/master"
|
|
||||||
jwks_uri: "http://localhost:8080/realms/master/protocol/openid-connect/certs"
|
|
||||||
authorization_endpoint: "http://localhost:8080/realms/master/protocol/openid-connect/auth"
|
|
||||||
response_types_supported:
|
|
||||||
- "code"
|
|
||||||
grant_types_supported:
|
|
||||||
- "authorization_code"
|
|
||||||
- "refresh_token"
|
|
||||||
code_challenge_methods_supported:
|
|
||||||
- "S256"
|
|
||||||
- "plain"
|
|
||||||
/token:
|
|
||||||
addBodyParams:
|
|
||||||
- name: "audience"
|
|
||||||
value: "mcp_proxy"
|
|
||||||
```
|
|
||||||
|
|
||||||
### Step 3: Start the Auth Proxy
|
|
||||||
|
|
||||||
Launch the proxy with the updated Keycloak configuration:
|
|
||||||
|
|
||||||
```bash
|
|
||||||
./openmcpauthproxy
|
|
||||||
```
|
|
||||||
|
|
||||||
Once running, the proxy will handle authentication requests through your configured Keycloak realm.
|
|
2
go.mod
2
go.mod
|
@ -1,6 +1,6 @@
|
||||||
module github.com/wso2/open-mcp-auth-proxy
|
module github.com/wso2/open-mcp-auth-proxy
|
||||||
|
|
||||||
go 1.21
|
go 1.22.3
|
||||||
|
|
||||||
require (
|
require (
|
||||||
github.com/golang-jwt/jwt/v4 v4.5.2
|
github.com/golang-jwt/jwt/v4 v4.5.2
|
||||||
|
|
6
go.sum
6
go.sum
|
@ -1,6 +0,0 @@
|
||||||
github.com/golang-jwt/jwt/v4 v4.5.2 h1:YtQM7lnr8iZ+j5q71MGKkNw9Mn7AjHM68uc9g5fXeUI=
|
|
||||||
github.com/golang-jwt/jwt/v4 v4.5.2/go.mod h1:m21LjoU+eqJr34lmDMbreY2eSTRJ1cv77w39/MY0Ch0=
|
|
||||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
|
|
||||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
|
||||||
gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY=
|
|
||||||
gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ=
|
|
|
@ -7,13 +7,13 @@ import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
|
"log"
|
||||||
"math/rand"
|
"math/rand"
|
||||||
"net/http"
|
"net/http"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/wso2/open-mcp-auth-proxy/internal/config"
|
"github.com/wso2/open-mcp-auth-proxy/internal/config"
|
||||||
"github.com/wso2/open-mcp-auth-proxy/internal/logging"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type asgardeoProvider struct {
|
type asgardeoProvider struct {
|
||||||
|
@ -31,7 +31,6 @@ func (p *asgardeoProvider) WellKnownHandler() http.HandlerFunc {
|
||||||
w.Header().Set("Access-Control-Allow-Origin", "*")
|
w.Header().Set("Access-Control-Allow-Origin", "*")
|
||||||
w.Header().Set("Access-Control-Allow-Headers", "Authorization, Content-Type")
|
w.Header().Set("Access-Control-Allow-Headers", "Authorization, Content-Type")
|
||||||
w.Header().Set("Access-Control-Allow-Methods", "GET, OPTIONS")
|
w.Header().Set("Access-Control-Allow-Methods", "GET, OPTIONS")
|
||||||
w.Header().Set("X-Accel-Buffering", "no")
|
|
||||||
|
|
||||||
if r.Method == http.MethodOptions {
|
if r.Method == http.MethodOptions {
|
||||||
w.WriteHeader(http.StatusNoContent)
|
w.WriteHeader(http.StatusNoContent)
|
||||||
|
@ -71,9 +70,8 @@ func (p *asgardeoProvider) WellKnownHandler() http.HandlerFunc {
|
||||||
}
|
}
|
||||||
|
|
||||||
w.Header().Set("Content-Type", "application/json")
|
w.Header().Set("Content-Type", "application/json")
|
||||||
w.Header().Set("X-Accel-Buffering", "no")
|
|
||||||
if err := json.NewEncoder(w).Encode(response); err != nil {
|
if err := json.NewEncoder(w).Encode(response); err != nil {
|
||||||
logger.Error("Error encoding well-known: %v", err)
|
log.Printf("[asgardeoProvider] Error encoding well-known: %v", err)
|
||||||
http.Error(w, "Internal server error", http.StatusInternalServerError)
|
http.Error(w, "Internal server error", http.StatusInternalServerError)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -85,7 +83,6 @@ func (p *asgardeoProvider) RegisterHandler() http.HandlerFunc {
|
||||||
w.Header().Set("Access-Control-Allow-Origin", "*")
|
w.Header().Set("Access-Control-Allow-Origin", "*")
|
||||||
w.Header().Set("Access-Control-Allow-Headers", "Authorization, Content-Type")
|
w.Header().Set("Access-Control-Allow-Headers", "Authorization, Content-Type")
|
||||||
w.Header().Set("Access-Control-Allow-Methods", "GET, OPTIONS")
|
w.Header().Set("Access-Control-Allow-Methods", "GET, OPTIONS")
|
||||||
w.Header().Set("X-Accel-Buffering", "no")
|
|
||||||
|
|
||||||
if r.Method == http.MethodOptions {
|
if r.Method == http.MethodOptions {
|
||||||
w.WriteHeader(http.StatusNoContent)
|
w.WriteHeader(http.StatusNoContent)
|
||||||
|
@ -98,7 +95,7 @@ func (p *asgardeoProvider) RegisterHandler() http.HandlerFunc {
|
||||||
|
|
||||||
var regReq RegisterRequest
|
var regReq RegisterRequest
|
||||||
if err := json.NewDecoder(r.Body).Decode(®Req); err != nil {
|
if err := json.NewDecoder(r.Body).Decode(®Req); err != nil {
|
||||||
logger.Error("Reading register request: %v", err)
|
log.Printf("ERROR: reading register request: %v", err)
|
||||||
http.Error(w, "Invalid request body", http.StatusBadRequest)
|
http.Error(w, "Invalid request body", http.StatusBadRequest)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@ -112,7 +109,7 @@ func (p *asgardeoProvider) RegisterHandler() http.HandlerFunc {
|
||||||
regReq.ClientSecret = randomString(16)
|
regReq.ClientSecret = randomString(16)
|
||||||
|
|
||||||
if err := p.createAsgardeoApplication(regReq); err != nil {
|
if err := p.createAsgardeoApplication(regReq); err != nil {
|
||||||
logger.Warn("Asgardeo application creation failed: %v", err)
|
log.Printf("WARN: Asgardeo application creation failed: %v", err)
|
||||||
// Optionally http.Error(...) if you want to fail
|
// Optionally http.Error(...) if you want to fail
|
||||||
// or continue to return partial data.
|
// or continue to return partial data.
|
||||||
}
|
}
|
||||||
|
@ -127,10 +124,9 @@ func (p *asgardeoProvider) RegisterHandler() http.HandlerFunc {
|
||||||
}
|
}
|
||||||
|
|
||||||
w.Header().Set("Content-Type", "application/json")
|
w.Header().Set("Content-Type", "application/json")
|
||||||
w.Header().Set("X-Accel-Buffering", "no")
|
|
||||||
w.WriteHeader(http.StatusCreated)
|
w.WriteHeader(http.StatusCreated)
|
||||||
if err := json.NewEncoder(w).Encode(resp); err != nil {
|
if err := json.NewEncoder(w).Encode(resp); err != nil {
|
||||||
logger.Error("Encoding /register response: %v", err)
|
log.Printf("ERROR: encoding /register response: %v", err)
|
||||||
http.Error(w, "Internal server error", http.StatusInternalServerError)
|
http.Error(w, "Internal server error", http.StatusInternalServerError)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -190,7 +186,7 @@ func (p *asgardeoProvider) createAsgardeoApplication(regReq RegisterRequest) err
|
||||||
return fmt.Errorf("Asgardeo creation error (%d): %s", resp.StatusCode, string(respBody))
|
return fmt.Errorf("Asgardeo creation error (%d): %s", resp.StatusCode, string(respBody))
|
||||||
}
|
}
|
||||||
|
|
||||||
logger.Info("Created Asgardeo application for clientID=%s", regReq.ClientID)
|
log.Printf("INFO: Created Asgardeo application for clientID=%s", regReq.ClientID)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -206,12 +202,9 @@ func (p *asgardeoProvider) getAsgardeoAdminToken() (string, error) {
|
||||||
}
|
}
|
||||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||||
|
|
||||||
// Sensitive data - should not be logged at INFO level
|
|
||||||
auth := p.cfg.Demo.ClientID + ":" + p.cfg.Demo.ClientSecret
|
auth := p.cfg.Demo.ClientID + ":" + p.cfg.Demo.ClientSecret
|
||||||
req.Header.Set("Authorization", "Basic "+base64.StdEncoding.EncodeToString([]byte(auth)))
|
req.Header.Set("Authorization", "Basic "+base64.StdEncoding.EncodeToString([]byte(auth)))
|
||||||
|
|
||||||
logger.Debug("Requesting admin token for Asgardeo with client ID: %s", p.cfg.Demo.ClientID)
|
|
||||||
|
|
||||||
tr := &http.Transport{
|
tr := &http.Transport{
|
||||||
TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
|
TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
|
||||||
}
|
}
|
||||||
|
@ -241,10 +234,6 @@ func (p *asgardeoProvider) getAsgardeoAdminToken() (string, error) {
|
||||||
return "", fmt.Errorf("failed to parse token JSON: %w", err)
|
return "", fmt.Errorf("failed to parse token JSON: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Don't log the actual token at info level, only at debug level
|
|
||||||
logger.Debug("Received access token: %s", tokenResp.AccessToken)
|
|
||||||
logger.Info("Successfully obtained admin token from Asgardeo")
|
|
||||||
|
|
||||||
return tokenResp.AccessToken, nil
|
return tokenResp.AccessToken, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -5,7 +5,6 @@ import (
|
||||||
"net/http"
|
"net/http"
|
||||||
|
|
||||||
"github.com/wso2/open-mcp-auth-proxy/internal/config"
|
"github.com/wso2/open-mcp-auth-proxy/internal/config"
|
||||||
"github.com/wso2/open-mcp-auth-proxy/internal/logging"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type defaultProvider struct {
|
type defaultProvider struct {
|
||||||
|
@ -82,7 +81,6 @@ func (p *defaultProvider) WellKnownHandler() http.HandlerFunc {
|
||||||
|
|
||||||
w.Header().Set("Content-Type", "application/json")
|
w.Header().Set("Content-Type", "application/json")
|
||||||
if err := json.NewEncoder(w).Encode(response); err != nil {
|
if err := json.NewEncoder(w).Encode(response); err != nil {
|
||||||
logger.Error("Error encoding well-known response: %v", err)
|
|
||||||
http.Error(w, "Internal server error", http.StatusInternalServerError)
|
http.Error(w, "Internal server error", http.StatusInternalServerError)
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
|
|
|
@ -1,125 +0,0 @@
|
||||||
package authz
|
|
||||||
|
|
||||||
import (
|
|
||||||
"encoding/json"
|
|
||||||
"net/http"
|
|
||||||
"net/http/httptest"
|
|
||||||
"testing"
|
|
||||||
|
|
||||||
"github.com/wso2/open-mcp-auth-proxy/internal/config"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestNewDefaultProvider(t *testing.T) {
|
|
||||||
cfg := &config.Config{}
|
|
||||||
provider := NewDefaultProvider(cfg)
|
|
||||||
|
|
||||||
if provider == nil {
|
|
||||||
t.Fatal("Expected non-nil provider")
|
|
||||||
}
|
|
||||||
|
|
||||||
// Ensure it implements the Provider interface
|
|
||||||
var _ Provider = provider
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestDefaultProviderWellKnownHandler(t *testing.T) {
|
|
||||||
// Create a config with a custom well-known response
|
|
||||||
cfg := &config.Config{
|
|
||||||
Default: config.DefaultConfig{
|
|
||||||
Path: map[string]config.PathConfig{
|
|
||||||
"/.well-known/oauth-authorization-server": {
|
|
||||||
Response: &config.ResponseConfig{
|
|
||||||
Issuer: "https://test-issuer.com",
|
|
||||||
JwksURI: "https://test-issuer.com/jwks",
|
|
||||||
ResponseTypesSupported: []string{"code"},
|
|
||||||
GrantTypesSupported: []string{"authorization_code"},
|
|
||||||
CodeChallengeMethodsSupported: []string{"S256"},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
provider := NewDefaultProvider(cfg)
|
|
||||||
handler := provider.WellKnownHandler()
|
|
||||||
|
|
||||||
// Create a test request
|
|
||||||
req := httptest.NewRequest("GET", "/.well-known/oauth-authorization-server", nil)
|
|
||||||
req.Host = "test-host.com"
|
|
||||||
req.Header.Set("X-Forwarded-Proto", "https")
|
|
||||||
|
|
||||||
// Create a response recorder
|
|
||||||
w := httptest.NewRecorder()
|
|
||||||
|
|
||||||
// Call the handler
|
|
||||||
handler(w, req)
|
|
||||||
|
|
||||||
// Check response status
|
|
||||||
if w.Code != http.StatusOK {
|
|
||||||
t.Errorf("Expected status OK, got %v", w.Code)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Verify content type
|
|
||||||
contentType := w.Header().Get("Content-Type")
|
|
||||||
if contentType != "application/json" {
|
|
||||||
t.Errorf("Expected Content-Type: application/json, got %s", contentType)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Decode and check the response body
|
|
||||||
var response map[string]interface{}
|
|
||||||
if err := json.NewDecoder(w.Body).Decode(&response); err != nil {
|
|
||||||
t.Fatalf("Failed to decode response JSON: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check expected values
|
|
||||||
if response["issuer"] != "https://test-issuer.com" {
|
|
||||||
t.Errorf("Expected issuer=https://test-issuer.com, got %v", response["issuer"])
|
|
||||||
}
|
|
||||||
if response["jwks_uri"] != "https://test-issuer.com/jwks" {
|
|
||||||
t.Errorf("Expected jwks_uri=https://test-issuer.com/jwks, got %v", response["jwks_uri"])
|
|
||||||
}
|
|
||||||
if response["authorization_endpoint"] != "https://test-host.com/authorize" {
|
|
||||||
t.Errorf("Expected authorization_endpoint=https://test-host.com/authorize, got %v", response["authorization_endpoint"])
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestDefaultProviderHandleOPTIONS(t *testing.T) {
|
|
||||||
provider := NewDefaultProvider(&config.Config{})
|
|
||||||
handler := provider.WellKnownHandler()
|
|
||||||
|
|
||||||
// Create OPTIONS request
|
|
||||||
req := httptest.NewRequest("OPTIONS", "/.well-known/oauth-authorization-server", nil)
|
|
||||||
w := httptest.NewRecorder()
|
|
||||||
|
|
||||||
// Call the handler
|
|
||||||
handler(w, req)
|
|
||||||
|
|
||||||
// Check response
|
|
||||||
if w.Code != http.StatusNoContent {
|
|
||||||
t.Errorf("Expected status NoContent for OPTIONS request, got %v", w.Code)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check CORS headers
|
|
||||||
if w.Header().Get("Access-Control-Allow-Origin") != "*" {
|
|
||||||
t.Errorf("Expected Access-Control-Allow-Origin: *, got %s", w.Header().Get("Access-Control-Allow-Origin"))
|
|
||||||
}
|
|
||||||
if w.Header().Get("Access-Control-Allow-Methods") != "GET, OPTIONS" {
|
|
||||||
t.Errorf("Expected Access-Control-Allow-Methods: GET, OPTIONS, got %s", w.Header().Get("Access-Control-Allow-Methods"))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestDefaultProviderInvalidMethod(t *testing.T) {
|
|
||||||
provider := NewDefaultProvider(&config.Config{})
|
|
||||||
handler := provider.WellKnownHandler()
|
|
||||||
|
|
||||||
// Create POST request (which should be rejected)
|
|
||||||
req := httptest.NewRequest("POST", "/.well-known/oauth-authorization-server", nil)
|
|
||||||
w := httptest.NewRecorder()
|
|
||||||
|
|
||||||
// Call the handler
|
|
||||||
handler(w, req)
|
|
||||||
|
|
||||||
// Check response
|
|
||||||
if w.Code != http.StatusMethodNotAllowed {
|
|
||||||
t.Errorf("Expected status MethodNotAllowed for POST request, got %v", w.Code)
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,35 +1,12 @@
|
||||||
package config
|
package config
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
|
||||||
"os"
|
"os"
|
||||||
|
|
||||||
"gopkg.in/yaml.v2"
|
"gopkg.in/yaml.v2"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Transport mode for MCP server
|
// AsgardeoConfig groups all Asgardeo-specific fields
|
||||||
type TransportMode string
|
|
||||||
|
|
||||||
const (
|
|
||||||
SSETransport TransportMode = "sse"
|
|
||||||
StdioTransport TransportMode = "stdio"
|
|
||||||
)
|
|
||||||
|
|
||||||
// Common path configuration for all transport modes
|
|
||||||
type PathsConfig struct {
|
|
||||||
SSE string `yaml:"sse"`
|
|
||||||
Messages string `yaml:"messages"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// StdioConfig contains stdio-specific configuration
|
|
||||||
type StdioConfig struct {
|
|
||||||
Enabled bool `yaml:"enabled"`
|
|
||||||
UserCommand string `yaml:"user_command"` // The command provided by the user
|
|
||||||
WorkDir string `yaml:"work_dir"` // Working directory (optional)
|
|
||||||
Args []string `yaml:"args,omitempty"` // Additional arguments
|
|
||||||
Env []string `yaml:"env,omitempty"` // Environment variables
|
|
||||||
}
|
|
||||||
|
|
||||||
type DemoConfig struct {
|
type DemoConfig struct {
|
||||||
ClientID string `yaml:"client_id"`
|
ClientID string `yaml:"client_id"`
|
||||||
ClientSecret string `yaml:"client_secret"`
|
ClientSecret string `yaml:"client_secret"`
|
||||||
|
@ -83,18 +60,15 @@ type DefaultConfig struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
type Config struct {
|
type Config struct {
|
||||||
AuthServerBaseURL string
|
AuthServerBaseURL string
|
||||||
ListenPort int `yaml:"listen_port"`
|
MCPServerBaseURL string `yaml:"mcp_server_base_url"`
|
||||||
BaseURL string `yaml:"base_url"`
|
ListenPort int `yaml:"listen_port"`
|
||||||
Port int `yaml:"port"`
|
JWKSURL string
|
||||||
JWKSURL string
|
TimeoutSeconds int `yaml:"timeout_seconds"`
|
||||||
TimeoutSeconds int `yaml:"timeout_seconds"`
|
MCPPaths []string `yaml:"mcp_paths"`
|
||||||
PathMapping map[string]string `yaml:"path_mapping"`
|
PathMapping map[string]string `yaml:"path_mapping"`
|
||||||
Mode string `yaml:"mode"`
|
Mode string `yaml:"mode"`
|
||||||
CORSConfig CORSConfig `yaml:"cors"`
|
CORSConfig CORSConfig `yaml:"cors"`
|
||||||
TransportMode TransportMode `yaml:"transport_mode"`
|
|
||||||
Paths PathsConfig `yaml:"paths"`
|
|
||||||
Stdio StdioConfig `yaml:"stdio"`
|
|
||||||
|
|
||||||
// Nested config for Asgardeo
|
// Nested config for Asgardeo
|
||||||
Demo DemoConfig `yaml:"demo"`
|
Demo DemoConfig `yaml:"demo"`
|
||||||
|
@ -102,56 +76,6 @@ type Config struct {
|
||||||
Default DefaultConfig `yaml:"default"`
|
Default DefaultConfig `yaml:"default"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// Validate checks if the config is valid based on transport mode
|
|
||||||
func (c *Config) Validate() error {
|
|
||||||
// Validate based on transport mode
|
|
||||||
if c.TransportMode == StdioTransport {
|
|
||||||
if !c.Stdio.Enabled {
|
|
||||||
return fmt.Errorf("stdio.enabled must be true in stdio transport mode")
|
|
||||||
}
|
|
||||||
if c.Stdio.UserCommand == "" {
|
|
||||||
return fmt.Errorf("stdio.user_command is required in stdio transport mode")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Validate paths
|
|
||||||
if c.Paths.SSE == "" {
|
|
||||||
c.Paths.SSE = "/sse" // Default value
|
|
||||||
}
|
|
||||||
if c.Paths.Messages == "" {
|
|
||||||
c.Paths.Messages = "/messages" // Default value
|
|
||||||
}
|
|
||||||
|
|
||||||
// Validate base URL
|
|
||||||
if c.BaseURL == "" {
|
|
||||||
if c.Port > 0 {
|
|
||||||
c.BaseURL = fmt.Sprintf("http://localhost:%d", c.Port)
|
|
||||||
} else {
|
|
||||||
c.BaseURL = "http://localhost:8000" // Default value
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetMCPPaths returns the list of paths that should be proxied to the MCP server
|
|
||||||
func (c *Config) GetMCPPaths() []string {
|
|
||||||
return []string{c.Paths.SSE, c.Paths.Messages}
|
|
||||||
}
|
|
||||||
|
|
||||||
// BuildExecCommand constructs the full command string for execution in stdio mode
|
|
||||||
func (c *Config) BuildExecCommand() string {
|
|
||||||
if c.Stdio.UserCommand == "" {
|
|
||||||
return ""
|
|
||||||
}
|
|
||||||
|
|
||||||
// Construct the full command
|
|
||||||
return fmt.Sprintf(
|
|
||||||
`npx -y supergateway --stdio "%s" --port %d --baseUrl %s --ssePath %s --messagePath %s`,
|
|
||||||
c.Stdio.UserCommand, c.Port, c.BaseURL, c.Paths.SSE, c.Paths.Messages,
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
// LoadConfig reads a YAML config file into Config struct.
|
// LoadConfig reads a YAML config file into Config struct.
|
||||||
func LoadConfig(path string) (*Config, error) {
|
func LoadConfig(path string) (*Config, error) {
|
||||||
f, err := os.Open(path)
|
f, err := os.Open(path)
|
||||||
|
@ -165,26 +89,8 @@ func LoadConfig(path string) (*Config, error) {
|
||||||
if err := decoder.Decode(&cfg); err != nil {
|
if err := decoder.Decode(&cfg); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// Set default values
|
|
||||||
if cfg.TimeoutSeconds == 0 {
|
if cfg.TimeoutSeconds == 0 {
|
||||||
cfg.TimeoutSeconds = 15 // default
|
cfg.TimeoutSeconds = 15 // default
|
||||||
}
|
}
|
||||||
|
|
||||||
// Set default transport mode if not specified
|
|
||||||
if cfg.TransportMode == "" {
|
|
||||||
cfg.TransportMode = SSETransport // Default to SSE
|
|
||||||
}
|
|
||||||
|
|
||||||
// Set default port if not specified
|
|
||||||
if cfg.Port == 0 {
|
|
||||||
cfg.Port = 8000 // default
|
|
||||||
}
|
|
||||||
|
|
||||||
// Validate the configuration
|
|
||||||
if err := cfg.Validate(); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return &cfg, nil
|
return &cfg, nil
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,196 +0,0 @@
|
||||||
package config
|
|
||||||
|
|
||||||
import (
|
|
||||||
"os"
|
|
||||||
"path/filepath"
|
|
||||||
"testing"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestLoadConfig(t *testing.T) {
|
|
||||||
// Create a temporary config file
|
|
||||||
tempDir := t.TempDir()
|
|
||||||
configPath := filepath.Join(tempDir, "test_config.yaml")
|
|
||||||
|
|
||||||
// Basic valid config
|
|
||||||
validConfig := `
|
|
||||||
listen_port: 8080
|
|
||||||
base_url: "http://localhost:8000"
|
|
||||||
transport_mode: "sse"
|
|
||||||
paths:
|
|
||||||
sse: "/sse"
|
|
||||||
messages: "/messages"
|
|
||||||
cors:
|
|
||||||
allowed_origins:
|
|
||||||
- "http://localhost:5173"
|
|
||||||
allowed_methods:
|
|
||||||
- "GET"
|
|
||||||
- "POST"
|
|
||||||
allowed_headers:
|
|
||||||
- "Authorization"
|
|
||||||
- "Content-Type"
|
|
||||||
allow_credentials: true
|
|
||||||
`
|
|
||||||
err := os.WriteFile(configPath, []byte(validConfig), 0644)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Failed to create test config file: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Test loading the valid config
|
|
||||||
cfg, err := LoadConfig(configPath)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Failed to load valid config: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Verify expected values from the config
|
|
||||||
if cfg.ListenPort != 8080 {
|
|
||||||
t.Errorf("Expected ListenPort=8080, got %d", cfg.ListenPort)
|
|
||||||
}
|
|
||||||
if cfg.BaseURL != "http://localhost:8000" {
|
|
||||||
t.Errorf("Expected BaseURL=http://localhost:8000, got %s", cfg.BaseURL)
|
|
||||||
}
|
|
||||||
if cfg.TransportMode != SSETransport {
|
|
||||||
t.Errorf("Expected TransportMode=sse, got %s", cfg.TransportMode)
|
|
||||||
}
|
|
||||||
if cfg.Paths.SSE != "/sse" {
|
|
||||||
t.Errorf("Expected Paths.SSE=/sse, got %s", cfg.Paths.SSE)
|
|
||||||
}
|
|
||||||
if cfg.Paths.Messages != "/messages" {
|
|
||||||
t.Errorf("Expected Paths.Messages=/messages, got %s", cfg.Paths.Messages)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Test default values
|
|
||||||
if cfg.TimeoutSeconds != 15 {
|
|
||||||
t.Errorf("Expected default TimeoutSeconds=15, got %d", cfg.TimeoutSeconds)
|
|
||||||
}
|
|
||||||
if cfg.Port != 8000 {
|
|
||||||
t.Errorf("Expected default Port=8000, got %d", cfg.Port)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestValidate(t *testing.T) {
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
config Config
|
|
||||||
expectError bool
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
name: "Valid SSE config",
|
|
||||||
config: Config{
|
|
||||||
TransportMode: SSETransport,
|
|
||||||
Paths: PathsConfig{
|
|
||||||
SSE: "/sse",
|
|
||||||
Messages: "/messages",
|
|
||||||
},
|
|
||||||
BaseURL: "http://localhost:8000",
|
|
||||||
},
|
|
||||||
expectError: false,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "Valid stdio config",
|
|
||||||
config: Config{
|
|
||||||
TransportMode: StdioTransport,
|
|
||||||
Stdio: StdioConfig{
|
|
||||||
Enabled: true,
|
|
||||||
UserCommand: "some-command",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
expectError: false,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "Invalid stdio config - not enabled",
|
|
||||||
config: Config{
|
|
||||||
TransportMode: StdioTransport,
|
|
||||||
Stdio: StdioConfig{
|
|
||||||
Enabled: false,
|
|
||||||
UserCommand: "some-command",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
expectError: true,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "Invalid stdio config - no command",
|
|
||||||
config: Config{
|
|
||||||
TransportMode: StdioTransport,
|
|
||||||
Stdio: StdioConfig{
|
|
||||||
Enabled: true,
|
|
||||||
UserCommand: "",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
expectError: true,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tc := range tests {
|
|
||||||
t.Run(tc.name, func(t *testing.T) {
|
|
||||||
err := tc.config.Validate()
|
|
||||||
if tc.expectError && err == nil {
|
|
||||||
t.Errorf("Expected validation error but got none")
|
|
||||||
}
|
|
||||||
if !tc.expectError && err != nil {
|
|
||||||
t.Errorf("Expected no validation error but got: %v", err)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestGetMCPPaths(t *testing.T) {
|
|
||||||
cfg := Config{
|
|
||||||
Paths: PathsConfig{
|
|
||||||
SSE: "/custom-sse",
|
|
||||||
Messages: "/custom-messages",
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
paths := cfg.GetMCPPaths()
|
|
||||||
if len(paths) != 2 {
|
|
||||||
t.Errorf("Expected 2 MCP paths, got %d", len(paths))
|
|
||||||
}
|
|
||||||
if paths[0] != "/custom-sse" {
|
|
||||||
t.Errorf("Expected first path=/custom-sse, got %s", paths[0])
|
|
||||||
}
|
|
||||||
if paths[1] != "/custom-messages" {
|
|
||||||
t.Errorf("Expected second path=/custom-messages, got %s", paths[1])
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestBuildExecCommand(t *testing.T) {
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
config Config
|
|
||||||
expectedResult string
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
name: "Valid command",
|
|
||||||
config: Config{
|
|
||||||
Stdio: StdioConfig{
|
|
||||||
UserCommand: "test-command",
|
|
||||||
},
|
|
||||||
Port: 8080,
|
|
||||||
BaseURL: "http://example.com",
|
|
||||||
Paths: PathsConfig{
|
|
||||||
SSE: "/sse-path",
|
|
||||||
Messages: "/msgs",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
expectedResult: `npx -y supergateway --stdio "test-command" --port 8080 --baseUrl http://example.com --ssePath /sse-path --messagePath /msgs`,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "Empty command",
|
|
||||||
config: Config{
|
|
||||||
Stdio: StdioConfig{
|
|
||||||
UserCommand: "",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
expectedResult: "",
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tc := range tests {
|
|
||||||
t.Run(tc.name, func(t *testing.T) {
|
|
||||||
result := tc.config.BuildExecCommand()
|
|
||||||
if result != tc.expectedResult {
|
|
||||||
t.Errorf("Expected command=%s, got %s", tc.expectedResult, result)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,34 +0,0 @@
|
||||||
package logger
|
|
||||||
|
|
||||||
import (
|
|
||||||
"log"
|
|
||||||
)
|
|
||||||
|
|
||||||
var isDebug = false
|
|
||||||
|
|
||||||
// SetDebug enables or disables debug logging
|
|
||||||
func SetDebug(debug bool) {
|
|
||||||
isDebug = debug
|
|
||||||
}
|
|
||||||
|
|
||||||
// Debug logs a debug-level message
|
|
||||||
func Debug(format string, v ...interface{}) {
|
|
||||||
if isDebug {
|
|
||||||
log.Printf("DEBUG: "+format, v...)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Info logs an info-level message
|
|
||||||
func Info(format string, v ...interface{}) {
|
|
||||||
log.Printf("INFO: "+format, v...)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Warn logs a warning-level message
|
|
||||||
func Warn(format string, v ...interface{}) {
|
|
||||||
log.Printf("WARN: "+format, v...)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Error logs an error-level message
|
|
||||||
func Error(format string, v ...interface{}) {
|
|
||||||
log.Printf("ERROR: "+format, v...)
|
|
||||||
}
|
|
|
@ -9,7 +9,6 @@ import (
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/wso2/open-mcp-auth-proxy/internal/config"
|
"github.com/wso2/open-mcp-auth-proxy/internal/config"
|
||||||
"github.com/wso2/open-mcp-auth-proxy/internal/logging"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// RequestModifier modifies requests before they are proxied
|
// RequestModifier modifies requests before they are proxied
|
||||||
|
@ -149,7 +148,6 @@ func (m *RegisterModifier) ModifyRequest(req *http.Request) (*http.Request, erro
|
||||||
if strings.Contains(contentType, "application/x-www-form-urlencoded") {
|
if strings.Contains(contentType, "application/x-www-form-urlencoded") {
|
||||||
// Parse form data
|
// Parse form data
|
||||||
if err := req.ParseForm(); err != nil {
|
if err := req.ParseForm(); err != nil {
|
||||||
logger.Error("Failed to parse form data: %v", err)
|
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -171,14 +169,12 @@ func (m *RegisterModifier) ModifyRequest(req *http.Request) (*http.Request, erro
|
||||||
// Read body
|
// Read body
|
||||||
bodyBytes, err := io.ReadAll(req.Body)
|
bodyBytes, err := io.ReadAll(req.Body)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Error("Failed to read request body: %v", err)
|
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// Parse JSON
|
// Parse JSON
|
||||||
var jsonData map[string]interface{}
|
var jsonData map[string]interface{}
|
||||||
if err := json.Unmarshal(bodyBytes, &jsonData); err != nil {
|
if err := json.Unmarshal(bodyBytes, &jsonData); err != nil {
|
||||||
logger.Error("Failed to parse JSON body: %v", err)
|
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -190,7 +186,6 @@ func (m *RegisterModifier) ModifyRequest(req *http.Request) (*http.Request, erro
|
||||||
// Marshal back to JSON
|
// Marshal back to JSON
|
||||||
modifiedBody, err := json.Marshal(jsonData)
|
modifiedBody, err := json.Marshal(jsonData)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Error("Failed to marshal modified JSON: %v", err)
|
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -1,147 +0,0 @@
|
||||||
package proxy
|
|
||||||
|
|
||||||
import (
|
|
||||||
"net/http"
|
|
||||||
"net/url"
|
|
||||||
"strings"
|
|
||||||
"testing"
|
|
||||||
|
|
||||||
"github.com/wso2/open-mcp-auth-proxy/internal/config"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestAuthorizationModifier(t *testing.T) {
|
|
||||||
cfg := &config.Config{
|
|
||||||
Default: config.DefaultConfig{
|
|
||||||
Path: map[string]config.PathConfig{
|
|
||||||
"/authorize": {
|
|
||||||
AddQueryParams: []config.ParamConfig{
|
|
||||||
{Name: "client_id", Value: "test-client-id"},
|
|
||||||
{Name: "scope", Value: "openid"},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
modifier := &AuthorizationModifier{Config: cfg}
|
|
||||||
|
|
||||||
// Create a test request
|
|
||||||
req, err := http.NewRequest("GET", "/authorize?response_type=code", nil)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Failed to create request: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Modify the request
|
|
||||||
modifiedReq, err := modifier.ModifyRequest(req)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("ModifyRequest failed: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check that the query parameters were added
|
|
||||||
query := modifiedReq.URL.Query()
|
|
||||||
if query.Get("client_id") != "test-client-id" {
|
|
||||||
t.Errorf("Expected client_id=test-client-id, got %s", query.Get("client_id"))
|
|
||||||
}
|
|
||||||
if query.Get("scope") != "openid" {
|
|
||||||
t.Errorf("Expected scope=openid, got %s", query.Get("scope"))
|
|
||||||
}
|
|
||||||
if query.Get("response_type") != "code" {
|
|
||||||
t.Errorf("Expected response_type=code, got %s", query.Get("response_type"))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestTokenModifier(t *testing.T) {
|
|
||||||
cfg := &config.Config{
|
|
||||||
Default: config.DefaultConfig{
|
|
||||||
Path: map[string]config.PathConfig{
|
|
||||||
"/token": {
|
|
||||||
AddBodyParams: []config.ParamConfig{
|
|
||||||
{Name: "audience", Value: "test-audience"},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
modifier := &TokenModifier{Config: cfg}
|
|
||||||
|
|
||||||
// Create a test request with form data
|
|
||||||
form := url.Values{}
|
|
||||||
|
|
||||||
req, err := http.NewRequest("POST", "/token", strings.NewReader(form.Encode()))
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Failed to create request: %v", err)
|
|
||||||
}
|
|
||||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
|
||||||
|
|
||||||
// Modify the request
|
|
||||||
modifiedReq, err := modifier.ModifyRequest(req)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("ModifyRequest failed: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
body := make([]byte, 1024)
|
|
||||||
n, err := modifiedReq.Body.Read(body)
|
|
||||||
if err != nil && err.Error() != "EOF" {
|
|
||||||
t.Fatalf("Failed to read body: %v", err)
|
|
||||||
}
|
|
||||||
bodyStr := string(body[:n])
|
|
||||||
|
|
||||||
// Parse the form data from the modified request
|
|
||||||
if err := modifiedReq.ParseForm(); err != nil {
|
|
||||||
t.Fatalf("Failed to parse form data: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check that the body parameters were added
|
|
||||||
if !strings.Contains(bodyStr, "audience") {
|
|
||||||
t.Errorf("Expected body to contain audience, got %s", bodyStr)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestRegisterModifier(t *testing.T) {
|
|
||||||
cfg := &config.Config{
|
|
||||||
Default: config.DefaultConfig{
|
|
||||||
Path: map[string]config.PathConfig{
|
|
||||||
"/register": {
|
|
||||||
AddBodyParams: []config.ParamConfig{
|
|
||||||
{Name: "client_name", Value: "test-client"},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
modifier := &RegisterModifier{Config: cfg}
|
|
||||||
|
|
||||||
// Create a test request with JSON data
|
|
||||||
jsonBody := `{"redirect_uris":["https://example.com/callback"]}`
|
|
||||||
req, err := http.NewRequest("POST", "/register", strings.NewReader(jsonBody))
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Failed to create request: %v", err)
|
|
||||||
}
|
|
||||||
req.Header.Set("Content-Type", "application/json")
|
|
||||||
|
|
||||||
// Modify the request
|
|
||||||
modifiedReq, err := modifier.ModifyRequest(req)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("ModifyRequest failed: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Read the body and check that it still contains the original data
|
|
||||||
// This test would need to be enhanced with a proper JSON parsing to verify
|
|
||||||
// the added parameters
|
|
||||||
body := make([]byte, 1024)
|
|
||||||
n, err := modifiedReq.Body.Read(body)
|
|
||||||
if err != nil && err.Error() != "EOF" {
|
|
||||||
t.Fatalf("Failed to read body: %v", err)
|
|
||||||
}
|
|
||||||
bodyStr := string(body[:n])
|
|
||||||
|
|
||||||
// Simple check to see if the modified body contains the expected fields
|
|
||||||
if !strings.Contains(bodyStr, "client_name") {
|
|
||||||
t.Errorf("Expected body to contain client_name, got %s", bodyStr)
|
|
||||||
}
|
|
||||||
if !strings.Contains(bodyStr, "redirect_uris") {
|
|
||||||
t.Errorf("Expected body to contain redirect_uris, got %s", bodyStr)
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -2,6 +2,7 @@ package proxy
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"log"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httputil"
|
"net/http/httputil"
|
||||||
"net/url"
|
"net/url"
|
||||||
|
@ -10,7 +11,6 @@ import (
|
||||||
|
|
||||||
"github.com/wso2/open-mcp-auth-proxy/internal/authz"
|
"github.com/wso2/open-mcp-auth-proxy/internal/authz"
|
||||||
"github.com/wso2/open-mcp-auth-proxy/internal/config"
|
"github.com/wso2/open-mcp-auth-proxy/internal/config"
|
||||||
"github.com/wso2/open-mcp-auth-proxy/internal/logging"
|
|
||||||
"github.com/wso2/open-mcp-auth-proxy/internal/util"
|
"github.com/wso2/open-mcp-auth-proxy/internal/util"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -82,8 +82,7 @@ func NewRouter(cfg *config.Config, provider authz.Provider) http.Handler {
|
||||||
}
|
}
|
||||||
|
|
||||||
// MCP paths
|
// MCP paths
|
||||||
mcpPaths := cfg.GetMCPPaths()
|
for _, path := range cfg.MCPPaths {
|
||||||
for _, path := range mcpPaths {
|
|
||||||
mux.HandleFunc(path, buildProxyHandler(cfg, modifiers))
|
mux.HandleFunc(path, buildProxyHandler(cfg, modifiers))
|
||||||
registeredPaths[path] = true
|
registeredPaths[path] = true
|
||||||
}
|
}
|
||||||
|
@ -101,21 +100,23 @@ func NewRouter(cfg *config.Config, provider authz.Provider) http.Handler {
|
||||||
|
|
||||||
func buildProxyHandler(cfg *config.Config, modifiers map[string]RequestModifier) http.HandlerFunc {
|
func buildProxyHandler(cfg *config.Config, modifiers map[string]RequestModifier) http.HandlerFunc {
|
||||||
// Parse the base URLs up front
|
// Parse the base URLs up front
|
||||||
|
|
||||||
authBase, err := url.Parse(cfg.AuthServerBaseURL)
|
authBase, err := url.Parse(cfg.AuthServerBaseURL)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Error("Invalid auth server URL: %v", err)
|
log.Fatalf("Invalid auth server URL: %v", err)
|
||||||
panic(err) // Fatal error that prevents startup
|
|
||||||
}
|
}
|
||||||
|
mcpBase, err := url.Parse(cfg.MCPServerBaseURL)
|
||||||
mcpBase, err := url.Parse(cfg.BaseURL)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Error("Invalid MCP server URL: %v", err)
|
log.Fatalf("Invalid MCP server URL: %v", err)
|
||||||
panic(err) // Fatal error that prevents startup
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Detect SSE paths from config
|
// Detect SSE paths from config
|
||||||
ssePaths := make(map[string]bool)
|
ssePaths := make(map[string]bool)
|
||||||
ssePaths[cfg.Paths.SSE] = true
|
for _, p := range cfg.MCPPaths {
|
||||||
|
if p == "/sse" {
|
||||||
|
ssePaths[p] = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
return func(w http.ResponseWriter, r *http.Request) {
|
return func(w http.ResponseWriter, r *http.Request) {
|
||||||
origin := r.Header.Get("Origin")
|
origin := r.Header.Get("Origin")
|
||||||
|
@ -123,7 +124,7 @@ func buildProxyHandler(cfg *config.Config, modifiers map[string]RequestModifier)
|
||||||
// Handle OPTIONS
|
// Handle OPTIONS
|
||||||
if r.Method == http.MethodOptions {
|
if r.Method == http.MethodOptions {
|
||||||
if allowedOrigin == "" {
|
if allowedOrigin == "" {
|
||||||
logger.Warn("Preflight request from disallowed origin: %s", origin)
|
log.Printf("[proxy] Preflight request from disallowed origin: %s", origin)
|
||||||
http.Error(w, "CORS origin not allowed", http.StatusForbidden)
|
http.Error(w, "CORS origin not allowed", http.StatusForbidden)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@ -133,7 +134,7 @@ func buildProxyHandler(cfg *config.Config, modifiers map[string]RequestModifier)
|
||||||
}
|
}
|
||||||
|
|
||||||
if allowedOrigin == "" {
|
if allowedOrigin == "" {
|
||||||
logger.Warn("Request from disallowed origin: %s for %s", origin, r.URL.Path)
|
log.Printf("[proxy] Request from disallowed origin: %s for %s", origin, r.URL.Path)
|
||||||
http.Error(w, "CORS origin not allowed", http.StatusForbidden)
|
http.Error(w, "CORS origin not allowed", http.StatusForbidden)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@ -151,7 +152,7 @@ func buildProxyHandler(cfg *config.Config, modifiers map[string]RequestModifier)
|
||||||
// Validate JWT for MCP paths if required
|
// Validate JWT for MCP paths if required
|
||||||
// Placeholder for JWT validation logic
|
// Placeholder for JWT validation logic
|
||||||
if err := util.ValidateJWT(r.Header.Get("Authorization")); err != nil {
|
if err := util.ValidateJWT(r.Header.Get("Authorization")); err != nil {
|
||||||
logger.Warn("Unauthorized request to %s: %v", r.URL.Path, err)
|
log.Printf("[proxy] Unauthorized request to %s: %v", r.URL.Path, err)
|
||||||
http.Error(w, "Unauthorized", http.StatusUnauthorized)
|
http.Error(w, "Unauthorized", http.StatusUnauthorized)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@ -169,7 +170,7 @@ func buildProxyHandler(cfg *config.Config, modifiers map[string]RequestModifier)
|
||||||
var err error
|
var err error
|
||||||
r, err = modifier.ModifyRequest(r)
|
r, err = modifier.ModifyRequest(r)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Error("Error modifying request: %v", err)
|
log.Printf("[proxy] Error modifying request: %v", err)
|
||||||
http.Error(w, "Bad Request", http.StatusBadRequest)
|
http.Error(w, "Bad Request", http.StatusBadRequest)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@ -192,12 +193,6 @@ func buildProxyHandler(cfg *config.Config, modifiers map[string]RequestModifier)
|
||||||
|
|
||||||
cleanHeaders := http.Header{}
|
cleanHeaders := http.Header{}
|
||||||
|
|
||||||
// Set proper origin header to match the target
|
|
||||||
if isSSE {
|
|
||||||
// For SSE, ensure origin matches the target
|
|
||||||
req.Header.Set("Origin", targetURL.Scheme+"://"+targetURL.Host)
|
|
||||||
}
|
|
||||||
|
|
||||||
for k, v := range r.Header {
|
for k, v := range r.Header {
|
||||||
// Skip hop-by-hop headers
|
// Skip hop-by-hop headers
|
||||||
if skipHeader(k) {
|
if skipHeader(k) {
|
||||||
|
@ -210,33 +205,21 @@ func buildProxyHandler(cfg *config.Config, modifiers map[string]RequestModifier)
|
||||||
|
|
||||||
req.Header = cleanHeaders
|
req.Header = cleanHeaders
|
||||||
|
|
||||||
logger.Debug("%s -> %s%s", r.URL.Path, req.URL.Host, req.URL.Path)
|
log.Printf("[proxy] %s -> %s%s", r.URL.Path, req.URL.Host, req.URL.Path)
|
||||||
},
|
},
|
||||||
ModifyResponse: func(resp *http.Response) error {
|
ModifyResponse: func(resp *http.Response) error {
|
||||||
logger.Debug("Response from %s%s: %d", resp.Request.URL.Host, resp.Request.URL.Path, resp.StatusCode)
|
log.Printf("[proxy] Response from %s%s: %d", resp.Request.URL.Host, resp.Request.URL.Path, resp.StatusCode)
|
||||||
resp.Header.Del("Access-Control-Allow-Origin") // Avoid upstream conflicts
|
resp.Header.Del("Access-Control-Allow-Origin") // Avoid upstream conflicts
|
||||||
return nil
|
return nil
|
||||||
},
|
},
|
||||||
ErrorHandler: func(rw http.ResponseWriter, req *http.Request, err error) {
|
ErrorHandler: func(rw http.ResponseWriter, req *http.Request, err error) {
|
||||||
logger.Error("Error proxying: %v", err)
|
log.Printf("[proxy] Error proxying: %v", err)
|
||||||
http.Error(rw, "Bad Gateway", http.StatusBadGateway)
|
http.Error(rw, "Bad Gateway", http.StatusBadGateway)
|
||||||
},
|
},
|
||||||
FlushInterval: -1, // immediate flush for SSE
|
FlushInterval: -1, // immediate flush for SSE
|
||||||
}
|
}
|
||||||
|
|
||||||
if isSSE {
|
if isSSE {
|
||||||
// Add special response handling for SSE connections to rewrite endpoint URLs
|
|
||||||
rp.Transport = &sseTransport{
|
|
||||||
Transport: http.DefaultTransport,
|
|
||||||
proxyHost: r.Host,
|
|
||||||
targetHost: targetURL.Host,
|
|
||||||
}
|
|
||||||
|
|
||||||
// Set SSE-specific headers
|
|
||||||
w.Header().Set("X-Accel-Buffering", "no")
|
|
||||||
w.Header().Set("Cache-Control", "no-cache")
|
|
||||||
w.Header().Set("Connection", "keep-alive")
|
|
||||||
|
|
||||||
// Keep SSE connections open
|
// Keep SSE connections open
|
||||||
HandleSSE(w, r, rp)
|
HandleSSE(w, r, rp)
|
||||||
} else {
|
} else {
|
||||||
|
@ -253,7 +236,6 @@ func getAllowedOrigin(origin string, cfg *config.Config) string {
|
||||||
return cfg.CORSConfig.AllowedOrigins[0] // Default to first allowed origin
|
return cfg.CORSConfig.AllowedOrigins[0] // Default to first allowed origin
|
||||||
}
|
}
|
||||||
for _, allowed := range cfg.CORSConfig.AllowedOrigins {
|
for _, allowed := range cfg.CORSConfig.AllowedOrigins {
|
||||||
logger.Debug("Checking CORS origin: %s against allowed: %s", origin, allowed)
|
|
||||||
if allowed == origin {
|
if allowed == origin {
|
||||||
return allowed
|
return allowed
|
||||||
}
|
}
|
||||||
|
@ -274,7 +256,6 @@ func addCORSHeaders(w http.ResponseWriter, cfg *config.Config, allowedOrigin, re
|
||||||
w.Header().Set("Access-Control-Allow-Credentials", "true")
|
w.Header().Set("Access-Control-Allow-Credentials", "true")
|
||||||
}
|
}
|
||||||
w.Header().Set("Vary", "Origin")
|
w.Header().Set("Vary", "Origin")
|
||||||
w.Header().Set("X-Accel-Buffering", "no")
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func isAuthPath(path string) bool {
|
func isAuthPath(path string) bool {
|
||||||
|
@ -292,8 +273,7 @@ func isAuthPath(path string) bool {
|
||||||
|
|
||||||
// isMCPPath checks if the path is an MCP path
|
// isMCPPath checks if the path is an MCP path
|
||||||
func isMCPPath(path string, cfg *config.Config) bool {
|
func isMCPPath(path string, cfg *config.Config) bool {
|
||||||
mcpPaths := cfg.GetMCPPaths()
|
for _, p := range cfg.MCPPaths {
|
||||||
for _, p := range mcpPaths {
|
|
||||||
if strings.HasPrefix(path, p) {
|
if strings.HasPrefix(path, p) {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,16 +1,11 @@
|
||||||
package proxy
|
package proxy
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bufio"
|
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
"log"
|
||||||
"io"
|
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httputil"
|
"net/http/httputil"
|
||||||
"strings"
|
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/wso2/open-mcp-auth-proxy/internal/logging"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// HandleSSE sets up a go-routine to wait for context cancellation
|
// HandleSSE sets up a go-routine to wait for context cancellation
|
||||||
|
@ -21,7 +16,7 @@ func HandleSSE(w http.ResponseWriter, r *http.Request, rp *httputil.ReverseProxy
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
<-ctx.Done()
|
<-ctx.Done()
|
||||||
logger.Info("SSE connection closed from %s (path: %s)", r.RemoteAddr, r.URL.Path)
|
log.Printf("INFO: SSE connection closed from %s (path: %s)", r.RemoteAddr, r.URL.Path)
|
||||||
close(done)
|
close(done)
|
||||||
}()
|
}()
|
||||||
|
|
||||||
|
@ -37,73 +32,3 @@ func HandleSSE(w http.ResponseWriter, r *http.Request, rp *httputil.ReverseProxy
|
||||||
func NewShutdownContext(timeout time.Duration) (context.Context, context.CancelFunc) {
|
func NewShutdownContext(timeout time.Duration) (context.Context, context.CancelFunc) {
|
||||||
return context.WithTimeout(context.Background(), timeout)
|
return context.WithTimeout(context.Background(), timeout)
|
||||||
}
|
}
|
||||||
|
|
||||||
// sseTransport is a custom http.RoundTripper that intercepts and modifies SSE responses
|
|
||||||
type sseTransport struct {
|
|
||||||
Transport http.RoundTripper
|
|
||||||
proxyHost string
|
|
||||||
targetHost string
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *sseTransport) RoundTrip(req *http.Request) (*http.Response, error) {
|
|
||||||
// Call the underlying transport
|
|
||||||
resp, err := t.Transport.RoundTrip(req)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check if this is an SSE response
|
|
||||||
contentType := resp.Header.Get("Content-Type")
|
|
||||||
if !strings.Contains(contentType, "text/event-stream") {
|
|
||||||
return resp, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
logger.Info("Intercepting SSE response to modify endpoint events")
|
|
||||||
|
|
||||||
// Create a response wrapper that modifies the response body
|
|
||||||
originalBody := resp.Body
|
|
||||||
pr, pw := io.Pipe()
|
|
||||||
|
|
||||||
go func() {
|
|
||||||
defer originalBody.Close()
|
|
||||||
defer pw.Close()
|
|
||||||
|
|
||||||
scanner := bufio.NewScanner(originalBody)
|
|
||||||
for scanner.Scan() {
|
|
||||||
line := scanner.Text()
|
|
||||||
|
|
||||||
// Check if this line contains an endpoint event
|
|
||||||
if strings.HasPrefix(line, "event: endpoint") {
|
|
||||||
// Read the data line
|
|
||||||
if scanner.Scan() {
|
|
||||||
dataLine := scanner.Text()
|
|
||||||
if strings.HasPrefix(dataLine, "data: ") {
|
|
||||||
// Extract the endpoint URL
|
|
||||||
endpoint := strings.TrimPrefix(dataLine, "data: ")
|
|
||||||
|
|
||||||
// Replace the host in the endpoint
|
|
||||||
logger.Debug("Original endpoint: %s", endpoint)
|
|
||||||
endpoint = strings.Replace(endpoint, t.targetHost, t.proxyHost, 1)
|
|
||||||
logger.Debug("Modified endpoint: %s", endpoint)
|
|
||||||
|
|
||||||
// Write the modified event lines
|
|
||||||
fmt.Fprintln(pw, line)
|
|
||||||
fmt.Fprintln(pw, "data: "+endpoint)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Write the original line for non-endpoint events
|
|
||||||
fmt.Fprintln(pw, line)
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := scanner.Err(); err != nil {
|
|
||||||
logger.Error("Error reading SSE stream: %v", err)
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
// Replace the response body with our modified pipe
|
|
||||||
resp.Body = pr
|
|
||||||
return resp, nil
|
|
||||||
}
|
|
||||||
|
|
|
@ -1,268 +0,0 @@
|
||||||
package subprocess
|
|
||||||
|
|
||||||
import (
|
|
||||||
"fmt"
|
|
||||||
"os"
|
|
||||||
"os/exec"
|
|
||||||
"sync"
|
|
||||||
"syscall"
|
|
||||||
"time"
|
|
||||||
"strings"
|
|
||||||
|
|
||||||
"github.com/wso2/open-mcp-auth-proxy/internal/config"
|
|
||||||
"github.com/wso2/open-mcp-auth-proxy/internal/logging"
|
|
||||||
)
|
|
||||||
|
|
||||||
// Manager handles starting and graceful shutdown of subprocesses
|
|
||||||
type Manager struct {
|
|
||||||
process *os.Process
|
|
||||||
processGroup int
|
|
||||||
mutex sync.Mutex
|
|
||||||
cmd *exec.Cmd
|
|
||||||
shutdownDelay time.Duration
|
|
||||||
}
|
|
||||||
|
|
||||||
// NewManager creates a new subprocess manager
|
|
||||||
func NewManager() *Manager {
|
|
||||||
return &Manager{
|
|
||||||
shutdownDelay: 5 * time.Second,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// EnsureDependenciesAvailable checks and installs required package executors
|
|
||||||
func EnsureDependenciesAvailable(command string) error {
|
|
||||||
// Always ensure npx is available regardless of the command
|
|
||||||
if _, err := exec.LookPath("npx"); err != nil {
|
|
||||||
// npx is not available, check if npm is installed
|
|
||||||
if _, err := exec.LookPath("npm"); err != nil {
|
|
||||||
return fmt.Errorf("npx not found and npm not available; please install Node.js from https://nodejs.org/")
|
|
||||||
}
|
|
||||||
|
|
||||||
// Try to install npx using npm
|
|
||||||
logger.Info("npx not found, attempting to install...")
|
|
||||||
cmd := exec.Command("npm", "install", "-g", "npx")
|
|
||||||
cmd.Stdout = os.Stdout
|
|
||||||
cmd.Stderr = os.Stderr
|
|
||||||
|
|
||||||
if err := cmd.Run(); err != nil {
|
|
||||||
return fmt.Errorf("failed to install npx: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
logger.Info("npx installed successfully")
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check if uv is needed based on the command
|
|
||||||
if strings.Contains(command, "uv ") {
|
|
||||||
if _, err := exec.LookPath("uv"); err != nil {
|
|
||||||
return fmt.Errorf("command requires uv but it's not installed; please install it following instructions at https://github.com/astral-sh/uv")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetShutdownDelay sets the maximum time to wait for graceful shutdown
|
|
||||||
func (m *Manager) SetShutdownDelay(duration time.Duration) {
|
|
||||||
m.shutdownDelay = duration
|
|
||||||
}
|
|
||||||
|
|
||||||
// Start launches a subprocess based on the configuration
|
|
||||||
func (m *Manager) Start(cfg *config.Config) error {
|
|
||||||
m.mutex.Lock()
|
|
||||||
defer m.mutex.Unlock()
|
|
||||||
|
|
||||||
// If a process is already running, return an error
|
|
||||||
if m.process != nil {
|
|
||||||
return os.ErrExist
|
|
||||||
}
|
|
||||||
|
|
||||||
if !cfg.Stdio.Enabled || cfg.Stdio.UserCommand == "" {
|
|
||||||
return nil // Nothing to start
|
|
||||||
}
|
|
||||||
|
|
||||||
// Get the full command string
|
|
||||||
execCommand := cfg.BuildExecCommand()
|
|
||||||
if execCommand == "" {
|
|
||||||
return nil // No command to execute
|
|
||||||
}
|
|
||||||
|
|
||||||
logger.Info("Starting subprocess with command: %s", execCommand)
|
|
||||||
|
|
||||||
// Use the shell to execute the command
|
|
||||||
cmd := exec.Command("sh", "-c", execCommand)
|
|
||||||
|
|
||||||
// Set working directory if specified
|
|
||||||
if cfg.Stdio.WorkDir != "" {
|
|
||||||
cmd.Dir = cfg.Stdio.WorkDir
|
|
||||||
}
|
|
||||||
|
|
||||||
// Set environment variables if specified
|
|
||||||
if len(cfg.Stdio.Env) > 0 {
|
|
||||||
cmd.Env = append(os.Environ(), cfg.Stdio.Env...)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Capture stdout/stderr
|
|
||||||
cmd.Stdout = os.Stdout
|
|
||||||
cmd.Stderr = os.Stderr
|
|
||||||
|
|
||||||
// Set the process group for proper termination
|
|
||||||
cmd.SysProcAttr = &syscall.SysProcAttr{Setpgid: true}
|
|
||||||
|
|
||||||
// Start the process
|
|
||||||
if err := cmd.Start(); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
m.process = cmd.Process
|
|
||||||
m.cmd = cmd
|
|
||||||
logger.Info("Subprocess started with PID: %d", m.process.Pid)
|
|
||||||
|
|
||||||
// Get and store the process group ID
|
|
||||||
pgid, err := syscall.Getpgid(m.process.Pid)
|
|
||||||
if err == nil {
|
|
||||||
m.processGroup = pgid
|
|
||||||
logger.Debug("Process group ID: %d", m.processGroup)
|
|
||||||
} else {
|
|
||||||
logger.Warn("Failed to get process group ID: %v", err)
|
|
||||||
m.processGroup = m.process.Pid
|
|
||||||
}
|
|
||||||
|
|
||||||
// Handle process termination in background
|
|
||||||
go func() {
|
|
||||||
if err := cmd.Wait(); err != nil {
|
|
||||||
logger.Error("Subprocess exited with error: %v", err)
|
|
||||||
} else {
|
|
||||||
logger.Info("Subprocess exited successfully")
|
|
||||||
}
|
|
||||||
|
|
||||||
// Clear the process reference when it exits
|
|
||||||
m.mutex.Lock()
|
|
||||||
m.process = nil
|
|
||||||
m.cmd = nil
|
|
||||||
m.mutex.Unlock()
|
|
||||||
}()
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// IsRunning checks if the subprocess is running
|
|
||||||
func (m *Manager) IsRunning() bool {
|
|
||||||
m.mutex.Lock()
|
|
||||||
defer m.mutex.Unlock()
|
|
||||||
return m.process != nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Shutdown gracefully terminates the subprocess
|
|
||||||
func (m *Manager) Shutdown() {
|
|
||||||
m.mutex.Lock()
|
|
||||||
processToTerminate := m.process // Local copy of the process reference
|
|
||||||
processGroupToTerminate := m.processGroup
|
|
||||||
m.mutex.Unlock()
|
|
||||||
|
|
||||||
if processToTerminate == nil {
|
|
||||||
return // No process to terminate
|
|
||||||
}
|
|
||||||
|
|
||||||
logger.Info("Terminating subprocess...")
|
|
||||||
terminateComplete := make(chan struct{})
|
|
||||||
|
|
||||||
go func() {
|
|
||||||
defer close(terminateComplete)
|
|
||||||
|
|
||||||
// Try graceful termination first with SIGTERM
|
|
||||||
terminatedGracefully := false
|
|
||||||
|
|
||||||
// Try to terminate the process group first
|
|
||||||
if processGroupToTerminate != 0 {
|
|
||||||
err := syscall.Kill(-processGroupToTerminate, syscall.SIGTERM)
|
|
||||||
if err != nil {
|
|
||||||
logger.Warn("Failed to send SIGTERM to process group: %v", err)
|
|
||||||
|
|
||||||
// Fallback to terminating just the process
|
|
||||||
m.mutex.Lock()
|
|
||||||
if m.process != nil {
|
|
||||||
err = m.process.Signal(syscall.SIGTERM)
|
|
||||||
if err != nil {
|
|
||||||
logger.Warn("Failed to send SIGTERM to process: %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
m.mutex.Unlock()
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
// Try to terminate just the process
|
|
||||||
m.mutex.Lock()
|
|
||||||
if m.process != nil {
|
|
||||||
err := m.process.Signal(syscall.SIGTERM)
|
|
||||||
if err != nil {
|
|
||||||
logger.Warn("Failed to send SIGTERM to process: %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
m.mutex.Unlock()
|
|
||||||
}
|
|
||||||
|
|
||||||
// Wait for the process to exit gracefully
|
|
||||||
for i := 0; i < 10; i++ {
|
|
||||||
time.Sleep(200 * time.Millisecond)
|
|
||||||
|
|
||||||
m.mutex.Lock()
|
|
||||||
if m.process == nil {
|
|
||||||
terminatedGracefully = true
|
|
||||||
m.mutex.Unlock()
|
|
||||||
break
|
|
||||||
}
|
|
||||||
m.mutex.Unlock()
|
|
||||||
}
|
|
||||||
|
|
||||||
if terminatedGracefully {
|
|
||||||
logger.Info("Subprocess terminated gracefully")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// If the process didn't exit gracefully, force kill
|
|
||||||
logger.Warn("Subprocess didn't exit gracefully, forcing termination...")
|
|
||||||
|
|
||||||
// Try to kill the process group first
|
|
||||||
if processGroupToTerminate != 0 {
|
|
||||||
if err := syscall.Kill(-processGroupToTerminate, syscall.SIGKILL); err != nil {
|
|
||||||
logger.Warn("Failed to send SIGKILL to process group: %v", err)
|
|
||||||
|
|
||||||
// Fallback to killing just the process
|
|
||||||
m.mutex.Lock()
|
|
||||||
if m.process != nil {
|
|
||||||
if err := m.process.Kill(); err != nil {
|
|
||||||
logger.Error("Failed to kill process: %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
m.mutex.Unlock()
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
// Try to kill just the process
|
|
||||||
m.mutex.Lock()
|
|
||||||
if m.process != nil {
|
|
||||||
if err := m.process.Kill(); err != nil {
|
|
||||||
logger.Error("Failed to kill process: %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
m.mutex.Unlock()
|
|
||||||
}
|
|
||||||
|
|
||||||
// Wait a bit more to confirm termination
|
|
||||||
time.Sleep(500 * time.Millisecond)
|
|
||||||
|
|
||||||
m.mutex.Lock()
|
|
||||||
if m.process == nil {
|
|
||||||
logger.Info("Subprocess terminated by force")
|
|
||||||
} else {
|
|
||||||
logger.Warn("Failed to terminate subprocess")
|
|
||||||
}
|
|
||||||
m.mutex.Unlock()
|
|
||||||
}()
|
|
||||||
|
|
||||||
// Wait for termination with timeout
|
|
||||||
select {
|
|
||||||
case <-terminateComplete:
|
|
||||||
// Termination completed
|
|
||||||
case <-time.After(m.shutdownDelay):
|
|
||||||
logger.Warn("Subprocess termination timed out")
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -4,12 +4,12 @@ import (
|
||||||
"crypto/rsa"
|
"crypto/rsa"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
|
"log"
|
||||||
"math/big"
|
"math/big"
|
||||||
"net/http"
|
"net/http"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/golang-jwt/jwt/v4"
|
"github.com/golang-jwt/jwt/v4"
|
||||||
"github.com/wso2/open-mcp-auth-proxy/internal/logging"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type JWKS struct {
|
type JWKS struct {
|
||||||
|
@ -50,7 +50,7 @@ func FetchJWKS(jwksURL string) error {
|
||||||
publicKeys[parsedKey.Kid] = pubKey
|
publicKeys[parsedKey.Kid] = pubKey
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
logger.Info("Loaded %d public keys.", len(publicKeys))
|
log.Printf("[JWKS] Loaded %d public keys.", len(publicKeys))
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -1,143 +0,0 @@
|
||||||
package util
|
|
||||||
|
|
||||||
import (
|
|
||||||
"crypto/rand"
|
|
||||||
"crypto/rsa"
|
|
||||||
"encoding/base64"
|
|
||||||
"encoding/json"
|
|
||||||
"net/http"
|
|
||||||
"net/http/httptest"
|
|
||||||
"testing"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/golang-jwt/jwt/v4"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestValidateJWT(t *testing.T) {
|
|
||||||
// Initialize the test JWKS data
|
|
||||||
initTestJWKS(t)
|
|
||||||
|
|
||||||
// Test cases
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
authHeader string
|
|
||||||
expectError bool
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
name: "Valid JWT token",
|
|
||||||
authHeader: "Bearer " + createValidJWT(t),
|
|
||||||
expectError: false,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "No auth header",
|
|
||||||
authHeader: "",
|
|
||||||
expectError: true,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "Invalid auth header format",
|
|
||||||
authHeader: "InvalidFormat",
|
|
||||||
expectError: true,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "Invalid JWT token",
|
|
||||||
authHeader: "Bearer invalid.jwt.token",
|
|
||||||
expectError: true,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tc := range tests {
|
|
||||||
t.Run(tc.name, func(t *testing.T) {
|
|
||||||
err := ValidateJWT(tc.authHeader)
|
|
||||||
if tc.expectError && err == nil {
|
|
||||||
t.Errorf("Expected error but got none")
|
|
||||||
}
|
|
||||||
if !tc.expectError && err != nil {
|
|
||||||
t.Errorf("Expected no error but got: %v", err)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestFetchJWKS(t *testing.T) {
|
|
||||||
// Create a mock JWKS server
|
|
||||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
// Generate a test RSA key
|
|
||||||
privateKey, err := rsa.GenerateKey(rand.Reader, 2048)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Failed to generate RSA key: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Create JWKS response
|
|
||||||
jwks := map[string]interface{}{
|
|
||||||
"keys": []map[string]interface{}{
|
|
||||||
{
|
|
||||||
"kty": "RSA",
|
|
||||||
"kid": "test-key-id",
|
|
||||||
"n": base64.RawURLEncoding.EncodeToString(privateKey.N.Bytes()),
|
|
||||||
"e": base64.RawURLEncoding.EncodeToString([]byte{1, 0, 1}), // Default exponent 65537
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
w.Header().Set("Content-Type", "application/json")
|
|
||||||
json.NewEncoder(w).Encode(jwks)
|
|
||||||
}))
|
|
||||||
defer server.Close()
|
|
||||||
|
|
||||||
// Test fetching JWKS
|
|
||||||
err := FetchJWKS(server.URL)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("FetchJWKS failed: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check that keys were stored
|
|
||||||
if len(publicKeys) == 0 {
|
|
||||||
t.Errorf("Expected publicKeys to be populated")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Helper function to initialize test JWKS data
|
|
||||||
func initTestJWKS(t *testing.T) {
|
|
||||||
// Create a test RSA key pair
|
|
||||||
privateKey, err := rsa.GenerateKey(rand.Reader, 2048)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Failed to generate RSA key: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Initialize the publicKeys map
|
|
||||||
publicKeys = map[string]*rsa.PublicKey{
|
|
||||||
"test-key-id": &privateKey.PublicKey,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Helper function to create a valid JWT token for testing
|
|
||||||
func createValidJWT(t *testing.T) string {
|
|
||||||
// Create a test RSA key pair
|
|
||||||
privateKey, err := rsa.GenerateKey(rand.Reader, 2048)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Failed to generate RSA key: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Ensure the test key is in the publicKeys map
|
|
||||||
if publicKeys == nil {
|
|
||||||
publicKeys = map[string]*rsa.PublicKey{}
|
|
||||||
}
|
|
||||||
publicKeys["test-key-id"] = &privateKey.PublicKey
|
|
||||||
|
|
||||||
// Create token
|
|
||||||
token := jwt.NewWithClaims(jwt.SigningMethodRS256, jwt.MapClaims{
|
|
||||||
"sub": "1234567890",
|
|
||||||
"name": "Test User",
|
|
||||||
"iat": time.Now().Unix(),
|
|
||||||
"exp": time.Now().Add(time.Hour).Unix(),
|
|
||||||
})
|
|
||||||
token.Header["kid"] = "test-key-id"
|
|
||||||
|
|
||||||
// Sign the token
|
|
||||||
tokenString, err := token.SignedString(privateKey)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Failed to sign token: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return tokenString
|
|
||||||
}
|
|
|
@ -1,11 +1,52 @@
|
||||||
## Purpose
|
## Purpose
|
||||||
<!-- Describe the problems, issues, or needs driving this feature/fix and include links to related issues in the following format: Resolves issue1, issue2, etc. -->
|
> Describe the problems, issues, or needs driving this feature/fix and include links to related issues in the following format: Resolves issue1, issue2, etc.
|
||||||
|
|
||||||
## Related Issues
|
## Goals
|
||||||
<!-- List any related issues -->
|
> Describe the solutions that this feature/fix will introduce to resolve the problems described above
|
||||||
|
|
||||||
|
## Approach
|
||||||
|
> Describe how you are implementing the solutions. Include an animated GIF or screenshot if the change affects the UI (email documentation@wso2.com to review all UI text). Include a link to a Markdown file or Google doc if the feature write-up is too long to paste here.
|
||||||
|
|
||||||
|
## User stories
|
||||||
|
> Summary of user stories addressed by this change>
|
||||||
|
|
||||||
|
## Release note
|
||||||
|
> Brief description of the new feature or bug fix as it will appear in the release notes
|
||||||
|
|
||||||
|
## Documentation
|
||||||
|
> Link(s) to product documentation that addresses the changes of this PR. If no doc impact, enter “N/A” plus brief explanation of why there’s no doc impact
|
||||||
|
|
||||||
|
## Training
|
||||||
|
> Link to the PR for changes to the training content in https://github.com/wso2/WSO2-Training, if applicable
|
||||||
|
|
||||||
|
## Certification
|
||||||
|
> Type “Sent” when you have provided new/updated certification questions, plus four answers for each question (correct answer highlighted in bold), based on this change. Certification questions/answers should be sent to certification@wso2.com and NOT pasted in this PR. If there is no impact on certification exams, type “N/A” and explain why.
|
||||||
|
|
||||||
|
## Marketing
|
||||||
|
> Link to drafts of marketing content that will describe and promote this feature, including product page changes, technical articles, blog posts, videos, etc., if applicable
|
||||||
|
|
||||||
|
## Automation tests
|
||||||
|
- Unit tests
|
||||||
|
> Code coverage information
|
||||||
|
- Integration tests
|
||||||
|
> Details about the test cases and coverage
|
||||||
|
|
||||||
|
## Security checks
|
||||||
|
- Followed secure coding standards in http://wso2.com/technical-reports/wso2-secure-engineering-guidelines? yes/no
|
||||||
|
- Ran FindSecurityBugs plugin and verified report? yes/no
|
||||||
|
- Confirmed that this PR doesn't commit any keys, passwords, tokens, usernames, or other secrets? yes/no
|
||||||
|
|
||||||
|
## Samples
|
||||||
|
> Provide high-level details about the samples related to this feature
|
||||||
|
|
||||||
## Related PRs
|
## Related PRs
|
||||||
<!-- List any other related PRs -->
|
> List any other related PRs
|
||||||
|
|
||||||
## Migrations (if applicable)
|
## Migrations (if applicable)
|
||||||
<!-- Describe migration steps and platforms on which migration has been tested -->
|
> Describe migration steps and platforms on which migration has been tested
|
||||||
|
|
||||||
|
## Test environment
|
||||||
|
> List all JDK versions, operating systems, databases, and browser/versions on which this feature/fix was tested
|
||||||
|
|
||||||
|
## Learning
|
||||||
|
> Describe the research phase and any blog posts, patterns, libraries, or add-ons you used to solve the problem.
|
|
@ -1 +0,0 @@
|
||||||
fastmcp==0.4.1
|
|
Loading…
Add table
Add a link
Reference in a new issue