diff --git a/src/lib/types.ts b/src/lib/types.ts index 723b93f..dbf824e 100644 --- a/src/lib/types.ts +++ b/src/lib/types.ts @@ -33,3 +33,11 @@ export interface OAuthCallbackServerOptions { /** Event emitter to signal when auth code is received */ events: EventEmitter } + +/* + * Configuration for the ping mechanism + */ +export interface PingConfig { + enabled: boolean + interval: number +} diff --git a/src/lib/utils.ts b/src/lib/utils.ts index a0a60dc..2feb5cf 100644 --- a/src/lib/utils.ts +++ b/src/lib/utils.ts @@ -4,7 +4,7 @@ import { SSEClientTransport } from '@modelcontextprotocol/sdk/client/sse.js' import { StreamableHTTPClientTransport } from '@modelcontextprotocol/sdk/client/streamableHttp.js' import { Transport } from '@modelcontextprotocol/sdk/shared/transport.js' import { OAuthClientInformationFull, OAuthClientInformationFullSchema } from '@modelcontextprotocol/sdk/shared/auth.js' -import { OAuthCallbackServerOptions } from './types' +import { OAuthCallbackServerOptions, PingConfig } from './types' import { getConfigFilePath, readJsonFile } from './mcp-auth-config' import express from 'express' import net from 'net' @@ -14,6 +14,7 @@ import fs from 'fs/promises' // Connection constants export const REASON_AUTH_NEEDED = 'authentication-needed' export const REASON_TRANSPORT_FALLBACK = 'falling-back-to-alternate-transport' +export const PING_INTERVAL_DEFAULT = 30000 // Transport strategy types export type TransportStrategy = 'sse-only' | 'http-only' | 'sse-first' | 'http-first' @@ -91,6 +92,44 @@ export type AuthInitializer = () => Promise<{ skipBrowserAuth: boolean }> +/** + * Sets up periodic ping to keep the connection alive + * @param transport The transport to ping + * @param config Ping configuration + * @returns A cleanup function to stop pinging + */ +export function setupPing(transport: Transport, config: PingConfig): () => void { + if (!config.enabled) { + return () => {} + } + + let pingTimeout: NodeJS.Timeout | null = null + let lastPingId = 0 + + const interval = config.interval * 1000 // convert ms to s + const pingInterval = setInterval(async () => { + const pingId = ++lastPingId + try { + // Docs: https://modelcontextprotocol.io/specification/2025-03-26/basic/utilities/ping + await transport.send({ + jsonrpc: '2.0', + id: `ping-${pingId}`, + method: 'ping', + }) + log(`Ping ${pingId} successful`) + } catch (error) { + log(`Ping ${pingId} failed:`, error) + } + }, interval) + + return () => { + if (pingTimeout) { + clearTimeout(pingTimeout) + } + clearInterval(pingInterval) + } +} + /** * Creates and connects to a remote server with OAuth authentication * @param client The client to connect with @@ -432,6 +471,21 @@ export async function parseCommandLineArgs(args: string[], usage: string) { i++ } + // Parse ping configuration + const keepAlive = args.includes('--keep-alive') + const pingIntervalIndex = args.indexOf('--ping-interval') + let pingInterval = PING_INTERVAL_DEFAULT + if (pingIntervalIndex !== -1 && pingIntervalIndex < args.length - 1) { + const intervalStr = args[pingIntervalIndex + 1] + const interval = parseInt(intervalStr) + if (!isNaN(interval) && interval > 0) { + pingInterval = interval + log(`Using ping interval: ${pingInterval} seconds`) + } else { + log(`Warning: Invalid ping interval "${args[pingIntervalIndex + 1]}". Using default: ${PING_INTERVAL_DEFAULT} seconds`) + } + } + const serverUrl = args[0] const specifiedPort = args[1] ? parseInt(args[1]) : undefined const allowHttp = args.includes('--allow-http') @@ -505,7 +559,16 @@ export async function parseCommandLineArgs(args: string[], usage: string) { }) } - return { serverUrl, callbackPort, headers, transportStrategy } + return { + serverUrl, + callbackPort, + headers, + transportStrategy, + pingConfig: { + enabled: keepAlive, + interval: pingInterval, + }, + } } /** diff --git a/src/proxy.ts b/src/proxy.ts index 535bfe2..199da09 100644 --- a/src/proxy.ts +++ b/src/proxy.ts @@ -18,11 +18,12 @@ import { parseCommandLineArgs, setupSignalHandlers, getServerUrlHash, - MCP_REMOTE_VERSION, TransportStrategy, + setupPing, } from './lib/utils' import { NodeOAuthClientProvider } from './lib/node-oauth-client-provider' import { createLazyAuthCoordinator } from './lib/coordination' +import { PingConfig } from './lib/types' /** * Main function to run the proxy @@ -32,6 +33,7 @@ async function runProxy( callbackPort: number, headers: Record, transportStrategy: TransportStrategy = 'http-first', + pingConfig: PingConfig, ) { // Set up event emitter for auth flow const events = new EventEmitter() @@ -80,6 +82,9 @@ async function runProxy( // Connect to remote server with lazy authentication const remoteTransport = await connectToRemoteServer(null, serverUrl, authProvider, headers, authInitializer, transportStrategy) + // Set up ping mechanism for remote transport + const stopPing = setupPing(remoteTransport, pingConfig) + // Set up bidirectional proxy between local and remote transports mcpProxy({ transportToClient: localTransport, @@ -89,11 +94,15 @@ async function runProxy( // Start the local STDIO server await localTransport.start() log('Local STDIO server running') + if (pingConfig.enabled) { + log(`Automatic ping enabled with ${pingConfig.interval} second interval`) + } log(`Proxy established successfully between local STDIO and remote ${remoteTransport.constructor.name}`) log('Press Ctrl+C to exit') // Setup cleanup handler const cleanup = async () => { + stopPing() await remoteTransport.close() await localTransport.close() // Only close the server if it was initialized @@ -136,8 +145,8 @@ to the CA certificate file. If using claude_desktop_config.json, this might look // Parse command-line arguments and run the proxy parseCommandLineArgs(process.argv.slice(2), 'Usage: npx tsx proxy.ts [callback-port]') - .then(({ serverUrl, callbackPort, headers, transportStrategy }) => { - return runProxy(serverUrl, callbackPort, headers, transportStrategy) + .then(({ serverUrl, callbackPort, headers, transportStrategy, pingConfig }) => { + return runProxy(serverUrl, callbackPort, headers, transportStrategy, pingConfig) }) .catch((error) => { log('Fatal error:', error)