From fce5f2d1840652de0eb3d8bc9068f9a313b0e2a1 Mon Sep 17 00:00:00 2001 From: claw Date: Tue, 12 May 2026 16:55:32 -0700 Subject: [PATCH] fix(github): address review findings on client.go - Use net/url.Parse for HTTPS scheme check (case-insensitive) - Guard SetHTTPClient against nil (restores default 30s client) - Rename 'url' param to 'reqURL' in doRequest/doGet for clarity - Return error when response exceeds maxResponseBytes instead of silently truncating Finding #1 (Bearer auth scheme) intentionally kept: GitHub REST API officially supports and recommends Bearer for all token types. See: https://docs.github.com/en/rest/authentication/authenticating-to-the-rest-api --- github/client.go | 26 ++++++++++++++++++++------ github/client_test.go | 25 +++++++++++++++++-------- 2 files changed, 37 insertions(+), 14 deletions(-) diff --git a/github/client.go b/github/client.go index c148f96..f599305 100644 --- a/github/client.go +++ b/github/client.go @@ -9,6 +9,7 @@ import ( "fmt" "io" "net/http" + "net/url" "strconv" "strings" "time" @@ -132,7 +133,11 @@ func NewClient(token, baseURL string, opts ...ClientOption) *Client { // SetHTTPClient sets the underlying HTTP client used for requests. // This is intended for testing to inject mock transports. +// Passing nil will restore the default client with a 30s timeout. func (c *Client) SetHTTPClient(hc *http.Client) { + if hc == nil { + hc = &http.Client{Timeout: 30 * time.Second} + } c.httpClient = hc } @@ -145,7 +150,7 @@ func (c *Client) SetRetryBackoff(d []time.Duration) { // doRequest performs an HTTP request with retry on 429 rate limit responses. // It respects the Retry-After header when present (capped at maxRetryAfter). // Transport errors (network failures, context cancellation) are not retried. -func (c *Client) doRequest(ctx context.Context, method, url string, accept string) ([]byte, error) { +func (c *Client) doRequest(ctx context.Context, method, reqURL string, accept string) ([]byte, error) { const maxAttempts = 3 const maxRetryAfter = 120 * time.Second @@ -160,8 +165,14 @@ func (c *Client) doRequest(ctx context.Context, method, url string, accept strin const maxErrorBodyBytes = 64 * 1024 // Reject non-HTTPS URLs early since the URL is immutable across retries. - if c.token != "" && !c.allowInsecureHTTP && !strings.HasPrefix(url, "https://") { - return nil, fmt.Errorf("refusing to send credentials over non-HTTPS URL %q (use AllowInsecureHTTP option for trusted networks)", url) + if c.token != "" && !c.allowInsecureHTTP { + parsed, err := url.Parse(reqURL) + if err != nil { + return nil, fmt.Errorf("parse request URL: %w", err) + } + if !strings.EqualFold(parsed.Scheme, "https") { + return nil, fmt.Errorf("refusing to send credentials over non-HTTPS URL %q (use AllowInsecureHTTP option for trusted networks)", reqURL) + } } var lastErr error @@ -183,7 +194,7 @@ func (c *Client) doRequest(ctx context.Context, method, url string, accept strin } } - req, err := http.NewRequestWithContext(ctx, method, url, nil) + req, err := http.NewRequestWithContext(ctx, method, reqURL, nil) if err != nil { return nil, fmt.Errorf("create request: %w", err) } @@ -208,6 +219,9 @@ func (c *Client) doRequest(ctx context.Context, method, url string, accept strin if err != nil { return nil, fmt.Errorf("read response body: %w", err) } + if int64(len(body)) >= maxResponseBytes { + return nil, fmt.Errorf("response body exceeded %d bytes (truncated)", maxResponseBytes) + } return body, nil } @@ -241,6 +255,6 @@ func (c *Client) doRequest(ctx context.Context, method, url string, accept strin } // doGet is a convenience wrapper for GET requests with the default Accept header. -func (c *Client) doGet(ctx context.Context, url string) ([]byte, error) { - return c.doRequest(ctx, http.MethodGet, url, "") +func (c *Client) doGet(ctx context.Context, reqURL string) ([]byte, error) { + return c.doRequest(ctx, http.MethodGet, reqURL, "") } diff --git a/github/client_test.go b/github/client_test.go index ea03ea2..73cf1df 100644 --- a/github/client_test.go +++ b/github/client_test.go @@ -282,8 +282,7 @@ func TestDoRequest_SetsUserAgentHeader(t *testing.T) { } func TestDoRequest_LimitsResponseBody(t *testing.T) { - // Verify that response body reading is actually bounded by maxResponseBytes. - // Use a small custom limit to avoid allocating 10 MiB in tests. + // Verify that oversized responses return an error rather than silently truncating. bigBody := strings.Repeat("x", maxResponseBytes+1024) srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(200) @@ -293,13 +292,12 @@ func TestDoRequest_LimitsResponseBody(t *testing.T) { c := NewClient("token", srv.URL, AllowInsecureHTTP()) c.SetHTTPClient(srv.Client()) - body, err := c.doGet(context.Background(), srv.URL+"/test") - if err != nil { - t.Fatalf("unexpected error: %v", err) + _, err := c.doGet(context.Background(), srv.URL+"/test") + if err == nil { + t.Fatal("expected error for oversized response body") } - // LimitReader should cap the body at maxResponseBytes - if len(body) > maxResponseBytes { - t.Errorf("expected body <= %d bytes, got %d", maxResponseBytes, len(body)) + if !strings.Contains(err.Error(), "exceeded") { + t.Errorf("expected truncation error, got: %v", err) } } @@ -384,3 +382,14 @@ func TestDoRequest_AllowsHTTPWithInsecureOption(t *testing.T) { t.Errorf("unexpected body: %s", body) } } + +func TestSetHTTPClient_NilRestoresDefault(t *testing.T) { + c := NewClient("token", "https://api.github.com") + c.SetHTTPClient(nil) + if c.httpClient == nil { + t.Fatal("expected non-nil httpClient after SetHTTPClient(nil)") + } + if c.httpClient.Timeout != 30*time.Second { + t.Errorf("expected 30s timeout, got %v", c.httpClient.Timeout) + } +}