diff --git a/.github/scripts/release.sh b/.github/scripts/release.sh deleted file mode 100644 index 2a1f6a9..0000000 --- a/.github/scripts/release.sh +++ /dev/null @@ -1,124 +0,0 @@ -#!/bin/bash - -# Copyright (c) 2025, WSO2 LLC. (https://www.wso2.com). -# -# This software is the property of WSO2 LLC. and its suppliers, if any. -# Dissemination of any information or reproduction of any material contained -# herein in any form is strictly forbidden, unless permitted by WSO2 expressly. -# You may not alter or remove any copyright or other notice from copies of this content. -# - -# Exit the script on any command with non-zero exit status. -set -e -set -o pipefail - -UPSTREAM_BRANCH="main" - -# Assign command line arguments to variables. -GIT_TOKEN=$1 -WORK_DIR=$2 -VERSION_TYPE=$3 # possible values: major, minor, patch - -# Check if GIT_TOKEN is empty -if [ -z "$GIT_TOKEN" ]; then - echo "❌ Error: GIT_TOKEN is not set." - exit 1 -fi - -# Check if WORK_DIR is empty -if [ -z "$WORK_DIR" ]; then - echo "❌ Error: WORK_DIR is not set." - exit 1 -fi - -# Validate VERSION_TYPE -if [[ "$VERSION_TYPE" != "major" && "$VERSION_TYPE" != "minor" && "$VERSION_TYPE" != "patch" ]]; then - echo "❌ Error: VERSION_TYPE must be one of: major, minor, or patch." - exit 1 -fi - -BUILD_DIRECTORY="$WORK_DIR/build" -RELEASE_DIRECTORY="$BUILD_DIRECTORY/releases" - -# Navigate to the working directory. -cd "${WORK_DIR}" - -# Create the release directory. -if [ ! -d "$RELEASE_DIRECTORY" ]; then - mkdir -p "$RELEASE_DIRECTORY" -else - rm -rf "$RELEASE_DIRECTORY"/* -fi - -# Extract current version. -CURRENT_VERSION=$(git describe --tags --abbrev=0 2>/dev/null || echo "0.0.0") -IFS='.' read -r MAJOR MINOR PATCH <<< "${CURRENT_VERSION}" - -# Determine which part to increment -case "$VERSION_TYPE" in - major) - MAJOR=$((MAJOR + 1)) - MINOR=0 - PATCH=0 - ;; - minor) - MINOR=$((MINOR + 1)) - PATCH=0 - ;; - patch|*) - PATCH=$((PATCH + 1)) - ;; -esac - -NEW_VERSION="${MAJOR}.${MINOR}.${PATCH}" - -echo "Creating release packages for version $NEW_VERSION..." - -# List of supported OSes. -oses=("linux" "linux-arm" "darwin") - -# Navigate to the release directory. -cd "${RELEASE_DIRECTORY}" - -for os in "${oses[@]}"; do - os_dir="../$os" - - if [ -d "$os_dir" ]; then - release_artifact_folder="openmcpauthproxy_${os}-v${NEW_VERSION}" - mkdir -p "$release_artifact_folder" - - cp -r $os_dir/* "$release_artifact_folder" - - # Zip the release package. - zip_file="$release_artifact_folder.zip" - echo "Creating $zip_file..." - zip -r "$zip_file" "$release_artifact_folder" - - # Delete the folder after zipping. - rm -rf "$release_artifact_folder" - - # Generate checksum file. - sha256sum "$zip_file" | sed "s|target/releases/||" > "$zip_file.sha256" - echo "Checksum generated for the $os package." - - echo "Release packages created successfully for $os." - else - echo "Skipping $os release package creation as the build artifacts are not available." - fi -done - -echo "Release packages created successfully in $RELEASE_DIRECTORY." - -# Navigate back to the project root directory. -cd "${WORK_DIR}" - -# Collect all ZIP and .sha256 files in the target/releases directory. -FILES_TO_UPLOAD=$(find build/releases -type f \( -name "*.zip" -o -name "*.sha256" \)) - -# Create a release with the current version. -TAG_NAME="v${NEW_VERSION}" -export GITHUB_TOKEN="${GIT_TOKEN}" -gh release create "${TAG_NAME}" ${FILES_TO_UPLOAD} --title "${TAG_NAME}" --notes "OpenMCPAuthProxy - ${TAG_NAME}" --target "${UPSTREAM_BRANCH}" || { echo "Failed to create release"; exit 1; } - - -echo "Release ${TAG_NAME} created successfully." diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml deleted file mode 100644 index 775003e..0000000 --- a/.github/workflows/ci.yaml +++ /dev/null @@ -1,71 +0,0 @@ -name: Build and Push container -run-name: Build and Push container -on: - workflow_dispatch: - #schedule: - # - cron: "0 10 * * *" - push: - branches: - - 'main' - - 'master' - tags: - - 'v*' - pull_request: - branches: - - 'main' - - 'master' -env: - IMAGE: git.kvant.cloud/${{github.repository}} -jobs: - build_concierge_backend: - runs-on: ubuntu-latest - steps: - - name: Checkout - uses: actions/checkout@v4 - with: - fetch-depth: 0 - - - name: Set current time - uses: https://github.com/gerred/actions/current-time@master - id: current_time - - - name: Set up Docker Buildx - uses: docker/setup-buildx-action@v3 - - - name: Login to git.kvant.cloud registry - uses: docker/login-action@v3 - with: - registry: git.kvant.cloud - username: ${{ vars.ORG_PACKAGE_WRITER_USERNAME }} - password: ${{ secrets.ORG_PACKAGE_WRITER_TOKEN }} - - - name: Docker meta - id: meta - uses: docker/metadata-action@v5 - with: - # list of Docker images to use as base name for tags - images: | - ${{env.IMAGE}} - # generate Docker tags based on the following events/attributes - tags: | - type=schedule - type=ref,event=branch - type=ref,event=pr - type=semver,pattern={{version}} - - - name: Build and push to gitea registry - uses: docker/build-push-action@v6 - with: - push: ${{ github.event_name != 'pull_request' }} - tags: ${{ steps.meta.outputs.tags }} - labels: ${{ steps.meta.outputs.labels }} - context: . - provenance: mode=max - sbom: true - build-args: | - BUILD_DATE=${{ steps.current_time.outputs.time }} - cache-from: | - type=registry,ref=${{ env.IMAGE }}:buildcache - type=registry,ref=${{ env.IMAGE }}:${{ github.ref_name }} - type=registry,ref=${{ env.IMAGE }}:main - cache-to: type=registry,ref=${{ env.IMAGE }}:buildcache,mode=max,image-manifest=true diff --git a/.github/workflows/pr-builder.yml b/.github/workflows/pr-builder.yml deleted file mode 100644 index a055e0d..0000000 --- a/.github/workflows/pr-builder.yml +++ /dev/null @@ -1,62 +0,0 @@ -name: Go CI - -on: - push: - branches: [ main ] - pull_request: - branches: [ main ] - -jobs: - test: - name: Test - runs-on: ubuntu-latest - strategy: - matrix: - go-version: ['1.20', '1.21'] - - steps: - - name: Checkout code - uses: actions/checkout@v3 - - - name: Set up Go - uses: actions/setup-go@v4 - with: - go-version: ${{ matrix.go-version }} - - - name: Get dependencies - run: go get -v -t -d ./... - - - name: Verify dependencies - run: go mod verify - - - name: Run go vet - run: go vet ./... - - - name: Run tests - run: go test -v -race -coverprofile=coverage.txt -covermode=atomic ./... - - - name: Upload coverage to Codecov - uses: codecov/codecov-action@v3 - with: - files: ./coverage.txt - fail_ci_if_error: false - - build: - name: Build - runs-on: ubuntu-latest - strategy: - matrix: - go-version: ['1.20', '1.21'] - os: [ubuntu-latest, macos-latest] - - steps: - - name: Checkout code - uses: actions/checkout@v3 - - - name: Set up Go - uses: actions/setup-go@v4 - with: - go-version: ${{ matrix.go-version }} - - - name: Build - run: go build -v ./cmd/proxy diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml deleted file mode 100644 index 0c51bc7..0000000 --- a/.github/workflows/release.yml +++ /dev/null @@ -1,64 +0,0 @@ -# -# Copyright (c) 2025, WSO2 LLC. (https://www.wso2.com). -# -# This software is the property of WSO2 LLC. and its suppliers, if any. -# Dissemination of any information or reproduction of any material contained -# herein in any form is strictly forbidden, unless permitted by WSO2 expressly. -# You may not alter or remove any copyright or other notice from copies of this content. -# - -name: Release - -on: - workflow_dispatch: - inputs: - version_type: - type: choice - description: Choose the type of version update - options: - - 'major' - - 'minor' - - 'patch' - required: true - -jobs: - update-and-release: - runs-on: ubuntu-latest - env: - GOPROXY: https://proxy.golang.org - if: github.event.pull_request.merged == true || github.event_name == 'workflow_dispatch' - steps: - - uses: actions/checkout@v2 - with: - ref: 'main' - fetch-depth: 0 - token: ${{ secrets.GITHUB_TOKEN }} - - uses: actions/checkout@v2 - - - name: Set up Go 1.x - uses: actions/setup-go@v3 - with: - go-version: "^1.x" - - - name: Cache Go modules - id: cache-go-modules - uses: actions/cache@v3 - with: - path: | - ~/.cache/go-build - ~/go/pkg/mod - key: ${{ runner.os }}-go-modules-${{ hashFiles('**/go.sum') }} - restore-keys: | - ${{ runner.os }}-go-modules- - - - name: Install dependencies - run: go mod download - - - name: Build and test - run: make build - working-directory: . - - - name: Update artifact version, package, commit, and create release. - env: - GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} - run: bash ./.github/scripts/release.sh $GITHUB_TOKEN ${{ github.workspace }} ${{ github.event.inputs.version_type }} diff --git a/.gitignore b/.gitignore index d200b58..6c1dd97 100644 --- a/.gitignore +++ b/.gitignore @@ -18,21 +18,15 @@ *.zip *.tar.gz *.rar -.venv # virtual machine crash logs, see http://www.java.com/en/download/help/error_hotspot.xml hs_err_pid* replay_pid* +# Go module cache files +go.sum + # OS generated files .DS_Store -# builds -build - -# test out files -coverage.out -coverage.html - -# IDE files -.vscode +openmcpauthproxy diff --git a/Dockerfile b/Dockerfile deleted file mode 100644 index dc468b1..0000000 --- a/Dockerfile +++ /dev/null @@ -1,48 +0,0 @@ -FROM --platform=${BUILDPLATFORM:-linux/amd64} golang:1.24@sha256:d9db32125db0c3a680cfb7a1afcaefb89c898a075ec148fdc2f0f646cc2ed509 AS build - -ARG TARGETPLATFORM -ARG BUILDPLATFORM -ARG TARGETOS -ARG TARGETARCH - -WORKDIR /workspace - -RUN apt update -qq && apt install -qq -y git bash curl g++ - -# Download libraries -ADD go.* . -RUN go mod download - -# Build -ADD cmd cmd -ADD internal internal -RUN CGO_ENABLED=0 GOOS=${TARGETOS} GOARCH=${TARGETARCH} go build -o webhook -ldflags '-w -extldflags "-static"' -o openmcpauthproxy ./cmd/proxy - -#Test -RUN CGO_ENABLED=1 GOOS=${TARGETOS} GOARCH=${TARGETARCH} go test -v -race ./... - - -# Build production container -FROM --platform=${BUILDPLATFORM:-linux/amd64} ubuntu:24.04 - -RUN apt-get update \ - && apt-get install --no-install-recommends -y \ - python3-pip \ - python-is-python3 \ - npm \ - && apt-get autoremove \ - && apt-get clean \ - && rm -rf /var/lib/apt/lists/* - -RUN pip install uvenv --break-system-packages - -WORKDIR /app -COPY --from=build /workspace/openmcpauthproxy /app/ - -ADD config.yaml /app - - -ENTRYPOINT ["/app/openmcpauthproxy"] - -ARG IMAGE_SOURCE -LABEL org.opencontainers.image.source=$IMAGE_SOURCE diff --git a/Makefile b/Makefile deleted file mode 100644 index b0d0926..0000000 --- a/Makefile +++ /dev/null @@ -1,88 +0,0 @@ -# Makefile for open-mcp-auth-proxy - -# Variables -PROJECT_ROOT := $(realpath $(dir $(abspath $(lastword $(MAKEFILE_LIST))))) -BINARY_NAME := openmcpauthproxy -GO := go -GOFMT := gofmt -GOVET := go vet -GOTEST := go test -GOLINT := golangci-lint -GOCOV := go tool cover -BUILD_DIR := build - -# Source files -SRC := $(shell find . -name "*.go" -not -path "./vendor/*") -PKGS := $(shell go list ./... | grep -v /vendor/) - -# Set build options -BUILD_OPTS := -v - -# Set test options -TEST_OPTS := -v -race - -.PHONY: all clean test fmt lint vet coverage help - -# Default target -all: lint test build-linux build-linux-arm build-darwin - -build: clean test build-linux build-linux-arm build-darwin - -build-linux: - mkdir -p $(BUILD_DIR)/linux - GOOS=linux GOARCH=amd64 CGO_ENABLED=0 go build -x -ldflags "-X main.version=$(BUILD_VERSION)" \ - -o $(BUILD_DIR)/linux/openmcpauthproxy $(PROJECT_ROOT)/cmd/proxy - cp config.yaml $(BUILD_DIR)/linux - -build-linux-arm: - mkdir -p $(BUILD_DIR)/linux-arm - GOOS=linux GOARCH=arm CGO_ENABLED=0 go build -x -ldflags "-X main.version=$(BUILD_VERSION)" \ - -o $(BUILD_DIR)/linux-arm/openmcpauthproxy $(PROJECT_ROOT)/cmd/proxy - cp config.yaml $(BUILD_DIR)/linux-arm - -build-darwin: - mkdir -p $(BUILD_DIR)/darwin - GOOS=darwin GOARCH=amd64 CGO_ENABLED=0 go build -x -ldflags "-X main.version=$(BUILD_VERSION)" \ - -o $(BUILD_DIR)/darwin/openmcpauthproxy $(PROJECT_ROOT)/cmd/proxy - cp config.yaml $(BUILD_DIR)/darwin - -# Clean build artifacts -clean: - @echo "Cleaning build artifacts..." - @rm -rf $(BUILD_DIR) - @rm -f coverage.out - -# Run tests -test: - @echo "Running tests..." - $(GOTEST) $(TEST_OPTS) ./... - -# Run tests with coverage report -coverage: - @echo "Running tests with coverage..." - @$(GOTEST) -coverprofile=coverage.out ./... - @$(GOCOV) -func=coverage.out - @$(GOCOV) -html=coverage.out -o coverage.html - @echo "Coverage report generated in coverage.html" - -# Run gofmt -fmt: - @echo "Running gofmt..." - @$(GOFMT) -w -s $(SRC) - -# Run go vet -vet: - @echo "Running go vet..." - @$(GOVET) ./... - -# Show help -help: - @echo "Available targets:" - @echo " all : Run lint, test, and build" - @echo " build : Build the application" - @echo " clean : Clean build artifacts" - @echo " test : Run tests" - @echo " coverage : Run tests with coverage report" - @echo " fmt : Run gofmt" - @echo " vet : Run go vet" - @echo " help : Show this help message" diff --git a/README.md b/README.md index 6be3ece..dbfe989 100644 --- a/README.md +++ b/README.md @@ -1,87 +1,82 @@ # Open MCP Auth Proxy -A lightweight authorization proxy for Model Context Protocol (MCP) servers that enforces authorization according to the [MCP authorization specification](https://spec.modelcontextprotocol.io/specification/2025-03-26/basic/authorization/) +The Open MCP Auth Proxy is a lightweight proxy designed to sit in front of MCP servers and enforce authorization in compliance with the [Model Context Protocol authorization](https://spec.modelcontextprotocol.io/specification/2025-03-26/basic/authorization/) requirements. It intercepts incoming requests, validates tokens, and offloads authentication and authorization to an OAuth-compliant Identity Provider. -[![🚀 Release](https://github.com/wso2/open-mcp-auth-proxy/actions/workflows/release.yml/badge.svg)](https://github.com/wso2/open-mcp-auth-proxy/actions/workflows/release.yml) -[![💬 Stackoverflow](https://img.shields.io/badge/Ask%20for%20help%20on-Stackoverflow-orange)](https://stackoverflow.com/questions/tagged/wso2is) -[![💬 Discord](https://img.shields.io/badge/Join%20us%20on-Discord-%23e01563.svg)](https://discord.gg/wso2) -[![🐦 Twitter](https://img.shields.io/twitter/follow/wso2.svg?style=social&label=Follow)](https://twitter.com/intent/follow?screen_name=wso2) -[![📝 License](https://img.shields.io/badge/License-Apache%202.0-blue.svg)](https://github.com/wso2/product-is/blob/master/LICENSE) +![image](https://github.com/user-attachments/assets/41cf6723-c488-4860-8640-8fec45006f92) -![Architecture Diagram](https://github.com/user-attachments/assets/41cf6723-c488-4860-8640-8fec45006f92) +## **Setup and Installation** -## What it Does +### **Prerequisites** -Open MCP Auth Proxy sits between MCP clients and your MCP server to: +* Go 1.20 or higher +* A running MCP server (SSE transport supported) +* An MCP client that supports MCP authorization -- Intercept incoming requests -- Validate authorization tokens -- Offload authentication and authorization to OAuth-compliant Identity Providers -- Support the MCP authorization protocol +### **Installation** -## Quick Start +```bash +git clone https://github.com/wso2/open-mcp-auth-proxy +cd open-mcp-auth-proxy -### Prerequisites +go get github.com/golang-jwt/jwt/v4 +go get gopkg.in/yaml.v2 -* Go 1.20 or higher -* A running MCP server +go build -o openmcpauthproxy ./cmd/proxy +``` -> If you don't have an MCP server, you can use the included example: -> -> 1. Navigate to the `resources` directory -> 2. Set up a Python environment: -> -> ```bash -> python3 -m venv .venv -> source .venv/bin/activate -> pip3 install -r requirements.txt -> ``` -> -> 3. Start the example server: -> -> ```bash -> python3 echo_server.py -> ``` +## Using Open MCP Auth Proxy -* An MCP client that supports MCP authorization +### Quick Start -### Basic Usage +Allows you to just enable authentication and authorization for your MCP server with the preconfigured auth provider powered by Asgardeo. -1. Download the latest release from [Github releases](https://github.com/wso2/open-mcp-auth-proxy/releases/latest). +If you don’t have an MCP server, follow the instructions given here to start your own MCP server for testing purposes. +1. Download [sample MCP server](resources/echo_server.py) +2. Run the server with +```bash +python3 echo_server.py +``` -2. Start the proxy in demo mode (uses pre-configured authentication with Asgardeo sandbox): +#### Configure the Auth Proxy + +Update the following parameters in `config.yaml`. + +### demo mode configuration: + +```yaml +mcp_server_base_url: "http://localhost:8000" # URL of your MCP server +listen_port: 8080 # Address where the proxy will listen +``` + +#### Start the Auth Proxy ```bash ./openmcpauthproxy --demo ``` -> The repository comes with a default `config.yaml` file that contains the basic configuration: -> -> ```yaml -> listen_port: 8080 -> base_url: "http://localhost:8000" # Your MCP server URL -> paths: -> sse: "/sse" -> messages: "/messages/" -> ``` +The `--demo` flag enables a demonstration mode with pre-configured authentication and authorization with a sandbox powered by [Asgardeo](https://asgardeo.io/). -3. Connect using an MCP client like [MCP Inspector](https://github.com/shashimalcse/inspector)(This is a temporary fork with fixes for authentication [issues](https://github.com/modelcontextprotocol/typescript-sdk/issues/257) in the original implementation) +#### Connect Using an MCP Client -## Connect an Identity Provider +You can use this improved fork of [MCP Inspector](https://github.com/shashimalcse/inspector) to test the connection and try out the complete authorization flow. -### Asgardeo +### Use with Asgardeo -To enable authorization through your Asgardeo organization: +Enable authorization for the MCP server through your own Asgardeo organization -1. [Register](https://asgardeo.io/signup) and create an organization in Asgardeo -2. Create an [M2M application](https://wso2.com/asgardeo/docs/guides/applications/register-machine-to-machine-app/) - 1. [Authorize this application](https://wso2.com/asgardeo/docs/guides/applications/register-machine-to-machine-app/#authorize-the-api-resources-for-the-app) to invoke "Application Management API" with the `internal_application_mgt_create` scope - ![image](https://github.com/user-attachments/assets/0bd57cac-1904-48cc-b7aa-0530224bc41a) - -3. Update `config.yaml` with the following parameters. +1. [Register]([url](https://asgardeo.io/signup)) and create an organization in Asgardeo +2. Now, you need to authorize the OpenMCPAuthProxy to allow dynamically registering MCP Clients as applications in your organization. To do that, + 1. Create an [M2M application](https://wso2.com/asgardeo/docs/guides/applications/register-machine-to-machine-app/) + 1. [Authorize this application](https://wso2.com/asgardeo/docs/guides/applications/register-machine-to-machine-app/#authorize-the-api-resources-for-the-app) to invoke “Application Management API” with the `internal_application_mgt_create` scope. + ![image](https://github.com/user-attachments/assets/0bd57cac-1904-48cc-b7aa-0530224bc41a) + 2. Note the **Client ID** and **Client secret** of this application. This is required by the auth proxy + +#### Configure the Auth Proxy + +Create a configuration file config.yaml with the following parameters: ```yaml -base_url: "http://localhost:8000" # URL of your MCP server +mcp_server_base_url: "http://localhost:8000" # URL of your MCP server listen_port: 8080 # Address where the proxy will listen asgardeo: @@ -90,137 +85,31 @@ asgardeo: client_secret: "" # Client secret of the M2M app ``` -4. Start the proxy with Asgardeo integration: +#### Start the Auth Proxy ```bash ./openmcpauthproxy --asgardeo ``` -### Other OAuth Providers +### Use with any standard OAuth Server -- [Auth0](docs/integrations/Auth0.md) -- [Keycloak](docs/integrations/keycloak.md) +Enable authorization for the MCP server with a compliant OAuth server -# Advanced Configuration +#### Configuration -### Transport Modes - -The proxy supports two transport modes: - -- **SSE Mode (Default)**: For Server-Sent Events transport -- **stdio Mode**: For MCP servers that use stdio transport - -When using stdio mode, the proxy: -- Starts an MCP server as a subprocess using the command specified in the configuration -- Communicates with the subprocess through standard input/output (stdio) -- **Note**: Any commands specified (like `npx` in the example below) must be installed on your system first - -To use stdio mode: - -```bash -./openmcpauthproxy --demo --stdio -``` - -#### Example: Running an MCP Server as a Subprocess - -1. Configure stdio mode in your `config.yaml`: +Create a configuration file config.yaml with the following parameters: ```yaml -listen_port: 8080 -base_url: "http://localhost:8000" - -stdio: - enabled: true - user_command: "npx -y @modelcontextprotocol/server-github" # Example using a GitHub MCP server - env: # Environment variables (optional) - - "GITHUB_PERSONAL_ACCESS_TOKEN=gitPAT" - -# CORS configuration -cors: - allowed_origins: - - "http://localhost:5173" # Origin of your client application - allowed_methods: - - "GET" - - "POST" - - "PUT" - - "DELETE" - allowed_headers: - - "Authorization" - - "Content-Type" - allow_credentials: true - -# Demo configuration for Asgardeo -demo: - org_name: "openmcpauthdemo" - client_id: "N0U9e_NNGr9mP_0fPnPfPI0a6twa" - client_secret: "qFHfiBp5gNGAO9zV4YPnDofBzzfInatfUbHyPZvM0jka" +mcp_server_base_url: "http://localhost:8000" # URL of your MCP server +listen_port: 8080 # Address where the proxy will listen ``` +**TODO**: Update the configs for a standard OAuth Server. -2. Run the proxy with stdio mode: +#### Start the Auth Proxy ```bash -./openmcpauthproxy --demo +./openmcpauthproxy ``` +#### Integrating with existing OAuth Providers -The proxy will: -- Start the MCP server as a subprocess using the specified command -- Handle all authorization requirements -- Forward messages between clients and the server - -### Complete Configuration Reference - -```yaml -# Common configuration -listen_port: 8080 -base_url: "http://localhost:8000" -port: 8000 - -# Path configuration -paths: - sse: "/sse" - messages: "/messages/" - -# Transport mode -transport_mode: "sse" # Options: "sse" or "stdio" - -# stdio-specific configuration (used only in stdio mode) -stdio: - enabled: true - user_command: "npx -y @modelcontextprotocol/server-github" # Command to start the MCP server (requires npx to be installed) - work_dir: "" # Optional working directory for the subprocess - -# CORS configuration -cors: - allowed_origins: - - "http://localhost:5173" - allowed_methods: - - "GET" - - "POST" - - "PUT" - - "DELETE" - allowed_headers: - - "Authorization" - - "Content-Type" - allow_credentials: true - -# Demo configuration for Asgardeo -demo: - org_name: "openmcpauthdemo" - client_id: "N0U9e_NNGr9mP_0fPnPfPI0a6twa" - client_secret: "qFHfiBp5gNGAO9zV4YPnDofBzzfInatfUbHyPZvM0jka" - -# Asgardeo configuration (used with --asgardeo flag) -asgardeo: - org_name: "" - client_id: "" - client_secret: "" -``` - -### Build from source - -```bash -git clone https://github.com/wso2/open-mcp-auth-proxy -cd open-mcp-auth-proxy -go get github.com/golang-jwt/jwt/v4 gopkg.in/yaml.v2 -go build -o openmcpauthproxy ./cmd/proxy -``` + - [Auth0](docs/Auth0.md) - Enable authorization for the MCP server through your Auth0 organization. diff --git a/cmd/proxy/main.go b/cmd/proxy/main.go index c43dd7d..cde3cf3 100644 --- a/cmd/proxy/main.go +++ b/cmd/proxy/main.go @@ -3,71 +3,31 @@ package main import ( "flag" "fmt" + "log" "net/http" "os" "os/signal" - "syscall" "time" "github.com/wso2/open-mcp-auth-proxy/internal/authz" "github.com/wso2/open-mcp-auth-proxy/internal/config" "github.com/wso2/open-mcp-auth-proxy/internal/constants" - logger "github.com/wso2/open-mcp-auth-proxy/internal/logging" "github.com/wso2/open-mcp-auth-proxy/internal/proxy" - "github.com/wso2/open-mcp-auth-proxy/internal/subprocess" "github.com/wso2/open-mcp-auth-proxy/internal/util" ) func main() { demoMode := flag.Bool("demo", false, "Use Asgardeo-based provider (demo).") asgardeoMode := flag.Bool("asgardeo", false, "Use Asgardeo-based provider (asgardeo).") - debugMode := flag.Bool("debug", false, "Enable debug logging") - stdioMode := flag.Bool("stdio", false, "Use stdio transport mode instead of SSE") flag.Parse() - logger.SetDebug(*debugMode) - // 1. Load config cfg, err := config.LoadConfig("config.yaml") if err != nil { - logger.Error("Error loading config: %v", err) - os.Exit(1) + log.Fatalf("Error loading config: %v", err) } - // Override transport mode if stdio flag is set - if *stdioMode { - cfg.TransportMode = config.StdioTransport - // Ensure stdio is enabled - cfg.Stdio.Enabled = true - // Re-validate config - if err := cfg.Validate(); err != nil { - logger.Error("Configuration error: %v", err) - os.Exit(1) - } - } - - logger.Info("Using transport mode: %s", cfg.TransportMode) - logger.Info("Using MCP server base URL: %s", cfg.BaseURL) - logger.Info("Using MCP paths: SSE=%s, Messages=%s", cfg.Paths.SSE, cfg.Paths.Messages) - - // 2. Start subprocess if configured and in stdio mode - var procManager *subprocess.Manager - if cfg.TransportMode == config.StdioTransport && cfg.Stdio.Enabled { - // Ensure all required dependencies are available - if err := subprocess.EnsureDependenciesAvailable(cfg.Stdio.UserCommand); err != nil { - logger.Warn("%v", err) - logger.Warn("Subprocess may fail to start due to missing dependencies") - } - - procManager = subprocess.NewManager() - if err := procManager.Start(cfg); err != nil { - logger.Warn("Failed to start subprocess: %v", err) - } - } else if cfg.TransportMode == config.SSETransport { - logger.Info("Using SSE transport mode, not starting subprocess") - } - - // 3. Create the chosen provider + // 2. Create the chosen provider var provider authz.Provider if *demoMode { cfg.Mode = "demo" @@ -86,49 +46,41 @@ func main() { provider = authz.NewDefaultProvider(cfg) } - // 4. (Optional) Fetch JWKS if you want local JWT validation + // 3. (Optional) Fetch JWKS if you want local JWT validation if err := util.FetchJWKS(cfg.JWKSURL); err != nil { - logger.Error("Failed to fetch JWKS: %v", err) - os.Exit(1) + log.Fatalf("Failed to fetch JWKS: %v", err) } - // 5. Build the main router + // 4. Build the main router mux := proxy.NewRouter(cfg, provider) - listen_address := fmt.Sprintf("0.0.0.0:%d", cfg.ListenPort) + listen_address := fmt.Sprintf(":%d", cfg.ListenPort) - // 6. Start the server + // 5. Start the server srv := &http.Server{ + Addr: listen_address, Handler: mux, } go func() { - logger.Info("Server listening on %s", listen_address) + log.Printf("Server listening on %s", listen_address) if err := srv.ListenAndServe(); err != nil && err != http.ErrServerClosed { - logger.Error("Server error: %v", err) - os.Exit(1) + log.Fatalf("Server error: %v", err) } }() - // 7. Wait for shutdown signal + // 6. Graceful shutdown on Ctrl+C stop := make(chan os.Signal, 1) - signal.Notify(stop, os.Interrupt, syscall.SIGTERM) + signal.Notify(stop, os.Interrupt) <-stop - logger.Info("Shutting down...") + log.Println("Shutting down...") - // 8. First terminate subprocess if running - if procManager != nil && procManager.IsRunning() { - procManager.Shutdown() - } - - // 9. Then shutdown the server - logger.Info("Shutting down HTTP server...") shutdownCtx, cancel := proxy.NewShutdownContext(5 * time.Second) defer cancel() if err := srv.Shutdown(shutdownCtx); err != nil { - logger.Error("HTTP server shutdown error: %v", err) + log.Printf("Shutdown error: %v", err) } - logger.Info("Stopped.") + log.Println("Stopped.") } diff --git a/config.yaml b/config.yaml index ef70fbb..0b0ade4 100644 --- a/config.yaml +++ b/config.yaml @@ -1,28 +1,22 @@ # config.yaml -# Common configuration for all transport modes +mcp_server_base_url: "" listen_port: 8080 -base_url: "http://localhost:8000" # Base URL for the MCP server -port: 8000 # Port for the MCP server timeout_seconds: 10 +mcp_paths: + - /messages/ + - /sse -# Transport mode configuration -transport_mode: "stdio" # Options: "sse" or "stdio" +path_mapping: + /token: /token + /register: /register + /authorize: /authorize + /.well-known/oauth-authorization-server: /.well-known/oauth-authorization-server -# stdio-specific configuration (used only when transport_mode is "stdio") -stdio: - enabled: true - user_command: uvx mcp-server-time --local-timezone=Europe/Zurich - #user_command: "npx -y @modelcontextprotocol/server-github" - work_dir: "" # Working directory (optional) - # env: # Environment variables (optional) - # - "NODE_ENV=development" - -# CORS settings cors: allowed_origins: - - "http://localhost:6274" # Origin of your frontend/client app + - "" allowed_methods: - "GET" - "POST" @@ -31,26 +25,29 @@ cors: allowed_headers: - "Authorization" - "Content-Type" - - "mcp-protocol-version" allow_credentials: true -# Keycloak endpoint path mappings -path_mapping: - sse: "/sse" # SSE endpoint path - messages: "/messages/" # Messages endpoint path - /token: /realms/master/protocol/openid-connect/token - /register: /realms/master/clients-registrations/openid-connect +demo: + org_name: "openmcpauthdemo" + client_id: "N0U9e_NNGr9mP_0fPnPfPI0a6twa" + client_secret: "qFHfiBp5gNGAO9zV4YPnDofBzzfInatfUbHyPZvM0jka" + +asgardeo: + org_name: "" + client_id: "" + client_secret: "" -# Keycloak configuration block default: - base_url: "https://iam.phoenix-systems.ch" - jwks_url: "https://iam.phoenix-systems.ch/realms/kvant/protocol/openid-connect/certs" + base_url: "" + jwks_url: "" path: /.well-known/oauth-authorization-server: response: - issuer: "https://iam.phoenix-systems.ch/realms/kvant" - jwks_uri: "https://iam.phoenix-systems.ch/realms/kvant/protocol/openid-connect/certs" - authorization_endpoint: "https://iam.phoenix-systems.ch/realms/kvant/protocol/openid-connect/auth" + issuer: "" + jwks_uri: "" + authorization_endpoint: "" # Optional + token_endpoint: "" # Optional + registration_endpoint: "" # Optional response_types_supported: - "code" grant_types_supported: @@ -58,8 +55,17 @@ default: - "refresh_token" code_challenge_methods_supported: - "S256" - - "plain" + - "plain" + /authroize: + addQueryParams: + - name: "" + value: "" /token: addBodyParams: - - name: "audience" - value: "mcp_proxy" \ No newline at end of file + - name: "" + value: "" + /register: + addBodyParams: + - name: "" + value: "" + diff --git a/docs/integrations/Auth0.md b/docs/Auth0.md similarity index 91% rename from docs/integrations/Auth0.md rename to docs/Auth0.md index 9195659..fe55edc 100644 --- a/docs/integrations/Auth0.md +++ b/docs/Auth0.md @@ -4,7 +4,7 @@ This guide will help you configure Open MCP Auth Proxy to use Auth0 as your iden ### Prerequisites -- An Auth0 organization (sign up [here](https://auth0.com) if you don't have one) +- An Auth0 organization (sign up here if you don't have one) - Open MCP Auth Proxy installed ### Setting Up Auth0 @@ -28,17 +28,9 @@ Update your `config.yaml` with Auth0 settings: ```yaml # Basic proxy configuration +mcp_server_base_url: "http://localhost:8000" listen_port: 8080 -base_url: "http://localhost:8000" -port: 8000 - -# Path configuration -paths: - sse: "/sse" - messages: "/messages/" - -# Transport mode -transport_mode: "sse" +timeout_seconds: 10 # CORS configuration cors: diff --git a/docs/integrations/keycloak.md b/docs/integrations/keycloak.md deleted file mode 100644 index 5e338cc..0000000 --- a/docs/integrations/keycloak.md +++ /dev/null @@ -1,92 +0,0 @@ -## Integrating Open MCP Auth Proxy with Keycloak - -This guide walks you through configuring the Open MCP Auth Proxy to authenticate using Keycloak as the identity provider. - ---- - -### Prerequisites - -Before you begin, ensure you have the following: - -- A running Keycloak instance -- Open MCP Auth Proxy installed and accessible - ---- - -### Step 1: Configure Keycloak for Client Registration - -Set up dynamic client registration in your Keycloak realm by following the [Keycloak client registration guide](https://www.keycloak.org/securing-apps/client-registration). - ---- - -### Step 2: Configure Open MCP Auth Proxy - -Update the `config.yaml` file in your Open MCP Auth Proxy setup using your Keycloak realm's [OIDC settings](https://www.keycloak.org/securing-apps/oidc-layers). Below is an example configuration: - -```yaml -# Proxy server configuration -listen_port: 8081 # Port for the auth proxy -base_url: "http://localhost:8000" # Base URL of the MCP server -port: 8000 # MCP server port - -# Define path mappings -paths: - sse: "/sse" - messages: "/messages/" - -# Set the transport mode -transport_mode: "sse" - -# CORS settings -cors: - allowed_origins: - - "http://localhost:5173" # Origin of your frontend/client app - allowed_methods: - - "GET" - - "POST" - - "PUT" - - "DELETE" - allowed_headers: - - "Authorization" - - "Content-Type" - - "mcp-protocol-version" - allow_credentials: true - -# Keycloak endpoint path mappings -path_mapping: - /token: /realms/master/protocol/openid-connect/token - /register: /realms/master/clients-registrations/openid-connect - -# Keycloak configuration block -default: - base_url: "http://localhost:8080" - jwks_url: "http://localhost:8080/realms/master/protocol/openid-connect/certs" - path: - /.well-known/oauth-authorization-server: - response: - issuer: "http://localhost:8080/realms/master" - jwks_uri: "http://localhost:8080/realms/master/protocol/openid-connect/certs" - authorization_endpoint: "http://localhost:8080/realms/master/protocol/openid-connect/auth" - response_types_supported: - - "code" - grant_types_supported: - - "authorization_code" - - "refresh_token" - code_challenge_methods_supported: - - "S256" - - "plain" - /token: - addBodyParams: - - name: "audience" - value: "mcp_proxy" -``` - -### Step 3: Start the Auth Proxy - -Launch the proxy with the updated Keycloak configuration: - -```bash -./openmcpauthproxy -``` - -Once running, the proxy will handle authentication requests through your configured Keycloak realm. diff --git a/go.mod b/go.mod index 0bceb4f..2d26216 100644 --- a/go.mod +++ b/go.mod @@ -1,6 +1,6 @@ module github.com/wso2/open-mcp-auth-proxy -go 1.21 +go 1.22.3 require ( github.com/golang-jwt/jwt/v4 v4.5.2 diff --git a/go.sum b/go.sum deleted file mode 100644 index 9d27ad1..0000000 --- a/go.sum +++ /dev/null @@ -1,6 +0,0 @@ -github.com/golang-jwt/jwt/v4 v4.5.2 h1:YtQM7lnr8iZ+j5q71MGKkNw9Mn7AjHM68uc9g5fXeUI= -github.com/golang-jwt/jwt/v4 v4.5.2/go.mod h1:m21LjoU+eqJr34lmDMbreY2eSTRJ1cv77w39/MY0Ch0= -gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= -gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= -gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY= -gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ= diff --git a/internal/authz/asgardeo.go b/internal/authz/asgardeo.go index a3c812c..7408f79 100644 --- a/internal/authz/asgardeo.go +++ b/internal/authz/asgardeo.go @@ -7,13 +7,13 @@ import ( "encoding/json" "fmt" "io" + "log" "math/rand" "net/http" "strings" "time" "github.com/wso2/open-mcp-auth-proxy/internal/config" - "github.com/wso2/open-mcp-auth-proxy/internal/logging" ) type asgardeoProvider struct { @@ -31,7 +31,6 @@ func (p *asgardeoProvider) WellKnownHandler() http.HandlerFunc { w.Header().Set("Access-Control-Allow-Origin", "*") w.Header().Set("Access-Control-Allow-Headers", "Authorization, Content-Type") w.Header().Set("Access-Control-Allow-Methods", "GET, OPTIONS") - w.Header().Set("X-Accel-Buffering", "no") if r.Method == http.MethodOptions { w.WriteHeader(http.StatusNoContent) @@ -71,9 +70,8 @@ func (p *asgardeoProvider) WellKnownHandler() http.HandlerFunc { } w.Header().Set("Content-Type", "application/json") - w.Header().Set("X-Accel-Buffering", "no") if err := json.NewEncoder(w).Encode(response); err != nil { - logger.Error("Error encoding well-known: %v", err) + log.Printf("[asgardeoProvider] Error encoding well-known: %v", err) http.Error(w, "Internal server error", http.StatusInternalServerError) } } @@ -85,7 +83,6 @@ func (p *asgardeoProvider) RegisterHandler() http.HandlerFunc { w.Header().Set("Access-Control-Allow-Origin", "*") w.Header().Set("Access-Control-Allow-Headers", "Authorization, Content-Type") w.Header().Set("Access-Control-Allow-Methods", "GET, OPTIONS") - w.Header().Set("X-Accel-Buffering", "no") if r.Method == http.MethodOptions { w.WriteHeader(http.StatusNoContent) @@ -98,7 +95,7 @@ func (p *asgardeoProvider) RegisterHandler() http.HandlerFunc { var regReq RegisterRequest if err := json.NewDecoder(r.Body).Decode(®Req); err != nil { - logger.Error("Reading register request: %v", err) + log.Printf("ERROR: reading register request: %v", err) http.Error(w, "Invalid request body", http.StatusBadRequest) return } @@ -112,7 +109,7 @@ func (p *asgardeoProvider) RegisterHandler() http.HandlerFunc { regReq.ClientSecret = randomString(16) if err := p.createAsgardeoApplication(regReq); err != nil { - logger.Warn("Asgardeo application creation failed: %v", err) + log.Printf("WARN: Asgardeo application creation failed: %v", err) // Optionally http.Error(...) if you want to fail // or continue to return partial data. } @@ -127,10 +124,9 @@ func (p *asgardeoProvider) RegisterHandler() http.HandlerFunc { } w.Header().Set("Content-Type", "application/json") - w.Header().Set("X-Accel-Buffering", "no") w.WriteHeader(http.StatusCreated) if err := json.NewEncoder(w).Encode(resp); err != nil { - logger.Error("Encoding /register response: %v", err) + log.Printf("ERROR: encoding /register response: %v", err) http.Error(w, "Internal server error", http.StatusInternalServerError) } } @@ -190,7 +186,7 @@ func (p *asgardeoProvider) createAsgardeoApplication(regReq RegisterRequest) err return fmt.Errorf("Asgardeo creation error (%d): %s", resp.StatusCode, string(respBody)) } - logger.Info("Created Asgardeo application for clientID=%s", regReq.ClientID) + log.Printf("INFO: Created Asgardeo application for clientID=%s", regReq.ClientID) return nil } @@ -206,11 +202,8 @@ func (p *asgardeoProvider) getAsgardeoAdminToken() (string, error) { } req.Header.Set("Content-Type", "application/x-www-form-urlencoded") - // Sensitive data - should not be logged at INFO level auth := p.cfg.Demo.ClientID + ":" + p.cfg.Demo.ClientSecret req.Header.Set("Authorization", "Basic "+base64.StdEncoding.EncodeToString([]byte(auth))) - - logger.Debug("Requesting admin token for Asgardeo with client ID: %s", p.cfg.Demo.ClientID) tr := &http.Transport{ TLSClientConfig: &tls.Config{InsecureSkipVerify: true}, @@ -241,10 +234,6 @@ func (p *asgardeoProvider) getAsgardeoAdminToken() (string, error) { return "", fmt.Errorf("failed to parse token JSON: %w", err) } - // Don't log the actual token at info level, only at debug level - logger.Debug("Received access token: %s", tokenResp.AccessToken) - logger.Info("Successfully obtained admin token from Asgardeo") - return tokenResp.AccessToken, nil } diff --git a/internal/authz/default.go b/internal/authz/default.go index 929f586..9230d39 100644 --- a/internal/authz/default.go +++ b/internal/authz/default.go @@ -5,7 +5,6 @@ import ( "net/http" "github.com/wso2/open-mcp-auth-proxy/internal/config" - "github.com/wso2/open-mcp-auth-proxy/internal/logging" ) type defaultProvider struct { @@ -82,7 +81,6 @@ func (p *defaultProvider) WellKnownHandler() http.HandlerFunc { w.Header().Set("Content-Type", "application/json") if err := json.NewEncoder(w).Encode(response); err != nil { - logger.Error("Error encoding well-known response: %v", err) http.Error(w, "Internal server error", http.StatusInternalServerError) } return diff --git a/internal/authz/default_test.go b/internal/authz/default_test.go deleted file mode 100644 index f40030f..0000000 --- a/internal/authz/default_test.go +++ /dev/null @@ -1,125 +0,0 @@ -package authz - -import ( - "encoding/json" - "net/http" - "net/http/httptest" - "testing" - - "github.com/wso2/open-mcp-auth-proxy/internal/config" -) - -func TestNewDefaultProvider(t *testing.T) { - cfg := &config.Config{} - provider := NewDefaultProvider(cfg) - - if provider == nil { - t.Fatal("Expected non-nil provider") - } - - // Ensure it implements the Provider interface - var _ Provider = provider -} - -func TestDefaultProviderWellKnownHandler(t *testing.T) { - // Create a config with a custom well-known response - cfg := &config.Config{ - Default: config.DefaultConfig{ - Path: map[string]config.PathConfig{ - "/.well-known/oauth-authorization-server": { - Response: &config.ResponseConfig{ - Issuer: "https://test-issuer.com", - JwksURI: "https://test-issuer.com/jwks", - ResponseTypesSupported: []string{"code"}, - GrantTypesSupported: []string{"authorization_code"}, - CodeChallengeMethodsSupported: []string{"S256"}, - }, - }, - }, - }, - } - - provider := NewDefaultProvider(cfg) - handler := provider.WellKnownHandler() - - // Create a test request - req := httptest.NewRequest("GET", "/.well-known/oauth-authorization-server", nil) - req.Host = "test-host.com" - req.Header.Set("X-Forwarded-Proto", "https") - - // Create a response recorder - w := httptest.NewRecorder() - - // Call the handler - handler(w, req) - - // Check response status - if w.Code != http.StatusOK { - t.Errorf("Expected status OK, got %v", w.Code) - } - - // Verify content type - contentType := w.Header().Get("Content-Type") - if contentType != "application/json" { - t.Errorf("Expected Content-Type: application/json, got %s", contentType) - } - - // Decode and check the response body - var response map[string]interface{} - if err := json.NewDecoder(w.Body).Decode(&response); err != nil { - t.Fatalf("Failed to decode response JSON: %v", err) - } - - // Check expected values - if response["issuer"] != "https://test-issuer.com" { - t.Errorf("Expected issuer=https://test-issuer.com, got %v", response["issuer"]) - } - if response["jwks_uri"] != "https://test-issuer.com/jwks" { - t.Errorf("Expected jwks_uri=https://test-issuer.com/jwks, got %v", response["jwks_uri"]) - } - if response["authorization_endpoint"] != "https://test-host.com/authorize" { - t.Errorf("Expected authorization_endpoint=https://test-host.com/authorize, got %v", response["authorization_endpoint"]) - } -} - -func TestDefaultProviderHandleOPTIONS(t *testing.T) { - provider := NewDefaultProvider(&config.Config{}) - handler := provider.WellKnownHandler() - - // Create OPTIONS request - req := httptest.NewRequest("OPTIONS", "/.well-known/oauth-authorization-server", nil) - w := httptest.NewRecorder() - - // Call the handler - handler(w, req) - - // Check response - if w.Code != http.StatusNoContent { - t.Errorf("Expected status NoContent for OPTIONS request, got %v", w.Code) - } - - // Check CORS headers - if w.Header().Get("Access-Control-Allow-Origin") != "*" { - t.Errorf("Expected Access-Control-Allow-Origin: *, got %s", w.Header().Get("Access-Control-Allow-Origin")) - } - if w.Header().Get("Access-Control-Allow-Methods") != "GET, OPTIONS" { - t.Errorf("Expected Access-Control-Allow-Methods: GET, OPTIONS, got %s", w.Header().Get("Access-Control-Allow-Methods")) - } -} - -func TestDefaultProviderInvalidMethod(t *testing.T) { - provider := NewDefaultProvider(&config.Config{}) - handler := provider.WellKnownHandler() - - // Create POST request (which should be rejected) - req := httptest.NewRequest("POST", "/.well-known/oauth-authorization-server", nil) - w := httptest.NewRecorder() - - // Call the handler - handler(w, req) - - // Check response - if w.Code != http.StatusMethodNotAllowed { - t.Errorf("Expected status MethodNotAllowed for POST request, got %v", w.Code) - } -} diff --git a/internal/config/config.go b/internal/config/config.go index fc6743c..01c3a6f 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -1,35 +1,12 @@ package config import ( - "fmt" "os" "gopkg.in/yaml.v2" ) -// Transport mode for MCP server -type TransportMode string - -const ( - SSETransport TransportMode = "sse" - StdioTransport TransportMode = "stdio" -) - -// Common path configuration for all transport modes -type PathsConfig struct { - SSE string `yaml:"sse"` - Messages string `yaml:"messages"` -} - -// StdioConfig contains stdio-specific configuration -type StdioConfig struct { - Enabled bool `yaml:"enabled"` - UserCommand string `yaml:"user_command"` // The command provided by the user - WorkDir string `yaml:"work_dir"` // Working directory (optional) - Args []string `yaml:"args,omitempty"` // Additional arguments - Env []string `yaml:"env,omitempty"` // Environment variables -} - +// AsgardeoConfig groups all Asgardeo-specific fields type DemoConfig struct { ClientID string `yaml:"client_id"` ClientSecret string `yaml:"client_secret"` @@ -83,18 +60,15 @@ type DefaultConfig struct { } type Config struct { - AuthServerBaseURL string - ListenPort int `yaml:"listen_port"` - BaseURL string `yaml:"base_url"` - Port int `yaml:"port"` - JWKSURL string - TimeoutSeconds int `yaml:"timeout_seconds"` - PathMapping map[string]string `yaml:"path_mapping"` - Mode string `yaml:"mode"` - CORSConfig CORSConfig `yaml:"cors"` - TransportMode TransportMode `yaml:"transport_mode"` - Paths PathsConfig `yaml:"paths"` - Stdio StdioConfig `yaml:"stdio"` + AuthServerBaseURL string + MCPServerBaseURL string `yaml:"mcp_server_base_url"` + ListenPort int `yaml:"listen_port"` + JWKSURL string + TimeoutSeconds int `yaml:"timeout_seconds"` + MCPPaths []string `yaml:"mcp_paths"` + PathMapping map[string]string `yaml:"path_mapping"` + Mode string `yaml:"mode"` + CORSConfig CORSConfig `yaml:"cors"` // Nested config for Asgardeo Demo DemoConfig `yaml:"demo"` @@ -102,56 +76,6 @@ type Config struct { Default DefaultConfig `yaml:"default"` } -// Validate checks if the config is valid based on transport mode -func (c *Config) Validate() error { - // Validate based on transport mode - if c.TransportMode == StdioTransport { - if !c.Stdio.Enabled { - return fmt.Errorf("stdio.enabled must be true in stdio transport mode") - } - if c.Stdio.UserCommand == "" { - return fmt.Errorf("stdio.user_command is required in stdio transport mode") - } - } - - // Validate paths - if c.Paths.SSE == "" { - c.Paths.SSE = "/sse" // Default value - } - if c.Paths.Messages == "" { - c.Paths.Messages = "/messages" // Default value - } - - // Validate base URL - if c.BaseURL == "" { - if c.Port > 0 { - c.BaseURL = fmt.Sprintf("http://localhost:%d", c.Port) - } else { - c.BaseURL = "http://localhost:8000" // Default value - } - } - - return nil -} - -// GetMCPPaths returns the list of paths that should be proxied to the MCP server -func (c *Config) GetMCPPaths() []string { - return []string{c.Paths.SSE, c.Paths.Messages} -} - -// BuildExecCommand constructs the full command string for execution in stdio mode -func (c *Config) BuildExecCommand() string { - if c.Stdio.UserCommand == "" { - return "" - } - - // Construct the full command - return fmt.Sprintf( - `npx -y supergateway --stdio "%s" --port %d --baseUrl %s --ssePath %s --messagePath %s`, - c.Stdio.UserCommand, c.Port, c.BaseURL, c.Paths.SSE, c.Paths.Messages, - ) -} - // LoadConfig reads a YAML config file into Config struct. func LoadConfig(path string) (*Config, error) { f, err := os.Open(path) @@ -165,26 +89,8 @@ func LoadConfig(path string) (*Config, error) { if err := decoder.Decode(&cfg); err != nil { return nil, err } - - // Set default values if cfg.TimeoutSeconds == 0 { cfg.TimeoutSeconds = 15 // default } - - // Set default transport mode if not specified - if cfg.TransportMode == "" { - cfg.TransportMode = SSETransport // Default to SSE - } - - // Set default port if not specified - if cfg.Port == 0 { - cfg.Port = 8000 // default - } - - // Validate the configuration - if err := cfg.Validate(); err != nil { - return nil, err - } - return &cfg, nil } diff --git a/internal/config/config_test.go b/internal/config/config_test.go deleted file mode 100644 index 20c0893..0000000 --- a/internal/config/config_test.go +++ /dev/null @@ -1,196 +0,0 @@ -package config - -import ( - "os" - "path/filepath" - "testing" -) - -func TestLoadConfig(t *testing.T) { - // Create a temporary config file - tempDir := t.TempDir() - configPath := filepath.Join(tempDir, "test_config.yaml") - - // Basic valid config - validConfig := ` -listen_port: 8080 -base_url: "http://localhost:8000" -transport_mode: "sse" -paths: - sse: "/sse" - messages: "/messages" -cors: - allowed_origins: - - "http://localhost:5173" - allowed_methods: - - "GET" - - "POST" - allowed_headers: - - "Authorization" - - "Content-Type" - allow_credentials: true -` - err := os.WriteFile(configPath, []byte(validConfig), 0644) - if err != nil { - t.Fatalf("Failed to create test config file: %v", err) - } - - // Test loading the valid config - cfg, err := LoadConfig(configPath) - if err != nil { - t.Fatalf("Failed to load valid config: %v", err) - } - - // Verify expected values from the config - if cfg.ListenPort != 8080 { - t.Errorf("Expected ListenPort=8080, got %d", cfg.ListenPort) - } - if cfg.BaseURL != "http://localhost:8000" { - t.Errorf("Expected BaseURL=http://localhost:8000, got %s", cfg.BaseURL) - } - if cfg.TransportMode != SSETransport { - t.Errorf("Expected TransportMode=sse, got %s", cfg.TransportMode) - } - if cfg.Paths.SSE != "/sse" { - t.Errorf("Expected Paths.SSE=/sse, got %s", cfg.Paths.SSE) - } - if cfg.Paths.Messages != "/messages" { - t.Errorf("Expected Paths.Messages=/messages, got %s", cfg.Paths.Messages) - } - - // Test default values - if cfg.TimeoutSeconds != 15 { - t.Errorf("Expected default TimeoutSeconds=15, got %d", cfg.TimeoutSeconds) - } - if cfg.Port != 8000 { - t.Errorf("Expected default Port=8000, got %d", cfg.Port) - } -} - -func TestValidate(t *testing.T) { - tests := []struct { - name string - config Config - expectError bool - }{ - { - name: "Valid SSE config", - config: Config{ - TransportMode: SSETransport, - Paths: PathsConfig{ - SSE: "/sse", - Messages: "/messages", - }, - BaseURL: "http://localhost:8000", - }, - expectError: false, - }, - { - name: "Valid stdio config", - config: Config{ - TransportMode: StdioTransport, - Stdio: StdioConfig{ - Enabled: true, - UserCommand: "some-command", - }, - }, - expectError: false, - }, - { - name: "Invalid stdio config - not enabled", - config: Config{ - TransportMode: StdioTransport, - Stdio: StdioConfig{ - Enabled: false, - UserCommand: "some-command", - }, - }, - expectError: true, - }, - { - name: "Invalid stdio config - no command", - config: Config{ - TransportMode: StdioTransport, - Stdio: StdioConfig{ - Enabled: true, - UserCommand: "", - }, - }, - expectError: true, - }, - } - - for _, tc := range tests { - t.Run(tc.name, func(t *testing.T) { - err := tc.config.Validate() - if tc.expectError && err == nil { - t.Errorf("Expected validation error but got none") - } - if !tc.expectError && err != nil { - t.Errorf("Expected no validation error but got: %v", err) - } - }) - } -} - -func TestGetMCPPaths(t *testing.T) { - cfg := Config{ - Paths: PathsConfig{ - SSE: "/custom-sse", - Messages: "/custom-messages", - }, - } - - paths := cfg.GetMCPPaths() - if len(paths) != 2 { - t.Errorf("Expected 2 MCP paths, got %d", len(paths)) - } - if paths[0] != "/custom-sse" { - t.Errorf("Expected first path=/custom-sse, got %s", paths[0]) - } - if paths[1] != "/custom-messages" { - t.Errorf("Expected second path=/custom-messages, got %s", paths[1]) - } -} - -func TestBuildExecCommand(t *testing.T) { - tests := []struct { - name string - config Config - expectedResult string - }{ - { - name: "Valid command", - config: Config{ - Stdio: StdioConfig{ - UserCommand: "test-command", - }, - Port: 8080, - BaseURL: "http://example.com", - Paths: PathsConfig{ - SSE: "/sse-path", - Messages: "/msgs", - }, - }, - expectedResult: `npx -y supergateway --stdio "test-command" --port 8080 --baseUrl http://example.com --ssePath /sse-path --messagePath /msgs`, - }, - { - name: "Empty command", - config: Config{ - Stdio: StdioConfig{ - UserCommand: "", - }, - }, - expectedResult: "", - }, - } - - for _, tc := range tests { - t.Run(tc.name, func(t *testing.T) { - result := tc.config.BuildExecCommand() - if result != tc.expectedResult { - t.Errorf("Expected command=%s, got %s", tc.expectedResult, result) - } - }) - } -} diff --git a/internal/logging/logger.go b/internal/logging/logger.go deleted file mode 100644 index 57bec27..0000000 --- a/internal/logging/logger.go +++ /dev/null @@ -1,34 +0,0 @@ -package logger - -import ( - "log" -) - -var isDebug = false - -// SetDebug enables or disables debug logging -func SetDebug(debug bool) { - isDebug = debug -} - -// Debug logs a debug-level message -func Debug(format string, v ...interface{}) { - if isDebug { - log.Printf("DEBUG: "+format, v...) - } -} - -// Info logs an info-level message -func Info(format string, v ...interface{}) { - log.Printf("INFO: "+format, v...) -} - -// Warn logs a warning-level message -func Warn(format string, v ...interface{}) { - log.Printf("WARN: "+format, v...) -} - -// Error logs an error-level message -func Error(format string, v ...interface{}) { - log.Printf("ERROR: "+format, v...) -} diff --git a/internal/proxy/modifier.go b/internal/proxy/modifier.go index 6662b2c..8e2268b 100644 --- a/internal/proxy/modifier.go +++ b/internal/proxy/modifier.go @@ -9,7 +9,6 @@ import ( "strings" "github.com/wso2/open-mcp-auth-proxy/internal/config" - "github.com/wso2/open-mcp-auth-proxy/internal/logging" ) // RequestModifier modifies requests before they are proxied @@ -149,7 +148,6 @@ func (m *RegisterModifier) ModifyRequest(req *http.Request) (*http.Request, erro if strings.Contains(contentType, "application/x-www-form-urlencoded") { // Parse form data if err := req.ParseForm(); err != nil { - logger.Error("Failed to parse form data: %v", err) return nil, err } @@ -171,14 +169,12 @@ func (m *RegisterModifier) ModifyRequest(req *http.Request) (*http.Request, erro // Read body bodyBytes, err := io.ReadAll(req.Body) if err != nil { - logger.Error("Failed to read request body: %v", err) return nil, err } // Parse JSON var jsonData map[string]interface{} if err := json.Unmarshal(bodyBytes, &jsonData); err != nil { - logger.Error("Failed to parse JSON body: %v", err) return nil, err } @@ -190,7 +186,6 @@ func (m *RegisterModifier) ModifyRequest(req *http.Request) (*http.Request, erro // Marshal back to JSON modifiedBody, err := json.Marshal(jsonData) if err != nil { - logger.Error("Failed to marshal modified JSON: %v", err) return nil, err } diff --git a/internal/proxy/modifier_test.go b/internal/proxy/modifier_test.go deleted file mode 100644 index 3d2fd44..0000000 --- a/internal/proxy/modifier_test.go +++ /dev/null @@ -1,147 +0,0 @@ -package proxy - -import ( - "net/http" - "net/url" - "strings" - "testing" - - "github.com/wso2/open-mcp-auth-proxy/internal/config" -) - -func TestAuthorizationModifier(t *testing.T) { - cfg := &config.Config{ - Default: config.DefaultConfig{ - Path: map[string]config.PathConfig{ - "/authorize": { - AddQueryParams: []config.ParamConfig{ - {Name: "client_id", Value: "test-client-id"}, - {Name: "scope", Value: "openid"}, - }, - }, - }, - }, - } - - modifier := &AuthorizationModifier{Config: cfg} - - // Create a test request - req, err := http.NewRequest("GET", "/authorize?response_type=code", nil) - if err != nil { - t.Fatalf("Failed to create request: %v", err) - } - - // Modify the request - modifiedReq, err := modifier.ModifyRequest(req) - if err != nil { - t.Fatalf("ModifyRequest failed: %v", err) - } - - // Check that the query parameters were added - query := modifiedReq.URL.Query() - if query.Get("client_id") != "test-client-id" { - t.Errorf("Expected client_id=test-client-id, got %s", query.Get("client_id")) - } - if query.Get("scope") != "openid" { - t.Errorf("Expected scope=openid, got %s", query.Get("scope")) - } - if query.Get("response_type") != "code" { - t.Errorf("Expected response_type=code, got %s", query.Get("response_type")) - } -} - -func TestTokenModifier(t *testing.T) { - cfg := &config.Config{ - Default: config.DefaultConfig{ - Path: map[string]config.PathConfig{ - "/token": { - AddBodyParams: []config.ParamConfig{ - {Name: "audience", Value: "test-audience"}, - }, - }, - }, - }, - } - - modifier := &TokenModifier{Config: cfg} - - // Create a test request with form data - form := url.Values{} - - req, err := http.NewRequest("POST", "/token", strings.NewReader(form.Encode())) - if err != nil { - t.Fatalf("Failed to create request: %v", err) - } - req.Header.Set("Content-Type", "application/x-www-form-urlencoded") - - // Modify the request - modifiedReq, err := modifier.ModifyRequest(req) - if err != nil { - t.Fatalf("ModifyRequest failed: %v", err) - } - - body := make([]byte, 1024) - n, err := modifiedReq.Body.Read(body) - if err != nil && err.Error() != "EOF" { - t.Fatalf("Failed to read body: %v", err) - } - bodyStr := string(body[:n]) - - // Parse the form data from the modified request - if err := modifiedReq.ParseForm(); err != nil { - t.Fatalf("Failed to parse form data: %v", err) - } - - // Check that the body parameters were added - if !strings.Contains(bodyStr, "audience") { - t.Errorf("Expected body to contain audience, got %s", bodyStr) - } -} - -func TestRegisterModifier(t *testing.T) { - cfg := &config.Config{ - Default: config.DefaultConfig{ - Path: map[string]config.PathConfig{ - "/register": { - AddBodyParams: []config.ParamConfig{ - {Name: "client_name", Value: "test-client"}, - }, - }, - }, - }, - } - - modifier := &RegisterModifier{Config: cfg} - - // Create a test request with JSON data - jsonBody := `{"redirect_uris":["https://example.com/callback"]}` - req, err := http.NewRequest("POST", "/register", strings.NewReader(jsonBody)) - if err != nil { - t.Fatalf("Failed to create request: %v", err) - } - req.Header.Set("Content-Type", "application/json") - - // Modify the request - modifiedReq, err := modifier.ModifyRequest(req) - if err != nil { - t.Fatalf("ModifyRequest failed: %v", err) - } - - // Read the body and check that it still contains the original data - // This test would need to be enhanced with a proper JSON parsing to verify - // the added parameters - body := make([]byte, 1024) - n, err := modifiedReq.Body.Read(body) - if err != nil && err.Error() != "EOF" { - t.Fatalf("Failed to read body: %v", err) - } - bodyStr := string(body[:n]) - - // Simple check to see if the modified body contains the expected fields - if !strings.Contains(bodyStr, "client_name") { - t.Errorf("Expected body to contain client_name, got %s", bodyStr) - } - if !strings.Contains(bodyStr, "redirect_uris") { - t.Errorf("Expected body to contain redirect_uris, got %s", bodyStr) - } -} diff --git a/internal/proxy/proxy.go b/internal/proxy/proxy.go index 33a9ea3..c999be4 100644 --- a/internal/proxy/proxy.go +++ b/internal/proxy/proxy.go @@ -2,6 +2,7 @@ package proxy import ( "context" + "log" "net/http" "net/http/httputil" "net/url" @@ -10,7 +11,6 @@ import ( "github.com/wso2/open-mcp-auth-proxy/internal/authz" "github.com/wso2/open-mcp-auth-proxy/internal/config" - "github.com/wso2/open-mcp-auth-proxy/internal/logging" "github.com/wso2/open-mcp-auth-proxy/internal/util" ) @@ -82,8 +82,7 @@ func NewRouter(cfg *config.Config, provider authz.Provider) http.Handler { } // MCP paths - mcpPaths := cfg.GetMCPPaths() - for _, path := range mcpPaths { + for _, path := range cfg.MCPPaths { mux.HandleFunc(path, buildProxyHandler(cfg, modifiers)) registeredPaths[path] = true } @@ -101,21 +100,23 @@ func NewRouter(cfg *config.Config, provider authz.Provider) http.Handler { func buildProxyHandler(cfg *config.Config, modifiers map[string]RequestModifier) http.HandlerFunc { // Parse the base URLs up front + authBase, err := url.Parse(cfg.AuthServerBaseURL) if err != nil { - logger.Error("Invalid auth server URL: %v", err) - panic(err) // Fatal error that prevents startup + log.Fatalf("Invalid auth server URL: %v", err) } - - mcpBase, err := url.Parse(cfg.BaseURL) + mcpBase, err := url.Parse(cfg.MCPServerBaseURL) if err != nil { - logger.Error("Invalid MCP server URL: %v", err) - panic(err) // Fatal error that prevents startup + log.Fatalf("Invalid MCP server URL: %v", err) } // Detect SSE paths from config ssePaths := make(map[string]bool) - ssePaths[cfg.Paths.SSE] = true + for _, p := range cfg.MCPPaths { + if p == "/sse" { + ssePaths[p] = true + } + } return func(w http.ResponseWriter, r *http.Request) { origin := r.Header.Get("Origin") @@ -123,7 +124,7 @@ func buildProxyHandler(cfg *config.Config, modifiers map[string]RequestModifier) // Handle OPTIONS if r.Method == http.MethodOptions { if allowedOrigin == "" { - logger.Warn("Preflight request from disallowed origin: %s", origin) + log.Printf("[proxy] Preflight request from disallowed origin: %s", origin) http.Error(w, "CORS origin not allowed", http.StatusForbidden) return } @@ -133,7 +134,7 @@ func buildProxyHandler(cfg *config.Config, modifiers map[string]RequestModifier) } if allowedOrigin == "" { - logger.Warn("Request from disallowed origin: %s for %s", origin, r.URL.Path) + log.Printf("[proxy] Request from disallowed origin: %s for %s", origin, r.URL.Path) http.Error(w, "CORS origin not allowed", http.StatusForbidden) return } @@ -151,7 +152,7 @@ func buildProxyHandler(cfg *config.Config, modifiers map[string]RequestModifier) // Validate JWT for MCP paths if required // Placeholder for JWT validation logic if err := util.ValidateJWT(r.Header.Get("Authorization")); err != nil { - logger.Warn("Unauthorized request to %s: %v", r.URL.Path, err) + log.Printf("[proxy] Unauthorized request to %s: %v", r.URL.Path, err) http.Error(w, "Unauthorized", http.StatusUnauthorized) return } @@ -169,7 +170,7 @@ func buildProxyHandler(cfg *config.Config, modifiers map[string]RequestModifier) var err error r, err = modifier.ModifyRequest(r) if err != nil { - logger.Error("Error modifying request: %v", err) + log.Printf("[proxy] Error modifying request: %v", err) http.Error(w, "Bad Request", http.StatusBadRequest) return } @@ -191,13 +192,7 @@ func buildProxyHandler(cfg *config.Config, modifiers map[string]RequestModifier) req.Host = targetURL.Host cleanHeaders := http.Header{} - - // Set proper origin header to match the target - if isSSE { - // For SSE, ensure origin matches the target - req.Header.Set("Origin", targetURL.Scheme+"://"+targetURL.Host) - } - + for k, v := range r.Header { // Skip hop-by-hop headers if skipHeader(k) { @@ -210,33 +205,21 @@ func buildProxyHandler(cfg *config.Config, modifiers map[string]RequestModifier) req.Header = cleanHeaders - logger.Debug("%s -> %s%s", r.URL.Path, req.URL.Host, req.URL.Path) + log.Printf("[proxy] %s -> %s%s", r.URL.Path, req.URL.Host, req.URL.Path) }, ModifyResponse: func(resp *http.Response) error { - logger.Debug("Response from %s%s: %d", resp.Request.URL.Host, resp.Request.URL.Path, resp.StatusCode) + log.Printf("[proxy] Response from %s%s: %d", resp.Request.URL.Host, resp.Request.URL.Path, resp.StatusCode) resp.Header.Del("Access-Control-Allow-Origin") // Avoid upstream conflicts return nil }, ErrorHandler: func(rw http.ResponseWriter, req *http.Request, err error) { - logger.Error("Error proxying: %v", err) + log.Printf("[proxy] Error proxying: %v", err) http.Error(rw, "Bad Gateway", http.StatusBadGateway) }, FlushInterval: -1, // immediate flush for SSE } if isSSE { - // Add special response handling for SSE connections to rewrite endpoint URLs - rp.Transport = &sseTransport{ - Transport: http.DefaultTransport, - proxyHost: r.Host, - targetHost: targetURL.Host, - } - - // Set SSE-specific headers - w.Header().Set("X-Accel-Buffering", "no") - w.Header().Set("Cache-Control", "no-cache") - w.Header().Set("Connection", "keep-alive") - // Keep SSE connections open HandleSSE(w, r, rp) } else { @@ -253,7 +236,6 @@ func getAllowedOrigin(origin string, cfg *config.Config) string { return cfg.CORSConfig.AllowedOrigins[0] // Default to first allowed origin } for _, allowed := range cfg.CORSConfig.AllowedOrigins { - logger.Debug("Checking CORS origin: %s against allowed: %s", origin, allowed) if allowed == origin { return allowed } @@ -274,7 +256,6 @@ func addCORSHeaders(w http.ResponseWriter, cfg *config.Config, allowedOrigin, re w.Header().Set("Access-Control-Allow-Credentials", "true") } w.Header().Set("Vary", "Origin") - w.Header().Set("X-Accel-Buffering", "no") } func isAuthPath(path string) bool { @@ -292,8 +273,7 @@ func isAuthPath(path string) bool { // isMCPPath checks if the path is an MCP path func isMCPPath(path string, cfg *config.Config) bool { - mcpPaths := cfg.GetMCPPaths() - for _, p := range mcpPaths { + for _, p := range cfg.MCPPaths { if strings.HasPrefix(path, p) { return true } diff --git a/internal/proxy/sse.go b/internal/proxy/sse.go index ce72e04..44d6558 100644 --- a/internal/proxy/sse.go +++ b/internal/proxy/sse.go @@ -1,16 +1,11 @@ package proxy import ( - "bufio" "context" - "fmt" - "io" + "log" "net/http" "net/http/httputil" - "strings" "time" - - "github.com/wso2/open-mcp-auth-proxy/internal/logging" ) // HandleSSE sets up a go-routine to wait for context cancellation @@ -21,7 +16,7 @@ func HandleSSE(w http.ResponseWriter, r *http.Request, rp *httputil.ReverseProxy go func() { <-ctx.Done() - logger.Info("SSE connection closed from %s (path: %s)", r.RemoteAddr, r.URL.Path) + log.Printf("INFO: SSE connection closed from %s (path: %s)", r.RemoteAddr, r.URL.Path) close(done) }() @@ -37,73 +32,3 @@ func HandleSSE(w http.ResponseWriter, r *http.Request, rp *httputil.ReverseProxy func NewShutdownContext(timeout time.Duration) (context.Context, context.CancelFunc) { return context.WithTimeout(context.Background(), timeout) } - -// sseTransport is a custom http.RoundTripper that intercepts and modifies SSE responses -type sseTransport struct { - Transport http.RoundTripper - proxyHost string - targetHost string -} - -func (t *sseTransport) RoundTrip(req *http.Request) (*http.Response, error) { - // Call the underlying transport - resp, err := t.Transport.RoundTrip(req) - if err != nil { - return nil, err - } - - // Check if this is an SSE response - contentType := resp.Header.Get("Content-Type") - if !strings.Contains(contentType, "text/event-stream") { - return resp, nil - } - - logger.Info("Intercepting SSE response to modify endpoint events") - - // Create a response wrapper that modifies the response body - originalBody := resp.Body - pr, pw := io.Pipe() - - go func() { - defer originalBody.Close() - defer pw.Close() - - scanner := bufio.NewScanner(originalBody) - for scanner.Scan() { - line := scanner.Text() - - // Check if this line contains an endpoint event - if strings.HasPrefix(line, "event: endpoint") { - // Read the data line - if scanner.Scan() { - dataLine := scanner.Text() - if strings.HasPrefix(dataLine, "data: ") { - // Extract the endpoint URL - endpoint := strings.TrimPrefix(dataLine, "data: ") - - // Replace the host in the endpoint - logger.Debug("Original endpoint: %s", endpoint) - endpoint = strings.Replace(endpoint, t.targetHost, t.proxyHost, 1) - logger.Debug("Modified endpoint: %s", endpoint) - - // Write the modified event lines - fmt.Fprintln(pw, line) - fmt.Fprintln(pw, "data: "+endpoint) - continue - } - } - } - - // Write the original line for non-endpoint events - fmt.Fprintln(pw, line) - } - - if err := scanner.Err(); err != nil { - logger.Error("Error reading SSE stream: %v", err) - } - }() - - // Replace the response body with our modified pipe - resp.Body = pr - return resp, nil -} diff --git a/internal/subprocess/manager.go b/internal/subprocess/manager.go deleted file mode 100644 index fa64337..0000000 --- a/internal/subprocess/manager.go +++ /dev/null @@ -1,268 +0,0 @@ -package subprocess - -import ( - "fmt" - "os" - "os/exec" - "sync" - "syscall" - "time" - "strings" - - "github.com/wso2/open-mcp-auth-proxy/internal/config" - "github.com/wso2/open-mcp-auth-proxy/internal/logging" -) - -// Manager handles starting and graceful shutdown of subprocesses -type Manager struct { - process *os.Process - processGroup int - mutex sync.Mutex - cmd *exec.Cmd - shutdownDelay time.Duration -} - -// NewManager creates a new subprocess manager -func NewManager() *Manager { - return &Manager{ - shutdownDelay: 5 * time.Second, - } -} - -// EnsureDependenciesAvailable checks and installs required package executors -func EnsureDependenciesAvailable(command string) error { - // Always ensure npx is available regardless of the command - if _, err := exec.LookPath("npx"); err != nil { - // npx is not available, check if npm is installed - if _, err := exec.LookPath("npm"); err != nil { - return fmt.Errorf("npx not found and npm not available; please install Node.js from https://nodejs.org/") - } - - // Try to install npx using npm - logger.Info("npx not found, attempting to install...") - cmd := exec.Command("npm", "install", "-g", "npx") - cmd.Stdout = os.Stdout - cmd.Stderr = os.Stderr - - if err := cmd.Run(); err != nil { - return fmt.Errorf("failed to install npx: %w", err) - } - - logger.Info("npx installed successfully") - } - - // Check if uv is needed based on the command - if strings.Contains(command, "uv ") { - if _, err := exec.LookPath("uv"); err != nil { - return fmt.Errorf("command requires uv but it's not installed; please install it following instructions at https://github.com/astral-sh/uv") - } - } - - return nil -} - -// SetShutdownDelay sets the maximum time to wait for graceful shutdown -func (m *Manager) SetShutdownDelay(duration time.Duration) { - m.shutdownDelay = duration -} - -// Start launches a subprocess based on the configuration -func (m *Manager) Start(cfg *config.Config) error { - m.mutex.Lock() - defer m.mutex.Unlock() - - // If a process is already running, return an error - if m.process != nil { - return os.ErrExist - } - - if !cfg.Stdio.Enabled || cfg.Stdio.UserCommand == "" { - return nil // Nothing to start - } - - // Get the full command string - execCommand := cfg.BuildExecCommand() - if execCommand == "" { - return nil // No command to execute - } - - logger.Info("Starting subprocess with command: %s", execCommand) - - // Use the shell to execute the command - cmd := exec.Command("sh", "-c", execCommand) - - // Set working directory if specified - if cfg.Stdio.WorkDir != "" { - cmd.Dir = cfg.Stdio.WorkDir - } - - // Set environment variables if specified - if len(cfg.Stdio.Env) > 0 { - cmd.Env = append(os.Environ(), cfg.Stdio.Env...) - } - - // Capture stdout/stderr - cmd.Stdout = os.Stdout - cmd.Stderr = os.Stderr - - // Set the process group for proper termination - cmd.SysProcAttr = &syscall.SysProcAttr{Setpgid: true} - - // Start the process - if err := cmd.Start(); err != nil { - return err - } - - m.process = cmd.Process - m.cmd = cmd - logger.Info("Subprocess started with PID: %d", m.process.Pid) - - // Get and store the process group ID - pgid, err := syscall.Getpgid(m.process.Pid) - if err == nil { - m.processGroup = pgid - logger.Debug("Process group ID: %d", m.processGroup) - } else { - logger.Warn("Failed to get process group ID: %v", err) - m.processGroup = m.process.Pid - } - - // Handle process termination in background - go func() { - if err := cmd.Wait(); err != nil { - logger.Error("Subprocess exited with error: %v", err) - } else { - logger.Info("Subprocess exited successfully") - } - - // Clear the process reference when it exits - m.mutex.Lock() - m.process = nil - m.cmd = nil - m.mutex.Unlock() - }() - - return nil -} - -// IsRunning checks if the subprocess is running -func (m *Manager) IsRunning() bool { - m.mutex.Lock() - defer m.mutex.Unlock() - return m.process != nil -} - -// Shutdown gracefully terminates the subprocess -func (m *Manager) Shutdown() { - m.mutex.Lock() - processToTerminate := m.process // Local copy of the process reference - processGroupToTerminate := m.processGroup - m.mutex.Unlock() - - if processToTerminate == nil { - return // No process to terminate - } - - logger.Info("Terminating subprocess...") - terminateComplete := make(chan struct{}) - - go func() { - defer close(terminateComplete) - - // Try graceful termination first with SIGTERM - terminatedGracefully := false - - // Try to terminate the process group first - if processGroupToTerminate != 0 { - err := syscall.Kill(-processGroupToTerminate, syscall.SIGTERM) - if err != nil { - logger.Warn("Failed to send SIGTERM to process group: %v", err) - - // Fallback to terminating just the process - m.mutex.Lock() - if m.process != nil { - err = m.process.Signal(syscall.SIGTERM) - if err != nil { - logger.Warn("Failed to send SIGTERM to process: %v", err) - } - } - m.mutex.Unlock() - } - } else { - // Try to terminate just the process - m.mutex.Lock() - if m.process != nil { - err := m.process.Signal(syscall.SIGTERM) - if err != nil { - logger.Warn("Failed to send SIGTERM to process: %v", err) - } - } - m.mutex.Unlock() - } - - // Wait for the process to exit gracefully - for i := 0; i < 10; i++ { - time.Sleep(200 * time.Millisecond) - - m.mutex.Lock() - if m.process == nil { - terminatedGracefully = true - m.mutex.Unlock() - break - } - m.mutex.Unlock() - } - - if terminatedGracefully { - logger.Info("Subprocess terminated gracefully") - return - } - - // If the process didn't exit gracefully, force kill - logger.Warn("Subprocess didn't exit gracefully, forcing termination...") - - // Try to kill the process group first - if processGroupToTerminate != 0 { - if err := syscall.Kill(-processGroupToTerminate, syscall.SIGKILL); err != nil { - logger.Warn("Failed to send SIGKILL to process group: %v", err) - - // Fallback to killing just the process - m.mutex.Lock() - if m.process != nil { - if err := m.process.Kill(); err != nil { - logger.Error("Failed to kill process: %v", err) - } - } - m.mutex.Unlock() - } - } else { - // Try to kill just the process - m.mutex.Lock() - if m.process != nil { - if err := m.process.Kill(); err != nil { - logger.Error("Failed to kill process: %v", err) - } - } - m.mutex.Unlock() - } - - // Wait a bit more to confirm termination - time.Sleep(500 * time.Millisecond) - - m.mutex.Lock() - if m.process == nil { - logger.Info("Subprocess terminated by force") - } else { - logger.Warn("Failed to terminate subprocess") - } - m.mutex.Unlock() - }() - - // Wait for termination with timeout - select { - case <-terminateComplete: - // Termination completed - case <-time.After(m.shutdownDelay): - logger.Warn("Subprocess termination timed out") - } -} diff --git a/internal/util/jwks.go b/internal/util/jwks.go index f80d82e..4832bf8 100644 --- a/internal/util/jwks.go +++ b/internal/util/jwks.go @@ -4,12 +4,12 @@ import ( "crypto/rsa" "encoding/json" "errors" + "log" "math/big" "net/http" "strings" "github.com/golang-jwt/jwt/v4" - "github.com/wso2/open-mcp-auth-proxy/internal/logging" ) type JWKS struct { @@ -50,7 +50,7 @@ func FetchJWKS(jwksURL string) error { publicKeys[parsedKey.Kid] = pubKey } } - logger.Info("Loaded %d public keys.", len(publicKeys)) + log.Printf("[JWKS] Loaded %d public keys.", len(publicKeys)) return nil } diff --git a/internal/util/jwks_test.go b/internal/util/jwks_test.go deleted file mode 100644 index 3b00c68..0000000 --- a/internal/util/jwks_test.go +++ /dev/null @@ -1,143 +0,0 @@ -package util - -import ( - "crypto/rand" - "crypto/rsa" - "encoding/base64" - "encoding/json" - "net/http" - "net/http/httptest" - "testing" - "time" - - "github.com/golang-jwt/jwt/v4" -) - -func TestValidateJWT(t *testing.T) { - // Initialize the test JWKS data - initTestJWKS(t) - - // Test cases - tests := []struct { - name string - authHeader string - expectError bool - }{ - { - name: "Valid JWT token", - authHeader: "Bearer " + createValidJWT(t), - expectError: false, - }, - { - name: "No auth header", - authHeader: "", - expectError: true, - }, - { - name: "Invalid auth header format", - authHeader: "InvalidFormat", - expectError: true, - }, - { - name: "Invalid JWT token", - authHeader: "Bearer invalid.jwt.token", - expectError: true, - }, - } - - for _, tc := range tests { - t.Run(tc.name, func(t *testing.T) { - err := ValidateJWT(tc.authHeader) - if tc.expectError && err == nil { - t.Errorf("Expected error but got none") - } - if !tc.expectError && err != nil { - t.Errorf("Expected no error but got: %v", err) - } - }) - } -} - -func TestFetchJWKS(t *testing.T) { - // Create a mock JWKS server - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - // Generate a test RSA key - privateKey, err := rsa.GenerateKey(rand.Reader, 2048) - if err != nil { - t.Fatalf("Failed to generate RSA key: %v", err) - } - - // Create JWKS response - jwks := map[string]interface{}{ - "keys": []map[string]interface{}{ - { - "kty": "RSA", - "kid": "test-key-id", - "n": base64.RawURLEncoding.EncodeToString(privateKey.N.Bytes()), - "e": base64.RawURLEncoding.EncodeToString([]byte{1, 0, 1}), // Default exponent 65537 - }, - }, - } - - w.Header().Set("Content-Type", "application/json") - json.NewEncoder(w).Encode(jwks) - })) - defer server.Close() - - // Test fetching JWKS - err := FetchJWKS(server.URL) - if err != nil { - t.Fatalf("FetchJWKS failed: %v", err) - } - - // Check that keys were stored - if len(publicKeys) == 0 { - t.Errorf("Expected publicKeys to be populated") - } -} - -// Helper function to initialize test JWKS data -func initTestJWKS(t *testing.T) { - // Create a test RSA key pair - privateKey, err := rsa.GenerateKey(rand.Reader, 2048) - if err != nil { - t.Fatalf("Failed to generate RSA key: %v", err) - } - - // Initialize the publicKeys map - publicKeys = map[string]*rsa.PublicKey{ - "test-key-id": &privateKey.PublicKey, - } -} - -// Helper function to create a valid JWT token for testing -func createValidJWT(t *testing.T) string { - // Create a test RSA key pair - privateKey, err := rsa.GenerateKey(rand.Reader, 2048) - if err != nil { - t.Fatalf("Failed to generate RSA key: %v", err) - } - - // Ensure the test key is in the publicKeys map - if publicKeys == nil { - publicKeys = map[string]*rsa.PublicKey{} - } - publicKeys["test-key-id"] = &privateKey.PublicKey - - // Create token - token := jwt.NewWithClaims(jwt.SigningMethodRS256, jwt.MapClaims{ - "sub": "1234567890", - "name": "Test User", - "iat": time.Now().Unix(), - "exp": time.Now().Add(time.Hour).Unix(), - }) - token.Header["kid"] = "test-key-id" - - // Sign the token - tokenString, err := token.SignedString(privateKey) - if err != nil { - t.Fatalf("Failed to sign token: %v", err) - } - - return tokenString -} diff --git a/pull_request_template.md b/pull_request_template.md index c401a06..9b32185 100644 --- a/pull_request_template.md +++ b/pull_request_template.md @@ -1,11 +1,52 @@ ## Purpose - +> Describe the problems, issues, or needs driving this feature/fix and include links to related issues in the following format: Resolves issue1, issue2, etc. -## Related Issues - +## Goals +> Describe the solutions that this feature/fix will introduce to resolve the problems described above + +## Approach +> Describe how you are implementing the solutions. Include an animated GIF or screenshot if the change affects the UI (email documentation@wso2.com to review all UI text). Include a link to a Markdown file or Google doc if the feature write-up is too long to paste here. + +## User stories +> Summary of user stories addressed by this change> + +## Release note +> Brief description of the new feature or bug fix as it will appear in the release notes + +## Documentation +> Link(s) to product documentation that addresses the changes of this PR. If no doc impact, enter “N/A” plus brief explanation of why there’s no doc impact + +## Training +> Link to the PR for changes to the training content in https://github.com/wso2/WSO2-Training, if applicable + +## Certification +> Type “Sent” when you have provided new/updated certification questions, plus four answers for each question (correct answer highlighted in bold), based on this change. Certification questions/answers should be sent to certification@wso2.com and NOT pasted in this PR. If there is no impact on certification exams, type “N/A” and explain why. + +## Marketing +> Link to drafts of marketing content that will describe and promote this feature, including product page changes, technical articles, blog posts, videos, etc., if applicable + +## Automation tests + - Unit tests + > Code coverage information + - Integration tests + > Details about the test cases and coverage + +## Security checks + - Followed secure coding standards in http://wso2.com/technical-reports/wso2-secure-engineering-guidelines? yes/no + - Ran FindSecurityBugs plugin and verified report? yes/no + - Confirmed that this PR doesn't commit any keys, passwords, tokens, usernames, or other secrets? yes/no + +## Samples +> Provide high-level details about the samples related to this feature ## Related PRs - +> List any other related PRs ## Migrations (if applicable) - +> Describe migration steps and platforms on which migration has been tested + +## Test environment +> List all JDK versions, operating systems, databases, and browser/versions on which this feature/fix was tested + +## Learning +> Describe the research phase and any blog posts, patterns, libraries, or add-ons you used to solve the problem. \ No newline at end of file diff --git a/resources/requirements.txt b/resources/requirements.txt deleted file mode 100644 index 102b728..0000000 --- a/resources/requirements.txt +++ /dev/null @@ -1 +0,0 @@ -fastmcp==0.4.1 \ No newline at end of file