diff --git a/github/client.go b/github/client.go new file mode 100644 index 0000000..2dc27ca --- /dev/null +++ b/github/client.go @@ -0,0 +1,260 @@ +// Package github provides a client for the GitHub API. +// It supports pull request operations, file content retrieval, +// and review submission for both github.com and GitHub Enterprise. +package github + +import ( + "context" + "errors" + "fmt" + "io" + "net/http" + "strconv" + "strings" + "time" +) + +const ( + defaultBaseURL = "https://api.github.com" + + // maxRetryAttempts is the number of times doRequest will attempt a request. + maxRetryAttempts = 3 + + // maxRetryAfter caps the maximum delay from a Retry-After header to prevent + // a server from stalling the client indefinitely. + maxRetryAfter = 60 * time.Second + + // maxErrorBodyBytes limits how much of an error response body we read + // to protect against malicious servers sending unbounded data. + maxErrorBodyBytes = 64 * 1024 // 64 KB + + // maxResponseBodyBytes limits how much of a successful response body we read + // for defense-in-depth against servers returning excessively large payloads. + maxResponseBodyBytes = 10 * 1024 * 1024 // 10 MB +) + +// APIError represents an HTTP error response from the GitHub API. +// It carries the status code so callers can distinguish between +// different failure modes (e.g. 404 vs 500). +// +// The Body field stores up to 64 KiB of the raw response for programmatic +// inspection. Error() truncates to 200 bytes for safe logging, but callers +// should avoid logging or propagating Body directly in production since it may +// contain sensitive details from the upstream server. +type APIError struct { + StatusCode int + Body string +} + +func (e *APIError) Error() string { + body := e.Body + if len(body) > 200 { + body = body[:200] + "...(truncated)" + } + // Sanitize newlines to prevent log injection from upstream response bodies. + body = strings.ReplaceAll(body, "\n", " ") + body = strings.ReplaceAll(body, "\r", " ") + return fmt.Sprintf("HTTP %d: %s", e.StatusCode, body) +} + +// IsNotFound reports whether an error is an API 404 response. +func IsNotFound(err error) bool { + if apiErr, ok := asAPIError(err); ok { + return apiErr.StatusCode == http.StatusNotFound + } + return false +} + +// IsUnauthorized reports whether an error is an API 401 response. +func IsUnauthorized(err error) bool { + if apiErr, ok := asAPIError(err); ok { + return apiErr.StatusCode == http.StatusUnauthorized + } + return false +} + +func asAPIError(err error) (*APIError, bool) { + if err == nil { + return nil, false + } + var target *APIError + if errors.As(err, &target) { + return target, true + } + return nil, false +} + +// Client interacts with the GitHub API. +// A Client is safe for concurrent use by multiple goroutines. +// SetHTTPClient and SetRetryBackoff are intended for test setup only and must +// be called before any goroutines issue requests; they have no synchronization. +type Client struct { + // TODO: baseURL is populated by NewClient but not yet consumed by doRequest/doGet. + // Higher-level exported methods (GetPullRequest, etc.) will use it to + // construct request URLs; remove this field if those methods end up + // accepting full URLs instead. + baseURL string + token string + httpClient *http.Client + + // retryBackoff defines the delays between retry attempts for 429 responses. + // retryBackoff[i] is the delay before attempt i+1 (after attempt i fails). + // If nil, defaults to {1s, 2s}. + retryBackoff []time.Duration + + // now returns the current time. Defaults to time.Now. + // Override in tests to control HTTP-date Retry-After calculations. + now func() time.Time +} + +// NewClient creates a new GitHub API client. +// If baseURL is empty, it defaults to https://api.github.com. +// For GitHub Enterprise, pass the API base URL (e.g. https://github.concur.com/api/v3). +func NewClient(token, baseURL string) *Client { + if baseURL == "" { + baseURL = defaultBaseURL + } + return &Client{ + baseURL: strings.TrimRight(baseURL, "/"), + token: token, + httpClient: &http.Client{Timeout: 30 * time.Second}, + now: time.Now, + } +} + +// SetHTTPClient sets the underlying HTTP client used for requests. +// This is intended for testing to inject mock transports. +func (c *Client) SetHTTPClient(hc *http.Client) { + c.httpClient = hc +} + +// SetRetryBackoff sets the delays between retry attempts. +// This is intended for testing to speed up retry tests. +// +// Note: if an empty non-nil slice is provided, Retry-After delays parsed from +// server responses will be computed and capped but not applied (because +// attempt < len(backoff) is always false). This is acceptable for the +// test-only use case but callers should be aware of this edge case. +func (c *Client) SetRetryBackoff(backoff []time.Duration) { + c.retryBackoff = backoff +} + +// parseRetryAfter parses a Retry-After header value, supporting both integer +// seconds (e.g. "120") and HTTP-date format (e.g. "Thu, 01 Dec 2025 16:00:00 GMT") +// as specified in RFC 7231 §7.1.3. +// +// For integer values, it returns the duration directly. +// For HTTP-date values, it computes the delay as the difference between the +// parsed time and now. If the date is in the past, it returns 0. +// +// Returns (0, false) if the value cannot be parsed as either format. +func (c *Client) parseRetryAfter(value string) (time.Duration, bool) { + value = strings.TrimSpace(value) + + // Try integer seconds first (most common from GitHub). + // RFC 7231 allows delta-seconds of 0 to indicate immediate retry. + if seconds, err := strconv.Atoi(value); err == nil && seconds >= 0 { + return time.Duration(seconds) * time.Second, true + } + + // Try HTTP-date format (RFC 7231 §7.1.3). + // http.ParseTime handles RFC 1123, RFC 850, and ASCTIME formats. + if retryAt, err := http.ParseTime(value); err == nil { + delay := retryAt.Sub(c.now()) + if delay < 0 { + delay = 0 + } + return delay, true + } + + return 0, false +} + +// doRequest performs an HTTP request with retry on 429 rate limit responses. +// It respects the Retry-After header when present, supporting both integer +// seconds and HTTP-date formats (capped at maxRetryAfter). +func (c *Client) doRequest(ctx context.Context, method, reqURL string, accept string) ([]byte, error) { + var backoff []time.Duration + if c.retryBackoff != nil { + backoff = append([]time.Duration(nil), c.retryBackoff...) + } else { + backoff = []time.Duration{1 * time.Second, 2 * time.Second} + } + + var lastErr error + for attempt := 0; attempt < maxRetryAttempts; attempt++ { + if attempt > 0 { + var delay time.Duration + if attempt-1 < len(backoff) { + delay = backoff[attempt-1] + } + if delay > 0 { + timer := time.NewTimer(delay) + select { + case <-timer.C: + timer.Stop() + case <-ctx.Done(): + timer.Stop() + return nil, ctx.Err() + } + } + } + + req, err := http.NewRequestWithContext(ctx, method, reqURL, nil) + if err != nil { + return nil, fmt.Errorf("create request: %w", err) + } + req.Header.Set("Authorization", "Bearer "+c.token) + if accept != "" { + req.Header.Set("Accept", accept) + } else { + req.Header.Set("Accept", "application/vnd.github+json") + } + + resp, err := c.httpClient.Do(req) + if err != nil { + return nil, fmt.Errorf("do request: %w", err) + } + + if resp.StatusCode >= 200 && resp.StatusCode < 300 { + body, err := io.ReadAll(io.LimitReader(resp.Body, maxResponseBodyBytes)) + resp.Body.Close() + if err != nil { + return nil, fmt.Errorf("read response body: %w", err) + } + return body, nil + } + + errBody, _ := io.ReadAll(io.LimitReader(resp.Body, maxErrorBodyBytes)) + resp.Body.Close() + + lastErr = &APIError{StatusCode: resp.StatusCode, Body: string(errBody)} + + // Retry on 429 rate limit + if resp.StatusCode == http.StatusTooManyRequests && attempt < maxRetryAttempts-1 { + // Check for Retry-After header and override backoff if present. + // Supports both integer seconds (common) and HTTP-date format (RFC 7231). + if ra := resp.Header.Get("Retry-After"); ra != "" { + if delay, ok := c.parseRetryAfter(ra); ok { + if delay > maxRetryAfter { + delay = maxRetryAfter + } + if attempt < len(backoff) { + backoff[attempt] = delay + } + } + } + continue + } + + // Don't retry other errors + return nil, lastErr + } + + return nil, lastErr +} + +// doGet is a convenience wrapper for GET requests with the default Accept header. +func (c *Client) doGet(ctx context.Context, url string) ([]byte, error) { + return c.doRequest(ctx, http.MethodGet, url, "") +} diff --git a/github/client_test.go b/github/client_test.go new file mode 100644 index 0000000..8a5e4a3 --- /dev/null +++ b/github/client_test.go @@ -0,0 +1,409 @@ +package github + +import ( + "context" + "errors" + "net/http" + "net/http/httptest" + "testing" + "time" +) + +func TestNewClient_DefaultBaseURL(t *testing.T) { + c := NewClient("tok", "") + if c.baseURL != defaultBaseURL { + t.Errorf("baseURL = %q, want %q", c.baseURL, defaultBaseURL) + } +} + +func TestNewClient_CustomBaseURL(t *testing.T) { + c := NewClient("tok", "https://github.concur.com/api/v3/") + if c.baseURL != "https://github.concur.com/api/v3" { + t.Errorf("baseURL = %q, want trailing slash stripped", c.baseURL) + } +} + +func TestDoRequest_Success(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if got := r.Header.Get("Authorization"); got != "Bearer test-token" { + t.Errorf("Authorization = %q, want Bearer test-token", got) + } + w.WriteHeader(http.StatusOK) + w.Write([]byte(`{"ok":true}`)) + })) + defer srv.Close() + + c := NewClient("test-token", srv.URL) + body, err := c.doGet(context.Background(), srv.URL+"/test") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if string(body) != `{"ok":true}` { + t.Errorf("body = %q, want %q", body, `{"ok":true}`) + } +} + +func TestDoRequest_429_RetryAfter_IntegerSeconds(t *testing.T) { + attempts := 0 + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + attempts++ + if attempts == 1 { + w.Header().Set("Retry-After", "0") + w.WriteHeader(http.StatusTooManyRequests) + w.Write([]byte("rate limited")) + return + } + w.WriteHeader(http.StatusOK) + w.Write([]byte("success")) + })) + defer srv.Close() + + c := NewClient("tok", srv.URL) + c.SetRetryBackoff([]time.Duration{0, 0}) + + body, err := c.doGet(context.Background(), srv.URL+"/test") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if string(body) != "success" { + t.Errorf("body = %q, want %q", body, "success") + } + if attempts != 2 { + t.Errorf("attempts = %d, want 2", attempts) + } +} + +func TestDoRequest_429_RetryAfter_HTTPDate(t *testing.T) { + // Fix "now" to a known time for deterministic testing. + fixedNow := time.Date(2025, 12, 1, 15, 59, 59, 0, time.UTC) + retryAt := "Mon, 01 Dec 2025 16:00:00 GMT" // 1 second in the future + + attempts := 0 + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + attempts++ + if attempts == 1 { + w.Header().Set("Retry-After", retryAt) + w.WriteHeader(http.StatusTooManyRequests) + w.Write([]byte("rate limited")) + return + } + w.WriteHeader(http.StatusOK) + w.Write([]byte("success")) + })) + defer srv.Close() + + c := NewClient("tok", srv.URL) + c.now = func() time.Time { return fixedNow } + // Initial backoff is 0; the HTTP-date parser will compute 1s and override. + c.SetRetryBackoff([]time.Duration{0, 0}) + + body, err := c.doGet(context.Background(), srv.URL+"/test") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if string(body) != "success" { + t.Errorf("body = %q, want %q", body, "success") + } + if attempts != 2 { + t.Errorf("attempts = %d, want 2", attempts) + } +} + +func TestDoRequest_429_RetryAfter_HTTPDate_InPast(t *testing.T) { + // If the HTTP-date is in the past, delay should be 0 (retry immediately). + fixedNow := time.Date(2025, 12, 1, 17, 0, 0, 0, time.UTC) + retryAt := "Mon, 01 Dec 2025 16:00:00 GMT" // 1 hour in the past + + attempts := 0 + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + attempts++ + if attempts == 1 { + w.Header().Set("Retry-After", retryAt) + w.WriteHeader(http.StatusTooManyRequests) + w.Write([]byte("rate limited")) + return + } + w.WriteHeader(http.StatusOK) + w.Write([]byte("success")) + })) + defer srv.Close() + + c := NewClient("tok", srv.URL) + c.now = func() time.Time { return fixedNow } + c.SetRetryBackoff([]time.Duration{0, 0}) + + body, err := c.doGet(context.Background(), srv.URL+"/test") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if string(body) != "success" { + t.Errorf("body = %q, want %q", body, "success") + } +} + +func TestDoRequest_429_NoRetryAfter_UsesDefaultBackoff(t *testing.T) { + attempts := 0 + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + attempts++ + if attempts == 1 { + w.WriteHeader(http.StatusTooManyRequests) + w.Write([]byte("rate limited")) + return + } + w.WriteHeader(http.StatusOK) + w.Write([]byte("success")) + })) + defer srv.Close() + + c := NewClient("tok", srv.URL) + c.SetRetryBackoff([]time.Duration{0, 0}) + + body, err := c.doGet(context.Background(), srv.URL+"/test") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if string(body) != "success" { + t.Errorf("body = %q, want %q", body, "success") + } + if attempts != 2 { + t.Errorf("attempts = %d, want 2", attempts) + } +} + +func TestDoRequest_429_InvalidRetryAfter_UsesDefaultBackoff(t *testing.T) { + attempts := 0 + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + attempts++ + if attempts == 1 { + w.Header().Set("Retry-After", "not-a-number-or-date") + w.WriteHeader(http.StatusTooManyRequests) + w.Write([]byte("rate limited")) + return + } + w.WriteHeader(http.StatusOK) + w.Write([]byte("success")) + })) + defer srv.Close() + + c := NewClient("tok", srv.URL) + c.SetRetryBackoff([]time.Duration{0, 0}) + + body, err := c.doGet(context.Background(), srv.URL+"/test") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if string(body) != "success" { + t.Errorf("body = %q, want %q", body, "success") + } +} + +func TestDoRequest_404_NoRetry(t *testing.T) { + attempts := 0 + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + attempts++ + w.WriteHeader(http.StatusNotFound) + w.Write([]byte("not found")) + })) + defer srv.Close() + + c := NewClient("tok", srv.URL) + _, err := c.doGet(context.Background(), srv.URL+"/test") + if err == nil { + t.Fatal("expected error, got nil") + } + if !IsNotFound(err) { + t.Errorf("expected IsNotFound, got %v", err) + } + if attempts != 1 { + t.Errorf("attempts = %d, want 1 (no retry on 404)", attempts) + } +} + +func TestDoRequest_401_NoRetry(t *testing.T) { + attempts := 0 + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + attempts++ + w.WriteHeader(http.StatusUnauthorized) + w.Write([]byte("unauthorized")) + })) + defer srv.Close() + + c := NewClient("tok", srv.URL) + _, err := c.doGet(context.Background(), srv.URL+"/test") + if err == nil { + t.Fatal("expected error, got nil") + } + if !IsUnauthorized(err) { + t.Errorf("expected IsUnauthorized, got %v", err) + } + if attempts != 1 { + t.Errorf("attempts = %d, want 1 (no retry on 401)", attempts) + } +} + +func TestDoRequest_ContextCanceled(t *testing.T) { + // This test exercises the timer-cancel path in the retry select: + // select { case <-timer.C; case <-ctx.Done() } + // The server returns 429 with a long Retry-After, and we cancel the + // context shortly after the first response so that cancellation races + // against the timer rather than preventing the initial HTTP round-trip. + requestReceived := make(chan struct{}, 1) + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + select { + case requestReceived <- struct{}{}: + default: + } + w.Header().Set("Retry-After", "10") + w.WriteHeader(http.StatusTooManyRequests) + })) + defer srv.Close() + + c := NewClient("tok", srv.URL) + c.SetRetryBackoff([]time.Duration{10 * time.Second, 10 * time.Second}) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + // Cancel the context after the first request completes, while the + // client is blocked in the retry timer select. + go func() { + <-requestReceived + // Small delay to ensure we're inside the timer select. + time.Sleep(50 * time.Millisecond) + cancel() + }() + + _, err := c.doGet(ctx, srv.URL+"/test") + if err == nil { + t.Fatal("expected error, got nil") + } + if !errors.Is(err, context.Canceled) { + t.Errorf("err = %v, want context.Canceled", err) + } +} + +func TestParseRetryAfter_IntegerSeconds(t *testing.T) { + c := NewClient("tok", "") + delay, ok := c.parseRetryAfter("42") + if !ok { + t.Fatal("expected ok=true") + } + if delay != 42*time.Second { + t.Errorf("delay = %v, want 42s", delay) + } +} + +func TestParseRetryAfter_ZeroSeconds(t *testing.T) { + c := NewClient("tok", "") + delay, ok := c.parseRetryAfter("0") + if !ok { + t.Fatal("expected ok=true for zero seconds (RFC 7231 allows immediate retry)") + } + if delay != 0 { + t.Errorf("delay = %v, want 0", delay) + } +} + +func TestParseRetryAfter_NegativeSeconds(t *testing.T) { + c := NewClient("tok", "") + _, ok := c.parseRetryAfter("-5") + if ok { + t.Error("expected ok=false for negative seconds") + } +} + +func TestParseRetryAfter_HTTPDate_Future(t *testing.T) { + fixedNow := time.Date(2025, 12, 1, 15, 59, 50, 0, time.UTC) + c := NewClient("tok", "") + c.now = func() time.Time { return fixedNow } + + delay, ok := c.parseRetryAfter("Mon, 01 Dec 2025 16:00:00 GMT") + if !ok { + t.Fatal("expected ok=true") + } + // Should be 10 seconds in the future. + if delay != 10*time.Second { + t.Errorf("delay = %v, want 10s", delay) + } +} + +func TestParseRetryAfter_HTTPDate_Past(t *testing.T) { + fixedNow := time.Date(2025, 12, 1, 17, 0, 0, 0, time.UTC) + c := NewClient("tok", "") + c.now = func() time.Time { return fixedNow } + + delay, ok := c.parseRetryAfter("Mon, 01 Dec 2025 16:00:00 GMT") + if !ok { + t.Fatal("expected ok=true") + } + if delay != 0 { + t.Errorf("delay = %v, want 0 (past date)", delay) + } +} + +func TestParseRetryAfter_RFC850_Format(t *testing.T) { + fixedNow := time.Date(2025, 12, 1, 15, 59, 50, 0, time.UTC) + c := NewClient("tok", "") + c.now = func() time.Time { return fixedNow } + + // RFC 850 format + delay, ok := c.parseRetryAfter("Monday, 01-Dec-25 16:00:00 GMT") + if !ok { + t.Fatal("expected ok=true for RFC 850 format") + } + if delay != 10*time.Second { + t.Errorf("delay = %v, want 10s", delay) + } +} + +func TestParseRetryAfter_Invalid(t *testing.T) { + c := NewClient("tok", "") + _, ok := c.parseRetryAfter("not-valid") + if ok { + t.Error("expected ok=false for invalid value") + } +} + +func TestParseRetryAfter_EmptyString(t *testing.T) { + c := NewClient("tok", "") + _, ok := c.parseRetryAfter("") + if ok { + t.Error("expected ok=false for empty string") + } +} + +func TestParseRetryAfter_MaxCap(t *testing.T) { + // Verify that parseRetryAfter returns the raw value (capping is done by caller). + c := NewClient("tok", "") + delay, ok := c.parseRetryAfter("3600") + if !ok { + t.Fatal("expected ok=true") + } + if delay != 3600*time.Second { + t.Errorf("delay = %v, want 3600s (caller is responsible for capping)", delay) + } +} + +func TestAPIError_Error_Truncation(t *testing.T) { + longBody := make([]byte, 300) + for i := range longBody { + longBody[i] = 'x' + } + apiErr := &APIError{StatusCode: 500, Body: string(longBody)} + msg := apiErr.Error() + if len(msg) > 250 { + // "HTTP 500: " (10) + 200 + "...(truncated)" (14) = 224 + t.Errorf("error message too long: %d chars", len(msg)) + } +} + +func TestAPIError_Error_NewlineSanitized(t *testing.T) { + apiErr := &APIError{StatusCode: 400, Body: "line1\nline2\rline3"} + msg := apiErr.Error() + for _, c := range msg { + if c == '\n' || c == '\r' { + t.Errorf("error message contains unsanitized newline: %q", msg) + break + } + } +}