diff --git a/docs/DESIGN-57-yaml-persona.md b/docs/DESIGN-57-yaml-persona.md index 719a473..f8e981f 100644 --- a/docs/DESIGN-57-yaml-persona.md +++ b/docs/DESIGN-57-yaml-persona.md @@ -9,7 +9,7 @@ JSON is awkward for persona files that contain multi-line text (identity, severi - Backwards compatibility: existing JSON personas must continue to work - Security: protect against DoS via deeply nested YAML (AIKIDO-2024-10486) - Consistency: use `.yaml` extension (not `.yml`) -- Library: use `github.com/goccy/go-yaml` v1.16.0+ (approved in CONVENTIONS.md); has built-in depth protection via `MaxYAMLDepth`/`MaxYAMLNodes` constants +- Library: use `github.com/goccy/go-yaml` v1.16.0+ (approved in CONVENTIONS.md); we implement custom AST-based depth/node-count checks for precise alias-aware validation ## Proposed Approach @@ -63,7 +63,7 @@ func checkYAMLDepth(node *yaml.Node, depth, maxDepth int) error { } ``` -The `github.com/goccy/go-yaml` library provides built-in depth protection via `MaxYAMLDepth` and `MaxYAMLNodes` decoder options. We use these instead of a manual depth-checking walk. +We implement a custom AST-based depth/node-count walk (`checkYAMLDepth`) rather than relying on library decoder options. This gives us precise control over how depth is counted across aliases and anchors, with a depth-aware validated map to prevent alias depth bypass. ## State/Data Model diff --git a/review/persona.go b/review/persona.go index 57dbd0c..29f2336 100644 --- a/review/persona.go +++ b/review/persona.go @@ -184,7 +184,7 @@ func unmarshalYAMLWithDepthLimit(data []byte, out any, maxDepth int) error { } nodeCount := 0 - if err := checkYAMLDepth(file.Docs[0].Body, 0, maxDepth, MaxYAMLNodes, make(map[ast.Node]struct{}), &nodeCount); err != nil { + if err := checkYAMLDepth(file.Docs[0].Body, 0, maxDepth, MaxYAMLNodes, make(map[ast.Node]int), make(map[ast.Node]bool), &nodeCount); err != nil { return err } @@ -195,9 +195,17 @@ func unmarshalYAMLWithDepthLimit(data []byte, out any, maxDepth int) error { } // checkYAMLDepth recursively checks that YAML AST nodes don't exceed the depth -// limit or the total node count limit. It also detects alias cycles to prevent -// infinite recursion from crafted YAML with self-referential aliases. -func checkYAMLDepth(node ast.Node, depth, maxDepth, maxNodes int, seen map[ast.Node]struct{}, nodeCount *int) error { +// limit or the total node count limit. It uses two tracking maps: +// - validated: maps each node to the minimum depth at which it was previously +// checked. If a node is revisited at a deeper depth (e.g., via an alias), +// we re-check it to ensure the combined effective depth doesn't exceed limits. +// - visiting: per-path recursion stack for true cycle detection. A node on the +// current path is a cycle (alias loop); we return nil to avoid infinite recursion. +// +// This design prevents the alias depth bypass where an anchored subtree validated +// at a shallow depth could be referenced via alias at a greater depth, effectively +// exceeding MaxYAMLDepth. +func checkYAMLDepth(node ast.Node, depth, maxDepth, maxNodes int, validated map[ast.Node]int, visiting map[ast.Node]bool, nodeCount *int) error { if node == nil { return nil } @@ -212,48 +220,60 @@ func checkYAMLDepth(node ast.Node, depth, maxDepth, maxNodes int, seen map[ast.N return fmt.Errorf("YAML node count exceeds maximum (%d)", maxNodes) } - // Cycle detection: uses pointer identity (ast.Node is an interface, but all - // concrete node types are pointers) to detect revisits. This intentionally - // compares pointer identity, not structural equality, since we want to track - // specific node instances in the parsed AST graph. - if _, ok := seen[node]; ok { - return nil // Already validated this subtree, skip to avoid infinite recursion. + // Cycle detection: if we're currently visiting this node on the current + // recursion path, it's a cycle (e.g., alias pointing to an ancestor). + // Return nil to break the cycle without error — cycles are a structural + // property, not a depth violation. + if visiting[node] { + return nil } - seen[node] = struct{}{} + + // Depth-aware short-circuit: only skip re-checking a node if we previously + // validated it at the same or deeper effective depth. If this visit is at a + // greater depth than before (e.g., alias referenced deeper in the tree), + // we must re-traverse to catch depth limit violations. + if prevDepth, ok := validated[node]; ok && depth <= prevDepth { + return nil + } + validated[node] = depth + + // Mark as visiting (on the current recursion path) for cycle detection. + visiting[node] = true + defer func() { visiting[node] = false }() // Walk children based on node type. switch n := node.(type) { case *ast.MappingNode: for _, value := range n.Values { - if err := checkYAMLDepth(value, depth+1, maxDepth, maxNodes, seen, nodeCount); err != nil { + if err := checkYAMLDepth(value, depth+1, maxDepth, maxNodes, validated, visiting, nodeCount); err != nil { return err } } case *ast.MappingValueNode: - if err := checkYAMLDepth(n.Key, depth+1, maxDepth, maxNodes, seen, nodeCount); err != nil { + if err := checkYAMLDepth(n.Key, depth+1, maxDepth, maxNodes, validated, visiting, nodeCount); err != nil { return err } - if err := checkYAMLDepth(n.Value, depth+1, maxDepth, maxNodes, seen, nodeCount); err != nil { + if err := checkYAMLDepth(n.Value, depth+1, maxDepth, maxNodes, validated, visiting, nodeCount); err != nil { return err } case *ast.SequenceNode: for _, value := range n.Values { - if err := checkYAMLDepth(value, depth+1, maxDepth, maxNodes, seen, nodeCount); err != nil { + if err := checkYAMLDepth(value, depth+1, maxDepth, maxNodes, validated, visiting, nodeCount); err != nil { return err } } case *ast.AliasNode: // Follow alias to its target, incrementing depth since aliases expand // the effective structure. - if err := checkYAMLDepth(n.Value, depth+1, maxDepth, maxNodes, seen, nodeCount); err != nil { + if err := checkYAMLDepth(n.Value, depth+1, maxDepth, maxNodes, validated, visiting, nodeCount); err != nil { return err } case *ast.AnchorNode: - if err := checkYAMLDepth(n.Value, depth+1, maxDepth, maxNodes, seen, nodeCount); err != nil { + if err := checkYAMLDepth(n.Value, depth+1, maxDepth, maxNodes, validated, visiting, nodeCount); err != nil { return err } case *ast.TagNode: - if err := checkYAMLDepth(n.Value, depth+1, maxDepth, maxNodes, seen, nodeCount); err != nil { + if err := checkYAMLDepth(n.Value, depth+1, maxDepth, maxNodes, validated, visiting, nodeCount); err != nil { return err } // Scalar types (StringNode, IntegerNode, FloatNode, BoolNode, NullNode, diff --git a/review/persona_test.go b/review/persona_test.go index 69be577..344596d 100644 --- a/review/persona_test.go +++ b/review/persona_test.go @@ -484,7 +484,6 @@ func TestYAMLDeeplyNestedRejection(t *testing.T) { } } - func TestYAMLEmptyFileRejection(t *testing.T) { dir := t.TempDir() @@ -536,7 +535,7 @@ func TestYAMLFileSizeLimit(t *testing.T) { func TestYAMLAliasCycleDetection(t *testing.T) { // Test that our checkYAMLDepth function handles alias cycles gracefully - // by using the seen map to prevent infinite recursion. + // by using the visiting map to prevent infinite recursion. // Create a node structure where an alias points to a parent node, // simulating what could happen with crafted input. @@ -559,17 +558,18 @@ func TestYAMLAliasCycleDetection(t *testing.T) { }) nodeCount := 0 - seen := make(map[ast.Node]struct{}) + validated := make(map[ast.Node]int) + visiting := make(map[ast.Node]bool) - // This should NOT hang or stack overflow - the seen map prevents infinite recursion - err := checkYAMLDepth(parent, 0, MaxYAMLDepth, MaxYAMLNodes, seen, &nodeCount) + // This should NOT hang or stack overflow - cycle detection prevents infinite recursion + err := checkYAMLDepth(parent, 0, MaxYAMLDepth, MaxYAMLNodes, validated, visiting, &nodeCount) if err != nil { t.Errorf("unexpected error traversing cyclic structure: %v", err) } - // Verify we tracked the parent in the seen map - if _, ok := seen[parent]; !ok { - t.Error("parent node not tracked in seen map") + // Verify we tracked the parent in the validated map + if _, ok := validated[parent]; !ok { + t.Error("parent node not tracked in validated map") } } @@ -644,16 +644,63 @@ func TestCheckYAMLDepthCycleDetectionDirect(t *testing.T) { }) nodeCount := 0 - seen := make(map[ast.Node]struct{}) - err := checkYAMLDepth(node, 0, MaxYAMLDepth, MaxYAMLNodes, seen, &nodeCount) + validated := make(map[ast.Node]int) + visiting := make(map[ast.Node]bool) + err := checkYAMLDepth(node, 0, MaxYAMLDepth, MaxYAMLNodes, validated, visiting, &nodeCount) // Should complete without infinite recursion due to cycle detection if err != nil { t.Errorf("unexpected error: %v", err) } - // The seen map should contain multiple entries - if len(seen) < 2 { - t.Errorf("seen map has %d entries, expected at least 2", len(seen)) + // The validated map should contain multiple entries + if len(validated) < 2 { + t.Errorf("validated map has %d entries, expected at least 2", len(validated)) + } +} + +func TestYAMLAliasDepthBypass(t *testing.T) { + // Test that an anchored subtree first validated at a shallow depth is + // re-checked when referenced via alias at a deeper position. Without the + // depth-aware validated map, the alias reference would skip re-checking + // and allow the effective nesting to exceed MaxYAMLDepth. + + dir := t.TempDir() + path := filepath.Join(dir, "alias-depth-bypass.yaml") + + // Build YAML with an anchor at shallow depth containing a subtree near the limit, + // then reference it via alias deep enough that effective depth exceeds MaxYAMLDepth. + var sb strings.Builder + sb.WriteString("name: test\nidentity: test\n") + + // Create the anchored subtree at depth 1 (key level) that nests 15 levels deep. + sb.WriteString("anchor_key: &deep_anchor\n") + for i := 0; i < 15; i++ { + sb.WriteString(strings.Repeat(" ", i+1)) + sb.WriteString(fmt.Sprintf("level%d:\n", i)) + } + sb.WriteString(strings.Repeat(" ", 16)) + sb.WriteString("leaf: value\n") + + // Create a wrapper that nests 6 levels deep, then references the anchor. + // Effective depth at alias target = 6 (wrapper nesting) + 1 (alias) + 15 (subtree) = 22 > 20 + sb.WriteString("wrapper:\n") + for i := 0; i < 6; i++ { + sb.WriteString(strings.Repeat(" ", i+1)) + sb.WriteString(fmt.Sprintf("n%d:\n", i)) + } + sb.WriteString(strings.Repeat(" ", 7)) + sb.WriteString("alias_ref: *deep_anchor\n") + + if err := os.WriteFile(path, []byte(sb.String()), 0644); err != nil { + t.Fatalf("failed to write test file: %v", err) + } + + _, err := LoadPersona(path) + if err == nil { + t.Fatal("expected error for alias depth bypass, got nil") + } + if !strings.Contains(err.Error(), "nesting depth exceeds") { + t.Errorf("error = %q, want containing 'nesting depth exceeds'", err.Error()) } }