Compare commits
7 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 14a0c2a946 | |||
| ef3e6d5e87 | |||
| aade891129 | |||
| 7b42de67ca | |||
| dd2661fe14 | |||
| 98a4772f30 | |||
| fc23b6ebe9 |
@@ -34,6 +34,10 @@ inputs:
|
||||
llm-model:
|
||||
description: 'LLM model name'
|
||||
required: true
|
||||
llm-provider:
|
||||
description: 'LLM API provider: openai or anthropic (default openai)'
|
||||
required: false
|
||||
default: 'openai'
|
||||
conventions-file:
|
||||
description: 'Path to conventions file in the repo (e.g. CLAUDE.md)'
|
||||
required: false
|
||||
@@ -140,6 +144,7 @@ runs:
|
||||
PATTERNS_FILES: ${{ inputs.patterns-files }}
|
||||
LLM_TEMPERATURE: ${{ inputs.temperature }}
|
||||
LLM_TIMEOUT: ${{ inputs.timeout }}
|
||||
LLM_PROVIDER: ${{ inputs.llm-provider }}
|
||||
run: |
|
||||
ARGS=""
|
||||
if [ "${{ inputs.dry-run }}" = "true" ]; then
|
||||
|
||||
@@ -34,6 +34,7 @@ func main() {
|
||||
dryRun := flag.Bool("dry-run", false, "Print review to stdout instead of posting")
|
||||
llmTemp := flag.Float64("llm-temperature", envOrDefaultFloat("LLM_TEMPERATURE", 0), "LLM temperature (0 = server default)")
|
||||
llmTimeout := flag.Int("llm-timeout", envOrDefaultInt("LLM_TIMEOUT", 300), "LLM request timeout in seconds (default 300)")
|
||||
llmProvider := flag.String("llm-provider", envOrDefault("LLM_PROVIDER", "openai"), "LLM API provider: openai or anthropic")
|
||||
|
||||
flag.Parse()
|
||||
|
||||
@@ -74,6 +75,12 @@ func main() {
|
||||
if *llmTemp > 0 {
|
||||
llmClient.WithTemperature(*llmTemp)
|
||||
}
|
||||
switch llm.Provider(*llmProvider) {
|
||||
case llm.ProviderOpenAI, llm.ProviderAnthropic:
|
||||
llmClient.WithProvider(llm.Provider(*llmProvider))
|
||||
default:
|
||||
log.Fatalf("Invalid --llm-provider %q, must be openai or anthropic", *llmProvider)
|
||||
}
|
||||
if *llmTimeout > 0 {
|
||||
llmClient.WithTimeout(time.Duration(*llmTimeout) * time.Second)
|
||||
}
|
||||
|
||||
+38
-17
@@ -1,3 +1,6 @@
|
||||
// Package gitea provides a client for the Gitea API.
|
||||
// It supports pull request operations, file content retrieval,
|
||||
// and review submission.
|
||||
package gitea
|
||||
|
||||
import (
|
||||
@@ -56,8 +59,8 @@ type ChangedFile struct {
|
||||
|
||||
// GetPullRequest fetches PR metadata.
|
||||
func (c *Client) GetPullRequest(ctx context.Context, owner, repo string, number int) (*PullRequest, error) {
|
||||
url := fmt.Sprintf("%s/api/v1/repos/%s/%s/pulls/%d", c.baseURL, owner, repo, number)
|
||||
body, err := c.doGet(ctx, url)
|
||||
reqURL := fmt.Sprintf("%s/api/v1/repos/%s/%s/pulls/%d", c.baseURL, owner, repo, number)
|
||||
body, err := c.doGet(ctx, reqURL)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("fetch PR: %w", err)
|
||||
}
|
||||
@@ -70,8 +73,8 @@ func (c *Client) GetPullRequest(ctx context.Context, owner, repo string, number
|
||||
|
||||
// GetPullRequestDiff fetches the unified diff for a PR.
|
||||
func (c *Client) GetPullRequestDiff(ctx context.Context, owner, repo string, number int) (string, error) {
|
||||
url := fmt.Sprintf("%s/api/v1/repos/%s/%s/pulls/%d.diff", c.baseURL, owner, repo, number)
|
||||
body, err := c.doGet(ctx, url)
|
||||
reqURL := fmt.Sprintf("%s/api/v1/repos/%s/%s/pulls/%d.diff", c.baseURL, owner, repo, number)
|
||||
body, err := c.doGet(ctx, reqURL)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("fetch diff: %w", err)
|
||||
}
|
||||
@@ -80,8 +83,8 @@ func (c *Client) GetPullRequestDiff(ctx context.Context, owner, repo string, num
|
||||
|
||||
// GetPullRequestFiles fetches the list of files changed in a PR.
|
||||
func (c *Client) GetPullRequestFiles(ctx context.Context, owner, repo string, number int) ([]ChangedFile, error) {
|
||||
url := fmt.Sprintf("%s/api/v1/repos/%s/%s/pulls/%d/files", c.baseURL, owner, repo, number)
|
||||
body, err := c.doGet(ctx, url)
|
||||
reqURL := fmt.Sprintf("%s/api/v1/repos/%s/%s/pulls/%d/files", c.baseURL, owner, repo, number)
|
||||
body, err := c.doGet(ctx, reqURL)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("fetch PR files: %w", err)
|
||||
}
|
||||
@@ -94,8 +97,8 @@ func (c *Client) GetPullRequestFiles(ctx context.Context, owner, repo string, nu
|
||||
|
||||
// GetCommitStatuses fetches CI statuses for a commit SHA.
|
||||
func (c *Client) GetCommitStatuses(ctx context.Context, owner, repo, sha string) ([]CommitStatus, error) {
|
||||
url := fmt.Sprintf("%s/api/v1/repos/%s/%s/commits/%s/statuses", c.baseURL, owner, repo, sha)
|
||||
body, err := c.doGet(ctx, url)
|
||||
reqURL := fmt.Sprintf("%s/api/v1/repos/%s/%s/commits/%s/statuses", c.baseURL, owner, repo, sha)
|
||||
body, err := c.doGet(ctx, reqURL)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("fetch commit statuses: %w", err)
|
||||
}
|
||||
@@ -108,8 +111,8 @@ func (c *Client) GetCommitStatuses(ctx context.Context, owner, repo, sha string)
|
||||
|
||||
// GetFileContent fetches a file from the default branch of a repo.
|
||||
func (c *Client) GetFileContent(ctx context.Context, owner, repo, filepath string) (string, error) {
|
||||
url := fmt.Sprintf("%s/api/v1/repos/%s/%s/raw/%s", c.baseURL, owner, repo, filepath)
|
||||
body, err := c.doGet(ctx, url)
|
||||
reqURL := fmt.Sprintf("%s/api/v1/repos/%s/%s/raw/%s", c.baseURL, owner, repo, escapePath(filepath))
|
||||
body, err := c.doGet(ctx, reqURL)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("fetch file %s: %w", filepath, err)
|
||||
}
|
||||
@@ -118,7 +121,7 @@ func (c *Client) GetFileContent(ctx context.Context, owner, repo, filepath strin
|
||||
|
||||
// GetFileContentRef fetches a file from a specific ref (branch/tag/sha) in a repo.
|
||||
func (c *Client) GetFileContentRef(ctx context.Context, owner, repo, filepath, ref string) (string, error) {
|
||||
reqURL := fmt.Sprintf("%s/api/v1/repos/%s/%s/raw/%s?ref=%s", c.baseURL, owner, repo, filepath, url.QueryEscape(ref))
|
||||
reqURL := fmt.Sprintf("%s/api/v1/repos/%s/%s/raw/%s?ref=%s", c.baseURL, owner, repo, escapePath(filepath), url.QueryEscape(ref))
|
||||
body, err := c.doGet(ctx, reqURL)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("fetch file %s@%s: %w", filepath, ref, err)
|
||||
@@ -129,7 +132,7 @@ func (c *Client) GetFileContentRef(ctx context.Context, owner, repo, filepath, r
|
||||
// PostReview submits a review to a PR.
|
||||
// event should be "APPROVED" or "REQUEST_CHANGES".
|
||||
func (c *Client) PostReview(ctx context.Context, owner, repo string, number int, event, body string) error {
|
||||
url := fmt.Sprintf("%s/api/v1/repos/%s/%s/pulls/%d/reviews", c.baseURL, owner, repo, number)
|
||||
reqURL := fmt.Sprintf("%s/api/v1/repos/%s/%s/pulls/%d/reviews", c.baseURL, owner, repo, number)
|
||||
|
||||
payload := struct {
|
||||
Body string `json:"body"`
|
||||
@@ -144,7 +147,7 @@ func (c *Client) PostReview(ctx context.Context, owner, repo string, number int,
|
||||
return fmt.Errorf("marshal review payload: %w", err)
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewReader(data))
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, reqURL, bytes.NewReader(data))
|
||||
if err != nil {
|
||||
return fmt.Errorf("create review request: %w", err)
|
||||
}
|
||||
@@ -164,8 +167,8 @@ func (c *Client) PostReview(ctx context.Context, owner, repo string, number int,
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Client) doGet(ctx context.Context, url string) ([]byte, error) {
|
||||
req, err := http.NewRequestWithContext(ctx, "GET", url, nil)
|
||||
func (c *Client) doGet(ctx context.Context, reqURL string) ([]byte, error) {
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, reqURL, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -184,6 +187,18 @@ func (c *Client) doGet(ctx context.Context, url string) ([]byte, error) {
|
||||
return io.ReadAll(resp.Body)
|
||||
}
|
||||
|
||||
// escapePath escapes each segment of a relative file path for use in URLs.
|
||||
// Slashes are preserved as path separators; other special characters are escaped.
|
||||
// Input should be a relative path (no leading slash). Already-encoded segments
|
||||
// will be double-encoded, which is the desired behavior for user-provided paths.
|
||||
func escapePath(p string) string {
|
||||
parts := strings.Split(p, "/")
|
||||
for i, part := range parts {
|
||||
parts[i] = url.PathEscape(part)
|
||||
}
|
||||
return strings.Join(parts, "/")
|
||||
}
|
||||
|
||||
// ContentEntry represents a file or directory entry from the contents API.
|
||||
type ContentEntry struct {
|
||||
Name string `json:"name"`
|
||||
@@ -192,9 +207,15 @@ type ContentEntry struct {
|
||||
}
|
||||
|
||||
// ListContents lists files and directories at a given path in a repo.
|
||||
// Pass an empty path to list the repository root.
|
||||
func (c *Client) ListContents(ctx context.Context, owner, repo, path string) ([]ContentEntry, error) {
|
||||
url := fmt.Sprintf("%s/api/v1/repos/%s/%s/contents/%s", c.baseURL, owner, repo, path)
|
||||
body, err := c.doGet(ctx, url)
|
||||
var reqURL string
|
||||
if path == "" {
|
||||
reqURL = fmt.Sprintf("%s/api/v1/repos/%s/%s/contents", c.baseURL, owner, repo)
|
||||
} else {
|
||||
reqURL = fmt.Sprintf("%s/api/v1/repos/%s/%s/contents/%s", c.baseURL, owner, repo, escapePath(path))
|
||||
}
|
||||
body, err := c.doGet(ctx, reqURL)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("list contents %s: %w", path, err)
|
||||
}
|
||||
|
||||
@@ -294,3 +294,27 @@ func TestGetAllFilesInPath_File(t *testing.T) {
|
||||
t.Errorf("unexpected content: %q", files["README.md"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestEscapePath(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
want string
|
||||
}{
|
||||
{"simple", "src/main.go", "src/main.go"},
|
||||
{"spaces", "my dir/my file.go", "my%20dir/my%20file.go"},
|
||||
{"special chars", "path/file#1.txt", "path/file%231.txt"},
|
||||
{"empty", "", ""},
|
||||
{"single segment", "README.md", "README.md"},
|
||||
{"nested deep", "a/b/c/d.md", "a/b/c/d.md"},
|
||||
{"already encoded", "path/file%20name.go", "path/file%2520name.go"},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := escapePath(tt.input)
|
||||
if got != tt.want {
|
||||
t.Errorf("escapePath(%q) = %q, want %q", tt.input, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
+144
-21
@@ -1,3 +1,6 @@
|
||||
// Package llm provides clients for LLM chat completion APIs.
|
||||
//
|
||||
// Supports OpenAI-compatible (default) and Anthropic Messages API providers.
|
||||
package llm
|
||||
|
||||
import (
|
||||
@@ -11,23 +14,36 @@ import (
|
||||
"time"
|
||||
)
|
||||
|
||||
// Client calls an OpenAI-compatible chat completion API.
|
||||
// Provider identifies which API format to use.
|
||||
type Provider string
|
||||
|
||||
const (
|
||||
// ProviderOpenAI uses the OpenAI-compatible chat/completions endpoint.
|
||||
ProviderOpenAI Provider = "openai"
|
||||
// ProviderAnthropic uses the Anthropic Messages API endpoint.
|
||||
ProviderAnthropic Provider = "anthropic"
|
||||
)
|
||||
|
||||
// Client calls an LLM chat completion API.
|
||||
// A Client is safe for concurrent use by multiple goroutines after construction.
|
||||
// WithTimeout and WithTemperature must be called during setup, before concurrent use.
|
||||
// WithTimeout, WithTemperature, and WithProvider must be called during setup,
|
||||
// before concurrent use.
|
||||
type Client struct {
|
||||
baseURL string
|
||||
apiKey string
|
||||
model string
|
||||
temperature float64
|
||||
provider Provider
|
||||
http *http.Client
|
||||
}
|
||||
|
||||
// NewClient creates a new LLM client.
|
||||
// NewClient creates a new LLM client. Default provider is OpenAI-compatible.
|
||||
func NewClient(baseURL, apiKey, model string) *Client {
|
||||
return &Client{
|
||||
baseURL: strings.TrimRight(baseURL, "/"),
|
||||
apiKey: apiKey,
|
||||
model: model,
|
||||
provider: ProviderOpenAI,
|
||||
http: &http.Client{Timeout: 5 * time.Minute},
|
||||
}
|
||||
}
|
||||
@@ -44,20 +60,39 @@ func (c *Client) WithTemperature(t float64) *Client {
|
||||
return c
|
||||
}
|
||||
|
||||
// WithProvider sets the API provider format (openai or anthropic).
|
||||
func (c *Client) WithProvider(p Provider) *Client {
|
||||
c.provider = p
|
||||
return c
|
||||
}
|
||||
|
||||
// Message represents a chat message.
|
||||
type Message struct {
|
||||
Role string `json:"role"`
|
||||
Content string `json:"content"`
|
||||
}
|
||||
|
||||
// ChatRequest is the request payload.
|
||||
// Complete sends a chat completion request and returns the assistant's response content.
|
||||
// The first message with role "system" is treated as the system prompt.
|
||||
func (c *Client) Complete(ctx context.Context, messages []Message) (string, error) {
|
||||
switch c.provider {
|
||||
case ProviderAnthropic:
|
||||
return c.completeAnthropic(ctx, messages)
|
||||
default:
|
||||
return c.completeOpenAI(ctx, messages)
|
||||
}
|
||||
}
|
||||
|
||||
// --- OpenAI-compatible implementation ---
|
||||
|
||||
// ChatRequest is the OpenAI request payload.
|
||||
type ChatRequest struct {
|
||||
Model string `json:"model"`
|
||||
Messages []Message `json:"messages"`
|
||||
Temperature float64 `json:"temperature,omitempty"`
|
||||
}
|
||||
|
||||
// ChatResponse is the response from the API.
|
||||
// ChatResponse is the OpenAI response.
|
||||
type ChatResponse struct {
|
||||
Choices []struct {
|
||||
Message struct {
|
||||
@@ -66,8 +101,7 @@ type ChatResponse struct {
|
||||
} `json:"choices"`
|
||||
}
|
||||
|
||||
// Complete sends a chat completion request and returns the assistant's response content.
|
||||
func (c *Client) Complete(ctx context.Context, messages []Message) (string, error) {
|
||||
func (c *Client) completeOpenAI(ctx context.Context, messages []Message) (string, error) {
|
||||
reqBody := ChatRequest{
|
||||
Model: c.model,
|
||||
Temperature: c.temperature,
|
||||
@@ -80,37 +114,126 @@ func (c *Client) Complete(ctx context.Context, messages []Message) (string, erro
|
||||
}
|
||||
|
||||
url := c.baseURL + "/chat/completions"
|
||||
req, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewReader(data))
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(data))
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("create request: %w", err)
|
||||
}
|
||||
req.Header.Set("Authorization", "Bearer "+c.apiKey)
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
return c.doRequest(req, func(body []byte) (string, error) {
|
||||
var resp ChatResponse
|
||||
if err := json.Unmarshal(body, &resp); err != nil {
|
||||
return "", fmt.Errorf("parse response: %w", err)
|
||||
}
|
||||
if len(resp.Choices) == 0 {
|
||||
return "", fmt.Errorf("no choices in LLM response")
|
||||
}
|
||||
return resp.Choices[0].Message.Content, nil
|
||||
})
|
||||
}
|
||||
|
||||
// --- Anthropic Messages API implementation ---
|
||||
|
||||
type anthropicRequest struct {
|
||||
Model string `json:"model"`
|
||||
MaxTokens int `json:"max_tokens"`
|
||||
System string `json:"system,omitempty"`
|
||||
Messages []anthropicMsg `json:"messages"`
|
||||
Temperature float64 `json:"temperature,omitempty"`
|
||||
}
|
||||
|
||||
type anthropicMsg struct {
|
||||
Role string `json:"role"`
|
||||
Content string `json:"content"`
|
||||
}
|
||||
|
||||
type anthropicResponse struct {
|
||||
Content []struct {
|
||||
Type string `json:"type"`
|
||||
Text string `json:"text"`
|
||||
} `json:"content"`
|
||||
}
|
||||
|
||||
func (c *Client) completeAnthropic(ctx context.Context, messages []Message) (string, error) {
|
||||
// Extract system message (first message with role "system")
|
||||
var system string
|
||||
var userMessages []anthropicMsg
|
||||
for _, m := range messages {
|
||||
if m.Role == "system" {
|
||||
system = m.Content
|
||||
} else {
|
||||
userMessages = append(userMessages, anthropicMsg{
|
||||
Role: m.Role,
|
||||
Content: m.Content,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
reqBody := anthropicRequest{
|
||||
Model: c.model,
|
||||
MaxTokens: 8192,
|
||||
System: system,
|
||||
Messages: userMessages,
|
||||
}
|
||||
if c.temperature > 0 {
|
||||
reqBody.Temperature = c.temperature
|
||||
}
|
||||
|
||||
data, err := json.Marshal(reqBody)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("marshal request: %w", err)
|
||||
}
|
||||
|
||||
url := c.baseURL + "/messages"
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(data))
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("create request: %w", err)
|
||||
}
|
||||
req.Header.Set("x-api-key", c.apiKey)
|
||||
req.Header.Set("anthropic-version", "2023-06-01")
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
return c.doRequest(req, func(body []byte) (string, error) {
|
||||
var resp anthropicResponse
|
||||
if err := json.Unmarshal(body, &resp); err != nil {
|
||||
return "", fmt.Errorf("parse response: %w", err)
|
||||
}
|
||||
if len(resp.Content) == 0 {
|
||||
return "", fmt.Errorf("no content in Anthropic response")
|
||||
}
|
||||
// Concatenate all text blocks
|
||||
var sb strings.Builder
|
||||
for _, block := range resp.Content {
|
||||
if block.Type == "text" {
|
||||
sb.WriteString(block.Text)
|
||||
}
|
||||
}
|
||||
result := sb.String()
|
||||
if result == "" {
|
||||
return "", fmt.Errorf("no text content in Anthropic response")
|
||||
}
|
||||
return result, nil
|
||||
})
|
||||
}
|
||||
|
||||
// --- Shared HTTP execution ---
|
||||
|
||||
func (c *Client) doRequest(req *http.Request, parse func([]byte) (string, error)) (string, error) {
|
||||
resp, err := c.http.Do(req)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("LLM request: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
return "", fmt.Errorf("LLM API error (status %d): %s", resp.StatusCode, string(body))
|
||||
}
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("read response: %w", err)
|
||||
}
|
||||
|
||||
var chatResp ChatResponse
|
||||
if err := json.Unmarshal(body, &chatResp); err != nil {
|
||||
return "", fmt.Errorf("parse response: %w", err)
|
||||
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
||||
return "", fmt.Errorf("LLM API error (status %d): %s", resp.StatusCode, string(body))
|
||||
}
|
||||
|
||||
if len(chatResp.Choices) == 0 {
|
||||
return "", fmt.Errorf("no choices in LLM response")
|
||||
}
|
||||
|
||||
return chatResp.Choices[0].Message.Content, nil
|
||||
return parse(body)
|
||||
}
|
||||
|
||||
@@ -208,3 +208,90 @@ func TestWithTimeout(t *testing.T) {
|
||||
t.Error("expected timeout error with 50ms timeout and 200ms server delay")
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
func TestComplete_Anthropic_Success(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path != "/messages" {
|
||||
t.Errorf("unexpected path: %s", r.URL.Path)
|
||||
}
|
||||
if r.Header.Get("x-api-key") != "test-key" {
|
||||
t.Errorf("expected x-api-key header, got %q", r.Header.Get("x-api-key"))
|
||||
}
|
||||
if r.Header.Get("anthropic-version") != "2023-06-01" {
|
||||
t.Errorf("expected anthropic-version header, got %q", r.Header.Get("anthropic-version"))
|
||||
}
|
||||
|
||||
var req map[string]interface{}
|
||||
json.NewDecoder(r.Body).Decode(&req)
|
||||
|
||||
if req["system"] != "You are helpful" {
|
||||
t.Errorf("expected system prompt, got %v", req["system"])
|
||||
}
|
||||
msgs := req["messages"].([]interface{})
|
||||
if len(msgs) != 1 {
|
||||
t.Errorf("expected 1 user message, got %d", len(msgs))
|
||||
}
|
||||
if req["max_tokens"] != float64(8192) {
|
||||
t.Errorf("expected max_tokens 8192, got %v", req["max_tokens"])
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.Write([]byte(`{"content":[{"type":"text","text":"Hello from Claude!"}]}`))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
client := NewClient(server.URL, "test-key", "claude-sonnet").WithProvider(ProviderAnthropic)
|
||||
got, err := client.Complete(context.Background(), []Message{
|
||||
{Role: "system", Content: "You are helpful"},
|
||||
{Role: "user", Content: "Hi"},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if got != "Hello from Claude!" {
|
||||
t.Errorf("expected %q, got %q", "Hello from Claude!", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestComplete_Anthropic_NoContent(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.Write([]byte(`{"content":[]}`))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
client := NewClient(server.URL, "test-key", "claude-sonnet").WithProvider(ProviderAnthropic)
|
||||
_, err := client.Complete(context.Background(), []Message{{Role: "user", Content: "Hi"}})
|
||||
if err == nil {
|
||||
t.Fatal("expected error for empty content, got nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestComplete_Anthropic_APIError(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
w.Write([]byte(`{"error":{"message":"invalid request"}}`))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
client := NewClient(server.URL, "test-key", "claude-sonnet").WithProvider(ProviderAnthropic)
|
||||
_, err := client.Complete(context.Background(), []Message{{Role: "user", Content: "Hi"}})
|
||||
if err == nil {
|
||||
t.Fatal("expected error for 400, got nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestWithProvider(t *testing.T) {
|
||||
client := NewClient("http://example.com", "key", "model")
|
||||
if client.provider != ProviderOpenAI {
|
||||
t.Errorf("expected default provider openai, got %s", client.provider)
|
||||
}
|
||||
result := client.WithProvider(ProviderAnthropic)
|
||||
if result != client {
|
||||
t.Error("WithProvider should return the same client for chaining")
|
||||
}
|
||||
if client.provider != ProviderAnthropic {
|
||||
t.Errorf("expected provider anthropic, got %s", client.provider)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
// Package review builds prompts for AI code review and parses LLM responses
|
||||
// into structured review results.
|
||||
package review
|
||||
|
||||
import (
|
||||
|
||||
Reference in New Issue
Block a user