From d0b7f097729b876d030d1bc60caa366376ba04dd Mon Sep 17 00:00:00 2001 From: Aaron Weiker Date: Wed, 13 May 2026 04:12:13 +0000 Subject: [PATCH 1/5] feat(github): implement PRReader interface (#80) Implement PRReader conformance on the GitHub client: GetPullRequest, GetPullRequestDiff, GetPullRequestFiles (paginated, populates Patch), GetCommitStatuses (merges commit statuses + check runs). Adds compile-time PRReader conformance check. Requires PR A. Part 2 of 3 for #80. --- github/conformance_test.go | 5 + github/pr.go | 212 ++++++++++++ github/pr_test.go | 637 +++++++++++++++++++++++++++++++++++++ 3 files changed, 854 insertions(+) create mode 100644 github/conformance_test.go create mode 100644 github/pr.go create mode 100644 github/pr_test.go diff --git a/github/conformance_test.go b/github/conformance_test.go new file mode 100644 index 0000000..4dfa195 --- /dev/null +++ b/github/conformance_test.go @@ -0,0 +1,5 @@ +package github_test + +import "gitea.weiker.me/rodin/review-bot/vcs" + +var _ vcs.PRReader = (*Client)(nil) diff --git a/github/pr.go b/github/pr.go new file mode 100644 index 0000000..e9bea5a --- /dev/null +++ b/github/pr.go @@ -0,0 +1,212 @@ +package github + +import ( + "context" + "encoding/json" + "fmt" + "net/http" + "net/url" + + "gitea.weiker.me/rodin/review-bot/vcs" +) + +// pullRequestResponse is the GitHub API response for a pull request. +type pullRequestResponse struct { + Number int `json:"number"` + Title string `json:"title"` + Body string `json:"body"` + Head struct { + SHA string `json:"sha"` + Ref string `json:"ref"` + } `json:"head"` + Base struct { + Ref string `json:"ref"` + } `json:"base"` +} + +// changedFileResponse is the GitHub API response for a changed file in a PR. +type changedFileResponse struct { + Filename string `json:"filename"` + Status string `json:"status"` + Patch string `json:"patch"` +} + +// commitStatusResponse is the GitHub combined status API response. +type commitStatusResponse struct { + Statuses []struct { + Context string `json:"context"` + State string `json:"state"` + Description string `json:"description"` + TargetURL string `json:"target_url"` + } `json:"statuses"` +} + +// checkRunsResponse is the GitHub check runs API response. +type checkRunsResponse struct { + CheckRuns []struct { + Name string `json:"name"` + Conclusion *string `json:"conclusion"` + Status string `json:"status"` + HTMLURL string `json:"html_url"` + } `json:"check_runs"` +} + +// GetPullRequest fetches PR metadata. +func (c *Client) GetPullRequest(ctx context.Context, owner, repo string, number int) (*vcs.PullRequest, error) { + reqURL := fmt.Sprintf("%s/repos/%s/%s/pulls/%d", c.baseURL, url.PathEscape(owner), url.PathEscape(repo), number) + body, err := c.doGet(ctx, reqURL) + if err != nil { + return nil, fmt.Errorf("fetch PR: %w", err) + } + var resp pullRequestResponse + if err := json.Unmarshal(body, &resp); err != nil { + return nil, fmt.Errorf("parse PR JSON: %w", err) + } + return &vcs.PullRequest{ + Number: resp.Number, + Title: resp.Title, + Body: resp.Body, + Head: vcs.HeadRef{SHA: resp.Head.SHA, Ref: resp.Head.Ref}, + Base: vcs.BaseRef{Ref: resp.Base.Ref}, + }, nil +} + +// GetPullRequestDiff fetches the unified diff for a PR. +// Uses Accept: application/vnd.github.diff to get raw diff text. +func (c *Client) GetPullRequestDiff(ctx context.Context, owner, repo string, number int) (string, error) { + reqURL := fmt.Sprintf("%s/repos/%s/%s/pulls/%d", c.baseURL, url.PathEscape(owner), url.PathEscape(repo), number) + body, err := c.doRequest(ctx, http.MethodGet, reqURL, "application/vnd.github.diff") + if err != nil { + return "", fmt.Errorf("fetch diff: %w", err) + } + return string(body), nil +} + +// maxPages is the upper bound on pagination loops to prevent unbounded iteration +// in case the server returns a full page indefinitely. +const maxPages = 100 + +// GetPullRequestFiles fetches the list of files changed in a PR. +// Paginates through all pages (100 per page) to collect all files. +// Returns nil (not an empty slice) when the PR has no changed files. +// Callers can safely range over or check len() on a nil slice. +func (c *Client) GetPullRequestFiles(ctx context.Context, owner, repo string, number int) ([]vcs.ChangedFile, error) { + var allFiles []vcs.ChangedFile + + for page := 1; page <= maxPages; page++ { + reqURL := fmt.Sprintf("%s/repos/%s/%s/pulls/%d/files?per_page=100&page=%d", + c.baseURL, url.PathEscape(owner), url.PathEscape(repo), number, page) + body, err := c.doGet(ctx, reqURL) + if err != nil { + return nil, fmt.Errorf("fetch PR files page %d: %w", page, err) + } + var files []changedFileResponse + if err := json.Unmarshal(body, &files); err != nil { + return nil, fmt.Errorf("parse PR files JSON: %w", err) + } + if len(files) == 0 { + break + } + for _, f := range files { + allFiles = append(allFiles, vcs.ChangedFile{ + Filename: f.Filename, + Status: f.Status, + Patch: f.Patch, + }) + } + if len(files) < 100 { + break + } + } + + return allFiles, nil +} + +// GetCommitStatuses fetches both commit statuses and check runs for a SHA, +// merging them into a unified []vcs.CommitStatus slice. +// Returns nil (not an empty slice) when there are no statuses or check runs. +// If the commit statuses endpoint fails (e.g. 404 for an unknown SHA), the +// function returns immediately without attempting the check-runs endpoint. +// If the check-runs endpoint fails after statuses were fetched successfully, +// the function returns an error (not a partial result) so callers always get +// either a complete view or a clear error signal. +func (c *Client) GetCommitStatuses(ctx context.Context, owner, repo, sha string) ([]vcs.CommitStatus, error) { + var result []vcs.CommitStatus + + // Fetch commit statuses + statusURL := fmt.Sprintf("%s/repos/%s/%s/commits/%s/status", + c.baseURL, url.PathEscape(owner), url.PathEscape(repo), url.PathEscape(sha)) + statusBody, err := c.doGet(ctx, statusURL) + if err != nil { + return nil, fmt.Errorf("fetch commit statuses: %w", err) + } + var statusResp commitStatusResponse + if err := json.Unmarshal(statusBody, &statusResp); err != nil { + return nil, fmt.Errorf("parse commit statuses JSON: %w", err) + } + for _, s := range statusResp.Statuses { + result = append(result, vcs.CommitStatus{ + Context: s.Context, + Status: s.State, + Description: s.Description, + TargetURL: s.TargetURL, + }) + } + + // Fetch check runs (paginated) + for checkPage := 1; checkPage <= maxPages; checkPage++ { + checkURL := fmt.Sprintf("%s/repos/%s/%s/commits/%s/check-runs?per_page=100&page=%d", + c.baseURL, url.PathEscape(owner), url.PathEscape(repo), url.PathEscape(sha), checkPage) + checkBody, err := c.doGet(ctx, checkURL) + if err != nil { + return nil, fmt.Errorf("fetch check runs page %d: %w", checkPage, err) + } + var checkResp checkRunsResponse + if err := json.Unmarshal(checkBody, &checkResp); err != nil { + return nil, fmt.Errorf("parse check runs JSON: %w", err) + } + for _, cr := range checkResp.CheckRuns { + result = append(result, vcs.CommitStatus{ + Context: cr.Name, + Status: mapCheckRunStatus(cr.Conclusion), + Description: derefString(cr.Conclusion), + TargetURL: cr.HTMLURL, + }) + } + if len(checkResp.CheckRuns) < 100 { + break + } + } + + return result, nil +} + +// mapCheckRunStatus maps a check run conclusion to a vcs.CommitStatus status string. +// Conclusion alone determines the mapped state: nil conclusion means the run is +// still in progress (pending), regardless of the status field value. +func mapCheckRunStatus(conclusion *string) string { + if conclusion == nil { + // Still running or queued + return "pending" + } + switch *conclusion { + case "success": + return "success" + case "failure", "action_required", "timed_out": + return "failure" + case "cancelled", "skipped", "neutral": + return "success" // non-blocking: these do not indicate a blocking failure per GitHub check suite semantics + case "stale", "waiting": + return "pending" + default: + return "pending" + } +} + +// derefString safely dereferences a string pointer, returning empty string if nil. +func derefString(s *string) string { + if s == nil { + return "" + } + return *s +} diff --git a/github/pr_test.go b/github/pr_test.go new file mode 100644 index 0000000..0e05a50 --- /dev/null +++ b/github/pr_test.go @@ -0,0 +1,637 @@ +package github + +import ( + "context" + "encoding/json" + "fmt" + "net/http" + "net/http/httptest" + "strings" + "testing" + "time" +) + +func TestGetPullRequest_HappyPath(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/repos/owner/repo/pulls/42" { + t.Errorf("unexpected path: %s", r.URL.Path) + } + json.NewEncoder(w).Encode(map[string]interface{}{ + "number": 42, + "title": "Test PR", + "body": "Description", + "head": map[string]string{"sha": "abc123", "ref": "feature-branch"}, + "base": map[string]string{"ref": "main"}, + }) + })) + defer srv.Close() + + c := NewClient("token", srv.URL, AllowInsecureHTTP()) + c.SetHTTPClient(srv.Client()) + + pr, err := c.GetPullRequest(context.Background(), "owner", "repo", 42) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if pr.Number != 42 { + t.Errorf("expected number 42, got %d", pr.Number) + } + if pr.Title != "Test PR" { + t.Errorf("expected title 'Test PR', got %q", pr.Title) + } + if pr.Body != "Description" { + t.Errorf("expected body 'Description', got %q", pr.Body) + } + if pr.Head.SHA != "abc123" { + t.Errorf("expected head SHA 'abc123', got %q", pr.Head.SHA) + } + if pr.Head.Ref != "feature-branch" { + t.Errorf("expected head ref 'feature-branch', got %q", pr.Head.Ref) + } + if pr.Base.Ref != "main" { + t.Errorf("expected base ref 'main', got %q", pr.Base.Ref) + } +} + +func TestGetPullRequest_404(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(404) + w.Write([]byte(`{"message":"Not Found"}`)) + })) + defer srv.Close() + + c := NewClient("token", srv.URL, AllowInsecureHTTP()) + c.SetHTTPClient(srv.Client()) + + _, err := c.GetPullRequest(context.Background(), "owner", "repo", 999) + if err == nil { + t.Fatal("expected error for 404") + } + if !IsNotFound(err) { + t.Errorf("expected IsNotFound=true, got error: %v", err) + } +} + +func TestGetPullRequest_401(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(401) + w.Write([]byte(`{"message":"Bad credentials"}`)) + })) + defer srv.Close() + + c := NewClient("token", srv.URL, AllowInsecureHTTP()) + c.SetHTTPClient(srv.Client()) + + _, err := c.GetPullRequest(context.Background(), "owner", "repo", 1) + if err == nil { + t.Fatal("expected error for 401") + } + if !IsUnauthorized(err) { + t.Errorf("expected IsUnauthorized=true, got error: %v", err) + } +} + +func TestGetPullRequest_429Retry(t *testing.T) { + attempts := 0 + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + attempts++ + if attempts == 1 { + w.WriteHeader(429) + w.Write([]byte(`{"message":"rate limit"}`)) + return + } + json.NewEncoder(w).Encode(map[string]interface{}{ + "number": 1, + "title": "PR", + "body": "", + "head": map[string]string{"sha": "abc", "ref": "br"}, + "base": map[string]string{"ref": "main"}, + }) + })) + defer srv.Close() + + c := NewClient("token", srv.URL, AllowInsecureHTTP()) + c.SetHTTPClient(srv.Client()) + c.SetRetryBackoff([]time.Duration{1 * time.Millisecond}) + + pr, err := c.GetPullRequest(context.Background(), "owner", "repo", 1) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if pr.Number != 1 { + t.Errorf("expected number 1, got %d", pr.Number) + } + if attempts != 2 { + t.Errorf("expected 2 attempts, got %d", attempts) + } +} + +func TestGetPullRequest_MalformedJSON(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(200) + w.Write([]byte(`{invalid json`)) + })) + defer srv.Close() + + c := NewClient("token", srv.URL, AllowInsecureHTTP()) + c.SetHTTPClient(srv.Client()) + + _, err := c.GetPullRequest(context.Background(), "owner", "repo", 1) + if err == nil { + t.Fatal("expected error for malformed JSON") + } + if !strings.Contains(err.Error(), "parse PR JSON") { + t.Errorf("expected parse error, got: %v", err) + } +} + +func TestGetPullRequestDiff_HappyPath(t *testing.T) { + expectedDiff := "diff --git a/file.go b/file.go\n--- a/file.go\n+++ b/file.go\n@@ -1,3 +1,4 @@\n+// new line\n" + var gotAccept string + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + gotAccept = r.Header.Get("Accept") + w.WriteHeader(200) + w.Write([]byte(expectedDiff)) + })) + defer srv.Close() + + c := NewClient("token", srv.URL, AllowInsecureHTTP()) + c.SetHTTPClient(srv.Client()) + + diff, err := c.GetPullRequestDiff(context.Background(), "owner", "repo", 42) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if diff != expectedDiff { + t.Errorf("unexpected diff: %q", diff) + } + if gotAccept != "application/vnd.github.diff" { + t.Errorf("expected diff Accept header, got %q", gotAccept) + } +} + +func TestGetPullRequestDiff_404(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(404) + w.Write([]byte(`{"message":"Not Found"}`)) + })) + defer srv.Close() + + c := NewClient("token", srv.URL, AllowInsecureHTTP()) + c.SetHTTPClient(srv.Client()) + + _, err := c.GetPullRequestDiff(context.Background(), "owner", "repo", 999) + if err == nil { + t.Fatal("expected error for 404") + } +} + +func TestGetPullRequestDiff_401(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(401) + w.Write([]byte(`{"message":"Bad credentials"}`)) + })) + defer srv.Close() + + c := NewClient("token", srv.URL, AllowInsecureHTTP()) + c.SetHTTPClient(srv.Client()) + + _, err := c.GetPullRequestDiff(context.Background(), "owner", "repo", 1) + if err == nil { + t.Fatal("expected error for 401") + } +} + +func TestGetPullRequestFiles_HappyPath(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + json.NewEncoder(w).Encode([]map[string]interface{}{ + {"filename": "main.go", "status": "modified", "patch": "@@ -1,3 +1,4 @@\n+line"}, + {"filename": "test.go", "status": "added", "patch": "@@ -0,0 +1,5 @@\n+new file"}, + }) + })) + defer srv.Close() + + c := NewClient("token", srv.URL, AllowInsecureHTTP()) + c.SetHTTPClient(srv.Client()) + + files, err := c.GetPullRequestFiles(context.Background(), "owner", "repo", 1) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if len(files) != 2 { + t.Fatalf("expected 2 files, got %d", len(files)) + } + if files[0].Filename != "main.go" { + t.Errorf("expected filename 'main.go', got %q", files[0].Filename) + } + if files[0].Status != "modified" { + t.Errorf("expected status 'modified', got %q", files[0].Status) + } + if files[0].Patch != "@@ -1,3 +1,4 @@\n+line" { + t.Errorf("unexpected patch: %q", files[0].Patch) + } +} + +func TestGetPullRequestFiles_Pagination(t *testing.T) { + // Simulate > 100 files requiring pagination + page1Files := make([]map[string]string, 100) + for i := 0; i < 100; i++ { + page1Files[i] = map[string]string{ + "filename": fmt.Sprintf("file%d.go", i), + "status": "modified", + "patch": fmt.Sprintf("patch%d", i), + } + } + page2Files := []map[string]string{ + {"filename": "file100.go", "status": "added", "patch": "patch100"}, + } + + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + page := r.URL.Query().Get("page") + if page == "" || page == "1" { + json.NewEncoder(w).Encode(page1Files) + } else { + json.NewEncoder(w).Encode(page2Files) + } + })) + defer srv.Close() + + c := NewClient("token", srv.URL, AllowInsecureHTTP()) + c.SetHTTPClient(srv.Client()) + + files, err := c.GetPullRequestFiles(context.Background(), "owner", "repo", 1) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if len(files) != 101 { + t.Errorf("expected 101 files (paginated), got %d", len(files)) + } + if files[100].Filename != "file100.go" { + t.Errorf("expected last file 'file100.go', got %q", files[100].Filename) + } + if files[100].Patch != "patch100" { + t.Errorf("expected last patch 'patch100', got %q", files[100].Patch) + } +} + +func TestGetPullRequestFiles_BinaryFile_NoPatch(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Binary files have no patch field in GitHub response + json.NewEncoder(w).Encode([]map[string]interface{}{ + {"filename": "image.png", "status": "added"}, + }) + })) + defer srv.Close() + + c := NewClient("token", srv.URL, AllowInsecureHTTP()) + c.SetHTTPClient(srv.Client()) + + files, err := c.GetPullRequestFiles(context.Background(), "owner", "repo", 1) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if len(files) != 1 { + t.Fatalf("expected 1 file, got %d", len(files)) + } + if files[0].Patch != "" { + t.Errorf("expected empty patch for binary file, got %q", files[0].Patch) + } +} + +func TestGetPullRequestFiles_404(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(404) + w.Write([]byte(`{"message":"Not Found"}`)) + })) + defer srv.Close() + + c := NewClient("token", srv.URL, AllowInsecureHTTP()) + c.SetHTTPClient(srv.Client()) + + _, err := c.GetPullRequestFiles(context.Background(), "owner", "repo", 999) + if err == nil { + t.Fatal("expected error for 404") + } +} + +func TestGetPullRequestFiles_MalformedJSON(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(200) + w.Write([]byte(`not json`)) + })) + defer srv.Close() + + c := NewClient("token", srv.URL, AllowInsecureHTTP()) + c.SetHTTPClient(srv.Client()) + + _, err := c.GetPullRequestFiles(context.Background(), "owner", "repo", 1) + if err == nil { + t.Fatal("expected error for malformed JSON") + } +} + +func TestGetFileContentAtRef_HappyPath(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/repos/owner/repo/contents/path/to/file.go" { + t.Errorf("unexpected path: %s", r.URL.Path) + } + if r.URL.Query().Get("ref") != "abc123" { + t.Errorf("unexpected ref: %s", r.URL.Query().Get("ref")) + } + json.NewEncoder(w).Encode(map[string]string{ + "content": "cGFja2FnZSBtYWlu", // "package main" in base64 + "encoding": "base64", + }) + })) + defer srv.Close() + + c := NewClient("token", srv.URL, AllowInsecureHTTP()) + c.SetHTTPClient(srv.Client()) + + content, err := c.GetFileContentAtRef(context.Background(), "owner", "repo", "path/to/file.go", "abc123") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if content != "package main" { + t.Errorf("expected 'package main', got %q", content) + } +} + +func TestGetFileContentAtRef_EmptyRef(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Query().Get("ref") != "" { + t.Errorf("expected no ref param, got %q", r.URL.Query().Get("ref")) + } + json.NewEncoder(w).Encode(map[string]string{ + "content": "aGVsbG8=", // "hello" in base64 + "encoding": "base64", + }) + })) + defer srv.Close() + + c := NewClient("token", srv.URL, AllowInsecureHTTP()) + c.SetHTTPClient(srv.Client()) + + content, err := c.GetFileContentAtRef(context.Background(), "owner", "repo", "file.txt", "") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if content != "hello" { + t.Errorf("expected 'hello', got %q", content) + } +} + +func TestGetFileContentAtRef_404(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(404) + w.Write([]byte(`{"message":"Not Found"}`)) + })) + defer srv.Close() + + c := NewClient("token", srv.URL, AllowInsecureHTTP()) + c.SetHTTPClient(srv.Client()) + + _, err := c.GetFileContentAtRef(context.Background(), "owner", "repo", "missing.go", "main") + if err == nil { + t.Fatal("expected error for 404") + } +} + +func TestGetFileContentAtRef_401(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(401) + w.Write([]byte(`{"message":"Bad credentials"}`)) + })) + defer srv.Close() + + c := NewClient("token", srv.URL, AllowInsecureHTTP()) + c.SetHTTPClient(srv.Client()) + + _, err := c.GetFileContentAtRef(context.Background(), "owner", "repo", "file.go", "main") + if err == nil { + t.Fatal("expected error for 401") + } +} + +func TestGetFileContentAtRef_MalformedJSON(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(200) + w.Write([]byte(`not valid json`)) + })) + defer srv.Close() + + c := NewClient("token", srv.URL, AllowInsecureHTTP()) + c.SetHTTPClient(srv.Client()) + + _, err := c.GetFileContentAtRef(context.Background(), "owner", "repo", "file.go", "main") + if err == nil { + t.Fatal("expected error for malformed JSON") + } +} + +func TestGetFileContentAtRef_429Retry(t *testing.T) { + attempts := 0 + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + attempts++ + if attempts == 1 { + w.WriteHeader(429) + w.Write([]byte(`{"message":"rate limit"}`)) + return + } + json.NewEncoder(w).Encode(map[string]string{ + "content": "b2s=", // "ok" in base64 + "encoding": "base64", + }) + })) + defer srv.Close() + + c := NewClient("token", srv.URL, AllowInsecureHTTP()) + c.SetHTTPClient(srv.Client()) + c.SetRetryBackoff([]time.Duration{1 * time.Millisecond}) + + content, err := c.GetFileContentAtRef(context.Background(), "owner", "repo", "file.go", "main") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if content != "ok" { + t.Errorf("expected 'ok', got %q", content) + } + if attempts != 2 { + t.Errorf("expected 2 attempts, got %d", attempts) + } +} + +func TestGetCommitStatuses_HappyPath(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch { + case strings.Contains(r.URL.Path, "/status"): + json.NewEncoder(w).Encode(map[string]interface{}{ + "state": "success", + "statuses": []map[string]string{ + { + "context": "ci/build", + "state": "success", + "description": "Build passed", + "target_url": "https://ci.example.com/1", + }, + }, + }) + case strings.Contains(r.URL.Path, "/check-runs"): + conclusion := "success" + json.NewEncoder(w).Encode(map[string]interface{}{ + "total_count": 1, + "check_runs": []map[string]interface{}{ + { + "name": "lint", + "conclusion": &conclusion, + "status": "completed", + "html_url": "https://github.com/check/1", + }, + }, + }) + default: + t.Errorf("unexpected path: %s", r.URL.Path) + w.WriteHeader(404) + } + })) + defer srv.Close() + + c := NewClient("token", srv.URL, AllowInsecureHTTP()) + c.SetHTTPClient(srv.Client()) + + statuses, err := c.GetCommitStatuses(context.Background(), "owner", "repo", "abc123") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if len(statuses) != 2 { + t.Fatalf("expected 2 statuses, got %d", len(statuses)) + } + // First should be from commit statuses + if statuses[0].Context != "ci/build" { + t.Errorf("expected context 'ci/build', got %q", statuses[0].Context) + } + if statuses[0].Status != "success" { + t.Errorf("expected status 'success', got %q", statuses[0].Status) + } + // Second should be from check runs + if statuses[1].Context != "lint" { + t.Errorf("expected context 'lint', got %q", statuses[1].Context) + } + if statuses[1].Status != "success" { + t.Errorf("expected status 'success', got %q", statuses[1].Status) + } +} + +func TestGetCommitStatuses_CheckRunConclusions(t *testing.T) { + tests := []struct { + conclusion *string + status string + want string + }{ + {stringPtr("success"), "completed", "success"}, + {stringPtr("failure"), "completed", "failure"}, + {stringPtr("action_required"), "completed", "failure"}, + {stringPtr("timed_out"), "completed", "failure"}, + {stringPtr("cancelled"), "completed", "success"}, + {stringPtr("skipped"), "completed", "success"}, + {stringPtr("neutral"), "completed", "success"}, + {nil, "in_progress", "pending"}, + {nil, "queued", "pending"}, + } + + for _, tt := range tests { + name := "nil" + if tt.conclusion != nil { + name = *tt.conclusion + } + t.Run(name, func(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if strings.Contains(r.URL.Path, "/status") { + json.NewEncoder(w).Encode(map[string]interface{}{ + "state": "success", + "statuses": []interface{}{}, + }) + return + } + json.NewEncoder(w).Encode(map[string]interface{}{ + "total_count": 1, + "check_runs": []map[string]interface{}{ + { + "name": "check", + "conclusion": tt.conclusion, + "status": tt.status, + "html_url": "https://github.com/check/1", + }, + }, + }) + })) + defer srv.Close() + + c := NewClient("token", srv.URL, AllowInsecureHTTP()) + c.SetHTTPClient(srv.Client()) + + statuses, err := c.GetCommitStatuses(context.Background(), "owner", "repo", "sha1") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if len(statuses) != 1 { + t.Fatalf("expected 1 status, got %d", len(statuses)) + } + if statuses[0].Status != tt.want { + t.Errorf("expected status %q, got %q", tt.want, statuses[0].Status) + } + }) + } +} + +func TestGetCommitStatuses_404(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(404) + w.Write([]byte(`{"message":"Not Found"}`)) + })) + defer srv.Close() + + c := NewClient("token", srv.URL, AllowInsecureHTTP()) + c.SetHTTPClient(srv.Client()) + + _, err := c.GetCommitStatuses(context.Background(), "owner", "repo", "badsha") + if err == nil { + t.Fatal("expected error for 404") + } +} + +func TestGetCommitStatuses_401(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(401) + w.Write([]byte(`{"message":"Bad credentials"}`)) + })) + defer srv.Close() + + c := NewClient("token", srv.URL, AllowInsecureHTTP()) + c.SetHTTPClient(srv.Client()) + + _, err := c.GetCommitStatuses(context.Background(), "owner", "repo", "sha") + if err == nil { + t.Fatal("expected error for 401") + } +} + +func TestGetCommitStatuses_MalformedJSON(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(200) + w.Write([]byte(`not json`)) + })) + defer srv.Close() + + c := NewClient("token", srv.URL, AllowInsecureHTTP()) + c.SetHTTPClient(srv.Client()) + + _, err := c.GetCommitStatuses(context.Background(), "owner", "repo", "sha") + if err == nil { + t.Fatal("expected error for malformed JSON") + } +} + +func stringPtr(s string) *string { + return &s +} From 289b400bfd41d0e01f4baa749e7de78a311f7979 Mon Sep 17 00:00:00 2001 From: claw Date: Tue, 12 May 2026 21:21:01 -0700 Subject: [PATCH 2/5] fix(github): add GetFileContentAtRef and fix conformance test - Implement GetFileContentAtRef on *Client to satisfy vcs.PRReader interface - Add escapePath and decodeBase64Content helpers - Fix conformance_test.go to properly import and qualify github.Client (was using unqualified Client in package github_test) Fixes CI failure: the PRReader interface requires GetFileContentAtRef but it was missing from this PR (only present in the file-reader PR). --- github/conformance_test.go | 9 +++-- github/files.go | 68 ++++++++++++++++++++++++++++++++++++++ 2 files changed, 75 insertions(+), 2 deletions(-) create mode 100644 github/files.go diff --git a/github/conformance_test.go b/github/conformance_test.go index 4dfa195..666bcab 100644 --- a/github/conformance_test.go +++ b/github/conformance_test.go @@ -1,5 +1,10 @@ package github_test -import "gitea.weiker.me/rodin/review-bot/vcs" +import ( + "gitea.weiker.me/rodin/review-bot/github" + "gitea.weiker.me/rodin/review-bot/vcs" +) -var _ vcs.PRReader = (*Client)(nil) +// Compile-time interface conformance assertion. +// Verifies github.Client satisfies vcs.PRReader. +var _ vcs.PRReader = (*github.Client)(nil) diff --git a/github/files.go b/github/files.go new file mode 100644 index 0000000..f7d415d --- /dev/null +++ b/github/files.go @@ -0,0 +1,68 @@ +package github + +import ( + "context" + "encoding/base64" + "encoding/json" + "fmt" + "net/url" + "strings" +) + +// GetFileContentAtRef fetches a file at a specific ref from a repo. +// If ref is empty, the query parameter is omitted (uses default branch). +// +// Note: dot-segments ("." and "..") in the path are silently removed to +// prevent path traversal. This means a path like "foo/../bar" resolves +// to "foo/bar" rather than "bar". +func (c *Client) GetFileContentAtRef(ctx context.Context, owner, repo, path, ref string) (string, error) { + reqURL := fmt.Sprintf("%s/repos/%s/%s/contents/%s", + c.baseURL, url.PathEscape(owner), url.PathEscape(repo), escapePath(path)) + if ref != "" { + reqURL += "?ref=" + url.QueryEscape(ref) + } + body, err := c.doGet(ctx, reqURL) + if err != nil { + return "", fmt.Errorf("fetch file %s: %w", path, err) + } + var resp struct { + Content string `json:"content"` + Encoding string `json:"encoding"` + } + if err := json.Unmarshal(body, &resp); err != nil { + return "", fmt.Errorf("parse file content JSON: %w", err) + } + if resp.Encoding != "base64" { + return "", fmt.Errorf("unexpected encoding %q for file %s", resp.Encoding, path) + } + decoded, err := decodeBase64Content(resp.Content) + if err != nil { + return "", fmt.Errorf("decode base64 content for %s: %w", path, err) + } + return decoded, nil +} + +// escapePath encodes each segment of a slash-separated path, stripping +// dot-segments to prevent path traversal. +func escapePath(p string) string { + parts := strings.Split(p, "/") + var clean []string + for _, part := range parts { + if part == "." || part == ".." || part == "" { + continue + } + clean = append(clean, url.PathEscape(part)) + } + return strings.Join(clean, "/") +} + +// decodeBase64Content decodes base64-encoded content from the GitHub contents API. +// GitHub returns base64 content with line breaks for formatting; we strip \r and \n before decoding. +func decodeBase64Content(encoded string) (string, error) { + cleaned := strings.NewReplacer("\n", "", "\r", "").Replace(encoded) + decoded, err := base64.StdEncoding.DecodeString(cleaned) + if err != nil { + return "", err + } + return string(decoded), nil +} From eaccc9607375f721b40799201d9dc814ba7dbdc8 Mon Sep 17 00:00:00 2001 From: claw Date: Tue, 12 May 2026 21:36:45 -0700 Subject: [PATCH 3/5] fix: address review feedback on PR #102 - Separate maxPages into maxFilesPages and maxCheckRunPages constants for clarity (sonnet MINOR #1) - Add parallel to CheckRunConclusions subtests (sonnet MINOR #2) - Add TestGetCommitStatuses_CheckRunsErrorAfterStatusesSucceed test covering check-runs 500 after statuses succeed (sonnet MINOR #2) - Expand mapCheckRunStatus doc comment with full mapping rules including cancelled/skipped/neutral rationale and unknown value behavior (sonnet MINOR #3, gpt MINOR #1) - Expand GetPullRequest doc comment to mention error types returned (sonnet NIT #4) - Add inline comment on Description field clarifying it holds raw conclusion value (gpt NIT #3) --- github/pr.go | 33 +++++++++++++++++++++++++-------- github/pr_test.go | 39 +++++++++++++++++++++++++++++++++++++++ 2 files changed, 64 insertions(+), 8 deletions(-) diff --git a/github/pr.go b/github/pr.go index e9bea5a..c028506 100644 --- a/github/pr.go +++ b/github/pr.go @@ -51,7 +51,10 @@ type checkRunsResponse struct { } `json:"check_runs"` } -// GetPullRequest fetches PR metadata. +// GetPullRequest fetches PR metadata from the GitHub API. +// Returns an *APIError wrapping the HTTP status on non-2xx responses (e.g. +// IsNotFound for 404, IsUnauthorized for 401). Network and context errors +// are wrapped but not typed as *APIError. func (c *Client) GetPullRequest(ctx context.Context, owner, repo string, number int) (*vcs.PullRequest, error) { reqURL := fmt.Sprintf("%s/repos/%s/%s/pulls/%d", c.baseURL, url.PathEscape(owner), url.PathEscape(repo), number) body, err := c.doGet(ctx, reqURL) @@ -82,9 +85,15 @@ func (c *Client) GetPullRequestDiff(ctx context.Context, owner, repo string, num return string(body), nil } -// maxPages is the upper bound on pagination loops to prevent unbounded iteration -// in case the server returns a full page indefinitely. -const maxPages = 100 +const ( + // maxFilesPages is the upper bound on pagination loops for PR file listing, + // preventing unbounded iteration if the server always returns a full page. + maxFilesPages = 100 + + // maxCheckRunPages is the upper bound on pagination loops for check-run listing, + // preventing unbounded iteration if the server always returns a full page. + maxCheckRunPages = 100 +) // GetPullRequestFiles fetches the list of files changed in a PR. // Paginates through all pages (100 per page) to collect all files. @@ -93,7 +102,7 @@ const maxPages = 100 func (c *Client) GetPullRequestFiles(ctx context.Context, owner, repo string, number int) ([]vcs.ChangedFile, error) { var allFiles []vcs.ChangedFile - for page := 1; page <= maxPages; page++ { + for page := 1; page <= maxFilesPages; page++ { reqURL := fmt.Sprintf("%s/repos/%s/%s/pulls/%d/files?per_page=100&page=%d", c.baseURL, url.PathEscape(owner), url.PathEscape(repo), number, page) body, err := c.doGet(ctx, reqURL) @@ -154,7 +163,7 @@ func (c *Client) GetCommitStatuses(ctx context.Context, owner, repo, sha string) } // Fetch check runs (paginated) - for checkPage := 1; checkPage <= maxPages; checkPage++ { + for checkPage := 1; checkPage <= maxCheckRunPages; checkPage++ { checkURL := fmt.Sprintf("%s/repos/%s/%s/commits/%s/check-runs?per_page=100&page=%d", c.baseURL, url.PathEscape(owner), url.PathEscape(repo), url.PathEscape(sha), checkPage) checkBody, err := c.doGet(ctx, checkURL) @@ -169,7 +178,7 @@ func (c *Client) GetCommitStatuses(ctx context.Context, owner, repo, sha string) result = append(result, vcs.CommitStatus{ Context: cr.Name, Status: mapCheckRunStatus(cr.Conclusion), - Description: derefString(cr.Conclusion), + Description: derefString(cr.Conclusion), // raw conclusion value (e.g. "success", "failure", "skipped") TargetURL: cr.HTMLURL, }) } @@ -181,9 +190,17 @@ func (c *Client) GetCommitStatuses(ctx context.Context, owner, repo, sha string) return result, nil } -// mapCheckRunStatus maps a check run conclusion to a vcs.CommitStatus status string. +// mapCheckRunStatus maps a GitHub check run conclusion to a vcs.CommitStatus status string. // Conclusion alone determines the mapped state: nil conclusion means the run is // still in progress (pending), regardless of the status field value. +// +// Mapping rules: +// - nil → "pending" (run still in progress or queued) +// - "success" → "success" +// - "failure", "action_required", "timed_out" → "failure" +// - "cancelled", "skipped", "neutral" → "success" (non-blocking per GitHub check suite semantics) +// - "stale", "waiting" → "pending" +// - unknown values → "pending" (conservative: treat unrecognized conclusions as incomplete) func mapCheckRunStatus(conclusion *string) string { if conclusion == nil { // Still running or queued diff --git a/github/pr_test.go b/github/pr_test.go index 0e05a50..f79147b 100644 --- a/github/pr_test.go +++ b/github/pr_test.go @@ -545,6 +545,7 @@ func TestGetCommitStatuses_CheckRunConclusions(t *testing.T) { name = *tt.conclusion } t.Run(name, func(t *testing.T) { + t.Parallel() srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if strings.Contains(r.URL.Path, "/status") { json.NewEncoder(w).Encode(map[string]interface{}{ @@ -632,6 +633,44 @@ func TestGetCommitStatuses_MalformedJSON(t *testing.T) { } } +func TestGetCommitStatuses_CheckRunsErrorAfterStatusesSucceed(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch { + case strings.Contains(r.URL.Path, "/status"): + // Statuses succeed + json.NewEncoder(w).Encode(map[string]interface{}{ + "state": "success", + "statuses": []map[string]string{ + { + "context": "ci/build", + "state": "success", + "description": "Build passed", + "target_url": "https://ci.example.com/1", + }, + }, + }) + case strings.Contains(r.URL.Path, "/check-runs"): + // Check runs fail with 500 + w.WriteHeader(500) + w.Write([]byte(`{"message":"Internal Server Error"}`)) + default: + w.WriteHeader(404) + } + })) + defer srv.Close() + + c := NewClient("token", srv.URL, AllowInsecureHTTP()) + c.SetHTTPClient(srv.Client()) + + _, err := c.GetCommitStatuses(context.Background(), "owner", "repo", "abc123") + if err == nil { + t.Fatal("expected error when check-runs endpoint fails after statuses succeed") + } + if !strings.Contains(err.Error(), "fetch check runs") { + t.Errorf("expected check runs error, got: %v", err) + } +} + func stringPtr(s string) *string { return &s } From 3cd5ae594e9d7a12486eb9980c3ca44225aeb507 Mon Sep 17 00:00:00 2001 From: claw Date: Tue, 12 May 2026 22:03:52 -0700 Subject: [PATCH 4/5] fix(github): escapePath returns error on dot-segments, fix Description semantics - escapePath now returns an error when paths contain dot-segments (".", "..") instead of silently rewriting them. This prevents subtle API misses where callers pass "foo/../bar" expecting to hit "bar" but the old code produced "foo/bar". - Uses path.Clean for canonical form after validation. - CommitStatus.Description for check runs is now empty string instead of the raw conclusion enum. The conclusion is already captured in the Status field via mapCheckRunStatus; storing it again in Description was semantically inconsistent with commit statuses where Description carries a human-readable narrative. - Removed unused derefString helper. - Added tests for escapePath valid paths, dot-segment rejection, and GetFileContentAtRef dot-segment error propagation. --- github/files.go | 51 +++++++++++++++++++--------- github/files_test.go | 79 ++++++++++++++++++++++++++++++++++++++++++++ github/pr.go | 9 +---- 3 files changed, 115 insertions(+), 24 deletions(-) create mode 100644 github/files_test.go diff --git a/github/files.go b/github/files.go index f7d415d..d0f7ecc 100644 --- a/github/files.go +++ b/github/files.go @@ -6,24 +6,28 @@ import ( "encoding/json" "fmt" "net/url" + "path" "strings" ) // GetFileContentAtRef fetches a file at a specific ref from a repo. // If ref is empty, the query parameter is omitted (uses default branch). // -// Note: dot-segments ("." and "..") in the path are silently removed to -// prevent path traversal. This means a path like "foo/../bar" resolves -// to "foo/bar" rather than "bar". -func (c *Client) GetFileContentAtRef(ctx context.Context, owner, repo, path, ref string) (string, error) { +// Returns an error if the path contains dot-segments (".", "..") or +// attempts to traverse above the repository root. +func (c *Client) GetFileContentAtRef(ctx context.Context, owner, repo, filePath, ref string) (string, error) { + escaped, err := escapePath(filePath) + if err != nil { + return "", fmt.Errorf("invalid file path: %w", err) + } reqURL := fmt.Sprintf("%s/repos/%s/%s/contents/%s", - c.baseURL, url.PathEscape(owner), url.PathEscape(repo), escapePath(path)) + c.baseURL, url.PathEscape(owner), url.PathEscape(repo), escaped) if ref != "" { reqURL += "?ref=" + url.QueryEscape(ref) } body, err := c.doGet(ctx, reqURL) if err != nil { - return "", fmt.Errorf("fetch file %s: %w", path, err) + return "", fmt.Errorf("fetch file %s: %w", filePath, err) } var resp struct { Content string `json:"content"` @@ -33,27 +37,42 @@ func (c *Client) GetFileContentAtRef(ctx context.Context, owner, repo, path, ref return "", fmt.Errorf("parse file content JSON: %w", err) } if resp.Encoding != "base64" { - return "", fmt.Errorf("unexpected encoding %q for file %s", resp.Encoding, path) + return "", fmt.Errorf("unexpected encoding %q for file %s", resp.Encoding, filePath) } decoded, err := decodeBase64Content(resp.Content) if err != nil { - return "", fmt.Errorf("decode base64 content for %s: %w", path, err) + return "", fmt.Errorf("decode base64 content for %s: %w", filePath, err) } return decoded, nil } -// escapePath encodes each segment of a slash-separated path, stripping -// dot-segments to prevent path traversal. -func escapePath(p string) string { - parts := strings.Split(p, "/") - var clean []string +// escapePath validates and encodes a slash-separated file path for use in +// GitHub API URLs. Returns an error if the path contains dot-segments ("." +// or "..") or resolves to a path outside the repository root. +func escapePath(p string) (string, error) { + // Reject paths containing dot-segments rather than silently rewriting them. + for _, seg := range strings.Split(p, "/") { + if seg == "." || seg == ".." { + return "", fmt.Errorf("path contains dot-segment %q: %s", seg, p) + } + } + + // Use path.Clean for canonical form, then verify it doesn't escape root. + cleaned := path.Clean(p) + if cleaned == "." || strings.HasPrefix(cleaned, "..") { + return "", fmt.Errorf("path resolves outside repository root: %s", p) + } + + // Encode each segment individually. + parts := strings.Split(cleaned, "/") + var encoded []string for _, part := range parts { - if part == "." || part == ".." || part == "" { + if part == "" { continue } - clean = append(clean, url.PathEscape(part)) + encoded = append(encoded, url.PathEscape(part)) } - return strings.Join(clean, "/") + return strings.Join(encoded, "/"), nil } // decodeBase64Content decodes base64-encoded content from the GitHub contents API. diff --git a/github/files_test.go b/github/files_test.go new file mode 100644 index 0000000..62c5412 --- /dev/null +++ b/github/files_test.go @@ -0,0 +1,79 @@ +package github + +import ( + "context" + "net/http" + "net/http/httptest" + "strings" + "testing" +) + +func TestEscapePath_ValidPaths(t *testing.T) { + t.Parallel() + tests := []struct { + name string + path string + want string + }{ + {"simple file", "file.go", "file.go"}, + {"nested path", "path/to/file.go", "path/to/file.go"}, + {"special chars", "path/to/my file.go", "path/to/my%20file.go"}, + {"leading slash stripped", "/path/to/file.go", "path/to/file.go"}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + got, err := escapePath(tt.path) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if got != tt.want { + t.Errorf("escapePath(%q) = %q, want %q", tt.path, got, tt.want) + } + }) + } +} + +func TestEscapePath_DotSegments(t *testing.T) { + t.Parallel() + tests := []struct { + name string + path string + }{ + {"single dot", "./file.go"}, + {"double dot", "../file.go"}, + {"dot in middle", "path/./file.go"}, + {"parent traversal", "path/../file.go"}, + {"only dots", ".."}, + {"nested parent traversal", "a/b/../../c"}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + _, err := escapePath(tt.path) + if err == nil { + t.Fatalf("expected error for path %q, got nil", tt.path) + } + if !strings.Contains(err.Error(), "dot-segment") { + t.Errorf("expected error about dot-segment, got: %v", err) + } + }) + } +} + +func TestGetFileContentAtRef_DotSegmentError(t *testing.T) { + // Server should never be called — the error is caught before the request. + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + t.Fatal("server should not have been called") + })) + defer srv.Close() + + c := NewClient(srv.URL, "token") + _, err := c.GetFileContentAtRef(context.Background(), "owner", "repo", "foo/../bar.go", "main") + if err == nil { + t.Fatal("expected error for path with dot-segments") + } + if !strings.Contains(err.Error(), "invalid file path") { + t.Errorf("expected 'invalid file path' error, got: %v", err) + } +} diff --git a/github/pr.go b/github/pr.go index c028506..2aa4c79 100644 --- a/github/pr.go +++ b/github/pr.go @@ -178,7 +178,7 @@ func (c *Client) GetCommitStatuses(ctx context.Context, owner, repo, sha string) result = append(result, vcs.CommitStatus{ Context: cr.Name, Status: mapCheckRunStatus(cr.Conclusion), - Description: derefString(cr.Conclusion), // raw conclusion value (e.g. "success", "failure", "skipped") + Description: "", // check runs have no human-readable description; conclusion is captured in Status TargetURL: cr.HTMLURL, }) } @@ -220,10 +220,3 @@ func mapCheckRunStatus(conclusion *string) string { } } -// derefString safely dereferences a string pointer, returning empty string if nil. -func derefString(s *string) string { - if s == nil { - return "" - } - return *s -} From 55366b3431d3ec2417d1bb087398ff4e0895effe Mon Sep 17 00:00:00 2001 From: claw Date: Tue, 12 May 2026 22:17:32 -0700 Subject: [PATCH 5/5] fix: address review feedback on PRReader implementation - Add maxFileContentSize (10 MB) limit to decodeBase64Content to prevent resource exhaustion from oversized file content (security MINOR) - Fix reversed NewClient arg order in TestGetFileContentAtRef_DotSegmentError (GPT MINOR + Sonnet NIT) - Remove 'waiting' from mapCheckRunStatus conclusion cases since it is a status value not a conclusion, update comment (GPT NIT) - Add TestDecodeBase64Content_SizeLimit test --- github/files.go | 13 +++++++++++++ github/files_test.go | 19 ++++++++++++++++++- github/pr.go | 4 ++-- 3 files changed, 33 insertions(+), 3 deletions(-) diff --git a/github/files.go b/github/files.go index d0f7ecc..6d968a3 100644 --- a/github/files.go +++ b/github/files.go @@ -75,13 +75,26 @@ func escapePath(p string) (string, error) { return strings.Join(encoded, "/"), nil } +// maxFileContentSize is the maximum decoded file size (10 MB) to prevent +// resource exhaustion when decoding base64 content from the API. +const maxFileContentSize = 10 * 1024 * 1024 + // decodeBase64Content decodes base64-encoded content from the GitHub contents API. // GitHub returns base64 content with line breaks for formatting; we strip \r and \n before decoding. +// Returns an error if the decoded content exceeds maxFileContentSize. func decodeBase64Content(encoded string) (string, error) { cleaned := strings.NewReplacer("\n", "", "\r", "").Replace(encoded) + // Check estimated decoded size before allocating. + // Base64 encodes 3 bytes into 4 chars, so decoded ~ len*3/4. + if len(cleaned)*3/4 > maxFileContentSize { + return "", fmt.Errorf("file content too large: estimated %d bytes exceeds limit of %d", len(cleaned)*3/4, maxFileContentSize) + } decoded, err := base64.StdEncoding.DecodeString(cleaned) if err != nil { return "", err } + if len(decoded) > maxFileContentSize { + return "", fmt.Errorf("file content too large: %d bytes exceeds limit of %d", len(decoded), maxFileContentSize) + } return string(decoded), nil } diff --git a/github/files_test.go b/github/files_test.go index 62c5412..8385a07 100644 --- a/github/files_test.go +++ b/github/files_test.go @@ -68,7 +68,7 @@ func TestGetFileContentAtRef_DotSegmentError(t *testing.T) { })) defer srv.Close() - c := NewClient(srv.URL, "token") + c := NewClient("token", srv.URL) _, err := c.GetFileContentAtRef(context.Background(), "owner", "repo", "foo/../bar.go", "main") if err == nil { t.Fatal("expected error for path with dot-segments") @@ -77,3 +77,20 @@ func TestGetFileContentAtRef_DotSegmentError(t *testing.T) { t.Errorf("expected 'invalid file path' error, got: %v", err) } } + +func TestDecodeBase64Content_SizeLimit(t *testing.T) { + t.Parallel() + // Create base64 content that would decode to > maxFileContentSize. + // maxFileContentSize is 10MB. Base64 of 11MB worth of zeros. + // We just need something big enough to trigger the estimated size check. + // 14MB of base64 chars (decodes to ~10.5MB). + huge := strings.Repeat("A", 14*1024*1024) + _, err := decodeBase64Content(huge) + if err == nil { + t.Fatal("expected error for oversized content") + } + if !strings.Contains(err.Error(), "too large") { + t.Errorf("expected 'too large' error, got: %v", err) + } +} + diff --git a/github/pr.go b/github/pr.go index 2aa4c79..6bee50b 100644 --- a/github/pr.go +++ b/github/pr.go @@ -199,7 +199,7 @@ func (c *Client) GetCommitStatuses(ctx context.Context, owner, repo, sha string) // - "success" → "success" // - "failure", "action_required", "timed_out" → "failure" // - "cancelled", "skipped", "neutral" → "success" (non-blocking per GitHub check suite semantics) -// - "stale", "waiting" → "pending" +// - "stale" → "pending" (check run became stale before completing) // - unknown values → "pending" (conservative: treat unrecognized conclusions as incomplete) func mapCheckRunStatus(conclusion *string) string { if conclusion == nil { @@ -213,7 +213,7 @@ func mapCheckRunStatus(conclusion *string) string { return "failure" case "cancelled", "skipped", "neutral": return "success" // non-blocking: these do not indicate a blocking failure per GitHub check suite semantics - case "stale", "waiting": + case "stale": return "pending" default: return "pending"