diff --git a/github/client.go b/github/client.go index 9fd7eb8..2d5204e 100644 --- a/github/client.go +++ b/github/client.go @@ -193,10 +193,24 @@ func (c *Client) SetRetryBackoff(d []time.Duration) error { return nil } -// doRequest performs an HTTP request with retry on 429 rate limit responses. -// It respects the Retry-After header when present (capped at maxRetryAfter). -// Transport errors (network failures, context cancellation) are not retried. -func (c *Client) doRequest(ctx context.Context, method, reqURL string, accept string) ([]byte, error) { +// requestOptions holds per-request configuration for doRequestCore. +type requestOptions struct { + // bodyFn returns a fresh io.Reader for the request body on each attempt. + // Must be non-nil for requests that carry a body (POST, PUT, PATCH). + // Returning a fresh reader on each call allows retries to re-send the body. + bodyFn func() io.Reader + + // accept overrides the default Accept header. Empty means "application/vnd.github+json". + accept string + + // extraHeaders are additional headers to set on each request attempt. + extraHeaders map[string]string +} + +// doRequestCore is the shared implementation for all HTTP requests with retry +// on 429 rate limit responses. It respects the Retry-After header when present +// (capped at maxRetryAfter). Transport errors are not retried. +func (c *Client) doRequestCore(ctx context.Context, method, reqURL string, opts requestOptions) ([]byte, error) { const maxRetryAfter = 120 * time.Second // backoff holds per-attempt delays: backoff[i] is the delay before attempt i+1. @@ -247,7 +261,11 @@ func (c *Client) doRequest(ctx context.Context, method, reqURL string, accept st } } - req, err := http.NewRequestWithContext(ctx, method, reqURL, nil) + var body io.Reader + if opts.bodyFn != nil { + body = opts.bodyFn() + } + req, err := http.NewRequestWithContext(ctx, method, reqURL, body) if err != nil { return nil, fmt.Errorf("create request: %w", err) } @@ -258,11 +276,14 @@ func (c *Client) doRequest(ctx context.Context, method, reqURL string, accept st req.Header.Set("Authorization", "Bearer "+c.token) } req.Header.Set("User-Agent", userAgent) - if accept != "" { - req.Header.Set("Accept", accept) + if opts.accept != "" { + req.Header.Set("Accept", opts.accept) } else { req.Header.Set("Accept", "application/vnd.github+json") } + for k, v := range opts.extraHeaders { + req.Header.Set(k, v) + } resp, err := c.httpClient.Do(req) if err != nil { @@ -273,11 +294,11 @@ func (c *Client) doRequest(ctx context.Context, method, reqURL string, accept st respStatus := resp.StatusCode retryAfterHeader := resp.Header.Get("Retry-After") - body, done, err := c.handleResponse(resp, maxResponseBytes, maxErrorBodyBytes) + respBody, done, handleErr := c.handleResponse(resp, maxResponseBytes, maxErrorBodyBytes) if done { - return body, err + return respBody, handleErr } - lastErr = err + lastErr = handleErr // Retry on 429 rate limit if respStatus == http.StatusTooManyRequests && attempt < maxRetryAttempts-1 { @@ -315,6 +336,13 @@ func (c *Client) doRequest(ctx context.Context, method, reqURL string, accept st return nil, lastErr } +// doRequest performs an HTTP request with retry on 429 rate limit responses. +// It respects the Retry-After header when present (capped at maxRetryAfter). +// Transport errors (network failures, context cancellation) are not retried. +func (c *Client) doRequest(ctx context.Context, method, reqURL string, accept string) ([]byte, error) { + return c.doRequestCore(ctx, method, reqURL, requestOptions{accept: accept}) +} + // handleResponse reads and closes the response body, returning the result. // It uses defer to ensure the body is always closed regardless of code path. // Returns (body, done, err) where done=true means the caller should return immediately. @@ -347,109 +375,11 @@ func (c *Client) doGet(ctx context.Context, reqURL string) ([]byte, error) { // doRequestWithBody is like doRequest but sends a request body. // It accepts the raw body bytes and sets Content-Type to application/json. // Retry semantics match doRequest (retries on 429 with Retry-After support). -func (c *Client) doRequestWithBody(ctx context.Context, method, reqURL string, body []byte) ([]byte, error) { - const maxRetryAfter = 120 * time.Second - - defaultBackoff := []time.Duration{1 * time.Second, 2 * time.Second} - var backoff []time.Duration - if c.retryBackoff != nil && len(c.retryBackoff) == maxRetryAttempts-1 { - backoff = make([]time.Duration, len(c.retryBackoff)) - copy(backoff, c.retryBackoff) - } else { - backoff = make([]time.Duration, len(defaultBackoff)) - copy(backoff, defaultBackoff) +func (c *Client) doRequestWithBody(ctx context.Context, method, reqURL string, reqBody []byte) ([]byte, error) { + var opts requestOptions + if reqBody != nil { + opts.bodyFn = func() io.Reader { return bytes.NewReader(reqBody) } + opts.extraHeaders = map[string]string{"Content-Type": "application/json"} } - - const maxErrorBodyBytes = 4 * 1024 - - if c.token != "" && !c.allowInsecureHTTP { - parsed, err := url.Parse(reqURL) - if err != nil { - return nil, fmt.Errorf("parse request URL: %w", err) - } - if !strings.EqualFold(parsed.Scheme, "https") { - return nil, fmt.Errorf("refusing to send credentials over non-HTTPS URL %q (use AllowInsecureHTTP option for trusted networks)", reqURL) - } - } - - var lastErr error - for attempt := 0; attempt < maxRetryAttempts; attempt++ { - if attempt > 0 { - var delay time.Duration - if attempt-1 < len(backoff) { - delay = backoff[attempt-1] - } - if delay > 0 { - timer := time.NewTimer(delay) - select { - case <-timer.C: - timer.Stop() - case <-ctx.Done(): - timer.Stop() - return nil, ctx.Err() - } - } - } - - var bodyReader io.Reader - if body != nil { - bodyReader = bytes.NewReader(body) - } - req, err := http.NewRequestWithContext(ctx, method, reqURL, bodyReader) - if err != nil { - return nil, fmt.Errorf("create request: %w", err) - } - if c.token != "" { - req.Header.Set("Authorization", "Bearer "+c.token) - } - req.Header.Set("User-Agent", userAgent) - req.Header.Set("Accept", "application/vnd.github+json") - if body != nil { - req.Header.Set("Content-Type", "application/json") - } - - resp, err := c.httpClient.Do(req) - if err != nil { - return nil, fmt.Errorf("do request: %w", err) - } - - respStatus := resp.StatusCode - retryAfterHeader := resp.Header.Get("Retry-After") - - respBody, done, handleErr := c.handleResponse(resp, maxResponseBytes, maxErrorBodyBytes) - if done { - return respBody, handleErr - } - lastErr = handleErr - - if respStatus == http.StatusTooManyRequests && attempt < maxRetryAttempts-1 { - if ra := retryAfterHeader; ra != "" { - if seconds, err := strconv.Atoi(ra); err == nil && seconds > 0 { - delay := time.Duration(seconds) * time.Second - if delay > maxRetryAfter { - delay = maxRetryAfter - } - if attempt < len(backoff) { - backoff[attempt] = delay - } - } else if retryAt, err := http.ParseTime(ra); err == nil { - delay := time.Until(retryAt) - if delay < 0 { - delay = 0 - } - if delay > maxRetryAfter { - delay = maxRetryAfter - } - if attempt < len(backoff) { - backoff[attempt] = delay - } - } - } - continue - } - - return nil, lastErr - } - - return nil, lastErr + return c.doRequestCore(ctx, method, reqURL, opts) } diff --git a/github/review.go b/github/review.go index e9e20cd..785175b 100644 --- a/github/review.go +++ b/github/review.go @@ -5,6 +5,7 @@ import ( "encoding/json" "errors" "fmt" + "net/http" "net/url" "gitea.weiker.me/rodin/review-bot/vcs" @@ -52,15 +53,13 @@ type dismissReviewRequest struct { // canonical vcs.Review.State value. func translateGitHubReviewState(state string) string { switch state { - case "APPROVED": - return "APPROVED" case "CHANGES_REQUESTED": return "REQUEST_CHANGES" case "COMMENTED": return "COMMENT" - case "DISMISSED": - return "DISMISSED" default: + // States like APPROVED, DISMISSED, and PENDING pass through unchanged + // as they already match the canonical vcs representation. return state } } @@ -100,7 +99,7 @@ func (c *Client) PostReview(ctx context.Context, owner, repo string, number int, return nil, fmt.Errorf("marshal review request: %w", err) } - body, err := c.doRequestWithBody(ctx, "POST", reqURL, data) + body, err := c.doRequestWithBody(ctx, http.MethodPost, reqURL, data) if err != nil { return nil, fmt.Errorf("post review: %w", err) } @@ -150,12 +149,13 @@ func (c *Client) ListReviews(ctx context.Context, owner, repo string, number int // DeleteReview deletes a pull request review. // Only PENDING reviews can be deleted; attempting to delete a submitted review -// (APPROVED, CHANGES_REQUESTED, COMMENTED) returns ErrCannotDeleteSubmittedReview. +// (APPROVED, CHANGES_REQUESTED, or COMMENTED per GitHub API naming) returns +// ErrCannotDeleteSubmittedReview. func (c *Client) DeleteReview(ctx context.Context, owner, repo string, number int, reviewID int64) error { reqURL := fmt.Sprintf("%s/repos/%s/%s/pulls/%d/reviews/%d", c.baseURL, url.PathEscape(owner), url.PathEscape(repo), number, reviewID) - _, err := c.doRequestWithBody(ctx, "DELETE", reqURL, nil) + _, err := c.doRequestWithBody(ctx, http.MethodDelete, reqURL, nil) if err != nil { var apiErr *APIError if errors.As(err, &apiErr) && apiErr.StatusCode == 422 { @@ -183,7 +183,7 @@ func (c *Client) DismissReview(ctx context.Context, owner, repo string, number i return fmt.Errorf("marshal dismiss request: %w", err) } - _, err = c.doRequestWithBody(ctx, "PUT", reqURL, data) + _, err = c.doRequestWithBody(ctx, http.MethodPut, reqURL, data) if err != nil { return fmt.Errorf("dismiss review: %w", err) } diff --git a/github/review_test.go b/github/review_test.go index 50e6700..1c3d0ac 100644 --- a/github/review_test.go +++ b/github/review_test.go @@ -7,6 +7,7 @@ import ( "io" "net/http" "net/http/httptest" + "strings" "testing" "time" @@ -153,7 +154,7 @@ func TestPostReview_MalformedResponse(t *testing.T) { if err == nil { t.Fatal("expected error for malformed response") } - if !containsStr(err.Error(), "parse review response") { + if !strings.Contains(err.Error(), "parse review response") { t.Errorf("expected parse error, got: %v", err) } } @@ -379,16 +380,4 @@ func TestTranslateGitHubReviewState(t *testing.T) { } } -// containsStr is a test helper for checking error messages. -func containsStr(s, substr string) bool { - return len(s) >= len(substr) && (s == substr || len(s) > 0 && containsSubstring(s, substr)) -} -func containsSubstring(s, sub string) bool { - for i := 0; i <= len(s)-len(sub); i++ { - if s[i:i+len(sub)] == sub { - return true - } - } - return false -}