mirror of
https://github.com/wso2/open-mcp-auth-proxy.git
synced 2025-06-27 17:13:31 +00:00
improve readme
This commit is contained in:
parent
7b727c03a3
commit
4e957e93a2
11 changed files with 889 additions and 1 deletions
323
internal/authz/asgardeo.go
Normal file
323
internal/authz/asgardeo.go
Normal file
|
@ -0,0 +1,323 @@
|
|||
package authz
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/tls"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"math/rand"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/wso2/open-mcp-auth-proxy/internal/config"
|
||||
)
|
||||
|
||||
type asgardeoProvider struct {
|
||||
cfg *config.Config
|
||||
}
|
||||
|
||||
// NewAsgardeoProvider initializes a Provider for Asgardeo (demo mode).
|
||||
func NewAsgardeoProvider(cfg *config.Config) Provider {
|
||||
return &asgardeoProvider{cfg: cfg}
|
||||
}
|
||||
|
||||
func (p *asgardeoProvider) WellKnownHandler() http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
w.Header().Set("Access-Control-Allow-Origin", "*")
|
||||
w.Header().Set("Access-Control-Allow-Headers", "Authorization, Content-Type")
|
||||
w.Header().Set("Access-Control-Allow-Methods", "GET, OPTIONS")
|
||||
|
||||
if r.Method == http.MethodOptions {
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
return
|
||||
}
|
||||
if r.Method != http.MethodGet {
|
||||
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
||||
return
|
||||
}
|
||||
|
||||
scheme := "http"
|
||||
if r.TLS != nil {
|
||||
scheme = "https"
|
||||
}
|
||||
if forwardedProto := r.Header.Get("X-Forwarded-Proto"); forwardedProto != "" {
|
||||
scheme = forwardedProto
|
||||
}
|
||||
host := r.Host
|
||||
if forwardedHost := r.Header.Get("X-Forwarded-Host"); forwardedHost != "" {
|
||||
host = forwardedHost
|
||||
}
|
||||
|
||||
baseURL := scheme + "://" + host
|
||||
|
||||
issuer := strings.TrimSuffix(p.cfg.AuthServerBaseURL, "/") + "/token"
|
||||
|
||||
response := map[string]interface{}{
|
||||
"issuer": issuer,
|
||||
"authorization_endpoint": baseURL + "/authorize",
|
||||
"token_endpoint": baseURL + "/token",
|
||||
"jwks_uri": p.cfg.JWKSURL,
|
||||
"response_types_supported": []string{"code"},
|
||||
"grant_types_supported": []string{"authorization_code", "refresh_token"},
|
||||
"token_endpoint_auth_methods_supported": []string{"client_secret_basic"},
|
||||
"registration_endpoint": baseURL + "/register",
|
||||
"code_challenge_methods_supported": []string{"plain", "S256"},
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
if err := json.NewEncoder(w).Encode(response); err != nil {
|
||||
log.Printf("[asgardeoProvider] Error encoding well-known: %v", err)
|
||||
http.Error(w, "Internal server error", http.StatusInternalServerError)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (p *asgardeoProvider) RegisterHandler() http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
w.Header().Set("Access-Control-Allow-Origin", "*")
|
||||
w.Header().Set("Access-Control-Allow-Headers", "Authorization, Content-Type")
|
||||
w.Header().Set("Access-Control-Allow-Methods", "GET, OPTIONS")
|
||||
|
||||
if r.Method == http.MethodOptions {
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
return
|
||||
}
|
||||
if r.Method != http.MethodPost {
|
||||
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
||||
return
|
||||
}
|
||||
|
||||
var regReq RegisterRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(®Req); err != nil {
|
||||
log.Printf("ERROR: reading register request: %v", err)
|
||||
http.Error(w, "Invalid request body", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
if len(regReq.RedirectURIs) == 0 {
|
||||
http.Error(w, "redirect_uris is required", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
// Generate credentials
|
||||
regReq.ClientID = "client-" + randomString(8)
|
||||
regReq.ClientSecret = randomString(16)
|
||||
|
||||
if err := p.createAsgardeoApplication(regReq); err != nil {
|
||||
log.Printf("WARN: Asgardeo application creation failed: %v", err)
|
||||
// Optionally http.Error(...) if you want to fail
|
||||
// or continue to return partial data.
|
||||
}
|
||||
|
||||
resp := RegisterResponse{
|
||||
ClientID: regReq.ClientID,
|
||||
ClientSecret: regReq.ClientSecret,
|
||||
ClientName: regReq.ClientName,
|
||||
RedirectURIs: regReq.RedirectURIs,
|
||||
GrantTypes: regReq.GrantTypes,
|
||||
ResponseTypes: regReq.ResponseTypes,
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusCreated)
|
||||
if err := json.NewEncoder(w).Encode(resp); err != nil {
|
||||
log.Printf("ERROR: encoding /register response: %v", err)
|
||||
http.Error(w, "Internal server error", http.StatusInternalServerError)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ----------------------------------------------------------------
|
||||
// Asgardeo-specific helpers
|
||||
// ----------------------------------------------------------------
|
||||
|
||||
type RegisterRequest struct {
|
||||
ClientID string `json:"client_id,omitempty"`
|
||||
ClientSecret string `json:"client_secret,omitempty"`
|
||||
ClientName string `json:"client_name"`
|
||||
RedirectURIs []string `json:"redirect_uris"`
|
||||
GrantTypes []string `json:"grant_types,omitempty"`
|
||||
ResponseTypes []string `json:"response_types,omitempty"`
|
||||
}
|
||||
|
||||
type RegisterResponse struct {
|
||||
ClientID string `json:"client_id"`
|
||||
ClientSecret string `json:"client_secret"`
|
||||
ClientName string `json:"client_name"`
|
||||
RedirectURIs []string `json:"redirect_uris"`
|
||||
GrantTypes []string `json:"grant_types"`
|
||||
ResponseTypes []string `json:"response_types"`
|
||||
}
|
||||
|
||||
func (p *asgardeoProvider) createAsgardeoApplication(regReq RegisterRequest) error {
|
||||
body := buildAsgardeoPayload(regReq)
|
||||
reqBytes, err := json.Marshal(body)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal Asgardeo request: %w", err)
|
||||
}
|
||||
|
||||
asgardeoAppURL := "https://api.asgardeo.io/t/" + p.cfg.Demo.OrgName + "/api/server/v1/applications"
|
||||
req, err := http.NewRequest("POST", asgardeoAppURL, bytes.NewBuffer(reqBytes))
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create Asgardeo API request: %w", err)
|
||||
}
|
||||
|
||||
token, err := p.getAsgardeoAdminToken()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get Asgardeo admin token: %w", err)
|
||||
}
|
||||
req.Header.Set("Authorization", "Bearer "+token)
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
client := &http.Client{Timeout: 30 * time.Second}
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return fmt.Errorf("asgardeo API call failed: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode >= 400 {
|
||||
respBody, _ := io.ReadAll(resp.Body)
|
||||
return fmt.Errorf("Asgardeo creation error (%d): %s", resp.StatusCode, string(respBody))
|
||||
}
|
||||
|
||||
log.Printf("INFO: Created Asgardeo application for clientID=%s", regReq.ClientID)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *asgardeoProvider) getAsgardeoAdminToken() (string, error) {
|
||||
tokenURL := p.cfg.AuthServerBaseURL + "/token"
|
||||
|
||||
formData := "grant_type=client_credentials&scope=internal_application_mgt_create internal_application_mgt_delete " +
|
||||
"internal_application_mgt_update internal_application_mgt_view"
|
||||
|
||||
req, err := http.NewRequest("POST", tokenURL, strings.NewReader(formData))
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
|
||||
auth := p.cfg.Demo.ClientID + ":" + p.cfg.Demo.ClientSecret
|
||||
req.Header.Set("Authorization", "Basic "+base64.StdEncoding.EncodeToString([]byte(auth)))
|
||||
|
||||
tr := &http.Transport{
|
||||
TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
|
||||
}
|
||||
client := &http.Client{
|
||||
Timeout: time.Duration(p.cfg.TimeoutSeconds) * time.Second,
|
||||
Transport: tr,
|
||||
}
|
||||
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("token request failed: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
return "", fmt.Errorf("token request failed (%d): %s", resp.StatusCode, string(body))
|
||||
}
|
||||
|
||||
var tokenResp struct {
|
||||
AccessToken string `json:"access_token"`
|
||||
TokenType string `json:"token_type"`
|
||||
ExpiresIn int `json:"expires_in"`
|
||||
Scope string `json:"scope"`
|
||||
}
|
||||
if err := json.NewDecoder(resp.Body).Decode(&tokenResp); err != nil {
|
||||
return "", fmt.Errorf("failed to parse token JSON: %w", err)
|
||||
}
|
||||
|
||||
return tokenResp.AccessToken, nil
|
||||
}
|
||||
|
||||
func buildAsgardeoPayload(regReq RegisterRequest) map[string]interface{} {
|
||||
appName := regReq.ClientName
|
||||
if appName == "" {
|
||||
appName = "demo-app"
|
||||
}
|
||||
appName += "-" + randomString(5)
|
||||
|
||||
return map[string]interface{}{
|
||||
"name": appName,
|
||||
"templateId": "custom-application-oidc",
|
||||
"inboundProtocolConfiguration": map[string]interface{}{
|
||||
"oidc": map[string]interface{}{
|
||||
"clientId": regReq.ClientID,
|
||||
"clientSecret": regReq.ClientSecret,
|
||||
"grantTypes": regReq.GrantTypes,
|
||||
"callbackURLs": regReq.RedirectURIs,
|
||||
"allowedOrigins": []string{},
|
||||
"publicClient": false,
|
||||
"pkce": map[string]bool{
|
||||
"mandatory": true,
|
||||
"supportPlainTransformAlgorithm": true,
|
||||
},
|
||||
"accessToken": map[string]interface{}{
|
||||
"type": "JWT",
|
||||
"userAccessTokenExpiryInSeconds": 3600,
|
||||
"applicationAccessTokenExpiryInSeconds": 3600,
|
||||
"bindingType": "cookie",
|
||||
"revokeTokensWhenIDPSessionTerminated": true,
|
||||
"validateTokenBinding": true,
|
||||
},
|
||||
"refreshToken": map[string]interface{}{
|
||||
"expiryInSeconds": 86400,
|
||||
"renewRefreshToken": true,
|
||||
},
|
||||
"idToken": map[string]interface{}{
|
||||
"expiryInSeconds": 3600,
|
||||
"audience": []string{},
|
||||
"encryption": map[string]interface{}{
|
||||
"enabled": false,
|
||||
"algorithm": "RSA-OAEP",
|
||||
"method": "A128CBC+HS256",
|
||||
},
|
||||
},
|
||||
"logout": map[string]interface{}{},
|
||||
"validateRequestObjectSignature": false,
|
||||
},
|
||||
},
|
||||
"authenticationSequence": map[string]interface{}{
|
||||
"type": "USER_DEFINED",
|
||||
"steps": []map[string]interface{}{
|
||||
{
|
||||
"id": 1,
|
||||
"options": []map[string]string{
|
||||
{
|
||||
"idp": "Google",
|
||||
"authenticator": "GoogleOIDCAuthenticator",
|
||||
},
|
||||
{
|
||||
"idp": "GitHub",
|
||||
"authenticator": "GithubAuthenticator",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
"script": "var onLoginRequest = function(context) {\n executeStep(1);\n};\n",
|
||||
"subjectStepId": 1,
|
||||
"attributeStepId": 1,
|
||||
},
|
||||
"advancedConfigurations": map[string]interface{}{
|
||||
"skipLoginConsent": false,
|
||||
"skipLogoutConsent": false,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
const letters = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"
|
||||
|
||||
func randomString(n int) string {
|
||||
b := make([]byte, n)
|
||||
for i := 0; i < n; i++ {
|
||||
b[i] = letters[rand.Intn(len(letters))]
|
||||
}
|
||||
return string(b)
|
||||
}
|
10
internal/authz/provider.go
Normal file
10
internal/authz/provider.go
Normal file
|
@ -0,0 +1,10 @@
|
|||
package authz
|
||||
|
||||
import "net/http"
|
||||
|
||||
// Provider is an interface describing how each auth provider
|
||||
// will handle /.well-known/oauth-authorization-server and /register
|
||||
type Provider interface {
|
||||
WellKnownHandler() http.HandlerFunc
|
||||
RegisterHandler() http.HandlerFunc
|
||||
}
|
46
internal/config/config.go
Normal file
46
internal/config/config.go
Normal file
|
@ -0,0 +1,46 @@
|
|||
package config
|
||||
|
||||
import (
|
||||
"os"
|
||||
|
||||
"gopkg.in/yaml.v2"
|
||||
)
|
||||
|
||||
// AsgardeoConfig groups all Asgardeo-specific fields
|
||||
type DemoConfig struct {
|
||||
ClientID string `yaml:"client_id"`
|
||||
ClientSecret string `yaml:"client_secret"`
|
||||
OrgName string `yaml:"org_name"`
|
||||
}
|
||||
|
||||
type Config struct {
|
||||
AuthServerBaseURL string `yaml:"auth_server_base_url"`
|
||||
MCPServerBaseURL string `yaml:"mcp_server_base_url"`
|
||||
ListenAddress string `yaml:"listen_address"`
|
||||
JWKSURL string `yaml:"jwks_url"`
|
||||
TimeoutSeconds int `yaml:"timeout_seconds"`
|
||||
MCPPaths []string `yaml:"mcp_paths"`
|
||||
PathMapping map[string]string `yaml:"path_mapping"`
|
||||
|
||||
// Nested config for Asgardeo
|
||||
Demo DemoConfig `yaml:"demo"`
|
||||
}
|
||||
|
||||
// LoadConfig reads a YAML config file into Config struct.
|
||||
func LoadConfig(path string) (*Config, error) {
|
||||
f, err := os.Open(path)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
var cfg Config
|
||||
decoder := yaml.NewDecoder(f)
|
||||
if err := decoder.Decode(&cfg); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if cfg.TimeoutSeconds == 0 {
|
||||
cfg.TimeoutSeconds = 15 // default
|
||||
}
|
||||
return &cfg, nil
|
||||
}
|
192
internal/proxy/proxy.go
Normal file
192
internal/proxy/proxy.go
Normal file
|
@ -0,0 +1,192 @@
|
|||
package proxy
|
||||
|
||||
import (
|
||||
"context"
|
||||
"log"
|
||||
"net/http"
|
||||
"net/http/httputil"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/wso2/open-mcp-auth-proxy/internal/authz"
|
||||
"github.com/wso2/open-mcp-auth-proxy/internal/config"
|
||||
"github.com/wso2/open-mcp-auth-proxy/internal/util"
|
||||
)
|
||||
|
||||
// NewRouter builds an http.ServeMux that routes
|
||||
// * /authorize, /token, /register, /.well-known to the provider or proxy
|
||||
// * MCP paths to the MCP server, etc.
|
||||
func NewRouter(cfg *config.Config, provider authz.Provider) http.Handler {
|
||||
mux := http.NewServeMux()
|
||||
|
||||
// 1. Custom well-known
|
||||
mux.HandleFunc("/.well-known/oauth-authorization-server", provider.WellKnownHandler())
|
||||
|
||||
// 2. Registration
|
||||
mux.HandleFunc("/register", provider.RegisterHandler())
|
||||
|
||||
// 3. Default "auth" paths, proxied
|
||||
defaultPaths := []string{"/authorize", "/token"}
|
||||
for _, path := range defaultPaths {
|
||||
mux.HandleFunc(path, buildProxyHandler(cfg))
|
||||
}
|
||||
|
||||
// 4. MCP paths
|
||||
for _, path := range cfg.MCPPaths {
|
||||
mux.HandleFunc(path, buildProxyHandler(cfg))
|
||||
}
|
||||
|
||||
// 5. If you want to map additional paths from config.PathMapping
|
||||
// to the same proxy logic:
|
||||
for path := range cfg.PathMapping {
|
||||
mux.HandleFunc(path, buildProxyHandler(cfg))
|
||||
}
|
||||
|
||||
return mux
|
||||
}
|
||||
|
||||
func buildProxyHandler(cfg *config.Config) http.HandlerFunc {
|
||||
// Parse the base URLs up front
|
||||
authBase, err := url.Parse(cfg.AuthServerBaseURL)
|
||||
if err != nil {
|
||||
log.Fatalf("Invalid auth server URL: %v", err)
|
||||
}
|
||||
mcpBase, err := url.Parse(cfg.MCPServerBaseURL)
|
||||
if err != nil {
|
||||
log.Fatalf("Invalid MCP server URL: %v", err)
|
||||
}
|
||||
|
||||
// We'll define sets for known auth paths, SSE paths, etc.
|
||||
authPaths := map[string]bool{
|
||||
"/authorize": true,
|
||||
"/token": true,
|
||||
"/.well-known/oauth-authorization-server": true,
|
||||
}
|
||||
|
||||
// Detect SSE paths from config
|
||||
ssePaths := make(map[string]bool)
|
||||
for _, p := range cfg.MCPPaths {
|
||||
if p == "/sse" {
|
||||
ssePaths[p] = true
|
||||
}
|
||||
}
|
||||
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
// Handle OPTIONS
|
||||
if r.Method == http.MethodOptions {
|
||||
addCORSHeaders(w)
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
return
|
||||
}
|
||||
|
||||
addCORSHeaders(w)
|
||||
|
||||
// Decide whether the request should go to the auth server or MCP
|
||||
var targetURL *url.URL
|
||||
isSSE := false
|
||||
|
||||
if authPaths[r.URL.Path] {
|
||||
targetURL = authBase
|
||||
} else if isMCPPath(r.URL.Path, cfg) {
|
||||
// Validate JWT if you want
|
||||
if err := util.ValidateJWT(r.Header.Get("Authorization")); err != nil {
|
||||
log.Printf("[proxy] Unauthorized request to %s: %v", r.URL.Path, err)
|
||||
http.Error(w, "Unauthorized", http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
targetURL = mcpBase
|
||||
if ssePaths[r.URL.Path] {
|
||||
isSSE = true
|
||||
}
|
||||
} else {
|
||||
// If it's not recognized as an auth path or an MCP path
|
||||
http.Error(w, "Forbidden", http.StatusForbidden)
|
||||
return
|
||||
}
|
||||
|
||||
// Build the reverse proxy
|
||||
rp := &httputil.ReverseProxy{
|
||||
Director: func(req *http.Request) {
|
||||
// Path rewriting if needed
|
||||
mapped := r.URL.Path
|
||||
if rewrite, ok := cfg.PathMapping[r.URL.Path]; ok {
|
||||
mapped = rewrite
|
||||
}
|
||||
basePath := strings.TrimRight(targetURL.Path, "/")
|
||||
req.URL.Scheme = targetURL.Scheme
|
||||
req.URL.Host = targetURL.Host
|
||||
req.URL.Path = basePath + mapped
|
||||
req.URL.RawQuery = r.URL.RawQuery
|
||||
req.Host = targetURL.Host
|
||||
|
||||
for header, values := range r.Header {
|
||||
// Skip hop-by-hop headers
|
||||
if strings.EqualFold(header, "Connection") ||
|
||||
strings.EqualFold(header, "Keep-Alive") ||
|
||||
strings.EqualFold(header, "Transfer-Encoding") ||
|
||||
strings.EqualFold(header, "Upgrade") ||
|
||||
strings.EqualFold(header, "Proxy-Authorization") ||
|
||||
strings.EqualFold(header, "Proxy-Connection") {
|
||||
continue
|
||||
}
|
||||
|
||||
for _, value := range values {
|
||||
req.Header.Set(header, value)
|
||||
}
|
||||
}
|
||||
log.Printf("[proxy] %s -> %s%s", r.URL.Path, req.URL.Host, req.URL.Path)
|
||||
},
|
||||
ErrorHandler: func(rw http.ResponseWriter, req *http.Request, err error) {
|
||||
log.Printf("[proxy] Error proxying: %v", err)
|
||||
http.Error(rw, "Bad Gateway", http.StatusBadGateway)
|
||||
},
|
||||
FlushInterval: -1, // immediate flush for SSE
|
||||
}
|
||||
|
||||
if isSSE {
|
||||
// Keep SSE connections open
|
||||
HandleSSE(w, r, rp)
|
||||
} else {
|
||||
// Standard requests: enforce a timeout
|
||||
ctx, cancel := context.WithTimeout(r.Context(), time.Duration(cfg.TimeoutSeconds)*time.Second)
|
||||
defer cancel()
|
||||
rp.ServeHTTP(w, r.WithContext(ctx))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func addCORSHeaders(w http.ResponseWriter) {
|
||||
w.Header().Set("Access-Control-Allow-Origin", "*")
|
||||
w.Header().Set("Access-Control-Allow-Headers", "Authorization, Content-Type")
|
||||
w.Header().Set("Access-Control-Allow-Methods", "GET, POST, OPTIONS")
|
||||
}
|
||||
|
||||
func isMCPPath(path string, cfg *config.Config) bool {
|
||||
for _, p := range cfg.MCPPaths {
|
||||
if strings.HasPrefix(path, p) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func copyHeaders(src http.Header, dst http.Header) {
|
||||
// Exclude hop-by-hop
|
||||
hopByHop := map[string]bool{
|
||||
"Connection": true,
|
||||
"Keep-Alive": true,
|
||||
"Transfer-Encoding": true,
|
||||
"Upgrade": true,
|
||||
"Proxy-Authorization": true,
|
||||
"Proxy-Connection": true,
|
||||
}
|
||||
for k, vv := range src {
|
||||
if hopByHop[strings.ToLower(k)] {
|
||||
continue
|
||||
}
|
||||
for _, v := range vv {
|
||||
dst.Add(k, v)
|
||||
}
|
||||
}
|
||||
}
|
34
internal/proxy/sse.go
Normal file
34
internal/proxy/sse.go
Normal file
|
@ -0,0 +1,34 @@
|
|||
package proxy
|
||||
|
||||
import (
|
||||
"context"
|
||||
"log"
|
||||
"net/http"
|
||||
"net/http/httputil"
|
||||
"time"
|
||||
)
|
||||
|
||||
// HandleSSE sets up a go-routine to wait for context cancellation
|
||||
// and flushes the response if possible.
|
||||
func HandleSSE(w http.ResponseWriter, r *http.Request, rp *httputil.ReverseProxy) {
|
||||
ctx := r.Context()
|
||||
done := make(chan struct{})
|
||||
|
||||
go func() {
|
||||
<-ctx.Done()
|
||||
log.Printf("INFO: SSE connection closed from %s (path: %s)", r.RemoteAddr, r.URL.Path)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
rp.ServeHTTP(w, r)
|
||||
if f, ok := w.(http.Flusher); ok {
|
||||
f.Flush()
|
||||
}
|
||||
|
||||
<-done
|
||||
}
|
||||
|
||||
// NewShutdownContext is a little helper to gracefully shut down
|
||||
func NewShutdownContext(timeout time.Duration) (context.Context, context.CancelFunc) {
|
||||
return context.WithTimeout(context.Background(), timeout)
|
||||
}
|
97
internal/util/jwks.go
Normal file
97
internal/util/jwks.go
Normal file
|
@ -0,0 +1,97 @@
|
|||
package util
|
||||
|
||||
import (
|
||||
"crypto/rsa"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"log"
|
||||
"math/big"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/golang-jwt/jwt/v4"
|
||||
)
|
||||
|
||||
type JWKS struct {
|
||||
Keys []json.RawMessage `json:"keys"`
|
||||
}
|
||||
|
||||
var publicKeys map[string]*rsa.PublicKey
|
||||
|
||||
// FetchJWKS downloads JWKS and stores in a package-level map
|
||||
func FetchJWKS(jwksURL string) error {
|
||||
resp, err := http.Get(jwksURL)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
var jwks JWKS
|
||||
if err := json.NewDecoder(resp.Body).Decode(&jwks); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
publicKeys = make(map[string]*rsa.PublicKey)
|
||||
for _, keyData := range jwks.Keys {
|
||||
var parsedKey struct {
|
||||
Kid string `json:"kid"`
|
||||
N string `json:"n"`
|
||||
E string `json:"e"`
|
||||
Kty string `json:"kty"`
|
||||
}
|
||||
if err := json.Unmarshal(keyData, &parsedKey); err != nil {
|
||||
continue
|
||||
}
|
||||
if parsedKey.Kty != "RSA" {
|
||||
continue
|
||||
}
|
||||
pubKey, err := parseRSAPublicKey(parsedKey.N, parsedKey.E)
|
||||
if err == nil {
|
||||
publicKeys[parsedKey.Kid] = pubKey
|
||||
}
|
||||
}
|
||||
log.Printf("[JWKS] Loaded %d public keys.", len(publicKeys))
|
||||
return nil
|
||||
}
|
||||
|
||||
func parseRSAPublicKey(nStr, eStr string) (*rsa.PublicKey, error) {
|
||||
nBytes, err := jwt.DecodeSegment(nStr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
eBytes, err := jwt.DecodeSegment(eStr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
n := new(big.Int).SetBytes(nBytes)
|
||||
e := 0
|
||||
for _, b := range eBytes {
|
||||
e = e<<8 + int(b)
|
||||
}
|
||||
|
||||
return &rsa.PublicKey{N: n, E: e}, nil
|
||||
}
|
||||
|
||||
// ValidateJWT checks the Authorization: Bearer token using stored JWKS
|
||||
func ValidateJWT(authHeader string) error {
|
||||
if authHeader == "" || !strings.HasPrefix(authHeader, "Bearer ") {
|
||||
return errors.New("missing or invalid Authorization header")
|
||||
}
|
||||
tokenStr := strings.TrimPrefix(authHeader, "Bearer ")
|
||||
token, err := jwt.Parse(tokenStr, func(token *jwt.Token) (interface{}, error) {
|
||||
kid, _ := token.Header["kid"].(string)
|
||||
pubKey, ok := publicKeys[kid]
|
||||
if !ok {
|
||||
return nil, errors.New("unknown or missing kid in token header")
|
||||
}
|
||||
return pubKey, nil
|
||||
})
|
||||
if err != nil {
|
||||
return errors.New("invalid token: " + err.Error())
|
||||
}
|
||||
if !token.Valid {
|
||||
return errors.New("invalid token: token not valid")
|
||||
}
|
||||
return nil
|
||||
}
|
Loading…
Add table
Add a link
Reference in a new issue