67d835909f
Adds a budget package that estimates token usage and progressively trims context to fit within model-specific limits. Trim order (least important first): 1. Language patterns 2. Repository conventions 3. Full file context 4. Diff (truncated as last resort) When content is trimmed, a note is appended to the user prompt so the LLM knows context was reduced. - New budget package with Fit(), EstimateTokens(), LimitForModel() - Model limit table (GPT-4.1: 128K, GPT-5: 200K, Claude: 200K) - Refactored review/prompt.go: BuildSystemBase() and BuildUserMeta() extract non-trimmable content; old functions delegate to new ones - main.go uses budget.Fit() instead of direct prompt assembly - 7 unit tests covering all trim paths Closes #19
329 lines
10 KiB
Go
329 lines
10 KiB
Go
package main
|
|
|
|
import (
|
|
"context"
|
|
"flag"
|
|
"fmt"
|
|
"log"
|
|
"os"
|
|
"strconv"
|
|
"strings"
|
|
"time"
|
|
|
|
"gitea.weiker.me/rodin/review-bot/budget"
|
|
"gitea.weiker.me/rodin/review-bot/gitea"
|
|
"gitea.weiker.me/rodin/review-bot/llm"
|
|
"gitea.weiker.me/rodin/review-bot/review"
|
|
)
|
|
|
|
var version = "dev"
|
|
|
|
func main() {
|
|
versionFlag := flag.Bool("version", false, "Print version and exit")
|
|
// CLI flags
|
|
giteaURL := flag.String("gitea-url", envOrDefault("GITEA_URL", ""), "Gitea instance URL")
|
|
repo := flag.String("repo", envOrDefault("GITEA_REPO", ""), "Repository (owner/name)")
|
|
prNum := flag.String("pr", envOrDefault("PR_NUMBER", ""), "Pull request number")
|
|
reviewerName := flag.String("reviewer-name", envOrDefault("REVIEWER_NAME", ""), "Reviewer display name")
|
|
reviewerToken := flag.String("reviewer-token", envOrDefault("REVIEWER_TOKEN", ""), "Gitea token for posting review")
|
|
llmBaseURL := flag.String("llm-base-url", envOrDefault("LLM_BASE_URL", ""), "LLM API base URL")
|
|
llmAPIKey := flag.String("llm-api-key", envOrDefault("LLM_API_KEY", ""), "LLM API key")
|
|
llmModel := flag.String("llm-model", envOrDefault("LLM_MODEL", ""), "LLM model name")
|
|
conventionsFile := flag.String("conventions-file", envOrDefault("CONVENTIONS_FILE", ""), "Conventions file path in repo (e.g. CLAUDE.md)")
|
|
patternsRepo := flag.String("patterns-repo", envOrDefault("PATTERNS_REPO", ""), "Repo with language patterns (e.g. rodin/elixir-patterns)")
|
|
patternsFiles := flag.String("patterns-files", envOrDefault("PATTERNS_FILES", "README.md"), "Comma-separated file paths to fetch from patterns repo")
|
|
dryRun := flag.Bool("dry-run", false, "Print review to stdout instead of posting")
|
|
llmTemp := flag.Float64("llm-temperature", envOrDefaultFloat("LLM_TEMPERATURE", 0), "LLM temperature (0 = server default)")
|
|
llmTimeout := flag.Int("llm-timeout", envOrDefaultInt("LLM_TIMEOUT", 300), "LLM request timeout in seconds (default 300)")
|
|
|
|
flag.Parse()
|
|
|
|
if *versionFlag {
|
|
fmt.Printf("review-bot %s\n", version)
|
|
os.Exit(0)
|
|
}
|
|
|
|
log.Printf("review-bot %s", version)
|
|
|
|
// Validate required fields
|
|
if *giteaURL == "" || *repo == "" || *prNum == "" || *reviewerToken == "" ||
|
|
*llmBaseURL == "" || *llmAPIKey == "" || *llmModel == "" {
|
|
fmt.Fprintf(os.Stderr, "Error: missing required flags or environment variables\n\n")
|
|
fmt.Fprintf(os.Stderr, "Required: --gitea-url, --repo, --pr, --reviewer-token, --llm-base-url, --llm-api-key, --llm-model\n")
|
|
os.Exit(1)
|
|
}
|
|
|
|
// Parse repo owner/name
|
|
parts := strings.SplitN(*repo, "/", 2)
|
|
if len(parts) != 2 {
|
|
log.Fatalf("Invalid repo format %q, expected owner/name", *repo)
|
|
}
|
|
owner, repoName := parts[0], parts[1]
|
|
|
|
// Parse PR number
|
|
prNumber, err := strconv.Atoi(*prNum)
|
|
if err != nil {
|
|
log.Fatalf("Invalid PR number %q: %v", *prNum, err)
|
|
}
|
|
|
|
// Initialize clients
|
|
giteaClient := gitea.NewClient(*giteaURL, *reviewerToken)
|
|
llmClient := llm.NewClient(*llmBaseURL, *llmAPIKey, *llmModel)
|
|
if *llmTemp < 0 || *llmTemp > 2 {
|
|
log.Fatal("--llm-temperature must be between 0 and 2")
|
|
}
|
|
if *llmTemp > 0 {
|
|
llmClient.WithTemperature(*llmTemp)
|
|
}
|
|
if *llmTimeout > 0 {
|
|
llmClient.WithTimeout(time.Duration(*llmTimeout) * time.Second)
|
|
}
|
|
|
|
// Create a top-level context. Timeout derived from LLM timeout + 1 min for other ops.
|
|
overallTimeout := time.Duration(*llmTimeout)*time.Second + time.Minute
|
|
ctx, cancel := context.WithTimeout(context.Background(), overallTimeout)
|
|
defer cancel()
|
|
|
|
log.Printf("Reviewing PR #%d on %s/%s", prNumber, owner, repoName)
|
|
|
|
// Step 1: Fetch PR metadata
|
|
pr, err := giteaClient.GetPullRequest(ctx, owner, repoName, prNumber)
|
|
if err != nil {
|
|
log.Fatalf("Failed to fetch PR: %v", err)
|
|
}
|
|
log.Printf("PR: %s", pr.Title)
|
|
|
|
// Step 2: Fetch diff
|
|
diff, err := giteaClient.GetPullRequestDiff(ctx, owner, repoName, prNumber)
|
|
if err != nil {
|
|
log.Fatalf("Failed to fetch diff: %v", err)
|
|
}
|
|
log.Printf("Diff size: %d bytes", len(diff))
|
|
|
|
// Step 3: Fetch full file content for modified files
|
|
fileContext := ""
|
|
files, err := giteaClient.GetPullRequestFiles(ctx, owner, repoName, prNumber)
|
|
if err != nil {
|
|
log.Printf("Warning: could not fetch PR files list: %v", err)
|
|
} else {
|
|
fileContext = fetchFileContext(ctx, giteaClient, owner, repoName, pr.Head.Ref, files)
|
|
log.Printf("Fetched full context for %d files", len(files))
|
|
}
|
|
|
|
// Step 4: Check CI status
|
|
ciPassed := true
|
|
ciDetails := ""
|
|
if pr.Head.Sha != "" {
|
|
statuses, err := giteaClient.GetCommitStatuses(ctx, owner, repoName, pr.Head.Sha)
|
|
if err != nil {
|
|
log.Printf("Warning: could not fetch CI status: %v", err)
|
|
} else {
|
|
ciPassed, ciDetails = evaluateCIStatus(statuses)
|
|
log.Printf("CI status: passed=%v", ciPassed)
|
|
}
|
|
}
|
|
|
|
// Step 5: Load conventions file if specified
|
|
conventions := ""
|
|
if *conventionsFile != "" {
|
|
content, err := giteaClient.GetFileContent(ctx, owner, repoName, *conventionsFile)
|
|
if err != nil {
|
|
log.Printf("Warning: could not load conventions file %q: %v", *conventionsFile, err)
|
|
} else {
|
|
conventions = content
|
|
log.Printf("Loaded conventions file: %s (%d bytes)", *conventionsFile, len(conventions))
|
|
}
|
|
}
|
|
|
|
// Step 6: Load patterns from external repo if specified
|
|
patterns := ""
|
|
if *patternsRepo != "" {
|
|
patterns = fetchPatterns(ctx, giteaClient, *patternsRepo, *patternsFiles)
|
|
log.Printf("Loaded patterns from %s (%d bytes)", *patternsRepo, len(patterns))
|
|
}
|
|
|
|
// Step 7: Budget-aware prompt assembly
|
|
sections := budget.Sections{
|
|
SystemBase: review.BuildSystemBase(),
|
|
Patterns: patterns,
|
|
Conventions: conventions,
|
|
FileContext: fileContext,
|
|
Diff: diff,
|
|
UserMeta: review.BuildUserMeta(pr.Title, pr.Body, ciPassed, ciDetails),
|
|
}
|
|
budgetResult := budget.Fit(*llmModel, sections)
|
|
log.Printf("Token estimate: ~%dK (limit: %dK)", budgetResult.EstTokens/1000, budget.LimitForModel(*llmModel)/1000)
|
|
if len(budgetResult.Trimmed) > 0 {
|
|
log.Printf("Context trimmed: %v", budgetResult.Trimmed)
|
|
}
|
|
|
|
// Step 8: Call LLM
|
|
log.Printf("Sending to LLM (%s)...", *llmModel)
|
|
messages := []llm.Message{
|
|
{Role: "system", Content: budgetResult.SystemPrompt},
|
|
{Role: "user", Content: budgetResult.UserPrompt},
|
|
}
|
|
|
|
response, err := llmClient.Complete(ctx, messages)
|
|
if err != nil {
|
|
log.Fatalf("LLM request failed: %v", err)
|
|
}
|
|
log.Printf("LLM response received (%d bytes)", len(response))
|
|
|
|
// Step 9: Parse response
|
|
result, err := review.ParseResponse(response)
|
|
if err != nil {
|
|
log.Fatalf("Failed to parse LLM response: %v", err)
|
|
}
|
|
log.Printf("Verdict: %s (%d findings)", result.Verdict, len(result.Findings))
|
|
|
|
// Step 10: Format and post review
|
|
reviewBody := review.FormatMarkdown(result, *reviewerName)
|
|
event := review.GiteaEvent(result.Verdict)
|
|
|
|
if *dryRun {
|
|
fmt.Println("--- DRY RUN ---")
|
|
fmt.Printf("Event: %s\n\n", event)
|
|
fmt.Println(reviewBody)
|
|
return
|
|
}
|
|
|
|
log.Printf("Posting review (event=%s)...", event)
|
|
if err := giteaClient.PostReview(ctx, owner, repoName, prNumber, event, reviewBody); err != nil {
|
|
log.Fatalf("Failed to post review: %v", err)
|
|
}
|
|
log.Printf("Review posted successfully!")
|
|
}
|
|
|
|
// fetchFileContext fetches the full content of modified files from the PR branch.
|
|
func fetchFileContext(ctx context.Context, client *gitea.Client, owner, repo, ref string, files []gitea.ChangedFile) string {
|
|
var sb strings.Builder
|
|
for _, f := range files {
|
|
if ctx.Err() != nil {
|
|
break
|
|
}
|
|
if f.Status == "removed" {
|
|
continue // Skip deleted files
|
|
}
|
|
content, err := client.GetFileContentRef(ctx, owner, repo, f.Filename, ref)
|
|
if err != nil {
|
|
log.Printf("Warning: could not fetch %s: %v", f.Filename, err)
|
|
continue
|
|
}
|
|
sb.WriteString(fmt.Sprintf("--- %s ---\n", f.Filename))
|
|
sb.WriteString("```\n")
|
|
sb.WriteString(content)
|
|
sb.WriteString("\n```\n\n")
|
|
}
|
|
return sb.String()
|
|
}
|
|
|
|
// fetchPatterns fetches pattern files from one or more external repos.
|
|
// patternsRepo is comma-separated list of owner/name repos.
|
|
// patternsFiles is comma-separated list of file paths or directories.
|
|
// If a path ends with / or is a directory, all files within it are fetched recursively.
|
|
func fetchPatterns(ctx context.Context, client *gitea.Client, patternsRepo, patternsFiles string) string {
|
|
var sb strings.Builder
|
|
|
|
repos := strings.Split(patternsRepo, ",")
|
|
paths := strings.Split(patternsFiles, ",")
|
|
|
|
for _, repoRef := range repos {
|
|
if ctx.Err() != nil {
|
|
break
|
|
}
|
|
repoRef = strings.TrimSpace(repoRef)
|
|
if repoRef == "" {
|
|
continue
|
|
}
|
|
parts := strings.SplitN(repoRef, "/", 2)
|
|
if len(parts) != 2 {
|
|
log.Printf("Warning: invalid patterns-repo format %q, expected owner/name", repoRef)
|
|
continue
|
|
}
|
|
owner, repo := parts[0], parts[1]
|
|
|
|
for _, path := range paths {
|
|
path = strings.TrimSpace(path)
|
|
if path == "" {
|
|
continue
|
|
}
|
|
|
|
files, err := client.GetAllFilesInPath(ctx, owner, repo, path)
|
|
if err != nil {
|
|
log.Printf("Warning: could not fetch %s from %s: %v", path, repoRef, err)
|
|
continue
|
|
}
|
|
|
|
for filepath, content := range files {
|
|
// Only include markdown and text files as patterns
|
|
if !isPatternFile(filepath) {
|
|
continue
|
|
}
|
|
sb.WriteString(fmt.Sprintf("### %s/%s\n\n%s\n\n", repoRef, filepath, content))
|
|
}
|
|
}
|
|
}
|
|
return sb.String()
|
|
}
|
|
|
|
// isPatternFile returns true if the file should be included as pattern content.
|
|
func isPatternFile(path string) bool {
|
|
lower := strings.ToLower(path)
|
|
return strings.HasSuffix(lower, ".md") ||
|
|
strings.HasSuffix(lower, ".txt") ||
|
|
strings.HasSuffix(lower, ".yml") ||
|
|
strings.HasSuffix(lower, ".yaml")
|
|
}
|
|
|
|
// evaluateCIStatus checks if all CI statuses indicate success.
|
|
func evaluateCIStatus(statuses []gitea.CommitStatus) (passed bool, details string) {
|
|
if len(statuses) == 0 {
|
|
return true, "no CI statuses found"
|
|
}
|
|
|
|
var failed []string
|
|
for _, s := range statuses {
|
|
switch s.Status {
|
|
case "success":
|
|
// good
|
|
case "pending":
|
|
// treat pending as not-failed
|
|
case "failure", "error":
|
|
failed = append(failed, fmt.Sprintf("%s: %s", s.Context, s.Description))
|
|
}
|
|
}
|
|
|
|
if len(failed) > 0 {
|
|
return false, strings.Join(failed, "; ")
|
|
}
|
|
return true, "all checks passed"
|
|
}
|
|
|
|
func envOrDefault(key, defaultVal string) string {
|
|
if v := os.Getenv(key); v != "" {
|
|
return v
|
|
}
|
|
return defaultVal
|
|
}
|
|
|
|
func envOrDefaultFloat(key string, defaultVal float64) float64 {
|
|
if v := os.Getenv(key); v != "" {
|
|
f, err := strconv.ParseFloat(v, 64)
|
|
if err == nil {
|
|
return f
|
|
}
|
|
}
|
|
return defaultVal
|
|
}
|
|
|
|
func envOrDefaultInt(key string, defaultVal int) int {
|
|
if v := os.Getenv(key); v != "" {
|
|
i, err := strconv.Atoi(v)
|
|
if err == nil {
|
|
return i
|
|
}
|
|
}
|
|
return defaultVal
|
|
}
|