Compare commits

..

1 Commits

Author SHA1 Message Date
Rodin b0349a22a0 ci: add test job for dot path normalization
PR Ready Gate / clear-labels (pull_request) Successful in 1s
CI / test (pull_request) Successful in 16s
CI / review (anthropic--claude-4.6-sonnet, sonnet, SONNET_REVIEW_TOKEN) (pull_request) Successful in 28s
CI / review (gpt-5, gpt, GPT_REVIEW_TOKEN) (pull_request) Successful in 58s
CI / test-dot-path (pull_request) Successful in 58s
CI / review (gpt-5, security, SECURITY_REVIEW.md, SECURITY_REVIEW_TOKEN) (pull_request) Successful in 1m8s
Uses rodin/security-patterns with patterns-files='.' to verify
the fix works end-to-end.
2026-05-11 07:10:50 -07:00
12 changed files with 173 additions and 684 deletions
+33 -4
View File
@@ -38,8 +38,6 @@ jobs:
- name: security - name: security
token_secret: SECURITY_REVIEW_TOKEN token_secret: SECURITY_REVIEW_TOKEN
model: gpt-5 model: gpt-5
patterns_repo: rodin/security-patterns
patterns_files: "."
system_prompt_file: SECURITY_REVIEW.md system_prompt_file: SECURITY_REVIEW.md
steps: steps:
- uses: actions/checkout@v4 - uses: actions/checkout@v4
@@ -62,8 +60,39 @@ jobs:
AICORE_API_URL: ${{ secrets.AICORE_API_URL }} AICORE_API_URL: ${{ secrets.AICORE_API_URL }}
AICORE_RESOURCE_GROUP: ${{ secrets.AICORE_RESOURCE_GROUP }} AICORE_RESOURCE_GROUP: ${{ secrets.AICORE_RESOURCE_GROUP }}
CONVENTIONS_FILE: "CONVENTIONS.md" CONVENTIONS_FILE: "CONVENTIONS.md"
PATTERNS_REPO: ${{ matrix.patterns_repo || 'rodin/go-patterns' }} PATTERNS_REPO: "rodin/go-patterns"
PATTERNS_FILES: ${{ matrix.patterns_files || 'README.md,patterns/' }} PATTERNS_FILES: "README.md,patterns/"
LLM_TIMEOUT: "600" LLM_TIMEOUT: "600"
SYSTEM_PROMPT_FILE: ${{ matrix.system_prompt_file }} SYSTEM_PROMPT_FILE: ${{ matrix.system_prompt_file }}
run: ./review-bot run: ./review-bot
# Test dot path normalization with security-patterns repo
test-dot-path:
runs-on: ubuntu-24.04
if: github.event_name == 'pull_request'
needs: test
steps:
- uses: actions/checkout@v4
- uses: actions/setup-go@v5
with:
go-version: '1.26'
- run: go build -o review-bot ./cmd/review-bot
- name: Test dot path normalization
env:
GITEA_URL: ${{ github.server_url }}
GITEA_REPO: ${{ github.repository }}
PR_NUMBER: ${{ github.event.pull_request.number }}
REVIEWER_TOKEN: ${{ secrets.GPT_REVIEW_TOKEN }}
REVIEWER_NAME: dot-path-test
LLM_PROVIDER: aicore
LLM_MODEL: gpt-5
AICORE_CLIENT_ID: ${{ secrets.AICORE_CLIENT_ID }}
AICORE_CLIENT_SECRET: ${{ secrets.AICORE_CLIENT_SECRET }}
AICORE_AUTH_URL: ${{ secrets.AICORE_AUTH_URL }}
AICORE_API_URL: ${{ secrets.AICORE_API_URL }}
AICORE_RESOURCE_GROUP: ${{ secrets.AICORE_RESOURCE_GROUP }}
CONVENTIONS_FILE: "CONVENTIONS.md"
PATTERNS_REPO: "rodin/security-patterns"
PATTERNS_FILES: "."
LLM_TIMEOUT: "600"
run: ./review-bot
+1 -1
View File
@@ -9,7 +9,7 @@
| Package | Use Case | Scope | | Package | Use Case | Scope |
|---------|----------|-------| |---------|----------|-------|
| `github.com/goccy/go-yaml` | YAML parsing and AST inspection (subpkgs: `ast`, `parser`) | production | | `gopkg.in/yaml.v3` | YAML parsing (persona files, config) | production |
| `github.com/google/go-cmp` | Test comparisons (`cmp.Diff`) | test only | | `github.com/google/go-cmp` | Test comparisons (`cmp.Diff`) | test only |
**Any import not in this table or the Go standard library is forbidden.** **Any import not in this table or the Go standard library is forbidden.**
+7 -16
View File
@@ -65,7 +65,7 @@ func main() {
conventionsFile := flag.String("conventions-file", envOrDefault("CONVENTIONS_FILE", ""), "Conventions file path in repo (e.g. CLAUDE.md)") conventionsFile := flag.String("conventions-file", envOrDefault("CONVENTIONS_FILE", ""), "Conventions file path in repo (e.g. CLAUDE.md)")
systemPromptFile := flag.String("system-prompt-file", envOrDefault("SYSTEM_PROMPT_FILE", ""), "Local file with additional system prompt instructions") systemPromptFile := flag.String("system-prompt-file", envOrDefault("SYSTEM_PROMPT_FILE", ""), "Local file with additional system prompt instructions")
patternsRepo := flag.String("patterns-repo", envOrDefault("PATTERNS_REPO", ""), "Repo with language patterns (e.g. rodin/elixir-patterns)") patternsRepo := flag.String("patterns-repo", envOrDefault("PATTERNS_REPO", ""), "Repo with language patterns (e.g. rodin/elixir-patterns)")
patternsFiles := flag.String("patterns-files", envOrDefault("PATTERNS_FILES", ""), "Comma-separated file paths to fetch from patterns repo (empty = all files)") patternsFiles := flag.String("patterns-files", envOrDefault("PATTERNS_FILES", "README.md"), "Comma-separated file paths to fetch from patterns repo")
dryRun := flag.Bool("dry-run", false, "Print review to stdout instead of posting") dryRun := flag.Bool("dry-run", false, "Print review to stdout instead of posting")
llmTemp := flag.Float64("llm-temperature", envOrDefaultFloat("LLM_TEMPERATURE", 0), "LLM temperature (0 = server default)") llmTemp := flag.Float64("llm-temperature", envOrDefaultFloat("LLM_TEMPERATURE", 0), "LLM temperature (0 = server default)")
llmTimeout := flag.Int("llm-timeout", envOrDefaultInt("LLM_TIMEOUT", 300), "LLM request timeout in seconds (default 300)") llmTimeout := flag.Int("llm-timeout", envOrDefaultInt("LLM_TIMEOUT", 300), "LLM request timeout in seconds (default 300)")
@@ -523,25 +523,11 @@ func fetchFileContext(ctx context.Context, client *gitea.Client, owner, repo, re
// patternsRepo is comma-separated list of owner/name repos. // patternsRepo is comma-separated list of owner/name repos.
// patternsFiles is comma-separated list of file paths or directories. // patternsFiles is comma-separated list of file paths or directories.
// If a path ends with / or is a directory, all files within it are fetched recursively. // If a path ends with / or is a directory, all files within it are fetched recursively.
// If patternsFiles is empty, all files from the repo root are fetched.
func fetchPatterns(ctx context.Context, client *gitea.Client, patternsRepo, patternsFiles string) string { func fetchPatterns(ctx context.Context, client *gitea.Client, patternsRepo, patternsFiles string) string {
var sb strings.Builder var sb strings.Builder
repos := strings.Split(patternsRepo, ",") repos := strings.Split(patternsRepo, ",")
paths := strings.Split(patternsFiles, ",")
// Build the list of paths to fetch
var paths []string
if patternsFiles == "" {
// Empty patternsFiles means "fetch all files from repo root"
paths = []string{""}
} else {
for _, p := range strings.Split(patternsFiles, ",") {
p = strings.TrimSpace(p)
if p != "" {
paths = append(paths, p)
}
}
}
for _, repoRef := range repos { for _, repoRef := range repos {
if ctx.Err() != nil { if ctx.Err() != nil {
@@ -562,6 +548,11 @@ func fetchPatterns(ctx context.Context, client *gitea.Client, patternsRepo, patt
var repoSkippedFiles []string var repoSkippedFiles []string
for _, path := range paths { for _, path := range paths {
path = strings.TrimSpace(path)
if path == "" {
continue
}
files, err := client.GetAllFilesInPath(ctx, owner, repo, path) files, err := client.GetAllFilesInPath(ctx, owner, repo, path)
if err != nil { if err != nil {
slog.Warn("could not fetch patterns", "path", path, "repo", repoRef, "error", err) slog.Warn("could not fetch patterns", "path", path, "repo", repoRef, "error", err)
-46
View File
@@ -504,52 +504,6 @@ func TestIsPatternFile(t *testing.T) {
} }
} }
// TestBuildPatternPaths verifies the path-building logic for fetchPatterns.
// Empty patternsFiles means "fetch all from root" (represented as [""]).
func TestBuildPatternPaths(t *testing.T) {
buildPaths := func(patternsFiles string) []string {
if patternsFiles == "" {
return []string{""}
}
var paths []string
for _, p := range strings.Split(patternsFiles, ",") {
p = strings.TrimSpace(p)
if p != "" {
paths = append(paths, p)
}
}
return paths
}
tests := []struct {
name string
input string
want []string
}{
{"empty fetches root", "", []string{""}},
{"single file", "README.md", []string{"README.md"}},
{"multiple files", "README.md,PATTERNS.md", []string{"README.md", "PATTERNS.md"}},
{"trims whitespace", " foo.md , bar.md ", []string{"foo.md", "bar.md"}},
{"skips empty between commas", "foo.md,,bar.md", []string{"foo.md", "bar.md"}},
{"directory path", "patterns/", []string{"patterns/"}},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
got := buildPaths(tc.input)
if len(got) != len(tc.want) {
t.Errorf("buildPaths(%q) = %v, want %v", tc.input, got, tc.want)
return
}
for i := range got {
if got[i] != tc.want[i] {
t.Errorf("buildPaths(%q)[%d] = %q, want %q", tc.input, i, got[i], tc.want[i])
}
}
})
}
}
func TestEvaluateCIStatus(t *testing.T) { func TestEvaluateCIStatus(t *testing.T) {
tests := []struct { tests := []struct {
name string name string
+31 -10
View File
@@ -9,7 +9,7 @@ JSON is awkward for persona files that contain multi-line text (identity, severi
- Backwards compatibility: existing JSON personas must continue to work - Backwards compatibility: existing JSON personas must continue to work
- Security: protect against DoS via deeply nested YAML (AIKIDO-2024-10486) - Security: protect against DoS via deeply nested YAML (AIKIDO-2024-10486)
- Consistency: use `.yaml` extension (not `.yml`) - Consistency: use `.yaml` extension (not `.yml`)
- Library: use `github.com/goccy/go-yaml` v1.16.0+ (approved in CONVENTIONS.md); we implement custom AST-based depth/node-count checks for precise alias-aware validation - Library: use `gopkg.in/yaml.v3` (approved in CONVENTIONS.md) with explicit depth limiting
## Proposed Approach ## Proposed Approach
@@ -33,16 +33,37 @@ func parsePersona(data []byte, source string) (*Persona, error) {
### YAML Parsing with Depth Protection ### YAML Parsing with Depth Protection
We implement a custom AST-based depth/node-count walk (`checkYAMLDepth` in ```go
`review/persona.go`) rather than relying on library decoder options. Key design func unmarshalYAMLWithDepthLimit(data []byte, out any, maxDepth int) error {
decisions: var node yaml.Node
dec := yaml.NewDecoder(bytes.NewReader(data))
if err := dec.Decode(&node); err != nil {
return err
}
if err := checkYAMLDepth(&node, 0, maxDepth); err != nil {
return err
}
return node.Decode(out)
}
- **Library:** `github.com/goccy/go-yaml` with `ast.Node`-based traversal func checkYAMLDepth(node *yaml.Node, depth, maxDepth int) error {
- **Dual-map tracking:** `validated` (depth-aware short-circuit) + `visiting` (cycle detection) if depth > maxDepth {
- **Node-count limit:** Conservative overcounting bounds total validation work return fmt.Errorf("YAML nesting depth exceeds maximum (%d)", maxDepth)
- **Alias-aware depth:** Aliases increment depth and are re-checked when encountered at greater depths }
// Handle alias nodes by following the Alias pointer
if node.Kind == yaml.AliasNode && node.Alias != nil {
return checkYAMLDepth(node.Alias, depth, maxDepth)
}
for _, child := range node.Content {
if err := checkYAMLDepth(child, depth+1, maxDepth); err != nil {
return err
}
}
return nil
}
```
See `review/persona.go:checkYAMLDepth` for the authoritative implementation. The `gopkg.in/yaml.v3` library does not have built-in depth protection, so we implement explicit depth checking by first decoding into a `yaml.Node`, walking the tree to verify depth (including alias resolution), then decoding into the target struct.
## State/Data Model ## State/Data Model
@@ -53,7 +74,7 @@ No new state. Same `Persona` struct, just different parsing.
| Error | Handling | | Error | Handling |
|-------|----------| |-------|----------|
| Invalid YAML syntax | Return parse error with source file | | Invalid YAML syntax | Return parse error with source file |
| Deeply nested YAML | Custom AST walk (`checkYAMLDepth`) rejects before decode | | Deeply nested YAML | Library rejects (v1.16.0+ fix) |
| Unknown extension | Fall back to JSON parsing | | Unknown extension | Fall back to JSON parsing |
| Missing required fields | Validation rejects after parse | | Missing required fields | Validation rejects after parse |
+15 -93
View File
@@ -11,7 +11,6 @@ import (
"fmt" "fmt"
"io" "io"
"log/slog" "log/slog"
"math"
"net" "net"
"net/http" "net/http"
"net/url" "net/url"
@@ -48,12 +47,6 @@ func IsServerError(err error) bool {
return errors.As(err, &apiErr) && apiErr.StatusCode >= 500 && apiErr.StatusCode < 600 return errors.As(err, &apiErr) && apiErr.StatusCode >= 500 && apiErr.StatusCode < 600
} }
// DefaultMaxDiffSize is the default maximum diff size in bytes (10 MB).
const DefaultMaxDiffSize = 10 * 1024 * 1024
// ErrDiffTooLarge is returned when a PR diff exceeds the configured MaxDiffSize.
var ErrDiffTooLarge = errors.New("diff size exceeds maximum allowed size")
// Client interacts with the Gitea API. // Client interacts with the Gitea API.
// A Client is safe for concurrent use by multiple goroutines. // A Client is safe for concurrent use by multiple goroutines.
type Client struct { type Client struct {
@@ -68,14 +61,6 @@ type Client struct {
// This field must be configured before the first request is made. // This field must be configured before the first request is made.
// Modifying it while requests are in flight is not safe. // Modifying it while requests are in flight is not safe.
RetryBackoff []time.Duration RetryBackoff []time.Duration
// MaxDiffSize is the maximum number of bytes allowed when fetching a PR diff.
// If zero, defaults to DefaultMaxDiffSize (10 MB). Set to any negative value
// (or math.MaxInt64) to disable the limit.
//
// This field must be configured before the first request is made.
// Modifying it while requests are in flight is not safe.
MaxDiffSize int64
} }
// NewClient creates a new Gitea API client. // NewClient creates a new Gitea API client.
@@ -140,28 +125,9 @@ func (c *Client) GetPullRequest(ctx context.Context, owner, repo string, number
} }
// GetPullRequestDiff fetches the unified diff for a PR. // GetPullRequestDiff fetches the unified diff for a PR.
// It enforces MaxDiffSize to prevent unbounded memory allocation.
// Returns ErrDiffTooLarge if the diff exceeds the configured limit.
func (c *Client) GetPullRequestDiff(ctx context.Context, owner, repo string, number int) (string, error) { func (c *Client) GetPullRequestDiff(ctx context.Context, owner, repo string, number int) (string, error) {
reqURL := fmt.Sprintf("%s/api/v1/repos/%s/%s/pulls/%d.diff", c.baseURL, url.PathEscape(owner), url.PathEscape(repo), number) reqURL := fmt.Sprintf("%s/api/v1/repos/%s/%s/pulls/%d.diff", c.baseURL, url.PathEscape(owner), url.PathEscape(repo), number)
body, err := c.doGet(ctx, reqURL)
maxSize := c.MaxDiffSize
if maxSize == 0 {
maxSize = DefaultMaxDiffSize
}
// When the limit is disabled (negative) or set to math.MaxInt64 (which
// would overflow the +1 detection and silently disable enforcement),
// use the standard unlimited doGet path.
if maxSize < 0 || maxSize == math.MaxInt64 {
body, err := c.doGet(ctx, reqURL)
if err != nil {
return "", fmt.Errorf("fetch diff: %w", err)
}
return string(body), nil
}
body, err := c.doGetLimited(ctx, reqURL, maxSize)
if err != nil { if err != nil {
return "", fmt.Errorf("fetch diff: %w", err) return "", fmt.Errorf("fetch diff: %w", err)
} }
@@ -326,9 +292,9 @@ func isRetriableSyscallError(err error) bool {
return true return true
} }
// redactURL strips query parameters and userinfo credentials from a URL for // redactURL strips query parameters from a URL for safe logging.
// safe logging. This prevents accidental exposure of sensitive data (tokens in // This prevents accidental exposure of sensitive data that future callers
// query strings, or user:pass in the authority) in log output. // might pass via query strings.
func redactURL(rawURL string) string { func redactURL(rawURL string) string {
parsed, err := url.Parse(rawURL) parsed, err := url.Parse(rawURL)
if err != nil { if err != nil {
@@ -336,9 +302,6 @@ func redactURL(rawURL string) string {
// potentially logging something sensitive. // potentially logging something sensitive.
return "[invalid URL]" return "[invalid URL]"
} }
if parsed.User != nil {
parsed.User = url.User("REDACTED")
}
if parsed.RawQuery != "" { if parsed.RawQuery != "" {
parsed.RawQuery = "[redacted]" parsed.RawQuery = "[redacted]"
} }
@@ -359,12 +322,10 @@ func sanitizeErrorForLog(err error) string {
return err.Error() return err.Error()
} }
// doGetWithReader performs an HTTP GET request with retry on 5xx errors and // doGet performs an HTTP GET request with retry on 5xx errors and temporary
// temporary network errors. Retries up to 3 times with exponential backoff // network errors. Retries up to 3 times with exponential backoff (1s, 2s delays
// (1s, 2s delays by default; configurable via Client.RetryBackoff for testing). // by default; configurable via Client.RetryBackoff for testing).
// The readBody function is called with the response body on success (2xx) and func (c *Client) doGet(ctx context.Context, reqURL string) ([]byte, error) {
// is responsible for reading and closing it.
func (c *Client) doGetWithReader(ctx context.Context, reqURL string, readBody func(io.ReadCloser) ([]byte, error)) ([]byte, error) {
const maxAttempts = 3 const maxAttempts = 3
// backoff[i] is the delay before attempt i+1 (i.e., after attempt i fails). // backoff[i] is the delay before attempt i+1 (i.e., after attempt i fails).
// First attempt (i=0) has no delay; retries wait 1s then 2s by default. // First attempt (i=0) has no delay; retries wait 1s then 2s by default.
@@ -429,7 +390,12 @@ func (c *Client) doGetWithReader(ctx context.Context, reqURL string, readBody fu
return nil, lastErr return nil, lastErr
} }
if resp.StatusCode >= 200 && resp.StatusCode < 300 { if resp.StatusCode >= 200 && resp.StatusCode < 300 {
return readBody(resp.Body) body, err := io.ReadAll(resp.Body)
resp.Body.Close()
if err != nil {
return nil, err
}
return body, nil
} }
// Error path: limit how much we read from potentially malicious server // Error path: limit how much we read from potentially malicious server
@@ -447,39 +413,6 @@ func (c *Client) doGetWithReader(ctx context.Context, reqURL string, readBody fu
return nil, lastErr return nil, lastErr
} }
// doGet performs an HTTP GET request with retry, reading the full response body.
func (c *Client) doGet(ctx context.Context, reqURL string) ([]byte, error) {
return c.doGetWithReader(ctx, reqURL, func(body io.ReadCloser) ([]byte, error) {
defer body.Close()
return io.ReadAll(body)
})
}
// doGetLimited performs an HTTP GET request with retry but enforces a maximum
// response body size. Returns ErrDiffTooLarge if the response exceeds maxBytes.
// It reads maxBytes+1 (clamped to avoid overflow) to detect truncation without
// buffering the entire body.
func (c *Client) doGetLimited(ctx context.Context, reqURL string, maxBytes int64) ([]byte, error) {
return c.doGetWithReader(ctx, reqURL, func(body io.ReadCloser) ([]byte, error) {
defer body.Close()
// Read up to maxBytes+1 to detect overflow.
// Clamp to prevent integer overflow when maxBytes == math.MaxInt64.
limitBytes := maxBytes + 1
if limitBytes <= 0 {
limitBytes = math.MaxInt64
}
limited := io.LimitReader(body, limitBytes)
data, err := io.ReadAll(limited)
if err != nil {
return nil, err
}
if int64(len(data)) > maxBytes {
return nil, fmt.Errorf("%w: response exceeds %d bytes", ErrDiffTooLarge, maxBytes)
}
return data, nil
})
}
// escapePath escapes each segment of a relative file path for use in URLs. // escapePath escapes each segment of a relative file path for use in URLs.
// Slashes are preserved as path separators; other special characters are escaped. // Slashes are preserved as path separators; other special characters are escaped.
// Input should be a relative path (no leading slash). Already-encoded segments // Input should be a relative path (no leading slash). Already-encoded segments
@@ -501,8 +434,6 @@ type ContentEntry struct {
// ListContents lists files and directories at a given path in a repo. // ListContents lists files and directories at a given path in a repo.
// Pass an empty path to list the repository root. // Pass an empty path to list the repository root.
// If the path points to a file (not a directory), Gitea returns a single
// object instead of an array; this method normalizes both cases to a slice.
func (c *Client) ListContents(ctx context.Context, owner, repo, path string) ([]ContentEntry, error) { func (c *Client) ListContents(ctx context.Context, owner, repo, path string) ([]ContentEntry, error) {
// Normalize "." to empty string — Gitea API rejects "." with 500 // Normalize "." to empty string — Gitea API rejects "." with 500
if path == "." { if path == "." {
@@ -520,16 +451,7 @@ func (c *Client) ListContents(ctx context.Context, owner, repo, path string) ([]
} }
var entries []ContentEntry var entries []ContentEntry
if err := json.Unmarshal(body, &entries); err != nil { if err := json.Unmarshal(body, &entries); err != nil {
// Gitea returns a single object (not an array) when path is a file return nil, fmt.Errorf("parse contents JSON: %w", err)
var single ContentEntry
if err2 := json.Unmarshal(body, &single); err2 != nil {
return nil, fmt.Errorf("parse contents JSON: %w", err)
}
// Guard against empty/malformed responses
if single.Name == "" && single.Path == "" {
return nil, fmt.Errorf("parse contents JSON: empty response for path %q", path)
}
entries = []ContentEntry{single}
} }
return entries, nil return entries, nil
} }
+2 -46
View File
@@ -304,40 +304,11 @@ func TestListContents_DotPath(t *testing.T) {
} }
} }
func TestListContents_FilePath(t *testing.T) {
// Gitea returns a single object (not an array) when path is a file
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/api/v1/repos/owner/repo/contents/README.md" {
t.Errorf("unexpected path: %s", r.URL.Path)
}
w.Header().Set("Content-Type", "application/json")
// Single object, not an array
fmt.Fprintf(w, `{"name":"README.md","path":"README.md","type":"file"}`)
}))
defer server.Close()
client := NewClient(server.URL, "test-token")
entries, err := client.ListContents(context.Background(), "owner", "repo", "README.md")
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if len(entries) != 1 {
t.Fatalf("expected 1 entry, got %d", len(entries))
}
if entries[0].Name != "README.md" {
t.Errorf("expected README.md, got %s", entries[0].Name)
}
if entries[0].Type != "file" {
t.Errorf("expected type file, got %s", entries[0].Type)
}
}
func TestGetAllFilesInPath_File(t *testing.T) { func TestGetAllFilesInPath_File(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path == "/api/v1/repos/owner/repo/contents/README.md" { if r.URL.Path == "/api/v1/repos/owner/repo/contents/README.md" {
// Gitea returns a single object (not array) when path is a file // Gitea returns 404 for contents API on files (it's not a dir)
w.Header().Set("Content-Type", "application/json") http.NotFound(w, r)
fmt.Fprintf(w, `{"name":"README.md","path":"README.md","type":"file"}`)
return return
} }
if r.URL.Path == "/api/v1/repos/owner/repo/raw/README.md" { if r.URL.Path == "/api/v1/repos/owner/repo/raw/README.md" {
@@ -1092,21 +1063,6 @@ func TestRedactURL(t *testing.T) {
input: "", input: "",
want: "", want: "",
}, },
{
name: "with userinfo - redacts credentials",
input: "https://admin:secret@gitea.example.com/api/v1/repos",
want: "https://REDACTED@gitea.example.com/api/v1/repos",
},
{
name: "with userinfo and query params",
input: "https://user:pass@example.com/path?token=abc",
want: "https://REDACTED@example.com/path?[redacted]",
},
{
name: "username only - no password",
input: "https://user@example.com/path",
want: "https://REDACTED@example.com/path",
},
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
-97
View File
@@ -1,97 +0,0 @@
package gitea
import (
"context"
"errors"
"math"
"net/http"
"net/http/httptest"
"strings"
"testing"
"time"
)
func TestGetPullRequestDiff_SizeLimits(t *testing.T) {
tests := []struct {
name string
diff string
maxDiffSize int64
wantErr error
wantDiff string
}{
{
name: "exceeds max size",
diff: strings.Repeat("+ added line\n", 1000), // ~13 KB
maxDiffSize: 100,
wantErr: ErrDiffTooLarge,
},
{
name: "within max size",
diff: "diff --git a/f.go b/f.go\n--- a/f.go\n+++ b/f.go\n@@ -1 +1 @@\n-old\n+new\n",
maxDiffSize: 1024,
wantDiff: "diff --git a/f.go b/f.go\n--- a/f.go\n+++ b/f.go\n@@ -1 +1 @@\n-old\n+new\n",
},
{
name: "exactly at limit",
diff: strings.Repeat("x", 50),
maxDiffSize: 50,
wantDiff: strings.Repeat("x", 50),
},
{
name: "one byte over limit",
diff: strings.Repeat("x", 51),
maxDiffSize: 50,
wantErr: ErrDiffTooLarge,
},
{
name: "disabled limit",
diff: strings.Repeat("x", 10000),
maxDiffSize: -1,
wantDiff: strings.Repeat("x", 10000),
},
{
name: "math.MaxInt64 treated as disabled",
diff: strings.Repeat("x", 10000),
maxDiffSize: math.MaxInt64,
wantDiff: strings.Repeat("x", 10000),
},
{
name: "default limit",
diff: "diff content",
maxDiffSize: 0, // zero means use DefaultMaxDiffSize
wantDiff: "diff content",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte(tt.diff)) //nolint:errcheck // test handler
}))
defer server.Close()
client := NewClient(server.URL, "test-token")
client.MaxDiffSize = tt.maxDiffSize
client.RetryBackoff = []time.Duration{}
got, err := client.GetPullRequestDiff(context.Background(), "owner", "repo", 1)
if tt.wantErr != nil {
if err == nil {
t.Fatal("expected error, got nil")
}
if !errors.Is(err, tt.wantErr) {
t.Errorf("expected %v, got: %v", tt.wantErr, err)
}
return
}
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if got != tt.wantDiff {
t.Errorf("diff mismatch: got length %d, want length %d", len(got), len(tt.wantDiff))
}
})
}
}
+1 -1
View File
@@ -2,4 +2,4 @@ module gitea.weiker.me/rodin/review-bot
go 1.26.2 go 1.26.2
require github.com/goccy/go-yaml v1.19.2 require gopkg.in/yaml.v3 v3.0.1
+4 -2
View File
@@ -1,2 +1,4 @@
github.com/goccy/go-yaml v1.19.2 h1:PmFC1S6h8ljIz6gMRBopkjP1TVT7xuwrButHID66PoM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
github.com/goccy/go-yaml v1.19.2/go.mod h1:XBurs7gK8ATbW4ZPGKgcbrY1Br56PdM69F7LkFRi1kA= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
+38 -146
View File
@@ -5,15 +5,12 @@ import (
"embed" "embed"
"encoding/json" "encoding/json"
"fmt" "fmt"
"io"
"os" "os"
"sort" "sort"
"strings" "strings"
"unicode/utf8" "unicode/utf8"
"github.com/goccy/go-yaml" "gopkg.in/yaml.v3"
"github.com/goccy/go-yaml/ast"
"github.com/goccy/go-yaml/parser"
) )
//go:embed personas/*.yaml //go:embed personas/*.yaml
@@ -121,7 +118,9 @@ func ListBuiltinPersonas() []string {
default: default:
continue continue
} }
seen[personaName] = true if !seen[personaName] {
seen[personaName] = true
}
} }
names := make([]string, 0, len(seen)) names := make([]string, 0, len(seen))
for name := range seen { for name := range seen {
@@ -143,19 +142,10 @@ func parsePersona(data []byte, source string) (*Persona, error) {
err = unmarshalYAMLWithDepthLimit(data, &p, MaxYAMLDepth) err = unmarshalYAMLWithDepthLimit(data, &p, MaxYAMLDepth)
} else { } else {
// Use json.Decoder with DisallowUnknownFields for consistency with // Use json.Decoder with DisallowUnknownFields for consistency with
// YAML's Strict() - both reject unknown fields to catch typos. // YAML's KnownFields(true) - both reject unknown fields to catch typos.
dec := json.NewDecoder(bytes.NewReader(data)) dec := json.NewDecoder(bytes.NewReader(data))
dec.DisallowUnknownFields() dec.DisallowUnknownFields()
err = dec.Decode(&p) err = dec.Decode(&p)
if err == nil {
// Reject trailing content after the first valid JSON object.
// Without this check, input like `{"name":"x"}garbage` would
// silently succeed because Decoder stops after one object.
var dummy json.RawMessage
if err2 := dec.Decode(&dummy); err2 != io.EOF {
err = fmt.Errorf("unexpected trailing content after JSON object")
}
}
} }
if err != nil { if err != nil {
return nil, fmt.Errorf("parse persona %s: %w", source, err) return nil, fmt.Errorf("parse persona %s: %w", source, err)
@@ -166,164 +156,70 @@ func parsePersona(data []byte, source string) (*Persona, error) {
return &p, nil return &p, nil
} }
// unmarshalYAMLWithDepthLimit unmarshals YAML data with three safety checks: // unmarshalYAMLWithDepthLimit unmarshals YAML data with explicit depth limiting
// - Depth limiting: rejects AST trees exceeding maxDepth to prevent stack exhaustion. // and strict field checking. This protects against stack exhaustion from deeply
// - Multi-document rejection: prevents silent data loss from ignored extra documents. // nested structures and catches typos in field names.
// - Strict field checking: rejects unknown YAML keys to catch typos early. // Multi-document YAML files are rejected to prevent silent data loss.
func unmarshalYAMLWithDepthLimit(data []byte, out any, maxDepth int) error { func unmarshalYAMLWithDepthLimit(data []byte, out any, maxDepth int) error {
// First pass: parse into AST to check depth limits, node counts, and // First pass: decode into a yaml.Node to check depth limits and node counts.
// multi-document rejection. This prevents stack exhaustion before we // This prevents stack exhaustion before we attempt to decode into structs.
// attempt to decode into structs. var node yaml.Node
file, err := parser.ParseBytes(data, 0) dec := yaml.NewDecoder(bytes.NewReader(data))
if err != nil { if err := dec.Decode(&node); err != nil {
return err return err
} }
// Reject empty YAML input (whitespace-only, comment-only, or truly empty files).
// The parser returns a single doc with nil body for these cases.
if len(file.Docs) == 0 || file.Docs[0].Body == nil {
return fmt.Errorf("empty YAML document")
}
// Reject multi-document YAML files - silently ignoring additional documents // Reject multi-document YAML files - silently ignoring additional documents
// could lead to confusing behavior where users think their changes take effect. // could lead to confusing behavior where users think their changes take effect.
if len(file.Docs) > 1 { var extra yaml.Node
if dec.Decode(&extra) == nil {
return fmt.Errorf("multi-document YAML is not supported; only single-document files are allowed") return fmt.Errorf("multi-document YAML is not supported; only single-document files are allowed")
} }
nodeCount := 0 nodeCount := 0
if err := checkYAMLDepth(file.Docs[0].Body, 0, maxDepth, MaxYAMLNodes, make(map[ast.Node]int), make(map[ast.Node]bool), &nodeCount); err != nil { if err := checkYAMLDepth(&node, 0, maxDepth, MaxYAMLNodes, make(map[*yaml.Node]struct{}), &nodeCount); err != nil {
return err return err
} }
// Second pass: decode with strict field checking enabled. // Second pass: decode with strict field checking enabled.
// Strict() rejects unknown keys, catching typos like "focuss" or "identiy". // KnownFields(true) rejects unknown keys, catching typos like "focuss" or "identiy".
// // We must re-decode from the original data because yaml.Node.Decode() doesn't
// Safety note: goccy/go-yaml's decoder does not expand YAML aliases // support the KnownFields option.
// recursively — it resolves them via the pre-built AST, which our first strictDec := yaml.NewDecoder(bytes.NewReader(data))
// pass already depth-checked. Alias chains that would exceed depth limits strictDec.KnownFields(true)
// are caught above; the decoder merely reads the resolved scalar values. return strictDec.Decode(out)
dec := yaml.NewDecoder(bytes.NewReader(data), yaml.Strict())
return dec.Decode(out)
} }
// checkYAMLDepth recursively checks that YAML AST nodes don't exceed the depth // checkYAMLDepth recursively checks that YAML nodes don't exceed the depth limit
// limit or the total node count limit. It uses two tracking maps: // or the total node count limit. It also detects alias cycles to prevent infinite
// - validated: maps each node to the maximum depth at which it was previously // recursion from crafted YAML with self-referential aliases.
// checked. If a node is revisited at a deeper depth (e.g., via an alias), func checkYAMLDepth(node *yaml.Node, depth, maxDepth, maxNodes int, seen map[*yaml.Node]struct{}, nodeCount *int) error {
// we re-check it to ensure the combined effective depth doesn't exceed limits.
// - visiting: per-path recursion stack for true cycle detection. A node on the
// current path is a cycle (alias loop); we return nil to avoid infinite recursion.
//
// This design prevents the alias depth bypass where an anchored subtree validated
// at a shallow depth could be referenced via alias at a greater depth, effectively
// exceeding MaxYAMLDepth.
func checkYAMLDepth(node ast.Node, depth, maxDepth, maxNodes int, validated map[ast.Node]int, visiting map[ast.Node]bool, nodeCount *int) error {
if node == nil {
return nil
}
if depth > maxDepth { if depth > maxDepth {
return fmt.Errorf("YAML nesting depth exceeds maximum (%d)", maxDepth) return fmt.Errorf("YAML nesting depth exceeds maximum (%d)", maxDepth)
} }
// Cycle detection: if we're currently visiting this node on the current
// recursion path, it's a cycle (e.g., alias pointing to an ancestor).
// Return nil to break the cycle without error — cycles are a structural
// property, not a depth violation.
if visiting[node] {
return nil
}
// Track total nodes visited as defense-in-depth against wide-but-shallow attacks. // Track total nodes visited as defense-in-depth against wide-but-shallow attacks.
// Placed after cycle detection but before the depth-aware short-circuit. This means
// nodes revisited at shallower depths (via aliases) are counted each time they are
// encountered — intentional conservative overcounting. This bounds the total work
// performed during validation rather than tracking unique nodes, which is the safer
// security posture for untrusted YAML input.
*nodeCount++ *nodeCount++
if *nodeCount > maxNodes { if *nodeCount > maxNodes {
return fmt.Errorf("YAML node count exceeds maximum (%d)", maxNodes) return fmt.Errorf("YAML node count exceeds maximum (%d)", maxNodes)
} }
// Depth-aware short-circuit: skip re-validation only when the current visit // Cycle detection: if we've seen this node before, we're in a cycle.
// depth is the same or shallower than the depth at which this node was if _, ok := seen[node]; ok {
// previously validated. A shallower (or equal) current depth means the return nil // Already validated this subtree, skip to avoid infinite recursion.
// prior, deeper validation already covered any subtree depth violations.
// If the current depth exceeds the previous validation depth (e.g., an alias
// references this node deeper in the tree), we must re-traverse to ensure
// the combined effective depth doesn't exceed maxDepth.
//
// Note: using ast.Node (interface) as map key relies on pointer identity,
// which is correct because all goccy/go-yaml AST node types are pointer
// receivers (*MappingNode, *SequenceNode, etc.), never value types.
if prevDepth, ok := validated[node]; ok && depth <= prevDepth {
return nil
} }
validated[node] = depth seen[node] = struct{}{}
// Mark as visiting (on the current recursion path) for cycle detection. // Handle alias nodes: follow the alias to its anchor target.
visiting[node] = true // Increment depth when following aliases since they expand the effective structure.
defer func() { visiting[node] = false }() if node.Kind == yaml.AliasNode && node.Alias != nil {
return checkYAMLDepth(node.Alias, depth+1, maxDepth, maxNodes, seen, nodeCount)
}
// Walk children based on node type. for _, child := range node.Content {
switch n := node.(type) { if err := checkYAMLDepth(child, depth+1, maxDepth, maxNodes, seen, nodeCount); err != nil {
case *ast.MappingNode:
for _, value := range n.Values {
if err := checkYAMLDepth(value, depth+1, maxDepth, maxNodes, validated, visiting, nodeCount); err != nil {
return err
}
}
case *ast.MappingValueNode:
// Both Key and Value are visited at depth+1 relative to this
// MappingValueNode. Since MappingNode visits its MappingValueNode
// children at depth+1 as well, keys and values end up at depth+2
// from the parent MappingNode. This is intentional: it mirrors the
// actual nesting structure (mapping → key-value pair → key/value).
if err := checkYAMLDepth(n.Key, depth+1, maxDepth, maxNodes, validated, visiting, nodeCount); err != nil {
return err return err
} }
if err := checkYAMLDepth(n.Value, depth+1, maxDepth, maxNodes, validated, visiting, nodeCount); err != nil {
return err
}
case *ast.SequenceNode:
for _, value := range n.Values {
if err := checkYAMLDepth(value, depth+1, maxDepth, maxNodes, validated, visiting, nodeCount); err != nil {
return err
}
}
case *ast.AliasNode:
// Follow alias to its target, incrementing depth since aliases expand
// the effective structure.
if err := checkYAMLDepth(n.Value, depth+1, maxDepth, maxNodes, validated, visiting, nodeCount); err != nil {
return err
}
case *ast.AnchorNode:
// Increment depth for anchor values as a conservative measure: the
// anchor definition itself is structural, and treating it as a depth
// level ensures that deeply nested anchors are caught at definition
// time rather than only when referenced via alias. This +1 is
// asymmetric with alias (which also increments) — by design, the
// effective depth budget for anchored-then-aliased content is reduced
// because both the definition site and the reference site each consume
// a level, making deeply nested anchor/alias pairs hit the limit sooner.
if err := checkYAMLDepth(n.Value, depth+1, maxDepth, maxNodes, validated, visiting, nodeCount); err != nil {
return err
}
case *ast.TagNode:
if err := checkYAMLDepth(n.Value, depth+1, maxDepth, maxNodes, validated, visiting, nodeCount); err != nil {
return err
}
case *ast.MergeKeyNode:
// MergeKeyNode represents the literal "<<" merge key token. It has no
// child nodes — the value side of a merge (e.g., *alias) lives in the
// parent MappingValueNode.Value, which is already recursed into above.
// Explicitly listed here (rather than in the default case) to prevent
// future library changes from silently bypassing depth checks.
default:
// Scalar leaf nodes (StringNode, IntegerNode, FloatNode, BoolNode,
// NullNode, InfinityNode, NanNode, LiteralNode) have no children to
// recurse into.
} }
return nil return nil
} }
@@ -331,11 +227,7 @@ func checkYAMLDepth(node ast.Node, depth, maxDepth, maxNodes int, validated map[
// ParsePersonaBytes parses persona data from bytes with a source label for errors. // ParsePersonaBytes parses persona data from bytes with a source label for errors.
// This is useful for parsing personas fetched from external sources (e.g., Gitea API) // This is useful for parsing personas fetched from external sources (e.g., Gitea API)
// without requiring filesystem access. Format is detected by source extension. // without requiring filesystem access. Format is detected by source extension.
// Input is bounded by MaxPersonaFileSize to prevent resource exhaustion.
func ParsePersonaBytes(data []byte, source string) (*Persona, error) { func ParsePersonaBytes(data []byte, source string) (*Persona, error) {
if len(data) > MaxPersonaFileSize {
return nil, fmt.Errorf("persona data from %s exceeds maximum size (%d bytes, limit %d)", source, len(data), MaxPersonaFileSize)
}
return parsePersona(data, source) return parsePersona(data, source)
} }
+41 -222
View File
@@ -7,7 +7,7 @@ import (
"strings" "strings"
"testing" "testing"
"github.com/goccy/go-yaml/ast" "gopkg.in/yaml.v3"
) )
func TestLoadBuiltinPersona(t *testing.T) { func TestLoadBuiltinPersona(t *testing.T) {
@@ -459,14 +459,7 @@ func TestYAMLDeeplyNestedRejection(t *testing.T) {
path := filepath.Join(dir, "deeply-nested.yaml") path := filepath.Join(dir, "deeply-nested.yaml")
// Build a deeply nested YAML structure that exceeds MaxYAMLDepth (20). // Build a deeply nested YAML structure that exceeds MaxYAMLDepth (20).
// Depth accumulation trace for "nested: \n level0: \n level1: ...": // Each level adds 2 to the depth count (key + value mapping).
// - Document root parsed at depth 0
// - Root MappingNode children (MappingValueNodes) visited at depth 1
// - "nested" MappingValueNode: key at depth 2, value at depth 2
// - Each levelN adds depth via MappingValueNode traversal (key + value)
// - Exact depth per level depends on AST structure (MappingNode wrapping),
// but 25 levels reliably exceeds MaxYAMLDepth (20) with comfortable margin.
// The test uses 25 levels rather than exactly 21 to avoid brittleness.
var sb strings.Builder var sb strings.Builder
sb.WriteString("name: test\nidentity: test\nnested:\n") sb.WriteString("name: test\nidentity: test\nnested:\n")
indent := " " indent := " "
@@ -490,35 +483,6 @@ func TestYAMLDeeplyNestedRejection(t *testing.T) {
} }
} }
func TestYAMLEmptyFileRejection(t *testing.T) {
tests := []struct {
name string
content string
}{
{"completely_empty", ""},
{"whitespace_only", " \n\n "},
{"comment_only", "# just a comment\n"},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
dir := t.TempDir()
path := filepath.Join(dir, tc.name+".yaml")
if err := os.WriteFile(path, []byte(tc.content), 0644); err != nil {
t.Fatalf("failed to write test file: %v", err)
}
_, err := LoadPersona(path)
if err == nil {
t.Fatal("expected error for empty YAML input, got nil")
}
if !strings.Contains(err.Error(), "empty YAML document") {
t.Errorf("expected error containing %q, got: %v", "empty YAML document", err)
}
})
}
}
func TestYAMLFileSizeLimit(t *testing.T) { func TestYAMLFileSizeLimit(t *testing.T) {
dir := t.TempDir() dir := t.TempDir()
path := filepath.Join(dir, "huge.yaml") path := filepath.Join(dir, "huge.yaml")
@@ -540,41 +504,41 @@ func TestYAMLFileSizeLimit(t *testing.T) {
func TestYAMLAliasCycleDetection(t *testing.T) { func TestYAMLAliasCycleDetection(t *testing.T) {
// Test that our checkYAMLDepth function handles alias cycles gracefully // Test that our checkYAMLDepth function handles alias cycles gracefully
// by using the visiting map to prevent infinite recursion. // by using the seen map to prevent infinite recursion.
// We test this directly because go-yaml's parser handles most cycles
// at parse time, but we need to ensure our checker is robust.
// Create a node structure where an alias points to a parent node, // Create a node structure where an alias points to a parent node,
// simulating what could happen with crafted input. // simulating what could happen with malicious input that bypasses
parent := &ast.MappingNode{ // go-yaml's cycle detection.
Values: []*ast.MappingValueNode{ parent := &yaml.Node{
{ Kind: yaml.MappingNode,
Key: &ast.StringNode{Value: "name"}, Content: []*yaml.Node{
Value: &ast.StringNode{Value: "test"}, {Kind: yaml.ScalarNode, Value: "name"},
}, {Kind: yaml.ScalarNode, Value: "test"},
{Kind: yaml.ScalarNode, Value: "nested"},
}, },
} }
// Create a child that aliases back to the parent (artificial cycle) // Create a child that aliases back to the parent (artificial cycle)
aliasToParent := &ast.AliasNode{ aliasToParent := &yaml.Node{
Value: parent, Kind: yaml.AliasNode,
Alias: parent,
} }
parent.Values = append(parent.Values, &ast.MappingValueNode{ parent.Content = append(parent.Content, aliasToParent)
Key: &ast.StringNode{Value: "nested"},
Value: aliasToParent,
})
nodeCount := 0 nodeCount := 0
validated := make(map[ast.Node]int) seen := make(map[*yaml.Node]struct{})
visiting := make(map[ast.Node]bool)
// This should NOT hang or stack overflow - cycle detection prevents infinite recursion // This should NOT hang or stack overflow - the seen map prevents infinite recursion
err := checkYAMLDepth(parent, 0, MaxYAMLDepth, MaxYAMLNodes, validated, visiting, &nodeCount) err := checkYAMLDepth(parent, 0, MaxYAMLDepth, MaxYAMLNodes, seen, &nodeCount)
if err != nil { if err != nil {
t.Errorf("unexpected error traversing cyclic structure: %v", err) t.Errorf("unexpected error traversing cyclic structure: %v", err)
} }
// Verify we tracked the parent in the validated map // Verify we tracked the parent in the seen map
if _, ok := validated[parent]; !ok { if _, ok := seen[parent]; !ok {
t.Error("parent node not tracked in validated map") t.Error("parent node not tracked in seen map")
} }
} }
@@ -630,82 +594,36 @@ func TestYAMLNodeCountLimit(t *testing.T) {
func TestCheckYAMLDepthCycleDetectionDirect(t *testing.T) { func TestCheckYAMLDepthCycleDetectionDirect(t *testing.T) {
// Direct test of cycle detection in checkYAMLDepth by creating // Direct test of cycle detection in checkYAMLDepth by creating
// a node structure with an artificial cycle. // a node structure with an artificial cycle.
node := &ast.MappingNode{ // This tests the seen map logic independent of go-yaml's parsing.
Values: []*ast.MappingValueNode{ node := &yaml.Node{
{ Kind: yaml.MappingNode,
Key: &ast.StringNode{Value: "key"}, Content: []*yaml.Node{
Value: &ast.StringNode{Value: "value"}, {Kind: yaml.ScalarNode, Value: "key"},
}, {Kind: yaml.ScalarNode, Value: "value"},
}, },
} }
// Create a cycle by making a child reference the parent // Create a cycle by making a child reference the parent
cycleChild := &ast.AliasNode{ cycleChild := &yaml.Node{
Value: node, // Points back to the parent Kind: yaml.AliasNode,
Alias: node, // Points back to the parent
} }
node.Values = append(node.Values, &ast.MappingValueNode{ node.Content = append(node.Content,
Key: &ast.StringNode{Value: "cyclic"}, &yaml.Node{Kind: yaml.ScalarNode, Value: "cyclic"},
Value: cycleChild, cycleChild,
}) )
nodeCount := 0 nodeCount := 0
validated := make(map[ast.Node]int) seen := make(map[*yaml.Node]struct{})
visiting := make(map[ast.Node]bool) err := checkYAMLDepth(node, 0, MaxYAMLDepth, MaxYAMLNodes, seen, &nodeCount)
err := checkYAMLDepth(node, 0, MaxYAMLDepth, MaxYAMLNodes, validated, visiting, &nodeCount)
// Should complete without infinite recursion due to cycle detection // Should complete without infinite recursion due to cycle detection
if err != nil { if err != nil {
t.Errorf("unexpected error: %v", err) t.Errorf("unexpected error: %v", err)
} }
// The validated map should contain multiple entries // The seen map should contain multiple entries
if len(validated) < 2 { if len(seen) < 2 {
t.Errorf("validated map has %d entries, expected at least 2", len(validated)) t.Errorf("seen map has %d entries, expected at least 2", len(seen))
}
}
func TestYAMLAliasDepthBypass(t *testing.T) {
// Test that an anchored subtree first validated at a shallow depth is
// re-checked when referenced via alias at a deeper position. Without the
// depth-aware validated map, the alias reference would skip re-checking
// and allow the effective nesting to exceed MaxYAMLDepth.
dir := t.TempDir()
path := filepath.Join(dir, "alias-depth-bypass.yaml")
// Build YAML with an anchor at shallow depth containing a subtree near the limit,
// then reference it via alias deep enough that effective depth exceeds MaxYAMLDepth.
var sb strings.Builder
sb.WriteString("name: test\nidentity: test\n")
// Create the anchored subtree at depth 1 (key level) that nests 15 levels deep.
sb.WriteString("anchor_key: &deep_anchor\n")
for i := 0; i < 15; i++ {
sb.WriteString(strings.Repeat(" ", i+1))
sb.WriteString(fmt.Sprintf("level%d:\n", i))
}
sb.WriteString(strings.Repeat(" ", 16))
sb.WriteString("leaf: value\n")
// Create a wrapper that nests 6 levels deep, then references the anchor.
// Effective depth at alias target = 6 (wrapper nesting) + 1 (alias) + 15 (subtree) = 22 > 20
sb.WriteString("wrapper:\n")
for i := 0; i < 6; i++ {
sb.WriteString(strings.Repeat(" ", i+1))
sb.WriteString(fmt.Sprintf("n%d:\n", i))
}
sb.WriteString(strings.Repeat(" ", 7))
sb.WriteString("alias_ref: *deep_anchor\n")
if err := os.WriteFile(path, []byte(sb.String()), 0644); err != nil {
t.Fatalf("failed to write test file: %v", err)
}
_, err := LoadPersona(path)
if err == nil {
t.Fatal("expected error for alias depth bypass, got nil")
}
if !strings.Contains(err.Error(), "nesting depth exceeds") {
t.Errorf("error = %q, want containing 'nesting depth exceeds'", err.Error())
} }
} }
@@ -858,102 +776,3 @@ identity: test identity
t.Errorf("Name = %q, want %q", p.Name, "test") t.Errorf("Name = %q, want %q", p.Name, "test")
} }
} }
func TestJSONTrailingContentRejected(t *testing.T) {
tests := []struct {
name string
content string
}{
{
name: "trailing garbage after object",
content: `{"name":"test","identity":"test identity"}garbage`,
},
{
name: "two JSON objects",
content: `{"name":"test","identity":"test identity"}{"name":"other"}`,
},
{
name: "trailing array",
content: `{"name":"test","identity":"test identity"}[]`,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
dir := t.TempDir()
path := filepath.Join(dir, "test.json")
if err := os.WriteFile(path, []byte(tt.content), 0644); err != nil {
t.Fatalf("failed to write test file: %v", err)
}
_, err := LoadPersona(path)
if err == nil {
t.Fatal("expected error for trailing content, got nil")
}
if !strings.Contains(err.Error(), "trailing content") {
t.Errorf("error = %q, want to contain 'trailing content'", err.Error())
}
})
}
}
func TestParsePersonaBytesSizeLimit(t *testing.T) {
// ParsePersonaBytes should reject input exceeding MaxPersonaFileSize
oversized := make([]byte, MaxPersonaFileSize+1)
for i := range oversized {
oversized[i] = 'x'
}
_, err := ParsePersonaBytes(oversized, "oversized.yaml")
if err == nil {
t.Fatal("expected error for oversized input, got nil")
}
if !strings.Contains(err.Error(), "exceeds maximum size") {
t.Errorf("error = %q, want to contain 'exceeds maximum size'", err.Error())
}
// Just under the limit should not trigger size error (may fail parse, but not size)
underLimit := []byte("name: test\nidentity: test persona\n")
p, err := ParsePersonaBytes(underLimit, "valid.yaml")
if err != nil {
t.Fatalf("unexpected error for valid input: %v", err)
}
if p.Name != "test" {
t.Errorf("Name = %q, want %q", p.Name, "test")
}
}
func TestYAMLMergeKeyDepthCheck(t *testing.T) {
// Verify that YAML merge keys (<<: *alias) are properly handled by the
// depth checker. The merge key content is in the MappingValueNode.Value
// (an AliasNode), not in the MergeKeyNode itself.
p, err := ParsePersonaBytes([]byte("name: merge-test\nidentity: test\n"), "merge.yaml")
if err != nil {
t.Fatalf("basic parse failed: %v", err)
}
if p.Name != "merge-test" {
t.Errorf("Name = %q, want %q", p.Name, "merge-test")
}
// Test that deeply nested merge keys still hit depth limit.
// Build YAML with merge key content nested beyond MaxYAMLDepth.
var sb strings.Builder
sb.WriteString("name: deep-merge\nidentity: deep merge persona\n")
sb.WriteString("anchor: &deep\n")
indent := " "
for i := 0; i < MaxYAMLDepth+5; i++ {
sb.WriteString(indent)
sb.WriteString(fmt.Sprintf("level%d:\n", i))
indent += " "
}
sb.WriteString(indent + "leaf: value\n")
sb.WriteString("target:\n <<: *deep\n")
_, err = ParsePersonaBytes([]byte(sb.String()), "deep-merge.yaml")
if err == nil {
t.Fatal("expected error for deeply nested merge key content, got nil")
}
if !strings.Contains(err.Error(), "depth") {
t.Errorf("error = %q, want to contain 'depth'", err.Error())
}
}