feat: mcp proxy keep alive (ping) mechanism
This commit is contained in:
parent
7eecc9ca3f
commit
cc84c2ce10
3 changed files with 85 additions and 5 deletions
|
@ -33,3 +33,11 @@ export interface OAuthCallbackServerOptions {
|
||||||
/** Event emitter to signal when auth code is received */
|
/** Event emitter to signal when auth code is received */
|
||||||
events: EventEmitter
|
events: EventEmitter
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/*
|
||||||
|
* Configuration for the ping mechanism
|
||||||
|
*/
|
||||||
|
export interface PingConfig {
|
||||||
|
enabled: boolean
|
||||||
|
interval: number
|
||||||
|
}
|
||||||
|
|
|
@ -4,7 +4,7 @@ import { SSEClientTransport } from '@modelcontextprotocol/sdk/client/sse.js'
|
||||||
import { StreamableHTTPClientTransport } from '@modelcontextprotocol/sdk/client/streamableHttp.js'
|
import { StreamableHTTPClientTransport } from '@modelcontextprotocol/sdk/client/streamableHttp.js'
|
||||||
import { Transport } from '@modelcontextprotocol/sdk/shared/transport.js'
|
import { Transport } from '@modelcontextprotocol/sdk/shared/transport.js'
|
||||||
import { OAuthClientInformationFull, OAuthClientInformationFullSchema } from '@modelcontextprotocol/sdk/shared/auth.js'
|
import { OAuthClientInformationFull, OAuthClientInformationFullSchema } from '@modelcontextprotocol/sdk/shared/auth.js'
|
||||||
import { OAuthCallbackServerOptions } from './types'
|
import { OAuthCallbackServerOptions, PingConfig } from './types'
|
||||||
import { getConfigFilePath, readJsonFile } from './mcp-auth-config'
|
import { getConfigFilePath, readJsonFile } from './mcp-auth-config'
|
||||||
import express from 'express'
|
import express from 'express'
|
||||||
import net from 'net'
|
import net from 'net'
|
||||||
|
@ -14,6 +14,7 @@ import fs from 'fs/promises'
|
||||||
// Connection constants
|
// Connection constants
|
||||||
export const REASON_AUTH_NEEDED = 'authentication-needed'
|
export const REASON_AUTH_NEEDED = 'authentication-needed'
|
||||||
export const REASON_TRANSPORT_FALLBACK = 'falling-back-to-alternate-transport'
|
export const REASON_TRANSPORT_FALLBACK = 'falling-back-to-alternate-transport'
|
||||||
|
export const PING_INTERVAL_DEFAULT = 30000
|
||||||
|
|
||||||
// Transport strategy types
|
// Transport strategy types
|
||||||
export type TransportStrategy = 'sse-only' | 'http-only' | 'sse-first' | 'http-first'
|
export type TransportStrategy = 'sse-only' | 'http-only' | 'sse-first' | 'http-first'
|
||||||
|
@ -91,6 +92,44 @@ export type AuthInitializer = () => Promise<{
|
||||||
skipBrowserAuth: boolean
|
skipBrowserAuth: boolean
|
||||||
}>
|
}>
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Sets up periodic ping to keep the connection alive
|
||||||
|
* @param transport The transport to ping
|
||||||
|
* @param config Ping configuration
|
||||||
|
* @returns A cleanup function to stop pinging
|
||||||
|
*/
|
||||||
|
export function setupPing(transport: Transport, config: PingConfig): () => void {
|
||||||
|
if (!config.enabled) {
|
||||||
|
return () => {}
|
||||||
|
}
|
||||||
|
|
||||||
|
let pingTimeout: NodeJS.Timeout | null = null
|
||||||
|
let lastPingId = 0
|
||||||
|
|
||||||
|
const interval = config.interval * 1000 // convert ms to s
|
||||||
|
const pingInterval = setInterval(async () => {
|
||||||
|
const pingId = ++lastPingId
|
||||||
|
try {
|
||||||
|
// Docs: https://modelcontextprotocol.io/specification/2025-03-26/basic/utilities/ping
|
||||||
|
await transport.send({
|
||||||
|
jsonrpc: '2.0',
|
||||||
|
id: `ping-${pingId}`,
|
||||||
|
method: 'ping',
|
||||||
|
})
|
||||||
|
log(`Ping ${pingId} successful`)
|
||||||
|
} catch (error) {
|
||||||
|
log(`Ping ${pingId} failed:`, error)
|
||||||
|
}
|
||||||
|
}, interval)
|
||||||
|
|
||||||
|
return () => {
|
||||||
|
if (pingTimeout) {
|
||||||
|
clearTimeout(pingTimeout)
|
||||||
|
}
|
||||||
|
clearInterval(pingInterval)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Creates and connects to a remote server with OAuth authentication
|
* Creates and connects to a remote server with OAuth authentication
|
||||||
* @param client The client to connect with
|
* @param client The client to connect with
|
||||||
|
@ -432,6 +471,21 @@ export async function parseCommandLineArgs(args: string[], usage: string) {
|
||||||
i++
|
i++
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Parse ping configuration
|
||||||
|
const keepAlive = args.includes('--keep-alive')
|
||||||
|
const pingIntervalIndex = args.indexOf('--ping-interval')
|
||||||
|
let pingInterval = PING_INTERVAL_DEFAULT
|
||||||
|
if (pingIntervalIndex !== -1 && pingIntervalIndex < args.length - 1) {
|
||||||
|
const intervalStr = args[pingIntervalIndex + 1]
|
||||||
|
const interval = parseInt(intervalStr)
|
||||||
|
if (!isNaN(interval) && interval > 0) {
|
||||||
|
pingInterval = interval
|
||||||
|
log(`Using ping interval: ${pingInterval} seconds`)
|
||||||
|
} else {
|
||||||
|
log(`Warning: Invalid ping interval "${args[pingIntervalIndex + 1]}". Using default: ${PING_INTERVAL_DEFAULT} seconds`)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
const serverUrl = args[0]
|
const serverUrl = args[0]
|
||||||
const specifiedPort = args[1] ? parseInt(args[1]) : undefined
|
const specifiedPort = args[1] ? parseInt(args[1]) : undefined
|
||||||
const allowHttp = args.includes('--allow-http')
|
const allowHttp = args.includes('--allow-http')
|
||||||
|
@ -505,7 +559,16 @@ export async function parseCommandLineArgs(args: string[], usage: string) {
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
return { serverUrl, callbackPort, headers, transportStrategy }
|
return {
|
||||||
|
serverUrl,
|
||||||
|
callbackPort,
|
||||||
|
headers,
|
||||||
|
transportStrategy,
|
||||||
|
pingConfig: {
|
||||||
|
enabled: keepAlive,
|
||||||
|
interval: pingInterval,
|
||||||
|
},
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|
15
src/proxy.ts
15
src/proxy.ts
|
@ -18,11 +18,12 @@ import {
|
||||||
parseCommandLineArgs,
|
parseCommandLineArgs,
|
||||||
setupSignalHandlers,
|
setupSignalHandlers,
|
||||||
getServerUrlHash,
|
getServerUrlHash,
|
||||||
MCP_REMOTE_VERSION,
|
|
||||||
TransportStrategy,
|
TransportStrategy,
|
||||||
|
setupPing,
|
||||||
} from './lib/utils'
|
} from './lib/utils'
|
||||||
import { NodeOAuthClientProvider } from './lib/node-oauth-client-provider'
|
import { NodeOAuthClientProvider } from './lib/node-oauth-client-provider'
|
||||||
import { createLazyAuthCoordinator } from './lib/coordination'
|
import { createLazyAuthCoordinator } from './lib/coordination'
|
||||||
|
import { PingConfig } from './lib/types'
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Main function to run the proxy
|
* Main function to run the proxy
|
||||||
|
@ -32,6 +33,7 @@ async function runProxy(
|
||||||
callbackPort: number,
|
callbackPort: number,
|
||||||
headers: Record<string, string>,
|
headers: Record<string, string>,
|
||||||
transportStrategy: TransportStrategy = 'http-first',
|
transportStrategy: TransportStrategy = 'http-first',
|
||||||
|
pingConfig: PingConfig,
|
||||||
) {
|
) {
|
||||||
// Set up event emitter for auth flow
|
// Set up event emitter for auth flow
|
||||||
const events = new EventEmitter()
|
const events = new EventEmitter()
|
||||||
|
@ -80,6 +82,9 @@ async function runProxy(
|
||||||
// Connect to remote server with lazy authentication
|
// Connect to remote server with lazy authentication
|
||||||
const remoteTransport = await connectToRemoteServer(null, serverUrl, authProvider, headers, authInitializer, transportStrategy)
|
const remoteTransport = await connectToRemoteServer(null, serverUrl, authProvider, headers, authInitializer, transportStrategy)
|
||||||
|
|
||||||
|
// Set up ping mechanism for remote transport
|
||||||
|
const stopPing = setupPing(remoteTransport, pingConfig)
|
||||||
|
|
||||||
// Set up bidirectional proxy between local and remote transports
|
// Set up bidirectional proxy between local and remote transports
|
||||||
mcpProxy({
|
mcpProxy({
|
||||||
transportToClient: localTransport,
|
transportToClient: localTransport,
|
||||||
|
@ -89,11 +94,15 @@ async function runProxy(
|
||||||
// Start the local STDIO server
|
// Start the local STDIO server
|
||||||
await localTransport.start()
|
await localTransport.start()
|
||||||
log('Local STDIO server running')
|
log('Local STDIO server running')
|
||||||
|
if (pingConfig.enabled) {
|
||||||
|
log(`Automatic ping enabled with ${pingConfig.interval} second interval`)
|
||||||
|
}
|
||||||
log(`Proxy established successfully between local STDIO and remote ${remoteTransport.constructor.name}`)
|
log(`Proxy established successfully between local STDIO and remote ${remoteTransport.constructor.name}`)
|
||||||
log('Press Ctrl+C to exit')
|
log('Press Ctrl+C to exit')
|
||||||
|
|
||||||
// Setup cleanup handler
|
// Setup cleanup handler
|
||||||
const cleanup = async () => {
|
const cleanup = async () => {
|
||||||
|
stopPing()
|
||||||
await remoteTransport.close()
|
await remoteTransport.close()
|
||||||
await localTransport.close()
|
await localTransport.close()
|
||||||
// Only close the server if it was initialized
|
// Only close the server if it was initialized
|
||||||
|
@ -136,8 +145,8 @@ to the CA certificate file. If using claude_desktop_config.json, this might look
|
||||||
|
|
||||||
// Parse command-line arguments and run the proxy
|
// Parse command-line arguments and run the proxy
|
||||||
parseCommandLineArgs(process.argv.slice(2), 'Usage: npx tsx proxy.ts <https://server-url> [callback-port]')
|
parseCommandLineArgs(process.argv.slice(2), 'Usage: npx tsx proxy.ts <https://server-url> [callback-port]')
|
||||||
.then(({ serverUrl, callbackPort, headers, transportStrategy }) => {
|
.then(({ serverUrl, callbackPort, headers, transportStrategy, pingConfig }) => {
|
||||||
return runProxy(serverUrl, callbackPort, headers, transportStrategy)
|
return runProxy(serverUrl, callbackPort, headers, transportStrategy, pingConfig)
|
||||||
})
|
})
|
||||||
.catch((error) => {
|
.catch((error) => {
|
||||||
log('Fatal error:', error)
|
log('Fatal error:', error)
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue