diff --git a/review/persona.go b/review/persona.go index 3172735..8e037a0 100644 --- a/review/persona.go +++ b/review/persona.go @@ -24,6 +24,10 @@ const MaxPersonaFileSize = 64 * 1024 // This prevents stack exhaustion from deeply nested structures. const MaxYAMLDepth = 20 +// MaxYAMLNodes is the maximum number of YAML nodes allowed in persona files. +// This prevents DoS via wide-but-shallow structures that bypass depth limits. +const MaxYAMLNodes = 1000 + // Persona defines a specialized review role with focused expertise. type Persona struct { Name string `json:"name" yaml:"name"` @@ -153,11 +157,9 @@ func parsePersona(data []byte, source string) (*Persona, error) { // unmarshalYAMLWithDepthLimit unmarshals YAML data with explicit depth limiting // and strict field checking. This protects against stack exhaustion from deeply // nested structures and catches typos in field names. -// Note: Multi-document YAML files are accepted but only the first document is -// parsed; additional documents are silently ignored. This is acceptable for -// persona files where multi-document support is not a use case. +// Multi-document YAML files are rejected to prevent silent data loss. func unmarshalYAMLWithDepthLimit(data []byte, out any, maxDepth int) error { - // First pass: decode into a yaml.Node to check depth limits. + // First pass: decode into a yaml.Node to check depth limits and node counts. // This prevents stack exhaustion before we attempt to decode into structs. var node yaml.Node dec := yaml.NewDecoder(bytes.NewReader(data)) @@ -165,7 +167,15 @@ func unmarshalYAMLWithDepthLimit(data []byte, out any, maxDepth int) error { return err } - if err := checkYAMLDepth(&node, 0, maxDepth); err != nil { + // Reject multi-document YAML files - silently ignoring additional documents + // could lead to confusing behavior where users think their changes take effect. + var extra yaml.Node + if dec.Decode(&extra) == nil { + return fmt.Errorf("multi-document YAML is not supported; only single-document files are allowed") + } + + nodeCount := 0 + if err := checkYAMLDepth(&node, 0, maxDepth, MaxYAMLNodes, make(map[*yaml.Node]struct{}), &nodeCount); err != nil { return err } @@ -178,19 +188,34 @@ func unmarshalYAMLWithDepthLimit(data []byte, out any, maxDepth int) error { return strictDec.Decode(out) } -// checkYAMLDepth recursively checks that YAML nodes don't exceed the depth limit. -// Handles alias nodes by following the Alias pointer to check the target's depth. -func checkYAMLDepth(node *yaml.Node, depth, maxDepth int) error { +// checkYAMLDepth recursively checks that YAML nodes don't exceed the depth limit +// or the total node count limit. It also detects alias cycles to prevent infinite +// recursion from crafted YAML with self-referential aliases. +func checkYAMLDepth(node *yaml.Node, depth, maxDepth, maxNodes int, seen map[*yaml.Node]struct{}, nodeCount *int) error { if depth > maxDepth { return fmt.Errorf("YAML nesting depth exceeds maximum (%d)", maxDepth) } - // Handle alias nodes: follow the alias to its anchor target. - // The alias itself doesn't add depth, but we must check the target. - if node.Kind == yaml.AliasNode && node.Alias != nil { - return checkYAMLDepth(node.Alias, depth, maxDepth) + + // Track total nodes visited as defense-in-depth against wide-but-shallow attacks. + *nodeCount++ + if *nodeCount > maxNodes { + return fmt.Errorf("YAML node count exceeds maximum (%d)", maxNodes) } + + // Cycle detection: if we've seen this node before, we're in a cycle. + if _, ok := seen[node]; ok { + return nil // Already validated this subtree, skip to avoid infinite recursion. + } + seen[node] = struct{}{} + + // Handle alias nodes: follow the alias to its anchor target. + // Increment depth when following aliases since they expand the effective structure. + if node.Kind == yaml.AliasNode && node.Alias != nil { + return checkYAMLDepth(node.Alias, depth+1, maxDepth, maxNodes, seen, nodeCount) + } + for _, child := range node.Content { - if err := checkYAMLDepth(child, depth+1, maxDepth); err != nil { + if err := checkYAMLDepth(child, depth+1, maxDepth, maxNodes, seen, nodeCount); err != nil { return err } } diff --git a/review/persona_test.go b/review/persona_test.go index cd77434..f1f40db 100644 --- a/review/persona_test.go +++ b/review/persona_test.go @@ -6,6 +6,8 @@ import ( "path/filepath" "strings" "testing" + + "gopkg.in/yaml.v3" ) func TestLoadBuiltinPersona(t *testing.T) { @@ -174,6 +176,7 @@ func TestLoadPersonaFromJSONFile(t *testing.T) { "display_name": "Test Persona", "identity": "You are a test persona.\nMulti-line identity works.", "focus": ["testing", "validation"], + "ignore": ["nothing"], "severity": { "major": "Big problems", @@ -499,6 +502,131 @@ func TestYAMLFileSizeLimit(t *testing.T) { } } +func TestYAMLAliasCycleDetection(t *testing.T) { + // Test that our checkYAMLDepth function handles alias cycles gracefully + // by using the seen map to prevent infinite recursion. + // We test this directly because go-yaml's parser handles most cycles + // at parse time, but we need to ensure our checker is robust. + + // Create a node structure where an alias points to a parent node, + // simulating what could happen with malicious input that bypasses + // go-yaml's cycle detection. + parent := &yaml.Node{ + Kind: yaml.MappingNode, + Content: []*yaml.Node{ + {Kind: yaml.ScalarNode, Value: "name"}, + {Kind: yaml.ScalarNode, Value: "test"}, + {Kind: yaml.ScalarNode, Value: "nested"}, + }, + } + + // Create a child that aliases back to the parent (artificial cycle) + aliasToParent := &yaml.Node{ + Kind: yaml.AliasNode, + Alias: parent, + } + parent.Content = append(parent.Content, aliasToParent) + + nodeCount := 0 + seen := make(map[*yaml.Node]struct{}) + + // This should NOT hang or stack overflow - the seen map prevents infinite recursion + err := checkYAMLDepth(parent, 0, MaxYAMLDepth, MaxYAMLNodes, seen, &nodeCount) + if err != nil { + t.Errorf("unexpected error traversing cyclic structure: %v", err) + } + + // Verify we tracked the parent in the seen map + if _, ok := seen[parent]; !ok { + t.Error("parent node not tracked in seen map") + } +} + +func TestYAMLMultiDocumentRejection(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "multi.yaml") + + // Multi-document YAML (documents separated by ---) + content := `name: first +identity: first document +--- +name: second +identity: second document +` + if err := os.WriteFile(path, []byte(content), 0644); err != nil { + t.Fatalf("failed to write test file: %v", err) + } + + _, err := LoadPersona(path) + if err == nil { + t.Error("expected error for multi-document YAML, got nil") + } + if !strings.Contains(err.Error(), "multi-document") { + t.Errorf("error = %q, want containing 'multi-document'", err.Error()) + } +} + +func TestYAMLNodeCountLimit(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "wide.yaml") + + // Build a YAML structure that's shallow but wide - many keys at the same level + // to test the node count limit (should exceed MaxYAMLNodes = 1000) + var sb strings.Builder + sb.WriteString("name: test\nidentity: test\n") + for i := 0; i < 600; i++ { + sb.WriteString(fmt.Sprintf("key%d: value%d\n", i, i)) + } + + if err := os.WriteFile(path, []byte(sb.String()), 0644); err != nil { + t.Fatalf("failed to write test file: %v", err) + } + + _, err := LoadPersona(path) + if err == nil { + t.Error("expected error for wide YAML exceeding node count, got nil") + } + if !strings.Contains(err.Error(), "node count exceeds") { + t.Errorf("error = %q, want containing 'node count exceeds'", err.Error()) + } +} + +func TestCheckYAMLDepthCycleDetectionDirect(t *testing.T) { + // Direct test of cycle detection in checkYAMLDepth by creating + // a node structure with an artificial cycle. + // This tests the seen map logic independent of go-yaml's parsing. + node := &yaml.Node{ + Kind: yaml.MappingNode, + Content: []*yaml.Node{ + {Kind: yaml.ScalarNode, Value: "key"}, + {Kind: yaml.ScalarNode, Value: "value"}, + }, + } + + // Create a cycle by making a child reference the parent + cycleChild := &yaml.Node{ + Kind: yaml.AliasNode, + Alias: node, // Points back to the parent + } + node.Content = append(node.Content, + &yaml.Node{Kind: yaml.ScalarNode, Value: "cyclic"}, + cycleChild, + ) + + nodeCount := 0 + seen := make(map[*yaml.Node]struct{}) + err := checkYAMLDepth(node, 0, MaxYAMLDepth, MaxYAMLNodes, seen, &nodeCount) + + // Should complete without infinite recursion due to cycle detection + if err != nil { + t.Errorf("unexpected error: %v", err) + } + // The seen map should contain multiple entries + if len(seen) < 2 { + t.Errorf("seen map has %d entries, expected at least 2", len(seen)) + } +} + func TestListBuiltinPersonasSortedOrder(t *testing.T) { names := ListBuiltinPersonas() if len(names) < 2 {