From 642d1e129f4b56da286ca0035d494139a4d6b1e0 Mon Sep 17 00:00:00 2001 From: Aaron Weiker Date: Wed, 13 May 2026 04:13:00 +0000 Subject: [PATCH] feat(github): implement FileReader interface (#80) Implement FileReader conformance on the GitHub client: GetFileContent, ListContents, path helpers, base64 decode. Includes compile-time conformance checks for both PRReader and FileReader. Requires PR B (#102). Part 3 of 3 for #80. --- github/conformance_test.go | 12 +- github/files.go | 135 +++++++++++++++ github/files_test.go | 334 +++++++++++++++++++++++++++++++++++++ 3 files changed, 479 insertions(+), 2 deletions(-) create mode 100644 github/files.go create mode 100644 github/files_test.go diff --git a/github/conformance_test.go b/github/conformance_test.go index 4dfa195..ca13188 100644 --- a/github/conformance_test.go +++ b/github/conformance_test.go @@ -1,5 +1,13 @@ package github_test -import "gitea.weiker.me/rodin/review-bot/vcs" +import ( + "gitea.weiker.me/rodin/review-bot/github" + "gitea.weiker.me/rodin/review-bot/vcs" +) -var _ vcs.PRReader = (*Client)(nil) +// Compile-time interface conformance assertions. +// These verify github.Client satisfies vcs.PRReader and vcs.FileReader. +var ( + _ vcs.PRReader = (*github.Client)(nil) + _ vcs.FileReader = (*github.Client)(nil) +) diff --git a/github/files.go b/github/files.go new file mode 100644 index 0000000..9f04941 --- /dev/null +++ b/github/files.go @@ -0,0 +1,135 @@ +package github + +import ( + "context" + "encoding/base64" + "encoding/json" + "fmt" + "net/url" + "strings" + + "gitea.weiker.me/rodin/review-bot/vcs" +) + +// GetFileContent fetches a file from a repo at the given ref. +// Delegates to GetFileContentAtRef with the provided ref. +func (c *Client) GetFileContent(ctx context.Context, owner, repo, path, ref string) (string, error) { + return c.GetFileContentAtRef(ctx, owner, repo, path, ref) +} + +// GetFileContentAtRef fetches a file at a specific ref from a repo. +// If ref is empty, the query parameter is omitted (uses default branch). +// +// Note: dot-segments ("." and "..") in the path are silently removed to +// prevent path traversal. This means a path like "foo/../bar" resolves +// to "foo/bar" rather than "bar". +func (c *Client) GetFileContentAtRef(ctx context.Context, owner, repo, path, ref string) (string, error) { + reqURL := fmt.Sprintf("%s/repos/%s/%s/contents/%s", + c.baseURL, url.PathEscape(owner), url.PathEscape(repo), escapePath(path)) + if ref != "" { + reqURL += "?ref=" + url.QueryEscape(ref) + } + body, err := c.doGet(ctx, reqURL) + if err != nil { + return "", fmt.Errorf("fetch file %s: %w", path, err) + } + var resp struct { + Content string `json:"content"` + Encoding string `json:"encoding"` + } + if err := json.Unmarshal(body, &resp); err != nil { + return "", fmt.Errorf("parse file content JSON: %w", err) + } + if resp.Encoding != "base64" { + return "", fmt.Errorf("unexpected encoding %q for file %s", resp.Encoding, path) + } + decoded, err := decodeBase64Content(resp.Content) + if err != nil { + return "", fmt.Errorf("decode base64 content for %s: %w", path, err) + } + return decoded, nil +} + +// ListContents lists files and directories at a given path in a repo. +// Returns the directory listing from the GitHub contents API. +// If the path points to a single file (not a directory), the API returns +// a JSON object instead of an array; this is handled by returning a +// single-element slice. +// +// Note: dot-segments ("." and "..") in the path are silently removed to +// prevent path traversal. This means a path like "foo/../bar" resolves +// to "foo/bar" rather than "bar". +func (c *Client) ListContents(ctx context.Context, owner, repo, path string) ([]vcs.ContentEntry, error) { + reqURL := fmt.Sprintf("%s/repos/%s/%s/contents/%s", + c.baseURL, url.PathEscape(owner), url.PathEscape(repo), escapePath(path)) + body, err := c.doGet(ctx, reqURL) + if err != nil { + return nil, fmt.Errorf("list contents %s: %w", path, err) + } + + type entry struct { + Name string `json:"name"` + Path string `json:"path"` + Type string `json:"type"` + } + + // The GitHub contents API returns an array for directories and an object + // for single files. Try array first (common case), then fall back to object. + // An empty array ([]) is valid — it represents an empty directory — and + // results in a zero-length slice returned without error. + var entries []entry + if err := json.Unmarshal(body, &entries); err != nil { + var single entry + if err2 := json.Unmarshal(body, &single); err2 != nil { + return nil, fmt.Errorf("parse contents JSON: as array: %v; as object: %w", err, err2) + } + // Guard against empty objects ({}) or unexpected shapes that + // unmarshal successfully but carry no useful data. + if single.Name == "" && single.Path == "" && single.Type == "" { + return nil, fmt.Errorf("parse contents JSON: unexpected response format") + } + entries = []entry{single} + } + + result := make([]vcs.ContentEntry, len(entries)) + for i, e := range entries { + result[i] = vcs.ContentEntry{ + Name: e.Name, + Path: e.Path, + Type: e.Type, + } + } + return result, nil +} + +// escapePath escapes each segment of a relative file path for use in URLs. +// Slashes are preserved as path separators; other special characters are escaped. +// Dot-segments ("." and "..") and empty segments (from consecutive slashes like +// "a//b") are silently removed to prevent path traversal and produce canonical +// paths. This is intentional: callers may receive a different path than requested +// without error. The function is package-private, and all callers +// (GetFileContentAtRef, ListContents) already handle missing-file errors from the +// API if the cleaned path doesn't match what the caller intended. +func escapePath(p string) string { + parts := strings.Split(p, "/") + var clean []string + for _, part := range parts { + if part == "." || part == ".." || part == "" { + continue + } + clean = append(clean, url.PathEscape(part)) + } + return strings.Join(clean, "/") +} + +// decodeBase64Content decodes base64-encoded content from the GitHub contents API. +// GitHub returns base64 content with line breaks for formatting; we strip \r and \n before decoding. +func decodeBase64Content(encoded string) (string, error) { + // GitHub inserts newlines in base64 content + cleaned := strings.NewReplacer("\n", "", "\r", "").Replace(encoded) + decoded, err := base64.StdEncoding.DecodeString(cleaned) + if err != nil { + return "", err + } + return string(decoded), nil +} diff --git a/github/files_test.go b/github/files_test.go new file mode 100644 index 0000000..eda64a8 --- /dev/null +++ b/github/files_test.go @@ -0,0 +1,334 @@ +package github + +import ( + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + "time" +) + +func TestGetFileContent_DelegatesToGetFileContentAtRef(t *testing.T) { + var gotRef string + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + gotRef = r.URL.Query().Get("ref") + json.NewEncoder(w).Encode(map[string]string{ + "content": "dGVzdA==", // "test" in base64 + "encoding": "base64", + }) + })) + defer srv.Close() + + c := NewClient("token", srv.URL, AllowInsecureHTTP()) + c.SetHTTPClient(srv.Client()) + + // Call with empty ref — should not include ref param + content, err := c.GetFileContent(context.Background(), "owner", "repo", "file.go", "") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if content != "test" { + t.Errorf("expected 'test', got %q", content) + } + if gotRef != "" { + t.Errorf("expected empty ref, got %q", gotRef) + } +} + +func TestGetFileContent_WithRef(t *testing.T) { + var gotRef string + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + gotRef = r.URL.Query().Get("ref") + json.NewEncoder(w).Encode(map[string]string{ + "content": "dGVzdA==", + "encoding": "base64", + }) + })) + defer srv.Close() + + c := NewClient("token", srv.URL, AllowInsecureHTTP()) + c.SetHTTPClient(srv.Client()) + + _, err := c.GetFileContent(context.Background(), "owner", "repo", "file.go", "abc123") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if gotRef != "abc123" { + t.Errorf("expected ref 'abc123', got %q", gotRef) + } +} + +func TestGetFileContent_404(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(404) + w.Write([]byte(`{"message":"Not Found"}`)) + })) + defer srv.Close() + + c := NewClient("token", srv.URL, AllowInsecureHTTP()) + c.SetHTTPClient(srv.Client()) + + _, err := c.GetFileContent(context.Background(), "owner", "repo", "missing.go", "") + if err == nil { + t.Fatal("expected error for 404") + } +} + +func TestGetFileContent_401(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(401) + w.Write([]byte(`{"message":"Bad credentials"}`)) + })) + defer srv.Close() + + c := NewClient("token", srv.URL, AllowInsecureHTTP()) + c.SetHTTPClient(srv.Client()) + + _, err := c.GetFileContent(context.Background(), "owner", "repo", "file.go", "") + if err == nil { + t.Fatal("expected error for 401") + } +} + +func TestGetFileContent_429Retry(t *testing.T) { + attempts := 0 + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + attempts++ + if attempts == 1 { + w.WriteHeader(429) + w.Write([]byte(`{"message":"rate limit"}`)) + return + } + json.NewEncoder(w).Encode(map[string]string{ + "content": "b2s=", + "encoding": "base64", + }) + })) + defer srv.Close() + + c := NewClient("token", srv.URL, AllowInsecureHTTP()) + c.SetHTTPClient(srv.Client()) + c.SetRetryBackoff([]time.Duration{1 * time.Millisecond}) + + content, err := c.GetFileContent(context.Background(), "owner", "repo", "file.go", "") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if content != "ok" { + t.Errorf("expected 'ok', got %q", content) + } + if attempts != 2 { + t.Errorf("expected 2 attempts, got %d", attempts) + } +} + +func TestGetFileContent_MalformedJSON(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(200) + w.Write([]byte(`not json`)) + })) + defer srv.Close() + + c := NewClient("token", srv.URL, AllowInsecureHTTP()) + c.SetHTTPClient(srv.Client()) + + _, err := c.GetFileContent(context.Background(), "owner", "repo", "file.go", "") + if err == nil { + t.Fatal("expected error for malformed JSON") + } +} + +func TestListContents_HappyPath(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/repos/owner/repo/contents/src" { + t.Errorf("unexpected path: %s", r.URL.Path) + } + json.NewEncoder(w).Encode([]map[string]string{ + {"name": "main.go", "path": "src/main.go", "type": "file"}, + {"name": "lib", "path": "src/lib", "type": "dir"}, + }) + })) + defer srv.Close() + + c := NewClient("token", srv.URL, AllowInsecureHTTP()) + c.SetHTTPClient(srv.Client()) + + entries, err := c.ListContents(context.Background(), "owner", "repo", "src") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if len(entries) != 2 { + t.Fatalf("expected 2 entries, got %d", len(entries)) + } + if entries[0].Name != "main.go" { + t.Errorf("expected name 'main.go', got %q", entries[0].Name) + } + if entries[0].Path != "src/main.go" { + t.Errorf("expected path 'src/main.go', got %q", entries[0].Path) + } + if entries[0].Type != "file" { + t.Errorf("expected type 'file', got %q", entries[0].Type) + } + if entries[1].Name != "lib" { + t.Errorf("expected name 'lib', got %q", entries[1].Name) + } + if entries[1].Type != "dir" { + t.Errorf("expected type 'dir', got %q", entries[1].Type) + } +} + +func TestListContents_404(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(404) + w.Write([]byte(`{"message":"Not Found"}`)) + })) + defer srv.Close() + + c := NewClient("token", srv.URL, AllowInsecureHTTP()) + c.SetHTTPClient(srv.Client()) + + _, err := c.ListContents(context.Background(), "owner", "repo", "missing") + if err == nil { + t.Fatal("expected error for 404") + } +} + +func TestListContents_401(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(401) + w.Write([]byte(`{"message":"Bad credentials"}`)) + })) + defer srv.Close() + + c := NewClient("token", srv.URL, AllowInsecureHTTP()) + c.SetHTTPClient(srv.Client()) + + _, err := c.ListContents(context.Background(), "owner", "repo", "src") + if err == nil { + t.Fatal("expected error for 401") + } +} + +func TestListContents_429Retry(t *testing.T) { + attempts := 0 + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + attempts++ + if attempts == 1 { + w.WriteHeader(429) + w.Write([]byte(`{"message":"rate limit"}`)) + return + } + json.NewEncoder(w).Encode([]map[string]string{ + {"name": "file.go", "path": "file.go", "type": "file"}, + }) + })) + defer srv.Close() + + c := NewClient("token", srv.URL, AllowInsecureHTTP()) + c.SetHTTPClient(srv.Client()) + c.SetRetryBackoff([]time.Duration{1 * time.Millisecond}) + + entries, err := c.ListContents(context.Background(), "owner", "repo", ".") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if len(entries) != 1 { + t.Fatalf("expected 1 entry, got %d", len(entries)) + } + if attempts != 2 { + t.Errorf("expected 2 attempts, got %d", attempts) + } +} + +func TestListContents_MalformedJSON(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(200) + w.Write([]byte(`not json`)) + })) + defer srv.Close() + + c := NewClient("token", srv.URL, AllowInsecureHTTP()) + c.SetHTTPClient(srv.Client()) + + _, err := c.ListContents(context.Background(), "owner", "repo", "src") + if err == nil { + t.Fatal("expected error for malformed JSON") + } +} + +func TestDecodeBase64Content(t *testing.T) { + // Test with newlines (GitHub's format) + encoded := "cGFja2FnZSBt\nYWlu" + decoded, err := decodeBase64Content(encoded) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if decoded != "package main" { + t.Errorf("expected 'package main', got %q", decoded) + } +} + +func TestDecodeBase64Content_Invalid(t *testing.T) { + _, err := decodeBase64Content("not!!!valid!!!base64") + if err == nil { + t.Fatal("expected error for invalid base64") + } +} + +func TestEscapePath_RejectsDotSegments(t *testing.T) { + tests := []struct { + input string + want string + }{ + {"src/main.go", "src/main.go"}, + {"../etc/passwd", "etc/passwd"}, + {"./src/../main.go", "src/main.go"}, + {"a/b/c", "a/b/c"}, + {"file with spaces.go", "file%20with%20spaces.go"}, + {"a/./b/../c", "a/b/c"}, + } + for _, tt := range tests { + got := escapePath(tt.input) + if got != tt.want { + t.Errorf("escapePath(%q) = %q, want %q", tt.input, got, tt.want) + } + } +} + +func TestDecodeBase64Content_CRLF(t *testing.T) { + // Base64 of "hello world" with CRLF line breaks inserted + encoded := "aGVs\r\nbG8g\r\nd29y\r\nbGQ=" + decoded, err := decodeBase64Content(encoded) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if decoded != "hello world" { + t.Errorf("expected 'hello world', got %q", decoded) + } +} + +func TestListContents_SingleFile(t *testing.T) { + // GitHub Contents API returns a JSON object (not array) for single-file paths + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(200) + w.Write([]byte(`{"name":"README.md","path":"README.md","type":"file"}`)) + })) + defer srv.Close() + + c := NewClient("token", srv.URL, AllowInsecureHTTP()) + c.SetHTTPClient(srv.Client()) + entries, err := c.ListContents(context.Background(), "owner", "repo", "README.md") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if len(entries) != 1 { + t.Fatalf("expected 1 entry, got %d", len(entries)) + } + if entries[0].Name != "README.md" { + t.Errorf("expected name 'README.md', got %q", entries[0].Name) + } + if entries[0].Type != "file" { + t.Errorf("expected type 'file', got %q", entries[0].Type) + } +}