(goctl)feature/model config (#4062)

Co-authored-by: Kevin Wan <wanjunfeng@gmail.com>
This commit is contained in:
kesonan
2024-04-10 23:01:59 +08:00
committed by GitHub
parent 682460c1c8
commit 2a7ada993b
12 changed files with 574 additions and 36 deletions

View File

@@ -5,6 +5,8 @@ import (
"strings"
"github.com/zeromicro/ddl-parser/parser"
"github.com/zeromicro/go-zero/tools/goctl/config"
"github.com/zeromicro/go-zero/tools/goctl/pkg/env"
)
var unsignedTypeMap = map[string]string{
@@ -73,6 +75,63 @@ var commonMysqlDataTypeMapInt = map[int]string{
parser.Boolean: "bool",
}
var commonMysqlDataTypeMap = map[int]string{
// number
parser.Bit: "bit",
parser.TinyInt: "tinyint",
parser.SmallInt: "smallint",
parser.MediumInt: "mediumint",
parser.Int: "int",
parser.MiddleInt: "middleint",
parser.Int1: "int1",
parser.Int2: "int2",
parser.Int3: "int3",
parser.Int4: "int4",
parser.Int8: "int8",
parser.Integer: "integer",
parser.BigInt: "bigint",
parser.Float: "float",
parser.Float4: "float4",
parser.Float8: "float8",
parser.Double: "double",
parser.Decimal: "decimal",
parser.Dec: "dec",
parser.Fixed: "fixed",
parser.Numeric: "numeric",
parser.Real: "real",
// date&time
parser.Date: "date",
parser.DateTime: "datetime",
parser.Timestamp: "timestamp",
parser.Time: "time",
parser.Year: "year",
// string
parser.Char: "char",
parser.VarChar: "varchar",
parser.NVarChar: "nvarchar",
parser.NChar: "nchar",
parser.Character: "character",
parser.LongVarChar: "longvarchar",
parser.LineString: "linestring",
parser.MultiLineString: "multilinestring",
parser.Binary: "binary",
parser.VarBinary: "varbinary",
parser.TinyText: "tinytext",
parser.Text: "text",
parser.MediumText: "mediumtext",
parser.LongText: "longtext",
parser.Enum: "enum",
parser.Set: "set",
parser.Json: "json",
parser.Blob: "blob",
parser.LongBlob: "longblob",
parser.MediumBlob: "mediumblob",
parser.TinyBlob: "tinyblob",
// bool
parser.Bool: "bool",
parser.Boolean: "boolean",
}
var commonMysqlDataTypeMapString = map[string]string{
// For consistency, all integer types are converted to int64
// bool
@@ -144,28 +203,79 @@ var commonMysqlDataTypeMapString = map[string]string{
}
// ConvertDataType converts mysql column type into golang type
func ConvertDataType(dataBaseType int, isDefaultNull, unsigned, strict bool) (string, error) {
tp, ok := commonMysqlDataTypeMapInt[dataBaseType]
if !ok {
return "", fmt.Errorf("unsupported database type: %v", dataBaseType)
func ConvertDataType(dataBaseType int, isDefaultNull, unsigned, strict bool) (string, string, error) {
if env.UseExperimental() {
tp, ok := commonMysqlDataTypeMap[dataBaseType]
if !ok {
return "", "", fmt.Errorf("unsupported database type: %v", dataBaseType)
}
goType, thirdPkg, _, err := ConvertStringDataType(tp, isDefaultNull, unsigned, strict)
return goType, thirdPkg, err
}
return mayConvertNullType(tp, isDefaultNull, unsigned, strict), nil
// the following are the old version compatibility code.
tp, ok := commonMysqlDataTypeMapInt[dataBaseType]
if !ok {
return "", "", fmt.Errorf("unsupported database type: %v", dataBaseType)
}
return mayConvertNullType(tp, isDefaultNull, unsigned, strict), "", nil
}
// ConvertStringDataType converts mysql column type into golang type
func ConvertStringDataType(dataBaseType string, isDefaultNull, unsigned, strict bool) (
goType string, isPQArray bool, err error) {
goType string, thirdPkg string, isPQArray bool, err error) {
if env.UseExperimental() {
customTp, thirdImport := convertDatatypeWithConfig(dataBaseType, isDefaultNull, unsigned)
if len(customTp) != 0 {
return customTp, thirdImport, false, nil
}
tp, ok := commonMysqlDataTypeMapString[strings.ToLower(dataBaseType)]
if !ok {
return "", "", false, fmt.Errorf("unsupported database type: %s", dataBaseType)
}
if strings.HasPrefix(dataBaseType, "_") {
return tp, "", true, nil
}
return mayConvertNullType(tp, isDefaultNull, unsigned, strict), "", false, nil
}
// the following are the old version compatibility code.
tp, ok := commonMysqlDataTypeMapString[strings.ToLower(dataBaseType)]
if !ok {
return "", false, fmt.Errorf("unsupported database type: %s", dataBaseType)
return "", "", false, fmt.Errorf("unsupported database type: %s", dataBaseType)
}
if strings.HasPrefix(dataBaseType, "_") {
return tp, true, nil
return tp, "", true, nil
}
return mayConvertNullType(tp, isDefaultNull, unsigned, strict), false, nil
return mayConvertNullType(tp, isDefaultNull, unsigned, strict), "", false, nil
}
func convertDatatypeWithConfig(dataBaseType string, isDefaultNull, unsigned bool) (string, string) {
if config.ExternalConfig == nil {
return "", ""
}
opt, ok := config.ExternalConfig.Model.TypesMap[strings.ToLower(dataBaseType)]
if !ok || (len(opt.Type) == 0 && len(opt.UnsignedType) == 0 && len(opt.NullType) == 0) {
return "", ""
}
if isDefaultNull {
if len(opt.NullType) != 0 {
return opt.NullType, opt.Pkg
}
} else if unsigned {
if len(opt.UnsignedType) != 0 {
return opt.UnsignedType, opt.Pkg
}
}
return opt.Type, opt.Pkg
}
func mayConvertNullType(goDataType string, isDefaultNull, unsigned, strict bool) string {

View File

@@ -8,23 +8,102 @@ import (
)
func TestConvertDataType(t *testing.T) {
v, err := ConvertDataType(parser.TinyInt, false, false, true)
v, _, err := ConvertDataType(parser.TinyInt, false, false, true)
assert.Nil(t, err)
assert.Equal(t, "int64", v)
v, err = ConvertDataType(parser.TinyInt, false, true, true)
v, _, err = ConvertDataType(parser.TinyInt, false, true, true)
assert.Nil(t, err)
assert.Equal(t, "uint64", v)
v, err = ConvertDataType(parser.TinyInt, true, false, true)
v, _, err = ConvertDataType(parser.TinyInt, true, false, true)
assert.Nil(t, err)
assert.Equal(t, "sql.NullInt64", v)
v, err = ConvertDataType(parser.Timestamp, false, false, true)
v, _, err = ConvertDataType(parser.Timestamp, false, false, true)
assert.Nil(t, err)
assert.Equal(t, "time.Time", v)
v, err = ConvertDataType(parser.Timestamp, true, false, true)
v, _, err = ConvertDataType(parser.Timestamp, true, false, true)
assert.Nil(t, err)
assert.Equal(t, "sql.NullTime", v)
v, _, err = ConvertDataType(parser.Decimal, false, false, true)
assert.Nil(t, err)
assert.Equal(t, "float64", v)
}
func TestConvertStringDataType(t *testing.T) {
type (
input struct {
dataType string
isDefaultNull bool
unsigned bool
strict bool
}
result struct {
goType string
thirdPkg string
isPQArray bool
}
)
var testData = []struct {
input input
want result
}{
{
input: input{
dataType: "bigint",
isDefaultNull: false,
unsigned: false,
strict: false,
},
want: result{
goType: "int64",
},
},
{
input: input{
dataType: "bigint",
isDefaultNull: true,
unsigned: false,
strict: false,
},
want: result{
goType: "sql.NullInt64",
},
},
{
input: input{
dataType: "bigint",
isDefaultNull: false,
unsigned: true,
strict: false,
},
want: result{
goType: "uint64",
},
},
{
input: input{
dataType: "_int2",
isDefaultNull: false,
unsigned: false,
strict: false,
},
want: result{
goType: "pq.Int64Array",
isPQArray: true,
},
},
}
for _, data := range testData {
tp, thirdPkg, isPQArray, err := ConvertStringDataType(data.input.dataType, data.input.isDefaultNull, data.input.unsigned, data.input.strict)
assert.NoError(t, err)
assert.Equal(t, data.want, result{
goType: tp,
thirdPkg: thirdPkg,
isPQArray: isPQArray,
})
}
}

View File

@@ -9,7 +9,7 @@ CREATE TABLE `user`
`mobile` varchar(255) COLLATE utf8mb4_general_ci NOT NULL DEFAULT '' COMMENT '手机号',
`gender` char(5) COLLATE utf8mb4_general_ci NOT NULL COMMENT '男|女|未公\r开',
`nickname` varchar(255) COLLATE utf8mb4_general_ci DEFAULT '' COMMENT '用户昵称',
`type` tinyint(1) COLLATE utf8mb4_general_ci DEFAULT 0 COMMENT '用户类型',
`type` tinyint(1) COLLATE utf8mb4_general_ci DEFAULT 0 COMMENT '用户类型',
`create_time` timestamp NULL,
`update_time` timestamp NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,
PRIMARY KEY (`id`),
@@ -22,12 +22,13 @@ CREATE TABLE `user`
CREATE TABLE `student`
(
`type` bigint NOT NULL,
`class` varchar(255) COLLATE utf8mb4_bin NOT NULL DEFAULT '',
`name` varchar(255) COLLATE utf8mb4_bin NOT NULL DEFAULT '',
`age` tinyint DEFAULT NULL,
`type` bigint NOT NULL,
`class` varchar(255) NOT NULL DEFAULT '',
`name` varchar(255) NOT NULL DEFAULT '',
`age` tinyint DEFAULT NULL,
`score` float(10, 0
) DEFAULT NULL,
`amount` decimal DEFAULT NULL,
`create_time` timestamp NOT NULL DEFAULT CURRENT_TIMESTAMP,
`update_time` timestamp NULL DEFAULT NULL,
`delete_time` timestamp NULL DEFAULT NULL ON UPDATE CURRENT_TIMESTAMP,

View File

@@ -1,12 +1,27 @@
package gen
import (
"fmt"
"strings"
"github.com/zeromicro/go-zero/tools/goctl/model/sql/template"
"github.com/zeromicro/go-zero/tools/goctl/util"
"github.com/zeromicro/go-zero/tools/goctl/util/pathx"
)
func genImports(table Table, withCache, timeImport bool) (string, error) {
var thirdImports []string
var m = map[string]struct{}{}
for _, c := range table.Fields {
if len(c.ThirdPkg) > 0 {
if _, ok := m[c.ThirdPkg]; ok {
continue
}
m[c.ThirdPkg] = struct{}{}
thirdImports = append(thirdImports, fmt.Sprintf("%q", c.ThirdPkg))
}
}
if withCache {
text, err := pathx.LoadTemplate(category, importsTemplateFile, template.Imports)
if err != nil {
@@ -17,6 +32,7 @@ func genImports(table Table, withCache, timeImport bool) (string, error) {
"time": timeImport,
"containsPQ": table.ContainsPQ,
"data": table,
"third": strings.Join(thirdImports, "\n"),
})
if err != nil {
return "", err
@@ -34,6 +50,7 @@ func genImports(table Table, withCache, timeImport bool) (string, error) {
"time": timeImport,
"containsPQ": table.ContainsPQ,
"data": table,
"third": strings.Join(thirdImports, "\n"),
})
if err != nil {
return "", err

View File

@@ -38,6 +38,7 @@ type (
Field struct {
NameOriginal string
Name stringx.String
ThirdPkg string
DataType string
Comment string
SeqInIndex int
@@ -219,7 +220,7 @@ func convertColumns(columns []*parser.Column, primaryColumn string, strict bool)
}
}
dataType, err := converter.ConvertDataType(column.DataType.Type(), isDefaultNull, column.DataType.Unsigned(), strict)
dataType, thirdPkg, err := converter.ConvertDataType(column.DataType.Type(), isDefaultNull, column.DataType.Unsigned(), strict)
if err != nil {
return Primary{}, nil, err
}
@@ -236,6 +237,7 @@ func convertColumns(columns []*parser.Column, primaryColumn string, strict bool)
var field Field
field.Name = stringx.From(column.Name)
field.ThirdPkg = thirdPkg
field.DataType = dataType
field.Comment = util.TrimNewLine(comment)
@@ -267,7 +269,7 @@ func (t *Table) ContainsTime() bool {
func ConvertDataType(table *model.Table, strict bool) (*Table, error) {
isPrimaryDefaultNull := table.PrimaryKey.ColumnDefault == nil && table.PrimaryKey.IsNullAble == "YES"
isPrimaryUnsigned := strings.Contains(table.PrimaryKey.DbColumn.ColumnType, "unsigned")
primaryDataType, containsPQ, err := converter.ConvertStringDataType(table.PrimaryKey.DataType, isPrimaryDefaultNull, isPrimaryUnsigned, strict)
primaryDataType, thirdPkg, containsPQ, err := converter.ConvertStringDataType(table.PrimaryKey.DataType, isPrimaryDefaultNull, isPrimaryUnsigned, strict)
if err != nil {
return nil, err
}
@@ -285,6 +287,7 @@ func ConvertDataType(table *model.Table, strict bool) (*Table, error) {
reply.PrimaryKey = Primary{
Field: Field{
Name: stringx.From(table.PrimaryKey.Name),
ThirdPkg: thirdPkg,
DataType: primaryDataType,
Comment: table.PrimaryKey.Comment,
SeqInIndex: seqInIndex,
@@ -351,7 +354,7 @@ func getTableFields(table *model.Table, strict bool) (map[string]*Field, error)
for _, each := range table.Columns {
isDefaultNull := each.ColumnDefault == nil && each.IsNullAble == "YES"
isPrimaryUnsigned := strings.Contains(each.ColumnType, "unsigned")
dt, containsPQ, err := converter.ConvertStringDataType(each.DataType, isDefaultNull, isPrimaryUnsigned, strict)
dt, thirdPkg, containsPQ, err := converter.ConvertStringDataType(each.DataType, isDefaultNull, isPrimaryUnsigned, strict)
if err != nil {
return nil, err
}
@@ -363,6 +366,7 @@ func getTableFields(table *model.Table, strict bool) (map[string]*Field, error)
field := &Field{
NameOriginal: each.Name,
Name: stringx.From(each.Name),
ThirdPkg: thirdPkg,
DataType: dt,
Comment: each.Comment,
SeqInIndex: columnSeqInIndex,

View File

@@ -9,4 +9,6 @@ import (
"github.com/zeromicro/go-zero/core/stores/builder"
"github.com/zeromicro/go-zero/core/stores/sqlx"
"github.com/zeromicro/go-zero/core/stringx"
{{.third}}
)

View File

@@ -11,4 +11,6 @@ import (
"github.com/zeromicro/go-zero/core/stores/sqlc"
"github.com/zeromicro/go-zero/core/stores/sqlx"
"github.com/zeromicro/go-zero/core/stringx"
{{.third}}
)