Compare commits

..

2 Commits

Author SHA1 Message Date
aweiker 642d1e129f feat(github): implement FileReader interface (#80)
PR Ready Gate / clear-labels (pull_request) Successful in 1s
CI / test (pull_request) Successful in 21s
CI / review (anthropic--claude-4.6-sonnet, sonnet, SONNET_REVIEW_TOKEN) (pull_request) Successful in 36s
CI / review (gpt-5, security, ., rodin/security-patterns, SECURITY_REVIEW.md, SECURITY_REVIEW_TOKEN) (pull_request) Successful in 57s
CI / review (gpt-5, gpt, GPT_REVIEW_TOKEN) (pull_request) Successful in 2m11s
Implement FileReader conformance on the GitHub client: GetFileContent,
ListContents, path helpers, base64 decode. Includes compile-time
conformance checks for both PRReader and FileReader.

Requires PR B (#102). Part 3 of 3 for #80.
2026-05-13 05:29:50 +00:00
aweiker 1cf0149c24 feat(github): implement PRReader interface (#80)
Implement PRReader conformance on the GitHub client: GetPullRequest,
GetPullRequestDiff, GetPullRequestFiles (paginated, populates Patch),
GetCommitStatuses (merges commit statuses + check runs).
Adds compile-time PRReader conformance check.

Requires PR A. Part 2 of 3 for #80.
2026-05-13 05:29:50 +00:00
4 changed files with 102 additions and 247 deletions
+32 -57
View File
@@ -6,7 +6,6 @@ import (
"encoding/json" "encoding/json"
"fmt" "fmt"
"net/url" "net/url"
"path"
"strings" "strings"
"gitea.weiker.me/rodin/review-bot/vcs" "gitea.weiker.me/rodin/review-bot/vcs"
@@ -14,28 +13,25 @@ import (
// GetFileContent fetches a file from a repo at the given ref. // GetFileContent fetches a file from a repo at the given ref.
// Delegates to GetFileContentAtRef with the provided ref. // Delegates to GetFileContentAtRef with the provided ref.
func (c *Client) GetFileContent(ctx context.Context, owner, repo, filePath, ref string) (string, error) { func (c *Client) GetFileContent(ctx context.Context, owner, repo, path, ref string) (string, error) {
return c.GetFileContentAtRef(ctx, owner, repo, filePath, ref) return c.GetFileContentAtRef(ctx, owner, repo, path, ref)
} }
// GetFileContentAtRef fetches a file at a specific ref from a repo. // GetFileContentAtRef fetches a file at a specific ref from a repo.
// If ref is empty, the query parameter is omitted (uses default branch). // If ref is empty, the query parameter is omitted (uses default branch).
// //
// Returns an error if the path contains dot-segments (".", "..") or // Note: dot-segments ("." and "..") in the path are silently removed to
// attempts to traverse above the repository root. // prevent path traversal. This means a path like "foo/../bar" resolves
func (c *Client) GetFileContentAtRef(ctx context.Context, owner, repo, filePath, ref string) (string, error) { // to "foo/bar" rather than "bar".
escaped, err := escapePath(filePath) func (c *Client) GetFileContentAtRef(ctx context.Context, owner, repo, path, ref string) (string, error) {
if err != nil {
return "", fmt.Errorf("invalid file path: %w", err)
}
reqURL := fmt.Sprintf("%s/repos/%s/%s/contents/%s", reqURL := fmt.Sprintf("%s/repos/%s/%s/contents/%s",
c.baseURL, url.PathEscape(owner), url.PathEscape(repo), escaped) c.baseURL, url.PathEscape(owner), url.PathEscape(repo), escapePath(path))
if ref != "" { if ref != "" {
reqURL += "?ref=" + url.QueryEscape(ref) reqURL += "?ref=" + url.QueryEscape(ref)
} }
body, err := c.doGet(ctx, reqURL) body, err := c.doGet(ctx, reqURL)
if err != nil { if err != nil {
return "", fmt.Errorf("fetch file %s: %w", filePath, err) return "", fmt.Errorf("fetch file %s: %w", path, err)
} }
var resp struct { var resp struct {
Content string `json:"content"` Content string `json:"content"`
@@ -45,11 +41,11 @@ func (c *Client) GetFileContentAtRef(ctx context.Context, owner, repo, filePath,
return "", fmt.Errorf("parse file content JSON: %w", err) return "", fmt.Errorf("parse file content JSON: %w", err)
} }
if resp.Encoding != "base64" { if resp.Encoding != "base64" {
return "", fmt.Errorf("unexpected encoding %q for file %s", resp.Encoding, filePath) return "", fmt.Errorf("unexpected encoding %q for file %s", resp.Encoding, path)
} }
decoded, err := decodeBase64Content(resp.Content) decoded, err := decodeBase64Content(resp.Content)
if err != nil { if err != nil {
return "", fmt.Errorf("decode base64 content for %s: %w", filePath, err) return "", fmt.Errorf("decode base64 content for %s: %w", path, err)
} }
return decoded, nil return decoded, nil
} }
@@ -59,16 +55,16 @@ func (c *Client) GetFileContentAtRef(ctx context.Context, owner, repo, filePath,
// If the path points to a single file (not a directory), the API returns // If the path points to a single file (not a directory), the API returns
// a JSON object instead of an array; this is handled by returning a // a JSON object instead of an array; this is handled by returning a
// single-element slice. // single-element slice.
func (c *Client) ListContents(ctx context.Context, owner, repo, filePath string) ([]vcs.ContentEntry, error) { //
escaped, err := escapePath(filePath) // Note: dot-segments ("." and "..") in the path are silently removed to
if err != nil { // prevent path traversal. This means a path like "foo/../bar" resolves
return nil, fmt.Errorf("invalid file path: %w", err) // to "foo/bar" rather than "bar".
} func (c *Client) ListContents(ctx context.Context, owner, repo, path string) ([]vcs.ContentEntry, error) {
reqURL := fmt.Sprintf("%s/repos/%s/%s/contents/%s", reqURL := fmt.Sprintf("%s/repos/%s/%s/contents/%s",
c.baseURL, url.PathEscape(owner), url.PathEscape(repo), escaped) c.baseURL, url.PathEscape(owner), url.PathEscape(repo), escapePath(path))
body, err := c.doGet(ctx, reqURL) body, err := c.doGet(ctx, reqURL)
if err != nil { if err != nil {
return nil, fmt.Errorf("list contents %s: %w", filePath, err) return nil, fmt.Errorf("list contents %s: %w", path, err)
} }
type entry struct { type entry struct {
@@ -106,55 +102,34 @@ func (c *Client) ListContents(ctx context.Context, owner, repo, filePath string)
return result, nil return result, nil
} }
// escapePath validates and encodes a slash-separated file path for use in // escapePath escapes each segment of a relative file path for use in URLs.
// GitHub API URLs. Returns an error if the path contains dot-segments ("." // Slashes are preserved as path separators; other special characters are escaped.
// or "..") or resolves to a path outside the repository root. // Dot-segments ("." and "..") and empty segments (from consecutive slashes like
func escapePath(p string) (string, error) { // "a//b") are silently removed to prevent path traversal and produce canonical
// Reject paths containing dot-segments rather than silently rewriting them. // paths. This is intentional: callers may receive a different path than requested
for _, seg := range strings.Split(p, "/") { // without error. The function is package-private, and all callers
if seg == "." || seg == ".." { // (GetFileContentAtRef, ListContents) already handle missing-file errors from the
return "", fmt.Errorf("path contains dot-segment %q: %s", seg, p) // API if the cleaned path doesn't match what the caller intended.
} func escapePath(p string) string {
} parts := strings.Split(p, "/")
var clean []string
// Use path.Clean for canonical form, then verify it doesn't escape root.
cleaned := path.Clean(p)
if cleaned == "." || strings.HasPrefix(cleaned, "..") {
return "", fmt.Errorf("path resolves outside repository root: %s", p)
}
// Encode each segment individually.
parts := strings.Split(cleaned, "/")
var encoded []string
for _, part := range parts { for _, part := range parts {
if part == "" { if part == "." || part == ".." || part == "" {
continue continue
} }
encoded = append(encoded, url.PathEscape(part)) clean = append(clean, url.PathEscape(part))
} }
return strings.Join(encoded, "/"), nil return strings.Join(clean, "/")
} }
// maxFileContentSize is the maximum decoded file size (10 MB) to prevent
// resource exhaustion when decoding base64 content from the API.
const maxFileContentSize = 10 * 1024 * 1024
// decodeBase64Content decodes base64-encoded content from the GitHub contents API. // decodeBase64Content decodes base64-encoded content from the GitHub contents API.
// GitHub returns base64 content with line breaks for formatting; we strip \r and \n before decoding. // GitHub returns base64 content with line breaks for formatting; we strip \r and \n before decoding.
// Returns an error if the decoded content exceeds maxFileContentSize.
func decodeBase64Content(encoded string) (string, error) { func decodeBase64Content(encoded string) (string, error) {
// GitHub inserts newlines in base64 content
cleaned := strings.NewReplacer("\n", "", "\r", "").Replace(encoded) cleaned := strings.NewReplacer("\n", "", "\r", "").Replace(encoded)
// Check estimated decoded size before allocating.
// Base64 encodes 3 bytes into 4 chars, so decoded ~ len*3/4.
if len(cleaned)*3/4 > maxFileContentSize {
return "", fmt.Errorf("file content too large: estimated %d bytes exceeds limit of %d", len(cleaned)*3/4, maxFileContentSize)
}
decoded, err := base64.StdEncoding.DecodeString(cleaned) decoded, err := base64.StdEncoding.DecodeString(cleaned)
if err != nil { if err != nil {
return "", err return "", err
} }
if len(decoded) > maxFileContentSize {
return "", fmt.Errorf("file content too large: %d bytes exceeds limit of %d", len(decoded), maxFileContentSize)
}
return string(decoded), nil return string(decoded), nil
} }
+54 -125
View File
@@ -5,7 +5,6 @@ import (
"encoding/json" "encoding/json"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"strings"
"testing" "testing"
"time" "time"
) )
@@ -110,9 +109,7 @@ func TestGetFileContent_429Retry(t *testing.T) {
c := NewClient("token", srv.URL, AllowInsecureHTTP()) c := NewClient("token", srv.URL, AllowInsecureHTTP())
c.SetHTTPClient(srv.Client()) c.SetHTTPClient(srv.Client())
if err := c.SetRetryBackoff([]time.Duration{1 * time.Millisecond, 1 * time.Millisecond}); err != nil { c.SetRetryBackoff([]time.Duration{1 * time.Millisecond})
t.Fatalf("SetRetryBackoff: %v", err)
}
content, err := c.GetFileContent(context.Background(), "owner", "repo", "file.go", "") content, err := c.GetFileContent(context.Background(), "owner", "repo", "file.go", "")
if err != nil { if err != nil {
@@ -230,11 +227,9 @@ func TestListContents_429Retry(t *testing.T) {
c := NewClient("token", srv.URL, AllowInsecureHTTP()) c := NewClient("token", srv.URL, AllowInsecureHTTP())
c.SetHTTPClient(srv.Client()) c.SetHTTPClient(srv.Client())
if err := c.SetRetryBackoff([]time.Duration{1 * time.Millisecond, 1 * time.Millisecond}); err != nil { c.SetRetryBackoff([]time.Duration{1 * time.Millisecond})
t.Fatalf("SetRetryBackoff: %v", err)
}
entries, err := c.ListContents(context.Background(), "owner", "repo", "src") entries, err := c.ListContents(context.Background(), "owner", "repo", ".")
if err != nil { if err != nil {
t.Fatalf("unexpected error: %v", err) t.Fatalf("unexpected error: %v", err)
} }
@@ -262,6 +257,57 @@ func TestListContents_MalformedJSON(t *testing.T) {
} }
} }
func TestDecodeBase64Content(t *testing.T) {
// Test with newlines (GitHub's format)
encoded := "cGFja2FnZSBt\nYWlu"
decoded, err := decodeBase64Content(encoded)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if decoded != "package main" {
t.Errorf("expected 'package main', got %q", decoded)
}
}
func TestDecodeBase64Content_Invalid(t *testing.T) {
_, err := decodeBase64Content("not!!!valid!!!base64")
if err == nil {
t.Fatal("expected error for invalid base64")
}
}
func TestEscapePath_RejectsDotSegments(t *testing.T) {
tests := []struct {
input string
want string
}{
{"src/main.go", "src/main.go"},
{"../etc/passwd", "etc/passwd"},
{"./src/../main.go", "src/main.go"},
{"a/b/c", "a/b/c"},
{"file with spaces.go", "file%20with%20spaces.go"},
{"a/./b/../c", "a/b/c"},
}
for _, tt := range tests {
got := escapePath(tt.input)
if got != tt.want {
t.Errorf("escapePath(%q) = %q, want %q", tt.input, got, tt.want)
}
}
}
func TestDecodeBase64Content_CRLF(t *testing.T) {
// Base64 of "hello world" with CRLF line breaks inserted
encoded := "aGVs\r\nbG8g\r\nd29y\r\nbGQ="
decoded, err := decodeBase64Content(encoded)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if decoded != "hello world" {
t.Errorf("expected 'hello world', got %q", decoded)
}
}
func TestListContents_SingleFile(t *testing.T) { func TestListContents_SingleFile(t *testing.T) {
// GitHub Contents API returns a JSON object (not array) for single-file paths // GitHub Contents API returns a JSON object (not array) for single-file paths
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
@@ -286,120 +332,3 @@ func TestListContents_SingleFile(t *testing.T) {
t.Errorf("expected type 'file', got %q", entries[0].Type) t.Errorf("expected type 'file', got %q", entries[0].Type)
} }
} }
func TestEscapePath_ValidPaths(t *testing.T) {
t.Parallel()
tests := []struct {
name string
path string
want string
}{
{"simple file", "file.go", "file.go"},
{"nested path", "path/to/file.go", "path/to/file.go"},
{"special chars", "path/to/my file.go", "path/to/my%20file.go"},
{"leading slash stripped", "/path/to/file.go", "path/to/file.go"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
got, err := escapePath(tt.path)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if got != tt.want {
t.Errorf("escapePath(%q) = %q, want %q", tt.path, got, tt.want)
}
})
}
}
func TestEscapePath_DotSegments(t *testing.T) {
t.Parallel()
tests := []struct {
name string
path string
}{
{"single dot", "./file.go"},
{"double dot", "../file.go"},
{"dot in middle", "path/./file.go"},
{"parent traversal", "path/../file.go"},
{"only dots", ".."},
{"nested parent traversal", "a/b/../../c"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
_, err := escapePath(tt.path)
if err == nil {
t.Fatalf("expected error for path %q, got nil", tt.path)
}
if !strings.Contains(err.Error(), "dot-segment") {
t.Errorf("expected error about dot-segment, got: %v", err)
}
})
}
}
func TestGetFileContentAtRef_DotSegmentError(t *testing.T) {
// Server should never be called — the error is caught before the request.
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
t.Fatal("server should not have been called")
}))
defer srv.Close()
c := NewClient("token", srv.URL, AllowInsecureHTTP())
_, err := c.GetFileContentAtRef(context.Background(), "owner", "repo", "foo/../bar.go", "main")
if err == nil {
t.Fatal("expected error for path with dot-segments")
}
if !strings.Contains(err.Error(), "invalid file path") {
t.Errorf("expected 'invalid file path' error, got: %v", err)
}
}
func TestDecodeBase64Content(t *testing.T) {
// Test with newlines (GitHub's format)
encoded := "cGFja2FnZSBt\nYWlu"
decoded, err := decodeBase64Content(encoded)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if decoded != "package main" {
t.Errorf("expected 'package main', got %q", decoded)
}
}
func TestDecodeBase64Content_Invalid(t *testing.T) {
_, err := decodeBase64Content("not!!!valid!!!base64")
if err == nil {
t.Fatal("expected error for invalid base64")
}
}
func TestDecodeBase64Content_CRLF(t *testing.T) {
// Base64 of "hello world" with CRLF line breaks inserted
encoded := "aGVs\r\nbG8g\r\nd29y\r\nbGQ="
decoded, err := decodeBase64Content(encoded)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if decoded != "hello world" {
t.Errorf("expected 'hello world', got %q", decoded)
}
}
func TestDecodeBase64Content_SizeLimit(t *testing.T) {
t.Parallel()
// Create base64 content that would decode to > maxFileContentSize.
// maxFileContentSize is 10MB. Base64 of 11MB worth of zeros.
// We just need something big enough to trigger the estimated size check.
// 14MB of base64 chars (decodes to ~10.5MB).
huge := strings.Repeat("A", 14*1024*1024)
_, err := decodeBase64Content(huge)
if err == nil {
t.Fatal("expected error for oversized content")
}
if !strings.Contains(err.Error(), "too large") {
t.Errorf("expected 'too large' error, got: %v", err)
}
}
+16 -26
View File
@@ -51,10 +51,7 @@ type checkRunsResponse struct {
} `json:"check_runs"` } `json:"check_runs"`
} }
// GetPullRequest fetches PR metadata from the GitHub API. // GetPullRequest fetches PR metadata.
// Returns an *APIError wrapping the HTTP status on non-2xx responses (e.g.
// IsNotFound for 404, IsUnauthorized for 401). Network and context errors
// are wrapped but not typed as *APIError.
func (c *Client) GetPullRequest(ctx context.Context, owner, repo string, number int) (*vcs.PullRequest, error) { func (c *Client) GetPullRequest(ctx context.Context, owner, repo string, number int) (*vcs.PullRequest, error) {
reqURL := fmt.Sprintf("%s/repos/%s/%s/pulls/%d", c.baseURL, url.PathEscape(owner), url.PathEscape(repo), number) reqURL := fmt.Sprintf("%s/repos/%s/%s/pulls/%d", c.baseURL, url.PathEscape(owner), url.PathEscape(repo), number)
body, err := c.doGet(ctx, reqURL) body, err := c.doGet(ctx, reqURL)
@@ -85,15 +82,9 @@ func (c *Client) GetPullRequestDiff(ctx context.Context, owner, repo string, num
return string(body), nil return string(body), nil
} }
const ( // maxPages is the upper bound on pagination loops to prevent unbounded iteration
// maxFilesPages is the upper bound on pagination loops for PR file listing, // in case the server returns a full page indefinitely.
// preventing unbounded iteration if the server always returns a full page. const maxPages = 100
maxFilesPages = 100
// maxCheckRunPages is the upper bound on pagination loops for check-run listing,
// preventing unbounded iteration if the server always returns a full page.
maxCheckRunPages = 100
)
// GetPullRequestFiles fetches the list of files changed in a PR. // GetPullRequestFiles fetches the list of files changed in a PR.
// Paginates through all pages (100 per page) to collect all files. // Paginates through all pages (100 per page) to collect all files.
@@ -102,7 +93,7 @@ const (
func (c *Client) GetPullRequestFiles(ctx context.Context, owner, repo string, number int) ([]vcs.ChangedFile, error) { func (c *Client) GetPullRequestFiles(ctx context.Context, owner, repo string, number int) ([]vcs.ChangedFile, error) {
var allFiles []vcs.ChangedFile var allFiles []vcs.ChangedFile
for page := 1; page <= maxFilesPages; page++ { for page := 1; page <= maxPages; page++ {
reqURL := fmt.Sprintf("%s/repos/%s/%s/pulls/%d/files?per_page=100&page=%d", reqURL := fmt.Sprintf("%s/repos/%s/%s/pulls/%d/files?per_page=100&page=%d",
c.baseURL, url.PathEscape(owner), url.PathEscape(repo), number, page) c.baseURL, url.PathEscape(owner), url.PathEscape(repo), number, page)
body, err := c.doGet(ctx, reqURL) body, err := c.doGet(ctx, reqURL)
@@ -163,7 +154,7 @@ func (c *Client) GetCommitStatuses(ctx context.Context, owner, repo, sha string)
} }
// Fetch check runs (paginated) // Fetch check runs (paginated)
for checkPage := 1; checkPage <= maxCheckRunPages; checkPage++ { for checkPage := 1; checkPage <= maxPages; checkPage++ {
checkURL := fmt.Sprintf("%s/repos/%s/%s/commits/%s/check-runs?per_page=100&page=%d", checkURL := fmt.Sprintf("%s/repos/%s/%s/commits/%s/check-runs?per_page=100&page=%d",
c.baseURL, url.PathEscape(owner), url.PathEscape(repo), url.PathEscape(sha), checkPage) c.baseURL, url.PathEscape(owner), url.PathEscape(repo), url.PathEscape(sha), checkPage)
checkBody, err := c.doGet(ctx, checkURL) checkBody, err := c.doGet(ctx, checkURL)
@@ -178,7 +169,7 @@ func (c *Client) GetCommitStatuses(ctx context.Context, owner, repo, sha string)
result = append(result, vcs.CommitStatus{ result = append(result, vcs.CommitStatus{
Context: cr.Name, Context: cr.Name,
Status: mapCheckRunStatus(cr.Conclusion), Status: mapCheckRunStatus(cr.Conclusion),
Description: "", // check runs have no human-readable description; conclusion is captured in Status Description: derefString(cr.Conclusion),
TargetURL: cr.HTMLURL, TargetURL: cr.HTMLURL,
}) })
} }
@@ -190,17 +181,9 @@ func (c *Client) GetCommitStatuses(ctx context.Context, owner, repo, sha string)
return result, nil return result, nil
} }
// mapCheckRunStatus maps a GitHub check run conclusion to a vcs.CommitStatus status string. // mapCheckRunStatus maps a check run conclusion to a vcs.CommitStatus status string.
// Conclusion alone determines the mapped state: nil conclusion means the run is // Conclusion alone determines the mapped state: nil conclusion means the run is
// still in progress (pending), regardless of the status field value. // still in progress (pending), regardless of the status field value.
//
// Mapping rules:
// - nil → "pending" (run still in progress or queued)
// - "success" → "success"
// - "failure", "action_required", "timed_out" → "failure"
// - "cancelled", "skipped", "neutral" → "success" (non-blocking per GitHub check suite semantics)
// - "stale" → "pending" (check run became stale before completing)
// - unknown values → "pending" (conservative: treat unrecognized conclusions as incomplete)
func mapCheckRunStatus(conclusion *string) string { func mapCheckRunStatus(conclusion *string) string {
if conclusion == nil { if conclusion == nil {
// Still running or queued // Still running or queued
@@ -213,10 +196,17 @@ func mapCheckRunStatus(conclusion *string) string {
return "failure" return "failure"
case "cancelled", "skipped", "neutral": case "cancelled", "skipped", "neutral":
return "success" // non-blocking: these do not indicate a blocking failure per GitHub check suite semantics return "success" // non-blocking: these do not indicate a blocking failure per GitHub check suite semantics
case "stale": case "stale", "waiting":
return "pending" return "pending"
default: default:
return "pending" return "pending"
} }
} }
// derefString safely dereferences a string pointer, returning empty string if nil.
func derefString(s *string) string {
if s == nil {
return ""
}
return *s
}
-39
View File
@@ -545,7 +545,6 @@ func TestGetCommitStatuses_CheckRunConclusions(t *testing.T) {
name = *tt.conclusion name = *tt.conclusion
} }
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
t.Parallel()
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if strings.Contains(r.URL.Path, "/status") { if strings.Contains(r.URL.Path, "/status") {
json.NewEncoder(w).Encode(map[string]interface{}{ json.NewEncoder(w).Encode(map[string]interface{}{
@@ -633,44 +632,6 @@ func TestGetCommitStatuses_MalformedJSON(t *testing.T) {
} }
} }
func TestGetCommitStatuses_CheckRunsErrorAfterStatusesSucceed(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch {
case strings.Contains(r.URL.Path, "/status"):
// Statuses succeed
json.NewEncoder(w).Encode(map[string]interface{}{
"state": "success",
"statuses": []map[string]string{
{
"context": "ci/build",
"state": "success",
"description": "Build passed",
"target_url": "https://ci.example.com/1",
},
},
})
case strings.Contains(r.URL.Path, "/check-runs"):
// Check runs fail with 500
w.WriteHeader(500)
w.Write([]byte(`{"message":"Internal Server Error"}`))
default:
w.WriteHeader(404)
}
}))
defer srv.Close()
c := NewClient("token", srv.URL, AllowInsecureHTTP())
c.SetHTTPClient(srv.Client())
_, err := c.GetCommitStatuses(context.Background(), "owner", "repo", "abc123")
if err == nil {
t.Fatal("expected error when check-runs endpoint fails after statuses succeed")
}
if !strings.Contains(err.Error(), "fetch check runs") {
t.Errorf("expected check runs error, got: %v", err)
}
}
func stringPtr(s string) *string { func stringPtr(s string) *string {
return &s return &s
} }