Merge pull request #1 from shashimalcse/proxy_impl

Add OpenMCPAuthProxy
This commit is contained in:
Omindu Rathnaweera 2025-04-02 18:32:05 +05:30 committed by GitHub
commit 43d815769d
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
11 changed files with 893 additions and 1 deletions

8
.gitignore vendored
View file

@ -22,3 +22,11 @@
# virtual machine crash logs, see http://www.java.com/en/download/help/error_hotspot.xml
hs_err_pid*
replay_pid*
# Go module cache files
go.sum
# OS generated files
.DS_Store
openmcpauthproxy

View file

@ -1,2 +1,81 @@
# open-mcp-auth-proxy
Authentication and Authorization Proxy for MCP Servers
## Overview
OpenMCPAuthProxy is a security middleware that implements the Model Context Protocol (MCP) Authorization Specification (2025-03-26). It functions as a proxy between clients and MCP servers, providing robust authentication and authorization capabilities. The proxy intercepts incoming requests, validates authentication tokens, and forwards only authorized requests to the underlying MCP server, enhancing the security posture of your MCP deployment.
## Setup and Installation
### Prerequisites
- Go 1.20 or higher
- A running MCP server (SSE transport supported)
### Installation
```bash
git clone https://github.com/wso2/open-mcp-auth-proxy
cd open-mcp-auth-proxy
go build -o openmcpauthproxy ./cmd/proxy
```
## Configuration
Create a configuration file `config.yaml` with the following parameters:
```yaml
mcp_server_base_url: "http://localhost:8000" # URL of your MCP server
listen_address: ":8080" # Address where the proxy will listen
```
## Usage Example
### 1. Start the MCP Server
Create a file named `echo_server.py`:
```python
from mcp.server.fastmcp import FastMCP
mcp = FastMCP("Echo")
@mcp.resource("echo://{message}")
def echo_resource(message: str) -> str:
"""Echo a message as a resource"""
return f"Resource echo: {message}"
@mcp.tool()
def echo_tool(message: str) -> str:
"""Echo a message as a tool"""
return f"Tool echo: {message}"
@mcp.prompt()
def echo_prompt(message: str) -> str:
"""Create an echo prompt"""
return f"Please process this message: {message}"
if __name__ == "__main__":
mcp.run(transport="sse")
```
Run the server:
```bash
python3 echo_server.py
```
### 2. Start the Auth Proxy
```bash
./openmcpauthproxy --demo
```
The `--demo` flag enables a demonstration mode with pre-configured authentication with asgardeo.
### 3. Connect Using an MCP Client
You can use the [MCP Inspector](https://github.com/modelcontextprotocol/inspector) to test the connection:
## Contributing
Contributions are welcome! Please feel free to submit a Pull Request.

73
cmd/proxy/main.go Normal file
View file

@ -0,0 +1,73 @@
package main
import (
"flag"
"fmt"
"log"
"net/http"
"os"
"os/signal"
"time"
"github.com/wso2/open-mcp-auth-proxy/internal/authz"
"github.com/wso2/open-mcp-auth-proxy/internal/config"
"github.com/wso2/open-mcp-auth-proxy/internal/proxy"
"github.com/wso2/open-mcp-auth-proxy/internal/util"
)
func main() {
demoMode := flag.Bool("demo", false, "Use Asgardeo-based provider (demo).")
flag.Parse()
// 1. Load config
cfg, err := config.LoadConfig("config.yaml")
if err != nil {
log.Fatalf("Error loading config: %v", err)
}
// 2. Create the chosen provider
var provider authz.Provider
if *demoMode {
cfg.AuthServerBaseURL = "https://api.asgardeo.io/t/" + cfg.Demo.OrgName + "/oauth2"
cfg.JWKSURL = "https://api.asgardeo.io/t/" + cfg.Demo.OrgName + "/oauth2/jwks"
provider = authz.NewAsgardeoProvider(cfg)
fmt.Println("Using Asgardeo provider (demo).")
} else {
log.Fatalf("Not supported yet.")
}
// 3. (Optional) Fetch JWKS if you want local JWT validation
if err := util.FetchJWKS(cfg.JWKSURL); err != nil {
log.Fatalf("Failed to fetch JWKS: %v", err)
}
// 4. Build the main router
mux := proxy.NewRouter(cfg, provider)
// 5. Start the server
srv := &http.Server{
Addr: cfg.ListenAddress,
Handler: mux,
}
go func() {
log.Printf("Server listening on %s", cfg.ListenAddress)
if err := srv.ListenAndServe(); err != nil && err != http.ErrServerClosed {
log.Fatalf("Server error: %v", err)
}
}()
// 6. Graceful shutdown on Ctrl+C
stop := make(chan os.Signal, 1)
signal.Notify(stop, os.Interrupt)
<-stop
log.Println("Shutting down...")
shutdownCtx, cancel := proxy.NewShutdownContext(5 * time.Second)
defer cancel()
if err := srv.Shutdown(shutdownCtx); err != nil {
log.Printf("Shutdown error: %v", err)
}
log.Println("Stopped.")
}

18
config.yaml Normal file
View file

@ -0,0 +1,18 @@
# config.yaml
auth_server_base_url: ""
mcp_server_base_url: "http://localhost:8000"
listen_address: ":8080"
jwks_url: ""
timeout_seconds: 10
mcp_paths:
- /messages/
- /sse
path_mapping:
demo:
org_name: "openmcpauthdemo"
client_id: "N0U9e_NNGr9mP_0fPnPfPI0a6twa"
client_secret: "qFHfiBp5gNGAO9zV4YPnDofBzzfInatfUbHyPZvM0jka"

8
go.mod Normal file
View file

@ -0,0 +1,8 @@
module github.com/wso2/open-mcp-auth-proxy
go 1.22.3
require (
github.com/golang-jwt/jwt/v4 v4.5.2
gopkg.in/yaml.v2 v2.4.0
)

327
internal/authz/asgardeo.go Normal file
View file

@ -0,0 +1,327 @@
package authz
import (
"bytes"
"crypto/tls"
"encoding/base64"
"encoding/json"
"fmt"
"io"
"log"
"math/rand"
"net/http"
"strings"
"time"
"github.com/wso2/open-mcp-auth-proxy/internal/config"
)
type asgardeoProvider struct {
cfg *config.Config
}
// NewAsgardeoProvider initializes a Provider for Asgardeo (demo mode).
func NewAsgardeoProvider(cfg *config.Config) Provider {
return &asgardeoProvider{cfg: cfg}
}
func (p *asgardeoProvider) WellKnownHandler() http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Access-Control-Allow-Origin", "*")
w.Header().Set("Access-Control-Allow-Headers", "Authorization, Content-Type")
w.Header().Set("Access-Control-Allow-Methods", "GET, OPTIONS")
if r.Method == http.MethodOptions {
w.WriteHeader(http.StatusNoContent)
return
}
if r.Method != http.MethodGet {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
scheme := "http"
if r.TLS != nil {
scheme = "https"
}
if forwardedProto := r.Header.Get("X-Forwarded-Proto"); forwardedProto != "" {
scheme = forwardedProto
}
host := r.Host
if forwardedHost := r.Header.Get("X-Forwarded-Host"); forwardedHost != "" {
host = forwardedHost
}
baseURL := scheme + "://" + host
issuer := strings.TrimSuffix(p.cfg.AuthServerBaseURL, "/") + "/token"
response := map[string]interface{}{
"issuer": issuer,
"authorization_endpoint": baseURL + "/authorize",
"token_endpoint": baseURL + "/token",
"jwks_uri": p.cfg.JWKSURL,
"response_types_supported": []string{"code"},
"grant_types_supported": []string{"authorization_code", "refresh_token"},
"token_endpoint_auth_methods_supported": []string{"client_secret_basic"},
"registration_endpoint": baseURL + "/register",
"code_challenge_methods_supported": []string{"plain", "S256"},
}
w.Header().Set("Content-Type", "application/json")
if err := json.NewEncoder(w).Encode(response); err != nil {
log.Printf("[asgardeoProvider] Error encoding well-known: %v", err)
http.Error(w, "Internal server error", http.StatusInternalServerError)
}
}
}
func (p *asgardeoProvider) RegisterHandler() http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Access-Control-Allow-Origin", "*")
w.Header().Set("Access-Control-Allow-Headers", "Authorization, Content-Type")
w.Header().Set("Access-Control-Allow-Methods", "GET, OPTIONS")
if r.Method == http.MethodOptions {
w.WriteHeader(http.StatusNoContent)
return
}
if r.Method != http.MethodPost {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
var regReq RegisterRequest
if err := json.NewDecoder(r.Body).Decode(&regReq); err != nil {
log.Printf("ERROR: reading register request: %v", err)
http.Error(w, "Invalid request body", http.StatusBadRequest)
return
}
if len(regReq.RedirectURIs) == 0 {
http.Error(w, "redirect_uris is required", http.StatusBadRequest)
return
}
// Generate credentials
regReq.ClientID = "client-" + randomString(8)
regReq.ClientSecret = randomString(16)
if err := p.createAsgardeoApplication(regReq); err != nil {
log.Printf("WARN: Asgardeo application creation failed: %v", err)
// Optionally http.Error(...) if you want to fail
// or continue to return partial data.
}
resp := RegisterResponse{
ClientID: regReq.ClientID,
ClientSecret: regReq.ClientSecret,
ClientName: regReq.ClientName,
RedirectURIs: regReq.RedirectURIs,
GrantTypes: regReq.GrantTypes,
ResponseTypes: regReq.ResponseTypes,
}
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusCreated)
if err := json.NewEncoder(w).Encode(resp); err != nil {
log.Printf("ERROR: encoding /register response: %v", err)
http.Error(w, "Internal server error", http.StatusInternalServerError)
}
}
}
// ----------------------------------------------------------------
// Asgardeo-specific helpers
// ----------------------------------------------------------------
type RegisterRequest struct {
ClientID string `json:"client_id,omitempty"`
ClientSecret string `json:"client_secret,omitempty"`
ClientName string `json:"client_name"`
RedirectURIs []string `json:"redirect_uris"`
GrantTypes []string `json:"grant_types,omitempty"`
ResponseTypes []string `json:"response_types,omitempty"`
}
type RegisterResponse struct {
ClientID string `json:"client_id"`
ClientSecret string `json:"client_secret"`
ClientName string `json:"client_name"`
RedirectURIs []string `json:"redirect_uris"`
GrantTypes []string `json:"grant_types"`
ResponseTypes []string `json:"response_types"`
}
func (p *asgardeoProvider) createAsgardeoApplication(regReq RegisterRequest) error {
body := buildAsgardeoPayload(regReq)
reqBytes, err := json.Marshal(body)
if err != nil {
return fmt.Errorf("failed to marshal Asgardeo request: %w", err)
}
asgardeoAppURL := "https://api.asgardeo.io/t/" + p.cfg.Demo.OrgName + "/api/server/v1/applications"
req, err := http.NewRequest("POST", asgardeoAppURL, bytes.NewBuffer(reqBytes))
if err != nil {
return fmt.Errorf("failed to create Asgardeo API request: %w", err)
}
token, err := p.getAsgardeoAdminToken()
if err != nil {
return fmt.Errorf("failed to get Asgardeo admin token: %w", err)
}
req.Header.Set("Authorization", "Bearer "+token)
req.Header.Set("Content-Type", "application/json")
client := &http.Client{Timeout: 30 * time.Second}
resp, err := client.Do(req)
if err != nil {
return fmt.Errorf("asgardeo API call failed: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode >= 400 {
respBody, _ := io.ReadAll(resp.Body)
return fmt.Errorf("Asgardeo creation error (%d): %s", resp.StatusCode, string(respBody))
}
log.Printf("INFO: Created Asgardeo application for clientID=%s", regReq.ClientID)
return nil
}
func (p *asgardeoProvider) getAsgardeoAdminToken() (string, error) {
tokenURL := p.cfg.AuthServerBaseURL + "/token"
formData := "grant_type=client_credentials&scope=internal_application_mgt_create internal_application_mgt_delete " +
"internal_application_mgt_update internal_application_mgt_view"
req, err := http.NewRequest("POST", tokenURL, strings.NewReader(formData))
if err != nil {
return "", err
}
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
auth := p.cfg.Demo.ClientID + ":" + p.cfg.Demo.ClientSecret
req.Header.Set("Authorization", "Basic "+base64.StdEncoding.EncodeToString([]byte(auth)))
tr := &http.Transport{
TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
}
client := &http.Client{
Timeout: time.Duration(p.cfg.TimeoutSeconds) * time.Second,
Transport: tr,
}
resp, err := client.Do(req)
if err != nil {
return "", fmt.Errorf("token request failed: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(resp.Body)
return "", fmt.Errorf("token request failed (%d): %s", resp.StatusCode, string(body))
}
var tokenResp struct {
AccessToken string `json:"access_token"`
TokenType string `json:"token_type"`
ExpiresIn int `json:"expires_in"`
Scope string `json:"scope"`
}
if err := json.NewDecoder(resp.Body).Decode(&tokenResp); err != nil {
return "", fmt.Errorf("failed to parse token JSON: %w", err)
}
return tokenResp.AccessToken, nil
}
func buildAsgardeoPayload(regReq RegisterRequest) map[string]interface{} {
appName := regReq.ClientName
if appName == "" {
appName = "demo-app"
}
appName += "-" + randomString(5)
return map[string]interface{}{
"name": appName,
"templateId": "custom-application-oidc",
"inboundProtocolConfiguration": map[string]interface{}{
"oidc": map[string]interface{}{
"clientId": regReq.ClientID,
"clientSecret": regReq.ClientSecret,
"grantTypes": regReq.GrantTypes,
"callbackURLs": regReq.RedirectURIs,
"allowedOrigins": []string{},
"publicClient": false,
"pkce": map[string]bool{
"mandatory": true,
"supportPlainTransformAlgorithm": true,
},
"accessToken": map[string]interface{}{
"type": "JWT",
"userAccessTokenExpiryInSeconds": 3600,
"applicationAccessTokenExpiryInSeconds": 3600,
"bindingType": "cookie",
"revokeTokensWhenIDPSessionTerminated": true,
"validateTokenBinding": true,
},
"refreshToken": map[string]interface{}{
"expiryInSeconds": 86400,
"renewRefreshToken": true,
},
"idToken": map[string]interface{}{
"expiryInSeconds": 3600,
"audience": []string{},
"encryption": map[string]interface{}{
"enabled": false,
"algorithm": "RSA-OAEP",
"method": "A128CBC+HS256",
},
},
"logout": map[string]interface{}{},
"validateRequestObjectSignature": false,
},
},
"authenticationSequence": map[string]interface{}{
"type": "USER_DEFINED",
"steps": []map[string]interface{}{
{
"id": 1,
"options": []map[string]string{
{
"idp": "Google",
"authenticator": "GoogleOIDCAuthenticator",
},
{
"idp": "GitHub",
"authenticator": "GithubAuthenticator",
},
{
"idp": "Microsoft",
"authenticator": "OpenIDConnectAuthenticator",
},
},
},
},
"script": "var onLoginRequest = function(context) {\n executeStep(1);\n};\n",
"subjectStepId": 1,
"attributeStepId": 1,
},
"advancedConfigurations": map[string]interface{}{
"skipLoginConsent": false,
"skipLogoutConsent": false,
},
}
}
const letters = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"
func randomString(n int) string {
b := make([]byte, n)
for i := 0; i < n; i++ {
b[i] = letters[rand.Intn(len(letters))]
}
return string(b)
}

View file

@ -0,0 +1,10 @@
package authz
import "net/http"
// Provider is an interface describing how each auth provider
// will handle /.well-known/oauth-authorization-server and /register
type Provider interface {
WellKnownHandler() http.HandlerFunc
RegisterHandler() http.HandlerFunc
}

46
internal/config/config.go Normal file
View file

@ -0,0 +1,46 @@
package config
import (
"os"
"gopkg.in/yaml.v2"
)
// AsgardeoConfig groups all Asgardeo-specific fields
type DemoConfig struct {
ClientID string `yaml:"client_id"`
ClientSecret string `yaml:"client_secret"`
OrgName string `yaml:"org_name"`
}
type Config struct {
AuthServerBaseURL string `yaml:"auth_server_base_url"`
MCPServerBaseURL string `yaml:"mcp_server_base_url"`
ListenAddress string `yaml:"listen_address"`
JWKSURL string `yaml:"jwks_url"`
TimeoutSeconds int `yaml:"timeout_seconds"`
MCPPaths []string `yaml:"mcp_paths"`
PathMapping map[string]string `yaml:"path_mapping"`
// Nested config for Asgardeo
Demo DemoConfig `yaml:"demo"`
}
// LoadConfig reads a YAML config file into Config struct.
func LoadConfig(path string) (*Config, error) {
f, err := os.Open(path)
if err != nil {
return nil, err
}
defer f.Close()
var cfg Config
decoder := yaml.NewDecoder(f)
if err := decoder.Decode(&cfg); err != nil {
return nil, err
}
if cfg.TimeoutSeconds == 0 {
cfg.TimeoutSeconds = 15 // default
}
return &cfg, nil
}

192
internal/proxy/proxy.go Normal file
View file

@ -0,0 +1,192 @@
package proxy
import (
"context"
"log"
"net/http"
"net/http/httputil"
"net/url"
"strings"
"time"
"github.com/wso2/open-mcp-auth-proxy/internal/authz"
"github.com/wso2/open-mcp-auth-proxy/internal/config"
"github.com/wso2/open-mcp-auth-proxy/internal/util"
)
// NewRouter builds an http.ServeMux that routes
// * /authorize, /token, /register, /.well-known to the provider or proxy
// * MCP paths to the MCP server, etc.
func NewRouter(cfg *config.Config, provider authz.Provider) http.Handler {
mux := http.NewServeMux()
// 1. Custom well-known
mux.HandleFunc("/.well-known/oauth-authorization-server", provider.WellKnownHandler())
// 2. Registration
mux.HandleFunc("/register", provider.RegisterHandler())
// 3. Default "auth" paths, proxied
defaultPaths := []string{"/authorize", "/token"}
for _, path := range defaultPaths {
mux.HandleFunc(path, buildProxyHandler(cfg))
}
// 4. MCP paths
for _, path := range cfg.MCPPaths {
mux.HandleFunc(path, buildProxyHandler(cfg))
}
// 5. If you want to map additional paths from config.PathMapping
// to the same proxy logic:
for path := range cfg.PathMapping {
mux.HandleFunc(path, buildProxyHandler(cfg))
}
return mux
}
func buildProxyHandler(cfg *config.Config) http.HandlerFunc {
// Parse the base URLs up front
authBase, err := url.Parse(cfg.AuthServerBaseURL)
if err != nil {
log.Fatalf("Invalid auth server URL: %v", err)
}
mcpBase, err := url.Parse(cfg.MCPServerBaseURL)
if err != nil {
log.Fatalf("Invalid MCP server URL: %v", err)
}
// We'll define sets for known auth paths, SSE paths, etc.
authPaths := map[string]bool{
"/authorize": true,
"/token": true,
"/.well-known/oauth-authorization-server": true,
}
// Detect SSE paths from config
ssePaths := make(map[string]bool)
for _, p := range cfg.MCPPaths {
if p == "/sse" {
ssePaths[p] = true
}
}
return func(w http.ResponseWriter, r *http.Request) {
// Handle OPTIONS
if r.Method == http.MethodOptions {
addCORSHeaders(w)
w.WriteHeader(http.StatusNoContent)
return
}
addCORSHeaders(w)
// Decide whether the request should go to the auth server or MCP
var targetURL *url.URL
isSSE := false
if authPaths[r.URL.Path] {
targetURL = authBase
} else if isMCPPath(r.URL.Path, cfg) {
// Validate JWT if you want
if err := util.ValidateJWT(r.Header.Get("Authorization")); err != nil {
log.Printf("[proxy] Unauthorized request to %s: %v", r.URL.Path, err)
http.Error(w, "Unauthorized", http.StatusUnauthorized)
return
}
targetURL = mcpBase
if ssePaths[r.URL.Path] {
isSSE = true
}
} else {
// If it's not recognized as an auth path or an MCP path
http.Error(w, "Forbidden", http.StatusForbidden)
return
}
// Build the reverse proxy
rp := &httputil.ReverseProxy{
Director: func(req *http.Request) {
// Path rewriting if needed
mapped := r.URL.Path
if rewrite, ok := cfg.PathMapping[r.URL.Path]; ok {
mapped = rewrite
}
basePath := strings.TrimRight(targetURL.Path, "/")
req.URL.Scheme = targetURL.Scheme
req.URL.Host = targetURL.Host
req.URL.Path = basePath + mapped
req.URL.RawQuery = r.URL.RawQuery
req.Host = targetURL.Host
for header, values := range r.Header {
// Skip hop-by-hop headers
if strings.EqualFold(header, "Connection") ||
strings.EqualFold(header, "Keep-Alive") ||
strings.EqualFold(header, "Transfer-Encoding") ||
strings.EqualFold(header, "Upgrade") ||
strings.EqualFold(header, "Proxy-Authorization") ||
strings.EqualFold(header, "Proxy-Connection") {
continue
}
for _, value := range values {
req.Header.Set(header, value)
}
}
log.Printf("[proxy] %s -> %s%s", r.URL.Path, req.URL.Host, req.URL.Path)
},
ErrorHandler: func(rw http.ResponseWriter, req *http.Request, err error) {
log.Printf("[proxy] Error proxying: %v", err)
http.Error(rw, "Bad Gateway", http.StatusBadGateway)
},
FlushInterval: -1, // immediate flush for SSE
}
if isSSE {
// Keep SSE connections open
HandleSSE(w, r, rp)
} else {
// Standard requests: enforce a timeout
ctx, cancel := context.WithTimeout(r.Context(), time.Duration(cfg.TimeoutSeconds)*time.Second)
defer cancel()
rp.ServeHTTP(w, r.WithContext(ctx))
}
}
}
func addCORSHeaders(w http.ResponseWriter) {
w.Header().Set("Access-Control-Allow-Origin", "*")
w.Header().Set("Access-Control-Allow-Headers", "Authorization, Content-Type")
w.Header().Set("Access-Control-Allow-Methods", "GET, POST, OPTIONS")
}
func isMCPPath(path string, cfg *config.Config) bool {
for _, p := range cfg.MCPPaths {
if strings.HasPrefix(path, p) {
return true
}
}
return false
}
func copyHeaders(src http.Header, dst http.Header) {
// Exclude hop-by-hop
hopByHop := map[string]bool{
"Connection": true,
"Keep-Alive": true,
"Transfer-Encoding": true,
"Upgrade": true,
"Proxy-Authorization": true,
"Proxy-Connection": true,
}
for k, vv := range src {
if hopByHop[strings.ToLower(k)] {
continue
}
for _, v := range vv {
dst.Add(k, v)
}
}
}

34
internal/proxy/sse.go Normal file
View file

@ -0,0 +1,34 @@
package proxy
import (
"context"
"log"
"net/http"
"net/http/httputil"
"time"
)
// HandleSSE sets up a go-routine to wait for context cancellation
// and flushes the response if possible.
func HandleSSE(w http.ResponseWriter, r *http.Request, rp *httputil.ReverseProxy) {
ctx := r.Context()
done := make(chan struct{})
go func() {
<-ctx.Done()
log.Printf("INFO: SSE connection closed from %s (path: %s)", r.RemoteAddr, r.URL.Path)
close(done)
}()
rp.ServeHTTP(w, r)
if f, ok := w.(http.Flusher); ok {
f.Flush()
}
<-done
}
// NewShutdownContext is a little helper to gracefully shut down
func NewShutdownContext(timeout time.Duration) (context.Context, context.CancelFunc) {
return context.WithTimeout(context.Background(), timeout)
}

97
internal/util/jwks.go Normal file
View file

@ -0,0 +1,97 @@
package util
import (
"crypto/rsa"
"encoding/json"
"errors"
"log"
"math/big"
"net/http"
"strings"
"github.com/golang-jwt/jwt/v4"
)
type JWKS struct {
Keys []json.RawMessage `json:"keys"`
}
var publicKeys map[string]*rsa.PublicKey
// FetchJWKS downloads JWKS and stores in a package-level map
func FetchJWKS(jwksURL string) error {
resp, err := http.Get(jwksURL)
if err != nil {
return err
}
defer resp.Body.Close()
var jwks JWKS
if err := json.NewDecoder(resp.Body).Decode(&jwks); err != nil {
return err
}
publicKeys = make(map[string]*rsa.PublicKey)
for _, keyData := range jwks.Keys {
var parsedKey struct {
Kid string `json:"kid"`
N string `json:"n"`
E string `json:"e"`
Kty string `json:"kty"`
}
if err := json.Unmarshal(keyData, &parsedKey); err != nil {
continue
}
if parsedKey.Kty != "RSA" {
continue
}
pubKey, err := parseRSAPublicKey(parsedKey.N, parsedKey.E)
if err == nil {
publicKeys[parsedKey.Kid] = pubKey
}
}
log.Printf("[JWKS] Loaded %d public keys.", len(publicKeys))
return nil
}
func parseRSAPublicKey(nStr, eStr string) (*rsa.PublicKey, error) {
nBytes, err := jwt.DecodeSegment(nStr)
if err != nil {
return nil, err
}
eBytes, err := jwt.DecodeSegment(eStr)
if err != nil {
return nil, err
}
n := new(big.Int).SetBytes(nBytes)
e := 0
for _, b := range eBytes {
e = e<<8 + int(b)
}
return &rsa.PublicKey{N: n, E: e}, nil
}
// ValidateJWT checks the Authorization: Bearer token using stored JWKS
func ValidateJWT(authHeader string) error {
if authHeader == "" || !strings.HasPrefix(authHeader, "Bearer ") {
return errors.New("missing or invalid Authorization header")
}
tokenStr := strings.TrimPrefix(authHeader, "Bearer ")
token, err := jwt.Parse(tokenStr, func(token *jwt.Token) (interface{}, error) {
kid, _ := token.Header["kid"].(string)
pubKey, ok := publicKeys[kid]
if !ok {
return nil, errors.New("unknown or missing kid in token header")
}
return pubKey, nil
})
if err != nil {
return errors.New("invalid token: " + err.Error())
}
if !token.Valid {
return errors.New("invalid token: token not valid")
}
return nil
}