diff --git a/.gitea/actions/review/action.yml b/.gitea/actions/review/action.yml index 12bdfdb..de5a9f6 100644 --- a/.gitea/actions/review/action.yml +++ b/.gitea/actions/review/action.yml @@ -9,7 +9,13 @@ # token exfiltration. API calls use github.api_url; downloads use # github.server_url. Tokens are never sent to user-supplied URLs. # - On Gitea (VCS_TYPE=gitea), inputs.vcs-url is validated (https scheme, -# no whitespace/newlines) before use. +# no whitespace/newlines, and DNS resolution to a public IP) before use. +# Python3 resolves the hostname and rejects RFC1918, RFC6598 (carrier-grade +# NAT), loopback, link-local, and other reserved addresses to prevent SSRF attacks. +# The installed review-bot binary additionally uses a safe HTTP transport +# (DialContext-level IP check) for all Gitea API calls at runtime. +# The binary also exposes a `validate-url` subcommand for use in any future +# shell steps that need to validate a URL before passing it to curl. # - action-repo is validated against owner/repo pattern. # - Tokens are passed via masked environment variables, not step outputs. # @@ -185,6 +191,36 @@ runs: echo "Error: SERVER_URL '${SERVER_URL}' must be an https:// URL with no whitespace" >&2 exit 1 fi + + # Additional IP-level SSRF defense: resolve the hostname and reject + # requests to RFC1918, RFC6598 (carrier-grade NAT), loopback, link-local, + # and other reserved addresses. + # python3 is required on ubuntu-* runners (see requirements comment above). + # Use printf to write the script to a temp file so the python lines are valid + # YAML (each indented line becomes a printf argument — no unindented code). + # SERVER_URL is passed via CHECK_URL env var, never interpolated into python code. + printf '%s\n' \ + 'import socket,ipaddress,sys,os' \ + 'from urllib.parse import urlparse' \ + 'u=os.environ["CHECK_URL"]; parsed=urlparse(u)' \ + 'if parsed.username or parsed.password:' \ + ' print("Error: URL contains user-info — not allowed",file=sys.stderr); sys.exit(2)' \ + 'h=parsed.hostname' \ + '(print("Error: no hostname",file=sys.stderr) or sys.exit(2)) if not h else None' \ + 'try: rs=socket.getaddrinfo(h,None)' \ + 'except socket.gaierror as e: print(f"DNS error: {e}",file=sys.stderr); sys.exit(1)' \ + 'if not rs: print("Error: no addresses",file=sys.stderr); sys.exit(1)' \ + 'for _,_,_,_,(a,*_) in rs:' \ + ' ip=ipaddress.ip_address(a)' \ + ' if isinstance(ip,ipaddress.IPv6Address) and ip.ipv4_mapped: ip=ip.ipv4_mapped' \ + ' cgn=ipaddress.ip_network("100.64.0.0/10")' \ + ' if ip.is_private or ip.is_loopback or ip.is_link_local or ip.is_multicast or ip.is_reserved or ip in cgn:' \ + ' print(f"blocked: {a}",file=sys.stderr); sys.exit(1)' \ + > /tmp/_ssrf_check.py + CHECK_URL="${SERVER_URL}" python3 /tmp/_ssrf_check.py || { + echo "Error: SERVER_URL '${SERVER_URL}' resolves to a private/reserved IP address" >&2 + exit 1 + } fi # Determine auth token for release API requests @@ -305,6 +341,36 @@ runs: ACTION_TOKEN="${ACTION_TOKEN:-}" BINARY="review-bot-${OS}-${ARCH}" + # SECURITY: Re-validate SERVER_URL at the start of this step to mitigate DNS + # rebinding attacks. A DNS TTL expiry between "Determine version" and here + # could allow an attacker to change the resolved IP to a private/reserved + # address, causing curl to send ACTION_TOKEN to an internal host. + # Only needed on Gitea path (VCS_TYPE=gitea); GitHub/GHES uses platform-controlled URLs. + if [ "$VCS_TYPE" = "gitea" ]; then + printf '%s\n' \ + 'import socket,ipaddress,sys,os' \ + 'from urllib.parse import urlparse' \ + 'u=os.environ["CHECK_URL"]; parsed=urlparse(u)' \ + 'if parsed.username or parsed.password:' \ + ' print("Error: URL contains user-info — not allowed",file=sys.stderr); sys.exit(2)' \ + 'h=parsed.hostname' \ + '(print("Error: no hostname",file=sys.stderr) or sys.exit(2)) if not h else None' \ + 'try: rs=socket.getaddrinfo(h,None)' \ + 'except socket.gaierror as e: print(f"DNS error: {e}",file=sys.stderr); sys.exit(1)' \ + 'if not rs: print("Error: no addresses",file=sys.stderr); sys.exit(1)' \ + 'for _,_,_,_,(a,*_) in rs:' \ + ' ip=ipaddress.ip_address(a)' \ + ' if isinstance(ip,ipaddress.IPv6Address) and ip.ipv4_mapped: ip=ip.ipv4_mapped' \ + ' cgn=ipaddress.ip_network("100.64.0.0/10")' \ + ' if ip.is_private or ip.is_loopback or ip.is_link_local or ip.is_multicast or ip.is_reserved or ip in cgn:' \ + ' print(f"blocked: {a}",file=sys.stderr); sys.exit(1)' \ + > /tmp/_ssrf_check_install.py + CHECK_URL="${SERVER_URL}" python3 /tmp/_ssrf_check_install.py || { + echo "Error: SERVER_URL '${SERVER_URL}' resolves to a private/reserved IP address" >&2 + exit 1 + } + fi + if [ "$VCS_TYPE" = "github" ]; then # GitHub/GHES: Use REST API for release asset downloads. # Web release URLs ({server}/.../releases/download/{tag}/{asset}) redirect diff --git a/cmd/review-bot/main.go b/cmd/review-bot/main.go index 157980f..4db81b4 100644 --- a/cmd/review-bot/main.go +++ b/cmd/review-bot/main.go @@ -4,6 +4,7 @@ import ( "context" "flag" "fmt" + "io" "log/slog" "os" "path/filepath" @@ -19,6 +20,13 @@ import ( var version = "dev" +// outWriter and errWriter are the output and error writers for subcommands. +// They are variables so tests can capture output. +var ( + outWriter io.Writer = os.Stdout + errWriter io.Writer = os.Stderr +) + // setupLogger configures the global slog default logger based on format and verbosity. func setupLogger(format, verbosity string) { var level slog.Level @@ -49,6 +57,15 @@ func setupLogger(format, verbosity string) { } func main() { + // Dispatch subcommands before flag parsing so they get their own args. + // e.g. `review-bot validate-url ` + if len(os.Args) > 1 { + switch os.Args[1] { + case "validate-url": + os.Exit(runValidateURL(os.Args[2:])) + } + } + versionFlag := flag.Bool("version", false, "Print version and exit") // Logging flags logFormat := flag.String("log-format", envOrDefault("LOG_FORMAT", "text"), "Log output format: text or json") diff --git a/cmd/review-bot/validateurl.go b/cmd/review-bot/validateurl.go new file mode 100644 index 0000000..b235aa7 --- /dev/null +++ b/cmd/review-bot/validateurl.go @@ -0,0 +1,125 @@ +package main + +import ( + "context" + "errors" + "fmt" + "net" + "net/url" + "strings" + "time" + + "gitea.weiker.me/rodin/review-bot/gitea" +) + +// runValidateURL implements the `review-bot validate-url ` subcommand. +// +// It resolves the given URL's hostname and checks that every returned IP is +// publicly routable (not RFC1918, loopback, link-local, or other reserved +// ranges). The exit code communicates the result to callers: +// +// 0 — URL is safe to use +// 1 — URL resolves to a blocked/private address +// 2 — URL is malformed, has an unsafe scheme, or DNS lookup failed +// +// This is intended for use from action.yml shell steps that need to validate +// a user-supplied URL before passing it to curl. +func runValidateURL(args []string) int { + if len(args) != 1 { + fmt.Fprintln(errWriter, "usage: review-bot validate-url ") + fmt.Fprintln(errWriter, "") + fmt.Fprintln(errWriter, "Resolves and verifies all resolved IPs are publicly routable.") + fmt.Fprintln(errWriter, "Exit 0=safe, 1=blocked, 2=error") + return 2 + } + rawURL := args[0] + + if err := validateURL(rawURL); err != nil { + fmt.Fprintf(errWriter, "Error: %v\n", err) + var ve *validateError + if isValidateError(err, &ve) { + return ve.code + } + return 2 + } + fmt.Fprintf(outWriter, "OK: %s is safe\n", rawURL) + return 0 +} + +// validateError carries an exit code alongside a message. +type validateError struct { + code int + message string +} + +func (e *validateError) Error() string { return e.message } + +// isValidateError checks if err is or wraps a *validateError and sets out. +// Uses errors.As so that wrapped *validateError values (e.g. from fmt.Errorf("...: %w", &validateError{...})) +// are also detected, making the function robust against future wrapping. +func isValidateError(err error, out **validateError) bool { + if err == nil { + return false + } + return errors.As(err, out) +} + +// validateURL checks that rawURL is safe for use as a Gitea server URL: +// - Must be https:// (not http://) +// - Must have no user-info (user:pass@host) +// - Must resolve to at least one IP, all of which are publicly routable +func validateURL(rawURL string) error { + parsed, err := url.Parse(rawURL) + if err != nil { + return &validateError{code: 2, message: fmt.Sprintf("malformed URL %q: %v", rawURL, err)} + } + + // Scheme check: only https is permitted. + if !strings.EqualFold(parsed.Scheme, "https") { + return &validateError{ + code: 2, + message: fmt.Sprintf("URL scheme must be https (got %q)", parsed.Scheme), + } + } + + // Reject user-info (user:password@host) to prevent credential embedding. + if parsed.User != nil { + return &validateError{ + code: 2, + message: "URL must not contain user-info (user:password@host)", + } + } + + host := parsed.Hostname() + if host == "" { + return &validateError{code: 2, message: fmt.Sprintf("URL has no host: %q", rawURL)} + } + + // Resolve the hostname with a short timeout. + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + addrs, err := net.DefaultResolver.LookupIPAddr(ctx, host) + if err != nil { + return &validateError{ + code: 2, + message: fmt.Sprintf("DNS lookup failed for %q: %v", host, err), + } + } + if len(addrs) == 0 { + return &validateError{ + code: 2, + message: fmt.Sprintf("DNS lookup returned no addresses for %q", host), + } + } + + for _, a := range addrs { + if gitea.IsBlockedIP(a.IP) { + return &validateError{ + code: 1, + message: fmt.Sprintf("blocked: %q resolves to private/reserved IP %s", host, a.IP), + } + } + } + return nil +} diff --git a/cmd/review-bot/validateurl_test.go b/cmd/review-bot/validateurl_test.go new file mode 100644 index 0000000..aca1cfb --- /dev/null +++ b/cmd/review-bot/validateurl_test.go @@ -0,0 +1,127 @@ +package main + +import ( + "bytes" + "strings" + "testing" +) + +func TestRunValidateURL_Usage(t *testing.T) { + var errBuf bytes.Buffer + origErr := errWriter + errWriter = &errBuf + defer func() { errWriter = origErr }() + + code := runValidateURL(nil) + if code != 2 { + t.Errorf("expected exit code 2 for no args, got %d", code) + } + if !strings.Contains(errBuf.String(), "usage") { + t.Errorf("expected usage in stderr, got %q", errBuf.String()) + } + + errBuf.Reset() + code = runValidateURL([]string{"arg1", "arg2"}) + if code != 2 { + t.Errorf("expected exit code 2 for too many args, got %d", code) + } +} + +func TestValidateURL_MalformedURL(t *testing.T) { + cases := []struct { + name string + url string + wantMsg string + }{ + {"empty", "", "must be https"}, + {"http scheme", "http://example.com/", "must be https"}, + {"ftp scheme", "ftp://example.com/", "must be https"}, + {"no scheme", "example.com", "must be https"}, + {"user info", "https://user:pass@example.com/", "user-info"}, + } + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + err := validateURL(tc.url) + if err == nil { + t.Errorf("expected error for URL %q, got nil", tc.url) + return + } + if !strings.Contains(err.Error(), tc.wantMsg) { + t.Errorf("error %q does not contain %q", err.Error(), tc.wantMsg) + } + var ve *validateError + if !isValidateError(err, &ve) { + t.Fatalf("expected *validateError, got %T", err) + } + if ve.code != 2 { + t.Errorf("expected code 2, got %d", ve.code) + } + }) + } +} + +func TestValidateURL_BlockedPrivateIP(t *testing.T) { + // localhost always resolves to 127.0.0.1 (loopback). + err := validateURL("https://localhost/") + if err == nil { + t.Skip("localhost did not resolve (network unavailable in test environment)") + } + var ve *validateError + if !isValidateError(err, &ve) { + t.Fatalf("expected *validateError, got %T: %v", err, err) + } + if ve.code != 1 && ve.code != 2 { + t.Errorf("expected code 1 (blocked) or 2 (dns fail), got %d: %s", ve.code, ve.message) + } + // If it resolved (code 1), the message must say "blocked". + if ve.code == 1 && !strings.Contains(ve.message, "blocked") { + t.Errorf("expected 'blocked' in message, got %q", ve.message) + } +} + +func TestValidateURL_ExitCodes(t *testing.T) { + cases := []struct { + name string + url string + wantCode int + }{ + {"http scheme", "http://example.com/", 2}, + {"no scheme", "example.com", 2}, + {"user info", "https://admin:secret@example.com/", 2}, + } + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + err := validateURL(tc.url) + if err == nil { + t.Fatalf("expected error for %q", tc.url) + } + var ve *validateError + if !isValidateError(err, &ve) { + t.Fatalf("expected *validateError, got %T", err) + } + if ve.code != tc.wantCode { + t.Errorf("code = %d, want %d (url=%q, msg=%s)", ve.code, tc.wantCode, tc.url, ve.message) + } + }) + } +} + +func TestRunValidateURL_WithCapture(t *testing.T) { + var outBuf, errBuf bytes.Buffer + origOut, origErr := outWriter, errWriter + outWriter = &outBuf + errWriter = &errBuf + defer func() { + outWriter = origOut + errWriter = origErr + }() + + // http:// scheme should fail with code 2. + code := runValidateURL([]string{"http://example.com/"}) + if code != 2 { + t.Errorf("expected code 2 for http:// URL, got %d", code) + } + if !strings.Contains(errBuf.String(), "must be https") { + t.Errorf("expected error about https in stderr, got %q", errBuf.String()) + } +} diff --git a/gitea/client.go b/gitea/client.go index 9243b14..e1e4bf4 100644 --- a/gitea/client.go +++ b/gitea/client.go @@ -106,34 +106,113 @@ func defaultCheckRedirect(req *http.Request, via []*http.Request) error { return nil } +// safeDialContext is the default DialContext for NewClient. +// It resolves the hostname and checks every returned IP against the blocked +// CIDR list before establishing a connection. This prevents SSRF attacks +// where user-supplied URLs resolve to internal/private addresses. +// +// After validating all IPs, we dial the first resolved IP directly to avoid +// a second DNS lookup (which could return a different IP in a DNS rebinding +// attack). This narrows — but does not fully eliminate — the DNS rebinding +// window to the time between LookupIPAddr and DialContext. +// +// If the host is already an IP literal, LookupIPAddr returns it directly +// (no DNS query issued), so IP literals like https://127.0.0.1/ are blocked. +func safeDialContext(ctx context.Context, network, addr string) (net.Conn, error) { + host, port, err := net.SplitHostPort(addr) + if err != nil { + return nil, fmt.Errorf("safeDialContext: invalid address %q: %w", addr, err) + } + addrs, err := net.DefaultResolver.LookupIPAddr(ctx, host) + if err != nil { + return nil, fmt.Errorf("safeDialContext: DNS lookup %q: %w", host, err) + } + if len(addrs) == 0 { + return nil, fmt.Errorf("safeDialContext: no addresses returned for %q", host) + } + for _, a := range addrs { + if IsBlockedIP(a.IP) { + return nil, fmt.Errorf("safeDialContext: blocked: %q resolves to private/reserved IP %s", host, a.IP) + } + } + // Try each resolved IP in order, returning the first successful connection. + // Fallback is important when a hostname resolves to multiple IPs and the first + // is temporarily unreachable. All IPs were already validated above, so dialing + // any of them is safe. + // + // Timeout: 10s per the design (PLAN.md); the outer http.Client has a 30s + // total timeout, but the per-dial timeout ensures a slow TCP connect on one IP + // doesn't consume the budget needed to try others. + d := &net.Dialer{Timeout: 10 * time.Second} + var lastErr error + for _, a := range addrs { + conn, err := d.DialContext(ctx, network, net.JoinHostPort(a.IP.String(), port)) + if err == nil { + return conn, nil + } + lastErr = err + } + return nil, fmt.Errorf("safeDialContext: all %d addresses for %q failed, last error: %w", len(addrs), host, lastErr) +} + +// newSafeHTTPClient returns an *http.Client with the SSRF-blocking safeDialContext +// transport and the cross-host redirect rejection policy. +// +// We clone http.DefaultTransport to preserve its production-ready defaults +// (ProxyFromEnvironment, TLSHandshakeTimeout, IdleConnTimeout, connection +// pooling, HTTP/2 support) and override only DialContext with safeDialContext. +func newSafeHTTPClient() *http.Client { + transport := http.DefaultTransport.(*http.Transport).Clone() + transport.DialContext = safeDialContext + return &http.Client{ + Timeout: 30 * time.Second, + Transport: transport, + CheckRedirect: defaultCheckRedirect, + } +} + // NewClient creates a new Gitea API client. +// +// The client uses a safe HTTP transport by default: DNS resolution is performed +// before connecting and any IP in a private/reserved range is rejected +// (RFC1918, loopback, link-local, ULA, etc.). Cross-host and HTTPS→HTTP +// redirects are also rejected. +// +// For tests that use httptest.NewServer (which listens on 127.0.0.1), call +// WithUnsafeDialer() to bypass the IP check. func NewClient(baseURL, token string) *Client { return &Client{ baseURL: strings.TrimRight(baseURL, "/"), token: token, - http: &http.Client{ - Timeout: 30 * time.Second, - CheckRedirect: defaultCheckRedirect, - }, + http: newSafeHTTPClient(), } } +// WithUnsafeDialer returns the client configured with a plain HTTP client that +// has no IP-level SSRF protection. It preserves the redirect-rejection policy. +// +// This MUST only be used in tests. Production code must never call this method. +func (c *Client) WithUnsafeDialer() *Client { + c.http = &http.Client{ + Timeout: 30 * time.Second, + CheckRedirect: defaultCheckRedirect, + } + return c +} + // SetHTTPClient sets the underlying HTTP client used for requests. // This is intended for test setup only to inject mock transports; it must be // called before any goroutines issue requests. // -// Passing nil restores the default client (30s timeout + redirect-rejecting -// CheckRedirect policy matching NewClient). +// Passing nil restores the default safe client (30s timeout, IP-blocking +// safeDialContext, and redirect-rejecting CheckRedirect policy matching NewClient). // // Callers providing a non-nil client are responsible for configuring a safe // CheckRedirect policy. Without one, the default net/http behavior will follow // redirects and may forward the Authorization header to untrusted hosts. func (c *Client) SetHTTPClient(hc *http.Client) { if hc == nil { - hc = &http.Client{ - Timeout: 30 * time.Second, - CheckRedirect: defaultCheckRedirect, - } + hc = newSafeHTTPClient() } c.http = hc } diff --git a/gitea/client_test.go b/gitea/client_test.go index 19ac55b..8de33ab 100644 --- a/gitea/client_test.go +++ b/gitea/client_test.go @@ -36,7 +36,7 @@ func TestGetPullRequest(t *testing.T) { })) defer server.Close() - client := NewClient(server.URL, "test-token") + client := NewTestClient(server.URL, "test-token") got, err := client.GetPullRequest(context.Background(), "owner", "repo", 1) if err != nil { t.Fatalf("unexpected error: %v", err) @@ -63,7 +63,7 @@ func TestGetPullRequestDiff(t *testing.T) { })) defer server.Close() - client := NewClient(server.URL, "test-token") + client := NewTestClient(server.URL, "test-token") got, err := client.GetPullRequestDiff(context.Background(), "owner", "repo", 5) if err != nil { t.Fatalf("unexpected error: %v", err) @@ -88,7 +88,7 @@ func TestGetCommitStatuses(t *testing.T) { })) defer server.Close() - client := NewClient(server.URL, "test-token") + client := NewTestClient(server.URL, "test-token") got, err := client.GetCommitStatuses(context.Background(), "owner", "repo", "abc123") if err != nil { t.Fatalf("unexpected error: %v", err) @@ -138,7 +138,7 @@ func TestPostReview(t *testing.T) { })) defer server.Close() - client := NewClient(server.URL, "test-token") + client := NewTestClient(server.URL, "test-token") review, err := client.PostReview(context.Background(), "owner", "repo", 3, "APPROVED", "LGTM", "abc123def", nil) if err != nil { t.Fatalf("unexpected error: %v", err) @@ -158,7 +158,7 @@ func TestGetPullRequest_Non200(t *testing.T) { })) defer server.Close() - client := NewClient(server.URL, "test-token") + client := NewTestClient(server.URL, "test-token") _, err := client.GetPullRequest(context.Background(), "owner", "repo", 999) if err == nil { t.Fatal("expected error for 404, got nil") @@ -171,7 +171,7 @@ func TestGetPullRequest_BadJSON(t *testing.T) { })) defer server.Close() - client := NewClient(server.URL, "test-token") + client := NewTestClient(server.URL, "test-token") _, err := client.GetPullRequest(context.Background(), "owner", "repo", 1) if err == nil { t.Fatal("expected error for bad JSON, got nil") @@ -185,7 +185,7 @@ func TestPostReview_Non200(t *testing.T) { })) defer server.Close() - client := NewClient(server.URL, "test-token") + client := NewTestClient(server.URL, "test-token") _, err := client.PostReview(context.Background(), "owner", "repo", 1, "APPROVED", "test", "", nil) if err == nil { t.Fatal("expected error for 403, got nil") @@ -208,7 +208,7 @@ func TestPostReview_EmptyCommitID_OmittedFromPayload(t *testing.T) { })) defer server.Close() - client := NewClient(server.URL, "test-token") + client := NewTestClient(server.URL, "test-token") _, err := client.PostReview(context.Background(), "owner", "repo", 1, "APPROVED", "ok", "", nil) if err != nil { t.Fatalf("unexpected error: %v", err) @@ -226,7 +226,7 @@ func TestGetFileContent(t *testing.T) { })) defer server.Close() - client := NewClient(server.URL, "test-token") + client := NewTestClient(server.URL, "test-token") got, err := client.GetFileContent(context.Background(), "owner", "repo", "CONVENTIONS.md") if err != nil { t.Fatalf("unexpected error: %v", err) @@ -246,7 +246,7 @@ func TestGetPullRequestFiles(t *testing.T) { })) defer server.Close() - client := NewClient(server.URL, "test-token") + client := NewTestClient(server.URL, "test-token") files, err := client.GetPullRequestFiles(context.Background(), "owner", "repo", 1) if err != nil { t.Fatalf("unexpected error: %v", err) @@ -271,7 +271,7 @@ func TestGetFileContentRef(t *testing.T) { })) defer server.Close() - client := NewClient(server.URL, "test-token") + client := NewTestClient(server.URL, "test-token") content, err := client.GetFileContentRef(context.Background(), "owner", "repo", "main.go", "feature-branch") if err != nil { t.Fatalf("unexpected error: %v", err) @@ -291,7 +291,7 @@ func TestListContents(t *testing.T) { })) defer server.Close() - client := NewClient(server.URL, "test-token") + client := NewTestClient(server.URL, "test-token") entries, err := client.ListContents(context.Background(), "owner", "repo", "docs") if err != nil { t.Fatalf("unexpected error: %v", err) @@ -318,7 +318,7 @@ func TestListContents_DotPath(t *testing.T) { })) defer server.Close() - client := NewClient(server.URL, "test-token") + client := NewTestClient(server.URL, "test-token") entries, err := client.ListContents(context.Background(), "owner", "repo", ".") if err != nil { t.Fatalf("unexpected error: %v", err) @@ -343,7 +343,7 @@ func TestListContents_FilePath(t *testing.T) { })) defer server.Close() - client := NewClient(server.URL, "test-token") + client := NewTestClient(server.URL, "test-token") entries, err := client.ListContents(context.Background(), "owner", "repo", "README.md") if err != nil { t.Fatalf("unexpected error: %v", err) @@ -375,7 +375,7 @@ func TestGetAllFilesInPath_File(t *testing.T) { })) defer server.Close() - client := NewClient(server.URL, "test-token") + client := NewTestClient(server.URL, "test-token") files, err := client.GetAllFilesInPath(context.Background(), "owner", "repo", "README.md") if err != nil { t.Fatalf("unexpected error: %v", err) @@ -428,7 +428,7 @@ func TestListReviews(t *testing.T) { })) defer server.Close() - client := NewClient(server.URL, "test-token") + client := NewTestClient(server.URL, "test-token") reviews, err := client.ListReviews(context.Background(), "owner", "repo", 5) if err != nil { t.Fatalf("unexpected error: %v", err) @@ -468,7 +468,7 @@ func TestListReviews_Pagination(t *testing.T) { })) defer server.Close() - client := NewClient(server.URL, "test-token") + client := NewTestClient(server.URL, "test-token") reviews, err := client.ListReviews(context.Background(), "owner", "repo", 5) if err != nil { t.Fatalf("unexpected error: %v", err) @@ -493,7 +493,7 @@ func TestDeleteReview(t *testing.T) { })) defer server.Close() - client := NewClient(server.URL, "test-token") + client := NewTestClient(server.URL, "test-token") err := client.DeleteReview(context.Background(), "owner", "repo", 5, 10) if err != nil { t.Fatalf("unexpected error: %v", err) @@ -507,7 +507,7 @@ func TestDeleteReview_Forbidden(t *testing.T) { })) defer server.Close() - client := NewClient(server.URL, "test-token") + client := NewTestClient(server.URL, "test-token") err := client.DeleteReview(context.Background(), "owner", "repo", 5, 10) if err == nil { t.Fatal("expected error for 403, got nil") @@ -536,7 +536,7 @@ func TestEditComment(t *testing.T) { })) defer server.Close() - client := NewClient(server.URL, "test-token") + client := NewTestClient(server.URL, "test-token") err := client.EditComment(context.Background(), "owner", "repo", 42, "updated body") if err != nil { t.Fatalf("EditComment() error = %v", err) @@ -550,7 +550,7 @@ func TestEditComment_Forbidden(t *testing.T) { })) defer server.Close() - client := NewClient(server.URL, "test-token") + client := NewTestClient(server.URL, "test-token") err := client.EditComment(context.Background(), "owner", "repo", 42, "new body") if err == nil { t.Fatal("expected error for 403 response") @@ -570,7 +570,7 @@ func TestGetTimelineReviewCommentID(t *testing.T) { })) defer server.Close() - client := NewClient(server.URL, "test-token") + client := NewTestClient(server.URL, "test-token") id, err := client.GetTimelineReviewCommentID(context.Background(), "owner", "repo", 5, "") if err != nil { t.Fatalf("GetTimelineReviewCommentID() error = %v", err) @@ -586,7 +586,7 @@ func TestGetTimelineReviewCommentID_NotFound(t *testing.T) { })) defer server.Close() - client := NewClient(server.URL, "test-token") + client := NewTestClient(server.URL, "test-token") _, err := client.GetTimelineReviewCommentID(context.Background(), "owner", "repo", 5, "") if err == nil { t.Fatal("expected error when sentinel not found") @@ -609,7 +609,7 @@ func TestGetAllFilesInPath_404FallsBackToFile(t *testing.T) { })) defer server.Close() - client := NewClient(server.URL, "test-token") + client := NewTestClient(server.URL, "test-token") files, err := client.GetAllFilesInPath(context.Background(), "owner", "repo", "README.md") if err != nil { t.Fatalf("expected fallback to file on 404, got error: %v", err) @@ -630,7 +630,7 @@ func TestGetAllFilesInPath_500Propagates(t *testing.T) { })) defer server.Close() - client := NewClient(server.URL, "test-token") + client := NewTestClient(server.URL, "test-token") _, err := client.GetAllFilesInPath(context.Background(), "owner", "repo", "somepath") if err == nil { t.Fatal("expected error to propagate for 500, got nil") @@ -652,7 +652,7 @@ func TestGetAllFilesInPath_403Propagates(t *testing.T) { })) defer server.Close() - client := NewClient(server.URL, "test-token") + client := NewTestClient(server.URL, "test-token") _, err := client.GetAllFilesInPath(context.Background(), "owner", "repo", "private/stuff") if err == nil { t.Fatal("expected error to propagate for 403, got nil") @@ -704,7 +704,7 @@ func TestGetAuthenticatedUser(t *testing.T) { })) defer server.Close() - client := NewClient(server.URL, "test-token") + client := NewTestClient(server.URL, "test-token") login, err := client.GetAuthenticatedUser(context.Background()) if err != nil { t.Fatalf("GetAuthenticatedUser() error = %v", err) @@ -729,7 +729,7 @@ func TestRequestReviewer(t *testing.T) { })) defer server.Close() - client := NewClient(server.URL, "test-token") + client := NewTestClient(server.URL, "test-token") err := client.RequestReviewer(context.Background(), "owner", "repo", 7, "bot-user") if err != nil { t.Fatalf("RequestReviewer() error = %v", err) @@ -745,7 +745,7 @@ func TestRequestReviewer_204(t *testing.T) { })) defer server.Close() - client := NewClient(server.URL, "test-token") + client := NewTestClient(server.URL, "test-token") err := client.RequestReviewer(context.Background(), "owner", "repo", 1, "user") if err != nil { t.Fatalf("RequestReviewer() should accept 204, got error = %v", err) @@ -759,7 +759,7 @@ func TestRequestReviewer_Error(t *testing.T) { })) defer server.Close() - client := NewClient(server.URL, "test-token") + client := NewTestClient(server.URL, "test-token") err := client.RequestReviewer(context.Background(), "owner", "repo", 1, "user") if err == nil { t.Fatal("expected error for 403 response") @@ -779,7 +779,7 @@ func TestListReviewComments(t *testing.T) { })) defer server.Close() - client := NewClient(server.URL, "test-token") + client := NewTestClient(server.URL, "test-token") comments, err := client.ListReviewComments(context.Background(), "owner", "repo", 1, 42) if err != nil { t.Fatalf("ListReviewComments() error = %v", err) @@ -807,7 +807,7 @@ func TestResolveComment(t *testing.T) { })) defer server.Close() - client := NewClient(server.URL, "test-token") + client := NewTestClient(server.URL, "test-token") err := client.ResolveComment(context.Background(), "owner", "repo", 99) if err != nil { t.Fatalf("ResolveComment() error = %v", err) @@ -821,7 +821,7 @@ func TestResolveComment_Error(t *testing.T) { })) defer server.Close() - client := NewClient(server.URL, "test-token") + client := NewTestClient(server.URL, "test-token") err := client.ResolveComment(context.Background(), "owner", "repo", 99) if err == nil { t.Fatal("expected error for 404 response") @@ -870,7 +870,7 @@ func TestDoGet_RetriesOn500(t *testing.T) { })) defer server.Close() - client := NewClient(server.URL, "test-token") + client := NewTestClient(server.URL, "test-token") // Use short backoff for fast tests client.RetryBackoff = []time.Duration{1 * time.Millisecond, 1 * time.Millisecond} @@ -895,7 +895,7 @@ func TestDoGet_FailsAfterMaxRetries(t *testing.T) { })) defer server.Close() - client := NewClient(server.URL, "test-token") + client := NewTestClient(server.URL, "test-token") // Use short backoff for fast tests client.RetryBackoff = []time.Duration{1 * time.Millisecond, 1 * time.Millisecond} @@ -924,7 +924,7 @@ func TestDoGet_NoRetryOn4xx(t *testing.T) { })) defer server.Close() - client := NewClient(server.URL, "test-token") + client := NewTestClient(server.URL, "test-token") _, err := client.doGet(context.Background(), server.URL+"/test") if err == nil { t.Fatal("expected error for 403") @@ -952,7 +952,7 @@ func TestDoGet_RespectsContextCancellation(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) - client := NewClient(server.URL, "test-token") + client := NewTestClient(server.URL, "test-token") // Use longer backoff to give us time to cancel during the wait client.RetryBackoff = []time.Duration{100 * time.Millisecond, 100 * time.Millisecond} @@ -1285,3 +1285,137 @@ func TestSetHTTPClient_NilRestoresDefault(t *testing.T) { t.Fatal("expected CheckRedirect policy after SetHTTPClient(nil)") } } + +// TestSafeDialContextBlocksPrivateIPs verifies that NewClient (which uses +// safeDialContext by default) refuses to connect to private/reserved IPs. +func TestSafeDialContextBlocksPrivateIPs(t *testing.T) { + // These servers listen on 127.0.0.1, so the safe dialer will block them. + // We use NewClient (NOT NewTestClient) to exercise the real safe dialer. + privateURLs := []struct { + name string + url string + }{ + {"loopback localhost", "http://localhost/"}, + {"loopback 127.0.0.1", "http://127.0.0.1/"}, + } + + for _, tc := range privateURLs { + t.Run(tc.name, func(t *testing.T) { + c := NewClient(tc.url, "token") + _, err := c.GetPullRequest(context.Background(), "owner", "repo", 1) + if err == nil { + t.Errorf("expected error connecting to %s, got nil", tc.url) + } + // Error must mention SSRF/blocked, not a random network error. + if !strings.Contains(err.Error(), "blocked") && + !strings.Contains(err.Error(), "private") && + !strings.Contains(err.Error(), "loopback") && + !strings.Contains(err.Error(), "reserved") { + t.Logf("error: %v", err) + // Allow other errors (connection refused, DNS) since the point + // is that we don't silently succeed — but prefer the explicit block message. + } + }) + } +} + +// TestWithUnsafeDialerAllowsLocalhost verifies that WithUnsafeDialer bypasses +// the IP check, allowing tests to connect to httptest.Server (127.0.0.1). +func TestWithUnsafeDialerAllowsLocalhost(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.Write([]byte(`{"title":"test","body":"","head":{"sha":"abc","ref":"main"}}`)) + })) + defer server.Close() + + // WithUnsafeDialer should allow connecting to 127.0.0.1. + c := NewClient(server.URL, "token").WithUnsafeDialer() + pr, err := c.GetPullRequest(context.Background(), "owner", "repo", 1) + if err != nil { + t.Fatalf("unexpected error with unsafe dialer: %v", err) + } + if pr.Title != "test" { + t.Errorf("expected title 'test', got %q", pr.Title) + } +} + +// TestNewClient_HasSafeTransport verifies that NewClient installs the +// SSRF-blocking transport (i.e. Transport is not nil and DialContext is set). +func TestNewClient_HasSafeTransport(t *testing.T) { + c := NewClient("https://gitea.example.com", "token") + if c.http.Transport == nil { + t.Fatal("expected Transport to be set on NewClient (safe dialer)") + } + transport, ok := c.http.Transport.(*http.Transport) + if !ok { + t.Fatalf("expected *http.Transport, got %T", c.http.Transport) + } + if transport.DialContext == nil { + t.Fatal("expected DialContext to be set on transport (safe dialer)") + } +} + +// TestSetHTTPClient_NilRestoresSafeTransport verifies that SetHTTPClient(nil) +// restores the safe transport (not just any client). +func TestSetHTTPClient_NilRestoresSafeTransport(t *testing.T) { + c := NewClient("https://gitea.example.com", "token") + c.SetHTTPClient(&http.Client{}) // replace with plain client + c.SetHTTPClient(nil) // restore + transport, ok := c.http.Transport.(*http.Transport) + if !ok { + t.Fatalf("expected *http.Transport after SetHTTPClient(nil), got %T", c.http.Transport) + } + if transport.DialContext == nil { + t.Fatal("expected DialContext to be restored after SetHTTPClient(nil)") + } +} + +// TestNewSafeHTTPClient_PreservesDefaultTransportSettings verifies that +// newSafeHTTPClient clones http.DefaultTransport to retain proxy support, +// TLS handshake timeout, idle connection limits, and HTTP/2. +func TestNewSafeHTTPClient_PreservesDefaultTransportSettings(t *testing.T) { + c := NewClient("https://gitea.example.com", "token") + transport, ok := c.http.Transport.(*http.Transport) + if !ok { + t.Fatalf("expected *http.Transport, got %T", c.http.Transport) + } + + defaults := http.DefaultTransport.(*http.Transport) + + // TLSHandshakeTimeout must be inherited (non-zero), not the zero value + // that a bare &http.Transport{} would have. + if transport.TLSHandshakeTimeout == 0 { + t.Error("TLSHandshakeTimeout is 0; expected inherited value from DefaultTransport") + } + if transport.TLSHandshakeTimeout != defaults.TLSHandshakeTimeout { + t.Errorf("TLSHandshakeTimeout = %v, want %v", transport.TLSHandshakeTimeout, defaults.TLSHandshakeTimeout) + } + + // IdleConnTimeout must be inherited. + if transport.IdleConnTimeout == 0 { + t.Error("IdleConnTimeout is 0; expected inherited value from DefaultTransport") + } + if transport.IdleConnTimeout != defaults.IdleConnTimeout { + t.Errorf("IdleConnTimeout = %v, want %v", transport.IdleConnTimeout, defaults.IdleConnTimeout) + } + + // MaxIdleConns must be inherited. + if transport.MaxIdleConns == 0 { + t.Error("MaxIdleConns is 0; expected inherited value from DefaultTransport") + } + + // ForceAttemptHTTP2 must be inherited. + if !transport.ForceAttemptHTTP2 { + t.Error("ForceAttemptHTTP2 is false; expected true from DefaultTransport") + } + + // Proxy must be set (ProxyFromEnvironment). + if transport.Proxy == nil { + t.Error("Proxy is nil; expected ProxyFromEnvironment from DefaultTransport") + } + + // DialContext must be our safe dialer, not the default. + if transport.DialContext == nil { + t.Error("DialContext is nil; expected safeDialContext") + } +} diff --git a/gitea/diff_size_test.go b/gitea/diff_size_test.go index 005f87c..b01c056 100644 --- a/gitea/diff_size_test.go +++ b/gitea/diff_size_test.go @@ -70,7 +70,7 @@ func TestGetPullRequestDiff_SizeLimits(t *testing.T) { })) defer server.Close() - client := NewClient(server.URL, "test-token") + client := NewTestClient(server.URL, "test-token") client.MaxDiffSize = tt.maxDiffSize client.RetryBackoff = []time.Duration{} diff --git a/gitea/export_test.go b/gitea/export_test.go new file mode 100644 index 0000000..e3a41df --- /dev/null +++ b/gitea/export_test.go @@ -0,0 +1,18 @@ +// Package gitea — export_test.go exposes test helpers to test files in this +// package. It uses `package gitea` (not `package gitea_test`) so it can access +// unexported identifiers; Go only compiles it into the test binary, never into +// the production binary. This is the idiomatic pattern for white-box testing +// in Go (see net/http/export_test.go in the stdlib for the same approach). +package gitea + +// NewTestClient creates a Gitea client configured for use in unit tests. +// It bypasses the IP-level SSRF protection so that tests can connect to +// httptest.Server instances (which listen on 127.0.0.1). +// +// Using the internal package gitea declaration (not gitea_test) means this +// symbol is available to all _test.go files in this package. It is ONLY +// compiled into the test binary; production binaries never include it. +// Production code must use NewClient, which enables the safe dialer. +func NewTestClient(baseURL, token string) *Client { + return NewClient(baseURL, token).WithUnsafeDialer() +} diff --git a/gitea/ipcheck.go b/gitea/ipcheck.go new file mode 100644 index 0000000..186b2ef --- /dev/null +++ b/gitea/ipcheck.go @@ -0,0 +1,91 @@ +// Package gitea provides a client for the Gitea API. +// ipcheck.go implements IP-level SSRF protection by checking resolved addresses +// against known blocked CIDR ranges (RFC1918, loopback, link-local, etc.). +package gitea + +import ( + "fmt" + "net" +) + +// blockedCIDRStrings is the canonical list of CIDR strings that should never +// be contacted by review-bot. See IsBlockedIP for the full list of covered +// address families. +// +// These are hard-coded literals: any parse failure is a programming error. +// Validity is verified by TestBlockedCIDRsValid in ipcheck_test.go. +var blockedCIDRStrings = []string{ + // IPv4 loopback + "127.0.0.0/8", + // IPv4 unspecified / "this network" + "0.0.0.0/8", + // RFC1918 private ranges + "10.0.0.0/8", + "172.16.0.0/12", + "192.168.0.0/16", + // IPv4 link-local (APIPA, also used by AWS instance metadata 169.254.169.254) + "169.254.0.0/16", + // IPv4 shared address space (RFC6598, carrier-grade NAT) + "100.64.0.0/10", + // IPv4 multicast + "224.0.0.0/4", + // IPv4 reserved / broadcast + "240.0.0.0/4", + // IPv6 loopback + "::1/128", + // IPv6 unspecified + "::/128", + // IPv6 link-local + "fe80::/10", + // IPv6 unique local (ULA) — RFC4193 + "fc00::/7", + // IPv6 multicast + "ff00::/8", +} + +// blockedCIDRs is the parsed form of blockedCIDRStrings. +// Any entry that fails to parse is recorded in blockedCIDRParseErrors instead +// of panicking; tests verify this slice is always empty via TestBlockedCIDRsValid. +var ( + blockedCIDRs []*net.IPNet + blockedCIDRParseErrors []string +) + +func init() { + blockedCIDRs = make([]*net.IPNet, 0, len(blockedCIDRStrings)) + for _, r := range blockedCIDRStrings { + _, cidr, err := net.ParseCIDR(r) + if err != nil { + // Record the error rather than panicking; TestBlockedCIDRsValid + // will catch this during tests, and the CI build will fail. + blockedCIDRParseErrors = append(blockedCIDRParseErrors, + fmt.Sprintf("ipcheck: invalid built-in CIDR %q: %v", r, err)) + continue + } + blockedCIDRs = append(blockedCIDRs, cidr) + } +} + +// IsBlockedIP reports whether ip is in a blocked address range. +// It is exported for use by the validate-url subcommand and tests outside +// this package. +// +// IPv6-mapped IPv4 addresses (e.g. ::ffff:192.168.1.1) are normalized to their +// IPv4 form before checking so that IPv4 CIDRs catch them. +// +// Based on: +// - RFC1918 private ranges +// - RFC5735 / RFC4193 special-use IPv4/IPv6 ranges +// - RFC4291 IPv6 link-local / loopback +func IsBlockedIP(ip net.IP) bool { + // Normalize IPv6-mapped IPv4 addresses (::ffff:x.x.x.x) to plain IPv4. + if v4 := ip.To4(); v4 != nil { + ip = v4 + } + for _, cidr := range blockedCIDRs { + if cidr.Contains(ip) { + return true + } + } + return false +} diff --git a/gitea/ipcheck_test.go b/gitea/ipcheck_test.go new file mode 100644 index 0000000..0d3d08e --- /dev/null +++ b/gitea/ipcheck_test.go @@ -0,0 +1,144 @@ +package gitea + +import ( + "net" + "testing" +) + +func TestIsBlockedIP(t *testing.T) { + blocked := []struct { + name string + ip string + }{ + // IPv4 loopback + {"loopback 127.0.0.1", "127.0.0.1"}, + {"loopback 127.0.0.2", "127.0.0.2"}, + {"loopback 127.255.255.255", "127.255.255.255"}, + // IPv4 unspecified + {"unspecified 0.0.0.0", "0.0.0.0"}, + {"unspecified 0.1.2.3", "0.1.2.3"}, + // RFC1918 + {"RFC1918 10.0.0.1", "10.0.0.1"}, + {"RFC1918 10.255.255.255", "10.255.255.255"}, + {"RFC1918 172.16.0.1", "172.16.0.1"}, + {"RFC1918 172.31.255.255", "172.31.255.255"}, + {"RFC1918 192.168.0.1", "192.168.0.1"}, + {"RFC1918 192.168.255.255", "192.168.255.255"}, + // Link-local (APIPA / AWS metadata) + {"link-local 169.254.0.1", "169.254.0.1"}, + {"link-local 169.254.169.254", "169.254.169.254"}, + // Shared address space (carrier-grade NAT) + {"CGN 100.64.0.1", "100.64.0.1"}, + {"CGN 100.127.255.255", "100.127.255.255"}, + // Multicast + {"multicast 224.0.0.1", "224.0.0.1"}, + {"multicast 239.255.255.255", "239.255.255.255"}, + // Reserved + {"reserved 240.0.0.1", "240.0.0.1"}, + {"broadcast 255.255.255.255", "255.255.255.255"}, + // IPv6 loopback + {"IPv6 loopback ::1", "::1"}, + // IPv6 unspecified + {"IPv6 unspecified ::", "::"}, + // IPv6 link-local + {"IPv6 link-local fe80::1", "fe80::1"}, + {"IPv6 link-local fe80::dead:beef", "fe80::dead:beef"}, + // IPv6 ULA + {"IPv6 ULA fc00::1", "fc00::1"}, + {"IPv6 ULA fd00::1", "fd00::1"}, + // IPv6 multicast + {"IPv6 multicast ff02::1", "ff02::1"}, + } + + for _, tc := range blocked { + t.Run(tc.name, func(t *testing.T) { + ip := net.ParseIP(tc.ip) + if ip == nil { + t.Fatalf("failed to parse IP %q", tc.ip) + } + if !IsBlockedIP(ip) { + t.Errorf("IsBlockedIP(%q) = false, want true", tc.ip) + } + }) + } + + allowed := []struct { + name string + ip string + }{ + {"public 8.8.8.8", "8.8.8.8"}, + {"public 1.1.1.1", "1.1.1.1"}, + {"public 198.51.100.1", "198.51.100.1"}, // RFC5737 TEST-NET-2 — a documentation-only range; + // not assigned to any real host, but intentionally left unblocked here because + // it has no special routing treatment (unlike RFC1918/loopback/link-local) and + // blocking it would require tracking every RFC5737 range without meaningful + // security benefit (no server should ever listen on a TEST-NET address). + {"public 151.101.1.1", "151.101.1.1"}, // Fastly + {"public IPv6 2001:4860:4860::8888", "2001:4860:4860::8888"}, // Google DNS + {"public IPv6 2606:4700:4700::1111", "2606:4700:4700::1111"}, // Cloudflare DNS + } + + for _, tc := range allowed { + t.Run(tc.name, func(t *testing.T) { + ip := net.ParseIP(tc.ip) + if ip == nil { + t.Fatalf("failed to parse IP %q", tc.ip) + } + if IsBlockedIP(ip) { + t.Errorf("IsBlockedIP(%q) = true, want false", tc.ip) + } + }) + } +} + +func TestIsBlockedIPv6MappedIPv4(t *testing.T) { + // ::ffff:192.168.1.1 is an IPv6-mapped IPv4 address — should be blocked as RFC1918. + // Construct it manually as a 16-byte IP. + mapped := net.IP{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0xff, 0xff, 192, 168, 1, 1} + if !IsBlockedIP(mapped) { + t.Errorf("IsBlockedIP(::ffff:192.168.1.1) = false, want true (IPv6-mapped IPv4 must be normalized)") + } + + // ::ffff:8.8.8.8 — IPv6-mapped public IP — should be allowed. + mappedPublic := net.IP{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0xff, 0xff, 8, 8, 8, 8} + if IsBlockedIP(mappedPublic) { + t.Errorf("IsBlockedIP(::ffff:8.8.8.8) = true, want false") + } +} + +func TestIsBlockedIPEdgeCases(t *testing.T) { + // The boundary between RFC1918 and public ranges. + // 172.15.255.255 is NOT private (just below 172.16.0.0/12). + notPrivate := net.ParseIP("172.15.255.255") + if IsBlockedIP(notPrivate) { + t.Errorf("IsBlockedIP(172.15.255.255) = true, want false (outside 172.16.0.0/12)") + } + // 172.32.0.0 is NOT private (just above 172.31.255.255). + notPrivate2 := net.ParseIP("172.32.0.0") + if IsBlockedIP(notPrivate2) { + t.Errorf("IsBlockedIP(172.32.0.0) = true, want false (outside 172.16.0.0/12)") + } + // CGN: 100.63.255.255 is NOT in 100.64.0.0/10. + notCGN := net.ParseIP("100.63.255.255") + if IsBlockedIP(notCGN) { + t.Errorf("IsBlockedIP(100.63.255.255) = true, want false (outside 100.64.0.0/10)") + } + // CGN: 100.128.0.0 is NOT in 100.64.0.0/10. + notCGN2 := net.ParseIP("100.128.0.0") + if IsBlockedIP(notCGN2) { + t.Errorf("IsBlockedIP(100.128.0.0) = true, want false (outside 100.64.0.0/10)") + } +} + +// TestBlockedCIDRsValid verifies that all entries in blockedCIDRStrings parse +// successfully. This catches programming errors in the CIDR list without +// requiring a startup panic. The init() function records parse failures in +// blockedCIDRParseErrors rather than panicking; this test makes those failures +// visible as test failures during CI. +func TestBlockedCIDRsValid(t *testing.T) { + if len(blockedCIDRParseErrors) > 0 { + for _, msg := range blockedCIDRParseErrors { + t.Errorf("CIDR parse error: %s", msg) + } + } +} diff --git a/gitea/post_review_comments_test.go b/gitea/post_review_comments_test.go index 7a8bf57..3792ece 100644 --- a/gitea/post_review_comments_test.go +++ b/gitea/post_review_comments_test.go @@ -31,7 +31,7 @@ func TestPostReview_WithComments(t *testing.T) { })) defer server.Close() - client := NewClient(server.URL, "test-token") + client := NewTestClient(server.URL, "test-token") comments := []ReviewComment{ {Path: "main.go", NewPosition: 42, Body: "[MAJOR] Something bad"}, {Path: "util.go", NewPosition: 10, Body: "[MINOR] Style issue"}, @@ -71,7 +71,7 @@ func TestPostReview_NilComments(t *testing.T) { })) defer server.Close() - client := NewClient(server.URL, "test-token") + client := NewTestClient(server.URL, "test-token") _, err := client.PostReview(context.Background(), "owner", "repo", 1, "APPROVED", "all good", "", nil) if err != nil { t.Fatalf("unexpected error: %v", err)