mirror of
https://github.com/wso2/open-mcp-auth-proxy.git
synced 2025-06-27 17:13:31 +00:00
Add unit tests
This commit is contained in:
parent
32c9378aad
commit
b2b2124b76
8 changed files with 756 additions and 0 deletions
62
.github/workflows/go.yml
vendored
Normal file
62
.github/workflows/go.yml
vendored
Normal file
|
@ -0,0 +1,62 @@
|
||||||
|
name: Go CI
|
||||||
|
|
||||||
|
on:
|
||||||
|
push:
|
||||||
|
branches: [ main ]
|
||||||
|
pull_request:
|
||||||
|
branches: [ main ]
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
test:
|
||||||
|
name: Test
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
strategy:
|
||||||
|
matrix:
|
||||||
|
go-version: ['1.20', '1.21']
|
||||||
|
|
||||||
|
steps:
|
||||||
|
- name: Checkout code
|
||||||
|
uses: actions/checkout@v3
|
||||||
|
|
||||||
|
- name: Set up Go
|
||||||
|
uses: actions/setup-go@v4
|
||||||
|
with:
|
||||||
|
go-version: ${{ matrix.go-version }}
|
||||||
|
|
||||||
|
- name: Get dependencies
|
||||||
|
run: go get -v -t -d ./...
|
||||||
|
|
||||||
|
- name: Verify dependencies
|
||||||
|
run: go mod verify
|
||||||
|
|
||||||
|
- name: Run go vet
|
||||||
|
run: go vet ./...
|
||||||
|
|
||||||
|
- name: Run tests
|
||||||
|
run: go test -v -race -coverprofile=coverage.txt -covermode=atomic ./...
|
||||||
|
|
||||||
|
- name: Upload coverage to Codecov
|
||||||
|
uses: codecov/codecov-action@v3
|
||||||
|
with:
|
||||||
|
files: ./coverage.txt
|
||||||
|
fail_ci_if_error: false
|
||||||
|
|
||||||
|
build:
|
||||||
|
name: Build
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
strategy:
|
||||||
|
matrix:
|
||||||
|
go-version: ['1.20', '1.21']
|
||||||
|
os: [ubuntu-latest, macos-latest]
|
||||||
|
|
||||||
|
steps:
|
||||||
|
- name: Checkout code
|
||||||
|
uses: actions/checkout@v3
|
||||||
|
|
||||||
|
- name: Set up Go
|
||||||
|
uses: actions/setup-go@v4
|
||||||
|
with:
|
||||||
|
go-version: ${{ matrix.go-version }}
|
||||||
|
|
||||||
|
- name: Build
|
||||||
|
run: go build -v ./cmd/proxy
|
4
.gitignore
vendored
4
.gitignore
vendored
|
@ -31,3 +31,7 @@ go.sum
|
||||||
.DS_Store
|
.DS_Store
|
||||||
|
|
||||||
openmcpauthproxy
|
openmcpauthproxy
|
||||||
|
|
||||||
|
# test out files
|
||||||
|
coverage.out
|
||||||
|
coverage.html
|
||||||
|
|
4
.vscode/settings.json
vendored
Normal file
4
.vscode/settings.json
vendored
Normal file
|
@ -0,0 +1,4 @@
|
||||||
|
{
|
||||||
|
"github.copilot.chat.codesearch.enabled": true,
|
||||||
|
"github.copilot.chat.newWorkspaceCreation.enabled": true
|
||||||
|
}
|
74
Makefile
Normal file
74
Makefile
Normal file
|
@ -0,0 +1,74 @@
|
||||||
|
# Makefile for open-mcp-auth-proxy
|
||||||
|
|
||||||
|
# Variables
|
||||||
|
BINARY_NAME := openmcpauthproxy
|
||||||
|
GO := go
|
||||||
|
GOFMT := gofmt
|
||||||
|
GOVET := go vet
|
||||||
|
GOTEST := go test
|
||||||
|
GOLINT := golangci-lint
|
||||||
|
GOCOV := go tool cover
|
||||||
|
BUILD_DIR := build
|
||||||
|
|
||||||
|
# Source files
|
||||||
|
SRC := $(shell find . -name "*.go" -not -path "./vendor/*")
|
||||||
|
PKGS := $(shell go list ./... | grep -v /vendor/)
|
||||||
|
|
||||||
|
# Set build options
|
||||||
|
BUILD_OPTS := -v
|
||||||
|
|
||||||
|
# Set test options
|
||||||
|
TEST_OPTS := -v -race
|
||||||
|
|
||||||
|
.PHONY: all build clean test fmt lint vet coverage help
|
||||||
|
|
||||||
|
# Default target
|
||||||
|
all: lint test build
|
||||||
|
|
||||||
|
# Build the application
|
||||||
|
build:
|
||||||
|
@echo "Building $(BINARY_NAME)..."
|
||||||
|
@mkdir -p $(BUILD_DIR)
|
||||||
|
$(GO) build $(BUILD_OPTS) -o $(BUILD_DIR)/$(BINARY_NAME) ./cmd/proxy
|
||||||
|
|
||||||
|
# Clean build artifacts
|
||||||
|
clean:
|
||||||
|
@echo "Cleaning build artifacts..."
|
||||||
|
@rm -rf $(BUILD_DIR)
|
||||||
|
@rm -f coverage.out
|
||||||
|
|
||||||
|
# Run tests
|
||||||
|
test:
|
||||||
|
@echo "Running tests..."
|
||||||
|
$(GOTEST) $(TEST_OPTS) ./...
|
||||||
|
|
||||||
|
# Run tests with coverage report
|
||||||
|
coverage:
|
||||||
|
@echo "Running tests with coverage..."
|
||||||
|
@$(GOTEST) -coverprofile=coverage.out ./...
|
||||||
|
@$(GOCOV) -func=coverage.out
|
||||||
|
@$(GOCOV) -html=coverage.out -o coverage.html
|
||||||
|
@echo "Coverage report generated in coverage.html"
|
||||||
|
|
||||||
|
# Run gofmt
|
||||||
|
fmt:
|
||||||
|
@echo "Running gofmt..."
|
||||||
|
@$(GOFMT) -w -s $(SRC)
|
||||||
|
|
||||||
|
# Run go vet
|
||||||
|
vet:
|
||||||
|
@echo "Running go vet..."
|
||||||
|
@$(GOVET) ./...
|
||||||
|
|
||||||
|
# Show help
|
||||||
|
help:
|
||||||
|
@echo "Available targets:"
|
||||||
|
@echo " all : Run lint, test, and build"
|
||||||
|
@echo " build : Build the application"
|
||||||
|
@echo " clean : Clean build artifacts"
|
||||||
|
@echo " test : Run tests"
|
||||||
|
@echo " coverage : Run tests with coverage report"
|
||||||
|
@echo " fmt : Run gofmt"
|
||||||
|
@echo " lint : Run golangci-lint"
|
||||||
|
@echo " vet : Run go vet"
|
||||||
|
@echo " help : Show this help message"
|
125
internal/authz/default_test.go
Normal file
125
internal/authz/default_test.go
Normal file
|
@ -0,0 +1,125 @@
|
||||||
|
package authz
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/wso2/open-mcp-auth-proxy/internal/config"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestNewDefaultProvider(t *testing.T) {
|
||||||
|
cfg := &config.Config{}
|
||||||
|
provider := NewDefaultProvider(cfg)
|
||||||
|
|
||||||
|
if provider == nil {
|
||||||
|
t.Fatal("Expected non-nil provider")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Ensure it implements the Provider interface
|
||||||
|
var _ Provider = provider
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDefaultProviderWellKnownHandler(t *testing.T) {
|
||||||
|
// Create a config with a custom well-known response
|
||||||
|
cfg := &config.Config{
|
||||||
|
Default: config.DefaultConfig{
|
||||||
|
Path: map[string]config.PathConfig{
|
||||||
|
"/.well-known/oauth-authorization-server": {
|
||||||
|
Response: &config.ResponseConfig{
|
||||||
|
Issuer: "https://test-issuer.com",
|
||||||
|
JwksURI: "https://test-issuer.com/jwks",
|
||||||
|
ResponseTypesSupported: []string{"code"},
|
||||||
|
GrantTypesSupported: []string{"authorization_code"},
|
||||||
|
CodeChallengeMethodsSupported: []string{"S256"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
provider := NewDefaultProvider(cfg)
|
||||||
|
handler := provider.WellKnownHandler()
|
||||||
|
|
||||||
|
// Create a test request
|
||||||
|
req := httptest.NewRequest("GET", "/.well-known/oauth-authorization-server", nil)
|
||||||
|
req.Host = "test-host.com"
|
||||||
|
req.Header.Set("X-Forwarded-Proto", "https")
|
||||||
|
|
||||||
|
// Create a response recorder
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
|
// Call the handler
|
||||||
|
handler(w, req)
|
||||||
|
|
||||||
|
// Check response status
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Errorf("Expected status OK, got %v", w.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify content type
|
||||||
|
contentType := w.Header().Get("Content-Type")
|
||||||
|
if contentType != "application/json" {
|
||||||
|
t.Errorf("Expected Content-Type: application/json, got %s", contentType)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Decode and check the response body
|
||||||
|
var response map[string]interface{}
|
||||||
|
if err := json.NewDecoder(w.Body).Decode(&response); err != nil {
|
||||||
|
t.Fatalf("Failed to decode response JSON: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check expected values
|
||||||
|
if response["issuer"] != "https://test-issuer.com" {
|
||||||
|
t.Errorf("Expected issuer=https://test-issuer.com, got %v", response["issuer"])
|
||||||
|
}
|
||||||
|
if response["jwks_uri"] != "https://test-issuer.com/jwks" {
|
||||||
|
t.Errorf("Expected jwks_uri=https://test-issuer.com/jwks, got %v", response["jwks_uri"])
|
||||||
|
}
|
||||||
|
if response["authorization_endpoint"] != "https://test-host.com/authorize" {
|
||||||
|
t.Errorf("Expected authorization_endpoint=https://test-host.com/authorize, got %v", response["authorization_endpoint"])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDefaultProviderHandleOPTIONS(t *testing.T) {
|
||||||
|
provider := NewDefaultProvider(&config.Config{})
|
||||||
|
handler := provider.WellKnownHandler()
|
||||||
|
|
||||||
|
// Create OPTIONS request
|
||||||
|
req := httptest.NewRequest("OPTIONS", "/.well-known/oauth-authorization-server", nil)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
|
// Call the handler
|
||||||
|
handler(w, req)
|
||||||
|
|
||||||
|
// Check response
|
||||||
|
if w.Code != http.StatusNoContent {
|
||||||
|
t.Errorf("Expected status NoContent for OPTIONS request, got %v", w.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check CORS headers
|
||||||
|
if w.Header().Get("Access-Control-Allow-Origin") != "*" {
|
||||||
|
t.Errorf("Expected Access-Control-Allow-Origin: *, got %s", w.Header().Get("Access-Control-Allow-Origin"))
|
||||||
|
}
|
||||||
|
if w.Header().Get("Access-Control-Allow-Methods") != "GET, OPTIONS" {
|
||||||
|
t.Errorf("Expected Access-Control-Allow-Methods: GET, OPTIONS, got %s", w.Header().Get("Access-Control-Allow-Methods"))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDefaultProviderInvalidMethod(t *testing.T) {
|
||||||
|
provider := NewDefaultProvider(&config.Config{})
|
||||||
|
handler := provider.WellKnownHandler()
|
||||||
|
|
||||||
|
// Create POST request (which should be rejected)
|
||||||
|
req := httptest.NewRequest("POST", "/.well-known/oauth-authorization-server", nil)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
|
// Call the handler
|
||||||
|
handler(w, req)
|
||||||
|
|
||||||
|
// Check response
|
||||||
|
if w.Code != http.StatusMethodNotAllowed {
|
||||||
|
t.Errorf("Expected status MethodNotAllowed for POST request, got %v", w.Code)
|
||||||
|
}
|
||||||
|
}
|
196
internal/config/config_test.go
Normal file
196
internal/config/config_test.go
Normal file
|
@ -0,0 +1,196 @@
|
||||||
|
package config
|
||||||
|
|
||||||
|
import (
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestLoadConfig(t *testing.T) {
|
||||||
|
// Create a temporary config file
|
||||||
|
tempDir := t.TempDir()
|
||||||
|
configPath := filepath.Join(tempDir, "test_config.yaml")
|
||||||
|
|
||||||
|
// Basic valid config
|
||||||
|
validConfig := `
|
||||||
|
listen_port: 8080
|
||||||
|
base_url: "http://localhost:8000"
|
||||||
|
transport_mode: "sse"
|
||||||
|
paths:
|
||||||
|
sse: "/sse"
|
||||||
|
messages: "/messages"
|
||||||
|
cors:
|
||||||
|
allowed_origins:
|
||||||
|
- "http://localhost:5173"
|
||||||
|
allowed_methods:
|
||||||
|
- "GET"
|
||||||
|
- "POST"
|
||||||
|
allowed_headers:
|
||||||
|
- "Authorization"
|
||||||
|
- "Content-Type"
|
||||||
|
allow_credentials: true
|
||||||
|
`
|
||||||
|
err := os.WriteFile(configPath, []byte(validConfig), 0644)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to create test config file: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test loading the valid config
|
||||||
|
cfg, err := LoadConfig(configPath)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to load valid config: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify expected values from the config
|
||||||
|
if cfg.ListenPort != 8080 {
|
||||||
|
t.Errorf("Expected ListenPort=8080, got %d", cfg.ListenPort)
|
||||||
|
}
|
||||||
|
if cfg.BaseURL != "http://localhost:8000" {
|
||||||
|
t.Errorf("Expected BaseURL=http://localhost:8000, got %s", cfg.BaseURL)
|
||||||
|
}
|
||||||
|
if cfg.TransportMode != SSETransport {
|
||||||
|
t.Errorf("Expected TransportMode=sse, got %s", cfg.TransportMode)
|
||||||
|
}
|
||||||
|
if cfg.Paths.SSE != "/sse" {
|
||||||
|
t.Errorf("Expected Paths.SSE=/sse, got %s", cfg.Paths.SSE)
|
||||||
|
}
|
||||||
|
if cfg.Paths.Messages != "/messages" {
|
||||||
|
t.Errorf("Expected Paths.Messages=/messages, got %s", cfg.Paths.Messages)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test default values
|
||||||
|
if cfg.TimeoutSeconds != 15 {
|
||||||
|
t.Errorf("Expected default TimeoutSeconds=15, got %d", cfg.TimeoutSeconds)
|
||||||
|
}
|
||||||
|
if cfg.Port != 8000 {
|
||||||
|
t.Errorf("Expected default Port=8000, got %d", cfg.Port)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestValidate(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
config Config
|
||||||
|
expectError bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "Valid SSE config",
|
||||||
|
config: Config{
|
||||||
|
TransportMode: SSETransport,
|
||||||
|
Paths: PathsConfig{
|
||||||
|
SSE: "/sse",
|
||||||
|
Messages: "/messages",
|
||||||
|
},
|
||||||
|
BaseURL: "http://localhost:8000",
|
||||||
|
},
|
||||||
|
expectError: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Valid stdio config",
|
||||||
|
config: Config{
|
||||||
|
TransportMode: StdioTransport,
|
||||||
|
Stdio: StdioConfig{
|
||||||
|
Enabled: true,
|
||||||
|
UserCommand: "some-command",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
expectError: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Invalid stdio config - not enabled",
|
||||||
|
config: Config{
|
||||||
|
TransportMode: StdioTransport,
|
||||||
|
Stdio: StdioConfig{
|
||||||
|
Enabled: false,
|
||||||
|
UserCommand: "some-command",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
expectError: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Invalid stdio config - no command",
|
||||||
|
config: Config{
|
||||||
|
TransportMode: StdioTransport,
|
||||||
|
Stdio: StdioConfig{
|
||||||
|
Enabled: true,
|
||||||
|
UserCommand: "",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
expectError: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range tests {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
err := tc.config.Validate()
|
||||||
|
if tc.expectError && err == nil {
|
||||||
|
t.Errorf("Expected validation error but got none")
|
||||||
|
}
|
||||||
|
if !tc.expectError && err != nil {
|
||||||
|
t.Errorf("Expected no validation error but got: %v", err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetMCPPaths(t *testing.T) {
|
||||||
|
cfg := Config{
|
||||||
|
Paths: PathsConfig{
|
||||||
|
SSE: "/custom-sse",
|
||||||
|
Messages: "/custom-messages",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
paths := cfg.GetMCPPaths()
|
||||||
|
if len(paths) != 2 {
|
||||||
|
t.Errorf("Expected 2 MCP paths, got %d", len(paths))
|
||||||
|
}
|
||||||
|
if paths[0] != "/custom-sse" {
|
||||||
|
t.Errorf("Expected first path=/custom-sse, got %s", paths[0])
|
||||||
|
}
|
||||||
|
if paths[1] != "/custom-messages" {
|
||||||
|
t.Errorf("Expected second path=/custom-messages, got %s", paths[1])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBuildExecCommand(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
config Config
|
||||||
|
expectedResult string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "Valid command",
|
||||||
|
config: Config{
|
||||||
|
Stdio: StdioConfig{
|
||||||
|
UserCommand: "test-command",
|
||||||
|
},
|
||||||
|
Port: 8080,
|
||||||
|
BaseURL: "http://example.com",
|
||||||
|
Paths: PathsConfig{
|
||||||
|
SSE: "/sse-path",
|
||||||
|
Messages: "/msgs",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
expectedResult: `npx -y supergateway --stdio "test-command" --port 8080 --baseUrl http://example.com --ssePath /sse-path --messagePath /msgs`,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Empty command",
|
||||||
|
config: Config{
|
||||||
|
Stdio: StdioConfig{
|
||||||
|
UserCommand: "",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
expectedResult: "",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range tests {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
result := tc.config.BuildExecCommand()
|
||||||
|
if result != tc.expectedResult {
|
||||||
|
t.Errorf("Expected command=%s, got %s", tc.expectedResult, result)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
148
internal/proxy/modifier_test.go
Normal file
148
internal/proxy/modifier_test.go
Normal file
|
@ -0,0 +1,148 @@
|
||||||
|
package proxy
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
"net/url"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/wso2/open-mcp-auth-proxy/internal/config"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestAuthorizationModifier(t *testing.T) {
|
||||||
|
cfg := &config.Config{
|
||||||
|
Default: config.DefaultConfig{
|
||||||
|
Path: map[string]config.PathConfig{
|
||||||
|
"/authorize": {
|
||||||
|
AddQueryParams: []config.ParamConfig{
|
||||||
|
{Name: "client_id", Value: "test-client-id"},
|
||||||
|
{Name: "scope", Value: "openid"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
modifier := &AuthorizationModifier{Config: cfg}
|
||||||
|
|
||||||
|
// Create a test request
|
||||||
|
req, err := http.NewRequest("GET", "/authorize?response_type=code", nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to create request: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Modify the request
|
||||||
|
modifiedReq, err := modifier.ModifyRequest(req)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("ModifyRequest failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check that the query parameters were added
|
||||||
|
query := modifiedReq.URL.Query()
|
||||||
|
if query.Get("client_id") != "test-client-id" {
|
||||||
|
t.Errorf("Expected client_id=test-client-id, got %s", query.Get("client_id"))
|
||||||
|
}
|
||||||
|
if query.Get("scope") != "openid" {
|
||||||
|
t.Errorf("Expected scope=openid, got %s", query.Get("scope"))
|
||||||
|
}
|
||||||
|
if query.Get("response_type") != "code" {
|
||||||
|
t.Errorf("Expected response_type=code, got %s", query.Get("response_type"))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTokenModifier(t *testing.T) {
|
||||||
|
cfg := &config.Config{
|
||||||
|
Default: config.DefaultConfig{
|
||||||
|
Path: map[string]config.PathConfig{
|
||||||
|
"/token": {
|
||||||
|
AddBodyParams: []config.ParamConfig{
|
||||||
|
{Name: "audience", Value: "test-audience"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
modifier := &TokenModifier{Config: cfg}
|
||||||
|
|
||||||
|
// Create a test request with form data
|
||||||
|
form := url.Values{}
|
||||||
|
|
||||||
|
req, err := http.NewRequest("POST", "/token", strings.NewReader(form.Encode()))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to create request: %v", err)
|
||||||
|
}
|
||||||
|
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||||
|
|
||||||
|
// Modify the request
|
||||||
|
modifiedReq, err := modifier.ModifyRequest(req)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("ModifyRequest failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
body := make([]byte, 1024)
|
||||||
|
n, err := modifiedReq.Body.Read(body)
|
||||||
|
if err != nil && err.Error() != "EOF" {
|
||||||
|
t.Fatalf("Failed to read body: %v", err)
|
||||||
|
}
|
||||||
|
bodyStr := string(body[:n])
|
||||||
|
print(bodyStr)
|
||||||
|
|
||||||
|
// Parse the form data from the modified request
|
||||||
|
if err := modifiedReq.ParseForm(); err != nil {
|
||||||
|
t.Fatalf("Failed to parse form data: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check that the body parameters were added
|
||||||
|
if !strings.Contains(bodyStr, "audience") {
|
||||||
|
t.Errorf("Expected body to contain audience, got %s", bodyStr)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRegisterModifier(t *testing.T) {
|
||||||
|
cfg := &config.Config{
|
||||||
|
Default: config.DefaultConfig{
|
||||||
|
Path: map[string]config.PathConfig{
|
||||||
|
"/register": {
|
||||||
|
AddBodyParams: []config.ParamConfig{
|
||||||
|
{Name: "client_name", Value: "test-client"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
modifier := &RegisterModifier{Config: cfg}
|
||||||
|
|
||||||
|
// Create a test request with JSON data
|
||||||
|
jsonBody := `{"redirect_uris":["https://example.com/callback"]}`
|
||||||
|
req, err := http.NewRequest("POST", "/register", strings.NewReader(jsonBody))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to create request: %v", err)
|
||||||
|
}
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
|
||||||
|
// Modify the request
|
||||||
|
modifiedReq, err := modifier.ModifyRequest(req)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("ModifyRequest failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Read the body and check that it still contains the original data
|
||||||
|
// This test would need to be enhanced with a proper JSON parsing to verify
|
||||||
|
// the added parameters
|
||||||
|
body := make([]byte, 1024)
|
||||||
|
n, err := modifiedReq.Body.Read(body)
|
||||||
|
if err != nil && err.Error() != "EOF" {
|
||||||
|
t.Fatalf("Failed to read body: %v", err)
|
||||||
|
}
|
||||||
|
bodyStr := string(body[:n])
|
||||||
|
|
||||||
|
// Simple check to see if the modified body contains the expected fields
|
||||||
|
if !strings.Contains(bodyStr, "client_name") {
|
||||||
|
t.Errorf("Expected body to contain client_name, got %s", bodyStr)
|
||||||
|
}
|
||||||
|
if !strings.Contains(bodyStr, "redirect_uris") {
|
||||||
|
t.Errorf("Expected body to contain redirect_uris, got %s", bodyStr)
|
||||||
|
}
|
||||||
|
}
|
143
internal/util/jwks_test.go
Normal file
143
internal/util/jwks_test.go
Normal file
|
@ -0,0 +1,143 @@
|
||||||
|
package util
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/rand"
|
||||||
|
"crypto/rsa"
|
||||||
|
"encoding/base64"
|
||||||
|
"encoding/json"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/golang-jwt/jwt/v4"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestValidateJWT(t *testing.T) {
|
||||||
|
// Initialize the test JWKS data
|
||||||
|
initTestJWKS(t)
|
||||||
|
|
||||||
|
// Test cases
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
authHeader string
|
||||||
|
expectError bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "Valid JWT token",
|
||||||
|
authHeader: "Bearer " + createValidJWT(t),
|
||||||
|
expectError: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "No auth header",
|
||||||
|
authHeader: "",
|
||||||
|
expectError: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Invalid auth header format",
|
||||||
|
authHeader: "InvalidFormat",
|
||||||
|
expectError: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Invalid JWT token",
|
||||||
|
authHeader: "Bearer invalid.jwt.token",
|
||||||
|
expectError: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range tests {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
err := ValidateJWT(tc.authHeader)
|
||||||
|
if tc.expectError && err == nil {
|
||||||
|
t.Errorf("Expected error but got none")
|
||||||
|
}
|
||||||
|
if !tc.expectError && err != nil {
|
||||||
|
t.Errorf("Expected no error but got: %v", err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFetchJWKS(t *testing.T) {
|
||||||
|
// Create a mock JWKS server
|
||||||
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
// Generate a test RSA key
|
||||||
|
privateKey, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to generate RSA key: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create JWKS response
|
||||||
|
jwks := map[string]interface{}{
|
||||||
|
"keys": []map[string]interface{}{
|
||||||
|
{
|
||||||
|
"kty": "RSA",
|
||||||
|
"kid": "test-key-id",
|
||||||
|
"n": base64.RawURLEncoding.EncodeToString(privateKey.N.Bytes()),
|
||||||
|
"e": base64.RawURLEncoding.EncodeToString([]byte{1, 0, 1}), // Default exponent 65537
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
json.NewEncoder(w).Encode(jwks)
|
||||||
|
}))
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
// Test fetching JWKS
|
||||||
|
err := FetchJWKS(server.URL)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("FetchJWKS failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check that keys were stored
|
||||||
|
if len(publicKeys) == 0 {
|
||||||
|
t.Errorf("Expected publicKeys to be populated")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Helper function to initialize test JWKS data
|
||||||
|
func initTestJWKS(t *testing.T) {
|
||||||
|
// Create a test RSA key pair
|
||||||
|
privateKey, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to generate RSA key: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Initialize the publicKeys map
|
||||||
|
publicKeys = map[string]*rsa.PublicKey{
|
||||||
|
"test-key-id": &privateKey.PublicKey,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Helper function to create a valid JWT token for testing
|
||||||
|
func createValidJWT(t *testing.T) string {
|
||||||
|
// Create a test RSA key pair
|
||||||
|
privateKey, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to generate RSA key: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Ensure the test key is in the publicKeys map
|
||||||
|
if publicKeys == nil {
|
||||||
|
publicKeys = map[string]*rsa.PublicKey{}
|
||||||
|
}
|
||||||
|
publicKeys["test-key-id"] = &privateKey.PublicKey
|
||||||
|
|
||||||
|
// Create token
|
||||||
|
token := jwt.NewWithClaims(jwt.SigningMethodRS256, jwt.MapClaims{
|
||||||
|
"sub": "1234567890",
|
||||||
|
"name": "Test User",
|
||||||
|
"iat": time.Now().Unix(),
|
||||||
|
"exp": time.Now().Add(time.Hour).Unix(),
|
||||||
|
})
|
||||||
|
token.Header["kid"] = "test-key-id"
|
||||||
|
|
||||||
|
// Sign the token
|
||||||
|
tokenString, err := token.SignedString(privateKey)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to sign token: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return tokenString
|
||||||
|
}
|
Loading…
Add table
Add a link
Reference in a new issue