Added the ability to pass custom headers, and encrypt specific header values

This commit is contained in:
Charles Robertson 2025-04-11 21:13:24 +01:00
parent 84b87375fb
commit baddd03e0e
3 changed files with 3634 additions and 8 deletions

3495
package-lock.json generated Normal file

File diff suppressed because it is too large Load diff

View file

@ -6,6 +6,9 @@ import express from 'express'
import net from 'net' import net from 'net'
import crypto from 'crypto' import crypto from 'crypto'
const iv = crypto.randomBytes(16);
const algorithm = 'aes-256-cbc';
// Package version from package.json // Package version from package.json
export const MCP_REMOTE_VERSION = require('../../package.json').version export const MCP_REMOTE_VERSION = require('../../package.json').version
@ -64,6 +67,105 @@ export function mcpProxy({ transportToClient, transportToServer }: { transportTo
} }
} }
/**
* Encyrypt data
* @param data The data to be encryppted
* @param secretKey The secret key that is used, along with an IV, to encrypt data
* @returns An encrypted string
*/
export function encrypt(data: string, secretKey: string) {
const key = crypto
.createHash("sha512")
.update(secretKey)
.digest("hex")
.substring(0, 32);
const cipher = crypto.createCipheriv(algorithm, Buffer.from(key), iv);
let encrypted = cipher.update(data, "utf-8", "hex");
encrypted += cipher.final("hex");
// Package the IV and encrypted data together so it can be stored in a single
// column in the database.
return iv.toString("hex") + encrypted;
}
/**
* Decrypt data
* @param data The data to be decrypted
* @param secretKey The secret key that is used, along with an IV, to decrypt data
* @returns A decrypted string
*/
export function decrypt(data: string, secretKey: string) {
const key = crypto
.createHash("sha512")
.update(secretKey)
.digest("hex")
.substring(0, 32);
// Unpackage the combined iv + encrypted message. Since we are using a fixed
// size IV, we can hard code the slice length.
const inputIV = data.slice(0, 32);
const encrypted = data.slice(32);
const decipher = crypto.createDecipheriv(
algorithm,
Buffer.from(key),
Buffer.from(inputIV, "hex"),
);
let decrypted = decipher.update(encrypted, "hex", "utf-8");
decrypted += decipher.final("utf-8");
return decrypted;
}
/**
* Creates a headers object
* @param headers A string that is passed in the arguments, from the AI client config file; the argument is preceded by another argument called --headers
* @param keysForEncryption The header object keys, whose values require encryption
* @param secretKey The secret key that is used, along with an IV, to encrypt/decrypt data
* @returns A headers object
*/
export function parseHeaders(
headers: string,
keysForEncryption: string,
secretKey: string
): any {
const headersArr = headers.split(',');
let credentials: any = {};
if (headersArr.length > 0) {
headersArr.map((val, idx) => {
const keyValArr = val.split(':');
let k = '';
let v = '';
keyValArr.map((val, idx) => {
if (idx === 0) {
k = val.toLowerCase().trim();
} else {
v = val.trim();
}
});
if (k !== '') {
credentials[k] = v;
}
});
const keysForEncryptionArr = keysForEncryption.split(',');
for (const property in credentials) {
const found = keysForEncryptionArr.find(
(element) => element === property
);
if (found && secretKey in credentials) {
const encrypted = encrypt(credentials[property], credentials[secretKey]);
const decrypted = decrypt(encrypted, credentials[secretKey]);
credentials[property] = encrypted;
// now delete the secret so that it is not sent to the remote MCP server via SSE transport
delete credentials[secretKey];
}
}
}
return credentials;
}
/** /**
* Creates and connects to a remote SSE server with OAuth authentication * Creates and connects to a remote SSE server with OAuth authentication
* @param serverUrl The URL of the remote server * @param serverUrl The URL of the remote server
@ -76,11 +178,22 @@ export async function connectToRemoteServer(
serverUrl: string, serverUrl: string,
authProvider: OAuthClientProvider, authProvider: OAuthClientProvider,
waitForAuthCode: () => Promise<string>, waitForAuthCode: () => Promise<string>,
skipBrowserAuth: boolean = false, skipBrowserAuth: boolean = false,
headers: string = '',
): Promise<SSEClientTransport> { ): Promise<SSEClientTransport> {
log(`[${pid}] Connecting to remote server: ${serverUrl}`) log(`[${pid}] Connecting to remote server: ${serverUrl}`)
const url = new URL(serverUrl) const url = new URL(serverUrl)
const transport = new SSEClientTransport(url, { authProvider })
const credentials = parseHeaders(headers, 'password', 'secret');
const requestInit = {
body: headers,
headers: credentials
}
const transport = new SSEClientTransport(url, {
authProvider,
requestInit
})
try { try {
await transport.start() await transport.start()
@ -102,7 +215,10 @@ export async function connectToRemoteServer(
await transport.finishAuth(code) await transport.finishAuth(code)
// Create a new transport after auth // Create a new transport after auth
const newTransport = new SSEClientTransport(url, { authProvider }) const newTransport = new SSEClientTransport(url, {
authProvider,
requestInit
})
await newTransport.start() await newTransport.start()
log('Connected to remote server after authentication') log('Connected to remote server after authentication')
return newTransport return newTransport
@ -296,7 +412,22 @@ export async function parseCommandLineArgs(args: string[], defaultPort: number,
log('Clean mode enabled: config files will be reset before reading') log('Clean mode enabled: config files will be reset before reading')
} }
return { serverUrl, callbackPort, clean } // Check for --header flag
const headerIndex = args.indexOf('--header')
const header = headerIndex !== -1
const headerValueIndex = headerIndex + 1
let headers = ''
// Remove the flag from args if it exists
if (header) {
if (headerValueIndex) {
headers = args[headerValueIndex]
args.splice(headerIndex, 2)
}
}
return { serverUrl, callbackPort, clean, headers }
} }
/** /**

View file

@ -21,7 +21,7 @@ import { coordinateAuth } from './lib/coordination'
/** /**
* Main function to run the proxy * Main function to run the proxy
*/ */
async function runProxy(serverUrl: string, callbackPort: number, clean: boolean = false) { async function runProxy(serverUrl: string, callbackPort: number, clean: boolean = false, headers: string = '') {
// Set up event emitter for auth flow // Set up event emitter for auth flow
const events = new EventEmitter() const events = new EventEmitter()
@ -52,7 +52,7 @@ async function runProxy(serverUrl: string, callbackPort: number, clean: boolean
try { try {
// Connect to remote server with authentication // Connect to remote server with authentication
const remoteTransport = await connectToRemoteServer(serverUrl, authProvider, waitForAuthCode, skipBrowserAuth) const remoteTransport = await connectToRemoteServer(serverUrl, authProvider, waitForAuthCode, skipBrowserAuth, headers)
// Set up bidirectional proxy between local and remote transports // Set up bidirectional proxy between local and remote transports
mcpProxy({ mcpProxy({
@ -104,8 +104,8 @@ to the CA certificate file. If using claude_desktop_config.json, this might look
// Parse command-line arguments and run the proxy // Parse command-line arguments and run the proxy
parseCommandLineArgs(process.argv.slice(2), 3334, 'Usage: npx tsx proxy.ts [--clean] <https://server-url> [callback-port]') parseCommandLineArgs(process.argv.slice(2), 3334, 'Usage: npx tsx proxy.ts [--clean] <https://server-url> [callback-port]')
.then(({ serverUrl, callbackPort, clean }) => { .then(({ serverUrl, callbackPort, clean, headers }) => {
return runProxy(serverUrl, callbackPort, clean) return runProxy(serverUrl, callbackPort, clean, headers)
}) })
.catch((error) => { .catch((error) => {
log('Fatal error:', error) log('Fatal error:', error)