diff --git a/github/client.go b/github/client.go index ab07a24..bd6f7e1 100644 --- a/github/client.go +++ b/github/client.go @@ -93,11 +93,16 @@ func (c *Client) SetHTTPClient(hc *http.Client) { } // doRequest performs an HTTP request with retry on 429 rate limit responses. -// It respects the Retry-After header when present. +// It respects the Retry-After header when present (capped at maxRetryAfter). func (c *Client) doRequest(ctx context.Context, method, url string, accept string) ([]byte, error) { const maxAttempts = 3 - backoff := c.RetryBackoff - if backoff == nil { + const maxRetryAfter = 120 * time.Second + + var backoff []time.Duration + if c.RetryBackoff != nil { + backoff = make([]time.Duration, len(c.RetryBackoff)) + copy(backoff, c.RetryBackoff) + } else { backoff = []time.Duration{1 * time.Second, 2 * time.Second} } @@ -125,7 +130,9 @@ func (c *Client) doRequest(ctx context.Context, method, url string, accept strin if err != nil { return nil, fmt.Errorf("create request: %w", err) } - req.Header.Set("Authorization", "Bearer "+c.token) + if c.token != "" { + req.Header.Set("Authorization", "Bearer "+c.token) + } if accept != "" { req.Header.Set("Accept", accept) } else { @@ -156,8 +163,12 @@ func (c *Client) doRequest(ctx context.Context, method, url string, accept strin // Check for Retry-After header and override backoff if present if ra := resp.Header.Get("Retry-After"); ra != "" { if seconds, err := strconv.Atoi(ra); err == nil && seconds > 0 { + delay := time.Duration(seconds) * time.Second + if delay > maxRetryAfter { + delay = maxRetryAfter + } if attempt < len(backoff) { - backoff[attempt] = time.Duration(seconds) * time.Second + backoff[attempt] = delay } } } diff --git a/github/client_test.go b/github/client_test.go index e3cd121..e00e534 100644 --- a/github/client_test.go +++ b/github/client_test.go @@ -185,6 +185,9 @@ func TestIsUnauthorized(t *testing.T) { } func TestDoRequest_429RetryAfterHeader(t *testing.T) { + if testing.Short() { + t.Skip("skipping slow retry test in short mode") + } attempts := 0 srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { attempts++ @@ -222,3 +225,39 @@ func TestDoRequest_429RetryAfterHeader(t *testing.T) { t.Errorf("expected ~1s delay from Retry-After, got %v", elapsed) } } + +func TestDoRequest_RetryAfterDoesNotMutateBackoff(t *testing.T) { + if testing.Short() { + t.Skip("skipping slow retry test in short mode") + } + attempts := 0 + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + attempts++ + if attempts == 1 { + w.Header().Set("Retry-After", "1") + w.WriteHeader(429) + w.Write([]byte(`{"message":"rate limit"}`)) + return + } + w.WriteHeader(200) + w.Write([]byte(`{"ok":true}`)) + })) + defer srv.Close() + + c := NewClient("token", srv.URL) + c.SetHTTPClient(srv.Client()) + c.RetryBackoff = []time.Duration{1 * time.Millisecond, 1 * time.Millisecond} + + _, err := c.doGet(context.Background(), srv.URL+"/test") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + // Verify the original RetryBackoff slice was not mutated + if c.RetryBackoff[0] != 1*time.Millisecond { + t.Errorf("RetryBackoff[0] was mutated: got %v, want 1ms", c.RetryBackoff[0]) + } + if c.RetryBackoff[1] != 1*time.Millisecond { + t.Errorf("RetryBackoff[1] was mutated: got %v, want 1ms", c.RetryBackoff[1]) + } +} diff --git a/github/files.go b/github/files.go index 9c162bf..a385623 100644 --- a/github/files.go +++ b/github/files.go @@ -11,8 +11,8 @@ import ( "gitea.weiker.me/rodin/review-bot/vcs" ) -// GetFileContent fetches a file from the default branch of a repo. -// Delegates to GetFileContentAtRef with an empty ref. +// GetFileContent fetches a file from a repo at the given ref. +// Delegates to GetFileContentAtRef with the provided ref. func (c *Client) GetFileContent(ctx context.Context, owner, repo, path, ref string) (string, error) { return c.GetFileContentAtRef(ctx, owner, repo, path, ref) } @@ -47,12 +47,17 @@ func (c *Client) ListContents(ctx context.Context, owner, repo, path string) ([] // escapePath escapes each segment of a relative file path for use in URLs. // Slashes are preserved as path separators; other special characters are escaped. +// Dot-segments ("." and "..") are removed to prevent path traversal. func escapePath(p string) string { parts := strings.Split(p, "/") - for i, part := range parts { - parts[i] = url.PathEscape(part) + var clean []string + for _, part := range parts { + if part == "." || part == ".." || part == "" { + continue + } + clean = append(clean, url.PathEscape(part)) } - return strings.Join(parts, "/") + return strings.Join(clean, "/") } // decodeBase64Content decodes base64-encoded content from the GitHub contents API. diff --git a/github/files_test.go b/github/files_test.go index bb76a0b..0c077d6 100644 --- a/github/files_test.go +++ b/github/files_test.go @@ -275,3 +275,23 @@ func TestDecodeBase64Content_Invalid(t *testing.T) { t.Fatal("expected error for invalid base64") } } + +func TestEscapePath_RejectsDotSegments(t *testing.T) { + tests := []struct { + input string + want string + }{ + {"src/main.go", "src/main.go"}, + {"../etc/passwd", "etc/passwd"}, + {"./src/../main.go", "src/main.go"}, + {"a/b/c", "a/b/c"}, + {"file with spaces.go", "file%20with%20spaces.go"}, + {"a/./b/../c", "a/b/c"}, + } + for _, tt := range tests { + got := escapePath(tt.input) + if got != tt.want { + t.Errorf("escapePath(%q) = %q, want %q", tt.input, got, tt.want) + } + } +} diff --git a/github/pr.go b/github/pr.go index 81bec09..0d1046f 100644 --- a/github/pr.go +++ b/github/pr.go @@ -205,7 +205,7 @@ func (c *Client) GetCommitStatuses(ctx context.Context, owner, repo, sha string) } // mapCheckRunStatus maps a check run conclusion+status to a vcs.CommitStatus status string. -func mapCheckRunStatus(conclusion *string, status string) string { +func mapCheckRunStatus(conclusion *string, _ string) string { if conclusion == nil { // Still running or queued return "pending" @@ -217,8 +217,6 @@ func mapCheckRunStatus(conclusion *string, status string) string { return "failure" case "cancelled", "skipped", "neutral": return "success" // non-blocking - case "in_progress", "queued": - return "pending" default: return "pending" }