feat(gitea): harden GetPullRequestDiff against unbounded diff size
Add a configurable MaxDiffSize field to Client that limits how much data GetPullRequestDiff will read into memory. The default is 10 MB (DefaultMaxDiffSize). When the diff exceeds the limit, ErrDiffTooLarge is returned, allowing callers to skip position translation gracefully. Implementation uses io.LimitReader to read maxBytes+1, detecting overflow without buffering the entire response. Setting MaxDiffSize to -1 disables the limit entirely. Closes #92
This commit is contained in:
+108
-1
@@ -47,6 +47,12 @@ func IsServerError(err error) bool {
|
||||
return errors.As(err, &apiErr) && apiErr.StatusCode >= 500 && apiErr.StatusCode < 600
|
||||
}
|
||||
|
||||
// DefaultMaxDiffSize is the default maximum diff size in bytes (10 MB).
|
||||
const DefaultMaxDiffSize = 10 * 1024 * 1024
|
||||
|
||||
// ErrDiffTooLarge is returned when a PR diff exceeds the configured MaxDiffSize.
|
||||
var ErrDiffTooLarge = errors.New("diff size exceeds maximum allowed size")
|
||||
|
||||
// Client interacts with the Gitea API.
|
||||
// A Client is safe for concurrent use by multiple goroutines.
|
||||
type Client struct {
|
||||
@@ -61,6 +67,10 @@ type Client struct {
|
||||
// This field must be configured before the first request is made.
|
||||
// Modifying it while requests are in flight is not safe.
|
||||
RetryBackoff []time.Duration
|
||||
|
||||
// MaxDiffSize is the maximum number of bytes allowed when fetching a PR diff.
|
||||
// If zero, defaults to DefaultMaxDiffSize (10 MB). Set to -1 to disable the limit.
|
||||
MaxDiffSize int64
|
||||
}
|
||||
|
||||
// NewClient creates a new Gitea API client.
|
||||
@@ -125,9 +135,26 @@ func (c *Client) GetPullRequest(ctx context.Context, owner, repo string, number
|
||||
}
|
||||
|
||||
// GetPullRequestDiff fetches the unified diff for a PR.
|
||||
// It enforces MaxDiffSize to prevent unbounded memory allocation.
|
||||
// Returns ErrDiffTooLarge if the diff exceeds the configured limit.
|
||||
func (c *Client) GetPullRequestDiff(ctx context.Context, owner, repo string, number int) (string, error) {
|
||||
reqURL := fmt.Sprintf("%s/api/v1/repos/%s/%s/pulls/%d.diff", c.baseURL, url.PathEscape(owner), url.PathEscape(repo), number)
|
||||
body, err := c.doGet(ctx, reqURL)
|
||||
|
||||
maxSize := c.MaxDiffSize
|
||||
if maxSize == 0 {
|
||||
maxSize = DefaultMaxDiffSize
|
||||
}
|
||||
|
||||
// When the limit is disabled, use the standard doGet path.
|
||||
if maxSize < 0 {
|
||||
body, err := c.doGet(ctx, reqURL)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("fetch diff: %w", err)
|
||||
}
|
||||
return string(body), nil
|
||||
}
|
||||
|
||||
body, err := c.doGetLimited(ctx, reqURL, maxSize)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("fetch diff: %w", err)
|
||||
}
|
||||
@@ -413,6 +440,86 @@ func (c *Client) doGet(ctx context.Context, reqURL string) ([]byte, error) {
|
||||
return nil, lastErr
|
||||
}
|
||||
|
||||
// doGetLimited performs an HTTP GET request with retry (like doGet) but enforces
|
||||
// a maximum response body size. Returns ErrDiffTooLarge if the response exceeds
|
||||
// maxBytes. It reads maxBytes+1 to detect overflow without buffering the entire body.
|
||||
func (c *Client) doGetLimited(ctx context.Context, reqURL string, maxBytes int64) ([]byte, error) {
|
||||
const maxAttempts = 3
|
||||
backoff := c.RetryBackoff
|
||||
if backoff == nil {
|
||||
backoff = []time.Duration{1 * time.Second, 2 * time.Second}
|
||||
}
|
||||
const maxErrorBodyBytes = 64 * 1024
|
||||
|
||||
var lastErr error
|
||||
for attempt := 0; attempt < maxAttempts; attempt++ {
|
||||
if attempt > 0 {
|
||||
var delay time.Duration
|
||||
if attempt-1 < len(backoff) {
|
||||
delay = backoff[attempt-1]
|
||||
}
|
||||
if delay > 0 {
|
||||
slog.Warn("retrying request after error",
|
||||
"attempt", attempt+1,
|
||||
"url", redactURL(reqURL),
|
||||
"delay", delay.String(),
|
||||
"lastError", sanitizeErrorForLog(lastErr))
|
||||
|
||||
timer := time.NewTimer(delay)
|
||||
select {
|
||||
case <-timer.C:
|
||||
case <-ctx.Done():
|
||||
timer.Stop()
|
||||
return nil, ctx.Err()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, reqURL, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Header.Set("Authorization", "token "+c.token)
|
||||
|
||||
resp, err := c.http.Do(req)
|
||||
if err != nil {
|
||||
lastErr = err
|
||||
if attempt < maxAttempts-1 && isTemporaryNetError(err) {
|
||||
slog.Warn("temporary network error, will retry",
|
||||
"attempt", attempt+1,
|
||||
"url", redactURL(reqURL),
|
||||
"error", err)
|
||||
continue
|
||||
}
|
||||
return nil, lastErr
|
||||
}
|
||||
if resp.StatusCode >= 200 && resp.StatusCode < 300 {
|
||||
// Read up to maxBytes+1 to detect overflow.
|
||||
limited := io.LimitReader(resp.Body, maxBytes+1)
|
||||
body, err := io.ReadAll(limited)
|
||||
resp.Body.Close()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if int64(len(body)) > maxBytes {
|
||||
return nil, fmt.Errorf("%w: response is larger than %d bytes", ErrDiffTooLarge, maxBytes)
|
||||
}
|
||||
return body, nil
|
||||
}
|
||||
|
||||
errBody, _ := io.ReadAll(io.LimitReader(resp.Body, maxErrorBodyBytes))
|
||||
resp.Body.Close()
|
||||
|
||||
lastErr = &APIError{StatusCode: resp.StatusCode, Body: string(errBody)}
|
||||
|
||||
if resp.StatusCode < 500 || resp.StatusCode >= 600 {
|
||||
return nil, lastErr
|
||||
}
|
||||
}
|
||||
|
||||
return nil, lastErr
|
||||
}
|
||||
|
||||
// escapePath escapes each segment of a relative file path for use in URLs.
|
||||
// Slashes are preserved as path separators; other special characters are escaped.
|
||||
// Input should be a relative path (no leading slash). Already-encoded segments
|
||||
|
||||
@@ -0,0 +1,143 @@
|
||||
package gitea
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestGetPullRequestDiff_ExceedsMaxSize(t *testing.T) {
|
||||
// Create a diff that exceeds a small limit
|
||||
largeDiff := strings.Repeat("+ added line\n", 1000) // ~13 KB
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Write([]byte(largeDiff))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
client := NewClient(server.URL, "test-token")
|
||||
client.MaxDiffSize = 100 // 100 bytes limit
|
||||
client.RetryBackoff = []time.Duration{} // no delay in tests
|
||||
|
||||
_, err := client.GetPullRequestDiff(context.Background(), "owner", "repo", 1)
|
||||
if err == nil {
|
||||
t.Fatal("expected error for oversized diff, got nil")
|
||||
}
|
||||
if !errors.Is(err, ErrDiffTooLarge) {
|
||||
t.Errorf("expected ErrDiffTooLarge, got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetPullRequestDiff_WithinMaxSize(t *testing.T) {
|
||||
smallDiff := "diff --git a/f.go b/f.go\n--- a/f.go\n+++ b/f.go\n@@ -1 +1 @@\n-old\n+new\n"
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Write([]byte(smallDiff))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
client := NewClient(server.URL, "test-token")
|
||||
client.MaxDiffSize = 1024 // 1 KB limit — more than enough
|
||||
client.RetryBackoff = []time.Duration{}
|
||||
|
||||
got, err := client.GetPullRequestDiff(context.Background(), "owner", "repo", 1)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if got != smallDiff {
|
||||
t.Errorf("expected diff %q, got %q", smallDiff, got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetPullRequestDiff_ExactlyAtLimit(t *testing.T) {
|
||||
// A diff that is exactly at the limit should succeed
|
||||
exactDiff := strings.Repeat("x", 50)
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Write([]byte(exactDiff))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
client := NewClient(server.URL, "test-token")
|
||||
client.MaxDiffSize = 50 // exactly the size of the diff
|
||||
client.RetryBackoff = []time.Duration{}
|
||||
|
||||
got, err := client.GetPullRequestDiff(context.Background(), "owner", "repo", 1)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error for diff at exact limit: %v", err)
|
||||
}
|
||||
if got != exactDiff {
|
||||
t.Errorf("expected diff to match, got length %d", len(got))
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetPullRequestDiff_OneByteOverLimit(t *testing.T) {
|
||||
// A diff that is one byte over the limit should fail
|
||||
overDiff := strings.Repeat("x", 51)
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Write([]byte(overDiff))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
client := NewClient(server.URL, "test-token")
|
||||
client.MaxDiffSize = 50
|
||||
client.RetryBackoff = []time.Duration{}
|
||||
|
||||
_, err := client.GetPullRequestDiff(context.Background(), "owner", "repo", 1)
|
||||
if err == nil {
|
||||
t.Fatal("expected error for diff one byte over limit")
|
||||
}
|
||||
if !errors.Is(err, ErrDiffTooLarge) {
|
||||
t.Errorf("expected ErrDiffTooLarge, got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetPullRequestDiff_DisabledLimit(t *testing.T) {
|
||||
// When MaxDiffSize is -1, no limit is enforced
|
||||
largeDiff := strings.Repeat("x", 10000)
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Write([]byte(largeDiff))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
client := NewClient(server.URL, "test-token")
|
||||
client.MaxDiffSize = -1 // disabled
|
||||
client.RetryBackoff = []time.Duration{}
|
||||
|
||||
got, err := client.GetPullRequestDiff(context.Background(), "owner", "repo", 1)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error with disabled limit: %v", err)
|
||||
}
|
||||
if got != largeDiff {
|
||||
t.Errorf("expected full diff with disabled limit, got length %d", len(got))
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetPullRequestDiff_DefaultLimit(t *testing.T) {
|
||||
// With zero MaxDiffSize (default), should use DefaultMaxDiffSize.
|
||||
// A small diff should succeed without setting MaxDiffSize.
|
||||
smallDiff := "diff content"
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Write([]byte(smallDiff))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
client := NewClient(server.URL, "test-token")
|
||||
// MaxDiffSize is zero (default) — should use DefaultMaxDiffSize (10 MB)
|
||||
client.RetryBackoff = []time.Duration{}
|
||||
|
||||
got, err := client.GetPullRequestDiff(context.Background(), "owner", "repo", 1)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error with default limit: %v", err)
|
||||
}
|
||||
if got != smallDiff {
|
||||
t.Errorf("expected diff %q, got %q", smallDiff, got)
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user