From 8a606686f635ecce4e767528a666cf77338da65a Mon Sep 17 00:00:00 2001
From: kirinzhong <144225553+kirinzhong@users.noreply.github.com>
Date: Tue, 24 Oct 2023 19:15:36 +0800
Subject: [PATCH] feat: add ITimer interface and implement of mysql (#6)

---
 d_timer/sql/model.go       |  60 +++++++++++++++
 d_timer/sql/timer.go       | 152 +++++++++++++++++++++++++++++++++++++
 d_timer/sql/timer_test.go  | 148 ++++++++++++++++++++++++++++++++++++
 engine.go                  |  43 +++++++++++
 event.go                   |  12 ++-
 eventbus/mysql/eventbus.go |  13 ++--
 timer.go                   |  65 ++++++++++++++++
 7 files changed, 486 insertions(+), 7 deletions(-)
 create mode 100644 d_timer/sql/model.go
 create mode 100644 d_timer/sql/timer.go
 create mode 100644 d_timer/sql/timer_test.go
 create mode 100644 timer.go

diff --git a/d_timer/sql/model.go b/d_timer/sql/model.go
new file mode 100644
index 0000000..3ac8f65
--- /dev/null
+++ b/d_timer/sql/model.go
@@ -0,0 +1,60 @@
+package sql
+
+import (
+	"errors"
+	"fmt"
+	"time"
+
+	"github.com/robfig/cron"
+)
+
+var ErrTimerOverdue = fmt.Errorf("timer overdue")
+
+type TimerJob struct {
+	ID        int64       `gorm:"primaryKey;column:id;autoIncrement"`
+	Service   string      `gorm:"column:service;type:varchar(30)"`
+	Key       string      `gorm:"column:key;type:varchar(30);uniqueIndex;not null"`
+	Cron      string      `gorm:"column:cron;type:varchar(30);null"`
+	NextTime  time.Time   `gorm:"column:next_time;type:datetime;index;not null"`
+	Status    TimerStatus `gorm:"column:status;type:tinyint"`
+	Msg       string      `gorm:"column:msg;type:varchar(128)"`
+	Payload   []byte      `gorm:"column:payload;type:text"`
+	CreatedAt time.Time   `gorm:"index;type:datetime"`
+}
+
+func (t *TimerJob) TableName() string {
+	return "ddd_timer"
+}
+
+func (t *TimerJob) Next() error {
+	if t.Cron != "" {
+		return t.Reset()
+	} else {
+		t.Close(nil)
+		return nil
+	}
+}
+
+func (t *TimerJob) Reset() error {
+	if t.Cron == "" {
+		return nil
+	}
+	scheduler, err := cron.Parse(t.Cron)
+	if err != nil {
+		t.Close(err)
+		return err
+	}
+
+	t.NextTime = scheduler.Next(time.Now())
+	t.Status = TimerToRun
+	return nil
+}
+
+func (t *TimerJob) Close(err error) {
+	if err == nil || errors.Is(err, ErrTimerOverdue) {
+		t.Status = TimerFinished
+	} else {
+		t.Status = TimerFailed
+		t.Msg = err.Error()
+	}
+}
diff --git a/d_timer/sql/timer.go b/d_timer/sql/timer.go
new file mode 100644
index 0000000..792efd4
--- /dev/null
+++ b/d_timer/sql/timer.go
@@ -0,0 +1,152 @@
+package sql
+
+import (
+	"context"
+	stdlog "log"
+	"os"
+	"sync"
+	"time"
+
+	ddd "github.com/bytedance/dddfirework"
+	"github.com/go-logr/logr"
+	"github.com/go-logr/stdr"
+	"gorm.io/gorm"
+	"gorm.io/gorm/clause"
+)
+
+const defaultInterval = time.Second
+
+var defaultLogger = stdr.New(stdlog.New(os.Stderr, "", stdlog.LstdFlags|stdlog.Lshortfile)).WithName("db_timer")
+
+type TimerStatus int
+
+const (
+	TimerToRun    TimerStatus = 1
+	TimerFinished TimerStatus = 2
+	TimerFailed   TimerStatus = 3
+)
+
+type Options struct {
+	RunInterval time.Duration
+	Logger      logr.Logger
+}
+
+type Option func(opt *Options)
+
+type DBTimer struct {
+	service string
+	db      *gorm.DB
+	cb      ddd.TimerHandler
+	opt     Options
+	logger  logr.Logger
+	once    sync.Once
+}
+
+func NewDBTimer(service string, db *gorm.DB, opts ...Option) *DBTimer {
+	if service == "" {
+		panic("service name is required")
+	}
+	opt := Options{
+		RunInterval: defaultInterval,
+		Logger:      defaultLogger,
+	}
+	for _, o := range opts {
+		o(&opt)
+	}
+	return &DBTimer{
+		service: service,
+		db:      db,
+		opt:     opt,
+		logger:  opt.Logger,
+		once:    sync.Once{},
+	}
+}
+
+func (t *DBTimer) RunCron(key, cronExp string, data []byte) error {
+	newTimer := TimerJob{
+		Service: t.service,
+		Key:     key,
+		Cron:    cronExp,
+		Payload: data,
+		Status:  TimerToRun,
+	}
+	return t.run(&newTimer)
+}
+
+func (t *DBTimer) RunOnce(key string, runTime time.Time, data []byte) error {
+	if runTime.Before(time.Now()) {
+		return ErrTimerOverdue
+	}
+
+	newTimer := TimerJob{
+		Service:  t.service,
+		Key:      key,
+		NextTime: runTime,
+		Payload:  data,
+		Status:   TimerToRun,
+	}
+	return t.run(&newTimer)
+}
+
+func (t *DBTimer) Cancel(key string) error {
+	return t.db.Unscoped().Where(TimerJob{Key: key}).Delete(&TimerJob{}).Error
+}
+
+func (t *DBTimer) run(job *TimerJob) error {
+	if err := job.Reset(); err != nil {
+		return err
+	}
+	return t.db.Where(TimerJob{
+		Service: t.service,
+		Key:     job.Key,
+	}).Attrs(job).FirstOrCreate(&TimerJob{}).Error
+}
+
+func (t *DBTimer) RegisterTimerHandler(cb ddd.TimerHandler) {
+	t.cb = cb
+}
+
+func (t *DBTimer) handleJobs(ctx context.Context) error {
+	return t.db.Transaction(func(tx *gorm.DB) error {
+		jobs := make([]*TimerJob, 0)
+		if err := tx.Clauses(clause.Locking{Strength: "UPDATE"}).Where(
+			"service = ? and next_time <= ? and status = ?", t.service, time.Now(), TimerToRun,
+		).Find(&jobs).Error; err != nil {
+			return err
+		}
+
+		if len(jobs) == 0 {
+			return nil
+		}
+
+		for _, job := range jobs {
+			if err := t.cb(ctx, job.Key, job.Cron, job.Payload); err != nil {
+				t.logger.Error(err, "timer callback failed")
+			}
+			if err := job.Next(); err != nil {
+				job.Close(err)
+			}
+
+			if err := tx.Save(job).Error; err != nil {
+				return err
+			}
+		}
+		return nil
+	})
+
+}
+
+func (t *DBTimer) Start(ctx context.Context) {
+	t.once.Do(func() {
+		run := func() {
+			// 定时触发event_handler
+			ticker := time.NewTicker(t.opt.RunInterval)
+			for range ticker.C {
+				if err := t.handleJobs(context.Background()); err != nil {
+					t.logger.Error(err, "handle job failed")
+				}
+			}
+		}
+		go run()
+	})
+}
diff --git a/d_timer/sql/timer_test.go b/d_timer/sql/timer_test.go
new file mode 100644
index 0000000..3dd0f15
--- /dev/null
+++ b/d_timer/sql/timer_test.go
@@ -0,0 +1,148 @@
+package sql
+
+import (
+	"context"
+	"sync"
+	"testing"
+	"time"
+
+	"github.com/bytedance/dddfirework"
+	exec_mysql "github.com/bytedance/dddfirework/executor/mysql"
+	"github.com/bytedance/dddfirework/testsuit"
+	"github.com/stretchr/testify/assert"
+	"gorm.io/gorm"
+)
+
+func initModel(db *gorm.DB) {
+	if err := db.AutoMigrate(&TimerJob{}); err != nil {
+		panic(err)
+	}
+}
+
+func init() {
+	db := testsuit.InitMysql()
+	initModel(db)
+
+	db.Where("1 = 1").Delete(&TimerJob{})
+}
+
+func TestTimerConcurrent(t *testing.T) {
+	db := testsuit.InitMysql()
+
+	var output string
+	var times []time.Time
+	var mu sync.Mutex
+
+	testName := "test_concurrent"
+
+	callback := func(ctx context.Context, key, cron string, data []byte) error {
+		mu.Lock()
+		defer mu.Unlock()
+
+		times = append(times, time.Now())
+		output = string(data)
+		return nil
+	}
+	ctx := context.Background()
+	var timer *DBTimer
+	for i := 0; i < 5; i++ {
+		timer = NewDBTimer(testName, db, func(opt *Options) {
+			opt.RunInterval = time.Millisecond * 10
+		})
+		timer.RegisterTimerHandler(callback)
+		timer.Start(ctx)
+	}
+
+	err := timer.RunCron(testName, "0/1 * * * * ?", []byte(testName))
+	assert.NoError(t, err)
+
+	time.Sleep(time.Second * 5)
+
+	assert.Equal(t, testName, output)
+	assert.Greater(t, len(times), 2)
+	for i := 0; i < len(times)-1; i++ {
+		assert.GreaterOrEqual(t, times[i+1].Sub(times[i]), time.Millisecond*950)
+		assert.Less(t, times[i+1].Sub(times[i]), time.Millisecond*1050)
+	}
+}
+
+func TestTimerEngine(t *testing.T) {
+	db := testsuit.InitMysql()
+	ctx := context.Background()
+
+	testName := "test_engine"
+	timer := NewDBTimer(testName, db, func(opt *Options) {
+		opt.RunInterval = time.Second
+	})
+	timer.Start(ctx)
+
+	var k, c string
+	engine := dddfirework.NewEngine(nil, exec_mysql.NewExecutor(db), dddfirework.WithTimer(timer))
+	engine.RegisterCronTask(dddfirework.EventType(testName), "0/1 * * * * ?", func(key, cron string) {
+		k, c = key, cron
+	})
+
+	time.Sleep(time.Second * 2)
+
+	assert.Equal(t, testName, k)
+	assert.Equal(t, "0/1 * * * * ?", c)
+}
+
+func TestTimerCancel(t *testing.T) {
+	db := testsuit.InitMysql()
+
+	var output string
+	testName := "test_cancel"
+
+	ctx := context.Background()
+	var timer = NewDBTimer(testName, db, func(opt *Options) {
+		opt.RunInterval = time.Second
+	})
+	timer.RegisterTimerHandler(func(ctx context.Context, key, cron string, data []byte) error {
+		output = string(data)
+		return nil
+	})
+	timer.Start(ctx)
+
+	err := timer.RunCron(testName, "0/1 * * * * ?", []byte(testName))
+	assert.NoError(t, err)
+
+	time.Sleep(time.Second * 2)
+
+	assert.Equal(t, testName, output)
+
+	err = timer.Cancel(testName)
+	assert.NoError(t, err)
+
+	output = ""
+	time.Sleep(time.Second * 2)
+	assert.Empty(t, output)
+}
+
+func TestTimerOnce(t *testing.T) {
+	db := testsuit.InitMysql()
+
+	var output string
+	testName := "test_once"
+
+	ctx := context.Background()
+	var timer = NewDBTimer(testName, db, func(opt *Options) {
+		opt.RunInterval = time.Millisecond * 500
+	})
+	timer.RegisterTimerHandler(func(ctx context.Context, key, cron string, data []byte) error {
+		output = string(data)
+		return nil
+	})
+	timer.Start(ctx)
+
+	err := timer.RunOnce(testName, time.Now().Add(time.Millisecond*500), []byte(testName))
+	assert.NoError(t, err)
+
+	time.Sleep(time.Second * 1)
+
+	assert.Equal(t, testName, output)
+
+	output = ""
+	time.Sleep(time.Second * 2)
+	assert.Empty(t, output)
+}
diff --git a/engine.go b/engine.go
index 8acc328..f969d7f 100644
--- a/engine.go
+++ b/engine.go
@@ -184,6 +184,7 @@ type Options struct {
 	EventPersist    EventPersist // 是否保存领域事件到 DB
 	Logger          logr.Logger
 	EventBus        IEventBus
+	Timer           ITimer
 	IDGenerator     IIDGenerator
 	PostSaveHooks   []PostSaveFunc
 }
@@ -232,6 +233,18 @@ func WithEventBus(eventBus IEventBus) EventBusOption {
 	return EventBusOption{eventBus: eventBus}
 }
 
+type DTimerOption struct {
+	timer ITimer
+}
+
+func (t DTimerOption) ApplyToOptions(opts *Options) {
+	opts.Timer = t.timer
+}
+
+func WithTimer(timer ITimer) DTimerOption {
+	return DTimerOption{timer: timer}
+}
+
 type EventPersist func(event *DomainEvent) (IModel, error)
 
 type EventSaveOption EventPersist
@@ -271,6 +284,7 @@ type Engine struct {
 	executor    IExecutor
 	idGenerator IIDGenerator
 	eventbus    IEventBus
+	timer       ITimer
 	logger      logr.Logger
 	options     Options
 }
@@ -282,6 +296,7 @@ func NewEngine(l ILock, e IExecutor, opts ...Option) *Engine {
 		Logger:          defaultLogger,
 		IDGenerator:     &defaultIDGenerator{},
 		EventBus:        &noEventBus{},
+		Timer:           &noTimer{},
 	}
 	for _, opt := range opts {
 		opt.ApplyToOptions(&options)
@@ -291,10 +306,13 @@ func NewEngine(l ILock, e IExecutor, opts ...Option) *Engine {
 	if txEB, ok := eventBus.(ITransactionEventBus); ok {
 		txEB.RegisterEventTXChecker(onTXChecker)
 	}
+	timer := options.Timer
+	timer.RegisterTimerHandler(onTimer)
 	return &Engine{
 		locker:      l,
 		executor:    e,
 		eventbus:    eventBus,
+		timer:       timer,
 		options:     options,
 		logger:      options.Logger,
 		idGenerator: options.IDGenerator,
@@ -306,6 +324,7 @@ func (e *Engine) NewStage() *Stage {
 		locker:      e.locker,
 		executor:    e.executor,
 		eventBus:    e.eventbus,
+		timer:       e.timer,
 		idGenerator: e.idGenerator,
 		meta:        &EntityContainer{},
 		result:      &Result{},
@@ -379,6 +398,25 @@ func (e *Engine) RegisterEventHandler(eventType EventType, construct EventHandle
 	})
 }
 
+// RegisterCronTask 注册定时任务
+func (e *Engine) RegisterCronTask(key EventType, cron string, f func(key, cron string)) {
+	if e.timer == nil {
+		panic("No ITimer specified")
+	}
+	if hasEventHandler(key) {
+		panic("key has registered")
+	}
+
+	RegisterEventHandler(key, func(ctx context.Context, evt *TimerEvent) error {
+		f(evt.Key, evt.Cron)
+		return nil
+	})
+
+	if err := e.timer.RunCron(string(key), cron, nil); err != nil {
+		panic(err)
+	}
+}
+
 // Stage 取舞台的意思,表示单次运行
 type Stage struct {
 	lockKeys []string
@@ -388,6 +426,7 @@ type Stage struct {
 	locker      ILock
 	executor    IExecutor
 	eventBus    IEventBus
+	timer       ITimer
 	idGenerator IIDGenerator
 	logger      logr.Logger
 	options     Options
@@ -409,6 +448,10 @@ func (e *Stage) WithOption(opts ...Option) *Stage {
 		txEB.RegisterEventTXChecker(onTXChecker)
 	}
 	e.eventBus = eventBus
+
+	timer := e.options.Timer
+	timer.RegisterTimerHandler(onTimer)
+	e.timer = timer
 	e.logger = e.options.Logger
 	e.idGenerator = e.options.IDGenerator
 	return e
diff --git a/event.go b/event.go
index 065dd87..a46ab54 100644
--- a/event.go
+++ b/event.go
@@ -37,10 +37,12 @@ const (
 	SendTypeNormal      SendType = "normal"      // 普通事件
 	SendTypeFIFO        SendType = "FIFO"        // 保序事件,即事件以 Sender 的发送时间顺序被消费执行
 	SendTypeTransaction SendType = "transaction" // 事务事件
+	SendTypeDelay       SendType = "delay"       // 延时发送
 )
 
 type EventOption struct {
 	SendType SendType
+	SendTime time.Time // 设定发送时间
 }
 
 type EventOpt func(opt *EventOption)
@@ -114,6 +116,14 @@ func RegisterEventHandler(t EventType, handler EventHandler) {
 	})
 }
 
+func hasEventHandler(t EventType) bool {
+	eventBusMu.Lock()
+	defer eventBusMu.Unlock()
+
+	_, ok := eventRouter[t]
+	return ok
+}
+
 // RegisterEventTXChecker 注册事务反查接口
 func RegisterEventTXChecker(t EventType, checker EventTXChecker) {
 	eventBusMu.Lock()
@@ -152,6 +162,7 @@ func RegisterEventBus(eventBus IEventBus) {
 	}
 }
 
+// onEvent EventBus 的统一的回调入口
 func onEvent(ctx context.Context, evt *DomainEvent) error {
 	defaultLogger.Info("on event call", "event", evt)
 	eventBusMu.Lock()
@@ -247,7 +258,6 @@ func NewDomainEvent(event IEvent, opts ...EventOpt) *DomainEvent {
 
 type IEventBus interface {
 	// Dispatch 发送领域事件到 EventBus,该方法会在事务内被同步调用
-	// context 返回值会被传入 AfterDispatch 调用
 	// 对于每个事件,EventBus 必须要至少保证 at least once 送达
 	Dispatch(ctx context.Context, evt ...*DomainEvent) error
 
diff --git a/eventbus/mysql/eventbus.go b/eventbus/mysql/eventbus.go
index 794b496..8a53ac6 100644
--- a/eventbus/mysql/eventbus.go
+++ b/eventbus/mysql/eventbus.go
@@ -238,7 +238,7 @@ func (e *EventBus) getTX(ctx context.Context) *Transaction {
 	return nil
 }
 
-// Dispatch 框架模式下通过框架实现持久化,外部模式下手动存储
+// Dispatch ...
 func (e *EventBus) Dispatch(ctx context.Context, events ...*dddfirework.DomainEvent) error {
 	tx := e.getTX(ctx)
 	pos := make([]*EventPO, len(events))
@@ -318,11 +318,7 @@ func (e *EventBus) RegisterEventHandler(cb dddfirework.DomainEventHandler) {
 
 func (e *EventBus) initService() error {
 	service := &ServicePO{}
-	err := e.db.Where(ServicePO{Name: e.serviceName}).FirstOrCreate(service).Error
-	if err != nil {
-		return err
-	}
-	return nil
+	return e.db.Where(ServicePO{Name: e.serviceName}).FirstOrCreate(service).Error
 }
 
 func (e *EventBus) lockService(tx *gorm.DB) (*ServicePO, error) {
@@ -423,6 +419,11 @@ func (e *EventBus) dispatchEvents(ctx context.Context, eventPOs []*EventPO) (suc
 		wg.Add(1)
 		go func() {
 			defer wg.Done()
+			defer func() {
+				if r := recover(); r != nil {
+
+				}
+			}()
 			for po := range events {
 				if err := e.cb(ctx, po.Event); err != nil {
 					failed = append(failed, po.ID)
diff --git a/timer.go b/timer.go
new file mode 100644
index 0000000..fdee7fd
--- /dev/null
+++ b/timer.go
@@ -0,0 +1,65 @@
+package dddfirework
+
+import (
+	"context"
+	"fmt"
+	"time"
+)
+
+var ErrNoEventTimerFound = fmt.Errorf("no event_timer found")
+
+type TimerHandler func(ctx context.Context, key, cron string, data []byte) error
+
+// ITimer 分布式定时器协议
+type ITimer interface {
+	// RegisterTimerHandler 注册定时任务,定时到来时候调用该回调函数
+	RegisterTimerHandler(cb TimerHandler)
+
+	// RunCron 按照 cron 语法设置定时,并在定时到达后作为参数调用定时任务回调
+	// key: 定时任务唯一标识,重复调用时不覆盖已有计时; cron: 定时配置; data: 透传数据,回调函数传入
+	RunCron(key, cron string, data []byte) error
+
+	// RunOnce 指定时间单次运行
+	// key: 定时任务唯一标识,重复调用时不覆盖已有计时; t: 执行时间; data: 透传数据,回调函数传入
+	RunOnce(key string, t time.Time, data []byte) error
+
+	// Cancel 删除某个定时
+	Cancel(key string) error
+}
+
+type noTimer struct {
+}
+
+func (d *noTimer) RunCron(key, cron string, data []byte) error {
+	return ErrNoEventTimerFound
+}
+
+func (d *noTimer) RunOnce(key string, t time.Time, data []byte) error {
+	return ErrNoEventTimerFound
+}
+
+func (d *noTimer) RegisterTimerHandler(cb TimerHandler) {
+}
+
+func (d *noTimer) Cancel(key string) error {
+	return nil
+}
+
+// TimerEvent 定时器专用的事件
+type TimerEvent struct {
+	Key     string
+	Cron    string
+	Payload []byte
+}
+
+func (e *TimerEvent) GetType() EventType {
+	return EventType(e.Key)
+}
+
+func (e *TimerEvent) GetSender() string {
+	return ""
+}
+
+func onTimer(ctx context.Context, key, cron string, data []byte) error {
+	return onEvent(ctx, NewDomainEvent(&TimerEvent{Key: key, Cron: cron, Payload: data}))
+}
-- 
GitLab