From eb5457f51e922da23c5c0161db8db0fff484b167 Mon Sep 17 00:00:00 2001 From: Aaron Weiker Date: Wed, 13 May 2026 04:11:53 +0000 Subject: [PATCH] feat(github): implement GitHub API client foundation (#80) Add GitHub API client with configurable base URL and GHE support, HTTP helpers with 429 retry and Retry-After handling. Also adds Patch field to vcs.ChangedFile. Part 1 of 3 for #80. --- github/client.go | 327 +++++++++++++++++++++++++ github/client_test.go | 556 ++++++++++++++++++++++++++++++++++++++++++ vcs/types.go | 1 + 3 files changed, 884 insertions(+) create mode 100644 github/client.go create mode 100644 github/client_test.go diff --git a/github/client.go b/github/client.go new file mode 100644 index 0000000..7f5c7f9 --- /dev/null +++ b/github/client.go @@ -0,0 +1,327 @@ +// Package github provides a client for the GitHub API. +// It supports pull request operations, file content retrieval, CI status checks, +// and directory listing for both github.com and GitHub Enterprise. +package github + +import ( + "context" + "errors" + "fmt" + "io" + "net/http" + "net/url" + "strconv" + "strings" + "time" +) + +const ( + defaultBaseURL = "https://api.github.com" + userAgent = "review-bot/1.0" + + // maxResponseBytes limits successful response body reads to 10 MiB. + maxResponseBytes = 10 * 1024 * 1024 +) + +// APIError represents an HTTP error response from the GitHub API. +// It carries the status code so callers can distinguish between +// different failure modes (e.g. 404 vs 500). +// +// The Body field stores up to 4 KiB of the raw response for programmatic +// inspection. Error() truncates to 200 bytes for safe logging, but callers +// should avoid logging or propagating Body directly in production since it may +// contain sensitive details from the upstream server. +type APIError struct { + StatusCode int + Body string +} + +func (e *APIError) Error() string { + body := e.Body + if len(body) > 200 { + body = body[:200] + "...(truncated)" + } + // Sanitize newlines to prevent log injection from upstream response bodies. + body = strings.ReplaceAll(body, "\n", " ") + body = strings.ReplaceAll(body, "\r", " ") + return fmt.Sprintf("HTTP %d: %s", e.StatusCode, body) +} + +// IsNotFound reports whether an error is an API 404 response. +func IsNotFound(err error) bool { + if apiErr, ok := asAPIError(err); ok { + return apiErr.StatusCode == http.StatusNotFound + } + return false +} + +// IsUnauthorized reports whether an error is an API 401 response. +func IsUnauthorized(err error) bool { + if apiErr, ok := asAPIError(err); ok { + return apiErr.StatusCode == http.StatusUnauthorized + } + return false +} + +func asAPIError(err error) (*APIError, bool) { + if err == nil { + return nil, false + } + var target *APIError + if errors.As(err, &target) { + return target, true + } + return nil, false +} + +// clientConfig holds optional configuration for NewClient. +type clientConfig struct { + allowInsecureHTTP bool +} + +// ClientOption configures optional behavior of NewClient. +type ClientOption func(*clientConfig) + +// AllowInsecureHTTP permits the client to use HTTP (non-TLS) base URLs. +// This should only be used for trusted internal deployments or testing. +func AllowInsecureHTTP() ClientOption { + return func(c *clientConfig) { + c.allowInsecureHTTP = true + } +} + +// Client interacts with the GitHub API. +// A Client is safe for concurrent use by multiple goroutines. +// SetHTTPClient and SetRetryBackoff are intended for test setup only and must +// be called before any goroutines issue requests; they have no synchronization. +type Client struct { + baseURL string + token string + allowInsecureHTTP bool + httpClient *http.Client + + // retryBackoff defines the delays between retry attempts for 429 responses. + // retryBackoff[i] is the delay before attempt i+1 (after attempt i fails). + // If nil, defaults to {1s, 2s}. Set to shorter durations in tests via SetRetryBackoff. + retryBackoff []time.Duration +} + +// defaultCheckRedirect is the redirect policy used by NewClient and SetHTTPClient(nil). +// It rejects HTTPS→HTTP protocol downgrades (to prevent plaintext leakage) and strips +// the Authorization header on cross-host redirects to prevent credential leakage to +// third-party hosts (e.g. CDN redirects from GitHub). +func defaultCheckRedirect(req *http.Request, via []*http.Request) error { + if len(via) >= 10 { + return fmt.Errorf("stopped after 10 redirects") + } + // Guard: net/http guarantees len(via) >= 1 but this is undocumented; + // defend against zero-length to avoid panic on index out of range. + if len(via) == 0 { + return nil + } + prev := via[len(via)-1] + // Reject protocol downgrade: HTTPS→HTTP leaks request metadata over plaintext. + if prev.URL.Scheme == "https" && req.URL.Scheme == "http" { + return fmt.Errorf("refusing redirect from HTTPS to HTTP (%s → %s)", prev.URL.Host, req.URL.Host) + } + // Strip Authorization on cross-host redirect to avoid leaking credentials + // to third-party hosts (GitHub legitimately redirects to CDN hosts). + if req.URL.Host != prev.URL.Host { + req.Header.Del("Authorization") + } + return nil +} + +// NewClient creates a new GitHub API client. +// If baseURL is empty, it defaults to https://api.github.com. +// For GitHub Enterprise, pass the API base URL (e.g. https://github.concur.com/api/v3). +// The baseURL must use HTTPS; pass AllowInsecureHTTP() as an option to permit HTTP +// for trusted internal deployments (e.g. local testing). +func NewClient(token, baseURL string, opts ...ClientOption) *Client { + if baseURL == "" { + baseURL = defaultBaseURL + } + cfg := clientConfig{} + for _, o := range opts { + o(&cfg) + } + return &Client{ + baseURL: strings.TrimRight(baseURL, "/"), + allowInsecureHTTP: cfg.allowInsecureHTTP, + token: token, + httpClient: &http.Client{ + Timeout: 30 * time.Second, + CheckRedirect: defaultCheckRedirect, + }, + } +} + +// SetHTTPClient sets the underlying HTTP client used for requests. +// This is intended for test setup only to inject mock transports; it must be +// called before any goroutines issue requests. +// +// Passing nil restores the default client (30s timeout + auth-stripping +// CheckRedirect policy matching NewClient). +// +// Callers providing a non-nil client are responsible for configuring a safe +// CheckRedirect policy. Without one, the default net/http behavior will follow +// redirects and may forward the Authorization header to untrusted hosts. +func (c *Client) SetHTTPClient(hc *http.Client) { + if hc == nil { + hc = &http.Client{ + Timeout: 30 * time.Second, + CheckRedirect: defaultCheckRedirect, + } + } + c.httpClient = hc +} + +// SetRetryBackoff configures the retry backoff durations for testing. +// It must be called before any goroutines issue requests. +// In production the default {1s, 2s} applies. +func (c *Client) SetRetryBackoff(d []time.Duration) { + c.retryBackoff = d +} + +// doRequest performs an HTTP request with retry on 429 rate limit responses. +// It respects the Retry-After header when present (capped at maxRetryAfter). +// Transport errors (network failures, context cancellation) are not retried. +func (c *Client) doRequest(ctx context.Context, method, reqURL string, accept string) ([]byte, error) { + const maxAttempts = 3 + const maxRetryAfter = 120 * time.Second + + var backoff []time.Duration + if c.retryBackoff != nil { + backoff = make([]time.Duration, len(c.retryBackoff)) + copy(backoff, c.retryBackoff) + } else { + backoff = []time.Duration{1 * time.Second, 2 * time.Second} + } + + // maxErrorBodyBytes limits how much of an error response body is stored. + // Kept small (4 KiB) to reduce the risk of sensitive data leakage if callers + // log APIError.Body directly. Error() further truncates to 200 bytes. + const maxErrorBodyBytes = 4 * 1024 + + // Reject non-HTTPS URLs early since the URL is immutable across retries. + if c.token != "" && !c.allowInsecureHTTP { + parsed, err := url.Parse(reqURL) + if err != nil { + return nil, fmt.Errorf("parse request URL: %w", err) + } + if !strings.EqualFold(parsed.Scheme, "https") { + return nil, fmt.Errorf("refusing to send credentials over non-HTTPS URL %q (use AllowInsecureHTTP option for trusted networks)", reqURL) + } + } + + var lastErr error + for attempt := 0; attempt < maxAttempts; attempt++ { + if attempt > 0 { + var delay time.Duration + if attempt-1 < len(backoff) { + delay = backoff[attempt-1] + } + if delay > 0 { + timer := time.NewTimer(delay) + select { + case <-timer.C: + timer.Stop() // no-op after fire; kept for symmetry with the ctx.Done case + case <-ctx.Done(): + timer.Stop() + return nil, ctx.Err() + } + } + } + + req, err := http.NewRequestWithContext(ctx, method, reqURL, nil) + if err != nil { + return nil, fmt.Errorf("create request: %w", err) + } + if c.token != "" { + // Bearer is the OAuth2 standard and is accepted by GitHub for both + // classic PATs and fine-grained tokens. The alternative "token" scheme + // is GitHub-specific and offers no additional compatibility. + req.Header.Set("Authorization", "Bearer "+c.token) + } + req.Header.Set("User-Agent", userAgent) + if accept != "" { + req.Header.Set("Accept", accept) + } else { + req.Header.Set("Accept", "application/vnd.github+json") + } + + resp, err := c.httpClient.Do(req) + if err != nil { + return nil, fmt.Errorf("do request: %w", err) + } + + body, done, err := c.handleResponse(resp, maxResponseBytes, maxErrorBodyBytes) + if done { + return body, err + } + lastErr = err + + // Retry on 429 rate limit + if resp.StatusCode == http.StatusTooManyRequests && attempt < maxAttempts-1 { + // Check for Retry-After header and override backoff if present. + // Supports both integer seconds (common) and HTTP-date format (RFC 7231). + if ra := resp.Header.Get("Retry-After"); ra != "" { + if seconds, err := strconv.Atoi(ra); err == nil && seconds > 0 { + delay := time.Duration(seconds) * time.Second + if delay > maxRetryAfter { + delay = maxRetryAfter + } + if attempt < len(backoff) { + backoff[attempt] = delay + } + } else if retryAt, err := http.ParseTime(ra); err == nil { + delay := time.Until(retryAt) + if delay < 0 { + delay = 0 + } + if delay > maxRetryAfter { + delay = maxRetryAfter + } + if attempt < len(backoff) { + backoff[attempt] = delay + } + } + } + continue + } + + // Don't retry other errors + return nil, lastErr + } + + return nil, lastErr +} + +// handleResponse reads and closes the response body, returning the result. +// It uses defer to ensure the body is always closed regardless of code path. +// Returns (body, done, err) where done=true means the caller should return immediately. +func (c *Client) handleResponse(resp *http.Response, maxRespBytes int, maxErrBytes int) ([]byte, bool, error) { + defer resp.Body.Close() + + if resp.StatusCode >= 200 && resp.StatusCode < 300 { + body, err := io.ReadAll(io.LimitReader(resp.Body, int64(maxRespBytes)+1)) + if err != nil { + return nil, true, fmt.Errorf("read response body: %w", err) + } + if len(body) > maxRespBytes { + return nil, true, fmt.Errorf("response body exceeded %d bytes (truncated)", maxRespBytes) + } + return body, true, nil + } + + errBody, readErr := io.ReadAll(io.LimitReader(resp.Body, int64(maxErrBytes))) + if readErr != nil && len(errBody) == 0 { + errBody = []byte(fmt.Sprintf("[error reading response body: %v]", readErr)) + } + return nil, false, &APIError{StatusCode: resp.StatusCode, Body: string(errBody)} +} + +// doGet is a convenience wrapper for GET requests with the default Accept header. +func (c *Client) doGet(ctx context.Context, reqURL string) ([]byte, error) { + return c.doRequest(ctx, http.MethodGet, reqURL, "") +} diff --git a/github/client_test.go b/github/client_test.go new file mode 100644 index 0000000..a8ccc06 --- /dev/null +++ b/github/client_test.go @@ -0,0 +1,556 @@ +package github + +import ( + "context" + "net/http" + "net/http/httptest" + "net/url" + "strings" + "testing" + "time" +) + +func TestNewClient_DefaultBaseURL(t *testing.T) { + c := NewClient("test-token", "") + if c.baseURL != "https://api.github.com" { + t.Errorf("expected default base URL, got %q", c.baseURL) + } +} + +func TestNewClient_CustomBaseURL(t *testing.T) { + c := NewClient("test-token", "https://github.concur.com/api/v3") + if c.baseURL != "https://github.concur.com/api/v3" { + t.Errorf("expected custom base URL, got %q", c.baseURL) + } +} + +func TestNewClient_TrimsTrailingSlash(t *testing.T) { + c := NewClient("test-token", "https://github.concur.com/api/v3/") + if c.baseURL != "https://github.concur.com/api/v3" { + t.Errorf("expected trailing slash trimmed, got %q", c.baseURL) + } +} + +func TestDoRequest_SetsAuthHeader(t *testing.T) { + var gotAuth string + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + gotAuth = r.Header.Get("Authorization") + w.WriteHeader(200) + w.Write([]byte("{}")) + })) + defer srv.Close() + + c := NewClient("my-token", srv.URL, AllowInsecureHTTP()) + c.SetHTTPClient(srv.Client()) + _, _ = c.doGet(context.Background(), srv.URL+"/test") + + if gotAuth != "Bearer my-token" { + t.Errorf("expected Bearer auth, got %q", gotAuth) + } +} + +func TestDoRequest_SetsDefaultAcceptHeader(t *testing.T) { + var gotAccept string + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + gotAccept = r.Header.Get("Accept") + w.WriteHeader(200) + w.Write([]byte("{}")) + })) + defer srv.Close() + + c := NewClient("token", srv.URL, AllowInsecureHTTP()) + c.SetHTTPClient(srv.Client()) + _, _ = c.doGet(context.Background(), srv.URL+"/test") + + if gotAccept != "application/vnd.github+json" { + t.Errorf("expected default Accept header, got %q", gotAccept) + } +} + +func TestDoRequest_429Retry(t *testing.T) { + attempts := 0 + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + attempts++ + if attempts == 1 { + w.WriteHeader(429) + w.Write([]byte(`{"message":"rate limit"}`)) + return + } + w.WriteHeader(200) + w.Write([]byte(`{"ok":true}`)) + })) + defer srv.Close() + + c := NewClient("token", srv.URL, AllowInsecureHTTP()) + c.SetHTTPClient(srv.Client()) + c.SetRetryBackoff([]time.Duration{10 * time.Millisecond, 10 * time.Millisecond}) + + body, err := c.doGet(context.Background(), srv.URL+"/test") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if string(body) != `{"ok":true}` { + t.Errorf("unexpected body: %s", body) + } + if attempts != 2 { + t.Errorf("expected 2 attempts, got %d", attempts) + } +} + +func TestDoRequest_429ExhaustsRetries(t *testing.T) { + attempts := 0 + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + attempts++ + w.WriteHeader(429) + w.Write([]byte(`{"message":"rate limit"}`)) + })) + defer srv.Close() + + c := NewClient("token", srv.URL, AllowInsecureHTTP()) + c.SetHTTPClient(srv.Client()) + c.SetRetryBackoff([]time.Duration{1 * time.Millisecond, 1 * time.Millisecond}) + + _, err := c.doGet(context.Background(), srv.URL+"/test") + if err == nil { + t.Fatal("expected error after exhausting retries") + } + apiErr, ok := err.(*APIError) + if !ok { + t.Fatalf("expected *APIError, got %T", err) + } + if apiErr.StatusCode != 429 { + t.Errorf("expected 429, got %d", apiErr.StatusCode) + } + if attempts != 3 { + t.Errorf("expected 3 attempts, got %d", attempts) + } +} + +func TestDoRequest_404NoRetry(t *testing.T) { + attempts := 0 + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + attempts++ + w.WriteHeader(404) + w.Write([]byte(`{"message":"not found"}`)) + })) + defer srv.Close() + + c := NewClient("token", srv.URL, AllowInsecureHTTP()) + c.SetHTTPClient(srv.Client()) + + _, err := c.doGet(context.Background(), srv.URL+"/test") + if err == nil { + t.Fatal("expected error for 404") + } + if attempts != 1 { + t.Errorf("expected 1 attempt (no retry on 404), got %d", attempts) + } +} + +func TestDoRequest_401NoRetry(t *testing.T) { + attempts := 0 + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + attempts++ + w.WriteHeader(401) + w.Write([]byte(`{"message":"bad credentials"}`)) + })) + defer srv.Close() + + c := NewClient("token", srv.URL, AllowInsecureHTTP()) + c.SetHTTPClient(srv.Client()) + + _, err := c.doGet(context.Background(), srv.URL+"/test") + if err == nil { + t.Fatal("expected error for 401") + } + if attempts != 1 { + t.Errorf("expected 1 attempt (no retry on 401), got %d", attempts) + } +} + +func TestIsNotFound(t *testing.T) { + err := &APIError{StatusCode: 404, Body: "not found"} + if !IsNotFound(err) { + t.Error("expected IsNotFound to return true for 404") + } + err2 := &APIError{StatusCode: 500, Body: "server error"} + if IsNotFound(err2) { + t.Error("expected IsNotFound to return false for 500") + } +} + +func TestIsUnauthorized(t *testing.T) { + err := &APIError{StatusCode: 401, Body: "bad credentials"} + if !IsUnauthorized(err) { + t.Error("expected IsUnauthorized to return true for 401") + } +} + +func TestAPIError_SanitizesNewlines(t *testing.T) { + err := &APIError{StatusCode: 500, Body: "line1\ninjected\rmore"} + msg := err.Error() + if strings.Contains(msg, "\n") || strings.Contains(msg, "\r") { + t.Errorf("expected newlines to be sanitized, got: %q", msg) + } + if !strings.Contains(msg, "line1 injected more") { + t.Errorf("expected sanitized body, got: %q", msg) + } +} + +func TestDoRequest_429RetryAfterHeader(t *testing.T) { + if testing.Short() { + t.Skip("skipping slow retry test in short mode") + } + attempts := 0 + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + attempts++ + if attempts == 1 { + w.Header().Set("Retry-After", "1") + w.WriteHeader(429) + w.Write([]byte(`{"message":"rate limit"}`)) + return + } + w.WriteHeader(200) + w.Write([]byte(`{"ok":true}`)) + })) + defer srv.Close() + + c := NewClient("token", srv.URL, AllowInsecureHTTP()) + c.SetHTTPClient(srv.Client()) + // Use short backoff; Retry-After should override + c.SetRetryBackoff([]time.Duration{1 * time.Millisecond, 1 * time.Millisecond}) + + start := time.Now() + body, err := c.doGet(context.Background(), srv.URL+"/test") + elapsed := time.Since(start) + + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if string(body) != `{"ok":true}` { + t.Errorf("unexpected body: %s", body) + } + if attempts != 2 { + t.Errorf("expected 2 attempts, got %d", attempts) + } + // Retry-After: 1 means at least 1 second delay + if elapsed < 900*time.Millisecond { + t.Errorf("expected ~1s delay from Retry-After, got %v", elapsed) + } +} + +func TestDoRequest_RetryAfterDoesNotMutateBackoff(t *testing.T) { + if testing.Short() { + t.Skip("skipping slow retry test in short mode") + } + attempts := 0 + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + attempts++ + if attempts == 1 { + w.Header().Set("Retry-After", "1") + w.WriteHeader(429) + w.Write([]byte(`{"message":"rate limit"}`)) + return + } + w.WriteHeader(200) + w.Write([]byte(`{"ok":true}`)) + })) + defer srv.Close() + + c := NewClient("token", srv.URL, AllowInsecureHTTP()) + c.SetHTTPClient(srv.Client()) + c.SetRetryBackoff([]time.Duration{1 * time.Millisecond, 1 * time.Millisecond}) + + _, err := c.doGet(context.Background(), srv.URL+"/test") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + // Verify the original retryBackoff slice was not mutated + if c.retryBackoff[0] != 1*time.Millisecond { + t.Errorf("retryBackoff[0] was mutated: got %v, want 1ms", c.retryBackoff[0]) + } + if c.retryBackoff[1] != 1*time.Millisecond { + t.Errorf("retryBackoff[1] was mutated: got %v, want 1ms", c.retryBackoff[1]) + } +} + +func TestDoRequest_429RetryAfterHTTPDate(t *testing.T) { + if testing.Short() { + t.Skip("skipping slow Retry-After HTTP-date test in short mode") + } + attempts := 0 + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + attempts++ + if attempts == 1 { + // Use HTTP-date format (RFC 7231) — a time 2 seconds in the future. + future := time.Now().Add(2 * time.Second).UTC() + w.Header().Set("Retry-After", future.Format(http.TimeFormat)) + w.WriteHeader(429) + w.Write([]byte(`{"message":"rate limit"}`)) + return + } + w.WriteHeader(200) + w.Write([]byte(`{"ok":true}`)) + })) + defer srv.Close() + + c := NewClient("token", srv.URL, AllowInsecureHTTP()) + c.SetHTTPClient(srv.Client()) + c.SetRetryBackoff([]time.Duration{1 * time.Millisecond, 1 * time.Millisecond}) + + start := time.Now() + body, err := c.doGet(context.Background(), srv.URL+"/test") + elapsed := time.Since(start) + + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if string(body) != `{"ok":true}` { + t.Errorf("unexpected body: %s", body) + } + if attempts != 2 { + t.Errorf("expected 2 attempts, got %d", attempts) + } + // HTTP-date was ~2s in the future; by the time client processes it, + // time.Until gives ~1-2s. Verify it's meaningfully delayed (not instant). + if elapsed < 500*time.Millisecond { + t.Errorf("expected meaningful delay from HTTP-date Retry-After, got %v", elapsed) + } +} + +func TestDoRequest_429RetryAfterHTTPDateInPast(t *testing.T) { + attempts := 0 + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + attempts++ + if attempts == 1 { + // Use a time in the past — should result in zero/immediate retry. + past := time.Now().Add(-10 * time.Second).UTC() + w.Header().Set("Retry-After", past.Format(http.TimeFormat)) + w.WriteHeader(429) + w.Write([]byte(`{"message":"rate limit"}`)) + return + } + w.WriteHeader(200) + w.Write([]byte(`{"ok":true}`)) + })) + defer srv.Close() + + c := NewClient("token", srv.URL, AllowInsecureHTTP()) + c.SetHTTPClient(srv.Client()) + c.SetRetryBackoff([]time.Duration{5 * time.Second, 5 * time.Second}) + + start := time.Now() + _, err := c.doGet(context.Background(), srv.URL+"/test") + elapsed := time.Since(start) + + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if attempts != 2 { + t.Errorf("expected 2 attempts, got %d", attempts) + } + // Past date should override the 5s backoff to ~0 + if elapsed > 500*time.Millisecond { + t.Errorf("expected near-instant retry for past HTTP-date, got %v", elapsed) + } +} + +func TestDoRequest_SetsUserAgentHeader(t *testing.T) { + var gotUA string + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + gotUA = r.Header.Get("User-Agent") + w.WriteHeader(200) + w.Write([]byte("{}")) + })) + defer srv.Close() + + c := NewClient("token", srv.URL, AllowInsecureHTTP()) + c.SetHTTPClient(srv.Client()) + _, _ = c.doGet(context.Background(), srv.URL+"/test") + + if gotUA != "review-bot/1.0" { + t.Errorf("expected User-Agent 'review-bot/1.0', got %q", gotUA) + } +} + +func TestDoRequest_LimitsResponseBody(t *testing.T) { + // Verify that oversized responses return an error rather than silently truncating. + bigBody := strings.Repeat("x", maxResponseBytes+1024) + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(200) + w.Write([]byte(bigBody)) + })) + defer srv.Close() + + c := NewClient("token", srv.URL, AllowInsecureHTTP()) + c.SetHTTPClient(srv.Client()) + _, err := c.doGet(context.Background(), srv.URL+"/test") + if err == nil { + t.Fatal("expected error for oversized response body") + } + if !strings.Contains(err.Error(), "exceeded") { + t.Errorf("expected truncation error, got: %v", err) + } +} + +func TestDoRequest_AcceptsExactlyAtLimit(t *testing.T) { + // A response body exactly equal to maxResponseBytes should succeed (not error). + exactBody := strings.Repeat("x", maxResponseBytes) + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(200) + w.Write([]byte(exactBody)) + })) + defer srv.Close() + + c := NewClient("token", srv.URL, AllowInsecureHTTP()) + c.SetHTTPClient(srv.Client()) + body, err := c.doGet(context.Background(), srv.URL+"/test") + if err != nil { + t.Fatalf("unexpected error for exactly-at-limit body: %v", err) + } + if len(body) != maxResponseBytes { + t.Errorf("expected body length %d, got %d", maxResponseBytes, len(body)) + } +} + +func TestDoRequest_SkipsAuthWhenTokenEmpty(t *testing.T) { + var gotAuth string + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + gotAuth = r.Header.Get("Authorization") + w.WriteHeader(200) + w.Write([]byte("{}")) + })) + defer srv.Close() + + c := NewClient("", srv.URL, AllowInsecureHTTP()) // empty token + c.SetHTTPClient(srv.Client()) + _, _ = c.doGet(context.Background(), srv.URL+"/test") + + if gotAuth != "" { + t.Errorf("expected no Authorization header with empty token, got %q", gotAuth) + } +} + +func TestNewClient_CheckRedirectStripsAuthOnCrossHost(t *testing.T) { + // Verify the CheckRedirect function is configured + c := NewClient("secret-token", "https://api.github.com") + if c.httpClient.CheckRedirect == nil { + t.Fatal("expected CheckRedirect to be set") + } +} + +func TestDefaultCheckRedirect_RejectsHTTPSToHTTP(t *testing.T) { + prev := &http.Request{URL: &url.URL{Scheme: "https", Host: "api.github.com", Path: "/foo"}} + req := &http.Request{ + URL: &url.URL{Scheme: "http", Host: "api.github.com", Path: "/foo"}, + Header: http.Header{"Authorization": []string{"Bearer token"}}, + } + err := defaultCheckRedirect(req, []*http.Request{prev}) + if err == nil { + t.Fatal("expected error on HTTPS→HTTP redirect") + } + if !strings.Contains(err.Error(), "refusing redirect from HTTPS to HTTP") { + t.Errorf("unexpected error message: %v", err) + } +} + +func TestDefaultCheckRedirect_StripsAuthOnCrossHost(t *testing.T) { + prev := &http.Request{URL: &url.URL{Scheme: "https", Host: "api.github.com", Path: "/foo"}} + req := &http.Request{ + URL: &url.URL{Scheme: "https", Host: "objects.githubusercontent.com", Path: "/bar"}, + Header: http.Header{"Authorization": []string{"Bearer token"}}, + } + err := defaultCheckRedirect(req, []*http.Request{prev}) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if auth := req.Header.Get("Authorization"); auth != "" { + t.Errorf("expected Authorization header to be stripped, got %q", auth) + } +} + +func TestDefaultCheckRedirect_PreservesAuthOnSameHost(t *testing.T) { + prev := &http.Request{URL: &url.URL{Scheme: "https", Host: "api.github.com", Path: "/foo"}} + req := &http.Request{ + URL: &url.URL{Scheme: "https", Host: "api.github.com", Path: "/bar"}, + Header: http.Header{"Authorization": []string{"Bearer token"}}, + } + err := defaultCheckRedirect(req, []*http.Request{prev}) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if auth := req.Header.Get("Authorization"); auth != "Bearer token" { + t.Errorf("expected Authorization to be preserved, got %q", auth) + } +} + +func TestDoRequest_RejectsHTTPWithToken(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(200) + w.Write([]byte("{}")) + })) + defer srv.Close() + + // Without AllowInsecureHTTP, should refuse to send token over HTTP + c := NewClient("secret-token", srv.URL) + c.SetHTTPClient(srv.Client()) + _, err := c.doGet(context.Background(), srv.URL+"/test") + if err == nil { + t.Fatal("expected error when sending token over HTTP") + } + if !strings.Contains(err.Error(), "refusing to send credentials") { + t.Errorf("unexpected error message: %v", err) + } +} + +func TestDoRequest_AllowsHTTPWithoutToken(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(200) + w.Write([]byte(`{"ok":true}`)) + })) + defer srv.Close() + + // Without token, HTTP should be fine (no credentials to leak) + c := NewClient("", srv.URL) + c.SetHTTPClient(srv.Client()) + body, err := c.doGet(context.Background(), srv.URL+"/test") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if string(body) != `{"ok":true}` { + t.Errorf("unexpected body: %s", body) + } +} + +func TestDoRequest_AllowsHTTPWithInsecureOption(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(200) + w.Write([]byte(`{"ok":true}`)) + })) + defer srv.Close() + + c := NewClient("secret-token", srv.URL, AllowInsecureHTTP()) + c.SetHTTPClient(srv.Client()) + body, err := c.doGet(context.Background(), srv.URL+"/test") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if string(body) != `{"ok":true}` { + t.Errorf("unexpected body: %s", body) + } +} + +func TestSetHTTPClient_NilRestoresDefault(t *testing.T) { + c := NewClient("token", "https://api.github.com") + c.SetHTTPClient(nil) + if c.httpClient == nil { + t.Fatal("expected non-nil httpClient after SetHTTPClient(nil)") + } + if c.httpClient.Timeout != 30*time.Second { + t.Errorf("expected 30s timeout, got %v", c.httpClient.Timeout) + } + if c.httpClient.CheckRedirect == nil { + t.Fatal("expected CheckRedirect policy after SetHTTPClient(nil)") + } +} diff --git a/vcs/types.go b/vcs/types.go index de904f3..608ad27 100644 --- a/vcs/types.go +++ b/vcs/types.go @@ -44,6 +44,7 @@ type PullRequest struct { type ChangedFile struct { Filename string `json:"filename"` Status string `json:"status"` + Patch string `json:"patch"` } // ContentEntry represents a file or directory entry from the contents API.