Compare commits
30 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 6316007eb1 | |||
| b380e7fcae | |||
| 30798ff023 | |||
| 6e8e744816 | |||
| 1194bc758c | |||
| 80af5037b2 | |||
| 5b2fa0b9af | |||
| 491df7cb1f | |||
| 1fcc0b738a | |||
| fce5f2d184 | |||
| af72c64b7f | |||
| 1bc3f206ba | |||
| c10bb72117 | |||
| ae91c8aef5 | |||
| 75f65fbf5d | |||
| 5b43afc6d4 | |||
| d1ef1e21e5 | |||
| 8e4c1cc32e | |||
| ec03dc2373 | |||
| 1749d95727 | |||
| 7c83365fc4 | |||
| 6be5e306aa | |||
| cd6cd93bf0 | |||
| c889724dda | |||
| 1ac51669ed | |||
| 2e6f46f28d | |||
| 3fc31c0822 | |||
| 2b611dbd0b | |||
| 3abb611baf | |||
| dd003c66d5 |
@@ -1,43 +1,17 @@
|
|||||||
# This composite action supports both Gitea Actions and GitHub Actions runners.
|
# This composite action is designed for Gitea Actions runners.
|
||||||
# It detects the VCS host type by checking whether github.api_url is set
|
# Gitea Actions supports GitHub Actions syntax including $GITHUB_OUTPUT,
|
||||||
# (present on GitHub.com and GHES runners, absent on Gitea runners) and uses
|
# actions/cache, and actions/checkout.
|
||||||
# the appropriate releases API for version resolution and binary download
|
|
||||||
# (REST API on GitHub, direct URLs on Gitea).
|
|
||||||
#
|
|
||||||
# Security notes:
|
|
||||||
# - On GitHub/GHES (VCS_TYPE=github), inputs.vcs-url is IGNORED to prevent
|
|
||||||
# token exfiltration. API calls use github.api_url; downloads use
|
|
||||||
# github.server_url. Tokens are never sent to user-supplied URLs.
|
|
||||||
# - On Gitea (VCS_TYPE=gitea), inputs.vcs-url is validated (https scheme,
|
|
||||||
# no whitespace/newlines, and DNS resolution to a public IP) before use.
|
|
||||||
# Python3 resolves the hostname and rejects RFC1918, RFC6598 (carrier-grade
|
|
||||||
# NAT), loopback, link-local, and other reserved addresses to prevent SSRF attacks.
|
|
||||||
# The installed review-bot binary additionally uses a safe HTTP transport
|
|
||||||
# (DialContext-level IP check) for all Gitea API calls at runtime.
|
|
||||||
# The binary also exposes a `validate-url` subcommand for use in any future
|
|
||||||
# shell steps that need to validate a URL before passing it to curl.
|
|
||||||
# - action-repo is validated against owner/repo pattern.
|
|
||||||
# - Tokens are passed via masked environment variables, not step outputs.
|
|
||||||
#
|
|
||||||
# Requirements: python3, sha256sum, curl (all present on ubuntu-* runners).
|
# Requirements: python3, sha256sum, curl (all present on ubuntu-* runners).
|
||||||
name: 'AI Code Review'
|
name: 'AI Code Review'
|
||||||
description: 'Run AI-powered code review on a pull request using review-bot'
|
description: 'Run AI-powered code review on a pull request using review-bot'
|
||||||
|
|
||||||
inputs:
|
inputs:
|
||||||
vcs-url:
|
gitea-url:
|
||||||
description: 'VCS server URL (only used on Gitea runners; ignored on GitHub/GHES). Defaults to server_url.'
|
description: 'Gitea instance URL (defaults to server_url)'
|
||||||
required: false
|
required: false
|
||||||
default: ''
|
default: ''
|
||||||
repo:
|
repo:
|
||||||
description: 'Repository to review (owner/name, defaults to current)'
|
description: 'Repository (owner/name, defaults to current)'
|
||||||
required: false
|
|
||||||
default: ''
|
|
||||||
action-repo:
|
|
||||||
description: 'Repository hosting review-bot releases (owner/name). Defaults to github.action_repository or rodin/review-bot.'
|
|
||||||
required: false
|
|
||||||
default: ''
|
|
||||||
action-repo-token:
|
|
||||||
description: 'Token for downloading release assets from action-repo (defaults to github.token on GitHub, reviewer-token on Gitea). Required for private repos.'
|
|
||||||
required: false
|
required: false
|
||||||
default: ''
|
default: ''
|
||||||
pr-number:
|
pr-number:
|
||||||
@@ -45,7 +19,7 @@ inputs:
|
|||||||
required: false
|
required: false
|
||||||
default: ''
|
default: ''
|
||||||
reviewer-token:
|
reviewer-token:
|
||||||
description: 'Token for posting the review'
|
description: 'Gitea token for posting the review'
|
||||||
required: true
|
required: true
|
||||||
reviewer-name:
|
reviewer-name:
|
||||||
description: 'Display name for the reviewer'
|
description: 'Display name for the reviewer'
|
||||||
@@ -130,17 +104,6 @@ inputs:
|
|||||||
description: 'Path to custom persona JSON file'
|
description: 'Path to custom persona JSON file'
|
||||||
required: false
|
required: false
|
||||||
default: ''
|
default: ''
|
||||||
doc-map:
|
|
||||||
description: >-
|
|
||||||
Path to a YAML file mapping source path globs to governing design docs.
|
|
||||||
review-bot intersects the map with changed PR paths and injects matching
|
|
||||||
docs as context alongside the diff.
|
|
||||||
required: false
|
|
||||||
default: ''
|
|
||||||
doc-map-max-bytes:
|
|
||||||
description: 'Maximum bytes of injected doc content from doc-map (default 102400 = 100KB)'
|
|
||||||
required: false
|
|
||||||
default: '102400'
|
|
||||||
|
|
||||||
runs:
|
runs:
|
||||||
using: 'composite'
|
using: 'composite'
|
||||||
@@ -149,325 +112,45 @@ runs:
|
|||||||
id: version
|
id: version
|
||||||
shell: bash
|
shell: bash
|
||||||
run: |
|
run: |
|
||||||
set -euo pipefail
|
GITEA_URL="${{ inputs.gitea-url || github.server_url }}"
|
||||||
|
REPO="${{ inputs.repo || 'rodin/review-bot' }}"
|
||||||
# --- Input Validation ---
|
|
||||||
|
|
||||||
# Determine the repo hosting review-bot releases (not the repo being reviewed)
|
|
||||||
ACTION_REPO="${{ inputs.action-repo }}"
|
|
||||||
if [ -z "$ACTION_REPO" ]; then
|
|
||||||
# github.action_repository is the repo containing the running action
|
|
||||||
ACTION_REPO="${{ github.action_repository }}"
|
|
||||||
fi
|
|
||||||
if [ -z "$ACTION_REPO" ]; then
|
|
||||||
# Final fallback for Gitea (which may not set action_repository)
|
|
||||||
ACTION_REPO="rodin/review-bot"
|
|
||||||
echo "::notice::action-repo not specified and github.action_repository is empty; falling back to rodin/review-bot"
|
|
||||||
fi
|
|
||||||
|
|
||||||
# Validate ACTION_REPO matches owner/repo pattern (prevent path traversal)
|
|
||||||
if ! printf '%s' "$ACTION_REPO" | grep -qE '^[a-zA-Z0-9._-]+/[a-zA-Z0-9._-]+$'; then
|
|
||||||
echo "Error: action-repo '${ACTION_REPO}' does not match expected owner/repo format" >&2
|
|
||||||
exit 1
|
|
||||||
fi
|
|
||||||
|
|
||||||
# Detect VCS host type using github.api_url context.
|
|
||||||
# github.api_url is set on GitHub.com (https://api.github.com) and GHES
|
|
||||||
# (https://<host>/api/v3). It is empty/unset on Gitea Actions runners.
|
|
||||||
GITHUB_API_URL="${{ github.api_url }}"
|
|
||||||
if [ -n "$GITHUB_API_URL" ]; then
|
|
||||||
VCS_TYPE="github"
|
|
||||||
else
|
|
||||||
VCS_TYPE="gitea"
|
|
||||||
fi
|
|
||||||
|
|
||||||
# Determine SERVER_URL based on VCS type.
|
|
||||||
# SECURITY: On GitHub/GHES, ALWAYS use github.server_url — never trust
|
|
||||||
# inputs.vcs-url to prevent token exfiltration to attacker-controlled hosts.
|
|
||||||
if [ "$VCS_TYPE" = "github" ]; then
|
|
||||||
SERVER_URL="${{ github.server_url }}"
|
|
||||||
if [ -n "${{ inputs.vcs-url }}" ]; then
|
|
||||||
echo "::warning::inputs.vcs-url is ignored on GitHub/GHES runners (VCS_TYPE=github). Using github.server_url instead."
|
|
||||||
fi
|
|
||||||
else
|
|
||||||
SERVER_URL="${{ inputs.vcs-url || github.server_url }}"
|
|
||||||
fi
|
|
||||||
# Strip trailing slash if present
|
|
||||||
SERVER_URL="${SERVER_URL%/}"
|
|
||||||
|
|
||||||
# Validate SERVER_URL for Gitea path: must be https, no whitespace/newlines.
|
|
||||||
# The [^[:space:]] class already rejects newlines, so no separate newline check needed.
|
|
||||||
if [ "$VCS_TYPE" = "gitea" ]; then
|
|
||||||
if ! printf '%s' "$SERVER_URL" | grep -qE '^https://[^[:space:]]+$'; then
|
|
||||||
echo "Error: SERVER_URL '${SERVER_URL}' must be an https:// URL with no whitespace" >&2
|
|
||||||
exit 1
|
|
||||||
fi
|
|
||||||
|
|
||||||
# Additional IP-level SSRF defense: resolve the hostname and reject
|
|
||||||
# requests to RFC1918, RFC6598 (carrier-grade NAT), loopback, link-local,
|
|
||||||
# and other reserved addresses.
|
|
||||||
# python3 is required on ubuntu-* runners (see requirements comment above).
|
|
||||||
# Use printf to write the script to a temp file so the python lines are valid
|
|
||||||
# YAML (each indented line becomes a printf argument — no unindented code).
|
|
||||||
# SERVER_URL is passed via CHECK_URL env var, never interpolated into python code.
|
|
||||||
printf '%s\n' \
|
|
||||||
'import socket,ipaddress,sys,os' \
|
|
||||||
'from urllib.parse import urlparse' \
|
|
||||||
'u=os.environ["CHECK_URL"]; parsed=urlparse(u)' \
|
|
||||||
'if parsed.username or parsed.password:' \
|
|
||||||
' print("Error: URL contains user-info — not allowed",file=sys.stderr); sys.exit(2)' \
|
|
||||||
'h=parsed.hostname' \
|
|
||||||
'(print("Error: no hostname",file=sys.stderr) or sys.exit(2)) if not h else None' \
|
|
||||||
'try: rs=socket.getaddrinfo(h,None)' \
|
|
||||||
'except socket.gaierror as e: print(f"DNS error: {e}",file=sys.stderr); sys.exit(1)' \
|
|
||||||
'if not rs: print("Error: no addresses",file=sys.stderr); sys.exit(1)' \
|
|
||||||
'for _,_,_,_,(a,*_) in rs:' \
|
|
||||||
' ip=ipaddress.ip_address(a)' \
|
|
||||||
' if isinstance(ip,ipaddress.IPv6Address) and ip.ipv4_mapped: ip=ip.ipv4_mapped' \
|
|
||||||
' cgn=ipaddress.ip_network("100.64.0.0/10")' \
|
|
||||||
' if ip.is_private or ip.is_loopback or ip.is_link_local or ip.is_multicast or ip.is_reserved or ip in cgn:' \
|
|
||||||
' print(f"blocked: {a}",file=sys.stderr); sys.exit(1)' \
|
|
||||||
> /tmp/_ssrf_check.py
|
|
||||||
CHECK_URL="${SERVER_URL}" python3 /tmp/_ssrf_check.py || {
|
|
||||||
echo "Error: SERVER_URL '${SERVER_URL}' resolves to a private/reserved IP address" >&2
|
|
||||||
exit 1
|
|
||||||
}
|
|
||||||
fi
|
|
||||||
|
|
||||||
# Determine auth token for release API requests
|
|
||||||
ACTION_TOKEN="${{ inputs.action-repo-token }}"
|
|
||||||
if [ -z "$ACTION_TOKEN" ]; then
|
|
||||||
if [ "$VCS_TYPE" = "github" ]; then
|
|
||||||
ACTION_TOKEN="${{ github.token }}"
|
|
||||||
else
|
|
||||||
ACTION_TOKEN="${{ inputs.reviewer-token }}"
|
|
||||||
fi
|
|
||||||
fi
|
|
||||||
|
|
||||||
# Validate token contains no control characters (defense-in-depth against header injection)
|
|
||||||
if [ -n "$ACTION_TOKEN" ]; then
|
|
||||||
if printf '%s' "$ACTION_TOKEN" | LC_ALL=C grep -q '[^[:print:]]'; then
|
|
||||||
echo "Error: ACTION_TOKEN contains control characters" >&2
|
|
||||||
exit 1
|
|
||||||
fi
|
|
||||||
fi
|
|
||||||
|
|
||||||
if [ "${{ inputs.version }}" = "latest" ]; then
|
if [ "${{ inputs.version }}" = "latest" ]; then
|
||||||
if [ "$VCS_TYPE" = "github" ]; then
|
VERSION=$(curl -sSf "${GITEA_URL}/api/v1/repos/${REPO}/releases?limit=1" \
|
||||||
# SECURITY: Use github.api_url which is a trusted platform-provided value.
|
| python3 -c "import sys, json; releases = json.load(sys.stdin); print(releases[0]['tag_name'] if releases else '')")
|
||||||
# Never construct API URLs from user-supplied inputs on GitHub.
|
|
||||||
API_URL="${GITHUB_API_URL}/repos/${ACTION_REPO}/releases?per_page=1"
|
|
||||||
else
|
|
||||||
# Gitea API — SERVER_URL was validated above
|
|
||||||
API_URL="${SERVER_URL}/api/v1/repos/${ACTION_REPO}/releases?limit=1"
|
|
||||||
fi
|
|
||||||
|
|
||||||
# Fetch latest version with inline auth header (no intermediate variable)
|
|
||||||
if [ -n "$ACTION_TOKEN" ]; then
|
|
||||||
if [ "$VCS_TYPE" = "github" ]; then
|
|
||||||
VERSION=$(curl -sSf --connect-timeout 10 --max-time 30 \
|
|
||||||
-H "Authorization: Bearer ${ACTION_TOKEN}" "$API_URL" \
|
|
||||||
| python3 -c "import sys, json; releases = json.load(sys.stdin); print(releases[0]['tag_name'] if releases else '')")
|
|
||||||
else
|
|
||||||
VERSION=$(curl -sSf --connect-timeout 10 --max-time 30 \
|
|
||||||
-H "Authorization: token ${ACTION_TOKEN}" "$API_URL" \
|
|
||||||
| python3 -c "import sys, json; releases = json.load(sys.stdin); print(releases[0]['tag_name'] if releases else '')")
|
|
||||||
fi
|
|
||||||
else
|
|
||||||
VERSION=$(curl -sSf --connect-timeout 10 --max-time 30 "$API_URL" \
|
|
||||||
| python3 -c "import sys, json; releases = json.load(sys.stdin); print(releases[0]['tag_name'] if releases else '')")
|
|
||||||
fi
|
|
||||||
|
|
||||||
if [ -z "$VERSION" ]; then
|
if [ -z "$VERSION" ]; then
|
||||||
echo "Failed to determine latest version from ${API_URL}" >&2
|
echo "Failed to determine latest version" >&2
|
||||||
exit 1
|
exit 1
|
||||||
fi
|
fi
|
||||||
else
|
else
|
||||||
VERSION="${{ inputs.version }}"
|
VERSION="${{ inputs.version }}"
|
||||||
fi
|
fi
|
||||||
|
|
||||||
# Validate VERSION: no slashes or whitespace (prevent path traversal).
|
|
||||||
# [:space:] includes newlines and carriage returns in POSIX.
|
|
||||||
if printf '%s' "$VERSION" | grep -qE '[/[:space:]]'; then
|
|
||||||
echo "Error: VERSION '${VERSION}' contains invalid characters (newline, slash, or whitespace)" >&2
|
|
||||||
exit 1
|
|
||||||
fi
|
|
||||||
|
|
||||||
# Detect OS and architecture for platform-specific binary download
|
|
||||||
OS_RAW=$(uname -s | tr '[:upper:]' '[:lower:]')
|
|
||||||
case "$OS_RAW" in
|
|
||||||
linux) OS="linux" ;;
|
|
||||||
darwin) OS="darwin" ;;
|
|
||||||
*)
|
|
||||||
echo "Error: unsupported OS: $(uname -s)" >&2
|
|
||||||
exit 1
|
|
||||||
;;
|
|
||||||
esac
|
|
||||||
|
|
||||||
RAW_ARCH=$(uname -m)
|
|
||||||
case "$RAW_ARCH" in
|
|
||||||
x86_64) ARCH="amd64" ;;
|
|
||||||
aarch64 | arm64) ARCH="arm64" ;;
|
|
||||||
*)
|
|
||||||
echo "Error: unsupported architecture: $RAW_ARCH" >&2
|
|
||||||
exit 1
|
|
||||||
;;
|
|
||||||
esac
|
|
||||||
|
|
||||||
echo "version=${VERSION}" >> "$GITHUB_OUTPUT"
|
echo "version=${VERSION}" >> "$GITHUB_OUTPUT"
|
||||||
echo "os=${OS}" >> "$GITHUB_OUTPUT"
|
|
||||||
echo "arch=${ARCH}" >> "$GITHUB_OUTPUT"
|
|
||||||
echo "action_repo=${ACTION_REPO}" >> "$GITHUB_OUTPUT"
|
|
||||||
echo "server_url=${SERVER_URL}" >> "$GITHUB_OUTPUT"
|
|
||||||
echo "vcs_type=${VCS_TYPE}" >> "$GITHUB_OUTPUT"
|
|
||||||
|
|
||||||
# SECURITY: Pass token via masked environment variable instead of step output.
|
|
||||||
# Step outputs can leak in debug logs; GITHUB_ENV with masking is safer.
|
|
||||||
if [ -n "$ACTION_TOKEN" ]; then
|
|
||||||
echo "::add-mask::${ACTION_TOKEN}"
|
|
||||||
echo "ACTION_TOKEN=${ACTION_TOKEN}" >> "$GITHUB_ENV"
|
|
||||||
fi
|
|
||||||
|
|
||||||
- name: Cache review-bot binary
|
- name: Cache review-bot binary
|
||||||
id: cache
|
id: cache
|
||||||
uses: actions/cache@v4
|
uses: actions/cache@v4
|
||||||
with:
|
with:
|
||||||
path: ${{ runner.temp }}/review-bot
|
path: ${{ runner.temp }}/review-bot
|
||||||
key: review-bot-${{ steps.version.outputs.os }}-${{ steps.version.outputs.arch }}-${{ steps.version.outputs.version }}
|
key: review-bot-linux-amd64-${{ steps.version.outputs.version }}
|
||||||
|
|
||||||
- name: Install review-bot
|
- name: Install review-bot
|
||||||
if: steps.cache.outputs.cache-hit != 'true'
|
if: steps.cache.outputs.cache-hit != 'true'
|
||||||
shell: bash
|
shell: bash
|
||||||
run: |
|
run: |
|
||||||
set -euo pipefail
|
GITEA_URL="${{ inputs.gitea-url || github.server_url }}"
|
||||||
|
REPO="${{ inputs.repo || 'rodin/review-bot' }}"
|
||||||
SERVER_URL="${{ steps.version.outputs.server_url }}"
|
|
||||||
ACTION_REPO="${{ steps.version.outputs.action_repo }}"
|
|
||||||
VERSION="${{ steps.version.outputs.version }}"
|
VERSION="${{ steps.version.outputs.version }}"
|
||||||
VCS_TYPE="${{ steps.version.outputs.vcs_type }}"
|
BINARY="review-bot-linux-amd64"
|
||||||
OS="${{ steps.version.outputs.os }}"
|
|
||||||
ARCH="${{ steps.version.outputs.arch }}"
|
|
||||||
# Read token from masked environment variable (set in Determine version step)
|
|
||||||
# Falls back to empty if not set (public repos don't need auth)
|
|
||||||
ACTION_TOKEN="${ACTION_TOKEN:-}"
|
|
||||||
BINARY="review-bot-${OS}-${ARCH}"
|
|
||||||
|
|
||||||
# SECURITY: Re-validate SERVER_URL at the start of this step to mitigate DNS
|
curl -sSfL "${GITEA_URL}/${REPO}/releases/download/${VERSION}/${BINARY}" \
|
||||||
# rebinding attacks. A DNS TTL expiry between "Determine version" and here
|
-o "${{ runner.temp }}/review-bot"
|
||||||
# could allow an attacker to change the resolved IP to a private/reserved
|
curl -sSfL "${GITEA_URL}/${REPO}/releases/download/${VERSION}/checksums.txt" \
|
||||||
# address, causing curl to send ACTION_TOKEN to an internal host.
|
-o "${{ runner.temp }}/checksums.txt"
|
||||||
# Only needed on Gitea path (VCS_TYPE=gitea); GitHub/GHES uses platform-controlled URLs.
|
|
||||||
if [ "$VCS_TYPE" = "gitea" ]; then
|
|
||||||
printf '%s\n' \
|
|
||||||
'import socket,ipaddress,sys,os' \
|
|
||||||
'from urllib.parse import urlparse' \
|
|
||||||
'u=os.environ["CHECK_URL"]; parsed=urlparse(u)' \
|
|
||||||
'if parsed.username or parsed.password:' \
|
|
||||||
' print("Error: URL contains user-info — not allowed",file=sys.stderr); sys.exit(2)' \
|
|
||||||
'h=parsed.hostname' \
|
|
||||||
'(print("Error: no hostname",file=sys.stderr) or sys.exit(2)) if not h else None' \
|
|
||||||
'try: rs=socket.getaddrinfo(h,None)' \
|
|
||||||
'except socket.gaierror as e: print(f"DNS error: {e}",file=sys.stderr); sys.exit(1)' \
|
|
||||||
'if not rs: print("Error: no addresses",file=sys.stderr); sys.exit(1)' \
|
|
||||||
'for _,_,_,_,(a,*_) in rs:' \
|
|
||||||
' ip=ipaddress.ip_address(a)' \
|
|
||||||
' if isinstance(ip,ipaddress.IPv6Address) and ip.ipv4_mapped: ip=ip.ipv4_mapped' \
|
|
||||||
' cgn=ipaddress.ip_network("100.64.0.0/10")' \
|
|
||||||
' if ip.is_private or ip.is_loopback or ip.is_link_local or ip.is_multicast or ip.is_reserved or ip in cgn:' \
|
|
||||||
' print(f"blocked: {a}",file=sys.stderr); sys.exit(1)' \
|
|
||||||
> /tmp/_ssrf_check_install.py
|
|
||||||
CHECK_URL="${SERVER_URL}" python3 /tmp/_ssrf_check_install.py || {
|
|
||||||
echo "Error: SERVER_URL '${SERVER_URL}' resolves to a private/reserved IP address" >&2
|
|
||||||
exit 1
|
|
||||||
}
|
|
||||||
fi
|
|
||||||
|
|
||||||
if [ "$VCS_TYPE" = "github" ]; then
|
|
||||||
# GitHub/GHES: Use REST API for release asset downloads.
|
|
||||||
# Web release URLs ({server}/.../releases/download/{tag}/{asset}) redirect
|
|
||||||
# to S3 and don't reliably support Authorization headers for private repos.
|
|
||||||
# The REST API endpoint with Accept: application/octet-stream is required.
|
|
||||||
# GITHUB_API_URL: trusted platform value, same as detected in "Determine version" step.
|
|
||||||
GITHUB_API_URL="${{ github.api_url }}"
|
|
||||||
|
|
||||||
if [ -n "$ACTION_TOKEN" ]; then
|
|
||||||
RELEASE_JSON=$(curl -sSf --connect-timeout 10 --max-time 30 \
|
|
||||||
-H "Authorization: Bearer ${ACTION_TOKEN}" \
|
|
||||||
"${GITHUB_API_URL}/repos/${ACTION_REPO}/releases/tags/${VERSION}")
|
|
||||||
else
|
|
||||||
RELEASE_JSON=$(curl -sSf --connect-timeout 10 --max-time 30 \
|
|
||||||
"${GITHUB_API_URL}/repos/${ACTION_REPO}/releases/tags/${VERSION}")
|
|
||||||
fi
|
|
||||||
|
|
||||||
# Extract asset IDs for binary and checksums
|
|
||||||
BINARY_ASSET_ID=$(printf '%s' "$RELEASE_JSON" | python3 -c "import sys, json; assets = json.load(sys.stdin).get('assets', []); matches = [a['id'] for a in assets if a['name'] == '${BINARY}']; print(matches[0] if matches else '')")
|
|
||||||
if [ -z "$BINARY_ASSET_ID" ]; then
|
|
||||||
echo "Error: could not find asset '${BINARY}' in release ${VERSION}" >&2
|
|
||||||
exit 1
|
|
||||||
fi
|
|
||||||
|
|
||||||
CHECKSUMS_ASSET_ID=$(printf '%s' "$RELEASE_JSON" | python3 -c "import sys, json; assets = json.load(sys.stdin).get('assets', []); matches = [a['id'] for a in assets if a['name'] == 'checksums.txt']; print(matches[0] if matches else '')")
|
|
||||||
if [ -z "$CHECKSUMS_ASSET_ID" ]; then
|
|
||||||
echo "Error: could not find asset 'checksums.txt' in release ${VERSION}" >&2
|
|
||||||
exit 1
|
|
||||||
fi
|
|
||||||
|
|
||||||
# Download assets via REST API with Accept: application/octet-stream
|
|
||||||
if [ -n "$ACTION_TOKEN" ]; then
|
|
||||||
curl -sSfL --connect-timeout 10 --max-time 120 \
|
|
||||||
-H "Authorization: Bearer ${ACTION_TOKEN}" \
|
|
||||||
-H "Accept: application/octet-stream" \
|
|
||||||
"${GITHUB_API_URL}/repos/${ACTION_REPO}/releases/assets/${BINARY_ASSET_ID}" \
|
|
||||||
-o "${{ runner.temp }}/review-bot"
|
|
||||||
curl -sSfL --connect-timeout 10 --max-time 30 \
|
|
||||||
-H "Authorization: Bearer ${ACTION_TOKEN}" \
|
|
||||||
-H "Accept: application/octet-stream" \
|
|
||||||
"${GITHUB_API_URL}/repos/${ACTION_REPO}/releases/assets/${CHECKSUMS_ASSET_ID}" \
|
|
||||||
-o "${{ runner.temp }}/checksums.txt"
|
|
||||||
else
|
|
||||||
curl -sSfL --connect-timeout 10 --max-time 120 \
|
|
||||||
-H "Accept: application/octet-stream" \
|
|
||||||
"${GITHUB_API_URL}/repos/${ACTION_REPO}/releases/assets/${BINARY_ASSET_ID}" \
|
|
||||||
-o "${{ runner.temp }}/review-bot"
|
|
||||||
curl -sSfL --connect-timeout 10 --max-time 30 \
|
|
||||||
-H "Accept: application/octet-stream" \
|
|
||||||
"${GITHUB_API_URL}/repos/${ACTION_REPO}/releases/assets/${CHECKSUMS_ASSET_ID}" \
|
|
||||||
-o "${{ runner.temp }}/checksums.txt"
|
|
||||||
fi
|
|
||||||
else
|
|
||||||
# Gitea: Direct download via web release URLs (Gitea serves assets
|
|
||||||
# directly without redirects — no -L needed).
|
|
||||||
# SECURITY: Omitting -L prevents forwarding Authorization header to
|
|
||||||
# unexpected hosts if Gitea ever introduces CDN redirects.
|
|
||||||
DOWNLOAD_URL="${SERVER_URL}/${ACTION_REPO}/releases/download/${VERSION}"
|
|
||||||
|
|
||||||
if [ -n "$ACTION_TOKEN" ]; then
|
|
||||||
curl -sSf --connect-timeout 10 --max-time 120 \
|
|
||||||
-H "Authorization: token ${ACTION_TOKEN}" \
|
|
||||||
"${DOWNLOAD_URL}/${BINARY}" -o "${{ runner.temp }}/review-bot"
|
|
||||||
curl -sSf --connect-timeout 10 --max-time 30 \
|
|
||||||
-H "Authorization: token ${ACTION_TOKEN}" \
|
|
||||||
"${DOWNLOAD_URL}/checksums.txt" -o "${{ runner.temp }}/checksums.txt"
|
|
||||||
else
|
|
||||||
curl -sSf --connect-timeout 10 --max-time 120 \
|
|
||||||
"${DOWNLOAD_URL}/${BINARY}" -o "${{ runner.temp }}/review-bot"
|
|
||||||
curl -sSf --connect-timeout 10 --max-time 30 \
|
|
||||||
"${DOWNLOAD_URL}/checksums.txt" -o "${{ runner.temp }}/checksums.txt"
|
|
||||||
fi
|
|
||||||
fi
|
|
||||||
|
|
||||||
# Verify SHA-256 checksum
|
# Verify SHA-256 checksum
|
||||||
# NOTE: This verifies integrity (download wasn't corrupted) but not
|
|
||||||
# authenticity — both binary and checksums come from the same server.
|
|
||||||
# For stronger guarantees, consider GPG signature verification.
|
|
||||||
cd "${{ runner.temp }}"
|
cd "${{ runner.temp }}"
|
||||||
EXPECTED=$(grep -E "^[0-9a-f]+[[:space:]]+\*?${BINARY}$" checksums.txt | awk '{print $1}')
|
EXPECTED=$(grep "${BINARY}" checksums.txt | awk '{print $1}')
|
||||||
# sha256sum (GNU) is not available on macOS; use shasum -a 256 on darwin.
|
ACTUAL=$(sha256sum review-bot | awk '{print $1}')
|
||||||
if [ "${OS}" = "darwin" ]; then
|
|
||||||
ACTUAL=$(shasum -a 256 review-bot | awk '{print $1}')
|
|
||||||
else
|
|
||||||
ACTUAL=$(sha256sum review-bot | awk '{print $1}')
|
|
||||||
fi
|
|
||||||
|
|
||||||
if [ -z "$EXPECTED" ]; then
|
if [ -z "$EXPECTED" ]; then
|
||||||
echo "Error: no checksum found for ${BINARY}" >&2
|
echo "Error: no checksum found for ${BINARY}" >&2
|
||||||
@@ -481,12 +164,12 @@ runs:
|
|||||||
fi
|
fi
|
||||||
|
|
||||||
chmod +x "${{ runner.temp }}/review-bot"
|
chmod +x "${{ runner.temp }}/review-bot"
|
||||||
echo "Installed review-bot-${OS}-${ARCH} ${VERSION} (checksum verified)"
|
echo "Installed review-bot ${VERSION} (checksum verified)"
|
||||||
|
|
||||||
- name: Run review
|
- name: Run review
|
||||||
shell: bash
|
shell: bash
|
||||||
env:
|
env:
|
||||||
VCS_URL: ${{ steps.version.outputs.server_url }}
|
GITEA_URL: ${{ inputs.gitea-url || github.server_url }}
|
||||||
GITEA_REPO: ${{ inputs.repo || github.repository }}
|
GITEA_REPO: ${{ inputs.repo || github.repository }}
|
||||||
PR_NUMBER: ${{ inputs.pr-number || github.event.pull_request.number }}
|
PR_NUMBER: ${{ inputs.pr-number || github.event.pull_request.number }}
|
||||||
REVIEWER_TOKEN: ${{ inputs.reviewer-token }}
|
REVIEWER_TOKEN: ${{ inputs.reviewer-token }}
|
||||||
@@ -504,8 +187,6 @@ runs:
|
|||||||
SYSTEM_PROMPT_FILE: ${{ inputs.system-prompt-file }}
|
SYSTEM_PROMPT_FILE: ${{ inputs.system-prompt-file }}
|
||||||
PERSONA: ${{ inputs.persona }}
|
PERSONA: ${{ inputs.persona }}
|
||||||
PERSONA_FILE: ${{ inputs.persona-file }}
|
PERSONA_FILE: ${{ inputs.persona-file }}
|
||||||
DOC_MAP_FILE: ${{ inputs.doc-map }}
|
|
||||||
DOC_MAP_MAX_BYTES: ${{ inputs.doc-map-max-bytes }}
|
|
||||||
AICORE_CLIENT_ID: ${{ inputs.aicore-client-id }}
|
AICORE_CLIENT_ID: ${{ inputs.aicore-client-id }}
|
||||||
AICORE_CLIENT_SECRET: ${{ inputs.aicore-client-secret }}
|
AICORE_CLIENT_SECRET: ${{ inputs.aicore-client-secret }}
|
||||||
AICORE_AUTH_URL: ${{ inputs.aicore-auth-url }}
|
AICORE_AUTH_URL: ${{ inputs.aicore-auth-url }}
|
||||||
|
|||||||
@@ -49,7 +49,7 @@ jobs:
|
|||||||
- run: go build -o review-bot ./cmd/review-bot
|
- run: go build -o review-bot ./cmd/review-bot
|
||||||
- name: Run ${{ matrix.name }} review
|
- name: Run ${{ matrix.name }} review
|
||||||
env:
|
env:
|
||||||
VCS_URL: ${{ github.server_url }}
|
GITEA_URL: ${{ github.server_url }}
|
||||||
GITEA_REPO: ${{ github.repository }}
|
GITEA_REPO: ${{ github.repository }}
|
||||||
PR_NUMBER: ${{ github.event.pull_request.number }}
|
PR_NUMBER: ${{ github.event.pull_request.number }}
|
||||||
REVIEWER_TOKEN: ${{ secrets[matrix.token_secret] }}
|
REVIEWER_TOKEN: ${{ secrets[matrix.token_secret] }}
|
||||||
|
|||||||
@@ -0,0 +1,200 @@
|
|||||||
|
# This composite action is designed for Gitea Actions runners.
|
||||||
|
# Gitea Actions supports GitHub Actions syntax including $GITHUB_OUTPUT,
|
||||||
|
# actions/cache, and actions/checkout.
|
||||||
|
# Requirements: python3, sha256sum, curl (all present on ubuntu-* runners).
|
||||||
|
name: 'AI Code Review'
|
||||||
|
description: 'Run AI-powered code review on a pull request using review-bot'
|
||||||
|
|
||||||
|
inputs:
|
||||||
|
gitea-url:
|
||||||
|
description: 'Gitea instance URL (defaults to server_url)'
|
||||||
|
required: false
|
||||||
|
default: ''
|
||||||
|
repo:
|
||||||
|
description: 'Repository (owner/name, defaults to current)'
|
||||||
|
required: false
|
||||||
|
default: ''
|
||||||
|
pr-number:
|
||||||
|
description: 'Pull request number (defaults to current PR)'
|
||||||
|
required: false
|
||||||
|
default: ''
|
||||||
|
reviewer-token:
|
||||||
|
description: 'Gitea token for posting the review'
|
||||||
|
required: true
|
||||||
|
reviewer-name:
|
||||||
|
description: 'Display name for the reviewer'
|
||||||
|
required: false
|
||||||
|
default: ''
|
||||||
|
llm-base-url:
|
||||||
|
description: 'OpenAI-compatible LLM API base URL (not required for aicore provider)'
|
||||||
|
required: false
|
||||||
|
default: ''
|
||||||
|
llm-api-key:
|
||||||
|
description: 'LLM API key (not required for aicore provider)'
|
||||||
|
required: false
|
||||||
|
default: ''
|
||||||
|
llm-model:
|
||||||
|
description: 'LLM model name'
|
||||||
|
required: true
|
||||||
|
llm-provider:
|
||||||
|
description: 'LLM API provider: openai, anthropic, or aicore (default openai)'
|
||||||
|
required: false
|
||||||
|
default: 'openai'
|
||||||
|
aicore-client-id:
|
||||||
|
description: 'SAP AI Core client ID (required for aicore provider)'
|
||||||
|
required: false
|
||||||
|
default: ''
|
||||||
|
aicore-client-secret:
|
||||||
|
description: 'SAP AI Core client secret (required for aicore provider)'
|
||||||
|
required: false
|
||||||
|
default: ''
|
||||||
|
aicore-auth-url:
|
||||||
|
description: 'SAP AI Core authentication URL (required for aicore provider)'
|
||||||
|
required: false
|
||||||
|
default: ''
|
||||||
|
aicore-api-url:
|
||||||
|
description: 'SAP AI Core API URL (required for aicore provider)'
|
||||||
|
required: false
|
||||||
|
default: ''
|
||||||
|
aicore-resource-group:
|
||||||
|
description: 'SAP AI Core resource group (default: default)'
|
||||||
|
required: false
|
||||||
|
default: 'default'
|
||||||
|
conventions-file:
|
||||||
|
description: 'Path to conventions file in the repo (e.g. CLAUDE.md)'
|
||||||
|
required: false
|
||||||
|
default: ''
|
||||||
|
patterns-repo:
|
||||||
|
description: 'Comma-separated repos with language patterns (e.g. rodin/elixir-patterns,rodin/phoenix-conventions)'
|
||||||
|
required: false
|
||||||
|
default: ''
|
||||||
|
patterns-files:
|
||||||
|
description: 'Comma-separated file paths or directories to fetch from patterns repos'
|
||||||
|
required: false
|
||||||
|
default: 'README.md'
|
||||||
|
temperature:
|
||||||
|
description: 'LLM temperature (0 = server default)'
|
||||||
|
required: false
|
||||||
|
default: '0'
|
||||||
|
timeout:
|
||||||
|
description: 'LLM request timeout in seconds (default 300)'
|
||||||
|
required: false
|
||||||
|
default: '300'
|
||||||
|
version:
|
||||||
|
description: 'review-bot version to install (e.g. v0.1.0, defaults to latest)'
|
||||||
|
required: false
|
||||||
|
default: 'latest'
|
||||||
|
dry-run:
|
||||||
|
description: 'Print review to stdout instead of posting'
|
||||||
|
required: false
|
||||||
|
default: 'false'
|
||||||
|
update-existing:
|
||||||
|
description: 'Delete previous review from same bot after posting new one. Accepts: true/1/yes or false/0/no (default true)'
|
||||||
|
required: false
|
||||||
|
default: 'true'
|
||||||
|
system-prompt-file:
|
||||||
|
description: 'Local file with additional system prompt instructions (e.g. security review focus)'
|
||||||
|
required: false
|
||||||
|
default: ''
|
||||||
|
persona:
|
||||||
|
description: 'Built-in persona name (security, architect, docs)'
|
||||||
|
required: false
|
||||||
|
default: ''
|
||||||
|
persona-file:
|
||||||
|
description: 'Path to custom persona JSON file'
|
||||||
|
required: false
|
||||||
|
default: ''
|
||||||
|
|
||||||
|
runs:
|
||||||
|
using: 'composite'
|
||||||
|
steps:
|
||||||
|
- name: Determine version
|
||||||
|
id: version
|
||||||
|
shell: bash
|
||||||
|
run: |
|
||||||
|
GITEA_URL="${{ inputs.gitea-url || github.server_url }}"
|
||||||
|
REPO="${{ inputs.repo || 'rodin/review-bot' }}"
|
||||||
|
if [ "${{ inputs.version }}" = "latest" ]; then
|
||||||
|
VERSION=$(curl -sSf "${GITEA_URL}/api/v1/repos/${REPO}/releases?limit=1" \
|
||||||
|
| python3 -c "import sys, json; releases = json.load(sys.stdin); print(releases[0]['tag_name'] if releases else '')")
|
||||||
|
if [ -z "$VERSION" ]; then
|
||||||
|
echo "Failed to determine latest version" >&2
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
else
|
||||||
|
VERSION="${{ inputs.version }}"
|
||||||
|
fi
|
||||||
|
echo "version=${VERSION}" >> "$GITHUB_OUTPUT"
|
||||||
|
|
||||||
|
- name: Cache review-bot binary
|
||||||
|
id: cache
|
||||||
|
uses: actions/cache@v4
|
||||||
|
with:
|
||||||
|
path: ${{ runner.temp }}/review-bot
|
||||||
|
key: review-bot-linux-amd64-${{ steps.version.outputs.version }}
|
||||||
|
|
||||||
|
- name: Install review-bot
|
||||||
|
if: steps.cache.outputs.cache-hit != 'true'
|
||||||
|
shell: bash
|
||||||
|
run: |
|
||||||
|
GITEA_URL="${{ inputs.gitea-url || github.server_url }}"
|
||||||
|
REPO="${{ inputs.repo || 'rodin/review-bot' }}"
|
||||||
|
VERSION="${{ steps.version.outputs.version }}"
|
||||||
|
BINARY="review-bot-linux-amd64"
|
||||||
|
|
||||||
|
curl -sSfL "${GITEA_URL}/${REPO}/releases/download/${VERSION}/${BINARY}" \
|
||||||
|
-o "${{ runner.temp }}/review-bot"
|
||||||
|
curl -sSfL "${GITEA_URL}/${REPO}/releases/download/${VERSION}/checksums.txt" \
|
||||||
|
-o "${{ runner.temp }}/checksums.txt"
|
||||||
|
|
||||||
|
# Verify SHA-256 checksum
|
||||||
|
cd "${{ runner.temp }}"
|
||||||
|
EXPECTED=$(grep "${BINARY}" checksums.txt | awk '{print $1}')
|
||||||
|
ACTUAL=$(sha256sum review-bot | awk '{print $1}')
|
||||||
|
|
||||||
|
if [ -z "$EXPECTED" ]; then
|
||||||
|
echo "Error: no checksum found for ${BINARY}" >&2
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
if [ "$EXPECTED" != "$ACTUAL" ]; then
|
||||||
|
echo "Error: checksum mismatch!" >&2
|
||||||
|
echo " Expected: $EXPECTED" >&2
|
||||||
|
echo " Actual: $ACTUAL" >&2
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
chmod +x "${{ runner.temp }}/review-bot"
|
||||||
|
echo "Installed review-bot ${VERSION} (checksum verified)"
|
||||||
|
|
||||||
|
- name: Run review
|
||||||
|
shell: bash
|
||||||
|
env:
|
||||||
|
GITHUB_SERVER_URL: ${{ inputs.gitea-url || github.server_url }}
|
||||||
|
GITHUB_REPOSITORY: ${{ inputs.repo || github.repository }}
|
||||||
|
PR_NUMBER: ${{ inputs.pr-number || github.event.pull_request.number }}
|
||||||
|
REVIEWER_TOKEN: ${{ inputs.reviewer-token }}
|
||||||
|
REVIEWER_NAME: ${{ inputs.reviewer-name }}
|
||||||
|
LLM_BASE_URL: ${{ inputs.llm-base-url }}
|
||||||
|
LLM_API_KEY: ${{ inputs.llm-api-key }}
|
||||||
|
LLM_MODEL: ${{ inputs.llm-model }}
|
||||||
|
CONVENTIONS_FILE: ${{ inputs.conventions-file }}
|
||||||
|
PATTERNS_REPO: ${{ inputs.patterns-repo }}
|
||||||
|
PATTERNS_FILES: ${{ inputs.patterns-files }}
|
||||||
|
LLM_TEMPERATURE: ${{ inputs.temperature }}
|
||||||
|
LLM_TIMEOUT: ${{ inputs.timeout }}
|
||||||
|
LLM_PROVIDER: ${{ inputs.llm-provider }}
|
||||||
|
UPDATE_EXISTING: ${{ inputs.update-existing }}
|
||||||
|
SYSTEM_PROMPT_FILE: ${{ inputs.system-prompt-file }}
|
||||||
|
PERSONA: ${{ inputs.persona }}
|
||||||
|
PERSONA_FILE: ${{ inputs.persona-file }}
|
||||||
|
AICORE_CLIENT_ID: ${{ inputs.aicore-client-id }}
|
||||||
|
AICORE_CLIENT_SECRET: ${{ inputs.aicore-client-secret }}
|
||||||
|
AICORE_AUTH_URL: ${{ inputs.aicore-auth-url }}
|
||||||
|
AICORE_API_URL: ${{ inputs.aicore-api-url }}
|
||||||
|
AICORE_RESOURCE_GROUP: ${{ inputs.aicore-resource-group }}
|
||||||
|
run: |
|
||||||
|
ARGS=""
|
||||||
|
if [ "${{ inputs.dry-run }}" = "true" ]; then
|
||||||
|
ARGS="--dry-run"
|
||||||
|
fi
|
||||||
|
${{ runner.temp }}/review-bot $ARGS
|
||||||
@@ -0,0 +1,69 @@
|
|||||||
|
name: CI
|
||||||
|
|
||||||
|
on:
|
||||||
|
push:
|
||||||
|
branches: [main]
|
||||||
|
pull_request:
|
||||||
|
types: [opened, synchronize]
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
test:
|
||||||
|
runs-on: ubuntu-24.04
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v4
|
||||||
|
- uses: actions/setup-go@v5
|
||||||
|
with:
|
||||||
|
go-version: '1.26'
|
||||||
|
- run: go test ./...
|
||||||
|
- run: go vet ./...
|
||||||
|
- run: go build -o review-bot ./cmd/review-bot
|
||||||
|
|
||||||
|
# Self-review using native SAP AI Core provider
|
||||||
|
# Models must match SAP AI Core deployments
|
||||||
|
# Available models: gpt-5, anthropic--claude-4.6-sonnet, anthropic--claude-4.6-opus
|
||||||
|
# Removed gpt-4.1, gpt-5-mini, gpt-4.1-mini - not deployed on AI Core
|
||||||
|
review:
|
||||||
|
runs-on: ubuntu-24.04
|
||||||
|
if: github.event_name == 'pull_request'
|
||||||
|
needs: test
|
||||||
|
strategy:
|
||||||
|
matrix:
|
||||||
|
include:
|
||||||
|
- name: sonnet
|
||||||
|
token_secret: SONNET_REVIEW_TOKEN
|
||||||
|
model: anthropic--claude-4.6-sonnet
|
||||||
|
- name: gpt
|
||||||
|
token_secret: GPT_REVIEW_TOKEN
|
||||||
|
model: gpt-5
|
||||||
|
- name: security
|
||||||
|
token_secret: SECURITY_REVIEW_TOKEN
|
||||||
|
model: gpt-5
|
||||||
|
patterns_repo: rodin/security-patterns
|
||||||
|
patterns_files: "."
|
||||||
|
system_prompt_file: SECURITY_REVIEW.md
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v4
|
||||||
|
- uses: actions/setup-go@v5
|
||||||
|
with:
|
||||||
|
go-version: '1.26'
|
||||||
|
- run: go build -o review-bot ./cmd/review-bot
|
||||||
|
- name: Run ${{ matrix.name }} review
|
||||||
|
env:
|
||||||
|
GITHUB_SERVER_URL: ${{ github.server_url }}
|
||||||
|
GITHUB_REPOSITORY: ${{ github.repository }}
|
||||||
|
PR_NUMBER: ${{ github.event.pull_request.number }}
|
||||||
|
REVIEWER_TOKEN: ${{ secrets[matrix.token_secret] }}
|
||||||
|
REVIEWER_NAME: ${{ matrix.name }}
|
||||||
|
LLM_PROVIDER: aicore
|
||||||
|
LLM_MODEL: ${{ matrix.model }}
|
||||||
|
AICORE_CLIENT_ID: ${{ secrets.AICORE_CLIENT_ID }}
|
||||||
|
AICORE_CLIENT_SECRET: ${{ secrets.AICORE_CLIENT_SECRET }}
|
||||||
|
AICORE_AUTH_URL: ${{ secrets.AICORE_AUTH_URL }}
|
||||||
|
AICORE_API_URL: ${{ secrets.AICORE_API_URL }}
|
||||||
|
AICORE_RESOURCE_GROUP: ${{ secrets.AICORE_RESOURCE_GROUP }}
|
||||||
|
CONVENTIONS_FILE: "CONVENTIONS.md"
|
||||||
|
PATTERNS_REPO: ${{ matrix.patterns_repo || 'rodin/go-patterns' }}
|
||||||
|
PATTERNS_FILES: ${{ matrix.patterns_files || 'README.md,patterns/' }}
|
||||||
|
LLM_TIMEOUT: "600"
|
||||||
|
SYSTEM_PROMPT_FILE: ${{ matrix.system_prompt_file }}
|
||||||
|
run: ./review-bot
|
||||||
@@ -0,0 +1,38 @@
|
|||||||
|
name: PR Ready Gate
|
||||||
|
|
||||||
|
on:
|
||||||
|
pull_request:
|
||||||
|
types: [synchronize]
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
clear-labels:
|
||||||
|
runs-on: ubuntu-24.04
|
||||||
|
# Always run - curl commands are safe if labels don't exist
|
||||||
|
steps:
|
||||||
|
- name: Remove ready and self-reviewed labels, reassign to author
|
||||||
|
env:
|
||||||
|
GITEA_TOKEN: ${{ secrets.RODIN_TOKEN }}
|
||||||
|
run: |
|
||||||
|
PR_NUMBER=${{ github.event.pull_request.number }}
|
||||||
|
AUTHOR=${{ github.event.pull_request.user.login }}
|
||||||
|
READY_LABEL_ID=38
|
||||||
|
SELF_REVIEWED_LABEL_ID=37
|
||||||
|
|
||||||
|
# Remove ready label if present
|
||||||
|
curl -sS -X DELETE \
|
||||||
|
-H "Authorization: token $GITEA_TOKEN" \
|
||||||
|
"https://gitea.weiker.me/api/v1/repos/${{ github.repository }}/issues/${PR_NUMBER}/labels/${READY_LABEL_ID}" || true
|
||||||
|
|
||||||
|
# Remove self-reviewed label if present
|
||||||
|
curl -sS -X DELETE \
|
||||||
|
-H "Authorization: token $GITEA_TOKEN" \
|
||||||
|
"https://gitea.weiker.me/api/v1/repos/${{ github.repository }}/issues/${PR_NUMBER}/labels/${SELF_REVIEWED_LABEL_ID}" || true
|
||||||
|
|
||||||
|
# Reassign to author
|
||||||
|
curl -sS -X PATCH \
|
||||||
|
-H "Authorization: token $GITEA_TOKEN" \
|
||||||
|
-H "Content-Type: application/json" \
|
||||||
|
-d "{\"assignees\": [\"${AUTHOR}\"]}" \
|
||||||
|
"https://gitea.weiker.me/api/v1/repos/${{ github.repository }}/pulls/${PR_NUMBER}"
|
||||||
|
|
||||||
|
echo "Cleared ready/self-reviewed labels and reassigned PR #${PR_NUMBER} to ${AUTHOR}"
|
||||||
@@ -0,0 +1,97 @@
|
|||||||
|
name: Release
|
||||||
|
|
||||||
|
on:
|
||||||
|
push:
|
||||||
|
tags:
|
||||||
|
- 'v*'
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
release:
|
||||||
|
runs-on: ubuntu-24.04
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v4
|
||||||
|
|
||||||
|
- uses: actions/setup-go@v5
|
||||||
|
with:
|
||||||
|
go-version: '1.26'
|
||||||
|
|
||||||
|
- name: Run tests
|
||||||
|
run: |
|
||||||
|
go vet ./...
|
||||||
|
go test ./...
|
||||||
|
|
||||||
|
- name: Build binaries
|
||||||
|
run: |
|
||||||
|
VERSION=${GITHUB_REF_NAME}
|
||||||
|
mkdir -p dist
|
||||||
|
|
||||||
|
GOOS=linux GOARCH=amd64 go build -ldflags "-s -w -X main.version=${VERSION}" -o dist/review-bot-linux-amd64 ./cmd/review-bot
|
||||||
|
GOOS=linux GOARCH=arm64 go build -ldflags "-s -w -X main.version=${VERSION}" -o dist/review-bot-linux-arm64 ./cmd/review-bot
|
||||||
|
GOOS=darwin GOARCH=amd64 go build -ldflags "-s -w -X main.version=${VERSION}" -o dist/review-bot-darwin-amd64 ./cmd/review-bot
|
||||||
|
GOOS=darwin GOARCH=arm64 go build -ldflags "-s -w -X main.version=${VERSION}" -o dist/review-bot-darwin-arm64 ./cmd/review-bot
|
||||||
|
|
||||||
|
cd dist && sha256sum * > checksums.txt
|
||||||
|
|
||||||
|
- name: Create release and upload assets
|
||||||
|
env:
|
||||||
|
GITEA_TOKEN: ${{ secrets.RELEASE_TOKEN }}
|
||||||
|
run: |
|
||||||
|
VERSION=${GITHUB_REF_NAME}
|
||||||
|
GITEA_URL="${{ github.server_url }}"
|
||||||
|
REPO="${{ github.repository }}"
|
||||||
|
|
||||||
|
# Create release (or find existing one for this tag)
|
||||||
|
HTTP_CODE=$(curl -s -o /tmp/release_response.json -w "%{http_code}" -X POST \
|
||||||
|
-H "Authorization: token ${GITEA_TOKEN}" \
|
||||||
|
-H "Content-Type: application/json" \
|
||||||
|
"${GITEA_URL}/api/v1/repos/${REPO}/releases" \
|
||||||
|
-d "{\"tag_name\": \"${VERSION}\", \"name\": \"${VERSION}\", \"body\": \"Release ${VERSION}\", \"draft\": false, \"prerelease\": false}")
|
||||||
|
|
||||||
|
if [ "$HTTP_CODE" = "409" ]; then
|
||||||
|
echo "Release for ${VERSION} already exists, fetching existing..."
|
||||||
|
curl -sSf -o /tmp/release_response.json \
|
||||||
|
-H "Authorization: token ${GITEA_TOKEN}" \
|
||||||
|
"${GITEA_URL}/api/v1/repos/${REPO}/releases/tags/${VERSION}"
|
||||||
|
elif [ "$HTTP_CODE" != "201" ]; then
|
||||||
|
echo "Failed to create release (HTTP ${HTTP_CODE})" >&2
|
||||||
|
cat /tmp/release_response.json >&2
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
# Parse release ID (python3 available on ubuntu-24.04 runners)
|
||||||
|
RELEASE_ID=$(python3 -c "import json; print(json.load(open('/tmp/release_response.json'))['id'])")
|
||||||
|
|
||||||
|
if [ -z "$RELEASE_ID" ]; then
|
||||||
|
echo "Failed to parse release ID" >&2
|
||||||
|
cat /tmp/release_response.json >&2
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
echo "Release ID: ${RELEASE_ID}"
|
||||||
|
|
||||||
|
# Upload each asset (idempotent: delete existing asset with same name first)
|
||||||
|
for file in dist/*; do
|
||||||
|
filename=$(basename "$file")
|
||||||
|
echo "Uploading ${filename}..."
|
||||||
|
|
||||||
|
# Check if asset already exists and delete it
|
||||||
|
EXISTING_ID=$(export ASSET_NAME="${filename}"; curl -sS \
|
||||||
|
-H "Authorization: token ${GITEA_TOKEN}" \
|
||||||
|
"${GITEA_URL}/api/v1/repos/${REPO}/releases/${RELEASE_ID}/assets" \
|
||||||
|
| python3 -c "import json,sys,os; name=os.environ['ASSET_NAME']; assets=json.load(sys.stdin); print(next((str(a['id']) for a in assets if a['name']==name),''))" 2>/dev/null)
|
||||||
|
|
||||||
|
if [ -n "$EXISTING_ID" ]; then
|
||||||
|
echo " Asset ${filename} already exists (id=${EXISTING_ID}), deleting..."
|
||||||
|
curl -sSf -X DELETE \
|
||||||
|
-H "Authorization: token ${GITEA_TOKEN}" \
|
||||||
|
"${GITEA_URL}/api/v1/repos/${REPO}/releases/${RELEASE_ID}/assets/${EXISTING_ID}"
|
||||||
|
fi
|
||||||
|
|
||||||
|
curl -sSf -X POST \
|
||||||
|
-H "Authorization: token ${GITEA_TOKEN}" \
|
||||||
|
-H "Content-Type: application/octet-stream" \
|
||||||
|
"${GITEA_URL}/api/v1/repos/${REPO}/releases/${RELEASE_ID}/assets?name=$(printf '%s' "${filename}" | jq -sRr @uri)" \
|
||||||
|
--data-binary "@${file}"
|
||||||
|
done
|
||||||
|
|
||||||
|
echo "Release ${VERSION} created with assets"
|
||||||
@@ -1,33 +0,0 @@
|
|||||||
# CHANGELOG
|
|
||||||
|
|
||||||
## Unreleased
|
|
||||||
|
|
||||||
### Added
|
|
||||||
|
|
||||||
- **`doc-map` input** (`--doc-map` flag / `DOC_MAP_FILE` env var): Path to a YAML file mapping source path globs to governing design docs. review-bot intersects the map with changed PR paths and injects matching docs into the system prompt under a `## Design Documents` heading. ([#137](https://gitea.weiker.me/rodin/review-bot/issues/137))
|
|
||||||
- **`doc-map-max-bytes` input** (`--doc-map-max-bytes` flag / `DOC_MAP_MAX_BYTES` env var): Cap on total injected design doc content in bytes. Default: 102400 (100 KB). Prevents accidental context overflow when a PR touches many modules.
|
|
||||||
- **`DesignDocs` budget section**: Design docs are included in the context budget and trimmed after conventions, before file context, if the total exceeds the model's context limit.
|
|
||||||
|
|
||||||
### Doc-map config format
|
|
||||||
|
|
||||||
```yaml
|
|
||||||
mappings:
|
|
||||||
- paths:
|
|
||||||
- "lib/gargoyle/engine/signal_risk/**"
|
|
||||||
docs:
|
|
||||||
- docs/domain/contexts/risk/risk-controls.md
|
|
||||||
- paths:
|
|
||||||
- "lib/gargoyle/trading/**"
|
|
||||||
docs:
|
|
||||||
- docs/domain/contexts/trading/
|
|
||||||
```
|
|
||||||
|
|
||||||
- `paths` — glob patterns (including `**`) matched against changed file paths in the PR
|
|
||||||
- `docs` — local file paths or directories (all `.md` files under a directory) to inject
|
|
||||||
- Multiple mappings can reference the same doc; docs are deduplicated
|
|
||||||
- Missing doc files: warn and skip (review continues without them)
|
|
||||||
- No matching paths: no docs injected, review runs normally
|
|
||||||
|
|
||||||
## v0.3.2
|
|
||||||
|
|
||||||
- Previous releases tracked in Gitea release notes.
|
|
||||||
+1
-1
@@ -9,7 +9,7 @@
|
|||||||
|
|
||||||
| Package | Use Case | Scope |
|
| Package | Use Case | Scope |
|
||||||
|---------|----------|-------|
|
|---------|----------|-------|
|
||||||
| `github.com/goccy/go-yaml` | YAML parsing and AST inspection (subpkgs: `ast`, `parser`) | production |
|
| `gopkg.in/yaml.v3` | YAML parsing (persona files, config) | production |
|
||||||
| `github.com/google/go-cmp` | Test comparisons (`cmp.Diff`) | test only |
|
| `github.com/google/go-cmp` | Test comparisons (`cmp.Diff`) | test only |
|
||||||
|
|
||||||
**Any import not in this table or the Go standard library is forbidden.**
|
**Any import not in this table or the Go standard library is forbidden.**
|
||||||
|
|||||||
@@ -1,50 +0,0 @@
|
|||||||
# Dev Loop Health Check — 2026-05-15 01:33 UTC
|
|
||||||
|
|
||||||
## Status: ✅ OPTIMAL
|
|
||||||
|
|
||||||
### Test Results
|
|
||||||
- All packages: **PASS** ✅ (6/6, fresh -count=1 run)
|
|
||||||
- Build: ✅ successful
|
|
||||||
- Vet: ✅ clean
|
|
||||||
|
|
||||||
### Coverage (current)
|
|
||||||
|
|
||||||
| Package | Coverage |
|
|
||||||
|---------|----------|
|
|
||||||
| budget | 91.8% |
|
|
||||||
| cmd/review-bot | 46.1% |
|
|
||||||
| gitea | 85.2% |
|
|
||||||
| github | 86.3% |
|
|
||||||
| llm | 81.3% |
|
|
||||||
| review | 92.0% |
|
|
||||||
|
|
||||||
### Recent Activity (since last check 01:28 UTC)
|
|
||||||
- Pulled `d0b0b0b` (dev-loop health update from 01:28 cycle)
|
|
||||||
- No new commits from dev work
|
|
||||||
- No open issues or PRs
|
|
||||||
- Working tree: clean, up to date with origin/main
|
|
||||||
|
|
||||||
### Notes on Coverage
|
|
||||||
- `cmd/review-bot` at 46.1% — main() itself at 26.5%; lowest coverage package
|
|
||||||
- Potential: integration test harness (issue #TBD)
|
|
||||||
- `vcs.go` adapter wrappers intentionally 0% — thin delegation, real logic tested in gitea/github packages
|
|
||||||
|
|
||||||
### Next Phase Priorities
|
|
||||||
1. **PR Submission (#132+)** — Enable review-bot to create PRs
|
|
||||||
2. **`github.Client.DismissReview`** — method referenced in orphaned files, not in client.go; file issue
|
|
||||||
3. **GitHub Enterprise Support** — Enterprise URL patterns, token scopes
|
|
||||||
4. **Increase cmd/review-bot coverage** — integration test harness for main()
|
|
||||||
5. **Performance & Observability** — Metrics, load testing, audit logging
|
|
||||||
|
|
||||||
### System Health
|
|
||||||
- ✅ All tests passing
|
|
||||||
- ✅ No warnings or lint issues
|
|
||||||
- ✅ Code clean, working tree clean
|
|
||||||
- ✅ No open issues or PRs on Gitea
|
|
||||||
- ✅ Ready for next development cycle
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
**Previous check:** 2026-05-15 01:28 UTC
|
|
||||||
**This check:** 2026-05-15 01:33 UTC
|
|
||||||
**Action:** NONE — healthy, no work to do
|
|
||||||
@@ -1,43 +0,0 @@
|
|||||||
=============================================================================
|
|
||||||
REVIEW-BOT DEV LOOP STATUS — 2026-05-15 01:48 UTC (post-sync)
|
|
||||||
=============================================================================
|
|
||||||
|
|
||||||
OVERALL STATUS: ✅ OPTIMAL
|
|
||||||
|
|
||||||
Test Results (fresh run post-sync):
|
|
||||||
- All 6 packages: PASS ✅
|
|
||||||
- Build: ✅ clean
|
|
||||||
- Vet: ✅ clean
|
|
||||||
- Fresh run: -count=1 verified
|
|
||||||
|
|
||||||
Recent Major Changes (synced from origin/main):
|
|
||||||
- Significant new GitHub client methods (~360 lines added)
|
|
||||||
- New validateurl package for URL validation
|
|
||||||
- New vcs adapter layer for VCS abstraction
|
|
||||||
- New gitea/ipcheck package for IP validation
|
|
||||||
- Expanded integration tests in cmd/review-bot
|
|
||||||
- All changes verified passing tests
|
|
||||||
|
|
||||||
Coverage (current post-sync):
|
|
||||||
- review: 92.0%
|
|
||||||
- budget: 91.8%
|
|
||||||
- github: 86.3%
|
|
||||||
- gitea: 85.2%
|
|
||||||
- llm: 81.3%
|
|
||||||
- cmd/review-bot: 46.1%
|
|
||||||
|
|
||||||
Repository:
|
|
||||||
- Branch: main (synced with origin — 4ffa6b6)
|
|
||||||
- Working tree: clean
|
|
||||||
- Open issues: 0
|
|
||||||
- Open PRs: 0
|
|
||||||
|
|
||||||
System Health: ✅ GREEN
|
|
||||||
✓ All tests passing (33 commits synced)
|
|
||||||
✓ No warnings
|
|
||||||
✓ Code clean
|
|
||||||
✓ Ready for feature work
|
|
||||||
|
|
||||||
Next Cycle: Ready to pick up feature work
|
|
||||||
|
|
||||||
=============================================================================
|
|
||||||
-175
@@ -1,175 +0,0 @@
|
|||||||
# Plan: Issue #125 — Rename GITEA_URL → VCS_URL
|
|
||||||
|
|
||||||
## Problem
|
|
||||||
|
|
||||||
The `GITEA_URL` environment variable (and `--gitea-url` flag) implies the binary only works with Gitea.
|
|
||||||
Now that review-bot supports both Gitea and GitHub/GHES, this name is misleading.
|
|
||||||
Renaming to `VCS_URL` makes the binary platform-agnostic in its interface.
|
|
||||||
|
|
||||||
## Constraints
|
|
||||||
|
|
||||||
- Must not break existing users who already use `GITEA_URL` — need a fallback
|
|
||||||
- The CLI flag `--gitea-url` should also be updated to `--vcs-url` for consistency
|
|
||||||
- `INTEGRATION_GITEA_URL` in integration tests is a test-only env var, not the binary's interface; but should be updated for clarity
|
|
||||||
- The action YAML uses `GITEA_URL` as an internal shell variable in bash scripts — distinct from the env var passed to the binary
|
|
||||||
- All changes must compile and pass existing tests
|
|
||||||
|
|
||||||
## Files Affected
|
|
||||||
|
|
||||||
### Binary / Go source
|
|
||||||
| File | Change |
|
|
||||||
|------|--------|
|
|
||||||
| `cmd/review-bot/main.go` | Rename `--gitea-url` → `--vcs-url`, add `VCS_URL` as primary, keep `GITEA_URL` fallback |
|
|
||||||
| `cmd/review-bot/integration_test.go` | Rename `INTEGRATION_GITEA_URL` → `INTEGRATION_VCS_URL` (test-only, no external compat concern) |
|
|
||||||
| `integration_test.go` | Same — rename `INTEGRATION_GITEA_URL` → `INTEGRATION_VCS_URL` |
|
|
||||||
|
|
||||||
### Action YAML
|
|
||||||
| File | Change |
|
|
||||||
|------|--------|
|
|
||||||
| `.gitea/actions/review/action.yml` | Rename input `gitea-url` → `vcs-url`; update env var passed to binary: `VCS_URL` instead of `GITEA_URL`; keep internal bash var as `GITEA_URL` (only used for release download, not passed to binary) |
|
|
||||||
| `.gitea/workflows/ci.yml` | Rename `GITEA_URL` env var to `VCS_URL` in Run review step |
|
|
||||||
|
|
||||||
### Documentation
|
|
||||||
| File | Change |
|
|
||||||
|------|--------|
|
|
||||||
| `README.md` | Update CLI example, env var table entry |
|
|
||||||
|
|
||||||
## Proposed Approach
|
|
||||||
|
|
||||||
### 1. Backward-compatible env var lookup in main.go
|
|
||||||
|
|
||||||
Replace:
|
|
||||||
```go
|
|
||||||
giteaURL := flag.String("gitea-url", envOrDefault("GITEA_URL", ""), "Gitea instance URL")
|
|
||||||
```
|
|
||||||
|
|
||||||
With:
|
|
||||||
```go
|
|
||||||
giteaURL := flag.String("vcs-url", envOrDefaultFallback("VCS_URL", "GITEA_URL", ""), "VCS server URL (e.g. https://gitea.example.com)")
|
|
||||||
```
|
|
||||||
|
|
||||||
Add a helper:
|
|
||||||
```go
|
|
||||||
// envOrDefaultFallback reads primary env var; if empty, falls back to deprecated env var.
|
|
||||||
func envOrDefaultFallback(primary, deprecated, defaultVal string) string {
|
|
||||||
if v := os.Getenv(primary); v != "" {
|
|
||||||
return v
|
|
||||||
}
|
|
||||||
if v := os.Getenv(deprecated); v != "" {
|
|
||||||
slog.Warn("deprecated env var in use; rename to " + primary, "old", deprecated, "new", primary)
|
|
||||||
return v
|
|
||||||
}
|
|
||||||
return defaultVal
|
|
||||||
}
|
|
||||||
```
|
|
||||||
|
|
||||||
**Note:** This must be called AFTER `setupLogger` conceptually, but the flag default is evaluated at flag registration time. Since `setupLogger` runs before `flag.Parse()`, the slog.Warn will print correctly at runtime. We use `log.Printf` as a fallback if this proves problematic.
|
|
||||||
|
|
||||||
Actually — flag defaults are evaluated at registration (line 57), before `setupLogger`. The warning won't go through slog. Two options:
|
|
||||||
- Use `log.Printf` for the deprecation warning (always visible)
|
|
||||||
- Move the fallback lookup to after `flag.Parse()`, checking if the parsed value is still empty
|
|
||||||
|
|
||||||
**Decision:** Move fallback to a post-parse check. This is cleaner:
|
|
||||||
```go
|
|
||||||
vcsURL := flag.String("vcs-url", os.Getenv("VCS_URL"), "VCS server URL")
|
|
||||||
flag.Parse()
|
|
||||||
// Backward compat: fall back to deprecated GITEA_URL
|
|
||||||
if *vcsURL == "" {
|
|
||||||
if v := os.Getenv("GITEA_URL"); v != "" {
|
|
||||||
slog.Warn("GITEA_URL is deprecated; use VCS_URL instead")
|
|
||||||
*vcsURL = v
|
|
||||||
}
|
|
||||||
}
|
|
||||||
```
|
|
||||||
|
|
||||||
This is clean, idiomatic, and the warning goes through slog correctly.
|
|
||||||
|
|
||||||
### 2. Keep `--gitea-url` as deprecated alias
|
|
||||||
|
|
||||||
Add a hidden flag for backward compat:
|
|
||||||
```go
|
|
||||||
giteaURLAlias := flag.String("gitea-url", "", "Deprecated: use --vcs-url")
|
|
||||||
```
|
|
||||||
|
|
||||||
Post-parse:
|
|
||||||
```go
|
|
||||||
if *vcsURL == "" && *giteaURLAlias != "" {
|
|
||||||
slog.Warn("--gitea-url is deprecated; use --vcs-url instead")
|
|
||||||
*vcsURL = *giteaURLAlias
|
|
||||||
}
|
|
||||||
```
|
|
||||||
|
|
||||||
### 3. Internal variable rename
|
|
||||||
|
|
||||||
Rename `giteaURL` local variable → `vcsURL` throughout `main.go` for consistency.
|
|
||||||
|
|
||||||
### 4. Error message update
|
|
||||||
|
|
||||||
```go
|
|
||||||
fmt.Fprintf(os.Stderr, "Required: --vcs-url, --repo, --pr, --reviewer-token, --llm-model\n")
|
|
||||||
```
|
|
||||||
|
|
||||||
### 5. Action YAML changes
|
|
||||||
|
|
||||||
In `.gitea/actions/review/action.yml`:
|
|
||||||
- Input `gitea-url` → `vcs-url` (with same description, `required: false`, `default: ''`)
|
|
||||||
- Line 172: `GITEA_URL: ${{ inputs.gitea-url || github.server_url }}` → `VCS_URL: ${{ inputs.vcs-url || github.server_url }}`
|
|
||||||
- Lines 115, 140: internal bash vars `GITEA_URL=` are used for downloading binaries — NOT passed to the review-bot binary. Leave them as internal bash vars (they're scope-local in bash). These could be renamed to `SERVER_URL` or `BASE_URL` for local clarity, but renaming them isn't strictly required.
|
|
||||||
|
|
||||||
In `.gitea/workflows/ci.yml`:
|
|
||||||
- Line 52: `GITEA_URL: ${{ github.server_url }}` → `VCS_URL: ${{ github.server_url }}`
|
|
||||||
|
|
||||||
### 6. Integration test updates
|
|
||||||
|
|
||||||
`INTEGRATION_GITEA_URL` → `INTEGRATION_VCS_URL` in both test files.
|
|
||||||
|
|
||||||
### 7. README
|
|
||||||
|
|
||||||
- CLI example: `--gitea-url` → `--vcs-url`
|
|
||||||
- Env var table: `GITEA_URL` → `VCS_URL`, add note about `GITEA_URL` fallback
|
|
||||||
|
|
||||||
## Backward Compatibility Summary
|
|
||||||
|
|
||||||
| Old | New | Fallback? |
|
|
||||||
|-----|-----|-----------|
|
|
||||||
| `GITEA_URL` env var | `VCS_URL` | ✅ with deprecation warning |
|
|
||||||
| `--gitea-url` flag | `--vcs-url` | ✅ with deprecation warning |
|
|
||||||
| `gitea-url` action input | `vcs-url` | ⚠️ No (action version bump handles this) |
|
|
||||||
| `INTEGRATION_GITEA_URL` | `INTEGRATION_VCS_URL` | N/A (test-only) |
|
|
||||||
|
|
||||||
## Error Cases
|
|
||||||
|
|
||||||
- Both `VCS_URL` and `GITEA_URL` set: `VCS_URL` wins (primary takes precedence)
|
|
||||||
- Both `--vcs-url` and `--gitea-url` provided: `--vcs-url` wins
|
|
||||||
- Neither set: existing "missing required flags" error unchanged
|
|
||||||
|
|
||||||
## Edge Cases
|
|
||||||
|
|
||||||
- `os.Getenv` returns "" for unset AND set-to-empty — consistent with existing behavior
|
|
||||||
- The `envOrDefault` helper is unchanged; we add `envOrDefaultFallback` for the one renamed var
|
|
||||||
|
|
||||||
## Testing Strategy
|
|
||||||
|
|
||||||
- Existing unit tests pass unchanged (they don't test env var parsing directly)
|
|
||||||
- Integration tests updated to use new env var name
|
|
||||||
- Manual: `GITEA_URL=https://example.com ./review-bot --repo x --pr 1 ...` should print deprecation warning and proceed
|
|
||||||
- Manual: `VCS_URL=https://example.com ./review-bot ...` should work silently
|
|
||||||
|
|
||||||
## Completion Checklist
|
|
||||||
|
|
||||||
1. `VCS_URL` is read first; `GITEA_URL` is fallback with deprecation warning
|
|
||||||
2. `--vcs-url` flag is primary; `--gitea-url` is deprecated alias with warning
|
|
||||||
3. Error message references `--vcs-url` not `--gitea-url`
|
|
||||||
4. `action.yml` passes `VCS_URL` (not `GITEA_URL`) to the binary
|
|
||||||
5. `ci.yml` passes `VCS_URL` (not `GITEA_URL`) to the binary
|
|
||||||
6. README updated in CLI example and env var table
|
|
||||||
7. Integration tests use `INTEGRATION_VCS_URL`
|
|
||||||
8. `go test ./...` passes
|
|
||||||
9. `go vet ./...` passes
|
|
||||||
10. `go build ./cmd/review-bot` succeeds
|
|
||||||
|
|
||||||
## Open Questions
|
|
||||||
|
|
||||||
- Should the CLI flag `--gitea-url` be completely hidden from `--help` or just deprecated with a note? The issue doesn't specify. Decision: keep it visible but add "(deprecated: use --vcs-url)" to the description.
|
|
||||||
- Should action.yml also add `gitea-url` as a deprecated input alias? The issue says "Update the action to pass the new env var name" — no mention of backward compat for the action input. Decision: rename only, no alias (action users pin a version anyway).
|
|
||||||
- The bash-internal `GITEA_URL` variable in action.yml scripts (used for release download, not passed to binary) — rename for clarity? Decision: yes, rename to `BASE_URL` to avoid confusion with the env var.
|
|
||||||
-194
@@ -1,194 +0,0 @@
|
|||||||
# Plan: Issue #137 — doc-map input for path-scoped doc injection
|
|
||||||
|
|
||||||
## Problem
|
|
||||||
|
|
||||||
review-bot currently injects context via `patterns-repo` (external VCS repos) and `conventions-file` (a single file from the reviewed repo). There is no mechanism to inject local repo documentation files scoped to the paths changed in a PR.
|
|
||||||
|
|
||||||
First consumer: `grgl/gargoyle#778` wants a "doc adherence" reviewer that checks code against the module's governing design doc, without injecting every doc in the tree.
|
|
||||||
|
|
||||||
## Constraints
|
|
||||||
|
|
||||||
- Must work with existing `budget.Fit` architecture (docs go into `SystemBase` section, never trimmed — or added as a new section below `Conventions`)
|
|
||||||
- Must not fail the review if doc files are missing (warn + skip)
|
|
||||||
- Context guard: default 100KB total injected doc content (configurable)
|
|
||||||
- YAML parsing must use `github.com/goccy/go-yaml` (the only approved YAML library)
|
|
||||||
- No new third-party dependencies (Go standard library + approved packages only)
|
|
||||||
- Path security: doc files must be read via VCS API (not local filesystem), so they are always fetched from the PR head ref within the repo workspace — same path used by `conventions-file` loading
|
|
||||||
|
|
||||||
Wait — re-reading the issue: the issue says "local repo files". In the CI action context, the action runner has the repo checked out. The design doc says "read each doc file from the local checkout". But review-bot has no local checkout — it runs as a binary and reads files via VCS API. Let me reconcile:
|
|
||||||
|
|
||||||
- `conventions-file` uses `vcs.GetFileContent` (fetches from VCS API, default branch)
|
|
||||||
- The doc-map docs should also be read via VCS API
|
|
||||||
- The doc-map config file itself (`doc-map` input) is a local file in the workspace (like `system-prompt-file`)
|
|
||||||
- The doc paths inside the config ARE relative to the repo root, to be fetched via VCS API
|
|
||||||
|
|
||||||
**Conclusion:** The `doc-map` YAML file is read from local filesystem (like `system-prompt-file`). The doc files listed inside are fetched from the VCS API.
|
|
||||||
|
|
||||||
Actually, re-reading more carefully: "Read each doc file (or all .md files under a directory) from the local checkout". But review-bot doesn't have a local checkout. Since `system-prompt-file` and `conventions-file` are both read locally, I should follow the same approach consistently.
|
|
||||||
|
|
||||||
**Final decision:** The `doc-map` config file is local (passed via `--doc-map` flag, read with `os.ReadFile` after workspace validation). The listed doc paths (and directory expansion) are read via VCS `GetFileContent` / `GetAllFilesInPath` — matching the `conventions-file` pattern for consistency, and enabling it to work on any branch (not just the checked-out one).
|
|
||||||
|
|
||||||
## Proposed Approach
|
|
||||||
|
|
||||||
### New files
|
|
||||||
|
|
||||||
1. `review/docmap.go` — `DocMap` type, YAML parsing, glob matching, doc loading logic
|
|
||||||
2. `review/docmap_test.go` — unit tests
|
|
||||||
|
|
||||||
### Modified files
|
|
||||||
|
|
||||||
1. `cmd/review-bot/main.go` — add `--doc-map` flag, wire up in Step 6c
|
|
||||||
2. `.gitea/actions/review/action.yml` — add `doc-map` input, pass as `DOC_MAP_FILE` env var
|
|
||||||
3. `budget/budget.go` — add `DesignDocs` section (between `SystemBase`/`Conventions` and `Diff`)
|
|
||||||
4. `CHANGELOG.md` — update
|
|
||||||
|
|
||||||
### DocMap types (review/docmap.go)
|
|
||||||
|
|
||||||
```go
|
|
||||||
// DocMapping maps a set of path globs to doc files/directories.
|
|
||||||
type DocMapping struct {
|
|
||||||
Paths []string `yaml:"paths"` // glob patterns
|
|
||||||
Docs []string `yaml:"docs"` // file paths or directories
|
|
||||||
}
|
|
||||||
|
|
||||||
// DocMapConfig is the top-level YAML structure.
|
|
||||||
type DocMapConfig struct {
|
|
||||||
Mappings []DocMapping `yaml:"mappings"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// DocMapOptions controls doc loading behavior.
|
|
||||||
type DocMapOptions struct {
|
|
||||||
MaxBytes int // default 100*1024
|
|
||||||
}
|
|
||||||
```
|
|
||||||
|
|
||||||
### Key functions
|
|
||||||
|
|
||||||
```go
|
|
||||||
// ParseDocMapConfig parses the YAML config file from a local path.
|
|
||||||
func ParseDocMapConfig(path string) (*DocMapConfig, error)
|
|
||||||
|
|
||||||
// MatchDocs returns deduplicated doc paths for the given changed files.
|
|
||||||
func MatchDocs(cfg *DocMapConfig, changedFiles []string) []string
|
|
||||||
|
|
||||||
// LoadMatchingDocs fetches doc content via VCS, respecting size limit.
|
|
||||||
// Returns (content, error). Missing files are warned and skipped.
|
|
||||||
func LoadMatchingDocs(ctx context.Context, fetcher DocFetcher, owner, repo string, docPaths []string, opts DocMapOptions) (string, error)
|
|
||||||
```
|
|
||||||
|
|
||||||
### DocFetcher interface
|
|
||||||
|
|
||||||
```go
|
|
||||||
// DocFetcher fetches files and directory listings from VCS.
|
|
||||||
// Subset of vcsClient, defined here to keep review package free of cmd-level deps.
|
|
||||||
type DocFetcher interface {
|
|
||||||
GetFileContent(ctx context.Context, owner, repo, filepath string) (string, error)
|
|
||||||
GetAllFilesInPath(ctx context.Context, owner, repo, path string) (map[string]string, error)
|
|
||||||
}
|
|
||||||
```
|
|
||||||
|
|
||||||
### Glob matching
|
|
||||||
|
|
||||||
Use `path.Match` from the Go standard library. It matches patterns like `lib/gargoyle/engine/signal_risk/**`. The `**` glob is NOT natively supported by `path.Match`, so we need either:
|
|
||||||
|
|
||||||
a) Use `filepath.Match` which also doesn't support `**`
|
|
||||||
b) Implement simple `**` support: `**` matches any number of path segments
|
|
||||||
|
|
||||||
**Decision:** Implement minimal `**` support: split path on `/`, split pattern on `/`, match each segment with `filepath.Match`. When a pattern segment is `**`, it consumes any number of remaining segments. This covers the primary use case without a new dependency.
|
|
||||||
|
|
||||||
### Budget integration
|
|
||||||
|
|
||||||
Add `DesignDocs` field to `budget.Sections`. Position: after `Conventions`, before `FileContext` (trimming order: Patterns → Conventions → DesignDocs → FileContext → Diff). Inject under `## Design Documents` heading in system prompt.
|
|
||||||
|
|
||||||
### Context size guard
|
|
||||||
|
|
||||||
Accumulate doc content bytes. If total would exceed `MaxBytes`, truncate last doc with a notice and stop loading more.
|
|
||||||
|
|
||||||
## State/Data Model
|
|
||||||
|
|
||||||
```
|
|
||||||
DocMapConfig
|
|
||||||
└── []DocMapping
|
|
||||||
├── Paths []string (glob patterns against changed file paths)
|
|
||||||
└── Docs []string (local doc paths or directories in target repo)
|
|
||||||
|
|
||||||
Flow:
|
|
||||||
1. Parse doc-map YAML → DocMapConfig
|
|
||||||
2. GetPullRequestFiles → []string of changed paths
|
|
||||||
3. MatchDocs(cfg, changedPaths) → deduplicated []string doc paths
|
|
||||||
4. For each doc path:
|
|
||||||
- If path ends with "/" or is a "directory" → GetAllFilesInPath, filter .md
|
|
||||||
- Otherwise → GetFileContent
|
|
||||||
5. Accumulate, respect size limit
|
|
||||||
6. Inject into system prompt
|
|
||||||
```
|
|
||||||
|
|
||||||
## Error Cases
|
|
||||||
|
|
||||||
| Situation | Behavior |
|
|
||||||
|-----------|----------|
|
|
||||||
| `--doc-map` file not found | Fatal error (like `--system-prompt-file`) |
|
|
||||||
| `--doc-map` file invalid YAML | Fatal error with descriptive message |
|
|
||||||
| Unknown keys in YAML | Log warning, continue |
|
|
||||||
| Doc file not found in VCS | Log warning, skip |
|
|
||||||
| Doc directory empty | Log debug, skip |
|
|
||||||
| Total size exceeds limit | Truncate with notice, log warning |
|
|
||||||
| No changed paths match | No docs injected, review runs normally |
|
|
||||||
| `paths` list empty in a mapping | Skip that mapping (no match possible) |
|
|
||||||
| `docs` list empty in a mapping | Skip that mapping (nothing to inject) |
|
|
||||||
|
|
||||||
## Edge Cases
|
|
||||||
|
|
||||||
- Empty `mappings` list → no docs injected, no error
|
|
||||||
- Same doc matched by multiple mappings → deduplicate by path
|
|
||||||
- Directory with no `.md` files → skip silently (log debug)
|
|
||||||
- Very large single doc file → counts against limit, may truncate
|
|
||||||
- Symlinks/special files in VCS → GetFileContent handles or errors (warn + skip)
|
|
||||||
- `doc-map` path outside workspace → fatal error (validateWorkspacePath)
|
|
||||||
- Directory path specified as `docs` entry without trailing `/` → check if it's a directory via ListContents or GetAllFilesInPath; if error, try GetFileContent
|
|
||||||
|
|
||||||
## Testing Strategy
|
|
||||||
|
|
||||||
### Unit tests (review/docmap_test.go)
|
|
||||||
|
|
||||||
1. **ParseDocMapConfig** — valid YAML, invalid YAML, unknown keys (warning), empty file
|
|
||||||
2. **MatchDocs** — no match, single match, multi-match, deduplication, `**` glob, exact match
|
|
||||||
3. **LoadMatchingDocs** — with mock DocFetcher:
|
|
||||||
- file path → content returned
|
|
||||||
- missing file → warned + skipped
|
|
||||||
- directory path → expands .md files
|
|
||||||
- directory with no .md → empty
|
|
||||||
- size guard → truncation with notice
|
|
||||||
- deduplication in combined results
|
|
||||||
|
|
||||||
### Integration coverage
|
|
||||||
|
|
||||||
The existing `main_test.go` tests cover flag wiring — add a test for `--doc-map` flag parsing and workspace path validation.
|
|
||||||
|
|
||||||
## Open Questions
|
|
||||||
|
|
||||||
1. **Directory detection**: The issue says "directory paths expand to all .md files". But review-bot has no local filesystem. When a `docs` entry is `docs/domain/contexts/trading/`, we can call `GetAllFilesInPath`. But what if someone writes `docs/domain/contexts/trading` (no trailing slash)? We could try GetFileContent first, and if it fails with a 404 or "is directory" error, fall back to GetAllFilesInPath. OR we could just always call GetAllFilesInPath and if it returns content, use it; if it returns empty, try GetFileContent.
|
|
||||||
**Decision**: Try GetAllFilesInPath first (always). If it returns ≥1 file, treat as directory. If it returns 0 files AND no error, try GetFileContent. If GetAllFilesInPath returns an error, try GetFileContent.
|
|
||||||
|
|
||||||
2. **Budget section placement**: The issue says docs go in "system prompt after system-prompt-file content". That means docs are part of the system prompt. Current budget: SystemBase (includes additionalPrompt) → Patterns → Conventions. I'll add DesignDocs after Conventions (trim after Conventions). Docs are injected into system prompt via `buildResult`.
|
|
||||||
**Decision**: DesignDocs section in budget, trimmed after Conventions, before FileContext.
|
|
||||||
|
|
||||||
3. **Configurable size limit**: The issue says "configurable". Add `--doc-map-max-bytes` flag (default 102400). Pass via `DocMapOptions`.
|
|
||||||
**Decision**: Add flag. Default 100KB (102400 bytes).
|
|
||||||
|
|
||||||
## Completion Checklist
|
|
||||||
|
|
||||||
1. `doc-map` input added to action.yml with correct env var passthrough
|
|
||||||
2. `--doc-map` and `--doc-map-max-bytes` flags parsed in main.go
|
|
||||||
3. `doc-map` file validated with `validateWorkspacePath` before reading
|
|
||||||
4. YAML parsed with `go-yaml`, unknown keys warned not errored
|
|
||||||
5. Glob matching handles `**` segments
|
|
||||||
6. Changed files list from PR drives intersection (not hardcoded)
|
|
||||||
7. Docs deduplicated before fetching
|
|
||||||
8. Missing doc files: warn + skip, not fatal
|
|
||||||
9. Context size guard truncates with notice, logs warning
|
|
||||||
10. `DesignDocs` section added to `budget.Sections` and `buildResult`
|
|
||||||
11. Tests cover: match, no-match, dedup, missing file, directory expansion, size guard, YAML parse error
|
|
||||||
12. `go test ./...` passes
|
|
||||||
13. `go vet ./...` passes
|
|
||||||
14. CHANGELOG updated
|
|
||||||
@@ -6,11 +6,10 @@ AI-powered code review bot for Gitea pull requests. Fetches diff + context, send
|
|||||||
|
|
||||||
- **Multi-provider**: OpenAI-compatible, Anthropic Messages API, and SAP AI Core
|
- **Multi-provider**: OpenAI-compatible, Anthropic Messages API, and SAP AI Core
|
||||||
- **Context-aware**: Fetches full file content, conventions, language patterns, CI status
|
- **Context-aware**: Fetches full file content, conventions, language patterns, CI status
|
||||||
- **Path-scoped docs**: `doc-map` config injects only the governing design docs for changed paths
|
|
||||||
- **Smart budget**: Automatically trims context to fit model token limits
|
- **Smart budget**: Automatically trims context to fit model token limits
|
||||||
- **Idempotent reviews**: Posts new review, then cleans up stale ones (one review per bot)
|
- **Idempotent reviews**: Posts new review, then cleans up stale ones (one review per bot)
|
||||||
- **Custom prompts**: Load additional instructions from a file (e.g. security-focused review)
|
- **Custom prompts**: Load additional instructions from a file (e.g. security-focused review)
|
||||||
- **Minimal dependencies**: Go stdlib + `github.com/goccy/go-yaml` only
|
- **Minimal dependencies**: Go stdlib + `gopkg.in/yaml.v3` only
|
||||||
|
|
||||||
## Quick Start: Composite Action
|
## Quick Start: Composite Action
|
||||||
|
|
||||||
@@ -283,7 +282,7 @@ Rules:
|
|||||||
|
|
||||||
```bash
|
```bash
|
||||||
review-bot \
|
review-bot \
|
||||||
--vcs-url https://gitea.example.com \
|
--gitea-url https://gitea.example.com \
|
||||||
--repo owner/name \
|
--repo owner/name \
|
||||||
--pr 42 \
|
--pr 42 \
|
||||||
--reviewer-token "$GITEA_TOKEN" \
|
--reviewer-token "$GITEA_TOKEN" \
|
||||||
@@ -300,7 +299,7 @@ All flags have environment variable equivalents:
|
|||||||
|
|
||||||
| Flag | Env Var |
|
| Flag | Env Var |
|
||||||
|------|---------|
|
|------|---------|
|
||||||
| `--vcs-url` | `VCS_URL` (fallback: `GITEA_URL`) |
|
| `--gitea-url` | `GITEA_URL` |
|
||||||
| `--repo` | `GITEA_REPO` |
|
| `--repo` | `GITEA_REPO` |
|
||||||
| `--pr` | `PR_NUMBER` |
|
| `--pr` | `PR_NUMBER` |
|
||||||
| `--reviewer-token` | `REVIEWER_TOKEN` |
|
| `--reviewer-token` | `REVIEWER_TOKEN` |
|
||||||
|
|||||||
@@ -1,37 +0,0 @@
|
|||||||
## Dev Loop Status: 2026-05-15 02:28 UTC
|
|
||||||
|
|
||||||
**Repository:** review-bot (rodin/review-bot on Gitea)
|
|
||||||
**Status:** ✅ OPTIMAL
|
|
||||||
|
|
||||||
### Health Check
|
|
||||||
|
|
||||||
- **Working tree:** clean
|
|
||||||
- **Branch:** main (up to date with origin)
|
|
||||||
- **Build:** ✅ passes (`go build ./cmd/review-bot`)
|
|
||||||
- **Tests:** ✅ ALL PASS (6/6 packages)
|
|
||||||
- **Vet:** ✅ clean
|
|
||||||
- **Open issues:** 0
|
|
||||||
- **Open PRs:** 0
|
|
||||||
|
|
||||||
### Recent Changes
|
|
||||||
|
|
||||||
Last commit: `dcfd360` (2026-05-15 01:48) — health check post-sync
|
|
||||||
|
|
||||||
### Coverage
|
|
||||||
|
|
||||||
| Package | Coverage |
|
|
||||||
|---------|----------|
|
|
||||||
| cmd/review-bot | 46.1% |
|
|
||||||
| gitea | 85.2% |
|
|
||||||
| github | 86.3% |
|
|
||||||
| review | 92.0% |
|
|
||||||
|
|
||||||
### Next Priority
|
|
||||||
|
|
||||||
- Increase cmd/review-bot coverage (lowest at 46.1%)
|
|
||||||
- Monitor prod logs for edge cases
|
|
||||||
- VCS integration stable; GitHub + Gitea paths clear
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
_Dev-loop cycle complete at 02:28 UTC._
|
|
||||||
+1
-7
@@ -63,8 +63,7 @@ type Sections struct {
|
|||||||
SystemBase string // Core instructions (never trimmed)
|
SystemBase string // Core instructions (never trimmed)
|
||||||
Patterns string // Language patterns (trimmed first)
|
Patterns string // Language patterns (trimmed first)
|
||||||
Conventions string // Repo conventions (trimmed second)
|
Conventions string // Repo conventions (trimmed second)
|
||||||
DesignDocs string // Path-scoped design documents (trimmed third)
|
FileContext string // Full file content (trimmed third)
|
||||||
FileContext string // Full file content (trimmed fourth)
|
|
||||||
Diff string // The actual diff (trimmed last, only truncated)
|
Diff string // The actual diff (trimmed last, only truncated)
|
||||||
UserMeta string // PR title, description, CI status (truncated only if base exceeds budget)
|
UserMeta string // PR title, description, CI status (truncated only if base exceeds budget)
|
||||||
}
|
}
|
||||||
@@ -104,7 +103,6 @@ func Fit(model string, sections Sections) Result {
|
|||||||
entries := []entry{
|
entries := []entry{
|
||||||
{"patterns", §ions.Patterns},
|
{"patterns", §ions.Patterns},
|
||||||
{"conventions", §ions.Conventions},
|
{"conventions", §ions.Conventions},
|
||||||
{"design docs", §ions.DesignDocs},
|
|
||||||
{"file context", §ions.FileContext},
|
{"file context", §ions.FileContext},
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -187,10 +185,6 @@ func buildResult(s Sections, trimmed []string, estTokens int) Result {
|
|||||||
sys.WriteString("\n\n## Repository Conventions\n\nThe repository has the following coding conventions that must be respected:\n\n")
|
sys.WriteString("\n\n## Repository Conventions\n\nThe repository has the following coding conventions that must be respected:\n\n")
|
||||||
sys.WriteString(s.Conventions)
|
sys.WriteString(s.Conventions)
|
||||||
}
|
}
|
||||||
if s.DesignDocs != "" {
|
|
||||||
sys.WriteString("\n\n## Design Documents\n\nThe following design documents govern the changed code. Review the diff for adherence:\n\n")
|
|
||||||
sys.WriteString(s.DesignDocs)
|
|
||||||
}
|
|
||||||
|
|
||||||
var usr strings.Builder
|
var usr strings.Builder
|
||||||
usr.WriteString(s.UserMeta)
|
usr.WriteString(s.UserMeta)
|
||||||
|
|||||||
@@ -157,6 +157,7 @@ func TestFit_PreservesNoteInOutput(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
func TestFit_HugeUserMeta(t *testing.T) {
|
func TestFit_HugeUserMeta(t *testing.T) {
|
||||||
// UserMeta so large that base alone exceeds limit
|
// UserMeta so large that base alone exceeds limit
|
||||||
// Use a unique marker past the truncation point
|
// Use a unique marker past the truncation point
|
||||||
|
|||||||
@@ -10,7 +10,6 @@ import (
|
|||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"gitea.weiker.me/rodin/review-bot/gitea"
|
"gitea.weiker.me/rodin/review-bot/gitea"
|
||||||
"gitea.weiker.me/rodin/review-bot/github"
|
|
||||||
"gitea.weiker.me/rodin/review-bot/llm"
|
"gitea.weiker.me/rodin/review-bot/llm"
|
||||||
"gitea.weiker.me/rodin/review-bot/review"
|
"gitea.weiker.me/rodin/review-bot/review"
|
||||||
)
|
)
|
||||||
@@ -18,7 +17,7 @@ import (
|
|||||||
// Integration test requires a running Gitea instance and LLM endpoint.
|
// Integration test requires a running Gitea instance and LLM endpoint.
|
||||||
// Set environment variables:
|
// Set environment variables:
|
||||||
//
|
//
|
||||||
// INTEGRATION_VCS_URL - VCS base URL
|
// INTEGRATION_GITEA_URL - Gitea base URL
|
||||||
// INTEGRATION_GITEA_TOKEN - Gitea API token with repo access
|
// INTEGRATION_GITEA_TOKEN - Gitea API token with repo access
|
||||||
// INTEGRATION_GITEA_REPO - owner/repo with an open PR
|
// INTEGRATION_GITEA_REPO - owner/repo with an open PR
|
||||||
// INTEGRATION_PR_NUMBER - PR number to test against
|
// INTEGRATION_PR_NUMBER - PR number to test against
|
||||||
@@ -26,7 +25,7 @@ import (
|
|||||||
// INTEGRATION_LLM_API_KEY - LLM API key
|
// INTEGRATION_LLM_API_KEY - LLM API key
|
||||||
// INTEGRATION_LLM_MODEL - Model name
|
// INTEGRATION_LLM_MODEL - Model name
|
||||||
func TestIntegration_FullReviewFlow(t *testing.T) {
|
func TestIntegration_FullReviewFlow(t *testing.T) {
|
||||||
giteaURL := os.Getenv("INTEGRATION_VCS_URL")
|
giteaURL := os.Getenv("INTEGRATION_GITEA_URL")
|
||||||
giteaToken := os.Getenv("INTEGRATION_GITEA_TOKEN")
|
giteaToken := os.Getenv("INTEGRATION_GITEA_TOKEN")
|
||||||
giteaRepo := os.Getenv("INTEGRATION_GITEA_REPO")
|
giteaRepo := os.Getenv("INTEGRATION_GITEA_REPO")
|
||||||
prNumStr := os.Getenv("INTEGRATION_PR_NUMBER")
|
prNumStr := os.Getenv("INTEGRATION_PR_NUMBER")
|
||||||
@@ -105,7 +104,7 @@ func TestIntegration_FullReviewFlow(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestIntegration_PostAndCleanup(t *testing.T) {
|
func TestIntegration_PostAndCleanup(t *testing.T) {
|
||||||
giteaURL := os.Getenv("INTEGRATION_VCS_URL")
|
giteaURL := os.Getenv("INTEGRATION_GITEA_URL")
|
||||||
giteaToken := os.Getenv("INTEGRATION_GITEA_TOKEN")
|
giteaToken := os.Getenv("INTEGRATION_GITEA_TOKEN")
|
||||||
giteaRepo := os.Getenv("INTEGRATION_GITEA_REPO")
|
giteaRepo := os.Getenv("INTEGRATION_GITEA_REPO")
|
||||||
prNumStr := os.Getenv("INTEGRATION_PR_NUMBER")
|
prNumStr := os.Getenv("INTEGRATION_PR_NUMBER")
|
||||||
@@ -131,7 +130,7 @@ func TestIntegration_PostAndCleanup(t *testing.T) {
|
|||||||
// Post a test review
|
// Post a test review
|
||||||
sentinel := "<!-- review-bot:integration-test -->"
|
sentinel := "<!-- review-bot:integration-test -->"
|
||||||
testBody := "# Integration Test Review\n\nThis is a test review.\n\n" + sentinel
|
testBody := "# Integration Test Review\n\nThis is a test review.\n\n" + sentinel
|
||||||
posted, err := giteaClient.PostReview(ctx, owner, repoName, prNumber, "COMMENT", testBody, "", nil)
|
posted, err := giteaClient.PostReview(ctx, owner, repoName, prNumber, "COMMENT", testBody, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("PostReview: %v", err)
|
t.Fatalf("PostReview: %v", err)
|
||||||
}
|
}
|
||||||
@@ -160,85 +159,3 @@ func TestIntegration_PostAndCleanup(t *testing.T) {
|
|||||||
t.Logf("Warning: could not delete test review %d: %v", posted.ID, err)
|
t.Logf("Warning: could not delete test review %d: %v", posted.ID, err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// TestIntegration_GitHub_PostAndVerifyReview exercises the full VCS routing path
|
|
||||||
// for GitHub when INTEGRATION_GITHUB_TOKEN and INTEGRATION_GITHUB_REPO are set.
|
|
||||||
// It verifies that the GitHub adapter is selected via VCS_TYPE=github and that
|
|
||||||
// PostReview succeeds against a real GitHub PR.
|
|
||||||
//
|
|
||||||
// Required environment variables:
|
|
||||||
//
|
|
||||||
// INTEGRATION_GITHUB_TOKEN - GitHub personal access token with repo access
|
|
||||||
// INTEGRATION_GITHUB_REPO - owner/repo with an open PR (e.g. Rodin-AI/review-bot)
|
|
||||||
// INTEGRATION_GITHUB_PR - PR number to test against
|
|
||||||
//
|
|
||||||
// The test skips gracefully when these variables are absent.
|
|
||||||
func TestIntegration_GitHub_PostAndVerifyReview(t *testing.T) {
|
|
||||||
githubToken := os.Getenv("INTEGRATION_GITHUB_TOKEN")
|
|
||||||
githubRepo := os.Getenv("INTEGRATION_GITHUB_REPO")
|
|
||||||
prNumStr := os.Getenv("INTEGRATION_GITHUB_PR")
|
|
||||||
|
|
||||||
if githubToken == "" || githubRepo == "" || prNumStr == "" {
|
|
||||||
t.Skip("INTEGRATION_GITHUB_TOKEN, INTEGRATION_GITHUB_REPO, and INTEGRATION_GITHUB_PR not set, skipping")
|
|
||||||
}
|
|
||||||
|
|
||||||
prNumber, err := strconv.Atoi(prNumStr)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Invalid PR number %q: %v", prNumStr, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
parts := strings.SplitN(githubRepo, "/", 2)
|
|
||||||
if len(parts) != 2 || parts[0] == "" || parts[1] == "" {
|
|
||||||
t.Fatalf("Invalid repo format %q, expected owner/repo", githubRepo)
|
|
||||||
}
|
|
||||||
owner, repoName := parts[0], parts[1]
|
|
||||||
|
|
||||||
ctx := context.Background()
|
|
||||||
ghClient := github.NewClient(githubToken, "https://api.github.com")
|
|
||||||
|
|
||||||
// Verify adapter selection: GetAuthenticatedUser must succeed.
|
|
||||||
user, err := ghClient.GetAuthenticatedUser(ctx)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("GetAuthenticatedUser: %v — check INTEGRATION_GITHUB_TOKEN", err)
|
|
||||||
}
|
|
||||||
t.Logf("Authenticated as: %s", user)
|
|
||||||
|
|
||||||
// Verify PR is accessible via GitHub adapter.
|
|
||||||
pr, err := ghClient.GetPullRequest(ctx, owner, repoName, prNumber)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("GetPullRequest: %v", err)
|
|
||||||
}
|
|
||||||
t.Logf("PR: %s (sha: %s)", pr.Title, pr.Head.Sha)
|
|
||||||
|
|
||||||
// Post a COMMENT review — does not require PR approval permissions.
|
|
||||||
sentinel := "<!-- review-bot:integration-test -->"
|
|
||||||
testBody := "# Integration Test Review (GitHub)\n\nThis is an automated integration test.\n\n" + sentinel
|
|
||||||
posted, err := ghClient.PostReview(ctx, owner, repoName, prNumber, "COMMENT", testBody, "", nil)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("PostReview: %v", err)
|
|
||||||
}
|
|
||||||
t.Logf("Posted review ID: %d", posted.ID)
|
|
||||||
|
|
||||||
// Verify the review appears in ListReviews.
|
|
||||||
reviews, err := ghClient.ListReviews(ctx, owner, repoName, prNumber)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("ListReviews: %v", err)
|
|
||||||
}
|
|
||||||
found := false
|
|
||||||
for _, r := range reviews {
|
|
||||||
if r.ID == posted.ID && strings.Contains(r.Body, sentinel) {
|
|
||||||
found = true
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if !found {
|
|
||||||
t.Errorf("posted review ID %d not found in ListReviews output", posted.ID)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Attempt cleanup — GitHub does not allow deleting submitted reviews,
|
|
||||||
// so this is expected to fail with ErrCannotDeleteSubmittedReview (422).
|
|
||||||
// Log it as informational only.
|
|
||||||
if err := ghClient.DeleteReview(ctx, owner, repoName, prNumber, posted.ID); err != nil {
|
|
||||||
t.Logf("Note: DeleteReview returned (expected for submitted GitHub reviews): %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|||||||
+76
-176
@@ -4,7 +4,6 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"flag"
|
"flag"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
|
||||||
"log/slog"
|
"log/slog"
|
||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
@@ -14,20 +13,13 @@ import (
|
|||||||
|
|
||||||
"gitea.weiker.me/rodin/review-bot/budget"
|
"gitea.weiker.me/rodin/review-bot/budget"
|
||||||
"gitea.weiker.me/rodin/review-bot/gitea"
|
"gitea.weiker.me/rodin/review-bot/gitea"
|
||||||
"gitea.weiker.me/rodin/review-bot/github"
|
|
||||||
"gitea.weiker.me/rodin/review-bot/llm"
|
"gitea.weiker.me/rodin/review-bot/llm"
|
||||||
"gitea.weiker.me/rodin/review-bot/review"
|
"gitea.weiker.me/rodin/review-bot/review"
|
||||||
|
"gitea.weiker.me/rodin/review-bot/vcs"
|
||||||
)
|
)
|
||||||
|
|
||||||
var version = "dev"
|
var version = "dev"
|
||||||
|
|
||||||
// outWriter and errWriter are the output and error writers for subcommands.
|
|
||||||
// They are variables so tests can capture output.
|
|
||||||
var (
|
|
||||||
outWriter io.Writer = os.Stdout
|
|
||||||
errWriter io.Writer = os.Stderr
|
|
||||||
)
|
|
||||||
|
|
||||||
// setupLogger configures the global slog default logger based on format and verbosity.
|
// setupLogger configures the global slog default logger based on format and verbosity.
|
||||||
func setupLogger(format, verbosity string) {
|
func setupLogger(format, verbosity string) {
|
||||||
var level slog.Level
|
var level slog.Level
|
||||||
@@ -58,23 +50,13 @@ func setupLogger(format, verbosity string) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func main() {
|
func main() {
|
||||||
// Dispatch subcommands before flag parsing so they get their own args.
|
|
||||||
// e.g. `review-bot validate-url <url>`
|
|
||||||
if len(os.Args) > 1 {
|
|
||||||
switch os.Args[1] {
|
|
||||||
case "validate-url":
|
|
||||||
os.Exit(runValidateURL(os.Args[2:]))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
versionFlag := flag.Bool("version", false, "Print version and exit")
|
versionFlag := flag.Bool("version", false, "Print version and exit")
|
||||||
// Logging flags
|
// Logging flags
|
||||||
logFormat := flag.String("log-format", envOrDefault("LOG_FORMAT", "text"), "Log output format: text or json")
|
logFormat := flag.String("log-format", envOrDefault("LOG_FORMAT", "text"), "Log output format: text or json")
|
||||||
verbosity := flag.String("verbosity", envOrDefault("LOG_VERBOSITY", "info"), "Log verbosity: debug, info, warn, error")
|
verbosity := flag.String("verbosity", envOrDefault("LOG_VERBOSITY", "info"), "Log verbosity: debug, info, warn, error")
|
||||||
// CLI flags
|
// CLI flags
|
||||||
vcsURL := flag.String("vcs-url", os.Getenv("VCS_URL"), "VCS server URL (e.g. https://gitea.example.com)")
|
giteaURL := flag.String("gitea-url", envOrDefault("GITEA_URL", envOrDefault("GITHUB_SERVER_URL", "")), "Gitea instance URL")
|
||||||
giteaURLAlias := flag.String("gitea-url", "", "Deprecated: use --vcs-url")
|
repo := flag.String("repo", envOrDefault("GITEA_REPO", envOrDefault("GITHUB_REPOSITORY", "")), "Repository (owner/name)")
|
||||||
repo := flag.String("repo", envOrDefault("GITEA_REPO", ""), "Repository (owner/name)")
|
|
||||||
prNum := flag.String("pr", envOrDefault("PR_NUMBER", ""), "Pull request number")
|
prNum := flag.String("pr", envOrDefault("PR_NUMBER", ""), "Pull request number")
|
||||||
reviewerName := flag.String("reviewer-name", envOrDefault("REVIEWER_NAME", ""), "Reviewer display name")
|
reviewerName := flag.String("reviewer-name", envOrDefault("REVIEWER_NAME", ""), "Reviewer display name")
|
||||||
reviewerToken := flag.String("reviewer-token", envOrDefault("REVIEWER_TOKEN", ""), "Gitea token for posting review")
|
reviewerToken := flag.String("reviewer-token", envOrDefault("REVIEWER_TOKEN", ""), "Gitea token for posting review")
|
||||||
@@ -84,7 +66,7 @@ func main() {
|
|||||||
conventionsFile := flag.String("conventions-file", envOrDefault("CONVENTIONS_FILE", ""), "Conventions file path in repo (e.g. CLAUDE.md)")
|
conventionsFile := flag.String("conventions-file", envOrDefault("CONVENTIONS_FILE", ""), "Conventions file path in repo (e.g. CLAUDE.md)")
|
||||||
systemPromptFile := flag.String("system-prompt-file", envOrDefault("SYSTEM_PROMPT_FILE", ""), "Local file with additional system prompt instructions")
|
systemPromptFile := flag.String("system-prompt-file", envOrDefault("SYSTEM_PROMPT_FILE", ""), "Local file with additional system prompt instructions")
|
||||||
patternsRepo := flag.String("patterns-repo", envOrDefault("PATTERNS_REPO", ""), "Repo with language patterns (e.g. rodin/elixir-patterns)")
|
patternsRepo := flag.String("patterns-repo", envOrDefault("PATTERNS_REPO", ""), "Repo with language patterns (e.g. rodin/elixir-patterns)")
|
||||||
patternsFiles := flag.String("patterns-files", envOrDefault("PATTERNS_FILES", ""), "Comma-separated file paths to fetch from patterns repo (empty = all files)")
|
patternsFiles := flag.String("patterns-files", envOrDefault("PATTERNS_FILES", "README.md"), "Comma-separated file paths to fetch from patterns repo")
|
||||||
dryRun := flag.Bool("dry-run", false, "Print review to stdout instead of posting")
|
dryRun := flag.Bool("dry-run", false, "Print review to stdout instead of posting")
|
||||||
llmTemp := flag.Float64("llm-temperature", envOrDefaultFloat("LLM_TEMPERATURE", 0), "LLM temperature (0 = server default)")
|
llmTemp := flag.Float64("llm-temperature", envOrDefaultFloat("LLM_TEMPERATURE", 0), "LLM temperature (0 = server default)")
|
||||||
llmTimeout := flag.Int("llm-timeout", envOrDefaultInt("LLM_TIMEOUT", 300), "LLM request timeout in seconds (default 300)")
|
llmTimeout := flag.Int("llm-timeout", envOrDefaultInt("LLM_TIMEOUT", 300), "LLM request timeout in seconds (default 300)")
|
||||||
@@ -97,8 +79,6 @@ func main() {
|
|||||||
aicoreAuthURL := flag.String("aicore-auth-url", envOrDefault("AICORE_AUTH_URL", ""), "SAP AI Core auth URL (for provider=aicore)")
|
aicoreAuthURL := flag.String("aicore-auth-url", envOrDefault("AICORE_AUTH_URL", ""), "SAP AI Core auth URL (for provider=aicore)")
|
||||||
aicoreAPIURL := flag.String("aicore-api-url", envOrDefault("AICORE_API_URL", ""), "SAP AI Core API URL (for provider=aicore)")
|
aicoreAPIURL := flag.String("aicore-api-url", envOrDefault("AICORE_API_URL", ""), "SAP AI Core API URL (for provider=aicore)")
|
||||||
aicoreResourceGroup := flag.String("aicore-resource-group", envOrDefault("AICORE_RESOURCE_GROUP", "default"), "SAP AI Core resource group (for provider=aicore)")
|
aicoreResourceGroup := flag.String("aicore-resource-group", envOrDefault("AICORE_RESOURCE_GROUP", "default"), "SAP AI Core resource group (for provider=aicore)")
|
||||||
docMapFile := flag.String("doc-map", envOrDefault("DOC_MAP_FILE", ""), "Path to YAML file mapping source path globs to governing design docs")
|
|
||||||
docMapMaxBytes := flag.Int("doc-map-max-bytes", envOrDefaultInt("DOC_MAP_MAX_BYTES", review.DefaultDocMapMaxBytes), "Maximum bytes of injected doc content (default 102400)")
|
|
||||||
|
|
||||||
flag.Parse()
|
flag.Parse()
|
||||||
|
|
||||||
@@ -112,24 +92,12 @@ func main() {
|
|||||||
|
|
||||||
slog.Info("review-bot starting", "version", version)
|
slog.Info("review-bot starting", "version", version)
|
||||||
|
|
||||||
// Backward compatibility: fall back to deprecated env var / flag if VCS_URL / --vcs-url not set.
|
|
||||||
if *vcsURL == "" {
|
|
||||||
if v := os.Getenv("GITEA_URL"); v != "" {
|
|
||||||
slog.Warn("GITEA_URL is deprecated; rename the environment variable to VCS_URL")
|
|
||||||
*vcsURL = v
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if *vcsURL == "" && *giteaURLAlias != "" {
|
|
||||||
slog.Warn("--gitea-url is deprecated; use --vcs-url instead")
|
|
||||||
*vcsURL = *giteaURLAlias
|
|
||||||
}
|
|
||||||
|
|
||||||
// Validate required fields
|
// Validate required fields
|
||||||
// For aicore provider, llm-base-url and llm-api-key are not required
|
// For aicore provider, llm-base-url and llm-api-key are not required
|
||||||
isAICore := llm.Provider(*llmProvider) == llm.ProviderAICore
|
isAICore := llm.Provider(*llmProvider) == llm.ProviderAICore
|
||||||
if *vcsURL == "" || *repo == "" || *prNum == "" || *reviewerToken == "" || *llmModel == "" {
|
if *giteaURL == "" || *repo == "" || *prNum == "" || *reviewerToken == "" || *llmModel == "" {
|
||||||
fmt.Fprintf(os.Stderr, "Error: missing required flags or environment variables\n\n")
|
fmt.Fprintf(os.Stderr, "Error: missing required flags or environment variables\n\n")
|
||||||
fmt.Fprintf(os.Stderr, "Required: --vcs-url, --repo, --pr, --reviewer-token, --llm-model\n")
|
fmt.Fprintf(os.Stderr, "Required: --gitea-url, --repo, --pr, --reviewer-token, --llm-model\n")
|
||||||
os.Exit(1)
|
os.Exit(1)
|
||||||
}
|
}
|
||||||
if !isAICore && (*llmBaseURL == "" || *llmAPIKey == "") {
|
if !isAICore && (*llmBaseURL == "" || *llmAPIKey == "") {
|
||||||
@@ -172,39 +140,7 @@ func main() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Initialize clients
|
// Initialize clients
|
||||||
// Detect VCS type: explicit flag > env var > URL heuristic (default: gitea).
|
giteaClient := gitea.NewClient(*giteaURL, *reviewerToken)
|
||||||
vcsType := envOrDefault("VCS_TYPE", "")
|
|
||||||
if vcsType == "" {
|
|
||||||
// Heuristic: if the URL looks like github.com or a GitHub Enterprise host,
|
|
||||||
// default to GitHub. The composite action sets VCS_TYPE explicitly, so this
|
|
||||||
// is a fallback for manual invocations.
|
|
||||||
if strings.Contains(*vcsURL, "github.com") || strings.Contains(*vcsURL, "github.concur.com") {
|
|
||||||
vcsType = "github"
|
|
||||||
} else {
|
|
||||||
vcsType = "gitea"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
slog.Info("VCS type detected", "vcs_type", vcsType, "vcs_url", *vcsURL)
|
|
||||||
|
|
||||||
var vcs vcsClient
|
|
||||||
switch vcsType {
|
|
||||||
case "github":
|
|
||||||
// GitHub: baseURL is the API URL, derived from server URL.
|
|
||||||
// github.com → https://api.github.com
|
|
||||||
// GHES (e.g. https://ghe.example.com) → https://ghe.example.com/api/v3
|
|
||||||
apiURL := githubAPIURL(*vcsURL)
|
|
||||||
ghClient := github.NewClient(*reviewerToken, apiURL)
|
|
||||||
vcs = newGithubVCSAdapter(ghClient)
|
|
||||||
slog.Info("using GitHub VCS client", "api_url", apiURL)
|
|
||||||
case "gitea":
|
|
||||||
giteaClient := gitea.NewClient(*vcsURL, *reviewerToken)
|
|
||||||
vcs = newGiteaVCSAdapter(giteaClient)
|
|
||||||
slog.Info("using Gitea VCS client", "url", *vcsURL)
|
|
||||||
default:
|
|
||||||
slog.Error("unsupported VCS type", "vcs_type", vcsType, "valid", "gitea, github")
|
|
||||||
os.Exit(1)
|
|
||||||
}
|
|
||||||
|
|
||||||
llmClient := llm.NewClient(*llmBaseURL, *llmAPIKey, *llmModel)
|
llmClient := llm.NewClient(*llmBaseURL, *llmAPIKey, *llmModel)
|
||||||
if *llmTemp < 0 || *llmTemp > 2 {
|
if *llmTemp < 0 || *llmTemp > 2 {
|
||||||
slog.Error("invalid LLM temperature", "temperature", *llmTemp, "range", "0-2")
|
slog.Error("invalid LLM temperature", "temperature", *llmTemp, "range", "0-2")
|
||||||
@@ -242,7 +178,7 @@ func main() {
|
|||||||
var persona *review.Persona
|
var persona *review.Persona
|
||||||
if *personaName != "" {
|
if *personaName != "" {
|
||||||
// Try loading from repo first, then fall back to built-in
|
// Try loading from repo first, then fall back to built-in
|
||||||
repoPersonas, err := review.LoadRepoPersonas(ctx, vcs, owner, repoName)
|
repoPersonas, err := review.LoadRepoPersonas(ctx, newGiteaClientAdapter(giteaClient), owner, repoName)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
slog.Warn("could not load repo personas", "repo", owner+"/"+repoName, "error", err)
|
slog.Warn("could not load repo personas", "repo", owner+"/"+repoName, "error", err)
|
||||||
// Continue with built-in personas only.
|
// Continue with built-in personas only.
|
||||||
@@ -278,7 +214,7 @@ func main() {
|
|||||||
slog.Info("reviewing pull request", "pr", prNumber, "repo", fmt.Sprintf("%s/%s", owner, repoName))
|
slog.Info("reviewing pull request", "pr", prNumber, "repo", fmt.Sprintf("%s/%s", owner, repoName))
|
||||||
|
|
||||||
// Step 1: Fetch PR metadata
|
// Step 1: Fetch PR metadata
|
||||||
pr, err := vcs.GetPullRequest(ctx, owner, repoName, prNumber)
|
pr, err := giteaClient.GetPullRequest(ctx, owner, repoName, prNumber)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
slog.Error("failed to fetch PR", "pr", prNumber, "error", err)
|
slog.Error("failed to fetch PR", "pr", prNumber, "error", err)
|
||||||
os.Exit(1)
|
os.Exit(1)
|
||||||
@@ -286,7 +222,7 @@ func main() {
|
|||||||
slog.Info("fetched PR metadata", "pr", prNumber, "title", pr.Title)
|
slog.Info("fetched PR metadata", "pr", prNumber, "title", pr.Title)
|
||||||
|
|
||||||
// Step 2: Fetch diff
|
// Step 2: Fetch diff
|
||||||
diff, err := vcs.GetPullRequestDiff(ctx, owner, repoName, prNumber)
|
diff, err := giteaClient.GetPullRequestDiff(ctx, owner, repoName, prNumber)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
slog.Error("failed to fetch diff", "pr", prNumber, "error", err)
|
slog.Error("failed to fetch diff", "pr", prNumber, "error", err)
|
||||||
os.Exit(1)
|
os.Exit(1)
|
||||||
@@ -295,11 +231,11 @@ func main() {
|
|||||||
|
|
||||||
// Step 3: Fetch full file content for modified files
|
// Step 3: Fetch full file content for modified files
|
||||||
fileContext := ""
|
fileContext := ""
|
||||||
files, err := vcs.GetPullRequestFiles(ctx, owner, repoName, prNumber)
|
files, err := giteaClient.GetPullRequestFiles(ctx, owner, repoName, prNumber)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
slog.Warn("could not fetch PR files list", "pr", prNumber, "error", err)
|
slog.Warn("could not fetch PR files list", "pr", prNumber, "error", err)
|
||||||
} else {
|
} else {
|
||||||
fileContext = fetchFileContext(ctx, vcs, owner, repoName, pr.Head.Ref, files)
|
fileContext = fetchFileContext(ctx, giteaClient, owner, repoName, pr.Head.Ref, files)
|
||||||
slog.Debug("fetched file context", "files", len(files))
|
slog.Debug("fetched file context", "files", len(files))
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -307,7 +243,7 @@ func main() {
|
|||||||
ciPassed := true
|
ciPassed := true
|
||||||
ciDetails := ""
|
ciDetails := ""
|
||||||
if pr.Head.Sha != "" {
|
if pr.Head.Sha != "" {
|
||||||
statuses, err := vcs.GetCommitStatuses(ctx, owner, repoName, pr.Head.Sha)
|
statuses, err := giteaClient.GetCommitStatuses(ctx, owner, repoName, pr.Head.Sha)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
slog.Warn("could not fetch CI status", "sha", pr.Head.Sha, "error", err)
|
slog.Warn("could not fetch CI status", "sha", pr.Head.Sha, "error", err)
|
||||||
} else {
|
} else {
|
||||||
@@ -319,7 +255,7 @@ func main() {
|
|||||||
// Step 5: Load conventions file if specified
|
// Step 5: Load conventions file if specified
|
||||||
conventions := ""
|
conventions := ""
|
||||||
if *conventionsFile != "" {
|
if *conventionsFile != "" {
|
||||||
content, err := vcs.GetFileContent(ctx, owner, repoName, *conventionsFile)
|
content, err := giteaClient.GetFileContent(ctx, owner, repoName, *conventionsFile)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
slog.Warn("could not load conventions file", "file", *conventionsFile, "error", err)
|
slog.Warn("could not load conventions file", "file", *conventionsFile, "error", err)
|
||||||
} else {
|
} else {
|
||||||
@@ -331,7 +267,7 @@ func main() {
|
|||||||
// Step 6: Load patterns from external repo if specified
|
// Step 6: Load patterns from external repo if specified
|
||||||
patterns := ""
|
patterns := ""
|
||||||
if *patternsRepo != "" {
|
if *patternsRepo != "" {
|
||||||
patterns = fetchPatterns(ctx, vcs, *patternsRepo, *patternsFiles)
|
patterns = fetchPatterns(ctx, giteaClient, *patternsRepo, *patternsFiles)
|
||||||
slog.Debug("loaded patterns", "repo", *patternsRepo, "bytes", len(patterns))
|
slog.Debug("loaded patterns", "repo", *patternsRepo, "bytes", len(patterns))
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -352,46 +288,6 @@ func main() {
|
|||||||
slog.Debug("loaded system prompt file", "file", *systemPromptFile, "bytes", len(additionalPrompt))
|
slog.Debug("loaded system prompt file", "file", *systemPromptFile, "bytes", len(additionalPrompt))
|
||||||
}
|
}
|
||||||
|
|
||||||
// Step 6c: Load path-scoped design docs if doc-map specified
|
|
||||||
designDocs := ""
|
|
||||||
if *docMapFile != "" {
|
|
||||||
resolvedDocMap, err := validateWorkspacePath(*docMapFile, "doc-map")
|
|
||||||
if err != nil {
|
|
||||||
slog.Error("invalid doc-map path", "error", err)
|
|
||||||
os.Exit(1)
|
|
||||||
}
|
|
||||||
docMapCfg, err := review.ParseDocMapConfig(resolvedDocMap)
|
|
||||||
if err != nil {
|
|
||||||
slog.Error("failed to parse doc-map file", "file", *docMapFile, "error", err)
|
|
||||||
os.Exit(1)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Collect changed file paths from the PR for intersection.
|
|
||||||
var changedPaths []string
|
|
||||||
for _, f := range files {
|
|
||||||
changedPaths = append(changedPaths, f.Filename)
|
|
||||||
}
|
|
||||||
|
|
||||||
matchedDocs := review.MatchDocs(docMapCfg, changedPaths)
|
|
||||||
slog.Debug("doc-map: matched docs", "count", len(matchedDocs), "docs", matchedDocs)
|
|
||||||
|
|
||||||
if len(matchedDocs) > 0 {
|
|
||||||
docMapOpts := review.DocMapOptions{MaxBytes: *docMapMaxBytes}
|
|
||||||
designDocs, err = review.LoadMatchingDocs(ctx, vcs, owner, repoName, matchedDocs, docMapOpts)
|
|
||||||
if err != nil {
|
|
||||||
// Non-fatal: individual missing files are already warned; log and continue.
|
|
||||||
slog.Warn("doc-map: partial failure loading docs", "error", err)
|
|
||||||
}
|
|
||||||
if designDocs != "" {
|
|
||||||
slog.Info("doc-map: injected design docs", "matched", len(matchedDocs), "bytes", len(designDocs))
|
|
||||||
} else {
|
|
||||||
slog.Debug("doc-map: no doc content loaded (all files missing or empty)")
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
slog.Debug("doc-map: no changed paths matched any mapping")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Step 7: Budget-aware prompt assembly
|
// Step 7: Budget-aware prompt assembly
|
||||||
var systemBase string
|
var systemBase string
|
||||||
if persona != nil {
|
if persona != nil {
|
||||||
@@ -407,7 +303,6 @@ func main() {
|
|||||||
SystemBase: systemBase,
|
SystemBase: systemBase,
|
||||||
Patterns: patterns,
|
Patterns: patterns,
|
||||||
Conventions: conventions,
|
Conventions: conventions,
|
||||||
DesignDocs: designDocs,
|
|
||||||
FileContext: fileContext,
|
FileContext: fileContext,
|
||||||
Diff: diff,
|
Diff: diff,
|
||||||
UserMeta: review.BuildUserMeta(pr.Title, pr.Body, ciPassed, ciDetails),
|
UserMeta: review.BuildUserMeta(pr.Title, pr.Body, ciPassed, ciDetails),
|
||||||
@@ -487,7 +382,7 @@ func main() {
|
|||||||
// Stale check: verify HEAD hasn't moved since we started
|
// Stale check: verify HEAD hasn't moved since we started
|
||||||
evaluatedSHA := pr.Head.Sha
|
evaluatedSHA := pr.Head.Sha
|
||||||
var currentSHA string
|
var currentSHA string
|
||||||
currentPR, err := vcs.GetPullRequest(ctx, owner, repoName, prNumber)
|
currentPR, err := giteaClient.GetPullRequest(ctx, owner, repoName, prNumber)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
slog.Warn("could not re-fetch PR for stale check", "pr", prNumber, "error", err)
|
slog.Warn("could not re-fetch PR for stale check", "pr", prNumber, "error", err)
|
||||||
// currentSHA stays empty — shouldSkipStaleReview will return false
|
// currentSHA stays empty — shouldSkipStaleReview will return false
|
||||||
@@ -504,10 +399,10 @@ func main() {
|
|||||||
|
|
||||||
// Map findings to inline comments for lines present in the diff
|
// Map findings to inline comments for lines present in the diff
|
||||||
diffRanges := gitea.ParseDiffNewLines(diff)
|
diffRanges := gitea.ParseDiffNewLines(diff)
|
||||||
var inlineComments []vcsReviewComment
|
var inlineComments []gitea.ReviewComment
|
||||||
for _, f := range result.Findings {
|
for _, f := range result.Findings {
|
||||||
if f.File != "" && f.Line > 0 && diffRanges.Contains(f.File, f.Line) {
|
if f.File != "" && f.Line > 0 && diffRanges.Contains(f.File, f.Line) {
|
||||||
inlineComments = append(inlineComments, vcsReviewComment{
|
inlineComments = append(inlineComments, gitea.ReviewComment{
|
||||||
Path: f.File,
|
Path: f.File,
|
||||||
NewPosition: int64(f.Line),
|
NewPosition: int64(f.Line),
|
||||||
Body: fmt.Sprintf("**[%s]** %s", f.Severity, f.Finding),
|
Body: fmt.Sprintf("**[%s]** %s", f.Severity, f.Finding),
|
||||||
@@ -522,9 +417,9 @@ func main() {
|
|||||||
// 1. POST new review first (gets non-stale approval badge on HEAD)
|
// 1. POST new review first (gets non-stale approval badge on HEAD)
|
||||||
// 2. Then supersede old review with link to the new one
|
// 2. Then supersede old review with link to the new one
|
||||||
// Order matters: post first so we have the new review's URL for the supersede message.
|
// Order matters: post first so we have the new review's URL for the supersede message.
|
||||||
var oldReviews []vcsReview
|
var oldReviews []gitea.Review
|
||||||
if *reviewerName != "" {
|
if *reviewerName != "" {
|
||||||
existingReviews, err := vcs.ListReviews(ctx, owner, repoName, prNumber)
|
existingReviews, err := giteaClient.ListReviews(ctx, owner, repoName, prNumber)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
slog.Warn("could not list existing reviews", "pr", prNumber, "error", err)
|
slog.Warn("could not list existing reviews", "pr", prNumber, "error", err)
|
||||||
} else {
|
} else {
|
||||||
@@ -537,11 +432,11 @@ func main() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Self-request as reviewer (ensures we appear in required-reviewer checks)
|
// Self-request as reviewer (ensures we appear in required-reviewer checks)
|
||||||
authUser, err := vcs.GetAuthenticatedUser(ctx)
|
authUser, err := giteaClient.GetAuthenticatedUser(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
slog.Warn("could not determine authenticated user for reviewer self-request", "error", err)
|
slog.Warn("could not determine authenticated user for reviewer self-request", "error", err)
|
||||||
} else if authUser != "" {
|
} else if authUser != "" {
|
||||||
if err := vcs.RequestReviewer(ctx, owner, repoName, prNumber, authUser); err != nil {
|
if err := giteaClient.RequestReviewer(ctx, owner, repoName, prNumber, authUser); err != nil {
|
||||||
slog.Warn("could not self-request as reviewer", "user", authUser, "error", err)
|
slog.Warn("could not self-request as reviewer", "user", authUser, "error", err)
|
||||||
} else {
|
} else {
|
||||||
slog.Debug("self-requested as reviewer", "user", authUser, "pr", prNumber)
|
slog.Debug("self-requested as reviewer", "user", authUser, "pr", prNumber)
|
||||||
@@ -550,34 +445,31 @@ func main() {
|
|||||||
|
|
||||||
// POST new review
|
// POST new review
|
||||||
slog.Info("posting review", "event", event, "pr", prNumber)
|
slog.Info("posting review", "event", event, "pr", prNumber)
|
||||||
posted, err := vcs.PostReview(ctx, owner, repoName, prNumber, event, reviewBody, evaluatedSHA, inlineComments)
|
posted, err := giteaClient.PostReview(ctx, owner, repoName, prNumber, event, reviewBody, inlineComments)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
slog.Error("failed to post review", "pr", prNumber, "event", event, "error", err)
|
slog.Error("failed to post review", "pr", prNumber, "event", event, "error", err)
|
||||||
os.Exit(1)
|
os.Exit(1)
|
||||||
}
|
}
|
||||||
slog.Info("review posted", "review_id", posted.ID, "user", posted.User.Login, "pr", prNumber)
|
slog.Info("review posted", "review_id", posted.ID, "user", posted.User.Login, "pr", prNumber)
|
||||||
|
|
||||||
// Supersede all old reviews with link to the new one.
|
// Supersede all old reviews with link to the new one
|
||||||
// This is only supported on Gitea (requires timeline API); GitHub reviews cannot
|
if len(oldReviews) > 0 {
|
||||||
// be edited after submission, so we skip the supersede step there.
|
newReviewURL := fmt.Sprintf("%s/%s/%s/pulls/%d#pullrequestreview-%d", strings.TrimRight(*giteaURL, "/"), owner, repoName, prNumber, posted.ID)
|
||||||
extVCS, isGiteaExt := vcs.(giteaExtClient)
|
|
||||||
if len(oldReviews) > 0 && isGiteaExt {
|
|
||||||
newReviewURL := fmt.Sprintf("%s/%s/%s/pulls/%d#pullrequestreview-%d", strings.TrimRight(*vcsURL, "/"), owner, repoName, prNumber, posted.ID)
|
|
||||||
for _, oldReview := range oldReviews {
|
for _, oldReview := range oldReviews {
|
||||||
cid, err := extVCS.GetTimelineReviewCommentIDForReview(ctx, owner, repoName, int64(prNumber), oldReview.ID)
|
cid, err := giteaClient.GetTimelineReviewCommentIDForReview(ctx, owner, repoName, prNumber, oldReview.ID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
slog.Warn("could not find comment ID for old review", "review_id", oldReview.ID, "error", err)
|
slog.Warn("could not find comment ID for old review", "review_id", oldReview.ID, "error", err)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
supersededBody := buildSupersededBody(oldReview.Body, oldReview.CommitID, newReviewURL, sentinel)
|
supersededBody := buildSupersededBody(oldReview.Body, oldReview.CommitID, newReviewURL, sentinel)
|
||||||
if err := extVCS.EditComment(ctx, owner, repoName, cid, supersededBody); err != nil {
|
if err := giteaClient.EditComment(ctx, owner, repoName, cid, supersededBody); err != nil {
|
||||||
slog.Warn("could not mark old review as superseded", "review_id", oldReview.ID, "comment_id", cid, "error", err)
|
slog.Warn("could not mark old review as superseded", "review_id", oldReview.ID, "comment_id", cid, "error", err)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
slog.Info("marked old review as superseded", "review_id", oldReview.ID, "new_review_id", posted.ID, "pr", prNumber)
|
slog.Info("marked old review as superseded", "review_id", oldReview.ID, "new_review_id", posted.ID, "pr", prNumber)
|
||||||
|
|
||||||
// Resolve old review's inline comments
|
// Resolve old review's inline comments
|
||||||
oldComments, err := extVCS.ListReviewComments(ctx, owner, repoName, int64(prNumber), oldReview.ID)
|
oldComments, err := giteaClient.ListReviewComments(ctx, owner, repoName, prNumber, oldReview.ID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
slog.Warn("could not list old review comments for resolution", "review_id", oldReview.ID, "error", err)
|
slog.Warn("could not list old review comments for resolution", "review_id", oldReview.ID, "error", err)
|
||||||
continue
|
continue
|
||||||
@@ -587,7 +479,7 @@ func main() {
|
|||||||
if c.ID == 0 {
|
if c.ID == 0 {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
if err := extVCS.ResolveComment(ctx, owner, repoName, c.ID); err != nil {
|
if err := giteaClient.ResolveComment(ctx, owner, repoName, c.ID); err != nil {
|
||||||
slog.Debug("could not resolve inline comment", "comment_id", c.ID, "error", err)
|
slog.Debug("could not resolve inline comment", "comment_id", c.ID, "error", err)
|
||||||
failed++
|
failed++
|
||||||
} else {
|
} else {
|
||||||
@@ -601,14 +493,12 @@ func main() {
|
|||||||
slog.Warn("some inline comments could not be resolved", "review_id", oldReview.ID, "failed", failed, "pr", prNumber)
|
slog.Warn("some inline comments could not be resolved", "review_id", oldReview.ID, "failed", failed, "pr", prNumber)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else if len(oldReviews) > 0 {
|
|
||||||
slog.Info("skipping supersede of old reviews (not supported on this VCS)", "old_count", len(oldReviews), "pr", prNumber)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// fetchFileContext fetches the full content of modified files from the PR branch.
|
// fetchFileContext fetches the full content of modified files from the PR branch.
|
||||||
func fetchFileContext(ctx context.Context, client vcsClient, owner, repo, ref string, files []vcsChangedFile) string {
|
func fetchFileContext(ctx context.Context, client *gitea.Client, owner, repo, ref string, files []gitea.ChangedFile) string {
|
||||||
var sb strings.Builder
|
var sb strings.Builder
|
||||||
for _, f := range files {
|
for _, f := range files {
|
||||||
if ctx.Err() != nil {
|
if ctx.Err() != nil {
|
||||||
@@ -634,25 +524,11 @@ func fetchFileContext(ctx context.Context, client vcsClient, owner, repo, ref st
|
|||||||
// patternsRepo is comma-separated list of owner/name repos.
|
// patternsRepo is comma-separated list of owner/name repos.
|
||||||
// patternsFiles is comma-separated list of file paths or directories.
|
// patternsFiles is comma-separated list of file paths or directories.
|
||||||
// If a path ends with / or is a directory, all files within it are fetched recursively.
|
// If a path ends with / or is a directory, all files within it are fetched recursively.
|
||||||
// If patternsFiles is empty, all files from the repo root are fetched.
|
func fetchPatterns(ctx context.Context, client *gitea.Client, patternsRepo, patternsFiles string) string {
|
||||||
func fetchPatterns(ctx context.Context, client vcsClient, patternsRepo, patternsFiles string) string {
|
|
||||||
var sb strings.Builder
|
var sb strings.Builder
|
||||||
|
|
||||||
repos := strings.Split(patternsRepo, ",")
|
repos := strings.Split(patternsRepo, ",")
|
||||||
|
paths := strings.Split(patternsFiles, ",")
|
||||||
// Build the list of paths to fetch
|
|
||||||
var paths []string
|
|
||||||
if patternsFiles == "" {
|
|
||||||
// Empty patternsFiles means "fetch all files from repo root"
|
|
||||||
paths = []string{""}
|
|
||||||
} else {
|
|
||||||
for _, p := range strings.Split(patternsFiles, ",") {
|
|
||||||
p = strings.TrimSpace(p)
|
|
||||||
if p != "" {
|
|
||||||
paths = append(paths, p)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, repoRef := range repos {
|
for _, repoRef := range repos {
|
||||||
if ctx.Err() != nil {
|
if ctx.Err() != nil {
|
||||||
@@ -673,6 +549,11 @@ func fetchPatterns(ctx context.Context, client vcsClient, patternsRepo, patterns
|
|||||||
var repoSkippedFiles []string
|
var repoSkippedFiles []string
|
||||||
|
|
||||||
for _, path := range paths {
|
for _, path := range paths {
|
||||||
|
path = strings.TrimSpace(path)
|
||||||
|
if path == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
files, err := client.GetAllFilesInPath(ctx, owner, repo, path)
|
files, err := client.GetAllFilesInPath(ctx, owner, repo, path)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
slog.Warn("could not fetch patterns", "path", path, "repo", repoRef, "error", err)
|
slog.Warn("could not fetch patterns", "path", path, "repo", repoRef, "error", err)
|
||||||
@@ -712,7 +593,7 @@ func isPatternFile(path string) bool {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// evaluateCIStatus checks if all CI statuses indicate success.
|
// evaluateCIStatus checks if all CI statuses indicate success.
|
||||||
func evaluateCIStatus(statuses []vcsCommitStatus) (passed bool, details string) {
|
func evaluateCIStatus(statuses []gitea.CommitStatus) (passed bool, details string) {
|
||||||
if len(statuses) == 0 {
|
if len(statuses) == 0 {
|
||||||
return true, "no CI statuses found"
|
return true, "no CI statuses found"
|
||||||
}
|
}
|
||||||
@@ -735,19 +616,6 @@ func evaluateCIStatus(statuses []vcsCommitStatus) (passed bool, details string)
|
|||||||
return true, "all checks passed"
|
return true, "all checks passed"
|
||||||
}
|
}
|
||||||
|
|
||||||
// githubAPIURL converts a GitHub server URL to its API base URL.
|
|
||||||
// github.com → https://api.github.com
|
|
||||||
// GHES (e.g. https://ghe.example.com) → https://ghe.example.com/api/v3
|
|
||||||
func githubAPIURL(serverURL string) string {
|
|
||||||
const canonicalGitHub = "https://github.com"
|
|
||||||
const githubAPIBase = "https://api.github.com"
|
|
||||||
if serverURL == "" || strings.TrimRight(serverURL, "/") == canonicalGitHub {
|
|
||||||
return githubAPIBase
|
|
||||||
}
|
|
||||||
// GitHub Enterprise Server: /api/v3 suffix
|
|
||||||
return strings.TrimRight(serverURL, "/") + "/api/v3"
|
|
||||||
}
|
|
||||||
|
|
||||||
func envOrDefault(key, defaultVal string) string {
|
func envOrDefault(key, defaultVal string) string {
|
||||||
if v := os.Getenv(key); v != "" {
|
if v := os.Getenv(key); v != "" {
|
||||||
return v
|
return v
|
||||||
@@ -863,7 +731,7 @@ func buildSupersededBody(originalBody, commitSHA, newReviewURL, sentinel string)
|
|||||||
// Gitea user. This indicates misconfiguration where two roles share a token
|
// Gitea user. This indicates misconfiguration where two roles share a token
|
||||||
// instead of having separate Gitea accounts. Returns true if shared token
|
// instead of having separate Gitea accounts. Returns true if shared token
|
||||||
// detected (caller should skip update-in-place logic to avoid clobbering).
|
// detected (caller should skip update-in-place logic to avoid clobbering).
|
||||||
func hasSharedToken(reviews []vcsReview, ownSentinel string) bool {
|
func hasSharedToken(reviews []gitea.Review, ownSentinel string) bool {
|
||||||
ownLogin := ""
|
ownLogin := ""
|
||||||
for _, r := range reviews {
|
for _, r := range reviews {
|
||||||
if strings.Contains(r.Body, ownSentinel) {
|
if strings.Contains(r.Body, ownSentinel) {
|
||||||
@@ -901,8 +769,8 @@ func extractSentinelName(body string) string {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// findOwnReview locates the most recent non-superseded review matching the sentinel.
|
// findOwnReview locates the most recent non-superseded review matching the sentinel.
|
||||||
func findOwnReview(reviews []vcsReview, sentinel string) *vcsReview {
|
func findOwnReview(reviews []gitea.Review, sentinel string) *gitea.Review {
|
||||||
var best *vcsReview
|
var best *gitea.Review
|
||||||
for i := range reviews {
|
for i := range reviews {
|
||||||
if !strings.Contains(reviews[i].Body, sentinel) {
|
if !strings.Contains(reviews[i].Body, sentinel) {
|
||||||
continue
|
continue
|
||||||
@@ -918,8 +786,8 @@ func findOwnReview(reviews []vcsReview, sentinel string) *vcsReview {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// findAllOwnReviews returns all non-superseded reviews matching the sentinel.
|
// findAllOwnReviews returns all non-superseded reviews matching the sentinel.
|
||||||
func findAllOwnReviews(reviews []vcsReview, sentinel string) []vcsReview {
|
func findAllOwnReviews(reviews []gitea.Review, sentinel string) []gitea.Review {
|
||||||
var result []vcsReview
|
var result []gitea.Review
|
||||||
for i := range reviews {
|
for i := range reviews {
|
||||||
if !strings.Contains(reviews[i].Body, sentinel) {
|
if !strings.Contains(reviews[i].Body, sentinel) {
|
||||||
continue
|
continue
|
||||||
@@ -944,3 +812,35 @@ func shouldSkipStaleReview(evaluatedSHA, currentSHA string) bool {
|
|||||||
}
|
}
|
||||||
return evaluatedSHA != currentSHA
|
return evaluatedSHA != currentSHA
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// giteaClientAdapter adapts gitea.Client to vcs.FileReader interface.
|
||||||
|
type giteaClientAdapter struct {
|
||||||
|
client *gitea.Client
|
||||||
|
}
|
||||||
|
|
||||||
|
func newGiteaClientAdapter(c *gitea.Client) *giteaClientAdapter {
|
||||||
|
return &giteaClientAdapter{client: c}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *giteaClientAdapter) ListContents(ctx context.Context, owner, repo, path string) ([]vcs.ContentEntry, error) {
|
||||||
|
entries, err := a.client.ListContents(ctx, owner, repo, path)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
result := make([]vcs.ContentEntry, len(entries))
|
||||||
|
for i, e := range entries {
|
||||||
|
result[i] = vcs.ContentEntry{
|
||||||
|
Name: e.Name,
|
||||||
|
Path: e.Path,
|
||||||
|
Type: e.Type,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return result, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *giteaClientAdapter) GetFileContent(ctx context.Context, owner, repo, filePath, ref string) (string, error) {
|
||||||
|
if ref != "" {
|
||||||
|
return a.client.GetFileContentRef(ctx, owner, repo, filePath, ref)
|
||||||
|
}
|
||||||
|
return a.client.GetFileContent(ctx, owner, repo, filePath)
|
||||||
|
}
|
||||||
|
|||||||
+36
-399
@@ -2,9 +2,7 @@ package main
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"context"
|
|
||||||
"flag"
|
"flag"
|
||||||
"fmt"
|
|
||||||
"log/slog"
|
"log/slog"
|
||||||
"os"
|
"os"
|
||||||
"os/exec"
|
"os/exec"
|
||||||
@@ -12,7 +10,7 @@ import (
|
|||||||
"strings"
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"gitea.weiker.me/rodin/review-bot/review"
|
"gitea.weiker.me/rodin/review-bot/gitea"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestValidateReviewerName(t *testing.T) {
|
func TestValidateReviewerName(t *testing.T) {
|
||||||
@@ -156,11 +154,12 @@ func TestValidateWorkspacePath(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func makeReview(id int64, login, state string, _ bool, body string) vcsReview {
|
func makeReview(id int64, login, state string, stale bool, body string) gitea.Review {
|
||||||
r := vcsReview{
|
r := gitea.Review{
|
||||||
ID: id,
|
ID: id,
|
||||||
Body: body,
|
Body: body,
|
||||||
State: state,
|
State: state,
|
||||||
|
Stale: stale,
|
||||||
}
|
}
|
||||||
r.User.Login = login
|
r.User.Login = login
|
||||||
return r
|
return r
|
||||||
@@ -217,7 +216,7 @@ func TestBuildSupersededBodyShortSHA(t *testing.T) {
|
|||||||
func TestFindOwnReview(t *testing.T) {
|
func TestFindOwnReview(t *testing.T) {
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
reviews []vcsReview
|
reviews []gitea.Review
|
||||||
sentinel string
|
sentinel string
|
||||||
wantID int64
|
wantID int64
|
||||||
wantNil bool
|
wantNil bool
|
||||||
@@ -230,7 +229,7 @@ func TestFindOwnReview(t *testing.T) {
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "found by sentinel",
|
name: "found by sentinel",
|
||||||
reviews: []vcsReview{
|
reviews: []gitea.Review{
|
||||||
makeReview(42, "bot", "APPROVED", false, "review body\n<!-- review-bot:sonnet -->"),
|
makeReview(42, "bot", "APPROVED", false, "review body\n<!-- review-bot:sonnet -->"),
|
||||||
},
|
},
|
||||||
sentinel: "<!-- review-bot:sonnet -->",
|
sentinel: "<!-- review-bot:sonnet -->",
|
||||||
@@ -238,7 +237,7 @@ func TestFindOwnReview(t *testing.T) {
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "wrong sentinel",
|
name: "wrong sentinel",
|
||||||
reviews: []vcsReview{
|
reviews: []gitea.Review{
|
||||||
makeReview(42, "bot", "APPROVED", false, "body\n<!-- review-bot:gpt -->"),
|
makeReview(42, "bot", "APPROVED", false, "body\n<!-- review-bot:gpt -->"),
|
||||||
},
|
},
|
||||||
sentinel: "<!-- review-bot:sonnet -->",
|
sentinel: "<!-- review-bot:sonnet -->",
|
||||||
@@ -246,7 +245,7 @@ func TestFindOwnReview(t *testing.T) {
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "multiple reviews, returns first match",
|
name: "multiple reviews, returns first match",
|
||||||
reviews: []vcsReview{
|
reviews: []gitea.Review{
|
||||||
makeReview(10, "bot", "APPROVED", false, "old\n<!-- review-bot:gpt -->"),
|
makeReview(10, "bot", "APPROVED", false, "old\n<!-- review-bot:gpt -->"),
|
||||||
makeReview(20, "bot", "APPROVED", false, "new\n<!-- review-bot:sonnet -->"),
|
makeReview(20, "bot", "APPROVED", false, "new\n<!-- review-bot:sonnet -->"),
|
||||||
},
|
},
|
||||||
@@ -255,7 +254,7 @@ func TestFindOwnReview(t *testing.T) {
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "skips superseded review",
|
name: "skips superseded review",
|
||||||
reviews: []vcsReview{
|
reviews: []gitea.Review{
|
||||||
makeReview(10, "bot", "APPROVED", false, "~~Original review~~\n\n**Superseded**\n<!-- review-bot:sonnet -->"),
|
makeReview(10, "bot", "APPROVED", false, "~~Original review~~\n\n**Superseded**\n<!-- review-bot:sonnet -->"),
|
||||||
makeReview(20, "bot", "APPROVED", false, "fresh review\n<!-- review-bot:sonnet -->"),
|
makeReview(20, "bot", "APPROVED", false, "fresh review\n<!-- review-bot:sonnet -->"),
|
||||||
},
|
},
|
||||||
@@ -264,7 +263,7 @@ func TestFindOwnReview(t *testing.T) {
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "only superseded reviews exist",
|
name: "only superseded reviews exist",
|
||||||
reviews: []vcsReview{
|
reviews: []gitea.Review{
|
||||||
makeReview(10, "bot", "APPROVED", false, "~~Original review~~\n\n<!-- review-bot:sonnet -->"),
|
makeReview(10, "bot", "APPROVED", false, "~~Original review~~\n\n<!-- review-bot:sonnet -->"),
|
||||||
},
|
},
|
||||||
sentinel: "<!-- review-bot:sonnet -->",
|
sentinel: "<!-- review-bot:sonnet -->",
|
||||||
@@ -272,7 +271,7 @@ func TestFindOwnReview(t *testing.T) {
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "picks highest ID among matches",
|
name: "picks highest ID among matches",
|
||||||
reviews: []vcsReview{
|
reviews: []gitea.Review{
|
||||||
makeReview(50, "bot", "APPROVED", false, "v1\n<!-- review-bot:sonnet -->"),
|
makeReview(50, "bot", "APPROVED", false, "v1\n<!-- review-bot:sonnet -->"),
|
||||||
makeReview(30, "bot", "APPROVED", false, "v0\n<!-- review-bot:sonnet -->"),
|
makeReview(30, "bot", "APPROVED", false, "v0\n<!-- review-bot:sonnet -->"),
|
||||||
},
|
},
|
||||||
@@ -303,7 +302,7 @@ func TestFindOwnReview(t *testing.T) {
|
|||||||
func TestHasSharedToken(t *testing.T) {
|
func TestHasSharedToken(t *testing.T) {
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
reviews []vcsReview
|
reviews []gitea.Review
|
||||||
sentinel string
|
sentinel string
|
||||||
want bool
|
want bool
|
||||||
}{
|
}{
|
||||||
@@ -315,36 +314,36 @@ func TestHasSharedToken(t *testing.T) {
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "no own review yet - cannot detect",
|
name: "no own review yet - cannot detect",
|
||||||
reviews: []vcsReview{
|
reviews: []gitea.Review{
|
||||||
{ID: 1, User: struct{ Login string }{Login: "other"}, Body: "<!-- review-bot:gpt --> body"},
|
{ID: 1, User: struct{ Login string `json:"login"` }{Login: "other"}, Body: "<!-- review-bot:gpt --> body"},
|
||||||
},
|
},
|
||||||
sentinel: "<!-- review-bot:sonnet -->",
|
sentinel: "<!-- review-bot:sonnet -->",
|
||||||
want: false,
|
want: false,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "separate users - no shared token",
|
name: "separate users - no shared token",
|
||||||
reviews: []vcsReview{
|
reviews: []gitea.Review{
|
||||||
{ID: 1, User: struct{ Login string }{Login: "sonnet-review-bot"}, Body: "<!-- review-bot:sonnet --> body"},
|
{ID: 1, User: struct{ Login string `json:"login"` }{Login: "sonnet-review-bot"}, Body: "<!-- review-bot:sonnet --> body"},
|
||||||
{ID: 2, User: struct{ Login string }{Login: "security-review-bot"}, Body: "<!-- review-bot:security --> body"},
|
{ID: 2, User: struct{ Login string `json:"login"` }{Login: "security-review-bot"}, Body: "<!-- review-bot:security --> body"},
|
||||||
},
|
},
|
||||||
sentinel: "<!-- review-bot:sonnet -->",
|
sentinel: "<!-- review-bot:sonnet -->",
|
||||||
want: false,
|
want: false,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "shared token detected - same user different sentinels",
|
name: "shared token detected - same user different sentinels",
|
||||||
reviews: []vcsReview{
|
reviews: []gitea.Review{
|
||||||
{ID: 1, User: struct{ Login string }{Login: "sonnet-review-bot"}, Body: "<!-- review-bot:sonnet --> body"},
|
{ID: 1, User: struct{ Login string `json:"login"` }{Login: "sonnet-review-bot"}, Body: "<!-- review-bot:sonnet --> body"},
|
||||||
{ID: 2, User: struct{ Login string }{Login: "sonnet-review-bot"}, Body: "<!-- review-bot:security --> body"},
|
{ID: 2, User: struct{ Login string `json:"login"` }{Login: "sonnet-review-bot"}, Body: "<!-- review-bot:security --> body"},
|
||||||
},
|
},
|
||||||
sentinel: "<!-- review-bot:sonnet -->",
|
sentinel: "<!-- review-bot:sonnet -->",
|
||||||
want: true,
|
want: true,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "three roles same user",
|
name: "three roles same user",
|
||||||
reviews: []vcsReview{
|
reviews: []gitea.Review{
|
||||||
{ID: 1, User: struct{ Login string }{Login: "bot"}, Body: "<!-- review-bot:sonnet --> body"},
|
{ID: 1, User: struct{ Login string `json:"login"` }{Login: "bot"}, Body: "<!-- review-bot:sonnet --> body"},
|
||||||
{ID: 2, User: struct{ Login string }{Login: "bot"}, Body: "<!-- review-bot:security --> body"},
|
{ID: 2, User: struct{ Login string `json:"login"` }{Login: "bot"}, Body: "<!-- review-bot:security --> body"},
|
||||||
{ID: 3, User: struct{ Login string }{Login: "bot"}, Body: "<!-- review-bot:gpt --> body"},
|
{ID: 3, User: struct{ Login string `json:"login"` }{Login: "bot"}, Body: "<!-- review-bot:gpt --> body"},
|
||||||
},
|
},
|
||||||
sentinel: "<!-- review-bot:sonnet -->",
|
sentinel: "<!-- review-bot:sonnet -->",
|
||||||
want: true,
|
want: true,
|
||||||
@@ -505,56 +504,10 @@ func TestIsPatternFile(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// TestBuildPatternPaths verifies the path-building logic for fetchPatterns.
|
|
||||||
// Empty patternsFiles means "fetch all from root" (represented as [""]).
|
|
||||||
func TestBuildPatternPaths(t *testing.T) {
|
|
||||||
buildPaths := func(patternsFiles string) []string {
|
|
||||||
if patternsFiles == "" {
|
|
||||||
return []string{""}
|
|
||||||
}
|
|
||||||
var paths []string
|
|
||||||
for _, p := range strings.Split(patternsFiles, ",") {
|
|
||||||
p = strings.TrimSpace(p)
|
|
||||||
if p != "" {
|
|
||||||
paths = append(paths, p)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return paths
|
|
||||||
}
|
|
||||||
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
input string
|
|
||||||
want []string
|
|
||||||
}{
|
|
||||||
{"empty fetches root", "", []string{""}},
|
|
||||||
{"single file", "README.md", []string{"README.md"}},
|
|
||||||
{"multiple files", "README.md,PATTERNS.md", []string{"README.md", "PATTERNS.md"}},
|
|
||||||
{"trims whitespace", " foo.md , bar.md ", []string{"foo.md", "bar.md"}},
|
|
||||||
{"skips empty between commas", "foo.md,,bar.md", []string{"foo.md", "bar.md"}},
|
|
||||||
{"directory path", "patterns/", []string{"patterns/"}},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tc := range tests {
|
|
||||||
t.Run(tc.name, func(t *testing.T) {
|
|
||||||
got := buildPaths(tc.input)
|
|
||||||
if len(got) != len(tc.want) {
|
|
||||||
t.Errorf("buildPaths(%q) = %v, want %v", tc.input, got, tc.want)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
for i := range got {
|
|
||||||
if got[i] != tc.want[i] {
|
|
||||||
t.Errorf("buildPaths(%q)[%d] = %q, want %q", tc.input, i, got[i], tc.want[i])
|
|
||||||
}
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestEvaluateCIStatus(t *testing.T) {
|
func TestEvaluateCIStatus(t *testing.T) {
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
statuses []vcsCommitStatus
|
statuses []gitea.CommitStatus
|
||||||
wantPassed bool
|
wantPassed bool
|
||||||
wantSubstr string
|
wantSubstr string
|
||||||
}{
|
}{
|
||||||
@@ -566,7 +519,7 @@ func TestEvaluateCIStatus(t *testing.T) {
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "all success",
|
name: "all success",
|
||||||
statuses: []vcsCommitStatus{
|
statuses: []gitea.CommitStatus{
|
||||||
{Status: "success", Context: "ci/build", Description: "Build passed"},
|
{Status: "success", Context: "ci/build", Description: "Build passed"},
|
||||||
{Status: "success", Context: "ci/test", Description: "Tests passed"},
|
{Status: "success", Context: "ci/test", Description: "Tests passed"},
|
||||||
},
|
},
|
||||||
@@ -575,7 +528,7 @@ func TestEvaluateCIStatus(t *testing.T) {
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "one failure",
|
name: "one failure",
|
||||||
statuses: []vcsCommitStatus{
|
statuses: []gitea.CommitStatus{
|
||||||
{Status: "success", Context: "ci/build", Description: "Build passed"},
|
{Status: "success", Context: "ci/build", Description: "Build passed"},
|
||||||
{Status: "failure", Context: "ci/test", Description: "Tests failed"},
|
{Status: "failure", Context: "ci/test", Description: "Tests failed"},
|
||||||
},
|
},
|
||||||
@@ -584,7 +537,7 @@ func TestEvaluateCIStatus(t *testing.T) {
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "error status",
|
name: "error status",
|
||||||
statuses: []vcsCommitStatus{
|
statuses: []gitea.CommitStatus{
|
||||||
{Status: "error", Context: "ci/lint", Description: "Lint error"},
|
{Status: "error", Context: "ci/lint", Description: "Lint error"},
|
||||||
},
|
},
|
||||||
wantPassed: false,
|
wantPassed: false,
|
||||||
@@ -592,7 +545,7 @@ func TestEvaluateCIStatus(t *testing.T) {
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "pending treated as not-failed",
|
name: "pending treated as not-failed",
|
||||||
statuses: []vcsCommitStatus{
|
statuses: []gitea.CommitStatus{
|
||||||
{Status: "pending", Context: "ci/build", Description: "In progress"},
|
{Status: "pending", Context: "ci/build", Description: "In progress"},
|
||||||
{Status: "success", Context: "ci/test", Description: "Tests passed"},
|
{Status: "success", Context: "ci/test", Description: "Tests passed"},
|
||||||
},
|
},
|
||||||
@@ -601,7 +554,7 @@ func TestEvaluateCIStatus(t *testing.T) {
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "multiple failures",
|
name: "multiple failures",
|
||||||
statuses: []vcsCommitStatus{
|
statuses: []gitea.CommitStatus{
|
||||||
{Status: "failure", Context: "ci/build", Description: "Build failed"},
|
{Status: "failure", Context: "ci/build", Description: "Build failed"},
|
||||||
{Status: "failure", Context: "ci/test", Description: "Tests failed"},
|
{Status: "failure", Context: "ci/test", Description: "Tests failed"},
|
||||||
},
|
},
|
||||||
@@ -610,7 +563,7 @@ func TestEvaluateCIStatus(t *testing.T) {
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "mixed with pending and failure",
|
name: "mixed with pending and failure",
|
||||||
statuses: []vcsCommitStatus{
|
statuses: []gitea.CommitStatus{
|
||||||
{Status: "success", Context: "ci/build", Description: "Build passed"},
|
{Status: "success", Context: "ci/build", Description: "Build passed"},
|
||||||
{Status: "pending", Context: "ci/deploy", Description: "Deploying"},
|
{Status: "pending", Context: "ci/deploy", Description: "Deploying"},
|
||||||
{Status: "failure", Context: "ci/test", Description: "Tests failed"},
|
{Status: "failure", Context: "ci/test", Description: "Tests failed"},
|
||||||
@@ -633,48 +586,6 @@ func TestEvaluateCIStatus(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestGithubAPIURL(t *testing.T) {
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
input string
|
|
||||||
want string
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
name: "empty string defaults to api.github.com",
|
|
||||||
input: "",
|
|
||||||
want: "https://api.github.com",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "github.com maps to api.github.com",
|
|
||||||
input: "https://github.com",
|
|
||||||
want: "https://api.github.com",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "github.com with trailing slash maps to api.github.com",
|
|
||||||
input: "https://github.com/",
|
|
||||||
want: "https://api.github.com",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "GHES host gets /api/v3 suffix",
|
|
||||||
input: "https://ghe.example.com",
|
|
||||||
want: "https://ghe.example.com/api/v3",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "GHES concur domain does not map to api.github.com",
|
|
||||||
input: "https://github.concur.com",
|
|
||||||
want: "https://github.concur.com/api/v3",
|
|
||||||
},
|
|
||||||
}
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
got := githubAPIURL(tt.input)
|
|
||||||
if got != tt.want {
|
|
||||||
t.Errorf("githubAPIURL(%q) = %q, want %q", tt.input, got, tt.want)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestEnvOrDefault(t *testing.T) {
|
func TestEnvOrDefault(t *testing.T) {
|
||||||
// Test with unset env var
|
// Test with unset env var
|
||||||
os.Unsetenv("TEST_ENV_OR_DEFAULT_UNSET")
|
os.Unsetenv("TEST_ENV_OR_DEFAULT_UNSET")
|
||||||
@@ -823,8 +734,8 @@ func TestExtractSentinelName_EdgeCases(t *testing.T) {
|
|||||||
{"<!-- review-bot:sonnet --> rest", "sonnet"},
|
{"<!-- review-bot:sonnet --> rest", "sonnet"},
|
||||||
{"<!-- review-bot:gpt-review --> rest", "gpt-review"},
|
{"<!-- review-bot:gpt-review --> rest", "gpt-review"},
|
||||||
{"no sentinel here", "unknown"},
|
{"no sentinel here", "unknown"},
|
||||||
{"<!-- review-bot:", "unknown"}, // prefix but no suffix
|
{"<!-- review-bot:", "unknown"}, // prefix but no suffix
|
||||||
{"prefix <!-- review-bot:abc --> end", "abc"}, // embedded in text
|
{"prefix <!-- review-bot:abc --> end", "abc"}, // embedded in text
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, tc := range tests {
|
for _, tc := range tests {
|
||||||
@@ -1015,7 +926,7 @@ func TestMainSubprocess_InvalidProvider(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// cleanEnv returns environ without any GITEA/LLM/REVIEWER/VCS env vars that would
|
// cleanEnv returns environ without any GITEA/LLM/REVIEWER env vars that would
|
||||||
// interfere with testing missing-flag scenarios.
|
// interfere with testing missing-flag scenarios.
|
||||||
func cleanEnv() []string {
|
func cleanEnv() []string {
|
||||||
var env []string
|
var env []string
|
||||||
@@ -1030,8 +941,7 @@ func cleanEnv() []string {
|
|||||||
strings.HasPrefix(key, "CONVENTIONS_"),
|
strings.HasPrefix(key, "CONVENTIONS_"),
|
||||||
strings.HasPrefix(key, "SYSTEM_PROMPT_"),
|
strings.HasPrefix(key, "SYSTEM_PROMPT_"),
|
||||||
strings.HasPrefix(key, "PATTERNS_"),
|
strings.HasPrefix(key, "PATTERNS_"),
|
||||||
strings.HasPrefix(key, "UPDATE_"),
|
strings.HasPrefix(key, "UPDATE_"):
|
||||||
strings.HasPrefix(key, "VCS_"):
|
|
||||||
continue
|
continue
|
||||||
default:
|
default:
|
||||||
env = append(env, e)
|
env = append(env, e)
|
||||||
@@ -1041,7 +951,7 @@ func cleanEnv() []string {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestFindAllOwnReviews(t *testing.T) {
|
func TestFindAllOwnReviews(t *testing.T) {
|
||||||
reviews := []vcsReview{
|
reviews := []gitea.Review{
|
||||||
{ID: 1, Body: "<!-- review-bot:sonnet -->\nfirst review"},
|
{ID: 1, Body: "<!-- review-bot:sonnet -->\nfirst review"},
|
||||||
{ID: 2, Body: "<!-- review-bot:gpt -->\nother bot"},
|
{ID: 2, Body: "<!-- review-bot:gpt -->\nother bot"},
|
||||||
{ID: 3, Body: "<!-- review-bot:sonnet -->\nsecond review"},
|
{ID: 3, Body: "<!-- review-bot:sonnet -->\nsecond review"},
|
||||||
@@ -1110,276 +1020,3 @@ func TestShouldSkipStaleReview(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// ============================================================
|
|
||||||
// Mock vcsClient for unit tests
|
|
||||||
// ============================================================
|
|
||||||
|
|
||||||
// mockVCSClient is a minimal mock of vcsClient for testing helper functions.
|
|
||||||
// Only the methods exercised by the test code need implementations; all others
|
|
||||||
// panic with a clear message to catch accidental calls.
|
|
||||||
type mockVCSClient struct {
|
|
||||||
fileContents map[string]string // key: "owner/repo/ref/path"
|
|
||||||
fileContentsErr map[string]error // key same as above → error to return
|
|
||||||
dirContents map[string][]review.ContentEntry
|
|
||||||
dirContentsErr map[string]error
|
|
||||||
allFiles map[string]map[string]string // key: "owner/repo/path"
|
|
||||||
allFilesErr map[string]error
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *mockVCSClient) key(owner, repo, extra string) string {
|
|
||||||
return owner + "/" + repo + "/" + extra
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *mockVCSClient) GetPullRequest(ctx context.Context, owner, repo string, number int) (*vcsPullRequest, error) {
|
|
||||||
panic("GetPullRequest not implemented in mockVCSClient")
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *mockVCSClient) GetPullRequestDiff(ctx context.Context, owner, repo string, number int) (string, error) {
|
|
||||||
panic("GetPullRequestDiff not implemented in mockVCSClient")
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *mockVCSClient) GetPullRequestFiles(ctx context.Context, owner, repo string, number int) ([]vcsChangedFile, error) {
|
|
||||||
panic("GetPullRequestFiles not implemented in mockVCSClient")
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *mockVCSClient) GetCommitStatuses(ctx context.Context, owner, repo, sha string) ([]vcsCommitStatus, error) {
|
|
||||||
panic("GetCommitStatuses not implemented in mockVCSClient")
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *mockVCSClient) GetFileContent(ctx context.Context, owner, repo, filepath string) (string, error) {
|
|
||||||
panic("GetFileContent not implemented in mockVCSClient")
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *mockVCSClient) GetFileContentRef(ctx context.Context, owner, repo, path, ref string) (string, error) {
|
|
||||||
k := m.key(owner, repo, ref+"/"+path)
|
|
||||||
if err, ok := m.fileContentsErr[k]; ok {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
if content, ok := m.fileContents[k]; ok {
|
|
||||||
return content, nil
|
|
||||||
}
|
|
||||||
return "", fmt.Errorf("HTTP 404: not found")
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *mockVCSClient) ListContents(ctx context.Context, owner, repo, path string) ([]review.ContentEntry, error) {
|
|
||||||
k := m.key(owner, repo, path)
|
|
||||||
if err, ok := m.dirContentsErr[k]; ok {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
if entries, ok := m.dirContents[k]; ok {
|
|
||||||
return entries, nil
|
|
||||||
}
|
|
||||||
return nil, fmt.Errorf("HTTP 404: not found")
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *mockVCSClient) GetAllFilesInPath(ctx context.Context, owner, repo, path string) (map[string]string, error) {
|
|
||||||
k := m.key(owner, repo, path)
|
|
||||||
if err, ok := m.allFilesErr[k]; ok {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
if files, ok := m.allFiles[k]; ok {
|
|
||||||
return files, nil
|
|
||||||
}
|
|
||||||
return nil, fmt.Errorf("HTTP 404: not found")
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *mockVCSClient) PostReview(ctx context.Context, owner, repo string, number int, event, body, commitID string, comments []vcsReviewComment) (*vcsReview, error) {
|
|
||||||
panic("PostReview not implemented in mockVCSClient")
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *mockVCSClient) ListReviews(ctx context.Context, owner, repo string, number int) ([]vcsReview, error) {
|
|
||||||
panic("ListReviews not implemented in mockVCSClient")
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *mockVCSClient) DeleteReview(ctx context.Context, owner, repo string, number int, reviewID int64) error {
|
|
||||||
panic("DeleteReview not implemented in mockVCSClient")
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *mockVCSClient) GetAuthenticatedUser(ctx context.Context) (string, error) {
|
|
||||||
panic("GetAuthenticatedUser not implemented in mockVCSClient")
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *mockVCSClient) RequestReviewer(ctx context.Context, owner, repo string, number int, reviewer string) error {
|
|
||||||
panic("RequestReviewer not implemented in mockVCSClient")
|
|
||||||
}
|
|
||||||
|
|
||||||
// ============================================================
|
|
||||||
// fetchFileContext tests
|
|
||||||
// ============================================================
|
|
||||||
|
|
||||||
func TestFetchFileContext_NoFiles(t *testing.T) {
|
|
||||||
ctx := context.Background()
|
|
||||||
client := &mockVCSClient{}
|
|
||||||
got := fetchFileContext(ctx, client, "owner", "repo", "main", nil)
|
|
||||||
if got != "" {
|
|
||||||
t.Errorf("expected empty string for no files, got: %q", got)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestFetchFileContext_SkipsRemovedFiles(t *testing.T) {
|
|
||||||
ctx := context.Background()
|
|
||||||
client := &mockVCSClient{}
|
|
||||||
files := []vcsChangedFile{
|
|
||||||
{Filename: "gone.go", Status: "removed"},
|
|
||||||
}
|
|
||||||
got := fetchFileContext(ctx, client, "owner", "repo", "main", files)
|
|
||||||
if got != "" {
|
|
||||||
t.Errorf("expected empty string for removed file, got: %q", got)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestFetchFileContext_FetchesModifiedFiles(t *testing.T) {
|
|
||||||
ctx := context.Background()
|
|
||||||
client := &mockVCSClient{
|
|
||||||
fileContents: map[string]string{
|
|
||||||
"owner/repo/main/foo.go": "package main\n\nfunc main() {}\n",
|
|
||||||
},
|
|
||||||
}
|
|
||||||
files := []vcsChangedFile{
|
|
||||||
{Filename: "foo.go", Status: "modified"},
|
|
||||||
}
|
|
||||||
got := fetchFileContext(ctx, client, "owner", "repo", "main", files)
|
|
||||||
if !strings.Contains(got, "--- foo.go ---") {
|
|
||||||
t.Errorf("expected file header in output, got: %q", got)
|
|
||||||
}
|
|
||||||
if !strings.Contains(got, "package main") {
|
|
||||||
t.Errorf("expected file content in output, got: %q", got)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestFetchFileContext_ContinuesOnError(t *testing.T) {
|
|
||||||
ctx := context.Background()
|
|
||||||
client := &mockVCSClient{
|
|
||||||
fileContents: map[string]string{
|
|
||||||
"owner/repo/main/good.go": "package good\n",
|
|
||||||
},
|
|
||||||
fileContentsErr: map[string]error{
|
|
||||||
"owner/repo/main/bad.go": fmt.Errorf("network error"),
|
|
||||||
},
|
|
||||||
}
|
|
||||||
files := []vcsChangedFile{
|
|
||||||
{Filename: "bad.go", Status: "modified"},
|
|
||||||
{Filename: "good.go", Status: "modified"},
|
|
||||||
}
|
|
||||||
got := fetchFileContext(ctx, client, "owner", "repo", "main", files)
|
|
||||||
// bad.go fails, good.go should still be included
|
|
||||||
if strings.Contains(got, "bad.go") {
|
|
||||||
t.Errorf("should not include failed file, got: %q", got)
|
|
||||||
}
|
|
||||||
if !strings.Contains(got, "good.go") {
|
|
||||||
t.Errorf("should include successful file, got: %q", got)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestFetchFileContext_RespectsContextCancellation(t *testing.T) {
|
|
||||||
ctx, cancel := context.WithCancel(context.Background())
|
|
||||||
cancel() // Cancel immediately
|
|
||||||
|
|
||||||
client := &mockVCSClient{
|
|
||||||
fileContents: map[string]string{
|
|
||||||
"owner/repo/main/foo.go": "package foo\n",
|
|
||||||
},
|
|
||||||
}
|
|
||||||
files := []vcsChangedFile{
|
|
||||||
{Filename: "foo.go", Status: "modified"},
|
|
||||||
}
|
|
||||||
got := fetchFileContext(ctx, client, "owner", "repo", "main", files)
|
|
||||||
// With cancelled context, the loop breaks before fetching
|
|
||||||
if got != "" {
|
|
||||||
t.Errorf("expected empty string with cancelled context, got: %q", got)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// ============================================================
|
|
||||||
// fetchPatterns tests
|
|
||||||
// ============================================================
|
|
||||||
|
|
||||||
func TestFetchPatterns_EmptyRepo(t *testing.T) {
|
|
||||||
ctx := context.Background()
|
|
||||||
client := &mockVCSClient{}
|
|
||||||
got := fetchPatterns(ctx, client, "", "")
|
|
||||||
if got != "" {
|
|
||||||
t.Errorf("expected empty string for empty patternsRepo, got: %q", got)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestFetchPatterns_SingleRepoAllFiles(t *testing.T) {
|
|
||||||
ctx := context.Background()
|
|
||||||
client := &mockVCSClient{
|
|
||||||
allFiles: map[string]map[string]string{
|
|
||||||
"rodin/patterns/": {
|
|
||||||
"patterns/go.md": "# Go patterns\n\nUse interfaces.",
|
|
||||||
"patterns/binary": "binary data",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
got := fetchPatterns(ctx, client, "rodin/patterns", "")
|
|
||||||
if !strings.Contains(got, "# Go patterns") {
|
|
||||||
t.Errorf("expected markdown content, got: %q", got)
|
|
||||||
}
|
|
||||||
// Binary file should be excluded
|
|
||||||
if strings.Contains(got, "binary data") {
|
|
||||||
t.Errorf("binary file should be excluded, got: %q", got)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestFetchPatterns_SpecificFiles(t *testing.T) {
|
|
||||||
ctx := context.Background()
|
|
||||||
client := &mockVCSClient{
|
|
||||||
allFiles: map[string]map[string]string{
|
|
||||||
"rodin/patterns/go.md": {
|
|
||||||
"go.md": "# Go idioms\n",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
got := fetchPatterns(ctx, client, "rodin/patterns", "go.md")
|
|
||||||
if !strings.Contains(got, "# Go idioms") {
|
|
||||||
t.Errorf("expected go idioms content, got: %q", got)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestFetchPatterns_SkipsInvalidRepo(t *testing.T) {
|
|
||||||
ctx := context.Background()
|
|
||||||
client := &mockVCSClient{}
|
|
||||||
// "badrepo" has no slash, should be skipped
|
|
||||||
got := fetchPatterns(ctx, client, "badrepo", "")
|
|
||||||
if got != "" {
|
|
||||||
t.Errorf("expected empty string for invalid repo format, got: %q", got)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestFetchPatterns_ContinuesOnFetchError(t *testing.T) {
|
|
||||||
ctx := context.Background()
|
|
||||||
client := &mockVCSClient{
|
|
||||||
allFilesErr: map[string]error{
|
|
||||||
"owner/repo/": fmt.Errorf("server error"),
|
|
||||||
},
|
|
||||||
}
|
|
||||||
// Should not panic; should return empty string
|
|
||||||
got := fetchPatterns(ctx, client, "owner/repo", "")
|
|
||||||
if got != "" {
|
|
||||||
t.Errorf("expected empty string on fetch error, got: %q", got)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestFetchPatterns_MultipleRepos(t *testing.T) {
|
|
||||||
ctx := context.Background()
|
|
||||||
client := &mockVCSClient{
|
|
||||||
allFiles: map[string]map[string]string{
|
|
||||||
"org/go-patterns/": {
|
|
||||||
"idioms.md": "# Go idioms\n",
|
|
||||||
},
|
|
||||||
"org/elixir-patterns/": {
|
|
||||||
"pipes.md": "# Elixir pipes\n",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
got := fetchPatterns(ctx, client, "org/go-patterns, org/elixir-patterns", "")
|
|
||||||
if !strings.Contains(got, "# Go idioms") {
|
|
||||||
t.Errorf("expected Go idioms content, got: %q", got)
|
|
||||||
}
|
|
||||||
if !strings.Contains(got, "# Elixir pipes") {
|
|
||||||
t.Errorf("expected Elixir pipes content, got: %q", got)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -1,125 +0,0 @@
|
|||||||
package main
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"errors"
|
|
||||||
"fmt"
|
|
||||||
"net"
|
|
||||||
"net/url"
|
|
||||||
"strings"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"gitea.weiker.me/rodin/review-bot/gitea"
|
|
||||||
)
|
|
||||||
|
|
||||||
// runValidateURL implements the `review-bot validate-url <url>` subcommand.
|
|
||||||
//
|
|
||||||
// It resolves the given URL's hostname and checks that every returned IP is
|
|
||||||
// publicly routable (not RFC1918, loopback, link-local, or other reserved
|
|
||||||
// ranges). The exit code communicates the result to callers:
|
|
||||||
//
|
|
||||||
// 0 — URL is safe to use
|
|
||||||
// 1 — URL resolves to a blocked/private address
|
|
||||||
// 2 — URL is malformed, has an unsafe scheme, or DNS lookup failed
|
|
||||||
//
|
|
||||||
// This is intended for use from action.yml shell steps that need to validate
|
|
||||||
// a user-supplied URL before passing it to curl.
|
|
||||||
func runValidateURL(args []string) int {
|
|
||||||
if len(args) != 1 {
|
|
||||||
fmt.Fprintln(errWriter, "usage: review-bot validate-url <url>")
|
|
||||||
fmt.Fprintln(errWriter, "")
|
|
||||||
fmt.Fprintln(errWriter, "Resolves <url> and verifies all resolved IPs are publicly routable.")
|
|
||||||
fmt.Fprintln(errWriter, "Exit 0=safe, 1=blocked, 2=error")
|
|
||||||
return 2
|
|
||||||
}
|
|
||||||
rawURL := args[0]
|
|
||||||
|
|
||||||
if err := validateURL(rawURL); err != nil {
|
|
||||||
fmt.Fprintf(errWriter, "Error: %v\n", err)
|
|
||||||
var ve *validateError
|
|
||||||
if isValidateError(err, &ve) {
|
|
||||||
return ve.code
|
|
||||||
}
|
|
||||||
return 2
|
|
||||||
}
|
|
||||||
fmt.Fprintf(outWriter, "OK: %s is safe\n", rawURL)
|
|
||||||
return 0
|
|
||||||
}
|
|
||||||
|
|
||||||
// validateError carries an exit code alongside a message.
|
|
||||||
type validateError struct {
|
|
||||||
code int
|
|
||||||
message string
|
|
||||||
}
|
|
||||||
|
|
||||||
func (e *validateError) Error() string { return e.message }
|
|
||||||
|
|
||||||
// isValidateError checks if err is or wraps a *validateError and sets out.
|
|
||||||
// Uses errors.As so that wrapped *validateError values (e.g. from fmt.Errorf("...: %w", &validateError{...}))
|
|
||||||
// are also detected, making the function robust against future wrapping.
|
|
||||||
func isValidateError(err error, out **validateError) bool {
|
|
||||||
if err == nil {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
return errors.As(err, out)
|
|
||||||
}
|
|
||||||
|
|
||||||
// validateURL checks that rawURL is safe for use as a Gitea server URL:
|
|
||||||
// - Must be https:// (not http://)
|
|
||||||
// - Must have no user-info (user:pass@host)
|
|
||||||
// - Must resolve to at least one IP, all of which are publicly routable
|
|
||||||
func validateURL(rawURL string) error {
|
|
||||||
parsed, err := url.Parse(rawURL)
|
|
||||||
if err != nil {
|
|
||||||
return &validateError{code: 2, message: fmt.Sprintf("malformed URL %q: %v", rawURL, err)}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Scheme check: only https is permitted.
|
|
||||||
if !strings.EqualFold(parsed.Scheme, "https") {
|
|
||||||
return &validateError{
|
|
||||||
code: 2,
|
|
||||||
message: fmt.Sprintf("URL scheme must be https (got %q)", parsed.Scheme),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Reject user-info (user:password@host) to prevent credential embedding.
|
|
||||||
if parsed.User != nil {
|
|
||||||
return &validateError{
|
|
||||||
code: 2,
|
|
||||||
message: "URL must not contain user-info (user:password@host)",
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
host := parsed.Hostname()
|
|
||||||
if host == "" {
|
|
||||||
return &validateError{code: 2, message: fmt.Sprintf("URL has no host: %q", rawURL)}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Resolve the hostname with a short timeout.
|
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
|
||||||
defer cancel()
|
|
||||||
|
|
||||||
addrs, err := net.DefaultResolver.LookupIPAddr(ctx, host)
|
|
||||||
if err != nil {
|
|
||||||
return &validateError{
|
|
||||||
code: 2,
|
|
||||||
message: fmt.Sprintf("DNS lookup failed for %q: %v", host, err),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if len(addrs) == 0 {
|
|
||||||
return &validateError{
|
|
||||||
code: 2,
|
|
||||||
message: fmt.Sprintf("DNS lookup returned no addresses for %q", host),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, a := range addrs {
|
|
||||||
if gitea.IsBlockedIP(a.IP) {
|
|
||||||
return &validateError{
|
|
||||||
code: 1,
|
|
||||||
message: fmt.Sprintf("blocked: %q resolves to private/reserved IP %s", host, a.IP),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
@@ -1,127 +0,0 @@
|
|||||||
package main
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bytes"
|
|
||||||
"strings"
|
|
||||||
"testing"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestRunValidateURL_Usage(t *testing.T) {
|
|
||||||
var errBuf bytes.Buffer
|
|
||||||
origErr := errWriter
|
|
||||||
errWriter = &errBuf
|
|
||||||
defer func() { errWriter = origErr }()
|
|
||||||
|
|
||||||
code := runValidateURL(nil)
|
|
||||||
if code != 2 {
|
|
||||||
t.Errorf("expected exit code 2 for no args, got %d", code)
|
|
||||||
}
|
|
||||||
if !strings.Contains(errBuf.String(), "usage") {
|
|
||||||
t.Errorf("expected usage in stderr, got %q", errBuf.String())
|
|
||||||
}
|
|
||||||
|
|
||||||
errBuf.Reset()
|
|
||||||
code = runValidateURL([]string{"arg1", "arg2"})
|
|
||||||
if code != 2 {
|
|
||||||
t.Errorf("expected exit code 2 for too many args, got %d", code)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestValidateURL_MalformedURL(t *testing.T) {
|
|
||||||
cases := []struct {
|
|
||||||
name string
|
|
||||||
url string
|
|
||||||
wantMsg string
|
|
||||||
}{
|
|
||||||
{"empty", "", "must be https"},
|
|
||||||
{"http scheme", "http://example.com/", "must be https"},
|
|
||||||
{"ftp scheme", "ftp://example.com/", "must be https"},
|
|
||||||
{"no scheme", "example.com", "must be https"},
|
|
||||||
{"user info", "https://user:pass@example.com/", "user-info"},
|
|
||||||
}
|
|
||||||
for _, tc := range cases {
|
|
||||||
t.Run(tc.name, func(t *testing.T) {
|
|
||||||
err := validateURL(tc.url)
|
|
||||||
if err == nil {
|
|
||||||
t.Errorf("expected error for URL %q, got nil", tc.url)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if !strings.Contains(err.Error(), tc.wantMsg) {
|
|
||||||
t.Errorf("error %q does not contain %q", err.Error(), tc.wantMsg)
|
|
||||||
}
|
|
||||||
var ve *validateError
|
|
||||||
if !isValidateError(err, &ve) {
|
|
||||||
t.Fatalf("expected *validateError, got %T", err)
|
|
||||||
}
|
|
||||||
if ve.code != 2 {
|
|
||||||
t.Errorf("expected code 2, got %d", ve.code)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestValidateURL_BlockedPrivateIP(t *testing.T) {
|
|
||||||
// localhost always resolves to 127.0.0.1 (loopback).
|
|
||||||
err := validateURL("https://localhost/")
|
|
||||||
if err == nil {
|
|
||||||
t.Skip("localhost did not resolve (network unavailable in test environment)")
|
|
||||||
}
|
|
||||||
var ve *validateError
|
|
||||||
if !isValidateError(err, &ve) {
|
|
||||||
t.Fatalf("expected *validateError, got %T: %v", err, err)
|
|
||||||
}
|
|
||||||
if ve.code != 1 && ve.code != 2 {
|
|
||||||
t.Errorf("expected code 1 (blocked) or 2 (dns fail), got %d: %s", ve.code, ve.message)
|
|
||||||
}
|
|
||||||
// If it resolved (code 1), the message must say "blocked".
|
|
||||||
if ve.code == 1 && !strings.Contains(ve.message, "blocked") {
|
|
||||||
t.Errorf("expected 'blocked' in message, got %q", ve.message)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestValidateURL_ExitCodes(t *testing.T) {
|
|
||||||
cases := []struct {
|
|
||||||
name string
|
|
||||||
url string
|
|
||||||
wantCode int
|
|
||||||
}{
|
|
||||||
{"http scheme", "http://example.com/", 2},
|
|
||||||
{"no scheme", "example.com", 2},
|
|
||||||
{"user info", "https://admin:secret@example.com/", 2},
|
|
||||||
}
|
|
||||||
for _, tc := range cases {
|
|
||||||
t.Run(tc.name, func(t *testing.T) {
|
|
||||||
err := validateURL(tc.url)
|
|
||||||
if err == nil {
|
|
||||||
t.Fatalf("expected error for %q", tc.url)
|
|
||||||
}
|
|
||||||
var ve *validateError
|
|
||||||
if !isValidateError(err, &ve) {
|
|
||||||
t.Fatalf("expected *validateError, got %T", err)
|
|
||||||
}
|
|
||||||
if ve.code != tc.wantCode {
|
|
||||||
t.Errorf("code = %d, want %d (url=%q, msg=%s)", ve.code, tc.wantCode, tc.url, ve.message)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestRunValidateURL_WithCapture(t *testing.T) {
|
|
||||||
var outBuf, errBuf bytes.Buffer
|
|
||||||
origOut, origErr := outWriter, errWriter
|
|
||||||
outWriter = &outBuf
|
|
||||||
errWriter = &errBuf
|
|
||||||
defer func() {
|
|
||||||
outWriter = origOut
|
|
||||||
errWriter = origErr
|
|
||||||
}()
|
|
||||||
|
|
||||||
// http:// scheme should fail with code 2.
|
|
||||||
code := runValidateURL([]string{"http://example.com/"})
|
|
||||||
if code != 2 {
|
|
||||||
t.Errorf("expected code 2 for http:// URL, got %d", code)
|
|
||||||
}
|
|
||||||
if !strings.Contains(errBuf.String(), "must be https") {
|
|
||||||
t.Errorf("expected error about https in stderr, got %q", errBuf.String())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,361 +0,0 @@
|
|||||||
package main
|
|
||||||
|
|
||||||
// vcs.go defines the vcsClient interface that both gitea.Client (via giteaVCSAdapter)
|
|
||||||
// and github.Client (via githubVCSAdapter) satisfy, enabling VCS-type routing in main.go.
|
|
||||||
//
|
|
||||||
// Interface design:
|
|
||||||
// - Methods cover all PR review operations used by main.go.
|
|
||||||
// - Gitea-specific operations (supersede, comment resolution) are in the separate
|
|
||||||
// giteaExtClient interface. GitHub implementations return ErrNotSupported for those.
|
|
||||||
// - Types are defined here as package-local VCS types; each adapter converts from
|
|
||||||
// its respective client package's types.
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"errors"
|
|
||||||
|
|
||||||
"gitea.weiker.me/rodin/review-bot/gitea"
|
|
||||||
"gitea.weiker.me/rodin/review-bot/github"
|
|
||||||
"gitea.weiker.me/rodin/review-bot/review"
|
|
||||||
)
|
|
||||||
|
|
||||||
// ErrNotSupported is returned by VCS methods that have no implementation for
|
|
||||||
// a particular VCS backend (e.g., Gitea-specific timeline APIs on GitHub).
|
|
||||||
var ErrNotSupported = errors.New("operation not supported on this VCS backend")
|
|
||||||
|
|
||||||
// vcsClient is the interface for all PR operations used by main.go.
|
|
||||||
// It is implemented by both giteaVCSAdapter and githubVCSAdapter.
|
|
||||||
// Interface defined here (in the consumer package) per Go idiom.
|
|
||||||
type vcsClient interface {
|
|
||||||
// PR metadata and content
|
|
||||||
GetPullRequest(ctx context.Context, owner, repo string, number int) (*vcsPullRequest, error)
|
|
||||||
GetPullRequestDiff(ctx context.Context, owner, repo string, number int) (string, error)
|
|
||||||
GetPullRequestFiles(ctx context.Context, owner, repo string, number int) ([]vcsChangedFile, error)
|
|
||||||
GetCommitStatuses(ctx context.Context, owner, repo, sha string) ([]vcsCommitStatus, error)
|
|
||||||
GetFileContent(ctx context.Context, owner, repo, filepath string) (string, error)
|
|
||||||
GetFileContentRef(ctx context.Context, owner, repo, filepath, ref string) (string, error)
|
|
||||||
ListContents(ctx context.Context, owner, repo, path string) ([]review.ContentEntry, error)
|
|
||||||
GetAllFilesInPath(ctx context.Context, owner, repo, path string) (map[string]string, error)
|
|
||||||
|
|
||||||
// Review operations
|
|
||||||
PostReview(ctx context.Context, owner, repo string, number int, event, body, commitID string, comments []vcsReviewComment) (*vcsReview, error)
|
|
||||||
ListReviews(ctx context.Context, owner, repo string, number int) ([]vcsReview, error)
|
|
||||||
DeleteReview(ctx context.Context, owner, repo string, number int, reviewID int64) error
|
|
||||||
GetAuthenticatedUser(ctx context.Context) (string, error)
|
|
||||||
RequestReviewer(ctx context.Context, owner, repo string, number int, reviewer string) error
|
|
||||||
}
|
|
||||||
|
|
||||||
// giteaExtClient extends vcsClient with Gitea-specific operations that have no
|
|
||||||
// GitHub equivalent. Code that uses these methods should first do a type assertion.
|
|
||||||
type giteaExtClient interface {
|
|
||||||
vcsClient
|
|
||||||
GetTimelineReviewCommentIDForReview(ctx context.Context, owner, repo string, prNum, reviewID int64) (int64, error)
|
|
||||||
EditComment(ctx context.Context, owner, repo string, commentID int64, body string) error
|
|
||||||
ListReviewComments(ctx context.Context, owner, repo string, prNum, reviewID int64) ([]gitea.ReviewComment, error)
|
|
||||||
ResolveComment(ctx context.Context, owner, repo string, commentID int64) error
|
|
||||||
}
|
|
||||||
|
|
||||||
// --- shared VCS types ---
|
|
||||||
|
|
||||||
// vcsPullRequest is VCS-agnostic PR metadata.
|
|
||||||
type vcsPullRequest struct {
|
|
||||||
Title string
|
|
||||||
Body string
|
|
||||||
Head struct {
|
|
||||||
Sha string
|
|
||||||
Ref string
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// vcsChangedFile is a file changed in a PR.
|
|
||||||
type vcsChangedFile struct {
|
|
||||||
Filename string
|
|
||||||
Status string
|
|
||||||
}
|
|
||||||
|
|
||||||
// vcsCommitStatus is a CI status entry.
|
|
||||||
type vcsCommitStatus struct {
|
|
||||||
Status string
|
|
||||||
Context string
|
|
||||||
Description string
|
|
||||||
TargetURL string
|
|
||||||
}
|
|
||||||
|
|
||||||
// vcsReviewComment is an inline review comment.
|
|
||||||
type vcsReviewComment struct {
|
|
||||||
Path string
|
|
||||||
NewPosition int64 // Gitea: absolute line; GitHub: diff hunk position
|
|
||||||
Body string
|
|
||||||
}
|
|
||||||
|
|
||||||
// vcsReview is a submitted PR review.
|
|
||||||
type vcsReview struct {
|
|
||||||
ID int64
|
|
||||||
Body string
|
|
||||||
CommitID string
|
|
||||||
User struct {
|
|
||||||
Login string
|
|
||||||
}
|
|
||||||
State string
|
|
||||||
}
|
|
||||||
|
|
||||||
// ============================================================
|
|
||||||
// giteaVCSAdapter
|
|
||||||
// ============================================================
|
|
||||||
|
|
||||||
// giteaVCSAdapter wraps gitea.Client to implement vcsClient + giteaExtClient.
|
|
||||||
type giteaVCSAdapter struct {
|
|
||||||
c *gitea.Client
|
|
||||||
}
|
|
||||||
|
|
||||||
func newGiteaVCSAdapter(c *gitea.Client) *giteaVCSAdapter { return &giteaVCSAdapter{c: c} }
|
|
||||||
|
|
||||||
func (a *giteaVCSAdapter) GetPullRequest(ctx context.Context, owner, repo string, number int) (*vcsPullRequest, error) {
|
|
||||||
pr, err := a.c.GetPullRequest(ctx, owner, repo, number)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
r := &vcsPullRequest{Title: pr.Title, Body: pr.Body}
|
|
||||||
r.Head.Sha = pr.Head.Sha
|
|
||||||
r.Head.Ref = pr.Head.Ref
|
|
||||||
return r, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (a *giteaVCSAdapter) GetPullRequestDiff(ctx context.Context, owner, repo string, number int) (string, error) {
|
|
||||||
return a.c.GetPullRequestDiff(ctx, owner, repo, number)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (a *giteaVCSAdapter) GetPullRequestFiles(ctx context.Context, owner, repo string, number int) ([]vcsChangedFile, error) {
|
|
||||||
files, err := a.c.GetPullRequestFiles(ctx, owner, repo, number)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
out := make([]vcsChangedFile, len(files))
|
|
||||||
for i, f := range files {
|
|
||||||
out[i] = vcsChangedFile{Filename: f.Filename, Status: f.Status}
|
|
||||||
}
|
|
||||||
return out, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (a *giteaVCSAdapter) GetCommitStatuses(ctx context.Context, owner, repo, sha string) ([]vcsCommitStatus, error) {
|
|
||||||
statuses, err := a.c.GetCommitStatuses(ctx, owner, repo, sha)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
out := make([]vcsCommitStatus, len(statuses))
|
|
||||||
for i, s := range statuses {
|
|
||||||
out[i] = vcsCommitStatus{Status: s.Status, Context: s.Context, Description: s.Description, TargetURL: s.TargetURL}
|
|
||||||
}
|
|
||||||
return out, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (a *giteaVCSAdapter) GetFileContent(ctx context.Context, owner, repo, filepath string) (string, error) {
|
|
||||||
return a.c.GetFileContent(ctx, owner, repo, filepath)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (a *giteaVCSAdapter) GetFileContentRef(ctx context.Context, owner, repo, filepath, ref string) (string, error) {
|
|
||||||
return a.c.GetFileContentRef(ctx, owner, repo, filepath, ref)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (a *giteaVCSAdapter) ListContents(ctx context.Context, owner, repo, path string) ([]review.ContentEntry, error) {
|
|
||||||
entries, err := a.c.ListContents(ctx, owner, repo, path)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
out := make([]review.ContentEntry, len(entries))
|
|
||||||
for i, e := range entries {
|
|
||||||
out[i] = review.ContentEntry{Name: e.Name, Path: e.Path, Type: e.Type}
|
|
||||||
}
|
|
||||||
return out, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (a *giteaVCSAdapter) GetAllFilesInPath(ctx context.Context, owner, repo, path string) (map[string]string, error) {
|
|
||||||
return a.c.GetAllFilesInPath(ctx, owner, repo, path)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (a *giteaVCSAdapter) PostReview(ctx context.Context, owner, repo string, number int, event, body, commitID string, comments []vcsReviewComment) (*vcsReview, error) {
|
|
||||||
gc := make([]gitea.ReviewComment, len(comments))
|
|
||||||
for i, c := range comments {
|
|
||||||
gc[i] = gitea.ReviewComment{Path: c.Path, NewPosition: c.NewPosition, Body: c.Body}
|
|
||||||
}
|
|
||||||
r, err := a.c.PostReview(ctx, owner, repo, number, event, body, commitID, gc)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
out := &vcsReview{ID: r.ID, Body: r.Body, CommitID: r.CommitID, State: r.State}
|
|
||||||
out.User.Login = r.User.Login
|
|
||||||
return out, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (a *giteaVCSAdapter) ListReviews(ctx context.Context, owner, repo string, number int) ([]vcsReview, error) {
|
|
||||||
reviews, err := a.c.ListReviews(ctx, owner, repo, number)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
out := make([]vcsReview, len(reviews))
|
|
||||||
for i, r := range reviews {
|
|
||||||
out[i] = vcsReview{ID: r.ID, Body: r.Body, CommitID: r.CommitID, State: r.State}
|
|
||||||
out[i].User.Login = r.User.Login
|
|
||||||
}
|
|
||||||
return out, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (a *giteaVCSAdapter) DeleteReview(ctx context.Context, owner, repo string, number int, reviewID int64) error {
|
|
||||||
return a.c.DeleteReview(ctx, owner, repo, number, reviewID)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (a *giteaVCSAdapter) GetAuthenticatedUser(ctx context.Context) (string, error) {
|
|
||||||
return a.c.GetAuthenticatedUser(ctx)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (a *giteaVCSAdapter) RequestReviewer(ctx context.Context, owner, repo string, number int, reviewer string) error {
|
|
||||||
return a.c.RequestReviewer(ctx, owner, repo, number, reviewer)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Gitea-specific extension methods.
|
|
||||||
|
|
||||||
func (a *giteaVCSAdapter) GetTimelineReviewCommentIDForReview(ctx context.Context, owner, repo string, prNum, reviewID int64) (int64, error) {
|
|
||||||
return a.c.GetTimelineReviewCommentIDForReview(ctx, owner, repo, int(prNum), reviewID)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (a *giteaVCSAdapter) EditComment(ctx context.Context, owner, repo string, commentID int64, body string) error {
|
|
||||||
return a.c.EditComment(ctx, owner, repo, commentID, body)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (a *giteaVCSAdapter) ListReviewComments(ctx context.Context, owner, repo string, prNum, reviewID int64) ([]gitea.ReviewComment, error) {
|
|
||||||
return a.c.ListReviewComments(ctx, owner, repo, int(prNum), reviewID)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (a *giteaVCSAdapter) ResolveComment(ctx context.Context, owner, repo string, commentID int64) error {
|
|
||||||
return a.c.ResolveComment(ctx, owner, repo, commentID)
|
|
||||||
}
|
|
||||||
|
|
||||||
// ============================================================
|
|
||||||
// githubVCSAdapter
|
|
||||||
// ============================================================
|
|
||||||
|
|
||||||
// githubVCSAdapter wraps github.Client to implement vcsClient.
|
|
||||||
// Gitea-specific extension methods (GetTimelineReviewCommentIDForReview, EditComment,
|
|
||||||
// ListReviewComments, ResolveComment) are not available on GitHub and will not be called
|
|
||||||
// because main.go gates them with a type assertion to giteaExtClient.
|
|
||||||
type githubVCSAdapter struct {
|
|
||||||
c *github.Client
|
|
||||||
}
|
|
||||||
|
|
||||||
func newGithubVCSAdapter(c *github.Client) *githubVCSAdapter { return &githubVCSAdapter{c: c} }
|
|
||||||
|
|
||||||
func (a *githubVCSAdapter) GetPullRequest(ctx context.Context, owner, repo string, number int) (*vcsPullRequest, error) {
|
|
||||||
pr, err := a.c.GetPullRequest(ctx, owner, repo, number)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
r := &vcsPullRequest{Title: pr.Title, Body: pr.Body}
|
|
||||||
r.Head.Sha = pr.Head.Sha
|
|
||||||
r.Head.Ref = pr.Head.Ref
|
|
||||||
return r, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (a *githubVCSAdapter) GetPullRequestDiff(ctx context.Context, owner, repo string, number int) (string, error) {
|
|
||||||
return a.c.GetPullRequestDiff(ctx, owner, repo, number)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (a *githubVCSAdapter) GetPullRequestFiles(ctx context.Context, owner, repo string, number int) ([]vcsChangedFile, error) {
|
|
||||||
files, err := a.c.GetPullRequestFiles(ctx, owner, repo, number)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
out := make([]vcsChangedFile, len(files))
|
|
||||||
for i, f := range files {
|
|
||||||
out[i] = vcsChangedFile{Filename: f.Filename, Status: f.Status}
|
|
||||||
}
|
|
||||||
return out, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (a *githubVCSAdapter) GetCommitStatuses(ctx context.Context, owner, repo, sha string) ([]vcsCommitStatus, error) {
|
|
||||||
statuses, err := a.c.GetCommitStatuses(ctx, owner, repo, sha)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
out := make([]vcsCommitStatus, len(statuses))
|
|
||||||
for i, s := range statuses {
|
|
||||||
// CommitStatus.Status is tagged as json:"state" — already the normalized "state" value
|
|
||||||
out[i] = vcsCommitStatus{Status: s.Status, Context: s.Context, Description: s.Description, TargetURL: s.TargetURL}
|
|
||||||
}
|
|
||||||
return out, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (a *githubVCSAdapter) GetFileContent(ctx context.Context, owner, repo, filepath string) (string, error) {
|
|
||||||
return a.c.GetFileContent(ctx, owner, repo, filepath)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (a *githubVCSAdapter) GetFileContentRef(ctx context.Context, owner, repo, filepath, ref string) (string, error) {
|
|
||||||
return a.c.GetFileContentRef(ctx, owner, repo, filepath, ref)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (a *githubVCSAdapter) ListContents(ctx context.Context, owner, repo, path string) ([]review.ContentEntry, error) {
|
|
||||||
entries, err := a.c.ListContents(ctx, owner, repo, path)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
out := make([]review.ContentEntry, len(entries))
|
|
||||||
for i, e := range entries {
|
|
||||||
out[i] = review.ContentEntry{Name: e.Name, Path: e.Path, Type: e.Type}
|
|
||||||
}
|
|
||||||
return out, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (a *githubVCSAdapter) GetAllFilesInPath(ctx context.Context, owner, repo, path string) (map[string]string, error) {
|
|
||||||
return a.c.GetAllFilesInPath(ctx, owner, repo, path)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (a *githubVCSAdapter) PostReview(ctx context.Context, owner, repo string, number int, event, body, commitID string, comments []vcsReviewComment) (*vcsReview, error) {
|
|
||||||
gc := make([]github.ReviewComment, len(comments))
|
|
||||||
for i, c := range comments {
|
|
||||||
// GitHub inline comments use diff hunk "position", not absolute line numbers.
|
|
||||||
// NewPosition from gitea diff parsing gives absolute line numbers, which
|
|
||||||
// will not match GitHub's position values. For initial GitHub support, we
|
|
||||||
// attach comments with Line+Side (absolute line on the RIGHT side) instead.
|
|
||||||
// Comments that cannot be mapped will be omitted (GitHub rejects invalid positions).
|
|
||||||
gc[i] = github.ReviewComment{
|
|
||||||
Path: c.Path,
|
|
||||||
Line: c.NewPosition,
|
|
||||||
Side: "RIGHT",
|
|
||||||
Body: c.Body,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
r, err := a.c.PostReview(ctx, owner, repo, number, event, body, commitID, gc)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
out := &vcsReview{ID: r.ID, Body: r.Body, State: r.State}
|
|
||||||
out.User.Login = r.User.Login
|
|
||||||
return out, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (a *githubVCSAdapter) ListReviews(ctx context.Context, owner, repo string, number int) ([]vcsReview, error) {
|
|
||||||
reviews, err := a.c.ListReviews(ctx, owner, repo, number)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
out := make([]vcsReview, len(reviews))
|
|
||||||
for i, r := range reviews {
|
|
||||||
out[i] = vcsReview{ID: r.ID, Body: r.Body, State: r.State}
|
|
||||||
out[i].User.Login = r.User.Login
|
|
||||||
}
|
|
||||||
return out, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (a *githubVCSAdapter) DeleteReview(ctx context.Context, owner, repo string, number int, reviewID int64) error {
|
|
||||||
// GitHub only allows deleting PENDING (draft) reviews. review-bot posts submitted
|
|
||||||
// reviews, so this will return an error for any review we actually posted.
|
|
||||||
// Callers should treat 422 errors here gracefully.
|
|
||||||
return a.c.DeleteReview(ctx, owner, repo, number, reviewID)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (a *githubVCSAdapter) GetAuthenticatedUser(ctx context.Context) (string, error) {
|
|
||||||
return a.c.GetAuthenticatedUser(ctx)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (a *githubVCSAdapter) RequestReviewer(ctx context.Context, owner, repo string, number int, reviewer string) error {
|
|
||||||
return a.c.RequestReviewer(ctx, owner, repo, number, reviewer)
|
|
||||||
}
|
|
||||||
@@ -9,7 +9,7 @@ JSON is awkward for persona files that contain multi-line text (identity, severi
|
|||||||
- Backwards compatibility: existing JSON personas must continue to work
|
- Backwards compatibility: existing JSON personas must continue to work
|
||||||
- Security: protect against DoS via deeply nested YAML (AIKIDO-2024-10486)
|
- Security: protect against DoS via deeply nested YAML (AIKIDO-2024-10486)
|
||||||
- Consistency: use `.yaml` extension (not `.yml`)
|
- Consistency: use `.yaml` extension (not `.yml`)
|
||||||
- Library: use `github.com/goccy/go-yaml` v1.16.0+ (approved in CONVENTIONS.md); we implement custom AST-based depth/node-count checks for precise alias-aware validation
|
- Library: use `gopkg.in/yaml.v3` (approved in CONVENTIONS.md) with explicit depth limiting
|
||||||
|
|
||||||
## Proposed Approach
|
## Proposed Approach
|
||||||
|
|
||||||
@@ -33,16 +33,37 @@ func parsePersona(data []byte, source string) (*Persona, error) {
|
|||||||
|
|
||||||
### YAML Parsing with Depth Protection
|
### YAML Parsing with Depth Protection
|
||||||
|
|
||||||
We implement a custom AST-based depth/node-count walk (`checkYAMLDepth` in
|
```go
|
||||||
`review/persona.go`) rather than relying on library decoder options. Key design
|
func unmarshalYAMLWithDepthLimit(data []byte, out any, maxDepth int) error {
|
||||||
decisions:
|
var node yaml.Node
|
||||||
|
dec := yaml.NewDecoder(bytes.NewReader(data))
|
||||||
|
if err := dec.Decode(&node); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if err := checkYAMLDepth(&node, 0, maxDepth); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return node.Decode(out)
|
||||||
|
}
|
||||||
|
|
||||||
- **Library:** `github.com/goccy/go-yaml` with `ast.Node`-based traversal
|
func checkYAMLDepth(node *yaml.Node, depth, maxDepth int) error {
|
||||||
- **Dual-map tracking:** `validated` (depth-aware short-circuit) + `visiting` (cycle detection)
|
if depth > maxDepth {
|
||||||
- **Node-count limit:** Conservative overcounting bounds total validation work
|
return fmt.Errorf("YAML nesting depth exceeds maximum (%d)", maxDepth)
|
||||||
- **Alias-aware depth:** Aliases increment depth and are re-checked when encountered at greater depths
|
}
|
||||||
|
// Handle alias nodes by following the Alias pointer
|
||||||
|
if node.Kind == yaml.AliasNode && node.Alias != nil {
|
||||||
|
return checkYAMLDepth(node.Alias, depth, maxDepth)
|
||||||
|
}
|
||||||
|
for _, child := range node.Content {
|
||||||
|
if err := checkYAMLDepth(child, depth+1, maxDepth); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
See `review/persona.go:checkYAMLDepth` for the authoritative implementation.
|
The `gopkg.in/yaml.v3` library does not have built-in depth protection, so we implement explicit depth checking by first decoding into a `yaml.Node`, walking the tree to verify depth (including alias resolution), then decoding into the target struct.
|
||||||
|
|
||||||
## State/Data Model
|
## State/Data Model
|
||||||
|
|
||||||
@@ -53,7 +74,7 @@ No new state. Same `Persona` struct, just different parsing.
|
|||||||
| Error | Handling |
|
| Error | Handling |
|
||||||
|-------|----------|
|
|-------|----------|
|
||||||
| Invalid YAML syntax | Return parse error with source file |
|
| Invalid YAML syntax | Return parse error with source file |
|
||||||
| Deeply nested YAML | Custom AST walk (`checkYAMLDepth`) rejects before decode |
|
| Deeply nested YAML | Library rejects (v1.16.0+ fix) |
|
||||||
| Unknown extension | Fall back to JSON parsing |
|
| Unknown extension | Fall back to JSON parsing |
|
||||||
| Missing required fields | Validation rejects after parse |
|
| Missing required fields | Validation rejects after parse |
|
||||||
|
|
||||||
|
|||||||
@@ -0,0 +1,268 @@
|
|||||||
|
# GitHub Support for review-bot
|
||||||
|
|
||||||
|
## Goal
|
||||||
|
|
||||||
|
AI code reviews on GitHub PRs using SAP AI Core as the LLM provider.
|
||||||
|
|
||||||
|
## Non-Goals
|
||||||
|
|
||||||
|
- Auto-detection of platform (explicit `--provider` flag is fine)
|
||||||
|
- Unifying into one abstraction layer for its own sake
|
||||||
|
|
||||||
|
## Constraints
|
||||||
|
|
||||||
|
1. **Same features on both platforms** — anything review-bot does on Gitea should work on GitHub
|
||||||
|
2. **Testable** — small interfaces, dependency injection, no global state
|
||||||
|
3. **Interface from working code** — extract from gitea/, don't invent in vacuum
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Part 1: Feature Inventory
|
||||||
|
|
||||||
|
What does review-bot actually do?
|
||||||
|
|
||||||
|
### Core Review Flow
|
||||||
|
|
||||||
|
| Feature | Description |
|
||||||
|
|---------|-------------|
|
||||||
|
| Get PR metadata | Title, body, head SHA, base ref |
|
||||||
|
| Get PR diff | Unified diff format |
|
||||||
|
| Get PR files | List of changed files with status |
|
||||||
|
| Get file content | Raw file at ref |
|
||||||
|
| List directory | Enumerate files in path |
|
||||||
|
| Post review | Body + inline comments + verdict |
|
||||||
|
|
||||||
|
### Review Management
|
||||||
|
|
||||||
|
| Feature | Description |
|
||||||
|
|---------|-------------|
|
||||||
|
| List reviews | Get existing reviews on PR |
|
||||||
|
| Delete review | Remove old review before re-posting |
|
||||||
|
| Get authenticated user | Who am I? |
|
||||||
|
|
||||||
|
### Platform-Specific (not in shared interface)
|
||||||
|
|
||||||
|
| Feature | Gitea | GitHub |
|
||||||
|
|---------|-------|--------|
|
||||||
|
| Resolve comment | Yes | No equivalent |
|
||||||
|
| Timeline API | Yes | No equivalent |
|
||||||
|
|
||||||
|
These stay on gitea.Client directly. Callers that need them type-assert.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Part 2: GitHub API Mapping
|
||||||
|
|
||||||
|
| Feature | Gitea API | GitHub API |
|
||||||
|
|---------|-----------|------------|
|
||||||
|
| Get PR | `GET /api/v1/repos/.../pulls/{n}` | `GET /repos/.../pulls/{n}` |
|
||||||
|
| Get diff | `.diff` suffix | `Accept: application/vnd.github.diff` header |
|
||||||
|
| Get files | `GET .../pulls/{n}/files` | Same |
|
||||||
|
| Get file content | `GET .../raw/{path}?ref=` | `GET .../contents/{path}?ref=` + base64 decode |
|
||||||
|
| List directory | `GET .../contents/{path}` | Same |
|
||||||
|
| Post review | `POST .../pulls/{n}/reviews` | Same (adapter handles comment schema) |
|
||||||
|
| List reviews | `GET .../pulls/{n}/reviews` | Same |
|
||||||
|
| Delete review | `DELETE .../pulls/{n}/reviews/{id}` | Same |
|
||||||
|
| Get user | `GET /api/v1/user` | `GET /user` |
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Part 3: Interface Design
|
||||||
|
|
||||||
|
**Principle:** Extract from working gitea/ code. The interface is discovered, not invented.
|
||||||
|
|
||||||
|
### Small, role-based interfaces
|
||||||
|
|
||||||
|
```go
|
||||||
|
// vcs/interfaces.go
|
||||||
|
|
||||||
|
type PRReader interface {
|
||||||
|
GetPullRequest(ctx context.Context, owner, repo string, number int) (*PullRequest, error)
|
||||||
|
GetPullRequestDiff(ctx context.Context, owner, repo string, number int) (string, error)
|
||||||
|
GetPullRequestFiles(ctx context.Context, owner, repo string, number int) ([]ChangedFile, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
type FileReader interface {
|
||||||
|
GetFileContent(ctx context.Context, owner, repo, path, ref string) (string, error)
|
||||||
|
ListContents(ctx context.Context, owner, repo, path string) ([]ContentEntry, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
type Reviewer interface {
|
||||||
|
PostReview(ctx context.Context, owner, repo string, number int, req ReviewRequest) (*Review, error)
|
||||||
|
ListReviews(ctx context.Context, owner, repo string, number int) ([]Review, error)
|
||||||
|
DeleteReview(ctx context.Context, owner, repo string, number int, reviewID int64) error
|
||||||
|
}
|
||||||
|
|
||||||
|
type Identity interface {
|
||||||
|
GetAuthenticatedUser(ctx context.Context) (string, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Client combines all for callers that need everything
|
||||||
|
type Client interface {
|
||||||
|
PRReader
|
||||||
|
FileReader
|
||||||
|
Reviewer
|
||||||
|
Identity
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### Types
|
||||||
|
|
||||||
|
Use what gitea/ already has. Move to vcs/types.go or re-export.
|
||||||
|
|
||||||
|
```go
|
||||||
|
type PullRequest struct { ... } // from gitea.PullRequest
|
||||||
|
type ChangedFile struct { ... } // from gitea.ChangedFile
|
||||||
|
type ContentEntry struct { ... } // from gitea.ContentEntry
|
||||||
|
type Review struct { ... } // from gitea.Review
|
||||||
|
type ReviewRequest struct { ... } // new, for PostReview input
|
||||||
|
type ReviewComment struct { ... } // from gitea.ReviewComment
|
||||||
|
```
|
||||||
|
|
||||||
|
### Adapter responsibilities
|
||||||
|
|
||||||
|
Each adapter (gitea, github) handles:
|
||||||
|
- API URL construction
|
||||||
|
- Auth header format (`token` vs `Bearer`)
|
||||||
|
- Request/response mapping
|
||||||
|
- Comment schema translation (line numbers, commit IDs, etc.)
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Part 4: Test Plan
|
||||||
|
|
||||||
|
### Unit Tests (mock HTTP)
|
||||||
|
|
||||||
|
```
|
||||||
|
github/
|
||||||
|
pr_test.go # TestGetPullRequest, TestGetDiff, TestGetFiles
|
||||||
|
files_test.go # TestGetFileContent, TestListContents
|
||||||
|
review_test.go # TestPostReview, TestListReviews, TestDeleteReview
|
||||||
|
identity_test.go # TestGetAuthenticatedUser
|
||||||
|
```
|
||||||
|
|
||||||
|
Per method: happy path, 404, 401, 429, malformed response.
|
||||||
|
|
||||||
|
### Integration Tests
|
||||||
|
|
||||||
|
Against github.com/aweiker/ai-core-review-bot:
|
||||||
|
- Fetch real PR
|
||||||
|
- Fetch real file
|
||||||
|
- Post + delete review (clean up)
|
||||||
|
|
||||||
|
### End-to-End
|
||||||
|
|
||||||
|
Open PR on test repo, run full review-bot, verify review appears.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Part 5: Implementation Phases
|
||||||
|
|
||||||
|
### Phase 1: Extract interfaces from gitea/
|
||||||
|
|
||||||
|
**Work:**
|
||||||
|
- Create `vcs/interfaces.go` with interfaces extracted from gitea/client.go signatures
|
||||||
|
- Create `vcs/types.go` — move or alias types from gitea/
|
||||||
|
- Verify gitea.Client satisfies vcs.Client (compile-time check)
|
||||||
|
|
||||||
|
**Exit criteria:** `var _ vcs.Client = (*gitea.Client)(nil)` compiles.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### Phase 2: Gitea adapter (if needed)
|
||||||
|
|
||||||
|
**Work:**
|
||||||
|
- If gitea.Client method signatures don't match exactly, create wrapper
|
||||||
|
- Keep gitea/ working exactly as before
|
||||||
|
|
||||||
|
**Exit criteria:** Existing tests pass. No behavior change.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### Phase 3: GitHub client — PRReader
|
||||||
|
|
||||||
|
**Work:**
|
||||||
|
- `github/client.go` — struct, constructor, HTTP helpers
|
||||||
|
- `github/pr.go` — GetPullRequest, GetPullRequestDiff, GetPullRequestFiles
|
||||||
|
- Unit tests
|
||||||
|
|
||||||
|
**Exit criteria:** `go test ./github/...` passes for PR methods.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### Phase 4: GitHub client — FileReader
|
||||||
|
|
||||||
|
**Work:**
|
||||||
|
- `github/files.go` — GetFileContent, ListContents
|
||||||
|
- Unit tests
|
||||||
|
|
||||||
|
**Exit criteria:** Unit tests pass.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### Phase 5: GitHub client — Reviewer + Identity
|
||||||
|
|
||||||
|
**Work:**
|
||||||
|
- `github/review.go` — PostReview, ListReviews, DeleteReview
|
||||||
|
- `github/identity.go` — GetAuthenticatedUser
|
||||||
|
- Unit tests
|
||||||
|
|
||||||
|
**Exit criteria:** Unit tests pass.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### Phase 6: Integration tests
|
||||||
|
|
||||||
|
**Work:**
|
||||||
|
- `integration/github_test.go`
|
||||||
|
- Test against real GitHub
|
||||||
|
|
||||||
|
**Exit criteria:** All integration tests pass.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### Phase 7: Wire into cmd/review-bot
|
||||||
|
|
||||||
|
**Work:**
|
||||||
|
- Add `--provider github|gitea` flag (default: gitea for backward compat)
|
||||||
|
- Select client based on flag
|
||||||
|
- Update to use vcs interfaces where it makes sense
|
||||||
|
|
||||||
|
**Exit criteria:**
|
||||||
|
- `./review-bot --provider github ...` works
|
||||||
|
- `./review-bot --provider gitea ...` works (same as before)
|
||||||
|
- Existing Gitea workflows unchanged
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### Phase 8: GitHub Actions workflow + releases
|
||||||
|
|
||||||
|
**Work:**
|
||||||
|
- `.github/workflows/ci.yml` — test on PR
|
||||||
|
- `.github/workflows/release.yml` — publish binary to GitHub releases
|
||||||
|
- `.github/actions/review/action.yml` — composite action
|
||||||
|
- Action downloads binary from github.com/aweiker/ai-core-review-bot releases
|
||||||
|
|
||||||
|
**Exit criteria:**
|
||||||
|
- CI runs on github.com/aweiker/ai-core-review-bot
|
||||||
|
- Release creates downloadable binary
|
||||||
|
- Review action posts review successfully
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Part 6: Decisions
|
||||||
|
|
||||||
|
| Question | Decision |
|
||||||
|
|----------|----------|
|
||||||
|
| Auth token | Workflow `GITHUB_TOKEN` (automatic) |
|
||||||
|
| Binary distribution | GitHub releases on aweiker/ai-core-review-bot |
|
||||||
|
| Comment schema | Adapter's job — translate ReviewComment to platform format |
|
||||||
|
| Default provider | `gitea` for backward compatibility |
|
||||||
|
| Shared types | vcs/types.go (extracted from gitea/) |
|
||||||
|
| Platform-specific features | Stay on concrete client, not interface |
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Summary
|
||||||
|
|
||||||
|
8 phases. Start by extracting interfaces from working gitea/ code, not inventing them. GitHub implements the same interfaces. Each phase has clear exit criteria.
|
||||||
+30
-213
@@ -11,7 +11,6 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"log/slog"
|
"log/slog"
|
||||||
"math"
|
|
||||||
"net"
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
"net/url"
|
||||||
@@ -48,12 +47,6 @@ func IsServerError(err error) bool {
|
|||||||
return errors.As(err, &apiErr) && apiErr.StatusCode >= 500 && apiErr.StatusCode < 600
|
return errors.As(err, &apiErr) && apiErr.StatusCode >= 500 && apiErr.StatusCode < 600
|
||||||
}
|
}
|
||||||
|
|
||||||
// DefaultMaxDiffSize is the default maximum diff size in bytes (10 MB).
|
|
||||||
const DefaultMaxDiffSize = 10 * 1024 * 1024
|
|
||||||
|
|
||||||
// ErrDiffTooLarge is returned when a PR diff exceeds the configured MaxDiffSize.
|
|
||||||
var ErrDiffTooLarge = errors.New("diff size exceeds maximum allowed size")
|
|
||||||
|
|
||||||
// Client interacts with the Gitea API.
|
// Client interacts with the Gitea API.
|
||||||
// A Client is safe for concurrent use by multiple goroutines.
|
// A Client is safe for concurrent use by multiple goroutines.
|
||||||
type Client struct {
|
type Client struct {
|
||||||
@@ -68,152 +61,20 @@ type Client struct {
|
|||||||
// This field must be configured before the first request is made.
|
// This field must be configured before the first request is made.
|
||||||
// Modifying it while requests are in flight is not safe.
|
// Modifying it while requests are in flight is not safe.
|
||||||
RetryBackoff []time.Duration
|
RetryBackoff []time.Duration
|
||||||
|
|
||||||
// MaxDiffSize is the maximum number of bytes allowed when fetching a PR diff.
|
|
||||||
// If zero, defaults to DefaultMaxDiffSize (10 MB). Set to any negative value
|
|
||||||
// (or math.MaxInt64) to disable the limit.
|
|
||||||
//
|
|
||||||
// This field must be configured before the first request is made.
|
|
||||||
// Modifying it while requests are in flight is not safe.
|
|
||||||
MaxDiffSize int64
|
|
||||||
}
|
|
||||||
|
|
||||||
// defaultCheckRedirect is the redirect policy used by NewClient.
|
|
||||||
// NOTE: This function is intentionally duplicated in github/client.go (and vice versa)
|
|
||||||
// because the packages are separate. Changes here must be mirrored there.
|
|
||||||
// It rejects HTTPS->HTTP protocol downgrades (to prevent plaintext leakage)
|
|
||||||
// and cross-host redirects (to prevent following responses from untrusted
|
|
||||||
// endpoints). Same-host, same-or-upgraded-scheme redirects are allowed.
|
|
||||||
func defaultCheckRedirect(req *http.Request, via []*http.Request) error {
|
|
||||||
if len(via) >= 10 {
|
|
||||||
return fmt.Errorf("stopped after 10 redirects")
|
|
||||||
}
|
|
||||||
// Guard for direct invocation in tests and any future callers;
|
|
||||||
// net/http guarantees len(via) >= 1 during actual redirects.
|
|
||||||
if len(via) == 0 {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
prev := via[len(via)-1]
|
|
||||||
// Reject protocol downgrade: HTTPS->HTTP leaks request metadata over plaintext.
|
|
||||||
if prev.URL.Scheme == "https" && req.URL.Scheme == "http" {
|
|
||||||
return fmt.Errorf("refusing redirect: HTTPS to HTTP downgrade (%s -> %s)", prev.URL.Host, req.URL.Host)
|
|
||||||
}
|
|
||||||
// Reject cross-host redirect entirely to avoid consuming responses
|
|
||||||
// from untrusted endpoints.
|
|
||||||
if req.URL.Host != prev.URL.Host {
|
|
||||||
return fmt.Errorf("refusing redirect: cross-host (%s -> %s)", prev.URL.Host, req.URL.Host)
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// safeDialContext is the default DialContext for NewClient.
|
|
||||||
// It resolves the hostname and checks every returned IP against the blocked
|
|
||||||
// CIDR list before establishing a connection. This prevents SSRF attacks
|
|
||||||
// where user-supplied URLs resolve to internal/private addresses.
|
|
||||||
//
|
|
||||||
// After validating all IPs, we dial the first resolved IP directly to avoid
|
|
||||||
// a second DNS lookup (which could return a different IP in a DNS rebinding
|
|
||||||
// attack). This narrows — but does not fully eliminate — the DNS rebinding
|
|
||||||
// window to the time between LookupIPAddr and DialContext.
|
|
||||||
//
|
|
||||||
// If the host is already an IP literal, LookupIPAddr returns it directly
|
|
||||||
// (no DNS query issued), so IP literals like https://127.0.0.1/ are blocked.
|
|
||||||
func safeDialContext(ctx context.Context, network, addr string) (net.Conn, error) {
|
|
||||||
host, port, err := net.SplitHostPort(addr)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("safeDialContext: invalid address %q: %w", addr, err)
|
|
||||||
}
|
|
||||||
addrs, err := net.DefaultResolver.LookupIPAddr(ctx, host)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("safeDialContext: DNS lookup %q: %w", host, err)
|
|
||||||
}
|
|
||||||
if len(addrs) == 0 {
|
|
||||||
return nil, fmt.Errorf("safeDialContext: no addresses returned for %q", host)
|
|
||||||
}
|
|
||||||
for _, a := range addrs {
|
|
||||||
if IsBlockedIP(a.IP) {
|
|
||||||
return nil, fmt.Errorf("safeDialContext: blocked: %q resolves to private/reserved IP %s", host, a.IP)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
// Try each resolved IP in order, returning the first successful connection.
|
|
||||||
// Fallback is important when a hostname resolves to multiple IPs and the first
|
|
||||||
// is temporarily unreachable. All IPs were already validated above, so dialing
|
|
||||||
// any of them is safe.
|
|
||||||
//
|
|
||||||
// Timeout: 10s per the design (PLAN.md); the outer http.Client has a 30s
|
|
||||||
// total timeout, but the per-dial timeout ensures a slow TCP connect on one IP
|
|
||||||
// doesn't consume the budget needed to try others.
|
|
||||||
d := &net.Dialer{Timeout: 10 * time.Second}
|
|
||||||
var lastErr error
|
|
||||||
for _, a := range addrs {
|
|
||||||
conn, err := d.DialContext(ctx, network, net.JoinHostPort(a.IP.String(), port))
|
|
||||||
if err == nil {
|
|
||||||
return conn, nil
|
|
||||||
}
|
|
||||||
lastErr = err
|
|
||||||
}
|
|
||||||
return nil, fmt.Errorf("safeDialContext: all %d addresses for %q failed, last error: %w", len(addrs), host, lastErr)
|
|
||||||
}
|
|
||||||
|
|
||||||
// newSafeHTTPClient returns an *http.Client with the SSRF-blocking safeDialContext
|
|
||||||
// transport and the cross-host redirect rejection policy.
|
|
||||||
//
|
|
||||||
// We clone http.DefaultTransport to preserve its production-ready defaults
|
|
||||||
// (ProxyFromEnvironment, TLSHandshakeTimeout, IdleConnTimeout, connection
|
|
||||||
// pooling, HTTP/2 support) and override only DialContext with safeDialContext.
|
|
||||||
func newSafeHTTPClient() *http.Client {
|
|
||||||
transport := http.DefaultTransport.(*http.Transport).Clone()
|
|
||||||
transport.DialContext = safeDialContext
|
|
||||||
return &http.Client{
|
|
||||||
Timeout: 30 * time.Second,
|
|
||||||
Transport: transport,
|
|
||||||
CheckRedirect: defaultCheckRedirect,
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewClient creates a new Gitea API client.
|
// NewClient creates a new Gitea API client.
|
||||||
//
|
|
||||||
// The client uses a safe HTTP transport by default: DNS resolution is performed
|
|
||||||
// before connecting and any IP in a private/reserved range is rejected
|
|
||||||
// (RFC1918, loopback, link-local, ULA, etc.). Cross-host and HTTPS→HTTP
|
|
||||||
// redirects are also rejected.
|
|
||||||
//
|
|
||||||
// For tests that use httptest.NewServer (which listens on 127.0.0.1), call
|
|
||||||
// WithUnsafeDialer() to bypass the IP check.
|
|
||||||
func NewClient(baseURL, token string) *Client {
|
func NewClient(baseURL, token string) *Client {
|
||||||
return &Client{
|
return &Client{
|
||||||
baseURL: strings.TrimRight(baseURL, "/"),
|
baseURL: strings.TrimRight(baseURL, "/"),
|
||||||
token: token,
|
token: token,
|
||||||
http: newSafeHTTPClient(),
|
http: &http.Client{Timeout: 30 * time.Second},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// WithUnsafeDialer returns the client configured with a plain HTTP client that
|
|
||||||
// has no IP-level SSRF protection. It preserves the redirect-rejection policy.
|
|
||||||
//
|
|
||||||
// This MUST only be used in tests. Production code must never call this method.
|
|
||||||
func (c *Client) WithUnsafeDialer() *Client {
|
|
||||||
c.http = &http.Client{
|
|
||||||
Timeout: 30 * time.Second,
|
|
||||||
CheckRedirect: defaultCheckRedirect,
|
|
||||||
}
|
|
||||||
return c
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetHTTPClient sets the underlying HTTP client used for requests.
|
// SetHTTPClient sets the underlying HTTP client used for requests.
|
||||||
// This is intended for test setup only to inject mock transports; it must be
|
// This is intended for testing to inject mock transports.
|
||||||
// called before any goroutines issue requests.
|
|
||||||
//
|
|
||||||
// Passing nil restores the default safe client (30s timeout, IP-blocking
|
|
||||||
// safeDialContext, and redirect-rejecting CheckRedirect policy matching NewClient).
|
|
||||||
//
|
|
||||||
// Callers providing a non-nil client are responsible for configuring a safe
|
|
||||||
// CheckRedirect policy. Without one, the default net/http behavior will follow
|
|
||||||
// redirects and may forward the Authorization header to untrusted hosts.
|
|
||||||
func (c *Client) SetHTTPClient(hc *http.Client) {
|
func (c *Client) SetHTTPClient(hc *http.Client) {
|
||||||
if hc == nil {
|
|
||||||
hc = newSafeHTTPClient()
|
|
||||||
}
|
|
||||||
c.http = hc
|
c.http = hc
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -264,28 +125,9 @@ func (c *Client) GetPullRequest(ctx context.Context, owner, repo string, number
|
|||||||
}
|
}
|
||||||
|
|
||||||
// GetPullRequestDiff fetches the unified diff for a PR.
|
// GetPullRequestDiff fetches the unified diff for a PR.
|
||||||
// It enforces MaxDiffSize to prevent unbounded memory allocation.
|
|
||||||
// Returns ErrDiffTooLarge if the diff exceeds the configured limit.
|
|
||||||
func (c *Client) GetPullRequestDiff(ctx context.Context, owner, repo string, number int) (string, error) {
|
func (c *Client) GetPullRequestDiff(ctx context.Context, owner, repo string, number int) (string, error) {
|
||||||
reqURL := fmt.Sprintf("%s/api/v1/repos/%s/%s/pulls/%d.diff", c.baseURL, url.PathEscape(owner), url.PathEscape(repo), number)
|
reqURL := fmt.Sprintf("%s/api/v1/repos/%s/%s/pulls/%d.diff", c.baseURL, url.PathEscape(owner), url.PathEscape(repo), number)
|
||||||
|
body, err := c.doGet(ctx, reqURL)
|
||||||
maxSize := c.MaxDiffSize
|
|
||||||
if maxSize == 0 {
|
|
||||||
maxSize = DefaultMaxDiffSize
|
|
||||||
}
|
|
||||||
|
|
||||||
// When the limit is disabled (negative) or set to math.MaxInt64 (which
|
|
||||||
// would overflow the +1 detection and silently disable enforcement),
|
|
||||||
// use the standard unlimited doGet path.
|
|
||||||
if maxSize < 0 || maxSize == math.MaxInt64 {
|
|
||||||
body, err := c.doGet(ctx, reqURL)
|
|
||||||
if err != nil {
|
|
||||||
return "", fmt.Errorf("fetch diff: %w", err)
|
|
||||||
}
|
|
||||||
return string(body), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
body, err := c.doGetLimited(ctx, reqURL, maxSize)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", fmt.Errorf("fetch diff: %w", err)
|
return "", fmt.Errorf("fetch diff: %w", err)
|
||||||
}
|
}
|
||||||
@@ -341,22 +183,18 @@ func (c *Client) GetFileContentRef(ctx context.Context, owner, repo, filepath, r
|
|||||||
}
|
}
|
||||||
|
|
||||||
// PostReview submits a review to a PR and returns the created review.
|
// PostReview submits a review to a PR and returns the created review.
|
||||||
// event should be one of "APPROVED", "REQUEST_CHANGES", or "COMMENT".
|
// event should be "APPROVED" or "REQUEST_CHANGES".
|
||||||
// commitID anchors the review to a specific commit SHA. If empty, Gitea
|
|
||||||
// defaults to the current PR head.
|
|
||||||
// comments are optional inline comments attached to specific lines.
|
// comments are optional inline comments attached to specific lines.
|
||||||
func (c *Client) PostReview(ctx context.Context, owner, repo string, number int, event, body, commitID string, comments []ReviewComment) (*Review, error) {
|
func (c *Client) PostReview(ctx context.Context, owner, repo string, number int, event, body string, comments []ReviewComment) (*Review, error) {
|
||||||
reqURL := fmt.Sprintf("%s/api/v1/repos/%s/%s/pulls/%d/reviews", c.baseURL, url.PathEscape(owner), url.PathEscape(repo), number)
|
reqURL := fmt.Sprintf("%s/api/v1/repos/%s/%s/pulls/%d/reviews", c.baseURL, url.PathEscape(owner), url.PathEscape(repo), number)
|
||||||
|
|
||||||
payload := struct {
|
payload := struct {
|
||||||
Body string `json:"body"`
|
Body string `json:"body"`
|
||||||
Event string `json:"event"`
|
Event string `json:"event"`
|
||||||
CommitID string `json:"commit_id,omitempty"`
|
|
||||||
Comments []ReviewComment `json:"comments,omitempty"`
|
Comments []ReviewComment `json:"comments,omitempty"`
|
||||||
}{
|
}{
|
||||||
Body: body,
|
Body: body,
|
||||||
Event: event,
|
Event: event,
|
||||||
CommitID: commitID,
|
|
||||||
Comments: comments,
|
Comments: comments,
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -454,9 +292,9 @@ func isRetriableSyscallError(err error) bool {
|
|||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
// redactURL strips query parameters and userinfo credentials from a URL for
|
// redactURL strips query parameters from a URL for safe logging.
|
||||||
// safe logging. This prevents accidental exposure of sensitive data (tokens in
|
// This prevents accidental exposure of sensitive data that future callers
|
||||||
// query strings, or user:pass in the authority) in log output.
|
// might pass via query strings.
|
||||||
func redactURL(rawURL string) string {
|
func redactURL(rawURL string) string {
|
||||||
parsed, err := url.Parse(rawURL)
|
parsed, err := url.Parse(rawURL)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -464,9 +302,6 @@ func redactURL(rawURL string) string {
|
|||||||
// potentially logging something sensitive.
|
// potentially logging something sensitive.
|
||||||
return "[invalid URL]"
|
return "[invalid URL]"
|
||||||
}
|
}
|
||||||
if parsed.User != nil {
|
|
||||||
parsed.User = url.User("REDACTED")
|
|
||||||
}
|
|
||||||
if parsed.RawQuery != "" {
|
if parsed.RawQuery != "" {
|
||||||
parsed.RawQuery = "[redacted]"
|
parsed.RawQuery = "[redacted]"
|
||||||
}
|
}
|
||||||
@@ -487,12 +322,10 @@ func sanitizeErrorForLog(err error) string {
|
|||||||
return err.Error()
|
return err.Error()
|
||||||
}
|
}
|
||||||
|
|
||||||
// doGetWithReader performs an HTTP GET request with retry on 5xx errors and
|
// doGet performs an HTTP GET request with retry on 5xx errors and temporary
|
||||||
// temporary network errors. Retries up to 3 times with exponential backoff
|
// network errors. Retries up to 3 times with exponential backoff (1s, 2s delays
|
||||||
// (1s, 2s delays by default; configurable via Client.RetryBackoff for testing).
|
// by default; configurable via Client.RetryBackoff for testing).
|
||||||
// The readBody function is called with the response body on success (2xx) and
|
func (c *Client) doGet(ctx context.Context, reqURL string) ([]byte, error) {
|
||||||
// is responsible for reading and closing it.
|
|
||||||
func (c *Client) doGetWithReader(ctx context.Context, reqURL string, readBody func(io.ReadCloser) ([]byte, error)) ([]byte, error) {
|
|
||||||
const maxAttempts = 3
|
const maxAttempts = 3
|
||||||
// backoff[i] is the delay before attempt i+1 (i.e., after attempt i fails).
|
// backoff[i] is the delay before attempt i+1 (i.e., after attempt i fails).
|
||||||
// First attempt (i=0) has no delay; retries wait 1s then 2s by default.
|
// First attempt (i=0) has no delay; retries wait 1s then 2s by default.
|
||||||
@@ -557,7 +390,12 @@ func (c *Client) doGetWithReader(ctx context.Context, reqURL string, readBody fu
|
|||||||
return nil, lastErr
|
return nil, lastErr
|
||||||
}
|
}
|
||||||
if resp.StatusCode >= 200 && resp.StatusCode < 300 {
|
if resp.StatusCode >= 200 && resp.StatusCode < 300 {
|
||||||
return readBody(resp.Body)
|
body, err := io.ReadAll(resp.Body)
|
||||||
|
resp.Body.Close()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return body, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Error path: limit how much we read from potentially malicious server
|
// Error path: limit how much we read from potentially malicious server
|
||||||
@@ -575,39 +413,6 @@ func (c *Client) doGetWithReader(ctx context.Context, reqURL string, readBody fu
|
|||||||
return nil, lastErr
|
return nil, lastErr
|
||||||
}
|
}
|
||||||
|
|
||||||
// doGet performs an HTTP GET request with retry, reading the full response body.
|
|
||||||
func (c *Client) doGet(ctx context.Context, reqURL string) ([]byte, error) {
|
|
||||||
return c.doGetWithReader(ctx, reqURL, func(body io.ReadCloser) ([]byte, error) {
|
|
||||||
defer body.Close()
|
|
||||||
return io.ReadAll(body)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
// doGetLimited performs an HTTP GET request with retry but enforces a maximum
|
|
||||||
// response body size. Returns ErrDiffTooLarge if the response exceeds maxBytes.
|
|
||||||
// It reads maxBytes+1 (clamped to avoid overflow) to detect truncation without
|
|
||||||
// buffering the entire body.
|
|
||||||
func (c *Client) doGetLimited(ctx context.Context, reqURL string, maxBytes int64) ([]byte, error) {
|
|
||||||
return c.doGetWithReader(ctx, reqURL, func(body io.ReadCloser) ([]byte, error) {
|
|
||||||
defer body.Close()
|
|
||||||
// Read up to maxBytes+1 to detect overflow.
|
|
||||||
// Clamp to prevent integer overflow when maxBytes == math.MaxInt64.
|
|
||||||
limitBytes := maxBytes + 1
|
|
||||||
if limitBytes <= 0 {
|
|
||||||
limitBytes = math.MaxInt64
|
|
||||||
}
|
|
||||||
limited := io.LimitReader(body, limitBytes)
|
|
||||||
data, err := io.ReadAll(limited)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
if int64(len(data)) > maxBytes {
|
|
||||||
return nil, fmt.Errorf("%w: response exceeds %d bytes", ErrDiffTooLarge, maxBytes)
|
|
||||||
}
|
|
||||||
return data, nil
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
// escapePath escapes each segment of a relative file path for use in URLs.
|
// escapePath escapes each segment of a relative file path for use in URLs.
|
||||||
// Slashes are preserved as path separators; other special characters are escaped.
|
// Slashes are preserved as path separators; other special characters are escaped.
|
||||||
// Input should be a relative path (no leading slash). Already-encoded segments
|
// Input should be a relative path (no leading slash). Already-encoded segments
|
||||||
@@ -1026,3 +831,15 @@ func (c *Client) ResolveComment(ctx context.Context, owner, repo string, comment
|
|||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// DismissReview dismisses a review on a pull request.
|
||||||
|
// This is a stub for the vcs.Reviewer interface; full implementation is Phase 2.
|
||||||
|
func (c *Client) DismissReview(ctx context.Context, owner, repo string, number int, reviewID int64, message string) error {
|
||||||
|
return fmt.Errorf("dismiss review %d on %s/%s#%d: %w", reviewID, owner, repo, number, errors.ErrUnsupported)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetFileContentAtRef fetches a file at a specific ref from a repo.
|
||||||
|
// This delegates to GetFileContentRef for the Gitea implementation.
|
||||||
|
func (c *Client) GetFileContentAtRef(ctx context.Context, owner, repo, path, ref string) (string, error) {
|
||||||
|
return c.GetFileContentRef(ctx, owner, repo, path, ref)
|
||||||
|
}
|
||||||
|
|||||||
+42
-395
@@ -9,7 +9,6 @@ import (
|
|||||||
"net"
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
"net/url"
|
|
||||||
"strings"
|
"strings"
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
"syscall"
|
"syscall"
|
||||||
@@ -36,7 +35,7 @@ func TestGetPullRequest(t *testing.T) {
|
|||||||
}))
|
}))
|
||||||
defer server.Close()
|
defer server.Close()
|
||||||
|
|
||||||
client := NewTestClient(server.URL, "test-token")
|
client := NewClient(server.URL, "test-token")
|
||||||
got, err := client.GetPullRequest(context.Background(), "owner", "repo", 1)
|
got, err := client.GetPullRequest(context.Background(), "owner", "repo", 1)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("unexpected error: %v", err)
|
t.Fatalf("unexpected error: %v", err)
|
||||||
@@ -63,7 +62,7 @@ func TestGetPullRequestDiff(t *testing.T) {
|
|||||||
}))
|
}))
|
||||||
defer server.Close()
|
defer server.Close()
|
||||||
|
|
||||||
client := NewTestClient(server.URL, "test-token")
|
client := NewClient(server.URL, "test-token")
|
||||||
got, err := client.GetPullRequestDiff(context.Background(), "owner", "repo", 5)
|
got, err := client.GetPullRequestDiff(context.Background(), "owner", "repo", 5)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("unexpected error: %v", err)
|
t.Fatalf("unexpected error: %v", err)
|
||||||
@@ -88,7 +87,7 @@ func TestGetCommitStatuses(t *testing.T) {
|
|||||||
}))
|
}))
|
||||||
defer server.Close()
|
defer server.Close()
|
||||||
|
|
||||||
client := NewTestClient(server.URL, "test-token")
|
client := NewClient(server.URL, "test-token")
|
||||||
got, err := client.GetCommitStatuses(context.Background(), "owner", "repo", "abc123")
|
got, err := client.GetCommitStatuses(context.Background(), "owner", "repo", "abc123")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("unexpected error: %v", err)
|
t.Fatalf("unexpected error: %v", err)
|
||||||
@@ -117,9 +116,8 @@ func TestPostReview(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
var payload struct {
|
var payload struct {
|
||||||
Body string `json:"body"`
|
Body string `json:"body"`
|
||||||
Event string `json:"event"`
|
Event string `json:"event"`
|
||||||
CommitID string `json:"commit_id"`
|
|
||||||
}
|
}
|
||||||
if err := json.NewDecoder(r.Body).Decode(&payload); err != nil {
|
if err := json.NewDecoder(r.Body).Decode(&payload); err != nil {
|
||||||
t.Fatalf("failed to decode payload: %v", err)
|
t.Fatalf("failed to decode payload: %v", err)
|
||||||
@@ -130,16 +128,14 @@ func TestPostReview(t *testing.T) {
|
|||||||
if payload.Event != "APPROVED" {
|
if payload.Event != "APPROVED" {
|
||||||
t.Errorf("expected event %q, got %q", "APPROVED", payload.Event)
|
t.Errorf("expected event %q, got %q", "APPROVED", payload.Event)
|
||||||
}
|
}
|
||||||
if payload.CommitID != "abc123def" {
|
|
||||||
t.Errorf("expected commit_id %q, got %q", "abc123def", payload.CommitID)
|
|
||||||
}
|
|
||||||
w.WriteHeader(http.StatusOK)
|
w.WriteHeader(http.StatusOK)
|
||||||
w.Write([]byte(`{"id":100,"user":{"login":"review-bot"},"state":"APPROVED","stale":false}`))
|
w.Write([]byte(`{"id":100,"user":{"login":"review-bot"},"state":"APPROVED","stale":false}`))
|
||||||
}))
|
}))
|
||||||
defer server.Close()
|
defer server.Close()
|
||||||
|
|
||||||
client := NewTestClient(server.URL, "test-token")
|
client := NewClient(server.URL, "test-token")
|
||||||
review, err := client.PostReview(context.Background(), "owner", "repo", 3, "APPROVED", "LGTM", "abc123def", nil)
|
review, err := client.PostReview(context.Background(), "owner", "repo", 3, "APPROVED", "LGTM", nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("unexpected error: %v", err)
|
t.Fatalf("unexpected error: %v", err)
|
||||||
}
|
}
|
||||||
@@ -158,7 +154,7 @@ func TestGetPullRequest_Non200(t *testing.T) {
|
|||||||
}))
|
}))
|
||||||
defer server.Close()
|
defer server.Close()
|
||||||
|
|
||||||
client := NewTestClient(server.URL, "test-token")
|
client := NewClient(server.URL, "test-token")
|
||||||
_, err := client.GetPullRequest(context.Background(), "owner", "repo", 999)
|
_, err := client.GetPullRequest(context.Background(), "owner", "repo", 999)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
t.Fatal("expected error for 404, got nil")
|
t.Fatal("expected error for 404, got nil")
|
||||||
@@ -171,7 +167,7 @@ func TestGetPullRequest_BadJSON(t *testing.T) {
|
|||||||
}))
|
}))
|
||||||
defer server.Close()
|
defer server.Close()
|
||||||
|
|
||||||
client := NewTestClient(server.URL, "test-token")
|
client := NewClient(server.URL, "test-token")
|
||||||
_, err := client.GetPullRequest(context.Background(), "owner", "repo", 1)
|
_, err := client.GetPullRequest(context.Background(), "owner", "repo", 1)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
t.Fatal("expected error for bad JSON, got nil")
|
t.Fatal("expected error for bad JSON, got nil")
|
||||||
@@ -185,36 +181,13 @@ func TestPostReview_Non200(t *testing.T) {
|
|||||||
}))
|
}))
|
||||||
defer server.Close()
|
defer server.Close()
|
||||||
|
|
||||||
client := NewTestClient(server.URL, "test-token")
|
client := NewClient(server.URL, "test-token")
|
||||||
_, err := client.PostReview(context.Background(), "owner", "repo", 1, "APPROVED", "test", "", nil)
|
_, err := client.PostReview(context.Background(), "owner", "repo", 1, "APPROVED", "test", nil)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
t.Fatal("expected error for 403, got nil")
|
t.Fatal("expected error for 403, got nil")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestPostReview_EmptyCommitID_OmittedFromPayload(t *testing.T) {
|
|
||||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
body, _ := io.ReadAll(r.Body)
|
|
||||||
var raw map[string]interface{}
|
|
||||||
if err := json.Unmarshal(body, &raw); err != nil {
|
|
||||||
t.Fatalf("failed to decode payload: %v", err)
|
|
||||||
}
|
|
||||||
if _, exists := raw["commit_id"]; exists {
|
|
||||||
t.Errorf("expected commit_id to be omitted from payload when empty, but it was present")
|
|
||||||
}
|
|
||||||
|
|
||||||
w.WriteHeader(http.StatusOK)
|
|
||||||
w.Write([]byte(`{"id":200,"user":{"login":"bot"},"state":"APPROVED","stale":false}`))
|
|
||||||
}))
|
|
||||||
defer server.Close()
|
|
||||||
|
|
||||||
client := NewTestClient(server.URL, "test-token")
|
|
||||||
_, err := client.PostReview(context.Background(), "owner", "repo", 1, "APPROVED", "ok", "", nil)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("unexpected error: %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestGetFileContent(t *testing.T) {
|
func TestGetFileContent(t *testing.T) {
|
||||||
expected := "# Conventions\n- Be nice\n"
|
expected := "# Conventions\n- Be nice\n"
|
||||||
|
|
||||||
@@ -226,7 +199,7 @@ func TestGetFileContent(t *testing.T) {
|
|||||||
}))
|
}))
|
||||||
defer server.Close()
|
defer server.Close()
|
||||||
|
|
||||||
client := NewTestClient(server.URL, "test-token")
|
client := NewClient(server.URL, "test-token")
|
||||||
got, err := client.GetFileContent(context.Background(), "owner", "repo", "CONVENTIONS.md")
|
got, err := client.GetFileContent(context.Background(), "owner", "repo", "CONVENTIONS.md")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("unexpected error: %v", err)
|
t.Fatalf("unexpected error: %v", err)
|
||||||
@@ -246,7 +219,7 @@ func TestGetPullRequestFiles(t *testing.T) {
|
|||||||
}))
|
}))
|
||||||
defer server.Close()
|
defer server.Close()
|
||||||
|
|
||||||
client := NewTestClient(server.URL, "test-token")
|
client := NewClient(server.URL, "test-token")
|
||||||
files, err := client.GetPullRequestFiles(context.Background(), "owner", "repo", 1)
|
files, err := client.GetPullRequestFiles(context.Background(), "owner", "repo", 1)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("unexpected error: %v", err)
|
t.Fatalf("unexpected error: %v", err)
|
||||||
@@ -271,7 +244,7 @@ func TestGetFileContentRef(t *testing.T) {
|
|||||||
}))
|
}))
|
||||||
defer server.Close()
|
defer server.Close()
|
||||||
|
|
||||||
client := NewTestClient(server.URL, "test-token")
|
client := NewClient(server.URL, "test-token")
|
||||||
content, err := client.GetFileContentRef(context.Background(), "owner", "repo", "main.go", "feature-branch")
|
content, err := client.GetFileContentRef(context.Background(), "owner", "repo", "main.go", "feature-branch")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("unexpected error: %v", err)
|
t.Fatalf("unexpected error: %v", err)
|
||||||
@@ -291,7 +264,7 @@ func TestListContents(t *testing.T) {
|
|||||||
}))
|
}))
|
||||||
defer server.Close()
|
defer server.Close()
|
||||||
|
|
||||||
client := NewTestClient(server.URL, "test-token")
|
client := NewClient(server.URL, "test-token")
|
||||||
entries, err := client.ListContents(context.Background(), "owner", "repo", "docs")
|
entries, err := client.ListContents(context.Background(), "owner", "repo", "docs")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("unexpected error: %v", err)
|
t.Fatalf("unexpected error: %v", err)
|
||||||
@@ -318,7 +291,7 @@ func TestListContents_DotPath(t *testing.T) {
|
|||||||
}))
|
}))
|
||||||
defer server.Close()
|
defer server.Close()
|
||||||
|
|
||||||
client := NewTestClient(server.URL, "test-token")
|
client := NewClient(server.URL, "test-token")
|
||||||
entries, err := client.ListContents(context.Background(), "owner", "repo", ".")
|
entries, err := client.ListContents(context.Background(), "owner", "repo", ".")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("unexpected error: %v", err)
|
t.Fatalf("unexpected error: %v", err)
|
||||||
@@ -343,7 +316,7 @@ func TestListContents_FilePath(t *testing.T) {
|
|||||||
}))
|
}))
|
||||||
defer server.Close()
|
defer server.Close()
|
||||||
|
|
||||||
client := NewTestClient(server.URL, "test-token")
|
client := NewClient(server.URL, "test-token")
|
||||||
entries, err := client.ListContents(context.Background(), "owner", "repo", "README.md")
|
entries, err := client.ListContents(context.Background(), "owner", "repo", "README.md")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("unexpected error: %v", err)
|
t.Fatalf("unexpected error: %v", err)
|
||||||
@@ -375,7 +348,7 @@ func TestGetAllFilesInPath_File(t *testing.T) {
|
|||||||
}))
|
}))
|
||||||
defer server.Close()
|
defer server.Close()
|
||||||
|
|
||||||
client := NewTestClient(server.URL, "test-token")
|
client := NewClient(server.URL, "test-token")
|
||||||
files, err := client.GetAllFilesInPath(context.Background(), "owner", "repo", "README.md")
|
files, err := client.GetAllFilesInPath(context.Background(), "owner", "repo", "README.md")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("unexpected error: %v", err)
|
t.Fatalf("unexpected error: %v", err)
|
||||||
@@ -428,7 +401,7 @@ func TestListReviews(t *testing.T) {
|
|||||||
}))
|
}))
|
||||||
defer server.Close()
|
defer server.Close()
|
||||||
|
|
||||||
client := NewTestClient(server.URL, "test-token")
|
client := NewClient(server.URL, "test-token")
|
||||||
reviews, err := client.ListReviews(context.Background(), "owner", "repo", 5)
|
reviews, err := client.ListReviews(context.Background(), "owner", "repo", 5)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("unexpected error: %v", err)
|
t.Fatalf("unexpected error: %v", err)
|
||||||
@@ -468,7 +441,7 @@ func TestListReviews_Pagination(t *testing.T) {
|
|||||||
}))
|
}))
|
||||||
defer server.Close()
|
defer server.Close()
|
||||||
|
|
||||||
client := NewTestClient(server.URL, "test-token")
|
client := NewClient(server.URL, "test-token")
|
||||||
reviews, err := client.ListReviews(context.Background(), "owner", "repo", 5)
|
reviews, err := client.ListReviews(context.Background(), "owner", "repo", 5)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("unexpected error: %v", err)
|
t.Fatalf("unexpected error: %v", err)
|
||||||
@@ -493,7 +466,7 @@ func TestDeleteReview(t *testing.T) {
|
|||||||
}))
|
}))
|
||||||
defer server.Close()
|
defer server.Close()
|
||||||
|
|
||||||
client := NewTestClient(server.URL, "test-token")
|
client := NewClient(server.URL, "test-token")
|
||||||
err := client.DeleteReview(context.Background(), "owner", "repo", 5, 10)
|
err := client.DeleteReview(context.Background(), "owner", "repo", 5, 10)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("unexpected error: %v", err)
|
t.Fatalf("unexpected error: %v", err)
|
||||||
@@ -507,7 +480,7 @@ func TestDeleteReview_Forbidden(t *testing.T) {
|
|||||||
}))
|
}))
|
||||||
defer server.Close()
|
defer server.Close()
|
||||||
|
|
||||||
client := NewTestClient(server.URL, "test-token")
|
client := NewClient(server.URL, "test-token")
|
||||||
err := client.DeleteReview(context.Background(), "owner", "repo", 5, 10)
|
err := client.DeleteReview(context.Background(), "owner", "repo", 5, 10)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
t.Fatal("expected error for 403, got nil")
|
t.Fatal("expected error for 403, got nil")
|
||||||
@@ -536,7 +509,7 @@ func TestEditComment(t *testing.T) {
|
|||||||
}))
|
}))
|
||||||
defer server.Close()
|
defer server.Close()
|
||||||
|
|
||||||
client := NewTestClient(server.URL, "test-token")
|
client := NewClient(server.URL, "test-token")
|
||||||
err := client.EditComment(context.Background(), "owner", "repo", 42, "updated body")
|
err := client.EditComment(context.Background(), "owner", "repo", 42, "updated body")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("EditComment() error = %v", err)
|
t.Fatalf("EditComment() error = %v", err)
|
||||||
@@ -550,7 +523,7 @@ func TestEditComment_Forbidden(t *testing.T) {
|
|||||||
}))
|
}))
|
||||||
defer server.Close()
|
defer server.Close()
|
||||||
|
|
||||||
client := NewTestClient(server.URL, "test-token")
|
client := NewClient(server.URL, "test-token")
|
||||||
err := client.EditComment(context.Background(), "owner", "repo", 42, "new body")
|
err := client.EditComment(context.Background(), "owner", "repo", 42, "new body")
|
||||||
if err == nil {
|
if err == nil {
|
||||||
t.Fatal("expected error for 403 response")
|
t.Fatal("expected error for 403 response")
|
||||||
@@ -570,7 +543,7 @@ func TestGetTimelineReviewCommentID(t *testing.T) {
|
|||||||
}))
|
}))
|
||||||
defer server.Close()
|
defer server.Close()
|
||||||
|
|
||||||
client := NewTestClient(server.URL, "test-token")
|
client := NewClient(server.URL, "test-token")
|
||||||
id, err := client.GetTimelineReviewCommentID(context.Background(), "owner", "repo", 5, "<!-- review-bot:sonnet -->")
|
id, err := client.GetTimelineReviewCommentID(context.Background(), "owner", "repo", 5, "<!-- review-bot:sonnet -->")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("GetTimelineReviewCommentID() error = %v", err)
|
t.Fatalf("GetTimelineReviewCommentID() error = %v", err)
|
||||||
@@ -586,7 +559,7 @@ func TestGetTimelineReviewCommentID_NotFound(t *testing.T) {
|
|||||||
}))
|
}))
|
||||||
defer server.Close()
|
defer server.Close()
|
||||||
|
|
||||||
client := NewTestClient(server.URL, "test-token")
|
client := NewClient(server.URL, "test-token")
|
||||||
_, err := client.GetTimelineReviewCommentID(context.Background(), "owner", "repo", 5, "<!-- review-bot:sonnet -->")
|
_, err := client.GetTimelineReviewCommentID(context.Background(), "owner", "repo", 5, "<!-- review-bot:sonnet -->")
|
||||||
if err == nil {
|
if err == nil {
|
||||||
t.Fatal("expected error when sentinel not found")
|
t.Fatal("expected error when sentinel not found")
|
||||||
@@ -609,7 +582,7 @@ func TestGetAllFilesInPath_404FallsBackToFile(t *testing.T) {
|
|||||||
}))
|
}))
|
||||||
defer server.Close()
|
defer server.Close()
|
||||||
|
|
||||||
client := NewTestClient(server.URL, "test-token")
|
client := NewClient(server.URL, "test-token")
|
||||||
files, err := client.GetAllFilesInPath(context.Background(), "owner", "repo", "README.md")
|
files, err := client.GetAllFilesInPath(context.Background(), "owner", "repo", "README.md")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("expected fallback to file on 404, got error: %v", err)
|
t.Fatalf("expected fallback to file on 404, got error: %v", err)
|
||||||
@@ -630,7 +603,7 @@ func TestGetAllFilesInPath_500Propagates(t *testing.T) {
|
|||||||
}))
|
}))
|
||||||
defer server.Close()
|
defer server.Close()
|
||||||
|
|
||||||
client := NewTestClient(server.URL, "test-token")
|
client := NewClient(server.URL, "test-token")
|
||||||
_, err := client.GetAllFilesInPath(context.Background(), "owner", "repo", "somepath")
|
_, err := client.GetAllFilesInPath(context.Background(), "owner", "repo", "somepath")
|
||||||
if err == nil {
|
if err == nil {
|
||||||
t.Fatal("expected error to propagate for 500, got nil")
|
t.Fatal("expected error to propagate for 500, got nil")
|
||||||
@@ -652,7 +625,7 @@ func TestGetAllFilesInPath_403Propagates(t *testing.T) {
|
|||||||
}))
|
}))
|
||||||
defer server.Close()
|
defer server.Close()
|
||||||
|
|
||||||
client := NewTestClient(server.URL, "test-token")
|
client := NewClient(server.URL, "test-token")
|
||||||
_, err := client.GetAllFilesInPath(context.Background(), "owner", "repo", "private/stuff")
|
_, err := client.GetAllFilesInPath(context.Background(), "owner", "repo", "private/stuff")
|
||||||
if err == nil {
|
if err == nil {
|
||||||
t.Fatal("expected error to propagate for 403, got nil")
|
t.Fatal("expected error to propagate for 403, got nil")
|
||||||
@@ -704,7 +677,7 @@ func TestGetAuthenticatedUser(t *testing.T) {
|
|||||||
}))
|
}))
|
||||||
defer server.Close()
|
defer server.Close()
|
||||||
|
|
||||||
client := NewTestClient(server.URL, "test-token")
|
client := NewClient(server.URL, "test-token")
|
||||||
login, err := client.GetAuthenticatedUser(context.Background())
|
login, err := client.GetAuthenticatedUser(context.Background())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("GetAuthenticatedUser() error = %v", err)
|
t.Fatalf("GetAuthenticatedUser() error = %v", err)
|
||||||
@@ -729,7 +702,7 @@ func TestRequestReviewer(t *testing.T) {
|
|||||||
}))
|
}))
|
||||||
defer server.Close()
|
defer server.Close()
|
||||||
|
|
||||||
client := NewTestClient(server.URL, "test-token")
|
client := NewClient(server.URL, "test-token")
|
||||||
err := client.RequestReviewer(context.Background(), "owner", "repo", 7, "bot-user")
|
err := client.RequestReviewer(context.Background(), "owner", "repo", 7, "bot-user")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("RequestReviewer() error = %v", err)
|
t.Fatalf("RequestReviewer() error = %v", err)
|
||||||
@@ -745,7 +718,7 @@ func TestRequestReviewer_204(t *testing.T) {
|
|||||||
}))
|
}))
|
||||||
defer server.Close()
|
defer server.Close()
|
||||||
|
|
||||||
client := NewTestClient(server.URL, "test-token")
|
client := NewClient(server.URL, "test-token")
|
||||||
err := client.RequestReviewer(context.Background(), "owner", "repo", 1, "user")
|
err := client.RequestReviewer(context.Background(), "owner", "repo", 1, "user")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("RequestReviewer() should accept 204, got error = %v", err)
|
t.Fatalf("RequestReviewer() should accept 204, got error = %v", err)
|
||||||
@@ -759,7 +732,7 @@ func TestRequestReviewer_Error(t *testing.T) {
|
|||||||
}))
|
}))
|
||||||
defer server.Close()
|
defer server.Close()
|
||||||
|
|
||||||
client := NewTestClient(server.URL, "test-token")
|
client := NewClient(server.URL, "test-token")
|
||||||
err := client.RequestReviewer(context.Background(), "owner", "repo", 1, "user")
|
err := client.RequestReviewer(context.Background(), "owner", "repo", 1, "user")
|
||||||
if err == nil {
|
if err == nil {
|
||||||
t.Fatal("expected error for 403 response")
|
t.Fatal("expected error for 403 response")
|
||||||
@@ -779,7 +752,7 @@ func TestListReviewComments(t *testing.T) {
|
|||||||
}))
|
}))
|
||||||
defer server.Close()
|
defer server.Close()
|
||||||
|
|
||||||
client := NewTestClient(server.URL, "test-token")
|
client := NewClient(server.URL, "test-token")
|
||||||
comments, err := client.ListReviewComments(context.Background(), "owner", "repo", 1, 42)
|
comments, err := client.ListReviewComments(context.Background(), "owner", "repo", 1, 42)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("ListReviewComments() error = %v", err)
|
t.Fatalf("ListReviewComments() error = %v", err)
|
||||||
@@ -807,7 +780,7 @@ func TestResolveComment(t *testing.T) {
|
|||||||
}))
|
}))
|
||||||
defer server.Close()
|
defer server.Close()
|
||||||
|
|
||||||
client := NewTestClient(server.URL, "test-token")
|
client := NewClient(server.URL, "test-token")
|
||||||
err := client.ResolveComment(context.Background(), "owner", "repo", 99)
|
err := client.ResolveComment(context.Background(), "owner", "repo", 99)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("ResolveComment() error = %v", err)
|
t.Fatalf("ResolveComment() error = %v", err)
|
||||||
@@ -821,7 +794,7 @@ func TestResolveComment_Error(t *testing.T) {
|
|||||||
}))
|
}))
|
||||||
defer server.Close()
|
defer server.Close()
|
||||||
|
|
||||||
client := NewTestClient(server.URL, "test-token")
|
client := NewClient(server.URL, "test-token")
|
||||||
err := client.ResolveComment(context.Background(), "owner", "repo", 99)
|
err := client.ResolveComment(context.Background(), "owner", "repo", 99)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
t.Fatal("expected error for 404 response")
|
t.Fatal("expected error for 404 response")
|
||||||
@@ -870,7 +843,7 @@ func TestDoGet_RetriesOn500(t *testing.T) {
|
|||||||
}))
|
}))
|
||||||
defer server.Close()
|
defer server.Close()
|
||||||
|
|
||||||
client := NewTestClient(server.URL, "test-token")
|
client := NewClient(server.URL, "test-token")
|
||||||
// Use short backoff for fast tests
|
// Use short backoff for fast tests
|
||||||
client.RetryBackoff = []time.Duration{1 * time.Millisecond, 1 * time.Millisecond}
|
client.RetryBackoff = []time.Duration{1 * time.Millisecond, 1 * time.Millisecond}
|
||||||
|
|
||||||
@@ -895,7 +868,7 @@ func TestDoGet_FailsAfterMaxRetries(t *testing.T) {
|
|||||||
}))
|
}))
|
||||||
defer server.Close()
|
defer server.Close()
|
||||||
|
|
||||||
client := NewTestClient(server.URL, "test-token")
|
client := NewClient(server.URL, "test-token")
|
||||||
// Use short backoff for fast tests
|
// Use short backoff for fast tests
|
||||||
client.RetryBackoff = []time.Duration{1 * time.Millisecond, 1 * time.Millisecond}
|
client.RetryBackoff = []time.Duration{1 * time.Millisecond, 1 * time.Millisecond}
|
||||||
|
|
||||||
@@ -924,7 +897,7 @@ func TestDoGet_NoRetryOn4xx(t *testing.T) {
|
|||||||
}))
|
}))
|
||||||
defer server.Close()
|
defer server.Close()
|
||||||
|
|
||||||
client := NewTestClient(server.URL, "test-token")
|
client := NewClient(server.URL, "test-token")
|
||||||
_, err := client.doGet(context.Background(), server.URL+"/test")
|
_, err := client.doGet(context.Background(), server.URL+"/test")
|
||||||
if err == nil {
|
if err == nil {
|
||||||
t.Fatal("expected error for 403")
|
t.Fatal("expected error for 403")
|
||||||
@@ -952,7 +925,7 @@ func TestDoGet_RespectsContextCancellation(t *testing.T) {
|
|||||||
|
|
||||||
ctx, cancel := context.WithCancel(context.Background())
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
|
||||||
client := NewTestClient(server.URL, "test-token")
|
client := NewClient(server.URL, "test-token")
|
||||||
// Use longer backoff to give us time to cancel during the wait
|
// Use longer backoff to give us time to cancel during the wait
|
||||||
client.RetryBackoff = []time.Duration{100 * time.Millisecond, 100 * time.Millisecond}
|
client.RetryBackoff = []time.Duration{100 * time.Millisecond, 100 * time.Millisecond}
|
||||||
|
|
||||||
@@ -972,6 +945,7 @@ func TestDoGet_RespectsContextCancellation(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
// mockTransport is a test helper that returns errors for the first N calls,
|
// mockTransport is a test helper that returns errors for the first N calls,
|
||||||
// then delegates to a real server.
|
// then delegates to a real server.
|
||||||
type mockTransport struct {
|
type mockTransport struct {
|
||||||
@@ -1118,21 +1092,6 @@ func TestRedactURL(t *testing.T) {
|
|||||||
input: "",
|
input: "",
|
||||||
want: "",
|
want: "",
|
||||||
},
|
},
|
||||||
{
|
|
||||||
name: "with userinfo - redacts credentials",
|
|
||||||
input: "https://admin:secret@gitea.example.com/api/v1/repos",
|
|
||||||
want: "https://REDACTED@gitea.example.com/api/v1/repos",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "with userinfo and query params",
|
|
||||||
input: "https://user:pass@example.com/path?token=abc",
|
|
||||||
want: "https://REDACTED@example.com/path?[redacted]",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "username only - no password",
|
|
||||||
input: "https://user@example.com/path",
|
|
||||||
want: "https://REDACTED@example.com/path",
|
|
||||||
},
|
|
||||||
}
|
}
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
@@ -1185,315 +1144,3 @@ func TestSanitizeErrorForLog(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestNewClient_HasCheckRedirect(t *testing.T) {
|
|
||||||
c := NewClient("https://gitea.example.com", "token")
|
|
||||||
if c.http.CheckRedirect == nil {
|
|
||||||
t.Fatal("expected CheckRedirect to be set")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestDefaultCheckRedirect_RejectsHTTPSToHTTP(t *testing.T) {
|
|
||||||
prev := &http.Request{URL: &url.URL{Scheme: "https", Host: "gitea.example.com", Path: "/foo"}}
|
|
||||||
req := &http.Request{
|
|
||||||
URL: &url.URL{Scheme: "http", Host: "gitea.example.com", Path: "/foo"},
|
|
||||||
Header: http.Header{"Authorization": []string{"token abc"}},
|
|
||||||
}
|
|
||||||
err := defaultCheckRedirect(req, []*http.Request{prev})
|
|
||||||
if err == nil {
|
|
||||||
t.Fatal("expected error on HTTPS->HTTP redirect")
|
|
||||||
}
|
|
||||||
if !strings.Contains(err.Error(), "HTTPS to HTTP downgrade") {
|
|
||||||
t.Errorf("unexpected error message: %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestDefaultCheckRedirect_RejectsCrossHost(t *testing.T) {
|
|
||||||
prev := &http.Request{URL: &url.URL{Scheme: "https", Host: "gitea.example.com", Path: "/foo"}}
|
|
||||||
req := &http.Request{
|
|
||||||
URL: &url.URL{Scheme: "https", Host: "cdn.example.com", Path: "/bar"},
|
|
||||||
Header: http.Header{"Authorization": []string{"token abc"}},
|
|
||||||
}
|
|
||||||
err := defaultCheckRedirect(req, []*http.Request{prev})
|
|
||||||
if err == nil {
|
|
||||||
t.Fatal("expected error on cross-host redirect")
|
|
||||||
}
|
|
||||||
if !strings.Contains(err.Error(), "cross-host") {
|
|
||||||
t.Errorf("unexpected error message: %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestDefaultCheckRedirect_AllowsSameHost(t *testing.T) {
|
|
||||||
prev := &http.Request{URL: &url.URL{Scheme: "https", Host: "gitea.example.com", Path: "/foo"}}
|
|
||||||
req := &http.Request{
|
|
||||||
URL: &url.URL{Scheme: "https", Host: "gitea.example.com", Path: "/bar"},
|
|
||||||
Header: http.Header{"Authorization": []string{"token abc"}},
|
|
||||||
}
|
|
||||||
err := defaultCheckRedirect(req, []*http.Request{prev})
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("unexpected error: %v", err)
|
|
||||||
}
|
|
||||||
if auth := req.Header.Get("Authorization"); auth != "token abc" {
|
|
||||||
t.Errorf("expected Authorization to be preserved, got %q", auth)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestDefaultCheckRedirect_AllowsSameHostHTTPToHTTP(t *testing.T) {
|
|
||||||
prev := &http.Request{URL: &url.URL{Scheme: "http", Host: "localhost:3000", Path: "/foo"}}
|
|
||||||
req := &http.Request{
|
|
||||||
URL: &url.URL{Scheme: "http", Host: "localhost:3000", Path: "/bar"},
|
|
||||||
Header: http.Header{},
|
|
||||||
}
|
|
||||||
err := defaultCheckRedirect(req, []*http.Request{prev})
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("unexpected error: %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestDefaultCheckRedirect_RejectsTooManyRedirects(t *testing.T) {
|
|
||||||
via := make([]*http.Request, 10)
|
|
||||||
for i := range via {
|
|
||||||
via[i] = &http.Request{URL: &url.URL{Scheme: "https", Host: "gitea.example.com", Path: "/"}}
|
|
||||||
}
|
|
||||||
req := &http.Request{URL: &url.URL{Scheme: "https", Host: "gitea.example.com", Path: "/final"}}
|
|
||||||
err := defaultCheckRedirect(req, via)
|
|
||||||
if err == nil {
|
|
||||||
t.Fatal("expected error after 10 redirects")
|
|
||||||
}
|
|
||||||
if !strings.Contains(err.Error(), "10 redirects") {
|
|
||||||
t.Errorf("unexpected error message: %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestDefaultCheckRedirect_EmptyViaAllowed(t *testing.T) {
|
|
||||||
req := &http.Request{URL: &url.URL{Scheme: "https", Host: "gitea.example.com", Path: "/foo"}}
|
|
||||||
err := defaultCheckRedirect(req, nil)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("unexpected error with empty via: %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestSetHTTPClient_NilRestoresDefault(t *testing.T) {
|
|
||||||
c := NewClient("https://gitea.example.com", "token")
|
|
||||||
c.SetHTTPClient(nil)
|
|
||||||
if c.http == nil {
|
|
||||||
t.Fatal("expected non-nil http client after SetHTTPClient(nil)")
|
|
||||||
}
|
|
||||||
if c.http.Timeout != 30*time.Second {
|
|
||||||
t.Errorf("expected 30s timeout, got %v", c.http.Timeout)
|
|
||||||
}
|
|
||||||
if c.http.CheckRedirect == nil {
|
|
||||||
t.Fatal("expected CheckRedirect policy after SetHTTPClient(nil)")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// TestSafeDialContextBlocksPrivateIPs verifies that NewClient (which uses
|
|
||||||
// safeDialContext by default) refuses to connect to private/reserved IPs.
|
|
||||||
func TestSafeDialContextBlocksPrivateIPs(t *testing.T) {
|
|
||||||
// These servers listen on 127.0.0.1, so the safe dialer will block them.
|
|
||||||
// We use NewClient (NOT NewTestClient) to exercise the real safe dialer.
|
|
||||||
privateURLs := []struct {
|
|
||||||
name string
|
|
||||||
url string
|
|
||||||
}{
|
|
||||||
{"loopback localhost", "http://localhost/"},
|
|
||||||
{"loopback 127.0.0.1", "http://127.0.0.1/"},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tc := range privateURLs {
|
|
||||||
t.Run(tc.name, func(t *testing.T) {
|
|
||||||
c := NewClient(tc.url, "token")
|
|
||||||
_, err := c.GetPullRequest(context.Background(), "owner", "repo", 1)
|
|
||||||
if err == nil {
|
|
||||||
t.Errorf("expected error connecting to %s, got nil", tc.url)
|
|
||||||
}
|
|
||||||
// Error must mention SSRF/blocked, not a random network error.
|
|
||||||
if !strings.Contains(err.Error(), "blocked") &&
|
|
||||||
!strings.Contains(err.Error(), "private") &&
|
|
||||||
!strings.Contains(err.Error(), "loopback") &&
|
|
||||||
!strings.Contains(err.Error(), "reserved") {
|
|
||||||
t.Logf("error: %v", err)
|
|
||||||
// Allow other errors (connection refused, DNS) since the point
|
|
||||||
// is that we don't silently succeed — but prefer the explicit block message.
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// TestWithUnsafeDialerAllowsLocalhost verifies that WithUnsafeDialer bypasses
|
|
||||||
// the IP check, allowing tests to connect to httptest.Server (127.0.0.1).
|
|
||||||
func TestWithUnsafeDialerAllowsLocalhost(t *testing.T) {
|
|
||||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
w.Header().Set("Content-Type", "application/json")
|
|
||||||
w.Write([]byte(`{"title":"test","body":"","head":{"sha":"abc","ref":"main"}}`))
|
|
||||||
}))
|
|
||||||
defer server.Close()
|
|
||||||
|
|
||||||
// WithUnsafeDialer should allow connecting to 127.0.0.1.
|
|
||||||
c := NewClient(server.URL, "token").WithUnsafeDialer()
|
|
||||||
pr, err := c.GetPullRequest(context.Background(), "owner", "repo", 1)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("unexpected error with unsafe dialer: %v", err)
|
|
||||||
}
|
|
||||||
if pr.Title != "test" {
|
|
||||||
t.Errorf("expected title 'test', got %q", pr.Title)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// TestNewClient_HasSafeTransport verifies that NewClient installs the
|
|
||||||
// SSRF-blocking transport (i.e. Transport is not nil and DialContext is set).
|
|
||||||
func TestNewClient_HasSafeTransport(t *testing.T) {
|
|
||||||
c := NewClient("https://gitea.example.com", "token")
|
|
||||||
if c.http.Transport == nil {
|
|
||||||
t.Fatal("expected Transport to be set on NewClient (safe dialer)")
|
|
||||||
}
|
|
||||||
transport, ok := c.http.Transport.(*http.Transport)
|
|
||||||
if !ok {
|
|
||||||
t.Fatalf("expected *http.Transport, got %T", c.http.Transport)
|
|
||||||
}
|
|
||||||
if transport.DialContext == nil {
|
|
||||||
t.Fatal("expected DialContext to be set on transport (safe dialer)")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// TestSetHTTPClient_NilRestoresSafeTransport verifies that SetHTTPClient(nil)
|
|
||||||
// restores the safe transport (not just any client).
|
|
||||||
func TestSetHTTPClient_NilRestoresSafeTransport(t *testing.T) {
|
|
||||||
c := NewClient("https://gitea.example.com", "token")
|
|
||||||
c.SetHTTPClient(&http.Client{}) // replace with plain client
|
|
||||||
c.SetHTTPClient(nil) // restore
|
|
||||||
transport, ok := c.http.Transport.(*http.Transport)
|
|
||||||
if !ok {
|
|
||||||
t.Fatalf("expected *http.Transport after SetHTTPClient(nil), got %T", c.http.Transport)
|
|
||||||
}
|
|
||||||
if transport.DialContext == nil {
|
|
||||||
t.Fatal("expected DialContext to be restored after SetHTTPClient(nil)")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// TestNewSafeHTTPClient_PreservesDefaultTransportSettings verifies that
|
|
||||||
// newSafeHTTPClient clones http.DefaultTransport to retain proxy support,
|
|
||||||
// TLS handshake timeout, idle connection limits, and HTTP/2.
|
|
||||||
func TestNewSafeHTTPClient_PreservesDefaultTransportSettings(t *testing.T) {
|
|
||||||
c := NewClient("https://gitea.example.com", "token")
|
|
||||||
transport, ok := c.http.Transport.(*http.Transport)
|
|
||||||
if !ok {
|
|
||||||
t.Fatalf("expected *http.Transport, got %T", c.http.Transport)
|
|
||||||
}
|
|
||||||
|
|
||||||
defaults := http.DefaultTransport.(*http.Transport)
|
|
||||||
|
|
||||||
// TLSHandshakeTimeout must be inherited (non-zero), not the zero value
|
|
||||||
// that a bare &http.Transport{} would have.
|
|
||||||
if transport.TLSHandshakeTimeout == 0 {
|
|
||||||
t.Error("TLSHandshakeTimeout is 0; expected inherited value from DefaultTransport")
|
|
||||||
}
|
|
||||||
if transport.TLSHandshakeTimeout != defaults.TLSHandshakeTimeout {
|
|
||||||
t.Errorf("TLSHandshakeTimeout = %v, want %v", transport.TLSHandshakeTimeout, defaults.TLSHandshakeTimeout)
|
|
||||||
}
|
|
||||||
|
|
||||||
// IdleConnTimeout must be inherited.
|
|
||||||
if transport.IdleConnTimeout == 0 {
|
|
||||||
t.Error("IdleConnTimeout is 0; expected inherited value from DefaultTransport")
|
|
||||||
}
|
|
||||||
if transport.IdleConnTimeout != defaults.IdleConnTimeout {
|
|
||||||
t.Errorf("IdleConnTimeout = %v, want %v", transport.IdleConnTimeout, defaults.IdleConnTimeout)
|
|
||||||
}
|
|
||||||
|
|
||||||
// MaxIdleConns must be inherited.
|
|
||||||
if transport.MaxIdleConns == 0 {
|
|
||||||
t.Error("MaxIdleConns is 0; expected inherited value from DefaultTransport")
|
|
||||||
}
|
|
||||||
|
|
||||||
// ForceAttemptHTTP2 must be inherited.
|
|
||||||
if !transport.ForceAttemptHTTP2 {
|
|
||||||
t.Error("ForceAttemptHTTP2 is false; expected true from DefaultTransport")
|
|
||||||
}
|
|
||||||
|
|
||||||
// Proxy must be set (ProxyFromEnvironment).
|
|
||||||
if transport.Proxy == nil {
|
|
||||||
t.Error("Proxy is nil; expected ProxyFromEnvironment from DefaultTransport")
|
|
||||||
}
|
|
||||||
|
|
||||||
// DialContext must be our safe dialer, not the default.
|
|
||||||
if transport.DialContext == nil {
|
|
||||||
t.Error("DialContext is nil; expected safeDialContext")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestGetTimelineReviewCommentIDForReview(t *testing.T) {
|
|
||||||
const reviewID = int64(42)
|
|
||||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
switch r.URL.Path {
|
|
||||||
case "/api/v1/repos/owner/repo/pulls/5/reviews/42":
|
|
||||||
w.Write([]byte(`{"body": "The review body <!-- review-bot:sonnet -->", "user": {"login": "sonnet-review"}}`))
|
|
||||||
case "/api/v1/repos/owner/repo/issues/5/timeline":
|
|
||||||
w.Write([]byte(`[
|
|
||||||
{"id": 100, "type": "comment", "body": "unrelated", "user": {"login": "sonnet-review"}},
|
|
||||||
{"id": 200, "type": "review", "body": "The review body <!-- review-bot:sonnet -->", "user": {"login": "sonnet-review"}}
|
|
||||||
]`))
|
|
||||||
default:
|
|
||||||
t.Errorf("unexpected path: %s", r.URL.Path)
|
|
||||||
w.WriteHeader(http.StatusNotFound)
|
|
||||||
}
|
|
||||||
}))
|
|
||||||
defer server.Close()
|
|
||||||
|
|
||||||
client := NewTestClient(server.URL, "test-token")
|
|
||||||
id, err := client.GetTimelineReviewCommentIDForReview(context.Background(), "owner", "repo", 5, reviewID)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("GetTimelineReviewCommentIDForReview() error = %v", err)
|
|
||||||
}
|
|
||||||
if id != 200 {
|
|
||||||
t.Errorf("got id=%d, want 200", id)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestGetTimelineReviewCommentIDForReview_ReviewFetchError(t *testing.T) {
|
|
||||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
w.WriteHeader(http.StatusNotFound)
|
|
||||||
w.Write([]byte(`{"message":"not found"}`))
|
|
||||||
}))
|
|
||||||
defer server.Close()
|
|
||||||
|
|
||||||
client := NewTestClient(server.URL, "test-token")
|
|
||||||
_, err := client.GetTimelineReviewCommentIDForReview(context.Background(), "owner", "repo", 5, 99)
|
|
||||||
if err == nil {
|
|
||||||
t.Fatal("expected error for missing review, got nil")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestGetTimelineReviewCommentIDForReview_EmptyBody(t *testing.T) {
|
|
||||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
w.Write([]byte(`{"body": "", "user": {"login": "bot"}}`))
|
|
||||||
}))
|
|
||||||
defer server.Close()
|
|
||||||
|
|
||||||
client := NewTestClient(server.URL, "test-token")
|
|
||||||
_, err := client.GetTimelineReviewCommentIDForReview(context.Background(), "owner", "repo", 5, 42)
|
|
||||||
if err == nil {
|
|
||||||
t.Fatal("expected error for empty body, got nil")
|
|
||||||
}
|
|
||||||
if !strings.Contains(err.Error(), "empty body") {
|
|
||||||
t.Errorf("error = %q, want to contain 'empty body'", err.Error())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestGetTimelineReviewCommentIDForReview_NotFoundInTimeline(t *testing.T) {
|
|
||||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
switch r.URL.Path {
|
|
||||||
case "/api/v1/repos/owner/repo/pulls/5/reviews/42":
|
|
||||||
w.Write([]byte(`{"body": "review content <!-- review-bot:sonnet -->", "user": {"login": "bot"}}`))
|
|
||||||
default:
|
|
||||||
// Timeline returns events that don't match (different user)
|
|
||||||
w.Write([]byte(`[{"id": 1, "type": "review", "body": "review content <!-- review-bot:sonnet -->", "user": {"login": "other-user"}}]`))
|
|
||||||
}
|
|
||||||
}))
|
|
||||||
defer server.Close()
|
|
||||||
|
|
||||||
client := NewTestClient(server.URL, "test-token")
|
|
||||||
_, err := client.GetTimelineReviewCommentIDForReview(context.Background(), "owner", "repo", 5, 42)
|
|
||||||
if err == nil {
|
|
||||||
t.Fatal("expected error when review not found in timeline, got nil")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -0,0 +1,25 @@
|
|||||||
|
//go:build phase2
|
||||||
|
|
||||||
|
package gitea_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"gitea.weiker.me/rodin/review-bot/gitea"
|
||||||
|
"gitea.weiker.me/rodin/review-bot/vcs"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Compile-time interface conformance assertions.
|
||||||
|
// These will verify gitea.Client satisfies vcs interfaces once the Phase 2
|
||||||
|
// adapter bridges the method signature gaps:
|
||||||
|
//
|
||||||
|
// - PRReader: GetPullRequest returns *gitea.PullRequest (needs *vcs.PullRequest)
|
||||||
|
// - PRReader: GetPullRequestFiles returns []gitea.ChangedFile (needs []vcs.ChangedFile)
|
||||||
|
// - FileReader: GetFileContent lacks ref parameter
|
||||||
|
// - Reviewer: PostReview uses (event, body, comments) instead of vcs.ReviewRequest
|
||||||
|
//
|
||||||
|
// Remove the phase2 build tag once the adapter is complete.
|
||||||
|
var (
|
||||||
|
_ vcs.PRReader = (*gitea.Client)(nil)
|
||||||
|
_ vcs.FileReader = (*gitea.Client)(nil)
|
||||||
|
_ vcs.Reviewer = (*gitea.Client)(nil)
|
||||||
|
_ vcs.Identity = (*gitea.Client)(nil)
|
||||||
|
)
|
||||||
@@ -1,97 +0,0 @@
|
|||||||
package gitea
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"errors"
|
|
||||||
"math"
|
|
||||||
"net/http"
|
|
||||||
"net/http/httptest"
|
|
||||||
"strings"
|
|
||||||
"testing"
|
|
||||||
"time"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestGetPullRequestDiff_SizeLimits(t *testing.T) {
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
diff string
|
|
||||||
maxDiffSize int64
|
|
||||||
wantErr error
|
|
||||||
wantDiff string
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
name: "exceeds max size",
|
|
||||||
diff: strings.Repeat("+ added line\n", 1000), // ~13 KB
|
|
||||||
maxDiffSize: 100,
|
|
||||||
wantErr: ErrDiffTooLarge,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "within max size",
|
|
||||||
diff: "diff --git a/f.go b/f.go\n--- a/f.go\n+++ b/f.go\n@@ -1 +1 @@\n-old\n+new\n",
|
|
||||||
maxDiffSize: 1024,
|
|
||||||
wantDiff: "diff --git a/f.go b/f.go\n--- a/f.go\n+++ b/f.go\n@@ -1 +1 @@\n-old\n+new\n",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "exactly at limit",
|
|
||||||
diff: strings.Repeat("x", 50),
|
|
||||||
maxDiffSize: 50,
|
|
||||||
wantDiff: strings.Repeat("x", 50),
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "one byte over limit",
|
|
||||||
diff: strings.Repeat("x", 51),
|
|
||||||
maxDiffSize: 50,
|
|
||||||
wantErr: ErrDiffTooLarge,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "disabled limit",
|
|
||||||
diff: strings.Repeat("x", 10000),
|
|
||||||
maxDiffSize: -1,
|
|
||||||
wantDiff: strings.Repeat("x", 10000),
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "math.MaxInt64 treated as disabled",
|
|
||||||
diff: strings.Repeat("x", 10000),
|
|
||||||
maxDiffSize: math.MaxInt64,
|
|
||||||
wantDiff: strings.Repeat("x", 10000),
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "default limit",
|
|
||||||
diff: "diff content",
|
|
||||||
maxDiffSize: 0, // zero means use DefaultMaxDiffSize
|
|
||||||
wantDiff: "diff content",
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
w.Write([]byte(tt.diff)) //nolint:errcheck // test handler
|
|
||||||
}))
|
|
||||||
defer server.Close()
|
|
||||||
|
|
||||||
client := NewTestClient(server.URL, "test-token")
|
|
||||||
client.MaxDiffSize = tt.maxDiffSize
|
|
||||||
client.RetryBackoff = []time.Duration{}
|
|
||||||
|
|
||||||
got, err := client.GetPullRequestDiff(context.Background(), "owner", "repo", 1)
|
|
||||||
|
|
||||||
if tt.wantErr != nil {
|
|
||||||
if err == nil {
|
|
||||||
t.Fatal("expected error, got nil")
|
|
||||||
}
|
|
||||||
if !errors.Is(err, tt.wantErr) {
|
|
||||||
t.Errorf("expected %v, got: %v", tt.wantErr, err)
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("unexpected error: %v", err)
|
|
||||||
}
|
|
||||||
if got != tt.wantDiff {
|
|
||||||
t.Errorf("diff mismatch: got length %d, want length %d", len(got), len(tt.wantDiff))
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,18 +0,0 @@
|
|||||||
// Package gitea — export_test.go exposes test helpers to test files in this
|
|
||||||
// package. It uses `package gitea` (not `package gitea_test`) so it can access
|
|
||||||
// unexported identifiers; Go only compiles it into the test binary, never into
|
|
||||||
// the production binary. This is the idiomatic pattern for white-box testing
|
|
||||||
// in Go (see net/http/export_test.go in the stdlib for the same approach).
|
|
||||||
package gitea
|
|
||||||
|
|
||||||
// NewTestClient creates a Gitea client configured for use in unit tests.
|
|
||||||
// It bypasses the IP-level SSRF protection so that tests can connect to
|
|
||||||
// httptest.Server instances (which listen on 127.0.0.1).
|
|
||||||
//
|
|
||||||
// Using the internal package gitea declaration (not gitea_test) means this
|
|
||||||
// symbol is available to all _test.go files in this package. It is ONLY
|
|
||||||
// compiled into the test binary; production binaries never include it.
|
|
||||||
// Production code must use NewClient, which enables the safe dialer.
|
|
||||||
func NewTestClient(baseURL, token string) *Client {
|
|
||||||
return NewClient(baseURL, token).WithUnsafeDialer()
|
|
||||||
}
|
|
||||||
@@ -1,91 +0,0 @@
|
|||||||
// Package gitea provides a client for the Gitea API.
|
|
||||||
// ipcheck.go implements IP-level SSRF protection by checking resolved addresses
|
|
||||||
// against known blocked CIDR ranges (RFC1918, loopback, link-local, etc.).
|
|
||||||
package gitea
|
|
||||||
|
|
||||||
import (
|
|
||||||
"fmt"
|
|
||||||
"net"
|
|
||||||
)
|
|
||||||
|
|
||||||
// blockedCIDRStrings is the canonical list of CIDR strings that should never
|
|
||||||
// be contacted by review-bot. See IsBlockedIP for the full list of covered
|
|
||||||
// address families.
|
|
||||||
//
|
|
||||||
// These are hard-coded literals: any parse failure is a programming error.
|
|
||||||
// Validity is verified by TestBlockedCIDRsValid in ipcheck_test.go.
|
|
||||||
var blockedCIDRStrings = []string{
|
|
||||||
// IPv4 loopback
|
|
||||||
"127.0.0.0/8",
|
|
||||||
// IPv4 unspecified / "this network"
|
|
||||||
"0.0.0.0/8",
|
|
||||||
// RFC1918 private ranges
|
|
||||||
"10.0.0.0/8",
|
|
||||||
"172.16.0.0/12",
|
|
||||||
"192.168.0.0/16",
|
|
||||||
// IPv4 link-local (APIPA, also used by AWS instance metadata 169.254.169.254)
|
|
||||||
"169.254.0.0/16",
|
|
||||||
// IPv4 shared address space (RFC6598, carrier-grade NAT)
|
|
||||||
"100.64.0.0/10",
|
|
||||||
// IPv4 multicast
|
|
||||||
"224.0.0.0/4",
|
|
||||||
// IPv4 reserved / broadcast
|
|
||||||
"240.0.0.0/4",
|
|
||||||
// IPv6 loopback
|
|
||||||
"::1/128",
|
|
||||||
// IPv6 unspecified
|
|
||||||
"::/128",
|
|
||||||
// IPv6 link-local
|
|
||||||
"fe80::/10",
|
|
||||||
// IPv6 unique local (ULA) — RFC4193
|
|
||||||
"fc00::/7",
|
|
||||||
// IPv6 multicast
|
|
||||||
"ff00::/8",
|
|
||||||
}
|
|
||||||
|
|
||||||
// blockedCIDRs is the parsed form of blockedCIDRStrings.
|
|
||||||
// Any entry that fails to parse is recorded in blockedCIDRParseErrors instead
|
|
||||||
// of panicking; tests verify this slice is always empty via TestBlockedCIDRsValid.
|
|
||||||
var (
|
|
||||||
blockedCIDRs []*net.IPNet
|
|
||||||
blockedCIDRParseErrors []string
|
|
||||||
)
|
|
||||||
|
|
||||||
func init() {
|
|
||||||
blockedCIDRs = make([]*net.IPNet, 0, len(blockedCIDRStrings))
|
|
||||||
for _, r := range blockedCIDRStrings {
|
|
||||||
_, cidr, err := net.ParseCIDR(r)
|
|
||||||
if err != nil {
|
|
||||||
// Record the error rather than panicking; TestBlockedCIDRsValid
|
|
||||||
// will catch this during tests, and the CI build will fail.
|
|
||||||
blockedCIDRParseErrors = append(blockedCIDRParseErrors,
|
|
||||||
fmt.Sprintf("ipcheck: invalid built-in CIDR %q: %v", r, err))
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
blockedCIDRs = append(blockedCIDRs, cidr)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// IsBlockedIP reports whether ip is in a blocked address range.
|
|
||||||
// It is exported for use by the validate-url subcommand and tests outside
|
|
||||||
// this package.
|
|
||||||
//
|
|
||||||
// IPv6-mapped IPv4 addresses (e.g. ::ffff:192.168.1.1) are normalized to their
|
|
||||||
// IPv4 form before checking so that IPv4 CIDRs catch them.
|
|
||||||
//
|
|
||||||
// Based on:
|
|
||||||
// - RFC1918 private ranges
|
|
||||||
// - RFC5735 / RFC4193 special-use IPv4/IPv6 ranges
|
|
||||||
// - RFC4291 IPv6 link-local / loopback
|
|
||||||
func IsBlockedIP(ip net.IP) bool {
|
|
||||||
// Normalize IPv6-mapped IPv4 addresses (::ffff:x.x.x.x) to plain IPv4.
|
|
||||||
if v4 := ip.To4(); v4 != nil {
|
|
||||||
ip = v4
|
|
||||||
}
|
|
||||||
for _, cidr := range blockedCIDRs {
|
|
||||||
if cidr.Contains(ip) {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
@@ -1,144 +0,0 @@
|
|||||||
package gitea
|
|
||||||
|
|
||||||
import (
|
|
||||||
"net"
|
|
||||||
"testing"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestIsBlockedIP(t *testing.T) {
|
|
||||||
blocked := []struct {
|
|
||||||
name string
|
|
||||||
ip string
|
|
||||||
}{
|
|
||||||
// IPv4 loopback
|
|
||||||
{"loopback 127.0.0.1", "127.0.0.1"},
|
|
||||||
{"loopback 127.0.0.2", "127.0.0.2"},
|
|
||||||
{"loopback 127.255.255.255", "127.255.255.255"},
|
|
||||||
// IPv4 unspecified
|
|
||||||
{"unspecified 0.0.0.0", "0.0.0.0"},
|
|
||||||
{"unspecified 0.1.2.3", "0.1.2.3"},
|
|
||||||
// RFC1918
|
|
||||||
{"RFC1918 10.0.0.1", "10.0.0.1"},
|
|
||||||
{"RFC1918 10.255.255.255", "10.255.255.255"},
|
|
||||||
{"RFC1918 172.16.0.1", "172.16.0.1"},
|
|
||||||
{"RFC1918 172.31.255.255", "172.31.255.255"},
|
|
||||||
{"RFC1918 192.168.0.1", "192.168.0.1"},
|
|
||||||
{"RFC1918 192.168.255.255", "192.168.255.255"},
|
|
||||||
// Link-local (APIPA / AWS metadata)
|
|
||||||
{"link-local 169.254.0.1", "169.254.0.1"},
|
|
||||||
{"link-local 169.254.169.254", "169.254.169.254"},
|
|
||||||
// Shared address space (carrier-grade NAT)
|
|
||||||
{"CGN 100.64.0.1", "100.64.0.1"},
|
|
||||||
{"CGN 100.127.255.255", "100.127.255.255"},
|
|
||||||
// Multicast
|
|
||||||
{"multicast 224.0.0.1", "224.0.0.1"},
|
|
||||||
{"multicast 239.255.255.255", "239.255.255.255"},
|
|
||||||
// Reserved
|
|
||||||
{"reserved 240.0.0.1", "240.0.0.1"},
|
|
||||||
{"broadcast 255.255.255.255", "255.255.255.255"},
|
|
||||||
// IPv6 loopback
|
|
||||||
{"IPv6 loopback ::1", "::1"},
|
|
||||||
// IPv6 unspecified
|
|
||||||
{"IPv6 unspecified ::", "::"},
|
|
||||||
// IPv6 link-local
|
|
||||||
{"IPv6 link-local fe80::1", "fe80::1"},
|
|
||||||
{"IPv6 link-local fe80::dead:beef", "fe80::dead:beef"},
|
|
||||||
// IPv6 ULA
|
|
||||||
{"IPv6 ULA fc00::1", "fc00::1"},
|
|
||||||
{"IPv6 ULA fd00::1", "fd00::1"},
|
|
||||||
// IPv6 multicast
|
|
||||||
{"IPv6 multicast ff02::1", "ff02::1"},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tc := range blocked {
|
|
||||||
t.Run(tc.name, func(t *testing.T) {
|
|
||||||
ip := net.ParseIP(tc.ip)
|
|
||||||
if ip == nil {
|
|
||||||
t.Fatalf("failed to parse IP %q", tc.ip)
|
|
||||||
}
|
|
||||||
if !IsBlockedIP(ip) {
|
|
||||||
t.Errorf("IsBlockedIP(%q) = false, want true", tc.ip)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
allowed := []struct {
|
|
||||||
name string
|
|
||||||
ip string
|
|
||||||
}{
|
|
||||||
{"public 8.8.8.8", "8.8.8.8"},
|
|
||||||
{"public 1.1.1.1", "1.1.1.1"},
|
|
||||||
{"public 198.51.100.1", "198.51.100.1"}, // RFC5737 TEST-NET-2 — a documentation-only range;
|
|
||||||
// not assigned to any real host, but intentionally left unblocked here because
|
|
||||||
// it has no special routing treatment (unlike RFC1918/loopback/link-local) and
|
|
||||||
// blocking it would require tracking every RFC5737 range without meaningful
|
|
||||||
// security benefit (no server should ever listen on a TEST-NET address).
|
|
||||||
{"public 151.101.1.1", "151.101.1.1"}, // Fastly
|
|
||||||
{"public IPv6 2001:4860:4860::8888", "2001:4860:4860::8888"}, // Google DNS
|
|
||||||
{"public IPv6 2606:4700:4700::1111", "2606:4700:4700::1111"}, // Cloudflare DNS
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tc := range allowed {
|
|
||||||
t.Run(tc.name, func(t *testing.T) {
|
|
||||||
ip := net.ParseIP(tc.ip)
|
|
||||||
if ip == nil {
|
|
||||||
t.Fatalf("failed to parse IP %q", tc.ip)
|
|
||||||
}
|
|
||||||
if IsBlockedIP(ip) {
|
|
||||||
t.Errorf("IsBlockedIP(%q) = true, want false", tc.ip)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestIsBlockedIPv6MappedIPv4(t *testing.T) {
|
|
||||||
// ::ffff:192.168.1.1 is an IPv6-mapped IPv4 address — should be blocked as RFC1918.
|
|
||||||
// Construct it manually as a 16-byte IP.
|
|
||||||
mapped := net.IP{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0xff, 0xff, 192, 168, 1, 1}
|
|
||||||
if !IsBlockedIP(mapped) {
|
|
||||||
t.Errorf("IsBlockedIP(::ffff:192.168.1.1) = false, want true (IPv6-mapped IPv4 must be normalized)")
|
|
||||||
}
|
|
||||||
|
|
||||||
// ::ffff:8.8.8.8 — IPv6-mapped public IP — should be allowed.
|
|
||||||
mappedPublic := net.IP{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0xff, 0xff, 8, 8, 8, 8}
|
|
||||||
if IsBlockedIP(mappedPublic) {
|
|
||||||
t.Errorf("IsBlockedIP(::ffff:8.8.8.8) = true, want false")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestIsBlockedIPEdgeCases(t *testing.T) {
|
|
||||||
// The boundary between RFC1918 and public ranges.
|
|
||||||
// 172.15.255.255 is NOT private (just below 172.16.0.0/12).
|
|
||||||
notPrivate := net.ParseIP("172.15.255.255")
|
|
||||||
if IsBlockedIP(notPrivate) {
|
|
||||||
t.Errorf("IsBlockedIP(172.15.255.255) = true, want false (outside 172.16.0.0/12)")
|
|
||||||
}
|
|
||||||
// 172.32.0.0 is NOT private (just above 172.31.255.255).
|
|
||||||
notPrivate2 := net.ParseIP("172.32.0.0")
|
|
||||||
if IsBlockedIP(notPrivate2) {
|
|
||||||
t.Errorf("IsBlockedIP(172.32.0.0) = true, want false (outside 172.16.0.0/12)")
|
|
||||||
}
|
|
||||||
// CGN: 100.63.255.255 is NOT in 100.64.0.0/10.
|
|
||||||
notCGN := net.ParseIP("100.63.255.255")
|
|
||||||
if IsBlockedIP(notCGN) {
|
|
||||||
t.Errorf("IsBlockedIP(100.63.255.255) = true, want false (outside 100.64.0.0/10)")
|
|
||||||
}
|
|
||||||
// CGN: 100.128.0.0 is NOT in 100.64.0.0/10.
|
|
||||||
notCGN2 := net.ParseIP("100.128.0.0")
|
|
||||||
if IsBlockedIP(notCGN2) {
|
|
||||||
t.Errorf("IsBlockedIP(100.128.0.0) = true, want false (outside 100.64.0.0/10)")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// TestBlockedCIDRsValid verifies that all entries in blockedCIDRStrings parse
|
|
||||||
// successfully. This catches programming errors in the CIDR list without
|
|
||||||
// requiring a startup panic. The init() function records parse failures in
|
|
||||||
// blockedCIDRParseErrors rather than panicking; this test makes those failures
|
|
||||||
// visible as test failures during CI.
|
|
||||||
func TestBlockedCIDRsValid(t *testing.T) {
|
|
||||||
if len(blockedCIDRParseErrors) > 0 {
|
|
||||||
for _, msg := range blockedCIDRParseErrors {
|
|
||||||
t.Errorf("CIDR parse error: %s", msg)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -31,13 +31,13 @@ func TestPostReview_WithComments(t *testing.T) {
|
|||||||
}))
|
}))
|
||||||
defer server.Close()
|
defer server.Close()
|
||||||
|
|
||||||
client := NewTestClient(server.URL, "test-token")
|
client := NewClient(server.URL, "test-token")
|
||||||
comments := []ReviewComment{
|
comments := []ReviewComment{
|
||||||
{Path: "main.go", NewPosition: 42, Body: "[MAJOR] Something bad"},
|
{Path: "main.go", NewPosition: 42, Body: "[MAJOR] Something bad"},
|
||||||
{Path: "util.go", NewPosition: 10, Body: "[MINOR] Style issue"},
|
{Path: "util.go", NewPosition: 10, Body: "[MINOR] Style issue"},
|
||||||
}
|
}
|
||||||
|
|
||||||
_, err := client.PostReview(context.Background(), "owner", "repo", 1, "REQUEST_CHANGES", "summary", "", comments)
|
_, err := client.PostReview(context.Background(), "owner", "repo", 1, "REQUEST_CHANGES", "summary", comments)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("unexpected error: %v", err)
|
t.Fatalf("unexpected error: %v", err)
|
||||||
}
|
}
|
||||||
@@ -71,8 +71,8 @@ func TestPostReview_NilComments(t *testing.T) {
|
|||||||
}))
|
}))
|
||||||
defer server.Close()
|
defer server.Close()
|
||||||
|
|
||||||
client := NewTestClient(server.URL, "test-token")
|
client := NewClient(server.URL, "test-token")
|
||||||
_, err := client.PostReview(context.Background(), "owner", "repo", 1, "APPROVED", "all good", "", nil)
|
_, err := client.PostReview(context.Background(), "owner", "repo", 1, "APPROVED", "all good", nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("unexpected error: %v", err)
|
t.Fatalf("unexpected error: %v", err)
|
||||||
}
|
}
|
||||||
|
|||||||
+129
-619
@@ -1,20 +1,15 @@
|
|||||||
// Package github provides a client for the GitHub API.
|
// Package github provides a client for the GitHub API.
|
||||||
// It supports pull request operations, file content retrieval,
|
// It supports pull request operations, file content retrieval, CI status checks,
|
||||||
// and review submission for both github.com and GitHub Enterprise.
|
// and directory listing for both github.com and GitHub Enterprise.
|
||||||
package github
|
package github
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
|
||||||
"context"
|
"context"
|
||||||
"encoding/base64"
|
|
||||||
"encoding/json"
|
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"log/slog"
|
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
"net/url"
|
||||||
"os"
|
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
@@ -22,28 +17,17 @@ import (
|
|||||||
|
|
||||||
const (
|
const (
|
||||||
defaultBaseURL = "https://api.github.com"
|
defaultBaseURL = "https://api.github.com"
|
||||||
|
userAgent = "review-bot/1.0"
|
||||||
|
|
||||||
// maxRetryAttempts is the number of times doRequest will attempt a request.
|
// maxResponseBytes limits successful response body reads to 10 MiB.
|
||||||
maxRetryAttempts = 3
|
maxResponseBytes = 10 * 1024 * 1024
|
||||||
|
|
||||||
// maxRetryAfter caps the maximum delay from a Retry-After header to prevent
|
|
||||||
// a server from stalling the client indefinitely.
|
|
||||||
maxRetryAfter = 60 * time.Second
|
|
||||||
|
|
||||||
// maxErrorBodyBytes limits how much of an error response body we read
|
|
||||||
// to protect against malicious servers sending unbounded data.
|
|
||||||
maxErrorBodyBytes = 64 * 1024 // 64 KB
|
|
||||||
|
|
||||||
// maxResponseBodyBytes limits how much of a successful response body we read
|
|
||||||
// for defense-in-depth against servers returning excessively large payloads.
|
|
||||||
maxResponseBodyBytes = 10 * 1024 * 1024 // 10 MB
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// APIError represents an HTTP error response from the GitHub API.
|
// APIError represents an HTTP error response from the GitHub API.
|
||||||
// It carries the status code so callers can distinguish between
|
// It carries the status code so callers can distinguish between
|
||||||
// different failure modes (e.g. 404 vs 500).
|
// different failure modes (e.g. 404 vs 500).
|
||||||
//
|
//
|
||||||
// The Body field stores up to 64 KiB of the raw response for programmatic
|
// The Body field stores up to 4 KiB of the raw response for programmatic
|
||||||
// inspection. Error() truncates to 200 bytes for safe logging, but callers
|
// inspection. Error() truncates to 200 bytes for safe logging, but callers
|
||||||
// should avoid logging or propagating Body directly in production since it may
|
// should avoid logging or propagating Body directly in production since it may
|
||||||
// contain sensitive details from the upstream server.
|
// contain sensitive details from the upstream server.
|
||||||
@@ -63,6 +47,13 @@ func (e *APIError) Error() string {
|
|||||||
return fmt.Sprintf("HTTP %d: %s", e.StatusCode, body)
|
return fmt.Sprintf("HTTP %d: %s", e.StatusCode, body)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SafeError returns the error string without response body content,
|
||||||
|
// suitable for logging in contexts where upstream response data should
|
||||||
|
// not be exposed.
|
||||||
|
func (e *APIError) SafeError() string {
|
||||||
|
return fmt.Sprintf("HTTP %d", e.StatusCode)
|
||||||
|
}
|
||||||
|
|
||||||
// IsNotFound reports whether an error is an API 404 response.
|
// IsNotFound reports whether an error is an API 404 response.
|
||||||
func IsNotFound(err error) bool {
|
func IsNotFound(err error) bool {
|
||||||
if apiErr, ok := asAPIError(err); ok {
|
if apiErr, ok := asAPIError(err); ok {
|
||||||
@@ -90,109 +81,85 @@ func asAPIError(err error) (*APIError, bool) {
|
|||||||
return nil, false
|
return nil, false
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// clientConfig holds optional configuration for NewClient.
|
||||||
|
type clientConfig struct {
|
||||||
|
allowInsecureHTTP bool
|
||||||
|
}
|
||||||
|
|
||||||
|
// ClientOption configures optional behavior of NewClient.
|
||||||
|
type ClientOption func(*clientConfig)
|
||||||
|
|
||||||
|
// AllowInsecureHTTP permits the client to use HTTP (non-TLS) base URLs.
|
||||||
|
// This should only be used for trusted internal deployments or testing.
|
||||||
|
func AllowInsecureHTTP() ClientOption {
|
||||||
|
return func(c *clientConfig) {
|
||||||
|
c.allowInsecureHTTP = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Client interacts with the GitHub API.
|
// Client interacts with the GitHub API.
|
||||||
// A Client is safe for concurrent use by multiple goroutines.
|
// A Client is safe for concurrent use by multiple goroutines.
|
||||||
// SetHTTPClient and SetRetryBackoff are intended for test setup only and must
|
// SetHTTPClient and SetRetryBackoff are intended for test setup only and must
|
||||||
// be called before any goroutines issue requests; they have no synchronization.
|
// be called before any goroutines issue requests; they have no synchronization.
|
||||||
type Client struct {
|
type Client struct {
|
||||||
baseURL string
|
baseURL string
|
||||||
token string
|
token string
|
||||||
httpClient *http.Client
|
|
||||||
|
|
||||||
// allowInsecureHTTP permits requests to HTTP (non-TLS) endpoints.
|
|
||||||
// When false, doRequest rejects URLs with an http:// scheme.
|
|
||||||
allowInsecureHTTP bool
|
allowInsecureHTTP bool
|
||||||
|
httpClient *http.Client
|
||||||
|
|
||||||
// retryBackoff defines the delays between retry attempts for 429 responses.
|
// retryBackoff defines the delays between retry attempts for 429 responses.
|
||||||
// retryBackoff[i] is the delay before attempt i+1 (after attempt i fails).
|
// retryBackoff[i] is the delay before attempt i+1 (after attempt i fails).
|
||||||
// If nil, defaults to {1s, 2s}.
|
// If nil, defaults to {1s, 2s}. Set to shorter durations in tests via SetRetryBackoff.
|
||||||
retryBackoff []time.Duration
|
retryBackoff []time.Duration
|
||||||
|
|
||||||
// now returns the current time. Defaults to time.Now.
|
|
||||||
// Override in tests to control HTTP-date Retry-After calculations.
|
|
||||||
now func() time.Time
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// defaultCheckRedirect is the redirect policy used by NewClient.
|
// defaultCheckRedirect is the redirect policy used by NewClient and SetHTTPClient(nil).
|
||||||
// NOTE: This function is intentionally duplicated in gitea/client.go (and vice versa)
|
// It rejects HTTPS→HTTP protocol downgrades (to prevent plaintext leakage) and strips
|
||||||
// because the packages are separate. Changes here must be mirrored there.
|
// the Authorization header on cross-host redirects to prevent credential leakage to
|
||||||
// It rejects HTTPS->HTTP protocol downgrades (to prevent plaintext leakage)
|
// third-party hosts (e.g. CDN redirects from GitHub).
|
||||||
// and cross-host redirects (to prevent following responses from untrusted
|
|
||||||
// endpoints). Same-host, same-or-upgraded-scheme redirects are allowed.
|
|
||||||
func defaultCheckRedirect(req *http.Request, via []*http.Request) error {
|
func defaultCheckRedirect(req *http.Request, via []*http.Request) error {
|
||||||
if len(via) >= 10 {
|
if len(via) >= 10 {
|
||||||
return fmt.Errorf("stopped after 10 redirects")
|
return fmt.Errorf("stopped after 10 redirects")
|
||||||
}
|
}
|
||||||
// Guard for direct invocation in tests and any future callers;
|
// Guard: net/http guarantees len(via) >= 1 but this is undocumented;
|
||||||
// net/http guarantees len(via) >= 1 during actual redirects.
|
// defend against zero-length to avoid panic on index out of range.
|
||||||
if len(via) == 0 {
|
if len(via) == 0 {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
prev := via[len(via)-1]
|
prev := via[len(via)-1]
|
||||||
// Reject protocol downgrade: HTTPS->HTTP leaks request metadata over plaintext.
|
// Reject protocol downgrade: HTTPS→HTTP leaks request metadata over plaintext.
|
||||||
if prev.URL.Scheme == "https" && req.URL.Scheme == "http" {
|
if prev.URL.Scheme == "https" && req.URL.Scheme == "http" {
|
||||||
return fmt.Errorf("refusing redirect: HTTPS to HTTP downgrade (%s -> %s)", prev.URL.Host, req.URL.Host)
|
return fmt.Errorf("refusing redirect from HTTPS to HTTP (%s → %s)", prev.URL.Host, req.URL.Host)
|
||||||
}
|
}
|
||||||
// Reject cross-host redirect entirely to avoid consuming responses
|
// Strip Authorization on cross-host redirect to avoid leaking credentials
|
||||||
// from untrusted endpoints.
|
// to third-party hosts (GitHub legitimately redirects to CDN hosts).
|
||||||
if req.URL.Host != prev.URL.Host {
|
if req.URL.Host != prev.URL.Host {
|
||||||
return fmt.Errorf("refusing redirect: cross-host (%s -> %s)", prev.URL.Host, req.URL.Host)
|
req.Header.Del("Authorization")
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// ClientOption configures optional behavior of a Client.
|
|
||||||
type ClientOption func(*clientConfig)
|
|
||||||
|
|
||||||
type clientConfig struct {
|
|
||||||
allowInsecureHTTP bool
|
|
||||||
insecureIsTestBypass bool
|
|
||||||
}
|
|
||||||
|
|
||||||
// AllowInsecureHTTP permits sending credentials over plaintext HTTP connections.
|
|
||||||
// In production, this option is gated by the REVIEW_BOT_ALLOW_INSECURE=1
|
|
||||||
// environment variable. Without the env var set, the option is ignored
|
|
||||||
// and a warning is logged.
|
|
||||||
//
|
|
||||||
// For tests, use AllowInsecureHTTPForTest (defined in a _test.go file in the same package) which bypasses the env gate.
|
|
||||||
func AllowInsecureHTTP() ClientOption {
|
|
||||||
return func(cfg *clientConfig) {
|
|
||||||
cfg.allowInsecureHTTP = true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// NewClient creates a new GitHub API client.
|
// NewClient creates a new GitHub API client.
|
||||||
// If baseURL is empty, it defaults to https://api.github.com.
|
// If baseURL is empty, it defaults to https://api.github.com.
|
||||||
// For GitHub Enterprise, pass the API base URL (e.g. https://github.concur.com/api/v3).
|
// For GitHub Enterprise, pass the API base URL (e.g. https://github.concur.com/api/v3).
|
||||||
|
// The baseURL must use HTTPS; pass AllowInsecureHTTP() as an option to permit HTTP
|
||||||
|
// for trusted internal deployments (e.g. local testing).
|
||||||
func NewClient(token, baseURL string, opts ...ClientOption) *Client {
|
func NewClient(token, baseURL string, opts ...ClientOption) *Client {
|
||||||
if baseURL == "" {
|
if baseURL == "" {
|
||||||
baseURL = defaultBaseURL
|
baseURL = defaultBaseURL
|
||||||
}
|
}
|
||||||
|
cfg := clientConfig{}
|
||||||
var cfg clientConfig
|
for _, o := range opts {
|
||||||
for _, opt := range opts {
|
o(&cfg)
|
||||||
opt(&cfg)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if cfg.allowInsecureHTTP && !cfg.insecureIsTestBypass {
|
|
||||||
if os.Getenv("REVIEW_BOT_ALLOW_INSECURE") != "1" {
|
|
||||||
slog.Warn("AllowInsecureHTTP ignored: set REVIEW_BOT_ALLOW_INSECURE=1 to enable")
|
|
||||||
cfg.allowInsecureHTTP = false
|
|
||||||
} else {
|
|
||||||
slog.Warn("AllowInsecureHTTP enabled — credentials may be sent over plaintext",
|
|
||||||
"env", "REVIEW_BOT_ALLOW_INSECURE=1")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return &Client{
|
return &Client{
|
||||||
baseURL: strings.TrimRight(baseURL, "/"),
|
baseURL: strings.TrimRight(baseURL, "/"),
|
||||||
token: token,
|
|
||||||
allowInsecureHTTP: cfg.allowInsecureHTTP,
|
allowInsecureHTTP: cfg.allowInsecureHTTP,
|
||||||
|
token: token,
|
||||||
httpClient: &http.Client{
|
httpClient: &http.Client{
|
||||||
Timeout: 30 * time.Second,
|
Timeout: 30 * time.Second,
|
||||||
CheckRedirect: defaultCheckRedirect,
|
CheckRedirect: defaultCheckRedirect,
|
||||||
},
|
},
|
||||||
now: time.Now,
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -200,7 +167,7 @@ func NewClient(token, baseURL string, opts ...ClientOption) *Client {
|
|||||||
// This is intended for test setup only to inject mock transports; it must be
|
// This is intended for test setup only to inject mock transports; it must be
|
||||||
// called before any goroutines issue requests.
|
// called before any goroutines issue requests.
|
||||||
//
|
//
|
||||||
// Passing nil restores the default client (30s timeout + redirect-rejecting
|
// Passing nil restores the default client (30s timeout + auth-stripping
|
||||||
// CheckRedirect policy matching NewClient).
|
// CheckRedirect policy matching NewClient).
|
||||||
//
|
//
|
||||||
// Callers providing a non-nil client are responsible for configuring a safe
|
// Callers providing a non-nil client are responsible for configuring a safe
|
||||||
@@ -212,94 +179,56 @@ func (c *Client) SetHTTPClient(hc *http.Client) {
|
|||||||
Timeout: 30 * time.Second,
|
Timeout: 30 * time.Second,
|
||||||
CheckRedirect: defaultCheckRedirect,
|
CheckRedirect: defaultCheckRedirect,
|
||||||
}
|
}
|
||||||
|
} else if hc.CheckRedirect == nil {
|
||||||
|
// Enforce safe redirect policy when caller provides a client without one.
|
||||||
|
// The default net/http behavior follows up to 10 redirects and forwards
|
||||||
|
// all headers (including Authorization) to any host, which can leak
|
||||||
|
// credentials on cross-host redirects.
|
||||||
|
hc.CheckRedirect = defaultCheckRedirect
|
||||||
}
|
}
|
||||||
c.httpClient = hc
|
c.httpClient = hc
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetRetryBackoff sets the delays between retry attempts.
|
// SetRetryBackoff configures the retry backoff durations for testing.
|
||||||
// This is intended for testing to speed up retry tests.
|
// It must be called before any goroutines issue requests.
|
||||||
//
|
// In production the default {1s, 2s} applies.
|
||||||
// Note: if an empty non-nil slice is provided, Retry-After delays parsed from
|
func (c *Client) SetRetryBackoff(d []time.Duration) {
|
||||||
// server responses will be computed and capped but not applied (because
|
c.retryBackoff = d
|
||||||
// attempt < len(backoff) is always false). This is acceptable for the
|
|
||||||
// test-only use case but callers should be aware of this edge case.
|
|
||||||
func (c *Client) SetRetryBackoff(backoff []time.Duration) {
|
|
||||||
c.retryBackoff = backoff
|
|
||||||
}
|
|
||||||
|
|
||||||
// parseRetryAfter parses a Retry-After header value, supporting both integer
|
|
||||||
// seconds (e.g. "120") and HTTP-date format (e.g. "Thu, 01 Dec 2025 16:00:00 GMT")
|
|
||||||
// as specified in RFC 7231 §7.1.3.
|
|
||||||
//
|
|
||||||
// For integer values, it returns the duration directly.
|
|
||||||
// For HTTP-date values, it computes the delay as the difference between the
|
|
||||||
// parsed time and now. If the date is in the past, it returns 0.
|
|
||||||
//
|
|
||||||
// Returns (0, false) if the value cannot be parsed as either format.
|
|
||||||
func (c *Client) parseRetryAfter(value string) (time.Duration, bool) {
|
|
||||||
value = strings.TrimSpace(value)
|
|
||||||
|
|
||||||
// Try integer seconds first (most common from GitHub).
|
|
||||||
// RFC 7231 allows delta-seconds of 0 to indicate immediate retry.
|
|
||||||
if seconds, err := strconv.Atoi(value); err == nil && seconds >= 0 {
|
|
||||||
return time.Duration(seconds) * time.Second, true
|
|
||||||
}
|
|
||||||
|
|
||||||
// Try HTTP-date format (RFC 7231 §7.1.3).
|
|
||||||
// http.ParseTime handles RFC 1123, RFC 850, and ASCTIME formats.
|
|
||||||
if retryAt, err := http.ParseTime(value); err == nil {
|
|
||||||
delay := retryAt.Sub(c.now())
|
|
||||||
if delay < 0 {
|
|
||||||
delay = 0
|
|
||||||
}
|
|
||||||
return delay, true
|
|
||||||
}
|
|
||||||
|
|
||||||
return 0, false
|
|
||||||
}
|
|
||||||
|
|
||||||
// redactURL redacts sensitive components from a URL for safe inclusion in error
|
|
||||||
// messages and log output. It removes userinfo (e.g., user:pass@) and replaces
|
|
||||||
// query parameters with a placeholder.
|
|
||||||
func redactURL(rawURL string) string {
|
|
||||||
u, err := url.Parse(rawURL)
|
|
||||||
if err != nil {
|
|
||||||
return "<unparseable URL>"
|
|
||||||
}
|
|
||||||
u.User = nil
|
|
||||||
|
|
||||||
if u.RawQuery != "" {
|
|
||||||
u.RawQuery = "<redacted>"
|
|
||||||
}
|
|
||||||
return u.String()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// doRequest performs an HTTP request with retry on 429 rate limit responses.
|
// doRequest performs an HTTP request with retry on 429 rate limit responses.
|
||||||
// It respects the Retry-After header when present, supporting both integer
|
// It respects the Retry-After header when present (capped at maxRetryAfter).
|
||||||
// seconds and HTTP-date formats (capped at maxRetryAfter).
|
// Transport errors (network failures, context cancellation) are not retried.
|
||||||
func (c *Client) doRequest(ctx context.Context, method, reqURL string, accept string) ([]byte, error) {
|
func (c *Client) doRequest(ctx context.Context, method, reqURL string, accept string) ([]byte, error) {
|
||||||
// NOTE: This parses reqURL a second time (http.NewRequestWithContext parses it
|
const maxAttempts = 3
|
||||||
// again internally). Acceptable cost: URL parsing is cheap and threading the
|
const maxRetryAfter = 120 * time.Second
|
||||||
// parsed *url.URL through would complicate the interface for negligible gain.
|
|
||||||
if !c.allowInsecureHTTP {
|
|
||||||
parsed, err := url.Parse(reqURL)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("parse request URL: %w", err)
|
|
||||||
}
|
|
||||||
if strings.EqualFold(parsed.Scheme, "http") {
|
|
||||||
return nil, fmt.Errorf("refusing HTTP request to %s: use HTTPS or set AllowInsecureHTTP option", redactURL(reqURL))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
var backoff []time.Duration
|
var backoff []time.Duration
|
||||||
if c.retryBackoff != nil {
|
if c.retryBackoff != nil {
|
||||||
backoff = append([]time.Duration(nil), c.retryBackoff...)
|
backoff = make([]time.Duration, len(c.retryBackoff))
|
||||||
|
copy(backoff, c.retryBackoff)
|
||||||
} else {
|
} else {
|
||||||
backoff = []time.Duration{1 * time.Second, 2 * time.Second}
|
backoff = []time.Duration{1 * time.Second, 2 * time.Second}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// maxErrorBodyBytes limits how much of an error response body is stored.
|
||||||
|
// Kept small (4 KiB) to reduce the risk of sensitive data leakage if callers
|
||||||
|
// log APIError.Body directly. Error() further truncates to 200 bytes.
|
||||||
|
const maxErrorBodyBytes = 4 * 1024
|
||||||
|
|
||||||
|
// Reject non-HTTPS URLs early since the URL is immutable across retries.
|
||||||
|
if c.token != "" && !c.allowInsecureHTTP {
|
||||||
|
parsed, err := url.Parse(reqURL)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("parse request URL: %w", err)
|
||||||
|
}
|
||||||
|
if !strings.EqualFold(parsed.Scheme, "https") {
|
||||||
|
return nil, fmt.Errorf("refusing to send credentials over non-HTTPS URL %q (use AllowInsecureHTTP option for trusted networks)", reqURL)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
var lastErr error
|
var lastErr error
|
||||||
for attempt := 0; attempt < maxRetryAttempts; attempt++ {
|
for attempt := 0; attempt < maxAttempts; attempt++ {
|
||||||
if attempt > 0 {
|
if attempt > 0 {
|
||||||
var delay time.Duration
|
var delay time.Duration
|
||||||
if attempt-1 < len(backoff) {
|
if attempt-1 < len(backoff) {
|
||||||
@@ -321,7 +250,13 @@ func (c *Client) doRequest(ctx context.Context, method, reqURL string, accept st
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("create request: %w", err)
|
return nil, fmt.Errorf("create request: %w", err)
|
||||||
}
|
}
|
||||||
req.Header.Set("Authorization", "Bearer "+c.token)
|
if c.token != "" {
|
||||||
|
// Bearer is the OAuth2 standard and is accepted by GitHub for both
|
||||||
|
// classic PATs and fine-grained tokens. The alternative "token" scheme
|
||||||
|
// is GitHub-specific and offers no additional compatibility.
|
||||||
|
req.Header.Set("Authorization", "Bearer "+c.token)
|
||||||
|
}
|
||||||
|
req.Header.Set("User-Agent", userAgent)
|
||||||
if accept != "" {
|
if accept != "" {
|
||||||
req.Header.Set("Accept", accept)
|
req.Header.Set("Accept", accept)
|
||||||
} else {
|
} else {
|
||||||
@@ -330,29 +265,34 @@ func (c *Client) doRequest(ctx context.Context, method, reqURL string, accept st
|
|||||||
|
|
||||||
resp, err := c.httpClient.Do(req)
|
resp, err := c.httpClient.Do(req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
// Transport errors (DNS, TLS, timeout) yield nil resp; no body to close.
|
||||||
return nil, fmt.Errorf("do request: %w", err)
|
return nil, fmt.Errorf("do request: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if resp.StatusCode >= 200 && resp.StatusCode < 300 {
|
body, done, err := handleResponse(resp, maxResponseBytes, maxErrorBodyBytes)
|
||||||
body, err := io.ReadAll(io.LimitReader(resp.Body, maxResponseBodyBytes))
|
if done {
|
||||||
resp.Body.Close()
|
return body, err
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("read response body: %w", err)
|
|
||||||
}
|
|
||||||
return body, nil
|
|
||||||
}
|
}
|
||||||
|
lastErr = err
|
||||||
errBody, _ := io.ReadAll(io.LimitReader(resp.Body, maxErrorBodyBytes))
|
|
||||||
resp.Body.Close()
|
|
||||||
|
|
||||||
lastErr = &APIError{StatusCode: resp.StatusCode, Body: string(errBody)}
|
|
||||||
|
|
||||||
// Retry on 429 rate limit
|
// Retry on 429 rate limit
|
||||||
if resp.StatusCode == http.StatusTooManyRequests && attempt < maxRetryAttempts-1 {
|
if resp.StatusCode == http.StatusTooManyRequests && attempt < maxAttempts-1 {
|
||||||
// Check for Retry-After header and override backoff if present.
|
// Check for Retry-After header and override backoff if present.
|
||||||
// Supports both integer seconds (common) and HTTP-date format (RFC 7231).
|
// Supports both integer seconds (common) and HTTP-date format (RFC 7231).
|
||||||
if ra := resp.Header.Get("Retry-After"); ra != "" {
|
if ra := resp.Header.Get("Retry-After"); ra != "" {
|
||||||
if delay, ok := c.parseRetryAfter(ra); ok {
|
if seconds, err := strconv.Atoi(ra); err == nil && seconds > 0 {
|
||||||
|
delay := time.Duration(seconds) * time.Second
|
||||||
|
if delay > maxRetryAfter {
|
||||||
|
delay = maxRetryAfter
|
||||||
|
}
|
||||||
|
if attempt < len(backoff) {
|
||||||
|
backoff[attempt] = delay
|
||||||
|
}
|
||||||
|
} else if retryAt, err := http.ParseTime(ra); err == nil {
|
||||||
|
delay := time.Until(retryAt)
|
||||||
|
if delay < 0 {
|
||||||
|
delay = 0
|
||||||
|
}
|
||||||
if delay > maxRetryAfter {
|
if delay > maxRetryAfter {
|
||||||
delay = maxRetryAfter
|
delay = maxRetryAfter
|
||||||
}
|
}
|
||||||
@@ -371,461 +311,31 @@ func (c *Client) doRequest(ctx context.Context, method, reqURL string, accept st
|
|||||||
return nil, lastErr
|
return nil, lastErr
|
||||||
}
|
}
|
||||||
|
|
||||||
// doGet is a convenience wrapper for GET requests with the default Accept header.
|
// handleResponse reads and closes the response body, returning the result.
|
||||||
func (c *Client) doGet(ctx context.Context, url string) ([]byte, error) {
|
// It uses defer to ensure the body is always closed regardless of code path.
|
||||||
return c.doRequest(ctx, http.MethodGet, url, "")
|
// Returns (body, done, err) where done=true means the caller should return immediately.
|
||||||
}
|
func handleResponse(resp *http.Response, maxRespBytes int, maxErrBytes int) ([]byte, bool, error) {
|
||||||
|
|
||||||
// doRequestWithBody performs an HTTP request with an optional body, applying the
|
|
||||||
// same HTTPS enforcement as doRequest. It is used by write methods (POST, PUT,
|
|
||||||
// DELETE) that bypass the retry loop in doRequest because write operations are
|
|
||||||
// not idempotent.
|
|
||||||
//
|
|
||||||
// body may be nil for requests that carry no payload (e.g. DELETE).
|
|
||||||
// When body is non-nil, Content-Type is set to application/json.
|
|
||||||
func (c *Client) doRequestWithBody(ctx context.Context, method, reqURL string, body []byte) ([]byte, error) {
|
|
||||||
if !c.allowInsecureHTTP {
|
|
||||||
parsed, err := url.Parse(reqURL)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("parse request URL: %w", err)
|
|
||||||
}
|
|
||||||
if strings.EqualFold(parsed.Scheme, "http") {
|
|
||||||
return nil, fmt.Errorf("refusing HTTP request to %s: use HTTPS or set AllowInsecureHTTP option", redactURL(reqURL))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
var reqBody io.Reader
|
|
||||||
if body != nil {
|
|
||||||
reqBody = bytes.NewReader(body)
|
|
||||||
}
|
|
||||||
|
|
||||||
req, err := http.NewRequestWithContext(ctx, method, reqURL, reqBody)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("create request: %w", err)
|
|
||||||
}
|
|
||||||
req.Header.Set("Authorization", "Bearer "+c.token)
|
|
||||||
req.Header.Set("Accept", "application/vnd.github+json")
|
|
||||||
if body != nil {
|
|
||||||
req.Header.Set("Content-Type", "application/json")
|
|
||||||
}
|
|
||||||
|
|
||||||
resp, err := c.httpClient.Do(req)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("do request: %w", err)
|
|
||||||
}
|
|
||||||
defer resp.Body.Close()
|
defer resp.Body.Close()
|
||||||
|
|
||||||
if resp.StatusCode >= 200 && resp.StatusCode < 300 {
|
if resp.StatusCode >= 200 && resp.StatusCode < 300 {
|
||||||
respBody, err := io.ReadAll(io.LimitReader(resp.Body, maxResponseBodyBytes))
|
body, err := io.ReadAll(io.LimitReader(resp.Body, int64(maxRespBytes)+1))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("read response body: %w", err)
|
return nil, true, fmt.Errorf("read response body: %w", err)
|
||||||
}
|
}
|
||||||
return respBody, nil
|
if len(body) > maxRespBytes {
|
||||||
}
|
return nil, true, fmt.Errorf("response body exceeded %d bytes (truncated)", maxRespBytes)
|
||||||
|
|
||||||
errBody, _ := io.ReadAll(io.LimitReader(resp.Body, maxErrorBodyBytes))
|
|
||||||
return nil, &APIError{StatusCode: resp.StatusCode, Body: string(errBody)}
|
|
||||||
}
|
|
||||||
|
|
||||||
// --- API types ---
|
|
||||||
|
|
||||||
// PullRequest holds relevant PR metadata.
|
|
||||||
type PullRequest struct {
|
|
||||||
Title string `json:"title"`
|
|
||||||
Body string `json:"body"`
|
|
||||||
Head struct {
|
|
||||||
Sha string `json:"sha"`
|
|
||||||
Ref string `json:"ref"`
|
|
||||||
} `json:"head"`
|
|
||||||
Draft bool `json:"draft"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// CommitStatus represents a single CI status entry.
|
|
||||||
// GitHub returns "state" not "status"; this type uses Status for consistency
|
|
||||||
// with the gitea package (both are normalized before use).
|
|
||||||
type CommitStatus struct {
|
|
||||||
Status string `json:"state"` // GitHub field is "state"
|
|
||||||
Context string `json:"context"`
|
|
||||||
Description string `json:"description"`
|
|
||||||
TargetURL string `json:"target_url"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// ChangedFile represents a file modified in a PR.
|
|
||||||
type ChangedFile struct {
|
|
||||||
Filename string `json:"filename"`
|
|
||||||
Status string `json:"status"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// ReviewComment represents an inline comment to attach to a review.
|
|
||||||
// GitHub uses "position" (diff hunk position), whereas Gitea uses "new_position" (line number).
|
|
||||||
// When posting inline comments on GitHub, position is required; line numbers
|
|
||||||
// from the diff cannot be used directly.
|
|
||||||
type ReviewComment struct {
|
|
||||||
ID int64 `json:"id,omitempty"`
|
|
||||||
Path string `json:"path"`
|
|
||||||
Position int64 `json:"position,omitempty"` // GitHub diff hunk position
|
|
||||||
Line int64 `json:"line,omitempty"` // GitHub absolute line number (alternative to position)
|
|
||||||
Side string `json:"side,omitempty"` // "RIGHT" or "LEFT"
|
|
||||||
Body string `json:"body"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// Review represents a pull request review from the GitHub API.
|
|
||||||
type Review struct {
|
|
||||||
ID int64 `json:"id"`
|
|
||||||
Body string `json:"body"`
|
|
||||||
User struct {
|
|
||||||
Login string `json:"login"`
|
|
||||||
} `json:"user"`
|
|
||||||
State string `json:"state"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// contentResponse is the GitHub contents API response for a single file.
|
|
||||||
type contentResponse struct {
|
|
||||||
Name string `json:"name"`
|
|
||||||
Path string `json:"path"`
|
|
||||||
Type string `json:"type"` // "file" or "dir" or "symlink" or "submodule"
|
|
||||||
Content string `json:"content"` // Base64-encoded file content (with embedded newlines)
|
|
||||||
Encoding string `json:"encoding"` // "base64" or ""
|
|
||||||
}
|
|
||||||
|
|
||||||
// ContentEntry represents a file or directory entry from the contents API.
|
|
||||||
type ContentEntry struct {
|
|
||||||
Name string `json:"name"`
|
|
||||||
Path string `json:"path"`
|
|
||||||
Type string `json:"type"` // "file" or "dir"
|
|
||||||
}
|
|
||||||
|
|
||||||
// --- PR methods ---
|
|
||||||
|
|
||||||
// GetPullRequest fetches PR metadata.
|
|
||||||
func (c *Client) GetPullRequest(ctx context.Context, owner, repo string, number int) (*PullRequest, error) {
|
|
||||||
reqURL := fmt.Sprintf("%s/repos/%s/%s/pulls/%d",
|
|
||||||
c.baseURL, url.PathEscape(owner), url.PathEscape(repo), number)
|
|
||||||
body, err := c.doGet(ctx, reqURL)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("fetch PR: %w", err)
|
|
||||||
}
|
|
||||||
var pr PullRequest
|
|
||||||
if err := json.Unmarshal(body, &pr); err != nil {
|
|
||||||
return nil, fmt.Errorf("parse PR JSON: %w", err)
|
|
||||||
}
|
|
||||||
return &pr, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetPullRequestDiff fetches the unified diff for a PR.
|
|
||||||
func (c *Client) GetPullRequestDiff(ctx context.Context, owner, repo string, number int) (string, error) {
|
|
||||||
reqURL := fmt.Sprintf("%s/repos/%s/%s/pulls/%d",
|
|
||||||
c.baseURL, url.PathEscape(owner), url.PathEscape(repo), number)
|
|
||||||
body, err := c.doRequest(ctx, http.MethodGet, reqURL, "application/vnd.github.diff")
|
|
||||||
if err != nil {
|
|
||||||
return "", fmt.Errorf("fetch diff: %w", err)
|
|
||||||
}
|
|
||||||
return string(body), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetPullRequestFiles fetches the list of files changed in a PR.
|
|
||||||
// GitHub paginates this endpoint (100 per page max).
|
|
||||||
func (c *Client) GetPullRequestFiles(ctx context.Context, owner, repo string, number int) ([]ChangedFile, error) {
|
|
||||||
const perPage = 100
|
|
||||||
var all []ChangedFile
|
|
||||||
for page := 1; ; page++ {
|
|
||||||
reqURL := fmt.Sprintf("%s/repos/%s/%s/pulls/%d/files?per_page=%d&page=%d",
|
|
||||||
c.baseURL, url.PathEscape(owner), url.PathEscape(repo), number, perPage, page)
|
|
||||||
body, err := c.doGet(ctx, reqURL)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("fetch PR files (page %d): %w", page, err)
|
|
||||||
}
|
|
||||||
var batch []ChangedFile
|
|
||||||
if err := json.Unmarshal(body, &batch); err != nil {
|
|
||||||
return nil, fmt.Errorf("parse PR files JSON (page %d): %w", page, err)
|
|
||||||
}
|
|
||||||
all = append(all, batch...)
|
|
||||||
if len(batch) < perPage {
|
|
||||||
break
|
|
||||||
}
|
}
|
||||||
|
return body, true, nil
|
||||||
}
|
}
|
||||||
return all, nil
|
|
||||||
|
errBody, readErr := io.ReadAll(io.LimitReader(resp.Body, int64(maxErrBytes)))
|
||||||
|
if readErr != nil && len(errBody) == 0 {
|
||||||
|
errBody = []byte(fmt.Sprintf("[error reading response body: %v]", readErr))
|
||||||
|
}
|
||||||
|
return nil, false, &APIError{StatusCode: resp.StatusCode, Body: string(errBody)}
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetCommitStatuses fetches CI statuses for a commit SHA.
|
// doGet is a convenience wrapper for GET requests with the default Accept header.
|
||||||
// GitHub has two status systems: legacy "commit statuses" and newer "check runs".
|
func (c *Client) doGet(ctx context.Context, reqURL string) ([]byte, error) {
|
||||||
// This method returns commit statuses only; check runs are a separate API.
|
return c.doRequest(ctx, http.MethodGet, reqURL, "")
|
||||||
// Note: GitHub returns "state" in the JSON; CommitStatus.Status is tagged accordingly.
|
|
||||||
func (c *Client) GetCommitStatuses(ctx context.Context, owner, repo, sha string) ([]CommitStatus, error) {
|
|
||||||
const perPage = 100
|
|
||||||
var all []CommitStatus
|
|
||||||
for page := 1; ; page++ {
|
|
||||||
reqURL := fmt.Sprintf("%s/repos/%s/%s/commits/%s/statuses?per_page=%d&page=%d",
|
|
||||||
c.baseURL, url.PathEscape(owner), url.PathEscape(repo), url.PathEscape(sha), perPage, page)
|
|
||||||
body, err := c.doGet(ctx, reqURL)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("fetch commit statuses (page %d): %w", page, err)
|
|
||||||
}
|
|
||||||
var batch []CommitStatus
|
|
||||||
if err := json.Unmarshal(body, &batch); err != nil {
|
|
||||||
return nil, fmt.Errorf("parse statuses JSON (page %d): %w", page, err)
|
|
||||||
}
|
|
||||||
all = append(all, batch...)
|
|
||||||
if len(batch) < perPage {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return all, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// --- File content methods ---
|
|
||||||
|
|
||||||
// GetFileContent fetches a file from the default branch of a repo.
|
|
||||||
// GitHub returns base64-encoded content; this method decodes it.
|
|
||||||
func (c *Client) GetFileContent(ctx context.Context, owner, repo, filepath string) (string, error) {
|
|
||||||
return c.getFileContentAtRef(ctx, owner, repo, filepath, "")
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetFileContentRef fetches a file from a specific ref (branch/tag/sha).
|
|
||||||
func (c *Client) GetFileContentRef(ctx context.Context, owner, repo, filepath, ref string) (string, error) {
|
|
||||||
return c.getFileContentAtRef(ctx, owner, repo, filepath, ref)
|
|
||||||
}
|
|
||||||
|
|
||||||
// getFileContentAtRef fetches a file at the given ref (empty = default branch).
|
|
||||||
// GitHub's contents API returns base64-encoded file content.
|
|
||||||
func (c *Client) getFileContentAtRef(ctx context.Context, owner, repo, filepath, ref string) (string, error) {
|
|
||||||
reqURL := fmt.Sprintf("%s/repos/%s/%s/contents/%s",
|
|
||||||
c.baseURL, url.PathEscape(owner), url.PathEscape(repo), escapePath(filepath))
|
|
||||||
if ref != "" {
|
|
||||||
reqURL += "?ref=" + url.QueryEscape(ref)
|
|
||||||
}
|
|
||||||
body, err := c.doGet(ctx, reqURL)
|
|
||||||
if err != nil {
|
|
||||||
return "", fmt.Errorf("fetch file %s: %w", filepath, err)
|
|
||||||
}
|
|
||||||
var resp contentResponse
|
|
||||||
if err := json.Unmarshal(body, &resp); err != nil {
|
|
||||||
return "", fmt.Errorf("parse file content JSON for %s: %w", filepath, err)
|
|
||||||
}
|
|
||||||
if resp.Type != "file" {
|
|
||||||
return "", fmt.Errorf("path %s is a %s, not a file", filepath, resp.Type)
|
|
||||||
}
|
|
||||||
if resp.Encoding == "base64" {
|
|
||||||
// GitHub embeds newlines in the base64 content for readability.
|
|
||||||
// Strip them before decoding.
|
|
||||||
cleaned := strings.ReplaceAll(resp.Content, "\n", "")
|
|
||||||
decoded, err := base64.StdEncoding.DecodeString(cleaned)
|
|
||||||
if err != nil {
|
|
||||||
return "", fmt.Errorf("decode base64 content for %s: %w", filepath, err)
|
|
||||||
}
|
|
||||||
return string(decoded), nil
|
|
||||||
}
|
|
||||||
// Non-base64 encoding (shouldn't happen normally, but handle gracefully).
|
|
||||||
return resp.Content, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// ListContents lists files and directories at a given path.
|
|
||||||
// Pass an empty path to list the repository root.
|
|
||||||
// GitHub returns a single object (not array) when path is a file — this
|
|
||||||
// method normalizes both cases to a slice, matching Gitea's behavior.
|
|
||||||
func (c *Client) ListContents(ctx context.Context, owner, repo, path string) ([]ContentEntry, error) {
|
|
||||||
var reqURL string
|
|
||||||
if path == "" || path == "." {
|
|
||||||
reqURL = fmt.Sprintf("%s/repos/%s/%s/contents",
|
|
||||||
c.baseURL, url.PathEscape(owner), url.PathEscape(repo))
|
|
||||||
} else {
|
|
||||||
reqURL = fmt.Sprintf("%s/repos/%s/%s/contents/%s",
|
|
||||||
c.baseURL, url.PathEscape(owner), url.PathEscape(repo), escapePath(path))
|
|
||||||
}
|
|
||||||
body, err := c.doGet(ctx, reqURL)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("list contents %s: %w", path, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
var entries []ContentEntry
|
|
||||||
if err := json.Unmarshal(body, &entries); err != nil {
|
|
||||||
// GitHub returns a single object when path is a file (not an array).
|
|
||||||
var single contentResponse
|
|
||||||
if err2 := json.Unmarshal(body, &single); err2 != nil {
|
|
||||||
return nil, fmt.Errorf("parse contents JSON: %w", err)
|
|
||||||
}
|
|
||||||
if single.Name == "" && single.Path == "" {
|
|
||||||
return nil, fmt.Errorf("parse contents JSON: empty response for path %q", path)
|
|
||||||
}
|
|
||||||
entries = []ContentEntry{{
|
|
||||||
Name: single.Name,
|
|
||||||
Path: single.Path,
|
|
||||||
Type: single.Type,
|
|
||||||
}}
|
|
||||||
}
|
|
||||||
return entries, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetAllFilesInPath recursively fetches all file contents under a path.
|
|
||||||
// If the path is a file, returns just that file's content.
|
|
||||||
// If the path is a directory, recursively fetches all files within it.
|
|
||||||
func (c *Client) GetAllFilesInPath(ctx context.Context, owner, repo, path string) (map[string]string, error) {
|
|
||||||
results := make(map[string]string)
|
|
||||||
|
|
||||||
entries, err := c.ListContents(ctx, owner, repo, path)
|
|
||||||
if err != nil {
|
|
||||||
if !IsNotFound(err) {
|
|
||||||
return nil, fmt.Errorf("list contents %q: %w", path, err)
|
|
||||||
}
|
|
||||||
// 404 means path may be a file — try fetching directly.
|
|
||||||
content, fileErr := c.GetFileContent(ctx, owner, repo, path)
|
|
||||||
if fileErr != nil {
|
|
||||||
return nil, fmt.Errorf("path %q is neither a file nor directory: %w", path, fileErr)
|
|
||||||
}
|
|
||||||
results[path] = content
|
|
||||||
return results, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, entry := range entries {
|
|
||||||
switch entry.Type {
|
|
||||||
case "file":
|
|
||||||
content, err := c.GetFileContent(ctx, owner, repo, entry.Path)
|
|
||||||
if err != nil {
|
|
||||||
slog.Warn("could not fetch file from patterns repo", "file", entry.Path, "error", err)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
results[entry.Path] = content
|
|
||||||
case "dir":
|
|
||||||
subResults, err := c.GetAllFilesInPath(ctx, owner, repo, entry.Path)
|
|
||||||
if err != nil {
|
|
||||||
slog.Warn("could not recurse into directory", "dir", entry.Path, "error", err)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
for k, v := range subResults {
|
|
||||||
results[k] = v
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return results, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// --- Review methods ---
|
|
||||||
|
|
||||||
// PostReview submits a review to a PR.
|
|
||||||
// event should be one of "APPROVE", "REQUEST_CHANGES", or "COMMENT".
|
|
||||||
// commitID anchors the review to a specific commit SHA. If empty, defaults to current HEAD.
|
|
||||||
// comments are optional inline comments; GitHub uses diff hunk position (not line numbers).
|
|
||||||
// Note: unlike Gitea, GitHub does not support deleting submitted reviews.
|
|
||||||
// Use COMMENT event to supersede old reviews.
|
|
||||||
func (c *Client) PostReview(ctx context.Context, owner, repo string, number int, event, body, commitID string, comments []ReviewComment) (*Review, error) {
|
|
||||||
reqURL := fmt.Sprintf("%s/repos/%s/%s/pulls/%d/reviews",
|
|
||||||
c.baseURL, url.PathEscape(owner), url.PathEscape(repo), number)
|
|
||||||
|
|
||||||
payload := struct {
|
|
||||||
Body string `json:"body"`
|
|
||||||
Event string `json:"event"`
|
|
||||||
CommitID string `json:"commit_id,omitempty"`
|
|
||||||
Comments []ReviewComment `json:"comments,omitempty"`
|
|
||||||
}{
|
|
||||||
Body: body,
|
|
||||||
Event: event,
|
|
||||||
CommitID: commitID,
|
|
||||||
Comments: comments,
|
|
||||||
}
|
|
||||||
|
|
||||||
data, err := json.Marshal(payload)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("marshal review payload: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
respBody, err := c.doRequestWithBody(ctx, http.MethodPost, reqURL, data)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("post review: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
var review Review
|
|
||||||
if err := json.Unmarshal(respBody, &review); err != nil {
|
|
||||||
return nil, fmt.Errorf("parse review response: %w", err)
|
|
||||||
}
|
|
||||||
return &review, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// ListReviews returns all reviews on a pull request.
|
|
||||||
// GitHub paginates via Link header; this method uses per_page=100.
|
|
||||||
func (c *Client) ListReviews(ctx context.Context, owner, repo string, number int) ([]Review, error) {
|
|
||||||
const perPage = 100
|
|
||||||
var all []Review
|
|
||||||
for page := 1; ; page++ {
|
|
||||||
reqURL := fmt.Sprintf("%s/repos/%s/%s/pulls/%d/reviews?per_page=%d&page=%d",
|
|
||||||
c.baseURL, url.PathEscape(owner), url.PathEscape(repo), number, perPage, page)
|
|
||||||
body, err := c.doGet(ctx, reqURL)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("list reviews (page %d): %w", page, err)
|
|
||||||
}
|
|
||||||
var batch []Review
|
|
||||||
if err := json.Unmarshal(body, &batch); err != nil {
|
|
||||||
return nil, fmt.Errorf("parse reviews (page %d): %w", page, err)
|
|
||||||
}
|
|
||||||
all = append(all, batch...)
|
|
||||||
if len(batch) < perPage {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return all, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// DeleteReview attempts to delete a pull request review.
|
|
||||||
// GitHub only allows deleting PENDING (draft) reviews. Submitted reviews cannot
|
|
||||||
// be deleted via the API; this method returns a descriptive error in that case.
|
|
||||||
// review-bot callers should handle this error gracefully (e.g., by not attempting
|
|
||||||
// supersede and instead posting a new review alongside the old one).
|
|
||||||
func (c *Client) DeleteReview(ctx context.Context, owner, repo string, number int, reviewID int64) error {
|
|
||||||
reqURL := fmt.Sprintf("%s/repos/%s/%s/pulls/%d/reviews/%d",
|
|
||||||
c.baseURL, url.PathEscape(owner), url.PathEscape(repo), number, reviewID)
|
|
||||||
|
|
||||||
// nil body: the GitHub DELETE endpoint for reviews requires no request body.
|
|
||||||
_, err := c.doRequestWithBody(ctx, http.MethodDelete, reqURL, nil)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("delete review: %w", err)
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetAuthenticatedUser returns the login of the authenticated user.
|
|
||||||
func (c *Client) GetAuthenticatedUser(ctx context.Context) (string, error) {
|
|
||||||
reqURL := c.baseURL + "/user"
|
|
||||||
body, err := c.doGet(ctx, reqURL)
|
|
||||||
if err != nil {
|
|
||||||
return "", fmt.Errorf("get authenticated user: %w", err)
|
|
||||||
}
|
|
||||||
var result struct {
|
|
||||||
Login string `json:"login"`
|
|
||||||
}
|
|
||||||
if err := json.Unmarshal(body, &result); err != nil {
|
|
||||||
return "", fmt.Errorf("parse user response: %w", err)
|
|
||||||
}
|
|
||||||
return result.Login, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// RequestReviewer adds a user as a requested reviewer on a pull request.
|
|
||||||
// This is idempotent — requesting an already-requested reviewer is a no-op.
|
|
||||||
func (c *Client) RequestReviewer(ctx context.Context, owner, repo string, number int, reviewer string) error {
|
|
||||||
reqURL := fmt.Sprintf("%s/repos/%s/%s/pulls/%d/requested_reviewers",
|
|
||||||
c.baseURL, url.PathEscape(owner), url.PathEscape(repo), number)
|
|
||||||
|
|
||||||
payload := struct {
|
|
||||||
Reviewers []string `json:"reviewers"`
|
|
||||||
}{Reviewers: []string{reviewer}}
|
|
||||||
data, err := json.Marshal(payload)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("marshal reviewer request: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
_, err = c.doRequestWithBody(ctx, http.MethodPost, reqURL, data)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("request reviewer: %w", err)
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// --- helpers ---
|
|
||||||
|
|
||||||
// escapePath escapes each segment of a relative file path for use in URLs.
|
|
||||||
// Slashes are preserved as path separators; other special characters are escaped.
|
|
||||||
func escapePath(p string) string {
|
|
||||||
parts := strings.Split(p, "/")
|
|
||||||
for i, part := range parts {
|
|
||||||
parts[i] = url.PathEscape(part)
|
|
||||||
}
|
|
||||||
return strings.Join(parts, "/")
|
|
||||||
}
|
}
|
||||||
|
|||||||
+445
-1073
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,13 @@
|
|||||||
|
package github_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"gitea.weiker.me/rodin/review-bot/github"
|
||||||
|
"gitea.weiker.me/rodin/review-bot/vcs"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Compile-time interface conformance assertions.
|
||||||
|
// These verify github.Client satisfies vcs.PRReader and vcs.FileReader.
|
||||||
|
var (
|
||||||
|
_ vcs.PRReader = (*github.Client)(nil)
|
||||||
|
_ vcs.FileReader = (*github.Client)(nil)
|
||||||
|
)
|
||||||
@@ -1,13 +0,0 @@
|
|||||||
package github
|
|
||||||
|
|
||||||
// AllowInsecureHTTPForTest permits sending credentials over plaintext HTTP
|
|
||||||
// without requiring the REVIEW_BOT_ALLOW_INSECURE environment variable.
|
|
||||||
// This is intended exclusively for test code using httptest.Server.
|
|
||||||
//
|
|
||||||
// Defined in a _test.go file so it is only available to test binaries.
|
|
||||||
func AllowInsecureHTTPForTest() ClientOption {
|
|
||||||
return func(cfg *clientConfig) {
|
|
||||||
cfg.allowInsecureHTTP = true
|
|
||||||
cfg.insecureIsTestBypass = true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
+135
@@ -0,0 +1,135 @@
|
|||||||
|
package github
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/base64"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"net/url"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"gitea.weiker.me/rodin/review-bot/vcs"
|
||||||
|
)
|
||||||
|
|
||||||
|
// GetFileContent fetches a file from a repo at the given ref.
|
||||||
|
// Delegates to GetFileContentAtRef with the provided ref.
|
||||||
|
func (c *Client) GetFileContent(ctx context.Context, owner, repo, path, ref string) (string, error) {
|
||||||
|
return c.GetFileContentAtRef(ctx, owner, repo, path, ref)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetFileContentAtRef fetches a file at a specific ref from a repo.
|
||||||
|
// If ref is empty, the query parameter is omitted (uses default branch).
|
||||||
|
//
|
||||||
|
// Note: dot-segments ("." and "..") in the path are silently removed to
|
||||||
|
// prevent path traversal. This means a path like "foo/../bar" resolves
|
||||||
|
// to "foo/bar" rather than "bar".
|
||||||
|
func (c *Client) GetFileContentAtRef(ctx context.Context, owner, repo, path, ref string) (string, error) {
|
||||||
|
reqURL := fmt.Sprintf("%s/repos/%s/%s/contents/%s",
|
||||||
|
c.baseURL, url.PathEscape(owner), url.PathEscape(repo), escapePath(path))
|
||||||
|
if ref != "" {
|
||||||
|
reqURL += "?ref=" + url.QueryEscape(ref)
|
||||||
|
}
|
||||||
|
body, err := c.doGet(ctx, reqURL)
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("fetch file %s: %w", path, err)
|
||||||
|
}
|
||||||
|
var resp struct {
|
||||||
|
Content string `json:"content"`
|
||||||
|
Encoding string `json:"encoding"`
|
||||||
|
}
|
||||||
|
if err := json.Unmarshal(body, &resp); err != nil {
|
||||||
|
return "", fmt.Errorf("parse file content JSON: %w", err)
|
||||||
|
}
|
||||||
|
if resp.Encoding != "base64" {
|
||||||
|
return "", fmt.Errorf("unexpected encoding %q for file %s", resp.Encoding, path)
|
||||||
|
}
|
||||||
|
decoded, err := decodeBase64Content(resp.Content)
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("decode base64 content for %s: %w", path, err)
|
||||||
|
}
|
||||||
|
return decoded, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ListContents lists files and directories at a given path in a repo.
|
||||||
|
// Returns the directory listing from the GitHub contents API.
|
||||||
|
// If the path points to a single file (not a directory), the API returns
|
||||||
|
// a JSON object instead of an array; this is handled by returning a
|
||||||
|
// single-element slice.
|
||||||
|
//
|
||||||
|
// Note: dot-segments ("." and "..") in the path are silently removed to
|
||||||
|
// prevent path traversal. This means a path like "foo/../bar" resolves
|
||||||
|
// to "foo/bar" rather than "bar".
|
||||||
|
func (c *Client) ListContents(ctx context.Context, owner, repo, path string) ([]vcs.ContentEntry, error) {
|
||||||
|
reqURL := fmt.Sprintf("%s/repos/%s/%s/contents/%s",
|
||||||
|
c.baseURL, url.PathEscape(owner), url.PathEscape(repo), escapePath(path))
|
||||||
|
body, err := c.doGet(ctx, reqURL)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("list contents %s: %w", path, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
type entry struct {
|
||||||
|
Name string `json:"name"`
|
||||||
|
Path string `json:"path"`
|
||||||
|
Type string `json:"type"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// The GitHub contents API returns an array for directories and an object
|
||||||
|
// for single files. Try array first (common case), then fall back to object.
|
||||||
|
// An empty array ([]) is valid — it represents an empty directory — and
|
||||||
|
// results in a zero-length slice returned without error.
|
||||||
|
var entries []entry
|
||||||
|
if err := json.Unmarshal(body, &entries); err != nil {
|
||||||
|
var single entry
|
||||||
|
if err2 := json.Unmarshal(body, &single); err2 != nil {
|
||||||
|
return nil, fmt.Errorf("parse contents JSON: as array: %w; as object: %w", err, err2)
|
||||||
|
}
|
||||||
|
// Guard against empty objects ({}) or unexpected shapes that
|
||||||
|
// unmarshal successfully but carry no useful data.
|
||||||
|
if single.Name == "" && single.Path == "" && single.Type == "" {
|
||||||
|
return nil, fmt.Errorf("parse contents JSON: unexpected response format")
|
||||||
|
}
|
||||||
|
entries = []entry{single}
|
||||||
|
}
|
||||||
|
|
||||||
|
result := make([]vcs.ContentEntry, len(entries))
|
||||||
|
for i, e := range entries {
|
||||||
|
result[i] = vcs.ContentEntry{
|
||||||
|
Name: e.Name,
|
||||||
|
Path: e.Path,
|
||||||
|
Type: e.Type,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return result, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// escapePath escapes each segment of a relative file path for use in URLs.
|
||||||
|
// Slashes are preserved as path separators; other special characters are escaped.
|
||||||
|
// Dot-segments ("." and "..") and empty segments (from consecutive slashes like
|
||||||
|
// "a//b") are silently removed to prevent path traversal and produce canonical
|
||||||
|
// paths. This is intentional: callers may receive a different path than requested
|
||||||
|
// without error. The function is package-private, and all callers
|
||||||
|
// (GetFileContentAtRef, ListContents) already handle missing-file errors from the
|
||||||
|
// API if the cleaned path doesn't match what the caller intended.
|
||||||
|
func escapePath(p string) string {
|
||||||
|
parts := strings.Split(p, "/")
|
||||||
|
var clean []string
|
||||||
|
for _, part := range parts {
|
||||||
|
if part == "." || part == ".." || part == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
clean = append(clean, url.PathEscape(part))
|
||||||
|
}
|
||||||
|
return strings.Join(clean, "/")
|
||||||
|
}
|
||||||
|
|
||||||
|
// decodeBase64Content decodes base64-encoded content from the GitHub contents API.
|
||||||
|
// GitHub returns base64 content with line breaks for formatting; we strip \r and \n before decoding.
|
||||||
|
func decodeBase64Content(encoded string) (string, error) {
|
||||||
|
// GitHub inserts newlines in base64 content
|
||||||
|
cleaned := strings.NewReplacer("\n", "", "\r", "").Replace(encoded)
|
||||||
|
decoded, err := base64.StdEncoding.DecodeString(cleaned)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
return string(decoded), nil
|
||||||
|
}
|
||||||
@@ -0,0 +1,334 @@
|
|||||||
|
package github
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestGetFileContent_DelegatesToGetFileContentAtRef(t *testing.T) {
|
||||||
|
var gotRef string
|
||||||
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
gotRef = r.URL.Query().Get("ref")
|
||||||
|
json.NewEncoder(w).Encode(map[string]string{
|
||||||
|
"content": "dGVzdA==", // "test" in base64
|
||||||
|
"encoding": "base64",
|
||||||
|
})
|
||||||
|
}))
|
||||||
|
defer srv.Close()
|
||||||
|
|
||||||
|
c := NewClient("token", srv.URL, AllowInsecureHTTP())
|
||||||
|
c.SetHTTPClient(srv.Client())
|
||||||
|
|
||||||
|
// Call with empty ref — should not include ref param
|
||||||
|
content, err := c.GetFileContent(context.Background(), "owner", "repo", "file.go", "")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
if content != "test" {
|
||||||
|
t.Errorf("expected 'test', got %q", content)
|
||||||
|
}
|
||||||
|
if gotRef != "" {
|
||||||
|
t.Errorf("expected empty ref, got %q", gotRef)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetFileContent_WithRef(t *testing.T) {
|
||||||
|
var gotRef string
|
||||||
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
gotRef = r.URL.Query().Get("ref")
|
||||||
|
json.NewEncoder(w).Encode(map[string]string{
|
||||||
|
"content": "dGVzdA==",
|
||||||
|
"encoding": "base64",
|
||||||
|
})
|
||||||
|
}))
|
||||||
|
defer srv.Close()
|
||||||
|
|
||||||
|
c := NewClient("token", srv.URL, AllowInsecureHTTP())
|
||||||
|
c.SetHTTPClient(srv.Client())
|
||||||
|
|
||||||
|
_, err := c.GetFileContent(context.Background(), "owner", "repo", "file.go", "abc123")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
if gotRef != "abc123" {
|
||||||
|
t.Errorf("expected ref 'abc123', got %q", gotRef)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetFileContent_404(t *testing.T) {
|
||||||
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.WriteHeader(404)
|
||||||
|
w.Write([]byte(`{"message":"Not Found"}`))
|
||||||
|
}))
|
||||||
|
defer srv.Close()
|
||||||
|
|
||||||
|
c := NewClient("token", srv.URL, AllowInsecureHTTP())
|
||||||
|
c.SetHTTPClient(srv.Client())
|
||||||
|
|
||||||
|
_, err := c.GetFileContent(context.Background(), "owner", "repo", "missing.go", "")
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected error for 404")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetFileContent_401(t *testing.T) {
|
||||||
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.WriteHeader(401)
|
||||||
|
w.Write([]byte(`{"message":"Bad credentials"}`))
|
||||||
|
}))
|
||||||
|
defer srv.Close()
|
||||||
|
|
||||||
|
c := NewClient("token", srv.URL, AllowInsecureHTTP())
|
||||||
|
c.SetHTTPClient(srv.Client())
|
||||||
|
|
||||||
|
_, err := c.GetFileContent(context.Background(), "owner", "repo", "file.go", "")
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected error for 401")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetFileContent_429Retry(t *testing.T) {
|
||||||
|
attempts := 0
|
||||||
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
attempts++
|
||||||
|
if attempts == 1 {
|
||||||
|
w.WriteHeader(429)
|
||||||
|
w.Write([]byte(`{"message":"rate limit"}`))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
json.NewEncoder(w).Encode(map[string]string{
|
||||||
|
"content": "b2s=",
|
||||||
|
"encoding": "base64",
|
||||||
|
})
|
||||||
|
}))
|
||||||
|
defer srv.Close()
|
||||||
|
|
||||||
|
c := NewClient("token", srv.URL, AllowInsecureHTTP())
|
||||||
|
c.SetHTTPClient(srv.Client())
|
||||||
|
c.SetRetryBackoff([]time.Duration{1 * time.Millisecond})
|
||||||
|
|
||||||
|
content, err := c.GetFileContent(context.Background(), "owner", "repo", "file.go", "")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
if content != "ok" {
|
||||||
|
t.Errorf("expected 'ok', got %q", content)
|
||||||
|
}
|
||||||
|
if attempts != 2 {
|
||||||
|
t.Errorf("expected 2 attempts, got %d", attempts)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetFileContent_MalformedJSON(t *testing.T) {
|
||||||
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.WriteHeader(200)
|
||||||
|
w.Write([]byte(`not json`))
|
||||||
|
}))
|
||||||
|
defer srv.Close()
|
||||||
|
|
||||||
|
c := NewClient("token", srv.URL, AllowInsecureHTTP())
|
||||||
|
c.SetHTTPClient(srv.Client())
|
||||||
|
|
||||||
|
_, err := c.GetFileContent(context.Background(), "owner", "repo", "file.go", "")
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected error for malformed JSON")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestListContents_HappyPath(t *testing.T) {
|
||||||
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if r.URL.Path != "/repos/owner/repo/contents/src" {
|
||||||
|
t.Errorf("unexpected path: %s", r.URL.Path)
|
||||||
|
}
|
||||||
|
json.NewEncoder(w).Encode([]map[string]string{
|
||||||
|
{"name": "main.go", "path": "src/main.go", "type": "file"},
|
||||||
|
{"name": "lib", "path": "src/lib", "type": "dir"},
|
||||||
|
})
|
||||||
|
}))
|
||||||
|
defer srv.Close()
|
||||||
|
|
||||||
|
c := NewClient("token", srv.URL, AllowInsecureHTTP())
|
||||||
|
c.SetHTTPClient(srv.Client())
|
||||||
|
|
||||||
|
entries, err := c.ListContents(context.Background(), "owner", "repo", "src")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
if len(entries) != 2 {
|
||||||
|
t.Fatalf("expected 2 entries, got %d", len(entries))
|
||||||
|
}
|
||||||
|
if entries[0].Name != "main.go" {
|
||||||
|
t.Errorf("expected name 'main.go', got %q", entries[0].Name)
|
||||||
|
}
|
||||||
|
if entries[0].Path != "src/main.go" {
|
||||||
|
t.Errorf("expected path 'src/main.go', got %q", entries[0].Path)
|
||||||
|
}
|
||||||
|
if entries[0].Type != "file" {
|
||||||
|
t.Errorf("expected type 'file', got %q", entries[0].Type)
|
||||||
|
}
|
||||||
|
if entries[1].Name != "lib" {
|
||||||
|
t.Errorf("expected name 'lib', got %q", entries[1].Name)
|
||||||
|
}
|
||||||
|
if entries[1].Type != "dir" {
|
||||||
|
t.Errorf("expected type 'dir', got %q", entries[1].Type)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestListContents_404(t *testing.T) {
|
||||||
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.WriteHeader(404)
|
||||||
|
w.Write([]byte(`{"message":"Not Found"}`))
|
||||||
|
}))
|
||||||
|
defer srv.Close()
|
||||||
|
|
||||||
|
c := NewClient("token", srv.URL, AllowInsecureHTTP())
|
||||||
|
c.SetHTTPClient(srv.Client())
|
||||||
|
|
||||||
|
_, err := c.ListContents(context.Background(), "owner", "repo", "missing")
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected error for 404")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestListContents_401(t *testing.T) {
|
||||||
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.WriteHeader(401)
|
||||||
|
w.Write([]byte(`{"message":"Bad credentials"}`))
|
||||||
|
}))
|
||||||
|
defer srv.Close()
|
||||||
|
|
||||||
|
c := NewClient("token", srv.URL, AllowInsecureHTTP())
|
||||||
|
c.SetHTTPClient(srv.Client())
|
||||||
|
|
||||||
|
_, err := c.ListContents(context.Background(), "owner", "repo", "src")
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected error for 401")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestListContents_429Retry(t *testing.T) {
|
||||||
|
attempts := 0
|
||||||
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
attempts++
|
||||||
|
if attempts == 1 {
|
||||||
|
w.WriteHeader(429)
|
||||||
|
w.Write([]byte(`{"message":"rate limit"}`))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
json.NewEncoder(w).Encode([]map[string]string{
|
||||||
|
{"name": "file.go", "path": "file.go", "type": "file"},
|
||||||
|
})
|
||||||
|
}))
|
||||||
|
defer srv.Close()
|
||||||
|
|
||||||
|
c := NewClient("token", srv.URL, AllowInsecureHTTP())
|
||||||
|
c.SetHTTPClient(srv.Client())
|
||||||
|
c.SetRetryBackoff([]time.Duration{1 * time.Millisecond})
|
||||||
|
|
||||||
|
entries, err := c.ListContents(context.Background(), "owner", "repo", ".")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
if len(entries) != 1 {
|
||||||
|
t.Fatalf("expected 1 entry, got %d", len(entries))
|
||||||
|
}
|
||||||
|
if attempts != 2 {
|
||||||
|
t.Errorf("expected 2 attempts, got %d", attempts)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestListContents_MalformedJSON(t *testing.T) {
|
||||||
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.WriteHeader(200)
|
||||||
|
w.Write([]byte(`not json`))
|
||||||
|
}))
|
||||||
|
defer srv.Close()
|
||||||
|
|
||||||
|
c := NewClient("token", srv.URL, AllowInsecureHTTP())
|
||||||
|
c.SetHTTPClient(srv.Client())
|
||||||
|
|
||||||
|
_, err := c.ListContents(context.Background(), "owner", "repo", "src")
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected error for malformed JSON")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDecodeBase64Content(t *testing.T) {
|
||||||
|
// Test with newlines (GitHub's format)
|
||||||
|
encoded := "cGFja2FnZSBt\nYWlu"
|
||||||
|
decoded, err := decodeBase64Content(encoded)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
if decoded != "package main" {
|
||||||
|
t.Errorf("expected 'package main', got %q", decoded)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDecodeBase64Content_Invalid(t *testing.T) {
|
||||||
|
_, err := decodeBase64Content("not!!!valid!!!base64")
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected error for invalid base64")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestEscapePath_RejectsDotSegments(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
input string
|
||||||
|
want string
|
||||||
|
}{
|
||||||
|
{"src/main.go", "src/main.go"},
|
||||||
|
{"../etc/passwd", "etc/passwd"},
|
||||||
|
{"./src/../main.go", "src/main.go"},
|
||||||
|
{"a/b/c", "a/b/c"},
|
||||||
|
{"file with spaces.go", "file%20with%20spaces.go"},
|
||||||
|
{"a/./b/../c", "a/b/c"},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
got := escapePath(tt.input)
|
||||||
|
if got != tt.want {
|
||||||
|
t.Errorf("escapePath(%q) = %q, want %q", tt.input, got, tt.want)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDecodeBase64Content_CRLF(t *testing.T) {
|
||||||
|
// Base64 of "hello world" with CRLF line breaks inserted
|
||||||
|
encoded := "aGVs\r\nbG8g\r\nd29y\r\nbGQ="
|
||||||
|
decoded, err := decodeBase64Content(encoded)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
if decoded != "hello world" {
|
||||||
|
t.Errorf("expected 'hello world', got %q", decoded)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestListContents_SingleFile(t *testing.T) {
|
||||||
|
// GitHub Contents API returns a JSON object (not array) for single-file paths
|
||||||
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.WriteHeader(200)
|
||||||
|
w.Write([]byte(`{"name":"README.md","path":"README.md","type":"file"}`))
|
||||||
|
}))
|
||||||
|
defer srv.Close()
|
||||||
|
|
||||||
|
c := NewClient("token", srv.URL, AllowInsecureHTTP())
|
||||||
|
c.SetHTTPClient(srv.Client())
|
||||||
|
entries, err := c.ListContents(context.Background(), "owner", "repo", "README.md")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
if len(entries) != 1 {
|
||||||
|
t.Fatalf("expected 1 entry, got %d", len(entries))
|
||||||
|
}
|
||||||
|
if entries[0].Name != "README.md" {
|
||||||
|
t.Errorf("expected name 'README.md', got %q", entries[0].Name)
|
||||||
|
}
|
||||||
|
if entries[0].Type != "file" {
|
||||||
|
t.Errorf("expected type 'file', got %q", entries[0].Type)
|
||||||
|
}
|
||||||
|
}
|
||||||
+212
@@ -0,0 +1,212 @@
|
|||||||
|
package github
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
"net/url"
|
||||||
|
|
||||||
|
"gitea.weiker.me/rodin/review-bot/vcs"
|
||||||
|
)
|
||||||
|
|
||||||
|
// pullRequestResponse is the GitHub API response for a pull request.
|
||||||
|
type pullRequestResponse struct {
|
||||||
|
Number int `json:"number"`
|
||||||
|
Title string `json:"title"`
|
||||||
|
Body string `json:"body"`
|
||||||
|
Head struct {
|
||||||
|
SHA string `json:"sha"`
|
||||||
|
Ref string `json:"ref"`
|
||||||
|
} `json:"head"`
|
||||||
|
Base struct {
|
||||||
|
Ref string `json:"ref"`
|
||||||
|
} `json:"base"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// changedFileResponse is the GitHub API response for a changed file in a PR.
|
||||||
|
type changedFileResponse struct {
|
||||||
|
Filename string `json:"filename"`
|
||||||
|
Status string `json:"status"`
|
||||||
|
Patch string `json:"patch"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// commitStatusResponse is the GitHub combined status API response.
|
||||||
|
type commitStatusResponse struct {
|
||||||
|
Statuses []struct {
|
||||||
|
Context string `json:"context"`
|
||||||
|
State string `json:"state"`
|
||||||
|
Description string `json:"description"`
|
||||||
|
TargetURL string `json:"target_url"`
|
||||||
|
} `json:"statuses"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// checkRunsResponse is the GitHub check runs API response.
|
||||||
|
type checkRunsResponse struct {
|
||||||
|
CheckRuns []struct {
|
||||||
|
Name string `json:"name"`
|
||||||
|
Conclusion *string `json:"conclusion"`
|
||||||
|
Status string `json:"status"`
|
||||||
|
HTMLURL string `json:"html_url"`
|
||||||
|
} `json:"check_runs"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetPullRequest fetches PR metadata.
|
||||||
|
func (c *Client) GetPullRequest(ctx context.Context, owner, repo string, number int) (*vcs.PullRequest, error) {
|
||||||
|
reqURL := fmt.Sprintf("%s/repos/%s/%s/pulls/%d", c.baseURL, url.PathEscape(owner), url.PathEscape(repo), number)
|
||||||
|
body, err := c.doGet(ctx, reqURL)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("fetch PR: %w", err)
|
||||||
|
}
|
||||||
|
var resp pullRequestResponse
|
||||||
|
if err := json.Unmarshal(body, &resp); err != nil {
|
||||||
|
return nil, fmt.Errorf("parse PR JSON: %w", err)
|
||||||
|
}
|
||||||
|
return &vcs.PullRequest{
|
||||||
|
Number: resp.Number,
|
||||||
|
Title: resp.Title,
|
||||||
|
Body: resp.Body,
|
||||||
|
Head: vcs.HeadRef{SHA: resp.Head.SHA, Ref: resp.Head.Ref},
|
||||||
|
Base: vcs.BaseRef{Ref: resp.Base.Ref},
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetPullRequestDiff fetches the unified diff for a PR.
|
||||||
|
// Uses Accept: application/vnd.github.diff to get raw diff text.
|
||||||
|
func (c *Client) GetPullRequestDiff(ctx context.Context, owner, repo string, number int) (string, error) {
|
||||||
|
reqURL := fmt.Sprintf("%s/repos/%s/%s/pulls/%d", c.baseURL, url.PathEscape(owner), url.PathEscape(repo), number)
|
||||||
|
body, err := c.doRequest(ctx, http.MethodGet, reqURL, "application/vnd.github.diff")
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("fetch diff: %w", err)
|
||||||
|
}
|
||||||
|
return string(body), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// maxPages is the upper bound on pagination loops to prevent unbounded iteration
|
||||||
|
// in case the server returns a full page indefinitely.
|
||||||
|
const maxPages = 100
|
||||||
|
|
||||||
|
// GetPullRequestFiles fetches the list of files changed in a PR.
|
||||||
|
// Paginates through all pages (100 per page) to collect all files.
|
||||||
|
// Returns nil (not an empty slice) when the PR has no changed files.
|
||||||
|
// Callers can safely range over or check len() on a nil slice.
|
||||||
|
func (c *Client) GetPullRequestFiles(ctx context.Context, owner, repo string, number int) ([]vcs.ChangedFile, error) {
|
||||||
|
var allFiles []vcs.ChangedFile
|
||||||
|
|
||||||
|
for page := 1; page <= maxPages; page++ {
|
||||||
|
reqURL := fmt.Sprintf("%s/repos/%s/%s/pulls/%d/files?per_page=100&page=%d",
|
||||||
|
c.baseURL, url.PathEscape(owner), url.PathEscape(repo), number, page)
|
||||||
|
body, err := c.doGet(ctx, reqURL)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("fetch PR files page %d: %w", page, err)
|
||||||
|
}
|
||||||
|
var files []changedFileResponse
|
||||||
|
if err := json.Unmarshal(body, &files); err != nil {
|
||||||
|
return nil, fmt.Errorf("parse PR files JSON: %w", err)
|
||||||
|
}
|
||||||
|
if len(files) == 0 {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
for _, f := range files {
|
||||||
|
allFiles = append(allFiles, vcs.ChangedFile{
|
||||||
|
Filename: f.Filename,
|
||||||
|
Status: f.Status,
|
||||||
|
Patch: f.Patch,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
if len(files) < 100 {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return allFiles, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetCommitStatuses fetches both commit statuses and check runs for a SHA,
|
||||||
|
// merging them into a unified []vcs.CommitStatus slice.
|
||||||
|
// Returns nil (not an empty slice) when there are no statuses or check runs.
|
||||||
|
// If the commit statuses endpoint fails (e.g. 404 for an unknown SHA), the
|
||||||
|
// function returns immediately without attempting the check-runs endpoint.
|
||||||
|
// If the check-runs endpoint fails after statuses were fetched successfully,
|
||||||
|
// the function returns an error (not a partial result) so callers always get
|
||||||
|
// either a complete view or a clear error signal.
|
||||||
|
func (c *Client) GetCommitStatuses(ctx context.Context, owner, repo, sha string) ([]vcs.CommitStatus, error) {
|
||||||
|
var result []vcs.CommitStatus
|
||||||
|
|
||||||
|
// Fetch commit statuses
|
||||||
|
statusURL := fmt.Sprintf("%s/repos/%s/%s/commits/%s/status",
|
||||||
|
c.baseURL, url.PathEscape(owner), url.PathEscape(repo), url.PathEscape(sha))
|
||||||
|
statusBody, err := c.doGet(ctx, statusURL)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("fetch commit statuses: %w", err)
|
||||||
|
}
|
||||||
|
var statusResp commitStatusResponse
|
||||||
|
if err := json.Unmarshal(statusBody, &statusResp); err != nil {
|
||||||
|
return nil, fmt.Errorf("parse commit statuses JSON: %w", err)
|
||||||
|
}
|
||||||
|
for _, s := range statusResp.Statuses {
|
||||||
|
result = append(result, vcs.CommitStatus{
|
||||||
|
Context: s.Context,
|
||||||
|
Status: s.State,
|
||||||
|
Description: s.Description,
|
||||||
|
TargetURL: s.TargetURL,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// Fetch check runs (paginated)
|
||||||
|
for checkPage := 1; checkPage <= maxPages; checkPage++ {
|
||||||
|
checkURL := fmt.Sprintf("%s/repos/%s/%s/commits/%s/check-runs?per_page=100&page=%d",
|
||||||
|
c.baseURL, url.PathEscape(owner), url.PathEscape(repo), url.PathEscape(sha), checkPage)
|
||||||
|
checkBody, err := c.doGet(ctx, checkURL)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("fetch check runs page %d: %w", checkPage, err)
|
||||||
|
}
|
||||||
|
var checkResp checkRunsResponse
|
||||||
|
if err := json.Unmarshal(checkBody, &checkResp); err != nil {
|
||||||
|
return nil, fmt.Errorf("parse check runs JSON: %w", err)
|
||||||
|
}
|
||||||
|
for _, cr := range checkResp.CheckRuns {
|
||||||
|
result = append(result, vcs.CommitStatus{
|
||||||
|
Context: cr.Name,
|
||||||
|
Status: mapCheckRunStatus(cr.Conclusion),
|
||||||
|
Description: derefString(cr.Conclusion),
|
||||||
|
TargetURL: cr.HTMLURL,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
if len(checkResp.CheckRuns) < 100 {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return result, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// mapCheckRunStatus maps a check run conclusion to a vcs.CommitStatus status string.
|
||||||
|
// Conclusion alone determines the mapped state: nil conclusion means the run is
|
||||||
|
// still in progress (pending), regardless of the status field value.
|
||||||
|
func mapCheckRunStatus(conclusion *string) string {
|
||||||
|
if conclusion == nil {
|
||||||
|
// Still running or queued
|
||||||
|
return "pending"
|
||||||
|
}
|
||||||
|
switch *conclusion {
|
||||||
|
case "success":
|
||||||
|
return "success"
|
||||||
|
case "failure", "action_required", "timed_out":
|
||||||
|
return "failure"
|
||||||
|
case "cancelled", "skipped", "neutral":
|
||||||
|
return "success" // non-blocking: these do not indicate a blocking failure per GitHub check suite semantics
|
||||||
|
case "stale", "waiting":
|
||||||
|
return "pending"
|
||||||
|
default:
|
||||||
|
return "pending"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// derefString safely dereferences a string pointer, returning empty string if nil.
|
||||||
|
func derefString(s *string) string {
|
||||||
|
if s == nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
return *s
|
||||||
|
}
|
||||||
@@ -0,0 +1,637 @@
|
|||||||
|
package github
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestGetPullRequest_HappyPath(t *testing.T) {
|
||||||
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if r.URL.Path != "/repos/owner/repo/pulls/42" {
|
||||||
|
t.Errorf("unexpected path: %s", r.URL.Path)
|
||||||
|
}
|
||||||
|
json.NewEncoder(w).Encode(map[string]interface{}{
|
||||||
|
"number": 42,
|
||||||
|
"title": "Test PR",
|
||||||
|
"body": "Description",
|
||||||
|
"head": map[string]string{"sha": "abc123", "ref": "feature-branch"},
|
||||||
|
"base": map[string]string{"ref": "main"},
|
||||||
|
})
|
||||||
|
}))
|
||||||
|
defer srv.Close()
|
||||||
|
|
||||||
|
c := NewClient("token", srv.URL, AllowInsecureHTTP())
|
||||||
|
c.SetHTTPClient(srv.Client())
|
||||||
|
|
||||||
|
pr, err := c.GetPullRequest(context.Background(), "owner", "repo", 42)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
if pr.Number != 42 {
|
||||||
|
t.Errorf("expected number 42, got %d", pr.Number)
|
||||||
|
}
|
||||||
|
if pr.Title != "Test PR" {
|
||||||
|
t.Errorf("expected title 'Test PR', got %q", pr.Title)
|
||||||
|
}
|
||||||
|
if pr.Body != "Description" {
|
||||||
|
t.Errorf("expected body 'Description', got %q", pr.Body)
|
||||||
|
}
|
||||||
|
if pr.Head.SHA != "abc123" {
|
||||||
|
t.Errorf("expected head SHA 'abc123', got %q", pr.Head.SHA)
|
||||||
|
}
|
||||||
|
if pr.Head.Ref != "feature-branch" {
|
||||||
|
t.Errorf("expected head ref 'feature-branch', got %q", pr.Head.Ref)
|
||||||
|
}
|
||||||
|
if pr.Base.Ref != "main" {
|
||||||
|
t.Errorf("expected base ref 'main', got %q", pr.Base.Ref)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetPullRequest_404(t *testing.T) {
|
||||||
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.WriteHeader(404)
|
||||||
|
w.Write([]byte(`{"message":"Not Found"}`))
|
||||||
|
}))
|
||||||
|
defer srv.Close()
|
||||||
|
|
||||||
|
c := NewClient("token", srv.URL, AllowInsecureHTTP())
|
||||||
|
c.SetHTTPClient(srv.Client())
|
||||||
|
|
||||||
|
_, err := c.GetPullRequest(context.Background(), "owner", "repo", 999)
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected error for 404")
|
||||||
|
}
|
||||||
|
if !IsNotFound(err) {
|
||||||
|
t.Errorf("expected IsNotFound=true, got error: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetPullRequest_401(t *testing.T) {
|
||||||
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.WriteHeader(401)
|
||||||
|
w.Write([]byte(`{"message":"Bad credentials"}`))
|
||||||
|
}))
|
||||||
|
defer srv.Close()
|
||||||
|
|
||||||
|
c := NewClient("token", srv.URL, AllowInsecureHTTP())
|
||||||
|
c.SetHTTPClient(srv.Client())
|
||||||
|
|
||||||
|
_, err := c.GetPullRequest(context.Background(), "owner", "repo", 1)
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected error for 401")
|
||||||
|
}
|
||||||
|
if !IsUnauthorized(err) {
|
||||||
|
t.Errorf("expected IsUnauthorized=true, got error: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetPullRequest_429Retry(t *testing.T) {
|
||||||
|
attempts := 0
|
||||||
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
attempts++
|
||||||
|
if attempts == 1 {
|
||||||
|
w.WriteHeader(429)
|
||||||
|
w.Write([]byte(`{"message":"rate limit"}`))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
json.NewEncoder(w).Encode(map[string]interface{}{
|
||||||
|
"number": 1,
|
||||||
|
"title": "PR",
|
||||||
|
"body": "",
|
||||||
|
"head": map[string]string{"sha": "abc", "ref": "br"},
|
||||||
|
"base": map[string]string{"ref": "main"},
|
||||||
|
})
|
||||||
|
}))
|
||||||
|
defer srv.Close()
|
||||||
|
|
||||||
|
c := NewClient("token", srv.URL, AllowInsecureHTTP())
|
||||||
|
c.SetHTTPClient(srv.Client())
|
||||||
|
c.SetRetryBackoff([]time.Duration{1 * time.Millisecond})
|
||||||
|
|
||||||
|
pr, err := c.GetPullRequest(context.Background(), "owner", "repo", 1)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
if pr.Number != 1 {
|
||||||
|
t.Errorf("expected number 1, got %d", pr.Number)
|
||||||
|
}
|
||||||
|
if attempts != 2 {
|
||||||
|
t.Errorf("expected 2 attempts, got %d", attempts)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetPullRequest_MalformedJSON(t *testing.T) {
|
||||||
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.WriteHeader(200)
|
||||||
|
w.Write([]byte(`{invalid json`))
|
||||||
|
}))
|
||||||
|
defer srv.Close()
|
||||||
|
|
||||||
|
c := NewClient("token", srv.URL, AllowInsecureHTTP())
|
||||||
|
c.SetHTTPClient(srv.Client())
|
||||||
|
|
||||||
|
_, err := c.GetPullRequest(context.Background(), "owner", "repo", 1)
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected error for malformed JSON")
|
||||||
|
}
|
||||||
|
if !strings.Contains(err.Error(), "parse PR JSON") {
|
||||||
|
t.Errorf("expected parse error, got: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetPullRequestDiff_HappyPath(t *testing.T) {
|
||||||
|
expectedDiff := "diff --git a/file.go b/file.go\n--- a/file.go\n+++ b/file.go\n@@ -1,3 +1,4 @@\n+// new line\n"
|
||||||
|
var gotAccept string
|
||||||
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
gotAccept = r.Header.Get("Accept")
|
||||||
|
w.WriteHeader(200)
|
||||||
|
w.Write([]byte(expectedDiff))
|
||||||
|
}))
|
||||||
|
defer srv.Close()
|
||||||
|
|
||||||
|
c := NewClient("token", srv.URL, AllowInsecureHTTP())
|
||||||
|
c.SetHTTPClient(srv.Client())
|
||||||
|
|
||||||
|
diff, err := c.GetPullRequestDiff(context.Background(), "owner", "repo", 42)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
if diff != expectedDiff {
|
||||||
|
t.Errorf("unexpected diff: %q", diff)
|
||||||
|
}
|
||||||
|
if gotAccept != "application/vnd.github.diff" {
|
||||||
|
t.Errorf("expected diff Accept header, got %q", gotAccept)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetPullRequestDiff_404(t *testing.T) {
|
||||||
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.WriteHeader(404)
|
||||||
|
w.Write([]byte(`{"message":"Not Found"}`))
|
||||||
|
}))
|
||||||
|
defer srv.Close()
|
||||||
|
|
||||||
|
c := NewClient("token", srv.URL, AllowInsecureHTTP())
|
||||||
|
c.SetHTTPClient(srv.Client())
|
||||||
|
|
||||||
|
_, err := c.GetPullRequestDiff(context.Background(), "owner", "repo", 999)
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected error for 404")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetPullRequestDiff_401(t *testing.T) {
|
||||||
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.WriteHeader(401)
|
||||||
|
w.Write([]byte(`{"message":"Bad credentials"}`))
|
||||||
|
}))
|
||||||
|
defer srv.Close()
|
||||||
|
|
||||||
|
c := NewClient("token", srv.URL, AllowInsecureHTTP())
|
||||||
|
c.SetHTTPClient(srv.Client())
|
||||||
|
|
||||||
|
_, err := c.GetPullRequestDiff(context.Background(), "owner", "repo", 1)
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected error for 401")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetPullRequestFiles_HappyPath(t *testing.T) {
|
||||||
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
json.NewEncoder(w).Encode([]map[string]interface{}{
|
||||||
|
{"filename": "main.go", "status": "modified", "patch": "@@ -1,3 +1,4 @@\n+line"},
|
||||||
|
{"filename": "test.go", "status": "added", "patch": "@@ -0,0 +1,5 @@\n+new file"},
|
||||||
|
})
|
||||||
|
}))
|
||||||
|
defer srv.Close()
|
||||||
|
|
||||||
|
c := NewClient("token", srv.URL, AllowInsecureHTTP())
|
||||||
|
c.SetHTTPClient(srv.Client())
|
||||||
|
|
||||||
|
files, err := c.GetPullRequestFiles(context.Background(), "owner", "repo", 1)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
if len(files) != 2 {
|
||||||
|
t.Fatalf("expected 2 files, got %d", len(files))
|
||||||
|
}
|
||||||
|
if files[0].Filename != "main.go" {
|
||||||
|
t.Errorf("expected filename 'main.go', got %q", files[0].Filename)
|
||||||
|
}
|
||||||
|
if files[0].Status != "modified" {
|
||||||
|
t.Errorf("expected status 'modified', got %q", files[0].Status)
|
||||||
|
}
|
||||||
|
if files[0].Patch != "@@ -1,3 +1,4 @@\n+line" {
|
||||||
|
t.Errorf("unexpected patch: %q", files[0].Patch)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetPullRequestFiles_Pagination(t *testing.T) {
|
||||||
|
// Simulate > 100 files requiring pagination
|
||||||
|
page1Files := make([]map[string]string, 100)
|
||||||
|
for i := 0; i < 100; i++ {
|
||||||
|
page1Files[i] = map[string]string{
|
||||||
|
"filename": fmt.Sprintf("file%d.go", i),
|
||||||
|
"status": "modified",
|
||||||
|
"patch": fmt.Sprintf("patch%d", i),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
page2Files := []map[string]string{
|
||||||
|
{"filename": "file100.go", "status": "added", "patch": "patch100"},
|
||||||
|
}
|
||||||
|
|
||||||
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
page := r.URL.Query().Get("page")
|
||||||
|
if page == "" || page == "1" {
|
||||||
|
json.NewEncoder(w).Encode(page1Files)
|
||||||
|
} else {
|
||||||
|
json.NewEncoder(w).Encode(page2Files)
|
||||||
|
}
|
||||||
|
}))
|
||||||
|
defer srv.Close()
|
||||||
|
|
||||||
|
c := NewClient("token", srv.URL, AllowInsecureHTTP())
|
||||||
|
c.SetHTTPClient(srv.Client())
|
||||||
|
|
||||||
|
files, err := c.GetPullRequestFiles(context.Background(), "owner", "repo", 1)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
if len(files) != 101 {
|
||||||
|
t.Errorf("expected 101 files (paginated), got %d", len(files))
|
||||||
|
}
|
||||||
|
if files[100].Filename != "file100.go" {
|
||||||
|
t.Errorf("expected last file 'file100.go', got %q", files[100].Filename)
|
||||||
|
}
|
||||||
|
if files[100].Patch != "patch100" {
|
||||||
|
t.Errorf("expected last patch 'patch100', got %q", files[100].Patch)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetPullRequestFiles_BinaryFile_NoPatch(t *testing.T) {
|
||||||
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
// Binary files have no patch field in GitHub response
|
||||||
|
json.NewEncoder(w).Encode([]map[string]interface{}{
|
||||||
|
{"filename": "image.png", "status": "added"},
|
||||||
|
})
|
||||||
|
}))
|
||||||
|
defer srv.Close()
|
||||||
|
|
||||||
|
c := NewClient("token", srv.URL, AllowInsecureHTTP())
|
||||||
|
c.SetHTTPClient(srv.Client())
|
||||||
|
|
||||||
|
files, err := c.GetPullRequestFiles(context.Background(), "owner", "repo", 1)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
if len(files) != 1 {
|
||||||
|
t.Fatalf("expected 1 file, got %d", len(files))
|
||||||
|
}
|
||||||
|
if files[0].Patch != "" {
|
||||||
|
t.Errorf("expected empty patch for binary file, got %q", files[0].Patch)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetPullRequestFiles_404(t *testing.T) {
|
||||||
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.WriteHeader(404)
|
||||||
|
w.Write([]byte(`{"message":"Not Found"}`))
|
||||||
|
}))
|
||||||
|
defer srv.Close()
|
||||||
|
|
||||||
|
c := NewClient("token", srv.URL, AllowInsecureHTTP())
|
||||||
|
c.SetHTTPClient(srv.Client())
|
||||||
|
|
||||||
|
_, err := c.GetPullRequestFiles(context.Background(), "owner", "repo", 999)
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected error for 404")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetPullRequestFiles_MalformedJSON(t *testing.T) {
|
||||||
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.WriteHeader(200)
|
||||||
|
w.Write([]byte(`not json`))
|
||||||
|
}))
|
||||||
|
defer srv.Close()
|
||||||
|
|
||||||
|
c := NewClient("token", srv.URL, AllowInsecureHTTP())
|
||||||
|
c.SetHTTPClient(srv.Client())
|
||||||
|
|
||||||
|
_, err := c.GetPullRequestFiles(context.Background(), "owner", "repo", 1)
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected error for malformed JSON")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetFileContentAtRef_HappyPath(t *testing.T) {
|
||||||
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if r.URL.Path != "/repos/owner/repo/contents/path/to/file.go" {
|
||||||
|
t.Errorf("unexpected path: %s", r.URL.Path)
|
||||||
|
}
|
||||||
|
if r.URL.Query().Get("ref") != "abc123" {
|
||||||
|
t.Errorf("unexpected ref: %s", r.URL.Query().Get("ref"))
|
||||||
|
}
|
||||||
|
json.NewEncoder(w).Encode(map[string]string{
|
||||||
|
"content": "cGFja2FnZSBtYWlu", // "package main" in base64
|
||||||
|
"encoding": "base64",
|
||||||
|
})
|
||||||
|
}))
|
||||||
|
defer srv.Close()
|
||||||
|
|
||||||
|
c := NewClient("token", srv.URL, AllowInsecureHTTP())
|
||||||
|
c.SetHTTPClient(srv.Client())
|
||||||
|
|
||||||
|
content, err := c.GetFileContentAtRef(context.Background(), "owner", "repo", "path/to/file.go", "abc123")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
if content != "package main" {
|
||||||
|
t.Errorf("expected 'package main', got %q", content)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetFileContentAtRef_EmptyRef(t *testing.T) {
|
||||||
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if r.URL.Query().Get("ref") != "" {
|
||||||
|
t.Errorf("expected no ref param, got %q", r.URL.Query().Get("ref"))
|
||||||
|
}
|
||||||
|
json.NewEncoder(w).Encode(map[string]string{
|
||||||
|
"content": "aGVsbG8=", // "hello" in base64
|
||||||
|
"encoding": "base64",
|
||||||
|
})
|
||||||
|
}))
|
||||||
|
defer srv.Close()
|
||||||
|
|
||||||
|
c := NewClient("token", srv.URL, AllowInsecureHTTP())
|
||||||
|
c.SetHTTPClient(srv.Client())
|
||||||
|
|
||||||
|
content, err := c.GetFileContentAtRef(context.Background(), "owner", "repo", "file.txt", "")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
if content != "hello" {
|
||||||
|
t.Errorf("expected 'hello', got %q", content)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetFileContentAtRef_404(t *testing.T) {
|
||||||
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.WriteHeader(404)
|
||||||
|
w.Write([]byte(`{"message":"Not Found"}`))
|
||||||
|
}))
|
||||||
|
defer srv.Close()
|
||||||
|
|
||||||
|
c := NewClient("token", srv.URL, AllowInsecureHTTP())
|
||||||
|
c.SetHTTPClient(srv.Client())
|
||||||
|
|
||||||
|
_, err := c.GetFileContentAtRef(context.Background(), "owner", "repo", "missing.go", "main")
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected error for 404")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetFileContentAtRef_401(t *testing.T) {
|
||||||
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.WriteHeader(401)
|
||||||
|
w.Write([]byte(`{"message":"Bad credentials"}`))
|
||||||
|
}))
|
||||||
|
defer srv.Close()
|
||||||
|
|
||||||
|
c := NewClient("token", srv.URL, AllowInsecureHTTP())
|
||||||
|
c.SetHTTPClient(srv.Client())
|
||||||
|
|
||||||
|
_, err := c.GetFileContentAtRef(context.Background(), "owner", "repo", "file.go", "main")
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected error for 401")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetFileContentAtRef_MalformedJSON(t *testing.T) {
|
||||||
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.WriteHeader(200)
|
||||||
|
w.Write([]byte(`not valid json`))
|
||||||
|
}))
|
||||||
|
defer srv.Close()
|
||||||
|
|
||||||
|
c := NewClient("token", srv.URL, AllowInsecureHTTP())
|
||||||
|
c.SetHTTPClient(srv.Client())
|
||||||
|
|
||||||
|
_, err := c.GetFileContentAtRef(context.Background(), "owner", "repo", "file.go", "main")
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected error for malformed JSON")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetFileContentAtRef_429Retry(t *testing.T) {
|
||||||
|
attempts := 0
|
||||||
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
attempts++
|
||||||
|
if attempts == 1 {
|
||||||
|
w.WriteHeader(429)
|
||||||
|
w.Write([]byte(`{"message":"rate limit"}`))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
json.NewEncoder(w).Encode(map[string]string{
|
||||||
|
"content": "b2s=", // "ok" in base64
|
||||||
|
"encoding": "base64",
|
||||||
|
})
|
||||||
|
}))
|
||||||
|
defer srv.Close()
|
||||||
|
|
||||||
|
c := NewClient("token", srv.URL, AllowInsecureHTTP())
|
||||||
|
c.SetHTTPClient(srv.Client())
|
||||||
|
c.SetRetryBackoff([]time.Duration{1 * time.Millisecond})
|
||||||
|
|
||||||
|
content, err := c.GetFileContentAtRef(context.Background(), "owner", "repo", "file.go", "main")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
if content != "ok" {
|
||||||
|
t.Errorf("expected 'ok', got %q", content)
|
||||||
|
}
|
||||||
|
if attempts != 2 {
|
||||||
|
t.Errorf("expected 2 attempts, got %d", attempts)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetCommitStatuses_HappyPath(t *testing.T) {
|
||||||
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
switch {
|
||||||
|
case strings.Contains(r.URL.Path, "/status"):
|
||||||
|
json.NewEncoder(w).Encode(map[string]interface{}{
|
||||||
|
"state": "success",
|
||||||
|
"statuses": []map[string]string{
|
||||||
|
{
|
||||||
|
"context": "ci/build",
|
||||||
|
"state": "success",
|
||||||
|
"description": "Build passed",
|
||||||
|
"target_url": "https://ci.example.com/1",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
case strings.Contains(r.URL.Path, "/check-runs"):
|
||||||
|
conclusion := "success"
|
||||||
|
json.NewEncoder(w).Encode(map[string]interface{}{
|
||||||
|
"total_count": 1,
|
||||||
|
"check_runs": []map[string]interface{}{
|
||||||
|
{
|
||||||
|
"name": "lint",
|
||||||
|
"conclusion": &conclusion,
|
||||||
|
"status": "completed",
|
||||||
|
"html_url": "https://github.com/check/1",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
default:
|
||||||
|
t.Errorf("unexpected path: %s", r.URL.Path)
|
||||||
|
w.WriteHeader(404)
|
||||||
|
}
|
||||||
|
}))
|
||||||
|
defer srv.Close()
|
||||||
|
|
||||||
|
c := NewClient("token", srv.URL, AllowInsecureHTTP())
|
||||||
|
c.SetHTTPClient(srv.Client())
|
||||||
|
|
||||||
|
statuses, err := c.GetCommitStatuses(context.Background(), "owner", "repo", "abc123")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
if len(statuses) != 2 {
|
||||||
|
t.Fatalf("expected 2 statuses, got %d", len(statuses))
|
||||||
|
}
|
||||||
|
// First should be from commit statuses
|
||||||
|
if statuses[0].Context != "ci/build" {
|
||||||
|
t.Errorf("expected context 'ci/build', got %q", statuses[0].Context)
|
||||||
|
}
|
||||||
|
if statuses[0].Status != "success" {
|
||||||
|
t.Errorf("expected status 'success', got %q", statuses[0].Status)
|
||||||
|
}
|
||||||
|
// Second should be from check runs
|
||||||
|
if statuses[1].Context != "lint" {
|
||||||
|
t.Errorf("expected context 'lint', got %q", statuses[1].Context)
|
||||||
|
}
|
||||||
|
if statuses[1].Status != "success" {
|
||||||
|
t.Errorf("expected status 'success', got %q", statuses[1].Status)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetCommitStatuses_CheckRunConclusions(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
conclusion *string
|
||||||
|
status string
|
||||||
|
want string
|
||||||
|
}{
|
||||||
|
{stringPtr("success"), "completed", "success"},
|
||||||
|
{stringPtr("failure"), "completed", "failure"},
|
||||||
|
{stringPtr("action_required"), "completed", "failure"},
|
||||||
|
{stringPtr("timed_out"), "completed", "failure"},
|
||||||
|
{stringPtr("cancelled"), "completed", "success"},
|
||||||
|
{stringPtr("skipped"), "completed", "success"},
|
||||||
|
{stringPtr("neutral"), "completed", "success"},
|
||||||
|
{nil, "in_progress", "pending"},
|
||||||
|
{nil, "queued", "pending"},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
name := "nil"
|
||||||
|
if tt.conclusion != nil {
|
||||||
|
name = *tt.conclusion
|
||||||
|
}
|
||||||
|
t.Run(name, func(t *testing.T) {
|
||||||
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if strings.Contains(r.URL.Path, "/status") {
|
||||||
|
json.NewEncoder(w).Encode(map[string]interface{}{
|
||||||
|
"state": "success",
|
||||||
|
"statuses": []interface{}{},
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
json.NewEncoder(w).Encode(map[string]interface{}{
|
||||||
|
"total_count": 1,
|
||||||
|
"check_runs": []map[string]interface{}{
|
||||||
|
{
|
||||||
|
"name": "check",
|
||||||
|
"conclusion": tt.conclusion,
|
||||||
|
"status": tt.status,
|
||||||
|
"html_url": "https://github.com/check/1",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
}))
|
||||||
|
defer srv.Close()
|
||||||
|
|
||||||
|
c := NewClient("token", srv.URL, AllowInsecureHTTP())
|
||||||
|
c.SetHTTPClient(srv.Client())
|
||||||
|
|
||||||
|
statuses, err := c.GetCommitStatuses(context.Background(), "owner", "repo", "sha1")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
if len(statuses) != 1 {
|
||||||
|
t.Fatalf("expected 1 status, got %d", len(statuses))
|
||||||
|
}
|
||||||
|
if statuses[0].Status != tt.want {
|
||||||
|
t.Errorf("expected status %q, got %q", tt.want, statuses[0].Status)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetCommitStatuses_404(t *testing.T) {
|
||||||
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.WriteHeader(404)
|
||||||
|
w.Write([]byte(`{"message":"Not Found"}`))
|
||||||
|
}))
|
||||||
|
defer srv.Close()
|
||||||
|
|
||||||
|
c := NewClient("token", srv.URL, AllowInsecureHTTP())
|
||||||
|
c.SetHTTPClient(srv.Client())
|
||||||
|
|
||||||
|
_, err := c.GetCommitStatuses(context.Background(), "owner", "repo", "badsha")
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected error for 404")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetCommitStatuses_401(t *testing.T) {
|
||||||
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.WriteHeader(401)
|
||||||
|
w.Write([]byte(`{"message":"Bad credentials"}`))
|
||||||
|
}))
|
||||||
|
defer srv.Close()
|
||||||
|
|
||||||
|
c := NewClient("token", srv.URL, AllowInsecureHTTP())
|
||||||
|
c.SetHTTPClient(srv.Client())
|
||||||
|
|
||||||
|
_, err := c.GetCommitStatuses(context.Background(), "owner", "repo", "sha")
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected error for 401")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetCommitStatuses_MalformedJSON(t *testing.T) {
|
||||||
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.WriteHeader(200)
|
||||||
|
w.Write([]byte(`not json`))
|
||||||
|
}))
|
||||||
|
defer srv.Close()
|
||||||
|
|
||||||
|
c := NewClient("token", srv.URL, AllowInsecureHTTP())
|
||||||
|
c.SetHTTPClient(srv.Client())
|
||||||
|
|
||||||
|
_, err := c.GetCommitStatuses(context.Background(), "owner", "repo", "sha")
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected error for malformed JSON")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func stringPtr(s string) *string {
|
||||||
|
return &s
|
||||||
|
}
|
||||||
@@ -2,4 +2,4 @@ module gitea.weiker.me/rodin/review-bot
|
|||||||
|
|
||||||
go 1.26.2
|
go 1.26.2
|
||||||
|
|
||||||
require github.com/goccy/go-yaml v1.19.2
|
require gopkg.in/yaml.v3 v3.0.1
|
||||||
|
|||||||
@@ -1,2 +1,4 @@
|
|||||||
github.com/goccy/go-yaml v1.19.2 h1:PmFC1S6h8ljIz6gMRBopkjP1TVT7xuwrButHID66PoM=
|
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
|
||||||
github.com/goccy/go-yaml v1.19.2/go.mod h1:XBurs7gK8ATbW4ZPGKgcbrY1Br56PdM69F7LkFRi1kA=
|
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||||
|
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
|
||||||
|
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||||
|
|||||||
+8
-9
@@ -16,17 +16,16 @@ import (
|
|||||||
|
|
||||||
// Integration test requires a running Gitea instance and LLM endpoint.
|
// Integration test requires a running Gitea instance and LLM endpoint.
|
||||||
// Set environment variables:
|
// Set environment variables:
|
||||||
//
|
// INTEGRATION_GITEA_URL - Gitea base URL
|
||||||
// INTEGRATION_VCS_URL - VCS base URL
|
// INTEGRATION_GITEA_TOKEN - Gitea API token with repo access
|
||||||
// INTEGRATION_GITEA_TOKEN - Gitea API token with repo access
|
// INTEGRATION_GITEA_REPO - owner/repo with an open PR
|
||||||
// INTEGRATION_GITEA_REPO - owner/repo with an open PR
|
// INTEGRATION_PR_NUMBER - PR number to test against
|
||||||
// INTEGRATION_PR_NUMBER - PR number to test against
|
// INTEGRATION_LLM_BASE_URL - LLM API base URL
|
||||||
// INTEGRATION_LLM_BASE_URL - LLM API base URL
|
// INTEGRATION_LLM_API_KEY - LLM API key
|
||||||
// INTEGRATION_LLM_API_KEY - LLM API key
|
// INTEGRATION_LLM_MODEL - Model name
|
||||||
// INTEGRATION_LLM_MODEL - Model name
|
|
||||||
|
|
||||||
func TestIntegration_FullReviewFlow(t *testing.T) {
|
func TestIntegration_FullReviewFlow(t *testing.T) {
|
||||||
giteaURL := os.Getenv("INTEGRATION_VCS_URL")
|
giteaURL := os.Getenv("INTEGRATION_GITEA_URL")
|
||||||
giteaToken := os.Getenv("INTEGRATION_GITEA_TOKEN")
|
giteaToken := os.Getenv("INTEGRATION_GITEA_TOKEN")
|
||||||
giteaRepo := os.Getenv("INTEGRATION_GITEA_REPO")
|
giteaRepo := os.Getenv("INTEGRATION_GITEA_REPO")
|
||||||
prNumStr := os.Getenv("INTEGRATION_PR_NUMBER")
|
prNumStr := os.Getenv("INTEGRATION_PR_NUMBER")
|
||||||
|
|||||||
+5
-5
@@ -207,11 +207,11 @@ func (c *Client) completeOpenAI(ctx context.Context, messages []Message) (string
|
|||||||
|
|
||||||
type anthropicRequest struct {
|
type anthropicRequest struct {
|
||||||
AnthropicVersion string `json:"anthropic_version,omitempty"`
|
AnthropicVersion string `json:"anthropic_version,omitempty"`
|
||||||
Model string `json:"model,omitempty"`
|
Model string `json:"model,omitempty"`
|
||||||
MaxTokens int `json:"max_tokens"`
|
MaxTokens int `json:"max_tokens"`
|
||||||
System string `json:"system,omitempty"`
|
System string `json:"system,omitempty"`
|
||||||
Messages []anthropicMsg `json:"messages"`
|
Messages []anthropicMsg `json:"messages"`
|
||||||
Temperature float64 `json:"temperature,omitempty"`
|
Temperature float64 `json:"temperature,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type anthropicMsg struct {
|
type anthropicMsg struct {
|
||||||
|
|||||||
@@ -210,6 +210,7 @@ func TestWithTimeout(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
func TestComplete_Anthropic_Success(t *testing.T) {
|
func TestComplete_Anthropic_Success(t *testing.T) {
|
||||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
if r.URL.Path != "/messages" {
|
if r.URL.Path != "/messages" {
|
||||||
|
|||||||
@@ -1,303 +0,0 @@
|
|||||||
// Package review provides doc-map parsing and doc injection for path-scoped
|
|
||||||
// design document context in AI code reviews.
|
|
||||||
package review
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"fmt"
|
|
||||||
"log/slog"
|
|
||||||
"os"
|
|
||||||
"path/filepath"
|
|
||||||
"strings"
|
|
||||||
"unicode/utf8"
|
|
||||||
|
|
||||||
"github.com/goccy/go-yaml"
|
|
||||||
)
|
|
||||||
|
|
||||||
const (
|
|
||||||
// DefaultDocMapMaxBytes is the default cap on total injected doc content.
|
|
||||||
DefaultDocMapMaxBytes = 100 * 1024 // 100 KB
|
|
||||||
)
|
|
||||||
|
|
||||||
// DocMapping maps a set of path glob patterns to governing doc files/directories.
|
|
||||||
type DocMapping struct {
|
|
||||||
Paths []string `yaml:"paths"` // glob patterns matched against changed PR files
|
|
||||||
Docs []string `yaml:"docs"` // doc file paths or directories in the reviewed repo
|
|
||||||
}
|
|
||||||
|
|
||||||
// DocMapConfig is the top-level structure of a doc-map YAML file.
|
|
||||||
type DocMapConfig struct {
|
|
||||||
Mappings []DocMapping `yaml:"mappings"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// DocMapOptions configures behavior for doc loading.
|
|
||||||
type DocMapOptions struct {
|
|
||||||
// MaxBytes caps the total size of injected doc content. Default: DefaultDocMapMaxBytes.
|
|
||||||
MaxBytes int
|
|
||||||
}
|
|
||||||
|
|
||||||
// DocFetcher reads file and directory content from a VCS repository.
|
|
||||||
// It is a subset of vcsClient, defined here to keep the review package free
|
|
||||||
// of cmd-level dependencies.
|
|
||||||
type DocFetcher interface {
|
|
||||||
// GetFileContent returns the content of a single file at default branch.
|
|
||||||
GetFileContent(ctx context.Context, owner, repo, path string) (string, error)
|
|
||||||
// GetAllFilesInPath returns all files (path → content) under a directory.
|
|
||||||
GetAllFilesInPath(ctx context.Context, owner, repo, path string) (map[string]string, error)
|
|
||||||
}
|
|
||||||
|
|
||||||
// ParseDocMapConfig reads and parses a doc-map YAML file from a local path.
|
|
||||||
// Unknown top-level keys produce a warning but are not fatal.
|
|
||||||
func ParseDocMapConfig(localPath string) (*DocMapConfig, error) {
|
|
||||||
data, err := readFileBytes(localPath)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("read doc-map file %q: %w", localPath, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
var cfg DocMapConfig
|
|
||||||
if err := yaml.UnmarshalWithOptions(data, &cfg, yaml.Strict()); err != nil {
|
|
||||||
// Re-parse without strict mode to log which keys are unknown.
|
|
||||||
var relaxed DocMapConfig
|
|
||||||
if err2 := yaml.Unmarshal(data, &relaxed); err2 != nil {
|
|
||||||
return nil, fmt.Errorf("parse doc-map YAML %q: %w", localPath, err)
|
|
||||||
}
|
|
||||||
slog.Warn("doc-map YAML contains unknown keys (ignored)", "file", localPath, "error", err)
|
|
||||||
cfg = relaxed
|
|
||||||
}
|
|
||||||
return &cfg, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// MatchDocs returns deduplicated doc paths for the given changed file paths.
|
|
||||||
// A mapping matches if any of its path globs matches any of the changed files.
|
|
||||||
func MatchDocs(cfg *DocMapConfig, changedFiles []string) []string {
|
|
||||||
seen := make(map[string]struct{})
|
|
||||||
var result []string
|
|
||||||
|
|
||||||
for _, mapping := range cfg.Mappings {
|
|
||||||
if len(mapping.Paths) == 0 || len(mapping.Docs) == 0 {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if mappingMatches(mapping.Paths, changedFiles) {
|
|
||||||
for _, doc := range mapping.Docs {
|
|
||||||
if doc == "" {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if _, ok := seen[doc]; !ok {
|
|
||||||
seen[doc] = struct{}{}
|
|
||||||
result = append(result, doc)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return result
|
|
||||||
}
|
|
||||||
|
|
||||||
// mappingMatches returns true if any glob in patterns matches any file in files.
|
|
||||||
func mappingMatches(patterns, files []string) bool {
|
|
||||||
for _, pat := range patterns {
|
|
||||||
for _, f := range files {
|
|
||||||
if globMatch(pat, f) {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
// globMatch matches a path against a glob pattern that may contain **.
|
|
||||||
// It supports:
|
|
||||||
// - Standard path.Match patterns (*, ?, [range])
|
|
||||||
// - ** as a path segment that matches zero or more segments
|
|
||||||
// - Trailing /** to match a directory and all its contents
|
|
||||||
//
|
|
||||||
// The pattern and path use forward slash as separator.
|
|
||||||
func globMatch(pattern, path string) bool {
|
|
||||||
return globMatchParts(splitPath(pattern), splitPath(path))
|
|
||||||
}
|
|
||||||
|
|
||||||
// splitPath splits a slash-separated path into non-empty parts.
|
|
||||||
func splitPath(p string) []string {
|
|
||||||
// Clean and split on "/"
|
|
||||||
parts := strings.Split(p, "/")
|
|
||||||
result := make([]string, 0, len(parts))
|
|
||||||
for _, part := range parts {
|
|
||||||
if part != "" {
|
|
||||||
result = append(result, part)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return result
|
|
||||||
}
|
|
||||||
|
|
||||||
// globMatchParts recursively matches pattern parts against path parts.
|
|
||||||
func globMatchParts(patParts, pathParts []string) bool {
|
|
||||||
for len(patParts) > 0 {
|
|
||||||
pat := patParts[0]
|
|
||||||
if pat == "**" {
|
|
||||||
patParts = patParts[1:]
|
|
||||||
if len(patParts) == 0 {
|
|
||||||
// Trailing **: matches any remaining path (including empty).
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
// ** in the middle: try matching the rest at every position.
|
|
||||||
for i := 0; i <= len(pathParts); i++ {
|
|
||||||
if globMatchParts(patParts, pathParts[i:]) {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
// Non-** segment: path must have a segment here.
|
|
||||||
if len(pathParts) == 0 {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
matched, err := filepath.Match(pat, pathParts[0])
|
|
||||||
if err != nil || !matched {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
patParts = patParts[1:]
|
|
||||||
pathParts = pathParts[1:]
|
|
||||||
}
|
|
||||||
// All pattern parts consumed; path must also be consumed.
|
|
||||||
return len(pathParts) == 0
|
|
||||||
}
|
|
||||||
|
|
||||||
// LoadMatchingDocs fetches content for the given doc paths via VCS and returns
|
|
||||||
// a formatted string suitable for injection into the system prompt.
|
|
||||||
//
|
|
||||||
// Behavior:
|
|
||||||
// - Paths that look like directories (end with /, or GetAllFilesInPath returns files)
|
|
||||||
// are expanded to all .md files under them.
|
|
||||||
// - Missing files are logged as warnings and skipped.
|
|
||||||
// - Total content is capped at opts.MaxBytes; truncation is noted inline.
|
|
||||||
func LoadMatchingDocs(ctx context.Context, fetcher DocFetcher, owner, repo string, docPaths []string, opts DocMapOptions) (string, error) {
|
|
||||||
if opts.MaxBytes <= 0 {
|
|
||||||
opts.MaxBytes = DefaultDocMapMaxBytes
|
|
||||||
}
|
|
||||||
|
|
||||||
var sb strings.Builder
|
|
||||||
totalBytes := 0
|
|
||||||
limitReached := false
|
|
||||||
|
|
||||||
for _, docPath := range docPaths {
|
|
||||||
if ctx.Err() != nil {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
if limitReached {
|
|
||||||
slog.Warn("doc-map: context size limit reached, skipping remaining docs",
|
|
||||||
"remaining_path", docPath, "limit_bytes", opts.MaxBytes)
|
|
||||||
break
|
|
||||||
}
|
|
||||||
|
|
||||||
entries, err := loadDocEntries(ctx, fetcher, owner, repo, docPath)
|
|
||||||
if err != nil {
|
|
||||||
slog.Warn("doc-map: could not load doc, skipping", "path", docPath, "error", err)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if len(entries) == 0 {
|
|
||||||
slog.Debug("doc-map: no .md files found under path", "path", docPath)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, entry := range entries {
|
|
||||||
if limitReached {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
available := opts.MaxBytes - totalBytes
|
|
||||||
if available <= 0 {
|
|
||||||
limitReached = true
|
|
||||||
sb.WriteString("\n\n> ⚠️ Design document context truncated — size limit reached.\n")
|
|
||||||
break
|
|
||||||
}
|
|
||||||
|
|
||||||
content := entry.content
|
|
||||||
truncated := false
|
|
||||||
if len(content) > available {
|
|
||||||
content = truncateUTF8(content, available)
|
|
||||||
truncated = true
|
|
||||||
limitReached = true
|
|
||||||
}
|
|
||||||
|
|
||||||
sb.WriteString("### ")
|
|
||||||
sb.WriteString(entry.path)
|
|
||||||
sb.WriteString("\n\n")
|
|
||||||
sb.WriteString(content)
|
|
||||||
sb.WriteString("\n")
|
|
||||||
if truncated {
|
|
||||||
sb.WriteString("\n> ⚠️ (truncated — size limit reached)\n")
|
|
||||||
}
|
|
||||||
totalBytes += len(content)
|
|
||||||
slog.Debug("doc-map: injected doc", "path", entry.path, "bytes", len(content))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if sb.Len() == 0 {
|
|
||||||
return "", nil
|
|
||||||
}
|
|
||||||
return sb.String(), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// docEntry holds a single doc file path and content.
|
|
||||||
type docEntry struct {
|
|
||||||
path string
|
|
||||||
content string
|
|
||||||
}
|
|
||||||
|
|
||||||
// loadDocEntries returns the doc content for a given path.
|
|
||||||
// If the path is a directory, all .md files under it are returned.
|
|
||||||
// If it's a file, a single entry is returned.
|
|
||||||
func loadDocEntries(ctx context.Context, fetcher DocFetcher, owner, repo, docPath string) ([]docEntry, error) {
|
|
||||||
// Try directory expansion first.
|
|
||||||
files, err := fetcher.GetAllFilesInPath(ctx, owner, repo, docPath)
|
|
||||||
if err == nil && len(files) > 0 {
|
|
||||||
// Filter for .md files only.
|
|
||||||
var entries []docEntry
|
|
||||||
for path, content := range files {
|
|
||||||
if isMDFile(path) {
|
|
||||||
entries = append(entries, docEntry{path: path, content: content})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
// Sort for deterministic output.
|
|
||||||
sortDocEntries(entries)
|
|
||||||
return entries, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Try as a single file.
|
|
||||||
content, fileErr := fetcher.GetFileContent(ctx, owner, repo, docPath)
|
|
||||||
if fileErr != nil {
|
|
||||||
// Return the file error (more specific than directory error).
|
|
||||||
return nil, fmt.Errorf("fetch doc %q: %w", docPath, fileErr)
|
|
||||||
}
|
|
||||||
return []docEntry{{path: docPath, content: content}}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// isMDFile returns true if the file has a .md extension.
|
|
||||||
func isMDFile(path string) bool {
|
|
||||||
return strings.HasSuffix(strings.ToLower(path), ".md")
|
|
||||||
}
|
|
||||||
|
|
||||||
// sortDocEntries sorts entries by path for deterministic output.
|
|
||||||
func sortDocEntries(entries []docEntry) {
|
|
||||||
// Simple insertion sort (doc lists are small).
|
|
||||||
for i := 1; i < len(entries); i++ {
|
|
||||||
for j := i; j > 0 && entries[j].path < entries[j-1].path; j-- {
|
|
||||||
entries[j], entries[j-1] = entries[j-1], entries[j]
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// readFileBytes reads the contents of a local file.
|
|
||||||
func readFileBytes(path string) ([]byte, error) {
|
|
||||||
return os.ReadFile(path)
|
|
||||||
}
|
|
||||||
|
|
||||||
// truncateUTF8 truncates s to at most maxBytes without splitting multi-byte
|
|
||||||
// UTF-8 characters. Returns a valid UTF-8 string of at most maxBytes bytes.
|
|
||||||
func truncateUTF8(s string, maxBytes int) string {
|
|
||||||
if len(s) <= maxBytes {
|
|
||||||
return s
|
|
||||||
}
|
|
||||||
for maxBytes > 0 && !utf8.RuneStart(s[maxBytes]) {
|
|
||||||
maxBytes--
|
|
||||||
}
|
|
||||||
return s[:maxBytes]
|
|
||||||
}
|
|
||||||
@@ -1,394 +0,0 @@
|
|||||||
package review
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"errors"
|
|
||||||
"os"
|
|
||||||
"path/filepath"
|
|
||||||
"strings"
|
|
||||||
"testing"
|
|
||||||
)
|
|
||||||
|
|
||||||
// fakeDocFetcher is a mock DocFetcher for tests.
|
|
||||||
type fakeDocFetcher struct {
|
|
||||||
files map[string]string // path -> content
|
|
||||||
dirs map[string]map[string]string // dir path -> (file path -> content)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (f *fakeDocFetcher) GetFileContent(_ context.Context, _, _, path string) (string, error) {
|
|
||||||
if content, ok := f.files[path]; ok {
|
|
||||||
return content, nil
|
|
||||||
}
|
|
||||||
return "", errors.New("file not found: " + path)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (f *fakeDocFetcher) GetAllFilesInPath(_ context.Context, _, _, path string) (map[string]string, error) {
|
|
||||||
if files, ok := f.dirs[path]; ok {
|
|
||||||
return files, nil
|
|
||||||
}
|
|
||||||
// Return empty (not an error) for unknown directories.
|
|
||||||
return nil, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// ============================================================
|
|
||||||
// ParseDocMapConfig
|
|
||||||
// ============================================================
|
|
||||||
|
|
||||||
func TestParseDocMapConfig_Valid(t *testing.T) {
|
|
||||||
yaml := `
|
|
||||||
mappings:
|
|
||||||
- paths:
|
|
||||||
- "lib/foo/**"
|
|
||||||
docs:
|
|
||||||
- docs/foo.md
|
|
||||||
- paths:
|
|
||||||
- "lib/bar/**"
|
|
||||||
- "lib/baz.go"
|
|
||||||
docs:
|
|
||||||
- docs/bar.md
|
|
||||||
- docs/shared/
|
|
||||||
`
|
|
||||||
f := writeTempYAML(t, yaml)
|
|
||||||
cfg, err := ParseDocMapConfig(f)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("unexpected error: %v", err)
|
|
||||||
}
|
|
||||||
if len(cfg.Mappings) != 2 {
|
|
||||||
t.Fatalf("expected 2 mappings, got %d", len(cfg.Mappings))
|
|
||||||
}
|
|
||||||
if cfg.Mappings[0].Paths[0] != "lib/foo/**" {
|
|
||||||
t.Errorf("unexpected path: %q", cfg.Mappings[0].Paths[0])
|
|
||||||
}
|
|
||||||
if cfg.Mappings[1].Docs[1] != "docs/shared/" {
|
|
||||||
t.Errorf("unexpected doc: %q", cfg.Mappings[1].Docs[1])
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestParseDocMapConfig_InvalidYAML(t *testing.T) {
|
|
||||||
f := writeTempYAML(t, "mappings: [{{invalid")
|
|
||||||
_, err := ParseDocMapConfig(f)
|
|
||||||
if err == nil {
|
|
||||||
t.Fatal("expected error for invalid YAML, got nil")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestParseDocMapConfig_EmptyMappings(t *testing.T) {
|
|
||||||
f := writeTempYAML(t, "mappings: []\n")
|
|
||||||
cfg, err := ParseDocMapConfig(f)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("unexpected error: %v", err)
|
|
||||||
}
|
|
||||||
if len(cfg.Mappings) != 0 {
|
|
||||||
t.Errorf("expected 0 mappings, got %d", len(cfg.Mappings))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestParseDocMapConfig_UnknownKeys(t *testing.T) {
|
|
||||||
// Unknown keys should produce a warning but not fail.
|
|
||||||
yaml := `
|
|
||||||
mappings:
|
|
||||||
- paths: ["lib/foo/**"]
|
|
||||||
docs: ["docs/foo.md"]
|
|
||||||
extra_key: ignored
|
|
||||||
`
|
|
||||||
f := writeTempYAML(t, yaml)
|
|
||||||
// Should succeed (lenient parsing).
|
|
||||||
cfg, err := ParseDocMapConfig(f)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("unexpected error for unknown keys: %v", err)
|
|
||||||
}
|
|
||||||
if len(cfg.Mappings) != 1 {
|
|
||||||
t.Errorf("expected 1 mapping, got %d", len(cfg.Mappings))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestParseDocMapConfig_FileNotFound(t *testing.T) {
|
|
||||||
_, err := ParseDocMapConfig("/nonexistent/path/doc-map.yml")
|
|
||||||
if err == nil {
|
|
||||||
t.Fatal("expected error for missing file, got nil")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// ============================================================
|
|
||||||
// MatchDocs
|
|
||||||
// ============================================================
|
|
||||||
|
|
||||||
func TestMatchDocs_NoMatch(t *testing.T) {
|
|
||||||
cfg := &DocMapConfig{
|
|
||||||
Mappings: []DocMapping{
|
|
||||||
{Paths: []string{"lib/foo/**"}, Docs: []string{"docs/foo.md"}},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
got := MatchDocs(cfg, []string{"lib/bar/baz.go"})
|
|
||||||
if len(got) != 0 {
|
|
||||||
t.Errorf("expected no matches, got %v", got)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestMatchDocs_SingleMatch(t *testing.T) {
|
|
||||||
cfg := &DocMapConfig{
|
|
||||||
Mappings: []DocMapping{
|
|
||||||
{Paths: []string{"lib/foo/**"}, Docs: []string{"docs/foo.md"}},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
got := MatchDocs(cfg, []string{"lib/foo/bar.go"})
|
|
||||||
if len(got) != 1 || got[0] != "docs/foo.md" {
|
|
||||||
t.Errorf("expected [docs/foo.md], got %v", got)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestMatchDocs_MultipleMatchesDeduplicated(t *testing.T) {
|
|
||||||
cfg := &DocMapConfig{
|
|
||||||
Mappings: []DocMapping{
|
|
||||||
{Paths: []string{"lib/foo/**"}, Docs: []string{"docs/shared.md", "docs/foo.md"}},
|
|
||||||
{Paths: []string{"lib/bar/**"}, Docs: []string{"docs/shared.md", "docs/bar.md"}},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
got := MatchDocs(cfg, []string{"lib/foo/a.go", "lib/bar/b.go"})
|
|
||||||
// Both match; docs/shared.md should appear only once.
|
|
||||||
wantSet := map[string]bool{
|
|
||||||
"docs/shared.md": true,
|
|
||||||
"docs/foo.md": true,
|
|
||||||
"docs/bar.md": true,
|
|
||||||
}
|
|
||||||
if len(got) != 3 {
|
|
||||||
t.Errorf("expected 3 docs, got %d: %v", len(got), got)
|
|
||||||
}
|
|
||||||
for _, d := range got {
|
|
||||||
if !wantSet[d] {
|
|
||||||
t.Errorf("unexpected doc: %q", d)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestMatchDocs_EmptyPaths(t *testing.T) {
|
|
||||||
// Mapping with empty paths list should not match anything.
|
|
||||||
cfg := &DocMapConfig{
|
|
||||||
Mappings: []DocMapping{
|
|
||||||
{Paths: []string{}, Docs: []string{"docs/foo.md"}},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
got := MatchDocs(cfg, []string{"lib/foo/bar.go"})
|
|
||||||
if len(got) != 0 {
|
|
||||||
t.Errorf("expected no matches for empty paths, got %v", got)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestMatchDocs_EmptyDocs(t *testing.T) {
|
|
||||||
// Mapping with empty docs list should produce nothing.
|
|
||||||
cfg := &DocMapConfig{
|
|
||||||
Mappings: []DocMapping{
|
|
||||||
{Paths: []string{"lib/foo/**"}, Docs: []string{}},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
got := MatchDocs(cfg, []string{"lib/foo/bar.go"})
|
|
||||||
if len(got) != 0 {
|
|
||||||
t.Errorf("expected no docs for empty docs list, got %v", got)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestMatchDocs_ExactMatch(t *testing.T) {
|
|
||||||
cfg := &DocMapConfig{
|
|
||||||
Mappings: []DocMapping{
|
|
||||||
{Paths: []string{"lib/baz.go"}, Docs: []string{"docs/baz.md"}},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
got := MatchDocs(cfg, []string{"lib/baz.go"})
|
|
||||||
if len(got) != 1 || got[0] != "docs/baz.md" {
|
|
||||||
t.Errorf("expected [docs/baz.md], got %v", got)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// ============================================================
|
|
||||||
// globMatch
|
|
||||||
// ============================================================
|
|
||||||
|
|
||||||
func TestGlobMatch(t *testing.T) {
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
pattern string
|
|
||||||
path string
|
|
||||||
want bool
|
|
||||||
}{
|
|
||||||
{"exact match", "lib/foo/bar.go", "lib/foo/bar.go", true},
|
|
||||||
{"exact no match", "lib/foo/bar.go", "lib/foo/baz.go", false},
|
|
||||||
{"star wildcard", "lib/foo/*.go", "lib/foo/bar.go", true},
|
|
||||||
{"star no match cross-dir", "lib/foo/*.go", "lib/foo/sub/bar.go", false},
|
|
||||||
{"trailing doublestar", "lib/foo/**", "lib/foo/bar.go", true},
|
|
||||||
{"trailing doublestar nested", "lib/foo/**", "lib/foo/sub/deep/bar.go", true},
|
|
||||||
// Note: trailing ** matches the parent path too; PR file lists contain file paths
|
|
||||||
// (not directories), so this corner case does not arise in practice.
|
|
||||||
{"trailing doublestar matches parent", "lib/foo/**", "lib/foo", true},
|
|
||||||
{"doublestar in middle", "lib/**/bar.go", "lib/foo/sub/bar.go", true},
|
|
||||||
{"doublestar in middle no match", "lib/**/bar.go", "lib/foo/sub/baz.go", false},
|
|
||||||
{"leading doublestar", "**/bar.go", "lib/foo/bar.go", true},
|
|
||||||
{"leading doublestar top-level", "**/bar.go", "bar.go", true},
|
|
||||||
{"question mark", "lib/foo/ba?.go", "lib/foo/bar.go", true},
|
|
||||||
{"question mark no match", "lib/foo/ba?.go", "lib/foo/ba.go", false},
|
|
||||||
{"star matches none in segment", "lib/*/bar.go", "lib/bar.go", false},
|
|
||||||
{"star single segment", "lib/*/bar.go", "lib/foo/bar.go", true},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tc := range tests {
|
|
||||||
t.Run(tc.name, func(t *testing.T) {
|
|
||||||
got := globMatch(tc.pattern, tc.path)
|
|
||||||
if got != tc.want {
|
|
||||||
t.Errorf("globMatch(%q, %q) = %v, want %v", tc.pattern, tc.path, got, tc.want)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// ============================================================
|
|
||||||
// LoadMatchingDocs
|
|
||||||
// ============================================================
|
|
||||||
|
|
||||||
func TestLoadMatchingDocs_FileInjection(t *testing.T) {
|
|
||||||
fetcher := &fakeDocFetcher{
|
|
||||||
files: map[string]string{
|
|
||||||
"docs/foo.md": "# Foo Design\n\nThis is the foo doc.",
|
|
||||||
},
|
|
||||||
}
|
|
||||||
content, err := LoadMatchingDocs(context.Background(), fetcher, "owner", "repo",
|
|
||||||
[]string{"docs/foo.md"}, DocMapOptions{MaxBytes: DefaultDocMapMaxBytes})
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("unexpected error: %v", err)
|
|
||||||
}
|
|
||||||
if !strings.Contains(content, "# Foo Design") {
|
|
||||||
t.Errorf("expected doc content, got: %q", content)
|
|
||||||
}
|
|
||||||
if !strings.Contains(content, "### docs/foo.md") {
|
|
||||||
t.Errorf("expected heading with path, got: %q", content)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestLoadMatchingDocs_MissingFileSkipped(t *testing.T) {
|
|
||||||
fetcher := &fakeDocFetcher{
|
|
||||||
files: map[string]string{
|
|
||||||
"docs/present.md": "present",
|
|
||||||
},
|
|
||||||
}
|
|
||||||
content, err := LoadMatchingDocs(context.Background(), fetcher, "owner", "repo",
|
|
||||||
[]string{"docs/missing.md", "docs/present.md"}, DocMapOptions{MaxBytes: DefaultDocMapMaxBytes})
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("unexpected error: %v", err)
|
|
||||||
}
|
|
||||||
if !strings.Contains(content, "present") {
|
|
||||||
t.Errorf("expected present doc content, got: %q", content)
|
|
||||||
}
|
|
||||||
// Missing file should be skipped, not cause a failure.
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestLoadMatchingDocs_DirectoryExpansion(t *testing.T) {
|
|
||||||
fetcher := &fakeDocFetcher{
|
|
||||||
dirs: map[string]map[string]string{
|
|
||||||
"docs/domain/": {
|
|
||||||
"docs/domain/a.md": "# A",
|
|
||||||
"docs/domain/b.md": "# B",
|
|
||||||
"docs/domain/c.go": "package domain", // should be skipped (not .md)
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
content, err := LoadMatchingDocs(context.Background(), fetcher, "owner", "repo",
|
|
||||||
[]string{"docs/domain/"}, DocMapOptions{MaxBytes: DefaultDocMapMaxBytes})
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("unexpected error: %v", err)
|
|
||||||
}
|
|
||||||
if !strings.Contains(content, "# A") {
|
|
||||||
t.Errorf("expected doc A content, got: %q", content)
|
|
||||||
}
|
|
||||||
if !strings.Contains(content, "# B") {
|
|
||||||
t.Errorf("expected doc B content, got: %q", content)
|
|
||||||
}
|
|
||||||
if strings.Contains(content, "package domain") {
|
|
||||||
t.Errorf("non-.md file should not be injected, got: %q", content)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestLoadMatchingDocs_DirectoryNoMDFiles(t *testing.T) {
|
|
||||||
fetcher := &fakeDocFetcher{
|
|
||||||
dirs: map[string]map[string]string{
|
|
||||||
"src/": {
|
|
||||||
"src/main.go": "package main",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
content, err := LoadMatchingDocs(context.Background(), fetcher, "owner", "repo",
|
|
||||||
[]string{"src/"}, DocMapOptions{MaxBytes: DefaultDocMapMaxBytes})
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("unexpected error: %v", err)
|
|
||||||
}
|
|
||||||
if content != "" {
|
|
||||||
t.Errorf("expected empty content for dir with no .md files, got: %q", content)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestLoadMatchingDocs_NoMatchingPaths(t *testing.T) {
|
|
||||||
fetcher := &fakeDocFetcher{}
|
|
||||||
content, err := LoadMatchingDocs(context.Background(), fetcher, "owner", "repo",
|
|
||||||
[]string{}, DocMapOptions{MaxBytes: DefaultDocMapMaxBytes})
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("unexpected error: %v", err)
|
|
||||||
}
|
|
||||||
if content != "" {
|
|
||||||
t.Errorf("expected empty content for no paths, got: %q", content)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestLoadMatchingDocs_ContextSizeGuard(t *testing.T) {
|
|
||||||
bigContent := strings.Repeat("x", 200)
|
|
||||||
fetcher := &fakeDocFetcher{
|
|
||||||
files: map[string]string{
|
|
||||||
"docs/a.md": bigContent,
|
|
||||||
"docs/b.md": bigContent,
|
|
||||||
"docs/c.md": bigContent,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
// Limit to 350 bytes — enough for a.md fully and part of b.md.
|
|
||||||
content, err := LoadMatchingDocs(context.Background(), fetcher, "owner", "repo",
|
|
||||||
[]string{"docs/a.md", "docs/b.md", "docs/c.md"}, DocMapOptions{MaxBytes: 350})
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("unexpected error: %v", err)
|
|
||||||
}
|
|
||||||
if len(content) > 600 {
|
|
||||||
t.Errorf("content too large, expected ≤600 bytes total, got %d", len(content))
|
|
||||||
}
|
|
||||||
if !strings.Contains(content, "truncated") {
|
|
||||||
t.Errorf("expected truncation notice, got: %q", content)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestLoadMatchingDocs_Deduplication(t *testing.T) {
|
|
||||||
fetcher := &fakeDocFetcher{
|
|
||||||
files: map[string]string{
|
|
||||||
"docs/shared.md": "shared content",
|
|
||||||
},
|
|
||||||
}
|
|
||||||
// MatchDocs deduplicates before calling LoadMatchingDocs, but test it with
|
|
||||||
// duplicates in input too.
|
|
||||||
content, err := LoadMatchingDocs(context.Background(), fetcher, "owner", "repo",
|
|
||||||
[]string{"docs/shared.md"}, DocMapOptions{MaxBytes: DefaultDocMapMaxBytes})
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("unexpected error: %v", err)
|
|
||||||
}
|
|
||||||
if !strings.Contains(content, "shared content") {
|
|
||||||
t.Errorf("expected shared content, got: %q", content)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// ============================================================
|
|
||||||
// Helpers
|
|
||||||
// ============================================================
|
|
||||||
|
|
||||||
func writeTempYAML(t *testing.T, content string) string {
|
|
||||||
t.Helper()
|
|
||||||
f, err := os.CreateTemp(t.TempDir(), "doc-map-*.yml")
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("failed to create temp file: %v", err)
|
|
||||||
}
|
|
||||||
defer f.Close()
|
|
||||||
if _, err := f.WriteString(content); err != nil {
|
|
||||||
t.Fatalf("failed to write temp file: %v", err)
|
|
||||||
}
|
|
||||||
return filepath.Clean(f.Name())
|
|
||||||
}
|
|
||||||
+38
-146
@@ -5,15 +5,12 @@ import (
|
|||||||
"embed"
|
"embed"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
|
||||||
"os"
|
"os"
|
||||||
"sort"
|
"sort"
|
||||||
"strings"
|
"strings"
|
||||||
"unicode/utf8"
|
"unicode/utf8"
|
||||||
|
|
||||||
"github.com/goccy/go-yaml"
|
"gopkg.in/yaml.v3"
|
||||||
"github.com/goccy/go-yaml/ast"
|
|
||||||
"github.com/goccy/go-yaml/parser"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
//go:embed personas/*.yaml
|
//go:embed personas/*.yaml
|
||||||
@@ -121,7 +118,9 @@ func ListBuiltinPersonas() []string {
|
|||||||
default:
|
default:
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
seen[personaName] = true
|
if !seen[personaName] {
|
||||||
|
seen[personaName] = true
|
||||||
|
}
|
||||||
}
|
}
|
||||||
names := make([]string, 0, len(seen))
|
names := make([]string, 0, len(seen))
|
||||||
for name := range seen {
|
for name := range seen {
|
||||||
@@ -143,19 +142,10 @@ func parsePersona(data []byte, source string) (*Persona, error) {
|
|||||||
err = unmarshalYAMLWithDepthLimit(data, &p, MaxYAMLDepth)
|
err = unmarshalYAMLWithDepthLimit(data, &p, MaxYAMLDepth)
|
||||||
} else {
|
} else {
|
||||||
// Use json.Decoder with DisallowUnknownFields for consistency with
|
// Use json.Decoder with DisallowUnknownFields for consistency with
|
||||||
// YAML's Strict() - both reject unknown fields to catch typos.
|
// YAML's KnownFields(true) - both reject unknown fields to catch typos.
|
||||||
dec := json.NewDecoder(bytes.NewReader(data))
|
dec := json.NewDecoder(bytes.NewReader(data))
|
||||||
dec.DisallowUnknownFields()
|
dec.DisallowUnknownFields()
|
||||||
err = dec.Decode(&p)
|
err = dec.Decode(&p)
|
||||||
if err == nil {
|
|
||||||
// Reject trailing content after the first valid JSON object.
|
|
||||||
// Without this check, input like `{"name":"x"}garbage` would
|
|
||||||
// silently succeed because Decoder stops after one object.
|
|
||||||
var dummy json.RawMessage
|
|
||||||
if err2 := dec.Decode(&dummy); err2 != io.EOF {
|
|
||||||
err = fmt.Errorf("unexpected trailing content after JSON object")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("parse persona %s: %w", source, err)
|
return nil, fmt.Errorf("parse persona %s: %w", source, err)
|
||||||
@@ -166,164 +156,70 @@ func parsePersona(data []byte, source string) (*Persona, error) {
|
|||||||
return &p, nil
|
return &p, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// unmarshalYAMLWithDepthLimit unmarshals YAML data with three safety checks:
|
// unmarshalYAMLWithDepthLimit unmarshals YAML data with explicit depth limiting
|
||||||
// - Depth limiting: rejects AST trees exceeding maxDepth to prevent stack exhaustion.
|
// and strict field checking. This protects against stack exhaustion from deeply
|
||||||
// - Multi-document rejection: prevents silent data loss from ignored extra documents.
|
// nested structures and catches typos in field names.
|
||||||
// - Strict field checking: rejects unknown YAML keys to catch typos early.
|
// Multi-document YAML files are rejected to prevent silent data loss.
|
||||||
func unmarshalYAMLWithDepthLimit(data []byte, out any, maxDepth int) error {
|
func unmarshalYAMLWithDepthLimit(data []byte, out any, maxDepth int) error {
|
||||||
// First pass: parse into AST to check depth limits, node counts, and
|
// First pass: decode into a yaml.Node to check depth limits and node counts.
|
||||||
// multi-document rejection. This prevents stack exhaustion before we
|
// This prevents stack exhaustion before we attempt to decode into structs.
|
||||||
// attempt to decode into structs.
|
var node yaml.Node
|
||||||
file, err := parser.ParseBytes(data, 0)
|
dec := yaml.NewDecoder(bytes.NewReader(data))
|
||||||
if err != nil {
|
if err := dec.Decode(&node); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// Reject empty YAML input (whitespace-only, comment-only, or truly empty files).
|
|
||||||
// The parser returns a single doc with nil body for these cases.
|
|
||||||
if len(file.Docs) == 0 || file.Docs[0].Body == nil {
|
|
||||||
return fmt.Errorf("empty YAML document")
|
|
||||||
}
|
|
||||||
|
|
||||||
// Reject multi-document YAML files - silently ignoring additional documents
|
// Reject multi-document YAML files - silently ignoring additional documents
|
||||||
// could lead to confusing behavior where users think their changes take effect.
|
// could lead to confusing behavior where users think their changes take effect.
|
||||||
if len(file.Docs) > 1 {
|
var extra yaml.Node
|
||||||
|
if dec.Decode(&extra) == nil {
|
||||||
return fmt.Errorf("multi-document YAML is not supported; only single-document files are allowed")
|
return fmt.Errorf("multi-document YAML is not supported; only single-document files are allowed")
|
||||||
}
|
}
|
||||||
|
|
||||||
nodeCount := 0
|
nodeCount := 0
|
||||||
if err := checkYAMLDepth(file.Docs[0].Body, 0, maxDepth, MaxYAMLNodes, make(map[ast.Node]int), make(map[ast.Node]bool), &nodeCount); err != nil {
|
if err := checkYAMLDepth(&node, 0, maxDepth, MaxYAMLNodes, make(map[*yaml.Node]struct{}), &nodeCount); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// Second pass: decode with strict field checking enabled.
|
// Second pass: decode with strict field checking enabled.
|
||||||
// Strict() rejects unknown keys, catching typos like "focuss" or "identiy".
|
// KnownFields(true) rejects unknown keys, catching typos like "focuss" or "identiy".
|
||||||
//
|
// We must re-decode from the original data because yaml.Node.Decode() doesn't
|
||||||
// Safety note: goccy/go-yaml's decoder does not expand YAML aliases
|
// support the KnownFields option.
|
||||||
// recursively — it resolves them via the pre-built AST, which our first
|
strictDec := yaml.NewDecoder(bytes.NewReader(data))
|
||||||
// pass already depth-checked. Alias chains that would exceed depth limits
|
strictDec.KnownFields(true)
|
||||||
// are caught above; the decoder merely reads the resolved scalar values.
|
return strictDec.Decode(out)
|
||||||
dec := yaml.NewDecoder(bytes.NewReader(data), yaml.Strict())
|
|
||||||
return dec.Decode(out)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// checkYAMLDepth recursively checks that YAML AST nodes don't exceed the depth
|
// checkYAMLDepth recursively checks that YAML nodes don't exceed the depth limit
|
||||||
// limit or the total node count limit. It uses two tracking maps:
|
// or the total node count limit. It also detects alias cycles to prevent infinite
|
||||||
// - validated: maps each node to the maximum depth at which it was previously
|
// recursion from crafted YAML with self-referential aliases.
|
||||||
// checked. If a node is revisited at a deeper depth (e.g., via an alias),
|
func checkYAMLDepth(node *yaml.Node, depth, maxDepth, maxNodes int, seen map[*yaml.Node]struct{}, nodeCount *int) error {
|
||||||
// we re-check it to ensure the combined effective depth doesn't exceed limits.
|
|
||||||
// - visiting: per-path recursion stack for true cycle detection. A node on the
|
|
||||||
// current path is a cycle (alias loop); we return nil to avoid infinite recursion.
|
|
||||||
//
|
|
||||||
// This design prevents the alias depth bypass where an anchored subtree validated
|
|
||||||
// at a shallow depth could be referenced via alias at a greater depth, effectively
|
|
||||||
// exceeding MaxYAMLDepth.
|
|
||||||
func checkYAMLDepth(node ast.Node, depth, maxDepth, maxNodes int, validated map[ast.Node]int, visiting map[ast.Node]bool, nodeCount *int) error {
|
|
||||||
if node == nil {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
if depth > maxDepth {
|
if depth > maxDepth {
|
||||||
return fmt.Errorf("YAML nesting depth exceeds maximum (%d)", maxDepth)
|
return fmt.Errorf("YAML nesting depth exceeds maximum (%d)", maxDepth)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Cycle detection: if we're currently visiting this node on the current
|
|
||||||
// recursion path, it's a cycle (e.g., alias pointing to an ancestor).
|
|
||||||
// Return nil to break the cycle without error — cycles are a structural
|
|
||||||
// property, not a depth violation.
|
|
||||||
if visiting[node] {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Track total nodes visited as defense-in-depth against wide-but-shallow attacks.
|
// Track total nodes visited as defense-in-depth against wide-but-shallow attacks.
|
||||||
// Placed after cycle detection but before the depth-aware short-circuit. This means
|
|
||||||
// nodes revisited at shallower depths (via aliases) are counted each time they are
|
|
||||||
// encountered — intentional conservative overcounting. This bounds the total work
|
|
||||||
// performed during validation rather than tracking unique nodes, which is the safer
|
|
||||||
// security posture for untrusted YAML input.
|
|
||||||
*nodeCount++
|
*nodeCount++
|
||||||
if *nodeCount > maxNodes {
|
if *nodeCount > maxNodes {
|
||||||
return fmt.Errorf("YAML node count exceeds maximum (%d)", maxNodes)
|
return fmt.Errorf("YAML node count exceeds maximum (%d)", maxNodes)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Depth-aware short-circuit: skip re-validation only when the current visit
|
// Cycle detection: if we've seen this node before, we're in a cycle.
|
||||||
// depth is the same or shallower than the depth at which this node was
|
if _, ok := seen[node]; ok {
|
||||||
// previously validated. A shallower (or equal) current depth means the
|
return nil // Already validated this subtree, skip to avoid infinite recursion.
|
||||||
// prior, deeper validation already covered any subtree depth violations.
|
|
||||||
// If the current depth exceeds the previous validation depth (e.g., an alias
|
|
||||||
// references this node deeper in the tree), we must re-traverse to ensure
|
|
||||||
// the combined effective depth doesn't exceed maxDepth.
|
|
||||||
//
|
|
||||||
// Note: using ast.Node (interface) as map key relies on pointer identity,
|
|
||||||
// which is correct because all goccy/go-yaml AST node types are pointer
|
|
||||||
// receivers (*MappingNode, *SequenceNode, etc.), never value types.
|
|
||||||
if prevDepth, ok := validated[node]; ok && depth <= prevDepth {
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
validated[node] = depth
|
seen[node] = struct{}{}
|
||||||
|
|
||||||
// Mark as visiting (on the current recursion path) for cycle detection.
|
// Handle alias nodes: follow the alias to its anchor target.
|
||||||
visiting[node] = true
|
// Increment depth when following aliases since they expand the effective structure.
|
||||||
defer func() { visiting[node] = false }()
|
if node.Kind == yaml.AliasNode && node.Alias != nil {
|
||||||
|
return checkYAMLDepth(node.Alias, depth+1, maxDepth, maxNodes, seen, nodeCount)
|
||||||
|
}
|
||||||
|
|
||||||
// Walk children based on node type.
|
for _, child := range node.Content {
|
||||||
switch n := node.(type) {
|
if err := checkYAMLDepth(child, depth+1, maxDepth, maxNodes, seen, nodeCount); err != nil {
|
||||||
case *ast.MappingNode:
|
|
||||||
for _, value := range n.Values {
|
|
||||||
if err := checkYAMLDepth(value, depth+1, maxDepth, maxNodes, validated, visiting, nodeCount); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
case *ast.MappingValueNode:
|
|
||||||
// Both Key and Value are visited at depth+1 relative to this
|
|
||||||
// MappingValueNode. Since MappingNode visits its MappingValueNode
|
|
||||||
// children at depth+1 as well, keys and values end up at depth+2
|
|
||||||
// from the parent MappingNode. This is intentional: it mirrors the
|
|
||||||
// actual nesting structure (mapping → key-value pair → key/value).
|
|
||||||
if err := checkYAMLDepth(n.Key, depth+1, maxDepth, maxNodes, validated, visiting, nodeCount); err != nil {
|
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
if err := checkYAMLDepth(n.Value, depth+1, maxDepth, maxNodes, validated, visiting, nodeCount); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
case *ast.SequenceNode:
|
|
||||||
for _, value := range n.Values {
|
|
||||||
if err := checkYAMLDepth(value, depth+1, maxDepth, maxNodes, validated, visiting, nodeCount); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
case *ast.AliasNode:
|
|
||||||
// Follow alias to its target, incrementing depth since aliases expand
|
|
||||||
// the effective structure.
|
|
||||||
if err := checkYAMLDepth(n.Value, depth+1, maxDepth, maxNodes, validated, visiting, nodeCount); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
case *ast.AnchorNode:
|
|
||||||
// Increment depth for anchor values as a conservative measure: the
|
|
||||||
// anchor definition itself is structural, and treating it as a depth
|
|
||||||
// level ensures that deeply nested anchors are caught at definition
|
|
||||||
// time rather than only when referenced via alias. This +1 is
|
|
||||||
// asymmetric with alias (which also increments) — by design, the
|
|
||||||
// effective depth budget for anchored-then-aliased content is reduced
|
|
||||||
// because both the definition site and the reference site each consume
|
|
||||||
// a level, making deeply nested anchor/alias pairs hit the limit sooner.
|
|
||||||
if err := checkYAMLDepth(n.Value, depth+1, maxDepth, maxNodes, validated, visiting, nodeCount); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
case *ast.TagNode:
|
|
||||||
if err := checkYAMLDepth(n.Value, depth+1, maxDepth, maxNodes, validated, visiting, nodeCount); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
case *ast.MergeKeyNode:
|
|
||||||
// MergeKeyNode represents the literal "<<" merge key token. It has no
|
|
||||||
// child nodes — the value side of a merge (e.g., *alias) lives in the
|
|
||||||
// parent MappingValueNode.Value, which is already recursed into above.
|
|
||||||
// Explicitly listed here (rather than in the default case) to prevent
|
|
||||||
// future library changes from silently bypassing depth checks.
|
|
||||||
default:
|
|
||||||
// Scalar leaf nodes (StringNode, IntegerNode, FloatNode, BoolNode,
|
|
||||||
// NullNode, InfinityNode, NanNode, LiteralNode) have no children to
|
|
||||||
// recurse into.
|
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@@ -331,11 +227,7 @@ func checkYAMLDepth(node ast.Node, depth, maxDepth, maxNodes int, validated map[
|
|||||||
// ParsePersonaBytes parses persona data from bytes with a source label for errors.
|
// ParsePersonaBytes parses persona data from bytes with a source label for errors.
|
||||||
// This is useful for parsing personas fetched from external sources (e.g., Gitea API)
|
// This is useful for parsing personas fetched from external sources (e.g., Gitea API)
|
||||||
// without requiring filesystem access. Format is detected by source extension.
|
// without requiring filesystem access. Format is detected by source extension.
|
||||||
// Input is bounded by MaxPersonaFileSize to prevent resource exhaustion.
|
|
||||||
func ParsePersonaBytes(data []byte, source string) (*Persona, error) {
|
func ParsePersonaBytes(data []byte, source string) (*Persona, error) {
|
||||||
if len(data) > MaxPersonaFileSize {
|
|
||||||
return nil, fmt.Errorf("persona data from %s exceeds maximum size (%d bytes, limit %d)", source, len(data), MaxPersonaFileSize)
|
|
||||||
}
|
|
||||||
return parsePersona(data, source)
|
return parsePersona(data, source)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
+42
-271
@@ -7,7 +7,7 @@ import (
|
|||||||
"strings"
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/goccy/go-yaml/ast"
|
"gopkg.in/yaml.v3"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestLoadBuiltinPersona(t *testing.T) {
|
func TestLoadBuiltinPersona(t *testing.T) {
|
||||||
@@ -355,7 +355,7 @@ func TestCapitalizeFirst(t *testing.T) {
|
|||||||
{"HELLO", "HELLO"},
|
{"HELLO", "HELLO"},
|
||||||
{"a", "A"},
|
{"a", "A"},
|
||||||
{"", ""},
|
{"", ""},
|
||||||
{"日本語", "日本語"}, // Non-ASCII: Japanese doesn't have case
|
{"日本語", "日本語"}, // Non-ASCII: Japanese doesn't have case
|
||||||
{"über", "Über"}, // German umlaut
|
{"über", "Über"}, // German umlaut
|
||||||
{"élève", "Élève"}, // French accent
|
{"élève", "Élève"}, // French accent
|
||||||
}
|
}
|
||||||
@@ -459,14 +459,7 @@ func TestYAMLDeeplyNestedRejection(t *testing.T) {
|
|||||||
path := filepath.Join(dir, "deeply-nested.yaml")
|
path := filepath.Join(dir, "deeply-nested.yaml")
|
||||||
|
|
||||||
// Build a deeply nested YAML structure that exceeds MaxYAMLDepth (20).
|
// Build a deeply nested YAML structure that exceeds MaxYAMLDepth (20).
|
||||||
// Depth accumulation trace for "nested: \n level0: \n level1: ...":
|
// Each level adds 2 to the depth count (key + value mapping).
|
||||||
// - Document root parsed at depth 0
|
|
||||||
// - Root MappingNode children (MappingValueNodes) visited at depth 1
|
|
||||||
// - "nested" MappingValueNode: key at depth 2, value at depth 2
|
|
||||||
// - Each levelN adds depth via MappingValueNode traversal (key + value)
|
|
||||||
// - Exact depth per level depends on AST structure (MappingNode wrapping),
|
|
||||||
// but 25 levels reliably exceeds MaxYAMLDepth (20) with comfortable margin.
|
|
||||||
// The test uses 25 levels rather than exactly 21 to avoid brittleness.
|
|
||||||
var sb strings.Builder
|
var sb strings.Builder
|
||||||
sb.WriteString("name: test\nidentity: test\nnested:\n")
|
sb.WriteString("name: test\nidentity: test\nnested:\n")
|
||||||
indent := " "
|
indent := " "
|
||||||
@@ -490,35 +483,6 @@ func TestYAMLDeeplyNestedRejection(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestYAMLEmptyFileRejection(t *testing.T) {
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
content string
|
|
||||||
}{
|
|
||||||
{"completely_empty", ""},
|
|
||||||
{"whitespace_only", " \n\n "},
|
|
||||||
{"comment_only", "# just a comment\n"},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tc := range tests {
|
|
||||||
t.Run(tc.name, func(t *testing.T) {
|
|
||||||
dir := t.TempDir()
|
|
||||||
path := filepath.Join(dir, tc.name+".yaml")
|
|
||||||
if err := os.WriteFile(path, []byte(tc.content), 0644); err != nil {
|
|
||||||
t.Fatalf("failed to write test file: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
_, err := LoadPersona(path)
|
|
||||||
if err == nil {
|
|
||||||
t.Fatal("expected error for empty YAML input, got nil")
|
|
||||||
}
|
|
||||||
if !strings.Contains(err.Error(), "empty YAML document") {
|
|
||||||
t.Errorf("expected error containing %q, got: %v", "empty YAML document", err)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestYAMLFileSizeLimit(t *testing.T) {
|
func TestYAMLFileSizeLimit(t *testing.T) {
|
||||||
dir := t.TempDir()
|
dir := t.TempDir()
|
||||||
path := filepath.Join(dir, "huge.yaml")
|
path := filepath.Join(dir, "huge.yaml")
|
||||||
@@ -540,41 +504,41 @@ func TestYAMLFileSizeLimit(t *testing.T) {
|
|||||||
|
|
||||||
func TestYAMLAliasCycleDetection(t *testing.T) {
|
func TestYAMLAliasCycleDetection(t *testing.T) {
|
||||||
// Test that our checkYAMLDepth function handles alias cycles gracefully
|
// Test that our checkYAMLDepth function handles alias cycles gracefully
|
||||||
// by using the visiting map to prevent infinite recursion.
|
// by using the seen map to prevent infinite recursion.
|
||||||
|
// We test this directly because go-yaml's parser handles most cycles
|
||||||
|
// at parse time, but we need to ensure our checker is robust.
|
||||||
|
|
||||||
// Create a node structure where an alias points to a parent node,
|
// Create a node structure where an alias points to a parent node,
|
||||||
// simulating what could happen with crafted input.
|
// simulating what could happen with malicious input that bypasses
|
||||||
parent := &ast.MappingNode{
|
// go-yaml's cycle detection.
|
||||||
Values: []*ast.MappingValueNode{
|
parent := &yaml.Node{
|
||||||
{
|
Kind: yaml.MappingNode,
|
||||||
Key: &ast.StringNode{Value: "name"},
|
Content: []*yaml.Node{
|
||||||
Value: &ast.StringNode{Value: "test"},
|
{Kind: yaml.ScalarNode, Value: "name"},
|
||||||
},
|
{Kind: yaml.ScalarNode, Value: "test"},
|
||||||
|
{Kind: yaml.ScalarNode, Value: "nested"},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create a child that aliases back to the parent (artificial cycle)
|
// Create a child that aliases back to the parent (artificial cycle)
|
||||||
aliasToParent := &ast.AliasNode{
|
aliasToParent := &yaml.Node{
|
||||||
Value: parent,
|
Kind: yaml.AliasNode,
|
||||||
|
Alias: parent,
|
||||||
}
|
}
|
||||||
parent.Values = append(parent.Values, &ast.MappingValueNode{
|
parent.Content = append(parent.Content, aliasToParent)
|
||||||
Key: &ast.StringNode{Value: "nested"},
|
|
||||||
Value: aliasToParent,
|
|
||||||
})
|
|
||||||
|
|
||||||
nodeCount := 0
|
nodeCount := 0
|
||||||
validated := make(map[ast.Node]int)
|
seen := make(map[*yaml.Node]struct{})
|
||||||
visiting := make(map[ast.Node]bool)
|
|
||||||
|
|
||||||
// This should NOT hang or stack overflow - cycle detection prevents infinite recursion
|
// This should NOT hang or stack overflow - the seen map prevents infinite recursion
|
||||||
err := checkYAMLDepth(parent, 0, MaxYAMLDepth, MaxYAMLNodes, validated, visiting, &nodeCount)
|
err := checkYAMLDepth(parent, 0, MaxYAMLDepth, MaxYAMLNodes, seen, &nodeCount)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("unexpected error traversing cyclic structure: %v", err)
|
t.Errorf("unexpected error traversing cyclic structure: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Verify we tracked the parent in the validated map
|
// Verify we tracked the parent in the seen map
|
||||||
if _, ok := validated[parent]; !ok {
|
if _, ok := seen[parent]; !ok {
|
||||||
t.Error("parent node not tracked in validated map")
|
t.Error("parent node not tracked in seen map")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -630,82 +594,36 @@ func TestYAMLNodeCountLimit(t *testing.T) {
|
|||||||
func TestCheckYAMLDepthCycleDetectionDirect(t *testing.T) {
|
func TestCheckYAMLDepthCycleDetectionDirect(t *testing.T) {
|
||||||
// Direct test of cycle detection in checkYAMLDepth by creating
|
// Direct test of cycle detection in checkYAMLDepth by creating
|
||||||
// a node structure with an artificial cycle.
|
// a node structure with an artificial cycle.
|
||||||
node := &ast.MappingNode{
|
// This tests the seen map logic independent of go-yaml's parsing.
|
||||||
Values: []*ast.MappingValueNode{
|
node := &yaml.Node{
|
||||||
{
|
Kind: yaml.MappingNode,
|
||||||
Key: &ast.StringNode{Value: "key"},
|
Content: []*yaml.Node{
|
||||||
Value: &ast.StringNode{Value: "value"},
|
{Kind: yaml.ScalarNode, Value: "key"},
|
||||||
},
|
{Kind: yaml.ScalarNode, Value: "value"},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create a cycle by making a child reference the parent
|
// Create a cycle by making a child reference the parent
|
||||||
cycleChild := &ast.AliasNode{
|
cycleChild := &yaml.Node{
|
||||||
Value: node, // Points back to the parent
|
Kind: yaml.AliasNode,
|
||||||
|
Alias: node, // Points back to the parent
|
||||||
}
|
}
|
||||||
node.Values = append(node.Values, &ast.MappingValueNode{
|
node.Content = append(node.Content,
|
||||||
Key: &ast.StringNode{Value: "cyclic"},
|
&yaml.Node{Kind: yaml.ScalarNode, Value: "cyclic"},
|
||||||
Value: cycleChild,
|
cycleChild,
|
||||||
})
|
)
|
||||||
|
|
||||||
nodeCount := 0
|
nodeCount := 0
|
||||||
validated := make(map[ast.Node]int)
|
seen := make(map[*yaml.Node]struct{})
|
||||||
visiting := make(map[ast.Node]bool)
|
err := checkYAMLDepth(node, 0, MaxYAMLDepth, MaxYAMLNodes, seen, &nodeCount)
|
||||||
err := checkYAMLDepth(node, 0, MaxYAMLDepth, MaxYAMLNodes, validated, visiting, &nodeCount)
|
|
||||||
|
|
||||||
// Should complete without infinite recursion due to cycle detection
|
// Should complete without infinite recursion due to cycle detection
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("unexpected error: %v", err)
|
t.Errorf("unexpected error: %v", err)
|
||||||
}
|
}
|
||||||
// The validated map should contain multiple entries
|
// The seen map should contain multiple entries
|
||||||
if len(validated) < 2 {
|
if len(seen) < 2 {
|
||||||
t.Errorf("validated map has %d entries, expected at least 2", len(validated))
|
t.Errorf("seen map has %d entries, expected at least 2", len(seen))
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestYAMLAliasDepthBypass(t *testing.T) {
|
|
||||||
// Test that an anchored subtree first validated at a shallow depth is
|
|
||||||
// re-checked when referenced via alias at a deeper position. Without the
|
|
||||||
// depth-aware validated map, the alias reference would skip re-checking
|
|
||||||
// and allow the effective nesting to exceed MaxYAMLDepth.
|
|
||||||
|
|
||||||
dir := t.TempDir()
|
|
||||||
path := filepath.Join(dir, "alias-depth-bypass.yaml")
|
|
||||||
|
|
||||||
// Build YAML with an anchor at shallow depth containing a subtree near the limit,
|
|
||||||
// then reference it via alias deep enough that effective depth exceeds MaxYAMLDepth.
|
|
||||||
var sb strings.Builder
|
|
||||||
sb.WriteString("name: test\nidentity: test\n")
|
|
||||||
|
|
||||||
// Create the anchored subtree at depth 1 (key level) that nests 15 levels deep.
|
|
||||||
sb.WriteString("anchor_key: &deep_anchor\n")
|
|
||||||
for i := 0; i < 15; i++ {
|
|
||||||
sb.WriteString(strings.Repeat(" ", i+1))
|
|
||||||
sb.WriteString(fmt.Sprintf("level%d:\n", i))
|
|
||||||
}
|
|
||||||
sb.WriteString(strings.Repeat(" ", 16))
|
|
||||||
sb.WriteString("leaf: value\n")
|
|
||||||
|
|
||||||
// Create a wrapper that nests 6 levels deep, then references the anchor.
|
|
||||||
// Effective depth at alias target = 6 (wrapper nesting) + 1 (alias) + 15 (subtree) = 22 > 20
|
|
||||||
sb.WriteString("wrapper:\n")
|
|
||||||
for i := 0; i < 6; i++ {
|
|
||||||
sb.WriteString(strings.Repeat(" ", i+1))
|
|
||||||
sb.WriteString(fmt.Sprintf("n%d:\n", i))
|
|
||||||
}
|
|
||||||
sb.WriteString(strings.Repeat(" ", 7))
|
|
||||||
sb.WriteString("alias_ref: *deep_anchor\n")
|
|
||||||
|
|
||||||
if err := os.WriteFile(path, []byte(sb.String()), 0644); err != nil {
|
|
||||||
t.Fatalf("failed to write test file: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
_, err := LoadPersona(path)
|
|
||||||
if err == nil {
|
|
||||||
t.Fatal("expected error for alias depth bypass, got nil")
|
|
||||||
}
|
|
||||||
if !strings.Contains(err.Error(), "nesting depth exceeds") {
|
|
||||||
t.Errorf("error = %q, want containing 'nesting depth exceeds'", err.Error())
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -858,150 +776,3 @@ identity: test identity
|
|||||||
t.Errorf("Name = %q, want %q", p.Name, "test")
|
t.Errorf("Name = %q, want %q", p.Name, "test")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestJSONTrailingContentRejected(t *testing.T) {
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
content string
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
name: "trailing garbage after object",
|
|
||||||
content: `{"name":"test","identity":"test identity"}garbage`,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "two JSON objects",
|
|
||||||
content: `{"name":"test","identity":"test identity"}{"name":"other"}`,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "trailing array",
|
|
||||||
content: `{"name":"test","identity":"test identity"}[]`,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
dir := t.TempDir()
|
|
||||||
path := filepath.Join(dir, "test.json")
|
|
||||||
if err := os.WriteFile(path, []byte(tt.content), 0644); err != nil {
|
|
||||||
t.Fatalf("failed to write test file: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
_, err := LoadPersona(path)
|
|
||||||
if err == nil {
|
|
||||||
t.Fatal("expected error for trailing content, got nil")
|
|
||||||
}
|
|
||||||
if !strings.Contains(err.Error(), "trailing content") {
|
|
||||||
t.Errorf("error = %q, want to contain 'trailing content'", err.Error())
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestParsePersonaBytesSizeLimit(t *testing.T) {
|
|
||||||
// ParsePersonaBytes should reject input exceeding MaxPersonaFileSize
|
|
||||||
oversized := make([]byte, MaxPersonaFileSize+1)
|
|
||||||
for i := range oversized {
|
|
||||||
oversized[i] = 'x'
|
|
||||||
}
|
|
||||||
|
|
||||||
_, err := ParsePersonaBytes(oversized, "oversized.yaml")
|
|
||||||
if err == nil {
|
|
||||||
t.Fatal("expected error for oversized input, got nil")
|
|
||||||
}
|
|
||||||
if !strings.Contains(err.Error(), "exceeds maximum size") {
|
|
||||||
t.Errorf("error = %q, want to contain 'exceeds maximum size'", err.Error())
|
|
||||||
}
|
|
||||||
|
|
||||||
// Just under the limit should not trigger size error (may fail parse, but not size)
|
|
||||||
underLimit := []byte("name: test\nidentity: test persona\n")
|
|
||||||
p, err := ParsePersonaBytes(underLimit, "valid.yaml")
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("unexpected error for valid input: %v", err)
|
|
||||||
}
|
|
||||||
if p.Name != "test" {
|
|
||||||
t.Errorf("Name = %q, want %q", p.Name, "test")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestYAMLMergeKeyDepthCheck(t *testing.T) {
|
|
||||||
// Verify that YAML merge keys (<<: *alias) are properly handled by the
|
|
||||||
// depth checker. The merge key content is in the MappingValueNode.Value
|
|
||||||
// (an AliasNode), not in the MergeKeyNode itself.
|
|
||||||
p, err := ParsePersonaBytes([]byte("name: merge-test\nidentity: test\n"), "merge.yaml")
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("basic parse failed: %v", err)
|
|
||||||
}
|
|
||||||
if p.Name != "merge-test" {
|
|
||||||
t.Errorf("Name = %q, want %q", p.Name, "merge-test")
|
|
||||||
}
|
|
||||||
|
|
||||||
// Test that deeply nested merge keys still hit depth limit.
|
|
||||||
// Build YAML with merge key content nested beyond MaxYAMLDepth.
|
|
||||||
var sb strings.Builder
|
|
||||||
sb.WriteString("name: deep-merge\nidentity: deep merge persona\n")
|
|
||||||
sb.WriteString("anchor: &deep\n")
|
|
||||||
indent := " "
|
|
||||||
for i := 0; i < MaxYAMLDepth+5; i++ {
|
|
||||||
sb.WriteString(indent)
|
|
||||||
sb.WriteString(fmt.Sprintf("level%d:\n", i))
|
|
||||||
indent += " "
|
|
||||||
}
|
|
||||||
sb.WriteString(indent + "leaf: value\n")
|
|
||||||
sb.WriteString("target:\n <<: *deep\n")
|
|
||||||
|
|
||||||
_, err = ParsePersonaBytes([]byte(sb.String()), "deep-merge.yaml")
|
|
||||||
if err == nil {
|
|
||||||
t.Fatal("expected error for deeply nested merge key content, got nil")
|
|
||||||
}
|
|
||||||
if !strings.Contains(err.Error(), "depth") {
|
|
||||||
t.Errorf("error = %q, want to contain 'depth'", err.Error())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestLoadPersona_NonexistentFile(t *testing.T) {
|
|
||||||
_, err := LoadPersona("/tmp/nonexistent-persona-file-xyz.yaml")
|
|
||||||
if err == nil {
|
|
||||||
t.Fatal("expected error for nonexistent file, got nil")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestLoadPersona_NotARegularFile(t *testing.T) {
|
|
||||||
// Use a directory as the path — directories are not regular files.
|
|
||||||
dir := t.TempDir()
|
|
||||||
_, err := LoadPersona(dir)
|
|
||||||
if err == nil {
|
|
||||||
t.Fatal("expected error for directory path, got nil")
|
|
||||||
}
|
|
||||||
if !strings.Contains(err.Error(), "not a regular file") {
|
|
||||||
t.Errorf("error = %q, want to contain 'not a regular file'", err.Error())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestLoadPersona_OversizedFile(t *testing.T) {
|
|
||||||
dir := t.TempDir()
|
|
||||||
path := filepath.Join(dir, "big.yaml")
|
|
||||||
// Write a file larger than MaxPersonaFileSize
|
|
||||||
data := make([]byte, MaxPersonaFileSize+1)
|
|
||||||
for i := range data {
|
|
||||||
data[i] = 'x'
|
|
||||||
}
|
|
||||||
if err := os.WriteFile(path, data, 0644); err != nil {
|
|
||||||
t.Fatalf("failed to create test file: %v", err)
|
|
||||||
}
|
|
||||||
_, err := LoadPersona(path)
|
|
||||||
if err == nil {
|
|
||||||
t.Fatal("expected error for oversized file, got nil")
|
|
||||||
}
|
|
||||||
if !strings.Contains(err.Error(), "exceeds maximum size") {
|
|
||||||
t.Errorf("error = %q, want to contain 'exceeds maximum size'", err.Error())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestCapitalizeFirst_RuneError(t *testing.T) {
|
|
||||||
// An invalid UTF-8 byte sequence should return the original string unchanged.
|
|
||||||
invalid := string([]byte{0xFF, 0xFE})
|
|
||||||
got := CapitalizeFirst(invalid)
|
|
||||||
if got != invalid {
|
|
||||||
t.Errorf("CapitalizeFirst(%q) = %q, want original %q", invalid, got, invalid)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -117,6 +117,7 @@ func TestBuildUserPrompt_WithoutFileContext(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
func TestBuildSystemBase(t *testing.T) {
|
func TestBuildSystemBase(t *testing.T) {
|
||||||
result := BuildSystemBase()
|
result := BuildSystemBase()
|
||||||
if result == "" {
|
if result == "" {
|
||||||
|
|||||||
+4
-17
@@ -4,32 +4,19 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"log/slog"
|
"log/slog"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
|
"gitea.weiker.me/rodin/review-bot/vcs"
|
||||||
)
|
)
|
||||||
|
|
||||||
// RepoPersonaPath is the directory path where repo-specific personas are stored.
|
// RepoPersonaPath is the directory path where repo-specific personas are stored.
|
||||||
const RepoPersonaPath = ".review-bot/personas"
|
const RepoPersonaPath = ".review-bot/personas"
|
||||||
|
|
||||||
// GiteaClient defines the subset of gitea.Client methods needed for loading repo personas.
|
|
||||||
// This interface allows for easier testing and decouples the review package from gitea.
|
|
||||||
type GiteaClient interface {
|
|
||||||
ListContents(ctx context.Context, owner, repo, path string) ([]ContentEntry, error)
|
|
||||||
GetFileContent(ctx context.Context, owner, repo, filepath string) (string, error)
|
|
||||||
}
|
|
||||||
|
|
||||||
// ContentEntry represents a file or directory entry from the contents API.
|
|
||||||
// This mirrors gitea.ContentEntry to avoid import cycles.
|
|
||||||
type ContentEntry struct {
|
|
||||||
Name string `json:"name"`
|
|
||||||
Path string `json:"path"`
|
|
||||||
Type string `json:"type"` // "file" or "dir"
|
|
||||||
}
|
|
||||||
|
|
||||||
// LoadRepoPersonas fetches personas from a repository's .review-bot/personas/ directory.
|
// LoadRepoPersonas fetches personas from a repository's .review-bot/personas/ directory.
|
||||||
// Returns an empty map (not nil) if the directory doesn't exist or is empty.
|
// Returns an empty map (not nil) if the directory doesn't exist or is empty.
|
||||||
// Individual parse failures are logged and skipped; the remaining personas are still returned.
|
// Individual parse failures are logged and skipped; the remaining personas are still returned.
|
||||||
// Auth errors and other non-404 errors are propagated.
|
// Auth errors and other non-404 errors are propagated.
|
||||||
// Files exceeding MaxPersonaFileSize are rejected to prevent resource exhaustion.
|
// Files exceeding MaxPersonaFileSize are rejected to prevent resource exhaustion.
|
||||||
func LoadRepoPersonas(ctx context.Context, client GiteaClient, owner, repo string) (map[string]*Persona, error) {
|
func LoadRepoPersonas(ctx context.Context, client vcs.FileReader, owner, repo string) (map[string]*Persona, error) {
|
||||||
result := make(map[string]*Persona)
|
result := make(map[string]*Persona)
|
||||||
|
|
||||||
entries, err := client.ListContents(ctx, owner, repo, RepoPersonaPath)
|
entries, err := client.ListContents(ctx, owner, repo, RepoPersonaPath)
|
||||||
@@ -57,7 +44,7 @@ func LoadRepoPersonas(ctx context.Context, client GiteaClient, owner, repo strin
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
content, err := client.GetFileContent(ctx, owner, repo, entry.Path)
|
content, err := client.GetFileContent(ctx, owner, repo, entry.Path, "")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
slog.Warn("could not fetch repo persona file",
|
slog.Warn("could not fetch repo persona file",
|
||||||
"file", entry.Path,
|
"file", entry.Path,
|
||||||
|
|||||||
+24
-55
@@ -5,6 +5,8 @@ import (
|
|||||||
"errors"
|
"errors"
|
||||||
"strings"
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
"gitea.weiker.me/rodin/review-bot/vcs"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestParsePersonaBytes(t *testing.T) {
|
func TestParsePersonaBytes(t *testing.T) {
|
||||||
@@ -17,11 +19,7 @@ func TestParsePersonaBytes(t *testing.T) {
|
|||||||
}{
|
}{
|
||||||
{
|
{
|
||||||
name: "valid yaml",
|
name: "valid yaml",
|
||||||
data: `name: test
|
data: "name: test\nidentity: test identity\nfocus:\n - testing\n",
|
||||||
identity: test identity
|
|
||||||
focus:
|
|
||||||
- testing
|
|
||||||
`,
|
|
||||||
source: "test.yaml",
|
source: "test.yaml",
|
||||||
wantName: "test",
|
wantName: "test",
|
||||||
},
|
},
|
||||||
@@ -67,15 +65,15 @@ focus:
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// mockGiteaClient implements GiteaClient for testing.
|
// mockGiteaClient implements vcs.FileReader for testing.
|
||||||
type mockGiteaClient struct {
|
type mockGiteaClient struct {
|
||||||
contents map[string][]ContentEntry // path -> entries
|
contents map[string][]vcs.ContentEntry // path -> entries
|
||||||
files map[string]string // path -> content
|
files map[string]string // path -> content
|
||||||
listErr error
|
listErr error
|
||||||
fileErr map[string]error // path -> error
|
fileErr map[string]error // path -> error
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *mockGiteaClient) ListContents(ctx context.Context, owner, repo, path string) ([]ContentEntry, error) {
|
func (m *mockGiteaClient) ListContents(ctx context.Context, owner, repo, path string) ([]vcs.ContentEntry, error) {
|
||||||
if m.listErr != nil {
|
if m.listErr != nil {
|
||||||
return nil, m.listErr
|
return nil, m.listErr
|
||||||
}
|
}
|
||||||
@@ -86,7 +84,7 @@ func (m *mockGiteaClient) ListContents(ctx context.Context, owner, repo, path st
|
|||||||
return entries, nil
|
return entries, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *mockGiteaClient) GetFileContent(ctx context.Context, owner, repo, filepath string) (string, error) {
|
func (m *mockGiteaClient) GetFileContent(ctx context.Context, owner, repo, filepath, ref string) (string, error) {
|
||||||
if m.fileErr != nil {
|
if m.fileErr != nil {
|
||||||
if err, ok := m.fileErr[filepath]; ok {
|
if err, ok := m.fileErr[filepath]; ok {
|
||||||
return "", err
|
return "", err
|
||||||
@@ -118,7 +116,7 @@ func TestLoadRepoPersonas(t *testing.T) {
|
|||||||
|
|
||||||
t.Run("empty directory returns empty map", func(t *testing.T) {
|
t.Run("empty directory returns empty map", func(t *testing.T) {
|
||||||
client := &mockGiteaClient{
|
client := &mockGiteaClient{
|
||||||
contents: map[string][]ContentEntry{
|
contents: map[string][]vcs.ContentEntry{
|
||||||
RepoPersonaPath: {},
|
RepoPersonaPath: {},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
@@ -133,27 +131,15 @@ func TestLoadRepoPersonas(t *testing.T) {
|
|||||||
|
|
||||||
t.Run("loads valid personas", func(t *testing.T) {
|
t.Run("loads valid personas", func(t *testing.T) {
|
||||||
client := &mockGiteaClient{
|
client := &mockGiteaClient{
|
||||||
contents: map[string][]ContentEntry{
|
contents: map[string][]vcs.ContentEntry{
|
||||||
RepoPersonaPath: {
|
RepoPersonaPath: {
|
||||||
{Name: "trading.yaml", Path: ".review-bot/personas/trading.yaml", Type: "file"},
|
{Name: "trading.yaml", Path: ".review-bot/personas/trading.yaml", Type: "file"},
|
||||||
{Name: "crypto.yaml", Path: ".review-bot/personas/crypto.yaml", Type: "file"},
|
{Name: "crypto.yaml", Path: ".review-bot/personas/crypto.yaml", Type: "file"},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
files: map[string]string{
|
files: map[string]string{
|
||||||
".review-bot/personas/trading.yaml": `name: trading
|
".review-bot/personas/trading.yaml": "name: trading\ndisplay_name: Trading Expert\nidentity: You are a trading expert.\nfocus:\n - order handling\n - risk management\n",
|
||||||
display_name: Trading Expert
|
".review-bot/personas/crypto.yaml": "name: crypto\ndisplay_name: Crypto Expert\nidentity: You are a cryptography expert.\nfocus:\n - key management\n - encryption\n",
|
||||||
identity: You are a trading expert.
|
|
||||||
focus:
|
|
||||||
- order handling
|
|
||||||
- risk management
|
|
||||||
`,
|
|
||||||
".review-bot/personas/crypto.yaml": `name: crypto
|
|
||||||
display_name: Crypto Expert
|
|
||||||
identity: You are a cryptography expert.
|
|
||||||
focus:
|
|
||||||
- key management
|
|
||||||
- encryption
|
|
||||||
`,
|
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
personas, err := LoadRepoPersonas(ctx, client, "owner", "repo")
|
personas, err := LoadRepoPersonas(ctx, client, "owner", "repo")
|
||||||
@@ -176,16 +162,14 @@ focus:
|
|||||||
|
|
||||||
t.Run("skips invalid persona files", func(t *testing.T) {
|
t.Run("skips invalid persona files", func(t *testing.T) {
|
||||||
client := &mockGiteaClient{
|
client := &mockGiteaClient{
|
||||||
contents: map[string][]ContentEntry{
|
contents: map[string][]vcs.ContentEntry{
|
||||||
RepoPersonaPath: {
|
RepoPersonaPath: {
|
||||||
{Name: "valid.yaml", Path: ".review-bot/personas/valid.yaml", Type: "file"},
|
{Name: "valid.yaml", Path: ".review-bot/personas/valid.yaml", Type: "file"},
|
||||||
{Name: "invalid.yaml", Path: ".review-bot/personas/invalid.yaml", Type: "file"},
|
{Name: "invalid.yaml", Path: ".review-bot/personas/invalid.yaml", Type: "file"},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
files: map[string]string{
|
files: map[string]string{
|
||||||
".review-bot/personas/valid.yaml": `name: valid
|
".review-bot/personas/valid.yaml": "name: valid\nidentity: Valid persona\n",
|
||||||
identity: Valid persona
|
|
||||||
`,
|
|
||||||
".review-bot/personas/invalid.yaml": "not valid yaml: [broken",
|
".review-bot/personas/invalid.yaml": "not valid yaml: [broken",
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
@@ -193,7 +177,6 @@ identity: Valid persona
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("unexpected error: %v", err)
|
t.Fatalf("unexpected error: %v", err)
|
||||||
}
|
}
|
||||||
// Should have the valid one, skip the invalid
|
|
||||||
if len(personas) != 1 {
|
if len(personas) != 1 {
|
||||||
t.Fatalf("expected 1 persona (skipped invalid), got %d", len(personas))
|
t.Fatalf("expected 1 persona (skipped invalid), got %d", len(personas))
|
||||||
}
|
}
|
||||||
@@ -204,7 +187,7 @@ identity: Valid persona
|
|||||||
|
|
||||||
t.Run("skips non-yaml files", func(t *testing.T) {
|
t.Run("skips non-yaml files", func(t *testing.T) {
|
||||||
client := &mockGiteaClient{
|
client := &mockGiteaClient{
|
||||||
contents: map[string][]ContentEntry{
|
contents: map[string][]vcs.ContentEntry{
|
||||||
RepoPersonaPath: {
|
RepoPersonaPath: {
|
||||||
{Name: "persona.yaml", Path: ".review-bot/personas/persona.yaml", Type: "file"},
|
{Name: "persona.yaml", Path: ".review-bot/personas/persona.yaml", Type: "file"},
|
||||||
{Name: "README.md", Path: ".review-bot/personas/README.md", Type: "file"},
|
{Name: "README.md", Path: ".review-bot/personas/README.md", Type: "file"},
|
||||||
@@ -212,10 +195,8 @@ identity: Valid persona
|
|||||||
},
|
},
|
||||||
},
|
},
|
||||||
files: map[string]string{
|
files: map[string]string{
|
||||||
".review-bot/personas/persona.yaml": `name: test
|
".review-bot/personas/persona.yaml": "name: test\nidentity: Test persona\n",
|
||||||
identity: Test persona
|
".review-bot/personas/README.md": "# Personas\n\nPut your personas here.",
|
||||||
`,
|
|
||||||
".review-bot/personas/README.md": "# Personas\n\nPut your personas here.",
|
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
personas, err := LoadRepoPersonas(ctx, client, "owner", "repo")
|
personas, err := LoadRepoPersonas(ctx, client, "owner", "repo")
|
||||||
@@ -229,16 +210,14 @@ identity: Test persona
|
|||||||
|
|
||||||
t.Run("skips subdirectories", func(t *testing.T) {
|
t.Run("skips subdirectories", func(t *testing.T) {
|
||||||
client := &mockGiteaClient{
|
client := &mockGiteaClient{
|
||||||
contents: map[string][]ContentEntry{
|
contents: map[string][]vcs.ContentEntry{
|
||||||
RepoPersonaPath: {
|
RepoPersonaPath: {
|
||||||
{Name: "persona.yaml", Path: ".review-bot/personas/persona.yaml", Type: "file"},
|
{Name: "persona.yaml", Path: ".review-bot/personas/persona.yaml", Type: "file"},
|
||||||
{Name: "subdir", Path: ".review-bot/personas/subdir", Type: "dir"},
|
{Name: "subdir", Path: ".review-bot/personas/subdir", Type: "dir"},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
files: map[string]string{
|
files: map[string]string{
|
||||||
".review-bot/personas/persona.yaml": `name: test
|
".review-bot/personas/persona.yaml": "name: test\nidentity: Test persona\n",
|
||||||
identity: Test persona
|
|
||||||
`,
|
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
personas, err := LoadRepoPersonas(ctx, client, "owner", "repo")
|
personas, err := LoadRepoPersonas(ctx, client, "owner", "repo")
|
||||||
@@ -265,16 +244,14 @@ identity: Test persona
|
|||||||
|
|
||||||
t.Run("skips files that fail to fetch", func(t *testing.T) {
|
t.Run("skips files that fail to fetch", func(t *testing.T) {
|
||||||
client := &mockGiteaClient{
|
client := &mockGiteaClient{
|
||||||
contents: map[string][]ContentEntry{
|
contents: map[string][]vcs.ContentEntry{
|
||||||
RepoPersonaPath: {
|
RepoPersonaPath: {
|
||||||
{Name: "good.yaml", Path: ".review-bot/personas/good.yaml", Type: "file"},
|
{Name: "good.yaml", Path: ".review-bot/personas/good.yaml", Type: "file"},
|
||||||
{Name: "bad.yaml", Path: ".review-bot/personas/bad.yaml", Type: "file"},
|
{Name: "bad.yaml", Path: ".review-bot/personas/bad.yaml", Type: "file"},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
files: map[string]string{
|
files: map[string]string{
|
||||||
".review-bot/personas/good.yaml": `name: good
|
".review-bot/personas/good.yaml": "name: good\nidentity: Good persona\n",
|
||||||
identity: Good persona
|
|
||||||
`,
|
|
||||||
},
|
},
|
||||||
fileErr: map[string]error{
|
fileErr: map[string]error{
|
||||||
".review-bot/personas/bad.yaml": errors.New("HTTP 500: internal server error"),
|
".review-bot/personas/bad.yaml": errors.New("HTTP 500: internal server error"),
|
||||||
@@ -290,27 +267,23 @@ identity: Good persona
|
|||||||
})
|
})
|
||||||
|
|
||||||
t.Run("skips oversized files", func(t *testing.T) {
|
t.Run("skips oversized files", func(t *testing.T) {
|
||||||
// Create a content string that exceeds MaxPersonaFileSize (64KB)
|
|
||||||
oversizedContent := strings.Repeat("a", MaxPersonaFileSize+1)
|
oversizedContent := strings.Repeat("a", MaxPersonaFileSize+1)
|
||||||
client := &mockGiteaClient{
|
client := &mockGiteaClient{
|
||||||
contents: map[string][]ContentEntry{
|
contents: map[string][]vcs.ContentEntry{
|
||||||
RepoPersonaPath: {
|
RepoPersonaPath: {
|
||||||
{Name: "normal.yaml", Path: ".review-bot/personas/normal.yaml", Type: "file"},
|
{Name: "normal.yaml", Path: ".review-bot/personas/normal.yaml", Type: "file"},
|
||||||
{Name: "huge.yaml", Path: ".review-bot/personas/huge.yaml", Type: "file"},
|
{Name: "huge.yaml", Path: ".review-bot/personas/huge.yaml", Type: "file"},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
files: map[string]string{
|
files: map[string]string{
|
||||||
".review-bot/personas/normal.yaml": `name: normal
|
".review-bot/personas/normal.yaml": "name: normal\nidentity: Normal sized persona\n",
|
||||||
identity: Normal sized persona
|
".review-bot/personas/huge.yaml": oversizedContent,
|
||||||
`,
|
|
||||||
".review-bot/personas/huge.yaml": oversizedContent,
|
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
personas, err := LoadRepoPersonas(ctx, client, "owner", "repo")
|
personas, err := LoadRepoPersonas(ctx, client, "owner", "repo")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("unexpected error: %v", err)
|
t.Fatalf("unexpected error: %v", err)
|
||||||
}
|
}
|
||||||
// Should have the normal one, skip the oversized
|
|
||||||
if len(personas) != 1 {
|
if len(personas) != 1 {
|
||||||
t.Fatalf("expected 1 persona (skipped oversized), got %d", len(personas))
|
t.Fatalf("expected 1 persona (skipped oversized), got %d", len(personas))
|
||||||
}
|
}
|
||||||
@@ -370,7 +343,6 @@ func TestGetBuiltinPersonasMap(t *testing.T) {
|
|||||||
t.Fatal("expected at least one built-in persona")
|
t.Fatal("expected at least one built-in persona")
|
||||||
}
|
}
|
||||||
|
|
||||||
// Verify expected personas exist
|
|
||||||
expected := []string{"security", "architect", "docs"}
|
expected := []string{"security", "architect", "docs"}
|
||||||
for _, name := range expected {
|
for _, name := range expected {
|
||||||
if personas[name] == nil {
|
if personas[name] == nil {
|
||||||
@@ -378,7 +350,6 @@ func TestGetBuiltinPersonasMap(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Verify personas are valid
|
|
||||||
for name, p := range personas {
|
for name, p := range personas {
|
||||||
if p.Name != name {
|
if p.Name != name {
|
||||||
t.Errorf("persona %q has mismatched name %q", name, p.Name)
|
t.Errorf("persona %q has mismatched name %q", name, p.Name)
|
||||||
@@ -422,8 +393,6 @@ func TestIsNotFoundError(t *testing.T) {
|
|||||||
{nil, false},
|
{nil, false},
|
||||||
{errors.New("HTTP 404: not found"), true},
|
{errors.New("HTTP 404: not found"), true},
|
||||||
{errors.New("HTTP 404"), true},
|
{errors.New("HTTP 404"), true},
|
||||||
// Intentionally false: generic "not found" could mask auth/transport errors.
|
|
||||||
// Only explicit HTTP 404 responses should be treated as "directory doesn't exist".
|
|
||||||
{errors.New("something not found"), false},
|
{errors.New("something not found"), false},
|
||||||
{errors.New("HTTP 401: unauthorized"), false},
|
{errors.New("HTTP 401: unauthorized"), false},
|
||||||
{errors.New("connection refused"), false},
|
{errors.New("connection refused"), false},
|
||||||
|
|||||||
@@ -0,0 +1,27 @@
|
|||||||
|
//go:build phase2
|
||||||
|
|
||||||
|
package vcs_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"gitea.weiker.me/rodin/review-bot/gitea"
|
||||||
|
"gitea.weiker.me/rodin/review-bot/vcs"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Compile-time assertion: documents the gap between gitea.Client and vcs.Client.
|
||||||
|
// Guarded by the "phase2" build tag — enable once the Gitea adapter bridges these gaps:
|
||||||
|
//
|
||||||
|
// 1. PostReview signature mismatch:
|
||||||
|
// gitea.Client: PostReview(ctx, owner, repo, number, event, body string, comments []gitea.ReviewComment)
|
||||||
|
// vcs.Reviewer: PostReview(ctx, owner, repo, number, req vcs.ReviewRequest)
|
||||||
|
//
|
||||||
|
// 2. GetFileContent signature mismatch:
|
||||||
|
// gitea.Client: GetFileContent(ctx, owner, repo, filepath string) [no ref; uses default branch]
|
||||||
|
// vcs.FileReader: GetFileContent(ctx, owner, repo, path, ref string)
|
||||||
|
// (gitea.Client has GetFileContentRef for the ref variant)
|
||||||
|
//
|
||||||
|
// 3. ReviewComment type mismatch:
|
||||||
|
// gitea.ReviewComment uses NewPosition int64 (Gitea line-number convention)
|
||||||
|
// vcs.ReviewComment uses Position int (GitHub diff-position convention)
|
||||||
|
//
|
||||||
|
// The Gitea adapter (Phase 2) will wrap gitea.Client to bridge these gaps.
|
||||||
|
var _ vcs.Client = (*gitea.Client)(nil)
|
||||||
@@ -0,0 +1,43 @@
|
|||||||
|
// Package vcs defines the shared VCS client interface and supporting types.
|
||||||
|
// Platform adapters (gitea, github) implement these interfaces so the core
|
||||||
|
// review logic can work with any VCS platform without platform-specific code.
|
||||||
|
package vcs
|
||||||
|
|
||||||
|
import "context"
|
||||||
|
|
||||||
|
// PRReader can fetch pull request metadata, diffs, and changed files.
|
||||||
|
type PRReader interface {
|
||||||
|
GetPullRequest(ctx context.Context, owner, repo string, number int) (*PullRequest, error)
|
||||||
|
GetPullRequestDiff(ctx context.Context, owner, repo string, number int) (string, error)
|
||||||
|
GetPullRequestFiles(ctx context.Context, owner, repo string, number int) ([]ChangedFile, error)
|
||||||
|
GetFileContentAtRef(ctx context.Context, owner, repo, path, ref string) (string, error)
|
||||||
|
GetCommitStatuses(ctx context.Context, owner, repo, sha string) ([]CommitStatus, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
// FileReader can fetch file contents and list directory entries.
|
||||||
|
type FileReader interface {
|
||||||
|
GetFileContent(ctx context.Context, owner, repo, path, ref string) (string, error)
|
||||||
|
ListContents(ctx context.Context, owner, repo, path string) ([]ContentEntry, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Reviewer can post, list, and delete pull request reviews.
|
||||||
|
type Reviewer interface {
|
||||||
|
PostReview(ctx context.Context, owner, repo string, number int, req ReviewRequest) (*Review, error)
|
||||||
|
ListReviews(ctx context.Context, owner, repo string, number int) ([]Review, error)
|
||||||
|
DeleteReview(ctx context.Context, owner, repo string, number int, reviewID int64) error
|
||||||
|
DismissReview(ctx context.Context, owner, repo string, number int, reviewID int64, message string) error
|
||||||
|
}
|
||||||
|
|
||||||
|
// Identity can report who the authenticated user is.
|
||||||
|
type Identity interface {
|
||||||
|
GetAuthenticatedUser(ctx context.Context) (string, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Client is the full VCS interface: PR reads, file reads, review management, and identity.
|
||||||
|
// Platform adapters (gitea, github) implement this interface.
|
||||||
|
type Client interface {
|
||||||
|
PRReader
|
||||||
|
FileReader
|
||||||
|
Reviewer
|
||||||
|
Identity
|
||||||
|
}
|
||||||
@@ -0,0 +1,98 @@
|
|||||||
|
package vcs
|
||||||
|
|
||||||
|
// ReviewEvent is the event type for a pull request review action.
|
||||||
|
// Adapters must translate these action constants to/from platform-native values.
|
||||||
|
// For example, Gitea uses "APPROVED" as both action and state, while GitHub
|
||||||
|
// uses "APPROVE" for the action and returns "approved" as the state.
|
||||||
|
type ReviewEvent string
|
||||||
|
|
||||||
|
const (
|
||||||
|
// ReviewEventApprove approves the pull request.
|
||||||
|
ReviewEventApprove ReviewEvent = "APPROVE"
|
||||||
|
// ReviewEventRequestChanges requests changes to the pull request.
|
||||||
|
ReviewEventRequestChanges ReviewEvent = "REQUEST_CHANGES"
|
||||||
|
// ReviewEventComment posts a review comment without approval or rejection.
|
||||||
|
ReviewEventComment ReviewEvent = "COMMENT"
|
||||||
|
)
|
||||||
|
|
||||||
|
// BaseRef identifies the target branch of a pull request.
|
||||||
|
type BaseRef struct {
|
||||||
|
Ref string `json:"ref"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// HeadRef identifies the source branch and latest commit of a pull request.
|
||||||
|
type HeadRef struct {
|
||||||
|
SHA string `json:"sha"`
|
||||||
|
Ref string `json:"ref"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// UserInfo identifies a user by login name.
|
||||||
|
type UserInfo struct {
|
||||||
|
Login string `json:"login"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// PullRequest holds relevant PR metadata.
|
||||||
|
type PullRequest struct {
|
||||||
|
Number int `json:"number"`
|
||||||
|
Title string `json:"title"`
|
||||||
|
Body string `json:"body"`
|
||||||
|
Head HeadRef `json:"head"`
|
||||||
|
Base BaseRef `json:"base"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// ChangedFile represents a file modified in a PR.
|
||||||
|
type ChangedFile struct {
|
||||||
|
Filename string `json:"filename"`
|
||||||
|
Status string `json:"status"`
|
||||||
|
Patch string `json:"patch"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// ContentEntry represents a file or directory entry from the contents API.
|
||||||
|
type ContentEntry struct {
|
||||||
|
Name string `json:"name"`
|
||||||
|
Path string `json:"path"`
|
||||||
|
Type string `json:"type"` // "file" or "dir"
|
||||||
|
}
|
||||||
|
|
||||||
|
// CommitStatus represents a single CI status entry for a commit.
|
||||||
|
type CommitStatus struct {
|
||||||
|
Status string `json:"status"`
|
||||||
|
Context string `json:"context"`
|
||||||
|
Description string `json:"description"`
|
||||||
|
TargetURL string `json:"target_url"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// Review represents a pull request review.
|
||||||
|
type Review struct {
|
||||||
|
ID int64 `json:"id"`
|
||||||
|
Body string `json:"body"`
|
||||||
|
User UserInfo `json:"user"`
|
||||||
|
State string `json:"state"`
|
||||||
|
Stale bool `json:"stale"`
|
||||||
|
CommitID string `json:"commit_id"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// ReviewComment represents an inline comment in a review.
|
||||||
|
// All adapters use GitHub diff-position convention:
|
||||||
|
// - Position is a 1-indexed offset from the @@ hunk line in the unified diff.
|
||||||
|
// - CommitID identifies the commit the comment is anchored to.
|
||||||
|
// It is optional; omit (empty string) for review-level comments that are
|
||||||
|
// not attached to a specific commit.
|
||||||
|
//
|
||||||
|
// Adapters are responsible for translating to/from platform-native formats
|
||||||
|
// (e.g. Gitea uses line numbers; GitHub uses diff positions natively).
|
||||||
|
type ReviewComment struct {
|
||||||
|
Path string `json:"path"`
|
||||||
|
Position int `json:"position"` // diff-position: 1-indexed offset from @@ hunk line
|
||||||
|
CommitID string `json:"commit_id"`
|
||||||
|
Body string `json:"body"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// ReviewRequest is the payload for posting a review.
|
||||||
|
type ReviewRequest struct {
|
||||||
|
// Body is the top-level review comment.
|
||||||
|
Body string `json:"body"`
|
||||||
|
// Event is the review action (approve, request changes, or comment).
|
||||||
|
Event ReviewEvent `json:"event"`
|
||||||
|
Comments []ReviewComment `json:"comments,omitempty"`
|
||||||
|
}
|
||||||
+193
@@ -0,0 +1,193 @@
|
|||||||
|
package vcs
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
// maxFilesInPath is the maximum number of files GetAllFilesInPath will fetch.
|
||||||
|
// Prevents unbounded resource consumption on very large directory trees.
|
||||||
|
maxFilesInPath = 10000
|
||||||
|
|
||||||
|
// maxTotalBytesInPath is the maximum total bytes GetAllFilesInPath will accumulate.
|
||||||
|
// Prevents memory exhaustion when fetching large repositories.
|
||||||
|
maxTotalBytesInPath = 100 * 1024 * 1024 // 100 MB
|
||||||
|
)
|
||||||
|
|
||||||
|
// GetAllFilesInPath recursively fetches all file contents under a path using the
|
||||||
|
// provided FileReader. Returns a map of filepath -> content for all files found.
|
||||||
|
// If the path points to an empty directory, returns an empty map.
|
||||||
|
//
|
||||||
|
// This function uses fail-fast error handling: any error from ListContents or
|
||||||
|
// GetFileContent aborts the entire traversal and returns the error immediately.
|
||||||
|
// This differs from gitea.Client.GetAllFilesInPath, which logs errors and continues.
|
||||||
|
// The fail-fast contract ensures callers can trust that a nil error means all files
|
||||||
|
// were successfully fetched.
|
||||||
|
//
|
||||||
|
// Resource limits: the traversal is bounded by maxFilesInPath (file count) and
|
||||||
|
// maxTotalBytesInPath (total accumulated bytes). The context is checked before each
|
||||||
|
// recursive call and file fetch to respect cancellation.
|
||||||
|
func GetAllFilesInPath(ctx context.Context, client FileReader, owner, repo, path string) (map[string]string, error) {
|
||||||
|
results := make(map[string]string)
|
||||||
|
totalBytes := 0
|
||||||
|
|
||||||
|
var walk func(string) error
|
||||||
|
walk = func(dir string) error {
|
||||||
|
if err := ctx.Err(); err != nil {
|
||||||
|
return fmt.Errorf("context canceled during traversal: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
entries, err := client.ListContents(ctx, owner, repo, dir)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("list contents %q: %w", dir, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, entry := range entries {
|
||||||
|
if err := ctx.Err(); err != nil {
|
||||||
|
return fmt.Errorf("context canceled during traversal: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
switch entry.Type {
|
||||||
|
case "file":
|
||||||
|
if len(results) >= maxFilesInPath {
|
||||||
|
return fmt.Errorf("exceeded max file count (%d) in path %q", maxFilesInPath, path)
|
||||||
|
}
|
||||||
|
|
||||||
|
content, err := client.GetFileContent(ctx, owner, repo, entry.Path, "")
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("get file %q: %w", entry.Path, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
totalBytes += len(content)
|
||||||
|
if totalBytes > maxTotalBytesInPath {
|
||||||
|
return fmt.Errorf("exceeded max total bytes (%d) in path %q", maxTotalBytesInPath, path)
|
||||||
|
}
|
||||||
|
|
||||||
|
results[entry.Path] = content
|
||||||
|
case "dir":
|
||||||
|
if err := walk(entry.Path); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := walk(path); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return results, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// BuildLineToPositionMap parses a unified diff and returns a map of
|
||||||
|
// filename -> (new line number -> diff position). The diff position is a
|
||||||
|
// 1-indexed offset from the @@ hunk header line for each file.
|
||||||
|
// Only lines that appear in the new file (context lines and additions) are mapped.
|
||||||
|
// Deletion-only lines are not included.
|
||||||
|
func BuildLineToPositionMap(diff string) map[string]map[int]int {
|
||||||
|
result := make(map[string]map[int]int)
|
||||||
|
|
||||||
|
lines := strings.Split(diff, "\n")
|
||||||
|
var currentFile string
|
||||||
|
var position int
|
||||||
|
var newLine int
|
||||||
|
|
||||||
|
for _, line := range lines {
|
||||||
|
// Detect new file in diff
|
||||||
|
if strings.HasPrefix(line, "+++ b/") {
|
||||||
|
currentFile = strings.TrimPrefix(line, "+++ b/")
|
||||||
|
position = 0
|
||||||
|
newLine = 0
|
||||||
|
if result[currentFile] == nil {
|
||||||
|
result[currentFile] = make(map[int]int)
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Skip --- lines (old file header)
|
||||||
|
if strings.HasPrefix(line, "--- ") {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Skip diff --git lines
|
||||||
|
if strings.HasPrefix(line, "diff --git") {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Skip index lines
|
||||||
|
if strings.HasPrefix(line, "index ") {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Parse hunk headers
|
||||||
|
if strings.HasPrefix(line, "@@") {
|
||||||
|
position++
|
||||||
|
// Extract new file start line from @@ -a,b +c,d @@
|
||||||
|
newLine = parseHunkNewStart(line)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// We need a current file to map lines
|
||||||
|
if currentFile == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Skip "\ No newline at end of file" markers — these are git diff
|
||||||
|
// metadata and not part of the file content.
|
||||||
|
if strings.HasPrefix(line, `\`) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Process diff content lines
|
||||||
|
if strings.HasPrefix(line, "+") {
|
||||||
|
position++
|
||||||
|
result[currentFile][newLine] = position
|
||||||
|
newLine++
|
||||||
|
} else if strings.HasPrefix(line, "-") {
|
||||||
|
position++
|
||||||
|
// Deletion lines don't map to new line numbers
|
||||||
|
} else if strings.HasPrefix(line, " ") {
|
||||||
|
// Context line (space-prefixed).
|
||||||
|
// Only map if position > 0, which means we've seen a hunk header.
|
||||||
|
// Lines before the first hunk header (position == 0) are not part
|
||||||
|
// of any diff hunk and should be skipped.
|
||||||
|
if position > 0 {
|
||||||
|
position++
|
||||||
|
result[currentFile][newLine] = position
|
||||||
|
newLine++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
|
// parseHunkNewStart extracts the new-file starting line number from a hunk header.
|
||||||
|
// Format: @@ -old_start[,old_count] +new_start[,new_count] @@
|
||||||
|
func parseHunkNewStart(hunkLine string) int {
|
||||||
|
// Find the +N part
|
||||||
|
plusIdx := strings.Index(hunkLine, "+")
|
||||||
|
if plusIdx < 0 {
|
||||||
|
return 1
|
||||||
|
}
|
||||||
|
rest := hunkLine[plusIdx+1:]
|
||||||
|
|
||||||
|
// Find the end of the number (first non-digit after +)
|
||||||
|
endIdx := 0
|
||||||
|
for endIdx < len(rest) && rest[endIdx] >= '0' && rest[endIdx] <= '9' {
|
||||||
|
endIdx++
|
||||||
|
}
|
||||||
|
|
||||||
|
if endIdx == 0 {
|
||||||
|
return 1
|
||||||
|
}
|
||||||
|
|
||||||
|
n, err := strconv.Atoi(rest[:endIdx])
|
||||||
|
if err != nil {
|
||||||
|
return 1
|
||||||
|
}
|
||||||
|
return n
|
||||||
|
}
|
||||||
@@ -0,0 +1,331 @@
|
|||||||
|
package vcs_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"gitea.weiker.me/rodin/review-bot/vcs"
|
||||||
|
)
|
||||||
|
|
||||||
|
// mockFileReader implements vcs.FileReader for testing.
|
||||||
|
type mockFileReader struct {
|
||||||
|
contents map[string][]vcs.ContentEntry // path -> entries
|
||||||
|
files map[string]string // path -> content
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockFileReader) GetFileContent(ctx context.Context, owner, repo, path, ref string) (string, error) {
|
||||||
|
content, ok := m.files[path]
|
||||||
|
if !ok {
|
||||||
|
return "", fmt.Errorf("HTTP 404: file not found: %s", path)
|
||||||
|
}
|
||||||
|
return content, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockFileReader) ListContents(ctx context.Context, owner, repo, path string) ([]vcs.ContentEntry, error) {
|
||||||
|
entries, ok := m.contents[path]
|
||||||
|
if !ok {
|
||||||
|
return nil, fmt.Errorf("HTTP 404: path not found: %s", path)
|
||||||
|
}
|
||||||
|
return entries, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetAllFilesInPath(t *testing.T) {
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
t.Run("empty directory", func(t *testing.T) {
|
||||||
|
client := &mockFileReader{
|
||||||
|
contents: map[string][]vcs.ContentEntry{
|
||||||
|
"src": {},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
result, err := vcs.GetAllFilesInPath(ctx, client, "owner", "repo", "src")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
if len(result) != 0 {
|
||||||
|
t.Errorf("expected empty map, got %d entries", len(result))
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("flat directory", func(t *testing.T) {
|
||||||
|
client := &mockFileReader{
|
||||||
|
contents: map[string][]vcs.ContentEntry{
|
||||||
|
"src": {
|
||||||
|
{Name: "main.go", Path: "src/main.go", Type: "file"},
|
||||||
|
{Name: "util.go", Path: "src/util.go", Type: "file"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
files: map[string]string{
|
||||||
|
"src/main.go": "package main",
|
||||||
|
"src/util.go": "package main\n// util",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
result, err := vcs.GetAllFilesInPath(ctx, client, "owner", "repo", "src")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
if len(result) != 2 {
|
||||||
|
t.Fatalf("expected 2 files, got %d", len(result))
|
||||||
|
}
|
||||||
|
if result["src/main.go"] != "package main" {
|
||||||
|
t.Errorf("main.go content = %q", result["src/main.go"])
|
||||||
|
}
|
||||||
|
if result["src/util.go"] != "package main\n// util" {
|
||||||
|
t.Errorf("util.go content = %q", result["src/util.go"])
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("nested directories", func(t *testing.T) {
|
||||||
|
client := &mockFileReader{
|
||||||
|
contents: map[string][]vcs.ContentEntry{
|
||||||
|
"src": {
|
||||||
|
{Name: "main.go", Path: "src/main.go", Type: "file"},
|
||||||
|
{Name: "pkg", Path: "src/pkg", Type: "dir"},
|
||||||
|
},
|
||||||
|
"src/pkg": {
|
||||||
|
{Name: "lib.go", Path: "src/pkg/lib.go", Type: "file"},
|
||||||
|
{Name: "sub", Path: "src/pkg/sub", Type: "dir"},
|
||||||
|
},
|
||||||
|
"src/pkg/sub": {
|
||||||
|
{Name: "deep.go", Path: "src/pkg/sub/deep.go", Type: "file"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
files: map[string]string{
|
||||||
|
"src/main.go": "package main",
|
||||||
|
"src/pkg/lib.go": "package pkg",
|
||||||
|
"src/pkg/sub/deep.go": "package sub",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
result, err := vcs.GetAllFilesInPath(ctx, client, "owner", "repo", "src")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
if len(result) != 3 {
|
||||||
|
t.Fatalf("expected 3 files, got %d", len(result))
|
||||||
|
}
|
||||||
|
if result["src/main.go"] != "package main" {
|
||||||
|
t.Errorf("main.go content = %q", result["src/main.go"])
|
||||||
|
}
|
||||||
|
if result["src/pkg/lib.go"] != "package pkg" {
|
||||||
|
t.Errorf("lib.go content = %q", result["src/pkg/lib.go"])
|
||||||
|
}
|
||||||
|
if result["src/pkg/sub/deep.go"] != "package sub" {
|
||||||
|
t.Errorf("deep.go content = %q", result["src/pkg/sub/deep.go"])
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("mixed files and dirs", func(t *testing.T) {
|
||||||
|
client := &mockFileReader{
|
||||||
|
contents: map[string][]vcs.ContentEntry{
|
||||||
|
"root": {
|
||||||
|
{Name: "README.md", Path: "root/README.md", Type: "file"},
|
||||||
|
{Name: "docs", Path: "root/docs", Type: "dir"},
|
||||||
|
{Name: "config.yaml", Path: "root/config.yaml", Type: "file"},
|
||||||
|
},
|
||||||
|
"root/docs": {
|
||||||
|
{Name: "guide.md", Path: "root/docs/guide.md", Type: "file"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
files: map[string]string{
|
||||||
|
"root/README.md": "# Hello",
|
||||||
|
"root/config.yaml": "key: value",
|
||||||
|
"root/docs/guide.md": "## Guide",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
result, err := vcs.GetAllFilesInPath(ctx, client, "owner", "repo", "root")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
if len(result) != 3 {
|
||||||
|
t.Fatalf("expected 3 files, got %d", len(result))
|
||||||
|
}
|
||||||
|
if result["root/README.md"] != "# Hello" {
|
||||||
|
t.Errorf("README content = %q", result["root/README.md"])
|
||||||
|
}
|
||||||
|
if result["root/docs/guide.md"] != "## Guide" {
|
||||||
|
t.Errorf("guide content = %q", result["root/docs/guide.md"])
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBuildLineToPositionMap(t *testing.T) {
|
||||||
|
t.Run("single hunk", func(t *testing.T) {
|
||||||
|
diff := "diff --git a/file.go b/file.go\nindex abc..def 100644\n--- a/file.go\n+++ b/file.go\n@@ -1,3 +1,4 @@\n package main\n \n+// new comment\n func main() {}\n"
|
||||||
|
result := vcs.BuildLineToPositionMap(diff)
|
||||||
|
fileMap, ok := result["file.go"]
|
||||||
|
if !ok {
|
||||||
|
t.Fatal("expected file.go in result")
|
||||||
|
}
|
||||||
|
// Hunk header @@ is position 1
|
||||||
|
// Line 1: " package main" -> position 2
|
||||||
|
if fileMap[1] != 2 {
|
||||||
|
t.Errorf("line 1 position = %d, want 2", fileMap[1])
|
||||||
|
}
|
||||||
|
// Line 2: " " (context) -> position 3
|
||||||
|
if fileMap[2] != 3 {
|
||||||
|
t.Errorf("line 2 position = %d, want 3", fileMap[2])
|
||||||
|
}
|
||||||
|
// Line 3: "+// new comment" -> position 4
|
||||||
|
if fileMap[3] != 4 {
|
||||||
|
t.Errorf("line 3 position = %d, want 4", fileMap[3])
|
||||||
|
}
|
||||||
|
// Line 4: " func main() {}" -> position 5
|
||||||
|
if fileMap[4] != 5 {
|
||||||
|
t.Errorf("line 4 position = %d, want 5", fileMap[4])
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("multi hunk", func(t *testing.T) {
|
||||||
|
diff := "diff --git a/file.go b/file.go\n--- a/file.go\n+++ b/file.go\n@@ -1,3 +1,3 @@\n package main\n \n-// old\n+// new\n@@ -10,3 +10,4 @@\n func foo() {\n+\t// added\n \treturn\n }\n"
|
||||||
|
result := vcs.BuildLineToPositionMap(diff)
|
||||||
|
fileMap, ok := result["file.go"]
|
||||||
|
if !ok {
|
||||||
|
t.Fatal("expected file.go in result")
|
||||||
|
}
|
||||||
|
// First hunk: @@ is position 1
|
||||||
|
// Line 1: " package main" -> position 2
|
||||||
|
if fileMap[1] != 2 {
|
||||||
|
t.Errorf("line 1 position = %d, want 2", fileMap[1])
|
||||||
|
}
|
||||||
|
// Line 3: "+// new" -> position 5 (after " ", "-// old" at pos 3,4)
|
||||||
|
if fileMap[3] != 5 {
|
||||||
|
t.Errorf("line 3 position = %d, want 5", fileMap[3])
|
||||||
|
}
|
||||||
|
// Second hunk: @@ is position 6
|
||||||
|
// Line 10: " func foo() {" -> position 7
|
||||||
|
if fileMap[10] != 7 {
|
||||||
|
t.Errorf("line 10 position = %d, want 7", fileMap[10])
|
||||||
|
}
|
||||||
|
// Line 11: "+\t// added" -> position 8
|
||||||
|
if fileMap[11] != 8 {
|
||||||
|
t.Errorf("line 11 position = %d, want 8", fileMap[11])
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("deletion lines not in map", func(t *testing.T) {
|
||||||
|
diff := "diff --git a/file.go b/file.go\n--- a/file.go\n+++ b/file.go\n@@ -1,4 +1,3 @@\n package main\n \n-// deleted line\n func main() {}\n"
|
||||||
|
result := vcs.BuildLineToPositionMap(diff)
|
||||||
|
fileMap, ok := result["file.go"]
|
||||||
|
if !ok {
|
||||||
|
t.Fatal("expected file.go in result")
|
||||||
|
}
|
||||||
|
// Line 1: " package main" -> position 2
|
||||||
|
if fileMap[1] != 2 {
|
||||||
|
t.Errorf("line 1 position = %d, want 2", fileMap[1])
|
||||||
|
}
|
||||||
|
// Line 3 in new file: " func main() {}" -> position 5 (after deletion at pos 4)
|
||||||
|
if fileMap[3] != 5 {
|
||||||
|
t.Errorf("line 3 position = %d, want 5", fileMap[3])
|
||||||
|
}
|
||||||
|
// Should only have 3 entries (lines 1, 2, 3 of new file)
|
||||||
|
if len(fileMap) != 3 {
|
||||||
|
t.Errorf("expected 3 mapped lines, got %d: %v", len(fileMap), fileMap)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("multiple files", func(t *testing.T) {
|
||||||
|
diff := "diff --git a/a.go b/a.go\n--- a/a.go\n+++ b/a.go\n@@ -1,2 +1,3 @@\n package a\n \n+// file a\ndiff --git a/b.go b/b.go\n--- a/b.go\n+++ b/b.go\n@@ -1,2 +1,3 @@\n package b\n \n+// file b\n"
|
||||||
|
result := vcs.BuildLineToPositionMap(diff)
|
||||||
|
if len(result) != 2 {
|
||||||
|
t.Fatalf("expected 2 files, got %d", len(result))
|
||||||
|
}
|
||||||
|
aMap, ok := result["a.go"]
|
||||||
|
if !ok {
|
||||||
|
t.Fatal("expected a.go in result")
|
||||||
|
}
|
||||||
|
bMap, ok := result["b.go"]
|
||||||
|
if !ok {
|
||||||
|
t.Fatal("expected b.go in result")
|
||||||
|
}
|
||||||
|
// a.go line 3: "+// file a" -> position 4
|
||||||
|
if aMap[3] != 4 {
|
||||||
|
t.Errorf("a.go line 3 position = %d, want 4", aMap[3])
|
||||||
|
}
|
||||||
|
// b.go line 3: "+// file b" -> position 4
|
||||||
|
if bMap[3] != 4 {
|
||||||
|
t.Errorf("b.go line 3 position = %d, want 4", bMap[3])
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetAllFilesInPath_ErrorPropagation(t *testing.T) {
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
t.Run("ListContents error propagates", func(t *testing.T) {
|
||||||
|
client := &mockFileReader{
|
||||||
|
contents: map[string][]vcs.ContentEntry{
|
||||||
|
// "src" not in map, so ListContents will fail
|
||||||
|
},
|
||||||
|
}
|
||||||
|
_, err := vcs.GetAllFilesInPath(ctx, client, "owner", "repo", "src")
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected error, got nil")
|
||||||
|
}
|
||||||
|
if !strings.Contains(err.Error(), "list contents") {
|
||||||
|
t.Errorf("expected error about list contents, got: %v", err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("GetFileContent error propagates", func(t *testing.T) {
|
||||||
|
client := &mockFileReader{
|
||||||
|
contents: map[string][]vcs.ContentEntry{
|
||||||
|
"src": {
|
||||||
|
{Name: "main.go", Path: "src/main.go", Type: "file"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
files: map[string]string{
|
||||||
|
// "src/main.go" not in files map, so GetFileContent will fail
|
||||||
|
},
|
||||||
|
}
|
||||||
|
_, err := vcs.GetAllFilesInPath(ctx, client, "owner", "repo", "src")
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected error, got nil")
|
||||||
|
}
|
||||||
|
if !strings.Contains(err.Error(), "get file") {
|
||||||
|
t.Errorf("expected error about get file, got: %v", err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("nested ListContents error propagates", func(t *testing.T) {
|
||||||
|
client := &mockFileReader{
|
||||||
|
contents: map[string][]vcs.ContentEntry{
|
||||||
|
"src": {
|
||||||
|
{Name: "pkg", Path: "src/pkg", Type: "dir"},
|
||||||
|
},
|
||||||
|
// "src/pkg" not in map, so recursive ListContents will fail
|
||||||
|
},
|
||||||
|
}
|
||||||
|
_, err := vcs.GetAllFilesInPath(ctx, client, "owner", "repo", "src")
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected error, got nil")
|
||||||
|
}
|
||||||
|
if !strings.Contains(err.Error(), "list contents") {
|
||||||
|
t.Errorf("expected error about list contents, got: %v", err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("canceled context propagates", func(t *testing.T) {
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
cancel() // Cancel immediately
|
||||||
|
|
||||||
|
client := &mockFileReader{
|
||||||
|
contents: map[string][]vcs.ContentEntry{
|
||||||
|
"src": {
|
||||||
|
{Name: "main.go", Path: "src/main.go", Type: "file"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
files: map[string]string{
|
||||||
|
"src/main.go": "package main",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
_, err := vcs.GetAllFilesInPath(ctx, client, "owner", "repo", "src")
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected error from canceled context, got nil")
|
||||||
|
}
|
||||||
|
if !strings.Contains(err.Error(), "context canceled") {
|
||||||
|
t.Errorf("expected context cancellation error, got: %v", err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user