diff --git a/gitea/client.go b/gitea/client.go index 55835ac..acd710c 100644 --- a/gitea/client.go +++ b/gitea/client.go @@ -47,6 +47,12 @@ func IsServerError(err error) bool { return errors.As(err, &apiErr) && apiErr.StatusCode >= 500 && apiErr.StatusCode < 600 } +// DefaultMaxDiffSize is the default maximum diff size in bytes (10 MB). +const DefaultMaxDiffSize = 10 * 1024 * 1024 + +// ErrDiffTooLarge is returned when a PR diff exceeds the configured MaxDiffSize. +var ErrDiffTooLarge = errors.New("diff size exceeds maximum allowed size") + // Client interacts with the Gitea API. // A Client is safe for concurrent use by multiple goroutines. type Client struct { @@ -61,6 +67,10 @@ type Client struct { // This field must be configured before the first request is made. // Modifying it while requests are in flight is not safe. RetryBackoff []time.Duration + + // MaxDiffSize is the maximum number of bytes allowed when fetching a PR diff. + // If zero, defaults to DefaultMaxDiffSize (10 MB). Set to -1 to disable the limit. + MaxDiffSize int64 } // NewClient creates a new Gitea API client. @@ -125,9 +135,26 @@ func (c *Client) GetPullRequest(ctx context.Context, owner, repo string, number } // GetPullRequestDiff fetches the unified diff for a PR. +// It enforces MaxDiffSize to prevent unbounded memory allocation. +// Returns ErrDiffTooLarge if the diff exceeds the configured limit. func (c *Client) GetPullRequestDiff(ctx context.Context, owner, repo string, number int) (string, error) { reqURL := fmt.Sprintf("%s/api/v1/repos/%s/%s/pulls/%d.diff", c.baseURL, url.PathEscape(owner), url.PathEscape(repo), number) - body, err := c.doGet(ctx, reqURL) + + maxSize := c.MaxDiffSize + if maxSize == 0 { + maxSize = DefaultMaxDiffSize + } + + // When the limit is disabled, use the standard doGet path. + if maxSize < 0 { + body, err := c.doGet(ctx, reqURL) + if err != nil { + return "", fmt.Errorf("fetch diff: %w", err) + } + return string(body), nil + } + + body, err := c.doGetLimited(ctx, reqURL, maxSize) if err != nil { return "", fmt.Errorf("fetch diff: %w", err) } @@ -413,6 +440,86 @@ func (c *Client) doGet(ctx context.Context, reqURL string) ([]byte, error) { return nil, lastErr } +// doGetLimited performs an HTTP GET request with retry (like doGet) but enforces +// a maximum response body size. Returns ErrDiffTooLarge if the response exceeds +// maxBytes. It reads maxBytes+1 to detect overflow without buffering the entire body. +func (c *Client) doGetLimited(ctx context.Context, reqURL string, maxBytes int64) ([]byte, error) { + const maxAttempts = 3 + backoff := c.RetryBackoff + if backoff == nil { + backoff = []time.Duration{1 * time.Second, 2 * time.Second} + } + const maxErrorBodyBytes = 64 * 1024 + + var lastErr error + for attempt := 0; attempt < maxAttempts; attempt++ { + if attempt > 0 { + var delay time.Duration + if attempt-1 < len(backoff) { + delay = backoff[attempt-1] + } + if delay > 0 { + slog.Warn("retrying request after error", + "attempt", attempt+1, + "url", redactURL(reqURL), + "delay", delay.String(), + "lastError", sanitizeErrorForLog(lastErr)) + + timer := time.NewTimer(delay) + select { + case <-timer.C: + case <-ctx.Done(): + timer.Stop() + return nil, ctx.Err() + } + } + } + + req, err := http.NewRequestWithContext(ctx, http.MethodGet, reqURL, nil) + if err != nil { + return nil, err + } + req.Header.Set("Authorization", "token "+c.token) + + resp, err := c.http.Do(req) + if err != nil { + lastErr = err + if attempt < maxAttempts-1 && isTemporaryNetError(err) { + slog.Warn("temporary network error, will retry", + "attempt", attempt+1, + "url", redactURL(reqURL), + "error", err) + continue + } + return nil, lastErr + } + if resp.StatusCode >= 200 && resp.StatusCode < 300 { + // Read up to maxBytes+1 to detect overflow. + limited := io.LimitReader(resp.Body, maxBytes+1) + body, err := io.ReadAll(limited) + resp.Body.Close() + if err != nil { + return nil, err + } + if int64(len(body)) > maxBytes { + return nil, fmt.Errorf("%w: response is larger than %d bytes", ErrDiffTooLarge, maxBytes) + } + return body, nil + } + + errBody, _ := io.ReadAll(io.LimitReader(resp.Body, maxErrorBodyBytes)) + resp.Body.Close() + + lastErr = &APIError{StatusCode: resp.StatusCode, Body: string(errBody)} + + if resp.StatusCode < 500 || resp.StatusCode >= 600 { + return nil, lastErr + } + } + + return nil, lastErr +} + // escapePath escapes each segment of a relative file path for use in URLs. // Slashes are preserved as path separators; other special characters are escaped. // Input should be a relative path (no leading slash). Already-encoded segments diff --git a/gitea/diff_size_test.go b/gitea/diff_size_test.go new file mode 100644 index 0000000..6601143 --- /dev/null +++ b/gitea/diff_size_test.go @@ -0,0 +1,143 @@ +package gitea + +import ( + "context" + "errors" + "net/http" + "net/http/httptest" + "strings" + "testing" + "time" +) + +func TestGetPullRequestDiff_ExceedsMaxSize(t *testing.T) { + // Create a diff that exceeds a small limit + largeDiff := strings.Repeat("+ added line\n", 1000) // ~13 KB + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Write([]byte(largeDiff)) + })) + defer server.Close() + + client := NewClient(server.URL, "test-token") + client.MaxDiffSize = 100 // 100 bytes limit + client.RetryBackoff = []time.Duration{} // no delay in tests + + _, err := client.GetPullRequestDiff(context.Background(), "owner", "repo", 1) + if err == nil { + t.Fatal("expected error for oversized diff, got nil") + } + if !errors.Is(err, ErrDiffTooLarge) { + t.Errorf("expected ErrDiffTooLarge, got: %v", err) + } +} + +func TestGetPullRequestDiff_WithinMaxSize(t *testing.T) { + smallDiff := "diff --git a/f.go b/f.go\n--- a/f.go\n+++ b/f.go\n@@ -1 +1 @@\n-old\n+new\n" + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Write([]byte(smallDiff)) + })) + defer server.Close() + + client := NewClient(server.URL, "test-token") + client.MaxDiffSize = 1024 // 1 KB limit — more than enough + client.RetryBackoff = []time.Duration{} + + got, err := client.GetPullRequestDiff(context.Background(), "owner", "repo", 1) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if got != smallDiff { + t.Errorf("expected diff %q, got %q", smallDiff, got) + } +} + +func TestGetPullRequestDiff_ExactlyAtLimit(t *testing.T) { + // A diff that is exactly at the limit should succeed + exactDiff := strings.Repeat("x", 50) + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Write([]byte(exactDiff)) + })) + defer server.Close() + + client := NewClient(server.URL, "test-token") + client.MaxDiffSize = 50 // exactly the size of the diff + client.RetryBackoff = []time.Duration{} + + got, err := client.GetPullRequestDiff(context.Background(), "owner", "repo", 1) + if err != nil { + t.Fatalf("unexpected error for diff at exact limit: %v", err) + } + if got != exactDiff { + t.Errorf("expected diff to match, got length %d", len(got)) + } +} + +func TestGetPullRequestDiff_OneByteOverLimit(t *testing.T) { + // A diff that is one byte over the limit should fail + overDiff := strings.Repeat("x", 51) + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Write([]byte(overDiff)) + })) + defer server.Close() + + client := NewClient(server.URL, "test-token") + client.MaxDiffSize = 50 + client.RetryBackoff = []time.Duration{} + + _, err := client.GetPullRequestDiff(context.Background(), "owner", "repo", 1) + if err == nil { + t.Fatal("expected error for diff one byte over limit") + } + if !errors.Is(err, ErrDiffTooLarge) { + t.Errorf("expected ErrDiffTooLarge, got: %v", err) + } +} + +func TestGetPullRequestDiff_DisabledLimit(t *testing.T) { + // When MaxDiffSize is -1, no limit is enforced + largeDiff := strings.Repeat("x", 10000) + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Write([]byte(largeDiff)) + })) + defer server.Close() + + client := NewClient(server.URL, "test-token") + client.MaxDiffSize = -1 // disabled + client.RetryBackoff = []time.Duration{} + + got, err := client.GetPullRequestDiff(context.Background(), "owner", "repo", 1) + if err != nil { + t.Fatalf("unexpected error with disabled limit: %v", err) + } + if got != largeDiff { + t.Errorf("expected full diff with disabled limit, got length %d", len(got)) + } +} + +func TestGetPullRequestDiff_DefaultLimit(t *testing.T) { + // With zero MaxDiffSize (default), should use DefaultMaxDiffSize. + // A small diff should succeed without setting MaxDiffSize. + smallDiff := "diff content" + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Write([]byte(smallDiff)) + })) + defer server.Close() + + client := NewClient(server.URL, "test-token") + // MaxDiffSize is zero (default) — should use DefaultMaxDiffSize (10 MB) + client.RetryBackoff = []time.Duration{} + + got, err := client.GetPullRequestDiff(context.Background(), "owner", "repo", 1) + if err != nil { + t.Fatalf("unexpected error with default limit: %v", err) + } + if got != smallDiff { + t.Errorf("expected diff %q, got %q", smallDiff, got) + } +}