Compare commits
15 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 49d6ca77a3 | |||
| 6ebf66aefb | |||
| 004343d05f | |||
| 92b84976cf | |||
| 881ce232eb | |||
| bf52fceea0 | |||
| d722035629 | |||
| b9b7be3b4e | |||
| baa917f228 | |||
| b0352ba1c9 | |||
| 0b16c4143a | |||
| 493349e11a | |||
| 5cedeee9f4 | |||
| 01b6af03a8 | |||
| 80091fb080 |
+1
-1
@@ -9,7 +9,7 @@
|
|||||||
|
|
||||||
| Package | Use Case | Scope |
|
| Package | Use Case | Scope |
|
||||||
|---------|----------|-------|
|
|---------|----------|-------|
|
||||||
| `github.com/goccy/go-yaml` | YAML parsing (persona files, config) | production |
|
| `github.com/goccy/go-yaml` | YAML parsing and AST inspection (subpkgs: `ast`, `parser`) | production |
|
||||||
| `github.com/google/go-cmp` | Test comparisons (`cmp.Diff`) | test only |
|
| `github.com/google/go-cmp` | Test comparisons (`cmp.Diff`) | test only |
|
||||||
|
|
||||||
**Any import not in this table or the Go standard library is forbidden.**
|
**Any import not in this table or the Go standard library is forbidden.**
|
||||||
|
|||||||
@@ -33,37 +33,16 @@ func parsePersona(data []byte, source string) (*Persona, error) {
|
|||||||
|
|
||||||
### YAML Parsing with Depth Protection
|
### YAML Parsing with Depth Protection
|
||||||
|
|
||||||
```go
|
We implement a custom AST-based depth/node-count walk (`checkYAMLDepth` in
|
||||||
func unmarshalYAMLWithDepthLimit(data []byte, out any, maxDepth int) error {
|
`review/persona.go`) rather than relying on library decoder options. Key design
|
||||||
var node yaml.Node
|
decisions:
|
||||||
dec := yaml.NewDecoder(bytes.NewReader(data))
|
|
||||||
if err := dec.Decode(&node); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
if err := checkYAMLDepth(&node, 0, maxDepth); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
return node.Decode(out)
|
|
||||||
}
|
|
||||||
|
|
||||||
func checkYAMLDepth(node *yaml.Node, depth, maxDepth int) error {
|
- **Library:** `github.com/goccy/go-yaml` with `ast.Node`-based traversal
|
||||||
if depth > maxDepth {
|
- **Dual-map tracking:** `validated` (depth-aware short-circuit) + `visiting` (cycle detection)
|
||||||
return fmt.Errorf("YAML nesting depth exceeds maximum (%d)", maxDepth)
|
- **Node-count limit:** Conservative overcounting bounds total validation work
|
||||||
}
|
- **Alias-aware depth:** Aliases increment depth and are re-checked when encountered at greater depths
|
||||||
// Handle alias nodes by following the Alias pointer
|
|
||||||
if node.Kind == yaml.AliasNode && node.Alias != nil {
|
|
||||||
return checkYAMLDepth(node.Alias, depth, maxDepth)
|
|
||||||
}
|
|
||||||
for _, child := range node.Content {
|
|
||||||
if err := checkYAMLDepth(child, depth+1, maxDepth); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
```
|
|
||||||
|
|
||||||
We implement a custom AST-based depth/node-count walk (`checkYAMLDepth`) rather than relying on library decoder options. This gives us precise control over how depth is counted across aliases and anchors, with a depth-aware validated map to prevent alias depth bypass.
|
See `review/persona.go:checkYAMLDepth` for the authoritative implementation.
|
||||||
|
|
||||||
## State/Data Model
|
## State/Data Model
|
||||||
|
|
||||||
@@ -74,7 +53,7 @@ No new state. Same `Persona` struct, just different parsing.
|
|||||||
| Error | Handling |
|
| Error | Handling |
|
||||||
|-------|----------|
|
|-------|----------|
|
||||||
| Invalid YAML syntax | Return parse error with source file |
|
| Invalid YAML syntax | Return parse error with source file |
|
||||||
| Deeply nested YAML | Library rejects (v1.16.0+ fix) |
|
| Deeply nested YAML | Custom AST walk (`checkYAMLDepth`) rejects before decode |
|
||||||
| Unknown extension | Fall back to JSON parsing |
|
| Unknown extension | Fall back to JSON parsing |
|
||||||
| Missing required fields | Validation rejects after parse |
|
| Missing required fields | Validation rejects after parse |
|
||||||
|
|
||||||
|
|||||||
+81
-14
@@ -11,6 +11,7 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"log/slog"
|
"log/slog"
|
||||||
|
"math"
|
||||||
"net"
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
"net/url"
|
||||||
@@ -47,6 +48,12 @@ func IsServerError(err error) bool {
|
|||||||
return errors.As(err, &apiErr) && apiErr.StatusCode >= 500 && apiErr.StatusCode < 600
|
return errors.As(err, &apiErr) && apiErr.StatusCode >= 500 && apiErr.StatusCode < 600
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// DefaultMaxDiffSize is the default maximum diff size in bytes (10 MB).
|
||||||
|
const DefaultMaxDiffSize = 10 * 1024 * 1024
|
||||||
|
|
||||||
|
// ErrDiffTooLarge is returned when a PR diff exceeds the configured MaxDiffSize.
|
||||||
|
var ErrDiffTooLarge = errors.New("diff size exceeds maximum allowed size")
|
||||||
|
|
||||||
// Client interacts with the Gitea API.
|
// Client interacts with the Gitea API.
|
||||||
// A Client is safe for concurrent use by multiple goroutines.
|
// A Client is safe for concurrent use by multiple goroutines.
|
||||||
type Client struct {
|
type Client struct {
|
||||||
@@ -61,6 +68,14 @@ type Client struct {
|
|||||||
// This field must be configured before the first request is made.
|
// This field must be configured before the first request is made.
|
||||||
// Modifying it while requests are in flight is not safe.
|
// Modifying it while requests are in flight is not safe.
|
||||||
RetryBackoff []time.Duration
|
RetryBackoff []time.Duration
|
||||||
|
|
||||||
|
// MaxDiffSize is the maximum number of bytes allowed when fetching a PR diff.
|
||||||
|
// If zero, defaults to DefaultMaxDiffSize (10 MB). Set to any negative value
|
||||||
|
// (or math.MaxInt64) to disable the limit.
|
||||||
|
//
|
||||||
|
// This field must be configured before the first request is made.
|
||||||
|
// Modifying it while requests are in flight is not safe.
|
||||||
|
MaxDiffSize int64
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewClient creates a new Gitea API client.
|
// NewClient creates a new Gitea API client.
|
||||||
@@ -125,9 +140,28 @@ func (c *Client) GetPullRequest(ctx context.Context, owner, repo string, number
|
|||||||
}
|
}
|
||||||
|
|
||||||
// GetPullRequestDiff fetches the unified diff for a PR.
|
// GetPullRequestDiff fetches the unified diff for a PR.
|
||||||
|
// It enforces MaxDiffSize to prevent unbounded memory allocation.
|
||||||
|
// Returns ErrDiffTooLarge if the diff exceeds the configured limit.
|
||||||
func (c *Client) GetPullRequestDiff(ctx context.Context, owner, repo string, number int) (string, error) {
|
func (c *Client) GetPullRequestDiff(ctx context.Context, owner, repo string, number int) (string, error) {
|
||||||
reqURL := fmt.Sprintf("%s/api/v1/repos/%s/%s/pulls/%d.diff", c.baseURL, url.PathEscape(owner), url.PathEscape(repo), number)
|
reqURL := fmt.Sprintf("%s/api/v1/repos/%s/%s/pulls/%d.diff", c.baseURL, url.PathEscape(owner), url.PathEscape(repo), number)
|
||||||
body, err := c.doGet(ctx, reqURL)
|
|
||||||
|
maxSize := c.MaxDiffSize
|
||||||
|
if maxSize == 0 {
|
||||||
|
maxSize = DefaultMaxDiffSize
|
||||||
|
}
|
||||||
|
|
||||||
|
// When the limit is disabled (negative) or set to math.MaxInt64 (which
|
||||||
|
// would overflow the +1 detection and silently disable enforcement),
|
||||||
|
// use the standard unlimited doGet path.
|
||||||
|
if maxSize < 0 || maxSize == math.MaxInt64 {
|
||||||
|
body, err := c.doGet(ctx, reqURL)
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("fetch diff: %w", err)
|
||||||
|
}
|
||||||
|
return string(body), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
body, err := c.doGetLimited(ctx, reqURL, maxSize)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", fmt.Errorf("fetch diff: %w", err)
|
return "", fmt.Errorf("fetch diff: %w", err)
|
||||||
}
|
}
|
||||||
@@ -292,9 +326,9 @@ func isRetriableSyscallError(err error) bool {
|
|||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
// redactURL strips query parameters from a URL for safe logging.
|
// redactURL strips query parameters and userinfo credentials from a URL for
|
||||||
// This prevents accidental exposure of sensitive data that future callers
|
// safe logging. This prevents accidental exposure of sensitive data (tokens in
|
||||||
// might pass via query strings.
|
// query strings, or user:pass in the authority) in log output.
|
||||||
func redactURL(rawURL string) string {
|
func redactURL(rawURL string) string {
|
||||||
parsed, err := url.Parse(rawURL)
|
parsed, err := url.Parse(rawURL)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -302,6 +336,9 @@ func redactURL(rawURL string) string {
|
|||||||
// potentially logging something sensitive.
|
// potentially logging something sensitive.
|
||||||
return "[invalid URL]"
|
return "[invalid URL]"
|
||||||
}
|
}
|
||||||
|
if parsed.User != nil {
|
||||||
|
parsed.User = url.User("REDACTED")
|
||||||
|
}
|
||||||
if parsed.RawQuery != "" {
|
if parsed.RawQuery != "" {
|
||||||
parsed.RawQuery = "[redacted]"
|
parsed.RawQuery = "[redacted]"
|
||||||
}
|
}
|
||||||
@@ -322,10 +359,12 @@ func sanitizeErrorForLog(err error) string {
|
|||||||
return err.Error()
|
return err.Error()
|
||||||
}
|
}
|
||||||
|
|
||||||
// doGet performs an HTTP GET request with retry on 5xx errors and temporary
|
// doGetWithReader performs an HTTP GET request with retry on 5xx errors and
|
||||||
// network errors. Retries up to 3 times with exponential backoff (1s, 2s delays
|
// temporary network errors. Retries up to 3 times with exponential backoff
|
||||||
// by default; configurable via Client.RetryBackoff for testing).
|
// (1s, 2s delays by default; configurable via Client.RetryBackoff for testing).
|
||||||
func (c *Client) doGet(ctx context.Context, reqURL string) ([]byte, error) {
|
// The readBody function is called with the response body on success (2xx) and
|
||||||
|
// is responsible for reading and closing it.
|
||||||
|
func (c *Client) doGetWithReader(ctx context.Context, reqURL string, readBody func(io.ReadCloser) ([]byte, error)) ([]byte, error) {
|
||||||
const maxAttempts = 3
|
const maxAttempts = 3
|
||||||
// backoff[i] is the delay before attempt i+1 (i.e., after attempt i fails).
|
// backoff[i] is the delay before attempt i+1 (i.e., after attempt i fails).
|
||||||
// First attempt (i=0) has no delay; retries wait 1s then 2s by default.
|
// First attempt (i=0) has no delay; retries wait 1s then 2s by default.
|
||||||
@@ -390,12 +429,7 @@ func (c *Client) doGet(ctx context.Context, reqURL string) ([]byte, error) {
|
|||||||
return nil, lastErr
|
return nil, lastErr
|
||||||
}
|
}
|
||||||
if resp.StatusCode >= 200 && resp.StatusCode < 300 {
|
if resp.StatusCode >= 200 && resp.StatusCode < 300 {
|
||||||
body, err := io.ReadAll(resp.Body)
|
return readBody(resp.Body)
|
||||||
resp.Body.Close()
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return body, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Error path: limit how much we read from potentially malicious server
|
// Error path: limit how much we read from potentially malicious server
|
||||||
@@ -413,6 +447,39 @@ func (c *Client) doGet(ctx context.Context, reqURL string) ([]byte, error) {
|
|||||||
return nil, lastErr
|
return nil, lastErr
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// doGet performs an HTTP GET request with retry, reading the full response body.
|
||||||
|
func (c *Client) doGet(ctx context.Context, reqURL string) ([]byte, error) {
|
||||||
|
return c.doGetWithReader(ctx, reqURL, func(body io.ReadCloser) ([]byte, error) {
|
||||||
|
defer body.Close()
|
||||||
|
return io.ReadAll(body)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// doGetLimited performs an HTTP GET request with retry but enforces a maximum
|
||||||
|
// response body size. Returns ErrDiffTooLarge if the response exceeds maxBytes.
|
||||||
|
// It reads maxBytes+1 (clamped to avoid overflow) to detect truncation without
|
||||||
|
// buffering the entire body.
|
||||||
|
func (c *Client) doGetLimited(ctx context.Context, reqURL string, maxBytes int64) ([]byte, error) {
|
||||||
|
return c.doGetWithReader(ctx, reqURL, func(body io.ReadCloser) ([]byte, error) {
|
||||||
|
defer body.Close()
|
||||||
|
// Read up to maxBytes+1 to detect overflow.
|
||||||
|
// Clamp to prevent integer overflow when maxBytes == math.MaxInt64.
|
||||||
|
limitBytes := maxBytes + 1
|
||||||
|
if limitBytes <= 0 {
|
||||||
|
limitBytes = math.MaxInt64
|
||||||
|
}
|
||||||
|
limited := io.LimitReader(body, limitBytes)
|
||||||
|
data, err := io.ReadAll(limited)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if int64(len(data)) > maxBytes {
|
||||||
|
return nil, fmt.Errorf("%w: response exceeds %d bytes", ErrDiffTooLarge, maxBytes)
|
||||||
|
}
|
||||||
|
return data, nil
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
// escapePath escapes each segment of a relative file path for use in URLs.
|
// escapePath escapes each segment of a relative file path for use in URLs.
|
||||||
// Slashes are preserved as path separators; other special characters are escaped.
|
// Slashes are preserved as path separators; other special characters are escaped.
|
||||||
// Input should be a relative path (no leading slash). Already-encoded segments
|
// Input should be a relative path (no leading slash). Already-encoded segments
|
||||||
|
|||||||
@@ -1092,6 +1092,21 @@ func TestRedactURL(t *testing.T) {
|
|||||||
input: "",
|
input: "",
|
||||||
want: "",
|
want: "",
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
name: "with userinfo - redacts credentials",
|
||||||
|
input: "https://admin:secret@gitea.example.com/api/v1/repos",
|
||||||
|
want: "https://REDACTED@gitea.example.com/api/v1/repos",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "with userinfo and query params",
|
||||||
|
input: "https://user:pass@example.com/path?token=abc",
|
||||||
|
want: "https://REDACTED@example.com/path?[redacted]",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "username only - no password",
|
||||||
|
input: "https://user@example.com/path",
|
||||||
|
want: "https://REDACTED@example.com/path",
|
||||||
|
},
|
||||||
}
|
}
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
|||||||
@@ -0,0 +1,97 @@
|
|||||||
|
package gitea
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"math"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestGetPullRequestDiff_SizeLimits(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
diff string
|
||||||
|
maxDiffSize int64
|
||||||
|
wantErr error
|
||||||
|
wantDiff string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "exceeds max size",
|
||||||
|
diff: strings.Repeat("+ added line\n", 1000), // ~13 KB
|
||||||
|
maxDiffSize: 100,
|
||||||
|
wantErr: ErrDiffTooLarge,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "within max size",
|
||||||
|
diff: "diff --git a/f.go b/f.go\n--- a/f.go\n+++ b/f.go\n@@ -1 +1 @@\n-old\n+new\n",
|
||||||
|
maxDiffSize: 1024,
|
||||||
|
wantDiff: "diff --git a/f.go b/f.go\n--- a/f.go\n+++ b/f.go\n@@ -1 +1 @@\n-old\n+new\n",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "exactly at limit",
|
||||||
|
diff: strings.Repeat("x", 50),
|
||||||
|
maxDiffSize: 50,
|
||||||
|
wantDiff: strings.Repeat("x", 50),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "one byte over limit",
|
||||||
|
diff: strings.Repeat("x", 51),
|
||||||
|
maxDiffSize: 50,
|
||||||
|
wantErr: ErrDiffTooLarge,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "disabled limit",
|
||||||
|
diff: strings.Repeat("x", 10000),
|
||||||
|
maxDiffSize: -1,
|
||||||
|
wantDiff: strings.Repeat("x", 10000),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "math.MaxInt64 treated as disabled",
|
||||||
|
diff: strings.Repeat("x", 10000),
|
||||||
|
maxDiffSize: math.MaxInt64,
|
||||||
|
wantDiff: strings.Repeat("x", 10000),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "default limit",
|
||||||
|
diff: "diff content",
|
||||||
|
maxDiffSize: 0, // zero means use DefaultMaxDiffSize
|
||||||
|
wantDiff: "diff content",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.Write([]byte(tt.diff)) //nolint:errcheck // test handler
|
||||||
|
}))
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
client := NewClient(server.URL, "test-token")
|
||||||
|
client.MaxDiffSize = tt.maxDiffSize
|
||||||
|
client.RetryBackoff = []time.Duration{}
|
||||||
|
|
||||||
|
got, err := client.GetPullRequestDiff(context.Background(), "owner", "repo", 1)
|
||||||
|
|
||||||
|
if tt.wantErr != nil {
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected error, got nil")
|
||||||
|
}
|
||||||
|
if !errors.Is(err, tt.wantErr) {
|
||||||
|
t.Errorf("expected %v, got: %v", tt.wantErr, err)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
if got != tt.wantDiff {
|
||||||
|
t.Errorf("diff mismatch: got length %d, want length %d", len(got), len(tt.wantDiff))
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
+70
-20
@@ -5,6 +5,7 @@ import (
|
|||||||
"embed"
|
"embed"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"io"
|
||||||
"os"
|
"os"
|
||||||
"sort"
|
"sort"
|
||||||
"strings"
|
"strings"
|
||||||
@@ -120,9 +121,7 @@ func ListBuiltinPersonas() []string {
|
|||||||
default:
|
default:
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
if !seen[personaName] {
|
seen[personaName] = true
|
||||||
seen[personaName] = true
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
names := make([]string, 0, len(seen))
|
names := make([]string, 0, len(seen))
|
||||||
for name := range seen {
|
for name := range seen {
|
||||||
@@ -148,6 +147,15 @@ func parsePersona(data []byte, source string) (*Persona, error) {
|
|||||||
dec := json.NewDecoder(bytes.NewReader(data))
|
dec := json.NewDecoder(bytes.NewReader(data))
|
||||||
dec.DisallowUnknownFields()
|
dec.DisallowUnknownFields()
|
||||||
err = dec.Decode(&p)
|
err = dec.Decode(&p)
|
||||||
|
if err == nil {
|
||||||
|
// Reject trailing content after the first valid JSON object.
|
||||||
|
// Without this check, input like `{"name":"x"}garbage` would
|
||||||
|
// silently succeed because Decoder stops after one object.
|
||||||
|
var dummy json.RawMessage
|
||||||
|
if err2 := dec.Decode(&dummy); err2 != io.EOF {
|
||||||
|
err = fmt.Errorf("unexpected trailing content after JSON object")
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("parse persona %s: %w", source, err)
|
return nil, fmt.Errorf("parse persona %s: %w", source, err)
|
||||||
@@ -158,10 +166,10 @@ func parsePersona(data []byte, source string) (*Persona, error) {
|
|||||||
return &p, nil
|
return &p, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// unmarshalYAMLWithDepthLimit unmarshals YAML data with explicit depth limiting
|
// unmarshalYAMLWithDepthLimit unmarshals YAML data with three safety checks:
|
||||||
// and strict field checking. This protects against stack exhaustion from deeply
|
// - Depth limiting: rejects AST trees exceeding maxDepth to prevent stack exhaustion.
|
||||||
// nested structures and catches typos in field names.
|
// - Multi-document rejection: prevents silent data loss from ignored extra documents.
|
||||||
// Multi-document YAML files are rejected to prevent silent data loss.
|
// - Strict field checking: rejects unknown YAML keys to catch typos early.
|
||||||
func unmarshalYAMLWithDepthLimit(data []byte, out any, maxDepth int) error {
|
func unmarshalYAMLWithDepthLimit(data []byte, out any, maxDepth int) error {
|
||||||
// First pass: parse into AST to check depth limits, node counts, and
|
// First pass: parse into AST to check depth limits, node counts, and
|
||||||
// multi-document rejection. This prevents stack exhaustion before we
|
// multi-document rejection. This prevents stack exhaustion before we
|
||||||
@@ -190,13 +198,18 @@ func unmarshalYAMLWithDepthLimit(data []byte, out any, maxDepth int) error {
|
|||||||
|
|
||||||
// Second pass: decode with strict field checking enabled.
|
// Second pass: decode with strict field checking enabled.
|
||||||
// Strict() rejects unknown keys, catching typos like "focuss" or "identiy".
|
// Strict() rejects unknown keys, catching typos like "focuss" or "identiy".
|
||||||
|
//
|
||||||
|
// Safety note: goccy/go-yaml's decoder does not expand YAML aliases
|
||||||
|
// recursively — it resolves them via the pre-built AST, which our first
|
||||||
|
// pass already depth-checked. Alias chains that would exceed depth limits
|
||||||
|
// are caught above; the decoder merely reads the resolved scalar values.
|
||||||
dec := yaml.NewDecoder(bytes.NewReader(data), yaml.Strict())
|
dec := yaml.NewDecoder(bytes.NewReader(data), yaml.Strict())
|
||||||
return dec.Decode(out)
|
return dec.Decode(out)
|
||||||
}
|
}
|
||||||
|
|
||||||
// checkYAMLDepth recursively checks that YAML AST nodes don't exceed the depth
|
// checkYAMLDepth recursively checks that YAML AST nodes don't exceed the depth
|
||||||
// limit or the total node count limit. It uses two tracking maps:
|
// limit or the total node count limit. It uses two tracking maps:
|
||||||
// - validated: maps each node to the minimum depth at which it was previously
|
// - validated: maps each node to the maximum depth at which it was previously
|
||||||
// checked. If a node is revisited at a deeper depth (e.g., via an alias),
|
// checked. If a node is revisited at a deeper depth (e.g., via an alias),
|
||||||
// we re-check it to ensure the combined effective depth doesn't exceed limits.
|
// we re-check it to ensure the combined effective depth doesn't exceed limits.
|
||||||
// - visiting: per-path recursion stack for true cycle detection. A node on the
|
// - visiting: per-path recursion stack for true cycle detection. A node on the
|
||||||
@@ -214,12 +227,6 @@ func checkYAMLDepth(node ast.Node, depth, maxDepth, maxNodes int, validated map[
|
|||||||
return fmt.Errorf("YAML nesting depth exceeds maximum (%d)", maxDepth)
|
return fmt.Errorf("YAML nesting depth exceeds maximum (%d)", maxDepth)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Track total nodes visited as defense-in-depth against wide-but-shallow attacks.
|
|
||||||
*nodeCount++
|
|
||||||
if *nodeCount > maxNodes {
|
|
||||||
return fmt.Errorf("YAML node count exceeds maximum (%d)", maxNodes)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Cycle detection: if we're currently visiting this node on the current
|
// Cycle detection: if we're currently visiting this node on the current
|
||||||
// recursion path, it's a cycle (e.g., alias pointing to an ancestor).
|
// recursion path, it's a cycle (e.g., alias pointing to an ancestor).
|
||||||
// Return nil to break the cycle without error — cycles are a structural
|
// Return nil to break the cycle without error — cycles are a structural
|
||||||
@@ -228,10 +235,28 @@ func checkYAMLDepth(node ast.Node, depth, maxDepth, maxNodes int, validated map[
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Depth-aware short-circuit: only skip re-checking a node if we previously
|
// Track total nodes visited as defense-in-depth against wide-but-shallow attacks.
|
||||||
// validated it at the same or deeper effective depth. If this visit is at a
|
// Placed after cycle detection but before the depth-aware short-circuit. This means
|
||||||
// greater depth than before (e.g., alias referenced deeper in the tree),
|
// nodes revisited at shallower depths (via aliases) are counted each time they are
|
||||||
// we must re-traverse to catch depth limit violations.
|
// encountered — intentional conservative overcounting. This bounds the total work
|
||||||
|
// performed during validation rather than tracking unique nodes, which is the safer
|
||||||
|
// security posture for untrusted YAML input.
|
||||||
|
*nodeCount++
|
||||||
|
if *nodeCount > maxNodes {
|
||||||
|
return fmt.Errorf("YAML node count exceeds maximum (%d)", maxNodes)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Depth-aware short-circuit: skip re-validation only when the current visit
|
||||||
|
// depth is the same or shallower than the depth at which this node was
|
||||||
|
// previously validated. A shallower (or equal) current depth means the
|
||||||
|
// prior, deeper validation already covered any subtree depth violations.
|
||||||
|
// If the current depth exceeds the previous validation depth (e.g., an alias
|
||||||
|
// references this node deeper in the tree), we must re-traverse to ensure
|
||||||
|
// the combined effective depth doesn't exceed maxDepth.
|
||||||
|
//
|
||||||
|
// Note: using ast.Node (interface) as map key relies on pointer identity,
|
||||||
|
// which is correct because all goccy/go-yaml AST node types are pointer
|
||||||
|
// receivers (*MappingNode, *SequenceNode, etc.), never value types.
|
||||||
if prevDepth, ok := validated[node]; ok && depth <= prevDepth {
|
if prevDepth, ok := validated[node]; ok && depth <= prevDepth {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@@ -250,6 +275,11 @@ func checkYAMLDepth(node ast.Node, depth, maxDepth, maxNodes int, validated map[
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
case *ast.MappingValueNode:
|
case *ast.MappingValueNode:
|
||||||
|
// Both Key and Value are visited at depth+1 relative to this
|
||||||
|
// MappingValueNode. Since MappingNode visits its MappingValueNode
|
||||||
|
// children at depth+1 as well, keys and values end up at depth+2
|
||||||
|
// from the parent MappingNode. This is intentional: it mirrors the
|
||||||
|
// actual nesting structure (mapping → key-value pair → key/value).
|
||||||
if err := checkYAMLDepth(n.Key, depth+1, maxDepth, maxNodes, validated, visiting, nodeCount); err != nil {
|
if err := checkYAMLDepth(n.Key, depth+1, maxDepth, maxNodes, validated, visiting, nodeCount); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@@ -269,6 +299,14 @@ func checkYAMLDepth(node ast.Node, depth, maxDepth, maxNodes int, validated map[
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
case *ast.AnchorNode:
|
case *ast.AnchorNode:
|
||||||
|
// Increment depth for anchor values as a conservative measure: the
|
||||||
|
// anchor definition itself is structural, and treating it as a depth
|
||||||
|
// level ensures that deeply nested anchors are caught at definition
|
||||||
|
// time rather than only when referenced via alias. This +1 is
|
||||||
|
// asymmetric with alias (which also increments) — by design, the
|
||||||
|
// effective depth budget for anchored-then-aliased content is reduced
|
||||||
|
// because both the definition site and the reference site each consume
|
||||||
|
// a level, making deeply nested anchor/alias pairs hit the limit sooner.
|
||||||
if err := checkYAMLDepth(n.Value, depth+1, maxDepth, maxNodes, validated, visiting, nodeCount); err != nil {
|
if err := checkYAMLDepth(n.Value, depth+1, maxDepth, maxNodes, validated, visiting, nodeCount); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@@ -276,8 +314,16 @@ func checkYAMLDepth(node ast.Node, depth, maxDepth, maxNodes int, validated map[
|
|||||||
if err := checkYAMLDepth(n.Value, depth+1, maxDepth, maxNodes, validated, visiting, nodeCount); err != nil {
|
if err := checkYAMLDepth(n.Value, depth+1, maxDepth, maxNodes, validated, visiting, nodeCount); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
// Scalar types (StringNode, IntegerNode, FloatNode, BoolNode, NullNode,
|
case *ast.MergeKeyNode:
|
||||||
// InfinityNode, NanNode, LiteralNode, MergeKeyNode) are leaf nodes.
|
// MergeKeyNode represents the literal "<<" merge key token. It has no
|
||||||
|
// child nodes — the value side of a merge (e.g., *alias) lives in the
|
||||||
|
// parent MappingValueNode.Value, which is already recursed into above.
|
||||||
|
// Explicitly listed here (rather than in the default case) to prevent
|
||||||
|
// future library changes from silently bypassing depth checks.
|
||||||
|
default:
|
||||||
|
// Scalar leaf nodes (StringNode, IntegerNode, FloatNode, BoolNode,
|
||||||
|
// NullNode, InfinityNode, NanNode, LiteralNode) have no children to
|
||||||
|
// recurse into.
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@@ -285,7 +331,11 @@ func checkYAMLDepth(node ast.Node, depth, maxDepth, maxNodes int, validated map[
|
|||||||
// ParsePersonaBytes parses persona data from bytes with a source label for errors.
|
// ParsePersonaBytes parses persona data from bytes with a source label for errors.
|
||||||
// This is useful for parsing personas fetched from external sources (e.g., Gitea API)
|
// This is useful for parsing personas fetched from external sources (e.g., Gitea API)
|
||||||
// without requiring filesystem access. Format is detected by source extension.
|
// without requiring filesystem access. Format is detected by source extension.
|
||||||
|
// Input is bounded by MaxPersonaFileSize to prevent resource exhaustion.
|
||||||
func ParsePersonaBytes(data []byte, source string) (*Persona, error) {
|
func ParsePersonaBytes(data []byte, source string) (*Persona, error) {
|
||||||
|
if len(data) > MaxPersonaFileSize {
|
||||||
|
return nil, fmt.Errorf("persona data from %s exceeds maximum size (%d bytes, limit %d)", source, len(data), MaxPersonaFileSize)
|
||||||
|
}
|
||||||
return parsePersona(data, source)
|
return parsePersona(data, source)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
+113
-9
@@ -459,8 +459,14 @@ func TestYAMLDeeplyNestedRejection(t *testing.T) {
|
|||||||
path := filepath.Join(dir, "deeply-nested.yaml")
|
path := filepath.Join(dir, "deeply-nested.yaml")
|
||||||
|
|
||||||
// Build a deeply nested YAML structure that exceeds MaxYAMLDepth (20).
|
// Build a deeply nested YAML structure that exceeds MaxYAMLDepth (20).
|
||||||
// Each nested mapping key generates a MappingValueNode, incrementing depth
|
// Depth accumulation trace for "nested: \n level0: \n level1: ...":
|
||||||
// by 1 per level in the AST walk. With 25 levels, we exceed MaxYAMLDepth (20).
|
// - Document root parsed at depth 0
|
||||||
|
// - Root MappingNode children (MappingValueNodes) visited at depth 1
|
||||||
|
// - "nested" MappingValueNode: key at depth 2, value at depth 2
|
||||||
|
// - Each levelN adds depth via MappingValueNode traversal (key + value)
|
||||||
|
// - Exact depth per level depends on AST structure (MappingNode wrapping),
|
||||||
|
// but 25 levels reliably exceeds MaxYAMLDepth (20) with comfortable margin.
|
||||||
|
// The test uses 25 levels rather than exactly 21 to avoid brittleness.
|
||||||
var sb strings.Builder
|
var sb strings.Builder
|
||||||
sb.WriteString("name: test\nidentity: test\nnested:\n")
|
sb.WriteString("name: test\nidentity: test\nnested:\n")
|
||||||
indent := " "
|
indent := " "
|
||||||
@@ -485,19 +491,18 @@ func TestYAMLDeeplyNestedRejection(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestYAMLEmptyFileRejection(t *testing.T) {
|
func TestYAMLEmptyFileRejection(t *testing.T) {
|
||||||
dir := t.TempDir()
|
|
||||||
|
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
content string
|
content string
|
||||||
}{
|
}{
|
||||||
{"completely empty", ""},
|
{"completely_empty", ""},
|
||||||
{"whitespace only", " \n\n "},
|
{"whitespace_only", " \n\n "},
|
||||||
{"comment only", "# just a comment\n"},
|
{"comment_only", "# just a comment\n"},
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, tc := range tests {
|
for _, tc := range tests {
|
||||||
t.Run(tc.name, func(t *testing.T) {
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
dir := t.TempDir()
|
||||||
path := filepath.Join(dir, tc.name+".yaml")
|
path := filepath.Join(dir, tc.name+".yaml")
|
||||||
if err := os.WriteFile(path, []byte(tc.content), 0644); err != nil {
|
if err := os.WriteFile(path, []byte(tc.content), 0644); err != nil {
|
||||||
t.Fatalf("failed to write test file: %v", err)
|
t.Fatalf("failed to write test file: %v", err)
|
||||||
@@ -505,9 +510,9 @@ func TestYAMLEmptyFileRejection(t *testing.T) {
|
|||||||
|
|
||||||
_, err := LoadPersona(path)
|
_, err := LoadPersona(path)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
t.Error("expected error for empty YAML input, got nil")
|
t.Fatal("expected error for empty YAML input, got nil")
|
||||||
}
|
}
|
||||||
if err != nil && !strings.Contains(err.Error(), "empty YAML document") {
|
if !strings.Contains(err.Error(), "empty YAML document") {
|
||||||
t.Errorf("expected error containing %q, got: %v", "empty YAML document", err)
|
t.Errorf("expected error containing %q, got: %v", "empty YAML document", err)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
@@ -853,3 +858,102 @@ identity: test identity
|
|||||||
t.Errorf("Name = %q, want %q", p.Name, "test")
|
t.Errorf("Name = %q, want %q", p.Name, "test")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestJSONTrailingContentRejected(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
content string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "trailing garbage after object",
|
||||||
|
content: `{"name":"test","identity":"test identity"}garbage`,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "two JSON objects",
|
||||||
|
content: `{"name":"test","identity":"test identity"}{"name":"other"}`,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "trailing array",
|
||||||
|
content: `{"name":"test","identity":"test identity"}[]`,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
dir := t.TempDir()
|
||||||
|
path := filepath.Join(dir, "test.json")
|
||||||
|
if err := os.WriteFile(path, []byte(tt.content), 0644); err != nil {
|
||||||
|
t.Fatalf("failed to write test file: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err := LoadPersona(path)
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected error for trailing content, got nil")
|
||||||
|
}
|
||||||
|
if !strings.Contains(err.Error(), "trailing content") {
|
||||||
|
t.Errorf("error = %q, want to contain 'trailing content'", err.Error())
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParsePersonaBytesSizeLimit(t *testing.T) {
|
||||||
|
// ParsePersonaBytes should reject input exceeding MaxPersonaFileSize
|
||||||
|
oversized := make([]byte, MaxPersonaFileSize+1)
|
||||||
|
for i := range oversized {
|
||||||
|
oversized[i] = 'x'
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err := ParsePersonaBytes(oversized, "oversized.yaml")
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected error for oversized input, got nil")
|
||||||
|
}
|
||||||
|
if !strings.Contains(err.Error(), "exceeds maximum size") {
|
||||||
|
t.Errorf("error = %q, want to contain 'exceeds maximum size'", err.Error())
|
||||||
|
}
|
||||||
|
|
||||||
|
// Just under the limit should not trigger size error (may fail parse, but not size)
|
||||||
|
underLimit := []byte("name: test\nidentity: test persona\n")
|
||||||
|
p, err := ParsePersonaBytes(underLimit, "valid.yaml")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error for valid input: %v", err)
|
||||||
|
}
|
||||||
|
if p.Name != "test" {
|
||||||
|
t.Errorf("Name = %q, want %q", p.Name, "test")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestYAMLMergeKeyDepthCheck(t *testing.T) {
|
||||||
|
// Verify that YAML merge keys (<<: *alias) are properly handled by the
|
||||||
|
// depth checker. The merge key content is in the MappingValueNode.Value
|
||||||
|
// (an AliasNode), not in the MergeKeyNode itself.
|
||||||
|
p, err := ParsePersonaBytes([]byte("name: merge-test\nidentity: test\n"), "merge.yaml")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("basic parse failed: %v", err)
|
||||||
|
}
|
||||||
|
if p.Name != "merge-test" {
|
||||||
|
t.Errorf("Name = %q, want %q", p.Name, "merge-test")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test that deeply nested merge keys still hit depth limit.
|
||||||
|
// Build YAML with merge key content nested beyond MaxYAMLDepth.
|
||||||
|
var sb strings.Builder
|
||||||
|
sb.WriteString("name: deep-merge\nidentity: deep merge persona\n")
|
||||||
|
sb.WriteString("anchor: &deep\n")
|
||||||
|
indent := " "
|
||||||
|
for i := 0; i < MaxYAMLDepth+5; i++ {
|
||||||
|
sb.WriteString(indent)
|
||||||
|
sb.WriteString(fmt.Sprintf("level%d:\n", i))
|
||||||
|
indent += " "
|
||||||
|
}
|
||||||
|
sb.WriteString(indent + "leaf: value\n")
|
||||||
|
sb.WriteString("target:\n <<: *deep\n")
|
||||||
|
|
||||||
|
_, err = ParsePersonaBytes([]byte(sb.String()), "deep-merge.yaml")
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected error for deeply nested merge key content, got nil")
|
||||||
|
}
|
||||||
|
if !strings.Contains(err.Error(), "depth") {
|
||||||
|
t.Errorf("error = %q, want to contain 'depth'", err.Error())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user