diff --git a/gitea/client.go b/gitea/client.go index 150b3c0..b016349 100644 --- a/gitea/client.go +++ b/gitea/client.go @@ -11,9 +11,11 @@ import ( "fmt" "io" "log/slog" + "net" "net/http" "net/url" "strings" + "syscall" "time" ) @@ -39,12 +41,26 @@ func IsNotFound(err error) bool { return errors.As(err, &apiErr) && apiErr.StatusCode == http.StatusNotFound } +// IsServerError reports whether an error is an API 5xx response. +func IsServerError(err error) bool { + var apiErr *APIError + return errors.As(err, &apiErr) && apiErr.StatusCode >= 500 && apiErr.StatusCode < 600 +} + // Client interacts with the Gitea API. // A Client is safe for concurrent use by multiple goroutines. type Client struct { baseURL string token string http *http.Client + + // RetryBackoff defines the delays between retry attempts. + // RetryBackoff[i] is the delay before attempt i+1 (after attempt i fails). + // If nil, defaults to {1s, 2s}. Set to shorter durations in tests. + // + // This field must be configured before the first request is made. + // Modifying it while requests are in flight is not safe. + RetryBackoff []time.Duration } // NewClient creates a new Gitea API client. @@ -56,6 +72,12 @@ func NewClient(baseURL, token string) *Client { } } +// SetHTTPClient sets the underlying HTTP client used for requests. +// This is intended for testing to inject mock transports. +func (c *Client) SetHTTPClient(hc *http.Client) { + c.http = hc +} + // PullRequest holds relevant PR metadata. type PullRequest struct { Title string `json:"title"` @@ -210,24 +232,185 @@ func (c *Client) PostReview(ctx context.Context, owner, repo string, number int, return &review, nil } +// isTemporaryNetError reports whether err is a temporary network error worth retrying. +// This includes connection refused, network unreachable, connection reset, and DNS +// timeouts. It explicitly excludes permanent errors like permission denied or +// "no such host" DNS failures. +func isTemporaryNetError(err error) bool { + if err == nil { + return false + } + + // Check for OpError and inspect the underlying syscall error. + // Not all OpErrors are transient — permission denied, for example, is permanent. + var opErr *net.OpError + if errors.As(err, &opErr) { + return isRetriableSyscallError(opErr.Err) + } + + // DNS errors: only retry on timeout, not on "no such host" which is permanent. + var dnsErr *net.DNSError + if errors.As(err, &dnsErr) { + return dnsErr.IsTimeout + } + + // Check for net.Error with Timeout() (Temporary is deprecated) + var netErr net.Error + if errors.As(err, &netErr) { + return netErr.Timeout() + } + + return false +} + +// isRetriableSyscallError reports whether the underlying error from a net.OpError +// is a transient syscall error worth retrying. +func isRetriableSyscallError(err error) bool { + if err == nil { + return false + } + + // Check for syscall.Errno directly or wrapped + var errno syscall.Errno + if errors.As(err, &errno) { + switch errno { + case syscall.ECONNREFUSED, // connection refused — server not listening + syscall.ECONNRESET, // connection reset by peer + syscall.ENETUNREACH, // network unreachable + syscall.EHOSTUNREACH, // host unreachable + syscall.ETIMEDOUT: // connection timed out + return true + default: + // EACCES, EPERM, etc. are permanent — don't retry + return false + } + } + + // If we can't identify the specific syscall error, be conservative and retry. + // This handles wrapped errors or platform-specific error types. + // The retry count is limited, so erring on the side of retrying is safe. + return true +} + +// redactURL strips query parameters from a URL for safe logging. +// This prevents accidental exposure of sensitive data that future callers +// might pass via query strings. +func redactURL(rawURL string) string { + parsed, err := url.Parse(rawURL) + if err != nil { + // If we cannot parse it, return a safe placeholder rather than + // potentially logging something sensitive. + return "[invalid URL]" + } + if parsed.RawQuery != "" { + parsed.RawQuery = "[redacted]" + } + return parsed.String() +} + +// sanitizeErrorForLog returns a loggable version of an error that omits +// potentially sensitive content like response bodies. For APIError, only +// the status code is included; for other errors, the type is preserved. +func sanitizeErrorForLog(err error) string { + if err == nil { + return "" + } + var apiErr *APIError + if errors.As(err, &apiErr) { + return fmt.Sprintf("HTTP %d", apiErr.StatusCode) + } + return err.Error() +} + +// doGet performs an HTTP GET request with retry on 5xx errors and temporary +// network errors. Retries up to 3 times with exponential backoff (1s, 2s delays +// by default; configurable via Client.RetryBackoff for testing). func (c *Client) doGet(ctx context.Context, reqURL string) ([]byte, error) { - req, err := http.NewRequestWithContext(ctx, http.MethodGet, reqURL, nil) - if err != nil { - return nil, err + const maxAttempts = 3 + // backoff[i] is the delay before attempt i+1 (i.e., after attempt i fails). + // First attempt (i=0) has no delay; retries wait 1s then 2s by default. + backoff := c.RetryBackoff + if backoff == nil { + backoff = []time.Duration{1 * time.Second, 2 * time.Second} } - req.Header.Set("Authorization", "token "+c.token) - resp, err := c.http.Do(req) - if err != nil { - return nil, err - } - defer resp.Body.Close() + // maxErrorBodyBytes limits how much of an error response body we read + // to protect against malicious servers sending unbounded data. + const maxErrorBodyBytes = 64 * 1024 // 64 KB - if resp.StatusCode < 200 || resp.StatusCode >= 300 { - body, _ := io.ReadAll(resp.Body) - return nil, &APIError{StatusCode: resp.StatusCode, Body: string(body)} + var lastErr error + for attempt := 0; attempt < maxAttempts; attempt++ { + if attempt > 0 { + // Determine delay: use backoff slice if available, otherwise retry immediately. + // An empty RetryBackoff slice means "retry without delay" — this is intentional + // as the caller explicitly configured no delays. + var delay time.Duration + if attempt-1 < len(backoff) { + delay = backoff[attempt-1] + } + + if delay > 0 { + slog.Warn("retrying request after error", + "attempt", attempt+1, + "url", redactURL(reqURL), + "delay", delay.String(), + "lastError", sanitizeErrorForLog(lastErr)) + + timer := time.NewTimer(delay) + select { + case <-timer.C: + case <-ctx.Done(): + timer.Stop() + return nil, ctx.Err() + } + } + } + + req, err := http.NewRequestWithContext(ctx, http.MethodGet, reqURL, nil) + if err != nil { + return nil, err + } + req.Header.Set("Authorization", "token "+c.token) + + resp, err := c.http.Do(req) + if err != nil { + // Always capture the error for consistent return at loop end. + // This ensures both network errors and HTTP 5xx return lastErr. + lastErr = err + + // Only retry temporary network errors when attempts remain. + if attempt < maxAttempts-1 && isTemporaryNetError(err) { + slog.Warn("temporary network error, will retry", + "attempt", attempt+1, + "url", redactURL(reqURL), + "error", err) + continue + } + // Non-retryable network error or final attempt exhausted. + return nil, lastErr + } + if resp.StatusCode >= 200 && resp.StatusCode < 300 { + body, err := io.ReadAll(resp.Body) + resp.Body.Close() + if err != nil { + return nil, err + } + return body, nil + } + + // Error path: limit how much we read from potentially malicious server + errBody, _ := io.ReadAll(io.LimitReader(resp.Body, maxErrorBodyBytes)) + resp.Body.Close() + + lastErr = &APIError{StatusCode: resp.StatusCode, Body: string(errBody)} + + // Only retry on 5xx server errors + if resp.StatusCode < 500 || resp.StatusCode >= 600 { + return nil, lastErr + } } - return io.ReadAll(resp.Body) + + return nil, lastErr } // escapePath escapes each segment of a relative file path for use in URLs. @@ -317,9 +500,9 @@ func (c *Client) GetAllFilesInPath(ctx context.Context, owner, repo, path string // Review represents a pull request review from the Gitea API. type Review struct { - ID int64 `json:"id"` - Body string `json:"body"` - User struct { + ID int64 `json:"id"` + Body string `json:"body"` + User struct { Login string `json:"login"` } `json:"user"` State string `json:"state"` diff --git a/gitea/client_test.go b/gitea/client_test.go index d09e38b..76156e7 100644 --- a/gitea/client_test.go +++ b/gitea/client_test.go @@ -6,10 +6,14 @@ import ( "errors" "fmt" "io" + "net" "net/http" "net/http/httptest" "strings" + "sync/atomic" + "syscall" "testing" + "time" ) func TestGetPullRequest(t *testing.T) { @@ -584,9 +588,9 @@ func TestGetAllFilesInPath_403Propagates(t *testing.T) { func TestIsNotFound(t *testing.T) { tests := []struct { - name string - err error - want bool + name string + err error + want bool }{ {"nil error", nil, false}, {"non-API error", fmt.Errorf("network timeout"), false}, @@ -743,3 +747,347 @@ func TestResolveComment_Error(t *testing.T) { t.Fatal("expected error for 404 response") } } + +func TestIsServerError(t *testing.T) { + tests := []struct { + name string + err error + want bool + }{ + {"nil error", nil, false}, + {"non-API error", fmt.Errorf("network timeout"), false}, + {"404 APIError", &APIError{StatusCode: 404, Body: "not found"}, false}, + {"500 APIError", &APIError{StatusCode: 500, Body: "server error"}, true}, + {"502 APIError", &APIError{StatusCode: 502, Body: "bad gateway"}, true}, + {"503 APIError", &APIError{StatusCode: 503, Body: "unavailable"}, true}, + {"599 APIError", &APIError{StatusCode: 599, Body: "edge case"}, true}, + {"600 not server error", &APIError{StatusCode: 600, Body: "edge"}, false}, + {"400 not server error", &APIError{StatusCode: 400, Body: "bad request"}, false}, + {"wrapped 500", fmt.Errorf("fetch: %w", &APIError{StatusCode: 500, Body: "err"}), true}, + {"wrapped 404", fmt.Errorf("fetch: %w", &APIError{StatusCode: 404, Body: "err"}), false}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := IsServerError(tt.err) + if got != tt.want { + t.Errorf("IsServerError(%v) = %v, want %v", tt.err, got, tt.want) + } + }) + } +} + +func TestDoGet_RetriesOn500(t *testing.T) { + attempts := 0 + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + attempts++ + if attempts < 3 { + w.WriteHeader(http.StatusInternalServerError) + w.Write([]byte(`{"message":"transient error"}`)) + return + } + w.WriteHeader(http.StatusOK) + w.Write([]byte(`{"data":"success"}`)) + })) + defer server.Close() + + client := NewClient(server.URL, "test-token") + // Use short backoff for fast tests + client.RetryBackoff = []time.Duration{1 * time.Millisecond, 1 * time.Millisecond} + + body, err := client.doGet(context.Background(), server.URL+"/test") + if err != nil { + t.Fatalf("expected success after retry, got error: %v", err) + } + if string(body) != `{"data":"success"}` { + t.Errorf("body = %q, want %q", string(body), `{"data":"success"}`) + } + if attempts != 3 { + t.Errorf("attempts = %d, want 3", attempts) + } +} + +func TestDoGet_FailsAfterMaxRetries(t *testing.T) { + attempts := 0 + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + attempts++ + w.WriteHeader(http.StatusInternalServerError) + w.Write([]byte(`{"message":"persistent error"}`)) + })) + defer server.Close() + + client := NewClient(server.URL, "test-token") + // Use short backoff for fast tests + client.RetryBackoff = []time.Duration{1 * time.Millisecond, 1 * time.Millisecond} + + _, err := client.doGet(context.Background(), server.URL+"/test") + if err == nil { + t.Fatal("expected error after max retries") + } + var apiErr *APIError + if !errors.As(err, &apiErr) { + t.Fatalf("expected APIError, got: %v", err) + } + if apiErr.StatusCode != http.StatusInternalServerError { + t.Errorf("status = %d, want 500", apiErr.StatusCode) + } + if attempts != 3 { + t.Errorf("attempts = %d, want 3 (max retries)", attempts) + } +} + +func TestDoGet_NoRetryOn4xx(t *testing.T) { + attempts := 0 + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + attempts++ + w.WriteHeader(http.StatusForbidden) + w.Write([]byte(`{"message":"forbidden"}`)) + })) + defer server.Close() + + client := NewClient(server.URL, "test-token") + _, err := client.doGet(context.Background(), server.URL+"/test") + if err == nil { + t.Fatal("expected error for 403") + } + var apiErr *APIError + if !errors.As(err, &apiErr) { + t.Fatalf("expected APIError, got: %v", err) + } + if apiErr.StatusCode != http.StatusForbidden { + t.Errorf("status = %d, want 403", apiErr.StatusCode) + } + if attempts != 1 { + t.Errorf("attempts = %d, want 1 (no retry on 4xx)", attempts) + } +} + +func TestDoGet_RespectsContextCancellation(t *testing.T) { + attempts := 0 + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + attempts++ + w.WriteHeader(http.StatusInternalServerError) + w.Write([]byte(`{"message":"error"}`)) + })) + defer server.Close() + + ctx, cancel := context.WithCancel(context.Background()) + + client := NewClient(server.URL, "test-token") + // Use longer backoff to give us time to cancel during the wait + client.RetryBackoff = []time.Duration{100 * time.Millisecond, 100 * time.Millisecond} + + // Cancel after first attempt returns and retry begins + go func() { + time.Sleep(20 * time.Millisecond) + cancel() + }() + + _, err := client.doGet(ctx, server.URL+"/test") + if err == nil { + t.Fatal("expected error on context cancellation") + } + // Should have made 1 attempt, then context cancelled during backoff + if attempts != 1 { + t.Errorf("attempts = %d, expected 1 before context cancel during backoff", attempts) + } +} + + +// mockTransport is a test helper that returns errors for the first N calls, +// then delegates to a real server. +type mockTransport struct { + failCount int32 // number of failures remaining (atomic) + failErr error // error to return on failure + realServer *httptest.Server + attemptsMade atomic.Int32 // tracks total attempts +} + +func (m *mockTransport) RoundTrip(req *http.Request) (*http.Response, error) { + m.attemptsMade.Add(1) + remaining := atomic.AddInt32(&m.failCount, -1) + if remaining >= 0 { + // Still have failures to return + return nil, m.failErr + } + // Redirect to real server + req.URL.Host = m.realServer.Listener.Addr().String() + req.URL.Scheme = "http" + return http.DefaultTransport.RoundTrip(req) +} + +func TestDoGet_RetriesOnTemporaryNetError(t *testing.T) { + // Real server that will handle successful requests + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + w.Write([]byte(`{"status":"ok"}`)) + })) + defer server.Close() + + // Mock transport: fail twice with ECONNREFUSED, then succeed + mt := &mockTransport{ + failCount: 2, + failErr: &net.OpError{Op: "dial", Net: "tcp", Err: syscall.ECONNREFUSED}, + realServer: server, + } + + client := NewClient("http://fake-host/", "test-token") + client.SetHTTPClient(&http.Client{Transport: mt}) + client.RetryBackoff = []time.Duration{1 * time.Millisecond, 1 * time.Millisecond} + + body, err := client.doGet(context.Background(), "http://fake-host/test") + if err != nil { + t.Fatalf("expected success after retries, got error: %v", err) + } + if string(body) != `{"status":"ok"}` { + t.Errorf("body = %q, want %q", string(body), `{"status":"ok"}`) + } + + // Should have made exactly 3 attempts: 2 failures + 1 success + if got := mt.attemptsMade.Load(); got != 3 { + t.Errorf("attempts = %d, want 3 (2 failures + 1 success)", got) + } +} + +func TestIsTemporaryNetError(t *testing.T) { + tests := []struct { + name string + err error + want bool + }{ + {"nil error", nil, false}, + {"plain error", fmt.Errorf("some error"), false}, + // OpError with retriable syscall errors + {"OpError ECONNREFUSED", &net.OpError{Op: "dial", Err: syscall.ECONNREFUSED}, true}, + {"OpError ECONNRESET", &net.OpError{Op: "read", Err: syscall.ECONNRESET}, true}, + {"OpError ENETUNREACH", &net.OpError{Op: "dial", Err: syscall.ENETUNREACH}, true}, + {"OpError EHOSTUNREACH", &net.OpError{Op: "dial", Err: syscall.EHOSTUNREACH}, true}, + {"OpError ETIMEDOUT", &net.OpError{Op: "dial", Err: syscall.ETIMEDOUT}, true}, + // OpError with permanent syscall errors — should NOT retry + {"OpError EACCES", &net.OpError{Op: "dial", Err: syscall.EACCES}, false}, + {"OpError EPERM", &net.OpError{Op: "dial", Err: syscall.EPERM}, false}, + // OpError with unknown inner error — conservative retry + {"OpError unknown inner", &net.OpError{Op: "dial", Err: fmt.Errorf("unknown")}, true}, + // DNS errors + {"DNS timeout", &net.DNSError{IsTimeout: true}, true}, + {"DNS no such host", &net.DNSError{IsTimeout: false, Name: "bad.host"}, false}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := isTemporaryNetError(tt.err) + if got != tt.want { + t.Errorf("isTemporaryNetError(%v) = %v, want %v", tt.err, got, tt.want) + } + }) + } +} + +func TestIsRetriableSyscallError(t *testing.T) { + tests := []struct { + name string + err error + want bool + }{ + {"nil", nil, false}, + {"ECONNREFUSED", syscall.ECONNREFUSED, true}, + {"ECONNRESET", syscall.ECONNRESET, true}, + {"ENETUNREACH", syscall.ENETUNREACH, true}, + {"EHOSTUNREACH", syscall.EHOSTUNREACH, true}, + {"ETIMEDOUT", syscall.ETIMEDOUT, true}, + {"EACCES (permanent)", syscall.EACCES, false}, + {"EPERM (permanent)", syscall.EPERM, false}, + {"ENOENT (permanent)", syscall.ENOENT, false}, + {"unknown error", fmt.Errorf("something"), true}, // conservative retry + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := isRetriableSyscallError(tt.err) + if got != tt.want { + t.Errorf("isRetriableSyscallError(%v) = %v, want %v", tt.err, got, tt.want) + } + }) + } +} + +func TestRedactURL(t *testing.T) { + tests := []struct { + name string + input string + want string + }{ + { + name: "no query params", + input: "https://gitea.example.com/api/v1/repos/owner/repo/pulls/1", + want: "https://gitea.example.com/api/v1/repos/owner/repo/pulls/1", + }, + { + name: "with query params - redacts", + input: "https://gitea.example.com/api/v1/repos/owner/repo/raw/file?ref=main", + want: "https://gitea.example.com/api/v1/repos/owner/repo/raw/file?[redacted]", + }, + { + name: "multiple query params", + input: "https://example.com/path?token=secret&page=1", + want: "https://example.com/path?[redacted]", + }, + { + name: "invalid URL", + input: "://invalid", + want: "[invalid URL]", + }, + { + name: "empty string", + input: "", + want: "", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := redactURL(tt.input) + if got != tt.want { + t.Errorf("redactURL(%q) = %q, want %q", tt.input, got, tt.want) + } + }) + } +} + +func TestSanitizeErrorForLog(t *testing.T) { + tests := []struct { + name string + err error + want string + }{ + { + name: "nil error", + err: nil, + want: "", + }, + { + name: "APIError omits body", + err: &APIError{StatusCode: 500, Body: "internal error: database connection failed"}, + want: "HTTP 500", + }, + { + name: "APIError with large body still only shows status", + err: &APIError{StatusCode: 502, Body: strings.Repeat("x", 1000)}, + want: "HTTP 502", + }, + { + name: "non-API error preserved", + err: fmt.Errorf("connection refused"), + want: "connection refused", + }, + { + name: "wrapped APIError", + err: fmt.Errorf("request failed: %w", &APIError{StatusCode: 503, Body: "service unavailable"}), + want: "HTTP 503", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := sanitizeErrorForLog(tt.err) + if got != tt.want { + t.Errorf("sanitizeErrorForLog() = %q, want %q", got, tt.want) + } + }) + } +}