chore: refactor and coding style (#4144)

This commit is contained in:
Kevin Wan
2024-05-11 23:06:59 +08:00
committed by GitHub
parent 040fee5669
commit f10084a3f5
3 changed files with 104 additions and 75 deletions

View File

@@ -113,7 +113,29 @@ func (u *Unmarshaler) unmarshalValuer(m Valuer, v any, fullName string) error {
return u.unmarshalWithFullName(simpleValuer{current: m}, v, fullName) return u.unmarshalWithFullName(simpleValuer{current: m}, v, fullName)
} }
func (u *Unmarshaler) fillMap(fieldType reflect.Type, value reflect.Value, mapValue any, fullName string) error { func (u *Unmarshaler) fillJsonUnmarshalerStruct(fieldType reflect.Type,
value reflect.Value, targetValue string) error {
if !value.CanSet() {
return errValueNotSettable
}
baseType := Deref(fieldType)
target := reflect.New(baseType)
unmarshaler, ok := target.Interface().(json.Unmarshaler)
if !ok {
return errUnsupportedType
}
if err := unmarshaler.UnmarshalJSON([]byte(targetValue)); err != nil {
return err
}
value.Set(target)
return nil
}
func (u *Unmarshaler) fillMap(fieldType reflect.Type, value reflect.Value,
mapValue any, fullName string) error {
if !value.CanSet() { if !value.CanSet() {
return errValueNotSettable return errValueNotSettable
} }
@@ -154,7 +176,8 @@ func (u *Unmarshaler) fillMapFromString(value reflect.Value, mapValue any) error
return nil return nil
} }
func (u *Unmarshaler) fillSlice(fieldType reflect.Type, value reflect.Value, mapValue any, fullName string) error { func (u *Unmarshaler) fillSlice(fieldType reflect.Type, value reflect.Value,
mapValue any, fullName string) error {
if !value.CanSet() { if !value.CanSet() {
return errValueNotSettable return errValueNotSettable
} }
@@ -218,8 +241,7 @@ func (u *Unmarshaler) fillSlice(fieldType reflect.Type, value reflect.Value, map
} }
func (u *Unmarshaler) fillSliceFromString(fieldType reflect.Type, value reflect.Value, func (u *Unmarshaler) fillSliceFromString(fieldType reflect.Type, value reflect.Value,
mapValue any, fullName string, mapValue any, fullName string) error {
) error {
var slice []any var slice []any
switch v := mapValue.(type) { switch v := mapValue.(type) {
case fmt.Stringer: case fmt.Stringer:
@@ -249,8 +271,7 @@ func (u *Unmarshaler) fillSliceFromString(fieldType reflect.Type, value reflect.
} }
func (u *Unmarshaler) fillSliceValue(slice reflect.Value, index int, func (u *Unmarshaler) fillSliceValue(slice reflect.Value, index int,
baseKind reflect.Kind, value any, fullName string, baseKind reflect.Kind, value any, fullName string) error {
) error {
if value == nil { if value == nil {
return errNilSliceElement return errNilSliceElement
} }
@@ -288,8 +309,7 @@ func (u *Unmarshaler) fillSliceValue(slice reflect.Value, index int,
} }
func (u *Unmarshaler) fillSliceWithDefault(derefedType reflect.Type, value reflect.Value, func (u *Unmarshaler) fillSliceWithDefault(derefedType reflect.Type, value reflect.Value,
defaultValue, fullName string, defaultValue, fullName string) error {
) error {
baseFieldType := Deref(derefedType.Elem()) baseFieldType := Deref(derefedType.Elem())
baseFieldKind := baseFieldType.Kind() baseFieldKind := baseFieldType.Kind()
defaultCacheLock.Lock() defaultCacheLock.Lock()
@@ -310,7 +330,8 @@ func (u *Unmarshaler) fillSliceWithDefault(derefedType reflect.Type, value refle
return u.fillSlice(derefedType, value, slice, fullName) return u.fillSlice(derefedType, value, slice, fullName)
} }
func (u *Unmarshaler) generateMap(keyType, elemType reflect.Type, mapValue any, fullName string) (reflect.Value, error) { func (u *Unmarshaler) generateMap(keyType, elemType reflect.Type, mapValue any,
fullName string) (reflect.Value, error) {
mapType := reflect.MapOf(keyType, elemType) mapType := reflect.MapOf(keyType, elemType)
valueType := reflect.TypeOf(mapValue) valueType := reflect.TypeOf(mapValue)
if mapType == valueType { if mapType == valueType {
@@ -403,8 +424,7 @@ func (u *Unmarshaler) generateMap(keyType, elemType reflect.Type, mapValue any,
} }
func (u *Unmarshaler) parseOptionsWithContext(field reflect.StructField, m Valuer, fullName string) ( func (u *Unmarshaler) parseOptionsWithContext(field reflect.StructField, m Valuer, fullName string) (
string, *fieldOptionsWithContext, error, string, *fieldOptionsWithContext, error) {
) {
key, options, err := parseKeyAndOptions(u.key, field) key, options, err := parseKeyAndOptions(u.key, field)
if err != nil { if err != nil {
return "", nil, err return "", nil, err
@@ -445,8 +465,7 @@ func (u *Unmarshaler) parseOptionsWithContext(field reflect.StructField, m Value
} }
func (u *Unmarshaler) processAnonymousField(field reflect.StructField, value reflect.Value, func (u *Unmarshaler) processAnonymousField(field reflect.StructField, value reflect.Value,
m valuerWithParent, fullName string, m valuerWithParent, fullName string) error {
) error {
key, options, err := u.parseOptionsWithContext(field, m, fullName) key, options, err := u.parseOptionsWithContext(field, m, fullName)
if err != nil { if err != nil {
return err return err
@@ -464,8 +483,7 @@ func (u *Unmarshaler) processAnonymousField(field reflect.StructField, value ref
} }
func (u *Unmarshaler) processAnonymousFieldOptional(field reflect.StructField, value reflect.Value, func (u *Unmarshaler) processAnonymousFieldOptional(field reflect.StructField, value reflect.Value,
key string, m valuerWithParent, fullName string, key string, m valuerWithParent, fullName string) error {
) error {
derefedFieldType := Deref(field.Type) derefedFieldType := Deref(field.Type)
switch derefedFieldType.Kind() { switch derefedFieldType.Kind() {
@@ -477,8 +495,7 @@ func (u *Unmarshaler) processAnonymousFieldOptional(field reflect.StructField, v
} }
func (u *Unmarshaler) processAnonymousFieldRequired(field reflect.StructField, value reflect.Value, func (u *Unmarshaler) processAnonymousFieldRequired(field reflect.StructField, value reflect.Value,
m valuerWithParent, fullName string, m valuerWithParent, fullName string) error {
) error {
fieldType := field.Type fieldType := field.Type
maybeNewValue(fieldType, value) maybeNewValue(fieldType, value)
derefedFieldType := Deref(fieldType) derefedFieldType := Deref(fieldType)
@@ -502,8 +519,7 @@ func (u *Unmarshaler) processAnonymousFieldRequired(field reflect.StructField, v
} }
func (u *Unmarshaler) processAnonymousStructFieldOptional(fieldType reflect.Type, func (u *Unmarshaler) processAnonymousStructFieldOptional(fieldType reflect.Type,
value reflect.Value, key string, m valuerWithParent, fullName string, value reflect.Value, key string, m valuerWithParent, fullName string) error {
) error {
var filled bool var filled bool
var required int var required int
var requiredFilled int var requiredFilled int
@@ -544,8 +560,7 @@ func (u *Unmarshaler) processAnonymousStructFieldOptional(fieldType reflect.Type
} }
func (u *Unmarshaler) processField(field reflect.StructField, value reflect.Value, func (u *Unmarshaler) processField(field reflect.StructField, value reflect.Value,
m valuerWithParent, fullName string, m valuerWithParent, fullName string) error {
) error {
if usingDifferentKeys(u.key, field) { if usingDifferentKeys(u.key, field) {
return nil return nil
} }
@@ -558,16 +573,13 @@ func (u *Unmarshaler) processField(field reflect.StructField, value reflect.Valu
} }
func (u *Unmarshaler) processFieldNotFromString(fieldType reflect.Type, value reflect.Value, func (u *Unmarshaler) processFieldNotFromString(fieldType reflect.Type, value reflect.Value,
vp valueWithParent, opts *fieldOptionsWithContext, fullName string, vp valueWithParent, opts *fieldOptionsWithContext, fullName string) error {
) error {
derefedFieldType := Deref(fieldType) derefedFieldType := Deref(fieldType)
typeKind := derefedFieldType.Kind() typeKind := derefedFieldType.Kind()
mapValue := vp.value mapValue := vp.value
valueKind := reflect.TypeOf(mapValue).Kind() valueKind := reflect.TypeOf(mapValue).Kind()
switch { switch {
case valueKind == reflect.String && typeKind == reflect.Struct && fieldType.Implements(reflect.TypeOf((*json.Unmarshaler)(nil)).Elem()):
return u.fillCustomUnmarshalerStruct(fieldType, value, mapValue)
case valueKind == reflect.Map && typeKind == reflect.Struct: case valueKind == reflect.Map && typeKind == reflect.Struct:
mv, ok := mapValue.(map[string]any) mv, ok := mapValue.(map[string]any)
if !ok { if !ok {
@@ -588,29 +600,15 @@ func (u *Unmarshaler) processFieldNotFromString(fieldType reflect.Type, value re
return u.fillSliceFromString(fieldType, value, mapValue, fullName) return u.fillSliceFromString(fieldType, value, mapValue, fullName)
case valueKind == reflect.String && derefedFieldType == durationType: case valueKind == reflect.String && derefedFieldType == durationType:
return fillDurationValue(fieldType, value, mapValue.(string)) return fillDurationValue(fieldType, value, mapValue.(string))
case valueKind == reflect.String && typeKind == reflect.Struct && implementsJsonUnmarshaler(fieldType):
return u.fillJsonUnmarshalerStruct(fieldType, value, mapValue.(string))
default: default:
return u.processFieldPrimitive(fieldType, value, mapValue, opts, fullName) return u.processFieldPrimitive(fieldType, value, mapValue, opts, fullName)
} }
} }
func (u *Unmarshaler) fillCustomUnmarshalerStruct(fieldType reflect.Type, value reflect.Value, mapValue any) error {
if !value.CanSet() {
return errValueNotSettable
}
baseType := Deref(fieldType)
target := reflect.New(baseType)
params := make([]reflect.Value, 1)
params[0] = reflect.ValueOf([]byte(mapValue.(string)))
target.MethodByName("UnmarshalJSON").Call(params)
value.Set(target)
return nil
}
func (u *Unmarshaler) processFieldPrimitive(fieldType reflect.Type, value reflect.Value, func (u *Unmarshaler) processFieldPrimitive(fieldType reflect.Type, value reflect.Value,
mapValue any, opts *fieldOptionsWithContext, fullName string, mapValue any, opts *fieldOptionsWithContext, fullName string) error {
) error {
typeKind := Deref(fieldType).Kind() typeKind := Deref(fieldType).Kind()
valueKind := reflect.TypeOf(mapValue).Kind() valueKind := reflect.TypeOf(mapValue).Kind()
@@ -631,8 +629,7 @@ func (u *Unmarshaler) processFieldPrimitive(fieldType reflect.Type, value reflec
} }
func (u *Unmarshaler) processFieldPrimitiveWithJSONNumber(fieldType reflect.Type, value reflect.Value, func (u *Unmarshaler) processFieldPrimitiveWithJSONNumber(fieldType reflect.Type, value reflect.Value,
v json.Number, opts *fieldOptionsWithContext, fullName string, v json.Number, opts *fieldOptionsWithContext, fullName string) error {
) error {
baseType := Deref(fieldType) baseType := Deref(fieldType)
typeKind := baseType.Kind() typeKind := baseType.Kind()
@@ -685,8 +682,7 @@ func (u *Unmarshaler) processFieldPrimitiveWithJSONNumber(fieldType reflect.Type
} }
func (u *Unmarshaler) processFieldStruct(fieldType reflect.Type, value reflect.Value, func (u *Unmarshaler) processFieldStruct(fieldType reflect.Type, value reflect.Value,
m valuerWithParent, fullName string, m valuerWithParent, fullName string) error {
) error {
if fieldType.Kind() == reflect.Ptr { if fieldType.Kind() == reflect.Ptr {
baseType := Deref(fieldType) baseType := Deref(fieldType)
target := reflect.New(baseType).Elem() target := reflect.New(baseType).Elem()
@@ -703,8 +699,7 @@ func (u *Unmarshaler) processFieldStruct(fieldType reflect.Type, value reflect.V
} }
func (u *Unmarshaler) processFieldTextUnmarshaler(fieldType reflect.Type, value reflect.Value, func (u *Unmarshaler) processFieldTextUnmarshaler(fieldType reflect.Type, value reflect.Value,
mapValue any, mapValue any) (bool, error) {
) (bool, error) {
var tval encoding.TextUnmarshaler var tval encoding.TextUnmarshaler
var ok bool var ok bool
@@ -732,8 +727,7 @@ func (u *Unmarshaler) processFieldTextUnmarshaler(fieldType reflect.Type, value
} }
func (u *Unmarshaler) processFieldWithEnvValue(fieldType reflect.Type, value reflect.Value, func (u *Unmarshaler) processFieldWithEnvValue(fieldType reflect.Type, value reflect.Value,
envVal string, opts *fieldOptionsWithContext, fullName string, envVal string, opts *fieldOptionsWithContext, fullName string) error {
) error {
if err := validateValueInOptions(envVal, opts.options()); err != nil { if err := validateValueInOptions(envVal, opts.options()); err != nil {
return err return err
} }
@@ -763,8 +757,7 @@ func (u *Unmarshaler) processFieldWithEnvValue(fieldType reflect.Type, value ref
} }
func (u *Unmarshaler) processNamedField(field reflect.StructField, value reflect.Value, func (u *Unmarshaler) processNamedField(field reflect.StructField, value reflect.Value,
m valuerWithParent, fullName string, m valuerWithParent, fullName string) error {
) error {
if !field.IsExported() { if !field.IsExported() {
return nil return nil
} }
@@ -811,8 +804,7 @@ func (u *Unmarshaler) processNamedField(field reflect.StructField, value reflect
} }
func (u *Unmarshaler) processNamedFieldWithValue(fieldType reflect.Type, value reflect.Value, func (u *Unmarshaler) processNamedFieldWithValue(fieldType reflect.Type, value reflect.Value,
vp valueWithParent, key string, opts *fieldOptionsWithContext, fullName string, vp valueWithParent, key string, opts *fieldOptionsWithContext, fullName string) error {
) error {
mapValue := vp.value mapValue := vp.value
if mapValue == nil { if mapValue == nil {
if opts.optional() { if opts.optional() {
@@ -847,8 +839,7 @@ func (u *Unmarshaler) processNamedFieldWithValue(fieldType reflect.Type, value r
} }
func (u *Unmarshaler) processNamedFieldWithValueFromString(fieldType reflect.Type, value reflect.Value, func (u *Unmarshaler) processNamedFieldWithValueFromString(fieldType reflect.Type, value reflect.Value,
mapValue any, key string, opts *fieldOptionsWithContext, fullName string, mapValue any, key string, opts *fieldOptionsWithContext, fullName string) error {
) error {
valueKind := reflect.TypeOf(mapValue).Kind() valueKind := reflect.TypeOf(mapValue).Kind()
if valueKind != reflect.String { if valueKind != reflect.String {
return fmt.Errorf("the value in map is not string, but %s", valueKind) return fmt.Errorf("the value in map is not string, but %s", valueKind)
@@ -877,8 +868,7 @@ func (u *Unmarshaler) processNamedFieldWithValueFromString(fieldType reflect.Typ
} }
func (u *Unmarshaler) processNamedFieldWithoutValue(fieldType reflect.Type, value reflect.Value, func (u *Unmarshaler) processNamedFieldWithoutValue(fieldType reflect.Type, value reflect.Value,
opts *fieldOptionsWithContext, fullName string, opts *fieldOptionsWithContext, fullName string) error {
) error {
derefedType := Deref(fieldType) derefedType := Deref(fieldType)
fieldKind := derefedType.Kind() fieldKind := derefedType.Kind()
if defaultValue, ok := opts.getDefault(); ok { if defaultValue, ok := opts.getDefault(); ok {
@@ -1020,8 +1010,7 @@ func fillDurationValue(fieldType reflect.Type, value reflect.Value, dur string)
} }
func fillPrimitive(fieldType reflect.Type, value reflect.Value, mapValue any, func fillPrimitive(fieldType reflect.Type, value reflect.Value, mapValue any,
opts *fieldOptionsWithContext, fullName string, opts *fieldOptionsWithContext, fullName string) error {
) error {
if !value.CanSet() { if !value.CanSet() {
return errValueNotSettable return errValueNotSettable
} }
@@ -1050,8 +1039,7 @@ func fillPrimitive(fieldType reflect.Type, value reflect.Value, mapValue any,
} }
func fillWithSameType(fieldType reflect.Type, value reflect.Value, mapValue any, func fillWithSameType(fieldType reflect.Type, value reflect.Value, mapValue any,
opts *fieldOptionsWithContext, opts *fieldOptionsWithContext) error {
) error {
if !value.CanSet() { if !value.CanSet() {
return errValueNotSettable return errValueNotSettable
} }
@@ -1099,6 +1087,10 @@ func getValueWithChainedKeys(m valuerWithParent, keys []string) (any, bool) {
} }
} }
func implementsJsonUnmarshaler(t reflect.Type) bool {
return t.Implements(reflect.TypeOf((*json.Unmarshaler)(nil)).Elem())
}
func join(elem ...string) string { func join(elem ...string) string {
var builder strings.Builder var builder strings.Builder

View File

@@ -2,6 +2,7 @@ package mapping
import ( import (
"encoding/json" "encoding/json"
"errors"
"fmt" "fmt"
"reflect" "reflect"
"strconv" "strconv"
@@ -5760,6 +5761,25 @@ func TestUnmarshalWithIgnoreFields(t *testing.T) {
} }
} }
func TestUnmarshal_JsonUnmarshaler(t *testing.T) {
t.Run("success", func(t *testing.T) {
v := struct {
Foo *mockUnmarshaler `json:"name"`
}{}
body := `{"name": "hello"}`
assert.NoError(t, UnmarshalJsonBytes([]byte(body), &v))
assert.Equal(t, "hello", v.Foo.Name)
})
t.Run("failure", func(t *testing.T) {
v := struct {
Foo *mockUnmarshalerWithError `json:"name"`
}{}
body := `{"name": "hello"}`
assert.Error(t, UnmarshalJsonBytes([]byte(body), &v))
})
}
func BenchmarkDefaultValue(b *testing.B) { func BenchmarkDefaultValue(b *testing.B) {
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
var a struct { var a struct {
@@ -5873,3 +5893,20 @@ func (m mockValuerWithParent) Value(_ string) (any, bool) {
func (m mockValuerWithParent) Parent() valuerWithParent { func (m mockValuerWithParent) Parent() valuerWithParent {
return m.parent return m.parent
} }
type mockUnmarshaler struct {
Name string
}
func (m *mockUnmarshaler) UnmarshalJSON(b []byte) error {
m.Name = string(b)
return nil
}
type mockUnmarshalerWithError struct {
Name string
}
func (m *mockUnmarshalerWithError) UnmarshalJSON(b []byte) error {
return errors.New("foo")
}

View File

@@ -442,6 +442,17 @@ func TestParseWithEscapedParams(t *testing.T) {
}) })
} }
func TestCustomUnmarshalerStructRequest(t *testing.T) {
reqBody := `{"name": "hello"}`
r := httptest.NewRequest(http.MethodPost, "/a", bytes.NewReader([]byte(reqBody)))
r.Header.Set(ContentType, JsonContentType)
v := struct {
Foo *mockUnmarshaler `json:"name"`
}{}
assert.Nil(t, Parse(r, &v))
assert.Equal(t, "hello", v.Foo.Name)
}
func BenchmarkParseRaw(b *testing.B) { func BenchmarkParseRaw(b *testing.B) {
r, err := http.NewRequest(http.MethodGet, "http://hello.com/a?name=hello&age=18&percent=3.4", http.NoBody) r, err := http.NewRequest(http.MethodGet, "http://hello.com/a?name=hello&age=18&percent=3.4", http.NoBody)
if err != nil { if err != nil {
@@ -517,22 +528,11 @@ func (m mockRequest) Validate() error {
return nil return nil
} }
type mockCustomUnmarshalerStruct struct { type mockUnmarshaler struct {
Name string Name string
} }
func (m *mockCustomUnmarshalerStruct) UnmarshalJSON(b []byte) error { func (m *mockUnmarshaler) UnmarshalJSON(b []byte) error {
m.Name = string(b) m.Name = string(b)
return nil return nil
} }
func TestCustomUnmarshalerStructRequest(t *testing.T) {
reqBody := `{"name": "hello"}`
r := httptest.NewRequest(http.MethodPost, "/a", bytes.NewReader([]byte(reqBody)))
r.Header.Set("Content-Type", "application/json")
v := struct {
Foo *mockCustomUnmarshalerStruct `json:"name"`
}{}
assert.Nil(t, Parse(r, &v))
assert.Equal(t, "hello", v.Foo.Name)
}