Skip to content
Snippets Groups Projects
Unverified Commit 65748490 authored by kirinzhong's avatar kirinzhong Committed by GitHub
Browse files

feat: 修复通过 gorm tag 自定义主键名导致删除失败的bug && 优化一下 PO schema 方案 (#15)

parent 513c9d40
No related branches found
No related tags found
No related merge requests found
...@@ -18,32 +18,21 @@ package mysql ...@@ -18,32 +18,21 @@ package mysql
import ( import (
"reflect" "reflect"
"strings" "strings"
"sync"
"gorm.io/gorm/schema"
) )
var strategy = schema.NamingStrategy{IdentifierMaxLength: 64} var schemaCache = &sync.Map{}
func hasValue(tag, attr string) bool { func hasValue(tag, attr string) bool {
parts := strings.Split(tag, ";") parts := strings.Split(tag, ";")
for _, p := range parts { for _, p := range parts {
if strings.HasPrefix(p, attr) { if strings.HasPrefix(strings.ToUpper(p), strings.ToUpper(attr)) {
return true return true
} }
} }
return false return false
} }
func getTagValue(tag, attr string) string {
parts := strings.Split(tag, ";")
for _, p := range parts {
if strings.HasPrefix(p, attr+":") {
return strings.TrimSpace(strings.TrimPrefix(p, attr+":"))
}
}
return ""
}
func diffStruct(currVal, prevVal reflect.Value) []string { func diffStruct(currVal, prevVal reflect.Value) []string {
result := make([]string, 0) result := make([]string, 0)
poType := currVal.Type() poType := currVal.Type()
...@@ -53,21 +42,11 @@ func diffStruct(currVal, prevVal reflect.Value) []string { ...@@ -53,21 +42,11 @@ func diffStruct(currVal, prevVal reflect.Value) []string {
if !fieldVal.CanInterface() { if !fieldVal.CanInterface() {
continue continue
} }
// 默认使用 gorm 的名称规则 fieldName := field.Name
fieldName := strategy.ColumnName("", field.Name)
fieldTag := field.Tag.Get("gorm") fieldTag := field.Tag.Get("gorm")
if fieldTag != "" && hasValue(fieldTag, "column") {
fieldName = getTagValue(fieldTag, "column")
}
if fieldVal.Kind() == reflect.Struct && (field.Anonymous || hasValue(fieldTag, "embedded")) { if fieldVal.Kind() == reflect.Struct && (field.Anonymous || hasValue(fieldTag, "embedded")) {
prefix := ""
if fieldTag != "" && hasValue(fieldTag, "embeddedPrefix") {
prefix = getTagValue(fieldTag, "embeddedPrefix")
}
structDiff := diffStruct(fieldVal, prevFiledVal) structDiff := diffStruct(fieldVal, prevFiledVal)
for _, k := range structDiff { result = append(result, structDiff...)
result = append(result, prefix+k)
}
} else if fieldVal.Comparable() { } else if fieldVal.Comparable() {
if fieldVal.Equal(prevFiledVal) { if fieldVal.Equal(prevFiledVal) {
continue continue
...@@ -78,7 +57,6 @@ func diffStruct(currVal, prevVal reflect.Value) []string { ...@@ -78,7 +57,6 @@ func diffStruct(currVal, prevVal reflect.Value) []string {
result = append(result, fieldName) result = append(result, fieldName)
} }
} }
result = append(result, "fake")
return result return result
} }
......
...@@ -68,10 +68,10 @@ func TestDiffModel(t *testing.T) { ...@@ -68,10 +68,10 @@ func TestDiffModel(t *testing.T) {
result := DiffModel(p1, p2) result := DiffModel(p1, p2)
assert.NotEmpty(t, result) assert.NotEmpty(t, result)
assert.Contains(t, result, "id") assert.Contains(t, result, "ID")
assert.Contains(t, result, "item_price") assert.Contains(t, result, "ItemPrice")
assert.Contains(t, result, "diff_ptr_value") assert.Contains(t, result, "DiffPtrValue")
assert.Contains(t, result, "base_created_at") assert.Contains(t, result, "CreatedAt")
assert.Contains(t, result, "slice_value") assert.Contains(t, result, "SliceValue")
assert.NotContains(t, result, "same_ptr_value") assert.NotContains(t, result, "SamePtrValue")
} }
...@@ -21,9 +21,9 @@ import ( ...@@ -21,9 +21,9 @@ import (
"reflect" "reflect"
"time" "time"
"gorm.io/gorm"
ddd "github.com/bytedance/dddfirework" ddd "github.com/bytedance/dddfirework"
"gorm.io/gorm"
"gorm.io/gorm/schema"
) )
var ErrInvalidDB = fmt.Errorf("invalid db") var ErrInvalidDB = fmt.Errorf("invalid db")
...@@ -101,14 +101,32 @@ var opMap = map[ddd.OpType]execFunc{ ...@@ -101,14 +101,32 @@ var opMap = map[ddd.OpType]execFunc{
}, },
ddd.OpDelete: func(db *gorm.DB, a *ddd.Action) error { ddd.OpDelete: func(db *gorm.DB, a *ddd.Action) error {
po := a.Models[0] po := a.Models[0]
poType := reflect.Indirect(reflect.ValueOf(po)).Type() s, err := schema.Parse(po, schemaCache, db.NamingStrategy)
newPO := reflect.New(poType).Interface() if err != nil {
return err
}
ids := make([]string, 0) if len(s.PrimaryFields) == 1 {
// 单主键支持批量删除
pk := s.PrimaryFields[0].DBName
poType := reflect.Indirect(reflect.ValueOf(po)).Type()
newPO := reflect.New(poType).Interface()
if len(a.Models) == 1 {
return db.Where(pk+" = ?", a.Models[0].GetID()).Delete(newPO).Error
}
ids := make([]string, 0)
for _, m := range a.Models {
ids = append(ids, m.GetID())
}
return db.Where(pk+" in ?", ids).Delete(newPO).Error
}
// 复合主键的一个个删
for _, m := range a.Models { for _, m := range a.Models {
ids = append(ids, m.GetID()) if err := db.Delete(m).Error; err != nil {
return err
}
} }
return db.Where("id in ?", ids).Delete(newPO).Error return nil
}, },
ddd.OpQuery: func(db *gorm.DB, a *ddd.Action) error { ddd.OpQuery: func(db *gorm.DB, a *ddd.Action) error {
res := db.Where(a.Query).Find(a.QueryResult) res := db.Where(a.Query).Find(a.QueryResult)
......
...@@ -79,13 +79,22 @@ func (o *order) GetID() string { ...@@ -79,13 +79,22 @@ func (o *order) GetID() string {
return o.ID return o.ID
} }
type basePO struct {
ID string `gorm:"primaryKey;column:uid"`
}
type orderPO struct { type orderPO struct {
ID string `gorm:"primaryKey;column:id"` BasePO basePO `gorm:"embedded"`
Title string `gorm:"column:title"` Title string `gorm:"column:title"`
} }
func newOrderPO(id string) *orderPO {
return &orderPO{BasePO: basePO{ID: id}}
}
func (o *orderPO) GetID() string { func (o *orderPO) GetID() string {
return o.ID return o.BasePO.ID
} }
func (o *orderPO) TableName() string { func (o *orderPO) TableName() string {
...@@ -142,12 +151,12 @@ func initModel(db *gorm.DB) { ...@@ -142,12 +151,12 @@ func initModel(db *gorm.DB) {
RegisterEntity2Model(&order{}, func(entity, parent ddd.IEntity, op ddd.OpType) (IModel, error) { RegisterEntity2Model(&order{}, func(entity, parent ddd.IEntity, op ddd.OpType) (IModel, error) {
order := entity.(*order) order := entity.(*order)
return &orderPO{ return &orderPO{
ID: order.GetID(), BasePO: basePO{ID: order.GetID()},
Title: order.Title, Title: order.Title,
}, nil }, nil
}, func(po IModel, do ddd.IEntity) error { }, func(po IModel, do ddd.IEntity) error {
s, t := po.(*orderPO), do.(*order) s, t := po.(*orderPO), do.(*order)
t.SetID(s.ID) t.SetID(s.GetID())
t.Title = s.Title t.Title = s.Title
return nil return nil
}) })
...@@ -287,12 +296,12 @@ func TestExecutor(t *testing.T) { ...@@ -287,12 +296,12 @@ func TestExecutor(t *testing.T) {
res := engine.Create(ctx, testOrder, testDeleteOrder) res := engine.Create(ctx, testOrder, testDeleteOrder)
assert.NoError(t, res.Error) assert.NoError(t, res.Error)
assert.NoError(t, db.First(&orderPO{}, "id = ?", "order1").Error) assert.NoError(t, db.First(newOrderPO("order1")).Error)
assert.NoError(t, db.First(&productPO{}, "id = ?", "product1").Error) assert.NoError(t, db.First(&productPO{ID: "product1"}).Error)
assert.NoError(t, db.First(&productPO{}, "id = ?", "product2").Error) assert.NoError(t, db.First(&productPO{ID: "product2"}).Error)
assert.NoError(t, db.First(&orderPO{}, "id = ?", "order2").Error) assert.NoError(t, db.First(newOrderPO("order2")).Error)
res = engine.RunCommand(ctx, &Case{ res = engine.Run(ctx, &Case{
db: db, db: db,
TestOrderID: "order1", TestOrderID: "order1",
TestDeleteID: "order2", TestDeleteID: "order2",
...@@ -312,8 +321,8 @@ func TestExecutor(t *testing.T) { ...@@ -312,8 +321,8 @@ func TestExecutor(t *testing.T) {
assert.Len(t, p.Tags, 2) assert.Len(t, p.Tags, 2)
// 测试删除根实体、子实体 // 测试删除根实体、子实体
assert.ErrorIs(t, db.First(&orderPO{}, "id = ?", "order2").Error, gorm.ErrRecordNotFound) assert.ErrorIs(t, db.First(newOrderPO("order2")).Error, gorm.ErrRecordNotFound)
assert.ErrorIs(t, db.First(&productPO{}, "id = ?", "product2").Error, gorm.ErrRecordNotFound) assert.ErrorIs(t, db.First(&productPO{ID: "product2"}).Error, gorm.ErrRecordNotFound)
} }
func TestDelete(t *testing.T) { func TestDelete(t *testing.T) {
...@@ -327,12 +336,12 @@ func TestDelete(t *testing.T) { ...@@ -327,12 +336,12 @@ func TestDelete(t *testing.T) {
engine := ddd.NewEngine(testsuit.NewMemLock(), NewExecutor(db)) engine := ddd.NewEngine(testsuit.NewMemLock(), NewExecutor(db))
res := engine.Create(ctx, o) res := engine.Create(ctx, o)
assert.NoError(t, res.Error) assert.NoError(t, res.Error)
assert.NoError(t, db.First(&orderPO{}, "id = ?", o.ID).Error) assert.NoError(t, db.First(newOrderPO(o.ID)).Error)
res = engine.Delete(ctx, o) res = engine.Delete(ctx, o)
assert.NoError(t, res.Error) assert.NoError(t, res.Error)
assert.ErrorIs(t, db.First(&orderPO{}, "id = ?", o.ID).Error, gorm.ErrRecordNotFound) assert.ErrorIs(t, db.First(newOrderPO(o.ID)).Error, gorm.ErrRecordNotFound)
} }
func TestRollback(t *testing.T) { func TestRollback(t *testing.T) {
...@@ -352,7 +361,7 @@ func TestRollback(t *testing.T) { ...@@ -352,7 +361,7 @@ func TestRollback(t *testing.T) {
return nil return nil
}) })
assert.NotNil(t, res.Error) assert.NotNil(t, res.Error)
assert.ErrorIs(t, db.First(&orderPO{}, "id = ?", "testrollback").Error, gorm.ErrRecordNotFound) assert.ErrorIs(t, db.First(newOrderPO("testrollback")).Error, gorm.ErrRecordNotFound)
} }
func initBenchmark(db *gorm.DB) { func initBenchmark(db *gorm.DB) {
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment