From 0bf5f6e4acdcb4f924c85c9cf9cc3b752d6e12ff Mon Sep 17 00:00:00 2001 From: qiankunli <qiankun.li@qq.com> Date: Thu, 11 Apr 2024 17:27:46 +0800 Subject: [PATCH] feat: support renew for lock (#31) Co-authored-by: liqiankun.1111 <liqiankun.1111@bytedance.com> --- lock/db/sql.go | 88 ++++++++++++++++++++++++++++++++++++++++++--- lock/db/sql_test.go | 26 ++++++++++++++ 2 files changed, 110 insertions(+), 4 deletions(-) diff --git a/lock/db/sql.go b/lock/db/sql.go index b1679fc..9657786 100644 --- a/lock/db/sql.go +++ b/lock/db/sql.go @@ -22,22 +22,57 @@ import ( "time" "github.com/avast/retry-go" + "github.com/go-logr/logr" "github.com/rs/xid" "gorm.io/gorm" "github.com/bytedance/dddfirework" + "github.com/bytedance/dddfirework/logger/stdr" ) +var defaultLogger = stdr.NewStdr("resource_lock") + +const ( + // 定时协程每隔renewInterval 去续期,renewInterval必须小于ttl,以确保在本次ttl到期前,续期定时协程能及时续上。 + renewInterval = 1 * time.Second +) + +type Options struct { + RenewInterval time.Duration + Retry bool + Logger logr.Logger +} + +type Option func(opt *Options) + type DBLock struct { - ttl time.Duration - db *gorm.DB + ttl time.Duration + db *gorm.DB + logger logr.Logger + opt Options } -func NewDBLock(db *gorm.DB, ttl time.Duration) *DBLock { - return &DBLock{db: db, ttl: ttl} +func NewDBLock(db *gorm.DB, ttl time.Duration, options ...Option) *DBLock { + opt := Options{ + RenewInterval: renewInterval, + Retry: true, + Logger: defaultLogger, + } + for _, o := range options { + o(&opt) + } + if ttl < opt.RenewInterval { + panic(fmt.Sprintf("ttl can not less than %f seconds", opt.RenewInterval.Seconds())) + } + return &DBLock{db: db, ttl: ttl, logger: opt.Logger, opt: opt} } func (r *DBLock) Lock(ctx context.Context, key string) (keyLock interface{}, err error) { + if !r.opt.Retry { + keyLock, err = r.lock(ctx, key) + return + } + // 加锁失败后重试 err = retry.Do( func() error { keyLock, err = r.lock(ctx, key) @@ -107,3 +142,48 @@ func (r *DBLock) UnLock(ctx context.Context, keyLock interface{}) error { } return nil } + +func (r *DBLock) update(ctx context.Context, keyLock interface{}) error { + l := keyLock.(*ResourceLock) + res := r.db.WithContext(ctx). + Model(&ResourceLock{}). + Where("`resource` = ? AND `locker_id` = ?", l.Resource, l.LockerID). // check if locker_id has been changed by others + UpdateColumns(ResourceLock{UpdatedAt: time.Now(), LockerID: l.LockerID}) + if res.Error != nil { + return fmt.Errorf("failed to update resource %s lock: %w", l.Resource, res.Error) + } + if res.RowsAffected == 0 { + return fmt.Errorf("resource %s updated by others", l.Resource) + } + return nil +} + +func (r *DBLock) Run(ctx context.Context, key string, fn func(ctx context.Context)) error { + locker, err := r.Lock(ctx, key) + if err != nil { + return err + } + defer func() { + if err = r.UnLock(ctx, locker); err != nil { + r.logger.Error(err, fmt.Sprintf("failed to unlock %s", key)) + } + }() // pass parent ctx in here, or defer will use sub ctx below + + ticker := time.NewTicker(r.opt.RenewInterval) + // 业务函数执行完成后,会停止renew 协程 + defer ticker.Stop() + subCtx, cancel := context.WithCancel(ctx) + go func() { + for range ticker.C { + if err = r.update(ctx, locker); err != nil { + // 续期失败,会通知业务函数停止执行 + cancel() + r.logger.Info(fmt.Sprintf("failed to renew lock %s, error: %v", key, err)) + break + } + } + }() + + fn(subCtx) + return nil +} diff --git a/lock/db/sql_test.go b/lock/db/sql_test.go index 5f2cfec..98e3a36 100644 --- a/lock/db/sql_test.go +++ b/lock/db/sql_test.go @@ -17,6 +17,8 @@ package db import ( "context" + "fmt" + "sync" "testing" "time" @@ -60,3 +62,27 @@ func TestUnLock(t *testing.T) { _, err = lock.Lock(context.Background(), "abc") assert.NoError(t, err) } + +func TestRun(t *testing.T) { + db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{}) + assert.NoError(t, err) + err = db.AutoMigrate(&ResourceLock{}) + assert.NoError(t, err) + var wg sync.WaitGroup + wg.Add(1) + lock := NewDBLock(db.Debug(), 5*time.Second, func(opt *Options) { + opt.Retry = false + }) + go func() { + err = lock.Run(context.Background(), "abc", func(ctx context.Context) { + defer wg.Done() + time.Sleep(10 * time.Second) + }) + }() + time.Sleep(2 * time.Second) + // 会加锁失败 + _, err = lock.Lock(context.Background(), "abc") + fmt.Println(err) + assert.Error(t, err) + wg.Wait() +} -- GitLab