diff --git a/.github/scripts/release.sh b/.github/scripts/release.sh new file mode 100644 index 0000000..2a1f6a9 --- /dev/null +++ b/.github/scripts/release.sh @@ -0,0 +1,124 @@ +#!/bin/bash + +# Copyright (c) 2025, WSO2 LLC. (https://www.wso2.com). +# +# This software is the property of WSO2 LLC. and its suppliers, if any. +# Dissemination of any information or reproduction of any material contained +# herein in any form is strictly forbidden, unless permitted by WSO2 expressly. +# You may not alter or remove any copyright or other notice from copies of this content. +# + +# Exit the script on any command with non-zero exit status. +set -e +set -o pipefail + +UPSTREAM_BRANCH="main" + +# Assign command line arguments to variables. +GIT_TOKEN=$1 +WORK_DIR=$2 +VERSION_TYPE=$3 # possible values: major, minor, patch + +# Check if GIT_TOKEN is empty +if [ -z "$GIT_TOKEN" ]; then + echo "❌ Error: GIT_TOKEN is not set." + exit 1 +fi + +# Check if WORK_DIR is empty +if [ -z "$WORK_DIR" ]; then + echo "❌ Error: WORK_DIR is not set." + exit 1 +fi + +# Validate VERSION_TYPE +if [[ "$VERSION_TYPE" != "major" && "$VERSION_TYPE" != "minor" && "$VERSION_TYPE" != "patch" ]]; then + echo "❌ Error: VERSION_TYPE must be one of: major, minor, or patch." + exit 1 +fi + +BUILD_DIRECTORY="$WORK_DIR/build" +RELEASE_DIRECTORY="$BUILD_DIRECTORY/releases" + +# Navigate to the working directory. +cd "${WORK_DIR}" + +# Create the release directory. +if [ ! -d "$RELEASE_DIRECTORY" ]; then + mkdir -p "$RELEASE_DIRECTORY" +else + rm -rf "$RELEASE_DIRECTORY"/* +fi + +# Extract current version. +CURRENT_VERSION=$(git describe --tags --abbrev=0 2>/dev/null || echo "0.0.0") +IFS='.' read -r MAJOR MINOR PATCH <<< "${CURRENT_VERSION}" + +# Determine which part to increment +case "$VERSION_TYPE" in + major) + MAJOR=$((MAJOR + 1)) + MINOR=0 + PATCH=0 + ;; + minor) + MINOR=$((MINOR + 1)) + PATCH=0 + ;; + patch|*) + PATCH=$((PATCH + 1)) + ;; +esac + +NEW_VERSION="${MAJOR}.${MINOR}.${PATCH}" + +echo "Creating release packages for version $NEW_VERSION..." + +# List of supported OSes. +oses=("linux" "linux-arm" "darwin") + +# Navigate to the release directory. +cd "${RELEASE_DIRECTORY}" + +for os in "${oses[@]}"; do + os_dir="../$os" + + if [ -d "$os_dir" ]; then + release_artifact_folder="openmcpauthproxy_${os}-v${NEW_VERSION}" + mkdir -p "$release_artifact_folder" + + cp -r $os_dir/* "$release_artifact_folder" + + # Zip the release package. + zip_file="$release_artifact_folder.zip" + echo "Creating $zip_file..." + zip -r "$zip_file" "$release_artifact_folder" + + # Delete the folder after zipping. + rm -rf "$release_artifact_folder" + + # Generate checksum file. + sha256sum "$zip_file" | sed "s|target/releases/||" > "$zip_file.sha256" + echo "Checksum generated for the $os package." + + echo "Release packages created successfully for $os." + else + echo "Skipping $os release package creation as the build artifacts are not available." + fi +done + +echo "Release packages created successfully in $RELEASE_DIRECTORY." + +# Navigate back to the project root directory. +cd "${WORK_DIR}" + +# Collect all ZIP and .sha256 files in the target/releases directory. +FILES_TO_UPLOAD=$(find build/releases -type f \( -name "*.zip" -o -name "*.sha256" \)) + +# Create a release with the current version. +TAG_NAME="v${NEW_VERSION}" +export GITHUB_TOKEN="${GIT_TOKEN}" +gh release create "${TAG_NAME}" ${FILES_TO_UPLOAD} --title "${TAG_NAME}" --notes "OpenMCPAuthProxy - ${TAG_NAME}" --target "${UPSTREAM_BRANCH}" || { echo "Failed to create release"; exit 1; } + + +echo "Release ${TAG_NAME} created successfully." diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml new file mode 100644 index 0000000..775003e --- /dev/null +++ b/.github/workflows/ci.yaml @@ -0,0 +1,71 @@ +name: Build and Push container +run-name: Build and Push container +on: + workflow_dispatch: + #schedule: + # - cron: "0 10 * * *" + push: + branches: + - 'main' + - 'master' + tags: + - 'v*' + pull_request: + branches: + - 'main' + - 'master' +env: + IMAGE: git.kvant.cloud/${{github.repository}} +jobs: + build_concierge_backend: + runs-on: ubuntu-latest + steps: + - name: Checkout + uses: actions/checkout@v4 + with: + fetch-depth: 0 + + - name: Set current time + uses: https://github.com/gerred/actions/current-time@master + id: current_time + + - name: Set up Docker Buildx + uses: docker/setup-buildx-action@v3 + + - name: Login to git.kvant.cloud registry + uses: docker/login-action@v3 + with: + registry: git.kvant.cloud + username: ${{ vars.ORG_PACKAGE_WRITER_USERNAME }} + password: ${{ secrets.ORG_PACKAGE_WRITER_TOKEN }} + + - name: Docker meta + id: meta + uses: docker/metadata-action@v5 + with: + # list of Docker images to use as base name for tags + images: | + ${{env.IMAGE}} + # generate Docker tags based on the following events/attributes + tags: | + type=schedule + type=ref,event=branch + type=ref,event=pr + type=semver,pattern={{version}} + + - name: Build and push to gitea registry + uses: docker/build-push-action@v6 + with: + push: ${{ github.event_name != 'pull_request' }} + tags: ${{ steps.meta.outputs.tags }} + labels: ${{ steps.meta.outputs.labels }} + context: . + provenance: mode=max + sbom: true + build-args: | + BUILD_DATE=${{ steps.current_time.outputs.time }} + cache-from: | + type=registry,ref=${{ env.IMAGE }}:buildcache + type=registry,ref=${{ env.IMAGE }}:${{ github.ref_name }} + type=registry,ref=${{ env.IMAGE }}:main + cache-to: type=registry,ref=${{ env.IMAGE }}:buildcache,mode=max,image-manifest=true diff --git a/.github/workflows/pr-builder.yml b/.github/workflows/pr-builder.yml new file mode 100644 index 0000000..a055e0d --- /dev/null +++ b/.github/workflows/pr-builder.yml @@ -0,0 +1,62 @@ +name: Go CI + +on: + push: + branches: [ main ] + pull_request: + branches: [ main ] + +jobs: + test: + name: Test + runs-on: ubuntu-latest + strategy: + matrix: + go-version: ['1.20', '1.21'] + + steps: + - name: Checkout code + uses: actions/checkout@v3 + + - name: Set up Go + uses: actions/setup-go@v4 + with: + go-version: ${{ matrix.go-version }} + + - name: Get dependencies + run: go get -v -t -d ./... + + - name: Verify dependencies + run: go mod verify + + - name: Run go vet + run: go vet ./... + + - name: Run tests + run: go test -v -race -coverprofile=coverage.txt -covermode=atomic ./... + + - name: Upload coverage to Codecov + uses: codecov/codecov-action@v3 + with: + files: ./coverage.txt + fail_ci_if_error: false + + build: + name: Build + runs-on: ubuntu-latest + strategy: + matrix: + go-version: ['1.20', '1.21'] + os: [ubuntu-latest, macos-latest] + + steps: + - name: Checkout code + uses: actions/checkout@v3 + + - name: Set up Go + uses: actions/setup-go@v4 + with: + go-version: ${{ matrix.go-version }} + + - name: Build + run: go build -v ./cmd/proxy diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml new file mode 100644 index 0000000..0c51bc7 --- /dev/null +++ b/.github/workflows/release.yml @@ -0,0 +1,64 @@ +# +# Copyright (c) 2025, WSO2 LLC. (https://www.wso2.com). +# +# This software is the property of WSO2 LLC. and its suppliers, if any. +# Dissemination of any information or reproduction of any material contained +# herein in any form is strictly forbidden, unless permitted by WSO2 expressly. +# You may not alter or remove any copyright or other notice from copies of this content. +# + +name: Release + +on: + workflow_dispatch: + inputs: + version_type: + type: choice + description: Choose the type of version update + options: + - 'major' + - 'minor' + - 'patch' + required: true + +jobs: + update-and-release: + runs-on: ubuntu-latest + env: + GOPROXY: https://proxy.golang.org + if: github.event.pull_request.merged == true || github.event_name == 'workflow_dispatch' + steps: + - uses: actions/checkout@v2 + with: + ref: 'main' + fetch-depth: 0 + token: ${{ secrets.GITHUB_TOKEN }} + - uses: actions/checkout@v2 + + - name: Set up Go 1.x + uses: actions/setup-go@v3 + with: + go-version: "^1.x" + + - name: Cache Go modules + id: cache-go-modules + uses: actions/cache@v3 + with: + path: | + ~/.cache/go-build + ~/go/pkg/mod + key: ${{ runner.os }}-go-modules-${{ hashFiles('**/go.sum') }} + restore-keys: | + ${{ runner.os }}-go-modules- + + - name: Install dependencies + run: go mod download + + - name: Build and test + run: make build + working-directory: . + + - name: Update artifact version, package, commit, and create release. + env: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + run: bash ./.github/scripts/release.sh $GITHUB_TOKEN ${{ github.workspace }} ${{ github.event.inputs.version_type }} diff --git a/.gitignore b/.gitignore index 6c1dd97..d200b58 100644 --- a/.gitignore +++ b/.gitignore @@ -18,15 +18,21 @@ *.zip *.tar.gz *.rar +.venv # virtual machine crash logs, see http://www.java.com/en/download/help/error_hotspot.xml hs_err_pid* replay_pid* -# Go module cache files -go.sum - # OS generated files .DS_Store -openmcpauthproxy +# builds +build + +# test out files +coverage.out +coverage.html + +# IDE files +.vscode diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 0000000..dc468b1 --- /dev/null +++ b/Dockerfile @@ -0,0 +1,48 @@ +FROM --platform=${BUILDPLATFORM:-linux/amd64} golang:1.24@sha256:d9db32125db0c3a680cfb7a1afcaefb89c898a075ec148fdc2f0f646cc2ed509 AS build + +ARG TARGETPLATFORM +ARG BUILDPLATFORM +ARG TARGETOS +ARG TARGETARCH + +WORKDIR /workspace + +RUN apt update -qq && apt install -qq -y git bash curl g++ + +# Download libraries +ADD go.* . +RUN go mod download + +# Build +ADD cmd cmd +ADD internal internal +RUN CGO_ENABLED=0 GOOS=${TARGETOS} GOARCH=${TARGETARCH} go build -o webhook -ldflags '-w -extldflags "-static"' -o openmcpauthproxy ./cmd/proxy + +#Test +RUN CGO_ENABLED=1 GOOS=${TARGETOS} GOARCH=${TARGETARCH} go test -v -race ./... + + +# Build production container +FROM --platform=${BUILDPLATFORM:-linux/amd64} ubuntu:24.04 + +RUN apt-get update \ + && apt-get install --no-install-recommends -y \ + python3-pip \ + python-is-python3 \ + npm \ + && apt-get autoremove \ + && apt-get clean \ + && rm -rf /var/lib/apt/lists/* + +RUN pip install uvenv --break-system-packages + +WORKDIR /app +COPY --from=build /workspace/openmcpauthproxy /app/ + +ADD config.yaml /app + + +ENTRYPOINT ["/app/openmcpauthproxy"] + +ARG IMAGE_SOURCE +LABEL org.opencontainers.image.source=$IMAGE_SOURCE diff --git a/Makefile b/Makefile new file mode 100644 index 0000000..b0d0926 --- /dev/null +++ b/Makefile @@ -0,0 +1,88 @@ +# Makefile for open-mcp-auth-proxy + +# Variables +PROJECT_ROOT := $(realpath $(dir $(abspath $(lastword $(MAKEFILE_LIST))))) +BINARY_NAME := openmcpauthproxy +GO := go +GOFMT := gofmt +GOVET := go vet +GOTEST := go test +GOLINT := golangci-lint +GOCOV := go tool cover +BUILD_DIR := build + +# Source files +SRC := $(shell find . -name "*.go" -not -path "./vendor/*") +PKGS := $(shell go list ./... | grep -v /vendor/) + +# Set build options +BUILD_OPTS := -v + +# Set test options +TEST_OPTS := -v -race + +.PHONY: all clean test fmt lint vet coverage help + +# Default target +all: lint test build-linux build-linux-arm build-darwin + +build: clean test build-linux build-linux-arm build-darwin + +build-linux: + mkdir -p $(BUILD_DIR)/linux + GOOS=linux GOARCH=amd64 CGO_ENABLED=0 go build -x -ldflags "-X main.version=$(BUILD_VERSION)" \ + -o $(BUILD_DIR)/linux/openmcpauthproxy $(PROJECT_ROOT)/cmd/proxy + cp config.yaml $(BUILD_DIR)/linux + +build-linux-arm: + mkdir -p $(BUILD_DIR)/linux-arm + GOOS=linux GOARCH=arm CGO_ENABLED=0 go build -x -ldflags "-X main.version=$(BUILD_VERSION)" \ + -o $(BUILD_DIR)/linux-arm/openmcpauthproxy $(PROJECT_ROOT)/cmd/proxy + cp config.yaml $(BUILD_DIR)/linux-arm + +build-darwin: + mkdir -p $(BUILD_DIR)/darwin + GOOS=darwin GOARCH=amd64 CGO_ENABLED=0 go build -x -ldflags "-X main.version=$(BUILD_VERSION)" \ + -o $(BUILD_DIR)/darwin/openmcpauthproxy $(PROJECT_ROOT)/cmd/proxy + cp config.yaml $(BUILD_DIR)/darwin + +# Clean build artifacts +clean: + @echo "Cleaning build artifacts..." + @rm -rf $(BUILD_DIR) + @rm -f coverage.out + +# Run tests +test: + @echo "Running tests..." + $(GOTEST) $(TEST_OPTS) ./... + +# Run tests with coverage report +coverage: + @echo "Running tests with coverage..." + @$(GOTEST) -coverprofile=coverage.out ./... + @$(GOCOV) -func=coverage.out + @$(GOCOV) -html=coverage.out -o coverage.html + @echo "Coverage report generated in coverage.html" + +# Run gofmt +fmt: + @echo "Running gofmt..." + @$(GOFMT) -w -s $(SRC) + +# Run go vet +vet: + @echo "Running go vet..." + @$(GOVET) ./... + +# Show help +help: + @echo "Available targets:" + @echo " all : Run lint, test, and build" + @echo " build : Build the application" + @echo " clean : Clean build artifacts" + @echo " test : Run tests" + @echo " coverage : Run tests with coverage report" + @echo " fmt : Run gofmt" + @echo " vet : Run go vet" + @echo " help : Show this help message" diff --git a/README.md b/README.md index b2fb23c..6be3ece 100644 --- a/README.md +++ b/README.md @@ -1,81 +1,88 @@ # Open MCP Auth Proxy -The Open MCP Auth Proxy is a lightweight proxy designed to sit in front of MCP servers and enforce authorization in compliance with the [Model Context Protocol authorization](https://spec.modelcontextprotocol.io/specification/2025-03-26/basic/authorization/) requirements. It intercepts incoming requests, validates tokens, and offloads authentication and authorization to an OAuth-compliant Identity Provider. +A lightweight authorization proxy for Model Context Protocol (MCP) servers that enforces authorization according to the [MCP authorization specification](https://spec.modelcontextprotocol.io/specification/2025-03-26/basic/authorization/) -![image](https://github.com/user-attachments/assets/fc728670-2fdb-4a63-bcc4-b9b6a6c8b4ba) +[![🚀 Release](https://github.com/wso2/open-mcp-auth-proxy/actions/workflows/release.yml/badge.svg)](https://github.com/wso2/open-mcp-auth-proxy/actions/workflows/release.yml) +[![💬 Stackoverflow](https://img.shields.io/badge/Ask%20for%20help%20on-Stackoverflow-orange)](https://stackoverflow.com/questions/tagged/wso2is) +[![💬 Discord](https://img.shields.io/badge/Join%20us%20on-Discord-%23e01563.svg)](https://discord.gg/wso2) +[![🐦 Twitter](https://img.shields.io/twitter/follow/wso2.svg?style=social&label=Follow)](https://twitter.com/intent/follow?screen_name=wso2) +[![📝 License](https://img.shields.io/badge/License-Apache%202.0-blue.svg)](https://github.com/wso2/product-is/blob/master/LICENSE) -## **Setup and Installation** +![Architecture Diagram](https://github.com/user-attachments/assets/41cf6723-c488-4860-8640-8fec45006f92) -### **Prerequisites** +## What it Does -* Go 1.20 or higher -* A running MCP server (SSE transport supported) -* An MCP client that supports MCP authorization +Open MCP Auth Proxy sits between MCP clients and your MCP server to: -### **Installation** +- Intercept incoming requests +- Validate authorization tokens +- Offload authentication and authorization to OAuth-compliant Identity Providers +- Support the MCP authorization protocol -```bash -git clone https://github.com/wso2/open-mcp-auth-proxy -cd open-mcp-auth-proxy +## Quick Start -go get github.com/golang-jwt/jwt/v4 -go get gopkg.in/yaml.v2 +### Prerequisites -go build -o openmcpauthproxy ./cmd/proxy -``` +* Go 1.20 or higher +* A running MCP server -## Using Open MCP Auth Proxy +> If you don't have an MCP server, you can use the included example: +> +> 1. Navigate to the `resources` directory +> 2. Set up a Python environment: +> +> ```bash +> python3 -m venv .venv +> source .venv/bin/activate +> pip3 install -r requirements.txt +> ``` +> +> 3. Start the example server: +> +> ```bash +> python3 echo_server.py +> ``` -### Quick Start +* An MCP client that supports MCP authorization -Allows you to just enable authentication and authorization for your MCP server with the preconfigured auth provider powered by Asgardeo. +### Basic Usage -If you don’t have an MCP server, follow the instructions given here to start your own MCP server for testing purposes. -1. Download [sample MCP server](resources/echo_server.py) -2. Run the server with -```bash -python3 echo_server.py -``` +1. Download the latest release from [Github releases](https://github.com/wso2/open-mcp-auth-proxy/releases/latest). -#### Configure the Auth Proxy - -Update the following parameters in `config.yaml`. - -```yaml -mcp_server_base_url: "http://localhost:8000" # URL of your MCP server -listen_address: ":8080" # Address where the proxy will listen -``` - -#### Start the Auth Proxy +2. Start the proxy in demo mode (uses pre-configured authentication with Asgardeo sandbox): ```bash ./openmcpauthproxy --demo ``` -The `--demo` flag enables a demonstration mode with pre-configured authentication and authorization with a sandbox powered by [Asgardeo](https://asgardeo.io/). +> The repository comes with a default `config.yaml` file that contains the basic configuration: +> +> ```yaml +> listen_port: 8080 +> base_url: "http://localhost:8000" # Your MCP server URL +> paths: +> sse: "/sse" +> messages: "/messages/" +> ``` -#### Connect Using an MCP Client +3. Connect using an MCP client like [MCP Inspector](https://github.com/shashimalcse/inspector)(This is a temporary fork with fixes for authentication [issues](https://github.com/modelcontextprotocol/typescript-sdk/issues/257) in the original implementation) -You can use the [MCP Inspector](https://github.com/modelcontextprotocol/inspector) to test the connection and try out the complete authorization flow. +## Connect an Identity Provider -### Use with Asgardeo +### Asgardeo -Enable authorization for the MCP server through your own Asgardeo organization +To enable authorization through your Asgardeo organization: -1. [Register]([url](https://asgardeo.io/signup)) and create an organization in Asgardeo -2. Now, you need to authorize the OpenMCPAuthProxy to allow dynamically registering MCP Clients as applications in your organization. To do that, - 1. Create an [M2M application](https://wso2.com/asgardeo/docs/guides/applications/register-machine-to-machine-app/) - 1. [Authorize this application](https://wso2.com/asgardeo/docs/guides/applications/register-machine-to-machine-app/#authorize-the-api-resources-for-the-app) to invoke “Application Management API” with the `internal_application_mgt_create` scope. - ![image](https://github.com/user-attachments/assets/0bd57cac-1904-48cc-b7aa-0530224bc41a) - 2. Note the **Client ID** and **Client secret** of this application. This is required by the auth proxy - -#### Configure the Auth Proxy - -Create a configuration file config.yaml with the following parameters: +1. [Register](https://asgardeo.io/signup) and create an organization in Asgardeo +2. Create an [M2M application](https://wso2.com/asgardeo/docs/guides/applications/register-machine-to-machine-app/) + 1. [Authorize this application](https://wso2.com/asgardeo/docs/guides/applications/register-machine-to-machine-app/#authorize-the-api-resources-for-the-app) to invoke "Application Management API" with the `internal_application_mgt_create` scope + ![image](https://github.com/user-attachments/assets/0bd57cac-1904-48cc-b7aa-0530224bc41a) + +3. Update `config.yaml` with the following parameters. ```yaml -mcp_server_base_url: "http://localhost:8000" # URL of your MCP server -listen_address: ":8080" # Address where the proxy will listen +base_url: "http://localhost:8000" # URL of your MCP server +listen_port: 8080 # Address where the proxy will listen asgardeo: org_name: "" # Your Asgardeo org name @@ -83,53 +90,137 @@ asgardeo: client_secret: "" # Client secret of the M2M app ``` -#### Start the Auth Proxy +4. Start the proxy with Asgardeo integration: ```bash ./openmcpauthproxy --asgardeo ``` -### Use with Auth0 +### Other OAuth Providers -Enable authorization for the MCP server through your Auth0 organization +- [Auth0](docs/integrations/Auth0.md) +- [Keycloak](docs/integrations/keycloak.md) -**TODO**: Add instructions +# Advanced Configuration -[Enable dynamic application registration](https://auth0.com/docs/get-started/applications/dynamic-client-registration#enable-dynamic-client-registration) in your Auth0 organization +### Transport Modes -#### Configure the Auth Proxy +The proxy supports two transport modes: -Create a configuration file config.yaml with the following parameters: +- **SSE Mode (Default)**: For Server-Sent Events transport +- **stdio Mode**: For MCP servers that use stdio transport -```yaml -mcp_server_base_url: "http://localhost:8000" # URL of your MCP server -listen_address: ":8080" # Address where the proxy will listen -``` +When using stdio mode, the proxy: +- Starts an MCP server as a subprocess using the command specified in the configuration +- Communicates with the subprocess through standard input/output (stdio) +- **Note**: Any commands specified (like `npx` in the example below) must be installed on your system first -**TODO**: Update the configs for Auth0. - -#### Start the Auth Proxy +To use stdio mode: ```bash -./openmcpauthproxy --auth0 +./openmcpauthproxy --demo --stdio ``` -### Use with a standard OAuth Server +#### Example: Running an MCP Server as a Subprocess -Enable authorization for the MCP server with a compliant OAuth server - -#### Configuration - -Create a configuration file config.yaml with the following parameters: +1. Configure stdio mode in your `config.yaml`: ```yaml -mcp_server_base_url: "http://localhost:8000" # URL of your MCP server -listen_address: ":8080" # Address where the proxy will listen -``` -**TODO**: Update the configs for a standard OAuth Server. +listen_port: 8080 +base_url: "http://localhost:8000" -#### Start the Auth Proxy +stdio: + enabled: true + user_command: "npx -y @modelcontextprotocol/server-github" # Example using a GitHub MCP server + env: # Environment variables (optional) + - "GITHUB_PERSONAL_ACCESS_TOKEN=gitPAT" + +# CORS configuration +cors: + allowed_origins: + - "http://localhost:5173" # Origin of your client application + allowed_methods: + - "GET" + - "POST" + - "PUT" + - "DELETE" + allowed_headers: + - "Authorization" + - "Content-Type" + allow_credentials: true + +# Demo configuration for Asgardeo +demo: + org_name: "openmcpauthdemo" + client_id: "N0U9e_NNGr9mP_0fPnPfPI0a6twa" + client_secret: "qFHfiBp5gNGAO9zV4YPnDofBzzfInatfUbHyPZvM0jka" +``` + +2. Run the proxy with stdio mode: ```bash -./openmcpauthproxy +./openmcpauthproxy --demo +``` + +The proxy will: +- Start the MCP server as a subprocess using the specified command +- Handle all authorization requirements +- Forward messages between clients and the server + +### Complete Configuration Reference + +```yaml +# Common configuration +listen_port: 8080 +base_url: "http://localhost:8000" +port: 8000 + +# Path configuration +paths: + sse: "/sse" + messages: "/messages/" + +# Transport mode +transport_mode: "sse" # Options: "sse" or "stdio" + +# stdio-specific configuration (used only in stdio mode) +stdio: + enabled: true + user_command: "npx -y @modelcontextprotocol/server-github" # Command to start the MCP server (requires npx to be installed) + work_dir: "" # Optional working directory for the subprocess + +# CORS configuration +cors: + allowed_origins: + - "http://localhost:5173" + allowed_methods: + - "GET" + - "POST" + - "PUT" + - "DELETE" + allowed_headers: + - "Authorization" + - "Content-Type" + allow_credentials: true + +# Demo configuration for Asgardeo +demo: + org_name: "openmcpauthdemo" + client_id: "N0U9e_NNGr9mP_0fPnPfPI0a6twa" + client_secret: "qFHfiBp5gNGAO9zV4YPnDofBzzfInatfUbHyPZvM0jka" + +# Asgardeo configuration (used with --asgardeo flag) +asgardeo: + org_name: "" + client_id: "" + client_secret: "" +``` + +### Build from source + +```bash +git clone https://github.com/wso2/open-mcp-auth-proxy +cd open-mcp-auth-proxy +go get github.com/golang-jwt/jwt/v4 gopkg.in/yaml.v2 +go build -o openmcpauthproxy ./cmd/proxy ``` diff --git a/cmd/proxy/main.go b/cmd/proxy/main.go index 9a4b472..c43dd7d 100644 --- a/cmd/proxy/main.go +++ b/cmd/proxy/main.go @@ -3,71 +3,132 @@ package main import ( "flag" "fmt" - "log" "net/http" "os" "os/signal" + "syscall" "time" "github.com/wso2/open-mcp-auth-proxy/internal/authz" "github.com/wso2/open-mcp-auth-proxy/internal/config" + "github.com/wso2/open-mcp-auth-proxy/internal/constants" + logger "github.com/wso2/open-mcp-auth-proxy/internal/logging" "github.com/wso2/open-mcp-auth-proxy/internal/proxy" + "github.com/wso2/open-mcp-auth-proxy/internal/subprocess" "github.com/wso2/open-mcp-auth-proxy/internal/util" ) func main() { demoMode := flag.Bool("demo", false, "Use Asgardeo-based provider (demo).") + asgardeoMode := flag.Bool("asgardeo", false, "Use Asgardeo-based provider (asgardeo).") + debugMode := flag.Bool("debug", false, "Enable debug logging") + stdioMode := flag.Bool("stdio", false, "Use stdio transport mode instead of SSE") flag.Parse() + logger.SetDebug(*debugMode) + // 1. Load config cfg, err := config.LoadConfig("config.yaml") if err != nil { - log.Fatalf("Error loading config: %v", err) + logger.Error("Error loading config: %v", err) + os.Exit(1) } - // 2. Create the chosen provider + // Override transport mode if stdio flag is set + if *stdioMode { + cfg.TransportMode = config.StdioTransport + // Ensure stdio is enabled + cfg.Stdio.Enabled = true + // Re-validate config + if err := cfg.Validate(); err != nil { + logger.Error("Configuration error: %v", err) + os.Exit(1) + } + } + + logger.Info("Using transport mode: %s", cfg.TransportMode) + logger.Info("Using MCP server base URL: %s", cfg.BaseURL) + logger.Info("Using MCP paths: SSE=%s, Messages=%s", cfg.Paths.SSE, cfg.Paths.Messages) + + // 2. Start subprocess if configured and in stdio mode + var procManager *subprocess.Manager + if cfg.TransportMode == config.StdioTransport && cfg.Stdio.Enabled { + // Ensure all required dependencies are available + if err := subprocess.EnsureDependenciesAvailable(cfg.Stdio.UserCommand); err != nil { + logger.Warn("%v", err) + logger.Warn("Subprocess may fail to start due to missing dependencies") + } + + procManager = subprocess.NewManager() + if err := procManager.Start(cfg); err != nil { + logger.Warn("Failed to start subprocess: %v", err) + } + } else if cfg.TransportMode == config.SSETransport { + logger.Info("Using SSE transport mode, not starting subprocess") + } + + // 3. Create the chosen provider var provider authz.Provider if *demoMode { - cfg.AuthServerBaseURL = "https://api.asgardeo.io/t/" + cfg.Demo.OrgName + "/oauth2" - cfg.JWKSURL = "https://api.asgardeo.io/t/" + cfg.Demo.OrgName + "/oauth2/jwks" + cfg.Mode = "demo" + cfg.AuthServerBaseURL = constants.ASGARDEO_BASE_URL + cfg.Demo.OrgName + "/oauth2" + cfg.JWKSURL = constants.ASGARDEO_BASE_URL + cfg.Demo.OrgName + "/oauth2/jwks" + provider = authz.NewAsgardeoProvider(cfg) + } else if *asgardeoMode { + cfg.Mode = "asgardeo" + cfg.AuthServerBaseURL = constants.ASGARDEO_BASE_URL + cfg.Asgardeo.OrgName + "/oauth2" + cfg.JWKSURL = constants.ASGARDEO_BASE_URL + cfg.Asgardeo.OrgName + "/oauth2/jwks" provider = authz.NewAsgardeoProvider(cfg) - fmt.Println("Using Asgardeo provider (demo).") } else { - log.Fatalf("Not supported yet.") + cfg.Mode = "default" + cfg.JWKSURL = cfg.Default.JWKSURL + cfg.AuthServerBaseURL = cfg.Default.BaseURL + provider = authz.NewDefaultProvider(cfg) } - // 3. (Optional) Fetch JWKS if you want local JWT validation + // 4. (Optional) Fetch JWKS if you want local JWT validation if err := util.FetchJWKS(cfg.JWKSURL); err != nil { - log.Fatalf("Failed to fetch JWKS: %v", err) + logger.Error("Failed to fetch JWKS: %v", err) + os.Exit(1) } - // 4. Build the main router + // 5. Build the main router mux := proxy.NewRouter(cfg, provider) - // 5. Start the server + listen_address := fmt.Sprintf("0.0.0.0:%d", cfg.ListenPort) + + // 6. Start the server srv := &http.Server{ - Addr: cfg.ListenAddress, + Addr: listen_address, Handler: mux, } go func() { - log.Printf("Server listening on %s", cfg.ListenAddress) + logger.Info("Server listening on %s", listen_address) if err := srv.ListenAndServe(); err != nil && err != http.ErrServerClosed { - log.Fatalf("Server error: %v", err) + logger.Error("Server error: %v", err) + os.Exit(1) } }() - // 6. Graceful shutdown on Ctrl+C + // 7. Wait for shutdown signal stop := make(chan os.Signal, 1) - signal.Notify(stop, os.Interrupt) + signal.Notify(stop, os.Interrupt, syscall.SIGTERM) <-stop - log.Println("Shutting down...") + logger.Info("Shutting down...") + // 8. First terminate subprocess if running + if procManager != nil && procManager.IsRunning() { + procManager.Shutdown() + } + + // 9. Then shutdown the server + logger.Info("Shutting down HTTP server...") shutdownCtx, cancel := proxy.NewShutdownContext(5 * time.Second) defer cancel() if err := srv.Shutdown(shutdownCtx); err != nil { - log.Printf("Shutdown error: %v", err) + logger.Error("HTTP server shutdown error: %v", err) } - log.Println("Stopped.") + logger.Info("Stopped.") } diff --git a/config.yaml b/config.yaml index 9725f14..ef70fbb 100644 --- a/config.yaml +++ b/config.yaml @@ -1,18 +1,65 @@ # config.yaml -auth_server_base_url: "" -mcp_server_base_url: "http://localhost:8000" -listen_address: ":8080" -jwks_url: "" +# Common configuration for all transport modes +listen_port: 8080 +base_url: "http://localhost:8000" # Base URL for the MCP server +port: 8000 # Port for the MCP server timeout_seconds: 10 -mcp_paths: - - /messages/ - - /sse +# Transport mode configuration +transport_mode: "stdio" # Options: "sse" or "stdio" + +# stdio-specific configuration (used only when transport_mode is "stdio") +stdio: + enabled: true + user_command: uvx mcp-server-time --local-timezone=Europe/Zurich + #user_command: "npx -y @modelcontextprotocol/server-github" + work_dir: "" # Working directory (optional) + # env: # Environment variables (optional) + # - "NODE_ENV=development" + +# CORS settings +cors: + allowed_origins: + - "http://localhost:6274" # Origin of your frontend/client app + allowed_methods: + - "GET" + - "POST" + - "PUT" + - "DELETE" + allowed_headers: + - "Authorization" + - "Content-Type" + - "mcp-protocol-version" + allow_credentials: true + +# Keycloak endpoint path mappings path_mapping: + sse: "/sse" # SSE endpoint path + messages: "/messages/" # Messages endpoint path + /token: /realms/master/protocol/openid-connect/token + /register: /realms/master/clients-registrations/openid-connect -demo: - org_name: "openmcpauthdemo" - client_id: "N0U9e_NNGr9mP_0fPnPfPI0a6twa" - client_secret: "qFHfiBp5gNGAO9zV4YPnDofBzzfInatfUbHyPZvM0jka" +# Keycloak configuration block +default: + base_url: "https://iam.phoenix-systems.ch" + jwks_url: "https://iam.phoenix-systems.ch/realms/kvant/protocol/openid-connect/certs" + path: + /.well-known/oauth-authorization-server: + response: + issuer: "https://iam.phoenix-systems.ch/realms/kvant" + jwks_uri: "https://iam.phoenix-systems.ch/realms/kvant/protocol/openid-connect/certs" + authorization_endpoint: "https://iam.phoenix-systems.ch/realms/kvant/protocol/openid-connect/auth" + response_types_supported: + - "code" + grant_types_supported: + - "authorization_code" + - "refresh_token" + code_challenge_methods_supported: + - "S256" + - "plain" + /token: + addBodyParams: + - name: "audience" + value: "mcp_proxy" \ No newline at end of file diff --git a/docs/integrations/Auth0.md b/docs/integrations/Auth0.md new file mode 100644 index 0000000..9195659 --- /dev/null +++ b/docs/integrations/Auth0.md @@ -0,0 +1,93 @@ +## Integrating with Auth0 + +This guide will help you configure Open MCP Auth Proxy to use Auth0 as your identity provider. + +### Prerequisites + +- An Auth0 organization (sign up [here](https://auth0.com) if you don't have one) +- Open MCP Auth Proxy installed + +### Setting Up Auth0 +1. [Enable Dynamic Client Registration](https://auth0.com/docs/get-started/applications/dynamic-client-registration) + - Go to your Auth0 dashboard + - Navigate to Settings > Advanced + - Enable "OIDC Dynamic Application Registration" +2. In order to setup connections in dynamically created clients [promote Connections to Domain Level](https://auth0.com/docs/authenticate/identity-providers/promote-connections-to-domain-level) +3. Create an API in Auth0: + - Go to your Auth0 dashboard + - Navigate to Applications > APIs + - Click on "Create API" + - Set a Name (e.g., "MCP API") + - Set an Identifier (e.g., "mcp_proxy") + - Keep the default signing algorithm (RS256) + - Click "Create" + +### Configuring the Open MCP Auth Proxy + +Update your `config.yaml` with Auth0 settings: + +```yaml +# Basic proxy configuration +listen_port: 8080 +base_url: "http://localhost:8000" +port: 8000 + +# Path configuration +paths: + sse: "/sse" + messages: "/messages/" + +# Transport mode +transport_mode: "sse" + +# CORS configuration +cors: + allowed_origins: + - "http://localhost:5173" # Your client application origin + allowed_methods: + - "GET" + - "POST" + - "PUT" + - "DELETE" + allowed_headers: + - "Authorization" + - "Content-Type" + allow_credentials: true + +# Path mappings for Auth0 endpoints +path_mapping: + /token: /oauth/token + /register: /oidc/register + +# Auth0 configuration +default: + base_url: "https://YOUR_AUTH0_DOMAIN" # e.g., https://dev-123456.us.auth0.com + jwks_url: "https://YOUR_AUTH0_DOMAIN/.well-known/jwks.json" + path: + /.well-known/oauth-authorization-server: + response: + issuer: "https://YOUR_AUTH0_DOMAIN/" + jwks_uri: "https://YOUR_AUTH0_DOMAIN/.well-known/jwks.json" + authorization_endpoint: "https://YOUR_AUTH0_DOMAIN/authorize?audience=mcp_proxy" # Only if you created an API with this identifier + response_types_supported: + - "code" + grant_types_supported: + - "authorization_code" + - "refresh_token" + code_challenge_methods_supported: + - "S256" + - "plain" + /token: + addBodyParams: + - name: "audience" + value: "mcp_proxy" # Only if you created an API with this identifier +``` + +Replace YOUR_AUTH0_DOMAIN with your Auth0 domain (e.g., dev-abc123.us.auth0.com). + +## Starting the Proxy with Auth0 Integration +Start the proxy in default mode (which will use Auth0 based on your configuration): + +```bash +./openmcpauthproxy +``` diff --git a/docs/integrations/keycloak.md b/docs/integrations/keycloak.md new file mode 100644 index 0000000..5e338cc --- /dev/null +++ b/docs/integrations/keycloak.md @@ -0,0 +1,92 @@ +## Integrating Open MCP Auth Proxy with Keycloak + +This guide walks you through configuring the Open MCP Auth Proxy to authenticate using Keycloak as the identity provider. + +--- + +### Prerequisites + +Before you begin, ensure you have the following: + +- A running Keycloak instance +- Open MCP Auth Proxy installed and accessible + +--- + +### Step 1: Configure Keycloak for Client Registration + +Set up dynamic client registration in your Keycloak realm by following the [Keycloak client registration guide](https://www.keycloak.org/securing-apps/client-registration). + +--- + +### Step 2: Configure Open MCP Auth Proxy + +Update the `config.yaml` file in your Open MCP Auth Proxy setup using your Keycloak realm's [OIDC settings](https://www.keycloak.org/securing-apps/oidc-layers). Below is an example configuration: + +```yaml +# Proxy server configuration +listen_port: 8081 # Port for the auth proxy +base_url: "http://localhost:8000" # Base URL of the MCP server +port: 8000 # MCP server port + +# Define path mappings +paths: + sse: "/sse" + messages: "/messages/" + +# Set the transport mode +transport_mode: "sse" + +# CORS settings +cors: + allowed_origins: + - "http://localhost:5173" # Origin of your frontend/client app + allowed_methods: + - "GET" + - "POST" + - "PUT" + - "DELETE" + allowed_headers: + - "Authorization" + - "Content-Type" + - "mcp-protocol-version" + allow_credentials: true + +# Keycloak endpoint path mappings +path_mapping: + /token: /realms/master/protocol/openid-connect/token + /register: /realms/master/clients-registrations/openid-connect + +# Keycloak configuration block +default: + base_url: "http://localhost:8080" + jwks_url: "http://localhost:8080/realms/master/protocol/openid-connect/certs" + path: + /.well-known/oauth-authorization-server: + response: + issuer: "http://localhost:8080/realms/master" + jwks_uri: "http://localhost:8080/realms/master/protocol/openid-connect/certs" + authorization_endpoint: "http://localhost:8080/realms/master/protocol/openid-connect/auth" + response_types_supported: + - "code" + grant_types_supported: + - "authorization_code" + - "refresh_token" + code_challenge_methods_supported: + - "S256" + - "plain" + /token: + addBodyParams: + - name: "audience" + value: "mcp_proxy" +``` + +### Step 3: Start the Auth Proxy + +Launch the proxy with the updated Keycloak configuration: + +```bash +./openmcpauthproxy +``` + +Once running, the proxy will handle authentication requests through your configured Keycloak realm. diff --git a/go.mod b/go.mod index 2d26216..0bceb4f 100644 --- a/go.mod +++ b/go.mod @@ -1,6 +1,6 @@ module github.com/wso2/open-mcp-auth-proxy -go 1.22.3 +go 1.21 require ( github.com/golang-jwt/jwt/v4 v4.5.2 diff --git a/go.sum b/go.sum new file mode 100644 index 0000000..9d27ad1 --- /dev/null +++ b/go.sum @@ -0,0 +1,6 @@ +github.com/golang-jwt/jwt/v4 v4.5.2 h1:YtQM7lnr8iZ+j5q71MGKkNw9Mn7AjHM68uc9g5fXeUI= +github.com/golang-jwt/jwt/v4 v4.5.2/go.mod h1:m21LjoU+eqJr34lmDMbreY2eSTRJ1cv77w39/MY0Ch0= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY= +gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ= diff --git a/internal/authz/asgardeo.go b/internal/authz/asgardeo.go index 7408f79..a3c812c 100644 --- a/internal/authz/asgardeo.go +++ b/internal/authz/asgardeo.go @@ -7,13 +7,13 @@ import ( "encoding/json" "fmt" "io" - "log" "math/rand" "net/http" "strings" "time" "github.com/wso2/open-mcp-auth-proxy/internal/config" + "github.com/wso2/open-mcp-auth-proxy/internal/logging" ) type asgardeoProvider struct { @@ -31,6 +31,7 @@ func (p *asgardeoProvider) WellKnownHandler() http.HandlerFunc { w.Header().Set("Access-Control-Allow-Origin", "*") w.Header().Set("Access-Control-Allow-Headers", "Authorization, Content-Type") w.Header().Set("Access-Control-Allow-Methods", "GET, OPTIONS") + w.Header().Set("X-Accel-Buffering", "no") if r.Method == http.MethodOptions { w.WriteHeader(http.StatusNoContent) @@ -70,8 +71,9 @@ func (p *asgardeoProvider) WellKnownHandler() http.HandlerFunc { } w.Header().Set("Content-Type", "application/json") + w.Header().Set("X-Accel-Buffering", "no") if err := json.NewEncoder(w).Encode(response); err != nil { - log.Printf("[asgardeoProvider] Error encoding well-known: %v", err) + logger.Error("Error encoding well-known: %v", err) http.Error(w, "Internal server error", http.StatusInternalServerError) } } @@ -83,6 +85,7 @@ func (p *asgardeoProvider) RegisterHandler() http.HandlerFunc { w.Header().Set("Access-Control-Allow-Origin", "*") w.Header().Set("Access-Control-Allow-Headers", "Authorization, Content-Type") w.Header().Set("Access-Control-Allow-Methods", "GET, OPTIONS") + w.Header().Set("X-Accel-Buffering", "no") if r.Method == http.MethodOptions { w.WriteHeader(http.StatusNoContent) @@ -95,7 +98,7 @@ func (p *asgardeoProvider) RegisterHandler() http.HandlerFunc { var regReq RegisterRequest if err := json.NewDecoder(r.Body).Decode(®Req); err != nil { - log.Printf("ERROR: reading register request: %v", err) + logger.Error("Reading register request: %v", err) http.Error(w, "Invalid request body", http.StatusBadRequest) return } @@ -109,7 +112,7 @@ func (p *asgardeoProvider) RegisterHandler() http.HandlerFunc { regReq.ClientSecret = randomString(16) if err := p.createAsgardeoApplication(regReq); err != nil { - log.Printf("WARN: Asgardeo application creation failed: %v", err) + logger.Warn("Asgardeo application creation failed: %v", err) // Optionally http.Error(...) if you want to fail // or continue to return partial data. } @@ -124,9 +127,10 @@ func (p *asgardeoProvider) RegisterHandler() http.HandlerFunc { } w.Header().Set("Content-Type", "application/json") + w.Header().Set("X-Accel-Buffering", "no") w.WriteHeader(http.StatusCreated) if err := json.NewEncoder(w).Encode(resp); err != nil { - log.Printf("ERROR: encoding /register response: %v", err) + logger.Error("Encoding /register response: %v", err) http.Error(w, "Internal server error", http.StatusInternalServerError) } } @@ -186,7 +190,7 @@ func (p *asgardeoProvider) createAsgardeoApplication(regReq RegisterRequest) err return fmt.Errorf("Asgardeo creation error (%d): %s", resp.StatusCode, string(respBody)) } - log.Printf("INFO: Created Asgardeo application for clientID=%s", regReq.ClientID) + logger.Info("Created Asgardeo application for clientID=%s", regReq.ClientID) return nil } @@ -202,8 +206,11 @@ func (p *asgardeoProvider) getAsgardeoAdminToken() (string, error) { } req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + // Sensitive data - should not be logged at INFO level auth := p.cfg.Demo.ClientID + ":" + p.cfg.Demo.ClientSecret req.Header.Set("Authorization", "Basic "+base64.StdEncoding.EncodeToString([]byte(auth))) + + logger.Debug("Requesting admin token for Asgardeo with client ID: %s", p.cfg.Demo.ClientID) tr := &http.Transport{ TLSClientConfig: &tls.Config{InsecureSkipVerify: true}, @@ -234,6 +241,10 @@ func (p *asgardeoProvider) getAsgardeoAdminToken() (string, error) { return "", fmt.Errorf("failed to parse token JSON: %w", err) } + // Don't log the actual token at info level, only at debug level + logger.Debug("Received access token: %s", tokenResp.AccessToken) + logger.Info("Successfully obtained admin token from Asgardeo") + return tokenResp.AccessToken, nil } diff --git a/internal/authz/default.go b/internal/authz/default.go new file mode 100644 index 0000000..929f586 --- /dev/null +++ b/internal/authz/default.go @@ -0,0 +1,96 @@ +package authz + +import ( + "encoding/json" + "net/http" + + "github.com/wso2/open-mcp-auth-proxy/internal/config" + "github.com/wso2/open-mcp-auth-proxy/internal/logging" +) + +type defaultProvider struct { + cfg *config.Config +} + +// NewDefaultProvider initializes a Provider for Asgardeo (demo mode). +func NewDefaultProvider(cfg *config.Config) Provider { + return &defaultProvider{cfg: cfg} +} + +func (p *defaultProvider) WellKnownHandler() http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Access-Control-Allow-Origin", "*") + w.Header().Set("Access-Control-Allow-Headers", "Authorization, Content-Type") + w.Header().Set("Access-Control-Allow-Methods", "GET, OPTIONS") + + if r.Method == http.MethodOptions { + w.WriteHeader(http.StatusNoContent) + return + } + + if r.Method != http.MethodGet { + http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) + return + } + + // Check if we have a custom response configuration + if p.cfg.Default.Path != nil { + pathConfig, exists := p.cfg.Default.Path["/.well-known/oauth-authorization-server"] + if exists && pathConfig.Response != nil { + // Use configured response values + responseConfig := pathConfig.Response + + // Get current host for proxy endpoints + scheme := "http" + if r.TLS != nil { + scheme = "https" + } + if forwardedProto := r.Header.Get("X-Forwarded-Proto"); forwardedProto != "" { + scheme = forwardedProto + } + host := r.Host + if forwardedHost := r.Header.Get("X-Forwarded-Host"); forwardedHost != "" { + host = forwardedHost + } + baseURL := scheme + "://" + host + + authorizationEndpoint := responseConfig.AuthorizationEndpoint + if authorizationEndpoint == "" { + authorizationEndpoint = baseURL + "/authorize" + } + tokenEndpoint := responseConfig.TokenEndpoint + if tokenEndpoint == "" { + tokenEndpoint = baseURL + "/token" + } + registraionEndpoint := responseConfig.RegistrationEndpoint + if registraionEndpoint == "" { + registraionEndpoint = baseURL + "/register" + } + + // Build response from config + response := map[string]interface{}{ + "issuer": responseConfig.Issuer, + "authorization_endpoint": authorizationEndpoint, + "token_endpoint": tokenEndpoint, + "jwks_uri": responseConfig.JwksURI, + "response_types_supported": responseConfig.ResponseTypesSupported, + "grant_types_supported": responseConfig.GrantTypesSupported, + "token_endpoint_auth_methods_supported": []string{"client_secret_basic"}, + "registration_endpoint": registraionEndpoint, + "code_challenge_methods_supported": responseConfig.CodeChallengeMethodsSupported, + } + + w.Header().Set("Content-Type", "application/json") + if err := json.NewEncoder(w).Encode(response); err != nil { + logger.Error("Error encoding well-known response: %v", err) + http.Error(w, "Internal server error", http.StatusInternalServerError) + } + return + } + } + } +} + +func (p *defaultProvider) RegisterHandler() http.HandlerFunc { + return nil +} diff --git a/internal/authz/default_test.go b/internal/authz/default_test.go new file mode 100644 index 0000000..f40030f --- /dev/null +++ b/internal/authz/default_test.go @@ -0,0 +1,125 @@ +package authz + +import ( + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + + "github.com/wso2/open-mcp-auth-proxy/internal/config" +) + +func TestNewDefaultProvider(t *testing.T) { + cfg := &config.Config{} + provider := NewDefaultProvider(cfg) + + if provider == nil { + t.Fatal("Expected non-nil provider") + } + + // Ensure it implements the Provider interface + var _ Provider = provider +} + +func TestDefaultProviderWellKnownHandler(t *testing.T) { + // Create a config with a custom well-known response + cfg := &config.Config{ + Default: config.DefaultConfig{ + Path: map[string]config.PathConfig{ + "/.well-known/oauth-authorization-server": { + Response: &config.ResponseConfig{ + Issuer: "https://test-issuer.com", + JwksURI: "https://test-issuer.com/jwks", + ResponseTypesSupported: []string{"code"}, + GrantTypesSupported: []string{"authorization_code"}, + CodeChallengeMethodsSupported: []string{"S256"}, + }, + }, + }, + }, + } + + provider := NewDefaultProvider(cfg) + handler := provider.WellKnownHandler() + + // Create a test request + req := httptest.NewRequest("GET", "/.well-known/oauth-authorization-server", nil) + req.Host = "test-host.com" + req.Header.Set("X-Forwarded-Proto", "https") + + // Create a response recorder + w := httptest.NewRecorder() + + // Call the handler + handler(w, req) + + // Check response status + if w.Code != http.StatusOK { + t.Errorf("Expected status OK, got %v", w.Code) + } + + // Verify content type + contentType := w.Header().Get("Content-Type") + if contentType != "application/json" { + t.Errorf("Expected Content-Type: application/json, got %s", contentType) + } + + // Decode and check the response body + var response map[string]interface{} + if err := json.NewDecoder(w.Body).Decode(&response); err != nil { + t.Fatalf("Failed to decode response JSON: %v", err) + } + + // Check expected values + if response["issuer"] != "https://test-issuer.com" { + t.Errorf("Expected issuer=https://test-issuer.com, got %v", response["issuer"]) + } + if response["jwks_uri"] != "https://test-issuer.com/jwks" { + t.Errorf("Expected jwks_uri=https://test-issuer.com/jwks, got %v", response["jwks_uri"]) + } + if response["authorization_endpoint"] != "https://test-host.com/authorize" { + t.Errorf("Expected authorization_endpoint=https://test-host.com/authorize, got %v", response["authorization_endpoint"]) + } +} + +func TestDefaultProviderHandleOPTIONS(t *testing.T) { + provider := NewDefaultProvider(&config.Config{}) + handler := provider.WellKnownHandler() + + // Create OPTIONS request + req := httptest.NewRequest("OPTIONS", "/.well-known/oauth-authorization-server", nil) + w := httptest.NewRecorder() + + // Call the handler + handler(w, req) + + // Check response + if w.Code != http.StatusNoContent { + t.Errorf("Expected status NoContent for OPTIONS request, got %v", w.Code) + } + + // Check CORS headers + if w.Header().Get("Access-Control-Allow-Origin") != "*" { + t.Errorf("Expected Access-Control-Allow-Origin: *, got %s", w.Header().Get("Access-Control-Allow-Origin")) + } + if w.Header().Get("Access-Control-Allow-Methods") != "GET, OPTIONS" { + t.Errorf("Expected Access-Control-Allow-Methods: GET, OPTIONS, got %s", w.Header().Get("Access-Control-Allow-Methods")) + } +} + +func TestDefaultProviderInvalidMethod(t *testing.T) { + provider := NewDefaultProvider(&config.Config{}) + handler := provider.WellKnownHandler() + + // Create POST request (which should be rejected) + req := httptest.NewRequest("POST", "/.well-known/oauth-authorization-server", nil) + w := httptest.NewRecorder() + + // Call the handler + handler(w, req) + + // Check response + if w.Code != http.StatusMethodNotAllowed { + t.Errorf("Expected status MethodNotAllowed for POST request, got %v", w.Code) + } +} diff --git a/internal/config/config.go b/internal/config/config.go index 3a3b231..fc6743c 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -1,29 +1,155 @@ package config import ( + "fmt" "os" "gopkg.in/yaml.v2" ) -// AsgardeoConfig groups all Asgardeo-specific fields +// Transport mode for MCP server +type TransportMode string + +const ( + SSETransport TransportMode = "sse" + StdioTransport TransportMode = "stdio" +) + +// Common path configuration for all transport modes +type PathsConfig struct { + SSE string `yaml:"sse"` + Messages string `yaml:"messages"` +} + +// StdioConfig contains stdio-specific configuration +type StdioConfig struct { + Enabled bool `yaml:"enabled"` + UserCommand string `yaml:"user_command"` // The command provided by the user + WorkDir string `yaml:"work_dir"` // Working directory (optional) + Args []string `yaml:"args,omitempty"` // Additional arguments + Env []string `yaml:"env,omitempty"` // Environment variables +} + type DemoConfig struct { ClientID string `yaml:"client_id"` ClientSecret string `yaml:"client_secret"` OrgName string `yaml:"org_name"` } +type AsgardeoConfig struct { + ClientID string `yaml:"client_id"` + ClientSecret string `yaml:"client_secret"` + OrgName string `yaml:"org_name"` +} + +type CORSConfig struct { + AllowedOrigins []string `yaml:"allowed_origins"` + AllowedMethods []string `yaml:"allowed_methods"` + AllowedHeaders []string `yaml:"allowed_headers"` + AllowCredentials bool `yaml:"allow_credentials"` +} + +type ParamConfig struct { + Name string `yaml:"name"` + Value string `yaml:"value"` +} + +type ResponseConfig struct { + Issuer string `yaml:"issuer,omitempty"` + JwksURI string `yaml:"jwks_uri,omitempty"` + AuthorizationEndpoint string `yaml:"authorization_endpoint,omitempty"` + TokenEndpoint string `yaml:"token_endpoint,omitempty"` + RegistrationEndpoint string `yaml:"registration_endpoint,omitempty"` + ResponseTypesSupported []string `yaml:"response_types_supported,omitempty"` + GrantTypesSupported []string `yaml:"grant_types_supported,omitempty"` + CodeChallengeMethodsSupported []string `yaml:"code_challenge_methods_supported,omitempty"` +} + +type PathConfig struct { + // For well-known endpoint + Response *ResponseConfig `yaml:"response,omitempty"` + + // For authorization endpoint + AddQueryParams []ParamConfig `yaml:"addQueryParams,omitempty"` + + // For token and register endpoints + AddBodyParams []ParamConfig `yaml:"addBodyParams,omitempty"` +} + +type DefaultConfig struct { + BaseURL string `yaml:"base_url,omitempty"` + Path map[string]PathConfig `yaml:"path,omitempty"` + JWKSURL string `yaml:"jwks_url,omitempty"` +} + type Config struct { - AuthServerBaseURL string `yaml:"auth_server_base_url"` - MCPServerBaseURL string `yaml:"mcp_server_base_url"` - ListenAddress string `yaml:"listen_address"` - JWKSURL string `yaml:"jwks_url"` - TimeoutSeconds int `yaml:"timeout_seconds"` - MCPPaths []string `yaml:"mcp_paths"` - PathMapping map[string]string `yaml:"path_mapping"` + AuthServerBaseURL string + ListenPort int `yaml:"listen_port"` + BaseURL string `yaml:"base_url"` + Port int `yaml:"port"` + JWKSURL string + TimeoutSeconds int `yaml:"timeout_seconds"` + PathMapping map[string]string `yaml:"path_mapping"` + Mode string `yaml:"mode"` + CORSConfig CORSConfig `yaml:"cors"` + TransportMode TransportMode `yaml:"transport_mode"` + Paths PathsConfig `yaml:"paths"` + Stdio StdioConfig `yaml:"stdio"` // Nested config for Asgardeo - Demo DemoConfig `yaml:"demo"` + Demo DemoConfig `yaml:"demo"` + Asgardeo AsgardeoConfig `yaml:"asgardeo"` + Default DefaultConfig `yaml:"default"` +} + +// Validate checks if the config is valid based on transport mode +func (c *Config) Validate() error { + // Validate based on transport mode + if c.TransportMode == StdioTransport { + if !c.Stdio.Enabled { + return fmt.Errorf("stdio.enabled must be true in stdio transport mode") + } + if c.Stdio.UserCommand == "" { + return fmt.Errorf("stdio.user_command is required in stdio transport mode") + } + } + + // Validate paths + if c.Paths.SSE == "" { + c.Paths.SSE = "/sse" // Default value + } + if c.Paths.Messages == "" { + c.Paths.Messages = "/messages" // Default value + } + + // Validate base URL + if c.BaseURL == "" { + if c.Port > 0 { + c.BaseURL = fmt.Sprintf("http://localhost:%d", c.Port) + } else { + c.BaseURL = "http://localhost:8000" // Default value + } + } + + return nil +} + +// GetMCPPaths returns the list of paths that should be proxied to the MCP server +func (c *Config) GetMCPPaths() []string { + return []string{c.Paths.SSE, c.Paths.Messages} +} + +// BuildExecCommand constructs the full command string for execution in stdio mode +func (c *Config) BuildExecCommand() string { + if c.Stdio.UserCommand == "" { + return "" + } + + // Construct the full command + return fmt.Sprintf( + `npx -y supergateway --stdio "%s" --port %d --baseUrl %s --ssePath %s --messagePath %s`, + c.Stdio.UserCommand, c.Port, c.BaseURL, c.Paths.SSE, c.Paths.Messages, + ) } // LoadConfig reads a YAML config file into Config struct. @@ -39,8 +165,26 @@ func LoadConfig(path string) (*Config, error) { if err := decoder.Decode(&cfg); err != nil { return nil, err } + + // Set default values if cfg.TimeoutSeconds == 0 { cfg.TimeoutSeconds = 15 // default } + + // Set default transport mode if not specified + if cfg.TransportMode == "" { + cfg.TransportMode = SSETransport // Default to SSE + } + + // Set default port if not specified + if cfg.Port == 0 { + cfg.Port = 8000 // default + } + + // Validate the configuration + if err := cfg.Validate(); err != nil { + return nil, err + } + return &cfg, nil } diff --git a/internal/config/config_test.go b/internal/config/config_test.go new file mode 100644 index 0000000..20c0893 --- /dev/null +++ b/internal/config/config_test.go @@ -0,0 +1,196 @@ +package config + +import ( + "os" + "path/filepath" + "testing" +) + +func TestLoadConfig(t *testing.T) { + // Create a temporary config file + tempDir := t.TempDir() + configPath := filepath.Join(tempDir, "test_config.yaml") + + // Basic valid config + validConfig := ` +listen_port: 8080 +base_url: "http://localhost:8000" +transport_mode: "sse" +paths: + sse: "/sse" + messages: "/messages" +cors: + allowed_origins: + - "http://localhost:5173" + allowed_methods: + - "GET" + - "POST" + allowed_headers: + - "Authorization" + - "Content-Type" + allow_credentials: true +` + err := os.WriteFile(configPath, []byte(validConfig), 0644) + if err != nil { + t.Fatalf("Failed to create test config file: %v", err) + } + + // Test loading the valid config + cfg, err := LoadConfig(configPath) + if err != nil { + t.Fatalf("Failed to load valid config: %v", err) + } + + // Verify expected values from the config + if cfg.ListenPort != 8080 { + t.Errorf("Expected ListenPort=8080, got %d", cfg.ListenPort) + } + if cfg.BaseURL != "http://localhost:8000" { + t.Errorf("Expected BaseURL=http://localhost:8000, got %s", cfg.BaseURL) + } + if cfg.TransportMode != SSETransport { + t.Errorf("Expected TransportMode=sse, got %s", cfg.TransportMode) + } + if cfg.Paths.SSE != "/sse" { + t.Errorf("Expected Paths.SSE=/sse, got %s", cfg.Paths.SSE) + } + if cfg.Paths.Messages != "/messages" { + t.Errorf("Expected Paths.Messages=/messages, got %s", cfg.Paths.Messages) + } + + // Test default values + if cfg.TimeoutSeconds != 15 { + t.Errorf("Expected default TimeoutSeconds=15, got %d", cfg.TimeoutSeconds) + } + if cfg.Port != 8000 { + t.Errorf("Expected default Port=8000, got %d", cfg.Port) + } +} + +func TestValidate(t *testing.T) { + tests := []struct { + name string + config Config + expectError bool + }{ + { + name: "Valid SSE config", + config: Config{ + TransportMode: SSETransport, + Paths: PathsConfig{ + SSE: "/sse", + Messages: "/messages", + }, + BaseURL: "http://localhost:8000", + }, + expectError: false, + }, + { + name: "Valid stdio config", + config: Config{ + TransportMode: StdioTransport, + Stdio: StdioConfig{ + Enabled: true, + UserCommand: "some-command", + }, + }, + expectError: false, + }, + { + name: "Invalid stdio config - not enabled", + config: Config{ + TransportMode: StdioTransport, + Stdio: StdioConfig{ + Enabled: false, + UserCommand: "some-command", + }, + }, + expectError: true, + }, + { + name: "Invalid stdio config - no command", + config: Config{ + TransportMode: StdioTransport, + Stdio: StdioConfig{ + Enabled: true, + UserCommand: "", + }, + }, + expectError: true, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + err := tc.config.Validate() + if tc.expectError && err == nil { + t.Errorf("Expected validation error but got none") + } + if !tc.expectError && err != nil { + t.Errorf("Expected no validation error but got: %v", err) + } + }) + } +} + +func TestGetMCPPaths(t *testing.T) { + cfg := Config{ + Paths: PathsConfig{ + SSE: "/custom-sse", + Messages: "/custom-messages", + }, + } + + paths := cfg.GetMCPPaths() + if len(paths) != 2 { + t.Errorf("Expected 2 MCP paths, got %d", len(paths)) + } + if paths[0] != "/custom-sse" { + t.Errorf("Expected first path=/custom-sse, got %s", paths[0]) + } + if paths[1] != "/custom-messages" { + t.Errorf("Expected second path=/custom-messages, got %s", paths[1]) + } +} + +func TestBuildExecCommand(t *testing.T) { + tests := []struct { + name string + config Config + expectedResult string + }{ + { + name: "Valid command", + config: Config{ + Stdio: StdioConfig{ + UserCommand: "test-command", + }, + Port: 8080, + BaseURL: "http://example.com", + Paths: PathsConfig{ + SSE: "/sse-path", + Messages: "/msgs", + }, + }, + expectedResult: `npx -y supergateway --stdio "test-command" --port 8080 --baseUrl http://example.com --ssePath /sse-path --messagePath /msgs`, + }, + { + name: "Empty command", + config: Config{ + Stdio: StdioConfig{ + UserCommand: "", + }, + }, + expectedResult: "", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + result := tc.config.BuildExecCommand() + if result != tc.expectedResult { + t.Errorf("Expected command=%s, got %s", tc.expectedResult, result) + } + }) + } +} diff --git a/internal/constants/constants.go b/internal/constants/constants.go new file mode 100644 index 0000000..1e5808e --- /dev/null +++ b/internal/constants/constants.go @@ -0,0 +1,7 @@ +package constants + +// Package constant provides constants for the MCP Auth Proxy + +const ( + ASGARDEO_BASE_URL = "https://api.asgardeo.io/t/" +) diff --git a/internal/logging/logger.go b/internal/logging/logger.go new file mode 100644 index 0000000..57bec27 --- /dev/null +++ b/internal/logging/logger.go @@ -0,0 +1,34 @@ +package logger + +import ( + "log" +) + +var isDebug = false + +// SetDebug enables or disables debug logging +func SetDebug(debug bool) { + isDebug = debug +} + +// Debug logs a debug-level message +func Debug(format string, v ...interface{}) { + if isDebug { + log.Printf("DEBUG: "+format, v...) + } +} + +// Info logs an info-level message +func Info(format string, v ...interface{}) { + log.Printf("INFO: "+format, v...) +} + +// Warn logs a warning-level message +func Warn(format string, v ...interface{}) { + log.Printf("WARN: "+format, v...) +} + +// Error logs an error-level message +func Error(format string, v ...interface{}) { + log.Printf("ERROR: "+format, v...) +} diff --git a/internal/proxy/modifier.go b/internal/proxy/modifier.go new file mode 100644 index 0000000..6662b2c --- /dev/null +++ b/internal/proxy/modifier.go @@ -0,0 +1,204 @@ +package proxy + +import ( + "bytes" + "encoding/json" + "fmt" + "io" + "net/http" + "strings" + + "github.com/wso2/open-mcp-auth-proxy/internal/config" + "github.com/wso2/open-mcp-auth-proxy/internal/logging" +) + +// RequestModifier modifies requests before they are proxied +type RequestModifier interface { + ModifyRequest(req *http.Request) (*http.Request, error) +} + +// AuthorizationModifier adds parameters to authorization requests +type AuthorizationModifier struct { + Config *config.Config +} + +// TokenModifier adds parameters to token requests +type TokenModifier struct { + Config *config.Config +} + +type RegisterModifier struct { + Config *config.Config +} + +// ModifyRequest adds configured parameters to authorization requests +func (m *AuthorizationModifier) ModifyRequest(req *http.Request) (*http.Request, error) { + // Check if we have parameters to add + if m.Config.Default.Path == nil { + return req, nil + } + + pathConfig, exists := m.Config.Default.Path["/authorize"] + if !exists || len(pathConfig.AddQueryParams) == 0 { + return req, nil + } + // Get current query parameters + query := req.URL.Query() + + // Add parameters from config + for _, param := range pathConfig.AddQueryParams { + query.Set(param.Name, param.Value) + } + + // Update the request URL + req.URL.RawQuery = query.Encode() + + return req, nil +} + +// ModifyRequest adds configured parameters to token requests +func (m *TokenModifier) ModifyRequest(req *http.Request) (*http.Request, error) { + // Only modify POST requests + if req.Method != http.MethodPost { + return req, nil + } + + // Check if we have parameters to add + if m.Config.Default.Path == nil { + return req, nil + } + + pathConfig, exists := m.Config.Default.Path["/token"] + if !exists || len(pathConfig.AddBodyParams) == 0 { + return req, nil + } + + contentType := req.Header.Get("Content-Type") + + if strings.Contains(contentType, "application/x-www-form-urlencoded") { + // Parse form data + if err := req.ParseForm(); err != nil { + return nil, err + } + + // Clone form data + formData := req.PostForm + + // Add configured parameters + for _, param := range pathConfig.AddBodyParams { + formData.Set(param.Name, param.Value) + } + + // Create new request body with modified form + formEncoded := formData.Encode() + req.Body = io.NopCloser(strings.NewReader(formEncoded)) + req.ContentLength = int64(len(formEncoded)) + req.Header.Set("Content-Length", fmt.Sprintf("%d", len(formEncoded))) + + } else if strings.Contains(contentType, "application/json") { + // Read body + bodyBytes, err := io.ReadAll(req.Body) + if err != nil { + return nil, err + } + + // Parse JSON + var jsonData map[string]interface{} + if err := json.Unmarshal(bodyBytes, &jsonData); err != nil { + return nil, err + } + + // Add parameters + for _, param := range pathConfig.AddBodyParams { + jsonData[param.Name] = param.Value + } + + // Marshal back to JSON + modifiedBody, err := json.Marshal(jsonData) + if err != nil { + return nil, err + } + + // Update request + req.Body = io.NopCloser(bytes.NewReader(modifiedBody)) + req.ContentLength = int64(len(modifiedBody)) + req.Header.Set("Content-Length", fmt.Sprintf("%d", len(modifiedBody))) + } + + return req, nil +} + +func (m *RegisterModifier) ModifyRequest(req *http.Request) (*http.Request, error) { + // Only modify POST requests + if req.Method != http.MethodPost { + return req, nil + } + + // Check if we have parameters to add + if m.Config.Default.Path == nil { + return req, nil + } + + pathConfig, exists := m.Config.Default.Path["/register"] + if !exists || len(pathConfig.AddBodyParams) == 0 { + return req, nil + } + + contentType := req.Header.Get("Content-Type") + + if strings.Contains(contentType, "application/x-www-form-urlencoded") { + // Parse form data + if err := req.ParseForm(); err != nil { + logger.Error("Failed to parse form data: %v", err) + return nil, err + } + + // Clone form data + formData := req.PostForm + + // Add configured parameters + for _, param := range pathConfig.AddBodyParams { + formData.Set(param.Name, param.Value) + } + + // Create new request body with modified form + formEncoded := formData.Encode() + req.Body = io.NopCloser(strings.NewReader(formEncoded)) + req.ContentLength = int64(len(formEncoded)) + req.Header.Set("Content-Length", fmt.Sprintf("%d", len(formEncoded))) + + } else if strings.Contains(contentType, "application/json") { + // Read body + bodyBytes, err := io.ReadAll(req.Body) + if err != nil { + logger.Error("Failed to read request body: %v", err) + return nil, err + } + + // Parse JSON + var jsonData map[string]interface{} + if err := json.Unmarshal(bodyBytes, &jsonData); err != nil { + logger.Error("Failed to parse JSON body: %v", err) + return nil, err + } + + // Add parameters + for _, param := range pathConfig.AddBodyParams { + jsonData[param.Name] = param.Value + } + + // Marshal back to JSON + modifiedBody, err := json.Marshal(jsonData) + if err != nil { + logger.Error("Failed to marshal modified JSON: %v", err) + return nil, err + } + + // Update request + req.Body = io.NopCloser(bytes.NewReader(modifiedBody)) + req.ContentLength = int64(len(modifiedBody)) + req.Header.Set("Content-Length", fmt.Sprintf("%d", len(modifiedBody))) + } + + return req, nil +} diff --git a/internal/proxy/modifier_test.go b/internal/proxy/modifier_test.go new file mode 100644 index 0000000..3d2fd44 --- /dev/null +++ b/internal/proxy/modifier_test.go @@ -0,0 +1,147 @@ +package proxy + +import ( + "net/http" + "net/url" + "strings" + "testing" + + "github.com/wso2/open-mcp-auth-proxy/internal/config" +) + +func TestAuthorizationModifier(t *testing.T) { + cfg := &config.Config{ + Default: config.DefaultConfig{ + Path: map[string]config.PathConfig{ + "/authorize": { + AddQueryParams: []config.ParamConfig{ + {Name: "client_id", Value: "test-client-id"}, + {Name: "scope", Value: "openid"}, + }, + }, + }, + }, + } + + modifier := &AuthorizationModifier{Config: cfg} + + // Create a test request + req, err := http.NewRequest("GET", "/authorize?response_type=code", nil) + if err != nil { + t.Fatalf("Failed to create request: %v", err) + } + + // Modify the request + modifiedReq, err := modifier.ModifyRequest(req) + if err != nil { + t.Fatalf("ModifyRequest failed: %v", err) + } + + // Check that the query parameters were added + query := modifiedReq.URL.Query() + if query.Get("client_id") != "test-client-id" { + t.Errorf("Expected client_id=test-client-id, got %s", query.Get("client_id")) + } + if query.Get("scope") != "openid" { + t.Errorf("Expected scope=openid, got %s", query.Get("scope")) + } + if query.Get("response_type") != "code" { + t.Errorf("Expected response_type=code, got %s", query.Get("response_type")) + } +} + +func TestTokenModifier(t *testing.T) { + cfg := &config.Config{ + Default: config.DefaultConfig{ + Path: map[string]config.PathConfig{ + "/token": { + AddBodyParams: []config.ParamConfig{ + {Name: "audience", Value: "test-audience"}, + }, + }, + }, + }, + } + + modifier := &TokenModifier{Config: cfg} + + // Create a test request with form data + form := url.Values{} + + req, err := http.NewRequest("POST", "/token", strings.NewReader(form.Encode())) + if err != nil { + t.Fatalf("Failed to create request: %v", err) + } + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + + // Modify the request + modifiedReq, err := modifier.ModifyRequest(req) + if err != nil { + t.Fatalf("ModifyRequest failed: %v", err) + } + + body := make([]byte, 1024) + n, err := modifiedReq.Body.Read(body) + if err != nil && err.Error() != "EOF" { + t.Fatalf("Failed to read body: %v", err) + } + bodyStr := string(body[:n]) + + // Parse the form data from the modified request + if err := modifiedReq.ParseForm(); err != nil { + t.Fatalf("Failed to parse form data: %v", err) + } + + // Check that the body parameters were added + if !strings.Contains(bodyStr, "audience") { + t.Errorf("Expected body to contain audience, got %s", bodyStr) + } +} + +func TestRegisterModifier(t *testing.T) { + cfg := &config.Config{ + Default: config.DefaultConfig{ + Path: map[string]config.PathConfig{ + "/register": { + AddBodyParams: []config.ParamConfig{ + {Name: "client_name", Value: "test-client"}, + }, + }, + }, + }, + } + + modifier := &RegisterModifier{Config: cfg} + + // Create a test request with JSON data + jsonBody := `{"redirect_uris":["https://example.com/callback"]}` + req, err := http.NewRequest("POST", "/register", strings.NewReader(jsonBody)) + if err != nil { + t.Fatalf("Failed to create request: %v", err) + } + req.Header.Set("Content-Type", "application/json") + + // Modify the request + modifiedReq, err := modifier.ModifyRequest(req) + if err != nil { + t.Fatalf("ModifyRequest failed: %v", err) + } + + // Read the body and check that it still contains the original data + // This test would need to be enhanced with a proper JSON parsing to verify + // the added parameters + body := make([]byte, 1024) + n, err := modifiedReq.Body.Read(body) + if err != nil && err.Error() != "EOF" { + t.Fatalf("Failed to read body: %v", err) + } + bodyStr := string(body[:n]) + + // Simple check to see if the modified body contains the expected fields + if !strings.Contains(bodyStr, "client_name") { + t.Errorf("Expected body to contain client_name, got %s", bodyStr) + } + if !strings.Contains(bodyStr, "redirect_uris") { + t.Errorf("Expected body to contain redirect_uris, got %s", bodyStr) + } +} diff --git a/internal/proxy/proxy.go b/internal/proxy/proxy.go index 382c8f3..33a9ea3 100644 --- a/internal/proxy/proxy.go +++ b/internal/proxy/proxy.go @@ -2,7 +2,6 @@ package proxy import ( "context" - "log" "net/http" "net/http/httputil" "net/url" @@ -11,6 +10,7 @@ import ( "github.com/wso2/open-mcp-auth-proxy/internal/authz" "github.com/wso2/open-mcp-auth-proxy/internal/config" + "github.com/wso2/open-mcp-auth-proxy/internal/logging" "github.com/wso2/open-mcp-auth-proxy/internal/util" ) @@ -20,78 +20,138 @@ import ( func NewRouter(cfg *config.Config, provider authz.Provider) http.Handler { mux := http.NewServeMux() - // 1. Custom well-known - mux.HandleFunc("/.well-known/oauth-authorization-server", provider.WellKnownHandler()) + modifiers := map[string]RequestModifier{ + "/authorize": &AuthorizationModifier{Config: cfg}, + "/token": &TokenModifier{Config: cfg}, + "/register": &RegisterModifier{Config: cfg}, + } - // 2. Registration - mux.HandleFunc("/register", provider.RegisterHandler()) + registeredPaths := make(map[string]bool) - // 3. Default "auth" paths, proxied - defaultPaths := []string{"/authorize", "/token"} + var defaultPaths []string + + // Handle based on mode configuration + if cfg.Mode == "demo" || cfg.Mode == "asgardeo" { + // Demo/Asgardeo mode: Custom handlers for well-known and register + mux.HandleFunc("/.well-known/oauth-authorization-server", provider.WellKnownHandler()) + registeredPaths["/.well-known/oauth-authorization-server"] = true + + mux.HandleFunc("/register", provider.RegisterHandler()) + registeredPaths["/register"] = true + + // Authorize and token will be proxied with parameter modification + defaultPaths = []string{"/authorize", "/token"} + } else { + // Default provider mode + if cfg.Default.Path != nil { + // Check if we have custom response for well-known + wellKnownConfig, exists := cfg.Default.Path["/.well-known/oauth-authorization-server"] + if exists && wellKnownConfig.Response != nil { + // If there's a custom response defined, use our handler + mux.HandleFunc("/.well-known/oauth-authorization-server", provider.WellKnownHandler()) + registeredPaths["/.well-known/oauth-authorization-server"] = true + } else { + // No custom response, add well-known to proxy paths + defaultPaths = append(defaultPaths, "/.well-known/oauth-authorization-server") + } + + defaultPaths = append(defaultPaths, "/authorize") + defaultPaths = append(defaultPaths, "/token") + defaultPaths = append(defaultPaths, "/register") + } else { + defaultPaths = []string{"/authorize", "/token", "/register", "/.well-known/oauth-authorization-server"} + } + } + + // Remove duplicates from defaultPaths + uniquePaths := make(map[string]bool) + cleanPaths := []string{} for _, path := range defaultPaths { - mux.HandleFunc(path, buildProxyHandler(cfg)) + if !uniquePaths[path] { + uniquePaths[path] = true + cleanPaths = append(cleanPaths, path) + } + } + defaultPaths = cleanPaths + + for _, path := range defaultPaths { + if !registeredPaths[path] { + mux.HandleFunc(path, buildProxyHandler(cfg, modifiers)) + registeredPaths[path] = true + } } - // 4. MCP paths - for _, path := range cfg.MCPPaths { - mux.HandleFunc(path, buildProxyHandler(cfg)) + // MCP paths + mcpPaths := cfg.GetMCPPaths() + for _, path := range mcpPaths { + mux.HandleFunc(path, buildProxyHandler(cfg, modifiers)) + registeredPaths[path] = true } - // 5. If you want to map additional paths from config.PathMapping - // to the same proxy logic: + // Register paths from PathMapping that haven't been registered yet for path := range cfg.PathMapping { - mux.HandleFunc(path, buildProxyHandler(cfg)) + if !registeredPaths[path] { + mux.HandleFunc(path, buildProxyHandler(cfg, modifiers)) + registeredPaths[path] = true + } } return mux } -func buildProxyHandler(cfg *config.Config) http.HandlerFunc { +func buildProxyHandler(cfg *config.Config, modifiers map[string]RequestModifier) http.HandlerFunc { // Parse the base URLs up front authBase, err := url.Parse(cfg.AuthServerBaseURL) if err != nil { - log.Fatalf("Invalid auth server URL: %v", err) + logger.Error("Invalid auth server URL: %v", err) + panic(err) // Fatal error that prevents startup } - mcpBase, err := url.Parse(cfg.MCPServerBaseURL) + + mcpBase, err := url.Parse(cfg.BaseURL) if err != nil { - log.Fatalf("Invalid MCP server URL: %v", err) - } - - // We'll define sets for known auth paths, SSE paths, etc. - authPaths := map[string]bool{ - "/authorize": true, - "/token": true, - "/.well-known/oauth-authorization-server": true, + logger.Error("Invalid MCP server URL: %v", err) + panic(err) // Fatal error that prevents startup } // Detect SSE paths from config ssePaths := make(map[string]bool) - for _, p := range cfg.MCPPaths { - if p == "/sse" { - ssePaths[p] = true - } - } + ssePaths[cfg.Paths.SSE] = true return func(w http.ResponseWriter, r *http.Request) { + origin := r.Header.Get("Origin") + allowedOrigin := getAllowedOrigin(origin, cfg) // Handle OPTIONS if r.Method == http.MethodOptions { - addCORSHeaders(w) + if allowedOrigin == "" { + logger.Warn("Preflight request from disallowed origin: %s", origin) + http.Error(w, "CORS origin not allowed", http.StatusForbidden) + return + } + addCORSHeaders(w, cfg, allowedOrigin, r.Header.Get("Access-Control-Request-Headers")) w.WriteHeader(http.StatusNoContent) return } - addCORSHeaders(w) + if allowedOrigin == "" { + logger.Warn("Request from disallowed origin: %s for %s", origin, r.URL.Path) + http.Error(w, "CORS origin not allowed", http.StatusForbidden) + return + } + + // Add CORS headers to all responses + addCORSHeaders(w, cfg, allowedOrigin, "") // Decide whether the request should go to the auth server or MCP var targetURL *url.URL isSSE := false - if authPaths[r.URL.Path] { + if isAuthPath(r.URL.Path) { targetURL = authBase } else if isMCPPath(r.URL.Path, cfg) { - // Validate JWT if you want + // Validate JWT for MCP paths if required + // Placeholder for JWT validation logic if err := util.ValidateJWT(r.Header.Get("Authorization")); err != nil { - log.Printf("[proxy] Unauthorized request to %s: %v", r.URL.Path, err) + logger.Warn("Unauthorized request to %s: %v", r.URL.Path, err) http.Error(w, "Unauthorized", http.StatusUnauthorized) return } @@ -100,11 +160,21 @@ func buildProxyHandler(cfg *config.Config) http.HandlerFunc { isSSE = true } } else { - // If it's not recognized as an auth path or an MCP path http.Error(w, "Forbidden", http.StatusForbidden) return } + // Apply request modifiers to add parameters + if modifier, exists := modifiers[r.URL.Path]; exists { + var err error + r, err = modifier.ModifyRequest(r) + if err != nil { + logger.Error("Error modifying request: %v", err) + http.Error(w, "Bad Request", http.StatusBadRequest) + return + } + } + // Build the reverse proxy rp := &httputil.ReverseProxy{ Director: func(req *http.Request) { @@ -120,31 +190,53 @@ func buildProxyHandler(cfg *config.Config) http.HandlerFunc { req.URL.RawQuery = r.URL.RawQuery req.Host = targetURL.Host - for header, values := range r.Header { + cleanHeaders := http.Header{} + + // Set proper origin header to match the target + if isSSE { + // For SSE, ensure origin matches the target + req.Header.Set("Origin", targetURL.Scheme+"://"+targetURL.Host) + } + + for k, v := range r.Header { // Skip hop-by-hop headers - if strings.EqualFold(header, "Connection") || - strings.EqualFold(header, "Keep-Alive") || - strings.EqualFold(header, "Transfer-Encoding") || - strings.EqualFold(header, "Upgrade") || - strings.EqualFold(header, "Proxy-Authorization") || - strings.EqualFold(header, "Proxy-Connection") { + if skipHeader(k) { continue } - for _, value := range values { - req.Header.Set(header, value) - } + // Set only the first value to avoid duplicates + cleanHeaders.Set(k, v[0]) } - log.Printf("[proxy] %s -> %s%s", r.URL.Path, req.URL.Host, req.URL.Path) + + req.Header = cleanHeaders + + logger.Debug("%s -> %s%s", r.URL.Path, req.URL.Host, req.URL.Path) + }, + ModifyResponse: func(resp *http.Response) error { + logger.Debug("Response from %s%s: %d", resp.Request.URL.Host, resp.Request.URL.Path, resp.StatusCode) + resp.Header.Del("Access-Control-Allow-Origin") // Avoid upstream conflicts + return nil }, ErrorHandler: func(rw http.ResponseWriter, req *http.Request, err error) { - log.Printf("[proxy] Error proxying: %v", err) + logger.Error("Error proxying: %v", err) http.Error(rw, "Bad Gateway", http.StatusBadGateway) }, FlushInterval: -1, // immediate flush for SSE } if isSSE { + // Add special response handling for SSE connections to rewrite endpoint URLs + rp.Transport = &sseTransport{ + Transport: http.DefaultTransport, + proxyHost: r.Host, + targetHost: targetURL.Host, + } + + // Set SSE-specific headers + w.Header().Set("X-Accel-Buffering", "no") + w.Header().Set("Cache-Control", "no-cache") + w.Header().Set("Connection", "keep-alive") + // Keep SSE connections open HandleSSE(w, r, rp) } else { @@ -156,14 +248,52 @@ func buildProxyHandler(cfg *config.Config) http.HandlerFunc { } } -func addCORSHeaders(w http.ResponseWriter) { - w.Header().Set("Access-Control-Allow-Origin", "*") - w.Header().Set("Access-Control-Allow-Headers", "Authorization, Content-Type") - w.Header().Set("Access-Control-Allow-Methods", "GET, POST, OPTIONS") +func getAllowedOrigin(origin string, cfg *config.Config) string { + if origin == "" { + return cfg.CORSConfig.AllowedOrigins[0] // Default to first allowed origin + } + for _, allowed := range cfg.CORSConfig.AllowedOrigins { + logger.Debug("Checking CORS origin: %s against allowed: %s", origin, allowed) + if allowed == origin { + return allowed + } + } + return "" } +// addCORSHeaders adds configurable CORS headers +func addCORSHeaders(w http.ResponseWriter, cfg *config.Config, allowedOrigin, requestHeaders string) { + w.Header().Set("Access-Control-Allow-Origin", allowedOrigin) + w.Header().Set("Access-Control-Allow-Methods", strings.Join(cfg.CORSConfig.AllowedMethods, ", ")) + if requestHeaders != "" { + w.Header().Set("Access-Control-Allow-Headers", requestHeaders) + } else { + w.Header().Set("Access-Control-Allow-Headers", strings.Join(cfg.CORSConfig.AllowedHeaders, ", ")) + } + if cfg.CORSConfig.AllowCredentials { + w.Header().Set("Access-Control-Allow-Credentials", "true") + } + w.Header().Set("Vary", "Origin") + w.Header().Set("X-Accel-Buffering", "no") +} + +func isAuthPath(path string) bool { + authPaths := map[string]bool{ + "/authorize": true, + "/token": true, + "/register": true, + "/.well-known/oauth-authorization-server": true, + } + if strings.HasPrefix(path, "/u/") { + return true + } + return authPaths[path] +} + +// isMCPPath checks if the path is an MCP path func isMCPPath(path string, cfg *config.Config) bool { - for _, p := range cfg.MCPPaths { + mcpPaths := cfg.GetMCPPaths() + for _, p := range mcpPaths { if strings.HasPrefix(path, p) { return true } @@ -171,22 +301,10 @@ func isMCPPath(path string, cfg *config.Config) bool { return false } -func copyHeaders(src http.Header, dst http.Header) { - // Exclude hop-by-hop - hopByHop := map[string]bool{ - "Connection": true, - "Keep-Alive": true, - "Transfer-Encoding": true, - "Upgrade": true, - "Proxy-Authorization": true, - "Proxy-Connection": true, - } - for k, vv := range src { - if hopByHop[strings.ToLower(k)] { - continue - } - for _, v := range vv { - dst.Add(k, v) - } +func skipHeader(h string) bool { + switch strings.ToLower(h) { + case "connection", "keep-alive", "transfer-encoding", "upgrade", "proxy-authorization", "proxy-connection", "te", "trailer": + return true } + return false } diff --git a/internal/proxy/sse.go b/internal/proxy/sse.go index 44d6558..ce72e04 100644 --- a/internal/proxy/sse.go +++ b/internal/proxy/sse.go @@ -1,11 +1,16 @@ package proxy import ( + "bufio" "context" - "log" + "fmt" + "io" "net/http" "net/http/httputil" + "strings" "time" + + "github.com/wso2/open-mcp-auth-proxy/internal/logging" ) // HandleSSE sets up a go-routine to wait for context cancellation @@ -16,7 +21,7 @@ func HandleSSE(w http.ResponseWriter, r *http.Request, rp *httputil.ReverseProxy go func() { <-ctx.Done() - log.Printf("INFO: SSE connection closed from %s (path: %s)", r.RemoteAddr, r.URL.Path) + logger.Info("SSE connection closed from %s (path: %s)", r.RemoteAddr, r.URL.Path) close(done) }() @@ -32,3 +37,73 @@ func HandleSSE(w http.ResponseWriter, r *http.Request, rp *httputil.ReverseProxy func NewShutdownContext(timeout time.Duration) (context.Context, context.CancelFunc) { return context.WithTimeout(context.Background(), timeout) } + +// sseTransport is a custom http.RoundTripper that intercepts and modifies SSE responses +type sseTransport struct { + Transport http.RoundTripper + proxyHost string + targetHost string +} + +func (t *sseTransport) RoundTrip(req *http.Request) (*http.Response, error) { + // Call the underlying transport + resp, err := t.Transport.RoundTrip(req) + if err != nil { + return nil, err + } + + // Check if this is an SSE response + contentType := resp.Header.Get("Content-Type") + if !strings.Contains(contentType, "text/event-stream") { + return resp, nil + } + + logger.Info("Intercepting SSE response to modify endpoint events") + + // Create a response wrapper that modifies the response body + originalBody := resp.Body + pr, pw := io.Pipe() + + go func() { + defer originalBody.Close() + defer pw.Close() + + scanner := bufio.NewScanner(originalBody) + for scanner.Scan() { + line := scanner.Text() + + // Check if this line contains an endpoint event + if strings.HasPrefix(line, "event: endpoint") { + // Read the data line + if scanner.Scan() { + dataLine := scanner.Text() + if strings.HasPrefix(dataLine, "data: ") { + // Extract the endpoint URL + endpoint := strings.TrimPrefix(dataLine, "data: ") + + // Replace the host in the endpoint + logger.Debug("Original endpoint: %s", endpoint) + endpoint = strings.Replace(endpoint, t.targetHost, t.proxyHost, 1) + logger.Debug("Modified endpoint: %s", endpoint) + + // Write the modified event lines + fmt.Fprintln(pw, line) + fmt.Fprintln(pw, "data: "+endpoint) + continue + } + } + } + + // Write the original line for non-endpoint events + fmt.Fprintln(pw, line) + } + + if err := scanner.Err(); err != nil { + logger.Error("Error reading SSE stream: %v", err) + } + }() + + // Replace the response body with our modified pipe + resp.Body = pr + return resp, nil +} diff --git a/internal/subprocess/manager.go b/internal/subprocess/manager.go new file mode 100644 index 0000000..fa64337 --- /dev/null +++ b/internal/subprocess/manager.go @@ -0,0 +1,268 @@ +package subprocess + +import ( + "fmt" + "os" + "os/exec" + "sync" + "syscall" + "time" + "strings" + + "github.com/wso2/open-mcp-auth-proxy/internal/config" + "github.com/wso2/open-mcp-auth-proxy/internal/logging" +) + +// Manager handles starting and graceful shutdown of subprocesses +type Manager struct { + process *os.Process + processGroup int + mutex sync.Mutex + cmd *exec.Cmd + shutdownDelay time.Duration +} + +// NewManager creates a new subprocess manager +func NewManager() *Manager { + return &Manager{ + shutdownDelay: 5 * time.Second, + } +} + +// EnsureDependenciesAvailable checks and installs required package executors +func EnsureDependenciesAvailable(command string) error { + // Always ensure npx is available regardless of the command + if _, err := exec.LookPath("npx"); err != nil { + // npx is not available, check if npm is installed + if _, err := exec.LookPath("npm"); err != nil { + return fmt.Errorf("npx not found and npm not available; please install Node.js from https://nodejs.org/") + } + + // Try to install npx using npm + logger.Info("npx not found, attempting to install...") + cmd := exec.Command("npm", "install", "-g", "npx") + cmd.Stdout = os.Stdout + cmd.Stderr = os.Stderr + + if err := cmd.Run(); err != nil { + return fmt.Errorf("failed to install npx: %w", err) + } + + logger.Info("npx installed successfully") + } + + // Check if uv is needed based on the command + if strings.Contains(command, "uv ") { + if _, err := exec.LookPath("uv"); err != nil { + return fmt.Errorf("command requires uv but it's not installed; please install it following instructions at https://github.com/astral-sh/uv") + } + } + + return nil +} + +// SetShutdownDelay sets the maximum time to wait for graceful shutdown +func (m *Manager) SetShutdownDelay(duration time.Duration) { + m.shutdownDelay = duration +} + +// Start launches a subprocess based on the configuration +func (m *Manager) Start(cfg *config.Config) error { + m.mutex.Lock() + defer m.mutex.Unlock() + + // If a process is already running, return an error + if m.process != nil { + return os.ErrExist + } + + if !cfg.Stdio.Enabled || cfg.Stdio.UserCommand == "" { + return nil // Nothing to start + } + + // Get the full command string + execCommand := cfg.BuildExecCommand() + if execCommand == "" { + return nil // No command to execute + } + + logger.Info("Starting subprocess with command: %s", execCommand) + + // Use the shell to execute the command + cmd := exec.Command("sh", "-c", execCommand) + + // Set working directory if specified + if cfg.Stdio.WorkDir != "" { + cmd.Dir = cfg.Stdio.WorkDir + } + + // Set environment variables if specified + if len(cfg.Stdio.Env) > 0 { + cmd.Env = append(os.Environ(), cfg.Stdio.Env...) + } + + // Capture stdout/stderr + cmd.Stdout = os.Stdout + cmd.Stderr = os.Stderr + + // Set the process group for proper termination + cmd.SysProcAttr = &syscall.SysProcAttr{Setpgid: true} + + // Start the process + if err := cmd.Start(); err != nil { + return err + } + + m.process = cmd.Process + m.cmd = cmd + logger.Info("Subprocess started with PID: %d", m.process.Pid) + + // Get and store the process group ID + pgid, err := syscall.Getpgid(m.process.Pid) + if err == nil { + m.processGroup = pgid + logger.Debug("Process group ID: %d", m.processGroup) + } else { + logger.Warn("Failed to get process group ID: %v", err) + m.processGroup = m.process.Pid + } + + // Handle process termination in background + go func() { + if err := cmd.Wait(); err != nil { + logger.Error("Subprocess exited with error: %v", err) + } else { + logger.Info("Subprocess exited successfully") + } + + // Clear the process reference when it exits + m.mutex.Lock() + m.process = nil + m.cmd = nil + m.mutex.Unlock() + }() + + return nil +} + +// IsRunning checks if the subprocess is running +func (m *Manager) IsRunning() bool { + m.mutex.Lock() + defer m.mutex.Unlock() + return m.process != nil +} + +// Shutdown gracefully terminates the subprocess +func (m *Manager) Shutdown() { + m.mutex.Lock() + processToTerminate := m.process // Local copy of the process reference + processGroupToTerminate := m.processGroup + m.mutex.Unlock() + + if processToTerminate == nil { + return // No process to terminate + } + + logger.Info("Terminating subprocess...") + terminateComplete := make(chan struct{}) + + go func() { + defer close(terminateComplete) + + // Try graceful termination first with SIGTERM + terminatedGracefully := false + + // Try to terminate the process group first + if processGroupToTerminate != 0 { + err := syscall.Kill(-processGroupToTerminate, syscall.SIGTERM) + if err != nil { + logger.Warn("Failed to send SIGTERM to process group: %v", err) + + // Fallback to terminating just the process + m.mutex.Lock() + if m.process != nil { + err = m.process.Signal(syscall.SIGTERM) + if err != nil { + logger.Warn("Failed to send SIGTERM to process: %v", err) + } + } + m.mutex.Unlock() + } + } else { + // Try to terminate just the process + m.mutex.Lock() + if m.process != nil { + err := m.process.Signal(syscall.SIGTERM) + if err != nil { + logger.Warn("Failed to send SIGTERM to process: %v", err) + } + } + m.mutex.Unlock() + } + + // Wait for the process to exit gracefully + for i := 0; i < 10; i++ { + time.Sleep(200 * time.Millisecond) + + m.mutex.Lock() + if m.process == nil { + terminatedGracefully = true + m.mutex.Unlock() + break + } + m.mutex.Unlock() + } + + if terminatedGracefully { + logger.Info("Subprocess terminated gracefully") + return + } + + // If the process didn't exit gracefully, force kill + logger.Warn("Subprocess didn't exit gracefully, forcing termination...") + + // Try to kill the process group first + if processGroupToTerminate != 0 { + if err := syscall.Kill(-processGroupToTerminate, syscall.SIGKILL); err != nil { + logger.Warn("Failed to send SIGKILL to process group: %v", err) + + // Fallback to killing just the process + m.mutex.Lock() + if m.process != nil { + if err := m.process.Kill(); err != nil { + logger.Error("Failed to kill process: %v", err) + } + } + m.mutex.Unlock() + } + } else { + // Try to kill just the process + m.mutex.Lock() + if m.process != nil { + if err := m.process.Kill(); err != nil { + logger.Error("Failed to kill process: %v", err) + } + } + m.mutex.Unlock() + } + + // Wait a bit more to confirm termination + time.Sleep(500 * time.Millisecond) + + m.mutex.Lock() + if m.process == nil { + logger.Info("Subprocess terminated by force") + } else { + logger.Warn("Failed to terminate subprocess") + } + m.mutex.Unlock() + }() + + // Wait for termination with timeout + select { + case <-terminateComplete: + // Termination completed + case <-time.After(m.shutdownDelay): + logger.Warn("Subprocess termination timed out") + } +} diff --git a/internal/util/jwks.go b/internal/util/jwks.go index 4832bf8..f80d82e 100644 --- a/internal/util/jwks.go +++ b/internal/util/jwks.go @@ -4,12 +4,12 @@ import ( "crypto/rsa" "encoding/json" "errors" - "log" "math/big" "net/http" "strings" "github.com/golang-jwt/jwt/v4" + "github.com/wso2/open-mcp-auth-proxy/internal/logging" ) type JWKS struct { @@ -50,7 +50,7 @@ func FetchJWKS(jwksURL string) error { publicKeys[parsedKey.Kid] = pubKey } } - log.Printf("[JWKS] Loaded %d public keys.", len(publicKeys)) + logger.Info("Loaded %d public keys.", len(publicKeys)) return nil } diff --git a/internal/util/jwks_test.go b/internal/util/jwks_test.go new file mode 100644 index 0000000..3b00c68 --- /dev/null +++ b/internal/util/jwks_test.go @@ -0,0 +1,143 @@ +package util + +import ( + "crypto/rand" + "crypto/rsa" + "encoding/base64" + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/golang-jwt/jwt/v4" +) + +func TestValidateJWT(t *testing.T) { + // Initialize the test JWKS data + initTestJWKS(t) + + // Test cases + tests := []struct { + name string + authHeader string + expectError bool + }{ + { + name: "Valid JWT token", + authHeader: "Bearer " + createValidJWT(t), + expectError: false, + }, + { + name: "No auth header", + authHeader: "", + expectError: true, + }, + { + name: "Invalid auth header format", + authHeader: "InvalidFormat", + expectError: true, + }, + { + name: "Invalid JWT token", + authHeader: "Bearer invalid.jwt.token", + expectError: true, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + err := ValidateJWT(tc.authHeader) + if tc.expectError && err == nil { + t.Errorf("Expected error but got none") + } + if !tc.expectError && err != nil { + t.Errorf("Expected no error but got: %v", err) + } + }) + } +} + +func TestFetchJWKS(t *testing.T) { + // Create a mock JWKS server + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Generate a test RSA key + privateKey, err := rsa.GenerateKey(rand.Reader, 2048) + if err != nil { + t.Fatalf("Failed to generate RSA key: %v", err) + } + + // Create JWKS response + jwks := map[string]interface{}{ + "keys": []map[string]interface{}{ + { + "kty": "RSA", + "kid": "test-key-id", + "n": base64.RawURLEncoding.EncodeToString(privateKey.N.Bytes()), + "e": base64.RawURLEncoding.EncodeToString([]byte{1, 0, 1}), // Default exponent 65537 + }, + }, + } + + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(jwks) + })) + defer server.Close() + + // Test fetching JWKS + err := FetchJWKS(server.URL) + if err != nil { + t.Fatalf("FetchJWKS failed: %v", err) + } + + // Check that keys were stored + if len(publicKeys) == 0 { + t.Errorf("Expected publicKeys to be populated") + } +} + +// Helper function to initialize test JWKS data +func initTestJWKS(t *testing.T) { + // Create a test RSA key pair + privateKey, err := rsa.GenerateKey(rand.Reader, 2048) + if err != nil { + t.Fatalf("Failed to generate RSA key: %v", err) + } + + // Initialize the publicKeys map + publicKeys = map[string]*rsa.PublicKey{ + "test-key-id": &privateKey.PublicKey, + } +} + +// Helper function to create a valid JWT token for testing +func createValidJWT(t *testing.T) string { + // Create a test RSA key pair + privateKey, err := rsa.GenerateKey(rand.Reader, 2048) + if err != nil { + t.Fatalf("Failed to generate RSA key: %v", err) + } + + // Ensure the test key is in the publicKeys map + if publicKeys == nil { + publicKeys = map[string]*rsa.PublicKey{} + } + publicKeys["test-key-id"] = &privateKey.PublicKey + + // Create token + token := jwt.NewWithClaims(jwt.SigningMethodRS256, jwt.MapClaims{ + "sub": "1234567890", + "name": "Test User", + "iat": time.Now().Unix(), + "exp": time.Now().Add(time.Hour).Unix(), + }) + token.Header["kid"] = "test-key-id" + + // Sign the token + tokenString, err := token.SignedString(privateKey) + if err != nil { + t.Fatalf("Failed to sign token: %v", err) + } + + return tokenString +} diff --git a/pull_request_template.md b/pull_request_template.md index 9b32185..c401a06 100644 --- a/pull_request_template.md +++ b/pull_request_template.md @@ -1,52 +1,11 @@ ## Purpose -> Describe the problems, issues, or needs driving this feature/fix and include links to related issues in the following format: Resolves issue1, issue2, etc. + -## Goals -> Describe the solutions that this feature/fix will introduce to resolve the problems described above - -## Approach -> Describe how you are implementing the solutions. Include an animated GIF or screenshot if the change affects the UI (email documentation@wso2.com to review all UI text). Include a link to a Markdown file or Google doc if the feature write-up is too long to paste here. - -## User stories -> Summary of user stories addressed by this change> - -## Release note -> Brief description of the new feature or bug fix as it will appear in the release notes - -## Documentation -> Link(s) to product documentation that addresses the changes of this PR. If no doc impact, enter “N/A” plus brief explanation of why there’s no doc impact - -## Training -> Link to the PR for changes to the training content in https://github.com/wso2/WSO2-Training, if applicable - -## Certification -> Type “Sent” when you have provided new/updated certification questions, plus four answers for each question (correct answer highlighted in bold), based on this change. Certification questions/answers should be sent to certification@wso2.com and NOT pasted in this PR. If there is no impact on certification exams, type “N/A” and explain why. - -## Marketing -> Link to drafts of marketing content that will describe and promote this feature, including product page changes, technical articles, blog posts, videos, etc., if applicable - -## Automation tests - - Unit tests - > Code coverage information - - Integration tests - > Details about the test cases and coverage - -## Security checks - - Followed secure coding standards in http://wso2.com/technical-reports/wso2-secure-engineering-guidelines? yes/no - - Ran FindSecurityBugs plugin and verified report? yes/no - - Confirmed that this PR doesn't commit any keys, passwords, tokens, usernames, or other secrets? yes/no - -## Samples -> Provide high-level details about the samples related to this feature +## Related Issues + ## Related PRs -> List any other related PRs + ## Migrations (if applicable) -> Describe migration steps and platforms on which migration has been tested - -## Test environment -> List all JDK versions, operating systems, databases, and browser/versions on which this feature/fix was tested - -## Learning -> Describe the research phase and any blog posts, patterns, libraries, or add-ons you used to solve the problem. \ No newline at end of file + diff --git a/resources/requirements.txt b/resources/requirements.txt new file mode 100644 index 0000000..102b728 --- /dev/null +++ b/resources/requirements.txt @@ -0,0 +1 @@ +fastmcp==0.4.1 \ No newline at end of file