Compare commits

..

3 Commits

Author SHA1 Message Date
Rodin 27a9be38bc fix: address PR #63 review findings
1. Refactor err2 to use scoped loadErr variable (MINOR - sonnet-review-bot)
   The else-if branches are mutually exclusive, so the error variable
   should be scoped inside the block, not declared outside with err2.

2. Sanitize DisplayName before embedding in Markdown (MINOR - security-review-bot)
   Remote persona metadata is untrusted. Added sanitizeMarkdownText() to
   escape Markdown special characters and strip control characters.
   Applied to both the header title and the footer attribution.

3. Document YAML DoS mitigations (MINOR - security-review-bot)
   Added comprehensive comment in remote_persona.go explaining existing
   defenses: file size limit, file count cap, depth limit, node count cap,
   and alias cycle detection. These collectively mitigate billion-laughs
   and stack exhaustion attacks.
2026-05-10 20:54:20 -07:00
Rodin 5fac8bc505 fix: address PR #62 review findings
CI / test (pull_request) Successful in 16s
CI / review (anthropic--claude-4.6-sonnet, sonnet, SONNET_REVIEW_TOKEN) (pull_request) Successful in 27s
CI / review (gpt-5, gpt, GPT_REVIEW_TOKEN) (pull_request) Successful in 1m5s
CI / review (gpt-5, security, SECURITY_REVIEW.md, SECURITY_REVIEW_TOKEN) (pull_request) Successful in 1m40s
- Remove duplicate flag.Parse() call
- Fix nil map panic in LoadRemotePersonas error path by assigning
  empty map when LoadRemotePersonas returns an error
- Tighten isNotFoundError to only check HTTP 404 (remove broad
  'not found' substring check to avoid false positives)
- Clean up personaErr variable scope using narrower-scoped err variables
- Add proper doc comment to LoadRemotePersonasFromPath (Go convention)
- Add file count cap (50 files) in LoadRemotePersonasFromPath to
  prevent resource exhaustion from repos with thousands of small files
- Update test expectation for tightened isNotFoundError
2026-05-10 20:44:24 -07:00
Rodin 2f8d047ef2 feat: load personas from target repo .review-bot/personas/
CI / review (gpt-5, security, SECURITY_REVIEW.md, SECURITY_REVIEW_TOKEN) (pull_request) Successful in 8m12s
CI / review (gpt-5, gpt, GPT_REVIEW_TOKEN) (pull_request) Successful in 8m15s
CI / test (pull_request) Successful in 15s
CI / review (anthropic--claude-4.6-sonnet, sonnet, SONNET_REVIEW_TOKEN) (pull_request) Failing after 42s
Adds support for repository-specific personas. When --persona is
specified, review-bot now:

1. Checks the target repo's .review-bot/personas/<name>.yaml directory
2. Falls back to built-in persona if not found in repo

This allows repos to define domain-specific personas (trading, regulatory,
etc.) or override built-in personas with project-specific rules, without
requiring changes to CI configuration.

Implementation:
- New review.PersonaFetcher interface for abstracting Gitea API access
- review.LoadRemotePersonas() with graceful fallback on 404
- review.MergePersonas() for combining remote and built-in personas
- giteaFetcher adapter in main.go to bridge gitea.Client

The feature follows a partial-success model: invalid YAML files or
network errors for individual persona files are logged and skipped,
allowing other valid personas to load.

Closes #60
2026-05-10 19:05:55 -07:00
33 changed files with 751 additions and 5559 deletions
+2 -4
View File
@@ -38,8 +38,6 @@ jobs:
- name: security
token_secret: SECURITY_REVIEW_TOKEN
model: gpt-5
patterns_repo: rodin/security-patterns
patterns_files: "."
system_prompt_file: SECURITY_REVIEW.md
steps:
- uses: actions/checkout@v4
@@ -62,8 +60,8 @@ jobs:
AICORE_API_URL: ${{ secrets.AICORE_API_URL }}
AICORE_RESOURCE_GROUP: ${{ secrets.AICORE_RESOURCE_GROUP }}
CONVENTIONS_FILE: "CONVENTIONS.md"
PATTERNS_REPO: ${{ matrix.patterns_repo || 'rodin/go-patterns' }}
PATTERNS_FILES: ${{ matrix.patterns_files || 'README.md,patterns/' }}
PATTERNS_REPO: "rodin/go-patterns"
PATTERNS_FILES: "README.md,patterns/"
LLM_TIMEOUT: "600"
SYSTEM_PROMPT_FILE: ${{ matrix.system_prompt_file }}
run: ./review-bot
-200
View File
@@ -1,200 +0,0 @@
# This composite action is designed for Gitea Actions runners.
# Gitea Actions supports GitHub Actions syntax including $GITHUB_OUTPUT,
# actions/cache, and actions/checkout.
# Requirements: python3, sha256sum, curl (all present on ubuntu-* runners).
name: 'AI Code Review'
description: 'Run AI-powered code review on a pull request using review-bot'
inputs:
gitea-url:
description: 'Gitea instance URL (defaults to server_url)'
required: false
default: ''
repo:
description: 'Repository (owner/name, defaults to current)'
required: false
default: ''
pr-number:
description: 'Pull request number (defaults to current PR)'
required: false
default: ''
reviewer-token:
description: 'Gitea token for posting the review'
required: true
reviewer-name:
description: 'Display name for the reviewer'
required: false
default: ''
llm-base-url:
description: 'OpenAI-compatible LLM API base URL (not required for aicore provider)'
required: false
default: ''
llm-api-key:
description: 'LLM API key (not required for aicore provider)'
required: false
default: ''
llm-model:
description: 'LLM model name'
required: true
llm-provider:
description: 'LLM API provider: openai, anthropic, or aicore (default openai)'
required: false
default: 'openai'
aicore-client-id:
description: 'SAP AI Core client ID (required for aicore provider)'
required: false
default: ''
aicore-client-secret:
description: 'SAP AI Core client secret (required for aicore provider)'
required: false
default: ''
aicore-auth-url:
description: 'SAP AI Core authentication URL (required for aicore provider)'
required: false
default: ''
aicore-api-url:
description: 'SAP AI Core API URL (required for aicore provider)'
required: false
default: ''
aicore-resource-group:
description: 'SAP AI Core resource group (default: default)'
required: false
default: 'default'
conventions-file:
description: 'Path to conventions file in the repo (e.g. CLAUDE.md)'
required: false
default: ''
patterns-repo:
description: 'Comma-separated repos with language patterns (e.g. rodin/elixir-patterns,rodin/phoenix-conventions)'
required: false
default: ''
patterns-files:
description: 'Comma-separated file paths or directories to fetch from patterns repos'
required: false
default: 'README.md'
temperature:
description: 'LLM temperature (0 = server default)'
required: false
default: '0'
timeout:
description: 'LLM request timeout in seconds (default 300)'
required: false
default: '300'
version:
description: 'review-bot version to install (e.g. v0.1.0, defaults to latest)'
required: false
default: 'latest'
dry-run:
description: 'Print review to stdout instead of posting'
required: false
default: 'false'
update-existing:
description: 'Delete previous review from same bot after posting new one. Accepts: true/1/yes or false/0/no (default true)'
required: false
default: 'true'
system-prompt-file:
description: 'Local file with additional system prompt instructions (e.g. security review focus)'
required: false
default: ''
persona:
description: 'Built-in persona name (security, architect, docs)'
required: false
default: ''
persona-file:
description: 'Path to custom persona JSON file'
required: false
default: ''
runs:
using: 'composite'
steps:
- name: Determine version
id: version
shell: bash
run: |
GITEA_URL="${{ inputs.gitea-url || github.server_url }}"
REPO="${{ inputs.repo || 'rodin/review-bot' }}"
if [ "${{ inputs.version }}" = "latest" ]; then
VERSION=$(curl -sSf "${GITEA_URL}/api/v1/repos/${REPO}/releases?limit=1" \
| python3 -c "import sys, json; releases = json.load(sys.stdin); print(releases[0]['tag_name'] if releases else '')")
if [ -z "$VERSION" ]; then
echo "Failed to determine latest version" >&2
exit 1
fi
else
VERSION="${{ inputs.version }}"
fi
echo "version=${VERSION}" >> "$GITHUB_OUTPUT"
- name: Cache review-bot binary
id: cache
uses: actions/cache@v4
with:
path: ${{ runner.temp }}/review-bot
key: review-bot-linux-amd64-${{ steps.version.outputs.version }}
- name: Install review-bot
if: steps.cache.outputs.cache-hit != 'true'
shell: bash
run: |
GITEA_URL="${{ inputs.gitea-url || github.server_url }}"
REPO="${{ inputs.repo || 'rodin/review-bot' }}"
VERSION="${{ steps.version.outputs.version }}"
BINARY="review-bot-linux-amd64"
curl -sSfL "${GITEA_URL}/${REPO}/releases/download/${VERSION}/${BINARY}" \
-o "${{ runner.temp }}/review-bot"
curl -sSfL "${GITEA_URL}/${REPO}/releases/download/${VERSION}/checksums.txt" \
-o "${{ runner.temp }}/checksums.txt"
# Verify SHA-256 checksum
cd "${{ runner.temp }}"
EXPECTED=$(grep "${BINARY}" checksums.txt | awk '{print $1}')
ACTUAL=$(sha256sum review-bot | awk '{print $1}')
if [ -z "$EXPECTED" ]; then
echo "Error: no checksum found for ${BINARY}" >&2
exit 1
fi
if [ "$EXPECTED" != "$ACTUAL" ]; then
echo "Error: checksum mismatch!" >&2
echo " Expected: $EXPECTED" >&2
echo " Actual: $ACTUAL" >&2
exit 1
fi
chmod +x "${{ runner.temp }}/review-bot"
echo "Installed review-bot ${VERSION} (checksum verified)"
- name: Run review
shell: bash
env:
GITHUB_SERVER_URL: ${{ inputs.gitea-url || github.server_url }}
GITHUB_REPOSITORY: ${{ inputs.repo || github.repository }}
PR_NUMBER: ${{ inputs.pr-number || github.event.pull_request.number }}
REVIEWER_TOKEN: ${{ inputs.reviewer-token }}
REVIEWER_NAME: ${{ inputs.reviewer-name }}
LLM_BASE_URL: ${{ inputs.llm-base-url }}
LLM_API_KEY: ${{ inputs.llm-api-key }}
LLM_MODEL: ${{ inputs.llm-model }}
CONVENTIONS_FILE: ${{ inputs.conventions-file }}
PATTERNS_REPO: ${{ inputs.patterns-repo }}
PATTERNS_FILES: ${{ inputs.patterns-files }}
LLM_TEMPERATURE: ${{ inputs.temperature }}
LLM_TIMEOUT: ${{ inputs.timeout }}
LLM_PROVIDER: ${{ inputs.llm-provider }}
UPDATE_EXISTING: ${{ inputs.update-existing }}
SYSTEM_PROMPT_FILE: ${{ inputs.system-prompt-file }}
PERSONA: ${{ inputs.persona }}
PERSONA_FILE: ${{ inputs.persona-file }}
AICORE_CLIENT_ID: ${{ inputs.aicore-client-id }}
AICORE_CLIENT_SECRET: ${{ inputs.aicore-client-secret }}
AICORE_AUTH_URL: ${{ inputs.aicore-auth-url }}
AICORE_API_URL: ${{ inputs.aicore-api-url }}
AICORE_RESOURCE_GROUP: ${{ inputs.aicore-resource-group }}
run: |
ARGS=""
if [ "${{ inputs.dry-run }}" = "true" ]; then
ARGS="--dry-run"
fi
${{ runner.temp }}/review-bot $ARGS
-69
View File
@@ -1,69 +0,0 @@
name: CI
on:
push:
branches: [main]
pull_request:
types: [opened, synchronize]
jobs:
test:
runs-on: ubuntu-24.04
steps:
- uses: actions/checkout@v4
- uses: actions/setup-go@v5
with:
go-version: '1.26'
- run: go test ./...
- run: go vet ./...
- run: go build -o review-bot ./cmd/review-bot
# Self-review using native SAP AI Core provider
# Models must match SAP AI Core deployments
# Available models: gpt-5, anthropic--claude-4.6-sonnet, anthropic--claude-4.6-opus
# Removed gpt-4.1, gpt-5-mini, gpt-4.1-mini - not deployed on AI Core
review:
runs-on: ubuntu-24.04
if: github.event_name == 'pull_request'
needs: test
strategy:
matrix:
include:
- name: sonnet
token_secret: SONNET_REVIEW_TOKEN
model: anthropic--claude-4.6-sonnet
- name: gpt
token_secret: GPT_REVIEW_TOKEN
model: gpt-5
- name: security
token_secret: SECURITY_REVIEW_TOKEN
model: gpt-5
patterns_repo: rodin/security-patterns
patterns_files: "."
system_prompt_file: SECURITY_REVIEW.md
steps:
- uses: actions/checkout@v4
- uses: actions/setup-go@v5
with:
go-version: '1.26'
- run: go build -o review-bot ./cmd/review-bot
- name: Run ${{ matrix.name }} review
env:
GITHUB_SERVER_URL: ${{ github.server_url }}
GITHUB_REPOSITORY: ${{ github.repository }}
PR_NUMBER: ${{ github.event.pull_request.number }}
REVIEWER_TOKEN: ${{ secrets[matrix.token_secret] }}
REVIEWER_NAME: ${{ matrix.name }}
LLM_PROVIDER: aicore
LLM_MODEL: ${{ matrix.model }}
AICORE_CLIENT_ID: ${{ secrets.AICORE_CLIENT_ID }}
AICORE_CLIENT_SECRET: ${{ secrets.AICORE_CLIENT_SECRET }}
AICORE_AUTH_URL: ${{ secrets.AICORE_AUTH_URL }}
AICORE_API_URL: ${{ secrets.AICORE_API_URL }}
AICORE_RESOURCE_GROUP: ${{ secrets.AICORE_RESOURCE_GROUP }}
CONVENTIONS_FILE: "CONVENTIONS.md"
PATTERNS_REPO: ${{ matrix.patterns_repo || 'rodin/go-patterns' }}
PATTERNS_FILES: ${{ matrix.patterns_files || 'README.md,patterns/' }}
LLM_TIMEOUT: "600"
SYSTEM_PROMPT_FILE: ${{ matrix.system_prompt_file }}
run: ./review-bot
-38
View File
@@ -1,38 +0,0 @@
name: PR Ready Gate
on:
pull_request:
types: [synchronize]
jobs:
clear-labels:
runs-on: ubuntu-24.04
# Always run - curl commands are safe if labels don't exist
steps:
- name: Remove ready and self-reviewed labels, reassign to author
env:
GITEA_TOKEN: ${{ secrets.RODIN_TOKEN }}
run: |
PR_NUMBER=${{ github.event.pull_request.number }}
AUTHOR=${{ github.event.pull_request.user.login }}
READY_LABEL_ID=38
SELF_REVIEWED_LABEL_ID=37
# Remove ready label if present
curl -sS -X DELETE \
-H "Authorization: token $GITEA_TOKEN" \
"https://gitea.weiker.me/api/v1/repos/${{ github.repository }}/issues/${PR_NUMBER}/labels/${READY_LABEL_ID}" || true
# Remove self-reviewed label if present
curl -sS -X DELETE \
-H "Authorization: token $GITEA_TOKEN" \
"https://gitea.weiker.me/api/v1/repos/${{ github.repository }}/issues/${PR_NUMBER}/labels/${SELF_REVIEWED_LABEL_ID}" || true
# Reassign to author
curl -sS -X PATCH \
-H "Authorization: token $GITEA_TOKEN" \
-H "Content-Type: application/json" \
-d "{\"assignees\": [\"${AUTHOR}\"]}" \
"https://gitea.weiker.me/api/v1/repos/${{ github.repository }}/pulls/${PR_NUMBER}"
echo "Cleared ready/self-reviewed labels and reassigned PR #${PR_NUMBER} to ${AUTHOR}"
-97
View File
@@ -1,97 +0,0 @@
name: Release
on:
push:
tags:
- 'v*'
jobs:
release:
runs-on: ubuntu-24.04
steps:
- uses: actions/checkout@v4
- uses: actions/setup-go@v5
with:
go-version: '1.26'
- name: Run tests
run: |
go vet ./...
go test ./...
- name: Build binaries
run: |
VERSION=${GITHUB_REF_NAME}
mkdir -p dist
GOOS=linux GOARCH=amd64 go build -ldflags "-s -w -X main.version=${VERSION}" -o dist/review-bot-linux-amd64 ./cmd/review-bot
GOOS=linux GOARCH=arm64 go build -ldflags "-s -w -X main.version=${VERSION}" -o dist/review-bot-linux-arm64 ./cmd/review-bot
GOOS=darwin GOARCH=amd64 go build -ldflags "-s -w -X main.version=${VERSION}" -o dist/review-bot-darwin-amd64 ./cmd/review-bot
GOOS=darwin GOARCH=arm64 go build -ldflags "-s -w -X main.version=${VERSION}" -o dist/review-bot-darwin-arm64 ./cmd/review-bot
cd dist && sha256sum * > checksums.txt
- name: Create release and upload assets
env:
GITEA_TOKEN: ${{ secrets.RELEASE_TOKEN }}
run: |
VERSION=${GITHUB_REF_NAME}
GITEA_URL="${{ github.server_url }}"
REPO="${{ github.repository }}"
# Create release (or find existing one for this tag)
HTTP_CODE=$(curl -s -o /tmp/release_response.json -w "%{http_code}" -X POST \
-H "Authorization: token ${GITEA_TOKEN}" \
-H "Content-Type: application/json" \
"${GITEA_URL}/api/v1/repos/${REPO}/releases" \
-d "{\"tag_name\": \"${VERSION}\", \"name\": \"${VERSION}\", \"body\": \"Release ${VERSION}\", \"draft\": false, \"prerelease\": false}")
if [ "$HTTP_CODE" = "409" ]; then
echo "Release for ${VERSION} already exists, fetching existing..."
curl -sSf -o /tmp/release_response.json \
-H "Authorization: token ${GITEA_TOKEN}" \
"${GITEA_URL}/api/v1/repos/${REPO}/releases/tags/${VERSION}"
elif [ "$HTTP_CODE" != "201" ]; then
echo "Failed to create release (HTTP ${HTTP_CODE})" >&2
cat /tmp/release_response.json >&2
exit 1
fi
# Parse release ID (python3 available on ubuntu-24.04 runners)
RELEASE_ID=$(python3 -c "import json; print(json.load(open('/tmp/release_response.json'))['id'])")
if [ -z "$RELEASE_ID" ]; then
echo "Failed to parse release ID" >&2
cat /tmp/release_response.json >&2
exit 1
fi
echo "Release ID: ${RELEASE_ID}"
# Upload each asset (idempotent: delete existing asset with same name first)
for file in dist/*; do
filename=$(basename "$file")
echo "Uploading ${filename}..."
# Check if asset already exists and delete it
EXISTING_ID=$(export ASSET_NAME="${filename}"; curl -sS \
-H "Authorization: token ${GITEA_TOKEN}" \
"${GITEA_URL}/api/v1/repos/${REPO}/releases/${RELEASE_ID}/assets" \
| python3 -c "import json,sys,os; name=os.environ['ASSET_NAME']; assets=json.load(sys.stdin); print(next((str(a['id']) for a in assets if a['name']==name),''))" 2>/dev/null)
if [ -n "$EXISTING_ID" ]; then
echo " Asset ${filename} already exists (id=${EXISTING_ID}), deleting..."
curl -sSf -X DELETE \
-H "Authorization: token ${GITEA_TOKEN}" \
"${GITEA_URL}/api/v1/repos/${REPO}/releases/${RELEASE_ID}/assets/${EXISTING_ID}"
fi
curl -sSf -X POST \
-H "Authorization: token ${GITEA_TOKEN}" \
-H "Content-Type: application/octet-stream" \
"${GITEA_URL}/api/v1/repos/${REPO}/releases/${RELEASE_ID}/assets?name=$(printf '%s' "${filename}" | jq -sRr @uri)" \
--data-binary "@${file}"
done
echo "Release ${VERSION} created with assets"
+37 -3
View File
@@ -329,12 +329,11 @@ All flags have environment variable equivalents:
### Token Scopes Required
| Scope | Purpose |
|-------|--------|
|-------|---------|
| `write:issue` | Post and delete reviews |
| `write:repository` | Read PR diffs, file content, commit statuses |
| `read:user` | Self-request as reviewer (optional but recommended) |
Without `read:user`, the bot still works but cannot add itself to the PR's reviewer list.
No `read:user` scope needed — the bot identifies itself from the review response.
## Development
@@ -460,6 +459,41 @@ YAML is the recommended format for personas because it supports:
JSON is also supported for backwards compatibility—just use `.json` extension.
### Repository Personas (Auto-Discovery)
Repositories can ship their own personas in `.review-bot/personas/`. When you specify `--persona <name>`, review-bot will:
1. **Try to load from the target repo** — Checks `.review-bot/personas/<name>.yaml` (or `.yml`)
2. **Fall back to built-in** — If not found in repo, uses the built-in persona
This lets each repo define domain-specific personas without modifying CI config:
```
my-trading-repo/
├── .review-bot/
│ └── personas/
│ ├── trading.yaml # Custom trading persona
│ └── regulatory.yaml # Compliance-focused reviews
├── lib/
└── ...
```
```yaml
# CI config (no persona-file needed)
- uses: rodin/review-bot/.gitea/actions/review@v1
with:
reviewer-name: trading
persona: trading # Will find .review-bot/personas/trading.yaml
...
```
**Priority order:**
1. Repo's `.review-bot/personas/<name>.yaml`
2. Built-in persona with matching name
3. Error if neither exists
This allows repos to override built-in personas (e.g., a custom `security` persona that adds project-specific rules) while keeping the simple `persona: security` syntax in CI.
### Persona vs system-prompt-file
+30 -47
View File
@@ -15,7 +15,6 @@ import (
"gitea.weiker.me/rodin/review-bot/gitea"
"gitea.weiker.me/rodin/review-bot/llm"
"gitea.weiker.me/rodin/review-bot/review"
"gitea.weiker.me/rodin/review-bot/vcs"
)
var version = "dev"
@@ -55,8 +54,8 @@ func main() {
logFormat := flag.String("log-format", envOrDefault("LOG_FORMAT", "text"), "Log output format: text or json")
verbosity := flag.String("verbosity", envOrDefault("LOG_VERBOSITY", "info"), "Log verbosity: debug, info, warn, error")
// CLI flags
giteaURL := flag.String("gitea-url", envOrDefault("GITEA_URL", envOrDefault("GITHUB_SERVER_URL", "")), "Gitea instance URL")
repo := flag.String("repo", envOrDefault("GITEA_REPO", envOrDefault("GITHUB_REPOSITORY", "")), "Repository (owner/name)")
giteaURL := flag.String("gitea-url", envOrDefault("GITEA_URL", ""), "Gitea instance URL")
repo := flag.String("repo", envOrDefault("GITEA_REPO", ""), "Repository (owner/name)")
prNum := flag.String("pr", envOrDefault("PR_NUMBER", ""), "Pull request number")
reviewerName := flag.String("reviewer-name", envOrDefault("REVIEWER_NAME", ""), "Reviewer display name")
reviewerToken := flag.String("reviewer-token", envOrDefault("REVIEWER_TOKEN", ""), "Gitea token for posting review")
@@ -116,7 +115,9 @@ func main() {
os.Exit(1)
}
// NOTE: Persona loading deferred until after Gitea client init to support repo personas
// Persona loading is deferred until after giteaClient is initialized,
// so we can try loading from the target repo first.
var persona *review.Persona
// Validate reviewer-name: only safe characters allowed in sentinel
if err := validateReviewerName(*reviewerName); err != nil {
@@ -174,22 +175,23 @@ func main() {
ctx, cancel := context.WithTimeout(context.Background(), overallTimeout)
defer cancel()
// Load persona if specified (after Gitea client init to support repo personas)
var persona *review.Persona
// Load persona: try remote repo first, then fall back to built-in
if *personaName != "" {
// Try loading from repo first, then fall back to built-in
repoPersonas, err := review.LoadRepoPersonas(ctx, newGiteaClientAdapter(giteaClient), owner, repoName)
// Try loading from target repo's .review-bot/personas/ directory
fetcher := &giteaFetcher{client: giteaClient}
remotePersonas, err := review.LoadRemotePersonas(ctx, fetcher, owner, repoName)
if err != nil {
slog.Warn("could not load repo personas", "repo", owner+"/"+repoName, "error", err)
// Continue with built-in personas only.
// NOTE: repoPersonas is nil here, but map indexing on a nil map is safe in Go
// (returns the zero value), so the fallback to built-in below works correctly.
slog.Warn("could not load remote personas", "repo", fmt.Sprintf("%s/%s", owner, repoName), "error", err)
// Assign empty map so the lookup below doesn't panic
remotePersonas = map[string]*review.Persona{}
}
if p, ok := repoPersonas[*personaName]; ok {
if p, ok := remotePersonas[*personaName]; ok {
persona = p
slog.Info("loaded repo persona", "persona", persona.Name, "display", persona.DisplayName, "repo", owner+"/"+repoName)
slog.Info("loaded persona from target repo", "persona", persona.Name, "display", persona.DisplayName)
} else {
// Fall back to built-in
// Fall back to built-in persona
var err error
persona, err = review.LoadBuiltinPersona(*personaName)
if err != nil {
slog.Error("failed to load persona", "persona", *personaName, "error", err)
@@ -203,11 +205,12 @@ func main() {
slog.Error("invalid persona-file path", "error", err)
os.Exit(1)
}
persona, err = review.LoadPersona(resolvedPath)
if err != nil {
slog.Error("failed to load persona file", "file", *personaFile, "error", err)
loadedPersona, loadErr := review.LoadPersona(resolvedPath)
if loadErr != nil {
slog.Error("failed to load persona file", "file", *personaFile, "error", loadErr)
os.Exit(1)
}
persona = loadedPersona
slog.Info("loaded persona from file", "file", *personaFile, "persona", persona.Name)
}
@@ -545,9 +548,6 @@ func fetchPatterns(ctx context.Context, client *gitea.Client, patternsRepo, patt
}
owner, repo := parts[0], parts[1]
var repoLoadedFiles []string
var repoSkippedFiles []string
for _, path := range paths {
path = strings.TrimSpace(path)
if path == "" {
@@ -563,22 +563,11 @@ func fetchPatterns(ctx context.Context, client *gitea.Client, patternsRepo, patt
for filePath, content := range files {
// Only include markdown and text files as patterns
if !isPatternFile(filePath) {
repoSkippedFiles = append(repoSkippedFiles, filePath)
continue
}
repoLoadedFiles = append(repoLoadedFiles, filePath)
sb.WriteString(fmt.Sprintf("### %s/%s\n\n%s\n\n", repoRef, filePath, content))
}
}
if len(repoLoadedFiles) > 0 {
slog.Info("loaded pattern files", "repo", repoRef, "count", len(repoLoadedFiles), "files", repoLoadedFiles)
} else {
slog.Warn("no pattern files loaded", "repo", repoRef, "paths", paths)
}
if len(repoSkippedFiles) > 0 {
slog.Debug("skipped non-pattern files", "repo", repoRef, "count", len(repoSkippedFiles), "files", repoSkippedFiles)
}
}
return sb.String()
}
@@ -813,23 +802,20 @@ func shouldSkipStaleReview(evaluatedSHA, currentSHA string) bool {
return evaluatedSHA != currentSHA
}
// giteaClientAdapter adapts gitea.Client to vcs.FileReader interface.
type giteaClientAdapter struct {
// giteaFetcher adapts gitea.Client to review.PersonaFetcher interface.
type giteaFetcher struct {
client *gitea.Client
}
func newGiteaClientAdapter(c *gitea.Client) *giteaClientAdapter {
return &giteaClientAdapter{client: c}
}
func (a *giteaClientAdapter) ListContents(ctx context.Context, owner, repo, path string) ([]vcs.ContentEntry, error) {
entries, err := a.client.ListContents(ctx, owner, repo, path)
func (f *giteaFetcher) ListContents(ctx context.Context, owner, repo, path string) ([]review.ContentEntry, error) {
entries, err := f.client.ListContents(ctx, owner, repo, path)
if err != nil {
return nil, err
}
result := make([]vcs.ContentEntry, len(entries))
// Convert gitea.ContentEntry to review.ContentEntry
result := make([]review.ContentEntry, len(entries))
for i, e := range entries {
result[i] = vcs.ContentEntry{
result[i] = review.ContentEntry{
Name: e.Name,
Path: e.Path,
Type: e.Type,
@@ -838,9 +824,6 @@ func (a *giteaClientAdapter) ListContents(ctx context.Context, owner, repo, path
return result, nil
}
func (a *giteaClientAdapter) GetFileContent(ctx context.Context, owner, repo, filePath, ref string) (string, error) {
if ref != "" {
return a.client.GetFileContentRef(ctx, owner, repo, filePath, ref)
}
return a.client.GetFileContent(ctx, owner, repo, filePath)
func (f *giteaFetcher) GetFileContent(ctx context.Context, owner, repo, filepath string) (string, error) {
return f.client.GetFileContent(ctx, owner, repo, filepath)
}
-268
View File
@@ -1,268 +0,0 @@
# GitHub Support for review-bot
## Goal
AI code reviews on GitHub PRs using SAP AI Core as the LLM provider.
## Non-Goals
- Auto-detection of platform (explicit `--provider` flag is fine)
- Unifying into one abstraction layer for its own sake
## Constraints
1. **Same features on both platforms** — anything review-bot does on Gitea should work on GitHub
2. **Testable** — small interfaces, dependency injection, no global state
3. **Interface from working code** — extract from gitea/, don't invent in vacuum
---
## Part 1: Feature Inventory
What does review-bot actually do?
### Core Review Flow
| Feature | Description |
|---------|-------------|
| Get PR metadata | Title, body, head SHA, base ref |
| Get PR diff | Unified diff format |
| Get PR files | List of changed files with status |
| Get file content | Raw file at ref |
| List directory | Enumerate files in path |
| Post review | Body + inline comments + verdict |
### Review Management
| Feature | Description |
|---------|-------------|
| List reviews | Get existing reviews on PR |
| Delete review | Remove old review before re-posting |
| Get authenticated user | Who am I? |
### Platform-Specific (not in shared interface)
| Feature | Gitea | GitHub |
|---------|-------|--------|
| Resolve comment | Yes | No equivalent |
| Timeline API | Yes | No equivalent |
These stay on gitea.Client directly. Callers that need them type-assert.
---
## Part 2: GitHub API Mapping
| Feature | Gitea API | GitHub API |
|---------|-----------|------------|
| Get PR | `GET /api/v1/repos/.../pulls/{n}` | `GET /repos/.../pulls/{n}` |
| Get diff | `.diff` suffix | `Accept: application/vnd.github.diff` header |
| Get files | `GET .../pulls/{n}/files` | Same |
| Get file content | `GET .../raw/{path}?ref=` | `GET .../contents/{path}?ref=` + base64 decode |
| List directory | `GET .../contents/{path}` | Same |
| Post review | `POST .../pulls/{n}/reviews` | Same (adapter handles comment schema) |
| List reviews | `GET .../pulls/{n}/reviews` | Same |
| Delete review | `DELETE .../pulls/{n}/reviews/{id}` | Same |
| Get user | `GET /api/v1/user` | `GET /user` |
---
## Part 3: Interface Design
**Principle:** Extract from working gitea/ code. The interface is discovered, not invented.
### Small, role-based interfaces
```go
// vcs/interfaces.go
type PRReader interface {
GetPullRequest(ctx context.Context, owner, repo string, number int) (*PullRequest, error)
GetPullRequestDiff(ctx context.Context, owner, repo string, number int) (string, error)
GetPullRequestFiles(ctx context.Context, owner, repo string, number int) ([]ChangedFile, error)
}
type FileReader interface {
GetFileContent(ctx context.Context, owner, repo, path, ref string) (string, error)
ListContents(ctx context.Context, owner, repo, path string) ([]ContentEntry, error)
}
type Reviewer interface {
PostReview(ctx context.Context, owner, repo string, number int, req ReviewRequest) (*Review, error)
ListReviews(ctx context.Context, owner, repo string, number int) ([]Review, error)
DeleteReview(ctx context.Context, owner, repo string, number int, reviewID int64) error
}
type Identity interface {
GetAuthenticatedUser(ctx context.Context) (string, error)
}
// Client combines all for callers that need everything
type Client interface {
PRReader
FileReader
Reviewer
Identity
}
```
### Types
Use what gitea/ already has. Move to vcs/types.go or re-export.
```go
type PullRequest struct { ... } // from gitea.PullRequest
type ChangedFile struct { ... } // from gitea.ChangedFile
type ContentEntry struct { ... } // from gitea.ContentEntry
type Review struct { ... } // from gitea.Review
type ReviewRequest struct { ... } // new, for PostReview input
type ReviewComment struct { ... } // from gitea.ReviewComment
```
### Adapter responsibilities
Each adapter (gitea, github) handles:
- API URL construction
- Auth header format (`token` vs `Bearer`)
- Request/response mapping
- Comment schema translation (line numbers, commit IDs, etc.)
---
## Part 4: Test Plan
### Unit Tests (mock HTTP)
```
github/
pr_test.go # TestGetPullRequest, TestGetDiff, TestGetFiles
files_test.go # TestGetFileContent, TestListContents
review_test.go # TestPostReview, TestListReviews, TestDeleteReview
identity_test.go # TestGetAuthenticatedUser
```
Per method: happy path, 404, 401, 429, malformed response.
### Integration Tests
Against github.com/aweiker/ai-core-review-bot:
- Fetch real PR
- Fetch real file
- Post + delete review (clean up)
### End-to-End
Open PR on test repo, run full review-bot, verify review appears.
---
## Part 5: Implementation Phases
### Phase 1: Extract interfaces from gitea/
**Work:**
- Create `vcs/interfaces.go` with interfaces extracted from gitea/client.go signatures
- Create `vcs/types.go` — move or alias types from gitea/
- Verify gitea.Client satisfies vcs.Client (compile-time check)
**Exit criteria:** `var _ vcs.Client = (*gitea.Client)(nil)` compiles.
---
### Phase 2: Gitea adapter (if needed)
**Work:**
- If gitea.Client method signatures don't match exactly, create wrapper
- Keep gitea/ working exactly as before
**Exit criteria:** Existing tests pass. No behavior change.
---
### Phase 3: GitHub client — PRReader
**Work:**
- `github/client.go` — struct, constructor, HTTP helpers
- `github/pr.go` — GetPullRequest, GetPullRequestDiff, GetPullRequestFiles
- Unit tests
**Exit criteria:** `go test ./github/...` passes for PR methods.
---
### Phase 4: GitHub client — FileReader
**Work:**
- `github/files.go` — GetFileContent, ListContents
- Unit tests
**Exit criteria:** Unit tests pass.
---
### Phase 5: GitHub client — Reviewer + Identity
**Work:**
- `github/review.go` — PostReview, ListReviews, DeleteReview
- `github/identity.go` — GetAuthenticatedUser
- Unit tests
**Exit criteria:** Unit tests pass.
---
### Phase 6: Integration tests
**Work:**
- `integration/github_test.go`
- Test against real GitHub
**Exit criteria:** All integration tests pass.
---
### Phase 7: Wire into cmd/review-bot
**Work:**
- Add `--provider github|gitea` flag (default: gitea for backward compat)
- Select client based on flag
- Update to use vcs interfaces where it makes sense
**Exit criteria:**
- `./review-bot --provider github ...` works
- `./review-bot --provider gitea ...` works (same as before)
- Existing Gitea workflows unchanged
---
### Phase 8: GitHub Actions workflow + releases
**Work:**
- `.github/workflows/ci.yml` — test on PR
- `.github/workflows/release.yml` — publish binary to GitHub releases
- `.github/actions/review/action.yml` — composite action
- Action downloads binary from github.com/aweiker/ai-core-review-bot releases
**Exit criteria:**
- CI runs on github.com/aweiker/ai-core-review-bot
- Release creates downloadable binary
- Review action posts review successfully
---
## Part 6: Decisions
| Question | Decision |
|----------|----------|
| Auth token | Workflow `GITHUB_TOKEN` (automatic) |
| Binary distribution | GitHub releases on aweiker/ai-core-review-bot |
| Comment schema | Adapter's job — translate ReviewComment to platform format |
| Default provider | `gitea` for backward compatibility |
| Shared types | vcs/types.go (extracted from gitea/) |
| Platform-specific features | Stay on concrete client, not interface |
---
## Summary
8 phases. Start by extracting interfaces from working gitea/ code, not inventing them. GitHub implements the same interfaces. Each phase has clear exit criteria.
-232
View File
@@ -1,232 +0,0 @@
package gitea
import (
"context"
"fmt"
"gitea.weiker.me/rodin/review-bot/vcs"
)
// Adapter wraps a gitea.Client and satisfies the vcs.Client interface.
// It handles translation between GitHub-canonical diff positions and Gitea
// line numbers, and between canonical review event strings and Gitea-native values.
type Adapter struct {
client *Client
}
// Compile-time interface conformance assertion.
var _ vcs.Client = (*Adapter)(nil)
// NewAdapter creates a new Adapter wrapping the given gitea Client.
func NewAdapter(client *Client) *Adapter {
return &Adapter{client: client}
}
// Underlying returns the wrapped gitea.Client for Gitea-specific operations
// that have no vcs.Client equivalent (resolve comment, timeline, supersede flow).
func (a *Adapter) Underlying() *Client {
return a.client
}
// --- PRReader ---
// GetPullRequest maps gitea.PullRequest to vcs.PullRequest.
func (a *Adapter) GetPullRequest(ctx context.Context, owner, repo string, number int) (*vcs.PullRequest, error) {
pr, err := a.client.GetPullRequest(ctx, owner, repo, number)
if err != nil {
return nil, fmt.Errorf("get pull request: %w", err)
}
return &vcs.PullRequest{
Number: number,
Title: pr.Title,
Body: pr.Body,
Head: vcs.HeadRef{
SHA: pr.Head.Sha,
Ref: pr.Head.Ref,
},
Base: vcs.BaseRef{
Ref: pr.Base.Ref,
},
}, nil
}
// GetPullRequestDiff is a pass-through to the underlying client.
func (a *Adapter) GetPullRequestDiff(ctx context.Context, owner, repo string, number int) (string, error) {
return a.client.GetPullRequestDiff(ctx, owner, repo, number)
}
// GetPullRequestFiles maps []gitea.ChangedFile to []vcs.ChangedFile.
// Patch field is omitted (zero-value) since Gitea's /pulls/{n}/files does not return patch text.
func (a *Adapter) GetPullRequestFiles(ctx context.Context, owner, repo string, number int) ([]vcs.ChangedFile, error) {
files, err := a.client.GetPullRequestFiles(ctx, owner, repo, number)
if err != nil {
return nil, err
}
result := make([]vcs.ChangedFile, len(files))
for i, f := range files {
result[i] = vcs.ChangedFile{
Filename: f.Filename,
Status: f.Status,
}
}
return result, nil
}
// GetFileContentAtRef is a pass-through to the underlying client.
func (a *Adapter) GetFileContentAtRef(ctx context.Context, owner, repo, path, ref string) (string, error) {
return a.client.GetFileContentAtRef(ctx, owner, repo, path, ref)
}
// GetCommitStatuses maps []gitea.CommitStatus to []vcs.CommitStatus.
func (a *Adapter) GetCommitStatuses(ctx context.Context, owner, repo, sha string) ([]vcs.CommitStatus, error) {
statuses, err := a.client.GetCommitStatuses(ctx, owner, repo, sha)
if err != nil {
return nil, err
}
result := make([]vcs.CommitStatus, len(statuses))
for i, s := range statuses {
result[i] = vcs.CommitStatus{
Status: s.Status,
Context: s.Context,
Description: s.Description,
TargetURL: s.TargetURL,
}
}
return result, nil
}
// --- FileReader ---
// GetFileContent delegates to the underlying client, routing to the ref-aware
// variant when ref is non-empty.
func (a *Adapter) GetFileContent(ctx context.Context, owner, repo, path, ref string) (string, error) {
if ref != "" {
return a.client.GetFileContentRef(ctx, owner, repo, path, ref)
}
return a.client.GetFileContent(ctx, owner, repo, path)
}
// ListContents maps []gitea.ContentEntry to []vcs.ContentEntry.
func (a *Adapter) ListContents(ctx context.Context, owner, repo, path string) ([]vcs.ContentEntry, error) {
entries, err := a.client.ListContents(ctx, owner, repo, path)
if err != nil {
return nil, err
}
result := make([]vcs.ContentEntry, len(entries))
for i, e := range entries {
result[i] = vcs.ContentEntry{
Name: e.Name,
Path: e.Path,
Type: e.Type,
}
}
return result, nil
}
// --- Reviewer ---
// translateEvent translates a vcs.ReviewEvent (GitHub-canonical) to a Gitea-native event string.
func translateEvent(event vcs.ReviewEvent) string {
switch event {
case vcs.ReviewEventApprove:
return "APPROVED"
case vcs.ReviewEventRequestChanges:
return "REQUEST_CHANGES"
case vcs.ReviewEventComment:
return "COMMENT"
default:
// Unknown events pass through as-is. This is intentional: new event types
// added to vcs.ReviewEvent will still be forwarded without a code change here,
// and Gitea will reject truly invalid values with a clear API error.
return string(event)
}
}
// PostReview translates vcs.ReviewRequest to the Gitea-native format.
// It fetches the PR diff, builds a position-to-line map, and translates each
// ReviewComment.Position (GitHub diff-position) to a Gitea new_position (line number).
func (a *Adapter) PostReview(ctx context.Context, owner, repo string, number int, req vcs.ReviewRequest) (*vcs.Review, error) {
event := translateEvent(req.Event)
var giteaComments []ReviewComment
if len(req.Comments) > 0 {
// Fetch diff to build position → line number map.
// The diff is fetched unconditionally when comments exist. This adds latency
// for reviews with inline comments but keeps the implementation simple — caching
// the diff across calls would add complexity for minimal gain since PostReview
// is called at most once per review cycle.
diff, err := a.client.GetPullRequestDiff(ctx, owner, repo, number)
if err != nil {
return nil, fmt.Errorf("fetch diff for position translation: %w", err)
}
posMap := BuildPositionToLineMap(diff)
for _, c := range req.Comments {
lineNum, err := posMap.Translate(c.Path, c.Position)
if err != nil {
return nil, fmt.Errorf("translate position %d in %s: %w", c.Position, c.Path, err)
}
// CommitID from vcs.ReviewComment is intentionally not forwarded:
// Gitea review comments are pinned to the PR head SHA automatically,
// and the CreatePullReview API has no per-comment commit_id field.
giteaComments = append(giteaComments, ReviewComment{
Path: c.Path,
NewPosition: int64(lineNum),
Body: c.Body,
})
}
}
review, err := a.client.PostReview(ctx, owner, repo, number, event, req.Body, giteaComments)
if err != nil {
return nil, fmt.Errorf("post review: %w", err)
}
return &vcs.Review{
ID: review.ID,
Body: review.Body,
User: vcs.UserInfo{Login: review.User.Login},
State: review.State,
Stale: review.Stale,
CommitID: review.CommitID,
}, nil
}
// ListReviews maps []gitea.Review to []vcs.Review.
func (a *Adapter) ListReviews(ctx context.Context, owner, repo string, number int) ([]vcs.Review, error) {
reviews, err := a.client.ListReviews(ctx, owner, repo, number)
if err != nil {
return nil, err
}
result := make([]vcs.Review, len(reviews))
for i, r := range reviews {
result[i] = vcs.Review{
ID: r.ID,
Body: r.Body,
User: vcs.UserInfo{Login: r.User.Login},
State: r.State,
Stale: r.Stale,
CommitID: r.CommitID,
}
}
return result, nil
}
// DeleteReview is a pass-through to the underlying client.
func (a *Adapter) DeleteReview(ctx context.Context, owner, repo string, number int, reviewID int64) error {
return a.client.DeleteReview(ctx, owner, repo, number, reviewID)
}
// DismissReview deletes the review. Gitea supports full deletion of any review state.
// The message parameter is intentionally unused — Gitea deletion has no dismissal message.
func (a *Adapter) DismissReview(ctx context.Context, owner, repo string, number int, reviewID int64, message string) error {
return a.client.DeleteReview(ctx, owner, repo, number, reviewID)
}
// --- Identity ---
// GetAuthenticatedUser is a pass-through to the underlying client.
func (a *Adapter) GetAuthenticatedUser(ctx context.Context) (string, error) {
return a.client.GetAuthenticatedUser(ctx)
}
-388
View File
@@ -1,388 +0,0 @@
package gitea_test
import (
"context"
"encoding/json"
"net/http"
"net/http/httptest"
"strings"
"testing"
"gitea.weiker.me/rodin/review-bot/gitea"
"gitea.weiker.me/rodin/review-bot/vcs"
)
func TestAdapter_GetPullRequest(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(map[string]any{
"title": "Test PR",
"body": "PR body",
"head": map[string]any{
"sha": "abc123",
"ref": "feature-branch",
},
"base": map[string]any{
"ref": "main",
},
})
}))
defer server.Close()
client := gitea.NewClient(server.URL, "token")
adapter := gitea.NewAdapter(client)
pr, err := adapter.GetPullRequest(context.Background(), "owner", "repo", 42)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if pr.Number != 42 {
t.Errorf("Number = %d, want 42", pr.Number)
}
if pr.Title != "Test PR" {
t.Errorf("Title = %q, want %q", pr.Title, "Test PR")
}
if pr.Body != "PR body" {
t.Errorf("Body = %q, want %q", pr.Body, "PR body")
}
if pr.Head.SHA != "abc123" {
t.Errorf("Head.SHA = %q, want %q", pr.Head.SHA, "abc123")
}
if pr.Head.Ref != "feature-branch" {
t.Errorf("Head.Ref = %q, want %q", pr.Head.Ref, "feature-branch")
}
if pr.Base.Ref != "main" {
t.Errorf("Base.Ref = %q, want %q", pr.Base.Ref, "main")
}
}
func TestAdapter_GetPullRequestFiles(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode([]map[string]any{
{"filename": "main.go", "status": "modified"},
{"filename": "new.go", "status": "added"},
})
}))
defer server.Close()
client := gitea.NewClient(server.URL, "token")
adapter := gitea.NewAdapter(client)
files, err := adapter.GetPullRequestFiles(context.Background(), "owner", "repo", 1)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if len(files) != 2 {
t.Fatalf("got %d files, want 2", len(files))
}
if files[0].Filename != "main.go" || files[0].Status != "modified" {
t.Errorf("files[0] = %+v", files[0])
}
if files[1].Filename != "new.go" || files[1].Status != "added" {
t.Errorf("files[1] = %+v", files[1])
}
}
func TestAdapter_ListReviews(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode([]map[string]any{
{
"id": 1,
"body": "LGTM",
"user": map[string]any{"login": "reviewer1"},
"state": "APPROVED",
"stale": false,
"commit_id": "abc123",
},
{
"id": 2,
"body": "Needs work",
"user": map[string]any{"login": "reviewer2"},
"state": "REQUEST_CHANGES",
"stale": true,
"commit_id": "def456",
},
})
}))
defer server.Close()
client := gitea.NewClient(server.URL, "token")
adapter := gitea.NewAdapter(client)
reviews, err := adapter.ListReviews(context.Background(), "owner", "repo", 1)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if len(reviews) != 2 {
t.Fatalf("got %d reviews, want 2", len(reviews))
}
if reviews[0].ID != 1 || reviews[0].Body != "LGTM" || reviews[0].User.Login != "reviewer1" {
t.Errorf("reviews[0] = %+v", reviews[0])
}
if reviews[0].State != "APPROVED" || reviews[0].Stale || reviews[0].CommitID != "abc123" {
t.Errorf("reviews[0] state/stale/commit = %v/%v/%v", reviews[0].State, reviews[0].Stale, reviews[0].CommitID)
}
if reviews[1].ID != 2 || !reviews[1].Stale || reviews[1].State != "REQUEST_CHANGES" {
t.Errorf("reviews[1] = %+v", reviews[1])
}
}
func TestAdapter_GetCommitStatuses(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode([]map[string]any{
{
"status": "success",
"context": "ci/test",
"description": "All tests pass",
"target_url": "https://ci.example.com/1",
},
})
}))
defer server.Close()
client := gitea.NewClient(server.URL, "token")
adapter := gitea.NewAdapter(client)
statuses, err := adapter.GetCommitStatuses(context.Background(), "owner", "repo", "abc123")
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if len(statuses) != 1 {
t.Fatalf("got %d statuses, want 1", len(statuses))
}
if statuses[0].Status != "success" {
t.Errorf("Status = %q, want %q", statuses[0].Status, "success")
}
if statuses[0].Context != "ci/test" {
t.Errorf("Context = %q, want %q", statuses[0].Context, "ci/test")
}
if statuses[0].Description != "All tests pass" {
t.Errorf("Description = %q, want %q", statuses[0].Description, "All tests pass")
}
if statuses[0].TargetURL != "https://ci.example.com/1" {
t.Errorf("TargetURL = %q, want %q", statuses[0].TargetURL, "https://ci.example.com/1")
}
}
func TestAdapter_PostReview_EventTranslation(t *testing.T) {
tests := []struct {
name string
event vcs.ReviewEvent
wantEvent string
}{
{"APPROVE becomes APPROVED", vcs.ReviewEventApprove, "APPROVED"},
{"REQUEST_CHANGES stays", vcs.ReviewEventRequestChanges, "REQUEST_CHANGES"},
{"COMMENT stays", vcs.ReviewEventComment, "COMMENT"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
var gotEvent string
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
var payload struct {
Event string `json:"event"`
}
json.NewDecoder(r.Body).Decode(&payload)
gotEvent = payload.Event
json.NewEncoder(w).Encode(map[string]any{
"id": 1,
"body": "test",
"user": map[string]any{"login": "bot"},
})
}))
defer server.Close()
client := gitea.NewClient(server.URL, "token")
adapter := gitea.NewAdapter(client)
_, err := adapter.PostReview(context.Background(), "owner", "repo", 1, vcs.ReviewRequest{
Body: "test",
Event: tt.event,
// No comments → no diff fetch needed
})
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if gotEvent != tt.wantEvent {
t.Errorf("event = %q, want %q", gotEvent, tt.wantEvent)
}
})
}
}
func TestAdapter_PostReview_WithComments_PositionTranslation(t *testing.T) {
diff := `diff --git a/main.go b/main.go
--- a/main.go
+++ b/main.go
@@ -1,3 +1,4 @@
package main
+// new comment at line 3
func main() {}
`
var gotComments []struct {
Path string `json:"path"`
NewPosition int64 `json:"new_position"`
Body string `json:"body"`
}
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
if strings.HasSuffix(r.URL.Path, ".diff") {
// Diff request
w.Write([]byte(diff))
return
}
if strings.HasSuffix(r.URL.Path, "/reviews") {
// Review post
var payload struct {
Comments []struct {
Path string `json:"path"`
NewPosition int64 `json:"new_position"`
Body string `json:"body"`
} `json:"comments"`
}
json.NewDecoder(r.Body).Decode(&payload)
gotComments = payload.Comments
json.NewEncoder(w).Encode(map[string]any{
"id": 1,
"body": "review",
"user": map[string]any{"login": "bot"},
})
return
}
t.Errorf("unexpected request: %s %s", r.Method, r.URL.Path)
w.WriteHeader(http.StatusNotFound)
}))
defer server.Close()
client := gitea.NewClient(server.URL, "token")
adapter := gitea.NewAdapter(client)
// Position 4 in this diff is "+// new comment at line 3" → new line 3
_, err := adapter.PostReview(context.Background(), "owner", "repo", 1, vcs.ReviewRequest{
Body: "review",
Event: vcs.ReviewEventRequestChanges,
Comments: []vcs.ReviewComment{
{
Path: "main.go",
Position: 4,
CommitID: "abc123",
Body: "needs fix",
},
},
})
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if len(gotComments) != 1 {
t.Fatalf("got %d comments, want 1", len(gotComments))
}
if gotComments[0].Path != "main.go" {
t.Errorf("path = %q, want %q", gotComments[0].Path, "main.go")
}
if gotComments[0].NewPosition != 3 {
t.Errorf("new_position = %d, want 3", gotComments[0].NewPosition)
}
if gotComments[0].Body != "needs fix" {
t.Errorf("body = %q, want %q", gotComments[0].Body, "needs fix")
}
}
func TestAdapter_DismissReview(t *testing.T) {
var deleteCalled bool
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Method == http.MethodDelete {
deleteCalled = true
w.WriteHeader(204)
return
}
w.WriteHeader(404)
}))
defer server.Close()
client := gitea.NewClient(server.URL, "token")
adapter := gitea.NewAdapter(client)
err := adapter.DismissReview(context.Background(), "owner", "repo", 1, 99, "stale review")
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if !deleteCalled {
t.Error("expected delete to be called")
}
}
func TestAdapter_Underlying(t *testing.T) {
client := gitea.NewClient("http://example.com", "token")
adapter := gitea.NewAdapter(client)
if adapter.Underlying() != client {
t.Error("Underlying() should return the wrapped client")
}
}
func TestAdapter_ListContents(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode([]map[string]any{
{"name": "main.go", "path": "src/main.go", "type": "file"},
{"name": "util", "path": "src/util", "type": "dir"},
})
}))
defer server.Close()
client := gitea.NewClient(server.URL, "token")
adapter := gitea.NewAdapter(client)
entries, err := adapter.ListContents(context.Background(), "owner", "repo", "src")
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if len(entries) != 2 {
t.Fatalf("got %d entries, want 2", len(entries))
}
if entries[0].Name != "main.go" || entries[0].Type != "file" {
t.Errorf("entries[0] = %+v", entries[0])
}
if entries[1].Name != "util" || entries[1].Type != "dir" {
t.Errorf("entries[1] = %+v", entries[1])
}
}
func TestAdapter_GetFileContent_RefRouting(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// When ref is provided, the URL should contain ?ref=
if r.URL.RawQuery != "" && strings.Contains(r.URL.RawQuery, "ref=") {
w.Write([]byte("content-at-ref"))
} else {
w.Write([]byte("content-default"))
}
}))
defer server.Close()
client := gitea.NewClient(server.URL, "token")
adapter := gitea.NewAdapter(client)
// Empty ref → routes to GetFileContent (no ?ref= query param)
got, err := adapter.GetFileContent(context.Background(), "owner", "repo", "main.go", "")
if err != nil {
t.Fatalf("GetFileContent(ref=\"\"): %v", err)
}
if got != "content-default" {
t.Errorf("GetFileContent(ref=\"\") = %q, want %q", got, "content-default")
}
// Non-empty ref → routes to GetFileContentRef (with ?ref= query param)
got, err = adapter.GetFileContent(context.Background(), "owner", "repo", "main.go", "abc123")
if err != nil {
t.Fatalf("GetFileContent(ref=\"abc123\"): %v", err)
}
if got != "content-at-ref" {
t.Errorf("GetFileContent(ref=\"abc123\") = %q, want %q", got, "content-at-ref")
}
}
+17 -230
View File
@@ -11,11 +11,9 @@ import (
"fmt"
"io"
"log/slog"
"net"
"net/http"
"net/url"
"strings"
"syscall"
"time"
)
@@ -41,26 +39,12 @@ func IsNotFound(err error) bool {
return errors.As(err, &apiErr) && apiErr.StatusCode == http.StatusNotFound
}
// IsServerError reports whether an error is an API 5xx response.
func IsServerError(err error) bool {
var apiErr *APIError
return errors.As(err, &apiErr) && apiErr.StatusCode >= 500 && apiErr.StatusCode < 600
}
// Client interacts with the Gitea API.
// A Client is safe for concurrent use by multiple goroutines.
type Client struct {
baseURL string
token string
http *http.Client
// RetryBackoff defines the delays between retry attempts.
// RetryBackoff[i] is the delay before attempt i+1 (after attempt i fails).
// If nil, defaults to {1s, 2s}. Set to shorter durations in tests.
//
// This field must be configured before the first request is made.
// Modifying it while requests are in flight is not safe.
RetryBackoff []time.Duration
}
// NewClient creates a new Gitea API client.
@@ -72,12 +56,6 @@ func NewClient(baseURL, token string) *Client {
}
}
// SetHTTPClient sets the underlying HTTP client used for requests.
// This is intended for testing to inject mock transports.
func (c *Client) SetHTTPClient(hc *http.Client) {
c.http = hc
}
// PullRequest holds relevant PR metadata.
type PullRequest struct {
Title string `json:"title"`
@@ -86,9 +64,6 @@ type PullRequest struct {
Sha string `json:"sha"`
Ref string `json:"ref"`
} `json:"head"`
Base struct {
Ref string `json:"ref"`
} `json:"base"`
}
// CommitStatus represents a single CI status entry.
@@ -235,185 +210,24 @@ func (c *Client) PostReview(ctx context.Context, owner, repo string, number int,
return &review, nil
}
// isTemporaryNetError reports whether err is a temporary network error worth retrying.
// This includes connection refused, network unreachable, connection reset, and DNS
// timeouts. It explicitly excludes permanent errors like permission denied or
// "no such host" DNS failures.
func isTemporaryNetError(err error) bool {
if err == nil {
return false
}
// Check for OpError and inspect the underlying syscall error.
// Not all OpErrors are transient — permission denied, for example, is permanent.
var opErr *net.OpError
if errors.As(err, &opErr) {
return isRetriableSyscallError(opErr.Err)
}
// DNS errors: only retry on timeout, not on "no such host" which is permanent.
var dnsErr *net.DNSError
if errors.As(err, &dnsErr) {
return dnsErr.IsTimeout
}
// Check for net.Error with Timeout() (Temporary is deprecated)
var netErr net.Error
if errors.As(err, &netErr) {
return netErr.Timeout()
}
return false
}
// isRetriableSyscallError reports whether the underlying error from a net.OpError
// is a transient syscall error worth retrying.
func isRetriableSyscallError(err error) bool {
if err == nil {
return false
}
// Check for syscall.Errno directly or wrapped
var errno syscall.Errno
if errors.As(err, &errno) {
switch errno {
case syscall.ECONNREFUSED, // connection refused — server not listening
syscall.ECONNRESET, // connection reset by peer
syscall.ENETUNREACH, // network unreachable
syscall.EHOSTUNREACH, // host unreachable
syscall.ETIMEDOUT: // connection timed out
return true
default:
// EACCES, EPERM, etc. are permanent — don't retry
return false
}
}
// If we can't identify the specific syscall error, be conservative and retry.
// This handles wrapped errors or platform-specific error types.
// The retry count is limited, so erring on the side of retrying is safe.
return true
}
// redactURL strips query parameters from a URL for safe logging.
// This prevents accidental exposure of sensitive data that future callers
// might pass via query strings.
func redactURL(rawURL string) string {
parsed, err := url.Parse(rawURL)
if err != nil {
// If we cannot parse it, return a safe placeholder rather than
// potentially logging something sensitive.
return "[invalid URL]"
}
if parsed.RawQuery != "" {
parsed.RawQuery = "[redacted]"
}
return parsed.String()
}
// sanitizeErrorForLog returns a loggable version of an error that omits
// potentially sensitive content like response bodies. For APIError, only
// the status code is included; for other errors, the type is preserved.
func sanitizeErrorForLog(err error) string {
if err == nil {
return "<nil>"
}
var apiErr *APIError
if errors.As(err, &apiErr) {
return fmt.Sprintf("HTTP %d", apiErr.StatusCode)
}
return err.Error()
}
// doGet performs an HTTP GET request with retry on 5xx errors and temporary
// network errors. Retries up to 3 times with exponential backoff (1s, 2s delays
// by default; configurable via Client.RetryBackoff for testing).
func (c *Client) doGet(ctx context.Context, reqURL string) ([]byte, error) {
const maxAttempts = 3
// backoff[i] is the delay before attempt i+1 (i.e., after attempt i fails).
// First attempt (i=0) has no delay; retries wait 1s then 2s by default.
backoff := c.RetryBackoff
if backoff == nil {
backoff = []time.Duration{1 * time.Second, 2 * time.Second}
req, err := http.NewRequestWithContext(ctx, http.MethodGet, reqURL, nil)
if err != nil {
return nil, err
}
req.Header.Set("Authorization", "token "+c.token)
// maxErrorBodyBytes limits how much of an error response body we read
// to protect against malicious servers sending unbounded data.
const maxErrorBodyBytes = 64 * 1024 // 64 KB
var lastErr error
for attempt := 0; attempt < maxAttempts; attempt++ {
if attempt > 0 {
// Determine delay: use backoff slice if available, otherwise retry immediately.
// An empty RetryBackoff slice means "retry without delay" — this is intentional
// as the caller explicitly configured no delays.
var delay time.Duration
if attempt-1 < len(backoff) {
delay = backoff[attempt-1]
}
if delay > 0 {
slog.Warn("retrying request after error",
"attempt", attempt+1,
"url", redactURL(reqURL),
"delay", delay.String(),
"lastError", sanitizeErrorForLog(lastErr))
timer := time.NewTimer(delay)
select {
case <-timer.C:
case <-ctx.Done():
timer.Stop()
return nil, ctx.Err()
}
}
}
req, err := http.NewRequestWithContext(ctx, http.MethodGet, reqURL, nil)
if err != nil {
return nil, err
}
req.Header.Set("Authorization", "token "+c.token)
resp, err := c.http.Do(req)
if err != nil {
// Always capture the error for consistent return at loop end.
// This ensures both network errors and HTTP 5xx return lastErr.
lastErr = err
// Only retry temporary network errors when attempts remain.
if attempt < maxAttempts-1 && isTemporaryNetError(err) {
slog.Warn("temporary network error, will retry",
"attempt", attempt+1,
"url", redactURL(reqURL),
"error", err)
continue
}
// Non-retryable network error or final attempt exhausted.
return nil, lastErr
}
if resp.StatusCode >= 200 && resp.StatusCode < 300 {
body, err := io.ReadAll(resp.Body)
resp.Body.Close()
if err != nil {
return nil, err
}
return body, nil
}
// Error path: limit how much we read from potentially malicious server
errBody, _ := io.ReadAll(io.LimitReader(resp.Body, maxErrorBodyBytes))
resp.Body.Close()
lastErr = &APIError{StatusCode: resp.StatusCode, Body: string(errBody)}
// Only retry on 5xx server errors
if resp.StatusCode < 500 || resp.StatusCode >= 600 {
return nil, lastErr
}
resp, err := c.http.Do(req)
if err != nil {
return nil, err
}
defer resp.Body.Close()
return nil, lastErr
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
body, _ := io.ReadAll(resp.Body)
return nil, &APIError{StatusCode: resp.StatusCode, Body: string(body)}
}
return io.ReadAll(resp.Body)
}
// escapePath escapes each segment of a relative file path for use in URLs.
@@ -437,13 +251,7 @@ type ContentEntry struct {
// ListContents lists files and directories at a given path in a repo.
// Pass an empty path to list the repository root.
// If the path points to a file (not a directory), Gitea returns a single
// object instead of an array; this method normalizes both cases to a slice.
func (c *Client) ListContents(ctx context.Context, owner, repo, path string) ([]ContentEntry, error) {
// Normalize "." to empty string — Gitea API rejects "." with 500
if path == "." {
path = ""
}
var reqURL string
if path == "" {
reqURL = fmt.Sprintf("%s/api/v1/repos/%s/%s/contents", c.baseURL, url.PathEscape(owner), url.PathEscape(repo))
@@ -456,16 +264,7 @@ func (c *Client) ListContents(ctx context.Context, owner, repo, path string) ([]
}
var entries []ContentEntry
if err := json.Unmarshal(body, &entries); err != nil {
// Gitea returns a single object (not an array) when path is a file
var single ContentEntry
if err2 := json.Unmarshal(body, &single); err2 != nil {
return nil, fmt.Errorf("parse contents JSON: %w", err)
}
// Guard against empty/malformed responses
if single.Name == "" && single.Path == "" {
return nil, fmt.Errorf("parse contents JSON: empty response for path %q", path)
}
entries = []ContentEntry{single}
return nil, fmt.Errorf("parse contents JSON: %w", err)
}
return entries, nil
}
@@ -518,9 +317,9 @@ func (c *Client) GetAllFilesInPath(ctx context.Context, owner, repo, path string
// Review represents a pull request review from the Gitea API.
type Review struct {
ID int64 `json:"id"`
Body string `json:"body"`
User struct {
ID int64 `json:"id"`
Body string `json:"body"`
User struct {
Login string `json:"login"`
} `json:"user"`
State string `json:"state"`
@@ -834,15 +633,3 @@ func (c *Client) ResolveComment(ctx context.Context, owner, repo string, comment
}
return nil
}
// DismissReview dismisses a review on a pull request.
// This is a stub for the vcs.Reviewer interface; full implementation is Phase 2.
func (c *Client) DismissReview(ctx context.Context, owner, repo string, number int, reviewID int64, message string) error {
return fmt.Errorf("dismiss review %d on %s/%s#%d: %w", reviewID, owner, repo, number, errors.ErrUnsupported)
}
// GetFileContentAtRef fetches a file at a specific ref from a repo.
// This delegates to GetFileContentRef for the Gitea implementation.
func (c *Client) GetFileContentAtRef(ctx context.Context, owner, repo, path, ref string) (string, error) {
return c.GetFileContentRef(ctx, owner, repo, path, ref)
}
+5 -406
View File
@@ -6,14 +6,10 @@ import (
"errors"
"fmt"
"io"
"net"
"net/http"
"net/http/httptest"
"strings"
"sync/atomic"
"syscall"
"testing"
"time"
)
func TestGetPullRequest(t *testing.T) {
@@ -280,64 +276,11 @@ func TestListContents(t *testing.T) {
}
}
func TestListContents_DotPath(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// "." should be normalized to empty path, which hits the root contents endpoint
if r.URL.Path != "/api/v1/repos/owner/repo/contents" {
t.Errorf("expected root contents path, got: %s", r.URL.Path)
}
w.Header().Set("Content-Type", "application/json")
fmt.Fprintf(w, `[{"name":"README.md","path":"README.md","type":"file"}]`)
}))
defer server.Close()
client := NewClient(server.URL, "test-token")
entries, err := client.ListContents(context.Background(), "owner", "repo", ".")
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if len(entries) != 1 {
t.Fatalf("expected 1 entry, got %d", len(entries))
}
if entries[0].Name != "README.md" {
t.Errorf("expected README.md, got %s", entries[0].Name)
}
}
func TestListContents_FilePath(t *testing.T) {
// Gitea returns a single object (not an array) when path is a file
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/api/v1/repos/owner/repo/contents/README.md" {
t.Errorf("unexpected path: %s", r.URL.Path)
}
w.Header().Set("Content-Type", "application/json")
// Single object, not an array
fmt.Fprintf(w, `{"name":"README.md","path":"README.md","type":"file"}`)
}))
defer server.Close()
client := NewClient(server.URL, "test-token")
entries, err := client.ListContents(context.Background(), "owner", "repo", "README.md")
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if len(entries) != 1 {
t.Fatalf("expected 1 entry, got %d", len(entries))
}
if entries[0].Name != "README.md" {
t.Errorf("expected README.md, got %s", entries[0].Name)
}
if entries[0].Type != "file" {
t.Errorf("expected type file, got %s", entries[0].Type)
}
}
func TestGetAllFilesInPath_File(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path == "/api/v1/repos/owner/repo/contents/README.md" {
// Gitea returns a single object (not array) when path is a file
w.Header().Set("Content-Type", "application/json")
fmt.Fprintf(w, `{"name":"README.md","path":"README.md","type":"file"}`)
// Gitea returns 404 for contents API on files (it's not a dir)
http.NotFound(w, r)
return
}
if r.URL.Path == "/api/v1/repos/owner/repo/raw/README.md" {
@@ -641,9 +584,9 @@ func TestGetAllFilesInPath_403Propagates(t *testing.T) {
func TestIsNotFound(t *testing.T) {
tests := []struct {
name string
err error
want bool
name string
err error
want bool
}{
{"nil error", nil, false},
{"non-API error", fmt.Errorf("network timeout"), false},
@@ -800,347 +743,3 @@ func TestResolveComment_Error(t *testing.T) {
t.Fatal("expected error for 404 response")
}
}
func TestIsServerError(t *testing.T) {
tests := []struct {
name string
err error
want bool
}{
{"nil error", nil, false},
{"non-API error", fmt.Errorf("network timeout"), false},
{"404 APIError", &APIError{StatusCode: 404, Body: "not found"}, false},
{"500 APIError", &APIError{StatusCode: 500, Body: "server error"}, true},
{"502 APIError", &APIError{StatusCode: 502, Body: "bad gateway"}, true},
{"503 APIError", &APIError{StatusCode: 503, Body: "unavailable"}, true},
{"599 APIError", &APIError{StatusCode: 599, Body: "edge case"}, true},
{"600 not server error", &APIError{StatusCode: 600, Body: "edge"}, false},
{"400 not server error", &APIError{StatusCode: 400, Body: "bad request"}, false},
{"wrapped 500", fmt.Errorf("fetch: %w", &APIError{StatusCode: 500, Body: "err"}), true},
{"wrapped 404", fmt.Errorf("fetch: %w", &APIError{StatusCode: 404, Body: "err"}), false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := IsServerError(tt.err)
if got != tt.want {
t.Errorf("IsServerError(%v) = %v, want %v", tt.err, got, tt.want)
}
})
}
}
func TestDoGet_RetriesOn500(t *testing.T) {
attempts := 0
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
attempts++
if attempts < 3 {
w.WriteHeader(http.StatusInternalServerError)
w.Write([]byte(`{"message":"transient error"}`))
return
}
w.WriteHeader(http.StatusOK)
w.Write([]byte(`{"data":"success"}`))
}))
defer server.Close()
client := NewClient(server.URL, "test-token")
// Use short backoff for fast tests
client.RetryBackoff = []time.Duration{1 * time.Millisecond, 1 * time.Millisecond}
body, err := client.doGet(context.Background(), server.URL+"/test")
if err != nil {
t.Fatalf("expected success after retry, got error: %v", err)
}
if string(body) != `{"data":"success"}` {
t.Errorf("body = %q, want %q", string(body), `{"data":"success"}`)
}
if attempts != 3 {
t.Errorf("attempts = %d, want 3", attempts)
}
}
func TestDoGet_FailsAfterMaxRetries(t *testing.T) {
attempts := 0
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
attempts++
w.WriteHeader(http.StatusInternalServerError)
w.Write([]byte(`{"message":"persistent error"}`))
}))
defer server.Close()
client := NewClient(server.URL, "test-token")
// Use short backoff for fast tests
client.RetryBackoff = []time.Duration{1 * time.Millisecond, 1 * time.Millisecond}
_, err := client.doGet(context.Background(), server.URL+"/test")
if err == nil {
t.Fatal("expected error after max retries")
}
var apiErr *APIError
if !errors.As(err, &apiErr) {
t.Fatalf("expected APIError, got: %v", err)
}
if apiErr.StatusCode != http.StatusInternalServerError {
t.Errorf("status = %d, want 500", apiErr.StatusCode)
}
if attempts != 3 {
t.Errorf("attempts = %d, want 3 (max retries)", attempts)
}
}
func TestDoGet_NoRetryOn4xx(t *testing.T) {
attempts := 0
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
attempts++
w.WriteHeader(http.StatusForbidden)
w.Write([]byte(`{"message":"forbidden"}`))
}))
defer server.Close()
client := NewClient(server.URL, "test-token")
_, err := client.doGet(context.Background(), server.URL+"/test")
if err == nil {
t.Fatal("expected error for 403")
}
var apiErr *APIError
if !errors.As(err, &apiErr) {
t.Fatalf("expected APIError, got: %v", err)
}
if apiErr.StatusCode != http.StatusForbidden {
t.Errorf("status = %d, want 403", apiErr.StatusCode)
}
if attempts != 1 {
t.Errorf("attempts = %d, want 1 (no retry on 4xx)", attempts)
}
}
func TestDoGet_RespectsContextCancellation(t *testing.T) {
attempts := 0
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
attempts++
w.WriteHeader(http.StatusInternalServerError)
w.Write([]byte(`{"message":"error"}`))
}))
defer server.Close()
ctx, cancel := context.WithCancel(context.Background())
client := NewClient(server.URL, "test-token")
// Use longer backoff to give us time to cancel during the wait
client.RetryBackoff = []time.Duration{100 * time.Millisecond, 100 * time.Millisecond}
// Cancel after first attempt returns and retry begins
go func() {
time.Sleep(20 * time.Millisecond)
cancel()
}()
_, err := client.doGet(ctx, server.URL+"/test")
if err == nil {
t.Fatal("expected error on context cancellation")
}
// Should have made 1 attempt, then context cancelled during backoff
if attempts != 1 {
t.Errorf("attempts = %d, expected 1 before context cancel during backoff", attempts)
}
}
// mockTransport is a test helper that returns errors for the first N calls,
// then delegates to a real server.
type mockTransport struct {
failCount int32 // number of failures remaining (atomic)
failErr error // error to return on failure
realServer *httptest.Server
attemptsMade atomic.Int32 // tracks total attempts
}
func (m *mockTransport) RoundTrip(req *http.Request) (*http.Response, error) {
m.attemptsMade.Add(1)
remaining := atomic.AddInt32(&m.failCount, -1)
if remaining >= 0 {
// Still have failures to return
return nil, m.failErr
}
// Redirect to real server
req.URL.Host = m.realServer.Listener.Addr().String()
req.URL.Scheme = "http"
return http.DefaultTransport.RoundTrip(req)
}
func TestDoGet_RetriesOnTemporaryNetError(t *testing.T) {
// Real server that will handle successful requests
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
w.Write([]byte(`{"status":"ok"}`))
}))
defer server.Close()
// Mock transport: fail twice with ECONNREFUSED, then succeed
mt := &mockTransport{
failCount: 2,
failErr: &net.OpError{Op: "dial", Net: "tcp", Err: syscall.ECONNREFUSED},
realServer: server,
}
client := NewClient("http://fake-host/", "test-token")
client.SetHTTPClient(&http.Client{Transport: mt})
client.RetryBackoff = []time.Duration{1 * time.Millisecond, 1 * time.Millisecond}
body, err := client.doGet(context.Background(), "http://fake-host/test")
if err != nil {
t.Fatalf("expected success after retries, got error: %v", err)
}
if string(body) != `{"status":"ok"}` {
t.Errorf("body = %q, want %q", string(body), `{"status":"ok"}`)
}
// Should have made exactly 3 attempts: 2 failures + 1 success
if got := mt.attemptsMade.Load(); got != 3 {
t.Errorf("attempts = %d, want 3 (2 failures + 1 success)", got)
}
}
func TestIsTemporaryNetError(t *testing.T) {
tests := []struct {
name string
err error
want bool
}{
{"nil error", nil, false},
{"plain error", fmt.Errorf("some error"), false},
// OpError with retriable syscall errors
{"OpError ECONNREFUSED", &net.OpError{Op: "dial", Err: syscall.ECONNREFUSED}, true},
{"OpError ECONNRESET", &net.OpError{Op: "read", Err: syscall.ECONNRESET}, true},
{"OpError ENETUNREACH", &net.OpError{Op: "dial", Err: syscall.ENETUNREACH}, true},
{"OpError EHOSTUNREACH", &net.OpError{Op: "dial", Err: syscall.EHOSTUNREACH}, true},
{"OpError ETIMEDOUT", &net.OpError{Op: "dial", Err: syscall.ETIMEDOUT}, true},
// OpError with permanent syscall errors — should NOT retry
{"OpError EACCES", &net.OpError{Op: "dial", Err: syscall.EACCES}, false},
{"OpError EPERM", &net.OpError{Op: "dial", Err: syscall.EPERM}, false},
// OpError with unknown inner error — conservative retry
{"OpError unknown inner", &net.OpError{Op: "dial", Err: fmt.Errorf("unknown")}, true},
// DNS errors
{"DNS timeout", &net.DNSError{IsTimeout: true}, true},
{"DNS no such host", &net.DNSError{IsTimeout: false, Name: "bad.host"}, false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := isTemporaryNetError(tt.err)
if got != tt.want {
t.Errorf("isTemporaryNetError(%v) = %v, want %v", tt.err, got, tt.want)
}
})
}
}
func TestIsRetriableSyscallError(t *testing.T) {
tests := []struct {
name string
err error
want bool
}{
{"nil", nil, false},
{"ECONNREFUSED", syscall.ECONNREFUSED, true},
{"ECONNRESET", syscall.ECONNRESET, true},
{"ENETUNREACH", syscall.ENETUNREACH, true},
{"EHOSTUNREACH", syscall.EHOSTUNREACH, true},
{"ETIMEDOUT", syscall.ETIMEDOUT, true},
{"EACCES (permanent)", syscall.EACCES, false},
{"EPERM (permanent)", syscall.EPERM, false},
{"ENOENT (permanent)", syscall.ENOENT, false},
{"unknown error", fmt.Errorf("something"), true}, // conservative retry
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := isRetriableSyscallError(tt.err)
if got != tt.want {
t.Errorf("isRetriableSyscallError(%v) = %v, want %v", tt.err, got, tt.want)
}
})
}
}
func TestRedactURL(t *testing.T) {
tests := []struct {
name string
input string
want string
}{
{
name: "no query params",
input: "https://gitea.example.com/api/v1/repos/owner/repo/pulls/1",
want: "https://gitea.example.com/api/v1/repos/owner/repo/pulls/1",
},
{
name: "with query params - redacts",
input: "https://gitea.example.com/api/v1/repos/owner/repo/raw/file?ref=main",
want: "https://gitea.example.com/api/v1/repos/owner/repo/raw/file?[redacted]",
},
{
name: "multiple query params",
input: "https://example.com/path?token=secret&page=1",
want: "https://example.com/path?[redacted]",
},
{
name: "invalid URL",
input: "://invalid",
want: "[invalid URL]",
},
{
name: "empty string",
input: "",
want: "",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := redactURL(tt.input)
if got != tt.want {
t.Errorf("redactURL(%q) = %q, want %q", tt.input, got, tt.want)
}
})
}
}
func TestSanitizeErrorForLog(t *testing.T) {
tests := []struct {
name string
err error
want string
}{
{
name: "nil error",
err: nil,
want: "<nil>",
},
{
name: "APIError omits body",
err: &APIError{StatusCode: 500, Body: "internal error: database connection failed"},
want: "HTTP 500",
},
{
name: "APIError with large body still only shows status",
err: &APIError{StatusCode: 502, Body: strings.Repeat("x", 1000)},
want: "HTTP 502",
},
{
name: "non-API error preserved",
err: fmt.Errorf("connection refused"),
want: "connection refused",
},
{
name: "wrapped APIError",
err: fmt.Errorf("request failed: %w", &APIError{StatusCode: 503, Body: "service unavailable"}),
want: "HTTP 503",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := sanitizeErrorForLog(tt.err)
if got != tt.want {
t.Errorf("sanitizeErrorForLog() = %q, want %q", got, tt.want)
}
})
}
}
-10
View File
@@ -1,10 +0,0 @@
package gitea_test
import (
"gitea.weiker.me/rodin/review-bot/gitea"
"gitea.weiker.me/rodin/review-bot/vcs"
)
// Compile-time interface conformance assertion.
// The Adapter (not the raw Client) satisfies the full vcs.Client interface.
var _ vcs.Client = (*gitea.Adapter)(nil)
-190
View File
@@ -1,190 +0,0 @@
package gitea
import (
"fmt"
"strconv"
"strings"
)
// PositionMap holds a per-file mapping of GitHub diff-position to new-file line number.
// Position is a 1-indexed offset from the @@ hunk header line in the unified diff.
type PositionMap struct {
// files maps filename → (position → new-file line number).
// Deletion lines are mapped to -1 (no new-file line).
files map[string]map[int]int
// maxPositions caches the highest position number per file,
// tracked during construction to avoid O(n) scans at translate time.
maxPositions map[string]int
}
// Translate converts a GitHub diff-position to a new-file line number for a given file.
// Returns an error if the file is not in the diff or the position is out of range.
// If the position targets a deletion line, it maps to the nearest non-deletion line below;
// if no such line exists, returns an error.
func (pm *PositionMap) Translate(file string, position int) (int, error) {
if pm == nil || pm.files == nil {
return 0, fmt.Errorf("empty position map")
}
fileMap, ok := pm.files[file]
if !ok {
return 0, fmt.Errorf("file %q not found in diff", file)
}
if position < 1 {
return 0, fmt.Errorf("position %d out of range (must be >= 1)", position)
}
lineNum, ok := fileMap[position]
if !ok {
return 0, fmt.Errorf("position %d out of range for file %q", position, file)
}
// lineNum == -1 means this position is a deletion line.
// Map to the nearest non-deletion line below.
if lineNum == -1 {
maxPos := pm.maxPosition(file)
for p := position + 1; p <= maxPos; p++ {
if ln, exists := fileMap[p]; exists && ln > 0 {
return ln, nil
}
}
return 0, fmt.Errorf("position %d targets a deletion line with no subsequent new-file line in %q", position, file)
}
return lineNum, nil
}
// maxPosition returns the highest position number for a file.
// O(1) — the maximum is tracked during map construction.
func (pm *PositionMap) maxPosition(file string) int {
return pm.maxPositions[file]
}
// BuildPositionToLineMap parses a unified diff and builds a PositionMap
// mapping diff-position → new-file line number per file.
//
// Diff-position counting rules (GitHub spec):
// - The @@ hunk header line is position 1 for the file's first hunk
// - Every subsequent line increments position by 1 — context, additions, AND deletions
// - A new @@ hunk within the same file continues incrementing (does not reset)
// - Position maps to the new file line number for additions and context lines
// - Deletion lines have a position but no new-file line number (stored as -1)
func BuildPositionToLineMap(diff string) *PositionMap {
pm := &PositionMap{
files: make(map[string]map[int]int),
maxPositions: make(map[string]int),
}
lines := strings.Split(diff, "\n")
var currentFile string
var position int
var newLine int
for _, line := range lines {
// Detect new file in diff.
// "+++ b/" is checked before "+++ /dev/null" — the two prefixes are
// non-overlapping ("+++ /dev/null" does not start with "+++ b/"), so
// ordering is independent. Checking the common case first for clarity.
if strings.HasPrefix(line, "+++ b/") {
currentFile = strings.TrimPrefix(line, "+++ b/")
position = 0
newLine = 0
if pm.files[currentFile] == nil {
pm.files[currentFile] = make(map[int]int)
}
continue
}
// Deleted file: +++ /dev/null means the file is being deleted
if strings.HasPrefix(line, "+++ /dev/null") {
currentFile = ""
continue
}
// Skip --- lines (old file header)
if strings.HasPrefix(line, "--- ") {
continue
}
// Skip diff --git lines
if strings.HasPrefix(line, "diff --git") {
continue
}
// Skip index lines
if strings.HasPrefix(line, "index ") {
continue
}
// Binary file detection
if strings.HasPrefix(line, "Binary files") {
currentFile = ""
continue
}
// Parse hunk headers
if strings.HasPrefix(line, "@@") && currentFile != "" {
position++
pm.maxPositions[currentFile] = position
newLine = parseHunkStart(line)
continue
}
if currentFile == "" {
continue
}
// Skip "\ No newline at end of file" markers
if strings.HasPrefix(line, `\`) {
continue
}
// Process diff content lines
if strings.HasPrefix(line, "+") {
// Addition: has a new-file line number
position++
pm.files[currentFile][position] = newLine
pm.maxPositions[currentFile] = position
newLine++
} else if strings.HasPrefix(line, "-") {
// Deletion: has a position but no new-file line number
position++
pm.files[currentFile][position] = -1
pm.maxPositions[currentFile] = position
} else if strings.HasPrefix(line, " ") {
// Context line
position++
pm.files[currentFile][position] = newLine
pm.maxPositions[currentFile] = position
newLine++
}
}
return pm
}
// parseHunkStart extracts the new-file starting line number from a hunk header.
// Format: @@ -old_start[,old_count] +new_start[,new_count] @@
func parseHunkStart(hunkLine string) int {
plusIdx := strings.Index(hunkLine, "+")
if plusIdx < 0 {
return 1
}
rest := hunkLine[plusIdx+1:]
endIdx := 0
for endIdx < len(rest) && rest[endIdx] >= '0' && rest[endIdx] <= '9' {
endIdx++
}
if endIdx == 0 {
return 1
}
n, err := strconv.Atoi(rest[:endIdx])
if err != nil {
return 1
}
return n
}
-274
View File
@@ -1,274 +0,0 @@
package gitea
import (
"testing"
)
func TestBuildPositionToLineMap_SingleHunk(t *testing.T) {
// @@ -16,4 +16,5 @@ ← position 1
// context ← position 2, new line 16
//-deleted ← position 3, no new line
//+added ← position 4, new line 17
// context ← position 5, new line 18
diff := `diff --git a/file.go b/file.go
index abc..def 100644
--- a/file.go
+++ b/file.go
@@ -16,4 +16,5 @@ func example() {
context line
-deleted line
+added line
context after
`
pm := BuildPositionToLineMap(diff)
tests := []struct {
pos int
wantLine int
}{
{2, 16}, // context line -> new line 16
{4, 17}, // added line -> new line 17
{5, 18}, // context after -> new line 18
}
for _, tt := range tests {
got, err := pm.Translate("file.go", tt.pos)
if err != nil {
t.Errorf("Translate(file.go, %d): unexpected error: %v", tt.pos, err)
continue
}
if got != tt.wantLine {
t.Errorf("Translate(file.go, %d) = %d, want %d", tt.pos, got, tt.wantLine)
}
}
}
func TestBuildPositionToLineMap_MultipleHunks(t *testing.T) {
diff := `diff --git a/file.go b/file.go
--- a/file.go
+++ b/file.go
@@ -1,3 +1,3 @@ package main
line1
-old
+new
@@ -10,3 +10,4 @@ func foo() {
func foo() {
+ // added
return
}
`
pm := BuildPositionToLineMap(diff)
tests := []struct {
pos int
wantLine int
}{
// First hunk: @@ is pos 1
{2, 1}, // " line1" -> new line 1
{4, 2}, // "+new" -> new line 2
// Second hunk: @@ is pos 5 (continues from 4)
// Wait: first hunk has pos 1(@@ hdr), 2(" line1"), 3("-old"), 4("+new")
// Second hunk @@ is pos 5
{6, 10}, // " func foo() {" -> new line 10
{7, 11}, // "+\t// added" -> new line 11
{8, 12}, // " \treturn" -> new line 12
{9, 13}, // " }" -> new line 13
}
for _, tt := range tests {
got, err := pm.Translate("file.go", tt.pos)
if err != nil {
t.Errorf("Translate(file.go, %d): unexpected error: %v", tt.pos, err)
continue
}
if got != tt.wantLine {
t.Errorf("Translate(file.go, %d) = %d, want %d", tt.pos, got, tt.wantLine)
}
}
}
func TestBuildPositionToLineMap_DeletionTargeted(t *testing.T) {
diff := `diff --git a/file.go b/file.go
--- a/file.go
+++ b/file.go
@@ -1,4 +1,3 @@ package main
line1
-deleted
line3
`
pm := BuildPositionToLineMap(diff)
// Position 3 is the deletion line "-deleted" — should map to nearest below
// Position 4 is " line3" which is new line 2
got, err := pm.Translate("file.go", 3)
if err != nil {
t.Fatalf("Translate(file.go, 3): unexpected error: %v", err)
}
if got != 2 {
t.Errorf("Translate(file.go, 3) = %d, want 2 (nearest non-deletion below)", got)
}
}
func TestBuildPositionToLineMap_DeletionAtEnd(t *testing.T) {
// If a deletion line is at the end with no subsequent non-deletion line, error
diff := `diff --git a/file.go b/file.go
--- a/file.go
+++ b/file.go
@@ -1,3 +1,2 @@ package main
line1
line2
-deleted at end
`
pm := BuildPositionToLineMap(diff)
_, err := pm.Translate("file.go", 4)
if err == nil {
t.Error("expected error for deletion at end with no subsequent line")
}
}
func TestBuildPositionToLineMap_NewFile(t *testing.T) {
diff := `diff --git a/new.go b/new.go
new file mode 100644
--- /dev/null
+++ b/new.go
@@ -0,0 +1,3 @@
+package main
+
+func init() {}
`
pm := BuildPositionToLineMap(diff)
tests := []struct {
pos int
wantLine int
}{
{2, 1}, // "+package main" -> line 1
{3, 2}, // "+" (empty line) -> line 2
{4, 3}, // "+func init() {}" -> line 3
}
for _, tt := range tests {
got, err := pm.Translate("new.go", tt.pos)
if err != nil {
t.Errorf("Translate(new.go, %d): unexpected error: %v", tt.pos, err)
continue
}
if got != tt.wantLine {
t.Errorf("Translate(new.go, %d) = %d, want %d", tt.pos, got, tt.wantLine)
}
}
}
func TestBuildPositionToLineMap_DeletedFile(t *testing.T) {
diff := `diff --git a/old.go b/old.go
deleted file mode 100644
--- a/old.go
+++ /dev/null
@@ -1,3 +0,0 @@
-package main
-
-func old() {}
`
pm := BuildPositionToLineMap(diff)
// Deleted file has no new-file lines; positions should error
_, err := pm.Translate("old.go", 2)
if err == nil {
t.Error("expected error for deleted file position")
}
}
func TestBuildPositionToLineMap_BinaryFile(t *testing.T) {
diff := `diff --git a/image.png b/image.png
Binary files /dev/null and b/image.png differ
diff --git a/code.go b/code.go
--- a/code.go
+++ b/code.go
@@ -1,2 +1,3 @@
package main
+// added
func main() {}
`
pm := BuildPositionToLineMap(diff)
// Binary file should not be in the map
_, err := pm.Translate("image.png", 1)
if err == nil {
t.Error("expected error for binary file")
}
// code.go should still work
got, err := pm.Translate("code.go", 3)
if err != nil {
t.Fatalf("Translate(code.go, 3): unexpected error: %v", err)
}
if got != 2 {
t.Errorf("Translate(code.go, 3) = %d, want 2", got)
}
}
func TestBuildPositionToLineMap_OutOfRange(t *testing.T) {
diff := `diff --git a/file.go b/file.go
--- a/file.go
+++ b/file.go
@@ -1,2 +1,2 @@
line1
-old
+new
`
pm := BuildPositionToLineMap(diff)
// Position 0 is invalid
_, err := pm.Translate("file.go", 0)
if err == nil {
t.Error("expected error for position 0")
}
// Position 5 is out of range (only positions 1-4 exist)
_, err = pm.Translate("file.go", 5)
if err == nil {
t.Error("expected error for position 5 (out of range)")
}
// Unknown file
_, err = pm.Translate("unknown.go", 1)
if err == nil {
t.Error("expected error for unknown file")
}
}
func TestBuildPositionToLineMap_MultipleFiles(t *testing.T) {
diff := `diff --git a/a.go b/a.go
--- a/a.go
+++ b/a.go
@@ -1,2 +1,3 @@
package a
+// file a
func aFunc() {}
diff --git a/b.go b/b.go
--- a/b.go
+++ b/b.go
@@ -1,2 +1,3 @@
package b
+// file b
func bFunc() {}
`
pm := BuildPositionToLineMap(diff)
// a.go: pos 3 is "+// file a" -> new line 2
got, err := pm.Translate("a.go", 3)
if err != nil {
t.Fatalf("Translate(a.go, 3): %v", err)
}
if got != 2 {
t.Errorf("Translate(a.go, 3) = %d, want 2", got)
}
// b.go: pos 3 is "+// file b" -> new line 2
// Note: position resets per file
got, err = pm.Translate("b.go", 3)
if err != nil {
t.Fatalf("Translate(b.go, 3): %v", err)
}
if got != 2 {
t.Errorf("Translate(b.go, 3) = %d, want 2", got)
}
}
-327
View File
@@ -1,327 +0,0 @@
// Package github provides a client for the GitHub API.
// It supports pull request operations, file content retrieval, CI status checks,
// and directory listing for both github.com and GitHub Enterprise.
package github
import (
"context"
"errors"
"fmt"
"io"
"net/http"
"net/url"
"strconv"
"strings"
"time"
)
const (
defaultBaseURL = "https://api.github.com"
userAgent = "review-bot/1.0"
// maxResponseBytes limits successful response body reads to 10 MiB.
maxResponseBytes = 10 * 1024 * 1024
)
// APIError represents an HTTP error response from the GitHub API.
// It carries the status code so callers can distinguish between
// different failure modes (e.g. 404 vs 500).
//
// The Body field stores up to 4 KiB of the raw response for programmatic
// inspection. Error() truncates to 200 bytes for safe logging, but callers
// should avoid logging or propagating Body directly in production since it may
// contain sensitive details from the upstream server.
type APIError struct {
StatusCode int
Body string
}
func (e *APIError) Error() string {
body := e.Body
if len(body) > 200 {
body = body[:200] + "...(truncated)"
}
// Sanitize newlines to prevent log injection from upstream response bodies.
body = strings.ReplaceAll(body, "\n", " ")
body = strings.ReplaceAll(body, "\r", " ")
return fmt.Sprintf("HTTP %d: %s", e.StatusCode, body)
}
// IsNotFound reports whether an error is an API 404 response.
func IsNotFound(err error) bool {
if apiErr, ok := asAPIError(err); ok {
return apiErr.StatusCode == http.StatusNotFound
}
return false
}
// IsUnauthorized reports whether an error is an API 401 response.
func IsUnauthorized(err error) bool {
if apiErr, ok := asAPIError(err); ok {
return apiErr.StatusCode == http.StatusUnauthorized
}
return false
}
func asAPIError(err error) (*APIError, bool) {
if err == nil {
return nil, false
}
var target *APIError
if errors.As(err, &target) {
return target, true
}
return nil, false
}
// clientConfig holds optional configuration for NewClient.
type clientConfig struct {
allowInsecureHTTP bool
}
// ClientOption configures optional behavior of NewClient.
type ClientOption func(*clientConfig)
// AllowInsecureHTTP permits the client to use HTTP (non-TLS) base URLs.
// This should only be used for trusted internal deployments or testing.
func AllowInsecureHTTP() ClientOption {
return func(c *clientConfig) {
c.allowInsecureHTTP = true
}
}
// Client interacts with the GitHub API.
// A Client is safe for concurrent use by multiple goroutines.
// SetHTTPClient and SetRetryBackoff are intended for test setup only and must
// be called before any goroutines issue requests; they have no synchronization.
type Client struct {
baseURL string
token string
allowInsecureHTTP bool
httpClient *http.Client
// retryBackoff defines the delays between retry attempts for 429 responses.
// retryBackoff[i] is the delay before attempt i+1 (after attempt i fails).
// If nil, defaults to {1s, 2s}. Set to shorter durations in tests via SetRetryBackoff.
retryBackoff []time.Duration
}
// defaultCheckRedirect is the redirect policy used by NewClient and SetHTTPClient(nil).
// It rejects HTTPS→HTTP protocol downgrades (to prevent plaintext leakage) and strips
// the Authorization header on cross-host redirects to prevent credential leakage to
// third-party hosts (e.g. CDN redirects from GitHub).
func defaultCheckRedirect(req *http.Request, via []*http.Request) error {
if len(via) >= 10 {
return fmt.Errorf("stopped after 10 redirects")
}
// Guard: net/http guarantees len(via) >= 1 but this is undocumented;
// defend against zero-length to avoid panic on index out of range.
if len(via) == 0 {
return nil
}
prev := via[len(via)-1]
// Reject protocol downgrade: HTTPS→HTTP leaks request metadata over plaintext.
if prev.URL.Scheme == "https" && req.URL.Scheme == "http" {
return fmt.Errorf("refusing redirect from HTTPS to HTTP (%s → %s)", prev.URL.Host, req.URL.Host)
}
// Strip Authorization on cross-host redirect to avoid leaking credentials
// to third-party hosts (GitHub legitimately redirects to CDN hosts).
if req.URL.Host != prev.URL.Host {
req.Header.Del("Authorization")
}
return nil
}
// NewClient creates a new GitHub API client.
// If baseURL is empty, it defaults to https://api.github.com.
// For GitHub Enterprise, pass the API base URL (e.g. https://github.concur.com/api/v3).
// The baseURL must use HTTPS; pass AllowInsecureHTTP() as an option to permit HTTP
// for trusted internal deployments (e.g. local testing).
func NewClient(token, baseURL string, opts ...ClientOption) *Client {
if baseURL == "" {
baseURL = defaultBaseURL
}
cfg := clientConfig{}
for _, o := range opts {
o(&cfg)
}
return &Client{
baseURL: strings.TrimRight(baseURL, "/"),
allowInsecureHTTP: cfg.allowInsecureHTTP,
token: token,
httpClient: &http.Client{
Timeout: 30 * time.Second,
CheckRedirect: defaultCheckRedirect,
},
}
}
// SetHTTPClient sets the underlying HTTP client used for requests.
// This is intended for test setup only to inject mock transports; it must be
// called before any goroutines issue requests.
//
// Passing nil restores the default client (30s timeout + auth-stripping
// CheckRedirect policy matching NewClient).
//
// Callers providing a non-nil client are responsible for configuring a safe
// CheckRedirect policy. Without one, the default net/http behavior will follow
// redirects and may forward the Authorization header to untrusted hosts.
func (c *Client) SetHTTPClient(hc *http.Client) {
if hc == nil {
hc = &http.Client{
Timeout: 30 * time.Second,
CheckRedirect: defaultCheckRedirect,
}
}
c.httpClient = hc
}
// SetRetryBackoff configures the retry backoff durations for testing.
// It must be called before any goroutines issue requests.
// In production the default {1s, 2s} applies.
func (c *Client) SetRetryBackoff(d []time.Duration) {
c.retryBackoff = d
}
// doRequest performs an HTTP request with retry on 429 rate limit responses.
// It respects the Retry-After header when present (capped at maxRetryAfter).
// Transport errors (network failures, context cancellation) are not retried.
func (c *Client) doRequest(ctx context.Context, method, reqURL string, accept string) ([]byte, error) {
const maxAttempts = 3
const maxRetryAfter = 120 * time.Second
var backoff []time.Duration
if c.retryBackoff != nil {
backoff = make([]time.Duration, len(c.retryBackoff))
copy(backoff, c.retryBackoff)
} else {
backoff = []time.Duration{1 * time.Second, 2 * time.Second}
}
// maxErrorBodyBytes limits how much of an error response body is stored.
// Kept small (4 KiB) to reduce the risk of sensitive data leakage if callers
// log APIError.Body directly. Error() further truncates to 200 bytes.
const maxErrorBodyBytes = 4 * 1024
// Reject non-HTTPS URLs early since the URL is immutable across retries.
if c.token != "" && !c.allowInsecureHTTP {
parsed, err := url.Parse(reqURL)
if err != nil {
return nil, fmt.Errorf("parse request URL: %w", err)
}
if !strings.EqualFold(parsed.Scheme, "https") {
return nil, fmt.Errorf("refusing to send credentials over non-HTTPS URL %q (use AllowInsecureHTTP option for trusted networks)", reqURL)
}
}
var lastErr error
for attempt := 0; attempt < maxAttempts; attempt++ {
if attempt > 0 {
var delay time.Duration
if attempt-1 < len(backoff) {
delay = backoff[attempt-1]
}
if delay > 0 {
timer := time.NewTimer(delay)
select {
case <-timer.C:
timer.Stop() // no-op after fire; kept for symmetry with the ctx.Done case
case <-ctx.Done():
timer.Stop()
return nil, ctx.Err()
}
}
}
req, err := http.NewRequestWithContext(ctx, method, reqURL, nil)
if err != nil {
return nil, fmt.Errorf("create request: %w", err)
}
if c.token != "" {
// Bearer is the OAuth2 standard and is accepted by GitHub for both
// classic PATs and fine-grained tokens. The alternative "token" scheme
// is GitHub-specific and offers no additional compatibility.
req.Header.Set("Authorization", "Bearer "+c.token)
}
req.Header.Set("User-Agent", userAgent)
if accept != "" {
req.Header.Set("Accept", accept)
} else {
req.Header.Set("Accept", "application/vnd.github+json")
}
resp, err := c.httpClient.Do(req)
if err != nil {
return nil, fmt.Errorf("do request: %w", err)
}
body, done, err := c.handleResponse(resp, maxResponseBytes, maxErrorBodyBytes)
if done {
return body, err
}
lastErr = err
// Retry on 429 rate limit
if resp.StatusCode == http.StatusTooManyRequests && attempt < maxAttempts-1 {
// Check for Retry-After header and override backoff if present.
// Supports both integer seconds (common) and HTTP-date format (RFC 7231).
if ra := resp.Header.Get("Retry-After"); ra != "" {
if seconds, err := strconv.Atoi(ra); err == nil && seconds > 0 {
delay := time.Duration(seconds) * time.Second
if delay > maxRetryAfter {
delay = maxRetryAfter
}
if attempt < len(backoff) {
backoff[attempt] = delay
}
} else if retryAt, err := http.ParseTime(ra); err == nil {
delay := time.Until(retryAt)
if delay < 0 {
delay = 0
}
if delay > maxRetryAfter {
delay = maxRetryAfter
}
if attempt < len(backoff) {
backoff[attempt] = delay
}
}
}
continue
}
// Don't retry other errors
return nil, lastErr
}
return nil, lastErr
}
// handleResponse reads and closes the response body, returning the result.
// It uses defer to ensure the body is always closed regardless of code path.
// Returns (body, done, err) where done=true means the caller should return immediately.
func (c *Client) handleResponse(resp *http.Response, maxRespBytes int, maxErrBytes int) ([]byte, bool, error) {
defer resp.Body.Close()
if resp.StatusCode >= 200 && resp.StatusCode < 300 {
body, err := io.ReadAll(io.LimitReader(resp.Body, int64(maxRespBytes)+1))
if err != nil {
return nil, true, fmt.Errorf("read response body: %w", err)
}
if len(body) > maxRespBytes {
return nil, true, fmt.Errorf("response body exceeded %d bytes (truncated)", maxRespBytes)
}
return body, true, nil
}
errBody, readErr := io.ReadAll(io.LimitReader(resp.Body, int64(maxErrBytes)))
if readErr != nil && len(errBody) == 0 {
errBody = []byte(fmt.Sprintf("[error reading response body: %v]", readErr))
}
return nil, false, &APIError{StatusCode: resp.StatusCode, Body: string(errBody)}
}
// doGet is a convenience wrapper for GET requests with the default Accept header.
func (c *Client) doGet(ctx context.Context, reqURL string) ([]byte, error) {
return c.doRequest(ctx, http.MethodGet, reqURL, "")
}
-556
View File
@@ -1,556 +0,0 @@
package github
import (
"context"
"net/http"
"net/http/httptest"
"net/url"
"strings"
"testing"
"time"
)
func TestNewClient_DefaultBaseURL(t *testing.T) {
c := NewClient("test-token", "")
if c.baseURL != "https://api.github.com" {
t.Errorf("expected default base URL, got %q", c.baseURL)
}
}
func TestNewClient_CustomBaseURL(t *testing.T) {
c := NewClient("test-token", "https://github.concur.com/api/v3")
if c.baseURL != "https://github.concur.com/api/v3" {
t.Errorf("expected custom base URL, got %q", c.baseURL)
}
}
func TestNewClient_TrimsTrailingSlash(t *testing.T) {
c := NewClient("test-token", "https://github.concur.com/api/v3/")
if c.baseURL != "https://github.concur.com/api/v3" {
t.Errorf("expected trailing slash trimmed, got %q", c.baseURL)
}
}
func TestDoRequest_SetsAuthHeader(t *testing.T) {
var gotAuth string
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
gotAuth = r.Header.Get("Authorization")
w.WriteHeader(200)
w.Write([]byte("{}"))
}))
defer srv.Close()
c := NewClient("my-token", srv.URL, AllowInsecureHTTP())
c.SetHTTPClient(srv.Client())
_, _ = c.doGet(context.Background(), srv.URL+"/test")
if gotAuth != "Bearer my-token" {
t.Errorf("expected Bearer auth, got %q", gotAuth)
}
}
func TestDoRequest_SetsDefaultAcceptHeader(t *testing.T) {
var gotAccept string
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
gotAccept = r.Header.Get("Accept")
w.WriteHeader(200)
w.Write([]byte("{}"))
}))
defer srv.Close()
c := NewClient("token", srv.URL, AllowInsecureHTTP())
c.SetHTTPClient(srv.Client())
_, _ = c.doGet(context.Background(), srv.URL+"/test")
if gotAccept != "application/vnd.github+json" {
t.Errorf("expected default Accept header, got %q", gotAccept)
}
}
func TestDoRequest_429Retry(t *testing.T) {
attempts := 0
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
attempts++
if attempts == 1 {
w.WriteHeader(429)
w.Write([]byte(`{"message":"rate limit"}`))
return
}
w.WriteHeader(200)
w.Write([]byte(`{"ok":true}`))
}))
defer srv.Close()
c := NewClient("token", srv.URL, AllowInsecureHTTP())
c.SetHTTPClient(srv.Client())
c.SetRetryBackoff([]time.Duration{10 * time.Millisecond, 10 * time.Millisecond})
body, err := c.doGet(context.Background(), srv.URL+"/test")
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if string(body) != `{"ok":true}` {
t.Errorf("unexpected body: %s", body)
}
if attempts != 2 {
t.Errorf("expected 2 attempts, got %d", attempts)
}
}
func TestDoRequest_429ExhaustsRetries(t *testing.T) {
attempts := 0
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
attempts++
w.WriteHeader(429)
w.Write([]byte(`{"message":"rate limit"}`))
}))
defer srv.Close()
c := NewClient("token", srv.URL, AllowInsecureHTTP())
c.SetHTTPClient(srv.Client())
c.SetRetryBackoff([]time.Duration{1 * time.Millisecond, 1 * time.Millisecond})
_, err := c.doGet(context.Background(), srv.URL+"/test")
if err == nil {
t.Fatal("expected error after exhausting retries")
}
apiErr, ok := err.(*APIError)
if !ok {
t.Fatalf("expected *APIError, got %T", err)
}
if apiErr.StatusCode != 429 {
t.Errorf("expected 429, got %d", apiErr.StatusCode)
}
if attempts != 3 {
t.Errorf("expected 3 attempts, got %d", attempts)
}
}
func TestDoRequest_404NoRetry(t *testing.T) {
attempts := 0
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
attempts++
w.WriteHeader(404)
w.Write([]byte(`{"message":"not found"}`))
}))
defer srv.Close()
c := NewClient("token", srv.URL, AllowInsecureHTTP())
c.SetHTTPClient(srv.Client())
_, err := c.doGet(context.Background(), srv.URL+"/test")
if err == nil {
t.Fatal("expected error for 404")
}
if attempts != 1 {
t.Errorf("expected 1 attempt (no retry on 404), got %d", attempts)
}
}
func TestDoRequest_401NoRetry(t *testing.T) {
attempts := 0
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
attempts++
w.WriteHeader(401)
w.Write([]byte(`{"message":"bad credentials"}`))
}))
defer srv.Close()
c := NewClient("token", srv.URL, AllowInsecureHTTP())
c.SetHTTPClient(srv.Client())
_, err := c.doGet(context.Background(), srv.URL+"/test")
if err == nil {
t.Fatal("expected error for 401")
}
if attempts != 1 {
t.Errorf("expected 1 attempt (no retry on 401), got %d", attempts)
}
}
func TestIsNotFound(t *testing.T) {
err := &APIError{StatusCode: 404, Body: "not found"}
if !IsNotFound(err) {
t.Error("expected IsNotFound to return true for 404")
}
err2 := &APIError{StatusCode: 500, Body: "server error"}
if IsNotFound(err2) {
t.Error("expected IsNotFound to return false for 500")
}
}
func TestIsUnauthorized(t *testing.T) {
err := &APIError{StatusCode: 401, Body: "bad credentials"}
if !IsUnauthorized(err) {
t.Error("expected IsUnauthorized to return true for 401")
}
}
func TestAPIError_SanitizesNewlines(t *testing.T) {
err := &APIError{StatusCode: 500, Body: "line1\ninjected\rmore"}
msg := err.Error()
if strings.Contains(msg, "\n") || strings.Contains(msg, "\r") {
t.Errorf("expected newlines to be sanitized, got: %q", msg)
}
if !strings.Contains(msg, "line1 injected more") {
t.Errorf("expected sanitized body, got: %q", msg)
}
}
func TestDoRequest_429RetryAfterHeader(t *testing.T) {
if testing.Short() {
t.Skip("skipping slow retry test in short mode")
}
attempts := 0
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
attempts++
if attempts == 1 {
w.Header().Set("Retry-After", "1")
w.WriteHeader(429)
w.Write([]byte(`{"message":"rate limit"}`))
return
}
w.WriteHeader(200)
w.Write([]byte(`{"ok":true}`))
}))
defer srv.Close()
c := NewClient("token", srv.URL, AllowInsecureHTTP())
c.SetHTTPClient(srv.Client())
// Use short backoff; Retry-After should override
c.SetRetryBackoff([]time.Duration{1 * time.Millisecond, 1 * time.Millisecond})
start := time.Now()
body, err := c.doGet(context.Background(), srv.URL+"/test")
elapsed := time.Since(start)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if string(body) != `{"ok":true}` {
t.Errorf("unexpected body: %s", body)
}
if attempts != 2 {
t.Errorf("expected 2 attempts, got %d", attempts)
}
// Retry-After: 1 means at least 1 second delay
if elapsed < 900*time.Millisecond {
t.Errorf("expected ~1s delay from Retry-After, got %v", elapsed)
}
}
func TestDoRequest_RetryAfterDoesNotMutateBackoff(t *testing.T) {
if testing.Short() {
t.Skip("skipping slow retry test in short mode")
}
attempts := 0
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
attempts++
if attempts == 1 {
w.Header().Set("Retry-After", "1")
w.WriteHeader(429)
w.Write([]byte(`{"message":"rate limit"}`))
return
}
w.WriteHeader(200)
w.Write([]byte(`{"ok":true}`))
}))
defer srv.Close()
c := NewClient("token", srv.URL, AllowInsecureHTTP())
c.SetHTTPClient(srv.Client())
c.SetRetryBackoff([]time.Duration{1 * time.Millisecond, 1 * time.Millisecond})
_, err := c.doGet(context.Background(), srv.URL+"/test")
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
// Verify the original retryBackoff slice was not mutated
if c.retryBackoff[0] != 1*time.Millisecond {
t.Errorf("retryBackoff[0] was mutated: got %v, want 1ms", c.retryBackoff[0])
}
if c.retryBackoff[1] != 1*time.Millisecond {
t.Errorf("retryBackoff[1] was mutated: got %v, want 1ms", c.retryBackoff[1])
}
}
func TestDoRequest_429RetryAfterHTTPDate(t *testing.T) {
if testing.Short() {
t.Skip("skipping slow Retry-After HTTP-date test in short mode")
}
attempts := 0
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
attempts++
if attempts == 1 {
// Use HTTP-date format (RFC 7231) — a time 2 seconds in the future.
future := time.Now().Add(2 * time.Second).UTC()
w.Header().Set("Retry-After", future.Format(http.TimeFormat))
w.WriteHeader(429)
w.Write([]byte(`{"message":"rate limit"}`))
return
}
w.WriteHeader(200)
w.Write([]byte(`{"ok":true}`))
}))
defer srv.Close()
c := NewClient("token", srv.URL, AllowInsecureHTTP())
c.SetHTTPClient(srv.Client())
c.SetRetryBackoff([]time.Duration{1 * time.Millisecond, 1 * time.Millisecond})
start := time.Now()
body, err := c.doGet(context.Background(), srv.URL+"/test")
elapsed := time.Since(start)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if string(body) != `{"ok":true}` {
t.Errorf("unexpected body: %s", body)
}
if attempts != 2 {
t.Errorf("expected 2 attempts, got %d", attempts)
}
// HTTP-date was ~2s in the future; by the time client processes it,
// time.Until gives ~1-2s. Verify it's meaningfully delayed (not instant).
if elapsed < 500*time.Millisecond {
t.Errorf("expected meaningful delay from HTTP-date Retry-After, got %v", elapsed)
}
}
func TestDoRequest_429RetryAfterHTTPDateInPast(t *testing.T) {
attempts := 0
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
attempts++
if attempts == 1 {
// Use a time in the past — should result in zero/immediate retry.
past := time.Now().Add(-10 * time.Second).UTC()
w.Header().Set("Retry-After", past.Format(http.TimeFormat))
w.WriteHeader(429)
w.Write([]byte(`{"message":"rate limit"}`))
return
}
w.WriteHeader(200)
w.Write([]byte(`{"ok":true}`))
}))
defer srv.Close()
c := NewClient("token", srv.URL, AllowInsecureHTTP())
c.SetHTTPClient(srv.Client())
c.SetRetryBackoff([]time.Duration{5 * time.Second, 5 * time.Second})
start := time.Now()
_, err := c.doGet(context.Background(), srv.URL+"/test")
elapsed := time.Since(start)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if attempts != 2 {
t.Errorf("expected 2 attempts, got %d", attempts)
}
// Past date should override the 5s backoff to ~0
if elapsed > 500*time.Millisecond {
t.Errorf("expected near-instant retry for past HTTP-date, got %v", elapsed)
}
}
func TestDoRequest_SetsUserAgentHeader(t *testing.T) {
var gotUA string
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
gotUA = r.Header.Get("User-Agent")
w.WriteHeader(200)
w.Write([]byte("{}"))
}))
defer srv.Close()
c := NewClient("token", srv.URL, AllowInsecureHTTP())
c.SetHTTPClient(srv.Client())
_, _ = c.doGet(context.Background(), srv.URL+"/test")
if gotUA != "review-bot/1.0" {
t.Errorf("expected User-Agent 'review-bot/1.0', got %q", gotUA)
}
}
func TestDoRequest_LimitsResponseBody(t *testing.T) {
// Verify that oversized responses return an error rather than silently truncating.
bigBody := strings.Repeat("x", maxResponseBytes+1024)
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(200)
w.Write([]byte(bigBody))
}))
defer srv.Close()
c := NewClient("token", srv.URL, AllowInsecureHTTP())
c.SetHTTPClient(srv.Client())
_, err := c.doGet(context.Background(), srv.URL+"/test")
if err == nil {
t.Fatal("expected error for oversized response body")
}
if !strings.Contains(err.Error(), "exceeded") {
t.Errorf("expected truncation error, got: %v", err)
}
}
func TestDoRequest_AcceptsExactlyAtLimit(t *testing.T) {
// A response body exactly equal to maxResponseBytes should succeed (not error).
exactBody := strings.Repeat("x", maxResponseBytes)
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(200)
w.Write([]byte(exactBody))
}))
defer srv.Close()
c := NewClient("token", srv.URL, AllowInsecureHTTP())
c.SetHTTPClient(srv.Client())
body, err := c.doGet(context.Background(), srv.URL+"/test")
if err != nil {
t.Fatalf("unexpected error for exactly-at-limit body: %v", err)
}
if len(body) != maxResponseBytes {
t.Errorf("expected body length %d, got %d", maxResponseBytes, len(body))
}
}
func TestDoRequest_SkipsAuthWhenTokenEmpty(t *testing.T) {
var gotAuth string
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
gotAuth = r.Header.Get("Authorization")
w.WriteHeader(200)
w.Write([]byte("{}"))
}))
defer srv.Close()
c := NewClient("", srv.URL, AllowInsecureHTTP()) // empty token
c.SetHTTPClient(srv.Client())
_, _ = c.doGet(context.Background(), srv.URL+"/test")
if gotAuth != "" {
t.Errorf("expected no Authorization header with empty token, got %q", gotAuth)
}
}
func TestNewClient_CheckRedirectStripsAuthOnCrossHost(t *testing.T) {
// Verify the CheckRedirect function is configured
c := NewClient("secret-token", "https://api.github.com")
if c.httpClient.CheckRedirect == nil {
t.Fatal("expected CheckRedirect to be set")
}
}
func TestDefaultCheckRedirect_RejectsHTTPSToHTTP(t *testing.T) {
prev := &http.Request{URL: &url.URL{Scheme: "https", Host: "api.github.com", Path: "/foo"}}
req := &http.Request{
URL: &url.URL{Scheme: "http", Host: "api.github.com", Path: "/foo"},
Header: http.Header{"Authorization": []string{"Bearer token"}},
}
err := defaultCheckRedirect(req, []*http.Request{prev})
if err == nil {
t.Fatal("expected error on HTTPS→HTTP redirect")
}
if !strings.Contains(err.Error(), "refusing redirect from HTTPS to HTTP") {
t.Errorf("unexpected error message: %v", err)
}
}
func TestDefaultCheckRedirect_StripsAuthOnCrossHost(t *testing.T) {
prev := &http.Request{URL: &url.URL{Scheme: "https", Host: "api.github.com", Path: "/foo"}}
req := &http.Request{
URL: &url.URL{Scheme: "https", Host: "objects.githubusercontent.com", Path: "/bar"},
Header: http.Header{"Authorization": []string{"Bearer token"}},
}
err := defaultCheckRedirect(req, []*http.Request{prev})
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if auth := req.Header.Get("Authorization"); auth != "" {
t.Errorf("expected Authorization header to be stripped, got %q", auth)
}
}
func TestDefaultCheckRedirect_PreservesAuthOnSameHost(t *testing.T) {
prev := &http.Request{URL: &url.URL{Scheme: "https", Host: "api.github.com", Path: "/foo"}}
req := &http.Request{
URL: &url.URL{Scheme: "https", Host: "api.github.com", Path: "/bar"},
Header: http.Header{"Authorization": []string{"Bearer token"}},
}
err := defaultCheckRedirect(req, []*http.Request{prev})
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if auth := req.Header.Get("Authorization"); auth != "Bearer token" {
t.Errorf("expected Authorization to be preserved, got %q", auth)
}
}
func TestDoRequest_RejectsHTTPWithToken(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(200)
w.Write([]byte("{}"))
}))
defer srv.Close()
// Without AllowInsecureHTTP, should refuse to send token over HTTP
c := NewClient("secret-token", srv.URL)
c.SetHTTPClient(srv.Client())
_, err := c.doGet(context.Background(), srv.URL+"/test")
if err == nil {
t.Fatal("expected error when sending token over HTTP")
}
if !strings.Contains(err.Error(), "refusing to send credentials") {
t.Errorf("unexpected error message: %v", err)
}
}
func TestDoRequest_AllowsHTTPWithoutToken(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(200)
w.Write([]byte(`{"ok":true}`))
}))
defer srv.Close()
// Without token, HTTP should be fine (no credentials to leak)
c := NewClient("", srv.URL)
c.SetHTTPClient(srv.Client())
body, err := c.doGet(context.Background(), srv.URL+"/test")
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if string(body) != `{"ok":true}` {
t.Errorf("unexpected body: %s", body)
}
}
func TestDoRequest_AllowsHTTPWithInsecureOption(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(200)
w.Write([]byte(`{"ok":true}`))
}))
defer srv.Close()
c := NewClient("secret-token", srv.URL, AllowInsecureHTTP())
c.SetHTTPClient(srv.Client())
body, err := c.doGet(context.Background(), srv.URL+"/test")
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if string(body) != `{"ok":true}` {
t.Errorf("unexpected body: %s", body)
}
}
func TestSetHTTPClient_NilRestoresDefault(t *testing.T) {
c := NewClient("token", "https://api.github.com")
c.SetHTTPClient(nil)
if c.httpClient == nil {
t.Fatal("expected non-nil httpClient after SetHTTPClient(nil)")
}
if c.httpClient.Timeout != 30*time.Second {
t.Errorf("expected 30s timeout, got %v", c.httpClient.Timeout)
}
if c.httpClient.CheckRedirect == nil {
t.Fatal("expected CheckRedirect policy after SetHTTPClient(nil)")
}
}
-10
View File
@@ -1,10 +0,0 @@
package github_test
import (
"gitea.weiker.me/rodin/review-bot/github"
"gitea.weiker.me/rodin/review-bot/vcs"
)
// Compile-time interface conformance assertion.
// Verifies github.Client satisfies vcs.PRReader.
var _ vcs.PRReader = (*github.Client)(nil)
-68
View File
@@ -1,68 +0,0 @@
package github
import (
"context"
"encoding/base64"
"encoding/json"
"fmt"
"net/url"
"strings"
)
// GetFileContentAtRef fetches a file at a specific ref from a repo.
// If ref is empty, the query parameter is omitted (uses default branch).
//
// Note: dot-segments ("." and "..") in the path are silently removed to
// prevent path traversal. This means a path like "foo/../bar" resolves
// to "foo/bar" rather than "bar".
func (c *Client) GetFileContentAtRef(ctx context.Context, owner, repo, path, ref string) (string, error) {
reqURL := fmt.Sprintf("%s/repos/%s/%s/contents/%s",
c.baseURL, url.PathEscape(owner), url.PathEscape(repo), escapePath(path))
if ref != "" {
reqURL += "?ref=" + url.QueryEscape(ref)
}
body, err := c.doGet(ctx, reqURL)
if err != nil {
return "", fmt.Errorf("fetch file %s: %w", path, err)
}
var resp struct {
Content string `json:"content"`
Encoding string `json:"encoding"`
}
if err := json.Unmarshal(body, &resp); err != nil {
return "", fmt.Errorf("parse file content JSON: %w", err)
}
if resp.Encoding != "base64" {
return "", fmt.Errorf("unexpected encoding %q for file %s", resp.Encoding, path)
}
decoded, err := decodeBase64Content(resp.Content)
if err != nil {
return "", fmt.Errorf("decode base64 content for %s: %w", path, err)
}
return decoded, nil
}
// escapePath encodes each segment of a slash-separated path, stripping
// dot-segments to prevent path traversal.
func escapePath(p string) string {
parts := strings.Split(p, "/")
var clean []string
for _, part := range parts {
if part == "." || part == ".." || part == "" {
continue
}
clean = append(clean, url.PathEscape(part))
}
return strings.Join(clean, "/")
}
// decodeBase64Content decodes base64-encoded content from the GitHub contents API.
// GitHub returns base64 content with line breaks for formatting; we strip \r and \n before decoding.
func decodeBase64Content(encoded string) (string, error) {
cleaned := strings.NewReplacer("\n", "", "\r", "").Replace(encoded)
decoded, err := base64.StdEncoding.DecodeString(cleaned)
if err != nil {
return "", err
}
return string(decoded), nil
}
-229
View File
@@ -1,229 +0,0 @@
package github
import (
"context"
"encoding/json"
"fmt"
"net/http"
"net/url"
"gitea.weiker.me/rodin/review-bot/vcs"
)
// pullRequestResponse is the GitHub API response for a pull request.
type pullRequestResponse struct {
Number int `json:"number"`
Title string `json:"title"`
Body string `json:"body"`
Head struct {
SHA string `json:"sha"`
Ref string `json:"ref"`
} `json:"head"`
Base struct {
Ref string `json:"ref"`
} `json:"base"`
}
// changedFileResponse is the GitHub API response for a changed file in a PR.
type changedFileResponse struct {
Filename string `json:"filename"`
Status string `json:"status"`
Patch string `json:"patch"`
}
// commitStatusResponse is the GitHub combined status API response.
type commitStatusResponse struct {
Statuses []struct {
Context string `json:"context"`
State string `json:"state"`
Description string `json:"description"`
TargetURL string `json:"target_url"`
} `json:"statuses"`
}
// checkRunsResponse is the GitHub check runs API response.
type checkRunsResponse struct {
CheckRuns []struct {
Name string `json:"name"`
Conclusion *string `json:"conclusion"`
Status string `json:"status"`
HTMLURL string `json:"html_url"`
} `json:"check_runs"`
}
// GetPullRequest fetches PR metadata from the GitHub API.
// Returns an *APIError wrapping the HTTP status on non-2xx responses (e.g.
// IsNotFound for 404, IsUnauthorized for 401). Network and context errors
// are wrapped but not typed as *APIError.
func (c *Client) GetPullRequest(ctx context.Context, owner, repo string, number int) (*vcs.PullRequest, error) {
reqURL := fmt.Sprintf("%s/repos/%s/%s/pulls/%d", c.baseURL, url.PathEscape(owner), url.PathEscape(repo), number)
body, err := c.doGet(ctx, reqURL)
if err != nil {
return nil, fmt.Errorf("fetch PR: %w", err)
}
var resp pullRequestResponse
if err := json.Unmarshal(body, &resp); err != nil {
return nil, fmt.Errorf("parse PR JSON: %w", err)
}
return &vcs.PullRequest{
Number: resp.Number,
Title: resp.Title,
Body: resp.Body,
Head: vcs.HeadRef{SHA: resp.Head.SHA, Ref: resp.Head.Ref},
Base: vcs.BaseRef{Ref: resp.Base.Ref},
}, nil
}
// GetPullRequestDiff fetches the unified diff for a PR.
// Uses Accept: application/vnd.github.diff to get raw diff text.
func (c *Client) GetPullRequestDiff(ctx context.Context, owner, repo string, number int) (string, error) {
reqURL := fmt.Sprintf("%s/repos/%s/%s/pulls/%d", c.baseURL, url.PathEscape(owner), url.PathEscape(repo), number)
body, err := c.doRequest(ctx, http.MethodGet, reqURL, "application/vnd.github.diff")
if err != nil {
return "", fmt.Errorf("fetch diff: %w", err)
}
return string(body), nil
}
const (
// maxFilesPages is the upper bound on pagination loops for PR file listing,
// preventing unbounded iteration if the server always returns a full page.
maxFilesPages = 100
// maxCheckRunPages is the upper bound on pagination loops for check-run listing,
// preventing unbounded iteration if the server always returns a full page.
maxCheckRunPages = 100
)
// GetPullRequestFiles fetches the list of files changed in a PR.
// Paginates through all pages (100 per page) to collect all files.
// Returns nil (not an empty slice) when the PR has no changed files.
// Callers can safely range over or check len() on a nil slice.
func (c *Client) GetPullRequestFiles(ctx context.Context, owner, repo string, number int) ([]vcs.ChangedFile, error) {
var allFiles []vcs.ChangedFile
for page := 1; page <= maxFilesPages; page++ {
reqURL := fmt.Sprintf("%s/repos/%s/%s/pulls/%d/files?per_page=100&page=%d",
c.baseURL, url.PathEscape(owner), url.PathEscape(repo), number, page)
body, err := c.doGet(ctx, reqURL)
if err != nil {
return nil, fmt.Errorf("fetch PR files page %d: %w", page, err)
}
var files []changedFileResponse
if err := json.Unmarshal(body, &files); err != nil {
return nil, fmt.Errorf("parse PR files JSON: %w", err)
}
if len(files) == 0 {
break
}
for _, f := range files {
allFiles = append(allFiles, vcs.ChangedFile{
Filename: f.Filename,
Status: f.Status,
Patch: f.Patch,
})
}
if len(files) < 100 {
break
}
}
return allFiles, nil
}
// GetCommitStatuses fetches both commit statuses and check runs for a SHA,
// merging them into a unified []vcs.CommitStatus slice.
// Returns nil (not an empty slice) when there are no statuses or check runs.
// If the commit statuses endpoint fails (e.g. 404 for an unknown SHA), the
// function returns immediately without attempting the check-runs endpoint.
// If the check-runs endpoint fails after statuses were fetched successfully,
// the function returns an error (not a partial result) so callers always get
// either a complete view or a clear error signal.
func (c *Client) GetCommitStatuses(ctx context.Context, owner, repo, sha string) ([]vcs.CommitStatus, error) {
var result []vcs.CommitStatus
// Fetch commit statuses
statusURL := fmt.Sprintf("%s/repos/%s/%s/commits/%s/status",
c.baseURL, url.PathEscape(owner), url.PathEscape(repo), url.PathEscape(sha))
statusBody, err := c.doGet(ctx, statusURL)
if err != nil {
return nil, fmt.Errorf("fetch commit statuses: %w", err)
}
var statusResp commitStatusResponse
if err := json.Unmarshal(statusBody, &statusResp); err != nil {
return nil, fmt.Errorf("parse commit statuses JSON: %w", err)
}
for _, s := range statusResp.Statuses {
result = append(result, vcs.CommitStatus{
Context: s.Context,
Status: s.State,
Description: s.Description,
TargetURL: s.TargetURL,
})
}
// Fetch check runs (paginated)
for checkPage := 1; checkPage <= maxCheckRunPages; checkPage++ {
checkURL := fmt.Sprintf("%s/repos/%s/%s/commits/%s/check-runs?per_page=100&page=%d",
c.baseURL, url.PathEscape(owner), url.PathEscape(repo), url.PathEscape(sha), checkPage)
checkBody, err := c.doGet(ctx, checkURL)
if err != nil {
return nil, fmt.Errorf("fetch check runs page %d: %w", checkPage, err)
}
var checkResp checkRunsResponse
if err := json.Unmarshal(checkBody, &checkResp); err != nil {
return nil, fmt.Errorf("parse check runs JSON: %w", err)
}
for _, cr := range checkResp.CheckRuns {
result = append(result, vcs.CommitStatus{
Context: cr.Name,
Status: mapCheckRunStatus(cr.Conclusion),
Description: derefString(cr.Conclusion), // raw conclusion value (e.g. "success", "failure", "skipped")
TargetURL: cr.HTMLURL,
})
}
if len(checkResp.CheckRuns) < 100 {
break
}
}
return result, nil
}
// mapCheckRunStatus maps a GitHub check run conclusion to a vcs.CommitStatus status string.
// Conclusion alone determines the mapped state: nil conclusion means the run is
// still in progress (pending), regardless of the status field value.
//
// Mapping rules:
// - nil → "pending" (run still in progress or queued)
// - "success" → "success"
// - "failure", "action_required", "timed_out" → "failure"
// - "cancelled", "skipped", "neutral" → "success" (non-blocking per GitHub check suite semantics)
// - "stale", "waiting" → "pending"
// - unknown values → "pending" (conservative: treat unrecognized conclusions as incomplete)
func mapCheckRunStatus(conclusion *string) string {
if conclusion == nil {
// Still running or queued
return "pending"
}
switch *conclusion {
case "success":
return "success"
case "failure", "action_required", "timed_out":
return "failure"
case "cancelled", "skipped", "neutral":
return "success" // non-blocking: these do not indicate a blocking failure per GitHub check suite semantics
case "stale", "waiting":
return "pending"
default:
return "pending"
}
}
// derefString safely dereferences a string pointer, returning empty string if nil.
func derefString(s *string) string {
if s == nil {
return ""
}
return *s
}
-676
View File
@@ -1,676 +0,0 @@
package github
import (
"context"
"encoding/json"
"fmt"
"net/http"
"net/http/httptest"
"strings"
"testing"
"time"
)
func TestGetPullRequest_HappyPath(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/repos/owner/repo/pulls/42" {
t.Errorf("unexpected path: %s", r.URL.Path)
}
json.NewEncoder(w).Encode(map[string]interface{}{
"number": 42,
"title": "Test PR",
"body": "Description",
"head": map[string]string{"sha": "abc123", "ref": "feature-branch"},
"base": map[string]string{"ref": "main"},
})
}))
defer srv.Close()
c := NewClient("token", srv.URL, AllowInsecureHTTP())
c.SetHTTPClient(srv.Client())
pr, err := c.GetPullRequest(context.Background(), "owner", "repo", 42)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if pr.Number != 42 {
t.Errorf("expected number 42, got %d", pr.Number)
}
if pr.Title != "Test PR" {
t.Errorf("expected title 'Test PR', got %q", pr.Title)
}
if pr.Body != "Description" {
t.Errorf("expected body 'Description', got %q", pr.Body)
}
if pr.Head.SHA != "abc123" {
t.Errorf("expected head SHA 'abc123', got %q", pr.Head.SHA)
}
if pr.Head.Ref != "feature-branch" {
t.Errorf("expected head ref 'feature-branch', got %q", pr.Head.Ref)
}
if pr.Base.Ref != "main" {
t.Errorf("expected base ref 'main', got %q", pr.Base.Ref)
}
}
func TestGetPullRequest_404(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(404)
w.Write([]byte(`{"message":"Not Found"}`))
}))
defer srv.Close()
c := NewClient("token", srv.URL, AllowInsecureHTTP())
c.SetHTTPClient(srv.Client())
_, err := c.GetPullRequest(context.Background(), "owner", "repo", 999)
if err == nil {
t.Fatal("expected error for 404")
}
if !IsNotFound(err) {
t.Errorf("expected IsNotFound=true, got error: %v", err)
}
}
func TestGetPullRequest_401(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(401)
w.Write([]byte(`{"message":"Bad credentials"}`))
}))
defer srv.Close()
c := NewClient("token", srv.URL, AllowInsecureHTTP())
c.SetHTTPClient(srv.Client())
_, err := c.GetPullRequest(context.Background(), "owner", "repo", 1)
if err == nil {
t.Fatal("expected error for 401")
}
if !IsUnauthorized(err) {
t.Errorf("expected IsUnauthorized=true, got error: %v", err)
}
}
func TestGetPullRequest_429Retry(t *testing.T) {
attempts := 0
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
attempts++
if attempts == 1 {
w.WriteHeader(429)
w.Write([]byte(`{"message":"rate limit"}`))
return
}
json.NewEncoder(w).Encode(map[string]interface{}{
"number": 1,
"title": "PR",
"body": "",
"head": map[string]string{"sha": "abc", "ref": "br"},
"base": map[string]string{"ref": "main"},
})
}))
defer srv.Close()
c := NewClient("token", srv.URL, AllowInsecureHTTP())
c.SetHTTPClient(srv.Client())
c.SetRetryBackoff([]time.Duration{1 * time.Millisecond})
pr, err := c.GetPullRequest(context.Background(), "owner", "repo", 1)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if pr.Number != 1 {
t.Errorf("expected number 1, got %d", pr.Number)
}
if attempts != 2 {
t.Errorf("expected 2 attempts, got %d", attempts)
}
}
func TestGetPullRequest_MalformedJSON(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(200)
w.Write([]byte(`{invalid json`))
}))
defer srv.Close()
c := NewClient("token", srv.URL, AllowInsecureHTTP())
c.SetHTTPClient(srv.Client())
_, err := c.GetPullRequest(context.Background(), "owner", "repo", 1)
if err == nil {
t.Fatal("expected error for malformed JSON")
}
if !strings.Contains(err.Error(), "parse PR JSON") {
t.Errorf("expected parse error, got: %v", err)
}
}
func TestGetPullRequestDiff_HappyPath(t *testing.T) {
expectedDiff := "diff --git a/file.go b/file.go\n--- a/file.go\n+++ b/file.go\n@@ -1,3 +1,4 @@\n+// new line\n"
var gotAccept string
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
gotAccept = r.Header.Get("Accept")
w.WriteHeader(200)
w.Write([]byte(expectedDiff))
}))
defer srv.Close()
c := NewClient("token", srv.URL, AllowInsecureHTTP())
c.SetHTTPClient(srv.Client())
diff, err := c.GetPullRequestDiff(context.Background(), "owner", "repo", 42)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if diff != expectedDiff {
t.Errorf("unexpected diff: %q", diff)
}
if gotAccept != "application/vnd.github.diff" {
t.Errorf("expected diff Accept header, got %q", gotAccept)
}
}
func TestGetPullRequestDiff_404(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(404)
w.Write([]byte(`{"message":"Not Found"}`))
}))
defer srv.Close()
c := NewClient("token", srv.URL, AllowInsecureHTTP())
c.SetHTTPClient(srv.Client())
_, err := c.GetPullRequestDiff(context.Background(), "owner", "repo", 999)
if err == nil {
t.Fatal("expected error for 404")
}
}
func TestGetPullRequestDiff_401(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(401)
w.Write([]byte(`{"message":"Bad credentials"}`))
}))
defer srv.Close()
c := NewClient("token", srv.URL, AllowInsecureHTTP())
c.SetHTTPClient(srv.Client())
_, err := c.GetPullRequestDiff(context.Background(), "owner", "repo", 1)
if err == nil {
t.Fatal("expected error for 401")
}
}
func TestGetPullRequestFiles_HappyPath(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
json.NewEncoder(w).Encode([]map[string]interface{}{
{"filename": "main.go", "status": "modified", "patch": "@@ -1,3 +1,4 @@\n+line"},
{"filename": "test.go", "status": "added", "patch": "@@ -0,0 +1,5 @@\n+new file"},
})
}))
defer srv.Close()
c := NewClient("token", srv.URL, AllowInsecureHTTP())
c.SetHTTPClient(srv.Client())
files, err := c.GetPullRequestFiles(context.Background(), "owner", "repo", 1)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if len(files) != 2 {
t.Fatalf("expected 2 files, got %d", len(files))
}
if files[0].Filename != "main.go" {
t.Errorf("expected filename 'main.go', got %q", files[0].Filename)
}
if files[0].Status != "modified" {
t.Errorf("expected status 'modified', got %q", files[0].Status)
}
if files[0].Patch != "@@ -1,3 +1,4 @@\n+line" {
t.Errorf("unexpected patch: %q", files[0].Patch)
}
}
func TestGetPullRequestFiles_Pagination(t *testing.T) {
// Simulate > 100 files requiring pagination
page1Files := make([]map[string]string, 100)
for i := 0; i < 100; i++ {
page1Files[i] = map[string]string{
"filename": fmt.Sprintf("file%d.go", i),
"status": "modified",
"patch": fmt.Sprintf("patch%d", i),
}
}
page2Files := []map[string]string{
{"filename": "file100.go", "status": "added", "patch": "patch100"},
}
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
page := r.URL.Query().Get("page")
if page == "" || page == "1" {
json.NewEncoder(w).Encode(page1Files)
} else {
json.NewEncoder(w).Encode(page2Files)
}
}))
defer srv.Close()
c := NewClient("token", srv.URL, AllowInsecureHTTP())
c.SetHTTPClient(srv.Client())
files, err := c.GetPullRequestFiles(context.Background(), "owner", "repo", 1)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if len(files) != 101 {
t.Errorf("expected 101 files (paginated), got %d", len(files))
}
if files[100].Filename != "file100.go" {
t.Errorf("expected last file 'file100.go', got %q", files[100].Filename)
}
if files[100].Patch != "patch100" {
t.Errorf("expected last patch 'patch100', got %q", files[100].Patch)
}
}
func TestGetPullRequestFiles_BinaryFile_NoPatch(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Binary files have no patch field in GitHub response
json.NewEncoder(w).Encode([]map[string]interface{}{
{"filename": "image.png", "status": "added"},
})
}))
defer srv.Close()
c := NewClient("token", srv.URL, AllowInsecureHTTP())
c.SetHTTPClient(srv.Client())
files, err := c.GetPullRequestFiles(context.Background(), "owner", "repo", 1)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if len(files) != 1 {
t.Fatalf("expected 1 file, got %d", len(files))
}
if files[0].Patch != "" {
t.Errorf("expected empty patch for binary file, got %q", files[0].Patch)
}
}
func TestGetPullRequestFiles_404(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(404)
w.Write([]byte(`{"message":"Not Found"}`))
}))
defer srv.Close()
c := NewClient("token", srv.URL, AllowInsecureHTTP())
c.SetHTTPClient(srv.Client())
_, err := c.GetPullRequestFiles(context.Background(), "owner", "repo", 999)
if err == nil {
t.Fatal("expected error for 404")
}
}
func TestGetPullRequestFiles_MalformedJSON(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(200)
w.Write([]byte(`not json`))
}))
defer srv.Close()
c := NewClient("token", srv.URL, AllowInsecureHTTP())
c.SetHTTPClient(srv.Client())
_, err := c.GetPullRequestFiles(context.Background(), "owner", "repo", 1)
if err == nil {
t.Fatal("expected error for malformed JSON")
}
}
func TestGetFileContentAtRef_HappyPath(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/repos/owner/repo/contents/path/to/file.go" {
t.Errorf("unexpected path: %s", r.URL.Path)
}
if r.URL.Query().Get("ref") != "abc123" {
t.Errorf("unexpected ref: %s", r.URL.Query().Get("ref"))
}
json.NewEncoder(w).Encode(map[string]string{
"content": "cGFja2FnZSBtYWlu", // "package main" in base64
"encoding": "base64",
})
}))
defer srv.Close()
c := NewClient("token", srv.URL, AllowInsecureHTTP())
c.SetHTTPClient(srv.Client())
content, err := c.GetFileContentAtRef(context.Background(), "owner", "repo", "path/to/file.go", "abc123")
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if content != "package main" {
t.Errorf("expected 'package main', got %q", content)
}
}
func TestGetFileContentAtRef_EmptyRef(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Query().Get("ref") != "" {
t.Errorf("expected no ref param, got %q", r.URL.Query().Get("ref"))
}
json.NewEncoder(w).Encode(map[string]string{
"content": "aGVsbG8=", // "hello" in base64
"encoding": "base64",
})
}))
defer srv.Close()
c := NewClient("token", srv.URL, AllowInsecureHTTP())
c.SetHTTPClient(srv.Client())
content, err := c.GetFileContentAtRef(context.Background(), "owner", "repo", "file.txt", "")
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if content != "hello" {
t.Errorf("expected 'hello', got %q", content)
}
}
func TestGetFileContentAtRef_404(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(404)
w.Write([]byte(`{"message":"Not Found"}`))
}))
defer srv.Close()
c := NewClient("token", srv.URL, AllowInsecureHTTP())
c.SetHTTPClient(srv.Client())
_, err := c.GetFileContentAtRef(context.Background(), "owner", "repo", "missing.go", "main")
if err == nil {
t.Fatal("expected error for 404")
}
}
func TestGetFileContentAtRef_401(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(401)
w.Write([]byte(`{"message":"Bad credentials"}`))
}))
defer srv.Close()
c := NewClient("token", srv.URL, AllowInsecureHTTP())
c.SetHTTPClient(srv.Client())
_, err := c.GetFileContentAtRef(context.Background(), "owner", "repo", "file.go", "main")
if err == nil {
t.Fatal("expected error for 401")
}
}
func TestGetFileContentAtRef_MalformedJSON(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(200)
w.Write([]byte(`not valid json`))
}))
defer srv.Close()
c := NewClient("token", srv.URL, AllowInsecureHTTP())
c.SetHTTPClient(srv.Client())
_, err := c.GetFileContentAtRef(context.Background(), "owner", "repo", "file.go", "main")
if err == nil {
t.Fatal("expected error for malformed JSON")
}
}
func TestGetFileContentAtRef_429Retry(t *testing.T) {
attempts := 0
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
attempts++
if attempts == 1 {
w.WriteHeader(429)
w.Write([]byte(`{"message":"rate limit"}`))
return
}
json.NewEncoder(w).Encode(map[string]string{
"content": "b2s=", // "ok" in base64
"encoding": "base64",
})
}))
defer srv.Close()
c := NewClient("token", srv.URL, AllowInsecureHTTP())
c.SetHTTPClient(srv.Client())
c.SetRetryBackoff([]time.Duration{1 * time.Millisecond})
content, err := c.GetFileContentAtRef(context.Background(), "owner", "repo", "file.go", "main")
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if content != "ok" {
t.Errorf("expected 'ok', got %q", content)
}
if attempts != 2 {
t.Errorf("expected 2 attempts, got %d", attempts)
}
}
func TestGetCommitStatuses_HappyPath(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch {
case strings.Contains(r.URL.Path, "/status"):
json.NewEncoder(w).Encode(map[string]interface{}{
"state": "success",
"statuses": []map[string]string{
{
"context": "ci/build",
"state": "success",
"description": "Build passed",
"target_url": "https://ci.example.com/1",
},
},
})
case strings.Contains(r.URL.Path, "/check-runs"):
conclusion := "success"
json.NewEncoder(w).Encode(map[string]interface{}{
"total_count": 1,
"check_runs": []map[string]interface{}{
{
"name": "lint",
"conclusion": &conclusion,
"status": "completed",
"html_url": "https://github.com/check/1",
},
},
})
default:
t.Errorf("unexpected path: %s", r.URL.Path)
w.WriteHeader(404)
}
}))
defer srv.Close()
c := NewClient("token", srv.URL, AllowInsecureHTTP())
c.SetHTTPClient(srv.Client())
statuses, err := c.GetCommitStatuses(context.Background(), "owner", "repo", "abc123")
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if len(statuses) != 2 {
t.Fatalf("expected 2 statuses, got %d", len(statuses))
}
// First should be from commit statuses
if statuses[0].Context != "ci/build" {
t.Errorf("expected context 'ci/build', got %q", statuses[0].Context)
}
if statuses[0].Status != "success" {
t.Errorf("expected status 'success', got %q", statuses[0].Status)
}
// Second should be from check runs
if statuses[1].Context != "lint" {
t.Errorf("expected context 'lint', got %q", statuses[1].Context)
}
if statuses[1].Status != "success" {
t.Errorf("expected status 'success', got %q", statuses[1].Status)
}
}
func TestGetCommitStatuses_CheckRunConclusions(t *testing.T) {
tests := []struct {
conclusion *string
status string
want string
}{
{stringPtr("success"), "completed", "success"},
{stringPtr("failure"), "completed", "failure"},
{stringPtr("action_required"), "completed", "failure"},
{stringPtr("timed_out"), "completed", "failure"},
{stringPtr("cancelled"), "completed", "success"},
{stringPtr("skipped"), "completed", "success"},
{stringPtr("neutral"), "completed", "success"},
{nil, "in_progress", "pending"},
{nil, "queued", "pending"},
}
for _, tt := range tests {
name := "nil"
if tt.conclusion != nil {
name = *tt.conclusion
}
t.Run(name, func(t *testing.T) {
t.Parallel()
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if strings.Contains(r.URL.Path, "/status") {
json.NewEncoder(w).Encode(map[string]interface{}{
"state": "success",
"statuses": []interface{}{},
})
return
}
json.NewEncoder(w).Encode(map[string]interface{}{
"total_count": 1,
"check_runs": []map[string]interface{}{
{
"name": "check",
"conclusion": tt.conclusion,
"status": tt.status,
"html_url": "https://github.com/check/1",
},
},
})
}))
defer srv.Close()
c := NewClient("token", srv.URL, AllowInsecureHTTP())
c.SetHTTPClient(srv.Client())
statuses, err := c.GetCommitStatuses(context.Background(), "owner", "repo", "sha1")
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if len(statuses) != 1 {
t.Fatalf("expected 1 status, got %d", len(statuses))
}
if statuses[0].Status != tt.want {
t.Errorf("expected status %q, got %q", tt.want, statuses[0].Status)
}
})
}
}
func TestGetCommitStatuses_404(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(404)
w.Write([]byte(`{"message":"Not Found"}`))
}))
defer srv.Close()
c := NewClient("token", srv.URL, AllowInsecureHTTP())
c.SetHTTPClient(srv.Client())
_, err := c.GetCommitStatuses(context.Background(), "owner", "repo", "badsha")
if err == nil {
t.Fatal("expected error for 404")
}
}
func TestGetCommitStatuses_401(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(401)
w.Write([]byte(`{"message":"Bad credentials"}`))
}))
defer srv.Close()
c := NewClient("token", srv.URL, AllowInsecureHTTP())
c.SetHTTPClient(srv.Client())
_, err := c.GetCommitStatuses(context.Background(), "owner", "repo", "sha")
if err == nil {
t.Fatal("expected error for 401")
}
}
func TestGetCommitStatuses_MalformedJSON(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(200)
w.Write([]byte(`not json`))
}))
defer srv.Close()
c := NewClient("token", srv.URL, AllowInsecureHTTP())
c.SetHTTPClient(srv.Client())
_, err := c.GetCommitStatuses(context.Background(), "owner", "repo", "sha")
if err == nil {
t.Fatal("expected error for malformed JSON")
}
}
func TestGetCommitStatuses_CheckRunsErrorAfterStatusesSucceed(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch {
case strings.Contains(r.URL.Path, "/status"):
// Statuses succeed
json.NewEncoder(w).Encode(map[string]interface{}{
"state": "success",
"statuses": []map[string]string{
{
"context": "ci/build",
"state": "success",
"description": "Build passed",
"target_url": "https://ci.example.com/1",
},
},
})
case strings.Contains(r.URL.Path, "/check-runs"):
// Check runs fail with 500
w.WriteHeader(500)
w.Write([]byte(`{"message":"Internal Server Error"}`))
default:
w.WriteHeader(404)
}
}))
defer srv.Close()
c := NewClient("token", srv.URL, AllowInsecureHTTP())
c.SetHTTPClient(srv.Client())
_, err := c.GetCommitStatuses(context.Background(), "owner", "repo", "abc123")
if err == nil {
t.Fatal("expected error when check-runs endpoint fails after statuses succeed")
}
if !strings.Contains(err.Error(), "fetch check runs") {
t.Errorf("expected check runs error, got: %v", err)
}
}
func stringPtr(s string) *string {
return &s
}
+27 -5
View File
@@ -2,6 +2,7 @@ package review
import (
"fmt"
"regexp"
"strings"
)
@@ -22,10 +23,29 @@ func GiteaEvent(verdict string) string {
}
}
// markdownSpecialChars matches characters that have special meaning in Markdown.
// We escape these to prevent untrusted input from breaking formatting.
// Uses a quoted string since raw strings can't contain backticks.
var markdownSpecialChars = regexp.MustCompile("([\\\\*_`\\[\\]()#<>|~])")
// sanitizeMarkdownText escapes special Markdown characters in untrusted text.
// This prevents markdown injection attacks where a malicious display name could
// break formatting, inject links, or create unexpected rendering.
func sanitizeMarkdownText(s string) string {
// First, remove any control characters and null bytes
cleaned := strings.Map(func(r rune) rune {
if r < 32 && r != '\t' && r != '\n' {
return -1 // drop the character
}
return r
}, s)
// Escape special Markdown characters by prepending backslash
return markdownSpecialChars.ReplaceAllString(cleaned, `\$1`)
}
// FormatMarkdownWithDisplay formats a ReviewResult with separate display name and sentinel name.
// Note: displayName is not HTML-escaped as Gitea sanitizes rendered Markdown.
// Persona display names are controlled by repo owners (trusted input).
// displayName is used for the header title, sentinelName is used for the cleanup sentinel.
// displayName is sanitized to prevent Markdown injection from untrusted remote persona metadata.
// sentinelName is used for the cleanup sentinel comment (machine-readable, not rendered).
// If displayName is empty, sentinelName is used for both.
func FormatMarkdownWithDisplay(result *ReviewResult, displayName, sentinelName string) string {
var sb strings.Builder
@@ -37,7 +57,8 @@ func FormatMarkdownWithDisplay(result *ReviewResult, displayName, sentinelName s
}
if headerName != "" {
title := CapitalizeFirst(headerName)
// Sanitize the header name to prevent Markdown injection
title := CapitalizeFirst(sanitizeMarkdownText(headerName))
sb.WriteString(fmt.Sprintf("# %s Review\n\n", title))
}
@@ -61,7 +82,8 @@ func FormatMarkdownWithDisplay(result *ReviewResult, displayName, sentinelName s
sb.WriteString(fmt.Sprintf("**%s** — %s\n", result.Verdict, result.Recommendation))
if sentinelName != "" {
sb.WriteString(fmt.Sprintf("\n---\n*Review by %s*\n", headerName))
// Sanitize headerName for the footer as well
sb.WriteString(fmt.Sprintf("\n---\n*Review by %s*\n", sanitizeMarkdownText(headerName)))
// Hidden sentinel for identifying this bot's reviews during cleanup
sb.WriteString(fmt.Sprintf("\n<!-- review-bot:%s -->\n", sentinelName))
}
+68
View File
@@ -214,3 +214,71 @@ func TestFormatMarkdownWithDisplay(t *testing.T) {
}
})
}
func TestSanitizeMarkdownText(t *testing.T) {
tests := []struct {
name string
input string
want string
}{
{
name: "plain text unchanged",
input: "Security Specialist",
want: "Security Specialist",
},
{
name: "escapes asterisks",
input: "**bold** attack",
want: `\*\*bold\*\* attack`,
},
{
name: "escapes brackets for links",
input: "[click me](http://evil.com)",
want: `\[click me\]\(http://evil.com\)`,
},
{
name: "escapes backticks",
input: "`code` injection",
want: "\\`code\\` injection",
},
{
name: "escapes angle brackets",
input: "<script>alert(1)</script>",
want: `\<script\>alert\(1\)\</script\>`,
},
{
name: "escapes hash for headers",
input: "# Fake Header",
want: `\# Fake Header`,
},
{
name: "escapes pipe for tables",
input: "col1 | col2",
want: `col1 \| col2`,
},
{
name: "removes control characters",
input: "hello\x00world\x1f",
want: "helloworld",
},
{
name: "preserves tabs and newlines",
input: "line1\n\tindented",
want: "line1\n\tindented",
},
{
name: "escapes tilde for strikethrough",
input: "~~strikethrough~~",
want: `\~\~strikethrough\~\~`,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := sanitizeMarkdownText(tt.input)
if got != tt.want {
t.Errorf("sanitizeMarkdownText(%q) = %q, want %q", tt.input, got, tt.want)
}
})
}
}
-7
View File
@@ -224,13 +224,6 @@ func checkYAMLDepth(node *yaml.Node, depth, maxDepth, maxNodes int, seen map[*ya
return nil
}
// ParsePersonaBytes parses persona data from bytes with a source label for errors.
// This is useful for parsing personas fetched from external sources (e.g., Gitea API)
// without requiring filesystem access. Format is detected by source extension.
func ParsePersonaBytes(data []byte, source string) (*Persona, error) {
return parsePersona(data, source)
}
func validatePersona(p *Persona, source string) error {
if p.Name == "" {
return fmt.Errorf("persona %s: name is required", source)
+171
View File
@@ -0,0 +1,171 @@
package review
import (
"context"
"fmt"
"log/slog"
"sort"
"strings"
)
// PersonaFetcher abstracts fetching files from a remote repository.
// This allows persona loading to work with any Git host API.
type PersonaFetcher interface {
// ListContents returns file/directory entries at a path.
// Returns an error if the path doesn't exist or isn't accessible.
ListContents(ctx context.Context, owner, repo, path string) ([]ContentEntry, error)
// GetFileContent returns the raw content of a file from the default branch.
GetFileContent(ctx context.Context, owner, repo, filepath string) (string, error)
}
// ContentEntry represents a file or directory entry.
type ContentEntry struct {
Name string // filename or directory name
Path string // full path from repo root
Type string // "file" or "dir"
}
// DefaultPersonasPath is the conventional location for repo-specific personas.
const DefaultPersonasPath = ".review-bot/personas"
// LoadRemotePersonas fetches personas from a remote repository's .review-bot/personas/ directory.
// Returns a map of persona name to Persona. If the directory doesn't exist or is empty,
// returns an empty map with no error (graceful fallback to built-in personas).
//
// Files larger than MaxPersonaFileSize are logged and skipped.
// Invalid YAML files are logged and skipped (partial success model).
// Only .yaml and .yml files are processed; other files are ignored.
func LoadRemotePersonas(ctx context.Context, fetcher PersonaFetcher, owner, repo string) (map[string]*Persona, error) {
return LoadRemotePersonasFromPath(ctx, fetcher, owner, repo, DefaultPersonasPath)
}
// LoadRemotePersonasFromPath loads personas from a custom path in a remote repository.
// It behaves the same as LoadRemotePersonas but allows specifying a path other than
// the default .review-bot/personas directory.
func LoadRemotePersonasFromPath(ctx context.Context, fetcher PersonaFetcher, owner, repo, path string) (map[string]*Persona, error) {
entries, err := fetcher.ListContents(ctx, owner, repo, path)
if err != nil {
// 404 is expected when repo doesn't have personas — return empty, not error
if isNotFoundError(err) {
slog.Debug("no remote personas directory found", "repo", fmt.Sprintf("%s/%s", owner, repo), "path", path)
return map[string]*Persona{}, nil
}
return nil, fmt.Errorf("list remote personas: %w", err)
}
// Cap the number of files to process to prevent resource exhaustion
// from repos with thousands of small files.
const maxPersonaFiles = 50
result := make(map[string]*Persona)
processed := 0
for _, entry := range entries {
if processed >= maxPersonaFiles {
slog.Warn("persona file limit reached", "limit", maxPersonaFiles, "repo", fmt.Sprintf("%s/%s", owner, repo))
break
}
if ctx.Err() != nil {
return nil, ctx.Err()
}
// Skip directories and non-YAML files
if entry.Type != "file" {
continue
}
if !isYAMLFile(entry.Name) {
continue
}
content, err := fetcher.GetFileContent(ctx, owner, repo, entry.Path)
if err != nil {
slog.Warn("could not fetch remote persona file", "file", entry.Path, "error", err)
continue
}
// Check size before parsing (defense in depth)
if len(content) > MaxPersonaFileSize {
slog.Warn("remote persona file exceeds size limit", "file", entry.Path, "size", len(content), "limit", MaxPersonaFileSize)
continue
}
// YAML parsing uses parsePersona which has defenses against YAML DoS attacks:
// - MaxPersonaFileSize (above) caps raw input size before any parsing
// - maxPersonaFiles (above) limits the number of files processed per repo
// - unmarshalYAMLWithDepthLimit enforces MaxYAMLDepth to prevent stack exhaustion
// - checkYAMLDepth tracks node counts (MaxYAMLNodes) against "billion laughs" expansion
// - Alias cycles are detected and capped by seen-node tracking
// See persona.go for the implementation details.
persona, err := parsePersona([]byte(content), entry.Path)
if err != nil {
slog.Warn("could not parse remote persona file", "file", entry.Path, "error", err)
continue
}
result[persona.Name] = persona
processed++
slog.Debug("loaded remote persona", "name", persona.Name, "file", entry.Path)
}
return result, nil
}
// MergePersonas combines remote and built-in personas.
// Remote personas take precedence on name collision.
// Returns the merged map and a list of persona names in sorted order.
func MergePersonas(remote, builtin map[string]*Persona) (map[string]*Persona, []string) {
merged := make(map[string]*Persona)
// Add built-in first
for name, p := range builtin {
merged[name] = p
}
// Remote overrides built-in on collision
for name, p := range remote {
if _, exists := merged[name]; exists {
slog.Debug("remote persona overrides built-in", "name", name)
}
merged[name] = p
}
// Collect sorted names
names := make([]string, 0, len(merged))
for name := range merged {
names = append(names, name)
}
sort.Strings(names)
return merged, names
}
// LoadAllBuiltinPersonas loads all built-in personas into a map.
func LoadAllBuiltinPersonas() map[string]*Persona {
result := make(map[string]*Persona)
for _, name := range ListBuiltinPersonas() {
p, err := LoadBuiltinPersona(name)
if err != nil {
slog.Warn("could not load built-in persona", "name", name, "error", err)
continue
}
result[name] = p
}
return result
}
// isYAMLFile returns true if the filename has a YAML extension.
func isYAMLFile(name string) bool {
lower := strings.ToLower(name)
return strings.HasSuffix(lower, ".yaml") || strings.HasSuffix(lower, ".yml")
}
// isNotFoundError checks if an error indicates a 404 response.
// This is a simple string check to avoid importing the gitea package
// (which would create a circular dependency).
func isNotFoundError(err error) bool {
if err == nil {
return false
}
errStr := err.Error()
return strings.Contains(errStr, "HTTP 404")
}
+394
View File
@@ -0,0 +1,394 @@
package review
import (
"context"
"errors"
"testing"
)
// mockFetcher implements PersonaFetcher for testing.
type mockFetcher struct {
contents map[string][]ContentEntry // path -> entries
files map[string]string // path -> content
listErr error // error to return from ListContents
getFileErr map[string]error // path -> error for GetFileContent
listNotFound bool // return 404-style error
}
func newMockFetcher() *mockFetcher {
return &mockFetcher{
contents: make(map[string][]ContentEntry),
files: make(map[string]string),
getFileErr: make(map[string]error),
}
}
func (m *mockFetcher) ListContents(ctx context.Context, owner, repo, path string) ([]ContentEntry, error) {
if m.listNotFound {
return nil, errors.New("HTTP 404: not found")
}
if m.listErr != nil {
return nil, m.listErr
}
entries, ok := m.contents[path]
if !ok {
return nil, errors.New("HTTP 404: not found")
}
return entries, nil
}
func (m *mockFetcher) GetFileContent(ctx context.Context, owner, repo, filepath string) (string, error) {
if err, ok := m.getFileErr[filepath]; ok {
return "", err
}
content, ok := m.files[filepath]
if !ok {
return "", errors.New("HTTP 404: file not found")
}
return content, nil
}
func TestLoadRemotePersonas_NoDirectory(t *testing.T) {
fetcher := newMockFetcher()
fetcher.listNotFound = true
result, err := LoadRemotePersonas(context.Background(), fetcher, "owner", "repo")
if err != nil {
t.Fatalf("expected no error for missing directory, got: %v", err)
}
if len(result) != 0 {
t.Errorf("expected empty map, got %d personas", len(result))
}
}
func TestLoadRemotePersonas_EmptyDirectory(t *testing.T) {
fetcher := newMockFetcher()
fetcher.contents[DefaultPersonasPath] = []ContentEntry{}
result, err := LoadRemotePersonas(context.Background(), fetcher, "owner", "repo")
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if len(result) != 0 {
t.Errorf("expected empty map, got %d personas", len(result))
}
}
func TestLoadRemotePersonas_SinglePersona(t *testing.T) {
fetcher := newMockFetcher()
fetcher.contents[DefaultPersonasPath] = []ContentEntry{
{Name: "trading.yaml", Path: ".review-bot/personas/trading.yaml", Type: "file"},
}
fetcher.files[".review-bot/personas/trading.yaml"] = `
name: trading
display_name: Trading Expert
identity: You are a trading systems expert.
focus:
- order execution
- market data
`
result, err := LoadRemotePersonas(context.Background(), fetcher, "owner", "repo")
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if len(result) != 1 {
t.Fatalf("expected 1 persona, got %d", len(result))
}
if result["trading"] == nil {
t.Fatal("expected 'trading' persona")
}
if result["trading"].DisplayName != "Trading Expert" {
t.Errorf("expected display name 'Trading Expert', got %q", result["trading"].DisplayName)
}
}
func TestLoadRemotePersonas_MultiplePersonas(t *testing.T) {
fetcher := newMockFetcher()
fetcher.contents[DefaultPersonasPath] = []ContentEntry{
{Name: "one.yaml", Path: ".review-bot/personas/one.yaml", Type: "file"},
{Name: "two.yml", Path: ".review-bot/personas/two.yml", Type: "file"},
}
fetcher.files[".review-bot/personas/one.yaml"] = `
name: one
identity: First persona.
`
fetcher.files[".review-bot/personas/two.yml"] = `
name: two
identity: Second persona.
`
result, err := LoadRemotePersonas(context.Background(), fetcher, "owner", "repo")
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if len(result) != 2 {
t.Fatalf("expected 2 personas, got %d", len(result))
}
if result["one"] == nil || result["two"] == nil {
t.Error("expected both personas to be loaded")
}
}
func TestLoadRemotePersonas_SkipsNonYAML(t *testing.T) {
fetcher := newMockFetcher()
fetcher.contents[DefaultPersonasPath] = []ContentEntry{
{Name: "valid.yaml", Path: ".review-bot/personas/valid.yaml", Type: "file"},
{Name: "readme.md", Path: ".review-bot/personas/readme.md", Type: "file"},
{Name: "config.json", Path: ".review-bot/personas/config.json", Type: "file"},
}
fetcher.files[".review-bot/personas/valid.yaml"] = `
name: valid
identity: Valid persona.
`
result, err := LoadRemotePersonas(context.Background(), fetcher, "owner", "repo")
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if len(result) != 1 {
t.Fatalf("expected 1 persona (skipping non-YAML), got %d", len(result))
}
}
func TestLoadRemotePersonas_SkipsDirectories(t *testing.T) {
fetcher := newMockFetcher()
fetcher.contents[DefaultPersonasPath] = []ContentEntry{
{Name: "valid.yaml", Path: ".review-bot/personas/valid.yaml", Type: "file"},
{Name: "subdir", Path: ".review-bot/personas/subdir", Type: "dir"},
}
fetcher.files[".review-bot/personas/valid.yaml"] = `
name: valid
identity: Valid persona.
`
result, err := LoadRemotePersonas(context.Background(), fetcher, "owner", "repo")
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if len(result) != 1 {
t.Fatalf("expected 1 persona (skipping dir), got %d", len(result))
}
}
func TestLoadRemotePersonas_SkipsInvalidYAML(t *testing.T) {
fetcher := newMockFetcher()
fetcher.contents[DefaultPersonasPath] = []ContentEntry{
{Name: "valid.yaml", Path: ".review-bot/personas/valid.yaml", Type: "file"},
{Name: "invalid.yaml", Path: ".review-bot/personas/invalid.yaml", Type: "file"},
}
fetcher.files[".review-bot/personas/valid.yaml"] = `
name: valid
identity: Valid persona.
`
fetcher.files[".review-bot/personas/invalid.yaml"] = `
this is not valid yaml: [unclosed bracket
`
result, err := LoadRemotePersonas(context.Background(), fetcher, "owner", "repo")
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if len(result) != 1 {
t.Fatalf("expected 1 persona (skipping invalid), got %d", len(result))
}
if result["valid"] == nil {
t.Error("expected valid persona to be loaded")
}
}
func TestLoadRemotePersonas_SkipsOversizedFiles(t *testing.T) {
fetcher := newMockFetcher()
fetcher.contents[DefaultPersonasPath] = []ContentEntry{
{Name: "huge.yaml", Path: ".review-bot/personas/huge.yaml", Type: "file"},
}
// Create content larger than MaxPersonaFileSize (64KB)
fetcher.files[".review-bot/personas/huge.yaml"] = `
name: huge
identity: ` + string(make([]byte, MaxPersonaFileSize+1000))
result, err := LoadRemotePersonas(context.Background(), fetcher, "owner", "repo")
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if len(result) != 0 {
t.Errorf("expected 0 personas (oversized file skipped), got %d", len(result))
}
}
func TestLoadRemotePersonas_SkipsFetchErrors(t *testing.T) {
fetcher := newMockFetcher()
fetcher.contents[DefaultPersonasPath] = []ContentEntry{
{Name: "valid.yaml", Path: ".review-bot/personas/valid.yaml", Type: "file"},
{Name: "error.yaml", Path: ".review-bot/personas/error.yaml", Type: "file"},
}
fetcher.files[".review-bot/personas/valid.yaml"] = `
name: valid
identity: Valid persona.
`
fetcher.getFileErr[".review-bot/personas/error.yaml"] = errors.New("network error")
result, err := LoadRemotePersonas(context.Background(), fetcher, "owner", "repo")
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if len(result) != 1 {
t.Fatalf("expected 1 persona (skipping error), got %d", len(result))
}
}
func TestLoadRemotePersonas_ListContentsError(t *testing.T) {
fetcher := newMockFetcher()
fetcher.listErr = errors.New("server error")
_, err := LoadRemotePersonas(context.Background(), fetcher, "owner", "repo")
if err == nil {
t.Fatal("expected error for list contents failure")
}
}
func TestLoadRemotePersonas_ContextCancellation(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
cancel() // Cancel immediately
fetcher := newMockFetcher()
fetcher.contents[DefaultPersonasPath] = []ContentEntry{
{Name: "one.yaml", Path: ".review-bot/personas/one.yaml", Type: "file"},
}
fetcher.files[".review-bot/personas/one.yaml"] = `
name: one
identity: One.
`
_, err := LoadRemotePersonas(ctx, fetcher, "owner", "repo")
if err == nil {
t.Fatal("expected context cancellation error")
}
}
func TestMergePersonas_NoOverlap(t *testing.T) {
remote := map[string]*Persona{
"trading": {Name: "trading", Identity: "Trading expert."},
}
builtin := map[string]*Persona{
"security": {Name: "security", Identity: "Security expert."},
}
merged, names := MergePersonas(remote, builtin)
if len(merged) != 2 {
t.Fatalf("expected 2 personas, got %d", len(merged))
}
if len(names) != 2 {
t.Fatalf("expected 2 names, got %d", len(names))
}
// Names should be sorted
if names[0] != "security" || names[1] != "trading" {
t.Errorf("expected sorted names [security, trading], got %v", names)
}
}
func TestMergePersonas_RemoteOverridesBuiltin(t *testing.T) {
remote := map[string]*Persona{
"security": {Name: "security", Identity: "Custom security expert."},
}
builtin := map[string]*Persona{
"security": {Name: "security", Identity: "Default security expert."},
}
merged, _ := MergePersonas(remote, builtin)
if merged["security"].Identity != "Custom security expert." {
t.Errorf("expected remote to override builtin, got identity: %q", merged["security"].Identity)
}
}
func TestMergePersonas_EmptyRemote(t *testing.T) {
remote := map[string]*Persona{}
builtin := map[string]*Persona{
"security": {Name: "security", Identity: "Security."},
}
merged, names := MergePersonas(remote, builtin)
if len(merged) != 1 {
t.Fatalf("expected 1 persona, got %d", len(merged))
}
if names[0] != "security" {
t.Errorf("expected 'security', got %q", names[0])
}
}
func TestMergePersonas_EmptyBuiltin(t *testing.T) {
remote := map[string]*Persona{
"trading": {Name: "trading", Identity: "Trading."},
}
builtin := map[string]*Persona{}
merged, names := MergePersonas(remote, builtin)
if len(merged) != 1 {
t.Fatalf("expected 1 persona, got %d", len(merged))
}
if names[0] != "trading" {
t.Errorf("expected 'trading', got %q", names[0])
}
}
func TestLoadAllBuiltinPersonas(t *testing.T) {
personas := LoadAllBuiltinPersonas()
// Should load at least the known built-in personas
expected := []string{"architect", "docs", "security"}
for _, name := range expected {
if personas[name] == nil {
t.Errorf("expected built-in persona %q to be loaded", name)
}
}
}
func TestIsYAMLFile(t *testing.T) {
tests := []struct {
name string
expected bool
}{
{"test.yaml", true},
{"test.yml", true},
{"test.YAML", true},
{"test.YML", true},
{"test.json", false},
{"test.md", false},
{"yaml", false},
{"", false},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
if got := isYAMLFile(tc.name); got != tc.expected {
t.Errorf("isYAMLFile(%q) = %v, want %v", tc.name, got, tc.expected)
}
})
}
}
func TestIsNotFoundError(t *testing.T) {
tests := []struct {
name string
err error
expected bool
}{
{"nil error", nil, false},
{"HTTP 404", errors.New("HTTP 404: not found"), true},
{"not found text", errors.New("path not found"), false},
{"server error", errors.New("server error"), false},
{"HTTP 500", errors.New("HTTP 500: internal error"), false},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
if got := isNotFoundError(tc.err); got != tc.expected {
t.Errorf("isNotFoundError(%v) = %v, want %v", tc.err, got, tc.expected)
}
})
}
}
-137
View File
@@ -1,137 +0,0 @@
package review
import (
"context"
"log/slog"
"strings"
"gitea.weiker.me/rodin/review-bot/vcs"
)
// RepoPersonaPath is the directory path where repo-specific personas are stored.
const RepoPersonaPath = ".review-bot/personas"
// LoadRepoPersonas fetches personas from a repository's .review-bot/personas/ directory.
// Returns an empty map (not nil) if the directory doesn't exist or is empty.
// Individual parse failures are logged and skipped; the remaining personas are still returned.
// Auth errors and other non-404 errors are propagated.
// Files exceeding MaxPersonaFileSize are rejected to prevent resource exhaustion.
func LoadRepoPersonas(ctx context.Context, client vcs.FileReader, owner, repo string) (map[string]*Persona, error) {
result := make(map[string]*Persona)
entries, err := client.ListContents(ctx, owner, repo, RepoPersonaPath)
if err != nil {
// Check if this is a 404 (directory doesn't exist) - expected case
if isNotFoundError(err) {
slog.Debug("no repo personas directory found", "repo", owner+"/"+repo)
return result, nil
}
// Other errors (auth, server) should propagate
return nil, err
}
if len(entries) == 0 {
slog.Debug("repo personas directory is empty", "repo", owner+"/"+repo)
return result, nil
}
for _, entry := range entries {
if entry.Type != "file" {
continue
}
// Only process YAML files
if !isYAMLFile(entry.Name) {
continue
}
content, err := client.GetFileContent(ctx, owner, repo, entry.Path, "")
if err != nil {
slog.Warn("could not fetch repo persona file",
"file", entry.Path,
"repo", owner+"/"+repo,
"error", err)
continue
}
// Enforce size limit before parsing to prevent resource exhaustion
if len(content) > MaxPersonaFileSize {
slog.Warn("repo persona file exceeds maximum size",
"file", entry.Path,
"repo", owner+"/"+repo,
"size", len(content),
"max", MaxPersonaFileSize)
continue
}
persona, err := ParsePersonaBytes([]byte(content), entry.Path)
if err != nil {
slog.Warn("could not parse repo persona file",
"file", entry.Path,
"repo", owner+"/"+repo,
"error", err)
continue
}
result[persona.Name] = persona
slog.Debug("loaded repo persona",
"name", persona.Name,
"file", entry.Path,
"repo", owner+"/"+repo)
}
return result, nil
}
// MergePersonas combines built-in personas with repo personas.
// Repo personas take precedence on name collision.
// Returns a new map; inputs are not modified.
func MergePersonas(builtin, repo map[string]*Persona) map[string]*Persona {
result := make(map[string]*Persona, len(builtin)+len(repo))
// Copy built-in personas first
for name, p := range builtin {
result[name] = p
}
// Overlay repo personas (override on collision)
for name, p := range repo {
if _, exists := result[name]; exists {
slog.Debug("repo persona overrides built-in", "name", name)
}
result[name] = p
}
return result
}
// GetBuiltinPersonasMap returns all built-in personas as a map keyed by name.
// Returns an empty map (not nil) if loading fails.
func GetBuiltinPersonasMap() map[string]*Persona {
result := make(map[string]*Persona)
for _, name := range ListBuiltinPersonas() {
p, err := LoadBuiltinPersona(name)
if err != nil {
slog.Warn("could not load built-in persona", "name", name, "error", err)
continue
}
result[name] = p
}
return result
}
// isYAMLFile checks if a filename has a YAML extension.
func isYAMLFile(name string) bool {
lower := strings.ToLower(name)
return strings.HasSuffix(lower, ".yaml") || strings.HasSuffix(lower, ".yml")
}
// isNotFoundError checks if an error represents a 404 response.
// This uses a specific "HTTP 404" substring match rather than a generic "not found"
// match to avoid masking authentication failures or transport errors that might
// contain "not found" in their message.
func isNotFoundError(err error) bool {
if err == nil {
return false
}
return strings.Contains(err.Error(), "HTTP 404")
}
-412
View File
@@ -1,412 +0,0 @@
package review
import (
"context"
"errors"
"strings"
"testing"
"gitea.weiker.me/rodin/review-bot/vcs"
)
func TestParsePersonaBytes(t *testing.T) {
tests := []struct {
name string
data string
source string
wantName string
wantErr string
}{
{
name: "valid yaml",
data: "name: test\nidentity: test identity\nfocus:\n - testing\n",
source: "test.yaml",
wantName: "test",
},
{
name: "missing name",
data: "identity: test\n",
source: "test.yaml",
wantErr: "name is required",
},
{
name: "invalid yaml",
data: "not: valid:\n yaml: [broken",
source: "test.yaml",
wantErr: "parse",
},
{
name: "json format by extension",
data: `{"name": "jsontest", "identity": "json identity"}`,
source: "test.json",
wantName: "jsontest",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
p, err := ParsePersonaBytes([]byte(tt.data), tt.source)
if tt.wantErr != "" {
if err == nil {
t.Fatalf("expected error containing %q, got nil", tt.wantErr)
}
if !strings.Contains(err.Error(), tt.wantErr) {
t.Errorf("error = %q, want containing %q", err.Error(), tt.wantErr)
}
return
}
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if p.Name != tt.wantName {
t.Errorf("Name = %q, want %q", p.Name, tt.wantName)
}
})
}
}
// mockGiteaClient implements vcs.FileReader for testing.
type mockGiteaClient struct {
contents map[string][]vcs.ContentEntry // path -> entries
files map[string]string // path -> content
listErr error
fileErr map[string]error // path -> error
}
func (m *mockGiteaClient) ListContents(ctx context.Context, owner, repo, path string) ([]vcs.ContentEntry, error) {
if m.listErr != nil {
return nil, m.listErr
}
entries, ok := m.contents[path]
if !ok {
return nil, errors.New("list contents .review-bot/personas: HTTP 404: not found")
}
return entries, nil
}
func (m *mockGiteaClient) GetFileContent(ctx context.Context, owner, repo, filepath, ref string) (string, error) {
if m.fileErr != nil {
if err, ok := m.fileErr[filepath]; ok {
return "", err
}
}
content, ok := m.files[filepath]
if !ok {
return "", errors.New("HTTP 404: file not found")
}
return content, nil
}
func TestLoadRepoPersonas(t *testing.T) {
ctx := context.Background()
t.Run("directory not found returns empty map", func(t *testing.T) {
client := &mockGiteaClient{} // No contents configured -> 404
personas, err := LoadRepoPersonas(ctx, client, "owner", "repo")
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if personas == nil {
t.Error("expected empty map, got nil")
}
if len(personas) != 0 {
t.Errorf("expected 0 personas, got %d", len(personas))
}
})
t.Run("empty directory returns empty map", func(t *testing.T) {
client := &mockGiteaClient{
contents: map[string][]vcs.ContentEntry{
RepoPersonaPath: {},
},
}
personas, err := LoadRepoPersonas(ctx, client, "owner", "repo")
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if len(personas) != 0 {
t.Errorf("expected 0 personas, got %d", len(personas))
}
})
t.Run("loads valid personas", func(t *testing.T) {
client := &mockGiteaClient{
contents: map[string][]vcs.ContentEntry{
RepoPersonaPath: {
{Name: "trading.yaml", Path: ".review-bot/personas/trading.yaml", Type: "file"},
{Name: "crypto.yaml", Path: ".review-bot/personas/crypto.yaml", Type: "file"},
},
},
files: map[string]string{
".review-bot/personas/trading.yaml": "name: trading\ndisplay_name: Trading Expert\nidentity: You are a trading expert.\nfocus:\n - order handling\n - risk management\n",
".review-bot/personas/crypto.yaml": "name: crypto\ndisplay_name: Crypto Expert\nidentity: You are a cryptography expert.\nfocus:\n - key management\n - encryption\n",
},
}
personas, err := LoadRepoPersonas(ctx, client, "owner", "repo")
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if len(personas) != 2 {
t.Fatalf("expected 2 personas, got %d", len(personas))
}
if personas["trading"] == nil {
t.Error("expected trading persona")
}
if personas["crypto"] == nil {
t.Error("expected crypto persona")
}
if personas["trading"].DisplayName != "Trading Expert" {
t.Errorf("trading display name = %q, want %q", personas["trading"].DisplayName, "Trading Expert")
}
})
t.Run("skips invalid persona files", func(t *testing.T) {
client := &mockGiteaClient{
contents: map[string][]vcs.ContentEntry{
RepoPersonaPath: {
{Name: "valid.yaml", Path: ".review-bot/personas/valid.yaml", Type: "file"},
{Name: "invalid.yaml", Path: ".review-bot/personas/invalid.yaml", Type: "file"},
},
},
files: map[string]string{
".review-bot/personas/valid.yaml": "name: valid\nidentity: Valid persona\n",
".review-bot/personas/invalid.yaml": "not valid yaml: [broken",
},
}
personas, err := LoadRepoPersonas(ctx, client, "owner", "repo")
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if len(personas) != 1 {
t.Fatalf("expected 1 persona (skipped invalid), got %d", len(personas))
}
if personas["valid"] == nil {
t.Error("expected valid persona")
}
})
t.Run("skips non-yaml files", func(t *testing.T) {
client := &mockGiteaClient{
contents: map[string][]vcs.ContentEntry{
RepoPersonaPath: {
{Name: "persona.yaml", Path: ".review-bot/personas/persona.yaml", Type: "file"},
{Name: "README.md", Path: ".review-bot/personas/README.md", Type: "file"},
{Name: "notes.txt", Path: ".review-bot/personas/notes.txt", Type: "file"},
},
},
files: map[string]string{
".review-bot/personas/persona.yaml": "name: test\nidentity: Test persona\n",
".review-bot/personas/README.md": "# Personas\n\nPut your personas here.",
},
}
personas, err := LoadRepoPersonas(ctx, client, "owner", "repo")
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if len(personas) != 1 {
t.Fatalf("expected 1 persona (yaml only), got %d", len(personas))
}
})
t.Run("skips subdirectories", func(t *testing.T) {
client := &mockGiteaClient{
contents: map[string][]vcs.ContentEntry{
RepoPersonaPath: {
{Name: "persona.yaml", Path: ".review-bot/personas/persona.yaml", Type: "file"},
{Name: "subdir", Path: ".review-bot/personas/subdir", Type: "dir"},
},
},
files: map[string]string{
".review-bot/personas/persona.yaml": "name: test\nidentity: Test persona\n",
},
}
personas, err := LoadRepoPersonas(ctx, client, "owner", "repo")
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if len(personas) != 1 {
t.Fatalf("expected 1 persona (files only), got %d", len(personas))
}
})
t.Run("propagates auth errors", func(t *testing.T) {
client := &mockGiteaClient{
listErr: errors.New("HTTP 401: unauthorized"),
}
_, err := LoadRepoPersonas(ctx, client, "owner", "repo")
if err == nil {
t.Fatal("expected error for auth failure")
}
if !strings.Contains(err.Error(), "401") {
t.Errorf("error = %q, want containing '401'", err.Error())
}
})
t.Run("skips files that fail to fetch", func(t *testing.T) {
client := &mockGiteaClient{
contents: map[string][]vcs.ContentEntry{
RepoPersonaPath: {
{Name: "good.yaml", Path: ".review-bot/personas/good.yaml", Type: "file"},
{Name: "bad.yaml", Path: ".review-bot/personas/bad.yaml", Type: "file"},
},
},
files: map[string]string{
".review-bot/personas/good.yaml": "name: good\nidentity: Good persona\n",
},
fileErr: map[string]error{
".review-bot/personas/bad.yaml": errors.New("HTTP 500: internal server error"),
},
}
personas, err := LoadRepoPersonas(ctx, client, "owner", "repo")
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if len(personas) != 1 {
t.Fatalf("expected 1 persona (skipped failed fetch), got %d", len(personas))
}
})
t.Run("skips oversized files", func(t *testing.T) {
oversizedContent := strings.Repeat("a", MaxPersonaFileSize+1)
client := &mockGiteaClient{
contents: map[string][]vcs.ContentEntry{
RepoPersonaPath: {
{Name: "normal.yaml", Path: ".review-bot/personas/normal.yaml", Type: "file"},
{Name: "huge.yaml", Path: ".review-bot/personas/huge.yaml", Type: "file"},
},
},
files: map[string]string{
".review-bot/personas/normal.yaml": "name: normal\nidentity: Normal sized persona\n",
".review-bot/personas/huge.yaml": oversizedContent,
},
}
personas, err := LoadRepoPersonas(ctx, client, "owner", "repo")
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if len(personas) != 1 {
t.Fatalf("expected 1 persona (skipped oversized), got %d", len(personas))
}
if personas["normal"] == nil {
t.Error("expected normal persona")
}
})
}
func TestMergePersonas(t *testing.T) {
builtin := map[string]*Persona{
"security": {Name: "security", Identity: "Built-in security"},
"docs": {Name: "docs", Identity: "Built-in docs"},
}
repo := map[string]*Persona{
"security": {Name: "security", Identity: "Repo security override"},
"trading": {Name: "trading", Identity: "Repo trading"},
}
merged := MergePersonas(builtin, repo)
t.Run("repo overrides builtin on collision", func(t *testing.T) {
if merged["security"].Identity != "Repo security override" {
t.Errorf("security identity = %q, want repo override", merged["security"].Identity)
}
})
t.Run("builtin preserved when no collision", func(t *testing.T) {
if merged["docs"].Identity != "Built-in docs" {
t.Errorf("docs identity = %q, want built-in", merged["docs"].Identity)
}
})
t.Run("repo-only persona added", func(t *testing.T) {
if merged["trading"] == nil {
t.Error("expected trading persona from repo")
}
if merged["trading"].Identity != "Repo trading" {
t.Errorf("trading identity = %q, want repo", merged["trading"].Identity)
}
})
t.Run("original maps not modified", func(t *testing.T) {
if builtin["trading"] != nil {
t.Error("builtin map was modified")
}
if len(repo) != 2 {
t.Error("repo map was modified")
}
})
}
func TestGetBuiltinPersonasMap(t *testing.T) {
personas := GetBuiltinPersonasMap()
if len(personas) == 0 {
t.Fatal("expected at least one built-in persona")
}
expected := []string{"security", "architect", "docs"}
for _, name := range expected {
if personas[name] == nil {
t.Errorf("expected built-in persona %q", name)
}
}
for name, p := range personas {
if p.Name != name {
t.Errorf("persona %q has mismatched name %q", name, p.Name)
}
if p.Identity == "" {
t.Errorf("persona %q has empty identity", name)
}
}
}
func TestIsYAMLFile(t *testing.T) {
tests := []struct {
name string
want bool
}{
{"test.yaml", true},
{"test.yml", true},
{"test.YAML", true},
{"test.YML", true},
{"test.json", false},
{"test.md", false},
{"test.txt", false},
{"yaml", false},
{"yaml.md", false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := isYAMLFile(tt.name); got != tt.want {
t.Errorf("isYAMLFile(%q) = %v, want %v", tt.name, got, tt.want)
}
})
}
}
func TestIsNotFoundError(t *testing.T) {
tests := []struct {
err error
want bool
}{
{nil, false},
{errors.New("HTTP 404: not found"), true},
{errors.New("HTTP 404"), true},
{errors.New("something not found"), false},
{errors.New("HTTP 401: unauthorized"), false},
{errors.New("connection refused"), false},
}
for _, tt := range tests {
name := "nil"
if tt.err != nil {
name = tt.err.Error()
}
t.Run(name, func(t *testing.T) {
if got := isNotFoundError(tt.err); got != tt.want {
t.Errorf("isNotFoundError(%v) = %v, want %v", tt.err, got, tt.want)
}
})
}
}
-11
View File
@@ -1,11 +0,0 @@
package vcs_test
import (
"gitea.weiker.me/rodin/review-bot/gitea"
"gitea.weiker.me/rodin/review-bot/vcs"
)
// Compile-time assertion: the gitea.Adapter satisfies vcs.Client.
// (The raw gitea.Client does NOT satisfy vcs.Client due to signature differences;
// the Adapter bridges them.)
var _ vcs.Client = (*gitea.Adapter)(nil)
-43
View File
@@ -1,43 +0,0 @@
// Package vcs defines the shared VCS client interface and supporting types.
// Platform adapters (gitea, github) implement these interfaces so the core
// review logic can work with any VCS platform without platform-specific code.
package vcs
import "context"
// PRReader can fetch pull request metadata, diffs, and changed files.
type PRReader interface {
GetPullRequest(ctx context.Context, owner, repo string, number int) (*PullRequest, error)
GetPullRequestDiff(ctx context.Context, owner, repo string, number int) (string, error)
GetPullRequestFiles(ctx context.Context, owner, repo string, number int) ([]ChangedFile, error)
GetFileContentAtRef(ctx context.Context, owner, repo, path, ref string) (string, error)
GetCommitStatuses(ctx context.Context, owner, repo, sha string) ([]CommitStatus, error)
}
// FileReader can fetch file contents and list directory entries.
type FileReader interface {
GetFileContent(ctx context.Context, owner, repo, path, ref string) (string, error)
ListContents(ctx context.Context, owner, repo, path string) ([]ContentEntry, error)
}
// Reviewer can post, list, and delete pull request reviews.
type Reviewer interface {
PostReview(ctx context.Context, owner, repo string, number int, req ReviewRequest) (*Review, error)
ListReviews(ctx context.Context, owner, repo string, number int) ([]Review, error)
DeleteReview(ctx context.Context, owner, repo string, number int, reviewID int64) error
DismissReview(ctx context.Context, owner, repo string, number int, reviewID int64, message string) error
}
// Identity can report who the authenticated user is.
type Identity interface {
GetAuthenticatedUser(ctx context.Context) (string, error)
}
// Client is the full VCS interface: PR reads, file reads, review management, and identity.
// Platform adapters (gitea, github) implement this interface.
type Client interface {
PRReader
FileReader
Reviewer
Identity
}
-98
View File
@@ -1,98 +0,0 @@
package vcs
// ReviewEvent is the event type for a pull request review action.
// Adapters must translate these action constants to/from platform-native values.
// For example, Gitea uses "APPROVED" as both action and state, while GitHub
// uses "APPROVE" for the action and returns "approved" as the state.
type ReviewEvent string
const (
// ReviewEventApprove approves the pull request.
ReviewEventApprove ReviewEvent = "APPROVE"
// ReviewEventRequestChanges requests changes to the pull request.
ReviewEventRequestChanges ReviewEvent = "REQUEST_CHANGES"
// ReviewEventComment posts a review comment without approval or rejection.
ReviewEventComment ReviewEvent = "COMMENT"
)
// BaseRef identifies the target branch of a pull request.
type BaseRef struct {
Ref string `json:"ref"`
}
// HeadRef identifies the source branch and latest commit of a pull request.
type HeadRef struct {
SHA string `json:"sha"`
Ref string `json:"ref"`
}
// UserInfo identifies a user by login name.
type UserInfo struct {
Login string `json:"login"`
}
// PullRequest holds relevant PR metadata.
type PullRequest struct {
Number int `json:"number"`
Title string `json:"title"`
Body string `json:"body"`
Head HeadRef `json:"head"`
Base BaseRef `json:"base"`
}
// ChangedFile represents a file modified in a PR.
type ChangedFile struct {
Filename string `json:"filename"`
Status string `json:"status"`
Patch string `json:"patch"`
}
// ContentEntry represents a file or directory entry from the contents API.
type ContentEntry struct {
Name string `json:"name"`
Path string `json:"path"`
Type string `json:"type"` // "file" or "dir"
}
// CommitStatus represents a single CI status entry for a commit.
type CommitStatus struct {
Status string `json:"status"`
Context string `json:"context"`
Description string `json:"description"`
TargetURL string `json:"target_url"`
}
// Review represents a pull request review.
type Review struct {
ID int64 `json:"id"`
Body string `json:"body"`
User UserInfo `json:"user"`
State string `json:"state"`
Stale bool `json:"stale"`
CommitID string `json:"commit_id"`
}
// ReviewComment represents an inline comment in a review.
// All adapters use GitHub diff-position convention:
// - Position is a 1-indexed offset from the @@ hunk line in the unified diff.
// - CommitID identifies the commit the comment is anchored to.
// It is optional; omit (empty string) for review-level comments that are
// not attached to a specific commit.
//
// Adapters are responsible for translating to/from platform-native formats
// (e.g. Gitea uses line numbers; GitHub uses diff positions natively).
type ReviewComment struct {
Path string `json:"path"`
Position int `json:"position"` // diff-position: 1-indexed offset from @@ hunk line
CommitID string `json:"commit_id"`
Body string `json:"body"`
}
// ReviewRequest is the payload for posting a review.
type ReviewRequest struct {
// Body is the top-level review comment.
Body string `json:"body"`
// Event is the review action (approve, request changes, or comment).
Event ReviewEvent `json:"event"`
Comments []ReviewComment `json:"comments,omitempty"`
}
-193
View File
@@ -1,193 +0,0 @@
package vcs
import (
"context"
"fmt"
"strconv"
"strings"
)
const (
// maxFilesInPath is the maximum number of files GetAllFilesInPath will fetch.
// Prevents unbounded resource consumption on very large directory trees.
maxFilesInPath = 10000
// maxTotalBytesInPath is the maximum total bytes GetAllFilesInPath will accumulate.
// Prevents memory exhaustion when fetching large repositories.
maxTotalBytesInPath = 100 * 1024 * 1024 // 100 MB
)
// GetAllFilesInPath recursively fetches all file contents under a path using the
// provided FileReader. Returns a map of filepath -> content for all files found.
// If the path points to an empty directory, returns an empty map.
//
// This function uses fail-fast error handling: any error from ListContents or
// GetFileContent aborts the entire traversal and returns the error immediately.
// This differs from gitea.Client.GetAllFilesInPath, which logs errors and continues.
// The fail-fast contract ensures callers can trust that a nil error means all files
// were successfully fetched.
//
// Resource limits: the traversal is bounded by maxFilesInPath (file count) and
// maxTotalBytesInPath (total accumulated bytes). The context is checked before each
// recursive call and file fetch to respect cancellation.
func GetAllFilesInPath(ctx context.Context, client FileReader, owner, repo, path string) (map[string]string, error) {
results := make(map[string]string)
totalBytes := 0
var walk func(string) error
walk = func(dir string) error {
if err := ctx.Err(); err != nil {
return fmt.Errorf("context canceled during traversal: %w", err)
}
entries, err := client.ListContents(ctx, owner, repo, dir)
if err != nil {
return fmt.Errorf("list contents %q: %w", dir, err)
}
for _, entry := range entries {
if err := ctx.Err(); err != nil {
return fmt.Errorf("context canceled during traversal: %w", err)
}
switch entry.Type {
case "file":
if len(results) >= maxFilesInPath {
return fmt.Errorf("exceeded max file count (%d) in path %q", maxFilesInPath, path)
}
content, err := client.GetFileContent(ctx, owner, repo, entry.Path, "")
if err != nil {
return fmt.Errorf("get file %q: %w", entry.Path, err)
}
totalBytes += len(content)
if totalBytes > maxTotalBytesInPath {
return fmt.Errorf("exceeded max total bytes (%d) in path %q", maxTotalBytesInPath, path)
}
results[entry.Path] = content
case "dir":
if err := walk(entry.Path); err != nil {
return err
}
}
}
return nil
}
if err := walk(path); err != nil {
return nil, err
}
return results, nil
}
// BuildLineToPositionMap parses a unified diff and returns a map of
// filename -> (new line number -> diff position). The diff position is a
// 1-indexed offset from the @@ hunk header line for each file.
// Only lines that appear in the new file (context lines and additions) are mapped.
// Deletion-only lines are not included.
func BuildLineToPositionMap(diff string) map[string]map[int]int {
result := make(map[string]map[int]int)
lines := strings.Split(diff, "\n")
var currentFile string
var position int
var newLine int
for _, line := range lines {
// Detect new file in diff
if strings.HasPrefix(line, "+++ b/") {
currentFile = strings.TrimPrefix(line, "+++ b/")
position = 0
newLine = 0
if result[currentFile] == nil {
result[currentFile] = make(map[int]int)
}
continue
}
// Skip --- lines (old file header)
if strings.HasPrefix(line, "--- ") {
continue
}
// Skip diff --git lines
if strings.HasPrefix(line, "diff --git") {
continue
}
// Skip index lines
if strings.HasPrefix(line, "index ") {
continue
}
// Parse hunk headers
if strings.HasPrefix(line, "@@") {
position++
// Extract new file start line from @@ -a,b +c,d @@
newLine = parseHunkNewStart(line)
continue
}
// We need a current file to map lines
if currentFile == "" {
continue
}
// Skip "\ No newline at end of file" markers — these are git diff
// metadata and not part of the file content.
if strings.HasPrefix(line, `\`) {
continue
}
// Process diff content lines
if strings.HasPrefix(line, "+") {
position++
result[currentFile][newLine] = position
newLine++
} else if strings.HasPrefix(line, "-") {
position++
// Deletion lines don't map to new line numbers
} else if strings.HasPrefix(line, " ") {
// Context line (space-prefixed).
// Only map if position > 0, which means we've seen a hunk header.
// Lines before the first hunk header (position == 0) are not part
// of any diff hunk and should be skipped.
if position > 0 {
position++
result[currentFile][newLine] = position
newLine++
}
}
}
return result
}
// parseHunkNewStart extracts the new-file starting line number from a hunk header.
// Format: @@ -old_start[,old_count] +new_start[,new_count] @@
func parseHunkNewStart(hunkLine string) int {
// Find the +N part
plusIdx := strings.Index(hunkLine, "+")
if plusIdx < 0 {
return 1
}
rest := hunkLine[plusIdx+1:]
// Find the end of the number (first non-digit after +)
endIdx := 0
for endIdx < len(rest) && rest[endIdx] >= '0' && rest[endIdx] <= '9' {
endIdx++
}
if endIdx == 0 {
return 1
}
n, err := strconv.Atoi(rest[:endIdx])
if err != nil {
return 1
}
return n
}
-331
View File
@@ -1,331 +0,0 @@
package vcs_test
import (
"context"
"fmt"
"strings"
"testing"
"gitea.weiker.me/rodin/review-bot/vcs"
)
// mockFileReader implements vcs.FileReader for testing.
type mockFileReader struct {
contents map[string][]vcs.ContentEntry // path -> entries
files map[string]string // path -> content
}
func (m *mockFileReader) GetFileContent(ctx context.Context, owner, repo, path, ref string) (string, error) {
content, ok := m.files[path]
if !ok {
return "", fmt.Errorf("HTTP 404: file not found: %s", path)
}
return content, nil
}
func (m *mockFileReader) ListContents(ctx context.Context, owner, repo, path string) ([]vcs.ContentEntry, error) {
entries, ok := m.contents[path]
if !ok {
return nil, fmt.Errorf("HTTP 404: path not found: %s", path)
}
return entries, nil
}
func TestGetAllFilesInPath(t *testing.T) {
ctx := context.Background()
t.Run("empty directory", func(t *testing.T) {
client := &mockFileReader{
contents: map[string][]vcs.ContentEntry{
"src": {},
},
}
result, err := vcs.GetAllFilesInPath(ctx, client, "owner", "repo", "src")
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if len(result) != 0 {
t.Errorf("expected empty map, got %d entries", len(result))
}
})
t.Run("flat directory", func(t *testing.T) {
client := &mockFileReader{
contents: map[string][]vcs.ContentEntry{
"src": {
{Name: "main.go", Path: "src/main.go", Type: "file"},
{Name: "util.go", Path: "src/util.go", Type: "file"},
},
},
files: map[string]string{
"src/main.go": "package main",
"src/util.go": "package main\n// util",
},
}
result, err := vcs.GetAllFilesInPath(ctx, client, "owner", "repo", "src")
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if len(result) != 2 {
t.Fatalf("expected 2 files, got %d", len(result))
}
if result["src/main.go"] != "package main" {
t.Errorf("main.go content = %q", result["src/main.go"])
}
if result["src/util.go"] != "package main\n// util" {
t.Errorf("util.go content = %q", result["src/util.go"])
}
})
t.Run("nested directories", func(t *testing.T) {
client := &mockFileReader{
contents: map[string][]vcs.ContentEntry{
"src": {
{Name: "main.go", Path: "src/main.go", Type: "file"},
{Name: "pkg", Path: "src/pkg", Type: "dir"},
},
"src/pkg": {
{Name: "lib.go", Path: "src/pkg/lib.go", Type: "file"},
{Name: "sub", Path: "src/pkg/sub", Type: "dir"},
},
"src/pkg/sub": {
{Name: "deep.go", Path: "src/pkg/sub/deep.go", Type: "file"},
},
},
files: map[string]string{
"src/main.go": "package main",
"src/pkg/lib.go": "package pkg",
"src/pkg/sub/deep.go": "package sub",
},
}
result, err := vcs.GetAllFilesInPath(ctx, client, "owner", "repo", "src")
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if len(result) != 3 {
t.Fatalf("expected 3 files, got %d", len(result))
}
if result["src/main.go"] != "package main" {
t.Errorf("main.go content = %q", result["src/main.go"])
}
if result["src/pkg/lib.go"] != "package pkg" {
t.Errorf("lib.go content = %q", result["src/pkg/lib.go"])
}
if result["src/pkg/sub/deep.go"] != "package sub" {
t.Errorf("deep.go content = %q", result["src/pkg/sub/deep.go"])
}
})
t.Run("mixed files and dirs", func(t *testing.T) {
client := &mockFileReader{
contents: map[string][]vcs.ContentEntry{
"root": {
{Name: "README.md", Path: "root/README.md", Type: "file"},
{Name: "docs", Path: "root/docs", Type: "dir"},
{Name: "config.yaml", Path: "root/config.yaml", Type: "file"},
},
"root/docs": {
{Name: "guide.md", Path: "root/docs/guide.md", Type: "file"},
},
},
files: map[string]string{
"root/README.md": "# Hello",
"root/config.yaml": "key: value",
"root/docs/guide.md": "## Guide",
},
}
result, err := vcs.GetAllFilesInPath(ctx, client, "owner", "repo", "root")
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if len(result) != 3 {
t.Fatalf("expected 3 files, got %d", len(result))
}
if result["root/README.md"] != "# Hello" {
t.Errorf("README content = %q", result["root/README.md"])
}
if result["root/docs/guide.md"] != "## Guide" {
t.Errorf("guide content = %q", result["root/docs/guide.md"])
}
})
}
func TestBuildLineToPositionMap(t *testing.T) {
t.Run("single hunk", func(t *testing.T) {
diff := "diff --git a/file.go b/file.go\nindex abc..def 100644\n--- a/file.go\n+++ b/file.go\n@@ -1,3 +1,4 @@\n package main\n \n+// new comment\n func main() {}\n"
result := vcs.BuildLineToPositionMap(diff)
fileMap, ok := result["file.go"]
if !ok {
t.Fatal("expected file.go in result")
}
// Hunk header @@ is position 1
// Line 1: " package main" -> position 2
if fileMap[1] != 2 {
t.Errorf("line 1 position = %d, want 2", fileMap[1])
}
// Line 2: " " (context) -> position 3
if fileMap[2] != 3 {
t.Errorf("line 2 position = %d, want 3", fileMap[2])
}
// Line 3: "+// new comment" -> position 4
if fileMap[3] != 4 {
t.Errorf("line 3 position = %d, want 4", fileMap[3])
}
// Line 4: " func main() {}" -> position 5
if fileMap[4] != 5 {
t.Errorf("line 4 position = %d, want 5", fileMap[4])
}
})
t.Run("multi hunk", func(t *testing.T) {
diff := "diff --git a/file.go b/file.go\n--- a/file.go\n+++ b/file.go\n@@ -1,3 +1,3 @@\n package main\n \n-// old\n+// new\n@@ -10,3 +10,4 @@\n func foo() {\n+\t// added\n \treturn\n }\n"
result := vcs.BuildLineToPositionMap(diff)
fileMap, ok := result["file.go"]
if !ok {
t.Fatal("expected file.go in result")
}
// First hunk: @@ is position 1
// Line 1: " package main" -> position 2
if fileMap[1] != 2 {
t.Errorf("line 1 position = %d, want 2", fileMap[1])
}
// Line 3: "+// new" -> position 5 (after " ", "-// old" at pos 3,4)
if fileMap[3] != 5 {
t.Errorf("line 3 position = %d, want 5", fileMap[3])
}
// Second hunk: @@ is position 6
// Line 10: " func foo() {" -> position 7
if fileMap[10] != 7 {
t.Errorf("line 10 position = %d, want 7", fileMap[10])
}
// Line 11: "+\t// added" -> position 8
if fileMap[11] != 8 {
t.Errorf("line 11 position = %d, want 8", fileMap[11])
}
})
t.Run("deletion lines not in map", func(t *testing.T) {
diff := "diff --git a/file.go b/file.go\n--- a/file.go\n+++ b/file.go\n@@ -1,4 +1,3 @@\n package main\n \n-// deleted line\n func main() {}\n"
result := vcs.BuildLineToPositionMap(diff)
fileMap, ok := result["file.go"]
if !ok {
t.Fatal("expected file.go in result")
}
// Line 1: " package main" -> position 2
if fileMap[1] != 2 {
t.Errorf("line 1 position = %d, want 2", fileMap[1])
}
// Line 3 in new file: " func main() {}" -> position 5 (after deletion at pos 4)
if fileMap[3] != 5 {
t.Errorf("line 3 position = %d, want 5", fileMap[3])
}
// Should only have 3 entries (lines 1, 2, 3 of new file)
if len(fileMap) != 3 {
t.Errorf("expected 3 mapped lines, got %d: %v", len(fileMap), fileMap)
}
})
t.Run("multiple files", func(t *testing.T) {
diff := "diff --git a/a.go b/a.go\n--- a/a.go\n+++ b/a.go\n@@ -1,2 +1,3 @@\n package a\n \n+// file a\ndiff --git a/b.go b/b.go\n--- a/b.go\n+++ b/b.go\n@@ -1,2 +1,3 @@\n package b\n \n+// file b\n"
result := vcs.BuildLineToPositionMap(diff)
if len(result) != 2 {
t.Fatalf("expected 2 files, got %d", len(result))
}
aMap, ok := result["a.go"]
if !ok {
t.Fatal("expected a.go in result")
}
bMap, ok := result["b.go"]
if !ok {
t.Fatal("expected b.go in result")
}
// a.go line 3: "+// file a" -> position 4
if aMap[3] != 4 {
t.Errorf("a.go line 3 position = %d, want 4", aMap[3])
}
// b.go line 3: "+// file b" -> position 4
if bMap[3] != 4 {
t.Errorf("b.go line 3 position = %d, want 4", bMap[3])
}
})
}
func TestGetAllFilesInPath_ErrorPropagation(t *testing.T) {
ctx := context.Background()
t.Run("ListContents error propagates", func(t *testing.T) {
client := &mockFileReader{
contents: map[string][]vcs.ContentEntry{
// "src" not in map, so ListContents will fail
},
}
_, err := vcs.GetAllFilesInPath(ctx, client, "owner", "repo", "src")
if err == nil {
t.Fatal("expected error, got nil")
}
if !strings.Contains(err.Error(), "list contents") {
t.Errorf("expected error about list contents, got: %v", err)
}
})
t.Run("GetFileContent error propagates", func(t *testing.T) {
client := &mockFileReader{
contents: map[string][]vcs.ContentEntry{
"src": {
{Name: "main.go", Path: "src/main.go", Type: "file"},
},
},
files: map[string]string{
// "src/main.go" not in files map, so GetFileContent will fail
},
}
_, err := vcs.GetAllFilesInPath(ctx, client, "owner", "repo", "src")
if err == nil {
t.Fatal("expected error, got nil")
}
if !strings.Contains(err.Error(), "get file") {
t.Errorf("expected error about get file, got: %v", err)
}
})
t.Run("nested ListContents error propagates", func(t *testing.T) {
client := &mockFileReader{
contents: map[string][]vcs.ContentEntry{
"src": {
{Name: "pkg", Path: "src/pkg", Type: "dir"},
},
// "src/pkg" not in map, so recursive ListContents will fail
},
}
_, err := vcs.GetAllFilesInPath(ctx, client, "owner", "repo", "src")
if err == nil {
t.Fatal("expected error, got nil")
}
if !strings.Contains(err.Error(), "list contents") {
t.Errorf("expected error about list contents, got: %v", err)
}
})
t.Run("canceled context propagates", func(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
cancel() // Cancel immediately
client := &mockFileReader{
contents: map[string][]vcs.ContentEntry{
"src": {
{Name: "main.go", Path: "src/main.go", Type: "file"},
},
},
files: map[string]string{
"src/main.go": "package main",
},
}
_, err := vcs.GetAllFilesInPath(ctx, client, "owner", "repo", "src")
if err == nil {
t.Fatal("expected error from canceled context, got nil")
}
if !strings.Contains(err.Error(), "context canceled") {
t.Errorf("expected context cancellation error, got: %v", err)
}
})
}