most of the implementation looking ok but sharing the token the wrong way

This commit is contained in:
Glen Maddern 2025-03-31 19:38:52 +11:00
parent 743b6b207f
commit 412b5d9486
6 changed files with 364 additions and 62 deletions

View file

@ -4,6 +4,10 @@ import { Transport } from '@modelcontextprotocol/sdk/shared/transport.js'
import { OAuthCallbackServerOptions } from './types'
import express from 'express'
import net from 'net'
import crypto from 'crypto'
// Package version from package.json
export const MCP_REMOTE_VERSION = require('../../package.json').version
const pid = process.pid
export function log(str: string, ...rest: unknown[]) {
@ -65,12 +69,14 @@ export function mcpProxy({ transportToClient, transportToServer }: { transportTo
* @param serverUrl The URL of the remote server
* @param authProvider The OAuth client provider
* @param waitForAuthCode Function to wait for the auth code
* @param skipBrowserAuth Whether to skip browser auth and use shared auth
* @returns The connected SSE client transport
*/
export async function connectToRemoteServer(
serverUrl: string,
authProvider: OAuthClientProvider,
waitForAuthCode: () => Promise<string>,
skipBrowserAuth: boolean = false,
): Promise<SSEClientTransport> {
log(`[${pid}] Connecting to remote server: ${serverUrl}`)
const url = new URL(serverUrl)
@ -82,7 +88,11 @@ export async function connectToRemoteServer(
return transport
} catch (error) {
if (error instanceof UnauthorizedError || (error instanceof Error && error.message.includes('Unauthorized'))) {
log('Authentication required. Waiting for authorization...')
if (skipBrowserAuth) {
log('Authentication required but skipping browser auth - using shared auth')
} else {
log('Authentication required. Waiting for authorization...')
}
// Wait for the authorization code from the callback
const code = await waitForAuthCode()
@ -112,10 +122,56 @@ export async function connectToRemoteServer(
* @param options The server options
* @returns An object with the server, authCode, and waitForAuthCode function
*/
export function setupOAuthCallbackServer(options: OAuthCallbackServerOptions) {
export function setupOAuthCallbackServerWithLongPoll(options: OAuthCallbackServerOptions) {
let authCode: string | null = null
const app = express()
// Create a promise to track when auth is completed
let authCompletedResolve: (code: string) => void
const authCompletedPromise = new Promise<string>((resolve) => {
authCompletedResolve = resolve
})
// Long-polling endpoint
app.get('/wait-for-auth', (req, res) => {
if (authCode) {
// Auth already completed
log('Auth already completed, returning immediately')
res.status(200).send(authCode)
return
}
if (req.query.poll === 'false') {
log('Client requested no long poll, responding with 202')
res.status(202).send('Authentication in progress')
return
}
// Long poll - wait for up to 30 seconds
const longPollTimeout = setTimeout(() => {
log('Long poll timeout reached, responding with 202')
res.status(202).send('Authentication in progress')
}, 30000)
// If auth completes while we're waiting, send the response immediately
authCompletedPromise
.then((code) => {
clearTimeout(longPollTimeout)
if (!res.headersSent) {
log('Auth completed during long poll, responding with 200')
res.status(200).send(code)
}
})
.catch(() => {
clearTimeout(longPollTimeout)
if (!res.headersSent) {
log('Auth failed during long poll, responding with 500')
res.status(500).send('Authentication failed')
}
})
})
// OAuth callback endpoint
app.get(options.path, (req, res) => {
const code = req.query.code as string | undefined
if (!code) {
@ -124,6 +180,9 @@ export function setupOAuthCallbackServer(options: OAuthCallbackServerOptions) {
}
authCode = code
log('Auth code received, resolving promise')
authCompletedResolve(code)
res.send('Authorization successful! You may close this window and return to the CLI.')
// Notify main flow that auth code is available
@ -134,10 +193,6 @@ export function setupOAuthCallbackServer(options: OAuthCallbackServerOptions) {
log(`OAuth callback server running at http://127.0.0.1:${options.port}`)
})
/**
* Waits for the OAuth authorization code
* @returns A promise that resolves with the authorization code
*/
const waitForAuthCode = (): Promise<string> => {
return new Promise((resolve) => {
if (authCode) {
@ -151,6 +206,16 @@ export function setupOAuthCallbackServer(options: OAuthCallbackServerOptions) {
})
}
return { server, authCode, waitForAuthCode, authCompletedPromise }
}
/**
* Sets up an Express server to handle OAuth callbacks
* @param options The server options
* @returns An object with the server, authCode, and waitForAuthCode function
*/
export function setupOAuthCallbackServer(options: OAuthCallbackServerOptions) {
const { server, authCode, waitForAuthCode } = setupOAuthCallbackServerWithLongPoll(options)
return { server, authCode, waitForAuthCode }
}
@ -248,4 +313,11 @@ export function setupSignalHandlers(cleanup: () => Promise<void>) {
process.stdin.resume()
}
export const MCP_REMOTE_VERSION = require('../../package.json').version
/**
* Generates a hash for the server URL to use in filenames
* @param serverUrl The server URL to hash
* @returns The hashed server URL
*/
export function getServerUrlHash(serverUrl: string): string {
return crypto.createHash('md5').update(serverUrl).digest('hex')
}