From 1fcc0b738a1fa8e036951e6a6b055a5a64aeabbb Mon Sep 17 00:00:00 2001 From: claw Date: Tue, 12 May 2026 17:13:07 -0700 Subject: [PATCH] fix(github): address MINOR/NIT findings from review #2866 - SetHTTPClient(nil): preserve CheckRedirect auth-stripping policy instead of restoring a plain http.Client that loses cross-host protection. - Authorization header: add comment documenting why Bearer scheme is correct (OAuth2 standard, works for both classic PATs and fine-grained tokens). - Retry-After parsing: support HTTP-date format (RFC 7231) in addition to integer seconds. GitHub only sends integers today, but the implementation is now spec-compliant. - escapePath dot-segment removal: document the behavior in public API doc comments for ListContents and GetFileContentAtRef so callers are aware without reading the internal helper. --- github/client.go | 34 ++++++++++++++++-- github/client_test.go | 81 +++++++++++++++++++++++++++++++++++++++++++ github/files.go | 4 +++ github/pr.go | 4 +++ 4 files changed, 120 insertions(+), 3 deletions(-) diff --git a/github/client.go b/github/client.go index f599305..69baa36 100644 --- a/github/client.go +++ b/github/client.go @@ -133,10 +133,23 @@ func NewClient(token, baseURL string, opts ...ClientOption) *Client { // SetHTTPClient sets the underlying HTTP client used for requests. // This is intended for testing to inject mock transports. -// Passing nil will restore the default client with a 30s timeout. +// Passing nil restores the default client (30s timeout + auth-stripping +// CheckRedirect policy matching NewClient). func (c *Client) SetHTTPClient(hc *http.Client) { if hc == nil { - hc = &http.Client{Timeout: 30 * time.Second} + hc = &http.Client{ + Timeout: 30 * time.Second, + CheckRedirect: func(req *http.Request, via []*http.Request) error { + if len(via) >= 10 { + return fmt.Errorf("stopped after 10 redirects") + } + prev := via[len(via)-1] + if req.URL.Host != prev.URL.Host || (prev.URL.Scheme == "https" && req.URL.Scheme == "http") { + req.Header.Del("Authorization") + } + return nil + }, + } } c.httpClient = hc } @@ -199,6 +212,9 @@ func (c *Client) doRequest(ctx context.Context, method, reqURL string, accept st return nil, fmt.Errorf("create request: %w", err) } if c.token != "" { + // Bearer is the OAuth2 standard and is accepted by GitHub for both + // classic PATs and fine-grained tokens. The alternative "token" scheme + // is GitHub-specific and offers no additional compatibility. req.Header.Set("Authorization", "Bearer "+c.token) } req.Header.Set("User-Agent", userAgent) @@ -232,7 +248,8 @@ func (c *Client) doRequest(ctx context.Context, method, reqURL string, accept st // Retry on 429 rate limit if resp.StatusCode == http.StatusTooManyRequests && attempt < maxAttempts-1 { - // Check for Retry-After header and override backoff if present + // Check for Retry-After header and override backoff if present. + // Supports both integer seconds (common) and HTTP-date format (RFC 7231). if ra := resp.Header.Get("Retry-After"); ra != "" { if seconds, err := strconv.Atoi(ra); err == nil && seconds > 0 { delay := time.Duration(seconds) * time.Second @@ -242,6 +259,17 @@ func (c *Client) doRequest(ctx context.Context, method, reqURL string, accept st if attempt < len(backoff) { backoff[attempt] = delay } + } else if t, err := http.ParseTime(ra); err == nil { + delay := time.Until(t) + if delay < 0 { + delay = 0 + } + if delay > maxRetryAfter { + delay = maxRetryAfter + } + if attempt < len(backoff) { + backoff[attempt] = delay + } } } continue diff --git a/github/client_test.go b/github/client_test.go index 73cf1df..d94bf8e 100644 --- a/github/client_test.go +++ b/github/client_test.go @@ -263,6 +263,84 @@ func TestDoRequest_RetryAfterDoesNotMutateBackoff(t *testing.T) { } } +func TestDoRequest_429RetryAfterHTTPDate(t *testing.T) { + attempts := 0 + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + attempts++ + if attempts == 1 { + // Use HTTP-date format (RFC 7231) — a time 2 seconds in the future. + future := time.Now().Add(2 * time.Second).UTC() + w.Header().Set("Retry-After", future.Format(http.TimeFormat)) + w.WriteHeader(429) + w.Write([]byte(`{"message":"rate limit"}`)) + return + } + w.WriteHeader(200) + w.Write([]byte(`{"ok":true}`)) + })) + defer srv.Close() + + c := NewClient("token", srv.URL, AllowInsecureHTTP()) + c.SetHTTPClient(srv.Client()) + c.SetRetryBackoff([]time.Duration{1 * time.Millisecond, 1 * time.Millisecond}) + + start := time.Now() + body, err := c.doGet(context.Background(), srv.URL+"/test") + elapsed := time.Since(start) + + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if string(body) != `{"ok":true}` { + t.Errorf("unexpected body: %s", body) + } + if attempts != 2 { + t.Errorf("expected 2 attempts, got %d", attempts) + } + // HTTP-date was ~2s in the future; by the time client processes it, + // time.Until gives ~1-2s. Verify it's meaningfully delayed (not instant). + if elapsed < 500*time.Millisecond { + t.Errorf("expected meaningful delay from HTTP-date Retry-After, got %v", elapsed) + } +} + +func TestDoRequest_429RetryAfterHTTPDateInPast(t *testing.T) { + attempts := 0 + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + attempts++ + if attempts == 1 { + // Use a time in the past — should result in zero/immediate retry. + past := time.Now().Add(-10 * time.Second).UTC() + w.Header().Set("Retry-After", past.Format(http.TimeFormat)) + w.WriteHeader(429) + w.Write([]byte(`{"message":"rate limit"}`)) + return + } + w.WriteHeader(200) + w.Write([]byte(`{"ok":true}`)) + })) + defer srv.Close() + + c := NewClient("token", srv.URL, AllowInsecureHTTP()) + c.SetHTTPClient(srv.Client()) + c.SetRetryBackoff([]time.Duration{5 * time.Second, 5 * time.Second}) + + start := time.Now() + _, err := c.doGet(context.Background(), srv.URL+"/test") + elapsed := time.Since(start) + + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if attempts != 2 { + t.Errorf("expected 2 attempts, got %d", attempts) + } + // Past date should override the 5s backoff to ~0 + if elapsed > 500*time.Millisecond { + t.Errorf("expected near-instant retry for past HTTP-date, got %v", elapsed) + } +} + func TestDoRequest_SetsUserAgentHeader(t *testing.T) { var gotUA string srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { @@ -392,4 +470,7 @@ func TestSetHTTPClient_NilRestoresDefault(t *testing.T) { if c.httpClient.Timeout != 30*time.Second { t.Errorf("expected 30s timeout, got %v", c.httpClient.Timeout) } + if c.httpClient.CheckRedirect == nil { + t.Fatal("expected CheckRedirect policy after SetHTTPClient(nil)") + } } diff --git a/github/files.go b/github/files.go index f09d3e5..f9a1cf6 100644 --- a/github/files.go +++ b/github/files.go @@ -22,6 +22,10 @@ func (c *Client) GetFileContent(ctx context.Context, owner, repo, path, ref stri // If the path points to a single file (not a directory), the API returns // a JSON object instead of an array; this is handled by returning a // single-element slice. +// +// Note: dot-segments ("." and "..") in the path are silently removed to +// prevent path traversal. This means a path like "foo/../bar" resolves +// to "foo/bar" rather than "bar". func (c *Client) ListContents(ctx context.Context, owner, repo, path string) ([]vcs.ContentEntry, error) { reqURL := fmt.Sprintf("%s/repos/%s/%s/contents/%s", c.baseURL, url.PathEscape(owner), url.PathEscape(repo), escapePath(path)) diff --git a/github/pr.go b/github/pr.go index 1bb428a..3e984c2 100644 --- a/github/pr.go +++ b/github/pr.go @@ -123,6 +123,10 @@ func (c *Client) GetPullRequestFiles(ctx context.Context, owner, repo string, nu // GetFileContentAtRef fetches a file at a specific ref from a repo. // If ref is empty, the query parameter is omitted (uses default branch). +// +// Note: dot-segments ("." and "..") in the path are silently removed to +// prevent path traversal. This means a path like "foo/../bar" resolves +// to "foo/bar" rather than "bar". func (c *Client) GetFileContentAtRef(ctx context.Context, owner, repo, path, ref string) (string, error) { reqURL := fmt.Sprintf("%s/repos/%s/%s/contents/%s", c.baseURL, url.PathEscape(owner), url.PathEscape(repo), escapePath(path))