diff --git a/src/lib/types.ts b/src/lib/types.ts index 723b93f..93e5d54 100644 --- a/src/lib/types.ts +++ b/src/lib/types.ts @@ -33,3 +33,8 @@ export interface OAuthCallbackServerOptions { /** Event emitter to signal when auth code is received */ events: EventEmitter } + +/* + * Connection status types used for logging (via local transport, in proxy mode) + */ +export type ConnStatus = 'connected' | 'connecting' | 'reconnecting' | 'authenticating' | 'error' | 'error_final' diff --git a/src/lib/utils.ts b/src/lib/utils.ts index 572550c..ee92f39 100644 --- a/src/lib/utils.ts +++ b/src/lib/utils.ts @@ -3,6 +3,9 @@ import { Client } from '@modelcontextprotocol/sdk/client/index.js' import { SSEClientTransport } from '@modelcontextprotocol/sdk/client/sse.js' import { StreamableHTTPClientTransport } from '@modelcontextprotocol/sdk/client/streamableHttp.js' import { Transport } from '@modelcontextprotocol/sdk/shared/transport.js' +import { StdioServerTransport } from '@modelcontextprotocol/sdk/server/stdio.js' +import { LoggingLevel } from '@modelcontextprotocol/sdk/types.js' +import { ConnStatus } from './types' // Connection constants export const REASON_AUTH_NEEDED = 'authentication-needed' @@ -80,6 +83,29 @@ export function mcpProxy({ transportToClient, transportToServer }: { transportTo } } +/** + * Extended StdioServerTransport class + */ +export class StdioServerTransportExt extends StdioServerTransport { + /** + * Send a log message through the transport + * @param level The log level ('error' | 'debug' | 'info' | 'notice' | 'warning' | 'critical' | 'alert' | 'emergency') + * @param data The data object to send (should be JSON serializable) + * @param logger Optional logger name, defaults to 'mcp-remote' + */ + sendMessage(level: LoggingLevel, data: any, logger: string = 'mcp-remote') { + return this.send({ + jsonrpc: '2.0', + method: 'notifications/message', + params: { + level, + logger, + data, + }, + }) + } +} + /** * Type for the auth initialization function */ @@ -106,9 +132,21 @@ export async function connectToRemoteServer( headers: Record, authInitializer: AuthInitializer, transportStrategy: TransportStrategy = 'http-first', + localTransport: StdioServerTransportExt | null = null, recursionReasons: Set = new Set(), ): Promise { - log(`[${pid}] Connecting to remote server: ${serverUrl}`) + const _log = (level: LoggingLevel, message: any, status: ConnStatus) => { + // If localTransport is provided (proxy mode), send the message to it + if (localTransport) { + localTransport.sendMessage(level, { + status, + message, + }) + } + log(message) + } + + _log('info', `[${pid}] Connecting to remote server: ${serverUrl}`, 'connecting') const url = new URL(serverUrl) // Create transport with eventSourceInit to pass Authorization header if present @@ -128,7 +166,7 @@ export async function connectToRemoteServer( }, } - log(`Using transport strategy: ${transportStrategy}`) + _log('info', `Using transport strategy: ${transportStrategy}`, 'connecting') // Determine if we should attempt to fallback on error // Choose transport based on user strategy and recursion history const shouldAttemptFallback = transportStrategy === 'http-first' || transportStrategy === 'sse-first' @@ -161,7 +199,7 @@ export async function connectToRemoteServer( await testClient.connect(testTransport) } } - log(`Connected to remote server using ${transport.constructor.name}`) + _log('info', `Connected to remote server using ${transport.constructor.name}`, 'connected') return transport } catch (error) { @@ -174,16 +212,16 @@ export async function connectToRemoteServer( error.message.includes('404') || error.message.includes('Not Found')) ) { - log(`Received error: ${error.message}`) + _log('error', `Received error: ${error.message}`, 'error') // If we've already tried falling back once, throw an error if (recursionReasons.has(REASON_TRANSPORT_FALLBACK)) { const errorMessage = `Already attempted transport fallback. Giving up.` - log(errorMessage) + _log('error', errorMessage, 'error_final') throw new Error(errorMessage) } - log(`Recursively reconnecting for reason: ${REASON_TRANSPORT_FALLBACK}`) + _log('info', `Recursively reconnecting for reason: ${REASON_TRANSPORT_FALLBACK}`, 'reconnecting') // Add to recursion reasons set recursionReasons.add(REASON_TRANSPORT_FALLBACK) @@ -196,45 +234,55 @@ export async function connectToRemoteServer( headers, authInitializer, sseTransport ? 'http-only' : 'sse-only', + localTransport, recursionReasons, ) } else if (error instanceof UnauthorizedError || (error instanceof Error && error.message.includes('Unauthorized'))) { - log('Authentication required. Initializing auth...') + _log('info', 'Authentication required. Initializing auth...', 'authenticating') // Initialize authentication on-demand const { waitForAuthCode, skipBrowserAuth } = await authInitializer() if (skipBrowserAuth) { - log('Authentication required but skipping browser auth - using shared auth') + _log('info', 'Authentication required but skipping browser auth - using shared auth', 'authenticating') } else { - log('Authentication required. Waiting for authorization...') + _log('info', 'Authentication required. Waiting for authorization...', 'authenticating') } // Wait for the authorization code from the callback const code = await waitForAuthCode() try { - log('Completing authorization...') + _log('info', 'Completing authorization...', 'authenticating') await transport.finishAuth(code) if (recursionReasons.has(REASON_AUTH_NEEDED)) { const errorMessage = `Already attempted reconnection for reason: ${REASON_AUTH_NEEDED}. Giving up.` - log(errorMessage) + _log('error', errorMessage, 'error_final') throw new Error(errorMessage) } // Track this reason for recursion recursionReasons.add(REASON_AUTH_NEEDED) - log(`Recursively reconnecting for reason: ${REASON_AUTH_NEEDED}`) + _log('info', `Recursively reconnecting for reason: ${REASON_AUTH_NEEDED}`, 'reconnecting') // Recursively call connectToRemoteServer with the updated recursion tracking - return connectToRemoteServer(client, serverUrl, authProvider, headers, authInitializer, transportStrategy, recursionReasons) + return connectToRemoteServer( + client, + serverUrl, + authProvider, + headers, + authInitializer, + transportStrategy, + localTransport, + recursionReasons, + ) } catch (authError) { - log('Authorization error:', authError) + _log('error', `Authorization error: ${authError}`, 'error_final') throw authError } } else { - log('Connection error:', error) + _log('error', `Connection error: ${error}`, 'error_final') throw error } } diff --git a/src/proxy.ts b/src/proxy.ts index 7263a95..1bb0ebc 100644 --- a/src/proxy.ts +++ b/src/proxy.ts @@ -10,7 +10,7 @@ */ import { EventEmitter } from 'events' -import { StdioServerTransport } from '@modelcontextprotocol/sdk/server/stdio.js' +import { StdioServerTransportExt } from './lib/utils' import { connectToRemoteServer, log, @@ -50,7 +50,7 @@ async function runProxy( }) // Create the STDIO transport for local connections - const localTransport = new StdioServerTransport() + const localTransport = new StdioServerTransportExt() // Keep track of the server instance for cleanup let server: any = null @@ -78,7 +78,15 @@ async function runProxy( try { // Connect to remote server with lazy authentication - const remoteTransport = await connectToRemoteServer(null, serverUrl, authProvider, headers, authInitializer, transportStrategy) + const remoteTransport = await connectToRemoteServer( + null, + serverUrl, + authProvider, + headers, + authInitializer, + transportStrategy, + localTransport, + ) // Set up bidirectional proxy between local and remote transports mcpProxy({