Compare commits

..

17 Commits

Author SHA1 Message Date
claw 6316007eb1 fix: address review findings from reviews #2955 and #2958
- Convert handleResponse to package-level function (unused receiver) [#17955]
- Add clarifying comment for nil resp on transport error [#17956]
- Use consistent %w wrapping in dual-unmarshal error path [#17957]
- Add SafeError() method to APIError for safe logging [#17964]
- Enforce safe CheckRedirect policy in SetHTTPClient [#17965]
- Add tests for SafeError and SetHTTPClient enforcement
2026-05-12 21:19:01 -07:00
claw b380e7fcae refactor(github): extract handleResponse for safe defer body close
PR Ready Gate / clear-labels (pull_request) Successful in 1s
CI / test (pull_request) Successful in 17s
CI / review (anthropic--claude-4.6-sonnet, sonnet, SONNET_REVIEW_TOKEN) (pull_request) Successful in 40s
CI / review (gpt-5, gpt, GPT_REVIEW_TOKEN) (pull_request) Successful in 1m16s
CI / review (gpt-5, security, ., rodin/security-patterns, SECURITY_REVIEW.md, SECURITY_REVIEW_TOKEN) (pull_request) Successful in 1m44s
Address review findings #1 and #2: the response body was closed explicitly
rather than via defer, which could leak if future code paths were added.

Extract handleResponse helper method that uses defer resp.Body.Close() to
guarantee cleanup. This avoids the loop-defer antipattern (defer inside a
for loop accumulates defers until function exit) by isolating the body
handling into its own function scope.
2026-05-12 20:47:59 -07:00
claw 30798ff023 fix: address sonnet review MINOR findings (#2916)
PR Ready Gate / clear-labels (pull_request) Successful in 2s
CI / test (pull_request) Successful in 18s
CI / review (anthropic--claude-4.6-sonnet, sonnet, SONNET_REVIEW_TOKEN) (pull_request) Successful in 46s
CI / review (gpt-5, security, ., rodin/security-patterns, SECURITY_REVIEW.md, SECURITY_REVIEW_TOKEN) (pull_request) Successful in 59s
CI / review (gpt-5, gpt, GPT_REVIEW_TOKEN) (pull_request) Successful in 1m6s
- client.go: fix misleading timer.Stop() comment (finding #1)
- pr.go: document all-or-nothing semantics for GetCommitStatuses
  when check-runs endpoint fails after statuses succeed (finding #2)
- files.go: include both array and object unmarshal errors in
  ListContents fallback error message (finding #3)
- pr.go: expand mapCheckRunStatus comment to explain non-blocking
  policy decision (finding #4)
2026-05-12 20:28:52 -07:00
claw 6e8e744816 fix(github): address self-review findings from 1194bc75
PR Ready Gate / clear-labels (pull_request) Successful in 2s
CI / test (pull_request) Successful in 17s
CI / review (anthropic--claude-4.6-sonnet, sonnet, SONNET_REVIEW_TOKEN) (pull_request) Successful in 51s
CI / review (gpt-5, security, ., rodin/security-patterns, SECURITY_REVIEW.md, SECURITY_REVIEW_TOKEN) (pull_request) Successful in 1m22s
CI / review (gpt-5, gpt, GPT_REVIEW_TOKEN) (pull_request) Successful in 1m36s
- Handle io.ReadAll error on error body read (client.go:265)
- Remove unused State field from commitStatusResponse (pr.go)
- Guard via slice access in defaultCheckRedirect (client.go:117)
- Move GetFileContentAtRef from pr.go to files.go (logical home)
2026-05-12 19:40:30 -07:00
claw 1194bc758c fix(github): address review findings from rounds 2884/2885/2887
PR Ready Gate / clear-labels (pull_request) Successful in 2s
CI / test (pull_request) Successful in 18s
CI / review (anthropic--claude-4.6-sonnet, sonnet, SONNET_REVIEW_TOKEN) (pull_request) Successful in 40s
CI / review (gpt-5, security, ., rodin/security-patterns, SECURITY_REVIEW.md, SECURITY_REVIEW_TOKEN) (pull_request) Successful in 1m18s
CI / review (gpt-5, gpt, GPT_REVIEW_TOKEN) (pull_request) Successful in 1m44s
- Fix response body limit check: read maxResponseBytes+1 and use > to
  distinguish exactly-at-limit from truncated (sonnet finding #1)
- Reject HTTPS→HTTP redirects outright instead of stripping auth and
  following; prevents plaintext metadata leakage (sonnet #2, security #1)
- Sanitize newlines in APIError.Error to prevent log injection from
  upstream response bodies (security #2)
- Add nil-return documentation to GetCommitStatuses (sonnet #3)
- Gate TestDoRequest_429RetryAfterHTTPDate behind testing.Short (sonnet #6)
- Add tests for redirect policy, exact-at-limit body, and error sanitization
2026-05-12 19:29:06 -07:00
claw 80af5037b2 fix(github): address review findings from round 2880/2883
PR Ready Gate / clear-labels (pull_request) Successful in 2s
CI / test (pull_request) Successful in 24s
CI / review (anthropic--claude-4.6-sonnet, sonnet, SONNET_REVIEW_TOKEN) (pull_request) Successful in 43s
CI / review (gpt-5, gpt, GPT_REVIEW_TOKEN) (pull_request) Successful in 1m16s
CI / review (gpt-5, security, ., rodin/security-patterns, SECURITY_REVIEW.md, SECURITY_REVIEW_TOKEN) (pull_request) Successful in 2m21s
Sonnet MINOR #1: Stop timer after <-timer.C fires for idiomatic cleanup.
Sonnet MINOR #2: Document that empty array from contents API is valid (empty dir).
Sonnet MINOR #3: Document that GetPullRequestFiles returns nil for no files.
Sonnet NIT #4: Strengthen SetHTTPClient/SetRetryBackoff docs to clarify test-only intent.
Sonnet NIT #5: Document GetCommitStatuses fail-fast behavior.
Sonnet NIT #6: Document double-slash collapsing in escapePath.
Security MINOR #1: Document redirect policy responsibility when providing custom client.
Security MINOR #2: Reduce maxErrorBodyBytes from 64KB to 4KB to limit sensitive data exposure.
2026-05-12 18:41:44 -07:00
claw 5b2fa0b9af refactor(github): address review findings from round 2872
PR Ready Gate / clear-labels (pull_request) Successful in 2s
CI / test (pull_request) Successful in 17s
CI / review (anthropic--claude-4.6-sonnet, sonnet, SONNET_REVIEW_TOKEN) (pull_request) Successful in 36s
CI / review (gpt-5, gpt, GPT_REVIEW_TOKEN) (pull_request) Successful in 1m31s
CI / review (gpt-5, security, ., rodin/security-patterns, SECURITY_REVIEW.md, SECURITY_REVIEW_TOKEN) (pull_request) Successful in 1m53s
- client.go: clarify timer drain comment (finding #1)
- client.go: rename t -> retryAt for time.Time clarity (finding #2)
- pr.go: remove dead _ string parameter from mapCheckRunStatus (finding #3)
- files.go: add inline comment explaining zero-value guard (finding #4)

Findings #5 (NIT, no code change) and #6 (NIT, defer vs t.Cleanup
in t.Run closures) pushed back — see PR comment.
2026-05-12 18:16:43 -07:00
claw 491df7cb1f fix(github): address review findings from rounds 2867/2870
PR Ready Gate / clear-labels (pull_request) Successful in 2s
CI / test (pull_request) Successful in 18s
CI / review (anthropic--claude-4.6-sonnet, sonnet, SONNET_REVIEW_TOKEN) (pull_request) Successful in 41s
CI / review (gpt-5, security, ., rodin/security-patterns, SECURITY_REVIEW.md, SECURITY_REVIEW_TOKEN) (pull_request) Successful in 1m20s
CI / review (gpt-5, gpt, GPT_REVIEW_TOKEN) (pull_request) Successful in 1m43s
- Extract duplicated CheckRedirect lambda to defaultCheckRedirect function
  (sonnet #1: eliminate duplication between NewClient and SetHTTPClient)
- Remove unnecessary int64 cast in response size check (sonnet #3)
- Validate fallback unmarshal in ListContents to reject zero-value entries
  (sonnet #5: prevent accepting unexpected JSON formats silently)
- Rename strPtr to stringPtr for consistency (sonnet #6)
- Add doc comment about APIError.Error body exposure (security #3)

Deferred to separate issues:
- #95: Reject cross-host redirects entirely (security #1)
- #96: Add safeguards for AllowInsecureHTTP (security #2)
2026-05-12 17:30:24 -07:00
claw 1fcc0b738a fix(github): address MINOR/NIT findings from review #2866
PR Ready Gate / clear-labels (pull_request) Successful in 1s
CI / test (pull_request) Successful in 18s
CI / review (anthropic--claude-4.6-sonnet, sonnet, SONNET_REVIEW_TOKEN) (pull_request) Successful in 39s
CI / review (gpt-5, gpt, GPT_REVIEW_TOKEN) (pull_request) Successful in 1m30s
CI / review (gpt-5, security, ., rodin/security-patterns, SECURITY_REVIEW.md, SECURITY_REVIEW_TOKEN) (pull_request) Successful in 2m8s
- SetHTTPClient(nil): preserve CheckRedirect auth-stripping policy
  instead of restoring a plain http.Client that loses cross-host
  protection.

- Authorization header: add comment documenting why Bearer scheme is
  correct (OAuth2 standard, works for both classic PATs and
  fine-grained tokens).

- Retry-After parsing: support HTTP-date format (RFC 7231) in addition
  to integer seconds. GitHub only sends integers today, but the
  implementation is now spec-compliant.

- escapePath dot-segment removal: document the behavior in public API
  doc comments for ListContents and GetFileContentAtRef so callers are
  aware without reading the internal helper.
2026-05-12 17:13:07 -07:00
claw fce5f2d184 fix(github): address review findings on client.go
PR Ready Gate / clear-labels (pull_request) Successful in 2s
CI / test (pull_request) Successful in 17s
CI / review (anthropic--claude-4.6-sonnet, sonnet, SONNET_REVIEW_TOKEN) (pull_request) Successful in 40s
CI / review (gpt-5, security, ., rodin/security-patterns, SECURITY_REVIEW.md, SECURITY_REVIEW_TOKEN) (pull_request) Successful in 1m23s
CI / review (gpt-5, gpt, GPT_REVIEW_TOKEN) (pull_request) Successful in 2m15s
- Use net/url.Parse for HTTPS scheme check (case-insensitive)
- Guard SetHTTPClient against nil (restores default 30s client)
- Rename 'url' param to 'reqURL' in doRequest/doGet for clarity
- Return error when response exceeds maxResponseBytes instead of
  silently truncating

Finding #1 (Bearer auth scheme) intentionally kept: GitHub REST API
officially supports and recommends Bearer for all token types.
See: https://docs.github.com/en/rest/authentication/authenticating-to-the-rest-api
2026-05-12 16:55:32 -07:00
claw af72c64b7f fix(github): correct ListContents error wrapping and move HTTPS guard before retry loop
PR Ready Gate / clear-labels (pull_request) Successful in 2s
CI / test (pull_request) Successful in 17s
CI / review (anthropic--claude-4.6-sonnet, sonnet, SONNET_REVIEW_TOKEN) (pull_request) Successful in 42s
CI / review (gpt-5, gpt, GPT_REVIEW_TOKEN) (pull_request) Successful in 2m11s
CI / review (gpt-5, security, ., rodin/security-patterns, SECURITY_REVIEW.md, SECURITY_REVIEW_TOKEN) (pull_request) Successful in 2m11s
2026-05-12 16:48:39 -07:00
claw 1bc3f206ba fix: address review findings from rounds 2843-2846
PR Ready Gate / clear-labels (pull_request) Successful in 2s
CI / test (pull_request) Successful in 17s
CI / review (anthropic--claude-4.6-sonnet, sonnet, SONNET_REVIEW_TOKEN) (pull_request) Successful in 41s
CI / review (gpt-5, security, ., rodin/security-patterns, SECURITY_REVIEW.md, SECURITY_REVIEW_TOKEN) (pull_request) Successful in 2m13s
CI / review (gpt-5, gpt, GPT_REVIEW_TOKEN) (pull_request) Successful in 2m23s
- Remove redundant timer.Stop() after timer fires (Sonnet #1, GPT #2)
- Remove unused TotalCount field from checkRunsResponse (Sonnet #2)
- Improve escapePath doc comment to explain deliberate silent stripping (Sonnet #3)
- Fix ListContents to handle both array (directory) and object (single file)
  responses from GitHub Contents API (GPT #3)
- Add HTTPS enforcement: refuse to send credentials over non-HTTPS URLs
  unless AllowInsecureHTTP() option is passed (Security #1)
- Replace constant-value test with actual behavior test for response
  body limiting (Sonnet #6)
- Run gofmt for consistent formatting (Sonnet #4)
- Add tests for HTTPS enforcement and ListContents single-file handling
2026-05-12 16:39:01 -07:00
claw c10bb72117 fix: address self-review NIT findings on PR #93
PR Ready Gate / clear-labels (pull_request) Successful in 2s
CI / test (pull_request) Successful in 22s
CI / review (anthropic--claude-4.6-sonnet, sonnet, SONNET_REVIEW_TOKEN) (pull_request) Successful in 37s
CI / review (gpt-5, gpt, GPT_REVIEW_TOKEN) (pull_request) Successful in 2m9s
CI / review (gpt-5, security, ., rodin/security-patterns, SECURITY_REVIEW.md, SECURITY_REVIEW_TOKEN) (pull_request) Successful in 2m17s
- Add timer.Stop() on happy path in retry loop (idiomatic)
- Add concurrency caveat to Client doc comment for SetHTTPClient/SetRetryBackoff
- Add explicit 'stale'/'waiting' cases to mapCheckRunStatus
2026-05-12 16:25:32 -07:00
claw ae91c8aef5 fix: address review findings from rounds 2834-2838
PR Ready Gate / clear-labels (pull_request) Successful in 2s
CI / test (pull_request) Successful in 17s
CI / review (anthropic--claude-4.6-sonnet, sonnet, SONNET_REVIEW_TOKEN) (pull_request) Successful in 49s
CI / review (gpt-5, security, ., rodin/security-patterns, SECURITY_REVIEW.md, SECURITY_REVIEW_TOKEN) (pull_request) Successful in 2m6s
CI / review (gpt-5, gpt, GPT_REVIEW_TOKEN) (pull_request) Successful in 2m19s
- Unexport RetryBackoff, add SetRetryBackoff method (#17286)
- Rename http field to httpClient to avoid shadowing (#17289)
- Group const blocks into single declaration (#17291)
- Fix CheckRedirect to compare against previous hop, not first (#17302)
- Strip auth header on protocol downgrade https→http (#17297)
- Add maxPages safeguard to pagination loops (#17299, #17300)
- Document mapCheckRunStatus unused second parameter (#17287, #17303)
2026-05-12 16:11:58 -07:00
claw 75f65fbf5d fix: address MINOR review findings on PR #93 (round 2)
PR Ready Gate / clear-labels (pull_request) Successful in 2s
CI / test (pull_request) Successful in 17s
CI / review (anthropic--claude-4.6-sonnet, sonnet, SONNET_REVIEW_TOKEN) (pull_request) Successful in 38s
CI / review (gpt-5, security, ., rodin/security-patterns, SECURITY_REVIEW.md, SECURITY_REVIEW_TOKEN) (pull_request) Successful in 2m28s
CI / review (gpt-5, gpt, GPT_REVIEW_TOKEN) (pull_request) Successful in 2m50s
- Add User-Agent header to all requests (gpt-review-bot)
- Limit successful response body to 10 MiB via io.LimitReader (security-review-bot)
- Add CheckRedirect to strip Authorization on cross-host redirects (security-review-bot)
- Fix decodeBase64Content to strip both \r and \n (gpt-review-bot)
- Document that transport errors are not retried (sonnet-review-bot)
- Update package doc to reflect current scope (no review submission yet)
- Add tests for User-Agent, empty-token auth skip, CRLF base64, CheckRedirect
2026-05-12 16:00:09 -07:00
claw 5b43afc6d4 fix: address review feedback on PR #93
PR Ready Gate / clear-labels (pull_request) Successful in 1s
CI / test (pull_request) Successful in 23s
CI / review (anthropic--claude-4.6-sonnet, sonnet, SONNET_REVIEW_TOKEN) (pull_request) Successful in 45s
CI / review (gpt-5, gpt, GPT_REVIEW_TOKEN) (pull_request) Successful in 1m48s
CI / review (gpt-5, security, ., rodin/security-patterns, SECURITY_REVIEW.md, SECURITY_REVIEW_TOKEN) (pull_request) Successful in 2m7s
- Fix Retry-After slice mutation: copy c.RetryBackoff before modifying
  to prevent permanent mutation of the shared slice (sonnet#1, security#1)
- Cap Retry-After to 120s maximum to prevent excessive sleeps (security#2)
- Guard auth header: only set Authorization when token is non-empty (gpt#2)
- Fix GetFileContent doc comment to match actual behavior (sonnet#3, gpt#1)
- Remove dead 'in_progress/queued' case in mapCheckRunStatus (sonnet#4)
- Add testing.Short() guard to slow retry test (sonnet#5)
- Reject dot-segments in escapePath to prevent path traversal (security#3)
- Add regression tests for non-mutation and escapePath safety
2026-05-12 15:43:45 -07:00
claw d1ef1e21e5 feat(github): implement PRReader + FileReader client (#80)
CI / test (pull_request) Successful in 18s
PR Ready Gate / clear-labels (pull_request) Successful in 2s
CI / review (anthropic--claude-4.6-sonnet, sonnet, SONNET_REVIEW_TOKEN) (pull_request) Successful in 34s
CI / review (gpt-5, gpt, GPT_REVIEW_TOKEN) (pull_request) Successful in 1m45s
CI / review (gpt-5, security, ., rodin/security-patterns, SECURITY_REVIEW.md, SECURITY_REVIEW_TOKEN) (pull_request) Successful in 2m56s
Implement the GitHub API client with PRReader and FileReader interface
conformance for both github.com and GitHub Enterprise.

New files:
- github/client.go: Client struct, NewClient with configurable base URL,
  HTTP helpers with 429 retry and Retry-After support
- github/pr.go: GetPullRequest, GetPullRequestDiff (per-request Accept
  header), GetPullRequestFiles (paginated, populates Patch field),
  GetFileContentAtRef (base64 decode), GetCommitStatuses (merges commit
  statuses + check runs with conclusion mapping)
- github/files.go: GetFileContent (delegates to GetFileContentAtRef),
  ListContents, escapePath, decodeBase64Content helpers

Type changes:
- vcs/types.go: Add Patch field to ChangedFile struct

Tests cover: happy path, 404, 401, 429+retry, malformed response,
pagination, binary files, check run conclusion mapping, base64 decoding.

Compile-time checks:
  var _ vcs.PRReader = (*Client)(nil)
  var _ vcs.FileReader = (*Client)(nil)

Exit criteria met:
- go test ./github/... passes (all methods)
- NewClient with empty baseURL uses https://api.github.com
- NewClient with GHE URL targets correctly
- GetFileContent delegates to GetFileContentAtRef with empty ref
- GetPullRequestFiles paginates and populates Patch field
- GetCommitStatuses merges both commit statuses and check-runs
2026-05-12 15:18:55 -07:00
29 changed files with 654 additions and 3113 deletions
+1 -1
View File
@@ -9,7 +9,7 @@
| Package | Use Case | Scope | | Package | Use Case | Scope |
|---------|----------|-------| |---------|----------|-------|
| `github.com/goccy/go-yaml` | YAML parsing and AST inspection (subpkgs: `ast`, `parser`) | production | | `gopkg.in/yaml.v3` | YAML parsing (persona files, config) | production |
| `github.com/google/go-cmp` | Test comparisons (`cmp.Diff`) | test only | | `github.com/google/go-cmp` | Test comparisons (`cmp.Diff`) | test only |
**Any import not in this table or the Go standard library is forbidden.** **Any import not in this table or the Go standard library is forbidden.**
+159 -250
View File
@@ -2,7 +2,6 @@ package main
import ( import (
"context" "context"
"errors"
"flag" "flag"
"fmt" "fmt"
"log/slog" "log/slog"
@@ -14,7 +13,6 @@ import (
"gitea.weiker.me/rodin/review-bot/budget" "gitea.weiker.me/rodin/review-bot/budget"
"gitea.weiker.me/rodin/review-bot/gitea" "gitea.weiker.me/rodin/review-bot/gitea"
"gitea.weiker.me/rodin/review-bot/github"
"gitea.weiker.me/rodin/review-bot/llm" "gitea.weiker.me/rodin/review-bot/llm"
"gitea.weiker.me/rodin/review-bot/review" "gitea.weiker.me/rodin/review-bot/review"
"gitea.weiker.me/rodin/review-bot/vcs" "gitea.weiker.me/rodin/review-bot/vcs"
@@ -56,22 +54,19 @@ func main() {
// Logging flags // Logging flags
logFormat := flag.String("log-format", envOrDefault("LOG_FORMAT", "text"), "Log output format: text or json") logFormat := flag.String("log-format", envOrDefault("LOG_FORMAT", "text"), "Log output format: text or json")
verbosity := flag.String("verbosity", envOrDefault("LOG_VERBOSITY", "info"), "Log verbosity: debug, info, warn, error") verbosity := flag.String("verbosity", envOrDefault("LOG_VERBOSITY", "info"), "Log verbosity: debug, info, warn, error")
// VCS flags // CLI flags
provider := flag.String("provider", envOrDefault("VCS_PROVIDER", "gitea"), "VCS provider: gitea or github") giteaURL := flag.String("gitea-url", envOrDefault("GITEA_URL", envOrDefault("GITHUB_SERVER_URL", "")), "Gitea instance URL")
baseURL := flag.String("base-url", envOrDefault("VCS_BASE_URL", ""), "VCS API base URL (for github provider; defaults to https://api.github.com)") repo := flag.String("repo", envOrDefault("GITEA_REPO", envOrDefault("GITHUB_REPOSITORY", "")), "Repository (owner/name)")
vcsURL := flag.String("vcs-url", envOrDefault("VCS_URL", envOrDefault("GITEA_URL", envOrDefault("GITHUB_SERVER_URL", ""))), "VCS instance URL (Gitea) [deprecated alias: --gitea-url]")
// Keep --gitea-url as backward-compatible alias (flag package doesn't support aliases natively, handle below)
repo := flag.String("repo", envOrDefault("VCS_REPO", envOrDefault("GITEA_REPO", envOrDefault("GITHUB_REPOSITORY", ""))), "Repository (owner/name)")
prNum := flag.String("pr", envOrDefault("PR_NUMBER", ""), "Pull request number") prNum := flag.String("pr", envOrDefault("PR_NUMBER", ""), "Pull request number")
reviewerName := flag.String("reviewer-name", envOrDefault("REVIEWER_NAME", ""), "Reviewer display name") reviewerName := flag.String("reviewer-name", envOrDefault("REVIEWER_NAME", ""), "Reviewer display name")
reviewerToken := flag.String("reviewer-token", envOrDefault("REVIEWER_TOKEN", ""), "VCS token for posting review") reviewerToken := flag.String("reviewer-token", envOrDefault("REVIEWER_TOKEN", ""), "Gitea token for posting review")
llmBaseURL := flag.String("llm-base-url", envOrDefault("LLM_BASE_URL", ""), "LLM API base URL") llmBaseURL := flag.String("llm-base-url", envOrDefault("LLM_BASE_URL", ""), "LLM API base URL")
llmAPIKey := flag.String("llm-api-key", envOrDefault("LLM_API_KEY", ""), "LLM API key") llmAPIKey := flag.String("llm-api-key", envOrDefault("LLM_API_KEY", ""), "LLM API key")
llmModel := flag.String("llm-model", envOrDefault("LLM_MODEL", ""), "LLM model name") llmModel := flag.String("llm-model", envOrDefault("LLM_MODEL", ""), "LLM model name")
conventionsFile := flag.String("conventions-file", envOrDefault("CONVENTIONS_FILE", ""), "Conventions file path in repo (e.g. CLAUDE.md)") conventionsFile := flag.String("conventions-file", envOrDefault("CONVENTIONS_FILE", ""), "Conventions file path in repo (e.g. CLAUDE.md)")
systemPromptFile := flag.String("system-prompt-file", envOrDefault("SYSTEM_PROMPT_FILE", ""), "Local file with additional system prompt instructions") systemPromptFile := flag.String("system-prompt-file", envOrDefault("SYSTEM_PROMPT_FILE", ""), "Local file with additional system prompt instructions")
patternsRepo := flag.String("patterns-repo", envOrDefault("PATTERNS_REPO", ""), "Repo with language patterns (e.g. rodin/elixir-patterns)") patternsRepo := flag.String("patterns-repo", envOrDefault("PATTERNS_REPO", ""), "Repo with language patterns (e.g. rodin/elixir-patterns)")
patternsFiles := flag.String("patterns-files", envOrDefault("PATTERNS_FILES", ""), "Comma-separated file paths to fetch from patterns repo (empty = all files)") patternsFiles := flag.String("patterns-files", envOrDefault("PATTERNS_FILES", "README.md"), "Comma-separated file paths to fetch from patterns repo")
dryRun := flag.Bool("dry-run", false, "Print review to stdout instead of posting") dryRun := flag.Bool("dry-run", false, "Print review to stdout instead of posting")
llmTemp := flag.Float64("llm-temperature", envOrDefaultFloat("LLM_TEMPERATURE", 0), "LLM temperature (0 = server default)") llmTemp := flag.Float64("llm-temperature", envOrDefaultFloat("LLM_TEMPERATURE", 0), "LLM temperature (0 = server default)")
llmTimeout := flag.Int("llm-timeout", envOrDefaultInt("LLM_TIMEOUT", 300), "LLM request timeout in seconds (default 300)") llmTimeout := flag.Int("llm-timeout", envOrDefaultInt("LLM_TIMEOUT", 300), "LLM request timeout in seconds (default 300)")
@@ -85,18 +80,6 @@ func main() {
aicoreAPIURL := flag.String("aicore-api-url", envOrDefault("AICORE_API_URL", ""), "SAP AI Core API URL (for provider=aicore)") aicoreAPIURL := flag.String("aicore-api-url", envOrDefault("AICORE_API_URL", ""), "SAP AI Core API URL (for provider=aicore)")
aicoreResourceGroup := flag.String("aicore-resource-group", envOrDefault("AICORE_RESOURCE_GROUP", "default"), "SAP AI Core resource group (for provider=aicore)") aicoreResourceGroup := flag.String("aicore-resource-group", envOrDefault("AICORE_RESOURCE_GROUP", "default"), "SAP AI Core resource group (for provider=aicore)")
// Register --gitea-url as a backward-compatible alias for --vcs-url.
// StringVar shares the *string pointer with vcsURL, so whichever flag is
// set last by flag.Parse wins — both point to the same underlying value.
// NOTE: If a user passes both --vcs-url and --gitea-url, the last one on
// the command line takes effect (standard flag package behavior). This is
// acceptable since --gitea-url is deprecated and both serve the same purpose.
//
// ORDERING: This must remain AFTER vcsURL's flag.String declaration and BEFORE
// flag.Parse(). The *vcsURL dereference captures the env-var-resolved default
// at registration time; moving flag.Parse() above this line would break it.
flag.StringVar(vcsURL, "gitea-url", *vcsURL, "Deprecated: use --vcs-url instead")
flag.Parse() flag.Parse()
if *versionFlag { if *versionFlag {
@@ -109,25 +92,12 @@ func main() {
slog.Info("review-bot starting", "version", version) slog.Info("review-bot starting", "version", version)
// Validate VCS provider
switch *provider {
case "gitea", "github":
// valid
default:
fmt.Fprintf(os.Stderr, "Error: invalid --provider %q (valid: gitea, github)\n", *provider)
os.Exit(1)
}
// Validate required fields // Validate required fields
// For aicore provider, llm-base-url and llm-api-key are not required
isAICore := llm.Provider(*llmProvider) == llm.ProviderAICore isAICore := llm.Provider(*llmProvider) == llm.ProviderAICore
if *repo == "" || *prNum == "" || *reviewerToken == "" || *llmModel == "" { if *giteaURL == "" || *repo == "" || *prNum == "" || *reviewerToken == "" || *llmModel == "" {
fmt.Fprintf(os.Stderr, "Error: missing required flags or environment variables\n\n") fmt.Fprintf(os.Stderr, "Error: missing required flags or environment variables\n\n")
fmt.Fprintf(os.Stderr, "Required: --repo, --pr, --reviewer-token, --llm-model\n") fmt.Fprintf(os.Stderr, "Required: --gitea-url, --repo, --pr, --reviewer-token, --llm-model\n")
os.Exit(1)
}
// --vcs-url is required only for gitea provider
if *provider == "gitea" && *vcsURL == "" {
fmt.Fprintf(os.Stderr, "Error: --vcs-url (or --gitea-url) is required for provider=gitea\n")
os.Exit(1) os.Exit(1)
} }
if !isAICore && (*llmBaseURL == "" || *llmAPIKey == "") { if !isAICore && (*llmBaseURL == "" || *llmAPIKey == "") {
@@ -146,6 +116,8 @@ func main() {
os.Exit(1) os.Exit(1)
} }
// NOTE: Persona loading deferred until after Gitea client init to support repo personas
// Validate reviewer-name: only safe characters allowed in sentinel // Validate reviewer-name: only safe characters allowed in sentinel
if err := validateReviewerName(*reviewerName); err != nil { if err := validateReviewerName(*reviewerName); err != nil {
slog.Error("invalid reviewer name", "error", err) slog.Error("invalid reviewer name", "error", err)
@@ -167,25 +139,8 @@ func main() {
os.Exit(1) os.Exit(1)
} }
// Initialize VCS client // Initialize clients
var client vcs.Client giteaClient := gitea.NewClient(*giteaURL, *reviewerToken)
switch *provider {
case "gitea":
giteaClient := gitea.NewClient(*vcsURL, *reviewerToken)
client = gitea.NewAdapter(giteaClient)
case "github":
ghBaseURL := *baseURL
if ghBaseURL == "" {
ghBaseURL = "https://api.github.com"
}
client = github.NewClient(*reviewerToken, ghBaseURL)
default:
fmt.Fprintf(os.Stderr, "Error: unhandled provider %q\n", *provider)
os.Exit(1)
}
slog.Info("VCS client initialized", "provider", *provider)
// Initialize LLM client
llmClient := llm.NewClient(*llmBaseURL, *llmAPIKey, *llmModel) llmClient := llm.NewClient(*llmBaseURL, *llmAPIKey, *llmModel)
if *llmTemp < 0 || *llmTemp > 2 { if *llmTemp < 0 || *llmTemp > 2 {
slog.Error("invalid LLM temperature", "temperature", *llmTemp, "range", "0-2") slog.Error("invalid LLM temperature", "temperature", *llmTemp, "range", "0-2")
@@ -219,13 +174,16 @@ func main() {
ctx, cancel := context.WithTimeout(context.Background(), overallTimeout) ctx, cancel := context.WithTimeout(context.Background(), overallTimeout)
defer cancel() defer cancel()
// Load persona if specified // Load persona if specified (after Gitea client init to support repo personas)
var persona *review.Persona var persona *review.Persona
if *personaName != "" { if *personaName != "" {
// Try loading from repo first, then fall back to built-in // Try loading from repo first, then fall back to built-in
repoPersonas, err := review.LoadRepoPersonas(ctx, client, owner, repoName) repoPersonas, err := review.LoadRepoPersonas(ctx, newGiteaClientAdapter(giteaClient), owner, repoName)
if err != nil { if err != nil {
slog.Warn("could not load repo personas", "repo", owner+"/"+repoName, "error", err) slog.Warn("could not load repo personas", "repo", owner+"/"+repoName, "error", err)
// Continue with built-in personas only.
// NOTE: repoPersonas is nil here, but map indexing on a nil map is safe in Go
// (returns the zero value), so the fallback to built-in below works correctly.
} }
if p, ok := repoPersonas[*personaName]; ok { if p, ok := repoPersonas[*personaName]; ok {
persona = p persona = p
@@ -256,7 +214,7 @@ func main() {
slog.Info("reviewing pull request", "pr", prNumber, "repo", fmt.Sprintf("%s/%s", owner, repoName)) slog.Info("reviewing pull request", "pr", prNumber, "repo", fmt.Sprintf("%s/%s", owner, repoName))
// Step 1: Fetch PR metadata // Step 1: Fetch PR metadata
pr, err := client.GetPullRequest(ctx, owner, repoName, prNumber) pr, err := giteaClient.GetPullRequest(ctx, owner, repoName, prNumber)
if err != nil { if err != nil {
slog.Error("failed to fetch PR", "pr", prNumber, "error", err) slog.Error("failed to fetch PR", "pr", prNumber, "error", err)
os.Exit(1) os.Exit(1)
@@ -264,7 +222,7 @@ func main() {
slog.Info("fetched PR metadata", "pr", prNumber, "title", pr.Title) slog.Info("fetched PR metadata", "pr", prNumber, "title", pr.Title)
// Step 2: Fetch diff // Step 2: Fetch diff
diff, err := client.GetPullRequestDiff(ctx, owner, repoName, prNumber) diff, err := giteaClient.GetPullRequestDiff(ctx, owner, repoName, prNumber)
if err != nil { if err != nil {
slog.Error("failed to fetch diff", "pr", prNumber, "error", err) slog.Error("failed to fetch diff", "pr", prNumber, "error", err)
os.Exit(1) os.Exit(1)
@@ -273,21 +231,21 @@ func main() {
// Step 3: Fetch full file content for modified files // Step 3: Fetch full file content for modified files
fileContext := "" fileContext := ""
files, err := client.GetPullRequestFiles(ctx, owner, repoName, prNumber) files, err := giteaClient.GetPullRequestFiles(ctx, owner, repoName, prNumber)
if err != nil { if err != nil {
slog.Warn("could not fetch PR files list", "pr", prNumber, "error", err) slog.Warn("could not fetch PR files list", "pr", prNumber, "error", err)
} else { } else {
fileContext = fetchFileContext(ctx, client, owner, repoName, pr.Head.Ref, files) fileContext = fetchFileContext(ctx, giteaClient, owner, repoName, pr.Head.Ref, files)
slog.Debug("fetched file context", "files", len(files)) slog.Debug("fetched file context", "files", len(files))
} }
// Step 4: Check CI status // Step 4: Check CI status
ciPassed := true ciPassed := true
ciDetails := "" ciDetails := ""
if pr.Head.SHA != "" { if pr.Head.Sha != "" {
statuses, err := client.GetCommitStatuses(ctx, owner, repoName, pr.Head.SHA) statuses, err := giteaClient.GetCommitStatuses(ctx, owner, repoName, pr.Head.Sha)
if err != nil { if err != nil {
slog.Warn("could not fetch CI status", "sha", pr.Head.SHA, "error", err) slog.Warn("could not fetch CI status", "sha", pr.Head.Sha, "error", err)
} else { } else {
ciPassed, ciDetails = evaluateCIStatus(statuses) ciPassed, ciDetails = evaluateCIStatus(statuses)
slog.Info("CI status checked", "passed", ciPassed) slog.Info("CI status checked", "passed", ciPassed)
@@ -297,7 +255,7 @@ func main() {
// Step 5: Load conventions file if specified // Step 5: Load conventions file if specified
conventions := "" conventions := ""
if *conventionsFile != "" { if *conventionsFile != "" {
content, err := client.GetFileContent(ctx, owner, repoName, *conventionsFile, "") content, err := giteaClient.GetFileContent(ctx, owner, repoName, *conventionsFile)
if err != nil { if err != nil {
slog.Warn("could not load conventions file", "file", *conventionsFile, "error", err) slog.Warn("could not load conventions file", "file", *conventionsFile, "error", err)
} else { } else {
@@ -309,7 +267,7 @@ func main() {
// Step 6: Load patterns from external repo if specified // Step 6: Load patterns from external repo if specified
patterns := "" patterns := ""
if *patternsRepo != "" { if *patternsRepo != "" {
patterns = fetchPatterns(ctx, client, *patternsRepo, *patternsFiles) patterns = fetchPatterns(ctx, giteaClient, *patternsRepo, *patternsFiles)
slog.Debug("loaded patterns", "repo", *patternsRepo, "bytes", len(patterns)) slog.Debug("loaded patterns", "repo", *patternsRepo, "bytes", len(patterns))
} }
@@ -402,16 +360,15 @@ func main() {
} }
// Add commit footer so readers know which commit was evaluated // Add commit footer so readers know which commit was evaluated
if pr.Head.SHA != "" { if pr.Head.Sha != "" {
shortSHA := pr.Head.SHA shortSHA := pr.Head.Sha
if len(shortSHA) > 8 { if len(shortSHA) > 8 {
shortSHA = shortSHA[:8] shortSHA = shortSHA[:8]
} }
reviewBody += fmt.Sprintf("\n\n---\n*Evaluated against %s*", shortSHA) reviewBody += fmt.Sprintf("\n\n---\n*Evaluated against %s*", shortSHA)
} }
// Map verdict to canonical review event event := review.GiteaEvent(result.Verdict)
event := verdictToEvent(result.Verdict)
if *dryRun { if *dryRun {
fmt.Println("--- DRY RUN ---") fmt.Println("--- DRY RUN ---")
@@ -423,13 +380,14 @@ func main() {
sentinel := fmt.Sprintf("<!-- review-bot:%s -->", *reviewerName) sentinel := fmt.Sprintf("<!-- review-bot:%s -->", *reviewerName)
// Stale check: verify HEAD hasn't moved since we started // Stale check: verify HEAD hasn't moved since we started
evaluatedSHA := pr.Head.SHA evaluatedSHA := pr.Head.Sha
var currentSHA string var currentSHA string
currentPR, err := client.GetPullRequest(ctx, owner, repoName, prNumber) currentPR, err := giteaClient.GetPullRequest(ctx, owner, repoName, prNumber)
if err != nil { if err != nil {
slog.Warn("could not re-fetch PR for stale check", "pr", prNumber, "error", err) slog.Warn("could not re-fetch PR for stale check", "pr", prNumber, "error", err)
// currentSHA stays empty — shouldSkipStaleReview will return false
} else { } else {
currentSHA = currentPR.Head.SHA currentSHA = currentPR.Head.Sha
} }
if shouldSkipStaleReview(evaluatedSHA, currentSHA) { if shouldSkipStaleReview(evaluatedSHA, currentSHA) {
slog.Warn("HEAD moved during review — skipping stale review", slog.Warn("HEAD moved during review — skipping stale review",
@@ -439,24 +397,17 @@ func main() {
return return
} }
// Build line→position map for inline comments // Map findings to inline comments for lines present in the diff
lineToPosition := vcs.BuildLineToPositionMap(diff) diffRanges := gitea.ParseDiffNewLines(diff)
var inlineComments []vcs.ReviewComment var inlineComments []gitea.ReviewComment
for _, f := range result.Findings { for _, f := range result.Findings {
if f.File == "" || f.Line <= 0 { if f.File != "" && f.Line > 0 && diffRanges.Contains(f.File, f.Line) {
continue inlineComments = append(inlineComments, gitea.ReviewComment{
Path: f.File,
NewPosition: int64(f.Line),
Body: fmt.Sprintf("**[%s]** %s", f.Severity, f.Finding),
})
} }
pos, ok := lineToPosition[f.File][f.Line]
if !ok {
slog.Warn("line not in diff, skipping comment", "file", f.File, "line", f.Line)
continue
}
inlineComments = append(inlineComments, vcs.ReviewComment{
Path: f.File,
Position: pos,
CommitID: pr.Head.SHA,
Body: fmt.Sprintf("**[%s]** %s", f.Severity, f.Finding),
})
} }
if len(inlineComments) > 0 { if len(inlineComments) > 0 {
slog.Debug("attaching inline comments", "count", len(inlineComments)) slog.Debug("attaching inline comments", "count", len(inlineComments))
@@ -465,9 +416,10 @@ func main() {
// --- Review update strategy --- // --- Review update strategy ---
// 1. POST new review first (gets non-stale approval badge on HEAD) // 1. POST new review first (gets non-stale approval badge on HEAD)
// 2. Then supersede old review with link to the new one // 2. Then supersede old review with link to the new one
var oldReviews []vcs.Review // Order matters: post first so we have the new review's URL for the supersede message.
var oldReviews []gitea.Review
if *reviewerName != "" { if *reviewerName != "" {
existingReviews, err := client.ListReviews(ctx, owner, repoName, prNumber) existingReviews, err := giteaClient.ListReviews(ctx, owner, repoName, prNumber)
if err != nil { if err != nil {
slog.Warn("could not list existing reviews", "pr", prNumber, "error", err) slog.Warn("could not list existing reviews", "pr", prNumber, "error", err)
} else { } else {
@@ -479,137 +431,74 @@ func main() {
} }
} }
// Self-request as reviewer (Gitea-specific; ensures we appear in required-reviewer checks) // Self-request as reviewer (ensures we appear in required-reviewer checks)
if giteaAdapter, ok := client.(*gitea.Adapter); ok { authUser, err := giteaClient.GetAuthenticatedUser(ctx)
authUser, err := client.GetAuthenticatedUser(ctx) if err != nil {
if err != nil { slog.Warn("could not determine authenticated user for reviewer self-request", "error", err)
slog.Warn("could not determine authenticated user for reviewer self-request", "error", err) } else if authUser != "" {
} else if authUser != "" { if err := giteaClient.RequestReviewer(ctx, owner, repoName, prNumber, authUser); err != nil {
if err := giteaAdapter.Underlying().RequestReviewer(ctx, owner, repoName, prNumber, authUser); err != nil { slog.Warn("could not self-request as reviewer", "user", authUser, "error", err)
slog.Warn("could not self-request as reviewer", "user", authUser, "error", err) } else {
} else { slog.Debug("self-requested as reviewer", "user", authUser, "pr", prNumber)
slog.Debug("self-requested as reviewer", "user", authUser, "pr", prNumber)
}
} }
} else {
slog.Debug("RequestReviewer not supported for provider, skipping")
} }
// POST new review // POST new review
slog.Info("posting review", "event", event, "pr", prNumber) slog.Info("posting review", "event", event, "pr", prNumber)
reviewReq := vcs.ReviewRequest{ posted, err := giteaClient.PostReview(ctx, owner, repoName, prNumber, event, reviewBody, inlineComments)
Body: reviewBody,
Event: event,
Comments: inlineComments,
}
posted, err := client.PostReview(ctx, owner, repoName, prNumber, reviewReq)
if err != nil { if err != nil {
slog.Error("failed to post review", "pr", prNumber, "event", event, "error", err) slog.Error("failed to post review", "pr", prNumber, "event", event, "error", err)
os.Exit(1) os.Exit(1)
} }
slog.Info("review posted", "review_id", posted.ID, "user", posted.User.Login, "pr", prNumber) slog.Info("review posted", "review_id", posted.ID, "user", posted.User.Login, "pr", prNumber)
// Supersede all old reviews // Supersede all old reviews with link to the new one
if len(oldReviews) > 0 { if len(oldReviews) > 0 {
if err := supersedeOldReviews(ctx, client, *provider, *vcsURL, owner, repoName, prNumber, oldReviews, posted.ID, sentinel); err != nil { newReviewURL := fmt.Sprintf("%s/%s/%s/pulls/%d#pullrequestreview-%d", strings.TrimRight(*giteaURL, "/"), owner, repoName, prNumber, posted.ID)
slog.Error("failed to supersede old reviews", "error", err) for _, oldReview := range oldReviews {
os.Exit(1) cid, err := giteaClient.GetTimelineReviewCommentIDForReview(ctx, owner, repoName, prNumber, oldReview.ID)
} if err != nil {
} slog.Warn("could not find comment ID for old review", "review_id", oldReview.ID, "error", err)
}
// verdictToEvent maps a verdict string from the LLM response to a canonical vcs.ReviewEvent.
func verdictToEvent(verdict string) vcs.ReviewEvent {
switch verdict {
case "APPROVE":
return vcs.ReviewEventApprove
case "REQUEST_CHANGES":
return vcs.ReviewEventRequestChanges
default:
return vcs.ReviewEventComment
}
}
// supersedeOldReviews marks prior reviews as superseded so only the latest review is visible.
// For GitHub: dismisses old reviews (vcsURL is unused in this path).
// For Gitea: edits the review body with a link to the new review and resolves inline comments.
//
// The vcsURL parameter is only used in the Gitea path to construct review permalink URLs;
// it is accepted unconditionally to keep the function signature uniform across providers.
func supersedeOldReviews(ctx context.Context, client vcs.Client, provider, vcsURL, owner, repoName string, prNumber int, oldReviews []vcs.Review, newReviewID int64, sentinel string) error {
switch provider {
case "github":
// Best-effort dismissal: attempt all reviews, join any errors.
var errs []error
for _, old := range oldReviews {
if err := client.DismissReview(ctx, owner, repoName, prNumber, old.ID, "Superseded by new review"); err != nil {
slog.Warn("failed to dismiss review", "id", old.ID, "error", err)
errs = append(errs, fmt.Errorf("dismiss review %d: %w", old.ID, err))
} else {
slog.Info("dismissed old review", "review_id", old.ID, "new_review_id", newReviewID, "pr", prNumber)
}
}
return errors.Join(errs...)
case "gitea":
// Continue to Gitea-specific logic below the switch.
default:
return fmt.Errorf("supersedeOldReviews: unsupported provider %q", provider)
}
// The type assertion below is guaranteed to succeed: the caller's provider switch
// ensures we only reach this point when provider == "gitea", and the gitea provider
// always constructs a *gitea.Adapter. The !ok branch guards against future refactors
// (e.g. wrapping the adapter in a decorator) that would silently break this path.
giteaAdapter, ok := client.(*gitea.Adapter)
if !ok {
return fmt.Errorf("expected gitea.Adapter for gitea provider, got %T", client)
}
underlying := giteaAdapter.Underlying()
newReviewURL := fmt.Sprintf("%s/%s/%s/pulls/%d#pullrequestreview-%d", strings.TrimRight(vcsURL, "/"), owner, repoName, prNumber, newReviewID)
for _, oldReview := range oldReviews {
cid, err := underlying.GetTimelineReviewCommentIDForReview(ctx, owner, repoName, prNumber, oldReview.ID)
if err != nil {
slog.Warn("could not find comment ID for old review", "review_id", oldReview.ID, "error", err)
continue
}
supersededBody := buildSupersededBody(oldReview.Body, oldReview.CommitID, newReviewURL, sentinel)
if err := underlying.EditComment(ctx, owner, repoName, cid, supersededBody); err != nil {
slog.Warn("could not mark old review as superseded", "review_id", oldReview.ID, "comment_id", cid, "error", err)
continue
}
slog.Info("marked old review as superseded", "review_id", oldReview.ID, "new_review_id", newReviewID, "pr", prNumber)
// Resolve old review's inline comments
oldComments, err := underlying.ListReviewComments(ctx, owner, repoName, prNumber, oldReview.ID)
if err != nil {
slog.Warn("could not list old review comments for resolution", "review_id", oldReview.ID, "error", err)
continue
}
resolved, failed := 0, 0
for _, c := range oldComments {
if c.ID == 0 {
continue continue
} }
if err := underlying.ResolveComment(ctx, owner, repoName, c.ID); err != nil { supersededBody := buildSupersededBody(oldReview.Body, oldReview.CommitID, newReviewURL, sentinel)
slog.Debug("could not resolve inline comment", "comment_id", c.ID, "error", err) if err := giteaClient.EditComment(ctx, owner, repoName, cid, supersededBody); err != nil {
failed++ slog.Warn("could not mark old review as superseded", "review_id", oldReview.ID, "comment_id", cid, "error", err)
} else { continue
resolved++ }
slog.Info("marked old review as superseded", "review_id", oldReview.ID, "new_review_id", posted.ID, "pr", prNumber)
// Resolve old review's inline comments
oldComments, err := giteaClient.ListReviewComments(ctx, owner, repoName, prNumber, oldReview.ID)
if err != nil {
slog.Warn("could not list old review comments for resolution", "review_id", oldReview.ID, "error", err)
continue
}
resolved, failed := 0, 0
for _, c := range oldComments {
if c.ID == 0 {
continue
}
if err := giteaClient.ResolveComment(ctx, owner, repoName, c.ID); err != nil {
slog.Debug("could not resolve inline comment", "comment_id", c.ID, "error", err)
failed++
} else {
resolved++
}
}
if resolved > 0 {
slog.Info("resolved old inline comments", "review_id", oldReview.ID, "count", resolved, "pr", prNumber)
}
if failed > 0 {
slog.Warn("some inline comments could not be resolved", "review_id", oldReview.ID, "failed", failed, "pr", prNumber)
} }
} }
if resolved > 0 {
slog.Info("resolved old inline comments", "review_id", oldReview.ID, "count", resolved, "pr", prNumber)
}
if failed > 0 {
slog.Warn("some inline comments could not be resolved", "review_id", oldReview.ID, "failed", failed, "pr", prNumber)
}
} }
return nil
} }
// fetchFileContext fetches the full content of modified files from the PR branch. // fetchFileContext fetches the full content of modified files from the PR branch.
func fetchFileContext(ctx context.Context, client vcs.PRReader, owner, repo, ref string, files []vcs.ChangedFile) string { func fetchFileContext(ctx context.Context, client *gitea.Client, owner, repo, ref string, files []gitea.ChangedFile) string {
var sb strings.Builder var sb strings.Builder
for _, f := range files { for _, f := range files {
if ctx.Err() != nil { if ctx.Err() != nil {
@@ -618,7 +507,7 @@ func fetchFileContext(ctx context.Context, client vcs.PRReader, owner, repo, ref
if f.Status == "removed" { if f.Status == "removed" {
continue // Skip deleted files continue // Skip deleted files
} }
content, err := client.GetFileContentAtRef(ctx, owner, repo, f.Filename, ref) content, err := client.GetFileContentRef(ctx, owner, repo, f.Filename, ref)
if err != nil { if err != nil {
slog.Warn("could not fetch file content", "file", f.Filename, "error", err) slog.Warn("could not fetch file content", "file", f.Filename, "error", err)
continue continue
@@ -635,25 +524,11 @@ func fetchFileContext(ctx context.Context, client vcs.PRReader, owner, repo, ref
// patternsRepo is comma-separated list of owner/name repos. // patternsRepo is comma-separated list of owner/name repos.
// patternsFiles is comma-separated list of file paths or directories. // patternsFiles is comma-separated list of file paths or directories.
// If a path ends with / or is a directory, all files within it are fetched recursively. // If a path ends with / or is a directory, all files within it are fetched recursively.
// If patternsFiles is empty, all files from the repo root are fetched. func fetchPatterns(ctx context.Context, client *gitea.Client, patternsRepo, patternsFiles string) string {
func fetchPatterns(ctx context.Context, client vcs.FileReader, patternsRepo, patternsFiles string) string {
var sb strings.Builder var sb strings.Builder
repos := strings.Split(patternsRepo, ",") repos := strings.Split(patternsRepo, ",")
paths := strings.Split(patternsFiles, ",")
// Build the list of paths to fetch
var paths []string
if patternsFiles == "" {
// Empty patternsFiles means "fetch all files from repo root"
paths = []string{""}
} else {
for _, p := range strings.Split(patternsFiles, ",") {
p = strings.TrimSpace(p)
if p != "" {
paths = append(paths, p)
}
}
}
for _, repoRef := range repos { for _, repoRef := range repos {
if ctx.Err() != nil { if ctx.Err() != nil {
@@ -674,7 +549,12 @@ func fetchPatterns(ctx context.Context, client vcs.FileReader, patternsRepo, pat
var repoSkippedFiles []string var repoSkippedFiles []string
for _, path := range paths { for _, path := range paths {
files, err := vcs.GetAllFilesInPath(ctx, client, owner, repo, path) path = strings.TrimSpace(path)
if path == "" {
continue
}
files, err := client.GetAllFilesInPath(ctx, owner, repo, path)
if err != nil { if err != nil {
slog.Warn("could not fetch patterns", "path", path, "repo", repoRef, "error", err) slog.Warn("could not fetch patterns", "path", path, "repo", repoRef, "error", err)
continue continue
@@ -713,20 +593,18 @@ func isPatternFile(path string) bool {
} }
// evaluateCIStatus checks if all CI statuses indicate success. // evaluateCIStatus checks if all CI statuses indicate success.
// Returns passed=true if no checks have failed (pending checks are not treated as failures). func evaluateCIStatus(statuses []gitea.CommitStatus) (passed bool, details string) {
func evaluateCIStatus(statuses []vcs.CommitStatus) (passed bool, details string) {
if len(statuses) == 0 { if len(statuses) == 0 {
return true, "no CI statuses found" return true, "no CI statuses found"
} }
var failed []string var failed []string
var pending int
for _, s := range statuses { for _, s := range statuses {
switch s.Status { switch s.Status {
case "success": case "success":
// good // good
case "pending": case "pending":
pending++ // treat pending as not-failed
case "failure", "error": case "failure", "error":
failed = append(failed, fmt.Sprintf("%s: %s", s.Context, s.Description)) failed = append(failed, fmt.Sprintf("%s: %s", s.Context, s.Description))
} }
@@ -735,9 +613,6 @@ func evaluateCIStatus(statuses []vcs.CommitStatus) (passed bool, details string)
if len(failed) > 0 { if len(failed) > 0 {
return false, strings.Join(failed, "; ") return false, strings.Join(failed, "; ")
} }
if pending > 0 {
return true, fmt.Sprintf("no failures (%d pending)", pending)
}
return true, "all checks passed" return true, "all checks passed"
} }
@@ -853,10 +728,10 @@ func buildSupersededBody(originalBody, commitSHA, newReviewURL, sentinel string)
} }
// hasSharedToken detects if another review-bot role posted under the same // hasSharedToken detects if another review-bot role posted under the same
// VCS user. This indicates misconfiguration where two roles share a token // Gitea user. This indicates misconfiguration where two roles share a token
// instead of having separate accounts. Returns true if shared token // instead of having separate Gitea accounts. Returns true if shared token
// detected (caller should skip update-in-place logic to avoid clobbering). // detected (caller should skip update-in-place logic to avoid clobbering).
func hasSharedToken(reviews []vcs.Review, ownSentinel string) bool { func hasSharedToken(reviews []gitea.Review, ownSentinel string) bool {
ownLogin := "" ownLogin := ""
for _, r := range reviews { for _, r := range reviews {
if strings.Contains(r.Body, ownSentinel) { if strings.Contains(r.Body, ownSentinel) {
@@ -869,7 +744,7 @@ func hasSharedToken(reviews []vcs.Review, ownSentinel string) bool {
} }
for _, r := range reviews { for _, r := range reviews {
if r.User.Login == ownLogin && strings.Contains(r.Body, "<!-- review-bot:") && !strings.Contains(r.Body, ownSentinel) { if r.User.Login == ownLogin && strings.Contains(r.Body, "<!-- review-bot:") && !strings.Contains(r.Body, ownSentinel) {
slog.Warn("shared token detected — another review-bot role is using the same VCS user", slog.Warn("shared token detected — another review-bot role is using the same Gitea user",
"sibling_role", extractSentinelName(r.Body), "user", ownLogin) "sibling_role", extractSentinelName(r.Body), "user", ownLogin)
return true return true
} }
@@ -890,27 +765,29 @@ func extractSentinelName(body string) string {
if end < 0 { if end < 0 {
return "unknown" return "unknown"
} }
name := rest[:end] return rest[:end]
// Sanitize: strip control characters to prevent log injection.
name = strings.Map(func(r rune) rune {
if r < 0x20 || r == 0x7f {
return -1
}
return r
}, name)
if len(name) > 64 {
name = name[:64]
}
if name == "" {
return "unknown"
}
return name
} }
// findOwnReview locates the most recent non-superseded review matching the sentinel.
func findOwnReview(reviews []gitea.Review, sentinel string) *gitea.Review {
var best *gitea.Review
for i := range reviews {
if !strings.Contains(reviews[i].Body, sentinel) {
continue
}
if strings.Contains(reviews[i].Body, "~~Original review~~") {
continue
}
if best == nil || reviews[i].ID > best.ID {
best = &reviews[i]
}
}
return best
}
// findAllOwnReviews returns all non-superseded reviews matching the sentinel. // findAllOwnReviews returns all non-superseded reviews matching the sentinel.
func findAllOwnReviews(reviews []vcs.Review, sentinel string) []vcs.Review { func findAllOwnReviews(reviews []gitea.Review, sentinel string) []gitea.Review {
var result []vcs.Review var result []gitea.Review
for i := range reviews { for i := range reviews {
if !strings.Contains(reviews[i].Body, sentinel) { if !strings.Contains(reviews[i].Body, sentinel) {
continue continue
@@ -935,3 +812,35 @@ func shouldSkipStaleReview(evaluatedSHA, currentSHA string) bool {
} }
return evaluatedSHA != currentSHA return evaluatedSHA != currentSHA
} }
// giteaClientAdapter adapts gitea.Client to vcs.FileReader interface.
type giteaClientAdapter struct {
client *gitea.Client
}
func newGiteaClientAdapter(c *gitea.Client) *giteaClientAdapter {
return &giteaClientAdapter{client: c}
}
func (a *giteaClientAdapter) ListContents(ctx context.Context, owner, repo, path string) ([]vcs.ContentEntry, error) {
entries, err := a.client.ListContents(ctx, owner, repo, path)
if err != nil {
return nil, err
}
result := make([]vcs.ContentEntry, len(entries))
for i, e := range entries {
result[i] = vcs.ContentEntry{
Name: e.Name,
Path: e.Path,
Type: e.Type,
}
}
return result, nil
}
func (a *giteaClientAdapter) GetFileContent(ctx context.Context, owner, repo, filePath, ref string) (string, error) {
if ref != "" {
return a.client.GetFileContentRef(ctx, owner, repo, filePath, ref)
}
return a.client.GetFileContent(ctx, owner, repo, filePath)
}
+126 -133
View File
@@ -10,7 +10,7 @@ import (
"strings" "strings"
"testing" "testing"
"gitea.weiker.me/rodin/review-bot/vcs" "gitea.weiker.me/rodin/review-bot/gitea"
) )
func TestValidateReviewerName(t *testing.T) { func TestValidateReviewerName(t *testing.T) {
@@ -107,7 +107,9 @@ func TestValidateWorkspacePath(t *testing.T) {
workspace: tmpDir, workspace: tmpDir,
path: "/etc/passwd", path: "/etc/passwd",
wantErr: true, wantErr: true,
errMatch: "failed to resolve", // Go 1.21+ filepath.Join normalizes absolute paths: Join("/tmp/x", "/etc/passwd")
// becomes "/tmp/x/etc/passwd", which is within workspace but doesn't exist.
errMatch: "failed to resolve",
}, },
{ {
name: "nonexistent file", name: "nonexistent file",
@@ -152,14 +154,15 @@ func TestValidateWorkspacePath(t *testing.T) {
} }
} }
func makeReview(id int64, login, state string, stale bool, body string) vcs.Review { func makeReview(id int64, login, state string, stale bool, body string) gitea.Review {
return vcs.Review{ r := gitea.Review{
ID: id, ID: id,
Body: body, Body: body,
User: vcs.UserInfo{Login: login},
State: state, State: state,
Stale: stale, Stale: stale,
} }
r.User.Login = login
return r
} }
func TestBuildSupersededBody(t *testing.T) { func TestBuildSupersededBody(t *testing.T) {
@@ -210,11 +213,96 @@ func TestBuildSupersededBodyShortSHA(t *testing.T) {
} }
} }
func TestFindOwnReview(t *testing.T) {
tests := []struct {
name string
reviews []gitea.Review
sentinel string
wantID int64
wantNil bool
}{
{
name: "no reviews",
reviews: nil,
sentinel: "<!-- review-bot:sonnet -->",
wantNil: true,
},
{
name: "found by sentinel",
reviews: []gitea.Review{
makeReview(42, "bot", "APPROVED", false, "review body\n<!-- review-bot:sonnet -->"),
},
sentinel: "<!-- review-bot:sonnet -->",
wantID: 42,
},
{
name: "wrong sentinel",
reviews: []gitea.Review{
makeReview(42, "bot", "APPROVED", false, "body\n<!-- review-bot:gpt -->"),
},
sentinel: "<!-- review-bot:sonnet -->",
wantNil: true,
},
{
name: "multiple reviews, returns first match",
reviews: []gitea.Review{
makeReview(10, "bot", "APPROVED", false, "old\n<!-- review-bot:gpt -->"),
makeReview(20, "bot", "APPROVED", false, "new\n<!-- review-bot:sonnet -->"),
},
sentinel: "<!-- review-bot:sonnet -->",
wantID: 20,
},
{
name: "skips superseded review",
reviews: []gitea.Review{
makeReview(10, "bot", "APPROVED", false, "~~Original review~~\n\n**Superseded**\n<!-- review-bot:sonnet -->"),
makeReview(20, "bot", "APPROVED", false, "fresh review\n<!-- review-bot:sonnet -->"),
},
sentinel: "<!-- review-bot:sonnet -->",
wantID: 20,
},
{
name: "only superseded reviews exist",
reviews: []gitea.Review{
makeReview(10, "bot", "APPROVED", false, "~~Original review~~\n\n<!-- review-bot:sonnet -->"),
},
sentinel: "<!-- review-bot:sonnet -->",
wantNil: true,
},
{
name: "picks highest ID among matches",
reviews: []gitea.Review{
makeReview(50, "bot", "APPROVED", false, "v1\n<!-- review-bot:sonnet -->"),
makeReview(30, "bot", "APPROVED", false, "v0\n<!-- review-bot:sonnet -->"),
},
sentinel: "<!-- review-bot:sonnet -->",
wantID: 50,
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
got := findOwnReview(tc.reviews, tc.sentinel)
if tc.wantNil {
if got != nil {
t.Errorf("findOwnReview() = %v, want nil", got)
}
} else {
if got == nil {
t.Fatal("findOwnReview() = nil, want non-nil")
}
if got.ID != tc.wantID {
t.Errorf("findOwnReview().ID = %d, want %d", got.ID, tc.wantID)
}
}
})
}
}
func TestHasSharedToken(t *testing.T) { func TestHasSharedToken(t *testing.T) {
tests := []struct { tests := []struct {
name string name string
reviews []vcs.Review reviews []gitea.Review
sentinel string sentinel string
want bool want bool
}{ }{
@@ -226,36 +314,36 @@ func TestHasSharedToken(t *testing.T) {
}, },
{ {
name: "no own review yet - cannot detect", name: "no own review yet - cannot detect",
reviews: []vcs.Review{ reviews: []gitea.Review{
makeReview(1, "other", "APPROVED", false, "<!-- review-bot:gpt --> body"), {ID: 1, User: struct{ Login string `json:"login"` }{Login: "other"}, Body: "<!-- review-bot:gpt --> body"},
}, },
sentinel: "<!-- review-bot:sonnet -->", sentinel: "<!-- review-bot:sonnet -->",
want: false, want: false,
}, },
{ {
name: "separate users - no shared token", name: "separate users - no shared token",
reviews: []vcs.Review{ reviews: []gitea.Review{
makeReview(1, "sonnet-review-bot", "APPROVED", false, "<!-- review-bot:sonnet --> body"), {ID: 1, User: struct{ Login string `json:"login"` }{Login: "sonnet-review-bot"}, Body: "<!-- review-bot:sonnet --> body"},
makeReview(2, "security-review-bot", "APPROVED", false, "<!-- review-bot:security --> body"), {ID: 2, User: struct{ Login string `json:"login"` }{Login: "security-review-bot"}, Body: "<!-- review-bot:security --> body"},
}, },
sentinel: "<!-- review-bot:sonnet -->", sentinel: "<!-- review-bot:sonnet -->",
want: false, want: false,
}, },
{ {
name: "shared token detected - same user different sentinels", name: "shared token detected - same user different sentinels",
reviews: []vcs.Review{ reviews: []gitea.Review{
makeReview(1, "sonnet-review-bot", "APPROVED", false, "<!-- review-bot:sonnet --> body"), {ID: 1, User: struct{ Login string `json:"login"` }{Login: "sonnet-review-bot"}, Body: "<!-- review-bot:sonnet --> body"},
makeReview(2, "sonnet-review-bot", "APPROVED", false, "<!-- review-bot:security --> body"), {ID: 2, User: struct{ Login string `json:"login"` }{Login: "sonnet-review-bot"}, Body: "<!-- review-bot:security --> body"},
}, },
sentinel: "<!-- review-bot:sonnet -->", sentinel: "<!-- review-bot:sonnet -->",
want: true, want: true,
}, },
{ {
name: "three roles same user", name: "three roles same user",
reviews: []vcs.Review{ reviews: []gitea.Review{
makeReview(1, "bot", "APPROVED", false, "<!-- review-bot:sonnet --> body"), {ID: 1, User: struct{ Login string `json:"login"` }{Login: "bot"}, Body: "<!-- review-bot:sonnet --> body"},
makeReview(2, "bot", "APPROVED", false, "<!-- review-bot:security --> body"), {ID: 2, User: struct{ Login string `json:"login"` }{Login: "bot"}, Body: "<!-- review-bot:security --> body"},
makeReview(3, "bot", "APPROVED", false, "<!-- review-bot:gpt --> body"), {ID: 3, User: struct{ Login string `json:"login"` }{Login: "bot"}, Body: "<!-- review-bot:gpt --> body"},
}, },
sentinel: "<!-- review-bot:sonnet -->", sentinel: "<!-- review-bot:sonnet -->",
want: true, want: true,
@@ -416,56 +504,10 @@ func TestIsPatternFile(t *testing.T) {
} }
} }
// TestBuildPatternPaths verifies the path-building logic for fetchPatterns.
// Empty patternsFiles means "fetch all from root" (represented as [""]).
func TestBuildPatternPaths(t *testing.T) {
buildPaths := func(patternsFiles string) []string {
if patternsFiles == "" {
return []string{""}
}
var paths []string
for _, p := range strings.Split(patternsFiles, ",") {
p = strings.TrimSpace(p)
if p != "" {
paths = append(paths, p)
}
}
return paths
}
tests := []struct {
name string
input string
want []string
}{
{"empty fetches root", "", []string{""}},
{"single file", "README.md", []string{"README.md"}},
{"multiple files", "README.md,PATTERNS.md", []string{"README.md", "PATTERNS.md"}},
{"trims whitespace", " foo.md , bar.md ", []string{"foo.md", "bar.md"}},
{"skips empty between commas", "foo.md,,bar.md", []string{"foo.md", "bar.md"}},
{"directory path", "patterns/", []string{"patterns/"}},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
got := buildPaths(tc.input)
if len(got) != len(tc.want) {
t.Errorf("buildPaths(%q) = %v, want %v", tc.input, got, tc.want)
return
}
for i := range got {
if got[i] != tc.want[i] {
t.Errorf("buildPaths(%q)[%d] = %q, want %q", tc.input, i, got[i], tc.want[i])
}
}
})
}
}
func TestEvaluateCIStatus(t *testing.T) { func TestEvaluateCIStatus(t *testing.T) {
tests := []struct { tests := []struct {
name string name string
statuses []vcs.CommitStatus statuses []gitea.CommitStatus
wantPassed bool wantPassed bool
wantSubstr string wantSubstr string
}{ }{
@@ -477,7 +519,7 @@ func TestEvaluateCIStatus(t *testing.T) {
}, },
{ {
name: "all success", name: "all success",
statuses: []vcs.CommitStatus{ statuses: []gitea.CommitStatus{
{Status: "success", Context: "ci/build", Description: "Build passed"}, {Status: "success", Context: "ci/build", Description: "Build passed"},
{Status: "success", Context: "ci/test", Description: "Tests passed"}, {Status: "success", Context: "ci/test", Description: "Tests passed"},
}, },
@@ -486,7 +528,7 @@ func TestEvaluateCIStatus(t *testing.T) {
}, },
{ {
name: "one failure", name: "one failure",
statuses: []vcs.CommitStatus{ statuses: []gitea.CommitStatus{
{Status: "success", Context: "ci/build", Description: "Build passed"}, {Status: "success", Context: "ci/build", Description: "Build passed"},
{Status: "failure", Context: "ci/test", Description: "Tests failed"}, {Status: "failure", Context: "ci/test", Description: "Tests failed"},
}, },
@@ -495,7 +537,7 @@ func TestEvaluateCIStatus(t *testing.T) {
}, },
{ {
name: "error status", name: "error status",
statuses: []vcs.CommitStatus{ statuses: []gitea.CommitStatus{
{Status: "error", Context: "ci/lint", Description: "Lint error"}, {Status: "error", Context: "ci/lint", Description: "Lint error"},
}, },
wantPassed: false, wantPassed: false,
@@ -503,16 +545,16 @@ func TestEvaluateCIStatus(t *testing.T) {
}, },
{ {
name: "pending treated as not-failed", name: "pending treated as not-failed",
statuses: []vcs.CommitStatus{ statuses: []gitea.CommitStatus{
{Status: "pending", Context: "ci/build", Description: "In progress"}, {Status: "pending", Context: "ci/build", Description: "In progress"},
{Status: "success", Context: "ci/test", Description: "Tests passed"}, {Status: "success", Context: "ci/test", Description: "Tests passed"},
}, },
wantPassed: true, wantPassed: true,
wantSubstr: "no failures", wantSubstr: "all checks passed",
}, },
{ {
name: "multiple failures", name: "multiple failures",
statuses: []vcs.CommitStatus{ statuses: []gitea.CommitStatus{
{Status: "failure", Context: "ci/build", Description: "Build failed"}, {Status: "failure", Context: "ci/build", Description: "Build failed"},
{Status: "failure", Context: "ci/test", Description: "Tests failed"}, {Status: "failure", Context: "ci/test", Description: "Tests failed"},
}, },
@@ -521,7 +563,7 @@ func TestEvaluateCIStatus(t *testing.T) {
}, },
{ {
name: "mixed with pending and failure", name: "mixed with pending and failure",
statuses: []vcs.CommitStatus{ statuses: []gitea.CommitStatus{
{Status: "success", Context: "ci/build", Description: "Build passed"}, {Status: "success", Context: "ci/build", Description: "Build passed"},
{Status: "pending", Context: "ci/deploy", Description: "Deploying"}, {Status: "pending", Context: "ci/deploy", Description: "Deploying"},
{Status: "failure", Context: "ci/test", Description: "Tests failed"}, {Status: "failure", Context: "ci/test", Description: "Tests failed"},
@@ -750,7 +792,7 @@ func TestMainSubprocess_InvalidReviewerName(t *testing.T) {
if os.Getenv("TEST_SUBPROCESS_MAIN") == "1" { if os.Getenv("TEST_SUBPROCESS_MAIN") == "1" {
flag.CommandLine = flag.NewFlagSet(os.Args[0], flag.ExitOnError) flag.CommandLine = flag.NewFlagSet(os.Args[0], flag.ExitOnError)
os.Args = []string{"review-bot", os.Args = []string{"review-bot",
"--vcs-url", "http://localhost", "--gitea-url", "http://localhost",
"--repo", "owner/repo", "--repo", "owner/repo",
"--pr", "1", "--pr", "1",
"--reviewer-name", "invalid name", "--reviewer-name", "invalid name",
@@ -778,7 +820,7 @@ func TestMainSubprocess_InvalidRepo(t *testing.T) {
if os.Getenv("TEST_SUBPROCESS_MAIN") == "1" { if os.Getenv("TEST_SUBPROCESS_MAIN") == "1" {
flag.CommandLine = flag.NewFlagSet(os.Args[0], flag.ExitOnError) flag.CommandLine = flag.NewFlagSet(os.Args[0], flag.ExitOnError)
os.Args = []string{"review-bot", os.Args = []string{"review-bot",
"--vcs-url", "http://localhost", "--gitea-url", "http://localhost",
"--repo", "invalidrepo", "--repo", "invalidrepo",
"--pr", "1", "--pr", "1",
"--reviewer-token", "tok", "--reviewer-token", "tok",
@@ -805,7 +847,7 @@ func TestMainSubprocess_InvalidPRNumber(t *testing.T) {
if os.Getenv("TEST_SUBPROCESS_MAIN") == "1" { if os.Getenv("TEST_SUBPROCESS_MAIN") == "1" {
flag.CommandLine = flag.NewFlagSet(os.Args[0], flag.ExitOnError) flag.CommandLine = flag.NewFlagSet(os.Args[0], flag.ExitOnError)
os.Args = []string{"review-bot", os.Args = []string{"review-bot",
"--vcs-url", "http://localhost", "--gitea-url", "http://localhost",
"--repo", "owner/repo", "--repo", "owner/repo",
"--pr", "notanumber", "--pr", "notanumber",
"--reviewer-token", "tok", "--reviewer-token", "tok",
@@ -832,7 +874,7 @@ func TestMainSubprocess_InvalidTemperature(t *testing.T) {
if os.Getenv("TEST_SUBPROCESS_MAIN") == "1" { if os.Getenv("TEST_SUBPROCESS_MAIN") == "1" {
flag.CommandLine = flag.NewFlagSet(os.Args[0], flag.ExitOnError) flag.CommandLine = flag.NewFlagSet(os.Args[0], flag.ExitOnError)
os.Args = []string{"review-bot", os.Args = []string{"review-bot",
"--vcs-url", "http://localhost", "--gitea-url", "http://localhost",
"--repo", "owner/repo", "--repo", "owner/repo",
"--pr", "1", "--pr", "1",
"--reviewer-token", "tok", "--reviewer-token", "tok",
@@ -860,7 +902,7 @@ func TestMainSubprocess_InvalidProvider(t *testing.T) {
if os.Getenv("TEST_SUBPROCESS_MAIN") == "1" { if os.Getenv("TEST_SUBPROCESS_MAIN") == "1" {
flag.CommandLine = flag.NewFlagSet(os.Args[0], flag.ExitOnError) flag.CommandLine = flag.NewFlagSet(os.Args[0], flag.ExitOnError)
os.Args = []string{"review-bot", os.Args = []string{"review-bot",
"--vcs-url", "http://localhost", "--gitea-url", "http://localhost",
"--repo", "owner/repo", "--repo", "owner/repo",
"--pr", "1", "--pr", "1",
"--reviewer-token", "tok", "--reviewer-token", "tok",
@@ -884,35 +926,7 @@ func TestMainSubprocess_InvalidProvider(t *testing.T) {
} }
} }
func TestMainSubprocess_InvalidVCSProvider(t *testing.T) { // cleanEnv returns environ without any GITEA/LLM/REVIEWER env vars that would
if os.Getenv("TEST_SUBPROCESS_MAIN") == "1" {
flag.CommandLine = flag.NewFlagSet(os.Args[0], flag.ExitOnError)
os.Args = []string{"review-bot",
"--provider", "invalid",
"--vcs-url", "http://localhost",
"--repo", "owner/repo",
"--pr", "1",
"--reviewer-token", "tok",
"--llm-base-url", "http://localhost",
"--llm-api-key", "key",
"--llm-model", "model",
}
main()
return
}
cmd := exec.Command(os.Args[0], "-test.run=TestMainSubprocess_InvalidVCSProvider")
cmd.Env = append(cleanEnv(), "TEST_SUBPROCESS_MAIN=1")
out, err := cmd.CombinedOutput()
if err == nil {
t.Fatal("expected non-zero exit with invalid VCS provider")
}
if !strings.Contains(string(out), "invalid --provider") {
t.Errorf("expected error about invalid --provider, got: %s", out)
}
}
// cleanEnv returns environ without any GITEA/LLM/REVIEWER/VCS env vars that would
// interfere with testing missing-flag scenarios. // interfere with testing missing-flag scenarios.
func cleanEnv() []string { func cleanEnv() []string {
var env []string var env []string
@@ -920,7 +934,6 @@ func cleanEnv() []string {
key := strings.SplitN(e, "=", 2)[0] key := strings.SplitN(e, "=", 2)[0]
switch { switch {
case strings.HasPrefix(key, "GITEA_"), case strings.HasPrefix(key, "GITEA_"),
strings.HasPrefix(key, "VCS_"),
strings.HasPrefix(key, "LLM_"), strings.HasPrefix(key, "LLM_"),
strings.HasPrefix(key, "REVIEWER_"), strings.HasPrefix(key, "REVIEWER_"),
strings.HasPrefix(key, "PR_"), strings.HasPrefix(key, "PR_"),
@@ -938,12 +951,12 @@ func cleanEnv() []string {
} }
func TestFindAllOwnReviews(t *testing.T) { func TestFindAllOwnReviews(t *testing.T) {
reviews := []vcs.Review{ reviews := []gitea.Review{
makeReview(1, "bot", "APPROVED", false, "<!-- review-bot:sonnet -->\nfirst review"), {ID: 1, Body: "<!-- review-bot:sonnet -->\nfirst review"},
makeReview(2, "bot", "APPROVED", false, "<!-- review-bot:gpt -->\nother bot"), {ID: 2, Body: "<!-- review-bot:gpt -->\nother bot"},
makeReview(3, "bot", "APPROVED", false, "<!-- review-bot:sonnet -->\nsecond review"), {ID: 3, Body: "<!-- review-bot:sonnet -->\nsecond review"},
makeReview(4, "bot", "APPROVED", false, "~~Original review~~\n<!-- review-bot:sonnet -->\nsuperseded"), {ID: 4, Body: "~~Original review~~\n<!-- review-bot:sonnet -->\nsuperseded"},
makeReview(5, "bot", "APPROVED", false, "<!-- review-bot:sonnet -->\nthird review"), {ID: 5, Body: "<!-- review-bot:sonnet -->\nthird review"},
} }
got := findAllOwnReviews(reviews, "<!-- review-bot:sonnet -->") got := findAllOwnReviews(reviews, "<!-- review-bot:sonnet -->")
@@ -1007,23 +1020,3 @@ func TestShouldSkipStaleReview(t *testing.T) {
}) })
} }
} }
func TestVerdictToEvent(t *testing.T) {
tests := []struct {
verdict string
want vcs.ReviewEvent
}{
{"APPROVE", vcs.ReviewEventApprove},
{"REQUEST_CHANGES", vcs.ReviewEventRequestChanges},
{"COMMENT", vcs.ReviewEventComment},
{"other", vcs.ReviewEventComment},
{"", vcs.ReviewEventComment},
}
for _, tc := range tests {
got := verdictToEvent(tc.verdict)
if got != tc.want {
t.Errorf("verdictToEvent(%q) = %q, want %q", tc.verdict, got, tc.want)
}
}
}
+31 -10
View File
@@ -9,7 +9,7 @@ JSON is awkward for persona files that contain multi-line text (identity, severi
- Backwards compatibility: existing JSON personas must continue to work - Backwards compatibility: existing JSON personas must continue to work
- Security: protect against DoS via deeply nested YAML (AIKIDO-2024-10486) - Security: protect against DoS via deeply nested YAML (AIKIDO-2024-10486)
- Consistency: use `.yaml` extension (not `.yml`) - Consistency: use `.yaml` extension (not `.yml`)
- Library: use `github.com/goccy/go-yaml` v1.16.0+ (approved in CONVENTIONS.md); we implement custom AST-based depth/node-count checks for precise alias-aware validation - Library: use `gopkg.in/yaml.v3` (approved in CONVENTIONS.md) with explicit depth limiting
## Proposed Approach ## Proposed Approach
@@ -33,16 +33,37 @@ func parsePersona(data []byte, source string) (*Persona, error) {
### YAML Parsing with Depth Protection ### YAML Parsing with Depth Protection
We implement a custom AST-based depth/node-count walk (`checkYAMLDepth` in ```go
`review/persona.go`) rather than relying on library decoder options. Key design func unmarshalYAMLWithDepthLimit(data []byte, out any, maxDepth int) error {
decisions: var node yaml.Node
dec := yaml.NewDecoder(bytes.NewReader(data))
if err := dec.Decode(&node); err != nil {
return err
}
if err := checkYAMLDepth(&node, 0, maxDepth); err != nil {
return err
}
return node.Decode(out)
}
- **Library:** `github.com/goccy/go-yaml` with `ast.Node`-based traversal func checkYAMLDepth(node *yaml.Node, depth, maxDepth int) error {
- **Dual-map tracking:** `validated` (depth-aware short-circuit) + `visiting` (cycle detection) if depth > maxDepth {
- **Node-count limit:** Conservative overcounting bounds total validation work return fmt.Errorf("YAML nesting depth exceeds maximum (%d)", maxDepth)
- **Alias-aware depth:** Aliases increment depth and are re-checked when encountered at greater depths }
// Handle alias nodes by following the Alias pointer
if node.Kind == yaml.AliasNode && node.Alias != nil {
return checkYAMLDepth(node.Alias, depth, maxDepth)
}
for _, child := range node.Content {
if err := checkYAMLDepth(child, depth+1, maxDepth); err != nil {
return err
}
}
return nil
}
```
See `review/persona.go:checkYAMLDepth` for the authoritative implementation. The `gopkg.in/yaml.v3` library does not have built-in depth protection, so we implement explicit depth checking by first decoding into a `yaml.Node`, walking the tree to verify depth (including alias resolution), then decoding into the target struct.
## State/Data Model ## State/Data Model
@@ -53,7 +74,7 @@ No new state. Same `Persona` struct, just different parsing.
| Error | Handling | | Error | Handling |
|-------|----------| |-------|----------|
| Invalid YAML syntax | Return parse error with source file | | Invalid YAML syntax | Return parse error with source file |
| Deeply nested YAML | Custom AST walk (`checkYAMLDepth`) rejects before decode | | Deeply nested YAML | Library rejects (v1.16.0+ fix) |
| Unknown extension | Fall back to JSON parsing | | Unknown extension | Fall back to JSON parsing |
| Missing required fields | Validation rejects after parse | | Missing required fields | Validation rejects after parse |
-232
View File
@@ -1,232 +0,0 @@
package gitea
import (
"context"
"fmt"
"gitea.weiker.me/rodin/review-bot/vcs"
)
// Adapter wraps a gitea.Client and satisfies the vcs.Client interface.
// It handles translation between GitHub-canonical diff positions and Gitea
// line numbers, and between canonical review event strings and Gitea-native values.
type Adapter struct {
client *Client
}
// Compile-time interface conformance assertion.
var _ vcs.Client = (*Adapter)(nil)
// NewAdapter creates a new Adapter wrapping the given gitea Client.
func NewAdapter(client *Client) *Adapter {
return &Adapter{client: client}
}
// Underlying returns the wrapped gitea.Client for Gitea-specific operations
// that have no vcs.Client equivalent (resolve comment, timeline, supersede flow).
func (a *Adapter) Underlying() *Client {
return a.client
}
// --- PRReader ---
// GetPullRequest maps gitea.PullRequest to vcs.PullRequest.
func (a *Adapter) GetPullRequest(ctx context.Context, owner, repo string, number int) (*vcs.PullRequest, error) {
pr, err := a.client.GetPullRequest(ctx, owner, repo, number)
if err != nil {
return nil, fmt.Errorf("get pull request: %w", err)
}
return &vcs.PullRequest{
Number: number,
Title: pr.Title,
Body: pr.Body,
Head: vcs.HeadRef{
SHA: pr.Head.Sha,
Ref: pr.Head.Ref,
},
Base: vcs.BaseRef{
Ref: pr.Base.Ref,
},
}, nil
}
// GetPullRequestDiff is a pass-through to the underlying client.
func (a *Adapter) GetPullRequestDiff(ctx context.Context, owner, repo string, number int) (string, error) {
return a.client.GetPullRequestDiff(ctx, owner, repo, number)
}
// GetPullRequestFiles maps []gitea.ChangedFile to []vcs.ChangedFile.
// Patch field is omitted (zero-value) since Gitea's /pulls/{n}/files does not return patch text.
func (a *Adapter) GetPullRequestFiles(ctx context.Context, owner, repo string, number int) ([]vcs.ChangedFile, error) {
files, err := a.client.GetPullRequestFiles(ctx, owner, repo, number)
if err != nil {
return nil, err
}
result := make([]vcs.ChangedFile, len(files))
for i, f := range files {
result[i] = vcs.ChangedFile{
Filename: f.Filename,
Status: f.Status,
}
}
return result, nil
}
// GetFileContentAtRef is a pass-through to the underlying client.
func (a *Adapter) GetFileContentAtRef(ctx context.Context, owner, repo, path, ref string) (string, error) {
return a.client.GetFileContentAtRef(ctx, owner, repo, path, ref)
}
// GetCommitStatuses maps []gitea.CommitStatus to []vcs.CommitStatus.
func (a *Adapter) GetCommitStatuses(ctx context.Context, owner, repo, sha string) ([]vcs.CommitStatus, error) {
statuses, err := a.client.GetCommitStatuses(ctx, owner, repo, sha)
if err != nil {
return nil, err
}
result := make([]vcs.CommitStatus, len(statuses))
for i, s := range statuses {
result[i] = vcs.CommitStatus{
Status: s.Status,
Context: s.Context,
Description: s.Description,
TargetURL: s.TargetURL,
}
}
return result, nil
}
// --- FileReader ---
// GetFileContent delegates to the underlying client, routing to the ref-aware
// variant when ref is non-empty.
func (a *Adapter) GetFileContent(ctx context.Context, owner, repo, path, ref string) (string, error) {
if ref != "" {
return a.client.GetFileContentRef(ctx, owner, repo, path, ref)
}
return a.client.GetFileContent(ctx, owner, repo, path)
}
// ListContents maps []gitea.ContentEntry to []vcs.ContentEntry.
func (a *Adapter) ListContents(ctx context.Context, owner, repo, path string) ([]vcs.ContentEntry, error) {
entries, err := a.client.ListContents(ctx, owner, repo, path)
if err != nil {
return nil, err
}
result := make([]vcs.ContentEntry, len(entries))
for i, e := range entries {
result[i] = vcs.ContentEntry{
Name: e.Name,
Path: e.Path,
Type: e.Type,
}
}
return result, nil
}
// --- Reviewer ---
// translateEvent translates a vcs.ReviewEvent (GitHub-canonical) to a Gitea-native event string.
func translateEvent(event vcs.ReviewEvent) string {
switch event {
case vcs.ReviewEventApprove:
return "APPROVED"
case vcs.ReviewEventRequestChanges:
return "REQUEST_CHANGES"
case vcs.ReviewEventComment:
return "COMMENT"
default:
// Unknown events pass through as-is. This is intentional: new event types
// added to vcs.ReviewEvent will still be forwarded without a code change here,
// and Gitea will reject truly invalid values with a clear API error.
return string(event)
}
}
// PostReview translates vcs.ReviewRequest to the Gitea-native format.
// It fetches the PR diff, builds a position-to-line map, and translates each
// ReviewComment.Position (GitHub diff-position) to a Gitea new_position (line number).
func (a *Adapter) PostReview(ctx context.Context, owner, repo string, number int, req vcs.ReviewRequest) (*vcs.Review, error) {
event := translateEvent(req.Event)
var giteaComments []ReviewComment
if len(req.Comments) > 0 {
// Fetch diff to build position → line number map.
// The diff is fetched unconditionally when comments exist. This adds latency
// for reviews with inline comments but keeps the implementation simple — caching
// the diff across calls would add complexity for minimal gain since PostReview
// is called at most once per review cycle.
diff, err := a.client.GetPullRequestDiff(ctx, owner, repo, number)
if err != nil {
return nil, fmt.Errorf("fetch diff for position translation: %w", err)
}
posMap := BuildPositionToLineMap(diff)
for _, c := range req.Comments {
lineNum, err := posMap.Translate(c.Path, c.Position)
if err != nil {
return nil, fmt.Errorf("translate position %d in %s: %w", c.Position, c.Path, err)
}
// CommitID from vcs.ReviewComment is intentionally not forwarded:
// Gitea review comments are pinned to the PR head SHA automatically,
// and the CreatePullReview API has no per-comment commit_id field.
giteaComments = append(giteaComments, ReviewComment{
Path: c.Path,
NewPosition: int64(lineNum),
Body: c.Body,
})
}
}
review, err := a.client.PostReview(ctx, owner, repo, number, event, req.Body, giteaComments)
if err != nil {
return nil, fmt.Errorf("post review: %w", err)
}
return &vcs.Review{
ID: review.ID,
Body: review.Body,
User: vcs.UserInfo{Login: review.User.Login},
State: review.State,
Stale: review.Stale,
CommitID: review.CommitID,
}, nil
}
// ListReviews maps []gitea.Review to []vcs.Review.
func (a *Adapter) ListReviews(ctx context.Context, owner, repo string, number int) ([]vcs.Review, error) {
reviews, err := a.client.ListReviews(ctx, owner, repo, number)
if err != nil {
return nil, err
}
result := make([]vcs.Review, len(reviews))
for i, r := range reviews {
result[i] = vcs.Review{
ID: r.ID,
Body: r.Body,
User: vcs.UserInfo{Login: r.User.Login},
State: r.State,
Stale: r.Stale,
CommitID: r.CommitID,
}
}
return result, nil
}
// DeleteReview is a pass-through to the underlying client.
func (a *Adapter) DeleteReview(ctx context.Context, owner, repo string, number int, reviewID int64) error {
return a.client.DeleteReview(ctx, owner, repo, number, reviewID)
}
// DismissReview deletes the review. Gitea supports full deletion of any review state.
// The message parameter is intentionally unused — Gitea deletion has no dismissal message.
func (a *Adapter) DismissReview(ctx context.Context, owner, repo string, number int, reviewID int64, message string) error {
return a.client.DeleteReview(ctx, owner, repo, number, reviewID)
}
// --- Identity ---
// GetAuthenticatedUser is a pass-through to the underlying client.
func (a *Adapter) GetAuthenticatedUser(ctx context.Context) (string, error) {
return a.client.GetAuthenticatedUser(ctx)
}
-388
View File
@@ -1,388 +0,0 @@
package gitea_test
import (
"context"
"encoding/json"
"net/http"
"net/http/httptest"
"strings"
"testing"
"gitea.weiker.me/rodin/review-bot/gitea"
"gitea.weiker.me/rodin/review-bot/vcs"
)
func TestAdapter_GetPullRequest(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(map[string]any{
"title": "Test PR",
"body": "PR body",
"head": map[string]any{
"sha": "abc123",
"ref": "feature-branch",
},
"base": map[string]any{
"ref": "main",
},
})
}))
defer server.Close()
client := gitea.NewClient(server.URL, "token")
adapter := gitea.NewAdapter(client)
pr, err := adapter.GetPullRequest(context.Background(), "owner", "repo", 42)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if pr.Number != 42 {
t.Errorf("Number = %d, want 42", pr.Number)
}
if pr.Title != "Test PR" {
t.Errorf("Title = %q, want %q", pr.Title, "Test PR")
}
if pr.Body != "PR body" {
t.Errorf("Body = %q, want %q", pr.Body, "PR body")
}
if pr.Head.SHA != "abc123" {
t.Errorf("Head.SHA = %q, want %q", pr.Head.SHA, "abc123")
}
if pr.Head.Ref != "feature-branch" {
t.Errorf("Head.Ref = %q, want %q", pr.Head.Ref, "feature-branch")
}
if pr.Base.Ref != "main" {
t.Errorf("Base.Ref = %q, want %q", pr.Base.Ref, "main")
}
}
func TestAdapter_GetPullRequestFiles(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode([]map[string]any{
{"filename": "main.go", "status": "modified"},
{"filename": "new.go", "status": "added"},
})
}))
defer server.Close()
client := gitea.NewClient(server.URL, "token")
adapter := gitea.NewAdapter(client)
files, err := adapter.GetPullRequestFiles(context.Background(), "owner", "repo", 1)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if len(files) != 2 {
t.Fatalf("got %d files, want 2", len(files))
}
if files[0].Filename != "main.go" || files[0].Status != "modified" {
t.Errorf("files[0] = %+v", files[0])
}
if files[1].Filename != "new.go" || files[1].Status != "added" {
t.Errorf("files[1] = %+v", files[1])
}
}
func TestAdapter_ListReviews(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode([]map[string]any{
{
"id": 1,
"body": "LGTM",
"user": map[string]any{"login": "reviewer1"},
"state": "APPROVED",
"stale": false,
"commit_id": "abc123",
},
{
"id": 2,
"body": "Needs work",
"user": map[string]any{"login": "reviewer2"},
"state": "REQUEST_CHANGES",
"stale": true,
"commit_id": "def456",
},
})
}))
defer server.Close()
client := gitea.NewClient(server.URL, "token")
adapter := gitea.NewAdapter(client)
reviews, err := adapter.ListReviews(context.Background(), "owner", "repo", 1)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if len(reviews) != 2 {
t.Fatalf("got %d reviews, want 2", len(reviews))
}
if reviews[0].ID != 1 || reviews[0].Body != "LGTM" || reviews[0].User.Login != "reviewer1" {
t.Errorf("reviews[0] = %+v", reviews[0])
}
if reviews[0].State != "APPROVED" || reviews[0].Stale || reviews[0].CommitID != "abc123" {
t.Errorf("reviews[0] state/stale/commit = %v/%v/%v", reviews[0].State, reviews[0].Stale, reviews[0].CommitID)
}
if reviews[1].ID != 2 || !reviews[1].Stale || reviews[1].State != "REQUEST_CHANGES" {
t.Errorf("reviews[1] = %+v", reviews[1])
}
}
func TestAdapter_GetCommitStatuses(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode([]map[string]any{
{
"status": "success",
"context": "ci/test",
"description": "All tests pass",
"target_url": "https://ci.example.com/1",
},
})
}))
defer server.Close()
client := gitea.NewClient(server.URL, "token")
adapter := gitea.NewAdapter(client)
statuses, err := adapter.GetCommitStatuses(context.Background(), "owner", "repo", "abc123")
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if len(statuses) != 1 {
t.Fatalf("got %d statuses, want 1", len(statuses))
}
if statuses[0].Status != "success" {
t.Errorf("Status = %q, want %q", statuses[0].Status, "success")
}
if statuses[0].Context != "ci/test" {
t.Errorf("Context = %q, want %q", statuses[0].Context, "ci/test")
}
if statuses[0].Description != "All tests pass" {
t.Errorf("Description = %q, want %q", statuses[0].Description, "All tests pass")
}
if statuses[0].TargetURL != "https://ci.example.com/1" {
t.Errorf("TargetURL = %q, want %q", statuses[0].TargetURL, "https://ci.example.com/1")
}
}
func TestAdapter_PostReview_EventTranslation(t *testing.T) {
tests := []struct {
name string
event vcs.ReviewEvent
wantEvent string
}{
{"APPROVE becomes APPROVED", vcs.ReviewEventApprove, "APPROVED"},
{"REQUEST_CHANGES stays", vcs.ReviewEventRequestChanges, "REQUEST_CHANGES"},
{"COMMENT stays", vcs.ReviewEventComment, "COMMENT"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
var gotEvent string
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
var payload struct {
Event string `json:"event"`
}
json.NewDecoder(r.Body).Decode(&payload)
gotEvent = payload.Event
json.NewEncoder(w).Encode(map[string]any{
"id": 1,
"body": "test",
"user": map[string]any{"login": "bot"},
})
}))
defer server.Close()
client := gitea.NewClient(server.URL, "token")
adapter := gitea.NewAdapter(client)
_, err := adapter.PostReview(context.Background(), "owner", "repo", 1, vcs.ReviewRequest{
Body: "test",
Event: tt.event,
// No comments → no diff fetch needed
})
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if gotEvent != tt.wantEvent {
t.Errorf("event = %q, want %q", gotEvent, tt.wantEvent)
}
})
}
}
func TestAdapter_PostReview_WithComments_PositionTranslation(t *testing.T) {
diff := `diff --git a/main.go b/main.go
--- a/main.go
+++ b/main.go
@@ -1,3 +1,4 @@
package main
+// new comment at line 3
func main() {}
`
var gotComments []struct {
Path string `json:"path"`
NewPosition int64 `json:"new_position"`
Body string `json:"body"`
}
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
if strings.HasSuffix(r.URL.Path, ".diff") {
// Diff request
w.Write([]byte(diff))
return
}
if strings.HasSuffix(r.URL.Path, "/reviews") {
// Review post
var payload struct {
Comments []struct {
Path string `json:"path"`
NewPosition int64 `json:"new_position"`
Body string `json:"body"`
} `json:"comments"`
}
json.NewDecoder(r.Body).Decode(&payload)
gotComments = payload.Comments
json.NewEncoder(w).Encode(map[string]any{
"id": 1,
"body": "review",
"user": map[string]any{"login": "bot"},
})
return
}
t.Errorf("unexpected request: %s %s", r.Method, r.URL.Path)
w.WriteHeader(http.StatusNotFound)
}))
defer server.Close()
client := gitea.NewClient(server.URL, "token")
adapter := gitea.NewAdapter(client)
// Position 4 in this diff is "+// new comment at line 3" → new line 3
_, err := adapter.PostReview(context.Background(), "owner", "repo", 1, vcs.ReviewRequest{
Body: "review",
Event: vcs.ReviewEventRequestChanges,
Comments: []vcs.ReviewComment{
{
Path: "main.go",
Position: 4,
CommitID: "abc123",
Body: "needs fix",
},
},
})
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if len(gotComments) != 1 {
t.Fatalf("got %d comments, want 1", len(gotComments))
}
if gotComments[0].Path != "main.go" {
t.Errorf("path = %q, want %q", gotComments[0].Path, "main.go")
}
if gotComments[0].NewPosition != 3 {
t.Errorf("new_position = %d, want 3", gotComments[0].NewPosition)
}
if gotComments[0].Body != "needs fix" {
t.Errorf("body = %q, want %q", gotComments[0].Body, "needs fix")
}
}
func TestAdapter_DismissReview(t *testing.T) {
var deleteCalled bool
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Method == http.MethodDelete {
deleteCalled = true
w.WriteHeader(204)
return
}
w.WriteHeader(404)
}))
defer server.Close()
client := gitea.NewClient(server.URL, "token")
adapter := gitea.NewAdapter(client)
err := adapter.DismissReview(context.Background(), "owner", "repo", 1, 99, "stale review")
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if !deleteCalled {
t.Error("expected delete to be called")
}
}
func TestAdapter_Underlying(t *testing.T) {
client := gitea.NewClient("http://example.com", "token")
adapter := gitea.NewAdapter(client)
if adapter.Underlying() != client {
t.Error("Underlying() should return the wrapped client")
}
}
func TestAdapter_ListContents(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode([]map[string]any{
{"name": "main.go", "path": "src/main.go", "type": "file"},
{"name": "util", "path": "src/util", "type": "dir"},
})
}))
defer server.Close()
client := gitea.NewClient(server.URL, "token")
adapter := gitea.NewAdapter(client)
entries, err := adapter.ListContents(context.Background(), "owner", "repo", "src")
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if len(entries) != 2 {
t.Fatalf("got %d entries, want 2", len(entries))
}
if entries[0].Name != "main.go" || entries[0].Type != "file" {
t.Errorf("entries[0] = %+v", entries[0])
}
if entries[1].Name != "util" || entries[1].Type != "dir" {
t.Errorf("entries[1] = %+v", entries[1])
}
}
func TestAdapter_GetFileContent_RefRouting(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// When ref is provided, the URL should contain ?ref=
if r.URL.RawQuery != "" && strings.Contains(r.URL.RawQuery, "ref=") {
w.Write([]byte("content-at-ref"))
} else {
w.Write([]byte("content-default"))
}
}))
defer server.Close()
client := gitea.NewClient(server.URL, "token")
adapter := gitea.NewAdapter(client)
// Empty ref → routes to GetFileContent (no ?ref= query param)
got, err := adapter.GetFileContent(context.Background(), "owner", "repo", "main.go", "")
if err != nil {
t.Fatalf("GetFileContent(ref=\"\"): %v", err)
}
if got != "content-default" {
t.Errorf("GetFileContent(ref=\"\") = %q, want %q", got, "content-default")
}
// Non-empty ref → routes to GetFileContentRef (with ?ref= query param)
got, err = adapter.GetFileContent(context.Background(), "owner", "repo", "main.go", "abc123")
if err != nil {
t.Fatalf("GetFileContent(ref=\"abc123\"): %v", err)
}
if got != "content-at-ref" {
t.Errorf("GetFileContent(ref=\"abc123\") = %q, want %q", got, "content-at-ref")
}
}
-3
View File
@@ -86,9 +86,6 @@ type PullRequest struct {
Sha string `json:"sha"` Sha string `json:"sha"`
Ref string `json:"ref"` Ref string `json:"ref"`
} `json:"head"` } `json:"head"`
Base struct {
Ref string `json:"ref"`
} `json:"base"`
} }
// CommitStatus represents a single CI status entry. // CommitStatus represents a single CI status entry.
+18 -3
View File
@@ -1,3 +1,5 @@
//go:build phase2
package gitea_test package gitea_test
import ( import (
@@ -5,6 +7,19 @@ import (
"gitea.weiker.me/rodin/review-bot/vcs" "gitea.weiker.me/rodin/review-bot/vcs"
) )
// Compile-time interface conformance assertion. // Compile-time interface conformance assertions.
// The Adapter (not the raw Client) satisfies the full vcs.Client interface. // These will verify gitea.Client satisfies vcs interfaces once the Phase 2
var _ vcs.Client = (*gitea.Adapter)(nil) // adapter bridges the method signature gaps:
//
// - PRReader: GetPullRequest returns *gitea.PullRequest (needs *vcs.PullRequest)
// - PRReader: GetPullRequestFiles returns []gitea.ChangedFile (needs []vcs.ChangedFile)
// - FileReader: GetFileContent lacks ref parameter
// - Reviewer: PostReview uses (event, body, comments) instead of vcs.ReviewRequest
//
// Remove the phase2 build tag once the adapter is complete.
var (
_ vcs.PRReader = (*gitea.Client)(nil)
_ vcs.FileReader = (*gitea.Client)(nil)
_ vcs.Reviewer = (*gitea.Client)(nil)
_ vcs.Identity = (*gitea.Client)(nil)
)
-197
View File
@@ -1,197 +0,0 @@
package gitea
import (
"fmt"
"strconv"
"strings"
)
// PositionMap holds a per-file mapping of GitHub diff-position to new-file line number.
// Position is a 1-indexed offset from the @@ hunk header line in the unified diff.
type PositionMap struct {
// files maps filename → (position → new-file line number).
// Deletion lines are mapped to -1 (no new-file line).
// Hunk-header lines are mapped to 0 (no new-file line).
files map[string]map[int]int
// maxPositions caches the highest position number per file,
// tracked during construction to avoid O(n) scans at translate time.
maxPositions map[string]int
}
// Translate converts a GitHub diff-position to a new-file line number for a given file.
// Returns an error if the file is not in the diff or the position is out of range.
// If the position targets a deletion or hunk-header line, it maps to the nearest
// context/addition line below; if no such line exists, returns an error.
func (pm *PositionMap) Translate(file string, position int) (int, error) {
if pm == nil || pm.files == nil {
return 0, fmt.Errorf("empty position map")
}
fileMap, ok := pm.files[file]
if !ok {
return 0, fmt.Errorf("file %q not found in diff", file)
}
if position < 1 {
return 0, fmt.Errorf("position %d out of range (must be >= 1)", position)
}
lineNum, ok := fileMap[position]
if !ok {
return 0, fmt.Errorf("position %d out of range for file %q", position, file)
}
// lineNum == -1 means this position is a deletion line.
// lineNum == 0 means this position is a hunk-header line.
// Both map to the nearest context/addition line below.
if lineNum <= 0 {
maxPos := pm.maxPosition(file)
for p := position + 1; p <= maxPos; p++ {
if ln, exists := fileMap[p]; exists && ln > 0 {
return ln, nil
}
}
if lineNum == 0 {
return 0, fmt.Errorf("position %d targets a hunk-header line with no subsequent new-file line in %q", position, file)
}
return 0, fmt.Errorf("position %d targets a deletion line with no subsequent new-file line in %q", position, file)
}
return lineNum, nil
}
// maxPosition returns the highest position number for a file.
// O(1) — the maximum is tracked during map construction.
func (pm *PositionMap) maxPosition(file string) int {
return pm.maxPositions[file]
}
// BuildPositionToLineMap parses a unified diff and builds a PositionMap
// mapping diff-position → new-file line number per file.
//
// Diff-position counting rules (GitHub spec):
// - The @@ hunk header line is position 1 for the file's first hunk
// - Every subsequent line increments position by 1 — context, additions, AND deletions
// - A new @@ hunk within the same file continues incrementing (does not reset)
// - Position maps to the new file line number for additions and context lines
// - Deletion lines have a position but no new-file line number (stored as -1)
// - Hunk-header lines have a position but no new-file line number (stored as 0)
func BuildPositionToLineMap(diff string) *PositionMap {
pm := &PositionMap{
files: make(map[string]map[int]int),
maxPositions: make(map[string]int),
}
lines := strings.Split(diff, "\n")
var currentFile string
var position int
var newLine int
for _, line := range lines {
// Detect new file in diff.
// "+++ b/" is checked before "+++ /dev/null" — the two prefixes are
// non-overlapping ("+++ /dev/null" does not start with "+++ b/"), so
// ordering is independent. Checking the common case first for clarity.
if strings.HasPrefix(line, "+++ b/") {
currentFile = strings.TrimPrefix(line, "+++ b/")
position = 0
newLine = 0
if pm.files[currentFile] == nil {
pm.files[currentFile] = make(map[int]int)
}
continue
}
// Deleted file: +++ /dev/null means the file is being deleted
if strings.HasPrefix(line, "+++ /dev/null") {
currentFile = ""
continue
}
// Skip --- lines (old file header)
if strings.HasPrefix(line, "--- ") {
continue
}
// Skip diff --git lines
if strings.HasPrefix(line, "diff --git") {
continue
}
// Skip index lines
if strings.HasPrefix(line, "index ") {
continue
}
// Binary file detection
if strings.HasPrefix(line, "Binary files") {
currentFile = ""
continue
}
// Parse hunk headers
if strings.HasPrefix(line, "@@") && currentFile != "" {
position++
pm.files[currentFile][position] = 0 // sentinel: hunk-header has no new-file line
pm.maxPositions[currentFile] = position
newLine = parseHunkStart(line)
continue
}
if currentFile == "" {
continue
}
// Skip "\ No newline at end of file" markers
if strings.HasPrefix(line, `\`) {
continue
}
// Process diff content lines
if strings.HasPrefix(line, "+") {
// Addition: has a new-file line number
position++
pm.files[currentFile][position] = newLine
pm.maxPositions[currentFile] = position
newLine++
} else if strings.HasPrefix(line, "-") {
// Deletion: has a position but no new-file line number
position++
pm.files[currentFile][position] = -1
pm.maxPositions[currentFile] = position
} else if strings.HasPrefix(line, " ") {
// Context line
position++
pm.files[currentFile][position] = newLine
pm.maxPositions[currentFile] = position
newLine++
}
}
return pm
}
// parseHunkStart extracts the new-file starting line number from a hunk header.
// Format: @@ -old_start[,old_count] +new_start[,new_count] @@
func parseHunkStart(hunkLine string) int {
plusIdx := strings.Index(hunkLine, "+")
if plusIdx < 0 {
return 1
}
rest := hunkLine[plusIdx+1:]
endIdx := 0
for endIdx < len(rest) && rest[endIdx] >= '0' && rest[endIdx] <= '9' {
endIdx++
}
if endIdx == 0 {
return 1
}
n, err := strconv.Atoi(rest[:endIdx])
if err != nil {
return 1
}
return n
}
-383
View File
@@ -1,383 +0,0 @@
package gitea
import (
"testing"
)
func TestBuildPositionToLineMap_SingleHunk(t *testing.T) {
// @@ -16,4 +16,5 @@ ← position 1
// context ← position 2, new line 16
//-deleted ← position 3, no new line
//+added ← position 4, new line 17
// context ← position 5, new line 18
diff := `diff --git a/file.go b/file.go
index abc..def 100644
--- a/file.go
+++ b/file.go
@@ -16,4 +16,5 @@ func example() {
context line
-deleted line
+added line
context after
`
pm := BuildPositionToLineMap(diff)
tests := []struct {
pos int
wantLine int
}{
{2, 16}, // context line -> new line 16
{4, 17}, // added line -> new line 17
{5, 18}, // context after -> new line 18
}
for _, tt := range tests {
got, err := pm.Translate("file.go", tt.pos)
if err != nil {
t.Errorf("Translate(file.go, %d): unexpected error: %v", tt.pos, err)
continue
}
if got != tt.wantLine {
t.Errorf("Translate(file.go, %d) = %d, want %d", tt.pos, got, tt.wantLine)
}
}
}
func TestBuildPositionToLineMap_MultipleHunks(t *testing.T) {
diff := `diff --git a/file.go b/file.go
--- a/file.go
+++ b/file.go
@@ -1,3 +1,3 @@ package main
line1
-old
+new
@@ -10,3 +10,4 @@ func foo() {
func foo() {
+ // added
return
}
`
pm := BuildPositionToLineMap(diff)
tests := []struct {
pos int
wantLine int
}{
// First hunk: @@ is pos 1
{2, 1}, // " line1" -> new line 1
{4, 2}, // "+new" -> new line 2
// Second hunk: @@ is pos 5 (continues from 4)
// Wait: first hunk has pos 1(@@ hdr), 2(" line1"), 3("-old"), 4("+new")
// Second hunk @@ is pos 5
{6, 10}, // " func foo() {" -> new line 10
{7, 11}, // "+\t// added" -> new line 11
{8, 12}, // " \treturn" -> new line 12
{9, 13}, // " }" -> new line 13
}
for _, tt := range tests {
got, err := pm.Translate("file.go", tt.pos)
if err != nil {
t.Errorf("Translate(file.go, %d): unexpected error: %v", tt.pos, err)
continue
}
if got != tt.wantLine {
t.Errorf("Translate(file.go, %d) = %d, want %d", tt.pos, got, tt.wantLine)
}
}
}
func TestBuildPositionToLineMap_DeletionTargeted(t *testing.T) {
diff := `diff --git a/file.go b/file.go
--- a/file.go
+++ b/file.go
@@ -1,4 +1,3 @@ package main
line1
-deleted
line3
`
pm := BuildPositionToLineMap(diff)
// Position 3 is the deletion line "-deleted" — should map to nearest below
// Position 4 is " line3" which is new line 2
got, err := pm.Translate("file.go", 3)
if err != nil {
t.Fatalf("Translate(file.go, 3): unexpected error: %v", err)
}
if got != 2 {
t.Errorf("Translate(file.go, 3) = %d, want 2 (nearest non-deletion below)", got)
}
}
func TestBuildPositionToLineMap_DeletionAtEnd(t *testing.T) {
// If a deletion line is at the end with no subsequent non-deletion line, error
diff := `diff --git a/file.go b/file.go
--- a/file.go
+++ b/file.go
@@ -1,3 +1,2 @@ package main
line1
line2
-deleted at end
`
pm := BuildPositionToLineMap(diff)
_, err := pm.Translate("file.go", 4)
if err == nil {
t.Error("expected error for deletion at end with no subsequent line")
}
}
func TestBuildPositionToLineMap_NewFile(t *testing.T) {
diff := `diff --git a/new.go b/new.go
new file mode 100644
--- /dev/null
+++ b/new.go
@@ -0,0 +1,3 @@
+package main
+
+func init() {}
`
pm := BuildPositionToLineMap(diff)
tests := []struct {
pos int
wantLine int
}{
{2, 1}, // "+package main" -> line 1
{3, 2}, // "+" (empty line) -> line 2
{4, 3}, // "+func init() {}" -> line 3
}
for _, tt := range tests {
got, err := pm.Translate("new.go", tt.pos)
if err != nil {
t.Errorf("Translate(new.go, %d): unexpected error: %v", tt.pos, err)
continue
}
if got != tt.wantLine {
t.Errorf("Translate(new.go, %d) = %d, want %d", tt.pos, got, tt.wantLine)
}
}
}
func TestBuildPositionToLineMap_DeletedFile(t *testing.T) {
diff := `diff --git a/old.go b/old.go
deleted file mode 100644
--- a/old.go
+++ /dev/null
@@ -1,3 +0,0 @@
-package main
-
-func old() {}
`
pm := BuildPositionToLineMap(diff)
// Deleted file has no new-file lines; positions should error
_, err := pm.Translate("old.go", 2)
if err == nil {
t.Error("expected error for deleted file position")
}
}
func TestBuildPositionToLineMap_BinaryFile(t *testing.T) {
diff := `diff --git a/image.png b/image.png
Binary files /dev/null and b/image.png differ
diff --git a/code.go b/code.go
--- a/code.go
+++ b/code.go
@@ -1,2 +1,3 @@
package main
+// added
func main() {}
`
pm := BuildPositionToLineMap(diff)
// Binary file should not be in the map
_, err := pm.Translate("image.png", 1)
if err == nil {
t.Error("expected error for binary file")
}
// code.go should still work
got, err := pm.Translate("code.go", 3)
if err != nil {
t.Fatalf("Translate(code.go, 3): unexpected error: %v", err)
}
if got != 2 {
t.Errorf("Translate(code.go, 3) = %d, want 2", got)
}
}
func TestBuildPositionToLineMap_OutOfRange(t *testing.T) {
diff := `diff --git a/file.go b/file.go
--- a/file.go
+++ b/file.go
@@ -1,2 +1,2 @@
line1
-old
+new
`
pm := BuildPositionToLineMap(diff)
// Position 0 is invalid
_, err := pm.Translate("file.go", 0)
if err == nil {
t.Error("expected error for position 0")
}
// Position 5 is out of range (only positions 1-4 exist)
_, err = pm.Translate("file.go", 5)
if err == nil {
t.Error("expected error for position 5 (out of range)")
}
// Unknown file
_, err = pm.Translate("unknown.go", 1)
if err == nil {
t.Error("expected error for unknown file")
}
}
func TestBuildPositionToLineMap_MultipleFiles(t *testing.T) {
diff := `diff --git a/a.go b/a.go
--- a/a.go
+++ b/a.go
@@ -1,2 +1,3 @@
package a
+// file a
func aFunc() {}
diff --git a/b.go b/b.go
--- a/b.go
+++ b/b.go
@@ -1,2 +1,3 @@
package b
+// file b
func bFunc() {}
`
pm := BuildPositionToLineMap(diff)
// a.go: pos 3 is "+// file a" -> new line 2
got, err := pm.Translate("a.go", 3)
if err != nil {
t.Fatalf("Translate(a.go, 3): %v", err)
}
if got != 2 {
t.Errorf("Translate(a.go, 3) = %d, want 2", got)
}
// b.go: pos 3 is "+// file b" -> new line 2
// Note: position resets per file
got, err = pm.Translate("b.go", 3)
if err != nil {
t.Fatalf("Translate(b.go, 3): %v", err)
}
if got != 2 {
t.Errorf("Translate(b.go, 3) = %d, want 2", got)
}
}
func TestTranslate_HunkHeaderPosition_SingleHunk(t *testing.T) {
// Position 1 is the @@ hunk-header line.
// It should resolve to the first context/addition line below (new line 16).
diff := `diff --git a/file.go b/file.go
index abc..def 100644
--- a/file.go
+++ b/file.go
@@ -16,4 +16,5 @@ func example() {
context line
-deleted line
+added line
context after
`
pm := BuildPositionToLineMap(diff)
got, err := pm.Translate("file.go", 1)
if err != nil {
t.Fatalf("Translate(file.go, 1): unexpected error: %v", err)
}
if got != 16 {
t.Errorf("Translate(file.go, 1) = %d, want 16 (first context/addition line in hunk)", got)
}
}
func TestTranslate_HunkHeaderPosition_MultiHunk(t *testing.T) {
// First hunk: @@ is pos 1, then " line1" (pos 2), "-old" (pos 3), "+new" (pos 4)
// Second hunk: @@ is pos 5, then " func foo() {" (pos 6), "+// added" (pos 7), etc.
// Translating position 5 (second @@) should resolve to new line 10.
diff := `diff --git a/file.go b/file.go
--- a/file.go
+++ b/file.go
@@ -1,3 +1,3 @@ package main
line1
-old
+new
@@ -10,3 +10,4 @@ func foo() {
func foo() {
+ // added
return
}
`
pm := BuildPositionToLineMap(diff)
// Position 5 is the second @@ hunk-header — should resolve to new line 10
got, err := pm.Translate("file.go", 5)
if err != nil {
t.Fatalf("Translate(file.go, 5): unexpected error: %v", err)
}
if got != 10 {
t.Errorf("Translate(file.go, 5) = %d, want 10 (first context/addition line in second hunk)", got)
}
// Also verify first hunk header at position 1 resolves to new line 1
got, err = pm.Translate("file.go", 1)
if err != nil {
t.Fatalf("Translate(file.go, 1): unexpected error: %v", err)
}
if got != 1 {
t.Errorf("Translate(file.go, 1) = %d, want 1 (first context/addition line in first hunk)", got)
}
}
func TestTranslate_HunkHeaderPosition_NewFile(t *testing.T) {
// New file: @@ -0,0 +1,3 @@ is position 1.
// Should resolve to new line 1 (the first addition).
diff := `diff --git a/new.go b/new.go
new file mode 100644
--- /dev/null
+++ b/new.go
@@ -0,0 +1,3 @@
+package main
+
+func init() {}
`
pm := BuildPositionToLineMap(diff)
got, err := pm.Translate("new.go", 1)
if err != nil {
t.Fatalf("Translate(new.go, 1): unexpected error: %v", err)
}
if got != 1 {
t.Errorf("Translate(new.go, 1) = %d, want 1 (first addition line)", got)
}
}
func TestTranslate_HunkHeaderAtEnd(t *testing.T) {
// A hunk-header at the last position with no subsequent new-file line should error.
// This is the hunk-header equivalent of TestBuildPositionToLineMap_DeletionAtEnd.
diff := `diff --git a/file.go b/file.go
--- a/file.go
+++ b/file.go
@@ -1,2 +1,2 @@ package main
line1
-old
+new
@@ -10,2 +10,1 @@ func foo() {
-removed
`
pm := BuildPositionToLineMap(diff)
// Position 5 is the second @@ hunk-header; the only line after it (pos 6) is a
// deletion (lineNum == -1), so there's no positive new-file line to resolve to.
// The hunk-header lookup should fail.
_, err := pm.Translate("file.go", 5)
if err == nil {
t.Error("expected error for hunk-header at end with no subsequent new-file line")
}
}
+39 -97
View File
@@ -4,9 +4,7 @@
package github package github
import ( import (
"bytes"
"context" "context"
"encoding/json"
"errors" "errors"
"fmt" "fmt"
"io" "io"
@@ -23,10 +21,6 @@ const (
// maxResponseBytes limits successful response body reads to 10 MiB. // maxResponseBytes limits successful response body reads to 10 MiB.
maxResponseBytes = 10 * 1024 * 1024 maxResponseBytes = 10 * 1024 * 1024
// maxRetryAttempts is the number of times doRequest will attempt a request.
// The retry backoff slice must have length maxRetryAttempts-1.
maxRetryAttempts = 3
) )
// APIError represents an HTTP error response from the GitHub API. // APIError represents an HTTP error response from the GitHub API.
@@ -53,6 +47,13 @@ func (e *APIError) Error() string {
return fmt.Sprintf("HTTP %d: %s", e.StatusCode, body) return fmt.Sprintf("HTTP %d: %s", e.StatusCode, body)
} }
// SafeError returns the error string without response body content,
// suitable for logging in contexts where upstream response data should
// not be exposed.
func (e *APIError) SafeError() string {
return fmt.Sprintf("HTTP %d", e.StatusCode)
}
// IsNotFound reports whether an error is an API 404 response. // IsNotFound reports whether an error is an API 404 response.
func IsNotFound(err error) bool { func IsNotFound(err error) bool {
if apiErr, ok := asAPIError(err); ok { if apiErr, ok := asAPIError(err); ok {
@@ -178,61 +179,43 @@ func (c *Client) SetHTTPClient(hc *http.Client) {
Timeout: 30 * time.Second, Timeout: 30 * time.Second,
CheckRedirect: defaultCheckRedirect, CheckRedirect: defaultCheckRedirect,
} }
} else if hc.CheckRedirect == nil {
// Enforce safe redirect policy when caller provides a client without one.
// The default net/http behavior follows up to 10 redirects and forwards
// all headers (including Authorization) to any host, which can leak
// credentials on cross-host redirects.
hc.CheckRedirect = defaultCheckRedirect
} }
c.httpClient = hc c.httpClient = hc
} }
// SetRetryBackoff configures the retry backoff durations for testing. // SetRetryBackoff configures the retry backoff durations for testing.
// It must be called before any goroutines issue requests. // It must be called before any goroutines issue requests.
// The slice must have exactly maxRetryAttempts-1 entries (one delay per retry gap).
// In production the default {1s, 2s} applies. // In production the default {1s, 2s} applies.
func (c *Client) SetRetryBackoff(d []time.Duration) error { func (c *Client) SetRetryBackoff(d []time.Duration) {
if len(d) != maxRetryAttempts-1 {
return fmt.Errorf("github: backoff length %d does not match maxRetryAttempts-1 (%d)", len(d), maxRetryAttempts-1)
}
c.retryBackoff = d c.retryBackoff = d
return nil
} }
// requestOptions holds per-request configuration for doRequestCore. // doRequest performs an HTTP request with retry on 429 rate limit responses.
type requestOptions struct { // It respects the Retry-After header when present (capped at maxRetryAfter).
// bodyFn returns a fresh io.Reader for the request body on each attempt. // Transport errors (network failures, context cancellation) are not retried.
// Must be non-nil for any request that carries a body (POST, PUT, PATCH, func (c *Client) doRequest(ctx context.Context, method, reqURL string, accept string) ([]byte, error) {
// or DELETE when a body is required by the API). const maxAttempts = 3
// Returning a fresh reader on each call allows retries to re-send the body.
bodyFn func() io.Reader
// accept overrides the default Accept header. Empty means "application/vnd.github+json".
accept string
// extraHeaders are additional headers to set on each request attempt.
extraHeaders map[string]string
}
// doRequestCore is the shared implementation for all HTTP requests with retry
// on 429 rate limit responses. It respects the Retry-After header when present
// (capped at maxRetryAfter). Transport errors are not retried.
func (c *Client) doRequestCore(ctx context.Context, method, reqURL string, opts requestOptions) ([]byte, error) {
const maxRetryAfter = 120 * time.Second const maxRetryAfter = 120 * time.Second
var backoff []time.Duration
if c.retryBackoff != nil {
backoff = make([]time.Duration, len(c.retryBackoff))
copy(backoff, c.retryBackoff)
} else {
backoff = []time.Duration{1 * time.Second, 2 * time.Second}
}
// maxErrorBodyBytes limits how much of an error response body is stored. // maxErrorBodyBytes limits how much of an error response body is stored.
// Kept small (4 KiB) to reduce the risk of sensitive data leakage if callers // Kept small (4 KiB) to reduce the risk of sensitive data leakage if callers
// log APIError.Body directly. Error() further truncates to 200 bytes. // log APIError.Body directly. Error() further truncates to 200 bytes.
const maxErrorBodyBytes = 4 * 1024 const maxErrorBodyBytes = 4 * 1024
// backoff holds per-attempt delays: backoff[i] is the delay before attempt i+1.
// Length must be maxRetryAttempts-1 (one entry per retry gap).
// SetRetryBackoff validates at configuration time; the default is always valid.
defaultBackoff := []time.Duration{1 * time.Second, 2 * time.Second}
var backoff []time.Duration
if c.retryBackoff != nil && len(c.retryBackoff) == maxRetryAttempts-1 {
backoff = make([]time.Duration, len(c.retryBackoff))
copy(backoff, c.retryBackoff)
} else {
backoff = make([]time.Duration, len(defaultBackoff))
copy(backoff, defaultBackoff)
}
// Reject non-HTTPS URLs early since the URL is immutable across retries. // Reject non-HTTPS URLs early since the URL is immutable across retries.
if c.token != "" && !c.allowInsecureHTTP { if c.token != "" && !c.allowInsecureHTTP {
parsed, err := url.Parse(reqURL) parsed, err := url.Parse(reqURL)
@@ -245,7 +228,7 @@ func (c *Client) doRequestCore(ctx context.Context, method, reqURL string, opts
} }
var lastErr error var lastErr error
for attempt := 0; attempt < maxRetryAttempts; attempt++ { for attempt := 0; attempt < maxAttempts; attempt++ {
if attempt > 0 { if attempt > 0 {
var delay time.Duration var delay time.Duration
if attempt-1 < len(backoff) { if attempt-1 < len(backoff) {
@@ -263,11 +246,7 @@ func (c *Client) doRequestCore(ctx context.Context, method, reqURL string, opts
} }
} }
var body io.Reader req, err := http.NewRequestWithContext(ctx, method, reqURL, nil)
if opts.bodyFn != nil {
body = opts.bodyFn()
}
req, err := http.NewRequestWithContext(ctx, method, reqURL, body)
if err != nil { if err != nil {
return nil, fmt.Errorf("create request: %w", err) return nil, fmt.Errorf("create request: %w", err)
} }
@@ -278,35 +257,29 @@ func (c *Client) doRequestCore(ctx context.Context, method, reqURL string, opts
req.Header.Set("Authorization", "Bearer "+c.token) req.Header.Set("Authorization", "Bearer "+c.token)
} }
req.Header.Set("User-Agent", userAgent) req.Header.Set("User-Agent", userAgent)
if opts.accept != "" { if accept != "" {
req.Header.Set("Accept", opts.accept) req.Header.Set("Accept", accept)
} else { } else {
req.Header.Set("Accept", "application/vnd.github+json") req.Header.Set("Accept", "application/vnd.github+json")
} }
for k, v := range opts.extraHeaders {
req.Header.Set(k, v)
}
resp, err := c.httpClient.Do(req) resp, err := c.httpClient.Do(req)
if err != nil { if err != nil {
// Transport errors (DNS, TLS, timeout) yield nil resp; no body to close.
return nil, fmt.Errorf("do request: %w", err) return nil, fmt.Errorf("do request: %w", err)
} }
// Capture response metadata before handleResponse takes body ownership. body, done, err := handleResponse(resp, maxResponseBytes, maxErrorBodyBytes)
respStatus := resp.StatusCode
retryAfterHeader := resp.Header.Get("Retry-After")
respBody, done, handleErr := c.handleResponse(resp, maxResponseBytes, maxErrorBodyBytes)
if done { if done {
return respBody, handleErr return body, err
} }
lastErr = handleErr lastErr = err
// Retry on 429 rate limit // Retry on 429 rate limit
if respStatus == http.StatusTooManyRequests && attempt < maxRetryAttempts-1 { if resp.StatusCode == http.StatusTooManyRequests && attempt < maxAttempts-1 {
// Check for Retry-After header and override backoff if present. // Check for Retry-After header and override backoff if present.
// Supports both integer seconds (common) and HTTP-date format (RFC 7231). // Supports both integer seconds (common) and HTTP-date format (RFC 7231).
if ra := retryAfterHeader; ra != "" { if ra := resp.Header.Get("Retry-After"); ra != "" {
if seconds, err := strconv.Atoi(ra); err == nil && seconds > 0 { if seconds, err := strconv.Atoi(ra); err == nil && seconds > 0 {
delay := time.Duration(seconds) * time.Second delay := time.Duration(seconds) * time.Second
if delay > maxRetryAfter { if delay > maxRetryAfter {
@@ -338,17 +311,10 @@ func (c *Client) doRequestCore(ctx context.Context, method, reqURL string, opts
return nil, lastErr return nil, lastErr
} }
// doRequest performs an HTTP request with retry on 429 rate limit responses.
// It respects the Retry-After header when present (capped at maxRetryAfter).
// Transport errors (network failures, context cancellation) are not retried.
func (c *Client) doRequest(ctx context.Context, method, reqURL string, accept string) ([]byte, error) {
return c.doRequestCore(ctx, method, reqURL, requestOptions{accept: accept})
}
// handleResponse reads and closes the response body, returning the result. // handleResponse reads and closes the response body, returning the result.
// It uses defer to ensure the body is always closed regardless of code path. // It uses defer to ensure the body is always closed regardless of code path.
// Returns (body, done, err) where done=true means the caller should return immediately. // Returns (body, done, err) where done=true means the caller should return immediately.
func (c *Client) handleResponse(resp *http.Response, maxRespBytes int, maxErrBytes int) ([]byte, bool, error) { func handleResponse(resp *http.Response, maxRespBytes int, maxErrBytes int) ([]byte, bool, error) {
defer resp.Body.Close() defer resp.Body.Close()
if resp.StatusCode >= 200 && resp.StatusCode < 300 { if resp.StatusCode >= 200 && resp.StatusCode < 300 {
@@ -357,7 +323,7 @@ func (c *Client) handleResponse(resp *http.Response, maxRespBytes int, maxErrByt
return nil, true, fmt.Errorf("read response body: %w", err) return nil, true, fmt.Errorf("read response body: %w", err)
} }
if len(body) > maxRespBytes { if len(body) > maxRespBytes {
return nil, true, fmt.Errorf("response body exceeded %d bytes", maxRespBytes) return nil, true, fmt.Errorf("response body exceeded %d bytes (truncated)", maxRespBytes)
} }
return body, true, nil return body, true, nil
} }
@@ -373,27 +339,3 @@ func (c *Client) handleResponse(resp *http.Response, maxRespBytes int, maxErrByt
func (c *Client) doGet(ctx context.Context, reqURL string) ([]byte, error) { func (c *Client) doGet(ctx context.Context, reqURL string) ([]byte, error) {
return c.doRequest(ctx, http.MethodGet, reqURL, "") return c.doRequest(ctx, http.MethodGet, reqURL, "")
} }
// doRequestWithBody is like doRequest but sends a request body.
// It accepts the raw body bytes and sets Content-Type to application/json.
// Retry semantics match doRequest (retries on 429 with Retry-After support).
func (c *Client) doRequestWithBody(ctx context.Context, method, reqURL string, reqBody []byte) ([]byte, error) {
var opts requestOptions
if reqBody != nil {
opts.bodyFn = func() io.Reader { return bytes.NewReader(reqBody) }
opts.extraHeaders = map[string]string{"Content-Type": "application/json"}
}
return c.doRequestCore(ctx, method, reqURL, opts)
}
// doJSONRequest performs an HTTP request with a JSON body and returns the response body.
// It delegates retry/backoff/429 handling to doRequestWithBody.
// This is a general-purpose helper used by any method that needs to send JSON payloads
// (e.g. PostReview, DismissReview).
func (c *Client) doJSONRequest(ctx context.Context, method, reqURL string, payload any) ([]byte, error) {
jsonBody, err := json.Marshal(payload)
if err != nil {
return nil, fmt.Errorf("marshal request body: %w", err)
}
return c.doRequestWithBody(ctx, method, reqURL, jsonBody)
}
+36 -88
View File
@@ -2,7 +2,6 @@ package github
import ( import (
"context" "context"
"errors"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"net/url" "net/url"
@@ -84,9 +83,7 @@ func TestDoRequest_429Retry(t *testing.T) {
c := NewClient("token", srv.URL, AllowInsecureHTTP()) c := NewClient("token", srv.URL, AllowInsecureHTTP())
c.SetHTTPClient(srv.Client()) c.SetHTTPClient(srv.Client())
if err := c.SetRetryBackoff([]time.Duration{10 * time.Millisecond, 10 * time.Millisecond}); err != nil { c.SetRetryBackoff([]time.Duration{10 * time.Millisecond, 10 * time.Millisecond})
t.Fatalf("SetRetryBackoff: %v", err)
}
body, err := c.doGet(context.Background(), srv.URL+"/test") body, err := c.doGet(context.Background(), srv.URL+"/test")
if err != nil { if err != nil {
@@ -111,9 +108,7 @@ func TestDoRequest_429ExhaustsRetries(t *testing.T) {
c := NewClient("token", srv.URL, AllowInsecureHTTP()) c := NewClient("token", srv.URL, AllowInsecureHTTP())
c.SetHTTPClient(srv.Client()) c.SetHTTPClient(srv.Client())
if err := c.SetRetryBackoff([]time.Duration{1 * time.Millisecond, 1 * time.Millisecond}); err != nil { c.SetRetryBackoff([]time.Duration{1 * time.Millisecond, 1 * time.Millisecond})
t.Fatalf("SetRetryBackoff: %v", err)
}
_, err := c.doGet(context.Background(), srv.URL+"/test") _, err := c.doGet(context.Background(), srv.URL+"/test")
if err == nil { if err == nil {
@@ -223,9 +218,7 @@ func TestDoRequest_429RetryAfterHeader(t *testing.T) {
c := NewClient("token", srv.URL, AllowInsecureHTTP()) c := NewClient("token", srv.URL, AllowInsecureHTTP())
c.SetHTTPClient(srv.Client()) c.SetHTTPClient(srv.Client())
// Use short backoff; Retry-After should override // Use short backoff; Retry-After should override
if err := c.SetRetryBackoff([]time.Duration{1 * time.Millisecond, 1 * time.Millisecond}); err != nil { c.SetRetryBackoff([]time.Duration{1 * time.Millisecond, 1 * time.Millisecond})
t.Fatalf("SetRetryBackoff: %v", err)
}
start := time.Now() start := time.Now()
body, err := c.doGet(context.Background(), srv.URL+"/test") body, err := c.doGet(context.Background(), srv.URL+"/test")
@@ -266,9 +259,7 @@ func TestDoRequest_RetryAfterDoesNotMutateBackoff(t *testing.T) {
c := NewClient("token", srv.URL, AllowInsecureHTTP()) c := NewClient("token", srv.URL, AllowInsecureHTTP())
c.SetHTTPClient(srv.Client()) c.SetHTTPClient(srv.Client())
if err := c.SetRetryBackoff([]time.Duration{1 * time.Millisecond, 1 * time.Millisecond}); err != nil { c.SetRetryBackoff([]time.Duration{1 * time.Millisecond, 1 * time.Millisecond})
t.Fatalf("SetRetryBackoff: %v", err)
}
_, err := c.doGet(context.Background(), srv.URL+"/test") _, err := c.doGet(context.Background(), srv.URL+"/test")
if err != nil { if err != nil {
@@ -306,9 +297,7 @@ func TestDoRequest_429RetryAfterHTTPDate(t *testing.T) {
c := NewClient("token", srv.URL, AllowInsecureHTTP()) c := NewClient("token", srv.URL, AllowInsecureHTTP())
c.SetHTTPClient(srv.Client()) c.SetHTTPClient(srv.Client())
if err := c.SetRetryBackoff([]time.Duration{1 * time.Millisecond, 1 * time.Millisecond}); err != nil { c.SetRetryBackoff([]time.Duration{1 * time.Millisecond, 1 * time.Millisecond})
t.Fatalf("SetRetryBackoff: %v", err)
}
start := time.Now() start := time.Now()
body, err := c.doGet(context.Background(), srv.URL+"/test") body, err := c.doGet(context.Background(), srv.URL+"/test")
@@ -349,9 +338,7 @@ func TestDoRequest_429RetryAfterHTTPDateInPast(t *testing.T) {
c := NewClient("token", srv.URL, AllowInsecureHTTP()) c := NewClient("token", srv.URL, AllowInsecureHTTP())
c.SetHTTPClient(srv.Client()) c.SetHTTPClient(srv.Client())
if err := c.SetRetryBackoff([]time.Duration{5 * time.Second, 5 * time.Second}); err != nil { c.SetRetryBackoff([]time.Duration{5 * time.Second, 5 * time.Second})
t.Fatalf("SetRetryBackoff: %v", err)
}
start := time.Now() start := time.Now()
_, err := c.doGet(context.Background(), srv.URL+"/test") _, err := c.doGet(context.Background(), srv.URL+"/test")
@@ -568,84 +555,45 @@ func TestSetHTTPClient_NilRestoresDefault(t *testing.T) {
} }
} }
func TestSetHTTPClient_NilCheckRedirectEnforcesDefault(t *testing.T) {
func TestSetRetryBackoff_RejectsInvalidLength(t *testing.T) {
c := NewClient("token", "https://api.github.com") c := NewClient("token", "https://api.github.com")
// Provide a client with nil CheckRedirect — should get default policy enforced.
// Too short hc := &http.Client{Timeout: 5 * time.Second}
err := c.SetRetryBackoff([]time.Duration{1 * time.Second}) c.SetHTTPClient(hc)
if err == nil { if c.httpClient.CheckRedirect == nil {
t.Fatal("expected error for backoff length 1") t.Fatal("expected CheckRedirect to be enforced when caller provides nil")
} }
if !strings.Contains(err.Error(), "backoff length 1") { if c.httpClient.Timeout != 5*time.Second {
t.Errorf("unexpected error message: %v", err) t.Errorf("expected caller's timeout preserved, got %v", c.httpClient.Timeout)
}
// Too long
err = c.SetRetryBackoff([]time.Duration{1 * time.Second, 2 * time.Second, 3 * time.Second})
if err == nil {
t.Fatal("expected error for backoff length 3")
}
// Correct length succeeds
err = c.SetRetryBackoff([]time.Duration{1 * time.Second, 2 * time.Second})
if err != nil {
t.Fatalf("unexpected error for valid backoff: %v", err)
} }
} }
func TestDoJSONRequest_429Retry(t *testing.T) { func TestSetHTTPClient_PreservesCustomCheckRedirect(t *testing.T) {
attempts := 0 c := NewClient("token", "https://api.github.com")
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { called := false
attempts++ hc := &http.Client{
if attempts < 3 { CheckRedirect: func(req *http.Request, via []*http.Request) error {
w.WriteHeader(429) called = true
w.Write([]byte(`{"message":"rate limit exceeded"}`)) return nil
return },
}
w.WriteHeader(200)
w.Write([]byte(`{"id":1}`))
}))
defer ts.Close()
c := NewClient("token", ts.URL, AllowInsecureHTTP())
if err := c.SetRetryBackoff([]time.Duration{1 * time.Millisecond, 1 * time.Millisecond}); err != nil {
t.Fatalf("SetRetryBackoff: %v", err)
} }
c.SetHTTPClient(hc)
body, err := c.doJSONRequest(context.Background(), http.MethodPost, ts.URL+"/test", map[string]string{"key": "val"}) // Invoke the redirect to verify original is preserved
if err != nil { _ = c.httpClient.CheckRedirect(nil, []*http.Request{{}})
t.Fatalf("unexpected error: %v", err) if !called {
} t.Fatal("expected custom CheckRedirect to be preserved")
if attempts != 3 {
t.Errorf("expected 3 attempts, got %d", attempts)
}
if string(body) != `{"id":1}` {
t.Errorf("unexpected body: %s", body)
} }
} }
func TestDoJSONRequest_429ExhaustsRetries(t *testing.T) { func TestAPIError_SafeError(t *testing.T) {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { e := &APIError{StatusCode: 403, Body: "some sensitive body content"}
w.WriteHeader(429) got := e.SafeError()
w.Write([]byte(`{"message":"rate limit"}`)) if got != "HTTP 403" {
})) t.Errorf("SafeError() = %q, want %q", got, "HTTP 403")
defer ts.Close()
c := NewClient("token", ts.URL, AllowInsecureHTTP())
if err := c.SetRetryBackoff([]time.Duration{1 * time.Millisecond, 1 * time.Millisecond}); err != nil {
t.Fatalf("SetRetryBackoff: %v", err)
} }
// Ensure Error() still includes body
_, err := c.doJSONRequest(context.Background(), http.MethodPost, ts.URL+"/test", map[string]string{"key": "val"}) full := e.Error()
if err == nil { if full != "HTTP 403: some sensitive body content" {
t.Fatal("expected error after exhausting retries") t.Errorf("Error() = %q, unexpected", full)
}
var apiErr *APIError
if !errors.As(err, &apiErr) {
t.Fatalf("expected APIError, got %T: %v", err, err)
}
if apiErr.StatusCode != 429 {
t.Errorf("expected 429, got %d", apiErr.StatusCode)
} }
} }
+6 -4
View File
@@ -5,7 +5,9 @@ import (
"gitea.weiker.me/rodin/review-bot/vcs" "gitea.weiker.me/rodin/review-bot/vcs"
) )
// Compile-time interface conformance assertion. // Compile-time interface conformance assertions.
// This verifies github.Client satisfies the full vcs.Client interface // These verify github.Client satisfies vcs.PRReader and vcs.FileReader.
// (PRReader, FileReader, Reviewer, Identity). var (
var _ vcs.Client = (*github.Client)(nil) _ vcs.PRReader = (*github.Client)(nil)
_ vcs.FileReader = (*github.Client)(nil)
)
+33 -58
View File
@@ -6,7 +6,6 @@ import (
"encoding/json" "encoding/json"
"fmt" "fmt"
"net/url" "net/url"
"path"
"strings" "strings"
"gitea.weiker.me/rodin/review-bot/vcs" "gitea.weiker.me/rodin/review-bot/vcs"
@@ -14,28 +13,25 @@ import (
// GetFileContent fetches a file from a repo at the given ref. // GetFileContent fetches a file from a repo at the given ref.
// Delegates to GetFileContentAtRef with the provided ref. // Delegates to GetFileContentAtRef with the provided ref.
func (c *Client) GetFileContent(ctx context.Context, owner, repo, filePath, ref string) (string, error) { func (c *Client) GetFileContent(ctx context.Context, owner, repo, path, ref string) (string, error) {
return c.GetFileContentAtRef(ctx, owner, repo, filePath, ref) return c.GetFileContentAtRef(ctx, owner, repo, path, ref)
} }
// GetFileContentAtRef fetches a file at a specific ref from a repo. // GetFileContentAtRef fetches a file at a specific ref from a repo.
// If ref is empty, the query parameter is omitted (uses default branch). // If ref is empty, the query parameter is omitted (uses default branch).
// //
// Returns an error if the path contains dot-segments (".", "..") or // Note: dot-segments ("." and "..") in the path are silently removed to
// attempts to traverse above the repository root. // prevent path traversal. This means a path like "foo/../bar" resolves
func (c *Client) GetFileContentAtRef(ctx context.Context, owner, repo, filePath, ref string) (string, error) { // to "foo/bar" rather than "bar".
escaped, err := escapePath(filePath) func (c *Client) GetFileContentAtRef(ctx context.Context, owner, repo, path, ref string) (string, error) {
if err != nil {
return "", fmt.Errorf("invalid file path: %w", err)
}
reqURL := fmt.Sprintf("%s/repos/%s/%s/contents/%s", reqURL := fmt.Sprintf("%s/repos/%s/%s/contents/%s",
c.baseURL, url.PathEscape(owner), url.PathEscape(repo), escaped) c.baseURL, url.PathEscape(owner), url.PathEscape(repo), escapePath(path))
if ref != "" { if ref != "" {
reqURL += "?ref=" + url.QueryEscape(ref) reqURL += "?ref=" + url.QueryEscape(ref)
} }
body, err := c.doGet(ctx, reqURL) body, err := c.doGet(ctx, reqURL)
if err != nil { if err != nil {
return "", fmt.Errorf("fetch file %s: %w", filePath, err) return "", fmt.Errorf("fetch file %s: %w", path, err)
} }
var resp struct { var resp struct {
Content string `json:"content"` Content string `json:"content"`
@@ -45,11 +41,11 @@ func (c *Client) GetFileContentAtRef(ctx context.Context, owner, repo, filePath,
return "", fmt.Errorf("parse file content JSON: %w", err) return "", fmt.Errorf("parse file content JSON: %w", err)
} }
if resp.Encoding != "base64" { if resp.Encoding != "base64" {
return "", fmt.Errorf("unexpected encoding %q for file %s", resp.Encoding, filePath) return "", fmt.Errorf("unexpected encoding %q for file %s", resp.Encoding, path)
} }
decoded, err := decodeBase64Content(resp.Content) decoded, err := decodeBase64Content(resp.Content)
if err != nil { if err != nil {
return "", fmt.Errorf("decode base64 content for %s: %w", filePath, err) return "", fmt.Errorf("decode base64 content for %s: %w", path, err)
} }
return decoded, nil return decoded, nil
} }
@@ -59,16 +55,16 @@ func (c *Client) GetFileContentAtRef(ctx context.Context, owner, repo, filePath,
// If the path points to a single file (not a directory), the API returns // If the path points to a single file (not a directory), the API returns
// a JSON object instead of an array; this is handled by returning a // a JSON object instead of an array; this is handled by returning a
// single-element slice. // single-element slice.
func (c *Client) ListContents(ctx context.Context, owner, repo, filePath string) ([]vcs.ContentEntry, error) { //
escaped, err := escapePath(filePath) // Note: dot-segments ("." and "..") in the path are silently removed to
if err != nil { // prevent path traversal. This means a path like "foo/../bar" resolves
return nil, fmt.Errorf("invalid file path: %w", err) // to "foo/bar" rather than "bar".
} func (c *Client) ListContents(ctx context.Context, owner, repo, path string) ([]vcs.ContentEntry, error) {
reqURL := fmt.Sprintf("%s/repos/%s/%s/contents/%s", reqURL := fmt.Sprintf("%s/repos/%s/%s/contents/%s",
c.baseURL, url.PathEscape(owner), url.PathEscape(repo), escaped) c.baseURL, url.PathEscape(owner), url.PathEscape(repo), escapePath(path))
body, err := c.doGet(ctx, reqURL) body, err := c.doGet(ctx, reqURL)
if err != nil { if err != nil {
return nil, fmt.Errorf("list contents %s: %w", filePath, err) return nil, fmt.Errorf("list contents %s: %w", path, err)
} }
type entry struct { type entry struct {
@@ -85,7 +81,7 @@ func (c *Client) ListContents(ctx context.Context, owner, repo, filePath string)
if err := json.Unmarshal(body, &entries); err != nil { if err := json.Unmarshal(body, &entries); err != nil {
var single entry var single entry
if err2 := json.Unmarshal(body, &single); err2 != nil { if err2 := json.Unmarshal(body, &single); err2 != nil {
return nil, fmt.Errorf("parse contents JSON: as array: %v; as object: %w", err, err2) return nil, fmt.Errorf("parse contents JSON: as array: %w; as object: %w", err, err2)
} }
// Guard against empty objects ({}) or unexpected shapes that // Guard against empty objects ({}) or unexpected shapes that
// unmarshal successfully but carry no useful data. // unmarshal successfully but carry no useful data.
@@ -106,55 +102,34 @@ func (c *Client) ListContents(ctx context.Context, owner, repo, filePath string)
return result, nil return result, nil
} }
// escapePath validates and encodes a slash-separated file path for use in // escapePath escapes each segment of a relative file path for use in URLs.
// GitHub API URLs. Returns an error if the path contains dot-segments ("." // Slashes are preserved as path separators; other special characters are escaped.
// or "..") or resolves to a path outside the repository root. // Dot-segments ("." and "..") and empty segments (from consecutive slashes like
func escapePath(p string) (string, error) { // "a//b") are silently removed to prevent path traversal and produce canonical
// Reject paths containing dot-segments rather than silently rewriting them. // paths. This is intentional: callers may receive a different path than requested
for _, seg := range strings.Split(p, "/") { // without error. The function is package-private, and all callers
if seg == "." || seg == ".." { // (GetFileContentAtRef, ListContents) already handle missing-file errors from the
return "", fmt.Errorf("path contains dot-segment %q: %s", seg, p) // API if the cleaned path doesn't match what the caller intended.
} func escapePath(p string) string {
} parts := strings.Split(p, "/")
var clean []string
// Use path.Clean for canonical form, then verify it doesn't escape root.
cleaned := path.Clean(p)
if cleaned == "." || strings.HasPrefix(cleaned, "..") {
return "", fmt.Errorf("path resolves outside repository root: %s", p)
}
// Encode each segment individually.
parts := strings.Split(cleaned, "/")
var encoded []string
for _, part := range parts { for _, part := range parts {
if part == "" { if part == "." || part == ".." || part == "" {
continue continue
} }
encoded = append(encoded, url.PathEscape(part)) clean = append(clean, url.PathEscape(part))
} }
return strings.Join(encoded, "/"), nil return strings.Join(clean, "/")
} }
// maxFileContentSize is the maximum decoded file size (10 MB) to prevent
// resource exhaustion when decoding base64 content from the API.
const maxFileContentSize = 10 * 1024 * 1024
// decodeBase64Content decodes base64-encoded content from the GitHub contents API. // decodeBase64Content decodes base64-encoded content from the GitHub contents API.
// GitHub returns base64 content with line breaks for formatting; we strip \r and \n before decoding. // GitHub returns base64 content with line breaks for formatting; we strip \r and \n before decoding.
// Returns an error if the decoded content exceeds maxFileContentSize.
func decodeBase64Content(encoded string) (string, error) { func decodeBase64Content(encoded string) (string, error) {
// GitHub inserts newlines in base64 content
cleaned := strings.NewReplacer("\n", "", "\r", "").Replace(encoded) cleaned := strings.NewReplacer("\n", "", "\r", "").Replace(encoded)
// Check estimated decoded size before allocating.
// Base64 encodes 3 bytes into 4 chars, so decoded ~ len*3/4.
if len(cleaned)*3/4 > maxFileContentSize {
return "", fmt.Errorf("file content too large: estimated %d bytes exceeds limit of %d", len(cleaned)*3/4, maxFileContentSize)
}
decoded, err := base64.StdEncoding.DecodeString(cleaned) decoded, err := base64.StdEncoding.DecodeString(cleaned)
if err != nil { if err != nil {
return "", err return "", err
} }
if len(decoded) > maxFileContentSize {
return "", fmt.Errorf("file content too large: %d bytes exceeds limit of %d", len(decoded), maxFileContentSize)
}
return string(decoded), nil return string(decoded), nil
} }
+54 -125
View File
@@ -5,7 +5,6 @@ import (
"encoding/json" "encoding/json"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"strings"
"testing" "testing"
"time" "time"
) )
@@ -110,9 +109,7 @@ func TestGetFileContent_429Retry(t *testing.T) {
c := NewClient("token", srv.URL, AllowInsecureHTTP()) c := NewClient("token", srv.URL, AllowInsecureHTTP())
c.SetHTTPClient(srv.Client()) c.SetHTTPClient(srv.Client())
if err := c.SetRetryBackoff([]time.Duration{1 * time.Millisecond, 1 * time.Millisecond}); err != nil { c.SetRetryBackoff([]time.Duration{1 * time.Millisecond})
t.Fatalf("SetRetryBackoff: %v", err)
}
content, err := c.GetFileContent(context.Background(), "owner", "repo", "file.go", "") content, err := c.GetFileContent(context.Background(), "owner", "repo", "file.go", "")
if err != nil { if err != nil {
@@ -230,11 +227,9 @@ func TestListContents_429Retry(t *testing.T) {
c := NewClient("token", srv.URL, AllowInsecureHTTP()) c := NewClient("token", srv.URL, AllowInsecureHTTP())
c.SetHTTPClient(srv.Client()) c.SetHTTPClient(srv.Client())
if err := c.SetRetryBackoff([]time.Duration{1 * time.Millisecond, 1 * time.Millisecond}); err != nil { c.SetRetryBackoff([]time.Duration{1 * time.Millisecond})
t.Fatalf("SetRetryBackoff: %v", err)
}
entries, err := c.ListContents(context.Background(), "owner", "repo", "src") entries, err := c.ListContents(context.Background(), "owner", "repo", ".")
if err != nil { if err != nil {
t.Fatalf("unexpected error: %v", err) t.Fatalf("unexpected error: %v", err)
} }
@@ -262,6 +257,57 @@ func TestListContents_MalformedJSON(t *testing.T) {
} }
} }
func TestDecodeBase64Content(t *testing.T) {
// Test with newlines (GitHub's format)
encoded := "cGFja2FnZSBt\nYWlu"
decoded, err := decodeBase64Content(encoded)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if decoded != "package main" {
t.Errorf("expected 'package main', got %q", decoded)
}
}
func TestDecodeBase64Content_Invalid(t *testing.T) {
_, err := decodeBase64Content("not!!!valid!!!base64")
if err == nil {
t.Fatal("expected error for invalid base64")
}
}
func TestEscapePath_RejectsDotSegments(t *testing.T) {
tests := []struct {
input string
want string
}{
{"src/main.go", "src/main.go"},
{"../etc/passwd", "etc/passwd"},
{"./src/../main.go", "src/main.go"},
{"a/b/c", "a/b/c"},
{"file with spaces.go", "file%20with%20spaces.go"},
{"a/./b/../c", "a/b/c"},
}
for _, tt := range tests {
got := escapePath(tt.input)
if got != tt.want {
t.Errorf("escapePath(%q) = %q, want %q", tt.input, got, tt.want)
}
}
}
func TestDecodeBase64Content_CRLF(t *testing.T) {
// Base64 of "hello world" with CRLF line breaks inserted
encoded := "aGVs\r\nbG8g\r\nd29y\r\nbGQ="
decoded, err := decodeBase64Content(encoded)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if decoded != "hello world" {
t.Errorf("expected 'hello world', got %q", decoded)
}
}
func TestListContents_SingleFile(t *testing.T) { func TestListContents_SingleFile(t *testing.T) {
// GitHub Contents API returns a JSON object (not array) for single-file paths // GitHub Contents API returns a JSON object (not array) for single-file paths
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
@@ -286,120 +332,3 @@ func TestListContents_SingleFile(t *testing.T) {
t.Errorf("expected type 'file', got %q", entries[0].Type) t.Errorf("expected type 'file', got %q", entries[0].Type)
} }
} }
func TestEscapePath_ValidPaths(t *testing.T) {
t.Parallel()
tests := []struct {
name string
path string
want string
}{
{"simple file", "file.go", "file.go"},
{"nested path", "path/to/file.go", "path/to/file.go"},
{"special chars", "path/to/my file.go", "path/to/my%20file.go"},
{"leading slash stripped", "/path/to/file.go", "path/to/file.go"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
got, err := escapePath(tt.path)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if got != tt.want {
t.Errorf("escapePath(%q) = %q, want %q", tt.path, got, tt.want)
}
})
}
}
func TestEscapePath_DotSegments(t *testing.T) {
t.Parallel()
tests := []struct {
name string
path string
}{
{"single dot", "./file.go"},
{"double dot", "../file.go"},
{"dot in middle", "path/./file.go"},
{"parent traversal", "path/../file.go"},
{"only dots", ".."},
{"nested parent traversal", "a/b/../../c"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
_, err := escapePath(tt.path)
if err == nil {
t.Fatalf("expected error for path %q, got nil", tt.path)
}
if !strings.Contains(err.Error(), "dot-segment") {
t.Errorf("expected error about dot-segment, got: %v", err)
}
})
}
}
func TestGetFileContentAtRef_DotSegmentError(t *testing.T) {
// Server should never be called — the error is caught before the request.
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
t.Fatal("server should not have been called")
}))
defer srv.Close()
c := NewClient("token", srv.URL, AllowInsecureHTTP())
_, err := c.GetFileContentAtRef(context.Background(), "owner", "repo", "foo/../bar.go", "main")
if err == nil {
t.Fatal("expected error for path with dot-segments")
}
if !strings.Contains(err.Error(), "invalid file path") {
t.Errorf("expected 'invalid file path' error, got: %v", err)
}
}
func TestDecodeBase64Content(t *testing.T) {
// Test with newlines (GitHub's format)
encoded := "cGFja2FnZSBt\nYWlu"
decoded, err := decodeBase64Content(encoded)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if decoded != "package main" {
t.Errorf("expected 'package main', got %q", decoded)
}
}
func TestDecodeBase64Content_Invalid(t *testing.T) {
_, err := decodeBase64Content("not!!!valid!!!base64")
if err == nil {
t.Fatal("expected error for invalid base64")
}
}
func TestDecodeBase64Content_CRLF(t *testing.T) {
// Base64 of "hello world" with CRLF line breaks inserted
encoded := "aGVs\r\nbG8g\r\nd29y\r\nbGQ="
decoded, err := decodeBase64Content(encoded)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if decoded != "hello world" {
t.Errorf("expected 'hello world', got %q", decoded)
}
}
func TestDecodeBase64Content_SizeLimit(t *testing.T) {
t.Parallel()
// Create base64 content that would decode to > maxFileContentSize.
// maxFileContentSize is 10MB. Base64 of 11MB worth of zeros.
// We just need something big enough to trigger the estimated size check.
// 14MB of base64 chars (decodes to ~10.5MB).
huge := strings.Repeat("A", 14*1024*1024)
_, err := decodeBase64Content(huge)
if err == nil {
t.Fatal("expected error for oversized content")
}
if !strings.Contains(err.Error(), "too large") {
t.Errorf("expected 'too large' error, got: %v", err)
}
}
-23
View File
@@ -1,23 +0,0 @@
package github
import (
"net/http"
"net/http/httptest"
"testing"
"time"
)
// newTestClient creates a *Client backed by an httptest.Server running the
// given handler. The server is automatically closed when the test finishes.
// Shared across test files in package github.
func newTestClient(t *testing.T, handler http.HandlerFunc) *Client {
t.Helper()
srv := httptest.NewServer(handler)
t.Cleanup(srv.Close)
c := NewClient("test-token", srv.URL, AllowInsecureHTTP())
c.SetHTTPClient(srv.Client())
if err := c.SetRetryBackoff([]time.Duration{1 * time.Millisecond, 1 * time.Millisecond}); err != nil {
t.Fatalf("SetRetryBackoff: %v", err)
}
return c
}
-29
View File
@@ -1,29 +0,0 @@
package github
import (
"context"
"encoding/json"
"fmt"
)
// userResponse is the GitHub API response for the authenticated user.
type userResponse struct {
Login string `json:"login"`
}
// GetAuthenticatedUser returns the login of the currently authenticated user.
func (c *Client) GetAuthenticatedUser(ctx context.Context) (string, error) {
reqURL := fmt.Sprintf("%s/user", c.baseURL)
body, err := c.doGet(ctx, reqURL)
if err != nil {
return "", fmt.Errorf("get authenticated user: %w", err)
}
var resp userResponse
if err := json.Unmarshal(body, &resp); err != nil {
return "", fmt.Errorf("parse user response: %w", err)
}
return resp.Login, nil
}
-46
View File
@@ -1,46 +0,0 @@
package github
import (
"context"
"encoding/json"
"net/http"
"testing"
)
func TestGetAuthenticatedUser_HappyPath(t *testing.T) {
c := newTestClient(t, func(w http.ResponseWriter, r *http.Request) {
if r.Method != "GET" {
t.Errorf("expected GET, got %s", r.Method)
}
if r.URL.Path != "/user" {
t.Errorf("unexpected path: %s", r.URL.Path)
}
if r.Header.Get("Authorization") != "Bearer test-token" {
t.Errorf("unexpected auth header: %s", r.Header.Get("Authorization"))
}
json.NewEncoder(w).Encode(map[string]string{"login": "review-bot"})
})
login, err := c.GetAuthenticatedUser(context.Background())
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if login != "review-bot" {
t.Errorf("expected login 'review-bot', got %q", login)
}
}
func TestGetAuthenticatedUser_401(t *testing.T) {
c := newTestClient(t, func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(401)
w.Write([]byte(`{"message":"Bad credentials"}`))
})
_, err := c.GetAuthenticatedUser(context.Background())
if err == nil {
t.Fatal("expected error for 401")
}
if !IsUnauthorized(err) {
t.Errorf("expected IsUnauthorized=true, got error: %v", err)
}
}
+16 -26
View File
@@ -51,10 +51,7 @@ type checkRunsResponse struct {
} `json:"check_runs"` } `json:"check_runs"`
} }
// GetPullRequest fetches PR metadata from the GitHub API. // GetPullRequest fetches PR metadata.
// Returns an *APIError wrapping the HTTP status on non-2xx responses (e.g.
// IsNotFound for 404, IsUnauthorized for 401). Network and context errors
// are wrapped but not typed as *APIError.
func (c *Client) GetPullRequest(ctx context.Context, owner, repo string, number int) (*vcs.PullRequest, error) { func (c *Client) GetPullRequest(ctx context.Context, owner, repo string, number int) (*vcs.PullRequest, error) {
reqURL := fmt.Sprintf("%s/repos/%s/%s/pulls/%d", c.baseURL, url.PathEscape(owner), url.PathEscape(repo), number) reqURL := fmt.Sprintf("%s/repos/%s/%s/pulls/%d", c.baseURL, url.PathEscape(owner), url.PathEscape(repo), number)
body, err := c.doGet(ctx, reqURL) body, err := c.doGet(ctx, reqURL)
@@ -85,15 +82,9 @@ func (c *Client) GetPullRequestDiff(ctx context.Context, owner, repo string, num
return string(body), nil return string(body), nil
} }
const ( // maxPages is the upper bound on pagination loops to prevent unbounded iteration
// maxFilesPages is the upper bound on pagination loops for PR file listing, // in case the server returns a full page indefinitely.
// preventing unbounded iteration if the server always returns a full page. const maxPages = 100
maxFilesPages = 100
// maxCheckRunPages is the upper bound on pagination loops for check-run listing,
// preventing unbounded iteration if the server always returns a full page.
maxCheckRunPages = 100
)
// GetPullRequestFiles fetches the list of files changed in a PR. // GetPullRequestFiles fetches the list of files changed in a PR.
// Paginates through all pages (100 per page) to collect all files. // Paginates through all pages (100 per page) to collect all files.
@@ -102,7 +93,7 @@ const (
func (c *Client) GetPullRequestFiles(ctx context.Context, owner, repo string, number int) ([]vcs.ChangedFile, error) { func (c *Client) GetPullRequestFiles(ctx context.Context, owner, repo string, number int) ([]vcs.ChangedFile, error) {
var allFiles []vcs.ChangedFile var allFiles []vcs.ChangedFile
for page := 1; page <= maxFilesPages; page++ { for page := 1; page <= maxPages; page++ {
reqURL := fmt.Sprintf("%s/repos/%s/%s/pulls/%d/files?per_page=100&page=%d", reqURL := fmt.Sprintf("%s/repos/%s/%s/pulls/%d/files?per_page=100&page=%d",
c.baseURL, url.PathEscape(owner), url.PathEscape(repo), number, page) c.baseURL, url.PathEscape(owner), url.PathEscape(repo), number, page)
body, err := c.doGet(ctx, reqURL) body, err := c.doGet(ctx, reqURL)
@@ -163,7 +154,7 @@ func (c *Client) GetCommitStatuses(ctx context.Context, owner, repo, sha string)
} }
// Fetch check runs (paginated) // Fetch check runs (paginated)
for checkPage := 1; checkPage <= maxCheckRunPages; checkPage++ { for checkPage := 1; checkPage <= maxPages; checkPage++ {
checkURL := fmt.Sprintf("%s/repos/%s/%s/commits/%s/check-runs?per_page=100&page=%d", checkURL := fmt.Sprintf("%s/repos/%s/%s/commits/%s/check-runs?per_page=100&page=%d",
c.baseURL, url.PathEscape(owner), url.PathEscape(repo), url.PathEscape(sha), checkPage) c.baseURL, url.PathEscape(owner), url.PathEscape(repo), url.PathEscape(sha), checkPage)
checkBody, err := c.doGet(ctx, checkURL) checkBody, err := c.doGet(ctx, checkURL)
@@ -178,7 +169,7 @@ func (c *Client) GetCommitStatuses(ctx context.Context, owner, repo, sha string)
result = append(result, vcs.CommitStatus{ result = append(result, vcs.CommitStatus{
Context: cr.Name, Context: cr.Name,
Status: mapCheckRunStatus(cr.Conclusion), Status: mapCheckRunStatus(cr.Conclusion),
Description: "", // check runs have no human-readable description; conclusion is captured in Status Description: derefString(cr.Conclusion),
TargetURL: cr.HTMLURL, TargetURL: cr.HTMLURL,
}) })
} }
@@ -190,17 +181,9 @@ func (c *Client) GetCommitStatuses(ctx context.Context, owner, repo, sha string)
return result, nil return result, nil
} }
// mapCheckRunStatus maps a GitHub check run conclusion to a vcs.CommitStatus status string. // mapCheckRunStatus maps a check run conclusion to a vcs.CommitStatus status string.
// Conclusion alone determines the mapped state: nil conclusion means the run is // Conclusion alone determines the mapped state: nil conclusion means the run is
// still in progress (pending), regardless of the status field value. // still in progress (pending), regardless of the status field value.
//
// Mapping rules:
// - nil → "pending" (run still in progress or queued)
// - "success" → "success"
// - "failure", "action_required", "timed_out" → "failure"
// - "cancelled", "skipped", "neutral" → "success" (non-blocking per GitHub check suite semantics)
// - "stale" → "pending" (check run became stale before completing)
// - unknown values → "pending" (conservative: treat unrecognized conclusions as incomplete)
func mapCheckRunStatus(conclusion *string) string { func mapCheckRunStatus(conclusion *string) string {
if conclusion == nil { if conclusion == nil {
// Still running or queued // Still running or queued
@@ -213,10 +196,17 @@ func mapCheckRunStatus(conclusion *string) string {
return "failure" return "failure"
case "cancelled", "skipped", "neutral": case "cancelled", "skipped", "neutral":
return "success" // non-blocking: these do not indicate a blocking failure per GitHub check suite semantics return "success" // non-blocking: these do not indicate a blocking failure per GitHub check suite semantics
case "stale": case "stale", "waiting":
return "pending" return "pending"
default: default:
return "pending" return "pending"
} }
} }
// derefString safely dereferences a string pointer, returning empty string if nil.
func derefString(s *string) string {
if s == nil {
return ""
}
return *s
}
-39
View File
@@ -545,7 +545,6 @@ func TestGetCommitStatuses_CheckRunConclusions(t *testing.T) {
name = *tt.conclusion name = *tt.conclusion
} }
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
t.Parallel()
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if strings.Contains(r.URL.Path, "/status") { if strings.Contains(r.URL.Path, "/status") {
json.NewEncoder(w).Encode(map[string]interface{}{ json.NewEncoder(w).Encode(map[string]interface{}{
@@ -633,44 +632,6 @@ func TestGetCommitStatuses_MalformedJSON(t *testing.T) {
} }
} }
func TestGetCommitStatuses_CheckRunsErrorAfterStatusesSucceed(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch {
case strings.Contains(r.URL.Path, "/status"):
// Statuses succeed
json.NewEncoder(w).Encode(map[string]interface{}{
"state": "success",
"statuses": []map[string]string{
{
"context": "ci/build",
"state": "success",
"description": "Build passed",
"target_url": "https://ci.example.com/1",
},
},
})
case strings.Contains(r.URL.Path, "/check-runs"):
// Check runs fail with 500
w.WriteHeader(500)
w.Write([]byte(`{"message":"Internal Server Error"}`))
default:
w.WriteHeader(404)
}
}))
defer srv.Close()
c := NewClient("token", srv.URL, AllowInsecureHTTP())
c.SetHTTPClient(srv.Client())
_, err := c.GetCommitStatuses(context.Background(), "owner", "repo", "abc123")
if err == nil {
t.Fatal("expected error when check-runs endpoint fails after statuses succeed")
}
if !strings.Contains(err.Error(), "fetch check runs") {
t.Errorf("expected check runs error, got: %v", err)
}
}
func stringPtr(s string) *string { func stringPtr(s string) *string {
return &s return &s
} }
-212
View File
@@ -1,212 +0,0 @@
package github
import (
"context"
"encoding/json"
"errors"
"fmt"
"net/http"
"net/url"
"gitea.weiker.me/rodin/review-bot/vcs"
)
// ErrCannotDeleteSubmittedReview is returned when DeleteReview is called on
// a review that has already been submitted (APPROVED, REQUEST_CHANGES, COMMENT).
// GitHub only allows deletion of PENDING reviews. Callers that need to replace
// a submitted review should use DismissReview instead.
var ErrCannotDeleteSubmittedReview = errors.New("cannot delete submitted review: use DismissReview instead")
// ErrConflictingCommitIDs is returned when PostReview receives comments with
// differing non-empty CommitIDs. The GitHub API accepts only a single commit_id
// per review submission; callers must ensure all comments target the same commit.
var ErrConflictingCommitIDs = errors.New("comments contain conflicting commit IDs: all must target the same commit")
// postReviewRequest is the GitHub API request body for creating a review.
type postReviewRequest struct {
CommitID string `json:"commit_id,omitempty"`
Body string `json:"body"`
Event string `json:"event"`
Comments []reviewCommentEntry `json:"comments,omitempty"`
}
// reviewCommentEntry is a single inline comment in a review creation request.
type reviewCommentEntry struct {
Path string `json:"path"`
Position int `json:"position"`
Body string `json:"body"`
}
// reviewResponse is the GitHub API response for a review.
type reviewResponse struct {
ID int64 `json:"id"`
Body string `json:"body"`
State string `json:"state"`
CommitID string `json:"commit_id"`
User struct {
Login string `json:"login"`
} `json:"user"`
}
// dismissReviewRequest is the GitHub API request body for dismissing a review.
type dismissReviewRequest struct {
Message string `json:"message"`
Event string `json:"event"`
}
// translateGitHubReviewState translates a GitHub API review state to the
// canonical vcs.Review.State value.
func translateGitHubReviewState(state string) string {
switch state {
case "CHANGES_REQUESTED":
return "REQUEST_CHANGES"
case "COMMENTED":
return "COMMENT"
default:
// States like APPROVED, DISMISSED, and PENDING pass through unchanged
// as they already match the canonical vcs representation. PENDING appears
// on draft reviews that have not yet been submitted via the GitHub UI or API.
return state
}
}
// PostReview submits a review on a pull request.
//
// The vcs.ReviewEvent constants (ReviewEventApprove, ReviewEventRequestChanges,
// ReviewEventComment) have string values that match GitHub's wire-format event
// strings (APPROVE, REQUEST_CHANGES, COMMENT), so Event is cast directly to
// string without translation.
//
// ReviewComment.Position maps directly to the GitHub API position field.
// When req.Comments is empty, the payload omits the comments field entirely
// (via the omitempty tag on postReviewRequest.Comments).
//
// The GitHub API accepts a single commit_id per review submission. PostReview
// extracts it from the first comment with a non-empty CommitID. If any subsequent
// comment specifies a different CommitID, PostReview returns ErrConflictingCommitIDs.
// Comments with an empty CommitID are allowed and inherit the review-level value.
func (c *Client) PostReview(ctx context.Context, owner, repo string, number int, req vcs.ReviewRequest) (*vcs.Review, error) {
reqURL := fmt.Sprintf("%s/repos/%s/%s/pulls/%d/reviews",
c.baseURL, url.PathEscape(owner), url.PathEscape(repo), number)
payload := postReviewRequest{
Body: req.Body,
Event: string(req.Event),
}
// Build the payload in one pass. The GitHub API accepts a single commit_id
// per review; we extract it from the first comment that supplies one and
// reject the request if any other comment disagrees.
for _, comment := range req.Comments {
if comment.CommitID != "" {
if payload.CommitID == "" {
payload.CommitID = comment.CommitID
} else if payload.CommitID != comment.CommitID {
return nil, ErrConflictingCommitIDs
}
}
payload.Comments = append(payload.Comments, reviewCommentEntry{
Path: comment.Path,
Position: comment.Position,
Body: comment.Body,
})
}
data, err := json.Marshal(payload)
if err != nil {
return nil, fmt.Errorf("marshal review request: %w", err)
}
body, err := c.doRequestWithBody(ctx, http.MethodPost, reqURL, data)
if err != nil {
return nil, fmt.Errorf("post review: %w", err)
}
var resp reviewResponse
if err := json.Unmarshal(body, &resp); err != nil {
return nil, fmt.Errorf("parse review response: %w", err)
}
return &vcs.Review{
ID: resp.ID,
Body: resp.Body,
User: vcs.UserInfo{Login: resp.User.Login},
State: translateGitHubReviewState(resp.State),
CommitID: resp.CommitID,
}, nil
}
// ListReviews retrieves all reviews for a pull request.
// GitHub review states are translated to canonical vcs values.
func (c *Client) ListReviews(ctx context.Context, owner, repo string, number int) ([]vcs.Review, error) {
reqURL := fmt.Sprintf("%s/repos/%s/%s/pulls/%d/reviews",
c.baseURL, url.PathEscape(owner), url.PathEscape(repo), number)
body, err := c.doGet(ctx, reqURL)
if err != nil {
return nil, fmt.Errorf("list reviews: %w", err)
}
var responses []reviewResponse
if err := json.Unmarshal(body, &responses); err != nil {
return nil, fmt.Errorf("parse reviews response: %w", err)
}
reviews := make([]vcs.Review, len(responses))
for i, r := range responses {
reviews[i] = vcs.Review{
ID: r.ID,
Body: r.Body,
User: vcs.UserInfo{Login: r.User.Login},
State: translateGitHubReviewState(r.State),
CommitID: r.CommitID,
}
}
return reviews, nil
}
// DeleteReview deletes a pull request review.
// Only PENDING reviews can be deleted; attempting to delete a submitted review
// (APPROVED, CHANGES_REQUESTED, or COMMENTED per GitHub API naming) returns
// ErrCannotDeleteSubmittedReview.
func (c *Client) DeleteReview(ctx context.Context, owner, repo string, number int, reviewID int64) error {
reqURL := fmt.Sprintf("%s/repos/%s/%s/pulls/%d/reviews/%d",
c.baseURL, url.PathEscape(owner), url.PathEscape(repo), number, reviewID)
// nil body: the GitHub DELETE endpoint for reviews requires no request body.
_, err := c.doRequestWithBody(ctx, http.MethodDelete, reqURL, nil)
if err != nil {
var apiErr *APIError
if errors.As(err, &apiErr) && apiErr.StatusCode == 422 {
return fmt.Errorf("delete review: %w", ErrCannotDeleteSubmittedReview)
}
return fmt.Errorf("delete review: %w", err)
}
return nil
}
// DismissReview dismisses a submitted review on a pull request.
// This is the correct way to "remove" a submitted review (APPROVED, REQUEST_CHANGES).
// GitHub does not allow deleting submitted reviews — they must be dismissed.
func (c *Client) DismissReview(ctx context.Context, owner, repo string, number int, reviewID int64, message string) error {
reqURL := fmt.Sprintf("%s/repos/%s/%s/pulls/%d/reviews/%d/dismissals",
c.baseURL, url.PathEscape(owner), url.PathEscape(repo), number, reviewID)
payload := dismissReviewRequest{
Message: message,
// Event is required by the GitHub API for dismissal requests, even though
// "DISMISS" is the only valid value for this endpoint.
Event: "DISMISS",
}
data, err := json.Marshal(payload)
if err != nil {
return fmt.Errorf("marshal dismiss request: %w", err)
}
_, err = c.doRequestWithBody(ctx, http.MethodPut, reqURL, data)
if err != nil {
return fmt.Errorf("dismiss review: %w", err)
}
return nil
}
-391
View File
@@ -1,391 +0,0 @@
package github
import (
"context"
"encoding/json"
"errors"
"io"
"net/http"
"strings"
"testing"
"gitea.weiker.me/rodin/review-bot/vcs"
)
// --- PostReview tests ---
func TestPostReview_HappyPath(t *testing.T) {
c := newTestClient(t, func(w http.ResponseWriter, r *http.Request) {
if r.Method != "POST" {
t.Fatalf("expected POST, got %s", r.Method)
}
if r.URL.Path != "/repos/owner/repo/pulls/5/reviews" {
t.Fatalf("unexpected path: %s", r.URL.Path)
}
if r.Header.Get("Content-Type") != "application/json" {
t.Errorf("expected Content-Type application/json, got %q", r.Header.Get("Content-Type"))
}
// Verify request body
body, _ := io.ReadAll(r.Body)
var req postReviewRequest
if err := json.Unmarshal(body, &req); err != nil {
t.Fatalf("unmarshal request: %v", err)
}
if req.Event != "APPROVE" {
t.Errorf("expected event APPROVE, got %q", req.Event)
}
if req.Body != "LGTM" {
t.Errorf("expected body 'LGTM', got %q", req.Body)
}
if req.CommitID != "abc123" {
t.Errorf("expected commit_id 'abc123', got %q", req.CommitID)
}
if len(req.Comments) != 1 {
t.Fatalf("expected 1 comment, got %d", len(req.Comments))
}
if req.Comments[0].Path != "main.go" {
t.Errorf("expected comment path 'main.go', got %q", req.Comments[0].Path)
}
if req.Comments[0].Position != 4 {
t.Errorf("expected comment position 4, got %d", req.Comments[0].Position)
}
json.NewEncoder(w).Encode(map[string]interface{}{
"id": 100,
"body": "LGTM",
"state": "APPROVED",
"commit_id": "abc123",
"user": map[string]string{"login": "reviewer"},
})
})
review, err := c.PostReview(context.Background(), "owner", "repo", 5, vcs.ReviewRequest{
Body: "LGTM",
Event: vcs.ReviewEventApprove,
Comments: []vcs.ReviewComment{
{Path: "main.go", Position: 4, CommitID: "abc123", Body: "nit: rename"},
},
})
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if review.ID != 100 {
t.Errorf("expected ID 100, got %d", review.ID)
}
if review.Body != "LGTM" {
t.Errorf("expected body 'LGTM', got %q", review.Body)
}
if review.State != "APPROVED" {
t.Errorf("expected state 'APPROVED', got %q", review.State)
}
if review.User.Login != "reviewer" {
t.Errorf("expected user 'reviewer', got %q", review.User.Login)
}
if review.CommitID != "abc123" {
t.Errorf("expected commit_id 'abc123', got %q", review.CommitID)
}
}
func TestPostReview_401(t *testing.T) {
c := newTestClient(t, func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(401)
w.Write([]byte(`{"message":"Bad credentials"}`))
})
_, err := c.PostReview(context.Background(), "owner", "repo", 5, vcs.ReviewRequest{
Body: "LGTM",
Event: vcs.ReviewEventApprove,
})
if err == nil {
t.Fatal("expected error for 401")
}
if !IsUnauthorized(err) {
t.Errorf("expected IsUnauthorized=true, got error: %v", err)
}
}
func TestPostReview_422(t *testing.T) {
c := newTestClient(t, func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(422)
w.Write([]byte(`{"message":"Unprocessable Entity"}`))
})
_, err := c.PostReview(context.Background(), "owner", "repo", 5, vcs.ReviewRequest{
Body: "LGTM",
Event: vcs.ReviewEventApprove,
})
if err == nil {
t.Fatal("expected error for 422")
}
// 422 should surface as a wrapped APIError
var apiErr *APIError
if !errors.As(err, &apiErr) {
t.Fatalf("expected *APIError, got %T: %v", err, err)
}
if apiErr.StatusCode != 422 {
t.Errorf("expected status 422, got %d", apiErr.StatusCode)
}
}
func TestPostReview_MalformedResponse(t *testing.T) {
c := newTestClient(t, func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte(`not json`))
})
_, err := c.PostReview(context.Background(), "owner", "repo", 5, vcs.ReviewRequest{
Body: "LGTM",
Event: vcs.ReviewEventApprove,
})
if err == nil {
t.Fatal("expected error for malformed response")
}
if !strings.Contains(err.Error(), "parse review response") {
t.Errorf("expected parse error, got: %v", err)
}
}
// --- ListReviews tests ---
func TestListReviews_HappyPath(t *testing.T) {
c := newTestClient(t, func(w http.ResponseWriter, r *http.Request) {
if r.Method != "GET" {
t.Fatalf("expected GET, got %s", r.Method)
}
if r.URL.Path != "/repos/owner/repo/pulls/3/reviews" {
t.Fatalf("unexpected path: %s", r.URL.Path)
}
json.NewEncoder(w).Encode([]map[string]interface{}{
{
"id": 1,
"body": "Approved",
"state": "APPROVED",
"commit_id": "sha1",
"user": map[string]string{"login": "user1"},
},
{
"id": 2,
"body": "Needs work",
"state": "CHANGES_REQUESTED",
"commit_id": "sha2",
"user": map[string]string{"login": "user2"},
},
{
"id": 3,
"body": "Comment only",
"state": "COMMENTED",
"commit_id": "sha3",
"user": map[string]string{"login": "user3"},
},
{
"id": 4,
"body": "Old review",
"state": "DISMISSED",
"commit_id": "sha4",
"user": map[string]string{"login": "user4"},
},
})
})
reviews, err := c.ListReviews(context.Background(), "owner", "repo", 3)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if len(reviews) != 4 {
t.Fatalf("expected 4 reviews, got %d", len(reviews))
}
// Check state translation
expected := []struct {
id int64
state string
}{
{1, "APPROVED"},
{2, "REQUEST_CHANGES"},
{3, "COMMENT"},
{4, "DISMISSED"},
}
for i, e := range expected {
if reviews[i].ID != e.id {
t.Errorf("review[%d]: expected ID %d, got %d", i, e.id, reviews[i].ID)
}
if reviews[i].State != e.state {
t.Errorf("review[%d]: expected state %q, got %q", i, e.state, reviews[i].State)
}
}
}
func TestListReviews_404(t *testing.T) {
c := newTestClient(t, func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(404)
w.Write([]byte(`{"message":"Not Found"}`))
})
_, err := c.ListReviews(context.Background(), "owner", "repo", 999)
if err == nil {
t.Fatal("expected error for 404")
}
if !IsNotFound(err) {
t.Errorf("expected IsNotFound=true, got error: %v", err)
}
}
func TestListReviews_401(t *testing.T) {
c := newTestClient(t, func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(401)
w.Write([]byte(`{"message":"Bad credentials"}`))
})
_, err := c.ListReviews(context.Background(), "owner", "repo", 3)
if err == nil {
t.Fatal("expected error for 401")
}
if !IsUnauthorized(err) {
t.Errorf("expected IsUnauthorized=true, got error: %v", err)
}
}
// --- DeleteReview tests ---
func TestDeleteReview_HappyPath(t *testing.T) {
c := newTestClient(t, func(w http.ResponseWriter, r *http.Request) {
if r.Method != "DELETE" {
t.Fatalf("expected DELETE, got %s", r.Method)
}
if r.URL.Path != "/repos/owner/repo/pulls/5/reviews/42" {
t.Fatalf("unexpected path: %s", r.URL.Path)
}
w.WriteHeader(204)
})
err := c.DeleteReview(context.Background(), "owner", "repo", 5, 42)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
}
func TestDeleteReview_422_SubmittedReview(t *testing.T) {
c := newTestClient(t, func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(422)
w.Write([]byte(`{"message":"Can not delete a non pending review"}`))
})
err := c.DeleteReview(context.Background(), "owner", "repo", 5, 42)
if err == nil {
t.Fatal("expected error for 422")
}
if !errors.Is(err, ErrCannotDeleteSubmittedReview) {
t.Errorf("expected ErrCannotDeleteSubmittedReview, got: %v", err)
}
}
// --- DismissReview tests ---
func TestDismissReview_HappyPath(t *testing.T) {
c := newTestClient(t, func(w http.ResponseWriter, r *http.Request) {
if r.Method != "PUT" {
t.Fatalf("expected PUT, got %s", r.Method)
}
if r.URL.Path != "/repos/owner/repo/pulls/5/reviews/10/dismissals" {
t.Fatalf("unexpected path: %s", r.URL.Path)
}
body, _ := io.ReadAll(r.Body)
var req dismissReviewRequest
if err := json.Unmarshal(body, &req); err != nil {
t.Fatalf("unmarshal request: %v", err)
}
if req.Message != "Superseded by new review" {
t.Errorf("expected message 'Superseded by new review', got %q", req.Message)
}
if req.Event != "DISMISS" {
t.Errorf("expected event 'DISMISS', got %q", req.Event)
}
json.NewEncoder(w).Encode(map[string]interface{}{
"id": 10,
"state": "DISMISSED",
})
})
err := c.DismissReview(context.Background(), "owner", "repo", 5, 10, "Superseded by new review")
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
}
func TestDismissReview_404(t *testing.T) {
c := newTestClient(t, func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(404)
w.Write([]byte(`{"message":"Not Found"}`))
})
err := c.DismissReview(context.Background(), "owner", "repo", 5, 999, "dismiss")
if err == nil {
t.Fatal("expected error for 404")
}
if !IsNotFound(err) {
t.Errorf("expected IsNotFound=true, got error: %v", err)
}
}
func TestDismissReview_401(t *testing.T) {
c := newTestClient(t, func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(401)
w.Write([]byte(`{"message":"Bad credentials"}`))
})
err := c.DismissReview(context.Background(), "owner", "repo", 5, 10, "dismiss")
if err == nil {
t.Fatal("expected error for 401")
}
if !IsUnauthorized(err) {
t.Errorf("expected IsUnauthorized=true, got error: %v", err)
}
}
// --- State translation tests ---
func TestTranslateGitHubReviewState(t *testing.T) {
tests := []struct {
name string
input string
want string
}{
{"approved passes through", "APPROVED", "APPROVED"},
{"changes_requested maps to REQUEST_CHANGES", "CHANGES_REQUESTED", "REQUEST_CHANGES"},
{"commented maps to COMMENT", "COMMENTED", "COMMENT"},
{"dismissed passes through", "DISMISSED", "DISMISSED"},
{"unknown state passes through", "UNKNOWN_STATE", "UNKNOWN_STATE"},
{"empty string passes through", "", ""},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := translateGitHubReviewState(tt.input)
if got != tt.want {
t.Errorf("translateGitHubReviewState(%q) = %q, want %q", tt.input, got, tt.want)
}
})
}
}
func TestPostReview_ConflictingCommitIDs(t *testing.T) {
c := newTestClient(t, func(w http.ResponseWriter, r *http.Request) {
t.Fatal("request should not be sent when commit IDs conflict")
})
_, err := c.PostReview(context.Background(), "owner", "repo", 5, vcs.ReviewRequest{
Body: "Review",
Event: vcs.ReviewEventComment,
Comments: []vcs.ReviewComment{
{Path: "a.go", Position: 1, CommitID: "sha-1", Body: "first"},
{Path: "b.go", Position: 2, CommitID: "sha-2", Body: "second"},
},
})
if err == nil {
t.Fatal("expected error for conflicting commit IDs")
}
if !errors.Is(err, ErrConflictingCommitIDs) {
t.Errorf("expected ErrConflictingCommitIDs, got: %v", err)
}
}
+1 -1
View File
@@ -2,4 +2,4 @@ module gitea.weiker.me/rodin/review-bot
go 1.26.2 go 1.26.2
require github.com/goccy/go-yaml v1.19.2 require gopkg.in/yaml.v3 v3.0.1
+4 -2
View File
@@ -1,2 +1,4 @@
github.com/goccy/go-yaml v1.19.2 h1:PmFC1S6h8ljIz6gMRBopkjP1TVT7xuwrButHID66PoM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
github.com/goccy/go-yaml v1.19.2/go.mod h1:XBurs7gK8ATbW4ZPGKgcbrY1Br56PdM69F7LkFRi1kA= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
+12
View File
@@ -10,6 +10,18 @@ func FormatMarkdown(result *ReviewResult, reviewerName string) string {
return FormatMarkdownWithDisplay(result, reviewerName, reviewerName) return FormatMarkdownWithDisplay(result, reviewerName, reviewerName)
} }
// GiteaEvent converts the verdict to the Gitea API event string.
func GiteaEvent(verdict string) string {
switch verdict {
case "APPROVE":
return "APPROVED"
case "REQUEST_CHANGES":
return "REQUEST_CHANGES"
default:
return "COMMENT"
}
}
// FormatMarkdownWithDisplay formats a ReviewResult with separate display name and sentinel name. // FormatMarkdownWithDisplay formats a ReviewResult with separate display name and sentinel name.
// Note: displayName is not HTML-escaped as Gitea sanitizes rendered Markdown. // Note: displayName is not HTML-escaped as Gitea sanitizes rendered Markdown.
// Persona display names are controlled by repo owners (trusted input). // Persona display names are controlled by repo owners (trusted input).
+19
View File
@@ -98,6 +98,25 @@ func TestFormatMarkdown_SpecialChars(t *testing.T) {
} }
} }
func TestGiteaEvent(t *testing.T) {
tests := []struct {
verdict string
expected string
}{
{"APPROVE", "APPROVED"},
{"REQUEST_CHANGES", "REQUEST_CHANGES"},
{"UNKNOWN", "COMMENT"},
{"", "COMMENT"},
}
for _, tc := range tests {
got := GiteaEvent(tc.verdict)
if got != tc.expected {
t.Errorf("GiteaEvent(%q) = %q, want %q", tc.verdict, got, tc.expected)
}
}
}
func TestFormatMarkdown_Sentinel(t *testing.T) { func TestFormatMarkdown_Sentinel(t *testing.T) {
result := &ReviewResult{ result := &ReviewResult{
Verdict: "APPROVE", Verdict: "APPROVE",
+38 -146
View File
@@ -5,15 +5,12 @@ import (
"embed" "embed"
"encoding/json" "encoding/json"
"fmt" "fmt"
"io"
"os" "os"
"sort" "sort"
"strings" "strings"
"unicode/utf8" "unicode/utf8"
"github.com/goccy/go-yaml" "gopkg.in/yaml.v3"
"github.com/goccy/go-yaml/ast"
"github.com/goccy/go-yaml/parser"
) )
//go:embed personas/*.yaml //go:embed personas/*.yaml
@@ -121,7 +118,9 @@ func ListBuiltinPersonas() []string {
default: default:
continue continue
} }
seen[personaName] = true if !seen[personaName] {
seen[personaName] = true
}
} }
names := make([]string, 0, len(seen)) names := make([]string, 0, len(seen))
for name := range seen { for name := range seen {
@@ -143,19 +142,10 @@ func parsePersona(data []byte, source string) (*Persona, error) {
err = unmarshalYAMLWithDepthLimit(data, &p, MaxYAMLDepth) err = unmarshalYAMLWithDepthLimit(data, &p, MaxYAMLDepth)
} else { } else {
// Use json.Decoder with DisallowUnknownFields for consistency with // Use json.Decoder with DisallowUnknownFields for consistency with
// YAML's Strict() - both reject unknown fields to catch typos. // YAML's KnownFields(true) - both reject unknown fields to catch typos.
dec := json.NewDecoder(bytes.NewReader(data)) dec := json.NewDecoder(bytes.NewReader(data))
dec.DisallowUnknownFields() dec.DisallowUnknownFields()
err = dec.Decode(&p) err = dec.Decode(&p)
if err == nil {
// Reject trailing content after the first valid JSON object.
// Without this check, input like `{"name":"x"}garbage` would
// silently succeed because Decoder stops after one object.
var dummy json.RawMessage
if err2 := dec.Decode(&dummy); err2 != io.EOF {
err = fmt.Errorf("unexpected trailing content after JSON object")
}
}
} }
if err != nil { if err != nil {
return nil, fmt.Errorf("parse persona %s: %w", source, err) return nil, fmt.Errorf("parse persona %s: %w", source, err)
@@ -166,164 +156,70 @@ func parsePersona(data []byte, source string) (*Persona, error) {
return &p, nil return &p, nil
} }
// unmarshalYAMLWithDepthLimit unmarshals YAML data with three safety checks: // unmarshalYAMLWithDepthLimit unmarshals YAML data with explicit depth limiting
// - Depth limiting: rejects AST trees exceeding maxDepth to prevent stack exhaustion. // and strict field checking. This protects against stack exhaustion from deeply
// - Multi-document rejection: prevents silent data loss from ignored extra documents. // nested structures and catches typos in field names.
// - Strict field checking: rejects unknown YAML keys to catch typos early. // Multi-document YAML files are rejected to prevent silent data loss.
func unmarshalYAMLWithDepthLimit(data []byte, out any, maxDepth int) error { func unmarshalYAMLWithDepthLimit(data []byte, out any, maxDepth int) error {
// First pass: parse into AST to check depth limits, node counts, and // First pass: decode into a yaml.Node to check depth limits and node counts.
// multi-document rejection. This prevents stack exhaustion before we // This prevents stack exhaustion before we attempt to decode into structs.
// attempt to decode into structs. var node yaml.Node
file, err := parser.ParseBytes(data, 0) dec := yaml.NewDecoder(bytes.NewReader(data))
if err != nil { if err := dec.Decode(&node); err != nil {
return err return err
} }
// Reject empty YAML input (whitespace-only, comment-only, or truly empty files).
// The parser returns a single doc with nil body for these cases.
if len(file.Docs) == 0 || file.Docs[0].Body == nil {
return fmt.Errorf("empty YAML document")
}
// Reject multi-document YAML files - silently ignoring additional documents // Reject multi-document YAML files - silently ignoring additional documents
// could lead to confusing behavior where users think their changes take effect. // could lead to confusing behavior where users think their changes take effect.
if len(file.Docs) > 1 { var extra yaml.Node
if dec.Decode(&extra) == nil {
return fmt.Errorf("multi-document YAML is not supported; only single-document files are allowed") return fmt.Errorf("multi-document YAML is not supported; only single-document files are allowed")
} }
nodeCount := 0 nodeCount := 0
if err := checkYAMLDepth(file.Docs[0].Body, 0, maxDepth, MaxYAMLNodes, make(map[ast.Node]int), make(map[ast.Node]bool), &nodeCount); err != nil { if err := checkYAMLDepth(&node, 0, maxDepth, MaxYAMLNodes, make(map[*yaml.Node]struct{}), &nodeCount); err != nil {
return err return err
} }
// Second pass: decode with strict field checking enabled. // Second pass: decode with strict field checking enabled.
// Strict() rejects unknown keys, catching typos like "focuss" or "identiy". // KnownFields(true) rejects unknown keys, catching typos like "focuss" or "identiy".
// // We must re-decode from the original data because yaml.Node.Decode() doesn't
// Safety note: goccy/go-yaml's decoder does not expand YAML aliases // support the KnownFields option.
// recursively — it resolves them via the pre-built AST, which our first strictDec := yaml.NewDecoder(bytes.NewReader(data))
// pass already depth-checked. Alias chains that would exceed depth limits strictDec.KnownFields(true)
// are caught above; the decoder merely reads the resolved scalar values. return strictDec.Decode(out)
dec := yaml.NewDecoder(bytes.NewReader(data), yaml.Strict())
return dec.Decode(out)
} }
// checkYAMLDepth recursively checks that YAML AST nodes don't exceed the depth // checkYAMLDepth recursively checks that YAML nodes don't exceed the depth limit
// limit or the total node count limit. It uses two tracking maps: // or the total node count limit. It also detects alias cycles to prevent infinite
// - validated: maps each node to the maximum depth at which it was previously // recursion from crafted YAML with self-referential aliases.
// checked. If a node is revisited at a deeper depth (e.g., via an alias), func checkYAMLDepth(node *yaml.Node, depth, maxDepth, maxNodes int, seen map[*yaml.Node]struct{}, nodeCount *int) error {
// we re-check it to ensure the combined effective depth doesn't exceed limits.
// - visiting: per-path recursion stack for true cycle detection. A node on the
// current path is a cycle (alias loop); we return nil to avoid infinite recursion.
//
// This design prevents the alias depth bypass where an anchored subtree validated
// at a shallow depth could be referenced via alias at a greater depth, effectively
// exceeding MaxYAMLDepth.
func checkYAMLDepth(node ast.Node, depth, maxDepth, maxNodes int, validated map[ast.Node]int, visiting map[ast.Node]bool, nodeCount *int) error {
if node == nil {
return nil
}
if depth > maxDepth { if depth > maxDepth {
return fmt.Errorf("YAML nesting depth exceeds maximum (%d)", maxDepth) return fmt.Errorf("YAML nesting depth exceeds maximum (%d)", maxDepth)
} }
// Cycle detection: if we're currently visiting this node on the current
// recursion path, it's a cycle (e.g., alias pointing to an ancestor).
// Return nil to break the cycle without error — cycles are a structural
// property, not a depth violation.
if visiting[node] {
return nil
}
// Track total nodes visited as defense-in-depth against wide-but-shallow attacks. // Track total nodes visited as defense-in-depth against wide-but-shallow attacks.
// Placed after cycle detection but before the depth-aware short-circuit. This means
// nodes revisited at shallower depths (via aliases) are counted each time they are
// encountered — intentional conservative overcounting. This bounds the total work
// performed during validation rather than tracking unique nodes, which is the safer
// security posture for untrusted YAML input.
*nodeCount++ *nodeCount++
if *nodeCount > maxNodes { if *nodeCount > maxNodes {
return fmt.Errorf("YAML node count exceeds maximum (%d)", maxNodes) return fmt.Errorf("YAML node count exceeds maximum (%d)", maxNodes)
} }
// Depth-aware short-circuit: skip re-validation only when the current visit // Cycle detection: if we've seen this node before, we're in a cycle.
// depth is the same or shallower than the depth at which this node was if _, ok := seen[node]; ok {
// previously validated. A shallower (or equal) current depth means the return nil // Already validated this subtree, skip to avoid infinite recursion.
// prior, deeper validation already covered any subtree depth violations.
// If the current depth exceeds the previous validation depth (e.g., an alias
// references this node deeper in the tree), we must re-traverse to ensure
// the combined effective depth doesn't exceed maxDepth.
//
// Note: using ast.Node (interface) as map key relies on pointer identity,
// which is correct because all goccy/go-yaml AST node types are pointer
// receivers (*MappingNode, *SequenceNode, etc.), never value types.
if prevDepth, ok := validated[node]; ok && depth <= prevDepth {
return nil
} }
validated[node] = depth seen[node] = struct{}{}
// Mark as visiting (on the current recursion path) for cycle detection. // Handle alias nodes: follow the alias to its anchor target.
visiting[node] = true // Increment depth when following aliases since they expand the effective structure.
defer func() { visiting[node] = false }() if node.Kind == yaml.AliasNode && node.Alias != nil {
return checkYAMLDepth(node.Alias, depth+1, maxDepth, maxNodes, seen, nodeCount)
}
// Walk children based on node type. for _, child := range node.Content {
switch n := node.(type) { if err := checkYAMLDepth(child, depth+1, maxDepth, maxNodes, seen, nodeCount); err != nil {
case *ast.MappingNode:
for _, value := range n.Values {
if err := checkYAMLDepth(value, depth+1, maxDepth, maxNodes, validated, visiting, nodeCount); err != nil {
return err
}
}
case *ast.MappingValueNode:
// Both Key and Value are visited at depth+1 relative to this
// MappingValueNode. Since MappingNode visits its MappingValueNode
// children at depth+1 as well, keys and values end up at depth+2
// from the parent MappingNode. This is intentional: it mirrors the
// actual nesting structure (mapping → key-value pair → key/value).
if err := checkYAMLDepth(n.Key, depth+1, maxDepth, maxNodes, validated, visiting, nodeCount); err != nil {
return err return err
} }
if err := checkYAMLDepth(n.Value, depth+1, maxDepth, maxNodes, validated, visiting, nodeCount); err != nil {
return err
}
case *ast.SequenceNode:
for _, value := range n.Values {
if err := checkYAMLDepth(value, depth+1, maxDepth, maxNodes, validated, visiting, nodeCount); err != nil {
return err
}
}
case *ast.AliasNode:
// Follow alias to its target, incrementing depth since aliases expand
// the effective structure.
if err := checkYAMLDepth(n.Value, depth+1, maxDepth, maxNodes, validated, visiting, nodeCount); err != nil {
return err
}
case *ast.AnchorNode:
// Increment depth for anchor values as a conservative measure: the
// anchor definition itself is structural, and treating it as a depth
// level ensures that deeply nested anchors are caught at definition
// time rather than only when referenced via alias. This +1 is
// asymmetric with alias (which also increments) — by design, the
// effective depth budget for anchored-then-aliased content is reduced
// because both the definition site and the reference site each consume
// a level, making deeply nested anchor/alias pairs hit the limit sooner.
if err := checkYAMLDepth(n.Value, depth+1, maxDepth, maxNodes, validated, visiting, nodeCount); err != nil {
return err
}
case *ast.TagNode:
if err := checkYAMLDepth(n.Value, depth+1, maxDepth, maxNodes, validated, visiting, nodeCount); err != nil {
return err
}
case *ast.MergeKeyNode:
// MergeKeyNode represents the literal "<<" merge key token. It has no
// child nodes — the value side of a merge (e.g., *alias) lives in the
// parent MappingValueNode.Value, which is already recursed into above.
// Explicitly listed here (rather than in the default case) to prevent
// future library changes from silently bypassing depth checks.
default:
// Scalar leaf nodes (StringNode, IntegerNode, FloatNode, BoolNode,
// NullNode, InfinityNode, NanNode, LiteralNode) have no children to
// recurse into.
} }
return nil return nil
} }
@@ -331,11 +227,7 @@ func checkYAMLDepth(node ast.Node, depth, maxDepth, maxNodes int, validated map[
// ParsePersonaBytes parses persona data from bytes with a source label for errors. // ParsePersonaBytes parses persona data from bytes with a source label for errors.
// This is useful for parsing personas fetched from external sources (e.g., Gitea API) // This is useful for parsing personas fetched from external sources (e.g., Gitea API)
// without requiring filesystem access. Format is detected by source extension. // without requiring filesystem access. Format is detected by source extension.
// Input is bounded by MaxPersonaFileSize to prevent resource exhaustion.
func ParsePersonaBytes(data []byte, source string) (*Persona, error) { func ParsePersonaBytes(data []byte, source string) (*Persona, error) {
if len(data) > MaxPersonaFileSize {
return nil, fmt.Errorf("persona data from %s exceeds maximum size (%d bytes, limit %d)", source, len(data), MaxPersonaFileSize)
}
return parsePersona(data, source) return parsePersona(data, source)
} }
+41 -222
View File
@@ -7,7 +7,7 @@ import (
"strings" "strings"
"testing" "testing"
"github.com/goccy/go-yaml/ast" "gopkg.in/yaml.v3"
) )
func TestLoadBuiltinPersona(t *testing.T) { func TestLoadBuiltinPersona(t *testing.T) {
@@ -459,14 +459,7 @@ func TestYAMLDeeplyNestedRejection(t *testing.T) {
path := filepath.Join(dir, "deeply-nested.yaml") path := filepath.Join(dir, "deeply-nested.yaml")
// Build a deeply nested YAML structure that exceeds MaxYAMLDepth (20). // Build a deeply nested YAML structure that exceeds MaxYAMLDepth (20).
// Depth accumulation trace for "nested: \n level0: \n level1: ...": // Each level adds 2 to the depth count (key + value mapping).
// - Document root parsed at depth 0
// - Root MappingNode children (MappingValueNodes) visited at depth 1
// - "nested" MappingValueNode: key at depth 2, value at depth 2
// - Each levelN adds depth via MappingValueNode traversal (key + value)
// - Exact depth per level depends on AST structure (MappingNode wrapping),
// but 25 levels reliably exceeds MaxYAMLDepth (20) with comfortable margin.
// The test uses 25 levels rather than exactly 21 to avoid brittleness.
var sb strings.Builder var sb strings.Builder
sb.WriteString("name: test\nidentity: test\nnested:\n") sb.WriteString("name: test\nidentity: test\nnested:\n")
indent := " " indent := " "
@@ -490,35 +483,6 @@ func TestYAMLDeeplyNestedRejection(t *testing.T) {
} }
} }
func TestYAMLEmptyFileRejection(t *testing.T) {
tests := []struct {
name string
content string
}{
{"completely_empty", ""},
{"whitespace_only", " \n\n "},
{"comment_only", "# just a comment\n"},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
dir := t.TempDir()
path := filepath.Join(dir, tc.name+".yaml")
if err := os.WriteFile(path, []byte(tc.content), 0644); err != nil {
t.Fatalf("failed to write test file: %v", err)
}
_, err := LoadPersona(path)
if err == nil {
t.Fatal("expected error for empty YAML input, got nil")
}
if !strings.Contains(err.Error(), "empty YAML document") {
t.Errorf("expected error containing %q, got: %v", "empty YAML document", err)
}
})
}
}
func TestYAMLFileSizeLimit(t *testing.T) { func TestYAMLFileSizeLimit(t *testing.T) {
dir := t.TempDir() dir := t.TempDir()
path := filepath.Join(dir, "huge.yaml") path := filepath.Join(dir, "huge.yaml")
@@ -540,41 +504,41 @@ func TestYAMLFileSizeLimit(t *testing.T) {
func TestYAMLAliasCycleDetection(t *testing.T) { func TestYAMLAliasCycleDetection(t *testing.T) {
// Test that our checkYAMLDepth function handles alias cycles gracefully // Test that our checkYAMLDepth function handles alias cycles gracefully
// by using the visiting map to prevent infinite recursion. // by using the seen map to prevent infinite recursion.
// We test this directly because go-yaml's parser handles most cycles
// at parse time, but we need to ensure our checker is robust.
// Create a node structure where an alias points to a parent node, // Create a node structure where an alias points to a parent node,
// simulating what could happen with crafted input. // simulating what could happen with malicious input that bypasses
parent := &ast.MappingNode{ // go-yaml's cycle detection.
Values: []*ast.MappingValueNode{ parent := &yaml.Node{
{ Kind: yaml.MappingNode,
Key: &ast.StringNode{Value: "name"}, Content: []*yaml.Node{
Value: &ast.StringNode{Value: "test"}, {Kind: yaml.ScalarNode, Value: "name"},
}, {Kind: yaml.ScalarNode, Value: "test"},
{Kind: yaml.ScalarNode, Value: "nested"},
}, },
} }
// Create a child that aliases back to the parent (artificial cycle) // Create a child that aliases back to the parent (artificial cycle)
aliasToParent := &ast.AliasNode{ aliasToParent := &yaml.Node{
Value: parent, Kind: yaml.AliasNode,
Alias: parent,
} }
parent.Values = append(parent.Values, &ast.MappingValueNode{ parent.Content = append(parent.Content, aliasToParent)
Key: &ast.StringNode{Value: "nested"},
Value: aliasToParent,
})
nodeCount := 0 nodeCount := 0
validated := make(map[ast.Node]int) seen := make(map[*yaml.Node]struct{})
visiting := make(map[ast.Node]bool)
// This should NOT hang or stack overflow - cycle detection prevents infinite recursion // This should NOT hang or stack overflow - the seen map prevents infinite recursion
err := checkYAMLDepth(parent, 0, MaxYAMLDepth, MaxYAMLNodes, validated, visiting, &nodeCount) err := checkYAMLDepth(parent, 0, MaxYAMLDepth, MaxYAMLNodes, seen, &nodeCount)
if err != nil { if err != nil {
t.Errorf("unexpected error traversing cyclic structure: %v", err) t.Errorf("unexpected error traversing cyclic structure: %v", err)
} }
// Verify we tracked the parent in the validated map // Verify we tracked the parent in the seen map
if _, ok := validated[parent]; !ok { if _, ok := seen[parent]; !ok {
t.Error("parent node not tracked in validated map") t.Error("parent node not tracked in seen map")
} }
} }
@@ -630,82 +594,36 @@ func TestYAMLNodeCountLimit(t *testing.T) {
func TestCheckYAMLDepthCycleDetectionDirect(t *testing.T) { func TestCheckYAMLDepthCycleDetectionDirect(t *testing.T) {
// Direct test of cycle detection in checkYAMLDepth by creating // Direct test of cycle detection in checkYAMLDepth by creating
// a node structure with an artificial cycle. // a node structure with an artificial cycle.
node := &ast.MappingNode{ // This tests the seen map logic independent of go-yaml's parsing.
Values: []*ast.MappingValueNode{ node := &yaml.Node{
{ Kind: yaml.MappingNode,
Key: &ast.StringNode{Value: "key"}, Content: []*yaml.Node{
Value: &ast.StringNode{Value: "value"}, {Kind: yaml.ScalarNode, Value: "key"},
}, {Kind: yaml.ScalarNode, Value: "value"},
}, },
} }
// Create a cycle by making a child reference the parent // Create a cycle by making a child reference the parent
cycleChild := &ast.AliasNode{ cycleChild := &yaml.Node{
Value: node, // Points back to the parent Kind: yaml.AliasNode,
Alias: node, // Points back to the parent
} }
node.Values = append(node.Values, &ast.MappingValueNode{ node.Content = append(node.Content,
Key: &ast.StringNode{Value: "cyclic"}, &yaml.Node{Kind: yaml.ScalarNode, Value: "cyclic"},
Value: cycleChild, cycleChild,
}) )
nodeCount := 0 nodeCount := 0
validated := make(map[ast.Node]int) seen := make(map[*yaml.Node]struct{})
visiting := make(map[ast.Node]bool) err := checkYAMLDepth(node, 0, MaxYAMLDepth, MaxYAMLNodes, seen, &nodeCount)
err := checkYAMLDepth(node, 0, MaxYAMLDepth, MaxYAMLNodes, validated, visiting, &nodeCount)
// Should complete without infinite recursion due to cycle detection // Should complete without infinite recursion due to cycle detection
if err != nil { if err != nil {
t.Errorf("unexpected error: %v", err) t.Errorf("unexpected error: %v", err)
} }
// The validated map should contain multiple entries // The seen map should contain multiple entries
if len(validated) < 2 { if len(seen) < 2 {
t.Errorf("validated map has %d entries, expected at least 2", len(validated)) t.Errorf("seen map has %d entries, expected at least 2", len(seen))
}
}
func TestYAMLAliasDepthBypass(t *testing.T) {
// Test that an anchored subtree first validated at a shallow depth is
// re-checked when referenced via alias at a deeper position. Without the
// depth-aware validated map, the alias reference would skip re-checking
// and allow the effective nesting to exceed MaxYAMLDepth.
dir := t.TempDir()
path := filepath.Join(dir, "alias-depth-bypass.yaml")
// Build YAML with an anchor at shallow depth containing a subtree near the limit,
// then reference it via alias deep enough that effective depth exceeds MaxYAMLDepth.
var sb strings.Builder
sb.WriteString("name: test\nidentity: test\n")
// Create the anchored subtree at depth 1 (key level) that nests 15 levels deep.
sb.WriteString("anchor_key: &deep_anchor\n")
for i := 0; i < 15; i++ {
sb.WriteString(strings.Repeat(" ", i+1))
sb.WriteString(fmt.Sprintf("level%d:\n", i))
}
sb.WriteString(strings.Repeat(" ", 16))
sb.WriteString("leaf: value\n")
// Create a wrapper that nests 6 levels deep, then references the anchor.
// Effective depth at alias target = 6 (wrapper nesting) + 1 (alias) + 15 (subtree) = 22 > 20
sb.WriteString("wrapper:\n")
for i := 0; i < 6; i++ {
sb.WriteString(strings.Repeat(" ", i+1))
sb.WriteString(fmt.Sprintf("n%d:\n", i))
}
sb.WriteString(strings.Repeat(" ", 7))
sb.WriteString("alias_ref: *deep_anchor\n")
if err := os.WriteFile(path, []byte(sb.String()), 0644); err != nil {
t.Fatalf("failed to write test file: %v", err)
}
_, err := LoadPersona(path)
if err == nil {
t.Fatal("expected error for alias depth bypass, got nil")
}
if !strings.Contains(err.Error(), "nesting depth exceeds") {
t.Errorf("error = %q, want containing 'nesting depth exceeds'", err.Error())
} }
} }
@@ -858,102 +776,3 @@ identity: test identity
t.Errorf("Name = %q, want %q", p.Name, "test") t.Errorf("Name = %q, want %q", p.Name, "test")
} }
} }
func TestJSONTrailingContentRejected(t *testing.T) {
tests := []struct {
name string
content string
}{
{
name: "trailing garbage after object",
content: `{"name":"test","identity":"test identity"}garbage`,
},
{
name: "two JSON objects",
content: `{"name":"test","identity":"test identity"}{"name":"other"}`,
},
{
name: "trailing array",
content: `{"name":"test","identity":"test identity"}[]`,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
dir := t.TempDir()
path := filepath.Join(dir, "test.json")
if err := os.WriteFile(path, []byte(tt.content), 0644); err != nil {
t.Fatalf("failed to write test file: %v", err)
}
_, err := LoadPersona(path)
if err == nil {
t.Fatal("expected error for trailing content, got nil")
}
if !strings.Contains(err.Error(), "trailing content") {
t.Errorf("error = %q, want to contain 'trailing content'", err.Error())
}
})
}
}
func TestParsePersonaBytesSizeLimit(t *testing.T) {
// ParsePersonaBytes should reject input exceeding MaxPersonaFileSize
oversized := make([]byte, MaxPersonaFileSize+1)
for i := range oversized {
oversized[i] = 'x'
}
_, err := ParsePersonaBytes(oversized, "oversized.yaml")
if err == nil {
t.Fatal("expected error for oversized input, got nil")
}
if !strings.Contains(err.Error(), "exceeds maximum size") {
t.Errorf("error = %q, want to contain 'exceeds maximum size'", err.Error())
}
// Just under the limit should not trigger size error (may fail parse, but not size)
underLimit := []byte("name: test\nidentity: test persona\n")
p, err := ParsePersonaBytes(underLimit, "valid.yaml")
if err != nil {
t.Fatalf("unexpected error for valid input: %v", err)
}
if p.Name != "test" {
t.Errorf("Name = %q, want %q", p.Name, "test")
}
}
func TestYAMLMergeKeyDepthCheck(t *testing.T) {
// Verify that YAML merge keys (<<: *alias) are properly handled by the
// depth checker. The merge key content is in the MappingValueNode.Value
// (an AliasNode), not in the MergeKeyNode itself.
p, err := ParsePersonaBytes([]byte("name: merge-test\nidentity: test\n"), "merge.yaml")
if err != nil {
t.Fatalf("basic parse failed: %v", err)
}
if p.Name != "merge-test" {
t.Errorf("Name = %q, want %q", p.Name, "merge-test")
}
// Test that deeply nested merge keys still hit depth limit.
// Build YAML with merge key content nested beyond MaxYAMLDepth.
var sb strings.Builder
sb.WriteString("name: deep-merge\nidentity: deep merge persona\n")
sb.WriteString("anchor: &deep\n")
indent := " "
for i := 0; i < MaxYAMLDepth+5; i++ {
sb.WriteString(indent)
sb.WriteString(fmt.Sprintf("level%d:\n", i))
indent += " "
}
sb.WriteString(indent + "leaf: value\n")
sb.WriteString("target:\n <<: *deep\n")
_, err = ParsePersonaBytes([]byte(sb.String()), "deep-merge.yaml")
if err == nil {
t.Fatal("expected error for deeply nested merge key content, got nil")
}
if !strings.Contains(err.Error(), "depth") {
t.Errorf("error = %q, want to contain 'depth'", err.Error())
}
}
+20 -4
View File
@@ -1,3 +1,5 @@
//go:build phase2
package vcs_test package vcs_test
import ( import (
@@ -5,7 +7,21 @@ import (
"gitea.weiker.me/rodin/review-bot/vcs" "gitea.weiker.me/rodin/review-bot/vcs"
) )
// Compile-time assertion: the gitea.Adapter satisfies vcs.Client. // Compile-time assertion: documents the gap between gitea.Client and vcs.Client.
// (The raw gitea.Client does NOT satisfy vcs.Client due to signature differences; // Guarded by the "phase2" build tag — enable once the Gitea adapter bridges these gaps:
// the Adapter bridges them.) //
var _ vcs.Client = (*gitea.Adapter)(nil) // 1. PostReview signature mismatch:
// gitea.Client: PostReview(ctx, owner, repo, number, event, body string, comments []gitea.ReviewComment)
// vcs.Reviewer: PostReview(ctx, owner, repo, number, req vcs.ReviewRequest)
//
// 2. GetFileContent signature mismatch:
// gitea.Client: GetFileContent(ctx, owner, repo, filepath string) [no ref; uses default branch]
// vcs.FileReader: GetFileContent(ctx, owner, repo, path, ref string)
// (gitea.Client has GetFileContentRef for the ref variant)
//
// 3. ReviewComment type mismatch:
// gitea.ReviewComment uses NewPosition int64 (Gitea line-number convention)
// vcs.ReviewComment uses Position int (GitHub diff-position convention)
//
// The Gitea adapter (Phase 2) will wrap gitea.Client to bridge these gaps.
var _ vcs.Client = (*gitea.Client)(nil)