From b6cdb51f837b22dc36d48e73ca61258019b47b39 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?=E5=82=85=E4=B8=80=E6=B4=AA?= <fuyihong.1@bytedance.com>
Date: Tue, 26 Mar 2024 22:52:28 +0800
Subject: [PATCH] fix(eventbus): retry events would be lost when using custom
 retry

---
 eventbus/mysql/eventbus.go      | 40 ++++++++++++---------
 eventbus/mysql/eventbus_test.go | 62 +++++++++++++++++++++++++++++++++
 2 files changed, 86 insertions(+), 16 deletions(-)

diff --git a/eventbus/mysql/eventbus.go b/eventbus/mysql/eventbus.go
index f1adcda..7973b59 100644
--- a/eventbus/mysql/eventbus.go
+++ b/eventbus/mysql/eventbus.go
@@ -53,6 +53,7 @@ var ErrNoTransaction = fmt.Errorf("no transaction")
 var ErrServiceNotCreate = fmt.Errorf("service not create")
 
 var defaultLogger = stdr.NewStdr("mysql_eventbus")
+var eventBusMu sync.Mutex
 
 type IRetryStrategy interface {
 	// Next 获取下一次重试的策略,返回 nil 表示不再重试
@@ -361,26 +362,38 @@ func (e *EventBus) getScanEvents(db *gorm.DB, service *ServicePO) ([]*EventPO, e
 	return eventPOs, nil
 }
 
-func (e *EventBus) getRetryEvents(db *gorm.DB, service *ServicePO) ([]*EventPO, error) {
+func (e *EventBus) getRetryEvents(db *gorm.DB, service *ServicePO) ([]*EventPO, []int64, error) {
 	now := time.Now()
 	retryIDs := make([]int64, 0)
+	remainIDs := make([]int64, 0)
 	for _, info := range service.Retry {
 		if info.RetryTime.Before(now) {
 			retryIDs = append(retryIDs, info.ID)
+		} else {
+			remainIDs = append(remainIDs, info.ID)
 		}
 	}
 	if len(retryIDs) == 0 {
-		return nil, nil
+		return nil, remainIDs, nil
 	}
 
 	eventPOs := make([]*EventPO, 0)
 	if err := db.Where("id in ?", retryIDs).Find(&eventPOs).Error; err != nil {
-		return nil, err
+		return nil, nil, err
 	}
-	return eventPOs, nil
+	return eventPOs, remainIDs, nil
 }
 
-func (e *EventBus) doRetryStrategy(service *ServicePO, failedIDs []int64) (retry, failed []*RetryInfo) {
+func (e *EventBus) doRetryStrategy(service *ServicePO, remainIDs, failedIDs []int64) (retry, failed []*RetryInfo) {
+	retryInfos := make(map[int64]*RetryInfo)
+	for _, info := range service.Retry {
+		retryInfos[info.ID] = info
+	}
+
+	for _, id := range remainIDs {
+		retry = append(retry, retryInfos[id])
+	}
+
 	// 没有定义重试策略,默认不重试直接失败
 	if e.retryStrategy == nil {
 		for _, id := range failedIDs {
@@ -388,10 +401,6 @@ func (e *EventBus) doRetryStrategy(service *ServicePO, failedIDs []int64) (retry
 		}
 		return
 	}
-	retryInfos := make(map[int64]*RetryInfo)
-	for _, info := range service.Retry {
-		retryInfos[info.ID] = info
-	}
 
 	for _, id := range failedIDs {
 		info := retryInfos[id]
@@ -417,7 +426,6 @@ func (e *EventBus) dispatchEvents(ctx context.Context, eventPOs []*EventPO) (fai
 	close(events)
 
 	wg := sync.WaitGroup{}
-	mu := sync.Mutex{}
 	for i := 0; i < e.opt.ConsumeConcurrent; i++ {
 		wg.Add(1)
 		go func() {
@@ -427,14 +435,14 @@ func (e *EventBus) dispatchEvents(ctx context.Context, eventPOs []*EventPO) (fai
 				defer func() {
 					if r := recover(); r != nil {
 						e.logger.Error(fmt.Errorf("err: %v stack:%s", r, string(debug.Stack())), fmt.Sprintf("panic while handling event(%s)", po.EventID))
-						mu.Lock()
-						defer mu.Unlock()
+						eventBusMu.Lock()
+						defer eventBusMu.Unlock()
 						panics = append(panics, po.ID)
 					}
 				}()
 				if err := e.cb(ctx, po.Event); err != nil {
-					mu.Lock()
-					defer mu.Unlock()
+					eventBusMu.Lock()
+					defer eventBusMu.Unlock()
 					failed = append(failed, po.ID)
 				}
 			}
@@ -460,7 +468,7 @@ func (e *EventBus) handleEvents() error {
 		if err != nil {
 			return err
 		}
-		retryEvents, err := e.getRetryEvents(tx, service)
+		retryEvents, remainIDs, err := e.getRetryEvents(tx, service)
 		if err != nil {
 			return err
 		}
@@ -472,7 +480,7 @@ func (e *EventBus) handleEvents() error {
 		}
 
 		failedIDs, panicIDs := e.dispatchEvents(ctx, events)
-		retry, failed := e.doRetryStrategy(service, failedIDs)
+		retry, failed := e.doRetryStrategy(service, remainIDs, failedIDs)
 		service.Retry = retry
 		service.Failed = append(service.Failed, failed...)
 		for _, id := range panicIDs {
diff --git a/eventbus/mysql/eventbus_test.go b/eventbus/mysql/eventbus_test.go
index 3b249aa..d79aa1e 100644
--- a/eventbus/mysql/eventbus_test.go
+++ b/eventbus/mysql/eventbus_test.go
@@ -247,6 +247,68 @@ func TestEventBusRetry(t *testing.T) {
 	}
 }
 
+func TestEventBusRetryStrategy(t *testing.T) {
+	ctx := context.Background()
+	db := testsuit.InitMysql()
+
+	mu := sync.Mutex{}
+	counts := map[string]int{}
+	eventBus := NewEventBus("test_retry_strategy", db, func(opt *Options) {
+		opt.LimitPerRun = 100
+		opt.ConsumeConcurrent = 10
+		opt.RetryStrategy = &CustomRetry{
+			Intervals: []time.Duration{
+				10 * time.Millisecond,
+				10 * time.Millisecond,
+				1 * time.Hour,
+			},
+		}
+		offset := int64(0)
+		opt.DefaultOffset = &offset
+	})
+	eventBus.RegisterEventHandler(func(ctx context.Context, evt *dddfirework.DomainEvent) error {
+		mu.Lock()
+		counts[evt.ID] += 1
+		mu.Unlock()
+
+		if evt.ID == "0" {
+			return fmt.Errorf("retry")
+		}
+		return nil
+	})
+
+	for i := 0; i < 10; i++ {
+		if err := eventBus.Dispatch(ctx, dddfirework.NewDomainEvent(&testEvent{EType: "test_retry_strategy", Data: "retry"})); err != nil {
+			assert.NoError(t, err)
+		}
+		if err := eventBus.handleEvents(); err != nil {
+			assert.NoError(t, err)
+		}
+		time.Sleep(time.Millisecond * 10)
+	}
+
+	err := db.Transaction(func(tx *gorm.DB) error {
+		var eventCount int64
+		tx.Model(&EventPO{}).Count(&eventCount)
+		assert.Equal(t, eventCount, int64(len(counts)))
+		for id, count := range counts {
+			if id == "0" {
+				assert.Equal(t, 3, count)
+			} else {
+				assert.Equal(t, 1, count)
+			}
+		}
+
+		service := &ServicePO{}
+		err := tx.Where("name = ?", "test_retry_strategy").First(service).Error
+		assert.NoError(t, err)
+		assert.Len(t, service.Retry, 1)
+		return err
+	})
+
+	assert.NoError(t, err)
+}
+
 func TestEventBusFailed(t *testing.T) {
 	ctx := context.Background()
 	db := testsuit.InitMysql()
-- 
GitLab