Refactor auth coordination to be lazy
- Create lazy auth coordinator that only initializes when needed - Modify connectToRemoteServer to only use auth when receiving an Unauthorized error - Update client.ts and proxy.ts to use the lazy auth approach - Add refactoring plan documentation 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude <noreply@anthropic.com>
This commit is contained in:
parent
14109a309f
commit
a6e6d0f1e8
5 changed files with 169 additions and 41 deletions
40
refactoring-plan.md
Normal file
40
refactoring-plan.md
Normal file
|
@ -0,0 +1,40 @@
|
||||||
|
# Auth Coordination Refactoring Plan
|
||||||
|
|
||||||
|
Currently, both `src/proxy.ts` and `src/client.ts` always run auth coordination before attempting to connect to the server. However, in some cases authentication is not required, and we already have the ability to catch and handle authorization errors in the `connectToRemoteServer` function.
|
||||||
|
|
||||||
|
The plan is to refactor the code so that auth coordination is only invoked when we actually receive an "Unauthorized" error, rather than preemptively setting up auth for all connections.
|
||||||
|
|
||||||
|
## Tasks
|
||||||
|
|
||||||
|
1. [x] **Create a lazy auth coordinator**: Modify `coordinateAuth` function to support lazy initialization, so we can set it up but only use it when needed.
|
||||||
|
- Added `createLazyAuthCoordinator` function that returns an object with `initializeAuth` method
|
||||||
|
- Kept original `coordinateAuth` function intact for backward compatibility
|
||||||
|
|
||||||
|
2. [x] **Refactor `connectToRemoteServer`**: Update this function to handle auth lazily:
|
||||||
|
- Removed the `waitForAuthCode` and `skipBrowserAuth` parameters
|
||||||
|
- Added a new `authInitializer` parameter that initializes auth when needed
|
||||||
|
- Only call this initializer when we encounter an "Unauthorized" error
|
||||||
|
- Created a new type `AuthInitializer` to define the expected interface
|
||||||
|
|
||||||
|
3. [x] **Update client.ts**: Refactor the client to use the new lazy auth approach.
|
||||||
|
- No longer calling `coordinateAuth` at the beginning
|
||||||
|
- Created function to initiate auth only when needed
|
||||||
|
- Pass this function to `connectToRemoteServer`
|
||||||
|
- Added proper handling of server cleanup
|
||||||
|
|
||||||
|
4. [x] **Update proxy.ts**: Similarly refactor the proxy to use the lazy auth approach.
|
||||||
|
- No longer calling `coordinateAuth` at the beginning
|
||||||
|
- Created function to initiate auth only when needed
|
||||||
|
- Pass this function to `connectToRemoteServer`
|
||||||
|
- Added proper handling of server cleanup
|
||||||
|
|
||||||
|
5. [ ] **Test both flows**:
|
||||||
|
- Test with servers requiring authentication
|
||||||
|
- Test with servers that don't require authentication
|
||||||
|
|
||||||
|
## Benefits
|
||||||
|
|
||||||
|
- Improved efficiency by avoiding unnecessary auth setup when not needed
|
||||||
|
- Faster startup for connections that don't require auth
|
||||||
|
- Cleaner separation of concerns
|
||||||
|
- Reduced complexity in the call flow
|
|
@ -22,7 +22,7 @@ import {
|
||||||
connectToRemoteServer,
|
connectToRemoteServer,
|
||||||
TransportStrategy,
|
TransportStrategy,
|
||||||
} from './lib/utils'
|
} from './lib/utils'
|
||||||
import { coordinateAuth } from './lib/coordination'
|
import { createLazyAuthCoordinator } from './lib/coordination'
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Main function to run the client
|
* Main function to run the client
|
||||||
|
@ -39,8 +39,8 @@ async function runClient(
|
||||||
// Get the server URL hash for lockfile operations
|
// Get the server URL hash for lockfile operations
|
||||||
const serverUrlHash = getServerUrlHash(serverUrl)
|
const serverUrlHash = getServerUrlHash(serverUrl)
|
||||||
|
|
||||||
// Coordinate authentication with other instances
|
// Create a lazy auth coordinator
|
||||||
const { server, waitForAuthCode, skipBrowserAuth } = await coordinateAuth(serverUrlHash, callbackPort, events)
|
const authCoordinator = createLazyAuthCoordinator(serverUrlHash, callbackPort, events)
|
||||||
|
|
||||||
// Create the OAuth client provider
|
// Create the OAuth client provider
|
||||||
const authProvider = new NodeOAuthClientProvider({
|
const authProvider = new NodeOAuthClientProvider({
|
||||||
|
@ -49,14 +49,6 @@ async function runClient(
|
||||||
clientName: 'MCP CLI Client',
|
clientName: 'MCP CLI Client',
|
||||||
})
|
})
|
||||||
|
|
||||||
// If auth was completed by another instance, just log that we'll use the auth from disk
|
|
||||||
if (skipBrowserAuth) {
|
|
||||||
log('Authentication was completed by another instance - will use tokens from disk...')
|
|
||||||
// TODO: remove, the callback is happening before the tokens are exchanged
|
|
||||||
// so we're slightly too early
|
|
||||||
await new Promise((res) => setTimeout(res, 1_000))
|
|
||||||
}
|
|
||||||
|
|
||||||
// Create the client
|
// Create the client
|
||||||
const client = new Client(
|
const client = new Client(
|
||||||
{
|
{
|
||||||
|
@ -68,15 +60,38 @@ async function runClient(
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// Keep track of the server instance for cleanup
|
||||||
|
let server: any = null
|
||||||
|
|
||||||
|
// Define an auth initializer function
|
||||||
|
const authInitializer = async () => {
|
||||||
|
const authState = await authCoordinator.initializeAuth()
|
||||||
|
|
||||||
|
// Store server in outer scope for cleanup
|
||||||
|
server = authState.server
|
||||||
|
|
||||||
|
// If auth was completed by another instance, just log that we'll use the auth from disk
|
||||||
|
if (authState.skipBrowserAuth) {
|
||||||
|
log('Authentication was completed by another instance - will use tokens from disk...')
|
||||||
|
// TODO: remove, the callback is happening before the tokens are exchanged
|
||||||
|
// so we're slightly too early
|
||||||
|
await new Promise((res) => setTimeout(res, 1_000))
|
||||||
|
}
|
||||||
|
|
||||||
|
return {
|
||||||
|
waitForAuthCode: authState.waitForAuthCode,
|
||||||
|
skipBrowserAuth: authState.skipBrowserAuth
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
try {
|
try {
|
||||||
// Connect to remote server with authentication
|
// Connect to remote server with lazy authentication
|
||||||
const transport = await connectToRemoteServer(
|
const transport = await connectToRemoteServer(
|
||||||
client,
|
client,
|
||||||
serverUrl,
|
serverUrl,
|
||||||
authProvider,
|
authProvider,
|
||||||
headers,
|
headers,
|
||||||
waitForAuthCode,
|
authInitializer,
|
||||||
skipBrowserAuth,
|
|
||||||
transportStrategy,
|
transportStrategy,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -98,8 +113,11 @@ async function runClient(
|
||||||
const cleanup = async () => {
|
const cleanup = async () => {
|
||||||
log('\nClosing connection...')
|
log('\nClosing connection...')
|
||||||
await client.close()
|
await client.close()
|
||||||
|
// If auth was initialized and server was created, close it
|
||||||
|
if (server) {
|
||||||
server.close()
|
server.close()
|
||||||
}
|
}
|
||||||
|
}
|
||||||
setupSignalHandlers(cleanup)
|
setupSignalHandlers(cleanup)
|
||||||
|
|
||||||
log('Connected successfully!')
|
log('Connected successfully!')
|
||||||
|
@ -124,11 +142,17 @@ async function runClient(
|
||||||
|
|
||||||
// log('Listening for messages. Press Ctrl+C to exit.')
|
// log('Listening for messages. Press Ctrl+C to exit.')
|
||||||
log('Exiting OK...')
|
log('Exiting OK...')
|
||||||
|
// Only close the server if it was initialized
|
||||||
|
if (server) {
|
||||||
server.close()
|
server.close()
|
||||||
|
}
|
||||||
process.exit(0)
|
process.exit(0)
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
log('Fatal error:', error)
|
log('Fatal error:', error)
|
||||||
|
// Only close the server if it was initialized
|
||||||
|
if (server) {
|
||||||
server.close()
|
server.close()
|
||||||
|
}
|
||||||
process.exit(1)
|
process.exit(1)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -5,6 +5,10 @@ import express from 'express'
|
||||||
import { AddressInfo } from 'net'
|
import { AddressInfo } from 'net'
|
||||||
import { log, setupOAuthCallbackServerWithLongPoll } from './utils'
|
import { log, setupOAuthCallbackServerWithLongPoll } from './utils'
|
||||||
|
|
||||||
|
export type AuthCoordinator = {
|
||||||
|
initializeAuth: () => Promise<{ server: Server; waitForAuthCode: () => Promise<string>; skipBrowserAuth: boolean }>
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Checks if a process with the given PID is running
|
* Checks if a process with the given PID is running
|
||||||
* @param pid The process ID to check
|
* @param pid The process ID to check
|
||||||
|
@ -88,6 +92,36 @@ export async function waitForAuthentication(port: number): Promise<boolean> {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Creates a lazy auth coordinator that will only initiate auth when needed
|
||||||
|
* @param serverUrlHash The hash of the server URL
|
||||||
|
* @param callbackPort The port to use for the callback server
|
||||||
|
* @param events The event emitter to use for signaling
|
||||||
|
* @returns An AuthCoordinator object with an initializeAuth method
|
||||||
|
*/
|
||||||
|
export function createLazyAuthCoordinator(
|
||||||
|
serverUrlHash: string,
|
||||||
|
callbackPort: number,
|
||||||
|
events: EventEmitter
|
||||||
|
): AuthCoordinator {
|
||||||
|
let authState: { server: Server; waitForAuthCode: () => Promise<string>; skipBrowserAuth: boolean } | null = null
|
||||||
|
|
||||||
|
return {
|
||||||
|
initializeAuth: async () => {
|
||||||
|
// If auth has already been initialized, return the existing state
|
||||||
|
if (authState) {
|
||||||
|
return authState
|
||||||
|
}
|
||||||
|
|
||||||
|
log('Initializing auth coordination on-demand')
|
||||||
|
|
||||||
|
// Initialize auth using the existing coordinateAuth logic
|
||||||
|
authState = await coordinateAuth(serverUrlHash, callbackPort, events)
|
||||||
|
return authState
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Coordinates authentication between multiple instances of the client/proxy
|
* Coordinates authentication between multiple instances of the client/proxy
|
||||||
* @param serverUrlHash The hash of the server URL
|
* @param serverUrlHash The hash of the server URL
|
||||||
|
|
|
@ -73,14 +73,21 @@ export function mcpProxy({ transportToClient, transportToServer }: { transportTo
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Type for the auth initialization function
|
||||||
|
*/
|
||||||
|
export type AuthInitializer = () => Promise<{
|
||||||
|
waitForAuthCode: () => Promise<string>
|
||||||
|
skipBrowserAuth: boolean
|
||||||
|
}>
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Creates and connects to a remote server with OAuth authentication
|
* Creates and connects to a remote server with OAuth authentication
|
||||||
* @param client The client to connect with
|
* @param client The client to connect with
|
||||||
* @param serverUrl The URL of the remote server
|
* @param serverUrl The URL of the remote server
|
||||||
* @param authProvider The OAuth client provider
|
* @param authProvider The OAuth client provider
|
||||||
* @param headers Additional headers to send with the request
|
* @param headers Additional headers to send with the request
|
||||||
* @param waitForAuthCode Function to wait for the auth code
|
* @param authInitializer Function to initialize authentication when needed
|
||||||
* @param skipBrowserAuth Whether to skip browser auth and use shared auth
|
|
||||||
* @param transportStrategy Strategy for selecting transport type ('sse-only', 'http-only', 'sse-first', 'http-first')
|
* @param transportStrategy Strategy for selecting transport type ('sse-only', 'http-only', 'sse-first', 'http-first')
|
||||||
* @param recursionReasons Set of reasons for recursive calls (internal use)
|
* @param recursionReasons Set of reasons for recursive calls (internal use)
|
||||||
* @returns The connected transport
|
* @returns The connected transport
|
||||||
|
@ -90,8 +97,7 @@ export async function connectToRemoteServer(
|
||||||
serverUrl: string,
|
serverUrl: string,
|
||||||
authProvider: OAuthClientProvider,
|
authProvider: OAuthClientProvider,
|
||||||
headers: Record<string, string>,
|
headers: Record<string, string>,
|
||||||
waitForAuthCode: () => Promise<string>,
|
authInitializer: AuthInitializer,
|
||||||
skipBrowserAuth: boolean = false,
|
|
||||||
transportStrategy: TransportStrategy = 'http-first',
|
transportStrategy: TransportStrategy = 'http-first',
|
||||||
recursionReasons: Set<string> = new Set(),
|
recursionReasons: Set<string> = new Set(),
|
||||||
): Promise<Transport> {
|
): Promise<Transport> {
|
||||||
|
@ -143,7 +149,6 @@ export async function connectToRemoteServer(
|
||||||
|
|
||||||
return transport
|
return transport
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
console.log('DID I CATCH OR WHAT?')
|
|
||||||
// Check if it's a protocol error and we should attempt fallback
|
// Check if it's a protocol error and we should attempt fallback
|
||||||
if (
|
if (
|
||||||
error instanceof Error &&
|
error instanceof Error &&
|
||||||
|
@ -172,12 +177,16 @@ export async function connectToRemoteServer(
|
||||||
serverUrl,
|
serverUrl,
|
||||||
authProvider,
|
authProvider,
|
||||||
headers,
|
headers,
|
||||||
waitForAuthCode,
|
authInitializer,
|
||||||
skipBrowserAuth,
|
|
||||||
sseTransport ? 'http-only' : 'sse-only',
|
sseTransport ? 'http-only' : 'sse-only',
|
||||||
recursionReasons,
|
recursionReasons,
|
||||||
)
|
)
|
||||||
} else if (error instanceof UnauthorizedError || (error instanceof Error && error.message.includes('Unauthorized'))) {
|
} else if (error instanceof UnauthorizedError || (error instanceof Error && error.message.includes('Unauthorized'))) {
|
||||||
|
log('Authentication required. Initializing auth...')
|
||||||
|
|
||||||
|
// Initialize authentication on-demand
|
||||||
|
const { waitForAuthCode, skipBrowserAuth } = await authInitializer()
|
||||||
|
|
||||||
if (skipBrowserAuth) {
|
if (skipBrowserAuth) {
|
||||||
log('Authentication required but skipping browser auth - using shared auth')
|
log('Authentication required but skipping browser auth - using shared auth')
|
||||||
} else {
|
} else {
|
||||||
|
@ -207,8 +216,7 @@ export async function connectToRemoteServer(
|
||||||
serverUrl,
|
serverUrl,
|
||||||
authProvider,
|
authProvider,
|
||||||
headers,
|
headers,
|
||||||
waitForAuthCode,
|
authInitializer,
|
||||||
skipBrowserAuth,
|
|
||||||
transportStrategy,
|
transportStrategy,
|
||||||
recursionReasons,
|
recursionReasons,
|
||||||
)
|
)
|
||||||
|
|
38
src/proxy.ts
38
src/proxy.ts
|
@ -21,7 +21,7 @@ import {
|
||||||
MCP_REMOTE_VERSION,
|
MCP_REMOTE_VERSION,
|
||||||
} from './lib/utils'
|
} from './lib/utils'
|
||||||
import { NodeOAuthClientProvider } from './lib/node-oauth-client-provider'
|
import { NodeOAuthClientProvider } from './lib/node-oauth-client-provider'
|
||||||
import { coordinateAuth } from './lib/coordination'
|
import { createLazyAuthCoordinator } from './lib/coordination'
|
||||||
import { Client } from '@modelcontextprotocol/sdk/client/index.js'
|
import { Client } from '@modelcontextprotocol/sdk/client/index.js'
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -34,8 +34,8 @@ async function runProxy(serverUrl: string, callbackPort: number, headers: Record
|
||||||
// Get the server URL hash for lockfile operations
|
// Get the server URL hash for lockfile operations
|
||||||
const serverUrlHash = getServerUrlHash(serverUrl)
|
const serverUrlHash = getServerUrlHash(serverUrl)
|
||||||
|
|
||||||
// Coordinate authentication with other instances
|
// Create a lazy auth coordinator
|
||||||
const { server, waitForAuthCode, skipBrowserAuth } = await coordinateAuth(serverUrlHash, callbackPort, events)
|
const authCoordinator = createLazyAuthCoordinator(serverUrlHash, callbackPort, events)
|
||||||
|
|
||||||
// Create the OAuth client provider
|
// Create the OAuth client provider
|
||||||
const authProvider = new NodeOAuthClientProvider({
|
const authProvider = new NodeOAuthClientProvider({
|
||||||
|
@ -44,16 +44,32 @@ async function runProxy(serverUrl: string, callbackPort: number, headers: Record
|
||||||
clientName: 'MCP CLI Proxy',
|
clientName: 'MCP CLI Proxy',
|
||||||
})
|
})
|
||||||
|
|
||||||
|
// Create the STDIO transport for local connections
|
||||||
|
const localTransport = new StdioServerTransport()
|
||||||
|
|
||||||
|
// Keep track of the server instance for cleanup
|
||||||
|
let server: any = null
|
||||||
|
|
||||||
|
// Define an auth initializer function
|
||||||
|
const authInitializer = async () => {
|
||||||
|
const authState = await authCoordinator.initializeAuth()
|
||||||
|
|
||||||
|
// Store server in outer scope for cleanup
|
||||||
|
server = authState.server
|
||||||
|
|
||||||
// If auth was completed by another instance, just log that we'll use the auth from disk
|
// If auth was completed by another instance, just log that we'll use the auth from disk
|
||||||
if (skipBrowserAuth) {
|
if (authState.skipBrowserAuth) {
|
||||||
log('Authentication was completed by another instance - will use tokens from disk')
|
log('Authentication was completed by another instance - will use tokens from disk')
|
||||||
// TODO: remove, the callback is happening before the tokens are exchanged
|
// TODO: remove, the callback is happening before the tokens are exchanged
|
||||||
// so we're slightly too early
|
// so we're slightly too early
|
||||||
await new Promise((res) => setTimeout(res, 1_000))
|
await new Promise((res) => setTimeout(res, 1_000))
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create the STDIO transport for local connections
|
return {
|
||||||
const localTransport = new StdioServerTransport()
|
waitForAuthCode: authState.waitForAuthCode,
|
||||||
|
skipBrowserAuth: authState.skipBrowserAuth
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
try {
|
try {
|
||||||
const client = new Client(
|
const client = new Client(
|
||||||
|
@ -65,8 +81,8 @@ async function runProxy(serverUrl: string, callbackPort: number, headers: Record
|
||||||
capabilities: {},
|
capabilities: {},
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
// Connect to remote server with authentication
|
// Connect to remote server with lazy authentication
|
||||||
const remoteTransport = await connectToRemoteServer(client, serverUrl, authProvider, headers, waitForAuthCode, skipBrowserAuth)
|
const remoteTransport = await connectToRemoteServer(client, serverUrl, authProvider, headers, authInitializer)
|
||||||
|
|
||||||
// Set up bidirectional proxy between local and remote transports
|
// Set up bidirectional proxy between local and remote transports
|
||||||
mcpProxy({
|
mcpProxy({
|
||||||
|
@ -84,8 +100,11 @@ async function runProxy(serverUrl: string, callbackPort: number, headers: Record
|
||||||
const cleanup = async () => {
|
const cleanup = async () => {
|
||||||
await remoteTransport.close()
|
await remoteTransport.close()
|
||||||
await localTransport.close()
|
await localTransport.close()
|
||||||
|
// Only close the server if it was initialized
|
||||||
|
if (server) {
|
||||||
server.close()
|
server.close()
|
||||||
}
|
}
|
||||||
|
}
|
||||||
setupSignalHandlers(cleanup)
|
setupSignalHandlers(cleanup)
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
log('Fatal error:', error)
|
log('Fatal error:', error)
|
||||||
|
@ -111,7 +130,10 @@ to the CA certificate file. If using claude_desktop_config.json, this might look
|
||||||
}
|
}
|
||||||
`)
|
`)
|
||||||
}
|
}
|
||||||
|
// Only close the server if it was initialized
|
||||||
|
if (server) {
|
||||||
server.close()
|
server.close()
|
||||||
|
}
|
||||||
process.exit(1)
|
process.exit(1)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue