Refactor auth coordination to be lazy
- Create lazy auth coordinator that only initializes when needed - Modify connectToRemoteServer to only use auth when receiving an Unauthorized error - Update client.ts and proxy.ts to use the lazy auth approach - Add refactoring plan documentation 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude <noreply@anthropic.com>
This commit is contained in:
parent
14109a309f
commit
a6e6d0f1e8
5 changed files with 169 additions and 41 deletions
40
refactoring-plan.md
Normal file
40
refactoring-plan.md
Normal file
|
@ -0,0 +1,40 @@
|
|||
# Auth Coordination Refactoring Plan
|
||||
|
||||
Currently, both `src/proxy.ts` and `src/client.ts` always run auth coordination before attempting to connect to the server. However, in some cases authentication is not required, and we already have the ability to catch and handle authorization errors in the `connectToRemoteServer` function.
|
||||
|
||||
The plan is to refactor the code so that auth coordination is only invoked when we actually receive an "Unauthorized" error, rather than preemptively setting up auth for all connections.
|
||||
|
||||
## Tasks
|
||||
|
||||
1. [x] **Create a lazy auth coordinator**: Modify `coordinateAuth` function to support lazy initialization, so we can set it up but only use it when needed.
|
||||
- Added `createLazyAuthCoordinator` function that returns an object with `initializeAuth` method
|
||||
- Kept original `coordinateAuth` function intact for backward compatibility
|
||||
|
||||
2. [x] **Refactor `connectToRemoteServer`**: Update this function to handle auth lazily:
|
||||
- Removed the `waitForAuthCode` and `skipBrowserAuth` parameters
|
||||
- Added a new `authInitializer` parameter that initializes auth when needed
|
||||
- Only call this initializer when we encounter an "Unauthorized" error
|
||||
- Created a new type `AuthInitializer` to define the expected interface
|
||||
|
||||
3. [x] **Update client.ts**: Refactor the client to use the new lazy auth approach.
|
||||
- No longer calling `coordinateAuth` at the beginning
|
||||
- Created function to initiate auth only when needed
|
||||
- Pass this function to `connectToRemoteServer`
|
||||
- Added proper handling of server cleanup
|
||||
|
||||
4. [x] **Update proxy.ts**: Similarly refactor the proxy to use the lazy auth approach.
|
||||
- No longer calling `coordinateAuth` at the beginning
|
||||
- Created function to initiate auth only when needed
|
||||
- Pass this function to `connectToRemoteServer`
|
||||
- Added proper handling of server cleanup
|
||||
|
||||
5. [ ] **Test both flows**:
|
||||
- Test with servers requiring authentication
|
||||
- Test with servers that don't require authentication
|
||||
|
||||
## Benefits
|
||||
|
||||
- Improved efficiency by avoiding unnecessary auth setup when not needed
|
||||
- Faster startup for connections that don't require auth
|
||||
- Cleaner separation of concerns
|
||||
- Reduced complexity in the call flow
|
|
@ -22,7 +22,7 @@ import {
|
|||
connectToRemoteServer,
|
||||
TransportStrategy,
|
||||
} from './lib/utils'
|
||||
import { coordinateAuth } from './lib/coordination'
|
||||
import { createLazyAuthCoordinator } from './lib/coordination'
|
||||
|
||||
/**
|
||||
* Main function to run the client
|
||||
|
@ -39,8 +39,8 @@ async function runClient(
|
|||
// Get the server URL hash for lockfile operations
|
||||
const serverUrlHash = getServerUrlHash(serverUrl)
|
||||
|
||||
// Coordinate authentication with other instances
|
||||
const { server, waitForAuthCode, skipBrowserAuth } = await coordinateAuth(serverUrlHash, callbackPort, events)
|
||||
// Create a lazy auth coordinator
|
||||
const authCoordinator = createLazyAuthCoordinator(serverUrlHash, callbackPort, events)
|
||||
|
||||
// Create the OAuth client provider
|
||||
const authProvider = new NodeOAuthClientProvider({
|
||||
|
@ -49,14 +49,6 @@ async function runClient(
|
|||
clientName: 'MCP CLI Client',
|
||||
})
|
||||
|
||||
// If auth was completed by another instance, just log that we'll use the auth from disk
|
||||
if (skipBrowserAuth) {
|
||||
log('Authentication was completed by another instance - will use tokens from disk...')
|
||||
// TODO: remove, the callback is happening before the tokens are exchanged
|
||||
// so we're slightly too early
|
||||
await new Promise((res) => setTimeout(res, 1_000))
|
||||
}
|
||||
|
||||
// Create the client
|
||||
const client = new Client(
|
||||
{
|
||||
|
@ -68,15 +60,38 @@ async function runClient(
|
|||
},
|
||||
)
|
||||
|
||||
// Keep track of the server instance for cleanup
|
||||
let server: any = null
|
||||
|
||||
// Define an auth initializer function
|
||||
const authInitializer = async () => {
|
||||
const authState = await authCoordinator.initializeAuth()
|
||||
|
||||
// Store server in outer scope for cleanup
|
||||
server = authState.server
|
||||
|
||||
// If auth was completed by another instance, just log that we'll use the auth from disk
|
||||
if (authState.skipBrowserAuth) {
|
||||
log('Authentication was completed by another instance - will use tokens from disk...')
|
||||
// TODO: remove, the callback is happening before the tokens are exchanged
|
||||
// so we're slightly too early
|
||||
await new Promise((res) => setTimeout(res, 1_000))
|
||||
}
|
||||
|
||||
return {
|
||||
waitForAuthCode: authState.waitForAuthCode,
|
||||
skipBrowserAuth: authState.skipBrowserAuth
|
||||
}
|
||||
}
|
||||
|
||||
try {
|
||||
// Connect to remote server with authentication
|
||||
// Connect to remote server with lazy authentication
|
||||
const transport = await connectToRemoteServer(
|
||||
client,
|
||||
serverUrl,
|
||||
authProvider,
|
||||
headers,
|
||||
waitForAuthCode,
|
||||
skipBrowserAuth,
|
||||
authInitializer,
|
||||
transportStrategy,
|
||||
)
|
||||
|
||||
|
@ -98,7 +113,10 @@ async function runClient(
|
|||
const cleanup = async () => {
|
||||
log('\nClosing connection...')
|
||||
await client.close()
|
||||
server.close()
|
||||
// If auth was initialized and server was created, close it
|
||||
if (server) {
|
||||
server.close()
|
||||
}
|
||||
}
|
||||
setupSignalHandlers(cleanup)
|
||||
|
||||
|
@ -124,11 +142,17 @@ async function runClient(
|
|||
|
||||
// log('Listening for messages. Press Ctrl+C to exit.')
|
||||
log('Exiting OK...')
|
||||
server.close()
|
||||
// Only close the server if it was initialized
|
||||
if (server) {
|
||||
server.close()
|
||||
}
|
||||
process.exit(0)
|
||||
} catch (error) {
|
||||
log('Fatal error:', error)
|
||||
server.close()
|
||||
// Only close the server if it was initialized
|
||||
if (server) {
|
||||
server.close()
|
||||
}
|
||||
process.exit(1)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -5,6 +5,10 @@ import express from 'express'
|
|||
import { AddressInfo } from 'net'
|
||||
import { log, setupOAuthCallbackServerWithLongPoll } from './utils'
|
||||
|
||||
export type AuthCoordinator = {
|
||||
initializeAuth: () => Promise<{ server: Server; waitForAuthCode: () => Promise<string>; skipBrowserAuth: boolean }>
|
||||
}
|
||||
|
||||
/**
|
||||
* Checks if a process with the given PID is running
|
||||
* @param pid The process ID to check
|
||||
|
@ -88,6 +92,36 @@ export async function waitForAuthentication(port: number): Promise<boolean> {
|
|||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates a lazy auth coordinator that will only initiate auth when needed
|
||||
* @param serverUrlHash The hash of the server URL
|
||||
* @param callbackPort The port to use for the callback server
|
||||
* @param events The event emitter to use for signaling
|
||||
* @returns An AuthCoordinator object with an initializeAuth method
|
||||
*/
|
||||
export function createLazyAuthCoordinator(
|
||||
serverUrlHash: string,
|
||||
callbackPort: number,
|
||||
events: EventEmitter
|
||||
): AuthCoordinator {
|
||||
let authState: { server: Server; waitForAuthCode: () => Promise<string>; skipBrowserAuth: boolean } | null = null
|
||||
|
||||
return {
|
||||
initializeAuth: async () => {
|
||||
// If auth has already been initialized, return the existing state
|
||||
if (authState) {
|
||||
return authState
|
||||
}
|
||||
|
||||
log('Initializing auth coordination on-demand')
|
||||
|
||||
// Initialize auth using the existing coordinateAuth logic
|
||||
authState = await coordinateAuth(serverUrlHash, callbackPort, events)
|
||||
return authState
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Coordinates authentication between multiple instances of the client/proxy
|
||||
* @param serverUrlHash The hash of the server URL
|
||||
|
|
|
@ -73,14 +73,21 @@ export function mcpProxy({ transportToClient, transportToServer }: { transportTo
|
|||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Type for the auth initialization function
|
||||
*/
|
||||
export type AuthInitializer = () => Promise<{
|
||||
waitForAuthCode: () => Promise<string>
|
||||
skipBrowserAuth: boolean
|
||||
}>
|
||||
|
||||
/**
|
||||
* Creates and connects to a remote server with OAuth authentication
|
||||
* @param client The client to connect with
|
||||
* @param serverUrl The URL of the remote server
|
||||
* @param authProvider The OAuth client provider
|
||||
* @param headers Additional headers to send with the request
|
||||
* @param waitForAuthCode Function to wait for the auth code
|
||||
* @param skipBrowserAuth Whether to skip browser auth and use shared auth
|
||||
* @param authInitializer Function to initialize authentication when needed
|
||||
* @param transportStrategy Strategy for selecting transport type ('sse-only', 'http-only', 'sse-first', 'http-first')
|
||||
* @param recursionReasons Set of reasons for recursive calls (internal use)
|
||||
* @returns The connected transport
|
||||
|
@ -90,8 +97,7 @@ export async function connectToRemoteServer(
|
|||
serverUrl: string,
|
||||
authProvider: OAuthClientProvider,
|
||||
headers: Record<string, string>,
|
||||
waitForAuthCode: () => Promise<string>,
|
||||
skipBrowserAuth: boolean = false,
|
||||
authInitializer: AuthInitializer,
|
||||
transportStrategy: TransportStrategy = 'http-first',
|
||||
recursionReasons: Set<string> = new Set(),
|
||||
): Promise<Transport> {
|
||||
|
@ -143,7 +149,6 @@ export async function connectToRemoteServer(
|
|||
|
||||
return transport
|
||||
} catch (error) {
|
||||
console.log('DID I CATCH OR WHAT?')
|
||||
// Check if it's a protocol error and we should attempt fallback
|
||||
if (
|
||||
error instanceof Error &&
|
||||
|
@ -172,12 +177,16 @@ export async function connectToRemoteServer(
|
|||
serverUrl,
|
||||
authProvider,
|
||||
headers,
|
||||
waitForAuthCode,
|
||||
skipBrowserAuth,
|
||||
authInitializer,
|
||||
sseTransport ? 'http-only' : 'sse-only',
|
||||
recursionReasons,
|
||||
)
|
||||
} else if (error instanceof UnauthorizedError || (error instanceof Error && error.message.includes('Unauthorized'))) {
|
||||
log('Authentication required. Initializing auth...')
|
||||
|
||||
// Initialize authentication on-demand
|
||||
const { waitForAuthCode, skipBrowserAuth } = await authInitializer()
|
||||
|
||||
if (skipBrowserAuth) {
|
||||
log('Authentication required but skipping browser auth - using shared auth')
|
||||
} else {
|
||||
|
@ -207,8 +216,7 @@ export async function connectToRemoteServer(
|
|||
serverUrl,
|
||||
authProvider,
|
||||
headers,
|
||||
waitForAuthCode,
|
||||
skipBrowserAuth,
|
||||
authInitializer,
|
||||
transportStrategy,
|
||||
recursionReasons,
|
||||
)
|
||||
|
|
52
src/proxy.ts
52
src/proxy.ts
|
@ -21,7 +21,7 @@ import {
|
|||
MCP_REMOTE_VERSION,
|
||||
} from './lib/utils'
|
||||
import { NodeOAuthClientProvider } from './lib/node-oauth-client-provider'
|
||||
import { coordinateAuth } from './lib/coordination'
|
||||
import { createLazyAuthCoordinator } from './lib/coordination'
|
||||
import { Client } from '@modelcontextprotocol/sdk/client/index.js'
|
||||
|
||||
/**
|
||||
|
@ -34,8 +34,8 @@ async function runProxy(serverUrl: string, callbackPort: number, headers: Record
|
|||
// Get the server URL hash for lockfile operations
|
||||
const serverUrlHash = getServerUrlHash(serverUrl)
|
||||
|
||||
// Coordinate authentication with other instances
|
||||
const { server, waitForAuthCode, skipBrowserAuth } = await coordinateAuth(serverUrlHash, callbackPort, events)
|
||||
// Create a lazy auth coordinator
|
||||
const authCoordinator = createLazyAuthCoordinator(serverUrlHash, callbackPort, events)
|
||||
|
||||
// Create the OAuth client provider
|
||||
const authProvider = new NodeOAuthClientProvider({
|
||||
|
@ -44,17 +44,33 @@ async function runProxy(serverUrl: string, callbackPort: number, headers: Record
|
|||
clientName: 'MCP CLI Proxy',
|
||||
})
|
||||
|
||||
// If auth was completed by another instance, just log that we'll use the auth from disk
|
||||
if (skipBrowserAuth) {
|
||||
log('Authentication was completed by another instance - will use tokens from disk')
|
||||
// TODO: remove, the callback is happening before the tokens are exchanged
|
||||
// so we're slightly too early
|
||||
await new Promise((res) => setTimeout(res, 1_000))
|
||||
}
|
||||
|
||||
// Create the STDIO transport for local connections
|
||||
const localTransport = new StdioServerTransport()
|
||||
|
||||
// Keep track of the server instance for cleanup
|
||||
let server: any = null
|
||||
|
||||
// Define an auth initializer function
|
||||
const authInitializer = async () => {
|
||||
const authState = await authCoordinator.initializeAuth()
|
||||
|
||||
// Store server in outer scope for cleanup
|
||||
server = authState.server
|
||||
|
||||
// If auth was completed by another instance, just log that we'll use the auth from disk
|
||||
if (authState.skipBrowserAuth) {
|
||||
log('Authentication was completed by another instance - will use tokens from disk')
|
||||
// TODO: remove, the callback is happening before the tokens are exchanged
|
||||
// so we're slightly too early
|
||||
await new Promise((res) => setTimeout(res, 1_000))
|
||||
}
|
||||
|
||||
return {
|
||||
waitForAuthCode: authState.waitForAuthCode,
|
||||
skipBrowserAuth: authState.skipBrowserAuth
|
||||
}
|
||||
}
|
||||
|
||||
try {
|
||||
const client = new Client(
|
||||
{
|
||||
|
@ -65,8 +81,8 @@ async function runProxy(serverUrl: string, callbackPort: number, headers: Record
|
|||
capabilities: {},
|
||||
},
|
||||
)
|
||||
// Connect to remote server with authentication
|
||||
const remoteTransport = await connectToRemoteServer(client, serverUrl, authProvider, headers, waitForAuthCode, skipBrowserAuth)
|
||||
// Connect to remote server with lazy authentication
|
||||
const remoteTransport = await connectToRemoteServer(client, serverUrl, authProvider, headers, authInitializer)
|
||||
|
||||
// Set up bidirectional proxy between local and remote transports
|
||||
mcpProxy({
|
||||
|
@ -84,7 +100,10 @@ async function runProxy(serverUrl: string, callbackPort: number, headers: Record
|
|||
const cleanup = async () => {
|
||||
await remoteTransport.close()
|
||||
await localTransport.close()
|
||||
server.close()
|
||||
// Only close the server if it was initialized
|
||||
if (server) {
|
||||
server.close()
|
||||
}
|
||||
}
|
||||
setupSignalHandlers(cleanup)
|
||||
} catch (error) {
|
||||
|
@ -111,7 +130,10 @@ to the CA certificate file. If using claude_desktop_config.json, this might look
|
|||
}
|
||||
`)
|
||||
}
|
||||
server.close()
|
||||
// Only close the server if it was initialized
|
||||
if (server) {
|
||||
server.close()
|
||||
}
|
||||
process.exit(1)
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue