Files
gitea/services/milestone_events/events_test.go
T

336 lines
10 KiB
Go

// Copyright 2026 The Gitea Authors. All rights reserved.
// SPDX-License-Identifier: MIT
package milestone_events
import (
"context"
"sync"
"testing"
"time"
issues_model "code.gitea.io/gitea/models/issues"
repo_model "code.gitea.io/gitea/models/repo"
"code.gitea.io/gitea/modules/eventsource"
"code.gitea.io/gitea/modules/json"
"code.gitea.io/gitea/modules/sessiontag"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// capturedCall is one observed broadcast: the recipient uid set plus the
// constructed Event.
type capturedCall struct {
uids []int64
event *eventsource.Event
}
// installFakes swaps every package-level seam used by the publishers for
// test doubles: a fake uid lister, a stubbed milestone lookup returning a
// synthetic milestone (no DB hit), a stubbed repo lookup, an "everyone
// passes" access checker, and a broadcaster that pushes calls onto a
// buffered channel.
//
// The returned restore func reverts every seam; defer it in the test.
func installFakes(t *testing.T, uids []int64, milestone *issues_model.Milestone) (<-chan capturedCall, func()) {
t.Helper()
calls := make(chan capturedCall, 16)
origBroadcast := broadcastFn
origLister := connectedUIDsLister
origChecker := repoAccessChecker
origRepoLookup := repoLookup
origMsLookup := milestoneLookup
broadcastFn = func(uids []int64, event *eventsource.Event) {
calls <- capturedCall{uids: append([]int64(nil), uids...), event: event}
}
connectedUIDsLister = func() []int64 {
return append([]int64(nil), uids...)
}
milestoneLookup = func(_ context.Context, id int64) (*issues_model.Milestone, error) {
if milestone != nil {
return milestone, nil
}
return &issues_model.Milestone{ID: id, RepoID: 1}, nil
}
repoLookup = func(_ context.Context, id int64) (*repo_model.Repository, error) {
return &repo_model.Repository{ID: id}, nil
}
repoAccessChecker = func(_ context.Context, _ int64, _ *repo_model.Repository) (bool, error) {
return true, nil
}
return calls, func() {
broadcastFn = origBroadcast
connectedUIDsLister = origLister
repoAccessChecker = origChecker
repoLookup = origRepoLookup
milestoneLookup = origMsLookup
}
}
// awaitCall blocks until one capturedCall arrives or the test deadline
// elapses. It fails the test on timeout.
func awaitCall(t *testing.T, ch <-chan capturedCall) capturedCall {
t.Helper()
select {
case c := <-ch:
return c
case <-time.After(2 * time.Second):
t.Fatal("timed out waiting for broadcast")
return capturedCall{}
}
}
func TestEventNameFormat(t *testing.T) {
assert.Equal(t, "repo-milestones.42", eventName(42))
assert.Equal(t, "repo-milestones.0", eventName(0))
}
func TestPublishMilestoneProgress_NameAndPayload(t *testing.T) {
ms := &issues_model.Milestone{
ID: 7,
RepoID: 10,
NumIssues: 8,
NumClosedIssues: 6,
NumOpenIssues: 2,
Completeness: 75,
}
ch, restore := installFakes(t, []int64{1}, ms)
defer restore()
PublishMilestoneProgress(context.Background(), 7)
c := awaitCall(t, ch)
assert.Equal(t, "repo-milestones.10", c.event.Name)
data, ok := c.event.Data.([]byte)
require.True(t, ok, "Event.Data should be []byte")
var got MilestoneProgress
require.NoError(t, json.Unmarshal(data, &got))
assert.Equal(t, MilestoneProgress{
RepoID: 10, MilestoneID: 7, OpenIssues: 2, ClosedIssues: 6, Completeness: 75,
}, got)
}
func TestPublishMilestoneProgress_IgnoresNonPositiveID(t *testing.T) {
ch, restore := installFakes(t, []int64{1}, nil)
defer restore()
PublishMilestoneProgress(context.Background(), 0)
PublishMilestoneProgress(context.Background(), -3)
select {
case <-ch:
t.Fatal("no broadcast expected for non-positive milestone id")
case <-time.After(200 * time.Millisecond):
}
}
func TestPublishMilestoneProgress_LookupErrorIsSilent(t *testing.T) {
ch, restore := installFakes(t, []int64{1}, nil)
defer restore()
milestoneLookup = func(_ context.Context, _ int64) (*issues_model.Milestone, error) {
return nil, issues_model.ErrMilestoneNotExist{ID: 99}
}
PublishMilestoneProgress(context.Background(), 99)
select {
case <-ch:
t.Fatal("no broadcast expected when the milestone re-fetch fails")
case <-time.After(200 * time.Millisecond):
}
}
func TestPublishMilestoneDeleted_NameAndPayload(t *testing.T) {
ch, restore := installFakes(t, []int64{1}, nil)
defer restore()
PublishMilestoneDeleted(context.Background(), 12, 5)
c := awaitCall(t, ch)
assert.Equal(t, "repo-milestones.12", c.event.Name)
var got MilestoneDeleted
require.NoError(t, json.Unmarshal(c.event.Data.([]byte), &got))
assert.Equal(t, MilestoneDeleted{RepoID: 12, MilestoneID: 5}, got)
}
func TestPublishMilestoneDeleted_IgnoresNonPositiveIDs(t *testing.T) {
ch, restore := installFakes(t, []int64{1}, nil)
defer restore()
PublishMilestoneDeleted(context.Background(), 0, 5)
PublishMilestoneDeleted(context.Background(), 12, 0)
select {
case <-ch:
t.Fatal("no broadcast expected for non-positive ids")
case <-time.After(200 * time.Millisecond):
}
}
// TestSessionTagPropagation verifies that when a publish is invoked
// inside a context decorated by sessiontag.WithSessionTag, the emitted
// JSON payload carries the tag.
func TestSessionTagPropagation(t *testing.T) {
ch, restore := installFakes(t, []int64{1}, &issues_model.Milestone{ID: 3, RepoID: 1})
defer restore()
ctx := sessiontag.WithSessionTag(context.Background(), "abc-123")
PublishMilestoneProgress(ctx, 3)
c := awaitCall(t, ch)
var payload MilestoneProgress
require.NoError(t, json.Unmarshal(c.event.Data.([]byte), &payload))
assert.Equal(t, "abc-123", payload.SessionTag)
}
// TestSessionTagAbsentWhenUnset verifies the omitempty tag stays empty
// when no session tag is on the context.
func TestSessionTagAbsentWhenUnset(t *testing.T) {
ch, restore := installFakes(t, []int64{1}, &issues_model.Milestone{ID: 3, RepoID: 1})
defer restore()
PublishMilestoneProgress(context.Background(), 3)
c := awaitCall(t, ch)
var payload MilestoneProgress
require.NoError(t, json.Unmarshal(c.event.Data.([]byte), &payload))
assert.Empty(t, payload.SessionTag)
}
// TestSessionTagResolvedSynchronously ensures the tag is read from the
// request context before the goroutine starts, not from the detached
// context (which never carries request-scoped values).
func TestSessionTagResolvedSynchronously(t *testing.T) {
ch, restore := installFakes(t, []int64{1}, &issues_model.Milestone{ID: 3, RepoID: 1})
defer restore()
ctx := sessiontag.WithSessionTag(context.Background(), "sync-tag")
PublishMilestoneProgress(ctx, 3)
c := awaitCall(t, ch)
var payload MilestoneProgress
require.NoError(t, json.Unmarshal(c.event.Data.([]byte), &payload))
assert.Equal(t, "sync-tag", payload.SessionTag)
}
// TestConnectedUIDsWithRepoIssueAccess_FiltersByPermission ensures the
// helper drops uids the access checker rejects.
func TestConnectedUIDsWithRepoIssueAccess_FiltersByPermission(t *testing.T) {
origLister := connectedUIDsLister
origChecker := repoAccessChecker
origRepoLookup := repoLookup
defer func() {
connectedUIDsLister = origLister
repoAccessChecker = origChecker
repoLookup = origRepoLookup
}()
connectedUIDsLister = func() []int64 { return []int64{1, 2, 3, 4} }
repoLookup = func(_ context.Context, id int64) (*repo_model.Repository, error) {
return &repo_model.Repository{ID: id}, nil
}
allowed := map[int64]bool{1: true, 3: true}
repoAccessChecker = func(_ context.Context, uid int64, _ *repo_model.Repository) (bool, error) {
return allowed[uid], nil
}
got := connectedUIDsWithRepoIssueAccess(context.Background(), 42)
assert.ElementsMatch(t, []int64{1, 3}, got)
}
// TestConnectedUIDsWithRepoIssueAccess_NoConnections shortcuts when no
// users are connected; the repo lookup must not be called.
func TestConnectedUIDsWithRepoIssueAccess_NoConnections(t *testing.T) {
origLister := connectedUIDsLister
origRepoLookup := repoLookup
defer func() {
connectedUIDsLister = origLister
repoLookup = origRepoLookup
}()
connectedUIDsLister = func() []int64 { return nil }
called := false
repoLookup = func(_ context.Context, _ int64) (*repo_model.Repository, error) {
called = true
return &repo_model.Repository{}, nil
}
got := connectedUIDsWithRepoIssueAccess(context.Background(), 42)
assert.Empty(t, got)
assert.False(t, called, "repo lookup should be skipped when no uids are connected")
}
// TestPublishEvent_BroadcastsToAllowedUIDs exercises publishEvent
// directly to verify the uid set computed by the access filter is what
// gets handed to broadcastFn.
func TestPublishEvent_BroadcastsToAllowedUIDs(t *testing.T) {
origBroadcast := broadcastFn
origLister := connectedUIDsLister
origChecker := repoAccessChecker
origRepoLookup := repoLookup
defer func() {
broadcastFn = origBroadcast
connectedUIDsLister = origLister
repoAccessChecker = origChecker
repoLookup = origRepoLookup
}()
var mu sync.Mutex
var got []int64
broadcastFn = func(uids []int64, _ *eventsource.Event) {
mu.Lock()
got = append([]int64(nil), uids...)
mu.Unlock()
}
connectedUIDsLister = func() []int64 { return []int64{10, 20, 30} }
repoLookup = func(_ context.Context, id int64) (*repo_model.Repository, error) {
return &repo_model.Repository{ID: id}, nil
}
repoAccessChecker = func(_ context.Context, uid int64, _ *repo_model.Repository) (bool, error) {
return uid != 20, nil
}
publishEvent(context.Background(), 1, MilestoneDeleted{RepoID: 1, MilestoneID: 5})
mu.Lock()
defer mu.Unlock()
assert.ElementsMatch(t, []int64{10, 30}, got)
}
// TestPublishMilestoneProgress_NoConnectionsNoBroadcast verifies the
// connected-uid shortcut: with nobody connected nothing is sent even
// though the milestone re-fetch succeeds.
func TestPublishMilestoneProgress_NoConnectionsNoBroadcast(t *testing.T) {
ch, restore := installFakes(t, nil, &issues_model.Milestone{ID: 3, RepoID: 1})
defer restore()
PublishMilestoneProgress(context.Background(), 3)
select {
case <-ch:
t.Fatal("no broadcast expected when no users are connected")
case <-time.After(200 * time.Millisecond):
}
}
// TestPublishMilestoneProgress_FanOutTargetList verifies the recipient
// list handed to broadcast is exactly the access-filtered set.
func TestPublishMilestoneProgress_FanOutTargetList(t *testing.T) {
ch, restore := installFakes(t, []int64{5, 6, 7}, &issues_model.Milestone{ID: 3, RepoID: 1})
defer restore()
repoAccessChecker = func(_ context.Context, uid int64, _ *repo_model.Repository) (bool, error) {
return uid != 6, nil
}
PublishMilestoneProgress(context.Background(), 3)
c := awaitCall(t, ch)
assert.ElementsMatch(t, []int64{5, 7}, c.uids)
}