commit
a077ab1075
8 changed files with 587 additions and 71 deletions
|
@ -41,6 +41,8 @@ python3 echo_server.py
|
||||||
|
|
||||||
Update the following parameters in `config.yaml`.
|
Update the following parameters in `config.yaml`.
|
||||||
|
|
||||||
|
### demo mode configuration:
|
||||||
|
|
||||||
```yaml
|
```yaml
|
||||||
mcp_server_base_url: "http://localhost:8000" # URL of your MCP server
|
mcp_server_base_url: "http://localhost:8000" # URL of your MCP server
|
||||||
listen_address: ":8080" # Address where the proxy will listen
|
listen_address: ":8080" # Address where the proxy will listen
|
||||||
|
|
|
@ -11,12 +11,14 @@ import (
|
||||||
|
|
||||||
"github.com/wso2/open-mcp-auth-proxy/internal/authz"
|
"github.com/wso2/open-mcp-auth-proxy/internal/authz"
|
||||||
"github.com/wso2/open-mcp-auth-proxy/internal/config"
|
"github.com/wso2/open-mcp-auth-proxy/internal/config"
|
||||||
|
"github.com/wso2/open-mcp-auth-proxy/internal/constants"
|
||||||
"github.com/wso2/open-mcp-auth-proxy/internal/proxy"
|
"github.com/wso2/open-mcp-auth-proxy/internal/proxy"
|
||||||
"github.com/wso2/open-mcp-auth-proxy/internal/util"
|
"github.com/wso2/open-mcp-auth-proxy/internal/util"
|
||||||
)
|
)
|
||||||
|
|
||||||
func main() {
|
func main() {
|
||||||
demoMode := flag.Bool("demo", false, "Use Asgardeo-based provider (demo).")
|
demoMode := flag.Bool("demo", false, "Use Asgardeo-based provider (demo).")
|
||||||
|
asgardeoMode := flag.Bool("asgardeo", false, "Use Asgardeo-based provider (asgardeo).")
|
||||||
flag.Parse()
|
flag.Parse()
|
||||||
|
|
||||||
// 1. Load config
|
// 1. Load config
|
||||||
|
@ -28,12 +30,20 @@ func main() {
|
||||||
// 2. Create the chosen provider
|
// 2. Create the chosen provider
|
||||||
var provider authz.Provider
|
var provider authz.Provider
|
||||||
if *demoMode {
|
if *demoMode {
|
||||||
cfg.AuthServerBaseURL = "https://api.asgardeo.io/t/" + cfg.Demo.OrgName + "/oauth2"
|
cfg.Mode = "demo"
|
||||||
cfg.JWKSURL = "https://api.asgardeo.io/t/" + cfg.Demo.OrgName + "/oauth2/jwks"
|
cfg.AuthServerBaseURL = constants.ASGARDEO_BASE_URL + cfg.Demo.OrgName + "/oauth2"
|
||||||
|
cfg.JWKSURL = constants.ASGARDEO_BASE_URL + cfg.Demo.OrgName + "/oauth2/jwks"
|
||||||
|
provider = authz.NewAsgardeoProvider(cfg)
|
||||||
|
} else if *asgardeoMode {
|
||||||
|
cfg.Mode = "asgardeo"
|
||||||
|
cfg.AuthServerBaseURL = constants.ASGARDEO_BASE_URL + cfg.Asgardeo.OrgName + "/oauth2"
|
||||||
|
cfg.JWKSURL = constants.ASGARDEO_BASE_URL + cfg.Asgardeo.OrgName + "/oauth2/jwks"
|
||||||
provider = authz.NewAsgardeoProvider(cfg)
|
provider = authz.NewAsgardeoProvider(cfg)
|
||||||
fmt.Println("Using Asgardeo provider (demo).")
|
|
||||||
} else {
|
} else {
|
||||||
log.Fatalf("Not supported yet.")
|
cfg.Mode = "default"
|
||||||
|
cfg.JWKSURL = cfg.Default.JWKSURL
|
||||||
|
cfg.AuthServerBaseURL = cfg.Default.BaseURL
|
||||||
|
provider = authz.NewDefaultProvider(cfg)
|
||||||
}
|
}
|
||||||
|
|
||||||
// 3. (Optional) Fetch JWKS if you want local JWT validation
|
// 3. (Optional) Fetch JWKS if you want local JWT validation
|
||||||
|
@ -44,14 +54,17 @@ func main() {
|
||||||
// 4. Build the main router
|
// 4. Build the main router
|
||||||
mux := proxy.NewRouter(cfg, provider)
|
mux := proxy.NewRouter(cfg, provider)
|
||||||
|
|
||||||
|
listen_address := fmt.Sprintf(":%d", cfg.ListenPort)
|
||||||
|
|
||||||
// 5. Start the server
|
// 5. Start the server
|
||||||
srv := &http.Server{
|
srv := &http.Server{
|
||||||
Addr: cfg.ListenAddress,
|
|
||||||
|
Addr: listen_address,
|
||||||
Handler: mux,
|
Handler: mux,
|
||||||
}
|
}
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
log.Printf("Server listening on %s", cfg.ListenAddress)
|
log.Printf("Server listening on %s", listen_address)
|
||||||
if err := srv.ListenAndServe(); err != nil && err != http.ErrServerClosed {
|
if err := srv.ListenAndServe(); err != nil && err != http.ErrServerClosed {
|
||||||
log.Fatalf("Server error: %v", err)
|
log.Fatalf("Server error: %v", err)
|
||||||
}
|
}
|
||||||
|
|
61
config.yaml
61
config.yaml
|
@ -1,9 +1,7 @@
|
||||||
# config.yaml
|
# config.yaml
|
||||||
|
|
||||||
auth_server_base_url: ""
|
mcp_server_base_url: ""
|
||||||
mcp_server_base_url: "http://localhost:8000"
|
listen_port: 8080
|
||||||
listen_address: ":8080"
|
|
||||||
jwks_url: ""
|
|
||||||
timeout_seconds: 10
|
timeout_seconds: 10
|
||||||
|
|
||||||
mcp_paths:
|
mcp_paths:
|
||||||
|
@ -11,8 +9,63 @@ mcp_paths:
|
||||||
- /sse
|
- /sse
|
||||||
|
|
||||||
path_mapping:
|
path_mapping:
|
||||||
|
/token: /token
|
||||||
|
/register: /register
|
||||||
|
/authorize: /authorize
|
||||||
|
/.well-known/oauth-authorization-server: /.well-known/oauth-authorization-server
|
||||||
|
|
||||||
|
cors:
|
||||||
|
allowed_origins:
|
||||||
|
- ""
|
||||||
|
allowed_methods:
|
||||||
|
- "GET"
|
||||||
|
- "POST"
|
||||||
|
- "PUT"
|
||||||
|
- "DELETE"
|
||||||
|
allowed_headers:
|
||||||
|
- "Authorization"
|
||||||
|
- "Content-Type"
|
||||||
|
allow_credentials: true
|
||||||
|
|
||||||
demo:
|
demo:
|
||||||
org_name: "openmcpauthdemo"
|
org_name: "openmcpauthdemo"
|
||||||
client_id: "N0U9e_NNGr9mP_0fPnPfPI0a6twa"
|
client_id: "N0U9e_NNGr9mP_0fPnPfPI0a6twa"
|
||||||
client_secret: "qFHfiBp5gNGAO9zV4YPnDofBzzfInatfUbHyPZvM0jka"
|
client_secret: "qFHfiBp5gNGAO9zV4YPnDofBzzfInatfUbHyPZvM0jka"
|
||||||
|
|
||||||
|
asgardeo:
|
||||||
|
org_name: "<org_name>"
|
||||||
|
client_id: "<client_id>"
|
||||||
|
client_secret: "<client_secret>"
|
||||||
|
|
||||||
|
default:
|
||||||
|
base_url: "<base_url>"
|
||||||
|
jwks_url: "<jwks_url>"
|
||||||
|
path:
|
||||||
|
/.well-known/oauth-authorization-server:
|
||||||
|
response:
|
||||||
|
issuer: "<issuer>"
|
||||||
|
jwks_uri: "<jwks_uri>"
|
||||||
|
authorization_endpoint: "<authorization_endpoint>" # Optional
|
||||||
|
token_endpoint: "<token_endpoint>" # Optional
|
||||||
|
registration_endpoint: "<registration_endpoint>" # Optional
|
||||||
|
response_types_supported:
|
||||||
|
- "code"
|
||||||
|
grant_types_supported:
|
||||||
|
- "authorization_code"
|
||||||
|
- "refresh_token"
|
||||||
|
code_challenge_methods_supported:
|
||||||
|
- "S256"
|
||||||
|
- "plain"
|
||||||
|
/authroize:
|
||||||
|
addQueryParams:
|
||||||
|
- name: "<name>"
|
||||||
|
value: "<value>"
|
||||||
|
/token:
|
||||||
|
addBodyParams:
|
||||||
|
- name: "<name>"
|
||||||
|
value: "<value>"
|
||||||
|
/register:
|
||||||
|
addBodyParams:
|
||||||
|
- name: "<name>"
|
||||||
|
value: "<value>"
|
||||||
|
|
||||||
|
|
94
internal/authz/default.go
Normal file
94
internal/authz/default.go
Normal file
|
@ -0,0 +1,94 @@
|
||||||
|
package authz
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"net/http"
|
||||||
|
|
||||||
|
"github.com/wso2/open-mcp-auth-proxy/internal/config"
|
||||||
|
)
|
||||||
|
|
||||||
|
type defaultProvider struct {
|
||||||
|
cfg *config.Config
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewDefaultProvider initializes a Provider for Asgardeo (demo mode).
|
||||||
|
func NewDefaultProvider(cfg *config.Config) Provider {
|
||||||
|
return &defaultProvider{cfg: cfg}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *defaultProvider) WellKnownHandler() http.HandlerFunc {
|
||||||
|
return func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.Header().Set("Access-Control-Allow-Origin", "*")
|
||||||
|
w.Header().Set("Access-Control-Allow-Headers", "Authorization, Content-Type")
|
||||||
|
w.Header().Set("Access-Control-Allow-Methods", "GET, OPTIONS")
|
||||||
|
|
||||||
|
if r.Method == http.MethodOptions {
|
||||||
|
w.WriteHeader(http.StatusNoContent)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if r.Method != http.MethodGet {
|
||||||
|
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if we have a custom response configuration
|
||||||
|
if p.cfg.Default.Path != nil {
|
||||||
|
pathConfig, exists := p.cfg.Default.Path["/.well-known/oauth-authorization-server"]
|
||||||
|
if exists && pathConfig.Response != nil {
|
||||||
|
// Use configured response values
|
||||||
|
responseConfig := pathConfig.Response
|
||||||
|
|
||||||
|
// Get current host for proxy endpoints
|
||||||
|
scheme := "http"
|
||||||
|
if r.TLS != nil {
|
||||||
|
scheme = "https"
|
||||||
|
}
|
||||||
|
if forwardedProto := r.Header.Get("X-Forwarded-Proto"); forwardedProto != "" {
|
||||||
|
scheme = forwardedProto
|
||||||
|
}
|
||||||
|
host := r.Host
|
||||||
|
if forwardedHost := r.Header.Get("X-Forwarded-Host"); forwardedHost != "" {
|
||||||
|
host = forwardedHost
|
||||||
|
}
|
||||||
|
baseURL := scheme + "://" + host
|
||||||
|
|
||||||
|
authorizationEndpoint := responseConfig.AuthorizationEndpoint
|
||||||
|
if authorizationEndpoint == "" {
|
||||||
|
authorizationEndpoint = baseURL + "/authorize"
|
||||||
|
}
|
||||||
|
tokenEndpoint := responseConfig.TokenEndpoint
|
||||||
|
if tokenEndpoint == "" {
|
||||||
|
tokenEndpoint = baseURL + "/token"
|
||||||
|
}
|
||||||
|
registraionEndpoint := responseConfig.RegistrationEndpoint
|
||||||
|
if registraionEndpoint == "" {
|
||||||
|
registraionEndpoint = baseURL + "/register"
|
||||||
|
}
|
||||||
|
|
||||||
|
// Build response from config
|
||||||
|
response := map[string]interface{}{
|
||||||
|
"issuer": responseConfig.Issuer,
|
||||||
|
"authorization_endpoint": authorizationEndpoint,
|
||||||
|
"token_endpoint": tokenEndpoint,
|
||||||
|
"jwks_uri": responseConfig.JwksURI,
|
||||||
|
"response_types_supported": responseConfig.ResponseTypesSupported,
|
||||||
|
"grant_types_supported": responseConfig.GrantTypesSupported,
|
||||||
|
"token_endpoint_auth_methods_supported": []string{"client_secret_basic"},
|
||||||
|
"registration_endpoint": registraionEndpoint,
|
||||||
|
"code_challenge_methods_supported": responseConfig.CodeChallengeMethodsSupported,
|
||||||
|
}
|
||||||
|
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
if err := json.NewEncoder(w).Encode(response); err != nil {
|
||||||
|
http.Error(w, "Internal server error", http.StatusInternalServerError)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *defaultProvider) RegisterHandler() http.HandlerFunc {
|
||||||
|
return nil
|
||||||
|
}
|
|
@ -13,17 +13,67 @@ type DemoConfig struct {
|
||||||
OrgName string `yaml:"org_name"`
|
OrgName string `yaml:"org_name"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type AsgardeoConfig struct {
|
||||||
|
ClientID string `yaml:"client_id"`
|
||||||
|
ClientSecret string `yaml:"client_secret"`
|
||||||
|
OrgName string `yaml:"org_name"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type CORSConfig struct {
|
||||||
|
AllowedOrigins []string `yaml:"allowed_origins"`
|
||||||
|
AllowedMethods []string `yaml:"allowed_methods"`
|
||||||
|
AllowedHeaders []string `yaml:"allowed_headers"`
|
||||||
|
AllowCredentials bool `yaml:"allow_credentials"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type ParamConfig struct {
|
||||||
|
Name string `yaml:"name"`
|
||||||
|
Value string `yaml:"value"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type ResponseConfig struct {
|
||||||
|
Issuer string `yaml:"issuer,omitempty"`
|
||||||
|
JwksURI string `yaml:"jwks_uri,omitempty"`
|
||||||
|
AuthorizationEndpoint string `yaml:"authorization_endpoint,omitempty"`
|
||||||
|
TokenEndpoint string `yaml:"token_endpoint,omitempty"`
|
||||||
|
RegistrationEndpoint string `yaml:"registration_endpoint,omitempty"`
|
||||||
|
ResponseTypesSupported []string `yaml:"response_types_supported,omitempty"`
|
||||||
|
GrantTypesSupported []string `yaml:"grant_types_supported,omitempty"`
|
||||||
|
CodeChallengeMethodsSupported []string `yaml:"code_challenge_methods_supported,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type PathConfig struct {
|
||||||
|
// For well-known endpoint
|
||||||
|
Response *ResponseConfig `yaml:"response,omitempty"`
|
||||||
|
|
||||||
|
// For authorization endpoint
|
||||||
|
AddQueryParams []ParamConfig `yaml:"addQueryParams,omitempty"`
|
||||||
|
|
||||||
|
// For token and register endpoints
|
||||||
|
AddBodyParams []ParamConfig `yaml:"addBodyParams,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type DefaultConfig struct {
|
||||||
|
BaseURL string `yaml:"base_url,omitempty"`
|
||||||
|
Path map[string]PathConfig `yaml:"path,omitempty"`
|
||||||
|
JWKSURL string `yaml:"jwks_url,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
type Config struct {
|
type Config struct {
|
||||||
AuthServerBaseURL string `yaml:"auth_server_base_url"`
|
AuthServerBaseURL string
|
||||||
MCPServerBaseURL string `yaml:"mcp_server_base_url"`
|
MCPServerBaseURL string `yaml:"mcp_server_base_url"`
|
||||||
ListenAddress string `yaml:"listen_address"`
|
ListenPort int `yaml:"listen_port"`
|
||||||
JWKSURL string `yaml:"jwks_url"`
|
JWKSURL string
|
||||||
TimeoutSeconds int `yaml:"timeout_seconds"`
|
TimeoutSeconds int `yaml:"timeout_seconds"`
|
||||||
MCPPaths []string `yaml:"mcp_paths"`
|
MCPPaths []string `yaml:"mcp_paths"`
|
||||||
PathMapping map[string]string `yaml:"path_mapping"`
|
PathMapping map[string]string `yaml:"path_mapping"`
|
||||||
|
Mode string `yaml:"mode"`
|
||||||
|
CORSConfig CORSConfig `yaml:"cors"`
|
||||||
|
|
||||||
// Nested config for Asgardeo
|
// Nested config for Asgardeo
|
||||||
Demo DemoConfig `yaml:"demo"`
|
Demo DemoConfig `yaml:"demo"`
|
||||||
|
Asgardeo AsgardeoConfig `yaml:"asgardeo"`
|
||||||
|
Default DefaultConfig `yaml:"default"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// LoadConfig reads a YAML config file into Config struct.
|
// LoadConfig reads a YAML config file into Config struct.
|
||||||
|
|
7
internal/constants/constants.go
Normal file
7
internal/constants/constants.go
Normal file
|
@ -0,0 +1,7 @@
|
||||||
|
package constants
|
||||||
|
|
||||||
|
// Package constant provides constants for the MCP Auth Proxy
|
||||||
|
|
||||||
|
const (
|
||||||
|
ASGARDEO_BASE_URL = "https://api.asgardeo.io/t/"
|
||||||
|
)
|
199
internal/proxy/modifier.go
Normal file
199
internal/proxy/modifier.go
Normal file
|
@ -0,0 +1,199 @@
|
||||||
|
package proxy
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/wso2/open-mcp-auth-proxy/internal/config"
|
||||||
|
)
|
||||||
|
|
||||||
|
// RequestModifier modifies requests before they are proxied
|
||||||
|
type RequestModifier interface {
|
||||||
|
ModifyRequest(req *http.Request) (*http.Request, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
// AuthorizationModifier adds parameters to authorization requests
|
||||||
|
type AuthorizationModifier struct {
|
||||||
|
Config *config.Config
|
||||||
|
}
|
||||||
|
|
||||||
|
// TokenModifier adds parameters to token requests
|
||||||
|
type TokenModifier struct {
|
||||||
|
Config *config.Config
|
||||||
|
}
|
||||||
|
|
||||||
|
type RegisterModifier struct {
|
||||||
|
Config *config.Config
|
||||||
|
}
|
||||||
|
|
||||||
|
// ModifyRequest adds configured parameters to authorization requests
|
||||||
|
func (m *AuthorizationModifier) ModifyRequest(req *http.Request) (*http.Request, error) {
|
||||||
|
// Check if we have parameters to add
|
||||||
|
if m.Config.Default.Path == nil {
|
||||||
|
return req, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
pathConfig, exists := m.Config.Default.Path["/authorize"]
|
||||||
|
if !exists || len(pathConfig.AddQueryParams) == 0 {
|
||||||
|
return req, nil
|
||||||
|
}
|
||||||
|
// Get current query parameters
|
||||||
|
query := req.URL.Query()
|
||||||
|
|
||||||
|
// Add parameters from config
|
||||||
|
for _, param := range pathConfig.AddQueryParams {
|
||||||
|
query.Set(param.Name, param.Value)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Update the request URL
|
||||||
|
req.URL.RawQuery = query.Encode()
|
||||||
|
|
||||||
|
return req, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ModifyRequest adds configured parameters to token requests
|
||||||
|
func (m *TokenModifier) ModifyRequest(req *http.Request) (*http.Request, error) {
|
||||||
|
// Only modify POST requests
|
||||||
|
if req.Method != http.MethodPost {
|
||||||
|
return req, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if we have parameters to add
|
||||||
|
if m.Config.Default.Path == nil {
|
||||||
|
return req, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
pathConfig, exists := m.Config.Default.Path["/token"]
|
||||||
|
if !exists || len(pathConfig.AddBodyParams) == 0 {
|
||||||
|
return req, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
contentType := req.Header.Get("Content-Type")
|
||||||
|
|
||||||
|
if strings.Contains(contentType, "application/x-www-form-urlencoded") {
|
||||||
|
// Parse form data
|
||||||
|
if err := req.ParseForm(); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Clone form data
|
||||||
|
formData := req.PostForm
|
||||||
|
|
||||||
|
// Add configured parameters
|
||||||
|
for _, param := range pathConfig.AddBodyParams {
|
||||||
|
formData.Set(param.Name, param.Value)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create new request body with modified form
|
||||||
|
formEncoded := formData.Encode()
|
||||||
|
req.Body = io.NopCloser(strings.NewReader(formEncoded))
|
||||||
|
req.ContentLength = int64(len(formEncoded))
|
||||||
|
req.Header.Set("Content-Length", fmt.Sprintf("%d", len(formEncoded)))
|
||||||
|
|
||||||
|
} else if strings.Contains(contentType, "application/json") {
|
||||||
|
// Read body
|
||||||
|
bodyBytes, err := io.ReadAll(req.Body)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Parse JSON
|
||||||
|
var jsonData map[string]interface{}
|
||||||
|
if err := json.Unmarshal(bodyBytes, &jsonData); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add parameters
|
||||||
|
for _, param := range pathConfig.AddBodyParams {
|
||||||
|
jsonData[param.Name] = param.Value
|
||||||
|
}
|
||||||
|
|
||||||
|
// Marshal back to JSON
|
||||||
|
modifiedBody, err := json.Marshal(jsonData)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Update request
|
||||||
|
req.Body = io.NopCloser(bytes.NewReader(modifiedBody))
|
||||||
|
req.ContentLength = int64(len(modifiedBody))
|
||||||
|
req.Header.Set("Content-Length", fmt.Sprintf("%d", len(modifiedBody)))
|
||||||
|
}
|
||||||
|
|
||||||
|
return req, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *RegisterModifier) ModifyRequest(req *http.Request) (*http.Request, error) {
|
||||||
|
// Only modify POST requests
|
||||||
|
if req.Method != http.MethodPost {
|
||||||
|
return req, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if we have parameters to add
|
||||||
|
if m.Config.Default.Path == nil {
|
||||||
|
return req, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
pathConfig, exists := m.Config.Default.Path["/register"]
|
||||||
|
if !exists || len(pathConfig.AddBodyParams) == 0 {
|
||||||
|
return req, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
contentType := req.Header.Get("Content-Type")
|
||||||
|
|
||||||
|
if strings.Contains(contentType, "application/x-www-form-urlencoded") {
|
||||||
|
// Parse form data
|
||||||
|
if err := req.ParseForm(); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Clone form data
|
||||||
|
formData := req.PostForm
|
||||||
|
|
||||||
|
// Add configured parameters
|
||||||
|
for _, param := range pathConfig.AddBodyParams {
|
||||||
|
formData.Set(param.Name, param.Value)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create new request body with modified form
|
||||||
|
formEncoded := formData.Encode()
|
||||||
|
req.Body = io.NopCloser(strings.NewReader(formEncoded))
|
||||||
|
req.ContentLength = int64(len(formEncoded))
|
||||||
|
req.Header.Set("Content-Length", fmt.Sprintf("%d", len(formEncoded)))
|
||||||
|
|
||||||
|
} else if strings.Contains(contentType, "application/json") {
|
||||||
|
// Read body
|
||||||
|
bodyBytes, err := io.ReadAll(req.Body)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Parse JSON
|
||||||
|
var jsonData map[string]interface{}
|
||||||
|
if err := json.Unmarshal(bodyBytes, &jsonData); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add parameters
|
||||||
|
for _, param := range pathConfig.AddBodyParams {
|
||||||
|
jsonData[param.Name] = param.Value
|
||||||
|
}
|
||||||
|
|
||||||
|
// Marshal back to JSON
|
||||||
|
modifiedBody, err := json.Marshal(jsonData)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Update request
|
||||||
|
req.Body = io.NopCloser(bytes.NewReader(modifiedBody))
|
||||||
|
req.ContentLength = int64(len(modifiedBody))
|
||||||
|
req.Header.Set("Content-Length", fmt.Sprintf("%d", len(modifiedBody)))
|
||||||
|
}
|
||||||
|
|
||||||
|
return req, nil
|
||||||
|
}
|
|
@ -20,34 +20,87 @@ import (
|
||||||
func NewRouter(cfg *config.Config, provider authz.Provider) http.Handler {
|
func NewRouter(cfg *config.Config, provider authz.Provider) http.Handler {
|
||||||
mux := http.NewServeMux()
|
mux := http.NewServeMux()
|
||||||
|
|
||||||
// 1. Custom well-known
|
modifiers := map[string]RequestModifier{
|
||||||
mux.HandleFunc("/.well-known/oauth-authorization-server", provider.WellKnownHandler())
|
"/authorize": &AuthorizationModifier{Config: cfg},
|
||||||
|
"/token": &TokenModifier{Config: cfg},
|
||||||
|
"/register": &RegisterModifier{Config: cfg},
|
||||||
|
}
|
||||||
|
|
||||||
// 2. Registration
|
registeredPaths := make(map[string]bool)
|
||||||
mux.HandleFunc("/register", provider.RegisterHandler())
|
|
||||||
|
|
||||||
// 3. Default "auth" paths, proxied
|
var defaultPaths []string
|
||||||
defaultPaths := []string{"/authorize", "/token"}
|
|
||||||
|
// Handle based on mode configuration
|
||||||
|
if cfg.Mode == "demo" || cfg.Mode == "asgardeo" {
|
||||||
|
// Demo/Asgardeo mode: Custom handlers for well-known and register
|
||||||
|
mux.HandleFunc("/.well-known/oauth-authorization-server", provider.WellKnownHandler())
|
||||||
|
registeredPaths["/.well-known/oauth-authorization-server"] = true
|
||||||
|
|
||||||
|
mux.HandleFunc("/register", provider.RegisterHandler())
|
||||||
|
registeredPaths["/register"] = true
|
||||||
|
|
||||||
|
// Authorize and token will be proxied with parameter modification
|
||||||
|
defaultPaths = []string{"/authorize", "/token"}
|
||||||
|
} else {
|
||||||
|
// Default provider mode
|
||||||
|
if cfg.Default.Path != nil {
|
||||||
|
// Check if we have custom response for well-known
|
||||||
|
wellKnownConfig, exists := cfg.Default.Path["/.well-known/oauth-authorization-server"]
|
||||||
|
if exists && wellKnownConfig.Response != nil {
|
||||||
|
// If there's a custom response defined, use our handler
|
||||||
|
mux.HandleFunc("/.well-known/oauth-authorization-server", provider.WellKnownHandler())
|
||||||
|
registeredPaths["/.well-known/oauth-authorization-server"] = true
|
||||||
|
} else {
|
||||||
|
// No custom response, add well-known to proxy paths
|
||||||
|
defaultPaths = append(defaultPaths, "/.well-known/oauth-authorization-server")
|
||||||
|
}
|
||||||
|
|
||||||
|
defaultPaths = append(defaultPaths, "/authorize")
|
||||||
|
defaultPaths = append(defaultPaths, "/token")
|
||||||
|
defaultPaths = append(defaultPaths, "/register")
|
||||||
|
} else {
|
||||||
|
defaultPaths = []string{"/authorize", "/token", "/register", "/.well-known/oauth-authorization-server"}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Remove duplicates from defaultPaths
|
||||||
|
uniquePaths := make(map[string]bool)
|
||||||
|
cleanPaths := []string{}
|
||||||
for _, path := range defaultPaths {
|
for _, path := range defaultPaths {
|
||||||
mux.HandleFunc(path, buildProxyHandler(cfg))
|
if !uniquePaths[path] {
|
||||||
|
uniquePaths[path] = true
|
||||||
|
cleanPaths = append(cleanPaths, path)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
defaultPaths = cleanPaths
|
||||||
|
|
||||||
|
for _, path := range defaultPaths {
|
||||||
|
if !registeredPaths[path] {
|
||||||
|
mux.HandleFunc(path, buildProxyHandler(cfg, modifiers))
|
||||||
|
registeredPaths[path] = true
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// 4. MCP paths
|
// MCP paths
|
||||||
for _, path := range cfg.MCPPaths {
|
for _, path := range cfg.MCPPaths {
|
||||||
mux.HandleFunc(path, buildProxyHandler(cfg))
|
mux.HandleFunc(path, buildProxyHandler(cfg, modifiers))
|
||||||
|
registeredPaths[path] = true
|
||||||
}
|
}
|
||||||
|
|
||||||
// 5. If you want to map additional paths from config.PathMapping
|
// Register paths from PathMapping that haven't been registered yet
|
||||||
// to the same proxy logic:
|
|
||||||
for path := range cfg.PathMapping {
|
for path := range cfg.PathMapping {
|
||||||
mux.HandleFunc(path, buildProxyHandler(cfg))
|
if !registeredPaths[path] {
|
||||||
|
mux.HandleFunc(path, buildProxyHandler(cfg, modifiers))
|
||||||
|
registeredPaths[path] = true
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return mux
|
return mux
|
||||||
}
|
}
|
||||||
|
|
||||||
func buildProxyHandler(cfg *config.Config) http.HandlerFunc {
|
func buildProxyHandler(cfg *config.Config, modifiers map[string]RequestModifier) http.HandlerFunc {
|
||||||
// Parse the base URLs up front
|
// Parse the base URLs up front
|
||||||
|
|
||||||
authBase, err := url.Parse(cfg.AuthServerBaseURL)
|
authBase, err := url.Parse(cfg.AuthServerBaseURL)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatalf("Invalid auth server URL: %v", err)
|
log.Fatalf("Invalid auth server URL: %v", err)
|
||||||
|
@ -57,13 +110,6 @@ func buildProxyHandler(cfg *config.Config) http.HandlerFunc {
|
||||||
log.Fatalf("Invalid MCP server URL: %v", err)
|
log.Fatalf("Invalid MCP server URL: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// We'll define sets for known auth paths, SSE paths, etc.
|
|
||||||
authPaths := map[string]bool{
|
|
||||||
"/authorize": true,
|
|
||||||
"/token": true,
|
|
||||||
"/.well-known/oauth-authorization-server": true,
|
|
||||||
}
|
|
||||||
|
|
||||||
// Detect SSE paths from config
|
// Detect SSE paths from config
|
||||||
ssePaths := make(map[string]bool)
|
ssePaths := make(map[string]bool)
|
||||||
for _, p := range cfg.MCPPaths {
|
for _, p := range cfg.MCPPaths {
|
||||||
|
@ -73,23 +119,38 @@ func buildProxyHandler(cfg *config.Config) http.HandlerFunc {
|
||||||
}
|
}
|
||||||
|
|
||||||
return func(w http.ResponseWriter, r *http.Request) {
|
return func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
origin := r.Header.Get("Origin")
|
||||||
|
allowedOrigin := getAllowedOrigin(origin, cfg)
|
||||||
// Handle OPTIONS
|
// Handle OPTIONS
|
||||||
if r.Method == http.MethodOptions {
|
if r.Method == http.MethodOptions {
|
||||||
addCORSHeaders(w)
|
if allowedOrigin == "" {
|
||||||
|
log.Printf("[proxy] Preflight request from disallowed origin: %s", origin)
|
||||||
|
http.Error(w, "CORS origin not allowed", http.StatusForbidden)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
addCORSHeaders(w, cfg, allowedOrigin, r.Header.Get("Access-Control-Request-Headers"))
|
||||||
w.WriteHeader(http.StatusNoContent)
|
w.WriteHeader(http.StatusNoContent)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
addCORSHeaders(w)
|
if allowedOrigin == "" {
|
||||||
|
log.Printf("[proxy] Request from disallowed origin: %s for %s", origin, r.URL.Path)
|
||||||
|
http.Error(w, "CORS origin not allowed", http.StatusForbidden)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add CORS headers to all responses
|
||||||
|
addCORSHeaders(w, cfg, allowedOrigin, "")
|
||||||
|
|
||||||
// Decide whether the request should go to the auth server or MCP
|
// Decide whether the request should go to the auth server or MCP
|
||||||
var targetURL *url.URL
|
var targetURL *url.URL
|
||||||
isSSE := false
|
isSSE := false
|
||||||
|
|
||||||
if authPaths[r.URL.Path] {
|
if isAuthPath(r.URL.Path) {
|
||||||
targetURL = authBase
|
targetURL = authBase
|
||||||
} else if isMCPPath(r.URL.Path, cfg) {
|
} else if isMCPPath(r.URL.Path, cfg) {
|
||||||
// Validate JWT if you want
|
// Validate JWT for MCP paths if required
|
||||||
|
// Placeholder for JWT validation logic
|
||||||
if err := util.ValidateJWT(r.Header.Get("Authorization")); err != nil {
|
if err := util.ValidateJWT(r.Header.Get("Authorization")); err != nil {
|
||||||
log.Printf("[proxy] Unauthorized request to %s: %v", r.URL.Path, err)
|
log.Printf("[proxy] Unauthorized request to %s: %v", r.URL.Path, err)
|
||||||
http.Error(w, "Unauthorized", http.StatusUnauthorized)
|
http.Error(w, "Unauthorized", http.StatusUnauthorized)
|
||||||
|
@ -100,11 +161,21 @@ func buildProxyHandler(cfg *config.Config) http.HandlerFunc {
|
||||||
isSSE = true
|
isSSE = true
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
// If it's not recognized as an auth path or an MCP path
|
|
||||||
http.Error(w, "Forbidden", http.StatusForbidden)
|
http.Error(w, "Forbidden", http.StatusForbidden)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Apply request modifiers to add parameters
|
||||||
|
if modifier, exists := modifiers[r.URL.Path]; exists {
|
||||||
|
var err error
|
||||||
|
r, err = modifier.ModifyRequest(r)
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("[proxy] Error modifying request: %v", err)
|
||||||
|
http.Error(w, "Bad Request", http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Build the reverse proxy
|
// Build the reverse proxy
|
||||||
rp := &httputil.ReverseProxy{
|
rp := &httputil.ReverseProxy{
|
||||||
Director: func(req *http.Request) {
|
Director: func(req *http.Request) {
|
||||||
|
@ -120,23 +191,27 @@ func buildProxyHandler(cfg *config.Config) http.HandlerFunc {
|
||||||
req.URL.RawQuery = r.URL.RawQuery
|
req.URL.RawQuery = r.URL.RawQuery
|
||||||
req.Host = targetURL.Host
|
req.Host = targetURL.Host
|
||||||
|
|
||||||
for header, values := range r.Header {
|
cleanHeaders := http.Header{}
|
||||||
|
|
||||||
|
for k, v := range r.Header {
|
||||||
// Skip hop-by-hop headers
|
// Skip hop-by-hop headers
|
||||||
if strings.EqualFold(header, "Connection") ||
|
if skipHeader(k) {
|
||||||
strings.EqualFold(header, "Keep-Alive") ||
|
|
||||||
strings.EqualFold(header, "Transfer-Encoding") ||
|
|
||||||
strings.EqualFold(header, "Upgrade") ||
|
|
||||||
strings.EqualFold(header, "Proxy-Authorization") ||
|
|
||||||
strings.EqualFold(header, "Proxy-Connection") {
|
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, value := range values {
|
// Set only the first value to avoid duplicates
|
||||||
req.Header.Set(header, value)
|
cleanHeaders.Set(k, v[0])
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
req.Header = cleanHeaders
|
||||||
|
|
||||||
log.Printf("[proxy] %s -> %s%s", r.URL.Path, req.URL.Host, req.URL.Path)
|
log.Printf("[proxy] %s -> %s%s", r.URL.Path, req.URL.Host, req.URL.Path)
|
||||||
},
|
},
|
||||||
|
ModifyResponse: func(resp *http.Response) error {
|
||||||
|
log.Printf("[proxy] Response from %s%s: %d", resp.Request.URL.Host, resp.Request.URL.Path, resp.StatusCode)
|
||||||
|
resp.Header.Del("Access-Control-Allow-Origin") // Avoid upstream conflicts
|
||||||
|
return nil
|
||||||
|
},
|
||||||
ErrorHandler: func(rw http.ResponseWriter, req *http.Request, err error) {
|
ErrorHandler: func(rw http.ResponseWriter, req *http.Request, err error) {
|
||||||
log.Printf("[proxy] Error proxying: %v", err)
|
log.Printf("[proxy] Error proxying: %v", err)
|
||||||
http.Error(rw, "Bad Gateway", http.StatusBadGateway)
|
http.Error(rw, "Bad Gateway", http.StatusBadGateway)
|
||||||
|
@ -156,12 +231,47 @@ func buildProxyHandler(cfg *config.Config) http.HandlerFunc {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func addCORSHeaders(w http.ResponseWriter) {
|
func getAllowedOrigin(origin string, cfg *config.Config) string {
|
||||||
w.Header().Set("Access-Control-Allow-Origin", "*")
|
if origin == "" {
|
||||||
w.Header().Set("Access-Control-Allow-Headers", "Authorization, Content-Type")
|
return cfg.CORSConfig.AllowedOrigins[0] // Default to first allowed origin
|
||||||
w.Header().Set("Access-Control-Allow-Methods", "GET, POST, OPTIONS")
|
}
|
||||||
|
for _, allowed := range cfg.CORSConfig.AllowedOrigins {
|
||||||
|
if allowed == origin {
|
||||||
|
return allowed
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return ""
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// addCORSHeaders adds configurable CORS headers
|
||||||
|
func addCORSHeaders(w http.ResponseWriter, cfg *config.Config, allowedOrigin, requestHeaders string) {
|
||||||
|
w.Header().Set("Access-Control-Allow-Origin", allowedOrigin)
|
||||||
|
w.Header().Set("Access-Control-Allow-Methods", strings.Join(cfg.CORSConfig.AllowedMethods, ", "))
|
||||||
|
if requestHeaders != "" {
|
||||||
|
w.Header().Set("Access-Control-Allow-Headers", requestHeaders)
|
||||||
|
} else {
|
||||||
|
w.Header().Set("Access-Control-Allow-Headers", strings.Join(cfg.CORSConfig.AllowedHeaders, ", "))
|
||||||
|
}
|
||||||
|
if cfg.CORSConfig.AllowCredentials {
|
||||||
|
w.Header().Set("Access-Control-Allow-Credentials", "true")
|
||||||
|
}
|
||||||
|
w.Header().Set("Vary", "Origin")
|
||||||
|
}
|
||||||
|
|
||||||
|
func isAuthPath(path string) bool {
|
||||||
|
authPaths := map[string]bool{
|
||||||
|
"/authorize": true,
|
||||||
|
"/token": true,
|
||||||
|
"/register": true,
|
||||||
|
"/.well-known/oauth-authorization-server": true,
|
||||||
|
}
|
||||||
|
if strings.HasPrefix(path, "/u/") {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
return authPaths[path]
|
||||||
|
}
|
||||||
|
|
||||||
|
// isMCPPath checks if the path is an MCP path
|
||||||
func isMCPPath(path string, cfg *config.Config) bool {
|
func isMCPPath(path string, cfg *config.Config) bool {
|
||||||
for _, p := range cfg.MCPPaths {
|
for _, p := range cfg.MCPPaths {
|
||||||
if strings.HasPrefix(path, p) {
|
if strings.HasPrefix(path, p) {
|
||||||
|
@ -171,22 +281,10 @@ func isMCPPath(path string, cfg *config.Config) bool {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
func copyHeaders(src http.Header, dst http.Header) {
|
func skipHeader(h string) bool {
|
||||||
// Exclude hop-by-hop
|
switch strings.ToLower(h) {
|
||||||
hopByHop := map[string]bool{
|
case "connection", "keep-alive", "transfer-encoding", "upgrade", "proxy-authorization", "proxy-connection", "te", "trailer":
|
||||||
"Connection": true,
|
return true
|
||||||
"Keep-Alive": true,
|
|
||||||
"Transfer-Encoding": true,
|
|
||||||
"Upgrade": true,
|
|
||||||
"Proxy-Authorization": true,
|
|
||||||
"Proxy-Connection": true,
|
|
||||||
}
|
|
||||||
for k, vv := range src {
|
|
||||||
if hopByHop[strings.ToLower(k)] {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
for _, v := range vv {
|
|
||||||
dst.Add(k, v)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
return false
|
||||||
}
|
}
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue