diff --git a/models/issues/milestone.go b/models/issues/milestone.go index 1dd8630276..bedabd0e4d 100644 --- a/models/issues/milestone.go +++ b/models/issues/milestone.go @@ -136,6 +136,21 @@ func GetMilestoneByRepoID(ctx context.Context, repoID, id int64) (*Milestone, er return m, nil } +// GetMilestoneByID returns the milestone identified by id, regardless of +// which repository it belongs to. Used by the milestone_events SSE +// publisher, which only has the milestone id and re-reads the fresh +// counters from a detached, process-lifetime context. +func GetMilestoneByID(ctx context.Context, id int64) (*Milestone, error) { + m := new(Milestone) + has, err := db.GetEngine(ctx).ID(id).Get(m) + if err != nil { + return nil, err + } else if !has { + return nil, ErrMilestoneNotExist{ID: id} + } + return m, nil +} + // GetMilestoneByRepoIDANDName return a milestone if one exist by name and repo func GetMilestoneByRepoIDANDName(ctx context.Context, repoID int64, name string) (*Milestone, error) { var mile Milestone diff --git a/services/milestone_events/events.go b/services/milestone_events/events.go new file mode 100644 index 0000000000..e92963f087 --- /dev/null +++ b/services/milestone_events/events.go @@ -0,0 +1,216 @@ +// Copyright 2026 The Gitea Authors. All rights reserved. +// SPDX-License-Identifier: MIT + +// Package milestone_events publishes milestone progress changes as +// Server-Sent Events so other browser tabs viewing the same repository's +// milestone list (or a single milestone's issue list) can update their +// progress bars in near real time. +// +// Each public Publish* helper marshals a typed payload to JSON, wraps it +// in an *eventsource.Event whose Name is "repo-milestones.{repo_id}", and +// fans the event out to every currently connected user that has read +// access to the repository's issues unit. All publish helpers are +// non-blocking: they spawn a goroutine so request handlers do not stall +// on slow consumers. +package milestone_events + +import ( + "context" + "strconv" + + issues_model "code.gitea.io/gitea/models/issues" + access_model "code.gitea.io/gitea/models/perm/access" + repo_model "code.gitea.io/gitea/models/repo" + "code.gitea.io/gitea/models/unit" + user_model "code.gitea.io/gitea/models/user" + "code.gitea.io/gitea/modules/eventsource" + "code.gitea.io/gitea/modules/graceful" + "code.gitea.io/gitea/modules/json" + "code.gitea.io/gitea/modules/log" + "code.gitea.io/gitea/modules/sessiontag" +) + +// Event payload structs ------------------------------------------------------ + +// MilestoneProgress is emitted whenever a milestone's issue counters +// (and therefore its completeness percentage) change. It funnels every +// mutation that can move the bar: issue close/reopen, milestone +// (re)assignment, issue creation/deletion, milestone status change and +// milestone edit. +type MilestoneProgress struct { + RepoID int64 `json:"repo_id"` + MilestoneID int64 `json:"milestone_id"` + OpenIssues int `json:"open_issues"` + ClosedIssues int `json:"closed_issues"` + Completeness int `json:"completeness"` + SessionTag string `json:"session_tag,omitempty"` +} + +// MilestoneDeleted is emitted when a milestone is deleted so viewers can +// drop the card (or navigate away from a single-milestone view). +type MilestoneDeleted struct { + RepoID int64 `json:"repo_id"` + MilestoneID int64 `json:"milestone_id"` +} + +// Broadcast plumbing --------------------------------------------------------- + +// broadcastFn is the package-level seam used to send an event to a set of +// uids. Tests swap it out to capture calls without touching the real +// eventsource manager. +var broadcastFn = defaultBroadcast + +func defaultBroadcast(uids []int64, event *eventsource.Event) { + mgr := eventsource.GetManager() + for _, uid := range uids { + mgr.SendMessage(uid, event) + } +} + +// connectedUIDsLister returns the uid set the broadcast helpers should +// consider as candidate recipients. Tests override it to feed a +// deterministic list. +var connectedUIDsLister = func() []int64 { + return eventsource.GetManager().ConnectedUIDs() +} + +// milestoneLookup re-reads a milestone by id from the detached context. +// Stubbable in tests so PublishMilestoneProgress can be exercised +// without a database. +var milestoneLookup = issues_model.GetMilestoneByID + +// repoLookup loads a repository by id. Stubbable in tests so the +// access-filter logic can be exercised without spinning up a database. +var repoLookup = repo_model.GetRepositoryByID + +// repoAccessChecker decides whether the user identified by uid is allowed +// to read the given repository's issues. Tests stub this to bypass the +// real permission system. +var repoAccessChecker = canReadMilestones + +// connectedUIDsWithRepoIssueAccess returns the subset of currently +// connected uids that the access checker confirms can read the issues +// unit of repoID. +func connectedUIDsWithRepoIssueAccess(ctx context.Context, repoID int64) []int64 { + uids := connectedUIDsLister() + if len(uids) == 0 { + return nil + } + repo, err := repoLookup(ctx, repoID) + if err != nil { + log.Debug("milestone_events: GetRepositoryByID(%d) failed: %v", repoID, err) + return nil + } + allowed := make([]int64, 0, len(uids)) + for _, uid := range uids { + ok, err := repoAccessChecker(ctx, uid, repo) + if err != nil { + log.Debug("milestone_events: access check uid=%d repo=%d: %v", uid, repoID, err) + continue + } + if ok { + allowed = append(allowed, uid) + } + } + return allowed +} + +// canReadMilestones implements the real read-permission check used in +// production: a user may see milestone progress for a repo when they can +// read its issues unit. +func canReadMilestones(ctx context.Context, uid int64, repo *repo_model.Repository) (bool, error) { + user, err := user_model.GetUserByID(ctx, uid) + if err != nil { + return false, err + } + // AccessModeRead == 1; the literal mirrors project_events, where the + // perm_model typed constant would force another import alias and the + // meaning is well established here. + return access_model.HasAccessUnit(ctx, user, repo, unit.TypeIssues, 1) +} + +// publishEvent is the shared pipeline used by every Publish* helper. +// It marshals the payload, builds the SSE Event, looks up authorized +// recipients, and fans the event out via broadcastFn. +func publishEvent(ctx context.Context, repoID int64, payload any) { + data, err := json.Marshal(payload) + if err != nil { + log.Error("milestone_events: marshal payload for repo %d: %v", repoID, err) + return + } + event := &eventsource.Event{ + Name: eventName(repoID), + Data: data, + } + uids := connectedUIDsWithRepoIssueAccess(ctx, repoID) + if len(uids) == 0 { + return + } + broadcastFn(uids, event) +} + +// eventName returns the SSE event name for a given repo id. +func eventName(repoID int64) string { + return "repo-milestones." + strconv.FormatInt(repoID, 10) +} + +// Publishers ----------------------------------------------------------------- + +// PublishMilestoneProgress re-reads the milestone's fresh counters and +// fans a MilestoneProgress event out to everyone who can read the repo's +// issues. The session tag is resolved synchronously from the request +// context before the goroutine starts; the goroutine itself runs on a +// detached, process-lifetime context so the request-scoped DB session +// being returned to the pool cannot make the re-fetch/access checks fail. +func PublishMilestoneProgress(ctx context.Context, milestoneID int64) { + if milestoneID <= 0 { + return + } + tag := sessiontag.SessionTagFromContext(ctx) + go func() { + detachCtx := detach(ctx) + m, err := milestoneLookup(detachCtx, milestoneID) + if err != nil { + log.Debug("milestone_events: GetMilestoneByID(%d) failed: %v", milestoneID, err) + return + } + payload := MilestoneProgress{ + RepoID: m.RepoID, + MilestoneID: m.ID, + OpenIssues: m.NumOpenIssues, + ClosedIssues: m.NumClosedIssues, + Completeness: m.Completeness, + SessionTag: tag, + } + publishEvent(detachCtx, m.RepoID, payload) + }() +} + +// PublishMilestoneDeleted fans a MilestoneDeleted event out for the given +// repo/milestone. No re-fetch is needed since the milestone is gone. +func PublishMilestoneDeleted(ctx context.Context, repoID, milestoneID int64) { + if repoID <= 0 || milestoneID <= 0 { + return + } + go func() { + detachCtx := detach(ctx) + publishEvent(detachCtx, repoID, MilestoneDeleted{ + RepoID: repoID, + MilestoneID: milestoneID, + }) + }() +} + +// detach returns a context safe for use in the fire-and-forget publish +// goroutine. The request's context carries a request-scoped DB session +// that is returned to the pool once the HTTP handler completes; reusing +// it from the goroutine races with that teardown and makes subsequent +// queries (GetMilestoneByID, GetRepositoryByID, access checks) fail +// intermittently. The session tag is already resolved synchronously +// before the goroutine starts, so the goroutine needs no request-scoped +// values — only a clean, process-lifetime DB context. ShutdownContext is +// backed by the global engine, outlives any single request, and is +// cancelled on app shutdown so we don't leak goroutines past teardown. +func detach(_ context.Context) context.Context { + return graceful.GetManager().ShutdownContext() +} diff --git a/services/milestone_events/events_test.go b/services/milestone_events/events_test.go new file mode 100644 index 0000000000..d000448bf6 --- /dev/null +++ b/services/milestone_events/events_test.go @@ -0,0 +1,335 @@ +// Copyright 2026 The Gitea Authors. All rights reserved. +// SPDX-License-Identifier: MIT + +package milestone_events + +import ( + "context" + "sync" + "testing" + "time" + + issues_model "code.gitea.io/gitea/models/issues" + repo_model "code.gitea.io/gitea/models/repo" + "code.gitea.io/gitea/modules/eventsource" + "code.gitea.io/gitea/modules/json" + "code.gitea.io/gitea/modules/sessiontag" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// capturedCall is one observed broadcast: the recipient uid set plus the +// constructed Event. +type capturedCall struct { + uids []int64 + event *eventsource.Event +} + +// installFakes swaps every package-level seam used by the publishers for +// test doubles: a fake uid lister, a stubbed milestone lookup returning a +// synthetic milestone (no DB hit), a stubbed repo lookup, an "everyone +// passes" access checker, and a broadcaster that pushes calls onto a +// buffered channel. +// +// The returned restore func reverts every seam; defer it in the test. +func installFakes(t *testing.T, uids []int64, milestone *issues_model.Milestone) (<-chan capturedCall, func()) { + t.Helper() + + calls := make(chan capturedCall, 16) + + origBroadcast := broadcastFn + origLister := connectedUIDsLister + origChecker := repoAccessChecker + origRepoLookup := repoLookup + origMsLookup := milestoneLookup + + broadcastFn = func(uids []int64, event *eventsource.Event) { + calls <- capturedCall{uids: append([]int64(nil), uids...), event: event} + } + connectedUIDsLister = func() []int64 { + return append([]int64(nil), uids...) + } + milestoneLookup = func(_ context.Context, id int64) (*issues_model.Milestone, error) { + if milestone != nil { + return milestone, nil + } + return &issues_model.Milestone{ID: id, RepoID: 1}, nil + } + repoLookup = func(_ context.Context, id int64) (*repo_model.Repository, error) { + return &repo_model.Repository{ID: id}, nil + } + repoAccessChecker = func(_ context.Context, _ int64, _ *repo_model.Repository) (bool, error) { + return true, nil + } + + return calls, func() { + broadcastFn = origBroadcast + connectedUIDsLister = origLister + repoAccessChecker = origChecker + repoLookup = origRepoLookup + milestoneLookup = origMsLookup + } +} + +// awaitCall blocks until one capturedCall arrives or the test deadline +// elapses. It fails the test on timeout. +func awaitCall(t *testing.T, ch <-chan capturedCall) capturedCall { + t.Helper() + select { + case c := <-ch: + return c + case <-time.After(2 * time.Second): + t.Fatal("timed out waiting for broadcast") + return capturedCall{} + } +} + +func TestEventNameFormat(t *testing.T) { + assert.Equal(t, "repo-milestones.42", eventName(42)) + assert.Equal(t, "repo-milestones.0", eventName(0)) +} + +func TestPublishMilestoneProgress_NameAndPayload(t *testing.T) { + ms := &issues_model.Milestone{ + ID: 7, + RepoID: 10, + NumIssues: 8, + NumClosedIssues: 6, + NumOpenIssues: 2, + Completeness: 75, + } + ch, restore := installFakes(t, []int64{1}, ms) + defer restore() + + PublishMilestoneProgress(context.Background(), 7) + + c := awaitCall(t, ch) + assert.Equal(t, "repo-milestones.10", c.event.Name) + + data, ok := c.event.Data.([]byte) + require.True(t, ok, "Event.Data should be []byte") + var got MilestoneProgress + require.NoError(t, json.Unmarshal(data, &got)) + assert.Equal(t, MilestoneProgress{ + RepoID: 10, MilestoneID: 7, OpenIssues: 2, ClosedIssues: 6, Completeness: 75, + }, got) +} + +func TestPublishMilestoneProgress_IgnoresNonPositiveID(t *testing.T) { + ch, restore := installFakes(t, []int64{1}, nil) + defer restore() + + PublishMilestoneProgress(context.Background(), 0) + PublishMilestoneProgress(context.Background(), -3) + + select { + case <-ch: + t.Fatal("no broadcast expected for non-positive milestone id") + case <-time.After(200 * time.Millisecond): + } +} + +func TestPublishMilestoneProgress_LookupErrorIsSilent(t *testing.T) { + ch, restore := installFakes(t, []int64{1}, nil) + defer restore() + milestoneLookup = func(_ context.Context, _ int64) (*issues_model.Milestone, error) { + return nil, issues_model.ErrMilestoneNotExist{ID: 99} + } + + PublishMilestoneProgress(context.Background(), 99) + + select { + case <-ch: + t.Fatal("no broadcast expected when the milestone re-fetch fails") + case <-time.After(200 * time.Millisecond): + } +} + +func TestPublishMilestoneDeleted_NameAndPayload(t *testing.T) { + ch, restore := installFakes(t, []int64{1}, nil) + defer restore() + + PublishMilestoneDeleted(context.Background(), 12, 5) + + c := awaitCall(t, ch) + assert.Equal(t, "repo-milestones.12", c.event.Name) + var got MilestoneDeleted + require.NoError(t, json.Unmarshal(c.event.Data.([]byte), &got)) + assert.Equal(t, MilestoneDeleted{RepoID: 12, MilestoneID: 5}, got) +} + +func TestPublishMilestoneDeleted_IgnoresNonPositiveIDs(t *testing.T) { + ch, restore := installFakes(t, []int64{1}, nil) + defer restore() + + PublishMilestoneDeleted(context.Background(), 0, 5) + PublishMilestoneDeleted(context.Background(), 12, 0) + + select { + case <-ch: + t.Fatal("no broadcast expected for non-positive ids") + case <-time.After(200 * time.Millisecond): + } +} + +// TestSessionTagPropagation verifies that when a publish is invoked +// inside a context decorated by sessiontag.WithSessionTag, the emitted +// JSON payload carries the tag. +func TestSessionTagPropagation(t *testing.T) { + ch, restore := installFakes(t, []int64{1}, &issues_model.Milestone{ID: 3, RepoID: 1}) + defer restore() + + ctx := sessiontag.WithSessionTag(context.Background(), "abc-123") + PublishMilestoneProgress(ctx, 3) + + c := awaitCall(t, ch) + var payload MilestoneProgress + require.NoError(t, json.Unmarshal(c.event.Data.([]byte), &payload)) + assert.Equal(t, "abc-123", payload.SessionTag) +} + +// TestSessionTagAbsentWhenUnset verifies the omitempty tag stays empty +// when no session tag is on the context. +func TestSessionTagAbsentWhenUnset(t *testing.T) { + ch, restore := installFakes(t, []int64{1}, &issues_model.Milestone{ID: 3, RepoID: 1}) + defer restore() + + PublishMilestoneProgress(context.Background(), 3) + + c := awaitCall(t, ch) + var payload MilestoneProgress + require.NoError(t, json.Unmarshal(c.event.Data.([]byte), &payload)) + assert.Empty(t, payload.SessionTag) +} + +// TestSessionTagResolvedSynchronously ensures the tag is read from the +// request context before the goroutine starts, not from the detached +// context (which never carries request-scoped values). +func TestSessionTagResolvedSynchronously(t *testing.T) { + ch, restore := installFakes(t, []int64{1}, &issues_model.Milestone{ID: 3, RepoID: 1}) + defer restore() + + ctx := sessiontag.WithSessionTag(context.Background(), "sync-tag") + PublishMilestoneProgress(ctx, 3) + + c := awaitCall(t, ch) + var payload MilestoneProgress + require.NoError(t, json.Unmarshal(c.event.Data.([]byte), &payload)) + assert.Equal(t, "sync-tag", payload.SessionTag) +} + +// TestConnectedUIDsWithRepoIssueAccess_FiltersByPermission ensures the +// helper drops uids the access checker rejects. +func TestConnectedUIDsWithRepoIssueAccess_FiltersByPermission(t *testing.T) { + origLister := connectedUIDsLister + origChecker := repoAccessChecker + origRepoLookup := repoLookup + defer func() { + connectedUIDsLister = origLister + repoAccessChecker = origChecker + repoLookup = origRepoLookup + }() + + connectedUIDsLister = func() []int64 { return []int64{1, 2, 3, 4} } + repoLookup = func(_ context.Context, id int64) (*repo_model.Repository, error) { + return &repo_model.Repository{ID: id}, nil + } + allowed := map[int64]bool{1: true, 3: true} + repoAccessChecker = func(_ context.Context, uid int64, _ *repo_model.Repository) (bool, error) { + return allowed[uid], nil + } + + got := connectedUIDsWithRepoIssueAccess(context.Background(), 42) + assert.ElementsMatch(t, []int64{1, 3}, got) +} + +// TestConnectedUIDsWithRepoIssueAccess_NoConnections shortcuts when no +// users are connected; the repo lookup must not be called. +func TestConnectedUIDsWithRepoIssueAccess_NoConnections(t *testing.T) { + origLister := connectedUIDsLister + origRepoLookup := repoLookup + defer func() { + connectedUIDsLister = origLister + repoLookup = origRepoLookup + }() + + connectedUIDsLister = func() []int64 { return nil } + called := false + repoLookup = func(_ context.Context, _ int64) (*repo_model.Repository, error) { + called = true + return &repo_model.Repository{}, nil + } + + got := connectedUIDsWithRepoIssueAccess(context.Background(), 42) + assert.Empty(t, got) + assert.False(t, called, "repo lookup should be skipped when no uids are connected") +} + +// TestPublishEvent_BroadcastsToAllowedUIDs exercises publishEvent +// directly to verify the uid set computed by the access filter is what +// gets handed to broadcastFn. +func TestPublishEvent_BroadcastsToAllowedUIDs(t *testing.T) { + origBroadcast := broadcastFn + origLister := connectedUIDsLister + origChecker := repoAccessChecker + origRepoLookup := repoLookup + defer func() { + broadcastFn = origBroadcast + connectedUIDsLister = origLister + repoAccessChecker = origChecker + repoLookup = origRepoLookup + }() + + var mu sync.Mutex + var got []int64 + broadcastFn = func(uids []int64, _ *eventsource.Event) { + mu.Lock() + got = append([]int64(nil), uids...) + mu.Unlock() + } + connectedUIDsLister = func() []int64 { return []int64{10, 20, 30} } + repoLookup = func(_ context.Context, id int64) (*repo_model.Repository, error) { + return &repo_model.Repository{ID: id}, nil + } + repoAccessChecker = func(_ context.Context, uid int64, _ *repo_model.Repository) (bool, error) { + return uid != 20, nil + } + + publishEvent(context.Background(), 1, MilestoneDeleted{RepoID: 1, MilestoneID: 5}) + + mu.Lock() + defer mu.Unlock() + assert.ElementsMatch(t, []int64{10, 30}, got) +} + +// TestPublishMilestoneProgress_NoConnectionsNoBroadcast verifies the +// connected-uid shortcut: with nobody connected nothing is sent even +// though the milestone re-fetch succeeds. +func TestPublishMilestoneProgress_NoConnectionsNoBroadcast(t *testing.T) { + ch, restore := installFakes(t, nil, &issues_model.Milestone{ID: 3, RepoID: 1}) + defer restore() + + PublishMilestoneProgress(context.Background(), 3) + + select { + case <-ch: + t.Fatal("no broadcast expected when no users are connected") + case <-time.After(200 * time.Millisecond): + } +} + +// TestPublishMilestoneProgress_FanOutTargetList verifies the recipient +// list handed to broadcast is exactly the access-filtered set. +func TestPublishMilestoneProgress_FanOutTargetList(t *testing.T) { + ch, restore := installFakes(t, []int64{5, 6, 7}, &issues_model.Milestone{ID: 3, RepoID: 1}) + defer restore() + repoAccessChecker = func(_ context.Context, uid int64, _ *repo_model.Repository) (bool, error) { + return uid != 6, nil + } + + PublishMilestoneProgress(context.Background(), 3) + + c := awaitCall(t, ch) + assert.ElementsMatch(t, []int64{5, 7}, c.uids) +}