open-mcp-auth-proxy/internal/proxy/modifier.go
Chiran Fernando 32c9378aad
Add transport mode support for stdio, SSE stability fixes (#13)
Add transport mode support for stdio, SSE stability fixes
2025-04-08 12:46:00 +05:30

204 lines
5.1 KiB
Go

package proxy
import (
"bytes"
"encoding/json"
"fmt"
"io"
"net/http"
"strings"
"github.com/wso2/open-mcp-auth-proxy/internal/config"
"github.com/wso2/open-mcp-auth-proxy/internal/logging"
)
// RequestModifier modifies requests before they are proxied
type RequestModifier interface {
ModifyRequest(req *http.Request) (*http.Request, error)
}
// AuthorizationModifier adds parameters to authorization requests
type AuthorizationModifier struct {
Config *config.Config
}
// TokenModifier adds parameters to token requests
type TokenModifier struct {
Config *config.Config
}
type RegisterModifier struct {
Config *config.Config
}
// ModifyRequest adds configured parameters to authorization requests
func (m *AuthorizationModifier) ModifyRequest(req *http.Request) (*http.Request, error) {
// Check if we have parameters to add
if m.Config.Default.Path == nil {
return req, nil
}
pathConfig, exists := m.Config.Default.Path["/authorize"]
if !exists || len(pathConfig.AddQueryParams) == 0 {
return req, nil
}
// Get current query parameters
query := req.URL.Query()
// Add parameters from config
for _, param := range pathConfig.AddQueryParams {
query.Set(param.Name, param.Value)
}
// Update the request URL
req.URL.RawQuery = query.Encode()
return req, nil
}
// ModifyRequest adds configured parameters to token requests
func (m *TokenModifier) ModifyRequest(req *http.Request) (*http.Request, error) {
// Only modify POST requests
if req.Method != http.MethodPost {
return req, nil
}
// Check if we have parameters to add
if m.Config.Default.Path == nil {
return req, nil
}
pathConfig, exists := m.Config.Default.Path["/token"]
if !exists || len(pathConfig.AddBodyParams) == 0 {
return req, nil
}
contentType := req.Header.Get("Content-Type")
if strings.Contains(contentType, "application/x-www-form-urlencoded") {
// Parse form data
if err := req.ParseForm(); err != nil {
return nil, err
}
// Clone form data
formData := req.PostForm
// Add configured parameters
for _, param := range pathConfig.AddBodyParams {
formData.Set(param.Name, param.Value)
}
// Create new request body with modified form
formEncoded := formData.Encode()
req.Body = io.NopCloser(strings.NewReader(formEncoded))
req.ContentLength = int64(len(formEncoded))
req.Header.Set("Content-Length", fmt.Sprintf("%d", len(formEncoded)))
} else if strings.Contains(contentType, "application/json") {
// Read body
bodyBytes, err := io.ReadAll(req.Body)
if err != nil {
return nil, err
}
// Parse JSON
var jsonData map[string]interface{}
if err := json.Unmarshal(bodyBytes, &jsonData); err != nil {
return nil, err
}
// Add parameters
for _, param := range pathConfig.AddBodyParams {
jsonData[param.Name] = param.Value
}
// Marshal back to JSON
modifiedBody, err := json.Marshal(jsonData)
if err != nil {
return nil, err
}
// Update request
req.Body = io.NopCloser(bytes.NewReader(modifiedBody))
req.ContentLength = int64(len(modifiedBody))
req.Header.Set("Content-Length", fmt.Sprintf("%d", len(modifiedBody)))
}
return req, nil
}
func (m *RegisterModifier) ModifyRequest(req *http.Request) (*http.Request, error) {
// Only modify POST requests
if req.Method != http.MethodPost {
return req, nil
}
// Check if we have parameters to add
if m.Config.Default.Path == nil {
return req, nil
}
pathConfig, exists := m.Config.Default.Path["/register"]
if !exists || len(pathConfig.AddBodyParams) == 0 {
return req, nil
}
contentType := req.Header.Get("Content-Type")
if strings.Contains(contentType, "application/x-www-form-urlencoded") {
// Parse form data
if err := req.ParseForm(); err != nil {
logger.Error("Failed to parse form data: %v", err)
return nil, err
}
// Clone form data
formData := req.PostForm
// Add configured parameters
for _, param := range pathConfig.AddBodyParams {
formData.Set(param.Name, param.Value)
}
// Create new request body with modified form
formEncoded := formData.Encode()
req.Body = io.NopCloser(strings.NewReader(formEncoded))
req.ContentLength = int64(len(formEncoded))
req.Header.Set("Content-Length", fmt.Sprintf("%d", len(formEncoded)))
} else if strings.Contains(contentType, "application/json") {
// Read body
bodyBytes, err := io.ReadAll(req.Body)
if err != nil {
logger.Error("Failed to read request body: %v", err)
return nil, err
}
// Parse JSON
var jsonData map[string]interface{}
if err := json.Unmarshal(bodyBytes, &jsonData); err != nil {
logger.Error("Failed to parse JSON body: %v", err)
return nil, err
}
// Add parameters
for _, param := range pathConfig.AddBodyParams {
jsonData[param.Name] = param.Value
}
// Marshal back to JSON
modifiedBody, err := json.Marshal(jsonData)
if err != nil {
logger.Error("Failed to marshal modified JSON: %v", err)
return nil, err
}
// Update request
req.Body = io.NopCloser(bytes.NewReader(modifiedBody))
req.ContentLength = int64(len(modifiedBody))
req.Header.Set("Content-Length", fmt.Sprintf("%d", len(modifiedBody)))
}
return req, nil
}