Compare commits
1 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 8ebfa80c14 |
@@ -21,8 +21,6 @@ To request a new dependency:
|
|||||||
2. Requires explicit approval from Aaron
|
2. Requires explicit approval from Aaron
|
||||||
3. After merge, a separate PR may use the package
|
3. After merge, a separate PR may use the package
|
||||||
|
|
||||||
<!-- Deviation from step 1+3 for go-yaml migration: see #91 for rationale. -->
|
|
||||||
|
|
||||||
*Enforcement: `scripts/check-deps.sh` parses this table — update only here.*
|
*Enforcement: `scripts/check-deps.sh` parses this table — update only here.*
|
||||||
|
|
||||||
## Error Handling
|
## Error Handling
|
||||||
|
|||||||
@@ -9,7 +9,7 @@ JSON is awkward for persona files that contain multi-line text (identity, severi
|
|||||||
- Backwards compatibility: existing JSON personas must continue to work
|
- Backwards compatibility: existing JSON personas must continue to work
|
||||||
- Security: protect against DoS via deeply nested YAML (AIKIDO-2024-10486)
|
- Security: protect against DoS via deeply nested YAML (AIKIDO-2024-10486)
|
||||||
- Consistency: use `.yaml` extension (not `.yml`)
|
- Consistency: use `.yaml` extension (not `.yml`)
|
||||||
- Library: use `github.com/goccy/go-yaml` v1.16.0+ (approved in CONVENTIONS.md); we implement custom AST-based depth/node-count checks for precise alias-aware validation
|
- Library: use `github.com/goccy/go-yaml` v1.16.0+ (approved in CONVENTIONS.md); has built-in depth protection via `MaxYAMLDepth`/`MaxYAMLNodes` constants
|
||||||
|
|
||||||
## Proposed Approach
|
## Proposed Approach
|
||||||
|
|
||||||
@@ -33,16 +33,37 @@ func parsePersona(data []byte, source string) (*Persona, error) {
|
|||||||
|
|
||||||
### YAML Parsing with Depth Protection
|
### YAML Parsing with Depth Protection
|
||||||
|
|
||||||
We implement a custom AST-based depth/node-count walk (`checkYAMLDepth` in
|
```go
|
||||||
`review/persona.go`) rather than relying on library decoder options. Key design
|
func unmarshalYAMLWithDepthLimit(data []byte, out any, maxDepth int) error {
|
||||||
decisions:
|
var node yaml.Node
|
||||||
|
dec := yaml.NewDecoder(bytes.NewReader(data))
|
||||||
|
if err := dec.Decode(&node); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if err := checkYAMLDepth(&node, 0, maxDepth); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return node.Decode(out)
|
||||||
|
}
|
||||||
|
|
||||||
- **Library:** `github.com/goccy/go-yaml` with `ast.Node`-based traversal
|
func checkYAMLDepth(node *yaml.Node, depth, maxDepth int) error {
|
||||||
- **Dual-map tracking:** `validated` (depth-aware short-circuit) + `visiting` (cycle detection)
|
if depth > maxDepth {
|
||||||
- **Node-count limit:** Conservative overcounting bounds total validation work
|
return fmt.Errorf("YAML nesting depth exceeds maximum (%d)", maxDepth)
|
||||||
- **Alias-aware depth:** Aliases increment depth and are re-checked when encountered at greater depths
|
}
|
||||||
|
// Handle alias nodes by following the Alias pointer
|
||||||
|
if node.Kind == yaml.AliasNode && node.Alias != nil {
|
||||||
|
return checkYAMLDepth(node.Alias, depth, maxDepth)
|
||||||
|
}
|
||||||
|
for _, child := range node.Content {
|
||||||
|
if err := checkYAMLDepth(child, depth+1, maxDepth); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
See `review/persona.go:checkYAMLDepth` for the authoritative implementation.
|
The `github.com/goccy/go-yaml` library provides built-in depth protection via `MaxYAMLDepth` and `MaxYAMLNodes` decoder options. We use these instead of a manual depth-checking walk.
|
||||||
|
|
||||||
## State/Data Model
|
## State/Data Model
|
||||||
|
|
||||||
@@ -53,7 +74,7 @@ No new state. Same `Persona` struct, just different parsing.
|
|||||||
| Error | Handling |
|
| Error | Handling |
|
||||||
|-------|----------|
|
|-------|----------|
|
||||||
| Invalid YAML syntax | Return parse error with source file |
|
| Invalid YAML syntax | Return parse error with source file |
|
||||||
| Deeply nested YAML | Custom AST walk (`checkYAMLDepth`) rejects before decode |
|
| Deeply nested YAML | Library rejects (v1.16.0+ fix) |
|
||||||
| Unknown extension | Fall back to JSON parsing |
|
| Unknown extension | Fall back to JSON parsing |
|
||||||
| Missing required fields | Validation rejects after parse |
|
| Missing required fields | Validation rejects after parse |
|
||||||
|
|
||||||
|
|||||||
@@ -1,260 +0,0 @@
|
|||||||
// Package github provides a client for the GitHub API.
|
|
||||||
// It supports pull request operations, file content retrieval,
|
|
||||||
// and review submission for both github.com and GitHub Enterprise.
|
|
||||||
package github
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"errors"
|
|
||||||
"fmt"
|
|
||||||
"io"
|
|
||||||
"net/http"
|
|
||||||
"strconv"
|
|
||||||
"strings"
|
|
||||||
"time"
|
|
||||||
)
|
|
||||||
|
|
||||||
const (
|
|
||||||
defaultBaseURL = "https://api.github.com"
|
|
||||||
|
|
||||||
// maxRetryAttempts is the number of times doRequest will attempt a request.
|
|
||||||
maxRetryAttempts = 3
|
|
||||||
|
|
||||||
// maxRetryAfter caps the maximum delay from a Retry-After header to prevent
|
|
||||||
// a server from stalling the client indefinitely.
|
|
||||||
maxRetryAfter = 60 * time.Second
|
|
||||||
|
|
||||||
// maxErrorBodyBytes limits how much of an error response body we read
|
|
||||||
// to protect against malicious servers sending unbounded data.
|
|
||||||
maxErrorBodyBytes = 64 * 1024 // 64 KB
|
|
||||||
|
|
||||||
// maxResponseBodyBytes limits how much of a successful response body we read
|
|
||||||
// for defense-in-depth against servers returning excessively large payloads.
|
|
||||||
maxResponseBodyBytes = 10 * 1024 * 1024 // 10 MB
|
|
||||||
)
|
|
||||||
|
|
||||||
// APIError represents an HTTP error response from the GitHub API.
|
|
||||||
// It carries the status code so callers can distinguish between
|
|
||||||
// different failure modes (e.g. 404 vs 500).
|
|
||||||
//
|
|
||||||
// The Body field stores up to 64 KiB of the raw response for programmatic
|
|
||||||
// inspection. Error() truncates to 200 bytes for safe logging, but callers
|
|
||||||
// should avoid logging or propagating Body directly in production since it may
|
|
||||||
// contain sensitive details from the upstream server.
|
|
||||||
type APIError struct {
|
|
||||||
StatusCode int
|
|
||||||
Body string
|
|
||||||
}
|
|
||||||
|
|
||||||
func (e *APIError) Error() string {
|
|
||||||
body := e.Body
|
|
||||||
if len(body) > 200 {
|
|
||||||
body = body[:200] + "...(truncated)"
|
|
||||||
}
|
|
||||||
// Sanitize newlines to prevent log injection from upstream response bodies.
|
|
||||||
body = strings.ReplaceAll(body, "\n", " ")
|
|
||||||
body = strings.ReplaceAll(body, "\r", " ")
|
|
||||||
return fmt.Sprintf("HTTP %d: %s", e.StatusCode, body)
|
|
||||||
}
|
|
||||||
|
|
||||||
// IsNotFound reports whether an error is an API 404 response.
|
|
||||||
func IsNotFound(err error) bool {
|
|
||||||
if apiErr, ok := asAPIError(err); ok {
|
|
||||||
return apiErr.StatusCode == http.StatusNotFound
|
|
||||||
}
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
// IsUnauthorized reports whether an error is an API 401 response.
|
|
||||||
func IsUnauthorized(err error) bool {
|
|
||||||
if apiErr, ok := asAPIError(err); ok {
|
|
||||||
return apiErr.StatusCode == http.StatusUnauthorized
|
|
||||||
}
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
func asAPIError(err error) (*APIError, bool) {
|
|
||||||
if err == nil {
|
|
||||||
return nil, false
|
|
||||||
}
|
|
||||||
var target *APIError
|
|
||||||
if errors.As(err, &target) {
|
|
||||||
return target, true
|
|
||||||
}
|
|
||||||
return nil, false
|
|
||||||
}
|
|
||||||
|
|
||||||
// Client interacts with the GitHub API.
|
|
||||||
// A Client is safe for concurrent use by multiple goroutines.
|
|
||||||
// SetHTTPClient and SetRetryBackoff are intended for test setup only and must
|
|
||||||
// be called before any goroutines issue requests; they have no synchronization.
|
|
||||||
type Client struct {
|
|
||||||
// TODO: baseURL is populated by NewClient but not yet consumed by doRequest/doGet.
|
|
||||||
// Higher-level exported methods (GetPullRequest, etc.) will use it to
|
|
||||||
// construct request URLs; remove this field if those methods end up
|
|
||||||
// accepting full URLs instead.
|
|
||||||
baseURL string
|
|
||||||
token string
|
|
||||||
httpClient *http.Client
|
|
||||||
|
|
||||||
// retryBackoff defines the delays between retry attempts for 429 responses.
|
|
||||||
// retryBackoff[i] is the delay before attempt i+1 (after attempt i fails).
|
|
||||||
// If nil, defaults to {1s, 2s}.
|
|
||||||
retryBackoff []time.Duration
|
|
||||||
|
|
||||||
// now returns the current time. Defaults to time.Now.
|
|
||||||
// Override in tests to control HTTP-date Retry-After calculations.
|
|
||||||
now func() time.Time
|
|
||||||
}
|
|
||||||
|
|
||||||
// NewClient creates a new GitHub API client.
|
|
||||||
// If baseURL is empty, it defaults to https://api.github.com.
|
|
||||||
// For GitHub Enterprise, pass the API base URL (e.g. https://github.concur.com/api/v3).
|
|
||||||
func NewClient(token, baseURL string) *Client {
|
|
||||||
if baseURL == "" {
|
|
||||||
baseURL = defaultBaseURL
|
|
||||||
}
|
|
||||||
return &Client{
|
|
||||||
baseURL: strings.TrimRight(baseURL, "/"),
|
|
||||||
token: token,
|
|
||||||
httpClient: &http.Client{Timeout: 30 * time.Second},
|
|
||||||
now: time.Now,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetHTTPClient sets the underlying HTTP client used for requests.
|
|
||||||
// This is intended for testing to inject mock transports.
|
|
||||||
func (c *Client) SetHTTPClient(hc *http.Client) {
|
|
||||||
c.httpClient = hc
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetRetryBackoff sets the delays between retry attempts.
|
|
||||||
// This is intended for testing to speed up retry tests.
|
|
||||||
//
|
|
||||||
// Note: if an empty non-nil slice is provided, Retry-After delays parsed from
|
|
||||||
// server responses will be computed and capped but not applied (because
|
|
||||||
// attempt < len(backoff) is always false). This is acceptable for the
|
|
||||||
// test-only use case but callers should be aware of this edge case.
|
|
||||||
func (c *Client) SetRetryBackoff(backoff []time.Duration) {
|
|
||||||
c.retryBackoff = backoff
|
|
||||||
}
|
|
||||||
|
|
||||||
// parseRetryAfter parses a Retry-After header value, supporting both integer
|
|
||||||
// seconds (e.g. "120") and HTTP-date format (e.g. "Thu, 01 Dec 2025 16:00:00 GMT")
|
|
||||||
// as specified in RFC 7231 §7.1.3.
|
|
||||||
//
|
|
||||||
// For integer values, it returns the duration directly.
|
|
||||||
// For HTTP-date values, it computes the delay as the difference between the
|
|
||||||
// parsed time and now. If the date is in the past, it returns 0.
|
|
||||||
//
|
|
||||||
// Returns (0, false) if the value cannot be parsed as either format.
|
|
||||||
func (c *Client) parseRetryAfter(value string) (time.Duration, bool) {
|
|
||||||
value = strings.TrimSpace(value)
|
|
||||||
|
|
||||||
// Try integer seconds first (most common from GitHub).
|
|
||||||
// RFC 7231 allows delta-seconds of 0 to indicate immediate retry.
|
|
||||||
if seconds, err := strconv.Atoi(value); err == nil && seconds >= 0 {
|
|
||||||
return time.Duration(seconds) * time.Second, true
|
|
||||||
}
|
|
||||||
|
|
||||||
// Try HTTP-date format (RFC 7231 §7.1.3).
|
|
||||||
// http.ParseTime handles RFC 1123, RFC 850, and ASCTIME formats.
|
|
||||||
if retryAt, err := http.ParseTime(value); err == nil {
|
|
||||||
delay := retryAt.Sub(c.now())
|
|
||||||
if delay < 0 {
|
|
||||||
delay = 0
|
|
||||||
}
|
|
||||||
return delay, true
|
|
||||||
}
|
|
||||||
|
|
||||||
return 0, false
|
|
||||||
}
|
|
||||||
|
|
||||||
// doRequest performs an HTTP request with retry on 429 rate limit responses.
|
|
||||||
// It respects the Retry-After header when present, supporting both integer
|
|
||||||
// seconds and HTTP-date formats (capped at maxRetryAfter).
|
|
||||||
func (c *Client) doRequest(ctx context.Context, method, reqURL string, accept string) ([]byte, error) {
|
|
||||||
var backoff []time.Duration
|
|
||||||
if c.retryBackoff != nil {
|
|
||||||
backoff = append([]time.Duration(nil), c.retryBackoff...)
|
|
||||||
} else {
|
|
||||||
backoff = []time.Duration{1 * time.Second, 2 * time.Second}
|
|
||||||
}
|
|
||||||
|
|
||||||
var lastErr error
|
|
||||||
for attempt := 0; attempt < maxRetryAttempts; attempt++ {
|
|
||||||
if attempt > 0 {
|
|
||||||
var delay time.Duration
|
|
||||||
if attempt-1 < len(backoff) {
|
|
||||||
delay = backoff[attempt-1]
|
|
||||||
}
|
|
||||||
if delay > 0 {
|
|
||||||
timer := time.NewTimer(delay)
|
|
||||||
select {
|
|
||||||
case <-timer.C:
|
|
||||||
timer.Stop()
|
|
||||||
case <-ctx.Done():
|
|
||||||
timer.Stop()
|
|
||||||
return nil, ctx.Err()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
req, err := http.NewRequestWithContext(ctx, method, reqURL, nil)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("create request: %w", err)
|
|
||||||
}
|
|
||||||
req.Header.Set("Authorization", "Bearer "+c.token)
|
|
||||||
if accept != "" {
|
|
||||||
req.Header.Set("Accept", accept)
|
|
||||||
} else {
|
|
||||||
req.Header.Set("Accept", "application/vnd.github+json")
|
|
||||||
}
|
|
||||||
|
|
||||||
resp, err := c.httpClient.Do(req)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("do request: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if resp.StatusCode >= 200 && resp.StatusCode < 300 {
|
|
||||||
body, err := io.ReadAll(io.LimitReader(resp.Body, maxResponseBodyBytes))
|
|
||||||
resp.Body.Close()
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("read response body: %w", err)
|
|
||||||
}
|
|
||||||
return body, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
errBody, _ := io.ReadAll(io.LimitReader(resp.Body, maxErrorBodyBytes))
|
|
||||||
resp.Body.Close()
|
|
||||||
|
|
||||||
lastErr = &APIError{StatusCode: resp.StatusCode, Body: string(errBody)}
|
|
||||||
|
|
||||||
// Retry on 429 rate limit
|
|
||||||
if resp.StatusCode == http.StatusTooManyRequests && attempt < maxRetryAttempts-1 {
|
|
||||||
// Check for Retry-After header and override backoff if present.
|
|
||||||
// Supports both integer seconds (common) and HTTP-date format (RFC 7231).
|
|
||||||
if ra := resp.Header.Get("Retry-After"); ra != "" {
|
|
||||||
if delay, ok := c.parseRetryAfter(ra); ok {
|
|
||||||
if delay > maxRetryAfter {
|
|
||||||
delay = maxRetryAfter
|
|
||||||
}
|
|
||||||
if attempt < len(backoff) {
|
|
||||||
backoff[attempt] = delay
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
// Don't retry other errors
|
|
||||||
return nil, lastErr
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil, lastErr
|
|
||||||
}
|
|
||||||
|
|
||||||
// doGet is a convenience wrapper for GET requests with the default Accept header.
|
|
||||||
func (c *Client) doGet(ctx context.Context, url string) ([]byte, error) {
|
|
||||||
return c.doRequest(ctx, http.MethodGet, url, "")
|
|
||||||
}
|
|
||||||
@@ -1,409 +0,0 @@
|
|||||||
package github
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"errors"
|
|
||||||
"net/http"
|
|
||||||
"net/http/httptest"
|
|
||||||
"testing"
|
|
||||||
"time"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestNewClient_DefaultBaseURL(t *testing.T) {
|
|
||||||
c := NewClient("tok", "")
|
|
||||||
if c.baseURL != defaultBaseURL {
|
|
||||||
t.Errorf("baseURL = %q, want %q", c.baseURL, defaultBaseURL)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestNewClient_CustomBaseURL(t *testing.T) {
|
|
||||||
c := NewClient("tok", "https://github.concur.com/api/v3/")
|
|
||||||
if c.baseURL != "https://github.concur.com/api/v3" {
|
|
||||||
t.Errorf("baseURL = %q, want trailing slash stripped", c.baseURL)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestDoRequest_Success(t *testing.T) {
|
|
||||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
if got := r.Header.Get("Authorization"); got != "Bearer test-token" {
|
|
||||||
t.Errorf("Authorization = %q, want Bearer test-token", got)
|
|
||||||
}
|
|
||||||
w.WriteHeader(http.StatusOK)
|
|
||||||
w.Write([]byte(`{"ok":true}`))
|
|
||||||
}))
|
|
||||||
defer srv.Close()
|
|
||||||
|
|
||||||
c := NewClient("test-token", srv.URL)
|
|
||||||
body, err := c.doGet(context.Background(), srv.URL+"/test")
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("unexpected error: %v", err)
|
|
||||||
}
|
|
||||||
if string(body) != `{"ok":true}` {
|
|
||||||
t.Errorf("body = %q, want %q", body, `{"ok":true}`)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestDoRequest_429_RetryAfter_IntegerSeconds(t *testing.T) {
|
|
||||||
attempts := 0
|
|
||||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
attempts++
|
|
||||||
if attempts == 1 {
|
|
||||||
w.Header().Set("Retry-After", "0")
|
|
||||||
w.WriteHeader(http.StatusTooManyRequests)
|
|
||||||
w.Write([]byte("rate limited"))
|
|
||||||
return
|
|
||||||
}
|
|
||||||
w.WriteHeader(http.StatusOK)
|
|
||||||
w.Write([]byte("success"))
|
|
||||||
}))
|
|
||||||
defer srv.Close()
|
|
||||||
|
|
||||||
c := NewClient("tok", srv.URL)
|
|
||||||
c.SetRetryBackoff([]time.Duration{0, 0})
|
|
||||||
|
|
||||||
body, err := c.doGet(context.Background(), srv.URL+"/test")
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("unexpected error: %v", err)
|
|
||||||
}
|
|
||||||
if string(body) != "success" {
|
|
||||||
t.Errorf("body = %q, want %q", body, "success")
|
|
||||||
}
|
|
||||||
if attempts != 2 {
|
|
||||||
t.Errorf("attempts = %d, want 2", attempts)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestDoRequest_429_RetryAfter_HTTPDate(t *testing.T) {
|
|
||||||
// Fix "now" to a known time for deterministic testing.
|
|
||||||
fixedNow := time.Date(2025, 12, 1, 15, 59, 59, 0, time.UTC)
|
|
||||||
retryAt := "Mon, 01 Dec 2025 16:00:00 GMT" // 1 second in the future
|
|
||||||
|
|
||||||
attempts := 0
|
|
||||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
attempts++
|
|
||||||
if attempts == 1 {
|
|
||||||
w.Header().Set("Retry-After", retryAt)
|
|
||||||
w.WriteHeader(http.StatusTooManyRequests)
|
|
||||||
w.Write([]byte("rate limited"))
|
|
||||||
return
|
|
||||||
}
|
|
||||||
w.WriteHeader(http.StatusOK)
|
|
||||||
w.Write([]byte("success"))
|
|
||||||
}))
|
|
||||||
defer srv.Close()
|
|
||||||
|
|
||||||
c := NewClient("tok", srv.URL)
|
|
||||||
c.now = func() time.Time { return fixedNow }
|
|
||||||
// Initial backoff is 0; the HTTP-date parser will compute 1s and override.
|
|
||||||
c.SetRetryBackoff([]time.Duration{0, 0})
|
|
||||||
|
|
||||||
body, err := c.doGet(context.Background(), srv.URL+"/test")
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("unexpected error: %v", err)
|
|
||||||
}
|
|
||||||
if string(body) != "success" {
|
|
||||||
t.Errorf("body = %q, want %q", body, "success")
|
|
||||||
}
|
|
||||||
if attempts != 2 {
|
|
||||||
t.Errorf("attempts = %d, want 2", attempts)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestDoRequest_429_RetryAfter_HTTPDate_InPast(t *testing.T) {
|
|
||||||
// If the HTTP-date is in the past, delay should be 0 (retry immediately).
|
|
||||||
fixedNow := time.Date(2025, 12, 1, 17, 0, 0, 0, time.UTC)
|
|
||||||
retryAt := "Mon, 01 Dec 2025 16:00:00 GMT" // 1 hour in the past
|
|
||||||
|
|
||||||
attempts := 0
|
|
||||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
attempts++
|
|
||||||
if attempts == 1 {
|
|
||||||
w.Header().Set("Retry-After", retryAt)
|
|
||||||
w.WriteHeader(http.StatusTooManyRequests)
|
|
||||||
w.Write([]byte("rate limited"))
|
|
||||||
return
|
|
||||||
}
|
|
||||||
w.WriteHeader(http.StatusOK)
|
|
||||||
w.Write([]byte("success"))
|
|
||||||
}))
|
|
||||||
defer srv.Close()
|
|
||||||
|
|
||||||
c := NewClient("tok", srv.URL)
|
|
||||||
c.now = func() time.Time { return fixedNow }
|
|
||||||
c.SetRetryBackoff([]time.Duration{0, 0})
|
|
||||||
|
|
||||||
body, err := c.doGet(context.Background(), srv.URL+"/test")
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("unexpected error: %v", err)
|
|
||||||
}
|
|
||||||
if string(body) != "success" {
|
|
||||||
t.Errorf("body = %q, want %q", body, "success")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestDoRequest_429_NoRetryAfter_UsesDefaultBackoff(t *testing.T) {
|
|
||||||
attempts := 0
|
|
||||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
attempts++
|
|
||||||
if attempts == 1 {
|
|
||||||
w.WriteHeader(http.StatusTooManyRequests)
|
|
||||||
w.Write([]byte("rate limited"))
|
|
||||||
return
|
|
||||||
}
|
|
||||||
w.WriteHeader(http.StatusOK)
|
|
||||||
w.Write([]byte("success"))
|
|
||||||
}))
|
|
||||||
defer srv.Close()
|
|
||||||
|
|
||||||
c := NewClient("tok", srv.URL)
|
|
||||||
c.SetRetryBackoff([]time.Duration{0, 0})
|
|
||||||
|
|
||||||
body, err := c.doGet(context.Background(), srv.URL+"/test")
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("unexpected error: %v", err)
|
|
||||||
}
|
|
||||||
if string(body) != "success" {
|
|
||||||
t.Errorf("body = %q, want %q", body, "success")
|
|
||||||
}
|
|
||||||
if attempts != 2 {
|
|
||||||
t.Errorf("attempts = %d, want 2", attempts)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestDoRequest_429_InvalidRetryAfter_UsesDefaultBackoff(t *testing.T) {
|
|
||||||
attempts := 0
|
|
||||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
attempts++
|
|
||||||
if attempts == 1 {
|
|
||||||
w.Header().Set("Retry-After", "not-a-number-or-date")
|
|
||||||
w.WriteHeader(http.StatusTooManyRequests)
|
|
||||||
w.Write([]byte("rate limited"))
|
|
||||||
return
|
|
||||||
}
|
|
||||||
w.WriteHeader(http.StatusOK)
|
|
||||||
w.Write([]byte("success"))
|
|
||||||
}))
|
|
||||||
defer srv.Close()
|
|
||||||
|
|
||||||
c := NewClient("tok", srv.URL)
|
|
||||||
c.SetRetryBackoff([]time.Duration{0, 0})
|
|
||||||
|
|
||||||
body, err := c.doGet(context.Background(), srv.URL+"/test")
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("unexpected error: %v", err)
|
|
||||||
}
|
|
||||||
if string(body) != "success" {
|
|
||||||
t.Errorf("body = %q, want %q", body, "success")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestDoRequest_404_NoRetry(t *testing.T) {
|
|
||||||
attempts := 0
|
|
||||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
attempts++
|
|
||||||
w.WriteHeader(http.StatusNotFound)
|
|
||||||
w.Write([]byte("not found"))
|
|
||||||
}))
|
|
||||||
defer srv.Close()
|
|
||||||
|
|
||||||
c := NewClient("tok", srv.URL)
|
|
||||||
_, err := c.doGet(context.Background(), srv.URL+"/test")
|
|
||||||
if err == nil {
|
|
||||||
t.Fatal("expected error, got nil")
|
|
||||||
}
|
|
||||||
if !IsNotFound(err) {
|
|
||||||
t.Errorf("expected IsNotFound, got %v", err)
|
|
||||||
}
|
|
||||||
if attempts != 1 {
|
|
||||||
t.Errorf("attempts = %d, want 1 (no retry on 404)", attempts)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestDoRequest_401_NoRetry(t *testing.T) {
|
|
||||||
attempts := 0
|
|
||||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
attempts++
|
|
||||||
w.WriteHeader(http.StatusUnauthorized)
|
|
||||||
w.Write([]byte("unauthorized"))
|
|
||||||
}))
|
|
||||||
defer srv.Close()
|
|
||||||
|
|
||||||
c := NewClient("tok", srv.URL)
|
|
||||||
_, err := c.doGet(context.Background(), srv.URL+"/test")
|
|
||||||
if err == nil {
|
|
||||||
t.Fatal("expected error, got nil")
|
|
||||||
}
|
|
||||||
if !IsUnauthorized(err) {
|
|
||||||
t.Errorf("expected IsUnauthorized, got %v", err)
|
|
||||||
}
|
|
||||||
if attempts != 1 {
|
|
||||||
t.Errorf("attempts = %d, want 1 (no retry on 401)", attempts)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestDoRequest_ContextCanceled(t *testing.T) {
|
|
||||||
// This test exercises the timer-cancel path in the retry select:
|
|
||||||
// select { case <-timer.C; case <-ctx.Done() }
|
|
||||||
// The server returns 429 with a long Retry-After, and we cancel the
|
|
||||||
// context shortly after the first response so that cancellation races
|
|
||||||
// against the timer rather than preventing the initial HTTP round-trip.
|
|
||||||
requestReceived := make(chan struct{}, 1)
|
|
||||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
select {
|
|
||||||
case requestReceived <- struct{}{}:
|
|
||||||
default:
|
|
||||||
}
|
|
||||||
w.Header().Set("Retry-After", "10")
|
|
||||||
w.WriteHeader(http.StatusTooManyRequests)
|
|
||||||
}))
|
|
||||||
defer srv.Close()
|
|
||||||
|
|
||||||
c := NewClient("tok", srv.URL)
|
|
||||||
c.SetRetryBackoff([]time.Duration{10 * time.Second, 10 * time.Second})
|
|
||||||
|
|
||||||
ctx, cancel := context.WithCancel(context.Background())
|
|
||||||
defer cancel()
|
|
||||||
|
|
||||||
// Cancel the context after the first request completes, while the
|
|
||||||
// client is blocked in the retry timer select.
|
|
||||||
go func() {
|
|
||||||
<-requestReceived
|
|
||||||
// Small delay to ensure we're inside the timer select.
|
|
||||||
time.Sleep(50 * time.Millisecond)
|
|
||||||
cancel()
|
|
||||||
}()
|
|
||||||
|
|
||||||
_, err := c.doGet(ctx, srv.URL+"/test")
|
|
||||||
if err == nil {
|
|
||||||
t.Fatal("expected error, got nil")
|
|
||||||
}
|
|
||||||
if !errors.Is(err, context.Canceled) {
|
|
||||||
t.Errorf("err = %v, want context.Canceled", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestParseRetryAfter_IntegerSeconds(t *testing.T) {
|
|
||||||
c := NewClient("tok", "")
|
|
||||||
delay, ok := c.parseRetryAfter("42")
|
|
||||||
if !ok {
|
|
||||||
t.Fatal("expected ok=true")
|
|
||||||
}
|
|
||||||
if delay != 42*time.Second {
|
|
||||||
t.Errorf("delay = %v, want 42s", delay)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestParseRetryAfter_ZeroSeconds(t *testing.T) {
|
|
||||||
c := NewClient("tok", "")
|
|
||||||
delay, ok := c.parseRetryAfter("0")
|
|
||||||
if !ok {
|
|
||||||
t.Fatal("expected ok=true for zero seconds (RFC 7231 allows immediate retry)")
|
|
||||||
}
|
|
||||||
if delay != 0 {
|
|
||||||
t.Errorf("delay = %v, want 0", delay)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestParseRetryAfter_NegativeSeconds(t *testing.T) {
|
|
||||||
c := NewClient("tok", "")
|
|
||||||
_, ok := c.parseRetryAfter("-5")
|
|
||||||
if ok {
|
|
||||||
t.Error("expected ok=false for negative seconds")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestParseRetryAfter_HTTPDate_Future(t *testing.T) {
|
|
||||||
fixedNow := time.Date(2025, 12, 1, 15, 59, 50, 0, time.UTC)
|
|
||||||
c := NewClient("tok", "")
|
|
||||||
c.now = func() time.Time { return fixedNow }
|
|
||||||
|
|
||||||
delay, ok := c.parseRetryAfter("Mon, 01 Dec 2025 16:00:00 GMT")
|
|
||||||
if !ok {
|
|
||||||
t.Fatal("expected ok=true")
|
|
||||||
}
|
|
||||||
// Should be 10 seconds in the future.
|
|
||||||
if delay != 10*time.Second {
|
|
||||||
t.Errorf("delay = %v, want 10s", delay)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestParseRetryAfter_HTTPDate_Past(t *testing.T) {
|
|
||||||
fixedNow := time.Date(2025, 12, 1, 17, 0, 0, 0, time.UTC)
|
|
||||||
c := NewClient("tok", "")
|
|
||||||
c.now = func() time.Time { return fixedNow }
|
|
||||||
|
|
||||||
delay, ok := c.parseRetryAfter("Mon, 01 Dec 2025 16:00:00 GMT")
|
|
||||||
if !ok {
|
|
||||||
t.Fatal("expected ok=true")
|
|
||||||
}
|
|
||||||
if delay != 0 {
|
|
||||||
t.Errorf("delay = %v, want 0 (past date)", delay)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestParseRetryAfter_RFC850_Format(t *testing.T) {
|
|
||||||
fixedNow := time.Date(2025, 12, 1, 15, 59, 50, 0, time.UTC)
|
|
||||||
c := NewClient("tok", "")
|
|
||||||
c.now = func() time.Time { return fixedNow }
|
|
||||||
|
|
||||||
// RFC 850 format
|
|
||||||
delay, ok := c.parseRetryAfter("Monday, 01-Dec-25 16:00:00 GMT")
|
|
||||||
if !ok {
|
|
||||||
t.Fatal("expected ok=true for RFC 850 format")
|
|
||||||
}
|
|
||||||
if delay != 10*time.Second {
|
|
||||||
t.Errorf("delay = %v, want 10s", delay)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestParseRetryAfter_Invalid(t *testing.T) {
|
|
||||||
c := NewClient("tok", "")
|
|
||||||
_, ok := c.parseRetryAfter("not-valid")
|
|
||||||
if ok {
|
|
||||||
t.Error("expected ok=false for invalid value")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestParseRetryAfter_EmptyString(t *testing.T) {
|
|
||||||
c := NewClient("tok", "")
|
|
||||||
_, ok := c.parseRetryAfter("")
|
|
||||||
if ok {
|
|
||||||
t.Error("expected ok=false for empty string")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestParseRetryAfter_MaxCap(t *testing.T) {
|
|
||||||
// Verify that parseRetryAfter returns the raw value (capping is done by caller).
|
|
||||||
c := NewClient("tok", "")
|
|
||||||
delay, ok := c.parseRetryAfter("3600")
|
|
||||||
if !ok {
|
|
||||||
t.Fatal("expected ok=true")
|
|
||||||
}
|
|
||||||
if delay != 3600*time.Second {
|
|
||||||
t.Errorf("delay = %v, want 3600s (caller is responsible for capping)", delay)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestAPIError_Error_Truncation(t *testing.T) {
|
|
||||||
longBody := make([]byte, 300)
|
|
||||||
for i := range longBody {
|
|
||||||
longBody[i] = 'x'
|
|
||||||
}
|
|
||||||
apiErr := &APIError{StatusCode: 500, Body: string(longBody)}
|
|
||||||
msg := apiErr.Error()
|
|
||||||
if len(msg) > 250 {
|
|
||||||
// "HTTP 500: " (10) + 200 + "...(truncated)" (14) = 224
|
|
||||||
t.Errorf("error message too long: %d chars", len(msg))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestAPIError_Error_NewlineSanitized(t *testing.T) {
|
|
||||||
apiErr := &APIError{StatusCode: 400, Body: "line1\nline2\rline3"}
|
|
||||||
msg := apiErr.Error()
|
|
||||||
for _, c := range msg {
|
|
||||||
if c == '\n' || c == '\r' {
|
|
||||||
t.Errorf("error message contains unsanitized newline: %q", msg)
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
+19
-69
@@ -5,7 +5,6 @@ import (
|
|||||||
"embed"
|
"embed"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
|
||||||
"os"
|
"os"
|
||||||
"sort"
|
"sort"
|
||||||
"strings"
|
"strings"
|
||||||
@@ -121,8 +120,10 @@ func ListBuiltinPersonas() []string {
|
|||||||
default:
|
default:
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
if !seen[personaName] {
|
||||||
seen[personaName] = true
|
seen[personaName] = true
|
||||||
}
|
}
|
||||||
|
}
|
||||||
names := make([]string, 0, len(seen))
|
names := make([]string, 0, len(seen))
|
||||||
for name := range seen {
|
for name := range seen {
|
||||||
names = append(names, name)
|
names = append(names, name)
|
||||||
@@ -147,15 +148,6 @@ func parsePersona(data []byte, source string) (*Persona, error) {
|
|||||||
dec := json.NewDecoder(bytes.NewReader(data))
|
dec := json.NewDecoder(bytes.NewReader(data))
|
||||||
dec.DisallowUnknownFields()
|
dec.DisallowUnknownFields()
|
||||||
err = dec.Decode(&p)
|
err = dec.Decode(&p)
|
||||||
if err == nil {
|
|
||||||
// Reject trailing content after the first valid JSON object.
|
|
||||||
// Without this check, input like `{"name":"x"}garbage` would
|
|
||||||
// silently succeed because Decoder stops after one object.
|
|
||||||
var dummy json.RawMessage
|
|
||||||
if err2 := dec.Decode(&dummy); err2 != io.EOF {
|
|
||||||
err = fmt.Errorf("unexpected trailing content after JSON object")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("parse persona %s: %w", source, err)
|
return nil, fmt.Errorf("parse persona %s: %w", source, err)
|
||||||
@@ -166,10 +158,10 @@ func parsePersona(data []byte, source string) (*Persona, error) {
|
|||||||
return &p, nil
|
return &p, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// unmarshalYAMLWithDepthLimit unmarshals YAML data with three safety checks:
|
// unmarshalYAMLWithDepthLimit unmarshals YAML data with explicit depth limiting
|
||||||
// - Depth limiting: rejects AST trees exceeding maxDepth to prevent stack exhaustion.
|
// and strict field checking. This protects against stack exhaustion from deeply
|
||||||
// - Multi-document rejection: prevents silent data loss from ignored extra documents.
|
// nested structures and catches typos in field names.
|
||||||
// - Strict field checking: rejects unknown YAML keys to catch typos early.
|
// Multi-document YAML files are rejected to prevent silent data loss.
|
||||||
func unmarshalYAMLWithDepthLimit(data []byte, out any, maxDepth int) error {
|
func unmarshalYAMLWithDepthLimit(data []byte, out any, maxDepth int) error {
|
||||||
// First pass: parse into AST to check depth limits, node counts, and
|
// First pass: parse into AST to check depth limits, node counts, and
|
||||||
// multi-document rejection. This prevents stack exhaustion before we
|
// multi-document rejection. This prevents stack exhaustion before we
|
||||||
@@ -198,18 +190,13 @@ func unmarshalYAMLWithDepthLimit(data []byte, out any, maxDepth int) error {
|
|||||||
|
|
||||||
// Second pass: decode with strict field checking enabled.
|
// Second pass: decode with strict field checking enabled.
|
||||||
// Strict() rejects unknown keys, catching typos like "focuss" or "identiy".
|
// Strict() rejects unknown keys, catching typos like "focuss" or "identiy".
|
||||||
//
|
|
||||||
// Safety note: goccy/go-yaml's decoder does not expand YAML aliases
|
|
||||||
// recursively — it resolves them via the pre-built AST, which our first
|
|
||||||
// pass already depth-checked. Alias chains that would exceed depth limits
|
|
||||||
// are caught above; the decoder merely reads the resolved scalar values.
|
|
||||||
dec := yaml.NewDecoder(bytes.NewReader(data), yaml.Strict())
|
dec := yaml.NewDecoder(bytes.NewReader(data), yaml.Strict())
|
||||||
return dec.Decode(out)
|
return dec.Decode(out)
|
||||||
}
|
}
|
||||||
|
|
||||||
// checkYAMLDepth recursively checks that YAML AST nodes don't exceed the depth
|
// checkYAMLDepth recursively checks that YAML AST nodes don't exceed the depth
|
||||||
// limit or the total node count limit. It uses two tracking maps:
|
// limit or the total node count limit. It uses two tracking maps:
|
||||||
// - validated: maps each node to the maximum depth at which it was previously
|
// - validated: maps each node to the minimum depth at which it was previously
|
||||||
// checked. If a node is revisited at a deeper depth (e.g., via an alias),
|
// checked. If a node is revisited at a deeper depth (e.g., via an alias),
|
||||||
// we re-check it to ensure the combined effective depth doesn't exceed limits.
|
// we re-check it to ensure the combined effective depth doesn't exceed limits.
|
||||||
// - visiting: per-path recursion stack for true cycle detection. A node on the
|
// - visiting: per-path recursion stack for true cycle detection. A node on the
|
||||||
@@ -227,6 +214,12 @@ func checkYAMLDepth(node ast.Node, depth, maxDepth, maxNodes int, validated map[
|
|||||||
return fmt.Errorf("YAML nesting depth exceeds maximum (%d)", maxDepth)
|
return fmt.Errorf("YAML nesting depth exceeds maximum (%d)", maxDepth)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Track total nodes visited as defense-in-depth against wide-but-shallow attacks.
|
||||||
|
*nodeCount++
|
||||||
|
if *nodeCount > maxNodes {
|
||||||
|
return fmt.Errorf("YAML node count exceeds maximum (%d)", maxNodes)
|
||||||
|
}
|
||||||
|
|
||||||
// Cycle detection: if we're currently visiting this node on the current
|
// Cycle detection: if we're currently visiting this node on the current
|
||||||
// recursion path, it's a cycle (e.g., alias pointing to an ancestor).
|
// recursion path, it's a cycle (e.g., alias pointing to an ancestor).
|
||||||
// Return nil to break the cycle without error — cycles are a structural
|
// Return nil to break the cycle without error — cycles are a structural
|
||||||
@@ -235,28 +228,10 @@ func checkYAMLDepth(node ast.Node, depth, maxDepth, maxNodes int, validated map[
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Track total nodes visited as defense-in-depth against wide-but-shallow attacks.
|
// Depth-aware short-circuit: only skip re-checking a node if we previously
|
||||||
// Placed after cycle detection but before the depth-aware short-circuit. This means
|
// validated it at the same or deeper effective depth. If this visit is at a
|
||||||
// nodes revisited at shallower depths (via aliases) are counted each time they are
|
// greater depth than before (e.g., alias referenced deeper in the tree),
|
||||||
// encountered — intentional conservative overcounting. This bounds the total work
|
// we must re-traverse to catch depth limit violations.
|
||||||
// performed during validation rather than tracking unique nodes, which is the safer
|
|
||||||
// security posture for untrusted YAML input.
|
|
||||||
*nodeCount++
|
|
||||||
if *nodeCount > maxNodes {
|
|
||||||
return fmt.Errorf("YAML node count exceeds maximum (%d)", maxNodes)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Depth-aware short-circuit: skip re-validation only when the current visit
|
|
||||||
// depth is the same or shallower than the depth at which this node was
|
|
||||||
// previously validated. A shallower (or equal) current depth means the
|
|
||||||
// prior, deeper validation already covered any subtree depth violations.
|
|
||||||
// If the current depth exceeds the previous validation depth (e.g., an alias
|
|
||||||
// references this node deeper in the tree), we must re-traverse to ensure
|
|
||||||
// the combined effective depth doesn't exceed maxDepth.
|
|
||||||
//
|
|
||||||
// Note: using ast.Node (interface) as map key relies on pointer identity,
|
|
||||||
// which is correct because all goccy/go-yaml AST node types are pointer
|
|
||||||
// receivers (*MappingNode, *SequenceNode, etc.), never value types.
|
|
||||||
if prevDepth, ok := validated[node]; ok && depth <= prevDepth {
|
if prevDepth, ok := validated[node]; ok && depth <= prevDepth {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@@ -275,11 +250,6 @@ func checkYAMLDepth(node ast.Node, depth, maxDepth, maxNodes int, validated map[
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
case *ast.MappingValueNode:
|
case *ast.MappingValueNode:
|
||||||
// Both Key and Value are visited at depth+1 relative to this
|
|
||||||
// MappingValueNode. Since MappingNode visits its MappingValueNode
|
|
||||||
// children at depth+1 as well, keys and values end up at depth+2
|
|
||||||
// from the parent MappingNode. This is intentional: it mirrors the
|
|
||||||
// actual nesting structure (mapping → key-value pair → key/value).
|
|
||||||
if err := checkYAMLDepth(n.Key, depth+1, maxDepth, maxNodes, validated, visiting, nodeCount); err != nil {
|
if err := checkYAMLDepth(n.Key, depth+1, maxDepth, maxNodes, validated, visiting, nodeCount); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@@ -299,14 +269,6 @@ func checkYAMLDepth(node ast.Node, depth, maxDepth, maxNodes int, validated map[
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
case *ast.AnchorNode:
|
case *ast.AnchorNode:
|
||||||
// Increment depth for anchor values as a conservative measure: the
|
|
||||||
// anchor definition itself is structural, and treating it as a depth
|
|
||||||
// level ensures that deeply nested anchors are caught at definition
|
|
||||||
// time rather than only when referenced via alias. This +1 is
|
|
||||||
// asymmetric with alias (which also increments) — by design, the
|
|
||||||
// effective depth budget for anchored-then-aliased content is reduced
|
|
||||||
// because both the definition site and the reference site each consume
|
|
||||||
// a level, making deeply nested anchor/alias pairs hit the limit sooner.
|
|
||||||
if err := checkYAMLDepth(n.Value, depth+1, maxDepth, maxNodes, validated, visiting, nodeCount); err != nil {
|
if err := checkYAMLDepth(n.Value, depth+1, maxDepth, maxNodes, validated, visiting, nodeCount); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@@ -314,16 +276,8 @@ func checkYAMLDepth(node ast.Node, depth, maxDepth, maxNodes int, validated map[
|
|||||||
if err := checkYAMLDepth(n.Value, depth+1, maxDepth, maxNodes, validated, visiting, nodeCount); err != nil {
|
if err := checkYAMLDepth(n.Value, depth+1, maxDepth, maxNodes, validated, visiting, nodeCount); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
case *ast.MergeKeyNode:
|
// Scalar types (StringNode, IntegerNode, FloatNode, BoolNode, NullNode,
|
||||||
// MergeKeyNode represents the literal "<<" merge key token. It has no
|
// InfinityNode, NanNode, LiteralNode, MergeKeyNode) are leaf nodes.
|
||||||
// child nodes — the value side of a merge (e.g., *alias) lives in the
|
|
||||||
// parent MappingValueNode.Value, which is already recursed into above.
|
|
||||||
// Explicitly listed here (rather than in the default case) to prevent
|
|
||||||
// future library changes from silently bypassing depth checks.
|
|
||||||
default:
|
|
||||||
// Scalar leaf nodes (StringNode, IntegerNode, FloatNode, BoolNode,
|
|
||||||
// NullNode, InfinityNode, NanNode, LiteralNode) have no children to
|
|
||||||
// recurse into.
|
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@@ -331,11 +285,7 @@ func checkYAMLDepth(node ast.Node, depth, maxDepth, maxNodes int, validated map[
|
|||||||
// ParsePersonaBytes parses persona data from bytes with a source label for errors.
|
// ParsePersonaBytes parses persona data from bytes with a source label for errors.
|
||||||
// This is useful for parsing personas fetched from external sources (e.g., Gitea API)
|
// This is useful for parsing personas fetched from external sources (e.g., Gitea API)
|
||||||
// without requiring filesystem access. Format is detected by source extension.
|
// without requiring filesystem access. Format is detected by source extension.
|
||||||
// Input is bounded by MaxPersonaFileSize to prevent resource exhaustion.
|
|
||||||
func ParsePersonaBytes(data []byte, source string) (*Persona, error) {
|
func ParsePersonaBytes(data []byte, source string) (*Persona, error) {
|
||||||
if len(data) > MaxPersonaFileSize {
|
|
||||||
return nil, fmt.Errorf("persona data from %s exceeds maximum size (%d bytes, limit %d)", source, len(data), MaxPersonaFileSize)
|
|
||||||
}
|
|
||||||
return parsePersona(data, source)
|
return parsePersona(data, source)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
+10
-113
@@ -459,14 +459,8 @@ func TestYAMLDeeplyNestedRejection(t *testing.T) {
|
|||||||
path := filepath.Join(dir, "deeply-nested.yaml")
|
path := filepath.Join(dir, "deeply-nested.yaml")
|
||||||
|
|
||||||
// Build a deeply nested YAML structure that exceeds MaxYAMLDepth (20).
|
// Build a deeply nested YAML structure that exceeds MaxYAMLDepth (20).
|
||||||
// Depth accumulation trace for "nested: \n level0: \n level1: ...":
|
// Each nested mapping key generates a MappingValueNode, incrementing depth
|
||||||
// - Document root parsed at depth 0
|
// by 1 per level in the AST walk. With 25 levels, we exceed MaxYAMLDepth (20).
|
||||||
// - Root MappingNode children (MappingValueNodes) visited at depth 1
|
|
||||||
// - "nested" MappingValueNode: key at depth 2, value at depth 2
|
|
||||||
// - Each levelN adds depth via MappingValueNode traversal (key + value)
|
|
||||||
// - Exact depth per level depends on AST structure (MappingNode wrapping),
|
|
||||||
// but 25 levels reliably exceeds MaxYAMLDepth (20) with comfortable margin.
|
|
||||||
// The test uses 25 levels rather than exactly 21 to avoid brittleness.
|
|
||||||
var sb strings.Builder
|
var sb strings.Builder
|
||||||
sb.WriteString("name: test\nidentity: test\nnested:\n")
|
sb.WriteString("name: test\nidentity: test\nnested:\n")
|
||||||
indent := " "
|
indent := " "
|
||||||
@@ -490,19 +484,21 @@ func TestYAMLDeeplyNestedRejection(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
func TestYAMLEmptyFileRejection(t *testing.T) {
|
func TestYAMLEmptyFileRejection(t *testing.T) {
|
||||||
|
dir := t.TempDir()
|
||||||
|
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
content string
|
content string
|
||||||
}{
|
}{
|
||||||
{"completely_empty", ""},
|
{"completely empty", ""},
|
||||||
{"whitespace_only", " \n\n "},
|
{"whitespace only", " \n\n "},
|
||||||
{"comment_only", "# just a comment\n"},
|
{"comment only", "# just a comment\n"},
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, tc := range tests {
|
for _, tc := range tests {
|
||||||
t.Run(tc.name, func(t *testing.T) {
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
dir := t.TempDir()
|
|
||||||
path := filepath.Join(dir, tc.name+".yaml")
|
path := filepath.Join(dir, tc.name+".yaml")
|
||||||
if err := os.WriteFile(path, []byte(tc.content), 0644); err != nil {
|
if err := os.WriteFile(path, []byte(tc.content), 0644); err != nil {
|
||||||
t.Fatalf("failed to write test file: %v", err)
|
t.Fatalf("failed to write test file: %v", err)
|
||||||
@@ -510,9 +506,9 @@ func TestYAMLEmptyFileRejection(t *testing.T) {
|
|||||||
|
|
||||||
_, err := LoadPersona(path)
|
_, err := LoadPersona(path)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
t.Fatal("expected error for empty YAML input, got nil")
|
t.Error("expected error for empty YAML input, got nil")
|
||||||
}
|
}
|
||||||
if !strings.Contains(err.Error(), "empty YAML document") {
|
if err != nil && !strings.Contains(err.Error(), "empty YAML document") {
|
||||||
t.Errorf("expected error containing %q, got: %v", "empty YAML document", err)
|
t.Errorf("expected error containing %q, got: %v", "empty YAML document", err)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
@@ -858,102 +854,3 @@ identity: test identity
|
|||||||
t.Errorf("Name = %q, want %q", p.Name, "test")
|
t.Errorf("Name = %q, want %q", p.Name, "test")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestJSONTrailingContentRejected(t *testing.T) {
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
content string
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
name: "trailing garbage after object",
|
|
||||||
content: `{"name":"test","identity":"test identity"}garbage`,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "two JSON objects",
|
|
||||||
content: `{"name":"test","identity":"test identity"}{"name":"other"}`,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "trailing array",
|
|
||||||
content: `{"name":"test","identity":"test identity"}[]`,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
dir := t.TempDir()
|
|
||||||
path := filepath.Join(dir, "test.json")
|
|
||||||
if err := os.WriteFile(path, []byte(tt.content), 0644); err != nil {
|
|
||||||
t.Fatalf("failed to write test file: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
_, err := LoadPersona(path)
|
|
||||||
if err == nil {
|
|
||||||
t.Fatal("expected error for trailing content, got nil")
|
|
||||||
}
|
|
||||||
if !strings.Contains(err.Error(), "trailing content") {
|
|
||||||
t.Errorf("error = %q, want to contain 'trailing content'", err.Error())
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestParsePersonaBytesSizeLimit(t *testing.T) {
|
|
||||||
// ParsePersonaBytes should reject input exceeding MaxPersonaFileSize
|
|
||||||
oversized := make([]byte, MaxPersonaFileSize+1)
|
|
||||||
for i := range oversized {
|
|
||||||
oversized[i] = 'x'
|
|
||||||
}
|
|
||||||
|
|
||||||
_, err := ParsePersonaBytes(oversized, "oversized.yaml")
|
|
||||||
if err == nil {
|
|
||||||
t.Fatal("expected error for oversized input, got nil")
|
|
||||||
}
|
|
||||||
if !strings.Contains(err.Error(), "exceeds maximum size") {
|
|
||||||
t.Errorf("error = %q, want to contain 'exceeds maximum size'", err.Error())
|
|
||||||
}
|
|
||||||
|
|
||||||
// Just under the limit should not trigger size error (may fail parse, but not size)
|
|
||||||
underLimit := []byte("name: test\nidentity: test persona\n")
|
|
||||||
p, err := ParsePersonaBytes(underLimit, "valid.yaml")
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("unexpected error for valid input: %v", err)
|
|
||||||
}
|
|
||||||
if p.Name != "test" {
|
|
||||||
t.Errorf("Name = %q, want %q", p.Name, "test")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestYAMLMergeKeyDepthCheck(t *testing.T) {
|
|
||||||
// Verify that YAML merge keys (<<: *alias) are properly handled by the
|
|
||||||
// depth checker. The merge key content is in the MappingValueNode.Value
|
|
||||||
// (an AliasNode), not in the MergeKeyNode itself.
|
|
||||||
p, err := ParsePersonaBytes([]byte("name: merge-test\nidentity: test\n"), "merge.yaml")
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("basic parse failed: %v", err)
|
|
||||||
}
|
|
||||||
if p.Name != "merge-test" {
|
|
||||||
t.Errorf("Name = %q, want %q", p.Name, "merge-test")
|
|
||||||
}
|
|
||||||
|
|
||||||
// Test that deeply nested merge keys still hit depth limit.
|
|
||||||
// Build YAML with merge key content nested beyond MaxYAMLDepth.
|
|
||||||
var sb strings.Builder
|
|
||||||
sb.WriteString("name: deep-merge\nidentity: deep merge persona\n")
|
|
||||||
sb.WriteString("anchor: &deep\n")
|
|
||||||
indent := " "
|
|
||||||
for i := 0; i < MaxYAMLDepth+5; i++ {
|
|
||||||
sb.WriteString(indent)
|
|
||||||
sb.WriteString(fmt.Sprintf("level%d:\n", i))
|
|
||||||
indent += " "
|
|
||||||
}
|
|
||||||
sb.WriteString(indent + "leaf: value\n")
|
|
||||||
sb.WriteString("target:\n <<: *deep\n")
|
|
||||||
|
|
||||||
_, err = ParsePersonaBytes([]byte(sb.String()), "deep-merge.yaml")
|
|
||||||
if err == nil {
|
|
||||||
t.Fatal("expected error for deeply nested merge key content, got nil")
|
|
||||||
}
|
|
||||||
if !strings.Contains(err.Error(), "depth") {
|
|
||||||
t.Errorf("error = %q, want to contain 'depth'", err.Error())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|||||||
Reference in New Issue
Block a user