Compare commits

..

1 Commits

Author SHA1 Message Date
claw 90959da830 feat(vcs): add util.go with GetAllFilesInPath and BuildLineToPositionMap (#84)
CI / test (pull_request) Successful in 18s
CI / review (anthropic--claude-4.6-sonnet, sonnet, SONNET_REVIEW_TOKEN) (pull_request) Successful in 32s
CI / review (gpt-5, security, ., rodin/security-patterns, SECURITY_REVIEW.md, SECURITY_REVIEW_TOKEN) (pull_request) Successful in 57s
CI / review (gpt-5, gpt, GPT_REVIEW_TOKEN) (pull_request) Successful in 2m40s
Add vcs/util.go containing two utility functions that were specified
in issue #78 but omitted from PR #83:

- GetAllFilesInPath: recursively fetches all file contents under a path
  using the vcs.FileReader interface. Returns map[string]string of
  path -> content.

- BuildLineToPositionMap: parses a unified diff and returns per-file
  mapping of new-file line numbers to GitHub diff-position convention.
  Position is 1-indexed from @@ hunk headers. Deletion lines are not
  mapped (no new-file line number). Position counter continues across
  hunks within the same file but resets for each new file.

Unit tests cover:
- GetAllFilesInPath: empty dir, flat dir, nested dirs, mixed, errors
- BuildLineToPositionMap: single hunk, multi-hunk, deletions not mapped,
  multiple files, empty diff, no-newline marker, single-line hunk,
  deleted file
2026-05-12 12:31:10 -07:00
16 changed files with 573 additions and 2786 deletions
+6 -10
View File
@@ -15,7 +15,6 @@ import (
"gitea.weiker.me/rodin/review-bot/gitea" "gitea.weiker.me/rodin/review-bot/gitea"
"gitea.weiker.me/rodin/review-bot/llm" "gitea.weiker.me/rodin/review-bot/llm"
"gitea.weiker.me/rodin/review-bot/review" "gitea.weiker.me/rodin/review-bot/review"
"gitea.weiker.me/rodin/review-bot/vcs"
) )
var version = "dev" var version = "dev"
@@ -813,7 +812,7 @@ func shouldSkipStaleReview(evaluatedSHA, currentSHA string) bool {
return evaluatedSHA != currentSHA return evaluatedSHA != currentSHA
} }
// giteaClientAdapter adapts gitea.Client to vcs.FileReader interface. // giteaClientAdapter adapts gitea.Client to review.GiteaClient interface.
type giteaClientAdapter struct { type giteaClientAdapter struct {
client *gitea.Client client *gitea.Client
} }
@@ -822,14 +821,14 @@ func newGiteaClientAdapter(c *gitea.Client) *giteaClientAdapter {
return &giteaClientAdapter{client: c} return &giteaClientAdapter{client: c}
} }
func (a *giteaClientAdapter) ListContents(ctx context.Context, owner, repo, path string) ([]vcs.ContentEntry, error) { func (a *giteaClientAdapter) ListContents(ctx context.Context, owner, repo, path string) ([]review.ContentEntry, error) {
entries, err := a.client.ListContents(ctx, owner, repo, path) entries, err := a.client.ListContents(ctx, owner, repo, path)
if err != nil { if err != nil {
return nil, err return nil, err
} }
result := make([]vcs.ContentEntry, len(entries)) result := make([]review.ContentEntry, len(entries))
for i, e := range entries { for i, e := range entries {
result[i] = vcs.ContentEntry{ result[i] = review.ContentEntry{
Name: e.Name, Name: e.Name,
Path: e.Path, Path: e.Path,
Type: e.Type, Type: e.Type,
@@ -838,9 +837,6 @@ func (a *giteaClientAdapter) ListContents(ctx context.Context, owner, repo, path
return result, nil return result, nil
} }
func (a *giteaClientAdapter) GetFileContent(ctx context.Context, owner, repo, filePath, ref string) (string, error) { func (a *giteaClientAdapter) GetFileContent(ctx context.Context, owner, repo, filepath string) (string, error) {
if ref != "" { return a.client.GetFileContent(ctx, owner, repo, filepath)
return a.client.GetFileContentRef(ctx, owner, repo, filePath, ref)
}
return a.client.GetFileContent(ctx, owner, repo, filePath)
} }
-12
View File
@@ -831,15 +831,3 @@ func (c *Client) ResolveComment(ctx context.Context, owner, repo string, comment
} }
return nil return nil
} }
// DismissReview dismisses a review on a pull request.
// This is a stub for the vcs.Reviewer interface; full implementation is Phase 2.
func (c *Client) DismissReview(ctx context.Context, owner, repo string, number int, reviewID int64, message string) error {
return fmt.Errorf("dismiss review %d on %s/%s#%d: %w", reviewID, owner, repo, number, errors.ErrUnsupported)
}
// GetFileContentAtRef fetches a file at a specific ref from a repo.
// This delegates to GetFileContentRef for the Gitea implementation.
func (c *Client) GetFileContentAtRef(ctx context.Context, owner, repo, path, ref string) (string, error) {
return c.GetFileContentRef(ctx, owner, repo, path, ref)
}
-25
View File
@@ -1,25 +0,0 @@
//go:build phase2
package gitea_test
import (
"gitea.weiker.me/rodin/review-bot/gitea"
"gitea.weiker.me/rodin/review-bot/vcs"
)
// Compile-time interface conformance assertions.
// These will verify gitea.Client satisfies vcs interfaces once the Phase 2
// adapter bridges the method signature gaps:
//
// - PRReader: GetPullRequest returns *gitea.PullRequest (needs *vcs.PullRequest)
// - PRReader: GetPullRequestFiles returns []gitea.ChangedFile (needs []vcs.ChangedFile)
// - FileReader: GetFileContent lacks ref parameter
// - Reviewer: PostReview uses (event, body, comments) instead of vcs.ReviewRequest
//
// Remove the phase2 build tag once the adapter is complete.
var (
_ vcs.PRReader = (*gitea.Client)(nil)
_ vcs.FileReader = (*gitea.Client)(nil)
_ vcs.Reviewer = (*gitea.Client)(nil)
_ vcs.Identity = (*gitea.Client)(nil)
)
-341
View File
@@ -1,341 +0,0 @@
// Package github provides a client for the GitHub API.
// It supports pull request operations, file content retrieval, CI status checks,
// and directory listing for both github.com and GitHub Enterprise.
package github
import (
"context"
"errors"
"fmt"
"io"
"net/http"
"net/url"
"strconv"
"strings"
"time"
)
const (
defaultBaseURL = "https://api.github.com"
userAgent = "review-bot/1.0"
// maxResponseBytes limits successful response body reads to 10 MiB.
maxResponseBytes = 10 * 1024 * 1024
)
// APIError represents an HTTP error response from the GitHub API.
// It carries the status code so callers can distinguish between
// different failure modes (e.g. 404 vs 500).
//
// The Body field stores up to 4 KiB of the raw response for programmatic
// inspection. Error() truncates to 200 bytes for safe logging, but callers
// should avoid logging or propagating Body directly in production since it may
// contain sensitive details from the upstream server.
type APIError struct {
StatusCode int
Body string
}
func (e *APIError) Error() string {
body := e.Body
if len(body) > 200 {
body = body[:200] + "...(truncated)"
}
// Sanitize newlines to prevent log injection from upstream response bodies.
body = strings.ReplaceAll(body, "\n", " ")
body = strings.ReplaceAll(body, "\r", " ")
return fmt.Sprintf("HTTP %d: %s", e.StatusCode, body)
}
// SafeError returns the error string without response body content,
// suitable for logging in contexts where upstream response data should
// not be exposed.
func (e *APIError) SafeError() string {
return fmt.Sprintf("HTTP %d", e.StatusCode)
}
// IsNotFound reports whether an error is an API 404 response.
func IsNotFound(err error) bool {
if apiErr, ok := asAPIError(err); ok {
return apiErr.StatusCode == http.StatusNotFound
}
return false
}
// IsUnauthorized reports whether an error is an API 401 response.
func IsUnauthorized(err error) bool {
if apiErr, ok := asAPIError(err); ok {
return apiErr.StatusCode == http.StatusUnauthorized
}
return false
}
func asAPIError(err error) (*APIError, bool) {
if err == nil {
return nil, false
}
var target *APIError
if errors.As(err, &target) {
return target, true
}
return nil, false
}
// clientConfig holds optional configuration for NewClient.
type clientConfig struct {
allowInsecureHTTP bool
}
// ClientOption configures optional behavior of NewClient.
type ClientOption func(*clientConfig)
// AllowInsecureHTTP permits the client to use HTTP (non-TLS) base URLs.
// This should only be used for trusted internal deployments or testing.
func AllowInsecureHTTP() ClientOption {
return func(c *clientConfig) {
c.allowInsecureHTTP = true
}
}
// Client interacts with the GitHub API.
// A Client is safe for concurrent use by multiple goroutines.
// SetHTTPClient and SetRetryBackoff are intended for test setup only and must
// be called before any goroutines issue requests; they have no synchronization.
type Client struct {
baseURL string
token string
allowInsecureHTTP bool
httpClient *http.Client
// retryBackoff defines the delays between retry attempts for 429 responses.
// retryBackoff[i] is the delay before attempt i+1 (after attempt i fails).
// If nil, defaults to {1s, 2s}. Set to shorter durations in tests via SetRetryBackoff.
retryBackoff []time.Duration
}
// defaultCheckRedirect is the redirect policy used by NewClient and SetHTTPClient(nil).
// It rejects HTTPS→HTTP protocol downgrades (to prevent plaintext leakage) and strips
// the Authorization header on cross-host redirects to prevent credential leakage to
// third-party hosts (e.g. CDN redirects from GitHub).
func defaultCheckRedirect(req *http.Request, via []*http.Request) error {
if len(via) >= 10 {
return fmt.Errorf("stopped after 10 redirects")
}
// Guard: net/http guarantees len(via) >= 1 but this is undocumented;
// defend against zero-length to avoid panic on index out of range.
if len(via) == 0 {
return nil
}
prev := via[len(via)-1]
// Reject protocol downgrade: HTTPS→HTTP leaks request metadata over plaintext.
if prev.URL.Scheme == "https" && req.URL.Scheme == "http" {
return fmt.Errorf("refusing redirect from HTTPS to HTTP (%s → %s)", prev.URL.Host, req.URL.Host)
}
// Strip Authorization on cross-host redirect to avoid leaking credentials
// to third-party hosts (GitHub legitimately redirects to CDN hosts).
if req.URL.Host != prev.URL.Host {
req.Header.Del("Authorization")
}
return nil
}
// NewClient creates a new GitHub API client.
// If baseURL is empty, it defaults to https://api.github.com.
// For GitHub Enterprise, pass the API base URL (e.g. https://github.concur.com/api/v3).
// The baseURL must use HTTPS; pass AllowInsecureHTTP() as an option to permit HTTP
// for trusted internal deployments (e.g. local testing).
func NewClient(token, baseURL string, opts ...ClientOption) *Client {
if baseURL == "" {
baseURL = defaultBaseURL
}
cfg := clientConfig{}
for _, o := range opts {
o(&cfg)
}
return &Client{
baseURL: strings.TrimRight(baseURL, "/"),
allowInsecureHTTP: cfg.allowInsecureHTTP,
token: token,
httpClient: &http.Client{
Timeout: 30 * time.Second,
CheckRedirect: defaultCheckRedirect,
},
}
}
// SetHTTPClient sets the underlying HTTP client used for requests.
// This is intended for test setup only to inject mock transports; it must be
// called before any goroutines issue requests.
//
// Passing nil restores the default client (30s timeout + auth-stripping
// CheckRedirect policy matching NewClient).
//
// Callers providing a non-nil client are responsible for configuring a safe
// CheckRedirect policy. Without one, the default net/http behavior will follow
// redirects and may forward the Authorization header to untrusted hosts.
func (c *Client) SetHTTPClient(hc *http.Client) {
if hc == nil {
hc = &http.Client{
Timeout: 30 * time.Second,
CheckRedirect: defaultCheckRedirect,
}
} else if hc.CheckRedirect == nil {
// Enforce safe redirect policy when caller provides a client without one.
// The default net/http behavior follows up to 10 redirects and forwards
// all headers (including Authorization) to any host, which can leak
// credentials on cross-host redirects.
hc.CheckRedirect = defaultCheckRedirect
}
c.httpClient = hc
}
// SetRetryBackoff configures the retry backoff durations for testing.
// It must be called before any goroutines issue requests.
// In production the default {1s, 2s} applies.
func (c *Client) SetRetryBackoff(d []time.Duration) {
c.retryBackoff = d
}
// doRequest performs an HTTP request with retry on 429 rate limit responses.
// It respects the Retry-After header when present (capped at maxRetryAfter).
// Transport errors (network failures, context cancellation) are not retried.
func (c *Client) doRequest(ctx context.Context, method, reqURL string, accept string) ([]byte, error) {
const maxAttempts = 3
const maxRetryAfter = 120 * time.Second
var backoff []time.Duration
if c.retryBackoff != nil {
backoff = make([]time.Duration, len(c.retryBackoff))
copy(backoff, c.retryBackoff)
} else {
backoff = []time.Duration{1 * time.Second, 2 * time.Second}
}
// maxErrorBodyBytes limits how much of an error response body is stored.
// Kept small (4 KiB) to reduce the risk of sensitive data leakage if callers
// log APIError.Body directly. Error() further truncates to 200 bytes.
const maxErrorBodyBytes = 4 * 1024
// Reject non-HTTPS URLs early since the URL is immutable across retries.
if c.token != "" && !c.allowInsecureHTTP {
parsed, err := url.Parse(reqURL)
if err != nil {
return nil, fmt.Errorf("parse request URL: %w", err)
}
if !strings.EqualFold(parsed.Scheme, "https") {
return nil, fmt.Errorf("refusing to send credentials over non-HTTPS URL %q (use AllowInsecureHTTP option for trusted networks)", reqURL)
}
}
var lastErr error
for attempt := 0; attempt < maxAttempts; attempt++ {
if attempt > 0 {
var delay time.Duration
if attempt-1 < len(backoff) {
delay = backoff[attempt-1]
}
if delay > 0 {
timer := time.NewTimer(delay)
select {
case <-timer.C:
timer.Stop() // no-op after fire; kept for symmetry with the ctx.Done case
case <-ctx.Done():
timer.Stop()
return nil, ctx.Err()
}
}
}
req, err := http.NewRequestWithContext(ctx, method, reqURL, nil)
if err != nil {
return nil, fmt.Errorf("create request: %w", err)
}
if c.token != "" {
// Bearer is the OAuth2 standard and is accepted by GitHub for both
// classic PATs and fine-grained tokens. The alternative "token" scheme
// is GitHub-specific and offers no additional compatibility.
req.Header.Set("Authorization", "Bearer "+c.token)
}
req.Header.Set("User-Agent", userAgent)
if accept != "" {
req.Header.Set("Accept", accept)
} else {
req.Header.Set("Accept", "application/vnd.github+json")
}
resp, err := c.httpClient.Do(req)
if err != nil {
// Transport errors (DNS, TLS, timeout) yield nil resp; no body to close.
return nil, fmt.Errorf("do request: %w", err)
}
body, done, err := handleResponse(resp, maxResponseBytes, maxErrorBodyBytes)
if done {
return body, err
}
lastErr = err
// Retry on 429 rate limit
if resp.StatusCode == http.StatusTooManyRequests && attempt < maxAttempts-1 {
// Check for Retry-After header and override backoff if present.
// Supports both integer seconds (common) and HTTP-date format (RFC 7231).
if ra := resp.Header.Get("Retry-After"); ra != "" {
if seconds, err := strconv.Atoi(ra); err == nil && seconds > 0 {
delay := time.Duration(seconds) * time.Second
if delay > maxRetryAfter {
delay = maxRetryAfter
}
if attempt < len(backoff) {
backoff[attempt] = delay
}
} else if retryAt, err := http.ParseTime(ra); err == nil {
delay := time.Until(retryAt)
if delay < 0 {
delay = 0
}
if delay > maxRetryAfter {
delay = maxRetryAfter
}
if attempt < len(backoff) {
backoff[attempt] = delay
}
}
}
continue
}
// Don't retry other errors
return nil, lastErr
}
return nil, lastErr
}
// handleResponse reads and closes the response body, returning the result.
// It uses defer to ensure the body is always closed regardless of code path.
// Returns (body, done, err) where done=true means the caller should return immediately.
func handleResponse(resp *http.Response, maxRespBytes int, maxErrBytes int) ([]byte, bool, error) {
defer resp.Body.Close()
if resp.StatusCode >= 200 && resp.StatusCode < 300 {
body, err := io.ReadAll(io.LimitReader(resp.Body, int64(maxRespBytes)+1))
if err != nil {
return nil, true, fmt.Errorf("read response body: %w", err)
}
if len(body) > maxRespBytes {
return nil, true, fmt.Errorf("response body exceeded %d bytes (truncated)", maxRespBytes)
}
return body, true, nil
}
errBody, readErr := io.ReadAll(io.LimitReader(resp.Body, int64(maxErrBytes)))
if readErr != nil && len(errBody) == 0 {
errBody = []byte(fmt.Sprintf("[error reading response body: %v]", readErr))
}
return nil, false, &APIError{StatusCode: resp.StatusCode, Body: string(errBody)}
}
// doGet is a convenience wrapper for GET requests with the default Accept header.
func (c *Client) doGet(ctx context.Context, reqURL string) ([]byte, error) {
return c.doRequest(ctx, http.MethodGet, reqURL, "")
}
-599
View File
@@ -1,599 +0,0 @@
package github
import (
"context"
"net/http"
"net/http/httptest"
"net/url"
"strings"
"testing"
"time"
)
func TestNewClient_DefaultBaseURL(t *testing.T) {
c := NewClient("test-token", "")
if c.baseURL != "https://api.github.com" {
t.Errorf("expected default base URL, got %q", c.baseURL)
}
}
func TestNewClient_CustomBaseURL(t *testing.T) {
c := NewClient("test-token", "https://github.concur.com/api/v3")
if c.baseURL != "https://github.concur.com/api/v3" {
t.Errorf("expected custom base URL, got %q", c.baseURL)
}
}
func TestNewClient_TrimsTrailingSlash(t *testing.T) {
c := NewClient("test-token", "https://github.concur.com/api/v3/")
if c.baseURL != "https://github.concur.com/api/v3" {
t.Errorf("expected trailing slash trimmed, got %q", c.baseURL)
}
}
func TestDoRequest_SetsAuthHeader(t *testing.T) {
var gotAuth string
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
gotAuth = r.Header.Get("Authorization")
w.WriteHeader(200)
w.Write([]byte("{}"))
}))
defer srv.Close()
c := NewClient("my-token", srv.URL, AllowInsecureHTTP())
c.SetHTTPClient(srv.Client())
_, _ = c.doGet(context.Background(), srv.URL+"/test")
if gotAuth != "Bearer my-token" {
t.Errorf("expected Bearer auth, got %q", gotAuth)
}
}
func TestDoRequest_SetsDefaultAcceptHeader(t *testing.T) {
var gotAccept string
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
gotAccept = r.Header.Get("Accept")
w.WriteHeader(200)
w.Write([]byte("{}"))
}))
defer srv.Close()
c := NewClient("token", srv.URL, AllowInsecureHTTP())
c.SetHTTPClient(srv.Client())
_, _ = c.doGet(context.Background(), srv.URL+"/test")
if gotAccept != "application/vnd.github+json" {
t.Errorf("expected default Accept header, got %q", gotAccept)
}
}
func TestDoRequest_429Retry(t *testing.T) {
attempts := 0
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
attempts++
if attempts == 1 {
w.WriteHeader(429)
w.Write([]byte(`{"message":"rate limit"}`))
return
}
w.WriteHeader(200)
w.Write([]byte(`{"ok":true}`))
}))
defer srv.Close()
c := NewClient("token", srv.URL, AllowInsecureHTTP())
c.SetHTTPClient(srv.Client())
c.SetRetryBackoff([]time.Duration{10 * time.Millisecond, 10 * time.Millisecond})
body, err := c.doGet(context.Background(), srv.URL+"/test")
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if string(body) != `{"ok":true}` {
t.Errorf("unexpected body: %s", body)
}
if attempts != 2 {
t.Errorf("expected 2 attempts, got %d", attempts)
}
}
func TestDoRequest_429ExhaustsRetries(t *testing.T) {
attempts := 0
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
attempts++
w.WriteHeader(429)
w.Write([]byte(`{"message":"rate limit"}`))
}))
defer srv.Close()
c := NewClient("token", srv.URL, AllowInsecureHTTP())
c.SetHTTPClient(srv.Client())
c.SetRetryBackoff([]time.Duration{1 * time.Millisecond, 1 * time.Millisecond})
_, err := c.doGet(context.Background(), srv.URL+"/test")
if err == nil {
t.Fatal("expected error after exhausting retries")
}
apiErr, ok := err.(*APIError)
if !ok {
t.Fatalf("expected *APIError, got %T", err)
}
if apiErr.StatusCode != 429 {
t.Errorf("expected 429, got %d", apiErr.StatusCode)
}
if attempts != 3 {
t.Errorf("expected 3 attempts, got %d", attempts)
}
}
func TestDoRequest_404NoRetry(t *testing.T) {
attempts := 0
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
attempts++
w.WriteHeader(404)
w.Write([]byte(`{"message":"not found"}`))
}))
defer srv.Close()
c := NewClient("token", srv.URL, AllowInsecureHTTP())
c.SetHTTPClient(srv.Client())
_, err := c.doGet(context.Background(), srv.URL+"/test")
if err == nil {
t.Fatal("expected error for 404")
}
if attempts != 1 {
t.Errorf("expected 1 attempt (no retry on 404), got %d", attempts)
}
}
func TestDoRequest_401NoRetry(t *testing.T) {
attempts := 0
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
attempts++
w.WriteHeader(401)
w.Write([]byte(`{"message":"bad credentials"}`))
}))
defer srv.Close()
c := NewClient("token", srv.URL, AllowInsecureHTTP())
c.SetHTTPClient(srv.Client())
_, err := c.doGet(context.Background(), srv.URL+"/test")
if err == nil {
t.Fatal("expected error for 401")
}
if attempts != 1 {
t.Errorf("expected 1 attempt (no retry on 401), got %d", attempts)
}
}
func TestIsNotFound(t *testing.T) {
err := &APIError{StatusCode: 404, Body: "not found"}
if !IsNotFound(err) {
t.Error("expected IsNotFound to return true for 404")
}
err2 := &APIError{StatusCode: 500, Body: "server error"}
if IsNotFound(err2) {
t.Error("expected IsNotFound to return false for 500")
}
}
func TestIsUnauthorized(t *testing.T) {
err := &APIError{StatusCode: 401, Body: "bad credentials"}
if !IsUnauthorized(err) {
t.Error("expected IsUnauthorized to return true for 401")
}
}
func TestAPIError_SanitizesNewlines(t *testing.T) {
err := &APIError{StatusCode: 500, Body: "line1\ninjected\rmore"}
msg := err.Error()
if strings.Contains(msg, "\n") || strings.Contains(msg, "\r") {
t.Errorf("expected newlines to be sanitized, got: %q", msg)
}
if !strings.Contains(msg, "line1 injected more") {
t.Errorf("expected sanitized body, got: %q", msg)
}
}
func TestDoRequest_429RetryAfterHeader(t *testing.T) {
if testing.Short() {
t.Skip("skipping slow retry test in short mode")
}
attempts := 0
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
attempts++
if attempts == 1 {
w.Header().Set("Retry-After", "1")
w.WriteHeader(429)
w.Write([]byte(`{"message":"rate limit"}`))
return
}
w.WriteHeader(200)
w.Write([]byte(`{"ok":true}`))
}))
defer srv.Close()
c := NewClient("token", srv.URL, AllowInsecureHTTP())
c.SetHTTPClient(srv.Client())
// Use short backoff; Retry-After should override
c.SetRetryBackoff([]time.Duration{1 * time.Millisecond, 1 * time.Millisecond})
start := time.Now()
body, err := c.doGet(context.Background(), srv.URL+"/test")
elapsed := time.Since(start)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if string(body) != `{"ok":true}` {
t.Errorf("unexpected body: %s", body)
}
if attempts != 2 {
t.Errorf("expected 2 attempts, got %d", attempts)
}
// Retry-After: 1 means at least 1 second delay
if elapsed < 900*time.Millisecond {
t.Errorf("expected ~1s delay from Retry-After, got %v", elapsed)
}
}
func TestDoRequest_RetryAfterDoesNotMutateBackoff(t *testing.T) {
if testing.Short() {
t.Skip("skipping slow retry test in short mode")
}
attempts := 0
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
attempts++
if attempts == 1 {
w.Header().Set("Retry-After", "1")
w.WriteHeader(429)
w.Write([]byte(`{"message":"rate limit"}`))
return
}
w.WriteHeader(200)
w.Write([]byte(`{"ok":true}`))
}))
defer srv.Close()
c := NewClient("token", srv.URL, AllowInsecureHTTP())
c.SetHTTPClient(srv.Client())
c.SetRetryBackoff([]time.Duration{1 * time.Millisecond, 1 * time.Millisecond})
_, err := c.doGet(context.Background(), srv.URL+"/test")
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
// Verify the original retryBackoff slice was not mutated
if c.retryBackoff[0] != 1*time.Millisecond {
t.Errorf("retryBackoff[0] was mutated: got %v, want 1ms", c.retryBackoff[0])
}
if c.retryBackoff[1] != 1*time.Millisecond {
t.Errorf("retryBackoff[1] was mutated: got %v, want 1ms", c.retryBackoff[1])
}
}
func TestDoRequest_429RetryAfterHTTPDate(t *testing.T) {
if testing.Short() {
t.Skip("skipping slow Retry-After HTTP-date test in short mode")
}
attempts := 0
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
attempts++
if attempts == 1 {
// Use HTTP-date format (RFC 7231) — a time 2 seconds in the future.
future := time.Now().Add(2 * time.Second).UTC()
w.Header().Set("Retry-After", future.Format(http.TimeFormat))
w.WriteHeader(429)
w.Write([]byte(`{"message":"rate limit"}`))
return
}
w.WriteHeader(200)
w.Write([]byte(`{"ok":true}`))
}))
defer srv.Close()
c := NewClient("token", srv.URL, AllowInsecureHTTP())
c.SetHTTPClient(srv.Client())
c.SetRetryBackoff([]time.Duration{1 * time.Millisecond, 1 * time.Millisecond})
start := time.Now()
body, err := c.doGet(context.Background(), srv.URL+"/test")
elapsed := time.Since(start)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if string(body) != `{"ok":true}` {
t.Errorf("unexpected body: %s", body)
}
if attempts != 2 {
t.Errorf("expected 2 attempts, got %d", attempts)
}
// HTTP-date was ~2s in the future; by the time client processes it,
// time.Until gives ~1-2s. Verify it's meaningfully delayed (not instant).
if elapsed < 500*time.Millisecond {
t.Errorf("expected meaningful delay from HTTP-date Retry-After, got %v", elapsed)
}
}
func TestDoRequest_429RetryAfterHTTPDateInPast(t *testing.T) {
attempts := 0
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
attempts++
if attempts == 1 {
// Use a time in the past — should result in zero/immediate retry.
past := time.Now().Add(-10 * time.Second).UTC()
w.Header().Set("Retry-After", past.Format(http.TimeFormat))
w.WriteHeader(429)
w.Write([]byte(`{"message":"rate limit"}`))
return
}
w.WriteHeader(200)
w.Write([]byte(`{"ok":true}`))
}))
defer srv.Close()
c := NewClient("token", srv.URL, AllowInsecureHTTP())
c.SetHTTPClient(srv.Client())
c.SetRetryBackoff([]time.Duration{5 * time.Second, 5 * time.Second})
start := time.Now()
_, err := c.doGet(context.Background(), srv.URL+"/test")
elapsed := time.Since(start)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if attempts != 2 {
t.Errorf("expected 2 attempts, got %d", attempts)
}
// Past date should override the 5s backoff to ~0
if elapsed > 500*time.Millisecond {
t.Errorf("expected near-instant retry for past HTTP-date, got %v", elapsed)
}
}
func TestDoRequest_SetsUserAgentHeader(t *testing.T) {
var gotUA string
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
gotUA = r.Header.Get("User-Agent")
w.WriteHeader(200)
w.Write([]byte("{}"))
}))
defer srv.Close()
c := NewClient("token", srv.URL, AllowInsecureHTTP())
c.SetHTTPClient(srv.Client())
_, _ = c.doGet(context.Background(), srv.URL+"/test")
if gotUA != "review-bot/1.0" {
t.Errorf("expected User-Agent 'review-bot/1.0', got %q", gotUA)
}
}
func TestDoRequest_LimitsResponseBody(t *testing.T) {
// Verify that oversized responses return an error rather than silently truncating.
bigBody := strings.Repeat("x", maxResponseBytes+1024)
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(200)
w.Write([]byte(bigBody))
}))
defer srv.Close()
c := NewClient("token", srv.URL, AllowInsecureHTTP())
c.SetHTTPClient(srv.Client())
_, err := c.doGet(context.Background(), srv.URL+"/test")
if err == nil {
t.Fatal("expected error for oversized response body")
}
if !strings.Contains(err.Error(), "exceeded") {
t.Errorf("expected truncation error, got: %v", err)
}
}
func TestDoRequest_AcceptsExactlyAtLimit(t *testing.T) {
// A response body exactly equal to maxResponseBytes should succeed (not error).
exactBody := strings.Repeat("x", maxResponseBytes)
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(200)
w.Write([]byte(exactBody))
}))
defer srv.Close()
c := NewClient("token", srv.URL, AllowInsecureHTTP())
c.SetHTTPClient(srv.Client())
body, err := c.doGet(context.Background(), srv.URL+"/test")
if err != nil {
t.Fatalf("unexpected error for exactly-at-limit body: %v", err)
}
if len(body) != maxResponseBytes {
t.Errorf("expected body length %d, got %d", maxResponseBytes, len(body))
}
}
func TestDoRequest_SkipsAuthWhenTokenEmpty(t *testing.T) {
var gotAuth string
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
gotAuth = r.Header.Get("Authorization")
w.WriteHeader(200)
w.Write([]byte("{}"))
}))
defer srv.Close()
c := NewClient("", srv.URL, AllowInsecureHTTP()) // empty token
c.SetHTTPClient(srv.Client())
_, _ = c.doGet(context.Background(), srv.URL+"/test")
if gotAuth != "" {
t.Errorf("expected no Authorization header with empty token, got %q", gotAuth)
}
}
func TestNewClient_CheckRedirectStripsAuthOnCrossHost(t *testing.T) {
// Verify the CheckRedirect function is configured
c := NewClient("secret-token", "https://api.github.com")
if c.httpClient.CheckRedirect == nil {
t.Fatal("expected CheckRedirect to be set")
}
}
func TestDefaultCheckRedirect_RejectsHTTPSToHTTP(t *testing.T) {
prev := &http.Request{URL: &url.URL{Scheme: "https", Host: "api.github.com", Path: "/foo"}}
req := &http.Request{
URL: &url.URL{Scheme: "http", Host: "api.github.com", Path: "/foo"},
Header: http.Header{"Authorization": []string{"Bearer token"}},
}
err := defaultCheckRedirect(req, []*http.Request{prev})
if err == nil {
t.Fatal("expected error on HTTPS→HTTP redirect")
}
if !strings.Contains(err.Error(), "refusing redirect from HTTPS to HTTP") {
t.Errorf("unexpected error message: %v", err)
}
}
func TestDefaultCheckRedirect_StripsAuthOnCrossHost(t *testing.T) {
prev := &http.Request{URL: &url.URL{Scheme: "https", Host: "api.github.com", Path: "/foo"}}
req := &http.Request{
URL: &url.URL{Scheme: "https", Host: "objects.githubusercontent.com", Path: "/bar"},
Header: http.Header{"Authorization": []string{"Bearer token"}},
}
err := defaultCheckRedirect(req, []*http.Request{prev})
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if auth := req.Header.Get("Authorization"); auth != "" {
t.Errorf("expected Authorization header to be stripped, got %q", auth)
}
}
func TestDefaultCheckRedirect_PreservesAuthOnSameHost(t *testing.T) {
prev := &http.Request{URL: &url.URL{Scheme: "https", Host: "api.github.com", Path: "/foo"}}
req := &http.Request{
URL: &url.URL{Scheme: "https", Host: "api.github.com", Path: "/bar"},
Header: http.Header{"Authorization": []string{"Bearer token"}},
}
err := defaultCheckRedirect(req, []*http.Request{prev})
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if auth := req.Header.Get("Authorization"); auth != "Bearer token" {
t.Errorf("expected Authorization to be preserved, got %q", auth)
}
}
func TestDoRequest_RejectsHTTPWithToken(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(200)
w.Write([]byte("{}"))
}))
defer srv.Close()
// Without AllowInsecureHTTP, should refuse to send token over HTTP
c := NewClient("secret-token", srv.URL)
c.SetHTTPClient(srv.Client())
_, err := c.doGet(context.Background(), srv.URL+"/test")
if err == nil {
t.Fatal("expected error when sending token over HTTP")
}
if !strings.Contains(err.Error(), "refusing to send credentials") {
t.Errorf("unexpected error message: %v", err)
}
}
func TestDoRequest_AllowsHTTPWithoutToken(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(200)
w.Write([]byte(`{"ok":true}`))
}))
defer srv.Close()
// Without token, HTTP should be fine (no credentials to leak)
c := NewClient("", srv.URL)
c.SetHTTPClient(srv.Client())
body, err := c.doGet(context.Background(), srv.URL+"/test")
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if string(body) != `{"ok":true}` {
t.Errorf("unexpected body: %s", body)
}
}
func TestDoRequest_AllowsHTTPWithInsecureOption(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(200)
w.Write([]byte(`{"ok":true}`))
}))
defer srv.Close()
c := NewClient("secret-token", srv.URL, AllowInsecureHTTP())
c.SetHTTPClient(srv.Client())
body, err := c.doGet(context.Background(), srv.URL+"/test")
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if string(body) != `{"ok":true}` {
t.Errorf("unexpected body: %s", body)
}
}
func TestSetHTTPClient_NilRestoresDefault(t *testing.T) {
c := NewClient("token", "https://api.github.com")
c.SetHTTPClient(nil)
if c.httpClient == nil {
t.Fatal("expected non-nil httpClient after SetHTTPClient(nil)")
}
if c.httpClient.Timeout != 30*time.Second {
t.Errorf("expected 30s timeout, got %v", c.httpClient.Timeout)
}
if c.httpClient.CheckRedirect == nil {
t.Fatal("expected CheckRedirect policy after SetHTTPClient(nil)")
}
}
func TestSetHTTPClient_NilCheckRedirectEnforcesDefault(t *testing.T) {
c := NewClient("token", "https://api.github.com")
// Provide a client with nil CheckRedirect — should get default policy enforced.
hc := &http.Client{Timeout: 5 * time.Second}
c.SetHTTPClient(hc)
if c.httpClient.CheckRedirect == nil {
t.Fatal("expected CheckRedirect to be enforced when caller provides nil")
}
if c.httpClient.Timeout != 5*time.Second {
t.Errorf("expected caller's timeout preserved, got %v", c.httpClient.Timeout)
}
}
func TestSetHTTPClient_PreservesCustomCheckRedirect(t *testing.T) {
c := NewClient("token", "https://api.github.com")
called := false
hc := &http.Client{
CheckRedirect: func(req *http.Request, via []*http.Request) error {
called = true
return nil
},
}
c.SetHTTPClient(hc)
// Invoke the redirect to verify original is preserved
_ = c.httpClient.CheckRedirect(nil, []*http.Request{{}})
if !called {
t.Fatal("expected custom CheckRedirect to be preserved")
}
}
func TestAPIError_SafeError(t *testing.T) {
e := &APIError{StatusCode: 403, Body: "some sensitive body content"}
got := e.SafeError()
if got != "HTTP 403" {
t.Errorf("SafeError() = %q, want %q", got, "HTTP 403")
}
// Ensure Error() still includes body
full := e.Error()
if full != "HTTP 403: some sensitive body content" {
t.Errorf("Error() = %q, unexpected", full)
}
}
-13
View File
@@ -1,13 +0,0 @@
package github_test
import (
"gitea.weiker.me/rodin/review-bot/github"
"gitea.weiker.me/rodin/review-bot/vcs"
)
// Compile-time interface conformance assertions.
// These verify github.Client satisfies vcs.PRReader and vcs.FileReader.
var (
_ vcs.PRReader = (*github.Client)(nil)
_ vcs.FileReader = (*github.Client)(nil)
)
-135
View File
@@ -1,135 +0,0 @@
package github
import (
"context"
"encoding/base64"
"encoding/json"
"fmt"
"net/url"
"strings"
"gitea.weiker.me/rodin/review-bot/vcs"
)
// GetFileContent fetches a file from a repo at the given ref.
// Delegates to GetFileContentAtRef with the provided ref.
func (c *Client) GetFileContent(ctx context.Context, owner, repo, path, ref string) (string, error) {
return c.GetFileContentAtRef(ctx, owner, repo, path, ref)
}
// GetFileContentAtRef fetches a file at a specific ref from a repo.
// If ref is empty, the query parameter is omitted (uses default branch).
//
// Note: dot-segments ("." and "..") in the path are silently removed to
// prevent path traversal. This means a path like "foo/../bar" resolves
// to "foo/bar" rather than "bar".
func (c *Client) GetFileContentAtRef(ctx context.Context, owner, repo, path, ref string) (string, error) {
reqURL := fmt.Sprintf("%s/repos/%s/%s/contents/%s",
c.baseURL, url.PathEscape(owner), url.PathEscape(repo), escapePath(path))
if ref != "" {
reqURL += "?ref=" + url.QueryEscape(ref)
}
body, err := c.doGet(ctx, reqURL)
if err != nil {
return "", fmt.Errorf("fetch file %s: %w", path, err)
}
var resp struct {
Content string `json:"content"`
Encoding string `json:"encoding"`
}
if err := json.Unmarshal(body, &resp); err != nil {
return "", fmt.Errorf("parse file content JSON: %w", err)
}
if resp.Encoding != "base64" {
return "", fmt.Errorf("unexpected encoding %q for file %s", resp.Encoding, path)
}
decoded, err := decodeBase64Content(resp.Content)
if err != nil {
return "", fmt.Errorf("decode base64 content for %s: %w", path, err)
}
return decoded, nil
}
// ListContents lists files and directories at a given path in a repo.
// Returns the directory listing from the GitHub contents API.
// If the path points to a single file (not a directory), the API returns
// a JSON object instead of an array; this is handled by returning a
// single-element slice.
//
// Note: dot-segments ("." and "..") in the path are silently removed to
// prevent path traversal. This means a path like "foo/../bar" resolves
// to "foo/bar" rather than "bar".
func (c *Client) ListContents(ctx context.Context, owner, repo, path string) ([]vcs.ContentEntry, error) {
reqURL := fmt.Sprintf("%s/repos/%s/%s/contents/%s",
c.baseURL, url.PathEscape(owner), url.PathEscape(repo), escapePath(path))
body, err := c.doGet(ctx, reqURL)
if err != nil {
return nil, fmt.Errorf("list contents %s: %w", path, err)
}
type entry struct {
Name string `json:"name"`
Path string `json:"path"`
Type string `json:"type"`
}
// The GitHub contents API returns an array for directories and an object
// for single files. Try array first (common case), then fall back to object.
// An empty array ([]) is valid — it represents an empty directory — and
// results in a zero-length slice returned without error.
var entries []entry
if err := json.Unmarshal(body, &entries); err != nil {
var single entry
if err2 := json.Unmarshal(body, &single); err2 != nil {
return nil, fmt.Errorf("parse contents JSON: as array: %w; as object: %w", err, err2)
}
// Guard against empty objects ({}) or unexpected shapes that
// unmarshal successfully but carry no useful data.
if single.Name == "" && single.Path == "" && single.Type == "" {
return nil, fmt.Errorf("parse contents JSON: unexpected response format")
}
entries = []entry{single}
}
result := make([]vcs.ContentEntry, len(entries))
for i, e := range entries {
result[i] = vcs.ContentEntry{
Name: e.Name,
Path: e.Path,
Type: e.Type,
}
}
return result, nil
}
// escapePath escapes each segment of a relative file path for use in URLs.
// Slashes are preserved as path separators; other special characters are escaped.
// Dot-segments ("." and "..") and empty segments (from consecutive slashes like
// "a//b") are silently removed to prevent path traversal and produce canonical
// paths. This is intentional: callers may receive a different path than requested
// without error. The function is package-private, and all callers
// (GetFileContentAtRef, ListContents) already handle missing-file errors from the
// API if the cleaned path doesn't match what the caller intended.
func escapePath(p string) string {
parts := strings.Split(p, "/")
var clean []string
for _, part := range parts {
if part == "." || part == ".." || part == "" {
continue
}
clean = append(clean, url.PathEscape(part))
}
return strings.Join(clean, "/")
}
// decodeBase64Content decodes base64-encoded content from the GitHub contents API.
// GitHub returns base64 content with line breaks for formatting; we strip \r and \n before decoding.
func decodeBase64Content(encoded string) (string, error) {
// GitHub inserts newlines in base64 content
cleaned := strings.NewReplacer("\n", "", "\r", "").Replace(encoded)
decoded, err := base64.StdEncoding.DecodeString(cleaned)
if err != nil {
return "", err
}
return string(decoded), nil
}
-334
View File
@@ -1,334 +0,0 @@
package github
import (
"context"
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
"time"
)
func TestGetFileContent_DelegatesToGetFileContentAtRef(t *testing.T) {
var gotRef string
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
gotRef = r.URL.Query().Get("ref")
json.NewEncoder(w).Encode(map[string]string{
"content": "dGVzdA==", // "test" in base64
"encoding": "base64",
})
}))
defer srv.Close()
c := NewClient("token", srv.URL, AllowInsecureHTTP())
c.SetHTTPClient(srv.Client())
// Call with empty ref — should not include ref param
content, err := c.GetFileContent(context.Background(), "owner", "repo", "file.go", "")
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if content != "test" {
t.Errorf("expected 'test', got %q", content)
}
if gotRef != "" {
t.Errorf("expected empty ref, got %q", gotRef)
}
}
func TestGetFileContent_WithRef(t *testing.T) {
var gotRef string
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
gotRef = r.URL.Query().Get("ref")
json.NewEncoder(w).Encode(map[string]string{
"content": "dGVzdA==",
"encoding": "base64",
})
}))
defer srv.Close()
c := NewClient("token", srv.URL, AllowInsecureHTTP())
c.SetHTTPClient(srv.Client())
_, err := c.GetFileContent(context.Background(), "owner", "repo", "file.go", "abc123")
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if gotRef != "abc123" {
t.Errorf("expected ref 'abc123', got %q", gotRef)
}
}
func TestGetFileContent_404(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(404)
w.Write([]byte(`{"message":"Not Found"}`))
}))
defer srv.Close()
c := NewClient("token", srv.URL, AllowInsecureHTTP())
c.SetHTTPClient(srv.Client())
_, err := c.GetFileContent(context.Background(), "owner", "repo", "missing.go", "")
if err == nil {
t.Fatal("expected error for 404")
}
}
func TestGetFileContent_401(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(401)
w.Write([]byte(`{"message":"Bad credentials"}`))
}))
defer srv.Close()
c := NewClient("token", srv.URL, AllowInsecureHTTP())
c.SetHTTPClient(srv.Client())
_, err := c.GetFileContent(context.Background(), "owner", "repo", "file.go", "")
if err == nil {
t.Fatal("expected error for 401")
}
}
func TestGetFileContent_429Retry(t *testing.T) {
attempts := 0
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
attempts++
if attempts == 1 {
w.WriteHeader(429)
w.Write([]byte(`{"message":"rate limit"}`))
return
}
json.NewEncoder(w).Encode(map[string]string{
"content": "b2s=",
"encoding": "base64",
})
}))
defer srv.Close()
c := NewClient("token", srv.URL, AllowInsecureHTTP())
c.SetHTTPClient(srv.Client())
c.SetRetryBackoff([]time.Duration{1 * time.Millisecond})
content, err := c.GetFileContent(context.Background(), "owner", "repo", "file.go", "")
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if content != "ok" {
t.Errorf("expected 'ok', got %q", content)
}
if attempts != 2 {
t.Errorf("expected 2 attempts, got %d", attempts)
}
}
func TestGetFileContent_MalformedJSON(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(200)
w.Write([]byte(`not json`))
}))
defer srv.Close()
c := NewClient("token", srv.URL, AllowInsecureHTTP())
c.SetHTTPClient(srv.Client())
_, err := c.GetFileContent(context.Background(), "owner", "repo", "file.go", "")
if err == nil {
t.Fatal("expected error for malformed JSON")
}
}
func TestListContents_HappyPath(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/repos/owner/repo/contents/src" {
t.Errorf("unexpected path: %s", r.URL.Path)
}
json.NewEncoder(w).Encode([]map[string]string{
{"name": "main.go", "path": "src/main.go", "type": "file"},
{"name": "lib", "path": "src/lib", "type": "dir"},
})
}))
defer srv.Close()
c := NewClient("token", srv.URL, AllowInsecureHTTP())
c.SetHTTPClient(srv.Client())
entries, err := c.ListContents(context.Background(), "owner", "repo", "src")
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if len(entries) != 2 {
t.Fatalf("expected 2 entries, got %d", len(entries))
}
if entries[0].Name != "main.go" {
t.Errorf("expected name 'main.go', got %q", entries[0].Name)
}
if entries[0].Path != "src/main.go" {
t.Errorf("expected path 'src/main.go', got %q", entries[0].Path)
}
if entries[0].Type != "file" {
t.Errorf("expected type 'file', got %q", entries[0].Type)
}
if entries[1].Name != "lib" {
t.Errorf("expected name 'lib', got %q", entries[1].Name)
}
if entries[1].Type != "dir" {
t.Errorf("expected type 'dir', got %q", entries[1].Type)
}
}
func TestListContents_404(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(404)
w.Write([]byte(`{"message":"Not Found"}`))
}))
defer srv.Close()
c := NewClient("token", srv.URL, AllowInsecureHTTP())
c.SetHTTPClient(srv.Client())
_, err := c.ListContents(context.Background(), "owner", "repo", "missing")
if err == nil {
t.Fatal("expected error for 404")
}
}
func TestListContents_401(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(401)
w.Write([]byte(`{"message":"Bad credentials"}`))
}))
defer srv.Close()
c := NewClient("token", srv.URL, AllowInsecureHTTP())
c.SetHTTPClient(srv.Client())
_, err := c.ListContents(context.Background(), "owner", "repo", "src")
if err == nil {
t.Fatal("expected error for 401")
}
}
func TestListContents_429Retry(t *testing.T) {
attempts := 0
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
attempts++
if attempts == 1 {
w.WriteHeader(429)
w.Write([]byte(`{"message":"rate limit"}`))
return
}
json.NewEncoder(w).Encode([]map[string]string{
{"name": "file.go", "path": "file.go", "type": "file"},
})
}))
defer srv.Close()
c := NewClient("token", srv.URL, AllowInsecureHTTP())
c.SetHTTPClient(srv.Client())
c.SetRetryBackoff([]time.Duration{1 * time.Millisecond})
entries, err := c.ListContents(context.Background(), "owner", "repo", ".")
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if len(entries) != 1 {
t.Fatalf("expected 1 entry, got %d", len(entries))
}
if attempts != 2 {
t.Errorf("expected 2 attempts, got %d", attempts)
}
}
func TestListContents_MalformedJSON(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(200)
w.Write([]byte(`not json`))
}))
defer srv.Close()
c := NewClient("token", srv.URL, AllowInsecureHTTP())
c.SetHTTPClient(srv.Client())
_, err := c.ListContents(context.Background(), "owner", "repo", "src")
if err == nil {
t.Fatal("expected error for malformed JSON")
}
}
func TestDecodeBase64Content(t *testing.T) {
// Test with newlines (GitHub's format)
encoded := "cGFja2FnZSBt\nYWlu"
decoded, err := decodeBase64Content(encoded)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if decoded != "package main" {
t.Errorf("expected 'package main', got %q", decoded)
}
}
func TestDecodeBase64Content_Invalid(t *testing.T) {
_, err := decodeBase64Content("not!!!valid!!!base64")
if err == nil {
t.Fatal("expected error for invalid base64")
}
}
func TestEscapePath_RejectsDotSegments(t *testing.T) {
tests := []struct {
input string
want string
}{
{"src/main.go", "src/main.go"},
{"../etc/passwd", "etc/passwd"},
{"./src/../main.go", "src/main.go"},
{"a/b/c", "a/b/c"},
{"file with spaces.go", "file%20with%20spaces.go"},
{"a/./b/../c", "a/b/c"},
}
for _, tt := range tests {
got := escapePath(tt.input)
if got != tt.want {
t.Errorf("escapePath(%q) = %q, want %q", tt.input, got, tt.want)
}
}
}
func TestDecodeBase64Content_CRLF(t *testing.T) {
// Base64 of "hello world" with CRLF line breaks inserted
encoded := "aGVs\r\nbG8g\r\nd29y\r\nbGQ="
decoded, err := decodeBase64Content(encoded)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if decoded != "hello world" {
t.Errorf("expected 'hello world', got %q", decoded)
}
}
func TestListContents_SingleFile(t *testing.T) {
// GitHub Contents API returns a JSON object (not array) for single-file paths
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(200)
w.Write([]byte(`{"name":"README.md","path":"README.md","type":"file"}`))
}))
defer srv.Close()
c := NewClient("token", srv.URL, AllowInsecureHTTP())
c.SetHTTPClient(srv.Client())
entries, err := c.ListContents(context.Background(), "owner", "repo", "README.md")
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if len(entries) != 1 {
t.Fatalf("expected 1 entry, got %d", len(entries))
}
if entries[0].Name != "README.md" {
t.Errorf("expected name 'README.md', got %q", entries[0].Name)
}
if entries[0].Type != "file" {
t.Errorf("expected type 'file', got %q", entries[0].Type)
}
}
-212
View File
@@ -1,212 +0,0 @@
package github
import (
"context"
"encoding/json"
"fmt"
"net/http"
"net/url"
"gitea.weiker.me/rodin/review-bot/vcs"
)
// pullRequestResponse is the GitHub API response for a pull request.
type pullRequestResponse struct {
Number int `json:"number"`
Title string `json:"title"`
Body string `json:"body"`
Head struct {
SHA string `json:"sha"`
Ref string `json:"ref"`
} `json:"head"`
Base struct {
Ref string `json:"ref"`
} `json:"base"`
}
// changedFileResponse is the GitHub API response for a changed file in a PR.
type changedFileResponse struct {
Filename string `json:"filename"`
Status string `json:"status"`
Patch string `json:"patch"`
}
// commitStatusResponse is the GitHub combined status API response.
type commitStatusResponse struct {
Statuses []struct {
Context string `json:"context"`
State string `json:"state"`
Description string `json:"description"`
TargetURL string `json:"target_url"`
} `json:"statuses"`
}
// checkRunsResponse is the GitHub check runs API response.
type checkRunsResponse struct {
CheckRuns []struct {
Name string `json:"name"`
Conclusion *string `json:"conclusion"`
Status string `json:"status"`
HTMLURL string `json:"html_url"`
} `json:"check_runs"`
}
// GetPullRequest fetches PR metadata.
func (c *Client) GetPullRequest(ctx context.Context, owner, repo string, number int) (*vcs.PullRequest, error) {
reqURL := fmt.Sprintf("%s/repos/%s/%s/pulls/%d", c.baseURL, url.PathEscape(owner), url.PathEscape(repo), number)
body, err := c.doGet(ctx, reqURL)
if err != nil {
return nil, fmt.Errorf("fetch PR: %w", err)
}
var resp pullRequestResponse
if err := json.Unmarshal(body, &resp); err != nil {
return nil, fmt.Errorf("parse PR JSON: %w", err)
}
return &vcs.PullRequest{
Number: resp.Number,
Title: resp.Title,
Body: resp.Body,
Head: vcs.HeadRef{SHA: resp.Head.SHA, Ref: resp.Head.Ref},
Base: vcs.BaseRef{Ref: resp.Base.Ref},
}, nil
}
// GetPullRequestDiff fetches the unified diff for a PR.
// Uses Accept: application/vnd.github.diff to get raw diff text.
func (c *Client) GetPullRequestDiff(ctx context.Context, owner, repo string, number int) (string, error) {
reqURL := fmt.Sprintf("%s/repos/%s/%s/pulls/%d", c.baseURL, url.PathEscape(owner), url.PathEscape(repo), number)
body, err := c.doRequest(ctx, http.MethodGet, reqURL, "application/vnd.github.diff")
if err != nil {
return "", fmt.Errorf("fetch diff: %w", err)
}
return string(body), nil
}
// maxPages is the upper bound on pagination loops to prevent unbounded iteration
// in case the server returns a full page indefinitely.
const maxPages = 100
// GetPullRequestFiles fetches the list of files changed in a PR.
// Paginates through all pages (100 per page) to collect all files.
// Returns nil (not an empty slice) when the PR has no changed files.
// Callers can safely range over or check len() on a nil slice.
func (c *Client) GetPullRequestFiles(ctx context.Context, owner, repo string, number int) ([]vcs.ChangedFile, error) {
var allFiles []vcs.ChangedFile
for page := 1; page <= maxPages; page++ {
reqURL := fmt.Sprintf("%s/repos/%s/%s/pulls/%d/files?per_page=100&page=%d",
c.baseURL, url.PathEscape(owner), url.PathEscape(repo), number, page)
body, err := c.doGet(ctx, reqURL)
if err != nil {
return nil, fmt.Errorf("fetch PR files page %d: %w", page, err)
}
var files []changedFileResponse
if err := json.Unmarshal(body, &files); err != nil {
return nil, fmt.Errorf("parse PR files JSON: %w", err)
}
if len(files) == 0 {
break
}
for _, f := range files {
allFiles = append(allFiles, vcs.ChangedFile{
Filename: f.Filename,
Status: f.Status,
Patch: f.Patch,
})
}
if len(files) < 100 {
break
}
}
return allFiles, nil
}
// GetCommitStatuses fetches both commit statuses and check runs for a SHA,
// merging them into a unified []vcs.CommitStatus slice.
// Returns nil (not an empty slice) when there are no statuses or check runs.
// If the commit statuses endpoint fails (e.g. 404 for an unknown SHA), the
// function returns immediately without attempting the check-runs endpoint.
// If the check-runs endpoint fails after statuses were fetched successfully,
// the function returns an error (not a partial result) so callers always get
// either a complete view or a clear error signal.
func (c *Client) GetCommitStatuses(ctx context.Context, owner, repo, sha string) ([]vcs.CommitStatus, error) {
var result []vcs.CommitStatus
// Fetch commit statuses
statusURL := fmt.Sprintf("%s/repos/%s/%s/commits/%s/status",
c.baseURL, url.PathEscape(owner), url.PathEscape(repo), url.PathEscape(sha))
statusBody, err := c.doGet(ctx, statusURL)
if err != nil {
return nil, fmt.Errorf("fetch commit statuses: %w", err)
}
var statusResp commitStatusResponse
if err := json.Unmarshal(statusBody, &statusResp); err != nil {
return nil, fmt.Errorf("parse commit statuses JSON: %w", err)
}
for _, s := range statusResp.Statuses {
result = append(result, vcs.CommitStatus{
Context: s.Context,
Status: s.State,
Description: s.Description,
TargetURL: s.TargetURL,
})
}
// Fetch check runs (paginated)
for checkPage := 1; checkPage <= maxPages; checkPage++ {
checkURL := fmt.Sprintf("%s/repos/%s/%s/commits/%s/check-runs?per_page=100&page=%d",
c.baseURL, url.PathEscape(owner), url.PathEscape(repo), url.PathEscape(sha), checkPage)
checkBody, err := c.doGet(ctx, checkURL)
if err != nil {
return nil, fmt.Errorf("fetch check runs page %d: %w", checkPage, err)
}
var checkResp checkRunsResponse
if err := json.Unmarshal(checkBody, &checkResp); err != nil {
return nil, fmt.Errorf("parse check runs JSON: %w", err)
}
for _, cr := range checkResp.CheckRuns {
result = append(result, vcs.CommitStatus{
Context: cr.Name,
Status: mapCheckRunStatus(cr.Conclusion),
Description: derefString(cr.Conclusion),
TargetURL: cr.HTMLURL,
})
}
if len(checkResp.CheckRuns) < 100 {
break
}
}
return result, nil
}
// mapCheckRunStatus maps a check run conclusion to a vcs.CommitStatus status string.
// Conclusion alone determines the mapped state: nil conclusion means the run is
// still in progress (pending), regardless of the status field value.
func mapCheckRunStatus(conclusion *string) string {
if conclusion == nil {
// Still running or queued
return "pending"
}
switch *conclusion {
case "success":
return "success"
case "failure", "action_required", "timed_out":
return "failure"
case "cancelled", "skipped", "neutral":
return "success" // non-blocking: these do not indicate a blocking failure per GitHub check suite semantics
case "stale", "waiting":
return "pending"
default:
return "pending"
}
}
// derefString safely dereferences a string pointer, returning empty string if nil.
func derefString(s *string) string {
if s == nil {
return ""
}
return *s
}
-637
View File
@@ -1,637 +0,0 @@
package github
import (
"context"
"encoding/json"
"fmt"
"net/http"
"net/http/httptest"
"strings"
"testing"
"time"
)
func TestGetPullRequest_HappyPath(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/repos/owner/repo/pulls/42" {
t.Errorf("unexpected path: %s", r.URL.Path)
}
json.NewEncoder(w).Encode(map[string]interface{}{
"number": 42,
"title": "Test PR",
"body": "Description",
"head": map[string]string{"sha": "abc123", "ref": "feature-branch"},
"base": map[string]string{"ref": "main"},
})
}))
defer srv.Close()
c := NewClient("token", srv.URL, AllowInsecureHTTP())
c.SetHTTPClient(srv.Client())
pr, err := c.GetPullRequest(context.Background(), "owner", "repo", 42)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if pr.Number != 42 {
t.Errorf("expected number 42, got %d", pr.Number)
}
if pr.Title != "Test PR" {
t.Errorf("expected title 'Test PR', got %q", pr.Title)
}
if pr.Body != "Description" {
t.Errorf("expected body 'Description', got %q", pr.Body)
}
if pr.Head.SHA != "abc123" {
t.Errorf("expected head SHA 'abc123', got %q", pr.Head.SHA)
}
if pr.Head.Ref != "feature-branch" {
t.Errorf("expected head ref 'feature-branch', got %q", pr.Head.Ref)
}
if pr.Base.Ref != "main" {
t.Errorf("expected base ref 'main', got %q", pr.Base.Ref)
}
}
func TestGetPullRequest_404(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(404)
w.Write([]byte(`{"message":"Not Found"}`))
}))
defer srv.Close()
c := NewClient("token", srv.URL, AllowInsecureHTTP())
c.SetHTTPClient(srv.Client())
_, err := c.GetPullRequest(context.Background(), "owner", "repo", 999)
if err == nil {
t.Fatal("expected error for 404")
}
if !IsNotFound(err) {
t.Errorf("expected IsNotFound=true, got error: %v", err)
}
}
func TestGetPullRequest_401(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(401)
w.Write([]byte(`{"message":"Bad credentials"}`))
}))
defer srv.Close()
c := NewClient("token", srv.URL, AllowInsecureHTTP())
c.SetHTTPClient(srv.Client())
_, err := c.GetPullRequest(context.Background(), "owner", "repo", 1)
if err == nil {
t.Fatal("expected error for 401")
}
if !IsUnauthorized(err) {
t.Errorf("expected IsUnauthorized=true, got error: %v", err)
}
}
func TestGetPullRequest_429Retry(t *testing.T) {
attempts := 0
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
attempts++
if attempts == 1 {
w.WriteHeader(429)
w.Write([]byte(`{"message":"rate limit"}`))
return
}
json.NewEncoder(w).Encode(map[string]interface{}{
"number": 1,
"title": "PR",
"body": "",
"head": map[string]string{"sha": "abc", "ref": "br"},
"base": map[string]string{"ref": "main"},
})
}))
defer srv.Close()
c := NewClient("token", srv.URL, AllowInsecureHTTP())
c.SetHTTPClient(srv.Client())
c.SetRetryBackoff([]time.Duration{1 * time.Millisecond})
pr, err := c.GetPullRequest(context.Background(), "owner", "repo", 1)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if pr.Number != 1 {
t.Errorf("expected number 1, got %d", pr.Number)
}
if attempts != 2 {
t.Errorf("expected 2 attempts, got %d", attempts)
}
}
func TestGetPullRequest_MalformedJSON(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(200)
w.Write([]byte(`{invalid json`))
}))
defer srv.Close()
c := NewClient("token", srv.URL, AllowInsecureHTTP())
c.SetHTTPClient(srv.Client())
_, err := c.GetPullRequest(context.Background(), "owner", "repo", 1)
if err == nil {
t.Fatal("expected error for malformed JSON")
}
if !strings.Contains(err.Error(), "parse PR JSON") {
t.Errorf("expected parse error, got: %v", err)
}
}
func TestGetPullRequestDiff_HappyPath(t *testing.T) {
expectedDiff := "diff --git a/file.go b/file.go\n--- a/file.go\n+++ b/file.go\n@@ -1,3 +1,4 @@\n+// new line\n"
var gotAccept string
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
gotAccept = r.Header.Get("Accept")
w.WriteHeader(200)
w.Write([]byte(expectedDiff))
}))
defer srv.Close()
c := NewClient("token", srv.URL, AllowInsecureHTTP())
c.SetHTTPClient(srv.Client())
diff, err := c.GetPullRequestDiff(context.Background(), "owner", "repo", 42)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if diff != expectedDiff {
t.Errorf("unexpected diff: %q", diff)
}
if gotAccept != "application/vnd.github.diff" {
t.Errorf("expected diff Accept header, got %q", gotAccept)
}
}
func TestGetPullRequestDiff_404(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(404)
w.Write([]byte(`{"message":"Not Found"}`))
}))
defer srv.Close()
c := NewClient("token", srv.URL, AllowInsecureHTTP())
c.SetHTTPClient(srv.Client())
_, err := c.GetPullRequestDiff(context.Background(), "owner", "repo", 999)
if err == nil {
t.Fatal("expected error for 404")
}
}
func TestGetPullRequestDiff_401(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(401)
w.Write([]byte(`{"message":"Bad credentials"}`))
}))
defer srv.Close()
c := NewClient("token", srv.URL, AllowInsecureHTTP())
c.SetHTTPClient(srv.Client())
_, err := c.GetPullRequestDiff(context.Background(), "owner", "repo", 1)
if err == nil {
t.Fatal("expected error for 401")
}
}
func TestGetPullRequestFiles_HappyPath(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
json.NewEncoder(w).Encode([]map[string]interface{}{
{"filename": "main.go", "status": "modified", "patch": "@@ -1,3 +1,4 @@\n+line"},
{"filename": "test.go", "status": "added", "patch": "@@ -0,0 +1,5 @@\n+new file"},
})
}))
defer srv.Close()
c := NewClient("token", srv.URL, AllowInsecureHTTP())
c.SetHTTPClient(srv.Client())
files, err := c.GetPullRequestFiles(context.Background(), "owner", "repo", 1)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if len(files) != 2 {
t.Fatalf("expected 2 files, got %d", len(files))
}
if files[0].Filename != "main.go" {
t.Errorf("expected filename 'main.go', got %q", files[0].Filename)
}
if files[0].Status != "modified" {
t.Errorf("expected status 'modified', got %q", files[0].Status)
}
if files[0].Patch != "@@ -1,3 +1,4 @@\n+line" {
t.Errorf("unexpected patch: %q", files[0].Patch)
}
}
func TestGetPullRequestFiles_Pagination(t *testing.T) {
// Simulate > 100 files requiring pagination
page1Files := make([]map[string]string, 100)
for i := 0; i < 100; i++ {
page1Files[i] = map[string]string{
"filename": fmt.Sprintf("file%d.go", i),
"status": "modified",
"patch": fmt.Sprintf("patch%d", i),
}
}
page2Files := []map[string]string{
{"filename": "file100.go", "status": "added", "patch": "patch100"},
}
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
page := r.URL.Query().Get("page")
if page == "" || page == "1" {
json.NewEncoder(w).Encode(page1Files)
} else {
json.NewEncoder(w).Encode(page2Files)
}
}))
defer srv.Close()
c := NewClient("token", srv.URL, AllowInsecureHTTP())
c.SetHTTPClient(srv.Client())
files, err := c.GetPullRequestFiles(context.Background(), "owner", "repo", 1)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if len(files) != 101 {
t.Errorf("expected 101 files (paginated), got %d", len(files))
}
if files[100].Filename != "file100.go" {
t.Errorf("expected last file 'file100.go', got %q", files[100].Filename)
}
if files[100].Patch != "patch100" {
t.Errorf("expected last patch 'patch100', got %q", files[100].Patch)
}
}
func TestGetPullRequestFiles_BinaryFile_NoPatch(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Binary files have no patch field in GitHub response
json.NewEncoder(w).Encode([]map[string]interface{}{
{"filename": "image.png", "status": "added"},
})
}))
defer srv.Close()
c := NewClient("token", srv.URL, AllowInsecureHTTP())
c.SetHTTPClient(srv.Client())
files, err := c.GetPullRequestFiles(context.Background(), "owner", "repo", 1)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if len(files) != 1 {
t.Fatalf("expected 1 file, got %d", len(files))
}
if files[0].Patch != "" {
t.Errorf("expected empty patch for binary file, got %q", files[0].Patch)
}
}
func TestGetPullRequestFiles_404(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(404)
w.Write([]byte(`{"message":"Not Found"}`))
}))
defer srv.Close()
c := NewClient("token", srv.URL, AllowInsecureHTTP())
c.SetHTTPClient(srv.Client())
_, err := c.GetPullRequestFiles(context.Background(), "owner", "repo", 999)
if err == nil {
t.Fatal("expected error for 404")
}
}
func TestGetPullRequestFiles_MalformedJSON(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(200)
w.Write([]byte(`not json`))
}))
defer srv.Close()
c := NewClient("token", srv.URL, AllowInsecureHTTP())
c.SetHTTPClient(srv.Client())
_, err := c.GetPullRequestFiles(context.Background(), "owner", "repo", 1)
if err == nil {
t.Fatal("expected error for malformed JSON")
}
}
func TestGetFileContentAtRef_HappyPath(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/repos/owner/repo/contents/path/to/file.go" {
t.Errorf("unexpected path: %s", r.URL.Path)
}
if r.URL.Query().Get("ref") != "abc123" {
t.Errorf("unexpected ref: %s", r.URL.Query().Get("ref"))
}
json.NewEncoder(w).Encode(map[string]string{
"content": "cGFja2FnZSBtYWlu", // "package main" in base64
"encoding": "base64",
})
}))
defer srv.Close()
c := NewClient("token", srv.URL, AllowInsecureHTTP())
c.SetHTTPClient(srv.Client())
content, err := c.GetFileContentAtRef(context.Background(), "owner", "repo", "path/to/file.go", "abc123")
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if content != "package main" {
t.Errorf("expected 'package main', got %q", content)
}
}
func TestGetFileContentAtRef_EmptyRef(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Query().Get("ref") != "" {
t.Errorf("expected no ref param, got %q", r.URL.Query().Get("ref"))
}
json.NewEncoder(w).Encode(map[string]string{
"content": "aGVsbG8=", // "hello" in base64
"encoding": "base64",
})
}))
defer srv.Close()
c := NewClient("token", srv.URL, AllowInsecureHTTP())
c.SetHTTPClient(srv.Client())
content, err := c.GetFileContentAtRef(context.Background(), "owner", "repo", "file.txt", "")
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if content != "hello" {
t.Errorf("expected 'hello', got %q", content)
}
}
func TestGetFileContentAtRef_404(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(404)
w.Write([]byte(`{"message":"Not Found"}`))
}))
defer srv.Close()
c := NewClient("token", srv.URL, AllowInsecureHTTP())
c.SetHTTPClient(srv.Client())
_, err := c.GetFileContentAtRef(context.Background(), "owner", "repo", "missing.go", "main")
if err == nil {
t.Fatal("expected error for 404")
}
}
func TestGetFileContentAtRef_401(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(401)
w.Write([]byte(`{"message":"Bad credentials"}`))
}))
defer srv.Close()
c := NewClient("token", srv.URL, AllowInsecureHTTP())
c.SetHTTPClient(srv.Client())
_, err := c.GetFileContentAtRef(context.Background(), "owner", "repo", "file.go", "main")
if err == nil {
t.Fatal("expected error for 401")
}
}
func TestGetFileContentAtRef_MalformedJSON(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(200)
w.Write([]byte(`not valid json`))
}))
defer srv.Close()
c := NewClient("token", srv.URL, AllowInsecureHTTP())
c.SetHTTPClient(srv.Client())
_, err := c.GetFileContentAtRef(context.Background(), "owner", "repo", "file.go", "main")
if err == nil {
t.Fatal("expected error for malformed JSON")
}
}
func TestGetFileContentAtRef_429Retry(t *testing.T) {
attempts := 0
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
attempts++
if attempts == 1 {
w.WriteHeader(429)
w.Write([]byte(`{"message":"rate limit"}`))
return
}
json.NewEncoder(w).Encode(map[string]string{
"content": "b2s=", // "ok" in base64
"encoding": "base64",
})
}))
defer srv.Close()
c := NewClient("token", srv.URL, AllowInsecureHTTP())
c.SetHTTPClient(srv.Client())
c.SetRetryBackoff([]time.Duration{1 * time.Millisecond})
content, err := c.GetFileContentAtRef(context.Background(), "owner", "repo", "file.go", "main")
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if content != "ok" {
t.Errorf("expected 'ok', got %q", content)
}
if attempts != 2 {
t.Errorf("expected 2 attempts, got %d", attempts)
}
}
func TestGetCommitStatuses_HappyPath(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch {
case strings.Contains(r.URL.Path, "/status"):
json.NewEncoder(w).Encode(map[string]interface{}{
"state": "success",
"statuses": []map[string]string{
{
"context": "ci/build",
"state": "success",
"description": "Build passed",
"target_url": "https://ci.example.com/1",
},
},
})
case strings.Contains(r.URL.Path, "/check-runs"):
conclusion := "success"
json.NewEncoder(w).Encode(map[string]interface{}{
"total_count": 1,
"check_runs": []map[string]interface{}{
{
"name": "lint",
"conclusion": &conclusion,
"status": "completed",
"html_url": "https://github.com/check/1",
},
},
})
default:
t.Errorf("unexpected path: %s", r.URL.Path)
w.WriteHeader(404)
}
}))
defer srv.Close()
c := NewClient("token", srv.URL, AllowInsecureHTTP())
c.SetHTTPClient(srv.Client())
statuses, err := c.GetCommitStatuses(context.Background(), "owner", "repo", "abc123")
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if len(statuses) != 2 {
t.Fatalf("expected 2 statuses, got %d", len(statuses))
}
// First should be from commit statuses
if statuses[0].Context != "ci/build" {
t.Errorf("expected context 'ci/build', got %q", statuses[0].Context)
}
if statuses[0].Status != "success" {
t.Errorf("expected status 'success', got %q", statuses[0].Status)
}
// Second should be from check runs
if statuses[1].Context != "lint" {
t.Errorf("expected context 'lint', got %q", statuses[1].Context)
}
if statuses[1].Status != "success" {
t.Errorf("expected status 'success', got %q", statuses[1].Status)
}
}
func TestGetCommitStatuses_CheckRunConclusions(t *testing.T) {
tests := []struct {
conclusion *string
status string
want string
}{
{stringPtr("success"), "completed", "success"},
{stringPtr("failure"), "completed", "failure"},
{stringPtr("action_required"), "completed", "failure"},
{stringPtr("timed_out"), "completed", "failure"},
{stringPtr("cancelled"), "completed", "success"},
{stringPtr("skipped"), "completed", "success"},
{stringPtr("neutral"), "completed", "success"},
{nil, "in_progress", "pending"},
{nil, "queued", "pending"},
}
for _, tt := range tests {
name := "nil"
if tt.conclusion != nil {
name = *tt.conclusion
}
t.Run(name, func(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if strings.Contains(r.URL.Path, "/status") {
json.NewEncoder(w).Encode(map[string]interface{}{
"state": "success",
"statuses": []interface{}{},
})
return
}
json.NewEncoder(w).Encode(map[string]interface{}{
"total_count": 1,
"check_runs": []map[string]interface{}{
{
"name": "check",
"conclusion": tt.conclusion,
"status": tt.status,
"html_url": "https://github.com/check/1",
},
},
})
}))
defer srv.Close()
c := NewClient("token", srv.URL, AllowInsecureHTTP())
c.SetHTTPClient(srv.Client())
statuses, err := c.GetCommitStatuses(context.Background(), "owner", "repo", "sha1")
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if len(statuses) != 1 {
t.Fatalf("expected 1 status, got %d", len(statuses))
}
if statuses[0].Status != tt.want {
t.Errorf("expected status %q, got %q", tt.want, statuses[0].Status)
}
})
}
}
func TestGetCommitStatuses_404(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(404)
w.Write([]byte(`{"message":"Not Found"}`))
}))
defer srv.Close()
c := NewClient("token", srv.URL, AllowInsecureHTTP())
c.SetHTTPClient(srv.Client())
_, err := c.GetCommitStatuses(context.Background(), "owner", "repo", "badsha")
if err == nil {
t.Fatal("expected error for 404")
}
}
func TestGetCommitStatuses_401(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(401)
w.Write([]byte(`{"message":"Bad credentials"}`))
}))
defer srv.Close()
c := NewClient("token", srv.URL, AllowInsecureHTTP())
c.SetHTTPClient(srv.Client())
_, err := c.GetCommitStatuses(context.Background(), "owner", "repo", "sha")
if err == nil {
t.Fatal("expected error for 401")
}
}
func TestGetCommitStatuses_MalformedJSON(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(200)
w.Write([]byte(`not json`))
}))
defer srv.Close()
c := NewClient("token", srv.URL, AllowInsecureHTTP())
c.SetHTTPClient(srv.Client())
_, err := c.GetCommitStatuses(context.Background(), "owner", "repo", "sha")
if err == nil {
t.Fatal("expected error for malformed JSON")
}
}
func stringPtr(s string) *string {
return &s
}
+17 -4
View File
@@ -4,19 +4,32 @@ import (
"context" "context"
"log/slog" "log/slog"
"strings" "strings"
"gitea.weiker.me/rodin/review-bot/vcs"
) )
// RepoPersonaPath is the directory path where repo-specific personas are stored. // RepoPersonaPath is the directory path where repo-specific personas are stored.
const RepoPersonaPath = ".review-bot/personas" const RepoPersonaPath = ".review-bot/personas"
// GiteaClient defines the subset of gitea.Client methods needed for loading repo personas.
// This interface allows for easier testing and decouples the review package from gitea.
type GiteaClient interface {
ListContents(ctx context.Context, owner, repo, path string) ([]ContentEntry, error)
GetFileContent(ctx context.Context, owner, repo, filepath string) (string, error)
}
// ContentEntry represents a file or directory entry from the contents API.
// This mirrors gitea.ContentEntry to avoid import cycles.
type ContentEntry struct {
Name string `json:"name"`
Path string `json:"path"`
Type string `json:"type"` // "file" or "dir"
}
// LoadRepoPersonas fetches personas from a repository's .review-bot/personas/ directory. // LoadRepoPersonas fetches personas from a repository's .review-bot/personas/ directory.
// Returns an empty map (not nil) if the directory doesn't exist or is empty. // Returns an empty map (not nil) if the directory doesn't exist or is empty.
// Individual parse failures are logged and skipped; the remaining personas are still returned. // Individual parse failures are logged and skipped; the remaining personas are still returned.
// Auth errors and other non-404 errors are propagated. // Auth errors and other non-404 errors are propagated.
// Files exceeding MaxPersonaFileSize are rejected to prevent resource exhaustion. // Files exceeding MaxPersonaFileSize are rejected to prevent resource exhaustion.
func LoadRepoPersonas(ctx context.Context, client vcs.FileReader, owner, repo string) (map[string]*Persona, error) { func LoadRepoPersonas(ctx context.Context, client GiteaClient, owner, repo string) (map[string]*Persona, error) {
result := make(map[string]*Persona) result := make(map[string]*Persona)
entries, err := client.ListContents(ctx, owner, repo, RepoPersonaPath) entries, err := client.ListContents(ctx, owner, repo, RepoPersonaPath)
@@ -44,7 +57,7 @@ func LoadRepoPersonas(ctx context.Context, client vcs.FileReader, owner, repo st
continue continue
} }
content, err := client.GetFileContent(ctx, owner, repo, entry.Path, "") content, err := client.GetFileContent(ctx, owner, repo, entry.Path)
if err != nil { if err != nil {
slog.Warn("could not fetch repo persona file", slog.Warn("could not fetch repo persona file",
"file", entry.Path, "file", entry.Path,
+62 -31
View File
@@ -5,21 +5,23 @@ import (
"errors" "errors"
"strings" "strings"
"testing" "testing"
"gitea.weiker.me/rodin/review-bot/vcs"
) )
func TestParsePersonaBytes(t *testing.T) { func TestParsePersonaBytes(t *testing.T) {
tests := []struct { tests := []struct {
name string name string
data string data string
source string source string
wantName string wantName string
wantErr string wantErr string
}{ }{
{ {
name: "valid yaml", name: "valid yaml",
data: "name: test\nidentity: test identity\nfocus:\n - testing\n", data: `name: test
identity: test identity
focus:
- testing
`,
source: "test.yaml", source: "test.yaml",
wantName: "test", wantName: "test",
}, },
@@ -36,8 +38,8 @@ func TestParsePersonaBytes(t *testing.T) {
wantErr: "parse", wantErr: "parse",
}, },
{ {
name: "json format by extension", name: "json format by extension",
data: `{"name": "jsontest", "identity": "json identity"}`, data: `{"name": "jsontest", "identity": "json identity"}`,
source: "test.json", source: "test.json",
wantName: "jsontest", wantName: "jsontest",
}, },
@@ -65,15 +67,15 @@ func TestParsePersonaBytes(t *testing.T) {
} }
} }
// mockGiteaClient implements vcs.FileReader for testing. // mockGiteaClient implements GiteaClient for testing.
type mockGiteaClient struct { type mockGiteaClient struct {
contents map[string][]vcs.ContentEntry // path -> entries contents map[string][]ContentEntry // path -> entries
files map[string]string // path -> content files map[string]string // path -> content
listErr error listErr error
fileErr map[string]error // path -> error fileErr map[string]error // path -> error
} }
func (m *mockGiteaClient) ListContents(ctx context.Context, owner, repo, path string) ([]vcs.ContentEntry, error) { func (m *mockGiteaClient) ListContents(ctx context.Context, owner, repo, path string) ([]ContentEntry, error) {
if m.listErr != nil { if m.listErr != nil {
return nil, m.listErr return nil, m.listErr
} }
@@ -84,7 +86,7 @@ func (m *mockGiteaClient) ListContents(ctx context.Context, owner, repo, path st
return entries, nil return entries, nil
} }
func (m *mockGiteaClient) GetFileContent(ctx context.Context, owner, repo, filepath, ref string) (string, error) { func (m *mockGiteaClient) GetFileContent(ctx context.Context, owner, repo, filepath string) (string, error) {
if m.fileErr != nil { if m.fileErr != nil {
if err, ok := m.fileErr[filepath]; ok { if err, ok := m.fileErr[filepath]; ok {
return "", err return "", err
@@ -116,7 +118,7 @@ func TestLoadRepoPersonas(t *testing.T) {
t.Run("empty directory returns empty map", func(t *testing.T) { t.Run("empty directory returns empty map", func(t *testing.T) {
client := &mockGiteaClient{ client := &mockGiteaClient{
contents: map[string][]vcs.ContentEntry{ contents: map[string][]ContentEntry{
RepoPersonaPath: {}, RepoPersonaPath: {},
}, },
} }
@@ -131,15 +133,27 @@ func TestLoadRepoPersonas(t *testing.T) {
t.Run("loads valid personas", func(t *testing.T) { t.Run("loads valid personas", func(t *testing.T) {
client := &mockGiteaClient{ client := &mockGiteaClient{
contents: map[string][]vcs.ContentEntry{ contents: map[string][]ContentEntry{
RepoPersonaPath: { RepoPersonaPath: {
{Name: "trading.yaml", Path: ".review-bot/personas/trading.yaml", Type: "file"}, {Name: "trading.yaml", Path: ".review-bot/personas/trading.yaml", Type: "file"},
{Name: "crypto.yaml", Path: ".review-bot/personas/crypto.yaml", Type: "file"}, {Name: "crypto.yaml", Path: ".review-bot/personas/crypto.yaml", Type: "file"},
}, },
}, },
files: map[string]string{ files: map[string]string{
".review-bot/personas/trading.yaml": "name: trading\ndisplay_name: Trading Expert\nidentity: You are a trading expert.\nfocus:\n - order handling\n - risk management\n", ".review-bot/personas/trading.yaml": `name: trading
".review-bot/personas/crypto.yaml": "name: crypto\ndisplay_name: Crypto Expert\nidentity: You are a cryptography expert.\nfocus:\n - key management\n - encryption\n", display_name: Trading Expert
identity: You are a trading expert.
focus:
- order handling
- risk management
`,
".review-bot/personas/crypto.yaml": `name: crypto
display_name: Crypto Expert
identity: You are a cryptography expert.
focus:
- key management
- encryption
`,
}, },
} }
personas, err := LoadRepoPersonas(ctx, client, "owner", "repo") personas, err := LoadRepoPersonas(ctx, client, "owner", "repo")
@@ -162,14 +176,16 @@ func TestLoadRepoPersonas(t *testing.T) {
t.Run("skips invalid persona files", func(t *testing.T) { t.Run("skips invalid persona files", func(t *testing.T) {
client := &mockGiteaClient{ client := &mockGiteaClient{
contents: map[string][]vcs.ContentEntry{ contents: map[string][]ContentEntry{
RepoPersonaPath: { RepoPersonaPath: {
{Name: "valid.yaml", Path: ".review-bot/personas/valid.yaml", Type: "file"}, {Name: "valid.yaml", Path: ".review-bot/personas/valid.yaml", Type: "file"},
{Name: "invalid.yaml", Path: ".review-bot/personas/invalid.yaml", Type: "file"}, {Name: "invalid.yaml", Path: ".review-bot/personas/invalid.yaml", Type: "file"},
}, },
}, },
files: map[string]string{ files: map[string]string{
".review-bot/personas/valid.yaml": "name: valid\nidentity: Valid persona\n", ".review-bot/personas/valid.yaml": `name: valid
identity: Valid persona
`,
".review-bot/personas/invalid.yaml": "not valid yaml: [broken", ".review-bot/personas/invalid.yaml": "not valid yaml: [broken",
}, },
} }
@@ -177,6 +193,7 @@ func TestLoadRepoPersonas(t *testing.T) {
if err != nil { if err != nil {
t.Fatalf("unexpected error: %v", err) t.Fatalf("unexpected error: %v", err)
} }
// Should have the valid one, skip the invalid
if len(personas) != 1 { if len(personas) != 1 {
t.Fatalf("expected 1 persona (skipped invalid), got %d", len(personas)) t.Fatalf("expected 1 persona (skipped invalid), got %d", len(personas))
} }
@@ -187,7 +204,7 @@ func TestLoadRepoPersonas(t *testing.T) {
t.Run("skips non-yaml files", func(t *testing.T) { t.Run("skips non-yaml files", func(t *testing.T) {
client := &mockGiteaClient{ client := &mockGiteaClient{
contents: map[string][]vcs.ContentEntry{ contents: map[string][]ContentEntry{
RepoPersonaPath: { RepoPersonaPath: {
{Name: "persona.yaml", Path: ".review-bot/personas/persona.yaml", Type: "file"}, {Name: "persona.yaml", Path: ".review-bot/personas/persona.yaml", Type: "file"},
{Name: "README.md", Path: ".review-bot/personas/README.md", Type: "file"}, {Name: "README.md", Path: ".review-bot/personas/README.md", Type: "file"},
@@ -195,8 +212,10 @@ func TestLoadRepoPersonas(t *testing.T) {
}, },
}, },
files: map[string]string{ files: map[string]string{
".review-bot/personas/persona.yaml": "name: test\nidentity: Test persona\n", ".review-bot/personas/persona.yaml": `name: test
".review-bot/personas/README.md": "# Personas\n\nPut your personas here.", identity: Test persona
`,
".review-bot/personas/README.md": "# Personas\n\nPut your personas here.",
}, },
} }
personas, err := LoadRepoPersonas(ctx, client, "owner", "repo") personas, err := LoadRepoPersonas(ctx, client, "owner", "repo")
@@ -210,14 +229,16 @@ func TestLoadRepoPersonas(t *testing.T) {
t.Run("skips subdirectories", func(t *testing.T) { t.Run("skips subdirectories", func(t *testing.T) {
client := &mockGiteaClient{ client := &mockGiteaClient{
contents: map[string][]vcs.ContentEntry{ contents: map[string][]ContentEntry{
RepoPersonaPath: { RepoPersonaPath: {
{Name: "persona.yaml", Path: ".review-bot/personas/persona.yaml", Type: "file"}, {Name: "persona.yaml", Path: ".review-bot/personas/persona.yaml", Type: "file"},
{Name: "subdir", Path: ".review-bot/personas/subdir", Type: "dir"}, {Name: "subdir", Path: ".review-bot/personas/subdir", Type: "dir"},
}, },
}, },
files: map[string]string{ files: map[string]string{
".review-bot/personas/persona.yaml": "name: test\nidentity: Test persona\n", ".review-bot/personas/persona.yaml": `name: test
identity: Test persona
`,
}, },
} }
personas, err := LoadRepoPersonas(ctx, client, "owner", "repo") personas, err := LoadRepoPersonas(ctx, client, "owner", "repo")
@@ -244,14 +265,16 @@ func TestLoadRepoPersonas(t *testing.T) {
t.Run("skips files that fail to fetch", func(t *testing.T) { t.Run("skips files that fail to fetch", func(t *testing.T) {
client := &mockGiteaClient{ client := &mockGiteaClient{
contents: map[string][]vcs.ContentEntry{ contents: map[string][]ContentEntry{
RepoPersonaPath: { RepoPersonaPath: {
{Name: "good.yaml", Path: ".review-bot/personas/good.yaml", Type: "file"}, {Name: "good.yaml", Path: ".review-bot/personas/good.yaml", Type: "file"},
{Name: "bad.yaml", Path: ".review-bot/personas/bad.yaml", Type: "file"}, {Name: "bad.yaml", Path: ".review-bot/personas/bad.yaml", Type: "file"},
}, },
}, },
files: map[string]string{ files: map[string]string{
".review-bot/personas/good.yaml": "name: good\nidentity: Good persona\n", ".review-bot/personas/good.yaml": `name: good
identity: Good persona
`,
}, },
fileErr: map[string]error{ fileErr: map[string]error{
".review-bot/personas/bad.yaml": errors.New("HTTP 500: internal server error"), ".review-bot/personas/bad.yaml": errors.New("HTTP 500: internal server error"),
@@ -267,23 +290,27 @@ func TestLoadRepoPersonas(t *testing.T) {
}) })
t.Run("skips oversized files", func(t *testing.T) { t.Run("skips oversized files", func(t *testing.T) {
// Create a content string that exceeds MaxPersonaFileSize (64KB)
oversizedContent := strings.Repeat("a", MaxPersonaFileSize+1) oversizedContent := strings.Repeat("a", MaxPersonaFileSize+1)
client := &mockGiteaClient{ client := &mockGiteaClient{
contents: map[string][]vcs.ContentEntry{ contents: map[string][]ContentEntry{
RepoPersonaPath: { RepoPersonaPath: {
{Name: "normal.yaml", Path: ".review-bot/personas/normal.yaml", Type: "file"}, {Name: "normal.yaml", Path: ".review-bot/personas/normal.yaml", Type: "file"},
{Name: "huge.yaml", Path: ".review-bot/personas/huge.yaml", Type: "file"}, {Name: "huge.yaml", Path: ".review-bot/personas/huge.yaml", Type: "file"},
}, },
}, },
files: map[string]string{ files: map[string]string{
".review-bot/personas/normal.yaml": "name: normal\nidentity: Normal sized persona\n", ".review-bot/personas/normal.yaml": `name: normal
".review-bot/personas/huge.yaml": oversizedContent, identity: Normal sized persona
`,
".review-bot/personas/huge.yaml": oversizedContent,
}, },
} }
personas, err := LoadRepoPersonas(ctx, client, "owner", "repo") personas, err := LoadRepoPersonas(ctx, client, "owner", "repo")
if err != nil { if err != nil {
t.Fatalf("unexpected error: %v", err) t.Fatalf("unexpected error: %v", err)
} }
// Should have the normal one, skip the oversized
if len(personas) != 1 { if len(personas) != 1 {
t.Fatalf("expected 1 persona (skipped oversized), got %d", len(personas)) t.Fatalf("expected 1 persona (skipped oversized), got %d", len(personas))
} }
@@ -343,6 +370,7 @@ func TestGetBuiltinPersonasMap(t *testing.T) {
t.Fatal("expected at least one built-in persona") t.Fatal("expected at least one built-in persona")
} }
// Verify expected personas exist
expected := []string{"security", "architect", "docs"} expected := []string{"security", "architect", "docs"}
for _, name := range expected { for _, name := range expected {
if personas[name] == nil { if personas[name] == nil {
@@ -350,6 +378,7 @@ func TestGetBuiltinPersonasMap(t *testing.T) {
} }
} }
// Verify personas are valid
for name, p := range personas { for name, p := range personas {
if p.Name != name { if p.Name != name {
t.Errorf("persona %q has mismatched name %q", name, p.Name) t.Errorf("persona %q has mismatched name %q", name, p.Name)
@@ -393,6 +422,8 @@ func TestIsNotFoundError(t *testing.T) {
{nil, false}, {nil, false},
{errors.New("HTTP 404: not found"), true}, {errors.New("HTTP 404: not found"), true},
{errors.New("HTTP 404"), true}, {errors.New("HTTP 404"), true},
// Intentionally false: generic "not found" could mask auth/transport errors.
// Only explicit HTTP 404 responses should be treated as "directory doesn't exist".
{errors.New("something not found"), false}, {errors.New("something not found"), false},
{errors.New("HTTP 401: unauthorized"), false}, {errors.New("HTTP 401: unauthorized"), false},
{errors.New("connection refused"), false}, {errors.New("connection refused"), false},
-3
View File
@@ -10,8 +10,6 @@ type PRReader interface {
GetPullRequest(ctx context.Context, owner, repo string, number int) (*PullRequest, error) GetPullRequest(ctx context.Context, owner, repo string, number int) (*PullRequest, error)
GetPullRequestDiff(ctx context.Context, owner, repo string, number int) (string, error) GetPullRequestDiff(ctx context.Context, owner, repo string, number int) (string, error)
GetPullRequestFiles(ctx context.Context, owner, repo string, number int) ([]ChangedFile, error) GetPullRequestFiles(ctx context.Context, owner, repo string, number int) ([]ChangedFile, error)
GetFileContentAtRef(ctx context.Context, owner, repo, path, ref string) (string, error)
GetCommitStatuses(ctx context.Context, owner, repo, sha string) ([]CommitStatus, error)
} }
// FileReader can fetch file contents and list directory entries. // FileReader can fetch file contents and list directory entries.
@@ -25,7 +23,6 @@ type Reviewer interface {
PostReview(ctx context.Context, owner, repo string, number int, req ReviewRequest) (*Review, error) PostReview(ctx context.Context, owner, repo string, number int, req ReviewRequest) (*Review, error)
ListReviews(ctx context.Context, owner, repo string, number int) ([]Review, error) ListReviews(ctx context.Context, owner, repo string, number int) ([]Review, error)
DeleteReview(ctx context.Context, owner, repo string, number int, reviewID int64) error DeleteReview(ctx context.Context, owner, repo string, number int, reviewID int64) error
DismissReview(ctx context.Context, owner, repo string, number int, reviewID int64, message string) error
} }
// Identity can report who the authenticated user is. // Identity can report who the authenticated user is.
+3 -19
View File
@@ -15,11 +15,6 @@ const (
ReviewEventComment ReviewEvent = "COMMENT" ReviewEventComment ReviewEvent = "COMMENT"
) )
// BaseRef identifies the target branch of a pull request.
type BaseRef struct {
Ref string `json:"ref"`
}
// HeadRef identifies the source branch and latest commit of a pull request. // HeadRef identifies the source branch and latest commit of a pull request.
type HeadRef struct { type HeadRef struct {
SHA string `json:"sha"` SHA string `json:"sha"`
@@ -33,18 +28,15 @@ type UserInfo struct {
// PullRequest holds relevant PR metadata. // PullRequest holds relevant PR metadata.
type PullRequest struct { type PullRequest struct {
Number int `json:"number"` Title string `json:"title"`
Title string `json:"title"` Body string `json:"body"`
Body string `json:"body"` Head HeadRef `json:"head"`
Head HeadRef `json:"head"`
Base BaseRef `json:"base"`
} }
// ChangedFile represents a file modified in a PR. // ChangedFile represents a file modified in a PR.
type ChangedFile struct { type ChangedFile struct {
Filename string `json:"filename"` Filename string `json:"filename"`
Status string `json:"status"` Status string `json:"status"`
Patch string `json:"patch"`
} }
// ContentEntry represents a file or directory entry from the contents API. // ContentEntry represents a file or directory entry from the contents API.
@@ -54,14 +46,6 @@ type ContentEntry struct {
Type string `json:"type"` // "file" or "dir" Type string `json:"type"` // "file" or "dir"
} }
// CommitStatus represents a single CI status entry for a commit.
type CommitStatus struct {
Status string `json:"status"`
Context string `json:"context"`
Description string `json:"description"`
TargetURL string `json:"target_url"`
}
// Review represents a pull request review. // Review represents a pull request review.
type Review struct { type Review struct {
ID int64 `json:"id"` ID int64 `json:"id"`
+62 -113
View File
@@ -7,86 +7,51 @@ import (
"strings" "strings"
) )
const (
// maxFilesInPath is the maximum number of files GetAllFilesInPath will fetch.
// Prevents unbounded resource consumption on very large directory trees.
maxFilesInPath = 10000
// maxTotalBytesInPath is the maximum total bytes GetAllFilesInPath will accumulate.
// Prevents memory exhaustion when fetching large repositories.
maxTotalBytesInPath = 100 * 1024 * 1024 // 100 MB
)
// GetAllFilesInPath recursively fetches all file contents under a path using the // GetAllFilesInPath recursively fetches all file contents under a path using the
// provided FileReader. Returns a map of filepath -> content for all files found. // provided FileReader. Returns a map of filepath -> content for all files found.
// If the path points to an empty directory, returns an empty map. // If the path points to an empty directory, returns an empty map.
//
// This function uses fail-fast error handling: any error from ListContents or
// GetFileContent aborts the entire traversal and returns the error immediately.
// This differs from gitea.Client.GetAllFilesInPath, which logs errors and continues.
// The fail-fast contract ensures callers can trust that a nil error means all files
// were successfully fetched.
//
// Resource limits: the traversal is bounded by maxFilesInPath (file count) and
// maxTotalBytesInPath (total accumulated bytes). The context is checked before each
// recursive call and file fetch to respect cancellation.
func GetAllFilesInPath(ctx context.Context, client FileReader, owner, repo, path string) (map[string]string, error) { func GetAllFilesInPath(ctx context.Context, client FileReader, owner, repo, path string) (map[string]string, error) {
results := make(map[string]string) results := make(map[string]string)
totalBytes := 0
var walk func(string) error entries, err := client.ListContents(ctx, owner, repo, path)
walk = func(dir string) error { if err != nil {
if err := ctx.Err(); err != nil { return nil, fmt.Errorf("list contents %q: %w", path, err)
return fmt.Errorf("context canceled during traversal: %w", err)
}
entries, err := client.ListContents(ctx, owner, repo, dir)
if err != nil {
return fmt.Errorf("list contents %q: %w", dir, err)
}
for _, entry := range entries {
if err := ctx.Err(); err != nil {
return fmt.Errorf("context canceled during traversal: %w", err)
}
switch entry.Type {
case "file":
if len(results) >= maxFilesInPath {
return fmt.Errorf("exceeded max file count (%d) in path %q", maxFilesInPath, path)
}
content, err := client.GetFileContent(ctx, owner, repo, entry.Path, "")
if err != nil {
return fmt.Errorf("get file %q: %w", entry.Path, err)
}
totalBytes += len(content)
if totalBytes > maxTotalBytesInPath {
return fmt.Errorf("exceeded max total bytes (%d) in path %q", maxTotalBytesInPath, path)
}
results[entry.Path] = content
case "dir":
if err := walk(entry.Path); err != nil {
return err
}
}
}
return nil
} }
if err := walk(path); err != nil { for _, entry := range entries {
return nil, err switch entry.Type {
case "file":
content, err := client.GetFileContent(ctx, owner, repo, entry.Path, "")
if err != nil {
return nil, fmt.Errorf("get file %q: %w", entry.Path, err)
}
results[entry.Path] = content
case "dir":
subResults, err := GetAllFilesInPath(ctx, client, owner, repo, entry.Path)
if err != nil {
return nil, fmt.Errorf("recurse into %q: %w", entry.Path, err)
}
for k, v := range subResults {
results[k] = v
}
}
} }
return results, nil return results, nil
} }
// BuildLineToPositionMap parses a unified diff and returns a map of // BuildLineToPositionMap parses a unified diff and returns a per-file map of
// filename -> (new line number -> diff position). The diff position is a // new-file line number diff-position.
// 1-indexed offset from the @@ hunk header line for each file. //
// Only lines that appear in the new file (context lines and additions) are mapped. // Position is a 1-indexed offset from the @@ hunk line (GitHub convention):
// Deletion-only lines are not included. // - The @@ hunk header itself counts as position 1 (for the first hunk)
// - Every subsequent content line (context, addition, deletion) increments position
// - A second @@ hunk in the same file continues the count; it does not reset
// - Addition and context lines have new-file line numbers and are mapped
// - Deletion lines have a position but no new-file line; they are not in the map
// - "\ No newline at end of file" markers do not count as a position
//
// Returns: map[filename]map[newFileLine]diffPosition
func BuildLineToPositionMap(diff string) map[string]map[int]int { func BuildLineToPositionMap(diff string) map[string]map[int]int {
result := make(map[string]map[int]int) result := make(map[string]map[int]int)
@@ -96,7 +61,7 @@ func BuildLineToPositionMap(diff string) map[string]map[int]int {
var newLine int var newLine int
for _, line := range lines { for _, line := range lines {
// Detect new file in diff // Detect new file in diff — resets position counter per file.
if strings.HasPrefix(line, "+++ b/") { if strings.HasPrefix(line, "+++ b/") {
currentFile = strings.TrimPrefix(line, "+++ b/") currentFile = strings.TrimPrefix(line, "+++ b/")
position = 0 position = 0
@@ -107,85 +72,69 @@ func BuildLineToPositionMap(diff string) map[string]map[int]int {
continue continue
} }
// Skip --- lines (old file header) // Skip deleted-file target (no new lines to map).
if strings.HasPrefix(line, "--- ") { if strings.HasPrefix(line, "+++ /dev/null") {
currentFile = ""
continue continue
} }
// Skip diff --git lines // Skip metadata lines that are not part of position counting.
if strings.HasPrefix(line, "diff --git") { if strings.HasPrefix(line, "diff --git") ||
strings.HasPrefix(line, "--- ") ||
strings.HasPrefix(line, "index ") ||
strings.HasPrefix(line, "\\") {
continue continue
} }
// Skip index lines // Parse hunk headers — the @@ line itself occupies a position.
if strings.HasPrefix(line, "index ") { if strings.HasPrefix(line, "@@") && currentFile != "" {
continue
}
// Parse hunk headers
if strings.HasPrefix(line, "@@") {
position++ position++
// Extract new file start line from @@ -a,b +c,d @@
newLine = parseHunkNewStart(line) newLine = parseHunkNewStart(line)
continue continue
} }
// We need a current file to map lines
if currentFile == "" { if currentFile == "" {
continue continue
} }
// Skip "\ No newline at end of file" markers — these are git diff // Content lines within a hunk.
// metadata and not part of the file content. switch {
if strings.HasPrefix(line, `\`) { case strings.HasPrefix(line, "+"):
continue // Addition: maps to both a position and a new-file line.
}
// Process diff content lines
if strings.HasPrefix(line, "+") {
position++ position++
result[currentFile][newLine] = position result[currentFile][newLine] = position
newLine++ newLine++
} else if strings.HasPrefix(line, "-") { case strings.HasPrefix(line, "-"):
// Deletion: occupies a position but has no new-file line number.
position++ position++
// Deletion lines don't map to new line numbers case strings.HasPrefix(line, " "):
} else if strings.HasPrefix(line, " ") { // Context line: maps to both a position and a new-file line.
// Context line (space-prefixed). position++
// Only map if position > 0, which means we've seen a hunk header. result[currentFile][newLine] = position
// Lines before the first hunk header (position == 0) are not part newLine++
// of any diff hunk and should be skipped.
if position > 0 {
position++
result[currentFile][newLine] = position
newLine++
}
} }
// Any other line (empty trailing lines from Split, etc.) is ignored.
} }
return result return result
} }
// parseHunkNewStart extracts the new-file starting line number from a hunk header. // parseHunkNewStart extracts the new-file starting line number from a hunk header.
// Format: @@ -old_start[,old_count] +new_start[,new_count] @@ // Format: @@ -old_start[,old_count] +new_start[,new_count] @@[ optional section heading]
func parseHunkNewStart(hunkLine string) int { func parseHunkNewStart(hunkLine string) int {
// Find the +N part // Find the +N part after the first @@
plusIdx := strings.Index(hunkLine, "+") plusIdx := strings.Index(hunkLine, "+")
if plusIdx < 0 { if plusIdx < 0 {
return 1 return 1
} }
rest := hunkLine[plusIdx+1:] rest := hunkLine[plusIdx+1:]
// Find the end of the number (first non-digit after +) // Read digits until comma, space, or end.
endIdx := 0 if idx := strings.IndexAny(rest, ", @"); idx > 0 {
for endIdx < len(rest) && rest[endIdx] >= '0' && rest[endIdx] <= '9' { rest = rest[:idx]
endIdx++
} }
if endIdx == 0 { n, err := strconv.Atoi(rest)
return 1
}
n, err := strconv.Atoi(rest[:endIdx])
if err != nil { if err != nil {
return 1 return 1
} }
+423 -298
View File
@@ -3,7 +3,6 @@ package vcs_test
import ( import (
"context" "context"
"fmt" "fmt"
"strings"
"testing" "testing"
"gitea.weiker.me/rodin/review-bot/vcs" "gitea.weiker.me/rodin/review-bot/vcs"
@@ -11,321 +10,447 @@ import (
// mockFileReader implements vcs.FileReader for testing. // mockFileReader implements vcs.FileReader for testing.
type mockFileReader struct { type mockFileReader struct {
contents map[string][]vcs.ContentEntry // path -> entries contents map[string]string // path -> content
files map[string]string // path -> content dirs map[string][]vcs.ContentEntry // path -> entries
} }
func (m *mockFileReader) GetFileContent(ctx context.Context, owner, repo, path, ref string) (string, error) { func (m *mockFileReader) GetFileContent(_ context.Context, _, _, path, _ string) (string, error) {
content, ok := m.files[path] content, ok := m.contents[path]
if !ok { if !ok {
return "", fmt.Errorf("HTTP 404: file not found: %s", path) return "", fmt.Errorf("file not found: %s", path)
} }
return content, nil return content, nil
} }
func (m *mockFileReader) ListContents(ctx context.Context, owner, repo, path string) ([]vcs.ContentEntry, error) { func (m *mockFileReader) ListContents(_ context.Context, _, _, path string) ([]vcs.ContentEntry, error) {
entries, ok := m.contents[path] entries, ok := m.dirs[path]
if !ok { if !ok {
return nil, fmt.Errorf("HTTP 404: path not found: %s", path) return nil, fmt.Errorf("directory not found: %s", path)
} }
return entries, nil return entries, nil
} }
func TestGetAllFilesInPath(t *testing.T) { // --- GetAllFilesInPath tests ---
ctx := context.Background()
t.Run("empty directory", func(t *testing.T) { func TestGetAllFilesInPath_EmptyDir(t *testing.T) {
client := &mockFileReader{ mock := &mockFileReader{
contents: map[string][]vcs.ContentEntry{ dirs: map[string][]vcs.ContentEntry{
"src": {}, "src": {},
}, },
} }
result, err := vcs.GetAllFilesInPath(ctx, client, "owner", "repo", "src")
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if len(result) != 0 {
t.Errorf("expected empty map, got %d entries", len(result))
}
})
t.Run("flat directory", func(t *testing.T) { result, err := vcs.GetAllFilesInPath(context.Background(), mock, "owner", "repo", "src")
client := &mockFileReader{ if err != nil {
contents: map[string][]vcs.ContentEntry{ t.Fatalf("unexpected error: %v", err)
"src": { }
{Name: "main.go", Path: "src/main.go", Type: "file"}, if len(result) != 0 {
{Name: "util.go", Path: "src/util.go", Type: "file"}, t.Errorf("expected empty map, got %d entries", len(result))
}, }
},
files: map[string]string{
"src/main.go": "package main",
"src/util.go": "package main\n// util",
},
}
result, err := vcs.GetAllFilesInPath(ctx, client, "owner", "repo", "src")
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if len(result) != 2 {
t.Fatalf("expected 2 files, got %d", len(result))
}
if result["src/main.go"] != "package main" {
t.Errorf("main.go content = %q", result["src/main.go"])
}
if result["src/util.go"] != "package main\n// util" {
t.Errorf("util.go content = %q", result["src/util.go"])
}
})
t.Run("nested directories", func(t *testing.T) {
client := &mockFileReader{
contents: map[string][]vcs.ContentEntry{
"src": {
{Name: "main.go", Path: "src/main.go", Type: "file"},
{Name: "pkg", Path: "src/pkg", Type: "dir"},
},
"src/pkg": {
{Name: "lib.go", Path: "src/pkg/lib.go", Type: "file"},
{Name: "sub", Path: "src/pkg/sub", Type: "dir"},
},
"src/pkg/sub": {
{Name: "deep.go", Path: "src/pkg/sub/deep.go", Type: "file"},
},
},
files: map[string]string{
"src/main.go": "package main",
"src/pkg/lib.go": "package pkg",
"src/pkg/sub/deep.go": "package sub",
},
}
result, err := vcs.GetAllFilesInPath(ctx, client, "owner", "repo", "src")
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if len(result) != 3 {
t.Fatalf("expected 3 files, got %d", len(result))
}
if result["src/main.go"] != "package main" {
t.Errorf("main.go content = %q", result["src/main.go"])
}
if result["src/pkg/lib.go"] != "package pkg" {
t.Errorf("lib.go content = %q", result["src/pkg/lib.go"])
}
if result["src/pkg/sub/deep.go"] != "package sub" {
t.Errorf("deep.go content = %q", result["src/pkg/sub/deep.go"])
}
})
t.Run("mixed files and dirs", func(t *testing.T) {
client := &mockFileReader{
contents: map[string][]vcs.ContentEntry{
"root": {
{Name: "README.md", Path: "root/README.md", Type: "file"},
{Name: "docs", Path: "root/docs", Type: "dir"},
{Name: "config.yaml", Path: "root/config.yaml", Type: "file"},
},
"root/docs": {
{Name: "guide.md", Path: "root/docs/guide.md", Type: "file"},
},
},
files: map[string]string{
"root/README.md": "# Hello",
"root/config.yaml": "key: value",
"root/docs/guide.md": "## Guide",
},
}
result, err := vcs.GetAllFilesInPath(ctx, client, "owner", "repo", "root")
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if len(result) != 3 {
t.Fatalf("expected 3 files, got %d", len(result))
}
if result["root/README.md"] != "# Hello" {
t.Errorf("README content = %q", result["root/README.md"])
}
if result["root/docs/guide.md"] != "## Guide" {
t.Errorf("guide content = %q", result["root/docs/guide.md"])
}
})
} }
func TestBuildLineToPositionMap(t *testing.T) { func TestGetAllFilesInPath_FlatDir(t *testing.T) {
t.Run("single hunk", func(t *testing.T) { mock := &mockFileReader{
diff := "diff --git a/file.go b/file.go\nindex abc..def 100644\n--- a/file.go\n+++ b/file.go\n@@ -1,3 +1,4 @@\n package main\n \n+// new comment\n func main() {}\n" dirs: map[string][]vcs.ContentEntry{
result := vcs.BuildLineToPositionMap(diff) "src": {
fileMap, ok := result["file.go"] {Name: "main.go", Path: "src/main.go", Type: "file"},
if !ok { {Name: "util.go", Path: "src/util.go", Type: "file"},
t.Fatal("expected file.go in result") },
} },
// Hunk header @@ is position 1 contents: map[string]string{
// Line 1: " package main" -> position 2 "src/main.go": "package main",
if fileMap[1] != 2 { "src/util.go": "package main\n\nfunc helper() {}",
t.Errorf("line 1 position = %d, want 2", fileMap[1]) },
} }
// Line 2: " " (context) -> position 3
if fileMap[2] != 3 {
t.Errorf("line 2 position = %d, want 3", fileMap[2])
}
// Line 3: "+// new comment" -> position 4
if fileMap[3] != 4 {
t.Errorf("line 3 position = %d, want 4", fileMap[3])
}
// Line 4: " func main() {}" -> position 5
if fileMap[4] != 5 {
t.Errorf("line 4 position = %d, want 5", fileMap[4])
}
})
t.Run("multi hunk", func(t *testing.T) { result, err := vcs.GetAllFilesInPath(context.Background(), mock, "owner", "repo", "src")
diff := "diff --git a/file.go b/file.go\n--- a/file.go\n+++ b/file.go\n@@ -1,3 +1,3 @@\n package main\n \n-// old\n+// new\n@@ -10,3 +10,4 @@\n func foo() {\n+\t// added\n \treturn\n }\n" if err != nil {
result := vcs.BuildLineToPositionMap(diff) t.Fatalf("unexpected error: %v", err)
fileMap, ok := result["file.go"] }
if !ok { if len(result) != 2 {
t.Fatal("expected file.go in result") t.Fatalf("expected 2 files, got %d", len(result))
} }
// First hunk: @@ is position 1 if result["src/main.go"] != "package main" {
// Line 1: " package main" -> position 2 t.Errorf("unexpected content for main.go: %q", result["src/main.go"])
if fileMap[1] != 2 { }
t.Errorf("line 1 position = %d, want 2", fileMap[1]) if result["src/util.go"] != "package main\n\nfunc helper() {}" {
} t.Errorf("unexpected content for util.go: %q", result["src/util.go"])
// Line 3: "+// new" -> position 5 (after " ", "-// old" at pos 3,4) }
if fileMap[3] != 5 {
t.Errorf("line 3 position = %d, want 5", fileMap[3])
}
// Second hunk: @@ is position 6
// Line 10: " func foo() {" -> position 7
if fileMap[10] != 7 {
t.Errorf("line 10 position = %d, want 7", fileMap[10])
}
// Line 11: "+\t// added" -> position 8
if fileMap[11] != 8 {
t.Errorf("line 11 position = %d, want 8", fileMap[11])
}
})
t.Run("deletion lines not in map", func(t *testing.T) {
diff := "diff --git a/file.go b/file.go\n--- a/file.go\n+++ b/file.go\n@@ -1,4 +1,3 @@\n package main\n \n-// deleted line\n func main() {}\n"
result := vcs.BuildLineToPositionMap(diff)
fileMap, ok := result["file.go"]
if !ok {
t.Fatal("expected file.go in result")
}
// Line 1: " package main" -> position 2
if fileMap[1] != 2 {
t.Errorf("line 1 position = %d, want 2", fileMap[1])
}
// Line 3 in new file: " func main() {}" -> position 5 (after deletion at pos 4)
if fileMap[3] != 5 {
t.Errorf("line 3 position = %d, want 5", fileMap[3])
}
// Should only have 3 entries (lines 1, 2, 3 of new file)
if len(fileMap) != 3 {
t.Errorf("expected 3 mapped lines, got %d: %v", len(fileMap), fileMap)
}
})
t.Run("multiple files", func(t *testing.T) {
diff := "diff --git a/a.go b/a.go\n--- a/a.go\n+++ b/a.go\n@@ -1,2 +1,3 @@\n package a\n \n+// file a\ndiff --git a/b.go b/b.go\n--- a/b.go\n+++ b/b.go\n@@ -1,2 +1,3 @@\n package b\n \n+// file b\n"
result := vcs.BuildLineToPositionMap(diff)
if len(result) != 2 {
t.Fatalf("expected 2 files, got %d", len(result))
}
aMap, ok := result["a.go"]
if !ok {
t.Fatal("expected a.go in result")
}
bMap, ok := result["b.go"]
if !ok {
t.Fatal("expected b.go in result")
}
// a.go line 3: "+// file a" -> position 4
if aMap[3] != 4 {
t.Errorf("a.go line 3 position = %d, want 4", aMap[3])
}
// b.go line 3: "+// file b" -> position 4
if bMap[3] != 4 {
t.Errorf("b.go line 3 position = %d, want 4", bMap[3])
}
})
} }
func TestGetAllFilesInPath_ErrorPropagation(t *testing.T) { func TestGetAllFilesInPath_NestedDirs(t *testing.T) {
ctx := context.Background() mock := &mockFileReader{
dirs: map[string][]vcs.ContentEntry{
"src": {
{Name: "main.go", Path: "src/main.go", Type: "file"},
{Name: "pkg", Path: "src/pkg", Type: "dir"},
},
"src/pkg": {
{Name: "lib.go", Path: "src/pkg/lib.go", Type: "file"},
{Name: "internal", Path: "src/pkg/internal", Type: "dir"},
},
"src/pkg/internal": {
{Name: "helper.go", Path: "src/pkg/internal/helper.go", Type: "file"},
},
},
contents: map[string]string{
"src/main.go": "package main",
"src/pkg/lib.go": "package pkg",
"src/pkg/internal/helper.go": "package internal",
},
}
t.Run("ListContents error propagates", func(t *testing.T) { result, err := vcs.GetAllFilesInPath(context.Background(), mock, "owner", "repo", "src")
client := &mockFileReader{ if err != nil {
contents: map[string][]vcs.ContentEntry{ t.Fatalf("unexpected error: %v", err)
// "src" not in map, so ListContents will fail }
}, if len(result) != 3 {
} t.Fatalf("expected 3 files, got %d", len(result))
_, err := vcs.GetAllFilesInPath(ctx, client, "owner", "repo", "src") }
if err == nil { if result["src/main.go"] != "package main" {
t.Fatal("expected error, got nil") t.Errorf("unexpected content for main.go")
} }
if !strings.Contains(err.Error(), "list contents") { if result["src/pkg/lib.go"] != "package pkg" {
t.Errorf("expected error about list contents, got: %v", err) t.Errorf("unexpected content for lib.go")
} }
}) if result["src/pkg/internal/helper.go"] != "package internal" {
t.Errorf("unexpected content for helper.go")
t.Run("GetFileContent error propagates", func(t *testing.T) { }
client := &mockFileReader{ }
contents: map[string][]vcs.ContentEntry{
"src": { func TestGetAllFilesInPath_Mixed(t *testing.T) {
{Name: "main.go", Path: "src/main.go", Type: "file"}, mock := &mockFileReader{
}, dirs: map[string][]vcs.ContentEntry{
}, "root": {
files: map[string]string{ {Name: "README.md", Path: "root/README.md", Type: "file"},
// "src/main.go" not in files map, so GetFileContent will fail {Name: "empty", Path: "root/empty", Type: "dir"},
}, {Name: "docs", Path: "root/docs", Type: "dir"},
} },
_, err := vcs.GetAllFilesInPath(ctx, client, "owner", "repo", "src") "root/empty": {},
if err == nil { "root/docs": {
t.Fatal("expected error, got nil") {Name: "guide.md", Path: "root/docs/guide.md", Type: "file"},
} },
if !strings.Contains(err.Error(), "get file") { },
t.Errorf("expected error about get file, got: %v", err) contents: map[string]string{
} "root/README.md": "# Hello",
}) "root/docs/guide.md": "## Guide",
},
t.Run("nested ListContents error propagates", func(t *testing.T) { }
client := &mockFileReader{
contents: map[string][]vcs.ContentEntry{ result, err := vcs.GetAllFilesInPath(context.Background(), mock, "owner", "repo", "root")
"src": { if err != nil {
{Name: "pkg", Path: "src/pkg", Type: "dir"}, t.Fatalf("unexpected error: %v", err)
}, }
// "src/pkg" not in map, so recursive ListContents will fail if len(result) != 2 {
}, t.Fatalf("expected 2 files, got %d", len(result))
} }
_, err := vcs.GetAllFilesInPath(ctx, client, "owner", "repo", "src") if result["root/README.md"] != "# Hello" {
if err == nil { t.Errorf("unexpected content for README.md")
t.Fatal("expected error, got nil") }
} if result["root/docs/guide.md"] != "## Guide" {
if !strings.Contains(err.Error(), "list contents") { t.Errorf("unexpected content for guide.md")
t.Errorf("expected error about list contents, got: %v", err) }
} }
})
func TestGetAllFilesInPath_ListContentsError(t *testing.T) {
t.Run("canceled context propagates", func(t *testing.T) { mock := &mockFileReader{
ctx, cancel := context.WithCancel(context.Background()) dirs: map[string][]vcs.ContentEntry{},
cancel() // Cancel immediately }
client := &mockFileReader{ _, err := vcs.GetAllFilesInPath(context.Background(), mock, "owner", "repo", "missing")
contents: map[string][]vcs.ContentEntry{ if err == nil {
"src": { t.Fatal("expected error for missing directory")
{Name: "main.go", Path: "src/main.go", Type: "file"}, }
}, }
},
files: map[string]string{ func TestGetAllFilesInPath_GetFileContentError(t *testing.T) {
"src/main.go": "package main", mock := &mockFileReader{
}, dirs: map[string][]vcs.ContentEntry{
} "src": {
_, err := vcs.GetAllFilesInPath(ctx, client, "owner", "repo", "src") {Name: "bad.go", Path: "src/bad.go", Type: "file"},
if err == nil { },
t.Fatal("expected error from canceled context, got nil") },
} contents: map[string]string{},
if !strings.Contains(err.Error(), "context canceled") { }
t.Errorf("expected context cancellation error, got: %v", err)
} _, err := vcs.GetAllFilesInPath(context.Background(), mock, "owner", "repo", "src")
}) if err == nil {
t.Fatal("expected error when file content fetch fails")
}
}
// --- BuildLineToPositionMap tests ---
func TestBuildLineToPositionMap_SingleHunk(t *testing.T) {
diff := `diff --git a/main.go b/main.go
index abc..def 100644
--- a/main.go
+++ b/main.go
@@ -5,7 +5,8 @@ package main
import "fmt"
func main() {
- fmt.Println("old")
+ fmt.Println("new")
+ fmt.Println("added")
return
}
`
result := vcs.BuildLineToPositionMap(diff)
fileMap, ok := result["main.go"]
if !ok {
t.Fatal("expected main.go in result")
}
// @@ line is position 1
// " import \"fmt\"" -> pos 2, newLine 5
// " " -> pos 3, newLine 6
// " func main() {" -> pos 4, newLine 7
// "-\tfmt.Println..." -> pos 5, no new line
// "+\tfmt.Println..." -> pos 6, newLine 8
// "+\tfmt.Println..." -> pos 7, newLine 9
// " \treturn" -> pos 8, newLine 10
// " }" -> pos 9, newLine 11
expected := map[int]int{
5: 2,
6: 3,
7: 4,
8: 6,
9: 7,
10: 8,
11: 9,
}
for line, wantPos := range expected {
gotPos, exists := fileMap[line]
if !exists {
t.Errorf("line %d: expected position %d, but line not in map", line, wantPos)
continue
}
if gotPos != wantPos {
t.Errorf("line %d: expected position %d, got %d", line, wantPos, gotPos)
}
}
if len(fileMap) != len(expected) {
t.Errorf("expected %d entries, got %d", len(expected), len(fileMap))
}
}
func TestBuildLineToPositionMap_MultiHunk(t *testing.T) {
diff := `diff --git a/main.go b/main.go
--- a/main.go
+++ b/main.go
@@ -1,3 +1,4 @@
package main
+import "fmt"
func main() {
@@ -10,3 +11,4 @@ func main() {
return
+ // extra
}
`
result := vcs.BuildLineToPositionMap(diff)
fileMap, ok := result["main.go"]
if !ok {
t.Fatal("expected main.go in result")
}
// First hunk:
// @@ line -> pos 1
// " package main" -> pos 2, newLine 1
// " " -> pos 3, newLine 2
// "+import..." -> pos 4, newLine 3
// " func main.." -> pos 5, newLine 4
//
// Second hunk (position continues!):
// @@ line -> pos 6
// " \treturn" -> pos 7, newLine 11
// "+\t// extra" -> pos 8, newLine 12
// " }" -> pos 9, newLine 13
expected := map[int]int{
1: 2,
2: 3,
3: 4,
4: 5,
11: 7,
12: 8,
13: 9,
}
for line, wantPos := range expected {
gotPos, exists := fileMap[line]
if !exists {
t.Errorf("line %d: expected position %d, but line not in map", line, wantPos)
continue
}
if gotPos != wantPos {
t.Errorf("line %d: expected position %d, got %d", line, wantPos, gotPos)
}
}
}
func TestBuildLineToPositionMap_DeletionLinesNotInMap(t *testing.T) {
diff := `diff --git a/old.go b/old.go
--- a/old.go
+++ b/old.go
@@ -1,3 +1,1 @@
-line one
-line two
remaining
`
result := vcs.BuildLineToPositionMap(diff)
fileMap, ok := result["old.go"]
if !ok {
t.Fatal("expected old.go in result")
}
// @@ line -> pos 1
// "-line one" -> pos 2, no new line
// "-line two" -> pos 3, no new line
// " remaining" -> pos 4, newLine 1
if pos, ok := fileMap[1]; !ok || pos != 4 {
t.Errorf("line 1: expected position 4, got %d (exists=%v)", pos, ok)
}
if len(fileMap) != 1 {
t.Errorf("expected 1 entry (only 'remaining'), got %d", len(fileMap))
}
}
func TestBuildLineToPositionMap_MultipleFiles(t *testing.T) {
diff := `diff --git a/a.go b/a.go
--- a/a.go
+++ b/a.go
@@ -1,2 +1,3 @@
package a
+// added
diff --git a/b.go b/b.go
new file mode 100644
--- /dev/null
+++ b/b.go
@@ -0,0 +1,3 @@
+package b
+
+func B() {}
`
result := vcs.BuildLineToPositionMap(diff)
// File a.go
aMap, ok := result["a.go"]
if !ok {
t.Fatal("expected a.go in result")
}
// @@ line -> pos 1
// " package a" -> pos 2, newLine 1
// "+// added" -> pos 3, newLine 2
// " " -> pos 4, newLine 3
if aMap[1] != 2 {
t.Errorf("a.go line 1: expected pos 2, got %d", aMap[1])
}
if aMap[2] != 3 {
t.Errorf("a.go line 2: expected pos 3, got %d", aMap[2])
}
if aMap[3] != 4 {
t.Errorf("a.go line 3: expected pos 4, got %d", aMap[3])
}
// File b.go — position resets for new file
bMap, ok := result["b.go"]
if !ok {
t.Fatal("expected b.go in result")
}
// @@ line -> pos 1
// "+package b" -> pos 2, newLine 1
// "+" -> pos 3, newLine 2
// "+func B() {}" -> pos 4, newLine 3
if bMap[1] != 2 {
t.Errorf("b.go line 1: expected pos 2, got %d", bMap[1])
}
if bMap[2] != 3 {
t.Errorf("b.go line 2: expected pos 3, got %d", bMap[2])
}
if bMap[3] != 4 {
t.Errorf("b.go line 3: expected pos 4, got %d", bMap[3])
}
}
func TestBuildLineToPositionMap_EmptyDiff(t *testing.T) {
result := vcs.BuildLineToPositionMap("")
if len(result) != 0 {
t.Errorf("expected empty map for empty diff, got %d entries", len(result))
}
}
func TestBuildLineToPositionMap_NoNewlineMarker(t *testing.T) {
diff := `diff --git a/a.go b/a.go
--- a/a.go
+++ b/a.go
@@ -1,2 +1,2 @@
-old
+new
\ No newline at end of file
`
result := vcs.BuildLineToPositionMap(diff)
fileMap := result["a.go"]
// @@ line -> pos 1
// "-old" -> pos 2
// "+new" -> pos 3, newLine 1
// "\ No.." -> not counted
if fileMap[1] != 3 {
t.Errorf("line 1: expected position 3, got %d", fileMap[1])
}
if len(fileMap) != 1 {
t.Errorf("expected 1 entry, got %d", len(fileMap))
}
}
func TestBuildLineToPositionMap_SingleLineHunk(t *testing.T) {
diff := `diff --git a/single.go b/single.go
--- a/single.go
+++ b/single.go
@@ -1 +1 @@
-old
+new
`
result := vcs.BuildLineToPositionMap(diff)
fileMap := result["single.go"]
// @@ line -> pos 1
// "-old" -> pos 2
// "+new" -> pos 3, newLine 1
if fileMap[1] != 3 {
t.Errorf("line 1: expected position 3, got %d", fileMap[1])
}
}
func TestBuildLineToPositionMap_DeletedFile(t *testing.T) {
diff := `diff --git a/deleted.go b/deleted.go
deleted file mode 100644
--- a/deleted.go
+++ /dev/null
@@ -1,3 +0,0 @@
-package deleted
-
-func Gone() {}
`
result := vcs.BuildLineToPositionMap(diff)
if _, ok := result["deleted.go"]; ok {
t.Error("deleted file should not appear in result")
}
} }