diff --git a/gitea/client.go b/gitea/client.go index 55835ac..32c3cae 100644 --- a/gitea/client.go +++ b/gitea/client.go @@ -11,6 +11,7 @@ import ( "fmt" "io" "log/slog" + "math" "net" "net/http" "net/url" @@ -47,6 +48,12 @@ func IsServerError(err error) bool { return errors.As(err, &apiErr) && apiErr.StatusCode >= 500 && apiErr.StatusCode < 600 } +// DefaultMaxDiffSize is the default maximum diff size in bytes (10 MB). +const DefaultMaxDiffSize = 10 * 1024 * 1024 + +// ErrDiffTooLarge is returned when a PR diff exceeds the configured MaxDiffSize. +var ErrDiffTooLarge = errors.New("diff size exceeds maximum allowed size") + // Client interacts with the Gitea API. // A Client is safe for concurrent use by multiple goroutines. type Client struct { @@ -61,6 +68,14 @@ type Client struct { // This field must be configured before the first request is made. // Modifying it while requests are in flight is not safe. RetryBackoff []time.Duration + + // MaxDiffSize is the maximum number of bytes allowed when fetching a PR diff. + // If zero, defaults to DefaultMaxDiffSize (10 MB). Set to any negative value + // (or math.MaxInt64) to disable the limit. + // + // This field must be configured before the first request is made. + // Modifying it while requests are in flight is not safe. + MaxDiffSize int64 } // NewClient creates a new Gitea API client. @@ -125,9 +140,28 @@ func (c *Client) GetPullRequest(ctx context.Context, owner, repo string, number } // GetPullRequestDiff fetches the unified diff for a PR. +// It enforces MaxDiffSize to prevent unbounded memory allocation. +// Returns ErrDiffTooLarge if the diff exceeds the configured limit. func (c *Client) GetPullRequestDiff(ctx context.Context, owner, repo string, number int) (string, error) { reqURL := fmt.Sprintf("%s/api/v1/repos/%s/%s/pulls/%d.diff", c.baseURL, url.PathEscape(owner), url.PathEscape(repo), number) - body, err := c.doGet(ctx, reqURL) + + maxSize := c.MaxDiffSize + if maxSize == 0 { + maxSize = DefaultMaxDiffSize + } + + // When the limit is disabled (negative) or set to math.MaxInt64 (which + // would overflow the +1 detection and silently disable enforcement), + // use the standard unlimited doGet path. + if maxSize < 0 || maxSize == math.MaxInt64 { + body, err := c.doGet(ctx, reqURL) + if err != nil { + return "", fmt.Errorf("fetch diff: %w", err) + } + return string(body), nil + } + + body, err := c.doGetLimited(ctx, reqURL, maxSize) if err != nil { return "", fmt.Errorf("fetch diff: %w", err) } @@ -292,9 +326,9 @@ func isRetriableSyscallError(err error) bool { return true } -// redactURL strips query parameters from a URL for safe logging. -// This prevents accidental exposure of sensitive data that future callers -// might pass via query strings. +// redactURL strips query parameters and userinfo credentials from a URL for +// safe logging. This prevents accidental exposure of sensitive data (tokens in +// query strings, or user:pass in the authority) in log output. func redactURL(rawURL string) string { parsed, err := url.Parse(rawURL) if err != nil { @@ -302,6 +336,9 @@ func redactURL(rawURL string) string { // potentially logging something sensitive. return "[invalid URL]" } + if parsed.User != nil { + parsed.User = url.User("REDACTED") + } if parsed.RawQuery != "" { parsed.RawQuery = "[redacted]" } @@ -322,10 +359,12 @@ func sanitizeErrorForLog(err error) string { return err.Error() } -// doGet performs an HTTP GET request with retry on 5xx errors and temporary -// network errors. Retries up to 3 times with exponential backoff (1s, 2s delays -// by default; configurable via Client.RetryBackoff for testing). -func (c *Client) doGet(ctx context.Context, reqURL string) ([]byte, error) { +// doGetWithReader performs an HTTP GET request with retry on 5xx errors and +// temporary network errors. Retries up to 3 times with exponential backoff +// (1s, 2s delays by default; configurable via Client.RetryBackoff for testing). +// The readBody function is called with the response body on success (2xx) and +// is responsible for reading and closing it. +func (c *Client) doGetWithReader(ctx context.Context, reqURL string, readBody func(io.ReadCloser) ([]byte, error)) ([]byte, error) { const maxAttempts = 3 // backoff[i] is the delay before attempt i+1 (i.e., after attempt i fails). // First attempt (i=0) has no delay; retries wait 1s then 2s by default. @@ -390,12 +429,7 @@ func (c *Client) doGet(ctx context.Context, reqURL string) ([]byte, error) { return nil, lastErr } if resp.StatusCode >= 200 && resp.StatusCode < 300 { - body, err := io.ReadAll(resp.Body) - resp.Body.Close() - if err != nil { - return nil, err - } - return body, nil + return readBody(resp.Body) } // Error path: limit how much we read from potentially malicious server @@ -413,6 +447,39 @@ func (c *Client) doGet(ctx context.Context, reqURL string) ([]byte, error) { return nil, lastErr } +// doGet performs an HTTP GET request with retry, reading the full response body. +func (c *Client) doGet(ctx context.Context, reqURL string) ([]byte, error) { + return c.doGetWithReader(ctx, reqURL, func(body io.ReadCloser) ([]byte, error) { + defer body.Close() + return io.ReadAll(body) + }) +} + +// doGetLimited performs an HTTP GET request with retry but enforces a maximum +// response body size. Returns ErrDiffTooLarge if the response exceeds maxBytes. +// It reads maxBytes+1 (clamped to avoid overflow) to detect truncation without +// buffering the entire body. +func (c *Client) doGetLimited(ctx context.Context, reqURL string, maxBytes int64) ([]byte, error) { + return c.doGetWithReader(ctx, reqURL, func(body io.ReadCloser) ([]byte, error) { + defer body.Close() + // Read up to maxBytes+1 to detect overflow. + // Clamp to prevent integer overflow when maxBytes == math.MaxInt64. + limitBytes := maxBytes + 1 + if limitBytes <= 0 { + limitBytes = math.MaxInt64 + } + limited := io.LimitReader(body, limitBytes) + data, err := io.ReadAll(limited) + if err != nil { + return nil, err + } + if int64(len(data)) > maxBytes { + return nil, fmt.Errorf("%w: response exceeds %d bytes", ErrDiffTooLarge, maxBytes) + } + return data, nil + }) +} + // escapePath escapes each segment of a relative file path for use in URLs. // Slashes are preserved as path separators; other special characters are escaped. // Input should be a relative path (no leading slash). Already-encoded segments diff --git a/gitea/client_test.go b/gitea/client_test.go index 2637c2e..eec0b7d 100644 --- a/gitea/client_test.go +++ b/gitea/client_test.go @@ -1092,6 +1092,21 @@ func TestRedactURL(t *testing.T) { input: "", want: "", }, + { + name: "with userinfo - redacts credentials", + input: "https://admin:secret@gitea.example.com/api/v1/repos", + want: "https://REDACTED@gitea.example.com/api/v1/repos", + }, + { + name: "with userinfo and query params", + input: "https://user:pass@example.com/path?token=abc", + want: "https://REDACTED@example.com/path?[redacted]", + }, + { + name: "username only - no password", + input: "https://user@example.com/path", + want: "https://REDACTED@example.com/path", + }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { diff --git a/gitea/diff_size_test.go b/gitea/diff_size_test.go new file mode 100644 index 0000000..005f87c --- /dev/null +++ b/gitea/diff_size_test.go @@ -0,0 +1,97 @@ +package gitea + +import ( + "context" + "errors" + "math" + "net/http" + "net/http/httptest" + "strings" + "testing" + "time" +) + +func TestGetPullRequestDiff_SizeLimits(t *testing.T) { + tests := []struct { + name string + diff string + maxDiffSize int64 + wantErr error + wantDiff string + }{ + { + name: "exceeds max size", + diff: strings.Repeat("+ added line\n", 1000), // ~13 KB + maxDiffSize: 100, + wantErr: ErrDiffTooLarge, + }, + { + name: "within max size", + diff: "diff --git a/f.go b/f.go\n--- a/f.go\n+++ b/f.go\n@@ -1 +1 @@\n-old\n+new\n", + maxDiffSize: 1024, + wantDiff: "diff --git a/f.go b/f.go\n--- a/f.go\n+++ b/f.go\n@@ -1 +1 @@\n-old\n+new\n", + }, + { + name: "exactly at limit", + diff: strings.Repeat("x", 50), + maxDiffSize: 50, + wantDiff: strings.Repeat("x", 50), + }, + { + name: "one byte over limit", + diff: strings.Repeat("x", 51), + maxDiffSize: 50, + wantErr: ErrDiffTooLarge, + }, + { + name: "disabled limit", + diff: strings.Repeat("x", 10000), + maxDiffSize: -1, + wantDiff: strings.Repeat("x", 10000), + }, + { + name: "math.MaxInt64 treated as disabled", + diff: strings.Repeat("x", 10000), + maxDiffSize: math.MaxInt64, + wantDiff: strings.Repeat("x", 10000), + }, + { + name: "default limit", + diff: "diff content", + maxDiffSize: 0, // zero means use DefaultMaxDiffSize + wantDiff: "diff content", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Write([]byte(tt.diff)) //nolint:errcheck // test handler + })) + defer server.Close() + + client := NewClient(server.URL, "test-token") + client.MaxDiffSize = tt.maxDiffSize + client.RetryBackoff = []time.Duration{} + + got, err := client.GetPullRequestDiff(context.Background(), "owner", "repo", 1) + + if tt.wantErr != nil { + if err == nil { + t.Fatal("expected error, got nil") + } + if !errors.Is(err, tt.wantErr) { + t.Errorf("expected %v, got: %v", tt.wantErr, err) + } + return + } + + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if got != tt.wantDiff { + t.Errorf("diff mismatch: got length %d, want length %d", len(got), len(tt.wantDiff)) + } + }) + } +}