diff --git a/github/conformance_test.go b/github/conformance_test.go new file mode 100644 index 0000000..4dfa195 --- /dev/null +++ b/github/conformance_test.go @@ -0,0 +1,5 @@ +package github_test + +import "gitea.weiker.me/rodin/review-bot/vcs" + +var _ vcs.PRReader = (*Client)(nil) diff --git a/github/pr.go b/github/pr.go new file mode 100644 index 0000000..e9bea5a --- /dev/null +++ b/github/pr.go @@ -0,0 +1,212 @@ +package github + +import ( + "context" + "encoding/json" + "fmt" + "net/http" + "net/url" + + "gitea.weiker.me/rodin/review-bot/vcs" +) + +// pullRequestResponse is the GitHub API response for a pull request. +type pullRequestResponse struct { + Number int `json:"number"` + Title string `json:"title"` + Body string `json:"body"` + Head struct { + SHA string `json:"sha"` + Ref string `json:"ref"` + } `json:"head"` + Base struct { + Ref string `json:"ref"` + } `json:"base"` +} + +// changedFileResponse is the GitHub API response for a changed file in a PR. +type changedFileResponse struct { + Filename string `json:"filename"` + Status string `json:"status"` + Patch string `json:"patch"` +} + +// commitStatusResponse is the GitHub combined status API response. +type commitStatusResponse struct { + Statuses []struct { + Context string `json:"context"` + State string `json:"state"` + Description string `json:"description"` + TargetURL string `json:"target_url"` + } `json:"statuses"` +} + +// checkRunsResponse is the GitHub check runs API response. +type checkRunsResponse struct { + CheckRuns []struct { + Name string `json:"name"` + Conclusion *string `json:"conclusion"` + Status string `json:"status"` + HTMLURL string `json:"html_url"` + } `json:"check_runs"` +} + +// GetPullRequest fetches PR metadata. +func (c *Client) GetPullRequest(ctx context.Context, owner, repo string, number int) (*vcs.PullRequest, error) { + reqURL := fmt.Sprintf("%s/repos/%s/%s/pulls/%d", c.baseURL, url.PathEscape(owner), url.PathEscape(repo), number) + body, err := c.doGet(ctx, reqURL) + if err != nil { + return nil, fmt.Errorf("fetch PR: %w", err) + } + var resp pullRequestResponse + if err := json.Unmarshal(body, &resp); err != nil { + return nil, fmt.Errorf("parse PR JSON: %w", err) + } + return &vcs.PullRequest{ + Number: resp.Number, + Title: resp.Title, + Body: resp.Body, + Head: vcs.HeadRef{SHA: resp.Head.SHA, Ref: resp.Head.Ref}, + Base: vcs.BaseRef{Ref: resp.Base.Ref}, + }, nil +} + +// GetPullRequestDiff fetches the unified diff for a PR. +// Uses Accept: application/vnd.github.diff to get raw diff text. +func (c *Client) GetPullRequestDiff(ctx context.Context, owner, repo string, number int) (string, error) { + reqURL := fmt.Sprintf("%s/repos/%s/%s/pulls/%d", c.baseURL, url.PathEscape(owner), url.PathEscape(repo), number) + body, err := c.doRequest(ctx, http.MethodGet, reqURL, "application/vnd.github.diff") + if err != nil { + return "", fmt.Errorf("fetch diff: %w", err) + } + return string(body), nil +} + +// maxPages is the upper bound on pagination loops to prevent unbounded iteration +// in case the server returns a full page indefinitely. +const maxPages = 100 + +// GetPullRequestFiles fetches the list of files changed in a PR. +// Paginates through all pages (100 per page) to collect all files. +// Returns nil (not an empty slice) when the PR has no changed files. +// Callers can safely range over or check len() on a nil slice. +func (c *Client) GetPullRequestFiles(ctx context.Context, owner, repo string, number int) ([]vcs.ChangedFile, error) { + var allFiles []vcs.ChangedFile + + for page := 1; page <= maxPages; page++ { + reqURL := fmt.Sprintf("%s/repos/%s/%s/pulls/%d/files?per_page=100&page=%d", + c.baseURL, url.PathEscape(owner), url.PathEscape(repo), number, page) + body, err := c.doGet(ctx, reqURL) + if err != nil { + return nil, fmt.Errorf("fetch PR files page %d: %w", page, err) + } + var files []changedFileResponse + if err := json.Unmarshal(body, &files); err != nil { + return nil, fmt.Errorf("parse PR files JSON: %w", err) + } + if len(files) == 0 { + break + } + for _, f := range files { + allFiles = append(allFiles, vcs.ChangedFile{ + Filename: f.Filename, + Status: f.Status, + Patch: f.Patch, + }) + } + if len(files) < 100 { + break + } + } + + return allFiles, nil +} + +// GetCommitStatuses fetches both commit statuses and check runs for a SHA, +// merging them into a unified []vcs.CommitStatus slice. +// Returns nil (not an empty slice) when there are no statuses or check runs. +// If the commit statuses endpoint fails (e.g. 404 for an unknown SHA), the +// function returns immediately without attempting the check-runs endpoint. +// If the check-runs endpoint fails after statuses were fetched successfully, +// the function returns an error (not a partial result) so callers always get +// either a complete view or a clear error signal. +func (c *Client) GetCommitStatuses(ctx context.Context, owner, repo, sha string) ([]vcs.CommitStatus, error) { + var result []vcs.CommitStatus + + // Fetch commit statuses + statusURL := fmt.Sprintf("%s/repos/%s/%s/commits/%s/status", + c.baseURL, url.PathEscape(owner), url.PathEscape(repo), url.PathEscape(sha)) + statusBody, err := c.doGet(ctx, statusURL) + if err != nil { + return nil, fmt.Errorf("fetch commit statuses: %w", err) + } + var statusResp commitStatusResponse + if err := json.Unmarshal(statusBody, &statusResp); err != nil { + return nil, fmt.Errorf("parse commit statuses JSON: %w", err) + } + for _, s := range statusResp.Statuses { + result = append(result, vcs.CommitStatus{ + Context: s.Context, + Status: s.State, + Description: s.Description, + TargetURL: s.TargetURL, + }) + } + + // Fetch check runs (paginated) + for checkPage := 1; checkPage <= maxPages; checkPage++ { + checkURL := fmt.Sprintf("%s/repos/%s/%s/commits/%s/check-runs?per_page=100&page=%d", + c.baseURL, url.PathEscape(owner), url.PathEscape(repo), url.PathEscape(sha), checkPage) + checkBody, err := c.doGet(ctx, checkURL) + if err != nil { + return nil, fmt.Errorf("fetch check runs page %d: %w", checkPage, err) + } + var checkResp checkRunsResponse + if err := json.Unmarshal(checkBody, &checkResp); err != nil { + return nil, fmt.Errorf("parse check runs JSON: %w", err) + } + for _, cr := range checkResp.CheckRuns { + result = append(result, vcs.CommitStatus{ + Context: cr.Name, + Status: mapCheckRunStatus(cr.Conclusion), + Description: derefString(cr.Conclusion), + TargetURL: cr.HTMLURL, + }) + } + if len(checkResp.CheckRuns) < 100 { + break + } + } + + return result, nil +} + +// mapCheckRunStatus maps a check run conclusion to a vcs.CommitStatus status string. +// Conclusion alone determines the mapped state: nil conclusion means the run is +// still in progress (pending), regardless of the status field value. +func mapCheckRunStatus(conclusion *string) string { + if conclusion == nil { + // Still running or queued + return "pending" + } + switch *conclusion { + case "success": + return "success" + case "failure", "action_required", "timed_out": + return "failure" + case "cancelled", "skipped", "neutral": + return "success" // non-blocking: these do not indicate a blocking failure per GitHub check suite semantics + case "stale", "waiting": + return "pending" + default: + return "pending" + } +} + +// derefString safely dereferences a string pointer, returning empty string if nil. +func derefString(s *string) string { + if s == nil { + return "" + } + return *s +} diff --git a/github/pr_test.go b/github/pr_test.go new file mode 100644 index 0000000..0e05a50 --- /dev/null +++ b/github/pr_test.go @@ -0,0 +1,637 @@ +package github + +import ( + "context" + "encoding/json" + "fmt" + "net/http" + "net/http/httptest" + "strings" + "testing" + "time" +) + +func TestGetPullRequest_HappyPath(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/repos/owner/repo/pulls/42" { + t.Errorf("unexpected path: %s", r.URL.Path) + } + json.NewEncoder(w).Encode(map[string]interface{}{ + "number": 42, + "title": "Test PR", + "body": "Description", + "head": map[string]string{"sha": "abc123", "ref": "feature-branch"}, + "base": map[string]string{"ref": "main"}, + }) + })) + defer srv.Close() + + c := NewClient("token", srv.URL, AllowInsecureHTTP()) + c.SetHTTPClient(srv.Client()) + + pr, err := c.GetPullRequest(context.Background(), "owner", "repo", 42) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if pr.Number != 42 { + t.Errorf("expected number 42, got %d", pr.Number) + } + if pr.Title != "Test PR" { + t.Errorf("expected title 'Test PR', got %q", pr.Title) + } + if pr.Body != "Description" { + t.Errorf("expected body 'Description', got %q", pr.Body) + } + if pr.Head.SHA != "abc123" { + t.Errorf("expected head SHA 'abc123', got %q", pr.Head.SHA) + } + if pr.Head.Ref != "feature-branch" { + t.Errorf("expected head ref 'feature-branch', got %q", pr.Head.Ref) + } + if pr.Base.Ref != "main" { + t.Errorf("expected base ref 'main', got %q", pr.Base.Ref) + } +} + +func TestGetPullRequest_404(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(404) + w.Write([]byte(`{"message":"Not Found"}`)) + })) + defer srv.Close() + + c := NewClient("token", srv.URL, AllowInsecureHTTP()) + c.SetHTTPClient(srv.Client()) + + _, err := c.GetPullRequest(context.Background(), "owner", "repo", 999) + if err == nil { + t.Fatal("expected error for 404") + } + if !IsNotFound(err) { + t.Errorf("expected IsNotFound=true, got error: %v", err) + } +} + +func TestGetPullRequest_401(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(401) + w.Write([]byte(`{"message":"Bad credentials"}`)) + })) + defer srv.Close() + + c := NewClient("token", srv.URL, AllowInsecureHTTP()) + c.SetHTTPClient(srv.Client()) + + _, err := c.GetPullRequest(context.Background(), "owner", "repo", 1) + if err == nil { + t.Fatal("expected error for 401") + } + if !IsUnauthorized(err) { + t.Errorf("expected IsUnauthorized=true, got error: %v", err) + } +} + +func TestGetPullRequest_429Retry(t *testing.T) { + attempts := 0 + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + attempts++ + if attempts == 1 { + w.WriteHeader(429) + w.Write([]byte(`{"message":"rate limit"}`)) + return + } + json.NewEncoder(w).Encode(map[string]interface{}{ + "number": 1, + "title": "PR", + "body": "", + "head": map[string]string{"sha": "abc", "ref": "br"}, + "base": map[string]string{"ref": "main"}, + }) + })) + defer srv.Close() + + c := NewClient("token", srv.URL, AllowInsecureHTTP()) + c.SetHTTPClient(srv.Client()) + c.SetRetryBackoff([]time.Duration{1 * time.Millisecond}) + + pr, err := c.GetPullRequest(context.Background(), "owner", "repo", 1) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if pr.Number != 1 { + t.Errorf("expected number 1, got %d", pr.Number) + } + if attempts != 2 { + t.Errorf("expected 2 attempts, got %d", attempts) + } +} + +func TestGetPullRequest_MalformedJSON(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(200) + w.Write([]byte(`{invalid json`)) + })) + defer srv.Close() + + c := NewClient("token", srv.URL, AllowInsecureHTTP()) + c.SetHTTPClient(srv.Client()) + + _, err := c.GetPullRequest(context.Background(), "owner", "repo", 1) + if err == nil { + t.Fatal("expected error for malformed JSON") + } + if !strings.Contains(err.Error(), "parse PR JSON") { + t.Errorf("expected parse error, got: %v", err) + } +} + +func TestGetPullRequestDiff_HappyPath(t *testing.T) { + expectedDiff := "diff --git a/file.go b/file.go\n--- a/file.go\n+++ b/file.go\n@@ -1,3 +1,4 @@\n+// new line\n" + var gotAccept string + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + gotAccept = r.Header.Get("Accept") + w.WriteHeader(200) + w.Write([]byte(expectedDiff)) + })) + defer srv.Close() + + c := NewClient("token", srv.URL, AllowInsecureHTTP()) + c.SetHTTPClient(srv.Client()) + + diff, err := c.GetPullRequestDiff(context.Background(), "owner", "repo", 42) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if diff != expectedDiff { + t.Errorf("unexpected diff: %q", diff) + } + if gotAccept != "application/vnd.github.diff" { + t.Errorf("expected diff Accept header, got %q", gotAccept) + } +} + +func TestGetPullRequestDiff_404(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(404) + w.Write([]byte(`{"message":"Not Found"}`)) + })) + defer srv.Close() + + c := NewClient("token", srv.URL, AllowInsecureHTTP()) + c.SetHTTPClient(srv.Client()) + + _, err := c.GetPullRequestDiff(context.Background(), "owner", "repo", 999) + if err == nil { + t.Fatal("expected error for 404") + } +} + +func TestGetPullRequestDiff_401(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(401) + w.Write([]byte(`{"message":"Bad credentials"}`)) + })) + defer srv.Close() + + c := NewClient("token", srv.URL, AllowInsecureHTTP()) + c.SetHTTPClient(srv.Client()) + + _, err := c.GetPullRequestDiff(context.Background(), "owner", "repo", 1) + if err == nil { + t.Fatal("expected error for 401") + } +} + +func TestGetPullRequestFiles_HappyPath(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + json.NewEncoder(w).Encode([]map[string]interface{}{ + {"filename": "main.go", "status": "modified", "patch": "@@ -1,3 +1,4 @@\n+line"}, + {"filename": "test.go", "status": "added", "patch": "@@ -0,0 +1,5 @@\n+new file"}, + }) + })) + defer srv.Close() + + c := NewClient("token", srv.URL, AllowInsecureHTTP()) + c.SetHTTPClient(srv.Client()) + + files, err := c.GetPullRequestFiles(context.Background(), "owner", "repo", 1) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if len(files) != 2 { + t.Fatalf("expected 2 files, got %d", len(files)) + } + if files[0].Filename != "main.go" { + t.Errorf("expected filename 'main.go', got %q", files[0].Filename) + } + if files[0].Status != "modified" { + t.Errorf("expected status 'modified', got %q", files[0].Status) + } + if files[0].Patch != "@@ -1,3 +1,4 @@\n+line" { + t.Errorf("unexpected patch: %q", files[0].Patch) + } +} + +func TestGetPullRequestFiles_Pagination(t *testing.T) { + // Simulate > 100 files requiring pagination + page1Files := make([]map[string]string, 100) + for i := 0; i < 100; i++ { + page1Files[i] = map[string]string{ + "filename": fmt.Sprintf("file%d.go", i), + "status": "modified", + "patch": fmt.Sprintf("patch%d", i), + } + } + page2Files := []map[string]string{ + {"filename": "file100.go", "status": "added", "patch": "patch100"}, + } + + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + page := r.URL.Query().Get("page") + if page == "" || page == "1" { + json.NewEncoder(w).Encode(page1Files) + } else { + json.NewEncoder(w).Encode(page2Files) + } + })) + defer srv.Close() + + c := NewClient("token", srv.URL, AllowInsecureHTTP()) + c.SetHTTPClient(srv.Client()) + + files, err := c.GetPullRequestFiles(context.Background(), "owner", "repo", 1) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if len(files) != 101 { + t.Errorf("expected 101 files (paginated), got %d", len(files)) + } + if files[100].Filename != "file100.go" { + t.Errorf("expected last file 'file100.go', got %q", files[100].Filename) + } + if files[100].Patch != "patch100" { + t.Errorf("expected last patch 'patch100', got %q", files[100].Patch) + } +} + +func TestGetPullRequestFiles_BinaryFile_NoPatch(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Binary files have no patch field in GitHub response + json.NewEncoder(w).Encode([]map[string]interface{}{ + {"filename": "image.png", "status": "added"}, + }) + })) + defer srv.Close() + + c := NewClient("token", srv.URL, AllowInsecureHTTP()) + c.SetHTTPClient(srv.Client()) + + files, err := c.GetPullRequestFiles(context.Background(), "owner", "repo", 1) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if len(files) != 1 { + t.Fatalf("expected 1 file, got %d", len(files)) + } + if files[0].Patch != "" { + t.Errorf("expected empty patch for binary file, got %q", files[0].Patch) + } +} + +func TestGetPullRequestFiles_404(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(404) + w.Write([]byte(`{"message":"Not Found"}`)) + })) + defer srv.Close() + + c := NewClient("token", srv.URL, AllowInsecureHTTP()) + c.SetHTTPClient(srv.Client()) + + _, err := c.GetPullRequestFiles(context.Background(), "owner", "repo", 999) + if err == nil { + t.Fatal("expected error for 404") + } +} + +func TestGetPullRequestFiles_MalformedJSON(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(200) + w.Write([]byte(`not json`)) + })) + defer srv.Close() + + c := NewClient("token", srv.URL, AllowInsecureHTTP()) + c.SetHTTPClient(srv.Client()) + + _, err := c.GetPullRequestFiles(context.Background(), "owner", "repo", 1) + if err == nil { + t.Fatal("expected error for malformed JSON") + } +} + +func TestGetFileContentAtRef_HappyPath(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/repos/owner/repo/contents/path/to/file.go" { + t.Errorf("unexpected path: %s", r.URL.Path) + } + if r.URL.Query().Get("ref") != "abc123" { + t.Errorf("unexpected ref: %s", r.URL.Query().Get("ref")) + } + json.NewEncoder(w).Encode(map[string]string{ + "content": "cGFja2FnZSBtYWlu", // "package main" in base64 + "encoding": "base64", + }) + })) + defer srv.Close() + + c := NewClient("token", srv.URL, AllowInsecureHTTP()) + c.SetHTTPClient(srv.Client()) + + content, err := c.GetFileContentAtRef(context.Background(), "owner", "repo", "path/to/file.go", "abc123") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if content != "package main" { + t.Errorf("expected 'package main', got %q", content) + } +} + +func TestGetFileContentAtRef_EmptyRef(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Query().Get("ref") != "" { + t.Errorf("expected no ref param, got %q", r.URL.Query().Get("ref")) + } + json.NewEncoder(w).Encode(map[string]string{ + "content": "aGVsbG8=", // "hello" in base64 + "encoding": "base64", + }) + })) + defer srv.Close() + + c := NewClient("token", srv.URL, AllowInsecureHTTP()) + c.SetHTTPClient(srv.Client()) + + content, err := c.GetFileContentAtRef(context.Background(), "owner", "repo", "file.txt", "") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if content != "hello" { + t.Errorf("expected 'hello', got %q", content) + } +} + +func TestGetFileContentAtRef_404(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(404) + w.Write([]byte(`{"message":"Not Found"}`)) + })) + defer srv.Close() + + c := NewClient("token", srv.URL, AllowInsecureHTTP()) + c.SetHTTPClient(srv.Client()) + + _, err := c.GetFileContentAtRef(context.Background(), "owner", "repo", "missing.go", "main") + if err == nil { + t.Fatal("expected error for 404") + } +} + +func TestGetFileContentAtRef_401(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(401) + w.Write([]byte(`{"message":"Bad credentials"}`)) + })) + defer srv.Close() + + c := NewClient("token", srv.URL, AllowInsecureHTTP()) + c.SetHTTPClient(srv.Client()) + + _, err := c.GetFileContentAtRef(context.Background(), "owner", "repo", "file.go", "main") + if err == nil { + t.Fatal("expected error for 401") + } +} + +func TestGetFileContentAtRef_MalformedJSON(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(200) + w.Write([]byte(`not valid json`)) + })) + defer srv.Close() + + c := NewClient("token", srv.URL, AllowInsecureHTTP()) + c.SetHTTPClient(srv.Client()) + + _, err := c.GetFileContentAtRef(context.Background(), "owner", "repo", "file.go", "main") + if err == nil { + t.Fatal("expected error for malformed JSON") + } +} + +func TestGetFileContentAtRef_429Retry(t *testing.T) { + attempts := 0 + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + attempts++ + if attempts == 1 { + w.WriteHeader(429) + w.Write([]byte(`{"message":"rate limit"}`)) + return + } + json.NewEncoder(w).Encode(map[string]string{ + "content": "b2s=", // "ok" in base64 + "encoding": "base64", + }) + })) + defer srv.Close() + + c := NewClient("token", srv.URL, AllowInsecureHTTP()) + c.SetHTTPClient(srv.Client()) + c.SetRetryBackoff([]time.Duration{1 * time.Millisecond}) + + content, err := c.GetFileContentAtRef(context.Background(), "owner", "repo", "file.go", "main") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if content != "ok" { + t.Errorf("expected 'ok', got %q", content) + } + if attempts != 2 { + t.Errorf("expected 2 attempts, got %d", attempts) + } +} + +func TestGetCommitStatuses_HappyPath(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch { + case strings.Contains(r.URL.Path, "/status"): + json.NewEncoder(w).Encode(map[string]interface{}{ + "state": "success", + "statuses": []map[string]string{ + { + "context": "ci/build", + "state": "success", + "description": "Build passed", + "target_url": "https://ci.example.com/1", + }, + }, + }) + case strings.Contains(r.URL.Path, "/check-runs"): + conclusion := "success" + json.NewEncoder(w).Encode(map[string]interface{}{ + "total_count": 1, + "check_runs": []map[string]interface{}{ + { + "name": "lint", + "conclusion": &conclusion, + "status": "completed", + "html_url": "https://github.com/check/1", + }, + }, + }) + default: + t.Errorf("unexpected path: %s", r.URL.Path) + w.WriteHeader(404) + } + })) + defer srv.Close() + + c := NewClient("token", srv.URL, AllowInsecureHTTP()) + c.SetHTTPClient(srv.Client()) + + statuses, err := c.GetCommitStatuses(context.Background(), "owner", "repo", "abc123") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if len(statuses) != 2 { + t.Fatalf("expected 2 statuses, got %d", len(statuses)) + } + // First should be from commit statuses + if statuses[0].Context != "ci/build" { + t.Errorf("expected context 'ci/build', got %q", statuses[0].Context) + } + if statuses[0].Status != "success" { + t.Errorf("expected status 'success', got %q", statuses[0].Status) + } + // Second should be from check runs + if statuses[1].Context != "lint" { + t.Errorf("expected context 'lint', got %q", statuses[1].Context) + } + if statuses[1].Status != "success" { + t.Errorf("expected status 'success', got %q", statuses[1].Status) + } +} + +func TestGetCommitStatuses_CheckRunConclusions(t *testing.T) { + tests := []struct { + conclusion *string + status string + want string + }{ + {stringPtr("success"), "completed", "success"}, + {stringPtr("failure"), "completed", "failure"}, + {stringPtr("action_required"), "completed", "failure"}, + {stringPtr("timed_out"), "completed", "failure"}, + {stringPtr("cancelled"), "completed", "success"}, + {stringPtr("skipped"), "completed", "success"}, + {stringPtr("neutral"), "completed", "success"}, + {nil, "in_progress", "pending"}, + {nil, "queued", "pending"}, + } + + for _, tt := range tests { + name := "nil" + if tt.conclusion != nil { + name = *tt.conclusion + } + t.Run(name, func(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if strings.Contains(r.URL.Path, "/status") { + json.NewEncoder(w).Encode(map[string]interface{}{ + "state": "success", + "statuses": []interface{}{}, + }) + return + } + json.NewEncoder(w).Encode(map[string]interface{}{ + "total_count": 1, + "check_runs": []map[string]interface{}{ + { + "name": "check", + "conclusion": tt.conclusion, + "status": tt.status, + "html_url": "https://github.com/check/1", + }, + }, + }) + })) + defer srv.Close() + + c := NewClient("token", srv.URL, AllowInsecureHTTP()) + c.SetHTTPClient(srv.Client()) + + statuses, err := c.GetCommitStatuses(context.Background(), "owner", "repo", "sha1") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if len(statuses) != 1 { + t.Fatalf("expected 1 status, got %d", len(statuses)) + } + if statuses[0].Status != tt.want { + t.Errorf("expected status %q, got %q", tt.want, statuses[0].Status) + } + }) + } +} + +func TestGetCommitStatuses_404(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(404) + w.Write([]byte(`{"message":"Not Found"}`)) + })) + defer srv.Close() + + c := NewClient("token", srv.URL, AllowInsecureHTTP()) + c.SetHTTPClient(srv.Client()) + + _, err := c.GetCommitStatuses(context.Background(), "owner", "repo", "badsha") + if err == nil { + t.Fatal("expected error for 404") + } +} + +func TestGetCommitStatuses_401(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(401) + w.Write([]byte(`{"message":"Bad credentials"}`)) + })) + defer srv.Close() + + c := NewClient("token", srv.URL, AllowInsecureHTTP()) + c.SetHTTPClient(srv.Client()) + + _, err := c.GetCommitStatuses(context.Background(), "owner", "repo", "sha") + if err == nil { + t.Fatal("expected error for 401") + } +} + +func TestGetCommitStatuses_MalformedJSON(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(200) + w.Write([]byte(`not json`)) + })) + defer srv.Close() + + c := NewClient("token", srv.URL, AllowInsecureHTTP()) + c.SetHTTPClient(srv.Client()) + + _, err := c.GetCommitStatuses(context.Background(), "owner", "repo", "sha") + if err == nil { + t.Fatal("expected error for malformed JSON") + } +} + +func stringPtr(s string) *string { + return &s +}