feat: add context.Context + unexport client fields
REVIEW.md findings 1-4, 14: - All Gitea client methods now accept context.Context as first param - All LLM client methods now accept context.Context as first param - Use http.NewRequestWithContext for cancellation/timeout support - Main uses 3-minute timeout context for all operations - Unexport Client struct fields (baseURL, token, apiKey, etc.) - Use bytes.NewReader instead of strings.NewReader(string(...))
This commit is contained in:
+43
-41
@@ -1,6 +1,8 @@
|
||||
package gitea
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
@@ -10,17 +12,17 @@ import (
|
||||
|
||||
// Client interacts with the Gitea API.
|
||||
type Client struct {
|
||||
BaseURL string
|
||||
Token string
|
||||
HTTP *http.Client
|
||||
baseURL string
|
||||
token string
|
||||
http *http.Client
|
||||
}
|
||||
|
||||
// NewClient creates a new Gitea API client.
|
||||
func NewClient(baseURL, token string) *Client {
|
||||
return &Client{
|
||||
BaseURL: strings.TrimRight(baseURL, "/"),
|
||||
Token: token,
|
||||
HTTP: &http.Client{},
|
||||
baseURL: strings.TrimRight(baseURL, "/"),
|
||||
token: token,
|
||||
http: &http.Client{},
|
||||
}
|
||||
}
|
||||
|
||||
@@ -49,9 +51,9 @@ type ChangedFile struct {
|
||||
}
|
||||
|
||||
// GetPullRequest fetches PR metadata.
|
||||
func (c *Client) GetPullRequest(owner, repo string, number int) (*PullRequest, error) {
|
||||
url := fmt.Sprintf("%s/api/v1/repos/%s/%s/pulls/%d", c.BaseURL, owner, repo, number)
|
||||
body, err := c.doGet(url)
|
||||
func (c *Client) GetPullRequest(ctx context.Context, owner, repo string, number int) (*PullRequest, error) {
|
||||
url := fmt.Sprintf("%s/api/v1/repos/%s/%s/pulls/%d", c.baseURL, owner, repo, number)
|
||||
body, err := c.doGet(ctx, url)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("fetch PR: %w", err)
|
||||
}
|
||||
@@ -63,9 +65,9 @@ func (c *Client) GetPullRequest(owner, repo string, number int) (*PullRequest, e
|
||||
}
|
||||
|
||||
// GetPullRequestDiff fetches the unified diff for a PR.
|
||||
func (c *Client) GetPullRequestDiff(owner, repo string, number int) (string, error) {
|
||||
url := fmt.Sprintf("%s/api/v1/repos/%s/%s/pulls/%d.diff", c.BaseURL, owner, repo, number)
|
||||
body, err := c.doGet(url)
|
||||
func (c *Client) GetPullRequestDiff(ctx context.Context, owner, repo string, number int) (string, error) {
|
||||
url := fmt.Sprintf("%s/api/v1/repos/%s/%s/pulls/%d.diff", c.baseURL, owner, repo, number)
|
||||
body, err := c.doGet(ctx, url)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("fetch diff: %w", err)
|
||||
}
|
||||
@@ -73,9 +75,9 @@ func (c *Client) GetPullRequestDiff(owner, repo string, number int) (string, err
|
||||
}
|
||||
|
||||
// GetPullRequestFiles fetches the list of files changed in a PR.
|
||||
func (c *Client) GetPullRequestFiles(owner, repo string, number int) ([]ChangedFile, error) {
|
||||
url := fmt.Sprintf("%s/api/v1/repos/%s/%s/pulls/%d/files", c.BaseURL, owner, repo, number)
|
||||
body, err := c.doGet(url)
|
||||
func (c *Client) GetPullRequestFiles(ctx context.Context, owner, repo string, number int) ([]ChangedFile, error) {
|
||||
url := fmt.Sprintf("%s/api/v1/repos/%s/%s/pulls/%d/files", c.baseURL, owner, repo, number)
|
||||
body, err := c.doGet(ctx, url)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("fetch PR files: %w", err)
|
||||
}
|
||||
@@ -87,9 +89,9 @@ func (c *Client) GetPullRequestFiles(owner, repo string, number int) ([]ChangedF
|
||||
}
|
||||
|
||||
// GetCommitStatuses fetches CI statuses for a commit SHA.
|
||||
func (c *Client) GetCommitStatuses(owner, repo, sha string) ([]CommitStatus, error) {
|
||||
url := fmt.Sprintf("%s/api/v1/repos/%s/%s/commits/%s/statuses", c.BaseURL, owner, repo, sha)
|
||||
body, err := c.doGet(url)
|
||||
func (c *Client) GetCommitStatuses(ctx context.Context, owner, repo, sha string) ([]CommitStatus, error) {
|
||||
url := fmt.Sprintf("%s/api/v1/repos/%s/%s/commits/%s/statuses", c.baseURL, owner, repo, sha)
|
||||
body, err := c.doGet(ctx, url)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("fetch commit statuses: %w", err)
|
||||
}
|
||||
@@ -101,9 +103,9 @@ func (c *Client) GetCommitStatuses(owner, repo, sha string) ([]CommitStatus, err
|
||||
}
|
||||
|
||||
// GetFileContent fetches a file from the default branch of a repo.
|
||||
func (c *Client) GetFileContent(owner, repo, filepath string) (string, error) {
|
||||
url := fmt.Sprintf("%s/api/v1/repos/%s/%s/raw/%s", c.BaseURL, owner, repo, filepath)
|
||||
body, err := c.doGet(url)
|
||||
func (c *Client) GetFileContent(ctx context.Context, owner, repo, filepath string) (string, error) {
|
||||
url := fmt.Sprintf("%s/api/v1/repos/%s/%s/raw/%s", c.baseURL, owner, repo, filepath)
|
||||
body, err := c.doGet(ctx, url)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("fetch file %s: %w", filepath, err)
|
||||
}
|
||||
@@ -111,9 +113,9 @@ func (c *Client) GetFileContent(owner, repo, filepath string) (string, error) {
|
||||
}
|
||||
|
||||
// GetFileContentRef fetches a file from a specific ref (branch/tag/sha) in a repo.
|
||||
func (c *Client) GetFileContentRef(owner, repo, filepath, ref string) (string, error) {
|
||||
url := fmt.Sprintf("%s/api/v1/repos/%s/%s/raw/%s?ref=%s", c.BaseURL, owner, repo, filepath, ref)
|
||||
body, err := c.doGet(url)
|
||||
func (c *Client) GetFileContentRef(ctx context.Context, owner, repo, filepath, ref string) (string, error) {
|
||||
url := fmt.Sprintf("%s/api/v1/repos/%s/%s/raw/%s?ref=%s", c.baseURL, owner, repo, filepath, ref)
|
||||
body, err := c.doGet(ctx, url)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("fetch file %s@%s: %w", filepath, ref, err)
|
||||
}
|
||||
@@ -122,8 +124,8 @@ func (c *Client) GetFileContentRef(owner, repo, filepath, ref string) (string, e
|
||||
|
||||
// PostReview submits a review to a PR.
|
||||
// event should be "APPROVED" or "REQUEST_CHANGES".
|
||||
func (c *Client) PostReview(owner, repo string, number int, event, body string) error {
|
||||
url := fmt.Sprintf("%s/api/v1/repos/%s/%s/pulls/%d/reviews", c.BaseURL, owner, repo, number)
|
||||
func (c *Client) PostReview(ctx context.Context, owner, repo string, number int, event, body string) error {
|
||||
url := fmt.Sprintf("%s/api/v1/repos/%s/%s/pulls/%d/reviews", c.baseURL, owner, repo, number)
|
||||
|
||||
payload := struct {
|
||||
Body string `json:"body"`
|
||||
@@ -138,14 +140,14 @@ func (c *Client) PostReview(owner, repo string, number int, event, body string)
|
||||
return fmt.Errorf("marshal review payload: %w", err)
|
||||
}
|
||||
|
||||
req, err := http.NewRequest("POST", url, strings.NewReader(string(data)))
|
||||
req, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewReader(data))
|
||||
if err != nil {
|
||||
return fmt.Errorf("create review request: %w", err)
|
||||
}
|
||||
req.Header.Set("Authorization", "token "+c.Token)
|
||||
req.Header.Set("Authorization", "token "+c.token)
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
resp, err := c.HTTP.Do(req)
|
||||
resp, err := c.http.Do(req)
|
||||
if err != nil {
|
||||
return fmt.Errorf("post review: %w", err)
|
||||
}
|
||||
@@ -158,14 +160,14 @@ func (c *Client) PostReview(owner, repo string, number int, event, body string)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Client) doGet(url string) ([]byte, error) {
|
||||
req, err := http.NewRequest("GET", url, nil)
|
||||
func (c *Client) doGet(ctx context.Context, url string) ([]byte, error) {
|
||||
req, err := http.NewRequestWithContext(ctx, "GET", url, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Header.Set("Authorization", "token "+c.Token)
|
||||
req.Header.Set("Authorization", "token "+c.token)
|
||||
|
||||
resp, err := c.HTTP.Do(req)
|
||||
resp, err := c.http.Do(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -186,9 +188,9 @@ type ContentEntry struct {
|
||||
}
|
||||
|
||||
// ListContents lists files and directories at a given path in a repo.
|
||||
func (c *Client) ListContents(owner, repo, path string) ([]ContentEntry, error) {
|
||||
url := fmt.Sprintf("%s/api/v1/repos/%s/%s/contents/%s", c.BaseURL, owner, repo, path)
|
||||
body, err := c.doGet(url)
|
||||
func (c *Client) ListContents(ctx context.Context, owner, repo, path string) ([]ContentEntry, error) {
|
||||
url := fmt.Sprintf("%s/api/v1/repos/%s/%s/contents/%s", c.baseURL, owner, repo, path)
|
||||
body, err := c.doGet(ctx, url)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("list contents %s: %w", path, err)
|
||||
}
|
||||
@@ -202,14 +204,14 @@ func (c *Client) ListContents(owner, repo, path string) ([]ContentEntry, error)
|
||||
// GetAllFilesInPath recursively fetches all file contents under a path.
|
||||
// If the path is a file, returns just that file's content.
|
||||
// If the path is a directory, recursively fetches all files within it.
|
||||
func (c *Client) GetAllFilesInPath(owner, repo, path string) (map[string]string, error) {
|
||||
func (c *Client) GetAllFilesInPath(ctx context.Context, owner, repo, path string) (map[string]string, error) {
|
||||
results := make(map[string]string)
|
||||
|
||||
// Try listing as directory first
|
||||
entries, err := c.ListContents(owner, repo, path)
|
||||
entries, err := c.ListContents(ctx, owner, repo, path)
|
||||
if err != nil {
|
||||
// Might be a file, try fetching directly
|
||||
content, fileErr := c.GetFileContent(owner, repo, path)
|
||||
content, fileErr := c.GetFileContent(ctx, owner, repo, path)
|
||||
if fileErr != nil {
|
||||
return nil, fmt.Errorf("path %q is neither a file nor directory: %w", path, err)
|
||||
}
|
||||
@@ -220,13 +222,13 @@ func (c *Client) GetAllFilesInPath(owner, repo, path string) (map[string]string,
|
||||
for _, entry := range entries {
|
||||
switch entry.Type {
|
||||
case "file":
|
||||
content, err := c.GetFileContent(owner, repo, entry.Path)
|
||||
content, err := c.GetFileContent(ctx, owner, repo, entry.Path)
|
||||
if err != nil {
|
||||
continue // Skip files we can't read
|
||||
}
|
||||
results[entry.Path] = content
|
||||
case "dir":
|
||||
subResults, err := c.GetAllFilesInPath(owner, repo, entry.Path)
|
||||
subResults, err := c.GetAllFilesInPath(ctx, owner, repo, entry.Path)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
+13
-12
@@ -1,6 +1,7 @@
|
||||
package gitea
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
@@ -28,7 +29,7 @@ func TestGetPullRequest(t *testing.T) {
|
||||
defer server.Close()
|
||||
|
||||
client := NewClient(server.URL, "test-token")
|
||||
got, err := client.GetPullRequest("owner", "repo", 1)
|
||||
got, err := client.GetPullRequest(context.Background(), "owner", "repo", 1)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
@@ -55,7 +56,7 @@ func TestGetPullRequestDiff(t *testing.T) {
|
||||
defer server.Close()
|
||||
|
||||
client := NewClient(server.URL, "test-token")
|
||||
got, err := client.GetPullRequestDiff("owner", "repo", 5)
|
||||
got, err := client.GetPullRequestDiff(context.Background(), "owner", "repo", 5)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
@@ -80,7 +81,7 @@ func TestGetCommitStatuses(t *testing.T) {
|
||||
defer server.Close()
|
||||
|
||||
client := NewClient(server.URL, "test-token")
|
||||
got, err := client.GetCommitStatuses("owner", "repo", "abc123")
|
||||
got, err := client.GetCommitStatuses(context.Background(), "owner", "repo", "abc123")
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
@@ -127,7 +128,7 @@ func TestPostReview(t *testing.T) {
|
||||
defer server.Close()
|
||||
|
||||
client := NewClient(server.URL, "test-token")
|
||||
err := client.PostReview("owner", "repo", 3, "APPROVED", "LGTM")
|
||||
err := client.PostReview(context.Background(), "owner", "repo", 3, "APPROVED", "LGTM")
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
@@ -141,7 +142,7 @@ func TestGetPullRequest_Non200(t *testing.T) {
|
||||
defer server.Close()
|
||||
|
||||
client := NewClient(server.URL, "test-token")
|
||||
_, err := client.GetPullRequest("owner", "repo", 999)
|
||||
_, err := client.GetPullRequest(context.Background(), "owner", "repo", 999)
|
||||
if err == nil {
|
||||
t.Fatal("expected error for 404, got nil")
|
||||
}
|
||||
@@ -154,7 +155,7 @@ func TestGetPullRequest_BadJSON(t *testing.T) {
|
||||
defer server.Close()
|
||||
|
||||
client := NewClient(server.URL, "test-token")
|
||||
_, err := client.GetPullRequest("owner", "repo", 1)
|
||||
_, err := client.GetPullRequest(context.Background(), "owner", "repo", 1)
|
||||
if err == nil {
|
||||
t.Fatal("expected error for bad JSON, got nil")
|
||||
}
|
||||
@@ -168,7 +169,7 @@ func TestPostReview_Non200(t *testing.T) {
|
||||
defer server.Close()
|
||||
|
||||
client := NewClient(server.URL, "test-token")
|
||||
err := client.PostReview("owner", "repo", 1, "APPROVED", "test")
|
||||
err := client.PostReview(context.Background(), "owner", "repo", 1, "APPROVED", "test")
|
||||
if err == nil {
|
||||
t.Fatal("expected error for 403, got nil")
|
||||
}
|
||||
@@ -186,7 +187,7 @@ func TestGetFileContent(t *testing.T) {
|
||||
defer server.Close()
|
||||
|
||||
client := NewClient(server.URL, "test-token")
|
||||
got, err := client.GetFileContent("owner", "repo", "CONVENTIONS.md")
|
||||
got, err := client.GetFileContent(context.Background(), "owner", "repo", "CONVENTIONS.md")
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
@@ -206,7 +207,7 @@ func TestGetPullRequestFiles(t *testing.T) {
|
||||
defer server.Close()
|
||||
|
||||
client := NewClient(server.URL, "test-token")
|
||||
files, err := client.GetPullRequestFiles("owner", "repo", 1)
|
||||
files, err := client.GetPullRequestFiles(context.Background(), "owner", "repo", 1)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
@@ -231,7 +232,7 @@ func TestGetFileContentRef(t *testing.T) {
|
||||
defer server.Close()
|
||||
|
||||
client := NewClient(server.URL, "test-token")
|
||||
content, err := client.GetFileContentRef("owner", "repo", "main.go", "feature-branch")
|
||||
content, err := client.GetFileContentRef(context.Background(), "owner", "repo", "main.go", "feature-branch")
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
@@ -251,7 +252,7 @@ func TestListContents(t *testing.T) {
|
||||
defer server.Close()
|
||||
|
||||
client := NewClient(server.URL, "test-token")
|
||||
entries, err := client.ListContents("owner", "repo", "docs")
|
||||
entries, err := client.ListContents(context.Background(), "owner", "repo", "docs")
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
@@ -282,7 +283,7 @@ func TestGetAllFilesInPath_File(t *testing.T) {
|
||||
defer server.Close()
|
||||
|
||||
client := NewClient(server.URL, "test-token")
|
||||
files, err := client.GetAllFilesInPath("owner", "repo", "README.md")
|
||||
files, err := client.GetAllFilesInPath(context.Background(), "owner", "repo", "README.md")
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user