336 lines
10 KiB
Go
336 lines
10 KiB
Go
// Copyright 2026 The Gitea Authors. All rights reserved.
|
|
// SPDX-License-Identifier: MIT
|
|
|
|
package milestone_events
|
|
|
|
import (
|
|
"context"
|
|
"sync"
|
|
"testing"
|
|
"time"
|
|
|
|
issues_model "code.gitea.io/gitea/models/issues"
|
|
repo_model "code.gitea.io/gitea/models/repo"
|
|
"code.gitea.io/gitea/modules/eventsource"
|
|
"code.gitea.io/gitea/modules/json"
|
|
"code.gitea.io/gitea/modules/sessiontag"
|
|
|
|
"github.com/stretchr/testify/assert"
|
|
"github.com/stretchr/testify/require"
|
|
)
|
|
|
|
// capturedCall is one observed broadcast: the recipient uid set plus the
|
|
// constructed Event.
|
|
type capturedCall struct {
|
|
uids []int64
|
|
event *eventsource.Event
|
|
}
|
|
|
|
// installFakes swaps every package-level seam used by the publishers for
|
|
// test doubles: a fake uid lister, a stubbed milestone lookup returning a
|
|
// synthetic milestone (no DB hit), a stubbed repo lookup, an "everyone
|
|
// passes" access checker, and a broadcaster that pushes calls onto a
|
|
// buffered channel.
|
|
//
|
|
// The returned restore func reverts every seam; defer it in the test.
|
|
func installFakes(t *testing.T, uids []int64, milestone *issues_model.Milestone) (<-chan capturedCall, func()) {
|
|
t.Helper()
|
|
|
|
calls := make(chan capturedCall, 16)
|
|
|
|
origBroadcast := broadcastFn
|
|
origLister := connectedUIDsLister
|
|
origChecker := repoAccessChecker
|
|
origRepoLookup := repoLookup
|
|
origMsLookup := milestoneLookup
|
|
|
|
broadcastFn = func(uids []int64, event *eventsource.Event) {
|
|
calls <- capturedCall{uids: append([]int64(nil), uids...), event: event}
|
|
}
|
|
connectedUIDsLister = func() []int64 {
|
|
return append([]int64(nil), uids...)
|
|
}
|
|
milestoneLookup = func(_ context.Context, id int64) (*issues_model.Milestone, error) {
|
|
if milestone != nil {
|
|
return milestone, nil
|
|
}
|
|
return &issues_model.Milestone{ID: id, RepoID: 1}, nil
|
|
}
|
|
repoLookup = func(_ context.Context, id int64) (*repo_model.Repository, error) {
|
|
return &repo_model.Repository{ID: id}, nil
|
|
}
|
|
repoAccessChecker = func(_ context.Context, _ int64, _ *repo_model.Repository) (bool, error) {
|
|
return true, nil
|
|
}
|
|
|
|
return calls, func() {
|
|
broadcastFn = origBroadcast
|
|
connectedUIDsLister = origLister
|
|
repoAccessChecker = origChecker
|
|
repoLookup = origRepoLookup
|
|
milestoneLookup = origMsLookup
|
|
}
|
|
}
|
|
|
|
// awaitCall blocks until one capturedCall arrives or the test deadline
|
|
// elapses. It fails the test on timeout.
|
|
func awaitCall(t *testing.T, ch <-chan capturedCall) capturedCall {
|
|
t.Helper()
|
|
select {
|
|
case c := <-ch:
|
|
return c
|
|
case <-time.After(2 * time.Second):
|
|
t.Fatal("timed out waiting for broadcast")
|
|
return capturedCall{}
|
|
}
|
|
}
|
|
|
|
func TestEventNameFormat(t *testing.T) {
|
|
assert.Equal(t, "repo-milestones.42", eventName(42))
|
|
assert.Equal(t, "repo-milestones.0", eventName(0))
|
|
}
|
|
|
|
func TestPublishMilestoneProgress_NameAndPayload(t *testing.T) {
|
|
ms := &issues_model.Milestone{
|
|
ID: 7,
|
|
RepoID: 10,
|
|
NumIssues: 8,
|
|
NumClosedIssues: 6,
|
|
NumOpenIssues: 2,
|
|
Completeness: 75,
|
|
}
|
|
ch, restore := installFakes(t, []int64{1}, ms)
|
|
defer restore()
|
|
|
|
PublishMilestoneProgress(context.Background(), 7)
|
|
|
|
c := awaitCall(t, ch)
|
|
assert.Equal(t, "repo-milestones.10", c.event.Name)
|
|
|
|
data, ok := c.event.Data.([]byte)
|
|
require.True(t, ok, "Event.Data should be []byte")
|
|
var got MilestoneProgress
|
|
require.NoError(t, json.Unmarshal(data, &got))
|
|
assert.Equal(t, MilestoneProgress{
|
|
RepoID: 10, MilestoneID: 7, OpenIssues: 2, ClosedIssues: 6, Completeness: 75,
|
|
}, got)
|
|
}
|
|
|
|
func TestPublishMilestoneProgress_IgnoresNonPositiveID(t *testing.T) {
|
|
ch, restore := installFakes(t, []int64{1}, nil)
|
|
defer restore()
|
|
|
|
PublishMilestoneProgress(context.Background(), 0)
|
|
PublishMilestoneProgress(context.Background(), -3)
|
|
|
|
select {
|
|
case <-ch:
|
|
t.Fatal("no broadcast expected for non-positive milestone id")
|
|
case <-time.After(200 * time.Millisecond):
|
|
}
|
|
}
|
|
|
|
func TestPublishMilestoneProgress_LookupErrorIsSilent(t *testing.T) {
|
|
ch, restore := installFakes(t, []int64{1}, nil)
|
|
defer restore()
|
|
milestoneLookup = func(_ context.Context, _ int64) (*issues_model.Milestone, error) {
|
|
return nil, issues_model.ErrMilestoneNotExist{ID: 99}
|
|
}
|
|
|
|
PublishMilestoneProgress(context.Background(), 99)
|
|
|
|
select {
|
|
case <-ch:
|
|
t.Fatal("no broadcast expected when the milestone re-fetch fails")
|
|
case <-time.After(200 * time.Millisecond):
|
|
}
|
|
}
|
|
|
|
func TestPublishMilestoneDeleted_NameAndPayload(t *testing.T) {
|
|
ch, restore := installFakes(t, []int64{1}, nil)
|
|
defer restore()
|
|
|
|
PublishMilestoneDeleted(context.Background(), 12, 5)
|
|
|
|
c := awaitCall(t, ch)
|
|
assert.Equal(t, "repo-milestones.12", c.event.Name)
|
|
var got MilestoneDeleted
|
|
require.NoError(t, json.Unmarshal(c.event.Data.([]byte), &got))
|
|
assert.Equal(t, MilestoneDeleted{RepoID: 12, MilestoneID: 5}, got)
|
|
}
|
|
|
|
func TestPublishMilestoneDeleted_IgnoresNonPositiveIDs(t *testing.T) {
|
|
ch, restore := installFakes(t, []int64{1}, nil)
|
|
defer restore()
|
|
|
|
PublishMilestoneDeleted(context.Background(), 0, 5)
|
|
PublishMilestoneDeleted(context.Background(), 12, 0)
|
|
|
|
select {
|
|
case <-ch:
|
|
t.Fatal("no broadcast expected for non-positive ids")
|
|
case <-time.After(200 * time.Millisecond):
|
|
}
|
|
}
|
|
|
|
// TestSessionTagPropagation verifies that when a publish is invoked
|
|
// inside a context decorated by sessiontag.WithSessionTag, the emitted
|
|
// JSON payload carries the tag.
|
|
func TestSessionTagPropagation(t *testing.T) {
|
|
ch, restore := installFakes(t, []int64{1}, &issues_model.Milestone{ID: 3, RepoID: 1})
|
|
defer restore()
|
|
|
|
ctx := sessiontag.WithSessionTag(context.Background(), "abc-123")
|
|
PublishMilestoneProgress(ctx, 3)
|
|
|
|
c := awaitCall(t, ch)
|
|
var payload MilestoneProgress
|
|
require.NoError(t, json.Unmarshal(c.event.Data.([]byte), &payload))
|
|
assert.Equal(t, "abc-123", payload.SessionTag)
|
|
}
|
|
|
|
// TestSessionTagAbsentWhenUnset verifies the omitempty tag stays empty
|
|
// when no session tag is on the context.
|
|
func TestSessionTagAbsentWhenUnset(t *testing.T) {
|
|
ch, restore := installFakes(t, []int64{1}, &issues_model.Milestone{ID: 3, RepoID: 1})
|
|
defer restore()
|
|
|
|
PublishMilestoneProgress(context.Background(), 3)
|
|
|
|
c := awaitCall(t, ch)
|
|
var payload MilestoneProgress
|
|
require.NoError(t, json.Unmarshal(c.event.Data.([]byte), &payload))
|
|
assert.Empty(t, payload.SessionTag)
|
|
}
|
|
|
|
// TestSessionTagResolvedSynchronously ensures the tag is read from the
|
|
// request context before the goroutine starts, not from the detached
|
|
// context (which never carries request-scoped values).
|
|
func TestSessionTagResolvedSynchronously(t *testing.T) {
|
|
ch, restore := installFakes(t, []int64{1}, &issues_model.Milestone{ID: 3, RepoID: 1})
|
|
defer restore()
|
|
|
|
ctx := sessiontag.WithSessionTag(context.Background(), "sync-tag")
|
|
PublishMilestoneProgress(ctx, 3)
|
|
|
|
c := awaitCall(t, ch)
|
|
var payload MilestoneProgress
|
|
require.NoError(t, json.Unmarshal(c.event.Data.([]byte), &payload))
|
|
assert.Equal(t, "sync-tag", payload.SessionTag)
|
|
}
|
|
|
|
// TestConnectedUIDsWithRepoIssueAccess_FiltersByPermission ensures the
|
|
// helper drops uids the access checker rejects.
|
|
func TestConnectedUIDsWithRepoIssueAccess_FiltersByPermission(t *testing.T) {
|
|
origLister := connectedUIDsLister
|
|
origChecker := repoAccessChecker
|
|
origRepoLookup := repoLookup
|
|
defer func() {
|
|
connectedUIDsLister = origLister
|
|
repoAccessChecker = origChecker
|
|
repoLookup = origRepoLookup
|
|
}()
|
|
|
|
connectedUIDsLister = func() []int64 { return []int64{1, 2, 3, 4} }
|
|
repoLookup = func(_ context.Context, id int64) (*repo_model.Repository, error) {
|
|
return &repo_model.Repository{ID: id}, nil
|
|
}
|
|
allowed := map[int64]bool{1: true, 3: true}
|
|
repoAccessChecker = func(_ context.Context, uid int64, _ *repo_model.Repository) (bool, error) {
|
|
return allowed[uid], nil
|
|
}
|
|
|
|
got := connectedUIDsWithRepoIssueAccess(context.Background(), 42)
|
|
assert.ElementsMatch(t, []int64{1, 3}, got)
|
|
}
|
|
|
|
// TestConnectedUIDsWithRepoIssueAccess_NoConnections shortcuts when no
|
|
// users are connected; the repo lookup must not be called.
|
|
func TestConnectedUIDsWithRepoIssueAccess_NoConnections(t *testing.T) {
|
|
origLister := connectedUIDsLister
|
|
origRepoLookup := repoLookup
|
|
defer func() {
|
|
connectedUIDsLister = origLister
|
|
repoLookup = origRepoLookup
|
|
}()
|
|
|
|
connectedUIDsLister = func() []int64 { return nil }
|
|
called := false
|
|
repoLookup = func(_ context.Context, _ int64) (*repo_model.Repository, error) {
|
|
called = true
|
|
return &repo_model.Repository{}, nil
|
|
}
|
|
|
|
got := connectedUIDsWithRepoIssueAccess(context.Background(), 42)
|
|
assert.Empty(t, got)
|
|
assert.False(t, called, "repo lookup should be skipped when no uids are connected")
|
|
}
|
|
|
|
// TestPublishEvent_BroadcastsToAllowedUIDs exercises publishEvent
|
|
// directly to verify the uid set computed by the access filter is what
|
|
// gets handed to broadcastFn.
|
|
func TestPublishEvent_BroadcastsToAllowedUIDs(t *testing.T) {
|
|
origBroadcast := broadcastFn
|
|
origLister := connectedUIDsLister
|
|
origChecker := repoAccessChecker
|
|
origRepoLookup := repoLookup
|
|
defer func() {
|
|
broadcastFn = origBroadcast
|
|
connectedUIDsLister = origLister
|
|
repoAccessChecker = origChecker
|
|
repoLookup = origRepoLookup
|
|
}()
|
|
|
|
var mu sync.Mutex
|
|
var got []int64
|
|
broadcastFn = func(uids []int64, _ *eventsource.Event) {
|
|
mu.Lock()
|
|
got = append([]int64(nil), uids...)
|
|
mu.Unlock()
|
|
}
|
|
connectedUIDsLister = func() []int64 { return []int64{10, 20, 30} }
|
|
repoLookup = func(_ context.Context, id int64) (*repo_model.Repository, error) {
|
|
return &repo_model.Repository{ID: id}, nil
|
|
}
|
|
repoAccessChecker = func(_ context.Context, uid int64, _ *repo_model.Repository) (bool, error) {
|
|
return uid != 20, nil
|
|
}
|
|
|
|
publishEvent(context.Background(), 1, MilestoneDeleted{RepoID: 1, MilestoneID: 5})
|
|
|
|
mu.Lock()
|
|
defer mu.Unlock()
|
|
assert.ElementsMatch(t, []int64{10, 30}, got)
|
|
}
|
|
|
|
// TestPublishMilestoneProgress_NoConnectionsNoBroadcast verifies the
|
|
// connected-uid shortcut: with nobody connected nothing is sent even
|
|
// though the milestone re-fetch succeeds.
|
|
func TestPublishMilestoneProgress_NoConnectionsNoBroadcast(t *testing.T) {
|
|
ch, restore := installFakes(t, nil, &issues_model.Milestone{ID: 3, RepoID: 1})
|
|
defer restore()
|
|
|
|
PublishMilestoneProgress(context.Background(), 3)
|
|
|
|
select {
|
|
case <-ch:
|
|
t.Fatal("no broadcast expected when no users are connected")
|
|
case <-time.After(200 * time.Millisecond):
|
|
}
|
|
}
|
|
|
|
// TestPublishMilestoneProgress_FanOutTargetList verifies the recipient
|
|
// list handed to broadcast is exactly the access-filtered set.
|
|
func TestPublishMilestoneProgress_FanOutTargetList(t *testing.T) {
|
|
ch, restore := installFakes(t, []int64{5, 6, 7}, &issues_model.Milestone{ID: 3, RepoID: 1})
|
|
defer restore()
|
|
repoAccessChecker = func(_ context.Context, uid int64, _ *repo_model.Repository) (bool, error) {
|
|
return uid != 6, nil
|
|
}
|
|
|
|
PublishMilestoneProgress(context.Background(), 3)
|
|
|
|
c := awaitCall(t, ch)
|
|
assert.ElementsMatch(t, []int64{5, 7}, c.uids)
|
|
}
|