Refactor auth coordination to be lazy
- Create lazy auth coordinator that only initializes when needed - Modify connectToRemoteServer to only use auth when receiving an Unauthorized error - Update client.ts and proxy.ts to use the lazy auth approach - Add refactoring plan documentation 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude <noreply@anthropic.com>
This commit is contained in:
parent
14109a309f
commit
a6e6d0f1e8
5 changed files with 169 additions and 41 deletions
|
@ -22,7 +22,7 @@ import {
|
|||
connectToRemoteServer,
|
||||
TransportStrategy,
|
||||
} from './lib/utils'
|
||||
import { coordinateAuth } from './lib/coordination'
|
||||
import { createLazyAuthCoordinator } from './lib/coordination'
|
||||
|
||||
/**
|
||||
* Main function to run the client
|
||||
|
@ -39,8 +39,8 @@ async function runClient(
|
|||
// Get the server URL hash for lockfile operations
|
||||
const serverUrlHash = getServerUrlHash(serverUrl)
|
||||
|
||||
// Coordinate authentication with other instances
|
||||
const { server, waitForAuthCode, skipBrowserAuth } = await coordinateAuth(serverUrlHash, callbackPort, events)
|
||||
// Create a lazy auth coordinator
|
||||
const authCoordinator = createLazyAuthCoordinator(serverUrlHash, callbackPort, events)
|
||||
|
||||
// Create the OAuth client provider
|
||||
const authProvider = new NodeOAuthClientProvider({
|
||||
|
@ -49,14 +49,6 @@ async function runClient(
|
|||
clientName: 'MCP CLI Client',
|
||||
})
|
||||
|
||||
// If auth was completed by another instance, just log that we'll use the auth from disk
|
||||
if (skipBrowserAuth) {
|
||||
log('Authentication was completed by another instance - will use tokens from disk...')
|
||||
// TODO: remove, the callback is happening before the tokens are exchanged
|
||||
// so we're slightly too early
|
||||
await new Promise((res) => setTimeout(res, 1_000))
|
||||
}
|
||||
|
||||
// Create the client
|
||||
const client = new Client(
|
||||
{
|
||||
|
@ -68,15 +60,38 @@ async function runClient(
|
|||
},
|
||||
)
|
||||
|
||||
// Keep track of the server instance for cleanup
|
||||
let server: any = null
|
||||
|
||||
// Define an auth initializer function
|
||||
const authInitializer = async () => {
|
||||
const authState = await authCoordinator.initializeAuth()
|
||||
|
||||
// Store server in outer scope for cleanup
|
||||
server = authState.server
|
||||
|
||||
// If auth was completed by another instance, just log that we'll use the auth from disk
|
||||
if (authState.skipBrowserAuth) {
|
||||
log('Authentication was completed by another instance - will use tokens from disk...')
|
||||
// TODO: remove, the callback is happening before the tokens are exchanged
|
||||
// so we're slightly too early
|
||||
await new Promise((res) => setTimeout(res, 1_000))
|
||||
}
|
||||
|
||||
return {
|
||||
waitForAuthCode: authState.waitForAuthCode,
|
||||
skipBrowserAuth: authState.skipBrowserAuth
|
||||
}
|
||||
}
|
||||
|
||||
try {
|
||||
// Connect to remote server with authentication
|
||||
// Connect to remote server with lazy authentication
|
||||
const transport = await connectToRemoteServer(
|
||||
client,
|
||||
serverUrl,
|
||||
authProvider,
|
||||
headers,
|
||||
waitForAuthCode,
|
||||
skipBrowserAuth,
|
||||
authInitializer,
|
||||
transportStrategy,
|
||||
)
|
||||
|
||||
|
@ -98,7 +113,10 @@ async function runClient(
|
|||
const cleanup = async () => {
|
||||
log('\nClosing connection...')
|
||||
await client.close()
|
||||
server.close()
|
||||
// If auth was initialized and server was created, close it
|
||||
if (server) {
|
||||
server.close()
|
||||
}
|
||||
}
|
||||
setupSignalHandlers(cleanup)
|
||||
|
||||
|
@ -124,11 +142,17 @@ async function runClient(
|
|||
|
||||
// log('Listening for messages. Press Ctrl+C to exit.')
|
||||
log('Exiting OK...')
|
||||
server.close()
|
||||
// Only close the server if it was initialized
|
||||
if (server) {
|
||||
server.close()
|
||||
}
|
||||
process.exit(0)
|
||||
} catch (error) {
|
||||
log('Fatal error:', error)
|
||||
server.close()
|
||||
// Only close the server if it was initialized
|
||||
if (server) {
|
||||
server.close()
|
||||
}
|
||||
process.exit(1)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -5,6 +5,10 @@ import express from 'express'
|
|||
import { AddressInfo } from 'net'
|
||||
import { log, setupOAuthCallbackServerWithLongPoll } from './utils'
|
||||
|
||||
export type AuthCoordinator = {
|
||||
initializeAuth: () => Promise<{ server: Server; waitForAuthCode: () => Promise<string>; skipBrowserAuth: boolean }>
|
||||
}
|
||||
|
||||
/**
|
||||
* Checks if a process with the given PID is running
|
||||
* @param pid The process ID to check
|
||||
|
@ -88,6 +92,36 @@ export async function waitForAuthentication(port: number): Promise<boolean> {
|
|||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates a lazy auth coordinator that will only initiate auth when needed
|
||||
* @param serverUrlHash The hash of the server URL
|
||||
* @param callbackPort The port to use for the callback server
|
||||
* @param events The event emitter to use for signaling
|
||||
* @returns An AuthCoordinator object with an initializeAuth method
|
||||
*/
|
||||
export function createLazyAuthCoordinator(
|
||||
serverUrlHash: string,
|
||||
callbackPort: number,
|
||||
events: EventEmitter
|
||||
): AuthCoordinator {
|
||||
let authState: { server: Server; waitForAuthCode: () => Promise<string>; skipBrowserAuth: boolean } | null = null
|
||||
|
||||
return {
|
||||
initializeAuth: async () => {
|
||||
// If auth has already been initialized, return the existing state
|
||||
if (authState) {
|
||||
return authState
|
||||
}
|
||||
|
||||
log('Initializing auth coordination on-demand')
|
||||
|
||||
// Initialize auth using the existing coordinateAuth logic
|
||||
authState = await coordinateAuth(serverUrlHash, callbackPort, events)
|
||||
return authState
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Coordinates authentication between multiple instances of the client/proxy
|
||||
* @param serverUrlHash The hash of the server URL
|
||||
|
|
|
@ -73,14 +73,21 @@ export function mcpProxy({ transportToClient, transportToServer }: { transportTo
|
|||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Type for the auth initialization function
|
||||
*/
|
||||
export type AuthInitializer = () => Promise<{
|
||||
waitForAuthCode: () => Promise<string>
|
||||
skipBrowserAuth: boolean
|
||||
}>
|
||||
|
||||
/**
|
||||
* Creates and connects to a remote server with OAuth authentication
|
||||
* @param client The client to connect with
|
||||
* @param serverUrl The URL of the remote server
|
||||
* @param authProvider The OAuth client provider
|
||||
* @param headers Additional headers to send with the request
|
||||
* @param waitForAuthCode Function to wait for the auth code
|
||||
* @param skipBrowserAuth Whether to skip browser auth and use shared auth
|
||||
* @param authInitializer Function to initialize authentication when needed
|
||||
* @param transportStrategy Strategy for selecting transport type ('sse-only', 'http-only', 'sse-first', 'http-first')
|
||||
* @param recursionReasons Set of reasons for recursive calls (internal use)
|
||||
* @returns The connected transport
|
||||
|
@ -90,8 +97,7 @@ export async function connectToRemoteServer(
|
|||
serverUrl: string,
|
||||
authProvider: OAuthClientProvider,
|
||||
headers: Record<string, string>,
|
||||
waitForAuthCode: () => Promise<string>,
|
||||
skipBrowserAuth: boolean = false,
|
||||
authInitializer: AuthInitializer,
|
||||
transportStrategy: TransportStrategy = 'http-first',
|
||||
recursionReasons: Set<string> = new Set(),
|
||||
): Promise<Transport> {
|
||||
|
@ -143,7 +149,6 @@ export async function connectToRemoteServer(
|
|||
|
||||
return transport
|
||||
} catch (error) {
|
||||
console.log('DID I CATCH OR WHAT?')
|
||||
// Check if it's a protocol error and we should attempt fallback
|
||||
if (
|
||||
error instanceof Error &&
|
||||
|
@ -172,12 +177,16 @@ export async function connectToRemoteServer(
|
|||
serverUrl,
|
||||
authProvider,
|
||||
headers,
|
||||
waitForAuthCode,
|
||||
skipBrowserAuth,
|
||||
authInitializer,
|
||||
sseTransport ? 'http-only' : 'sse-only',
|
||||
recursionReasons,
|
||||
)
|
||||
} else if (error instanceof UnauthorizedError || (error instanceof Error && error.message.includes('Unauthorized'))) {
|
||||
log('Authentication required. Initializing auth...')
|
||||
|
||||
// Initialize authentication on-demand
|
||||
const { waitForAuthCode, skipBrowserAuth } = await authInitializer()
|
||||
|
||||
if (skipBrowserAuth) {
|
||||
log('Authentication required but skipping browser auth - using shared auth')
|
||||
} else {
|
||||
|
@ -207,8 +216,7 @@ export async function connectToRemoteServer(
|
|||
serverUrl,
|
||||
authProvider,
|
||||
headers,
|
||||
waitForAuthCode,
|
||||
skipBrowserAuth,
|
||||
authInitializer,
|
||||
transportStrategy,
|
||||
recursionReasons,
|
||||
)
|
||||
|
|
52
src/proxy.ts
52
src/proxy.ts
|
@ -21,7 +21,7 @@ import {
|
|||
MCP_REMOTE_VERSION,
|
||||
} from './lib/utils'
|
||||
import { NodeOAuthClientProvider } from './lib/node-oauth-client-provider'
|
||||
import { coordinateAuth } from './lib/coordination'
|
||||
import { createLazyAuthCoordinator } from './lib/coordination'
|
||||
import { Client } from '@modelcontextprotocol/sdk/client/index.js'
|
||||
|
||||
/**
|
||||
|
@ -34,8 +34,8 @@ async function runProxy(serverUrl: string, callbackPort: number, headers: Record
|
|||
// Get the server URL hash for lockfile operations
|
||||
const serverUrlHash = getServerUrlHash(serverUrl)
|
||||
|
||||
// Coordinate authentication with other instances
|
||||
const { server, waitForAuthCode, skipBrowserAuth } = await coordinateAuth(serverUrlHash, callbackPort, events)
|
||||
// Create a lazy auth coordinator
|
||||
const authCoordinator = createLazyAuthCoordinator(serverUrlHash, callbackPort, events)
|
||||
|
||||
// Create the OAuth client provider
|
||||
const authProvider = new NodeOAuthClientProvider({
|
||||
|
@ -44,17 +44,33 @@ async function runProxy(serverUrl: string, callbackPort: number, headers: Record
|
|||
clientName: 'MCP CLI Proxy',
|
||||
})
|
||||
|
||||
// If auth was completed by another instance, just log that we'll use the auth from disk
|
||||
if (skipBrowserAuth) {
|
||||
log('Authentication was completed by another instance - will use tokens from disk')
|
||||
// TODO: remove, the callback is happening before the tokens are exchanged
|
||||
// so we're slightly too early
|
||||
await new Promise((res) => setTimeout(res, 1_000))
|
||||
}
|
||||
|
||||
// Create the STDIO transport for local connections
|
||||
const localTransport = new StdioServerTransport()
|
||||
|
||||
// Keep track of the server instance for cleanup
|
||||
let server: any = null
|
||||
|
||||
// Define an auth initializer function
|
||||
const authInitializer = async () => {
|
||||
const authState = await authCoordinator.initializeAuth()
|
||||
|
||||
// Store server in outer scope for cleanup
|
||||
server = authState.server
|
||||
|
||||
// If auth was completed by another instance, just log that we'll use the auth from disk
|
||||
if (authState.skipBrowserAuth) {
|
||||
log('Authentication was completed by another instance - will use tokens from disk')
|
||||
// TODO: remove, the callback is happening before the tokens are exchanged
|
||||
// so we're slightly too early
|
||||
await new Promise((res) => setTimeout(res, 1_000))
|
||||
}
|
||||
|
||||
return {
|
||||
waitForAuthCode: authState.waitForAuthCode,
|
||||
skipBrowserAuth: authState.skipBrowserAuth
|
||||
}
|
||||
}
|
||||
|
||||
try {
|
||||
const client = new Client(
|
||||
{
|
||||
|
@ -65,8 +81,8 @@ async function runProxy(serverUrl: string, callbackPort: number, headers: Record
|
|||
capabilities: {},
|
||||
},
|
||||
)
|
||||
// Connect to remote server with authentication
|
||||
const remoteTransport = await connectToRemoteServer(client, serverUrl, authProvider, headers, waitForAuthCode, skipBrowserAuth)
|
||||
// Connect to remote server with lazy authentication
|
||||
const remoteTransport = await connectToRemoteServer(client, serverUrl, authProvider, headers, authInitializer)
|
||||
|
||||
// Set up bidirectional proxy between local and remote transports
|
||||
mcpProxy({
|
||||
|
@ -84,7 +100,10 @@ async function runProxy(serverUrl: string, callbackPort: number, headers: Record
|
|||
const cleanup = async () => {
|
||||
await remoteTransport.close()
|
||||
await localTransport.close()
|
||||
server.close()
|
||||
// Only close the server if it was initialized
|
||||
if (server) {
|
||||
server.close()
|
||||
}
|
||||
}
|
||||
setupSignalHandlers(cleanup)
|
||||
} catch (error) {
|
||||
|
@ -111,7 +130,10 @@ to the CA certificate file. If using claude_desktop_config.json, this might look
|
|||
}
|
||||
`)
|
||||
}
|
||||
server.close()
|
||||
// Only close the server if it was initialized
|
||||
if (server) {
|
||||
server.close()
|
||||
}
|
||||
process.exit(1)
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue