diff --git a/gitea/client.go b/gitea/client.go index 150b3c0..5feb9a8 100644 --- a/gitea/client.go +++ b/gitea/client.go @@ -39,6 +39,12 @@ func IsNotFound(err error) bool { return errors.As(err, &apiErr) && apiErr.StatusCode == http.StatusNotFound } +// IsServerError reports whether an error is an API 5xx response. +func IsServerError(err error) bool { + var apiErr *APIError + return errors.As(err, &apiErr) && apiErr.StatusCode >= 500 && apiErr.StatusCode < 600 +} + // Client interacts with the Gitea API. // A Client is safe for concurrent use by multiple goroutines. type Client struct { @@ -210,24 +216,56 @@ func (c *Client) PostReview(ctx context.Context, owner, repo string, number int, return &review, nil } +// doGet performs an HTTP GET request with retry on 5xx errors. +// Retries up to 3 times with exponential backoff (1s, 2s delays). func (c *Client) doGet(ctx context.Context, reqURL string) ([]byte, error) { - req, err := http.NewRequestWithContext(ctx, http.MethodGet, reqURL, nil) - if err != nil { - return nil, err - } - req.Header.Set("Authorization", "token "+c.token) + const maxAttempts = 3 + backoff := []time.Duration{0, 1 * time.Second, 2 * time.Second} - resp, err := c.http.Do(req) - if err != nil { - return nil, err - } - defer resp.Body.Close() + var lastErr error + for attempt := 0; attempt < maxAttempts; attempt++ { + if attempt > 0 { + slog.Warn("retrying request after server error", + "attempt", attempt+1, + "url", reqURL, + "delay", backoff[attempt].String()) + select { + case <-time.After(backoff[attempt]): + case <-ctx.Done(): + return nil, ctx.Err() + } + } - if resp.StatusCode < 200 || resp.StatusCode >= 300 { - body, _ := io.ReadAll(resp.Body) - return nil, &APIError{StatusCode: resp.StatusCode, Body: string(body)} + req, err := http.NewRequestWithContext(ctx, http.MethodGet, reqURL, nil) + if err != nil { + return nil, err + } + req.Header.Set("Authorization", "token "+c.token) + + resp, err := c.http.Do(req) + if err != nil { + return nil, err + } + + body, readErr := io.ReadAll(resp.Body) + resp.Body.Close() + + if resp.StatusCode >= 200 && resp.StatusCode < 300 { + if readErr != nil { + return nil, readErr + } + return body, nil + } + + lastErr = &APIError{StatusCode: resp.StatusCode, Body: string(body)} + + // Only retry on 5xx server errors + if resp.StatusCode < 500 || resp.StatusCode >= 600 { + return nil, lastErr + } } - return io.ReadAll(resp.Body) + + return nil, lastErr } // escapePath escapes each segment of a relative file path for use in URLs. diff --git a/gitea/client_test.go b/gitea/client_test.go index d09e38b..ebd5cfe 100644 --- a/gitea/client_test.go +++ b/gitea/client_test.go @@ -10,6 +10,7 @@ import ( "net/http/httptest" "strings" "testing" + "time" ) func TestGetPullRequest(t *testing.T) { @@ -743,3 +744,137 @@ func TestResolveComment_Error(t *testing.T) { t.Fatal("expected error for 404 response") } } + +func TestIsServerError(t *testing.T) { + tests := []struct { + name string + err error + want bool + }{ + {"nil error", nil, false}, + {"non-API error", fmt.Errorf("network timeout"), false}, + {"404 APIError", &APIError{StatusCode: 404, Body: "not found"}, false}, + {"500 APIError", &APIError{StatusCode: 500, Body: "server error"}, true}, + {"502 APIError", &APIError{StatusCode: 502, Body: "bad gateway"}, true}, + {"503 APIError", &APIError{StatusCode: 503, Body: "unavailable"}, true}, + {"599 APIError", &APIError{StatusCode: 599, Body: "edge case"}, true}, + {"600 not server error", &APIError{StatusCode: 600, Body: "edge"}, false}, + {"400 not server error", &APIError{StatusCode: 400, Body: "bad request"}, false}, + {"wrapped 500", fmt.Errorf("fetch: %w", &APIError{StatusCode: 500, Body: "err"}), true}, + {"wrapped 404", fmt.Errorf("fetch: %w", &APIError{StatusCode: 404, Body: "err"}), false}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := IsServerError(tt.err) + if got != tt.want { + t.Errorf("IsServerError(%v) = %v, want %v", tt.err, got, tt.want) + } + }) + } +} + +func TestDoGet_RetriesOn500(t *testing.T) { + attempts := 0 + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + attempts++ + if attempts < 3 { + w.WriteHeader(http.StatusInternalServerError) + w.Write([]byte(`{"message":"transient error"}`)) + return + } + w.WriteHeader(http.StatusOK) + w.Write([]byte(`{"data":"success"}`)) + })) + defer server.Close() + + client := NewClient(server.URL, "test-token") + body, err := client.doGet(context.Background(), server.URL+"/test") + if err != nil { + t.Fatalf("expected success after retry, got error: %v", err) + } + if string(body) != `{"data":"success"}` { + t.Errorf("body = %q, want %q", string(body), `{"data":"success"}`) + } + if attempts != 3 { + t.Errorf("attempts = %d, want 3", attempts) + } +} + +func TestDoGet_FailsAfterMaxRetries(t *testing.T) { + attempts := 0 + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + attempts++ + w.WriteHeader(http.StatusInternalServerError) + w.Write([]byte(`{"message":"persistent error"}`)) + })) + defer server.Close() + + client := NewClient(server.URL, "test-token") + _, err := client.doGet(context.Background(), server.URL+"/test") + if err == nil { + t.Fatal("expected error after max retries") + } + var apiErr *APIError + if !errors.As(err, &apiErr) { + t.Fatalf("expected APIError, got: %v", err) + } + if apiErr.StatusCode != http.StatusInternalServerError { + t.Errorf("status = %d, want 500", apiErr.StatusCode) + } + if attempts != 3 { + t.Errorf("attempts = %d, want 3 (max retries)", attempts) + } +} + +func TestDoGet_NoRetryOn4xx(t *testing.T) { + attempts := 0 + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + attempts++ + w.WriteHeader(http.StatusForbidden) + w.Write([]byte(`{"message":"forbidden"}`)) + })) + defer server.Close() + + client := NewClient(server.URL, "test-token") + _, err := client.doGet(context.Background(), server.URL+"/test") + if err == nil { + t.Fatal("expected error for 403") + } + var apiErr *APIError + if !errors.As(err, &apiErr) { + t.Fatalf("expected APIError, got: %v", err) + } + if apiErr.StatusCode != http.StatusForbidden { + t.Errorf("status = %d, want 403", apiErr.StatusCode) + } + if attempts != 1 { + t.Errorf("attempts = %d, want 1 (no retry on 4xx)", attempts) + } +} + +func TestDoGet_RespectsContextCancellation(t *testing.T) { + attempts := 0 + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + attempts++ + w.WriteHeader(http.StatusInternalServerError) + w.Write([]byte(`{"message":"error"}`)) + })) + defer server.Close() + + ctx, cancel := context.WithCancel(context.Background()) + // Cancel immediately after first attempt would trigger retry + go func() { + time.Sleep(50 * time.Millisecond) + cancel() + }() + + client := NewClient(server.URL, "test-token") + _, err := client.doGet(ctx, server.URL+"/test") + if err == nil { + t.Fatal("expected error on context cancellation") + } + // Should have made 1 attempt, then context cancelled during backoff + if attempts > 2 { + t.Errorf("attempts = %d, expected at most 2 before context cancel", attempts) + } +}