diff --git a/src/react/index.ts b/src/react/index.ts index fa72917..d1062f4 100644 --- a/src/react/index.ts +++ b/src/react/index.ts @@ -2,7 +2,12 @@ import { CallToolResultSchema, JSONRPCMessage, ListToolsResultSchema, Tool } fro import { useCallback, useEffect, useRef, useState } from 'react' import { SSEClientTransport } from '@modelcontextprotocol/sdk/client/sse.js' import { Client } from '@modelcontextprotocol/sdk/client/index.js' -import { discoverOAuthMetadata, exchangeAuthorization, startAuthorization } from '@modelcontextprotocol/sdk/client/auth.js' +import { + OAuthClientProvider, + discoverOAuthMetadata, + exchangeAuthorization, + startAuthorization, +} from '@modelcontextprotocol/sdk/client/auth.js' import { OAuthClientInformation, OAuthMetadata, OAuthTokens } from '@modelcontextprotocol/sdk/shared/auth.js' function assert(condition: unknown, message: string): asserts condition { @@ -86,7 +91,7 @@ type StoredState = { /** * Browser-compatible OAuth client provider for MCP */ -class BrowserOAuthClientProvider { +class BrowserOAuthClientProvider implements OAuthClientProvider { private storageKeyPrefix: string private serverUrlHash: string private clientName: string @@ -288,156 +293,401 @@ class BrowserOAuthClientProvider { } /** - * useMcp is a React hook that connects to a remote MCP server, negotiates auth - * (including opening a popup window or new tab to complete the OAuth flow), - * and enables passing a list of tools (once loaded) to ai-sdk (using `useChat`). + * Class to encapsulate all MCP client functionality, + * including authentication flow and connection management */ -export function useMcp(options: UseMcpOptions): UseMcpResult { - const [state, setState] = useState('discovering') - const [tools, setTools] = useState([]) - const [error, setError] = useState(undefined) - const [log, setLog] = useState([]) - const [authUrl, setAuthUrl] = useState(undefined) +class McpClient { + // State + private _state: UseMcpResult['state'] = 'discovering' + private _error?: string + private _tools: Tool[] = [] + private _log: UseMcpResult['log'] = [] + private _authUrl?: string - const clientRef = useRef(null) - const transportRef = useRef(null) - const authProviderRef = useRef(null) - const metadataRef = useRef(undefined) - const authUrlRef = useRef(undefined) - const codeVerifierRef = useRef(undefined) - const connectingRef = useRef(false) - const isInitialMount = useRef(true) + // Client and transport + private client: Client | null = null + private transport: SSEClientTransport | null = null + private authProvider: BrowserOAuthClientProvider | null = null - // Set up default options - const { - url, - clientName = 'MCP React Client', - clientUri = window.location.origin, - callbackUrl = new URL('/oauth/callback', window.location.origin).toString(), - storageKeyPrefix = 'mcp:auth', - clientConfig = { - name: 'mcp-react-client', - version: '0.1.0', + // Authentication state + private metadata?: OAuthMetadata + private authUrlRef?: URL + private codeVerifier?: string + private connecting = false + + // Update callbacks + private onStateChange: (state: UseMcpResult['state']) => void + private onToolsChange: (tools: Tool[]) => void + private onErrorChange: (error?: string) => void + private onLogChange: (log: UseMcpResult['log']) => void + private onAuthUrlChange: (authUrl?: string) => void + + constructor( + private url: string, + private options: { + clientName?: string + clientUri?: string + callbackUrl?: string + storageKeyPrefix?: string + clientConfig?: { + name?: string + version?: string + } + debug?: boolean + autoRetry?: boolean | number + autoReconnect?: boolean | number + popupFeatures?: string }, - debug = false, - autoRetry = false, - autoReconnect = 3000, - popupFeatures = 'width=600,height=700,resizable=yes,scrollbars=yes', - } = options - - // Add to log - const addLog = useCallback( - (level: 'debug' | 'info' | 'warn' | 'error', message: string) => { - if (level === 'debug' && !debug) return - setLog((prevLog) => [...prevLog, { level, message }]) + callbacks: { + onStateChange: (state: UseMcpResult['state']) => void + onToolsChange: (tools: Tool[]) => void + onErrorChange: (error?: string) => void + onLogChange: (log: UseMcpResult['log']) => void + onAuthUrlChange: (authUrl?: string) => void }, - [debug], - ) + ) { + // Initialize callbacks + this.onStateChange = callbacks.onStateChange + this.onToolsChange = callbacks.onToolsChange + this.onErrorChange = callbacks.onErrorChange + this.onLogChange = callbacks.onLogChange + this.onAuthUrlChange = callbacks.onAuthUrlChange - // Call a tool on the MCP server - const callTool = useCallback( - async (name: string, args?: Record) => { - if (!clientRef.current || state !== 'ready') { - throw new Error('MCP client not ready') + // Initialize auth provider + this.initAuthProvider() + } + + get state(): UseMcpResult['state'] { + return this._state + } + + get tools(): Tool[] { + return this._tools + } + + get error(): string | undefined { + return this._error + } + + get log(): UseMcpResult['log'] { + return this._log + } + + get authUrl(): string | undefined { + return this._authUrl + } + + /** + * Initialize the auth provider + */ + private initAuthProvider(): void { + if (!this.authProvider) { + this.authProvider = new BrowserOAuthClientProvider(this.url, { + storageKeyPrefix: this.options.storageKeyPrefix, + clientName: this.options.clientName, + clientUri: this.options.clientUri, + callbackUrl: this.options.callbackUrl, + }) + } + } + + /** + * Add a log entry + */ + private addLog(level: 'debug' | 'info' | 'warn' | 'error', message: string): void { + if (level === 'debug' && !this.options.debug) return + this._log = [...this._log, { level, message }] + this.onLogChange(this._log) + } + + /** + * Update the state + */ + private setState(state: UseMcpResult['state']): void { + this._state = state + this.onStateChange(state) + } + + /** + * Update the error + */ + private setError(error?: string): void { + this._error = error + this.onErrorChange(error) + } + + /** + * Update the tools + */ + private setTools(tools: Tool[]): void { + this._tools = tools + this.onToolsChange(tools) + } + + /** + * Update the auth URL + */ + private setAuthUrl(authUrl?: string): void { + this._authUrl = authUrl + this.onAuthUrlChange(authUrl) + } + + /** + * Handle OAuth discovery and authentication + */ + private async discoverOAuthAndAuthenticate(error: Error): Promise { + try { + // Discover OAuth metadata now that we know we need it + if (!this.metadata) { + this.addLog('info', 'Discovering OAuth metadata...') + this.metadata = await discoverOAuthMetadata(this.url) + this.addLog('debug', `OAuth metadata: ${this.metadata ? 'Found' : 'Not available'}`) } - try { - console.log('CALLING TOOL') - const result = await clientRef.current.request( - { - method: 'tools/call', - params: { name, arguments: args }, + // If metadata is found, start auth flow + if (this.metadata) { + this.setState('authenticating') + + try { + // Start authentication process + await this.handleAuthentication() + + // After successful auth, retry connection + // Important: We need to fully disconnect and reconnect + await this.disconnect() + await this.connect() + } catch (authErr) { + this.addLog('error', `Authentication error: ${authErr instanceof Error ? authErr.message : String(authErr)}`) + this.setState('failed') + this.setError(`Authentication failed: ${authErr instanceof Error ? authErr.message : String(authErr)}`) + this.connecting = false + } + } else { + // No OAuth metadata available + this.setState('failed') + this.setError(`Authentication required but no OAuth metadata found: ${error.message}`) + this.connecting = false + } + } catch (oauthErr) { + this.addLog('error', `OAuth discovery error: ${oauthErr instanceof Error ? oauthErr.message : String(oauthErr)}`) + this.setState('failed') + this.setError(`Authentication setup failed: ${oauthErr instanceof Error ? oauthErr.message : String(oauthErr)}`) + this.connecting = false + } + } + + /** + * Connect to the MCP server + */ + async connect(): Promise { + // Prevent multiple simultaneous connection attempts + if (this.connecting) return + this.connecting = true + + try { + this.setState('discovering') + this.setError(undefined) + + // Create MCP client + this.client = new Client( + { + name: this.options.clientConfig?.name || 'mcp-react-client', + version: this.options.clientConfig?.version || '0.1.0', + }, + { + capabilities: { + sampling: {}, }, - CallToolResultSchema, - ) - return result - } catch (err) { - addLog('error', `Error calling tool ${name}: ${err instanceof Error ? err.message : String(err)}`) - throw err - } - }, - [state, addLog], - ) + }, + ) - // Disconnect from the MCP server - const disconnect = useCallback(async () => { - if (clientRef.current) { - try { - await clientRef.current.close() - } catch (err) { - addLog('error', `Error closing client: ${err instanceof Error ? err.message : String(err)}`) + // Create SSE transport + this.setState('connecting') + this.addLog('info', 'Creating transport...') + + const serverUrl = new URL(this.url) + this.transport = new SSEClientTransport(serverUrl, { + authProvider: this.authProvider, + }) + + // Set up transport handlers + this.transport.onmessage = (message: JSONRPCMessage) => { + // @ts-expect-error TODO: fix this type + this.addLog('debug', `Received message: ${message.method || message.id}`) } - clientRef.current = null + + this.transport.onerror = (err: Error) => { + this.addLog('error', `Transport error: ${err.message}`) + + if (err.message.includes('Unauthorized')) { + // Only discover OAuth metadata and authenticate if we get a 401 + this.discoverOAuthAndAuthenticate(err) + } else { + this.setState('failed') + this.setError(`Connection error: ${err.message}`) + this.connecting = false + } + } + + this.transport.onclose = () => { + this.addLog('info', 'Connection closed') + // If we were previously connected, try to reconnect + if (this.state === 'ready' && this.options.autoReconnect) { + const delay = typeof this.options.autoReconnect === 'number' ? this.options.autoReconnect : 3000 + this.addLog('info', `Will reconnect in ${delay}ms...`) + setTimeout(() => { + this.disconnect().then(() => this.connect()) + }, delay) + } + } + + // Try connecting transport + try { + this.addLog('info', 'Starting transport...') + // await this.transport.start() + } catch (err) { + this.addLog('error', `Transport start error: ${err instanceof Error ? err.message : String(err)}`) + + if (err instanceof Error && err.message.includes('Unauthorized')) { + // Only discover OAuth and authenticate if we get a 401 + await this.discoverOAuthAndAuthenticate(err) + return // Important: Return here to avoid proceeding with the unauthorized connection + } else { + this.setState('failed') + this.setError(`Connection error: ${err instanceof Error ? err.message : String(err)}`) + this.connecting = false + return + } + } + + // Connect client + try { + this.addLog('info', 'Connecting client...') + this.setState('loading') + await this.client.connect(this.transport) + this.addLog('info', 'Client connected') + + // Load tools + try { + this.addLog('info', 'Loading tools...') + const toolsResponse = await this.client.request({ method: 'tools/list' }, ListToolsResultSchema) + this.setTools(toolsResponse.tools) + this.addLog('info', `Loaded ${toolsResponse.tools.length} tools`) + + // Connection completed successfully + this.setState('ready') + this.connecting = false + } catch (toolErr) { + this.addLog('error', `Error loading tools: ${toolErr instanceof Error ? toolErr.message : String(toolErr)}`) + // We're still connected, just couldn't load tools + this.setState('ready') + this.connecting = false + } + } catch (connectErr) { + this.addLog('error', `Client connect error: ${connectErr instanceof Error ? connectErr.message : String(connectErr)}`) + + if (connectErr instanceof Error && connectErr.message.includes('Unauthorized')) { + // Only discover OAuth and authenticate if we get a 401 + await this.discoverOAuthAndAuthenticate(connectErr) + } else { + this.setState('failed') + this.setError(`Connection error: ${connectErr instanceof Error ? connectErr.message : String(connectErr)}`) + this.connecting = false + } + } + } catch (err) { + this.addLog('error', `Unexpected error: ${err instanceof Error ? err.message : String(err)}`) + this.setState('failed') + this.setError(`Unexpected error: ${err instanceof Error ? err.message : String(err)}`) + this.connecting = false + } + } + + /** + * Disconnect from the MCP server + */ + async disconnect(): Promise { + if (this.client) { + try { + await this.client.close() + } catch (err) { + this.addLog('error', `Error closing client: ${err instanceof Error ? err.message : String(err)}`) + } + this.client = null } - if (transportRef.current) { + if (this.transport) { try { - await transportRef.current.close() + await this.transport.close() } catch (err) { - addLog('error', `Error closing transport: ${err instanceof Error ? err.message : String(err)}`) + this.addLog('error', `Error closing transport: ${err instanceof Error ? err.message : String(err)}`) } - transportRef.current = null + this.transport = null } - connectingRef.current = false - setState('discovering') - setTools([]) - setError(undefined) - }, [addLog]) + this.connecting = false + this.setState('discovering') + this.setTools([]) + this.setError(undefined) + } - // Start the auth flow and get the auth URL - const startAuthFlow = useCallback(async (): Promise => { - if (!authProviderRef.current || !metadataRef.current) { + /** + * Start the auth flow and get the auth URL + */ + async startAuthFlow(): Promise { + if (!this.authProvider || !this.metadata) { throw new Error('Auth provider or metadata not available') } - addLog('info', 'Starting authentication flow...') + this.addLog('info', 'Starting authentication flow...') // Check if we have client info - let clientInfo = await authProviderRef.current.clientInformation() + let clientInfo = await this.authProvider.clientInformation() if (!clientInfo) { // Register client dynamically - addLog('info', 'No client information found, registering...') + this.addLog('info', 'No client information found, registering...') // Note: In a complete implementation, you'd register the client here // This would be done server-side in a real application throw new Error('Dynamic client registration not implemented in this example') } // Start authorization flow - addLog('info', 'Preparing authorization...') - const { authorizationUrl, codeVerifier } = await startAuthorization(url, { - metadata: metadataRef.current, + this.addLog('info', 'Preparing authorization...') + const { authorizationUrl, codeVerifier } = await startAuthorization(this.url, { + metadata: this.metadata, clientInformation: clientInfo, - redirectUrl: authProviderRef.current.redirectUrl, + redirectUrl: this.authProvider.redirectUrl, }) // Save code verifier and auth URL for later use - await authProviderRef.current.saveCodeVerifier(codeVerifier) - codeVerifierRef.current = codeVerifier - authUrlRef.current = authorizationUrl - setAuthUrl(authorizationUrl.toString()) + await this.authProvider.saveCodeVerifier(codeVerifier) + this.codeVerifier = codeVerifier + this.authUrlRef = authorizationUrl + this.setAuthUrl(authorizationUrl.toString()) return authorizationUrl - }, [url, addLog]) + } - // Handle authentication flow - const handleAuthentication = useCallback(async () => { - if (!authProviderRef.current) { + /** + * Handle authentication flow + */ + async handleAuthentication(): Promise { + if (!this.authProvider) { throw new Error('Auth provider not available') } // Get or create the auth URL - if (!authUrlRef.current) { + if (!this.authUrlRef) { try { - await startAuthFlow() + await this.startAuthFlow() } catch (err) { - addLog('error', `Failed to start auth flow: ${err instanceof Error ? err.message : String(err)}`) + this.addLog('error', `Failed to start auth flow: ${err instanceof Error ? err.message : String(err)}`) throw err } } - if (!authUrlRef.current) { + if (!this.authUrlRef) { throw new Error('Failed to create authorization URL') } @@ -463,7 +713,6 @@ export function useMcp(options: UseMcpOptions): UseMcpResult { clearTimeout(timeoutId) if (pollIntervalId) clearTimeout(pollIntervalId) - // TODO: not this, obviously resolve(event.data.code) } } @@ -474,7 +723,7 @@ export function useMcp(options: UseMcpOptions): UseMcpResult { const pollForTokens = () => { try { // Check if tokens have appeared in localStorage - const tokensKey = authProviderRef.current!.getKey('tokens') + const tokensKey = this.authProvider!.getKey('tokens') const storedTokens = localStorage.getItem(tokensKey) if (storedTokens) { @@ -486,7 +735,7 @@ export function useMcp(options: UseMcpOptions): UseMcpResult { // Parse tokens to make sure they're valid const tokens = JSON.parse(storedTokens) if (tokens.access_token) { - addLog('info', 'Found tokens in localStorage via polling') + this.addLog('info', 'Found tokens in localStorage via polling') resolve(tokens.access_token) } } @@ -506,263 +755,217 @@ export function useMcp(options: UseMcpOptions): UseMcpResult { }) // Redirect to authorization - addLog('info', 'Opening authorization window...') - assert(metadataRef.current, 'Metadata not available') - const redirectResult = await authProviderRef.current.redirectToAuthorization(authUrlRef.current, metadataRef.current, { - popupFeatures, + this.addLog('info', 'Opening authorization window...') + assert(this.metadata, 'Metadata not available') + const redirectResult = await this.authProvider.redirectToAuthorization(this.authUrlRef, this.metadata, { + popupFeatures: this.options.popupFeatures, }) if (!redirectResult.success) { // Popup was blocked - setState('failed') - setError('Authentication popup was blocked by the browser. Please click the link to authenticate in a new window.') - setAuthUrl(redirectResult.url) - addLog('warn', 'Authentication popup was blocked. User needs to manually authorize.') + this.setState('failed') + this.setError('Authentication popup was blocked by the browser. Please click the link to authenticate in a new window.') + this.setAuthUrl(redirectResult.url) + this.addLog('warn', 'Authentication popup was blocked. User needs to manually authorize.') throw new Error('Authentication popup blocked') } // Wait for auth to complete - addLog('info', 'Waiting for authorization...') + this.addLog('info', 'Waiting for authorization...') const code = await authPromise - addLog('info', 'Authorization code received') + this.addLog('info', 'Authorization code received') return code - }, [url, addLog, popupFeatures, startAuthFlow]) + } - // Initialize connection to MCP server - const connect = useCallback(async () => { - // Prevent multiple simultaneous connection attempts - if (connectingRef.current) return - connectingRef.current = true + /** + * Handle authentication completion + */ + async handleAuthCompletion(code: string): Promise { + if (!this.authProvider || !this.transport) { + throw new Error('Authentication context not available') + } try { - setState('discovering') - setError(undefined) + this.addLog('info', 'Finishing authorization...') + await this.transport.finishAuth(code) + this.addLog('info', 'Authorization completed') - // Create auth provider if not already created - if (!authProviderRef.current) { - authProviderRef.current = new BrowserOAuthClientProvider(url, { - storageKeyPrefix, - clientName, - clientUri, - callbackUrl, - }) - } + // Reset auth URL state + this.authUrlRef = undefined + this.setAuthUrl(undefined) - // Create MCP client - clientRef.current = new Client( - { - name: clientConfig.name || 'mcp-react-client', - version: clientConfig.version || '0.1.0', - }, - { - capabilities: { - sampling: {}, - }, - }, - ) - - // Create SSE transport - try connecting without auth first - setState('connecting') - addLog('info', 'Creating transport...') - - const serverUrl = new URL(url) - transportRef.current = new SSEClientTransport(serverUrl, { - // @ts-expect-error TODO: fix this type, expect BrowserOAuthClientProvider - authProvider: authProviderRef.current, - }) - - // Set up transport handlers - transportRef.current.onmessage = (message: JSONRPCMessage) => { - // @ts-expect-error TODO: fix this type - addLog('debug', `Received message: ${message.method || message.id}`) - } - - transportRef.current.onerror = (err: Error) => { - addLog('error', `Transport error: ${err.message}`) - - if (err.message.includes('Unauthorized')) { - // Only discover OAuth metadata and authenticate if we get a 401 - discoverOAuthAndAuthenticate(err) - } else { - setState('failed') - setError(`Connection error: ${err.message}`) - connectingRef.current = false - } - } - - transportRef.current.onclose = () => { - addLog('info', 'Connection closed') - // If we were previously connected, try to reconnect - if (state === 'ready' && autoReconnect) { - const delay = typeof autoReconnect === 'number' ? autoReconnect : 3000 - addLog('info', `Will reconnect in ${delay}ms...`) - setTimeout(() => { - disconnect().then(() => connect()) - }, delay) - } - } - - // Helper function to handle OAuth discovery and authentication - const discoverOAuthAndAuthenticate = async (error: Error) => { - try { - // Discover OAuth metadata now that we know we need it - if (!metadataRef.current) { - addLog('info', 'Discovering OAuth metadata...') - metadataRef.current = await discoverOAuthMetadata(url) - addLog('debug', `OAuth metadata: ${metadataRef.current ? 'Found' : 'Not available'}`) - } - - // If metadata is found, start auth flow - if (metadataRef.current) { - setState('authenticating') - // Start authentication process - await handleAuthentication() - // After successful auth, retry connection - return connect() - } else { - // No OAuth metadata available - setState('failed') - setError(`Authentication required but no OAuth metadata found: ${error.message}`) - connectingRef.current = false - } - } catch (oauthErr) { - addLog('error', `OAuth discovery error: ${oauthErr instanceof Error ? oauthErr.message : String(oauthErr)}`) - setState('failed') - setError(`Authentication setup failed: ${oauthErr instanceof Error ? oauthErr.message : String(oauthErr)}`) - connectingRef.current = false - } - } - - // Try connecting transport first without OAuth discovery - try { - addLog('info', 'Starting transport...') - // await transportRef.current.start() - } catch (err) { - addLog('error', `Transport start error: ${err instanceof Error ? err.message : String(err)}`) - - if (err instanceof Error && err.message.includes('Unauthorized')) { - // Only discover OAuth and authenticate if we get a 401 - await discoverOAuthAndAuthenticate(err) - } else { - setState('failed') - setError(`Connection error: ${err instanceof Error ? err.message : String(err)}`) - connectingRef.current = false - return - } - } - - // Connect client - try { - addLog('info', 'Connecting client...') - setState('loading') - await clientRef.current.connect(transportRef.current) - addLog('info', 'Client connected') - - // Load tools - try { - addLog('info', 'Loading tools...') - const toolsResponse = await clientRef.current.request({ method: 'tools/list' }, ListToolsResultSchema) - setTools(toolsResponse.tools) - addLog('info', `Loaded ${toolsResponse.tools.length} tools`) - - // Connection completed successfully - setState('ready') - connectingRef.current = false - } catch (toolErr) { - addLog('error', `Error loading tools: ${toolErr instanceof Error ? toolErr.message : String(toolErr)}`) - // We're still connected, just couldn't load tools - setState('ready') - connectingRef.current = false - } - } catch (connectErr) { - addLog('error', `Client connect error: ${connectErr instanceof Error ? connectErr.message : String(connectErr)}`) - - if (connectErr instanceof Error && connectErr.message.includes('Unauthorized')) { - // Only discover OAuth and authenticate if we get a 401 - await discoverOAuthAndAuthenticate(connectErr) - } else { - setState('failed') - setError(`Connection error: ${connectErr instanceof Error ? connectErr.message : String(connectErr)}`) - connectingRef.current = false - } - } + // Reconnect with the new auth token - important to do a full disconnect/connect cycle + await this.disconnect() + await this.connect() } catch (err) { - addLog('error', `Unexpected error: ${err instanceof Error ? err.message : String(err)}`) - setState('failed') - setError(`Unexpected error: ${err instanceof Error ? err.message : String(err)}`) - connectingRef.current = false + this.addLog('error', `Auth completion error: ${err instanceof Error ? err.message : String(err)}`) + this.setState('failed') + this.setError(`Authentication failed: ${err instanceof Error ? err.message : String(err)}`) } - }, [ - url, - clientName, - clientUri, - callbackUrl, - storageKeyPrefix, - clientConfig, - debug, - autoReconnect, - addLog, - handleAuthentication, - disconnect, - ]) + } - // Provide public authenticate method - const authenticate = useCallback(async (): Promise => { - if (!authUrlRef.current) { - await startAuthFlow() + /** + * Call a tool on the MCP server + */ + async callTool(name: string, args?: Record): Promise { + if (!this.client || this.state !== 'ready') { + throw new Error('MCP client not ready') } - if (authUrlRef.current) { - return authUrlRef.current.toString() + try { + const result = await this.client.request( + { + method: 'tools/call', + params: { name, arguments: args }, + }, + CallToolResultSchema, + ) + return result + } catch (err) { + this.addLog('error', `Error calling tool ${name}: ${err instanceof Error ? err.message : String(err)}`) + throw err + } + } + + /** + * Retry connection + */ + retry(): void { + if (this.state === 'failed') { + this.disconnect().then(() => this.connect()) + } + } + + /** + * Manually trigger authentication + */ + async authenticate(): Promise { + if (!this.authUrlRef) { + await this.startAuthFlow() + } + + if (this.authUrlRef) { + return this.authUrlRef.toString() } return undefined - }, []) + } - // Handle auth completion - this is called when we receive a message from the popup - const handleAuthCompletion = useCallback( - async (code: string) => { - if (!authProviderRef.current || !transportRef.current) { - throw new Error('Authentication context not available') - } - - try { - addLog('info', 'Finishing authorization...') - await transportRef.current.finishAuth(code) - addLog('info', 'Authorization completed') - - // Reset auth URL state - authUrlRef.current = undefined - setAuthUrl(undefined) - - // Reconnect with the new auth token - await disconnect() - connect() - } catch (err) { - addLog('error', `Auth completion error: ${err instanceof Error ? err.message : String(err)}`) - setState('failed') - setError(`Authentication failed: ${err instanceof Error ? err.message : String(err)}`) - } - }, - [addLog, disconnect, connect], - ) - - // Retry connection - const retry = useCallback(() => { - if (state === 'failed') { - disconnect().then(() => connect()) + /** + * Clear all localStorage items for this server + */ + clearStorage(): number { + if (!this.authProvider) { + this.addLog('warn', 'Cannot clear storage: auth provider not initialized') + return 0 } - }, [state, disconnect, connect]) + + // Use the provider's method to clear storage + const clearedCount = this.authProvider.clearStorage() + + // Clear auth-related state in the class + this.authUrlRef = undefined + this.setAuthUrl(undefined) + this.metadata = undefined + this.codeVerifier = undefined + + this.addLog('info', `Cleared ${clearedCount} storage items for server`) + + return clearedCount + } +} + +/** + * useMcp is a React hook that connects to a remote MCP server, negotiates auth + * (including opening a popup window or new tab to complete the OAuth flow), + * and enables passing a list of tools (once loaded) to ai-sdk (using `useChat`). + */ +export function useMcp(options: UseMcpOptions): UseMcpResult { + const [state, setState] = useState('discovering') + const [tools, setTools] = useState([]) + const [error, setError] = useState(undefined) + const [log, setLog] = useState([]) + const [authUrl, setAuthUrl] = useState(undefined) + + // Use a ref to maintain a single instance of the McpClient + const clientRef = useRef(null) + const isInitialMount = useRef(true) + + // Initialize the client if it doesn't exist yet + const getClient = useCallback(() => { + if (!clientRef.current) { + clientRef.current = new McpClient( + options.url, + { + clientName: options.clientName || 'MCP React Client', + clientUri: options.clientUri || window.location.origin, + callbackUrl: options.callbackUrl || new URL('/oauth/callback', window.location.origin).toString(), + storageKeyPrefix: options.storageKeyPrefix || 'mcp:auth', + clientConfig: options.clientConfig || { + name: 'mcp-react-client', + version: '0.1.0', + }, + debug: options.debug || false, + autoRetry: options.autoRetry || false, + autoReconnect: options.autoReconnect || 3000, + popupFeatures: options.popupFeatures || 'width=600,height=700,resizable=yes,scrollbars=yes', + }, + { + onStateChange: setState, + onToolsChange: setTools, + onErrorChange: setError, + onLogChange: setLog, + onAuthUrlChange: setAuthUrl, + }, + ) + } + return clientRef.current + }, [ + options.url, + options.clientName, + options.clientUri, + options.callbackUrl, + options.storageKeyPrefix, + options.clientConfig, + options.debug, + options.autoRetry, + options.autoReconnect, + options.popupFeatures, + ]) + + // Connect on initial mount + useEffect(() => { + if (isInitialMount.current) { + isInitialMount.current = false + const client = getClient() + client.connect() + } + }, [getClient]) + + // Auto-retry on failure + useEffect(() => { + if (state === 'failed' && options.autoRetry) { + const delay = typeof options.autoRetry === 'number' ? options.autoRetry : 5000 + const timeoutId = setTimeout(() => { + const client = getClient() + client.retry() + }, delay) + + return () => { + clearTimeout(timeoutId) + } + } + }, [state, options.autoRetry, getClient]) // Set up message listener for auth callback useEffect(() => { const messageHandler = (event: MessageEvent) => { - // Verify origin for security if (event.origin !== window.location.origin) return if (event.data && event.data.type === 'mcp_auth_callback' && event.data.code) { - handleAuthCompletion(event.data.code).catch((err) => { - addLog('error', `Auth callback error: ${err.message}`) + const client = getClient() + client.handleAuthCompletion(event.data.code).catch((err) => { + console.error('Auth callback error:', err) }) } } @@ -771,53 +974,45 @@ export function useMcp(options: UseMcpOptions): UseMcpResult { return () => { window.removeEventListener('message', messageHandler) } - }, [handleAuthCompletion, addLog]) - - // Initial connection and auto-retry - useEffect(() => { - if (isInitialMount.current) { - isInitialMount.current = false - connect() - } else if (state === 'failed' && autoRetry) { - const delay = typeof autoRetry === 'number' ? autoRetry : 5000 - const timeoutId = setTimeout(() => { - addLog('info', 'Auto-retrying connection...') - disconnect().then(() => connect()) - }, delay) - - return () => { - clearTimeout(timeoutId) - } - } - }, [state, autoRetry, connect, disconnect, addLog]) + }, [getClient]) // Clean up on unmount useEffect(() => { return () => { - if (clientRef.current || transportRef.current) { - disconnect() + if (clientRef.current) { + clientRef.current.disconnect() } } - }, [disconnect]) + }, []) + + // Public methods - proxied to the client + const callTool = useCallback( + async (name: string, args?: Record) => { + const client = getClient() + return client.callTool(name, args) + }, + [getClient], + ) + + const retry = useCallback(() => { + const client = getClient() + client.retry() + }, [getClient]) + + const disconnect = useCallback(async () => { + const client = getClient() + await client.disconnect() + }, [getClient]) + + const authenticate = useCallback(async (): Promise => { + const client = getClient() + return client.authenticate() + }, [getClient]) - // Clear all localStorage items for this server const clearStorage = useCallback(() => { - if (!authProviderRef.current) { - addLog('warn', 'Cannot clear storage: auth provider not initialized') - return - } - - // Use the provider's method to clear storage - const clearedCount = authProviderRef.current.clearStorage() - - // Clear auth-related state in the hook - authUrlRef.current = undefined - setAuthUrl(undefined) - metadataRef.current = undefined - codeVerifierRef.current = undefined - - addLog('info', `Cleared ${clearedCount} storage items for server`) - }, [addLog]) + const client = getClient() + client.clearStorage() + }, [getClient]) return { state, @@ -864,8 +1059,6 @@ export async function onMcpAuthorization( } // Find the matching auth state in localStorage - // const storageKeys = Object.keys(localStorage).filter((key) => key.includes('_auth_state') && localStorage.getItem(key) === state) - const stateKey = `${storageKeyPrefix}:state_${state}` const storedState = localStorage.getItem(stateKey) console.log({ stateKey, storedState }) @@ -913,6 +1106,7 @@ export async function onMcpAuthorization( window.opener.postMessage( { type: 'mcp_auth_callback', + code: code, }, window.location.origin, )