Compare commits

..

1 Commits

Author SHA1 Message Date
claw 8ebfa80c14 fix(security): prevent alias depth bypass in YAML validator
PR Ready Gate / clear-labels (pull_request) Successful in 2s
CI / test (pull_request) Successful in 17s
CI / review (anthropic--claude-4.6-sonnet, sonnet, SONNET_REVIEW_TOKEN) (pull_request) Successful in 42s
CI / review (gpt-5, gpt, GPT_REVIEW_TOKEN) (pull_request) Successful in 1m56s
CI / review (gpt-5, security, ., rodin/security-patterns, SECURITY_REVIEW.md, SECURITY_REVIEW_TOKEN) (pull_request) Successful in 2m31s
The global 'seen' set allowed anchored subtrees validated at a shallow
depth to be skipped when later referenced via alias at a greater depth.
This could let effective nesting exceed MaxYAMLDepth, enabling DoS.

Fix: replace the single 'seen' set with two tracking maps:
- validated (node -> min depth): only short-circuits when current depth
  <= previously validated depth; re-checks at deeper contexts.
- visiting (node -> bool): per-path recursion stack for true cycle
  detection (breaks alias loops without suppressing depth checks).

Add TestYAMLAliasDepthBypass that constructs a document with an
anchored 15-level subtree referenced via alias under 6 levels of
nesting, verifying the combined effective depth (22) is rejected.

Addresses security-review-bot findings on review #2774.
2026-05-12 14:05:26 -07:00
6 changed files with 61 additions and 864 deletions
-2
View File
@@ -21,8 +21,6 @@ To request a new dependency:
2. Requires explicit approval from Aaron
3. After merge, a separate PR may use the package
<!-- Deviation from step 1+3 for go-yaml migration: see #91 for rationale. -->
*Enforcement: `scripts/check-deps.sh` parses this table — update only here.*
## Error Handling
+31 -10
View File
@@ -9,7 +9,7 @@ JSON is awkward for persona files that contain multi-line text (identity, severi
- Backwards compatibility: existing JSON personas must continue to work
- Security: protect against DoS via deeply nested YAML (AIKIDO-2024-10486)
- Consistency: use `.yaml` extension (not `.yml`)
- Library: use `github.com/goccy/go-yaml` v1.16.0+ (approved in CONVENTIONS.md); we implement custom AST-based depth/node-count checks for precise alias-aware validation
- Library: use `github.com/goccy/go-yaml` v1.16.0+ (approved in CONVENTIONS.md); has built-in depth protection via `MaxYAMLDepth`/`MaxYAMLNodes` constants
## Proposed Approach
@@ -33,16 +33,37 @@ func parsePersona(data []byte, source string) (*Persona, error) {
### YAML Parsing with Depth Protection
We implement a custom AST-based depth/node-count walk (`checkYAMLDepth` in
`review/persona.go`) rather than relying on library decoder options. Key design
decisions:
```go
func unmarshalYAMLWithDepthLimit(data []byte, out any, maxDepth int) error {
var node yaml.Node
dec := yaml.NewDecoder(bytes.NewReader(data))
if err := dec.Decode(&node); err != nil {
return err
}
if err := checkYAMLDepth(&node, 0, maxDepth); err != nil {
return err
}
return node.Decode(out)
}
- **Library:** `github.com/goccy/go-yaml` with `ast.Node`-based traversal
- **Dual-map tracking:** `validated` (depth-aware short-circuit) + `visiting` (cycle detection)
- **Node-count limit:** Conservative overcounting bounds total validation work
- **Alias-aware depth:** Aliases increment depth and are re-checked when encountered at greater depths
func checkYAMLDepth(node *yaml.Node, depth, maxDepth int) error {
if depth > maxDepth {
return fmt.Errorf("YAML nesting depth exceeds maximum (%d)", maxDepth)
}
// Handle alias nodes by following the Alias pointer
if node.Kind == yaml.AliasNode && node.Alias != nil {
return checkYAMLDepth(node.Alias, depth, maxDepth)
}
for _, child := range node.Content {
if err := checkYAMLDepth(child, depth+1, maxDepth); err != nil {
return err
}
}
return nil
}
```
See `review/persona.go:checkYAMLDepth` for the authoritative implementation.
The `github.com/goccy/go-yaml` library provides built-in depth protection via `MaxYAMLDepth` and `MaxYAMLNodes` decoder options. We use these instead of a manual depth-checking walk.
## State/Data Model
@@ -53,7 +74,7 @@ No new state. Same `Persona` struct, just different parsing.
| Error | Handling |
|-------|----------|
| Invalid YAML syntax | Return parse error with source file |
| Deeply nested YAML | Custom AST walk (`checkYAMLDepth`) rejects before decode |
| Deeply nested YAML | Library rejects (v1.16.0+ fix) |
| Unknown extension | Fall back to JSON parsing |
| Missing required fields | Validation rejects after parse |
-260
View File
@@ -1,260 +0,0 @@
// Package github provides a client for the GitHub API.
// It supports pull request operations, file content retrieval,
// and review submission for both github.com and GitHub Enterprise.
package github
import (
"context"
"errors"
"fmt"
"io"
"net/http"
"strconv"
"strings"
"time"
)
const (
defaultBaseURL = "https://api.github.com"
// maxRetryAttempts is the number of times doRequest will attempt a request.
maxRetryAttempts = 3
// maxRetryAfter caps the maximum delay from a Retry-After header to prevent
// a server from stalling the client indefinitely.
maxRetryAfter = 60 * time.Second
// maxErrorBodyBytes limits how much of an error response body we read
// to protect against malicious servers sending unbounded data.
maxErrorBodyBytes = 64 * 1024 // 64 KB
// maxResponseBodyBytes limits how much of a successful response body we read
// for defense-in-depth against servers returning excessively large payloads.
maxResponseBodyBytes = 10 * 1024 * 1024 // 10 MB
)
// APIError represents an HTTP error response from the GitHub API.
// It carries the status code so callers can distinguish between
// different failure modes (e.g. 404 vs 500).
//
// The Body field stores up to 64 KiB of the raw response for programmatic
// inspection. Error() truncates to 200 bytes for safe logging, but callers
// should avoid logging or propagating Body directly in production since it may
// contain sensitive details from the upstream server.
type APIError struct {
StatusCode int
Body string
}
func (e *APIError) Error() string {
body := e.Body
if len(body) > 200 {
body = body[:200] + "...(truncated)"
}
// Sanitize newlines to prevent log injection from upstream response bodies.
body = strings.ReplaceAll(body, "\n", " ")
body = strings.ReplaceAll(body, "\r", " ")
return fmt.Sprintf("HTTP %d: %s", e.StatusCode, body)
}
// IsNotFound reports whether an error is an API 404 response.
func IsNotFound(err error) bool {
if apiErr, ok := asAPIError(err); ok {
return apiErr.StatusCode == http.StatusNotFound
}
return false
}
// IsUnauthorized reports whether an error is an API 401 response.
func IsUnauthorized(err error) bool {
if apiErr, ok := asAPIError(err); ok {
return apiErr.StatusCode == http.StatusUnauthorized
}
return false
}
func asAPIError(err error) (*APIError, bool) {
if err == nil {
return nil, false
}
var target *APIError
if errors.As(err, &target) {
return target, true
}
return nil, false
}
// Client interacts with the GitHub API.
// A Client is safe for concurrent use by multiple goroutines.
// SetHTTPClient and SetRetryBackoff are intended for test setup only and must
// be called before any goroutines issue requests; they have no synchronization.
type Client struct {
// TODO: baseURL is populated by NewClient but not yet consumed by doRequest/doGet.
// Higher-level exported methods (GetPullRequest, etc.) will use it to
// construct request URLs; remove this field if those methods end up
// accepting full URLs instead.
baseURL string
token string
httpClient *http.Client
// retryBackoff defines the delays between retry attempts for 429 responses.
// retryBackoff[i] is the delay before attempt i+1 (after attempt i fails).
// If nil, defaults to {1s, 2s}.
retryBackoff []time.Duration
// now returns the current time. Defaults to time.Now.
// Override in tests to control HTTP-date Retry-After calculations.
now func() time.Time
}
// NewClient creates a new GitHub API client.
// If baseURL is empty, it defaults to https://api.github.com.
// For GitHub Enterprise, pass the API base URL (e.g. https://github.concur.com/api/v3).
func NewClient(token, baseURL string) *Client {
if baseURL == "" {
baseURL = defaultBaseURL
}
return &Client{
baseURL: strings.TrimRight(baseURL, "/"),
token: token,
httpClient: &http.Client{Timeout: 30 * time.Second},
now: time.Now,
}
}
// SetHTTPClient sets the underlying HTTP client used for requests.
// This is intended for testing to inject mock transports.
func (c *Client) SetHTTPClient(hc *http.Client) {
c.httpClient = hc
}
// SetRetryBackoff sets the delays between retry attempts.
// This is intended for testing to speed up retry tests.
//
// Note: if an empty non-nil slice is provided, Retry-After delays parsed from
// server responses will be computed and capped but not applied (because
// attempt < len(backoff) is always false). This is acceptable for the
// test-only use case but callers should be aware of this edge case.
func (c *Client) SetRetryBackoff(backoff []time.Duration) {
c.retryBackoff = backoff
}
// parseRetryAfter parses a Retry-After header value, supporting both integer
// seconds (e.g. "120") and HTTP-date format (e.g. "Thu, 01 Dec 2025 16:00:00 GMT")
// as specified in RFC 7231 §7.1.3.
//
// For integer values, it returns the duration directly.
// For HTTP-date values, it computes the delay as the difference between the
// parsed time and now. If the date is in the past, it returns 0.
//
// Returns (0, false) if the value cannot be parsed as either format.
func (c *Client) parseRetryAfter(value string) (time.Duration, bool) {
value = strings.TrimSpace(value)
// Try integer seconds first (most common from GitHub).
// RFC 7231 allows delta-seconds of 0 to indicate immediate retry.
if seconds, err := strconv.Atoi(value); err == nil && seconds >= 0 {
return time.Duration(seconds) * time.Second, true
}
// Try HTTP-date format (RFC 7231 §7.1.3).
// http.ParseTime handles RFC 1123, RFC 850, and ASCTIME formats.
if retryAt, err := http.ParseTime(value); err == nil {
delay := retryAt.Sub(c.now())
if delay < 0 {
delay = 0
}
return delay, true
}
return 0, false
}
// doRequest performs an HTTP request with retry on 429 rate limit responses.
// It respects the Retry-After header when present, supporting both integer
// seconds and HTTP-date formats (capped at maxRetryAfter).
func (c *Client) doRequest(ctx context.Context, method, reqURL string, accept string) ([]byte, error) {
var backoff []time.Duration
if c.retryBackoff != nil {
backoff = append([]time.Duration(nil), c.retryBackoff...)
} else {
backoff = []time.Duration{1 * time.Second, 2 * time.Second}
}
var lastErr error
for attempt := 0; attempt < maxRetryAttempts; attempt++ {
if attempt > 0 {
var delay time.Duration
if attempt-1 < len(backoff) {
delay = backoff[attempt-1]
}
if delay > 0 {
timer := time.NewTimer(delay)
select {
case <-timer.C:
timer.Stop()
case <-ctx.Done():
timer.Stop()
return nil, ctx.Err()
}
}
}
req, err := http.NewRequestWithContext(ctx, method, reqURL, nil)
if err != nil {
return nil, fmt.Errorf("create request: %w", err)
}
req.Header.Set("Authorization", "Bearer "+c.token)
if accept != "" {
req.Header.Set("Accept", accept)
} else {
req.Header.Set("Accept", "application/vnd.github+json")
}
resp, err := c.httpClient.Do(req)
if err != nil {
return nil, fmt.Errorf("do request: %w", err)
}
if resp.StatusCode >= 200 && resp.StatusCode < 300 {
body, err := io.ReadAll(io.LimitReader(resp.Body, maxResponseBodyBytes))
resp.Body.Close()
if err != nil {
return nil, fmt.Errorf("read response body: %w", err)
}
return body, nil
}
errBody, _ := io.ReadAll(io.LimitReader(resp.Body, maxErrorBodyBytes))
resp.Body.Close()
lastErr = &APIError{StatusCode: resp.StatusCode, Body: string(errBody)}
// Retry on 429 rate limit
if resp.StatusCode == http.StatusTooManyRequests && attempt < maxRetryAttempts-1 {
// Check for Retry-After header and override backoff if present.
// Supports both integer seconds (common) and HTTP-date format (RFC 7231).
if ra := resp.Header.Get("Retry-After"); ra != "" {
if delay, ok := c.parseRetryAfter(ra); ok {
if delay > maxRetryAfter {
delay = maxRetryAfter
}
if attempt < len(backoff) {
backoff[attempt] = delay
}
}
}
continue
}
// Don't retry other errors
return nil, lastErr
}
return nil, lastErr
}
// doGet is a convenience wrapper for GET requests with the default Accept header.
func (c *Client) doGet(ctx context.Context, url string) ([]byte, error) {
return c.doRequest(ctx, http.MethodGet, url, "")
}
-409
View File
@@ -1,409 +0,0 @@
package github
import (
"context"
"errors"
"net/http"
"net/http/httptest"
"testing"
"time"
)
func TestNewClient_DefaultBaseURL(t *testing.T) {
c := NewClient("tok", "")
if c.baseURL != defaultBaseURL {
t.Errorf("baseURL = %q, want %q", c.baseURL, defaultBaseURL)
}
}
func TestNewClient_CustomBaseURL(t *testing.T) {
c := NewClient("tok", "https://github.concur.com/api/v3/")
if c.baseURL != "https://github.concur.com/api/v3" {
t.Errorf("baseURL = %q, want trailing slash stripped", c.baseURL)
}
}
func TestDoRequest_Success(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if got := r.Header.Get("Authorization"); got != "Bearer test-token" {
t.Errorf("Authorization = %q, want Bearer test-token", got)
}
w.WriteHeader(http.StatusOK)
w.Write([]byte(`{"ok":true}`))
}))
defer srv.Close()
c := NewClient("test-token", srv.URL)
body, err := c.doGet(context.Background(), srv.URL+"/test")
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if string(body) != `{"ok":true}` {
t.Errorf("body = %q, want %q", body, `{"ok":true}`)
}
}
func TestDoRequest_429_RetryAfter_IntegerSeconds(t *testing.T) {
attempts := 0
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
attempts++
if attempts == 1 {
w.Header().Set("Retry-After", "0")
w.WriteHeader(http.StatusTooManyRequests)
w.Write([]byte("rate limited"))
return
}
w.WriteHeader(http.StatusOK)
w.Write([]byte("success"))
}))
defer srv.Close()
c := NewClient("tok", srv.URL)
c.SetRetryBackoff([]time.Duration{0, 0})
body, err := c.doGet(context.Background(), srv.URL+"/test")
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if string(body) != "success" {
t.Errorf("body = %q, want %q", body, "success")
}
if attempts != 2 {
t.Errorf("attempts = %d, want 2", attempts)
}
}
func TestDoRequest_429_RetryAfter_HTTPDate(t *testing.T) {
// Fix "now" to a known time for deterministic testing.
fixedNow := time.Date(2025, 12, 1, 15, 59, 59, 0, time.UTC)
retryAt := "Mon, 01 Dec 2025 16:00:00 GMT" // 1 second in the future
attempts := 0
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
attempts++
if attempts == 1 {
w.Header().Set("Retry-After", retryAt)
w.WriteHeader(http.StatusTooManyRequests)
w.Write([]byte("rate limited"))
return
}
w.WriteHeader(http.StatusOK)
w.Write([]byte("success"))
}))
defer srv.Close()
c := NewClient("tok", srv.URL)
c.now = func() time.Time { return fixedNow }
// Initial backoff is 0; the HTTP-date parser will compute 1s and override.
c.SetRetryBackoff([]time.Duration{0, 0})
body, err := c.doGet(context.Background(), srv.URL+"/test")
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if string(body) != "success" {
t.Errorf("body = %q, want %q", body, "success")
}
if attempts != 2 {
t.Errorf("attempts = %d, want 2", attempts)
}
}
func TestDoRequest_429_RetryAfter_HTTPDate_InPast(t *testing.T) {
// If the HTTP-date is in the past, delay should be 0 (retry immediately).
fixedNow := time.Date(2025, 12, 1, 17, 0, 0, 0, time.UTC)
retryAt := "Mon, 01 Dec 2025 16:00:00 GMT" // 1 hour in the past
attempts := 0
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
attempts++
if attempts == 1 {
w.Header().Set("Retry-After", retryAt)
w.WriteHeader(http.StatusTooManyRequests)
w.Write([]byte("rate limited"))
return
}
w.WriteHeader(http.StatusOK)
w.Write([]byte("success"))
}))
defer srv.Close()
c := NewClient("tok", srv.URL)
c.now = func() time.Time { return fixedNow }
c.SetRetryBackoff([]time.Duration{0, 0})
body, err := c.doGet(context.Background(), srv.URL+"/test")
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if string(body) != "success" {
t.Errorf("body = %q, want %q", body, "success")
}
}
func TestDoRequest_429_NoRetryAfter_UsesDefaultBackoff(t *testing.T) {
attempts := 0
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
attempts++
if attempts == 1 {
w.WriteHeader(http.StatusTooManyRequests)
w.Write([]byte("rate limited"))
return
}
w.WriteHeader(http.StatusOK)
w.Write([]byte("success"))
}))
defer srv.Close()
c := NewClient("tok", srv.URL)
c.SetRetryBackoff([]time.Duration{0, 0})
body, err := c.doGet(context.Background(), srv.URL+"/test")
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if string(body) != "success" {
t.Errorf("body = %q, want %q", body, "success")
}
if attempts != 2 {
t.Errorf("attempts = %d, want 2", attempts)
}
}
func TestDoRequest_429_InvalidRetryAfter_UsesDefaultBackoff(t *testing.T) {
attempts := 0
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
attempts++
if attempts == 1 {
w.Header().Set("Retry-After", "not-a-number-or-date")
w.WriteHeader(http.StatusTooManyRequests)
w.Write([]byte("rate limited"))
return
}
w.WriteHeader(http.StatusOK)
w.Write([]byte("success"))
}))
defer srv.Close()
c := NewClient("tok", srv.URL)
c.SetRetryBackoff([]time.Duration{0, 0})
body, err := c.doGet(context.Background(), srv.URL+"/test")
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if string(body) != "success" {
t.Errorf("body = %q, want %q", body, "success")
}
}
func TestDoRequest_404_NoRetry(t *testing.T) {
attempts := 0
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
attempts++
w.WriteHeader(http.StatusNotFound)
w.Write([]byte("not found"))
}))
defer srv.Close()
c := NewClient("tok", srv.URL)
_, err := c.doGet(context.Background(), srv.URL+"/test")
if err == nil {
t.Fatal("expected error, got nil")
}
if !IsNotFound(err) {
t.Errorf("expected IsNotFound, got %v", err)
}
if attempts != 1 {
t.Errorf("attempts = %d, want 1 (no retry on 404)", attempts)
}
}
func TestDoRequest_401_NoRetry(t *testing.T) {
attempts := 0
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
attempts++
w.WriteHeader(http.StatusUnauthorized)
w.Write([]byte("unauthorized"))
}))
defer srv.Close()
c := NewClient("tok", srv.URL)
_, err := c.doGet(context.Background(), srv.URL+"/test")
if err == nil {
t.Fatal("expected error, got nil")
}
if !IsUnauthorized(err) {
t.Errorf("expected IsUnauthorized, got %v", err)
}
if attempts != 1 {
t.Errorf("attempts = %d, want 1 (no retry on 401)", attempts)
}
}
func TestDoRequest_ContextCanceled(t *testing.T) {
// This test exercises the timer-cancel path in the retry select:
// select { case <-timer.C; case <-ctx.Done() }
// The server returns 429 with a long Retry-After, and we cancel the
// context shortly after the first response so that cancellation races
// against the timer rather than preventing the initial HTTP round-trip.
requestReceived := make(chan struct{}, 1)
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
select {
case requestReceived <- struct{}{}:
default:
}
w.Header().Set("Retry-After", "10")
w.WriteHeader(http.StatusTooManyRequests)
}))
defer srv.Close()
c := NewClient("tok", srv.URL)
c.SetRetryBackoff([]time.Duration{10 * time.Second, 10 * time.Second})
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
// Cancel the context after the first request completes, while the
// client is blocked in the retry timer select.
go func() {
<-requestReceived
// Small delay to ensure we're inside the timer select.
time.Sleep(50 * time.Millisecond)
cancel()
}()
_, err := c.doGet(ctx, srv.URL+"/test")
if err == nil {
t.Fatal("expected error, got nil")
}
if !errors.Is(err, context.Canceled) {
t.Errorf("err = %v, want context.Canceled", err)
}
}
func TestParseRetryAfter_IntegerSeconds(t *testing.T) {
c := NewClient("tok", "")
delay, ok := c.parseRetryAfter("42")
if !ok {
t.Fatal("expected ok=true")
}
if delay != 42*time.Second {
t.Errorf("delay = %v, want 42s", delay)
}
}
func TestParseRetryAfter_ZeroSeconds(t *testing.T) {
c := NewClient("tok", "")
delay, ok := c.parseRetryAfter("0")
if !ok {
t.Fatal("expected ok=true for zero seconds (RFC 7231 allows immediate retry)")
}
if delay != 0 {
t.Errorf("delay = %v, want 0", delay)
}
}
func TestParseRetryAfter_NegativeSeconds(t *testing.T) {
c := NewClient("tok", "")
_, ok := c.parseRetryAfter("-5")
if ok {
t.Error("expected ok=false for negative seconds")
}
}
func TestParseRetryAfter_HTTPDate_Future(t *testing.T) {
fixedNow := time.Date(2025, 12, 1, 15, 59, 50, 0, time.UTC)
c := NewClient("tok", "")
c.now = func() time.Time { return fixedNow }
delay, ok := c.parseRetryAfter("Mon, 01 Dec 2025 16:00:00 GMT")
if !ok {
t.Fatal("expected ok=true")
}
// Should be 10 seconds in the future.
if delay != 10*time.Second {
t.Errorf("delay = %v, want 10s", delay)
}
}
func TestParseRetryAfter_HTTPDate_Past(t *testing.T) {
fixedNow := time.Date(2025, 12, 1, 17, 0, 0, 0, time.UTC)
c := NewClient("tok", "")
c.now = func() time.Time { return fixedNow }
delay, ok := c.parseRetryAfter("Mon, 01 Dec 2025 16:00:00 GMT")
if !ok {
t.Fatal("expected ok=true")
}
if delay != 0 {
t.Errorf("delay = %v, want 0 (past date)", delay)
}
}
func TestParseRetryAfter_RFC850_Format(t *testing.T) {
fixedNow := time.Date(2025, 12, 1, 15, 59, 50, 0, time.UTC)
c := NewClient("tok", "")
c.now = func() time.Time { return fixedNow }
// RFC 850 format
delay, ok := c.parseRetryAfter("Monday, 01-Dec-25 16:00:00 GMT")
if !ok {
t.Fatal("expected ok=true for RFC 850 format")
}
if delay != 10*time.Second {
t.Errorf("delay = %v, want 10s", delay)
}
}
func TestParseRetryAfter_Invalid(t *testing.T) {
c := NewClient("tok", "")
_, ok := c.parseRetryAfter("not-valid")
if ok {
t.Error("expected ok=false for invalid value")
}
}
func TestParseRetryAfter_EmptyString(t *testing.T) {
c := NewClient("tok", "")
_, ok := c.parseRetryAfter("")
if ok {
t.Error("expected ok=false for empty string")
}
}
func TestParseRetryAfter_MaxCap(t *testing.T) {
// Verify that parseRetryAfter returns the raw value (capping is done by caller).
c := NewClient("tok", "")
delay, ok := c.parseRetryAfter("3600")
if !ok {
t.Fatal("expected ok=true")
}
if delay != 3600*time.Second {
t.Errorf("delay = %v, want 3600s (caller is responsible for capping)", delay)
}
}
func TestAPIError_Error_Truncation(t *testing.T) {
longBody := make([]byte, 300)
for i := range longBody {
longBody[i] = 'x'
}
apiErr := &APIError{StatusCode: 500, Body: string(longBody)}
msg := apiErr.Error()
if len(msg) > 250 {
// "HTTP 500: " (10) + 200 + "...(truncated)" (14) = 224
t.Errorf("error message too long: %d chars", len(msg))
}
}
func TestAPIError_Error_NewlineSanitized(t *testing.T) {
apiErr := &APIError{StatusCode: 400, Body: "line1\nline2\rline3"}
msg := apiErr.Error()
for _, c := range msg {
if c == '\n' || c == '\r' {
t.Errorf("error message contains unsanitized newline: %q", msg)
break
}
}
}
+20 -70
View File
@@ -5,7 +5,6 @@ import (
"embed"
"encoding/json"
"fmt"
"io"
"os"
"sort"
"strings"
@@ -121,7 +120,9 @@ func ListBuiltinPersonas() []string {
default:
continue
}
seen[personaName] = true
if !seen[personaName] {
seen[personaName] = true
}
}
names := make([]string, 0, len(seen))
for name := range seen {
@@ -147,15 +148,6 @@ func parsePersona(data []byte, source string) (*Persona, error) {
dec := json.NewDecoder(bytes.NewReader(data))
dec.DisallowUnknownFields()
err = dec.Decode(&p)
if err == nil {
// Reject trailing content after the first valid JSON object.
// Without this check, input like `{"name":"x"}garbage` would
// silently succeed because Decoder stops after one object.
var dummy json.RawMessage
if err2 := dec.Decode(&dummy); err2 != io.EOF {
err = fmt.Errorf("unexpected trailing content after JSON object")
}
}
}
if err != nil {
return nil, fmt.Errorf("parse persona %s: %w", source, err)
@@ -166,10 +158,10 @@ func parsePersona(data []byte, source string) (*Persona, error) {
return &p, nil
}
// unmarshalYAMLWithDepthLimit unmarshals YAML data with three safety checks:
// - Depth limiting: rejects AST trees exceeding maxDepth to prevent stack exhaustion.
// - Multi-document rejection: prevents silent data loss from ignored extra documents.
// - Strict field checking: rejects unknown YAML keys to catch typos early.
// unmarshalYAMLWithDepthLimit unmarshals YAML data with explicit depth limiting
// and strict field checking. This protects against stack exhaustion from deeply
// nested structures and catches typos in field names.
// Multi-document YAML files are rejected to prevent silent data loss.
func unmarshalYAMLWithDepthLimit(data []byte, out any, maxDepth int) error {
// First pass: parse into AST to check depth limits, node counts, and
// multi-document rejection. This prevents stack exhaustion before we
@@ -198,18 +190,13 @@ func unmarshalYAMLWithDepthLimit(data []byte, out any, maxDepth int) error {
// Second pass: decode with strict field checking enabled.
// Strict() rejects unknown keys, catching typos like "focuss" or "identiy".
//
// Safety note: goccy/go-yaml's decoder does not expand YAML aliases
// recursively — it resolves them via the pre-built AST, which our first
// pass already depth-checked. Alias chains that would exceed depth limits
// are caught above; the decoder merely reads the resolved scalar values.
dec := yaml.NewDecoder(bytes.NewReader(data), yaml.Strict())
return dec.Decode(out)
}
// checkYAMLDepth recursively checks that YAML AST nodes don't exceed the depth
// limit or the total node count limit. It uses two tracking maps:
// - validated: maps each node to the maximum depth at which it was previously
// - validated: maps each node to the minimum depth at which it was previously
// checked. If a node is revisited at a deeper depth (e.g., via an alias),
// we re-check it to ensure the combined effective depth doesn't exceed limits.
// - visiting: per-path recursion stack for true cycle detection. A node on the
@@ -227,6 +214,12 @@ func checkYAMLDepth(node ast.Node, depth, maxDepth, maxNodes int, validated map[
return fmt.Errorf("YAML nesting depth exceeds maximum (%d)", maxDepth)
}
// Track total nodes visited as defense-in-depth against wide-but-shallow attacks.
*nodeCount++
if *nodeCount > maxNodes {
return fmt.Errorf("YAML node count exceeds maximum (%d)", maxNodes)
}
// Cycle detection: if we're currently visiting this node on the current
// recursion path, it's a cycle (e.g., alias pointing to an ancestor).
// Return nil to break the cycle without error — cycles are a structural
@@ -235,28 +228,10 @@ func checkYAMLDepth(node ast.Node, depth, maxDepth, maxNodes int, validated map[
return nil
}
// Track total nodes visited as defense-in-depth against wide-but-shallow attacks.
// Placed after cycle detection but before the depth-aware short-circuit. This means
// nodes revisited at shallower depths (via aliases) are counted each time they are
// encountered — intentional conservative overcounting. This bounds the total work
// performed during validation rather than tracking unique nodes, which is the safer
// security posture for untrusted YAML input.
*nodeCount++
if *nodeCount > maxNodes {
return fmt.Errorf("YAML node count exceeds maximum (%d)", maxNodes)
}
// Depth-aware short-circuit: skip re-validation only when the current visit
// depth is the same or shallower than the depth at which this node was
// previously validated. A shallower (or equal) current depth means the
// prior, deeper validation already covered any subtree depth violations.
// If the current depth exceeds the previous validation depth (e.g., an alias
// references this node deeper in the tree), we must re-traverse to ensure
// the combined effective depth doesn't exceed maxDepth.
//
// Note: using ast.Node (interface) as map key relies on pointer identity,
// which is correct because all goccy/go-yaml AST node types are pointer
// receivers (*MappingNode, *SequenceNode, etc.), never value types.
// Depth-aware short-circuit: only skip re-checking a node if we previously
// validated it at the same or deeper effective depth. If this visit is at a
// greater depth than before (e.g., alias referenced deeper in the tree),
// we must re-traverse to catch depth limit violations.
if prevDepth, ok := validated[node]; ok && depth <= prevDepth {
return nil
}
@@ -275,11 +250,6 @@ func checkYAMLDepth(node ast.Node, depth, maxDepth, maxNodes int, validated map[
}
}
case *ast.MappingValueNode:
// Both Key and Value are visited at depth+1 relative to this
// MappingValueNode. Since MappingNode visits its MappingValueNode
// children at depth+1 as well, keys and values end up at depth+2
// from the parent MappingNode. This is intentional: it mirrors the
// actual nesting structure (mapping → key-value pair → key/value).
if err := checkYAMLDepth(n.Key, depth+1, maxDepth, maxNodes, validated, visiting, nodeCount); err != nil {
return err
}
@@ -299,14 +269,6 @@ func checkYAMLDepth(node ast.Node, depth, maxDepth, maxNodes int, validated map[
return err
}
case *ast.AnchorNode:
// Increment depth for anchor values as a conservative measure: the
// anchor definition itself is structural, and treating it as a depth
// level ensures that deeply nested anchors are caught at definition
// time rather than only when referenced via alias. This +1 is
// asymmetric with alias (which also increments) — by design, the
// effective depth budget for anchored-then-aliased content is reduced
// because both the definition site and the reference site each consume
// a level, making deeply nested anchor/alias pairs hit the limit sooner.
if err := checkYAMLDepth(n.Value, depth+1, maxDepth, maxNodes, validated, visiting, nodeCount); err != nil {
return err
}
@@ -314,16 +276,8 @@ func checkYAMLDepth(node ast.Node, depth, maxDepth, maxNodes int, validated map[
if err := checkYAMLDepth(n.Value, depth+1, maxDepth, maxNodes, validated, visiting, nodeCount); err != nil {
return err
}
case *ast.MergeKeyNode:
// MergeKeyNode represents the literal "<<" merge key token. It has no
// child nodes — the value side of a merge (e.g., *alias) lives in the
// parent MappingValueNode.Value, which is already recursed into above.
// Explicitly listed here (rather than in the default case) to prevent
// future library changes from silently bypassing depth checks.
default:
// Scalar leaf nodes (StringNode, IntegerNode, FloatNode, BoolNode,
// NullNode, InfinityNode, NanNode, LiteralNode) have no children to
// recurse into.
// Scalar types (StringNode, IntegerNode, FloatNode, BoolNode, NullNode,
// InfinityNode, NanNode, LiteralNode, MergeKeyNode) are leaf nodes.
}
return nil
}
@@ -331,11 +285,7 @@ func checkYAMLDepth(node ast.Node, depth, maxDepth, maxNodes int, validated map[
// ParsePersonaBytes parses persona data from bytes with a source label for errors.
// This is useful for parsing personas fetched from external sources (e.g., Gitea API)
// without requiring filesystem access. Format is detected by source extension.
// Input is bounded by MaxPersonaFileSize to prevent resource exhaustion.
func ParsePersonaBytes(data []byte, source string) (*Persona, error) {
if len(data) > MaxPersonaFileSize {
return nil, fmt.Errorf("persona data from %s exceeds maximum size (%d bytes, limit %d)", source, len(data), MaxPersonaFileSize)
}
return parsePersona(data, source)
}
+10 -113
View File
@@ -459,14 +459,8 @@ func TestYAMLDeeplyNestedRejection(t *testing.T) {
path := filepath.Join(dir, "deeply-nested.yaml")
// Build a deeply nested YAML structure that exceeds MaxYAMLDepth (20).
// Depth accumulation trace for "nested: \n level0: \n level1: ...":
// - Document root parsed at depth 0
// - Root MappingNode children (MappingValueNodes) visited at depth 1
// - "nested" MappingValueNode: key at depth 2, value at depth 2
// - Each levelN adds depth via MappingValueNode traversal (key + value)
// - Exact depth per level depends on AST structure (MappingNode wrapping),
// but 25 levels reliably exceeds MaxYAMLDepth (20) with comfortable margin.
// The test uses 25 levels rather than exactly 21 to avoid brittleness.
// Each nested mapping key generates a MappingValueNode, incrementing depth
// by 1 per level in the AST walk. With 25 levels, we exceed MaxYAMLDepth (20).
var sb strings.Builder
sb.WriteString("name: test\nidentity: test\nnested:\n")
indent := " "
@@ -490,19 +484,21 @@ func TestYAMLDeeplyNestedRejection(t *testing.T) {
}
}
func TestYAMLEmptyFileRejection(t *testing.T) {
dir := t.TempDir()
tests := []struct {
name string
content string
}{
{"completely_empty", ""},
{"whitespace_only", " \n\n "},
{"comment_only", "# just a comment\n"},
{"completely empty", ""},
{"whitespace only", " \n\n "},
{"comment only", "# just a comment\n"},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
dir := t.TempDir()
path := filepath.Join(dir, tc.name+".yaml")
if err := os.WriteFile(path, []byte(tc.content), 0644); err != nil {
t.Fatalf("failed to write test file: %v", err)
@@ -510,9 +506,9 @@ func TestYAMLEmptyFileRejection(t *testing.T) {
_, err := LoadPersona(path)
if err == nil {
t.Fatal("expected error for empty YAML input, got nil")
t.Error("expected error for empty YAML input, got nil")
}
if !strings.Contains(err.Error(), "empty YAML document") {
if err != nil && !strings.Contains(err.Error(), "empty YAML document") {
t.Errorf("expected error containing %q, got: %v", "empty YAML document", err)
}
})
@@ -858,102 +854,3 @@ identity: test identity
t.Errorf("Name = %q, want %q", p.Name, "test")
}
}
func TestJSONTrailingContentRejected(t *testing.T) {
tests := []struct {
name string
content string
}{
{
name: "trailing garbage after object",
content: `{"name":"test","identity":"test identity"}garbage`,
},
{
name: "two JSON objects",
content: `{"name":"test","identity":"test identity"}{"name":"other"}`,
},
{
name: "trailing array",
content: `{"name":"test","identity":"test identity"}[]`,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
dir := t.TempDir()
path := filepath.Join(dir, "test.json")
if err := os.WriteFile(path, []byte(tt.content), 0644); err != nil {
t.Fatalf("failed to write test file: %v", err)
}
_, err := LoadPersona(path)
if err == nil {
t.Fatal("expected error for trailing content, got nil")
}
if !strings.Contains(err.Error(), "trailing content") {
t.Errorf("error = %q, want to contain 'trailing content'", err.Error())
}
})
}
}
func TestParsePersonaBytesSizeLimit(t *testing.T) {
// ParsePersonaBytes should reject input exceeding MaxPersonaFileSize
oversized := make([]byte, MaxPersonaFileSize+1)
for i := range oversized {
oversized[i] = 'x'
}
_, err := ParsePersonaBytes(oversized, "oversized.yaml")
if err == nil {
t.Fatal("expected error for oversized input, got nil")
}
if !strings.Contains(err.Error(), "exceeds maximum size") {
t.Errorf("error = %q, want to contain 'exceeds maximum size'", err.Error())
}
// Just under the limit should not trigger size error (may fail parse, but not size)
underLimit := []byte("name: test\nidentity: test persona\n")
p, err := ParsePersonaBytes(underLimit, "valid.yaml")
if err != nil {
t.Fatalf("unexpected error for valid input: %v", err)
}
if p.Name != "test" {
t.Errorf("Name = %q, want %q", p.Name, "test")
}
}
func TestYAMLMergeKeyDepthCheck(t *testing.T) {
// Verify that YAML merge keys (<<: *alias) are properly handled by the
// depth checker. The merge key content is in the MappingValueNode.Value
// (an AliasNode), not in the MergeKeyNode itself.
p, err := ParsePersonaBytes([]byte("name: merge-test\nidentity: test\n"), "merge.yaml")
if err != nil {
t.Fatalf("basic parse failed: %v", err)
}
if p.Name != "merge-test" {
t.Errorf("Name = %q, want %q", p.Name, "merge-test")
}
// Test that deeply nested merge keys still hit depth limit.
// Build YAML with merge key content nested beyond MaxYAMLDepth.
var sb strings.Builder
sb.WriteString("name: deep-merge\nidentity: deep merge persona\n")
sb.WriteString("anchor: &deep\n")
indent := " "
for i := 0; i < MaxYAMLDepth+5; i++ {
sb.WriteString(indent)
sb.WriteString(fmt.Sprintf("level%d:\n", i))
indent += " "
}
sb.WriteString(indent + "leaf: value\n")
sb.WriteString("target:\n <<: *deep\n")
_, err = ParsePersonaBytes([]byte(sb.String()), "deep-merge.yaml")
if err == nil {
t.Fatal("expected error for deeply nested merge key content, got nil")
}
if !strings.Contains(err.Error(), "depth") {
t.Errorf("error = %q, want to contain 'depth'", err.Error())
}
}