feat(gitea): harden GetPullRequestDiff against unbounded diff size

Add a configurable MaxDiffSize field to Client that limits how much
data GetPullRequestDiff will read into memory. The default is 10 MB
(DefaultMaxDiffSize). When the diff exceeds the limit, ErrDiffTooLarge
is returned, allowing callers to skip position translation gracefully.

Implementation uses io.LimitReader to read maxBytes+1, detecting
overflow without buffering the entire response. Setting MaxDiffSize
to -1 disables the limit entirely.

Closes #92
This commit is contained in:
claw
2026-05-13 04:57:30 -07:00
committed by Aaron Weiker
parent 881ce232eb
commit 92b84976cf
2 changed files with 251 additions and 1 deletions
+108 -1
View File
@@ -47,6 +47,12 @@ func IsServerError(err error) bool {
return errors.As(err, &apiErr) && apiErr.StatusCode >= 500 && apiErr.StatusCode < 600
}
// DefaultMaxDiffSize is the default maximum diff size in bytes (10 MB).
const DefaultMaxDiffSize = 10 * 1024 * 1024
// ErrDiffTooLarge is returned when a PR diff exceeds the configured MaxDiffSize.
var ErrDiffTooLarge = errors.New("diff size exceeds maximum allowed size")
// Client interacts with the Gitea API.
// A Client is safe for concurrent use by multiple goroutines.
type Client struct {
@@ -61,6 +67,10 @@ type Client struct {
// This field must be configured before the first request is made.
// Modifying it while requests are in flight is not safe.
RetryBackoff []time.Duration
// MaxDiffSize is the maximum number of bytes allowed when fetching a PR diff.
// If zero, defaults to DefaultMaxDiffSize (10 MB). Set to -1 to disable the limit.
MaxDiffSize int64
}
// NewClient creates a new Gitea API client.
@@ -125,9 +135,26 @@ func (c *Client) GetPullRequest(ctx context.Context, owner, repo string, number
}
// GetPullRequestDiff fetches the unified diff for a PR.
// It enforces MaxDiffSize to prevent unbounded memory allocation.
// Returns ErrDiffTooLarge if the diff exceeds the configured limit.
func (c *Client) GetPullRequestDiff(ctx context.Context, owner, repo string, number int) (string, error) {
reqURL := fmt.Sprintf("%s/api/v1/repos/%s/%s/pulls/%d.diff", c.baseURL, url.PathEscape(owner), url.PathEscape(repo), number)
body, err := c.doGet(ctx, reqURL)
maxSize := c.MaxDiffSize
if maxSize == 0 {
maxSize = DefaultMaxDiffSize
}
// When the limit is disabled, use the standard doGet path.
if maxSize < 0 {
body, err := c.doGet(ctx, reqURL)
if err != nil {
return "", fmt.Errorf("fetch diff: %w", err)
}
return string(body), nil
}
body, err := c.doGetLimited(ctx, reqURL, maxSize)
if err != nil {
return "", fmt.Errorf("fetch diff: %w", err)
}
@@ -413,6 +440,86 @@ func (c *Client) doGet(ctx context.Context, reqURL string) ([]byte, error) {
return nil, lastErr
}
// doGetLimited performs an HTTP GET request with retry (like doGet) but enforces
// a maximum response body size. Returns ErrDiffTooLarge if the response exceeds
// maxBytes. It reads maxBytes+1 to detect overflow without buffering the entire body.
func (c *Client) doGetLimited(ctx context.Context, reqURL string, maxBytes int64) ([]byte, error) {
const maxAttempts = 3
backoff := c.RetryBackoff
if backoff == nil {
backoff = []time.Duration{1 * time.Second, 2 * time.Second}
}
const maxErrorBodyBytes = 64 * 1024
var lastErr error
for attempt := 0; attempt < maxAttempts; attempt++ {
if attempt > 0 {
var delay time.Duration
if attempt-1 < len(backoff) {
delay = backoff[attempt-1]
}
if delay > 0 {
slog.Warn("retrying request after error",
"attempt", attempt+1,
"url", redactURL(reqURL),
"delay", delay.String(),
"lastError", sanitizeErrorForLog(lastErr))
timer := time.NewTimer(delay)
select {
case <-timer.C:
case <-ctx.Done():
timer.Stop()
return nil, ctx.Err()
}
}
}
req, err := http.NewRequestWithContext(ctx, http.MethodGet, reqURL, nil)
if err != nil {
return nil, err
}
req.Header.Set("Authorization", "token "+c.token)
resp, err := c.http.Do(req)
if err != nil {
lastErr = err
if attempt < maxAttempts-1 && isTemporaryNetError(err) {
slog.Warn("temporary network error, will retry",
"attempt", attempt+1,
"url", redactURL(reqURL),
"error", err)
continue
}
return nil, lastErr
}
if resp.StatusCode >= 200 && resp.StatusCode < 300 {
// Read up to maxBytes+1 to detect overflow.
limited := io.LimitReader(resp.Body, maxBytes+1)
body, err := io.ReadAll(limited)
resp.Body.Close()
if err != nil {
return nil, err
}
if int64(len(body)) > maxBytes {
return nil, fmt.Errorf("%w: response is larger than %d bytes", ErrDiffTooLarge, maxBytes)
}
return body, nil
}
errBody, _ := io.ReadAll(io.LimitReader(resp.Body, maxErrorBodyBytes))
resp.Body.Close()
lastErr = &APIError{StatusCode: resp.StatusCode, Body: string(errBody)}
if resp.StatusCode < 500 || resp.StatusCode >= 600 {
return nil, lastErr
}
}
return nil, lastErr
}
// escapePath escapes each segment of a relative file path for use in URLs.
// Slashes are preserved as path separators; other special characters are escaped.
// Input should be a relative path (no leading slash). Already-encoded segments
+143
View File
@@ -0,0 +1,143 @@
package gitea
import (
"context"
"errors"
"net/http"
"net/http/httptest"
"strings"
"testing"
"time"
)
func TestGetPullRequestDiff_ExceedsMaxSize(t *testing.T) {
// Create a diff that exceeds a small limit
largeDiff := strings.Repeat("+ added line\n", 1000) // ~13 KB
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte(largeDiff))
}))
defer server.Close()
client := NewClient(server.URL, "test-token")
client.MaxDiffSize = 100 // 100 bytes limit
client.RetryBackoff = []time.Duration{} // no delay in tests
_, err := client.GetPullRequestDiff(context.Background(), "owner", "repo", 1)
if err == nil {
t.Fatal("expected error for oversized diff, got nil")
}
if !errors.Is(err, ErrDiffTooLarge) {
t.Errorf("expected ErrDiffTooLarge, got: %v", err)
}
}
func TestGetPullRequestDiff_WithinMaxSize(t *testing.T) {
smallDiff := "diff --git a/f.go b/f.go\n--- a/f.go\n+++ b/f.go\n@@ -1 +1 @@\n-old\n+new\n"
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte(smallDiff))
}))
defer server.Close()
client := NewClient(server.URL, "test-token")
client.MaxDiffSize = 1024 // 1 KB limit — more than enough
client.RetryBackoff = []time.Duration{}
got, err := client.GetPullRequestDiff(context.Background(), "owner", "repo", 1)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if got != smallDiff {
t.Errorf("expected diff %q, got %q", smallDiff, got)
}
}
func TestGetPullRequestDiff_ExactlyAtLimit(t *testing.T) {
// A diff that is exactly at the limit should succeed
exactDiff := strings.Repeat("x", 50)
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte(exactDiff))
}))
defer server.Close()
client := NewClient(server.URL, "test-token")
client.MaxDiffSize = 50 // exactly the size of the diff
client.RetryBackoff = []time.Duration{}
got, err := client.GetPullRequestDiff(context.Background(), "owner", "repo", 1)
if err != nil {
t.Fatalf("unexpected error for diff at exact limit: %v", err)
}
if got != exactDiff {
t.Errorf("expected diff to match, got length %d", len(got))
}
}
func TestGetPullRequestDiff_OneByteOverLimit(t *testing.T) {
// A diff that is one byte over the limit should fail
overDiff := strings.Repeat("x", 51)
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte(overDiff))
}))
defer server.Close()
client := NewClient(server.URL, "test-token")
client.MaxDiffSize = 50
client.RetryBackoff = []time.Duration{}
_, err := client.GetPullRequestDiff(context.Background(), "owner", "repo", 1)
if err == nil {
t.Fatal("expected error for diff one byte over limit")
}
if !errors.Is(err, ErrDiffTooLarge) {
t.Errorf("expected ErrDiffTooLarge, got: %v", err)
}
}
func TestGetPullRequestDiff_DisabledLimit(t *testing.T) {
// When MaxDiffSize is -1, no limit is enforced
largeDiff := strings.Repeat("x", 10000)
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte(largeDiff))
}))
defer server.Close()
client := NewClient(server.URL, "test-token")
client.MaxDiffSize = -1 // disabled
client.RetryBackoff = []time.Duration{}
got, err := client.GetPullRequestDiff(context.Background(), "owner", "repo", 1)
if err != nil {
t.Fatalf("unexpected error with disabled limit: %v", err)
}
if got != largeDiff {
t.Errorf("expected full diff with disabled limit, got length %d", len(got))
}
}
func TestGetPullRequestDiff_DefaultLimit(t *testing.T) {
// With zero MaxDiffSize (default), should use DefaultMaxDiffSize.
// A small diff should succeed without setting MaxDiffSize.
smallDiff := "diff content"
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte(smallDiff))
}))
defer server.Close()
client := NewClient(server.URL, "test-token")
// MaxDiffSize is zero (default) — should use DefaultMaxDiffSize (10 MB)
client.RetryBackoff = []time.Duration{}
got, err := client.GetPullRequestDiff(context.Background(), "owner", "repo", 1)
if err != nil {
t.Fatalf("unexpected error with default limit: %v", err)
}
if got != smallDiff {
t.Errorf("expected diff %q, got %q", smallDiff, got)
}
}