From 0bf5f6e4acdcb4f924c85c9cf9cc3b752d6e12ff Mon Sep 17 00:00:00 2001
From: qiankunli <qiankun.li@qq.com>
Date: Thu, 11 Apr 2024 17:27:46 +0800
Subject: [PATCH] feat: support renew for lock (#31)

Co-authored-by: liqiankun.1111 <liqiankun.1111@bytedance.com>
---
 lock/db/sql.go      | 88 ++++++++++++++++++++++++++++++++++++++++++---
 lock/db/sql_test.go | 26 ++++++++++++++
 2 files changed, 110 insertions(+), 4 deletions(-)

diff --git a/lock/db/sql.go b/lock/db/sql.go
index b1679fc..9657786 100644
--- a/lock/db/sql.go
+++ b/lock/db/sql.go
@@ -22,22 +22,57 @@ import (
 	"time"
 
 	"github.com/avast/retry-go"
+	"github.com/go-logr/logr"
 	"github.com/rs/xid"
 	"gorm.io/gorm"
 
 	"github.com/bytedance/dddfirework"
+	"github.com/bytedance/dddfirework/logger/stdr"
 )
 
+var defaultLogger = stdr.NewStdr("resource_lock")
+
+const (
+	// 定时协程每隔renewInterval 去续期,renewInterval必须小于ttl,以确保在本次ttl到期前,续期定时协程能及时续上。
+	renewInterval = 1 * time.Second
+)
+
+type Options struct {
+	RenewInterval time.Duration
+	Retry         bool
+	Logger        logr.Logger
+}
+
+type Option func(opt *Options)
+
 type DBLock struct {
-	ttl time.Duration
-	db  *gorm.DB
+	ttl    time.Duration
+	db     *gorm.DB
+	logger logr.Logger
+	opt    Options
 }
 
-func NewDBLock(db *gorm.DB, ttl time.Duration) *DBLock {
-	return &DBLock{db: db, ttl: ttl}
+func NewDBLock(db *gorm.DB, ttl time.Duration, options ...Option) *DBLock {
+	opt := Options{
+		RenewInterval: renewInterval,
+		Retry:         true,
+		Logger:        defaultLogger,
+	}
+	for _, o := range options {
+		o(&opt)
+	}
+	if ttl < opt.RenewInterval {
+		panic(fmt.Sprintf("ttl can not less than %f seconds", opt.RenewInterval.Seconds()))
+	}
+	return &DBLock{db: db, ttl: ttl, logger: opt.Logger, opt: opt}
 }
 
 func (r *DBLock) Lock(ctx context.Context, key string) (keyLock interface{}, err error) {
+	if !r.opt.Retry {
+		keyLock, err = r.lock(ctx, key)
+		return
+	}
+	// 加锁失败后重试
 	err = retry.Do(
 		func() error {
 			keyLock, err = r.lock(ctx, key)
@@ -107,3 +142,48 @@ func (r *DBLock) UnLock(ctx context.Context, keyLock interface{}) error {
 	}
 	return nil
 }
+
+func (r *DBLock) update(ctx context.Context, keyLock interface{}) error {
+	l := keyLock.(*ResourceLock)
+	res := r.db.WithContext(ctx).
+		Model(&ResourceLock{}).
+		Where("`resource` = ? AND `locker_id` = ?", l.Resource, l.LockerID). // check if locker_id has been changed by others
+		UpdateColumns(ResourceLock{UpdatedAt: time.Now(), LockerID: l.LockerID})
+	if res.Error != nil {
+		return fmt.Errorf("failed to update resource %s lock: %w", l.Resource, res.Error)
+	}
+	if res.RowsAffected == 0 {
+		return fmt.Errorf("resource %s updated by others", l.Resource)
+	}
+	return nil
+}
+
+func (r *DBLock) Run(ctx context.Context, key string, fn func(ctx context.Context)) error {
+	locker, err := r.Lock(ctx, key)
+	if err != nil {
+		return err
+	}
+	defer func() {
+		if err = r.UnLock(ctx, locker); err != nil {
+			r.logger.Error(err, fmt.Sprintf("failed to unlock %s", key))
+		}
+	}() // pass parent ctx in here, or defer will use sub ctx below
+
+	ticker := time.NewTicker(r.opt.RenewInterval)
+	// 业务函数执行完成后,会停止renew 协程
+	defer ticker.Stop()
+	subCtx, cancel := context.WithCancel(ctx)
+	go func() {
+		for range ticker.C {
+			if err = r.update(ctx, locker); err != nil {
+				// 续期失败,会通知业务函数停止执行
+				cancel()
+				r.logger.Info(fmt.Sprintf("failed to renew lock %s, error: %v", key, err))
+				break
+			}
+		}
+	}()
+
+	fn(subCtx)
+	return nil
+}
diff --git a/lock/db/sql_test.go b/lock/db/sql_test.go
index 5f2cfec..98e3a36 100644
--- a/lock/db/sql_test.go
+++ b/lock/db/sql_test.go
@@ -17,6 +17,8 @@ package db
 
 import (
 	"context"
+	"fmt"
+	"sync"
 	"testing"
 	"time"
 
@@ -60,3 +62,27 @@ func TestUnLock(t *testing.T) {
 	_, err = lock.Lock(context.Background(), "abc")
 	assert.NoError(t, err)
 }
+
+func TestRun(t *testing.T) {
+	db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{})
+	assert.NoError(t, err)
+	err = db.AutoMigrate(&ResourceLock{})
+	assert.NoError(t, err)
+	var wg sync.WaitGroup
+	wg.Add(1)
+	lock := NewDBLock(db.Debug(), 5*time.Second, func(opt *Options) {
+		opt.Retry = false
+	})
+	go func() {
+		err = lock.Run(context.Background(), "abc", func(ctx context.Context) {
+			defer wg.Done()
+			time.Sleep(10 * time.Second)
+		})
+	}()
+	time.Sleep(2 * time.Second)
+	// 会加锁失败
+	_, err = lock.Lock(context.Background(), "abc")
+	fmt.Println(err)
+	assert.Error(t, err)
+	wg.Wait()
+}
-- 
GitLab