diff --git a/tools/goctl/model/sql/gen/findone.go b/tools/goctl/model/sql/gen/findone.go index 7916507c1..66236cba0 100644 --- a/tools/goctl/model/sql/gen/findone.go +++ b/tools/goctl/model/sql/gen/findone.go @@ -16,6 +16,7 @@ func genFindOne(table Table, withCache, postgreSql bool) (string, string, error) output, err := util.With("findOne"). Parse(text). + AddFunc("hasField", hasField(table)). Execute(map[string]any{ "withCache": withCache, "upperStartCamelObject": camel, diff --git a/tools/goctl/model/sql/gen/findonebyfield.go b/tools/goctl/model/sql/gen/findonebyfield.go index e11284d0b..1581509bf 100644 --- a/tools/goctl/model/sql/gen/findonebyfield.go +++ b/tools/goctl/model/sql/gen/findonebyfield.go @@ -22,7 +22,7 @@ func genFindOneByField(table Table, withCache, postgreSql bool) (*findOneCode, e return nil, err } - t := util.With("findOneByField").Parse(text) + t := util.With("findOneByField").Parse(text).AddFunc("hasField", hasField(table)) var list []string camelTableName := table.Name.ToCamel() for _, key := range table.UniqueCacheKey { @@ -54,7 +54,7 @@ func genFindOneByField(table Table, withCache, postgreSql bool) (*findOneCode, e return nil, err } - t = util.With("findOneByFieldMethod").Parse(text) + t = util.With("findOneByFieldMethod").Parse(text).AddFunc("hasField", hasField(table)) var listMethod []string for _, key := range table.UniqueCacheKey { var inJoin, paramJoin Join @@ -88,7 +88,7 @@ func genFindOneByField(table Table, withCache, postgreSql bool) (*findOneCode, e return nil, err } - out, err := util.With("findOneByFieldExtraMethod").Parse(text).Execute(map[string]any{ + out, err := util.With("findOneByFieldExtraMethod").AddFunc("hasField", hasField(table)).Parse(text).Execute(map[string]any{ "upperStartCamelObject": camelTableName, "primaryKeyLeft": table.PrimaryCacheKey.VarLeft, "lowerStartCamelObject": stringx.From(camelTableName).Untitle(), diff --git a/tools/goctl/model/sql/gen/gen.go b/tools/goctl/model/sql/gen/gen.go index 83056919e..0a95976fa 100644 --- a/tools/goctl/model/sql/gen/gen.go +++ b/tools/goctl/model/sql/gen/gen.go @@ -360,6 +360,7 @@ func (g *defaultGenerator) genModelCustom(in parser.Table, withCache bool) (stri t := util.With("model-custom"). Parse(text). + AddFunc("hasField", hasField(Table{Table: in})). GoFmt(true) output, err := t.Execute(map[string]any{ "pkg": g.pkg, @@ -381,6 +382,7 @@ func (g *defaultGenerator) executeModel(table Table, code *code) (*bytes.Buffer, } t := util.With("model"). Parse(text). + AddFunc("hasField", hasField(table)). GoFmt(true) output, err := t.Execute(map[string]any{ "pkg": g.pkg, diff --git a/tools/goctl/model/sql/gen/imports.go b/tools/goctl/model/sql/gen/imports.go index 81980fca5..af7c61c69 100644 --- a/tools/goctl/model/sql/gen/imports.go +++ b/tools/goctl/model/sql/gen/imports.go @@ -28,7 +28,7 @@ func genImports(table Table, withCache, timeImport bool) (string, error) { return "", err } - buffer, err := util.With("import").Parse(text).Execute(map[string]any{ + buffer, err := util.With("import").Parse(text).AddFunc("hasField", hasField(table)).Execute(map[string]any{ "time": timeImport, "containsPQ": table.ContainsPQ, "data": table, diff --git a/tools/goctl/model/sql/gen/template.go b/tools/goctl/model/sql/gen/template.go index 0f5c3e8db..a5a94b76f 100644 --- a/tools/goctl/model/sql/gen/template.go +++ b/tools/goctl/model/sql/gen/template.go @@ -94,3 +94,16 @@ func Update() error { return pathx.InitTemplates(category, templates) } + +// hasField returns a function that checks if a field exists in the table. +// It uses a pre-built map for O(1) lookup performance. +func hasField(table Table) func(string) bool { + fieldSet := make(map[string]struct{}, len(table.Fields)) + for _, field := range table.Fields { + fieldSet[field.NameOriginal] = struct{}{} + } + return func(f string) bool { + _, ok := fieldSet[f] + return ok + } +} diff --git a/tools/goctl/model/sql/gen/template_test.go b/tools/goctl/model/sql/gen/template_test.go index 206ed565e..463af3c5f 100644 --- a/tools/goctl/model/sql/gen/template_test.go +++ b/tools/goctl/model/sql/gen/template_test.go @@ -6,8 +6,10 @@ import ( "testing" "github.com/stretchr/testify/assert" + "github.com/zeromicro/go-zero/tools/goctl/model/sql/parser" "github.com/zeromicro/go-zero/tools/goctl/model/sql/template" "github.com/zeromicro/go-zero/tools/goctl/util/pathx" + "github.com/zeromicro/go-zero/tools/goctl/util/stringx" ) func TestGenTemplates(t *testing.T) { @@ -91,3 +93,151 @@ func TestUpdate(t *testing.T) { assert.Nil(t, err) assert.Equal(t, template.New, string(data)) } + +func TestHasField(t *testing.T) { + tests := []struct { + name string + table Table + fieldName string + wantResult bool + }{ + { + name: "field exists", + table: Table{ + Table: parser.Table{ + Fields: []*parser.Field{ + {NameOriginal: "id"}, + {NameOriginal: "name"}, + {NameOriginal: "created_at"}, + }, + }, + }, + fieldName: "name", + wantResult: true, + }, + { + name: "field does not exist", + table: Table{ + Table: parser.Table{ + Fields: []*parser.Field{ + {NameOriginal: "id"}, + {NameOriginal: "name"}, + }, + }, + }, + fieldName: "email", + wantResult: false, + }, + { + name: "empty table", + table: Table{ + Table: parser.Table{ + Fields: []*parser.Field{}, + }, + }, + fieldName: "id", + wantResult: false, + }, + { + name: "case sensitive match", + table: Table{ + Table: parser.Table{ + Fields: []*parser.Field{ + {NameOriginal: "ID"}, + {NameOriginal: "Name"}, + }, + }, + }, + fieldName: "id", + wantResult: false, + }, + { + name: "exact match required", + table: Table{ + Table: parser.Table{ + Fields: []*parser.Field{ + {NameOriginal: "user_name"}, + }, + }, + }, + fieldName: "user_name", + wantResult: true, + }, + { + name: "partial match should fail", + table: Table{ + Table: parser.Table{ + Fields: []*parser.Field{ + {NameOriginal: "user_name"}, + }, + }, + }, + fieldName: "user", + wantResult: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + fn := hasField(tt.table) + result := fn(tt.fieldName) + assert.Equal(t, tt.wantResult, result) + }) + } +} + +func TestHasFieldWithRealTable(t *testing.T) { + // Create a realistic table structure + table := Table{ + Table: parser.Table{ + Name: stringx.From("users"), + Fields: []*parser.Field{ + {NameOriginal: "id", DataType: "int64"}, + {NameOriginal: "username", DataType: "string"}, + {NameOriginal: "email", DataType: "string"}, + {NameOriginal: "password", DataType: "string"}, + {NameOriginal: "created_at", DataType: "time.Time"}, + {NameOriginal: "updated_at", DataType: "time.Time"}, + }, + }, + } + + fn := hasField(table) + + // Test all existing fields + assert.True(t, fn("id")) + assert.True(t, fn("username")) + assert.True(t, fn("email")) + assert.True(t, fn("password")) + assert.True(t, fn("created_at")) + assert.True(t, fn("updated_at")) + + // Test non-existing fields + assert.False(t, fn("deleted_at")) + assert.False(t, fn("ID")) + assert.False(t, fn("Username")) + assert.False(t, fn("")) +} + +func TestHasFieldPerformance(t *testing.T) { + // Create a table with many fields to test performance optimization + var fields []*parser.Field + for i := 0; i < 1000; i++ { + fields = append(fields, &parser.Field{ + NameOriginal: "field_" + string(rune('0'+i%10)) + string(rune('a'+i%26)), + }) + } + + table := Table{ + Table: parser.Table{ + Fields: fields, + }, + } + + fn := hasField(table) + + // Verify the function works correctly + assert.True(t, fn(fields[0].NameOriginal)) + assert.True(t, fn(fields[999].NameOriginal)) + assert.False(t, fn("non_existent_field")) +} diff --git a/tools/goctl/model/sql/gen/update.go b/tools/goctl/model/sql/gen/update.go index f1cb0a1d5..c05efc7c9 100644 --- a/tools/goctl/model/sql/gen/update.go +++ b/tools/goctl/model/sql/gen/update.go @@ -61,7 +61,7 @@ func genUpdate(table Table, withCache, postgreSql bool) ( return "", "", err } - output, err := util.With("update").Parse(text).Execute( + output, err := util.With("update").Parse(text).AddFunc("hasField", hasField(table)).Execute( map[string]any{ "withCache": withCache, "containsIndexCache": table.ContainsUniqueCacheKey, @@ -94,7 +94,7 @@ func genUpdate(table Table, withCache, postgreSql bool) ( return "", "", err } - updateMethodOutput, err := util.With("updateMethod").Parse(text).Execute( + updateMethodOutput, err := util.With("updateMethod").Parse(text).AddFunc("hasField", hasField(table)).Execute( map[string]any{ "upperStartCamelObject": camelTableName, "data": table, diff --git a/tools/goctl/util/templatex.go b/tools/goctl/util/templatex.go index 726544c52..89f17573b 100644 --- a/tools/goctl/util/templatex.go +++ b/tools/goctl/util/templatex.go @@ -15,15 +15,17 @@ const regularPerm = 0o666 // DefaultTemplate is a tool to provides the text/template operations type DefaultTemplate struct { - name string - text string - goFmt bool + name string + text string + goFmt bool + funcMap template.FuncMap } // With returns an instance of DefaultTemplate func With(name string) *DefaultTemplate { return &DefaultTemplate{ - name: name, + name: name, + funcMap: make(template.FuncMap), } } @@ -55,7 +57,11 @@ func (t *DefaultTemplate) SaveTo(data any, path string, forceUpdate bool) error // Execute returns the codes after the template executed func (t *DefaultTemplate) Execute(data any) (*bytes.Buffer, error) { - tem, err := template.New(t.name).Parse(t.text) + tmp := template.New(t.name) + if len(t.funcMap) > 0 { + tmp.Funcs(t.funcMap) + } + tem, err := tmp.Parse(t.text) if err != nil { return nil, errorx.Wrap(err, "template parse error:", t.text) } @@ -79,6 +85,16 @@ func (t *DefaultTemplate) Execute(data any) (*bytes.Buffer, error) { return buf, nil } +// AddFunc adds a template function. It returns the template instance for chaining. +// If funcName is empty or function is nil, it returns the template without modification. +func (t *DefaultTemplate) AddFunc(funcName string, function any) *DefaultTemplate { + if funcName == "" || function == nil { + return t + } + t.funcMap[funcName] = function + return t +} + // IsTemplateVariable returns true if the text is a template variable. // The text must start with a dot and be a valid template. func IsTemplateVariable(text string) bool {