Compare commits
23 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| f28c792bda | |||
| b534247c85 | |||
| 6f02cef662 | |||
| fccfdd2ff7 | |||
| e3fb19fa1b | |||
| a1bbab406d | |||
| 4d48917e36 | |||
| bd516cd044 | |||
| 1f67954da7 | |||
| d396599d05 | |||
| 9f3f32174b | |||
| c53a07b230 | |||
| bbf3dfbf0d | |||
| ed3a5dddf1 | |||
| 449a24e4c5 | |||
| 4440823571 | |||
| c349986187 | |||
| 934c6728ee | |||
| 5ac93bea70 | |||
| f84cc3bbcf | |||
| 8c8f3ab4b3 | |||
| 50facefdd6 | |||
| bd2df7d986 |
@@ -1,17 +1,43 @@
|
||||
# This composite action is designed for Gitea Actions runners.
|
||||
# Gitea Actions supports GitHub Actions syntax including $GITHUB_OUTPUT,
|
||||
# actions/cache, and actions/checkout.
|
||||
# This composite action supports both Gitea Actions and GitHub Actions runners.
|
||||
# It detects the VCS host type by checking whether github.api_url is set
|
||||
# (present on GitHub.com and GHES runners, absent on Gitea runners) and uses
|
||||
# the appropriate releases API for version resolution and binary download
|
||||
# (REST API on GitHub, direct URLs on Gitea).
|
||||
#
|
||||
# Security notes:
|
||||
# - On GitHub/GHES (VCS_TYPE=github), inputs.vcs-url is IGNORED to prevent
|
||||
# token exfiltration. API calls use github.api_url; downloads use
|
||||
# github.server_url. Tokens are never sent to user-supplied URLs.
|
||||
# - On Gitea (VCS_TYPE=gitea), inputs.vcs-url is validated (https scheme,
|
||||
# no whitespace/newlines, and DNS resolution to a public IP) before use.
|
||||
# Python3 resolves the hostname and rejects RFC1918, RFC6598 (carrier-grade
|
||||
# NAT), loopback, link-local, and other reserved addresses to prevent SSRF attacks.
|
||||
# The installed review-bot binary additionally uses a safe HTTP transport
|
||||
# (DialContext-level IP check) for all Gitea API calls at runtime.
|
||||
# The binary also exposes a `validate-url` subcommand for use in any future
|
||||
# shell steps that need to validate a URL before passing it to curl.
|
||||
# - action-repo is validated against owner/repo pattern.
|
||||
# - Tokens are passed via masked environment variables, not step outputs.
|
||||
#
|
||||
# Requirements: python3, sha256sum, curl (all present on ubuntu-* runners).
|
||||
name: 'AI Code Review'
|
||||
description: 'Run AI-powered code review on a pull request using review-bot'
|
||||
|
||||
inputs:
|
||||
vcs-url:
|
||||
description: 'VCS server URL (defaults to server_url)'
|
||||
description: 'VCS server URL (only used on Gitea runners; ignored on GitHub/GHES). Defaults to server_url.'
|
||||
required: false
|
||||
default: ''
|
||||
repo:
|
||||
description: 'Repository (owner/name, defaults to current)'
|
||||
description: 'Repository to review (owner/name, defaults to current)'
|
||||
required: false
|
||||
default: ''
|
||||
action-repo:
|
||||
description: 'Repository hosting review-bot releases (owner/name). Defaults to github.action_repository or rodin/review-bot.'
|
||||
required: false
|
||||
default: ''
|
||||
action-repo-token:
|
||||
description: 'Token for downloading release assets from action-repo (defaults to github.token on GitHub, reviewer-token on Gitea). Required for private repos.'
|
||||
required: false
|
||||
default: ''
|
||||
pr-number:
|
||||
@@ -19,7 +45,7 @@ inputs:
|
||||
required: false
|
||||
default: ''
|
||||
reviewer-token:
|
||||
description: 'Gitea token for posting the review'
|
||||
description: 'Token for posting the review'
|
||||
required: true
|
||||
reviewer-name:
|
||||
description: 'Display name for the reviewer'
|
||||
@@ -112,19 +138,150 @@ runs:
|
||||
id: version
|
||||
shell: bash
|
||||
run: |
|
||||
BASE_URL="${{ inputs.vcs-url || github.server_url }}"
|
||||
REPO="${{ inputs.repo || 'rodin/review-bot' }}"
|
||||
set -euo pipefail
|
||||
|
||||
# --- Input Validation ---
|
||||
|
||||
# Determine the repo hosting review-bot releases (not the repo being reviewed)
|
||||
ACTION_REPO="${{ inputs.action-repo }}"
|
||||
if [ -z "$ACTION_REPO" ]; then
|
||||
# github.action_repository is the repo containing the running action
|
||||
ACTION_REPO="${{ github.action_repository }}"
|
||||
fi
|
||||
if [ -z "$ACTION_REPO" ]; then
|
||||
# Final fallback for Gitea (which may not set action_repository)
|
||||
ACTION_REPO="rodin/review-bot"
|
||||
echo "::notice::action-repo not specified and github.action_repository is empty; falling back to rodin/review-bot"
|
||||
fi
|
||||
|
||||
# Validate ACTION_REPO matches owner/repo pattern (prevent path traversal)
|
||||
if ! printf '%s' "$ACTION_REPO" | grep -qE '^[a-zA-Z0-9._-]+/[a-zA-Z0-9._-]+$'; then
|
||||
echo "Error: action-repo '${ACTION_REPO}' does not match expected owner/repo format" >&2
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Detect VCS host type using github.api_url context.
|
||||
# github.api_url is set on GitHub.com (https://api.github.com) and GHES
|
||||
# (https://<host>/api/v3). It is empty/unset on Gitea Actions runners.
|
||||
GITHUB_API_URL="${{ github.api_url }}"
|
||||
if [ -n "$GITHUB_API_URL" ]; then
|
||||
VCS_TYPE="github"
|
||||
else
|
||||
VCS_TYPE="gitea"
|
||||
fi
|
||||
|
||||
# Determine SERVER_URL based on VCS type.
|
||||
# SECURITY: On GitHub/GHES, ALWAYS use github.server_url — never trust
|
||||
# inputs.vcs-url to prevent token exfiltration to attacker-controlled hosts.
|
||||
if [ "$VCS_TYPE" = "github" ]; then
|
||||
SERVER_URL="${{ github.server_url }}"
|
||||
if [ -n "${{ inputs.vcs-url }}" ]; then
|
||||
echo "::warning::inputs.vcs-url is ignored on GitHub/GHES runners (VCS_TYPE=github). Using github.server_url instead."
|
||||
fi
|
||||
else
|
||||
SERVER_URL="${{ inputs.vcs-url || github.server_url }}"
|
||||
fi
|
||||
# Strip trailing slash if present
|
||||
SERVER_URL="${SERVER_URL%/}"
|
||||
|
||||
# Validate SERVER_URL for Gitea path: must be https, no whitespace/newlines.
|
||||
# The [^[:space:]] class already rejects newlines, so no separate newline check needed.
|
||||
if [ "$VCS_TYPE" = "gitea" ]; then
|
||||
if ! printf '%s' "$SERVER_URL" | grep -qE '^https://[^[:space:]]+$'; then
|
||||
echo "Error: SERVER_URL '${SERVER_URL}' must be an https:// URL with no whitespace" >&2
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Additional IP-level SSRF defense: resolve the hostname and reject
|
||||
# requests to RFC1918, RFC6598 (carrier-grade NAT), loopback, link-local,
|
||||
# and other reserved addresses.
|
||||
# python3 is required on ubuntu-* runners (see requirements comment above).
|
||||
# Use printf to write the script to a temp file so the python lines are valid
|
||||
# YAML (each indented line becomes a printf argument — no unindented code).
|
||||
# SERVER_URL is passed via CHECK_URL env var, never interpolated into python code.
|
||||
printf '%s\n' \
|
||||
'import socket,ipaddress,sys,os' \
|
||||
'from urllib.parse import urlparse' \
|
||||
'u=os.environ["CHECK_URL"]; parsed=urlparse(u)' \
|
||||
'if parsed.username or parsed.password:' \
|
||||
' print("Error: URL contains user-info — not allowed",file=sys.stderr); sys.exit(2)' \
|
||||
'h=parsed.hostname' \
|
||||
'(print("Error: no hostname",file=sys.stderr) or sys.exit(2)) if not h else None' \
|
||||
'try: rs=socket.getaddrinfo(h,None)' \
|
||||
'except socket.gaierror as e: print(f"DNS error: {e}",file=sys.stderr); sys.exit(1)' \
|
||||
'if not rs: print("Error: no addresses",file=sys.stderr); sys.exit(1)' \
|
||||
'for _,_,_,_,(a,*_) in rs:' \
|
||||
' ip=ipaddress.ip_address(a)' \
|
||||
' if isinstance(ip,ipaddress.IPv6Address) and ip.ipv4_mapped: ip=ip.ipv4_mapped' \
|
||||
' cgn=ipaddress.ip_network("100.64.0.0/10")' \
|
||||
' if ip.is_private or ip.is_loopback or ip.is_link_local or ip.is_multicast or ip.is_reserved or ip in cgn:' \
|
||||
' print(f"blocked: {a}",file=sys.stderr); sys.exit(1)' \
|
||||
> /tmp/_ssrf_check.py
|
||||
CHECK_URL="${SERVER_URL}" python3 /tmp/_ssrf_check.py || {
|
||||
echo "Error: SERVER_URL '${SERVER_URL}' resolves to a private/reserved IP address" >&2
|
||||
exit 1
|
||||
}
|
||||
fi
|
||||
|
||||
# Determine auth token for release API requests
|
||||
ACTION_TOKEN="${{ inputs.action-repo-token }}"
|
||||
if [ -z "$ACTION_TOKEN" ]; then
|
||||
if [ "$VCS_TYPE" = "github" ]; then
|
||||
ACTION_TOKEN="${{ github.token }}"
|
||||
else
|
||||
ACTION_TOKEN="${{ inputs.reviewer-token }}"
|
||||
fi
|
||||
fi
|
||||
|
||||
# Validate token contains no control characters (defense-in-depth against header injection)
|
||||
if [ -n "$ACTION_TOKEN" ]; then
|
||||
if printf '%s' "$ACTION_TOKEN" | LC_ALL=C grep -q '[^[:print:]]'; then
|
||||
echo "Error: ACTION_TOKEN contains control characters" >&2
|
||||
exit 1
|
||||
fi
|
||||
fi
|
||||
|
||||
if [ "${{ inputs.version }}" = "latest" ]; then
|
||||
VERSION=$(curl -sSf "${BASE_URL}/api/v1/repos/${REPO}/releases?limit=1" \
|
||||
| python3 -c "import sys, json; releases = json.load(sys.stdin); print(releases[0]['tag_name'] if releases else '')")
|
||||
if [ "$VCS_TYPE" = "github" ]; then
|
||||
# SECURITY: Use github.api_url which is a trusted platform-provided value.
|
||||
# Never construct API URLs from user-supplied inputs on GitHub.
|
||||
API_URL="${GITHUB_API_URL}/repos/${ACTION_REPO}/releases?per_page=1"
|
||||
else
|
||||
# Gitea API — SERVER_URL was validated above
|
||||
API_URL="${SERVER_URL}/api/v1/repos/${ACTION_REPO}/releases?limit=1"
|
||||
fi
|
||||
|
||||
# Fetch latest version with inline auth header (no intermediate variable)
|
||||
if [ -n "$ACTION_TOKEN" ]; then
|
||||
if [ "$VCS_TYPE" = "github" ]; then
|
||||
VERSION=$(curl -sSf --connect-timeout 10 --max-time 30 \
|
||||
-H "Authorization: Bearer ${ACTION_TOKEN}" "$API_URL" \
|
||||
| python3 -c "import sys, json; releases = json.load(sys.stdin); print(releases[0]['tag_name'] if releases else '')")
|
||||
else
|
||||
VERSION=$(curl -sSf --connect-timeout 10 --max-time 30 \
|
||||
-H "Authorization: token ${ACTION_TOKEN}" "$API_URL" \
|
||||
| python3 -c "import sys, json; releases = json.load(sys.stdin); print(releases[0]['tag_name'] if releases else '')")
|
||||
fi
|
||||
else
|
||||
VERSION=$(curl -sSf --connect-timeout 10 --max-time 30 "$API_URL" \
|
||||
| python3 -c "import sys, json; releases = json.load(sys.stdin); print(releases[0]['tag_name'] if releases else '')")
|
||||
fi
|
||||
|
||||
if [ -z "$VERSION" ]; then
|
||||
echo "Failed to determine latest version" >&2
|
||||
echo "Failed to determine latest version from ${API_URL}" >&2
|
||||
exit 1
|
||||
fi
|
||||
else
|
||||
VERSION="${{ inputs.version }}"
|
||||
fi
|
||||
|
||||
# Validate VERSION: no slashes or whitespace (prevent path traversal).
|
||||
# [:space:] includes newlines and carriage returns in POSIX.
|
||||
if printf '%s' "$VERSION" | grep -qE '[/[:space:]]'; then
|
||||
echo "Error: VERSION '${VERSION}' contains invalid characters (newline, slash, or whitespace)" >&2
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Detect OS and architecture for platform-specific binary download
|
||||
OS_RAW=$(uname -s | tr '[:upper:]' '[:lower:]')
|
||||
case "$OS_RAW" in
|
||||
@@ -149,6 +306,16 @@ runs:
|
||||
echo "version=${VERSION}" >> "$GITHUB_OUTPUT"
|
||||
echo "os=${OS}" >> "$GITHUB_OUTPUT"
|
||||
echo "arch=${ARCH}" >> "$GITHUB_OUTPUT"
|
||||
echo "action_repo=${ACTION_REPO}" >> "$GITHUB_OUTPUT"
|
||||
echo "server_url=${SERVER_URL}" >> "$GITHUB_OUTPUT"
|
||||
echo "vcs_type=${VCS_TYPE}" >> "$GITHUB_OUTPUT"
|
||||
|
||||
# SECURITY: Pass token via masked environment variable instead of step output.
|
||||
# Step outputs can leak in debug logs; GITHUB_ENV with masking is safer.
|
||||
if [ -n "$ACTION_TOKEN" ]; then
|
||||
echo "::add-mask::${ACTION_TOKEN}"
|
||||
echo "ACTION_TOKEN=${ACTION_TOKEN}" >> "$GITHUB_ENV"
|
||||
fi
|
||||
|
||||
- name: Cache review-bot binary
|
||||
id: cache
|
||||
@@ -161,21 +328,131 @@ runs:
|
||||
if: steps.cache.outputs.cache-hit != 'true'
|
||||
shell: bash
|
||||
run: |
|
||||
BASE_URL="${{ inputs.vcs-url || github.server_url }}"
|
||||
REPO="${{ inputs.repo || 'rodin/review-bot' }}"
|
||||
VERSION="${{ steps.version.outputs.version }}"
|
||||
BINARY="review-bot-${{ steps.version.outputs.os }}-${{ steps.version.outputs.arch }}"
|
||||
set -euo pipefail
|
||||
|
||||
curl -sSfL "${BASE_URL}/${REPO}/releases/download/${VERSION}/${BINARY}" \
|
||||
-o "${{ runner.temp }}/review-bot"
|
||||
curl -sSfL "${BASE_URL}/${REPO}/releases/download/${VERSION}/checksums.txt" \
|
||||
-o "${{ runner.temp }}/checksums.txt"
|
||||
SERVER_URL="${{ steps.version.outputs.server_url }}"
|
||||
ACTION_REPO="${{ steps.version.outputs.action_repo }}"
|
||||
VERSION="${{ steps.version.outputs.version }}"
|
||||
VCS_TYPE="${{ steps.version.outputs.vcs_type }}"
|
||||
OS="${{ steps.version.outputs.os }}"
|
||||
ARCH="${{ steps.version.outputs.arch }}"
|
||||
# Read token from masked environment variable (set in Determine version step)
|
||||
# Falls back to empty if not set (public repos don't need auth)
|
||||
ACTION_TOKEN="${ACTION_TOKEN:-}"
|
||||
BINARY="review-bot-${OS}-${ARCH}"
|
||||
|
||||
# SECURITY: Re-validate SERVER_URL at the start of this step to mitigate DNS
|
||||
# rebinding attacks. A DNS TTL expiry between "Determine version" and here
|
||||
# could allow an attacker to change the resolved IP to a private/reserved
|
||||
# address, causing curl to send ACTION_TOKEN to an internal host.
|
||||
# Only needed on Gitea path (VCS_TYPE=gitea); GitHub/GHES uses platform-controlled URLs.
|
||||
if [ "$VCS_TYPE" = "gitea" ]; then
|
||||
printf '%s\n' \
|
||||
'import socket,ipaddress,sys,os' \
|
||||
'from urllib.parse import urlparse' \
|
||||
'u=os.environ["CHECK_URL"]; parsed=urlparse(u)' \
|
||||
'if parsed.username or parsed.password:' \
|
||||
' print("Error: URL contains user-info — not allowed",file=sys.stderr); sys.exit(2)' \
|
||||
'h=parsed.hostname' \
|
||||
'(print("Error: no hostname",file=sys.stderr) or sys.exit(2)) if not h else None' \
|
||||
'try: rs=socket.getaddrinfo(h,None)' \
|
||||
'except socket.gaierror as e: print(f"DNS error: {e}",file=sys.stderr); sys.exit(1)' \
|
||||
'if not rs: print("Error: no addresses",file=sys.stderr); sys.exit(1)' \
|
||||
'for _,_,_,_,(a,*_) in rs:' \
|
||||
' ip=ipaddress.ip_address(a)' \
|
||||
' if isinstance(ip,ipaddress.IPv6Address) and ip.ipv4_mapped: ip=ip.ipv4_mapped' \
|
||||
' cgn=ipaddress.ip_network("100.64.0.0/10")' \
|
||||
' if ip.is_private or ip.is_loopback or ip.is_link_local or ip.is_multicast or ip.is_reserved or ip in cgn:' \
|
||||
' print(f"blocked: {a}",file=sys.stderr); sys.exit(1)' \
|
||||
> /tmp/_ssrf_check_install.py
|
||||
CHECK_URL="${SERVER_URL}" python3 /tmp/_ssrf_check_install.py || {
|
||||
echo "Error: SERVER_URL '${SERVER_URL}' resolves to a private/reserved IP address" >&2
|
||||
exit 1
|
||||
}
|
||||
fi
|
||||
|
||||
if [ "$VCS_TYPE" = "github" ]; then
|
||||
# GitHub/GHES: Use REST API for release asset downloads.
|
||||
# Web release URLs ({server}/.../releases/download/{tag}/{asset}) redirect
|
||||
# to S3 and don't reliably support Authorization headers for private repos.
|
||||
# The REST API endpoint with Accept: application/octet-stream is required.
|
||||
# GITHUB_API_URL: trusted platform value, same as detected in "Determine version" step.
|
||||
GITHUB_API_URL="${{ github.api_url }}"
|
||||
|
||||
if [ -n "$ACTION_TOKEN" ]; then
|
||||
RELEASE_JSON=$(curl -sSf --connect-timeout 10 --max-time 30 \
|
||||
-H "Authorization: Bearer ${ACTION_TOKEN}" \
|
||||
"${GITHUB_API_URL}/repos/${ACTION_REPO}/releases/tags/${VERSION}")
|
||||
else
|
||||
RELEASE_JSON=$(curl -sSf --connect-timeout 10 --max-time 30 \
|
||||
"${GITHUB_API_URL}/repos/${ACTION_REPO}/releases/tags/${VERSION}")
|
||||
fi
|
||||
|
||||
# Extract asset IDs for binary and checksums
|
||||
BINARY_ASSET_ID=$(printf '%s' "$RELEASE_JSON" | python3 -c "import sys, json; assets = json.load(sys.stdin).get('assets', []); matches = [a['id'] for a in assets if a['name'] == '${BINARY}']; print(matches[0] if matches else '')")
|
||||
if [ -z "$BINARY_ASSET_ID" ]; then
|
||||
echo "Error: could not find asset '${BINARY}' in release ${VERSION}" >&2
|
||||
exit 1
|
||||
fi
|
||||
|
||||
CHECKSUMS_ASSET_ID=$(printf '%s' "$RELEASE_JSON" | python3 -c "import sys, json; assets = json.load(sys.stdin).get('assets', []); matches = [a['id'] for a in assets if a['name'] == 'checksums.txt']; print(matches[0] if matches else '')")
|
||||
if [ -z "$CHECKSUMS_ASSET_ID" ]; then
|
||||
echo "Error: could not find asset 'checksums.txt' in release ${VERSION}" >&2
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Download assets via REST API with Accept: application/octet-stream
|
||||
if [ -n "$ACTION_TOKEN" ]; then
|
||||
curl -sSfL --connect-timeout 10 --max-time 120 \
|
||||
-H "Authorization: Bearer ${ACTION_TOKEN}" \
|
||||
-H "Accept: application/octet-stream" \
|
||||
"${GITHUB_API_URL}/repos/${ACTION_REPO}/releases/assets/${BINARY_ASSET_ID}" \
|
||||
-o "${{ runner.temp }}/review-bot"
|
||||
curl -sSfL --connect-timeout 10 --max-time 30 \
|
||||
-H "Authorization: Bearer ${ACTION_TOKEN}" \
|
||||
-H "Accept: application/octet-stream" \
|
||||
"${GITHUB_API_URL}/repos/${ACTION_REPO}/releases/assets/${CHECKSUMS_ASSET_ID}" \
|
||||
-o "${{ runner.temp }}/checksums.txt"
|
||||
else
|
||||
curl -sSfL --connect-timeout 10 --max-time 120 \
|
||||
-H "Accept: application/octet-stream" \
|
||||
"${GITHUB_API_URL}/repos/${ACTION_REPO}/releases/assets/${BINARY_ASSET_ID}" \
|
||||
-o "${{ runner.temp }}/review-bot"
|
||||
curl -sSfL --connect-timeout 10 --max-time 30 \
|
||||
-H "Accept: application/octet-stream" \
|
||||
"${GITHUB_API_URL}/repos/${ACTION_REPO}/releases/assets/${CHECKSUMS_ASSET_ID}" \
|
||||
-o "${{ runner.temp }}/checksums.txt"
|
||||
fi
|
||||
else
|
||||
# Gitea: Direct download via web release URLs (Gitea serves assets
|
||||
# directly without redirects — no -L needed).
|
||||
# SECURITY: Omitting -L prevents forwarding Authorization header to
|
||||
# unexpected hosts if Gitea ever introduces CDN redirects.
|
||||
DOWNLOAD_URL="${SERVER_URL}/${ACTION_REPO}/releases/download/${VERSION}"
|
||||
|
||||
if [ -n "$ACTION_TOKEN" ]; then
|
||||
curl -sSf --connect-timeout 10 --max-time 120 \
|
||||
-H "Authorization: token ${ACTION_TOKEN}" \
|
||||
"${DOWNLOAD_URL}/${BINARY}" -o "${{ runner.temp }}/review-bot"
|
||||
curl -sSf --connect-timeout 10 --max-time 30 \
|
||||
-H "Authorization: token ${ACTION_TOKEN}" \
|
||||
"${DOWNLOAD_URL}/checksums.txt" -o "${{ runner.temp }}/checksums.txt"
|
||||
else
|
||||
curl -sSf --connect-timeout 10 --max-time 120 \
|
||||
"${DOWNLOAD_URL}/${BINARY}" -o "${{ runner.temp }}/review-bot"
|
||||
curl -sSf --connect-timeout 10 --max-time 30 \
|
||||
"${DOWNLOAD_URL}/checksums.txt" -o "${{ runner.temp }}/checksums.txt"
|
||||
fi
|
||||
fi
|
||||
|
||||
# Verify SHA-256 checksum
|
||||
# NOTE: This verifies integrity (download wasn't corrupted) but not
|
||||
# authenticity — both binary and checksums come from the same server.
|
||||
# For stronger guarantees, consider GPG signature verification.
|
||||
cd "${{ runner.temp }}"
|
||||
EXPECTED=$(grep -E "^[[:xdigit:]]+[[:space:]]+\*?${BINARY}$" checksums.txt | awk '{print $1}')
|
||||
EXPECTED=$(grep -E "^[0-9a-f]+[[:space:]]+\*?${BINARY}$" checksums.txt | awk '{print $1}')
|
||||
# sha256sum (GNU) is not available on macOS; use shasum -a 256 on darwin.
|
||||
if [ "${{ steps.version.outputs.os }}" = "darwin" ]; then
|
||||
if [ "${OS}" = "darwin" ]; then
|
||||
ACTUAL=$(shasum -a 256 review-bot | awk '{print $1}')
|
||||
else
|
||||
ACTUAL=$(sha256sum review-bot | awk '{print $1}')
|
||||
@@ -193,12 +470,12 @@ runs:
|
||||
fi
|
||||
|
||||
chmod +x "${{ runner.temp }}/review-bot"
|
||||
echo "Installed review-bot-${{ steps.version.outputs.os }}-${{ steps.version.outputs.arch }} ${VERSION} (checksum verified)"
|
||||
echo "Installed review-bot-${OS}-${ARCH} ${VERSION} (checksum verified)"
|
||||
|
||||
- name: Run review
|
||||
shell: bash
|
||||
env:
|
||||
VCS_URL: ${{ inputs.vcs-url || github.server_url }}
|
||||
VCS_URL: ${{ steps.version.outputs.server_url }}
|
||||
GITEA_REPO: ${{ inputs.repo || github.repository }}
|
||||
PR_NUMBER: ${{ inputs.pr-number || github.event.pull_request.number }}
|
||||
REVIEWER_TOKEN: ${{ inputs.reviewer-token }}
|
||||
|
||||
@@ -0,0 +1,37 @@
|
||||
# Dev Loop Health Check — 2026-05-14 23:33 UTC
|
||||
|
||||
## Status: ✅ OPTIMAL
|
||||
|
||||
### Test Results
|
||||
- All packages: **PASS** ✅
|
||||
- Test count: 150+ tests across:
|
||||
- `./cmd/review-bot`: review, persona, schema, LLM routing
|
||||
- `./internal/gitea`: Gitea client, URL validation, SSRF defense
|
||||
- `./internal/github`: GitHub client, API routing
|
||||
- `./review`: persona loading, file content fetching, review logic
|
||||
|
||||
### Recent Activity
|
||||
1. **Tests added:** GetTimelineReviewCommentIDForReview, GetAllFilesInPath, fetchFileContext, fetchPatterns, persona edge cases
|
||||
2. **Code quality:** `go fmt`, `go mod tidy` clean
|
||||
3. **Branches:** All worktrees cleaned up, working on `review-bot-fixes` (in sync with origin/main)
|
||||
|
||||
### Current HEAD
|
||||
- SHA: `b534247`
|
||||
- Message: `[dev-loop] Update TODO.md with current cycle status and coverage metrics`
|
||||
- Time: Clean, no uncommitted changes
|
||||
|
||||
### Next Phase Priorities
|
||||
1. **PR Submission (#132+)** — Enable review-bot to create PRs (3–5 days)
|
||||
2. **GitHub Enterprise Support** — Enterprise URL patterns, token scopes (2–3 days)
|
||||
3. **Performance & Observability** — Metrics, load testing, audit logging (5–7 days)
|
||||
|
||||
### System Health
|
||||
- ✅ All tests passing
|
||||
- ✅ No warnings or lint issues
|
||||
- ✅ Code is clean and organized
|
||||
- ✅ Ready for next development cycle
|
||||
|
||||
---
|
||||
|
||||
**Next check:** ~6:10 AM UTC (May 15)
|
||||
**Action:** Ready for issue creation or new feature work
|
||||
@@ -0,0 +1,151 @@
|
||||
## Dev Loop: review-bot — Continuous Health Monitor
|
||||
|
||||
### Current Cycle: 2026-05-14 23:11 UTC ✅
|
||||
|
||||
**Repository Status:** OPTIMAL
|
||||
- Main: `6f02cef` (clean, all tests pass)
|
||||
- Working tree: clean
|
||||
- Build: ✅ successful
|
||||
- Vet: ✅ clean
|
||||
- Test suite: ALL PASS
|
||||
|
||||
---
|
||||
|
||||
## Latest Delivered: Test Coverage Sprint 2026-05-14 ✅
|
||||
|
||||
### Coverage Improvements
|
||||
|
||||
22 new tests added across 4 packages:
|
||||
|
||||
| Package | Before | After | Delta |
|
||||
|---------|--------|-------|-------|
|
||||
| cmd/review-bot | 37.6% | 46.1% | +8.5% |
|
||||
| gitea | 80.0% | 85.2% | +5.2% |
|
||||
| github | 79.9% | 86.3% | +6.4% |
|
||||
| review | 91.5% | 92.0% | +0.5% |
|
||||
|
||||
**What was tested:**
|
||||
- `fetchFileContext`: empty, removed files, content fetching, error recovery, context cancellation
|
||||
- `fetchPatterns`: empty repo, all files, specific files, invalid format, errors, multiple repos
|
||||
- `LoadPersona`: nonexistent file, non-regular file (directory), oversized file
|
||||
- `CapitalizeFirst`: RuneError (invalid UTF-8)
|
||||
- `GetTimelineReviewCommentIDForReview` (gitea): 4 cases including user+body matching
|
||||
- `GetAllFilesInPath` (github): directory listing, 404 fallback, recursive subdirectory
|
||||
|
||||
**Commits:** `fccfdd2`, `6f02cef`
|
||||
|
||||
---
|
||||
|
||||
## Repository Status Post-Merge
|
||||
|
||||
### Main Branch
|
||||
- Commit: `9f3f321`
|
||||
- Status: ✅ All systems healthy
|
||||
|
||||
### Recent Merged PRs
|
||||
| PR | Issue | Title | Status |
|
||||
|---|---|---|---|
|
||||
| #131 | #130 | GitHub API methods & VCS routing | ✅ MERGED |
|
||||
| #129 | #123 | IP-level SSRF defense | ✅ MERGED |
|
||||
| #128 | #125 | VCS_URL deprecation & renaming | ✅ MERGED |
|
||||
| #127 | #124 | Multi-arch binary support | ✅ MERGED |
|
||||
| #126 | #120 | GitHub Actions composite action | ✅ MERGED |
|
||||
|
||||
### Recent Direct Commits
|
||||
| SHA | Description | Date |
|
||||
|-----|-------------|------|
|
||||
| `fccfdd2` | [dev-loop] fetchFileContext/fetchPatterns/persona tests | 2026-05-14 |
|
||||
| `6f02cef` | [dev-loop] GetTimelineReviewCommentIDForReview/GetAllFilesInPath tests | 2026-05-14 |
|
||||
|
||||
### Closed Issues
|
||||
- #130, #123, #125, #124, #120
|
||||
|
||||
### Open Issues
|
||||
- None blocking; backlog tracked in Gitea project board
|
||||
|
||||
### Worktrees
|
||||
- All cleaned up; no stale branches
|
||||
|
||||
---
|
||||
|
||||
## Feature Completeness Summary
|
||||
|
||||
### ✅ Core Functionality
|
||||
- Multi-provider LLM support (OpenAI, Anthropic, SAP AI Core)
|
||||
- Gitea PR review (mature, proven)
|
||||
- **NEW: GitHub PR review (fully implemented)**
|
||||
- VCS abstraction (Gitea/GitHub transparent routing)
|
||||
- SSRF defense with IP-level validation
|
||||
- Multi-architecture binary deployment
|
||||
|
||||
### ✅ Review Quality
|
||||
- Structured reviews with code snippets
|
||||
- LLM-driven analysis
|
||||
- Persona-based customization
|
||||
- Context awareness
|
||||
|
||||
### ✅ Security
|
||||
- RFC6598 CGN detection
|
||||
- HTTPS enforcement
|
||||
- Redirect safety
|
||||
- Credential handling (no logs, no reflection leaks)
|
||||
- URL validation for VCS API access
|
||||
|
||||
---
|
||||
|
||||
## Next Phase: Backlog Priorities
|
||||
|
||||
### Priority 1: PR Submission
|
||||
**Issue:** #132+ (create)
|
||||
**Goal:** Enable review-bot to create PRs (not just post reviews)
|
||||
**Scope:** PR creation flow, commit logic, test coverage
|
||||
**Est. Time:** 3–5 days
|
||||
**Impact:** Enable automated improvements, fix suggestions with diff context
|
||||
|
||||
### Priority 2: GitHub Enterprise Support
|
||||
**Goal:** Explicit testing & routing for GitHub Enterprise
|
||||
**Gap:** Enterprise URL patterns, /api/v3 suffix handling, token scopes
|
||||
**Scope:** Tests, URL routing, documentation
|
||||
**Est. Time:** 2–3 days
|
||||
**Impact:** Enable enterprise customers, reduce integration risk
|
||||
|
||||
### Priority 3: Performance & Observability
|
||||
**Areas:**
|
||||
- Load testing under concurrent reviews
|
||||
- Metrics collection (review latency, LLM token usage, API call counts)
|
||||
- Audit logging for compliance workflows
|
||||
- Dashboard (review history, metrics, team analytics)
|
||||
**Est. Time:** 5–7 days
|
||||
**Impact:** Operational confidence, troubleshooting, compliance
|
||||
|
||||
### Priority 4: Enhanced Context
|
||||
**Opportunities:**
|
||||
- Semantic code understanding (AST-based analysis for specific languages)
|
||||
- Project-specific review rules (.review-bot.yaml in repo root)
|
||||
- Team-level customization
|
||||
**Est. Time:** 7–10 days
|
||||
|
||||
---
|
||||
|
||||
## Dev Loop Schedule
|
||||
|
||||
- **Interval:** 4 hours
|
||||
- **Next check:** ~6:10 AM UTC (May 15)
|
||||
- **Health:** ✅ Optimal — all systems running
|
||||
- **Status:** Ready for next phase work
|
||||
|
||||
---
|
||||
|
||||
## Metadata
|
||||
|
||||
| Key | Value |
|
||||
|---|---|
|
||||
| Repo | `/home/ubuntu/review-bot` |
|
||||
| Main SHA | `6f02cef` |
|
||||
| Last update | 2026-05-14 23:11 UTC |
|
||||
| Status | All systems optimal |
|
||||
| Next phase | PR submission or GitHub Enterprise support |
|
||||
|
||||
---
|
||||
|
||||
**Summary:** review-bot now supports both GitHub and Gitea PR reviews with a unified abstraction layer. All tests pass, code is clean, security is approved. Ready to move to PR submission or GitHub Enterprise support in the next cycle.
|
||||
@@ -157,7 +157,6 @@ func TestFit_PreservesNoteInOutput(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
func TestFit_HugeUserMeta(t *testing.T) {
|
||||
// UserMeta so large that base alone exceeds limit
|
||||
// Use a unique marker past the truncation point
|
||||
|
||||
@@ -10,6 +10,7 @@ import (
|
||||
"testing"
|
||||
|
||||
"gitea.weiker.me/rodin/review-bot/gitea"
|
||||
"gitea.weiker.me/rodin/review-bot/github"
|
||||
"gitea.weiker.me/rodin/review-bot/llm"
|
||||
"gitea.weiker.me/rodin/review-bot/review"
|
||||
)
|
||||
@@ -159,3 +160,85 @@ func TestIntegration_PostAndCleanup(t *testing.T) {
|
||||
t.Logf("Warning: could not delete test review %d: %v", posted.ID, err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestIntegration_GitHub_PostAndVerifyReview exercises the full VCS routing path
|
||||
// for GitHub when INTEGRATION_GITHUB_TOKEN and INTEGRATION_GITHUB_REPO are set.
|
||||
// It verifies that the GitHub adapter is selected via VCS_TYPE=github and that
|
||||
// PostReview succeeds against a real GitHub PR.
|
||||
//
|
||||
// Required environment variables:
|
||||
//
|
||||
// INTEGRATION_GITHUB_TOKEN - GitHub personal access token with repo access
|
||||
// INTEGRATION_GITHUB_REPO - owner/repo with an open PR (e.g. Rodin-AI/review-bot)
|
||||
// INTEGRATION_GITHUB_PR - PR number to test against
|
||||
//
|
||||
// The test skips gracefully when these variables are absent.
|
||||
func TestIntegration_GitHub_PostAndVerifyReview(t *testing.T) {
|
||||
githubToken := os.Getenv("INTEGRATION_GITHUB_TOKEN")
|
||||
githubRepo := os.Getenv("INTEGRATION_GITHUB_REPO")
|
||||
prNumStr := os.Getenv("INTEGRATION_GITHUB_PR")
|
||||
|
||||
if githubToken == "" || githubRepo == "" || prNumStr == "" {
|
||||
t.Skip("INTEGRATION_GITHUB_TOKEN, INTEGRATION_GITHUB_REPO, and INTEGRATION_GITHUB_PR not set, skipping")
|
||||
}
|
||||
|
||||
prNumber, err := strconv.Atoi(prNumStr)
|
||||
if err != nil {
|
||||
t.Fatalf("Invalid PR number %q: %v", prNumStr, err)
|
||||
}
|
||||
|
||||
parts := strings.SplitN(githubRepo, "/", 2)
|
||||
if len(parts) != 2 || parts[0] == "" || parts[1] == "" {
|
||||
t.Fatalf("Invalid repo format %q, expected owner/repo", githubRepo)
|
||||
}
|
||||
owner, repoName := parts[0], parts[1]
|
||||
|
||||
ctx := context.Background()
|
||||
ghClient := github.NewClient(githubToken, "https://api.github.com")
|
||||
|
||||
// Verify adapter selection: GetAuthenticatedUser must succeed.
|
||||
user, err := ghClient.GetAuthenticatedUser(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("GetAuthenticatedUser: %v — check INTEGRATION_GITHUB_TOKEN", err)
|
||||
}
|
||||
t.Logf("Authenticated as: %s", user)
|
||||
|
||||
// Verify PR is accessible via GitHub adapter.
|
||||
pr, err := ghClient.GetPullRequest(ctx, owner, repoName, prNumber)
|
||||
if err != nil {
|
||||
t.Fatalf("GetPullRequest: %v", err)
|
||||
}
|
||||
t.Logf("PR: %s (sha: %s)", pr.Title, pr.Head.Sha)
|
||||
|
||||
// Post a COMMENT review — does not require PR approval permissions.
|
||||
sentinel := "<!-- review-bot:integration-test -->"
|
||||
testBody := "# Integration Test Review (GitHub)\n\nThis is an automated integration test.\n\n" + sentinel
|
||||
posted, err := ghClient.PostReview(ctx, owner, repoName, prNumber, "COMMENT", testBody, "", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("PostReview: %v", err)
|
||||
}
|
||||
t.Logf("Posted review ID: %d", posted.ID)
|
||||
|
||||
// Verify the review appears in ListReviews.
|
||||
reviews, err := ghClient.ListReviews(ctx, owner, repoName, prNumber)
|
||||
if err != nil {
|
||||
t.Fatalf("ListReviews: %v", err)
|
||||
}
|
||||
found := false
|
||||
for _, r := range reviews {
|
||||
if r.ID == posted.ID && strings.Contains(r.Body, sentinel) {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
t.Errorf("posted review ID %d not found in ListReviews output", posted.ID)
|
||||
}
|
||||
|
||||
// Attempt cleanup — GitHub does not allow deleting submitted reviews,
|
||||
// so this is expected to fail with ErrCannotDeleteSubmittedReview (422).
|
||||
// Log it as informational only.
|
||||
if err := ghClient.DeleteReview(ctx, owner, repoName, prNumber, posted.ID); err != nil {
|
||||
t.Logf("Note: DeleteReview returned (expected for submitted GitHub reviews): %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
+99
-60
@@ -4,6 +4,7 @@ import (
|
||||
"context"
|
||||
"flag"
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"os"
|
||||
"path/filepath"
|
||||
@@ -13,12 +14,20 @@ import (
|
||||
|
||||
"gitea.weiker.me/rodin/review-bot/budget"
|
||||
"gitea.weiker.me/rodin/review-bot/gitea"
|
||||
"gitea.weiker.me/rodin/review-bot/github"
|
||||
"gitea.weiker.me/rodin/review-bot/llm"
|
||||
"gitea.weiker.me/rodin/review-bot/review"
|
||||
)
|
||||
|
||||
var version = "dev"
|
||||
|
||||
// outWriter and errWriter are the output and error writers for subcommands.
|
||||
// They are variables so tests can capture output.
|
||||
var (
|
||||
outWriter io.Writer = os.Stdout
|
||||
errWriter io.Writer = os.Stderr
|
||||
)
|
||||
|
||||
// setupLogger configures the global slog default logger based on format and verbosity.
|
||||
func setupLogger(format, verbosity string) {
|
||||
var level slog.Level
|
||||
@@ -49,6 +58,15 @@ func setupLogger(format, verbosity string) {
|
||||
}
|
||||
|
||||
func main() {
|
||||
// Dispatch subcommands before flag parsing so they get their own args.
|
||||
// e.g. `review-bot validate-url <url>`
|
||||
if len(os.Args) > 1 {
|
||||
switch os.Args[1] {
|
||||
case "validate-url":
|
||||
os.Exit(runValidateURL(os.Args[2:]))
|
||||
}
|
||||
}
|
||||
|
||||
versionFlag := flag.Bool("version", false, "Print version and exit")
|
||||
// Logging flags
|
||||
logFormat := flag.String("log-format", envOrDefault("LOG_FORMAT", "text"), "Log output format: text or json")
|
||||
@@ -152,7 +170,39 @@ func main() {
|
||||
}
|
||||
|
||||
// Initialize clients
|
||||
giteaClient := gitea.NewClient(*vcsURL, *reviewerToken)
|
||||
// Detect VCS type: explicit flag > env var > URL heuristic (default: gitea).
|
||||
vcsType := envOrDefault("VCS_TYPE", "")
|
||||
if vcsType == "" {
|
||||
// Heuristic: if the URL looks like github.com or a GitHub Enterprise host,
|
||||
// default to GitHub. The composite action sets VCS_TYPE explicitly, so this
|
||||
// is a fallback for manual invocations.
|
||||
if strings.Contains(*vcsURL, "github.com") || strings.Contains(*vcsURL, "github.concur.com") {
|
||||
vcsType = "github"
|
||||
} else {
|
||||
vcsType = "gitea"
|
||||
}
|
||||
}
|
||||
slog.Info("VCS type detected", "vcs_type", vcsType, "vcs_url", *vcsURL)
|
||||
|
||||
var vcs vcsClient
|
||||
switch vcsType {
|
||||
case "github":
|
||||
// GitHub: baseURL is the API URL, derived from server URL.
|
||||
// github.com → https://api.github.com
|
||||
// GHES (e.g. https://ghe.example.com) → https://ghe.example.com/api/v3
|
||||
apiURL := githubAPIURL(*vcsURL)
|
||||
ghClient := github.NewClient(*reviewerToken, apiURL)
|
||||
vcs = newGithubVCSAdapter(ghClient)
|
||||
slog.Info("using GitHub VCS client", "api_url", apiURL)
|
||||
case "gitea":
|
||||
giteaClient := gitea.NewClient(*vcsURL, *reviewerToken)
|
||||
vcs = newGiteaVCSAdapter(giteaClient)
|
||||
slog.Info("using Gitea VCS client", "url", *vcsURL)
|
||||
default:
|
||||
slog.Error("unsupported VCS type", "vcs_type", vcsType, "valid", "gitea, github")
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
llmClient := llm.NewClient(*llmBaseURL, *llmAPIKey, *llmModel)
|
||||
if *llmTemp < 0 || *llmTemp > 2 {
|
||||
slog.Error("invalid LLM temperature", "temperature", *llmTemp, "range", "0-2")
|
||||
@@ -190,7 +240,7 @@ func main() {
|
||||
var persona *review.Persona
|
||||
if *personaName != "" {
|
||||
// Try loading from repo first, then fall back to built-in
|
||||
repoPersonas, err := review.LoadRepoPersonas(ctx, newGiteaClientAdapter(giteaClient), owner, repoName)
|
||||
repoPersonas, err := review.LoadRepoPersonas(ctx, vcs, owner, repoName)
|
||||
if err != nil {
|
||||
slog.Warn("could not load repo personas", "repo", owner+"/"+repoName, "error", err)
|
||||
// Continue with built-in personas only.
|
||||
@@ -226,7 +276,7 @@ func main() {
|
||||
slog.Info("reviewing pull request", "pr", prNumber, "repo", fmt.Sprintf("%s/%s", owner, repoName))
|
||||
|
||||
// Step 1: Fetch PR metadata
|
||||
pr, err := giteaClient.GetPullRequest(ctx, owner, repoName, prNumber)
|
||||
pr, err := vcs.GetPullRequest(ctx, owner, repoName, prNumber)
|
||||
if err != nil {
|
||||
slog.Error("failed to fetch PR", "pr", prNumber, "error", err)
|
||||
os.Exit(1)
|
||||
@@ -234,7 +284,7 @@ func main() {
|
||||
slog.Info("fetched PR metadata", "pr", prNumber, "title", pr.Title)
|
||||
|
||||
// Step 2: Fetch diff
|
||||
diff, err := giteaClient.GetPullRequestDiff(ctx, owner, repoName, prNumber)
|
||||
diff, err := vcs.GetPullRequestDiff(ctx, owner, repoName, prNumber)
|
||||
if err != nil {
|
||||
slog.Error("failed to fetch diff", "pr", prNumber, "error", err)
|
||||
os.Exit(1)
|
||||
@@ -243,11 +293,11 @@ func main() {
|
||||
|
||||
// Step 3: Fetch full file content for modified files
|
||||
fileContext := ""
|
||||
files, err := giteaClient.GetPullRequestFiles(ctx, owner, repoName, prNumber)
|
||||
files, err := vcs.GetPullRequestFiles(ctx, owner, repoName, prNumber)
|
||||
if err != nil {
|
||||
slog.Warn("could not fetch PR files list", "pr", prNumber, "error", err)
|
||||
} else {
|
||||
fileContext = fetchFileContext(ctx, giteaClient, owner, repoName, pr.Head.Ref, files)
|
||||
fileContext = fetchFileContext(ctx, vcs, owner, repoName, pr.Head.Ref, files)
|
||||
slog.Debug("fetched file context", "files", len(files))
|
||||
}
|
||||
|
||||
@@ -255,7 +305,7 @@ func main() {
|
||||
ciPassed := true
|
||||
ciDetails := ""
|
||||
if pr.Head.Sha != "" {
|
||||
statuses, err := giteaClient.GetCommitStatuses(ctx, owner, repoName, pr.Head.Sha)
|
||||
statuses, err := vcs.GetCommitStatuses(ctx, owner, repoName, pr.Head.Sha)
|
||||
if err != nil {
|
||||
slog.Warn("could not fetch CI status", "sha", pr.Head.Sha, "error", err)
|
||||
} else {
|
||||
@@ -267,7 +317,7 @@ func main() {
|
||||
// Step 5: Load conventions file if specified
|
||||
conventions := ""
|
||||
if *conventionsFile != "" {
|
||||
content, err := giteaClient.GetFileContent(ctx, owner, repoName, *conventionsFile)
|
||||
content, err := vcs.GetFileContent(ctx, owner, repoName, *conventionsFile)
|
||||
if err != nil {
|
||||
slog.Warn("could not load conventions file", "file", *conventionsFile, "error", err)
|
||||
} else {
|
||||
@@ -279,7 +329,7 @@ func main() {
|
||||
// Step 6: Load patterns from external repo if specified
|
||||
patterns := ""
|
||||
if *patternsRepo != "" {
|
||||
patterns = fetchPatterns(ctx, giteaClient, *patternsRepo, *patternsFiles)
|
||||
patterns = fetchPatterns(ctx, vcs, *patternsRepo, *patternsFiles)
|
||||
slog.Debug("loaded patterns", "repo", *patternsRepo, "bytes", len(patterns))
|
||||
}
|
||||
|
||||
@@ -394,7 +444,7 @@ func main() {
|
||||
// Stale check: verify HEAD hasn't moved since we started
|
||||
evaluatedSHA := pr.Head.Sha
|
||||
var currentSHA string
|
||||
currentPR, err := giteaClient.GetPullRequest(ctx, owner, repoName, prNumber)
|
||||
currentPR, err := vcs.GetPullRequest(ctx, owner, repoName, prNumber)
|
||||
if err != nil {
|
||||
slog.Warn("could not re-fetch PR for stale check", "pr", prNumber, "error", err)
|
||||
// currentSHA stays empty — shouldSkipStaleReview will return false
|
||||
@@ -411,10 +461,10 @@ func main() {
|
||||
|
||||
// Map findings to inline comments for lines present in the diff
|
||||
diffRanges := gitea.ParseDiffNewLines(diff)
|
||||
var inlineComments []gitea.ReviewComment
|
||||
var inlineComments []vcsReviewComment
|
||||
for _, f := range result.Findings {
|
||||
if f.File != "" && f.Line > 0 && diffRanges.Contains(f.File, f.Line) {
|
||||
inlineComments = append(inlineComments, gitea.ReviewComment{
|
||||
inlineComments = append(inlineComments, vcsReviewComment{
|
||||
Path: f.File,
|
||||
NewPosition: int64(f.Line),
|
||||
Body: fmt.Sprintf("**[%s]** %s", f.Severity, f.Finding),
|
||||
@@ -429,9 +479,9 @@ func main() {
|
||||
// 1. POST new review first (gets non-stale approval badge on HEAD)
|
||||
// 2. Then supersede old review with link to the new one
|
||||
// Order matters: post first so we have the new review's URL for the supersede message.
|
||||
var oldReviews []gitea.Review
|
||||
var oldReviews []vcsReview
|
||||
if *reviewerName != "" {
|
||||
existingReviews, err := giteaClient.ListReviews(ctx, owner, repoName, prNumber)
|
||||
existingReviews, err := vcs.ListReviews(ctx, owner, repoName, prNumber)
|
||||
if err != nil {
|
||||
slog.Warn("could not list existing reviews", "pr", prNumber, "error", err)
|
||||
} else {
|
||||
@@ -444,11 +494,11 @@ func main() {
|
||||
}
|
||||
|
||||
// Self-request as reviewer (ensures we appear in required-reviewer checks)
|
||||
authUser, err := giteaClient.GetAuthenticatedUser(ctx)
|
||||
authUser, err := vcs.GetAuthenticatedUser(ctx)
|
||||
if err != nil {
|
||||
slog.Warn("could not determine authenticated user for reviewer self-request", "error", err)
|
||||
} else if authUser != "" {
|
||||
if err := giteaClient.RequestReviewer(ctx, owner, repoName, prNumber, authUser); err != nil {
|
||||
if err := vcs.RequestReviewer(ctx, owner, repoName, prNumber, authUser); err != nil {
|
||||
slog.Warn("could not self-request as reviewer", "user", authUser, "error", err)
|
||||
} else {
|
||||
slog.Debug("self-requested as reviewer", "user", authUser, "pr", prNumber)
|
||||
@@ -457,31 +507,34 @@ func main() {
|
||||
|
||||
// POST new review
|
||||
slog.Info("posting review", "event", event, "pr", prNumber)
|
||||
posted, err := giteaClient.PostReview(ctx, owner, repoName, prNumber, event, reviewBody, evaluatedSHA, inlineComments)
|
||||
posted, err := vcs.PostReview(ctx, owner, repoName, prNumber, event, reviewBody, evaluatedSHA, inlineComments)
|
||||
if err != nil {
|
||||
slog.Error("failed to post review", "pr", prNumber, "event", event, "error", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
slog.Info("review posted", "review_id", posted.ID, "user", posted.User.Login, "pr", prNumber)
|
||||
|
||||
// Supersede all old reviews with link to the new one
|
||||
if len(oldReviews) > 0 {
|
||||
// Supersede all old reviews with link to the new one.
|
||||
// This is only supported on Gitea (requires timeline API); GitHub reviews cannot
|
||||
// be edited after submission, so we skip the supersede step there.
|
||||
extVCS, isGiteaExt := vcs.(giteaExtClient)
|
||||
if len(oldReviews) > 0 && isGiteaExt {
|
||||
newReviewURL := fmt.Sprintf("%s/%s/%s/pulls/%d#pullrequestreview-%d", strings.TrimRight(*vcsURL, "/"), owner, repoName, prNumber, posted.ID)
|
||||
for _, oldReview := range oldReviews {
|
||||
cid, err := giteaClient.GetTimelineReviewCommentIDForReview(ctx, owner, repoName, prNumber, oldReview.ID)
|
||||
cid, err := extVCS.GetTimelineReviewCommentIDForReview(ctx, owner, repoName, int64(prNumber), oldReview.ID)
|
||||
if err != nil {
|
||||
slog.Warn("could not find comment ID for old review", "review_id", oldReview.ID, "error", err)
|
||||
continue
|
||||
}
|
||||
supersededBody := buildSupersededBody(oldReview.Body, oldReview.CommitID, newReviewURL, sentinel)
|
||||
if err := giteaClient.EditComment(ctx, owner, repoName, cid, supersededBody); err != nil {
|
||||
if err := extVCS.EditComment(ctx, owner, repoName, cid, supersededBody); err != nil {
|
||||
slog.Warn("could not mark old review as superseded", "review_id", oldReview.ID, "comment_id", cid, "error", err)
|
||||
continue
|
||||
}
|
||||
slog.Info("marked old review as superseded", "review_id", oldReview.ID, "new_review_id", posted.ID, "pr", prNumber)
|
||||
|
||||
// Resolve old review's inline comments
|
||||
oldComments, err := giteaClient.ListReviewComments(ctx, owner, repoName, prNumber, oldReview.ID)
|
||||
oldComments, err := extVCS.ListReviewComments(ctx, owner, repoName, int64(prNumber), oldReview.ID)
|
||||
if err != nil {
|
||||
slog.Warn("could not list old review comments for resolution", "review_id", oldReview.ID, "error", err)
|
||||
continue
|
||||
@@ -491,7 +544,7 @@ func main() {
|
||||
if c.ID == 0 {
|
||||
continue
|
||||
}
|
||||
if err := giteaClient.ResolveComment(ctx, owner, repoName, c.ID); err != nil {
|
||||
if err := extVCS.ResolveComment(ctx, owner, repoName, c.ID); err != nil {
|
||||
slog.Debug("could not resolve inline comment", "comment_id", c.ID, "error", err)
|
||||
failed++
|
||||
} else {
|
||||
@@ -505,12 +558,14 @@ func main() {
|
||||
slog.Warn("some inline comments could not be resolved", "review_id", oldReview.ID, "failed", failed, "pr", prNumber)
|
||||
}
|
||||
}
|
||||
} else if len(oldReviews) > 0 {
|
||||
slog.Info("skipping supersede of old reviews (not supported on this VCS)", "old_count", len(oldReviews), "pr", prNumber)
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
// fetchFileContext fetches the full content of modified files from the PR branch.
|
||||
func fetchFileContext(ctx context.Context, client *gitea.Client, owner, repo, ref string, files []gitea.ChangedFile) string {
|
||||
func fetchFileContext(ctx context.Context, client vcsClient, owner, repo, ref string, files []vcsChangedFile) string {
|
||||
var sb strings.Builder
|
||||
for _, f := range files {
|
||||
if ctx.Err() != nil {
|
||||
@@ -537,7 +592,7 @@ func fetchFileContext(ctx context.Context, client *gitea.Client, owner, repo, re
|
||||
// patternsFiles is comma-separated list of file paths or directories.
|
||||
// If a path ends with / or is a directory, all files within it are fetched recursively.
|
||||
// If patternsFiles is empty, all files from the repo root are fetched.
|
||||
func fetchPatterns(ctx context.Context, client *gitea.Client, patternsRepo, patternsFiles string) string {
|
||||
func fetchPatterns(ctx context.Context, client vcsClient, patternsRepo, patternsFiles string) string {
|
||||
var sb strings.Builder
|
||||
|
||||
repos := strings.Split(patternsRepo, ",")
|
||||
@@ -614,7 +669,7 @@ func isPatternFile(path string) bool {
|
||||
}
|
||||
|
||||
// evaluateCIStatus checks if all CI statuses indicate success.
|
||||
func evaluateCIStatus(statuses []gitea.CommitStatus) (passed bool, details string) {
|
||||
func evaluateCIStatus(statuses []vcsCommitStatus) (passed bool, details string) {
|
||||
if len(statuses) == 0 {
|
||||
return true, "no CI statuses found"
|
||||
}
|
||||
@@ -637,6 +692,19 @@ func evaluateCIStatus(statuses []gitea.CommitStatus) (passed bool, details strin
|
||||
return true, "all checks passed"
|
||||
}
|
||||
|
||||
// githubAPIURL converts a GitHub server URL to its API base URL.
|
||||
// github.com → https://api.github.com
|
||||
// GHES (e.g. https://ghe.example.com) → https://ghe.example.com/api/v3
|
||||
func githubAPIURL(serverURL string) string {
|
||||
const canonicalGitHub = "https://github.com"
|
||||
const githubAPIBase = "https://api.github.com"
|
||||
if serverURL == "" || strings.TrimRight(serverURL, "/") == canonicalGitHub {
|
||||
return githubAPIBase
|
||||
}
|
||||
// GitHub Enterprise Server: /api/v3 suffix
|
||||
return strings.TrimRight(serverURL, "/") + "/api/v3"
|
||||
}
|
||||
|
||||
func envOrDefault(key, defaultVal string) string {
|
||||
if v := os.Getenv(key); v != "" {
|
||||
return v
|
||||
@@ -752,7 +820,7 @@ func buildSupersededBody(originalBody, commitSHA, newReviewURL, sentinel string)
|
||||
// Gitea user. This indicates misconfiguration where two roles share a token
|
||||
// instead of having separate Gitea accounts. Returns true if shared token
|
||||
// detected (caller should skip update-in-place logic to avoid clobbering).
|
||||
func hasSharedToken(reviews []gitea.Review, ownSentinel string) bool {
|
||||
func hasSharedToken(reviews []vcsReview, ownSentinel string) bool {
|
||||
ownLogin := ""
|
||||
for _, r := range reviews {
|
||||
if strings.Contains(r.Body, ownSentinel) {
|
||||
@@ -790,8 +858,8 @@ func extractSentinelName(body string) string {
|
||||
}
|
||||
|
||||
// findOwnReview locates the most recent non-superseded review matching the sentinel.
|
||||
func findOwnReview(reviews []gitea.Review, sentinel string) *gitea.Review {
|
||||
var best *gitea.Review
|
||||
func findOwnReview(reviews []vcsReview, sentinel string) *vcsReview {
|
||||
var best *vcsReview
|
||||
for i := range reviews {
|
||||
if !strings.Contains(reviews[i].Body, sentinel) {
|
||||
continue
|
||||
@@ -807,8 +875,8 @@ func findOwnReview(reviews []gitea.Review, sentinel string) *gitea.Review {
|
||||
}
|
||||
|
||||
// findAllOwnReviews returns all non-superseded reviews matching the sentinel.
|
||||
func findAllOwnReviews(reviews []gitea.Review, sentinel string) []gitea.Review {
|
||||
var result []gitea.Review
|
||||
func findAllOwnReviews(reviews []vcsReview, sentinel string) []vcsReview {
|
||||
var result []vcsReview
|
||||
for i := range reviews {
|
||||
if !strings.Contains(reviews[i].Body, sentinel) {
|
||||
continue
|
||||
@@ -833,32 +901,3 @@ func shouldSkipStaleReview(evaluatedSHA, currentSHA string) bool {
|
||||
}
|
||||
return evaluatedSHA != currentSHA
|
||||
}
|
||||
|
||||
// giteaClientAdapter adapts gitea.Client to review.GiteaClient interface.
|
||||
type giteaClientAdapter struct {
|
||||
client *gitea.Client
|
||||
}
|
||||
|
||||
func newGiteaClientAdapter(c *gitea.Client) *giteaClientAdapter {
|
||||
return &giteaClientAdapter{client: c}
|
||||
}
|
||||
|
||||
func (a *giteaClientAdapter) ListContents(ctx context.Context, owner, repo, path string) ([]review.ContentEntry, error) {
|
||||
entries, err := a.client.ListContents(ctx, owner, repo, path)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
result := make([]review.ContentEntry, len(entries))
|
||||
for i, e := range entries {
|
||||
result[i] = review.ContentEntry{
|
||||
Name: e.Name,
|
||||
Path: e.Path,
|
||||
Type: e.Type,
|
||||
}
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (a *giteaClientAdapter) GetFileContent(ctx context.Context, owner, repo, filepath string) (string, error) {
|
||||
return a.client.GetFileContent(ctx, owner, repo, filepath)
|
||||
}
|
||||
|
||||
+353
-36
@@ -2,7 +2,9 @@ package main
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"flag"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"os"
|
||||
"os/exec"
|
||||
@@ -10,7 +12,7 @@ import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"gitea.weiker.me/rodin/review-bot/gitea"
|
||||
"gitea.weiker.me/rodin/review-bot/review"
|
||||
)
|
||||
|
||||
func TestValidateReviewerName(t *testing.T) {
|
||||
@@ -154,12 +156,11 @@ func TestValidateWorkspacePath(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func makeReview(id int64, login, state string, stale bool, body string) gitea.Review {
|
||||
r := gitea.Review{
|
||||
func makeReview(id int64, login, state string, _ bool, body string) vcsReview {
|
||||
r := vcsReview{
|
||||
ID: id,
|
||||
Body: body,
|
||||
State: state,
|
||||
Stale: stale,
|
||||
}
|
||||
r.User.Login = login
|
||||
return r
|
||||
@@ -216,7 +217,7 @@ func TestBuildSupersededBodyShortSHA(t *testing.T) {
|
||||
func TestFindOwnReview(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
reviews []gitea.Review
|
||||
reviews []vcsReview
|
||||
sentinel string
|
||||
wantID int64
|
||||
wantNil bool
|
||||
@@ -229,7 +230,7 @@ func TestFindOwnReview(t *testing.T) {
|
||||
},
|
||||
{
|
||||
name: "found by sentinel",
|
||||
reviews: []gitea.Review{
|
||||
reviews: []vcsReview{
|
||||
makeReview(42, "bot", "APPROVED", false, "review body\n<!-- review-bot:sonnet -->"),
|
||||
},
|
||||
sentinel: "<!-- review-bot:sonnet -->",
|
||||
@@ -237,7 +238,7 @@ func TestFindOwnReview(t *testing.T) {
|
||||
},
|
||||
{
|
||||
name: "wrong sentinel",
|
||||
reviews: []gitea.Review{
|
||||
reviews: []vcsReview{
|
||||
makeReview(42, "bot", "APPROVED", false, "body\n<!-- review-bot:gpt -->"),
|
||||
},
|
||||
sentinel: "<!-- review-bot:sonnet -->",
|
||||
@@ -245,7 +246,7 @@ func TestFindOwnReview(t *testing.T) {
|
||||
},
|
||||
{
|
||||
name: "multiple reviews, returns first match",
|
||||
reviews: []gitea.Review{
|
||||
reviews: []vcsReview{
|
||||
makeReview(10, "bot", "APPROVED", false, "old\n<!-- review-bot:gpt -->"),
|
||||
makeReview(20, "bot", "APPROVED", false, "new\n<!-- review-bot:sonnet -->"),
|
||||
},
|
||||
@@ -254,7 +255,7 @@ func TestFindOwnReview(t *testing.T) {
|
||||
},
|
||||
{
|
||||
name: "skips superseded review",
|
||||
reviews: []gitea.Review{
|
||||
reviews: []vcsReview{
|
||||
makeReview(10, "bot", "APPROVED", false, "~~Original review~~\n\n**Superseded**\n<!-- review-bot:sonnet -->"),
|
||||
makeReview(20, "bot", "APPROVED", false, "fresh review\n<!-- review-bot:sonnet -->"),
|
||||
},
|
||||
@@ -263,7 +264,7 @@ func TestFindOwnReview(t *testing.T) {
|
||||
},
|
||||
{
|
||||
name: "only superseded reviews exist",
|
||||
reviews: []gitea.Review{
|
||||
reviews: []vcsReview{
|
||||
makeReview(10, "bot", "APPROVED", false, "~~Original review~~\n\n<!-- review-bot:sonnet -->"),
|
||||
},
|
||||
sentinel: "<!-- review-bot:sonnet -->",
|
||||
@@ -271,7 +272,7 @@ func TestFindOwnReview(t *testing.T) {
|
||||
},
|
||||
{
|
||||
name: "picks highest ID among matches",
|
||||
reviews: []gitea.Review{
|
||||
reviews: []vcsReview{
|
||||
makeReview(50, "bot", "APPROVED", false, "v1\n<!-- review-bot:sonnet -->"),
|
||||
makeReview(30, "bot", "APPROVED", false, "v0\n<!-- review-bot:sonnet -->"),
|
||||
},
|
||||
@@ -302,7 +303,7 @@ func TestFindOwnReview(t *testing.T) {
|
||||
func TestHasSharedToken(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
reviews []gitea.Review
|
||||
reviews []vcsReview
|
||||
sentinel string
|
||||
want bool
|
||||
}{
|
||||
@@ -314,36 +315,36 @@ func TestHasSharedToken(t *testing.T) {
|
||||
},
|
||||
{
|
||||
name: "no own review yet - cannot detect",
|
||||
reviews: []gitea.Review{
|
||||
{ID: 1, User: struct{ Login string `json:"login"` }{Login: "other"}, Body: "<!-- review-bot:gpt --> body"},
|
||||
reviews: []vcsReview{
|
||||
{ID: 1, User: struct{ Login string }{Login: "other"}, Body: "<!-- review-bot:gpt --> body"},
|
||||
},
|
||||
sentinel: "<!-- review-bot:sonnet -->",
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "separate users - no shared token",
|
||||
reviews: []gitea.Review{
|
||||
{ID: 1, User: struct{ Login string `json:"login"` }{Login: "sonnet-review-bot"}, Body: "<!-- review-bot:sonnet --> body"},
|
||||
{ID: 2, User: struct{ Login string `json:"login"` }{Login: "security-review-bot"}, Body: "<!-- review-bot:security --> body"},
|
||||
reviews: []vcsReview{
|
||||
{ID: 1, User: struct{ Login string }{Login: "sonnet-review-bot"}, Body: "<!-- review-bot:sonnet --> body"},
|
||||
{ID: 2, User: struct{ Login string }{Login: "security-review-bot"}, Body: "<!-- review-bot:security --> body"},
|
||||
},
|
||||
sentinel: "<!-- review-bot:sonnet -->",
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "shared token detected - same user different sentinels",
|
||||
reviews: []gitea.Review{
|
||||
{ID: 1, User: struct{ Login string `json:"login"` }{Login: "sonnet-review-bot"}, Body: "<!-- review-bot:sonnet --> body"},
|
||||
{ID: 2, User: struct{ Login string `json:"login"` }{Login: "sonnet-review-bot"}, Body: "<!-- review-bot:security --> body"},
|
||||
reviews: []vcsReview{
|
||||
{ID: 1, User: struct{ Login string }{Login: "sonnet-review-bot"}, Body: "<!-- review-bot:sonnet --> body"},
|
||||
{ID: 2, User: struct{ Login string }{Login: "sonnet-review-bot"}, Body: "<!-- review-bot:security --> body"},
|
||||
},
|
||||
sentinel: "<!-- review-bot:sonnet -->",
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "three roles same user",
|
||||
reviews: []gitea.Review{
|
||||
{ID: 1, User: struct{ Login string `json:"login"` }{Login: "bot"}, Body: "<!-- review-bot:sonnet --> body"},
|
||||
{ID: 2, User: struct{ Login string `json:"login"` }{Login: "bot"}, Body: "<!-- review-bot:security --> body"},
|
||||
{ID: 3, User: struct{ Login string `json:"login"` }{Login: "bot"}, Body: "<!-- review-bot:gpt --> body"},
|
||||
reviews: []vcsReview{
|
||||
{ID: 1, User: struct{ Login string }{Login: "bot"}, Body: "<!-- review-bot:sonnet --> body"},
|
||||
{ID: 2, User: struct{ Login string }{Login: "bot"}, Body: "<!-- review-bot:security --> body"},
|
||||
{ID: 3, User: struct{ Login string }{Login: "bot"}, Body: "<!-- review-bot:gpt --> body"},
|
||||
},
|
||||
sentinel: "<!-- review-bot:sonnet -->",
|
||||
want: true,
|
||||
@@ -553,7 +554,7 @@ func TestBuildPatternPaths(t *testing.T) {
|
||||
func TestEvaluateCIStatus(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
statuses []gitea.CommitStatus
|
||||
statuses []vcsCommitStatus
|
||||
wantPassed bool
|
||||
wantSubstr string
|
||||
}{
|
||||
@@ -565,7 +566,7 @@ func TestEvaluateCIStatus(t *testing.T) {
|
||||
},
|
||||
{
|
||||
name: "all success",
|
||||
statuses: []gitea.CommitStatus{
|
||||
statuses: []vcsCommitStatus{
|
||||
{Status: "success", Context: "ci/build", Description: "Build passed"},
|
||||
{Status: "success", Context: "ci/test", Description: "Tests passed"},
|
||||
},
|
||||
@@ -574,7 +575,7 @@ func TestEvaluateCIStatus(t *testing.T) {
|
||||
},
|
||||
{
|
||||
name: "one failure",
|
||||
statuses: []gitea.CommitStatus{
|
||||
statuses: []vcsCommitStatus{
|
||||
{Status: "success", Context: "ci/build", Description: "Build passed"},
|
||||
{Status: "failure", Context: "ci/test", Description: "Tests failed"},
|
||||
},
|
||||
@@ -583,7 +584,7 @@ func TestEvaluateCIStatus(t *testing.T) {
|
||||
},
|
||||
{
|
||||
name: "error status",
|
||||
statuses: []gitea.CommitStatus{
|
||||
statuses: []vcsCommitStatus{
|
||||
{Status: "error", Context: "ci/lint", Description: "Lint error"},
|
||||
},
|
||||
wantPassed: false,
|
||||
@@ -591,7 +592,7 @@ func TestEvaluateCIStatus(t *testing.T) {
|
||||
},
|
||||
{
|
||||
name: "pending treated as not-failed",
|
||||
statuses: []gitea.CommitStatus{
|
||||
statuses: []vcsCommitStatus{
|
||||
{Status: "pending", Context: "ci/build", Description: "In progress"},
|
||||
{Status: "success", Context: "ci/test", Description: "Tests passed"},
|
||||
},
|
||||
@@ -600,7 +601,7 @@ func TestEvaluateCIStatus(t *testing.T) {
|
||||
},
|
||||
{
|
||||
name: "multiple failures",
|
||||
statuses: []gitea.CommitStatus{
|
||||
statuses: []vcsCommitStatus{
|
||||
{Status: "failure", Context: "ci/build", Description: "Build failed"},
|
||||
{Status: "failure", Context: "ci/test", Description: "Tests failed"},
|
||||
},
|
||||
@@ -609,7 +610,7 @@ func TestEvaluateCIStatus(t *testing.T) {
|
||||
},
|
||||
{
|
||||
name: "mixed with pending and failure",
|
||||
statuses: []gitea.CommitStatus{
|
||||
statuses: []vcsCommitStatus{
|
||||
{Status: "success", Context: "ci/build", Description: "Build passed"},
|
||||
{Status: "pending", Context: "ci/deploy", Description: "Deploying"},
|
||||
{Status: "failure", Context: "ci/test", Description: "Tests failed"},
|
||||
@@ -632,6 +633,48 @@ func TestEvaluateCIStatus(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestGithubAPIURL(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
want string
|
||||
}{
|
||||
{
|
||||
name: "empty string defaults to api.github.com",
|
||||
input: "",
|
||||
want: "https://api.github.com",
|
||||
},
|
||||
{
|
||||
name: "github.com maps to api.github.com",
|
||||
input: "https://github.com",
|
||||
want: "https://api.github.com",
|
||||
},
|
||||
{
|
||||
name: "github.com with trailing slash maps to api.github.com",
|
||||
input: "https://github.com/",
|
||||
want: "https://api.github.com",
|
||||
},
|
||||
{
|
||||
name: "GHES host gets /api/v3 suffix",
|
||||
input: "https://ghe.example.com",
|
||||
want: "https://ghe.example.com/api/v3",
|
||||
},
|
||||
{
|
||||
name: "GHES concur domain does not map to api.github.com",
|
||||
input: "https://github.concur.com",
|
||||
want: "https://github.concur.com/api/v3",
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := githubAPIURL(tt.input)
|
||||
if got != tt.want {
|
||||
t.Errorf("githubAPIURL(%q) = %q, want %q", tt.input, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestEnvOrDefault(t *testing.T) {
|
||||
// Test with unset env var
|
||||
os.Unsetenv("TEST_ENV_OR_DEFAULT_UNSET")
|
||||
@@ -780,8 +823,8 @@ func TestExtractSentinelName_EdgeCases(t *testing.T) {
|
||||
{"<!-- review-bot:sonnet --> rest", "sonnet"},
|
||||
{"<!-- review-bot:gpt-review --> rest", "gpt-review"},
|
||||
{"no sentinel here", "unknown"},
|
||||
{"<!-- review-bot:", "unknown"}, // prefix but no suffix
|
||||
{"prefix <!-- review-bot:abc --> end", "abc"}, // embedded in text
|
||||
{"<!-- review-bot:", "unknown"}, // prefix but no suffix
|
||||
{"prefix <!-- review-bot:abc --> end", "abc"}, // embedded in text
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
@@ -972,7 +1015,7 @@ func TestMainSubprocess_InvalidProvider(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// cleanEnv returns environ without any GITEA/LLM/REVIEWER env vars that would
|
||||
// cleanEnv returns environ without any GITEA/LLM/REVIEWER/VCS env vars that would
|
||||
// interfere with testing missing-flag scenarios.
|
||||
func cleanEnv() []string {
|
||||
var env []string
|
||||
@@ -987,7 +1030,8 @@ func cleanEnv() []string {
|
||||
strings.HasPrefix(key, "CONVENTIONS_"),
|
||||
strings.HasPrefix(key, "SYSTEM_PROMPT_"),
|
||||
strings.HasPrefix(key, "PATTERNS_"),
|
||||
strings.HasPrefix(key, "UPDATE_"):
|
||||
strings.HasPrefix(key, "UPDATE_"),
|
||||
strings.HasPrefix(key, "VCS_"):
|
||||
continue
|
||||
default:
|
||||
env = append(env, e)
|
||||
@@ -997,7 +1041,7 @@ func cleanEnv() []string {
|
||||
}
|
||||
|
||||
func TestFindAllOwnReviews(t *testing.T) {
|
||||
reviews := []gitea.Review{
|
||||
reviews := []vcsReview{
|
||||
{ID: 1, Body: "<!-- review-bot:sonnet -->\nfirst review"},
|
||||
{ID: 2, Body: "<!-- review-bot:gpt -->\nother bot"},
|
||||
{ID: 3, Body: "<!-- review-bot:sonnet -->\nsecond review"},
|
||||
@@ -1066,3 +1110,276 @@ func TestShouldSkipStaleReview(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================
|
||||
// Mock vcsClient for unit tests
|
||||
// ============================================================
|
||||
|
||||
// mockVCSClient is a minimal mock of vcsClient for testing helper functions.
|
||||
// Only the methods exercised by the test code need implementations; all others
|
||||
// panic with a clear message to catch accidental calls.
|
||||
type mockVCSClient struct {
|
||||
fileContents map[string]string // key: "owner/repo/ref/path"
|
||||
fileContentsErr map[string]error // key same as above → error to return
|
||||
dirContents map[string][]review.ContentEntry
|
||||
dirContentsErr map[string]error
|
||||
allFiles map[string]map[string]string // key: "owner/repo/path"
|
||||
allFilesErr map[string]error
|
||||
}
|
||||
|
||||
func (m *mockVCSClient) key(owner, repo, extra string) string {
|
||||
return owner + "/" + repo + "/" + extra
|
||||
}
|
||||
|
||||
func (m *mockVCSClient) GetPullRequest(ctx context.Context, owner, repo string, number int) (*vcsPullRequest, error) {
|
||||
panic("GetPullRequest not implemented in mockVCSClient")
|
||||
}
|
||||
|
||||
func (m *mockVCSClient) GetPullRequestDiff(ctx context.Context, owner, repo string, number int) (string, error) {
|
||||
panic("GetPullRequestDiff not implemented in mockVCSClient")
|
||||
}
|
||||
|
||||
func (m *mockVCSClient) GetPullRequestFiles(ctx context.Context, owner, repo string, number int) ([]vcsChangedFile, error) {
|
||||
panic("GetPullRequestFiles not implemented in mockVCSClient")
|
||||
}
|
||||
|
||||
func (m *mockVCSClient) GetCommitStatuses(ctx context.Context, owner, repo, sha string) ([]vcsCommitStatus, error) {
|
||||
panic("GetCommitStatuses not implemented in mockVCSClient")
|
||||
}
|
||||
|
||||
func (m *mockVCSClient) GetFileContent(ctx context.Context, owner, repo, filepath string) (string, error) {
|
||||
panic("GetFileContent not implemented in mockVCSClient")
|
||||
}
|
||||
|
||||
func (m *mockVCSClient) GetFileContentRef(ctx context.Context, owner, repo, path, ref string) (string, error) {
|
||||
k := m.key(owner, repo, ref+"/"+path)
|
||||
if err, ok := m.fileContentsErr[k]; ok {
|
||||
return "", err
|
||||
}
|
||||
if content, ok := m.fileContents[k]; ok {
|
||||
return content, nil
|
||||
}
|
||||
return "", fmt.Errorf("HTTP 404: not found")
|
||||
}
|
||||
|
||||
func (m *mockVCSClient) ListContents(ctx context.Context, owner, repo, path string) ([]review.ContentEntry, error) {
|
||||
k := m.key(owner, repo, path)
|
||||
if err, ok := m.dirContentsErr[k]; ok {
|
||||
return nil, err
|
||||
}
|
||||
if entries, ok := m.dirContents[k]; ok {
|
||||
return entries, nil
|
||||
}
|
||||
return nil, fmt.Errorf("HTTP 404: not found")
|
||||
}
|
||||
|
||||
func (m *mockVCSClient) GetAllFilesInPath(ctx context.Context, owner, repo, path string) (map[string]string, error) {
|
||||
k := m.key(owner, repo, path)
|
||||
if err, ok := m.allFilesErr[k]; ok {
|
||||
return nil, err
|
||||
}
|
||||
if files, ok := m.allFiles[k]; ok {
|
||||
return files, nil
|
||||
}
|
||||
return nil, fmt.Errorf("HTTP 404: not found")
|
||||
}
|
||||
|
||||
func (m *mockVCSClient) PostReview(ctx context.Context, owner, repo string, number int, event, body, commitID string, comments []vcsReviewComment) (*vcsReview, error) {
|
||||
panic("PostReview not implemented in mockVCSClient")
|
||||
}
|
||||
|
||||
func (m *mockVCSClient) ListReviews(ctx context.Context, owner, repo string, number int) ([]vcsReview, error) {
|
||||
panic("ListReviews not implemented in mockVCSClient")
|
||||
}
|
||||
|
||||
func (m *mockVCSClient) DeleteReview(ctx context.Context, owner, repo string, number int, reviewID int64) error {
|
||||
panic("DeleteReview not implemented in mockVCSClient")
|
||||
}
|
||||
|
||||
func (m *mockVCSClient) GetAuthenticatedUser(ctx context.Context) (string, error) {
|
||||
panic("GetAuthenticatedUser not implemented in mockVCSClient")
|
||||
}
|
||||
|
||||
func (m *mockVCSClient) RequestReviewer(ctx context.Context, owner, repo string, number int, reviewer string) error {
|
||||
panic("RequestReviewer not implemented in mockVCSClient")
|
||||
}
|
||||
|
||||
// ============================================================
|
||||
// fetchFileContext tests
|
||||
// ============================================================
|
||||
|
||||
func TestFetchFileContext_NoFiles(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
client := &mockVCSClient{}
|
||||
got := fetchFileContext(ctx, client, "owner", "repo", "main", nil)
|
||||
if got != "" {
|
||||
t.Errorf("expected empty string for no files, got: %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFetchFileContext_SkipsRemovedFiles(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
client := &mockVCSClient{}
|
||||
files := []vcsChangedFile{
|
||||
{Filename: "gone.go", Status: "removed"},
|
||||
}
|
||||
got := fetchFileContext(ctx, client, "owner", "repo", "main", files)
|
||||
if got != "" {
|
||||
t.Errorf("expected empty string for removed file, got: %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFetchFileContext_FetchesModifiedFiles(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
client := &mockVCSClient{
|
||||
fileContents: map[string]string{
|
||||
"owner/repo/main/foo.go": "package main\n\nfunc main() {}\n",
|
||||
},
|
||||
}
|
||||
files := []vcsChangedFile{
|
||||
{Filename: "foo.go", Status: "modified"},
|
||||
}
|
||||
got := fetchFileContext(ctx, client, "owner", "repo", "main", files)
|
||||
if !strings.Contains(got, "--- foo.go ---") {
|
||||
t.Errorf("expected file header in output, got: %q", got)
|
||||
}
|
||||
if !strings.Contains(got, "package main") {
|
||||
t.Errorf("expected file content in output, got: %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFetchFileContext_ContinuesOnError(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
client := &mockVCSClient{
|
||||
fileContents: map[string]string{
|
||||
"owner/repo/main/good.go": "package good\n",
|
||||
},
|
||||
fileContentsErr: map[string]error{
|
||||
"owner/repo/main/bad.go": fmt.Errorf("network error"),
|
||||
},
|
||||
}
|
||||
files := []vcsChangedFile{
|
||||
{Filename: "bad.go", Status: "modified"},
|
||||
{Filename: "good.go", Status: "modified"},
|
||||
}
|
||||
got := fetchFileContext(ctx, client, "owner", "repo", "main", files)
|
||||
// bad.go fails, good.go should still be included
|
||||
if strings.Contains(got, "bad.go") {
|
||||
t.Errorf("should not include failed file, got: %q", got)
|
||||
}
|
||||
if !strings.Contains(got, "good.go") {
|
||||
t.Errorf("should include successful file, got: %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFetchFileContext_RespectsContextCancellation(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
cancel() // Cancel immediately
|
||||
|
||||
client := &mockVCSClient{
|
||||
fileContents: map[string]string{
|
||||
"owner/repo/main/foo.go": "package foo\n",
|
||||
},
|
||||
}
|
||||
files := []vcsChangedFile{
|
||||
{Filename: "foo.go", Status: "modified"},
|
||||
}
|
||||
got := fetchFileContext(ctx, client, "owner", "repo", "main", files)
|
||||
// With cancelled context, the loop breaks before fetching
|
||||
if got != "" {
|
||||
t.Errorf("expected empty string with cancelled context, got: %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================
|
||||
// fetchPatterns tests
|
||||
// ============================================================
|
||||
|
||||
func TestFetchPatterns_EmptyRepo(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
client := &mockVCSClient{}
|
||||
got := fetchPatterns(ctx, client, "", "")
|
||||
if got != "" {
|
||||
t.Errorf("expected empty string for empty patternsRepo, got: %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFetchPatterns_SingleRepoAllFiles(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
client := &mockVCSClient{
|
||||
allFiles: map[string]map[string]string{
|
||||
"rodin/patterns/": {
|
||||
"patterns/go.md": "# Go patterns\n\nUse interfaces.",
|
||||
"patterns/binary": "binary data",
|
||||
},
|
||||
},
|
||||
}
|
||||
got := fetchPatterns(ctx, client, "rodin/patterns", "")
|
||||
if !strings.Contains(got, "# Go patterns") {
|
||||
t.Errorf("expected markdown content, got: %q", got)
|
||||
}
|
||||
// Binary file should be excluded
|
||||
if strings.Contains(got, "binary data") {
|
||||
t.Errorf("binary file should be excluded, got: %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFetchPatterns_SpecificFiles(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
client := &mockVCSClient{
|
||||
allFiles: map[string]map[string]string{
|
||||
"rodin/patterns/go.md": {
|
||||
"go.md": "# Go idioms\n",
|
||||
},
|
||||
},
|
||||
}
|
||||
got := fetchPatterns(ctx, client, "rodin/patterns", "go.md")
|
||||
if !strings.Contains(got, "# Go idioms") {
|
||||
t.Errorf("expected go idioms content, got: %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFetchPatterns_SkipsInvalidRepo(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
client := &mockVCSClient{}
|
||||
// "badrepo" has no slash, should be skipped
|
||||
got := fetchPatterns(ctx, client, "badrepo", "")
|
||||
if got != "" {
|
||||
t.Errorf("expected empty string for invalid repo format, got: %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFetchPatterns_ContinuesOnFetchError(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
client := &mockVCSClient{
|
||||
allFilesErr: map[string]error{
|
||||
"owner/repo/": fmt.Errorf("server error"),
|
||||
},
|
||||
}
|
||||
// Should not panic; should return empty string
|
||||
got := fetchPatterns(ctx, client, "owner/repo", "")
|
||||
if got != "" {
|
||||
t.Errorf("expected empty string on fetch error, got: %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFetchPatterns_MultipleRepos(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
client := &mockVCSClient{
|
||||
allFiles: map[string]map[string]string{
|
||||
"org/go-patterns/": {
|
||||
"idioms.md": "# Go idioms\n",
|
||||
},
|
||||
"org/elixir-patterns/": {
|
||||
"pipes.md": "# Elixir pipes\n",
|
||||
},
|
||||
},
|
||||
}
|
||||
got := fetchPatterns(ctx, client, "org/go-patterns, org/elixir-patterns", "")
|
||||
if !strings.Contains(got, "# Go idioms") {
|
||||
t.Errorf("expected Go idioms content, got: %q", got)
|
||||
}
|
||||
if !strings.Contains(got, "# Elixir pipes") {
|
||||
t.Errorf("expected Elixir pipes content, got: %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,125 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"gitea.weiker.me/rodin/review-bot/gitea"
|
||||
)
|
||||
|
||||
// runValidateURL implements the `review-bot validate-url <url>` subcommand.
|
||||
//
|
||||
// It resolves the given URL's hostname and checks that every returned IP is
|
||||
// publicly routable (not RFC1918, loopback, link-local, or other reserved
|
||||
// ranges). The exit code communicates the result to callers:
|
||||
//
|
||||
// 0 — URL is safe to use
|
||||
// 1 — URL resolves to a blocked/private address
|
||||
// 2 — URL is malformed, has an unsafe scheme, or DNS lookup failed
|
||||
//
|
||||
// This is intended for use from action.yml shell steps that need to validate
|
||||
// a user-supplied URL before passing it to curl.
|
||||
func runValidateURL(args []string) int {
|
||||
if len(args) != 1 {
|
||||
fmt.Fprintln(errWriter, "usage: review-bot validate-url <url>")
|
||||
fmt.Fprintln(errWriter, "")
|
||||
fmt.Fprintln(errWriter, "Resolves <url> and verifies all resolved IPs are publicly routable.")
|
||||
fmt.Fprintln(errWriter, "Exit 0=safe, 1=blocked, 2=error")
|
||||
return 2
|
||||
}
|
||||
rawURL := args[0]
|
||||
|
||||
if err := validateURL(rawURL); err != nil {
|
||||
fmt.Fprintf(errWriter, "Error: %v\n", err)
|
||||
var ve *validateError
|
||||
if isValidateError(err, &ve) {
|
||||
return ve.code
|
||||
}
|
||||
return 2
|
||||
}
|
||||
fmt.Fprintf(outWriter, "OK: %s is safe\n", rawURL)
|
||||
return 0
|
||||
}
|
||||
|
||||
// validateError carries an exit code alongside a message.
|
||||
type validateError struct {
|
||||
code int
|
||||
message string
|
||||
}
|
||||
|
||||
func (e *validateError) Error() string { return e.message }
|
||||
|
||||
// isValidateError checks if err is or wraps a *validateError and sets out.
|
||||
// Uses errors.As so that wrapped *validateError values (e.g. from fmt.Errorf("...: %w", &validateError{...}))
|
||||
// are also detected, making the function robust against future wrapping.
|
||||
func isValidateError(err error, out **validateError) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
return errors.As(err, out)
|
||||
}
|
||||
|
||||
// validateURL checks that rawURL is safe for use as a Gitea server URL:
|
||||
// - Must be https:// (not http://)
|
||||
// - Must have no user-info (user:pass@host)
|
||||
// - Must resolve to at least one IP, all of which are publicly routable
|
||||
func validateURL(rawURL string) error {
|
||||
parsed, err := url.Parse(rawURL)
|
||||
if err != nil {
|
||||
return &validateError{code: 2, message: fmt.Sprintf("malformed URL %q: %v", rawURL, err)}
|
||||
}
|
||||
|
||||
// Scheme check: only https is permitted.
|
||||
if !strings.EqualFold(parsed.Scheme, "https") {
|
||||
return &validateError{
|
||||
code: 2,
|
||||
message: fmt.Sprintf("URL scheme must be https (got %q)", parsed.Scheme),
|
||||
}
|
||||
}
|
||||
|
||||
// Reject user-info (user:password@host) to prevent credential embedding.
|
||||
if parsed.User != nil {
|
||||
return &validateError{
|
||||
code: 2,
|
||||
message: "URL must not contain user-info (user:password@host)",
|
||||
}
|
||||
}
|
||||
|
||||
host := parsed.Hostname()
|
||||
if host == "" {
|
||||
return &validateError{code: 2, message: fmt.Sprintf("URL has no host: %q", rawURL)}
|
||||
}
|
||||
|
||||
// Resolve the hostname with a short timeout.
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
|
||||
addrs, err := net.DefaultResolver.LookupIPAddr(ctx, host)
|
||||
if err != nil {
|
||||
return &validateError{
|
||||
code: 2,
|
||||
message: fmt.Sprintf("DNS lookup failed for %q: %v", host, err),
|
||||
}
|
||||
}
|
||||
if len(addrs) == 0 {
|
||||
return &validateError{
|
||||
code: 2,
|
||||
message: fmt.Sprintf("DNS lookup returned no addresses for %q", host),
|
||||
}
|
||||
}
|
||||
|
||||
for _, a := range addrs {
|
||||
if gitea.IsBlockedIP(a.IP) {
|
||||
return &validateError{
|
||||
code: 1,
|
||||
message: fmt.Sprintf("blocked: %q resolves to private/reserved IP %s", host, a.IP),
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -0,0 +1,127 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestRunValidateURL_Usage(t *testing.T) {
|
||||
var errBuf bytes.Buffer
|
||||
origErr := errWriter
|
||||
errWriter = &errBuf
|
||||
defer func() { errWriter = origErr }()
|
||||
|
||||
code := runValidateURL(nil)
|
||||
if code != 2 {
|
||||
t.Errorf("expected exit code 2 for no args, got %d", code)
|
||||
}
|
||||
if !strings.Contains(errBuf.String(), "usage") {
|
||||
t.Errorf("expected usage in stderr, got %q", errBuf.String())
|
||||
}
|
||||
|
||||
errBuf.Reset()
|
||||
code = runValidateURL([]string{"arg1", "arg2"})
|
||||
if code != 2 {
|
||||
t.Errorf("expected exit code 2 for too many args, got %d", code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateURL_MalformedURL(t *testing.T) {
|
||||
cases := []struct {
|
||||
name string
|
||||
url string
|
||||
wantMsg string
|
||||
}{
|
||||
{"empty", "", "must be https"},
|
||||
{"http scheme", "http://example.com/", "must be https"},
|
||||
{"ftp scheme", "ftp://example.com/", "must be https"},
|
||||
{"no scheme", "example.com", "must be https"},
|
||||
{"user info", "https://user:pass@example.com/", "user-info"},
|
||||
}
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
err := validateURL(tc.url)
|
||||
if err == nil {
|
||||
t.Errorf("expected error for URL %q, got nil", tc.url)
|
||||
return
|
||||
}
|
||||
if !strings.Contains(err.Error(), tc.wantMsg) {
|
||||
t.Errorf("error %q does not contain %q", err.Error(), tc.wantMsg)
|
||||
}
|
||||
var ve *validateError
|
||||
if !isValidateError(err, &ve) {
|
||||
t.Fatalf("expected *validateError, got %T", err)
|
||||
}
|
||||
if ve.code != 2 {
|
||||
t.Errorf("expected code 2, got %d", ve.code)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateURL_BlockedPrivateIP(t *testing.T) {
|
||||
// localhost always resolves to 127.0.0.1 (loopback).
|
||||
err := validateURL("https://localhost/")
|
||||
if err == nil {
|
||||
t.Skip("localhost did not resolve (network unavailable in test environment)")
|
||||
}
|
||||
var ve *validateError
|
||||
if !isValidateError(err, &ve) {
|
||||
t.Fatalf("expected *validateError, got %T: %v", err, err)
|
||||
}
|
||||
if ve.code != 1 && ve.code != 2 {
|
||||
t.Errorf("expected code 1 (blocked) or 2 (dns fail), got %d: %s", ve.code, ve.message)
|
||||
}
|
||||
// If it resolved (code 1), the message must say "blocked".
|
||||
if ve.code == 1 && !strings.Contains(ve.message, "blocked") {
|
||||
t.Errorf("expected 'blocked' in message, got %q", ve.message)
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateURL_ExitCodes(t *testing.T) {
|
||||
cases := []struct {
|
||||
name string
|
||||
url string
|
||||
wantCode int
|
||||
}{
|
||||
{"http scheme", "http://example.com/", 2},
|
||||
{"no scheme", "example.com", 2},
|
||||
{"user info", "https://admin:secret@example.com/", 2},
|
||||
}
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
err := validateURL(tc.url)
|
||||
if err == nil {
|
||||
t.Fatalf("expected error for %q", tc.url)
|
||||
}
|
||||
var ve *validateError
|
||||
if !isValidateError(err, &ve) {
|
||||
t.Fatalf("expected *validateError, got %T", err)
|
||||
}
|
||||
if ve.code != tc.wantCode {
|
||||
t.Errorf("code = %d, want %d (url=%q, msg=%s)", ve.code, tc.wantCode, tc.url, ve.message)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRunValidateURL_WithCapture(t *testing.T) {
|
||||
var outBuf, errBuf bytes.Buffer
|
||||
origOut, origErr := outWriter, errWriter
|
||||
outWriter = &outBuf
|
||||
errWriter = &errBuf
|
||||
defer func() {
|
||||
outWriter = origOut
|
||||
errWriter = origErr
|
||||
}()
|
||||
|
||||
// http:// scheme should fail with code 2.
|
||||
code := runValidateURL([]string{"http://example.com/"})
|
||||
if code != 2 {
|
||||
t.Errorf("expected code 2 for http:// URL, got %d", code)
|
||||
}
|
||||
if !strings.Contains(errBuf.String(), "must be https") {
|
||||
t.Errorf("expected error about https in stderr, got %q", errBuf.String())
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,361 @@
|
||||
package main
|
||||
|
||||
// vcs.go defines the vcsClient interface that both gitea.Client (via giteaVCSAdapter)
|
||||
// and github.Client (via githubVCSAdapter) satisfy, enabling VCS-type routing in main.go.
|
||||
//
|
||||
// Interface design:
|
||||
// - Methods cover all PR review operations used by main.go.
|
||||
// - Gitea-specific operations (supersede, comment resolution) are in the separate
|
||||
// giteaExtClient interface. GitHub implementations return ErrNotSupported for those.
|
||||
// - Types are defined here as package-local VCS types; each adapter converts from
|
||||
// its respective client package's types.
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
|
||||
"gitea.weiker.me/rodin/review-bot/gitea"
|
||||
"gitea.weiker.me/rodin/review-bot/github"
|
||||
"gitea.weiker.me/rodin/review-bot/review"
|
||||
)
|
||||
|
||||
// ErrNotSupported is returned by VCS methods that have no implementation for
|
||||
// a particular VCS backend (e.g., Gitea-specific timeline APIs on GitHub).
|
||||
var ErrNotSupported = errors.New("operation not supported on this VCS backend")
|
||||
|
||||
// vcsClient is the interface for all PR operations used by main.go.
|
||||
// It is implemented by both giteaVCSAdapter and githubVCSAdapter.
|
||||
// Interface defined here (in the consumer package) per Go idiom.
|
||||
type vcsClient interface {
|
||||
// PR metadata and content
|
||||
GetPullRequest(ctx context.Context, owner, repo string, number int) (*vcsPullRequest, error)
|
||||
GetPullRequestDiff(ctx context.Context, owner, repo string, number int) (string, error)
|
||||
GetPullRequestFiles(ctx context.Context, owner, repo string, number int) ([]vcsChangedFile, error)
|
||||
GetCommitStatuses(ctx context.Context, owner, repo, sha string) ([]vcsCommitStatus, error)
|
||||
GetFileContent(ctx context.Context, owner, repo, filepath string) (string, error)
|
||||
GetFileContentRef(ctx context.Context, owner, repo, filepath, ref string) (string, error)
|
||||
ListContents(ctx context.Context, owner, repo, path string) ([]review.ContentEntry, error)
|
||||
GetAllFilesInPath(ctx context.Context, owner, repo, path string) (map[string]string, error)
|
||||
|
||||
// Review operations
|
||||
PostReview(ctx context.Context, owner, repo string, number int, event, body, commitID string, comments []vcsReviewComment) (*vcsReview, error)
|
||||
ListReviews(ctx context.Context, owner, repo string, number int) ([]vcsReview, error)
|
||||
DeleteReview(ctx context.Context, owner, repo string, number int, reviewID int64) error
|
||||
GetAuthenticatedUser(ctx context.Context) (string, error)
|
||||
RequestReviewer(ctx context.Context, owner, repo string, number int, reviewer string) error
|
||||
}
|
||||
|
||||
// giteaExtClient extends vcsClient with Gitea-specific operations that have no
|
||||
// GitHub equivalent. Code that uses these methods should first do a type assertion.
|
||||
type giteaExtClient interface {
|
||||
vcsClient
|
||||
GetTimelineReviewCommentIDForReview(ctx context.Context, owner, repo string, prNum, reviewID int64) (int64, error)
|
||||
EditComment(ctx context.Context, owner, repo string, commentID int64, body string) error
|
||||
ListReviewComments(ctx context.Context, owner, repo string, prNum, reviewID int64) ([]gitea.ReviewComment, error)
|
||||
ResolveComment(ctx context.Context, owner, repo string, commentID int64) error
|
||||
}
|
||||
|
||||
// --- shared VCS types ---
|
||||
|
||||
// vcsPullRequest is VCS-agnostic PR metadata.
|
||||
type vcsPullRequest struct {
|
||||
Title string
|
||||
Body string
|
||||
Head struct {
|
||||
Sha string
|
||||
Ref string
|
||||
}
|
||||
}
|
||||
|
||||
// vcsChangedFile is a file changed in a PR.
|
||||
type vcsChangedFile struct {
|
||||
Filename string
|
||||
Status string
|
||||
}
|
||||
|
||||
// vcsCommitStatus is a CI status entry.
|
||||
type vcsCommitStatus struct {
|
||||
Status string
|
||||
Context string
|
||||
Description string
|
||||
TargetURL string
|
||||
}
|
||||
|
||||
// vcsReviewComment is an inline review comment.
|
||||
type vcsReviewComment struct {
|
||||
Path string
|
||||
NewPosition int64 // Gitea: absolute line; GitHub: diff hunk position
|
||||
Body string
|
||||
}
|
||||
|
||||
// vcsReview is a submitted PR review.
|
||||
type vcsReview struct {
|
||||
ID int64
|
||||
Body string
|
||||
CommitID string
|
||||
User struct {
|
||||
Login string
|
||||
}
|
||||
State string
|
||||
}
|
||||
|
||||
// ============================================================
|
||||
// giteaVCSAdapter
|
||||
// ============================================================
|
||||
|
||||
// giteaVCSAdapter wraps gitea.Client to implement vcsClient + giteaExtClient.
|
||||
type giteaVCSAdapter struct {
|
||||
c *gitea.Client
|
||||
}
|
||||
|
||||
func newGiteaVCSAdapter(c *gitea.Client) *giteaVCSAdapter { return &giteaVCSAdapter{c: c} }
|
||||
|
||||
func (a *giteaVCSAdapter) GetPullRequest(ctx context.Context, owner, repo string, number int) (*vcsPullRequest, error) {
|
||||
pr, err := a.c.GetPullRequest(ctx, owner, repo, number)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
r := &vcsPullRequest{Title: pr.Title, Body: pr.Body}
|
||||
r.Head.Sha = pr.Head.Sha
|
||||
r.Head.Ref = pr.Head.Ref
|
||||
return r, nil
|
||||
}
|
||||
|
||||
func (a *giteaVCSAdapter) GetPullRequestDiff(ctx context.Context, owner, repo string, number int) (string, error) {
|
||||
return a.c.GetPullRequestDiff(ctx, owner, repo, number)
|
||||
}
|
||||
|
||||
func (a *giteaVCSAdapter) GetPullRequestFiles(ctx context.Context, owner, repo string, number int) ([]vcsChangedFile, error) {
|
||||
files, err := a.c.GetPullRequestFiles(ctx, owner, repo, number)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
out := make([]vcsChangedFile, len(files))
|
||||
for i, f := range files {
|
||||
out[i] = vcsChangedFile{Filename: f.Filename, Status: f.Status}
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (a *giteaVCSAdapter) GetCommitStatuses(ctx context.Context, owner, repo, sha string) ([]vcsCommitStatus, error) {
|
||||
statuses, err := a.c.GetCommitStatuses(ctx, owner, repo, sha)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
out := make([]vcsCommitStatus, len(statuses))
|
||||
for i, s := range statuses {
|
||||
out[i] = vcsCommitStatus{Status: s.Status, Context: s.Context, Description: s.Description, TargetURL: s.TargetURL}
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (a *giteaVCSAdapter) GetFileContent(ctx context.Context, owner, repo, filepath string) (string, error) {
|
||||
return a.c.GetFileContent(ctx, owner, repo, filepath)
|
||||
}
|
||||
|
||||
func (a *giteaVCSAdapter) GetFileContentRef(ctx context.Context, owner, repo, filepath, ref string) (string, error) {
|
||||
return a.c.GetFileContentRef(ctx, owner, repo, filepath, ref)
|
||||
}
|
||||
|
||||
func (a *giteaVCSAdapter) ListContents(ctx context.Context, owner, repo, path string) ([]review.ContentEntry, error) {
|
||||
entries, err := a.c.ListContents(ctx, owner, repo, path)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
out := make([]review.ContentEntry, len(entries))
|
||||
for i, e := range entries {
|
||||
out[i] = review.ContentEntry{Name: e.Name, Path: e.Path, Type: e.Type}
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (a *giteaVCSAdapter) GetAllFilesInPath(ctx context.Context, owner, repo, path string) (map[string]string, error) {
|
||||
return a.c.GetAllFilesInPath(ctx, owner, repo, path)
|
||||
}
|
||||
|
||||
func (a *giteaVCSAdapter) PostReview(ctx context.Context, owner, repo string, number int, event, body, commitID string, comments []vcsReviewComment) (*vcsReview, error) {
|
||||
gc := make([]gitea.ReviewComment, len(comments))
|
||||
for i, c := range comments {
|
||||
gc[i] = gitea.ReviewComment{Path: c.Path, NewPosition: c.NewPosition, Body: c.Body}
|
||||
}
|
||||
r, err := a.c.PostReview(ctx, owner, repo, number, event, body, commitID, gc)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
out := &vcsReview{ID: r.ID, Body: r.Body, CommitID: r.CommitID, State: r.State}
|
||||
out.User.Login = r.User.Login
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (a *giteaVCSAdapter) ListReviews(ctx context.Context, owner, repo string, number int) ([]vcsReview, error) {
|
||||
reviews, err := a.c.ListReviews(ctx, owner, repo, number)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
out := make([]vcsReview, len(reviews))
|
||||
for i, r := range reviews {
|
||||
out[i] = vcsReview{ID: r.ID, Body: r.Body, CommitID: r.CommitID, State: r.State}
|
||||
out[i].User.Login = r.User.Login
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (a *giteaVCSAdapter) DeleteReview(ctx context.Context, owner, repo string, number int, reviewID int64) error {
|
||||
return a.c.DeleteReview(ctx, owner, repo, number, reviewID)
|
||||
}
|
||||
|
||||
func (a *giteaVCSAdapter) GetAuthenticatedUser(ctx context.Context) (string, error) {
|
||||
return a.c.GetAuthenticatedUser(ctx)
|
||||
}
|
||||
|
||||
func (a *giteaVCSAdapter) RequestReviewer(ctx context.Context, owner, repo string, number int, reviewer string) error {
|
||||
return a.c.RequestReviewer(ctx, owner, repo, number, reviewer)
|
||||
}
|
||||
|
||||
// Gitea-specific extension methods.
|
||||
|
||||
func (a *giteaVCSAdapter) GetTimelineReviewCommentIDForReview(ctx context.Context, owner, repo string, prNum, reviewID int64) (int64, error) {
|
||||
return a.c.GetTimelineReviewCommentIDForReview(ctx, owner, repo, int(prNum), reviewID)
|
||||
}
|
||||
|
||||
func (a *giteaVCSAdapter) EditComment(ctx context.Context, owner, repo string, commentID int64, body string) error {
|
||||
return a.c.EditComment(ctx, owner, repo, commentID, body)
|
||||
}
|
||||
|
||||
func (a *giteaVCSAdapter) ListReviewComments(ctx context.Context, owner, repo string, prNum, reviewID int64) ([]gitea.ReviewComment, error) {
|
||||
return a.c.ListReviewComments(ctx, owner, repo, int(prNum), reviewID)
|
||||
}
|
||||
|
||||
func (a *giteaVCSAdapter) ResolveComment(ctx context.Context, owner, repo string, commentID int64) error {
|
||||
return a.c.ResolveComment(ctx, owner, repo, commentID)
|
||||
}
|
||||
|
||||
// ============================================================
|
||||
// githubVCSAdapter
|
||||
// ============================================================
|
||||
|
||||
// githubVCSAdapter wraps github.Client to implement vcsClient.
|
||||
// Gitea-specific extension methods (GetTimelineReviewCommentIDForReview, EditComment,
|
||||
// ListReviewComments, ResolveComment) are not available on GitHub and will not be called
|
||||
// because main.go gates them with a type assertion to giteaExtClient.
|
||||
type githubVCSAdapter struct {
|
||||
c *github.Client
|
||||
}
|
||||
|
||||
func newGithubVCSAdapter(c *github.Client) *githubVCSAdapter { return &githubVCSAdapter{c: c} }
|
||||
|
||||
func (a *githubVCSAdapter) GetPullRequest(ctx context.Context, owner, repo string, number int) (*vcsPullRequest, error) {
|
||||
pr, err := a.c.GetPullRequest(ctx, owner, repo, number)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
r := &vcsPullRequest{Title: pr.Title, Body: pr.Body}
|
||||
r.Head.Sha = pr.Head.Sha
|
||||
r.Head.Ref = pr.Head.Ref
|
||||
return r, nil
|
||||
}
|
||||
|
||||
func (a *githubVCSAdapter) GetPullRequestDiff(ctx context.Context, owner, repo string, number int) (string, error) {
|
||||
return a.c.GetPullRequestDiff(ctx, owner, repo, number)
|
||||
}
|
||||
|
||||
func (a *githubVCSAdapter) GetPullRequestFiles(ctx context.Context, owner, repo string, number int) ([]vcsChangedFile, error) {
|
||||
files, err := a.c.GetPullRequestFiles(ctx, owner, repo, number)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
out := make([]vcsChangedFile, len(files))
|
||||
for i, f := range files {
|
||||
out[i] = vcsChangedFile{Filename: f.Filename, Status: f.Status}
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (a *githubVCSAdapter) GetCommitStatuses(ctx context.Context, owner, repo, sha string) ([]vcsCommitStatus, error) {
|
||||
statuses, err := a.c.GetCommitStatuses(ctx, owner, repo, sha)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
out := make([]vcsCommitStatus, len(statuses))
|
||||
for i, s := range statuses {
|
||||
// CommitStatus.Status is tagged as json:"state" — already the normalized "state" value
|
||||
out[i] = vcsCommitStatus{Status: s.Status, Context: s.Context, Description: s.Description, TargetURL: s.TargetURL}
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (a *githubVCSAdapter) GetFileContent(ctx context.Context, owner, repo, filepath string) (string, error) {
|
||||
return a.c.GetFileContent(ctx, owner, repo, filepath)
|
||||
}
|
||||
|
||||
func (a *githubVCSAdapter) GetFileContentRef(ctx context.Context, owner, repo, filepath, ref string) (string, error) {
|
||||
return a.c.GetFileContentRef(ctx, owner, repo, filepath, ref)
|
||||
}
|
||||
|
||||
func (a *githubVCSAdapter) ListContents(ctx context.Context, owner, repo, path string) ([]review.ContentEntry, error) {
|
||||
entries, err := a.c.ListContents(ctx, owner, repo, path)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
out := make([]review.ContentEntry, len(entries))
|
||||
for i, e := range entries {
|
||||
out[i] = review.ContentEntry{Name: e.Name, Path: e.Path, Type: e.Type}
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (a *githubVCSAdapter) GetAllFilesInPath(ctx context.Context, owner, repo, path string) (map[string]string, error) {
|
||||
return a.c.GetAllFilesInPath(ctx, owner, repo, path)
|
||||
}
|
||||
|
||||
func (a *githubVCSAdapter) PostReview(ctx context.Context, owner, repo string, number int, event, body, commitID string, comments []vcsReviewComment) (*vcsReview, error) {
|
||||
gc := make([]github.ReviewComment, len(comments))
|
||||
for i, c := range comments {
|
||||
// GitHub inline comments use diff hunk "position", not absolute line numbers.
|
||||
// NewPosition from gitea diff parsing gives absolute line numbers, which
|
||||
// will not match GitHub's position values. For initial GitHub support, we
|
||||
// attach comments with Line+Side (absolute line on the RIGHT side) instead.
|
||||
// Comments that cannot be mapped will be omitted (GitHub rejects invalid positions).
|
||||
gc[i] = github.ReviewComment{
|
||||
Path: c.Path,
|
||||
Line: c.NewPosition,
|
||||
Side: "RIGHT",
|
||||
Body: c.Body,
|
||||
}
|
||||
}
|
||||
r, err := a.c.PostReview(ctx, owner, repo, number, event, body, commitID, gc)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
out := &vcsReview{ID: r.ID, Body: r.Body, State: r.State}
|
||||
out.User.Login = r.User.Login
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (a *githubVCSAdapter) ListReviews(ctx context.Context, owner, repo string, number int) ([]vcsReview, error) {
|
||||
reviews, err := a.c.ListReviews(ctx, owner, repo, number)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
out := make([]vcsReview, len(reviews))
|
||||
for i, r := range reviews {
|
||||
out[i] = vcsReview{ID: r.ID, Body: r.Body, State: r.State}
|
||||
out[i].User.Login = r.User.Login
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (a *githubVCSAdapter) DeleteReview(ctx context.Context, owner, repo string, number int, reviewID int64) error {
|
||||
// GitHub only allows deleting PENDING (draft) reviews. review-bot posts submitted
|
||||
// reviews, so this will return an error for any review we actually posted.
|
||||
// Callers should treat 422 errors here gracefully.
|
||||
return a.c.DeleteReview(ctx, owner, repo, number, reviewID)
|
||||
}
|
||||
|
||||
func (a *githubVCSAdapter) GetAuthenticatedUser(ctx context.Context) (string, error) {
|
||||
return a.c.GetAuthenticatedUser(ctx)
|
||||
}
|
||||
|
||||
func (a *githubVCSAdapter) RequestReviewer(ctx context.Context, owner, repo string, number int, reviewer string) error {
|
||||
return a.c.RequestReviewer(ctx, owner, repo, number, reviewer)
|
||||
}
|
||||
+89
-10
@@ -106,34 +106,113 @@ func defaultCheckRedirect(req *http.Request, via []*http.Request) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// safeDialContext is the default DialContext for NewClient.
|
||||
// It resolves the hostname and checks every returned IP against the blocked
|
||||
// CIDR list before establishing a connection. This prevents SSRF attacks
|
||||
// where user-supplied URLs resolve to internal/private addresses.
|
||||
//
|
||||
// After validating all IPs, we dial the first resolved IP directly to avoid
|
||||
// a second DNS lookup (which could return a different IP in a DNS rebinding
|
||||
// attack). This narrows — but does not fully eliminate — the DNS rebinding
|
||||
// window to the time between LookupIPAddr and DialContext.
|
||||
//
|
||||
// If the host is already an IP literal, LookupIPAddr returns it directly
|
||||
// (no DNS query issued), so IP literals like https://127.0.0.1/ are blocked.
|
||||
func safeDialContext(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||
host, port, err := net.SplitHostPort(addr)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("safeDialContext: invalid address %q: %w", addr, err)
|
||||
}
|
||||
addrs, err := net.DefaultResolver.LookupIPAddr(ctx, host)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("safeDialContext: DNS lookup %q: %w", host, err)
|
||||
}
|
||||
if len(addrs) == 0 {
|
||||
return nil, fmt.Errorf("safeDialContext: no addresses returned for %q", host)
|
||||
}
|
||||
for _, a := range addrs {
|
||||
if IsBlockedIP(a.IP) {
|
||||
return nil, fmt.Errorf("safeDialContext: blocked: %q resolves to private/reserved IP %s", host, a.IP)
|
||||
}
|
||||
}
|
||||
// Try each resolved IP in order, returning the first successful connection.
|
||||
// Fallback is important when a hostname resolves to multiple IPs and the first
|
||||
// is temporarily unreachable. All IPs were already validated above, so dialing
|
||||
// any of them is safe.
|
||||
//
|
||||
// Timeout: 10s per the design (PLAN.md); the outer http.Client has a 30s
|
||||
// total timeout, but the per-dial timeout ensures a slow TCP connect on one IP
|
||||
// doesn't consume the budget needed to try others.
|
||||
d := &net.Dialer{Timeout: 10 * time.Second}
|
||||
var lastErr error
|
||||
for _, a := range addrs {
|
||||
conn, err := d.DialContext(ctx, network, net.JoinHostPort(a.IP.String(), port))
|
||||
if err == nil {
|
||||
return conn, nil
|
||||
}
|
||||
lastErr = err
|
||||
}
|
||||
return nil, fmt.Errorf("safeDialContext: all %d addresses for %q failed, last error: %w", len(addrs), host, lastErr)
|
||||
}
|
||||
|
||||
// newSafeHTTPClient returns an *http.Client with the SSRF-blocking safeDialContext
|
||||
// transport and the cross-host redirect rejection policy.
|
||||
//
|
||||
// We clone http.DefaultTransport to preserve its production-ready defaults
|
||||
// (ProxyFromEnvironment, TLSHandshakeTimeout, IdleConnTimeout, connection
|
||||
// pooling, HTTP/2 support) and override only DialContext with safeDialContext.
|
||||
func newSafeHTTPClient() *http.Client {
|
||||
transport := http.DefaultTransport.(*http.Transport).Clone()
|
||||
transport.DialContext = safeDialContext
|
||||
return &http.Client{
|
||||
Timeout: 30 * time.Second,
|
||||
Transport: transport,
|
||||
CheckRedirect: defaultCheckRedirect,
|
||||
}
|
||||
}
|
||||
|
||||
// NewClient creates a new Gitea API client.
|
||||
//
|
||||
// The client uses a safe HTTP transport by default: DNS resolution is performed
|
||||
// before connecting and any IP in a private/reserved range is rejected
|
||||
// (RFC1918, loopback, link-local, ULA, etc.). Cross-host and HTTPS→HTTP
|
||||
// redirects are also rejected.
|
||||
//
|
||||
// For tests that use httptest.NewServer (which listens on 127.0.0.1), call
|
||||
// WithUnsafeDialer() to bypass the IP check.
|
||||
func NewClient(baseURL, token string) *Client {
|
||||
return &Client{
|
||||
baseURL: strings.TrimRight(baseURL, "/"),
|
||||
token: token,
|
||||
http: &http.Client{
|
||||
Timeout: 30 * time.Second,
|
||||
CheckRedirect: defaultCheckRedirect,
|
||||
},
|
||||
http: newSafeHTTPClient(),
|
||||
}
|
||||
}
|
||||
|
||||
// WithUnsafeDialer returns the client configured with a plain HTTP client that
|
||||
// has no IP-level SSRF protection. It preserves the redirect-rejection policy.
|
||||
//
|
||||
// This MUST only be used in tests. Production code must never call this method.
|
||||
func (c *Client) WithUnsafeDialer() *Client {
|
||||
c.http = &http.Client{
|
||||
Timeout: 30 * time.Second,
|
||||
CheckRedirect: defaultCheckRedirect,
|
||||
}
|
||||
return c
|
||||
}
|
||||
|
||||
// SetHTTPClient sets the underlying HTTP client used for requests.
|
||||
// This is intended for test setup only to inject mock transports; it must be
|
||||
// called before any goroutines issue requests.
|
||||
//
|
||||
// Passing nil restores the default client (30s timeout + redirect-rejecting
|
||||
// CheckRedirect policy matching NewClient).
|
||||
// Passing nil restores the default safe client (30s timeout, IP-blocking
|
||||
// safeDialContext, and redirect-rejecting CheckRedirect policy matching NewClient).
|
||||
//
|
||||
// Callers providing a non-nil client are responsible for configuring a safe
|
||||
// CheckRedirect policy. Without one, the default net/http behavior will follow
|
||||
// redirects and may forward the Authorization header to untrusted hosts.
|
||||
func (c *Client) SetHTTPClient(hc *http.Client) {
|
||||
if hc == nil {
|
||||
hc = &http.Client{
|
||||
Timeout: 30 * time.Second,
|
||||
CheckRedirect: defaultCheckRedirect,
|
||||
}
|
||||
hc = newSafeHTTPClient()
|
||||
}
|
||||
c.http = hc
|
||||
}
|
||||
|
||||
+249
-37
@@ -36,7 +36,7 @@ func TestGetPullRequest(t *testing.T) {
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
client := NewClient(server.URL, "test-token")
|
||||
client := NewTestClient(server.URL, "test-token")
|
||||
got, err := client.GetPullRequest(context.Background(), "owner", "repo", 1)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
@@ -63,7 +63,7 @@ func TestGetPullRequestDiff(t *testing.T) {
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
client := NewClient(server.URL, "test-token")
|
||||
client := NewTestClient(server.URL, "test-token")
|
||||
got, err := client.GetPullRequestDiff(context.Background(), "owner", "repo", 5)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
@@ -88,7 +88,7 @@ func TestGetCommitStatuses(t *testing.T) {
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
client := NewClient(server.URL, "test-token")
|
||||
client := NewTestClient(server.URL, "test-token")
|
||||
got, err := client.GetCommitStatuses(context.Background(), "owner", "repo", "abc123")
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
@@ -138,7 +138,7 @@ func TestPostReview(t *testing.T) {
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
client := NewClient(server.URL, "test-token")
|
||||
client := NewTestClient(server.URL, "test-token")
|
||||
review, err := client.PostReview(context.Background(), "owner", "repo", 3, "APPROVED", "LGTM", "abc123def", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
@@ -158,7 +158,7 @@ func TestGetPullRequest_Non200(t *testing.T) {
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
client := NewClient(server.URL, "test-token")
|
||||
client := NewTestClient(server.URL, "test-token")
|
||||
_, err := client.GetPullRequest(context.Background(), "owner", "repo", 999)
|
||||
if err == nil {
|
||||
t.Fatal("expected error for 404, got nil")
|
||||
@@ -171,7 +171,7 @@ func TestGetPullRequest_BadJSON(t *testing.T) {
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
client := NewClient(server.URL, "test-token")
|
||||
client := NewTestClient(server.URL, "test-token")
|
||||
_, err := client.GetPullRequest(context.Background(), "owner", "repo", 1)
|
||||
if err == nil {
|
||||
t.Fatal("expected error for bad JSON, got nil")
|
||||
@@ -185,7 +185,7 @@ func TestPostReview_Non200(t *testing.T) {
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
client := NewClient(server.URL, "test-token")
|
||||
client := NewTestClient(server.URL, "test-token")
|
||||
_, err := client.PostReview(context.Background(), "owner", "repo", 1, "APPROVED", "test", "", nil)
|
||||
if err == nil {
|
||||
t.Fatal("expected error for 403, got nil")
|
||||
@@ -208,7 +208,7 @@ func TestPostReview_EmptyCommitID_OmittedFromPayload(t *testing.T) {
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
client := NewClient(server.URL, "test-token")
|
||||
client := NewTestClient(server.URL, "test-token")
|
||||
_, err := client.PostReview(context.Background(), "owner", "repo", 1, "APPROVED", "ok", "", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
@@ -226,7 +226,7 @@ func TestGetFileContent(t *testing.T) {
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
client := NewClient(server.URL, "test-token")
|
||||
client := NewTestClient(server.URL, "test-token")
|
||||
got, err := client.GetFileContent(context.Background(), "owner", "repo", "CONVENTIONS.md")
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
@@ -246,7 +246,7 @@ func TestGetPullRequestFiles(t *testing.T) {
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
client := NewClient(server.URL, "test-token")
|
||||
client := NewTestClient(server.URL, "test-token")
|
||||
files, err := client.GetPullRequestFiles(context.Background(), "owner", "repo", 1)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
@@ -271,7 +271,7 @@ func TestGetFileContentRef(t *testing.T) {
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
client := NewClient(server.URL, "test-token")
|
||||
client := NewTestClient(server.URL, "test-token")
|
||||
content, err := client.GetFileContentRef(context.Background(), "owner", "repo", "main.go", "feature-branch")
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
@@ -291,7 +291,7 @@ func TestListContents(t *testing.T) {
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
client := NewClient(server.URL, "test-token")
|
||||
client := NewTestClient(server.URL, "test-token")
|
||||
entries, err := client.ListContents(context.Background(), "owner", "repo", "docs")
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
@@ -318,7 +318,7 @@ func TestListContents_DotPath(t *testing.T) {
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
client := NewClient(server.URL, "test-token")
|
||||
client := NewTestClient(server.URL, "test-token")
|
||||
entries, err := client.ListContents(context.Background(), "owner", "repo", ".")
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
@@ -343,7 +343,7 @@ func TestListContents_FilePath(t *testing.T) {
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
client := NewClient(server.URL, "test-token")
|
||||
client := NewTestClient(server.URL, "test-token")
|
||||
entries, err := client.ListContents(context.Background(), "owner", "repo", "README.md")
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
@@ -375,7 +375,7 @@ func TestGetAllFilesInPath_File(t *testing.T) {
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
client := NewClient(server.URL, "test-token")
|
||||
client := NewTestClient(server.URL, "test-token")
|
||||
files, err := client.GetAllFilesInPath(context.Background(), "owner", "repo", "README.md")
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
@@ -428,7 +428,7 @@ func TestListReviews(t *testing.T) {
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
client := NewClient(server.URL, "test-token")
|
||||
client := NewTestClient(server.URL, "test-token")
|
||||
reviews, err := client.ListReviews(context.Background(), "owner", "repo", 5)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
@@ -468,7 +468,7 @@ func TestListReviews_Pagination(t *testing.T) {
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
client := NewClient(server.URL, "test-token")
|
||||
client := NewTestClient(server.URL, "test-token")
|
||||
reviews, err := client.ListReviews(context.Background(), "owner", "repo", 5)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
@@ -493,7 +493,7 @@ func TestDeleteReview(t *testing.T) {
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
client := NewClient(server.URL, "test-token")
|
||||
client := NewTestClient(server.URL, "test-token")
|
||||
err := client.DeleteReview(context.Background(), "owner", "repo", 5, 10)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
@@ -507,7 +507,7 @@ func TestDeleteReview_Forbidden(t *testing.T) {
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
client := NewClient(server.URL, "test-token")
|
||||
client := NewTestClient(server.URL, "test-token")
|
||||
err := client.DeleteReview(context.Background(), "owner", "repo", 5, 10)
|
||||
if err == nil {
|
||||
t.Fatal("expected error for 403, got nil")
|
||||
@@ -536,7 +536,7 @@ func TestEditComment(t *testing.T) {
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
client := NewClient(server.URL, "test-token")
|
||||
client := NewTestClient(server.URL, "test-token")
|
||||
err := client.EditComment(context.Background(), "owner", "repo", 42, "updated body")
|
||||
if err != nil {
|
||||
t.Fatalf("EditComment() error = %v", err)
|
||||
@@ -550,7 +550,7 @@ func TestEditComment_Forbidden(t *testing.T) {
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
client := NewClient(server.URL, "test-token")
|
||||
client := NewTestClient(server.URL, "test-token")
|
||||
err := client.EditComment(context.Background(), "owner", "repo", 42, "new body")
|
||||
if err == nil {
|
||||
t.Fatal("expected error for 403 response")
|
||||
@@ -570,7 +570,7 @@ func TestGetTimelineReviewCommentID(t *testing.T) {
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
client := NewClient(server.URL, "test-token")
|
||||
client := NewTestClient(server.URL, "test-token")
|
||||
id, err := client.GetTimelineReviewCommentID(context.Background(), "owner", "repo", 5, "<!-- review-bot:sonnet -->")
|
||||
if err != nil {
|
||||
t.Fatalf("GetTimelineReviewCommentID() error = %v", err)
|
||||
@@ -586,7 +586,7 @@ func TestGetTimelineReviewCommentID_NotFound(t *testing.T) {
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
client := NewClient(server.URL, "test-token")
|
||||
client := NewTestClient(server.URL, "test-token")
|
||||
_, err := client.GetTimelineReviewCommentID(context.Background(), "owner", "repo", 5, "<!-- review-bot:sonnet -->")
|
||||
if err == nil {
|
||||
t.Fatal("expected error when sentinel not found")
|
||||
@@ -609,7 +609,7 @@ func TestGetAllFilesInPath_404FallsBackToFile(t *testing.T) {
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
client := NewClient(server.URL, "test-token")
|
||||
client := NewTestClient(server.URL, "test-token")
|
||||
files, err := client.GetAllFilesInPath(context.Background(), "owner", "repo", "README.md")
|
||||
if err != nil {
|
||||
t.Fatalf("expected fallback to file on 404, got error: %v", err)
|
||||
@@ -630,7 +630,7 @@ func TestGetAllFilesInPath_500Propagates(t *testing.T) {
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
client := NewClient(server.URL, "test-token")
|
||||
client := NewTestClient(server.URL, "test-token")
|
||||
_, err := client.GetAllFilesInPath(context.Background(), "owner", "repo", "somepath")
|
||||
if err == nil {
|
||||
t.Fatal("expected error to propagate for 500, got nil")
|
||||
@@ -652,7 +652,7 @@ func TestGetAllFilesInPath_403Propagates(t *testing.T) {
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
client := NewClient(server.URL, "test-token")
|
||||
client := NewTestClient(server.URL, "test-token")
|
||||
_, err := client.GetAllFilesInPath(context.Background(), "owner", "repo", "private/stuff")
|
||||
if err == nil {
|
||||
t.Fatal("expected error to propagate for 403, got nil")
|
||||
@@ -704,7 +704,7 @@ func TestGetAuthenticatedUser(t *testing.T) {
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
client := NewClient(server.URL, "test-token")
|
||||
client := NewTestClient(server.URL, "test-token")
|
||||
login, err := client.GetAuthenticatedUser(context.Background())
|
||||
if err != nil {
|
||||
t.Fatalf("GetAuthenticatedUser() error = %v", err)
|
||||
@@ -729,7 +729,7 @@ func TestRequestReviewer(t *testing.T) {
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
client := NewClient(server.URL, "test-token")
|
||||
client := NewTestClient(server.URL, "test-token")
|
||||
err := client.RequestReviewer(context.Background(), "owner", "repo", 7, "bot-user")
|
||||
if err != nil {
|
||||
t.Fatalf("RequestReviewer() error = %v", err)
|
||||
@@ -745,7 +745,7 @@ func TestRequestReviewer_204(t *testing.T) {
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
client := NewClient(server.URL, "test-token")
|
||||
client := NewTestClient(server.URL, "test-token")
|
||||
err := client.RequestReviewer(context.Background(), "owner", "repo", 1, "user")
|
||||
if err != nil {
|
||||
t.Fatalf("RequestReviewer() should accept 204, got error = %v", err)
|
||||
@@ -759,7 +759,7 @@ func TestRequestReviewer_Error(t *testing.T) {
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
client := NewClient(server.URL, "test-token")
|
||||
client := NewTestClient(server.URL, "test-token")
|
||||
err := client.RequestReviewer(context.Background(), "owner", "repo", 1, "user")
|
||||
if err == nil {
|
||||
t.Fatal("expected error for 403 response")
|
||||
@@ -779,7 +779,7 @@ func TestListReviewComments(t *testing.T) {
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
client := NewClient(server.URL, "test-token")
|
||||
client := NewTestClient(server.URL, "test-token")
|
||||
comments, err := client.ListReviewComments(context.Background(), "owner", "repo", 1, 42)
|
||||
if err != nil {
|
||||
t.Fatalf("ListReviewComments() error = %v", err)
|
||||
@@ -807,7 +807,7 @@ func TestResolveComment(t *testing.T) {
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
client := NewClient(server.URL, "test-token")
|
||||
client := NewTestClient(server.URL, "test-token")
|
||||
err := client.ResolveComment(context.Background(), "owner", "repo", 99)
|
||||
if err != nil {
|
||||
t.Fatalf("ResolveComment() error = %v", err)
|
||||
@@ -821,7 +821,7 @@ func TestResolveComment_Error(t *testing.T) {
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
client := NewClient(server.URL, "test-token")
|
||||
client := NewTestClient(server.URL, "test-token")
|
||||
err := client.ResolveComment(context.Background(), "owner", "repo", 99)
|
||||
if err == nil {
|
||||
t.Fatal("expected error for 404 response")
|
||||
@@ -870,7 +870,7 @@ func TestDoGet_RetriesOn500(t *testing.T) {
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
client := NewClient(server.URL, "test-token")
|
||||
client := NewTestClient(server.URL, "test-token")
|
||||
// Use short backoff for fast tests
|
||||
client.RetryBackoff = []time.Duration{1 * time.Millisecond, 1 * time.Millisecond}
|
||||
|
||||
@@ -895,7 +895,7 @@ func TestDoGet_FailsAfterMaxRetries(t *testing.T) {
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
client := NewClient(server.URL, "test-token")
|
||||
client := NewTestClient(server.URL, "test-token")
|
||||
// Use short backoff for fast tests
|
||||
client.RetryBackoff = []time.Duration{1 * time.Millisecond, 1 * time.Millisecond}
|
||||
|
||||
@@ -924,7 +924,7 @@ func TestDoGet_NoRetryOn4xx(t *testing.T) {
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
client := NewClient(server.URL, "test-token")
|
||||
client := NewTestClient(server.URL, "test-token")
|
||||
_, err := client.doGet(context.Background(), server.URL+"/test")
|
||||
if err == nil {
|
||||
t.Fatal("expected error for 403")
|
||||
@@ -952,7 +952,7 @@ func TestDoGet_RespectsContextCancellation(t *testing.T) {
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
|
||||
client := NewClient(server.URL, "test-token")
|
||||
client := NewTestClient(server.URL, "test-token")
|
||||
// Use longer backoff to give us time to cancel during the wait
|
||||
client.RetryBackoff = []time.Duration{100 * time.Millisecond, 100 * time.Millisecond}
|
||||
|
||||
@@ -971,6 +971,7 @@ func TestDoGet_RespectsContextCancellation(t *testing.T) {
|
||||
t.Errorf("attempts = %d, expected 1 before context cancel during backoff", attempts)
|
||||
}
|
||||
}
|
||||
|
||||
// mockTransport is a test helper that returns errors for the first N calls,
|
||||
// then delegates to a real server.
|
||||
type mockTransport struct {
|
||||
@@ -1285,3 +1286,214 @@ func TestSetHTTPClient_NilRestoresDefault(t *testing.T) {
|
||||
t.Fatal("expected CheckRedirect policy after SetHTTPClient(nil)")
|
||||
}
|
||||
}
|
||||
|
||||
// TestSafeDialContextBlocksPrivateIPs verifies that NewClient (which uses
|
||||
// safeDialContext by default) refuses to connect to private/reserved IPs.
|
||||
func TestSafeDialContextBlocksPrivateIPs(t *testing.T) {
|
||||
// These servers listen on 127.0.0.1, so the safe dialer will block them.
|
||||
// We use NewClient (NOT NewTestClient) to exercise the real safe dialer.
|
||||
privateURLs := []struct {
|
||||
name string
|
||||
url string
|
||||
}{
|
||||
{"loopback localhost", "http://localhost/"},
|
||||
{"loopback 127.0.0.1", "http://127.0.0.1/"},
|
||||
}
|
||||
|
||||
for _, tc := range privateURLs {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
c := NewClient(tc.url, "token")
|
||||
_, err := c.GetPullRequest(context.Background(), "owner", "repo", 1)
|
||||
if err == nil {
|
||||
t.Errorf("expected error connecting to %s, got nil", tc.url)
|
||||
}
|
||||
// Error must mention SSRF/blocked, not a random network error.
|
||||
if !strings.Contains(err.Error(), "blocked") &&
|
||||
!strings.Contains(err.Error(), "private") &&
|
||||
!strings.Contains(err.Error(), "loopback") &&
|
||||
!strings.Contains(err.Error(), "reserved") {
|
||||
t.Logf("error: %v", err)
|
||||
// Allow other errors (connection refused, DNS) since the point
|
||||
// is that we don't silently succeed — but prefer the explicit block message.
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestWithUnsafeDialerAllowsLocalhost verifies that WithUnsafeDialer bypasses
|
||||
// the IP check, allowing tests to connect to httptest.Server (127.0.0.1).
|
||||
func TestWithUnsafeDialerAllowsLocalhost(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.Write([]byte(`{"title":"test","body":"","head":{"sha":"abc","ref":"main"}}`))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
// WithUnsafeDialer should allow connecting to 127.0.0.1.
|
||||
c := NewClient(server.URL, "token").WithUnsafeDialer()
|
||||
pr, err := c.GetPullRequest(context.Background(), "owner", "repo", 1)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error with unsafe dialer: %v", err)
|
||||
}
|
||||
if pr.Title != "test" {
|
||||
t.Errorf("expected title 'test', got %q", pr.Title)
|
||||
}
|
||||
}
|
||||
|
||||
// TestNewClient_HasSafeTransport verifies that NewClient installs the
|
||||
// SSRF-blocking transport (i.e. Transport is not nil and DialContext is set).
|
||||
func TestNewClient_HasSafeTransport(t *testing.T) {
|
||||
c := NewClient("https://gitea.example.com", "token")
|
||||
if c.http.Transport == nil {
|
||||
t.Fatal("expected Transport to be set on NewClient (safe dialer)")
|
||||
}
|
||||
transport, ok := c.http.Transport.(*http.Transport)
|
||||
if !ok {
|
||||
t.Fatalf("expected *http.Transport, got %T", c.http.Transport)
|
||||
}
|
||||
if transport.DialContext == nil {
|
||||
t.Fatal("expected DialContext to be set on transport (safe dialer)")
|
||||
}
|
||||
}
|
||||
|
||||
// TestSetHTTPClient_NilRestoresSafeTransport verifies that SetHTTPClient(nil)
|
||||
// restores the safe transport (not just any client).
|
||||
func TestSetHTTPClient_NilRestoresSafeTransport(t *testing.T) {
|
||||
c := NewClient("https://gitea.example.com", "token")
|
||||
c.SetHTTPClient(&http.Client{}) // replace with plain client
|
||||
c.SetHTTPClient(nil) // restore
|
||||
transport, ok := c.http.Transport.(*http.Transport)
|
||||
if !ok {
|
||||
t.Fatalf("expected *http.Transport after SetHTTPClient(nil), got %T", c.http.Transport)
|
||||
}
|
||||
if transport.DialContext == nil {
|
||||
t.Fatal("expected DialContext to be restored after SetHTTPClient(nil)")
|
||||
}
|
||||
}
|
||||
|
||||
// TestNewSafeHTTPClient_PreservesDefaultTransportSettings verifies that
|
||||
// newSafeHTTPClient clones http.DefaultTransport to retain proxy support,
|
||||
// TLS handshake timeout, idle connection limits, and HTTP/2.
|
||||
func TestNewSafeHTTPClient_PreservesDefaultTransportSettings(t *testing.T) {
|
||||
c := NewClient("https://gitea.example.com", "token")
|
||||
transport, ok := c.http.Transport.(*http.Transport)
|
||||
if !ok {
|
||||
t.Fatalf("expected *http.Transport, got %T", c.http.Transport)
|
||||
}
|
||||
|
||||
defaults := http.DefaultTransport.(*http.Transport)
|
||||
|
||||
// TLSHandshakeTimeout must be inherited (non-zero), not the zero value
|
||||
// that a bare &http.Transport{} would have.
|
||||
if transport.TLSHandshakeTimeout == 0 {
|
||||
t.Error("TLSHandshakeTimeout is 0; expected inherited value from DefaultTransport")
|
||||
}
|
||||
if transport.TLSHandshakeTimeout != defaults.TLSHandshakeTimeout {
|
||||
t.Errorf("TLSHandshakeTimeout = %v, want %v", transport.TLSHandshakeTimeout, defaults.TLSHandshakeTimeout)
|
||||
}
|
||||
|
||||
// IdleConnTimeout must be inherited.
|
||||
if transport.IdleConnTimeout == 0 {
|
||||
t.Error("IdleConnTimeout is 0; expected inherited value from DefaultTransport")
|
||||
}
|
||||
if transport.IdleConnTimeout != defaults.IdleConnTimeout {
|
||||
t.Errorf("IdleConnTimeout = %v, want %v", transport.IdleConnTimeout, defaults.IdleConnTimeout)
|
||||
}
|
||||
|
||||
// MaxIdleConns must be inherited.
|
||||
if transport.MaxIdleConns == 0 {
|
||||
t.Error("MaxIdleConns is 0; expected inherited value from DefaultTransport")
|
||||
}
|
||||
|
||||
// ForceAttemptHTTP2 must be inherited.
|
||||
if !transport.ForceAttemptHTTP2 {
|
||||
t.Error("ForceAttemptHTTP2 is false; expected true from DefaultTransport")
|
||||
}
|
||||
|
||||
// Proxy must be set (ProxyFromEnvironment).
|
||||
if transport.Proxy == nil {
|
||||
t.Error("Proxy is nil; expected ProxyFromEnvironment from DefaultTransport")
|
||||
}
|
||||
|
||||
// DialContext must be our safe dialer, not the default.
|
||||
if transport.DialContext == nil {
|
||||
t.Error("DialContext is nil; expected safeDialContext")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetTimelineReviewCommentIDForReview(t *testing.T) {
|
||||
const reviewID = int64(42)
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
switch r.URL.Path {
|
||||
case "/api/v1/repos/owner/repo/pulls/5/reviews/42":
|
||||
w.Write([]byte(`{"body": "The review body <!-- review-bot:sonnet -->", "user": {"login": "sonnet-review"}}`))
|
||||
case "/api/v1/repos/owner/repo/issues/5/timeline":
|
||||
w.Write([]byte(`[
|
||||
{"id": 100, "type": "comment", "body": "unrelated", "user": {"login": "sonnet-review"}},
|
||||
{"id": 200, "type": "review", "body": "The review body <!-- review-bot:sonnet -->", "user": {"login": "sonnet-review"}}
|
||||
]`))
|
||||
default:
|
||||
t.Errorf("unexpected path: %s", r.URL.Path)
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
}
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
client := NewTestClient(server.URL, "test-token")
|
||||
id, err := client.GetTimelineReviewCommentIDForReview(context.Background(), "owner", "repo", 5, reviewID)
|
||||
if err != nil {
|
||||
t.Fatalf("GetTimelineReviewCommentIDForReview() error = %v", err)
|
||||
}
|
||||
if id != 200 {
|
||||
t.Errorf("got id=%d, want 200", id)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetTimelineReviewCommentIDForReview_ReviewFetchError(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
w.Write([]byte(`{"message":"not found"}`))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
client := NewTestClient(server.URL, "test-token")
|
||||
_, err := client.GetTimelineReviewCommentIDForReview(context.Background(), "owner", "repo", 5, 99)
|
||||
if err == nil {
|
||||
t.Fatal("expected error for missing review, got nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetTimelineReviewCommentIDForReview_EmptyBody(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Write([]byte(`{"body": "", "user": {"login": "bot"}}`))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
client := NewTestClient(server.URL, "test-token")
|
||||
_, err := client.GetTimelineReviewCommentIDForReview(context.Background(), "owner", "repo", 5, 42)
|
||||
if err == nil {
|
||||
t.Fatal("expected error for empty body, got nil")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "empty body") {
|
||||
t.Errorf("error = %q, want to contain 'empty body'", err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetTimelineReviewCommentIDForReview_NotFoundInTimeline(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
switch r.URL.Path {
|
||||
case "/api/v1/repos/owner/repo/pulls/5/reviews/42":
|
||||
w.Write([]byte(`{"body": "review content <!-- review-bot:sonnet -->", "user": {"login": "bot"}}`))
|
||||
default:
|
||||
// Timeline returns events that don't match (different user)
|
||||
w.Write([]byte(`[{"id": 1, "type": "review", "body": "review content <!-- review-bot:sonnet -->", "user": {"login": "other-user"}}]`))
|
||||
}
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
client := NewTestClient(server.URL, "test-token")
|
||||
_, err := client.GetTimelineReviewCommentIDForReview(context.Background(), "owner", "repo", 5, 42)
|
||||
if err == nil {
|
||||
t.Fatal("expected error when review not found in timeline, got nil")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -70,7 +70,7 @@ func TestGetPullRequestDiff_SizeLimits(t *testing.T) {
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
client := NewClient(server.URL, "test-token")
|
||||
client := NewTestClient(server.URL, "test-token")
|
||||
client.MaxDiffSize = tt.maxDiffSize
|
||||
client.RetryBackoff = []time.Duration{}
|
||||
|
||||
|
||||
@@ -0,0 +1,18 @@
|
||||
// Package gitea — export_test.go exposes test helpers to test files in this
|
||||
// package. It uses `package gitea` (not `package gitea_test`) so it can access
|
||||
// unexported identifiers; Go only compiles it into the test binary, never into
|
||||
// the production binary. This is the idiomatic pattern for white-box testing
|
||||
// in Go (see net/http/export_test.go in the stdlib for the same approach).
|
||||
package gitea
|
||||
|
||||
// NewTestClient creates a Gitea client configured for use in unit tests.
|
||||
// It bypasses the IP-level SSRF protection so that tests can connect to
|
||||
// httptest.Server instances (which listen on 127.0.0.1).
|
||||
//
|
||||
// Using the internal package gitea declaration (not gitea_test) means this
|
||||
// symbol is available to all _test.go files in this package. It is ONLY
|
||||
// compiled into the test binary; production binaries never include it.
|
||||
// Production code must use NewClient, which enables the safe dialer.
|
||||
func NewTestClient(baseURL, token string) *Client {
|
||||
return NewClient(baseURL, token).WithUnsafeDialer()
|
||||
}
|
||||
@@ -0,0 +1,91 @@
|
||||
// Package gitea provides a client for the Gitea API.
|
||||
// ipcheck.go implements IP-level SSRF protection by checking resolved addresses
|
||||
// against known blocked CIDR ranges (RFC1918, loopback, link-local, etc.).
|
||||
package gitea
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
)
|
||||
|
||||
// blockedCIDRStrings is the canonical list of CIDR strings that should never
|
||||
// be contacted by review-bot. See IsBlockedIP for the full list of covered
|
||||
// address families.
|
||||
//
|
||||
// These are hard-coded literals: any parse failure is a programming error.
|
||||
// Validity is verified by TestBlockedCIDRsValid in ipcheck_test.go.
|
||||
var blockedCIDRStrings = []string{
|
||||
// IPv4 loopback
|
||||
"127.0.0.0/8",
|
||||
// IPv4 unspecified / "this network"
|
||||
"0.0.0.0/8",
|
||||
// RFC1918 private ranges
|
||||
"10.0.0.0/8",
|
||||
"172.16.0.0/12",
|
||||
"192.168.0.0/16",
|
||||
// IPv4 link-local (APIPA, also used by AWS instance metadata 169.254.169.254)
|
||||
"169.254.0.0/16",
|
||||
// IPv4 shared address space (RFC6598, carrier-grade NAT)
|
||||
"100.64.0.0/10",
|
||||
// IPv4 multicast
|
||||
"224.0.0.0/4",
|
||||
// IPv4 reserved / broadcast
|
||||
"240.0.0.0/4",
|
||||
// IPv6 loopback
|
||||
"::1/128",
|
||||
// IPv6 unspecified
|
||||
"::/128",
|
||||
// IPv6 link-local
|
||||
"fe80::/10",
|
||||
// IPv6 unique local (ULA) — RFC4193
|
||||
"fc00::/7",
|
||||
// IPv6 multicast
|
||||
"ff00::/8",
|
||||
}
|
||||
|
||||
// blockedCIDRs is the parsed form of blockedCIDRStrings.
|
||||
// Any entry that fails to parse is recorded in blockedCIDRParseErrors instead
|
||||
// of panicking; tests verify this slice is always empty via TestBlockedCIDRsValid.
|
||||
var (
|
||||
blockedCIDRs []*net.IPNet
|
||||
blockedCIDRParseErrors []string
|
||||
)
|
||||
|
||||
func init() {
|
||||
blockedCIDRs = make([]*net.IPNet, 0, len(blockedCIDRStrings))
|
||||
for _, r := range blockedCIDRStrings {
|
||||
_, cidr, err := net.ParseCIDR(r)
|
||||
if err != nil {
|
||||
// Record the error rather than panicking; TestBlockedCIDRsValid
|
||||
// will catch this during tests, and the CI build will fail.
|
||||
blockedCIDRParseErrors = append(blockedCIDRParseErrors,
|
||||
fmt.Sprintf("ipcheck: invalid built-in CIDR %q: %v", r, err))
|
||||
continue
|
||||
}
|
||||
blockedCIDRs = append(blockedCIDRs, cidr)
|
||||
}
|
||||
}
|
||||
|
||||
// IsBlockedIP reports whether ip is in a blocked address range.
|
||||
// It is exported for use by the validate-url subcommand and tests outside
|
||||
// this package.
|
||||
//
|
||||
// IPv6-mapped IPv4 addresses (e.g. ::ffff:192.168.1.1) are normalized to their
|
||||
// IPv4 form before checking so that IPv4 CIDRs catch them.
|
||||
//
|
||||
// Based on:
|
||||
// - RFC1918 private ranges
|
||||
// - RFC5735 / RFC4193 special-use IPv4/IPv6 ranges
|
||||
// - RFC4291 IPv6 link-local / loopback
|
||||
func IsBlockedIP(ip net.IP) bool {
|
||||
// Normalize IPv6-mapped IPv4 addresses (::ffff:x.x.x.x) to plain IPv4.
|
||||
if v4 := ip.To4(); v4 != nil {
|
||||
ip = v4
|
||||
}
|
||||
for _, cidr := range blockedCIDRs {
|
||||
if cidr.Contains(ip) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
@@ -0,0 +1,144 @@
|
||||
package gitea
|
||||
|
||||
import (
|
||||
"net"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestIsBlockedIP(t *testing.T) {
|
||||
blocked := []struct {
|
||||
name string
|
||||
ip string
|
||||
}{
|
||||
// IPv4 loopback
|
||||
{"loopback 127.0.0.1", "127.0.0.1"},
|
||||
{"loopback 127.0.0.2", "127.0.0.2"},
|
||||
{"loopback 127.255.255.255", "127.255.255.255"},
|
||||
// IPv4 unspecified
|
||||
{"unspecified 0.0.0.0", "0.0.0.0"},
|
||||
{"unspecified 0.1.2.3", "0.1.2.3"},
|
||||
// RFC1918
|
||||
{"RFC1918 10.0.0.1", "10.0.0.1"},
|
||||
{"RFC1918 10.255.255.255", "10.255.255.255"},
|
||||
{"RFC1918 172.16.0.1", "172.16.0.1"},
|
||||
{"RFC1918 172.31.255.255", "172.31.255.255"},
|
||||
{"RFC1918 192.168.0.1", "192.168.0.1"},
|
||||
{"RFC1918 192.168.255.255", "192.168.255.255"},
|
||||
// Link-local (APIPA / AWS metadata)
|
||||
{"link-local 169.254.0.1", "169.254.0.1"},
|
||||
{"link-local 169.254.169.254", "169.254.169.254"},
|
||||
// Shared address space (carrier-grade NAT)
|
||||
{"CGN 100.64.0.1", "100.64.0.1"},
|
||||
{"CGN 100.127.255.255", "100.127.255.255"},
|
||||
// Multicast
|
||||
{"multicast 224.0.0.1", "224.0.0.1"},
|
||||
{"multicast 239.255.255.255", "239.255.255.255"},
|
||||
// Reserved
|
||||
{"reserved 240.0.0.1", "240.0.0.1"},
|
||||
{"broadcast 255.255.255.255", "255.255.255.255"},
|
||||
// IPv6 loopback
|
||||
{"IPv6 loopback ::1", "::1"},
|
||||
// IPv6 unspecified
|
||||
{"IPv6 unspecified ::", "::"},
|
||||
// IPv6 link-local
|
||||
{"IPv6 link-local fe80::1", "fe80::1"},
|
||||
{"IPv6 link-local fe80::dead:beef", "fe80::dead:beef"},
|
||||
// IPv6 ULA
|
||||
{"IPv6 ULA fc00::1", "fc00::1"},
|
||||
{"IPv6 ULA fd00::1", "fd00::1"},
|
||||
// IPv6 multicast
|
||||
{"IPv6 multicast ff02::1", "ff02::1"},
|
||||
}
|
||||
|
||||
for _, tc := range blocked {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
ip := net.ParseIP(tc.ip)
|
||||
if ip == nil {
|
||||
t.Fatalf("failed to parse IP %q", tc.ip)
|
||||
}
|
||||
if !IsBlockedIP(ip) {
|
||||
t.Errorf("IsBlockedIP(%q) = false, want true", tc.ip)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
allowed := []struct {
|
||||
name string
|
||||
ip string
|
||||
}{
|
||||
{"public 8.8.8.8", "8.8.8.8"},
|
||||
{"public 1.1.1.1", "1.1.1.1"},
|
||||
{"public 198.51.100.1", "198.51.100.1"}, // RFC5737 TEST-NET-2 — a documentation-only range;
|
||||
// not assigned to any real host, but intentionally left unblocked here because
|
||||
// it has no special routing treatment (unlike RFC1918/loopback/link-local) and
|
||||
// blocking it would require tracking every RFC5737 range without meaningful
|
||||
// security benefit (no server should ever listen on a TEST-NET address).
|
||||
{"public 151.101.1.1", "151.101.1.1"}, // Fastly
|
||||
{"public IPv6 2001:4860:4860::8888", "2001:4860:4860::8888"}, // Google DNS
|
||||
{"public IPv6 2606:4700:4700::1111", "2606:4700:4700::1111"}, // Cloudflare DNS
|
||||
}
|
||||
|
||||
for _, tc := range allowed {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
ip := net.ParseIP(tc.ip)
|
||||
if ip == nil {
|
||||
t.Fatalf("failed to parse IP %q", tc.ip)
|
||||
}
|
||||
if IsBlockedIP(ip) {
|
||||
t.Errorf("IsBlockedIP(%q) = true, want false", tc.ip)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsBlockedIPv6MappedIPv4(t *testing.T) {
|
||||
// ::ffff:192.168.1.1 is an IPv6-mapped IPv4 address — should be blocked as RFC1918.
|
||||
// Construct it manually as a 16-byte IP.
|
||||
mapped := net.IP{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0xff, 0xff, 192, 168, 1, 1}
|
||||
if !IsBlockedIP(mapped) {
|
||||
t.Errorf("IsBlockedIP(::ffff:192.168.1.1) = false, want true (IPv6-mapped IPv4 must be normalized)")
|
||||
}
|
||||
|
||||
// ::ffff:8.8.8.8 — IPv6-mapped public IP — should be allowed.
|
||||
mappedPublic := net.IP{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0xff, 0xff, 8, 8, 8, 8}
|
||||
if IsBlockedIP(mappedPublic) {
|
||||
t.Errorf("IsBlockedIP(::ffff:8.8.8.8) = true, want false")
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsBlockedIPEdgeCases(t *testing.T) {
|
||||
// The boundary between RFC1918 and public ranges.
|
||||
// 172.15.255.255 is NOT private (just below 172.16.0.0/12).
|
||||
notPrivate := net.ParseIP("172.15.255.255")
|
||||
if IsBlockedIP(notPrivate) {
|
||||
t.Errorf("IsBlockedIP(172.15.255.255) = true, want false (outside 172.16.0.0/12)")
|
||||
}
|
||||
// 172.32.0.0 is NOT private (just above 172.31.255.255).
|
||||
notPrivate2 := net.ParseIP("172.32.0.0")
|
||||
if IsBlockedIP(notPrivate2) {
|
||||
t.Errorf("IsBlockedIP(172.32.0.0) = true, want false (outside 172.16.0.0/12)")
|
||||
}
|
||||
// CGN: 100.63.255.255 is NOT in 100.64.0.0/10.
|
||||
notCGN := net.ParseIP("100.63.255.255")
|
||||
if IsBlockedIP(notCGN) {
|
||||
t.Errorf("IsBlockedIP(100.63.255.255) = true, want false (outside 100.64.0.0/10)")
|
||||
}
|
||||
// CGN: 100.128.0.0 is NOT in 100.64.0.0/10.
|
||||
notCGN2 := net.ParseIP("100.128.0.0")
|
||||
if IsBlockedIP(notCGN2) {
|
||||
t.Errorf("IsBlockedIP(100.128.0.0) = true, want false (outside 100.64.0.0/10)")
|
||||
}
|
||||
}
|
||||
|
||||
// TestBlockedCIDRsValid verifies that all entries in blockedCIDRStrings parse
|
||||
// successfully. This catches programming errors in the CIDR list without
|
||||
// requiring a startup panic. The init() function records parse failures in
|
||||
// blockedCIDRParseErrors rather than panicking; this test makes those failures
|
||||
// visible as test failures during CI.
|
||||
func TestBlockedCIDRsValid(t *testing.T) {
|
||||
if len(blockedCIDRParseErrors) > 0 {
|
||||
for _, msg := range blockedCIDRParseErrors {
|
||||
t.Errorf("CIDR parse error: %s", msg)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -31,7 +31,7 @@ func TestPostReview_WithComments(t *testing.T) {
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
client := NewClient(server.URL, "test-token")
|
||||
client := NewTestClient(server.URL, "test-token")
|
||||
comments := []ReviewComment{
|
||||
{Path: "main.go", NewPosition: 42, Body: "[MAJOR] Something bad"},
|
||||
{Path: "util.go", NewPosition: 10, Body: "[MINOR] Style issue"},
|
||||
@@ -71,7 +71,7 @@ func TestPostReview_NilComments(t *testing.T) {
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
client := NewClient(server.URL, "test-token")
|
||||
client := NewTestClient(server.URL, "test-token")
|
||||
_, err := client.PostReview(context.Background(), "owner", "repo", 1, "APPROVED", "all good", "", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
|
||||
+457
-4
@@ -4,7 +4,10 @@
|
||||
package github
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
@@ -92,10 +95,6 @@ func asAPIError(err error) (*APIError, bool) {
|
||||
// SetHTTPClient and SetRetryBackoff are intended for test setup only and must
|
||||
// be called before any goroutines issue requests; they have no synchronization.
|
||||
type Client struct {
|
||||
// TODO: baseURL is populated by NewClient but not yet consumed by doRequest/doGet.
|
||||
// Higher-level exported methods (GetPullRequest, etc.) will use it to
|
||||
// construct request URLs; remove this field if those methods end up
|
||||
// accepting full URLs instead.
|
||||
baseURL string
|
||||
token string
|
||||
httpClient *http.Client
|
||||
@@ -376,3 +375,457 @@ func (c *Client) doRequest(ctx context.Context, method, reqURL string, accept st
|
||||
func (c *Client) doGet(ctx context.Context, url string) ([]byte, error) {
|
||||
return c.doRequest(ctx, http.MethodGet, url, "")
|
||||
}
|
||||
|
||||
// doRequestWithBody performs an HTTP request with an optional body, applying the
|
||||
// same HTTPS enforcement as doRequest. It is used by write methods (POST, PUT,
|
||||
// DELETE) that bypass the retry loop in doRequest because write operations are
|
||||
// not idempotent.
|
||||
//
|
||||
// body may be nil for requests that carry no payload (e.g. DELETE).
|
||||
// When body is non-nil, Content-Type is set to application/json.
|
||||
func (c *Client) doRequestWithBody(ctx context.Context, method, reqURL string, body []byte) ([]byte, error) {
|
||||
if !c.allowInsecureHTTP {
|
||||
parsed, err := url.Parse(reqURL)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("parse request URL: %w", err)
|
||||
}
|
||||
if strings.EqualFold(parsed.Scheme, "http") {
|
||||
return nil, fmt.Errorf("refusing HTTP request to %s: use HTTPS or set AllowInsecureHTTP option", redactURL(reqURL))
|
||||
}
|
||||
}
|
||||
|
||||
var reqBody io.Reader
|
||||
if body != nil {
|
||||
reqBody = bytes.NewReader(body)
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, method, reqURL, reqBody)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create request: %w", err)
|
||||
}
|
||||
req.Header.Set("Authorization", "Bearer "+c.token)
|
||||
req.Header.Set("Accept", "application/vnd.github+json")
|
||||
if body != nil {
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
}
|
||||
|
||||
resp, err := c.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("do request: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode >= 200 && resp.StatusCode < 300 {
|
||||
respBody, err := io.ReadAll(io.LimitReader(resp.Body, maxResponseBodyBytes))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("read response body: %w", err)
|
||||
}
|
||||
return respBody, nil
|
||||
}
|
||||
|
||||
errBody, _ := io.ReadAll(io.LimitReader(resp.Body, maxErrorBodyBytes))
|
||||
return nil, &APIError{StatusCode: resp.StatusCode, Body: string(errBody)}
|
||||
}
|
||||
|
||||
// --- API types ---
|
||||
|
||||
// PullRequest holds relevant PR metadata.
|
||||
type PullRequest struct {
|
||||
Title string `json:"title"`
|
||||
Body string `json:"body"`
|
||||
Head struct {
|
||||
Sha string `json:"sha"`
|
||||
Ref string `json:"ref"`
|
||||
} `json:"head"`
|
||||
Draft bool `json:"draft"`
|
||||
}
|
||||
|
||||
// CommitStatus represents a single CI status entry.
|
||||
// GitHub returns "state" not "status"; this type uses Status for consistency
|
||||
// with the gitea package (both are normalized before use).
|
||||
type CommitStatus struct {
|
||||
Status string `json:"state"` // GitHub field is "state"
|
||||
Context string `json:"context"`
|
||||
Description string `json:"description"`
|
||||
TargetURL string `json:"target_url"`
|
||||
}
|
||||
|
||||
// ChangedFile represents a file modified in a PR.
|
||||
type ChangedFile struct {
|
||||
Filename string `json:"filename"`
|
||||
Status string `json:"status"`
|
||||
}
|
||||
|
||||
// ReviewComment represents an inline comment to attach to a review.
|
||||
// GitHub uses "position" (diff hunk position), whereas Gitea uses "new_position" (line number).
|
||||
// When posting inline comments on GitHub, position is required; line numbers
|
||||
// from the diff cannot be used directly.
|
||||
type ReviewComment struct {
|
||||
ID int64 `json:"id,omitempty"`
|
||||
Path string `json:"path"`
|
||||
Position int64 `json:"position,omitempty"` // GitHub diff hunk position
|
||||
Line int64 `json:"line,omitempty"` // GitHub absolute line number (alternative to position)
|
||||
Side string `json:"side,omitempty"` // "RIGHT" or "LEFT"
|
||||
Body string `json:"body"`
|
||||
}
|
||||
|
||||
// Review represents a pull request review from the GitHub API.
|
||||
type Review struct {
|
||||
ID int64 `json:"id"`
|
||||
Body string `json:"body"`
|
||||
User struct {
|
||||
Login string `json:"login"`
|
||||
} `json:"user"`
|
||||
State string `json:"state"`
|
||||
}
|
||||
|
||||
// contentResponse is the GitHub contents API response for a single file.
|
||||
type contentResponse struct {
|
||||
Name string `json:"name"`
|
||||
Path string `json:"path"`
|
||||
Type string `json:"type"` // "file" or "dir" or "symlink" or "submodule"
|
||||
Content string `json:"content"` // Base64-encoded file content (with embedded newlines)
|
||||
Encoding string `json:"encoding"` // "base64" or ""
|
||||
}
|
||||
|
||||
// ContentEntry represents a file or directory entry from the contents API.
|
||||
type ContentEntry struct {
|
||||
Name string `json:"name"`
|
||||
Path string `json:"path"`
|
||||
Type string `json:"type"` // "file" or "dir"
|
||||
}
|
||||
|
||||
// --- PR methods ---
|
||||
|
||||
// GetPullRequest fetches PR metadata.
|
||||
func (c *Client) GetPullRequest(ctx context.Context, owner, repo string, number int) (*PullRequest, error) {
|
||||
reqURL := fmt.Sprintf("%s/repos/%s/%s/pulls/%d",
|
||||
c.baseURL, url.PathEscape(owner), url.PathEscape(repo), number)
|
||||
body, err := c.doGet(ctx, reqURL)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("fetch PR: %w", err)
|
||||
}
|
||||
var pr PullRequest
|
||||
if err := json.Unmarshal(body, &pr); err != nil {
|
||||
return nil, fmt.Errorf("parse PR JSON: %w", err)
|
||||
}
|
||||
return &pr, nil
|
||||
}
|
||||
|
||||
// GetPullRequestDiff fetches the unified diff for a PR.
|
||||
func (c *Client) GetPullRequestDiff(ctx context.Context, owner, repo string, number int) (string, error) {
|
||||
reqURL := fmt.Sprintf("%s/repos/%s/%s/pulls/%d",
|
||||
c.baseURL, url.PathEscape(owner), url.PathEscape(repo), number)
|
||||
body, err := c.doRequest(ctx, http.MethodGet, reqURL, "application/vnd.github.diff")
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("fetch diff: %w", err)
|
||||
}
|
||||
return string(body), nil
|
||||
}
|
||||
|
||||
// GetPullRequestFiles fetches the list of files changed in a PR.
|
||||
// GitHub paginates this endpoint (100 per page max).
|
||||
func (c *Client) GetPullRequestFiles(ctx context.Context, owner, repo string, number int) ([]ChangedFile, error) {
|
||||
const perPage = 100
|
||||
var all []ChangedFile
|
||||
for page := 1; ; page++ {
|
||||
reqURL := fmt.Sprintf("%s/repos/%s/%s/pulls/%d/files?per_page=%d&page=%d",
|
||||
c.baseURL, url.PathEscape(owner), url.PathEscape(repo), number, perPage, page)
|
||||
body, err := c.doGet(ctx, reqURL)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("fetch PR files (page %d): %w", page, err)
|
||||
}
|
||||
var batch []ChangedFile
|
||||
if err := json.Unmarshal(body, &batch); err != nil {
|
||||
return nil, fmt.Errorf("parse PR files JSON (page %d): %w", page, err)
|
||||
}
|
||||
all = append(all, batch...)
|
||||
if len(batch) < perPage {
|
||||
break
|
||||
}
|
||||
}
|
||||
return all, nil
|
||||
}
|
||||
|
||||
// GetCommitStatuses fetches CI statuses for a commit SHA.
|
||||
// GitHub has two status systems: legacy "commit statuses" and newer "check runs".
|
||||
// This method returns commit statuses only; check runs are a separate API.
|
||||
// Note: GitHub returns "state" in the JSON; CommitStatus.Status is tagged accordingly.
|
||||
func (c *Client) GetCommitStatuses(ctx context.Context, owner, repo, sha string) ([]CommitStatus, error) {
|
||||
const perPage = 100
|
||||
var all []CommitStatus
|
||||
for page := 1; ; page++ {
|
||||
reqURL := fmt.Sprintf("%s/repos/%s/%s/commits/%s/statuses?per_page=%d&page=%d",
|
||||
c.baseURL, url.PathEscape(owner), url.PathEscape(repo), url.PathEscape(sha), perPage, page)
|
||||
body, err := c.doGet(ctx, reqURL)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("fetch commit statuses (page %d): %w", page, err)
|
||||
}
|
||||
var batch []CommitStatus
|
||||
if err := json.Unmarshal(body, &batch); err != nil {
|
||||
return nil, fmt.Errorf("parse statuses JSON (page %d): %w", page, err)
|
||||
}
|
||||
all = append(all, batch...)
|
||||
if len(batch) < perPage {
|
||||
break
|
||||
}
|
||||
}
|
||||
return all, nil
|
||||
}
|
||||
|
||||
// --- File content methods ---
|
||||
|
||||
// GetFileContent fetches a file from the default branch of a repo.
|
||||
// GitHub returns base64-encoded content; this method decodes it.
|
||||
func (c *Client) GetFileContent(ctx context.Context, owner, repo, filepath string) (string, error) {
|
||||
return c.getFileContentAtRef(ctx, owner, repo, filepath, "")
|
||||
}
|
||||
|
||||
// GetFileContentRef fetches a file from a specific ref (branch/tag/sha).
|
||||
func (c *Client) GetFileContentRef(ctx context.Context, owner, repo, filepath, ref string) (string, error) {
|
||||
return c.getFileContentAtRef(ctx, owner, repo, filepath, ref)
|
||||
}
|
||||
|
||||
// getFileContentAtRef fetches a file at the given ref (empty = default branch).
|
||||
// GitHub's contents API returns base64-encoded file content.
|
||||
func (c *Client) getFileContentAtRef(ctx context.Context, owner, repo, filepath, ref string) (string, error) {
|
||||
reqURL := fmt.Sprintf("%s/repos/%s/%s/contents/%s",
|
||||
c.baseURL, url.PathEscape(owner), url.PathEscape(repo), escapePath(filepath))
|
||||
if ref != "" {
|
||||
reqURL += "?ref=" + url.QueryEscape(ref)
|
||||
}
|
||||
body, err := c.doGet(ctx, reqURL)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("fetch file %s: %w", filepath, err)
|
||||
}
|
||||
var resp contentResponse
|
||||
if err := json.Unmarshal(body, &resp); err != nil {
|
||||
return "", fmt.Errorf("parse file content JSON for %s: %w", filepath, err)
|
||||
}
|
||||
if resp.Type != "file" {
|
||||
return "", fmt.Errorf("path %s is a %s, not a file", filepath, resp.Type)
|
||||
}
|
||||
if resp.Encoding == "base64" {
|
||||
// GitHub embeds newlines in the base64 content for readability.
|
||||
// Strip them before decoding.
|
||||
cleaned := strings.ReplaceAll(resp.Content, "\n", "")
|
||||
decoded, err := base64.StdEncoding.DecodeString(cleaned)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("decode base64 content for %s: %w", filepath, err)
|
||||
}
|
||||
return string(decoded), nil
|
||||
}
|
||||
// Non-base64 encoding (shouldn't happen normally, but handle gracefully).
|
||||
return resp.Content, nil
|
||||
}
|
||||
|
||||
// ListContents lists files and directories at a given path.
|
||||
// Pass an empty path to list the repository root.
|
||||
// GitHub returns a single object (not array) when path is a file — this
|
||||
// method normalizes both cases to a slice, matching Gitea's behavior.
|
||||
func (c *Client) ListContents(ctx context.Context, owner, repo, path string) ([]ContentEntry, error) {
|
||||
var reqURL string
|
||||
if path == "" || path == "." {
|
||||
reqURL = fmt.Sprintf("%s/repos/%s/%s/contents",
|
||||
c.baseURL, url.PathEscape(owner), url.PathEscape(repo))
|
||||
} else {
|
||||
reqURL = fmt.Sprintf("%s/repos/%s/%s/contents/%s",
|
||||
c.baseURL, url.PathEscape(owner), url.PathEscape(repo), escapePath(path))
|
||||
}
|
||||
body, err := c.doGet(ctx, reqURL)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("list contents %s: %w", path, err)
|
||||
}
|
||||
|
||||
var entries []ContentEntry
|
||||
if err := json.Unmarshal(body, &entries); err != nil {
|
||||
// GitHub returns a single object when path is a file (not an array).
|
||||
var single contentResponse
|
||||
if err2 := json.Unmarshal(body, &single); err2 != nil {
|
||||
return nil, fmt.Errorf("parse contents JSON: %w", err)
|
||||
}
|
||||
if single.Name == "" && single.Path == "" {
|
||||
return nil, fmt.Errorf("parse contents JSON: empty response for path %q", path)
|
||||
}
|
||||
entries = []ContentEntry{{
|
||||
Name: single.Name,
|
||||
Path: single.Path,
|
||||
Type: single.Type,
|
||||
}}
|
||||
}
|
||||
return entries, nil
|
||||
}
|
||||
|
||||
// GetAllFilesInPath recursively fetches all file contents under a path.
|
||||
// If the path is a file, returns just that file's content.
|
||||
// If the path is a directory, recursively fetches all files within it.
|
||||
func (c *Client) GetAllFilesInPath(ctx context.Context, owner, repo, path string) (map[string]string, error) {
|
||||
results := make(map[string]string)
|
||||
|
||||
entries, err := c.ListContents(ctx, owner, repo, path)
|
||||
if err != nil {
|
||||
if !IsNotFound(err) {
|
||||
return nil, fmt.Errorf("list contents %q: %w", path, err)
|
||||
}
|
||||
// 404 means path may be a file — try fetching directly.
|
||||
content, fileErr := c.GetFileContent(ctx, owner, repo, path)
|
||||
if fileErr != nil {
|
||||
return nil, fmt.Errorf("path %q is neither a file nor directory: %w", path, fileErr)
|
||||
}
|
||||
results[path] = content
|
||||
return results, nil
|
||||
}
|
||||
|
||||
for _, entry := range entries {
|
||||
switch entry.Type {
|
||||
case "file":
|
||||
content, err := c.GetFileContent(ctx, owner, repo, entry.Path)
|
||||
if err != nil {
|
||||
slog.Warn("could not fetch file from patterns repo", "file", entry.Path, "error", err)
|
||||
continue
|
||||
}
|
||||
results[entry.Path] = content
|
||||
case "dir":
|
||||
subResults, err := c.GetAllFilesInPath(ctx, owner, repo, entry.Path)
|
||||
if err != nil {
|
||||
slog.Warn("could not recurse into directory", "dir", entry.Path, "error", err)
|
||||
continue
|
||||
}
|
||||
for k, v := range subResults {
|
||||
results[k] = v
|
||||
}
|
||||
}
|
||||
}
|
||||
return results, nil
|
||||
}
|
||||
|
||||
// --- Review methods ---
|
||||
|
||||
// PostReview submits a review to a PR.
|
||||
// event should be one of "APPROVE", "REQUEST_CHANGES", or "COMMENT".
|
||||
// commitID anchors the review to a specific commit SHA. If empty, defaults to current HEAD.
|
||||
// comments are optional inline comments; GitHub uses diff hunk position (not line numbers).
|
||||
// Note: unlike Gitea, GitHub does not support deleting submitted reviews.
|
||||
// Use COMMENT event to supersede old reviews.
|
||||
func (c *Client) PostReview(ctx context.Context, owner, repo string, number int, event, body, commitID string, comments []ReviewComment) (*Review, error) {
|
||||
reqURL := fmt.Sprintf("%s/repos/%s/%s/pulls/%d/reviews",
|
||||
c.baseURL, url.PathEscape(owner), url.PathEscape(repo), number)
|
||||
|
||||
payload := struct {
|
||||
Body string `json:"body"`
|
||||
Event string `json:"event"`
|
||||
CommitID string `json:"commit_id,omitempty"`
|
||||
Comments []ReviewComment `json:"comments,omitempty"`
|
||||
}{
|
||||
Body: body,
|
||||
Event: event,
|
||||
CommitID: commitID,
|
||||
Comments: comments,
|
||||
}
|
||||
|
||||
data, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("marshal review payload: %w", err)
|
||||
}
|
||||
|
||||
respBody, err := c.doRequestWithBody(ctx, http.MethodPost, reqURL, data)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("post review: %w", err)
|
||||
}
|
||||
|
||||
var review Review
|
||||
if err := json.Unmarshal(respBody, &review); err != nil {
|
||||
return nil, fmt.Errorf("parse review response: %w", err)
|
||||
}
|
||||
return &review, nil
|
||||
}
|
||||
|
||||
// ListReviews returns all reviews on a pull request.
|
||||
// GitHub paginates via Link header; this method uses per_page=100.
|
||||
func (c *Client) ListReviews(ctx context.Context, owner, repo string, number int) ([]Review, error) {
|
||||
const perPage = 100
|
||||
var all []Review
|
||||
for page := 1; ; page++ {
|
||||
reqURL := fmt.Sprintf("%s/repos/%s/%s/pulls/%d/reviews?per_page=%d&page=%d",
|
||||
c.baseURL, url.PathEscape(owner), url.PathEscape(repo), number, perPage, page)
|
||||
body, err := c.doGet(ctx, reqURL)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("list reviews (page %d): %w", page, err)
|
||||
}
|
||||
var batch []Review
|
||||
if err := json.Unmarshal(body, &batch); err != nil {
|
||||
return nil, fmt.Errorf("parse reviews (page %d): %w", page, err)
|
||||
}
|
||||
all = append(all, batch...)
|
||||
if len(batch) < perPage {
|
||||
break
|
||||
}
|
||||
}
|
||||
return all, nil
|
||||
}
|
||||
|
||||
// DeleteReview attempts to delete a pull request review.
|
||||
// GitHub only allows deleting PENDING (draft) reviews. Submitted reviews cannot
|
||||
// be deleted via the API; this method returns a descriptive error in that case.
|
||||
// review-bot callers should handle this error gracefully (e.g., by not attempting
|
||||
// supersede and instead posting a new review alongside the old one).
|
||||
func (c *Client) DeleteReview(ctx context.Context, owner, repo string, number int, reviewID int64) error {
|
||||
reqURL := fmt.Sprintf("%s/repos/%s/%s/pulls/%d/reviews/%d",
|
||||
c.baseURL, url.PathEscape(owner), url.PathEscape(repo), number, reviewID)
|
||||
|
||||
// nil body: the GitHub DELETE endpoint for reviews requires no request body.
|
||||
_, err := c.doRequestWithBody(ctx, http.MethodDelete, reqURL, nil)
|
||||
if err != nil {
|
||||
return fmt.Errorf("delete review: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetAuthenticatedUser returns the login of the authenticated user.
|
||||
func (c *Client) GetAuthenticatedUser(ctx context.Context) (string, error) {
|
||||
reqURL := c.baseURL + "/user"
|
||||
body, err := c.doGet(ctx, reqURL)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("get authenticated user: %w", err)
|
||||
}
|
||||
var result struct {
|
||||
Login string `json:"login"`
|
||||
}
|
||||
if err := json.Unmarshal(body, &result); err != nil {
|
||||
return "", fmt.Errorf("parse user response: %w", err)
|
||||
}
|
||||
return result.Login, nil
|
||||
}
|
||||
|
||||
// RequestReviewer adds a user as a requested reviewer on a pull request.
|
||||
// This is idempotent — requesting an already-requested reviewer is a no-op.
|
||||
func (c *Client) RequestReviewer(ctx context.Context, owner, repo string, number int, reviewer string) error {
|
||||
reqURL := fmt.Sprintf("%s/repos/%s/%s/pulls/%d/requested_reviewers",
|
||||
c.baseURL, url.PathEscape(owner), url.PathEscape(repo), number)
|
||||
|
||||
payload := struct {
|
||||
Reviewers []string `json:"reviewers"`
|
||||
}{Reviewers: []string{reviewer}}
|
||||
data, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
return fmt.Errorf("marshal reviewer request: %w", err)
|
||||
}
|
||||
|
||||
_, err = c.doRequestWithBody(ctx, http.MethodPost, reqURL, data)
|
||||
if err != nil {
|
||||
return fmt.Errorf("request reviewer: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// --- helpers ---
|
||||
|
||||
// escapePath escapes each segment of a relative file path for use in URLs.
|
||||
// Slashes are preserved as path separators; other special characters are escaped.
|
||||
func escapePath(p string) string {
|
||||
parts := strings.Split(p, "/")
|
||||
for i, part := range parts {
|
||||
parts[i] = url.PathEscape(part)
|
||||
}
|
||||
return strings.Join(parts, "/")
|
||||
}
|
||||
|
||||
@@ -2,7 +2,9 @@ package github
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
@@ -656,3 +658,570 @@ func TestRedactURL_UserinfoWithQuery(t *testing.T) {
|
||||
t.Errorf("redactURL = %q, want %q", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
// --- Tests for API methods ---
|
||||
|
||||
func TestGetPullRequest_Success(t *testing.T) {
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path != "/repos/owner/repo/pulls/42" {
|
||||
t.Errorf("unexpected path: %s", r.URL.Path)
|
||||
}
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte(`{"title":"Test PR","body":"description","head":{"sha":"abc123","ref":"feature"},"draft":false}`))
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
c := NewClient("tok", srv.URL, AllowInsecureHTTPForTest())
|
||||
pr, err := c.GetPullRequest(context.Background(), "owner", "repo", 42)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if pr.Title != "Test PR" {
|
||||
t.Errorf("Title = %q, want %q", pr.Title, "Test PR")
|
||||
}
|
||||
if pr.Head.Sha != "abc123" {
|
||||
t.Errorf("Head.Sha = %q, want %q", pr.Head.Sha, "abc123")
|
||||
}
|
||||
if pr.Head.Ref != "feature" {
|
||||
t.Errorf("Head.Ref = %q, want %q", pr.Head.Ref, "feature")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetPullRequest_NotFound(t *testing.T) {
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
w.Write([]byte(`{"message":"Not Found"}`))
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
c := NewClient("tok", srv.URL, AllowInsecureHTTPForTest())
|
||||
_, err := c.GetPullRequest(context.Background(), "owner", "repo", 99)
|
||||
if err == nil {
|
||||
t.Fatal("expected error, got nil")
|
||||
}
|
||||
if !IsNotFound(err) {
|
||||
t.Errorf("expected IsNotFound=true, got false for error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetPullRequestDiff_Success(t *testing.T) {
|
||||
const wantDiff = "diff --git a/foo.go b/foo.go\n--- a/foo.go\n+++ b/foo.go\n"
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Header.Get("Accept") != "application/vnd.github.diff" {
|
||||
t.Errorf("Accept = %q, want application/vnd.github.diff", r.Header.Get("Accept"))
|
||||
}
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte(wantDiff))
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
c := NewClient("tok", srv.URL, AllowInsecureHTTPForTest())
|
||||
diff, err := c.GetPullRequestDiff(context.Background(), "owner", "repo", 1)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if diff != wantDiff {
|
||||
t.Errorf("diff = %q, want %q", diff, wantDiff)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetPullRequestFiles_Success(t *testing.T) {
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte(`[{"filename":"foo.go","status":"modified"},{"filename":"bar.go","status":"added"}]`))
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
c := NewClient("tok", srv.URL, AllowInsecureHTTPForTest())
|
||||
files, err := c.GetPullRequestFiles(context.Background(), "owner", "repo", 1)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if len(files) != 2 {
|
||||
t.Fatalf("len(files) = %d, want 2", len(files))
|
||||
}
|
||||
if files[0].Filename != "foo.go" || files[0].Status != "modified" {
|
||||
t.Errorf("files[0] = %+v, want {foo.go modified}", files[0])
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetPullRequestFiles_Paginated(t *testing.T) {
|
||||
page := 0
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
page++
|
||||
if page == 1 {
|
||||
// Return 100 items (page full → expect another request)
|
||||
items := make([]map[string]string, 100)
|
||||
for i := range items {
|
||||
items[i] = map[string]string{"filename": fmt.Sprintf("file%d.go", i), "status": "modified"}
|
||||
}
|
||||
data, _ := json.Marshal(items)
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write(data)
|
||||
return
|
||||
}
|
||||
// Page 2: return fewer than perPage → stop
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte(`[{"filename":"last.go","status":"added"}]`))
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
c := NewClient("tok", srv.URL, AllowInsecureHTTPForTest())
|
||||
files, err := c.GetPullRequestFiles(context.Background(), "owner", "repo", 1)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if len(files) != 101 {
|
||||
t.Errorf("len(files) = %d, want 101", len(files))
|
||||
}
|
||||
if page != 2 {
|
||||
t.Errorf("page = %d, want 2", page)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetCommitStatuses_Success(t *testing.T) {
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
// GitHub uses "state" field
|
||||
w.Write([]byte(`[{"state":"success","context":"ci/test","description":"Tests pass","target_url":"https://ci.example.com"}]`))
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
c := NewClient("tok", srv.URL, AllowInsecureHTTPForTest())
|
||||
statuses, err := c.GetCommitStatuses(context.Background(), "owner", "repo", "deadbeef")
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if len(statuses) != 1 {
|
||||
t.Fatalf("len(statuses) = %d, want 1", len(statuses))
|
||||
}
|
||||
if statuses[0].Status != "success" {
|
||||
t.Errorf("Status = %q, want %q", statuses[0].Status, "success")
|
||||
}
|
||||
if statuses[0].Context != "ci/test" {
|
||||
t.Errorf("Context = %q, want %q", statuses[0].Context, "ci/test")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetFileContent_Base64(t *testing.T) {
|
||||
// "hello world\n" base64-encoded with embedded newlines (as GitHub does it)
|
||||
encoded := "aGVsbG8gd29ybGQK"
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if !strings.HasSuffix(r.URL.Path, "/contents/README.md") {
|
||||
t.Errorf("unexpected path: %s", r.URL.Path)
|
||||
}
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte(`{"name":"README.md","path":"README.md","type":"file","content":"` + encoded + `","encoding":"base64"}`))
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
c := NewClient("tok", srv.URL, AllowInsecureHTTPForTest())
|
||||
content, err := c.GetFileContent(context.Background(), "owner", "repo", "README.md")
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if content != "hello world\n" {
|
||||
t.Errorf("content = %q, want %q", content, "hello world\n")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetFileContent_Base64WithNewlines(t *testing.T) {
|
||||
// GitHub embeds newlines in base64 content for readability (every 60 chars)
|
||||
// Test that we strip them correctly before decoding
|
||||
// "hello world\n" = aGVsbG8gd29ybGQK — split it with embedded \n
|
||||
encoded := "aGVs\nbG8g\nd29y\nbGQK"
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
// JSON-encode the embedded newlines as \n
|
||||
body := `{"name":"README.md","path":"README.md","type":"file","content":"aGVs\nbG8g\nd29y\nbGQK","encoding":"base64"}`
|
||||
_ = encoded // suppress unused warning
|
||||
w.Write([]byte(body))
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
c := NewClient("tok", srv.URL, AllowInsecureHTTPForTest())
|
||||
content, err := c.GetFileContent(context.Background(), "owner", "repo", "README.md")
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if content != "hello world\n" {
|
||||
t.Errorf("content = %q, want %q", content, "hello world\n")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetFileContent_IsDirectory(t *testing.T) {
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte(`{"name":"docs","path":"docs","type":"dir","content":"","encoding":""}`))
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
c := NewClient("tok", srv.URL, AllowInsecureHTTPForTest())
|
||||
_, err := c.GetFileContent(context.Background(), "owner", "repo", "docs")
|
||||
if err == nil {
|
||||
t.Fatal("expected error for directory, got nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetFileContentRef_Success(t *testing.T) {
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Query().Get("ref") != "main" {
|
||||
t.Errorf("ref = %q, want %q", r.URL.Query().Get("ref"), "main")
|
||||
}
|
||||
encoded := "dGVzdA==" // "test"
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte(`{"name":"foo.go","path":"foo.go","type":"file","content":"` + encoded + `","encoding":"base64"}`))
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
c := NewClient("tok", srv.URL, AllowInsecureHTTPForTest())
|
||||
content, err := c.GetFileContentRef(context.Background(), "owner", "repo", "foo.go", "main")
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if content != "test" {
|
||||
t.Errorf("content = %q, want %q", content, "test")
|
||||
}
|
||||
}
|
||||
|
||||
func TestListContents_Directory(t *testing.T) {
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte(`[{"name":"foo.go","path":"foo.go","type":"file"},{"name":"bar","path":"bar","type":"dir"}]`))
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
c := NewClient("tok", srv.URL, AllowInsecureHTTPForTest())
|
||||
entries, err := c.ListContents(context.Background(), "owner", "repo", "")
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if len(entries) != 2 {
|
||||
t.Fatalf("len(entries) = %d, want 2", len(entries))
|
||||
}
|
||||
if entries[0].Name != "foo.go" || entries[0].Type != "file" {
|
||||
t.Errorf("entries[0] = %+v, unexpected", entries[0])
|
||||
}
|
||||
}
|
||||
|
||||
func TestListContents_SingleFile(t *testing.T) {
|
||||
// GitHub returns a single object when the path is a file
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte(`{"name":"README.md","path":"README.md","type":"file","content":"","encoding":""}`))
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
c := NewClient("tok", srv.URL, AllowInsecureHTTPForTest())
|
||||
entries, err := c.ListContents(context.Background(), "owner", "repo", "README.md")
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if len(entries) != 1 {
|
||||
t.Fatalf("len(entries) = %d, want 1", len(entries))
|
||||
}
|
||||
if entries[0].Name != "README.md" {
|
||||
t.Errorf("entries[0].Name = %q, want README.md", entries[0].Name)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPostReview_Success(t *testing.T) {
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodPost {
|
||||
t.Errorf("method = %s, want POST", r.Method)
|
||||
}
|
||||
if r.URL.Path != "/repos/owner/repo/pulls/1/reviews" {
|
||||
t.Errorf("path = %s, unexpected", r.URL.Path)
|
||||
}
|
||||
var payload map[string]interface{}
|
||||
if err := json.NewDecoder(r.Body).Decode(&payload); err != nil {
|
||||
t.Errorf("decode body: %v", err)
|
||||
}
|
||||
if payload["event"] != "APPROVE" {
|
||||
t.Errorf("event = %v, want APPROVE", payload["event"])
|
||||
}
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte(`{"id":99,"body":"looks good","user":{"login":"bot"},"state":"APPROVED"}`))
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
c := NewClient("tok", srv.URL, AllowInsecureHTTPForTest())
|
||||
review, err := c.PostReview(context.Background(), "owner", "repo", 1, "APPROVE", "looks good", "abc", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if review.ID != 99 {
|
||||
t.Errorf("review.ID = %d, want 99", review.ID)
|
||||
}
|
||||
if review.User.Login != "bot" {
|
||||
t.Errorf("review.User.Login = %q, want bot", review.User.Login)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPostReview_Unauthorized(t *testing.T) {
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusUnauthorized)
|
||||
w.Write([]byte(`{"message":"Bad credentials"}`))
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
c := NewClient("bad-tok", srv.URL, AllowInsecureHTTPForTest())
|
||||
_, err := c.PostReview(context.Background(), "owner", "repo", 1, "APPROVE", "body", "", nil)
|
||||
if err == nil {
|
||||
t.Fatal("expected error, got nil")
|
||||
}
|
||||
if !IsUnauthorized(err) {
|
||||
t.Errorf("expected IsUnauthorized=true, got false for error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestListReviews_Success(t *testing.T) {
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte(`[{"id":1,"body":"review 1","user":{"login":"alice"},"state":"APPROVED"},{"id":2,"body":"review 2","user":{"login":"bob"},"state":"CHANGES_REQUESTED"}]`))
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
c := NewClient("tok", srv.URL, AllowInsecureHTTPForTest())
|
||||
reviews, err := c.ListReviews(context.Background(), "owner", "repo", 1)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if len(reviews) != 2 {
|
||||
t.Fatalf("len(reviews) = %d, want 2", len(reviews))
|
||||
}
|
||||
if reviews[0].ID != 1 || reviews[0].User.Login != "alice" {
|
||||
t.Errorf("reviews[0] = %+v, unexpected", reviews[0])
|
||||
}
|
||||
}
|
||||
|
||||
func TestDeleteReview_Success(t *testing.T) {
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodDelete {
|
||||
t.Errorf("method = %s, want DELETE", r.Method)
|
||||
}
|
||||
if r.URL.Path != "/repos/owner/repo/pulls/1/reviews/42" {
|
||||
t.Errorf("path = %s, unexpected", r.URL.Path)
|
||||
}
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
c := NewClient("tok", srv.URL, AllowInsecureHTTPForTest())
|
||||
err := c.DeleteReview(context.Background(), "owner", "repo", 1, 42)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDeleteReview_SubmittedReview(t *testing.T) {
|
||||
// GitHub returns 422 for trying to delete a non-pending review
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusUnprocessableEntity)
|
||||
w.Write([]byte(`{"message":"Can only delete a pending review"}`))
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
c := NewClient("tok", srv.URL, AllowInsecureHTTPForTest())
|
||||
err := c.DeleteReview(context.Background(), "owner", "repo", 1, 99)
|
||||
if err == nil {
|
||||
t.Fatal("expected error, got nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetAuthenticatedUser_Success(t *testing.T) {
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path != "/user" {
|
||||
t.Errorf("path = %s, want /user", r.URL.Path)
|
||||
}
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte(`{"login":"review-bot","id":12345}`))
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
c := NewClient("tok", srv.URL, AllowInsecureHTTPForTest())
|
||||
login, err := c.GetAuthenticatedUser(context.Background())
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if login != "review-bot" {
|
||||
t.Errorf("login = %q, want review-bot", login)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRequestReviewer_Success(t *testing.T) {
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodPost {
|
||||
t.Errorf("method = %s, want POST", r.Method)
|
||||
}
|
||||
if r.URL.Path != "/repos/owner/repo/pulls/1/requested_reviewers" {
|
||||
t.Errorf("path = %s, unexpected", r.URL.Path)
|
||||
}
|
||||
var payload map[string]interface{}
|
||||
if err := json.NewDecoder(r.Body).Decode(&payload); err != nil {
|
||||
t.Errorf("decode body: %v", err)
|
||||
}
|
||||
reviewers, ok := payload["reviewers"].([]interface{})
|
||||
if !ok || len(reviewers) != 1 || reviewers[0] != "reviewer1" {
|
||||
t.Errorf("reviewers = %v, unexpected", payload["reviewers"])
|
||||
}
|
||||
w.WriteHeader(http.StatusCreated)
|
||||
w.Write([]byte(`{}`))
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
c := NewClient("tok", srv.URL, AllowInsecureHTTPForTest())
|
||||
err := c.RequestReviewer(context.Background(), "owner", "repo", 1, "reviewer1")
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPostReview_RejectsHTTP(t *testing.T) {
|
||||
// PostReview must reject http:// base URLs — tokens must not be sent in plaintext.
|
||||
c := NewClient("tok", "http://127.0.0.1:1")
|
||||
_, err := c.PostReview(context.Background(), "owner", "repo", 1, "APPROVE", "body", "", nil)
|
||||
if err == nil {
|
||||
t.Fatal("expected error for HTTP base URL in PostReview")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "refusing HTTP request") {
|
||||
t.Errorf("unexpected error message: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDeleteReview_RejectsHTTP(t *testing.T) {
|
||||
// DeleteReview must reject http:// base URLs — tokens must not be sent in plaintext.
|
||||
c := NewClient("tok", "http://127.0.0.1:1")
|
||||
err := c.DeleteReview(context.Background(), "owner", "repo", 1, 42)
|
||||
if err == nil {
|
||||
t.Fatal("expected error for HTTP base URL in DeleteReview")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "refusing HTTP request") {
|
||||
t.Errorf("unexpected error message: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRequestReviewer_RejectsHTTP(t *testing.T) {
|
||||
// RequestReviewer must reject http:// base URLs — tokens must not be sent in plaintext.
|
||||
c := NewClient("tok", "http://127.0.0.1:1")
|
||||
err := c.RequestReviewer(context.Background(), "owner", "repo", 1, "reviewer1")
|
||||
if err == nil {
|
||||
t.Fatal("expected error for HTTP base URL in RequestReviewer")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "refusing HTTP request") {
|
||||
t.Errorf("unexpected error message: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEscapePath_SpecialChars(t *testing.T) {
|
||||
tests := []struct {
|
||||
input string
|
||||
want string
|
||||
}{
|
||||
{"README.md", "README.md"},
|
||||
{"docs/guide.md", "docs/guide.md"},
|
||||
{"path with spaces/file.md", "path%20with%20spaces/file.md"},
|
||||
{"path/with [brackets]/file.md", "path/with%20%5Bbrackets%5D/file.md"},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
got := escapePath(tt.input)
|
||||
if got != tt.want {
|
||||
t.Errorf("escapePath(%q) = %q, want %q", tt.input, got, tt.want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetAllFilesInPath_DirectoryWithFiles(t *testing.T) {
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
switch r.URL.Path {
|
||||
case "/repos/owner/repo/contents/patterns":
|
||||
// Directory listing
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte(`[{"name":"go.md","path":"patterns/go.md","type":"file"}]`))
|
||||
case "/repos/owner/repo/contents/patterns/go.md":
|
||||
// GitHub file response with base64 content
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte(`{"name":"go.md","path":"patterns/go.md","type":"file","encoding":"base64","content":"IyBHbyBwYXR0ZXJucwo="}`))
|
||||
default:
|
||||
t.Errorf("unexpected path: %s", r.URL.Path)
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
}
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
c := NewClient("tok", srv.URL, AllowInsecureHTTPForTest())
|
||||
files, err := c.GetAllFilesInPath(context.Background(), "owner", "repo", "patterns")
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if len(files) != 1 {
|
||||
t.Fatalf("len(files) = %d, want 1", len(files))
|
||||
}
|
||||
if files["patterns/go.md"] != "# Go patterns\n" {
|
||||
t.Errorf("unexpected content: %q", files["patterns/go.md"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetAllFilesInPath_404FallsBackToFile(t *testing.T) {
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
switch r.URL.Path {
|
||||
case "/repos/owner/repo/contents/README.md":
|
||||
// ListContents returns 404 for file paths
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
w.Write([]byte(`{"message":"Not Found"}`))
|
||||
default:
|
||||
t.Errorf("unexpected path: %s", r.URL.Path)
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
}
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
c := NewClient("tok", srv.URL, AllowInsecureHTTPForTest())
|
||||
// GetFileContent also goes to /contents/ — this will 404 too.
|
||||
// The function should return the path-not-found error.
|
||||
_, err := c.GetAllFilesInPath(context.Background(), "owner", "repo", "README.md")
|
||||
if err == nil {
|
||||
t.Fatal("expected error when both dir and file 404, got nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetAllFilesInPath_DirectoryWithSubdir(t *testing.T) {
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
switch r.URL.Path {
|
||||
case "/repos/owner/repo/contents/src":
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte(`[
|
||||
{"name":"main.go","path":"src/main.go","type":"file"},
|
||||
{"name":"sub","path":"src/sub","type":"dir"}
|
||||
]`))
|
||||
case "/repos/owner/repo/contents/src/main.go":
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte(`{"name":"main.go","path":"src/main.go","type":"file","encoding":"base64","content":"cGFja2FnZSBtYWluCg=="}`))
|
||||
case "/repos/owner/repo/contents/src/sub":
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte(`[{"name":"util.go","path":"src/sub/util.go","type":"file"}]`))
|
||||
case "/repos/owner/repo/contents/src/sub/util.go":
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte(`{"name":"util.go","path":"src/sub/util.go","type":"file","encoding":"base64","content":"cGFja2FnZSBzdWIK"}`))
|
||||
default:
|
||||
t.Errorf("unexpected path: %s", r.URL.Path)
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
}
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
c := NewClient("tok", srv.URL, AllowInsecureHTTPForTest())
|
||||
files, err := c.GetAllFilesInPath(context.Background(), "owner", "repo", "src")
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if len(files) != 2 {
|
||||
t.Fatalf("len(files) = %d, want 2: %v", len(files), files)
|
||||
}
|
||||
if files["src/main.go"] != "package main\n" {
|
||||
t.Errorf("src/main.go content unexpected: %q", files["src/main.go"])
|
||||
}
|
||||
if files["src/sub/util.go"] != "package sub\n" {
|
||||
t.Errorf("src/sub/util.go content unexpected: %q", files["src/sub/util.go"])
|
||||
}
|
||||
}
|
||||
|
||||
+5
-5
@@ -207,11 +207,11 @@ func (c *Client) completeOpenAI(ctx context.Context, messages []Message) (string
|
||||
|
||||
type anthropicRequest struct {
|
||||
AnthropicVersion string `json:"anthropic_version,omitempty"`
|
||||
Model string `json:"model,omitempty"`
|
||||
MaxTokens int `json:"max_tokens"`
|
||||
System string `json:"system,omitempty"`
|
||||
Messages []anthropicMsg `json:"messages"`
|
||||
Temperature float64 `json:"temperature,omitempty"`
|
||||
Model string `json:"model,omitempty"`
|
||||
MaxTokens int `json:"max_tokens"`
|
||||
System string `json:"system,omitempty"`
|
||||
Messages []anthropicMsg `json:"messages"`
|
||||
Temperature float64 `json:"temperature,omitempty"`
|
||||
}
|
||||
|
||||
type anthropicMsg struct {
|
||||
|
||||
@@ -210,7 +210,6 @@ func TestWithTimeout(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
func TestComplete_Anthropic_Success(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path != "/messages" {
|
||||
|
||||
+49
-1
@@ -355,7 +355,7 @@ func TestCapitalizeFirst(t *testing.T) {
|
||||
{"HELLO", "HELLO"},
|
||||
{"a", "A"},
|
||||
{"", ""},
|
||||
{"日本語", "日本語"}, // Non-ASCII: Japanese doesn't have case
|
||||
{"日本語", "日本語"}, // Non-ASCII: Japanese doesn't have case
|
||||
{"über", "Über"}, // German umlaut
|
||||
{"élève", "Élève"}, // French accent
|
||||
}
|
||||
@@ -957,3 +957,51 @@ func TestYAMLMergeKeyDepthCheck(t *testing.T) {
|
||||
t.Errorf("error = %q, want to contain 'depth'", err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadPersona_NonexistentFile(t *testing.T) {
|
||||
_, err := LoadPersona("/tmp/nonexistent-persona-file-xyz.yaml")
|
||||
if err == nil {
|
||||
t.Fatal("expected error for nonexistent file, got nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadPersona_NotARegularFile(t *testing.T) {
|
||||
// Use a directory as the path — directories are not regular files.
|
||||
dir := t.TempDir()
|
||||
_, err := LoadPersona(dir)
|
||||
if err == nil {
|
||||
t.Fatal("expected error for directory path, got nil")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "not a regular file") {
|
||||
t.Errorf("error = %q, want to contain 'not a regular file'", err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadPersona_OversizedFile(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
path := filepath.Join(dir, "big.yaml")
|
||||
// Write a file larger than MaxPersonaFileSize
|
||||
data := make([]byte, MaxPersonaFileSize+1)
|
||||
for i := range data {
|
||||
data[i] = 'x'
|
||||
}
|
||||
if err := os.WriteFile(path, data, 0644); err != nil {
|
||||
t.Fatalf("failed to create test file: %v", err)
|
||||
}
|
||||
_, err := LoadPersona(path)
|
||||
if err == nil {
|
||||
t.Fatal("expected error for oversized file, got nil")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "exceeds maximum size") {
|
||||
t.Errorf("error = %q, want to contain 'exceeds maximum size'", err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
func TestCapitalizeFirst_RuneError(t *testing.T) {
|
||||
// An invalid UTF-8 byte sequence should return the original string unchanged.
|
||||
invalid := string([]byte{0xFF, 0xFE})
|
||||
got := CapitalizeFirst(invalid)
|
||||
if got != invalid {
|
||||
t.Errorf("CapitalizeFirst(%q) = %q, want original %q", invalid, got, invalid)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -117,7 +117,6 @@ func TestBuildUserPrompt_WithoutFileContext(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
func TestBuildSystemBase(t *testing.T) {
|
||||
result := BuildSystemBase()
|
||||
if result == "" {
|
||||
|
||||
@@ -9,11 +9,11 @@ import (
|
||||
|
||||
func TestParsePersonaBytes(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
data string
|
||||
source string
|
||||
wantName string
|
||||
wantErr string
|
||||
name string
|
||||
data string
|
||||
source string
|
||||
wantName string
|
||||
wantErr string
|
||||
}{
|
||||
{
|
||||
name: "valid yaml",
|
||||
@@ -38,8 +38,8 @@ focus:
|
||||
wantErr: "parse",
|
||||
},
|
||||
{
|
||||
name: "json format by extension",
|
||||
data: `{"name": "jsontest", "identity": "json identity"}`,
|
||||
name: "json format by extension",
|
||||
data: `{"name": "jsontest", "identity": "json identity"}`,
|
||||
source: "test.json",
|
||||
wantName: "jsontest",
|
||||
},
|
||||
|
||||
Reference in New Issue
Block a user