From d1ef1e21e547a2ed2903346b18afeb4e3970d0bb Mon Sep 17 00:00:00 2001 From: claw Date: Tue, 12 May 2026 15:16:33 -0700 Subject: [PATCH 01/16] feat(github): implement PRReader + FileReader client (#80) Implement the GitHub API client with PRReader and FileReader interface conformance for both github.com and GitHub Enterprise. New files: - github/client.go: Client struct, NewClient with configurable base URL, HTTP helpers with 429 retry and Retry-After support - github/pr.go: GetPullRequest, GetPullRequestDiff (per-request Accept header), GetPullRequestFiles (paginated, populates Patch field), GetFileContentAtRef (base64 decode), GetCommitStatuses (merges commit statuses + check runs with conclusion mapping) - github/files.go: GetFileContent (delegates to GetFileContentAtRef), ListContents, escapePath, decodeBase64Content helpers Type changes: - vcs/types.go: Add Patch field to ChangedFile struct Tests cover: happy path, 404, 401, 429+retry, malformed response, pagination, binary files, check run conclusion mapping, base64 decoding. Compile-time checks: var _ vcs.PRReader = (*Client)(nil) var _ vcs.FileReader = (*Client)(nil) Exit criteria met: - go test ./github/... passes (all methods) - NewClient with empty baseURL uses https://api.github.com - NewClient with GHE URL targets correctly - GetFileContent delegates to GetFileContentAtRef with empty ref - GetPullRequestFiles paginates and populates Patch field - GetCommitStatuses merges both commit statuses and check-runs --- github/client.go | 177 +++++++++++ github/client_test.go | 224 +++++++++++++ github/conformance_test.go | 13 + github/files.go | 68 ++++ github/files_test.go | 277 ++++++++++++++++ github/pr.go | 233 ++++++++++++++ github/pr_test.go | 637 +++++++++++++++++++++++++++++++++++++ vcs/types.go | 1 + 8 files changed, 1630 insertions(+) create mode 100644 github/client.go create mode 100644 github/client_test.go create mode 100644 github/conformance_test.go create mode 100644 github/files.go create mode 100644 github/files_test.go create mode 100644 github/pr.go create mode 100644 github/pr_test.go diff --git a/github/client.go b/github/client.go new file mode 100644 index 0000000..ab07a24 --- /dev/null +++ b/github/client.go @@ -0,0 +1,177 @@ +// Package github provides a client for the GitHub API. +// It supports pull request operations, file content retrieval, +// and review submission for both github.com and GitHub Enterprise. +package github + +import ( + "context" + "errors" + "fmt" + "io" + "net/http" + "strconv" + "strings" + "time" +) + +const defaultBaseURL = "https://api.github.com" + +// APIError represents an HTTP error response from the GitHub API. +// It carries the status code so callers can distinguish between +// different failure modes (e.g. 404 vs 500). +type APIError struct { + StatusCode int + Body string +} + +func (e *APIError) Error() string { + body := e.Body + if len(body) > 200 { + body = body[:200] + "...(truncated)" + } + return fmt.Sprintf("HTTP %d: %s", e.StatusCode, body) +} + +// IsNotFound reports whether an error is an API 404 response. +func IsNotFound(err error) bool { + if apiErr, ok := asAPIError(err); ok { + return apiErr.StatusCode == http.StatusNotFound + } + return false +} + +// IsUnauthorized reports whether an error is an API 401 response. +func IsUnauthorized(err error) bool { + if apiErr, ok := asAPIError(err); ok { + return apiErr.StatusCode == http.StatusUnauthorized + } + return false +} + +func asAPIError(err error) (*APIError, bool) { + if err == nil { + return nil, false + } + var target *APIError + if errors.As(err, &target) { + return target, true + } + return nil, false +} + +// Client interacts with the GitHub API. +// A Client is safe for concurrent use by multiple goroutines. +type Client struct { + baseURL string + token string + http *http.Client + + // RetryBackoff defines the delays between retry attempts for 429 responses. + // RetryBackoff[i] is the delay before attempt i+1 (after attempt i fails). + // If nil, defaults to {1s, 2s}. Set to shorter durations in tests. + RetryBackoff []time.Duration +} + +// NewClient creates a new GitHub API client. +// If baseURL is empty, it defaults to https://api.github.com. +// For GitHub Enterprise, pass the API base URL (e.g. https://github.concur.com/api/v3). +func NewClient(token, baseURL string) *Client { + if baseURL == "" { + baseURL = defaultBaseURL + } + return &Client{ + baseURL: strings.TrimRight(baseURL, "/"), + token: token, + http: &http.Client{Timeout: 30 * time.Second}, + } +} + +// SetHTTPClient sets the underlying HTTP client used for requests. +// This is intended for testing to inject mock transports. +func (c *Client) SetHTTPClient(hc *http.Client) { + c.http = hc +} + +// doRequest performs an HTTP request with retry on 429 rate limit responses. +// It respects the Retry-After header when present. +func (c *Client) doRequest(ctx context.Context, method, url string, accept string) ([]byte, error) { + const maxAttempts = 3 + backoff := c.RetryBackoff + if backoff == nil { + backoff = []time.Duration{1 * time.Second, 2 * time.Second} + } + + const maxErrorBodyBytes = 64 * 1024 + + var lastErr error + for attempt := 0; attempt < maxAttempts; attempt++ { + if attempt > 0 { + var delay time.Duration + if attempt-1 < len(backoff) { + delay = backoff[attempt-1] + } + if delay > 0 { + timer := time.NewTimer(delay) + select { + case <-timer.C: + case <-ctx.Done(): + timer.Stop() + return nil, ctx.Err() + } + } + } + + req, err := http.NewRequestWithContext(ctx, method, url, nil) + if err != nil { + return nil, fmt.Errorf("create request: %w", err) + } + req.Header.Set("Authorization", "Bearer "+c.token) + if accept != "" { + req.Header.Set("Accept", accept) + } else { + req.Header.Set("Accept", "application/vnd.github+json") + } + + resp, err := c.http.Do(req) + if err != nil { + return nil, fmt.Errorf("do request: %w", err) + } + + if resp.StatusCode >= 200 && resp.StatusCode < 300 { + body, err := io.ReadAll(resp.Body) + resp.Body.Close() + if err != nil { + return nil, fmt.Errorf("read response body: %w", err) + } + return body, nil + } + + errBody, _ := io.ReadAll(io.LimitReader(resp.Body, maxErrorBodyBytes)) + resp.Body.Close() + + lastErr = &APIError{StatusCode: resp.StatusCode, Body: string(errBody)} + + // Retry on 429 rate limit + if resp.StatusCode == http.StatusTooManyRequests && attempt < maxAttempts-1 { + // Check for Retry-After header and override backoff if present + if ra := resp.Header.Get("Retry-After"); ra != "" { + if seconds, err := strconv.Atoi(ra); err == nil && seconds > 0 { + if attempt < len(backoff) { + backoff[attempt] = time.Duration(seconds) * time.Second + } + } + } + continue + } + + // Don't retry other errors + return nil, lastErr + } + + return nil, lastErr +} + +// doGet is a convenience wrapper for GET requests with the default Accept header. +func (c *Client) doGet(ctx context.Context, url string) ([]byte, error) { + return c.doRequest(ctx, http.MethodGet, url, "") +} diff --git a/github/client_test.go b/github/client_test.go new file mode 100644 index 0000000..e3cd121 --- /dev/null +++ b/github/client_test.go @@ -0,0 +1,224 @@ +package github + +import ( + "context" + "net/http" + "net/http/httptest" + "testing" + "time" +) + +func TestNewClient_DefaultBaseURL(t *testing.T) { + c := NewClient("test-token", "") + if c.baseURL != "https://api.github.com" { + t.Errorf("expected default base URL, got %q", c.baseURL) + } +} + +func TestNewClient_CustomBaseURL(t *testing.T) { + c := NewClient("test-token", "https://github.concur.com/api/v3") + if c.baseURL != "https://github.concur.com/api/v3" { + t.Errorf("expected custom base URL, got %q", c.baseURL) + } +} + +func TestNewClient_TrimsTrailingSlash(t *testing.T) { + c := NewClient("test-token", "https://github.concur.com/api/v3/") + if c.baseURL != "https://github.concur.com/api/v3" { + t.Errorf("expected trailing slash trimmed, got %q", c.baseURL) + } +} + +func TestDoRequest_SetsAuthHeader(t *testing.T) { + var gotAuth string + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + gotAuth = r.Header.Get("Authorization") + w.WriteHeader(200) + w.Write([]byte("{}")) + })) + defer srv.Close() + + c := NewClient("my-token", srv.URL) + c.SetHTTPClient(srv.Client()) + _, _ = c.doGet(context.Background(), srv.URL+"/test") + + if gotAuth != "Bearer my-token" { + t.Errorf("expected Bearer auth, got %q", gotAuth) + } +} + +func TestDoRequest_SetsDefaultAcceptHeader(t *testing.T) { + var gotAccept string + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + gotAccept = r.Header.Get("Accept") + w.WriteHeader(200) + w.Write([]byte("{}")) + })) + defer srv.Close() + + c := NewClient("token", srv.URL) + c.SetHTTPClient(srv.Client()) + _, _ = c.doGet(context.Background(), srv.URL+"/test") + + if gotAccept != "application/vnd.github+json" { + t.Errorf("expected default Accept header, got %q", gotAccept) + } +} + +func TestDoRequest_429Retry(t *testing.T) { + attempts := 0 + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + attempts++ + if attempts == 1 { + w.WriteHeader(429) + w.Write([]byte(`{"message":"rate limit"}`)) + return + } + w.WriteHeader(200) + w.Write([]byte(`{"ok":true}`)) + })) + defer srv.Close() + + c := NewClient("token", srv.URL) + c.SetHTTPClient(srv.Client()) + c.RetryBackoff = []time.Duration{10 * time.Millisecond, 10 * time.Millisecond} + + body, err := c.doGet(context.Background(), srv.URL+"/test") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if string(body) != `{"ok":true}` { + t.Errorf("unexpected body: %s", body) + } + if attempts != 2 { + t.Errorf("expected 2 attempts, got %d", attempts) + } +} + +func TestDoRequest_429ExhaustsRetries(t *testing.T) { + attempts := 0 + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + attempts++ + w.WriteHeader(429) + w.Write([]byte(`{"message":"rate limit"}`)) + })) + defer srv.Close() + + c := NewClient("token", srv.URL) + c.SetHTTPClient(srv.Client()) + c.RetryBackoff = []time.Duration{1 * time.Millisecond, 1 * time.Millisecond} + + _, err := c.doGet(context.Background(), srv.URL+"/test") + if err == nil { + t.Fatal("expected error after exhausting retries") + } + apiErr, ok := err.(*APIError) + if !ok { + t.Fatalf("expected *APIError, got %T", err) + } + if apiErr.StatusCode != 429 { + t.Errorf("expected 429, got %d", apiErr.StatusCode) + } + if attempts != 3 { + t.Errorf("expected 3 attempts, got %d", attempts) + } +} + +func TestDoRequest_404NoRetry(t *testing.T) { + attempts := 0 + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + attempts++ + w.WriteHeader(404) + w.Write([]byte(`{"message":"not found"}`)) + })) + defer srv.Close() + + c := NewClient("token", srv.URL) + c.SetHTTPClient(srv.Client()) + + _, err := c.doGet(context.Background(), srv.URL+"/test") + if err == nil { + t.Fatal("expected error for 404") + } + if attempts != 1 { + t.Errorf("expected 1 attempt (no retry on 404), got %d", attempts) + } +} + +func TestDoRequest_401NoRetry(t *testing.T) { + attempts := 0 + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + attempts++ + w.WriteHeader(401) + w.Write([]byte(`{"message":"bad credentials"}`)) + })) + defer srv.Close() + + c := NewClient("token", srv.URL) + c.SetHTTPClient(srv.Client()) + + _, err := c.doGet(context.Background(), srv.URL+"/test") + if err == nil { + t.Fatal("expected error for 401") + } + if attempts != 1 { + t.Errorf("expected 1 attempt (no retry on 401), got %d", attempts) + } +} + +func TestIsNotFound(t *testing.T) { + err := &APIError{StatusCode: 404, Body: "not found"} + if !IsNotFound(err) { + t.Error("expected IsNotFound to return true for 404") + } + err2 := &APIError{StatusCode: 500, Body: "server error"} + if IsNotFound(err2) { + t.Error("expected IsNotFound to return false for 500") + } +} + +func TestIsUnauthorized(t *testing.T) { + err := &APIError{StatusCode: 401, Body: "bad credentials"} + if !IsUnauthorized(err) { + t.Error("expected IsUnauthorized to return true for 401") + } +} + +func TestDoRequest_429RetryAfterHeader(t *testing.T) { + attempts := 0 + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + attempts++ + if attempts == 1 { + w.Header().Set("Retry-After", "1") + w.WriteHeader(429) + w.Write([]byte(`{"message":"rate limit"}`)) + return + } + w.WriteHeader(200) + w.Write([]byte(`{"ok":true}`)) + })) + defer srv.Close() + + c := NewClient("token", srv.URL) + c.SetHTTPClient(srv.Client()) + // Use short backoff; Retry-After should override + c.RetryBackoff = []time.Duration{1 * time.Millisecond, 1 * time.Millisecond} + + start := time.Now() + body, err := c.doGet(context.Background(), srv.URL+"/test") + elapsed := time.Since(start) + + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if string(body) != `{"ok":true}` { + t.Errorf("unexpected body: %s", body) + } + if attempts != 2 { + t.Errorf("expected 2 attempts, got %d", attempts) + } + // Retry-After: 1 means at least 1 second delay + if elapsed < 900*time.Millisecond { + t.Errorf("expected ~1s delay from Retry-After, got %v", elapsed) + } +} diff --git a/github/conformance_test.go b/github/conformance_test.go new file mode 100644 index 0000000..ca13188 --- /dev/null +++ b/github/conformance_test.go @@ -0,0 +1,13 @@ +package github_test + +import ( + "gitea.weiker.me/rodin/review-bot/github" + "gitea.weiker.me/rodin/review-bot/vcs" +) + +// Compile-time interface conformance assertions. +// These verify github.Client satisfies vcs.PRReader and vcs.FileReader. +var ( + _ vcs.PRReader = (*github.Client)(nil) + _ vcs.FileReader = (*github.Client)(nil) +) diff --git a/github/files.go b/github/files.go new file mode 100644 index 0000000..9c162bf --- /dev/null +++ b/github/files.go @@ -0,0 +1,68 @@ +package github + +import ( + "context" + "encoding/base64" + "encoding/json" + "fmt" + "net/url" + "strings" + + "gitea.weiker.me/rodin/review-bot/vcs" +) + +// GetFileContent fetches a file from the default branch of a repo. +// Delegates to GetFileContentAtRef with an empty ref. +func (c *Client) GetFileContent(ctx context.Context, owner, repo, path, ref string) (string, error) { + return c.GetFileContentAtRef(ctx, owner, repo, path, ref) +} + +// ListContents lists files and directories at a given path in a repo. +// Returns the directory listing from the GitHub contents API. +func (c *Client) ListContents(ctx context.Context, owner, repo, path string) ([]vcs.ContentEntry, error) { + reqURL := fmt.Sprintf("%s/repos/%s/%s/contents/%s", + c.baseURL, url.PathEscape(owner), url.PathEscape(repo), escapePath(path)) + body, err := c.doGet(ctx, reqURL) + if err != nil { + return nil, fmt.Errorf("list contents %s: %w", path, err) + } + var entries []struct { + Name string `json:"name"` + Path string `json:"path"` + Type string `json:"type"` + } + if err := json.Unmarshal(body, &entries); err != nil { + return nil, fmt.Errorf("parse contents JSON: %w", err) + } + result := make([]vcs.ContentEntry, len(entries)) + for i, e := range entries { + result[i] = vcs.ContentEntry{ + Name: e.Name, + Path: e.Path, + Type: e.Type, + } + } + return result, nil +} + +// escapePath escapes each segment of a relative file path for use in URLs. +// Slashes are preserved as path separators; other special characters are escaped. +func escapePath(p string) string { + parts := strings.Split(p, "/") + for i, part := range parts { + parts[i] = url.PathEscape(part) + } + return strings.Join(parts, "/") +} + +// decodeBase64Content decodes base64-encoded content from the GitHub contents API. +// GitHub returns base64 content with newlines for formatting, which we strip before decoding. +func decodeBase64Content(encoded string) (string, error) { + // GitHub inserts newlines in base64 content + cleaned := strings.ReplaceAll(encoded, "\n", "") + decoded, err := base64.StdEncoding.DecodeString(cleaned) + if err != nil { + return "", err + } + return string(decoded), nil +} diff --git a/github/files_test.go b/github/files_test.go new file mode 100644 index 0000000..bb76a0b --- /dev/null +++ b/github/files_test.go @@ -0,0 +1,277 @@ +package github + +import ( + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + "time" +) + +func TestGetFileContent_DelegatesToGetFileContentAtRef(t *testing.T) { + var gotRef string + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + gotRef = r.URL.Query().Get("ref") + json.NewEncoder(w).Encode(map[string]string{ + "content": "dGVzdA==", // "test" in base64 + "encoding": "base64", + }) + })) + defer srv.Close() + + c := NewClient("token", srv.URL) + c.SetHTTPClient(srv.Client()) + + // Call with empty ref — should not include ref param + content, err := c.GetFileContent(context.Background(), "owner", "repo", "file.go", "") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if content != "test" { + t.Errorf("expected 'test', got %q", content) + } + if gotRef != "" { + t.Errorf("expected empty ref, got %q", gotRef) + } +} + +func TestGetFileContent_WithRef(t *testing.T) { + var gotRef string + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + gotRef = r.URL.Query().Get("ref") + json.NewEncoder(w).Encode(map[string]string{ + "content": "dGVzdA==", + "encoding": "base64", + }) + })) + defer srv.Close() + + c := NewClient("token", srv.URL) + c.SetHTTPClient(srv.Client()) + + _, err := c.GetFileContent(context.Background(), "owner", "repo", "file.go", "abc123") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if gotRef != "abc123" { + t.Errorf("expected ref 'abc123', got %q", gotRef) + } +} + +func TestGetFileContent_404(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(404) + w.Write([]byte(`{"message":"Not Found"}`)) + })) + defer srv.Close() + + c := NewClient("token", srv.URL) + c.SetHTTPClient(srv.Client()) + + _, err := c.GetFileContent(context.Background(), "owner", "repo", "missing.go", "") + if err == nil { + t.Fatal("expected error for 404") + } +} + +func TestGetFileContent_401(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(401) + w.Write([]byte(`{"message":"Bad credentials"}`)) + })) + defer srv.Close() + + c := NewClient("token", srv.URL) + c.SetHTTPClient(srv.Client()) + + _, err := c.GetFileContent(context.Background(), "owner", "repo", "file.go", "") + if err == nil { + t.Fatal("expected error for 401") + } +} + +func TestGetFileContent_429Retry(t *testing.T) { + attempts := 0 + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + attempts++ + if attempts == 1 { + w.WriteHeader(429) + w.Write([]byte(`{"message":"rate limit"}`)) + return + } + json.NewEncoder(w).Encode(map[string]string{ + "content": "b2s=", + "encoding": "base64", + }) + })) + defer srv.Close() + + c := NewClient("token", srv.URL) + c.SetHTTPClient(srv.Client()) + c.RetryBackoff = []time.Duration{1 * time.Millisecond} + + content, err := c.GetFileContent(context.Background(), "owner", "repo", "file.go", "") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if content != "ok" { + t.Errorf("expected 'ok', got %q", content) + } + if attempts != 2 { + t.Errorf("expected 2 attempts, got %d", attempts) + } +} + +func TestGetFileContent_MalformedJSON(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(200) + w.Write([]byte(`not json`)) + })) + defer srv.Close() + + c := NewClient("token", srv.URL) + c.SetHTTPClient(srv.Client()) + + _, err := c.GetFileContent(context.Background(), "owner", "repo", "file.go", "") + if err == nil { + t.Fatal("expected error for malformed JSON") + } +} + +func TestListContents_HappyPath(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/repos/owner/repo/contents/src" { + t.Errorf("unexpected path: %s", r.URL.Path) + } + json.NewEncoder(w).Encode([]map[string]string{ + {"name": "main.go", "path": "src/main.go", "type": "file"}, + {"name": "lib", "path": "src/lib", "type": "dir"}, + }) + })) + defer srv.Close() + + c := NewClient("token", srv.URL) + c.SetHTTPClient(srv.Client()) + + entries, err := c.ListContents(context.Background(), "owner", "repo", "src") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if len(entries) != 2 { + t.Fatalf("expected 2 entries, got %d", len(entries)) + } + if entries[0].Name != "main.go" { + t.Errorf("expected name 'main.go', got %q", entries[0].Name) + } + if entries[0].Path != "src/main.go" { + t.Errorf("expected path 'src/main.go', got %q", entries[0].Path) + } + if entries[0].Type != "file" { + t.Errorf("expected type 'file', got %q", entries[0].Type) + } + if entries[1].Name != "lib" { + t.Errorf("expected name 'lib', got %q", entries[1].Name) + } + if entries[1].Type != "dir" { + t.Errorf("expected type 'dir', got %q", entries[1].Type) + } +} + +func TestListContents_404(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(404) + w.Write([]byte(`{"message":"Not Found"}`)) + })) + defer srv.Close() + + c := NewClient("token", srv.URL) + c.SetHTTPClient(srv.Client()) + + _, err := c.ListContents(context.Background(), "owner", "repo", "missing") + if err == nil { + t.Fatal("expected error for 404") + } +} + +func TestListContents_401(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(401) + w.Write([]byte(`{"message":"Bad credentials"}`)) + })) + defer srv.Close() + + c := NewClient("token", srv.URL) + c.SetHTTPClient(srv.Client()) + + _, err := c.ListContents(context.Background(), "owner", "repo", "src") + if err == nil { + t.Fatal("expected error for 401") + } +} + +func TestListContents_429Retry(t *testing.T) { + attempts := 0 + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + attempts++ + if attempts == 1 { + w.WriteHeader(429) + w.Write([]byte(`{"message":"rate limit"}`)) + return + } + json.NewEncoder(w).Encode([]map[string]string{ + {"name": "file.go", "path": "file.go", "type": "file"}, + }) + })) + defer srv.Close() + + c := NewClient("token", srv.URL) + c.SetHTTPClient(srv.Client()) + c.RetryBackoff = []time.Duration{1 * time.Millisecond} + + entries, err := c.ListContents(context.Background(), "owner", "repo", ".") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if len(entries) != 1 { + t.Fatalf("expected 1 entry, got %d", len(entries)) + } + if attempts != 2 { + t.Errorf("expected 2 attempts, got %d", attempts) + } +} + +func TestListContents_MalformedJSON(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(200) + w.Write([]byte(`not json`)) + })) + defer srv.Close() + + c := NewClient("token", srv.URL) + c.SetHTTPClient(srv.Client()) + + _, err := c.ListContents(context.Background(), "owner", "repo", "src") + if err == nil { + t.Fatal("expected error for malformed JSON") + } +} + +func TestDecodeBase64Content(t *testing.T) { + // Test with newlines (GitHub's format) + encoded := "cGFja2FnZSBt\nYWlu" + decoded, err := decodeBase64Content(encoded) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if decoded != "package main" { + t.Errorf("expected 'package main', got %q", decoded) + } +} + +func TestDecodeBase64Content_Invalid(t *testing.T) { + _, err := decodeBase64Content("not!!!valid!!!base64") + if err == nil { + t.Fatal("expected error for invalid base64") + } +} diff --git a/github/pr.go b/github/pr.go new file mode 100644 index 0000000..81bec09 --- /dev/null +++ b/github/pr.go @@ -0,0 +1,233 @@ +package github + +import ( + "context" + "encoding/json" + "fmt" + "net/http" + "net/url" + + "gitea.weiker.me/rodin/review-bot/vcs" +) + +// pullRequestResponse is the GitHub API response for a pull request. +type pullRequestResponse struct { + Number int `json:"number"` + Title string `json:"title"` + Body string `json:"body"` + Head struct { + SHA string `json:"sha"` + Ref string `json:"ref"` + } `json:"head"` + Base struct { + Ref string `json:"ref"` + } `json:"base"` +} + +// changedFileResponse is the GitHub API response for a changed file in a PR. +type changedFileResponse struct { + Filename string `json:"filename"` + Status string `json:"status"` + Patch string `json:"patch"` +} + +// commitStatusResponse is the GitHub combined status API response. +type commitStatusResponse struct { + State string `json:"state"` + Statuses []struct { + Context string `json:"context"` + State string `json:"state"` + Description string `json:"description"` + TargetURL string `json:"target_url"` + } `json:"statuses"` +} + +// checkRunsResponse is the GitHub check runs API response. +type checkRunsResponse struct { + TotalCount int `json:"total_count"` + CheckRuns []struct { + Name string `json:"name"` + Conclusion *string `json:"conclusion"` + Status string `json:"status"` + HTMLURL string `json:"html_url"` + } `json:"check_runs"` +} + +// GetPullRequest fetches PR metadata. +func (c *Client) GetPullRequest(ctx context.Context, owner, repo string, number int) (*vcs.PullRequest, error) { + reqURL := fmt.Sprintf("%s/repos/%s/%s/pulls/%d", c.baseURL, url.PathEscape(owner), url.PathEscape(repo), number) + body, err := c.doGet(ctx, reqURL) + if err != nil { + return nil, fmt.Errorf("fetch PR: %w", err) + } + var resp pullRequestResponse + if err := json.Unmarshal(body, &resp); err != nil { + return nil, fmt.Errorf("parse PR JSON: %w", err) + } + return &vcs.PullRequest{ + Number: resp.Number, + Title: resp.Title, + Body: resp.Body, + Head: vcs.HeadRef{SHA: resp.Head.SHA, Ref: resp.Head.Ref}, + Base: vcs.BaseRef{Ref: resp.Base.Ref}, + }, nil +} + +// GetPullRequestDiff fetches the unified diff for a PR. +// Uses Accept: application/vnd.github.diff to get raw diff text. +func (c *Client) GetPullRequestDiff(ctx context.Context, owner, repo string, number int) (string, error) { + reqURL := fmt.Sprintf("%s/repos/%s/%s/pulls/%d", c.baseURL, url.PathEscape(owner), url.PathEscape(repo), number) + body, err := c.doRequest(ctx, http.MethodGet, reqURL, "application/vnd.github.diff") + if err != nil { + return "", fmt.Errorf("fetch diff: %w", err) + } + return string(body), nil +} + +// GetPullRequestFiles fetches the list of files changed in a PR. +// Paginates through all pages (100 per page) to collect all files. +func (c *Client) GetPullRequestFiles(ctx context.Context, owner, repo string, number int) ([]vcs.ChangedFile, error) { + var allFiles []vcs.ChangedFile + page := 1 + + for { + reqURL := fmt.Sprintf("%s/repos/%s/%s/pulls/%d/files?per_page=100&page=%d", + c.baseURL, url.PathEscape(owner), url.PathEscape(repo), number, page) + body, err := c.doGet(ctx, reqURL) + if err != nil { + return nil, fmt.Errorf("fetch PR files page %d: %w", page, err) + } + var files []changedFileResponse + if err := json.Unmarshal(body, &files); err != nil { + return nil, fmt.Errorf("parse PR files JSON: %w", err) + } + if len(files) == 0 { + break + } + for _, f := range files { + allFiles = append(allFiles, vcs.ChangedFile{ + Filename: f.Filename, + Status: f.Status, + Patch: f.Patch, + }) + } + if len(files) < 100 { + break + } + page++ + } + + return allFiles, nil +} + +// GetFileContentAtRef fetches a file at a specific ref from a repo. +// If ref is empty, the query parameter is omitted (uses default branch). +func (c *Client) GetFileContentAtRef(ctx context.Context, owner, repo, path, ref string) (string, error) { + reqURL := fmt.Sprintf("%s/repos/%s/%s/contents/%s", + c.baseURL, url.PathEscape(owner), url.PathEscape(repo), escapePath(path)) + if ref != "" { + reqURL += "?ref=" + url.QueryEscape(ref) + } + body, err := c.doGet(ctx, reqURL) + if err != nil { + return "", fmt.Errorf("fetch file %s: %w", path, err) + } + var resp struct { + Content string `json:"content"` + Encoding string `json:"encoding"` + } + if err := json.Unmarshal(body, &resp); err != nil { + return "", fmt.Errorf("parse file content JSON: %w", err) + } + if resp.Encoding != "base64" { + return "", fmt.Errorf("unexpected encoding %q for file %s", resp.Encoding, path) + } + decoded, err := decodeBase64Content(resp.Content) + if err != nil { + return "", fmt.Errorf("decode base64 content for %s: %w", path, err) + } + return decoded, nil +} + +// GetCommitStatuses fetches both commit statuses and check runs for a SHA, +// merging them into a unified []vcs.CommitStatus slice. +func (c *Client) GetCommitStatuses(ctx context.Context, owner, repo, sha string) ([]vcs.CommitStatus, error) { + var result []vcs.CommitStatus + + // Fetch commit statuses + statusURL := fmt.Sprintf("%s/repos/%s/%s/commits/%s/status", + c.baseURL, url.PathEscape(owner), url.PathEscape(repo), url.PathEscape(sha)) + statusBody, err := c.doGet(ctx, statusURL) + if err != nil { + return nil, fmt.Errorf("fetch commit statuses: %w", err) + } + var statusResp commitStatusResponse + if err := json.Unmarshal(statusBody, &statusResp); err != nil { + return nil, fmt.Errorf("parse commit statuses JSON: %w", err) + } + for _, s := range statusResp.Statuses { + result = append(result, vcs.CommitStatus{ + Context: s.Context, + Status: s.State, + Description: s.Description, + TargetURL: s.TargetURL, + }) + } + + // Fetch check runs (paginated) + checkPage := 1 + for { + checkURL := fmt.Sprintf("%s/repos/%s/%s/commits/%s/check-runs?per_page=100&page=%d", + c.baseURL, url.PathEscape(owner), url.PathEscape(repo), url.PathEscape(sha), checkPage) + checkBody, err := c.doGet(ctx, checkURL) + if err != nil { + return nil, fmt.Errorf("fetch check runs page %d: %w", checkPage, err) + } + var checkResp checkRunsResponse + if err := json.Unmarshal(checkBody, &checkResp); err != nil { + return nil, fmt.Errorf("parse check runs JSON: %w", err) + } + for _, cr := range checkResp.CheckRuns { + result = append(result, vcs.CommitStatus{ + Context: cr.Name, + Status: mapCheckRunStatus(cr.Conclusion, cr.Status), + Description: derefString(cr.Conclusion), + TargetURL: cr.HTMLURL, + }) + } + if len(checkResp.CheckRuns) < 100 { + break + } + checkPage++ + } + + return result, nil +} + +// mapCheckRunStatus maps a check run conclusion+status to a vcs.CommitStatus status string. +func mapCheckRunStatus(conclusion *string, status string) string { + if conclusion == nil { + // Still running or queued + return "pending" + } + switch *conclusion { + case "success": + return "success" + case "failure", "action_required", "timed_out": + return "failure" + case "cancelled", "skipped", "neutral": + return "success" // non-blocking + case "in_progress", "queued": + return "pending" + default: + return "pending" + } +} + +// derefString safely dereferences a string pointer, returning empty string if nil. +func derefString(s *string) string { + if s == nil { + return "" + } + return *s +} diff --git a/github/pr_test.go b/github/pr_test.go new file mode 100644 index 0000000..78366b7 --- /dev/null +++ b/github/pr_test.go @@ -0,0 +1,637 @@ +package github + +import ( + "context" + "encoding/json" + "fmt" + "net/http" + "net/http/httptest" + "strings" + "testing" + "time" +) + +func TestGetPullRequest_HappyPath(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/repos/owner/repo/pulls/42" { + t.Errorf("unexpected path: %s", r.URL.Path) + } + json.NewEncoder(w).Encode(map[string]interface{}{ + "number": 42, + "title": "Test PR", + "body": "Description", + "head": map[string]string{"sha": "abc123", "ref": "feature-branch"}, + "base": map[string]string{"ref": "main"}, + }) + })) + defer srv.Close() + + c := NewClient("token", srv.URL) + c.SetHTTPClient(srv.Client()) + + pr, err := c.GetPullRequest(context.Background(), "owner", "repo", 42) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if pr.Number != 42 { + t.Errorf("expected number 42, got %d", pr.Number) + } + if pr.Title != "Test PR" { + t.Errorf("expected title 'Test PR', got %q", pr.Title) + } + if pr.Body != "Description" { + t.Errorf("expected body 'Description', got %q", pr.Body) + } + if pr.Head.SHA != "abc123" { + t.Errorf("expected head SHA 'abc123', got %q", pr.Head.SHA) + } + if pr.Head.Ref != "feature-branch" { + t.Errorf("expected head ref 'feature-branch', got %q", pr.Head.Ref) + } + if pr.Base.Ref != "main" { + t.Errorf("expected base ref 'main', got %q", pr.Base.Ref) + } +} + +func TestGetPullRequest_404(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(404) + w.Write([]byte(`{"message":"Not Found"}`)) + })) + defer srv.Close() + + c := NewClient("token", srv.URL) + c.SetHTTPClient(srv.Client()) + + _, err := c.GetPullRequest(context.Background(), "owner", "repo", 999) + if err == nil { + t.Fatal("expected error for 404") + } + if !IsNotFound(err) { + t.Errorf("expected IsNotFound=true, got error: %v", err) + } +} + +func TestGetPullRequest_401(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(401) + w.Write([]byte(`{"message":"Bad credentials"}`)) + })) + defer srv.Close() + + c := NewClient("token", srv.URL) + c.SetHTTPClient(srv.Client()) + + _, err := c.GetPullRequest(context.Background(), "owner", "repo", 1) + if err == nil { + t.Fatal("expected error for 401") + } + if !IsUnauthorized(err) { + t.Errorf("expected IsUnauthorized=true, got error: %v", err) + } +} + +func TestGetPullRequest_429Retry(t *testing.T) { + attempts := 0 + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + attempts++ + if attempts == 1 { + w.WriteHeader(429) + w.Write([]byte(`{"message":"rate limit"}`)) + return + } + json.NewEncoder(w).Encode(map[string]interface{}{ + "number": 1, + "title": "PR", + "body": "", + "head": map[string]string{"sha": "abc", "ref": "br"}, + "base": map[string]string{"ref": "main"}, + }) + })) + defer srv.Close() + + c := NewClient("token", srv.URL) + c.SetHTTPClient(srv.Client()) + c.RetryBackoff = []time.Duration{1 * time.Millisecond} + + pr, err := c.GetPullRequest(context.Background(), "owner", "repo", 1) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if pr.Number != 1 { + t.Errorf("expected number 1, got %d", pr.Number) + } + if attempts != 2 { + t.Errorf("expected 2 attempts, got %d", attempts) + } +} + +func TestGetPullRequest_MalformedJSON(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(200) + w.Write([]byte(`{invalid json`)) + })) + defer srv.Close() + + c := NewClient("token", srv.URL) + c.SetHTTPClient(srv.Client()) + + _, err := c.GetPullRequest(context.Background(), "owner", "repo", 1) + if err == nil { + t.Fatal("expected error for malformed JSON") + } + if !strings.Contains(err.Error(), "parse PR JSON") { + t.Errorf("expected parse error, got: %v", err) + } +} + +func TestGetPullRequestDiff_HappyPath(t *testing.T) { + expectedDiff := "diff --git a/file.go b/file.go\n--- a/file.go\n+++ b/file.go\n@@ -1,3 +1,4 @@\n+// new line\n" + var gotAccept string + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + gotAccept = r.Header.Get("Accept") + w.WriteHeader(200) + w.Write([]byte(expectedDiff)) + })) + defer srv.Close() + + c := NewClient("token", srv.URL) + c.SetHTTPClient(srv.Client()) + + diff, err := c.GetPullRequestDiff(context.Background(), "owner", "repo", 42) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if diff != expectedDiff { + t.Errorf("unexpected diff: %q", diff) + } + if gotAccept != "application/vnd.github.diff" { + t.Errorf("expected diff Accept header, got %q", gotAccept) + } +} + +func TestGetPullRequestDiff_404(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(404) + w.Write([]byte(`{"message":"Not Found"}`)) + })) + defer srv.Close() + + c := NewClient("token", srv.URL) + c.SetHTTPClient(srv.Client()) + + _, err := c.GetPullRequestDiff(context.Background(), "owner", "repo", 999) + if err == nil { + t.Fatal("expected error for 404") + } +} + +func TestGetPullRequestDiff_401(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(401) + w.Write([]byte(`{"message":"Bad credentials"}`)) + })) + defer srv.Close() + + c := NewClient("token", srv.URL) + c.SetHTTPClient(srv.Client()) + + _, err := c.GetPullRequestDiff(context.Background(), "owner", "repo", 1) + if err == nil { + t.Fatal("expected error for 401") + } +} + +func TestGetPullRequestFiles_HappyPath(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + json.NewEncoder(w).Encode([]map[string]interface{}{ + {"filename": "main.go", "status": "modified", "patch": "@@ -1,3 +1,4 @@\n+line"}, + {"filename": "test.go", "status": "added", "patch": "@@ -0,0 +1,5 @@\n+new file"}, + }) + })) + defer srv.Close() + + c := NewClient("token", srv.URL) + c.SetHTTPClient(srv.Client()) + + files, err := c.GetPullRequestFiles(context.Background(), "owner", "repo", 1) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if len(files) != 2 { + t.Fatalf("expected 2 files, got %d", len(files)) + } + if files[0].Filename != "main.go" { + t.Errorf("expected filename 'main.go', got %q", files[0].Filename) + } + if files[0].Status != "modified" { + t.Errorf("expected status 'modified', got %q", files[0].Status) + } + if files[0].Patch != "@@ -1,3 +1,4 @@\n+line" { + t.Errorf("unexpected patch: %q", files[0].Patch) + } +} + +func TestGetPullRequestFiles_Pagination(t *testing.T) { + // Simulate > 100 files requiring pagination + page1Files := make([]map[string]string, 100) + for i := 0; i < 100; i++ { + page1Files[i] = map[string]string{ + "filename": fmt.Sprintf("file%d.go", i), + "status": "modified", + "patch": fmt.Sprintf("patch%d", i), + } + } + page2Files := []map[string]string{ + {"filename": "file100.go", "status": "added", "patch": "patch100"}, + } + + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + page := r.URL.Query().Get("page") + if page == "" || page == "1" { + json.NewEncoder(w).Encode(page1Files) + } else { + json.NewEncoder(w).Encode(page2Files) + } + })) + defer srv.Close() + + c := NewClient("token", srv.URL) + c.SetHTTPClient(srv.Client()) + + files, err := c.GetPullRequestFiles(context.Background(), "owner", "repo", 1) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if len(files) != 101 { + t.Errorf("expected 101 files (paginated), got %d", len(files)) + } + if files[100].Filename != "file100.go" { + t.Errorf("expected last file 'file100.go', got %q", files[100].Filename) + } + if files[100].Patch != "patch100" { + t.Errorf("expected last patch 'patch100', got %q", files[100].Patch) + } +} + +func TestGetPullRequestFiles_BinaryFile_NoPatch(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Binary files have no patch field in GitHub response + json.NewEncoder(w).Encode([]map[string]interface{}{ + {"filename": "image.png", "status": "added"}, + }) + })) + defer srv.Close() + + c := NewClient("token", srv.URL) + c.SetHTTPClient(srv.Client()) + + files, err := c.GetPullRequestFiles(context.Background(), "owner", "repo", 1) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if len(files) != 1 { + t.Fatalf("expected 1 file, got %d", len(files)) + } + if files[0].Patch != "" { + t.Errorf("expected empty patch for binary file, got %q", files[0].Patch) + } +} + +func TestGetPullRequestFiles_404(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(404) + w.Write([]byte(`{"message":"Not Found"}`)) + })) + defer srv.Close() + + c := NewClient("token", srv.URL) + c.SetHTTPClient(srv.Client()) + + _, err := c.GetPullRequestFiles(context.Background(), "owner", "repo", 999) + if err == nil { + t.Fatal("expected error for 404") + } +} + +func TestGetPullRequestFiles_MalformedJSON(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(200) + w.Write([]byte(`not json`)) + })) + defer srv.Close() + + c := NewClient("token", srv.URL) + c.SetHTTPClient(srv.Client()) + + _, err := c.GetPullRequestFiles(context.Background(), "owner", "repo", 1) + if err == nil { + t.Fatal("expected error for malformed JSON") + } +} + +func TestGetFileContentAtRef_HappyPath(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/repos/owner/repo/contents/path/to/file.go" { + t.Errorf("unexpected path: %s", r.URL.Path) + } + if r.URL.Query().Get("ref") != "abc123" { + t.Errorf("unexpected ref: %s", r.URL.Query().Get("ref")) + } + json.NewEncoder(w).Encode(map[string]string{ + "content": "cGFja2FnZSBtYWlu", // "package main" in base64 + "encoding": "base64", + }) + })) + defer srv.Close() + + c := NewClient("token", srv.URL) + c.SetHTTPClient(srv.Client()) + + content, err := c.GetFileContentAtRef(context.Background(), "owner", "repo", "path/to/file.go", "abc123") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if content != "package main" { + t.Errorf("expected 'package main', got %q", content) + } +} + +func TestGetFileContentAtRef_EmptyRef(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Query().Get("ref") != "" { + t.Errorf("expected no ref param, got %q", r.URL.Query().Get("ref")) + } + json.NewEncoder(w).Encode(map[string]string{ + "content": "aGVsbG8=", // "hello" in base64 + "encoding": "base64", + }) + })) + defer srv.Close() + + c := NewClient("token", srv.URL) + c.SetHTTPClient(srv.Client()) + + content, err := c.GetFileContentAtRef(context.Background(), "owner", "repo", "file.txt", "") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if content != "hello" { + t.Errorf("expected 'hello', got %q", content) + } +} + +func TestGetFileContentAtRef_404(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(404) + w.Write([]byte(`{"message":"Not Found"}`)) + })) + defer srv.Close() + + c := NewClient("token", srv.URL) + c.SetHTTPClient(srv.Client()) + + _, err := c.GetFileContentAtRef(context.Background(), "owner", "repo", "missing.go", "main") + if err == nil { + t.Fatal("expected error for 404") + } +} + +func TestGetFileContentAtRef_401(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(401) + w.Write([]byte(`{"message":"Bad credentials"}`)) + })) + defer srv.Close() + + c := NewClient("token", srv.URL) + c.SetHTTPClient(srv.Client()) + + _, err := c.GetFileContentAtRef(context.Background(), "owner", "repo", "file.go", "main") + if err == nil { + t.Fatal("expected error for 401") + } +} + +func TestGetFileContentAtRef_MalformedJSON(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(200) + w.Write([]byte(`not valid json`)) + })) + defer srv.Close() + + c := NewClient("token", srv.URL) + c.SetHTTPClient(srv.Client()) + + _, err := c.GetFileContentAtRef(context.Background(), "owner", "repo", "file.go", "main") + if err == nil { + t.Fatal("expected error for malformed JSON") + } +} + +func TestGetFileContentAtRef_429Retry(t *testing.T) { + attempts := 0 + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + attempts++ + if attempts == 1 { + w.WriteHeader(429) + w.Write([]byte(`{"message":"rate limit"}`)) + return + } + json.NewEncoder(w).Encode(map[string]string{ + "content": "b2s=", // "ok" in base64 + "encoding": "base64", + }) + })) + defer srv.Close() + + c := NewClient("token", srv.URL) + c.SetHTTPClient(srv.Client()) + c.RetryBackoff = []time.Duration{1 * time.Millisecond} + + content, err := c.GetFileContentAtRef(context.Background(), "owner", "repo", "file.go", "main") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if content != "ok" { + t.Errorf("expected 'ok', got %q", content) + } + if attempts != 2 { + t.Errorf("expected 2 attempts, got %d", attempts) + } +} + +func TestGetCommitStatuses_HappyPath(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch { + case strings.Contains(r.URL.Path, "/status"): + json.NewEncoder(w).Encode(map[string]interface{}{ + "state": "success", + "statuses": []map[string]string{ + { + "context": "ci/build", + "state": "success", + "description": "Build passed", + "target_url": "https://ci.example.com/1", + }, + }, + }) + case strings.Contains(r.URL.Path, "/check-runs"): + conclusion := "success" + json.NewEncoder(w).Encode(map[string]interface{}{ + "total_count": 1, + "check_runs": []map[string]interface{}{ + { + "name": "lint", + "conclusion": &conclusion, + "status": "completed", + "html_url": "https://github.com/check/1", + }, + }, + }) + default: + t.Errorf("unexpected path: %s", r.URL.Path) + w.WriteHeader(404) + } + })) + defer srv.Close() + + c := NewClient("token", srv.URL) + c.SetHTTPClient(srv.Client()) + + statuses, err := c.GetCommitStatuses(context.Background(), "owner", "repo", "abc123") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if len(statuses) != 2 { + t.Fatalf("expected 2 statuses, got %d", len(statuses)) + } + // First should be from commit statuses + if statuses[0].Context != "ci/build" { + t.Errorf("expected context 'ci/build', got %q", statuses[0].Context) + } + if statuses[0].Status != "success" { + t.Errorf("expected status 'success', got %q", statuses[0].Status) + } + // Second should be from check runs + if statuses[1].Context != "lint" { + t.Errorf("expected context 'lint', got %q", statuses[1].Context) + } + if statuses[1].Status != "success" { + t.Errorf("expected status 'success', got %q", statuses[1].Status) + } +} + +func TestGetCommitStatuses_CheckRunConclusions(t *testing.T) { + tests := []struct { + conclusion *string + status string + want string + }{ + {strPtr("success"), "completed", "success"}, + {strPtr("failure"), "completed", "failure"}, + {strPtr("action_required"), "completed", "failure"}, + {strPtr("timed_out"), "completed", "failure"}, + {strPtr("cancelled"), "completed", "success"}, + {strPtr("skipped"), "completed", "success"}, + {strPtr("neutral"), "completed", "success"}, + {nil, "in_progress", "pending"}, + {nil, "queued", "pending"}, + } + + for _, tt := range tests { + name := "nil" + if tt.conclusion != nil { + name = *tt.conclusion + } + t.Run(name, func(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if strings.Contains(r.URL.Path, "/status") { + json.NewEncoder(w).Encode(map[string]interface{}{ + "state": "success", + "statuses": []interface{}{}, + }) + return + } + json.NewEncoder(w).Encode(map[string]interface{}{ + "total_count": 1, + "check_runs": []map[string]interface{}{ + { + "name": "check", + "conclusion": tt.conclusion, + "status": tt.status, + "html_url": "https://github.com/check/1", + }, + }, + }) + })) + defer srv.Close() + + c := NewClient("token", srv.URL) + c.SetHTTPClient(srv.Client()) + + statuses, err := c.GetCommitStatuses(context.Background(), "owner", "repo", "sha1") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if len(statuses) != 1 { + t.Fatalf("expected 1 status, got %d", len(statuses)) + } + if statuses[0].Status != tt.want { + t.Errorf("expected status %q, got %q", tt.want, statuses[0].Status) + } + }) + } +} + +func TestGetCommitStatuses_404(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(404) + w.Write([]byte(`{"message":"Not Found"}`)) + })) + defer srv.Close() + + c := NewClient("token", srv.URL) + c.SetHTTPClient(srv.Client()) + + _, err := c.GetCommitStatuses(context.Background(), "owner", "repo", "badsha") + if err == nil { + t.Fatal("expected error for 404") + } +} + +func TestGetCommitStatuses_401(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(401) + w.Write([]byte(`{"message":"Bad credentials"}`)) + })) + defer srv.Close() + + c := NewClient("token", srv.URL) + c.SetHTTPClient(srv.Client()) + + _, err := c.GetCommitStatuses(context.Background(), "owner", "repo", "sha") + if err == nil { + t.Fatal("expected error for 401") + } +} + +func TestGetCommitStatuses_MalformedJSON(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(200) + w.Write([]byte(`not json`)) + })) + defer srv.Close() + + c := NewClient("token", srv.URL) + c.SetHTTPClient(srv.Client()) + + _, err := c.GetCommitStatuses(context.Background(), "owner", "repo", "sha") + if err == nil { + t.Fatal("expected error for malformed JSON") + } +} + +func strPtr(s string) *string { + return &s +} diff --git a/vcs/types.go b/vcs/types.go index de904f3..608ad27 100644 --- a/vcs/types.go +++ b/vcs/types.go @@ -44,6 +44,7 @@ type PullRequest struct { type ChangedFile struct { Filename string `json:"filename"` Status string `json:"status"` + Patch string `json:"patch"` } // ContentEntry represents a file or directory entry from the contents API. -- 2.47.3 From 5b43afc6d43f890028049589c8889c13d911e36f Mon Sep 17 00:00:00 2001 From: claw Date: Tue, 12 May 2026 15:43:45 -0700 Subject: [PATCH 02/16] fix: address review feedback on PR #93 - Fix Retry-After slice mutation: copy c.RetryBackoff before modifying to prevent permanent mutation of the shared slice (sonnet#1, security#1) - Cap Retry-After to 120s maximum to prevent excessive sleeps (security#2) - Guard auth header: only set Authorization when token is non-empty (gpt#2) - Fix GetFileContent doc comment to match actual behavior (sonnet#3, gpt#1) - Remove dead 'in_progress/queued' case in mapCheckRunStatus (sonnet#4) - Add testing.Short() guard to slow retry test (sonnet#5) - Reject dot-segments in escapePath to prevent path traversal (security#3) - Add regression tests for non-mutation and escapePath safety --- github/client.go | 21 ++++++++++++++++----- github/client_test.go | 39 +++++++++++++++++++++++++++++++++++++++ github/files.go | 15 ++++++++++----- github/files_test.go | 20 ++++++++++++++++++++ github/pr.go | 4 +--- 5 files changed, 86 insertions(+), 13 deletions(-) diff --git a/github/client.go b/github/client.go index ab07a24..bd6f7e1 100644 --- a/github/client.go +++ b/github/client.go @@ -93,11 +93,16 @@ func (c *Client) SetHTTPClient(hc *http.Client) { } // doRequest performs an HTTP request with retry on 429 rate limit responses. -// It respects the Retry-After header when present. +// It respects the Retry-After header when present (capped at maxRetryAfter). func (c *Client) doRequest(ctx context.Context, method, url string, accept string) ([]byte, error) { const maxAttempts = 3 - backoff := c.RetryBackoff - if backoff == nil { + const maxRetryAfter = 120 * time.Second + + var backoff []time.Duration + if c.RetryBackoff != nil { + backoff = make([]time.Duration, len(c.RetryBackoff)) + copy(backoff, c.RetryBackoff) + } else { backoff = []time.Duration{1 * time.Second, 2 * time.Second} } @@ -125,7 +130,9 @@ func (c *Client) doRequest(ctx context.Context, method, url string, accept strin if err != nil { return nil, fmt.Errorf("create request: %w", err) } - req.Header.Set("Authorization", "Bearer "+c.token) + if c.token != "" { + req.Header.Set("Authorization", "Bearer "+c.token) + } if accept != "" { req.Header.Set("Accept", accept) } else { @@ -156,8 +163,12 @@ func (c *Client) doRequest(ctx context.Context, method, url string, accept strin // Check for Retry-After header and override backoff if present if ra := resp.Header.Get("Retry-After"); ra != "" { if seconds, err := strconv.Atoi(ra); err == nil && seconds > 0 { + delay := time.Duration(seconds) * time.Second + if delay > maxRetryAfter { + delay = maxRetryAfter + } if attempt < len(backoff) { - backoff[attempt] = time.Duration(seconds) * time.Second + backoff[attempt] = delay } } } diff --git a/github/client_test.go b/github/client_test.go index e3cd121..e00e534 100644 --- a/github/client_test.go +++ b/github/client_test.go @@ -185,6 +185,9 @@ func TestIsUnauthorized(t *testing.T) { } func TestDoRequest_429RetryAfterHeader(t *testing.T) { + if testing.Short() { + t.Skip("skipping slow retry test in short mode") + } attempts := 0 srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { attempts++ @@ -222,3 +225,39 @@ func TestDoRequest_429RetryAfterHeader(t *testing.T) { t.Errorf("expected ~1s delay from Retry-After, got %v", elapsed) } } + +func TestDoRequest_RetryAfterDoesNotMutateBackoff(t *testing.T) { + if testing.Short() { + t.Skip("skipping slow retry test in short mode") + } + attempts := 0 + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + attempts++ + if attempts == 1 { + w.Header().Set("Retry-After", "1") + w.WriteHeader(429) + w.Write([]byte(`{"message":"rate limit"}`)) + return + } + w.WriteHeader(200) + w.Write([]byte(`{"ok":true}`)) + })) + defer srv.Close() + + c := NewClient("token", srv.URL) + c.SetHTTPClient(srv.Client()) + c.RetryBackoff = []time.Duration{1 * time.Millisecond, 1 * time.Millisecond} + + _, err := c.doGet(context.Background(), srv.URL+"/test") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + // Verify the original RetryBackoff slice was not mutated + if c.RetryBackoff[0] != 1*time.Millisecond { + t.Errorf("RetryBackoff[0] was mutated: got %v, want 1ms", c.RetryBackoff[0]) + } + if c.RetryBackoff[1] != 1*time.Millisecond { + t.Errorf("RetryBackoff[1] was mutated: got %v, want 1ms", c.RetryBackoff[1]) + } +} diff --git a/github/files.go b/github/files.go index 9c162bf..a385623 100644 --- a/github/files.go +++ b/github/files.go @@ -11,8 +11,8 @@ import ( "gitea.weiker.me/rodin/review-bot/vcs" ) -// GetFileContent fetches a file from the default branch of a repo. -// Delegates to GetFileContentAtRef with an empty ref. +// GetFileContent fetches a file from a repo at the given ref. +// Delegates to GetFileContentAtRef with the provided ref. func (c *Client) GetFileContent(ctx context.Context, owner, repo, path, ref string) (string, error) { return c.GetFileContentAtRef(ctx, owner, repo, path, ref) } @@ -47,12 +47,17 @@ func (c *Client) ListContents(ctx context.Context, owner, repo, path string) ([] // escapePath escapes each segment of a relative file path for use in URLs. // Slashes are preserved as path separators; other special characters are escaped. +// Dot-segments ("." and "..") are removed to prevent path traversal. func escapePath(p string) string { parts := strings.Split(p, "/") - for i, part := range parts { - parts[i] = url.PathEscape(part) + var clean []string + for _, part := range parts { + if part == "." || part == ".." || part == "" { + continue + } + clean = append(clean, url.PathEscape(part)) } - return strings.Join(parts, "/") + return strings.Join(clean, "/") } // decodeBase64Content decodes base64-encoded content from the GitHub contents API. diff --git a/github/files_test.go b/github/files_test.go index bb76a0b..0c077d6 100644 --- a/github/files_test.go +++ b/github/files_test.go @@ -275,3 +275,23 @@ func TestDecodeBase64Content_Invalid(t *testing.T) { t.Fatal("expected error for invalid base64") } } + +func TestEscapePath_RejectsDotSegments(t *testing.T) { + tests := []struct { + input string + want string + }{ + {"src/main.go", "src/main.go"}, + {"../etc/passwd", "etc/passwd"}, + {"./src/../main.go", "src/main.go"}, + {"a/b/c", "a/b/c"}, + {"file with spaces.go", "file%20with%20spaces.go"}, + {"a/./b/../c", "a/b/c"}, + } + for _, tt := range tests { + got := escapePath(tt.input) + if got != tt.want { + t.Errorf("escapePath(%q) = %q, want %q", tt.input, got, tt.want) + } + } +} diff --git a/github/pr.go b/github/pr.go index 81bec09..0d1046f 100644 --- a/github/pr.go +++ b/github/pr.go @@ -205,7 +205,7 @@ func (c *Client) GetCommitStatuses(ctx context.Context, owner, repo, sha string) } // mapCheckRunStatus maps a check run conclusion+status to a vcs.CommitStatus status string. -func mapCheckRunStatus(conclusion *string, status string) string { +func mapCheckRunStatus(conclusion *string, _ string) string { if conclusion == nil { // Still running or queued return "pending" @@ -217,8 +217,6 @@ func mapCheckRunStatus(conclusion *string, status string) string { return "failure" case "cancelled", "skipped", "neutral": return "success" // non-blocking - case "in_progress", "queued": - return "pending" default: return "pending" } -- 2.47.3 From 75f65fbf5d59e64f55cb27b60697a466c6890532 Mon Sep 17 00:00:00 2001 From: claw Date: Tue, 12 May 2026 16:00:09 -0700 Subject: [PATCH 03/16] fix: address MINOR review findings on PR #93 (round 2) - Add User-Agent header to all requests (gpt-review-bot) - Limit successful response body to 10 MiB via io.LimitReader (security-review-bot) - Add CheckRedirect to strip Authorization on cross-host redirects (security-review-bot) - Fix decodeBase64Content to strip both \r and \n (gpt-review-bot) - Document that transport errors are not retried (sonnet-review-bot) - Update package doc to reflect current scope (no review submission yet) - Add tests for User-Agent, empty-token auth skip, CRLF base64, CheckRedirect --- github/client.go | 26 +++++++++++++++++---- github/client_test.go | 53 +++++++++++++++++++++++++++++++++++++++++++ github/files.go | 4 ++-- github/files_test.go | 12 ++++++++++ 4 files changed, 89 insertions(+), 6 deletions(-) diff --git a/github/client.go b/github/client.go index bd6f7e1..e0f5dfc 100644 --- a/github/client.go +++ b/github/client.go @@ -1,6 +1,6 @@ // Package github provides a client for the GitHub API. -// It supports pull request operations, file content retrieval, -// and review submission for both github.com and GitHub Enterprise. +// It supports pull request operations, file content retrieval, CI status checks, +// and directory listing for both github.com and GitHub Enterprise. package github import ( @@ -15,6 +15,10 @@ import ( ) const defaultBaseURL = "https://api.github.com" +const userAgent = "review-bot/1.0" + +// maxResponseBytes limits successful response body reads to 10 MiB. +const maxResponseBytes = 10 * 1024 * 1024 // APIError represents an HTTP error response from the GitHub API. // It carries the status code so callers can distinguish between @@ -82,7 +86,19 @@ func NewClient(token, baseURL string) *Client { return &Client{ baseURL: strings.TrimRight(baseURL, "/"), token: token, - http: &http.Client{Timeout: 30 * time.Second}, + http: &http.Client{ + Timeout: 30 * time.Second, + CheckRedirect: func(req *http.Request, via []*http.Request) error { + // Prevent forwarding Authorization header to different hosts on redirect. + if len(via) > 0 && req.URL.Host != via[0].URL.Host { + req.Header.Del("Authorization") + } + if len(via) >= 10 { + return fmt.Errorf("stopped after 10 redirects") + } + return nil + }, + }, } } @@ -94,6 +110,7 @@ func (c *Client) SetHTTPClient(hc *http.Client) { // doRequest performs an HTTP request with retry on 429 rate limit responses. // It respects the Retry-After header when present (capped at maxRetryAfter). +// Transport errors (network failures, context cancellation) are not retried. func (c *Client) doRequest(ctx context.Context, method, url string, accept string) ([]byte, error) { const maxAttempts = 3 const maxRetryAfter = 120 * time.Second @@ -133,6 +150,7 @@ func (c *Client) doRequest(ctx context.Context, method, url string, accept strin if c.token != "" { req.Header.Set("Authorization", "Bearer "+c.token) } + req.Header.Set("User-Agent", userAgent) if accept != "" { req.Header.Set("Accept", accept) } else { @@ -145,7 +163,7 @@ func (c *Client) doRequest(ctx context.Context, method, url string, accept strin } if resp.StatusCode >= 200 && resp.StatusCode < 300 { - body, err := io.ReadAll(resp.Body) + body, err := io.ReadAll(io.LimitReader(resp.Body, maxResponseBytes)) resp.Body.Close() if err != nil { return nil, fmt.Errorf("read response body: %w", err) diff --git a/github/client_test.go b/github/client_test.go index e00e534..794df2f 100644 --- a/github/client_test.go +++ b/github/client_test.go @@ -261,3 +261,56 @@ func TestDoRequest_RetryAfterDoesNotMutateBackoff(t *testing.T) { t.Errorf("RetryBackoff[1] was mutated: got %v, want 1ms", c.RetryBackoff[1]) } } + +func TestDoRequest_SetsUserAgentHeader(t *testing.T) { + var gotUA string + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + gotUA = r.Header.Get("User-Agent") + w.WriteHeader(200) + w.Write([]byte("{}")) + })) + defer srv.Close() + + c := NewClient("token", srv.URL) + c.SetHTTPClient(srv.Client()) + _, _ = c.doGet(context.Background(), srv.URL+"/test") + + if gotUA != "review-bot/1.0" { + t.Errorf("expected User-Agent 'review-bot/1.0', got %q", gotUA) + } +} + +func TestDoRequest_LimitsResponseBody(t *testing.T) { + // Verify that responses are read through a limit reader. + // We can't easily test the 10 MiB limit without OOM risk, + // but we verify the constant is set correctly. + if maxResponseBytes != 10*1024*1024 { + t.Errorf("expected maxResponseBytes = 10 MiB, got %d", maxResponseBytes) + } +} + +func TestDoRequest_SkipsAuthWhenTokenEmpty(t *testing.T) { + var gotAuth string + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + gotAuth = r.Header.Get("Authorization") + w.WriteHeader(200) + w.Write([]byte("{}")) + })) + defer srv.Close() + + c := NewClient("", srv.URL) // empty token + c.SetHTTPClient(srv.Client()) + _, _ = c.doGet(context.Background(), srv.URL+"/test") + + if gotAuth != "" { + t.Errorf("expected no Authorization header with empty token, got %q", gotAuth) + } +} + +func TestNewClient_CheckRedirectStripsAuthOnCrossHost(t *testing.T) { + // Verify the CheckRedirect function is configured + c := NewClient("secret-token", "https://api.github.com") + if c.http.CheckRedirect == nil { + t.Fatal("expected CheckRedirect to be set") + } +} diff --git a/github/files.go b/github/files.go index a385623..df2d6fc 100644 --- a/github/files.go +++ b/github/files.go @@ -61,10 +61,10 @@ func escapePath(p string) string { } // decodeBase64Content decodes base64-encoded content from the GitHub contents API. -// GitHub returns base64 content with newlines for formatting, which we strip before decoding. +// GitHub returns base64 content with line breaks for formatting; we strip \r and \n before decoding. func decodeBase64Content(encoded string) (string, error) { // GitHub inserts newlines in base64 content - cleaned := strings.ReplaceAll(encoded, "\n", "") + cleaned := strings.NewReplacer("\n", "", "\r", "").Replace(encoded) decoded, err := base64.StdEncoding.DecodeString(cleaned) if err != nil { return "", err diff --git a/github/files_test.go b/github/files_test.go index 0c077d6..3c6d889 100644 --- a/github/files_test.go +++ b/github/files_test.go @@ -295,3 +295,15 @@ func TestEscapePath_RejectsDotSegments(t *testing.T) { } } } + +func TestDecodeBase64Content_CRLF(t *testing.T) { + // Base64 of "hello world" with CRLF line breaks inserted + encoded := "aGVs\r\nbG8g\r\nd29y\r\nbGQ=" + decoded, err := decodeBase64Content(encoded) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if decoded != "hello world" { + t.Errorf("expected 'hello world', got %q", decoded) + } +} -- 2.47.3 From ae91c8aef53911ad8e178fd379cd58d84a6e5689 Mon Sep 17 00:00:00 2001 From: claw Date: Tue, 12 May 2026 16:11:58 -0700 Subject: [PATCH 04/16] fix: address review findings from rounds 2834-2838 MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Unexport RetryBackoff, add SetRetryBackoff method (#17286) - Rename http field to httpClient to avoid shadowing (#17289) - Group const blocks into single declaration (#17291) - Fix CheckRedirect to compare against previous hop, not first (#17302) - Strip auth header on protocol downgrade https→http (#17297) - Add maxPages safeguard to pagination loops (#17299, #17300) - Document mapCheckRunStatus unused second parameter (#17287, #17303) --- github/client.go | 47 ++++++++++++++++++++++++++----------------- github/client_test.go | 20 +++++++++--------- github/files_test.go | 4 ++-- github/pr.go | 17 +++++++++------- github/pr_test.go | 4 ++-- 5 files changed, 52 insertions(+), 40 deletions(-) diff --git a/github/client.go b/github/client.go index e0f5dfc..293cc17 100644 --- a/github/client.go +++ b/github/client.go @@ -14,11 +14,13 @@ import ( "time" ) -const defaultBaseURL = "https://api.github.com" -const userAgent = "review-bot/1.0" +const ( + defaultBaseURL = "https://api.github.com" + userAgent = "review-bot/1.0" -// maxResponseBytes limits successful response body reads to 10 MiB. -const maxResponseBytes = 10 * 1024 * 1024 + // maxResponseBytes limits successful response body reads to 10 MiB. + maxResponseBytes = 10 * 1024 * 1024 +) // APIError represents an HTTP error response from the GitHub API. // It carries the status code so callers can distinguish between @@ -68,12 +70,12 @@ func asAPIError(err error) (*APIError, bool) { type Client struct { baseURL string token string - http *http.Client + httpClient *http.Client - // RetryBackoff defines the delays between retry attempts for 429 responses. - // RetryBackoff[i] is the delay before attempt i+1 (after attempt i fails). - // If nil, defaults to {1s, 2s}. Set to shorter durations in tests. - RetryBackoff []time.Duration + // retryBackoff defines the delays between retry attempts for 429 responses. + // retryBackoff[i] is the delay before attempt i+1 (after attempt i fails). + // If nil, defaults to {1s, 2s}. Set to shorter durations in tests via SetRetryBackoff. + retryBackoff []time.Duration } // NewClient creates a new GitHub API client. @@ -86,16 +88,17 @@ func NewClient(token, baseURL string) *Client { return &Client{ baseURL: strings.TrimRight(baseURL, "/"), token: token, - http: &http.Client{ + httpClient: &http.Client{ Timeout: 30 * time.Second, CheckRedirect: func(req *http.Request, via []*http.Request) error { - // Prevent forwarding Authorization header to different hosts on redirect. - if len(via) > 0 && req.URL.Host != via[0].URL.Host { - req.Header.Del("Authorization") - } if len(via) >= 10 { return fmt.Errorf("stopped after 10 redirects") } + // Strip Authorization on cross-host redirect or protocol downgrade (https→http). + prev := via[len(via)-1] + if req.URL.Host != prev.URL.Host || (prev.URL.Scheme == "https" && req.URL.Scheme == "http") { + req.Header.Del("Authorization") + } return nil }, }, @@ -105,7 +108,13 @@ func NewClient(token, baseURL string) *Client { // SetHTTPClient sets the underlying HTTP client used for requests. // This is intended for testing to inject mock transports. func (c *Client) SetHTTPClient(hc *http.Client) { - c.http = hc + c.httpClient = hc +} + +// SetRetryBackoff configures the retry backoff durations for testing. +// In production the default {1s, 2s} applies. +func (c *Client) SetRetryBackoff(d []time.Duration) { + c.retryBackoff = d } // doRequest performs an HTTP request with retry on 429 rate limit responses. @@ -116,9 +125,9 @@ func (c *Client) doRequest(ctx context.Context, method, url string, accept strin const maxRetryAfter = 120 * time.Second var backoff []time.Duration - if c.RetryBackoff != nil { - backoff = make([]time.Duration, len(c.RetryBackoff)) - copy(backoff, c.RetryBackoff) + if c.retryBackoff != nil { + backoff = make([]time.Duration, len(c.retryBackoff)) + copy(backoff, c.retryBackoff) } else { backoff = []time.Duration{1 * time.Second, 2 * time.Second} } @@ -157,7 +166,7 @@ func (c *Client) doRequest(ctx context.Context, method, url string, accept strin req.Header.Set("Accept", "application/vnd.github+json") } - resp, err := c.http.Do(req) + resp, err := c.httpClient.Do(req) if err != nil { return nil, fmt.Errorf("do request: %w", err) } diff --git a/github/client_test.go b/github/client_test.go index 794df2f..d59edc7 100644 --- a/github/client_test.go +++ b/github/client_test.go @@ -81,7 +81,7 @@ func TestDoRequest_429Retry(t *testing.T) { c := NewClient("token", srv.URL) c.SetHTTPClient(srv.Client()) - c.RetryBackoff = []time.Duration{10 * time.Millisecond, 10 * time.Millisecond} + c.SetRetryBackoff([]time.Duration{10 * time.Millisecond, 10 * time.Millisecond}) body, err := c.doGet(context.Background(), srv.URL+"/test") if err != nil { @@ -106,7 +106,7 @@ func TestDoRequest_429ExhaustsRetries(t *testing.T) { c := NewClient("token", srv.URL) c.SetHTTPClient(srv.Client()) - c.RetryBackoff = []time.Duration{1 * time.Millisecond, 1 * time.Millisecond} + c.SetRetryBackoff([]time.Duration{1 * time.Millisecond, 1 * time.Millisecond}) _, err := c.doGet(context.Background(), srv.URL+"/test") if err == nil { @@ -205,7 +205,7 @@ func TestDoRequest_429RetryAfterHeader(t *testing.T) { c := NewClient("token", srv.URL) c.SetHTTPClient(srv.Client()) // Use short backoff; Retry-After should override - c.RetryBackoff = []time.Duration{1 * time.Millisecond, 1 * time.Millisecond} + c.SetRetryBackoff([]time.Duration{1 * time.Millisecond, 1 * time.Millisecond}) start := time.Now() body, err := c.doGet(context.Background(), srv.URL+"/test") @@ -246,19 +246,19 @@ func TestDoRequest_RetryAfterDoesNotMutateBackoff(t *testing.T) { c := NewClient("token", srv.URL) c.SetHTTPClient(srv.Client()) - c.RetryBackoff = []time.Duration{1 * time.Millisecond, 1 * time.Millisecond} + c.SetRetryBackoff([]time.Duration{1 * time.Millisecond, 1 * time.Millisecond}) _, err := c.doGet(context.Background(), srv.URL+"/test") if err != nil { t.Fatalf("unexpected error: %v", err) } - // Verify the original RetryBackoff slice was not mutated - if c.RetryBackoff[0] != 1*time.Millisecond { - t.Errorf("RetryBackoff[0] was mutated: got %v, want 1ms", c.RetryBackoff[0]) + // Verify the original retryBackoff slice was not mutated + if c.retryBackoff[0] != 1*time.Millisecond { + t.Errorf("retryBackoff[0] was mutated: got %v, want 1ms", c.retryBackoff[0]) } - if c.RetryBackoff[1] != 1*time.Millisecond { - t.Errorf("RetryBackoff[1] was mutated: got %v, want 1ms", c.RetryBackoff[1]) + if c.retryBackoff[1] != 1*time.Millisecond { + t.Errorf("retryBackoff[1] was mutated: got %v, want 1ms", c.retryBackoff[1]) } } @@ -310,7 +310,7 @@ func TestDoRequest_SkipsAuthWhenTokenEmpty(t *testing.T) { func TestNewClient_CheckRedirectStripsAuthOnCrossHost(t *testing.T) { // Verify the CheckRedirect function is configured c := NewClient("secret-token", "https://api.github.com") - if c.http.CheckRedirect == nil { + if c.httpClient.CheckRedirect == nil { t.Fatal("expected CheckRedirect to be set") } } diff --git a/github/files_test.go b/github/files_test.go index 3c6d889..2c8d80f 100644 --- a/github/files_test.go +++ b/github/files_test.go @@ -109,7 +109,7 @@ func TestGetFileContent_429Retry(t *testing.T) { c := NewClient("token", srv.URL) c.SetHTTPClient(srv.Client()) - c.RetryBackoff = []time.Duration{1 * time.Millisecond} + c.SetRetryBackoff([]time.Duration{1 * time.Millisecond}) content, err := c.GetFileContent(context.Background(), "owner", "repo", "file.go", "") if err != nil { @@ -227,7 +227,7 @@ func TestListContents_429Retry(t *testing.T) { c := NewClient("token", srv.URL) c.SetHTTPClient(srv.Client()) - c.RetryBackoff = []time.Duration{1 * time.Millisecond} + c.SetRetryBackoff([]time.Duration{1 * time.Millisecond}) entries, err := c.ListContents(context.Background(), "owner", "repo", ".") if err != nil { diff --git a/github/pr.go b/github/pr.go index 0d1046f..c26f061 100644 --- a/github/pr.go +++ b/github/pr.go @@ -84,13 +84,16 @@ func (c *Client) GetPullRequestDiff(ctx context.Context, owner, repo string, num return string(body), nil } +// maxPages is the upper bound on pagination loops to prevent unbounded iteration +// in case the server returns a full page indefinitely. +const maxPages = 100 + // GetPullRequestFiles fetches the list of files changed in a PR. // Paginates through all pages (100 per page) to collect all files. func (c *Client) GetPullRequestFiles(ctx context.Context, owner, repo string, number int) ([]vcs.ChangedFile, error) { var allFiles []vcs.ChangedFile - page := 1 - for { + for page := 1; page <= maxPages; page++ { reqURL := fmt.Sprintf("%s/repos/%s/%s/pulls/%d/files?per_page=100&page=%d", c.baseURL, url.PathEscape(owner), url.PathEscape(repo), number, page) body, err := c.doGet(ctx, reqURL) @@ -114,7 +117,6 @@ func (c *Client) GetPullRequestFiles(ctx context.Context, owner, repo string, nu if len(files) < 100 { break } - page++ } return allFiles, nil @@ -175,8 +177,7 @@ func (c *Client) GetCommitStatuses(ctx context.Context, owner, repo, sha string) } // Fetch check runs (paginated) - checkPage := 1 - for { + for checkPage := 1; checkPage <= maxPages; checkPage++ { checkURL := fmt.Sprintf("%s/repos/%s/%s/commits/%s/check-runs?per_page=100&page=%d", c.baseURL, url.PathEscape(owner), url.PathEscape(repo), url.PathEscape(sha), checkPage) checkBody, err := c.doGet(ctx, checkURL) @@ -198,13 +199,15 @@ func (c *Client) GetCommitStatuses(ctx context.Context, owner, repo, sha string) if len(checkResp.CheckRuns) < 100 { break } - checkPage++ } return result, nil } -// mapCheckRunStatus maps a check run conclusion+status to a vcs.CommitStatus status string. +// mapCheckRunStatus maps a check run conclusion to a vcs.CommitStatus status string. +// The second parameter (check run status field, e.g. "completed", "in_progress") is +// unused because conclusion alone determines the mapped state: nil conclusion means +// the run is still in progress (pending), regardless of the status field value. func mapCheckRunStatus(conclusion *string, _ string) string { if conclusion == nil { // Still running or queued diff --git a/github/pr_test.go b/github/pr_test.go index 78366b7..7dbd2ad 100644 --- a/github/pr_test.go +++ b/github/pr_test.go @@ -112,7 +112,7 @@ func TestGetPullRequest_429Retry(t *testing.T) { c := NewClient("token", srv.URL) c.SetHTTPClient(srv.Client()) - c.RetryBackoff = []time.Duration{1 * time.Millisecond} + c.SetRetryBackoff([]time.Duration{1 * time.Millisecond}) pr, err := c.GetPullRequest(context.Background(), "owner", "repo", 1) if err != nil { @@ -447,7 +447,7 @@ func TestGetFileContentAtRef_429Retry(t *testing.T) { c := NewClient("token", srv.URL) c.SetHTTPClient(srv.Client()) - c.RetryBackoff = []time.Duration{1 * time.Millisecond} + c.SetRetryBackoff([]time.Duration{1 * time.Millisecond}) content, err := c.GetFileContentAtRef(context.Background(), "owner", "repo", "file.go", "main") if err != nil { -- 2.47.3 From c10bb7211718bd895be4b7f42e37041405d751f7 Mon Sep 17 00:00:00 2001 From: claw Date: Tue, 12 May 2026 16:25:32 -0700 Subject: [PATCH 05/16] fix: address self-review NIT findings on PR #93 - Add timer.Stop() on happy path in retry loop (idiomatic) - Add concurrency caveat to Client doc comment for SetHTTPClient/SetRetryBackoff - Add explicit 'stale'/'waiting' cases to mapCheckRunStatus --- github/client.go | 4 +++- github/pr.go | 2 ++ 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/github/client.go b/github/client.go index 293cc17..3916425 100644 --- a/github/client.go +++ b/github/client.go @@ -66,7 +66,8 @@ func asAPIError(err error) (*APIError, bool) { } // Client interacts with the GitHub API. -// A Client is safe for concurrent use by multiple goroutines. +// A Client is safe for concurrent use by multiple goroutines; +// however, SetHTTPClient and SetRetryBackoff must not be called concurrently with requests. type Client struct { baseURL string token string @@ -145,6 +146,7 @@ func (c *Client) doRequest(ctx context.Context, method, url string, accept strin timer := time.NewTimer(delay) select { case <-timer.C: + timer.Stop() case <-ctx.Done(): timer.Stop() return nil, ctx.Err() diff --git a/github/pr.go b/github/pr.go index c26f061..ec330ae 100644 --- a/github/pr.go +++ b/github/pr.go @@ -220,6 +220,8 @@ func mapCheckRunStatus(conclusion *string, _ string) string { return "failure" case "cancelled", "skipped", "neutral": return "success" // non-blocking + case "stale", "waiting": + return "pending" default: return "pending" } -- 2.47.3 From 1bc3f206ba49bdaef28306db64a01b16642cafe0 Mon Sep 17 00:00:00 2001 From: claw Date: Tue, 12 May 2026 16:39:01 -0700 Subject: [PATCH 06/16] fix: address review findings from rounds 2843-2846 - Remove redundant timer.Stop() after timer fires (Sonnet #1, GPT #2) - Remove unused TotalCount field from checkRunsResponse (Sonnet #2) - Improve escapePath doc comment to explain deliberate silent stripping (Sonnet #3) - Fix ListContents to handle both array (directory) and object (single file) responses from GitHub Contents API (GPT #3) - Add HTTPS enforcement: refuse to send credentials over non-HTTPS URLs unless AllowInsecureHTTP() option is passed (Security #1) - Replace constant-value test with actual behavior test for response body limiting (Sonnet #6) - Run gofmt for consistent formatting (Sonnet #4) - Add tests for HTTPS enforcement and ListContents single-file handling --- github/client.go | 42 +++++++++++++++--- github/client_test.go | 100 +++++++++++++++++++++++++++++++++++------- github/files.go | 23 ++++++++-- github/files_test.go | 47 +++++++++++++++----- github/pr.go | 3 +- github/pr_test.go | 48 ++++++++++---------- 6 files changed, 201 insertions(+), 62 deletions(-) diff --git a/github/client.go b/github/client.go index 3916425..8dd157b 100644 --- a/github/client.go +++ b/github/client.go @@ -65,13 +65,30 @@ func asAPIError(err error) (*APIError, bool) { return nil, false } +// clientConfig holds optional configuration for NewClient. +type clientConfig struct { + allowInsecureHTTP bool +} + +// ClientOption configures optional behavior of NewClient. +type ClientOption func(*clientConfig) + +// AllowInsecureHTTP permits the client to use HTTP (non-TLS) base URLs. +// This should only be used for trusted internal deployments or testing. +func AllowInsecureHTTP() ClientOption { + return func(c *clientConfig) { + c.allowInsecureHTTP = true + } +} + // Client interacts with the GitHub API. // A Client is safe for concurrent use by multiple goroutines; // however, SetHTTPClient and SetRetryBackoff must not be called concurrently with requests. type Client struct { - baseURL string - token string - httpClient *http.Client + baseURL string + token string + allowInsecureHTTP bool + httpClient *http.Client // retryBackoff defines the delays between retry attempts for 429 responses. // retryBackoff[i] is the delay before attempt i+1 (after attempt i fails). @@ -82,13 +99,20 @@ type Client struct { // NewClient creates a new GitHub API client. // If baseURL is empty, it defaults to https://api.github.com. // For GitHub Enterprise, pass the API base URL (e.g. https://github.concur.com/api/v3). -func NewClient(token, baseURL string) *Client { +// The baseURL must use HTTPS; pass AllowInsecureHTTP() as an option to permit HTTP +// for trusted internal deployments (e.g. local testing). +func NewClient(token, baseURL string, opts ...ClientOption) *Client { if baseURL == "" { baseURL = defaultBaseURL } + cfg := clientConfig{} + for _, o := range opts { + o(&cfg) + } return &Client{ - baseURL: strings.TrimRight(baseURL, "/"), - token: token, + baseURL: strings.TrimRight(baseURL, "/"), + allowInsecureHTTP: cfg.allowInsecureHTTP, + token: token, httpClient: &http.Client{ Timeout: 30 * time.Second, CheckRedirect: func(req *http.Request, via []*http.Request) error { @@ -146,7 +170,7 @@ func (c *Client) doRequest(ctx context.Context, method, url string, accept strin timer := time.NewTimer(delay) select { case <-timer.C: - timer.Stop() + // Timer already fired; Stop() is a no-op here. case <-ctx.Done(): timer.Stop() return nil, ctx.Err() @@ -159,6 +183,10 @@ func (c *Client) doRequest(ctx context.Context, method, url string, accept strin return nil, fmt.Errorf("create request: %w", err) } if c.token != "" { + // Refuse to send credentials over plaintext unless explicitly allowed. + if !c.allowInsecureHTTP && req.URL.Scheme != "https" { + return nil, fmt.Errorf("refusing to send credentials over non-HTTPS URL %q (use AllowInsecureHTTP option for trusted networks)", req.URL.Host) + } req.Header.Set("Authorization", "Bearer "+c.token) } req.Header.Set("User-Agent", userAgent) diff --git a/github/client_test.go b/github/client_test.go index d59edc7..ea03ea2 100644 --- a/github/client_test.go +++ b/github/client_test.go @@ -4,6 +4,7 @@ import ( "context" "net/http" "net/http/httptest" + "strings" "testing" "time" ) @@ -38,7 +39,7 @@ func TestDoRequest_SetsAuthHeader(t *testing.T) { })) defer srv.Close() - c := NewClient("my-token", srv.URL) + c := NewClient("my-token", srv.URL, AllowInsecureHTTP()) c.SetHTTPClient(srv.Client()) _, _ = c.doGet(context.Background(), srv.URL+"/test") @@ -56,7 +57,7 @@ func TestDoRequest_SetsDefaultAcceptHeader(t *testing.T) { })) defer srv.Close() - c := NewClient("token", srv.URL) + c := NewClient("token", srv.URL, AllowInsecureHTTP()) c.SetHTTPClient(srv.Client()) _, _ = c.doGet(context.Background(), srv.URL+"/test") @@ -79,7 +80,7 @@ func TestDoRequest_429Retry(t *testing.T) { })) defer srv.Close() - c := NewClient("token", srv.URL) + c := NewClient("token", srv.URL, AllowInsecureHTTP()) c.SetHTTPClient(srv.Client()) c.SetRetryBackoff([]time.Duration{10 * time.Millisecond, 10 * time.Millisecond}) @@ -104,7 +105,7 @@ func TestDoRequest_429ExhaustsRetries(t *testing.T) { })) defer srv.Close() - c := NewClient("token", srv.URL) + c := NewClient("token", srv.URL, AllowInsecureHTTP()) c.SetHTTPClient(srv.Client()) c.SetRetryBackoff([]time.Duration{1 * time.Millisecond, 1 * time.Millisecond}) @@ -133,7 +134,7 @@ func TestDoRequest_404NoRetry(t *testing.T) { })) defer srv.Close() - c := NewClient("token", srv.URL) + c := NewClient("token", srv.URL, AllowInsecureHTTP()) c.SetHTTPClient(srv.Client()) _, err := c.doGet(context.Background(), srv.URL+"/test") @@ -154,7 +155,7 @@ func TestDoRequest_401NoRetry(t *testing.T) { })) defer srv.Close() - c := NewClient("token", srv.URL) + c := NewClient("token", srv.URL, AllowInsecureHTTP()) c.SetHTTPClient(srv.Client()) _, err := c.doGet(context.Background(), srv.URL+"/test") @@ -202,7 +203,7 @@ func TestDoRequest_429RetryAfterHeader(t *testing.T) { })) defer srv.Close() - c := NewClient("token", srv.URL) + c := NewClient("token", srv.URL, AllowInsecureHTTP()) c.SetHTTPClient(srv.Client()) // Use short backoff; Retry-After should override c.SetRetryBackoff([]time.Duration{1 * time.Millisecond, 1 * time.Millisecond}) @@ -244,7 +245,7 @@ func TestDoRequest_RetryAfterDoesNotMutateBackoff(t *testing.T) { })) defer srv.Close() - c := NewClient("token", srv.URL) + c := NewClient("token", srv.URL, AllowInsecureHTTP()) c.SetHTTPClient(srv.Client()) c.SetRetryBackoff([]time.Duration{1 * time.Millisecond, 1 * time.Millisecond}) @@ -271,7 +272,7 @@ func TestDoRequest_SetsUserAgentHeader(t *testing.T) { })) defer srv.Close() - c := NewClient("token", srv.URL) + c := NewClient("token", srv.URL, AllowInsecureHTTP()) c.SetHTTPClient(srv.Client()) _, _ = c.doGet(context.Background(), srv.URL+"/test") @@ -281,11 +282,24 @@ func TestDoRequest_SetsUserAgentHeader(t *testing.T) { } func TestDoRequest_LimitsResponseBody(t *testing.T) { - // Verify that responses are read through a limit reader. - // We can't easily test the 10 MiB limit without OOM risk, - // but we verify the constant is set correctly. - if maxResponseBytes != 10*1024*1024 { - t.Errorf("expected maxResponseBytes = 10 MiB, got %d", maxResponseBytes) + // Verify that response body reading is actually bounded by maxResponseBytes. + // Use a small custom limit to avoid allocating 10 MiB in tests. + bigBody := strings.Repeat("x", maxResponseBytes+1024) + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(200) + w.Write([]byte(bigBody)) + })) + defer srv.Close() + + c := NewClient("token", srv.URL, AllowInsecureHTTP()) + c.SetHTTPClient(srv.Client()) + body, err := c.doGet(context.Background(), srv.URL+"/test") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + // LimitReader should cap the body at maxResponseBytes + if len(body) > maxResponseBytes { + t.Errorf("expected body <= %d bytes, got %d", maxResponseBytes, len(body)) } } @@ -298,7 +312,7 @@ func TestDoRequest_SkipsAuthWhenTokenEmpty(t *testing.T) { })) defer srv.Close() - c := NewClient("", srv.URL) // empty token + c := NewClient("", srv.URL, AllowInsecureHTTP()) // empty token c.SetHTTPClient(srv.Client()) _, _ = c.doGet(context.Background(), srv.URL+"/test") @@ -314,3 +328,59 @@ func TestNewClient_CheckRedirectStripsAuthOnCrossHost(t *testing.T) { t.Fatal("expected CheckRedirect to be set") } } + +func TestDoRequest_RejectsHTTPWithToken(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(200) + w.Write([]byte("{}")) + })) + defer srv.Close() + + // Without AllowInsecureHTTP, should refuse to send token over HTTP + c := NewClient("secret-token", srv.URL) + c.SetHTTPClient(srv.Client()) + _, err := c.doGet(context.Background(), srv.URL+"/test") + if err == nil { + t.Fatal("expected error when sending token over HTTP") + } + if !strings.Contains(err.Error(), "refusing to send credentials") { + t.Errorf("unexpected error message: %v", err) + } +} + +func TestDoRequest_AllowsHTTPWithoutToken(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(200) + w.Write([]byte(`{"ok":true}`)) + })) + defer srv.Close() + + // Without token, HTTP should be fine (no credentials to leak) + c := NewClient("", srv.URL) + c.SetHTTPClient(srv.Client()) + body, err := c.doGet(context.Background(), srv.URL+"/test") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if string(body) != `{"ok":true}` { + t.Errorf("unexpected body: %s", body) + } +} + +func TestDoRequest_AllowsHTTPWithInsecureOption(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(200) + w.Write([]byte(`{"ok":true}`)) + })) + defer srv.Close() + + c := NewClient("secret-token", srv.URL, AllowInsecureHTTP()) + c.SetHTTPClient(srv.Client()) + body, err := c.doGet(context.Background(), srv.URL+"/test") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if string(body) != `{"ok":true}` { + t.Errorf("unexpected body: %s", body) + } +} diff --git a/github/files.go b/github/files.go index df2d6fc..442cc63 100644 --- a/github/files.go +++ b/github/files.go @@ -19,6 +19,9 @@ func (c *Client) GetFileContent(ctx context.Context, owner, repo, path, ref stri // ListContents lists files and directories at a given path in a repo. // Returns the directory listing from the GitHub contents API. +// If the path points to a single file (not a directory), the API returns +// a JSON object instead of an array; this is handled by returning a +// single-element slice. func (c *Client) ListContents(ctx context.Context, owner, repo, path string) ([]vcs.ContentEntry, error) { reqURL := fmt.Sprintf("%s/repos/%s/%s/contents/%s", c.baseURL, url.PathEscape(owner), url.PathEscape(repo), escapePath(path)) @@ -26,14 +29,24 @@ func (c *Client) ListContents(ctx context.Context, owner, repo, path string) ([] if err != nil { return nil, fmt.Errorf("list contents %s: %w", path, err) } - var entries []struct { + + type entry struct { Name string `json:"name"` Path string `json:"path"` Type string `json:"type"` } + + // The GitHub contents API returns an array for directories and an object + // for single files. Try array first (common case), then fall back to object. + var entries []entry if err := json.Unmarshal(body, &entries); err != nil { - return nil, fmt.Errorf("parse contents JSON: %w", err) + var single entry + if err2 := json.Unmarshal(body, &single); err2 != nil { + return nil, fmt.Errorf("parse contents JSON: %w", err) + } + entries = []entry{single} } + result := make([]vcs.ContentEntry, len(entries)) for i, e := range entries { result[i] = vcs.ContentEntry{ @@ -47,7 +60,11 @@ func (c *Client) ListContents(ctx context.Context, owner, repo, path string) ([] // escapePath escapes each segment of a relative file path for use in URLs. // Slashes are preserved as path separators; other special characters are escaped. -// Dot-segments ("." and "..") are removed to prevent path traversal. +// Dot-segments ("." and "..") are silently removed to prevent path traversal. +// This is intentional: callers may receive a different path than requested without +// error. The function is package-private, and all callers (GetFileContentAtRef, +// ListContents) already handle missing-file errors from the API if the cleaned +// path doesn't match what the caller intended. func escapePath(p string) string { parts := strings.Split(p, "/") var clean []string diff --git a/github/files_test.go b/github/files_test.go index 2c8d80f..eda64a8 100644 --- a/github/files_test.go +++ b/github/files_test.go @@ -20,7 +20,7 @@ func TestGetFileContent_DelegatesToGetFileContentAtRef(t *testing.T) { })) defer srv.Close() - c := NewClient("token", srv.URL) + c := NewClient("token", srv.URL, AllowInsecureHTTP()) c.SetHTTPClient(srv.Client()) // Call with empty ref — should not include ref param @@ -47,7 +47,7 @@ func TestGetFileContent_WithRef(t *testing.T) { })) defer srv.Close() - c := NewClient("token", srv.URL) + c := NewClient("token", srv.URL, AllowInsecureHTTP()) c.SetHTTPClient(srv.Client()) _, err := c.GetFileContent(context.Background(), "owner", "repo", "file.go", "abc123") @@ -66,7 +66,7 @@ func TestGetFileContent_404(t *testing.T) { })) defer srv.Close() - c := NewClient("token", srv.URL) + c := NewClient("token", srv.URL, AllowInsecureHTTP()) c.SetHTTPClient(srv.Client()) _, err := c.GetFileContent(context.Background(), "owner", "repo", "missing.go", "") @@ -82,7 +82,7 @@ func TestGetFileContent_401(t *testing.T) { })) defer srv.Close() - c := NewClient("token", srv.URL) + c := NewClient("token", srv.URL, AllowInsecureHTTP()) c.SetHTTPClient(srv.Client()) _, err := c.GetFileContent(context.Background(), "owner", "repo", "file.go", "") @@ -107,7 +107,7 @@ func TestGetFileContent_429Retry(t *testing.T) { })) defer srv.Close() - c := NewClient("token", srv.URL) + c := NewClient("token", srv.URL, AllowInsecureHTTP()) c.SetHTTPClient(srv.Client()) c.SetRetryBackoff([]time.Duration{1 * time.Millisecond}) @@ -130,7 +130,7 @@ func TestGetFileContent_MalformedJSON(t *testing.T) { })) defer srv.Close() - c := NewClient("token", srv.URL) + c := NewClient("token", srv.URL, AllowInsecureHTTP()) c.SetHTTPClient(srv.Client()) _, err := c.GetFileContent(context.Background(), "owner", "repo", "file.go", "") @@ -151,7 +151,7 @@ func TestListContents_HappyPath(t *testing.T) { })) defer srv.Close() - c := NewClient("token", srv.URL) + c := NewClient("token", srv.URL, AllowInsecureHTTP()) c.SetHTTPClient(srv.Client()) entries, err := c.ListContents(context.Background(), "owner", "repo", "src") @@ -185,7 +185,7 @@ func TestListContents_404(t *testing.T) { })) defer srv.Close() - c := NewClient("token", srv.URL) + c := NewClient("token", srv.URL, AllowInsecureHTTP()) c.SetHTTPClient(srv.Client()) _, err := c.ListContents(context.Background(), "owner", "repo", "missing") @@ -201,7 +201,7 @@ func TestListContents_401(t *testing.T) { })) defer srv.Close() - c := NewClient("token", srv.URL) + c := NewClient("token", srv.URL, AllowInsecureHTTP()) c.SetHTTPClient(srv.Client()) _, err := c.ListContents(context.Background(), "owner", "repo", "src") @@ -225,7 +225,7 @@ func TestListContents_429Retry(t *testing.T) { })) defer srv.Close() - c := NewClient("token", srv.URL) + c := NewClient("token", srv.URL, AllowInsecureHTTP()) c.SetHTTPClient(srv.Client()) c.SetRetryBackoff([]time.Duration{1 * time.Millisecond}) @@ -248,7 +248,7 @@ func TestListContents_MalformedJSON(t *testing.T) { })) defer srv.Close() - c := NewClient("token", srv.URL) + c := NewClient("token", srv.URL, AllowInsecureHTTP()) c.SetHTTPClient(srv.Client()) _, err := c.ListContents(context.Background(), "owner", "repo", "src") @@ -307,3 +307,28 @@ func TestDecodeBase64Content_CRLF(t *testing.T) { t.Errorf("expected 'hello world', got %q", decoded) } } + +func TestListContents_SingleFile(t *testing.T) { + // GitHub Contents API returns a JSON object (not array) for single-file paths + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(200) + w.Write([]byte(`{"name":"README.md","path":"README.md","type":"file"}`)) + })) + defer srv.Close() + + c := NewClient("token", srv.URL, AllowInsecureHTTP()) + c.SetHTTPClient(srv.Client()) + entries, err := c.ListContents(context.Background(), "owner", "repo", "README.md") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if len(entries) != 1 { + t.Fatalf("expected 1 entry, got %d", len(entries)) + } + if entries[0].Name != "README.md" { + t.Errorf("expected name 'README.md', got %q", entries[0].Name) + } + if entries[0].Type != "file" { + t.Errorf("expected type 'file', got %q", entries[0].Type) + } +} diff --git a/github/pr.go b/github/pr.go index ec330ae..1bb428a 100644 --- a/github/pr.go +++ b/github/pr.go @@ -44,8 +44,7 @@ type commitStatusResponse struct { // checkRunsResponse is the GitHub check runs API response. type checkRunsResponse struct { - TotalCount int `json:"total_count"` - CheckRuns []struct { + CheckRuns []struct { Name string `json:"name"` Conclusion *string `json:"conclusion"` Status string `json:"status"` diff --git a/github/pr_test.go b/github/pr_test.go index 7dbd2ad..405cc6f 100644 --- a/github/pr_test.go +++ b/github/pr_test.go @@ -26,7 +26,7 @@ func TestGetPullRequest_HappyPath(t *testing.T) { })) defer srv.Close() - c := NewClient("token", srv.URL) + c := NewClient("token", srv.URL, AllowInsecureHTTP()) c.SetHTTPClient(srv.Client()) pr, err := c.GetPullRequest(context.Background(), "owner", "repo", 42) @@ -60,7 +60,7 @@ func TestGetPullRequest_404(t *testing.T) { })) defer srv.Close() - c := NewClient("token", srv.URL) + c := NewClient("token", srv.URL, AllowInsecureHTTP()) c.SetHTTPClient(srv.Client()) _, err := c.GetPullRequest(context.Background(), "owner", "repo", 999) @@ -79,7 +79,7 @@ func TestGetPullRequest_401(t *testing.T) { })) defer srv.Close() - c := NewClient("token", srv.URL) + c := NewClient("token", srv.URL, AllowInsecureHTTP()) c.SetHTTPClient(srv.Client()) _, err := c.GetPullRequest(context.Background(), "owner", "repo", 1) @@ -110,7 +110,7 @@ func TestGetPullRequest_429Retry(t *testing.T) { })) defer srv.Close() - c := NewClient("token", srv.URL) + c := NewClient("token", srv.URL, AllowInsecureHTTP()) c.SetHTTPClient(srv.Client()) c.SetRetryBackoff([]time.Duration{1 * time.Millisecond}) @@ -133,7 +133,7 @@ func TestGetPullRequest_MalformedJSON(t *testing.T) { })) defer srv.Close() - c := NewClient("token", srv.URL) + c := NewClient("token", srv.URL, AllowInsecureHTTP()) c.SetHTTPClient(srv.Client()) _, err := c.GetPullRequest(context.Background(), "owner", "repo", 1) @@ -155,7 +155,7 @@ func TestGetPullRequestDiff_HappyPath(t *testing.T) { })) defer srv.Close() - c := NewClient("token", srv.URL) + c := NewClient("token", srv.URL, AllowInsecureHTTP()) c.SetHTTPClient(srv.Client()) diff, err := c.GetPullRequestDiff(context.Background(), "owner", "repo", 42) @@ -177,7 +177,7 @@ func TestGetPullRequestDiff_404(t *testing.T) { })) defer srv.Close() - c := NewClient("token", srv.URL) + c := NewClient("token", srv.URL, AllowInsecureHTTP()) c.SetHTTPClient(srv.Client()) _, err := c.GetPullRequestDiff(context.Background(), "owner", "repo", 999) @@ -193,7 +193,7 @@ func TestGetPullRequestDiff_401(t *testing.T) { })) defer srv.Close() - c := NewClient("token", srv.URL) + c := NewClient("token", srv.URL, AllowInsecureHTTP()) c.SetHTTPClient(srv.Client()) _, err := c.GetPullRequestDiff(context.Background(), "owner", "repo", 1) @@ -211,7 +211,7 @@ func TestGetPullRequestFiles_HappyPath(t *testing.T) { })) defer srv.Close() - c := NewClient("token", srv.URL) + c := NewClient("token", srv.URL, AllowInsecureHTTP()) c.SetHTTPClient(srv.Client()) files, err := c.GetPullRequestFiles(context.Background(), "owner", "repo", 1) @@ -256,7 +256,7 @@ func TestGetPullRequestFiles_Pagination(t *testing.T) { })) defer srv.Close() - c := NewClient("token", srv.URL) + c := NewClient("token", srv.URL, AllowInsecureHTTP()) c.SetHTTPClient(srv.Client()) files, err := c.GetPullRequestFiles(context.Background(), "owner", "repo", 1) @@ -283,7 +283,7 @@ func TestGetPullRequestFiles_BinaryFile_NoPatch(t *testing.T) { })) defer srv.Close() - c := NewClient("token", srv.URL) + c := NewClient("token", srv.URL, AllowInsecureHTTP()) c.SetHTTPClient(srv.Client()) files, err := c.GetPullRequestFiles(context.Background(), "owner", "repo", 1) @@ -305,7 +305,7 @@ func TestGetPullRequestFiles_404(t *testing.T) { })) defer srv.Close() - c := NewClient("token", srv.URL) + c := NewClient("token", srv.URL, AllowInsecureHTTP()) c.SetHTTPClient(srv.Client()) _, err := c.GetPullRequestFiles(context.Background(), "owner", "repo", 999) @@ -321,7 +321,7 @@ func TestGetPullRequestFiles_MalformedJSON(t *testing.T) { })) defer srv.Close() - c := NewClient("token", srv.URL) + c := NewClient("token", srv.URL, AllowInsecureHTTP()) c.SetHTTPClient(srv.Client()) _, err := c.GetPullRequestFiles(context.Background(), "owner", "repo", 1) @@ -345,7 +345,7 @@ func TestGetFileContentAtRef_HappyPath(t *testing.T) { })) defer srv.Close() - c := NewClient("token", srv.URL) + c := NewClient("token", srv.URL, AllowInsecureHTTP()) c.SetHTTPClient(srv.Client()) content, err := c.GetFileContentAtRef(context.Background(), "owner", "repo", "path/to/file.go", "abc123") @@ -369,7 +369,7 @@ func TestGetFileContentAtRef_EmptyRef(t *testing.T) { })) defer srv.Close() - c := NewClient("token", srv.URL) + c := NewClient("token", srv.URL, AllowInsecureHTTP()) c.SetHTTPClient(srv.Client()) content, err := c.GetFileContentAtRef(context.Background(), "owner", "repo", "file.txt", "") @@ -388,7 +388,7 @@ func TestGetFileContentAtRef_404(t *testing.T) { })) defer srv.Close() - c := NewClient("token", srv.URL) + c := NewClient("token", srv.URL, AllowInsecureHTTP()) c.SetHTTPClient(srv.Client()) _, err := c.GetFileContentAtRef(context.Background(), "owner", "repo", "missing.go", "main") @@ -404,7 +404,7 @@ func TestGetFileContentAtRef_401(t *testing.T) { })) defer srv.Close() - c := NewClient("token", srv.URL) + c := NewClient("token", srv.URL, AllowInsecureHTTP()) c.SetHTTPClient(srv.Client()) _, err := c.GetFileContentAtRef(context.Background(), "owner", "repo", "file.go", "main") @@ -420,7 +420,7 @@ func TestGetFileContentAtRef_MalformedJSON(t *testing.T) { })) defer srv.Close() - c := NewClient("token", srv.URL) + c := NewClient("token", srv.URL, AllowInsecureHTTP()) c.SetHTTPClient(srv.Client()) _, err := c.GetFileContentAtRef(context.Background(), "owner", "repo", "file.go", "main") @@ -445,7 +445,7 @@ func TestGetFileContentAtRef_429Retry(t *testing.T) { })) defer srv.Close() - c := NewClient("token", srv.URL) + c := NewClient("token", srv.URL, AllowInsecureHTTP()) c.SetHTTPClient(srv.Client()) c.SetRetryBackoff([]time.Duration{1 * time.Millisecond}) @@ -496,7 +496,7 @@ func TestGetCommitStatuses_HappyPath(t *testing.T) { })) defer srv.Close() - c := NewClient("token", srv.URL) + c := NewClient("token", srv.URL, AllowInsecureHTTP()) c.SetHTTPClient(srv.Client()) statuses, err := c.GetCommitStatuses(context.Background(), "owner", "repo", "abc123") @@ -567,7 +567,7 @@ func TestGetCommitStatuses_CheckRunConclusions(t *testing.T) { })) defer srv.Close() - c := NewClient("token", srv.URL) + c := NewClient("token", srv.URL, AllowInsecureHTTP()) c.SetHTTPClient(srv.Client()) statuses, err := c.GetCommitStatuses(context.Background(), "owner", "repo", "sha1") @@ -591,7 +591,7 @@ func TestGetCommitStatuses_404(t *testing.T) { })) defer srv.Close() - c := NewClient("token", srv.URL) + c := NewClient("token", srv.URL, AllowInsecureHTTP()) c.SetHTTPClient(srv.Client()) _, err := c.GetCommitStatuses(context.Background(), "owner", "repo", "badsha") @@ -607,7 +607,7 @@ func TestGetCommitStatuses_401(t *testing.T) { })) defer srv.Close() - c := NewClient("token", srv.URL) + c := NewClient("token", srv.URL, AllowInsecureHTTP()) c.SetHTTPClient(srv.Client()) _, err := c.GetCommitStatuses(context.Background(), "owner", "repo", "sha") @@ -623,7 +623,7 @@ func TestGetCommitStatuses_MalformedJSON(t *testing.T) { })) defer srv.Close() - c := NewClient("token", srv.URL) + c := NewClient("token", srv.URL, AllowInsecureHTTP()) c.SetHTTPClient(srv.Client()) _, err := c.GetCommitStatuses(context.Background(), "owner", "repo", "sha") -- 2.47.3 From af72c64b7f3e7760edb37a73f29f520f40c74510 Mon Sep 17 00:00:00 2001 From: claw Date: Tue, 12 May 2026 16:48:39 -0700 Subject: [PATCH 07/16] fix(github): correct ListContents error wrapping and move HTTPS guard before retry loop --- github/client.go | 9 +++++---- github/files.go | 2 +- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/github/client.go b/github/client.go index 8dd157b..c148f96 100644 --- a/github/client.go +++ b/github/client.go @@ -159,6 +159,11 @@ func (c *Client) doRequest(ctx context.Context, method, url string, accept strin const maxErrorBodyBytes = 64 * 1024 + // Reject non-HTTPS URLs early since the URL is immutable across retries. + if c.token != "" && !c.allowInsecureHTTP && !strings.HasPrefix(url, "https://") { + return nil, fmt.Errorf("refusing to send credentials over non-HTTPS URL %q (use AllowInsecureHTTP option for trusted networks)", url) + } + var lastErr error for attempt := 0; attempt < maxAttempts; attempt++ { if attempt > 0 { @@ -183,10 +188,6 @@ func (c *Client) doRequest(ctx context.Context, method, url string, accept strin return nil, fmt.Errorf("create request: %w", err) } if c.token != "" { - // Refuse to send credentials over plaintext unless explicitly allowed. - if !c.allowInsecureHTTP && req.URL.Scheme != "https" { - return nil, fmt.Errorf("refusing to send credentials over non-HTTPS URL %q (use AllowInsecureHTTP option for trusted networks)", req.URL.Host) - } req.Header.Set("Authorization", "Bearer "+c.token) } req.Header.Set("User-Agent", userAgent) diff --git a/github/files.go b/github/files.go index 442cc63..f09d3e5 100644 --- a/github/files.go +++ b/github/files.go @@ -42,7 +42,7 @@ func (c *Client) ListContents(ctx context.Context, owner, repo, path string) ([] if err := json.Unmarshal(body, &entries); err != nil { var single entry if err2 := json.Unmarshal(body, &single); err2 != nil { - return nil, fmt.Errorf("parse contents JSON: %w", err) + return nil, fmt.Errorf("parse contents JSON: %w", err2) } entries = []entry{single} } -- 2.47.3 From fce5f2d1840652de0eb3d8bc9068f9a313b0e2a1 Mon Sep 17 00:00:00 2001 From: claw Date: Tue, 12 May 2026 16:55:32 -0700 Subject: [PATCH 08/16] fix(github): address review findings on client.go - Use net/url.Parse for HTTPS scheme check (case-insensitive) - Guard SetHTTPClient against nil (restores default 30s client) - Rename 'url' param to 'reqURL' in doRequest/doGet for clarity - Return error when response exceeds maxResponseBytes instead of silently truncating Finding #1 (Bearer auth scheme) intentionally kept: GitHub REST API officially supports and recommends Bearer for all token types. See: https://docs.github.com/en/rest/authentication/authenticating-to-the-rest-api --- github/client.go | 26 ++++++++++++++++++++------ github/client_test.go | 25 +++++++++++++++++-------- 2 files changed, 37 insertions(+), 14 deletions(-) diff --git a/github/client.go b/github/client.go index c148f96..f599305 100644 --- a/github/client.go +++ b/github/client.go @@ -9,6 +9,7 @@ import ( "fmt" "io" "net/http" + "net/url" "strconv" "strings" "time" @@ -132,7 +133,11 @@ func NewClient(token, baseURL string, opts ...ClientOption) *Client { // SetHTTPClient sets the underlying HTTP client used for requests. // This is intended for testing to inject mock transports. +// Passing nil will restore the default client with a 30s timeout. func (c *Client) SetHTTPClient(hc *http.Client) { + if hc == nil { + hc = &http.Client{Timeout: 30 * time.Second} + } c.httpClient = hc } @@ -145,7 +150,7 @@ func (c *Client) SetRetryBackoff(d []time.Duration) { // doRequest performs an HTTP request with retry on 429 rate limit responses. // It respects the Retry-After header when present (capped at maxRetryAfter). // Transport errors (network failures, context cancellation) are not retried. -func (c *Client) doRequest(ctx context.Context, method, url string, accept string) ([]byte, error) { +func (c *Client) doRequest(ctx context.Context, method, reqURL string, accept string) ([]byte, error) { const maxAttempts = 3 const maxRetryAfter = 120 * time.Second @@ -160,8 +165,14 @@ func (c *Client) doRequest(ctx context.Context, method, url string, accept strin const maxErrorBodyBytes = 64 * 1024 // Reject non-HTTPS URLs early since the URL is immutable across retries. - if c.token != "" && !c.allowInsecureHTTP && !strings.HasPrefix(url, "https://") { - return nil, fmt.Errorf("refusing to send credentials over non-HTTPS URL %q (use AllowInsecureHTTP option for trusted networks)", url) + if c.token != "" && !c.allowInsecureHTTP { + parsed, err := url.Parse(reqURL) + if err != nil { + return nil, fmt.Errorf("parse request URL: %w", err) + } + if !strings.EqualFold(parsed.Scheme, "https") { + return nil, fmt.Errorf("refusing to send credentials over non-HTTPS URL %q (use AllowInsecureHTTP option for trusted networks)", reqURL) + } } var lastErr error @@ -183,7 +194,7 @@ func (c *Client) doRequest(ctx context.Context, method, url string, accept strin } } - req, err := http.NewRequestWithContext(ctx, method, url, nil) + req, err := http.NewRequestWithContext(ctx, method, reqURL, nil) if err != nil { return nil, fmt.Errorf("create request: %w", err) } @@ -208,6 +219,9 @@ func (c *Client) doRequest(ctx context.Context, method, url string, accept strin if err != nil { return nil, fmt.Errorf("read response body: %w", err) } + if int64(len(body)) >= maxResponseBytes { + return nil, fmt.Errorf("response body exceeded %d bytes (truncated)", maxResponseBytes) + } return body, nil } @@ -241,6 +255,6 @@ func (c *Client) doRequest(ctx context.Context, method, url string, accept strin } // doGet is a convenience wrapper for GET requests with the default Accept header. -func (c *Client) doGet(ctx context.Context, url string) ([]byte, error) { - return c.doRequest(ctx, http.MethodGet, url, "") +func (c *Client) doGet(ctx context.Context, reqURL string) ([]byte, error) { + return c.doRequest(ctx, http.MethodGet, reqURL, "") } diff --git a/github/client_test.go b/github/client_test.go index ea03ea2..73cf1df 100644 --- a/github/client_test.go +++ b/github/client_test.go @@ -282,8 +282,7 @@ func TestDoRequest_SetsUserAgentHeader(t *testing.T) { } func TestDoRequest_LimitsResponseBody(t *testing.T) { - // Verify that response body reading is actually bounded by maxResponseBytes. - // Use a small custom limit to avoid allocating 10 MiB in tests. + // Verify that oversized responses return an error rather than silently truncating. bigBody := strings.Repeat("x", maxResponseBytes+1024) srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(200) @@ -293,13 +292,12 @@ func TestDoRequest_LimitsResponseBody(t *testing.T) { c := NewClient("token", srv.URL, AllowInsecureHTTP()) c.SetHTTPClient(srv.Client()) - body, err := c.doGet(context.Background(), srv.URL+"/test") - if err != nil { - t.Fatalf("unexpected error: %v", err) + _, err := c.doGet(context.Background(), srv.URL+"/test") + if err == nil { + t.Fatal("expected error for oversized response body") } - // LimitReader should cap the body at maxResponseBytes - if len(body) > maxResponseBytes { - t.Errorf("expected body <= %d bytes, got %d", maxResponseBytes, len(body)) + if !strings.Contains(err.Error(), "exceeded") { + t.Errorf("expected truncation error, got: %v", err) } } @@ -384,3 +382,14 @@ func TestDoRequest_AllowsHTTPWithInsecureOption(t *testing.T) { t.Errorf("unexpected body: %s", body) } } + +func TestSetHTTPClient_NilRestoresDefault(t *testing.T) { + c := NewClient("token", "https://api.github.com") + c.SetHTTPClient(nil) + if c.httpClient == nil { + t.Fatal("expected non-nil httpClient after SetHTTPClient(nil)") + } + if c.httpClient.Timeout != 30*time.Second { + t.Errorf("expected 30s timeout, got %v", c.httpClient.Timeout) + } +} -- 2.47.3 From 1fcc0b738a1fa8e036951e6a6b055a5a64aeabbb Mon Sep 17 00:00:00 2001 From: claw Date: Tue, 12 May 2026 17:13:07 -0700 Subject: [PATCH 09/16] fix(github): address MINOR/NIT findings from review #2866 - SetHTTPClient(nil): preserve CheckRedirect auth-stripping policy instead of restoring a plain http.Client that loses cross-host protection. - Authorization header: add comment documenting why Bearer scheme is correct (OAuth2 standard, works for both classic PATs and fine-grained tokens). - Retry-After parsing: support HTTP-date format (RFC 7231) in addition to integer seconds. GitHub only sends integers today, but the implementation is now spec-compliant. - escapePath dot-segment removal: document the behavior in public API doc comments for ListContents and GetFileContentAtRef so callers are aware without reading the internal helper. --- github/client.go | 34 ++++++++++++++++-- github/client_test.go | 81 +++++++++++++++++++++++++++++++++++++++++++ github/files.go | 4 +++ github/pr.go | 4 +++ 4 files changed, 120 insertions(+), 3 deletions(-) diff --git a/github/client.go b/github/client.go index f599305..69baa36 100644 --- a/github/client.go +++ b/github/client.go @@ -133,10 +133,23 @@ func NewClient(token, baseURL string, opts ...ClientOption) *Client { // SetHTTPClient sets the underlying HTTP client used for requests. // This is intended for testing to inject mock transports. -// Passing nil will restore the default client with a 30s timeout. +// Passing nil restores the default client (30s timeout + auth-stripping +// CheckRedirect policy matching NewClient). func (c *Client) SetHTTPClient(hc *http.Client) { if hc == nil { - hc = &http.Client{Timeout: 30 * time.Second} + hc = &http.Client{ + Timeout: 30 * time.Second, + CheckRedirect: func(req *http.Request, via []*http.Request) error { + if len(via) >= 10 { + return fmt.Errorf("stopped after 10 redirects") + } + prev := via[len(via)-1] + if req.URL.Host != prev.URL.Host || (prev.URL.Scheme == "https" && req.URL.Scheme == "http") { + req.Header.Del("Authorization") + } + return nil + }, + } } c.httpClient = hc } @@ -199,6 +212,9 @@ func (c *Client) doRequest(ctx context.Context, method, reqURL string, accept st return nil, fmt.Errorf("create request: %w", err) } if c.token != "" { + // Bearer is the OAuth2 standard and is accepted by GitHub for both + // classic PATs and fine-grained tokens. The alternative "token" scheme + // is GitHub-specific and offers no additional compatibility. req.Header.Set("Authorization", "Bearer "+c.token) } req.Header.Set("User-Agent", userAgent) @@ -232,7 +248,8 @@ func (c *Client) doRequest(ctx context.Context, method, reqURL string, accept st // Retry on 429 rate limit if resp.StatusCode == http.StatusTooManyRequests && attempt < maxAttempts-1 { - // Check for Retry-After header and override backoff if present + // Check for Retry-After header and override backoff if present. + // Supports both integer seconds (common) and HTTP-date format (RFC 7231). if ra := resp.Header.Get("Retry-After"); ra != "" { if seconds, err := strconv.Atoi(ra); err == nil && seconds > 0 { delay := time.Duration(seconds) * time.Second @@ -242,6 +259,17 @@ func (c *Client) doRequest(ctx context.Context, method, reqURL string, accept st if attempt < len(backoff) { backoff[attempt] = delay } + } else if t, err := http.ParseTime(ra); err == nil { + delay := time.Until(t) + if delay < 0 { + delay = 0 + } + if delay > maxRetryAfter { + delay = maxRetryAfter + } + if attempt < len(backoff) { + backoff[attempt] = delay + } } } continue diff --git a/github/client_test.go b/github/client_test.go index 73cf1df..d94bf8e 100644 --- a/github/client_test.go +++ b/github/client_test.go @@ -263,6 +263,84 @@ func TestDoRequest_RetryAfterDoesNotMutateBackoff(t *testing.T) { } } +func TestDoRequest_429RetryAfterHTTPDate(t *testing.T) { + attempts := 0 + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + attempts++ + if attempts == 1 { + // Use HTTP-date format (RFC 7231) — a time 2 seconds in the future. + future := time.Now().Add(2 * time.Second).UTC() + w.Header().Set("Retry-After", future.Format(http.TimeFormat)) + w.WriteHeader(429) + w.Write([]byte(`{"message":"rate limit"}`)) + return + } + w.WriteHeader(200) + w.Write([]byte(`{"ok":true}`)) + })) + defer srv.Close() + + c := NewClient("token", srv.URL, AllowInsecureHTTP()) + c.SetHTTPClient(srv.Client()) + c.SetRetryBackoff([]time.Duration{1 * time.Millisecond, 1 * time.Millisecond}) + + start := time.Now() + body, err := c.doGet(context.Background(), srv.URL+"/test") + elapsed := time.Since(start) + + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if string(body) != `{"ok":true}` { + t.Errorf("unexpected body: %s", body) + } + if attempts != 2 { + t.Errorf("expected 2 attempts, got %d", attempts) + } + // HTTP-date was ~2s in the future; by the time client processes it, + // time.Until gives ~1-2s. Verify it's meaningfully delayed (not instant). + if elapsed < 500*time.Millisecond { + t.Errorf("expected meaningful delay from HTTP-date Retry-After, got %v", elapsed) + } +} + +func TestDoRequest_429RetryAfterHTTPDateInPast(t *testing.T) { + attempts := 0 + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + attempts++ + if attempts == 1 { + // Use a time in the past — should result in zero/immediate retry. + past := time.Now().Add(-10 * time.Second).UTC() + w.Header().Set("Retry-After", past.Format(http.TimeFormat)) + w.WriteHeader(429) + w.Write([]byte(`{"message":"rate limit"}`)) + return + } + w.WriteHeader(200) + w.Write([]byte(`{"ok":true}`)) + })) + defer srv.Close() + + c := NewClient("token", srv.URL, AllowInsecureHTTP()) + c.SetHTTPClient(srv.Client()) + c.SetRetryBackoff([]time.Duration{5 * time.Second, 5 * time.Second}) + + start := time.Now() + _, err := c.doGet(context.Background(), srv.URL+"/test") + elapsed := time.Since(start) + + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if attempts != 2 { + t.Errorf("expected 2 attempts, got %d", attempts) + } + // Past date should override the 5s backoff to ~0 + if elapsed > 500*time.Millisecond { + t.Errorf("expected near-instant retry for past HTTP-date, got %v", elapsed) + } +} + func TestDoRequest_SetsUserAgentHeader(t *testing.T) { var gotUA string srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { @@ -392,4 +470,7 @@ func TestSetHTTPClient_NilRestoresDefault(t *testing.T) { if c.httpClient.Timeout != 30*time.Second { t.Errorf("expected 30s timeout, got %v", c.httpClient.Timeout) } + if c.httpClient.CheckRedirect == nil { + t.Fatal("expected CheckRedirect policy after SetHTTPClient(nil)") + } } diff --git a/github/files.go b/github/files.go index f09d3e5..f9a1cf6 100644 --- a/github/files.go +++ b/github/files.go @@ -22,6 +22,10 @@ func (c *Client) GetFileContent(ctx context.Context, owner, repo, path, ref stri // If the path points to a single file (not a directory), the API returns // a JSON object instead of an array; this is handled by returning a // single-element slice. +// +// Note: dot-segments ("." and "..") in the path are silently removed to +// prevent path traversal. This means a path like "foo/../bar" resolves +// to "foo/bar" rather than "bar". func (c *Client) ListContents(ctx context.Context, owner, repo, path string) ([]vcs.ContentEntry, error) { reqURL := fmt.Sprintf("%s/repos/%s/%s/contents/%s", c.baseURL, url.PathEscape(owner), url.PathEscape(repo), escapePath(path)) diff --git a/github/pr.go b/github/pr.go index 1bb428a..3e984c2 100644 --- a/github/pr.go +++ b/github/pr.go @@ -123,6 +123,10 @@ func (c *Client) GetPullRequestFiles(ctx context.Context, owner, repo string, nu // GetFileContentAtRef fetches a file at a specific ref from a repo. // If ref is empty, the query parameter is omitted (uses default branch). +// +// Note: dot-segments ("." and "..") in the path are silently removed to +// prevent path traversal. This means a path like "foo/../bar" resolves +// to "foo/bar" rather than "bar". func (c *Client) GetFileContentAtRef(ctx context.Context, owner, repo, path, ref string) (string, error) { reqURL := fmt.Sprintf("%s/repos/%s/%s/contents/%s", c.baseURL, url.PathEscape(owner), url.PathEscape(repo), escapePath(path)) -- 2.47.3 From 491df7cb1f67f1fc64d6b99bb453b4432d29e9d3 Mon Sep 17 00:00:00 2001 From: claw Date: Tue, 12 May 2026 17:30:24 -0700 Subject: [PATCH 10/16] fix(github): address review findings from rounds 2867/2870 - Extract duplicated CheckRedirect lambda to defaultCheckRedirect function (sonnet #1: eliminate duplication between NewClient and SetHTTPClient) - Remove unnecessary int64 cast in response size check (sonnet #3) - Validate fallback unmarshal in ListContents to reject zero-value entries (sonnet #5: prevent accepting unexpected JSON formats silently) - Rename strPtr to stringPtr for consistency (sonnet #6) - Add doc comment about APIError.Error body exposure (security #3) Deferred to separate issues: - #95: Reject cross-host redirects entirely (security #1) - #96: Add safeguards for AllowInsecureHTTP (security #2) --- github/client.go | 48 +++++++++++++++++++++++------------------------ github/files.go | 3 +++ github/pr_test.go | 16 ++++++++-------- 3 files changed, 35 insertions(+), 32 deletions(-) diff --git a/github/client.go b/github/client.go index 69baa36..23a945e 100644 --- a/github/client.go +++ b/github/client.go @@ -26,6 +26,10 @@ const ( // APIError represents an HTTP error response from the GitHub API. // It carries the status code so callers can distinguish between // different failure modes (e.g. 404 vs 500). +// +// Note: Error() includes up to 200 bytes of the response body for debugging. +// Callers should avoid logging raw error messages in production if the upstream +// server may return sensitive details in error responses. type APIError struct { StatusCode int Body string @@ -97,6 +101,21 @@ type Client struct { retryBackoff []time.Duration } +// defaultCheckRedirect is the redirect policy used by NewClient and SetHTTPClient(nil). +// It strips the Authorization header on cross-host redirects or protocol downgrades +// (HTTPS→HTTP) to prevent credential leakage, while still following the redirect. +func defaultCheckRedirect(req *http.Request, via []*http.Request) error { + if len(via) >= 10 { + return fmt.Errorf("stopped after 10 redirects") + } + // Strip Authorization on cross-host redirect or protocol downgrade (https→http). + prev := via[len(via)-1] + if req.URL.Host != prev.URL.Host || (prev.URL.Scheme == "https" && req.URL.Scheme == "http") { + req.Header.Del("Authorization") + } + return nil +} + // NewClient creates a new GitHub API client. // If baseURL is empty, it defaults to https://api.github.com. // For GitHub Enterprise, pass the API base URL (e.g. https://github.concur.com/api/v3). @@ -115,18 +134,8 @@ func NewClient(token, baseURL string, opts ...ClientOption) *Client { allowInsecureHTTP: cfg.allowInsecureHTTP, token: token, httpClient: &http.Client{ - Timeout: 30 * time.Second, - CheckRedirect: func(req *http.Request, via []*http.Request) error { - if len(via) >= 10 { - return fmt.Errorf("stopped after 10 redirects") - } - // Strip Authorization on cross-host redirect or protocol downgrade (https→http). - prev := via[len(via)-1] - if req.URL.Host != prev.URL.Host || (prev.URL.Scheme == "https" && req.URL.Scheme == "http") { - req.Header.Del("Authorization") - } - return nil - }, + Timeout: 30 * time.Second, + CheckRedirect: defaultCheckRedirect, }, } } @@ -138,17 +147,8 @@ func NewClient(token, baseURL string, opts ...ClientOption) *Client { func (c *Client) SetHTTPClient(hc *http.Client) { if hc == nil { hc = &http.Client{ - Timeout: 30 * time.Second, - CheckRedirect: func(req *http.Request, via []*http.Request) error { - if len(via) >= 10 { - return fmt.Errorf("stopped after 10 redirects") - } - prev := via[len(via)-1] - if req.URL.Host != prev.URL.Host || (prev.URL.Scheme == "https" && req.URL.Scheme == "http") { - req.Header.Del("Authorization") - } - return nil - }, + Timeout: 30 * time.Second, + CheckRedirect: defaultCheckRedirect, } } c.httpClient = hc @@ -235,7 +235,7 @@ func (c *Client) doRequest(ctx context.Context, method, reqURL string, accept st if err != nil { return nil, fmt.Errorf("read response body: %w", err) } - if int64(len(body)) >= maxResponseBytes { + if len(body) >= maxResponseBytes { return nil, fmt.Errorf("response body exceeded %d bytes (truncated)", maxResponseBytes) } return body, nil diff --git a/github/files.go b/github/files.go index f9a1cf6..aeebe20 100644 --- a/github/files.go +++ b/github/files.go @@ -48,6 +48,9 @@ func (c *Client) ListContents(ctx context.Context, owner, repo, path string) ([] if err2 := json.Unmarshal(body, &single); err2 != nil { return nil, fmt.Errorf("parse contents JSON: %w", err2) } + if single.Name == "" && single.Path == "" && single.Type == "" { + return nil, fmt.Errorf("parse contents JSON: unexpected response format") + } entries = []entry{single} } diff --git a/github/pr_test.go b/github/pr_test.go index 405cc6f..0e05a50 100644 --- a/github/pr_test.go +++ b/github/pr_test.go @@ -528,13 +528,13 @@ func TestGetCommitStatuses_CheckRunConclusions(t *testing.T) { status string want string }{ - {strPtr("success"), "completed", "success"}, - {strPtr("failure"), "completed", "failure"}, - {strPtr("action_required"), "completed", "failure"}, - {strPtr("timed_out"), "completed", "failure"}, - {strPtr("cancelled"), "completed", "success"}, - {strPtr("skipped"), "completed", "success"}, - {strPtr("neutral"), "completed", "success"}, + {stringPtr("success"), "completed", "success"}, + {stringPtr("failure"), "completed", "failure"}, + {stringPtr("action_required"), "completed", "failure"}, + {stringPtr("timed_out"), "completed", "failure"}, + {stringPtr("cancelled"), "completed", "success"}, + {stringPtr("skipped"), "completed", "success"}, + {stringPtr("neutral"), "completed", "success"}, {nil, "in_progress", "pending"}, {nil, "queued", "pending"}, } @@ -632,6 +632,6 @@ func TestGetCommitStatuses_MalformedJSON(t *testing.T) { } } -func strPtr(s string) *string { +func stringPtr(s string) *string { return &s } -- 2.47.3 From 5b2fa0b9afb6c31612d03d3333e5ef4e969f7185 Mon Sep 17 00:00:00 2001 From: claw Date: Tue, 12 May 2026 18:16:43 -0700 Subject: [PATCH 11/16] refactor(github): address review findings from round 2872 MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - client.go: clarify timer drain comment (finding #1) - client.go: rename t -> retryAt for time.Time clarity (finding #2) - pr.go: remove dead _ string parameter from mapCheckRunStatus (finding #3) - files.go: add inline comment explaining zero-value guard (finding #4) Findings #5 (NIT, no code change) and #6 (NIT, defer vs t.Cleanup in t.Run closures) pushed back — see PR comment. --- github/client.go | 6 +++--- github/files.go | 2 ++ github/pr.go | 9 ++++----- 3 files changed, 9 insertions(+), 8 deletions(-) diff --git a/github/client.go b/github/client.go index 23a945e..2a4569e 100644 --- a/github/client.go +++ b/github/client.go @@ -199,7 +199,7 @@ func (c *Client) doRequest(ctx context.Context, method, reqURL string, accept st timer := time.NewTimer(delay) select { case <-timer.C: - // Timer already fired; Stop() is a no-op here. + // Backoff elapsed, proceed with retry. case <-ctx.Done(): timer.Stop() return nil, ctx.Err() @@ -259,8 +259,8 @@ func (c *Client) doRequest(ctx context.Context, method, reqURL string, accept st if attempt < len(backoff) { backoff[attempt] = delay } - } else if t, err := http.ParseTime(ra); err == nil { - delay := time.Until(t) + } else if retryAt, err := http.ParseTime(ra); err == nil { + delay := time.Until(retryAt) if delay < 0 { delay = 0 } diff --git a/github/files.go b/github/files.go index aeebe20..25fd697 100644 --- a/github/files.go +++ b/github/files.go @@ -48,6 +48,8 @@ func (c *Client) ListContents(ctx context.Context, owner, repo, path string) ([] if err2 := json.Unmarshal(body, &single); err2 != nil { return nil, fmt.Errorf("parse contents JSON: %w", err2) } + // Guard against empty objects ({}) or unexpected shapes that + // unmarshal successfully but carry no useful data. if single.Name == "" && single.Path == "" && single.Type == "" { return nil, fmt.Errorf("parse contents JSON: unexpected response format") } diff --git a/github/pr.go b/github/pr.go index 3e984c2..c088f2b 100644 --- a/github/pr.go +++ b/github/pr.go @@ -194,7 +194,7 @@ func (c *Client) GetCommitStatuses(ctx context.Context, owner, repo, sha string) for _, cr := range checkResp.CheckRuns { result = append(result, vcs.CommitStatus{ Context: cr.Name, - Status: mapCheckRunStatus(cr.Conclusion, cr.Status), + Status: mapCheckRunStatus(cr.Conclusion), Description: derefString(cr.Conclusion), TargetURL: cr.HTMLURL, }) @@ -208,10 +208,9 @@ func (c *Client) GetCommitStatuses(ctx context.Context, owner, repo, sha string) } // mapCheckRunStatus maps a check run conclusion to a vcs.CommitStatus status string. -// The second parameter (check run status field, e.g. "completed", "in_progress") is -// unused because conclusion alone determines the mapped state: nil conclusion means -// the run is still in progress (pending), regardless of the status field value. -func mapCheckRunStatus(conclusion *string, _ string) string { +// Conclusion alone determines the mapped state: nil conclusion means the run is +// still in progress (pending), regardless of the status field value. +func mapCheckRunStatus(conclusion *string) string { if conclusion == nil { // Still running or queued return "pending" -- 2.47.3 From 80af5037b220d3f6221197e4b4e50b59c4252bf7 Mon Sep 17 00:00:00 2001 From: claw Date: Tue, 12 May 2026 18:41:44 -0700 Subject: [PATCH 12/16] fix(github): address review findings from round 2880/2883 Sonnet MINOR #1: Stop timer after <-timer.C fires for idiomatic cleanup. Sonnet MINOR #2: Document that empty array from contents API is valid (empty dir). Sonnet MINOR #3: Document that GetPullRequestFiles returns nil for no files. Sonnet NIT #4: Strengthen SetHTTPClient/SetRetryBackoff docs to clarify test-only intent. Sonnet NIT #5: Document GetCommitStatuses fail-fast behavior. Sonnet NIT #6: Document double-slash collapsing in escapePath. Security MINOR #1: Document redirect policy responsibility when providing custom client. Security MINOR #2: Reduce maxErrorBodyBytes from 64KB to 4KB to limit sensitive data exposure. --- github/client.go | 28 ++++++++++++++++++++-------- github/files.go | 13 ++++++++----- github/pr.go | 4 ++++ 3 files changed, 32 insertions(+), 13 deletions(-) diff --git a/github/client.go b/github/client.go index 2a4569e..75f4596 100644 --- a/github/client.go +++ b/github/client.go @@ -27,9 +27,10 @@ const ( // It carries the status code so callers can distinguish between // different failure modes (e.g. 404 vs 500). // -// Note: Error() includes up to 200 bytes of the response body for debugging. -// Callers should avoid logging raw error messages in production if the upstream -// server may return sensitive details in error responses. +// The Body field stores up to 4 KiB of the raw response for programmatic +// inspection. Error() truncates to 200 bytes for safe logging, but callers +// should avoid logging or propagating Body directly in production since it may +// contain sensitive details from the upstream server. type APIError struct { StatusCode int Body string @@ -87,8 +88,9 @@ func AllowInsecureHTTP() ClientOption { } // Client interacts with the GitHub API. -// A Client is safe for concurrent use by multiple goroutines; -// however, SetHTTPClient and SetRetryBackoff must not be called concurrently with requests. +// A Client is safe for concurrent use by multiple goroutines. +// SetHTTPClient and SetRetryBackoff are intended for test setup only and must +// be called before any goroutines issue requests; they have no synchronization. type Client struct { baseURL string token string @@ -141,9 +143,15 @@ func NewClient(token, baseURL string, opts ...ClientOption) *Client { } // SetHTTPClient sets the underlying HTTP client used for requests. -// This is intended for testing to inject mock transports. +// This is intended for test setup only to inject mock transports; it must be +// called before any goroutines issue requests. +// // Passing nil restores the default client (30s timeout + auth-stripping // CheckRedirect policy matching NewClient). +// +// Callers providing a non-nil client are responsible for configuring a safe +// CheckRedirect policy. Without one, the default net/http behavior will follow +// redirects and may forward the Authorization header to untrusted hosts. func (c *Client) SetHTTPClient(hc *http.Client) { if hc == nil { hc = &http.Client{ @@ -155,6 +163,7 @@ func (c *Client) SetHTTPClient(hc *http.Client) { } // SetRetryBackoff configures the retry backoff durations for testing. +// It must be called before any goroutines issue requests. // In production the default {1s, 2s} applies. func (c *Client) SetRetryBackoff(d []time.Duration) { c.retryBackoff = d @@ -175,7 +184,10 @@ func (c *Client) doRequest(ctx context.Context, method, reqURL string, accept st backoff = []time.Duration{1 * time.Second, 2 * time.Second} } - const maxErrorBodyBytes = 64 * 1024 + // maxErrorBodyBytes limits how much of an error response body is stored. + // Kept small (4 KiB) to reduce the risk of sensitive data leakage if callers + // log APIError.Body directly. Error() further truncates to 200 bytes. + const maxErrorBodyBytes = 4 * 1024 // Reject non-HTTPS URLs early since the URL is immutable across retries. if c.token != "" && !c.allowInsecureHTTP { @@ -199,7 +211,7 @@ func (c *Client) doRequest(ctx context.Context, method, reqURL string, accept st timer := time.NewTimer(delay) select { case <-timer.C: - // Backoff elapsed, proceed with retry. + timer.Stop() // no-op after fire, releases runtime resources promptly case <-ctx.Done(): timer.Stop() return nil, ctx.Err() diff --git a/github/files.go b/github/files.go index 25fd697..0b12c4e 100644 --- a/github/files.go +++ b/github/files.go @@ -42,6 +42,8 @@ func (c *Client) ListContents(ctx context.Context, owner, repo, path string) ([] // The GitHub contents API returns an array for directories and an object // for single files. Try array first (common case), then fall back to object. + // An empty array ([]) is valid — it represents an empty directory — and + // results in a zero-length slice returned without error. var entries []entry if err := json.Unmarshal(body, &entries); err != nil { var single entry @@ -69,11 +71,12 @@ func (c *Client) ListContents(ctx context.Context, owner, repo, path string) ([] // escapePath escapes each segment of a relative file path for use in URLs. // Slashes are preserved as path separators; other special characters are escaped. -// Dot-segments ("." and "..") are silently removed to prevent path traversal. -// This is intentional: callers may receive a different path than requested without -// error. The function is package-private, and all callers (GetFileContentAtRef, -// ListContents) already handle missing-file errors from the API if the cleaned -// path doesn't match what the caller intended. +// Dot-segments ("." and "..") and empty segments (from consecutive slashes like +// "a//b") are silently removed to prevent path traversal and produce canonical +// paths. This is intentional: callers may receive a different path than requested +// without error. The function is package-private, and all callers +// (GetFileContentAtRef, ListContents) already handle missing-file errors from the +// API if the cleaned path doesn't match what the caller intended. func escapePath(p string) string { parts := strings.Split(p, "/") var clean []string diff --git a/github/pr.go b/github/pr.go index c088f2b..e21eb0e 100644 --- a/github/pr.go +++ b/github/pr.go @@ -89,6 +89,8 @@ const maxPages = 100 // GetPullRequestFiles fetches the list of files changed in a PR. // Paginates through all pages (100 per page) to collect all files. +// Returns nil (not an empty slice) when the PR has no changed files. +// Callers can safely range over or check len() on a nil slice. func (c *Client) GetPullRequestFiles(ctx context.Context, owner, repo string, number int) ([]vcs.ChangedFile, error) { var allFiles []vcs.ChangedFile @@ -156,6 +158,8 @@ func (c *Client) GetFileContentAtRef(ctx context.Context, owner, repo, path, ref // GetCommitStatuses fetches both commit statuses and check runs for a SHA, // merging them into a unified []vcs.CommitStatus slice. +// If the commit statuses endpoint fails (e.g. 404 for an unknown SHA), the +// function returns immediately without attempting the check-runs endpoint. func (c *Client) GetCommitStatuses(ctx context.Context, owner, repo, sha string) ([]vcs.CommitStatus, error) { var result []vcs.CommitStatus -- 2.47.3 From 1194bc758ce8608c9f4fd79442f1ce6c27aedc84 Mon Sep 17 00:00:00 2001 From: claw Date: Tue, 12 May 2026 19:29:06 -0700 Subject: [PATCH 13/16] fix(github): address review findings from rounds 2884/2885/2887 MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Fix response body limit check: read maxResponseBytes+1 and use > to distinguish exactly-at-limit from truncated (sonnet finding #1) - Reject HTTPS→HTTP redirects outright instead of stripping auth and following; prevents plaintext metadata leakage (sonnet #2, security #1) - Sanitize newlines in APIError.Error to prevent log injection from upstream response bodies (security #2) - Add nil-return documentation to GetCommitStatuses (sonnet #3) - Gate TestDoRequest_429RetryAfterHTTPDate behind testing.Short (sonnet #6) - Add tests for redirect policy, exact-at-limit body, and error sanitization --- github/client.go | 21 ++++++++---- github/client_test.go | 80 +++++++++++++++++++++++++++++++++++++++++++ github/pr.go | 1 + 3 files changed, 96 insertions(+), 6 deletions(-) diff --git a/github/client.go b/github/client.go index 75f4596..eedfdc7 100644 --- a/github/client.go +++ b/github/client.go @@ -41,6 +41,9 @@ func (e *APIError) Error() string { if len(body) > 200 { body = body[:200] + "...(truncated)" } + // Sanitize newlines to prevent log injection from upstream response bodies. + body = strings.ReplaceAll(body, "\n", " ") + body = strings.ReplaceAll(body, "\r", " ") return fmt.Sprintf("HTTP %d: %s", e.StatusCode, body) } @@ -104,15 +107,21 @@ type Client struct { } // defaultCheckRedirect is the redirect policy used by NewClient and SetHTTPClient(nil). -// It strips the Authorization header on cross-host redirects or protocol downgrades -// (HTTPS→HTTP) to prevent credential leakage, while still following the redirect. +// It rejects HTTPS→HTTP protocol downgrades (to prevent plaintext leakage) and strips +// the Authorization header on cross-host redirects to prevent credential leakage to +// third-party hosts (e.g. CDN redirects from GitHub). func defaultCheckRedirect(req *http.Request, via []*http.Request) error { if len(via) >= 10 { return fmt.Errorf("stopped after 10 redirects") } - // Strip Authorization on cross-host redirect or protocol downgrade (https→http). prev := via[len(via)-1] - if req.URL.Host != prev.URL.Host || (prev.URL.Scheme == "https" && req.URL.Scheme == "http") { + // Reject protocol downgrade: HTTPS→HTTP leaks request metadata over plaintext. + if prev.URL.Scheme == "https" && req.URL.Scheme == "http" { + return fmt.Errorf("refusing redirect from HTTPS to HTTP (%s → %s)", prev.URL.Host, req.URL.Host) + } + // Strip Authorization on cross-host redirect to avoid leaking credentials + // to third-party hosts (GitHub legitimately redirects to CDN hosts). + if req.URL.Host != prev.URL.Host { req.Header.Del("Authorization") } return nil @@ -242,12 +251,12 @@ func (c *Client) doRequest(ctx context.Context, method, reqURL string, accept st } if resp.StatusCode >= 200 && resp.StatusCode < 300 { - body, err := io.ReadAll(io.LimitReader(resp.Body, maxResponseBytes)) + body, err := io.ReadAll(io.LimitReader(resp.Body, maxResponseBytes+1)) resp.Body.Close() if err != nil { return nil, fmt.Errorf("read response body: %w", err) } - if len(body) >= maxResponseBytes { + if len(body) > maxResponseBytes { return nil, fmt.Errorf("response body exceeded %d bytes (truncated)", maxResponseBytes) } return body, nil diff --git a/github/client_test.go b/github/client_test.go index d94bf8e..a8ccc06 100644 --- a/github/client_test.go +++ b/github/client_test.go @@ -4,6 +4,7 @@ import ( "context" "net/http" "net/http/httptest" + "net/url" "strings" "testing" "time" @@ -185,6 +186,17 @@ func TestIsUnauthorized(t *testing.T) { } } +func TestAPIError_SanitizesNewlines(t *testing.T) { + err := &APIError{StatusCode: 500, Body: "line1\ninjected\rmore"} + msg := err.Error() + if strings.Contains(msg, "\n") || strings.Contains(msg, "\r") { + t.Errorf("expected newlines to be sanitized, got: %q", msg) + } + if !strings.Contains(msg, "line1 injected more") { + t.Errorf("expected sanitized body, got: %q", msg) + } +} + func TestDoRequest_429RetryAfterHeader(t *testing.T) { if testing.Short() { t.Skip("skipping slow retry test in short mode") @@ -264,6 +276,9 @@ func TestDoRequest_RetryAfterDoesNotMutateBackoff(t *testing.T) { } func TestDoRequest_429RetryAfterHTTPDate(t *testing.T) { + if testing.Short() { + t.Skip("skipping slow Retry-After HTTP-date test in short mode") + } attempts := 0 srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { attempts++ @@ -379,6 +394,26 @@ func TestDoRequest_LimitsResponseBody(t *testing.T) { } } +func TestDoRequest_AcceptsExactlyAtLimit(t *testing.T) { + // A response body exactly equal to maxResponseBytes should succeed (not error). + exactBody := strings.Repeat("x", maxResponseBytes) + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(200) + w.Write([]byte(exactBody)) + })) + defer srv.Close() + + c := NewClient("token", srv.URL, AllowInsecureHTTP()) + c.SetHTTPClient(srv.Client()) + body, err := c.doGet(context.Background(), srv.URL+"/test") + if err != nil { + t.Fatalf("unexpected error for exactly-at-limit body: %v", err) + } + if len(body) != maxResponseBytes { + t.Errorf("expected body length %d, got %d", maxResponseBytes, len(body)) + } +} + func TestDoRequest_SkipsAuthWhenTokenEmpty(t *testing.T) { var gotAuth string srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { @@ -405,6 +440,51 @@ func TestNewClient_CheckRedirectStripsAuthOnCrossHost(t *testing.T) { } } +func TestDefaultCheckRedirect_RejectsHTTPSToHTTP(t *testing.T) { + prev := &http.Request{URL: &url.URL{Scheme: "https", Host: "api.github.com", Path: "/foo"}} + req := &http.Request{ + URL: &url.URL{Scheme: "http", Host: "api.github.com", Path: "/foo"}, + Header: http.Header{"Authorization": []string{"Bearer token"}}, + } + err := defaultCheckRedirect(req, []*http.Request{prev}) + if err == nil { + t.Fatal("expected error on HTTPS→HTTP redirect") + } + if !strings.Contains(err.Error(), "refusing redirect from HTTPS to HTTP") { + t.Errorf("unexpected error message: %v", err) + } +} + +func TestDefaultCheckRedirect_StripsAuthOnCrossHost(t *testing.T) { + prev := &http.Request{URL: &url.URL{Scheme: "https", Host: "api.github.com", Path: "/foo"}} + req := &http.Request{ + URL: &url.URL{Scheme: "https", Host: "objects.githubusercontent.com", Path: "/bar"}, + Header: http.Header{"Authorization": []string{"Bearer token"}}, + } + err := defaultCheckRedirect(req, []*http.Request{prev}) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if auth := req.Header.Get("Authorization"); auth != "" { + t.Errorf("expected Authorization header to be stripped, got %q", auth) + } +} + +func TestDefaultCheckRedirect_PreservesAuthOnSameHost(t *testing.T) { + prev := &http.Request{URL: &url.URL{Scheme: "https", Host: "api.github.com", Path: "/foo"}} + req := &http.Request{ + URL: &url.URL{Scheme: "https", Host: "api.github.com", Path: "/bar"}, + Header: http.Header{"Authorization": []string{"Bearer token"}}, + } + err := defaultCheckRedirect(req, []*http.Request{prev}) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if auth := req.Header.Get("Authorization"); auth != "Bearer token" { + t.Errorf("expected Authorization to be preserved, got %q", auth) + } +} + func TestDoRequest_RejectsHTTPWithToken(t *testing.T) { srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(200) diff --git a/github/pr.go b/github/pr.go index e21eb0e..89a3d99 100644 --- a/github/pr.go +++ b/github/pr.go @@ -158,6 +158,7 @@ func (c *Client) GetFileContentAtRef(ctx context.Context, owner, repo, path, ref // GetCommitStatuses fetches both commit statuses and check runs for a SHA, // merging them into a unified []vcs.CommitStatus slice. +// Returns nil (not an empty slice) when there are no statuses or check runs. // If the commit statuses endpoint fails (e.g. 404 for an unknown SHA), the // function returns immediately without attempting the check-runs endpoint. func (c *Client) GetCommitStatuses(ctx context.Context, owner, repo, sha string) ([]vcs.CommitStatus, error) { -- 2.47.3 From 6e8e7448169e639b46cf06f7694011f2ffafb485 Mon Sep 17 00:00:00 2001 From: claw Date: Tue, 12 May 2026 19:40:30 -0700 Subject: [PATCH 14/16] fix(github): address self-review findings from 1194bc75 - Handle io.ReadAll error on error body read (client.go:265) - Remove unused State field from commitStatusResponse (pr.go) - Guard via slice access in defaultCheckRedirect (client.go:117) - Move GetFileContentAtRef from pr.go to files.go (logical home) --- github/client.go | 10 +++++++++- github/files.go | 33 +++++++++++++++++++++++++++++++++ github/pr.go | 34 ---------------------------------- 3 files changed, 42 insertions(+), 35 deletions(-) diff --git a/github/client.go b/github/client.go index eedfdc7..64976fd 100644 --- a/github/client.go +++ b/github/client.go @@ -114,6 +114,11 @@ func defaultCheckRedirect(req *http.Request, via []*http.Request) error { if len(via) >= 10 { return fmt.Errorf("stopped after 10 redirects") } + // Guard: net/http guarantees len(via) >= 1 but this is undocumented; + // defend against zero-length to avoid panic on index out of range. + if len(via) == 0 { + return nil + } prev := via[len(via)-1] // Reject protocol downgrade: HTTPS→HTTP leaks request metadata over plaintext. if prev.URL.Scheme == "https" && req.URL.Scheme == "http" { @@ -262,7 +267,10 @@ func (c *Client) doRequest(ctx context.Context, method, reqURL string, accept st return body, nil } - errBody, _ := io.ReadAll(io.LimitReader(resp.Body, maxErrorBodyBytes)) + errBody, readErr := io.ReadAll(io.LimitReader(resp.Body, maxErrorBodyBytes)) + if readErr != nil && len(errBody) == 0 { + errBody = []byte(fmt.Sprintf("[error reading response body: %v]", readErr)) + } resp.Body.Close() lastErr = &APIError{StatusCode: resp.StatusCode, Body: string(errBody)} diff --git a/github/files.go b/github/files.go index 0b12c4e..2531b8e 100644 --- a/github/files.go +++ b/github/files.go @@ -17,6 +17,39 @@ func (c *Client) GetFileContent(ctx context.Context, owner, repo, path, ref stri return c.GetFileContentAtRef(ctx, owner, repo, path, ref) } +// GetFileContentAtRef fetches a file at a specific ref from a repo. +// If ref is empty, the query parameter is omitted (uses default branch). +// +// Note: dot-segments ("." and "..") in the path are silently removed to +// prevent path traversal. This means a path like "foo/../bar" resolves +// to "foo/bar" rather than "bar". +func (c *Client) GetFileContentAtRef(ctx context.Context, owner, repo, path, ref string) (string, error) { + reqURL := fmt.Sprintf("%s/repos/%s/%s/contents/%s", + c.baseURL, url.PathEscape(owner), url.PathEscape(repo), escapePath(path)) + if ref != "" { + reqURL += "?ref=" + url.QueryEscape(ref) + } + body, err := c.doGet(ctx, reqURL) + if err != nil { + return "", fmt.Errorf("fetch file %s: %w", path, err) + } + var resp struct { + Content string `json:"content"` + Encoding string `json:"encoding"` + } + if err := json.Unmarshal(body, &resp); err != nil { + return "", fmt.Errorf("parse file content JSON: %w", err) + } + if resp.Encoding != "base64" { + return "", fmt.Errorf("unexpected encoding %q for file %s", resp.Encoding, path) + } + decoded, err := decodeBase64Content(resp.Content) + if err != nil { + return "", fmt.Errorf("decode base64 content for %s: %w", path, err) + } + return decoded, nil +} + // ListContents lists files and directories at a given path in a repo. // Returns the directory listing from the GitHub contents API. // If the path points to a single file (not a directory), the API returns diff --git a/github/pr.go b/github/pr.go index 89a3d99..2db9304 100644 --- a/github/pr.go +++ b/github/pr.go @@ -33,7 +33,6 @@ type changedFileResponse struct { // commitStatusResponse is the GitHub combined status API response. type commitStatusResponse struct { - State string `json:"state"` Statuses []struct { Context string `json:"context"` State string `json:"state"` @@ -123,39 +122,6 @@ func (c *Client) GetPullRequestFiles(ctx context.Context, owner, repo string, nu return allFiles, nil } -// GetFileContentAtRef fetches a file at a specific ref from a repo. -// If ref is empty, the query parameter is omitted (uses default branch). -// -// Note: dot-segments ("." and "..") in the path are silently removed to -// prevent path traversal. This means a path like "foo/../bar" resolves -// to "foo/bar" rather than "bar". -func (c *Client) GetFileContentAtRef(ctx context.Context, owner, repo, path, ref string) (string, error) { - reqURL := fmt.Sprintf("%s/repos/%s/%s/contents/%s", - c.baseURL, url.PathEscape(owner), url.PathEscape(repo), escapePath(path)) - if ref != "" { - reqURL += "?ref=" + url.QueryEscape(ref) - } - body, err := c.doGet(ctx, reqURL) - if err != nil { - return "", fmt.Errorf("fetch file %s: %w", path, err) - } - var resp struct { - Content string `json:"content"` - Encoding string `json:"encoding"` - } - if err := json.Unmarshal(body, &resp); err != nil { - return "", fmt.Errorf("parse file content JSON: %w", err) - } - if resp.Encoding != "base64" { - return "", fmt.Errorf("unexpected encoding %q for file %s", resp.Encoding, path) - } - decoded, err := decodeBase64Content(resp.Content) - if err != nil { - return "", fmt.Errorf("decode base64 content for %s: %w", path, err) - } - return decoded, nil -} - // GetCommitStatuses fetches both commit statuses and check runs for a SHA, // merging them into a unified []vcs.CommitStatus slice. // Returns nil (not an empty slice) when there are no statuses or check runs. -- 2.47.3 From 30798ff02383edc4fc82324e0a58fa0191e15cbb Mon Sep 17 00:00:00 2001 From: claw Date: Tue, 12 May 2026 20:28:52 -0700 Subject: [PATCH 15/16] fix: address sonnet review MINOR findings (#2916) - client.go: fix misleading timer.Stop() comment (finding #1) - pr.go: document all-or-nothing semantics for GetCommitStatuses when check-runs endpoint fails after statuses succeed (finding #2) - files.go: include both array and object unmarshal errors in ListContents fallback error message (finding #3) - pr.go: expand mapCheckRunStatus comment to explain non-blocking policy decision (finding #4) --- github/client.go | 2 +- github/files.go | 2 +- github/pr.go | 5 ++++- 3 files changed, 6 insertions(+), 3 deletions(-) diff --git a/github/client.go b/github/client.go index 64976fd..c3ea252 100644 --- a/github/client.go +++ b/github/client.go @@ -225,7 +225,7 @@ func (c *Client) doRequest(ctx context.Context, method, reqURL string, accept st timer := time.NewTimer(delay) select { case <-timer.C: - timer.Stop() // no-op after fire, releases runtime resources promptly + timer.Stop() // no-op after fire; kept for symmetry with the ctx.Done case case <-ctx.Done(): timer.Stop() return nil, ctx.Err() diff --git a/github/files.go b/github/files.go index 2531b8e..9f04941 100644 --- a/github/files.go +++ b/github/files.go @@ -81,7 +81,7 @@ func (c *Client) ListContents(ctx context.Context, owner, repo, path string) ([] if err := json.Unmarshal(body, &entries); err != nil { var single entry if err2 := json.Unmarshal(body, &single); err2 != nil { - return nil, fmt.Errorf("parse contents JSON: %w", err2) + return nil, fmt.Errorf("parse contents JSON: as array: %v; as object: %w", err, err2) } // Guard against empty objects ({}) or unexpected shapes that // unmarshal successfully but carry no useful data. diff --git a/github/pr.go b/github/pr.go index 2db9304..e9bea5a 100644 --- a/github/pr.go +++ b/github/pr.go @@ -127,6 +127,9 @@ func (c *Client) GetPullRequestFiles(ctx context.Context, owner, repo string, nu // Returns nil (not an empty slice) when there are no statuses or check runs. // If the commit statuses endpoint fails (e.g. 404 for an unknown SHA), the // function returns immediately without attempting the check-runs endpoint. +// If the check-runs endpoint fails after statuses were fetched successfully, +// the function returns an error (not a partial result) so callers always get +// either a complete view or a clear error signal. func (c *Client) GetCommitStatuses(ctx context.Context, owner, repo, sha string) ([]vcs.CommitStatus, error) { var result []vcs.CommitStatus @@ -192,7 +195,7 @@ func mapCheckRunStatus(conclusion *string) string { case "failure", "action_required", "timed_out": return "failure" case "cancelled", "skipped", "neutral": - return "success" // non-blocking + return "success" // non-blocking: these do not indicate a blocking failure per GitHub check suite semantics case "stale", "waiting": return "pending" default: -- 2.47.3 From b380e7fcaebbfaa43fd936598f48932c7a37240e Mon Sep 17 00:00:00 2001 From: claw Date: Tue, 12 May 2026 20:47:59 -0700 Subject: [PATCH 16/16] refactor(github): extract handleResponse for safe defer body close Address review findings #1 and #2: the response body was closed explicitly rather than via defer, which could leak if future code paths were added. Extract handleResponse helper method that uses defer resp.Body.Close() to guarantee cleanup. This avoids the loop-defer antipattern (defer inside a for loop accumulates defers until function exit) by isolating the body handling into its own function scope. --- github/client.go | 46 ++++++++++++++++++++++++++++------------------ 1 file changed, 28 insertions(+), 18 deletions(-) diff --git a/github/client.go b/github/client.go index c3ea252..7f5c7f9 100644 --- a/github/client.go +++ b/github/client.go @@ -255,25 +255,11 @@ func (c *Client) doRequest(ctx context.Context, method, reqURL string, accept st return nil, fmt.Errorf("do request: %w", err) } - if resp.StatusCode >= 200 && resp.StatusCode < 300 { - body, err := io.ReadAll(io.LimitReader(resp.Body, maxResponseBytes+1)) - resp.Body.Close() - if err != nil { - return nil, fmt.Errorf("read response body: %w", err) - } - if len(body) > maxResponseBytes { - return nil, fmt.Errorf("response body exceeded %d bytes (truncated)", maxResponseBytes) - } - return body, nil + body, done, err := c.handleResponse(resp, maxResponseBytes, maxErrorBodyBytes) + if done { + return body, err } - - errBody, readErr := io.ReadAll(io.LimitReader(resp.Body, maxErrorBodyBytes)) - if readErr != nil && len(errBody) == 0 { - errBody = []byte(fmt.Sprintf("[error reading response body: %v]", readErr)) - } - resp.Body.Close() - - lastErr = &APIError{StatusCode: resp.StatusCode, Body: string(errBody)} + lastErr = err // Retry on 429 rate limit if resp.StatusCode == http.StatusTooManyRequests && attempt < maxAttempts-1 { @@ -311,6 +297,30 @@ func (c *Client) doRequest(ctx context.Context, method, reqURL string, accept st return nil, lastErr } +// handleResponse reads and closes the response body, returning the result. +// It uses defer to ensure the body is always closed regardless of code path. +// Returns (body, done, err) where done=true means the caller should return immediately. +func (c *Client) handleResponse(resp *http.Response, maxRespBytes int, maxErrBytes int) ([]byte, bool, error) { + defer resp.Body.Close() + + if resp.StatusCode >= 200 && resp.StatusCode < 300 { + body, err := io.ReadAll(io.LimitReader(resp.Body, int64(maxRespBytes)+1)) + if err != nil { + return nil, true, fmt.Errorf("read response body: %w", err) + } + if len(body) > maxRespBytes { + return nil, true, fmt.Errorf("response body exceeded %d bytes (truncated)", maxRespBytes) + } + return body, true, nil + } + + errBody, readErr := io.ReadAll(io.LimitReader(resp.Body, int64(maxErrBytes))) + if readErr != nil && len(errBody) == 0 { + errBody = []byte(fmt.Sprintf("[error reading response body: %v]", readErr)) + } + return nil, false, &APIError{StatusCode: resp.StatusCode, Body: string(errBody)} +} + // doGet is a convenience wrapper for GET requests with the default Accept header. func (c *Client) doGet(ctx context.Context, reqURL string) ([]byte, error) { return c.doRequest(ctx, http.MethodGet, reqURL, "") -- 2.47.3