diff --git a/github/client.go b/github/client.go new file mode 100644 index 0000000..7f5c7f9 --- /dev/null +++ b/github/client.go @@ -0,0 +1,327 @@ +// Package github provides a client for the GitHub API. +// It supports pull request operations, file content retrieval, CI status checks, +// and directory listing for both github.com and GitHub Enterprise. +package github + +import ( + "context" + "errors" + "fmt" + "io" + "net/http" + "net/url" + "strconv" + "strings" + "time" +) + +const ( + defaultBaseURL = "https://api.github.com" + userAgent = "review-bot/1.0" + + // maxResponseBytes limits successful response body reads to 10 MiB. + maxResponseBytes = 10 * 1024 * 1024 +) + +// APIError represents an HTTP error response from the GitHub API. +// It carries the status code so callers can distinguish between +// different failure modes (e.g. 404 vs 500). +// +// The Body field stores up to 4 KiB of the raw response for programmatic +// inspection. Error() truncates to 200 bytes for safe logging, but callers +// should avoid logging or propagating Body directly in production since it may +// contain sensitive details from the upstream server. +type APIError struct { + StatusCode int + Body string +} + +func (e *APIError) Error() string { + body := e.Body + if len(body) > 200 { + body = body[:200] + "...(truncated)" + } + // Sanitize newlines to prevent log injection from upstream response bodies. + body = strings.ReplaceAll(body, "\n", " ") + body = strings.ReplaceAll(body, "\r", " ") + return fmt.Sprintf("HTTP %d: %s", e.StatusCode, body) +} + +// IsNotFound reports whether an error is an API 404 response. +func IsNotFound(err error) bool { + if apiErr, ok := asAPIError(err); ok { + return apiErr.StatusCode == http.StatusNotFound + } + return false +} + +// IsUnauthorized reports whether an error is an API 401 response. +func IsUnauthorized(err error) bool { + if apiErr, ok := asAPIError(err); ok { + return apiErr.StatusCode == http.StatusUnauthorized + } + return false +} + +func asAPIError(err error) (*APIError, bool) { + if err == nil { + return nil, false + } + var target *APIError + if errors.As(err, &target) { + return target, true + } + return nil, false +} + +// clientConfig holds optional configuration for NewClient. +type clientConfig struct { + allowInsecureHTTP bool +} + +// ClientOption configures optional behavior of NewClient. +type ClientOption func(*clientConfig) + +// AllowInsecureHTTP permits the client to use HTTP (non-TLS) base URLs. +// This should only be used for trusted internal deployments or testing. +func AllowInsecureHTTP() ClientOption { + return func(c *clientConfig) { + c.allowInsecureHTTP = true + } +} + +// Client interacts with the GitHub API. +// A Client is safe for concurrent use by multiple goroutines. +// SetHTTPClient and SetRetryBackoff are intended for test setup only and must +// be called before any goroutines issue requests; they have no synchronization. +type Client struct { + baseURL string + token string + allowInsecureHTTP bool + httpClient *http.Client + + // retryBackoff defines the delays between retry attempts for 429 responses. + // retryBackoff[i] is the delay before attempt i+1 (after attempt i fails). + // If nil, defaults to {1s, 2s}. Set to shorter durations in tests via SetRetryBackoff. + retryBackoff []time.Duration +} + +// defaultCheckRedirect is the redirect policy used by NewClient and SetHTTPClient(nil). +// It rejects HTTPS→HTTP protocol downgrades (to prevent plaintext leakage) and strips +// the Authorization header on cross-host redirects to prevent credential leakage to +// third-party hosts (e.g. CDN redirects from GitHub). +func defaultCheckRedirect(req *http.Request, via []*http.Request) error { + if len(via) >= 10 { + return fmt.Errorf("stopped after 10 redirects") + } + // Guard: net/http guarantees len(via) >= 1 but this is undocumented; + // defend against zero-length to avoid panic on index out of range. + if len(via) == 0 { + return nil + } + prev := via[len(via)-1] + // Reject protocol downgrade: HTTPS→HTTP leaks request metadata over plaintext. + if prev.URL.Scheme == "https" && req.URL.Scheme == "http" { + return fmt.Errorf("refusing redirect from HTTPS to HTTP (%s → %s)", prev.URL.Host, req.URL.Host) + } + // Strip Authorization on cross-host redirect to avoid leaking credentials + // to third-party hosts (GitHub legitimately redirects to CDN hosts). + if req.URL.Host != prev.URL.Host { + req.Header.Del("Authorization") + } + return nil +} + +// NewClient creates a new GitHub API client. +// If baseURL is empty, it defaults to https://api.github.com. +// For GitHub Enterprise, pass the API base URL (e.g. https://github.concur.com/api/v3). +// The baseURL must use HTTPS; pass AllowInsecureHTTP() as an option to permit HTTP +// for trusted internal deployments (e.g. local testing). +func NewClient(token, baseURL string, opts ...ClientOption) *Client { + if baseURL == "" { + baseURL = defaultBaseURL + } + cfg := clientConfig{} + for _, o := range opts { + o(&cfg) + } + return &Client{ + baseURL: strings.TrimRight(baseURL, "/"), + allowInsecureHTTP: cfg.allowInsecureHTTP, + token: token, + httpClient: &http.Client{ + Timeout: 30 * time.Second, + CheckRedirect: defaultCheckRedirect, + }, + } +} + +// SetHTTPClient sets the underlying HTTP client used for requests. +// This is intended for test setup only to inject mock transports; it must be +// called before any goroutines issue requests. +// +// Passing nil restores the default client (30s timeout + auth-stripping +// CheckRedirect policy matching NewClient). +// +// Callers providing a non-nil client are responsible for configuring a safe +// CheckRedirect policy. Without one, the default net/http behavior will follow +// redirects and may forward the Authorization header to untrusted hosts. +func (c *Client) SetHTTPClient(hc *http.Client) { + if hc == nil { + hc = &http.Client{ + Timeout: 30 * time.Second, + CheckRedirect: defaultCheckRedirect, + } + } + c.httpClient = hc +} + +// SetRetryBackoff configures the retry backoff durations for testing. +// It must be called before any goroutines issue requests. +// In production the default {1s, 2s} applies. +func (c *Client) SetRetryBackoff(d []time.Duration) { + c.retryBackoff = d +} + +// doRequest performs an HTTP request with retry on 429 rate limit responses. +// It respects the Retry-After header when present (capped at maxRetryAfter). +// Transport errors (network failures, context cancellation) are not retried. +func (c *Client) doRequest(ctx context.Context, method, reqURL string, accept string) ([]byte, error) { + const maxAttempts = 3 + const maxRetryAfter = 120 * time.Second + + var backoff []time.Duration + if c.retryBackoff != nil { + backoff = make([]time.Duration, len(c.retryBackoff)) + copy(backoff, c.retryBackoff) + } else { + backoff = []time.Duration{1 * time.Second, 2 * time.Second} + } + + // maxErrorBodyBytes limits how much of an error response body is stored. + // Kept small (4 KiB) to reduce the risk of sensitive data leakage if callers + // log APIError.Body directly. Error() further truncates to 200 bytes. + const maxErrorBodyBytes = 4 * 1024 + + // Reject non-HTTPS URLs early since the URL is immutable across retries. + if c.token != "" && !c.allowInsecureHTTP { + parsed, err := url.Parse(reqURL) + if err != nil { + return nil, fmt.Errorf("parse request URL: %w", err) + } + if !strings.EqualFold(parsed.Scheme, "https") { + return nil, fmt.Errorf("refusing to send credentials over non-HTTPS URL %q (use AllowInsecureHTTP option for trusted networks)", reqURL) + } + } + + var lastErr error + for attempt := 0; attempt < maxAttempts; attempt++ { + if attempt > 0 { + var delay time.Duration + if attempt-1 < len(backoff) { + delay = backoff[attempt-1] + } + if delay > 0 { + timer := time.NewTimer(delay) + select { + case <-timer.C: + timer.Stop() // no-op after fire; kept for symmetry with the ctx.Done case + case <-ctx.Done(): + timer.Stop() + return nil, ctx.Err() + } + } + } + + req, err := http.NewRequestWithContext(ctx, method, reqURL, nil) + if err != nil { + return nil, fmt.Errorf("create request: %w", err) + } + if c.token != "" { + // Bearer is the OAuth2 standard and is accepted by GitHub for both + // classic PATs and fine-grained tokens. The alternative "token" scheme + // is GitHub-specific and offers no additional compatibility. + req.Header.Set("Authorization", "Bearer "+c.token) + } + req.Header.Set("User-Agent", userAgent) + if accept != "" { + req.Header.Set("Accept", accept) + } else { + req.Header.Set("Accept", "application/vnd.github+json") + } + + resp, err := c.httpClient.Do(req) + if err != nil { + return nil, fmt.Errorf("do request: %w", err) + } + + body, done, err := c.handleResponse(resp, maxResponseBytes, maxErrorBodyBytes) + if done { + return body, err + } + lastErr = err + + // Retry on 429 rate limit + if resp.StatusCode == http.StatusTooManyRequests && attempt < maxAttempts-1 { + // Check for Retry-After header and override backoff if present. + // Supports both integer seconds (common) and HTTP-date format (RFC 7231). + if ra := resp.Header.Get("Retry-After"); ra != "" { + if seconds, err := strconv.Atoi(ra); err == nil && seconds > 0 { + delay := time.Duration(seconds) * time.Second + if delay > maxRetryAfter { + delay = maxRetryAfter + } + if attempt < len(backoff) { + backoff[attempt] = delay + } + } else if retryAt, err := http.ParseTime(ra); err == nil { + delay := time.Until(retryAt) + if delay < 0 { + delay = 0 + } + if delay > maxRetryAfter { + delay = maxRetryAfter + } + if attempt < len(backoff) { + backoff[attempt] = delay + } + } + } + continue + } + + // Don't retry other errors + return nil, lastErr + } + + return nil, lastErr +} + +// handleResponse reads and closes the response body, returning the result. +// It uses defer to ensure the body is always closed regardless of code path. +// Returns (body, done, err) where done=true means the caller should return immediately. +func (c *Client) handleResponse(resp *http.Response, maxRespBytes int, maxErrBytes int) ([]byte, bool, error) { + defer resp.Body.Close() + + if resp.StatusCode >= 200 && resp.StatusCode < 300 { + body, err := io.ReadAll(io.LimitReader(resp.Body, int64(maxRespBytes)+1)) + if err != nil { + return nil, true, fmt.Errorf("read response body: %w", err) + } + if len(body) > maxRespBytes { + return nil, true, fmt.Errorf("response body exceeded %d bytes (truncated)", maxRespBytes) + } + return body, true, nil + } + + errBody, readErr := io.ReadAll(io.LimitReader(resp.Body, int64(maxErrBytes))) + if readErr != nil && len(errBody) == 0 { + errBody = []byte(fmt.Sprintf("[error reading response body: %v]", readErr)) + } + return nil, false, &APIError{StatusCode: resp.StatusCode, Body: string(errBody)} +} + +// doGet is a convenience wrapper for GET requests with the default Accept header. +func (c *Client) doGet(ctx context.Context, reqURL string) ([]byte, error) { + return c.doRequest(ctx, http.MethodGet, reqURL, "") +} diff --git a/github/client_test.go b/github/client_test.go new file mode 100644 index 0000000..a8ccc06 --- /dev/null +++ b/github/client_test.go @@ -0,0 +1,556 @@ +package github + +import ( + "context" + "net/http" + "net/http/httptest" + "net/url" + "strings" + "testing" + "time" +) + +func TestNewClient_DefaultBaseURL(t *testing.T) { + c := NewClient("test-token", "") + if c.baseURL != "https://api.github.com" { + t.Errorf("expected default base URL, got %q", c.baseURL) + } +} + +func TestNewClient_CustomBaseURL(t *testing.T) { + c := NewClient("test-token", "https://github.concur.com/api/v3") + if c.baseURL != "https://github.concur.com/api/v3" { + t.Errorf("expected custom base URL, got %q", c.baseURL) + } +} + +func TestNewClient_TrimsTrailingSlash(t *testing.T) { + c := NewClient("test-token", "https://github.concur.com/api/v3/") + if c.baseURL != "https://github.concur.com/api/v3" { + t.Errorf("expected trailing slash trimmed, got %q", c.baseURL) + } +} + +func TestDoRequest_SetsAuthHeader(t *testing.T) { + var gotAuth string + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + gotAuth = r.Header.Get("Authorization") + w.WriteHeader(200) + w.Write([]byte("{}")) + })) + defer srv.Close() + + c := NewClient("my-token", srv.URL, AllowInsecureHTTP()) + c.SetHTTPClient(srv.Client()) + _, _ = c.doGet(context.Background(), srv.URL+"/test") + + if gotAuth != "Bearer my-token" { + t.Errorf("expected Bearer auth, got %q", gotAuth) + } +} + +func TestDoRequest_SetsDefaultAcceptHeader(t *testing.T) { + var gotAccept string + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + gotAccept = r.Header.Get("Accept") + w.WriteHeader(200) + w.Write([]byte("{}")) + })) + defer srv.Close() + + c := NewClient("token", srv.URL, AllowInsecureHTTP()) + c.SetHTTPClient(srv.Client()) + _, _ = c.doGet(context.Background(), srv.URL+"/test") + + if gotAccept != "application/vnd.github+json" { + t.Errorf("expected default Accept header, got %q", gotAccept) + } +} + +func TestDoRequest_429Retry(t *testing.T) { + attempts := 0 + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + attempts++ + if attempts == 1 { + w.WriteHeader(429) + w.Write([]byte(`{"message":"rate limit"}`)) + return + } + w.WriteHeader(200) + w.Write([]byte(`{"ok":true}`)) + })) + defer srv.Close() + + c := NewClient("token", srv.URL, AllowInsecureHTTP()) + c.SetHTTPClient(srv.Client()) + c.SetRetryBackoff([]time.Duration{10 * time.Millisecond, 10 * time.Millisecond}) + + body, err := c.doGet(context.Background(), srv.URL+"/test") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if string(body) != `{"ok":true}` { + t.Errorf("unexpected body: %s", body) + } + if attempts != 2 { + t.Errorf("expected 2 attempts, got %d", attempts) + } +} + +func TestDoRequest_429ExhaustsRetries(t *testing.T) { + attempts := 0 + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + attempts++ + w.WriteHeader(429) + w.Write([]byte(`{"message":"rate limit"}`)) + })) + defer srv.Close() + + c := NewClient("token", srv.URL, AllowInsecureHTTP()) + c.SetHTTPClient(srv.Client()) + c.SetRetryBackoff([]time.Duration{1 * time.Millisecond, 1 * time.Millisecond}) + + _, err := c.doGet(context.Background(), srv.URL+"/test") + if err == nil { + t.Fatal("expected error after exhausting retries") + } + apiErr, ok := err.(*APIError) + if !ok { + t.Fatalf("expected *APIError, got %T", err) + } + if apiErr.StatusCode != 429 { + t.Errorf("expected 429, got %d", apiErr.StatusCode) + } + if attempts != 3 { + t.Errorf("expected 3 attempts, got %d", attempts) + } +} + +func TestDoRequest_404NoRetry(t *testing.T) { + attempts := 0 + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + attempts++ + w.WriteHeader(404) + w.Write([]byte(`{"message":"not found"}`)) + })) + defer srv.Close() + + c := NewClient("token", srv.URL, AllowInsecureHTTP()) + c.SetHTTPClient(srv.Client()) + + _, err := c.doGet(context.Background(), srv.URL+"/test") + if err == nil { + t.Fatal("expected error for 404") + } + if attempts != 1 { + t.Errorf("expected 1 attempt (no retry on 404), got %d", attempts) + } +} + +func TestDoRequest_401NoRetry(t *testing.T) { + attempts := 0 + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + attempts++ + w.WriteHeader(401) + w.Write([]byte(`{"message":"bad credentials"}`)) + })) + defer srv.Close() + + c := NewClient("token", srv.URL, AllowInsecureHTTP()) + c.SetHTTPClient(srv.Client()) + + _, err := c.doGet(context.Background(), srv.URL+"/test") + if err == nil { + t.Fatal("expected error for 401") + } + if attempts != 1 { + t.Errorf("expected 1 attempt (no retry on 401), got %d", attempts) + } +} + +func TestIsNotFound(t *testing.T) { + err := &APIError{StatusCode: 404, Body: "not found"} + if !IsNotFound(err) { + t.Error("expected IsNotFound to return true for 404") + } + err2 := &APIError{StatusCode: 500, Body: "server error"} + if IsNotFound(err2) { + t.Error("expected IsNotFound to return false for 500") + } +} + +func TestIsUnauthorized(t *testing.T) { + err := &APIError{StatusCode: 401, Body: "bad credentials"} + if !IsUnauthorized(err) { + t.Error("expected IsUnauthorized to return true for 401") + } +} + +func TestAPIError_SanitizesNewlines(t *testing.T) { + err := &APIError{StatusCode: 500, Body: "line1\ninjected\rmore"} + msg := err.Error() + if strings.Contains(msg, "\n") || strings.Contains(msg, "\r") { + t.Errorf("expected newlines to be sanitized, got: %q", msg) + } + if !strings.Contains(msg, "line1 injected more") { + t.Errorf("expected sanitized body, got: %q", msg) + } +} + +func TestDoRequest_429RetryAfterHeader(t *testing.T) { + if testing.Short() { + t.Skip("skipping slow retry test in short mode") + } + attempts := 0 + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + attempts++ + if attempts == 1 { + w.Header().Set("Retry-After", "1") + w.WriteHeader(429) + w.Write([]byte(`{"message":"rate limit"}`)) + return + } + w.WriteHeader(200) + w.Write([]byte(`{"ok":true}`)) + })) + defer srv.Close() + + c := NewClient("token", srv.URL, AllowInsecureHTTP()) + c.SetHTTPClient(srv.Client()) + // Use short backoff; Retry-After should override + c.SetRetryBackoff([]time.Duration{1 * time.Millisecond, 1 * time.Millisecond}) + + start := time.Now() + body, err := c.doGet(context.Background(), srv.URL+"/test") + elapsed := time.Since(start) + + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if string(body) != `{"ok":true}` { + t.Errorf("unexpected body: %s", body) + } + if attempts != 2 { + t.Errorf("expected 2 attempts, got %d", attempts) + } + // Retry-After: 1 means at least 1 second delay + if elapsed < 900*time.Millisecond { + t.Errorf("expected ~1s delay from Retry-After, got %v", elapsed) + } +} + +func TestDoRequest_RetryAfterDoesNotMutateBackoff(t *testing.T) { + if testing.Short() { + t.Skip("skipping slow retry test in short mode") + } + attempts := 0 + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + attempts++ + if attempts == 1 { + w.Header().Set("Retry-After", "1") + w.WriteHeader(429) + w.Write([]byte(`{"message":"rate limit"}`)) + return + } + w.WriteHeader(200) + w.Write([]byte(`{"ok":true}`)) + })) + defer srv.Close() + + c := NewClient("token", srv.URL, AllowInsecureHTTP()) + c.SetHTTPClient(srv.Client()) + c.SetRetryBackoff([]time.Duration{1 * time.Millisecond, 1 * time.Millisecond}) + + _, err := c.doGet(context.Background(), srv.URL+"/test") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + // Verify the original retryBackoff slice was not mutated + if c.retryBackoff[0] != 1*time.Millisecond { + t.Errorf("retryBackoff[0] was mutated: got %v, want 1ms", c.retryBackoff[0]) + } + if c.retryBackoff[1] != 1*time.Millisecond { + t.Errorf("retryBackoff[1] was mutated: got %v, want 1ms", c.retryBackoff[1]) + } +} + +func TestDoRequest_429RetryAfterHTTPDate(t *testing.T) { + if testing.Short() { + t.Skip("skipping slow Retry-After HTTP-date test in short mode") + } + attempts := 0 + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + attempts++ + if attempts == 1 { + // Use HTTP-date format (RFC 7231) — a time 2 seconds in the future. + future := time.Now().Add(2 * time.Second).UTC() + w.Header().Set("Retry-After", future.Format(http.TimeFormat)) + w.WriteHeader(429) + w.Write([]byte(`{"message":"rate limit"}`)) + return + } + w.WriteHeader(200) + w.Write([]byte(`{"ok":true}`)) + })) + defer srv.Close() + + c := NewClient("token", srv.URL, AllowInsecureHTTP()) + c.SetHTTPClient(srv.Client()) + c.SetRetryBackoff([]time.Duration{1 * time.Millisecond, 1 * time.Millisecond}) + + start := time.Now() + body, err := c.doGet(context.Background(), srv.URL+"/test") + elapsed := time.Since(start) + + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if string(body) != `{"ok":true}` { + t.Errorf("unexpected body: %s", body) + } + if attempts != 2 { + t.Errorf("expected 2 attempts, got %d", attempts) + } + // HTTP-date was ~2s in the future; by the time client processes it, + // time.Until gives ~1-2s. Verify it's meaningfully delayed (not instant). + if elapsed < 500*time.Millisecond { + t.Errorf("expected meaningful delay from HTTP-date Retry-After, got %v", elapsed) + } +} + +func TestDoRequest_429RetryAfterHTTPDateInPast(t *testing.T) { + attempts := 0 + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + attempts++ + if attempts == 1 { + // Use a time in the past — should result in zero/immediate retry. + past := time.Now().Add(-10 * time.Second).UTC() + w.Header().Set("Retry-After", past.Format(http.TimeFormat)) + w.WriteHeader(429) + w.Write([]byte(`{"message":"rate limit"}`)) + return + } + w.WriteHeader(200) + w.Write([]byte(`{"ok":true}`)) + })) + defer srv.Close() + + c := NewClient("token", srv.URL, AllowInsecureHTTP()) + c.SetHTTPClient(srv.Client()) + c.SetRetryBackoff([]time.Duration{5 * time.Second, 5 * time.Second}) + + start := time.Now() + _, err := c.doGet(context.Background(), srv.URL+"/test") + elapsed := time.Since(start) + + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if attempts != 2 { + t.Errorf("expected 2 attempts, got %d", attempts) + } + // Past date should override the 5s backoff to ~0 + if elapsed > 500*time.Millisecond { + t.Errorf("expected near-instant retry for past HTTP-date, got %v", elapsed) + } +} + +func TestDoRequest_SetsUserAgentHeader(t *testing.T) { + var gotUA string + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + gotUA = r.Header.Get("User-Agent") + w.WriteHeader(200) + w.Write([]byte("{}")) + })) + defer srv.Close() + + c := NewClient("token", srv.URL, AllowInsecureHTTP()) + c.SetHTTPClient(srv.Client()) + _, _ = c.doGet(context.Background(), srv.URL+"/test") + + if gotUA != "review-bot/1.0" { + t.Errorf("expected User-Agent 'review-bot/1.0', got %q", gotUA) + } +} + +func TestDoRequest_LimitsResponseBody(t *testing.T) { + // Verify that oversized responses return an error rather than silently truncating. + bigBody := strings.Repeat("x", maxResponseBytes+1024) + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(200) + w.Write([]byte(bigBody)) + })) + defer srv.Close() + + c := NewClient("token", srv.URL, AllowInsecureHTTP()) + c.SetHTTPClient(srv.Client()) + _, err := c.doGet(context.Background(), srv.URL+"/test") + if err == nil { + t.Fatal("expected error for oversized response body") + } + if !strings.Contains(err.Error(), "exceeded") { + t.Errorf("expected truncation error, got: %v", err) + } +} + +func TestDoRequest_AcceptsExactlyAtLimit(t *testing.T) { + // A response body exactly equal to maxResponseBytes should succeed (not error). + exactBody := strings.Repeat("x", maxResponseBytes) + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(200) + w.Write([]byte(exactBody)) + })) + defer srv.Close() + + c := NewClient("token", srv.URL, AllowInsecureHTTP()) + c.SetHTTPClient(srv.Client()) + body, err := c.doGet(context.Background(), srv.URL+"/test") + if err != nil { + t.Fatalf("unexpected error for exactly-at-limit body: %v", err) + } + if len(body) != maxResponseBytes { + t.Errorf("expected body length %d, got %d", maxResponseBytes, len(body)) + } +} + +func TestDoRequest_SkipsAuthWhenTokenEmpty(t *testing.T) { + var gotAuth string + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + gotAuth = r.Header.Get("Authorization") + w.WriteHeader(200) + w.Write([]byte("{}")) + })) + defer srv.Close() + + c := NewClient("", srv.URL, AllowInsecureHTTP()) // empty token + c.SetHTTPClient(srv.Client()) + _, _ = c.doGet(context.Background(), srv.URL+"/test") + + if gotAuth != "" { + t.Errorf("expected no Authorization header with empty token, got %q", gotAuth) + } +} + +func TestNewClient_CheckRedirectStripsAuthOnCrossHost(t *testing.T) { + // Verify the CheckRedirect function is configured + c := NewClient("secret-token", "https://api.github.com") + if c.httpClient.CheckRedirect == nil { + t.Fatal("expected CheckRedirect to be set") + } +} + +func TestDefaultCheckRedirect_RejectsHTTPSToHTTP(t *testing.T) { + prev := &http.Request{URL: &url.URL{Scheme: "https", Host: "api.github.com", Path: "/foo"}} + req := &http.Request{ + URL: &url.URL{Scheme: "http", Host: "api.github.com", Path: "/foo"}, + Header: http.Header{"Authorization": []string{"Bearer token"}}, + } + err := defaultCheckRedirect(req, []*http.Request{prev}) + if err == nil { + t.Fatal("expected error on HTTPS→HTTP redirect") + } + if !strings.Contains(err.Error(), "refusing redirect from HTTPS to HTTP") { + t.Errorf("unexpected error message: %v", err) + } +} + +func TestDefaultCheckRedirect_StripsAuthOnCrossHost(t *testing.T) { + prev := &http.Request{URL: &url.URL{Scheme: "https", Host: "api.github.com", Path: "/foo"}} + req := &http.Request{ + URL: &url.URL{Scheme: "https", Host: "objects.githubusercontent.com", Path: "/bar"}, + Header: http.Header{"Authorization": []string{"Bearer token"}}, + } + err := defaultCheckRedirect(req, []*http.Request{prev}) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if auth := req.Header.Get("Authorization"); auth != "" { + t.Errorf("expected Authorization header to be stripped, got %q", auth) + } +} + +func TestDefaultCheckRedirect_PreservesAuthOnSameHost(t *testing.T) { + prev := &http.Request{URL: &url.URL{Scheme: "https", Host: "api.github.com", Path: "/foo"}} + req := &http.Request{ + URL: &url.URL{Scheme: "https", Host: "api.github.com", Path: "/bar"}, + Header: http.Header{"Authorization": []string{"Bearer token"}}, + } + err := defaultCheckRedirect(req, []*http.Request{prev}) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if auth := req.Header.Get("Authorization"); auth != "Bearer token" { + t.Errorf("expected Authorization to be preserved, got %q", auth) + } +} + +func TestDoRequest_RejectsHTTPWithToken(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(200) + w.Write([]byte("{}")) + })) + defer srv.Close() + + // Without AllowInsecureHTTP, should refuse to send token over HTTP + c := NewClient("secret-token", srv.URL) + c.SetHTTPClient(srv.Client()) + _, err := c.doGet(context.Background(), srv.URL+"/test") + if err == nil { + t.Fatal("expected error when sending token over HTTP") + } + if !strings.Contains(err.Error(), "refusing to send credentials") { + t.Errorf("unexpected error message: %v", err) + } +} + +func TestDoRequest_AllowsHTTPWithoutToken(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(200) + w.Write([]byte(`{"ok":true}`)) + })) + defer srv.Close() + + // Without token, HTTP should be fine (no credentials to leak) + c := NewClient("", srv.URL) + c.SetHTTPClient(srv.Client()) + body, err := c.doGet(context.Background(), srv.URL+"/test") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if string(body) != `{"ok":true}` { + t.Errorf("unexpected body: %s", body) + } +} + +func TestDoRequest_AllowsHTTPWithInsecureOption(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(200) + w.Write([]byte(`{"ok":true}`)) + })) + defer srv.Close() + + c := NewClient("secret-token", srv.URL, AllowInsecureHTTP()) + c.SetHTTPClient(srv.Client()) + body, err := c.doGet(context.Background(), srv.URL+"/test") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if string(body) != `{"ok":true}` { + t.Errorf("unexpected body: %s", body) + } +} + +func TestSetHTTPClient_NilRestoresDefault(t *testing.T) { + c := NewClient("token", "https://api.github.com") + c.SetHTTPClient(nil) + if c.httpClient == nil { + t.Fatal("expected non-nil httpClient after SetHTTPClient(nil)") + } + if c.httpClient.Timeout != 30*time.Second { + t.Errorf("expected 30s timeout, got %v", c.httpClient.Timeout) + } + if c.httpClient.CheckRedirect == nil { + t.Fatal("expected CheckRedirect policy after SetHTTPClient(nil)") + } +} diff --git a/github/conformance_test.go b/github/conformance_test.go new file mode 100644 index 0000000..ca13188 --- /dev/null +++ b/github/conformance_test.go @@ -0,0 +1,13 @@ +package github_test + +import ( + "gitea.weiker.me/rodin/review-bot/github" + "gitea.weiker.me/rodin/review-bot/vcs" +) + +// Compile-time interface conformance assertions. +// These verify github.Client satisfies vcs.PRReader and vcs.FileReader. +var ( + _ vcs.PRReader = (*github.Client)(nil) + _ vcs.FileReader = (*github.Client)(nil) +) diff --git a/github/files.go b/github/files.go new file mode 100644 index 0000000..9f04941 --- /dev/null +++ b/github/files.go @@ -0,0 +1,135 @@ +package github + +import ( + "context" + "encoding/base64" + "encoding/json" + "fmt" + "net/url" + "strings" + + "gitea.weiker.me/rodin/review-bot/vcs" +) + +// GetFileContent fetches a file from a repo at the given ref. +// Delegates to GetFileContentAtRef with the provided ref. +func (c *Client) GetFileContent(ctx context.Context, owner, repo, path, ref string) (string, error) { + return c.GetFileContentAtRef(ctx, owner, repo, path, ref) +} + +// GetFileContentAtRef fetches a file at a specific ref from a repo. +// If ref is empty, the query parameter is omitted (uses default branch). +// +// Note: dot-segments ("." and "..") in the path are silently removed to +// prevent path traversal. This means a path like "foo/../bar" resolves +// to "foo/bar" rather than "bar". +func (c *Client) GetFileContentAtRef(ctx context.Context, owner, repo, path, ref string) (string, error) { + reqURL := fmt.Sprintf("%s/repos/%s/%s/contents/%s", + c.baseURL, url.PathEscape(owner), url.PathEscape(repo), escapePath(path)) + if ref != "" { + reqURL += "?ref=" + url.QueryEscape(ref) + } + body, err := c.doGet(ctx, reqURL) + if err != nil { + return "", fmt.Errorf("fetch file %s: %w", path, err) + } + var resp struct { + Content string `json:"content"` + Encoding string `json:"encoding"` + } + if err := json.Unmarshal(body, &resp); err != nil { + return "", fmt.Errorf("parse file content JSON: %w", err) + } + if resp.Encoding != "base64" { + return "", fmt.Errorf("unexpected encoding %q for file %s", resp.Encoding, path) + } + decoded, err := decodeBase64Content(resp.Content) + if err != nil { + return "", fmt.Errorf("decode base64 content for %s: %w", path, err) + } + return decoded, nil +} + +// ListContents lists files and directories at a given path in a repo. +// Returns the directory listing from the GitHub contents API. +// If the path points to a single file (not a directory), the API returns +// a JSON object instead of an array; this is handled by returning a +// single-element slice. +// +// Note: dot-segments ("." and "..") in the path are silently removed to +// prevent path traversal. This means a path like "foo/../bar" resolves +// to "foo/bar" rather than "bar". +func (c *Client) ListContents(ctx context.Context, owner, repo, path string) ([]vcs.ContentEntry, error) { + reqURL := fmt.Sprintf("%s/repos/%s/%s/contents/%s", + c.baseURL, url.PathEscape(owner), url.PathEscape(repo), escapePath(path)) + body, err := c.doGet(ctx, reqURL) + if err != nil { + return nil, fmt.Errorf("list contents %s: %w", path, err) + } + + type entry struct { + Name string `json:"name"` + Path string `json:"path"` + Type string `json:"type"` + } + + // The GitHub contents API returns an array for directories and an object + // for single files. Try array first (common case), then fall back to object. + // An empty array ([]) is valid — it represents an empty directory — and + // results in a zero-length slice returned without error. + var entries []entry + if err := json.Unmarshal(body, &entries); err != nil { + var single entry + if err2 := json.Unmarshal(body, &single); err2 != nil { + return nil, fmt.Errorf("parse contents JSON: as array: %v; as object: %w", err, err2) + } + // Guard against empty objects ({}) or unexpected shapes that + // unmarshal successfully but carry no useful data. + if single.Name == "" && single.Path == "" && single.Type == "" { + return nil, fmt.Errorf("parse contents JSON: unexpected response format") + } + entries = []entry{single} + } + + result := make([]vcs.ContentEntry, len(entries)) + for i, e := range entries { + result[i] = vcs.ContentEntry{ + Name: e.Name, + Path: e.Path, + Type: e.Type, + } + } + return result, nil +} + +// escapePath escapes each segment of a relative file path for use in URLs. +// Slashes are preserved as path separators; other special characters are escaped. +// Dot-segments ("." and "..") and empty segments (from consecutive slashes like +// "a//b") are silently removed to prevent path traversal and produce canonical +// paths. This is intentional: callers may receive a different path than requested +// without error. The function is package-private, and all callers +// (GetFileContentAtRef, ListContents) already handle missing-file errors from the +// API if the cleaned path doesn't match what the caller intended. +func escapePath(p string) string { + parts := strings.Split(p, "/") + var clean []string + for _, part := range parts { + if part == "." || part == ".." || part == "" { + continue + } + clean = append(clean, url.PathEscape(part)) + } + return strings.Join(clean, "/") +} + +// decodeBase64Content decodes base64-encoded content from the GitHub contents API. +// GitHub returns base64 content with line breaks for formatting; we strip \r and \n before decoding. +func decodeBase64Content(encoded string) (string, error) { + // GitHub inserts newlines in base64 content + cleaned := strings.NewReplacer("\n", "", "\r", "").Replace(encoded) + decoded, err := base64.StdEncoding.DecodeString(cleaned) + if err != nil { + return "", err + } + return string(decoded), nil +} diff --git a/github/files_test.go b/github/files_test.go new file mode 100644 index 0000000..eda64a8 --- /dev/null +++ b/github/files_test.go @@ -0,0 +1,334 @@ +package github + +import ( + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + "time" +) + +func TestGetFileContent_DelegatesToGetFileContentAtRef(t *testing.T) { + var gotRef string + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + gotRef = r.URL.Query().Get("ref") + json.NewEncoder(w).Encode(map[string]string{ + "content": "dGVzdA==", // "test" in base64 + "encoding": "base64", + }) + })) + defer srv.Close() + + c := NewClient("token", srv.URL, AllowInsecureHTTP()) + c.SetHTTPClient(srv.Client()) + + // Call with empty ref — should not include ref param + content, err := c.GetFileContent(context.Background(), "owner", "repo", "file.go", "") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if content != "test" { + t.Errorf("expected 'test', got %q", content) + } + if gotRef != "" { + t.Errorf("expected empty ref, got %q", gotRef) + } +} + +func TestGetFileContent_WithRef(t *testing.T) { + var gotRef string + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + gotRef = r.URL.Query().Get("ref") + json.NewEncoder(w).Encode(map[string]string{ + "content": "dGVzdA==", + "encoding": "base64", + }) + })) + defer srv.Close() + + c := NewClient("token", srv.URL, AllowInsecureHTTP()) + c.SetHTTPClient(srv.Client()) + + _, err := c.GetFileContent(context.Background(), "owner", "repo", "file.go", "abc123") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if gotRef != "abc123" { + t.Errorf("expected ref 'abc123', got %q", gotRef) + } +} + +func TestGetFileContent_404(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(404) + w.Write([]byte(`{"message":"Not Found"}`)) + })) + defer srv.Close() + + c := NewClient("token", srv.URL, AllowInsecureHTTP()) + c.SetHTTPClient(srv.Client()) + + _, err := c.GetFileContent(context.Background(), "owner", "repo", "missing.go", "") + if err == nil { + t.Fatal("expected error for 404") + } +} + +func TestGetFileContent_401(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(401) + w.Write([]byte(`{"message":"Bad credentials"}`)) + })) + defer srv.Close() + + c := NewClient("token", srv.URL, AllowInsecureHTTP()) + c.SetHTTPClient(srv.Client()) + + _, err := c.GetFileContent(context.Background(), "owner", "repo", "file.go", "") + if err == nil { + t.Fatal("expected error for 401") + } +} + +func TestGetFileContent_429Retry(t *testing.T) { + attempts := 0 + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + attempts++ + if attempts == 1 { + w.WriteHeader(429) + w.Write([]byte(`{"message":"rate limit"}`)) + return + } + json.NewEncoder(w).Encode(map[string]string{ + "content": "b2s=", + "encoding": "base64", + }) + })) + defer srv.Close() + + c := NewClient("token", srv.URL, AllowInsecureHTTP()) + c.SetHTTPClient(srv.Client()) + c.SetRetryBackoff([]time.Duration{1 * time.Millisecond}) + + content, err := c.GetFileContent(context.Background(), "owner", "repo", "file.go", "") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if content != "ok" { + t.Errorf("expected 'ok', got %q", content) + } + if attempts != 2 { + t.Errorf("expected 2 attempts, got %d", attempts) + } +} + +func TestGetFileContent_MalformedJSON(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(200) + w.Write([]byte(`not json`)) + })) + defer srv.Close() + + c := NewClient("token", srv.URL, AllowInsecureHTTP()) + c.SetHTTPClient(srv.Client()) + + _, err := c.GetFileContent(context.Background(), "owner", "repo", "file.go", "") + if err == nil { + t.Fatal("expected error for malformed JSON") + } +} + +func TestListContents_HappyPath(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/repos/owner/repo/contents/src" { + t.Errorf("unexpected path: %s", r.URL.Path) + } + json.NewEncoder(w).Encode([]map[string]string{ + {"name": "main.go", "path": "src/main.go", "type": "file"}, + {"name": "lib", "path": "src/lib", "type": "dir"}, + }) + })) + defer srv.Close() + + c := NewClient("token", srv.URL, AllowInsecureHTTP()) + c.SetHTTPClient(srv.Client()) + + entries, err := c.ListContents(context.Background(), "owner", "repo", "src") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if len(entries) != 2 { + t.Fatalf("expected 2 entries, got %d", len(entries)) + } + if entries[0].Name != "main.go" { + t.Errorf("expected name 'main.go', got %q", entries[0].Name) + } + if entries[0].Path != "src/main.go" { + t.Errorf("expected path 'src/main.go', got %q", entries[0].Path) + } + if entries[0].Type != "file" { + t.Errorf("expected type 'file', got %q", entries[0].Type) + } + if entries[1].Name != "lib" { + t.Errorf("expected name 'lib', got %q", entries[1].Name) + } + if entries[1].Type != "dir" { + t.Errorf("expected type 'dir', got %q", entries[1].Type) + } +} + +func TestListContents_404(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(404) + w.Write([]byte(`{"message":"Not Found"}`)) + })) + defer srv.Close() + + c := NewClient("token", srv.URL, AllowInsecureHTTP()) + c.SetHTTPClient(srv.Client()) + + _, err := c.ListContents(context.Background(), "owner", "repo", "missing") + if err == nil { + t.Fatal("expected error for 404") + } +} + +func TestListContents_401(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(401) + w.Write([]byte(`{"message":"Bad credentials"}`)) + })) + defer srv.Close() + + c := NewClient("token", srv.URL, AllowInsecureHTTP()) + c.SetHTTPClient(srv.Client()) + + _, err := c.ListContents(context.Background(), "owner", "repo", "src") + if err == nil { + t.Fatal("expected error for 401") + } +} + +func TestListContents_429Retry(t *testing.T) { + attempts := 0 + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + attempts++ + if attempts == 1 { + w.WriteHeader(429) + w.Write([]byte(`{"message":"rate limit"}`)) + return + } + json.NewEncoder(w).Encode([]map[string]string{ + {"name": "file.go", "path": "file.go", "type": "file"}, + }) + })) + defer srv.Close() + + c := NewClient("token", srv.URL, AllowInsecureHTTP()) + c.SetHTTPClient(srv.Client()) + c.SetRetryBackoff([]time.Duration{1 * time.Millisecond}) + + entries, err := c.ListContents(context.Background(), "owner", "repo", ".") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if len(entries) != 1 { + t.Fatalf("expected 1 entry, got %d", len(entries)) + } + if attempts != 2 { + t.Errorf("expected 2 attempts, got %d", attempts) + } +} + +func TestListContents_MalformedJSON(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(200) + w.Write([]byte(`not json`)) + })) + defer srv.Close() + + c := NewClient("token", srv.URL, AllowInsecureHTTP()) + c.SetHTTPClient(srv.Client()) + + _, err := c.ListContents(context.Background(), "owner", "repo", "src") + if err == nil { + t.Fatal("expected error for malformed JSON") + } +} + +func TestDecodeBase64Content(t *testing.T) { + // Test with newlines (GitHub's format) + encoded := "cGFja2FnZSBt\nYWlu" + decoded, err := decodeBase64Content(encoded) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if decoded != "package main" { + t.Errorf("expected 'package main', got %q", decoded) + } +} + +func TestDecodeBase64Content_Invalid(t *testing.T) { + _, err := decodeBase64Content("not!!!valid!!!base64") + if err == nil { + t.Fatal("expected error for invalid base64") + } +} + +func TestEscapePath_RejectsDotSegments(t *testing.T) { + tests := []struct { + input string + want string + }{ + {"src/main.go", "src/main.go"}, + {"../etc/passwd", "etc/passwd"}, + {"./src/../main.go", "src/main.go"}, + {"a/b/c", "a/b/c"}, + {"file with spaces.go", "file%20with%20spaces.go"}, + {"a/./b/../c", "a/b/c"}, + } + for _, tt := range tests { + got := escapePath(tt.input) + if got != tt.want { + t.Errorf("escapePath(%q) = %q, want %q", tt.input, got, tt.want) + } + } +} + +func TestDecodeBase64Content_CRLF(t *testing.T) { + // Base64 of "hello world" with CRLF line breaks inserted + encoded := "aGVs\r\nbG8g\r\nd29y\r\nbGQ=" + decoded, err := decodeBase64Content(encoded) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if decoded != "hello world" { + t.Errorf("expected 'hello world', got %q", decoded) + } +} + +func TestListContents_SingleFile(t *testing.T) { + // GitHub Contents API returns a JSON object (not array) for single-file paths + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(200) + w.Write([]byte(`{"name":"README.md","path":"README.md","type":"file"}`)) + })) + defer srv.Close() + + c := NewClient("token", srv.URL, AllowInsecureHTTP()) + c.SetHTTPClient(srv.Client()) + entries, err := c.ListContents(context.Background(), "owner", "repo", "README.md") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if len(entries) != 1 { + t.Fatalf("expected 1 entry, got %d", len(entries)) + } + if entries[0].Name != "README.md" { + t.Errorf("expected name 'README.md', got %q", entries[0].Name) + } + if entries[0].Type != "file" { + t.Errorf("expected type 'file', got %q", entries[0].Type) + } +} diff --git a/github/pr.go b/github/pr.go new file mode 100644 index 0000000..e9bea5a --- /dev/null +++ b/github/pr.go @@ -0,0 +1,212 @@ +package github + +import ( + "context" + "encoding/json" + "fmt" + "net/http" + "net/url" + + "gitea.weiker.me/rodin/review-bot/vcs" +) + +// pullRequestResponse is the GitHub API response for a pull request. +type pullRequestResponse struct { + Number int `json:"number"` + Title string `json:"title"` + Body string `json:"body"` + Head struct { + SHA string `json:"sha"` + Ref string `json:"ref"` + } `json:"head"` + Base struct { + Ref string `json:"ref"` + } `json:"base"` +} + +// changedFileResponse is the GitHub API response for a changed file in a PR. +type changedFileResponse struct { + Filename string `json:"filename"` + Status string `json:"status"` + Patch string `json:"patch"` +} + +// commitStatusResponse is the GitHub combined status API response. +type commitStatusResponse struct { + Statuses []struct { + Context string `json:"context"` + State string `json:"state"` + Description string `json:"description"` + TargetURL string `json:"target_url"` + } `json:"statuses"` +} + +// checkRunsResponse is the GitHub check runs API response. +type checkRunsResponse struct { + CheckRuns []struct { + Name string `json:"name"` + Conclusion *string `json:"conclusion"` + Status string `json:"status"` + HTMLURL string `json:"html_url"` + } `json:"check_runs"` +} + +// GetPullRequest fetches PR metadata. +func (c *Client) GetPullRequest(ctx context.Context, owner, repo string, number int) (*vcs.PullRequest, error) { + reqURL := fmt.Sprintf("%s/repos/%s/%s/pulls/%d", c.baseURL, url.PathEscape(owner), url.PathEscape(repo), number) + body, err := c.doGet(ctx, reqURL) + if err != nil { + return nil, fmt.Errorf("fetch PR: %w", err) + } + var resp pullRequestResponse + if err := json.Unmarshal(body, &resp); err != nil { + return nil, fmt.Errorf("parse PR JSON: %w", err) + } + return &vcs.PullRequest{ + Number: resp.Number, + Title: resp.Title, + Body: resp.Body, + Head: vcs.HeadRef{SHA: resp.Head.SHA, Ref: resp.Head.Ref}, + Base: vcs.BaseRef{Ref: resp.Base.Ref}, + }, nil +} + +// GetPullRequestDiff fetches the unified diff for a PR. +// Uses Accept: application/vnd.github.diff to get raw diff text. +func (c *Client) GetPullRequestDiff(ctx context.Context, owner, repo string, number int) (string, error) { + reqURL := fmt.Sprintf("%s/repos/%s/%s/pulls/%d", c.baseURL, url.PathEscape(owner), url.PathEscape(repo), number) + body, err := c.doRequest(ctx, http.MethodGet, reqURL, "application/vnd.github.diff") + if err != nil { + return "", fmt.Errorf("fetch diff: %w", err) + } + return string(body), nil +} + +// maxPages is the upper bound on pagination loops to prevent unbounded iteration +// in case the server returns a full page indefinitely. +const maxPages = 100 + +// GetPullRequestFiles fetches the list of files changed in a PR. +// Paginates through all pages (100 per page) to collect all files. +// Returns nil (not an empty slice) when the PR has no changed files. +// Callers can safely range over or check len() on a nil slice. +func (c *Client) GetPullRequestFiles(ctx context.Context, owner, repo string, number int) ([]vcs.ChangedFile, error) { + var allFiles []vcs.ChangedFile + + for page := 1; page <= maxPages; page++ { + reqURL := fmt.Sprintf("%s/repos/%s/%s/pulls/%d/files?per_page=100&page=%d", + c.baseURL, url.PathEscape(owner), url.PathEscape(repo), number, page) + body, err := c.doGet(ctx, reqURL) + if err != nil { + return nil, fmt.Errorf("fetch PR files page %d: %w", page, err) + } + var files []changedFileResponse + if err := json.Unmarshal(body, &files); err != nil { + return nil, fmt.Errorf("parse PR files JSON: %w", err) + } + if len(files) == 0 { + break + } + for _, f := range files { + allFiles = append(allFiles, vcs.ChangedFile{ + Filename: f.Filename, + Status: f.Status, + Patch: f.Patch, + }) + } + if len(files) < 100 { + break + } + } + + return allFiles, nil +} + +// GetCommitStatuses fetches both commit statuses and check runs for a SHA, +// merging them into a unified []vcs.CommitStatus slice. +// Returns nil (not an empty slice) when there are no statuses or check runs. +// If the commit statuses endpoint fails (e.g. 404 for an unknown SHA), the +// function returns immediately without attempting the check-runs endpoint. +// If the check-runs endpoint fails after statuses were fetched successfully, +// the function returns an error (not a partial result) so callers always get +// either a complete view or a clear error signal. +func (c *Client) GetCommitStatuses(ctx context.Context, owner, repo, sha string) ([]vcs.CommitStatus, error) { + var result []vcs.CommitStatus + + // Fetch commit statuses + statusURL := fmt.Sprintf("%s/repos/%s/%s/commits/%s/status", + c.baseURL, url.PathEscape(owner), url.PathEscape(repo), url.PathEscape(sha)) + statusBody, err := c.doGet(ctx, statusURL) + if err != nil { + return nil, fmt.Errorf("fetch commit statuses: %w", err) + } + var statusResp commitStatusResponse + if err := json.Unmarshal(statusBody, &statusResp); err != nil { + return nil, fmt.Errorf("parse commit statuses JSON: %w", err) + } + for _, s := range statusResp.Statuses { + result = append(result, vcs.CommitStatus{ + Context: s.Context, + Status: s.State, + Description: s.Description, + TargetURL: s.TargetURL, + }) + } + + // Fetch check runs (paginated) + for checkPage := 1; checkPage <= maxPages; checkPage++ { + checkURL := fmt.Sprintf("%s/repos/%s/%s/commits/%s/check-runs?per_page=100&page=%d", + c.baseURL, url.PathEscape(owner), url.PathEscape(repo), url.PathEscape(sha), checkPage) + checkBody, err := c.doGet(ctx, checkURL) + if err != nil { + return nil, fmt.Errorf("fetch check runs page %d: %w", checkPage, err) + } + var checkResp checkRunsResponse + if err := json.Unmarshal(checkBody, &checkResp); err != nil { + return nil, fmt.Errorf("parse check runs JSON: %w", err) + } + for _, cr := range checkResp.CheckRuns { + result = append(result, vcs.CommitStatus{ + Context: cr.Name, + Status: mapCheckRunStatus(cr.Conclusion), + Description: derefString(cr.Conclusion), + TargetURL: cr.HTMLURL, + }) + } + if len(checkResp.CheckRuns) < 100 { + break + } + } + + return result, nil +} + +// mapCheckRunStatus maps a check run conclusion to a vcs.CommitStatus status string. +// Conclusion alone determines the mapped state: nil conclusion means the run is +// still in progress (pending), regardless of the status field value. +func mapCheckRunStatus(conclusion *string) string { + if conclusion == nil { + // Still running or queued + return "pending" + } + switch *conclusion { + case "success": + return "success" + case "failure", "action_required", "timed_out": + return "failure" + case "cancelled", "skipped", "neutral": + return "success" // non-blocking: these do not indicate a blocking failure per GitHub check suite semantics + case "stale", "waiting": + return "pending" + default: + return "pending" + } +} + +// derefString safely dereferences a string pointer, returning empty string if nil. +func derefString(s *string) string { + if s == nil { + return "" + } + return *s +} diff --git a/github/pr_test.go b/github/pr_test.go new file mode 100644 index 0000000..0e05a50 --- /dev/null +++ b/github/pr_test.go @@ -0,0 +1,637 @@ +package github + +import ( + "context" + "encoding/json" + "fmt" + "net/http" + "net/http/httptest" + "strings" + "testing" + "time" +) + +func TestGetPullRequest_HappyPath(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/repos/owner/repo/pulls/42" { + t.Errorf("unexpected path: %s", r.URL.Path) + } + json.NewEncoder(w).Encode(map[string]interface{}{ + "number": 42, + "title": "Test PR", + "body": "Description", + "head": map[string]string{"sha": "abc123", "ref": "feature-branch"}, + "base": map[string]string{"ref": "main"}, + }) + })) + defer srv.Close() + + c := NewClient("token", srv.URL, AllowInsecureHTTP()) + c.SetHTTPClient(srv.Client()) + + pr, err := c.GetPullRequest(context.Background(), "owner", "repo", 42) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if pr.Number != 42 { + t.Errorf("expected number 42, got %d", pr.Number) + } + if pr.Title != "Test PR" { + t.Errorf("expected title 'Test PR', got %q", pr.Title) + } + if pr.Body != "Description" { + t.Errorf("expected body 'Description', got %q", pr.Body) + } + if pr.Head.SHA != "abc123" { + t.Errorf("expected head SHA 'abc123', got %q", pr.Head.SHA) + } + if pr.Head.Ref != "feature-branch" { + t.Errorf("expected head ref 'feature-branch', got %q", pr.Head.Ref) + } + if pr.Base.Ref != "main" { + t.Errorf("expected base ref 'main', got %q", pr.Base.Ref) + } +} + +func TestGetPullRequest_404(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(404) + w.Write([]byte(`{"message":"Not Found"}`)) + })) + defer srv.Close() + + c := NewClient("token", srv.URL, AllowInsecureHTTP()) + c.SetHTTPClient(srv.Client()) + + _, err := c.GetPullRequest(context.Background(), "owner", "repo", 999) + if err == nil { + t.Fatal("expected error for 404") + } + if !IsNotFound(err) { + t.Errorf("expected IsNotFound=true, got error: %v", err) + } +} + +func TestGetPullRequest_401(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(401) + w.Write([]byte(`{"message":"Bad credentials"}`)) + })) + defer srv.Close() + + c := NewClient("token", srv.URL, AllowInsecureHTTP()) + c.SetHTTPClient(srv.Client()) + + _, err := c.GetPullRequest(context.Background(), "owner", "repo", 1) + if err == nil { + t.Fatal("expected error for 401") + } + if !IsUnauthorized(err) { + t.Errorf("expected IsUnauthorized=true, got error: %v", err) + } +} + +func TestGetPullRequest_429Retry(t *testing.T) { + attempts := 0 + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + attempts++ + if attempts == 1 { + w.WriteHeader(429) + w.Write([]byte(`{"message":"rate limit"}`)) + return + } + json.NewEncoder(w).Encode(map[string]interface{}{ + "number": 1, + "title": "PR", + "body": "", + "head": map[string]string{"sha": "abc", "ref": "br"}, + "base": map[string]string{"ref": "main"}, + }) + })) + defer srv.Close() + + c := NewClient("token", srv.URL, AllowInsecureHTTP()) + c.SetHTTPClient(srv.Client()) + c.SetRetryBackoff([]time.Duration{1 * time.Millisecond}) + + pr, err := c.GetPullRequest(context.Background(), "owner", "repo", 1) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if pr.Number != 1 { + t.Errorf("expected number 1, got %d", pr.Number) + } + if attempts != 2 { + t.Errorf("expected 2 attempts, got %d", attempts) + } +} + +func TestGetPullRequest_MalformedJSON(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(200) + w.Write([]byte(`{invalid json`)) + })) + defer srv.Close() + + c := NewClient("token", srv.URL, AllowInsecureHTTP()) + c.SetHTTPClient(srv.Client()) + + _, err := c.GetPullRequest(context.Background(), "owner", "repo", 1) + if err == nil { + t.Fatal("expected error for malformed JSON") + } + if !strings.Contains(err.Error(), "parse PR JSON") { + t.Errorf("expected parse error, got: %v", err) + } +} + +func TestGetPullRequestDiff_HappyPath(t *testing.T) { + expectedDiff := "diff --git a/file.go b/file.go\n--- a/file.go\n+++ b/file.go\n@@ -1,3 +1,4 @@\n+// new line\n" + var gotAccept string + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + gotAccept = r.Header.Get("Accept") + w.WriteHeader(200) + w.Write([]byte(expectedDiff)) + })) + defer srv.Close() + + c := NewClient("token", srv.URL, AllowInsecureHTTP()) + c.SetHTTPClient(srv.Client()) + + diff, err := c.GetPullRequestDiff(context.Background(), "owner", "repo", 42) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if diff != expectedDiff { + t.Errorf("unexpected diff: %q", diff) + } + if gotAccept != "application/vnd.github.diff" { + t.Errorf("expected diff Accept header, got %q", gotAccept) + } +} + +func TestGetPullRequestDiff_404(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(404) + w.Write([]byte(`{"message":"Not Found"}`)) + })) + defer srv.Close() + + c := NewClient("token", srv.URL, AllowInsecureHTTP()) + c.SetHTTPClient(srv.Client()) + + _, err := c.GetPullRequestDiff(context.Background(), "owner", "repo", 999) + if err == nil { + t.Fatal("expected error for 404") + } +} + +func TestGetPullRequestDiff_401(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(401) + w.Write([]byte(`{"message":"Bad credentials"}`)) + })) + defer srv.Close() + + c := NewClient("token", srv.URL, AllowInsecureHTTP()) + c.SetHTTPClient(srv.Client()) + + _, err := c.GetPullRequestDiff(context.Background(), "owner", "repo", 1) + if err == nil { + t.Fatal("expected error for 401") + } +} + +func TestGetPullRequestFiles_HappyPath(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + json.NewEncoder(w).Encode([]map[string]interface{}{ + {"filename": "main.go", "status": "modified", "patch": "@@ -1,3 +1,4 @@\n+line"}, + {"filename": "test.go", "status": "added", "patch": "@@ -0,0 +1,5 @@\n+new file"}, + }) + })) + defer srv.Close() + + c := NewClient("token", srv.URL, AllowInsecureHTTP()) + c.SetHTTPClient(srv.Client()) + + files, err := c.GetPullRequestFiles(context.Background(), "owner", "repo", 1) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if len(files) != 2 { + t.Fatalf("expected 2 files, got %d", len(files)) + } + if files[0].Filename != "main.go" { + t.Errorf("expected filename 'main.go', got %q", files[0].Filename) + } + if files[0].Status != "modified" { + t.Errorf("expected status 'modified', got %q", files[0].Status) + } + if files[0].Patch != "@@ -1,3 +1,4 @@\n+line" { + t.Errorf("unexpected patch: %q", files[0].Patch) + } +} + +func TestGetPullRequestFiles_Pagination(t *testing.T) { + // Simulate > 100 files requiring pagination + page1Files := make([]map[string]string, 100) + for i := 0; i < 100; i++ { + page1Files[i] = map[string]string{ + "filename": fmt.Sprintf("file%d.go", i), + "status": "modified", + "patch": fmt.Sprintf("patch%d", i), + } + } + page2Files := []map[string]string{ + {"filename": "file100.go", "status": "added", "patch": "patch100"}, + } + + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + page := r.URL.Query().Get("page") + if page == "" || page == "1" { + json.NewEncoder(w).Encode(page1Files) + } else { + json.NewEncoder(w).Encode(page2Files) + } + })) + defer srv.Close() + + c := NewClient("token", srv.URL, AllowInsecureHTTP()) + c.SetHTTPClient(srv.Client()) + + files, err := c.GetPullRequestFiles(context.Background(), "owner", "repo", 1) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if len(files) != 101 { + t.Errorf("expected 101 files (paginated), got %d", len(files)) + } + if files[100].Filename != "file100.go" { + t.Errorf("expected last file 'file100.go', got %q", files[100].Filename) + } + if files[100].Patch != "patch100" { + t.Errorf("expected last patch 'patch100', got %q", files[100].Patch) + } +} + +func TestGetPullRequestFiles_BinaryFile_NoPatch(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Binary files have no patch field in GitHub response + json.NewEncoder(w).Encode([]map[string]interface{}{ + {"filename": "image.png", "status": "added"}, + }) + })) + defer srv.Close() + + c := NewClient("token", srv.URL, AllowInsecureHTTP()) + c.SetHTTPClient(srv.Client()) + + files, err := c.GetPullRequestFiles(context.Background(), "owner", "repo", 1) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if len(files) != 1 { + t.Fatalf("expected 1 file, got %d", len(files)) + } + if files[0].Patch != "" { + t.Errorf("expected empty patch for binary file, got %q", files[0].Patch) + } +} + +func TestGetPullRequestFiles_404(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(404) + w.Write([]byte(`{"message":"Not Found"}`)) + })) + defer srv.Close() + + c := NewClient("token", srv.URL, AllowInsecureHTTP()) + c.SetHTTPClient(srv.Client()) + + _, err := c.GetPullRequestFiles(context.Background(), "owner", "repo", 999) + if err == nil { + t.Fatal("expected error for 404") + } +} + +func TestGetPullRequestFiles_MalformedJSON(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(200) + w.Write([]byte(`not json`)) + })) + defer srv.Close() + + c := NewClient("token", srv.URL, AllowInsecureHTTP()) + c.SetHTTPClient(srv.Client()) + + _, err := c.GetPullRequestFiles(context.Background(), "owner", "repo", 1) + if err == nil { + t.Fatal("expected error for malformed JSON") + } +} + +func TestGetFileContentAtRef_HappyPath(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/repos/owner/repo/contents/path/to/file.go" { + t.Errorf("unexpected path: %s", r.URL.Path) + } + if r.URL.Query().Get("ref") != "abc123" { + t.Errorf("unexpected ref: %s", r.URL.Query().Get("ref")) + } + json.NewEncoder(w).Encode(map[string]string{ + "content": "cGFja2FnZSBtYWlu", // "package main" in base64 + "encoding": "base64", + }) + })) + defer srv.Close() + + c := NewClient("token", srv.URL, AllowInsecureHTTP()) + c.SetHTTPClient(srv.Client()) + + content, err := c.GetFileContentAtRef(context.Background(), "owner", "repo", "path/to/file.go", "abc123") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if content != "package main" { + t.Errorf("expected 'package main', got %q", content) + } +} + +func TestGetFileContentAtRef_EmptyRef(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Query().Get("ref") != "" { + t.Errorf("expected no ref param, got %q", r.URL.Query().Get("ref")) + } + json.NewEncoder(w).Encode(map[string]string{ + "content": "aGVsbG8=", // "hello" in base64 + "encoding": "base64", + }) + })) + defer srv.Close() + + c := NewClient("token", srv.URL, AllowInsecureHTTP()) + c.SetHTTPClient(srv.Client()) + + content, err := c.GetFileContentAtRef(context.Background(), "owner", "repo", "file.txt", "") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if content != "hello" { + t.Errorf("expected 'hello', got %q", content) + } +} + +func TestGetFileContentAtRef_404(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(404) + w.Write([]byte(`{"message":"Not Found"}`)) + })) + defer srv.Close() + + c := NewClient("token", srv.URL, AllowInsecureHTTP()) + c.SetHTTPClient(srv.Client()) + + _, err := c.GetFileContentAtRef(context.Background(), "owner", "repo", "missing.go", "main") + if err == nil { + t.Fatal("expected error for 404") + } +} + +func TestGetFileContentAtRef_401(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(401) + w.Write([]byte(`{"message":"Bad credentials"}`)) + })) + defer srv.Close() + + c := NewClient("token", srv.URL, AllowInsecureHTTP()) + c.SetHTTPClient(srv.Client()) + + _, err := c.GetFileContentAtRef(context.Background(), "owner", "repo", "file.go", "main") + if err == nil { + t.Fatal("expected error for 401") + } +} + +func TestGetFileContentAtRef_MalformedJSON(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(200) + w.Write([]byte(`not valid json`)) + })) + defer srv.Close() + + c := NewClient("token", srv.URL, AllowInsecureHTTP()) + c.SetHTTPClient(srv.Client()) + + _, err := c.GetFileContentAtRef(context.Background(), "owner", "repo", "file.go", "main") + if err == nil { + t.Fatal("expected error for malformed JSON") + } +} + +func TestGetFileContentAtRef_429Retry(t *testing.T) { + attempts := 0 + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + attempts++ + if attempts == 1 { + w.WriteHeader(429) + w.Write([]byte(`{"message":"rate limit"}`)) + return + } + json.NewEncoder(w).Encode(map[string]string{ + "content": "b2s=", // "ok" in base64 + "encoding": "base64", + }) + })) + defer srv.Close() + + c := NewClient("token", srv.URL, AllowInsecureHTTP()) + c.SetHTTPClient(srv.Client()) + c.SetRetryBackoff([]time.Duration{1 * time.Millisecond}) + + content, err := c.GetFileContentAtRef(context.Background(), "owner", "repo", "file.go", "main") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if content != "ok" { + t.Errorf("expected 'ok', got %q", content) + } + if attempts != 2 { + t.Errorf("expected 2 attempts, got %d", attempts) + } +} + +func TestGetCommitStatuses_HappyPath(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch { + case strings.Contains(r.URL.Path, "/status"): + json.NewEncoder(w).Encode(map[string]interface{}{ + "state": "success", + "statuses": []map[string]string{ + { + "context": "ci/build", + "state": "success", + "description": "Build passed", + "target_url": "https://ci.example.com/1", + }, + }, + }) + case strings.Contains(r.URL.Path, "/check-runs"): + conclusion := "success" + json.NewEncoder(w).Encode(map[string]interface{}{ + "total_count": 1, + "check_runs": []map[string]interface{}{ + { + "name": "lint", + "conclusion": &conclusion, + "status": "completed", + "html_url": "https://github.com/check/1", + }, + }, + }) + default: + t.Errorf("unexpected path: %s", r.URL.Path) + w.WriteHeader(404) + } + })) + defer srv.Close() + + c := NewClient("token", srv.URL, AllowInsecureHTTP()) + c.SetHTTPClient(srv.Client()) + + statuses, err := c.GetCommitStatuses(context.Background(), "owner", "repo", "abc123") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if len(statuses) != 2 { + t.Fatalf("expected 2 statuses, got %d", len(statuses)) + } + // First should be from commit statuses + if statuses[0].Context != "ci/build" { + t.Errorf("expected context 'ci/build', got %q", statuses[0].Context) + } + if statuses[0].Status != "success" { + t.Errorf("expected status 'success', got %q", statuses[0].Status) + } + // Second should be from check runs + if statuses[1].Context != "lint" { + t.Errorf("expected context 'lint', got %q", statuses[1].Context) + } + if statuses[1].Status != "success" { + t.Errorf("expected status 'success', got %q", statuses[1].Status) + } +} + +func TestGetCommitStatuses_CheckRunConclusions(t *testing.T) { + tests := []struct { + conclusion *string + status string + want string + }{ + {stringPtr("success"), "completed", "success"}, + {stringPtr("failure"), "completed", "failure"}, + {stringPtr("action_required"), "completed", "failure"}, + {stringPtr("timed_out"), "completed", "failure"}, + {stringPtr("cancelled"), "completed", "success"}, + {stringPtr("skipped"), "completed", "success"}, + {stringPtr("neutral"), "completed", "success"}, + {nil, "in_progress", "pending"}, + {nil, "queued", "pending"}, + } + + for _, tt := range tests { + name := "nil" + if tt.conclusion != nil { + name = *tt.conclusion + } + t.Run(name, func(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if strings.Contains(r.URL.Path, "/status") { + json.NewEncoder(w).Encode(map[string]interface{}{ + "state": "success", + "statuses": []interface{}{}, + }) + return + } + json.NewEncoder(w).Encode(map[string]interface{}{ + "total_count": 1, + "check_runs": []map[string]interface{}{ + { + "name": "check", + "conclusion": tt.conclusion, + "status": tt.status, + "html_url": "https://github.com/check/1", + }, + }, + }) + })) + defer srv.Close() + + c := NewClient("token", srv.URL, AllowInsecureHTTP()) + c.SetHTTPClient(srv.Client()) + + statuses, err := c.GetCommitStatuses(context.Background(), "owner", "repo", "sha1") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if len(statuses) != 1 { + t.Fatalf("expected 1 status, got %d", len(statuses)) + } + if statuses[0].Status != tt.want { + t.Errorf("expected status %q, got %q", tt.want, statuses[0].Status) + } + }) + } +} + +func TestGetCommitStatuses_404(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(404) + w.Write([]byte(`{"message":"Not Found"}`)) + })) + defer srv.Close() + + c := NewClient("token", srv.URL, AllowInsecureHTTP()) + c.SetHTTPClient(srv.Client()) + + _, err := c.GetCommitStatuses(context.Background(), "owner", "repo", "badsha") + if err == nil { + t.Fatal("expected error for 404") + } +} + +func TestGetCommitStatuses_401(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(401) + w.Write([]byte(`{"message":"Bad credentials"}`)) + })) + defer srv.Close() + + c := NewClient("token", srv.URL, AllowInsecureHTTP()) + c.SetHTTPClient(srv.Client()) + + _, err := c.GetCommitStatuses(context.Background(), "owner", "repo", "sha") + if err == nil { + t.Fatal("expected error for 401") + } +} + +func TestGetCommitStatuses_MalformedJSON(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(200) + w.Write([]byte(`not json`)) + })) + defer srv.Close() + + c := NewClient("token", srv.URL, AllowInsecureHTTP()) + c.SetHTTPClient(srv.Client()) + + _, err := c.GetCommitStatuses(context.Background(), "owner", "repo", "sha") + if err == nil { + t.Fatal("expected error for malformed JSON") + } +} + +func stringPtr(s string) *string { + return &s +} diff --git a/vcs/types.go b/vcs/types.go index de904f3..608ad27 100644 --- a/vcs/types.go +++ b/vcs/types.go @@ -44,6 +44,7 @@ type PullRequest struct { type ChangedFile struct { Filename string `json:"filename"` Status string `json:"status"` + Patch string `json:"patch"` } // ContentEntry represents a file or directory entry from the contents API.