most of the implementation looking ok but sharing the token the wrong way

This commit is contained in:
Glen Maddern 2025-03-31 19:38:52 +11:00
parent 743b6b207f
commit 412b5d9486
6 changed files with 364 additions and 62 deletions

View file

@ -18,7 +18,8 @@ import { SSEClientTransport } from '@modelcontextprotocol/sdk/client/sse.js'
import { ListResourcesResultSchema, ListToolsResultSchema } from '@modelcontextprotocol/sdk/types.js' import { ListResourcesResultSchema, ListToolsResultSchema } from '@modelcontextprotocol/sdk/types.js'
import { UnauthorizedError } from '@modelcontextprotocol/sdk/client/auth.js' import { UnauthorizedError } from '@modelcontextprotocol/sdk/client/auth.js'
import { NodeOAuthClientProvider } from './lib/node-oauth-client-provider' import { NodeOAuthClientProvider } from './lib/node-oauth-client-provider'
import { parseCommandLineArgs, setupOAuthCallbackServer, setupSignalHandlers, MCP_REMOTE_VERSION } from './lib/utils' import { parseCommandLineArgs, setupSignalHandlers, log, MCP_REMOTE_VERSION, getServerUrlHash } from './lib/utils'
import { coordinateAuth } from './lib/coordination'
/** /**
* Main function to run the client * Main function to run the client
@ -27,6 +28,12 @@ async function runClient(serverUrl: string, callbackPort: number, clean: boolean
// Set up event emitter for auth flow // Set up event emitter for auth flow
const events = new EventEmitter() const events = new EventEmitter()
// Get the server URL hash for lockfile operations
const serverUrlHash = getServerUrlHash(serverUrl)
// Coordinate authentication with other instances
const { server, waitForAuthCode, skipBrowserAuth } = await coordinateAuth(serverUrlHash, callbackPort, events)
// Create the OAuth client provider // Create the OAuth client provider
const authProvider = new NodeOAuthClientProvider({ const authProvider = new NodeOAuthClientProvider({
serverUrl, serverUrl,
@ -35,6 +42,11 @@ async function runClient(serverUrl: string, callbackPort: number, clean: boolean
clean, clean,
}) })
// If we got auth from another instance, pre-populate with the received code
if (skipBrowserAuth) {
log('Using auth code from another instance')
}
// Create the client // Create the client
const client = new Client( const client = new Client(
{ {
@ -53,15 +65,15 @@ async function runClient(serverUrl: string, callbackPort: number, clean: boolean
// Set up message and error handlers // Set up message and error handlers
transport.onmessage = (message) => { transport.onmessage = (message) => {
console.log('Received message:', JSON.stringify(message, null, 2)) log('Received message:', JSON.stringify(message, null, 2))
} }
transport.onerror = (error) => { transport.onerror = (error) => {
console.error('Transport error:', error) log('Transport error:', error)
} }
transport.onclose = () => { transport.onclose = () => {
console.log('Connection closed.') log('Connection closed.')
process.exit(0) process.exit(0)
} }
return transport return transport
@ -69,16 +81,9 @@ async function runClient(serverUrl: string, callbackPort: number, clean: boolean
const transport = initTransport() const transport = initTransport()
// Set up an HTTP server to handle OAuth callback
const { server, waitForAuthCode } = setupOAuthCallbackServer({
port: callbackPort,
path: '/oauth/callback',
events,
})
// Set up cleanup handler // Set up cleanup handler
const cleanup = async () => { const cleanup = async () => {
console.log('\nClosing connection...') log('\nClosing connection...')
await client.close() await client.close()
server.close() server.close()
} }
@ -86,44 +91,44 @@ async function runClient(serverUrl: string, callbackPort: number, clean: boolean
// Try to connect // Try to connect
try { try {
console.log('Connecting to server...') log('Connecting to server...')
await client.connect(transport) await client.connect(transport)
console.log('Connected successfully!') log('Connected successfully!')
} catch (error) { } catch (error) {
if (error instanceof UnauthorizedError || (error instanceof Error && error.message.includes('Unauthorized'))) { if (error instanceof UnauthorizedError || (error instanceof Error && error.message.includes('Unauthorized'))) {
console.log('Authentication required. Waiting for authorization...') log('Authentication required. Waiting for authorization...')
// Wait for the authorization code from the callback // Wait for the authorization code from the callback or another instance
const code = await waitForAuthCode() const code = await waitForAuthCode()
try { try {
console.log('Completing authorization...') log('Completing authorization...')
await transport.finishAuth(code) await transport.finishAuth(code)
// Reconnect after authorization with a new transport // Reconnect after authorization with a new transport
console.log('Connecting after authorization...') log('Connecting after authorization...')
await client.connect(initTransport()) await client.connect(initTransport())
console.log('Connected successfully!') log('Connected successfully!')
// Request tools list after auth // Request tools list after auth
console.log('Requesting tools list...') log('Requesting tools list...')
const tools = await client.request({ method: 'tools/list' }, ListToolsResultSchema) const tools = await client.request({ method: 'tools/list' }, ListToolsResultSchema)
console.log('Tools:', JSON.stringify(tools, null, 2)) log('Tools:', JSON.stringify(tools, null, 2))
// Request resources list after auth // Request resources list after auth
console.log('Requesting resource list...') log('Requesting resource list...')
const resources = await client.request({ method: 'resources/list' }, ListResourcesResultSchema) const resources = await client.request({ method: 'resources/list' }, ListResourcesResultSchema)
console.log('Resources:', JSON.stringify(resources, null, 2)) log('Resources:', JSON.stringify(resources, null, 2))
console.log('Listening for messages. Press Ctrl+C to exit.') log('Listening for messages. Press Ctrl+C to exit.')
} catch (authError) { } catch (authError) {
console.error('Authorization error:', authError) log('Authorization error:', authError)
server.close() server.close()
process.exit(1) process.exit(1)
} }
} else { } else {
console.error('Connection error:', error) log('Connection error:', error)
server.close() server.close()
process.exit(1) process.exit(1)
} }
@ -131,23 +136,23 @@ async function runClient(serverUrl: string, callbackPort: number, clean: boolean
try { try {
// Request tools list // Request tools list
console.log('Requesting tools list...') log('Requesting tools list...')
const tools = await client.request({ method: 'tools/list' }, ListToolsResultSchema) const tools = await client.request({ method: 'tools/list' }, ListToolsResultSchema)
console.log('Tools:', JSON.stringify(tools, null, 2)) log('Tools:', JSON.stringify(tools, null, 2))
} catch (e) { } catch (e) {
console.log('Error requesting tools list:', e) log('Error requesting tools list:', e)
} }
try { try {
// Request resources list // Request resources list
console.log('Requesting resource list...') log('Requesting resource list...')
const resources = await client.request({ method: 'resources/list' }, ListResourcesResultSchema) const resources = await client.request({ method: 'resources/list' }, ListResourcesResultSchema)
console.log('Resources:', JSON.stringify(resources, null, 2)) log('Resources:', JSON.stringify(resources, null, 2))
} catch (e) { } catch (e) {
console.log('Error requesting resources list:', e) log('Error requesting resources list:', e)
} }
console.log('Listening for messages. Press Ctrl+C to exit.') log('Listening for messages. Press Ctrl+C to exit.')
} }
// Parse command-line arguments and run the client // Parse command-line arguments and run the client

181
src/lib/coordination.ts Normal file
View file

@ -0,0 +1,181 @@
import { checkLockfile, createLockfile, deleteLockfile, getConfigFilePath, LockfileData } from './mcp-auth-config'
import { EventEmitter } from 'events'
import { Server } from 'http'
import express from 'express'
import { AddressInfo } from 'net'
import { log, setupOAuthCallbackServerWithLongPoll } from './utils'
/**
* Checks if a process with the given PID is running
* @param pid The process ID to check
* @returns True if the process is running, false otherwise
*/
export async function isPidRunning(pid: number): Promise<boolean> {
try {
process.kill(pid, 0) // Doesn't kill the process, just checks if it exists
return true
} catch {
return false
}
}
/**
* Checks if a lockfile is valid (process running and endpoint accessible)
* @param lockData The lockfile data
* @returns True if the lockfile is valid, false otherwise
*/
export async function isLockValid(lockData: LockfileData): Promise<boolean> {
// Check if the lockfile is too old (over 30 minutes)
const MAX_LOCK_AGE = 30 * 60 * 1000 // 30 minutes
if (Date.now() - lockData.timestamp > MAX_LOCK_AGE) {
log('Lockfile is too old')
return false
}
// Check if the process is still running
if (!(await isPidRunning(lockData.pid))) {
log('Process from lockfile is not running')
return false
}
// Check if the endpoint is accessible
try {
const controller = new AbortController()
const timeout = setTimeout(() => controller.abort(), 1000)
const response = await fetch(`http://127.0.0.1:${lockData.port}/wait-for-auth?poll=false`, {
signal: controller.signal,
})
clearTimeout(timeout)
return response.status === 200 || response.status === 202
} catch (error) {
log(`Error connecting to auth server: ${(error as Error).message}`)
return false
}
}
/**
* Waits for authentication from another server instance
* @param port The port to connect to
* @returns The auth code if successful, false otherwise
*/
export async function waitForAuthentication(port: number): Promise<string | false> {
log(`Waiting for authentication from the server on port ${port}...`)
try {
while (true) {
const url = `http://127.0.0.1:${port}/wait-for-auth`
log(`Querying: ${url}`)
const response = await fetch(url)
if (response.status === 200) {
const code = await response.text()
log(`Received code: ${code}`)
return code // Return the auth code
} else if (response.status === 202) {
// do nothing, loop
} else {
log(`Unexpected response status: ${response.status}`)
return false
}
}
} catch (error) {
log(`Error waiting for authentication: ${(error as Error).message}`)
return false
}
}
/**
* Coordinates authentication between multiple instances of the client/proxy
* @param serverUrlHash The hash of the server URL
* @param callbackPort The port to use for the callback server
* @param events The event emitter to use for signaling
* @returns An object with the server, waitForAuthCode function, and a flag indicating if browser auth can be skipped
*/
export async function coordinateAuth(
serverUrlHash: string,
callbackPort: number,
events: EventEmitter,
): Promise<{ server: Server; waitForAuthCode: () => Promise<string>; skipBrowserAuth: boolean; authCode?: string }> {
// Check for a lockfile
const lockData = await checkLockfile(serverUrlHash)
// If there's a valid lockfile, try to use the existing auth process
if (lockData && (await isLockValid(lockData))) {
log(`Another instance is handling authentication on port ${lockData.port}`)
try {
// Try to wait for the authentication to complete
const code = await waitForAuthentication(lockData.port)
if (code) {
log('Authentication completed by another instance')
// Setup a dummy server and return a pre-resolved promise for the auth code
const dummyServer = express().listen(0) // Listen on any available port
const dummyWaitForAuthCode = () => Promise.resolve(code)
return {
server: dummyServer,
waitForAuthCode: dummyWaitForAuthCode,
skipBrowserAuth: true,
authCode: code,
}
} else {
log('Taking over authentication process...')
}
} catch (error) {
log(`Error waiting for authentication: ${error}`)
}
// If we get here, the other process didn't complete auth successfully
await deleteLockfile(serverUrlHash)
} else if (lockData) {
// Invalid lockfile, delete its
log('Found invalid lockfile, deleting it')
await deleteLockfile(serverUrlHash)
}
// Create our own lockfile
const { server, waitForAuthCode, authCompletedPromise } = setupOAuthCallbackServerWithLongPoll({
port: callbackPort,
path: '/oauth/callback',
events,
})
// Get the actual port the server is running on
const address = server.address() as AddressInfo
const actualPort = address.port
log(`Creating lockfile for server ${serverUrlHash} with process ${process.pid} on port ${actualPort}`)
await createLockfile(serverUrlHash, process.pid, actualPort)
// Make sure lockfile is deleted on process exit
const cleanupHandler = async () => {
try {
log(`Cleaning up lockfile for server ${serverUrlHash}`)
await deleteLockfile(serverUrlHash)
} catch (error) {
log(`Error cleaning up lockfile: ${error}`)
}
}
process.once('exit', () => {
try {
// Synchronous version for 'exit' event since we can't use async here
const configPath = getConfigFilePath(serverUrlHash, 'lock.json')
require('fs').unlinkSync(configPath)
} catch {}
})
// Also handle SIGINT separately
process.once('SIGINT', async () => {
await cleanupHandler()
})
return {
server,
waitForAuthCode,
skipBrowserAuth: false,
}
}

View file

@ -1,4 +1,3 @@
import crypto from 'crypto'
import path from 'path' import path from 'path'
import os from 'os' import os from 'os'
import fs from 'fs/promises' import fs from 'fs/promises'
@ -27,7 +26,61 @@ import { log, MCP_REMOTE_VERSION } from './utils'
/** /**
* Known configuration file names that might need to be cleaned * Known configuration file names that might need to be cleaned
*/ */
export const knownConfigFiles = ['client_info.json', 'tokens.json', 'code_verifier.txt'] export const knownConfigFiles = ['client_info.json', 'tokens.json', 'code_verifier.txt', 'lock.json']
/**
* Lockfile data structure
*/
export interface LockfileData {
pid: number
port: number
timestamp: number
}
/**
* Creates a lockfile for the given server
* @param serverUrlHash The hash of the server URL
* @param pid The process ID
* @param port The port the server is running on
*/
export async function createLockfile(serverUrlHash: string, pid: number, port: number): Promise<void> {
const lockData: LockfileData = {
pid,
port,
timestamp: Date.now(),
}
await writeJsonFile(serverUrlHash, 'lock.json', lockData)
}
/**
* Checks if a lockfile exists for the given server
* @param serverUrlHash The hash of the server URL
* @returns The lockfile data or null if it doesn't exist
*/
export async function checkLockfile(serverUrlHash: string): Promise<LockfileData | null> {
try {
const lockfile = await readJsonFile<LockfileData>(serverUrlHash, 'lock.json', {
async parseAsync(data: any) {
if (typeof data !== 'object' || data === null) return null
if (typeof data.pid !== 'number' || typeof data.port !== 'number' || typeof data.timestamp !== 'number') {
return null
}
return data as LockfileData
},
})
return lockfile || null
} catch {
return null
}
}
/**
* Deletes the lockfile for the given server
* @param serverUrlHash The hash of the server URL
*/
export async function deleteLockfile(serverUrlHash: string): Promise<void> {
await deleteConfigFile(serverUrlHash, 'lock.json')
}
/** /**
* Deletes all known configuration files for a specific server * Deletes all known configuration files for a specific server
@ -63,15 +116,6 @@ export async function ensureConfigDir(): Promise<void> {
} }
} }
/**
* Generates a hash for the server URL to use in filenames
* @param serverUrl The server URL to hash
* @returns The hashed server URL
*/
export function getServerUrlHash(serverUrl: string): string {
return crypto.createHash('md5').update(serverUrl).digest('hex')
}
/** /**
* Gets the file path for a config file * Gets the file path for a config file
* @param serverUrlHash The hash of the server URL * @param serverUrlHash The hash of the server URL

View file

@ -8,8 +8,8 @@ import {
OAuthTokensSchema, OAuthTokensSchema,
} from '@modelcontextprotocol/sdk/shared/auth.js' } from '@modelcontextprotocol/sdk/shared/auth.js'
import type { OAuthProviderOptions } from './types' import type { OAuthProviderOptions } from './types'
import { getServerUrlHash, readJsonFile, writeJsonFile, readTextFile, writeTextFile, cleanServerConfig } from './mcp-auth-config' import { readJsonFile, writeJsonFile, readTextFile, writeTextFile, cleanServerConfig } from './mcp-auth-config'
import { log } from './utils' import { getServerUrlHash, log } from './utils'
/** /**
* Implements the OAuthClientProvider interface for Node.js environments. * Implements the OAuthClientProvider interface for Node.js environments.

View file

@ -4,6 +4,10 @@ import { Transport } from '@modelcontextprotocol/sdk/shared/transport.js'
import { OAuthCallbackServerOptions } from './types' import { OAuthCallbackServerOptions } from './types'
import express from 'express' import express from 'express'
import net from 'net' import net from 'net'
import crypto from 'crypto'
// Package version from package.json
export const MCP_REMOTE_VERSION = require('../../package.json').version
const pid = process.pid const pid = process.pid
export function log(str: string, ...rest: unknown[]) { export function log(str: string, ...rest: unknown[]) {
@ -65,12 +69,14 @@ export function mcpProxy({ transportToClient, transportToServer }: { transportTo
* @param serverUrl The URL of the remote server * @param serverUrl The URL of the remote server
* @param authProvider The OAuth client provider * @param authProvider The OAuth client provider
* @param waitForAuthCode Function to wait for the auth code * @param waitForAuthCode Function to wait for the auth code
* @param skipBrowserAuth Whether to skip browser auth and use shared auth
* @returns The connected SSE client transport * @returns The connected SSE client transport
*/ */
export async function connectToRemoteServer( export async function connectToRemoteServer(
serverUrl: string, serverUrl: string,
authProvider: OAuthClientProvider, authProvider: OAuthClientProvider,
waitForAuthCode: () => Promise<string>, waitForAuthCode: () => Promise<string>,
skipBrowserAuth: boolean = false,
): Promise<SSEClientTransport> { ): Promise<SSEClientTransport> {
log(`[${pid}] Connecting to remote server: ${serverUrl}`) log(`[${pid}] Connecting to remote server: ${serverUrl}`)
const url = new URL(serverUrl) const url = new URL(serverUrl)
@ -82,7 +88,11 @@ export async function connectToRemoteServer(
return transport return transport
} catch (error) { } catch (error) {
if (error instanceof UnauthorizedError || (error instanceof Error && error.message.includes('Unauthorized'))) { if (error instanceof UnauthorizedError || (error instanceof Error && error.message.includes('Unauthorized'))) {
log('Authentication required. Waiting for authorization...') if (skipBrowserAuth) {
log('Authentication required but skipping browser auth - using shared auth')
} else {
log('Authentication required. Waiting for authorization...')
}
// Wait for the authorization code from the callback // Wait for the authorization code from the callback
const code = await waitForAuthCode() const code = await waitForAuthCode()
@ -112,10 +122,56 @@ export async function connectToRemoteServer(
* @param options The server options * @param options The server options
* @returns An object with the server, authCode, and waitForAuthCode function * @returns An object with the server, authCode, and waitForAuthCode function
*/ */
export function setupOAuthCallbackServer(options: OAuthCallbackServerOptions) { export function setupOAuthCallbackServerWithLongPoll(options: OAuthCallbackServerOptions) {
let authCode: string | null = null let authCode: string | null = null
const app = express() const app = express()
// Create a promise to track when auth is completed
let authCompletedResolve: (code: string) => void
const authCompletedPromise = new Promise<string>((resolve) => {
authCompletedResolve = resolve
})
// Long-polling endpoint
app.get('/wait-for-auth', (req, res) => {
if (authCode) {
// Auth already completed
log('Auth already completed, returning immediately')
res.status(200).send(authCode)
return
}
if (req.query.poll === 'false') {
log('Client requested no long poll, responding with 202')
res.status(202).send('Authentication in progress')
return
}
// Long poll - wait for up to 30 seconds
const longPollTimeout = setTimeout(() => {
log('Long poll timeout reached, responding with 202')
res.status(202).send('Authentication in progress')
}, 30000)
// If auth completes while we're waiting, send the response immediately
authCompletedPromise
.then((code) => {
clearTimeout(longPollTimeout)
if (!res.headersSent) {
log('Auth completed during long poll, responding with 200')
res.status(200).send(code)
}
})
.catch(() => {
clearTimeout(longPollTimeout)
if (!res.headersSent) {
log('Auth failed during long poll, responding with 500')
res.status(500).send('Authentication failed')
}
})
})
// OAuth callback endpoint
app.get(options.path, (req, res) => { app.get(options.path, (req, res) => {
const code = req.query.code as string | undefined const code = req.query.code as string | undefined
if (!code) { if (!code) {
@ -124,6 +180,9 @@ export function setupOAuthCallbackServer(options: OAuthCallbackServerOptions) {
} }
authCode = code authCode = code
log('Auth code received, resolving promise')
authCompletedResolve(code)
res.send('Authorization successful! You may close this window and return to the CLI.') res.send('Authorization successful! You may close this window and return to the CLI.')
// Notify main flow that auth code is available // Notify main flow that auth code is available
@ -134,10 +193,6 @@ export function setupOAuthCallbackServer(options: OAuthCallbackServerOptions) {
log(`OAuth callback server running at http://127.0.0.1:${options.port}`) log(`OAuth callback server running at http://127.0.0.1:${options.port}`)
}) })
/**
* Waits for the OAuth authorization code
* @returns A promise that resolves with the authorization code
*/
const waitForAuthCode = (): Promise<string> => { const waitForAuthCode = (): Promise<string> => {
return new Promise((resolve) => { return new Promise((resolve) => {
if (authCode) { if (authCode) {
@ -151,6 +206,16 @@ export function setupOAuthCallbackServer(options: OAuthCallbackServerOptions) {
}) })
} }
return { server, authCode, waitForAuthCode, authCompletedPromise }
}
/**
* Sets up an Express server to handle OAuth callbacks
* @param options The server options
* @returns An object with the server, authCode, and waitForAuthCode function
*/
export function setupOAuthCallbackServer(options: OAuthCallbackServerOptions) {
const { server, authCode, waitForAuthCode } = setupOAuthCallbackServerWithLongPoll(options)
return { server, authCode, waitForAuthCode } return { server, authCode, waitForAuthCode }
} }
@ -248,4 +313,11 @@ export function setupSignalHandlers(cleanup: () => Promise<void>) {
process.stdin.resume() process.stdin.resume()
} }
export const MCP_REMOTE_VERSION = require('../../package.json').version /**
* Generates a hash for the server URL to use in filenames
* @param serverUrl The server URL to hash
* @returns The hashed server URL
*/
export function getServerUrlHash(serverUrl: string): string {
return crypto.createHash('md5').update(serverUrl).digest('hex')
}

View file

@ -14,8 +14,9 @@
import { EventEmitter } from 'events' import { EventEmitter } from 'events'
import { StdioServerTransport } from '@modelcontextprotocol/sdk/server/stdio.js' import { StdioServerTransport } from '@modelcontextprotocol/sdk/server/stdio.js'
import { connectToRemoteServer, log, mcpProxy, parseCommandLineArgs, setupOAuthCallbackServer, setupSignalHandlers } from './lib/utils' import { connectToRemoteServer, log, mcpProxy, parseCommandLineArgs, setupSignalHandlers, getServerUrlHash } from './lib/utils'
import { NodeOAuthClientProvider } from './lib/node-oauth-client-provider' import { NodeOAuthClientProvider } from './lib/node-oauth-client-provider'
import { coordinateAuth } from './lib/coordination'
/** /**
* Main function to run the proxy * Main function to run the proxy
@ -24,6 +25,12 @@ async function runProxy(serverUrl: string, callbackPort: number, clean: boolean
// Set up event emitter for auth flow // Set up event emitter for auth flow
const events = new EventEmitter() const events = new EventEmitter()
// Get the server URL hash for lockfile operations
const serverUrlHash = getServerUrlHash(serverUrl)
// Coordinate authentication with other instances
const { server, waitForAuthCode, skipBrowserAuth } = await coordinateAuth(serverUrlHash, callbackPort, events)
// Create the OAuth client provider // Create the OAuth client provider
const authProvider = new NodeOAuthClientProvider({ const authProvider = new NodeOAuthClientProvider({
serverUrl, serverUrl,
@ -35,16 +42,9 @@ async function runProxy(serverUrl: string, callbackPort: number, clean: boolean
// Create the STDIO transport for local connections // Create the STDIO transport for local connections
const localTransport = new StdioServerTransport() const localTransport = new StdioServerTransport()
// Set up an HTTP server to handle OAuth callback
const { server, waitForAuthCode } = setupOAuthCallbackServer({
port: callbackPort,
path: '/oauth/callback',
events,
})
try { try {
// Connect to remote server with authentication // Connect to remote server with authentication
const remoteTransport = await connectToRemoteServer(serverUrl, authProvider, waitForAuthCode) const remoteTransport = await connectToRemoteServer(serverUrl, authProvider, waitForAuthCode, skipBrowserAuth)
// Set up bidirectional proxy between local and remote transports // Set up bidirectional proxy between local and remote transports
mcpProxy({ mcpProxy({