Merge pull request #2 from shashimalcse/asgardeo

Add asgardeo mode
This commit is contained in:
Thilina Shashimal Senarath 2025-04-03 14:07:44 +05:30 committed by GitHub
commit a077ab1075
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
8 changed files with 587 additions and 71 deletions

View file

@ -41,6 +41,8 @@ python3 echo_server.py
Update the following parameters in `config.yaml`. Update the following parameters in `config.yaml`.
### demo mode configuration:
```yaml ```yaml
mcp_server_base_url: "http://localhost:8000" # URL of your MCP server mcp_server_base_url: "http://localhost:8000" # URL of your MCP server
listen_address: ":8080" # Address where the proxy will listen listen_address: ":8080" # Address where the proxy will listen

View file

@ -11,12 +11,14 @@ import (
"github.com/wso2/open-mcp-auth-proxy/internal/authz" "github.com/wso2/open-mcp-auth-proxy/internal/authz"
"github.com/wso2/open-mcp-auth-proxy/internal/config" "github.com/wso2/open-mcp-auth-proxy/internal/config"
"github.com/wso2/open-mcp-auth-proxy/internal/constants"
"github.com/wso2/open-mcp-auth-proxy/internal/proxy" "github.com/wso2/open-mcp-auth-proxy/internal/proxy"
"github.com/wso2/open-mcp-auth-proxy/internal/util" "github.com/wso2/open-mcp-auth-proxy/internal/util"
) )
func main() { func main() {
demoMode := flag.Bool("demo", false, "Use Asgardeo-based provider (demo).") demoMode := flag.Bool("demo", false, "Use Asgardeo-based provider (demo).")
asgardeoMode := flag.Bool("asgardeo", false, "Use Asgardeo-based provider (asgardeo).")
flag.Parse() flag.Parse()
// 1. Load config // 1. Load config
@ -28,12 +30,20 @@ func main() {
// 2. Create the chosen provider // 2. Create the chosen provider
var provider authz.Provider var provider authz.Provider
if *demoMode { if *demoMode {
cfg.AuthServerBaseURL = "https://api.asgardeo.io/t/" + cfg.Demo.OrgName + "/oauth2" cfg.Mode = "demo"
cfg.JWKSURL = "https://api.asgardeo.io/t/" + cfg.Demo.OrgName + "/oauth2/jwks" cfg.AuthServerBaseURL = constants.ASGARDEO_BASE_URL + cfg.Demo.OrgName + "/oauth2"
cfg.JWKSURL = constants.ASGARDEO_BASE_URL + cfg.Demo.OrgName + "/oauth2/jwks"
provider = authz.NewAsgardeoProvider(cfg)
} else if *asgardeoMode {
cfg.Mode = "asgardeo"
cfg.AuthServerBaseURL = constants.ASGARDEO_BASE_URL + cfg.Asgardeo.OrgName + "/oauth2"
cfg.JWKSURL = constants.ASGARDEO_BASE_URL + cfg.Asgardeo.OrgName + "/oauth2/jwks"
provider = authz.NewAsgardeoProvider(cfg) provider = authz.NewAsgardeoProvider(cfg)
fmt.Println("Using Asgardeo provider (demo).")
} else { } else {
log.Fatalf("Not supported yet.") cfg.Mode = "default"
cfg.JWKSURL = cfg.Default.JWKSURL
cfg.AuthServerBaseURL = cfg.Default.BaseURL
provider = authz.NewDefaultProvider(cfg)
} }
// 3. (Optional) Fetch JWKS if you want local JWT validation // 3. (Optional) Fetch JWKS if you want local JWT validation
@ -44,14 +54,17 @@ func main() {
// 4. Build the main router // 4. Build the main router
mux := proxy.NewRouter(cfg, provider) mux := proxy.NewRouter(cfg, provider)
listen_address := fmt.Sprintf(":%d", cfg.ListenPort)
// 5. Start the server // 5. Start the server
srv := &http.Server{ srv := &http.Server{
Addr: cfg.ListenAddress,
Addr: listen_address,
Handler: mux, Handler: mux,
} }
go func() { go func() {
log.Printf("Server listening on %s", cfg.ListenAddress) log.Printf("Server listening on %s", listen_address)
if err := srv.ListenAndServe(); err != nil && err != http.ErrServerClosed { if err := srv.ListenAndServe(); err != nil && err != http.ErrServerClosed {
log.Fatalf("Server error: %v", err) log.Fatalf("Server error: %v", err)
} }

View file

@ -1,9 +1,7 @@
# config.yaml # config.yaml
auth_server_base_url: "" mcp_server_base_url: ""
mcp_server_base_url: "http://localhost:8000" listen_port: 8080
listen_address: ":8080"
jwks_url: ""
timeout_seconds: 10 timeout_seconds: 10
mcp_paths: mcp_paths:
@ -11,8 +9,63 @@ mcp_paths:
- /sse - /sse
path_mapping: path_mapping:
/token: /token
/register: /register
/authorize: /authorize
/.well-known/oauth-authorization-server: /.well-known/oauth-authorization-server
cors:
allowed_origins:
- ""
allowed_methods:
- "GET"
- "POST"
- "PUT"
- "DELETE"
allowed_headers:
- "Authorization"
- "Content-Type"
allow_credentials: true
demo: demo:
org_name: "openmcpauthdemo" org_name: "openmcpauthdemo"
client_id: "N0U9e_NNGr9mP_0fPnPfPI0a6twa" client_id: "N0U9e_NNGr9mP_0fPnPfPI0a6twa"
client_secret: "qFHfiBp5gNGAO9zV4YPnDofBzzfInatfUbHyPZvM0jka" client_secret: "qFHfiBp5gNGAO9zV4YPnDofBzzfInatfUbHyPZvM0jka"
asgardeo:
org_name: "<org_name>"
client_id: "<client_id>"
client_secret: "<client_secret>"
default:
base_url: "<base_url>"
jwks_url: "<jwks_url>"
path:
/.well-known/oauth-authorization-server:
response:
issuer: "<issuer>"
jwks_uri: "<jwks_uri>"
authorization_endpoint: "<authorization_endpoint>" # Optional
token_endpoint: "<token_endpoint>" # Optional
registration_endpoint: "<registration_endpoint>" # Optional
response_types_supported:
- "code"
grant_types_supported:
- "authorization_code"
- "refresh_token"
code_challenge_methods_supported:
- "S256"
- "plain"
/authroize:
addQueryParams:
- name: "<name>"
value: "<value>"
/token:
addBodyParams:
- name: "<name>"
value: "<value>"
/register:
addBodyParams:
- name: "<name>"
value: "<value>"

94
internal/authz/default.go Normal file
View file

@ -0,0 +1,94 @@
package authz
import (
"encoding/json"
"net/http"
"github.com/wso2/open-mcp-auth-proxy/internal/config"
)
type defaultProvider struct {
cfg *config.Config
}
// NewDefaultProvider initializes a Provider for Asgardeo (demo mode).
func NewDefaultProvider(cfg *config.Config) Provider {
return &defaultProvider{cfg: cfg}
}
func (p *defaultProvider) WellKnownHandler() http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Access-Control-Allow-Origin", "*")
w.Header().Set("Access-Control-Allow-Headers", "Authorization, Content-Type")
w.Header().Set("Access-Control-Allow-Methods", "GET, OPTIONS")
if r.Method == http.MethodOptions {
w.WriteHeader(http.StatusNoContent)
return
}
if r.Method != http.MethodGet {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
// Check if we have a custom response configuration
if p.cfg.Default.Path != nil {
pathConfig, exists := p.cfg.Default.Path["/.well-known/oauth-authorization-server"]
if exists && pathConfig.Response != nil {
// Use configured response values
responseConfig := pathConfig.Response
// Get current host for proxy endpoints
scheme := "http"
if r.TLS != nil {
scheme = "https"
}
if forwardedProto := r.Header.Get("X-Forwarded-Proto"); forwardedProto != "" {
scheme = forwardedProto
}
host := r.Host
if forwardedHost := r.Header.Get("X-Forwarded-Host"); forwardedHost != "" {
host = forwardedHost
}
baseURL := scheme + "://" + host
authorizationEndpoint := responseConfig.AuthorizationEndpoint
if authorizationEndpoint == "" {
authorizationEndpoint = baseURL + "/authorize"
}
tokenEndpoint := responseConfig.TokenEndpoint
if tokenEndpoint == "" {
tokenEndpoint = baseURL + "/token"
}
registraionEndpoint := responseConfig.RegistrationEndpoint
if registraionEndpoint == "" {
registraionEndpoint = baseURL + "/register"
}
// Build response from config
response := map[string]interface{}{
"issuer": responseConfig.Issuer,
"authorization_endpoint": authorizationEndpoint,
"token_endpoint": tokenEndpoint,
"jwks_uri": responseConfig.JwksURI,
"response_types_supported": responseConfig.ResponseTypesSupported,
"grant_types_supported": responseConfig.GrantTypesSupported,
"token_endpoint_auth_methods_supported": []string{"client_secret_basic"},
"registration_endpoint": registraionEndpoint,
"code_challenge_methods_supported": responseConfig.CodeChallengeMethodsSupported,
}
w.Header().Set("Content-Type", "application/json")
if err := json.NewEncoder(w).Encode(response); err != nil {
http.Error(w, "Internal server error", http.StatusInternalServerError)
}
return
}
}
}
}
func (p *defaultProvider) RegisterHandler() http.HandlerFunc {
return nil
}

View file

@ -13,17 +13,67 @@ type DemoConfig struct {
OrgName string `yaml:"org_name"` OrgName string `yaml:"org_name"`
} }
type AsgardeoConfig struct {
ClientID string `yaml:"client_id"`
ClientSecret string `yaml:"client_secret"`
OrgName string `yaml:"org_name"`
}
type CORSConfig struct {
AllowedOrigins []string `yaml:"allowed_origins"`
AllowedMethods []string `yaml:"allowed_methods"`
AllowedHeaders []string `yaml:"allowed_headers"`
AllowCredentials bool `yaml:"allow_credentials"`
}
type ParamConfig struct {
Name string `yaml:"name"`
Value string `yaml:"value"`
}
type ResponseConfig struct {
Issuer string `yaml:"issuer,omitempty"`
JwksURI string `yaml:"jwks_uri,omitempty"`
AuthorizationEndpoint string `yaml:"authorization_endpoint,omitempty"`
TokenEndpoint string `yaml:"token_endpoint,omitempty"`
RegistrationEndpoint string `yaml:"registration_endpoint,omitempty"`
ResponseTypesSupported []string `yaml:"response_types_supported,omitempty"`
GrantTypesSupported []string `yaml:"grant_types_supported,omitempty"`
CodeChallengeMethodsSupported []string `yaml:"code_challenge_methods_supported,omitempty"`
}
type PathConfig struct {
// For well-known endpoint
Response *ResponseConfig `yaml:"response,omitempty"`
// For authorization endpoint
AddQueryParams []ParamConfig `yaml:"addQueryParams,omitempty"`
// For token and register endpoints
AddBodyParams []ParamConfig `yaml:"addBodyParams,omitempty"`
}
type DefaultConfig struct {
BaseURL string `yaml:"base_url,omitempty"`
Path map[string]PathConfig `yaml:"path,omitempty"`
JWKSURL string `yaml:"jwks_url,omitempty"`
}
type Config struct { type Config struct {
AuthServerBaseURL string `yaml:"auth_server_base_url"` AuthServerBaseURL string
MCPServerBaseURL string `yaml:"mcp_server_base_url"` MCPServerBaseURL string `yaml:"mcp_server_base_url"`
ListenAddress string `yaml:"listen_address"` ListenPort int `yaml:"listen_port"`
JWKSURL string `yaml:"jwks_url"` JWKSURL string
TimeoutSeconds int `yaml:"timeout_seconds"` TimeoutSeconds int `yaml:"timeout_seconds"`
MCPPaths []string `yaml:"mcp_paths"` MCPPaths []string `yaml:"mcp_paths"`
PathMapping map[string]string `yaml:"path_mapping"` PathMapping map[string]string `yaml:"path_mapping"`
Mode string `yaml:"mode"`
CORSConfig CORSConfig `yaml:"cors"`
// Nested config for Asgardeo // Nested config for Asgardeo
Demo DemoConfig `yaml:"demo"` Demo DemoConfig `yaml:"demo"`
Asgardeo AsgardeoConfig `yaml:"asgardeo"`
Default DefaultConfig `yaml:"default"`
} }
// LoadConfig reads a YAML config file into Config struct. // LoadConfig reads a YAML config file into Config struct.

View file

@ -0,0 +1,7 @@
package constants
// Package constant provides constants for the MCP Auth Proxy
const (
ASGARDEO_BASE_URL = "https://api.asgardeo.io/t/"
)

199
internal/proxy/modifier.go Normal file
View file

@ -0,0 +1,199 @@
package proxy
import (
"bytes"
"encoding/json"
"fmt"
"io"
"net/http"
"strings"
"github.com/wso2/open-mcp-auth-proxy/internal/config"
)
// RequestModifier modifies requests before they are proxied
type RequestModifier interface {
ModifyRequest(req *http.Request) (*http.Request, error)
}
// AuthorizationModifier adds parameters to authorization requests
type AuthorizationModifier struct {
Config *config.Config
}
// TokenModifier adds parameters to token requests
type TokenModifier struct {
Config *config.Config
}
type RegisterModifier struct {
Config *config.Config
}
// ModifyRequest adds configured parameters to authorization requests
func (m *AuthorizationModifier) ModifyRequest(req *http.Request) (*http.Request, error) {
// Check if we have parameters to add
if m.Config.Default.Path == nil {
return req, nil
}
pathConfig, exists := m.Config.Default.Path["/authorize"]
if !exists || len(pathConfig.AddQueryParams) == 0 {
return req, nil
}
// Get current query parameters
query := req.URL.Query()
// Add parameters from config
for _, param := range pathConfig.AddQueryParams {
query.Set(param.Name, param.Value)
}
// Update the request URL
req.URL.RawQuery = query.Encode()
return req, nil
}
// ModifyRequest adds configured parameters to token requests
func (m *TokenModifier) ModifyRequest(req *http.Request) (*http.Request, error) {
// Only modify POST requests
if req.Method != http.MethodPost {
return req, nil
}
// Check if we have parameters to add
if m.Config.Default.Path == nil {
return req, nil
}
pathConfig, exists := m.Config.Default.Path["/token"]
if !exists || len(pathConfig.AddBodyParams) == 0 {
return req, nil
}
contentType := req.Header.Get("Content-Type")
if strings.Contains(contentType, "application/x-www-form-urlencoded") {
// Parse form data
if err := req.ParseForm(); err != nil {
return nil, err
}
// Clone form data
formData := req.PostForm
// Add configured parameters
for _, param := range pathConfig.AddBodyParams {
formData.Set(param.Name, param.Value)
}
// Create new request body with modified form
formEncoded := formData.Encode()
req.Body = io.NopCloser(strings.NewReader(formEncoded))
req.ContentLength = int64(len(formEncoded))
req.Header.Set("Content-Length", fmt.Sprintf("%d", len(formEncoded)))
} else if strings.Contains(contentType, "application/json") {
// Read body
bodyBytes, err := io.ReadAll(req.Body)
if err != nil {
return nil, err
}
// Parse JSON
var jsonData map[string]interface{}
if err := json.Unmarshal(bodyBytes, &jsonData); err != nil {
return nil, err
}
// Add parameters
for _, param := range pathConfig.AddBodyParams {
jsonData[param.Name] = param.Value
}
// Marshal back to JSON
modifiedBody, err := json.Marshal(jsonData)
if err != nil {
return nil, err
}
// Update request
req.Body = io.NopCloser(bytes.NewReader(modifiedBody))
req.ContentLength = int64(len(modifiedBody))
req.Header.Set("Content-Length", fmt.Sprintf("%d", len(modifiedBody)))
}
return req, nil
}
func (m *RegisterModifier) ModifyRequest(req *http.Request) (*http.Request, error) {
// Only modify POST requests
if req.Method != http.MethodPost {
return req, nil
}
// Check if we have parameters to add
if m.Config.Default.Path == nil {
return req, nil
}
pathConfig, exists := m.Config.Default.Path["/register"]
if !exists || len(pathConfig.AddBodyParams) == 0 {
return req, nil
}
contentType := req.Header.Get("Content-Type")
if strings.Contains(contentType, "application/x-www-form-urlencoded") {
// Parse form data
if err := req.ParseForm(); err != nil {
return nil, err
}
// Clone form data
formData := req.PostForm
// Add configured parameters
for _, param := range pathConfig.AddBodyParams {
formData.Set(param.Name, param.Value)
}
// Create new request body with modified form
formEncoded := formData.Encode()
req.Body = io.NopCloser(strings.NewReader(formEncoded))
req.ContentLength = int64(len(formEncoded))
req.Header.Set("Content-Length", fmt.Sprintf("%d", len(formEncoded)))
} else if strings.Contains(contentType, "application/json") {
// Read body
bodyBytes, err := io.ReadAll(req.Body)
if err != nil {
return nil, err
}
// Parse JSON
var jsonData map[string]interface{}
if err := json.Unmarshal(bodyBytes, &jsonData); err != nil {
return nil, err
}
// Add parameters
for _, param := range pathConfig.AddBodyParams {
jsonData[param.Name] = param.Value
}
// Marshal back to JSON
modifiedBody, err := json.Marshal(jsonData)
if err != nil {
return nil, err
}
// Update request
req.Body = io.NopCloser(bytes.NewReader(modifiedBody))
req.ContentLength = int64(len(modifiedBody))
req.Header.Set("Content-Length", fmt.Sprintf("%d", len(modifiedBody)))
}
return req, nil
}

View file

@ -20,34 +20,87 @@ import (
func NewRouter(cfg *config.Config, provider authz.Provider) http.Handler { func NewRouter(cfg *config.Config, provider authz.Provider) http.Handler {
mux := http.NewServeMux() mux := http.NewServeMux()
// 1. Custom well-known modifiers := map[string]RequestModifier{
mux.HandleFunc("/.well-known/oauth-authorization-server", provider.WellKnownHandler()) "/authorize": &AuthorizationModifier{Config: cfg},
"/token": &TokenModifier{Config: cfg},
"/register": &RegisterModifier{Config: cfg},
}
// 2. Registration registeredPaths := make(map[string]bool)
mux.HandleFunc("/register", provider.RegisterHandler())
// 3. Default "auth" paths, proxied var defaultPaths []string
defaultPaths := []string{"/authorize", "/token"}
// Handle based on mode configuration
if cfg.Mode == "demo" || cfg.Mode == "asgardeo" {
// Demo/Asgardeo mode: Custom handlers for well-known and register
mux.HandleFunc("/.well-known/oauth-authorization-server", provider.WellKnownHandler())
registeredPaths["/.well-known/oauth-authorization-server"] = true
mux.HandleFunc("/register", provider.RegisterHandler())
registeredPaths["/register"] = true
// Authorize and token will be proxied with parameter modification
defaultPaths = []string{"/authorize", "/token"}
} else {
// Default provider mode
if cfg.Default.Path != nil {
// Check if we have custom response for well-known
wellKnownConfig, exists := cfg.Default.Path["/.well-known/oauth-authorization-server"]
if exists && wellKnownConfig.Response != nil {
// If there's a custom response defined, use our handler
mux.HandleFunc("/.well-known/oauth-authorization-server", provider.WellKnownHandler())
registeredPaths["/.well-known/oauth-authorization-server"] = true
} else {
// No custom response, add well-known to proxy paths
defaultPaths = append(defaultPaths, "/.well-known/oauth-authorization-server")
}
defaultPaths = append(defaultPaths, "/authorize")
defaultPaths = append(defaultPaths, "/token")
defaultPaths = append(defaultPaths, "/register")
} else {
defaultPaths = []string{"/authorize", "/token", "/register", "/.well-known/oauth-authorization-server"}
}
}
// Remove duplicates from defaultPaths
uniquePaths := make(map[string]bool)
cleanPaths := []string{}
for _, path := range defaultPaths { for _, path := range defaultPaths {
mux.HandleFunc(path, buildProxyHandler(cfg)) if !uniquePaths[path] {
uniquePaths[path] = true
cleanPaths = append(cleanPaths, path)
}
}
defaultPaths = cleanPaths
for _, path := range defaultPaths {
if !registeredPaths[path] {
mux.HandleFunc(path, buildProxyHandler(cfg, modifiers))
registeredPaths[path] = true
}
} }
// 4. MCP paths // MCP paths
for _, path := range cfg.MCPPaths { for _, path := range cfg.MCPPaths {
mux.HandleFunc(path, buildProxyHandler(cfg)) mux.HandleFunc(path, buildProxyHandler(cfg, modifiers))
registeredPaths[path] = true
} }
// 5. If you want to map additional paths from config.PathMapping // Register paths from PathMapping that haven't been registered yet
// to the same proxy logic:
for path := range cfg.PathMapping { for path := range cfg.PathMapping {
mux.HandleFunc(path, buildProxyHandler(cfg)) if !registeredPaths[path] {
mux.HandleFunc(path, buildProxyHandler(cfg, modifiers))
registeredPaths[path] = true
}
} }
return mux return mux
} }
func buildProxyHandler(cfg *config.Config) http.HandlerFunc { func buildProxyHandler(cfg *config.Config, modifiers map[string]RequestModifier) http.HandlerFunc {
// Parse the base URLs up front // Parse the base URLs up front
authBase, err := url.Parse(cfg.AuthServerBaseURL) authBase, err := url.Parse(cfg.AuthServerBaseURL)
if err != nil { if err != nil {
log.Fatalf("Invalid auth server URL: %v", err) log.Fatalf("Invalid auth server URL: %v", err)
@ -57,13 +110,6 @@ func buildProxyHandler(cfg *config.Config) http.HandlerFunc {
log.Fatalf("Invalid MCP server URL: %v", err) log.Fatalf("Invalid MCP server URL: %v", err)
} }
// We'll define sets for known auth paths, SSE paths, etc.
authPaths := map[string]bool{
"/authorize": true,
"/token": true,
"/.well-known/oauth-authorization-server": true,
}
// Detect SSE paths from config // Detect SSE paths from config
ssePaths := make(map[string]bool) ssePaths := make(map[string]bool)
for _, p := range cfg.MCPPaths { for _, p := range cfg.MCPPaths {
@ -73,23 +119,38 @@ func buildProxyHandler(cfg *config.Config) http.HandlerFunc {
} }
return func(w http.ResponseWriter, r *http.Request) { return func(w http.ResponseWriter, r *http.Request) {
origin := r.Header.Get("Origin")
allowedOrigin := getAllowedOrigin(origin, cfg)
// Handle OPTIONS // Handle OPTIONS
if r.Method == http.MethodOptions { if r.Method == http.MethodOptions {
addCORSHeaders(w) if allowedOrigin == "" {
log.Printf("[proxy] Preflight request from disallowed origin: %s", origin)
http.Error(w, "CORS origin not allowed", http.StatusForbidden)
return
}
addCORSHeaders(w, cfg, allowedOrigin, r.Header.Get("Access-Control-Request-Headers"))
w.WriteHeader(http.StatusNoContent) w.WriteHeader(http.StatusNoContent)
return return
} }
addCORSHeaders(w) if allowedOrigin == "" {
log.Printf("[proxy] Request from disallowed origin: %s for %s", origin, r.URL.Path)
http.Error(w, "CORS origin not allowed", http.StatusForbidden)
return
}
// Add CORS headers to all responses
addCORSHeaders(w, cfg, allowedOrigin, "")
// Decide whether the request should go to the auth server or MCP // Decide whether the request should go to the auth server or MCP
var targetURL *url.URL var targetURL *url.URL
isSSE := false isSSE := false
if authPaths[r.URL.Path] { if isAuthPath(r.URL.Path) {
targetURL = authBase targetURL = authBase
} else if isMCPPath(r.URL.Path, cfg) { } else if isMCPPath(r.URL.Path, cfg) {
// Validate JWT if you want // Validate JWT for MCP paths if required
// Placeholder for JWT validation logic
if err := util.ValidateJWT(r.Header.Get("Authorization")); err != nil { if err := util.ValidateJWT(r.Header.Get("Authorization")); err != nil {
log.Printf("[proxy] Unauthorized request to %s: %v", r.URL.Path, err) log.Printf("[proxy] Unauthorized request to %s: %v", r.URL.Path, err)
http.Error(w, "Unauthorized", http.StatusUnauthorized) http.Error(w, "Unauthorized", http.StatusUnauthorized)
@ -100,11 +161,21 @@ func buildProxyHandler(cfg *config.Config) http.HandlerFunc {
isSSE = true isSSE = true
} }
} else { } else {
// If it's not recognized as an auth path or an MCP path
http.Error(w, "Forbidden", http.StatusForbidden) http.Error(w, "Forbidden", http.StatusForbidden)
return return
} }
// Apply request modifiers to add parameters
if modifier, exists := modifiers[r.URL.Path]; exists {
var err error
r, err = modifier.ModifyRequest(r)
if err != nil {
log.Printf("[proxy] Error modifying request: %v", err)
http.Error(w, "Bad Request", http.StatusBadRequest)
return
}
}
// Build the reverse proxy // Build the reverse proxy
rp := &httputil.ReverseProxy{ rp := &httputil.ReverseProxy{
Director: func(req *http.Request) { Director: func(req *http.Request) {
@ -120,23 +191,27 @@ func buildProxyHandler(cfg *config.Config) http.HandlerFunc {
req.URL.RawQuery = r.URL.RawQuery req.URL.RawQuery = r.URL.RawQuery
req.Host = targetURL.Host req.Host = targetURL.Host
for header, values := range r.Header { cleanHeaders := http.Header{}
for k, v := range r.Header {
// Skip hop-by-hop headers // Skip hop-by-hop headers
if strings.EqualFold(header, "Connection") || if skipHeader(k) {
strings.EqualFold(header, "Keep-Alive") ||
strings.EqualFold(header, "Transfer-Encoding") ||
strings.EqualFold(header, "Upgrade") ||
strings.EqualFold(header, "Proxy-Authorization") ||
strings.EqualFold(header, "Proxy-Connection") {
continue continue
} }
for _, value := range values { // Set only the first value to avoid duplicates
req.Header.Set(header, value) cleanHeaders.Set(k, v[0])
}
} }
req.Header = cleanHeaders
log.Printf("[proxy] %s -> %s%s", r.URL.Path, req.URL.Host, req.URL.Path) log.Printf("[proxy] %s -> %s%s", r.URL.Path, req.URL.Host, req.URL.Path)
}, },
ModifyResponse: func(resp *http.Response) error {
log.Printf("[proxy] Response from %s%s: %d", resp.Request.URL.Host, resp.Request.URL.Path, resp.StatusCode)
resp.Header.Del("Access-Control-Allow-Origin") // Avoid upstream conflicts
return nil
},
ErrorHandler: func(rw http.ResponseWriter, req *http.Request, err error) { ErrorHandler: func(rw http.ResponseWriter, req *http.Request, err error) {
log.Printf("[proxy] Error proxying: %v", err) log.Printf("[proxy] Error proxying: %v", err)
http.Error(rw, "Bad Gateway", http.StatusBadGateway) http.Error(rw, "Bad Gateway", http.StatusBadGateway)
@ -156,12 +231,47 @@ func buildProxyHandler(cfg *config.Config) http.HandlerFunc {
} }
} }
func addCORSHeaders(w http.ResponseWriter) { func getAllowedOrigin(origin string, cfg *config.Config) string {
w.Header().Set("Access-Control-Allow-Origin", "*") if origin == "" {
w.Header().Set("Access-Control-Allow-Headers", "Authorization, Content-Type") return cfg.CORSConfig.AllowedOrigins[0] // Default to first allowed origin
w.Header().Set("Access-Control-Allow-Methods", "GET, POST, OPTIONS") }
for _, allowed := range cfg.CORSConfig.AllowedOrigins {
if allowed == origin {
return allowed
}
}
return ""
} }
// addCORSHeaders adds configurable CORS headers
func addCORSHeaders(w http.ResponseWriter, cfg *config.Config, allowedOrigin, requestHeaders string) {
w.Header().Set("Access-Control-Allow-Origin", allowedOrigin)
w.Header().Set("Access-Control-Allow-Methods", strings.Join(cfg.CORSConfig.AllowedMethods, ", "))
if requestHeaders != "" {
w.Header().Set("Access-Control-Allow-Headers", requestHeaders)
} else {
w.Header().Set("Access-Control-Allow-Headers", strings.Join(cfg.CORSConfig.AllowedHeaders, ", "))
}
if cfg.CORSConfig.AllowCredentials {
w.Header().Set("Access-Control-Allow-Credentials", "true")
}
w.Header().Set("Vary", "Origin")
}
func isAuthPath(path string) bool {
authPaths := map[string]bool{
"/authorize": true,
"/token": true,
"/register": true,
"/.well-known/oauth-authorization-server": true,
}
if strings.HasPrefix(path, "/u/") {
return true
}
return authPaths[path]
}
// isMCPPath checks if the path is an MCP path
func isMCPPath(path string, cfg *config.Config) bool { func isMCPPath(path string, cfg *config.Config) bool {
for _, p := range cfg.MCPPaths { for _, p := range cfg.MCPPaths {
if strings.HasPrefix(path, p) { if strings.HasPrefix(path, p) {
@ -171,22 +281,10 @@ func isMCPPath(path string, cfg *config.Config) bool {
return false return false
} }
func copyHeaders(src http.Header, dst http.Header) { func skipHeader(h string) bool {
// Exclude hop-by-hop switch strings.ToLower(h) {
hopByHop := map[string]bool{ case "connection", "keep-alive", "transfer-encoding", "upgrade", "proxy-authorization", "proxy-connection", "te", "trailer":
"Connection": true, return true
"Keep-Alive": true,
"Transfer-Encoding": true,
"Upgrade": true,
"Proxy-Authorization": true,
"Proxy-Connection": true,
}
for k, vv := range src {
if hopByHop[strings.ToLower(k)] {
continue
}
for _, v := range vv {
dst.Add(k, v)
}
} }
return false
} }