// Copyright 2026 The Gitea Authors. All rights reserved. // SPDX-License-Identifier: MIT package milestone_events import ( "context" "sync" "testing" "time" issues_model "code.gitea.io/gitea/models/issues" repo_model "code.gitea.io/gitea/models/repo" "code.gitea.io/gitea/modules/eventsource" "code.gitea.io/gitea/modules/json" "code.gitea.io/gitea/modules/sessiontag" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) // capturedCall is one observed broadcast: the recipient uid set plus the // constructed Event. type capturedCall struct { uids []int64 event *eventsource.Event } // installFakes swaps every package-level seam used by the publishers for // test doubles: a fake uid lister, a stubbed milestone lookup returning a // synthetic milestone (no DB hit), a stubbed repo lookup, an "everyone // passes" access checker, and a broadcaster that pushes calls onto a // buffered channel. // // The returned restore func reverts every seam; defer it in the test. func installFakes(t *testing.T, uids []int64, milestone *issues_model.Milestone) (<-chan capturedCall, func()) { t.Helper() calls := make(chan capturedCall, 16) origBroadcast := broadcastFn origLister := connectedUIDsLister origChecker := repoAccessChecker origRepoLookup := repoLookup origMsLookup := milestoneLookup broadcastFn = func(uids []int64, event *eventsource.Event) { calls <- capturedCall{uids: append([]int64(nil), uids...), event: event} } connectedUIDsLister = func() []int64 { return append([]int64(nil), uids...) } milestoneLookup = func(_ context.Context, id int64) (*issues_model.Milestone, error) { if milestone != nil { return milestone, nil } return &issues_model.Milestone{ID: id, RepoID: 1}, nil } repoLookup = func(_ context.Context, id int64) (*repo_model.Repository, error) { return &repo_model.Repository{ID: id}, nil } repoAccessChecker = func(_ context.Context, _ int64, _ *repo_model.Repository) (bool, error) { return true, nil } return calls, func() { broadcastFn = origBroadcast connectedUIDsLister = origLister repoAccessChecker = origChecker repoLookup = origRepoLookup milestoneLookup = origMsLookup } } // awaitCall blocks until one capturedCall arrives or the test deadline // elapses. It fails the test on timeout. func awaitCall(t *testing.T, ch <-chan capturedCall) capturedCall { t.Helper() select { case c := <-ch: return c case <-time.After(2 * time.Second): t.Fatal("timed out waiting for broadcast") return capturedCall{} } } func TestEventNameFormat(t *testing.T) { assert.Equal(t, "repo-milestones.42", eventName(42)) assert.Equal(t, "repo-milestones.0", eventName(0)) } func TestPublishMilestoneProgress_NameAndPayload(t *testing.T) { ms := &issues_model.Milestone{ ID: 7, RepoID: 10, NumIssues: 8, NumClosedIssues: 6, NumOpenIssues: 2, Completeness: 75, } ch, restore := installFakes(t, []int64{1}, ms) defer restore() PublishMilestoneProgress(context.Background(), 7) c := awaitCall(t, ch) assert.Equal(t, "repo-milestones.10", c.event.Name) data, ok := c.event.Data.([]byte) require.True(t, ok, "Event.Data should be []byte") var got MilestoneProgress require.NoError(t, json.Unmarshal(data, &got)) assert.Equal(t, MilestoneProgress{ RepoID: 10, MilestoneID: 7, OpenIssues: 2, ClosedIssues: 6, Completeness: 75, }, got) } func TestPublishMilestoneProgress_IgnoresNonPositiveID(t *testing.T) { ch, restore := installFakes(t, []int64{1}, nil) defer restore() PublishMilestoneProgress(context.Background(), 0) PublishMilestoneProgress(context.Background(), -3) select { case <-ch: t.Fatal("no broadcast expected for non-positive milestone id") case <-time.After(200 * time.Millisecond): } } func TestPublishMilestoneProgress_LookupErrorIsSilent(t *testing.T) { ch, restore := installFakes(t, []int64{1}, nil) defer restore() milestoneLookup = func(_ context.Context, _ int64) (*issues_model.Milestone, error) { return nil, issues_model.ErrMilestoneNotExist{ID: 99} } PublishMilestoneProgress(context.Background(), 99) select { case <-ch: t.Fatal("no broadcast expected when the milestone re-fetch fails") case <-time.After(200 * time.Millisecond): } } func TestPublishMilestoneDeleted_NameAndPayload(t *testing.T) { ch, restore := installFakes(t, []int64{1}, nil) defer restore() PublishMilestoneDeleted(context.Background(), 12, 5) c := awaitCall(t, ch) assert.Equal(t, "repo-milestones.12", c.event.Name) var got MilestoneDeleted require.NoError(t, json.Unmarshal(c.event.Data.([]byte), &got)) assert.Equal(t, MilestoneDeleted{RepoID: 12, MilestoneID: 5}, got) } func TestPublishMilestoneDeleted_IgnoresNonPositiveIDs(t *testing.T) { ch, restore := installFakes(t, []int64{1}, nil) defer restore() PublishMilestoneDeleted(context.Background(), 0, 5) PublishMilestoneDeleted(context.Background(), 12, 0) select { case <-ch: t.Fatal("no broadcast expected for non-positive ids") case <-time.After(200 * time.Millisecond): } } // TestSessionTagPropagation verifies that when a publish is invoked // inside a context decorated by sessiontag.WithSessionTag, the emitted // JSON payload carries the tag. func TestSessionTagPropagation(t *testing.T) { ch, restore := installFakes(t, []int64{1}, &issues_model.Milestone{ID: 3, RepoID: 1}) defer restore() ctx := sessiontag.WithSessionTag(context.Background(), "abc-123") PublishMilestoneProgress(ctx, 3) c := awaitCall(t, ch) var payload MilestoneProgress require.NoError(t, json.Unmarshal(c.event.Data.([]byte), &payload)) assert.Equal(t, "abc-123", payload.SessionTag) } // TestSessionTagAbsentWhenUnset verifies the omitempty tag stays empty // when no session tag is on the context. func TestSessionTagAbsentWhenUnset(t *testing.T) { ch, restore := installFakes(t, []int64{1}, &issues_model.Milestone{ID: 3, RepoID: 1}) defer restore() PublishMilestoneProgress(context.Background(), 3) c := awaitCall(t, ch) var payload MilestoneProgress require.NoError(t, json.Unmarshal(c.event.Data.([]byte), &payload)) assert.Empty(t, payload.SessionTag) } // TestSessionTagResolvedSynchronously ensures the tag is read from the // request context before the goroutine starts, not from the detached // context (which never carries request-scoped values). func TestSessionTagResolvedSynchronously(t *testing.T) { ch, restore := installFakes(t, []int64{1}, &issues_model.Milestone{ID: 3, RepoID: 1}) defer restore() ctx := sessiontag.WithSessionTag(context.Background(), "sync-tag") PublishMilestoneProgress(ctx, 3) c := awaitCall(t, ch) var payload MilestoneProgress require.NoError(t, json.Unmarshal(c.event.Data.([]byte), &payload)) assert.Equal(t, "sync-tag", payload.SessionTag) } // TestConnectedUIDsWithRepoIssueAccess_FiltersByPermission ensures the // helper drops uids the access checker rejects. func TestConnectedUIDsWithRepoIssueAccess_FiltersByPermission(t *testing.T) { origLister := connectedUIDsLister origChecker := repoAccessChecker origRepoLookup := repoLookup defer func() { connectedUIDsLister = origLister repoAccessChecker = origChecker repoLookup = origRepoLookup }() connectedUIDsLister = func() []int64 { return []int64{1, 2, 3, 4} } repoLookup = func(_ context.Context, id int64) (*repo_model.Repository, error) { return &repo_model.Repository{ID: id}, nil } allowed := map[int64]bool{1: true, 3: true} repoAccessChecker = func(_ context.Context, uid int64, _ *repo_model.Repository) (bool, error) { return allowed[uid], nil } got := connectedUIDsWithRepoIssueAccess(context.Background(), 42) assert.ElementsMatch(t, []int64{1, 3}, got) } // TestConnectedUIDsWithRepoIssueAccess_NoConnections shortcuts when no // users are connected; the repo lookup must not be called. func TestConnectedUIDsWithRepoIssueAccess_NoConnections(t *testing.T) { origLister := connectedUIDsLister origRepoLookup := repoLookup defer func() { connectedUIDsLister = origLister repoLookup = origRepoLookup }() connectedUIDsLister = func() []int64 { return nil } called := false repoLookup = func(_ context.Context, _ int64) (*repo_model.Repository, error) { called = true return &repo_model.Repository{}, nil } got := connectedUIDsWithRepoIssueAccess(context.Background(), 42) assert.Empty(t, got) assert.False(t, called, "repo lookup should be skipped when no uids are connected") } // TestPublishEvent_BroadcastsToAllowedUIDs exercises publishEvent // directly to verify the uid set computed by the access filter is what // gets handed to broadcastFn. func TestPublishEvent_BroadcastsToAllowedUIDs(t *testing.T) { origBroadcast := broadcastFn origLister := connectedUIDsLister origChecker := repoAccessChecker origRepoLookup := repoLookup defer func() { broadcastFn = origBroadcast connectedUIDsLister = origLister repoAccessChecker = origChecker repoLookup = origRepoLookup }() var mu sync.Mutex var got []int64 broadcastFn = func(uids []int64, _ *eventsource.Event) { mu.Lock() got = append([]int64(nil), uids...) mu.Unlock() } connectedUIDsLister = func() []int64 { return []int64{10, 20, 30} } repoLookup = func(_ context.Context, id int64) (*repo_model.Repository, error) { return &repo_model.Repository{ID: id}, nil } repoAccessChecker = func(_ context.Context, uid int64, _ *repo_model.Repository) (bool, error) { return uid != 20, nil } publishEvent(context.Background(), 1, MilestoneDeleted{RepoID: 1, MilestoneID: 5}) mu.Lock() defer mu.Unlock() assert.ElementsMatch(t, []int64{10, 30}, got) } // TestPublishMilestoneProgress_NoConnectionsNoBroadcast verifies the // connected-uid shortcut: with nobody connected nothing is sent even // though the milestone re-fetch succeeds. func TestPublishMilestoneProgress_NoConnectionsNoBroadcast(t *testing.T) { ch, restore := installFakes(t, nil, &issues_model.Milestone{ID: 3, RepoID: 1}) defer restore() PublishMilestoneProgress(context.Background(), 3) select { case <-ch: t.Fatal("no broadcast expected when no users are connected") case <-time.After(200 * time.Millisecond): } } // TestPublishMilestoneProgress_FanOutTargetList verifies the recipient // list handed to broadcast is exactly the access-filtered set. func TestPublishMilestoneProgress_FanOutTargetList(t *testing.T) { ch, restore := installFakes(t, []int64{5, 6, 7}, &issues_model.Milestone{ID: 3, RepoID: 1}) defer restore() repoAccessChecker = func(_ context.Context, uid int64, _ *repo_model.Repository) (bool, error) { return uid != 6, nil } PublishMilestoneProgress(context.Background(), 3) c := awaitCall(t, ch) assert.ElementsMatch(t, []int64{5, 7}, c.uids) }