Compare commits
1 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 8ebfa80c14 |
+1
-1
@@ -9,7 +9,7 @@
|
||||
|
||||
| Package | Use Case | Scope |
|
||||
|---------|----------|-------|
|
||||
| `github.com/goccy/go-yaml` | YAML parsing and AST inspection (subpkgs: `ast`, `parser`) | production |
|
||||
| `github.com/goccy/go-yaml` | YAML parsing (persona files, config) | production |
|
||||
| `github.com/google/go-cmp` | Test comparisons (`cmp.Diff`) | test only |
|
||||
|
||||
**Any import not in this table or the Go standard library is forbidden.**
|
||||
|
||||
@@ -9,7 +9,7 @@ JSON is awkward for persona files that contain multi-line text (identity, severi
|
||||
- Backwards compatibility: existing JSON personas must continue to work
|
||||
- Security: protect against DoS via deeply nested YAML (AIKIDO-2024-10486)
|
||||
- Consistency: use `.yaml` extension (not `.yml`)
|
||||
- Library: use `github.com/goccy/go-yaml` v1.16.0+ (approved in CONVENTIONS.md); we implement custom AST-based depth/node-count checks for precise alias-aware validation
|
||||
- Library: use `github.com/goccy/go-yaml` v1.16.0+ (approved in CONVENTIONS.md); has built-in depth protection via `MaxYAMLDepth`/`MaxYAMLNodes` constants
|
||||
|
||||
## Proposed Approach
|
||||
|
||||
@@ -33,16 +33,37 @@ func parsePersona(data []byte, source string) (*Persona, error) {
|
||||
|
||||
### YAML Parsing with Depth Protection
|
||||
|
||||
We implement a custom AST-based depth/node-count walk (`checkYAMLDepth` in
|
||||
`review/persona.go`) rather than relying on library decoder options. Key design
|
||||
decisions:
|
||||
```go
|
||||
func unmarshalYAMLWithDepthLimit(data []byte, out any, maxDepth int) error {
|
||||
var node yaml.Node
|
||||
dec := yaml.NewDecoder(bytes.NewReader(data))
|
||||
if err := dec.Decode(&node); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := checkYAMLDepth(&node, 0, maxDepth); err != nil {
|
||||
return err
|
||||
}
|
||||
return node.Decode(out)
|
||||
}
|
||||
|
||||
- **Library:** `github.com/goccy/go-yaml` with `ast.Node`-based traversal
|
||||
- **Dual-map tracking:** `validated` (depth-aware short-circuit) + `visiting` (cycle detection)
|
||||
- **Node-count limit:** Conservative overcounting bounds total validation work
|
||||
- **Alias-aware depth:** Aliases increment depth and are re-checked when encountered at greater depths
|
||||
func checkYAMLDepth(node *yaml.Node, depth, maxDepth int) error {
|
||||
if depth > maxDepth {
|
||||
return fmt.Errorf("YAML nesting depth exceeds maximum (%d)", maxDepth)
|
||||
}
|
||||
// Handle alias nodes by following the Alias pointer
|
||||
if node.Kind == yaml.AliasNode && node.Alias != nil {
|
||||
return checkYAMLDepth(node.Alias, depth, maxDepth)
|
||||
}
|
||||
for _, child := range node.Content {
|
||||
if err := checkYAMLDepth(child, depth+1, maxDepth); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
```
|
||||
|
||||
See `review/persona.go:checkYAMLDepth` for the authoritative implementation.
|
||||
The `github.com/goccy/go-yaml` library provides built-in depth protection via `MaxYAMLDepth` and `MaxYAMLNodes` decoder options. We use these instead of a manual depth-checking walk.
|
||||
|
||||
## State/Data Model
|
||||
|
||||
@@ -53,7 +74,7 @@ No new state. Same `Persona` struct, just different parsing.
|
||||
| Error | Handling |
|
||||
|-------|----------|
|
||||
| Invalid YAML syntax | Return parse error with source file |
|
||||
| Deeply nested YAML | Custom AST walk (`checkYAMLDepth`) rejects before decode |
|
||||
| Deeply nested YAML | Library rejects (v1.16.0+ fix) |
|
||||
| Unknown extension | Fall back to JSON parsing |
|
||||
| Missing required fields | Validation rejects after parse |
|
||||
|
||||
|
||||
+14
-81
@@ -11,7 +11,6 @@ import (
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"math"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
@@ -48,12 +47,6 @@ func IsServerError(err error) bool {
|
||||
return errors.As(err, &apiErr) && apiErr.StatusCode >= 500 && apiErr.StatusCode < 600
|
||||
}
|
||||
|
||||
// DefaultMaxDiffSize is the default maximum diff size in bytes (10 MB).
|
||||
const DefaultMaxDiffSize = 10 * 1024 * 1024
|
||||
|
||||
// ErrDiffTooLarge is returned when a PR diff exceeds the configured MaxDiffSize.
|
||||
var ErrDiffTooLarge = errors.New("diff size exceeds maximum allowed size")
|
||||
|
||||
// Client interacts with the Gitea API.
|
||||
// A Client is safe for concurrent use by multiple goroutines.
|
||||
type Client struct {
|
||||
@@ -68,14 +61,6 @@ type Client struct {
|
||||
// This field must be configured before the first request is made.
|
||||
// Modifying it while requests are in flight is not safe.
|
||||
RetryBackoff []time.Duration
|
||||
|
||||
// MaxDiffSize is the maximum number of bytes allowed when fetching a PR diff.
|
||||
// If zero, defaults to DefaultMaxDiffSize (10 MB). Set to any negative value
|
||||
// (or math.MaxInt64) to disable the limit.
|
||||
//
|
||||
// This field must be configured before the first request is made.
|
||||
// Modifying it while requests are in flight is not safe.
|
||||
MaxDiffSize int64
|
||||
}
|
||||
|
||||
// NewClient creates a new Gitea API client.
|
||||
@@ -140,28 +125,9 @@ func (c *Client) GetPullRequest(ctx context.Context, owner, repo string, number
|
||||
}
|
||||
|
||||
// GetPullRequestDiff fetches the unified diff for a PR.
|
||||
// It enforces MaxDiffSize to prevent unbounded memory allocation.
|
||||
// Returns ErrDiffTooLarge if the diff exceeds the configured limit.
|
||||
func (c *Client) GetPullRequestDiff(ctx context.Context, owner, repo string, number int) (string, error) {
|
||||
reqURL := fmt.Sprintf("%s/api/v1/repos/%s/%s/pulls/%d.diff", c.baseURL, url.PathEscape(owner), url.PathEscape(repo), number)
|
||||
|
||||
maxSize := c.MaxDiffSize
|
||||
if maxSize == 0 {
|
||||
maxSize = DefaultMaxDiffSize
|
||||
}
|
||||
|
||||
// When the limit is disabled (negative) or set to math.MaxInt64 (which
|
||||
// would overflow the +1 detection and silently disable enforcement),
|
||||
// use the standard unlimited doGet path.
|
||||
if maxSize < 0 || maxSize == math.MaxInt64 {
|
||||
body, err := c.doGet(ctx, reqURL)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("fetch diff: %w", err)
|
||||
}
|
||||
return string(body), nil
|
||||
}
|
||||
|
||||
body, err := c.doGetLimited(ctx, reqURL, maxSize)
|
||||
body, err := c.doGet(ctx, reqURL)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("fetch diff: %w", err)
|
||||
}
|
||||
@@ -326,9 +292,9 @@ func isRetriableSyscallError(err error) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
// redactURL strips query parameters and userinfo credentials from a URL for
|
||||
// safe logging. This prevents accidental exposure of sensitive data (tokens in
|
||||
// query strings, or user:pass in the authority) in log output.
|
||||
// redactURL strips query parameters from a URL for safe logging.
|
||||
// This prevents accidental exposure of sensitive data that future callers
|
||||
// might pass via query strings.
|
||||
func redactURL(rawURL string) string {
|
||||
parsed, err := url.Parse(rawURL)
|
||||
if err != nil {
|
||||
@@ -336,9 +302,6 @@ func redactURL(rawURL string) string {
|
||||
// potentially logging something sensitive.
|
||||
return "[invalid URL]"
|
||||
}
|
||||
if parsed.User != nil {
|
||||
parsed.User = url.User("REDACTED")
|
||||
}
|
||||
if parsed.RawQuery != "" {
|
||||
parsed.RawQuery = "[redacted]"
|
||||
}
|
||||
@@ -359,12 +322,10 @@ func sanitizeErrorForLog(err error) string {
|
||||
return err.Error()
|
||||
}
|
||||
|
||||
// doGetWithReader performs an HTTP GET request with retry on 5xx errors and
|
||||
// temporary network errors. Retries up to 3 times with exponential backoff
|
||||
// (1s, 2s delays by default; configurable via Client.RetryBackoff for testing).
|
||||
// The readBody function is called with the response body on success (2xx) and
|
||||
// is responsible for reading and closing it.
|
||||
func (c *Client) doGetWithReader(ctx context.Context, reqURL string, readBody func(io.ReadCloser) ([]byte, error)) ([]byte, error) {
|
||||
// doGet performs an HTTP GET request with retry on 5xx errors and temporary
|
||||
// network errors. Retries up to 3 times with exponential backoff (1s, 2s delays
|
||||
// by default; configurable via Client.RetryBackoff for testing).
|
||||
func (c *Client) doGet(ctx context.Context, reqURL string) ([]byte, error) {
|
||||
const maxAttempts = 3
|
||||
// backoff[i] is the delay before attempt i+1 (i.e., after attempt i fails).
|
||||
// First attempt (i=0) has no delay; retries wait 1s then 2s by default.
|
||||
@@ -429,7 +390,12 @@ func (c *Client) doGetWithReader(ctx context.Context, reqURL string, readBody fu
|
||||
return nil, lastErr
|
||||
}
|
||||
if resp.StatusCode >= 200 && resp.StatusCode < 300 {
|
||||
return readBody(resp.Body)
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
resp.Body.Close()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return body, nil
|
||||
}
|
||||
|
||||
// Error path: limit how much we read from potentially malicious server
|
||||
@@ -447,39 +413,6 @@ func (c *Client) doGetWithReader(ctx context.Context, reqURL string, readBody fu
|
||||
return nil, lastErr
|
||||
}
|
||||
|
||||
// doGet performs an HTTP GET request with retry, reading the full response body.
|
||||
func (c *Client) doGet(ctx context.Context, reqURL string) ([]byte, error) {
|
||||
return c.doGetWithReader(ctx, reqURL, func(body io.ReadCloser) ([]byte, error) {
|
||||
defer body.Close()
|
||||
return io.ReadAll(body)
|
||||
})
|
||||
}
|
||||
|
||||
// doGetLimited performs an HTTP GET request with retry but enforces a maximum
|
||||
// response body size. Returns ErrDiffTooLarge if the response exceeds maxBytes.
|
||||
// It reads maxBytes+1 (clamped to avoid overflow) to detect truncation without
|
||||
// buffering the entire body.
|
||||
func (c *Client) doGetLimited(ctx context.Context, reqURL string, maxBytes int64) ([]byte, error) {
|
||||
return c.doGetWithReader(ctx, reqURL, func(body io.ReadCloser) ([]byte, error) {
|
||||
defer body.Close()
|
||||
// Read up to maxBytes+1 to detect overflow.
|
||||
// Clamp to prevent integer overflow when maxBytes == math.MaxInt64.
|
||||
limitBytes := maxBytes + 1
|
||||
if limitBytes <= 0 {
|
||||
limitBytes = math.MaxInt64
|
||||
}
|
||||
limited := io.LimitReader(body, limitBytes)
|
||||
data, err := io.ReadAll(limited)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if int64(len(data)) > maxBytes {
|
||||
return nil, fmt.Errorf("%w: response exceeds %d bytes", ErrDiffTooLarge, maxBytes)
|
||||
}
|
||||
return data, nil
|
||||
})
|
||||
}
|
||||
|
||||
// escapePath escapes each segment of a relative file path for use in URLs.
|
||||
// Slashes are preserved as path separators; other special characters are escaped.
|
||||
// Input should be a relative path (no leading slash). Already-encoded segments
|
||||
|
||||
@@ -1092,21 +1092,6 @@ func TestRedactURL(t *testing.T) {
|
||||
input: "",
|
||||
want: "",
|
||||
},
|
||||
{
|
||||
name: "with userinfo - redacts credentials",
|
||||
input: "https://admin:secret@gitea.example.com/api/v1/repos",
|
||||
want: "https://REDACTED@gitea.example.com/api/v1/repos",
|
||||
},
|
||||
{
|
||||
name: "with userinfo and query params",
|
||||
input: "https://user:pass@example.com/path?token=abc",
|
||||
want: "https://REDACTED@example.com/path?[redacted]",
|
||||
},
|
||||
{
|
||||
name: "username only - no password",
|
||||
input: "https://user@example.com/path",
|
||||
want: "https://REDACTED@example.com/path",
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
|
||||
@@ -1,97 +0,0 @@
|
||||
package gitea
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"math"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestGetPullRequestDiff_SizeLimits(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
diff string
|
||||
maxDiffSize int64
|
||||
wantErr error
|
||||
wantDiff string
|
||||
}{
|
||||
{
|
||||
name: "exceeds max size",
|
||||
diff: strings.Repeat("+ added line\n", 1000), // ~13 KB
|
||||
maxDiffSize: 100,
|
||||
wantErr: ErrDiffTooLarge,
|
||||
},
|
||||
{
|
||||
name: "within max size",
|
||||
diff: "diff --git a/f.go b/f.go\n--- a/f.go\n+++ b/f.go\n@@ -1 +1 @@\n-old\n+new\n",
|
||||
maxDiffSize: 1024,
|
||||
wantDiff: "diff --git a/f.go b/f.go\n--- a/f.go\n+++ b/f.go\n@@ -1 +1 @@\n-old\n+new\n",
|
||||
},
|
||||
{
|
||||
name: "exactly at limit",
|
||||
diff: strings.Repeat("x", 50),
|
||||
maxDiffSize: 50,
|
||||
wantDiff: strings.Repeat("x", 50),
|
||||
},
|
||||
{
|
||||
name: "one byte over limit",
|
||||
diff: strings.Repeat("x", 51),
|
||||
maxDiffSize: 50,
|
||||
wantErr: ErrDiffTooLarge,
|
||||
},
|
||||
{
|
||||
name: "disabled limit",
|
||||
diff: strings.Repeat("x", 10000),
|
||||
maxDiffSize: -1,
|
||||
wantDiff: strings.Repeat("x", 10000),
|
||||
},
|
||||
{
|
||||
name: "math.MaxInt64 treated as disabled",
|
||||
diff: strings.Repeat("x", 10000),
|
||||
maxDiffSize: math.MaxInt64,
|
||||
wantDiff: strings.Repeat("x", 10000),
|
||||
},
|
||||
{
|
||||
name: "default limit",
|
||||
diff: "diff content",
|
||||
maxDiffSize: 0, // zero means use DefaultMaxDiffSize
|
||||
wantDiff: "diff content",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Write([]byte(tt.diff)) //nolint:errcheck // test handler
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
client := NewClient(server.URL, "test-token")
|
||||
client.MaxDiffSize = tt.maxDiffSize
|
||||
client.RetryBackoff = []time.Duration{}
|
||||
|
||||
got, err := client.GetPullRequestDiff(context.Background(), "owner", "repo", 1)
|
||||
|
||||
if tt.wantErr != nil {
|
||||
if err == nil {
|
||||
t.Fatal("expected error, got nil")
|
||||
}
|
||||
if !errors.Is(err, tt.wantErr) {
|
||||
t.Errorf("expected %v, got: %v", tt.wantErr, err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if got != tt.wantDiff {
|
||||
t.Errorf("diff mismatch: got length %d, want length %d", len(got), len(tt.wantDiff))
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
+20
-70
@@ -5,7 +5,6 @@ import (
|
||||
"embed"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"sort"
|
||||
"strings"
|
||||
@@ -121,7 +120,9 @@ func ListBuiltinPersonas() []string {
|
||||
default:
|
||||
continue
|
||||
}
|
||||
seen[personaName] = true
|
||||
if !seen[personaName] {
|
||||
seen[personaName] = true
|
||||
}
|
||||
}
|
||||
names := make([]string, 0, len(seen))
|
||||
for name := range seen {
|
||||
@@ -147,15 +148,6 @@ func parsePersona(data []byte, source string) (*Persona, error) {
|
||||
dec := json.NewDecoder(bytes.NewReader(data))
|
||||
dec.DisallowUnknownFields()
|
||||
err = dec.Decode(&p)
|
||||
if err == nil {
|
||||
// Reject trailing content after the first valid JSON object.
|
||||
// Without this check, input like `{"name":"x"}garbage` would
|
||||
// silently succeed because Decoder stops after one object.
|
||||
var dummy json.RawMessage
|
||||
if err2 := dec.Decode(&dummy); err2 != io.EOF {
|
||||
err = fmt.Errorf("unexpected trailing content after JSON object")
|
||||
}
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("parse persona %s: %w", source, err)
|
||||
@@ -166,10 +158,10 @@ func parsePersona(data []byte, source string) (*Persona, error) {
|
||||
return &p, nil
|
||||
}
|
||||
|
||||
// unmarshalYAMLWithDepthLimit unmarshals YAML data with three safety checks:
|
||||
// - Depth limiting: rejects AST trees exceeding maxDepth to prevent stack exhaustion.
|
||||
// - Multi-document rejection: prevents silent data loss from ignored extra documents.
|
||||
// - Strict field checking: rejects unknown YAML keys to catch typos early.
|
||||
// unmarshalYAMLWithDepthLimit unmarshals YAML data with explicit depth limiting
|
||||
// and strict field checking. This protects against stack exhaustion from deeply
|
||||
// nested structures and catches typos in field names.
|
||||
// Multi-document YAML files are rejected to prevent silent data loss.
|
||||
func unmarshalYAMLWithDepthLimit(data []byte, out any, maxDepth int) error {
|
||||
// First pass: parse into AST to check depth limits, node counts, and
|
||||
// multi-document rejection. This prevents stack exhaustion before we
|
||||
@@ -198,18 +190,13 @@ func unmarshalYAMLWithDepthLimit(data []byte, out any, maxDepth int) error {
|
||||
|
||||
// Second pass: decode with strict field checking enabled.
|
||||
// Strict() rejects unknown keys, catching typos like "focuss" or "identiy".
|
||||
//
|
||||
// Safety note: goccy/go-yaml's decoder does not expand YAML aliases
|
||||
// recursively — it resolves them via the pre-built AST, which our first
|
||||
// pass already depth-checked. Alias chains that would exceed depth limits
|
||||
// are caught above; the decoder merely reads the resolved scalar values.
|
||||
dec := yaml.NewDecoder(bytes.NewReader(data), yaml.Strict())
|
||||
return dec.Decode(out)
|
||||
}
|
||||
|
||||
// checkYAMLDepth recursively checks that YAML AST nodes don't exceed the depth
|
||||
// limit or the total node count limit. It uses two tracking maps:
|
||||
// - validated: maps each node to the maximum depth at which it was previously
|
||||
// - validated: maps each node to the minimum depth at which it was previously
|
||||
// checked. If a node is revisited at a deeper depth (e.g., via an alias),
|
||||
// we re-check it to ensure the combined effective depth doesn't exceed limits.
|
||||
// - visiting: per-path recursion stack for true cycle detection. A node on the
|
||||
@@ -227,6 +214,12 @@ func checkYAMLDepth(node ast.Node, depth, maxDepth, maxNodes int, validated map[
|
||||
return fmt.Errorf("YAML nesting depth exceeds maximum (%d)", maxDepth)
|
||||
}
|
||||
|
||||
// Track total nodes visited as defense-in-depth against wide-but-shallow attacks.
|
||||
*nodeCount++
|
||||
if *nodeCount > maxNodes {
|
||||
return fmt.Errorf("YAML node count exceeds maximum (%d)", maxNodes)
|
||||
}
|
||||
|
||||
// Cycle detection: if we're currently visiting this node on the current
|
||||
// recursion path, it's a cycle (e.g., alias pointing to an ancestor).
|
||||
// Return nil to break the cycle without error — cycles are a structural
|
||||
@@ -235,28 +228,10 @@ func checkYAMLDepth(node ast.Node, depth, maxDepth, maxNodes int, validated map[
|
||||
return nil
|
||||
}
|
||||
|
||||
// Track total nodes visited as defense-in-depth against wide-but-shallow attacks.
|
||||
// Placed after cycle detection but before the depth-aware short-circuit. This means
|
||||
// nodes revisited at shallower depths (via aliases) are counted each time they are
|
||||
// encountered — intentional conservative overcounting. This bounds the total work
|
||||
// performed during validation rather than tracking unique nodes, which is the safer
|
||||
// security posture for untrusted YAML input.
|
||||
*nodeCount++
|
||||
if *nodeCount > maxNodes {
|
||||
return fmt.Errorf("YAML node count exceeds maximum (%d)", maxNodes)
|
||||
}
|
||||
|
||||
// Depth-aware short-circuit: skip re-validation only when the current visit
|
||||
// depth is the same or shallower than the depth at which this node was
|
||||
// previously validated. A shallower (or equal) current depth means the
|
||||
// prior, deeper validation already covered any subtree depth violations.
|
||||
// If the current depth exceeds the previous validation depth (e.g., an alias
|
||||
// references this node deeper in the tree), we must re-traverse to ensure
|
||||
// the combined effective depth doesn't exceed maxDepth.
|
||||
//
|
||||
// Note: using ast.Node (interface) as map key relies on pointer identity,
|
||||
// which is correct because all goccy/go-yaml AST node types are pointer
|
||||
// receivers (*MappingNode, *SequenceNode, etc.), never value types.
|
||||
// Depth-aware short-circuit: only skip re-checking a node if we previously
|
||||
// validated it at the same or deeper effective depth. If this visit is at a
|
||||
// greater depth than before (e.g., alias referenced deeper in the tree),
|
||||
// we must re-traverse to catch depth limit violations.
|
||||
if prevDepth, ok := validated[node]; ok && depth <= prevDepth {
|
||||
return nil
|
||||
}
|
||||
@@ -275,11 +250,6 @@ func checkYAMLDepth(node ast.Node, depth, maxDepth, maxNodes int, validated map[
|
||||
}
|
||||
}
|
||||
case *ast.MappingValueNode:
|
||||
// Both Key and Value are visited at depth+1 relative to this
|
||||
// MappingValueNode. Since MappingNode visits its MappingValueNode
|
||||
// children at depth+1 as well, keys and values end up at depth+2
|
||||
// from the parent MappingNode. This is intentional: it mirrors the
|
||||
// actual nesting structure (mapping → key-value pair → key/value).
|
||||
if err := checkYAMLDepth(n.Key, depth+1, maxDepth, maxNodes, validated, visiting, nodeCount); err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -299,14 +269,6 @@ func checkYAMLDepth(node ast.Node, depth, maxDepth, maxNodes int, validated map[
|
||||
return err
|
||||
}
|
||||
case *ast.AnchorNode:
|
||||
// Increment depth for anchor values as a conservative measure: the
|
||||
// anchor definition itself is structural, and treating it as a depth
|
||||
// level ensures that deeply nested anchors are caught at definition
|
||||
// time rather than only when referenced via alias. This +1 is
|
||||
// asymmetric with alias (which also increments) — by design, the
|
||||
// effective depth budget for anchored-then-aliased content is reduced
|
||||
// because both the definition site and the reference site each consume
|
||||
// a level, making deeply nested anchor/alias pairs hit the limit sooner.
|
||||
if err := checkYAMLDepth(n.Value, depth+1, maxDepth, maxNodes, validated, visiting, nodeCount); err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -314,16 +276,8 @@ func checkYAMLDepth(node ast.Node, depth, maxDepth, maxNodes int, validated map[
|
||||
if err := checkYAMLDepth(n.Value, depth+1, maxDepth, maxNodes, validated, visiting, nodeCount); err != nil {
|
||||
return err
|
||||
}
|
||||
case *ast.MergeKeyNode:
|
||||
// MergeKeyNode represents the literal "<<" merge key token. It has no
|
||||
// child nodes — the value side of a merge (e.g., *alias) lives in the
|
||||
// parent MappingValueNode.Value, which is already recursed into above.
|
||||
// Explicitly listed here (rather than in the default case) to prevent
|
||||
// future library changes from silently bypassing depth checks.
|
||||
default:
|
||||
// Scalar leaf nodes (StringNode, IntegerNode, FloatNode, BoolNode,
|
||||
// NullNode, InfinityNode, NanNode, LiteralNode) have no children to
|
||||
// recurse into.
|
||||
// Scalar types (StringNode, IntegerNode, FloatNode, BoolNode, NullNode,
|
||||
// InfinityNode, NanNode, LiteralNode, MergeKeyNode) are leaf nodes.
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -331,11 +285,7 @@ func checkYAMLDepth(node ast.Node, depth, maxDepth, maxNodes int, validated map[
|
||||
// ParsePersonaBytes parses persona data from bytes with a source label for errors.
|
||||
// This is useful for parsing personas fetched from external sources (e.g., Gitea API)
|
||||
// without requiring filesystem access. Format is detected by source extension.
|
||||
// Input is bounded by MaxPersonaFileSize to prevent resource exhaustion.
|
||||
func ParsePersonaBytes(data []byte, source string) (*Persona, error) {
|
||||
if len(data) > MaxPersonaFileSize {
|
||||
return nil, fmt.Errorf("persona data from %s exceeds maximum size (%d bytes, limit %d)", source, len(data), MaxPersonaFileSize)
|
||||
}
|
||||
return parsePersona(data, source)
|
||||
}
|
||||
|
||||
|
||||
+10
-113
@@ -459,14 +459,8 @@ func TestYAMLDeeplyNestedRejection(t *testing.T) {
|
||||
path := filepath.Join(dir, "deeply-nested.yaml")
|
||||
|
||||
// Build a deeply nested YAML structure that exceeds MaxYAMLDepth (20).
|
||||
// Depth accumulation trace for "nested: \n level0: \n level1: ...":
|
||||
// - Document root parsed at depth 0
|
||||
// - Root MappingNode children (MappingValueNodes) visited at depth 1
|
||||
// - "nested" MappingValueNode: key at depth 2, value at depth 2
|
||||
// - Each levelN adds depth via MappingValueNode traversal (key + value)
|
||||
// - Exact depth per level depends on AST structure (MappingNode wrapping),
|
||||
// but 25 levels reliably exceeds MaxYAMLDepth (20) with comfortable margin.
|
||||
// The test uses 25 levels rather than exactly 21 to avoid brittleness.
|
||||
// Each nested mapping key generates a MappingValueNode, incrementing depth
|
||||
// by 1 per level in the AST walk. With 25 levels, we exceed MaxYAMLDepth (20).
|
||||
var sb strings.Builder
|
||||
sb.WriteString("name: test\nidentity: test\nnested:\n")
|
||||
indent := " "
|
||||
@@ -490,19 +484,21 @@ func TestYAMLDeeplyNestedRejection(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
func TestYAMLEmptyFileRejection(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
content string
|
||||
}{
|
||||
{"completely_empty", ""},
|
||||
{"whitespace_only", " \n\n "},
|
||||
{"comment_only", "# just a comment\n"},
|
||||
{"completely empty", ""},
|
||||
{"whitespace only", " \n\n "},
|
||||
{"comment only", "# just a comment\n"},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
path := filepath.Join(dir, tc.name+".yaml")
|
||||
if err := os.WriteFile(path, []byte(tc.content), 0644); err != nil {
|
||||
t.Fatalf("failed to write test file: %v", err)
|
||||
@@ -510,9 +506,9 @@ func TestYAMLEmptyFileRejection(t *testing.T) {
|
||||
|
||||
_, err := LoadPersona(path)
|
||||
if err == nil {
|
||||
t.Fatal("expected error for empty YAML input, got nil")
|
||||
t.Error("expected error for empty YAML input, got nil")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "empty YAML document") {
|
||||
if err != nil && !strings.Contains(err.Error(), "empty YAML document") {
|
||||
t.Errorf("expected error containing %q, got: %v", "empty YAML document", err)
|
||||
}
|
||||
})
|
||||
@@ -858,102 +854,3 @@ identity: test identity
|
||||
t.Errorf("Name = %q, want %q", p.Name, "test")
|
||||
}
|
||||
}
|
||||
|
||||
func TestJSONTrailingContentRejected(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
content string
|
||||
}{
|
||||
{
|
||||
name: "trailing garbage after object",
|
||||
content: `{"name":"test","identity":"test identity"}garbage`,
|
||||
},
|
||||
{
|
||||
name: "two JSON objects",
|
||||
content: `{"name":"test","identity":"test identity"}{"name":"other"}`,
|
||||
},
|
||||
{
|
||||
name: "trailing array",
|
||||
content: `{"name":"test","identity":"test identity"}[]`,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
path := filepath.Join(dir, "test.json")
|
||||
if err := os.WriteFile(path, []byte(tt.content), 0644); err != nil {
|
||||
t.Fatalf("failed to write test file: %v", err)
|
||||
}
|
||||
|
||||
_, err := LoadPersona(path)
|
||||
if err == nil {
|
||||
t.Fatal("expected error for trailing content, got nil")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "trailing content") {
|
||||
t.Errorf("error = %q, want to contain 'trailing content'", err.Error())
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestParsePersonaBytesSizeLimit(t *testing.T) {
|
||||
// ParsePersonaBytes should reject input exceeding MaxPersonaFileSize
|
||||
oversized := make([]byte, MaxPersonaFileSize+1)
|
||||
for i := range oversized {
|
||||
oversized[i] = 'x'
|
||||
}
|
||||
|
||||
_, err := ParsePersonaBytes(oversized, "oversized.yaml")
|
||||
if err == nil {
|
||||
t.Fatal("expected error for oversized input, got nil")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "exceeds maximum size") {
|
||||
t.Errorf("error = %q, want to contain 'exceeds maximum size'", err.Error())
|
||||
}
|
||||
|
||||
// Just under the limit should not trigger size error (may fail parse, but not size)
|
||||
underLimit := []byte("name: test\nidentity: test persona\n")
|
||||
p, err := ParsePersonaBytes(underLimit, "valid.yaml")
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error for valid input: %v", err)
|
||||
}
|
||||
if p.Name != "test" {
|
||||
t.Errorf("Name = %q, want %q", p.Name, "test")
|
||||
}
|
||||
}
|
||||
|
||||
func TestYAMLMergeKeyDepthCheck(t *testing.T) {
|
||||
// Verify that YAML merge keys (<<: *alias) are properly handled by the
|
||||
// depth checker. The merge key content is in the MappingValueNode.Value
|
||||
// (an AliasNode), not in the MergeKeyNode itself.
|
||||
p, err := ParsePersonaBytes([]byte("name: merge-test\nidentity: test\n"), "merge.yaml")
|
||||
if err != nil {
|
||||
t.Fatalf("basic parse failed: %v", err)
|
||||
}
|
||||
if p.Name != "merge-test" {
|
||||
t.Errorf("Name = %q, want %q", p.Name, "merge-test")
|
||||
}
|
||||
|
||||
// Test that deeply nested merge keys still hit depth limit.
|
||||
// Build YAML with merge key content nested beyond MaxYAMLDepth.
|
||||
var sb strings.Builder
|
||||
sb.WriteString("name: deep-merge\nidentity: deep merge persona\n")
|
||||
sb.WriteString("anchor: &deep\n")
|
||||
indent := " "
|
||||
for i := 0; i < MaxYAMLDepth+5; i++ {
|
||||
sb.WriteString(indent)
|
||||
sb.WriteString(fmt.Sprintf("level%d:\n", i))
|
||||
indent += " "
|
||||
}
|
||||
sb.WriteString(indent + "leaf: value\n")
|
||||
sb.WriteString("target:\n <<: *deep\n")
|
||||
|
||||
_, err = ParsePersonaBytes([]byte(sb.String()), "deep-merge.yaml")
|
||||
if err == nil {
|
||||
t.Fatal("expected error for deeply nested merge key content, got nil")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "depth") {
|
||||
t.Errorf("error = %q, want to contain 'depth'", err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user