Compare commits

..

1 Commits

Author SHA1 Message Date
Rodin 14a0c2a946 feat: add Anthropic Messages API support (#18)
CI / test (pull_request) Successful in 13s
CI / review (gpt-5, sonnet, SONNET_REVIEW_TOKEN) (pull_request) Successful in 1m2s
CI / review (gpt-5-mini, gpt, GPT_REVIEW_TOKEN) (pull_request) Successful in 1m43s
Adds --llm-provider flag (openai|anthropic) to switch between API formats.

Anthropic implementation:
- POST /messages endpoint
- x-api-key + anthropic-version headers
- System prompt as top-level field (not a message)
- max_tokens: 8192 for response generation
- Parses content blocks [{type: "text", text: "..."}]

Changes:
- llm/client.go: Provider type, completeAnthropic(), doRequest() shared helper
- cmd/review-bot/main.go: --llm-provider / LLM_PROVIDER flag
- .gitea/actions/review/action.yml: llm-provider input + env
- llm/client_test.go: 4 new tests for Anthropic path

Backwards compatible — default provider is still openai.

Closes #18
2026-05-01 18:49:17 -07:00
9 changed files with 261 additions and 544 deletions
+5
View File
@@ -34,6 +34,10 @@ inputs:
llm-model: llm-model:
description: 'LLM model name' description: 'LLM model name'
required: true required: true
llm-provider:
description: 'LLM API provider: openai or anthropic (default openai)'
required: false
default: 'openai'
conventions-file: conventions-file:
description: 'Path to conventions file in the repo (e.g. CLAUDE.md)' description: 'Path to conventions file in the repo (e.g. CLAUDE.md)'
required: false required: false
@@ -140,6 +144,7 @@ runs:
PATTERNS_FILES: ${{ inputs.patterns-files }} PATTERNS_FILES: ${{ inputs.patterns-files }}
LLM_TEMPERATURE: ${{ inputs.temperature }} LLM_TEMPERATURE: ${{ inputs.temperature }}
LLM_TIMEOUT: ${{ inputs.timeout }} LLM_TIMEOUT: ${{ inputs.timeout }}
LLM_PROVIDER: ${{ inputs.llm-provider }}
run: | run: |
ARGS="" ARGS=""
if [ "${{ inputs.dry-run }}" = "true" ]; then if [ "${{ inputs.dry-run }}" = "true" ]; then
+2 -3
View File
@@ -31,7 +31,7 @@ jobs:
model: gpt-5 model: gpt-5
- name: gpt - name: gpt
token_secret: GPT_REVIEW_TOKEN token_secret: GPT_REVIEW_TOKEN
model: gpt-4.1 model: gpt-5-mini
steps: steps:
- uses: actions/checkout@v4 - uses: actions/checkout@v4
- uses: actions/setup-go@v5 - uses: actions/setup-go@v5
@@ -49,6 +49,5 @@ jobs:
LLM_MODEL: ${{ matrix.model }} LLM_MODEL: ${{ matrix.model }}
CONVENTIONS_FILE: "CONVENTIONS.md" CONVENTIONS_FILE: "CONVENTIONS.md"
PATTERNS_REPO: "rodin/go-patterns" PATTERNS_REPO: "rodin/go-patterns"
PATTERNS_FILES: "README.md,patterns/" PATTERNS_FILES: "README.md,docs/"
LLM_TIMEOUT: "600"
run: ./review-bot run: ./review-bot
-226
View File
@@ -1,226 +0,0 @@
// Package budget manages LLM context window budgeting for review-bot.
//
// It estimates token usage and progressively trims context content to fit
// within model-specific limits. The trimming order (least important first):
// patterns → conventions → file context → diff truncation.
package budget
import (
"fmt"
"strings"
"unicode/utf8"
)
// modelLimit pairs a model name prefix with its context window size.
type modelLimit struct {
prefix string
limit int
}
// Known model context limits (in tokens), ordered longest-prefix-first
// for deterministic matching.
var modelLimits = []modelLimit{
{"claude-haiku-3.5-20241022", 200_000},
{"claude-sonnet-4-20250514", 200_000},
{"claude-opus-4-20250514", 200_000},
{"gpt-4.1-mini", 128_000},
{"gpt-5-mini", 200_000},
{"gpt-4.1", 128_000},
{"gpt-5", 200_000},
}
const defaultLimit = 128_000
// reserveTokens is headroom for the response generation.
const reserveTokens = 4_000
const diffTruncMarker = "\n\n... [diff truncated due to context limit] ..."
const diffTooLargeMarker = "... [diff too large for context window — review manually] ..."
const userMetaTruncMarker = "\n... [description truncated] ..."
// EstimateTokens estimates the number of tokens in a string.
// Uses the rough heuristic of ~4 bytes per token, which is
// conservative for English text and code.
func EstimateTokens(s string) int {
return len(s) / 4
}
// LimitForModel returns the context window size for the given model.
// Uses longest-prefix-first matching for deterministic results.
func LimitForModel(model string) int {
for _, ml := range modelLimits {
if model == ml.prefix || strings.HasPrefix(model, ml.prefix) {
return ml.limit
}
}
return defaultLimit
}
// Sections holds the prompt content sections in trim priority order.
// When the total exceeds the budget, sections are trimmed from least
// important (Patterns) to most important (Diff).
type Sections struct {
SystemBase string // Core instructions (never trimmed)
Patterns string // Language patterns (trimmed first)
Conventions string // Repo conventions (trimmed second)
FileContext string // Full file content (trimmed third)
Diff string // The actual diff (trimmed last, only truncated)
UserMeta string // PR title, description, CI status (truncated only if base exceeds budget)
}
// Result holds the trimmed content and metadata about what was dropped.
type Result struct {
SystemPrompt string
UserPrompt string
Trimmed []string // Human-readable descriptions of what was trimmed
EstTokens int // Estimated total tokens after trimming
}
// Fit trims sections to fit within the model's context limit.
// Returns the assembled prompts and a list of what was trimmed.
func Fit(model string, sections Sections) Result {
limit := LimitForModel(model) - reserveTokens
baseTokens := EstimateTokens(sections.SystemBase) + EstimateTokens(sections.UserMeta)
available := limit - baseTokens
if available < 0 {
// Base content alone exceeds budget. Truncate UserMeta (keep first ~1000 tokens).
if len(sections.UserMeta) > 4000 {
sections.UserMeta = truncateUTF8(sections.UserMeta, 4000) + userMetaTruncMarker
baseTokens = EstimateTokens(sections.SystemBase) + EstimateTokens(sections.UserMeta)
available = limit - baseTokens
}
if available < 0 {
available = 0
}
}
// Trimmable sections in priority order (first = dropped first)
type entry struct {
name string
content *string
}
entries := []entry{
{"patterns", &sections.Patterns},
{"conventions", &sections.Conventions},
{"file context", &sections.FileContext},
}
// Check if everything fits
totalTrimmable := EstimateTokens(sections.Diff)
for _, e := range entries {
totalTrimmable += EstimateTokens(*e.content)
}
var trimmed []string
if totalTrimmable > available {
// Trim from least important
for i := range entries {
tokens := EstimateTokens(*entries[i].content)
if tokens == 0 {
continue
}
trimmed = append(trimmed, fmt.Sprintf("%s (~%dK tokens)", entries[i].name, tokens/1000))
*entries[i].content = ""
// Recalculate
totalTrimmable = EstimateTokens(sections.Diff)
for _, e := range entries {
totalTrimmable += EstimateTokens(*e.content)
}
if totalTrimmable <= available {
break
}
}
}
// If still too large, truncate the diff
if totalTrimmable > available {
diffBudget := available
for _, e := range entries {
diffBudget -= EstimateTokens(*e.content)
}
if diffBudget < 0 {
diffBudget = 0
}
// Reserve space for truncation marker
markerBudget := EstimateTokens(diffTruncMarker)
effectiveBudget := diffBudget - markerBudget
if effectiveBudget < 0 {
effectiveBudget = 0
}
maxChars := effectiveBudget * 4
if maxChars < len(sections.Diff) {
removed := EstimateTokens(sections.Diff) - diffBudget
trimmed = append(trimmed, fmt.Sprintf("diff truncated (~%dK tokens removed)", removed/1000))
if maxChars > 0 {
if diffBudget >= markerBudget {
sections.Diff = truncateUTF8(sections.Diff, maxChars) + diffTruncMarker
} else {
sections.Diff = truncateUTF8(sections.Diff, maxChars)
}
} else {
sections.Diff = diffTooLargeMarker
}
}
}
finalTokens := baseTokens
for _, e := range entries {
finalTokens += EstimateTokens(*e.content)
}
finalTokens += EstimateTokens(sections.Diff)
return buildResult(sections, trimmed, finalTokens)
}
func buildResult(s Sections, trimmed []string, estTokens int) Result {
var sys strings.Builder
sys.WriteString(s.SystemBase)
if s.Patterns != "" {
sys.WriteString("\n\n## Language Patterns & Idioms\n\nUse the following patterns as review criteria. Code that violates these established patterns is a finding:\n\n")
sys.WriteString(s.Patterns)
}
if s.Conventions != "" {
sys.WriteString("\n\n## Repository Conventions\n\nThe repository has the following coding conventions that must be respected:\n\n")
sys.WriteString(s.Conventions)
}
var usr strings.Builder
usr.WriteString(s.UserMeta)
if s.FileContext != "" {
usr.WriteString("\n### Full File Context (modified files)\n\n")
usr.WriteString(s.FileContext)
usr.WriteString("\n")
}
if s.Diff != "" {
usr.WriteString("\n### Diff (changes to review)\n\n```diff\n")
usr.WriteString(s.Diff)
usr.WriteString("\n```\n")
}
if len(trimmed) > 0 {
usr.WriteString("\n⚠️ Note: Context was trimmed to fit model limits. Dropped: ")
usr.WriteString(strings.Join(trimmed, ", "))
usr.WriteString("\n")
}
return Result{
SystemPrompt: sys.String(),
UserPrompt: usr.String(),
Trimmed: trimmed,
EstTokens: estTokens,
}
}
// truncateUTF8 truncates s to at most maxBytes without splitting multi-byte
// UTF-8 characters. Returns a valid UTF-8 string of at most maxBytes bytes.
func truncateUTF8(s string, maxBytes int) string {
if len(s) <= maxBytes {
return s
}
for maxBytes > 0 && !utf8.RuneStart(s[maxBytes]) {
maxBytes--
}
return s[:maxBytes]
}
-203
View File
@@ -1,203 +0,0 @@
package budget
import (
"strings"
"testing"
)
func TestEstimateTokens(t *testing.T) {
tests := []struct {
input string
want int
}{
{"", 0},
{"abcd", 1},
{"12345678", 2},
{strings.Repeat("x", 400), 100},
}
for _, tt := range tests {
got := EstimateTokens(tt.input)
if got != tt.want {
t.Errorf("EstimateTokens(%d chars) = %d, want %d", len(tt.input), got, tt.want)
}
}
}
func TestLimitForModel(t *testing.T) {
tests := []struct {
model string
want int
}{
{"gpt-4.1", 128_000},
{"gpt-5", 200_000},
{"gpt-5-mini", 200_000},
{"unknown-model", defaultLimit},
{"gpt-4.1-2026-01-01", 128_000}, // prefix match
}
for _, tt := range tests {
got := LimitForModel(tt.model)
if got != tt.want {
t.Errorf("LimitForModel(%q) = %d, want %d", tt.model, got, tt.want)
}
}
}
func TestFit_AllFits(t *testing.T) {
s := Sections{
SystemBase: "system instructions",
Patterns: "some patterns",
Conventions: "some conventions",
FileContext: "file content",
Diff: "diff content",
UserMeta: "PR: title\n",
}
result := Fit("gpt-5", s)
if len(result.Trimmed) != 0 {
t.Errorf("expected no trimming, got %v", result.Trimmed)
}
if !strings.Contains(result.SystemPrompt, "some patterns") {
t.Error("expected patterns in system prompt")
}
if !strings.Contains(result.SystemPrompt, "some conventions") {
t.Error("expected conventions in system prompt")
}
if !strings.Contains(result.UserPrompt, "file content") {
t.Error("expected file context in user prompt")
}
}
func TestFit_TrimsPatterns(t *testing.T) {
// Create content that exceeds 128K token budget for gpt-4.1
// Budget ≈ 128K - 4K reserve = 124K tokens = ~496K chars
// Fill patterns with enough to push over
bigPatterns := strings.Repeat("x", 500_000) // ~125K tokens
s := Sections{
SystemBase: "base",
Patterns: bigPatterns,
Conventions: "conventions",
FileContext: "files",
Diff: "diff",
UserMeta: "meta",
}
result := Fit("gpt-4.1", s)
if len(result.Trimmed) == 0 {
t.Fatal("expected trimming")
}
if !strings.Contains(result.Trimmed[0], "patterns") {
t.Errorf("expected patterns to be trimmed first, got %v", result.Trimmed)
}
if strings.Contains(result.SystemPrompt, bigPatterns[:100]) {
t.Error("expected patterns to be removed from output")
}
// Conventions should survive
if !strings.Contains(result.SystemPrompt, "conventions") {
t.Error("expected conventions to survive after patterns trimmed")
}
}
func TestFit_TrimsConventions(t *testing.T) {
// Patterns + conventions + diff all exceed budget even after patterns removed
big := strings.Repeat("y", 520_000) // ~130K tokens each (exceeds 124K budget even alone)
s := Sections{
SystemBase: "base",
Patterns: big,
Conventions: big,
FileContext: "files",
Diff: "diff",
UserMeta: "meta",
}
result := Fit("gpt-4.1", s)
if len(result.Trimmed) < 2 {
t.Fatalf("expected at least 2 trimmed, got %v", result.Trimmed)
}
if !strings.Contains(result.Trimmed[0], "patterns") {
t.Errorf("expected patterns trimmed first, got %s", result.Trimmed[0])
}
if !strings.Contains(result.Trimmed[1], "conventions") {
t.Errorf("expected conventions trimmed second, got %s", result.Trimmed[1])
}
}
func TestFit_TruncatesDiff(t *testing.T) {
// Only diff is huge, no patterns/conventions
hugeDiff := strings.Repeat("z", 600_000) // ~150K tokens > 128K limit
s := Sections{
SystemBase: "base",
Diff: hugeDiff,
UserMeta: "meta",
}
result := Fit("gpt-4.1", s)
if len(result.Trimmed) == 0 {
t.Fatal("expected diff truncation")
}
if !strings.Contains(result.Trimmed[len(result.Trimmed)-1], "diff truncated") {
t.Errorf("expected diff truncation note, got %v", result.Trimmed)
}
if !strings.Contains(result.UserPrompt, "[diff truncated due to context limit]") {
t.Error("expected truncation marker in user prompt")
}
}
func TestFit_PreservesNoteInOutput(t *testing.T) {
big := strings.Repeat("w", 500_000)
s := Sections{
SystemBase: "base",
Patterns: big,
Diff: "small diff",
UserMeta: "meta",
}
result := Fit("gpt-4.1", s)
if !strings.Contains(result.UserPrompt, "⚠️ Note: Context was trimmed") {
t.Error("expected trimming note in user prompt")
}
}
func TestFit_HugeUserMeta(t *testing.T) {
// UserMeta so large that base alone exceeds limit
// Use a unique marker past the truncation point
hugeDesc := strings.Repeat("d", 5000) + "UNIQUE_MARKER_PAST_TRUNCATION" + strings.Repeat("d", 595_000)
s := Sections{
SystemBase: "base",
Diff: "small diff",
UserMeta: hugeDesc,
}
result := Fit("gpt-4.1", s)
limit := LimitForModel("gpt-4.1") - reserveTokens
if result.EstTokens > limit {
t.Errorf("EstTokens %d exceeds limit %d", result.EstTokens, limit)
}
// Content past truncation point should not be present
if strings.Contains(result.UserPrompt, "UNIQUE_MARKER_PAST_TRUNCATION") {
t.Error("expected UserMeta to be truncated but found content past truncation point")
}
// Truncation marker should be present
if !strings.Contains(result.UserPrompt, "[description truncated]") {
t.Error("expected truncation marker in output")
}
}
func TestFit_NeverExceedsLimit(t *testing.T) {
// All sections huge — verify final tokens never exceed limit
big := strings.Repeat("a", 200_000)
s := Sections{
SystemBase: strings.Repeat("s", 8000),
Patterns: big,
Conventions: big,
FileContext: big,
Diff: big,
UserMeta: strings.Repeat("m", 8000),
}
result := Fit("gpt-4.1", s)
limit := LimitForModel("gpt-4.1") - reserveTokens
if result.EstTokens > limit {
t.Errorf("EstTokens %d exceeds limit %d (trimmed: %v)", result.EstTokens, limit, result.Trimmed)
}
}
+15 -20
View File
@@ -10,7 +10,6 @@ import (
"strings" "strings"
"time" "time"
"gitea.weiker.me/rodin/review-bot/budget"
"gitea.weiker.me/rodin/review-bot/gitea" "gitea.weiker.me/rodin/review-bot/gitea"
"gitea.weiker.me/rodin/review-bot/llm" "gitea.weiker.me/rodin/review-bot/llm"
"gitea.weiker.me/rodin/review-bot/review" "gitea.weiker.me/rodin/review-bot/review"
@@ -35,6 +34,7 @@ func main() {
dryRun := flag.Bool("dry-run", false, "Print review to stdout instead of posting") dryRun := flag.Bool("dry-run", false, "Print review to stdout instead of posting")
llmTemp := flag.Float64("llm-temperature", envOrDefaultFloat("LLM_TEMPERATURE", 0), "LLM temperature (0 = server default)") llmTemp := flag.Float64("llm-temperature", envOrDefaultFloat("LLM_TEMPERATURE", 0), "LLM temperature (0 = server default)")
llmTimeout := flag.Int("llm-timeout", envOrDefaultInt("LLM_TIMEOUT", 300), "LLM request timeout in seconds (default 300)") llmTimeout := flag.Int("llm-timeout", envOrDefaultInt("LLM_TIMEOUT", 300), "LLM request timeout in seconds (default 300)")
llmProvider := flag.String("llm-provider", envOrDefault("LLM_PROVIDER", "openai"), "LLM API provider: openai or anthropic")
flag.Parse() flag.Parse()
@@ -75,6 +75,12 @@ func main() {
if *llmTemp > 0 { if *llmTemp > 0 {
llmClient.WithTemperature(*llmTemp) llmClient.WithTemperature(*llmTemp)
} }
switch llm.Provider(*llmProvider) {
case llm.ProviderOpenAI, llm.ProviderAnthropic:
llmClient.WithProvider(llm.Provider(*llmProvider))
default:
log.Fatalf("Invalid --llm-provider %q, must be openai or anthropic", *llmProvider)
}
if *llmTimeout > 0 { if *llmTimeout > 0 {
llmClient.WithTimeout(time.Duration(*llmTimeout) * time.Second) llmClient.WithTimeout(time.Duration(*llmTimeout) * time.Second)
} }
@@ -142,26 +148,15 @@ func main() {
log.Printf("Loaded patterns from %s (%d bytes)", *patternsRepo, len(patterns)) log.Printf("Loaded patterns from %s (%d bytes)", *patternsRepo, len(patterns))
} }
// Step 7: Budget-aware prompt assembly // Step 7: Build prompts
sections := budget.Sections{ systemPrompt := review.BuildSystemPrompt(conventions, patterns)
SystemBase: review.BuildSystemBase(), userPrompt := review.BuildUserPrompt(pr.Title, pr.Body, diff, fileContext, ciPassed, ciDetails)
Patterns: patterns,
Conventions: conventions,
FileContext: fileContext,
Diff: diff,
UserMeta: review.BuildUserMeta(pr.Title, pr.Body, ciPassed, ciDetails),
}
budgetResult := budget.Fit(*llmModel, sections)
log.Printf("Token estimate: ~%dK (limit: %dK)", budgetResult.EstTokens/1000, budget.LimitForModel(*llmModel)/1000)
if len(budgetResult.Trimmed) > 0 {
log.Printf("Context trimmed: %v", budgetResult.Trimmed)
}
// Step 8: Call LLM // Step 8: Call LLM
log.Printf("Sending to LLM (%s)...", *llmModel) log.Printf("Sending to LLM (%s)...", *llmModel)
messages := []llm.Message{ messages := []llm.Message{
{Role: "system", Content: budgetResult.SystemPrompt}, {Role: "system", Content: systemPrompt},
{Role: "user", Content: budgetResult.UserPrompt}, {Role: "user", Content: userPrompt},
} }
response, err := llmClient.Complete(ctx, messages) response, err := llmClient.Complete(ctx, messages)
@@ -255,12 +250,12 @@ func fetchPatterns(ctx context.Context, client *gitea.Client, patternsRepo, patt
continue continue
} }
for filePath, content := range files { for filepath, content := range files {
// Only include markdown and text files as patterns // Only include markdown and text files as patterns
if !isPatternFile(filePath) { if !isPatternFile(filepath) {
continue continue
} }
sb.WriteString(fmt.Sprintf("### %s/%s\n\n%s\n\n", repoRef, filePath, content)) sb.WriteString(fmt.Sprintf("### %s/%s\n\n%s\n\n", repoRef, filepath, content))
} }
} }
} }
+144 -22
View File
@@ -1,4 +1,6 @@
// Package llm provides a client for OpenAI-compatible chat completion APIs. // Package llm provides clients for LLM chat completion APIs.
//
// Supports OpenAI-compatible (default) and Anthropic Messages API providers.
package llm package llm
import ( import (
@@ -12,23 +14,36 @@ import (
"time" "time"
) )
// Client calls an OpenAI-compatible chat completion API. // Provider identifies which API format to use.
type Provider string
const (
// ProviderOpenAI uses the OpenAI-compatible chat/completions endpoint.
ProviderOpenAI Provider = "openai"
// ProviderAnthropic uses the Anthropic Messages API endpoint.
ProviderAnthropic Provider = "anthropic"
)
// Client calls an LLM chat completion API.
// A Client is safe for concurrent use by multiple goroutines after construction. // A Client is safe for concurrent use by multiple goroutines after construction.
// WithTimeout and WithTemperature must be called during setup, before concurrent use. // WithTimeout, WithTemperature, and WithProvider must be called during setup,
// before concurrent use.
type Client struct { type Client struct {
baseURL string baseURL string
apiKey string apiKey string
model string model string
temperature float64 temperature float64
provider Provider
http *http.Client http *http.Client
} }
// NewClient creates a new LLM client. // NewClient creates a new LLM client. Default provider is OpenAI-compatible.
func NewClient(baseURL, apiKey, model string) *Client { func NewClient(baseURL, apiKey, model string) *Client {
return &Client{ return &Client{
baseURL: strings.TrimRight(baseURL, "/"), baseURL: strings.TrimRight(baseURL, "/"),
apiKey: apiKey, apiKey: apiKey,
model: model, model: model,
provider: ProviderOpenAI,
http: &http.Client{Timeout: 5 * time.Minute}, http: &http.Client{Timeout: 5 * time.Minute},
} }
} }
@@ -45,20 +60,39 @@ func (c *Client) WithTemperature(t float64) *Client {
return c return c
} }
// WithProvider sets the API provider format (openai or anthropic).
func (c *Client) WithProvider(p Provider) *Client {
c.provider = p
return c
}
// Message represents a chat message. // Message represents a chat message.
type Message struct { type Message struct {
Role string `json:"role"` Role string `json:"role"`
Content string `json:"content"` Content string `json:"content"`
} }
// ChatRequest is the request payload. // Complete sends a chat completion request and returns the assistant's response content.
// The first message with role "system" is treated as the system prompt.
func (c *Client) Complete(ctx context.Context, messages []Message) (string, error) {
switch c.provider {
case ProviderAnthropic:
return c.completeAnthropic(ctx, messages)
default:
return c.completeOpenAI(ctx, messages)
}
}
// --- OpenAI-compatible implementation ---
// ChatRequest is the OpenAI request payload.
type ChatRequest struct { type ChatRequest struct {
Model string `json:"model"` Model string `json:"model"`
Messages []Message `json:"messages"` Messages []Message `json:"messages"`
Temperature float64 `json:"temperature,omitempty"` Temperature float64 `json:"temperature,omitempty"`
} }
// ChatResponse is the response from the API. // ChatResponse is the OpenAI response.
type ChatResponse struct { type ChatResponse struct {
Choices []struct { Choices []struct {
Message struct { Message struct {
@@ -67,8 +101,7 @@ type ChatResponse struct {
} `json:"choices"` } `json:"choices"`
} }
// Complete sends a chat completion request and returns the assistant's response content. func (c *Client) completeOpenAI(ctx context.Context, messages []Message) (string, error) {
func (c *Client) Complete(ctx context.Context, messages []Message) (string, error) {
reqBody := ChatRequest{ reqBody := ChatRequest{
Model: c.model, Model: c.model,
Temperature: c.temperature, Temperature: c.temperature,
@@ -81,37 +114,126 @@ func (c *Client) Complete(ctx context.Context, messages []Message) (string, erro
} }
url := c.baseURL + "/chat/completions" url := c.baseURL + "/chat/completions"
req, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewReader(data)) req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(data))
if err != nil { if err != nil {
return "", fmt.Errorf("create request: %w", err) return "", fmt.Errorf("create request: %w", err)
} }
req.Header.Set("Authorization", "Bearer "+c.apiKey) req.Header.Set("Authorization", "Bearer "+c.apiKey)
req.Header.Set("Content-Type", "application/json") req.Header.Set("Content-Type", "application/json")
return c.doRequest(req, func(body []byte) (string, error) {
var resp ChatResponse
if err := json.Unmarshal(body, &resp); err != nil {
return "", fmt.Errorf("parse response: %w", err)
}
if len(resp.Choices) == 0 {
return "", fmt.Errorf("no choices in LLM response")
}
return resp.Choices[0].Message.Content, nil
})
}
// --- Anthropic Messages API implementation ---
type anthropicRequest struct {
Model string `json:"model"`
MaxTokens int `json:"max_tokens"`
System string `json:"system,omitempty"`
Messages []anthropicMsg `json:"messages"`
Temperature float64 `json:"temperature,omitempty"`
}
type anthropicMsg struct {
Role string `json:"role"`
Content string `json:"content"`
}
type anthropicResponse struct {
Content []struct {
Type string `json:"type"`
Text string `json:"text"`
} `json:"content"`
}
func (c *Client) completeAnthropic(ctx context.Context, messages []Message) (string, error) {
// Extract system message (first message with role "system")
var system string
var userMessages []anthropicMsg
for _, m := range messages {
if m.Role == "system" {
system = m.Content
} else {
userMessages = append(userMessages, anthropicMsg{
Role: m.Role,
Content: m.Content,
})
}
}
reqBody := anthropicRequest{
Model: c.model,
MaxTokens: 8192,
System: system,
Messages: userMessages,
}
if c.temperature > 0 {
reqBody.Temperature = c.temperature
}
data, err := json.Marshal(reqBody)
if err != nil {
return "", fmt.Errorf("marshal request: %w", err)
}
url := c.baseURL + "/messages"
req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(data))
if err != nil {
return "", fmt.Errorf("create request: %w", err)
}
req.Header.Set("x-api-key", c.apiKey)
req.Header.Set("anthropic-version", "2023-06-01")
req.Header.Set("Content-Type", "application/json")
return c.doRequest(req, func(body []byte) (string, error) {
var resp anthropicResponse
if err := json.Unmarshal(body, &resp); err != nil {
return "", fmt.Errorf("parse response: %w", err)
}
if len(resp.Content) == 0 {
return "", fmt.Errorf("no content in Anthropic response")
}
// Concatenate all text blocks
var sb strings.Builder
for _, block := range resp.Content {
if block.Type == "text" {
sb.WriteString(block.Text)
}
}
result := sb.String()
if result == "" {
return "", fmt.Errorf("no text content in Anthropic response")
}
return result, nil
})
}
// --- Shared HTTP execution ---
func (c *Client) doRequest(req *http.Request, parse func([]byte) (string, error)) (string, error) {
resp, err := c.http.Do(req) resp, err := c.http.Do(req)
if err != nil { if err != nil {
return "", fmt.Errorf("LLM request: %w", err) return "", fmt.Errorf("LLM request: %w", err)
} }
defer resp.Body.Close() defer resp.Body.Close()
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
body, _ := io.ReadAll(resp.Body)
return "", fmt.Errorf("LLM API error (status %d): %s", resp.StatusCode, string(body))
}
body, err := io.ReadAll(resp.Body) body, err := io.ReadAll(resp.Body)
if err != nil { if err != nil {
return "", fmt.Errorf("read response: %w", err) return "", fmt.Errorf("read response: %w", err)
} }
var chatResp ChatResponse if resp.StatusCode < 200 || resp.StatusCode >= 300 {
if err := json.Unmarshal(body, &chatResp); err != nil { return "", fmt.Errorf("LLM API error (status %d): %s", resp.StatusCode, string(body))
return "", fmt.Errorf("parse response: %w", err)
} }
if len(chatResp.Choices) == 0 { return parse(body)
return "", fmt.Errorf("no choices in LLM response")
}
return chatResp.Choices[0].Message.Content, nil
} }
+87
View File
@@ -208,3 +208,90 @@ func TestWithTimeout(t *testing.T) {
t.Error("expected timeout error with 50ms timeout and 200ms server delay") t.Error("expected timeout error with 50ms timeout and 200ms server delay")
} }
} }
func TestComplete_Anthropic_Success(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/messages" {
t.Errorf("unexpected path: %s", r.URL.Path)
}
if r.Header.Get("x-api-key") != "test-key" {
t.Errorf("expected x-api-key header, got %q", r.Header.Get("x-api-key"))
}
if r.Header.Get("anthropic-version") != "2023-06-01" {
t.Errorf("expected anthropic-version header, got %q", r.Header.Get("anthropic-version"))
}
var req map[string]interface{}
json.NewDecoder(r.Body).Decode(&req)
if req["system"] != "You are helpful" {
t.Errorf("expected system prompt, got %v", req["system"])
}
msgs := req["messages"].([]interface{})
if len(msgs) != 1 {
t.Errorf("expected 1 user message, got %d", len(msgs))
}
if req["max_tokens"] != float64(8192) {
t.Errorf("expected max_tokens 8192, got %v", req["max_tokens"])
}
w.Header().Set("Content-Type", "application/json")
w.Write([]byte(`{"content":[{"type":"text","text":"Hello from Claude!"}]}`))
}))
defer server.Close()
client := NewClient(server.URL, "test-key", "claude-sonnet").WithProvider(ProviderAnthropic)
got, err := client.Complete(context.Background(), []Message{
{Role: "system", Content: "You are helpful"},
{Role: "user", Content: "Hi"},
})
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if got != "Hello from Claude!" {
t.Errorf("expected %q, got %q", "Hello from Claude!", got)
}
}
func TestComplete_Anthropic_NoContent(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
w.Write([]byte(`{"content":[]}`))
}))
defer server.Close()
client := NewClient(server.URL, "test-key", "claude-sonnet").WithProvider(ProviderAnthropic)
_, err := client.Complete(context.Background(), []Message{{Role: "user", Content: "Hi"}})
if err == nil {
t.Fatal("expected error for empty content, got nil")
}
}
func TestComplete_Anthropic_APIError(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusBadRequest)
w.Write([]byte(`{"error":{"message":"invalid request"}}`))
}))
defer server.Close()
client := NewClient(server.URL, "test-key", "claude-sonnet").WithProvider(ProviderAnthropic)
_, err := client.Complete(context.Background(), []Message{{Role: "user", Content: "Hi"}})
if err == nil {
t.Fatal("expected error for 400, got nil")
}
}
func TestWithProvider(t *testing.T) {
client := NewClient("http://example.com", "key", "model")
if client.provider != ProviderOpenAI {
t.Errorf("expected default provider openai, got %s", client.provider)
}
result := client.WithProvider(ProviderAnthropic)
if result != client {
t.Error("WithProvider should return the same client for chaining")
}
if client.provider != ProviderAnthropic {
t.Errorf("expected provider anthropic, got %s", client.provider)
}
}
+4 -26
View File
@@ -7,10 +7,8 @@ import (
"strings" "strings"
) )
// BuildSystemBase returns the core system prompt instructions without // BuildSystemPrompt constructs the system prompt for the LLM reviewer.
// patterns or conventions. Used by the budget package to separate func BuildSystemPrompt(conventions, patterns string) string {
// trimmable from non-trimmable content.
func BuildSystemBase() string {
var sb strings.Builder var sb strings.Builder
sb.WriteString("You are an expert code reviewer. Review the provided pull request diff carefully.\n\n") sb.WriteString("You are an expert code reviewer. Review the provided pull request diff carefully.\n\n")
@@ -44,15 +42,6 @@ func BuildSystemBase() string {
sb.WriteString("- Line numbers should reference the new file line numbers from the diff headers.\n") sb.WriteString("- Line numbers should reference the new file line numbers from the diff headers.\n")
sb.WriteString("- If the diff is empty or trivial (only formatting/whitespace), APPROVE with no findings.\n") sb.WriteString("- If the diff is empty or trivial (only formatting/whitespace), APPROVE with no findings.\n")
return sb.String()
}
// BuildSystemPrompt constructs the full system prompt with patterns and conventions.
// Deprecated: Use BuildSystemBase with budget.Fit for context-aware assembly.
func BuildSystemPrompt(conventions, patterns string) string {
var sb strings.Builder
sb.WriteString(BuildSystemBase())
if patterns != "" { if patterns != "" {
sb.WriteString(fmt.Sprintf("\n\n## Language Patterns & Idioms\n\nUse the following patterns as review criteria. Code that violates these established patterns is a finding:\n\n%s\n", patterns)) sb.WriteString(fmt.Sprintf("\n\n## Language Patterns & Idioms\n\nUse the following patterns as review criteria. Code that violates these established patterns is a finding:\n\n%s\n", patterns))
} }
@@ -64,9 +53,8 @@ func BuildSystemPrompt(conventions, patterns string) string {
return sb.String() return sb.String()
} }
// BuildUserMeta returns the PR metadata header (title, description, CI status) // BuildUserPrompt constructs the user message with PR context.
// without the diff or file context. Used by the budget package. func BuildUserPrompt(title, description, diff, fileContext string, ciPassed bool, ciDetails string) string {
func BuildUserMeta(title, description string, ciPassed bool, ciDetails string) string {
var sb strings.Builder var sb strings.Builder
sb.WriteString(fmt.Sprintf("## Pull Request: %s\n\n", title)) sb.WriteString(fmt.Sprintf("## Pull Request: %s\n\n", title))
@@ -85,16 +73,6 @@ func BuildUserMeta(title, description string, ciPassed bool, ciDetails string) s
sb.WriteString(fmt.Sprintf("CI Details: %s\n", ciDetails)) sb.WriteString(fmt.Sprintf("CI Details: %s\n", ciDetails))
} }
return sb.String()
}
// BuildUserPrompt constructs the user message with PR context.
// Deprecated: Use BuildUserMeta with budget.Fit for context-aware assembly.
func BuildUserPrompt(title, description, diff, fileContext string, ciPassed bool, ciDetails string) string {
var sb strings.Builder
sb.WriteString(BuildUserMeta(title, description, ciPassed, ciDetails))
if fileContext != "" { if fileContext != "" {
sb.WriteString("\n### Full File Context (modified files)\n\n") sb.WriteString("\n### Full File Context (modified files)\n\n")
sb.WriteString(fileContext) sb.WriteString(fileContext)
-40
View File
@@ -116,43 +116,3 @@ func TestBuildUserPrompt_WithoutFileContext(t *testing.T) {
t.Error("should not include file context section when empty") t.Error("should not include file context section when empty")
} }
} }
func TestBuildSystemBase(t *testing.T) {
result := BuildSystemBase()
if result == "" {
t.Fatal("BuildSystemBase returned empty string")
}
if !strings.Contains(result, "expert code reviewer") {
t.Error("expected reviewer role in system base")
}
if !strings.Contains(result, "REQUEST_CHANGES") {
t.Error("expected verdict format in system base")
}
if !strings.Contains(result, "JSON") {
t.Error("expected JSON output instruction in system base")
}
}
func TestBuildUserMeta(t *testing.T) {
result := BuildUserMeta("Fix bug", "Some description", true, "all checks passed")
if !strings.Contains(result, "Fix bug") {
t.Error("expected title in user meta")
}
if !strings.Contains(result, "Some description") {
t.Error("expected description in user meta")
}
if !strings.Contains(result, "PASSED") {
t.Error("expected CI PASSED status")
}
}
func TestBuildUserMeta_CIFailed(t *testing.T) {
result := BuildUserMeta("Title", "", false, "test job failed")
if !strings.Contains(result, "FAILED") {
t.Error("expected CI FAILED status")
}
if strings.Contains(result, "Description") {
t.Error("expected no description section when empty")
}
}