diff --git a/src/client.ts b/src/client.ts index d620884..ac58fb8 100644 --- a/src/client.ts +++ b/src/client.ts @@ -11,17 +11,28 @@ import { EventEmitter } from 'events' import { Client } from '@modelcontextprotocol/sdk/client/index.js' -import { SSEClientTransport } from '@modelcontextprotocol/sdk/client/sse.js' import { ListResourcesResultSchema, ListToolsResultSchema } from '@modelcontextprotocol/sdk/types.js' -import { UnauthorizedError } from '@modelcontextprotocol/sdk/client/auth.js' import { NodeOAuthClientProvider } from './lib/node-oauth-client-provider' -import { parseCommandLineArgs, setupSignalHandlers, log, MCP_REMOTE_VERSION, getServerUrlHash } from './lib/utils' +import { + parseCommandLineArgs, + setupSignalHandlers, + log, + MCP_REMOTE_VERSION, + getServerUrlHash, + connectToRemoteServer, + TransportStrategy, +} from './lib/utils' import { coordinateAuth } from './lib/coordination' /** * Main function to run the client */ -async function runClient(serverUrl: string, callbackPort: number, headers: Record) { +async function runClient( + serverUrl: string, + callbackPort: number, + headers: Record, + transportStrategy: TransportStrategy = 'http-first', +) { // Set up event emitter for auth flow const events = new EventEmitter() @@ -57,10 +68,9 @@ async function runClient(serverUrl: string, callbackPort: number, headers: Recor }, ) - // Create the transport factory - const url = new URL(serverUrl) - function initTransport() { - const transport = new SSEClientTransport(url, { authProvider, requestInit: { headers } }) + try { + // Connect to remote server with authentication + const transport = await connectToRemoteServer(serverUrl, authProvider, headers, waitForAuthCode, skipBrowserAuth, transportStrategy) // Set up message and error handlers transport.onmessage = (message) => { @@ -75,89 +85,50 @@ async function runClient(serverUrl: string, callbackPort: number, headers: Recor log('Connection closed.') process.exit(0) } - return transport - } - const transport = initTransport() + // Set up cleanup handler + const cleanup = async () => { + log('\nClosing connection...') + await client.close() + server.close() + } + setupSignalHandlers(cleanup) - // Set up cleanup handler - const cleanup = async () => { - log('\nClosing connection...') - await client.close() - server.close() - } - setupSignalHandlers(cleanup) - - // Try to connect - try { - log('Connecting to server...') + // Connect the client + log('Connecting client...') await client.connect(transport) log('Connected successfully!') - } catch (error) { - if (error instanceof UnauthorizedError || (error instanceof Error && error.message.includes('Unauthorized'))) { - log('Authentication required. Waiting for authorization...') - // Wait for the authorization code from the callback or another instance - const code = await waitForAuthCode() - - try { - log('Completing authorization...') - await transport.finishAuth(code) - - // Reconnect after authorization with a new transport - log('Connecting after authorization...') - await client.connect(initTransport()) - - log('Connected successfully!') - - // Request tools list after auth - log('Requesting tools list...') - const tools = await client.request({ method: 'tools/list' }, ListToolsResultSchema) - log('Tools:', JSON.stringify(tools, null, 2)) - - // Request resources list after auth - log('Requesting resource list...') - const resources = await client.request({ method: 'resources/list' }, ListResourcesResultSchema) - log('Resources:', JSON.stringify(resources, null, 2)) - - log('Listening for messages. Press Ctrl+C to exit.') - } catch (authError) { - log('Authorization error:', authError) - server.close() - process.exit(1) - } - } else { - log('Connection error:', error) - server.close() - process.exit(1) + try { + // Request tools list + log('Requesting tools list...') + const tools = await client.request({ method: 'tools/list' }, ListToolsResultSchema) + log('Tools:', JSON.stringify(tools, null, 2)) + } catch (e) { + log('Error requesting tools list:', e) } - } - try { - // Request tools list - log('Requesting tools list...') - const tools = await client.request({ method: 'tools/list' }, ListToolsResultSchema) - log('Tools:', JSON.stringify(tools, null, 2)) - } catch (e) { - log('Error requesting tools list:', e) - } + try { + // Request resources list + log('Requesting resource list...') + const resources = await client.request({ method: 'resources/list' }, ListResourcesResultSchema) + log('Resources:', JSON.stringify(resources, null, 2)) + } catch (e) { + log('Error requesting resources list:', e) + } - try { - // Request resources list - log('Requesting resource list...') - const resources = await client.request({ method: 'resources/list' }, ListResourcesResultSchema) - log('Resources:', JSON.stringify(resources, null, 2)) - } catch (e) { - log('Error requesting resources list:', e) + log('Listening for messages. Press Ctrl+C to exit.') + } catch (error) { + log('Fatal error:', error) + server.close() + process.exit(1) } - - log('Listening for messages. Press Ctrl+C to exit.') } // Parse command-line arguments and run the client parseCommandLineArgs(process.argv.slice(2), 3333, 'Usage: npx tsx client.ts [callback-port]') - .then(({ serverUrl, callbackPort, headers }) => { - return runClient(serverUrl, callbackPort, headers) + .then(({ serverUrl, callbackPort, headers, transportStrategy }) => { + return runClient(serverUrl, callbackPort, headers, transportStrategy) }) .catch((error) => { console.error('Fatal error:', error)