Compare commits
2 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 7eeb3147db | |||
| d468ea6022 |
+4
-11
@@ -11,7 +11,6 @@ import (
|
||||
type PositionMap struct {
|
||||
// files maps filename → (position → new-file line number).
|
||||
// Deletion lines are mapped to -1 (no new-file line).
|
||||
// Hunk-header lines are mapped to 0 (no new-file line).
|
||||
files map[string]map[int]int
|
||||
// maxPositions caches the highest position number per file,
|
||||
// tracked during construction to avoid O(n) scans at translate time.
|
||||
@@ -20,8 +19,8 @@ type PositionMap struct {
|
||||
|
||||
// Translate converts a GitHub diff-position to a new-file line number for a given file.
|
||||
// Returns an error if the file is not in the diff or the position is out of range.
|
||||
// If the position targets a deletion or hunk-header line, it maps to the nearest
|
||||
// context/addition line below; if no such line exists, returns an error.
|
||||
// If the position targets a deletion line, it maps to the nearest non-deletion line below;
|
||||
// if no such line exists, returns an error.
|
||||
func (pm *PositionMap) Translate(file string, position int) (int, error) {
|
||||
if pm == nil || pm.files == nil {
|
||||
return 0, fmt.Errorf("empty position map")
|
||||
@@ -42,18 +41,14 @@ func (pm *PositionMap) Translate(file string, position int) (int, error) {
|
||||
}
|
||||
|
||||
// lineNum == -1 means this position is a deletion line.
|
||||
// lineNum == 0 means this position is a hunk-header line.
|
||||
// Both map to the nearest context/addition line below.
|
||||
if lineNum <= 0 {
|
||||
// Map to the nearest non-deletion line below.
|
||||
if lineNum == -1 {
|
||||
maxPos := pm.maxPosition(file)
|
||||
for p := position + 1; p <= maxPos; p++ {
|
||||
if ln, exists := fileMap[p]; exists && ln > 0 {
|
||||
return ln, nil
|
||||
}
|
||||
}
|
||||
if lineNum == 0 {
|
||||
return 0, fmt.Errorf("position %d targets a hunk-header line with no subsequent new-file line in %q", position, file)
|
||||
}
|
||||
return 0, fmt.Errorf("position %d targets a deletion line with no subsequent new-file line in %q", position, file)
|
||||
}
|
||||
|
||||
@@ -75,7 +70,6 @@ func (pm *PositionMap) maxPosition(file string) int {
|
||||
// - A new @@ hunk within the same file continues incrementing (does not reset)
|
||||
// - Position maps to the new file line number for additions and context lines
|
||||
// - Deletion lines have a position but no new-file line number (stored as -1)
|
||||
// - Hunk-header lines have a position but no new-file line number (stored as 0)
|
||||
func BuildPositionToLineMap(diff string) *PositionMap {
|
||||
pm := &PositionMap{
|
||||
files: make(map[string]map[int]int),
|
||||
@@ -132,7 +126,6 @@ func BuildPositionToLineMap(diff string) *PositionMap {
|
||||
// Parse hunk headers
|
||||
if strings.HasPrefix(line, "@@") && currentFile != "" {
|
||||
position++
|
||||
pm.files[currentFile][position] = 0 // sentinel: hunk-header has no new-file line
|
||||
pm.maxPositions[currentFile] = position
|
||||
newLine = parseHunkStart(line)
|
||||
continue
|
||||
|
||||
@@ -272,112 +272,3 @@ diff --git a/b.go b/b.go
|
||||
t.Errorf("Translate(b.go, 3) = %d, want 2", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestTranslate_HunkHeaderPosition_SingleHunk(t *testing.T) {
|
||||
// Position 1 is the @@ hunk-header line.
|
||||
// It should resolve to the first context/addition line below (new line 16).
|
||||
diff := `diff --git a/file.go b/file.go
|
||||
index abc..def 100644
|
||||
--- a/file.go
|
||||
+++ b/file.go
|
||||
@@ -16,4 +16,5 @@ func example() {
|
||||
context line
|
||||
-deleted line
|
||||
+added line
|
||||
context after
|
||||
`
|
||||
pm := BuildPositionToLineMap(diff)
|
||||
|
||||
got, err := pm.Translate("file.go", 1)
|
||||
if err != nil {
|
||||
t.Fatalf("Translate(file.go, 1): unexpected error: %v", err)
|
||||
}
|
||||
if got != 16 {
|
||||
t.Errorf("Translate(file.go, 1) = %d, want 16 (first context/addition line in hunk)", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestTranslate_HunkHeaderPosition_MultiHunk(t *testing.T) {
|
||||
// First hunk: @@ is pos 1, then " line1" (pos 2), "-old" (pos 3), "+new" (pos 4)
|
||||
// Second hunk: @@ is pos 5, then " func foo() {" (pos 6), "+// added" (pos 7), etc.
|
||||
// Translating position 5 (second @@) should resolve to new line 10.
|
||||
diff := `diff --git a/file.go b/file.go
|
||||
--- a/file.go
|
||||
+++ b/file.go
|
||||
@@ -1,3 +1,3 @@ package main
|
||||
line1
|
||||
-old
|
||||
+new
|
||||
@@ -10,3 +10,4 @@ func foo() {
|
||||
func foo() {
|
||||
+ // added
|
||||
return
|
||||
}
|
||||
`
|
||||
pm := BuildPositionToLineMap(diff)
|
||||
|
||||
// Position 5 is the second @@ hunk-header — should resolve to new line 10
|
||||
got, err := pm.Translate("file.go", 5)
|
||||
if err != nil {
|
||||
t.Fatalf("Translate(file.go, 5): unexpected error: %v", err)
|
||||
}
|
||||
if got != 10 {
|
||||
t.Errorf("Translate(file.go, 5) = %d, want 10 (first context/addition line in second hunk)", got)
|
||||
}
|
||||
|
||||
// Also verify first hunk header at position 1 resolves to new line 1
|
||||
got, err = pm.Translate("file.go", 1)
|
||||
if err != nil {
|
||||
t.Fatalf("Translate(file.go, 1): unexpected error: %v", err)
|
||||
}
|
||||
if got != 1 {
|
||||
t.Errorf("Translate(file.go, 1) = %d, want 1 (first context/addition line in first hunk)", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestTranslate_HunkHeaderPosition_NewFile(t *testing.T) {
|
||||
// New file: @@ -0,0 +1,3 @@ is position 1.
|
||||
// Should resolve to new line 1 (the first addition).
|
||||
diff := `diff --git a/new.go b/new.go
|
||||
new file mode 100644
|
||||
--- /dev/null
|
||||
+++ b/new.go
|
||||
@@ -0,0 +1,3 @@
|
||||
+package main
|
||||
+
|
||||
+func init() {}
|
||||
`
|
||||
pm := BuildPositionToLineMap(diff)
|
||||
|
||||
got, err := pm.Translate("new.go", 1)
|
||||
if err != nil {
|
||||
t.Fatalf("Translate(new.go, 1): unexpected error: %v", err)
|
||||
}
|
||||
if got != 1 {
|
||||
t.Errorf("Translate(new.go, 1) = %d, want 1 (first addition line)", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestTranslate_HunkHeaderAtEnd(t *testing.T) {
|
||||
// A hunk-header at the last position with no subsequent new-file line should error.
|
||||
// This is the hunk-header equivalent of TestBuildPositionToLineMap_DeletionAtEnd.
|
||||
diff := `diff --git a/file.go b/file.go
|
||||
--- a/file.go
|
||||
+++ b/file.go
|
||||
@@ -1,2 +1,2 @@ package main
|
||||
line1
|
||||
-old
|
||||
+new
|
||||
@@ -10,2 +10,1 @@ func foo() {
|
||||
-removed
|
||||
`
|
||||
pm := BuildPositionToLineMap(diff)
|
||||
|
||||
// Position 5 is the second @@ hunk-header; the only line after it (pos 6) is a
|
||||
// deletion (lineNum == -1), so there's no positive new-file line to resolve to.
|
||||
// The hunk-header lookup should fail.
|
||||
_, err := pm.Translate("file.go", 5)
|
||||
if err == nil {
|
||||
t.Error("expected error for hunk-header at end with no subsequent new-file line")
|
||||
}
|
||||
}
|
||||
|
||||
+8
-25
@@ -21,10 +21,6 @@ const (
|
||||
|
||||
// maxResponseBytes limits successful response body reads to 10 MiB.
|
||||
maxResponseBytes = 10 * 1024 * 1024
|
||||
|
||||
// maxRetryAttempts is the number of times doRequest will attempt a request.
|
||||
// The retry backoff slice must have length maxRetryAttempts-1.
|
||||
maxRetryAttempts = 3
|
||||
)
|
||||
|
||||
// APIError represents an HTTP error response from the GitHub API.
|
||||
@@ -182,33 +178,24 @@ func (c *Client) SetHTTPClient(hc *http.Client) {
|
||||
|
||||
// SetRetryBackoff configures the retry backoff durations for testing.
|
||||
// It must be called before any goroutines issue requests.
|
||||
// The slice must have exactly maxRetryAttempts-1 entries (one delay per retry gap).
|
||||
// In production the default {1s, 2s} applies.
|
||||
func (c *Client) SetRetryBackoff(d []time.Duration) error {
|
||||
if len(d) != maxRetryAttempts-1 {
|
||||
return fmt.Errorf("github: backoff length %d does not match maxRetryAttempts-1 (%d)", len(d), maxRetryAttempts-1)
|
||||
}
|
||||
func (c *Client) SetRetryBackoff(d []time.Duration) {
|
||||
c.retryBackoff = d
|
||||
return nil
|
||||
}
|
||||
|
||||
// doRequest performs an HTTP request with retry on 429 rate limit responses.
|
||||
// It respects the Retry-After header when present (capped at maxRetryAfter).
|
||||
// Transport errors (network failures, context cancellation) are not retried.
|
||||
func (c *Client) doRequest(ctx context.Context, method, reqURL string, accept string) ([]byte, error) {
|
||||
const maxAttempts = 3
|
||||
const maxRetryAfter = 120 * time.Second
|
||||
|
||||
// backoff holds per-attempt delays: backoff[i] is the delay before attempt i+1.
|
||||
// Length must be maxRetryAttempts-1 (one entry per retry gap).
|
||||
// SetRetryBackoff validates at configuration time; the default is always valid.
|
||||
defaultBackoff := []time.Duration{1 * time.Second, 2 * time.Second}
|
||||
var backoff []time.Duration
|
||||
if c.retryBackoff != nil && len(c.retryBackoff) == maxRetryAttempts-1 {
|
||||
if c.retryBackoff != nil {
|
||||
backoff = make([]time.Duration, len(c.retryBackoff))
|
||||
copy(backoff, c.retryBackoff)
|
||||
} else {
|
||||
backoff = make([]time.Duration, len(defaultBackoff))
|
||||
copy(backoff, defaultBackoff)
|
||||
backoff = []time.Duration{1 * time.Second, 2 * time.Second}
|
||||
}
|
||||
|
||||
// maxErrorBodyBytes limits how much of an error response body is stored.
|
||||
@@ -228,7 +215,7 @@ func (c *Client) doRequest(ctx context.Context, method, reqURL string, accept st
|
||||
}
|
||||
|
||||
var lastErr error
|
||||
for attempt := 0; attempt < maxRetryAttempts; attempt++ {
|
||||
for attempt := 0; attempt < maxAttempts; attempt++ {
|
||||
if attempt > 0 {
|
||||
var delay time.Duration
|
||||
if attempt-1 < len(backoff) {
|
||||
@@ -268,10 +255,6 @@ func (c *Client) doRequest(ctx context.Context, method, reqURL string, accept st
|
||||
return nil, fmt.Errorf("do request: %w", err)
|
||||
}
|
||||
|
||||
// Capture response metadata before handleResponse takes body ownership.
|
||||
respStatus := resp.StatusCode
|
||||
retryAfterHeader := resp.Header.Get("Retry-After")
|
||||
|
||||
body, done, err := c.handleResponse(resp, maxResponseBytes, maxErrorBodyBytes)
|
||||
if done {
|
||||
return body, err
|
||||
@@ -279,10 +262,10 @@ func (c *Client) doRequest(ctx context.Context, method, reqURL string, accept st
|
||||
lastErr = err
|
||||
|
||||
// Retry on 429 rate limit
|
||||
if respStatus == http.StatusTooManyRequests && attempt < maxRetryAttempts-1 {
|
||||
if resp.StatusCode == http.StatusTooManyRequests && attempt < maxAttempts-1 {
|
||||
// Check for Retry-After header and override backoff if present.
|
||||
// Supports both integer seconds (common) and HTTP-date format (RFC 7231).
|
||||
if ra := retryAfterHeader; ra != "" {
|
||||
if ra := resp.Header.Get("Retry-After"); ra != "" {
|
||||
if seconds, err := strconv.Atoi(ra); err == nil && seconds > 0 {
|
||||
delay := time.Duration(seconds) * time.Second
|
||||
if delay > maxRetryAfter {
|
||||
@@ -326,7 +309,7 @@ func (c *Client) handleResponse(resp *http.Response, maxRespBytes int, maxErrByt
|
||||
return nil, true, fmt.Errorf("read response body: %w", err)
|
||||
}
|
||||
if len(body) > maxRespBytes {
|
||||
return nil, true, fmt.Errorf("response body exceeded %d bytes", maxRespBytes)
|
||||
return nil, true, fmt.Errorf("response body exceeded %d bytes (truncated)", maxRespBytes)
|
||||
}
|
||||
return body, true, nil
|
||||
}
|
||||
|
||||
+6
-44
@@ -83,9 +83,7 @@ func TestDoRequest_429Retry(t *testing.T) {
|
||||
|
||||
c := NewClient("token", srv.URL, AllowInsecureHTTP())
|
||||
c.SetHTTPClient(srv.Client())
|
||||
if err := c.SetRetryBackoff([]time.Duration{10 * time.Millisecond, 10 * time.Millisecond}); err != nil {
|
||||
t.Fatalf("SetRetryBackoff: %v", err)
|
||||
}
|
||||
c.SetRetryBackoff([]time.Duration{10 * time.Millisecond, 10 * time.Millisecond})
|
||||
|
||||
body, err := c.doGet(context.Background(), srv.URL+"/test")
|
||||
if err != nil {
|
||||
@@ -110,9 +108,7 @@ func TestDoRequest_429ExhaustsRetries(t *testing.T) {
|
||||
|
||||
c := NewClient("token", srv.URL, AllowInsecureHTTP())
|
||||
c.SetHTTPClient(srv.Client())
|
||||
if err := c.SetRetryBackoff([]time.Duration{1 * time.Millisecond, 1 * time.Millisecond}); err != nil {
|
||||
t.Fatalf("SetRetryBackoff: %v", err)
|
||||
}
|
||||
c.SetRetryBackoff([]time.Duration{1 * time.Millisecond, 1 * time.Millisecond})
|
||||
|
||||
_, err := c.doGet(context.Background(), srv.URL+"/test")
|
||||
if err == nil {
|
||||
@@ -222,9 +218,7 @@ func TestDoRequest_429RetryAfterHeader(t *testing.T) {
|
||||
c := NewClient("token", srv.URL, AllowInsecureHTTP())
|
||||
c.SetHTTPClient(srv.Client())
|
||||
// Use short backoff; Retry-After should override
|
||||
if err := c.SetRetryBackoff([]time.Duration{1 * time.Millisecond, 1 * time.Millisecond}); err != nil {
|
||||
t.Fatalf("SetRetryBackoff: %v", err)
|
||||
}
|
||||
c.SetRetryBackoff([]time.Duration{1 * time.Millisecond, 1 * time.Millisecond})
|
||||
|
||||
start := time.Now()
|
||||
body, err := c.doGet(context.Background(), srv.URL+"/test")
|
||||
@@ -265,9 +259,7 @@ func TestDoRequest_RetryAfterDoesNotMutateBackoff(t *testing.T) {
|
||||
|
||||
c := NewClient("token", srv.URL, AllowInsecureHTTP())
|
||||
c.SetHTTPClient(srv.Client())
|
||||
if err := c.SetRetryBackoff([]time.Duration{1 * time.Millisecond, 1 * time.Millisecond}); err != nil {
|
||||
t.Fatalf("SetRetryBackoff: %v", err)
|
||||
}
|
||||
c.SetRetryBackoff([]time.Duration{1 * time.Millisecond, 1 * time.Millisecond})
|
||||
|
||||
_, err := c.doGet(context.Background(), srv.URL+"/test")
|
||||
if err != nil {
|
||||
@@ -305,9 +297,7 @@ func TestDoRequest_429RetryAfterHTTPDate(t *testing.T) {
|
||||
|
||||
c := NewClient("token", srv.URL, AllowInsecureHTTP())
|
||||
c.SetHTTPClient(srv.Client())
|
||||
if err := c.SetRetryBackoff([]time.Duration{1 * time.Millisecond, 1 * time.Millisecond}); err != nil {
|
||||
t.Fatalf("SetRetryBackoff: %v", err)
|
||||
}
|
||||
c.SetRetryBackoff([]time.Duration{1 * time.Millisecond, 1 * time.Millisecond})
|
||||
|
||||
start := time.Now()
|
||||
body, err := c.doGet(context.Background(), srv.URL+"/test")
|
||||
@@ -348,9 +338,7 @@ func TestDoRequest_429RetryAfterHTTPDateInPast(t *testing.T) {
|
||||
|
||||
c := NewClient("token", srv.URL, AllowInsecureHTTP())
|
||||
c.SetHTTPClient(srv.Client())
|
||||
if err := c.SetRetryBackoff([]time.Duration{5 * time.Second, 5 * time.Second}); err != nil {
|
||||
t.Fatalf("SetRetryBackoff: %v", err)
|
||||
}
|
||||
c.SetRetryBackoff([]time.Duration{5 * time.Second, 5 * time.Second})
|
||||
|
||||
start := time.Now()
|
||||
_, err := c.doGet(context.Background(), srv.URL+"/test")
|
||||
@@ -566,29 +554,3 @@ func TestSetHTTPClient_NilRestoresDefault(t *testing.T) {
|
||||
t.Fatal("expected CheckRedirect policy after SetHTTPClient(nil)")
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
func TestSetRetryBackoff_RejectsInvalidLength(t *testing.T) {
|
||||
c := NewClient("token", "https://api.github.com")
|
||||
|
||||
// Too short
|
||||
err := c.SetRetryBackoff([]time.Duration{1 * time.Second})
|
||||
if err == nil {
|
||||
t.Fatal("expected error for backoff length 1")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "backoff length 1") {
|
||||
t.Errorf("unexpected error message: %v", err)
|
||||
}
|
||||
|
||||
// Too long
|
||||
err = c.SetRetryBackoff([]time.Duration{1 * time.Second, 2 * time.Second, 3 * time.Second})
|
||||
if err == nil {
|
||||
t.Fatal("expected error for backoff length 3")
|
||||
}
|
||||
|
||||
// Correct length succeeds
|
||||
err = c.SetRetryBackoff([]time.Duration{1 * time.Second, 2 * time.Second})
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error for valid backoff: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
+32
-57
@@ -6,7 +6,6 @@ import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/url"
|
||||
"path"
|
||||
"strings"
|
||||
|
||||
"gitea.weiker.me/rodin/review-bot/vcs"
|
||||
@@ -14,28 +13,25 @@ import (
|
||||
|
||||
// GetFileContent fetches a file from a repo at the given ref.
|
||||
// Delegates to GetFileContentAtRef with the provided ref.
|
||||
func (c *Client) GetFileContent(ctx context.Context, owner, repo, filePath, ref string) (string, error) {
|
||||
return c.GetFileContentAtRef(ctx, owner, repo, filePath, ref)
|
||||
func (c *Client) GetFileContent(ctx context.Context, owner, repo, path, ref string) (string, error) {
|
||||
return c.GetFileContentAtRef(ctx, owner, repo, path, ref)
|
||||
}
|
||||
|
||||
// GetFileContentAtRef fetches a file at a specific ref from a repo.
|
||||
// If ref is empty, the query parameter is omitted (uses default branch).
|
||||
//
|
||||
// Returns an error if the path contains dot-segments (".", "..") or
|
||||
// attempts to traverse above the repository root.
|
||||
func (c *Client) GetFileContentAtRef(ctx context.Context, owner, repo, filePath, ref string) (string, error) {
|
||||
escaped, err := escapePath(filePath)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("invalid file path: %w", err)
|
||||
}
|
||||
// Note: dot-segments ("." and "..") in the path are silently removed to
|
||||
// prevent path traversal. This means a path like "foo/../bar" resolves
|
||||
// to "foo/bar" rather than "bar".
|
||||
func (c *Client) GetFileContentAtRef(ctx context.Context, owner, repo, path, ref string) (string, error) {
|
||||
reqURL := fmt.Sprintf("%s/repos/%s/%s/contents/%s",
|
||||
c.baseURL, url.PathEscape(owner), url.PathEscape(repo), escaped)
|
||||
c.baseURL, url.PathEscape(owner), url.PathEscape(repo), escapePath(path))
|
||||
if ref != "" {
|
||||
reqURL += "?ref=" + url.QueryEscape(ref)
|
||||
}
|
||||
body, err := c.doGet(ctx, reqURL)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("fetch file %s: %w", filePath, err)
|
||||
return "", fmt.Errorf("fetch file %s: %w", path, err)
|
||||
}
|
||||
var resp struct {
|
||||
Content string `json:"content"`
|
||||
@@ -45,11 +41,11 @@ func (c *Client) GetFileContentAtRef(ctx context.Context, owner, repo, filePath,
|
||||
return "", fmt.Errorf("parse file content JSON: %w", err)
|
||||
}
|
||||
if resp.Encoding != "base64" {
|
||||
return "", fmt.Errorf("unexpected encoding %q for file %s", resp.Encoding, filePath)
|
||||
return "", fmt.Errorf("unexpected encoding %q for file %s", resp.Encoding, path)
|
||||
}
|
||||
decoded, err := decodeBase64Content(resp.Content)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("decode base64 content for %s: %w", filePath, err)
|
||||
return "", fmt.Errorf("decode base64 content for %s: %w", path, err)
|
||||
}
|
||||
return decoded, nil
|
||||
}
|
||||
@@ -59,16 +55,16 @@ func (c *Client) GetFileContentAtRef(ctx context.Context, owner, repo, filePath,
|
||||
// If the path points to a single file (not a directory), the API returns
|
||||
// a JSON object instead of an array; this is handled by returning a
|
||||
// single-element slice.
|
||||
func (c *Client) ListContents(ctx context.Context, owner, repo, filePath string) ([]vcs.ContentEntry, error) {
|
||||
escaped, err := escapePath(filePath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid file path: %w", err)
|
||||
}
|
||||
//
|
||||
// Note: dot-segments ("." and "..") in the path are silently removed to
|
||||
// prevent path traversal. This means a path like "foo/../bar" resolves
|
||||
// to "foo/bar" rather than "bar".
|
||||
func (c *Client) ListContents(ctx context.Context, owner, repo, path string) ([]vcs.ContentEntry, error) {
|
||||
reqURL := fmt.Sprintf("%s/repos/%s/%s/contents/%s",
|
||||
c.baseURL, url.PathEscape(owner), url.PathEscape(repo), escaped)
|
||||
c.baseURL, url.PathEscape(owner), url.PathEscape(repo), escapePath(path))
|
||||
body, err := c.doGet(ctx, reqURL)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("list contents %s: %w", filePath, err)
|
||||
return nil, fmt.Errorf("list contents %s: %w", path, err)
|
||||
}
|
||||
|
||||
type entry struct {
|
||||
@@ -106,55 +102,34 @@ func (c *Client) ListContents(ctx context.Context, owner, repo, filePath string)
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// escapePath validates and encodes a slash-separated file path for use in
|
||||
// GitHub API URLs. Returns an error if the path contains dot-segments ("."
|
||||
// or "..") or resolves to a path outside the repository root.
|
||||
func escapePath(p string) (string, error) {
|
||||
// Reject paths containing dot-segments rather than silently rewriting them.
|
||||
for _, seg := range strings.Split(p, "/") {
|
||||
if seg == "." || seg == ".." {
|
||||
return "", fmt.Errorf("path contains dot-segment %q: %s", seg, p)
|
||||
}
|
||||
}
|
||||
|
||||
// Use path.Clean for canonical form, then verify it doesn't escape root.
|
||||
cleaned := path.Clean(p)
|
||||
if cleaned == "." || strings.HasPrefix(cleaned, "..") {
|
||||
return "", fmt.Errorf("path resolves outside repository root: %s", p)
|
||||
}
|
||||
|
||||
// Encode each segment individually.
|
||||
parts := strings.Split(cleaned, "/")
|
||||
var encoded []string
|
||||
// escapePath escapes each segment of a relative file path for use in URLs.
|
||||
// Slashes are preserved as path separators; other special characters are escaped.
|
||||
// Dot-segments ("." and "..") and empty segments (from consecutive slashes like
|
||||
// "a//b") are silently removed to prevent path traversal and produce canonical
|
||||
// paths. This is intentional: callers may receive a different path than requested
|
||||
// without error. The function is package-private, and all callers
|
||||
// (GetFileContentAtRef, ListContents) already handle missing-file errors from the
|
||||
// API if the cleaned path doesn't match what the caller intended.
|
||||
func escapePath(p string) string {
|
||||
parts := strings.Split(p, "/")
|
||||
var clean []string
|
||||
for _, part := range parts {
|
||||
if part == "" {
|
||||
if part == "." || part == ".." || part == "" {
|
||||
continue
|
||||
}
|
||||
encoded = append(encoded, url.PathEscape(part))
|
||||
clean = append(clean, url.PathEscape(part))
|
||||
}
|
||||
return strings.Join(encoded, "/"), nil
|
||||
return strings.Join(clean, "/")
|
||||
}
|
||||
|
||||
// maxFileContentSize is the maximum decoded file size (10 MB) to prevent
|
||||
// resource exhaustion when decoding base64 content from the API.
|
||||
const maxFileContentSize = 10 * 1024 * 1024
|
||||
|
||||
// decodeBase64Content decodes base64-encoded content from the GitHub contents API.
|
||||
// GitHub returns base64 content with line breaks for formatting; we strip \r and \n before decoding.
|
||||
// Returns an error if the decoded content exceeds maxFileContentSize.
|
||||
func decodeBase64Content(encoded string) (string, error) {
|
||||
// GitHub inserts newlines in base64 content
|
||||
cleaned := strings.NewReplacer("\n", "", "\r", "").Replace(encoded)
|
||||
// Check estimated decoded size before allocating.
|
||||
// Base64 encodes 3 bytes into 4 chars, so decoded ~ len*3/4.
|
||||
if len(cleaned)*3/4 > maxFileContentSize {
|
||||
return "", fmt.Errorf("file content too large: estimated %d bytes exceeds limit of %d", len(cleaned)*3/4, maxFileContentSize)
|
||||
}
|
||||
decoded, err := base64.StdEncoding.DecodeString(cleaned)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if len(decoded) > maxFileContentSize {
|
||||
return "", fmt.Errorf("file content too large: %d bytes exceeds limit of %d", len(decoded), maxFileContentSize)
|
||||
}
|
||||
return string(decoded), nil
|
||||
}
|
||||
|
||||
+54
-125
@@ -5,7 +5,6 @@ import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
@@ -110,9 +109,7 @@ func TestGetFileContent_429Retry(t *testing.T) {
|
||||
|
||||
c := NewClient("token", srv.URL, AllowInsecureHTTP())
|
||||
c.SetHTTPClient(srv.Client())
|
||||
if err := c.SetRetryBackoff([]time.Duration{1 * time.Millisecond, 1 * time.Millisecond}); err != nil {
|
||||
t.Fatalf("SetRetryBackoff: %v", err)
|
||||
}
|
||||
c.SetRetryBackoff([]time.Duration{1 * time.Millisecond})
|
||||
|
||||
content, err := c.GetFileContent(context.Background(), "owner", "repo", "file.go", "")
|
||||
if err != nil {
|
||||
@@ -230,11 +227,9 @@ func TestListContents_429Retry(t *testing.T) {
|
||||
|
||||
c := NewClient("token", srv.URL, AllowInsecureHTTP())
|
||||
c.SetHTTPClient(srv.Client())
|
||||
if err := c.SetRetryBackoff([]time.Duration{1 * time.Millisecond, 1 * time.Millisecond}); err != nil {
|
||||
t.Fatalf("SetRetryBackoff: %v", err)
|
||||
}
|
||||
c.SetRetryBackoff([]time.Duration{1 * time.Millisecond})
|
||||
|
||||
entries, err := c.ListContents(context.Background(), "owner", "repo", "src")
|
||||
entries, err := c.ListContents(context.Background(), "owner", "repo", ".")
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
@@ -262,6 +257,57 @@ func TestListContents_MalformedJSON(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestDecodeBase64Content(t *testing.T) {
|
||||
// Test with newlines (GitHub's format)
|
||||
encoded := "cGFja2FnZSBt\nYWlu"
|
||||
decoded, err := decodeBase64Content(encoded)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if decoded != "package main" {
|
||||
t.Errorf("expected 'package main', got %q", decoded)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDecodeBase64Content_Invalid(t *testing.T) {
|
||||
_, err := decodeBase64Content("not!!!valid!!!base64")
|
||||
if err == nil {
|
||||
t.Fatal("expected error for invalid base64")
|
||||
}
|
||||
}
|
||||
|
||||
func TestEscapePath_RejectsDotSegments(t *testing.T) {
|
||||
tests := []struct {
|
||||
input string
|
||||
want string
|
||||
}{
|
||||
{"src/main.go", "src/main.go"},
|
||||
{"../etc/passwd", "etc/passwd"},
|
||||
{"./src/../main.go", "src/main.go"},
|
||||
{"a/b/c", "a/b/c"},
|
||||
{"file with spaces.go", "file%20with%20spaces.go"},
|
||||
{"a/./b/../c", "a/b/c"},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
got := escapePath(tt.input)
|
||||
if got != tt.want {
|
||||
t.Errorf("escapePath(%q) = %q, want %q", tt.input, got, tt.want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestDecodeBase64Content_CRLF(t *testing.T) {
|
||||
// Base64 of "hello world" with CRLF line breaks inserted
|
||||
encoded := "aGVs\r\nbG8g\r\nd29y\r\nbGQ="
|
||||
decoded, err := decodeBase64Content(encoded)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if decoded != "hello world" {
|
||||
t.Errorf("expected 'hello world', got %q", decoded)
|
||||
}
|
||||
}
|
||||
|
||||
func TestListContents_SingleFile(t *testing.T) {
|
||||
// GitHub Contents API returns a JSON object (not array) for single-file paths
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
@@ -286,120 +332,3 @@ func TestListContents_SingleFile(t *testing.T) {
|
||||
t.Errorf("expected type 'file', got %q", entries[0].Type)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEscapePath_ValidPaths(t *testing.T) {
|
||||
t.Parallel()
|
||||
tests := []struct {
|
||||
name string
|
||||
path string
|
||||
want string
|
||||
}{
|
||||
{"simple file", "file.go", "file.go"},
|
||||
{"nested path", "path/to/file.go", "path/to/file.go"},
|
||||
{"special chars", "path/to/my file.go", "path/to/my%20file.go"},
|
||||
{"leading slash stripped", "/path/to/file.go", "path/to/file.go"},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
got, err := escapePath(tt.path)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if got != tt.want {
|
||||
t.Errorf("escapePath(%q) = %q, want %q", tt.path, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestEscapePath_DotSegments(t *testing.T) {
|
||||
t.Parallel()
|
||||
tests := []struct {
|
||||
name string
|
||||
path string
|
||||
}{
|
||||
{"single dot", "./file.go"},
|
||||
{"double dot", "../file.go"},
|
||||
{"dot in middle", "path/./file.go"},
|
||||
{"parent traversal", "path/../file.go"},
|
||||
{"only dots", ".."},
|
||||
{"nested parent traversal", "a/b/../../c"},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
_, err := escapePath(tt.path)
|
||||
if err == nil {
|
||||
t.Fatalf("expected error for path %q, got nil", tt.path)
|
||||
}
|
||||
if !strings.Contains(err.Error(), "dot-segment") {
|
||||
t.Errorf("expected error about dot-segment, got: %v", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetFileContentAtRef_DotSegmentError(t *testing.T) {
|
||||
// Server should never be called — the error is caught before the request.
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
t.Fatal("server should not have been called")
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
c := NewClient("token", srv.URL, AllowInsecureHTTP())
|
||||
_, err := c.GetFileContentAtRef(context.Background(), "owner", "repo", "foo/../bar.go", "main")
|
||||
if err == nil {
|
||||
t.Fatal("expected error for path with dot-segments")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "invalid file path") {
|
||||
t.Errorf("expected 'invalid file path' error, got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDecodeBase64Content(t *testing.T) {
|
||||
// Test with newlines (GitHub's format)
|
||||
encoded := "cGFja2FnZSBt\nYWlu"
|
||||
decoded, err := decodeBase64Content(encoded)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if decoded != "package main" {
|
||||
t.Errorf("expected 'package main', got %q", decoded)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDecodeBase64Content_Invalid(t *testing.T) {
|
||||
_, err := decodeBase64Content("not!!!valid!!!base64")
|
||||
if err == nil {
|
||||
t.Fatal("expected error for invalid base64")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDecodeBase64Content_CRLF(t *testing.T) {
|
||||
// Base64 of "hello world" with CRLF line breaks inserted
|
||||
encoded := "aGVs\r\nbG8g\r\nd29y\r\nbGQ="
|
||||
decoded, err := decodeBase64Content(encoded)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if decoded != "hello world" {
|
||||
t.Errorf("expected 'hello world', got %q", decoded)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDecodeBase64Content_SizeLimit(t *testing.T) {
|
||||
t.Parallel()
|
||||
// Create base64 content that would decode to > maxFileContentSize.
|
||||
// maxFileContentSize is 10MB. Base64 of 11MB worth of zeros.
|
||||
// We just need something big enough to trigger the estimated size check.
|
||||
// 14MB of base64 chars (decodes to ~10.5MB).
|
||||
huge := strings.Repeat("A", 14*1024*1024)
|
||||
_, err := decodeBase64Content(huge)
|
||||
if err == nil {
|
||||
t.Fatal("expected error for oversized content")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "too large") {
|
||||
t.Errorf("expected 'too large' error, got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
+16
-26
@@ -51,10 +51,7 @@ type checkRunsResponse struct {
|
||||
} `json:"check_runs"`
|
||||
}
|
||||
|
||||
// GetPullRequest fetches PR metadata from the GitHub API.
|
||||
// Returns an *APIError wrapping the HTTP status on non-2xx responses (e.g.
|
||||
// IsNotFound for 404, IsUnauthorized for 401). Network and context errors
|
||||
// are wrapped but not typed as *APIError.
|
||||
// GetPullRequest fetches PR metadata.
|
||||
func (c *Client) GetPullRequest(ctx context.Context, owner, repo string, number int) (*vcs.PullRequest, error) {
|
||||
reqURL := fmt.Sprintf("%s/repos/%s/%s/pulls/%d", c.baseURL, url.PathEscape(owner), url.PathEscape(repo), number)
|
||||
body, err := c.doGet(ctx, reqURL)
|
||||
@@ -85,15 +82,9 @@ func (c *Client) GetPullRequestDiff(ctx context.Context, owner, repo string, num
|
||||
return string(body), nil
|
||||
}
|
||||
|
||||
const (
|
||||
// maxFilesPages is the upper bound on pagination loops for PR file listing,
|
||||
// preventing unbounded iteration if the server always returns a full page.
|
||||
maxFilesPages = 100
|
||||
|
||||
// maxCheckRunPages is the upper bound on pagination loops for check-run listing,
|
||||
// preventing unbounded iteration if the server always returns a full page.
|
||||
maxCheckRunPages = 100
|
||||
)
|
||||
// maxPages is the upper bound on pagination loops to prevent unbounded iteration
|
||||
// in case the server returns a full page indefinitely.
|
||||
const maxPages = 100
|
||||
|
||||
// GetPullRequestFiles fetches the list of files changed in a PR.
|
||||
// Paginates through all pages (100 per page) to collect all files.
|
||||
@@ -102,7 +93,7 @@ const (
|
||||
func (c *Client) GetPullRequestFiles(ctx context.Context, owner, repo string, number int) ([]vcs.ChangedFile, error) {
|
||||
var allFiles []vcs.ChangedFile
|
||||
|
||||
for page := 1; page <= maxFilesPages; page++ {
|
||||
for page := 1; page <= maxPages; page++ {
|
||||
reqURL := fmt.Sprintf("%s/repos/%s/%s/pulls/%d/files?per_page=100&page=%d",
|
||||
c.baseURL, url.PathEscape(owner), url.PathEscape(repo), number, page)
|
||||
body, err := c.doGet(ctx, reqURL)
|
||||
@@ -163,7 +154,7 @@ func (c *Client) GetCommitStatuses(ctx context.Context, owner, repo, sha string)
|
||||
}
|
||||
|
||||
// Fetch check runs (paginated)
|
||||
for checkPage := 1; checkPage <= maxCheckRunPages; checkPage++ {
|
||||
for checkPage := 1; checkPage <= maxPages; checkPage++ {
|
||||
checkURL := fmt.Sprintf("%s/repos/%s/%s/commits/%s/check-runs?per_page=100&page=%d",
|
||||
c.baseURL, url.PathEscape(owner), url.PathEscape(repo), url.PathEscape(sha), checkPage)
|
||||
checkBody, err := c.doGet(ctx, checkURL)
|
||||
@@ -178,7 +169,7 @@ func (c *Client) GetCommitStatuses(ctx context.Context, owner, repo, sha string)
|
||||
result = append(result, vcs.CommitStatus{
|
||||
Context: cr.Name,
|
||||
Status: mapCheckRunStatus(cr.Conclusion),
|
||||
Description: "", // check runs have no human-readable description; conclusion is captured in Status
|
||||
Description: derefString(cr.Conclusion),
|
||||
TargetURL: cr.HTMLURL,
|
||||
})
|
||||
}
|
||||
@@ -190,17 +181,9 @@ func (c *Client) GetCommitStatuses(ctx context.Context, owner, repo, sha string)
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// mapCheckRunStatus maps a GitHub check run conclusion to a vcs.CommitStatus status string.
|
||||
// mapCheckRunStatus maps a check run conclusion to a vcs.CommitStatus status string.
|
||||
// Conclusion alone determines the mapped state: nil conclusion means the run is
|
||||
// still in progress (pending), regardless of the status field value.
|
||||
//
|
||||
// Mapping rules:
|
||||
// - nil → "pending" (run still in progress or queued)
|
||||
// - "success" → "success"
|
||||
// - "failure", "action_required", "timed_out" → "failure"
|
||||
// - "cancelled", "skipped", "neutral" → "success" (non-blocking per GitHub check suite semantics)
|
||||
// - "stale" → "pending" (check run became stale before completing)
|
||||
// - unknown values → "pending" (conservative: treat unrecognized conclusions as incomplete)
|
||||
func mapCheckRunStatus(conclusion *string) string {
|
||||
if conclusion == nil {
|
||||
// Still running or queued
|
||||
@@ -213,10 +196,17 @@ func mapCheckRunStatus(conclusion *string) string {
|
||||
return "failure"
|
||||
case "cancelled", "skipped", "neutral":
|
||||
return "success" // non-blocking: these do not indicate a blocking failure per GitHub check suite semantics
|
||||
case "stale":
|
||||
case "stale", "waiting":
|
||||
return "pending"
|
||||
default:
|
||||
return "pending"
|
||||
}
|
||||
}
|
||||
|
||||
// derefString safely dereferences a string pointer, returning empty string if nil.
|
||||
func derefString(s *string) string {
|
||||
if s == nil {
|
||||
return ""
|
||||
}
|
||||
return *s
|
||||
}
|
||||
|
||||
@@ -545,7 +545,6 @@ func TestGetCommitStatuses_CheckRunConclusions(t *testing.T) {
|
||||
name = *tt.conclusion
|
||||
}
|
||||
t.Run(name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if strings.Contains(r.URL.Path, "/status") {
|
||||
json.NewEncoder(w).Encode(map[string]interface{}{
|
||||
@@ -633,44 +632,6 @@ func TestGetCommitStatuses_MalformedJSON(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetCommitStatuses_CheckRunsErrorAfterStatusesSucceed(t *testing.T) {
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
switch {
|
||||
case strings.Contains(r.URL.Path, "/status"):
|
||||
// Statuses succeed
|
||||
json.NewEncoder(w).Encode(map[string]interface{}{
|
||||
"state": "success",
|
||||
"statuses": []map[string]string{
|
||||
{
|
||||
"context": "ci/build",
|
||||
"state": "success",
|
||||
"description": "Build passed",
|
||||
"target_url": "https://ci.example.com/1",
|
||||
},
|
||||
},
|
||||
})
|
||||
case strings.Contains(r.URL.Path, "/check-runs"):
|
||||
// Check runs fail with 500
|
||||
w.WriteHeader(500)
|
||||
w.Write([]byte(`{"message":"Internal Server Error"}`))
|
||||
default:
|
||||
w.WriteHeader(404)
|
||||
}
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
c := NewClient("token", srv.URL, AllowInsecureHTTP())
|
||||
c.SetHTTPClient(srv.Client())
|
||||
|
||||
_, err := c.GetCommitStatuses(context.Background(), "owner", "repo", "abc123")
|
||||
if err == nil {
|
||||
t.Fatal("expected error when check-runs endpoint fails after statuses succeed")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "fetch check runs") {
|
||||
t.Errorf("expected check runs error, got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func stringPtr(s string) *string {
|
||||
return &s
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user