mirror of
https://github.com/wso2/open-mcp-auth-proxy.git
synced 2025-06-27 17:13:31 +00:00
Intercept and modify SSE response
This commit is contained in:
parent
527e8d0293
commit
d3fcddb18a
1 changed files with 74 additions and 0 deletions
|
@ -1,10 +1,14 @@
|
|||
package proxy
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"net/http"
|
||||
"net/http/httputil"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
|
@ -32,3 +36,73 @@ func HandleSSE(w http.ResponseWriter, r *http.Request, rp *httputil.ReverseProxy
|
|||
func NewShutdownContext(timeout time.Duration) (context.Context, context.CancelFunc) {
|
||||
return context.WithTimeout(context.Background(), timeout)
|
||||
}
|
||||
|
||||
// sseTransport is a custom http.RoundTripper that intercepts and modifies SSE responses
|
||||
type sseTransport struct {
|
||||
Transport http.RoundTripper
|
||||
proxyHost string
|
||||
targetHost string
|
||||
}
|
||||
|
||||
func (t *sseTransport) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
// Call the underlying transport
|
||||
resp, err := t.Transport.RoundTrip(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Check if this is an SSE response
|
||||
contentType := resp.Header.Get("Content-Type")
|
||||
if !strings.Contains(contentType, "text/event-stream") {
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
log.Printf("INFO: Intercepting SSE response to modify endpoint events")
|
||||
|
||||
// Create a response wrapper that modifies the response body
|
||||
originalBody := resp.Body
|
||||
pr, pw := io.Pipe()
|
||||
|
||||
go func() {
|
||||
defer originalBody.Close()
|
||||
defer pw.Close()
|
||||
|
||||
scanner := bufio.NewScanner(originalBody)
|
||||
for scanner.Scan() {
|
||||
line := scanner.Text()
|
||||
|
||||
// Check if this line contains an endpoint event
|
||||
if strings.HasPrefix(line, "event: endpoint") {
|
||||
// Read the data line
|
||||
if scanner.Scan() {
|
||||
dataLine := scanner.Text()
|
||||
if strings.HasPrefix(dataLine, "data: ") {
|
||||
// Extract the endpoint URL
|
||||
endpoint := strings.TrimPrefix(dataLine, "data: ")
|
||||
|
||||
// Replace the host in the endpoint
|
||||
log.Printf("DEBUG: Original endpoint: %s", endpoint)
|
||||
endpoint = strings.Replace(endpoint, t.targetHost, t.proxyHost, 1)
|
||||
log.Printf("DEBUG: Modified endpoint: %s", endpoint)
|
||||
|
||||
// Write the modified event lines
|
||||
fmt.Fprintln(pw, line)
|
||||
fmt.Fprintln(pw, "data: "+endpoint)
|
||||
continue
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Write the original line for non-endpoint events
|
||||
fmt.Fprintln(pw, line)
|
||||
}
|
||||
|
||||
if err := scanner.Err(); err != nil {
|
||||
log.Printf("Error reading SSE stream: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
// Replace the response body with our modified pipe
|
||||
resp.Body = pr
|
||||
return resp, nil
|
||||
}
|
Loading…
Add table
Add a link
Reference in a new issue