adding polling fallback for tokens

This commit is contained in:
Glen Maddern 2025-03-25 16:11:01 +11:00
parent c4779d082a
commit 397fcd3e1b

View file

@ -129,37 +129,37 @@ class BrowserOAuthClientProvider {
* @returns The number of items cleared * @returns The number of items cleared
*/ */
clearStorage(): number { clearStorage(): number {
const prefix = `${this.storageKeyPrefix}_${this.serverUrlHash}`; const prefix = `${this.storageKeyPrefix}_${this.serverUrlHash}`
const keysToRemove = []; const keysToRemove = []
// Find all keys that match the prefix // Find all keys that match the prefix
for (let i = 0; i < localStorage.length; i++) { for (let i = 0; i < localStorage.length; i++) {
const key = localStorage.key(i); const key = localStorage.key(i)
if (key && key.startsWith(prefix)) { if (key && key.startsWith(prefix)) {
keysToRemove.push(key); keysToRemove.push(key)
} }
} }
// Also check for any state keys // Also check for any state keys
for (let i = 0; i < localStorage.length; i++) { for (let i = 0; i < localStorage.length; i++) {
const key = localStorage.key(i); const key = localStorage.key(i)
if (key && key.startsWith(`${this.storageKeyPrefix}:state_`)) { if (key && key.startsWith(`${this.storageKeyPrefix}:state_`)) {
// Load state to check if it's for this server // Load state to check if it's for this server
try { try {
const state = JSON.parse(localStorage.getItem(key) || '{}'); const state = JSON.parse(localStorage.getItem(key) || '{}')
if (state.serverUrlHash === this.serverUrlHash) { if (state.serverUrlHash === this.serverUrlHash) {
keysToRemove.push(key); keysToRemove.push(key)
} }
} catch (e) { } catch (e) {
// Ignore JSON parse errors // Ignore JSON parse errors
} }
} }
} }
// Remove all matching keys // Remove all matching keys
keysToRemove.forEach(key => localStorage.removeItem(key)); keysToRemove.forEach((key) => localStorage.removeItem(key))
return keysToRemove.length; return keysToRemove.length
} }
private hashString(str: string): string { private hashString(str: string): string {
@ -173,7 +173,7 @@ class BrowserOAuthClientProvider {
return Math.abs(hash).toString(16) return Math.abs(hash).toString(16)
} }
private getKey(key: string): string { getKey(key: string): string {
return `${this.storageKeyPrefix}_${this.serverUrlHash}_${key}` return `${this.storageKeyPrefix}_${this.serverUrlHash}_${key}`
} }
@ -443,9 +443,12 @@ export function useMcp(options: UseMcpOptions): UseMcpResult {
// Set up listener for post-auth message // Set up listener for post-auth message
const authPromise = new Promise<string>((resolve, reject) => { const authPromise = new Promise<string>((resolve, reject) => {
let pollIntervalId: number | undefined
const timeoutId = setTimeout( const timeoutId = setTimeout(
() => { () => {
window.removeEventListener('message', messageHandler) window.removeEventListener('message', messageHandler)
if (pollIntervalId) clearTimeout(pollIntervalId)
reject(new Error('Authentication timeout after 5 minutes')) reject(new Error('Authentication timeout after 5 minutes'))
}, },
5 * 60 * 1000, 5 * 60 * 1000,
@ -458,15 +461,48 @@ export function useMcp(options: UseMcpOptions): UseMcpResult {
if (event.data && event.data.type === 'mcp_auth_callback' && event.data.code) { if (event.data && event.data.type === 'mcp_auth_callback' && event.data.code) {
window.removeEventListener('message', messageHandler) window.removeEventListener('message', messageHandler)
clearTimeout(timeoutId) clearTimeout(timeoutId)
if (pollIntervalId) clearTimeout(pollIntervalId)
// TODO: not this, obviously // TODO: not this, obviously
// reload window, we should find the token in local storage resolve(event.data.code)
window.location.reload()
// resolve(event.data.code)
} }
} }
window.addEventListener('message', messageHandler) window.addEventListener('message', messageHandler)
// Add polling fallback to check for tokens in localStorage
const pollForTokens = () => {
try {
// Check if tokens have appeared in localStorage
const tokensKey = authProviderRef.current!.getKey('tokens')
const storedTokens = localStorage.getItem(tokensKey)
if (storedTokens) {
// Tokens found, clean up and resolve
window.removeEventListener('message', messageHandler)
clearTimeout(timeoutId)
if (pollIntervalId) clearTimeout(pollIntervalId)
// Parse tokens to make sure they're valid
const tokens = JSON.parse(storedTokens)
if (tokens.access_token) {
addLog('info', 'Found tokens in localStorage via polling')
resolve(tokens.access_token)
}
}
} catch (err) {
// Error during polling, continue anyway
console.error(err)
}
}
// Start polling every 500ms using setTimeout for recursive polling
const poll = () => {
pollForTokens()
pollIntervalId = setTimeout(poll, 500) as unknown as number
}
poll() // Start the polling
}) })
// Redirect to authorization // Redirect to authorization
@ -767,21 +803,21 @@ export function useMcp(options: UseMcpOptions): UseMcpResult {
// Clear all localStorage items for this server // Clear all localStorage items for this server
const clearStorage = useCallback(() => { const clearStorage = useCallback(() => {
if (!authProviderRef.current) { if (!authProviderRef.current) {
addLog('warn', 'Cannot clear storage: auth provider not initialized'); addLog('warn', 'Cannot clear storage: auth provider not initialized')
return; return
} }
// Use the provider's method to clear storage // Use the provider's method to clear storage
const clearedCount = authProviderRef.current.clearStorage(); const clearedCount = authProviderRef.current.clearStorage()
// Clear auth-related state in the hook // Clear auth-related state in the hook
authUrlRef.current = undefined; authUrlRef.current = undefined
setAuthUrl(undefined); setAuthUrl(undefined)
metadataRef.current = undefined; metadataRef.current = undefined
codeVerifierRef.current = undefined; codeVerifierRef.current = undefined
addLog('info', `Cleared ${clearedCount} storage items for server`); addLog('info', `Cleared ${clearedCount} storage items for server`)
}, [addLog]); }, [addLog])
return { return {
state, state,