package review import ( "bytes" "embed" "encoding/json" "fmt" "io" "os" "sort" "strings" "unicode/utf8" "github.com/goccy/go-yaml" "github.com/goccy/go-yaml/ast" "github.com/goccy/go-yaml/parser" ) //go:embed personas/*.yaml var embeddedPersonas embed.FS // MaxPersonaFileSize is the maximum size for persona files (64 KB). // This prevents denial-of-service via excessively large files. const MaxPersonaFileSize = 64 * 1024 // MaxYAMLDepth is the maximum nesting depth allowed in YAML persona files. // This prevents stack exhaustion from deeply nested structures. const MaxYAMLDepth = 20 // MaxYAMLNodes is the maximum number of YAML nodes allowed in persona files. // This prevents DoS via wide-but-shallow structures that bypass depth limits. const MaxYAMLNodes = 1000 // Persona defines a specialized review role with focused expertise. type Persona struct { Name string `json:"name" yaml:"name"` DisplayName string `json:"display_name" yaml:"display_name"` ModelPref string `json:"model_preference,omitempty" yaml:"model_preference,omitempty"` Identity string `json:"identity" yaml:"identity"` Focus []string `json:"focus" yaml:"focus"` Ignore []string `json:"ignore" yaml:"ignore"` Severity Severity `json:"severity" yaml:"severity"` OutputFormat string `json:"output_format,omitempty" yaml:"output_format,omitempty"` } // Severity defines what constitutes each severity level for this persona. // These are prompt guidance for the LLM, not output format changes. type Severity struct { Major string `json:"major" yaml:"major"` Minor string `json:"minor" yaml:"minor"` Nit string `json:"nit" yaml:"nit"` } // LoadPersona loads a persona from a JSON or YAML file path. // Format is detected by file extension: .yaml/.yml for YAML, .json or other for JSON. // Files larger than MaxPersonaFileSize are rejected. // // Symlinks are supported: os.Stat follows symlinks, so a symlink pointing to // a regular file will pass the IsRegular() check. Symlinks to non-regular files // (directories, FIFOs, devices) are still rejected. func LoadPersona(path string) (*Persona, error) { // os.Stat follows symlinks, so symlinks to regular files are supported. // The IsRegular() check operates on the target, not the symlink itself. info, err := os.Stat(path) if err != nil { return nil, fmt.Errorf("read persona file %s: %w", path, err) } if !info.Mode().IsRegular() { return nil, fmt.Errorf("persona file %s is not a regular file", path) } if info.Size() > MaxPersonaFileSize { return nil, fmt.Errorf("persona file %s exceeds maximum size (%d bytes)", path, MaxPersonaFileSize) } data, err := os.ReadFile(path) if err != nil { return nil, fmt.Errorf("read persona file %s: %w", path, err) } // Re-check size after read to defend against TOCTOU races where file // grows between stat and read (e.g., appending process, replaced file). if len(data) > MaxPersonaFileSize { return nil, fmt.Errorf("persona file %s exceeds maximum size (%d bytes)", path, MaxPersonaFileSize) } return parsePersona(data, path) } // LoadBuiltinPersona loads a built-in persona by name. // Returns an error if the persona doesn't exist. // Built-in personas are stored in YAML format only (see embed directive). func LoadBuiltinPersona(name string) (*Persona, error) { yamlFile := name + ".yaml" data, err := embeddedPersonas.ReadFile("personas/" + yamlFile) if err != nil { available := ListBuiltinPersonas() return nil, fmt.Errorf("unknown built-in persona %q (available: %s)", name, strings.Join(available, ", ")) } return parsePersona(data, "builtin:"+yamlFile) } // ListBuiltinPersonas returns the names of all built-in personas in sorted order. // Returns an empty slice if the embedded directory cannot be read. func ListBuiltinPersonas() []string { entries, err := embeddedPersonas.ReadDir("personas") if err != nil { return []string{} } seen := make(map[string]bool) for _, e := range entries { if e.IsDir() { continue } name := e.Name() // Strip extension to get persona name var personaName string switch { case strings.HasSuffix(name, ".yaml"): personaName = strings.TrimSuffix(name, ".yaml") case strings.HasSuffix(name, ".yml"): personaName = strings.TrimSuffix(name, ".yml") case strings.HasSuffix(name, ".json"): personaName = strings.TrimSuffix(name, ".json") default: continue } seen[personaName] = true } names := make([]string, 0, len(seen)) for name := range seen { names = append(names, name) } sort.Strings(names) return names } // parsePersona parses persona data from JSON or YAML format. // Format is detected by the source file extension. func parsePersona(data []byte, source string) (*Persona, error) { lowerSource := strings.ToLower(source) isYAML := strings.HasSuffix(lowerSource, ".yaml") || strings.HasSuffix(lowerSource, ".yml") var p Persona var err error if isYAML { err = unmarshalYAMLWithDepthLimit(data, &p, MaxYAMLDepth) } else { // Use json.Decoder with DisallowUnknownFields for consistency with // YAML's Strict() - both reject unknown fields to catch typos. dec := json.NewDecoder(bytes.NewReader(data)) dec.DisallowUnknownFields() err = dec.Decode(&p) if err == nil { // Reject trailing content after the first valid JSON object. // Without this check, input like `{"name":"x"}garbage` would // silently succeed because Decoder stops after one object. var dummy json.RawMessage if err2 := dec.Decode(&dummy); err2 != io.EOF { err = fmt.Errorf("unexpected trailing content after JSON object") } } } if err != nil { return nil, fmt.Errorf("parse persona %s: %w", source, err) } if err := validatePersona(&p, source); err != nil { return nil, err } return &p, nil } // unmarshalYAMLWithDepthLimit unmarshals YAML data with three safety checks: // - Depth limiting: rejects AST trees exceeding maxDepth to prevent stack exhaustion. // - Multi-document rejection: prevents silent data loss from ignored extra documents. // - Strict field checking: rejects unknown YAML keys to catch typos early. func unmarshalYAMLWithDepthLimit(data []byte, out any, maxDepth int) error { // First pass: parse into AST to check depth limits, node counts, and // multi-document rejection. This prevents stack exhaustion before we // attempt to decode into structs. file, err := parser.ParseBytes(data, 0) if err != nil { return err } // Reject empty YAML input (whitespace-only, comment-only, or truly empty files). // The parser returns a single doc with nil body for these cases. if len(file.Docs) == 0 || file.Docs[0].Body == nil { return fmt.Errorf("empty YAML document") } // Reject multi-document YAML files - silently ignoring additional documents // could lead to confusing behavior where users think their changes take effect. if len(file.Docs) > 1 { return fmt.Errorf("multi-document YAML is not supported; only single-document files are allowed") } nodeCount := 0 if err := checkYAMLDepth(file.Docs[0].Body, 0, maxDepth, MaxYAMLNodes, make(map[ast.Node]int), make(map[ast.Node]bool), &nodeCount); err != nil { return err } // Second pass: decode with strict field checking enabled. // Strict() rejects unknown keys, catching typos like "focuss" or "identiy". // // Safety note: goccy/go-yaml's decoder does not expand YAML aliases // recursively — it resolves them via the pre-built AST, which our first // pass already depth-checked. Alias chains that would exceed depth limits // are caught above; the decoder merely reads the resolved scalar values. dec := yaml.NewDecoder(bytes.NewReader(data), yaml.Strict()) return dec.Decode(out) } // checkYAMLDepth recursively checks that YAML AST nodes don't exceed the depth // limit or the total node count limit. It uses two tracking maps: // - validated: maps each node to the maximum depth at which it was previously // checked. If a node is revisited at a deeper depth (e.g., via an alias), // we re-check it to ensure the combined effective depth doesn't exceed limits. // - visiting: per-path recursion stack for true cycle detection. A node on the // current path is a cycle (alias loop); we return nil to avoid infinite recursion. // // This design prevents the alias depth bypass where an anchored subtree validated // at a shallow depth could be referenced via alias at a greater depth, effectively // exceeding MaxYAMLDepth. func checkYAMLDepth(node ast.Node, depth, maxDepth, maxNodes int, validated map[ast.Node]int, visiting map[ast.Node]bool, nodeCount *int) error { if node == nil { return nil } if depth > maxDepth { return fmt.Errorf("YAML nesting depth exceeds maximum (%d)", maxDepth) } // Cycle detection: if we're currently visiting this node on the current // recursion path, it's a cycle (e.g., alias pointing to an ancestor). // Return nil to break the cycle without error — cycles are a structural // property, not a depth violation. if visiting[node] { return nil } // Track total nodes visited as defense-in-depth against wide-but-shallow attacks. // Placed after cycle detection but before the depth-aware short-circuit. This means // nodes revisited at shallower depths (via aliases) are counted each time they are // encountered — intentional conservative overcounting. This bounds the total work // performed during validation rather than tracking unique nodes, which is the safer // security posture for untrusted YAML input. *nodeCount++ if *nodeCount > maxNodes { return fmt.Errorf("YAML node count exceeds maximum (%d)", maxNodes) } // Depth-aware short-circuit: skip re-validation only when the current visit // depth is the same or shallower than the depth at which this node was // previously validated. A shallower (or equal) current depth means the // prior, deeper validation already covered any subtree depth violations. // If the current depth exceeds the previous validation depth (e.g., an alias // references this node deeper in the tree), we must re-traverse to ensure // the combined effective depth doesn't exceed maxDepth. // // Note: using ast.Node (interface) as map key relies on pointer identity, // which is correct because all goccy/go-yaml AST node types are pointer // receivers (*MappingNode, *SequenceNode, etc.), never value types. if prevDepth, ok := validated[node]; ok && depth <= prevDepth { return nil } validated[node] = depth // Mark as visiting (on the current recursion path) for cycle detection. visiting[node] = true defer func() { visiting[node] = false }() // Walk children based on node type. switch n := node.(type) { case *ast.MappingNode: for _, value := range n.Values { if err := checkYAMLDepth(value, depth+1, maxDepth, maxNodes, validated, visiting, nodeCount); err != nil { return err } } case *ast.MappingValueNode: // Both Key and Value are visited at depth+1 relative to this // MappingValueNode. Since MappingNode visits its MappingValueNode // children at depth+1 as well, keys and values end up at depth+2 // from the parent MappingNode. This is intentional: it mirrors the // actual nesting structure (mapping → key-value pair → key/value). if err := checkYAMLDepth(n.Key, depth+1, maxDepth, maxNodes, validated, visiting, nodeCount); err != nil { return err } if err := checkYAMLDepth(n.Value, depth+1, maxDepth, maxNodes, validated, visiting, nodeCount); err != nil { return err } case *ast.SequenceNode: for _, value := range n.Values { if err := checkYAMLDepth(value, depth+1, maxDepth, maxNodes, validated, visiting, nodeCount); err != nil { return err } } case *ast.AliasNode: // Follow alias to its target, incrementing depth since aliases expand // the effective structure. if err := checkYAMLDepth(n.Value, depth+1, maxDepth, maxNodes, validated, visiting, nodeCount); err != nil { return err } case *ast.AnchorNode: // Increment depth for anchor values as a conservative measure: the // anchor definition itself is structural, and treating it as a depth // level ensures that deeply nested anchors are caught at definition // time rather than only when referenced via alias. This +1 is // asymmetric with alias (which also increments) — by design, the // effective depth budget for anchored-then-aliased content is reduced // because both the definition site and the reference site each consume // a level, making deeply nested anchor/alias pairs hit the limit sooner. if err := checkYAMLDepth(n.Value, depth+1, maxDepth, maxNodes, validated, visiting, nodeCount); err != nil { return err } case *ast.TagNode: if err := checkYAMLDepth(n.Value, depth+1, maxDepth, maxNodes, validated, visiting, nodeCount); err != nil { return err } case *ast.MergeKeyNode: // MergeKeyNode represents the literal "<<" merge key token. It has no // child nodes — the value side of a merge (e.g., *alias) lives in the // parent MappingValueNode.Value, which is already recursed into above. // Explicitly listed here (rather than in the default case) to prevent // future library changes from silently bypassing depth checks. default: // Scalar leaf nodes (StringNode, IntegerNode, FloatNode, BoolNode, // NullNode, InfinityNode, NanNode, LiteralNode) have no children to // recurse into. } return nil } // ParsePersonaBytes parses persona data from bytes with a source label for errors. // This is useful for parsing personas fetched from external sources (e.g., Gitea API) // without requiring filesystem access. Format is detected by source extension. // Input is bounded by MaxPersonaFileSize to prevent resource exhaustion. func ParsePersonaBytes(data []byte, source string) (*Persona, error) { if len(data) > MaxPersonaFileSize { return nil, fmt.Errorf("persona data from %s exceeds maximum size (%d bytes, limit %d)", source, len(data), MaxPersonaFileSize) } return parsePersona(data, source) } func validatePersona(p *Persona, source string) error { if p.Name == "" { return fmt.Errorf("persona %s: name is required", source) } if p.Identity == "" { return fmt.Errorf("persona %s: identity is required", source) } // DisplayName defaults to Name if not set if p.DisplayName == "" { p.DisplayName = p.Name } return nil } // CapitalizeFirst capitalizes the first rune of a string in a Unicode-safe way. // Returns the original string if it's empty. func CapitalizeFirst(s string) string { if s == "" { return s } r, size := utf8.DecodeRuneInString(s) if r == utf8.RuneError { return s } return strings.ToUpper(string(r)) + s[size:] }