mirror of
https://github.com/wso2/open-mcp-auth-proxy.git
synced 2025-06-27 17:13:31 +00:00
Merge pull request #1 from shashimalcse/proxy_impl
Add OpenMCPAuthProxy
This commit is contained in:
commit
43d815769d
11 changed files with 893 additions and 1 deletions
8
.gitignore
vendored
8
.gitignore
vendored
|
@ -22,3 +22,11 @@
|
|||
# virtual machine crash logs, see http://www.java.com/en/download/help/error_hotspot.xml
|
||||
hs_err_pid*
|
||||
replay_pid*
|
||||
|
||||
# Go module cache files
|
||||
go.sum
|
||||
|
||||
# OS generated files
|
||||
.DS_Store
|
||||
|
||||
openmcpauthproxy
|
||||
|
|
81
README.md
81
README.md
|
@ -1,2 +1,81 @@
|
|||
# open-mcp-auth-proxy
|
||||
Authentication and Authorization Proxy for MCP Servers
|
||||
|
||||
## Overview
|
||||
|
||||
OpenMCPAuthProxy is a security middleware that implements the Model Context Protocol (MCP) Authorization Specification (2025-03-26). It functions as a proxy between clients and MCP servers, providing robust authentication and authorization capabilities. The proxy intercepts incoming requests, validates authentication tokens, and forwards only authorized requests to the underlying MCP server, enhancing the security posture of your MCP deployment.
|
||||
|
||||
## Setup and Installation
|
||||
|
||||
### Prerequisites
|
||||
- Go 1.20 or higher
|
||||
- A running MCP server (SSE transport supported)
|
||||
|
||||
### Installation
|
||||
```bash
|
||||
git clone https://github.com/wso2/open-mcp-auth-proxy
|
||||
cd open-mcp-auth-proxy
|
||||
go build -o openmcpauthproxy ./cmd/proxy
|
||||
```
|
||||
|
||||
## Configuration
|
||||
|
||||
Create a configuration file `config.yaml` with the following parameters:
|
||||
|
||||
```yaml
|
||||
mcp_server_base_url: "http://localhost:8000" # URL of your MCP server
|
||||
listen_address: ":8080" # Address where the proxy will listen
|
||||
```
|
||||
|
||||
## Usage Example
|
||||
|
||||
### 1. Start the MCP Server
|
||||
|
||||
Create a file named `echo_server.py`:
|
||||
|
||||
```python
|
||||
from mcp.server.fastmcp import FastMCP
|
||||
|
||||
mcp = FastMCP("Echo")
|
||||
|
||||
|
||||
@mcp.resource("echo://{message}")
|
||||
def echo_resource(message: str) -> str:
|
||||
"""Echo a message as a resource"""
|
||||
return f"Resource echo: {message}"
|
||||
|
||||
|
||||
@mcp.tool()
|
||||
def echo_tool(message: str) -> str:
|
||||
"""Echo a message as a tool"""
|
||||
return f"Tool echo: {message}"
|
||||
|
||||
|
||||
@mcp.prompt()
|
||||
def echo_prompt(message: str) -> str:
|
||||
"""Create an echo prompt"""
|
||||
return f"Please process this message: {message}"
|
||||
|
||||
if __name__ == "__main__":
|
||||
mcp.run(transport="sse")
|
||||
```
|
||||
|
||||
Run the server:
|
||||
```bash
|
||||
python3 echo_server.py
|
||||
```
|
||||
|
||||
### 2. Start the Auth Proxy
|
||||
|
||||
```bash
|
||||
./openmcpauthproxy --demo
|
||||
```
|
||||
|
||||
The `--demo` flag enables a demonstration mode with pre-configured authentication with asgardeo.
|
||||
|
||||
### 3. Connect Using an MCP Client
|
||||
|
||||
You can use the [MCP Inspector](https://github.com/modelcontextprotocol/inspector) to test the connection:
|
||||
|
||||
## Contributing
|
||||
|
||||
Contributions are welcome! Please feel free to submit a Pull Request.
|
||||
|
|
73
cmd/proxy/main.go
Normal file
73
cmd/proxy/main.go
Normal file
|
@ -0,0 +1,73 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"flag"
|
||||
"fmt"
|
||||
"log"
|
||||
"net/http"
|
||||
"os"
|
||||
"os/signal"
|
||||
"time"
|
||||
|
||||
"github.com/wso2/open-mcp-auth-proxy/internal/authz"
|
||||
"github.com/wso2/open-mcp-auth-proxy/internal/config"
|
||||
"github.com/wso2/open-mcp-auth-proxy/internal/proxy"
|
||||
"github.com/wso2/open-mcp-auth-proxy/internal/util"
|
||||
)
|
||||
|
||||
func main() {
|
||||
demoMode := flag.Bool("demo", false, "Use Asgardeo-based provider (demo).")
|
||||
flag.Parse()
|
||||
|
||||
// 1. Load config
|
||||
cfg, err := config.LoadConfig("config.yaml")
|
||||
if err != nil {
|
||||
log.Fatalf("Error loading config: %v", err)
|
||||
}
|
||||
|
||||
// 2. Create the chosen provider
|
||||
var provider authz.Provider
|
||||
if *demoMode {
|
||||
cfg.AuthServerBaseURL = "https://api.asgardeo.io/t/" + cfg.Demo.OrgName + "/oauth2"
|
||||
cfg.JWKSURL = "https://api.asgardeo.io/t/" + cfg.Demo.OrgName + "/oauth2/jwks"
|
||||
provider = authz.NewAsgardeoProvider(cfg)
|
||||
fmt.Println("Using Asgardeo provider (demo).")
|
||||
} else {
|
||||
log.Fatalf("Not supported yet.")
|
||||
}
|
||||
|
||||
// 3. (Optional) Fetch JWKS if you want local JWT validation
|
||||
if err := util.FetchJWKS(cfg.JWKSURL); err != nil {
|
||||
log.Fatalf("Failed to fetch JWKS: %v", err)
|
||||
}
|
||||
|
||||
// 4. Build the main router
|
||||
mux := proxy.NewRouter(cfg, provider)
|
||||
|
||||
// 5. Start the server
|
||||
srv := &http.Server{
|
||||
Addr: cfg.ListenAddress,
|
||||
Handler: mux,
|
||||
}
|
||||
|
||||
go func() {
|
||||
log.Printf("Server listening on %s", cfg.ListenAddress)
|
||||
if err := srv.ListenAndServe(); err != nil && err != http.ErrServerClosed {
|
||||
log.Fatalf("Server error: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
// 6. Graceful shutdown on Ctrl+C
|
||||
stop := make(chan os.Signal, 1)
|
||||
signal.Notify(stop, os.Interrupt)
|
||||
<-stop
|
||||
log.Println("Shutting down...")
|
||||
|
||||
shutdownCtx, cancel := proxy.NewShutdownContext(5 * time.Second)
|
||||
defer cancel()
|
||||
|
||||
if err := srv.Shutdown(shutdownCtx); err != nil {
|
||||
log.Printf("Shutdown error: %v", err)
|
||||
}
|
||||
log.Println("Stopped.")
|
||||
}
|
18
config.yaml
Normal file
18
config.yaml
Normal file
|
@ -0,0 +1,18 @@
|
|||
# config.yaml
|
||||
|
||||
auth_server_base_url: ""
|
||||
mcp_server_base_url: "http://localhost:8000"
|
||||
listen_address: ":8080"
|
||||
jwks_url: ""
|
||||
timeout_seconds: 10
|
||||
|
||||
mcp_paths:
|
||||
- /messages/
|
||||
- /sse
|
||||
|
||||
path_mapping:
|
||||
|
||||
demo:
|
||||
org_name: "openmcpauthdemo"
|
||||
client_id: "N0U9e_NNGr9mP_0fPnPfPI0a6twa"
|
||||
client_secret: "qFHfiBp5gNGAO9zV4YPnDofBzzfInatfUbHyPZvM0jka"
|
8
go.mod
Normal file
8
go.mod
Normal file
|
@ -0,0 +1,8 @@
|
|||
module github.com/wso2/open-mcp-auth-proxy
|
||||
|
||||
go 1.22.3
|
||||
|
||||
require (
|
||||
github.com/golang-jwt/jwt/v4 v4.5.2
|
||||
gopkg.in/yaml.v2 v2.4.0
|
||||
)
|
327
internal/authz/asgardeo.go
Normal file
327
internal/authz/asgardeo.go
Normal file
|
@ -0,0 +1,327 @@
|
|||
package authz
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/tls"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"math/rand"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/wso2/open-mcp-auth-proxy/internal/config"
|
||||
)
|
||||
|
||||
type asgardeoProvider struct {
|
||||
cfg *config.Config
|
||||
}
|
||||
|
||||
// NewAsgardeoProvider initializes a Provider for Asgardeo (demo mode).
|
||||
func NewAsgardeoProvider(cfg *config.Config) Provider {
|
||||
return &asgardeoProvider{cfg: cfg}
|
||||
}
|
||||
|
||||
func (p *asgardeoProvider) WellKnownHandler() http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
w.Header().Set("Access-Control-Allow-Origin", "*")
|
||||
w.Header().Set("Access-Control-Allow-Headers", "Authorization, Content-Type")
|
||||
w.Header().Set("Access-Control-Allow-Methods", "GET, OPTIONS")
|
||||
|
||||
if r.Method == http.MethodOptions {
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
return
|
||||
}
|
||||
if r.Method != http.MethodGet {
|
||||
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
||||
return
|
||||
}
|
||||
|
||||
scheme := "http"
|
||||
if r.TLS != nil {
|
||||
scheme = "https"
|
||||
}
|
||||
if forwardedProto := r.Header.Get("X-Forwarded-Proto"); forwardedProto != "" {
|
||||
scheme = forwardedProto
|
||||
}
|
||||
host := r.Host
|
||||
if forwardedHost := r.Header.Get("X-Forwarded-Host"); forwardedHost != "" {
|
||||
host = forwardedHost
|
||||
}
|
||||
|
||||
baseURL := scheme + "://" + host
|
||||
|
||||
issuer := strings.TrimSuffix(p.cfg.AuthServerBaseURL, "/") + "/token"
|
||||
|
||||
response := map[string]interface{}{
|
||||
"issuer": issuer,
|
||||
"authorization_endpoint": baseURL + "/authorize",
|
||||
"token_endpoint": baseURL + "/token",
|
||||
"jwks_uri": p.cfg.JWKSURL,
|
||||
"response_types_supported": []string{"code"},
|
||||
"grant_types_supported": []string{"authorization_code", "refresh_token"},
|
||||
"token_endpoint_auth_methods_supported": []string{"client_secret_basic"},
|
||||
"registration_endpoint": baseURL + "/register",
|
||||
"code_challenge_methods_supported": []string{"plain", "S256"},
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
if err := json.NewEncoder(w).Encode(response); err != nil {
|
||||
log.Printf("[asgardeoProvider] Error encoding well-known: %v", err)
|
||||
http.Error(w, "Internal server error", http.StatusInternalServerError)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (p *asgardeoProvider) RegisterHandler() http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
w.Header().Set("Access-Control-Allow-Origin", "*")
|
||||
w.Header().Set("Access-Control-Allow-Headers", "Authorization, Content-Type")
|
||||
w.Header().Set("Access-Control-Allow-Methods", "GET, OPTIONS")
|
||||
|
||||
if r.Method == http.MethodOptions {
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
return
|
||||
}
|
||||
if r.Method != http.MethodPost {
|
||||
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
||||
return
|
||||
}
|
||||
|
||||
var regReq RegisterRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(®Req); err != nil {
|
||||
log.Printf("ERROR: reading register request: %v", err)
|
||||
http.Error(w, "Invalid request body", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
if len(regReq.RedirectURIs) == 0 {
|
||||
http.Error(w, "redirect_uris is required", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
// Generate credentials
|
||||
regReq.ClientID = "client-" + randomString(8)
|
||||
regReq.ClientSecret = randomString(16)
|
||||
|
||||
if err := p.createAsgardeoApplication(regReq); err != nil {
|
||||
log.Printf("WARN: Asgardeo application creation failed: %v", err)
|
||||
// Optionally http.Error(...) if you want to fail
|
||||
// or continue to return partial data.
|
||||
}
|
||||
|
||||
resp := RegisterResponse{
|
||||
ClientID: regReq.ClientID,
|
||||
ClientSecret: regReq.ClientSecret,
|
||||
ClientName: regReq.ClientName,
|
||||
RedirectURIs: regReq.RedirectURIs,
|
||||
GrantTypes: regReq.GrantTypes,
|
||||
ResponseTypes: regReq.ResponseTypes,
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusCreated)
|
||||
if err := json.NewEncoder(w).Encode(resp); err != nil {
|
||||
log.Printf("ERROR: encoding /register response: %v", err)
|
||||
http.Error(w, "Internal server error", http.StatusInternalServerError)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ----------------------------------------------------------------
|
||||
// Asgardeo-specific helpers
|
||||
// ----------------------------------------------------------------
|
||||
|
||||
type RegisterRequest struct {
|
||||
ClientID string `json:"client_id,omitempty"`
|
||||
ClientSecret string `json:"client_secret,omitempty"`
|
||||
ClientName string `json:"client_name"`
|
||||
RedirectURIs []string `json:"redirect_uris"`
|
||||
GrantTypes []string `json:"grant_types,omitempty"`
|
||||
ResponseTypes []string `json:"response_types,omitempty"`
|
||||
}
|
||||
|
||||
type RegisterResponse struct {
|
||||
ClientID string `json:"client_id"`
|
||||
ClientSecret string `json:"client_secret"`
|
||||
ClientName string `json:"client_name"`
|
||||
RedirectURIs []string `json:"redirect_uris"`
|
||||
GrantTypes []string `json:"grant_types"`
|
||||
ResponseTypes []string `json:"response_types"`
|
||||
}
|
||||
|
||||
func (p *asgardeoProvider) createAsgardeoApplication(regReq RegisterRequest) error {
|
||||
body := buildAsgardeoPayload(regReq)
|
||||
reqBytes, err := json.Marshal(body)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal Asgardeo request: %w", err)
|
||||
}
|
||||
|
||||
asgardeoAppURL := "https://api.asgardeo.io/t/" + p.cfg.Demo.OrgName + "/api/server/v1/applications"
|
||||
req, err := http.NewRequest("POST", asgardeoAppURL, bytes.NewBuffer(reqBytes))
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create Asgardeo API request: %w", err)
|
||||
}
|
||||
|
||||
token, err := p.getAsgardeoAdminToken()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get Asgardeo admin token: %w", err)
|
||||
}
|
||||
req.Header.Set("Authorization", "Bearer "+token)
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
client := &http.Client{Timeout: 30 * time.Second}
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return fmt.Errorf("asgardeo API call failed: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode >= 400 {
|
||||
respBody, _ := io.ReadAll(resp.Body)
|
||||
return fmt.Errorf("Asgardeo creation error (%d): %s", resp.StatusCode, string(respBody))
|
||||
}
|
||||
|
||||
log.Printf("INFO: Created Asgardeo application for clientID=%s", regReq.ClientID)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *asgardeoProvider) getAsgardeoAdminToken() (string, error) {
|
||||
tokenURL := p.cfg.AuthServerBaseURL + "/token"
|
||||
|
||||
formData := "grant_type=client_credentials&scope=internal_application_mgt_create internal_application_mgt_delete " +
|
||||
"internal_application_mgt_update internal_application_mgt_view"
|
||||
|
||||
req, err := http.NewRequest("POST", tokenURL, strings.NewReader(formData))
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
|
||||
auth := p.cfg.Demo.ClientID + ":" + p.cfg.Demo.ClientSecret
|
||||
req.Header.Set("Authorization", "Basic "+base64.StdEncoding.EncodeToString([]byte(auth)))
|
||||
|
||||
tr := &http.Transport{
|
||||
TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
|
||||
}
|
||||
client := &http.Client{
|
||||
Timeout: time.Duration(p.cfg.TimeoutSeconds) * time.Second,
|
||||
Transport: tr,
|
||||
}
|
||||
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("token request failed: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
return "", fmt.Errorf("token request failed (%d): %s", resp.StatusCode, string(body))
|
||||
}
|
||||
|
||||
var tokenResp struct {
|
||||
AccessToken string `json:"access_token"`
|
||||
TokenType string `json:"token_type"`
|
||||
ExpiresIn int `json:"expires_in"`
|
||||
Scope string `json:"scope"`
|
||||
}
|
||||
if err := json.NewDecoder(resp.Body).Decode(&tokenResp); err != nil {
|
||||
return "", fmt.Errorf("failed to parse token JSON: %w", err)
|
||||
}
|
||||
|
||||
return tokenResp.AccessToken, nil
|
||||
}
|
||||
|
||||
func buildAsgardeoPayload(regReq RegisterRequest) map[string]interface{} {
|
||||
appName := regReq.ClientName
|
||||
if appName == "" {
|
||||
appName = "demo-app"
|
||||
}
|
||||
appName += "-" + randomString(5)
|
||||
|
||||
return map[string]interface{}{
|
||||
"name": appName,
|
||||
"templateId": "custom-application-oidc",
|
||||
"inboundProtocolConfiguration": map[string]interface{}{
|
||||
"oidc": map[string]interface{}{
|
||||
"clientId": regReq.ClientID,
|
||||
"clientSecret": regReq.ClientSecret,
|
||||
"grantTypes": regReq.GrantTypes,
|
||||
"callbackURLs": regReq.RedirectURIs,
|
||||
"allowedOrigins": []string{},
|
||||
"publicClient": false,
|
||||
"pkce": map[string]bool{
|
||||
"mandatory": true,
|
||||
"supportPlainTransformAlgorithm": true,
|
||||
},
|
||||
"accessToken": map[string]interface{}{
|
||||
"type": "JWT",
|
||||
"userAccessTokenExpiryInSeconds": 3600,
|
||||
"applicationAccessTokenExpiryInSeconds": 3600,
|
||||
"bindingType": "cookie",
|
||||
"revokeTokensWhenIDPSessionTerminated": true,
|
||||
"validateTokenBinding": true,
|
||||
},
|
||||
"refreshToken": map[string]interface{}{
|
||||
"expiryInSeconds": 86400,
|
||||
"renewRefreshToken": true,
|
||||
},
|
||||
"idToken": map[string]interface{}{
|
||||
"expiryInSeconds": 3600,
|
||||
"audience": []string{},
|
||||
"encryption": map[string]interface{}{
|
||||
"enabled": false,
|
||||
"algorithm": "RSA-OAEP",
|
||||
"method": "A128CBC+HS256",
|
||||
},
|
||||
},
|
||||
"logout": map[string]interface{}{},
|
||||
"validateRequestObjectSignature": false,
|
||||
},
|
||||
},
|
||||
"authenticationSequence": map[string]interface{}{
|
||||
"type": "USER_DEFINED",
|
||||
"steps": []map[string]interface{}{
|
||||
{
|
||||
"id": 1,
|
||||
"options": []map[string]string{
|
||||
{
|
||||
"idp": "Google",
|
||||
"authenticator": "GoogleOIDCAuthenticator",
|
||||
},
|
||||
{
|
||||
"idp": "GitHub",
|
||||
"authenticator": "GithubAuthenticator",
|
||||
},
|
||||
{
|
||||
"idp": "Microsoft",
|
||||
"authenticator": "OpenIDConnectAuthenticator",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
"script": "var onLoginRequest = function(context) {\n executeStep(1);\n};\n",
|
||||
"subjectStepId": 1,
|
||||
"attributeStepId": 1,
|
||||
},
|
||||
"advancedConfigurations": map[string]interface{}{
|
||||
"skipLoginConsent": false,
|
||||
"skipLogoutConsent": false,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
const letters = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"
|
||||
|
||||
func randomString(n int) string {
|
||||
b := make([]byte, n)
|
||||
for i := 0; i < n; i++ {
|
||||
b[i] = letters[rand.Intn(len(letters))]
|
||||
}
|
||||
return string(b)
|
||||
}
|
10
internal/authz/provider.go
Normal file
10
internal/authz/provider.go
Normal file
|
@ -0,0 +1,10 @@
|
|||
package authz
|
||||
|
||||
import "net/http"
|
||||
|
||||
// Provider is an interface describing how each auth provider
|
||||
// will handle /.well-known/oauth-authorization-server and /register
|
||||
type Provider interface {
|
||||
WellKnownHandler() http.HandlerFunc
|
||||
RegisterHandler() http.HandlerFunc
|
||||
}
|
46
internal/config/config.go
Normal file
46
internal/config/config.go
Normal file
|
@ -0,0 +1,46 @@
|
|||
package config
|
||||
|
||||
import (
|
||||
"os"
|
||||
|
||||
"gopkg.in/yaml.v2"
|
||||
)
|
||||
|
||||
// AsgardeoConfig groups all Asgardeo-specific fields
|
||||
type DemoConfig struct {
|
||||
ClientID string `yaml:"client_id"`
|
||||
ClientSecret string `yaml:"client_secret"`
|
||||
OrgName string `yaml:"org_name"`
|
||||
}
|
||||
|
||||
type Config struct {
|
||||
AuthServerBaseURL string `yaml:"auth_server_base_url"`
|
||||
MCPServerBaseURL string `yaml:"mcp_server_base_url"`
|
||||
ListenAddress string `yaml:"listen_address"`
|
||||
JWKSURL string `yaml:"jwks_url"`
|
||||
TimeoutSeconds int `yaml:"timeout_seconds"`
|
||||
MCPPaths []string `yaml:"mcp_paths"`
|
||||
PathMapping map[string]string `yaml:"path_mapping"`
|
||||
|
||||
// Nested config for Asgardeo
|
||||
Demo DemoConfig `yaml:"demo"`
|
||||
}
|
||||
|
||||
// LoadConfig reads a YAML config file into Config struct.
|
||||
func LoadConfig(path string) (*Config, error) {
|
||||
f, err := os.Open(path)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
var cfg Config
|
||||
decoder := yaml.NewDecoder(f)
|
||||
if err := decoder.Decode(&cfg); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if cfg.TimeoutSeconds == 0 {
|
||||
cfg.TimeoutSeconds = 15 // default
|
||||
}
|
||||
return &cfg, nil
|
||||
}
|
192
internal/proxy/proxy.go
Normal file
192
internal/proxy/proxy.go
Normal file
|
@ -0,0 +1,192 @@
|
|||
package proxy
|
||||
|
||||
import (
|
||||
"context"
|
||||
"log"
|
||||
"net/http"
|
||||
"net/http/httputil"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/wso2/open-mcp-auth-proxy/internal/authz"
|
||||
"github.com/wso2/open-mcp-auth-proxy/internal/config"
|
||||
"github.com/wso2/open-mcp-auth-proxy/internal/util"
|
||||
)
|
||||
|
||||
// NewRouter builds an http.ServeMux that routes
|
||||
// * /authorize, /token, /register, /.well-known to the provider or proxy
|
||||
// * MCP paths to the MCP server, etc.
|
||||
func NewRouter(cfg *config.Config, provider authz.Provider) http.Handler {
|
||||
mux := http.NewServeMux()
|
||||
|
||||
// 1. Custom well-known
|
||||
mux.HandleFunc("/.well-known/oauth-authorization-server", provider.WellKnownHandler())
|
||||
|
||||
// 2. Registration
|
||||
mux.HandleFunc("/register", provider.RegisterHandler())
|
||||
|
||||
// 3. Default "auth" paths, proxied
|
||||
defaultPaths := []string{"/authorize", "/token"}
|
||||
for _, path := range defaultPaths {
|
||||
mux.HandleFunc(path, buildProxyHandler(cfg))
|
||||
}
|
||||
|
||||
// 4. MCP paths
|
||||
for _, path := range cfg.MCPPaths {
|
||||
mux.HandleFunc(path, buildProxyHandler(cfg))
|
||||
}
|
||||
|
||||
// 5. If you want to map additional paths from config.PathMapping
|
||||
// to the same proxy logic:
|
||||
for path := range cfg.PathMapping {
|
||||
mux.HandleFunc(path, buildProxyHandler(cfg))
|
||||
}
|
||||
|
||||
return mux
|
||||
}
|
||||
|
||||
func buildProxyHandler(cfg *config.Config) http.HandlerFunc {
|
||||
// Parse the base URLs up front
|
||||
authBase, err := url.Parse(cfg.AuthServerBaseURL)
|
||||
if err != nil {
|
||||
log.Fatalf("Invalid auth server URL: %v", err)
|
||||
}
|
||||
mcpBase, err := url.Parse(cfg.MCPServerBaseURL)
|
||||
if err != nil {
|
||||
log.Fatalf("Invalid MCP server URL: %v", err)
|
||||
}
|
||||
|
||||
// We'll define sets for known auth paths, SSE paths, etc.
|
||||
authPaths := map[string]bool{
|
||||
"/authorize": true,
|
||||
"/token": true,
|
||||
"/.well-known/oauth-authorization-server": true,
|
||||
}
|
||||
|
||||
// Detect SSE paths from config
|
||||
ssePaths := make(map[string]bool)
|
||||
for _, p := range cfg.MCPPaths {
|
||||
if p == "/sse" {
|
||||
ssePaths[p] = true
|
||||
}
|
||||
}
|
||||
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
// Handle OPTIONS
|
||||
if r.Method == http.MethodOptions {
|
||||
addCORSHeaders(w)
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
return
|
||||
}
|
||||
|
||||
addCORSHeaders(w)
|
||||
|
||||
// Decide whether the request should go to the auth server or MCP
|
||||
var targetURL *url.URL
|
||||
isSSE := false
|
||||
|
||||
if authPaths[r.URL.Path] {
|
||||
targetURL = authBase
|
||||
} else if isMCPPath(r.URL.Path, cfg) {
|
||||
// Validate JWT if you want
|
||||
if err := util.ValidateJWT(r.Header.Get("Authorization")); err != nil {
|
||||
log.Printf("[proxy] Unauthorized request to %s: %v", r.URL.Path, err)
|
||||
http.Error(w, "Unauthorized", http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
targetURL = mcpBase
|
||||
if ssePaths[r.URL.Path] {
|
||||
isSSE = true
|
||||
}
|
||||
} else {
|
||||
// If it's not recognized as an auth path or an MCP path
|
||||
http.Error(w, "Forbidden", http.StatusForbidden)
|
||||
return
|
||||
}
|
||||
|
||||
// Build the reverse proxy
|
||||
rp := &httputil.ReverseProxy{
|
||||
Director: func(req *http.Request) {
|
||||
// Path rewriting if needed
|
||||
mapped := r.URL.Path
|
||||
if rewrite, ok := cfg.PathMapping[r.URL.Path]; ok {
|
||||
mapped = rewrite
|
||||
}
|
||||
basePath := strings.TrimRight(targetURL.Path, "/")
|
||||
req.URL.Scheme = targetURL.Scheme
|
||||
req.URL.Host = targetURL.Host
|
||||
req.URL.Path = basePath + mapped
|
||||
req.URL.RawQuery = r.URL.RawQuery
|
||||
req.Host = targetURL.Host
|
||||
|
||||
for header, values := range r.Header {
|
||||
// Skip hop-by-hop headers
|
||||
if strings.EqualFold(header, "Connection") ||
|
||||
strings.EqualFold(header, "Keep-Alive") ||
|
||||
strings.EqualFold(header, "Transfer-Encoding") ||
|
||||
strings.EqualFold(header, "Upgrade") ||
|
||||
strings.EqualFold(header, "Proxy-Authorization") ||
|
||||
strings.EqualFold(header, "Proxy-Connection") {
|
||||
continue
|
||||
}
|
||||
|
||||
for _, value := range values {
|
||||
req.Header.Set(header, value)
|
||||
}
|
||||
}
|
||||
log.Printf("[proxy] %s -> %s%s", r.URL.Path, req.URL.Host, req.URL.Path)
|
||||
},
|
||||
ErrorHandler: func(rw http.ResponseWriter, req *http.Request, err error) {
|
||||
log.Printf("[proxy] Error proxying: %v", err)
|
||||
http.Error(rw, "Bad Gateway", http.StatusBadGateway)
|
||||
},
|
||||
FlushInterval: -1, // immediate flush for SSE
|
||||
}
|
||||
|
||||
if isSSE {
|
||||
// Keep SSE connections open
|
||||
HandleSSE(w, r, rp)
|
||||
} else {
|
||||
// Standard requests: enforce a timeout
|
||||
ctx, cancel := context.WithTimeout(r.Context(), time.Duration(cfg.TimeoutSeconds)*time.Second)
|
||||
defer cancel()
|
||||
rp.ServeHTTP(w, r.WithContext(ctx))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func addCORSHeaders(w http.ResponseWriter) {
|
||||
w.Header().Set("Access-Control-Allow-Origin", "*")
|
||||
w.Header().Set("Access-Control-Allow-Headers", "Authorization, Content-Type")
|
||||
w.Header().Set("Access-Control-Allow-Methods", "GET, POST, OPTIONS")
|
||||
}
|
||||
|
||||
func isMCPPath(path string, cfg *config.Config) bool {
|
||||
for _, p := range cfg.MCPPaths {
|
||||
if strings.HasPrefix(path, p) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func copyHeaders(src http.Header, dst http.Header) {
|
||||
// Exclude hop-by-hop
|
||||
hopByHop := map[string]bool{
|
||||
"Connection": true,
|
||||
"Keep-Alive": true,
|
||||
"Transfer-Encoding": true,
|
||||
"Upgrade": true,
|
||||
"Proxy-Authorization": true,
|
||||
"Proxy-Connection": true,
|
||||
}
|
||||
for k, vv := range src {
|
||||
if hopByHop[strings.ToLower(k)] {
|
||||
continue
|
||||
}
|
||||
for _, v := range vv {
|
||||
dst.Add(k, v)
|
||||
}
|
||||
}
|
||||
}
|
34
internal/proxy/sse.go
Normal file
34
internal/proxy/sse.go
Normal file
|
@ -0,0 +1,34 @@
|
|||
package proxy
|
||||
|
||||
import (
|
||||
"context"
|
||||
"log"
|
||||
"net/http"
|
||||
"net/http/httputil"
|
||||
"time"
|
||||
)
|
||||
|
||||
// HandleSSE sets up a go-routine to wait for context cancellation
|
||||
// and flushes the response if possible.
|
||||
func HandleSSE(w http.ResponseWriter, r *http.Request, rp *httputil.ReverseProxy) {
|
||||
ctx := r.Context()
|
||||
done := make(chan struct{})
|
||||
|
||||
go func() {
|
||||
<-ctx.Done()
|
||||
log.Printf("INFO: SSE connection closed from %s (path: %s)", r.RemoteAddr, r.URL.Path)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
rp.ServeHTTP(w, r)
|
||||
if f, ok := w.(http.Flusher); ok {
|
||||
f.Flush()
|
||||
}
|
||||
|
||||
<-done
|
||||
}
|
||||
|
||||
// NewShutdownContext is a little helper to gracefully shut down
|
||||
func NewShutdownContext(timeout time.Duration) (context.Context, context.CancelFunc) {
|
||||
return context.WithTimeout(context.Background(), timeout)
|
||||
}
|
97
internal/util/jwks.go
Normal file
97
internal/util/jwks.go
Normal file
|
@ -0,0 +1,97 @@
|
|||
package util
|
||||
|
||||
import (
|
||||
"crypto/rsa"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"log"
|
||||
"math/big"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/golang-jwt/jwt/v4"
|
||||
)
|
||||
|
||||
type JWKS struct {
|
||||
Keys []json.RawMessage `json:"keys"`
|
||||
}
|
||||
|
||||
var publicKeys map[string]*rsa.PublicKey
|
||||
|
||||
// FetchJWKS downloads JWKS and stores in a package-level map
|
||||
func FetchJWKS(jwksURL string) error {
|
||||
resp, err := http.Get(jwksURL)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
var jwks JWKS
|
||||
if err := json.NewDecoder(resp.Body).Decode(&jwks); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
publicKeys = make(map[string]*rsa.PublicKey)
|
||||
for _, keyData := range jwks.Keys {
|
||||
var parsedKey struct {
|
||||
Kid string `json:"kid"`
|
||||
N string `json:"n"`
|
||||
E string `json:"e"`
|
||||
Kty string `json:"kty"`
|
||||
}
|
||||
if err := json.Unmarshal(keyData, &parsedKey); err != nil {
|
||||
continue
|
||||
}
|
||||
if parsedKey.Kty != "RSA" {
|
||||
continue
|
||||
}
|
||||
pubKey, err := parseRSAPublicKey(parsedKey.N, parsedKey.E)
|
||||
if err == nil {
|
||||
publicKeys[parsedKey.Kid] = pubKey
|
||||
}
|
||||
}
|
||||
log.Printf("[JWKS] Loaded %d public keys.", len(publicKeys))
|
||||
return nil
|
||||
}
|
||||
|
||||
func parseRSAPublicKey(nStr, eStr string) (*rsa.PublicKey, error) {
|
||||
nBytes, err := jwt.DecodeSegment(nStr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
eBytes, err := jwt.DecodeSegment(eStr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
n := new(big.Int).SetBytes(nBytes)
|
||||
e := 0
|
||||
for _, b := range eBytes {
|
||||
e = e<<8 + int(b)
|
||||
}
|
||||
|
||||
return &rsa.PublicKey{N: n, E: e}, nil
|
||||
}
|
||||
|
||||
// ValidateJWT checks the Authorization: Bearer token using stored JWKS
|
||||
func ValidateJWT(authHeader string) error {
|
||||
if authHeader == "" || !strings.HasPrefix(authHeader, "Bearer ") {
|
||||
return errors.New("missing or invalid Authorization header")
|
||||
}
|
||||
tokenStr := strings.TrimPrefix(authHeader, "Bearer ")
|
||||
token, err := jwt.Parse(tokenStr, func(token *jwt.Token) (interface{}, error) {
|
||||
kid, _ := token.Header["kid"].(string)
|
||||
pubKey, ok := publicKeys[kid]
|
||||
if !ok {
|
||||
return nil, errors.New("unknown or missing kid in token header")
|
||||
}
|
||||
return pubKey, nil
|
||||
})
|
||||
if err != nil {
|
||||
return errors.New("invalid token: " + err.Error())
|
||||
}
|
||||
if !token.Valid {
|
||||
return errors.New("invalid token: token not valid")
|
||||
}
|
||||
return nil
|
||||
}
|
Loading…
Add table
Add a link
Reference in a new issue