diff --git a/github/client.go b/github/client.go index c3ea252..7f5c7f9 100644 --- a/github/client.go +++ b/github/client.go @@ -255,25 +255,11 @@ func (c *Client) doRequest(ctx context.Context, method, reqURL string, accept st return nil, fmt.Errorf("do request: %w", err) } - if resp.StatusCode >= 200 && resp.StatusCode < 300 { - body, err := io.ReadAll(io.LimitReader(resp.Body, maxResponseBytes+1)) - resp.Body.Close() - if err != nil { - return nil, fmt.Errorf("read response body: %w", err) - } - if len(body) > maxResponseBytes { - return nil, fmt.Errorf("response body exceeded %d bytes (truncated)", maxResponseBytes) - } - return body, nil + body, done, err := c.handleResponse(resp, maxResponseBytes, maxErrorBodyBytes) + if done { + return body, err } - - errBody, readErr := io.ReadAll(io.LimitReader(resp.Body, maxErrorBodyBytes)) - if readErr != nil && len(errBody) == 0 { - errBody = []byte(fmt.Sprintf("[error reading response body: %v]", readErr)) - } - resp.Body.Close() - - lastErr = &APIError{StatusCode: resp.StatusCode, Body: string(errBody)} + lastErr = err // Retry on 429 rate limit if resp.StatusCode == http.StatusTooManyRequests && attempt < maxAttempts-1 { @@ -311,6 +297,30 @@ func (c *Client) doRequest(ctx context.Context, method, reqURL string, accept st return nil, lastErr } +// handleResponse reads and closes the response body, returning the result. +// It uses defer to ensure the body is always closed regardless of code path. +// Returns (body, done, err) where done=true means the caller should return immediately. +func (c *Client) handleResponse(resp *http.Response, maxRespBytes int, maxErrBytes int) ([]byte, bool, error) { + defer resp.Body.Close() + + if resp.StatusCode >= 200 && resp.StatusCode < 300 { + body, err := io.ReadAll(io.LimitReader(resp.Body, int64(maxRespBytes)+1)) + if err != nil { + return nil, true, fmt.Errorf("read response body: %w", err) + } + if len(body) > maxRespBytes { + return nil, true, fmt.Errorf("response body exceeded %d bytes (truncated)", maxRespBytes) + } + return body, true, nil + } + + errBody, readErr := io.ReadAll(io.LimitReader(resp.Body, int64(maxErrBytes))) + if readErr != nil && len(errBody) == 0 { + errBody = []byte(fmt.Sprintf("[error reading response body: %v]", readErr)) + } + return nil, false, &APIError{StatusCode: resp.StatusCode, Body: string(errBody)} +} + // doGet is a convenience wrapper for GET requests with the default Accept header. func (c *Client) doGet(ctx context.Context, reqURL string) ([]byte, error) { return c.doRequest(ctx, http.MethodGet, reqURL, "")