From a3b4906afd9e8eb1e396ffc10d31a365c03100b8 Mon Sep 17 00:00:00 2001 From: Glen Maddern Date: Thu, 10 Apr 2025 11:27:25 +1000 Subject: [PATCH] Added --header CLI args support This can include ${ENV_VAR} strings that are replaced with the values in process.env --- README.md | 23 +++++++++++++++++++++++ package.json | 2 +- src/client.ts | 8 ++++---- src/lib/utils.ts | 44 ++++++++++++++++++++++++++++++++++++++++---- src/proxy.ts | 8 ++++---- 5 files changed, 72 insertions(+), 13 deletions(-) diff --git a/README.md b/README.md index ac30c88..ded8937 100644 --- a/README.md +++ b/README.md @@ -32,6 +32,29 @@ All the most popular MCP clients (Claude Desktop, Cursor & Windsurf) use the fol } ``` +### Custom Headers + +To bypass authentication, or to emit custom headers on all requests to your remote server, pass `--header` CLI arguments: + +```json +{ + "mcpServers": { + "remote-example": { + "command": "npx", + "args": [ + "mcp-remote", + "https://remote.mcp.server/sse", + "--header", + "Authorization: Bearer ${AUTH_TOKEN}" + ] + }, + "env": { + "AUTH_TOKEN": "..." + } + } +} +``` + ### Flags * If `npx` is producing errors, consider adding `-y` as the first argument to auto-accept the installation of the `mcp-remote` package. diff --git a/package.json b/package.json index d8ab40f..26c26bf 100644 --- a/package.json +++ b/package.json @@ -23,8 +23,8 @@ "mcp-remote-client": "dist/client.js" }, "scripts": { - "dev": "tsup --watch", "build": "tsup", + "build:watch": "tsup --watch", "check": "prettier --check . && tsc" }, "dependencies": { diff --git a/src/client.ts b/src/client.ts index 5b908c3..9a4b7c4 100644 --- a/src/client.ts +++ b/src/client.ts @@ -24,7 +24,7 @@ import { coordinateAuth } from './lib/coordination' /** * Main function to run the client */ -async function runClient(serverUrl: string, callbackPort: number, clean: boolean = false) { +async function runClient(serverUrl: string, callbackPort: number, headers: Record, clean: boolean = false) { // Set up event emitter for auth flow const events = new EventEmitter() @@ -64,7 +64,7 @@ async function runClient(serverUrl: string, callbackPort: number, clean: boolean // Create the transport factory const url = new URL(serverUrl) function initTransport() { - const transport = new SSEClientTransport(url, { authProvider }) + const transport = new SSEClientTransport(url, { authProvider, requestInit: { headers } }) // Set up message and error handlers transport.onmessage = (message) => { @@ -160,8 +160,8 @@ async function runClient(serverUrl: string, callbackPort: number, clean: boolean // Parse command-line arguments and run the client parseCommandLineArgs(process.argv.slice(2), 3333, 'Usage: npx tsx client.ts [--clean] [callback-port]') - .then(({ serverUrl, callbackPort, clean }) => { - return runClient(serverUrl, callbackPort, clean) + .then(({ serverUrl, callbackPort, clean, headers }) => { + return runClient(serverUrl, callbackPort, headers, clean) }) .catch((error) => { console.error('Fatal error:', error) diff --git a/src/lib/utils.ts b/src/lib/utils.ts index a37fe61..30688af 100644 --- a/src/lib/utils.ts +++ b/src/lib/utils.ts @@ -68,6 +68,7 @@ export function mcpProxy({ transportToClient, transportToServer }: { transportTo * Creates and connects to a remote SSE server with OAuth authentication * @param serverUrl The URL of the remote server * @param authProvider The OAuth client provider + * @param headers Additional headers to send with the request * @param waitForAuthCode Function to wait for the auth code * @param skipBrowserAuth Whether to skip browser auth and use shared auth * @returns The connected SSE client transport @@ -75,12 +76,13 @@ export function mcpProxy({ transportToClient, transportToServer }: { transportTo export async function connectToRemoteServer( serverUrl: string, authProvider: OAuthClientProvider, + headers: Record, waitForAuthCode: () => Promise, skipBrowserAuth: boolean = false, ): Promise { log(`[${pid}] Connecting to remote server: ${serverUrl}`) const url = new URL(serverUrl) - const transport = new SSEClientTransport(url, { authProvider }) + const transport = new SSEClientTransport(url, { authProvider, requestInit: { headers } }) try { await transport.start() @@ -102,7 +104,7 @@ export async function connectToRemoteServer( await transport.finishAuth(code) // Create a new transport after auth - const newTransport = new SSEClientTransport(url, { authProvider }) + const newTransport = new SSEClientTransport(url, { authProvider, requestInit: { headers } }) await newTransport.start() log('Connected to remote server after authentication') return newTransport @@ -255,7 +257,7 @@ export async function findAvailablePort(preferredPort?: number): Promise * @param args Command line arguments * @param defaultPort Default port for the callback server if specified port is unavailable * @param usage Usage message to show on error - * @returns A promise that resolves to an object with parsed serverUrl, callbackPort, and clean flag + * @returns A promise that resolves to an object with parsed serverUrl, callbackPort, clean flag, and headers */ export async function parseCommandLineArgs(args: string[], defaultPort: number, usage: string) { // Check for --clean flag @@ -267,6 +269,21 @@ export async function parseCommandLineArgs(args: string[], defaultPort: number, args.splice(cleanIndex, 1) } + // Process headers + const headers: Record = {} + args.forEach((arg, i) => { + if (arg === '--header' && i < args.length - 1) { + const value = args[i + 1] + const match = value.match(/^([A-Za-z0-9_-]+):(.*)$/) + if (match) { + headers[match[1]] = match[2] + } else { + log(`Warning: ignoring invalid header argument: ${value}`) + } + args.splice(i, 2) + } + }) + const serverUrl = args[0] const specifiedPort = args[1] ? parseInt(args[1]) : undefined @@ -296,7 +313,26 @@ export async function parseCommandLineArgs(args: string[], defaultPort: number, log('Clean mode enabled: config files will be reset before reading') } - return { serverUrl, callbackPort, clean } + if (Object.keys(headers).length > 0) { + log(`Using custom headers: ${JSON.stringify(headers)}`) + } + // Replace environment variables in headers + // example `Authorization: Bearer ${TOKEN}` will read process.env.TOKEN + for (const [key, value] of Object.entries(headers)) { + headers[key] = value.replace(/\$\{([^}]+)}/g, (match, envVarName) => { + const envVarValue = process.env[envVarName] + + if (envVarValue !== undefined) { + log(`Replacing ${match} with environment value in header '${key}'`) + return envVarValue + } else { + log(`Warning: Environment variable '${envVarName}' not found for header '${key}'.`) + return '' + } + }) + } + + return { serverUrl, callbackPort, clean, headers } } /** diff --git a/src/proxy.ts b/src/proxy.ts index 4c8d75c..f9a415f 100644 --- a/src/proxy.ts +++ b/src/proxy.ts @@ -21,7 +21,7 @@ import { coordinateAuth } from './lib/coordination' /** * Main function to run the proxy */ -async function runProxy(serverUrl: string, callbackPort: number, clean: boolean = false) { +async function runProxy(serverUrl: string, callbackPort: number, headers: Record, clean: boolean = false) { // Set up event emitter for auth flow const events = new EventEmitter() @@ -52,7 +52,7 @@ async function runProxy(serverUrl: string, callbackPort: number, clean: boolean try { // Connect to remote server with authentication - const remoteTransport = await connectToRemoteServer(serverUrl, authProvider, waitForAuthCode, skipBrowserAuth) + const remoteTransport = await connectToRemoteServer(serverUrl, authProvider, headers, waitForAuthCode, skipBrowserAuth) // Set up bidirectional proxy between local and remote transports mcpProxy({ @@ -104,8 +104,8 @@ to the CA certificate file. If using claude_desktop_config.json, this might look // Parse command-line arguments and run the proxy parseCommandLineArgs(process.argv.slice(2), 3334, 'Usage: npx tsx proxy.ts [--clean] [callback-port]') - .then(({ serverUrl, callbackPort, clean }) => { - return runProxy(serverUrl, callbackPort, clean) + .then(({ serverUrl, callbackPort, clean, headers }) => { + return runProxy(serverUrl, callbackPort, headers, clean) }) .catch((error) => { log('Fatal error:', error)