Added the ability to pass custom headers, and encrypt specific header values

This commit is contained in:
Charles Robertson 2025-04-11 21:13:24 +01:00
parent 84b87375fb
commit baddd03e0e
3 changed files with 3634 additions and 8 deletions

3495
package-lock.json generated Normal file

File diff suppressed because it is too large Load diff

View file

@ -6,6 +6,9 @@ import express from 'express'
import net from 'net'
import crypto from 'crypto'
const iv = crypto.randomBytes(16);
const algorithm = 'aes-256-cbc';
// Package version from package.json
export const MCP_REMOTE_VERSION = require('../../package.json').version
@ -64,6 +67,105 @@ export function mcpProxy({ transportToClient, transportToServer }: { transportTo
}
}
/**
* Encyrypt data
* @param data The data to be encryppted
* @param secretKey The secret key that is used, along with an IV, to encrypt data
* @returns An encrypted string
*/
export function encrypt(data: string, secretKey: string) {
const key = crypto
.createHash("sha512")
.update(secretKey)
.digest("hex")
.substring(0, 32);
const cipher = crypto.createCipheriv(algorithm, Buffer.from(key), iv);
let encrypted = cipher.update(data, "utf-8", "hex");
encrypted += cipher.final("hex");
// Package the IV and encrypted data together so it can be stored in a single
// column in the database.
return iv.toString("hex") + encrypted;
}
/**
* Decrypt data
* @param data The data to be decrypted
* @param secretKey The secret key that is used, along with an IV, to decrypt data
* @returns A decrypted string
*/
export function decrypt(data: string, secretKey: string) {
const key = crypto
.createHash("sha512")
.update(secretKey)
.digest("hex")
.substring(0, 32);
// Unpackage the combined iv + encrypted message. Since we are using a fixed
// size IV, we can hard code the slice length.
const inputIV = data.slice(0, 32);
const encrypted = data.slice(32);
const decipher = crypto.createDecipheriv(
algorithm,
Buffer.from(key),
Buffer.from(inputIV, "hex"),
);
let decrypted = decipher.update(encrypted, "hex", "utf-8");
decrypted += decipher.final("utf-8");
return decrypted;
}
/**
* Creates a headers object
* @param headers A string that is passed in the arguments, from the AI client config file; the argument is preceded by another argument called --headers
* @param keysForEncryption The header object keys, whose values require encryption
* @param secretKey The secret key that is used, along with an IV, to encrypt/decrypt data
* @returns A headers object
*/
export function parseHeaders(
headers: string,
keysForEncryption: string,
secretKey: string
): any {
const headersArr = headers.split(',');
let credentials: any = {};
if (headersArr.length > 0) {
headersArr.map((val, idx) => {
const keyValArr = val.split(':');
let k = '';
let v = '';
keyValArr.map((val, idx) => {
if (idx === 0) {
k = val.toLowerCase().trim();
} else {
v = val.trim();
}
});
if (k !== '') {
credentials[k] = v;
}
});
const keysForEncryptionArr = keysForEncryption.split(',');
for (const property in credentials) {
const found = keysForEncryptionArr.find(
(element) => element === property
);
if (found && secretKey in credentials) {
const encrypted = encrypt(credentials[property], credentials[secretKey]);
const decrypted = decrypt(encrypted, credentials[secretKey]);
credentials[property] = encrypted;
// now delete the secret so that it is not sent to the remote MCP server via SSE transport
delete credentials[secretKey];
}
}
}
return credentials;
}
/**
* Creates and connects to a remote SSE server with OAuth authentication
* @param serverUrl The URL of the remote server
@ -77,10 +179,21 @@ export async function connectToRemoteServer(
authProvider: OAuthClientProvider,
waitForAuthCode: () => Promise<string>,
skipBrowserAuth: boolean = false,
headers: string = '',
): Promise<SSEClientTransport> {
log(`[${pid}] Connecting to remote server: ${serverUrl}`)
const url = new URL(serverUrl)
const transport = new SSEClientTransport(url, { authProvider })
const credentials = parseHeaders(headers, 'password', 'secret');
const requestInit = {
body: headers,
headers: credentials
}
const transport = new SSEClientTransport(url, {
authProvider,
requestInit
})
try {
await transport.start()
@ -102,7 +215,10 @@ export async function connectToRemoteServer(
await transport.finishAuth(code)
// Create a new transport after auth
const newTransport = new SSEClientTransport(url, { authProvider })
const newTransport = new SSEClientTransport(url, {
authProvider,
requestInit
})
await newTransport.start()
log('Connected to remote server after authentication')
return newTransport
@ -296,7 +412,22 @@ export async function parseCommandLineArgs(args: string[], defaultPort: number,
log('Clean mode enabled: config files will be reset before reading')
}
return { serverUrl, callbackPort, clean }
// Check for --header flag
const headerIndex = args.indexOf('--header')
const header = headerIndex !== -1
const headerValueIndex = headerIndex + 1
let headers = ''
// Remove the flag from args if it exists
if (header) {
if (headerValueIndex) {
headers = args[headerValueIndex]
args.splice(headerIndex, 2)
}
}
return { serverUrl, callbackPort, clean, headers }
}
/**

View file

@ -21,7 +21,7 @@ import { coordinateAuth } from './lib/coordination'
/**
* Main function to run the proxy
*/
async function runProxy(serverUrl: string, callbackPort: number, clean: boolean = false) {
async function runProxy(serverUrl: string, callbackPort: number, clean: boolean = false, headers: string = '') {
// Set up event emitter for auth flow
const events = new EventEmitter()
@ -52,7 +52,7 @@ async function runProxy(serverUrl: string, callbackPort: number, clean: boolean
try {
// Connect to remote server with authentication
const remoteTransport = await connectToRemoteServer(serverUrl, authProvider, waitForAuthCode, skipBrowserAuth)
const remoteTransport = await connectToRemoteServer(serverUrl, authProvider, waitForAuthCode, skipBrowserAuth, headers)
// Set up bidirectional proxy between local and remote transports
mcpProxy({
@ -104,8 +104,8 @@ to the CA certificate file. If using claude_desktop_config.json, this might look
// Parse command-line arguments and run the proxy
parseCommandLineArgs(process.argv.slice(2), 3334, 'Usage: npx tsx proxy.ts [--clean] <https://server-url> [callback-port]')
.then(({ serverUrl, callbackPort, clean }) => {
return runProxy(serverUrl, callbackPort, clean)
.then(({ serverUrl, callbackPort, clean, headers }) => {
return runProxy(serverUrl, callbackPort, clean, headers)
})
.catch((error) => {
log('Fatal error:', error)