mirror of
https://github.com/wso2/open-mcp-auth-proxy.git
synced 2025-06-27 17:13:31 +00:00
improve readme
This commit is contained in:
parent
7b727c03a3
commit
4e957e93a2
11 changed files with 889 additions and 1 deletions
192
internal/proxy/proxy.go
Normal file
192
internal/proxy/proxy.go
Normal file
|
@ -0,0 +1,192 @@
|
|||
package proxy
|
||||
|
||||
import (
|
||||
"context"
|
||||
"log"
|
||||
"net/http"
|
||||
"net/http/httputil"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/wso2/open-mcp-auth-proxy/internal/authz"
|
||||
"github.com/wso2/open-mcp-auth-proxy/internal/config"
|
||||
"github.com/wso2/open-mcp-auth-proxy/internal/util"
|
||||
)
|
||||
|
||||
// NewRouter builds an http.ServeMux that routes
|
||||
// * /authorize, /token, /register, /.well-known to the provider or proxy
|
||||
// * MCP paths to the MCP server, etc.
|
||||
func NewRouter(cfg *config.Config, provider authz.Provider) http.Handler {
|
||||
mux := http.NewServeMux()
|
||||
|
||||
// 1. Custom well-known
|
||||
mux.HandleFunc("/.well-known/oauth-authorization-server", provider.WellKnownHandler())
|
||||
|
||||
// 2. Registration
|
||||
mux.HandleFunc("/register", provider.RegisterHandler())
|
||||
|
||||
// 3. Default "auth" paths, proxied
|
||||
defaultPaths := []string{"/authorize", "/token"}
|
||||
for _, path := range defaultPaths {
|
||||
mux.HandleFunc(path, buildProxyHandler(cfg))
|
||||
}
|
||||
|
||||
// 4. MCP paths
|
||||
for _, path := range cfg.MCPPaths {
|
||||
mux.HandleFunc(path, buildProxyHandler(cfg))
|
||||
}
|
||||
|
||||
// 5. If you want to map additional paths from config.PathMapping
|
||||
// to the same proxy logic:
|
||||
for path := range cfg.PathMapping {
|
||||
mux.HandleFunc(path, buildProxyHandler(cfg))
|
||||
}
|
||||
|
||||
return mux
|
||||
}
|
||||
|
||||
func buildProxyHandler(cfg *config.Config) http.HandlerFunc {
|
||||
// Parse the base URLs up front
|
||||
authBase, err := url.Parse(cfg.AuthServerBaseURL)
|
||||
if err != nil {
|
||||
log.Fatalf("Invalid auth server URL: %v", err)
|
||||
}
|
||||
mcpBase, err := url.Parse(cfg.MCPServerBaseURL)
|
||||
if err != nil {
|
||||
log.Fatalf("Invalid MCP server URL: %v", err)
|
||||
}
|
||||
|
||||
// We'll define sets for known auth paths, SSE paths, etc.
|
||||
authPaths := map[string]bool{
|
||||
"/authorize": true,
|
||||
"/token": true,
|
||||
"/.well-known/oauth-authorization-server": true,
|
||||
}
|
||||
|
||||
// Detect SSE paths from config
|
||||
ssePaths := make(map[string]bool)
|
||||
for _, p := range cfg.MCPPaths {
|
||||
if p == "/sse" {
|
||||
ssePaths[p] = true
|
||||
}
|
||||
}
|
||||
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
// Handle OPTIONS
|
||||
if r.Method == http.MethodOptions {
|
||||
addCORSHeaders(w)
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
return
|
||||
}
|
||||
|
||||
addCORSHeaders(w)
|
||||
|
||||
// Decide whether the request should go to the auth server or MCP
|
||||
var targetURL *url.URL
|
||||
isSSE := false
|
||||
|
||||
if authPaths[r.URL.Path] {
|
||||
targetURL = authBase
|
||||
} else if isMCPPath(r.URL.Path, cfg) {
|
||||
// Validate JWT if you want
|
||||
if err := util.ValidateJWT(r.Header.Get("Authorization")); err != nil {
|
||||
log.Printf("[proxy] Unauthorized request to %s: %v", r.URL.Path, err)
|
||||
http.Error(w, "Unauthorized", http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
targetURL = mcpBase
|
||||
if ssePaths[r.URL.Path] {
|
||||
isSSE = true
|
||||
}
|
||||
} else {
|
||||
// If it's not recognized as an auth path or an MCP path
|
||||
http.Error(w, "Forbidden", http.StatusForbidden)
|
||||
return
|
||||
}
|
||||
|
||||
// Build the reverse proxy
|
||||
rp := &httputil.ReverseProxy{
|
||||
Director: func(req *http.Request) {
|
||||
// Path rewriting if needed
|
||||
mapped := r.URL.Path
|
||||
if rewrite, ok := cfg.PathMapping[r.URL.Path]; ok {
|
||||
mapped = rewrite
|
||||
}
|
||||
basePath := strings.TrimRight(targetURL.Path, "/")
|
||||
req.URL.Scheme = targetURL.Scheme
|
||||
req.URL.Host = targetURL.Host
|
||||
req.URL.Path = basePath + mapped
|
||||
req.URL.RawQuery = r.URL.RawQuery
|
||||
req.Host = targetURL.Host
|
||||
|
||||
for header, values := range r.Header {
|
||||
// Skip hop-by-hop headers
|
||||
if strings.EqualFold(header, "Connection") ||
|
||||
strings.EqualFold(header, "Keep-Alive") ||
|
||||
strings.EqualFold(header, "Transfer-Encoding") ||
|
||||
strings.EqualFold(header, "Upgrade") ||
|
||||
strings.EqualFold(header, "Proxy-Authorization") ||
|
||||
strings.EqualFold(header, "Proxy-Connection") {
|
||||
continue
|
||||
}
|
||||
|
||||
for _, value := range values {
|
||||
req.Header.Set(header, value)
|
||||
}
|
||||
}
|
||||
log.Printf("[proxy] %s -> %s%s", r.URL.Path, req.URL.Host, req.URL.Path)
|
||||
},
|
||||
ErrorHandler: func(rw http.ResponseWriter, req *http.Request, err error) {
|
||||
log.Printf("[proxy] Error proxying: %v", err)
|
||||
http.Error(rw, "Bad Gateway", http.StatusBadGateway)
|
||||
},
|
||||
FlushInterval: -1, // immediate flush for SSE
|
||||
}
|
||||
|
||||
if isSSE {
|
||||
// Keep SSE connections open
|
||||
HandleSSE(w, r, rp)
|
||||
} else {
|
||||
// Standard requests: enforce a timeout
|
||||
ctx, cancel := context.WithTimeout(r.Context(), time.Duration(cfg.TimeoutSeconds)*time.Second)
|
||||
defer cancel()
|
||||
rp.ServeHTTP(w, r.WithContext(ctx))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func addCORSHeaders(w http.ResponseWriter) {
|
||||
w.Header().Set("Access-Control-Allow-Origin", "*")
|
||||
w.Header().Set("Access-Control-Allow-Headers", "Authorization, Content-Type")
|
||||
w.Header().Set("Access-Control-Allow-Methods", "GET, POST, OPTIONS")
|
||||
}
|
||||
|
||||
func isMCPPath(path string, cfg *config.Config) bool {
|
||||
for _, p := range cfg.MCPPaths {
|
||||
if strings.HasPrefix(path, p) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func copyHeaders(src http.Header, dst http.Header) {
|
||||
// Exclude hop-by-hop
|
||||
hopByHop := map[string]bool{
|
||||
"Connection": true,
|
||||
"Keep-Alive": true,
|
||||
"Transfer-Encoding": true,
|
||||
"Upgrade": true,
|
||||
"Proxy-Authorization": true,
|
||||
"Proxy-Connection": true,
|
||||
}
|
||||
for k, vv := range src {
|
||||
if hopByHop[strings.ToLower(k)] {
|
||||
continue
|
||||
}
|
||||
for _, v := range vv {
|
||||
dst.Add(k, v)
|
||||
}
|
||||
}
|
||||
}
|
34
internal/proxy/sse.go
Normal file
34
internal/proxy/sse.go
Normal file
|
@ -0,0 +1,34 @@
|
|||
package proxy
|
||||
|
||||
import (
|
||||
"context"
|
||||
"log"
|
||||
"net/http"
|
||||
"net/http/httputil"
|
||||
"time"
|
||||
)
|
||||
|
||||
// HandleSSE sets up a go-routine to wait for context cancellation
|
||||
// and flushes the response if possible.
|
||||
func HandleSSE(w http.ResponseWriter, r *http.Request, rp *httputil.ReverseProxy) {
|
||||
ctx := r.Context()
|
||||
done := make(chan struct{})
|
||||
|
||||
go func() {
|
||||
<-ctx.Done()
|
||||
log.Printf("INFO: SSE connection closed from %s (path: %s)", r.RemoteAddr, r.URL.Path)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
rp.ServeHTTP(w, r)
|
||||
if f, ok := w.(http.Flusher); ok {
|
||||
f.Flush()
|
||||
}
|
||||
|
||||
<-done
|
||||
}
|
||||
|
||||
// NewShutdownContext is a little helper to gracefully shut down
|
||||
func NewShutdownContext(timeout time.Duration) (context.Context, context.CancelFunc) {
|
||||
return context.WithTimeout(context.Background(), timeout)
|
||||
}
|
Loading…
Add table
Add a link
Reference in a new issue