feat: relay logs from connectToRemoteServer to proxy local transport

This commit is contained in:
justin 2025-05-07 21:08:33 -04:00
parent e642a56563
commit 86aee1c203
3 changed files with 76 additions and 51 deletions

View file

@ -1,5 +1,4 @@
import { EventEmitter } from 'events' import { EventEmitter } from 'events'
import { LoggingLevel } from '@modelcontextprotocol/sdk/types.js'
/** /**
* Options for creating an OAuth client provider * Options for creating an OAuth client provider
@ -36,10 +35,6 @@ export interface OAuthCallbackServerOptions {
} }
/* /*
* Message sending helper type * Connection status types used for logging (via local transport, in proxy mode)
*/ */
export interface MCPLogMessageParams { export type ConnStatus = 'connected' | 'connecting' | 'reconnecting' | 'authenticating' | 'error' | 'error_final'
level: LoggingLevel
logger: string
data: Record<string, any>
}

View file

@ -3,7 +3,9 @@ import { Client } from '@modelcontextprotocol/sdk/client/index.js'
import { SSEClientTransport } from '@modelcontextprotocol/sdk/client/sse.js' import { SSEClientTransport } from '@modelcontextprotocol/sdk/client/sse.js'
import { StreamableHTTPClientTransport } from '@modelcontextprotocol/sdk/client/streamableHttp.js' import { StreamableHTTPClientTransport } from '@modelcontextprotocol/sdk/client/streamableHttp.js'
import { Transport } from '@modelcontextprotocol/sdk/shared/transport.js' import { Transport } from '@modelcontextprotocol/sdk/shared/transport.js'
import type { MCPLogMessageParams } from './types' import { StdioServerTransport } from '@modelcontextprotocol/sdk/server/stdio.js'
import { LoggingLevel } from '@modelcontextprotocol/sdk/types.js'
import { ConnStatus } from './types'
// Connection constants // Connection constants
export const REASON_AUTH_NEEDED = 'authentication-needed' export const REASON_AUTH_NEEDED = 'authentication-needed'
@ -74,6 +76,29 @@ export function mcpProxy({ transportToClient, transportToServer }: { transportTo
} }
} }
/**
* Extended StdioServerTransport class
*/
export class StdioServerTransportExt extends StdioServerTransport {
/**
* Send a log message through the transport
* @param level The log level ('error' | 'debug' | 'info' | 'notice' | 'warning' | 'critical' | 'alert' | 'emergency')
* @param data The data object to send (should be JSON serializable)
* @param logger Optional logger name, defaults to 'mcp-remote'
*/
sendMessage(level: LoggingLevel, data: any, logger: string = 'mcp-remote') {
return this.send({
jsonrpc: '2.0',
method: 'notifications/message',
params: {
level,
logger,
data,
},
})
}
}
/** /**
* Type for the auth initialization function * Type for the auth initialization function
*/ */
@ -100,9 +125,21 @@ export async function connectToRemoteServer(
headers: Record<string, string>, headers: Record<string, string>,
authInitializer: AuthInitializer, authInitializer: AuthInitializer,
transportStrategy: TransportStrategy = 'http-first', transportStrategy: TransportStrategy = 'http-first',
localTransport: StdioServerTransportExt | null = null,
recursionReasons: Set<string> = new Set(), recursionReasons: Set<string> = new Set(),
): Promise<Transport> { ): Promise<Transport> {
log(`[${pid}] Connecting to remote server: ${serverUrl}`) const _log = (level: LoggingLevel, message: any, status: ConnStatus) => {
// If localTransport is provided (proxy mode), send the message to it
if (localTransport) {
localTransport.sendMessage(level, {
status,
message,
})
}
log(message)
}
_log('info', `[${pid}] Connecting to remote server: ${serverUrl}`, 'connecting')
const url = new URL(serverUrl) const url = new URL(serverUrl)
// Create transport with eventSourceInit to pass Authorization header if present // Create transport with eventSourceInit to pass Authorization header if present
@ -122,7 +159,7 @@ export async function connectToRemoteServer(
}, },
} }
log(`Using transport strategy: ${transportStrategy}`) _log('info', `Using transport strategy: ${transportStrategy}`, 'connecting')
// Determine if we should attempt to fallback on error // Determine if we should attempt to fallback on error
// Choose transport based on user strategy and recursion history // Choose transport based on user strategy and recursion history
const shouldAttemptFallback = transportStrategy === 'http-first' || transportStrategy === 'sse-first' const shouldAttemptFallback = transportStrategy === 'http-first' || transportStrategy === 'sse-first'
@ -155,7 +192,7 @@ export async function connectToRemoteServer(
await testClient.connect(testTransport) await testClient.connect(testTransport)
} }
} }
log(`Connected to remote server using ${transport.constructor.name}`) _log('info', `Connected to remote server using ${transport.constructor.name}`, 'connected')
return transport return transport
} catch (error) { } catch (error) {
@ -168,16 +205,16 @@ export async function connectToRemoteServer(
error.message.includes('404') || error.message.includes('404') ||
error.message.includes('Not Found')) error.message.includes('Not Found'))
) { ) {
log(`Received error: ${error.message}`) _log('error', `Received error: ${error.message}`, 'error')
// If we've already tried falling back once, throw an error // If we've already tried falling back once, throw an error
if (recursionReasons.has(REASON_TRANSPORT_FALLBACK)) { if (recursionReasons.has(REASON_TRANSPORT_FALLBACK)) {
const errorMessage = `Already attempted transport fallback. Giving up.` const errorMessage = `Already attempted transport fallback. Giving up.`
log(errorMessage) _log('error', errorMessage, 'error_final')
throw new Error(errorMessage) throw new Error(errorMessage)
} }
log(`Recursively reconnecting for reason: ${REASON_TRANSPORT_FALLBACK}`) _log('info', `Recursively reconnecting for reason: ${REASON_TRANSPORT_FALLBACK}`, 'reconnecting')
// Add to recursion reasons set // Add to recursion reasons set
recursionReasons.add(REASON_TRANSPORT_FALLBACK) recursionReasons.add(REASON_TRANSPORT_FALLBACK)
@ -190,45 +227,55 @@ export async function connectToRemoteServer(
headers, headers,
authInitializer, authInitializer,
sseTransport ? 'http-only' : 'sse-only', sseTransport ? 'http-only' : 'sse-only',
localTransport,
recursionReasons, recursionReasons,
) )
} else if (error instanceof UnauthorizedError || (error instanceof Error && error.message.includes('Unauthorized'))) { } else if (error instanceof UnauthorizedError || (error instanceof Error && error.message.includes('Unauthorized'))) {
log('Authentication required. Initializing auth...') _log('info', 'Authentication required. Initializing auth...', 'authenticating')
// Initialize authentication on-demand // Initialize authentication on-demand
const { waitForAuthCode, skipBrowserAuth } = await authInitializer() const { waitForAuthCode, skipBrowserAuth } = await authInitializer()
if (skipBrowserAuth) { if (skipBrowserAuth) {
log('Authentication required but skipping browser auth - using shared auth') _log('info', 'Authentication required but skipping browser auth - using shared auth', 'authenticating')
} else { } else {
log('Authentication required. Waiting for authorization...') _log('info', 'Authentication required. Waiting for authorization...', 'authenticating')
} }
// Wait for the authorization code from the callback // Wait for the authorization code from the callback
const code = await waitForAuthCode() const code = await waitForAuthCode()
try { try {
log('Completing authorization...') _log('info', 'Completing authorization...', 'authenticating')
await transport.finishAuth(code) await transport.finishAuth(code)
if (recursionReasons.has(REASON_AUTH_NEEDED)) { if (recursionReasons.has(REASON_AUTH_NEEDED)) {
const errorMessage = `Already attempted reconnection for reason: ${REASON_AUTH_NEEDED}. Giving up.` const errorMessage = `Already attempted reconnection for reason: ${REASON_AUTH_NEEDED}. Giving up.`
log(errorMessage) _log('error', errorMessage, 'error_final')
throw new Error(errorMessage) throw new Error(errorMessage)
} }
// Track this reason for recursion // Track this reason for recursion
recursionReasons.add(REASON_AUTH_NEEDED) recursionReasons.add(REASON_AUTH_NEEDED)
log(`Recursively reconnecting for reason: ${REASON_AUTH_NEEDED}`) _log('info', `Recursively reconnecting for reason: ${REASON_AUTH_NEEDED}`, 'reconnecting')
// Recursively call connectToRemoteServer with the updated recursion tracking // Recursively call connectToRemoteServer with the updated recursion tracking
return connectToRemoteServer(client, serverUrl, authProvider, headers, authInitializer, transportStrategy, recursionReasons) return connectToRemoteServer(
client,
serverUrl,
authProvider,
headers,
authInitializer,
transportStrategy,
localTransport,
recursionReasons,
)
} catch (authError) { } catch (authError) {
log('Authorization error:', authError) _log('error', `Authorization error: ${authError}`, 'error_final')
throw authError throw authError
} }
} else { } else {
log('Connection error:', error) _log('error', `Connection error: ${error}`, 'error_final')
throw error throw error
} }
} }
@ -488,19 +535,3 @@ export function setupSignalHandlers(cleanup: () => Promise<void>) {
export function getServerUrlHash(serverUrl: string): string { export function getServerUrlHash(serverUrl: string): string {
return crypto.createHash('md5').update(serverUrl).digest('hex') return crypto.createHash('md5').update(serverUrl).digest('hex')
} }
/**
* Helper function to send log messages through stdio MCP transport, for proxy logging purposes
* @param transport The transport to send the message through
* @param params Message parameters including level, logger name, and data
*
* In accordance with the official MCP specification
* @see https://modelcontextprotocol.io/specification/2025-03-26/server/utilities/logging#log-message-notifications
*/
export function sendLog(transport: Transport, params: MCPLogMessageParams) {
return transport.send({
jsonrpc: '2.0',
method: 'notifications/message',
params: { ...params },
})
}

View file

@ -10,7 +10,7 @@
*/ */
import { EventEmitter } from 'events' import { EventEmitter } from 'events'
import { StdioServerTransport } from '@modelcontextprotocol/sdk/server/stdio.js' import { StdioServerTransportExt } from './lib/utils'
import { import {
connectToRemoteServer, connectToRemoteServer,
log, log,
@ -20,7 +20,6 @@ import {
getServerUrlHash, getServerUrlHash,
MCP_REMOTE_VERSION, MCP_REMOTE_VERSION,
TransportStrategy, TransportStrategy,
sendLog,
} from './lib/utils' } from './lib/utils'
import { NodeOAuthClientProvider } from './lib/node-oauth-client-provider' import { NodeOAuthClientProvider } from './lib/node-oauth-client-provider'
import { createLazyAuthCoordinator } from './lib/coordination' import { createLazyAuthCoordinator } from './lib/coordination'
@ -51,7 +50,7 @@ async function runProxy(
}) })
// Create the STDIO transport for local connections // Create the STDIO transport for local connections
const localTransport = new StdioServerTransport() const localTransport = new StdioServerTransportExt()
// Keep track of the server instance for cleanup // Keep track of the server instance for cleanup
let server: any = null let server: any = null
@ -77,17 +76,17 @@ async function runProxy(
} }
} }
sendLog(localTransport, {
level: 'debug',
logger: 'mcp-remote',
data: {
message: 'Connecting to remote server...',
},
})
try { try {
// Connect to remote server with lazy authentication // Connect to remote server with lazy authentication
const remoteTransport = await connectToRemoteServer(null, serverUrl, authProvider, headers, authInitializer, transportStrategy) const remoteTransport = await connectToRemoteServer(
null,
serverUrl,
authProvider,
headers,
authInitializer,
transportStrategy,
localTransport,
)
// Set up bidirectional proxy between local and remote transports // Set up bidirectional proxy between local and remote transports
mcpProxy({ mcpProxy({