diff --git a/cmd/proxy/main.go b/cmd/proxy/main.go index 562e7aa..2583b75 100644 --- a/cmd/proxy/main.go +++ b/cmd/proxy/main.go @@ -92,11 +92,11 @@ func main() { os.Exit(1) } - // 5. (Optional) Build the policy engine - engine := &authz.DefaultPolicyEngine{} + // 5. (Optional) Build the access controler + accessController := &authz.ScopeValidator{} // 6. Build the main router - mux := proxy.NewRouter(cfg, provider, engine) + mux := proxy.NewRouter(cfg, provider, accessController) listen_address := fmt.Sprintf(":%d", cfg.ListenPort) diff --git a/internal/authz/access_control.go b/internal/authz/access_control.go new file mode 100644 index 0000000..2e321c3 --- /dev/null +++ b/internal/authz/access_control.go @@ -0,0 +1,19 @@ +package authz + +import "net/http" + +type Decision int + +const ( + DecisionAllow Decision = iota + DecisionDeny +) + +type AccessControlResult struct { + Decision Decision + Message string +} + +type AccessControl interface { + ValidateAccess(r *http.Request, claims *TokenClaims, requiredScopes any) AccessControlResult +} diff --git a/internal/authz/policy.go b/internal/authz/policy.go deleted file mode 100644 index 5995250..0000000 --- a/internal/authz/policy.go +++ /dev/null @@ -1,19 +0,0 @@ -package authz - -import "net/http" - -type Decision int - -const ( - DecisionAllow Decision = iota - DecisionDeny -) - -type PolicyResult struct { - Decision Decision - Message string -} - -type PolicyEngine interface { - Evaluate(r *http.Request, claims *TokenClaims, requiredScopes any) PolicyResult -} diff --git a/internal/authz/default_policy_engine.go b/internal/authz/scope_validator.go similarity index 83% rename from internal/authz/default_policy_engine.go rename to internal/authz/scope_validator.go index 6f002e6..248cf8a 100644 --- a/internal/authz/default_policy_engine.go +++ b/internal/authz/scope_validator.go @@ -12,14 +12,14 @@ type TokenClaims struct { Scopes []string } -type DefaultPolicyEngine struct{} +type ScopeValidator struct{} // Evaluate and checks the token claims against one or more required scopes. -func (d *DefaultPolicyEngine) Evaluate( +func (d *ScopeValidator) ValidateAccess( _ *http.Request, claims *TokenClaims, requiredScopes any, -) PolicyResult { +) AccessControlResult { logger.Info("Required scopes: %v", requiredScopes) @@ -32,7 +32,7 @@ func (d *DefaultPolicyEngine) Evaluate( } if strings.TrimSpace(scopeStr) == "" { - return PolicyResult{DecisionAllow, ""} + return AccessControlResult{DecisionAllow, ""} } scopes := strings.FieldsFunc(scopeStr, func(r rune) bool { @@ -48,7 +48,7 @@ func (d *DefaultPolicyEngine) Evaluate( logger.Info("Token scopes: %v", claims.Scopes) for _, tokenScope := range claims.Scopes { if _, ok := required[tokenScope]; ok { - return PolicyResult{DecisionAllow, ""} + return AccessControlResult{DecisionAllow, ""} } } @@ -56,7 +56,7 @@ func (d *DefaultPolicyEngine) Evaluate( for s := range required { list = append(list, s) } - return PolicyResult{ + return AccessControlResult{ DecisionDeny, fmt.Sprintf("missing required scope(s): %s", strings.Join(list, ", ")), } diff --git a/internal/proxy/proxy.go b/internal/proxy/proxy.go index 9ec8211..83aeb6e 100644 --- a/internal/proxy/proxy.go +++ b/internal/proxy/proxy.go @@ -18,7 +18,7 @@ import ( // NewRouter builds an http.ServeMux that routes // * /authorize, /token, /register, /.well-known to the provider or proxy // * MCP paths to the MCP server, etc. -func NewRouter(cfg *config.Config, provider authz.Provider, policyEngine authz.PolicyEngine) http.Handler { +func NewRouter(cfg *config.Config, provider authz.Provider, accessController authz.AccessControl) http.Handler { mux := http.NewServeMux() modifiers := map[string]RequestModifier{ @@ -56,20 +56,6 @@ func NewRouter(cfg *config.Config, provider authz.Provider, policyEngine authz.P defaultPaths = append(defaultPaths, "/.well-known/oauth-authorization-server") } - mux.HandleFunc("/.well-known/oauth-protected-resource", func(w http.ResponseWriter, r *http.Request) { - origin := r.Header.Get("Origin") - allowed := getAllowedOrigin(origin, cfg) - if r.Method == http.MethodOptions { - addCORSHeaders(w, cfg, allowed, r.Header.Get("Access-Control-Request-Headers")) - w.WriteHeader(http.StatusNoContent) - return - } - - addCORSHeaders(w, cfg, allowed, "") - provider.ProtectedResourceMetadataHandler()(w, r) - }) - registeredPaths["/.well-known/oauth-protected-resource"] = true - defaultPaths = append(defaultPaths, "/authorize") defaultPaths = append(defaultPaths, "/token") defaultPaths = append(defaultPaths, "/register") @@ -78,6 +64,20 @@ func NewRouter(cfg *config.Config, provider authz.Provider, policyEngine authz.P } } + mux.HandleFunc("/.well-known/oauth-protected-resource", func(w http.ResponseWriter, r *http.Request) { + origin := r.Header.Get("Origin") + allowed := getAllowedOrigin(origin, cfg) + if r.Method == http.MethodOptions { + addCORSHeaders(w, cfg, allowed, r.Header.Get("Access-Control-Request-Headers")) + w.WriteHeader(http.StatusNoContent) + return + } + + addCORSHeaders(w, cfg, allowed, "") + provider.ProtectedResourceMetadataHandler()(w, r) + }) + registeredPaths["/.well-known/oauth-protected-resource"] = true + // Remove duplicates from defaultPaths uniquePaths := make(map[string]bool) cleanPaths := []string{} @@ -91,7 +91,7 @@ func NewRouter(cfg *config.Config, provider authz.Provider, policyEngine authz.P for _, path := range defaultPaths { if !registeredPaths[path] { - mux.HandleFunc(path, buildProxyHandler(cfg, modifiers, policyEngine)) + mux.HandleFunc(path, buildProxyHandler(cfg, modifiers, accessController)) registeredPaths[path] = true } } @@ -99,14 +99,14 @@ func NewRouter(cfg *config.Config, provider authz.Provider, policyEngine authz.P // MCP paths mcpPaths := cfg.GetMCPPaths() for _, path := range mcpPaths { - mux.HandleFunc(path, buildProxyHandler(cfg, modifiers, policyEngine)) + mux.HandleFunc(path, buildProxyHandler(cfg, modifiers, accessController)) registeredPaths[path] = true } // Register paths from PathMapping that haven't been registered yet for path := range cfg.PathMapping { if !registeredPaths[path] { - mux.HandleFunc(path, buildProxyHandler(cfg, modifiers, policyEngine)) + mux.HandleFunc(path, buildProxyHandler(cfg, modifiers, accessController)) registeredPaths[path] = true } } @@ -114,7 +114,7 @@ func NewRouter(cfg *config.Config, provider authz.Provider, policyEngine authz.P return mux } -func buildProxyHandler(cfg *config.Config, modifiers map[string]RequestModifier, policyEngine authz.PolicyEngine) http.HandlerFunc { +func buildProxyHandler(cfg *config.Config, modifiers map[string]RequestModifier, accessController authz.AccessControl) http.HandlerFunc { // Parse the base URLs up front authBase, err := url.Parse(cfg.AuthServerBaseURL) if err != nil { @@ -175,7 +175,7 @@ func buildProxyHandler(cfg *config.Config, modifiers map[string]RequestModifier, } isSSE = true } else { - if err := authorizeMCP(w, r, isLatestSpec, cfg, policyEngine); err != nil { + if err := authorizeMCP(w, r, isLatestSpec, cfg, accessController); err != nil { http.Error(w, err.Error(), http.StatusForbidden) return } @@ -300,7 +300,7 @@ func authorizeSSE(w http.ResponseWriter, r *http.Request, isLatestSpec bool, res } // Handles both v1 (just signature) and v2 (aud + scope) flows -func authorizeMCP(w http.ResponseWriter, r *http.Request, isLatestSpec bool, cfg *config.Config, policyEngine authz.PolicyEngine) error { +func authorizeMCP(w http.ResponseWriter, r *http.Request, isLatestSpec bool, cfg *config.Config, accessController authz.AccessControl) error { authzHeader := r.Header.Get("Authorization") if !strings.HasPrefix(authzHeader, "Bearer ") { if isLatestSpec { @@ -340,7 +340,7 @@ func authorizeMCP(w http.ResponseWriter, r *http.Request, isLatestSpec bool, cfg if len(requiredScopes) == 0 { return nil } - pr := policyEngine.Evaluate(r, claims, requiredScopes) + pr := accessController.ValidateAccess(r, claims, requiredScopes) if pr.Decision == authz.DecisionDeny { http.Error(w, "Forbidden: "+pr.Message, http.StatusForbidden) return fmt.Errorf("forbidden — %s", pr.Message) diff --git a/internal/util/version.go b/internal/util/version.go index f330016..230ef1d 100644 --- a/internal/util/version.go +++ b/internal/util/version.go @@ -19,7 +19,8 @@ func ParseVersionDate(version string) (time.Time, error) { // This function returns the version string, using the cutover date if empty func GetVersionWithDefault(version string) string { if version == "" { - return constants.SpecCutoverDate.Format("2006-01-02") + defaultTime, _ := time.Parse(constants.TimeLayout, "2025-05-15") + return defaultTime.Format(constants.TimeLayout) } return version }