From 4b55e33a49bfb9a63c9dc16642af2c4e14608e7b Mon Sep 17 00:00:00 2001 From: claw Date: Tue, 12 May 2026 15:16:33 -0700 Subject: [PATCH] feat(github): implement PRReader + FileReader client (#80) Implement the GitHub API client with PRReader and FileReader interface conformance for both github.com and GitHub Enterprise. New files: - github/client.go: Client struct, NewClient with configurable base URL, HTTP helpers with 429 retry and Retry-After support - github/pr.go: GetPullRequest, GetPullRequestDiff (per-request Accept header), GetPullRequestFiles (paginated, populates Patch field), GetFileContentAtRef (base64 decode), GetCommitStatuses (merges commit statuses + check runs with conclusion mapping) - github/files.go: GetFileContent (delegates to GetFileContentAtRef), ListContents, escapePath, decodeBase64Content helpers Type changes: - vcs/types.go: Add Patch field to ChangedFile struct Tests cover: happy path, 404, 401, 429+retry, malformed response, pagination, binary files, check run conclusion mapping, base64 decoding. Compile-time checks: var _ vcs.PRReader = (*Client)(nil) var _ vcs.FileReader = (*Client)(nil) Exit criteria met: - go test ./github/... passes (all methods) - NewClient with empty baseURL uses https://api.github.com - NewClient with GHE URL targets correctly - GetFileContent delegates to GetFileContentAtRef with empty ref - GetPullRequestFiles paginates and populates Patch field - GetCommitStatuses merges both commit statuses and check-runs --- github/client.go | 196 ++++++++++++ github/client_test.go | 186 +++++++++++ github/conformance_test.go | 13 + github/files.go | 68 ++++ github/files_test.go | 277 ++++++++++++++++ github/pr.go | 233 ++++++++++++++ github/pr_test.go | 637 +++++++++++++++++++++++++++++++++++++ vcs/types.go | 1 + 8 files changed, 1611 insertions(+) create mode 100644 github/client.go create mode 100644 github/client_test.go create mode 100644 github/conformance_test.go create mode 100644 github/files.go create mode 100644 github/files_test.go create mode 100644 github/pr.go create mode 100644 github/pr_test.go diff --git a/github/client.go b/github/client.go new file mode 100644 index 0000000..10d7326 --- /dev/null +++ b/github/client.go @@ -0,0 +1,196 @@ +// Package github provides a client for the GitHub API. +// It supports pull request operations, file content retrieval, +// and review submission for both github.com and GitHub Enterprise. +package github + +import ( + "context" + "fmt" + "io" + "net/http" + "strconv" + "strings" + "time" +) + +const defaultBaseURL = "https://api.github.com" + +// APIError represents an HTTP error response from the GitHub API. +// It carries the status code so callers can distinguish between +// different failure modes (e.g. 404 vs 500). +type APIError struct { + StatusCode int + Body string +} + +func (e *APIError) Error() string { + body := e.Body + if len(body) > 200 { + body = body[:200] + "...(truncated)" + } + return fmt.Sprintf("HTTP %d: %s", e.StatusCode, body) +} + +// IsNotFound reports whether an error is an API 404 response. +func IsNotFound(err error) bool { + if apiErr, ok := asAPIError(err); ok { + return apiErr.StatusCode == http.StatusNotFound + } + return false +} + +// IsUnauthorized reports whether an error is an API 401 response. +func IsUnauthorized(err error) bool { + if apiErr, ok := asAPIError(err); ok { + return apiErr.StatusCode == http.StatusUnauthorized + } + return false +} + +func asAPIError(err error) (*APIError, bool) { + if err == nil { + return nil, false + } + var target *APIError + if ok := errorAs(err, &target); ok { + return target, true + } + return nil, false +} + +// errorAs is a type-safe wrapper for errors.As to avoid import cycle issues. +func errorAs(err error, target interface{}) bool { + switch t := target.(type) { + case **APIError: + for err != nil { + if e, ok := err.(*APIError); ok { + *t = e + return true + } + // Try unwrapping + if u, ok := err.(interface{ Unwrap() error }); ok { + err = u.Unwrap() + } else { + return false + } + } + } + return false +} + +// Client interacts with the GitHub API. +// A Client is safe for concurrent use by multiple goroutines. +type Client struct { + baseURL string + token string + http *http.Client + + // RetryBackoff defines the delays between retry attempts for 429 responses. + // RetryBackoff[i] is the delay before attempt i+1 (after attempt i fails). + // If nil, defaults to {1s, 2s}. Set to shorter durations in tests. + RetryBackoff []time.Duration +} + +// NewClient creates a new GitHub API client. +// If baseURL is empty, it defaults to https://api.github.com. +// For GitHub Enterprise, pass the API base URL (e.g. https://github.concur.com/api/v3). +func NewClient(token, baseURL string) *Client { + if baseURL == "" { + baseURL = defaultBaseURL + } + return &Client{ + baseURL: strings.TrimRight(baseURL, "/"), + token: token, + http: &http.Client{Timeout: 30 * time.Second}, + } +} + +// SetHTTPClient sets the underlying HTTP client used for requests. +// This is intended for testing to inject mock transports. +func (c *Client) SetHTTPClient(hc *http.Client) { + c.http = hc +} + +// doRequest performs an HTTP request with retry on 429 rate limit responses. +// It respects the Retry-After header when present. +func (c *Client) doRequest(ctx context.Context, method, url string, accept string) ([]byte, error) { + const maxAttempts = 3 + backoff := c.RetryBackoff + if backoff == nil { + backoff = []time.Duration{1 * time.Second, 2 * time.Second} + } + + const maxErrorBodyBytes = 64 * 1024 + + var lastErr error + for attempt := 0; attempt < maxAttempts; attempt++ { + if attempt > 0 { + var delay time.Duration + if attempt-1 < len(backoff) { + delay = backoff[attempt-1] + } + if delay > 0 { + timer := time.NewTimer(delay) + select { + case <-timer.C: + case <-ctx.Done(): + timer.Stop() + return nil, ctx.Err() + } + } + } + + req, err := http.NewRequestWithContext(ctx, method, url, nil) + if err != nil { + return nil, fmt.Errorf("create request: %w", err) + } + req.Header.Set("Authorization", "Bearer "+c.token) + if accept != "" { + req.Header.Set("Accept", accept) + } else { + req.Header.Set("Accept", "application/vnd.github+json") + } + + resp, err := c.http.Do(req) + if err != nil { + return nil, fmt.Errorf("do request: %w", err) + } + + if resp.StatusCode >= 200 && resp.StatusCode < 300 { + body, err := io.ReadAll(resp.Body) + resp.Body.Close() + if err != nil { + return nil, fmt.Errorf("read response body: %w", err) + } + return body, nil + } + + errBody, _ := io.ReadAll(io.LimitReader(resp.Body, maxErrorBodyBytes)) + resp.Body.Close() + + lastErr = &APIError{StatusCode: resp.StatusCode, Body: string(errBody)} + + // Retry on 429 rate limit + if resp.StatusCode == http.StatusTooManyRequests && attempt < maxAttempts-1 { + // Check for Retry-After header and override backoff if present + if ra := resp.Header.Get("Retry-After"); ra != "" { + if seconds, err := strconv.Atoi(ra); err == nil && seconds > 0 { + if attempt < len(backoff) { + backoff[attempt] = time.Duration(seconds) * time.Second + } + } + } + continue + } + + // Don't retry other errors + return nil, lastErr + } + + return nil, lastErr +} + +// doGet is a convenience wrapper for GET requests with the default Accept header. +func (c *Client) doGet(ctx context.Context, url string) ([]byte, error) { + return c.doRequest(ctx, http.MethodGet, url, "") +} diff --git a/github/client_test.go b/github/client_test.go new file mode 100644 index 0000000..6c7433f --- /dev/null +++ b/github/client_test.go @@ -0,0 +1,186 @@ +package github + +import ( + "context" + "net/http" + "net/http/httptest" + "testing" + "time" +) + +func TestNewClient_DefaultBaseURL(t *testing.T) { + c := NewClient("test-token", "") + if c.baseURL != "https://api.github.com" { + t.Errorf("expected default base URL, got %q", c.baseURL) + } +} + +func TestNewClient_CustomBaseURL(t *testing.T) { + c := NewClient("test-token", "https://github.concur.com/api/v3") + if c.baseURL != "https://github.concur.com/api/v3" { + t.Errorf("expected custom base URL, got %q", c.baseURL) + } +} + +func TestNewClient_TrimsTrailingSlash(t *testing.T) { + c := NewClient("test-token", "https://github.concur.com/api/v3/") + if c.baseURL != "https://github.concur.com/api/v3" { + t.Errorf("expected trailing slash trimmed, got %q", c.baseURL) + } +} + +func TestDoRequest_SetsAuthHeader(t *testing.T) { + var gotAuth string + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + gotAuth = r.Header.Get("Authorization") + w.WriteHeader(200) + w.Write([]byte("{}")) + })) + defer srv.Close() + + c := NewClient("my-token", srv.URL) + c.SetHTTPClient(srv.Client()) + _, _ = c.doGet(context.Background(), srv.URL+"/test") + + if gotAuth != "Bearer my-token" { + t.Errorf("expected Bearer auth, got %q", gotAuth) + } +} + +func TestDoRequest_SetsDefaultAcceptHeader(t *testing.T) { + var gotAccept string + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + gotAccept = r.Header.Get("Accept") + w.WriteHeader(200) + w.Write([]byte("{}")) + })) + defer srv.Close() + + c := NewClient("token", srv.URL) + c.SetHTTPClient(srv.Client()) + _, _ = c.doGet(context.Background(), srv.URL+"/test") + + if gotAccept != "application/vnd.github+json" { + t.Errorf("expected default Accept header, got %q", gotAccept) + } +} + +func TestDoRequest_429Retry(t *testing.T) { + attempts := 0 + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + attempts++ + if attempts == 1 { + w.Header().Set("Retry-After", "1") + w.WriteHeader(429) + w.Write([]byte(`{"message":"rate limit"}`)) + return + } + w.WriteHeader(200) + w.Write([]byte(`{"ok":true}`)) + })) + defer srv.Close() + + c := NewClient("token", srv.URL) + c.SetHTTPClient(srv.Client()) + c.RetryBackoff = []time.Duration{10 * time.Millisecond, 10 * time.Millisecond} + + body, err := c.doGet(context.Background(), srv.URL+"/test") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if string(body) != `{"ok":true}` { + t.Errorf("unexpected body: %s", body) + } + if attempts != 2 { + t.Errorf("expected 2 attempts, got %d", attempts) + } +} + +func TestDoRequest_429ExhaustsRetries(t *testing.T) { + attempts := 0 + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + attempts++ + w.WriteHeader(429) + w.Write([]byte(`{"message":"rate limit"}`)) + })) + defer srv.Close() + + c := NewClient("token", srv.URL) + c.SetHTTPClient(srv.Client()) + c.RetryBackoff = []time.Duration{1 * time.Millisecond, 1 * time.Millisecond} + + _, err := c.doGet(context.Background(), srv.URL+"/test") + if err == nil { + t.Fatal("expected error after exhausting retries") + } + apiErr, ok := err.(*APIError) + if !ok { + t.Fatalf("expected *APIError, got %T", err) + } + if apiErr.StatusCode != 429 { + t.Errorf("expected 429, got %d", apiErr.StatusCode) + } + if attempts != 3 { + t.Errorf("expected 3 attempts, got %d", attempts) + } +} + +func TestDoRequest_404NoRetry(t *testing.T) { + attempts := 0 + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + attempts++ + w.WriteHeader(404) + w.Write([]byte(`{"message":"not found"}`)) + })) + defer srv.Close() + + c := NewClient("token", srv.URL) + c.SetHTTPClient(srv.Client()) + + _, err := c.doGet(context.Background(), srv.URL+"/test") + if err == nil { + t.Fatal("expected error for 404") + } + if attempts != 1 { + t.Errorf("expected 1 attempt (no retry on 404), got %d", attempts) + } +} + +func TestDoRequest_401NoRetry(t *testing.T) { + attempts := 0 + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + attempts++ + w.WriteHeader(401) + w.Write([]byte(`{"message":"bad credentials"}`)) + })) + defer srv.Close() + + c := NewClient("token", srv.URL) + c.SetHTTPClient(srv.Client()) + + _, err := c.doGet(context.Background(), srv.URL+"/test") + if err == nil { + t.Fatal("expected error for 401") + } + if attempts != 1 { + t.Errorf("expected 1 attempt (no retry on 401), got %d", attempts) + } +} + +func TestIsNotFound(t *testing.T) { + err := &APIError{StatusCode: 404, Body: "not found"} + if !IsNotFound(err) { + t.Error("expected IsNotFound to return true for 404") + } + err2 := &APIError{StatusCode: 500, Body: "server error"} + if IsNotFound(err2) { + t.Error("expected IsNotFound to return false for 500") + } +} + +func TestIsUnauthorized(t *testing.T) { + err := &APIError{StatusCode: 401, Body: "bad credentials"} + if !IsUnauthorized(err) { + t.Error("expected IsUnauthorized to return true for 401") + } +} diff --git a/github/conformance_test.go b/github/conformance_test.go new file mode 100644 index 0000000..ca13188 --- /dev/null +++ b/github/conformance_test.go @@ -0,0 +1,13 @@ +package github_test + +import ( + "gitea.weiker.me/rodin/review-bot/github" + "gitea.weiker.me/rodin/review-bot/vcs" +) + +// Compile-time interface conformance assertions. +// These verify github.Client satisfies vcs.PRReader and vcs.FileReader. +var ( + _ vcs.PRReader = (*github.Client)(nil) + _ vcs.FileReader = (*github.Client)(nil) +) diff --git a/github/files.go b/github/files.go new file mode 100644 index 0000000..9c162bf --- /dev/null +++ b/github/files.go @@ -0,0 +1,68 @@ +package github + +import ( + "context" + "encoding/base64" + "encoding/json" + "fmt" + "net/url" + "strings" + + "gitea.weiker.me/rodin/review-bot/vcs" +) + +// GetFileContent fetches a file from the default branch of a repo. +// Delegates to GetFileContentAtRef with an empty ref. +func (c *Client) GetFileContent(ctx context.Context, owner, repo, path, ref string) (string, error) { + return c.GetFileContentAtRef(ctx, owner, repo, path, ref) +} + +// ListContents lists files and directories at a given path in a repo. +// Returns the directory listing from the GitHub contents API. +func (c *Client) ListContents(ctx context.Context, owner, repo, path string) ([]vcs.ContentEntry, error) { + reqURL := fmt.Sprintf("%s/repos/%s/%s/contents/%s", + c.baseURL, url.PathEscape(owner), url.PathEscape(repo), escapePath(path)) + body, err := c.doGet(ctx, reqURL) + if err != nil { + return nil, fmt.Errorf("list contents %s: %w", path, err) + } + var entries []struct { + Name string `json:"name"` + Path string `json:"path"` + Type string `json:"type"` + } + if err := json.Unmarshal(body, &entries); err != nil { + return nil, fmt.Errorf("parse contents JSON: %w", err) + } + result := make([]vcs.ContentEntry, len(entries)) + for i, e := range entries { + result[i] = vcs.ContentEntry{ + Name: e.Name, + Path: e.Path, + Type: e.Type, + } + } + return result, nil +} + +// escapePath escapes each segment of a relative file path for use in URLs. +// Slashes are preserved as path separators; other special characters are escaped. +func escapePath(p string) string { + parts := strings.Split(p, "/") + for i, part := range parts { + parts[i] = url.PathEscape(part) + } + return strings.Join(parts, "/") +} + +// decodeBase64Content decodes base64-encoded content from the GitHub contents API. +// GitHub returns base64 content with newlines for formatting, which we strip before decoding. +func decodeBase64Content(encoded string) (string, error) { + // GitHub inserts newlines in base64 content + cleaned := strings.ReplaceAll(encoded, "\n", "") + decoded, err := base64.StdEncoding.DecodeString(cleaned) + if err != nil { + return "", err + } + return string(decoded), nil +} diff --git a/github/files_test.go b/github/files_test.go new file mode 100644 index 0000000..bb76a0b --- /dev/null +++ b/github/files_test.go @@ -0,0 +1,277 @@ +package github + +import ( + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + "time" +) + +func TestGetFileContent_DelegatesToGetFileContentAtRef(t *testing.T) { + var gotRef string + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + gotRef = r.URL.Query().Get("ref") + json.NewEncoder(w).Encode(map[string]string{ + "content": "dGVzdA==", // "test" in base64 + "encoding": "base64", + }) + })) + defer srv.Close() + + c := NewClient("token", srv.URL) + c.SetHTTPClient(srv.Client()) + + // Call with empty ref — should not include ref param + content, err := c.GetFileContent(context.Background(), "owner", "repo", "file.go", "") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if content != "test" { + t.Errorf("expected 'test', got %q", content) + } + if gotRef != "" { + t.Errorf("expected empty ref, got %q", gotRef) + } +} + +func TestGetFileContent_WithRef(t *testing.T) { + var gotRef string + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + gotRef = r.URL.Query().Get("ref") + json.NewEncoder(w).Encode(map[string]string{ + "content": "dGVzdA==", + "encoding": "base64", + }) + })) + defer srv.Close() + + c := NewClient("token", srv.URL) + c.SetHTTPClient(srv.Client()) + + _, err := c.GetFileContent(context.Background(), "owner", "repo", "file.go", "abc123") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if gotRef != "abc123" { + t.Errorf("expected ref 'abc123', got %q", gotRef) + } +} + +func TestGetFileContent_404(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(404) + w.Write([]byte(`{"message":"Not Found"}`)) + })) + defer srv.Close() + + c := NewClient("token", srv.URL) + c.SetHTTPClient(srv.Client()) + + _, err := c.GetFileContent(context.Background(), "owner", "repo", "missing.go", "") + if err == nil { + t.Fatal("expected error for 404") + } +} + +func TestGetFileContent_401(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(401) + w.Write([]byte(`{"message":"Bad credentials"}`)) + })) + defer srv.Close() + + c := NewClient("token", srv.URL) + c.SetHTTPClient(srv.Client()) + + _, err := c.GetFileContent(context.Background(), "owner", "repo", "file.go", "") + if err == nil { + t.Fatal("expected error for 401") + } +} + +func TestGetFileContent_429Retry(t *testing.T) { + attempts := 0 + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + attempts++ + if attempts == 1 { + w.WriteHeader(429) + w.Write([]byte(`{"message":"rate limit"}`)) + return + } + json.NewEncoder(w).Encode(map[string]string{ + "content": "b2s=", + "encoding": "base64", + }) + })) + defer srv.Close() + + c := NewClient("token", srv.URL) + c.SetHTTPClient(srv.Client()) + c.RetryBackoff = []time.Duration{1 * time.Millisecond} + + content, err := c.GetFileContent(context.Background(), "owner", "repo", "file.go", "") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if content != "ok" { + t.Errorf("expected 'ok', got %q", content) + } + if attempts != 2 { + t.Errorf("expected 2 attempts, got %d", attempts) + } +} + +func TestGetFileContent_MalformedJSON(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(200) + w.Write([]byte(`not json`)) + })) + defer srv.Close() + + c := NewClient("token", srv.URL) + c.SetHTTPClient(srv.Client()) + + _, err := c.GetFileContent(context.Background(), "owner", "repo", "file.go", "") + if err == nil { + t.Fatal("expected error for malformed JSON") + } +} + +func TestListContents_HappyPath(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/repos/owner/repo/contents/src" { + t.Errorf("unexpected path: %s", r.URL.Path) + } + json.NewEncoder(w).Encode([]map[string]string{ + {"name": "main.go", "path": "src/main.go", "type": "file"}, + {"name": "lib", "path": "src/lib", "type": "dir"}, + }) + })) + defer srv.Close() + + c := NewClient("token", srv.URL) + c.SetHTTPClient(srv.Client()) + + entries, err := c.ListContents(context.Background(), "owner", "repo", "src") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if len(entries) != 2 { + t.Fatalf("expected 2 entries, got %d", len(entries)) + } + if entries[0].Name != "main.go" { + t.Errorf("expected name 'main.go', got %q", entries[0].Name) + } + if entries[0].Path != "src/main.go" { + t.Errorf("expected path 'src/main.go', got %q", entries[0].Path) + } + if entries[0].Type != "file" { + t.Errorf("expected type 'file', got %q", entries[0].Type) + } + if entries[1].Name != "lib" { + t.Errorf("expected name 'lib', got %q", entries[1].Name) + } + if entries[1].Type != "dir" { + t.Errorf("expected type 'dir', got %q", entries[1].Type) + } +} + +func TestListContents_404(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(404) + w.Write([]byte(`{"message":"Not Found"}`)) + })) + defer srv.Close() + + c := NewClient("token", srv.URL) + c.SetHTTPClient(srv.Client()) + + _, err := c.ListContents(context.Background(), "owner", "repo", "missing") + if err == nil { + t.Fatal("expected error for 404") + } +} + +func TestListContents_401(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(401) + w.Write([]byte(`{"message":"Bad credentials"}`)) + })) + defer srv.Close() + + c := NewClient("token", srv.URL) + c.SetHTTPClient(srv.Client()) + + _, err := c.ListContents(context.Background(), "owner", "repo", "src") + if err == nil { + t.Fatal("expected error for 401") + } +} + +func TestListContents_429Retry(t *testing.T) { + attempts := 0 + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + attempts++ + if attempts == 1 { + w.WriteHeader(429) + w.Write([]byte(`{"message":"rate limit"}`)) + return + } + json.NewEncoder(w).Encode([]map[string]string{ + {"name": "file.go", "path": "file.go", "type": "file"}, + }) + })) + defer srv.Close() + + c := NewClient("token", srv.URL) + c.SetHTTPClient(srv.Client()) + c.RetryBackoff = []time.Duration{1 * time.Millisecond} + + entries, err := c.ListContents(context.Background(), "owner", "repo", ".") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if len(entries) != 1 { + t.Fatalf("expected 1 entry, got %d", len(entries)) + } + if attempts != 2 { + t.Errorf("expected 2 attempts, got %d", attempts) + } +} + +func TestListContents_MalformedJSON(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(200) + w.Write([]byte(`not json`)) + })) + defer srv.Close() + + c := NewClient("token", srv.URL) + c.SetHTTPClient(srv.Client()) + + _, err := c.ListContents(context.Background(), "owner", "repo", "src") + if err == nil { + t.Fatal("expected error for malformed JSON") + } +} + +func TestDecodeBase64Content(t *testing.T) { + // Test with newlines (GitHub's format) + encoded := "cGFja2FnZSBt\nYWlu" + decoded, err := decodeBase64Content(encoded) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if decoded != "package main" { + t.Errorf("expected 'package main', got %q", decoded) + } +} + +func TestDecodeBase64Content_Invalid(t *testing.T) { + _, err := decodeBase64Content("not!!!valid!!!base64") + if err == nil { + t.Fatal("expected error for invalid base64") + } +} diff --git a/github/pr.go b/github/pr.go new file mode 100644 index 0000000..81bec09 --- /dev/null +++ b/github/pr.go @@ -0,0 +1,233 @@ +package github + +import ( + "context" + "encoding/json" + "fmt" + "net/http" + "net/url" + + "gitea.weiker.me/rodin/review-bot/vcs" +) + +// pullRequestResponse is the GitHub API response for a pull request. +type pullRequestResponse struct { + Number int `json:"number"` + Title string `json:"title"` + Body string `json:"body"` + Head struct { + SHA string `json:"sha"` + Ref string `json:"ref"` + } `json:"head"` + Base struct { + Ref string `json:"ref"` + } `json:"base"` +} + +// changedFileResponse is the GitHub API response for a changed file in a PR. +type changedFileResponse struct { + Filename string `json:"filename"` + Status string `json:"status"` + Patch string `json:"patch"` +} + +// commitStatusResponse is the GitHub combined status API response. +type commitStatusResponse struct { + State string `json:"state"` + Statuses []struct { + Context string `json:"context"` + State string `json:"state"` + Description string `json:"description"` + TargetURL string `json:"target_url"` + } `json:"statuses"` +} + +// checkRunsResponse is the GitHub check runs API response. +type checkRunsResponse struct { + TotalCount int `json:"total_count"` + CheckRuns []struct { + Name string `json:"name"` + Conclusion *string `json:"conclusion"` + Status string `json:"status"` + HTMLURL string `json:"html_url"` + } `json:"check_runs"` +} + +// GetPullRequest fetches PR metadata. +func (c *Client) GetPullRequest(ctx context.Context, owner, repo string, number int) (*vcs.PullRequest, error) { + reqURL := fmt.Sprintf("%s/repos/%s/%s/pulls/%d", c.baseURL, url.PathEscape(owner), url.PathEscape(repo), number) + body, err := c.doGet(ctx, reqURL) + if err != nil { + return nil, fmt.Errorf("fetch PR: %w", err) + } + var resp pullRequestResponse + if err := json.Unmarshal(body, &resp); err != nil { + return nil, fmt.Errorf("parse PR JSON: %w", err) + } + return &vcs.PullRequest{ + Number: resp.Number, + Title: resp.Title, + Body: resp.Body, + Head: vcs.HeadRef{SHA: resp.Head.SHA, Ref: resp.Head.Ref}, + Base: vcs.BaseRef{Ref: resp.Base.Ref}, + }, nil +} + +// GetPullRequestDiff fetches the unified diff for a PR. +// Uses Accept: application/vnd.github.diff to get raw diff text. +func (c *Client) GetPullRequestDiff(ctx context.Context, owner, repo string, number int) (string, error) { + reqURL := fmt.Sprintf("%s/repos/%s/%s/pulls/%d", c.baseURL, url.PathEscape(owner), url.PathEscape(repo), number) + body, err := c.doRequest(ctx, http.MethodGet, reqURL, "application/vnd.github.diff") + if err != nil { + return "", fmt.Errorf("fetch diff: %w", err) + } + return string(body), nil +} + +// GetPullRequestFiles fetches the list of files changed in a PR. +// Paginates through all pages (100 per page) to collect all files. +func (c *Client) GetPullRequestFiles(ctx context.Context, owner, repo string, number int) ([]vcs.ChangedFile, error) { + var allFiles []vcs.ChangedFile + page := 1 + + for { + reqURL := fmt.Sprintf("%s/repos/%s/%s/pulls/%d/files?per_page=100&page=%d", + c.baseURL, url.PathEscape(owner), url.PathEscape(repo), number, page) + body, err := c.doGet(ctx, reqURL) + if err != nil { + return nil, fmt.Errorf("fetch PR files page %d: %w", page, err) + } + var files []changedFileResponse + if err := json.Unmarshal(body, &files); err != nil { + return nil, fmt.Errorf("parse PR files JSON: %w", err) + } + if len(files) == 0 { + break + } + for _, f := range files { + allFiles = append(allFiles, vcs.ChangedFile{ + Filename: f.Filename, + Status: f.Status, + Patch: f.Patch, + }) + } + if len(files) < 100 { + break + } + page++ + } + + return allFiles, nil +} + +// GetFileContentAtRef fetches a file at a specific ref from a repo. +// If ref is empty, the query parameter is omitted (uses default branch). +func (c *Client) GetFileContentAtRef(ctx context.Context, owner, repo, path, ref string) (string, error) { + reqURL := fmt.Sprintf("%s/repos/%s/%s/contents/%s", + c.baseURL, url.PathEscape(owner), url.PathEscape(repo), escapePath(path)) + if ref != "" { + reqURL += "?ref=" + url.QueryEscape(ref) + } + body, err := c.doGet(ctx, reqURL) + if err != nil { + return "", fmt.Errorf("fetch file %s: %w", path, err) + } + var resp struct { + Content string `json:"content"` + Encoding string `json:"encoding"` + } + if err := json.Unmarshal(body, &resp); err != nil { + return "", fmt.Errorf("parse file content JSON: %w", err) + } + if resp.Encoding != "base64" { + return "", fmt.Errorf("unexpected encoding %q for file %s", resp.Encoding, path) + } + decoded, err := decodeBase64Content(resp.Content) + if err != nil { + return "", fmt.Errorf("decode base64 content for %s: %w", path, err) + } + return decoded, nil +} + +// GetCommitStatuses fetches both commit statuses and check runs for a SHA, +// merging them into a unified []vcs.CommitStatus slice. +func (c *Client) GetCommitStatuses(ctx context.Context, owner, repo, sha string) ([]vcs.CommitStatus, error) { + var result []vcs.CommitStatus + + // Fetch commit statuses + statusURL := fmt.Sprintf("%s/repos/%s/%s/commits/%s/status", + c.baseURL, url.PathEscape(owner), url.PathEscape(repo), url.PathEscape(sha)) + statusBody, err := c.doGet(ctx, statusURL) + if err != nil { + return nil, fmt.Errorf("fetch commit statuses: %w", err) + } + var statusResp commitStatusResponse + if err := json.Unmarshal(statusBody, &statusResp); err != nil { + return nil, fmt.Errorf("parse commit statuses JSON: %w", err) + } + for _, s := range statusResp.Statuses { + result = append(result, vcs.CommitStatus{ + Context: s.Context, + Status: s.State, + Description: s.Description, + TargetURL: s.TargetURL, + }) + } + + // Fetch check runs (paginated) + checkPage := 1 + for { + checkURL := fmt.Sprintf("%s/repos/%s/%s/commits/%s/check-runs?per_page=100&page=%d", + c.baseURL, url.PathEscape(owner), url.PathEscape(repo), url.PathEscape(sha), checkPage) + checkBody, err := c.doGet(ctx, checkURL) + if err != nil { + return nil, fmt.Errorf("fetch check runs page %d: %w", checkPage, err) + } + var checkResp checkRunsResponse + if err := json.Unmarshal(checkBody, &checkResp); err != nil { + return nil, fmt.Errorf("parse check runs JSON: %w", err) + } + for _, cr := range checkResp.CheckRuns { + result = append(result, vcs.CommitStatus{ + Context: cr.Name, + Status: mapCheckRunStatus(cr.Conclusion, cr.Status), + Description: derefString(cr.Conclusion), + TargetURL: cr.HTMLURL, + }) + } + if len(checkResp.CheckRuns) < 100 { + break + } + checkPage++ + } + + return result, nil +} + +// mapCheckRunStatus maps a check run conclusion+status to a vcs.CommitStatus status string. +func mapCheckRunStatus(conclusion *string, status string) string { + if conclusion == nil { + // Still running or queued + return "pending" + } + switch *conclusion { + case "success": + return "success" + case "failure", "action_required", "timed_out": + return "failure" + case "cancelled", "skipped", "neutral": + return "success" // non-blocking + case "in_progress", "queued": + return "pending" + default: + return "pending" + } +} + +// derefString safely dereferences a string pointer, returning empty string if nil. +func derefString(s *string) string { + if s == nil { + return "" + } + return *s +} diff --git a/github/pr_test.go b/github/pr_test.go new file mode 100644 index 0000000..78366b7 --- /dev/null +++ b/github/pr_test.go @@ -0,0 +1,637 @@ +package github + +import ( + "context" + "encoding/json" + "fmt" + "net/http" + "net/http/httptest" + "strings" + "testing" + "time" +) + +func TestGetPullRequest_HappyPath(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/repos/owner/repo/pulls/42" { + t.Errorf("unexpected path: %s", r.URL.Path) + } + json.NewEncoder(w).Encode(map[string]interface{}{ + "number": 42, + "title": "Test PR", + "body": "Description", + "head": map[string]string{"sha": "abc123", "ref": "feature-branch"}, + "base": map[string]string{"ref": "main"}, + }) + })) + defer srv.Close() + + c := NewClient("token", srv.URL) + c.SetHTTPClient(srv.Client()) + + pr, err := c.GetPullRequest(context.Background(), "owner", "repo", 42) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if pr.Number != 42 { + t.Errorf("expected number 42, got %d", pr.Number) + } + if pr.Title != "Test PR" { + t.Errorf("expected title 'Test PR', got %q", pr.Title) + } + if pr.Body != "Description" { + t.Errorf("expected body 'Description', got %q", pr.Body) + } + if pr.Head.SHA != "abc123" { + t.Errorf("expected head SHA 'abc123', got %q", pr.Head.SHA) + } + if pr.Head.Ref != "feature-branch" { + t.Errorf("expected head ref 'feature-branch', got %q", pr.Head.Ref) + } + if pr.Base.Ref != "main" { + t.Errorf("expected base ref 'main', got %q", pr.Base.Ref) + } +} + +func TestGetPullRequest_404(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(404) + w.Write([]byte(`{"message":"Not Found"}`)) + })) + defer srv.Close() + + c := NewClient("token", srv.URL) + c.SetHTTPClient(srv.Client()) + + _, err := c.GetPullRequest(context.Background(), "owner", "repo", 999) + if err == nil { + t.Fatal("expected error for 404") + } + if !IsNotFound(err) { + t.Errorf("expected IsNotFound=true, got error: %v", err) + } +} + +func TestGetPullRequest_401(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(401) + w.Write([]byte(`{"message":"Bad credentials"}`)) + })) + defer srv.Close() + + c := NewClient("token", srv.URL) + c.SetHTTPClient(srv.Client()) + + _, err := c.GetPullRequest(context.Background(), "owner", "repo", 1) + if err == nil { + t.Fatal("expected error for 401") + } + if !IsUnauthorized(err) { + t.Errorf("expected IsUnauthorized=true, got error: %v", err) + } +} + +func TestGetPullRequest_429Retry(t *testing.T) { + attempts := 0 + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + attempts++ + if attempts == 1 { + w.WriteHeader(429) + w.Write([]byte(`{"message":"rate limit"}`)) + return + } + json.NewEncoder(w).Encode(map[string]interface{}{ + "number": 1, + "title": "PR", + "body": "", + "head": map[string]string{"sha": "abc", "ref": "br"}, + "base": map[string]string{"ref": "main"}, + }) + })) + defer srv.Close() + + c := NewClient("token", srv.URL) + c.SetHTTPClient(srv.Client()) + c.RetryBackoff = []time.Duration{1 * time.Millisecond} + + pr, err := c.GetPullRequest(context.Background(), "owner", "repo", 1) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if pr.Number != 1 { + t.Errorf("expected number 1, got %d", pr.Number) + } + if attempts != 2 { + t.Errorf("expected 2 attempts, got %d", attempts) + } +} + +func TestGetPullRequest_MalformedJSON(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(200) + w.Write([]byte(`{invalid json`)) + })) + defer srv.Close() + + c := NewClient("token", srv.URL) + c.SetHTTPClient(srv.Client()) + + _, err := c.GetPullRequest(context.Background(), "owner", "repo", 1) + if err == nil { + t.Fatal("expected error for malformed JSON") + } + if !strings.Contains(err.Error(), "parse PR JSON") { + t.Errorf("expected parse error, got: %v", err) + } +} + +func TestGetPullRequestDiff_HappyPath(t *testing.T) { + expectedDiff := "diff --git a/file.go b/file.go\n--- a/file.go\n+++ b/file.go\n@@ -1,3 +1,4 @@\n+// new line\n" + var gotAccept string + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + gotAccept = r.Header.Get("Accept") + w.WriteHeader(200) + w.Write([]byte(expectedDiff)) + })) + defer srv.Close() + + c := NewClient("token", srv.URL) + c.SetHTTPClient(srv.Client()) + + diff, err := c.GetPullRequestDiff(context.Background(), "owner", "repo", 42) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if diff != expectedDiff { + t.Errorf("unexpected diff: %q", diff) + } + if gotAccept != "application/vnd.github.diff" { + t.Errorf("expected diff Accept header, got %q", gotAccept) + } +} + +func TestGetPullRequestDiff_404(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(404) + w.Write([]byte(`{"message":"Not Found"}`)) + })) + defer srv.Close() + + c := NewClient("token", srv.URL) + c.SetHTTPClient(srv.Client()) + + _, err := c.GetPullRequestDiff(context.Background(), "owner", "repo", 999) + if err == nil { + t.Fatal("expected error for 404") + } +} + +func TestGetPullRequestDiff_401(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(401) + w.Write([]byte(`{"message":"Bad credentials"}`)) + })) + defer srv.Close() + + c := NewClient("token", srv.URL) + c.SetHTTPClient(srv.Client()) + + _, err := c.GetPullRequestDiff(context.Background(), "owner", "repo", 1) + if err == nil { + t.Fatal("expected error for 401") + } +} + +func TestGetPullRequestFiles_HappyPath(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + json.NewEncoder(w).Encode([]map[string]interface{}{ + {"filename": "main.go", "status": "modified", "patch": "@@ -1,3 +1,4 @@\n+line"}, + {"filename": "test.go", "status": "added", "patch": "@@ -0,0 +1,5 @@\n+new file"}, + }) + })) + defer srv.Close() + + c := NewClient("token", srv.URL) + c.SetHTTPClient(srv.Client()) + + files, err := c.GetPullRequestFiles(context.Background(), "owner", "repo", 1) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if len(files) != 2 { + t.Fatalf("expected 2 files, got %d", len(files)) + } + if files[0].Filename != "main.go" { + t.Errorf("expected filename 'main.go', got %q", files[0].Filename) + } + if files[0].Status != "modified" { + t.Errorf("expected status 'modified', got %q", files[0].Status) + } + if files[0].Patch != "@@ -1,3 +1,4 @@\n+line" { + t.Errorf("unexpected patch: %q", files[0].Patch) + } +} + +func TestGetPullRequestFiles_Pagination(t *testing.T) { + // Simulate > 100 files requiring pagination + page1Files := make([]map[string]string, 100) + for i := 0; i < 100; i++ { + page1Files[i] = map[string]string{ + "filename": fmt.Sprintf("file%d.go", i), + "status": "modified", + "patch": fmt.Sprintf("patch%d", i), + } + } + page2Files := []map[string]string{ + {"filename": "file100.go", "status": "added", "patch": "patch100"}, + } + + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + page := r.URL.Query().Get("page") + if page == "" || page == "1" { + json.NewEncoder(w).Encode(page1Files) + } else { + json.NewEncoder(w).Encode(page2Files) + } + })) + defer srv.Close() + + c := NewClient("token", srv.URL) + c.SetHTTPClient(srv.Client()) + + files, err := c.GetPullRequestFiles(context.Background(), "owner", "repo", 1) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if len(files) != 101 { + t.Errorf("expected 101 files (paginated), got %d", len(files)) + } + if files[100].Filename != "file100.go" { + t.Errorf("expected last file 'file100.go', got %q", files[100].Filename) + } + if files[100].Patch != "patch100" { + t.Errorf("expected last patch 'patch100', got %q", files[100].Patch) + } +} + +func TestGetPullRequestFiles_BinaryFile_NoPatch(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Binary files have no patch field in GitHub response + json.NewEncoder(w).Encode([]map[string]interface{}{ + {"filename": "image.png", "status": "added"}, + }) + })) + defer srv.Close() + + c := NewClient("token", srv.URL) + c.SetHTTPClient(srv.Client()) + + files, err := c.GetPullRequestFiles(context.Background(), "owner", "repo", 1) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if len(files) != 1 { + t.Fatalf("expected 1 file, got %d", len(files)) + } + if files[0].Patch != "" { + t.Errorf("expected empty patch for binary file, got %q", files[0].Patch) + } +} + +func TestGetPullRequestFiles_404(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(404) + w.Write([]byte(`{"message":"Not Found"}`)) + })) + defer srv.Close() + + c := NewClient("token", srv.URL) + c.SetHTTPClient(srv.Client()) + + _, err := c.GetPullRequestFiles(context.Background(), "owner", "repo", 999) + if err == nil { + t.Fatal("expected error for 404") + } +} + +func TestGetPullRequestFiles_MalformedJSON(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(200) + w.Write([]byte(`not json`)) + })) + defer srv.Close() + + c := NewClient("token", srv.URL) + c.SetHTTPClient(srv.Client()) + + _, err := c.GetPullRequestFiles(context.Background(), "owner", "repo", 1) + if err == nil { + t.Fatal("expected error for malformed JSON") + } +} + +func TestGetFileContentAtRef_HappyPath(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/repos/owner/repo/contents/path/to/file.go" { + t.Errorf("unexpected path: %s", r.URL.Path) + } + if r.URL.Query().Get("ref") != "abc123" { + t.Errorf("unexpected ref: %s", r.URL.Query().Get("ref")) + } + json.NewEncoder(w).Encode(map[string]string{ + "content": "cGFja2FnZSBtYWlu", // "package main" in base64 + "encoding": "base64", + }) + })) + defer srv.Close() + + c := NewClient("token", srv.URL) + c.SetHTTPClient(srv.Client()) + + content, err := c.GetFileContentAtRef(context.Background(), "owner", "repo", "path/to/file.go", "abc123") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if content != "package main" { + t.Errorf("expected 'package main', got %q", content) + } +} + +func TestGetFileContentAtRef_EmptyRef(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Query().Get("ref") != "" { + t.Errorf("expected no ref param, got %q", r.URL.Query().Get("ref")) + } + json.NewEncoder(w).Encode(map[string]string{ + "content": "aGVsbG8=", // "hello" in base64 + "encoding": "base64", + }) + })) + defer srv.Close() + + c := NewClient("token", srv.URL) + c.SetHTTPClient(srv.Client()) + + content, err := c.GetFileContentAtRef(context.Background(), "owner", "repo", "file.txt", "") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if content != "hello" { + t.Errorf("expected 'hello', got %q", content) + } +} + +func TestGetFileContentAtRef_404(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(404) + w.Write([]byte(`{"message":"Not Found"}`)) + })) + defer srv.Close() + + c := NewClient("token", srv.URL) + c.SetHTTPClient(srv.Client()) + + _, err := c.GetFileContentAtRef(context.Background(), "owner", "repo", "missing.go", "main") + if err == nil { + t.Fatal("expected error for 404") + } +} + +func TestGetFileContentAtRef_401(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(401) + w.Write([]byte(`{"message":"Bad credentials"}`)) + })) + defer srv.Close() + + c := NewClient("token", srv.URL) + c.SetHTTPClient(srv.Client()) + + _, err := c.GetFileContentAtRef(context.Background(), "owner", "repo", "file.go", "main") + if err == nil { + t.Fatal("expected error for 401") + } +} + +func TestGetFileContentAtRef_MalformedJSON(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(200) + w.Write([]byte(`not valid json`)) + })) + defer srv.Close() + + c := NewClient("token", srv.URL) + c.SetHTTPClient(srv.Client()) + + _, err := c.GetFileContentAtRef(context.Background(), "owner", "repo", "file.go", "main") + if err == nil { + t.Fatal("expected error for malformed JSON") + } +} + +func TestGetFileContentAtRef_429Retry(t *testing.T) { + attempts := 0 + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + attempts++ + if attempts == 1 { + w.WriteHeader(429) + w.Write([]byte(`{"message":"rate limit"}`)) + return + } + json.NewEncoder(w).Encode(map[string]string{ + "content": "b2s=", // "ok" in base64 + "encoding": "base64", + }) + })) + defer srv.Close() + + c := NewClient("token", srv.URL) + c.SetHTTPClient(srv.Client()) + c.RetryBackoff = []time.Duration{1 * time.Millisecond} + + content, err := c.GetFileContentAtRef(context.Background(), "owner", "repo", "file.go", "main") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if content != "ok" { + t.Errorf("expected 'ok', got %q", content) + } + if attempts != 2 { + t.Errorf("expected 2 attempts, got %d", attempts) + } +} + +func TestGetCommitStatuses_HappyPath(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch { + case strings.Contains(r.URL.Path, "/status"): + json.NewEncoder(w).Encode(map[string]interface{}{ + "state": "success", + "statuses": []map[string]string{ + { + "context": "ci/build", + "state": "success", + "description": "Build passed", + "target_url": "https://ci.example.com/1", + }, + }, + }) + case strings.Contains(r.URL.Path, "/check-runs"): + conclusion := "success" + json.NewEncoder(w).Encode(map[string]interface{}{ + "total_count": 1, + "check_runs": []map[string]interface{}{ + { + "name": "lint", + "conclusion": &conclusion, + "status": "completed", + "html_url": "https://github.com/check/1", + }, + }, + }) + default: + t.Errorf("unexpected path: %s", r.URL.Path) + w.WriteHeader(404) + } + })) + defer srv.Close() + + c := NewClient("token", srv.URL) + c.SetHTTPClient(srv.Client()) + + statuses, err := c.GetCommitStatuses(context.Background(), "owner", "repo", "abc123") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if len(statuses) != 2 { + t.Fatalf("expected 2 statuses, got %d", len(statuses)) + } + // First should be from commit statuses + if statuses[0].Context != "ci/build" { + t.Errorf("expected context 'ci/build', got %q", statuses[0].Context) + } + if statuses[0].Status != "success" { + t.Errorf("expected status 'success', got %q", statuses[0].Status) + } + // Second should be from check runs + if statuses[1].Context != "lint" { + t.Errorf("expected context 'lint', got %q", statuses[1].Context) + } + if statuses[1].Status != "success" { + t.Errorf("expected status 'success', got %q", statuses[1].Status) + } +} + +func TestGetCommitStatuses_CheckRunConclusions(t *testing.T) { + tests := []struct { + conclusion *string + status string + want string + }{ + {strPtr("success"), "completed", "success"}, + {strPtr("failure"), "completed", "failure"}, + {strPtr("action_required"), "completed", "failure"}, + {strPtr("timed_out"), "completed", "failure"}, + {strPtr("cancelled"), "completed", "success"}, + {strPtr("skipped"), "completed", "success"}, + {strPtr("neutral"), "completed", "success"}, + {nil, "in_progress", "pending"}, + {nil, "queued", "pending"}, + } + + for _, tt := range tests { + name := "nil" + if tt.conclusion != nil { + name = *tt.conclusion + } + t.Run(name, func(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if strings.Contains(r.URL.Path, "/status") { + json.NewEncoder(w).Encode(map[string]interface{}{ + "state": "success", + "statuses": []interface{}{}, + }) + return + } + json.NewEncoder(w).Encode(map[string]interface{}{ + "total_count": 1, + "check_runs": []map[string]interface{}{ + { + "name": "check", + "conclusion": tt.conclusion, + "status": tt.status, + "html_url": "https://github.com/check/1", + }, + }, + }) + })) + defer srv.Close() + + c := NewClient("token", srv.URL) + c.SetHTTPClient(srv.Client()) + + statuses, err := c.GetCommitStatuses(context.Background(), "owner", "repo", "sha1") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if len(statuses) != 1 { + t.Fatalf("expected 1 status, got %d", len(statuses)) + } + if statuses[0].Status != tt.want { + t.Errorf("expected status %q, got %q", tt.want, statuses[0].Status) + } + }) + } +} + +func TestGetCommitStatuses_404(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(404) + w.Write([]byte(`{"message":"Not Found"}`)) + })) + defer srv.Close() + + c := NewClient("token", srv.URL) + c.SetHTTPClient(srv.Client()) + + _, err := c.GetCommitStatuses(context.Background(), "owner", "repo", "badsha") + if err == nil { + t.Fatal("expected error for 404") + } +} + +func TestGetCommitStatuses_401(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(401) + w.Write([]byte(`{"message":"Bad credentials"}`)) + })) + defer srv.Close() + + c := NewClient("token", srv.URL) + c.SetHTTPClient(srv.Client()) + + _, err := c.GetCommitStatuses(context.Background(), "owner", "repo", "sha") + if err == nil { + t.Fatal("expected error for 401") + } +} + +func TestGetCommitStatuses_MalformedJSON(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(200) + w.Write([]byte(`not json`)) + })) + defer srv.Close() + + c := NewClient("token", srv.URL) + c.SetHTTPClient(srv.Client()) + + _, err := c.GetCommitStatuses(context.Background(), "owner", "repo", "sha") + if err == nil { + t.Fatal("expected error for malformed JSON") + } +} + +func strPtr(s string) *string { + return &s +} diff --git a/vcs/types.go b/vcs/types.go index de904f3..608ad27 100644 --- a/vcs/types.go +++ b/vcs/types.go @@ -44,6 +44,7 @@ type PullRequest struct { type ChangedFile struct { Filename string `json:"filename"` Status string `json:"status"` + Patch string `json:"patch"` } // ContentEntry represents a file or directory entry from the contents API.