Merge pull request #16 from shashimalcse/shashimalcse-patch-0003
This commit is contained in:
commit
11fc3cdfcd
6 changed files with 692 additions and 0 deletions
8
.gitignore
vendored
8
.gitignore
vendored
|
@ -30,4 +30,12 @@ go.sum
|
|||
# OS generated files
|
||||
.DS_Store
|
||||
|
||||
# builds
|
||||
openmcpauthproxy
|
||||
|
||||
# test out files
|
||||
coverage.out
|
||||
coverage.html
|
||||
|
||||
# IDE files
|
||||
.vscode
|
||||
|
|
73
Makefile
Normal file
73
Makefile
Normal file
|
@ -0,0 +1,73 @@
|
|||
# Makefile for open-mcp-auth-proxy
|
||||
|
||||
# Variables
|
||||
BINARY_NAME := openmcpauthproxy
|
||||
GO := go
|
||||
GOFMT := gofmt
|
||||
GOVET := go vet
|
||||
GOTEST := go test
|
||||
GOLINT := golangci-lint
|
||||
GOCOV := go tool cover
|
||||
BUILD_DIR := build
|
||||
|
||||
# Source files
|
||||
SRC := $(shell find . -name "*.go" -not -path "./vendor/*")
|
||||
PKGS := $(shell go list ./... | grep -v /vendor/)
|
||||
|
||||
# Set build options
|
||||
BUILD_OPTS := -v
|
||||
|
||||
# Set test options
|
||||
TEST_OPTS := -v -race
|
||||
|
||||
.PHONY: all build clean test fmt lint vet coverage help
|
||||
|
||||
# Default target
|
||||
all: lint test build
|
||||
|
||||
# Build the application
|
||||
build:
|
||||
@echo "Building $(BINARY_NAME)..."
|
||||
@mkdir -p $(BUILD_DIR)
|
||||
$(GO) build $(BUILD_OPTS) -o $(BUILD_DIR)/$(BINARY_NAME) ./cmd/proxy
|
||||
|
||||
# Clean build artifacts
|
||||
clean:
|
||||
@echo "Cleaning build artifacts..."
|
||||
@rm -rf $(BUILD_DIR)
|
||||
@rm -f coverage.out
|
||||
|
||||
# Run tests
|
||||
test:
|
||||
@echo "Running tests..."
|
||||
$(GOTEST) $(TEST_OPTS) ./...
|
||||
|
||||
# Run tests with coverage report
|
||||
coverage:
|
||||
@echo "Running tests with coverage..."
|
||||
@$(GOTEST) -coverprofile=coverage.out ./...
|
||||
@$(GOCOV) -func=coverage.out
|
||||
@$(GOCOV) -html=coverage.out -o coverage.html
|
||||
@echo "Coverage report generated in coverage.html"
|
||||
|
||||
# Run gofmt
|
||||
fmt:
|
||||
@echo "Running gofmt..."
|
||||
@$(GOFMT) -w -s $(SRC)
|
||||
|
||||
# Run go vet
|
||||
vet:
|
||||
@echo "Running go vet..."
|
||||
@$(GOVET) ./...
|
||||
|
||||
# Show help
|
||||
help:
|
||||
@echo "Available targets:"
|
||||
@echo " all : Run lint, test, and build"
|
||||
@echo " build : Build the application"
|
||||
@echo " clean : Clean build artifacts"
|
||||
@echo " test : Run tests"
|
||||
@echo " coverage : Run tests with coverage report"
|
||||
@echo " fmt : Run gofmt"
|
||||
@echo " vet : Run go vet"
|
||||
@echo " help : Show this help message"
|
125
internal/authz/default_test.go
Normal file
125
internal/authz/default_test.go
Normal file
|
@ -0,0 +1,125 @@
|
|||
package authz
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/wso2/open-mcp-auth-proxy/internal/config"
|
||||
)
|
||||
|
||||
func TestNewDefaultProvider(t *testing.T) {
|
||||
cfg := &config.Config{}
|
||||
provider := NewDefaultProvider(cfg)
|
||||
|
||||
if provider == nil {
|
||||
t.Fatal("Expected non-nil provider")
|
||||
}
|
||||
|
||||
// Ensure it implements the Provider interface
|
||||
var _ Provider = provider
|
||||
}
|
||||
|
||||
func TestDefaultProviderWellKnownHandler(t *testing.T) {
|
||||
// Create a config with a custom well-known response
|
||||
cfg := &config.Config{
|
||||
Default: config.DefaultConfig{
|
||||
Path: map[string]config.PathConfig{
|
||||
"/.well-known/oauth-authorization-server": {
|
||||
Response: &config.ResponseConfig{
|
||||
Issuer: "https://test-issuer.com",
|
||||
JwksURI: "https://test-issuer.com/jwks",
|
||||
ResponseTypesSupported: []string{"code"},
|
||||
GrantTypesSupported: []string{"authorization_code"},
|
||||
CodeChallengeMethodsSupported: []string{"S256"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
provider := NewDefaultProvider(cfg)
|
||||
handler := provider.WellKnownHandler()
|
||||
|
||||
// Create a test request
|
||||
req := httptest.NewRequest("GET", "/.well-known/oauth-authorization-server", nil)
|
||||
req.Host = "test-host.com"
|
||||
req.Header.Set("X-Forwarded-Proto", "https")
|
||||
|
||||
// Create a response recorder
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
// Call the handler
|
||||
handler(w, req)
|
||||
|
||||
// Check response status
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("Expected status OK, got %v", w.Code)
|
||||
}
|
||||
|
||||
// Verify content type
|
||||
contentType := w.Header().Get("Content-Type")
|
||||
if contentType != "application/json" {
|
||||
t.Errorf("Expected Content-Type: application/json, got %s", contentType)
|
||||
}
|
||||
|
||||
// Decode and check the response body
|
||||
var response map[string]interface{}
|
||||
if err := json.NewDecoder(w.Body).Decode(&response); err != nil {
|
||||
t.Fatalf("Failed to decode response JSON: %v", err)
|
||||
}
|
||||
|
||||
// Check expected values
|
||||
if response["issuer"] != "https://test-issuer.com" {
|
||||
t.Errorf("Expected issuer=https://test-issuer.com, got %v", response["issuer"])
|
||||
}
|
||||
if response["jwks_uri"] != "https://test-issuer.com/jwks" {
|
||||
t.Errorf("Expected jwks_uri=https://test-issuer.com/jwks, got %v", response["jwks_uri"])
|
||||
}
|
||||
if response["authorization_endpoint"] != "https://test-host.com/authorize" {
|
||||
t.Errorf("Expected authorization_endpoint=https://test-host.com/authorize, got %v", response["authorization_endpoint"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestDefaultProviderHandleOPTIONS(t *testing.T) {
|
||||
provider := NewDefaultProvider(&config.Config{})
|
||||
handler := provider.WellKnownHandler()
|
||||
|
||||
// Create OPTIONS request
|
||||
req := httptest.NewRequest("OPTIONS", "/.well-known/oauth-authorization-server", nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
// Call the handler
|
||||
handler(w, req)
|
||||
|
||||
// Check response
|
||||
if w.Code != http.StatusNoContent {
|
||||
t.Errorf("Expected status NoContent for OPTIONS request, got %v", w.Code)
|
||||
}
|
||||
|
||||
// Check CORS headers
|
||||
if w.Header().Get("Access-Control-Allow-Origin") != "*" {
|
||||
t.Errorf("Expected Access-Control-Allow-Origin: *, got %s", w.Header().Get("Access-Control-Allow-Origin"))
|
||||
}
|
||||
if w.Header().Get("Access-Control-Allow-Methods") != "GET, OPTIONS" {
|
||||
t.Errorf("Expected Access-Control-Allow-Methods: GET, OPTIONS, got %s", w.Header().Get("Access-Control-Allow-Methods"))
|
||||
}
|
||||
}
|
||||
|
||||
func TestDefaultProviderInvalidMethod(t *testing.T) {
|
||||
provider := NewDefaultProvider(&config.Config{})
|
||||
handler := provider.WellKnownHandler()
|
||||
|
||||
// Create POST request (which should be rejected)
|
||||
req := httptest.NewRequest("POST", "/.well-known/oauth-authorization-server", nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
// Call the handler
|
||||
handler(w, req)
|
||||
|
||||
// Check response
|
||||
if w.Code != http.StatusMethodNotAllowed {
|
||||
t.Errorf("Expected status MethodNotAllowed for POST request, got %v", w.Code)
|
||||
}
|
||||
}
|
196
internal/config/config_test.go
Normal file
196
internal/config/config_test.go
Normal file
|
@ -0,0 +1,196 @@
|
|||
package config
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestLoadConfig(t *testing.T) {
|
||||
// Create a temporary config file
|
||||
tempDir := t.TempDir()
|
||||
configPath := filepath.Join(tempDir, "test_config.yaml")
|
||||
|
||||
// Basic valid config
|
||||
validConfig := `
|
||||
listen_port: 8080
|
||||
base_url: "http://localhost:8000"
|
||||
transport_mode: "sse"
|
||||
paths:
|
||||
sse: "/sse"
|
||||
messages: "/messages"
|
||||
cors:
|
||||
allowed_origins:
|
||||
- "http://localhost:5173"
|
||||
allowed_methods:
|
||||
- "GET"
|
||||
- "POST"
|
||||
allowed_headers:
|
||||
- "Authorization"
|
||||
- "Content-Type"
|
||||
allow_credentials: true
|
||||
`
|
||||
err := os.WriteFile(configPath, []byte(validConfig), 0644)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create test config file: %v", err)
|
||||
}
|
||||
|
||||
// Test loading the valid config
|
||||
cfg, err := LoadConfig(configPath)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to load valid config: %v", err)
|
||||
}
|
||||
|
||||
// Verify expected values from the config
|
||||
if cfg.ListenPort != 8080 {
|
||||
t.Errorf("Expected ListenPort=8080, got %d", cfg.ListenPort)
|
||||
}
|
||||
if cfg.BaseURL != "http://localhost:8000" {
|
||||
t.Errorf("Expected BaseURL=http://localhost:8000, got %s", cfg.BaseURL)
|
||||
}
|
||||
if cfg.TransportMode != SSETransport {
|
||||
t.Errorf("Expected TransportMode=sse, got %s", cfg.TransportMode)
|
||||
}
|
||||
if cfg.Paths.SSE != "/sse" {
|
||||
t.Errorf("Expected Paths.SSE=/sse, got %s", cfg.Paths.SSE)
|
||||
}
|
||||
if cfg.Paths.Messages != "/messages" {
|
||||
t.Errorf("Expected Paths.Messages=/messages, got %s", cfg.Paths.Messages)
|
||||
}
|
||||
|
||||
// Test default values
|
||||
if cfg.TimeoutSeconds != 15 {
|
||||
t.Errorf("Expected default TimeoutSeconds=15, got %d", cfg.TimeoutSeconds)
|
||||
}
|
||||
if cfg.Port != 8000 {
|
||||
t.Errorf("Expected default Port=8000, got %d", cfg.Port)
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidate(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
config Config
|
||||
expectError bool
|
||||
}{
|
||||
{
|
||||
name: "Valid SSE config",
|
||||
config: Config{
|
||||
TransportMode: SSETransport,
|
||||
Paths: PathsConfig{
|
||||
SSE: "/sse",
|
||||
Messages: "/messages",
|
||||
},
|
||||
BaseURL: "http://localhost:8000",
|
||||
},
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "Valid stdio config",
|
||||
config: Config{
|
||||
TransportMode: StdioTransport,
|
||||
Stdio: StdioConfig{
|
||||
Enabled: true,
|
||||
UserCommand: "some-command",
|
||||
},
|
||||
},
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "Invalid stdio config - not enabled",
|
||||
config: Config{
|
||||
TransportMode: StdioTransport,
|
||||
Stdio: StdioConfig{
|
||||
Enabled: false,
|
||||
UserCommand: "some-command",
|
||||
},
|
||||
},
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
name: "Invalid stdio config - no command",
|
||||
config: Config{
|
||||
TransportMode: StdioTransport,
|
||||
Stdio: StdioConfig{
|
||||
Enabled: true,
|
||||
UserCommand: "",
|
||||
},
|
||||
},
|
||||
expectError: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
err := tc.config.Validate()
|
||||
if tc.expectError && err == nil {
|
||||
t.Errorf("Expected validation error but got none")
|
||||
}
|
||||
if !tc.expectError && err != nil {
|
||||
t.Errorf("Expected no validation error but got: %v", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetMCPPaths(t *testing.T) {
|
||||
cfg := Config{
|
||||
Paths: PathsConfig{
|
||||
SSE: "/custom-sse",
|
||||
Messages: "/custom-messages",
|
||||
},
|
||||
}
|
||||
|
||||
paths := cfg.GetMCPPaths()
|
||||
if len(paths) != 2 {
|
||||
t.Errorf("Expected 2 MCP paths, got %d", len(paths))
|
||||
}
|
||||
if paths[0] != "/custom-sse" {
|
||||
t.Errorf("Expected first path=/custom-sse, got %s", paths[0])
|
||||
}
|
||||
if paths[1] != "/custom-messages" {
|
||||
t.Errorf("Expected second path=/custom-messages, got %s", paths[1])
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildExecCommand(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
config Config
|
||||
expectedResult string
|
||||
}{
|
||||
{
|
||||
name: "Valid command",
|
||||
config: Config{
|
||||
Stdio: StdioConfig{
|
||||
UserCommand: "test-command",
|
||||
},
|
||||
Port: 8080,
|
||||
BaseURL: "http://example.com",
|
||||
Paths: PathsConfig{
|
||||
SSE: "/sse-path",
|
||||
Messages: "/msgs",
|
||||
},
|
||||
},
|
||||
expectedResult: `npx -y supergateway --stdio "test-command" --port 8080 --baseUrl http://example.com --ssePath /sse-path --messagePath /msgs`,
|
||||
},
|
||||
{
|
||||
name: "Empty command",
|
||||
config: Config{
|
||||
Stdio: StdioConfig{
|
||||
UserCommand: "",
|
||||
},
|
||||
},
|
||||
expectedResult: "",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
result := tc.config.BuildExecCommand()
|
||||
if result != tc.expectedResult {
|
||||
t.Errorf("Expected command=%s, got %s", tc.expectedResult, result)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
147
internal/proxy/modifier_test.go
Normal file
147
internal/proxy/modifier_test.go
Normal file
|
@ -0,0 +1,147 @@
|
|||
package proxy
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/wso2/open-mcp-auth-proxy/internal/config"
|
||||
)
|
||||
|
||||
func TestAuthorizationModifier(t *testing.T) {
|
||||
cfg := &config.Config{
|
||||
Default: config.DefaultConfig{
|
||||
Path: map[string]config.PathConfig{
|
||||
"/authorize": {
|
||||
AddQueryParams: []config.ParamConfig{
|
||||
{Name: "client_id", Value: "test-client-id"},
|
||||
{Name: "scope", Value: "openid"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
modifier := &AuthorizationModifier{Config: cfg}
|
||||
|
||||
// Create a test request
|
||||
req, err := http.NewRequest("GET", "/authorize?response_type=code", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create request: %v", err)
|
||||
}
|
||||
|
||||
// Modify the request
|
||||
modifiedReq, err := modifier.ModifyRequest(req)
|
||||
if err != nil {
|
||||
t.Fatalf("ModifyRequest failed: %v", err)
|
||||
}
|
||||
|
||||
// Check that the query parameters were added
|
||||
query := modifiedReq.URL.Query()
|
||||
if query.Get("client_id") != "test-client-id" {
|
||||
t.Errorf("Expected client_id=test-client-id, got %s", query.Get("client_id"))
|
||||
}
|
||||
if query.Get("scope") != "openid" {
|
||||
t.Errorf("Expected scope=openid, got %s", query.Get("scope"))
|
||||
}
|
||||
if query.Get("response_type") != "code" {
|
||||
t.Errorf("Expected response_type=code, got %s", query.Get("response_type"))
|
||||
}
|
||||
}
|
||||
|
||||
func TestTokenModifier(t *testing.T) {
|
||||
cfg := &config.Config{
|
||||
Default: config.DefaultConfig{
|
||||
Path: map[string]config.PathConfig{
|
||||
"/token": {
|
||||
AddBodyParams: []config.ParamConfig{
|
||||
{Name: "audience", Value: "test-audience"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
modifier := &TokenModifier{Config: cfg}
|
||||
|
||||
// Create a test request with form data
|
||||
form := url.Values{}
|
||||
|
||||
req, err := http.NewRequest("POST", "/token", strings.NewReader(form.Encode()))
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create request: %v", err)
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
|
||||
// Modify the request
|
||||
modifiedReq, err := modifier.ModifyRequest(req)
|
||||
if err != nil {
|
||||
t.Fatalf("ModifyRequest failed: %v", err)
|
||||
}
|
||||
|
||||
body := make([]byte, 1024)
|
||||
n, err := modifiedReq.Body.Read(body)
|
||||
if err != nil && err.Error() != "EOF" {
|
||||
t.Fatalf("Failed to read body: %v", err)
|
||||
}
|
||||
bodyStr := string(body[:n])
|
||||
|
||||
// Parse the form data from the modified request
|
||||
if err := modifiedReq.ParseForm(); err != nil {
|
||||
t.Fatalf("Failed to parse form data: %v", err)
|
||||
}
|
||||
|
||||
// Check that the body parameters were added
|
||||
if !strings.Contains(bodyStr, "audience") {
|
||||
t.Errorf("Expected body to contain audience, got %s", bodyStr)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRegisterModifier(t *testing.T) {
|
||||
cfg := &config.Config{
|
||||
Default: config.DefaultConfig{
|
||||
Path: map[string]config.PathConfig{
|
||||
"/register": {
|
||||
AddBodyParams: []config.ParamConfig{
|
||||
{Name: "client_name", Value: "test-client"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
modifier := &RegisterModifier{Config: cfg}
|
||||
|
||||
// Create a test request with JSON data
|
||||
jsonBody := `{"redirect_uris":["https://example.com/callback"]}`
|
||||
req, err := http.NewRequest("POST", "/register", strings.NewReader(jsonBody))
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create request: %v", err)
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
// Modify the request
|
||||
modifiedReq, err := modifier.ModifyRequest(req)
|
||||
if err != nil {
|
||||
t.Fatalf("ModifyRequest failed: %v", err)
|
||||
}
|
||||
|
||||
// Read the body and check that it still contains the original data
|
||||
// This test would need to be enhanced with a proper JSON parsing to verify
|
||||
// the added parameters
|
||||
body := make([]byte, 1024)
|
||||
n, err := modifiedReq.Body.Read(body)
|
||||
if err != nil && err.Error() != "EOF" {
|
||||
t.Fatalf("Failed to read body: %v", err)
|
||||
}
|
||||
bodyStr := string(body[:n])
|
||||
|
||||
// Simple check to see if the modified body contains the expected fields
|
||||
if !strings.Contains(bodyStr, "client_name") {
|
||||
t.Errorf("Expected body to contain client_name, got %s", bodyStr)
|
||||
}
|
||||
if !strings.Contains(bodyStr, "redirect_uris") {
|
||||
t.Errorf("Expected body to contain redirect_uris, got %s", bodyStr)
|
||||
}
|
||||
}
|
143
internal/util/jwks_test.go
Normal file
143
internal/util/jwks_test.go
Normal file
|
@ -0,0 +1,143 @@
|
|||
package util
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"crypto/rsa"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/golang-jwt/jwt/v4"
|
||||
)
|
||||
|
||||
func TestValidateJWT(t *testing.T) {
|
||||
// Initialize the test JWKS data
|
||||
initTestJWKS(t)
|
||||
|
||||
// Test cases
|
||||
tests := []struct {
|
||||
name string
|
||||
authHeader string
|
||||
expectError bool
|
||||
}{
|
||||
{
|
||||
name: "Valid JWT token",
|
||||
authHeader: "Bearer " + createValidJWT(t),
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "No auth header",
|
||||
authHeader: "",
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
name: "Invalid auth header format",
|
||||
authHeader: "InvalidFormat",
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
name: "Invalid JWT token",
|
||||
authHeader: "Bearer invalid.jwt.token",
|
||||
expectError: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
err := ValidateJWT(tc.authHeader)
|
||||
if tc.expectError && err == nil {
|
||||
t.Errorf("Expected error but got none")
|
||||
}
|
||||
if !tc.expectError && err != nil {
|
||||
t.Errorf("Expected no error but got: %v", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestFetchJWKS(t *testing.T) {
|
||||
// Create a mock JWKS server
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
// Generate a test RSA key
|
||||
privateKey, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to generate RSA key: %v", err)
|
||||
}
|
||||
|
||||
// Create JWKS response
|
||||
jwks := map[string]interface{}{
|
||||
"keys": []map[string]interface{}{
|
||||
{
|
||||
"kty": "RSA",
|
||||
"kid": "test-key-id",
|
||||
"n": base64.RawURLEncoding.EncodeToString(privateKey.N.Bytes()),
|
||||
"e": base64.RawURLEncoding.EncodeToString([]byte{1, 0, 1}), // Default exponent 65537
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(jwks)
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
// Test fetching JWKS
|
||||
err := FetchJWKS(server.URL)
|
||||
if err != nil {
|
||||
t.Fatalf("FetchJWKS failed: %v", err)
|
||||
}
|
||||
|
||||
// Check that keys were stored
|
||||
if len(publicKeys) == 0 {
|
||||
t.Errorf("Expected publicKeys to be populated")
|
||||
}
|
||||
}
|
||||
|
||||
// Helper function to initialize test JWKS data
|
||||
func initTestJWKS(t *testing.T) {
|
||||
// Create a test RSA key pair
|
||||
privateKey, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to generate RSA key: %v", err)
|
||||
}
|
||||
|
||||
// Initialize the publicKeys map
|
||||
publicKeys = map[string]*rsa.PublicKey{
|
||||
"test-key-id": &privateKey.PublicKey,
|
||||
}
|
||||
}
|
||||
|
||||
// Helper function to create a valid JWT token for testing
|
||||
func createValidJWT(t *testing.T) string {
|
||||
// Create a test RSA key pair
|
||||
privateKey, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to generate RSA key: %v", err)
|
||||
}
|
||||
|
||||
// Ensure the test key is in the publicKeys map
|
||||
if publicKeys == nil {
|
||||
publicKeys = map[string]*rsa.PublicKey{}
|
||||
}
|
||||
publicKeys["test-key-id"] = &privateKey.PublicKey
|
||||
|
||||
// Create token
|
||||
token := jwt.NewWithClaims(jwt.SigningMethodRS256, jwt.MapClaims{
|
||||
"sub": "1234567890",
|
||||
"name": "Test User",
|
||||
"iat": time.Now().Unix(),
|
||||
"exp": time.Now().Add(time.Hour).Unix(),
|
||||
})
|
||||
token.Header["kid"] = "test-key-id"
|
||||
|
||||
// Sign the token
|
||||
tokenString, err := token.SignedString(privateKey)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to sign token: %v", err)
|
||||
}
|
||||
|
||||
return tokenString
|
||||
}
|
Loading…
Add table
Add a link
Reference in a new issue