Compare commits
42 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 4d48917e36 | |||
| bd516cd044 | |||
| 1f67954da7 | |||
| d396599d05 | |||
| 9f3f32174b | |||
| c53a07b230 | |||
| bbf3dfbf0d | |||
| ed3a5dddf1 | |||
| 449a24e4c5 | |||
| 4440823571 | |||
| c349986187 | |||
| 934c6728ee | |||
| 5ac93bea70 | |||
| f84cc3bbcf | |||
| 8c8f3ab4b3 | |||
| 50facefdd6 | |||
| bd2df7d986 | |||
| d3bb83a10a | |||
| c56f5fec52 | |||
| b80a1517ed | |||
| 5f7ffab487 | |||
| f8b9d7d282 | |||
| 7a8fc166ec | |||
| 5e351b85f0 | |||
| ab2a6c8aef | |||
| 6b7f3f6924 | |||
| 4c032a3b53 | |||
| 64c9d551ba | |||
| db7b7e66bf | |||
| 0232343126 | |||
| b26514714f | |||
| 028d46942a | |||
| e59c2bc831 | |||
| dc2e1ca5de | |||
| 7de6fdd9ec | |||
| 1e0959b077 | |||
| 67c3db70cb | |||
| a845ce32eb | |||
| 9f8e9aa8d3 | |||
| 31a28b1dd5 | |||
| e414471a16 | |||
| 41e1d48b54 |
@@ -1,17 +1,43 @@
|
||||
# This composite action is designed for Gitea Actions runners.
|
||||
# Gitea Actions supports GitHub Actions syntax including $GITHUB_OUTPUT,
|
||||
# actions/cache, and actions/checkout.
|
||||
# This composite action supports both Gitea Actions and GitHub Actions runners.
|
||||
# It detects the VCS host type by checking whether github.api_url is set
|
||||
# (present on GitHub.com and GHES runners, absent on Gitea runners) and uses
|
||||
# the appropriate releases API for version resolution and binary download
|
||||
# (REST API on GitHub, direct URLs on Gitea).
|
||||
#
|
||||
# Security notes:
|
||||
# - On GitHub/GHES (VCS_TYPE=github), inputs.vcs-url is IGNORED to prevent
|
||||
# token exfiltration. API calls use github.api_url; downloads use
|
||||
# github.server_url. Tokens are never sent to user-supplied URLs.
|
||||
# - On Gitea (VCS_TYPE=gitea), inputs.vcs-url is validated (https scheme,
|
||||
# no whitespace/newlines, and DNS resolution to a public IP) before use.
|
||||
# Python3 resolves the hostname and rejects RFC1918, RFC6598 (carrier-grade
|
||||
# NAT), loopback, link-local, and other reserved addresses to prevent SSRF attacks.
|
||||
# The installed review-bot binary additionally uses a safe HTTP transport
|
||||
# (DialContext-level IP check) for all Gitea API calls at runtime.
|
||||
# The binary also exposes a `validate-url` subcommand for use in any future
|
||||
# shell steps that need to validate a URL before passing it to curl.
|
||||
# - action-repo is validated against owner/repo pattern.
|
||||
# - Tokens are passed via masked environment variables, not step outputs.
|
||||
#
|
||||
# Requirements: python3, sha256sum, curl (all present on ubuntu-* runners).
|
||||
name: 'AI Code Review'
|
||||
description: 'Run AI-powered code review on a pull request using review-bot'
|
||||
|
||||
inputs:
|
||||
gitea-url:
|
||||
description: 'Gitea instance URL (defaults to server_url)'
|
||||
vcs-url:
|
||||
description: 'VCS server URL (only used on Gitea runners; ignored on GitHub/GHES). Defaults to server_url.'
|
||||
required: false
|
||||
default: ''
|
||||
repo:
|
||||
description: 'Repository (owner/name, defaults to current)'
|
||||
description: 'Repository to review (owner/name, defaults to current)'
|
||||
required: false
|
||||
default: ''
|
||||
action-repo:
|
||||
description: 'Repository hosting review-bot releases (owner/name). Defaults to github.action_repository or rodin/review-bot.'
|
||||
required: false
|
||||
default: ''
|
||||
action-repo-token:
|
||||
description: 'Token for downloading release assets from action-repo (defaults to github.token on GitHub, reviewer-token on Gitea). Required for private repos.'
|
||||
required: false
|
||||
default: ''
|
||||
pr-number:
|
||||
@@ -19,7 +45,7 @@ inputs:
|
||||
required: false
|
||||
default: ''
|
||||
reviewer-token:
|
||||
description: 'Gitea token for posting the review'
|
||||
description: 'Token for posting the review'
|
||||
required: true
|
||||
reviewer-name:
|
||||
description: 'Display name for the reviewer'
|
||||
@@ -112,45 +138,325 @@ runs:
|
||||
id: version
|
||||
shell: bash
|
||||
run: |
|
||||
GITEA_URL="${{ inputs.gitea-url || github.server_url }}"
|
||||
REPO="${{ inputs.repo || 'rodin/review-bot' }}"
|
||||
set -euo pipefail
|
||||
|
||||
# --- Input Validation ---
|
||||
|
||||
# Determine the repo hosting review-bot releases (not the repo being reviewed)
|
||||
ACTION_REPO="${{ inputs.action-repo }}"
|
||||
if [ -z "$ACTION_REPO" ]; then
|
||||
# github.action_repository is the repo containing the running action
|
||||
ACTION_REPO="${{ github.action_repository }}"
|
||||
fi
|
||||
if [ -z "$ACTION_REPO" ]; then
|
||||
# Final fallback for Gitea (which may not set action_repository)
|
||||
ACTION_REPO="rodin/review-bot"
|
||||
echo "::notice::action-repo not specified and github.action_repository is empty; falling back to rodin/review-bot"
|
||||
fi
|
||||
|
||||
# Validate ACTION_REPO matches owner/repo pattern (prevent path traversal)
|
||||
if ! printf '%s' "$ACTION_REPO" | grep -qE '^[a-zA-Z0-9._-]+/[a-zA-Z0-9._-]+$'; then
|
||||
echo "Error: action-repo '${ACTION_REPO}' does not match expected owner/repo format" >&2
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Detect VCS host type using github.api_url context.
|
||||
# github.api_url is set on GitHub.com (https://api.github.com) and GHES
|
||||
# (https://<host>/api/v3). It is empty/unset on Gitea Actions runners.
|
||||
GITHUB_API_URL="${{ github.api_url }}"
|
||||
if [ -n "$GITHUB_API_URL" ]; then
|
||||
VCS_TYPE="github"
|
||||
else
|
||||
VCS_TYPE="gitea"
|
||||
fi
|
||||
|
||||
# Determine SERVER_URL based on VCS type.
|
||||
# SECURITY: On GitHub/GHES, ALWAYS use github.server_url — never trust
|
||||
# inputs.vcs-url to prevent token exfiltration to attacker-controlled hosts.
|
||||
if [ "$VCS_TYPE" = "github" ]; then
|
||||
SERVER_URL="${{ github.server_url }}"
|
||||
if [ -n "${{ inputs.vcs-url }}" ]; then
|
||||
echo "::warning::inputs.vcs-url is ignored on GitHub/GHES runners (VCS_TYPE=github). Using github.server_url instead."
|
||||
fi
|
||||
else
|
||||
SERVER_URL="${{ inputs.vcs-url || github.server_url }}"
|
||||
fi
|
||||
# Strip trailing slash if present
|
||||
SERVER_URL="${SERVER_URL%/}"
|
||||
|
||||
# Validate SERVER_URL for Gitea path: must be https, no whitespace/newlines.
|
||||
# The [^[:space:]] class already rejects newlines, so no separate newline check needed.
|
||||
if [ "$VCS_TYPE" = "gitea" ]; then
|
||||
if ! printf '%s' "$SERVER_URL" | grep -qE '^https://[^[:space:]]+$'; then
|
||||
echo "Error: SERVER_URL '${SERVER_URL}' must be an https:// URL with no whitespace" >&2
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Additional IP-level SSRF defense: resolve the hostname and reject
|
||||
# requests to RFC1918, RFC6598 (carrier-grade NAT), loopback, link-local,
|
||||
# and other reserved addresses.
|
||||
# python3 is required on ubuntu-* runners (see requirements comment above).
|
||||
# Use printf to write the script to a temp file so the python lines are valid
|
||||
# YAML (each indented line becomes a printf argument — no unindented code).
|
||||
# SERVER_URL is passed via CHECK_URL env var, never interpolated into python code.
|
||||
printf '%s\n' \
|
||||
'import socket,ipaddress,sys,os' \
|
||||
'from urllib.parse import urlparse' \
|
||||
'u=os.environ["CHECK_URL"]; parsed=urlparse(u)' \
|
||||
'if parsed.username or parsed.password:' \
|
||||
' print("Error: URL contains user-info — not allowed",file=sys.stderr); sys.exit(2)' \
|
||||
'h=parsed.hostname' \
|
||||
'(print("Error: no hostname",file=sys.stderr) or sys.exit(2)) if not h else None' \
|
||||
'try: rs=socket.getaddrinfo(h,None)' \
|
||||
'except socket.gaierror as e: print(f"DNS error: {e}",file=sys.stderr); sys.exit(1)' \
|
||||
'if not rs: print("Error: no addresses",file=sys.stderr); sys.exit(1)' \
|
||||
'for _,_,_,_,(a,*_) in rs:' \
|
||||
' ip=ipaddress.ip_address(a)' \
|
||||
' if isinstance(ip,ipaddress.IPv6Address) and ip.ipv4_mapped: ip=ip.ipv4_mapped' \
|
||||
' cgn=ipaddress.ip_network("100.64.0.0/10")' \
|
||||
' if ip.is_private or ip.is_loopback or ip.is_link_local or ip.is_multicast or ip.is_reserved or ip in cgn:' \
|
||||
' print(f"blocked: {a}",file=sys.stderr); sys.exit(1)' \
|
||||
> /tmp/_ssrf_check.py
|
||||
CHECK_URL="${SERVER_URL}" python3 /tmp/_ssrf_check.py || {
|
||||
echo "Error: SERVER_URL '${SERVER_URL}' resolves to a private/reserved IP address" >&2
|
||||
exit 1
|
||||
}
|
||||
fi
|
||||
|
||||
# Determine auth token for release API requests
|
||||
ACTION_TOKEN="${{ inputs.action-repo-token }}"
|
||||
if [ -z "$ACTION_TOKEN" ]; then
|
||||
if [ "$VCS_TYPE" = "github" ]; then
|
||||
ACTION_TOKEN="${{ github.token }}"
|
||||
else
|
||||
ACTION_TOKEN="${{ inputs.reviewer-token }}"
|
||||
fi
|
||||
fi
|
||||
|
||||
# Validate token contains no control characters (defense-in-depth against header injection)
|
||||
if [ -n "$ACTION_TOKEN" ]; then
|
||||
if printf '%s' "$ACTION_TOKEN" | LC_ALL=C grep -q '[^[:print:]]'; then
|
||||
echo "Error: ACTION_TOKEN contains control characters" >&2
|
||||
exit 1
|
||||
fi
|
||||
fi
|
||||
|
||||
if [ "${{ inputs.version }}" = "latest" ]; then
|
||||
VERSION=$(curl -sSf "${GITEA_URL}/api/v1/repos/${REPO}/releases?limit=1" \
|
||||
| python3 -c "import sys, json; releases = json.load(sys.stdin); print(releases[0]['tag_name'] if releases else '')")
|
||||
if [ "$VCS_TYPE" = "github" ]; then
|
||||
# SECURITY: Use github.api_url which is a trusted platform-provided value.
|
||||
# Never construct API URLs from user-supplied inputs on GitHub.
|
||||
API_URL="${GITHUB_API_URL}/repos/${ACTION_REPO}/releases?per_page=1"
|
||||
else
|
||||
# Gitea API — SERVER_URL was validated above
|
||||
API_URL="${SERVER_URL}/api/v1/repos/${ACTION_REPO}/releases?limit=1"
|
||||
fi
|
||||
|
||||
# Fetch latest version with inline auth header (no intermediate variable)
|
||||
if [ -n "$ACTION_TOKEN" ]; then
|
||||
if [ "$VCS_TYPE" = "github" ]; then
|
||||
VERSION=$(curl -sSf --connect-timeout 10 --max-time 30 \
|
||||
-H "Authorization: Bearer ${ACTION_TOKEN}" "$API_URL" \
|
||||
| python3 -c "import sys, json; releases = json.load(sys.stdin); print(releases[0]['tag_name'] if releases else '')")
|
||||
else
|
||||
VERSION=$(curl -sSf --connect-timeout 10 --max-time 30 \
|
||||
-H "Authorization: token ${ACTION_TOKEN}" "$API_URL" \
|
||||
| python3 -c "import sys, json; releases = json.load(sys.stdin); print(releases[0]['tag_name'] if releases else '')")
|
||||
fi
|
||||
else
|
||||
VERSION=$(curl -sSf --connect-timeout 10 --max-time 30 "$API_URL" \
|
||||
| python3 -c "import sys, json; releases = json.load(sys.stdin); print(releases[0]['tag_name'] if releases else '')")
|
||||
fi
|
||||
|
||||
if [ -z "$VERSION" ]; then
|
||||
echo "Failed to determine latest version" >&2
|
||||
echo "Failed to determine latest version from ${API_URL}" >&2
|
||||
exit 1
|
||||
fi
|
||||
else
|
||||
VERSION="${{ inputs.version }}"
|
||||
fi
|
||||
|
||||
# Validate VERSION: no slashes or whitespace (prevent path traversal).
|
||||
# [:space:] includes newlines and carriage returns in POSIX.
|
||||
if printf '%s' "$VERSION" | grep -qE '[/[:space:]]'; then
|
||||
echo "Error: VERSION '${VERSION}' contains invalid characters (newline, slash, or whitespace)" >&2
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Detect OS and architecture for platform-specific binary download
|
||||
OS_RAW=$(uname -s | tr '[:upper:]' '[:lower:]')
|
||||
case "$OS_RAW" in
|
||||
linux) OS="linux" ;;
|
||||
darwin) OS="darwin" ;;
|
||||
*)
|
||||
echo "Error: unsupported OS: $(uname -s)" >&2
|
||||
exit 1
|
||||
;;
|
||||
esac
|
||||
|
||||
RAW_ARCH=$(uname -m)
|
||||
case "$RAW_ARCH" in
|
||||
x86_64) ARCH="amd64" ;;
|
||||
aarch64 | arm64) ARCH="arm64" ;;
|
||||
*)
|
||||
echo "Error: unsupported architecture: $RAW_ARCH" >&2
|
||||
exit 1
|
||||
;;
|
||||
esac
|
||||
|
||||
echo "version=${VERSION}" >> "$GITHUB_OUTPUT"
|
||||
echo "os=${OS}" >> "$GITHUB_OUTPUT"
|
||||
echo "arch=${ARCH}" >> "$GITHUB_OUTPUT"
|
||||
echo "action_repo=${ACTION_REPO}" >> "$GITHUB_OUTPUT"
|
||||
echo "server_url=${SERVER_URL}" >> "$GITHUB_OUTPUT"
|
||||
echo "vcs_type=${VCS_TYPE}" >> "$GITHUB_OUTPUT"
|
||||
|
||||
# SECURITY: Pass token via masked environment variable instead of step output.
|
||||
# Step outputs can leak in debug logs; GITHUB_ENV with masking is safer.
|
||||
if [ -n "$ACTION_TOKEN" ]; then
|
||||
echo "::add-mask::${ACTION_TOKEN}"
|
||||
echo "ACTION_TOKEN=${ACTION_TOKEN}" >> "$GITHUB_ENV"
|
||||
fi
|
||||
|
||||
- name: Cache review-bot binary
|
||||
id: cache
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: ${{ runner.temp }}/review-bot
|
||||
key: review-bot-linux-amd64-${{ steps.version.outputs.version }}
|
||||
key: review-bot-${{ steps.version.outputs.os }}-${{ steps.version.outputs.arch }}-${{ steps.version.outputs.version }}
|
||||
|
||||
- name: Install review-bot
|
||||
if: steps.cache.outputs.cache-hit != 'true'
|
||||
shell: bash
|
||||
run: |
|
||||
GITEA_URL="${{ inputs.gitea-url || github.server_url }}"
|
||||
REPO="${{ inputs.repo || 'rodin/review-bot' }}"
|
||||
VERSION="${{ steps.version.outputs.version }}"
|
||||
BINARY="review-bot-linux-amd64"
|
||||
set -euo pipefail
|
||||
|
||||
curl -sSfL "${GITEA_URL}/${REPO}/releases/download/${VERSION}/${BINARY}" \
|
||||
-o "${{ runner.temp }}/review-bot"
|
||||
curl -sSfL "${GITEA_URL}/${REPO}/releases/download/${VERSION}/checksums.txt" \
|
||||
-o "${{ runner.temp }}/checksums.txt"
|
||||
SERVER_URL="${{ steps.version.outputs.server_url }}"
|
||||
ACTION_REPO="${{ steps.version.outputs.action_repo }}"
|
||||
VERSION="${{ steps.version.outputs.version }}"
|
||||
VCS_TYPE="${{ steps.version.outputs.vcs_type }}"
|
||||
OS="${{ steps.version.outputs.os }}"
|
||||
ARCH="${{ steps.version.outputs.arch }}"
|
||||
# Read token from masked environment variable (set in Determine version step)
|
||||
# Falls back to empty if not set (public repos don't need auth)
|
||||
ACTION_TOKEN="${ACTION_TOKEN:-}"
|
||||
BINARY="review-bot-${OS}-${ARCH}"
|
||||
|
||||
# SECURITY: Re-validate SERVER_URL at the start of this step to mitigate DNS
|
||||
# rebinding attacks. A DNS TTL expiry between "Determine version" and here
|
||||
# could allow an attacker to change the resolved IP to a private/reserved
|
||||
# address, causing curl to send ACTION_TOKEN to an internal host.
|
||||
# Only needed on Gitea path (VCS_TYPE=gitea); GitHub/GHES uses platform-controlled URLs.
|
||||
if [ "$VCS_TYPE" = "gitea" ]; then
|
||||
printf '%s\n' \
|
||||
'import socket,ipaddress,sys,os' \
|
||||
'from urllib.parse import urlparse' \
|
||||
'u=os.environ["CHECK_URL"]; parsed=urlparse(u)' \
|
||||
'if parsed.username or parsed.password:' \
|
||||
' print("Error: URL contains user-info — not allowed",file=sys.stderr); sys.exit(2)' \
|
||||
'h=parsed.hostname' \
|
||||
'(print("Error: no hostname",file=sys.stderr) or sys.exit(2)) if not h else None' \
|
||||
'try: rs=socket.getaddrinfo(h,None)' \
|
||||
'except socket.gaierror as e: print(f"DNS error: {e}",file=sys.stderr); sys.exit(1)' \
|
||||
'if not rs: print("Error: no addresses",file=sys.stderr); sys.exit(1)' \
|
||||
'for _,_,_,_,(a,*_) in rs:' \
|
||||
' ip=ipaddress.ip_address(a)' \
|
||||
' if isinstance(ip,ipaddress.IPv6Address) and ip.ipv4_mapped: ip=ip.ipv4_mapped' \
|
||||
' cgn=ipaddress.ip_network("100.64.0.0/10")' \
|
||||
' if ip.is_private or ip.is_loopback or ip.is_link_local or ip.is_multicast or ip.is_reserved or ip in cgn:' \
|
||||
' print(f"blocked: {a}",file=sys.stderr); sys.exit(1)' \
|
||||
> /tmp/_ssrf_check_install.py
|
||||
CHECK_URL="${SERVER_URL}" python3 /tmp/_ssrf_check_install.py || {
|
||||
echo "Error: SERVER_URL '${SERVER_URL}' resolves to a private/reserved IP address" >&2
|
||||
exit 1
|
||||
}
|
||||
fi
|
||||
|
||||
if [ "$VCS_TYPE" = "github" ]; then
|
||||
# GitHub/GHES: Use REST API for release asset downloads.
|
||||
# Web release URLs ({server}/.../releases/download/{tag}/{asset}) redirect
|
||||
# to S3 and don't reliably support Authorization headers for private repos.
|
||||
# The REST API endpoint with Accept: application/octet-stream is required.
|
||||
# GITHUB_API_URL: trusted platform value, same as detected in "Determine version" step.
|
||||
GITHUB_API_URL="${{ github.api_url }}"
|
||||
|
||||
if [ -n "$ACTION_TOKEN" ]; then
|
||||
RELEASE_JSON=$(curl -sSf --connect-timeout 10 --max-time 30 \
|
||||
-H "Authorization: Bearer ${ACTION_TOKEN}" \
|
||||
"${GITHUB_API_URL}/repos/${ACTION_REPO}/releases/tags/${VERSION}")
|
||||
else
|
||||
RELEASE_JSON=$(curl -sSf --connect-timeout 10 --max-time 30 \
|
||||
"${GITHUB_API_URL}/repos/${ACTION_REPO}/releases/tags/${VERSION}")
|
||||
fi
|
||||
|
||||
# Extract asset IDs for binary and checksums
|
||||
BINARY_ASSET_ID=$(printf '%s' "$RELEASE_JSON" | python3 -c "import sys, json; assets = json.load(sys.stdin).get('assets', []); matches = [a['id'] for a in assets if a['name'] == '${BINARY}']; print(matches[0] if matches else '')")
|
||||
if [ -z "$BINARY_ASSET_ID" ]; then
|
||||
echo "Error: could not find asset '${BINARY}' in release ${VERSION}" >&2
|
||||
exit 1
|
||||
fi
|
||||
|
||||
CHECKSUMS_ASSET_ID=$(printf '%s' "$RELEASE_JSON" | python3 -c "import sys, json; assets = json.load(sys.stdin).get('assets', []); matches = [a['id'] for a in assets if a['name'] == 'checksums.txt']; print(matches[0] if matches else '')")
|
||||
if [ -z "$CHECKSUMS_ASSET_ID" ]; then
|
||||
echo "Error: could not find asset 'checksums.txt' in release ${VERSION}" >&2
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Download assets via REST API with Accept: application/octet-stream
|
||||
if [ -n "$ACTION_TOKEN" ]; then
|
||||
curl -sSfL --connect-timeout 10 --max-time 120 \
|
||||
-H "Authorization: Bearer ${ACTION_TOKEN}" \
|
||||
-H "Accept: application/octet-stream" \
|
||||
"${GITHUB_API_URL}/repos/${ACTION_REPO}/releases/assets/${BINARY_ASSET_ID}" \
|
||||
-o "${{ runner.temp }}/review-bot"
|
||||
curl -sSfL --connect-timeout 10 --max-time 30 \
|
||||
-H "Authorization: Bearer ${ACTION_TOKEN}" \
|
||||
-H "Accept: application/octet-stream" \
|
||||
"${GITHUB_API_URL}/repos/${ACTION_REPO}/releases/assets/${CHECKSUMS_ASSET_ID}" \
|
||||
-o "${{ runner.temp }}/checksums.txt"
|
||||
else
|
||||
curl -sSfL --connect-timeout 10 --max-time 120 \
|
||||
-H "Accept: application/octet-stream" \
|
||||
"${GITHUB_API_URL}/repos/${ACTION_REPO}/releases/assets/${BINARY_ASSET_ID}" \
|
||||
-o "${{ runner.temp }}/review-bot"
|
||||
curl -sSfL --connect-timeout 10 --max-time 30 \
|
||||
-H "Accept: application/octet-stream" \
|
||||
"${GITHUB_API_URL}/repos/${ACTION_REPO}/releases/assets/${CHECKSUMS_ASSET_ID}" \
|
||||
-o "${{ runner.temp }}/checksums.txt"
|
||||
fi
|
||||
else
|
||||
# Gitea: Direct download via web release URLs (Gitea serves assets
|
||||
# directly without redirects — no -L needed).
|
||||
# SECURITY: Omitting -L prevents forwarding Authorization header to
|
||||
# unexpected hosts if Gitea ever introduces CDN redirects.
|
||||
DOWNLOAD_URL="${SERVER_URL}/${ACTION_REPO}/releases/download/${VERSION}"
|
||||
|
||||
if [ -n "$ACTION_TOKEN" ]; then
|
||||
curl -sSf --connect-timeout 10 --max-time 120 \
|
||||
-H "Authorization: token ${ACTION_TOKEN}" \
|
||||
"${DOWNLOAD_URL}/${BINARY}" -o "${{ runner.temp }}/review-bot"
|
||||
curl -sSf --connect-timeout 10 --max-time 30 \
|
||||
-H "Authorization: token ${ACTION_TOKEN}" \
|
||||
"${DOWNLOAD_URL}/checksums.txt" -o "${{ runner.temp }}/checksums.txt"
|
||||
else
|
||||
curl -sSf --connect-timeout 10 --max-time 120 \
|
||||
"${DOWNLOAD_URL}/${BINARY}" -o "${{ runner.temp }}/review-bot"
|
||||
curl -sSf --connect-timeout 10 --max-time 30 \
|
||||
"${DOWNLOAD_URL}/checksums.txt" -o "${{ runner.temp }}/checksums.txt"
|
||||
fi
|
||||
fi
|
||||
|
||||
# Verify SHA-256 checksum
|
||||
# NOTE: This verifies integrity (download wasn't corrupted) but not
|
||||
# authenticity — both binary and checksums come from the same server.
|
||||
# For stronger guarantees, consider GPG signature verification.
|
||||
cd "${{ runner.temp }}"
|
||||
EXPECTED=$(grep "${BINARY}" checksums.txt | awk '{print $1}')
|
||||
ACTUAL=$(sha256sum review-bot | awk '{print $1}')
|
||||
EXPECTED=$(grep -E "^[0-9a-f]+[[:space:]]+\*?${BINARY}$" checksums.txt | awk '{print $1}')
|
||||
# sha256sum (GNU) is not available on macOS; use shasum -a 256 on darwin.
|
||||
if [ "${OS}" = "darwin" ]; then
|
||||
ACTUAL=$(shasum -a 256 review-bot | awk '{print $1}')
|
||||
else
|
||||
ACTUAL=$(sha256sum review-bot | awk '{print $1}')
|
||||
fi
|
||||
|
||||
if [ -z "$EXPECTED" ]; then
|
||||
echo "Error: no checksum found for ${BINARY}" >&2
|
||||
@@ -164,12 +470,12 @@ runs:
|
||||
fi
|
||||
|
||||
chmod +x "${{ runner.temp }}/review-bot"
|
||||
echo "Installed review-bot ${VERSION} (checksum verified)"
|
||||
echo "Installed review-bot-${OS}-${ARCH} ${VERSION} (checksum verified)"
|
||||
|
||||
- name: Run review
|
||||
shell: bash
|
||||
env:
|
||||
GITEA_URL: ${{ inputs.gitea-url || github.server_url }}
|
||||
VCS_URL: ${{ steps.version.outputs.server_url }}
|
||||
GITEA_REPO: ${{ inputs.repo || github.repository }}
|
||||
PR_NUMBER: ${{ inputs.pr-number || github.event.pull_request.number }}
|
||||
REVIEWER_TOKEN: ${{ inputs.reviewer-token }}
|
||||
|
||||
@@ -49,7 +49,7 @@ jobs:
|
||||
- run: go build -o review-bot ./cmd/review-bot
|
||||
- name: Run ${{ matrix.name }} review
|
||||
env:
|
||||
GITEA_URL: ${{ github.server_url }}
|
||||
VCS_URL: ${{ github.server_url }}
|
||||
GITEA_REPO: ${{ github.repository }}
|
||||
PR_NUMBER: ${{ github.event.pull_request.number }}
|
||||
REVIEWER_TOKEN: ${{ secrets[matrix.token_secret] }}
|
||||
|
||||
+175
@@ -0,0 +1,175 @@
|
||||
# Plan: Issue #125 — Rename GITEA_URL → VCS_URL
|
||||
|
||||
## Problem
|
||||
|
||||
The `GITEA_URL` environment variable (and `--gitea-url` flag) implies the binary only works with Gitea.
|
||||
Now that review-bot supports both Gitea and GitHub/GHES, this name is misleading.
|
||||
Renaming to `VCS_URL` makes the binary platform-agnostic in its interface.
|
||||
|
||||
## Constraints
|
||||
|
||||
- Must not break existing users who already use `GITEA_URL` — need a fallback
|
||||
- The CLI flag `--gitea-url` should also be updated to `--vcs-url` for consistency
|
||||
- `INTEGRATION_GITEA_URL` in integration tests is a test-only env var, not the binary's interface; but should be updated for clarity
|
||||
- The action YAML uses `GITEA_URL` as an internal shell variable in bash scripts — distinct from the env var passed to the binary
|
||||
- All changes must compile and pass existing tests
|
||||
|
||||
## Files Affected
|
||||
|
||||
### Binary / Go source
|
||||
| File | Change |
|
||||
|------|--------|
|
||||
| `cmd/review-bot/main.go` | Rename `--gitea-url` → `--vcs-url`, add `VCS_URL` as primary, keep `GITEA_URL` fallback |
|
||||
| `cmd/review-bot/integration_test.go` | Rename `INTEGRATION_GITEA_URL` → `INTEGRATION_VCS_URL` (test-only, no external compat concern) |
|
||||
| `integration_test.go` | Same — rename `INTEGRATION_GITEA_URL` → `INTEGRATION_VCS_URL` |
|
||||
|
||||
### Action YAML
|
||||
| File | Change |
|
||||
|------|--------|
|
||||
| `.gitea/actions/review/action.yml` | Rename input `gitea-url` → `vcs-url`; update env var passed to binary: `VCS_URL` instead of `GITEA_URL`; keep internal bash var as `GITEA_URL` (only used for release download, not passed to binary) |
|
||||
| `.gitea/workflows/ci.yml` | Rename `GITEA_URL` env var to `VCS_URL` in Run review step |
|
||||
|
||||
### Documentation
|
||||
| File | Change |
|
||||
|------|--------|
|
||||
| `README.md` | Update CLI example, env var table entry |
|
||||
|
||||
## Proposed Approach
|
||||
|
||||
### 1. Backward-compatible env var lookup in main.go
|
||||
|
||||
Replace:
|
||||
```go
|
||||
giteaURL := flag.String("gitea-url", envOrDefault("GITEA_URL", ""), "Gitea instance URL")
|
||||
```
|
||||
|
||||
With:
|
||||
```go
|
||||
giteaURL := flag.String("vcs-url", envOrDefaultFallback("VCS_URL", "GITEA_URL", ""), "VCS server URL (e.g. https://gitea.example.com)")
|
||||
```
|
||||
|
||||
Add a helper:
|
||||
```go
|
||||
// envOrDefaultFallback reads primary env var; if empty, falls back to deprecated env var.
|
||||
func envOrDefaultFallback(primary, deprecated, defaultVal string) string {
|
||||
if v := os.Getenv(primary); v != "" {
|
||||
return v
|
||||
}
|
||||
if v := os.Getenv(deprecated); v != "" {
|
||||
slog.Warn("deprecated env var in use; rename to " + primary, "old", deprecated, "new", primary)
|
||||
return v
|
||||
}
|
||||
return defaultVal
|
||||
}
|
||||
```
|
||||
|
||||
**Note:** This must be called AFTER `setupLogger` conceptually, but the flag default is evaluated at flag registration time. Since `setupLogger` runs before `flag.Parse()`, the slog.Warn will print correctly at runtime. We use `log.Printf` as a fallback if this proves problematic.
|
||||
|
||||
Actually — flag defaults are evaluated at registration (line 57), before `setupLogger`. The warning won't go through slog. Two options:
|
||||
- Use `log.Printf` for the deprecation warning (always visible)
|
||||
- Move the fallback lookup to after `flag.Parse()`, checking if the parsed value is still empty
|
||||
|
||||
**Decision:** Move fallback to a post-parse check. This is cleaner:
|
||||
```go
|
||||
vcsURL := flag.String("vcs-url", os.Getenv("VCS_URL"), "VCS server URL")
|
||||
flag.Parse()
|
||||
// Backward compat: fall back to deprecated GITEA_URL
|
||||
if *vcsURL == "" {
|
||||
if v := os.Getenv("GITEA_URL"); v != "" {
|
||||
slog.Warn("GITEA_URL is deprecated; use VCS_URL instead")
|
||||
*vcsURL = v
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
This is clean, idiomatic, and the warning goes through slog correctly.
|
||||
|
||||
### 2. Keep `--gitea-url` as deprecated alias
|
||||
|
||||
Add a hidden flag for backward compat:
|
||||
```go
|
||||
giteaURLAlias := flag.String("gitea-url", "", "Deprecated: use --vcs-url")
|
||||
```
|
||||
|
||||
Post-parse:
|
||||
```go
|
||||
if *vcsURL == "" && *giteaURLAlias != "" {
|
||||
slog.Warn("--gitea-url is deprecated; use --vcs-url instead")
|
||||
*vcsURL = *giteaURLAlias
|
||||
}
|
||||
```
|
||||
|
||||
### 3. Internal variable rename
|
||||
|
||||
Rename `giteaURL` local variable → `vcsURL` throughout `main.go` for consistency.
|
||||
|
||||
### 4. Error message update
|
||||
|
||||
```go
|
||||
fmt.Fprintf(os.Stderr, "Required: --vcs-url, --repo, --pr, --reviewer-token, --llm-model\n")
|
||||
```
|
||||
|
||||
### 5. Action YAML changes
|
||||
|
||||
In `.gitea/actions/review/action.yml`:
|
||||
- Input `gitea-url` → `vcs-url` (with same description, `required: false`, `default: ''`)
|
||||
- Line 172: `GITEA_URL: ${{ inputs.gitea-url || github.server_url }}` → `VCS_URL: ${{ inputs.vcs-url || github.server_url }}`
|
||||
- Lines 115, 140: internal bash vars `GITEA_URL=` are used for downloading binaries — NOT passed to the review-bot binary. Leave them as internal bash vars (they're scope-local in bash). These could be renamed to `SERVER_URL` or `BASE_URL` for local clarity, but renaming them isn't strictly required.
|
||||
|
||||
In `.gitea/workflows/ci.yml`:
|
||||
- Line 52: `GITEA_URL: ${{ github.server_url }}` → `VCS_URL: ${{ github.server_url }}`
|
||||
|
||||
### 6. Integration test updates
|
||||
|
||||
`INTEGRATION_GITEA_URL` → `INTEGRATION_VCS_URL` in both test files.
|
||||
|
||||
### 7. README
|
||||
|
||||
- CLI example: `--gitea-url` → `--vcs-url`
|
||||
- Env var table: `GITEA_URL` → `VCS_URL`, add note about `GITEA_URL` fallback
|
||||
|
||||
## Backward Compatibility Summary
|
||||
|
||||
| Old | New | Fallback? |
|
||||
|-----|-----|-----------|
|
||||
| `GITEA_URL` env var | `VCS_URL` | ✅ with deprecation warning |
|
||||
| `--gitea-url` flag | `--vcs-url` | ✅ with deprecation warning |
|
||||
| `gitea-url` action input | `vcs-url` | ⚠️ No (action version bump handles this) |
|
||||
| `INTEGRATION_GITEA_URL` | `INTEGRATION_VCS_URL` | N/A (test-only) |
|
||||
|
||||
## Error Cases
|
||||
|
||||
- Both `VCS_URL` and `GITEA_URL` set: `VCS_URL` wins (primary takes precedence)
|
||||
- Both `--vcs-url` and `--gitea-url` provided: `--vcs-url` wins
|
||||
- Neither set: existing "missing required flags" error unchanged
|
||||
|
||||
## Edge Cases
|
||||
|
||||
- `os.Getenv` returns "" for unset AND set-to-empty — consistent with existing behavior
|
||||
- The `envOrDefault` helper is unchanged; we add `envOrDefaultFallback` for the one renamed var
|
||||
|
||||
## Testing Strategy
|
||||
|
||||
- Existing unit tests pass unchanged (they don't test env var parsing directly)
|
||||
- Integration tests updated to use new env var name
|
||||
- Manual: `GITEA_URL=https://example.com ./review-bot --repo x --pr 1 ...` should print deprecation warning and proceed
|
||||
- Manual: `VCS_URL=https://example.com ./review-bot ...` should work silently
|
||||
|
||||
## Completion Checklist
|
||||
|
||||
1. `VCS_URL` is read first; `GITEA_URL` is fallback with deprecation warning
|
||||
2. `--vcs-url` flag is primary; `--gitea-url` is deprecated alias with warning
|
||||
3. Error message references `--vcs-url` not `--gitea-url`
|
||||
4. `action.yml` passes `VCS_URL` (not `GITEA_URL`) to the binary
|
||||
5. `ci.yml` passes `VCS_URL` (not `GITEA_URL`) to the binary
|
||||
6. README updated in CLI example and env var table
|
||||
7. Integration tests use `INTEGRATION_VCS_URL`
|
||||
8. `go test ./...` passes
|
||||
9. `go vet ./...` passes
|
||||
10. `go build ./cmd/review-bot` succeeds
|
||||
|
||||
## Open Questions
|
||||
|
||||
- Should the CLI flag `--gitea-url` be completely hidden from `--help` or just deprecated with a note? The issue doesn't specify. Decision: keep it visible but add "(deprecated: use --vcs-url)" to the description.
|
||||
- Should action.yml also add `gitea-url` as a deprecated input alias? The issue says "Update the action to pass the new env var name" — no mention of backward compat for the action input. Decision: rename only, no alias (action users pin a version anyway).
|
||||
- The bash-internal `GITEA_URL` variable in action.yml scripts (used for release download, not passed to binary) — rename for clarity? Decision: yes, rename to `BASE_URL` to avoid confusion with the env var.
|
||||
@@ -282,7 +282,7 @@ Rules:
|
||||
|
||||
```bash
|
||||
review-bot \
|
||||
--gitea-url https://gitea.example.com \
|
||||
--vcs-url https://gitea.example.com \
|
||||
--repo owner/name \
|
||||
--pr 42 \
|
||||
--reviewer-token "$GITEA_TOKEN" \
|
||||
@@ -299,7 +299,7 @@ All flags have environment variable equivalents:
|
||||
|
||||
| Flag | Env Var |
|
||||
|------|---------|
|
||||
| `--gitea-url` | `GITEA_URL` |
|
||||
| `--vcs-url` | `VCS_URL` (fallback: `GITEA_URL`) |
|
||||
| `--repo` | `GITEA_REPO` |
|
||||
| `--pr` | `PR_NUMBER` |
|
||||
| `--reviewer-token` | `REVIEWER_TOKEN` |
|
||||
|
||||
@@ -0,0 +1,151 @@
|
||||
## Dev Loop: review-bot — Continuous Health Monitor
|
||||
|
||||
### Current Cycle: 2026-05-15 02:10 UTC ✅
|
||||
|
||||
**Repository Status:** OPTIMAL
|
||||
- Main: `9f3f321` (clean, all tests pass)
|
||||
- Working tree: clean
|
||||
- Build: ✅ successful
|
||||
- Vet: ✅ clean
|
||||
- Test suite: ALL PASS
|
||||
|
||||
---
|
||||
|
||||
## Latest Delivered: Issue #130 ✅
|
||||
|
||||
### GitHub API + VCS Routing Complete
|
||||
|
||||
**Phase 1: GitHub API Methods** ✅
|
||||
- 12+ methods implemented in `github/client.go`
|
||||
- GetPullRequest, GetPullRequestDiff, GetPullRequestFiles
|
||||
- GetCommitStatuses, GetFileContent, ListContents, GetAllFilesInPath
|
||||
- PostReview, ListReviews, DeleteReview, GetAuthenticatedUser, RequestReviewer
|
||||
|
||||
**Phase 2: VCS Abstraction** ✅
|
||||
- `vcsClient` interface (GitHub + Gitea)
|
||||
- `giteaExtClient` interface (Gitea-specific ops)
|
||||
- Adapters for both platforms
|
||||
- URL-based auto-detection (github.com → GitHub, else Gitea)
|
||||
- `--vcs-type` flag and `VCS_TYPE` env override
|
||||
|
||||
**Quality Metrics** ✅
|
||||
- 474 lines of GitHub client tests
|
||||
- 82 lines of routing tests
|
||||
- 361 lines of VCS adapter code
|
||||
- Security review: APPROVED (MINOR: URL heuristic note)
|
||||
- All tests passing; go vet clean
|
||||
|
||||
**Known Limitations** (Documented)
|
||||
- GitHub: Can only delete PENDING (draft) reviews, not submitted (handled gracefully)
|
||||
- GitHub pagination: per-page=100 with Link header checking
|
||||
- Check-runs: Uses statuses API; check-runs deferrable to future enhancement
|
||||
|
||||
---
|
||||
|
||||
## Repository Status Post-Merge
|
||||
|
||||
### Main Branch
|
||||
- Commit: `9f3f321`
|
||||
- Status: ✅ All systems healthy
|
||||
|
||||
### Recent Merged PRs
|
||||
| PR | Issue | Title | Status |
|
||||
|---|---|---|---|
|
||||
| #131 | #130 | GitHub API methods & VCS routing | ✅ MERGED |
|
||||
| #129 | #123 | IP-level SSRF defense | ✅ MERGED |
|
||||
| #128 | #125 | VCS_URL deprecation & renaming | ✅ MERGED |
|
||||
| #127 | #124 | Multi-arch binary support | ✅ MERGED |
|
||||
| #126 | #120 | GitHub Actions composite action | ✅ MERGED |
|
||||
|
||||
### Closed Issues
|
||||
- #130, #123, #125, #124, #120
|
||||
|
||||
### Open Issues
|
||||
- None blocking; backlog tracked in Gitea project board
|
||||
|
||||
### Worktrees
|
||||
- All cleaned up; no stale branches
|
||||
|
||||
---
|
||||
|
||||
## Feature Completeness Summary
|
||||
|
||||
### ✅ Core Functionality
|
||||
- Multi-provider LLM support (OpenAI, Anthropic, SAP AI Core)
|
||||
- Gitea PR review (mature, proven)
|
||||
- **NEW: GitHub PR review (fully implemented)**
|
||||
- VCS abstraction (Gitea/GitHub transparent routing)
|
||||
- SSRF defense with IP-level validation
|
||||
- Multi-architecture binary deployment
|
||||
|
||||
### ✅ Review Quality
|
||||
- Structured reviews with code snippets
|
||||
- LLM-driven analysis
|
||||
- Persona-based customization
|
||||
- Context awareness
|
||||
|
||||
### ✅ Security
|
||||
- RFC6598 CGN detection
|
||||
- HTTPS enforcement
|
||||
- Redirect safety
|
||||
- Credential handling (no logs, no reflection leaks)
|
||||
- URL validation for VCS API access
|
||||
|
||||
---
|
||||
|
||||
## Next Phase: Backlog Priorities
|
||||
|
||||
### Priority 1: PR Submission
|
||||
**Issue:** #132+ (create)
|
||||
**Goal:** Enable review-bot to create PRs (not just post reviews)
|
||||
**Scope:** PR creation flow, commit logic, test coverage
|
||||
**Est. Time:** 3–5 days
|
||||
**Impact:** Enable automated improvements, fix suggestions with diff context
|
||||
|
||||
### Priority 2: GitHub Enterprise Support
|
||||
**Goal:** Explicit testing & routing for GitHub Enterprise
|
||||
**Gap:** Enterprise URL patterns, /api/v3 suffix handling, token scopes
|
||||
**Scope:** Tests, URL routing, documentation
|
||||
**Est. Time:** 2–3 days
|
||||
**Impact:** Enable enterprise customers, reduce integration risk
|
||||
|
||||
### Priority 3: Performance & Observability
|
||||
**Areas:**
|
||||
- Load testing under concurrent reviews
|
||||
- Metrics collection (review latency, LLM token usage, API call counts)
|
||||
- Audit logging for compliance workflows
|
||||
- Dashboard (review history, metrics, team analytics)
|
||||
**Est. Time:** 5–7 days
|
||||
**Impact:** Operational confidence, troubleshooting, compliance
|
||||
|
||||
### Priority 4: Enhanced Context
|
||||
**Opportunities:**
|
||||
- Semantic code understanding (AST-based analysis for specific languages)
|
||||
- Project-specific review rules (.review-bot.yaml in repo root)
|
||||
- Team-level customization
|
||||
**Est. Time:** 7–10 days
|
||||
|
||||
---
|
||||
|
||||
## Dev Loop Schedule
|
||||
|
||||
- **Interval:** 4 hours
|
||||
- **Next check:** ~6:10 AM UTC (May 15)
|
||||
- **Health:** ✅ Optimal — all systems running
|
||||
- **Status:** Ready for next phase work
|
||||
|
||||
---
|
||||
|
||||
## Metadata
|
||||
|
||||
| Key | Value |
|
||||
|---|---|
|
||||
| Repo | `/home/ubuntu/review-bot` |
|
||||
| Main SHA | `9f3f321` |
|
||||
| Last update | 2026-05-15 02:10 UTC |
|
||||
| Status | All systems optimal |
|
||||
| Next phase | PR submission or GitHub Enterprise support |
|
||||
|
||||
---
|
||||
|
||||
**Summary:** review-bot now supports both GitHub and Gitea PR reviews with a unified abstraction layer. All tests pass, code is clean, security is approved. Ready to move to PR submission or GitHub Enterprise support in the next cycle.
|
||||
@@ -10,6 +10,7 @@ import (
|
||||
"testing"
|
||||
|
||||
"gitea.weiker.me/rodin/review-bot/gitea"
|
||||
"gitea.weiker.me/rodin/review-bot/github"
|
||||
"gitea.weiker.me/rodin/review-bot/llm"
|
||||
"gitea.weiker.me/rodin/review-bot/review"
|
||||
)
|
||||
@@ -17,7 +18,7 @@ import (
|
||||
// Integration test requires a running Gitea instance and LLM endpoint.
|
||||
// Set environment variables:
|
||||
//
|
||||
// INTEGRATION_GITEA_URL - Gitea base URL
|
||||
// INTEGRATION_VCS_URL - VCS base URL
|
||||
// INTEGRATION_GITEA_TOKEN - Gitea API token with repo access
|
||||
// INTEGRATION_GITEA_REPO - owner/repo with an open PR
|
||||
// INTEGRATION_PR_NUMBER - PR number to test against
|
||||
@@ -25,7 +26,7 @@ import (
|
||||
// INTEGRATION_LLM_API_KEY - LLM API key
|
||||
// INTEGRATION_LLM_MODEL - Model name
|
||||
func TestIntegration_FullReviewFlow(t *testing.T) {
|
||||
giteaURL := os.Getenv("INTEGRATION_GITEA_URL")
|
||||
giteaURL := os.Getenv("INTEGRATION_VCS_URL")
|
||||
giteaToken := os.Getenv("INTEGRATION_GITEA_TOKEN")
|
||||
giteaRepo := os.Getenv("INTEGRATION_GITEA_REPO")
|
||||
prNumStr := os.Getenv("INTEGRATION_PR_NUMBER")
|
||||
@@ -104,7 +105,7 @@ func TestIntegration_FullReviewFlow(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestIntegration_PostAndCleanup(t *testing.T) {
|
||||
giteaURL := os.Getenv("INTEGRATION_GITEA_URL")
|
||||
giteaURL := os.Getenv("INTEGRATION_VCS_URL")
|
||||
giteaToken := os.Getenv("INTEGRATION_GITEA_TOKEN")
|
||||
giteaRepo := os.Getenv("INTEGRATION_GITEA_REPO")
|
||||
prNumStr := os.Getenv("INTEGRATION_PR_NUMBER")
|
||||
@@ -130,7 +131,7 @@ func TestIntegration_PostAndCleanup(t *testing.T) {
|
||||
// Post a test review
|
||||
sentinel := "<!-- review-bot:integration-test -->"
|
||||
testBody := "# Integration Test Review\n\nThis is a test review.\n\n" + sentinel
|
||||
posted, err := giteaClient.PostReview(ctx, owner, repoName, prNumber, "COMMENT", testBody, nil)
|
||||
posted, err := giteaClient.PostReview(ctx, owner, repoName, prNumber, "COMMENT", testBody, "", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("PostReview: %v", err)
|
||||
}
|
||||
@@ -159,3 +160,85 @@ func TestIntegration_PostAndCleanup(t *testing.T) {
|
||||
t.Logf("Warning: could not delete test review %d: %v", posted.ID, err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestIntegration_GitHub_PostAndVerifyReview exercises the full VCS routing path
|
||||
// for GitHub when INTEGRATION_GITHUB_TOKEN and INTEGRATION_GITHUB_REPO are set.
|
||||
// It verifies that the GitHub adapter is selected via VCS_TYPE=github and that
|
||||
// PostReview succeeds against a real GitHub PR.
|
||||
//
|
||||
// Required environment variables:
|
||||
//
|
||||
// INTEGRATION_GITHUB_TOKEN - GitHub personal access token with repo access
|
||||
// INTEGRATION_GITHUB_REPO - owner/repo with an open PR (e.g. Rodin-AI/review-bot)
|
||||
// INTEGRATION_GITHUB_PR - PR number to test against
|
||||
//
|
||||
// The test skips gracefully when these variables are absent.
|
||||
func TestIntegration_GitHub_PostAndVerifyReview(t *testing.T) {
|
||||
githubToken := os.Getenv("INTEGRATION_GITHUB_TOKEN")
|
||||
githubRepo := os.Getenv("INTEGRATION_GITHUB_REPO")
|
||||
prNumStr := os.Getenv("INTEGRATION_GITHUB_PR")
|
||||
|
||||
if githubToken == "" || githubRepo == "" || prNumStr == "" {
|
||||
t.Skip("INTEGRATION_GITHUB_TOKEN, INTEGRATION_GITHUB_REPO, and INTEGRATION_GITHUB_PR not set, skipping")
|
||||
}
|
||||
|
||||
prNumber, err := strconv.Atoi(prNumStr)
|
||||
if err != nil {
|
||||
t.Fatalf("Invalid PR number %q: %v", prNumStr, err)
|
||||
}
|
||||
|
||||
parts := strings.SplitN(githubRepo, "/", 2)
|
||||
if len(parts) != 2 || parts[0] == "" || parts[1] == "" {
|
||||
t.Fatalf("Invalid repo format %q, expected owner/repo", githubRepo)
|
||||
}
|
||||
owner, repoName := parts[0], parts[1]
|
||||
|
||||
ctx := context.Background()
|
||||
ghClient := github.NewClient(githubToken, "https://api.github.com")
|
||||
|
||||
// Verify adapter selection: GetAuthenticatedUser must succeed.
|
||||
user, err := ghClient.GetAuthenticatedUser(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("GetAuthenticatedUser: %v — check INTEGRATION_GITHUB_TOKEN", err)
|
||||
}
|
||||
t.Logf("Authenticated as: %s", user)
|
||||
|
||||
// Verify PR is accessible via GitHub adapter.
|
||||
pr, err := ghClient.GetPullRequest(ctx, owner, repoName, prNumber)
|
||||
if err != nil {
|
||||
t.Fatalf("GetPullRequest: %v", err)
|
||||
}
|
||||
t.Logf("PR: %s (sha: %s)", pr.Title, pr.Head.Sha)
|
||||
|
||||
// Post a COMMENT review — does not require PR approval permissions.
|
||||
sentinel := "<!-- review-bot:integration-test -->"
|
||||
testBody := "# Integration Test Review (GitHub)\n\nThis is an automated integration test.\n\n" + sentinel
|
||||
posted, err := ghClient.PostReview(ctx, owner, repoName, prNumber, "COMMENT", testBody, "", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("PostReview: %v", err)
|
||||
}
|
||||
t.Logf("Posted review ID: %d", posted.ID)
|
||||
|
||||
// Verify the review appears in ListReviews.
|
||||
reviews, err := ghClient.ListReviews(ctx, owner, repoName, prNumber)
|
||||
if err != nil {
|
||||
t.Fatalf("ListReviews: %v", err)
|
||||
}
|
||||
found := false
|
||||
for _, r := range reviews {
|
||||
if r.ID == posted.ID && strings.Contains(r.Body, sentinel) {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
t.Errorf("posted review ID %d not found in ListReviews output", posted.ID)
|
||||
}
|
||||
|
||||
// Attempt cleanup — GitHub does not allow deleting submitted reviews,
|
||||
// so this is expected to fail with ErrCannotDeleteSubmittedReview (422).
|
||||
// Log it as informational only.
|
||||
if err := ghClient.DeleteReview(ctx, owner, repoName, prNumber, posted.ID); err != nil {
|
||||
t.Logf("Note: DeleteReview returned (expected for submitted GitHub reviews): %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
+116
-62
@@ -4,6 +4,7 @@ import (
|
||||
"context"
|
||||
"flag"
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"os"
|
||||
"path/filepath"
|
||||
@@ -13,12 +14,20 @@ import (
|
||||
|
||||
"gitea.weiker.me/rodin/review-bot/budget"
|
||||
"gitea.weiker.me/rodin/review-bot/gitea"
|
||||
"gitea.weiker.me/rodin/review-bot/github"
|
||||
"gitea.weiker.me/rodin/review-bot/llm"
|
||||
"gitea.weiker.me/rodin/review-bot/review"
|
||||
)
|
||||
|
||||
var version = "dev"
|
||||
|
||||
// outWriter and errWriter are the output and error writers for subcommands.
|
||||
// They are variables so tests can capture output.
|
||||
var (
|
||||
outWriter io.Writer = os.Stdout
|
||||
errWriter io.Writer = os.Stderr
|
||||
)
|
||||
|
||||
// setupLogger configures the global slog default logger based on format and verbosity.
|
||||
func setupLogger(format, verbosity string) {
|
||||
var level slog.Level
|
||||
@@ -49,12 +58,22 @@ func setupLogger(format, verbosity string) {
|
||||
}
|
||||
|
||||
func main() {
|
||||
// Dispatch subcommands before flag parsing so they get their own args.
|
||||
// e.g. `review-bot validate-url <url>`
|
||||
if len(os.Args) > 1 {
|
||||
switch os.Args[1] {
|
||||
case "validate-url":
|
||||
os.Exit(runValidateURL(os.Args[2:]))
|
||||
}
|
||||
}
|
||||
|
||||
versionFlag := flag.Bool("version", false, "Print version and exit")
|
||||
// Logging flags
|
||||
logFormat := flag.String("log-format", envOrDefault("LOG_FORMAT", "text"), "Log output format: text or json")
|
||||
verbosity := flag.String("verbosity", envOrDefault("LOG_VERBOSITY", "info"), "Log verbosity: debug, info, warn, error")
|
||||
// CLI flags
|
||||
giteaURL := flag.String("gitea-url", envOrDefault("GITEA_URL", ""), "Gitea instance URL")
|
||||
vcsURL := flag.String("vcs-url", os.Getenv("VCS_URL"), "VCS server URL (e.g. https://gitea.example.com)")
|
||||
giteaURLAlias := flag.String("gitea-url", "", "Deprecated: use --vcs-url")
|
||||
repo := flag.String("repo", envOrDefault("GITEA_REPO", ""), "Repository (owner/name)")
|
||||
prNum := flag.String("pr", envOrDefault("PR_NUMBER", ""), "Pull request number")
|
||||
reviewerName := flag.String("reviewer-name", envOrDefault("REVIEWER_NAME", ""), "Reviewer display name")
|
||||
@@ -91,12 +110,24 @@ func main() {
|
||||
|
||||
slog.Info("review-bot starting", "version", version)
|
||||
|
||||
// Backward compatibility: fall back to deprecated env var / flag if VCS_URL / --vcs-url not set.
|
||||
if *vcsURL == "" {
|
||||
if v := os.Getenv("GITEA_URL"); v != "" {
|
||||
slog.Warn("GITEA_URL is deprecated; rename the environment variable to VCS_URL")
|
||||
*vcsURL = v
|
||||
}
|
||||
}
|
||||
if *vcsURL == "" && *giteaURLAlias != "" {
|
||||
slog.Warn("--gitea-url is deprecated; use --vcs-url instead")
|
||||
*vcsURL = *giteaURLAlias
|
||||
}
|
||||
|
||||
// Validate required fields
|
||||
// For aicore provider, llm-base-url and llm-api-key are not required
|
||||
isAICore := llm.Provider(*llmProvider) == llm.ProviderAICore
|
||||
if *giteaURL == "" || *repo == "" || *prNum == "" || *reviewerToken == "" || *llmModel == "" {
|
||||
if *vcsURL == "" || *repo == "" || *prNum == "" || *reviewerToken == "" || *llmModel == "" {
|
||||
fmt.Fprintf(os.Stderr, "Error: missing required flags or environment variables\n\n")
|
||||
fmt.Fprintf(os.Stderr, "Required: --gitea-url, --repo, --pr, --reviewer-token, --llm-model\n")
|
||||
fmt.Fprintf(os.Stderr, "Required: --vcs-url, --repo, --pr, --reviewer-token, --llm-model\n")
|
||||
os.Exit(1)
|
||||
}
|
||||
if !isAICore && (*llmBaseURL == "" || *llmAPIKey == "") {
|
||||
@@ -139,7 +170,39 @@ func main() {
|
||||
}
|
||||
|
||||
// Initialize clients
|
||||
giteaClient := gitea.NewClient(*giteaURL, *reviewerToken)
|
||||
// Detect VCS type: explicit flag > env var > URL heuristic (default: gitea).
|
||||
vcsType := envOrDefault("VCS_TYPE", "")
|
||||
if vcsType == "" {
|
||||
// Heuristic: if the URL looks like github.com or a GitHub Enterprise host,
|
||||
// default to GitHub. The composite action sets VCS_TYPE explicitly, so this
|
||||
// is a fallback for manual invocations.
|
||||
if strings.Contains(*vcsURL, "github.com") || strings.Contains(*vcsURL, "github.concur.com") {
|
||||
vcsType = "github"
|
||||
} else {
|
||||
vcsType = "gitea"
|
||||
}
|
||||
}
|
||||
slog.Info("VCS type detected", "vcs_type", vcsType, "vcs_url", *vcsURL)
|
||||
|
||||
var vcs vcsClient
|
||||
switch vcsType {
|
||||
case "github":
|
||||
// GitHub: baseURL is the API URL, derived from server URL.
|
||||
// github.com → https://api.github.com
|
||||
// GHES (e.g. https://ghe.example.com) → https://ghe.example.com/api/v3
|
||||
apiURL := githubAPIURL(*vcsURL)
|
||||
ghClient := github.NewClient(*reviewerToken, apiURL)
|
||||
vcs = newGithubVCSAdapter(ghClient)
|
||||
slog.Info("using GitHub VCS client", "api_url", apiURL)
|
||||
case "gitea":
|
||||
giteaClient := gitea.NewClient(*vcsURL, *reviewerToken)
|
||||
vcs = newGiteaVCSAdapter(giteaClient)
|
||||
slog.Info("using Gitea VCS client", "url", *vcsURL)
|
||||
default:
|
||||
slog.Error("unsupported VCS type", "vcs_type", vcsType, "valid", "gitea, github")
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
llmClient := llm.NewClient(*llmBaseURL, *llmAPIKey, *llmModel)
|
||||
if *llmTemp < 0 || *llmTemp > 2 {
|
||||
slog.Error("invalid LLM temperature", "temperature", *llmTemp, "range", "0-2")
|
||||
@@ -177,7 +240,7 @@ func main() {
|
||||
var persona *review.Persona
|
||||
if *personaName != "" {
|
||||
// Try loading from repo first, then fall back to built-in
|
||||
repoPersonas, err := review.LoadRepoPersonas(ctx, newGiteaClientAdapter(giteaClient), owner, repoName)
|
||||
repoPersonas, err := review.LoadRepoPersonas(ctx, vcs, owner, repoName)
|
||||
if err != nil {
|
||||
slog.Warn("could not load repo personas", "repo", owner+"/"+repoName, "error", err)
|
||||
// Continue with built-in personas only.
|
||||
@@ -213,7 +276,7 @@ func main() {
|
||||
slog.Info("reviewing pull request", "pr", prNumber, "repo", fmt.Sprintf("%s/%s", owner, repoName))
|
||||
|
||||
// Step 1: Fetch PR metadata
|
||||
pr, err := giteaClient.GetPullRequest(ctx, owner, repoName, prNumber)
|
||||
pr, err := vcs.GetPullRequest(ctx, owner, repoName, prNumber)
|
||||
if err != nil {
|
||||
slog.Error("failed to fetch PR", "pr", prNumber, "error", err)
|
||||
os.Exit(1)
|
||||
@@ -221,7 +284,7 @@ func main() {
|
||||
slog.Info("fetched PR metadata", "pr", prNumber, "title", pr.Title)
|
||||
|
||||
// Step 2: Fetch diff
|
||||
diff, err := giteaClient.GetPullRequestDiff(ctx, owner, repoName, prNumber)
|
||||
diff, err := vcs.GetPullRequestDiff(ctx, owner, repoName, prNumber)
|
||||
if err != nil {
|
||||
slog.Error("failed to fetch diff", "pr", prNumber, "error", err)
|
||||
os.Exit(1)
|
||||
@@ -230,11 +293,11 @@ func main() {
|
||||
|
||||
// Step 3: Fetch full file content for modified files
|
||||
fileContext := ""
|
||||
files, err := giteaClient.GetPullRequestFiles(ctx, owner, repoName, prNumber)
|
||||
files, err := vcs.GetPullRequestFiles(ctx, owner, repoName, prNumber)
|
||||
if err != nil {
|
||||
slog.Warn("could not fetch PR files list", "pr", prNumber, "error", err)
|
||||
} else {
|
||||
fileContext = fetchFileContext(ctx, giteaClient, owner, repoName, pr.Head.Ref, files)
|
||||
fileContext = fetchFileContext(ctx, vcs, owner, repoName, pr.Head.Ref, files)
|
||||
slog.Debug("fetched file context", "files", len(files))
|
||||
}
|
||||
|
||||
@@ -242,7 +305,7 @@ func main() {
|
||||
ciPassed := true
|
||||
ciDetails := ""
|
||||
if pr.Head.Sha != "" {
|
||||
statuses, err := giteaClient.GetCommitStatuses(ctx, owner, repoName, pr.Head.Sha)
|
||||
statuses, err := vcs.GetCommitStatuses(ctx, owner, repoName, pr.Head.Sha)
|
||||
if err != nil {
|
||||
slog.Warn("could not fetch CI status", "sha", pr.Head.Sha, "error", err)
|
||||
} else {
|
||||
@@ -254,7 +317,7 @@ func main() {
|
||||
// Step 5: Load conventions file if specified
|
||||
conventions := ""
|
||||
if *conventionsFile != "" {
|
||||
content, err := giteaClient.GetFileContent(ctx, owner, repoName, *conventionsFile)
|
||||
content, err := vcs.GetFileContent(ctx, owner, repoName, *conventionsFile)
|
||||
if err != nil {
|
||||
slog.Warn("could not load conventions file", "file", *conventionsFile, "error", err)
|
||||
} else {
|
||||
@@ -266,7 +329,7 @@ func main() {
|
||||
// Step 6: Load patterns from external repo if specified
|
||||
patterns := ""
|
||||
if *patternsRepo != "" {
|
||||
patterns = fetchPatterns(ctx, giteaClient, *patternsRepo, *patternsFiles)
|
||||
patterns = fetchPatterns(ctx, vcs, *patternsRepo, *patternsFiles)
|
||||
slog.Debug("loaded patterns", "repo", *patternsRepo, "bytes", len(patterns))
|
||||
}
|
||||
|
||||
@@ -381,7 +444,7 @@ func main() {
|
||||
// Stale check: verify HEAD hasn't moved since we started
|
||||
evaluatedSHA := pr.Head.Sha
|
||||
var currentSHA string
|
||||
currentPR, err := giteaClient.GetPullRequest(ctx, owner, repoName, prNumber)
|
||||
currentPR, err := vcs.GetPullRequest(ctx, owner, repoName, prNumber)
|
||||
if err != nil {
|
||||
slog.Warn("could not re-fetch PR for stale check", "pr", prNumber, "error", err)
|
||||
// currentSHA stays empty — shouldSkipStaleReview will return false
|
||||
@@ -398,10 +461,10 @@ func main() {
|
||||
|
||||
// Map findings to inline comments for lines present in the diff
|
||||
diffRanges := gitea.ParseDiffNewLines(diff)
|
||||
var inlineComments []gitea.ReviewComment
|
||||
var inlineComments []vcsReviewComment
|
||||
for _, f := range result.Findings {
|
||||
if f.File != "" && f.Line > 0 && diffRanges.Contains(f.File, f.Line) {
|
||||
inlineComments = append(inlineComments, gitea.ReviewComment{
|
||||
inlineComments = append(inlineComments, vcsReviewComment{
|
||||
Path: f.File,
|
||||
NewPosition: int64(f.Line),
|
||||
Body: fmt.Sprintf("**[%s]** %s", f.Severity, f.Finding),
|
||||
@@ -416,9 +479,9 @@ func main() {
|
||||
// 1. POST new review first (gets non-stale approval badge on HEAD)
|
||||
// 2. Then supersede old review with link to the new one
|
||||
// Order matters: post first so we have the new review's URL for the supersede message.
|
||||
var oldReviews []gitea.Review
|
||||
var oldReviews []vcsReview
|
||||
if *reviewerName != "" {
|
||||
existingReviews, err := giteaClient.ListReviews(ctx, owner, repoName, prNumber)
|
||||
existingReviews, err := vcs.ListReviews(ctx, owner, repoName, prNumber)
|
||||
if err != nil {
|
||||
slog.Warn("could not list existing reviews", "pr", prNumber, "error", err)
|
||||
} else {
|
||||
@@ -431,11 +494,11 @@ func main() {
|
||||
}
|
||||
|
||||
// Self-request as reviewer (ensures we appear in required-reviewer checks)
|
||||
authUser, err := giteaClient.GetAuthenticatedUser(ctx)
|
||||
authUser, err := vcs.GetAuthenticatedUser(ctx)
|
||||
if err != nil {
|
||||
slog.Warn("could not determine authenticated user for reviewer self-request", "error", err)
|
||||
} else if authUser != "" {
|
||||
if err := giteaClient.RequestReviewer(ctx, owner, repoName, prNumber, authUser); err != nil {
|
||||
if err := vcs.RequestReviewer(ctx, owner, repoName, prNumber, authUser); err != nil {
|
||||
slog.Warn("could not self-request as reviewer", "user", authUser, "error", err)
|
||||
} else {
|
||||
slog.Debug("self-requested as reviewer", "user", authUser, "pr", prNumber)
|
||||
@@ -444,31 +507,34 @@ func main() {
|
||||
|
||||
// POST new review
|
||||
slog.Info("posting review", "event", event, "pr", prNumber)
|
||||
posted, err := giteaClient.PostReview(ctx, owner, repoName, prNumber, event, reviewBody, inlineComments)
|
||||
posted, err := vcs.PostReview(ctx, owner, repoName, prNumber, event, reviewBody, evaluatedSHA, inlineComments)
|
||||
if err != nil {
|
||||
slog.Error("failed to post review", "pr", prNumber, "event", event, "error", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
slog.Info("review posted", "review_id", posted.ID, "user", posted.User.Login, "pr", prNumber)
|
||||
|
||||
// Supersede all old reviews with link to the new one
|
||||
if len(oldReviews) > 0 {
|
||||
newReviewURL := fmt.Sprintf("%s/%s/%s/pulls/%d#pullrequestreview-%d", strings.TrimRight(*giteaURL, "/"), owner, repoName, prNumber, posted.ID)
|
||||
// Supersede all old reviews with link to the new one.
|
||||
// This is only supported on Gitea (requires timeline API); GitHub reviews cannot
|
||||
// be edited after submission, so we skip the supersede step there.
|
||||
extVCS, isGiteaExt := vcs.(giteaExtClient)
|
||||
if len(oldReviews) > 0 && isGiteaExt {
|
||||
newReviewURL := fmt.Sprintf("%s/%s/%s/pulls/%d#pullrequestreview-%d", strings.TrimRight(*vcsURL, "/"), owner, repoName, prNumber, posted.ID)
|
||||
for _, oldReview := range oldReviews {
|
||||
cid, err := giteaClient.GetTimelineReviewCommentIDForReview(ctx, owner, repoName, prNumber, oldReview.ID)
|
||||
cid, err := extVCS.GetTimelineReviewCommentIDForReview(ctx, owner, repoName, int64(prNumber), oldReview.ID)
|
||||
if err != nil {
|
||||
slog.Warn("could not find comment ID for old review", "review_id", oldReview.ID, "error", err)
|
||||
continue
|
||||
}
|
||||
supersededBody := buildSupersededBody(oldReview.Body, oldReview.CommitID, newReviewURL, sentinel)
|
||||
if err := giteaClient.EditComment(ctx, owner, repoName, cid, supersededBody); err != nil {
|
||||
if err := extVCS.EditComment(ctx, owner, repoName, cid, supersededBody); err != nil {
|
||||
slog.Warn("could not mark old review as superseded", "review_id", oldReview.ID, "comment_id", cid, "error", err)
|
||||
continue
|
||||
}
|
||||
slog.Info("marked old review as superseded", "review_id", oldReview.ID, "new_review_id", posted.ID, "pr", prNumber)
|
||||
|
||||
// Resolve old review's inline comments
|
||||
oldComments, err := giteaClient.ListReviewComments(ctx, owner, repoName, prNumber, oldReview.ID)
|
||||
oldComments, err := extVCS.ListReviewComments(ctx, owner, repoName, int64(prNumber), oldReview.ID)
|
||||
if err != nil {
|
||||
slog.Warn("could not list old review comments for resolution", "review_id", oldReview.ID, "error", err)
|
||||
continue
|
||||
@@ -478,7 +544,7 @@ func main() {
|
||||
if c.ID == 0 {
|
||||
continue
|
||||
}
|
||||
if err := giteaClient.ResolveComment(ctx, owner, repoName, c.ID); err != nil {
|
||||
if err := extVCS.ResolveComment(ctx, owner, repoName, c.ID); err != nil {
|
||||
slog.Debug("could not resolve inline comment", "comment_id", c.ID, "error", err)
|
||||
failed++
|
||||
} else {
|
||||
@@ -492,12 +558,14 @@ func main() {
|
||||
slog.Warn("some inline comments could not be resolved", "review_id", oldReview.ID, "failed", failed, "pr", prNumber)
|
||||
}
|
||||
}
|
||||
} else if len(oldReviews) > 0 {
|
||||
slog.Info("skipping supersede of old reviews (not supported on this VCS)", "old_count", len(oldReviews), "pr", prNumber)
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
// fetchFileContext fetches the full content of modified files from the PR branch.
|
||||
func fetchFileContext(ctx context.Context, client *gitea.Client, owner, repo, ref string, files []gitea.ChangedFile) string {
|
||||
func fetchFileContext(ctx context.Context, client vcsClient, owner, repo, ref string, files []vcsChangedFile) string {
|
||||
var sb strings.Builder
|
||||
for _, f := range files {
|
||||
if ctx.Err() != nil {
|
||||
@@ -524,7 +592,7 @@ func fetchFileContext(ctx context.Context, client *gitea.Client, owner, repo, re
|
||||
// patternsFiles is comma-separated list of file paths or directories.
|
||||
// If a path ends with / or is a directory, all files within it are fetched recursively.
|
||||
// If patternsFiles is empty, all files from the repo root are fetched.
|
||||
func fetchPatterns(ctx context.Context, client *gitea.Client, patternsRepo, patternsFiles string) string {
|
||||
func fetchPatterns(ctx context.Context, client vcsClient, patternsRepo, patternsFiles string) string {
|
||||
var sb strings.Builder
|
||||
|
||||
repos := strings.Split(patternsRepo, ",")
|
||||
@@ -601,7 +669,7 @@ func isPatternFile(path string) bool {
|
||||
}
|
||||
|
||||
// evaluateCIStatus checks if all CI statuses indicate success.
|
||||
func evaluateCIStatus(statuses []gitea.CommitStatus) (passed bool, details string) {
|
||||
func evaluateCIStatus(statuses []vcsCommitStatus) (passed bool, details string) {
|
||||
if len(statuses) == 0 {
|
||||
return true, "no CI statuses found"
|
||||
}
|
||||
@@ -624,6 +692,19 @@ func evaluateCIStatus(statuses []gitea.CommitStatus) (passed bool, details strin
|
||||
return true, "all checks passed"
|
||||
}
|
||||
|
||||
// githubAPIURL converts a GitHub server URL to its API base URL.
|
||||
// github.com → https://api.github.com
|
||||
// GHES (e.g. https://ghe.example.com) → https://ghe.example.com/api/v3
|
||||
func githubAPIURL(serverURL string) string {
|
||||
const canonicalGitHub = "https://github.com"
|
||||
const githubAPIBase = "https://api.github.com"
|
||||
if serverURL == "" || strings.TrimRight(serverURL, "/") == canonicalGitHub {
|
||||
return githubAPIBase
|
||||
}
|
||||
// GitHub Enterprise Server: /api/v3 suffix
|
||||
return strings.TrimRight(serverURL, "/") + "/api/v3"
|
||||
}
|
||||
|
||||
func envOrDefault(key, defaultVal string) string {
|
||||
if v := os.Getenv(key); v != "" {
|
||||
return v
|
||||
@@ -739,7 +820,7 @@ func buildSupersededBody(originalBody, commitSHA, newReviewURL, sentinel string)
|
||||
// Gitea user. This indicates misconfiguration where two roles share a token
|
||||
// instead of having separate Gitea accounts. Returns true if shared token
|
||||
// detected (caller should skip update-in-place logic to avoid clobbering).
|
||||
func hasSharedToken(reviews []gitea.Review, ownSentinel string) bool {
|
||||
func hasSharedToken(reviews []vcsReview, ownSentinel string) bool {
|
||||
ownLogin := ""
|
||||
for _, r := range reviews {
|
||||
if strings.Contains(r.Body, ownSentinel) {
|
||||
@@ -777,8 +858,8 @@ func extractSentinelName(body string) string {
|
||||
}
|
||||
|
||||
// findOwnReview locates the most recent non-superseded review matching the sentinel.
|
||||
func findOwnReview(reviews []gitea.Review, sentinel string) *gitea.Review {
|
||||
var best *gitea.Review
|
||||
func findOwnReview(reviews []vcsReview, sentinel string) *vcsReview {
|
||||
var best *vcsReview
|
||||
for i := range reviews {
|
||||
if !strings.Contains(reviews[i].Body, sentinel) {
|
||||
continue
|
||||
@@ -794,8 +875,8 @@ func findOwnReview(reviews []gitea.Review, sentinel string) *gitea.Review {
|
||||
}
|
||||
|
||||
// findAllOwnReviews returns all non-superseded reviews matching the sentinel.
|
||||
func findAllOwnReviews(reviews []gitea.Review, sentinel string) []gitea.Review {
|
||||
var result []gitea.Review
|
||||
func findAllOwnReviews(reviews []vcsReview, sentinel string) []vcsReview {
|
||||
var result []vcsReview
|
||||
for i := range reviews {
|
||||
if !strings.Contains(reviews[i].Body, sentinel) {
|
||||
continue
|
||||
@@ -821,31 +902,4 @@ func shouldSkipStaleReview(evaluatedSHA, currentSHA string) bool {
|
||||
return evaluatedSHA != currentSHA
|
||||
}
|
||||
|
||||
// giteaClientAdapter adapts gitea.Client to review.GiteaClient interface.
|
||||
type giteaClientAdapter struct {
|
||||
client *gitea.Client
|
||||
}
|
||||
|
||||
func newGiteaClientAdapter(c *gitea.Client) *giteaClientAdapter {
|
||||
return &giteaClientAdapter{client: c}
|
||||
}
|
||||
|
||||
func (a *giteaClientAdapter) ListContents(ctx context.Context, owner, repo, path string) ([]review.ContentEntry, error) {
|
||||
entries, err := a.client.ListContents(ctx, owner, repo, path)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
result := make([]review.ContentEntry, len(entries))
|
||||
for i, e := range entries {
|
||||
result[i] = review.ContentEntry{
|
||||
Name: e.Name,
|
||||
Path: e.Path,
|
||||
Type: e.Type,
|
||||
}
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (a *giteaClientAdapter) GetFileContent(ctx context.Context, owner, repo, filepath string) (string, error) {
|
||||
return a.client.GetFileContent(ctx, owner, repo, filepath)
|
||||
}
|
||||
|
||||
+75
-34
@@ -10,7 +10,6 @@ import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"gitea.weiker.me/rodin/review-bot/gitea"
|
||||
)
|
||||
|
||||
func TestValidateReviewerName(t *testing.T) {
|
||||
@@ -154,12 +153,11 @@ func TestValidateWorkspacePath(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func makeReview(id int64, login, state string, stale bool, body string) gitea.Review {
|
||||
r := gitea.Review{
|
||||
func makeReview(id int64, login, state string, _ bool, body string) vcsReview {
|
||||
r := vcsReview{
|
||||
ID: id,
|
||||
Body: body,
|
||||
State: state,
|
||||
Stale: stale,
|
||||
}
|
||||
r.User.Login = login
|
||||
return r
|
||||
@@ -216,7 +214,7 @@ func TestBuildSupersededBodyShortSHA(t *testing.T) {
|
||||
func TestFindOwnReview(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
reviews []gitea.Review
|
||||
reviews []vcsReview
|
||||
sentinel string
|
||||
wantID int64
|
||||
wantNil bool
|
||||
@@ -229,7 +227,7 @@ func TestFindOwnReview(t *testing.T) {
|
||||
},
|
||||
{
|
||||
name: "found by sentinel",
|
||||
reviews: []gitea.Review{
|
||||
reviews: []vcsReview{
|
||||
makeReview(42, "bot", "APPROVED", false, "review body\n<!-- review-bot:sonnet -->"),
|
||||
},
|
||||
sentinel: "<!-- review-bot:sonnet -->",
|
||||
@@ -237,7 +235,7 @@ func TestFindOwnReview(t *testing.T) {
|
||||
},
|
||||
{
|
||||
name: "wrong sentinel",
|
||||
reviews: []gitea.Review{
|
||||
reviews: []vcsReview{
|
||||
makeReview(42, "bot", "APPROVED", false, "body\n<!-- review-bot:gpt -->"),
|
||||
},
|
||||
sentinel: "<!-- review-bot:sonnet -->",
|
||||
@@ -245,7 +243,7 @@ func TestFindOwnReview(t *testing.T) {
|
||||
},
|
||||
{
|
||||
name: "multiple reviews, returns first match",
|
||||
reviews: []gitea.Review{
|
||||
reviews: []vcsReview{
|
||||
makeReview(10, "bot", "APPROVED", false, "old\n<!-- review-bot:gpt -->"),
|
||||
makeReview(20, "bot", "APPROVED", false, "new\n<!-- review-bot:sonnet -->"),
|
||||
},
|
||||
@@ -254,7 +252,7 @@ func TestFindOwnReview(t *testing.T) {
|
||||
},
|
||||
{
|
||||
name: "skips superseded review",
|
||||
reviews: []gitea.Review{
|
||||
reviews: []vcsReview{
|
||||
makeReview(10, "bot", "APPROVED", false, "~~Original review~~\n\n**Superseded**\n<!-- review-bot:sonnet -->"),
|
||||
makeReview(20, "bot", "APPROVED", false, "fresh review\n<!-- review-bot:sonnet -->"),
|
||||
},
|
||||
@@ -263,7 +261,7 @@ func TestFindOwnReview(t *testing.T) {
|
||||
},
|
||||
{
|
||||
name: "only superseded reviews exist",
|
||||
reviews: []gitea.Review{
|
||||
reviews: []vcsReview{
|
||||
makeReview(10, "bot", "APPROVED", false, "~~Original review~~\n\n<!-- review-bot:sonnet -->"),
|
||||
},
|
||||
sentinel: "<!-- review-bot:sonnet -->",
|
||||
@@ -271,7 +269,7 @@ func TestFindOwnReview(t *testing.T) {
|
||||
},
|
||||
{
|
||||
name: "picks highest ID among matches",
|
||||
reviews: []gitea.Review{
|
||||
reviews: []vcsReview{
|
||||
makeReview(50, "bot", "APPROVED", false, "v1\n<!-- review-bot:sonnet -->"),
|
||||
makeReview(30, "bot", "APPROVED", false, "v0\n<!-- review-bot:sonnet -->"),
|
||||
},
|
||||
@@ -302,7 +300,7 @@ func TestFindOwnReview(t *testing.T) {
|
||||
func TestHasSharedToken(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
reviews []gitea.Review
|
||||
reviews []vcsReview
|
||||
sentinel string
|
||||
want bool
|
||||
}{
|
||||
@@ -314,36 +312,36 @@ func TestHasSharedToken(t *testing.T) {
|
||||
},
|
||||
{
|
||||
name: "no own review yet - cannot detect",
|
||||
reviews: []gitea.Review{
|
||||
{ID: 1, User: struct{ Login string `json:"login"` }{Login: "other"}, Body: "<!-- review-bot:gpt --> body"},
|
||||
reviews: []vcsReview{
|
||||
{ID: 1, User: struct{ Login string }{Login: "other"}, Body: "<!-- review-bot:gpt --> body"},
|
||||
},
|
||||
sentinel: "<!-- review-bot:sonnet -->",
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "separate users - no shared token",
|
||||
reviews: []gitea.Review{
|
||||
{ID: 1, User: struct{ Login string `json:"login"` }{Login: "sonnet-review-bot"}, Body: "<!-- review-bot:sonnet --> body"},
|
||||
{ID: 2, User: struct{ Login string `json:"login"` }{Login: "security-review-bot"}, Body: "<!-- review-bot:security --> body"},
|
||||
reviews: []vcsReview{
|
||||
{ID: 1, User: struct{ Login string }{Login: "sonnet-review-bot"}, Body: "<!-- review-bot:sonnet --> body"},
|
||||
{ID: 2, User: struct{ Login string }{Login: "security-review-bot"}, Body: "<!-- review-bot:security --> body"},
|
||||
},
|
||||
sentinel: "<!-- review-bot:sonnet -->",
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "shared token detected - same user different sentinels",
|
||||
reviews: []gitea.Review{
|
||||
{ID: 1, User: struct{ Login string `json:"login"` }{Login: "sonnet-review-bot"}, Body: "<!-- review-bot:sonnet --> body"},
|
||||
{ID: 2, User: struct{ Login string `json:"login"` }{Login: "sonnet-review-bot"}, Body: "<!-- review-bot:security --> body"},
|
||||
reviews: []vcsReview{
|
||||
{ID: 1, User: struct{ Login string }{Login: "sonnet-review-bot"}, Body: "<!-- review-bot:sonnet --> body"},
|
||||
{ID: 2, User: struct{ Login string }{Login: "sonnet-review-bot"}, Body: "<!-- review-bot:security --> body"},
|
||||
},
|
||||
sentinel: "<!-- review-bot:sonnet -->",
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "three roles same user",
|
||||
reviews: []gitea.Review{
|
||||
{ID: 1, User: struct{ Login string `json:"login"` }{Login: "bot"}, Body: "<!-- review-bot:sonnet --> body"},
|
||||
{ID: 2, User: struct{ Login string `json:"login"` }{Login: "bot"}, Body: "<!-- review-bot:security --> body"},
|
||||
{ID: 3, User: struct{ Login string `json:"login"` }{Login: "bot"}, Body: "<!-- review-bot:gpt --> body"},
|
||||
reviews: []vcsReview{
|
||||
{ID: 1, User: struct{ Login string }{Login: "bot"}, Body: "<!-- review-bot:sonnet --> body"},
|
||||
{ID: 2, User: struct{ Login string }{Login: "bot"}, Body: "<!-- review-bot:security --> body"},
|
||||
{ID: 3, User: struct{ Login string }{Login: "bot"}, Body: "<!-- review-bot:gpt --> body"},
|
||||
},
|
||||
sentinel: "<!-- review-bot:sonnet -->",
|
||||
want: true,
|
||||
@@ -553,7 +551,7 @@ func TestBuildPatternPaths(t *testing.T) {
|
||||
func TestEvaluateCIStatus(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
statuses []gitea.CommitStatus
|
||||
statuses []vcsCommitStatus
|
||||
wantPassed bool
|
||||
wantSubstr string
|
||||
}{
|
||||
@@ -565,7 +563,7 @@ func TestEvaluateCIStatus(t *testing.T) {
|
||||
},
|
||||
{
|
||||
name: "all success",
|
||||
statuses: []gitea.CommitStatus{
|
||||
statuses: []vcsCommitStatus{
|
||||
{Status: "success", Context: "ci/build", Description: "Build passed"},
|
||||
{Status: "success", Context: "ci/test", Description: "Tests passed"},
|
||||
},
|
||||
@@ -574,7 +572,7 @@ func TestEvaluateCIStatus(t *testing.T) {
|
||||
},
|
||||
{
|
||||
name: "one failure",
|
||||
statuses: []gitea.CommitStatus{
|
||||
statuses: []vcsCommitStatus{
|
||||
{Status: "success", Context: "ci/build", Description: "Build passed"},
|
||||
{Status: "failure", Context: "ci/test", Description: "Tests failed"},
|
||||
},
|
||||
@@ -583,7 +581,7 @@ func TestEvaluateCIStatus(t *testing.T) {
|
||||
},
|
||||
{
|
||||
name: "error status",
|
||||
statuses: []gitea.CommitStatus{
|
||||
statuses: []vcsCommitStatus{
|
||||
{Status: "error", Context: "ci/lint", Description: "Lint error"},
|
||||
},
|
||||
wantPassed: false,
|
||||
@@ -591,7 +589,7 @@ func TestEvaluateCIStatus(t *testing.T) {
|
||||
},
|
||||
{
|
||||
name: "pending treated as not-failed",
|
||||
statuses: []gitea.CommitStatus{
|
||||
statuses: []vcsCommitStatus{
|
||||
{Status: "pending", Context: "ci/build", Description: "In progress"},
|
||||
{Status: "success", Context: "ci/test", Description: "Tests passed"},
|
||||
},
|
||||
@@ -600,7 +598,7 @@ func TestEvaluateCIStatus(t *testing.T) {
|
||||
},
|
||||
{
|
||||
name: "multiple failures",
|
||||
statuses: []gitea.CommitStatus{
|
||||
statuses: []vcsCommitStatus{
|
||||
{Status: "failure", Context: "ci/build", Description: "Build failed"},
|
||||
{Status: "failure", Context: "ci/test", Description: "Tests failed"},
|
||||
},
|
||||
@@ -609,7 +607,7 @@ func TestEvaluateCIStatus(t *testing.T) {
|
||||
},
|
||||
{
|
||||
name: "mixed with pending and failure",
|
||||
statuses: []gitea.CommitStatus{
|
||||
statuses: []vcsCommitStatus{
|
||||
{Status: "success", Context: "ci/build", Description: "Build passed"},
|
||||
{Status: "pending", Context: "ci/deploy", Description: "Deploying"},
|
||||
{Status: "failure", Context: "ci/test", Description: "Tests failed"},
|
||||
@@ -632,6 +630,48 @@ func TestEvaluateCIStatus(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestGithubAPIURL(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
want string
|
||||
}{
|
||||
{
|
||||
name: "empty string defaults to api.github.com",
|
||||
input: "",
|
||||
want: "https://api.github.com",
|
||||
},
|
||||
{
|
||||
name: "github.com maps to api.github.com",
|
||||
input: "https://github.com",
|
||||
want: "https://api.github.com",
|
||||
},
|
||||
{
|
||||
name: "github.com with trailing slash maps to api.github.com",
|
||||
input: "https://github.com/",
|
||||
want: "https://api.github.com",
|
||||
},
|
||||
{
|
||||
name: "GHES host gets /api/v3 suffix",
|
||||
input: "https://ghe.example.com",
|
||||
want: "https://ghe.example.com/api/v3",
|
||||
},
|
||||
{
|
||||
name: "GHES concur domain does not map to api.github.com",
|
||||
input: "https://github.concur.com",
|
||||
want: "https://github.concur.com/api/v3",
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := githubAPIURL(tt.input)
|
||||
if got != tt.want {
|
||||
t.Errorf("githubAPIURL(%q) = %q, want %q", tt.input, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestEnvOrDefault(t *testing.T) {
|
||||
// Test with unset env var
|
||||
os.Unsetenv("TEST_ENV_OR_DEFAULT_UNSET")
|
||||
@@ -972,7 +1012,7 @@ func TestMainSubprocess_InvalidProvider(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// cleanEnv returns environ without any GITEA/LLM/REVIEWER env vars that would
|
||||
// cleanEnv returns environ without any GITEA/LLM/REVIEWER/VCS env vars that would
|
||||
// interfere with testing missing-flag scenarios.
|
||||
func cleanEnv() []string {
|
||||
var env []string
|
||||
@@ -987,7 +1027,8 @@ func cleanEnv() []string {
|
||||
strings.HasPrefix(key, "CONVENTIONS_"),
|
||||
strings.HasPrefix(key, "SYSTEM_PROMPT_"),
|
||||
strings.HasPrefix(key, "PATTERNS_"),
|
||||
strings.HasPrefix(key, "UPDATE_"):
|
||||
strings.HasPrefix(key, "UPDATE_"),
|
||||
strings.HasPrefix(key, "VCS_"):
|
||||
continue
|
||||
default:
|
||||
env = append(env, e)
|
||||
@@ -997,7 +1038,7 @@ func cleanEnv() []string {
|
||||
}
|
||||
|
||||
func TestFindAllOwnReviews(t *testing.T) {
|
||||
reviews := []gitea.Review{
|
||||
reviews := []vcsReview{
|
||||
{ID: 1, Body: "<!-- review-bot:sonnet -->\nfirst review"},
|
||||
{ID: 2, Body: "<!-- review-bot:gpt -->\nother bot"},
|
||||
{ID: 3, Body: "<!-- review-bot:sonnet -->\nsecond review"},
|
||||
|
||||
@@ -0,0 +1,125 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"gitea.weiker.me/rodin/review-bot/gitea"
|
||||
)
|
||||
|
||||
// runValidateURL implements the `review-bot validate-url <url>` subcommand.
|
||||
//
|
||||
// It resolves the given URL's hostname and checks that every returned IP is
|
||||
// publicly routable (not RFC1918, loopback, link-local, or other reserved
|
||||
// ranges). The exit code communicates the result to callers:
|
||||
//
|
||||
// 0 — URL is safe to use
|
||||
// 1 — URL resolves to a blocked/private address
|
||||
// 2 — URL is malformed, has an unsafe scheme, or DNS lookup failed
|
||||
//
|
||||
// This is intended for use from action.yml shell steps that need to validate
|
||||
// a user-supplied URL before passing it to curl.
|
||||
func runValidateURL(args []string) int {
|
||||
if len(args) != 1 {
|
||||
fmt.Fprintln(errWriter, "usage: review-bot validate-url <url>")
|
||||
fmt.Fprintln(errWriter, "")
|
||||
fmt.Fprintln(errWriter, "Resolves <url> and verifies all resolved IPs are publicly routable.")
|
||||
fmt.Fprintln(errWriter, "Exit 0=safe, 1=blocked, 2=error")
|
||||
return 2
|
||||
}
|
||||
rawURL := args[0]
|
||||
|
||||
if err := validateURL(rawURL); err != nil {
|
||||
fmt.Fprintf(errWriter, "Error: %v\n", err)
|
||||
var ve *validateError
|
||||
if isValidateError(err, &ve) {
|
||||
return ve.code
|
||||
}
|
||||
return 2
|
||||
}
|
||||
fmt.Fprintf(outWriter, "OK: %s is safe\n", rawURL)
|
||||
return 0
|
||||
}
|
||||
|
||||
// validateError carries an exit code alongside a message.
|
||||
type validateError struct {
|
||||
code int
|
||||
message string
|
||||
}
|
||||
|
||||
func (e *validateError) Error() string { return e.message }
|
||||
|
||||
// isValidateError checks if err is or wraps a *validateError and sets out.
|
||||
// Uses errors.As so that wrapped *validateError values (e.g. from fmt.Errorf("...: %w", &validateError{...}))
|
||||
// are also detected, making the function robust against future wrapping.
|
||||
func isValidateError(err error, out **validateError) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
return errors.As(err, out)
|
||||
}
|
||||
|
||||
// validateURL checks that rawURL is safe for use as a Gitea server URL:
|
||||
// - Must be https:// (not http://)
|
||||
// - Must have no user-info (user:pass@host)
|
||||
// - Must resolve to at least one IP, all of which are publicly routable
|
||||
func validateURL(rawURL string) error {
|
||||
parsed, err := url.Parse(rawURL)
|
||||
if err != nil {
|
||||
return &validateError{code: 2, message: fmt.Sprintf("malformed URL %q: %v", rawURL, err)}
|
||||
}
|
||||
|
||||
// Scheme check: only https is permitted.
|
||||
if !strings.EqualFold(parsed.Scheme, "https") {
|
||||
return &validateError{
|
||||
code: 2,
|
||||
message: fmt.Sprintf("URL scheme must be https (got %q)", parsed.Scheme),
|
||||
}
|
||||
}
|
||||
|
||||
// Reject user-info (user:password@host) to prevent credential embedding.
|
||||
if parsed.User != nil {
|
||||
return &validateError{
|
||||
code: 2,
|
||||
message: "URL must not contain user-info (user:password@host)",
|
||||
}
|
||||
}
|
||||
|
||||
host := parsed.Hostname()
|
||||
if host == "" {
|
||||
return &validateError{code: 2, message: fmt.Sprintf("URL has no host: %q", rawURL)}
|
||||
}
|
||||
|
||||
// Resolve the hostname with a short timeout.
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
|
||||
addrs, err := net.DefaultResolver.LookupIPAddr(ctx, host)
|
||||
if err != nil {
|
||||
return &validateError{
|
||||
code: 2,
|
||||
message: fmt.Sprintf("DNS lookup failed for %q: %v", host, err),
|
||||
}
|
||||
}
|
||||
if len(addrs) == 0 {
|
||||
return &validateError{
|
||||
code: 2,
|
||||
message: fmt.Sprintf("DNS lookup returned no addresses for %q", host),
|
||||
}
|
||||
}
|
||||
|
||||
for _, a := range addrs {
|
||||
if gitea.IsBlockedIP(a.IP) {
|
||||
return &validateError{
|
||||
code: 1,
|
||||
message: fmt.Sprintf("blocked: %q resolves to private/reserved IP %s", host, a.IP),
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -0,0 +1,127 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestRunValidateURL_Usage(t *testing.T) {
|
||||
var errBuf bytes.Buffer
|
||||
origErr := errWriter
|
||||
errWriter = &errBuf
|
||||
defer func() { errWriter = origErr }()
|
||||
|
||||
code := runValidateURL(nil)
|
||||
if code != 2 {
|
||||
t.Errorf("expected exit code 2 for no args, got %d", code)
|
||||
}
|
||||
if !strings.Contains(errBuf.String(), "usage") {
|
||||
t.Errorf("expected usage in stderr, got %q", errBuf.String())
|
||||
}
|
||||
|
||||
errBuf.Reset()
|
||||
code = runValidateURL([]string{"arg1", "arg2"})
|
||||
if code != 2 {
|
||||
t.Errorf("expected exit code 2 for too many args, got %d", code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateURL_MalformedURL(t *testing.T) {
|
||||
cases := []struct {
|
||||
name string
|
||||
url string
|
||||
wantMsg string
|
||||
}{
|
||||
{"empty", "", "must be https"},
|
||||
{"http scheme", "http://example.com/", "must be https"},
|
||||
{"ftp scheme", "ftp://example.com/", "must be https"},
|
||||
{"no scheme", "example.com", "must be https"},
|
||||
{"user info", "https://user:pass@example.com/", "user-info"},
|
||||
}
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
err := validateURL(tc.url)
|
||||
if err == nil {
|
||||
t.Errorf("expected error for URL %q, got nil", tc.url)
|
||||
return
|
||||
}
|
||||
if !strings.Contains(err.Error(), tc.wantMsg) {
|
||||
t.Errorf("error %q does not contain %q", err.Error(), tc.wantMsg)
|
||||
}
|
||||
var ve *validateError
|
||||
if !isValidateError(err, &ve) {
|
||||
t.Fatalf("expected *validateError, got %T", err)
|
||||
}
|
||||
if ve.code != 2 {
|
||||
t.Errorf("expected code 2, got %d", ve.code)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateURL_BlockedPrivateIP(t *testing.T) {
|
||||
// localhost always resolves to 127.0.0.1 (loopback).
|
||||
err := validateURL("https://localhost/")
|
||||
if err == nil {
|
||||
t.Skip("localhost did not resolve (network unavailable in test environment)")
|
||||
}
|
||||
var ve *validateError
|
||||
if !isValidateError(err, &ve) {
|
||||
t.Fatalf("expected *validateError, got %T: %v", err, err)
|
||||
}
|
||||
if ve.code != 1 && ve.code != 2 {
|
||||
t.Errorf("expected code 1 (blocked) or 2 (dns fail), got %d: %s", ve.code, ve.message)
|
||||
}
|
||||
// If it resolved (code 1), the message must say "blocked".
|
||||
if ve.code == 1 && !strings.Contains(ve.message, "blocked") {
|
||||
t.Errorf("expected 'blocked' in message, got %q", ve.message)
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateURL_ExitCodes(t *testing.T) {
|
||||
cases := []struct {
|
||||
name string
|
||||
url string
|
||||
wantCode int
|
||||
}{
|
||||
{"http scheme", "http://example.com/", 2},
|
||||
{"no scheme", "example.com", 2},
|
||||
{"user info", "https://admin:secret@example.com/", 2},
|
||||
}
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
err := validateURL(tc.url)
|
||||
if err == nil {
|
||||
t.Fatalf("expected error for %q", tc.url)
|
||||
}
|
||||
var ve *validateError
|
||||
if !isValidateError(err, &ve) {
|
||||
t.Fatalf("expected *validateError, got %T", err)
|
||||
}
|
||||
if ve.code != tc.wantCode {
|
||||
t.Errorf("code = %d, want %d (url=%q, msg=%s)", ve.code, tc.wantCode, tc.url, ve.message)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRunValidateURL_WithCapture(t *testing.T) {
|
||||
var outBuf, errBuf bytes.Buffer
|
||||
origOut, origErr := outWriter, errWriter
|
||||
outWriter = &outBuf
|
||||
errWriter = &errBuf
|
||||
defer func() {
|
||||
outWriter = origOut
|
||||
errWriter = origErr
|
||||
}()
|
||||
|
||||
// http:// scheme should fail with code 2.
|
||||
code := runValidateURL([]string{"http://example.com/"})
|
||||
if code != 2 {
|
||||
t.Errorf("expected code 2 for http:// URL, got %d", code)
|
||||
}
|
||||
if !strings.Contains(errBuf.String(), "must be https") {
|
||||
t.Errorf("expected error about https in stderr, got %q", errBuf.String())
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,361 @@
|
||||
package main
|
||||
|
||||
// vcs.go defines the vcsClient interface that both gitea.Client (via giteaVCSAdapter)
|
||||
// and github.Client (via githubVCSAdapter) satisfy, enabling VCS-type routing in main.go.
|
||||
//
|
||||
// Interface design:
|
||||
// - Methods cover all PR review operations used by main.go.
|
||||
// - Gitea-specific operations (supersede, comment resolution) are in the separate
|
||||
// giteaExtClient interface. GitHub implementations return ErrNotSupported for those.
|
||||
// - Types are defined here as package-local VCS types; each adapter converts from
|
||||
// its respective client package's types.
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
|
||||
"gitea.weiker.me/rodin/review-bot/gitea"
|
||||
"gitea.weiker.me/rodin/review-bot/github"
|
||||
"gitea.weiker.me/rodin/review-bot/review"
|
||||
)
|
||||
|
||||
// ErrNotSupported is returned by VCS methods that have no implementation for
|
||||
// a particular VCS backend (e.g., Gitea-specific timeline APIs on GitHub).
|
||||
var ErrNotSupported = errors.New("operation not supported on this VCS backend")
|
||||
|
||||
// vcsClient is the interface for all PR operations used by main.go.
|
||||
// It is implemented by both giteaVCSAdapter and githubVCSAdapter.
|
||||
// Interface defined here (in the consumer package) per Go idiom.
|
||||
type vcsClient interface {
|
||||
// PR metadata and content
|
||||
GetPullRequest(ctx context.Context, owner, repo string, number int) (*vcsPullRequest, error)
|
||||
GetPullRequestDiff(ctx context.Context, owner, repo string, number int) (string, error)
|
||||
GetPullRequestFiles(ctx context.Context, owner, repo string, number int) ([]vcsChangedFile, error)
|
||||
GetCommitStatuses(ctx context.Context, owner, repo, sha string) ([]vcsCommitStatus, error)
|
||||
GetFileContent(ctx context.Context, owner, repo, filepath string) (string, error)
|
||||
GetFileContentRef(ctx context.Context, owner, repo, filepath, ref string) (string, error)
|
||||
ListContents(ctx context.Context, owner, repo, path string) ([]review.ContentEntry, error)
|
||||
GetAllFilesInPath(ctx context.Context, owner, repo, path string) (map[string]string, error)
|
||||
|
||||
// Review operations
|
||||
PostReview(ctx context.Context, owner, repo string, number int, event, body, commitID string, comments []vcsReviewComment) (*vcsReview, error)
|
||||
ListReviews(ctx context.Context, owner, repo string, number int) ([]vcsReview, error)
|
||||
DeleteReview(ctx context.Context, owner, repo string, number int, reviewID int64) error
|
||||
GetAuthenticatedUser(ctx context.Context) (string, error)
|
||||
RequestReviewer(ctx context.Context, owner, repo string, number int, reviewer string) error
|
||||
}
|
||||
|
||||
// giteaExtClient extends vcsClient with Gitea-specific operations that have no
|
||||
// GitHub equivalent. Code that uses these methods should first do a type assertion.
|
||||
type giteaExtClient interface {
|
||||
vcsClient
|
||||
GetTimelineReviewCommentIDForReview(ctx context.Context, owner, repo string, prNum, reviewID int64) (int64, error)
|
||||
EditComment(ctx context.Context, owner, repo string, commentID int64, body string) error
|
||||
ListReviewComments(ctx context.Context, owner, repo string, prNum, reviewID int64) ([]gitea.ReviewComment, error)
|
||||
ResolveComment(ctx context.Context, owner, repo string, commentID int64) error
|
||||
}
|
||||
|
||||
// --- shared VCS types ---
|
||||
|
||||
// vcsPullRequest is VCS-agnostic PR metadata.
|
||||
type vcsPullRequest struct {
|
||||
Title string
|
||||
Body string
|
||||
Head struct {
|
||||
Sha string
|
||||
Ref string
|
||||
}
|
||||
}
|
||||
|
||||
// vcsChangedFile is a file changed in a PR.
|
||||
type vcsChangedFile struct {
|
||||
Filename string
|
||||
Status string
|
||||
}
|
||||
|
||||
// vcsCommitStatus is a CI status entry.
|
||||
type vcsCommitStatus struct {
|
||||
Status string
|
||||
Context string
|
||||
Description string
|
||||
TargetURL string
|
||||
}
|
||||
|
||||
// vcsReviewComment is an inline review comment.
|
||||
type vcsReviewComment struct {
|
||||
Path string
|
||||
NewPosition int64 // Gitea: absolute line; GitHub: diff hunk position
|
||||
Body string
|
||||
}
|
||||
|
||||
// vcsReview is a submitted PR review.
|
||||
type vcsReview struct {
|
||||
ID int64
|
||||
Body string
|
||||
CommitID string
|
||||
User struct {
|
||||
Login string
|
||||
}
|
||||
State string
|
||||
}
|
||||
|
||||
// ============================================================
|
||||
// giteaVCSAdapter
|
||||
// ============================================================
|
||||
|
||||
// giteaVCSAdapter wraps gitea.Client to implement vcsClient + giteaExtClient.
|
||||
type giteaVCSAdapter struct {
|
||||
c *gitea.Client
|
||||
}
|
||||
|
||||
func newGiteaVCSAdapter(c *gitea.Client) *giteaVCSAdapter { return &giteaVCSAdapter{c: c} }
|
||||
|
||||
func (a *giteaVCSAdapter) GetPullRequest(ctx context.Context, owner, repo string, number int) (*vcsPullRequest, error) {
|
||||
pr, err := a.c.GetPullRequest(ctx, owner, repo, number)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
r := &vcsPullRequest{Title: pr.Title, Body: pr.Body}
|
||||
r.Head.Sha = pr.Head.Sha
|
||||
r.Head.Ref = pr.Head.Ref
|
||||
return r, nil
|
||||
}
|
||||
|
||||
func (a *giteaVCSAdapter) GetPullRequestDiff(ctx context.Context, owner, repo string, number int) (string, error) {
|
||||
return a.c.GetPullRequestDiff(ctx, owner, repo, number)
|
||||
}
|
||||
|
||||
func (a *giteaVCSAdapter) GetPullRequestFiles(ctx context.Context, owner, repo string, number int) ([]vcsChangedFile, error) {
|
||||
files, err := a.c.GetPullRequestFiles(ctx, owner, repo, number)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
out := make([]vcsChangedFile, len(files))
|
||||
for i, f := range files {
|
||||
out[i] = vcsChangedFile{Filename: f.Filename, Status: f.Status}
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (a *giteaVCSAdapter) GetCommitStatuses(ctx context.Context, owner, repo, sha string) ([]vcsCommitStatus, error) {
|
||||
statuses, err := a.c.GetCommitStatuses(ctx, owner, repo, sha)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
out := make([]vcsCommitStatus, len(statuses))
|
||||
for i, s := range statuses {
|
||||
out[i] = vcsCommitStatus{Status: s.Status, Context: s.Context, Description: s.Description, TargetURL: s.TargetURL}
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (a *giteaVCSAdapter) GetFileContent(ctx context.Context, owner, repo, filepath string) (string, error) {
|
||||
return a.c.GetFileContent(ctx, owner, repo, filepath)
|
||||
}
|
||||
|
||||
func (a *giteaVCSAdapter) GetFileContentRef(ctx context.Context, owner, repo, filepath, ref string) (string, error) {
|
||||
return a.c.GetFileContentRef(ctx, owner, repo, filepath, ref)
|
||||
}
|
||||
|
||||
func (a *giteaVCSAdapter) ListContents(ctx context.Context, owner, repo, path string) ([]review.ContentEntry, error) {
|
||||
entries, err := a.c.ListContents(ctx, owner, repo, path)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
out := make([]review.ContentEntry, len(entries))
|
||||
for i, e := range entries {
|
||||
out[i] = review.ContentEntry{Name: e.Name, Path: e.Path, Type: e.Type}
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (a *giteaVCSAdapter) GetAllFilesInPath(ctx context.Context, owner, repo, path string) (map[string]string, error) {
|
||||
return a.c.GetAllFilesInPath(ctx, owner, repo, path)
|
||||
}
|
||||
|
||||
func (a *giteaVCSAdapter) PostReview(ctx context.Context, owner, repo string, number int, event, body, commitID string, comments []vcsReviewComment) (*vcsReview, error) {
|
||||
gc := make([]gitea.ReviewComment, len(comments))
|
||||
for i, c := range comments {
|
||||
gc[i] = gitea.ReviewComment{Path: c.Path, NewPosition: c.NewPosition, Body: c.Body}
|
||||
}
|
||||
r, err := a.c.PostReview(ctx, owner, repo, number, event, body, commitID, gc)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
out := &vcsReview{ID: r.ID, Body: r.Body, CommitID: r.CommitID, State: r.State}
|
||||
out.User.Login = r.User.Login
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (a *giteaVCSAdapter) ListReviews(ctx context.Context, owner, repo string, number int) ([]vcsReview, error) {
|
||||
reviews, err := a.c.ListReviews(ctx, owner, repo, number)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
out := make([]vcsReview, len(reviews))
|
||||
for i, r := range reviews {
|
||||
out[i] = vcsReview{ID: r.ID, Body: r.Body, CommitID: r.CommitID, State: r.State}
|
||||
out[i].User.Login = r.User.Login
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (a *giteaVCSAdapter) DeleteReview(ctx context.Context, owner, repo string, number int, reviewID int64) error {
|
||||
return a.c.DeleteReview(ctx, owner, repo, number, reviewID)
|
||||
}
|
||||
|
||||
func (a *giteaVCSAdapter) GetAuthenticatedUser(ctx context.Context) (string, error) {
|
||||
return a.c.GetAuthenticatedUser(ctx)
|
||||
}
|
||||
|
||||
func (a *giteaVCSAdapter) RequestReviewer(ctx context.Context, owner, repo string, number int, reviewer string) error {
|
||||
return a.c.RequestReviewer(ctx, owner, repo, number, reviewer)
|
||||
}
|
||||
|
||||
// Gitea-specific extension methods.
|
||||
|
||||
func (a *giteaVCSAdapter) GetTimelineReviewCommentIDForReview(ctx context.Context, owner, repo string, prNum, reviewID int64) (int64, error) {
|
||||
return a.c.GetTimelineReviewCommentIDForReview(ctx, owner, repo, int(prNum), reviewID)
|
||||
}
|
||||
|
||||
func (a *giteaVCSAdapter) EditComment(ctx context.Context, owner, repo string, commentID int64, body string) error {
|
||||
return a.c.EditComment(ctx, owner, repo, commentID, body)
|
||||
}
|
||||
|
||||
func (a *giteaVCSAdapter) ListReviewComments(ctx context.Context, owner, repo string, prNum, reviewID int64) ([]gitea.ReviewComment, error) {
|
||||
return a.c.ListReviewComments(ctx, owner, repo, int(prNum), reviewID)
|
||||
}
|
||||
|
||||
func (a *giteaVCSAdapter) ResolveComment(ctx context.Context, owner, repo string, commentID int64) error {
|
||||
return a.c.ResolveComment(ctx, owner, repo, commentID)
|
||||
}
|
||||
|
||||
// ============================================================
|
||||
// githubVCSAdapter
|
||||
// ============================================================
|
||||
|
||||
// githubVCSAdapter wraps github.Client to implement vcsClient.
|
||||
// Gitea-specific extension methods (GetTimelineReviewCommentIDForReview, EditComment,
|
||||
// ListReviewComments, ResolveComment) are not available on GitHub and will not be called
|
||||
// because main.go gates them with a type assertion to giteaExtClient.
|
||||
type githubVCSAdapter struct {
|
||||
c *github.Client
|
||||
}
|
||||
|
||||
func newGithubVCSAdapter(c *github.Client) *githubVCSAdapter { return &githubVCSAdapter{c: c} }
|
||||
|
||||
func (a *githubVCSAdapter) GetPullRequest(ctx context.Context, owner, repo string, number int) (*vcsPullRequest, error) {
|
||||
pr, err := a.c.GetPullRequest(ctx, owner, repo, number)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
r := &vcsPullRequest{Title: pr.Title, Body: pr.Body}
|
||||
r.Head.Sha = pr.Head.Sha
|
||||
r.Head.Ref = pr.Head.Ref
|
||||
return r, nil
|
||||
}
|
||||
|
||||
func (a *githubVCSAdapter) GetPullRequestDiff(ctx context.Context, owner, repo string, number int) (string, error) {
|
||||
return a.c.GetPullRequestDiff(ctx, owner, repo, number)
|
||||
}
|
||||
|
||||
func (a *githubVCSAdapter) GetPullRequestFiles(ctx context.Context, owner, repo string, number int) ([]vcsChangedFile, error) {
|
||||
files, err := a.c.GetPullRequestFiles(ctx, owner, repo, number)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
out := make([]vcsChangedFile, len(files))
|
||||
for i, f := range files {
|
||||
out[i] = vcsChangedFile{Filename: f.Filename, Status: f.Status}
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (a *githubVCSAdapter) GetCommitStatuses(ctx context.Context, owner, repo, sha string) ([]vcsCommitStatus, error) {
|
||||
statuses, err := a.c.GetCommitStatuses(ctx, owner, repo, sha)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
out := make([]vcsCommitStatus, len(statuses))
|
||||
for i, s := range statuses {
|
||||
// CommitStatus.Status is tagged as json:"state" — already the normalized "state" value
|
||||
out[i] = vcsCommitStatus{Status: s.Status, Context: s.Context, Description: s.Description, TargetURL: s.TargetURL}
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (a *githubVCSAdapter) GetFileContent(ctx context.Context, owner, repo, filepath string) (string, error) {
|
||||
return a.c.GetFileContent(ctx, owner, repo, filepath)
|
||||
}
|
||||
|
||||
func (a *githubVCSAdapter) GetFileContentRef(ctx context.Context, owner, repo, filepath, ref string) (string, error) {
|
||||
return a.c.GetFileContentRef(ctx, owner, repo, filepath, ref)
|
||||
}
|
||||
|
||||
func (a *githubVCSAdapter) ListContents(ctx context.Context, owner, repo, path string) ([]review.ContentEntry, error) {
|
||||
entries, err := a.c.ListContents(ctx, owner, repo, path)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
out := make([]review.ContentEntry, len(entries))
|
||||
for i, e := range entries {
|
||||
out[i] = review.ContentEntry{Name: e.Name, Path: e.Path, Type: e.Type}
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (a *githubVCSAdapter) GetAllFilesInPath(ctx context.Context, owner, repo, path string) (map[string]string, error) {
|
||||
return a.c.GetAllFilesInPath(ctx, owner, repo, path)
|
||||
}
|
||||
|
||||
func (a *githubVCSAdapter) PostReview(ctx context.Context, owner, repo string, number int, event, body, commitID string, comments []vcsReviewComment) (*vcsReview, error) {
|
||||
gc := make([]github.ReviewComment, len(comments))
|
||||
for i, c := range comments {
|
||||
// GitHub inline comments use diff hunk "position", not absolute line numbers.
|
||||
// NewPosition from gitea diff parsing gives absolute line numbers, which
|
||||
// will not match GitHub's position values. For initial GitHub support, we
|
||||
// attach comments with Line+Side (absolute line on the RIGHT side) instead.
|
||||
// Comments that cannot be mapped will be omitted (GitHub rejects invalid positions).
|
||||
gc[i] = github.ReviewComment{
|
||||
Path: c.Path,
|
||||
Line: c.NewPosition,
|
||||
Side: "RIGHT",
|
||||
Body: c.Body,
|
||||
}
|
||||
}
|
||||
r, err := a.c.PostReview(ctx, owner, repo, number, event, body, commitID, gc)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
out := &vcsReview{ID: r.ID, Body: r.Body, State: r.State}
|
||||
out.User.Login = r.User.Login
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (a *githubVCSAdapter) ListReviews(ctx context.Context, owner, repo string, number int) ([]vcsReview, error) {
|
||||
reviews, err := a.c.ListReviews(ctx, owner, repo, number)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
out := make([]vcsReview, len(reviews))
|
||||
for i, r := range reviews {
|
||||
out[i] = vcsReview{ID: r.ID, Body: r.Body, State: r.State}
|
||||
out[i].User.Login = r.User.Login
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (a *githubVCSAdapter) DeleteReview(ctx context.Context, owner, repo string, number int, reviewID int64) error {
|
||||
// GitHub only allows deleting PENDING (draft) reviews. review-bot posts submitted
|
||||
// reviews, so this will return an error for any review we actually posted.
|
||||
// Callers should treat 422 errors here gracefully.
|
||||
return a.c.DeleteReview(ctx, owner, repo, number, reviewID)
|
||||
}
|
||||
|
||||
func (a *githubVCSAdapter) GetAuthenticatedUser(ctx context.Context) (string, error) {
|
||||
return a.c.GetAuthenticatedUser(ctx)
|
||||
}
|
||||
|
||||
func (a *githubVCSAdapter) RequestReviewer(ctx context.Context, owner, repo string, number int, reviewer string) error {
|
||||
return a.c.RequestReviewer(ctx, owner, repo, number, reviewer)
|
||||
}
|
||||
+132
-4
@@ -78,18 +78,142 @@ type Client struct {
|
||||
MaxDiffSize int64
|
||||
}
|
||||
|
||||
// defaultCheckRedirect is the redirect policy used by NewClient.
|
||||
// NOTE: This function is intentionally duplicated in github/client.go (and vice versa)
|
||||
// because the packages are separate. Changes here must be mirrored there.
|
||||
// It rejects HTTPS->HTTP protocol downgrades (to prevent plaintext leakage)
|
||||
// and cross-host redirects (to prevent following responses from untrusted
|
||||
// endpoints). Same-host, same-or-upgraded-scheme redirects are allowed.
|
||||
func defaultCheckRedirect(req *http.Request, via []*http.Request) error {
|
||||
if len(via) >= 10 {
|
||||
return fmt.Errorf("stopped after 10 redirects")
|
||||
}
|
||||
// Guard for direct invocation in tests and any future callers;
|
||||
// net/http guarantees len(via) >= 1 during actual redirects.
|
||||
if len(via) == 0 {
|
||||
return nil
|
||||
}
|
||||
prev := via[len(via)-1]
|
||||
// Reject protocol downgrade: HTTPS->HTTP leaks request metadata over plaintext.
|
||||
if prev.URL.Scheme == "https" && req.URL.Scheme == "http" {
|
||||
return fmt.Errorf("refusing redirect: HTTPS to HTTP downgrade (%s -> %s)", prev.URL.Host, req.URL.Host)
|
||||
}
|
||||
// Reject cross-host redirect entirely to avoid consuming responses
|
||||
// from untrusted endpoints.
|
||||
if req.URL.Host != prev.URL.Host {
|
||||
return fmt.Errorf("refusing redirect: cross-host (%s -> %s)", prev.URL.Host, req.URL.Host)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// safeDialContext is the default DialContext for NewClient.
|
||||
// It resolves the hostname and checks every returned IP against the blocked
|
||||
// CIDR list before establishing a connection. This prevents SSRF attacks
|
||||
// where user-supplied URLs resolve to internal/private addresses.
|
||||
//
|
||||
// After validating all IPs, we dial the first resolved IP directly to avoid
|
||||
// a second DNS lookup (which could return a different IP in a DNS rebinding
|
||||
// attack). This narrows — but does not fully eliminate — the DNS rebinding
|
||||
// window to the time between LookupIPAddr and DialContext.
|
||||
//
|
||||
// If the host is already an IP literal, LookupIPAddr returns it directly
|
||||
// (no DNS query issued), so IP literals like https://127.0.0.1/ are blocked.
|
||||
func safeDialContext(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||
host, port, err := net.SplitHostPort(addr)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("safeDialContext: invalid address %q: %w", addr, err)
|
||||
}
|
||||
addrs, err := net.DefaultResolver.LookupIPAddr(ctx, host)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("safeDialContext: DNS lookup %q: %w", host, err)
|
||||
}
|
||||
if len(addrs) == 0 {
|
||||
return nil, fmt.Errorf("safeDialContext: no addresses returned for %q", host)
|
||||
}
|
||||
for _, a := range addrs {
|
||||
if IsBlockedIP(a.IP) {
|
||||
return nil, fmt.Errorf("safeDialContext: blocked: %q resolves to private/reserved IP %s", host, a.IP)
|
||||
}
|
||||
}
|
||||
// Try each resolved IP in order, returning the first successful connection.
|
||||
// Fallback is important when a hostname resolves to multiple IPs and the first
|
||||
// is temporarily unreachable. All IPs were already validated above, so dialing
|
||||
// any of them is safe.
|
||||
//
|
||||
// Timeout: 10s per the design (PLAN.md); the outer http.Client has a 30s
|
||||
// total timeout, but the per-dial timeout ensures a slow TCP connect on one IP
|
||||
// doesn't consume the budget needed to try others.
|
||||
d := &net.Dialer{Timeout: 10 * time.Second}
|
||||
var lastErr error
|
||||
for _, a := range addrs {
|
||||
conn, err := d.DialContext(ctx, network, net.JoinHostPort(a.IP.String(), port))
|
||||
if err == nil {
|
||||
return conn, nil
|
||||
}
|
||||
lastErr = err
|
||||
}
|
||||
return nil, fmt.Errorf("safeDialContext: all %d addresses for %q failed, last error: %w", len(addrs), host, lastErr)
|
||||
}
|
||||
|
||||
// newSafeHTTPClient returns an *http.Client with the SSRF-blocking safeDialContext
|
||||
// transport and the cross-host redirect rejection policy.
|
||||
//
|
||||
// We clone http.DefaultTransport to preserve its production-ready defaults
|
||||
// (ProxyFromEnvironment, TLSHandshakeTimeout, IdleConnTimeout, connection
|
||||
// pooling, HTTP/2 support) and override only DialContext with safeDialContext.
|
||||
func newSafeHTTPClient() *http.Client {
|
||||
transport := http.DefaultTransport.(*http.Transport).Clone()
|
||||
transport.DialContext = safeDialContext
|
||||
return &http.Client{
|
||||
Timeout: 30 * time.Second,
|
||||
Transport: transport,
|
||||
CheckRedirect: defaultCheckRedirect,
|
||||
}
|
||||
}
|
||||
|
||||
// NewClient creates a new Gitea API client.
|
||||
//
|
||||
// The client uses a safe HTTP transport by default: DNS resolution is performed
|
||||
// before connecting and any IP in a private/reserved range is rejected
|
||||
// (RFC1918, loopback, link-local, ULA, etc.). Cross-host and HTTPS→HTTP
|
||||
// redirects are also rejected.
|
||||
//
|
||||
// For tests that use httptest.NewServer (which listens on 127.0.0.1), call
|
||||
// WithUnsafeDialer() to bypass the IP check.
|
||||
func NewClient(baseURL, token string) *Client {
|
||||
return &Client{
|
||||
baseURL: strings.TrimRight(baseURL, "/"),
|
||||
token: token,
|
||||
http: &http.Client{Timeout: 30 * time.Second},
|
||||
http: newSafeHTTPClient(),
|
||||
}
|
||||
}
|
||||
|
||||
// WithUnsafeDialer returns the client configured with a plain HTTP client that
|
||||
// has no IP-level SSRF protection. It preserves the redirect-rejection policy.
|
||||
//
|
||||
// This MUST only be used in tests. Production code must never call this method.
|
||||
func (c *Client) WithUnsafeDialer() *Client {
|
||||
c.http = &http.Client{
|
||||
Timeout: 30 * time.Second,
|
||||
CheckRedirect: defaultCheckRedirect,
|
||||
}
|
||||
return c
|
||||
}
|
||||
|
||||
// SetHTTPClient sets the underlying HTTP client used for requests.
|
||||
// This is intended for testing to inject mock transports.
|
||||
// This is intended for test setup only to inject mock transports; it must be
|
||||
// called before any goroutines issue requests.
|
||||
//
|
||||
// Passing nil restores the default safe client (30s timeout, IP-blocking
|
||||
// safeDialContext, and redirect-rejecting CheckRedirect policy matching NewClient).
|
||||
//
|
||||
// Callers providing a non-nil client are responsible for configuring a safe
|
||||
// CheckRedirect policy. Without one, the default net/http behavior will follow
|
||||
// redirects and may forward the Authorization header to untrusted hosts.
|
||||
func (c *Client) SetHTTPClient(hc *http.Client) {
|
||||
if hc == nil {
|
||||
hc = newSafeHTTPClient()
|
||||
}
|
||||
c.http = hc
|
||||
}
|
||||
|
||||
@@ -217,18 +341,22 @@ func (c *Client) GetFileContentRef(ctx context.Context, owner, repo, filepath, r
|
||||
}
|
||||
|
||||
// PostReview submits a review to a PR and returns the created review.
|
||||
// event should be "APPROVED" or "REQUEST_CHANGES".
|
||||
// event should be one of "APPROVED", "REQUEST_CHANGES", or "COMMENT".
|
||||
// commitID anchors the review to a specific commit SHA. If empty, Gitea
|
||||
// defaults to the current PR head.
|
||||
// comments are optional inline comments attached to specific lines.
|
||||
func (c *Client) PostReview(ctx context.Context, owner, repo string, number int, event, body string, comments []ReviewComment) (*Review, error) {
|
||||
func (c *Client) PostReview(ctx context.Context, owner, repo string, number int, event, body, commitID string, comments []ReviewComment) (*Review, error) {
|
||||
reqURL := fmt.Sprintf("%s/api/v1/repos/%s/%s/pulls/%d/reviews", c.baseURL, url.PathEscape(owner), url.PathEscape(repo), number)
|
||||
|
||||
payload := struct {
|
||||
Body string `json:"body"`
|
||||
Event string `json:"event"`
|
||||
CommitID string `json:"commit_id,omitempty"`
|
||||
Comments []ReviewComment `json:"comments,omitempty"`
|
||||
}{
|
||||
Body: body,
|
||||
Event: event,
|
||||
CommitID: commitID,
|
||||
Comments: comments,
|
||||
}
|
||||
|
||||
|
||||
+303
-43
@@ -9,6 +9,7 @@ import (
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
"syscall"
|
||||
@@ -35,7 +36,7 @@ func TestGetPullRequest(t *testing.T) {
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
client := NewClient(server.URL, "test-token")
|
||||
client := NewTestClient(server.URL, "test-token")
|
||||
got, err := client.GetPullRequest(context.Background(), "owner", "repo", 1)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
@@ -62,7 +63,7 @@ func TestGetPullRequestDiff(t *testing.T) {
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
client := NewClient(server.URL, "test-token")
|
||||
client := NewTestClient(server.URL, "test-token")
|
||||
got, err := client.GetPullRequestDiff(context.Background(), "owner", "repo", 5)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
@@ -87,7 +88,7 @@ func TestGetCommitStatuses(t *testing.T) {
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
client := NewClient(server.URL, "test-token")
|
||||
client := NewTestClient(server.URL, "test-token")
|
||||
got, err := client.GetCommitStatuses(context.Background(), "owner", "repo", "abc123")
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
@@ -116,8 +117,9 @@ func TestPostReview(t *testing.T) {
|
||||
}
|
||||
|
||||
var payload struct {
|
||||
Body string `json:"body"`
|
||||
Event string `json:"event"`
|
||||
Body string `json:"body"`
|
||||
Event string `json:"event"`
|
||||
CommitID string `json:"commit_id"`
|
||||
}
|
||||
if err := json.NewDecoder(r.Body).Decode(&payload); err != nil {
|
||||
t.Fatalf("failed to decode payload: %v", err)
|
||||
@@ -128,14 +130,16 @@ func TestPostReview(t *testing.T) {
|
||||
if payload.Event != "APPROVED" {
|
||||
t.Errorf("expected event %q, got %q", "APPROVED", payload.Event)
|
||||
}
|
||||
|
||||
if payload.CommitID != "abc123def" {
|
||||
t.Errorf("expected commit_id %q, got %q", "abc123def", payload.CommitID)
|
||||
}
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte(`{"id":100,"user":{"login":"review-bot"},"state":"APPROVED","stale":false}`))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
client := NewClient(server.URL, "test-token")
|
||||
review, err := client.PostReview(context.Background(), "owner", "repo", 3, "APPROVED", "LGTM", nil)
|
||||
client := NewTestClient(server.URL, "test-token")
|
||||
review, err := client.PostReview(context.Background(), "owner", "repo", 3, "APPROVED", "LGTM", "abc123def", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
@@ -154,7 +158,7 @@ func TestGetPullRequest_Non200(t *testing.T) {
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
client := NewClient(server.URL, "test-token")
|
||||
client := NewTestClient(server.URL, "test-token")
|
||||
_, err := client.GetPullRequest(context.Background(), "owner", "repo", 999)
|
||||
if err == nil {
|
||||
t.Fatal("expected error for 404, got nil")
|
||||
@@ -167,7 +171,7 @@ func TestGetPullRequest_BadJSON(t *testing.T) {
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
client := NewClient(server.URL, "test-token")
|
||||
client := NewTestClient(server.URL, "test-token")
|
||||
_, err := client.GetPullRequest(context.Background(), "owner", "repo", 1)
|
||||
if err == nil {
|
||||
t.Fatal("expected error for bad JSON, got nil")
|
||||
@@ -181,13 +185,36 @@ func TestPostReview_Non200(t *testing.T) {
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
client := NewClient(server.URL, "test-token")
|
||||
_, err := client.PostReview(context.Background(), "owner", "repo", 1, "APPROVED", "test", nil)
|
||||
client := NewTestClient(server.URL, "test-token")
|
||||
_, err := client.PostReview(context.Background(), "owner", "repo", 1, "APPROVED", "test", "", nil)
|
||||
if err == nil {
|
||||
t.Fatal("expected error for 403, got nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestPostReview_EmptyCommitID_OmittedFromPayload(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
body, _ := io.ReadAll(r.Body)
|
||||
var raw map[string]interface{}
|
||||
if err := json.Unmarshal(body, &raw); err != nil {
|
||||
t.Fatalf("failed to decode payload: %v", err)
|
||||
}
|
||||
if _, exists := raw["commit_id"]; exists {
|
||||
t.Errorf("expected commit_id to be omitted from payload when empty, but it was present")
|
||||
}
|
||||
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte(`{"id":200,"user":{"login":"bot"},"state":"APPROVED","stale":false}`))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
client := NewTestClient(server.URL, "test-token")
|
||||
_, err := client.PostReview(context.Background(), "owner", "repo", 1, "APPROVED", "ok", "", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetFileContent(t *testing.T) {
|
||||
expected := "# Conventions\n- Be nice\n"
|
||||
|
||||
@@ -199,7 +226,7 @@ func TestGetFileContent(t *testing.T) {
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
client := NewClient(server.URL, "test-token")
|
||||
client := NewTestClient(server.URL, "test-token")
|
||||
got, err := client.GetFileContent(context.Background(), "owner", "repo", "CONVENTIONS.md")
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
@@ -219,7 +246,7 @@ func TestGetPullRequestFiles(t *testing.T) {
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
client := NewClient(server.URL, "test-token")
|
||||
client := NewTestClient(server.URL, "test-token")
|
||||
files, err := client.GetPullRequestFiles(context.Background(), "owner", "repo", 1)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
@@ -244,7 +271,7 @@ func TestGetFileContentRef(t *testing.T) {
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
client := NewClient(server.URL, "test-token")
|
||||
client := NewTestClient(server.URL, "test-token")
|
||||
content, err := client.GetFileContentRef(context.Background(), "owner", "repo", "main.go", "feature-branch")
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
@@ -264,7 +291,7 @@ func TestListContents(t *testing.T) {
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
client := NewClient(server.URL, "test-token")
|
||||
client := NewTestClient(server.URL, "test-token")
|
||||
entries, err := client.ListContents(context.Background(), "owner", "repo", "docs")
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
@@ -291,7 +318,7 @@ func TestListContents_DotPath(t *testing.T) {
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
client := NewClient(server.URL, "test-token")
|
||||
client := NewTestClient(server.URL, "test-token")
|
||||
entries, err := client.ListContents(context.Background(), "owner", "repo", ".")
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
@@ -316,7 +343,7 @@ func TestListContents_FilePath(t *testing.T) {
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
client := NewClient(server.URL, "test-token")
|
||||
client := NewTestClient(server.URL, "test-token")
|
||||
entries, err := client.ListContents(context.Background(), "owner", "repo", "README.md")
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
@@ -348,7 +375,7 @@ func TestGetAllFilesInPath_File(t *testing.T) {
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
client := NewClient(server.URL, "test-token")
|
||||
client := NewTestClient(server.URL, "test-token")
|
||||
files, err := client.GetAllFilesInPath(context.Background(), "owner", "repo", "README.md")
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
@@ -401,7 +428,7 @@ func TestListReviews(t *testing.T) {
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
client := NewClient(server.URL, "test-token")
|
||||
client := NewTestClient(server.URL, "test-token")
|
||||
reviews, err := client.ListReviews(context.Background(), "owner", "repo", 5)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
@@ -441,7 +468,7 @@ func TestListReviews_Pagination(t *testing.T) {
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
client := NewClient(server.URL, "test-token")
|
||||
client := NewTestClient(server.URL, "test-token")
|
||||
reviews, err := client.ListReviews(context.Background(), "owner", "repo", 5)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
@@ -466,7 +493,7 @@ func TestDeleteReview(t *testing.T) {
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
client := NewClient(server.URL, "test-token")
|
||||
client := NewTestClient(server.URL, "test-token")
|
||||
err := client.DeleteReview(context.Background(), "owner", "repo", 5, 10)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
@@ -480,7 +507,7 @@ func TestDeleteReview_Forbidden(t *testing.T) {
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
client := NewClient(server.URL, "test-token")
|
||||
client := NewTestClient(server.URL, "test-token")
|
||||
err := client.DeleteReview(context.Background(), "owner", "repo", 5, 10)
|
||||
if err == nil {
|
||||
t.Fatal("expected error for 403, got nil")
|
||||
@@ -509,7 +536,7 @@ func TestEditComment(t *testing.T) {
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
client := NewClient(server.URL, "test-token")
|
||||
client := NewTestClient(server.URL, "test-token")
|
||||
err := client.EditComment(context.Background(), "owner", "repo", 42, "updated body")
|
||||
if err != nil {
|
||||
t.Fatalf("EditComment() error = %v", err)
|
||||
@@ -523,7 +550,7 @@ func TestEditComment_Forbidden(t *testing.T) {
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
client := NewClient(server.URL, "test-token")
|
||||
client := NewTestClient(server.URL, "test-token")
|
||||
err := client.EditComment(context.Background(), "owner", "repo", 42, "new body")
|
||||
if err == nil {
|
||||
t.Fatal("expected error for 403 response")
|
||||
@@ -543,7 +570,7 @@ func TestGetTimelineReviewCommentID(t *testing.T) {
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
client := NewClient(server.URL, "test-token")
|
||||
client := NewTestClient(server.URL, "test-token")
|
||||
id, err := client.GetTimelineReviewCommentID(context.Background(), "owner", "repo", 5, "<!-- review-bot:sonnet -->")
|
||||
if err != nil {
|
||||
t.Fatalf("GetTimelineReviewCommentID() error = %v", err)
|
||||
@@ -559,7 +586,7 @@ func TestGetTimelineReviewCommentID_NotFound(t *testing.T) {
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
client := NewClient(server.URL, "test-token")
|
||||
client := NewTestClient(server.URL, "test-token")
|
||||
_, err := client.GetTimelineReviewCommentID(context.Background(), "owner", "repo", 5, "<!-- review-bot:sonnet -->")
|
||||
if err == nil {
|
||||
t.Fatal("expected error when sentinel not found")
|
||||
@@ -582,7 +609,7 @@ func TestGetAllFilesInPath_404FallsBackToFile(t *testing.T) {
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
client := NewClient(server.URL, "test-token")
|
||||
client := NewTestClient(server.URL, "test-token")
|
||||
files, err := client.GetAllFilesInPath(context.Background(), "owner", "repo", "README.md")
|
||||
if err != nil {
|
||||
t.Fatalf("expected fallback to file on 404, got error: %v", err)
|
||||
@@ -603,7 +630,7 @@ func TestGetAllFilesInPath_500Propagates(t *testing.T) {
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
client := NewClient(server.URL, "test-token")
|
||||
client := NewTestClient(server.URL, "test-token")
|
||||
_, err := client.GetAllFilesInPath(context.Background(), "owner", "repo", "somepath")
|
||||
if err == nil {
|
||||
t.Fatal("expected error to propagate for 500, got nil")
|
||||
@@ -625,7 +652,7 @@ func TestGetAllFilesInPath_403Propagates(t *testing.T) {
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
client := NewClient(server.URL, "test-token")
|
||||
client := NewTestClient(server.URL, "test-token")
|
||||
_, err := client.GetAllFilesInPath(context.Background(), "owner", "repo", "private/stuff")
|
||||
if err == nil {
|
||||
t.Fatal("expected error to propagate for 403, got nil")
|
||||
@@ -677,7 +704,7 @@ func TestGetAuthenticatedUser(t *testing.T) {
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
client := NewClient(server.URL, "test-token")
|
||||
client := NewTestClient(server.URL, "test-token")
|
||||
login, err := client.GetAuthenticatedUser(context.Background())
|
||||
if err != nil {
|
||||
t.Fatalf("GetAuthenticatedUser() error = %v", err)
|
||||
@@ -702,7 +729,7 @@ func TestRequestReviewer(t *testing.T) {
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
client := NewClient(server.URL, "test-token")
|
||||
client := NewTestClient(server.URL, "test-token")
|
||||
err := client.RequestReviewer(context.Background(), "owner", "repo", 7, "bot-user")
|
||||
if err != nil {
|
||||
t.Fatalf("RequestReviewer() error = %v", err)
|
||||
@@ -718,7 +745,7 @@ func TestRequestReviewer_204(t *testing.T) {
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
client := NewClient(server.URL, "test-token")
|
||||
client := NewTestClient(server.URL, "test-token")
|
||||
err := client.RequestReviewer(context.Background(), "owner", "repo", 1, "user")
|
||||
if err != nil {
|
||||
t.Fatalf("RequestReviewer() should accept 204, got error = %v", err)
|
||||
@@ -732,7 +759,7 @@ func TestRequestReviewer_Error(t *testing.T) {
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
client := NewClient(server.URL, "test-token")
|
||||
client := NewTestClient(server.URL, "test-token")
|
||||
err := client.RequestReviewer(context.Background(), "owner", "repo", 1, "user")
|
||||
if err == nil {
|
||||
t.Fatal("expected error for 403 response")
|
||||
@@ -752,7 +779,7 @@ func TestListReviewComments(t *testing.T) {
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
client := NewClient(server.URL, "test-token")
|
||||
client := NewTestClient(server.URL, "test-token")
|
||||
comments, err := client.ListReviewComments(context.Background(), "owner", "repo", 1, 42)
|
||||
if err != nil {
|
||||
t.Fatalf("ListReviewComments() error = %v", err)
|
||||
@@ -780,7 +807,7 @@ func TestResolveComment(t *testing.T) {
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
client := NewClient(server.URL, "test-token")
|
||||
client := NewTestClient(server.URL, "test-token")
|
||||
err := client.ResolveComment(context.Background(), "owner", "repo", 99)
|
||||
if err != nil {
|
||||
t.Fatalf("ResolveComment() error = %v", err)
|
||||
@@ -794,7 +821,7 @@ func TestResolveComment_Error(t *testing.T) {
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
client := NewClient(server.URL, "test-token")
|
||||
client := NewTestClient(server.URL, "test-token")
|
||||
err := client.ResolveComment(context.Background(), "owner", "repo", 99)
|
||||
if err == nil {
|
||||
t.Fatal("expected error for 404 response")
|
||||
@@ -843,7 +870,7 @@ func TestDoGet_RetriesOn500(t *testing.T) {
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
client := NewClient(server.URL, "test-token")
|
||||
client := NewTestClient(server.URL, "test-token")
|
||||
// Use short backoff for fast tests
|
||||
client.RetryBackoff = []time.Duration{1 * time.Millisecond, 1 * time.Millisecond}
|
||||
|
||||
@@ -868,7 +895,7 @@ func TestDoGet_FailsAfterMaxRetries(t *testing.T) {
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
client := NewClient(server.URL, "test-token")
|
||||
client := NewTestClient(server.URL, "test-token")
|
||||
// Use short backoff for fast tests
|
||||
client.RetryBackoff = []time.Duration{1 * time.Millisecond, 1 * time.Millisecond}
|
||||
|
||||
@@ -897,7 +924,7 @@ func TestDoGet_NoRetryOn4xx(t *testing.T) {
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
client := NewClient(server.URL, "test-token")
|
||||
client := NewTestClient(server.URL, "test-token")
|
||||
_, err := client.doGet(context.Background(), server.URL+"/test")
|
||||
if err == nil {
|
||||
t.Fatal("expected error for 403")
|
||||
@@ -925,7 +952,7 @@ func TestDoGet_RespectsContextCancellation(t *testing.T) {
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
|
||||
client := NewClient(server.URL, "test-token")
|
||||
client := NewTestClient(server.URL, "test-token")
|
||||
// Use longer backoff to give us time to cancel during the wait
|
||||
client.RetryBackoff = []time.Duration{100 * time.Millisecond, 100 * time.Millisecond}
|
||||
|
||||
@@ -944,8 +971,6 @@ func TestDoGet_RespectsContextCancellation(t *testing.T) {
|
||||
t.Errorf("attempts = %d, expected 1 before context cancel during backoff", attempts)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
// mockTransport is a test helper that returns errors for the first N calls,
|
||||
// then delegates to a real server.
|
||||
type mockTransport struct {
|
||||
@@ -1159,3 +1184,238 @@ func TestSanitizeErrorForLog(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewClient_HasCheckRedirect(t *testing.T) {
|
||||
c := NewClient("https://gitea.example.com", "token")
|
||||
if c.http.CheckRedirect == nil {
|
||||
t.Fatal("expected CheckRedirect to be set")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDefaultCheckRedirect_RejectsHTTPSToHTTP(t *testing.T) {
|
||||
prev := &http.Request{URL: &url.URL{Scheme: "https", Host: "gitea.example.com", Path: "/foo"}}
|
||||
req := &http.Request{
|
||||
URL: &url.URL{Scheme: "http", Host: "gitea.example.com", Path: "/foo"},
|
||||
Header: http.Header{"Authorization": []string{"token abc"}},
|
||||
}
|
||||
err := defaultCheckRedirect(req, []*http.Request{prev})
|
||||
if err == nil {
|
||||
t.Fatal("expected error on HTTPS->HTTP redirect")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "HTTPS to HTTP downgrade") {
|
||||
t.Errorf("unexpected error message: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDefaultCheckRedirect_RejectsCrossHost(t *testing.T) {
|
||||
prev := &http.Request{URL: &url.URL{Scheme: "https", Host: "gitea.example.com", Path: "/foo"}}
|
||||
req := &http.Request{
|
||||
URL: &url.URL{Scheme: "https", Host: "cdn.example.com", Path: "/bar"},
|
||||
Header: http.Header{"Authorization": []string{"token abc"}},
|
||||
}
|
||||
err := defaultCheckRedirect(req, []*http.Request{prev})
|
||||
if err == nil {
|
||||
t.Fatal("expected error on cross-host redirect")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "cross-host") {
|
||||
t.Errorf("unexpected error message: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDefaultCheckRedirect_AllowsSameHost(t *testing.T) {
|
||||
prev := &http.Request{URL: &url.URL{Scheme: "https", Host: "gitea.example.com", Path: "/foo"}}
|
||||
req := &http.Request{
|
||||
URL: &url.URL{Scheme: "https", Host: "gitea.example.com", Path: "/bar"},
|
||||
Header: http.Header{"Authorization": []string{"token abc"}},
|
||||
}
|
||||
err := defaultCheckRedirect(req, []*http.Request{prev})
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if auth := req.Header.Get("Authorization"); auth != "token abc" {
|
||||
t.Errorf("expected Authorization to be preserved, got %q", auth)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDefaultCheckRedirect_AllowsSameHostHTTPToHTTP(t *testing.T) {
|
||||
prev := &http.Request{URL: &url.URL{Scheme: "http", Host: "localhost:3000", Path: "/foo"}}
|
||||
req := &http.Request{
|
||||
URL: &url.URL{Scheme: "http", Host: "localhost:3000", Path: "/bar"},
|
||||
Header: http.Header{},
|
||||
}
|
||||
err := defaultCheckRedirect(req, []*http.Request{prev})
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDefaultCheckRedirect_RejectsTooManyRedirects(t *testing.T) {
|
||||
via := make([]*http.Request, 10)
|
||||
for i := range via {
|
||||
via[i] = &http.Request{URL: &url.URL{Scheme: "https", Host: "gitea.example.com", Path: "/"}}
|
||||
}
|
||||
req := &http.Request{URL: &url.URL{Scheme: "https", Host: "gitea.example.com", Path: "/final"}}
|
||||
err := defaultCheckRedirect(req, via)
|
||||
if err == nil {
|
||||
t.Fatal("expected error after 10 redirects")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "10 redirects") {
|
||||
t.Errorf("unexpected error message: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDefaultCheckRedirect_EmptyViaAllowed(t *testing.T) {
|
||||
req := &http.Request{URL: &url.URL{Scheme: "https", Host: "gitea.example.com", Path: "/foo"}}
|
||||
err := defaultCheckRedirect(req, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error with empty via: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSetHTTPClient_NilRestoresDefault(t *testing.T) {
|
||||
c := NewClient("https://gitea.example.com", "token")
|
||||
c.SetHTTPClient(nil)
|
||||
if c.http == nil {
|
||||
t.Fatal("expected non-nil http client after SetHTTPClient(nil)")
|
||||
}
|
||||
if c.http.Timeout != 30*time.Second {
|
||||
t.Errorf("expected 30s timeout, got %v", c.http.Timeout)
|
||||
}
|
||||
if c.http.CheckRedirect == nil {
|
||||
t.Fatal("expected CheckRedirect policy after SetHTTPClient(nil)")
|
||||
}
|
||||
}
|
||||
|
||||
// TestSafeDialContextBlocksPrivateIPs verifies that NewClient (which uses
|
||||
// safeDialContext by default) refuses to connect to private/reserved IPs.
|
||||
func TestSafeDialContextBlocksPrivateIPs(t *testing.T) {
|
||||
// These servers listen on 127.0.0.1, so the safe dialer will block them.
|
||||
// We use NewClient (NOT NewTestClient) to exercise the real safe dialer.
|
||||
privateURLs := []struct {
|
||||
name string
|
||||
url string
|
||||
}{
|
||||
{"loopback localhost", "http://localhost/"},
|
||||
{"loopback 127.0.0.1", "http://127.0.0.1/"},
|
||||
}
|
||||
|
||||
for _, tc := range privateURLs {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
c := NewClient(tc.url, "token")
|
||||
_, err := c.GetPullRequest(context.Background(), "owner", "repo", 1)
|
||||
if err == nil {
|
||||
t.Errorf("expected error connecting to %s, got nil", tc.url)
|
||||
}
|
||||
// Error must mention SSRF/blocked, not a random network error.
|
||||
if !strings.Contains(err.Error(), "blocked") &&
|
||||
!strings.Contains(err.Error(), "private") &&
|
||||
!strings.Contains(err.Error(), "loopback") &&
|
||||
!strings.Contains(err.Error(), "reserved") {
|
||||
t.Logf("error: %v", err)
|
||||
// Allow other errors (connection refused, DNS) since the point
|
||||
// is that we don't silently succeed — but prefer the explicit block message.
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestWithUnsafeDialerAllowsLocalhost verifies that WithUnsafeDialer bypasses
|
||||
// the IP check, allowing tests to connect to httptest.Server (127.0.0.1).
|
||||
func TestWithUnsafeDialerAllowsLocalhost(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.Write([]byte(`{"title":"test","body":"","head":{"sha":"abc","ref":"main"}}`))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
// WithUnsafeDialer should allow connecting to 127.0.0.1.
|
||||
c := NewClient(server.URL, "token").WithUnsafeDialer()
|
||||
pr, err := c.GetPullRequest(context.Background(), "owner", "repo", 1)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error with unsafe dialer: %v", err)
|
||||
}
|
||||
if pr.Title != "test" {
|
||||
t.Errorf("expected title 'test', got %q", pr.Title)
|
||||
}
|
||||
}
|
||||
|
||||
// TestNewClient_HasSafeTransport verifies that NewClient installs the
|
||||
// SSRF-blocking transport (i.e. Transport is not nil and DialContext is set).
|
||||
func TestNewClient_HasSafeTransport(t *testing.T) {
|
||||
c := NewClient("https://gitea.example.com", "token")
|
||||
if c.http.Transport == nil {
|
||||
t.Fatal("expected Transport to be set on NewClient (safe dialer)")
|
||||
}
|
||||
transport, ok := c.http.Transport.(*http.Transport)
|
||||
if !ok {
|
||||
t.Fatalf("expected *http.Transport, got %T", c.http.Transport)
|
||||
}
|
||||
if transport.DialContext == nil {
|
||||
t.Fatal("expected DialContext to be set on transport (safe dialer)")
|
||||
}
|
||||
}
|
||||
|
||||
// TestSetHTTPClient_NilRestoresSafeTransport verifies that SetHTTPClient(nil)
|
||||
// restores the safe transport (not just any client).
|
||||
func TestSetHTTPClient_NilRestoresSafeTransport(t *testing.T) {
|
||||
c := NewClient("https://gitea.example.com", "token")
|
||||
c.SetHTTPClient(&http.Client{}) // replace with plain client
|
||||
c.SetHTTPClient(nil) // restore
|
||||
transport, ok := c.http.Transport.(*http.Transport)
|
||||
if !ok {
|
||||
t.Fatalf("expected *http.Transport after SetHTTPClient(nil), got %T", c.http.Transport)
|
||||
}
|
||||
if transport.DialContext == nil {
|
||||
t.Fatal("expected DialContext to be restored after SetHTTPClient(nil)")
|
||||
}
|
||||
}
|
||||
|
||||
// TestNewSafeHTTPClient_PreservesDefaultTransportSettings verifies that
|
||||
// newSafeHTTPClient clones http.DefaultTransport to retain proxy support,
|
||||
// TLS handshake timeout, idle connection limits, and HTTP/2.
|
||||
func TestNewSafeHTTPClient_PreservesDefaultTransportSettings(t *testing.T) {
|
||||
c := NewClient("https://gitea.example.com", "token")
|
||||
transport, ok := c.http.Transport.(*http.Transport)
|
||||
if !ok {
|
||||
t.Fatalf("expected *http.Transport, got %T", c.http.Transport)
|
||||
}
|
||||
|
||||
defaults := http.DefaultTransport.(*http.Transport)
|
||||
|
||||
// TLSHandshakeTimeout must be inherited (non-zero), not the zero value
|
||||
// that a bare &http.Transport{} would have.
|
||||
if transport.TLSHandshakeTimeout == 0 {
|
||||
t.Error("TLSHandshakeTimeout is 0; expected inherited value from DefaultTransport")
|
||||
}
|
||||
if transport.TLSHandshakeTimeout != defaults.TLSHandshakeTimeout {
|
||||
t.Errorf("TLSHandshakeTimeout = %v, want %v", transport.TLSHandshakeTimeout, defaults.TLSHandshakeTimeout)
|
||||
}
|
||||
|
||||
// IdleConnTimeout must be inherited.
|
||||
if transport.IdleConnTimeout == 0 {
|
||||
t.Error("IdleConnTimeout is 0; expected inherited value from DefaultTransport")
|
||||
}
|
||||
if transport.IdleConnTimeout != defaults.IdleConnTimeout {
|
||||
t.Errorf("IdleConnTimeout = %v, want %v", transport.IdleConnTimeout, defaults.IdleConnTimeout)
|
||||
}
|
||||
|
||||
// MaxIdleConns must be inherited.
|
||||
if transport.MaxIdleConns == 0 {
|
||||
t.Error("MaxIdleConns is 0; expected inherited value from DefaultTransport")
|
||||
}
|
||||
|
||||
// ForceAttemptHTTP2 must be inherited.
|
||||
if !transport.ForceAttemptHTTP2 {
|
||||
t.Error("ForceAttemptHTTP2 is false; expected true from DefaultTransport")
|
||||
}
|
||||
|
||||
// Proxy must be set (ProxyFromEnvironment).
|
||||
if transport.Proxy == nil {
|
||||
t.Error("Proxy is nil; expected ProxyFromEnvironment from DefaultTransport")
|
||||
}
|
||||
|
||||
// DialContext must be our safe dialer, not the default.
|
||||
if transport.DialContext == nil {
|
||||
t.Error("DialContext is nil; expected safeDialContext")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -70,7 +70,7 @@ func TestGetPullRequestDiff_SizeLimits(t *testing.T) {
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
client := NewClient(server.URL, "test-token")
|
||||
client := NewTestClient(server.URL, "test-token")
|
||||
client.MaxDiffSize = tt.maxDiffSize
|
||||
client.RetryBackoff = []time.Duration{}
|
||||
|
||||
|
||||
@@ -0,0 +1,18 @@
|
||||
// Package gitea — export_test.go exposes test helpers to test files in this
|
||||
// package. It uses `package gitea` (not `package gitea_test`) so it can access
|
||||
// unexported identifiers; Go only compiles it into the test binary, never into
|
||||
// the production binary. This is the idiomatic pattern for white-box testing
|
||||
// in Go (see net/http/export_test.go in the stdlib for the same approach).
|
||||
package gitea
|
||||
|
||||
// NewTestClient creates a Gitea client configured for use in unit tests.
|
||||
// It bypasses the IP-level SSRF protection so that tests can connect to
|
||||
// httptest.Server instances (which listen on 127.0.0.1).
|
||||
//
|
||||
// Using the internal package gitea declaration (not gitea_test) means this
|
||||
// symbol is available to all _test.go files in this package. It is ONLY
|
||||
// compiled into the test binary; production binaries never include it.
|
||||
// Production code must use NewClient, which enables the safe dialer.
|
||||
func NewTestClient(baseURL, token string) *Client {
|
||||
return NewClient(baseURL, token).WithUnsafeDialer()
|
||||
}
|
||||
@@ -0,0 +1,91 @@
|
||||
// Package gitea provides a client for the Gitea API.
|
||||
// ipcheck.go implements IP-level SSRF protection by checking resolved addresses
|
||||
// against known blocked CIDR ranges (RFC1918, loopback, link-local, etc.).
|
||||
package gitea
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
)
|
||||
|
||||
// blockedCIDRStrings is the canonical list of CIDR strings that should never
|
||||
// be contacted by review-bot. See IsBlockedIP for the full list of covered
|
||||
// address families.
|
||||
//
|
||||
// These are hard-coded literals: any parse failure is a programming error.
|
||||
// Validity is verified by TestBlockedCIDRsValid in ipcheck_test.go.
|
||||
var blockedCIDRStrings = []string{
|
||||
// IPv4 loopback
|
||||
"127.0.0.0/8",
|
||||
// IPv4 unspecified / "this network"
|
||||
"0.0.0.0/8",
|
||||
// RFC1918 private ranges
|
||||
"10.0.0.0/8",
|
||||
"172.16.0.0/12",
|
||||
"192.168.0.0/16",
|
||||
// IPv4 link-local (APIPA, also used by AWS instance metadata 169.254.169.254)
|
||||
"169.254.0.0/16",
|
||||
// IPv4 shared address space (RFC6598, carrier-grade NAT)
|
||||
"100.64.0.0/10",
|
||||
// IPv4 multicast
|
||||
"224.0.0.0/4",
|
||||
// IPv4 reserved / broadcast
|
||||
"240.0.0.0/4",
|
||||
// IPv6 loopback
|
||||
"::1/128",
|
||||
// IPv6 unspecified
|
||||
"::/128",
|
||||
// IPv6 link-local
|
||||
"fe80::/10",
|
||||
// IPv6 unique local (ULA) — RFC4193
|
||||
"fc00::/7",
|
||||
// IPv6 multicast
|
||||
"ff00::/8",
|
||||
}
|
||||
|
||||
// blockedCIDRs is the parsed form of blockedCIDRStrings.
|
||||
// Any entry that fails to parse is recorded in blockedCIDRParseErrors instead
|
||||
// of panicking; tests verify this slice is always empty via TestBlockedCIDRsValid.
|
||||
var (
|
||||
blockedCIDRs []*net.IPNet
|
||||
blockedCIDRParseErrors []string
|
||||
)
|
||||
|
||||
func init() {
|
||||
blockedCIDRs = make([]*net.IPNet, 0, len(blockedCIDRStrings))
|
||||
for _, r := range blockedCIDRStrings {
|
||||
_, cidr, err := net.ParseCIDR(r)
|
||||
if err != nil {
|
||||
// Record the error rather than panicking; TestBlockedCIDRsValid
|
||||
// will catch this during tests, and the CI build will fail.
|
||||
blockedCIDRParseErrors = append(blockedCIDRParseErrors,
|
||||
fmt.Sprintf("ipcheck: invalid built-in CIDR %q: %v", r, err))
|
||||
continue
|
||||
}
|
||||
blockedCIDRs = append(blockedCIDRs, cidr)
|
||||
}
|
||||
}
|
||||
|
||||
// IsBlockedIP reports whether ip is in a blocked address range.
|
||||
// It is exported for use by the validate-url subcommand and tests outside
|
||||
// this package.
|
||||
//
|
||||
// IPv6-mapped IPv4 addresses (e.g. ::ffff:192.168.1.1) are normalized to their
|
||||
// IPv4 form before checking so that IPv4 CIDRs catch them.
|
||||
//
|
||||
// Based on:
|
||||
// - RFC1918 private ranges
|
||||
// - RFC5735 / RFC4193 special-use IPv4/IPv6 ranges
|
||||
// - RFC4291 IPv6 link-local / loopback
|
||||
func IsBlockedIP(ip net.IP) bool {
|
||||
// Normalize IPv6-mapped IPv4 addresses (::ffff:x.x.x.x) to plain IPv4.
|
||||
if v4 := ip.To4(); v4 != nil {
|
||||
ip = v4
|
||||
}
|
||||
for _, cidr := range blockedCIDRs {
|
||||
if cidr.Contains(ip) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
@@ -0,0 +1,144 @@
|
||||
package gitea
|
||||
|
||||
import (
|
||||
"net"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestIsBlockedIP(t *testing.T) {
|
||||
blocked := []struct {
|
||||
name string
|
||||
ip string
|
||||
}{
|
||||
// IPv4 loopback
|
||||
{"loopback 127.0.0.1", "127.0.0.1"},
|
||||
{"loopback 127.0.0.2", "127.0.0.2"},
|
||||
{"loopback 127.255.255.255", "127.255.255.255"},
|
||||
// IPv4 unspecified
|
||||
{"unspecified 0.0.0.0", "0.0.0.0"},
|
||||
{"unspecified 0.1.2.3", "0.1.2.3"},
|
||||
// RFC1918
|
||||
{"RFC1918 10.0.0.1", "10.0.0.1"},
|
||||
{"RFC1918 10.255.255.255", "10.255.255.255"},
|
||||
{"RFC1918 172.16.0.1", "172.16.0.1"},
|
||||
{"RFC1918 172.31.255.255", "172.31.255.255"},
|
||||
{"RFC1918 192.168.0.1", "192.168.0.1"},
|
||||
{"RFC1918 192.168.255.255", "192.168.255.255"},
|
||||
// Link-local (APIPA / AWS metadata)
|
||||
{"link-local 169.254.0.1", "169.254.0.1"},
|
||||
{"link-local 169.254.169.254", "169.254.169.254"},
|
||||
// Shared address space (carrier-grade NAT)
|
||||
{"CGN 100.64.0.1", "100.64.0.1"},
|
||||
{"CGN 100.127.255.255", "100.127.255.255"},
|
||||
// Multicast
|
||||
{"multicast 224.0.0.1", "224.0.0.1"},
|
||||
{"multicast 239.255.255.255", "239.255.255.255"},
|
||||
// Reserved
|
||||
{"reserved 240.0.0.1", "240.0.0.1"},
|
||||
{"broadcast 255.255.255.255", "255.255.255.255"},
|
||||
// IPv6 loopback
|
||||
{"IPv6 loopback ::1", "::1"},
|
||||
// IPv6 unspecified
|
||||
{"IPv6 unspecified ::", "::"},
|
||||
// IPv6 link-local
|
||||
{"IPv6 link-local fe80::1", "fe80::1"},
|
||||
{"IPv6 link-local fe80::dead:beef", "fe80::dead:beef"},
|
||||
// IPv6 ULA
|
||||
{"IPv6 ULA fc00::1", "fc00::1"},
|
||||
{"IPv6 ULA fd00::1", "fd00::1"},
|
||||
// IPv6 multicast
|
||||
{"IPv6 multicast ff02::1", "ff02::1"},
|
||||
}
|
||||
|
||||
for _, tc := range blocked {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
ip := net.ParseIP(tc.ip)
|
||||
if ip == nil {
|
||||
t.Fatalf("failed to parse IP %q", tc.ip)
|
||||
}
|
||||
if !IsBlockedIP(ip) {
|
||||
t.Errorf("IsBlockedIP(%q) = false, want true", tc.ip)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
allowed := []struct {
|
||||
name string
|
||||
ip string
|
||||
}{
|
||||
{"public 8.8.8.8", "8.8.8.8"},
|
||||
{"public 1.1.1.1", "1.1.1.1"},
|
||||
{"public 198.51.100.1", "198.51.100.1"}, // RFC5737 TEST-NET-2 — a documentation-only range;
|
||||
// not assigned to any real host, but intentionally left unblocked here because
|
||||
// it has no special routing treatment (unlike RFC1918/loopback/link-local) and
|
||||
// blocking it would require tracking every RFC5737 range without meaningful
|
||||
// security benefit (no server should ever listen on a TEST-NET address).
|
||||
{"public 151.101.1.1", "151.101.1.1"}, // Fastly
|
||||
{"public IPv6 2001:4860:4860::8888", "2001:4860:4860::8888"}, // Google DNS
|
||||
{"public IPv6 2606:4700:4700::1111", "2606:4700:4700::1111"}, // Cloudflare DNS
|
||||
}
|
||||
|
||||
for _, tc := range allowed {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
ip := net.ParseIP(tc.ip)
|
||||
if ip == nil {
|
||||
t.Fatalf("failed to parse IP %q", tc.ip)
|
||||
}
|
||||
if IsBlockedIP(ip) {
|
||||
t.Errorf("IsBlockedIP(%q) = true, want false", tc.ip)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsBlockedIPv6MappedIPv4(t *testing.T) {
|
||||
// ::ffff:192.168.1.1 is an IPv6-mapped IPv4 address — should be blocked as RFC1918.
|
||||
// Construct it manually as a 16-byte IP.
|
||||
mapped := net.IP{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0xff, 0xff, 192, 168, 1, 1}
|
||||
if !IsBlockedIP(mapped) {
|
||||
t.Errorf("IsBlockedIP(::ffff:192.168.1.1) = false, want true (IPv6-mapped IPv4 must be normalized)")
|
||||
}
|
||||
|
||||
// ::ffff:8.8.8.8 — IPv6-mapped public IP — should be allowed.
|
||||
mappedPublic := net.IP{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0xff, 0xff, 8, 8, 8, 8}
|
||||
if IsBlockedIP(mappedPublic) {
|
||||
t.Errorf("IsBlockedIP(::ffff:8.8.8.8) = true, want false")
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsBlockedIPEdgeCases(t *testing.T) {
|
||||
// The boundary between RFC1918 and public ranges.
|
||||
// 172.15.255.255 is NOT private (just below 172.16.0.0/12).
|
||||
notPrivate := net.ParseIP("172.15.255.255")
|
||||
if IsBlockedIP(notPrivate) {
|
||||
t.Errorf("IsBlockedIP(172.15.255.255) = true, want false (outside 172.16.0.0/12)")
|
||||
}
|
||||
// 172.32.0.0 is NOT private (just above 172.31.255.255).
|
||||
notPrivate2 := net.ParseIP("172.32.0.0")
|
||||
if IsBlockedIP(notPrivate2) {
|
||||
t.Errorf("IsBlockedIP(172.32.0.0) = true, want false (outside 172.16.0.0/12)")
|
||||
}
|
||||
// CGN: 100.63.255.255 is NOT in 100.64.0.0/10.
|
||||
notCGN := net.ParseIP("100.63.255.255")
|
||||
if IsBlockedIP(notCGN) {
|
||||
t.Errorf("IsBlockedIP(100.63.255.255) = true, want false (outside 100.64.0.0/10)")
|
||||
}
|
||||
// CGN: 100.128.0.0 is NOT in 100.64.0.0/10.
|
||||
notCGN2 := net.ParseIP("100.128.0.0")
|
||||
if IsBlockedIP(notCGN2) {
|
||||
t.Errorf("IsBlockedIP(100.128.0.0) = true, want false (outside 100.64.0.0/10)")
|
||||
}
|
||||
}
|
||||
|
||||
// TestBlockedCIDRsValid verifies that all entries in blockedCIDRStrings parse
|
||||
// successfully. This catches programming errors in the CIDR list without
|
||||
// requiring a startup panic. The init() function records parse failures in
|
||||
// blockedCIDRParseErrors rather than panicking; this test makes those failures
|
||||
// visible as test failures during CI.
|
||||
func TestBlockedCIDRsValid(t *testing.T) {
|
||||
if len(blockedCIDRParseErrors) > 0 {
|
||||
for _, msg := range blockedCIDRParseErrors {
|
||||
t.Errorf("CIDR parse error: %s", msg)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -31,13 +31,13 @@ func TestPostReview_WithComments(t *testing.T) {
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
client := NewClient(server.URL, "test-token")
|
||||
client := NewTestClient(server.URL, "test-token")
|
||||
comments := []ReviewComment{
|
||||
{Path: "main.go", NewPosition: 42, Body: "[MAJOR] Something bad"},
|
||||
{Path: "util.go", NewPosition: 10, Body: "[MINOR] Style issue"},
|
||||
}
|
||||
|
||||
_, err := client.PostReview(context.Background(), "owner", "repo", 1, "REQUEST_CHANGES", "summary", comments)
|
||||
_, err := client.PostReview(context.Background(), "owner", "repo", 1, "REQUEST_CHANGES", "summary", "", comments)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
@@ -71,8 +71,8 @@ func TestPostReview_NilComments(t *testing.T) {
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
client := NewClient(server.URL, "test-token")
|
||||
_, err := client.PostReview(context.Background(), "owner", "repo", 1, "APPROVED", "all good", nil)
|
||||
client := NewTestClient(server.URL, "test-token")
|
||||
_, err := client.PostReview(context.Background(), "owner", "repo", 1, "APPROVED", "all good", "", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
@@ -0,0 +1,831 @@
|
||||
// Package github provides a client for the GitHub API.
|
||||
// It supports pull request operations, file content retrieval,
|
||||
// and review submission for both github.com and GitHub Enterprise.
|
||||
package github
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
defaultBaseURL = "https://api.github.com"
|
||||
|
||||
// maxRetryAttempts is the number of times doRequest will attempt a request.
|
||||
maxRetryAttempts = 3
|
||||
|
||||
// maxRetryAfter caps the maximum delay from a Retry-After header to prevent
|
||||
// a server from stalling the client indefinitely.
|
||||
maxRetryAfter = 60 * time.Second
|
||||
|
||||
// maxErrorBodyBytes limits how much of an error response body we read
|
||||
// to protect against malicious servers sending unbounded data.
|
||||
maxErrorBodyBytes = 64 * 1024 // 64 KB
|
||||
|
||||
// maxResponseBodyBytes limits how much of a successful response body we read
|
||||
// for defense-in-depth against servers returning excessively large payloads.
|
||||
maxResponseBodyBytes = 10 * 1024 * 1024 // 10 MB
|
||||
)
|
||||
|
||||
// APIError represents an HTTP error response from the GitHub API.
|
||||
// It carries the status code so callers can distinguish between
|
||||
// different failure modes (e.g. 404 vs 500).
|
||||
//
|
||||
// The Body field stores up to 64 KiB of the raw response for programmatic
|
||||
// inspection. Error() truncates to 200 bytes for safe logging, but callers
|
||||
// should avoid logging or propagating Body directly in production since it may
|
||||
// contain sensitive details from the upstream server.
|
||||
type APIError struct {
|
||||
StatusCode int
|
||||
Body string
|
||||
}
|
||||
|
||||
func (e *APIError) Error() string {
|
||||
body := e.Body
|
||||
if len(body) > 200 {
|
||||
body = body[:200] + "...(truncated)"
|
||||
}
|
||||
// Sanitize newlines to prevent log injection from upstream response bodies.
|
||||
body = strings.ReplaceAll(body, "\n", " ")
|
||||
body = strings.ReplaceAll(body, "\r", " ")
|
||||
return fmt.Sprintf("HTTP %d: %s", e.StatusCode, body)
|
||||
}
|
||||
|
||||
// IsNotFound reports whether an error is an API 404 response.
|
||||
func IsNotFound(err error) bool {
|
||||
if apiErr, ok := asAPIError(err); ok {
|
||||
return apiErr.StatusCode == http.StatusNotFound
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// IsUnauthorized reports whether an error is an API 401 response.
|
||||
func IsUnauthorized(err error) bool {
|
||||
if apiErr, ok := asAPIError(err); ok {
|
||||
return apiErr.StatusCode == http.StatusUnauthorized
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func asAPIError(err error) (*APIError, bool) {
|
||||
if err == nil {
|
||||
return nil, false
|
||||
}
|
||||
var target *APIError
|
||||
if errors.As(err, &target) {
|
||||
return target, true
|
||||
}
|
||||
return nil, false
|
||||
}
|
||||
|
||||
// Client interacts with the GitHub API.
|
||||
// A Client is safe for concurrent use by multiple goroutines.
|
||||
// SetHTTPClient and SetRetryBackoff are intended for test setup only and must
|
||||
// be called before any goroutines issue requests; they have no synchronization.
|
||||
type Client struct {
|
||||
baseURL string
|
||||
token string
|
||||
httpClient *http.Client
|
||||
|
||||
// allowInsecureHTTP permits requests to HTTP (non-TLS) endpoints.
|
||||
// When false, doRequest rejects URLs with an http:// scheme.
|
||||
allowInsecureHTTP bool
|
||||
|
||||
// retryBackoff defines the delays between retry attempts for 429 responses.
|
||||
// retryBackoff[i] is the delay before attempt i+1 (after attempt i fails).
|
||||
// If nil, defaults to {1s, 2s}.
|
||||
retryBackoff []time.Duration
|
||||
|
||||
// now returns the current time. Defaults to time.Now.
|
||||
// Override in tests to control HTTP-date Retry-After calculations.
|
||||
now func() time.Time
|
||||
}
|
||||
|
||||
// defaultCheckRedirect is the redirect policy used by NewClient.
|
||||
// NOTE: This function is intentionally duplicated in gitea/client.go (and vice versa)
|
||||
// because the packages are separate. Changes here must be mirrored there.
|
||||
// It rejects HTTPS->HTTP protocol downgrades (to prevent plaintext leakage)
|
||||
// and cross-host redirects (to prevent following responses from untrusted
|
||||
// endpoints). Same-host, same-or-upgraded-scheme redirects are allowed.
|
||||
func defaultCheckRedirect(req *http.Request, via []*http.Request) error {
|
||||
if len(via) >= 10 {
|
||||
return fmt.Errorf("stopped after 10 redirects")
|
||||
}
|
||||
// Guard for direct invocation in tests and any future callers;
|
||||
// net/http guarantees len(via) >= 1 during actual redirects.
|
||||
if len(via) == 0 {
|
||||
return nil
|
||||
}
|
||||
prev := via[len(via)-1]
|
||||
// Reject protocol downgrade: HTTPS->HTTP leaks request metadata over plaintext.
|
||||
if prev.URL.Scheme == "https" && req.URL.Scheme == "http" {
|
||||
return fmt.Errorf("refusing redirect: HTTPS to HTTP downgrade (%s -> %s)", prev.URL.Host, req.URL.Host)
|
||||
}
|
||||
// Reject cross-host redirect entirely to avoid consuming responses
|
||||
// from untrusted endpoints.
|
||||
if req.URL.Host != prev.URL.Host {
|
||||
return fmt.Errorf("refusing redirect: cross-host (%s -> %s)", prev.URL.Host, req.URL.Host)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ClientOption configures optional behavior of a Client.
|
||||
type ClientOption func(*clientConfig)
|
||||
|
||||
type clientConfig struct {
|
||||
allowInsecureHTTP bool
|
||||
insecureIsTestBypass bool
|
||||
}
|
||||
|
||||
// AllowInsecureHTTP permits sending credentials over plaintext HTTP connections.
|
||||
// In production, this option is gated by the REVIEW_BOT_ALLOW_INSECURE=1
|
||||
// environment variable. Without the env var set, the option is ignored
|
||||
// and a warning is logged.
|
||||
//
|
||||
// For tests, use AllowInsecureHTTPForTest (defined in a _test.go file in the same package) which bypasses the env gate.
|
||||
func AllowInsecureHTTP() ClientOption {
|
||||
return func(cfg *clientConfig) {
|
||||
cfg.allowInsecureHTTP = true
|
||||
}
|
||||
}
|
||||
|
||||
// NewClient creates a new GitHub API client.
|
||||
// If baseURL is empty, it defaults to https://api.github.com.
|
||||
// For GitHub Enterprise, pass the API base URL (e.g. https://github.concur.com/api/v3).
|
||||
func NewClient(token, baseURL string, opts ...ClientOption) *Client {
|
||||
if baseURL == "" {
|
||||
baseURL = defaultBaseURL
|
||||
}
|
||||
|
||||
var cfg clientConfig
|
||||
for _, opt := range opts {
|
||||
opt(&cfg)
|
||||
}
|
||||
|
||||
if cfg.allowInsecureHTTP && !cfg.insecureIsTestBypass {
|
||||
if os.Getenv("REVIEW_BOT_ALLOW_INSECURE") != "1" {
|
||||
slog.Warn("AllowInsecureHTTP ignored: set REVIEW_BOT_ALLOW_INSECURE=1 to enable")
|
||||
cfg.allowInsecureHTTP = false
|
||||
} else {
|
||||
slog.Warn("AllowInsecureHTTP enabled — credentials may be sent over plaintext",
|
||||
"env", "REVIEW_BOT_ALLOW_INSECURE=1")
|
||||
}
|
||||
}
|
||||
|
||||
return &Client{
|
||||
baseURL: strings.TrimRight(baseURL, "/"),
|
||||
token: token,
|
||||
allowInsecureHTTP: cfg.allowInsecureHTTP,
|
||||
httpClient: &http.Client{
|
||||
Timeout: 30 * time.Second,
|
||||
CheckRedirect: defaultCheckRedirect,
|
||||
},
|
||||
now: time.Now,
|
||||
}
|
||||
}
|
||||
|
||||
// SetHTTPClient sets the underlying HTTP client used for requests.
|
||||
// This is intended for test setup only to inject mock transports; it must be
|
||||
// called before any goroutines issue requests.
|
||||
//
|
||||
// Passing nil restores the default client (30s timeout + redirect-rejecting
|
||||
// CheckRedirect policy matching NewClient).
|
||||
//
|
||||
// Callers providing a non-nil client are responsible for configuring a safe
|
||||
// CheckRedirect policy. Without one, the default net/http behavior will follow
|
||||
// redirects and may forward the Authorization header to untrusted hosts.
|
||||
func (c *Client) SetHTTPClient(hc *http.Client) {
|
||||
if hc == nil {
|
||||
hc = &http.Client{
|
||||
Timeout: 30 * time.Second,
|
||||
CheckRedirect: defaultCheckRedirect,
|
||||
}
|
||||
}
|
||||
c.httpClient = hc
|
||||
}
|
||||
|
||||
// SetRetryBackoff sets the delays between retry attempts.
|
||||
// This is intended for testing to speed up retry tests.
|
||||
//
|
||||
// Note: if an empty non-nil slice is provided, Retry-After delays parsed from
|
||||
// server responses will be computed and capped but not applied (because
|
||||
// attempt < len(backoff) is always false). This is acceptable for the
|
||||
// test-only use case but callers should be aware of this edge case.
|
||||
func (c *Client) SetRetryBackoff(backoff []time.Duration) {
|
||||
c.retryBackoff = backoff
|
||||
}
|
||||
|
||||
// parseRetryAfter parses a Retry-After header value, supporting both integer
|
||||
// seconds (e.g. "120") and HTTP-date format (e.g. "Thu, 01 Dec 2025 16:00:00 GMT")
|
||||
// as specified in RFC 7231 §7.1.3.
|
||||
//
|
||||
// For integer values, it returns the duration directly.
|
||||
// For HTTP-date values, it computes the delay as the difference between the
|
||||
// parsed time and now. If the date is in the past, it returns 0.
|
||||
//
|
||||
// Returns (0, false) if the value cannot be parsed as either format.
|
||||
func (c *Client) parseRetryAfter(value string) (time.Duration, bool) {
|
||||
value = strings.TrimSpace(value)
|
||||
|
||||
// Try integer seconds first (most common from GitHub).
|
||||
// RFC 7231 allows delta-seconds of 0 to indicate immediate retry.
|
||||
if seconds, err := strconv.Atoi(value); err == nil && seconds >= 0 {
|
||||
return time.Duration(seconds) * time.Second, true
|
||||
}
|
||||
|
||||
// Try HTTP-date format (RFC 7231 §7.1.3).
|
||||
// http.ParseTime handles RFC 1123, RFC 850, and ASCTIME formats.
|
||||
if retryAt, err := http.ParseTime(value); err == nil {
|
||||
delay := retryAt.Sub(c.now())
|
||||
if delay < 0 {
|
||||
delay = 0
|
||||
}
|
||||
return delay, true
|
||||
}
|
||||
|
||||
return 0, false
|
||||
}
|
||||
|
||||
// redactURL redacts sensitive components from a URL for safe inclusion in error
|
||||
// messages and log output. It removes userinfo (e.g., user:pass@) and replaces
|
||||
// query parameters with a placeholder.
|
||||
func redactURL(rawURL string) string {
|
||||
u, err := url.Parse(rawURL)
|
||||
if err != nil {
|
||||
return "<unparseable URL>"
|
||||
}
|
||||
u.User = nil
|
||||
|
||||
if u.RawQuery != "" {
|
||||
u.RawQuery = "<redacted>"
|
||||
}
|
||||
return u.String()
|
||||
}
|
||||
|
||||
// doRequest performs an HTTP request with retry on 429 rate limit responses.
|
||||
// It respects the Retry-After header when present, supporting both integer
|
||||
// seconds and HTTP-date formats (capped at maxRetryAfter).
|
||||
func (c *Client) doRequest(ctx context.Context, method, reqURL string, accept string) ([]byte, error) {
|
||||
// NOTE: This parses reqURL a second time (http.NewRequestWithContext parses it
|
||||
// again internally). Acceptable cost: URL parsing is cheap and threading the
|
||||
// parsed *url.URL through would complicate the interface for negligible gain.
|
||||
if !c.allowInsecureHTTP {
|
||||
parsed, err := url.Parse(reqURL)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("parse request URL: %w", err)
|
||||
}
|
||||
if strings.EqualFold(parsed.Scheme, "http") {
|
||||
return nil, fmt.Errorf("refusing HTTP request to %s: use HTTPS or set AllowInsecureHTTP option", redactURL(reqURL))
|
||||
}
|
||||
}
|
||||
|
||||
var backoff []time.Duration
|
||||
if c.retryBackoff != nil {
|
||||
backoff = append([]time.Duration(nil), c.retryBackoff...)
|
||||
} else {
|
||||
backoff = []time.Duration{1 * time.Second, 2 * time.Second}
|
||||
}
|
||||
|
||||
var lastErr error
|
||||
for attempt := 0; attempt < maxRetryAttempts; attempt++ {
|
||||
if attempt > 0 {
|
||||
var delay time.Duration
|
||||
if attempt-1 < len(backoff) {
|
||||
delay = backoff[attempt-1]
|
||||
}
|
||||
if delay > 0 {
|
||||
timer := time.NewTimer(delay)
|
||||
select {
|
||||
case <-timer.C:
|
||||
timer.Stop() // no-op after fire; kept for symmetry with the ctx.Done case
|
||||
case <-ctx.Done():
|
||||
timer.Stop()
|
||||
return nil, ctx.Err()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, method, reqURL, nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create request: %w", err)
|
||||
}
|
||||
req.Header.Set("Authorization", "Bearer "+c.token)
|
||||
if accept != "" {
|
||||
req.Header.Set("Accept", accept)
|
||||
} else {
|
||||
req.Header.Set("Accept", "application/vnd.github+json")
|
||||
}
|
||||
|
||||
resp, err := c.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("do request: %w", err)
|
||||
}
|
||||
|
||||
if resp.StatusCode >= 200 && resp.StatusCode < 300 {
|
||||
body, err := io.ReadAll(io.LimitReader(resp.Body, maxResponseBodyBytes))
|
||||
resp.Body.Close()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("read response body: %w", err)
|
||||
}
|
||||
return body, nil
|
||||
}
|
||||
|
||||
errBody, _ := io.ReadAll(io.LimitReader(resp.Body, maxErrorBodyBytes))
|
||||
resp.Body.Close()
|
||||
|
||||
lastErr = &APIError{StatusCode: resp.StatusCode, Body: string(errBody)}
|
||||
|
||||
// Retry on 429 rate limit
|
||||
if resp.StatusCode == http.StatusTooManyRequests && attempt < maxRetryAttempts-1 {
|
||||
// Check for Retry-After header and override backoff if present.
|
||||
// Supports both integer seconds (common) and HTTP-date format (RFC 7231).
|
||||
if ra := resp.Header.Get("Retry-After"); ra != "" {
|
||||
if delay, ok := c.parseRetryAfter(ra); ok {
|
||||
if delay > maxRetryAfter {
|
||||
delay = maxRetryAfter
|
||||
}
|
||||
if attempt < len(backoff) {
|
||||
backoff[attempt] = delay
|
||||
}
|
||||
}
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
// Don't retry other errors
|
||||
return nil, lastErr
|
||||
}
|
||||
|
||||
return nil, lastErr
|
||||
}
|
||||
|
||||
// doGet is a convenience wrapper for GET requests with the default Accept header.
|
||||
func (c *Client) doGet(ctx context.Context, url string) ([]byte, error) {
|
||||
return c.doRequest(ctx, http.MethodGet, url, "")
|
||||
}
|
||||
|
||||
// doRequestWithBody performs an HTTP request with an optional body, applying the
|
||||
// same HTTPS enforcement as doRequest. It is used by write methods (POST, PUT,
|
||||
// DELETE) that bypass the retry loop in doRequest because write operations are
|
||||
// not idempotent.
|
||||
//
|
||||
// body may be nil for requests that carry no payload (e.g. DELETE).
|
||||
// When body is non-nil, Content-Type is set to application/json.
|
||||
func (c *Client) doRequestWithBody(ctx context.Context, method, reqURL string, body []byte) ([]byte, error) {
|
||||
if !c.allowInsecureHTTP {
|
||||
parsed, err := url.Parse(reqURL)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("parse request URL: %w", err)
|
||||
}
|
||||
if strings.EqualFold(parsed.Scheme, "http") {
|
||||
return nil, fmt.Errorf("refusing HTTP request to %s: use HTTPS or set AllowInsecureHTTP option", redactURL(reqURL))
|
||||
}
|
||||
}
|
||||
|
||||
var reqBody io.Reader
|
||||
if body != nil {
|
||||
reqBody = bytes.NewReader(body)
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, method, reqURL, reqBody)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create request: %w", err)
|
||||
}
|
||||
req.Header.Set("Authorization", "Bearer "+c.token)
|
||||
req.Header.Set("Accept", "application/vnd.github+json")
|
||||
if body != nil {
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
}
|
||||
|
||||
resp, err := c.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("do request: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode >= 200 && resp.StatusCode < 300 {
|
||||
respBody, err := io.ReadAll(io.LimitReader(resp.Body, maxResponseBodyBytes))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("read response body: %w", err)
|
||||
}
|
||||
return respBody, nil
|
||||
}
|
||||
|
||||
errBody, _ := io.ReadAll(io.LimitReader(resp.Body, maxErrorBodyBytes))
|
||||
return nil, &APIError{StatusCode: resp.StatusCode, Body: string(errBody)}
|
||||
}
|
||||
|
||||
// --- API types ---
|
||||
|
||||
// PullRequest holds relevant PR metadata.
|
||||
type PullRequest struct {
|
||||
Title string `json:"title"`
|
||||
Body string `json:"body"`
|
||||
Head struct {
|
||||
Sha string `json:"sha"`
|
||||
Ref string `json:"ref"`
|
||||
} `json:"head"`
|
||||
Draft bool `json:"draft"`
|
||||
}
|
||||
|
||||
// CommitStatus represents a single CI status entry.
|
||||
// GitHub returns "state" not "status"; this type uses Status for consistency
|
||||
// with the gitea package (both are normalized before use).
|
||||
type CommitStatus struct {
|
||||
Status string `json:"state"` // GitHub field is "state"
|
||||
Context string `json:"context"`
|
||||
Description string `json:"description"`
|
||||
TargetURL string `json:"target_url"`
|
||||
}
|
||||
|
||||
// ChangedFile represents a file modified in a PR.
|
||||
type ChangedFile struct {
|
||||
Filename string `json:"filename"`
|
||||
Status string `json:"status"`
|
||||
}
|
||||
|
||||
// ReviewComment represents an inline comment to attach to a review.
|
||||
// GitHub uses "position" (diff hunk position), whereas Gitea uses "new_position" (line number).
|
||||
// When posting inline comments on GitHub, position is required; line numbers
|
||||
// from the diff cannot be used directly.
|
||||
type ReviewComment struct {
|
||||
ID int64 `json:"id,omitempty"`
|
||||
Path string `json:"path"`
|
||||
Position int64 `json:"position,omitempty"` // GitHub diff hunk position
|
||||
Line int64 `json:"line,omitempty"` // GitHub absolute line number (alternative to position)
|
||||
Side string `json:"side,omitempty"` // "RIGHT" or "LEFT"
|
||||
Body string `json:"body"`
|
||||
}
|
||||
|
||||
// Review represents a pull request review from the GitHub API.
|
||||
type Review struct {
|
||||
ID int64 `json:"id"`
|
||||
Body string `json:"body"`
|
||||
User struct {
|
||||
Login string `json:"login"`
|
||||
} `json:"user"`
|
||||
State string `json:"state"`
|
||||
}
|
||||
|
||||
// contentResponse is the GitHub contents API response for a single file.
|
||||
type contentResponse struct {
|
||||
Name string `json:"name"`
|
||||
Path string `json:"path"`
|
||||
Type string `json:"type"` // "file" or "dir" or "symlink" or "submodule"
|
||||
Content string `json:"content"` // Base64-encoded file content (with embedded newlines)
|
||||
Encoding string `json:"encoding"` // "base64" or ""
|
||||
}
|
||||
|
||||
// ContentEntry represents a file or directory entry from the contents API.
|
||||
type ContentEntry struct {
|
||||
Name string `json:"name"`
|
||||
Path string `json:"path"`
|
||||
Type string `json:"type"` // "file" or "dir"
|
||||
}
|
||||
|
||||
// --- PR methods ---
|
||||
|
||||
// GetPullRequest fetches PR metadata.
|
||||
func (c *Client) GetPullRequest(ctx context.Context, owner, repo string, number int) (*PullRequest, error) {
|
||||
reqURL := fmt.Sprintf("%s/repos/%s/%s/pulls/%d",
|
||||
c.baseURL, url.PathEscape(owner), url.PathEscape(repo), number)
|
||||
body, err := c.doGet(ctx, reqURL)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("fetch PR: %w", err)
|
||||
}
|
||||
var pr PullRequest
|
||||
if err := json.Unmarshal(body, &pr); err != nil {
|
||||
return nil, fmt.Errorf("parse PR JSON: %w", err)
|
||||
}
|
||||
return &pr, nil
|
||||
}
|
||||
|
||||
// GetPullRequestDiff fetches the unified diff for a PR.
|
||||
func (c *Client) GetPullRequestDiff(ctx context.Context, owner, repo string, number int) (string, error) {
|
||||
reqURL := fmt.Sprintf("%s/repos/%s/%s/pulls/%d",
|
||||
c.baseURL, url.PathEscape(owner), url.PathEscape(repo), number)
|
||||
body, err := c.doRequest(ctx, http.MethodGet, reqURL, "application/vnd.github.diff")
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("fetch diff: %w", err)
|
||||
}
|
||||
return string(body), nil
|
||||
}
|
||||
|
||||
// GetPullRequestFiles fetches the list of files changed in a PR.
|
||||
// GitHub paginates this endpoint (100 per page max).
|
||||
func (c *Client) GetPullRequestFiles(ctx context.Context, owner, repo string, number int) ([]ChangedFile, error) {
|
||||
const perPage = 100
|
||||
var all []ChangedFile
|
||||
for page := 1; ; page++ {
|
||||
reqURL := fmt.Sprintf("%s/repos/%s/%s/pulls/%d/files?per_page=%d&page=%d",
|
||||
c.baseURL, url.PathEscape(owner), url.PathEscape(repo), number, perPage, page)
|
||||
body, err := c.doGet(ctx, reqURL)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("fetch PR files (page %d): %w", page, err)
|
||||
}
|
||||
var batch []ChangedFile
|
||||
if err := json.Unmarshal(body, &batch); err != nil {
|
||||
return nil, fmt.Errorf("parse PR files JSON (page %d): %w", page, err)
|
||||
}
|
||||
all = append(all, batch...)
|
||||
if len(batch) < perPage {
|
||||
break
|
||||
}
|
||||
}
|
||||
return all, nil
|
||||
}
|
||||
|
||||
// GetCommitStatuses fetches CI statuses for a commit SHA.
|
||||
// GitHub has two status systems: legacy "commit statuses" and newer "check runs".
|
||||
// This method returns commit statuses only; check runs are a separate API.
|
||||
// Note: GitHub returns "state" in the JSON; CommitStatus.Status is tagged accordingly.
|
||||
func (c *Client) GetCommitStatuses(ctx context.Context, owner, repo, sha string) ([]CommitStatus, error) {
|
||||
const perPage = 100
|
||||
var all []CommitStatus
|
||||
for page := 1; ; page++ {
|
||||
reqURL := fmt.Sprintf("%s/repos/%s/%s/commits/%s/statuses?per_page=%d&page=%d",
|
||||
c.baseURL, url.PathEscape(owner), url.PathEscape(repo), url.PathEscape(sha), perPage, page)
|
||||
body, err := c.doGet(ctx, reqURL)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("fetch commit statuses (page %d): %w", page, err)
|
||||
}
|
||||
var batch []CommitStatus
|
||||
if err := json.Unmarshal(body, &batch); err != nil {
|
||||
return nil, fmt.Errorf("parse statuses JSON (page %d): %w", page, err)
|
||||
}
|
||||
all = append(all, batch...)
|
||||
if len(batch) < perPage {
|
||||
break
|
||||
}
|
||||
}
|
||||
return all, nil
|
||||
}
|
||||
|
||||
// --- File content methods ---
|
||||
|
||||
// GetFileContent fetches a file from the default branch of a repo.
|
||||
// GitHub returns base64-encoded content; this method decodes it.
|
||||
func (c *Client) GetFileContent(ctx context.Context, owner, repo, filepath string) (string, error) {
|
||||
return c.getFileContentAtRef(ctx, owner, repo, filepath, "")
|
||||
}
|
||||
|
||||
// GetFileContentRef fetches a file from a specific ref (branch/tag/sha).
|
||||
func (c *Client) GetFileContentRef(ctx context.Context, owner, repo, filepath, ref string) (string, error) {
|
||||
return c.getFileContentAtRef(ctx, owner, repo, filepath, ref)
|
||||
}
|
||||
|
||||
// getFileContentAtRef fetches a file at the given ref (empty = default branch).
|
||||
// GitHub's contents API returns base64-encoded file content.
|
||||
func (c *Client) getFileContentAtRef(ctx context.Context, owner, repo, filepath, ref string) (string, error) {
|
||||
reqURL := fmt.Sprintf("%s/repos/%s/%s/contents/%s",
|
||||
c.baseURL, url.PathEscape(owner), url.PathEscape(repo), escapePath(filepath))
|
||||
if ref != "" {
|
||||
reqURL += "?ref=" + url.QueryEscape(ref)
|
||||
}
|
||||
body, err := c.doGet(ctx, reqURL)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("fetch file %s: %w", filepath, err)
|
||||
}
|
||||
var resp contentResponse
|
||||
if err := json.Unmarshal(body, &resp); err != nil {
|
||||
return "", fmt.Errorf("parse file content JSON for %s: %w", filepath, err)
|
||||
}
|
||||
if resp.Type != "file" {
|
||||
return "", fmt.Errorf("path %s is a %s, not a file", filepath, resp.Type)
|
||||
}
|
||||
if resp.Encoding == "base64" {
|
||||
// GitHub embeds newlines in the base64 content for readability.
|
||||
// Strip them before decoding.
|
||||
cleaned := strings.ReplaceAll(resp.Content, "\n", "")
|
||||
decoded, err := base64.StdEncoding.DecodeString(cleaned)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("decode base64 content for %s: %w", filepath, err)
|
||||
}
|
||||
return string(decoded), nil
|
||||
}
|
||||
// Non-base64 encoding (shouldn't happen normally, but handle gracefully).
|
||||
return resp.Content, nil
|
||||
}
|
||||
|
||||
// ListContents lists files and directories at a given path.
|
||||
// Pass an empty path to list the repository root.
|
||||
// GitHub returns a single object (not array) when path is a file — this
|
||||
// method normalizes both cases to a slice, matching Gitea's behavior.
|
||||
func (c *Client) ListContents(ctx context.Context, owner, repo, path string) ([]ContentEntry, error) {
|
||||
var reqURL string
|
||||
if path == "" || path == "." {
|
||||
reqURL = fmt.Sprintf("%s/repos/%s/%s/contents",
|
||||
c.baseURL, url.PathEscape(owner), url.PathEscape(repo))
|
||||
} else {
|
||||
reqURL = fmt.Sprintf("%s/repos/%s/%s/contents/%s",
|
||||
c.baseURL, url.PathEscape(owner), url.PathEscape(repo), escapePath(path))
|
||||
}
|
||||
body, err := c.doGet(ctx, reqURL)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("list contents %s: %w", path, err)
|
||||
}
|
||||
|
||||
var entries []ContentEntry
|
||||
if err := json.Unmarshal(body, &entries); err != nil {
|
||||
// GitHub returns a single object when path is a file (not an array).
|
||||
var single contentResponse
|
||||
if err2 := json.Unmarshal(body, &single); err2 != nil {
|
||||
return nil, fmt.Errorf("parse contents JSON: %w", err)
|
||||
}
|
||||
if single.Name == "" && single.Path == "" {
|
||||
return nil, fmt.Errorf("parse contents JSON: empty response for path %q", path)
|
||||
}
|
||||
entries = []ContentEntry{{
|
||||
Name: single.Name,
|
||||
Path: single.Path,
|
||||
Type: single.Type,
|
||||
}}
|
||||
}
|
||||
return entries, nil
|
||||
}
|
||||
|
||||
// GetAllFilesInPath recursively fetches all file contents under a path.
|
||||
// If the path is a file, returns just that file's content.
|
||||
// If the path is a directory, recursively fetches all files within it.
|
||||
func (c *Client) GetAllFilesInPath(ctx context.Context, owner, repo, path string) (map[string]string, error) {
|
||||
results := make(map[string]string)
|
||||
|
||||
entries, err := c.ListContents(ctx, owner, repo, path)
|
||||
if err != nil {
|
||||
if !IsNotFound(err) {
|
||||
return nil, fmt.Errorf("list contents %q: %w", path, err)
|
||||
}
|
||||
// 404 means path may be a file — try fetching directly.
|
||||
content, fileErr := c.GetFileContent(ctx, owner, repo, path)
|
||||
if fileErr != nil {
|
||||
return nil, fmt.Errorf("path %q is neither a file nor directory: %w", path, fileErr)
|
||||
}
|
||||
results[path] = content
|
||||
return results, nil
|
||||
}
|
||||
|
||||
for _, entry := range entries {
|
||||
switch entry.Type {
|
||||
case "file":
|
||||
content, err := c.GetFileContent(ctx, owner, repo, entry.Path)
|
||||
if err != nil {
|
||||
slog.Warn("could not fetch file from patterns repo", "file", entry.Path, "error", err)
|
||||
continue
|
||||
}
|
||||
results[entry.Path] = content
|
||||
case "dir":
|
||||
subResults, err := c.GetAllFilesInPath(ctx, owner, repo, entry.Path)
|
||||
if err != nil {
|
||||
slog.Warn("could not recurse into directory", "dir", entry.Path, "error", err)
|
||||
continue
|
||||
}
|
||||
for k, v := range subResults {
|
||||
results[k] = v
|
||||
}
|
||||
}
|
||||
}
|
||||
return results, nil
|
||||
}
|
||||
|
||||
// --- Review methods ---
|
||||
|
||||
// PostReview submits a review to a PR.
|
||||
// event should be one of "APPROVE", "REQUEST_CHANGES", or "COMMENT".
|
||||
// commitID anchors the review to a specific commit SHA. If empty, defaults to current HEAD.
|
||||
// comments are optional inline comments; GitHub uses diff hunk position (not line numbers).
|
||||
// Note: unlike Gitea, GitHub does not support deleting submitted reviews.
|
||||
// Use COMMENT event to supersede old reviews.
|
||||
func (c *Client) PostReview(ctx context.Context, owner, repo string, number int, event, body, commitID string, comments []ReviewComment) (*Review, error) {
|
||||
reqURL := fmt.Sprintf("%s/repos/%s/%s/pulls/%d/reviews",
|
||||
c.baseURL, url.PathEscape(owner), url.PathEscape(repo), number)
|
||||
|
||||
payload := struct {
|
||||
Body string `json:"body"`
|
||||
Event string `json:"event"`
|
||||
CommitID string `json:"commit_id,omitempty"`
|
||||
Comments []ReviewComment `json:"comments,omitempty"`
|
||||
}{
|
||||
Body: body,
|
||||
Event: event,
|
||||
CommitID: commitID,
|
||||
Comments: comments,
|
||||
}
|
||||
|
||||
data, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("marshal review payload: %w", err)
|
||||
}
|
||||
|
||||
respBody, err := c.doRequestWithBody(ctx, http.MethodPost, reqURL, data)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("post review: %w", err)
|
||||
}
|
||||
|
||||
var review Review
|
||||
if err := json.Unmarshal(respBody, &review); err != nil {
|
||||
return nil, fmt.Errorf("parse review response: %w", err)
|
||||
}
|
||||
return &review, nil
|
||||
}
|
||||
|
||||
// ListReviews returns all reviews on a pull request.
|
||||
// GitHub paginates via Link header; this method uses per_page=100.
|
||||
func (c *Client) ListReviews(ctx context.Context, owner, repo string, number int) ([]Review, error) {
|
||||
const perPage = 100
|
||||
var all []Review
|
||||
for page := 1; ; page++ {
|
||||
reqURL := fmt.Sprintf("%s/repos/%s/%s/pulls/%d/reviews?per_page=%d&page=%d",
|
||||
c.baseURL, url.PathEscape(owner), url.PathEscape(repo), number, perPage, page)
|
||||
body, err := c.doGet(ctx, reqURL)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("list reviews (page %d): %w", page, err)
|
||||
}
|
||||
var batch []Review
|
||||
if err := json.Unmarshal(body, &batch); err != nil {
|
||||
return nil, fmt.Errorf("parse reviews (page %d): %w", page, err)
|
||||
}
|
||||
all = append(all, batch...)
|
||||
if len(batch) < perPage {
|
||||
break
|
||||
}
|
||||
}
|
||||
return all, nil
|
||||
}
|
||||
|
||||
// DeleteReview attempts to delete a pull request review.
|
||||
// GitHub only allows deleting PENDING (draft) reviews. Submitted reviews cannot
|
||||
// be deleted via the API; this method returns a descriptive error in that case.
|
||||
// review-bot callers should handle this error gracefully (e.g., by not attempting
|
||||
// supersede and instead posting a new review alongside the old one).
|
||||
func (c *Client) DeleteReview(ctx context.Context, owner, repo string, number int, reviewID int64) error {
|
||||
reqURL := fmt.Sprintf("%s/repos/%s/%s/pulls/%d/reviews/%d",
|
||||
c.baseURL, url.PathEscape(owner), url.PathEscape(repo), number, reviewID)
|
||||
|
||||
// nil body: the GitHub DELETE endpoint for reviews requires no request body.
|
||||
_, err := c.doRequestWithBody(ctx, http.MethodDelete, reqURL, nil)
|
||||
if err != nil {
|
||||
return fmt.Errorf("delete review: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetAuthenticatedUser returns the login of the authenticated user.
|
||||
func (c *Client) GetAuthenticatedUser(ctx context.Context) (string, error) {
|
||||
reqURL := c.baseURL + "/user"
|
||||
body, err := c.doGet(ctx, reqURL)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("get authenticated user: %w", err)
|
||||
}
|
||||
var result struct {
|
||||
Login string `json:"login"`
|
||||
}
|
||||
if err := json.Unmarshal(body, &result); err != nil {
|
||||
return "", fmt.Errorf("parse user response: %w", err)
|
||||
}
|
||||
return result.Login, nil
|
||||
}
|
||||
|
||||
// RequestReviewer adds a user as a requested reviewer on a pull request.
|
||||
// This is idempotent — requesting an already-requested reviewer is a no-op.
|
||||
func (c *Client) RequestReviewer(ctx context.Context, owner, repo string, number int, reviewer string) error {
|
||||
reqURL := fmt.Sprintf("%s/repos/%s/%s/pulls/%d/requested_reviewers",
|
||||
c.baseURL, url.PathEscape(owner), url.PathEscape(repo), number)
|
||||
|
||||
payload := struct {
|
||||
Reviewers []string `json:"reviewers"`
|
||||
}{Reviewers: []string{reviewer}}
|
||||
data, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
return fmt.Errorf("marshal reviewer request: %w", err)
|
||||
}
|
||||
|
||||
_, err = c.doRequestWithBody(ctx, http.MethodPost, reqURL, data)
|
||||
if err != nil {
|
||||
return fmt.Errorf("request reviewer: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// --- helpers ---
|
||||
|
||||
// escapePath escapes each segment of a relative file path for use in URLs.
|
||||
// Slashes are preserved as path separators; other special characters are escaped.
|
||||
func escapePath(p string) string {
|
||||
parts := strings.Split(p, "/")
|
||||
for i, part := range parts {
|
||||
parts[i] = url.PathEscape(part)
|
||||
}
|
||||
return strings.Join(parts, "/")
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,13 @@
|
||||
package github
|
||||
|
||||
// AllowInsecureHTTPForTest permits sending credentials over plaintext HTTP
|
||||
// without requiring the REVIEW_BOT_ALLOW_INSECURE environment variable.
|
||||
// This is intended exclusively for test code using httptest.Server.
|
||||
//
|
||||
// Defined in a _test.go file so it is only available to test binaries.
|
||||
func AllowInsecureHTTPForTest() ClientOption {
|
||||
return func(cfg *clientConfig) {
|
||||
cfg.allowInsecureHTTP = true
|
||||
cfg.insecureIsTestBypass = true
|
||||
}
|
||||
}
|
||||
+9
-8
@@ -16,16 +16,17 @@ import (
|
||||
|
||||
// Integration test requires a running Gitea instance and LLM endpoint.
|
||||
// Set environment variables:
|
||||
// INTEGRATION_GITEA_URL - Gitea base URL
|
||||
// INTEGRATION_GITEA_TOKEN - Gitea API token with repo access
|
||||
// INTEGRATION_GITEA_REPO - owner/repo with an open PR
|
||||
// INTEGRATION_PR_NUMBER - PR number to test against
|
||||
// INTEGRATION_LLM_BASE_URL - LLM API base URL
|
||||
// INTEGRATION_LLM_API_KEY - LLM API key
|
||||
// INTEGRATION_LLM_MODEL - Model name
|
||||
//
|
||||
// INTEGRATION_VCS_URL - VCS base URL
|
||||
// INTEGRATION_GITEA_TOKEN - Gitea API token with repo access
|
||||
// INTEGRATION_GITEA_REPO - owner/repo with an open PR
|
||||
// INTEGRATION_PR_NUMBER - PR number to test against
|
||||
// INTEGRATION_LLM_BASE_URL - LLM API base URL
|
||||
// INTEGRATION_LLM_API_KEY - LLM API key
|
||||
// INTEGRATION_LLM_MODEL - Model name
|
||||
|
||||
func TestIntegration_FullReviewFlow(t *testing.T) {
|
||||
giteaURL := os.Getenv("INTEGRATION_GITEA_URL")
|
||||
giteaURL := os.Getenv("INTEGRATION_VCS_URL")
|
||||
giteaToken := os.Getenv("INTEGRATION_GITEA_TOKEN")
|
||||
giteaRepo := os.Getenv("INTEGRATION_GITEA_REPO")
|
||||
prNumStr := os.Getenv("INTEGRATION_PR_NUMBER")
|
||||
|
||||
Reference in New Issue
Block a user