most of the implementation looking ok but sharing the token the wrong way
This commit is contained in:
parent
743b6b207f
commit
412b5d9486
6 changed files with 364 additions and 62 deletions
|
@ -4,6 +4,10 @@ import { Transport } from '@modelcontextprotocol/sdk/shared/transport.js'
|
|||
import { OAuthCallbackServerOptions } from './types'
|
||||
import express from 'express'
|
||||
import net from 'net'
|
||||
import crypto from 'crypto'
|
||||
|
||||
// Package version from package.json
|
||||
export const MCP_REMOTE_VERSION = require('../../package.json').version
|
||||
|
||||
const pid = process.pid
|
||||
export function log(str: string, ...rest: unknown[]) {
|
||||
|
@ -65,12 +69,14 @@ export function mcpProxy({ transportToClient, transportToServer }: { transportTo
|
|||
* @param serverUrl The URL of the remote server
|
||||
* @param authProvider The OAuth client provider
|
||||
* @param waitForAuthCode Function to wait for the auth code
|
||||
* @param skipBrowserAuth Whether to skip browser auth and use shared auth
|
||||
* @returns The connected SSE client transport
|
||||
*/
|
||||
export async function connectToRemoteServer(
|
||||
serverUrl: string,
|
||||
authProvider: OAuthClientProvider,
|
||||
waitForAuthCode: () => Promise<string>,
|
||||
skipBrowserAuth: boolean = false,
|
||||
): Promise<SSEClientTransport> {
|
||||
log(`[${pid}] Connecting to remote server: ${serverUrl}`)
|
||||
const url = new URL(serverUrl)
|
||||
|
@ -82,7 +88,11 @@ export async function connectToRemoteServer(
|
|||
return transport
|
||||
} catch (error) {
|
||||
if (error instanceof UnauthorizedError || (error instanceof Error && error.message.includes('Unauthorized'))) {
|
||||
log('Authentication required. Waiting for authorization...')
|
||||
if (skipBrowserAuth) {
|
||||
log('Authentication required but skipping browser auth - using shared auth')
|
||||
} else {
|
||||
log('Authentication required. Waiting for authorization...')
|
||||
}
|
||||
|
||||
// Wait for the authorization code from the callback
|
||||
const code = await waitForAuthCode()
|
||||
|
@ -112,10 +122,56 @@ export async function connectToRemoteServer(
|
|||
* @param options The server options
|
||||
* @returns An object with the server, authCode, and waitForAuthCode function
|
||||
*/
|
||||
export function setupOAuthCallbackServer(options: OAuthCallbackServerOptions) {
|
||||
export function setupOAuthCallbackServerWithLongPoll(options: OAuthCallbackServerOptions) {
|
||||
let authCode: string | null = null
|
||||
const app = express()
|
||||
|
||||
// Create a promise to track when auth is completed
|
||||
let authCompletedResolve: (code: string) => void
|
||||
const authCompletedPromise = new Promise<string>((resolve) => {
|
||||
authCompletedResolve = resolve
|
||||
})
|
||||
|
||||
// Long-polling endpoint
|
||||
app.get('/wait-for-auth', (req, res) => {
|
||||
if (authCode) {
|
||||
// Auth already completed
|
||||
log('Auth already completed, returning immediately')
|
||||
res.status(200).send(authCode)
|
||||
return
|
||||
}
|
||||
|
||||
if (req.query.poll === 'false') {
|
||||
log('Client requested no long poll, responding with 202')
|
||||
res.status(202).send('Authentication in progress')
|
||||
return
|
||||
}
|
||||
|
||||
// Long poll - wait for up to 30 seconds
|
||||
const longPollTimeout = setTimeout(() => {
|
||||
log('Long poll timeout reached, responding with 202')
|
||||
res.status(202).send('Authentication in progress')
|
||||
}, 30000)
|
||||
|
||||
// If auth completes while we're waiting, send the response immediately
|
||||
authCompletedPromise
|
||||
.then((code) => {
|
||||
clearTimeout(longPollTimeout)
|
||||
if (!res.headersSent) {
|
||||
log('Auth completed during long poll, responding with 200')
|
||||
res.status(200).send(code)
|
||||
}
|
||||
})
|
||||
.catch(() => {
|
||||
clearTimeout(longPollTimeout)
|
||||
if (!res.headersSent) {
|
||||
log('Auth failed during long poll, responding with 500')
|
||||
res.status(500).send('Authentication failed')
|
||||
}
|
||||
})
|
||||
})
|
||||
|
||||
// OAuth callback endpoint
|
||||
app.get(options.path, (req, res) => {
|
||||
const code = req.query.code as string | undefined
|
||||
if (!code) {
|
||||
|
@ -124,6 +180,9 @@ export function setupOAuthCallbackServer(options: OAuthCallbackServerOptions) {
|
|||
}
|
||||
|
||||
authCode = code
|
||||
log('Auth code received, resolving promise')
|
||||
authCompletedResolve(code)
|
||||
|
||||
res.send('Authorization successful! You may close this window and return to the CLI.')
|
||||
|
||||
// Notify main flow that auth code is available
|
||||
|
@ -134,10 +193,6 @@ export function setupOAuthCallbackServer(options: OAuthCallbackServerOptions) {
|
|||
log(`OAuth callback server running at http://127.0.0.1:${options.port}`)
|
||||
})
|
||||
|
||||
/**
|
||||
* Waits for the OAuth authorization code
|
||||
* @returns A promise that resolves with the authorization code
|
||||
*/
|
||||
const waitForAuthCode = (): Promise<string> => {
|
||||
return new Promise((resolve) => {
|
||||
if (authCode) {
|
||||
|
@ -151,6 +206,16 @@ export function setupOAuthCallbackServer(options: OAuthCallbackServerOptions) {
|
|||
})
|
||||
}
|
||||
|
||||
return { server, authCode, waitForAuthCode, authCompletedPromise }
|
||||
}
|
||||
|
||||
/**
|
||||
* Sets up an Express server to handle OAuth callbacks
|
||||
* @param options The server options
|
||||
* @returns An object with the server, authCode, and waitForAuthCode function
|
||||
*/
|
||||
export function setupOAuthCallbackServer(options: OAuthCallbackServerOptions) {
|
||||
const { server, authCode, waitForAuthCode } = setupOAuthCallbackServerWithLongPoll(options)
|
||||
return { server, authCode, waitForAuthCode }
|
||||
}
|
||||
|
||||
|
@ -248,4 +313,11 @@ export function setupSignalHandlers(cleanup: () => Promise<void>) {
|
|||
process.stdin.resume()
|
||||
}
|
||||
|
||||
export const MCP_REMOTE_VERSION = require('../../package.json').version
|
||||
/**
|
||||
* Generates a hash for the server URL to use in filenames
|
||||
* @param serverUrl The server URL to hash
|
||||
* @returns The hashed server URL
|
||||
*/
|
||||
export function getServerUrlHash(serverUrl: string): string {
|
||||
return crypto.createHash('md5').update(serverUrl).digest('hex')
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue