diff --git a/core/stores/sqlx/orm.go b/core/stores/sqlx/orm.go index e03104cfa..5b308581d 100644 --- a/core/stores/sqlx/orm.go +++ b/core/stores/sqlx/orm.go @@ -72,7 +72,7 @@ func getTaggedFieldValueMap(v reflect.Value) (map[string]any, error) { func getValueInterface(value reflect.Value) (any, error) { switch value.Kind() { case reflect.Ptr: - if !value.CanInterface() { + if !value.CanAddr() || !value.Addr().CanInterface() { return nil, ErrNotReadableValue } @@ -81,7 +81,7 @@ func getValueInterface(value reflect.Value) (any, error) { value.Set(reflect.New(baseValueType)) } - return value.Interface(), nil + return value.Addr().Interface(), nil default: if !value.CanAddr() || !value.Addr().CanInterface() { return nil, ErrNotReadableValue diff --git a/core/stores/sqlx/orm_test.go b/core/stores/sqlx/orm_test.go index c35465473..c3e9e46ea 100644 --- a/core/stores/sqlx/orm_test.go +++ b/core/stores/sqlx/orm_test.go @@ -5,6 +5,7 @@ import ( "database/sql" "errors" "testing" + "time" "github.com/DATA-DOG/go-sqlmock" "github.com/stretchr/testify/assert" @@ -1575,6 +1576,643 @@ func TestAnonymousStructPrError(t *testing.T) { }) } +func TestUnmarshalRowsZeroValueStructPtr(t *testing.T) { + secondNamePtr := "second_ptr" + secondAgePtr := int64(30) + thirdNamePtr := "third_ptr" + thirdAgePtr := int64(0) + + expect := []struct { + Name string + NamePtr *string + Age int64 + AgePtr *int64 + }{ + { + Name: "first", + NamePtr: nil, + Age: 2, + AgePtr: nil, + }, + { + Name: "second", + NamePtr: &secondNamePtr, + Age: 3, + AgePtr: &secondAgePtr, + }, + { + Name: "", + NamePtr: &thirdNamePtr, + Age: 0, + AgePtr: &thirdAgePtr, + }, + } + + var value []struct { + Age int64 `db:"age"` + AgePtr *int64 `db:"age_ptr"` + Name string `db:"name"` + NamePtr *string `db:"name_ptr"` + } + + dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) { + rs := sqlmock.NewRows([]string{"name", "name_ptr", "age", "age_ptr"}). + AddRow("first", nil, 2, nil). + AddRow("second", "second_ptr", 3, 30). + AddRow("", "third_ptr", 0, 0) + + mock.ExpectQuery("select (.+) from users where user=?"). + WithArgs("anyone").WillReturnRows(rs) + + assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error { + return unmarshalRows(&value, rows, true) + }, "select name, name_ptr, age, age_ptr from users where user=?", "anyone")) + + assert.Equal(t, 3, len(value), "应该返回3行数据") + + for i, each := range expect { + + assert.Equal(t, each.Name, value[i].Name) + assert.Equal(t, each.Age, value[i].Age) + + if each.NamePtr == nil { + assert.Nil(t, value[i].NamePtr) + } else { + assert.NotNil(t, value[i].NamePtr) + assert.Equal(t, *each.NamePtr, *value[i].NamePtr) + } + + if each.AgePtr == nil { + assert.Nil(t, value[i].AgePtr) + } else { + assert.NotNil(t, value[i].AgePtr) + assert.Equal(t, *each.AgePtr, *value[i].AgePtr) + } + } + }) +} + +func TestUnmarshalRowsAllNullStructPtrFields(t *testing.T) { + expect := []struct { + NamePtr *string + AgePtr *int64 + }{ + { + NamePtr: nil, + AgePtr: nil, + }, + { + NamePtr: stringPtr("second"), + AgePtr: int64Ptr(30), + }, + { + NamePtr: nil, + AgePtr: nil, + }, + } + + var value []struct { + AgePtr *int64 `db:"age_ptr"` + NamePtr *string `db:"name_ptr"` + } + + dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) { + rs := sqlmock.NewRows([]string{"name_ptr", "age_ptr"}). + AddRow(nil, nil). + AddRow("second", 30). + AddRow(nil, nil) + + mock.ExpectQuery("select (.+) from users where user=?"). + WithArgs("anyone").WillReturnRows(rs) + + assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error { + return unmarshalRows(&value, rows, true) + }, "select name_ptr, age_ptr from users where user=?", "anyone")) + + assert.Equal(t, 3, len(value)) + + for i, each := range expect { + if each.NamePtr == nil { + assert.Nil(t, value[i].NamePtr) + } else { + assert.NotNil(t, value[i].NamePtr) + assert.Equal(t, *each.NamePtr, *value[i].NamePtr) + } + + if each.AgePtr == nil { + assert.Nil(t, value[i].AgePtr) + } else { + assert.NotNil(t, value[i].AgePtr) + assert.Equal(t, *each.AgePtr, *value[i].AgePtr) + } + } + }) +} + +func TestUnmarshalRowsWithSqlNullTypes(t *testing.T) { + expect := []struct { + Name string + NullName sql.NullString + Age int64 + NullAge sql.NullInt64 + Score float64 + NullScore sql.NullFloat64 + Active bool + NullActive sql.NullBool + }{ + { + Name: "first", + NullName: sql.NullString{ + String: "", + Valid: false, + }, + Age: 20, + NullAge: sql.NullInt64{ + Int64: 0, + Valid: false, + }, + Score: 85.5, + NullScore: sql.NullFloat64{ + Float64: 0, + Valid: false, + }, + Active: true, + NullActive: sql.NullBool{ + Bool: false, + Valid: false, + }, + }, + { + Name: "second", + NullName: sql.NullString{ + String: "not_null_name", + Valid: true, + }, + Age: 25, + NullAge: sql.NullInt64{ + Int64: 30, + Valid: true, + }, + Score: 90.0, + NullScore: sql.NullFloat64{ + Float64: 95.5, + Valid: true, + }, + Active: false, + NullActive: sql.NullBool{ + Bool: true, + Valid: true, + }, + }, + { + Name: "third", + NullName: sql.NullString{ + String: "", + Valid: false, + }, + Age: 0, + NullAge: sql.NullInt64{ + Int64: 0, + Valid: false, + }, + Score: 0, + NullScore: sql.NullFloat64{ + Float64: 0, + Valid: false, + }, + Active: false, + NullActive: sql.NullBool{ + Bool: false, + Valid: false, + }, + }, + } + + var value []struct { + Name string `db:"name"` + NullName sql.NullString `db:"null_name"` + Age int64 `db:"age"` + NullAge sql.NullInt64 `db:"null_age"` + Score float64 `db:"score"` + NullScore sql.NullFloat64 `db:"null_score"` + Active bool `db:"active"` + NullActive sql.NullBool `db:"null_active"` + } + + dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) { + rs := sqlmock.NewRows([]string{ + "name", "null_name", "age", "null_age", "score", "null_score", "active", "null_active", + }). + AddRow("first", nil, 20, nil, 85.5, nil, true, nil). + AddRow("second", "not_null_name", 25, 30, 90.0, 95.5, false, true). + AddRow("third", nil, 0, nil, 0, nil, false, nil) + + mock.ExpectQuery("select (.+) from users where type=?"). + WithArgs("test").WillReturnRows(rs) + + assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error { + return unmarshalRows(&value, rows, true) + }, "select name, null_name, age, null_age, score, null_score, active, null_active from users where type=?", "test")) + + assert.Equal(t, 3, len(value)) + + for i, each := range expect { + assert.Equal(t, each.Name, value[i].Name) + assert.Equal(t, each.Age, value[i].Age) + assert.Equal(t, each.Score, value[i].Score) + assert.Equal(t, each.Active, value[i].Active) + + assert.Equal(t, each.NullName.Valid, value[i].NullName.Valid) + if each.NullName.Valid { + assert.Equal(t, each.NullName.String, value[i].NullName.String) + } + + assert.Equal(t, each.NullAge.Valid, value[i].NullAge.Valid) + if each.NullAge.Valid { + assert.Equal(t, each.NullAge.Int64, value[i].NullAge.Int64) + } + + assert.Equal(t, each.NullScore.Valid, value[i].NullScore.Valid) + if each.NullScore.Valid { + assert.Equal(t, each.NullScore.Float64, value[i].NullScore.Float64) + } + + assert.Equal(t, each.NullActive.Valid, value[i].NullActive.Valid) + if each.NullActive.Valid { + assert.Equal(t, each.NullActive.Bool, value[i].NullActive.Bool) + } + } + }) +} + +func TestUnmarshalRowsSqlNullWithMixedData(t *testing.T) { + expect := []struct { + Name string + NullName sql.NullString + Age int64 + NullAge sql.NullInt64 + IsStudent bool + NullActive sql.NullBool + }{ + { + Name: "student1", + NullName: sql.NullString{ + String: "", + Valid: false, + }, + Age: 18, + NullAge: sql.NullInt64{ + Int64: 0, + Valid: false, + }, + IsStudent: true, + NullActive: sql.NullBool{ + Bool: false, + Valid: false, + }, + }, + { + Name: "student2", + NullName: sql.NullString{ + String: "has_nickname", + Valid: true, + }, + Age: 20, + NullAge: sql.NullInt64{ + Int64: 22, + Valid: true, + }, + IsStudent: false, + NullActive: sql.NullBool{ + Bool: true, + Valid: true, + }, + }, + } + + var value []struct { + Name string `db:"name"` + NullName sql.NullString `db:"null_name"` + Age int64 `db:"age"` + NullAge sql.NullInt64 `db:"null_age"` + IsStudent bool `db:"is_student"` + NullActive sql.NullBool `db:"null_active"` + } + + dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) { + rs := sqlmock.NewRows([]string{"name", "null_name", "age", "null_age", "is_student", "null_active"}). + AddRow("student1", nil, 18, nil, true, nil). + AddRow("student2", "has_nickname", 20, 22, false, true) + + mock.ExpectQuery("select (.+) from students where class=?"). + WithArgs("A").WillReturnRows(rs) + + assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error { + return unmarshalRows(&value, rows, true) + }, "select name, null_name, age, null_age, is_student, null_active from students where class=?", "A")) + + assert.Equal(t, 2, len(value)) + + for i, each := range expect { + assert.Equal(t, each.Name, value[i].Name) + assert.Equal(t, each.Age, value[i].Age) + assert.Equal(t, each.IsStudent, value[i].IsStudent) + + assert.Equal(t, each.NullName.Valid, value[i].NullName.Valid) + if each.NullName.Valid { + assert.Equal(t, each.NullName.String, value[i].NullName.String) + } + + assert.Equal(t, each.NullAge.Valid, value[i].NullAge.Valid) + if each.NullAge.Valid { + assert.Equal(t, each.NullAge.Int64, value[i].NullAge.Int64) + } + + assert.Equal(t, each.NullActive.Valid, value[i].NullActive.Valid) + if each.NullActive.Valid { + assert.Equal(t, each.NullActive.Bool, value[i].NullActive.Bool) + } + } + }) +} + +func TestUnmarshalRowsSqlNullTime(t *testing.T) { + now := time.Now() + futureTime := now.AddDate(1, 0, 0) + + expect := []struct { + Name string + BirthDate sql.NullTime + LastLogin sql.NullTime + }{ + { + Name: "user1", + BirthDate: sql.NullTime{ + Time: time.Time{}, + Valid: false, + }, + LastLogin: sql.NullTime{ + Time: now, + Valid: true, + }, + }, + { + Name: "user2", + BirthDate: sql.NullTime{ + Time: futureTime, + Valid: true, + }, + LastLogin: sql.NullTime{ + Time: time.Time{}, + Valid: false, + }, + }, + } + + var value []struct { + Name string `db:"name"` + BirthDate sql.NullTime `db:"birth_date"` + LastLogin sql.NullTime `db:"last_login"` + } + + dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) { + rs := sqlmock.NewRows([]string{"name", "birth_date", "last_login"}). + AddRow("user1", nil, now). + AddRow("user2", futureTime, nil) + + mock.ExpectQuery("select (.+) from users"). + WillReturnRows(rs) + + assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error { + return unmarshalRows(&value, rows, true) + }, "select name, birth_date, last_login from users")) + + assert.Equal(t, 2, len(value)) + + for i, each := range expect { + assert.Equal(t, each.Name, value[i].Name) + + assert.Equal(t, each.BirthDate.Valid, value[i].BirthDate.Valid) + if each.BirthDate.Valid { + assert.WithinDuration(t, each.BirthDate.Time, value[i].BirthDate.Time, time.Second) + } + + assert.Equal(t, each.LastLogin.Valid, value[i].LastLogin.Valid) + if each.LastLogin.Valid { + assert.WithinDuration(t, each.LastLogin.Time, value[i].LastLogin.Time, time.Second) + } + } + }) +} + +func TestUnmarshalRowsSqlNullWithEmptyValues(t *testing.T) { + expect := []struct { + Name string + NullString sql.NullString + NullInt sql.NullInt64 + NullFloat sql.NullFloat64 + NullBool sql.NullBool + }{ + { + Name: "empty_values", + NullString: sql.NullString{ + String: "", + Valid: true, + }, + NullInt: sql.NullInt64{ + Int64: 0, + Valid: true, + }, + NullFloat: sql.NullFloat64{ + Float64: 0.0, + Valid: true, + }, + NullBool: sql.NullBool{ + Bool: false, + Valid: true, + }, + }, + { + Name: "null_values", + NullString: sql.NullString{ + String: "", + Valid: false, + }, + NullInt: sql.NullInt64{ + Int64: 0, + Valid: false, + }, + NullFloat: sql.NullFloat64{ + Float64: 0.0, + Valid: false, + }, + NullBool: sql.NullBool{ + Bool: false, + Valid: false, + }, + }, + { + Name: "mixed_values", + NullString: sql.NullString{ + String: "actual_value", + Valid: true, + }, + NullInt: sql.NullInt64{ + Int64: 0, + Valid: true, + }, + NullFloat: sql.NullFloat64{ + Float64: 0.0, + Valid: false, + }, + NullBool: sql.NullBool{ + Bool: true, + Valid: true, + }, + }, + } + + var value []struct { + Name string `db:"name"` + NullString sql.NullString `db:"null_string"` + NullInt sql.NullInt64 `db:"null_int"` + NullFloat sql.NullFloat64 `db:"null_float"` + NullBool sql.NullBool `db:"null_bool"` + } + + dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) { + rs := sqlmock.NewRows([]string{"name", "null_string", "null_int", "null_float", "null_bool"}). + AddRow("empty_values", "", 0, 0.0, false). + AddRow("null_values", nil, nil, nil, nil). + AddRow("mixed_values", "actual_value", 0, nil, true) + + mock.ExpectQuery("select (.+) from test_table"). + WillReturnRows(rs) + + assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error { + return unmarshalRows(&value, rows, true) + }, "select name, null_string, null_int, null_float, null_bool from test_table")) + + assert.Equal(t, 3, len(value)) + + for i, each := range expect { + + assert.Equal(t, each.Name, value[i].Name) + + assert.Equal(t, each.NullString.Valid, value[i].NullString.Valid) + if each.NullString.Valid { + assert.Equal(t, each.NullString.String, value[i].NullString.String) + } else { + assert.Equal(t, "", value[i].NullString.String) + } + + assert.Equal(t, each.NullInt.Valid, value[i].NullInt.Valid) + if each.NullInt.Valid { + assert.Equal(t, each.NullInt.Int64, value[i].NullInt.Int64) + } else { + assert.Equal(t, int64(0), value[i].NullInt.Int64) + } + + assert.Equal(t, each.NullFloat.Valid, value[i].NullFloat.Valid) + if each.NullFloat.Valid { + assert.Equal(t, each.NullFloat.Float64, value[i].NullFloat.Float64) + } else { + assert.Equal(t, 0.0, value[i].NullFloat.Float64) + } + + assert.Equal(t, each.NullBool.Valid, value[i].NullBool.Valid) + if each.NullBool.Valid { + assert.Equal(t, each.NullBool.Bool, value[i].NullBool.Bool) + } else { + assert.Equal(t, false, value[i].NullBool.Bool) + } + } + }) +} + +func TestUnmarshalRowsSqlNullStringEmptyVsNull(t *testing.T) { + expect := []struct { + Name string + EmptyString sql.NullString + NullString sql.NullString + NormalString sql.NullString + }{ + { + Name: "row1", + EmptyString: sql.NullString{ + String: "", + Valid: true, + }, + NullString: sql.NullString{ + String: "", + Valid: false, + }, + NormalString: sql.NullString{ + String: "hello", + Valid: true, + }, + }, + { + Name: "row2", + EmptyString: sql.NullString{ + String: " ", + Valid: true, + }, + NullString: sql.NullString{ + String: "", + Valid: false, + }, + NormalString: sql.NullString{ + String: "", + Valid: true, + }, + }, + } + + var value []struct { + Name string `db:"name"` + EmptyString sql.NullString `db:"empty_string"` + NullString sql.NullString `db:"null_string"` + NormalString sql.NullString `db:"normal_string"` + } + + dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) { + rs := sqlmock.NewRows([]string{"name", "empty_string", "null_string", "normal_string"}). + AddRow("row1", "", nil, "hello"). + AddRow("row2", " ", nil, "") + + mock.ExpectQuery("select (.+) from string_test"). + WillReturnRows(rs) + + assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error { + return unmarshalRows(&value, rows, true) + }, "select name, empty_string, null_string, normal_string from string_test")) + + assert.Equal(t, 2, len(value)) + + for i, each := range expect { + assert.True(t, value[i].EmptyString.Valid) + assert.Equal(t, each.EmptyString.String, value[i].EmptyString.String) + + assert.False(t, value[i].NullString.Valid) + assert.Equal(t, "", value[i].NullString.String) + + assert.Equal(t, each.NormalString.Valid, value[i].NormalString.Valid) + if each.NormalString.Valid { + assert.Equal(t, each.NormalString.String, value[i].NormalString.String) + } + } + }) +} + +func stringPtr(s string) *string { + return &s +} + +func int64Ptr(i int64) *int64 { + return &i +} + func BenchmarkIgnore(b *testing.B) { db, mock, err := sqlmock.New() if err != nil {