package github import ( "context" "encoding/json" "fmt" "net/http" "net/http/httptest" "strings" "testing" "time" ) func TestGetPullRequest_HappyPath(t *testing.T) { srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if r.URL.Path != "/repos/owner/repo/pulls/42" { t.Errorf("unexpected path: %s", r.URL.Path) } json.NewEncoder(w).Encode(map[string]interface{}{ "number": 42, "title": "Test PR", "body": "Description", "head": map[string]string{"sha": "abc123", "ref": "feature-branch"}, "base": map[string]string{"ref": "main"}, }) })) defer srv.Close() c := NewClient("token", srv.URL, AllowInsecureHTTP()) c.SetHTTPClient(srv.Client()) pr, err := c.GetPullRequest(context.Background(), "owner", "repo", 42) if err != nil { t.Fatalf("unexpected error: %v", err) } if pr.Number != 42 { t.Errorf("expected number 42, got %d", pr.Number) } if pr.Title != "Test PR" { t.Errorf("expected title 'Test PR', got %q", pr.Title) } if pr.Body != "Description" { t.Errorf("expected body 'Description', got %q", pr.Body) } if pr.Head.SHA != "abc123" { t.Errorf("expected head SHA 'abc123', got %q", pr.Head.SHA) } if pr.Head.Ref != "feature-branch" { t.Errorf("expected head ref 'feature-branch', got %q", pr.Head.Ref) } if pr.Base.Ref != "main" { t.Errorf("expected base ref 'main', got %q", pr.Base.Ref) } } func TestGetPullRequest_404(t *testing.T) { srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(404) w.Write([]byte(`{"message":"Not Found"}`)) })) defer srv.Close() c := NewClient("token", srv.URL, AllowInsecureHTTP()) c.SetHTTPClient(srv.Client()) _, err := c.GetPullRequest(context.Background(), "owner", "repo", 999) if err == nil { t.Fatal("expected error for 404") } if !IsNotFound(err) { t.Errorf("expected IsNotFound=true, got error: %v", err) } } func TestGetPullRequest_401(t *testing.T) { srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(401) w.Write([]byte(`{"message":"Bad credentials"}`)) })) defer srv.Close() c := NewClient("token", srv.URL, AllowInsecureHTTP()) c.SetHTTPClient(srv.Client()) _, err := c.GetPullRequest(context.Background(), "owner", "repo", 1) if err == nil { t.Fatal("expected error for 401") } if !IsUnauthorized(err) { t.Errorf("expected IsUnauthorized=true, got error: %v", err) } } func TestGetPullRequest_429Retry(t *testing.T) { attempts := 0 srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { attempts++ if attempts == 1 { w.WriteHeader(429) w.Write([]byte(`{"message":"rate limit"}`)) return } json.NewEncoder(w).Encode(map[string]interface{}{ "number": 1, "title": "PR", "body": "", "head": map[string]string{"sha": "abc", "ref": "br"}, "base": map[string]string{"ref": "main"}, }) })) defer srv.Close() c := NewClient("token", srv.URL, AllowInsecureHTTP()) c.SetHTTPClient(srv.Client()) c.SetRetryBackoff([]time.Duration{1 * time.Millisecond}) pr, err := c.GetPullRequest(context.Background(), "owner", "repo", 1) if err != nil { t.Fatalf("unexpected error: %v", err) } if pr.Number != 1 { t.Errorf("expected number 1, got %d", pr.Number) } if attempts != 2 { t.Errorf("expected 2 attempts, got %d", attempts) } } func TestGetPullRequest_MalformedJSON(t *testing.T) { srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(200) w.Write([]byte(`{invalid json`)) })) defer srv.Close() c := NewClient("token", srv.URL, AllowInsecureHTTP()) c.SetHTTPClient(srv.Client()) _, err := c.GetPullRequest(context.Background(), "owner", "repo", 1) if err == nil { t.Fatal("expected error for malformed JSON") } if !strings.Contains(err.Error(), "parse PR JSON") { t.Errorf("expected parse error, got: %v", err) } } func TestGetPullRequestDiff_HappyPath(t *testing.T) { expectedDiff := "diff --git a/file.go b/file.go\n--- a/file.go\n+++ b/file.go\n@@ -1,3 +1,4 @@\n+// new line\n" var gotAccept string srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { gotAccept = r.Header.Get("Accept") w.WriteHeader(200) w.Write([]byte(expectedDiff)) })) defer srv.Close() c := NewClient("token", srv.URL, AllowInsecureHTTP()) c.SetHTTPClient(srv.Client()) diff, err := c.GetPullRequestDiff(context.Background(), "owner", "repo", 42) if err != nil { t.Fatalf("unexpected error: %v", err) } if diff != expectedDiff { t.Errorf("unexpected diff: %q", diff) } if gotAccept != "application/vnd.github.diff" { t.Errorf("expected diff Accept header, got %q", gotAccept) } } func TestGetPullRequestDiff_404(t *testing.T) { srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(404) w.Write([]byte(`{"message":"Not Found"}`)) })) defer srv.Close() c := NewClient("token", srv.URL, AllowInsecureHTTP()) c.SetHTTPClient(srv.Client()) _, err := c.GetPullRequestDiff(context.Background(), "owner", "repo", 999) if err == nil { t.Fatal("expected error for 404") } } func TestGetPullRequestDiff_401(t *testing.T) { srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(401) w.Write([]byte(`{"message":"Bad credentials"}`)) })) defer srv.Close() c := NewClient("token", srv.URL, AllowInsecureHTTP()) c.SetHTTPClient(srv.Client()) _, err := c.GetPullRequestDiff(context.Background(), "owner", "repo", 1) if err == nil { t.Fatal("expected error for 401") } } func TestGetPullRequestFiles_HappyPath(t *testing.T) { srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { json.NewEncoder(w).Encode([]map[string]interface{}{ {"filename": "main.go", "status": "modified", "patch": "@@ -1,3 +1,4 @@\n+line"}, {"filename": "test.go", "status": "added", "patch": "@@ -0,0 +1,5 @@\n+new file"}, }) })) defer srv.Close() c := NewClient("token", srv.URL, AllowInsecureHTTP()) c.SetHTTPClient(srv.Client()) files, err := c.GetPullRequestFiles(context.Background(), "owner", "repo", 1) if err != nil { t.Fatalf("unexpected error: %v", err) } if len(files) != 2 { t.Fatalf("expected 2 files, got %d", len(files)) } if files[0].Filename != "main.go" { t.Errorf("expected filename 'main.go', got %q", files[0].Filename) } if files[0].Status != "modified" { t.Errorf("expected status 'modified', got %q", files[0].Status) } if files[0].Patch != "@@ -1,3 +1,4 @@\n+line" { t.Errorf("unexpected patch: %q", files[0].Patch) } } func TestGetPullRequestFiles_Pagination(t *testing.T) { // Simulate > 100 files requiring pagination page1Files := make([]map[string]string, 100) for i := 0; i < 100; i++ { page1Files[i] = map[string]string{ "filename": fmt.Sprintf("file%d.go", i), "status": "modified", "patch": fmt.Sprintf("patch%d", i), } } page2Files := []map[string]string{ {"filename": "file100.go", "status": "added", "patch": "patch100"}, } srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { page := r.URL.Query().Get("page") if page == "" || page == "1" { json.NewEncoder(w).Encode(page1Files) } else { json.NewEncoder(w).Encode(page2Files) } })) defer srv.Close() c := NewClient("token", srv.URL, AllowInsecureHTTP()) c.SetHTTPClient(srv.Client()) files, err := c.GetPullRequestFiles(context.Background(), "owner", "repo", 1) if err != nil { t.Fatalf("unexpected error: %v", err) } if len(files) != 101 { t.Errorf("expected 101 files (paginated), got %d", len(files)) } if files[100].Filename != "file100.go" { t.Errorf("expected last file 'file100.go', got %q", files[100].Filename) } if files[100].Patch != "patch100" { t.Errorf("expected last patch 'patch100', got %q", files[100].Patch) } } func TestGetPullRequestFiles_BinaryFile_NoPatch(t *testing.T) { srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { // Binary files have no patch field in GitHub response json.NewEncoder(w).Encode([]map[string]interface{}{ {"filename": "image.png", "status": "added"}, }) })) defer srv.Close() c := NewClient("token", srv.URL, AllowInsecureHTTP()) c.SetHTTPClient(srv.Client()) files, err := c.GetPullRequestFiles(context.Background(), "owner", "repo", 1) if err != nil { t.Fatalf("unexpected error: %v", err) } if len(files) != 1 { t.Fatalf("expected 1 file, got %d", len(files)) } if files[0].Patch != "" { t.Errorf("expected empty patch for binary file, got %q", files[0].Patch) } } func TestGetPullRequestFiles_404(t *testing.T) { srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(404) w.Write([]byte(`{"message":"Not Found"}`)) })) defer srv.Close() c := NewClient("token", srv.URL, AllowInsecureHTTP()) c.SetHTTPClient(srv.Client()) _, err := c.GetPullRequestFiles(context.Background(), "owner", "repo", 999) if err == nil { t.Fatal("expected error for 404") } } func TestGetPullRequestFiles_MalformedJSON(t *testing.T) { srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(200) w.Write([]byte(`not json`)) })) defer srv.Close() c := NewClient("token", srv.URL, AllowInsecureHTTP()) c.SetHTTPClient(srv.Client()) _, err := c.GetPullRequestFiles(context.Background(), "owner", "repo", 1) if err == nil { t.Fatal("expected error for malformed JSON") } } func TestGetFileContentAtRef_HappyPath(t *testing.T) { srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if r.URL.Path != "/repos/owner/repo/contents/path/to/file.go" { t.Errorf("unexpected path: %s", r.URL.Path) } if r.URL.Query().Get("ref") != "abc123" { t.Errorf("unexpected ref: %s", r.URL.Query().Get("ref")) } json.NewEncoder(w).Encode(map[string]string{ "content": "cGFja2FnZSBtYWlu", // "package main" in base64 "encoding": "base64", }) })) defer srv.Close() c := NewClient("token", srv.URL, AllowInsecureHTTP()) c.SetHTTPClient(srv.Client()) content, err := c.GetFileContentAtRef(context.Background(), "owner", "repo", "path/to/file.go", "abc123") if err != nil { t.Fatalf("unexpected error: %v", err) } if content != "package main" { t.Errorf("expected 'package main', got %q", content) } } func TestGetFileContentAtRef_EmptyRef(t *testing.T) { srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if r.URL.Query().Get("ref") != "" { t.Errorf("expected no ref param, got %q", r.URL.Query().Get("ref")) } json.NewEncoder(w).Encode(map[string]string{ "content": "aGVsbG8=", // "hello" in base64 "encoding": "base64", }) })) defer srv.Close() c := NewClient("token", srv.URL, AllowInsecureHTTP()) c.SetHTTPClient(srv.Client()) content, err := c.GetFileContentAtRef(context.Background(), "owner", "repo", "file.txt", "") if err != nil { t.Fatalf("unexpected error: %v", err) } if content != "hello" { t.Errorf("expected 'hello', got %q", content) } } func TestGetFileContentAtRef_404(t *testing.T) { srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(404) w.Write([]byte(`{"message":"Not Found"}`)) })) defer srv.Close() c := NewClient("token", srv.URL, AllowInsecureHTTP()) c.SetHTTPClient(srv.Client()) _, err := c.GetFileContentAtRef(context.Background(), "owner", "repo", "missing.go", "main") if err == nil { t.Fatal("expected error for 404") } } func TestGetFileContentAtRef_401(t *testing.T) { srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(401) w.Write([]byte(`{"message":"Bad credentials"}`)) })) defer srv.Close() c := NewClient("token", srv.URL, AllowInsecureHTTP()) c.SetHTTPClient(srv.Client()) _, err := c.GetFileContentAtRef(context.Background(), "owner", "repo", "file.go", "main") if err == nil { t.Fatal("expected error for 401") } } func TestGetFileContentAtRef_MalformedJSON(t *testing.T) { srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(200) w.Write([]byte(`not valid json`)) })) defer srv.Close() c := NewClient("token", srv.URL, AllowInsecureHTTP()) c.SetHTTPClient(srv.Client()) _, err := c.GetFileContentAtRef(context.Background(), "owner", "repo", "file.go", "main") if err == nil { t.Fatal("expected error for malformed JSON") } } func TestGetFileContentAtRef_429Retry(t *testing.T) { attempts := 0 srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { attempts++ if attempts == 1 { w.WriteHeader(429) w.Write([]byte(`{"message":"rate limit"}`)) return } json.NewEncoder(w).Encode(map[string]string{ "content": "b2s=", // "ok" in base64 "encoding": "base64", }) })) defer srv.Close() c := NewClient("token", srv.URL, AllowInsecureHTTP()) c.SetHTTPClient(srv.Client()) c.SetRetryBackoff([]time.Duration{1 * time.Millisecond}) content, err := c.GetFileContentAtRef(context.Background(), "owner", "repo", "file.go", "main") if err != nil { t.Fatalf("unexpected error: %v", err) } if content != "ok" { t.Errorf("expected 'ok', got %q", content) } if attempts != 2 { t.Errorf("expected 2 attempts, got %d", attempts) } } func TestGetCommitStatuses_HappyPath(t *testing.T) { srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { switch { case strings.Contains(r.URL.Path, "/status"): json.NewEncoder(w).Encode(map[string]interface{}{ "state": "success", "statuses": []map[string]string{ { "context": "ci/build", "state": "success", "description": "Build passed", "target_url": "https://ci.example.com/1", }, }, }) case strings.Contains(r.URL.Path, "/check-runs"): conclusion := "success" json.NewEncoder(w).Encode(map[string]interface{}{ "total_count": 1, "check_runs": []map[string]interface{}{ { "name": "lint", "conclusion": &conclusion, "status": "completed", "html_url": "https://github.com/check/1", }, }, }) default: t.Errorf("unexpected path: %s", r.URL.Path) w.WriteHeader(404) } })) defer srv.Close() c := NewClient("token", srv.URL, AllowInsecureHTTP()) c.SetHTTPClient(srv.Client()) statuses, err := c.GetCommitStatuses(context.Background(), "owner", "repo", "abc123") if err != nil { t.Fatalf("unexpected error: %v", err) } if len(statuses) != 2 { t.Fatalf("expected 2 statuses, got %d", len(statuses)) } // First should be from commit statuses if statuses[0].Context != "ci/build" { t.Errorf("expected context 'ci/build', got %q", statuses[0].Context) } if statuses[0].Status != "success" { t.Errorf("expected status 'success', got %q", statuses[0].Status) } // Second should be from check runs if statuses[1].Context != "lint" { t.Errorf("expected context 'lint', got %q", statuses[1].Context) } if statuses[1].Status != "success" { t.Errorf("expected status 'success', got %q", statuses[1].Status) } } func TestGetCommitStatuses_CheckRunConclusions(t *testing.T) { tests := []struct { conclusion *string status string want string }{ {stringPtr("success"), "completed", "success"}, {stringPtr("failure"), "completed", "failure"}, {stringPtr("action_required"), "completed", "failure"}, {stringPtr("timed_out"), "completed", "failure"}, {stringPtr("cancelled"), "completed", "success"}, {stringPtr("skipped"), "completed", "success"}, {stringPtr("neutral"), "completed", "success"}, {nil, "in_progress", "pending"}, {nil, "queued", "pending"}, } for _, tt := range tests { name := "nil" if tt.conclusion != nil { name = *tt.conclusion } t.Run(name, func(t *testing.T) { t.Parallel() srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if strings.Contains(r.URL.Path, "/status") { json.NewEncoder(w).Encode(map[string]interface{}{ "state": "success", "statuses": []interface{}{}, }) return } json.NewEncoder(w).Encode(map[string]interface{}{ "total_count": 1, "check_runs": []map[string]interface{}{ { "name": "check", "conclusion": tt.conclusion, "status": tt.status, "html_url": "https://github.com/check/1", }, }, }) })) defer srv.Close() c := NewClient("token", srv.URL, AllowInsecureHTTP()) c.SetHTTPClient(srv.Client()) statuses, err := c.GetCommitStatuses(context.Background(), "owner", "repo", "sha1") if err != nil { t.Fatalf("unexpected error: %v", err) } if len(statuses) != 1 { t.Fatalf("expected 1 status, got %d", len(statuses)) } if statuses[0].Status != tt.want { t.Errorf("expected status %q, got %q", tt.want, statuses[0].Status) } }) } } func TestGetCommitStatuses_404(t *testing.T) { srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(404) w.Write([]byte(`{"message":"Not Found"}`)) })) defer srv.Close() c := NewClient("token", srv.URL, AllowInsecureHTTP()) c.SetHTTPClient(srv.Client()) _, err := c.GetCommitStatuses(context.Background(), "owner", "repo", "badsha") if err == nil { t.Fatal("expected error for 404") } } func TestGetCommitStatuses_401(t *testing.T) { srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(401) w.Write([]byte(`{"message":"Bad credentials"}`)) })) defer srv.Close() c := NewClient("token", srv.URL, AllowInsecureHTTP()) c.SetHTTPClient(srv.Client()) _, err := c.GetCommitStatuses(context.Background(), "owner", "repo", "sha") if err == nil { t.Fatal("expected error for 401") } } func TestGetCommitStatuses_MalformedJSON(t *testing.T) { srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(200) w.Write([]byte(`not json`)) })) defer srv.Close() c := NewClient("token", srv.URL, AllowInsecureHTTP()) c.SetHTTPClient(srv.Client()) _, err := c.GetCommitStatuses(context.Background(), "owner", "repo", "sha") if err == nil { t.Fatal("expected error for malformed JSON") } } func TestGetCommitStatuses_CheckRunsErrorAfterStatusesSucceed(t *testing.T) { srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { switch { case strings.Contains(r.URL.Path, "/status"): // Statuses succeed json.NewEncoder(w).Encode(map[string]interface{}{ "state": "success", "statuses": []map[string]string{ { "context": "ci/build", "state": "success", "description": "Build passed", "target_url": "https://ci.example.com/1", }, }, }) case strings.Contains(r.URL.Path, "/check-runs"): // Check runs fail with 500 w.WriteHeader(500) w.Write([]byte(`{"message":"Internal Server Error"}`)) default: w.WriteHeader(404) } })) defer srv.Close() c := NewClient("token", srv.URL, AllowInsecureHTTP()) c.SetHTTPClient(srv.Client()) _, err := c.GetCommitStatuses(context.Background(), "owner", "repo", "abc123") if err == nil { t.Fatal("expected error when check-runs endpoint fails after statuses succeed") } if !strings.Contains(err.Error(), "fetch check runs") { t.Errorf("expected check runs error, got: %v", err) } } func stringPtr(s string) *string { return &s }