diff --git a/github/client.go b/github/client.go index 7f5c7f9..1555b34 100644 --- a/github/client.go +++ b/github/client.go @@ -47,6 +47,13 @@ func (e *APIError) Error() string { return fmt.Sprintf("HTTP %d: %s", e.StatusCode, body) } +// SafeError returns the error string without response body content, +// suitable for logging in contexts where upstream response data should +// not be exposed. +func (e *APIError) SafeError() string { + return fmt.Sprintf("HTTP %d", e.StatusCode) +} + // IsNotFound reports whether an error is an API 404 response. func IsNotFound(err error) bool { if apiErr, ok := asAPIError(err); ok { @@ -172,6 +179,12 @@ func (c *Client) SetHTTPClient(hc *http.Client) { Timeout: 30 * time.Second, CheckRedirect: defaultCheckRedirect, } + } else if hc.CheckRedirect == nil { + // Enforce safe redirect policy when caller provides a client without one. + // The default net/http behavior follows up to 10 redirects and forwards + // all headers (including Authorization) to any host, which can leak + // credentials on cross-host redirects. + hc.CheckRedirect = defaultCheckRedirect } c.httpClient = hc } @@ -252,10 +265,11 @@ func (c *Client) doRequest(ctx context.Context, method, reqURL string, accept st resp, err := c.httpClient.Do(req) if err != nil { + // Transport errors (DNS, TLS, timeout) yield nil resp; no body to close. return nil, fmt.Errorf("do request: %w", err) } - body, done, err := c.handleResponse(resp, maxResponseBytes, maxErrorBodyBytes) + body, done, err := handleResponse(resp, maxResponseBytes, maxErrorBodyBytes) if done { return body, err } @@ -300,7 +314,7 @@ func (c *Client) doRequest(ctx context.Context, method, reqURL string, accept st // handleResponse reads and closes the response body, returning the result. // It uses defer to ensure the body is always closed regardless of code path. // Returns (body, done, err) where done=true means the caller should return immediately. -func (c *Client) handleResponse(resp *http.Response, maxRespBytes int, maxErrBytes int) ([]byte, bool, error) { +func handleResponse(resp *http.Response, maxRespBytes int, maxErrBytes int) ([]byte, bool, error) { defer resp.Body.Close() if resp.StatusCode >= 200 && resp.StatusCode < 300 { diff --git a/github/client_test.go b/github/client_test.go index a8ccc06..f1ccfcd 100644 --- a/github/client_test.go +++ b/github/client_test.go @@ -554,3 +554,46 @@ func TestSetHTTPClient_NilRestoresDefault(t *testing.T) { t.Fatal("expected CheckRedirect policy after SetHTTPClient(nil)") } } + +func TestSetHTTPClient_NilCheckRedirectEnforcesDefault(t *testing.T) { + c := NewClient("token", "https://api.github.com") + // Provide a client with nil CheckRedirect — should get default policy enforced. + hc := &http.Client{Timeout: 5 * time.Second} + c.SetHTTPClient(hc) + if c.httpClient.CheckRedirect == nil { + t.Fatal("expected CheckRedirect to be enforced when caller provides nil") + } + if c.httpClient.Timeout != 5*time.Second { + t.Errorf("expected caller's timeout preserved, got %v", c.httpClient.Timeout) + } +} + +func TestSetHTTPClient_PreservesCustomCheckRedirect(t *testing.T) { + c := NewClient("token", "https://api.github.com") + called := false + hc := &http.Client{ + CheckRedirect: func(req *http.Request, via []*http.Request) error { + called = true + return nil + }, + } + c.SetHTTPClient(hc) + // Invoke the redirect to verify original is preserved + _ = c.httpClient.CheckRedirect(nil, []*http.Request{{}}) + if !called { + t.Fatal("expected custom CheckRedirect to be preserved") + } +} + +func TestAPIError_SafeError(t *testing.T) { + e := &APIError{StatusCode: 403, Body: "some sensitive body content"} + got := e.SafeError() + if got != "HTTP 403" { + t.Errorf("SafeError() = %q, want %q", got, "HTTP 403") + } + // Ensure Error() still includes body + full := e.Error() + if full != "HTTP 403: some sensitive body content" { + t.Errorf("Error() = %q, unexpected", full) + } +} diff --git a/github/files.go b/github/files.go index 9f04941..bb301f0 100644 --- a/github/files.go +++ b/github/files.go @@ -81,7 +81,7 @@ func (c *Client) ListContents(ctx context.Context, owner, repo, path string) ([] if err := json.Unmarshal(body, &entries); err != nil { var single entry if err2 := json.Unmarshal(body, &single); err2 != nil { - return nil, fmt.Errorf("parse contents JSON: as array: %v; as object: %w", err, err2) + return nil, fmt.Errorf("parse contents JSON: as array: %w; as object: %w", err, err2) } // Guard against empty objects ({}) or unexpected shapes that // unmarshal successfully but carry no useful data.