diff --git a/package.json b/package.json index 5ba0f68..e9f0be6 100644 --- a/package.json +++ b/package.json @@ -1,6 +1,6 @@ { "name": "mcp-remote", - "version": "0.0.9", + "version": "0.0.10", "type": "module", "bin": { "mcp-remote": "dist/cli/proxy.js" diff --git a/src/cli/shared.ts b/src/cli/shared.ts index a4101aa..b146abd 100644 --- a/src/cli/shared.ts +++ b/src/cli/shared.ts @@ -44,7 +44,7 @@ export class NodeOAuthClientProvider implements OAuthClientProvider { } get redirectUrl(): string { - return `http://localhost:${this.options.callbackPort}${this.callbackPath}` + return `http://127.0.0.1:${this.options.callbackPort}${this.callbackPath}` } get clientMetadata() { @@ -229,7 +229,7 @@ export function setupOAuthCallbackServer(options: OAuthCallbackServerOptions) { }) const server = app.listen(options.port, () => { - console.error(`OAuth callback server running at http://localhost:${options.port}`) + console.error(`OAuth callback server running at http://127.0.0.1:${options.port}`) }) /** @@ -299,7 +299,7 @@ export async function parseCommandLineArgs(args: string[], defaultPort: number, } const url = new URL(serverUrl) - const isLocalhost = url.hostname === 'localhost' && url.protocol === 'http:' + const isLocalhost = (url.hostname === 'localhost' || url.hostname === '127.0.0.1') && url.protocol === 'http:' if (!(url.protocol == 'https:' || isLocalhost)) { console.error(usage) diff --git a/src/react/index.ts b/src/react/index.ts index 40948c2..43f8a4f 100644 --- a/src/react/index.ts +++ b/src/react/index.ts @@ -2,7 +2,12 @@ import { CallToolResultSchema, JSONRPCMessage, ListToolsResultSchema, Tool } fro import { useCallback, useEffect, useRef, useState } from 'react' import { SSEClientTransport } from '@modelcontextprotocol/sdk/client/sse.js' import { Client } from '@modelcontextprotocol/sdk/client/index.js' -import { discoverOAuthMetadata, exchangeAuthorization, startAuthorization } from '@modelcontextprotocol/sdk/client/auth.js' +import { + OAuthClientProvider, + discoverOAuthMetadata, + exchangeAuthorization, + startAuthorization, +} from '@modelcontextprotocol/sdk/client/auth.js' import { OAuthClientInformation, OAuthMetadata, OAuthTokens } from '@modelcontextprotocol/sdk/shared/auth.js' function assert(condition: unknown, message: string): asserts condition { @@ -70,6 +75,10 @@ export type UseMcpResult = { * @returns Auth URL that can be used to manually open a new window */ authenticate: () => Promise + /** + * Clear all localStorage items for this server + */ + clearStorage: () => void } type StoredState = { @@ -82,12 +91,14 @@ type StoredState = { /** * Browser-compatible OAuth client provider for MCP */ -class BrowserOAuthClientProvider { +class BrowserOAuthClientProvider implements OAuthClientProvider { private storageKeyPrefix: string - private serverUrlHash: string + serverUrlHash: string private clientName: string private clientUri: string private callbackUrl: string + // Store additional options for popup windows + private popupFeatures: string constructor( readonly serverUrl: string, @@ -96,6 +107,7 @@ class BrowserOAuthClientProvider { clientName?: string clientUri?: string callbackUrl?: string + popupFeatures?: string } = {}, ) { this.storageKeyPrefix = options.storageKeyPrefix || 'mcp:auth' @@ -103,6 +115,7 @@ class BrowserOAuthClientProvider { this.clientName = options.clientName || 'MCP Browser Client' this.clientUri = options.clientUri || window.location.origin this.callbackUrl = options.callbackUrl || new URL('/oauth/callback', window.location.origin).toString() + this.popupFeatures = options.popupFeatures || 'width=600,height=700,resizable=yes,scrollbars=yes' } get redirectUrl(): string { @@ -120,6 +133,44 @@ class BrowserOAuthClientProvider { } } + /** + * Clears all storage items related to this server + * @returns The number of items cleared + */ + clearStorage(): number { + const prefix = `${this.storageKeyPrefix}_${this.serverUrlHash}` + const keysToRemove = [] + + // Find all keys that match the prefix + for (let i = 0; i < localStorage.length; i++) { + const key = localStorage.key(i) + if (key && key.startsWith(prefix)) { + keysToRemove.push(key) + } + } + + // Also check for any state keys + for (let i = 0; i < localStorage.length; i++) { + const key = localStorage.key(i) + if (key && key.startsWith(`${this.storageKeyPrefix}:state_`)) { + // Load state to check if it's for this server + try { + const state = JSON.parse(localStorage.getItem(key) || '{}') + if (state.serverUrlHash === this.serverUrlHash) { + keysToRemove.push(key) + } + } catch (e) { + // Ignore JSON parse errors + } + } + } + + // Remove all matching keys + keysToRemove.forEach((key) => localStorage.removeItem(key)) + + return keysToRemove.length + } + private hashString(str: string): string { // Simple hash function for browser environments let hash = 0 @@ -131,7 +182,7 @@ class BrowserOAuthClientProvider { return Math.abs(hash).toString(16) } - private getKey(key: string): string { + getKey(key: string): string { return `${this.storageKeyPrefix}_${this.serverUrlHash}_${key}` } @@ -169,36 +220,52 @@ class BrowserOAuthClientProvider { localStorage.setItem(key, JSON.stringify(tokens)) } - async redirectToAuthorization( + /** + * Redirect method that matches the interface expected by OAuthClientProvider + */ + async redirectToAuthorization(authorizationUrl: URL): Promise { + // Simply open the URL in the current window + console.log('WE WERE ABOUT TO REDIRECT BUT WE DONT DO THAT HERE') + // window.location.href = authorizationUrl.toString() + } + + /** + * Extended popup-based authorization method specific to browser environments + */ + async openAuthorizationPopup( authorizationUrl: URL, metadata: OAuthMetadata, - options?: { - popupFeatures?: string - }, ): Promise<{ success: boolean; popupBlocked?: boolean; url: string }> { - // Store the auth state for the popup flow - const state = Math.random().toString(36).substring(2) - const stateKey = `${this.storageKeyPrefix}:state_${state}` - localStorage.setItem( - stateKey, - JSON.stringify({ - authorizationUrl: authorizationUrl.toString(), - metadata, - serverUrlHash: this.serverUrlHash, - expiry: +new Date() + 1000 * 60 * 5 /* 5 minutes */, - } as StoredState), - ) - authorizationUrl.searchParams.set('state', state) + // Use existing state parameter if it exists in the URL + const existingState = authorizationUrl.searchParams.get('state') + + if (!existingState) { + // This should not happen as startAuthFlow should've added state + // But if it doesn't exist, add it as a fallback + const state = Math.random().toString(36).substring(2) + const stateKey = `${this.storageKeyPrefix}:state_${state}` + + localStorage.setItem( + stateKey, + JSON.stringify({ + authorizationUrl: authorizationUrl.toString(), + metadata, + serverUrlHash: this.serverUrlHash, + expiry: +new Date() + 1000 * 60 * 5 /* 5 minutes */, + } as StoredState), + ) + + authorizationUrl.searchParams.set('state', state) + } const authUrl = authorizationUrl.toString() - const popupFeatures = options?.popupFeatures || 'width=600,height=700,resizable=yes,scrollbars=yes' // Store the auth URL in case we need it for manual authentication localStorage.setItem(this.getKey('auth_url'), authUrl) try { // Open the authorization URL in a popup window - const popup = window.open(authUrl, 'mcp_auth', popupFeatures) + const popup = window.open(authUrl, 'mcp_auth', this.popupFeatures) // Check if popup was blocked or closed immediately if (!popup || popup.closed || popup.closed === undefined) { @@ -246,164 +313,436 @@ class BrowserOAuthClientProvider { } /** - * useMcp is a React hook that connects to a remote MCP server, negotiates auth - * (including opening a popup window or new tab to complete the OAuth flow), - * and enables passing a list of tools (once loaded) to ai-sdk (using `useChat`). + * Class to encapsulate all MCP client functionality, + * including authentication flow and connection management */ -export function useMcp(options: UseMcpOptions): UseMcpResult { - const [state, setState] = useState('discovering') - const [tools, setTools] = useState([]) - const [error, setError] = useState(undefined) - const [log, setLog] = useState([]) - const [authUrl, setAuthUrl] = useState(undefined) +class McpClient { + // State + private _state: UseMcpResult['state'] = 'discovering' + private _error?: string + private _tools: Tool[] = [] + private _log: UseMcpResult['log'] = [] + private _authUrl?: string - const clientRef = useRef(null) - const transportRef = useRef(null) - const authProviderRef = useRef(null) - const metadataRef = useRef(undefined) - const authUrlRef = useRef(undefined) - const codeVerifierRef = useRef(undefined) - const connectingRef = useRef(false) - const isInitialMount = useRef(true) + // Client and transport + private client: Client | null = null + private transport: SSEClientTransport | null = null + private authProvider: BrowserOAuthClientProvider | undefined = undefined - // Set up default options - const { - url, - clientName = 'MCP React Client', - clientUri = window.location.origin, - callbackUrl = new URL('/oauth/callback', window.location.origin).toString(), - storageKeyPrefix = 'mcp:auth', - clientConfig = { - name: 'mcp-react-client', - version: '0.1.0', + // Authentication state + private metadata?: OAuthMetadata + private authUrlRef?: URL + private authState?: string + private codeVerifier?: string + private connecting = false + + // Update callbacks + private onStateChange: (state: UseMcpResult['state']) => void + private onToolsChange: (tools: Tool[]) => void + private onErrorChange: (error?: string) => void + private onLogChange: (log: UseMcpResult['log']) => void + private onAuthUrlChange: (authUrl?: string) => void + + constructor( + private url: string, + private options: { + clientName?: string + clientUri?: string + callbackUrl?: string + storageKeyPrefix?: string + clientConfig?: { + name?: string + version?: string + } + debug?: boolean + autoRetry?: boolean | number + autoReconnect?: boolean | number + popupFeatures?: string }, - debug = false, - autoRetry = false, - autoReconnect = 3000, - popupFeatures = 'width=600,height=700,resizable=yes,scrollbars=yes', - } = options - - // Add to log - const addLog = useCallback( - (level: 'debug' | 'info' | 'warn' | 'error', message: string) => { - if (level === 'debug' && !debug) return - setLog((prevLog) => [...prevLog, { level, message }]) + callbacks: { + onStateChange: (state: UseMcpResult['state']) => void + onToolsChange: (tools: Tool[]) => void + onErrorChange: (error?: string) => void + onLogChange: (log: UseMcpResult['log']) => void + onAuthUrlChange: (authUrl?: string) => void }, - [debug], - ) + ) { + // Initialize callbacks + this.onStateChange = callbacks.onStateChange + this.onToolsChange = callbacks.onToolsChange + this.onErrorChange = callbacks.onErrorChange + this.onLogChange = callbacks.onLogChange + this.onAuthUrlChange = callbacks.onAuthUrlChange - // Call a tool on the MCP server - const callTool = useCallback( - async (name: string, args?: Record) => { - if (!clientRef.current || state !== 'ready') { - throw new Error('MCP client not ready') + // Initialize auth provider + this.initAuthProvider() + } + + get state(): UseMcpResult['state'] { + return this._state + } + + get tools(): Tool[] { + return this._tools + } + + get error(): string | undefined { + return this._error + } + + get log(): UseMcpResult['log'] { + return this._log + } + + get authUrl(): string | undefined { + return this._authUrl + } + + /** + * Initialize the auth provider + */ + private initAuthProvider(): void { + if (!this.authProvider) { + this.authProvider = new BrowserOAuthClientProvider(this.url, { + storageKeyPrefix: this.options.storageKeyPrefix, + clientName: this.options.clientName, + clientUri: this.options.clientUri, + callbackUrl: this.options.callbackUrl, + }) + } + } + + /** + * Add a log entry + */ + private addLog(level: 'debug' | 'info' | 'warn' | 'error', message: string): void { + if (level === 'debug' && !this.options.debug) return + this._log = [...this._log, { level, message }] + this.onLogChange(this._log) + } + + /** + * Update the state + */ + private setState(state: UseMcpResult['state']): void { + this._state = state + this.onStateChange(state) + } + + /** + * Update the error + */ + private setError(error?: string): void { + this._error = error + this.onErrorChange(error) + } + + /** + * Update the tools + */ + private setTools(tools: Tool[]): void { + this._tools = tools + this.onToolsChange(tools) + } + + /** + * Update the auth URL + */ + private setAuthUrl(authUrl?: string): void { + this._authUrl = authUrl + this.onAuthUrlChange(authUrl) + } + + /** + * Handle OAuth discovery and authentication + */ + private async discoverOAuthAndAuthenticate(error: Error): Promise { + try { + // Discover OAuth metadata now that we know we need it + if (!this.metadata) { + this.addLog('info', 'Discovering OAuth metadata...') + this.metadata = await discoverOAuthMetadata(this.url) + this.addLog('debug', `OAuth metadata: ${this.metadata ? 'Found' : 'Not available'}`) } - try { - console.log('CALLING TOOL') - const result = await clientRef.current.request( - { - method: 'tools/call', - params: { name, arguments: args }, + // If metadata is found, start auth flow + if (this.metadata) { + this.setState('authenticating') + + try { + // Start authentication process + await this.handleAuthentication() + + // After successful auth, retry connection + // Important: We need to fully disconnect and reconnect + await this.disconnect() + await this.connect() + } catch (authErr) { + this.addLog('error', `Authentication error: ${authErr instanceof Error ? authErr.message : String(authErr)}`) + this.setState('failed') + this.setError(`Authentication failed: ${authErr instanceof Error ? authErr.message : String(authErr)}`) + this.connecting = false + } + } else { + // No OAuth metadata available + this.setState('failed') + this.setError(`Authentication required but no OAuth metadata found: ${error.message}`) + this.connecting = false + } + } catch (oauthErr) { + this.addLog('error', `OAuth discovery error: ${oauthErr instanceof Error ? oauthErr.message : String(oauthErr)}`) + this.setState('failed') + this.setError(`Authentication setup failed: ${oauthErr instanceof Error ? oauthErr.message : String(oauthErr)}`) + this.connecting = false + } + } + + /** + * Connect to the MCP server + */ + async connect(): Promise { + // Prevent multiple simultaneous connection attempts + if (this.connecting) return + this.connecting = true + + try { + this.setState('discovering') + this.setError(undefined) + + // Create MCP client + this.client = new Client( + { + name: this.options.clientConfig?.name || 'mcp-react-client', + version: this.options.clientConfig?.version || '0.1.0', + }, + { + capabilities: { + sampling: {}, }, - CallToolResultSchema, - ) - return result - } catch (err) { - addLog('error', `Error calling tool ${name}: ${err instanceof Error ? err.message : String(err)}`) - throw err - } - }, - [state, addLog], - ) + }, + ) - // Disconnect from the MCP server - const disconnect = useCallback(async () => { - if (clientRef.current) { - try { - await clientRef.current.close() - } catch (err) { - addLog('error', `Error closing client: ${err instanceof Error ? err.message : String(err)}`) + // Create SSE transport + this.setState('connecting') + this.addLog('info', 'Creating transport...') + + const serverUrl = new URL(this.url) + this.transport = new SSEClientTransport(serverUrl, { + authProvider: this.authProvider, + }) + + // Set up transport handlers + this.transport.onmessage = (message: JSONRPCMessage) => { + // @ts-expect-error TODO: fix this type + this.addLog('debug', `Received message: ${message.method || message.id}`) } - clientRef.current = null + + this.transport.onerror = (err: Error) => { + this.addLog('error', `Transport error: ${err.message}`) + + if (err.message.includes('Unauthorized')) { + // Only discover OAuth metadata and authenticate if we get a 401 + this.discoverOAuthAndAuthenticate(err) + } else { + this.setState('failed') + this.setError(`Connection error: ${err.message}`) + this.connecting = false + } + } + + this.transport.onclose = () => { + this.addLog('info', 'Connection closed') + // If we were previously connected, try to reconnect + if (this.state === 'ready' && this.options.autoReconnect) { + const delay = typeof this.options.autoReconnect === 'number' ? this.options.autoReconnect : 3000 + this.addLog('info', `Will reconnect in ${delay}ms...`) + setTimeout(() => { + this.disconnect().then(() => this.connect()) + }, delay) + } + } + + // Try connecting transport + try { + this.addLog('info', 'Starting transport...') + // await this.transport.start() + } catch (err) { + this.addLog('error', `Transport start error: ${err instanceof Error ? err.message : String(err)}`) + + if (err instanceof Error && err.message.includes('Unauthorized')) { + // Only discover OAuth and authenticate if we get a 401 + await this.discoverOAuthAndAuthenticate(err) + return // Important: Return here to avoid proceeding with the unauthorized connection + } else { + this.setState('failed') + this.setError(`Connection error: ${err instanceof Error ? err.message : String(err)}`) + this.connecting = false + return + } + } + + // Connect client + try { + this.addLog('info', 'Connecting client...') + this.setState('loading') + await this.client.connect(this.transport) + this.addLog('info', 'Client connected') + + // Load tools + try { + this.addLog('info', 'Loading tools...') + const toolsResponse = await this.client.request({ method: 'tools/list' }, ListToolsResultSchema) + this.setTools(toolsResponse.tools) + this.addLog('info', `Loaded ${toolsResponse.tools.length} tools`) + + // Connection completed successfully + this.setState('ready') + this.connecting = false + } catch (toolErr) { + this.addLog('error', `Error loading tools: ${toolErr instanceof Error ? toolErr.message : String(toolErr)}`) + // We're still connected, just couldn't load tools + this.setState('ready') + this.connecting = false + } + } catch (connectErr) { + this.addLog('error', `Client connect error: ${connectErr instanceof Error ? connectErr.message : String(connectErr)}`) + + if (connectErr instanceof Error && connectErr.message.includes('Unauthorized')) { + // Only discover OAuth and authenticate if we get a 401 + await this.discoverOAuthAndAuthenticate(connectErr) + } else { + this.setState('failed') + this.setError(`Connection error: ${connectErr instanceof Error ? connectErr.message : String(connectErr)}`) + this.connecting = false + } + } + } catch (err) { + this.addLog('error', `Unexpected error: ${err instanceof Error ? err.message : String(err)}`) + this.setState('failed') + this.setError(`Unexpected error: ${err instanceof Error ? err.message : String(err)}`) + this.connecting = false + } + } + + /** + * Disconnect from the MCP server + */ + async disconnect(): Promise { + if (this.client) { + try { + await this.client.close() + } catch (err) { + this.addLog('error', `Error closing client: ${err instanceof Error ? err.message : String(err)}`) + } + this.client = null } - if (transportRef.current) { + if (this.transport) { try { - await transportRef.current.close() + await this.transport.close() } catch (err) { - addLog('error', `Error closing transport: ${err instanceof Error ? err.message : String(err)}`) + this.addLog('error', `Error closing transport: ${err instanceof Error ? err.message : String(err)}`) } - transportRef.current = null + this.transport = null } - connectingRef.current = false - setState('discovering') - setTools([]) - setError(undefined) - }, [addLog]) + this.connecting = false + this.setState('discovering') + this.setTools([]) + this.setError(undefined) + } - // Start the auth flow and get the auth URL - const startAuthFlow = useCallback(async (): Promise => { - if (!authProviderRef.current || !metadataRef.current) { + /** + * Start the auth flow and get the auth URL + */ + async startAuthFlow(): Promise { + if (!this.authProvider || !this.metadata) { throw new Error('Auth provider or metadata not available') } - addLog('info', 'Starting authentication flow...') + this.addLog('info', 'Starting authentication flow...') // Check if we have client info - let clientInfo = await authProviderRef.current.clientInformation() + let clientInfo = await this.authProvider.clientInformation() if (!clientInfo) { // Register client dynamically - addLog('info', 'No client information found, registering...') + this.addLog('info', 'No client information found, registering...') // Note: In a complete implementation, you'd register the client here // This would be done server-side in a real application throw new Error('Dynamic client registration not implemented in this example') } // Start authorization flow - addLog('info', 'Preparing authorization...') - const { authorizationUrl, codeVerifier } = await startAuthorization(url, { - metadata: metadataRef.current, + this.addLog('info', 'Preparing authorization...') + const { authorizationUrl, codeVerifier } = await startAuthorization(this.url, { + metadata: this.metadata, clientInformation: clientInfo, - redirectUrl: authProviderRef.current.redirectUrl, + redirectUrl: this.authProvider.redirectUrl, }) // Save code verifier and auth URL for later use - await authProviderRef.current.saveCodeVerifier(codeVerifier) - codeVerifierRef.current = codeVerifier - authUrlRef.current = authorizationUrl - setAuthUrl(authorizationUrl.toString()) + await this.authProvider.saveCodeVerifier(codeVerifier) + this.codeVerifier = codeVerifier + + // Generate state parameter that will be used for both popup and manual flows + const state = Math.random().toString(36).substring(2) + const stateKey = `${this.options.storageKeyPrefix}:state_${state}` + + // Store state for later retrieval + localStorage.setItem( + stateKey, + JSON.stringify({ + authorizationUrl: authorizationUrl.toString(), + metadata: this.metadata, + serverUrlHash: this.authProvider.serverUrlHash, + expiry: +new Date() + 1000 * 60 * 5 /* 5 minutes */, + } as StoredState), + ) + + // Add state to the URL + authorizationUrl.searchParams.set('state', state) + + // Store the state and URL for later use + this.authState = state + this.authUrlRef = authorizationUrl + + // Set manual auth URL (already includes state parameter) + this.setAuthUrl(authorizationUrl.toString()) return authorizationUrl - }, [url, addLog]) + } - // Handle authentication flow - const handleAuthentication = useCallback(async () => { - if (!authProviderRef.current) { + /** + * Handle authentication flow + */ + async handleAuthentication(): Promise { + if (!this.authProvider) { throw new Error('Auth provider not available') } // Get or create the auth URL - if (!authUrlRef.current) { + if (!this.authUrlRef) { try { - await startAuthFlow() + await this.startAuthFlow() } catch (err) { - addLog('error', `Failed to start auth flow: ${err instanceof Error ? err.message : String(err)}`) + this.addLog('error', `Failed to start auth flow: ${err instanceof Error ? err.message : String(err)}`) throw err } } - if (!authUrlRef.current) { + if (!this.authUrlRef) { throw new Error('Failed to create authorization URL') } // Set up listener for post-auth message const authPromise = new Promise((resolve, reject) => { + let pollIntervalId: number | undefined + const timeoutId = setTimeout( () => { window.removeEventListener('message', messageHandler) + if (pollIntervalId) clearTimeout(pollIntervalId) reject(new Error('Authentication timeout after 5 minutes')) }, 5 * 60 * 1000, @@ -416,276 +755,308 @@ export function useMcp(options: UseMcpOptions): UseMcpResult { if (event.data && event.data.type === 'mcp_auth_callback' && event.data.code) { window.removeEventListener('message', messageHandler) clearTimeout(timeoutId) + if (pollIntervalId) clearTimeout(pollIntervalId) - // TODO: not this, obviously - // reload window, we should find the token in local storage - window.location.reload() - // resolve(event.data.code) + resolve(event.data.code) } } window.addEventListener('message', messageHandler) + + // Add polling fallback to check for tokens in localStorage + const pollForTokens = () => { + try { + // Check if tokens have appeared in localStorage + const tokensKey = this.authProvider!.getKey('tokens') + const storedTokens = localStorage.getItem(tokensKey) + + if (storedTokens) { + // Tokens found, clean up and resolve + window.removeEventListener('message', messageHandler) + clearTimeout(timeoutId) + if (pollIntervalId) clearTimeout(pollIntervalId) + + // Parse tokens to make sure they're valid + const tokens = JSON.parse(storedTokens) + if (tokens.access_token) { + console.log('Found tokens in localStorage via polling') + // Resolve with an object that indicates tokens are already available + // This will signal to handleAuthCompletion that no token exchange is needed + resolve('TOKENS_ALREADY_EXCHANGED') + } + } + } catch (err) { + // Error during polling, continue anyway + console.error(err) + } + } + + // Start polling every 500ms using setTimeout for recursive polling + const poll = () => { + pollIntervalId = setTimeout(poll, 500) as unknown as number + pollForTokens() + } + + poll() // Start the polling }) // Redirect to authorization - addLog('info', 'Opening authorization window...') - assert(metadataRef.current, 'Metadata not available') - const redirectResult = await authProviderRef.current.redirectToAuthorization(authUrlRef.current, metadataRef.current, { - popupFeatures, - }) + this.addLog('info', 'Opening authorization window...') + assert(this.metadata, 'Metadata not available') + const redirectResult = await this.authProvider.openAuthorizationPopup(this.authUrlRef, this.metadata) if (!redirectResult.success) { // Popup was blocked - setState('failed') - setError('Authentication popup was blocked by the browser. Please click the link to authenticate in a new window.') - setAuthUrl(redirectResult.url) - addLog('warn', 'Authentication popup was blocked. User needs to manually authorize.') + this.setState('failed') + this.setError('Authentication popup was blocked by the browser. Please click the link to authenticate in a new window.') + this.setAuthUrl(redirectResult.url) + this.addLog('warn', 'Authentication popup was blocked. User needs to manually authorize.') throw new Error('Authentication popup blocked') } // Wait for auth to complete - addLog('info', 'Waiting for authorization...') + this.addLog('info', 'Waiting for authorization...') const code = await authPromise - addLog('info', 'Authorization code received') + this.addLog('info', 'Authorization code received') return code - }, [url, addLog, popupFeatures, startAuthFlow]) + } - // Initialize connection to MCP server - const connect = useCallback(async () => { - // Prevent multiple simultaneous connection attempts - if (connectingRef.current) return - connectingRef.current = true + /** + * Handle authentication completion + * @param code - The authorization code or special token indicator + */ + async handleAuthCompletion(code: string): Promise { + if (!this.authProvider || !this.transport) { + throw new Error('Authentication context not available') + } try { - setState('discovering') - setError(undefined) - - // Create auth provider if not already created - if (!authProviderRef.current) { - authProviderRef.current = new BrowserOAuthClientProvider(url, { - storageKeyPrefix, - clientName, - clientUri, - callbackUrl, - }) + // Check if this is our special token indicator + if (code === 'TOKENS_ALREADY_EXCHANGED') { + this.addLog('info', 'Using already exchanged tokens from localStorage') + // No need to exchange tokens, they're already in localStorage + } else { + // We received an authorization code that needs to be exchanged + this.addLog('info', 'Finishing authorization with code exchange...') + await this.transport.finishAuth(code) + this.addLog('info', 'Authorization code exchanged for tokens') } - // Create MCP client - clientRef.current = new Client( + this.addLog('info', 'Authorization completed') + + // Reset auth URL state + this.authUrlRef = undefined + this.setAuthUrl(undefined) + + // Reconnect with the new auth token - important to do a full disconnect/connect cycle + await this.disconnect() + await this.connect() + } catch (err) { + this.addLog('error', `Auth completion error: ${err instanceof Error ? err.message : String(err)}`) + this.setState('failed') + this.setError(`Authentication failed: ${err instanceof Error ? err.message : String(err)}`) + } + } + + /** + * Call a tool on the MCP server + */ + async callTool(name: string, args?: Record): Promise { + if (!this.client || this.state !== 'ready') { + throw new Error('MCP client not ready') + } + + try { + const result = await this.client.request( { - name: clientConfig.name || 'mcp-react-client', - version: clientConfig.version || '0.1.0', + method: 'tools/call', + params: { name, arguments: args }, + }, + CallToolResultSchema, + ) + return result + } catch (err) { + this.addLog('error', `Error calling tool ${name}: ${err instanceof Error ? err.message : String(err)}`) + throw err + } + } + + /** + * Retry connection + */ + retry(): void { + if (this.state === 'failed') { + this.disconnect().then(() => this.connect()) + } + } + + /** + * Manually trigger authentication + */ + async authenticate(): Promise { + if (!this.authProvider) { + try { + // Discover OAuth metadata if we don't have it yet + this.addLog('info', 'Discovering OAuth metadata...') + this.metadata = await discoverOAuthMetadata(this.url) + this.addLog('debug', `OAuth metadata: ${this.metadata ? 'Found' : 'Not available'}`) + + if (!this.metadata) { + throw new Error('No OAuth metadata available') + } + + // Initialize the auth provider now that we have metadata + this.initAuthProvider() + } catch (err) { + this.addLog('error', `Failed to discover OAuth metadata: ${err instanceof Error ? err.message : String(err)}`) + return undefined + } + } + + try { + // If we don't have an auth URL yet with state param, start a new flow + if (!this.authUrlRef || !this.authUrlRef.searchParams.get('state')) { + await this.startAuthFlow() + } + + if (!this.authUrlRef) { + throw new Error('Failed to create authorization URL') + } + + // The URL already has the state parameter from startAuthFlow + return this.authUrlRef.toString() + } catch (err) { + this.addLog('error', `Error preparing manual authentication: ${err instanceof Error ? err.message : String(err)}`) + return undefined + } + } + + /** + * Clear all localStorage items for this server + */ + clearStorage(): number { + if (!this.authProvider) { + this.addLog('warn', 'Cannot clear storage: auth provider not initialized') + return 0 + } + + // Use the provider's method to clear storage + const clearedCount = this.authProvider.clearStorage() + + // Clear auth-related state in the class + this.authUrlRef = undefined + this.setAuthUrl(undefined) + this.metadata = undefined + this.codeVerifier = undefined + + this.addLog('info', `Cleared ${clearedCount} storage items for server`) + + return clearedCount + } +} + +/** + * useMcp is a React hook that connects to a remote MCP server, negotiates auth + * (including opening a popup window or new tab to complete the OAuth flow), + * and enables passing a list of tools (once loaded) to ai-sdk (using `useChat`). + */ +export function useMcp(options: UseMcpOptions): UseMcpResult { + const [state, setState] = useState('discovering') + const [tools, setTools] = useState([]) + const [error, setError] = useState(undefined) + const [log, setLog] = useState([]) + const [authUrl, setAuthUrl] = useState(undefined) + + // Use a ref to maintain a single instance of the McpClient + const clientRef = useRef(null) + const isInitialMount = useRef(true) + + // Initialize the client if it doesn't exist yet + const getClient = useCallback(() => { + if (!clientRef.current) { + clientRef.current = new McpClient( + options.url, + { + clientName: options.clientName || 'MCP React Client', + clientUri: options.clientUri || window.location.origin, + callbackUrl: options.callbackUrl || new URL('/oauth/callback', window.location.origin).toString(), + storageKeyPrefix: options.storageKeyPrefix || 'mcp:auth', + clientConfig: options.clientConfig || { + name: 'mcp-react-client', + version: '0.1.0', + }, + debug: options.debug || false, + autoRetry: options.autoRetry || false, + autoReconnect: options.autoReconnect || 3000, + popupFeatures: options.popupFeatures || 'width=600,height=700,resizable=yes,scrollbars=yes', }, { - capabilities: { - sampling: {}, - }, + onStateChange: setState, + onToolsChange: setTools, + onErrorChange: setError, + onLogChange: setLog, + onAuthUrlChange: setAuthUrl, }, ) - - // Create SSE transport - try connecting without auth first - setState('connecting') - addLog('info', 'Creating transport...') - - const serverUrl = new URL(url) - transportRef.current = new SSEClientTransport(serverUrl, { - // @ts-expect-error TODO: fix this type, expect BrowserOAuthClientProvider - authProvider: authProviderRef.current, - }) - - // Set up transport handlers - transportRef.current.onmessage = (message: JSONRPCMessage) => { - // @ts-expect-error TODO: fix this type - addLog('debug', `Received message: ${message.method || message.id}`) - } - - transportRef.current.onerror = (err: Error) => { - addLog('error', `Transport error: ${err.message}`) - - if (err.message.includes('Unauthorized')) { - // Only discover OAuth metadata and authenticate if we get a 401 - discoverOAuthAndAuthenticate(err) - } else { - setState('failed') - setError(`Connection error: ${err.message}`) - connectingRef.current = false - } - } - - transportRef.current.onclose = () => { - addLog('info', 'Connection closed') - // If we were previously connected, try to reconnect - if (state === 'ready' && autoReconnect) { - const delay = typeof autoReconnect === 'number' ? autoReconnect : 3000 - addLog('info', `Will reconnect in ${delay}ms...`) - setTimeout(() => { - disconnect().then(() => connect()) - }, delay) - } - } - - // Helper function to handle OAuth discovery and authentication - const discoverOAuthAndAuthenticate = async (error: Error) => { - try { - // Discover OAuth metadata now that we know we need it - if (!metadataRef.current) { - addLog('info', 'Discovering OAuth metadata...') - metadataRef.current = await discoverOAuthMetadata(url) - addLog('debug', `OAuth metadata: ${metadataRef.current ? 'Found' : 'Not available'}`) - } - - // If metadata is found, start auth flow - if (metadataRef.current) { - setState('authenticating') - // Start authentication process - await handleAuthentication() - // After successful auth, retry connection - return connect() - } else { - // No OAuth metadata available - setState('failed') - setError(`Authentication required but no OAuth metadata found: ${error.message}`) - connectingRef.current = false - } - } catch (oauthErr) { - addLog('error', `OAuth discovery error: ${oauthErr instanceof Error ? oauthErr.message : String(oauthErr)}`) - setState('failed') - setError(`Authentication setup failed: ${oauthErr instanceof Error ? oauthErr.message : String(oauthErr)}`) - connectingRef.current = false - } - } - - // Try connecting transport first without OAuth discovery - try { - addLog('info', 'Starting transport...') - // await transportRef.current.start() - } catch (err) { - addLog('error', `Transport start error: ${err instanceof Error ? err.message : String(err)}`) - - if (err instanceof Error && err.message.includes('Unauthorized')) { - // Only discover OAuth and authenticate if we get a 401 - await discoverOAuthAndAuthenticate(err) - } else { - setState('failed') - setError(`Connection error: ${err instanceof Error ? err.message : String(err)}`) - connectingRef.current = false - return - } - } - - // Connect client - try { - addLog('info', 'Connecting client...') - setState('loading') - await clientRef.current.connect(transportRef.current) - addLog('info', 'Client connected') - - // Load tools - try { - addLog('info', 'Loading tools...') - const toolsResponse = await clientRef.current.request({ method: 'tools/list' }, ListToolsResultSchema) - setTools(toolsResponse.tools) - addLog('info', `Loaded ${toolsResponse.tools.length} tools`) - - // Connection completed successfully - setState('ready') - connectingRef.current = false - } catch (toolErr) { - addLog('error', `Error loading tools: ${toolErr instanceof Error ? toolErr.message : String(toolErr)}`) - // We're still connected, just couldn't load tools - setState('ready') - connectingRef.current = false - } - } catch (connectErr) { - addLog('error', `Client connect error: ${connectErr instanceof Error ? connectErr.message : String(connectErr)}`) - - if (connectErr instanceof Error && connectErr.message.includes('Unauthorized')) { - // Only discover OAuth and authenticate if we get a 401 - await discoverOAuthAndAuthenticate(connectErr) - } else { - setState('failed') - setError(`Connection error: ${connectErr instanceof Error ? connectErr.message : String(connectErr)}`) - connectingRef.current = false - } - } - } catch (err) { - addLog('error', `Unexpected error: ${err instanceof Error ? err.message : String(err)}`) - setState('failed') - setError(`Unexpected error: ${err instanceof Error ? err.message : String(err)}`) - connectingRef.current = false } + return clientRef.current }, [ - url, - clientName, - clientUri, - callbackUrl, - storageKeyPrefix, - clientConfig, - debug, - autoReconnect, - addLog, - handleAuthentication, - disconnect, + options.url, + options.clientName, + options.clientUri, + options.callbackUrl, + options.storageKeyPrefix, + options.clientConfig, + options.debug, + options.autoRetry, + options.autoReconnect, + options.popupFeatures, ]) - // Provide public authenticate method - const authenticate = useCallback(async (): Promise => { - if (!authUrlRef.current) { - await startAuthFlow() + // Connect on initial mount + useEffect(() => { + if (isInitialMount.current) { + isInitialMount.current = false + const client = getClient() + client.connect() } + }, [getClient]) - if (authUrlRef.current) { - return authUrlRef.current.toString() - } - return undefined - }, []) + // Auto-retry on failure + useEffect(() => { + if (state === 'failed' && options.autoRetry) { + const delay = typeof options.autoRetry === 'number' ? options.autoRetry : 5000 + const timeoutId = setTimeout(() => { + const client = getClient() + client.retry() + }, delay) - // Handle auth completion - this is called when we receive a message from the popup - const handleAuthCompletion = useCallback( - async (code: string) => { - if (!authProviderRef.current || !transportRef.current) { - throw new Error('Authentication context not available') + return () => { + clearTimeout(timeoutId) } - - try { - addLog('info', 'Finishing authorization...') - await transportRef.current.finishAuth(code) - addLog('info', 'Authorization completed') - - // Reset auth URL state - authUrlRef.current = undefined - setAuthUrl(undefined) - - // Reconnect with the new auth token - await disconnect() - connect() - } catch (err) { - addLog('error', `Auth completion error: ${err instanceof Error ? err.message : String(err)}`) - setState('failed') - setError(`Authentication failed: ${err instanceof Error ? err.message : String(err)}`) - } - }, - [addLog, disconnect, connect], - ) - - // Retry connection - const retry = useCallback(() => { - if (state === 'failed') { - disconnect().then(() => connect()) } - }, [state, disconnect, connect]) + }, [state, options.autoRetry, getClient]) // Set up message listener for auth callback useEffect(() => { const messageHandler = (event: MessageEvent) => { - // Verify origin for security if (event.origin !== window.location.origin) return - if (event.data && event.data.type === 'mcp_auth_callback' && event.data.code) { - handleAuthCompletion(event.data.code).catch((err) => { - addLog('error', `Auth callback error: ${err.message}`) - }) + if (event.data && event.data.type === 'mcp_auth_callback') { + const client = getClient() + + // If code is provided, use it; otherwise, assume tokens are already in localStorage + if (event.data.code) { + client.handleAuthCompletion(event.data.code).catch((err) => { + console.error('Auth callback error:', err) + }) + } else { + // Tokens were already exchanged by the popup + client.handleAuthCompletion('TOKENS_ALREADY_EXCHANGED').catch((err) => { + console.error('Auth callback error:', err) + }) + } } } @@ -693,34 +1064,45 @@ export function useMcp(options: UseMcpOptions): UseMcpResult { return () => { window.removeEventListener('message', messageHandler) } - }, [handleAuthCompletion, addLog]) - - // Initial connection and auto-retry - useEffect(() => { - if (isInitialMount.current) { - isInitialMount.current = false - connect() - } else if (state === 'failed' && autoRetry) { - const delay = typeof autoRetry === 'number' ? autoRetry : 5000 - const timeoutId = setTimeout(() => { - addLog('info', 'Auto-retrying connection...') - disconnect().then(() => connect()) - }, delay) - - return () => { - clearTimeout(timeoutId) - } - } - }, [state, autoRetry, connect, disconnect, addLog]) + }, [getClient]) // Clean up on unmount useEffect(() => { return () => { - if (clientRef.current || transportRef.current) { - disconnect() + if (clientRef.current) { + clientRef.current.disconnect() } } - }, [disconnect]) + }, []) + + // Public methods - proxied to the client + const callTool = useCallback( + async (name: string, args?: Record) => { + const client = getClient() + return client.callTool(name, args) + }, + [getClient], + ) + + const retry = useCallback(() => { + const client = getClient() + client.retry() + }, [getClient]) + + const disconnect = useCallback(async () => { + const client = getClient() + await client.disconnect() + }, [getClient]) + + const authenticate = useCallback(async (): Promise => { + const client = getClient() + return client.authenticate() + }, [getClient]) + + const clearStorage = useCallback(() => { + const client = getClient() + client.clearStorage() + }, [getClient]) return { state, @@ -732,6 +1114,7 @@ export function useMcp(options: UseMcpOptions): UseMcpResult { retry, disconnect, authenticate, + clearStorage, } } @@ -766,8 +1149,6 @@ export async function onMcpAuthorization( } // Find the matching auth state in localStorage - // const storageKeys = Object.keys(localStorage).filter((key) => key.includes('_auth_state') && localStorage.getItem(key) === state) - const stateKey = `${storageKeyPrefix}:state_${state}` const storedState = localStorage.getItem(stateKey) console.log({ stateKey, storedState }) @@ -815,6 +1196,8 @@ export async function onMcpAuthorization( window.opener.postMessage( { type: 'mcp_auth_callback', + // Don't send the code back since we've already done the token exchange + // This signals to the main window that tokens are already in localStorage }, window.location.origin, )