From 040fee5669df43150411ad0368e5c45024f162cf Mon Sep 17 00:00:00 2001 From: Leo Date: Sat, 11 May 2024 22:25:10 +0800 Subject: [PATCH] feat: httpx.Parse supports parsing structures that implement the Unmarshaler interface (#4143) --- core/mapping/unmarshaler.go | 80 +++++++++++++++++++++++++++---------- rest/httpx/requests_test.go | 21 ++++++++++ 2 files changed, 80 insertions(+), 21 deletions(-) diff --git a/core/mapping/unmarshaler.go b/core/mapping/unmarshaler.go index 12cbe46bd..f6c2e9441 100644 --- a/core/mapping/unmarshaler.go +++ b/core/mapping/unmarshaler.go @@ -218,7 +218,8 @@ func (u *Unmarshaler) fillSlice(fieldType reflect.Type, value reflect.Value, map } func (u *Unmarshaler) fillSliceFromString(fieldType reflect.Type, value reflect.Value, - mapValue any, fullName string) error { + mapValue any, fullName string, +) error { var slice []any switch v := mapValue.(type) { case fmt.Stringer: @@ -248,7 +249,8 @@ func (u *Unmarshaler) fillSliceFromString(fieldType reflect.Type, value reflect. } func (u *Unmarshaler) fillSliceValue(slice reflect.Value, index int, - baseKind reflect.Kind, value any, fullName string) error { + baseKind reflect.Kind, value any, fullName string, +) error { if value == nil { return errNilSliceElement } @@ -286,7 +288,8 @@ func (u *Unmarshaler) fillSliceValue(slice reflect.Value, index int, } func (u *Unmarshaler) fillSliceWithDefault(derefedType reflect.Type, value reflect.Value, - defaultValue, fullName string) error { + defaultValue, fullName string, +) error { baseFieldType := Deref(derefedType.Elem()) baseFieldKind := baseFieldType.Kind() defaultCacheLock.Lock() @@ -400,7 +403,8 @@ func (u *Unmarshaler) generateMap(keyType, elemType reflect.Type, mapValue any, } func (u *Unmarshaler) parseOptionsWithContext(field reflect.StructField, m Valuer, fullName string) ( - string, *fieldOptionsWithContext, error) { + string, *fieldOptionsWithContext, error, +) { key, options, err := parseKeyAndOptions(u.key, field) if err != nil { return "", nil, err @@ -441,7 +445,8 @@ func (u *Unmarshaler) parseOptionsWithContext(field reflect.StructField, m Value } func (u *Unmarshaler) processAnonymousField(field reflect.StructField, value reflect.Value, - m valuerWithParent, fullName string) error { + m valuerWithParent, fullName string, +) error { key, options, err := u.parseOptionsWithContext(field, m, fullName) if err != nil { return err @@ -459,7 +464,8 @@ func (u *Unmarshaler) processAnonymousField(field reflect.StructField, value ref } func (u *Unmarshaler) processAnonymousFieldOptional(field reflect.StructField, value reflect.Value, - key string, m valuerWithParent, fullName string) error { + key string, m valuerWithParent, fullName string, +) error { derefedFieldType := Deref(field.Type) switch derefedFieldType.Kind() { @@ -471,7 +477,8 @@ func (u *Unmarshaler) processAnonymousFieldOptional(field reflect.StructField, v } func (u *Unmarshaler) processAnonymousFieldRequired(field reflect.StructField, value reflect.Value, - m valuerWithParent, fullName string) error { + m valuerWithParent, fullName string, +) error { fieldType := field.Type maybeNewValue(fieldType, value) derefedFieldType := Deref(fieldType) @@ -495,7 +502,8 @@ func (u *Unmarshaler) processAnonymousFieldRequired(field reflect.StructField, v } func (u *Unmarshaler) processAnonymousStructFieldOptional(fieldType reflect.Type, - value reflect.Value, key string, m valuerWithParent, fullName string) error { + value reflect.Value, key string, m valuerWithParent, fullName string, +) error { var filled bool var required int var requiredFilled int @@ -536,7 +544,8 @@ func (u *Unmarshaler) processAnonymousStructFieldOptional(fieldType reflect.Type } func (u *Unmarshaler) processField(field reflect.StructField, value reflect.Value, - m valuerWithParent, fullName string) error { + m valuerWithParent, fullName string, +) error { if usingDifferentKeys(u.key, field) { return nil } @@ -549,13 +558,16 @@ func (u *Unmarshaler) processField(field reflect.StructField, value reflect.Valu } func (u *Unmarshaler) processFieldNotFromString(fieldType reflect.Type, value reflect.Value, - vp valueWithParent, opts *fieldOptionsWithContext, fullName string) error { + vp valueWithParent, opts *fieldOptionsWithContext, fullName string, +) error { derefedFieldType := Deref(fieldType) typeKind := derefedFieldType.Kind() mapValue := vp.value valueKind := reflect.TypeOf(mapValue).Kind() switch { + case valueKind == reflect.String && typeKind == reflect.Struct && fieldType.Implements(reflect.TypeOf((*json.Unmarshaler)(nil)).Elem()): + return u.fillCustomUnmarshalerStruct(fieldType, value, mapValue) case valueKind == reflect.Map && typeKind == reflect.Struct: mv, ok := mapValue.(map[string]any) if !ok { @@ -581,8 +593,24 @@ func (u *Unmarshaler) processFieldNotFromString(fieldType reflect.Type, value re } } +func (u *Unmarshaler) fillCustomUnmarshalerStruct(fieldType reflect.Type, value reflect.Value, mapValue any) error { + if !value.CanSet() { + return errValueNotSettable + } + baseType := Deref(fieldType) + target := reflect.New(baseType) + + params := make([]reflect.Value, 1) + params[0] = reflect.ValueOf([]byte(mapValue.(string))) + target.MethodByName("UnmarshalJSON").Call(params) + + value.Set(target) + return nil +} + func (u *Unmarshaler) processFieldPrimitive(fieldType reflect.Type, value reflect.Value, - mapValue any, opts *fieldOptionsWithContext, fullName string) error { + mapValue any, opts *fieldOptionsWithContext, fullName string, +) error { typeKind := Deref(fieldType).Kind() valueKind := reflect.TypeOf(mapValue).Kind() @@ -603,7 +631,8 @@ func (u *Unmarshaler) processFieldPrimitive(fieldType reflect.Type, value reflec } func (u *Unmarshaler) processFieldPrimitiveWithJSONNumber(fieldType reflect.Type, value reflect.Value, - v json.Number, opts *fieldOptionsWithContext, fullName string) error { + v json.Number, opts *fieldOptionsWithContext, fullName string, +) error { baseType := Deref(fieldType) typeKind := baseType.Kind() @@ -656,7 +685,8 @@ func (u *Unmarshaler) processFieldPrimitiveWithJSONNumber(fieldType reflect.Type } func (u *Unmarshaler) processFieldStruct(fieldType reflect.Type, value reflect.Value, - m valuerWithParent, fullName string) error { + m valuerWithParent, fullName string, +) error { if fieldType.Kind() == reflect.Ptr { baseType := Deref(fieldType) target := reflect.New(baseType).Elem() @@ -673,7 +703,8 @@ func (u *Unmarshaler) processFieldStruct(fieldType reflect.Type, value reflect.V } func (u *Unmarshaler) processFieldTextUnmarshaler(fieldType reflect.Type, value reflect.Value, - mapValue any) (bool, error) { + mapValue any, +) (bool, error) { var tval encoding.TextUnmarshaler var ok bool @@ -701,7 +732,8 @@ func (u *Unmarshaler) processFieldTextUnmarshaler(fieldType reflect.Type, value } func (u *Unmarshaler) processFieldWithEnvValue(fieldType reflect.Type, value reflect.Value, - envVal string, opts *fieldOptionsWithContext, fullName string) error { + envVal string, opts *fieldOptionsWithContext, fullName string, +) error { if err := validateValueInOptions(envVal, opts.options()); err != nil { return err } @@ -731,7 +763,8 @@ func (u *Unmarshaler) processFieldWithEnvValue(fieldType reflect.Type, value ref } func (u *Unmarshaler) processNamedField(field reflect.StructField, value reflect.Value, - m valuerWithParent, fullName string) error { + m valuerWithParent, fullName string, +) error { if !field.IsExported() { return nil } @@ -778,7 +811,8 @@ func (u *Unmarshaler) processNamedField(field reflect.StructField, value reflect } func (u *Unmarshaler) processNamedFieldWithValue(fieldType reflect.Type, value reflect.Value, - vp valueWithParent, key string, opts *fieldOptionsWithContext, fullName string) error { + vp valueWithParent, key string, opts *fieldOptionsWithContext, fullName string, +) error { mapValue := vp.value if mapValue == nil { if opts.optional() { @@ -813,7 +847,8 @@ func (u *Unmarshaler) processNamedFieldWithValue(fieldType reflect.Type, value r } func (u *Unmarshaler) processNamedFieldWithValueFromString(fieldType reflect.Type, value reflect.Value, - mapValue any, key string, opts *fieldOptionsWithContext, fullName string) error { + mapValue any, key string, opts *fieldOptionsWithContext, fullName string, +) error { valueKind := reflect.TypeOf(mapValue).Kind() if valueKind != reflect.String { return fmt.Errorf("the value in map is not string, but %s", valueKind) @@ -842,7 +877,8 @@ func (u *Unmarshaler) processNamedFieldWithValueFromString(fieldType reflect.Typ } func (u *Unmarshaler) processNamedFieldWithoutValue(fieldType reflect.Type, value reflect.Value, - opts *fieldOptionsWithContext, fullName string) error { + opts *fieldOptionsWithContext, fullName string, +) error { derefedType := Deref(fieldType) fieldKind := derefedType.Kind() if defaultValue, ok := opts.getDefault(); ok { @@ -984,7 +1020,8 @@ func fillDurationValue(fieldType reflect.Type, value reflect.Value, dur string) } func fillPrimitive(fieldType reflect.Type, value reflect.Value, mapValue any, - opts *fieldOptionsWithContext, fullName string) error { + opts *fieldOptionsWithContext, fullName string, +) error { if !value.CanSet() { return errValueNotSettable } @@ -1013,7 +1050,8 @@ func fillPrimitive(fieldType reflect.Type, value reflect.Value, mapValue any, } func fillWithSameType(fieldType reflect.Type, value reflect.Value, mapValue any, - opts *fieldOptionsWithContext) error { + opts *fieldOptionsWithContext, +) error { if !value.CanSet() { return errValueNotSettable } diff --git a/rest/httpx/requests_test.go b/rest/httpx/requests_test.go index 1a1c8a10e..85ec16f94 100644 --- a/rest/httpx/requests_test.go +++ b/rest/httpx/requests_test.go @@ -1,6 +1,7 @@ package httpx import ( + "bytes" "errors" "net/http" "net/http/httptest" @@ -515,3 +516,23 @@ func (m mockRequest) Validate() error { return nil } + +type mockCustomUnmarshalerStruct struct { + Name string +} + +func (m *mockCustomUnmarshalerStruct) UnmarshalJSON(b []byte) error { + m.Name = string(b) + return nil +} + +func TestCustomUnmarshalerStructRequest(t *testing.T) { + reqBody := `{"name": "hello"}` + r := httptest.NewRequest(http.MethodPost, "/a", bytes.NewReader([]byte(reqBody))) + r.Header.Set("Content-Type", "application/json") + v := struct { + Foo *mockCustomUnmarshalerStruct `json:"name"` + }{} + assert.Nil(t, Parse(r, &v)) + assert.Equal(t, "hello", v.Foo.Name) +}