diff --git a/go.mod b/go.mod index 27e585835..0f57f9da1 100644 --- a/go.mod +++ b/go.mod @@ -41,6 +41,7 @@ require ( github.com/stretchr/testify v1.5.1 github.com/tmc/grpc-websocket-proxy v0.0.0-20171017195756-830351dc03c6 // indirect github.com/urfave/cli v1.22.4 + github.com/xwb1989/sqlparser v0.0.0-20180606152119-120387863bf2 github.com/yuin/gopher-lua v0.0.0-20191220021717-ab39c6098bdb // indirect go.etcd.io/etcd v0.0.0-20200402134248-51bdeb39e698 go.uber.org/automaxprocs v1.3.0 diff --git a/go.sum b/go.sum index 2fc6e4e90..1501ac069 100644 --- a/go.sum +++ b/go.sum @@ -266,6 +266,8 @@ github.com/xiang90/probing v0.0.0-20190116061207-43a291ad63a2 h1:eY9dn8+vbi4tKz5 github.com/xiang90/probing v0.0.0-20190116061207-43a291ad63a2/go.mod h1:UETIi67q53MR2AWcXfiuqkDkRtnGDLqkBTpCHuJHxtU= github.com/xlab/treeprint v0.0.0-20180616005107-d6fb6747feb6 h1:YdYsPAZ2pC6Tow/nPZOPQ96O3hm/ToAkGsPLzedXERk= github.com/xlab/treeprint v0.0.0-20180616005107-d6fb6747feb6/go.mod h1:ce1O1j6UtZfjr22oyGxGLbauSBp2YVXpARAosm7dHBg= +github.com/xwb1989/sqlparser v0.0.0-20180606152119-120387863bf2 h1:zzrxE1FKn5ryBNl9eKOeqQ58Y/Qpo3Q9QNxKHX5uzzQ= +github.com/xwb1989/sqlparser v0.0.0-20180606152119-120387863bf2/go.mod h1:hzfGeIUDq/j97IG+FhNqkowIyEcD88LrW6fyU3K3WqY= github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= github.com/yuin/gopher-lua v0.0.0-20191220021717-ab39c6098bdb h1:ZkM6LRnq40pR1Ox0hTHlnpkcOTuFIDQpZ1IN8rKKhX0= github.com/yuin/gopher-lua v0.0.0-20191220021717-ab39c6098bdb/go.mod h1:gqRgreBUhTSL0GeU64rtZ3Uq3wtjOa/TB2YfrtkCbVQ= diff --git a/tools/goctl/goctl.go b/tools/goctl/goctl.go index 9a7aba991..41a868c8c 100644 --- a/tools/goctl/goctl.go +++ b/tools/goctl/goctl.go @@ -16,6 +16,7 @@ import ( "github.com/tal-tech/go-zero/tools/goctl/configgen" "github.com/tal-tech/go-zero/tools/goctl/docker" "github.com/tal-tech/go-zero/tools/goctl/feature" + "github.com/tal-tech/go-zero/tools/goctl/model/sql/command" "github.com/urfave/cli" ) @@ -172,14 +173,19 @@ var ( Usage: "generate sql model", Flags: []cli.Flag{ cli.StringFlag{ - Name: "config, c", - Usage: "the file that contains main function", + Name: "src, s", + Usage: "the file path of sql", }, cli.StringFlag{ Name: "dir, d", Usage: "the target dir", }, + cli.BoolFlag{ + Name: "cache, c", + Usage: "generate code with cache", + }, }, + Action: command.Mysql, }, { Name: "config", diff --git a/tools/goctl/model/sql/command/command.go b/tools/goctl/model/sql/command/command.go new file mode 100644 index 000000000..461ffbce4 --- /dev/null +++ b/tools/goctl/model/sql/command/command.go @@ -0,0 +1,14 @@ +package command + +import ( + "github.com/tal-tech/go-zero/tools/goctl/model/sql/gen" + "github.com/urfave/cli" +) + +func Mysql(ctx *cli.Context) error { + src := ctx.String("src") + dir := ctx.String("dir") + cache := ctx.Bool("cache") + generator := gen.NewDefaultGenerator(src, dir) + return generator.Start(cache) +} diff --git a/tools/goctl/model/sql/converter/types.go b/tools/goctl/model/sql/converter/types.go new file mode 100644 index 000000000..c6f931053 --- /dev/null +++ b/tools/goctl/model/sql/converter/types.go @@ -0,0 +1,46 @@ +package converter + +import ( + "fmt" + "strings" +) + +var ( + commonMysqlDataTypeMap = map[string]string{ + // For consistency, all integer types are converted to int64 + "tinyint": "int64", + "smallint": "int64", + "mediumint": "int64", + "int": "int64", + "integer": "int64", + "bigint": "int64", + "float": "float64", + "double": "float64", + "decimal": "float64", + "date": "time.Time", + "time": "string", + "year": "int64", + "datetime": "time.Time", + "timestamp": "time.Time", + "char": "string", + "varchar": "string", + "tinyblob": "string", + "tinytext": "string", + "blob": "string", + "text": "string", + "mediumblob": "string", + "mediumtext": "string", + "longblob": "string", + "longtext": "string", + } +) + +func ConvertDataType(dataBaseType string) (goDataType string, err error) { + tp, ok := commonMysqlDataTypeMap[strings.ToLower(dataBaseType)] + if !ok { + err = fmt.Errorf("unexpected database type: %s", dataBaseType) + return + } + goDataType = tp + return +} diff --git a/tools/goctl/model/sql/gen/convert.go b/tools/goctl/model/sql/gen/convert.go deleted file mode 100644 index 239656aba..000000000 --- a/tools/goctl/model/sql/gen/convert.go +++ /dev/null @@ -1,108 +0,0 @@ -package gen - -import ( - "errors" - "fmt" - "sort" - "strings" - - "github.com/tal-tech/go-zero/tools/goctl/model/sql/util" -) - -func TableConvert(outerTable OuterTable) (*InnerTable, error) { - var table InnerTable - table.CreateNotFound = outerTable.CreateNotFound - tableSnakeCase, tableUpperCamelCase, tableLowerCamelCase := util.FormatField(outerTable.Table) - table.SnakeCase = tableSnakeCase - table.UpperCamelCase = tableUpperCamelCase - table.LowerCamelCase = tableLowerCamelCase - fields := make([]*InnerField, 0) - var primaryField *InnerField - conflict := make(map[string]struct{}) - var containsCache bool - for _, field := range outerTable.Fields { - if field.Cache && !containsCache { - containsCache = true - } - fieldSnakeCase, fieldUpperCamelCase, fieldLowerCamelCase := util.FormatField(field.Name) - tag, err := genTag(fieldSnakeCase) - if err != nil { - return nil, err - } - var comment string - if field.Comment != "" { - comment = fmt.Sprintf("// %s", field.Comment) - } - withFields := make([]InnerWithField, 0) - unique := make([]string, 0) - unique = append(unique, fmt.Sprintf("%v", field.QueryType)) - unique = append(unique, field.Name) - - for _, item := range field.WithFields { - unique = append(unique, item.Name) - withFieldSnakeCase, withFieldUpperCamelCase, withFieldLowerCamelCase := util.FormatField(item.Name) - withFields = append(withFields, InnerWithField{ - Case: Case{ - SnakeCase: withFieldSnakeCase, - LowerCamelCase: withFieldLowerCamelCase, - UpperCamelCase: withFieldUpperCamelCase, - }, - DataType: commonMysqlDataTypeMap[item.DataBaseType], - }) - } - sort.Strings(unique) - uniqueKey := strings.Join(unique, "#") - if _, ok := conflict[uniqueKey]; ok { - return nil, ErrCircleQuery - } else { - conflict[uniqueKey] = struct{}{} - } - sortFields := make([]InnerSort, 0) - for _, sortField := range field.OuterSort { - sortSnake, sortUpperCamelCase, sortLowerCamelCase := util.FormatField(sortField.Field) - sortFields = append(sortFields, InnerSort{ - Field: Case{ - SnakeCase: sortSnake, - LowerCamelCase: sortUpperCamelCase, - UpperCamelCase: sortLowerCamelCase, - }, - Asc: sortField.Asc, - }) - } - innerField := &InnerField{ - IsPrimaryKey: field.IsPrimaryKey, - InnerWithField: InnerWithField{ - Case: Case{ - SnakeCase: fieldSnakeCase, - LowerCamelCase: fieldLowerCamelCase, - UpperCamelCase: fieldUpperCamelCase, - }, - DataType: commonMysqlDataTypeMap[field.DataBaseType], - }, - DataBaseType: field.DataBaseType, - Tag: tag, - Comment: comment, - Cache: field.Cache, - QueryType: field.QueryType, - WithFields: withFields, - Sort: sortFields, - } - if field.IsPrimaryKey { - primaryField = innerField - } - fields = append(fields, innerField) - } - if primaryField == nil { - return nil, errors.New("please ensure that primary exists") - } - table.ContainsCache = containsCache - primaryField.Cache = containsCache - table.PrimaryField = primaryField - table.Fields = fields - cacheKey, err := genCacheKeys(&table) - if err != nil { - return nil, err - } - table.CacheKey = cacheKey - return &table, nil -} diff --git a/tools/goctl/model/sql/gen/delete.go b/tools/goctl/model/sql/gen/delete.go index 23343936b..872a17945 100644 --- a/tools/goctl/model/sql/gen/delete.go +++ b/tools/goctl/model/sql/gen/delete.go @@ -1,51 +1,46 @@ package gen import ( - "bytes" "strings" - "text/template" - sqltemplate "github.com/tal-tech/go-zero/tools/goctl/model/sql/template" + "github.com/tal-tech/go-zero/core/collection" + "github.com/tal-tech/go-zero/tools/goctl/model/sql/template" + "github.com/tal-tech/go-zero/tools/goctl/util/stringx" + "github.com/tal-tech/go-zero/tools/goctl/util/templatex" ) -func genDelete(table *InnerTable) (string, error) { - t, err := template.New("delete").Parse(sqltemplate.Delete) - if err != nil { - return "", nil - } - deleteBuffer := new(bytes.Buffer) - keys := make([]string, 0) - keyValues := make([]string, 0) - for snake, key := range table.CacheKey { - if snake == table.PrimaryField.SnakeCase { - keys = append(keys, key.Key) - } else { - keys = append(keys, key.DataKey) +func genDelete(table Table, withCache bool) (string, error) { + keySet := collection.NewSet() + keyVariableSet := collection.NewSet() + for fieldName, key := range table.CacheKey { + keySet.AddStr(key.KeyExpression) + if fieldName != table.PrimaryKey.Name.Source() { + keySet.AddStr(key.DataKeyExpression) } - keyValues = append(keyValues, key.KeyVariable) + keyVariableSet.AddStr(key.Variable) } - var isOnlyPrimaryKeyCache = true + var containsIndexCache = false for _, item := range table.Fields { - if item.IsPrimaryKey { - continue - } - if item.Cache { - isOnlyPrimaryKeyCache = false + if item.IsKey { + containsIndexCache = true break } } - err = t.Execute(deleteBuffer, map[string]interface{}{ - "upperObject": table.UpperCamelCase, - "containsCache": table.ContainsCache, - "isNotPrimaryKey": !isOnlyPrimaryKeyCache, - "lowerPrimaryKey": table.PrimaryField.LowerCamelCase, - "dataType": table.PrimaryField.DataType, - "keys": strings.Join(keys, "\r\n"), - "snakePrimaryKey": table.PrimaryField.SnakeCase, - "keyValues": strings.Join(keyValues, ", "), - }) + camel := table.Name.Snake2Camel() + output, err := templatex.With("delete"). + Parse(template.Delete). + Execute(map[string]interface{}{ + "upperStartCamelObject": camel, + "withCache": withCache, + "containsIndexCache": containsIndexCache, + "lowerStartCamelPrimaryKey": stringx.From(camel).LowerStart(), + "dataType": table.PrimaryKey.DataType, + "keys": strings.Join(keySet.KeysStr(), "\n"), + "originalPrimaryKey": table.PrimaryKey.Name.Source(), + "keyValues": strings.Join(keyVariableSet.KeysStr(), ", "), + }) if err != nil { return "", err } - return deleteBuffer.String(), nil + return output.String(), nil } diff --git a/tools/goctl/model/sql/gen/field.go b/tools/goctl/model/sql/gen/field.go index fe8413c5a..d49882221 100644 --- a/tools/goctl/model/sql/gen/field.go +++ b/tools/goctl/model/sql/gen/field.go @@ -1,15 +1,15 @@ package gen import ( - "bytes" "strings" - "text/template" - sqltemplate "github.com/tal-tech/go-zero/tools/goctl/model/sql/template" + "github.com/tal-tech/go-zero/tools/goctl/model/sql/parser" + "github.com/tal-tech/go-zero/tools/goctl/model/sql/template" + "github.com/tal-tech/go-zero/tools/goctl/util/templatex" ) -func genFields(fields []*InnerField) (string, error) { - list := make([]string, 0) +func genFields(fields []parser.Field) (string, error) { + var list []string for _, field := range fields { result, err := genField(field) if err != nil { @@ -17,23 +17,20 @@ func genFields(fields []*InnerField) (string, error) { } list = append(list, result) } - return strings.Join(list, "\r\n"), nil + return strings.Join(list, "\n"), nil } -func genField(field *InnerField) (string, error) { - t, err := template.New("types").Parse(sqltemplate.Field) - if err != nil { - return "", nil - } - var typeBuffer = new(bytes.Buffer) - err = t.Execute(typeBuffer, map[string]string{ - "name": field.UpperCamelCase, - "type": field.DataType, - "tag": field.Tag, - "comment": field.Comment, - }) +func genField(field parser.Field) (string, error) { + output, err := templatex.With("types"). + Parse(template.Field). + Execute(map[string]string{ + "name": field.Name.Snake2Camel(), + "type": field.DataType, + "tag": field.Name.Source(), + "comment": field.Comment, + }) if err != nil { return "", err } - return typeBuffer.String(), nil + return output.String(), nil } diff --git a/tools/goctl/model/sql/gen/findallbyfield.go b/tools/goctl/model/sql/gen/findallbyfield.go deleted file mode 100644 index 9e1c6be2e..000000000 --- a/tools/goctl/model/sql/gen/findallbyfield.go +++ /dev/null @@ -1,55 +0,0 @@ -package gen - -import ( - "bytes" - "strings" - "text/template" - - sqltemplate "github.com/tal-tech/go-zero/tools/goctl/model/sql/template" -) - -func genFindAllByField(table *InnerTable) (string, error) { - t, err := template.New("findAllByField").Parse(sqltemplate.FindAllByField) - if err != nil { - return "", err - } - list := make([]string, 0) - for _, field := range table.Fields { - if field.IsPrimaryKey { - continue - } - if field.QueryType != QueryAll { - continue - } - fineOneByFieldBuffer := new(bytes.Buffer) - upperFields := make([]string, 0) - in := make([]string, 0) - expressionFields := make([]string, 0) - expressionValuesFields := make([]string, 0) - upperFields = append(upperFields, field.UpperCamelCase) - in = append(in, field.LowerCamelCase+" "+field.DataType) - expressionFields = append(expressionFields, field.SnakeCase+" = ?") - expressionValuesFields = append(expressionValuesFields, field.LowerCamelCase) - for _, withField := range field.WithFields { - upperFields = append(upperFields, withField.UpperCamelCase) - in = append(in, withField.LowerCamelCase+" "+withField.DataType) - expressionFields = append(expressionFields, withField.SnakeCase+" = ?") - expressionValuesFields = append(expressionValuesFields, withField.LowerCamelCase) - } - err = t.Execute(fineOneByFieldBuffer, map[string]interface{}{ - "in": strings.Join(in, ","), - "upperObject": table.UpperCamelCase, - "upperFields": strings.Join(upperFields, "And"), - "lowerObject": table.LowerCamelCase, - "snakePrimaryKey": field.SnakeCase, - "expression": strings.Join(expressionFields, " AND "), - "expressionValues": strings.Join(expressionValuesFields, ", "), - "containsCache": table.ContainsCache, - }) - if err != nil { - return "", err - } - list = append(list, fineOneByFieldBuffer.String()) - } - return strings.Join(list, ""), nil -} diff --git a/tools/goctl/model/sql/gen/findallbylimit.go b/tools/goctl/model/sql/gen/findallbylimit.go deleted file mode 100644 index de7e5dbf3..000000000 --- a/tools/goctl/model/sql/gen/findallbylimit.go +++ /dev/null @@ -1,63 +0,0 @@ -package gen - -import ( - "bytes" - "strings" - "text/template" - - sqltemplate "github.com/tal-tech/go-zero/tools/goctl/model/sql/template" -) - -func genFindLimitByField(table *InnerTable) (string, error) { - t, err := template.New("findLimitByField").Parse(sqltemplate.FindLimitByField) - if err != nil { - return "", err - } - list := make([]string, 0) - for _, field := range table.Fields { - if field.IsPrimaryKey { - continue - } - if field.QueryType != QueryLimit { - continue - } - fineOneByFieldBuffer := new(bytes.Buffer) - upperFields := make([]string, 0) - in := make([]string, 0) - expressionFields := make([]string, 0) - expressionValuesFields := make([]string, 0) - upperFields = append(upperFields, field.UpperCamelCase) - in = append(in, field.LowerCamelCase+" "+field.DataType) - expressionFields = append(expressionFields, field.SnakeCase+" = ?") - expressionValuesFields = append(expressionValuesFields, field.LowerCamelCase) - for _, withField := range field.WithFields { - upperFields = append(upperFields, withField.UpperCamelCase) - in = append(in, withField.LowerCamelCase+" "+withField.DataType) - expressionFields = append(expressionFields, withField.SnakeCase+" = ?") - expressionValuesFields = append(expressionValuesFields, withField.LowerCamelCase) - } - sortList := make([]string, 0) - for _, item := range field.Sort { - var sort = "ASC" - if !item.Asc { - sort = "DESC" - } - sortList = append(sortList, item.Field.SnakeCase+" "+sort) - } - err = t.Execute(fineOneByFieldBuffer, map[string]interface{}{ - "in": strings.Join(in, ","), - "upperObject": table.UpperCamelCase, - "upperFields": strings.Join(upperFields, "And"), - "lowerObject": table.LowerCamelCase, - "expression": strings.Join(expressionFields, " AND "), - "expressionValues": strings.Join(expressionValuesFields, ", "), - "sortExpression": strings.Join(sortList, ","), - "containsCache": table.ContainsCache, - }) - if err != nil { - return "", err - } - list = append(list, fineOneByFieldBuffer.String()) - } - return strings.Join(list, ""), nil -} diff --git a/tools/goctl/model/sql/gen/findone.go b/tools/goctl/model/sql/gen/findone.go index 080928b7e..b7b1a1b5c 100644 --- a/tools/goctl/model/sql/gen/findone.go +++ b/tools/goctl/model/sql/gen/findone.go @@ -1,30 +1,27 @@ package gen import ( - "bytes" - "text/template" - - sqltemplate "github.com/tal-tech/go-zero/tools/goctl/model/sql/template" + "github.com/tal-tech/go-zero/tools/goctl/model/sql/template" + "github.com/tal-tech/go-zero/tools/goctl/util/stringx" + "github.com/tal-tech/go-zero/tools/goctl/util/templatex" ) -func genFindOne(table *InnerTable) (string, error) { - t, err := template.New("findOne").Parse(sqltemplate.FindOne) +func genFindOne(table Table, withCache bool) (string, error) { + camel := table.Name.Snake2Camel() + output, err := templatex.With("findOne"). + Parse(template.FindOne). + Execute(map[string]interface{}{ + "withCache": withCache, + "upperStartCamelObject": camel, + "lowerStartCamelObject": stringx.From(camel).LowerStart(), + "originalPrimaryKey": table.PrimaryKey.Name.Source(), + "lowerStartCamelPrimaryKey": stringx.From(table.PrimaryKey.Name.Snake2Camel()).LowerStart(), + "dataType": table.PrimaryKey.DataType, + "cacheKey": table.CacheKey[table.PrimaryKey.Name.Source()].KeyExpression, + "cacheKeyVariable": table.CacheKey[table.PrimaryKey.Name.Source()].Variable, + }) if err != nil { return "", err } - fineOneBuffer := new(bytes.Buffer) - err = t.Execute(fineOneBuffer, map[string]interface{}{ - "withCache": table.PrimaryField.Cache, - "upperObject": table.UpperCamelCase, - "lowerObject": table.LowerCamelCase, - "snakePrimaryKey": table.PrimaryField.SnakeCase, - "lowerPrimaryKey": table.PrimaryField.LowerCamelCase, - "dataType": table.PrimaryField.DataType, - "cacheKey": table.CacheKey[table.PrimaryField.SnakeCase].Key, - "cacheKeyVariable": table.CacheKey[table.PrimaryField.SnakeCase].KeyVariable, - }) - if err != nil { - return "", err - } - return fineOneBuffer.String(), nil + return output.String(), nil } diff --git a/tools/goctl/model/sql/gen/fineonebyfield.go b/tools/goctl/model/sql/gen/fineonebyfield.go index 2fbc4f37c..06f547874 100644 --- a/tools/goctl/model/sql/gen/fineonebyfield.go +++ b/tools/goctl/model/sql/gen/fineonebyfield.go @@ -1,67 +1,41 @@ package gen import ( - "bytes" + "fmt" "strings" - "text/template" - sqltemplate "github.com/tal-tech/go-zero/tools/goctl/model/sql/template" + "github.com/tal-tech/go-zero/tools/goctl/model/sql/template" + "github.com/tal-tech/go-zero/tools/goctl/util/stringx" + "github.com/tal-tech/go-zero/tools/goctl/util/templatex" ) -func genFineOneByField(table *InnerTable) (string, error) { - t, err := template.New("findOneByField").Parse(sqltemplate.FindOneByField) - if err != nil { - return "", err - } - list := make([]string, 0) +func genFineOneByField(table Table, withCache bool) (string, error) { + t := templatex.With("findOneByField").Parse(template.FindOneByField) + var list []string + camelTableName := table.Name.Snake2Camel() for _, field := range table.Fields { if field.IsPrimaryKey { continue } - if field.QueryType != QueryOne { - continue - } - fineOneByFieldBuffer := new(bytes.Buffer) - upperFields := make([]string, 0) - in := make([]string, 0) - expressionFields := make([]string, 0) - expressionValuesFields := make([]string, 0) - upperFields = append(upperFields, field.UpperCamelCase) - in = append(in, field.LowerCamelCase+" "+field.DataType) - expressionFields = append(expressionFields, field.SnakeCase+" = ?") - expressionValuesFields = append(expressionValuesFields, field.LowerCamelCase) - for _, withField := range field.WithFields { - upperFields = append(upperFields, withField.UpperCamelCase) - in = append(in, withField.LowerCamelCase+" "+withField.DataType) - expressionFields = append(expressionFields, withField.SnakeCase+" = ?") - expressionValuesFields = append(expressionValuesFields, withField.LowerCamelCase) - } - err = t.Execute(fineOneByFieldBuffer, map[string]interface{}{ - "in": strings.Join(in, ","), - "upperObject": table.UpperCamelCase, - "upperFields": strings.Join(upperFields, "And"), - "onlyOneFiled": len(field.WithFields) == 0, - "withCache": field.Cache, - "containsCache": table.ContainsCache, - "lowerObject": table.LowerCamelCase, - "lowerField": field.LowerCamelCase, - "snakeField": field.SnakeCase, - "lowerPrimaryKey": table.PrimaryField.LowerCamelCase, - "UpperPrimaryKey": table.PrimaryField.UpperCamelCase, - "primaryKeyDefine": table.CacheKey[table.PrimaryField.SnakeCase].Define, - "primarySnakeField": table.PrimaryField.SnakeCase, - "primaryDataType": table.PrimaryField.DataType, - "primaryDataTypeString": table.PrimaryField.DataType == "string", - "upperObjectKey": table.PrimaryField.UpperCamelCase, - "cacheKey": table.CacheKey[field.SnakeCase].Key, - "cacheKeyVariable": table.CacheKey[field.SnakeCase].KeyVariable, - "expression": strings.Join(expressionFields, " AND "), - "expressionValues": strings.Join(expressionValuesFields, ", "), + camelFieldName := field.Name.Snake2Camel() + output, err := t.Execute(map[string]interface{}{ + "upperStartCamelObject": camelTableName, + "upperField": camelFieldName, + "in": fmt.Sprintf("%s %s", stringx.From(camelFieldName).LowerStart(), field.DataType), + "withCache": withCache, + "cacheKey": table.CacheKey[field.Name.Source()].KeyExpression, + "cacheKeyVariable": table.CacheKey[field.Name.Source()].Variable, + "primaryKeyLeft": table.CacheKey[table.Name.Source()].Left, + "lowerStartCamelObject": stringx.From(camelTableName).LowerStart(), + "lowerStartCamelField": stringx.From(camelFieldName).LowerStart(), + "upperStartCamelPrimaryKey": table.PrimaryKey.Name.Snake2Camel(), + "originalField": field.Name.Source(), + "originalPrimaryField": table.PrimaryKey.Name.Source(), }) if err != nil { return "", err } - list = append(list, fineOneByFieldBuffer.String()) + list = append(list, output.String()) } - return strings.Join(list, ""), nil + return strings.Join(list, "\n"), nil } diff --git a/tools/goctl/model/sql/gen/gen.go b/tools/goctl/model/sql/gen/gen.go new file mode 100644 index 000000000..cfca15cd4 --- /dev/null +++ b/tools/goctl/model/sql/gen/gen.go @@ -0,0 +1,160 @@ +package gen + +import ( + "fmt" + "io/ioutil" + "os" + "path/filepath" + "strings" + + "github.com/tal-tech/go-zero/tools/goctl/model/sql/parser" + sqltemplate "github.com/tal-tech/go-zero/tools/goctl/model/sql/template" + "github.com/tal-tech/go-zero/tools/goctl/util" + "github.com/tal-tech/go-zero/tools/goctl/util/stringx" + "github.com/tal-tech/go-zero/tools/goctl/util/templatex" +) + +const ( + pwd = "." + createTableFlag = `(?m)CREATE\s+TABLE` +) + +type ( + defaultGenerator struct { + source string + src string + dir string + } +) + +func NewDefaultGenerator(src, dir string) *defaultGenerator { + if src == "" { + src = pwd + } + if dir == "" { + dir = pwd + } + return &defaultGenerator{src: src, dir: dir} +} + +func (g *defaultGenerator) Start(withCache bool) error { + fileSrc, err := filepath.Abs(g.src) + if err != nil { + return err + } + dirAbs, err := filepath.Abs(g.dir) + if err != nil { + return err + } + err = util.MkdirIfNotExist(dirAbs) + if err != nil { + return err + } + data, err := ioutil.ReadFile(fileSrc) + if err != nil { + return err + } + g.source = string(data) + modelList, err := g.genFromDDL(withCache) + if err != nil { + return err + } + + for tableName, code := range modelList { + filename := filepath.Join(dirAbs, fmt.Sprintf("%smodel.go", stringx.From(tableName).Lower())) + err = ioutil.WriteFile(filename, []byte(code), os.ModePerm) + if err != nil { + return err + } + } + return nil +} + +// ret1: key-table name,value-code +func (g *defaultGenerator) genFromDDL(withCache bool) (map[string]string, error) { + ddlList := g.split() + m := make(map[string]string) + for _, ddl := range ddlList { + table, err := parser.Parse(ddl) + if err != nil { + return nil, err + } + code, err := g.genModel(*table, withCache) + if err != nil { + return nil, err + } + m[table.Name.Source()] = code + } + return m, nil +} + +type ( + Table struct { + parser.Table + CacheKey map[string]Key + } +) + +func (g *defaultGenerator) genModel(in parser.Table, withCache bool) (string, error) { + t := templatex.With("model"). + Parse(sqltemplate.Model). + GoFmt(true) + + m, err := genCacheKeys(in) + if err != nil { + return "", err + } + importsCode := genImports() + var table Table + table.Table = in + table.CacheKey = m + + varsCode, err := genVars(table) + if err != nil { + return "", err + } + typesCode, err := genTypes(table) + if err != nil { + return "", err + } + newCode, err := genNew(table) + if err != nil { + return "", err + } + insertCode, err := genInsert(table) + if err != nil { + return "", err + } + var findCode = make([]string, 0) + findOneCode, err := genFindOne(table, withCache) + if err != nil { + return "", err + } + findOneByFieldCode, err := genFineOneByField(table, withCache) + if err != nil { + return "", err + } + findCode = append(findCode, findOneCode, findOneByFieldCode) + updateCode, err := genUpdate(table, withCache) + if err != nil { + return "", err + } + deleteCode, err := genDelete(table, withCache) + if err != nil { + return "", err + } + output, err := t.Execute(map[string]interface{}{ + "imports": importsCode, + "vars": varsCode, + "types": typesCode, + "new": newCode, + "insert": insertCode, + "find": strings.Join(findCode, "\r\n"), + "update": updateCode, + "delete": deleteCode, + }) + if err != nil { + return "", err + } + return output.String(), nil +} diff --git a/tools/goctl/model/sql/gen/imports.go b/tools/goctl/model/sql/gen/imports.go index dcf4f0ddd..85c253644 100644 --- a/tools/goctl/model/sql/gen/imports.go +++ b/tools/goctl/model/sql/gen/imports.go @@ -1,23 +1,9 @@ package gen import ( - "bytes" - "text/template" - sqltemplate "github.com/tal-tech/go-zero/tools/goctl/model/sql/template" ) -func genImports(table *InnerTable) (string, error) { - t, err := template.New("imports").Parse(sqltemplate.Imports) - if err != nil { - return "", err - } - importBuffer := new(bytes.Buffer) - err = t.Execute(importBuffer, map[string]interface{}{ - "containsCache": table.ContainsCache, - }) - if err != nil { - return "", err - } - return importBuffer.String(), nil +func genImports() string { + return sqltemplate.Imports } diff --git a/tools/goctl/model/sql/gen/insert.go b/tools/goctl/model/sql/gen/insert.go index d246c81f3..2a8a14f17 100644 --- a/tools/goctl/model/sql/gen/insert.go +++ b/tools/goctl/model/sql/gen/insert.go @@ -1,37 +1,38 @@ package gen import ( - "bytes" "strings" - "text/template" - sqltemplate "github.com/tal-tech/go-zero/tools/goctl/model/sql/template" + "github.com/tal-tech/go-zero/tools/goctl/model/sql/template" + "github.com/tal-tech/go-zero/tools/goctl/util/stringx" + "github.com/tal-tech/go-zero/tools/goctl/util/templatex" ) -func genInsert(table *InnerTable) (string, error) { - t, err := template.New("insert").Parse(sqltemplate.Insert) - if err != nil { - return "", nil - } - insertBuffer := new(bytes.Buffer) +func genInsert(table Table) (string, error) { expressions := make([]string, 0) expressionValues := make([]string, 0) for _, filed := range table.Fields { - if filed.SnakeCase == "create_time" || filed.SnakeCase == "update_time" || filed.IsPrimaryKey { + camel := filed.Name.Snake2Camel() + if camel == "CreateTime" || camel == "UpdateTime" { + continue + } + if filed.IsPrimaryKey && table.PrimaryKey.AutoIncrement { continue } expressions = append(expressions, "?") - expressionValues = append(expressionValues, "data."+filed.UpperCamelCase) + expressionValues = append(expressionValues, "data."+camel) } - err = t.Execute(insertBuffer, map[string]interface{}{ - "upperObject": table.UpperCamelCase, - "lowerObject": table.LowerCamelCase, - "expression": strings.Join(expressions, ", "), - "expressionValues": strings.Join(expressionValues, ", "), - "containsCache": table.ContainsCache, - }) + camel := table.Name.Snake2Camel() + output, err := templatex.With("insert"). + Parse(template.Insert). + Execute(map[string]interface{}{ + "upperStartCamelObject": camel, + "lowerStartCamelObject": stringx.From(camel).LowerStart(), + "expression": strings.Join(expressions, ", "), + "expressionValues": strings.Join(expressionValues, ", "), + }) if err != nil { return "", err } - return insertBuffer.String(), nil + return output.String(), nil } diff --git a/tools/goctl/model/sql/gen/keys.go b/tools/goctl/model/sql/gen/keys.go index 6fa2ce02a..5d9437b9f 100644 --- a/tools/goctl/model/sql/gen/keys.go +++ b/tools/goctl/model/sql/gen/keys.go @@ -1,105 +1,50 @@ package gen import ( - "bytes" - "strings" - "text/template" -) + "fmt" -var ( - cacheKeyExpressionTemplate = `cache{{.upperCamelTable}}{{.upperCamelField}}Prefix = "cache#{{.lowerCamelTable}}#{{.lowerCamelField}}#"` - keyTemplate = `{{.lowerCamelField}}Key := fmt.Sprintf("%s%v", {{.define}}, {{.lowerCamelField}})` - keyRespTemplate = `{{.lowerCamelField}}Key := fmt.Sprintf("%s%v", {{.define}}, resp.{{.upperCamelField}})` - keyDataTemplate = `{{.lowerCamelField}}Key := fmt.Sprintf("%s%v", {{.define}}, data.{{.upperCamelField}})` + "github.com/tal-tech/go-zero/tools/goctl/model/sql/parser" + "github.com/tal-tech/go-zero/tools/goctl/util/stringx" ) type ( + // tableName:user + // {{prefix}}=cache + // key:id Key struct { - Define string // cacheKey define,如:cacheUserUserIdPrefix - Value string // cacheKey value expression,如:cache#user#userId# - Expression string // cacheKey expression,如:cacheUserUserIdPrefix="cache#user#userId#" - KeyVariable string // cacheKey 声明变量,如:userIdKey - Key string // 缓存key的代码,如 userIdKey:=fmt.Sprintf("%s%v", cacheUserUserIdPrefix, userId) - DataKey string // 缓存key的代码,如 userIdKey:=fmt.Sprintf("%s%v", cacheUserUserIdPrefix, data.userId) - RespKey string // 缓存key的代码,如 userIdKey:=fmt.Sprintf("%s%v", cacheUserUserIdPrefix, resp.userId) + VarExpression string // cacheUserIdPrefix="cache#user#id#" + Left string // cacheUserIdPrefix + Right string // cache#user#id# + Variable string // userIdKey + KeyExpression string // userIdKey: = fmt.Sprintf("cache#user#id#%v", userId) + DataKeyExpression string // userIdKey: = fmt.Sprintf("cache#user#id#%v", data.userId) + RespKeyExpression string // userIdKey: = fmt.Sprintf("cache#user#id#%v", resp.userId) } ) -// key-数据库原始字段名,value-缓存key对象 -func genCacheKeys(table *InnerTable) (map[string]Key, error) { +// key-数据库原始字段名,value-缓存key相关数据 +func genCacheKeys(table parser.Table) (map[string]Key, error) { fields := table.Fields - var m = make(map[string]Key) - if !table.ContainsCache { - return m, nil - } + m := make(map[string]Key) + camelTableName := table.Name.Snake2Camel() + lowerStartCamelTableName := stringx.From(camelTableName).LowerStart() for _, field := range fields { - if !field.Cache && !field.IsPrimaryKey { + if !field.IsKey { continue } - t, err := template.New("keyExpression").Parse(cacheKeyExpressionTemplate) - if err != nil { - return nil, err - } - var expressionBuffer = new(bytes.Buffer) - err = t.Execute(expressionBuffer, map[string]string{ - "upperCamelTable": table.UpperCamelCase, - "lowerCamelTable": table.LowerCamelCase, - "upperCamelField": field.UpperCamelCase, - "lowerCamelField": field.LowerCamelCase, - }) - if err != nil { - return nil, err - } - expression := expressionBuffer.String() - expressionAr := strings.Split(expression, "=") - define := strings.TrimSpace(expressionAr[0]) - value := strings.TrimSpace(expressionAr[1]) - t, err = template.New("key").Parse(keyTemplate) - if err != nil { - return nil, err - } - var keyBuffer = new(bytes.Buffer) - err = t.Execute(keyBuffer, map[string]string{ - "lowerCamelField": field.LowerCamelCase, - "define": define, - }) - if err != nil { - return nil, err - } - t, err = template.New("keyData").Parse(keyDataTemplate) - if err != nil { - return nil, err - } - var keyDataBuffer = new(bytes.Buffer) - err = t.Execute(keyDataBuffer, map[string]string{ - "lowerCamelField": field.LowerCamelCase, - "upperCamelField": field.UpperCamelCase, - "define": define, - }) - if err != nil { - return nil, err - } - t, err = template.New("keyResp").Parse(keyRespTemplate) - if err != nil { - return nil, err - } - var keyRespBuffer = new(bytes.Buffer) - err = t.Execute(keyRespBuffer, map[string]string{ - "lowerCamelField": field.LowerCamelCase, - "upperCamelField": field.UpperCamelCase, - "define": define, - }) - if err != nil { - return nil, err - } - m[field.SnakeCase] = Key{ - Define: define, - Value: value, - Expression: expression, - KeyVariable: field.LowerCamelCase + "Key", - Key: keyBuffer.String(), - DataKey: keyDataBuffer.String(), - RespKey: keyRespBuffer.String(), + camelFieldName := field.Name.Snake2Camel() + lowerStartCamelFieldName := stringx.From(camelFieldName).LowerStart() + left := fmt.Sprintf("cache%s%sPrefix", camelTableName, camelFieldName) + right := fmt.Sprintf("cache#%s#%s#", lowerStartCamelTableName, lowerStartCamelFieldName) + variable := fmt.Sprintf("%s%sKey", lowerStartCamelTableName, camelFieldName) + m[field.Name.Source()] = Key{ + VarExpression: fmt.Sprintf(`%s = "%s"`, left, right), + Left: left, + Right: right, + Variable: variable, + KeyExpression: fmt.Sprintf(`%s := fmt.Sprintf("cache#user#id#%s", %s)`, variable, "%s", lowerStartCamelFieldName), + DataKeyExpression: fmt.Sprintf(`%s := fmt.Sprintf("cache#user#id#%s", data.%s)`, variable, "%s", lowerStartCamelFieldName), + RespKeyExpression: fmt.Sprintf(`%s := fmt.Sprintf("cache#user#id#%s", resp.%s)`, variable, "%s", lowerStartCamelFieldName), } } return m, nil diff --git a/tools/goctl/model/sql/gen/keys_test.go b/tools/goctl/model/sql/gen/keys_test.go deleted file mode 100644 index 1018f21ac..000000000 --- a/tools/goctl/model/sql/gen/keys_test.go +++ /dev/null @@ -1,100 +0,0 @@ -package gen - -import ( - "log" - "testing" - - "github.com/tal-tech/go-zero/core/logx" -) - -func TestKeys(t *testing.T) { - var table = OuterTable{ - Table: "user_info", - CreateNotFound: true, - Fields: []*OuterFiled{ - { - IsPrimaryKey: true, - Name: "user_id", - DataBaseType: "bigint", - Comment: "主键id", - }, - { - Name: "campus_id", - DataBaseType: "bigint", - Comment: "整校id", - QueryType: QueryAll, - Cache: false, - }, - { - Name: "name", - DataBaseType: "varchar", - Comment: "用户姓名", - QueryType: QueryOne, - }, - { - Name: "id_number", - DataBaseType: "varchar", - Comment: "身份证", - Cache: false, - QueryType: QueryNone, - WithFields: []OuterWithField{ - { - Name: "name", - DataBaseType: "varchar", - }, - }, - }, - { - Name: "age", - DataBaseType: "int", - Comment: "年龄", - Cache: false, - QueryType: QueryNone, - }, - { - Name: "gender", - DataBaseType: "tinyint", - Comment: "性别,0-男,1-女,2-不限", - QueryType: QueryLimit, - WithFields: []OuterWithField{ - { - Name: "campus_id", - DataBaseType: "bigint", - }, - }, - OuterSort: []OuterSort{ - { - Field: "create_time", - Asc: false, - }, - }, - }, - { - Name: "mobile", - DataBaseType: "varchar", - Comment: "手机号", - QueryType: QueryOne, - Cache: true, - }, - { - Name: "create_time", - DataBaseType: "timestamp", - Comment: "创建时间", - }, - { - Name: "update_time", - DataBaseType: "timestamp", - Comment: "更新时间", - }, - }, - } - innerTable, err := TableConvert(table) - if err != nil { - log.Fatalln(err) - } - tp, err := GenModel(innerTable) - if err != nil { - log.Fatalln(err) - } - logx.Info(tp) -} diff --git a/tools/goctl/model/sql/gen/model.go b/tools/goctl/model/sql/gen/model.go deleted file mode 100644 index 1dcbe9b95..000000000 --- a/tools/goctl/model/sql/gen/model.go +++ /dev/null @@ -1,86 +0,0 @@ -package gen - -import ( - "bytes" - "go/format" - "strings" - "text/template" - - "github.com/tal-tech/go-zero/core/logx" - sqltemplate "github.com/tal-tech/go-zero/tools/goctl/model/sql/template" -) - -func GenModel(table *InnerTable) (string, error) { - t, err := template.New("model").Parse(sqltemplate.Model) - if err != nil { - return "", nil - } - modelBuffer := new(bytes.Buffer) - importsCode, err := genImports(table) - if err != nil { - return "", err - } - varsCode, err := genVars(table) - if err != nil { - return "", err - } - typesCode, err := genTypes(table) - if err != nil { - return "", err - } - newCode, err := genNew(table) - if err != nil { - return "", err - } - insertCode, err := genInsert(table) - if err != nil { - return "", err - } - var findCode = make([]string, 0) - findOneCode, err := genFindOne(table) - if err != nil { - return "", err - } - findOneByFieldCode, err := genFineOneByField(table) - if err != nil { - return "", err - } - findAllCode, err := genFindAllByField(table) - if err != nil { - return "", err - } - findLimitCode, err := genFindLimitByField(table) - if err != nil { - return "", err - } - findCode = append(findCode, findOneCode, findOneByFieldCode, findAllCode, findLimitCode) - updateCode, err := genUpdate(table) - if err != nil { - return "", err - } - deleteCode, err := genDelete(table) - if err != nil { - return "", err - } - - err = t.Execute(modelBuffer, map[string]interface{}{ - "imports": importsCode, - "vars": varsCode, - "types": typesCode, - "new": newCode, - "insert": insertCode, - "find": strings.Join(findCode, "\r\n"), - "update": updateCode, - "delete": deleteCode, - }) - if err != nil { - return "", err - } - result := modelBuffer.String() - bts, err := format.Source([]byte(result)) - if err != nil { - logx.Errorf("%+v", err) - return "", err - } - return string(bts), nil -} diff --git a/tools/goctl/model/sql/gen/new.go b/tools/goctl/model/sql/gen/new.go index 47047b897..99fb835a6 100644 --- a/tools/goctl/model/sql/gen/new.go +++ b/tools/goctl/model/sql/gen/new.go @@ -1,24 +1,18 @@ package gen import ( - "bytes" - "text/template" - - sqltemplate "github.com/tal-tech/go-zero/tools/goctl/model/sql/template" + "github.com/tal-tech/go-zero/tools/goctl/model/sql/template" + "github.com/tal-tech/go-zero/tools/goctl/util/templatex" ) -func genNew(table *InnerTable) (string, error) { - t, err := template.New("new").Parse(sqltemplate.New) +func genNew(table Table) (string, error) { + output, err := templatex.With("new"). + Parse(template.New). + Execute(map[string]interface{}{ + "upperStartCamelObject": table.Name.Snake2Camel(), + }) if err != nil { return "", err } - newBuffer := new(bytes.Buffer) - err = t.Execute(newBuffer, map[string]interface{}{ - "containsCache": table.ContainsCache, - "upperObject": table.UpperCamelCase, - }) - if err != nil { - return "", err - } - return newBuffer.String(), nil + return output.String(), nil } diff --git a/tools/goctl/model/sql/gen/shared.go b/tools/goctl/model/sql/gen/shared.go deleted file mode 100644 index fdafab728..000000000 --- a/tools/goctl/model/sql/gen/shared.go +++ /dev/null @@ -1,99 +0,0 @@ -package gen - -var ( - commonMysqlDataTypeMap = map[string]string{ - "tinyint": "int64", - "smallint": "int64", - "mediumint": "int64", - "int": "int64", - "integer": "int64", - "bigint": "int64", - "float": "float64", - "double": "float64", - "decimal": "float64", - "date": "time.Time", - "time": "string", - "year": "int64", - "datetime": "time.Time", - "timestamp": "time.Time", - "char": "string", - "varchar": "string", - "tinyblob": "string", - "tinytext": "string", - "blob": "string", - "text": "string", - "mediumblob": "string", - "mediumtext": "string", - "longblob": "string", - "longtext": "string", - } -) - -const ( - QueryNone QueryType = 0 - QueryOne QueryType = 1 // 仅支持单个字段为查询条件 - QueryAll QueryType = 2 // 可支持多个字段为查询条件,且关系均为and - QueryLimit QueryType = 3 // 可支持多个字段为查询条件,且关系均为and -) - -type ( - QueryType int - - Case struct { - SnakeCase string - LowerCamelCase string - UpperCamelCase string - } - InnerWithField struct { - Case - DataType string - } - InnerTable struct { - Case - ContainsCache bool - CreateNotFound bool - PrimaryField *InnerField - Fields []*InnerField - CacheKey map[string]Key // key-数据库字段 - } - InnerField struct { - IsPrimaryKey bool - InnerWithField - DataBaseType string // 数据库中字段类型 - Tag string // 标签,格式:`db:"xxx"` - Comment string // 注释,以"// 开头" - Cache bool // 是否缓存模式 - QueryType QueryType - WithFields []InnerWithField - Sort []InnerSort - } - InnerSort struct { - Field Case - Asc bool - } - - OuterTable struct { - Table string `json:"table"` - CreateNotFound bool `json:"createNotFound,optional"` - Fields []*OuterFiled `json:"fields"` - } - OuterWithField struct { - Name string `json:"name"` - DataBaseType string `json:"dataBaseType"` - } - OuterSort struct { - Field string `json:"fields"` - Asc bool `json:"asc,optional"` - } - OuterFiled struct { - IsPrimaryKey bool `json:"isPrimaryKey,optional"` - Name string `json:"name"` - DataBaseType string `json:"dataBaseType"` - Comment string `json:"comment"` - Cache bool `json:"cache,optional"` - // if IsPrimaryKey==false下面字段有效 - QueryType QueryType `json:"queryType"` // 查找类型 - WithFields []OuterWithField `json:"withFields,optional"` // 其他字段联合组成条件的字段列表 - OuterSort []OuterSort `json:"sort,optional"` - } -) diff --git a/tools/goctl/model/sql/gen/split.go b/tools/goctl/model/sql/gen/split.go new file mode 100644 index 000000000..abc8e6335 --- /dev/null +++ b/tools/goctl/model/sql/gen/split.go @@ -0,0 +1,23 @@ +package gen + +import ( + "regexp" +) + +func (g *defaultGenerator) split() []string { + reg := regexp.MustCompile(createTableFlag) + index := reg.FindAllStringIndex(g.source, -1) + list := make([]string, 0) + source := g.source + for i := len(index) - 1; i >= 0; i-- { + subIndex := index[i] + if len(subIndex) == 0 { + continue + } + start := subIndex[0] + ddl := source[start:] + list = append(list, ddl) + source = source[:start] + } + return list +} diff --git a/tools/goctl/model/sql/gen/tag.go b/tools/goctl/model/sql/gen/tag.go index 6c5edf736..8414a102e 100644 --- a/tools/goctl/model/sql/gen/tag.go +++ b/tools/goctl/model/sql/gen/tag.go @@ -1,26 +1,21 @@ package gen import ( - "bytes" - "text/template" - - sqltemplate "github.com/tal-tech/go-zero/tools/goctl/model/sql/template" + "github.com/tal-tech/go-zero/tools/goctl/model/sql/template" + "github.com/tal-tech/go-zero/tools/goctl/util/templatex" ) func genTag(in string) (string, error) { if in == "" { return in, nil } - t, err := template.New("tag").Parse(sqltemplate.Tag) + output, err := templatex.With("tag"). + Parse(template.Tag). + Execute(map[string]interface{}{ + "field": in, + }) if err != nil { return "", err } - var tagBuffer = new(bytes.Buffer) - err = t.Execute(tagBuffer, map[string]interface{}{ - "field": in, - }) - if err != nil { - return "", err - } - return tagBuffer.String(), nil + return output.String(), nil } diff --git a/tools/goctl/model/sql/gen/types.go b/tools/goctl/model/sql/gen/types.go index 9856f853b..eac7e0590 100644 --- a/tools/goctl/model/sql/gen/types.go +++ b/tools/goctl/model/sql/gen/types.go @@ -1,30 +1,24 @@ package gen import ( - "bytes" - "text/template" - - sqltemplate "github.com/tal-tech/go-zero/tools/goctl/model/sql/template" + "github.com/tal-tech/go-zero/tools/goctl/model/sql/template" + "github.com/tal-tech/go-zero/tools/goctl/util/templatex" ) -func genTypes(table *InnerTable) (string, error) { +func genTypes(table Table) (string, error) { fields := table.Fields - t, err := template.New("types").Parse(sqltemplate.Types) - if err != nil { - return "", nil - } - var typeBuffer = new(bytes.Buffer) fieldsString, err := genFields(fields) if err != nil { return "", err } - err = t.Execute(typeBuffer, map[string]interface{}{ - "upperObject": table.UpperCamelCase, - "containsCache": table.ContainsCache, - "fields": fieldsString, - }) + output, err := templatex.With("types"). + Parse(template.Types). + Execute(map[string]interface{}{ + "upperStartCamelObject": table.Name.Snake2Camel(), + "fields": fieldsString, + }) if err != nil { return "", err } - return typeBuffer.String(), nil + return output.String(), nil } diff --git a/tools/goctl/model/sql/gen/update.go b/tools/goctl/model/sql/gen/update.go index 27d6de6f4..3a034b630 100644 --- a/tools/goctl/model/sql/gen/update.go +++ b/tools/goctl/model/sql/gen/update.go @@ -1,38 +1,40 @@ package gen import ( - "bytes" "strings" - "text/template" - sqltemplate "github.com/tal-tech/go-zero/tools/goctl/model/sql/template" + "github.com/tal-tech/go-zero/tools/goctl/model/sql/template" + "github.com/tal-tech/go-zero/tools/goctl/util/stringx" + "github.com/tal-tech/go-zero/tools/goctl/util/templatex" ) -func genUpdate(table *InnerTable) (string, error) { - t, err := template.New("update").Parse(sqltemplate.Update) +func genUpdate(table Table, withCache bool) (string, error) { + expressionValues := make([]string, 0) + for _, filed := range table.Fields { + camel := filed.Name.Snake2Camel() + if camel == "CreateTime" || camel == "UpdateTime" { + continue + } + if filed.IsPrimaryKey { + continue + } + expressionValues = append(expressionValues, "data."+camel) + } + expressionValues = append(expressionValues, "data."+table.PrimaryKey.Name.Snake2Camel()) + camelTableName := table.Name.Snake2Camel() + output, err := templatex.With("update"). + Parse(template.Update). + Execute(map[string]interface{}{ + "withCache": withCache, + "upperStartCamelObject": camelTableName, + "primaryCacheKey": table.CacheKey[table.PrimaryKey.Name.Source()].DataKeyExpression, + "primaryKeyVariable": table.CacheKey[table.PrimaryKey.Name.Source()].Variable, + "lowerStartCamelObject": stringx.From(camelTableName).LowerStart(), + "originalPrimaryKey": table.PrimaryKey.Name.Source(), + "expressionValues": strings.Join(expressionValues, ", "), + }) if err != nil { return "", nil } - updateBuffer := new(bytes.Buffer) - expressionValues := make([]string, 0) - for _, filed := range table.Fields { - if filed.SnakeCase == "create_time" || filed.SnakeCase == "update_time" || filed.IsPrimaryKey { - continue - } - expressionValues = append(expressionValues, "data."+filed.UpperCamelCase) - } - expressionValues = append(expressionValues, "data."+table.PrimaryField.UpperCamelCase) - err = t.Execute(updateBuffer, map[string]interface{}{ - "containsCache": table.ContainsCache, - "upperObject": table.UpperCamelCase, - "primaryCacheKey": table.CacheKey[table.PrimaryField.SnakeCase].DataKey, - "primaryKeyVariable": table.CacheKey[table.PrimaryField.SnakeCase].KeyVariable, - "lowerObject": table.LowerCamelCase, - "primarySnakeCase": table.PrimaryField.SnakeCase, - "expressionValues": strings.Join(expressionValues, ", "), - }) - if err != nil { - return "", err - } - return updateBuffer.String(), nil + return output.String(), nil } diff --git a/tools/goctl/model/sql/gen/vars.go b/tools/goctl/model/sql/gen/vars.go index f5ed6aeaa..df157c0cd 100644 --- a/tools/goctl/model/sql/gen/vars.go +++ b/tools/goctl/model/sql/gen/vars.go @@ -1,36 +1,32 @@ package gen import ( - "bytes" "strings" - "text/template" - sqltemplate "github.com/tal-tech/go-zero/tools/goctl/model/sql/template" + "github.com/tal-tech/go-zero/tools/goctl/model/sql/template" + "github.com/tal-tech/go-zero/tools/goctl/util/stringx" + "github.com/tal-tech/go-zero/tools/goctl/util/templatex" ) -func genVars(table *InnerTable) (string, error) { - t, err := template.New("vars").Parse(sqltemplate.Vars) - if err != nil { - return "", err - } - varBuffer := new(bytes.Buffer) - m, err := genCacheKeys(table) - if err != nil { - return "", err - } +func genVars(table Table) (string, error) { keys := make([]string, 0) - for _, v := range m { - keys = append(keys, v.Expression) + for _, v := range table.CacheKey { + keys = append(keys, v.VarExpression) } - err = t.Execute(varBuffer, map[string]interface{}{ - "lowerObject": table.LowerCamelCase, - "upperObject": table.UpperCamelCase, - "createNotFound": table.CreateNotFound, - "keysDefine": strings.Join(keys, "\r\n"), - "snakePrimaryKey": table.PrimaryField.SnakeCase, - }) + camel := table.Name.Snake2Camel() + output, err := templatex.With("var"). + Parse(template.Vars). + GoFmt(true). + Execute(map[string]interface{}{ + "lowerStartCamelObject": stringx.From(camel).LowerStart(), + "upperStartCamelObject": camel, + "cacheKeys": strings.Join(keys, "\r\n"), + "autoIncrement": table.PrimaryKey.AutoIncrement, + "originalPrimaryKey": table.PrimaryKey.Name.Source(), + }) if err != nil { return "", err } - return varBuffer.String(), nil + + return output.String(), nil } diff --git a/tools/goctl/model/sql/parser/error.go b/tools/goctl/model/sql/parser/error.go new file mode 100644 index 000000000..57f56a0c9 --- /dev/null +++ b/tools/goctl/model/sql/parser/error.go @@ -0,0 +1,11 @@ +package parser + +import ( + "errors" +) + +var ( + unSupportDDL = errors.New("unexpected type") + tableBodyIsNotFound = errors.New("create table spec not found") + errPrimaryKey = errors.New("unexpected joint primary key") +) diff --git a/tools/goctl/model/sql/parser/parser.go b/tools/goctl/model/sql/parser/parser.go new file mode 100644 index 000000000..dec0da9cb --- /dev/null +++ b/tools/goctl/model/sql/parser/parser.go @@ -0,0 +1,104 @@ +package parser + +import ( + "fmt" + + "github.com/tal-tech/go-zero/tools/goctl/model/sql/converter" + "github.com/tal-tech/go-zero/tools/goctl/util/stringx" + "github.com/xwb1989/sqlparser" +) + +type ( + Table struct { + Name stringx.String + PrimaryKey Primary + Fields []Field + } + Primary struct { + Field + AutoIncrement bool + } + Field struct { + Name stringx.String + DataBaseType string + DataType string + IsKey bool + IsPrimaryKey bool + Comment string + } +) + +func Parse(ddl string) (*Table, error) { + stmt, err := sqlparser.ParseStrictDDL(ddl) + if err != nil { + return nil, err + } + ddlStmt, ok := stmt.(*sqlparser.DDL) + if !ok { + return nil, unSupportDDL + } + action := ddlStmt.Action + if action != sqlparser.CreateStr { + return nil, fmt.Errorf("expected [CREATE] action,but found: %s", action) + } + tableName := ddlStmt.NewName.Name.String() + tableSpec := ddlStmt.TableSpec + if tableSpec == nil { + return nil, tableBodyIsNotFound + } + + columns := tableSpec.Columns + indexes := tableSpec.Indexes + + for _, index := range indexes { + info := index.Info + if info == nil { + continue + } + if info.Primary { + if len(index.Columns) > 1 { + return nil, errPrimaryKey + } + break + } + } + var ( + fields []Field + primaryKey Primary + ) + for _, column := range columns { + if column == nil { + continue + } + var comment string + if column.Type.Comment != nil { + comment = string(column.Type.Comment.Val) + } + dataType, err := converter.ConvertDataType(column.Type.Type) + if err != nil { + return nil, err + } + var field Field + field.Name = stringx.From(column.Name.String()) + field.DataBaseType = column.Type.Type + field.DataType = dataType + field.Comment = comment + // see line 1194 https://github.com/xwb1989/sqlparser/blob/master/ast.go + field.IsKey = column.Type.KeyOpt != 0 + field.IsPrimaryKey = column.Type.KeyOpt == 1 + fields = append(fields, field) + // see line 1195 https://github.com/xwb1989/sqlparser/blob/master/ast.go + if column.Type.KeyOpt == 1 { + primaryKey.Field = field + if column.Type.Autoincrement { + primaryKey.AutoIncrement = true + } + } + } + return &Table{ + Name: stringx.From(tableName), + PrimaryKey: primaryKey, + Fields: fields, + }, nil + +} diff --git a/tools/goctl/model/sql/sqlctl.go b/tools/goctl/model/sql/sqlctl.go deleted file mode 100644 index 56459388d..000000000 --- a/tools/goctl/model/sql/sqlctl.go +++ /dev/null @@ -1,22 +0,0 @@ -package sqlgen - -import ( - "errors" - - "github.com/tal-tech/go-zero/tools/goctl/model/sql/gen" -) - -var ( - ErrCircleQuery = errors.New("circle query with other fields") -) - -func Gen(in gen.OuterTable) (string, error) { - t, err := gen.TableConvert(in) - if err != nil { - if err == gen.ErrCircleQuery { - return "", ErrCircleQuery - } - return "", err - } - return gen.GenModel(t) -} diff --git a/tools/goctl/model/sql/template/delete.go b/tools/goctl/model/sql/template/delete.go index 372aff45c..7c2f8b57c 100644 --- a/tools/goctl/model/sql/template/delete.go +++ b/tools/goctl/model/sql/template/delete.go @@ -1,17 +1,17 @@ -package sqltemplate +package template var Delete = ` -func (m *{{.upperObject}}Model) Delete({{.lowerPrimaryKey}} {{.dataType}}) error { - {{if .containsCache}}{{if .isNotPrimaryKey}}data,err:=m.FindOne({{.lowerPrimaryKey}}) +func (m *{{.upperStartCamelObject}}Model) Delete({{.lowerStartCamelPrimaryKey}} {{.dataType}}) error { + {{if .withCache}}{{if .containsIndexCache}}data,err:=m.FindOne({{.lowerStartCamelPrimaryKey}}) if err!=nil{ return err }{{end}} {{.keys}} - _, err {{if .isNotPrimaryKey}}={{else}}:={{end}} m.Exec(func(conn sqlx.SqlConn) (result sql.Result, err error) { - query := ` + "`" + `delete from ` + "` +" + ` m.table + ` + " `" + ` where {{.snakePrimaryKey}} = ?` + "`" + ` - return conn.Exec(query, {{.lowerPrimaryKey}}) - }, {{.keyValues}}){{else}}query := ` + "`" + `delete from ` + "` +" + ` m.table + ` + " `" + ` where {{.snakePrimaryKey}} = ?` + "`" + ` - _,err:=m.ExecNoCache(query, {{.lowerPrimaryKey}}){{end}} + _, err {{if .containsIndexCache}}={{else}}:={{end}} m.Exec(func(conn sqlx.SqlConn) (result sql.Result, err error) { + query := ` + "`" + `delete from ` + "` +" + ` m.table + ` + " `" + ` where {{.originalPrimaryKey}} = ?` + "`" + ` + return conn.Exec(query, {{.lowerStartCamelPrimaryKey}}) + }, {{.keyValues}}){{else}}query := ` + "`" + `delete from ` + "` +" + ` m.table + ` + " `" + ` where {{.originalPrimaryKey}} = ?` + "`" + ` + _,err:=m.ExecNoCache(query, {{.lowerStartCamelPrimaryKey}}){{end}} return err } ` diff --git a/tools/goctl/model/sql/template/errors.go b/tools/goctl/model/sql/template/errors.go index 4df82d956..a24ff3fa6 100644 --- a/tools/goctl/model/sql/template/errors.go +++ b/tools/goctl/model/sql/template/errors.go @@ -1,4 +1,4 @@ -package sqltemplate +package template var Error = `package model diff --git a/tools/goctl/model/sql/template/field.go b/tools/goctl/model/sql/template/field.go index 8a76cb407..ec0b02843 100644 --- a/tools/goctl/model/sql/template/field.go +++ b/tools/goctl/model/sql/template/field.go @@ -1,3 +1,3 @@ -package sqltemplate +package template var Field = `{{.name}} {{.type}} {{.tag}} {{.comment}}` diff --git a/tools/goctl/model/sql/template/find.go b/tools/goctl/model/sql/template/find.go index 0d59f53ef..696e2d9d1 100644 --- a/tools/goctl/model/sql/template/find.go +++ b/tools/goctl/model/sql/template/find.go @@ -1,13 +1,13 @@ -package sqltemplate +package template // 通过id查询 var FindOne = ` -func (m *{{.upperObject}}Model) FindOne({{.lowerPrimaryKey}} {{.dataType}}) (*{{.upperObject}}, error) { +func (m *{{.upperStartCamelObject}}Model) FindOne({{.lowerStartCamelPrimaryKey}} {{.dataType}}) (*{{.upperStartCamelObject}}, error) { {{if .withCache}}{{.cacheKey}} - var resp {{.upperObject}} + var resp {{.upperStartCamelObject}} err := m.QueryRow(&resp, {{.cacheKeyVariable}}, func(conn sqlx.SqlConn, v interface{}) error { - query := ` + "`" + `select ` + "`" + ` + {{.lowerObject}}Rows + ` + "`" + ` from ` + "` + " + `m.table ` + " + `" + ` where {{.snakePrimaryKey}} = ? limit 1` + "`" + ` - return conn.QueryRow(v, query, {{.lowerPrimaryKey}}) + query := ` + "`" + `select ` + "`" + ` + {{.lowerStartCamelObject}}Rows + ` + "`" + ` from ` + "` + " + `m.table ` + " + `" + ` where {{.originalPrimaryKey}} = ? limit 1` + "`" + ` + return conn.QueryRow(v, query, {{.lowerStartCamelPrimaryKey}}) }) switch err { case nil: @@ -16,9 +16,9 @@ func (m *{{.upperObject}}Model) FindOne({{.lowerPrimaryKey}} {{.dataType}}) (*{{ return nil, ErrNotFound default: return nil, err - }{{else}}query := ` + "`" + `select ` + "`" + ` + {{.lowerObject}}Rows + ` + "`" + ` from ` + "` + " + `m.table ` + " + `" + ` where {{.snakePrimaryKey}} = ? limit 1` + "`" + ` - var resp {{.upperObject}} - err := m.QueryRowNoCache(&resp, query, {{.lowerPrimaryKey}}) + }{{else}}query := ` + "`" + `select ` + "`" + ` + {{.lowerStartCamelObject}}Rows + ` + "`" + ` from ` + "` + " + `m.table ` + " + `" + ` where {{.originalPrimaryKey}} = ? limit 1` + "`" + ` + var resp {{.upperStartCamelObject}} + err := m.QueryRowNoCache(&resp, query, {{.lowerStartCamelPrimaryKey}}) switch err { case nil: return &resp, nil @@ -32,19 +32,19 @@ func (m *{{.upperObject}}Model) FindOne({{.lowerPrimaryKey}} {{.dataType}}) (*{{ // 通过指定字段查询 var FindOneByField = ` -func (m *{{.upperObject}}Model) FindOneBy{{.upperFields}}({{.in}}) (*{{.upperObject}}, error) { - {{if .onlyOneFiled}}{{if .withCache}}{{.cacheKey}} - var resp {{.upperObject}} +func (m *{{.upperStartCamelObject}}Model) FindOneBy{{.upperField}}({{.in}}) (*{{.upperStartCamelObject}}, error) { + {{if .withCache}}{{.cacheKey}} + var resp {{.upperStartCamelObject}} err := m.QueryRowIndex(&resp, {{.cacheKeyVariable}}, func(primary interface{}) string { - return fmt.Sprintf("%s%v", {{.primaryKeyDefine}}, primary) + return fmt.Sprintf("%s%v", {{.primaryKeyLeft}}, primary) }, func(conn sqlx.SqlConn, v interface{}) (i interface{}, e error) { - query := ` + "`" + `select ` + "`" + ` + {{.lowerObject}}Rows + ` + "`" + ` from ` + "` + " + `m.table ` + " + `" + ` where {{.snakeField}} = ? limit 1` + "`" + ` - if err := conn.QueryRow(&resp, query, {{.lowerField}}); err != nil { + query := ` + "`" + `select ` + "`" + ` + {{.lowerStartCamelObject}}Rows + ` + "`" + ` from ` + "` + " + `m.table ` + " + `" + ` where {{.originalField}} = ? limit 1` + "`" + ` + if err := conn.QueryRow(&resp, query, {{.lowerStartCamelField}}); err != nil { return nil, err } - return resp.{{.UpperPrimaryKey}}, nil + return resp.{{.upperStartCamelPrimaryKey}}, nil }, func(conn sqlx.SqlConn, v, primary interface{}) error { - query := ` + "`" + `select ` + "`" + ` + {{.lowerObject}}Rows + ` + "`" + ` from ` + "` + " + `m.table ` + " + `" + ` where {{.primarySnakeField}} = ? limit 1` + "`" + ` + query := ` + "`" + `select ` + "`" + ` + {{.lowerStartCamelObject}}Rows + ` + "`" + ` from ` + "` + " + `m.table ` + " + `" + ` where {{.originalPrimaryField}} = ? limit 1` + "`" + ` return conn.QueryRow(v, query, primary) }) switch err { @@ -54,62 +54,6 @@ func (m *{{.upperObject}}Model) FindOneBy{{.upperFields}}({{.in}}) (*{{.upperObj return nil, ErrNotFound default: return nil, err - }{{else}}var resp {{.upperObject}} - query := ` + "`" + `select ` + "`" + ` + {{.lowerObject}}Rows + ` + "`" + ` from ` + "` + " + `m.table ` + " + `" + ` where {{.expression}} limit 1` + "`" + ` - err := m.QueryRowNoCache(&resp, query, {{.expressionValues}}) - switch err { - case nil: - return &resp, nil - case sqlc.ErrNotFound: - return nil, ErrNotFound - default: - return nil, err - }{{end}}{{else}}var resp {{.upperObject}} - query := ` + "`" + `select ` + "`" + ` + {{.lowerObject}}Rows + ` + "`" + ` from ` + "` + " + `m.table ` + " + `" + ` where {{.expression}} limit 1` + "`" + ` - err := m.QueryRowNoCache(&resp, query, {{.expressionValues}}) - switch err { - case nil: - return &resp, nil - case sqlc.ErrNotFound: - return nil, ErrNotFound - default: - return nil, err - }{{end}} -} -` - -// 查询all -var FindAllByField = ` -func (m *{{.upperObject}}Model) FindAllBy{{.upperFields}}({{.in}}) ([]*{{.upperObject}}, error) { - var resp []*{{.upperObject}} - query := ` + "`" + `select ` + "`" + ` + {{.lowerObject}}Rows + ` + "`" + ` from ` + "` + " + `m.table ` + " + `" + ` where {{.expression}}` + "`" + ` - err := m.QueryRowsNoCache(&resp, query, {{.expressionValues}}) - if err != nil { - return nil, err - } - return resp, nil -} -` - -// limit分页查询 -var FindLimitByField = ` -func (m *{{.upperObject}}Model) FindLimitBy{{.upperFields}}({{.in}}, page, limit int) ([]*{{.upperObject}}, error) { - var resp []*{{.upperObject}} - query := ` + "`" + `select ` + "`" + ` + {{.lowerObject}}Rows + ` + "`" + `from ` + "` + " + `m.table ` + " + `" + ` where {{.expression}} order by {{.sortExpression}} limit ?,?` + "`" + ` - err := m.QueryRowsNoCache(&resp, query, {{.expressionValues}}, (page-1)*limit, limit) - if err != nil { - return nil, err - } - return resp, nil -} - -func (m *{{.upperObject}}Model) FindAllCountBy{{.upperFields}}({{.in}}) (int64, error) { - var count int64 - query := ` + "`" + `select count(1) from ` + "` + " + `m.table ` + " + `" + ` where {{.expression}} ` + "`" + ` - err := m.QueryRowNoCache(&count, query, {{.expressionValues}}) - if err != nil { - return 0, err - } - return count, nil + } } ` diff --git a/tools/goctl/model/sql/template/import.go b/tools/goctl/model/sql/template/import.go index dd2c15552..41f7103cc 100644 --- a/tools/goctl/model/sql/template/import.go +++ b/tools/goctl/model/sql/template/import.go @@ -1,13 +1,12 @@ -package sqltemplate +package template var Imports = ` import ( - {{if .containsCache}}"database/sql" - "fmt"{{end}} + "database/sql" "strings" "time" - "github.com/tal-tech/go-zero/core/stores/cache" + "github.com/tal-tech/go-zero/core/stores/cache" "github.com/tal-tech/go-zero/core/stores/sqlc" "github.com/tal-tech/go-zero/core/stores/sqlx" "github.com/tal-tech/go-zero/core/stringx" diff --git a/tools/goctl/model/sql/template/insert.go b/tools/goctl/model/sql/template/insert.go index c97a9a8e2..be71eabd8 100644 --- a/tools/goctl/model/sql/template/insert.go +++ b/tools/goctl/model/sql/template/insert.go @@ -1,8 +1,8 @@ -package sqltemplate +package template var Insert = ` -func (m *{{.upperObject}}Model) Insert(data {{.upperObject}}) error { - query := ` + "`" + `insert into ` + "`" + ` + m.table + ` + "`(` + " + `{{.lowerObject}}RowsExpectAutoSet` + " + `) value ({{.expression}})` " + ` +func (m *{{.upperStartCamelObject}}Model) Insert(data {{.upperStartCamelObject}}) error { + query := ` + "`" + `insert into ` + "`" + ` + m.table + ` + "`(` + " + `{{.lowerStartCamelObject}}RowsExpectAutoSet` + " + `) value ({{.expression}})` " + ` _, err := m.ExecNoCache(query, {{.expressionValues}}) return err } diff --git a/tools/goctl/model/sql/template/model.go b/tools/goctl/model/sql/template/model.go index 3fbe847f7..4509b2d66 100644 --- a/tools/goctl/model/sql/template/model.go +++ b/tools/goctl/model/sql/template/model.go @@ -1,4 +1,4 @@ -package sqltemplate +package template var Model = `package model {{.imports}} diff --git a/tools/goctl/model/sql/template/new.go b/tools/goctl/model/sql/template/new.go index 70c2b6cf1..acead0141 100644 --- a/tools/goctl/model/sql/template/new.go +++ b/tools/goctl/model/sql/template/new.go @@ -1,8 +1,8 @@ -package sqltemplate +package template var New = ` -func New{{.upperObject}}Model(conn sqlx.SqlConn, c cache.CacheConf, table string) *{{.upperObject}}Model { - return &{{.upperObject}}Model{ +func New{{.upperStartCamelObject}}Model(conn sqlx.SqlConn, c cache.CacheConf, table string) *{{.upperStartCamelObject}}Model { + return &{{.upperStartCamelObject}}Model{ CachedConn: sqlc.NewConn(conn, c), table: table, } diff --git a/tools/goctl/model/sql/template/tag.go b/tools/goctl/model/sql/template/tag.go index 74389ca4f..2491b23fd 100644 --- a/tools/goctl/model/sql/template/tag.go +++ b/tools/goctl/model/sql/template/tag.go @@ -1,3 +1,3 @@ -package sqltemplate +package template var Tag = "`db:\"{{.field}}\"`" diff --git a/tools/goctl/model/sql/template/types.go b/tools/goctl/model/sql/template/types.go index e3f11fe5c..27dd844bc 100644 --- a/tools/goctl/model/sql/template/types.go +++ b/tools/goctl/model/sql/template/types.go @@ -1,13 +1,13 @@ -package sqltemplate +package template var Types = ` type ( - {{.upperObject}}Model struct { + {{.upperStartCamelObject}}Model struct { sqlc.CachedConn table string } - {{.upperObject}} struct { + {{.upperStartCamelObject}} struct { {{.fields}} } ) diff --git a/tools/goctl/model/sql/template/update.go b/tools/goctl/model/sql/template/update.go index f2a3f6c20..27c295a81 100644 --- a/tools/goctl/model/sql/template/update.go +++ b/tools/goctl/model/sql/template/update.go @@ -1,12 +1,12 @@ -package sqltemplate +package template var Update = ` -func (m *{{.upperObject}}Model) Update(data {{.upperObject}}) error { - {{if .containsCache}}{{.primaryCacheKey}} +func (m *{{.upperStartCamelObject}}Model) Update(data {{.upperStartCamelObject}}) error { + {{if .withCache}}{{.primaryCacheKey}} _, err := m.Exec(func(conn sqlx.SqlConn) (result sql.Result, err error) { - query := ` + "`" + `update ` + "` +" + `m.table +` + "` " + `set ` + "` +" + `{{.lowerObject}}RowsWithPlaceHolder` + " + `" + ` where {{.primarySnakeCase}} = ?` + "`" + ` + query := ` + "`" + `update ` + "` +" + `m.table +` + "` " + `set ` + "` +" + `{{.lowerStartCamelObject}}RowsWithPlaceHolder` + " + `" + ` where {{.originalPrimaryKey}} = ?` + "`" + ` return conn.Exec(query, {{.expressionValues}}) - }, {{.primaryKeyVariable}}){{else}}query := ` + "`" + `update ` + "` +" + `m.table +` + "` " + `set ` + "` +" + `{{.lowerObject}}RowsWithPlaceHolder` + " + `" + ` where {{.primarySnakeCase}} = ?` + "`" + ` + }, {{.primaryKeyVariable}}){{else}}query := ` + "`" + `update ` + "` +" + `m.table +` + "` " + `set ` + "` +" + `{{.lowerStartCamelObject}}RowsWithPlaceHolder` + " + `" + ` where {{.originalPrimaryKey}} = ?` + "`" + ` _,err:=m.ExecNoCache(query, {{.expressionValues}}){{end}} return err } diff --git a/tools/goctl/model/sql/template/vars.go b/tools/goctl/model/sql/template/vars.go index 18b06b3bb..e883013cb 100644 --- a/tools/goctl/model/sql/template/vars.go +++ b/tools/goctl/model/sql/template/vars.go @@ -1,14 +1,14 @@ -package sqltemplate +package template var Vars = ` var ( - {{.lowerObject}}FieldNames = builderx.FieldNames(&{{.upperObject}}{}) - {{.lowerObject}}Rows = strings.Join({{.lowerObject}}FieldNames, ",") - {{.lowerObject}}RowsExpectAutoSet = strings.Join(stringx.Remove({{.lowerObject}}FieldNames, "{{.snakePrimaryKey}}", "create_time", "update_time"), ",") - {{.lowerObject}}RowsWithPlaceHolder = strings.Join(stringx.Remove({{.lowerObject}}FieldNames, "{{.snakePrimaryKey}}", "create_time", "update_time"), "=?,") + "=?" + {{.lowerStartCamelObject}}FieldNames = builderx.FieldNames(&{{.upperStartCamelObject}}{}) + {{.lowerStartCamelObject}}Rows = strings.Join({{.lowerStartCamelObject}}FieldNames, ",") + {{.lowerStartCamelObject}}RowsExpectAutoSet = strings.Join(stringx.Remove({{.lowerStartCamelObject}}FieldNames, {{.if autoIncrement}}"{{.originalPrimaryKey}}",{{end}} "create_time", "update_time"), ",") + {{.lowerStartCamelObject}}RowsWithPlaceHolder = strings.Join(stringx.Remove({{.lowerStartCamelObject}}FieldNames, "{{.originalPrimaryKey}}", "create_time", "update_time"), "=?,") + "=?" - {{.keysDefine}} + {{.cacheKeys}} - {{if .createNotFound}}ErrNotFound = sqlx.ErrNotFound{{end}} + ErrNotFound = sqlx.ErrNotFound ) ` diff --git a/tools/goctl/model/sql/util/stringurtl_test.go b/tools/goctl/model/sql/util/stringurtl_test.go deleted file mode 100644 index e43f0b51e..000000000 --- a/tools/goctl/model/sql/util/stringurtl_test.go +++ /dev/null @@ -1,12 +0,0 @@ -package util - -import ( - "fmt" - "testing" -) - -func TestFormatField(t *testing.T) { - var in = "go_java" - snakeCase, upperCase, lowerCase := FormatField(in) - fmt.Println(snakeCase, upperCase, lowerCase) -} diff --git a/tools/goctl/model/sql/util/stringutil.go b/tools/goctl/model/sql/util/stringutil.go deleted file mode 100644 index 55b8ae182..000000000 --- a/tools/goctl/model/sql/util/stringutil.go +++ /dev/null @@ -1,48 +0,0 @@ -package util - -import ( - "strings" - "unicode" -) - -func FormatField(field string) (snakeCase, upperCamelCase, lowerCamelCase string) { - snakeCase = field - list := strings.Split(field, "_") - upperCaseList := make([]string, 0) - lowerCaseList := make([]string, 0) - for index, word := range list { - upperStart := convertUpperStart(word) - lowerStart := convertLowerStart(word) - upperCaseList = append(upperCaseList, upperStart) - if index == 0 { - lowerCaseList = append(lowerCaseList, lowerStart) - } else { - lowerCaseList = append(lowerCaseList, upperStart) - } - } - upperCamelCase = strings.Join(upperCaseList, "") - lowerCamelCase = strings.Join(lowerCaseList, "") - return -} - -func convertLowerStart(in string) string { - var resp []rune - for index, r := range in { - if index == 0 { - r = unicode.ToLower(r) - } - resp = append(resp, r) - } - return string(resp) -} - -func convertUpperStart(in string) string { - var resp []rune - for index, r := range in { - if index == 0 { - r = unicode.ToUpper(r) - } - resp = append(resp, r) - } - return string(resp) -} diff --git a/tools/goctl/util/stringx/string.go b/tools/goctl/util/stringx/string.go new file mode 100644 index 000000000..88e46abb5 --- /dev/null +++ b/tools/goctl/util/stringx/string.go @@ -0,0 +1,145 @@ +package stringx + +import ( + "bytes" + "strings" + "unicode" +) + +const ( + emptyString = "" +) + +type ( + String struct { + source string + } +) + +func From(data string) String { + return String{source: data} +} +func (s String) IsEmptyOrSpace() bool { + if len(s.source) == 0 { + return true + } + if strings.TrimSpace(s.source) == "" { + return true + } + return false +} + +func (s String) Lower() string { + if s.IsEmptyOrSpace() { + return s.source + } + return strings.ToLower(s.source) +} +func (s String) Upper() string { + if s.IsEmptyOrSpace() { + return s.source + } + return strings.ToUpper(s.source) +} +func (s String) Title() string { + if s.IsEmptyOrSpace() { + return s.source + } + return strings.Title(s.source) +} + +// snake->camel(upper start) +func (s String) Snake2Camel() string { + if s.IsEmptyOrSpace() { + return s.source + } + list := s.splitBy(func(r rune) bool { + return r == '_' + }, true) + var target []string + for _, item := range list { + target = append(target, From(item).Title()) + } + return strings.Join(target, "") +} + +// camel->snake +func (s String) Camel2Snake() string { + if s.IsEmptyOrSpace() { + return s.source + } + list := s.splitBy(func(r rune) bool { + return unicode.IsUpper(r) + }, false) + var target []string + for _, item := range list { + target = append(target, From(item).Lower()) + } + return strings.Join(target, "_") +} + +// return original string if rune is not letter at index 0 +func (s String) LowerStart() string { + if s.IsEmptyOrSpace() { + return s.source + } + r := rune(s.source[0]) + if !unicode.IsUpper(r) && !unicode.IsLower(r) { + return s.source + } + return string(r) + s.source[1:] +} + +// it will not ignore spaces +func (s String) splitBy(fn func(r rune) bool, remove bool) []string { + if s.IsEmptyOrSpace() { + return nil + } + var list []string + buffer := new(bytes.Buffer) + for _, r := range s.source { + if fn(r) { + if buffer.Len() != 0 { + list = append(list, buffer.String()) + buffer.Reset() + } + if !remove { + buffer.WriteRune(r) + } + continue + } + buffer.WriteRune(r) + } + if buffer.Len() != 0 { + list = append(list, buffer.String()) + } + return list +} + +func (s String) Replace(old, new string) string { + return strings.ReplaceAll(s.source, old, new) +} + +func (s String) ReplaceAll(old, new string) string { + return strings.ReplaceAll(s.source, old, new) +} + +func (s String) Source() string { + return s.source +} + +func Title(s string) string { + if len(s) == 0 { + return s + } + + return strings.ToUpper(s[:1]) + s[1:] +} + +func Untitle(s string) string { + if len(s) == 0 { + return s + } + + return strings.ToLower(s[:1]) + s[1:] +} diff --git a/tools/goctl/util/stringx/string_test.go b/tools/goctl/util/stringx/string_test.go new file mode 100644 index 000000000..b0ad07ff6 --- /dev/null +++ b/tools/goctl/util/stringx/string_test.go @@ -0,0 +1,42 @@ +package stringx + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestString_IsEmptyOrSpace(t *testing.T) { + ret := From(" ").IsEmptyOrSpace() + assert.Equal(t, true, ret) + ret2 := From("ll??").IsEmptyOrSpace() + assert.Equal(t, false, ret2) + ret3 := From(` + `).IsEmptyOrSpace() + assert.Equal(t, true, ret3) +} + +func TestString_Snake2Camel(t *testing.T) { + ret := From("____this_is_snake").Snake2Camel() + assert.Equal(t, "ThisIsSnake", ret) + + ret2 := From("测试_test_Data").Snake2Camel() + assert.Equal(t, "测试TestData", ret2) + + ret3 := From("___").Snake2Camel() + assert.Equal(t, "", ret3) + + ret4 := From("testData_").Snake2Camel() + assert.Equal(t, "TestData", ret4) + + ret5 := From("testDataTestData").Snake2Camel() + assert.Equal(t, "TestDataTestData", ret5) +} + +func TestString_Camel2Snake(t *testing.T) { + ret := From("ThisIsCCCamel").Camel2Snake() + assert.Equal(t, "this_is_c_c_camel", ret) + + ret2 := From("测试Test_Data_test_data").Camel2Snake() + assert.Equal(t, "测试_test__data_test_data", ret2) +} diff --git a/tools/goctl/util/templatex/templatex.go b/tools/goctl/util/templatex/templatex.go new file mode 100644 index 000000000..e8b448db3 --- /dev/null +++ b/tools/goctl/util/templatex/templatex.go @@ -0,0 +1,67 @@ +package templatex + +import ( + "bytes" + goformat "go/format" + "io/ioutil" + "os" + "text/template" +) + +type ( + defaultTemplate struct { + name string + text string + goFmt bool + savePath string + } +) + +func With(name string) *defaultTemplate { + return &defaultTemplate{ + name: name, + } +} +func (t *defaultTemplate) Parse(text string) *defaultTemplate { + t.text = text + return t +} + +func (t *defaultTemplate) GoFmt(format bool) *defaultTemplate { + t.goFmt = format + return t +} + +func (t *defaultTemplate) SaveTo(data interface{}, path string) error { + output, err := t.execute(data) + if err != nil { + return err + } + return ioutil.WriteFile(path, output.Bytes(), os.ModePerm) +} + +func (t *defaultTemplate) Execute(data interface{}) (*bytes.Buffer, error) { + return t.execute(data) +} + +func (t *defaultTemplate) execute(data interface{}) (*bytes.Buffer, error) { + tem, err := template.New(t.name).Parse(t.text) + if err != nil { + return nil, err + } + buf := new(bytes.Buffer) + err = tem.Execute(buf, data) + if err != nil { + return nil, err + } + if !t.goFmt { + return buf, nil + } + formatOutput, err := goformat.Source(buf.Bytes()) + if err != nil { + return nil, err + } + buf.Reset() + buf.Write(formatOutput) + return buf, nil +}