Merge pull request #13 from geelen/no-multiball
Tried to stop multiple windows opening at once
This commit is contained in:
commit
c6f98ff4b7
7 changed files with 397 additions and 65 deletions
|
@ -1,6 +1,6 @@
|
||||||
{
|
{
|
||||||
"name": "mcp-remote",
|
"name": "mcp-remote",
|
||||||
"version": "0.0.13",
|
"version": "0.0.15",
|
||||||
"description": "Remote proxy for Model Context Protocol, allowing local-only clients to connect to remote servers using oAuth",
|
"description": "Remote proxy for Model Context Protocol, allowing local-only clients to connect to remote servers using oAuth",
|
||||||
"keywords": [
|
"keywords": [
|
||||||
"mcp",
|
"mcp",
|
||||||
|
@ -10,7 +10,7 @@
|
||||||
"oauth"
|
"oauth"
|
||||||
],
|
],
|
||||||
"author": "Glen Maddern <glen@cloudflare.com>",
|
"author": "Glen Maddern <glen@cloudflare.com>",
|
||||||
"repository": "https://github.com/geelen/remote-mcp",
|
"repository": "https://github.com/geelen/mcp-remote",
|
||||||
"type": "module",
|
"type": "module",
|
||||||
"files": [
|
"files": [
|
||||||
"dist",
|
"dist",
|
||||||
|
|
|
@ -18,7 +18,8 @@ import { SSEClientTransport } from '@modelcontextprotocol/sdk/client/sse.js'
|
||||||
import { ListResourcesResultSchema, ListToolsResultSchema } from '@modelcontextprotocol/sdk/types.js'
|
import { ListResourcesResultSchema, ListToolsResultSchema } from '@modelcontextprotocol/sdk/types.js'
|
||||||
import { UnauthorizedError } from '@modelcontextprotocol/sdk/client/auth.js'
|
import { UnauthorizedError } from '@modelcontextprotocol/sdk/client/auth.js'
|
||||||
import { NodeOAuthClientProvider } from './lib/node-oauth-client-provider'
|
import { NodeOAuthClientProvider } from './lib/node-oauth-client-provider'
|
||||||
import { parseCommandLineArgs, setupOAuthCallbackServer, setupSignalHandlers, MCP_REMOTE_VERSION } from './lib/utils'
|
import { parseCommandLineArgs, setupSignalHandlers, log, MCP_REMOTE_VERSION, getServerUrlHash } from './lib/utils'
|
||||||
|
import { coordinateAuth } from './lib/coordination'
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Main function to run the client
|
* Main function to run the client
|
||||||
|
@ -27,6 +28,12 @@ async function runClient(serverUrl: string, callbackPort: number, clean: boolean
|
||||||
// Set up event emitter for auth flow
|
// Set up event emitter for auth flow
|
||||||
const events = new EventEmitter()
|
const events = new EventEmitter()
|
||||||
|
|
||||||
|
// Get the server URL hash for lockfile operations
|
||||||
|
const serverUrlHash = getServerUrlHash(serverUrl)
|
||||||
|
|
||||||
|
// Coordinate authentication with other instances
|
||||||
|
const { server, waitForAuthCode, skipBrowserAuth } = await coordinateAuth(serverUrlHash, callbackPort, events)
|
||||||
|
|
||||||
// Create the OAuth client provider
|
// Create the OAuth client provider
|
||||||
const authProvider = new NodeOAuthClientProvider({
|
const authProvider = new NodeOAuthClientProvider({
|
||||||
serverUrl,
|
serverUrl,
|
||||||
|
@ -35,6 +42,14 @@ async function runClient(serverUrl: string, callbackPort: number, clean: boolean
|
||||||
clean,
|
clean,
|
||||||
})
|
})
|
||||||
|
|
||||||
|
// If auth was completed by another instance, just log that we'll use the auth from disk
|
||||||
|
if (skipBrowserAuth) {
|
||||||
|
log('Authentication was completed by another instance - will use tokens from disk...')
|
||||||
|
// TODO: remove, the callback is happening before the tokens are exchanged
|
||||||
|
// so we're slightly too early
|
||||||
|
await new Promise((res) => setTimeout(res, 1_000))
|
||||||
|
}
|
||||||
|
|
||||||
// Create the client
|
// Create the client
|
||||||
const client = new Client(
|
const client = new Client(
|
||||||
{
|
{
|
||||||
|
@ -53,15 +68,15 @@ async function runClient(serverUrl: string, callbackPort: number, clean: boolean
|
||||||
|
|
||||||
// Set up message and error handlers
|
// Set up message and error handlers
|
||||||
transport.onmessage = (message) => {
|
transport.onmessage = (message) => {
|
||||||
console.log('Received message:', JSON.stringify(message, null, 2))
|
log('Received message:', JSON.stringify(message, null, 2))
|
||||||
}
|
}
|
||||||
|
|
||||||
transport.onerror = (error) => {
|
transport.onerror = (error) => {
|
||||||
console.error('Transport error:', error)
|
log('Transport error:', error)
|
||||||
}
|
}
|
||||||
|
|
||||||
transport.onclose = () => {
|
transport.onclose = () => {
|
||||||
console.log('Connection closed.')
|
log('Connection closed.')
|
||||||
process.exit(0)
|
process.exit(0)
|
||||||
}
|
}
|
||||||
return transport
|
return transport
|
||||||
|
@ -69,16 +84,9 @@ async function runClient(serverUrl: string, callbackPort: number, clean: boolean
|
||||||
|
|
||||||
const transport = initTransport()
|
const transport = initTransport()
|
||||||
|
|
||||||
// Set up an HTTP server to handle OAuth callback
|
|
||||||
const { server, waitForAuthCode } = setupOAuthCallbackServer({
|
|
||||||
port: callbackPort,
|
|
||||||
path: '/oauth/callback',
|
|
||||||
events,
|
|
||||||
})
|
|
||||||
|
|
||||||
// Set up cleanup handler
|
// Set up cleanup handler
|
||||||
const cleanup = async () => {
|
const cleanup = async () => {
|
||||||
console.log('\nClosing connection...')
|
log('\nClosing connection...')
|
||||||
await client.close()
|
await client.close()
|
||||||
server.close()
|
server.close()
|
||||||
}
|
}
|
||||||
|
@ -86,44 +94,44 @@ async function runClient(serverUrl: string, callbackPort: number, clean: boolean
|
||||||
|
|
||||||
// Try to connect
|
// Try to connect
|
||||||
try {
|
try {
|
||||||
console.log('Connecting to server...')
|
log('Connecting to server...')
|
||||||
await client.connect(transport)
|
await client.connect(transport)
|
||||||
console.log('Connected successfully!')
|
log('Connected successfully!')
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
if (error instanceof UnauthorizedError || (error instanceof Error && error.message.includes('Unauthorized'))) {
|
if (error instanceof UnauthorizedError || (error instanceof Error && error.message.includes('Unauthorized'))) {
|
||||||
console.log('Authentication required. Waiting for authorization...')
|
log('Authentication required. Waiting for authorization...')
|
||||||
|
|
||||||
// Wait for the authorization code from the callback
|
// Wait for the authorization code from the callback or another instance
|
||||||
const code = await waitForAuthCode()
|
const code = await waitForAuthCode()
|
||||||
|
|
||||||
try {
|
try {
|
||||||
console.log('Completing authorization...')
|
log('Completing authorization...')
|
||||||
await transport.finishAuth(code)
|
await transport.finishAuth(code)
|
||||||
|
|
||||||
// Reconnect after authorization with a new transport
|
// Reconnect after authorization with a new transport
|
||||||
console.log('Connecting after authorization...')
|
log('Connecting after authorization...')
|
||||||
await client.connect(initTransport())
|
await client.connect(initTransport())
|
||||||
|
|
||||||
console.log('Connected successfully!')
|
log('Connected successfully!')
|
||||||
|
|
||||||
// Request tools list after auth
|
// Request tools list after auth
|
||||||
console.log('Requesting tools list...')
|
log('Requesting tools list...')
|
||||||
const tools = await client.request({ method: 'tools/list' }, ListToolsResultSchema)
|
const tools = await client.request({ method: 'tools/list' }, ListToolsResultSchema)
|
||||||
console.log('Tools:', JSON.stringify(tools, null, 2))
|
log('Tools:', JSON.stringify(tools, null, 2))
|
||||||
|
|
||||||
// Request resources list after auth
|
// Request resources list after auth
|
||||||
console.log('Requesting resource list...')
|
log('Requesting resource list...')
|
||||||
const resources = await client.request({ method: 'resources/list' }, ListResourcesResultSchema)
|
const resources = await client.request({ method: 'resources/list' }, ListResourcesResultSchema)
|
||||||
console.log('Resources:', JSON.stringify(resources, null, 2))
|
log('Resources:', JSON.stringify(resources, null, 2))
|
||||||
|
|
||||||
console.log('Listening for messages. Press Ctrl+C to exit.')
|
log('Listening for messages. Press Ctrl+C to exit.')
|
||||||
} catch (authError) {
|
} catch (authError) {
|
||||||
console.error('Authorization error:', authError)
|
log('Authorization error:', authError)
|
||||||
server.close()
|
server.close()
|
||||||
process.exit(1)
|
process.exit(1)
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
console.error('Connection error:', error)
|
log('Connection error:', error)
|
||||||
server.close()
|
server.close()
|
||||||
process.exit(1)
|
process.exit(1)
|
||||||
}
|
}
|
||||||
|
@ -131,23 +139,23 @@ async function runClient(serverUrl: string, callbackPort: number, clean: boolean
|
||||||
|
|
||||||
try {
|
try {
|
||||||
// Request tools list
|
// Request tools list
|
||||||
console.log('Requesting tools list...')
|
log('Requesting tools list...')
|
||||||
const tools = await client.request({ method: 'tools/list' }, ListToolsResultSchema)
|
const tools = await client.request({ method: 'tools/list' }, ListToolsResultSchema)
|
||||||
console.log('Tools:', JSON.stringify(tools, null, 2))
|
log('Tools:', JSON.stringify(tools, null, 2))
|
||||||
} catch (e) {
|
} catch (e) {
|
||||||
console.log('Error requesting tools list:', e)
|
log('Error requesting tools list:', e)
|
||||||
}
|
}
|
||||||
|
|
||||||
try {
|
try {
|
||||||
// Request resources list
|
// Request resources list
|
||||||
console.log('Requesting resource list...')
|
log('Requesting resource list...')
|
||||||
const resources = await client.request({ method: 'resources/list' }, ListResourcesResultSchema)
|
const resources = await client.request({ method: 'resources/list' }, ListResourcesResultSchema)
|
||||||
console.log('Resources:', JSON.stringify(resources, null, 2))
|
log('Resources:', JSON.stringify(resources, null, 2))
|
||||||
} catch (e) {
|
} catch (e) {
|
||||||
console.log('Error requesting resources list:', e)
|
log('Error requesting resources list:', e)
|
||||||
}
|
}
|
||||||
|
|
||||||
console.log('Listening for messages. Press Ctrl+C to exit.')
|
log('Listening for messages. Press Ctrl+C to exit.')
|
||||||
}
|
}
|
||||||
|
|
||||||
// Parse command-line arguments and run the client
|
// Parse command-line arguments and run the client
|
||||||
|
|
188
src/lib/coordination.ts
Normal file
188
src/lib/coordination.ts
Normal file
|
@ -0,0 +1,188 @@
|
||||||
|
import { checkLockfile, createLockfile, deleteLockfile, getConfigFilePath, LockfileData } from './mcp-auth-config'
|
||||||
|
import { EventEmitter } from 'events'
|
||||||
|
import { Server } from 'http'
|
||||||
|
import express from 'express'
|
||||||
|
import { AddressInfo } from 'net'
|
||||||
|
import { log, setupOAuthCallbackServerWithLongPoll } from './utils'
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Checks if a process with the given PID is running
|
||||||
|
* @param pid The process ID to check
|
||||||
|
* @returns True if the process is running, false otherwise
|
||||||
|
*/
|
||||||
|
export async function isPidRunning(pid: number): Promise<boolean> {
|
||||||
|
try {
|
||||||
|
process.kill(pid, 0) // Doesn't kill the process, just checks if it exists
|
||||||
|
return true
|
||||||
|
} catch {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Checks if a lockfile is valid (process running and endpoint accessible)
|
||||||
|
* @param lockData The lockfile data
|
||||||
|
* @returns True if the lockfile is valid, false otherwise
|
||||||
|
*/
|
||||||
|
export async function isLockValid(lockData: LockfileData): Promise<boolean> {
|
||||||
|
// Check if the lockfile is too old (over 30 minutes)
|
||||||
|
const MAX_LOCK_AGE = 30 * 60 * 1000 // 30 minutes
|
||||||
|
if (Date.now() - lockData.timestamp > MAX_LOCK_AGE) {
|
||||||
|
log('Lockfile is too old')
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if the process is still running
|
||||||
|
if (!(await isPidRunning(lockData.pid))) {
|
||||||
|
log('Process from lockfile is not running')
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if the endpoint is accessible
|
||||||
|
try {
|
||||||
|
const controller = new AbortController()
|
||||||
|
const timeout = setTimeout(() => controller.abort(), 1000)
|
||||||
|
|
||||||
|
const response = await fetch(`http://127.0.0.1:${lockData.port}/wait-for-auth?poll=false`, {
|
||||||
|
signal: controller.signal,
|
||||||
|
})
|
||||||
|
|
||||||
|
clearTimeout(timeout)
|
||||||
|
return response.status === 200 || response.status === 202
|
||||||
|
} catch (error) {
|
||||||
|
log(`Error connecting to auth server: ${(error as Error).message}`)
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Waits for authentication from another server instance
|
||||||
|
* @param port The port to connect to
|
||||||
|
* @returns True if authentication completed successfully, false otherwise
|
||||||
|
*/
|
||||||
|
export async function waitForAuthentication(port: number): Promise<boolean> {
|
||||||
|
log(`Waiting for authentication from the server on port ${port}...`)
|
||||||
|
|
||||||
|
try {
|
||||||
|
while (true) {
|
||||||
|
const url = `http://127.0.0.1:${port}/wait-for-auth`
|
||||||
|
log(`Querying: ${url}`)
|
||||||
|
const response = await fetch(url)
|
||||||
|
|
||||||
|
if (response.status === 200) {
|
||||||
|
// Auth completed, but we don't return the code anymore
|
||||||
|
log(`Authentication completed by other instance`)
|
||||||
|
return true
|
||||||
|
} else if (response.status === 202) {
|
||||||
|
// Continue polling
|
||||||
|
log(`Authentication still in progress`)
|
||||||
|
await new Promise(resolve => setTimeout(resolve, 1000))
|
||||||
|
} else {
|
||||||
|
log(`Unexpected response status: ${response.status}`)
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} catch (error) {
|
||||||
|
log(`Error waiting for authentication: ${(error as Error).message}`)
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Coordinates authentication between multiple instances of the client/proxy
|
||||||
|
* @param serverUrlHash The hash of the server URL
|
||||||
|
* @param callbackPort The port to use for the callback server
|
||||||
|
* @param events The event emitter to use for signaling
|
||||||
|
* @returns An object with the server, waitForAuthCode function, and a flag indicating if browser auth can be skipped
|
||||||
|
*/
|
||||||
|
export async function coordinateAuth(
|
||||||
|
serverUrlHash: string,
|
||||||
|
callbackPort: number,
|
||||||
|
events: EventEmitter,
|
||||||
|
): Promise<{ server: Server; waitForAuthCode: () => Promise<string>; skipBrowserAuth: boolean }> {
|
||||||
|
// Check for a lockfile
|
||||||
|
const lockData = await checkLockfile(serverUrlHash)
|
||||||
|
|
||||||
|
// If there's a valid lockfile, try to use the existing auth process
|
||||||
|
if (lockData && (await isLockValid(lockData))) {
|
||||||
|
log(`Another instance is handling authentication on port ${lockData.port}`)
|
||||||
|
|
||||||
|
try {
|
||||||
|
// Try to wait for the authentication to complete
|
||||||
|
const authCompleted = await waitForAuthentication(lockData.port)
|
||||||
|
if (authCompleted) {
|
||||||
|
log('Authentication completed by another instance')
|
||||||
|
|
||||||
|
// Setup a dummy server - the client will use tokens directly from disk
|
||||||
|
const dummyServer = express().listen(0) // Listen on any available port
|
||||||
|
|
||||||
|
// This shouldn't actually be called in normal operation, but provide it for API compatibility
|
||||||
|
const dummyWaitForAuthCode = () => {
|
||||||
|
log('WARNING: waitForAuthCode called in secondary instance - this is unexpected')
|
||||||
|
// Return a promise that never resolves - the client should use the tokens from disk instead
|
||||||
|
return new Promise<string>(() => {})
|
||||||
|
}
|
||||||
|
|
||||||
|
return {
|
||||||
|
server: dummyServer,
|
||||||
|
waitForAuthCode: dummyWaitForAuthCode,
|
||||||
|
skipBrowserAuth: true,
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
log('Taking over authentication process...')
|
||||||
|
}
|
||||||
|
} catch (error) {
|
||||||
|
log(`Error waiting for authentication: ${error}`)
|
||||||
|
}
|
||||||
|
|
||||||
|
// If we get here, the other process didn't complete auth successfully
|
||||||
|
await deleteLockfile(serverUrlHash)
|
||||||
|
} else if (lockData) {
|
||||||
|
// Invalid lockfile, delete its
|
||||||
|
log('Found invalid lockfile, deleting it')
|
||||||
|
await deleteLockfile(serverUrlHash)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create our own lockfile
|
||||||
|
const { server, waitForAuthCode, authCompletedPromise } = setupOAuthCallbackServerWithLongPoll({
|
||||||
|
port: callbackPort,
|
||||||
|
path: '/oauth/callback',
|
||||||
|
events,
|
||||||
|
})
|
||||||
|
|
||||||
|
// Get the actual port the server is running on
|
||||||
|
const address = server.address() as AddressInfo
|
||||||
|
const actualPort = address.port
|
||||||
|
|
||||||
|
log(`Creating lockfile for server ${serverUrlHash} with process ${process.pid} on port ${actualPort}`)
|
||||||
|
await createLockfile(serverUrlHash, process.pid, actualPort)
|
||||||
|
|
||||||
|
// Make sure lockfile is deleted on process exit
|
||||||
|
const cleanupHandler = async () => {
|
||||||
|
try {
|
||||||
|
log(`Cleaning up lockfile for server ${serverUrlHash}`)
|
||||||
|
await deleteLockfile(serverUrlHash)
|
||||||
|
} catch (error) {
|
||||||
|
log(`Error cleaning up lockfile: ${error}`)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
process.once('exit', () => {
|
||||||
|
try {
|
||||||
|
// Synchronous version for 'exit' event since we can't use async here
|
||||||
|
const configPath = getConfigFilePath(serverUrlHash, 'lock.json')
|
||||||
|
require('fs').unlinkSync(configPath)
|
||||||
|
} catch {}
|
||||||
|
})
|
||||||
|
|
||||||
|
// Also handle SIGINT separately
|
||||||
|
process.once('SIGINT', async () => {
|
||||||
|
await cleanupHandler()
|
||||||
|
})
|
||||||
|
|
||||||
|
return {
|
||||||
|
server,
|
||||||
|
waitForAuthCode,
|
||||||
|
skipBrowserAuth: false
|
||||||
|
}
|
||||||
|
}
|
|
@ -1,4 +1,3 @@
|
||||||
import crypto from 'crypto'
|
|
||||||
import path from 'path'
|
import path from 'path'
|
||||||
import os from 'os'
|
import os from 'os'
|
||||||
import fs from 'fs/promises'
|
import fs from 'fs/promises'
|
||||||
|
@ -27,7 +26,61 @@ import { log, MCP_REMOTE_VERSION } from './utils'
|
||||||
/**
|
/**
|
||||||
* Known configuration file names that might need to be cleaned
|
* Known configuration file names that might need to be cleaned
|
||||||
*/
|
*/
|
||||||
export const knownConfigFiles = ['client_info.json', 'tokens.json', 'code_verifier.txt']
|
export const knownConfigFiles = ['client_info.json', 'tokens.json', 'code_verifier.txt', 'lock.json']
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Lockfile data structure
|
||||||
|
*/
|
||||||
|
export interface LockfileData {
|
||||||
|
pid: number
|
||||||
|
port: number
|
||||||
|
timestamp: number
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Creates a lockfile for the given server
|
||||||
|
* @param serverUrlHash The hash of the server URL
|
||||||
|
* @param pid The process ID
|
||||||
|
* @param port The port the server is running on
|
||||||
|
*/
|
||||||
|
export async function createLockfile(serverUrlHash: string, pid: number, port: number): Promise<void> {
|
||||||
|
const lockData: LockfileData = {
|
||||||
|
pid,
|
||||||
|
port,
|
||||||
|
timestamp: Date.now(),
|
||||||
|
}
|
||||||
|
await writeJsonFile(serverUrlHash, 'lock.json', lockData)
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Checks if a lockfile exists for the given server
|
||||||
|
* @param serverUrlHash The hash of the server URL
|
||||||
|
* @returns The lockfile data or null if it doesn't exist
|
||||||
|
*/
|
||||||
|
export async function checkLockfile(serverUrlHash: string): Promise<LockfileData | null> {
|
||||||
|
try {
|
||||||
|
const lockfile = await readJsonFile<LockfileData>(serverUrlHash, 'lock.json', {
|
||||||
|
async parseAsync(data: any) {
|
||||||
|
if (typeof data !== 'object' || data === null) return null
|
||||||
|
if (typeof data.pid !== 'number' || typeof data.port !== 'number' || typeof data.timestamp !== 'number') {
|
||||||
|
return null
|
||||||
|
}
|
||||||
|
return data as LockfileData
|
||||||
|
},
|
||||||
|
})
|
||||||
|
return lockfile || null
|
||||||
|
} catch {
|
||||||
|
return null
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Deletes the lockfile for the given server
|
||||||
|
* @param serverUrlHash The hash of the server URL
|
||||||
|
*/
|
||||||
|
export async function deleteLockfile(serverUrlHash: string): Promise<void> {
|
||||||
|
await deleteConfigFile(serverUrlHash, 'lock.json')
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Deletes all known configuration files for a specific server
|
* Deletes all known configuration files for a specific server
|
||||||
|
@ -63,15 +116,6 @@ export async function ensureConfigDir(): Promise<void> {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
|
||||||
* Generates a hash for the server URL to use in filenames
|
|
||||||
* @param serverUrl The server URL to hash
|
|
||||||
* @returns The hashed server URL
|
|
||||||
*/
|
|
||||||
export function getServerUrlHash(serverUrl: string): string {
|
|
||||||
return crypto.createHash('md5').update(serverUrl).digest('hex')
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Gets the file path for a config file
|
* Gets the file path for a config file
|
||||||
* @param serverUrlHash The hash of the server URL
|
* @param serverUrlHash The hash of the server URL
|
||||||
|
@ -125,11 +169,15 @@ export async function readJsonFile<T>(
|
||||||
|
|
||||||
const filePath = getConfigFilePath(serverUrlHash, filename)
|
const filePath = getConfigFilePath(serverUrlHash, filename)
|
||||||
const content = await fs.readFile(filePath, 'utf-8')
|
const content = await fs.readFile(filePath, 'utf-8')
|
||||||
return await schema.parseAsync(JSON.parse(content))
|
const result = await schema.parseAsync(JSON.parse(content))
|
||||||
|
// console.log({ filename: result })
|
||||||
|
return result
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
if ((error as NodeJS.ErrnoException).code === 'ENOENT') {
|
if ((error as NodeJS.ErrnoException).code === 'ENOENT') {
|
||||||
|
// console.log(`File ${filename} does not exist`)
|
||||||
return undefined
|
return undefined
|
||||||
}
|
}
|
||||||
|
log(`Error reading ${filename}:`, error)
|
||||||
return undefined
|
return undefined
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -8,8 +8,8 @@ import {
|
||||||
OAuthTokensSchema,
|
OAuthTokensSchema,
|
||||||
} from '@modelcontextprotocol/sdk/shared/auth.js'
|
} from '@modelcontextprotocol/sdk/shared/auth.js'
|
||||||
import type { OAuthProviderOptions } from './types'
|
import type { OAuthProviderOptions } from './types'
|
||||||
import { getServerUrlHash, readJsonFile, writeJsonFile, readTextFile, writeTextFile, cleanServerConfig } from './mcp-auth-config'
|
import { readJsonFile, writeJsonFile, readTextFile, writeTextFile, cleanServerConfig } from './mcp-auth-config'
|
||||||
import { log } from './utils'
|
import { getServerUrlHash, log } from './utils'
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Implements the OAuthClientProvider interface for Node.js environments.
|
* Implements the OAuthClientProvider interface for Node.js environments.
|
||||||
|
@ -59,6 +59,7 @@ export class NodeOAuthClientProvider implements OAuthClientProvider {
|
||||||
* @returns The client information or undefined
|
* @returns The client information or undefined
|
||||||
*/
|
*/
|
||||||
async clientInformation(): Promise<OAuthClientInformation | undefined> {
|
async clientInformation(): Promise<OAuthClientInformation | undefined> {
|
||||||
|
// log('Reading client info')
|
||||||
return readJsonFile<OAuthClientInformation>(this.serverUrlHash, 'client_info.json', OAuthClientInformationSchema, this.options.clean)
|
return readJsonFile<OAuthClientInformation>(this.serverUrlHash, 'client_info.json', OAuthClientInformationSchema, this.options.clean)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -67,6 +68,7 @@ export class NodeOAuthClientProvider implements OAuthClientProvider {
|
||||||
* @param clientInformation The client information to save
|
* @param clientInformation The client information to save
|
||||||
*/
|
*/
|
||||||
async saveClientInformation(clientInformation: OAuthClientInformationFull): Promise<void> {
|
async saveClientInformation(clientInformation: OAuthClientInformationFull): Promise<void> {
|
||||||
|
// log('Saving client info')
|
||||||
await writeJsonFile(this.serverUrlHash, 'client_info.json', clientInformation)
|
await writeJsonFile(this.serverUrlHash, 'client_info.json', clientInformation)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -75,6 +77,8 @@ export class NodeOAuthClientProvider implements OAuthClientProvider {
|
||||||
* @returns The OAuth tokens or undefined
|
* @returns The OAuth tokens or undefined
|
||||||
*/
|
*/
|
||||||
async tokens(): Promise<OAuthTokens | undefined> {
|
async tokens(): Promise<OAuthTokens | undefined> {
|
||||||
|
// log('Reading tokens')
|
||||||
|
// console.log(new Error().stack)
|
||||||
return readJsonFile<OAuthTokens>(this.serverUrlHash, 'tokens.json', OAuthTokensSchema, this.options.clean)
|
return readJsonFile<OAuthTokens>(this.serverUrlHash, 'tokens.json', OAuthTokensSchema, this.options.clean)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -83,6 +87,7 @@ export class NodeOAuthClientProvider implements OAuthClientProvider {
|
||||||
* @param tokens The tokens to save
|
* @param tokens The tokens to save
|
||||||
*/
|
*/
|
||||||
async saveTokens(tokens: OAuthTokens): Promise<void> {
|
async saveTokens(tokens: OAuthTokens): Promise<void> {
|
||||||
|
// log('Saving tokens')
|
||||||
await writeJsonFile(this.serverUrlHash, 'tokens.json', tokens)
|
await writeJsonFile(this.serverUrlHash, 'tokens.json', tokens)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -105,6 +110,7 @@ export class NodeOAuthClientProvider implements OAuthClientProvider {
|
||||||
* @param codeVerifier The code verifier to save
|
* @param codeVerifier The code verifier to save
|
||||||
*/
|
*/
|
||||||
async saveCodeVerifier(codeVerifier: string): Promise<void> {
|
async saveCodeVerifier(codeVerifier: string): Promise<void> {
|
||||||
|
// log('Saving code verifier')
|
||||||
await writeTextFile(this.serverUrlHash, 'code_verifier.txt', codeVerifier)
|
await writeTextFile(this.serverUrlHash, 'code_verifier.txt', codeVerifier)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -113,6 +119,7 @@ export class NodeOAuthClientProvider implements OAuthClientProvider {
|
||||||
* @returns The code verifier
|
* @returns The code verifier
|
||||||
*/
|
*/
|
||||||
async codeVerifier(): Promise<string> {
|
async codeVerifier(): Promise<string> {
|
||||||
|
// log('Reading code verifier')
|
||||||
return await readTextFile(this.serverUrlHash, 'code_verifier.txt', 'No code verifier saved for session', this.options.clean)
|
return await readTextFile(this.serverUrlHash, 'code_verifier.txt', 'No code verifier saved for session', this.options.clean)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -4,6 +4,10 @@ import { Transport } from '@modelcontextprotocol/sdk/shared/transport.js'
|
||||||
import { OAuthCallbackServerOptions } from './types'
|
import { OAuthCallbackServerOptions } from './types'
|
||||||
import express from 'express'
|
import express from 'express'
|
||||||
import net from 'net'
|
import net from 'net'
|
||||||
|
import crypto from 'crypto'
|
||||||
|
|
||||||
|
// Package version from package.json
|
||||||
|
export const MCP_REMOTE_VERSION = require('../../package.json').version
|
||||||
|
|
||||||
const pid = process.pid
|
const pid = process.pid
|
||||||
export function log(str: string, ...rest: unknown[]) {
|
export function log(str: string, ...rest: unknown[]) {
|
||||||
|
@ -65,12 +69,14 @@ export function mcpProxy({ transportToClient, transportToServer }: { transportTo
|
||||||
* @param serverUrl The URL of the remote server
|
* @param serverUrl The URL of the remote server
|
||||||
* @param authProvider The OAuth client provider
|
* @param authProvider The OAuth client provider
|
||||||
* @param waitForAuthCode Function to wait for the auth code
|
* @param waitForAuthCode Function to wait for the auth code
|
||||||
|
* @param skipBrowserAuth Whether to skip browser auth and use shared auth
|
||||||
* @returns The connected SSE client transport
|
* @returns The connected SSE client transport
|
||||||
*/
|
*/
|
||||||
export async function connectToRemoteServer(
|
export async function connectToRemoteServer(
|
||||||
serverUrl: string,
|
serverUrl: string,
|
||||||
authProvider: OAuthClientProvider,
|
authProvider: OAuthClientProvider,
|
||||||
waitForAuthCode: () => Promise<string>,
|
waitForAuthCode: () => Promise<string>,
|
||||||
|
skipBrowserAuth: boolean = false,
|
||||||
): Promise<SSEClientTransport> {
|
): Promise<SSEClientTransport> {
|
||||||
log(`[${pid}] Connecting to remote server: ${serverUrl}`)
|
log(`[${pid}] Connecting to remote server: ${serverUrl}`)
|
||||||
const url = new URL(serverUrl)
|
const url = new URL(serverUrl)
|
||||||
|
@ -82,7 +88,11 @@ export async function connectToRemoteServer(
|
||||||
return transport
|
return transport
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
if (error instanceof UnauthorizedError || (error instanceof Error && error.message.includes('Unauthorized'))) {
|
if (error instanceof UnauthorizedError || (error instanceof Error && error.message.includes('Unauthorized'))) {
|
||||||
log('Authentication required. Waiting for authorization...')
|
if (skipBrowserAuth) {
|
||||||
|
log('Authentication required but skipping browser auth - using shared auth')
|
||||||
|
} else {
|
||||||
|
log('Authentication required. Waiting for authorization...')
|
||||||
|
}
|
||||||
|
|
||||||
// Wait for the authorization code from the callback
|
// Wait for the authorization code from the callback
|
||||||
const code = await waitForAuthCode()
|
const code = await waitForAuthCode()
|
||||||
|
@ -112,10 +122,57 @@ export async function connectToRemoteServer(
|
||||||
* @param options The server options
|
* @param options The server options
|
||||||
* @returns An object with the server, authCode, and waitForAuthCode function
|
* @returns An object with the server, authCode, and waitForAuthCode function
|
||||||
*/
|
*/
|
||||||
export function setupOAuthCallbackServer(options: OAuthCallbackServerOptions) {
|
export function setupOAuthCallbackServerWithLongPoll(options: OAuthCallbackServerOptions) {
|
||||||
let authCode: string | null = null
|
let authCode: string | null = null
|
||||||
const app = express()
|
const app = express()
|
||||||
|
|
||||||
|
// Create a promise to track when auth is completed
|
||||||
|
let authCompletedResolve: (code: string) => void
|
||||||
|
const authCompletedPromise = new Promise<string>((resolve) => {
|
||||||
|
authCompletedResolve = resolve
|
||||||
|
})
|
||||||
|
|
||||||
|
// Long-polling endpoint
|
||||||
|
app.get('/wait-for-auth', (req, res) => {
|
||||||
|
if (authCode) {
|
||||||
|
// Auth already completed - just return 200 without the actual code
|
||||||
|
// Secondary instances will read tokens from disk
|
||||||
|
log('Auth already completed, returning 200')
|
||||||
|
res.status(200).send('Authentication completed')
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if (req.query.poll === 'false') {
|
||||||
|
log('Client requested no long poll, responding with 202')
|
||||||
|
res.status(202).send('Authentication in progress')
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Long poll - wait for up to 30 seconds
|
||||||
|
const longPollTimeout = setTimeout(() => {
|
||||||
|
log('Long poll timeout reached, responding with 202')
|
||||||
|
res.status(202).send('Authentication in progress')
|
||||||
|
}, 30000)
|
||||||
|
|
||||||
|
// If auth completes while we're waiting, send the response immediately
|
||||||
|
authCompletedPromise
|
||||||
|
.then(() => {
|
||||||
|
clearTimeout(longPollTimeout)
|
||||||
|
if (!res.headersSent) {
|
||||||
|
log('Auth completed during long poll, responding with 200')
|
||||||
|
res.status(200).send('Authentication completed')
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.catch(() => {
|
||||||
|
clearTimeout(longPollTimeout)
|
||||||
|
if (!res.headersSent) {
|
||||||
|
log('Auth failed during long poll, responding with 500')
|
||||||
|
res.status(500).send('Authentication failed')
|
||||||
|
}
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
// OAuth callback endpoint
|
||||||
app.get(options.path, (req, res) => {
|
app.get(options.path, (req, res) => {
|
||||||
const code = req.query.code as string | undefined
|
const code = req.query.code as string | undefined
|
||||||
if (!code) {
|
if (!code) {
|
||||||
|
@ -124,6 +181,9 @@ export function setupOAuthCallbackServer(options: OAuthCallbackServerOptions) {
|
||||||
}
|
}
|
||||||
|
|
||||||
authCode = code
|
authCode = code
|
||||||
|
log('Auth code received, resolving promise')
|
||||||
|
authCompletedResolve(code)
|
||||||
|
|
||||||
res.send('Authorization successful! You may close this window and return to the CLI.')
|
res.send('Authorization successful! You may close this window and return to the CLI.')
|
||||||
|
|
||||||
// Notify main flow that auth code is available
|
// Notify main flow that auth code is available
|
||||||
|
@ -134,10 +194,6 @@ export function setupOAuthCallbackServer(options: OAuthCallbackServerOptions) {
|
||||||
log(`OAuth callback server running at http://127.0.0.1:${options.port}`)
|
log(`OAuth callback server running at http://127.0.0.1:${options.port}`)
|
||||||
})
|
})
|
||||||
|
|
||||||
/**
|
|
||||||
* Waits for the OAuth authorization code
|
|
||||||
* @returns A promise that resolves with the authorization code
|
|
||||||
*/
|
|
||||||
const waitForAuthCode = (): Promise<string> => {
|
const waitForAuthCode = (): Promise<string> => {
|
||||||
return new Promise((resolve) => {
|
return new Promise((resolve) => {
|
||||||
if (authCode) {
|
if (authCode) {
|
||||||
|
@ -151,6 +207,16 @@ export function setupOAuthCallbackServer(options: OAuthCallbackServerOptions) {
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
return { server, authCode, waitForAuthCode, authCompletedPromise }
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Sets up an Express server to handle OAuth callbacks
|
||||||
|
* @param options The server options
|
||||||
|
* @returns An object with the server, authCode, and waitForAuthCode function
|
||||||
|
*/
|
||||||
|
export function setupOAuthCallbackServer(options: OAuthCallbackServerOptions) {
|
||||||
|
const { server, authCode, waitForAuthCode } = setupOAuthCallbackServerWithLongPoll(options)
|
||||||
return { server, authCode, waitForAuthCode }
|
return { server, authCode, waitForAuthCode }
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -248,4 +314,11 @@ export function setupSignalHandlers(cleanup: () => Promise<void>) {
|
||||||
process.stdin.resume()
|
process.stdin.resume()
|
||||||
}
|
}
|
||||||
|
|
||||||
export const MCP_REMOTE_VERSION = require('../../package.json').version
|
/**
|
||||||
|
* Generates a hash for the server URL to use in filenames
|
||||||
|
* @param serverUrl The server URL to hash
|
||||||
|
* @returns The hashed server URL
|
||||||
|
*/
|
||||||
|
export function getServerUrlHash(serverUrl: string): string {
|
||||||
|
return crypto.createHash('md5').update(serverUrl).digest('hex')
|
||||||
|
}
|
||||||
|
|
26
src/proxy.ts
26
src/proxy.ts
|
@ -14,8 +14,9 @@
|
||||||
|
|
||||||
import { EventEmitter } from 'events'
|
import { EventEmitter } from 'events'
|
||||||
import { StdioServerTransport } from '@modelcontextprotocol/sdk/server/stdio.js'
|
import { StdioServerTransport } from '@modelcontextprotocol/sdk/server/stdio.js'
|
||||||
import { connectToRemoteServer, log, mcpProxy, parseCommandLineArgs, setupOAuthCallbackServer, setupSignalHandlers } from './lib/utils'
|
import { connectToRemoteServer, log, mcpProxy, parseCommandLineArgs, setupSignalHandlers, getServerUrlHash } from './lib/utils'
|
||||||
import { NodeOAuthClientProvider } from './lib/node-oauth-client-provider'
|
import { NodeOAuthClientProvider } from './lib/node-oauth-client-provider'
|
||||||
|
import { coordinateAuth } from './lib/coordination'
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Main function to run the proxy
|
* Main function to run the proxy
|
||||||
|
@ -24,6 +25,12 @@ async function runProxy(serverUrl: string, callbackPort: number, clean: boolean
|
||||||
// Set up event emitter for auth flow
|
// Set up event emitter for auth flow
|
||||||
const events = new EventEmitter()
|
const events = new EventEmitter()
|
||||||
|
|
||||||
|
// Get the server URL hash for lockfile operations
|
||||||
|
const serverUrlHash = getServerUrlHash(serverUrl)
|
||||||
|
|
||||||
|
// Coordinate authentication with other instances
|
||||||
|
const { server, waitForAuthCode, skipBrowserAuth } = await coordinateAuth(serverUrlHash, callbackPort, events)
|
||||||
|
|
||||||
// Create the OAuth client provider
|
// Create the OAuth client provider
|
||||||
const authProvider = new NodeOAuthClientProvider({
|
const authProvider = new NodeOAuthClientProvider({
|
||||||
serverUrl,
|
serverUrl,
|
||||||
|
@ -32,19 +39,20 @@ async function runProxy(serverUrl: string, callbackPort: number, clean: boolean
|
||||||
clean,
|
clean,
|
||||||
})
|
})
|
||||||
|
|
||||||
|
// If auth was completed by another instance, just log that we'll use the auth from disk
|
||||||
|
if (skipBrowserAuth) {
|
||||||
|
log('Authentication was completed by another instance - will use tokens from disk')
|
||||||
|
// TODO: remove, the callback is happening before the tokens are exchanged
|
||||||
|
// so we're slightly too early
|
||||||
|
await new Promise((res) => setTimeout(res, 1_000))
|
||||||
|
}
|
||||||
|
|
||||||
// Create the STDIO transport for local connections
|
// Create the STDIO transport for local connections
|
||||||
const localTransport = new StdioServerTransport()
|
const localTransport = new StdioServerTransport()
|
||||||
|
|
||||||
// Set up an HTTP server to handle OAuth callback
|
|
||||||
const { server, waitForAuthCode } = setupOAuthCallbackServer({
|
|
||||||
port: callbackPort,
|
|
||||||
path: '/oauth/callback',
|
|
||||||
events,
|
|
||||||
})
|
|
||||||
|
|
||||||
try {
|
try {
|
||||||
// Connect to remote server with authentication
|
// Connect to remote server with authentication
|
||||||
const remoteTransport = await connectToRemoteServer(serverUrl, authProvider, waitForAuthCode)
|
const remoteTransport = await connectToRemoteServer(serverUrl, authProvider, waitForAuthCode, skipBrowserAuth)
|
||||||
|
|
||||||
// Set up bidirectional proxy between local and remote transports
|
// Set up bidirectional proxy between local and remote transports
|
||||||
mcpProxy({
|
mcpProxy({
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue