Added --header CLI args support

This can include ${ENV_VAR} strings that are replaced with the values in process.env
This commit is contained in:
Glen Maddern 2025-04-10 11:27:25 +10:00
parent 84b87375fb
commit a3b4906afd
5 changed files with 72 additions and 13 deletions

View file

@ -24,7 +24,7 @@ import { coordinateAuth } from './lib/coordination'
/**
* Main function to run the client
*/
async function runClient(serverUrl: string, callbackPort: number, clean: boolean = false) {
async function runClient(serverUrl: string, callbackPort: number, headers: Record<string, string>, clean: boolean = false) {
// Set up event emitter for auth flow
const events = new EventEmitter()
@ -64,7 +64,7 @@ async function runClient(serverUrl: string, callbackPort: number, clean: boolean
// Create the transport factory
const url = new URL(serverUrl)
function initTransport() {
const transport = new SSEClientTransport(url, { authProvider })
const transport = new SSEClientTransport(url, { authProvider, requestInit: { headers } })
// Set up message and error handlers
transport.onmessage = (message) => {
@ -160,8 +160,8 @@ async function runClient(serverUrl: string, callbackPort: number, clean: boolean
// Parse command-line arguments and run the client
parseCommandLineArgs(process.argv.slice(2), 3333, 'Usage: npx tsx client.ts [--clean] <https://server-url> [callback-port]')
.then(({ serverUrl, callbackPort, clean }) => {
return runClient(serverUrl, callbackPort, clean)
.then(({ serverUrl, callbackPort, clean, headers }) => {
return runClient(serverUrl, callbackPort, headers, clean)
})
.catch((error) => {
console.error('Fatal error:', error)

View file

@ -68,6 +68,7 @@ export function mcpProxy({ transportToClient, transportToServer }: { transportTo
* Creates and connects to a remote SSE server with OAuth authentication
* @param serverUrl The URL of the remote server
* @param authProvider The OAuth client provider
* @param headers Additional headers to send with the request
* @param waitForAuthCode Function to wait for the auth code
* @param skipBrowserAuth Whether to skip browser auth and use shared auth
* @returns The connected SSE client transport
@ -75,12 +76,13 @@ export function mcpProxy({ transportToClient, transportToServer }: { transportTo
export async function connectToRemoteServer(
serverUrl: string,
authProvider: OAuthClientProvider,
headers: Record<string, string>,
waitForAuthCode: () => Promise<string>,
skipBrowserAuth: boolean = false,
): Promise<SSEClientTransport> {
log(`[${pid}] Connecting to remote server: ${serverUrl}`)
const url = new URL(serverUrl)
const transport = new SSEClientTransport(url, { authProvider })
const transport = new SSEClientTransport(url, { authProvider, requestInit: { headers } })
try {
await transport.start()
@ -102,7 +104,7 @@ export async function connectToRemoteServer(
await transport.finishAuth(code)
// Create a new transport after auth
const newTransport = new SSEClientTransport(url, { authProvider })
const newTransport = new SSEClientTransport(url, { authProvider, requestInit: { headers } })
await newTransport.start()
log('Connected to remote server after authentication')
return newTransport
@ -255,7 +257,7 @@ export async function findAvailablePort(preferredPort?: number): Promise<number>
* @param args Command line arguments
* @param defaultPort Default port for the callback server if specified port is unavailable
* @param usage Usage message to show on error
* @returns A promise that resolves to an object with parsed serverUrl, callbackPort, and clean flag
* @returns A promise that resolves to an object with parsed serverUrl, callbackPort, clean flag, and headers
*/
export async function parseCommandLineArgs(args: string[], defaultPort: number, usage: string) {
// Check for --clean flag
@ -267,6 +269,21 @@ export async function parseCommandLineArgs(args: string[], defaultPort: number,
args.splice(cleanIndex, 1)
}
// Process headers
const headers: Record<string, string> = {}
args.forEach((arg, i) => {
if (arg === '--header' && i < args.length - 1) {
const value = args[i + 1]
const match = value.match(/^([A-Za-z0-9_-]+):(.*)$/)
if (match) {
headers[match[1]] = match[2]
} else {
log(`Warning: ignoring invalid header argument: ${value}`)
}
args.splice(i, 2)
}
})
const serverUrl = args[0]
const specifiedPort = args[1] ? parseInt(args[1]) : undefined
@ -296,7 +313,26 @@ export async function parseCommandLineArgs(args: string[], defaultPort: number,
log('Clean mode enabled: config files will be reset before reading')
}
return { serverUrl, callbackPort, clean }
if (Object.keys(headers).length > 0) {
log(`Using custom headers: ${JSON.stringify(headers)}`)
}
// Replace environment variables in headers
// example `Authorization: Bearer ${TOKEN}` will read process.env.TOKEN
for (const [key, value] of Object.entries(headers)) {
headers[key] = value.replace(/\$\{([^}]+)}/g, (match, envVarName) => {
const envVarValue = process.env[envVarName]
if (envVarValue !== undefined) {
log(`Replacing ${match} with environment value in header '${key}'`)
return envVarValue
} else {
log(`Warning: Environment variable '${envVarName}' not found for header '${key}'.`)
return ''
}
})
}
return { serverUrl, callbackPort, clean, headers }
}
/**

View file

@ -21,7 +21,7 @@ import { coordinateAuth } from './lib/coordination'
/**
* Main function to run the proxy
*/
async function runProxy(serverUrl: string, callbackPort: number, clean: boolean = false) {
async function runProxy(serverUrl: string, callbackPort: number, headers: Record<string, string>, clean: boolean = false) {
// Set up event emitter for auth flow
const events = new EventEmitter()
@ -52,7 +52,7 @@ async function runProxy(serverUrl: string, callbackPort: number, clean: boolean
try {
// Connect to remote server with authentication
const remoteTransport = await connectToRemoteServer(serverUrl, authProvider, waitForAuthCode, skipBrowserAuth)
const remoteTransport = await connectToRemoteServer(serverUrl, authProvider, headers, waitForAuthCode, skipBrowserAuth)
// Set up bidirectional proxy between local and remote transports
mcpProxy({
@ -104,8 +104,8 @@ to the CA certificate file. If using claude_desktop_config.json, this might look
// Parse command-line arguments and run the proxy
parseCommandLineArgs(process.argv.slice(2), 3334, 'Usage: npx tsx proxy.ts [--clean] <https://server-url> [callback-port]')
.then(({ serverUrl, callbackPort, clean }) => {
return runProxy(serverUrl, callbackPort, clean)
.then(({ serverUrl, callbackPort, clean, headers }) => {
return runProxy(serverUrl, callbackPort, headers, clean)
})
.catch((error) => {
log('Fatal error:', error)