mirror of
https://github.com/zeromicro/go-zero.git
synced 2026-05-12 01:10:00 +08:00
Compare commits
72 Commits
copilot/fi
...
go1.16
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
2ea0a843f8 | ||
|
|
9e0e01b2bc | ||
|
|
af50a80d01 | ||
|
|
703fb8d970 | ||
|
|
e964e530e1 | ||
|
|
52265087d1 | ||
|
|
b4c2677eb9 | ||
|
|
30296fb1ca | ||
|
|
356c80defd | ||
|
|
8c31525378 | ||
|
|
2cf09f3c36 | ||
|
|
d41e542c92 | ||
|
|
265a24ac6d | ||
|
|
7d88fc39dc | ||
|
|
6957b6a344 | ||
|
|
bca6a230c8 | ||
|
|
cc8413d683 | ||
|
|
3842283fa8 | ||
|
|
fe13a533f5 | ||
|
|
7a327ccda4 | ||
|
|
06e4507406 | ||
|
|
8794d5b753 | ||
|
|
9bfa63d995 | ||
|
|
a432b121fb | ||
|
|
b61c94bb66 | ||
|
|
93fcf899dc | ||
|
|
9f4b3bae92 | ||
|
|
805cb87d98 | ||
|
|
366131640e | ||
|
|
956884a3ff | ||
|
|
f571cb8af2 | ||
|
|
cc5acf3b90 | ||
|
|
e1aa665443 | ||
|
|
cd357d9484 | ||
|
|
6d4d7cbd6b | ||
|
|
c593b5b531 | ||
|
|
fd5b38b07c | ||
|
|
41efb48f55 | ||
|
|
0ef3626839 | ||
|
|
77a72b16e9 | ||
|
|
21566f1b7a | ||
|
|
b2646e228b | ||
|
|
588b883710 | ||
|
|
033910bbd8 | ||
|
|
530dd79e3f | ||
|
|
cd5263ac75 | ||
|
|
ea3302a468 | ||
|
|
abf15b373c | ||
|
|
a865e9ee29 | ||
|
|
f8292198cf | ||
|
|
016d965f56 | ||
|
|
95d7c73409 | ||
|
|
939ef2a181 | ||
|
|
f0b8dd45fe | ||
|
|
0ba9335b04 | ||
|
|
04f181f0b4 | ||
|
|
89f841c126 | ||
|
|
d785c8c377 | ||
|
|
687a1d15da | ||
|
|
aaa974e1ad | ||
|
|
2779568ccf | ||
|
|
f7d50ae626 | ||
|
|
33594ea350 | ||
|
|
ee2ec974c4 | ||
|
|
fd2f2f0f54 | ||
|
|
86a2429d7d | ||
|
|
e5fe5dcc50 | ||
|
|
b510e7c242 | ||
|
|
dfe92e709f | ||
|
|
cb649cf627 | ||
|
|
ce19a5ade6 | ||
|
|
6dc56de714 |
@@ -213,23 +213,23 @@ func (s *Set) validate(i interface{}) {
|
|||||||
switch i.(type) {
|
switch i.(type) {
|
||||||
case int:
|
case int:
|
||||||
if s.tp != intType {
|
if s.tp != intType {
|
||||||
logx.Errorf("Error: element is int, but set contains elements with type %d", s.tp)
|
logx.Errorf("element is int, but set contains elements with type %d", s.tp)
|
||||||
}
|
}
|
||||||
case int64:
|
case int64:
|
||||||
if s.tp != int64Type {
|
if s.tp != int64Type {
|
||||||
logx.Errorf("Error: element is int64, but set contains elements with type %d", s.tp)
|
logx.Errorf("element is int64, but set contains elements with type %d", s.tp)
|
||||||
}
|
}
|
||||||
case uint:
|
case uint:
|
||||||
if s.tp != uintType {
|
if s.tp != uintType {
|
||||||
logx.Errorf("Error: element is uint, but set contains elements with type %d", s.tp)
|
logx.Errorf("element is uint, but set contains elements with type %d", s.tp)
|
||||||
}
|
}
|
||||||
case uint64:
|
case uint64:
|
||||||
if s.tp != uint64Type {
|
if s.tp != uint64Type {
|
||||||
logx.Errorf("Error: element is uint64, but set contains elements with type %d", s.tp)
|
logx.Errorf("element is uint64, but set contains elements with type %d", s.tp)
|
||||||
}
|
}
|
||||||
case string:
|
case string:
|
||||||
if s.tp != stringType {
|
if s.tp != stringType {
|
||||||
logx.Errorf("Error: element is string, but set contains elements with type %d", s.tp)
|
logx.Errorf("element is string, but set contains elements with type %d", s.tp)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -13,17 +13,29 @@ import (
|
|||||||
"github.com/zeromicro/go-zero/internal/encoding"
|
"github.com/zeromicro/go-zero/internal/encoding"
|
||||||
)
|
)
|
||||||
|
|
||||||
var loaders = map[string]func([]byte, interface{}) error{
|
const jsonTagKey = "json"
|
||||||
".json": LoadFromJsonBytes,
|
|
||||||
".toml": LoadFromTomlBytes,
|
var (
|
||||||
".yaml": LoadFromYamlBytes,
|
fillDefaultUnmarshaler = mapping.NewUnmarshaler(jsonTagKey, mapping.WithDefault())
|
||||||
".yml": LoadFromYamlBytes,
|
loaders = map[string]func([]byte, interface{}) error{
|
||||||
|
".json": LoadFromJsonBytes,
|
||||||
|
".toml": LoadFromTomlBytes,
|
||||||
|
".yaml": LoadFromYamlBytes,
|
||||||
|
".yml": LoadFromYamlBytes,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
// children and mapField should not be both filled.
|
||||||
|
// named fields and map cannot be bound to the same field name.
|
||||||
|
type fieldInfo struct {
|
||||||
|
children map[string]*fieldInfo
|
||||||
|
mapField *fieldInfo
|
||||||
}
|
}
|
||||||
|
|
||||||
type fieldInfo struct {
|
// FillDefault fills the default values for the given v,
|
||||||
name string
|
// and the premise is that the value of v must be guaranteed to be empty.
|
||||||
kind reflect.Kind
|
func FillDefault(v interface{}) error {
|
||||||
children map[string]fieldInfo
|
return fillDefaultUnmarshaler.Unmarshal(map[string]interface{}{}, v)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Load loads config into v from file, .json, .yaml and .yml are acceptable.
|
// Load loads config into v from file, .json, .yaml and .yml are acceptable.
|
||||||
@@ -58,13 +70,17 @@ func LoadConfig(file string, v interface{}, opts ...Option) error {
|
|||||||
|
|
||||||
// LoadFromJsonBytes loads config into v from content json bytes.
|
// LoadFromJsonBytes loads config into v from content json bytes.
|
||||||
func LoadFromJsonBytes(content []byte, v interface{}) error {
|
func LoadFromJsonBytes(content []byte, v interface{}) error {
|
||||||
|
info, err := buildFieldsInfo(reflect.TypeOf(v))
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
var m map[string]interface{}
|
var m map[string]interface{}
|
||||||
if err := jsonx.Unmarshal(content, &m); err != nil {
|
if err := jsonx.Unmarshal(content, &m); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
finfo := buildFieldsInfo(reflect.TypeOf(v))
|
lowerCaseKeyMap := toLowerCaseKeyMap(m, info)
|
||||||
lowerCaseKeyMap := toLowerCaseKeyMap(m, finfo)
|
|
||||||
|
|
||||||
return mapping.UnmarshalJsonMap(lowerCaseKeyMap, v, mapping.WithCanonicalKeyFunc(toLowerCase))
|
return mapping.UnmarshalJsonMap(lowerCaseKeyMap, v, mapping.WithCanonicalKeyFunc(toLowerCase))
|
||||||
}
|
}
|
||||||
@@ -108,7 +124,63 @@ func MustLoad(path string, v interface{}, opts ...Option) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func buildFieldsInfo(tp reflect.Type) map[string]fieldInfo {
|
func addOrMergeFields(info *fieldInfo, key string, child *fieldInfo) error {
|
||||||
|
if prev, ok := info.children[key]; ok {
|
||||||
|
if child.mapField != nil {
|
||||||
|
return newDupKeyError(key)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := mergeFields(prev, key, child.children); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
info.children[key] = child
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func buildAnonymousFieldInfo(info *fieldInfo, lowerCaseName string, ft reflect.Type) error {
|
||||||
|
switch ft.Kind() {
|
||||||
|
case reflect.Struct:
|
||||||
|
fields, err := buildFieldsInfo(ft)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
for k, v := range fields.children {
|
||||||
|
if err = addOrMergeFields(info, k, v); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
case reflect.Map:
|
||||||
|
elemField, err := buildFieldsInfo(mapping.Deref(ft.Elem()))
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, ok := info.children[lowerCaseName]; ok {
|
||||||
|
return newDupKeyError(lowerCaseName)
|
||||||
|
}
|
||||||
|
|
||||||
|
info.children[lowerCaseName] = &fieldInfo{
|
||||||
|
children: make(map[string]*fieldInfo),
|
||||||
|
mapField: elemField,
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
if _, ok := info.children[lowerCaseName]; ok {
|
||||||
|
return newDupKeyError(lowerCaseName)
|
||||||
|
}
|
||||||
|
|
||||||
|
info.children[lowerCaseName] = &fieldInfo{
|
||||||
|
children: make(map[string]*fieldInfo),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func buildFieldsInfo(tp reflect.Type) (*fieldInfo, error) {
|
||||||
tp = mapping.Deref(tp)
|
tp = mapping.Deref(tp)
|
||||||
|
|
||||||
switch tp.Kind() {
|
switch tp.Kind() {
|
||||||
@@ -116,61 +188,95 @@ func buildFieldsInfo(tp reflect.Type) map[string]fieldInfo {
|
|||||||
return buildStructFieldsInfo(tp)
|
return buildStructFieldsInfo(tp)
|
||||||
case reflect.Array, reflect.Slice:
|
case reflect.Array, reflect.Slice:
|
||||||
return buildFieldsInfo(mapping.Deref(tp.Elem()))
|
return buildFieldsInfo(mapping.Deref(tp.Elem()))
|
||||||
|
case reflect.Chan, reflect.Func:
|
||||||
|
return nil, fmt.Errorf("unsupported type: %s", tp.Kind())
|
||||||
default:
|
default:
|
||||||
return nil
|
return &fieldInfo{
|
||||||
|
children: make(map[string]*fieldInfo),
|
||||||
|
}, nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func buildStructFieldsInfo(tp reflect.Type) map[string]fieldInfo {
|
func buildNamedFieldInfo(info *fieldInfo, lowerCaseName string, ft reflect.Type) error {
|
||||||
info := make(map[string]fieldInfo)
|
var finfo *fieldInfo
|
||||||
|
var err error
|
||||||
|
|
||||||
|
switch ft.Kind() {
|
||||||
|
case reflect.Struct:
|
||||||
|
finfo, err = buildFieldsInfo(ft)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
case reflect.Array, reflect.Slice:
|
||||||
|
finfo, err = buildFieldsInfo(ft.Elem())
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
case reflect.Map:
|
||||||
|
elemInfo, err := buildFieldsInfo(mapping.Deref(ft.Elem()))
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
finfo = &fieldInfo{
|
||||||
|
children: make(map[string]*fieldInfo),
|
||||||
|
mapField: elemInfo,
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
finfo, err = buildFieldsInfo(ft)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return addOrMergeFields(info, lowerCaseName, finfo)
|
||||||
|
}
|
||||||
|
|
||||||
|
func buildStructFieldsInfo(tp reflect.Type) (*fieldInfo, error) {
|
||||||
|
info := &fieldInfo{
|
||||||
|
children: make(map[string]*fieldInfo),
|
||||||
|
}
|
||||||
|
|
||||||
for i := 0; i < tp.NumField(); i++ {
|
for i := 0; i < tp.NumField(); i++ {
|
||||||
field := tp.Field(i)
|
field := tp.Field(i)
|
||||||
name := field.Name
|
name := field.Name
|
||||||
lowerCaseName := toLowerCase(name)
|
lowerCaseName := toLowerCase(name)
|
||||||
ft := mapping.Deref(field.Type)
|
ft := mapping.Deref(field.Type)
|
||||||
|
|
||||||
// flatten anonymous fields
|
// flatten anonymous fields
|
||||||
if field.Anonymous {
|
if field.Anonymous {
|
||||||
if ft.Kind() == reflect.Struct {
|
if err := buildAnonymousFieldInfo(info, lowerCaseName, ft); err != nil {
|
||||||
fields := buildFieldsInfo(ft)
|
return nil, err
|
||||||
for k, v := range fields {
|
|
||||||
info[k] = v
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
info[lowerCaseName] = fieldInfo{
|
|
||||||
name: name,
|
|
||||||
kind: ft.Kind(),
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
continue
|
} else if err := buildNamedFieldInfo(info, lowerCaseName, ft); err != nil {
|
||||||
}
|
return nil, err
|
||||||
|
|
||||||
var fields map[string]fieldInfo
|
|
||||||
switch ft.Kind() {
|
|
||||||
case reflect.Struct:
|
|
||||||
fields = buildFieldsInfo(ft)
|
|
||||||
case reflect.Array, reflect.Slice:
|
|
||||||
fields = buildFieldsInfo(ft.Elem())
|
|
||||||
case reflect.Map:
|
|
||||||
fields = buildFieldsInfo(ft.Elem())
|
|
||||||
}
|
|
||||||
|
|
||||||
info[lowerCaseName] = fieldInfo{
|
|
||||||
name: name,
|
|
||||||
kind: ft.Kind(),
|
|
||||||
children: fields,
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return info
|
return info, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func mergeFields(prev *fieldInfo, key string, children map[string]*fieldInfo) error {
|
||||||
|
if len(prev.children) == 0 || len(children) == 0 {
|
||||||
|
return newDupKeyError(key)
|
||||||
|
}
|
||||||
|
|
||||||
|
// merge fields
|
||||||
|
for k, v := range children {
|
||||||
|
if _, ok := prev.children[k]; ok {
|
||||||
|
return newDupKeyError(k)
|
||||||
|
}
|
||||||
|
|
||||||
|
prev.children[k] = v
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func toLowerCase(s string) string {
|
func toLowerCase(s string) string {
|
||||||
return strings.ToLower(s)
|
return strings.ToLower(s)
|
||||||
}
|
}
|
||||||
|
|
||||||
func toLowerCaseInterface(v interface{}, info map[string]fieldInfo) interface{} {
|
func toLowerCaseInterface(v interface{}, info *fieldInfo) interface{} {
|
||||||
switch vv := v.(type) {
|
switch vv := v.(type) {
|
||||||
case map[string]interface{}:
|
case map[string]interface{}:
|
||||||
return toLowerCaseKeyMap(vv, info)
|
return toLowerCaseKeyMap(vv, info)
|
||||||
@@ -185,19 +291,21 @@ func toLowerCaseInterface(v interface{}, info map[string]fieldInfo) interface{}
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func toLowerCaseKeyMap(m map[string]interface{}, info map[string]fieldInfo) map[string]interface{} {
|
func toLowerCaseKeyMap(m map[string]interface{}, info *fieldInfo) map[string]interface{} {
|
||||||
res := make(map[string]interface{})
|
res := make(map[string]interface{})
|
||||||
|
|
||||||
for k, v := range m {
|
for k, v := range m {
|
||||||
ti, ok := info[k]
|
ti, ok := info.children[k]
|
||||||
if ok {
|
if ok {
|
||||||
res[k] = toLowerCaseInterface(v, ti.children)
|
res[k] = toLowerCaseInterface(v, ti)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
lk := toLowerCase(k)
|
lk := toLowerCase(k)
|
||||||
if ti, ok = info[lk]; ok {
|
if ti, ok = info.children[lk]; ok {
|
||||||
res[lk] = toLowerCaseInterface(v, ti.children)
|
res[lk] = toLowerCaseInterface(v, ti)
|
||||||
|
} else if info.mapField != nil {
|
||||||
|
res[k] = toLowerCaseInterface(v, info.mapField)
|
||||||
} else {
|
} else {
|
||||||
res[k] = v
|
res[k] = v
|
||||||
}
|
}
|
||||||
@@ -205,3 +313,15 @@ func toLowerCaseKeyMap(m map[string]interface{}, info map[string]fieldInfo) map[
|
|||||||
|
|
||||||
return res
|
return res
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type dupKeyError struct {
|
||||||
|
key string
|
||||||
|
}
|
||||||
|
|
||||||
|
func newDupKeyError(key string) dupKeyError {
|
||||||
|
return dupKeyError{key: key}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e dupKeyError) Error() string {
|
||||||
|
return fmt.Sprintf("duplicated key %s", e.key)
|
||||||
|
}
|
||||||
|
|||||||
@@ -9,6 +9,8 @@ import (
|
|||||||
"github.com/zeromicro/go-zero/core/hash"
|
"github.com/zeromicro/go-zero/core/hash"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
var dupErr dupKeyError
|
||||||
|
|
||||||
func TestLoadConfig_notExists(t *testing.T) {
|
func TestLoadConfig_notExists(t *testing.T) {
|
||||||
assert.NotNil(t, Load("not_a_file", nil))
|
assert.NotNil(t, Load("not_a_file", nil))
|
||||||
}
|
}
|
||||||
@@ -17,7 +19,7 @@ func TestLoadConfig_notRecogFile(t *testing.T) {
|
|||||||
filename, err := fs.TempFilenameWithText("hello")
|
filename, err := fs.TempFilenameWithText("hello")
|
||||||
assert.Nil(t, err)
|
assert.Nil(t, err)
|
||||||
defer os.Remove(filename)
|
defer os.Remove(filename)
|
||||||
assert.NotNil(t, Load(filename, nil))
|
assert.NotNil(t, LoadConfig(filename, nil))
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestConfigJson(t *testing.T) {
|
func TestConfigJson(t *testing.T) {
|
||||||
@@ -64,7 +66,7 @@ func TestLoadFromJsonBytesArray(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
assert.NoError(t, LoadFromJsonBytes(input, &val))
|
assert.NoError(t, LoadConfigFromJsonBytes(input, &val))
|
||||||
var expect []string
|
var expect []string
|
||||||
for _, user := range val.Users {
|
for _, user := range val.Users {
|
||||||
expect = append(expect, user.Name)
|
expect = append(expect, user.Name)
|
||||||
@@ -172,7 +174,7 @@ B: bar`)
|
|||||||
A string
|
A string
|
||||||
B string
|
B string
|
||||||
}
|
}
|
||||||
assert.NoError(t, LoadFromYamlBytes(text, &val1))
|
assert.NoError(t, LoadConfigFromYamlBytes(text, &val1))
|
||||||
assert.Equal(t, "foo", val1.A)
|
assert.Equal(t, "foo", val1.A)
|
||||||
assert.Equal(t, "bar", val1.B)
|
assert.Equal(t, "bar", val1.B)
|
||||||
assert.NoError(t, LoadFromYamlBytes(text, &val2))
|
assert.NoError(t, LoadFromYamlBytes(text, &val2))
|
||||||
@@ -384,6 +386,102 @@ func TestLoadFromYamlBytesLayers(t *testing.T) {
|
|||||||
assert.Equal(t, "foo", val.Value)
|
assert.Equal(t, "foo", val.Value)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestLoadFromYamlItemOverlay(t *testing.T) {
|
||||||
|
type (
|
||||||
|
Redis struct {
|
||||||
|
Host string
|
||||||
|
Port int
|
||||||
|
}
|
||||||
|
|
||||||
|
RedisKey struct {
|
||||||
|
Redis
|
||||||
|
Key string
|
||||||
|
}
|
||||||
|
|
||||||
|
Server struct {
|
||||||
|
Redis RedisKey
|
||||||
|
}
|
||||||
|
|
||||||
|
TestConfig struct {
|
||||||
|
Server
|
||||||
|
Redis Redis
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
input := []byte(`Redis:
|
||||||
|
Host: localhost
|
||||||
|
Port: 6379
|
||||||
|
Key: test
|
||||||
|
`)
|
||||||
|
|
||||||
|
var c TestConfig
|
||||||
|
assert.ErrorAs(t, LoadFromYamlBytes(input, &c), &dupErr)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLoadFromYamlItemOverlayReverse(t *testing.T) {
|
||||||
|
type (
|
||||||
|
Redis struct {
|
||||||
|
Host string
|
||||||
|
Port int
|
||||||
|
}
|
||||||
|
|
||||||
|
RedisKey struct {
|
||||||
|
Redis
|
||||||
|
Key string
|
||||||
|
}
|
||||||
|
|
||||||
|
Server struct {
|
||||||
|
Redis Redis
|
||||||
|
}
|
||||||
|
|
||||||
|
TestConfig struct {
|
||||||
|
Redis RedisKey
|
||||||
|
Server
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
input := []byte(`Redis:
|
||||||
|
Host: localhost
|
||||||
|
Port: 6379
|
||||||
|
Key: test
|
||||||
|
`)
|
||||||
|
|
||||||
|
var c TestConfig
|
||||||
|
assert.ErrorAs(t, LoadFromYamlBytes(input, &c), &dupErr)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLoadFromYamlItemOverlayWithMap(t *testing.T) {
|
||||||
|
type (
|
||||||
|
Redis struct {
|
||||||
|
Host string
|
||||||
|
Port int
|
||||||
|
}
|
||||||
|
|
||||||
|
RedisKey struct {
|
||||||
|
Redis
|
||||||
|
Key string
|
||||||
|
}
|
||||||
|
|
||||||
|
Server struct {
|
||||||
|
Redis RedisKey
|
||||||
|
}
|
||||||
|
|
||||||
|
TestConfig struct {
|
||||||
|
Server
|
||||||
|
Redis map[string]interface{}
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
input := []byte(`Redis:
|
||||||
|
Host: localhost
|
||||||
|
Port: 6379
|
||||||
|
Key: test
|
||||||
|
`)
|
||||||
|
|
||||||
|
var c TestConfig
|
||||||
|
assert.ErrorAs(t, LoadFromYamlBytes(input, &c), &dupErr)
|
||||||
|
}
|
||||||
|
|
||||||
func TestUnmarshalJsonBytesMap(t *testing.T) {
|
func TestUnmarshalJsonBytesMap(t *testing.T) {
|
||||||
input := []byte(`{"foo":{"/mtproto.RPCTos": "bff.bff","bar":"baz"}}`)
|
input := []byte(`{"foo":{"/mtproto.RPCTos": "bff.bff","bar":"baz"}}`)
|
||||||
|
|
||||||
@@ -450,6 +548,480 @@ func TestUnmarshalJsonBytesWithAnonymousField(t *testing.T) {
|
|||||||
assert.Equal(t, Int(3), c.Int)
|
assert.Equal(t, Int(3), c.Int)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestUnmarshalJsonBytesWithMapValueOfStruct(t *testing.T) {
|
||||||
|
type (
|
||||||
|
Value struct {
|
||||||
|
Name string
|
||||||
|
}
|
||||||
|
|
||||||
|
Config struct {
|
||||||
|
Items map[string]Value
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
var inputs = [][]byte{
|
||||||
|
[]byte(`{"Items": {"Key":{"Name": "foo"}}}`),
|
||||||
|
[]byte(`{"Items": {"Key":{"Name": "foo"}}}`),
|
||||||
|
[]byte(`{"items": {"key":{"name": "foo"}}}`),
|
||||||
|
[]byte(`{"items": {"key":{"name": "foo"}}}`),
|
||||||
|
}
|
||||||
|
for _, input := range inputs {
|
||||||
|
var c Config
|
||||||
|
if assert.NoError(t, LoadFromJsonBytes(input, &c)) {
|
||||||
|
assert.Equal(t, 1, len(c.Items))
|
||||||
|
for _, v := range c.Items {
|
||||||
|
assert.Equal(t, "foo", v.Name)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUnmarshalJsonBytesWithMapTypeValueOfStruct(t *testing.T) {
|
||||||
|
type (
|
||||||
|
Value struct {
|
||||||
|
Name string
|
||||||
|
}
|
||||||
|
|
||||||
|
Map map[string]Value
|
||||||
|
|
||||||
|
Config struct {
|
||||||
|
Map
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
var inputs = [][]byte{
|
||||||
|
[]byte(`{"Map": {"Key":{"Name": "foo"}}}`),
|
||||||
|
[]byte(`{"Map": {"Key":{"Name": "foo"}}}`),
|
||||||
|
[]byte(`{"map": {"key":{"name": "foo"}}}`),
|
||||||
|
[]byte(`{"map": {"key":{"name": "foo"}}}`),
|
||||||
|
}
|
||||||
|
for _, input := range inputs {
|
||||||
|
var c Config
|
||||||
|
if assert.NoError(t, LoadFromJsonBytes(input, &c)) {
|
||||||
|
assert.Equal(t, 1, len(c.Map))
|
||||||
|
for _, v := range c.Map {
|
||||||
|
assert.Equal(t, "foo", v.Name)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func Test_FieldOverwrite(t *testing.T) {
|
||||||
|
t.Run("normal", func(t *testing.T) {
|
||||||
|
type Base struct {
|
||||||
|
Name string
|
||||||
|
}
|
||||||
|
|
||||||
|
type St1 struct {
|
||||||
|
Base
|
||||||
|
Name2 string
|
||||||
|
}
|
||||||
|
|
||||||
|
type St2 struct {
|
||||||
|
Base
|
||||||
|
Name2 string
|
||||||
|
}
|
||||||
|
|
||||||
|
type St3 struct {
|
||||||
|
*Base
|
||||||
|
Name2 string
|
||||||
|
}
|
||||||
|
|
||||||
|
type St4 struct {
|
||||||
|
*Base
|
||||||
|
Name2 *string
|
||||||
|
}
|
||||||
|
|
||||||
|
validate := func(val interface{}) {
|
||||||
|
input := []byte(`{"Name": "hello", "Name2": "world"}`)
|
||||||
|
assert.NoError(t, LoadFromJsonBytes(input, val))
|
||||||
|
}
|
||||||
|
|
||||||
|
validate(&St1{})
|
||||||
|
validate(&St2{})
|
||||||
|
validate(&St3{})
|
||||||
|
validate(&St4{})
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("Inherit Override", func(t *testing.T) {
|
||||||
|
type Base struct {
|
||||||
|
Name string
|
||||||
|
}
|
||||||
|
|
||||||
|
type St1 struct {
|
||||||
|
Base
|
||||||
|
Name string
|
||||||
|
}
|
||||||
|
|
||||||
|
type St2 struct {
|
||||||
|
Base
|
||||||
|
Name int
|
||||||
|
}
|
||||||
|
|
||||||
|
type St3 struct {
|
||||||
|
*Base
|
||||||
|
Name int
|
||||||
|
}
|
||||||
|
|
||||||
|
type St4 struct {
|
||||||
|
*Base
|
||||||
|
Name *string
|
||||||
|
}
|
||||||
|
|
||||||
|
validate := func(val interface{}) {
|
||||||
|
input := []byte(`{"Name": "hello"}`)
|
||||||
|
err := LoadFromJsonBytes(input, val)
|
||||||
|
assert.ErrorAs(t, err, &dupErr)
|
||||||
|
assert.Equal(t, newDupKeyError("name").Error(), err.Error())
|
||||||
|
}
|
||||||
|
|
||||||
|
validate(&St1{})
|
||||||
|
validate(&St2{})
|
||||||
|
validate(&St3{})
|
||||||
|
validate(&St4{})
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("Inherit more", func(t *testing.T) {
|
||||||
|
type Base1 struct {
|
||||||
|
Name string
|
||||||
|
}
|
||||||
|
|
||||||
|
type St0 struct {
|
||||||
|
Base1
|
||||||
|
Name string
|
||||||
|
}
|
||||||
|
|
||||||
|
type St1 struct {
|
||||||
|
St0
|
||||||
|
Name string
|
||||||
|
}
|
||||||
|
|
||||||
|
type St2 struct {
|
||||||
|
St0
|
||||||
|
Name int
|
||||||
|
}
|
||||||
|
|
||||||
|
type St3 struct {
|
||||||
|
*St0
|
||||||
|
Name int
|
||||||
|
}
|
||||||
|
|
||||||
|
type St4 struct {
|
||||||
|
*St0
|
||||||
|
Name *int
|
||||||
|
}
|
||||||
|
|
||||||
|
validate := func(val interface{}) {
|
||||||
|
input := []byte(`{"Name": "hello"}`)
|
||||||
|
err := LoadFromJsonBytes(input, val)
|
||||||
|
assert.ErrorAs(t, err, &dupErr)
|
||||||
|
assert.Equal(t, newDupKeyError("name").Error(), err.Error())
|
||||||
|
}
|
||||||
|
|
||||||
|
validate(&St0{})
|
||||||
|
validate(&St1{})
|
||||||
|
validate(&St2{})
|
||||||
|
validate(&St3{})
|
||||||
|
validate(&St4{})
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFieldOverwriteComplicated(t *testing.T) {
|
||||||
|
t.Run("double maps", func(t *testing.T) {
|
||||||
|
type (
|
||||||
|
Base1 struct {
|
||||||
|
Values map[string]string
|
||||||
|
}
|
||||||
|
Base2 struct {
|
||||||
|
Values map[string]string
|
||||||
|
}
|
||||||
|
Config struct {
|
||||||
|
Base1
|
||||||
|
Base2
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
var c Config
|
||||||
|
input := []byte(`{"Values": {"Key": "Value"}}`)
|
||||||
|
assert.ErrorAs(t, LoadFromJsonBytes(input, &c), &dupErr)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("merge children", func(t *testing.T) {
|
||||||
|
type (
|
||||||
|
Inner1 struct {
|
||||||
|
Name string
|
||||||
|
}
|
||||||
|
Inner2 struct {
|
||||||
|
Age int
|
||||||
|
}
|
||||||
|
Base1 struct {
|
||||||
|
Inner Inner1
|
||||||
|
}
|
||||||
|
Base2 struct {
|
||||||
|
Inner Inner2
|
||||||
|
}
|
||||||
|
Config struct {
|
||||||
|
Base1
|
||||||
|
Base2
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
var c Config
|
||||||
|
input := []byte(`{"Inner": {"Name": "foo", "Age": 10}}`)
|
||||||
|
if assert.NoError(t, LoadFromJsonBytes(input, &c)) {
|
||||||
|
assert.Equal(t, "foo", c.Base1.Inner.Name)
|
||||||
|
assert.Equal(t, 10, c.Base2.Inner.Age)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("overwritten maps", func(t *testing.T) {
|
||||||
|
type (
|
||||||
|
Inner struct {
|
||||||
|
Map map[string]string
|
||||||
|
}
|
||||||
|
Config struct {
|
||||||
|
Map map[string]string
|
||||||
|
Inner
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
var c Config
|
||||||
|
input := []byte(`{"Inner": {"Map": {"Key": "Value"}}}`)
|
||||||
|
assert.ErrorAs(t, LoadFromJsonBytes(input, &c), &dupErr)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("overwritten nested maps", func(t *testing.T) {
|
||||||
|
type (
|
||||||
|
Inner struct {
|
||||||
|
Map map[string]string
|
||||||
|
}
|
||||||
|
Middle1 struct {
|
||||||
|
Map map[string]string
|
||||||
|
Inner
|
||||||
|
}
|
||||||
|
Middle2 struct {
|
||||||
|
Map map[string]string
|
||||||
|
Inner
|
||||||
|
}
|
||||||
|
Config struct {
|
||||||
|
Middle1
|
||||||
|
Middle2
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
var c Config
|
||||||
|
input := []byte(`{"Middle1": {"Inner": {"Map": {"Key": "Value"}}}}`)
|
||||||
|
assert.ErrorAs(t, LoadFromJsonBytes(input, &c), &dupErr)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("overwritten outer/inner maps", func(t *testing.T) {
|
||||||
|
type (
|
||||||
|
Inner struct {
|
||||||
|
Map map[string]string
|
||||||
|
}
|
||||||
|
Middle struct {
|
||||||
|
Inner
|
||||||
|
Map map[string]string
|
||||||
|
}
|
||||||
|
Config struct {
|
||||||
|
Middle
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
var c Config
|
||||||
|
input := []byte(`{"Middle": {"Inner": {"Map": {"Key": "Value"}}}}`)
|
||||||
|
assert.ErrorAs(t, LoadFromJsonBytes(input, &c), &dupErr)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("overwritten anonymous maps", func(t *testing.T) {
|
||||||
|
type (
|
||||||
|
Inner struct {
|
||||||
|
Map map[string]string
|
||||||
|
}
|
||||||
|
Middle struct {
|
||||||
|
Inner
|
||||||
|
Map map[string]string
|
||||||
|
}
|
||||||
|
Elem map[string]Middle
|
||||||
|
Config struct {
|
||||||
|
Elem
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
var c Config
|
||||||
|
input := []byte(`{"Elem": {"Key": {"Inner": {"Map": {"Key": "Value"}}}}}`)
|
||||||
|
assert.ErrorAs(t, LoadFromJsonBytes(input, &c), &dupErr)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("overwritten primitive and map", func(t *testing.T) {
|
||||||
|
type (
|
||||||
|
Inner struct {
|
||||||
|
Value string
|
||||||
|
}
|
||||||
|
Elem map[string]Inner
|
||||||
|
Named struct {
|
||||||
|
Elem string
|
||||||
|
}
|
||||||
|
Config struct {
|
||||||
|
Named
|
||||||
|
Elem
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
var c Config
|
||||||
|
input := []byte(`{"Elem": {"Key": {"Value": "Value"}}}`)
|
||||||
|
assert.ErrorAs(t, LoadFromJsonBytes(input, &c), &dupErr)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("overwritten map and slice", func(t *testing.T) {
|
||||||
|
type (
|
||||||
|
Inner struct {
|
||||||
|
Value string
|
||||||
|
}
|
||||||
|
Elem []Inner
|
||||||
|
Named struct {
|
||||||
|
Elem string
|
||||||
|
}
|
||||||
|
Config struct {
|
||||||
|
Named
|
||||||
|
Elem
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
var c Config
|
||||||
|
input := []byte(`{"Elem": {"Key": {"Value": "Value"}}}`)
|
||||||
|
assert.ErrorAs(t, LoadFromJsonBytes(input, &c), &dupErr)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("overwritten map and string", func(t *testing.T) {
|
||||||
|
type (
|
||||||
|
Elem string
|
||||||
|
Named struct {
|
||||||
|
Elem string
|
||||||
|
}
|
||||||
|
Config struct {
|
||||||
|
Named
|
||||||
|
Elem
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
var c Config
|
||||||
|
input := []byte(`{"Elem": {"Key": {"Value": "Value"}}}`)
|
||||||
|
assert.ErrorAs(t, LoadFromJsonBytes(input, &c), &dupErr)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLoadNamedFieldOverwritten(t *testing.T) {
|
||||||
|
t.Run("overwritten named struct", func(t *testing.T) {
|
||||||
|
type (
|
||||||
|
Elem string
|
||||||
|
Named struct {
|
||||||
|
Elem string
|
||||||
|
}
|
||||||
|
Base struct {
|
||||||
|
Named
|
||||||
|
Elem
|
||||||
|
}
|
||||||
|
Config struct {
|
||||||
|
Val Base
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
var c Config
|
||||||
|
input := []byte(`{"Val": {"Elem": {"Key": {"Value": "Value"}}}}`)
|
||||||
|
assert.ErrorAs(t, LoadFromJsonBytes(input, &c), &dupErr)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("overwritten named []struct", func(t *testing.T) {
|
||||||
|
type (
|
||||||
|
Elem string
|
||||||
|
Named struct {
|
||||||
|
Elem string
|
||||||
|
}
|
||||||
|
Base struct {
|
||||||
|
Named
|
||||||
|
Elem
|
||||||
|
}
|
||||||
|
Config struct {
|
||||||
|
Vals []Base
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
var c Config
|
||||||
|
input := []byte(`{"Vals": [{"Elem": {"Key": {"Value": "Value"}}}]}`)
|
||||||
|
assert.ErrorAs(t, LoadFromJsonBytes(input, &c), &dupErr)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("overwritten named map[string]struct", func(t *testing.T) {
|
||||||
|
type (
|
||||||
|
Elem string
|
||||||
|
Named struct {
|
||||||
|
Elem string
|
||||||
|
}
|
||||||
|
Base struct {
|
||||||
|
Named
|
||||||
|
Elem
|
||||||
|
}
|
||||||
|
Config struct {
|
||||||
|
Vals map[string]Base
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
var c Config
|
||||||
|
input := []byte(`{"Vals": {"Key": {"Elem": {"Key": {"Value": "Value"}}}}}`)
|
||||||
|
assert.ErrorAs(t, LoadFromJsonBytes(input, &c), &dupErr)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("overwritten named *struct", func(t *testing.T) {
|
||||||
|
type (
|
||||||
|
Elem string
|
||||||
|
Named struct {
|
||||||
|
Elem string
|
||||||
|
}
|
||||||
|
Base struct {
|
||||||
|
Named
|
||||||
|
Elem
|
||||||
|
}
|
||||||
|
Config struct {
|
||||||
|
Vals *Base
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
var c Config
|
||||||
|
input := []byte(`{"Vals": [{"Elem": {"Key": {"Value": "Value"}}}]}`)
|
||||||
|
assert.ErrorAs(t, LoadFromJsonBytes(input, &c), &dupErr)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("overwritten named struct", func(t *testing.T) {
|
||||||
|
type (
|
||||||
|
Named struct {
|
||||||
|
Elem string
|
||||||
|
}
|
||||||
|
Base struct {
|
||||||
|
Named
|
||||||
|
Elem Named
|
||||||
|
}
|
||||||
|
Config struct {
|
||||||
|
Val Base
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
var c Config
|
||||||
|
input := []byte(`{"Val": {"Elem": "Value"}}`)
|
||||||
|
assert.ErrorAs(t, LoadFromJsonBytes(input, &c), &dupErr)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("overwritten named struct", func(t *testing.T) {
|
||||||
|
type Config struct {
|
||||||
|
Val chan int
|
||||||
|
}
|
||||||
|
|
||||||
|
var c Config
|
||||||
|
input := []byte(`{"Val": 1}`)
|
||||||
|
assert.Error(t, LoadFromJsonBytes(input, &c))
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
func createTempFile(ext, text string) (string, error) {
|
func createTempFile(ext, text string) (string, error) {
|
||||||
tmpfile, err := os.CreateTemp(os.TempDir(), hash.Md5Hex([]byte(text))+"*"+ext)
|
tmpfile, err := os.CreateTemp(os.TempDir(), hash.Md5Hex([]byte(text))+"*"+ext)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -467,3 +1039,55 @@ func createTempFile(ext, text string) (string, error) {
|
|||||||
|
|
||||||
return filename, nil
|
return filename, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestFillDefaultUnmarshal(t *testing.T) {
|
||||||
|
t.Run("nil", func(t *testing.T) {
|
||||||
|
type St struct{}
|
||||||
|
err := FillDefault(St{})
|
||||||
|
assert.Error(t, err)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("not nil", func(t *testing.T) {
|
||||||
|
type St struct{}
|
||||||
|
err := FillDefault(&St{})
|
||||||
|
assert.NoError(t, err)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("default", func(t *testing.T) {
|
||||||
|
type St struct {
|
||||||
|
A string `json:",default=a"`
|
||||||
|
B string
|
||||||
|
}
|
||||||
|
var st St
|
||||||
|
err := FillDefault(&st)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, st.A, "a")
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("env", func(t *testing.T) {
|
||||||
|
type St struct {
|
||||||
|
A string `json:",default=a"`
|
||||||
|
B string
|
||||||
|
C string `json:",env=TEST_C"`
|
||||||
|
}
|
||||||
|
t.Setenv("TEST_C", "c")
|
||||||
|
|
||||||
|
var st St
|
||||||
|
err := FillDefault(&st)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, st.A, "a")
|
||||||
|
assert.Equal(t, st.C, "c")
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("has vaue", func(t *testing.T) {
|
||||||
|
type St struct {
|
||||||
|
A string `json:",default=a"`
|
||||||
|
B string
|
||||||
|
}
|
||||||
|
var st = St{
|
||||||
|
A: "b",
|
||||||
|
}
|
||||||
|
err := FillDefault(&st)
|
||||||
|
assert.Error(t, err)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|||||||
@@ -12,7 +12,6 @@ import (
|
|||||||
|
|
||||||
// PropertyError represents a configuration error message.
|
// PropertyError represents a configuration error message.
|
||||||
type PropertyError struct {
|
type PropertyError struct {
|
||||||
error
|
|
||||||
message string
|
message string
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -4,6 +4,7 @@
|
|||||||
|
|
||||||
```go
|
```go
|
||||||
type RestfulConf struct {
|
type RestfulConf struct {
|
||||||
|
ServiceName string `json:",env=SERVICE_NAME"` // read from env automatically
|
||||||
Host string `json:",default=0.0.0.0"`
|
Host string `json:",default=0.0.0.0"`
|
||||||
Port int
|
Port int
|
||||||
LogMode string `json:",options=[file,console]"`
|
LogMode string `json:",options=[file,console]"`
|
||||||
@@ -21,20 +22,20 @@ type RestfulConf struct {
|
|||||||
|
|
||||||
```yaml
|
```yaml
|
||||||
# most fields are optional or have default values
|
# most fields are optional or have default values
|
||||||
Port: 8080
|
port: 8080
|
||||||
LogMode: console
|
logMode: console
|
||||||
# you can use env settings
|
# you can use env settings
|
||||||
MaxBytes: ${MAX_BYTES}
|
maxBytes: ${MAX_BYTES}
|
||||||
```
|
```
|
||||||
|
|
||||||
- toml example
|
- toml example
|
||||||
|
|
||||||
```toml
|
```toml
|
||||||
# most fields are optional or have default values
|
# most fields are optional or have default values
|
||||||
Port = 8_080
|
port = 8_080
|
||||||
LogMode = "console"
|
logMode = "console"
|
||||||
# you can use env settings
|
# you can use env settings
|
||||||
MaxBytes = "${MAX_BYTES}"
|
maxBytes = "${MAX_BYTES}"
|
||||||
```
|
```
|
||||||
|
|
||||||
3. Load the config from a file:
|
3. Load the config from a file:
|
||||||
|
|||||||
@@ -53,10 +53,11 @@ func TestChunkExecutorFlushInterval(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestChunkExecutorEmpty(t *testing.T) {
|
func TestChunkExecutorEmpty(t *testing.T) {
|
||||||
NewChunkExecutor(func(items []interface{}) {
|
executor := NewChunkExecutor(func(items []interface{}) {
|
||||||
assert.Fail(t, "should not called")
|
assert.Fail(t, "should not called")
|
||||||
}, WithChunkBytes(10), WithFlushInterval(time.Millisecond))
|
}, WithChunkBytes(10), WithFlushInterval(time.Millisecond))
|
||||||
time.Sleep(time.Millisecond * 100)
|
time.Sleep(time.Millisecond * 100)
|
||||||
|
executor.Wait()
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestChunkExecutorFlush(t *testing.T) {
|
func TestChunkExecutorFlush(t *testing.T) {
|
||||||
|
|||||||
@@ -8,6 +8,7 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/zeromicro/go-zero/core/proc"
|
||||||
"github.com/zeromicro/go-zero/core/timex"
|
"github.com/zeromicro/go-zero/core/timex"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -67,6 +68,7 @@ func TestPeriodicalExecutor_QuitGoroutine(t *testing.T) {
|
|||||||
ticker.Tick()
|
ticker.Tick()
|
||||||
ticker.Wait(time.Millisecond * idleRound)
|
ticker.Wait(time.Millisecond * idleRound)
|
||||||
assert.Equal(t, routines, runtime.NumGoroutine())
|
assert.Equal(t, routines, runtime.NumGoroutine())
|
||||||
|
proc.Shutdown()
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestPeriodicalExecutor_Bulk(t *testing.T) {
|
func TestPeriodicalExecutor_Bulk(t *testing.T) {
|
||||||
|
|||||||
@@ -27,6 +27,26 @@ func Close() error {
|
|||||||
return logx.Close()
|
return logx.Close()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Debug writes v into access log.
|
||||||
|
func Debug(ctx context.Context, v ...interface{}) {
|
||||||
|
getLogger(ctx).Debug(v...)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Debugf writes v with format into access log.
|
||||||
|
func Debugf(ctx context.Context, format string, v ...interface{}) {
|
||||||
|
getLogger(ctx).Debugf(format, v...)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Debugv writes v into access log with json content.
|
||||||
|
func Debugv(ctx context.Context, v interface{}) {
|
||||||
|
getLogger(ctx).Debugv(v)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Debugw writes msg along with fields into access log.
|
||||||
|
func Debugw(ctx context.Context, msg string, fields ...LogField) {
|
||||||
|
getLogger(ctx).Debugw(msg, fields...)
|
||||||
|
}
|
||||||
|
|
||||||
// Error writes v into error log.
|
// Error writes v into error log.
|
||||||
func Error(ctx context.Context, v ...interface{}) {
|
func Error(ctx context.Context, v ...interface{}) {
|
||||||
getLogger(ctx).Error(v...)
|
getLogger(ctx).Error(v...)
|
||||||
|
|||||||
@@ -140,6 +140,54 @@ func TestInfow(t *testing.T) {
|
|||||||
assert.True(t, strings.Contains(buf.String(), fmt.Sprintf("%s:%d", file, line+1)))
|
assert.True(t, strings.Contains(buf.String(), fmt.Sprintf("%s:%d", file, line+1)))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestDebug(t *testing.T) {
|
||||||
|
var buf strings.Builder
|
||||||
|
writer := logx.NewWriter(&buf)
|
||||||
|
old := logx.Reset()
|
||||||
|
logx.SetWriter(writer)
|
||||||
|
defer logx.SetWriter(old)
|
||||||
|
|
||||||
|
file, line := getFileLine()
|
||||||
|
Debug(context.Background(), "foo")
|
||||||
|
assert.True(t, strings.Contains(buf.String(), fmt.Sprintf("%s:%d", file, line+1)))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDebugf(t *testing.T) {
|
||||||
|
var buf strings.Builder
|
||||||
|
writer := logx.NewWriter(&buf)
|
||||||
|
old := logx.Reset()
|
||||||
|
logx.SetWriter(writer)
|
||||||
|
defer logx.SetWriter(old)
|
||||||
|
|
||||||
|
file, line := getFileLine()
|
||||||
|
Debugf(context.Background(), "foo %s", "bar")
|
||||||
|
assert.True(t, strings.Contains(buf.String(), fmt.Sprintf("%s:%d", file, line+1)))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDebugv(t *testing.T) {
|
||||||
|
var buf strings.Builder
|
||||||
|
writer := logx.NewWriter(&buf)
|
||||||
|
old := logx.Reset()
|
||||||
|
logx.SetWriter(writer)
|
||||||
|
defer logx.SetWriter(old)
|
||||||
|
|
||||||
|
file, line := getFileLine()
|
||||||
|
Debugv(context.Background(), "foo")
|
||||||
|
assert.True(t, strings.Contains(buf.String(), fmt.Sprintf("%s:%d", file, line+1)))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDebugw(t *testing.T) {
|
||||||
|
var buf strings.Builder
|
||||||
|
writer := logx.NewWriter(&buf)
|
||||||
|
old := logx.Reset()
|
||||||
|
logx.SetWriter(writer)
|
||||||
|
defer logx.SetWriter(old)
|
||||||
|
|
||||||
|
file, line := getFileLine()
|
||||||
|
Debugw(context.Background(), "foo", Field("a", "b"))
|
||||||
|
assert.True(t, strings.Contains(buf.String(), fmt.Sprintf("%s:%d", file, line+1)))
|
||||||
|
}
|
||||||
|
|
||||||
func TestMust(t *testing.T) {
|
func TestMust(t *testing.T) {
|
||||||
assert.NotPanics(t, func() {
|
assert.NotPanics(t, func() {
|
||||||
Must(nil)
|
Must(nil)
|
||||||
|
|||||||
@@ -2,17 +2,34 @@ package logx
|
|||||||
|
|
||||||
// A LogConf is a logging config.
|
// A LogConf is a logging config.
|
||||||
type LogConf struct {
|
type LogConf struct {
|
||||||
ServiceName string `json:",optional"`
|
// ServiceName represents the service name.
|
||||||
Mode string `json:",default=console,options=[console,file,volume]"`
|
ServiceName string `json:",optional"`
|
||||||
Encoding string `json:",default=json,options=[json,plain]"`
|
// Mode represents the logging mode, default is `console`.
|
||||||
TimeFormat string `json:",optional"`
|
// console: log to console.
|
||||||
Path string `json:",default=logs"`
|
// file: log to file.
|
||||||
Level string `json:",default=info,options=[debug,info,error,severe]"`
|
// volume: used in k8s, prepend the hostname to the log file name.
|
||||||
MaxContentLength uint32 `json:",optional"`
|
Mode string `json:",default=console,options=[console,file,volume]"`
|
||||||
Compress bool `json:",optional"`
|
// Encoding represents the encoding type, default is `json`.
|
||||||
Stat bool `json:",default=true"`
|
// json: json encoding.
|
||||||
KeepDays int `json:",optional"`
|
// plain: plain text encoding, typically used in development.
|
||||||
StackCooldownMillis int `json:",default=100"`
|
Encoding string `json:",default=json,options=[json,plain]"`
|
||||||
|
// TimeFormat represents the time format, default is `2006-01-02T15:04:05.000Z07:00`.
|
||||||
|
TimeFormat string `json:",optional"`
|
||||||
|
// Path represents the log file path, default is `logs`.
|
||||||
|
Path string `json:",default=logs"`
|
||||||
|
// Level represents the log level, default is `info`.
|
||||||
|
Level string `json:",default=info,options=[debug,info,error,severe]"`
|
||||||
|
// MaxContentLength represents the max content bytes, default is no limit.
|
||||||
|
MaxContentLength uint32 `json:",optional"`
|
||||||
|
// Compress represents whether to compress the log file, default is `false`.
|
||||||
|
Compress bool `json:",optional"`
|
||||||
|
// Stdout represents whether to log statistics, default is `true`.
|
||||||
|
Stat bool `json:",default=true"`
|
||||||
|
// KeepDays represents how many days the log files will be kept. Default to keep all files.
|
||||||
|
// Only take effect when Mode is `file` or `volume`, both work when Rotation is `daily` or `size`.
|
||||||
|
KeepDays int `json:",optional"`
|
||||||
|
// StackCooldownMillis represents the cooldown time for stack logging, default is 100ms.
|
||||||
|
StackCooldownMillis int `json:",default=100"`
|
||||||
// MaxBackups represents how many backup log files will be kept. 0 means all files will be kept forever.
|
// MaxBackups represents how many backup log files will be kept. 0 means all files will be kept forever.
|
||||||
// Only take effect when RotationRuleType is `size`.
|
// Only take effect when RotationRuleType is `size`.
|
||||||
// Even thougth `MaxBackups` sets 0, log files will still be removed
|
// Even thougth `MaxBackups` sets 0, log files will still be removed
|
||||||
|
|||||||
@@ -108,7 +108,7 @@ func TestNopWriter(t *testing.T) {
|
|||||||
w.Stack("foo")
|
w.Stack("foo")
|
||||||
w.Stat("foo")
|
w.Stat("foo")
|
||||||
w.Slow("foo")
|
w.Slow("foo")
|
||||||
w.Close()
|
_ = w.Close()
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -47,6 +47,7 @@ type (
|
|||||||
UnmarshalOption func(*unmarshalOptions)
|
UnmarshalOption func(*unmarshalOptions)
|
||||||
|
|
||||||
unmarshalOptions struct {
|
unmarshalOptions struct {
|
||||||
|
fillDefault bool
|
||||||
fromString bool
|
fromString bool
|
||||||
canonicalKey func(key string) string
|
canonicalKey func(key string) string
|
||||||
}
|
}
|
||||||
@@ -710,7 +711,14 @@ func (u *Unmarshaler) processNamedField(field reflect.StructField, value reflect
|
|||||||
|
|
||||||
valuer := createValuer(m, opts)
|
valuer := createValuer(m, opts)
|
||||||
mapValue, hasValue := getValue(valuer, canonicalKey)
|
mapValue, hasValue := getValue(valuer, canonicalKey)
|
||||||
if !hasValue {
|
|
||||||
|
// When fillDefault is used, m is a null value, hasValue must be false, all priority judgments fillDefault,
|
||||||
|
if u.opts.fillDefault {
|
||||||
|
if !value.IsZero() {
|
||||||
|
return fmt.Errorf("set the default value, %s must be zero", fullName)
|
||||||
|
}
|
||||||
|
return u.processNamedFieldWithoutValue(field.Type, value, opts, fullName)
|
||||||
|
} else if !hasValue {
|
||||||
return u.processNamedFieldWithoutValue(field.Type, value, opts, fullName)
|
return u.processNamedFieldWithoutValue(field.Type, value, opts, fullName)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -801,6 +809,10 @@ func (u *Unmarshaler) processNamedFieldWithoutValue(fieldType reflect.Type, valu
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if u.opts.fillDefault {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
switch fieldKind {
|
switch fieldKind {
|
||||||
case reflect.Array, reflect.Map, reflect.Slice:
|
case reflect.Array, reflect.Map, reflect.Slice:
|
||||||
if !opts.optional() {
|
if !opts.optional() {
|
||||||
@@ -853,7 +865,12 @@ func (u *Unmarshaler) unmarshalWithFullName(m valuerWithParent, v interface{}, f
|
|||||||
|
|
||||||
numFields := baseType.NumField()
|
numFields := baseType.NumField()
|
||||||
for i := 0; i < numFields; i++ {
|
for i := 0; i < numFields; i++ {
|
||||||
if err := u.processField(baseType.Field(i), valElem.Field(i), m, fullName); err != nil {
|
field := baseType.Field(i)
|
||||||
|
if !field.IsExported() {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := u.processField(field, valElem.Field(i), m, fullName); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -868,13 +885,20 @@ func WithStringValues() UnmarshalOption {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// WithCanonicalKeyFunc customizes an Unmarshaler with Canonical Key func
|
// WithCanonicalKeyFunc customizes an Unmarshaler with Canonical Key func.
|
||||||
func WithCanonicalKeyFunc(f func(string) string) UnmarshalOption {
|
func WithCanonicalKeyFunc(f func(string) string) UnmarshalOption {
|
||||||
return func(opt *unmarshalOptions) {
|
return func(opt *unmarshalOptions) {
|
||||||
opt.canonicalKey = f
|
opt.canonicalKey = f
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// WithDefault customizes an Unmarshaler with fill default values.
|
||||||
|
func WithDefault() UnmarshalOption {
|
||||||
|
return func(opt *unmarshalOptions) {
|
||||||
|
opt.fillDefault = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func createValuer(v valuerWithParent, opts *fieldOptionsWithContext) valuerWithParent {
|
func createValuer(v valuerWithParent, opts *fieldOptionsWithContext) valuerWithParent {
|
||||||
if opts.inherit() {
|
if opts.inherit() {
|
||||||
return recursiveValuer{
|
return recursiveValuer{
|
||||||
@@ -1004,7 +1028,7 @@ func newInitError(name string) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func newTypeMismatchError(name string) error {
|
func newTypeMismatchError(name string) error {
|
||||||
return fmt.Errorf("error: type mismatch for field %s", name)
|
return fmt.Errorf("type mismatch for field %s", name)
|
||||||
}
|
}
|
||||||
|
|
||||||
func readKeys(key string) []string {
|
func readKeys(key string) []string {
|
||||||
|
|||||||
@@ -12,6 +12,7 @@ import (
|
|||||||
|
|
||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
|
|
||||||
"github.com/zeromicro/go-zero/core/stringx"
|
"github.com/zeromicro/go-zero/core/stringx"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -793,7 +794,9 @@ func TestUnmarshalStringMapFromNotSettableValue(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
ast := assert.New(t)
|
ast := assert.New(t)
|
||||||
ast.Error(UnmarshalKey(m, &v))
|
ast.NoError(UnmarshalKey(m, &v))
|
||||||
|
assert.Empty(t, v.sort)
|
||||||
|
assert.Nil(t, v.psort)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestUnmarshalStringMapFromString(t *testing.T) {
|
func TestUnmarshalStringMapFromString(t *testing.T) {
|
||||||
@@ -4265,6 +4268,24 @@ func TestUnmarshalStructPtrOfPtr(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestUnmarshalOnlyPublicVariables(t *testing.T) {
|
||||||
|
type demo struct {
|
||||||
|
age int `key:"age"`
|
||||||
|
Name string `key:"name"`
|
||||||
|
}
|
||||||
|
|
||||||
|
m := map[string]interface{}{
|
||||||
|
"age": 3,
|
||||||
|
"name": "go-zero",
|
||||||
|
}
|
||||||
|
|
||||||
|
var in demo
|
||||||
|
if assert.NoError(t, UnmarshalKey(m, &in)) {
|
||||||
|
assert.Equal(t, 0, in.age)
|
||||||
|
assert.Equal(t, "go-zero", in.Name)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func BenchmarkDefaultValue(b *testing.B) {
|
func BenchmarkDefaultValue(b *testing.B) {
|
||||||
for i := 0; i < b.N; i++ {
|
for i := 0; i < b.N; i++ {
|
||||||
var a struct {
|
var a struct {
|
||||||
@@ -4364,3 +4385,56 @@ func BenchmarkUnmarshal(b *testing.B) {
|
|||||||
UnmarshalKey(data, &an)
|
UnmarshalKey(data, &an)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestFillDefaultUnmarshal(t *testing.T) {
|
||||||
|
fillDefaultUnmarshal := NewUnmarshaler(jsonTagKey, WithDefault())
|
||||||
|
t.Run("nil", func(t *testing.T) {
|
||||||
|
type St struct{}
|
||||||
|
err := fillDefaultUnmarshal.Unmarshal(map[string]interface{}{}, St{})
|
||||||
|
assert.Error(t, err)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("not nil", func(t *testing.T) {
|
||||||
|
type St struct{}
|
||||||
|
err := fillDefaultUnmarshal.Unmarshal(map[string]interface{}{}, &St{})
|
||||||
|
assert.NoError(t, err)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("default", func(t *testing.T) {
|
||||||
|
type St struct {
|
||||||
|
A string `json:",default=a"`
|
||||||
|
B string
|
||||||
|
}
|
||||||
|
var st St
|
||||||
|
err := fillDefaultUnmarshal.Unmarshal(map[string]interface{}{}, &st)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, st.A, "a")
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("env", func(t *testing.T) {
|
||||||
|
type St struct {
|
||||||
|
A string `json:",default=a"`
|
||||||
|
B string
|
||||||
|
C string `json:",env=TEST_C"`
|
||||||
|
}
|
||||||
|
t.Setenv("TEST_C", "c")
|
||||||
|
|
||||||
|
var st St
|
||||||
|
err := fillDefaultUnmarshal.Unmarshal(map[string]interface{}{}, &st)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, st.A, "a")
|
||||||
|
assert.Equal(t, st.C, "c")
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("has value", func(t *testing.T) {
|
||||||
|
type St struct {
|
||||||
|
A string `json:",default=a"`
|
||||||
|
B string
|
||||||
|
}
|
||||||
|
var st = St{
|
||||||
|
A: "b",
|
||||||
|
}
|
||||||
|
err := fillDefaultUnmarshal.Unmarshal(map[string]interface{}{}, &st)
|
||||||
|
assert.Error(t, err)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|||||||
@@ -5,6 +5,7 @@ import (
|
|||||||
|
|
||||||
"github.com/prometheus/client_golang/prometheus/testutil"
|
"github.com/prometheus/client_golang/prometheus/testutil"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/zeromicro/go-zero/core/proc"
|
||||||
"github.com/zeromicro/go-zero/core/prometheus"
|
"github.com/zeromicro/go-zero/core/prometheus"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -17,6 +18,9 @@ func TestNewCounterVec(t *testing.T) {
|
|||||||
})
|
})
|
||||||
defer counterVec.close()
|
defer counterVec.close()
|
||||||
counterVecNil := NewCounterVec(nil)
|
counterVecNil := NewCounterVec(nil)
|
||||||
|
counterVec.Inc("path", "code")
|
||||||
|
counterVec.Add(1, "path", "code")
|
||||||
|
proc.Shutdown()
|
||||||
assert.NotNil(t, counterVec)
|
assert.NotNil(t, counterVec)
|
||||||
assert.Nil(t, counterVecNil)
|
assert.Nil(t, counterVecNil)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -5,6 +5,7 @@ import (
|
|||||||
|
|
||||||
"github.com/prometheus/client_golang/prometheus/testutil"
|
"github.com/prometheus/client_golang/prometheus/testutil"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/zeromicro/go-zero/core/proc"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestNewGaugeVec(t *testing.T) {
|
func TestNewGaugeVec(t *testing.T) {
|
||||||
@@ -18,6 +19,8 @@ func TestNewGaugeVec(t *testing.T) {
|
|||||||
gaugeVecNil := NewGaugeVec(nil)
|
gaugeVecNil := NewGaugeVec(nil)
|
||||||
assert.NotNil(t, gaugeVec)
|
assert.NotNil(t, gaugeVec)
|
||||||
assert.Nil(t, gaugeVecNil)
|
assert.Nil(t, gaugeVecNil)
|
||||||
|
|
||||||
|
proc.Shutdown()
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestGaugeInc(t *testing.T) {
|
func TestGaugeInc(t *testing.T) {
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ import (
|
|||||||
|
|
||||||
"github.com/prometheus/client_golang/prometheus/testutil"
|
"github.com/prometheus/client_golang/prometheus/testutil"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/zeromicro/go-zero/core/proc"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestNewHistogramVec(t *testing.T) {
|
func TestNewHistogramVec(t *testing.T) {
|
||||||
@@ -47,4 +48,6 @@ func TestHistogramObserve(t *testing.T) {
|
|||||||
|
|
||||||
err := testutil.CollectAndCompare(hv.histogram, strings.NewReader(metadata+val))
|
err := testutil.CollectAndCompare(hv.histogram, strings.NewReader(metadata+val))
|
||||||
assert.Nil(t, err)
|
assert.Nil(t, err)
|
||||||
|
|
||||||
|
proc.Shutdown()
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -15,5 +15,14 @@ func AddWrapUpListener(fn func()) func() {
|
|||||||
return fn
|
return fn
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetTimeToForceQuit does nothing on windows.
|
||||||
func SetTimeToForceQuit(duration time.Duration) {
|
func SetTimeToForceQuit(duration time.Duration) {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Shutdown does nothing on windows.
|
||||||
|
func Shutdown() {
|
||||||
|
}
|
||||||
|
|
||||||
|
// WrapUp does nothing on windows.
|
||||||
|
func WrapUp() {
|
||||||
|
}
|
||||||
|
|||||||
@@ -43,6 +43,16 @@ func SetTimeToForceQuit(duration time.Duration) {
|
|||||||
delayTimeBeforeForceQuit = duration
|
delayTimeBeforeForceQuit = duration
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Shutdown calls the registered shutdown listeners, only for test purpose.
|
||||||
|
func Shutdown() {
|
||||||
|
shutdownListeners.notifyListeners()
|
||||||
|
}
|
||||||
|
|
||||||
|
// WrapUp wraps up the process, only for test purpose.
|
||||||
|
func WrapUp() {
|
||||||
|
wrapUpListeners.notifyListeners()
|
||||||
|
}
|
||||||
|
|
||||||
func gracefulStop(signals chan os.Signal) {
|
func gracefulStop(signals chan os.Signal) {
|
||||||
signal.Stop(signals)
|
signal.Stop(signals)
|
||||||
|
|
||||||
|
|||||||
@@ -18,14 +18,14 @@ func TestShutdown(t *testing.T) {
|
|||||||
called := AddWrapUpListener(func() {
|
called := AddWrapUpListener(func() {
|
||||||
val++
|
val++
|
||||||
})
|
})
|
||||||
wrapUpListeners.notifyListeners()
|
WrapUp()
|
||||||
called()
|
called()
|
||||||
assert.Equal(t, 1, val)
|
assert.Equal(t, 1, val)
|
||||||
|
|
||||||
called = AddShutdownListener(func() {
|
called = AddShutdownListener(func() {
|
||||||
val += 2
|
val += 2
|
||||||
})
|
})
|
||||||
shutdownListeners.notifyListeners()
|
Shutdown()
|
||||||
called()
|
called()
|
||||||
assert.Equal(t, 3, val)
|
assert.Equal(t, 3, val)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ package service
|
|||||||
import (
|
import (
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/zeromicro/go-zero/core/logx"
|
"github.com/zeromicro/go-zero/core/logx"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -16,3 +17,15 @@ func TestServiceConf(t *testing.T) {
|
|||||||
}
|
}
|
||||||
c.MustSetUp()
|
c.MustSetUp()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestServiceConfWithMetricsUrl(t *testing.T) {
|
||||||
|
c := ServiceConf{
|
||||||
|
Name: "foo",
|
||||||
|
Log: logx.LogConf{
|
||||||
|
Mode: "volume",
|
||||||
|
},
|
||||||
|
Mode: "dev",
|
||||||
|
MetricsUrl: "http://localhost:8080",
|
||||||
|
}
|
||||||
|
assert.NoError(t, c.SetUp())
|
||||||
|
}
|
||||||
|
|||||||
@@ -5,6 +5,7 @@ import (
|
|||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/zeromicro/go-zero/core/proc"
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
@@ -55,6 +56,7 @@ func TestServiceGroup(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
group.Stop()
|
group.Stop()
|
||||||
|
proc.Shutdown()
|
||||||
|
|
||||||
mutex.Lock()
|
mutex.Lock()
|
||||||
defer mutex.Unlock()
|
defer mutex.Unlock()
|
||||||
|
|||||||
5
core/stores/cache/cache.go
vendored
5
core/stores/cache/cache.go
vendored
@@ -9,6 +9,7 @@ import (
|
|||||||
|
|
||||||
"github.com/zeromicro/go-zero/core/errorx"
|
"github.com/zeromicro/go-zero/core/errorx"
|
||||||
"github.com/zeromicro/go-zero/core/hash"
|
"github.com/zeromicro/go-zero/core/hash"
|
||||||
|
"github.com/zeromicro/go-zero/core/stores/redis"
|
||||||
"github.com/zeromicro/go-zero/core/syncx"
|
"github.com/zeromicro/go-zero/core/syncx"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -62,12 +63,12 @@ func New(c ClusterConf, barrier syncx.SingleFlight, st *Stat, errNotFound error,
|
|||||||
}
|
}
|
||||||
|
|
||||||
if len(c) == 1 {
|
if len(c) == 1 {
|
||||||
return NewNode(c[0].NewRedis(), barrier, st, errNotFound, opts...)
|
return NewNode(redis.MustNewRedis(c[0].RedisConf), barrier, st, errNotFound, opts...)
|
||||||
}
|
}
|
||||||
|
|
||||||
dispatcher := hash.NewConsistentHash()
|
dispatcher := hash.NewConsistentHash()
|
||||||
for _, node := range c {
|
for _, node := range c {
|
||||||
cn := NewNode(node.NewRedis(), barrier, st, errNotFound, opts...)
|
cn := NewNode(redis.MustNewRedis(node.RedisConf), barrier, st, errNotFound, opts...)
|
||||||
dispatcher.AddWithWeight(cn, node.Weight)
|
dispatcher.AddWithWeight(cn, node.Weight)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
4
core/stores/cache/cache_test.go
vendored
4
core/stores/cache/cache_test.go
vendored
@@ -163,12 +163,10 @@ func TestCache_SetDel(t *testing.T) {
|
|||||||
r1, err := miniredis.Run()
|
r1, err := miniredis.Run()
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
defer r1.Close()
|
defer r1.Close()
|
||||||
r1.SetError("mock error")
|
|
||||||
|
|
||||||
r2, err := miniredis.Run()
|
r2, err := miniredis.Run()
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
defer r2.Close()
|
defer r2.Close()
|
||||||
r2.SetError("mock error")
|
|
||||||
|
|
||||||
conf := ClusterConf{
|
conf := ClusterConf{
|
||||||
{
|
{
|
||||||
@@ -187,6 +185,8 @@ func TestCache_SetDel(t *testing.T) {
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
c := New(conf, syncx.NewSingleFlight(), NewStat("mock"), errPlaceholder)
|
c := New(conf, syncx.NewSingleFlight(), NewStat("mock"), errPlaceholder)
|
||||||
|
r1.SetError("mock error")
|
||||||
|
r2.SetError("mock error")
|
||||||
assert.NoError(t, c.Del("a", "b", "c"))
|
assert.NoError(t, c.Del("a", "b", "c"))
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|||||||
3
core/stores/cache/cachenode.go
vendored
3
core/stores/cache/cachenode.go
vendored
@@ -277,5 +277,6 @@ func (c cacheNode) processCache(ctx context.Context, key, data string, v interfa
|
|||||||
|
|
||||||
func (c cacheNode) setCacheWithNotFound(ctx context.Context, key string) error {
|
func (c cacheNode) setCacheWithNotFound(ctx context.Context, key string) error {
|
||||||
seconds := int(math.Ceil(c.aroundDuration(c.notFoundExpiry).Seconds()))
|
seconds := int(math.Ceil(c.aroundDuration(c.notFoundExpiry).Seconds()))
|
||||||
return c.rds.SetexCtx(ctx, key, notFoundPlaceholder, seconds)
|
_, err := c.rds.SetnxExCtx(ctx, key, notFoundPlaceholder, seconds)
|
||||||
|
return err
|
||||||
}
|
}
|
||||||
|
|||||||
29
core/stores/cache/cachenode_test.go
vendored
29
core/stores/cache/cachenode_test.go
vendored
@@ -209,6 +209,35 @@ func TestCacheNode_TakeNotFound(t *testing.T) {
|
|||||||
assert.Equal(t, errDummy, err)
|
assert.Equal(t, errDummy, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestCacheNode_TakeNotFoundButChangedByOthers(t *testing.T) {
|
||||||
|
store, clean, err := redistest.CreateRedis()
|
||||||
|
assert.NoError(t, err)
|
||||||
|
defer clean()
|
||||||
|
|
||||||
|
cn := cacheNode{
|
||||||
|
rds: store,
|
||||||
|
r: rand.New(rand.NewSource(time.Now().UnixNano())),
|
||||||
|
barrier: syncx.NewSingleFlight(),
|
||||||
|
lock: new(sync.Mutex),
|
||||||
|
unstableExpiry: mathx.NewUnstable(expiryDeviation),
|
||||||
|
stat: NewStat("any"),
|
||||||
|
errNotFound: errTestNotFound,
|
||||||
|
}
|
||||||
|
|
||||||
|
var str string
|
||||||
|
err = cn.Take(&str, "any", func(v interface{}) error {
|
||||||
|
store.Set("any", "foo")
|
||||||
|
return errTestNotFound
|
||||||
|
})
|
||||||
|
assert.True(t, cn.IsNotFound(err))
|
||||||
|
|
||||||
|
val, err := store.Get("any")
|
||||||
|
if assert.NoError(t, err) {
|
||||||
|
assert.Equal(t, "foo", val)
|
||||||
|
}
|
||||||
|
assert.True(t, cn.IsNotFound(cn.Get("any", &str)))
|
||||||
|
}
|
||||||
|
|
||||||
func TestCacheNode_TakeWithExpire(t *testing.T) {
|
func TestCacheNode_TakeWithExpire(t *testing.T) {
|
||||||
store, clean, err := redistest.CreateRedis()
|
store, clean, err := redistest.CreateRedis()
|
||||||
assert.Nil(t, err)
|
assert.Nil(t, err)
|
||||||
|
|||||||
2
core/stores/cache/cleaner_test.go
vendored
2
core/stores/cache/cleaner_test.go
vendored
@@ -5,6 +5,7 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/zeromicro/go-zero/core/proc"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestNextDelay(t *testing.T) {
|
func TestNextDelay(t *testing.T) {
|
||||||
@@ -51,6 +52,7 @@ func TestNextDelay(t *testing.T) {
|
|||||||
next, ok := nextDelay(test.input)
|
next, ok := nextDelay(test.input)
|
||||||
assert.Equal(t, test.ok, ok)
|
assert.Equal(t, test.ok, ok)
|
||||||
assert.Equal(t, test.output, next)
|
assert.Equal(t, test.output, next)
|
||||||
|
proc.Shutdown()
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -164,7 +164,7 @@ func NewStore(c KvConf) Store {
|
|||||||
// because Store and redis.Redis has different methods.
|
// because Store and redis.Redis has different methods.
|
||||||
dispatcher := hash.NewConsistentHash()
|
dispatcher := hash.NewConsistentHash()
|
||||||
for _, node := range c {
|
for _, node := range c {
|
||||||
cn := node.NewRedis()
|
cn := redis.MustNewRedis(node.RedisConf)
|
||||||
dispatcher.AddWithWeight(cn, node.Weight)
|
dispatcher.AddWithWeight(cn, node.Weight)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -5,7 +5,6 @@ import (
|
|||||||
|
|
||||||
"github.com/zeromicro/go-zero/core/trace"
|
"github.com/zeromicro/go-zero/core/trace"
|
||||||
"go.mongodb.org/mongo-driver/mongo"
|
"go.mongodb.org/mongo-driver/mongo"
|
||||||
"go.opentelemetry.io/otel"
|
|
||||||
"go.opentelemetry.io/otel/attribute"
|
"go.opentelemetry.io/otel/attribute"
|
||||||
"go.opentelemetry.io/otel/codes"
|
"go.opentelemetry.io/otel/codes"
|
||||||
oteltrace "go.opentelemetry.io/otel/trace"
|
oteltrace "go.opentelemetry.io/otel/trace"
|
||||||
@@ -14,11 +13,8 @@ import (
|
|||||||
var mongoCmdAttributeKey = attribute.Key("mongo.cmd")
|
var mongoCmdAttributeKey = attribute.Key("mongo.cmd")
|
||||||
|
|
||||||
func startSpan(ctx context.Context, cmd string) (context.Context, oteltrace.Span) {
|
func startSpan(ctx context.Context, cmd string) (context.Context, oteltrace.Span) {
|
||||||
tracer := otel.Tracer(trace.TraceName)
|
tracer := trace.TracerFromContext(ctx)
|
||||||
ctx, span := tracer.Start(ctx,
|
ctx, span := tracer.Start(ctx, spanName, oteltrace.WithSpanKind(oteltrace.SpanKindClient))
|
||||||
spanName,
|
|
||||||
oteltrace.WithSpanKind(oteltrace.SpanKindClient),
|
|
||||||
)
|
|
||||||
span.SetAttributes(mongoCmdAttributeKey.String(cmd))
|
span.SetAttributes(mongoCmdAttributeKey.String(cmd))
|
||||||
|
|
||||||
return ctx, span
|
return ctx, span
|
||||||
|
|||||||
@@ -9,6 +9,8 @@ var (
|
|||||||
ErrEmptyType = errors.New("empty redis type")
|
ErrEmptyType = errors.New("empty redis type")
|
||||||
// ErrEmptyKey is an error that indicates no redis key is set.
|
// ErrEmptyKey is an error that indicates no redis key is set.
|
||||||
ErrEmptyKey = errors.New("empty redis key")
|
ErrEmptyKey = errors.New("empty redis key")
|
||||||
|
// ErrPing is an error that indicates ping failed.
|
||||||
|
ErrPing = errors.New("ping redis failed")
|
||||||
)
|
)
|
||||||
|
|
||||||
type (
|
type (
|
||||||
@@ -28,6 +30,7 @@ type (
|
|||||||
)
|
)
|
||||||
|
|
||||||
// NewRedis returns a Redis.
|
// NewRedis returns a Redis.
|
||||||
|
// Deprecated: use MustNewRedis or NewRedis instead.
|
||||||
func (rc RedisConf) NewRedis() *Redis {
|
func (rc RedisConf) NewRedis() *Redis {
|
||||||
var opts []Option
|
var opts []Option
|
||||||
if rc.Type == ClusterType {
|
if rc.Type == ClusterType {
|
||||||
|
|||||||
@@ -14,7 +14,6 @@ import (
|
|||||||
"github.com/zeromicro/go-zero/core/mapping"
|
"github.com/zeromicro/go-zero/core/mapping"
|
||||||
"github.com/zeromicro/go-zero/core/timex"
|
"github.com/zeromicro/go-zero/core/timex"
|
||||||
"github.com/zeromicro/go-zero/core/trace"
|
"github.com/zeromicro/go-zero/core/trace"
|
||||||
"go.opentelemetry.io/otel"
|
|
||||||
"go.opentelemetry.io/otel/attribute"
|
"go.opentelemetry.io/otel/attribute"
|
||||||
"go.opentelemetry.io/otel/codes"
|
"go.opentelemetry.io/otel/codes"
|
||||||
oteltrace "go.opentelemetry.io/otel/trace"
|
oteltrace "go.opentelemetry.io/otel/trace"
|
||||||
@@ -25,15 +24,13 @@ const spanName = "redis"
|
|||||||
|
|
||||||
var (
|
var (
|
||||||
startTimeKey = contextKey("startTime")
|
startTimeKey = contextKey("startTime")
|
||||||
durationHook = hook{tracer: otel.Tracer(trace.TraceName)}
|
durationHook = hook{}
|
||||||
redisCmdsAttributeKey = attribute.Key("redis.cmds")
|
redisCmdsAttributeKey = attribute.Key("redis.cmds")
|
||||||
)
|
)
|
||||||
|
|
||||||
type (
|
type (
|
||||||
contextKey string
|
contextKey string
|
||||||
hook struct {
|
hook struct{}
|
||||||
tracer oteltrace.Tracer
|
|
||||||
}
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func (h hook) BeforeProcess(ctx context.Context, cmd red.Cmder) (context.Context, error) {
|
func (h hook) BeforeProcess(ctx context.Context, cmd red.Cmder) (context.Context, error) {
|
||||||
@@ -155,7 +152,9 @@ func logDuration(ctx context.Context, cmds []red.Cmder, duration time.Duration)
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (h hook) startSpan(ctx context.Context, cmds ...red.Cmder) context.Context {
|
func (h hook) startSpan(ctx context.Context, cmds ...red.Cmder) context.Context {
|
||||||
ctx, span := h.tracer.Start(ctx,
|
tracer := trace.TracerFromContext(ctx)
|
||||||
|
|
||||||
|
ctx, span := tracer.Start(ctx,
|
||||||
spanName,
|
spanName,
|
||||||
oteltrace.WithSpanKind(oteltrace.SpanKindClient),
|
oteltrace.WithSpanKind(oteltrace.SpanKindClient),
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"log"
|
||||||
"strconv"
|
"strconv"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@@ -86,7 +87,46 @@ type (
|
|||||||
)
|
)
|
||||||
|
|
||||||
// New returns a Redis with given options.
|
// New returns a Redis with given options.
|
||||||
|
// Deprecated: use MustNewRedis or NewRedis instead.
|
||||||
func New(addr string, opts ...Option) *Redis {
|
func New(addr string, opts ...Option) *Redis {
|
||||||
|
return newRedis(addr, opts...)
|
||||||
|
}
|
||||||
|
|
||||||
|
// MustNewRedis returns a Redis with given options.
|
||||||
|
func MustNewRedis(conf RedisConf, opts ...Option) *Redis {
|
||||||
|
rds, err := NewRedis(conf, opts...)
|
||||||
|
if err != nil {
|
||||||
|
log.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return rds
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewRedis returns a Redis with given options.
|
||||||
|
func NewRedis(conf RedisConf, opts ...Option) (*Redis, error) {
|
||||||
|
if err := conf.Validate(); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if conf.Type == ClusterType {
|
||||||
|
opts = append([]Option{Cluster()}, opts...)
|
||||||
|
}
|
||||||
|
if len(conf.Pass) > 0 {
|
||||||
|
opts = append([]Option{WithPass(conf.Pass)}, opts...)
|
||||||
|
}
|
||||||
|
if conf.Tls {
|
||||||
|
opts = append([]Option{WithTLS()}, opts...)
|
||||||
|
}
|
||||||
|
|
||||||
|
rds := newRedis(conf.Host, opts...)
|
||||||
|
if !rds.Ping() {
|
||||||
|
return nil, ErrPing
|
||||||
|
}
|
||||||
|
|
||||||
|
return rds, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func newRedis(addr string, opts ...Option) *Redis {
|
||||||
r := &Redis{
|
r := &Redis{
|
||||||
Addr: addr,
|
Addr: addr,
|
||||||
Type: NodeType,
|
Type: NodeType,
|
||||||
|
|||||||
@@ -16,6 +16,116 @@ import (
|
|||||||
"github.com/zeromicro/go-zero/core/stringx"
|
"github.com/zeromicro/go-zero/core/stringx"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
func TestNewRedis(t *testing.T) {
|
||||||
|
r1, err := miniredis.Run()
|
||||||
|
assert.NoError(t, err)
|
||||||
|
defer r1.Close()
|
||||||
|
|
||||||
|
r2, err := miniredis.Run()
|
||||||
|
assert.NoError(t, err)
|
||||||
|
defer r2.Close()
|
||||||
|
r2.SetError("mock")
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
RedisConf
|
||||||
|
ok bool
|
||||||
|
redisErr bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "missing host",
|
||||||
|
RedisConf: RedisConf{
|
||||||
|
Host: "",
|
||||||
|
Type: NodeType,
|
||||||
|
Pass: "",
|
||||||
|
},
|
||||||
|
ok: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "missing type",
|
||||||
|
RedisConf: RedisConf{
|
||||||
|
Host: "localhost:6379",
|
||||||
|
Type: "",
|
||||||
|
Pass: "",
|
||||||
|
},
|
||||||
|
ok: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "ok",
|
||||||
|
RedisConf: RedisConf{
|
||||||
|
Host: r1.Addr(),
|
||||||
|
Type: NodeType,
|
||||||
|
Pass: "",
|
||||||
|
},
|
||||||
|
ok: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "ok",
|
||||||
|
RedisConf: RedisConf{
|
||||||
|
Host: r1.Addr(),
|
||||||
|
Type: ClusterType,
|
||||||
|
Pass: "",
|
||||||
|
},
|
||||||
|
ok: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "password",
|
||||||
|
RedisConf: RedisConf{
|
||||||
|
Host: r1.Addr(),
|
||||||
|
Type: NodeType,
|
||||||
|
Pass: "pw",
|
||||||
|
},
|
||||||
|
ok: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "tls",
|
||||||
|
RedisConf: RedisConf{
|
||||||
|
Host: r1.Addr(),
|
||||||
|
Type: NodeType,
|
||||||
|
Tls: true,
|
||||||
|
},
|
||||||
|
ok: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "node error",
|
||||||
|
RedisConf: RedisConf{
|
||||||
|
Host: r2.Addr(),
|
||||||
|
Type: NodeType,
|
||||||
|
Pass: "",
|
||||||
|
},
|
||||||
|
ok: true,
|
||||||
|
redisErr: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "cluster error",
|
||||||
|
RedisConf: RedisConf{
|
||||||
|
Host: r2.Addr(),
|
||||||
|
Type: ClusterType,
|
||||||
|
Pass: "",
|
||||||
|
},
|
||||||
|
ok: true,
|
||||||
|
redisErr: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, test := range tests {
|
||||||
|
t.Run(stringx.RandId(), func(t *testing.T) {
|
||||||
|
rds, err := NewRedis(test.RedisConf)
|
||||||
|
if test.ok {
|
||||||
|
if test.redisErr {
|
||||||
|
assert.Error(t, err)
|
||||||
|
assert.Nil(t, rds)
|
||||||
|
} else {
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.NotNil(t, rds)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
assert.Error(t, err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestRedis_Decr(t *testing.T) {
|
func TestRedis_Decr(t *testing.T) {
|
||||||
runOnRedis(t, func(client *Redis) {
|
runOnRedis(t, func(client *Redis) {
|
||||||
_, err := New(client.Addr, badType()).Decr("a")
|
_, err := New(client.Addr, badType()).Decr("a")
|
||||||
@@ -1651,42 +1761,17 @@ func TestRedis_WithPass(t *testing.T) {
|
|||||||
func runOnRedis(t *testing.T, fn func(client *Redis)) {
|
func runOnRedis(t *testing.T, fn func(client *Redis)) {
|
||||||
logx.Disable()
|
logx.Disable()
|
||||||
|
|
||||||
s, err := miniredis.Run()
|
s := miniredis.RunT(t)
|
||||||
assert.Nil(t, err)
|
fn(MustNewRedis(RedisConf{
|
||||||
defer func() {
|
Host: s.Addr(),
|
||||||
client, err := clientManager.GetResource(s.Addr(), func() (io.Closer, error) {
|
Type: NodeType,
|
||||||
return nil, errors.New("should already exist")
|
}))
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
t.Error(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if client != nil {
|
|
||||||
_ = client.Close()
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
fn(New(s.Addr()))
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func runOnRedisWithError(t *testing.T, fn func(client *Redis)) {
|
func runOnRedisWithError(t *testing.T, fn func(client *Redis)) {
|
||||||
logx.Disable()
|
logx.Disable()
|
||||||
|
|
||||||
s, err := miniredis.Run()
|
s := miniredis.RunT(t)
|
||||||
assert.NoError(t, err)
|
|
||||||
defer func() {
|
|
||||||
client, err := clientManager.GetResource(s.Addr(), func() (io.Closer, error) {
|
|
||||||
return nil, errors.New("should already exist")
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
t.Error(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if client != nil {
|
|
||||||
_ = client.Close()
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
s.SetError("mock error")
|
s.SetError("mock error")
|
||||||
fn(New(s.Addr()))
|
fn(New(s.Addr()))
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -52,14 +52,11 @@ func TestSqlConn(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func buildConn() (mock sqlmock.Sqlmock, err error) {
|
func buildConn() (mock sqlmock.Sqlmock, err error) {
|
||||||
_, err = connManager.GetResource(mockedDatasource, func() (io.Closer, error) {
|
connManager.GetResource(mockedDatasource, func() (io.Closer, error) {
|
||||||
var db *sql.DB
|
var db *sql.DB
|
||||||
var err error
|
var err error
|
||||||
db, mock, err = sqlmock.New()
|
db, mock, err = sqlmock.New()
|
||||||
return &pingedDB{
|
return db, err
|
||||||
DB: db,
|
|
||||||
}, err
|
|
||||||
})
|
})
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -3,7 +3,6 @@ package sqlx
|
|||||||
import (
|
import (
|
||||||
"database/sql"
|
"database/sql"
|
||||||
"io"
|
"io"
|
||||||
"sync"
|
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/zeromicro/go-zero/core/syncx"
|
"github.com/zeromicro/go-zero/core/syncx"
|
||||||
@@ -17,43 +16,29 @@ const (
|
|||||||
|
|
||||||
var connManager = syncx.NewResourceManager()
|
var connManager = syncx.NewResourceManager()
|
||||||
|
|
||||||
type pingedDB struct {
|
func getCachedSqlConn(driverName, server string) (*sql.DB, error) {
|
||||||
*sql.DB
|
|
||||||
once sync.Once
|
|
||||||
}
|
|
||||||
|
|
||||||
func getCachedSqlConn(driverName, server string) (*pingedDB, error) {
|
|
||||||
val, err := connManager.GetResource(server, func() (io.Closer, error) {
|
val, err := connManager.GetResource(server, func() (io.Closer, error) {
|
||||||
conn, err := newDBConnection(driverName, server)
|
conn, err := newDBConnection(driverName, server)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
return &pingedDB{
|
return conn, nil
|
||||||
DB: conn,
|
|
||||||
}, nil
|
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
return val.(*pingedDB), nil
|
return val.(*sql.DB), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func getSqlConn(driverName, server string) (*sql.DB, error) {
|
func getSqlConn(driverName, server string) (*sql.DB, error) {
|
||||||
pdb, err := getCachedSqlConn(driverName, server)
|
conn, err := getCachedSqlConn(driverName, server)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
pdb.once.Do(func() {
|
return conn, nil
|
||||||
err = pdb.Ping()
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return pdb.DB, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func newDBConnection(driverName, datasource string) (*sql.DB, error) {
|
func newDBConnection(driverName, datasource string) (*sql.DB, error) {
|
||||||
@@ -70,5 +55,10 @@ func newDBConnection(driverName, datasource string) (*sql.DB, error) {
|
|||||||
conn.SetMaxOpenConns(maxOpenConns)
|
conn.SetMaxOpenConns(maxOpenConns)
|
||||||
conn.SetConnMaxLifetime(maxLifetime)
|
conn.SetConnMaxLifetime(maxLifetime)
|
||||||
|
|
||||||
|
if err := conn.Ping(); err != nil {
|
||||||
|
_ = conn.Close()
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
return conn, nil
|
return conn, nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -5,7 +5,6 @@ import (
|
|||||||
"database/sql"
|
"database/sql"
|
||||||
|
|
||||||
"github.com/zeromicro/go-zero/core/trace"
|
"github.com/zeromicro/go-zero/core/trace"
|
||||||
"go.opentelemetry.io/otel"
|
|
||||||
"go.opentelemetry.io/otel/attribute"
|
"go.opentelemetry.io/otel/attribute"
|
||||||
"go.opentelemetry.io/otel/codes"
|
"go.opentelemetry.io/otel/codes"
|
||||||
oteltrace "go.opentelemetry.io/otel/trace"
|
oteltrace "go.opentelemetry.io/otel/trace"
|
||||||
@@ -14,11 +13,8 @@ import (
|
|||||||
var sqlAttributeKey = attribute.Key("sql.method")
|
var sqlAttributeKey = attribute.Key("sql.method")
|
||||||
|
|
||||||
func startSpan(ctx context.Context, method string) (context.Context, oteltrace.Span) {
|
func startSpan(ctx context.Context, method string) (context.Context, oteltrace.Span) {
|
||||||
tracer := otel.Tracer(trace.TraceName)
|
tracer := trace.TracerFromContext(ctx)
|
||||||
start, span := tracer.Start(ctx,
|
start, span := tracer.Start(ctx, spanName, oteltrace.WithSpanKind(oteltrace.SpanKindClient))
|
||||||
spanName,
|
|
||||||
oteltrace.WithSpanKind(oteltrace.SpanKindClient),
|
|
||||||
)
|
|
||||||
span.SetAttributes(sqlAttributeKey.String(method))
|
span.SetAttributes(sqlAttributeKey.String(method))
|
||||||
|
|
||||||
return start, span
|
return start, span
|
||||||
|
|||||||
@@ -66,7 +66,7 @@ func format(query string, args ...interface{}) (string, error) {
|
|||||||
switch ch {
|
switch ch {
|
||||||
case '?':
|
case '?':
|
||||||
if argIndex >= numArgs {
|
if argIndex >= numArgs {
|
||||||
return "", fmt.Errorf("error: %d ? in sql, but less arguments provided", argIndex)
|
return "", fmt.Errorf("%d ? in sql, but less arguments provided", argIndex)
|
||||||
}
|
}
|
||||||
|
|
||||||
writeValue(&b, args[argIndex])
|
writeValue(&b, args[argIndex])
|
||||||
@@ -93,7 +93,7 @@ func format(query string, args ...interface{}) (string, error) {
|
|||||||
|
|
||||||
index--
|
index--
|
||||||
if index < 0 || numArgs <= index {
|
if index < 0 || numArgs <= index {
|
||||||
return "", fmt.Errorf("error: wrong index %d in sql", index)
|
return "", fmt.Errorf("wrong index %d in sql", index)
|
||||||
}
|
}
|
||||||
|
|
||||||
writeValue(&b, args[index])
|
writeValue(&b, args[index])
|
||||||
@@ -124,7 +124,7 @@ func format(query string, args ...interface{}) (string, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if argIndex < numArgs {
|
if argIndex < numArgs {
|
||||||
return "", fmt.Errorf("error: %d arguments provided, not matching sql", argIndex)
|
return "", fmt.Errorf("%d arguments provided, not matching sql", argIndex)
|
||||||
}
|
}
|
||||||
|
|
||||||
return b.String(), nil
|
return b.String(), nil
|
||||||
|
|||||||
@@ -14,7 +14,6 @@ func (n *node) add(word string) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
nd := n
|
nd := n
|
||||||
var depth int
|
|
||||||
for i, char := range chars {
|
for i, char := range chars {
|
||||||
if nd.children == nil {
|
if nd.children == nil {
|
||||||
child := new(node)
|
child := new(node)
|
||||||
@@ -23,7 +22,6 @@ func (n *node) add(word string) {
|
|||||||
nd = child
|
nd = child
|
||||||
} else if child, ok := nd.children[char]; ok {
|
} else if child, ok := nd.children[char]; ok {
|
||||||
nd = child
|
nd = child
|
||||||
depth++
|
|
||||||
} else {
|
} else {
|
||||||
child := new(node)
|
child := new(node)
|
||||||
child.depth = i + 1
|
child.depth = i + 1
|
||||||
|
|||||||
@@ -1,6 +1,13 @@
|
|||||||
package stringx
|
package stringx
|
||||||
|
|
||||||
import "strings"
|
import (
|
||||||
|
"sort"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
// replace more than once to avoid overlapped keywords after replace.
|
||||||
|
// only try 2 times to avoid too many or infinite loops.
|
||||||
|
const replaceTimes = 2
|
||||||
|
|
||||||
type (
|
type (
|
||||||
// Replacer interface wraps the Replace method.
|
// Replacer interface wraps the Replace method.
|
||||||
@@ -30,68 +37,48 @@ func NewReplacer(mapping map[string]string) Replacer {
|
|||||||
|
|
||||||
// Replace replaces text with given substitutes.
|
// Replace replaces text with given substitutes.
|
||||||
func (r *replacer) Replace(text string) string {
|
func (r *replacer) Replace(text string) string {
|
||||||
var builder strings.Builder
|
for i := 0; i < replaceTimes; i++ {
|
||||||
var start int
|
var replaced bool
|
||||||
chars := []rune(text)
|
if text, replaced = r.doReplace(text); !replaced {
|
||||||
size := len(chars)
|
return text
|
||||||
|
|
||||||
for start < size {
|
|
||||||
cur := r.node
|
|
||||||
|
|
||||||
if start > 0 {
|
|
||||||
builder.WriteString(string(chars[:start]))
|
|
||||||
}
|
|
||||||
|
|
||||||
for i := start; i < size; i++ {
|
|
||||||
child, ok := cur.children[chars[i]]
|
|
||||||
if ok {
|
|
||||||
cur = child
|
|
||||||
} else if cur == r.node {
|
|
||||||
builder.WriteRune(chars[i])
|
|
||||||
// cur already points to root, set start only
|
|
||||||
start = i + 1
|
|
||||||
continue
|
|
||||||
} else {
|
|
||||||
curDepth := cur.depth
|
|
||||||
cur = cur.fail
|
|
||||||
child, ok = cur.children[chars[i]]
|
|
||||||
if !ok {
|
|
||||||
// write this path
|
|
||||||
builder.WriteString(string(chars[i-curDepth : i+1]))
|
|
||||||
// go to root
|
|
||||||
cur = r.node
|
|
||||||
start = i + 1
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
failDepth := cur.depth
|
|
||||||
// write path before jump
|
|
||||||
builder.WriteString(string(chars[start : start+curDepth-failDepth]))
|
|
||||||
start += curDepth - failDepth
|
|
||||||
cur = child
|
|
||||||
}
|
|
||||||
|
|
||||||
if cur.end {
|
|
||||||
val := string(chars[i+1-cur.depth : i+1])
|
|
||||||
builder.WriteString(r.mapping[val])
|
|
||||||
builder.WriteString(string(chars[i+1:]))
|
|
||||||
// only matching this path, all previous paths are done
|
|
||||||
if start >= i+1-cur.depth && i+1 >= size {
|
|
||||||
return builder.String()
|
|
||||||
}
|
|
||||||
|
|
||||||
chars = []rune(builder.String())
|
|
||||||
size = len(chars)
|
|
||||||
builder.Reset()
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if !cur.end {
|
|
||||||
builder.WriteString(string(chars[start:]))
|
|
||||||
return builder.String()
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return string(chars)
|
return text
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *replacer) doReplace(text string) (string, bool) {
|
||||||
|
chars := []rune(text)
|
||||||
|
scopes := r.find(chars)
|
||||||
|
if len(scopes) == 0 {
|
||||||
|
return text, false
|
||||||
|
}
|
||||||
|
|
||||||
|
sort.Slice(scopes, func(i, j int) bool {
|
||||||
|
if scopes[i].start < scopes[j].start {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
if scopes[i].start == scopes[j].start {
|
||||||
|
return scopes[i].stop > scopes[j].stop
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
})
|
||||||
|
|
||||||
|
var buf strings.Builder
|
||||||
|
var index int
|
||||||
|
for i := 0; i < len(scopes); i++ {
|
||||||
|
scp := &scopes[i]
|
||||||
|
if scp.start < index {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
buf.WriteString(string(chars[index:scp.start]))
|
||||||
|
buf.WriteString(r.mapping[string(chars[scp.start:scp.stop])])
|
||||||
|
index = scp.stop
|
||||||
|
}
|
||||||
|
if index < len(chars) {
|
||||||
|
buf.WriteString(string(chars[index:]))
|
||||||
|
}
|
||||||
|
|
||||||
|
return buf.String(), true
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,5 +1,4 @@
|
|||||||
//go:build go1.18
|
//go:build go1.18
|
||||||
// +build go1.18
|
|
||||||
|
|
||||||
package stringx
|
package stringx
|
||||||
|
|
||||||
|
|||||||
@@ -15,6 +15,15 @@ func TestReplacer_Replace(t *testing.T) {
|
|||||||
assert.Equal(t, "零1234五", NewReplacer(mapping).Replace("零一二三四五"))
|
assert.Equal(t, "零1234五", NewReplacer(mapping).Replace("零一二三四五"))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestReplacer_ReplaceJumpMatch(t *testing.T) {
|
||||||
|
mapping := map[string]string{
|
||||||
|
"abcdeg": "ABCDEG",
|
||||||
|
"cdef": "CDEF",
|
||||||
|
"cde": "CDE",
|
||||||
|
}
|
||||||
|
assert.Equal(t, "abCDEF", NewReplacer(mapping).Replace("abcdef"))
|
||||||
|
}
|
||||||
|
|
||||||
func TestReplacer_ReplaceOverlap(t *testing.T) {
|
func TestReplacer_ReplaceOverlap(t *testing.T) {
|
||||||
mapping := map[string]string{
|
mapping := map[string]string{
|
||||||
"3d": "34",
|
"3d": "34",
|
||||||
@@ -44,6 +53,14 @@ func TestReplacer_ReplacePartialMatch(t *testing.T) {
|
|||||||
assert.Equal(t, "零一二三四五", NewReplacer(mapping).Replace("零一二三四五"))
|
assert.Equal(t, "零一二三四五", NewReplacer(mapping).Replace("零一二三四五"))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestReplacer_ReplacePartialMatchEnds(t *testing.T) {
|
||||||
|
mapping := map[string]string{
|
||||||
|
"二三四七": "2347",
|
||||||
|
"三四": "34",
|
||||||
|
}
|
||||||
|
assert.Equal(t, "零一二34", NewReplacer(mapping).Replace("零一二三四"))
|
||||||
|
}
|
||||||
|
|
||||||
func TestReplacer_ReplaceMultiMatches(t *testing.T) {
|
func TestReplacer_ReplaceMultiMatches(t *testing.T) {
|
||||||
mapping := map[string]string{
|
mapping := map[string]string{
|
||||||
"二三": "23",
|
"二三": "23",
|
||||||
@@ -51,6 +68,54 @@ func TestReplacer_ReplaceMultiMatches(t *testing.T) {
|
|||||||
assert.Equal(t, "零一23四五一23四五", NewReplacer(mapping).Replace("零一二三四五一二三四五"))
|
assert.Equal(t, "零一23四五一23四五", NewReplacer(mapping).Replace("零一二三四五一二三四五"))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestReplacer_ReplaceLongestMatching(t *testing.T) {
|
||||||
|
keywords := map[string]string{
|
||||||
|
"日本": "japan",
|
||||||
|
"日本的首都": "东京",
|
||||||
|
}
|
||||||
|
replacer := NewReplacer(keywords)
|
||||||
|
assert.Equal(t, "东京在japan", replacer.Replace("日本的首都在日本"))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestReplacer_ReplaceSuffixMatch(t *testing.T) {
|
||||||
|
// case1
|
||||||
|
{
|
||||||
|
keywords := map[string]string{
|
||||||
|
"abcde": "ABCDE",
|
||||||
|
"bcde": "BCDE",
|
||||||
|
"bcd": "BCD",
|
||||||
|
}
|
||||||
|
assert.Equal(t, "aBCDf", NewReplacer(keywords).Replace("abcdf"))
|
||||||
|
}
|
||||||
|
// case2
|
||||||
|
{
|
||||||
|
keywords := map[string]string{
|
||||||
|
"abcde": "ABCDE",
|
||||||
|
"bcde": "BCDE",
|
||||||
|
"cde": "CDE",
|
||||||
|
"c": "C",
|
||||||
|
"cd": "CD",
|
||||||
|
}
|
||||||
|
assert.Equal(t, "abCDf", NewReplacer(keywords).Replace("abcdf"))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestReplacer_ReplaceLongestOverlap(t *testing.T) {
|
||||||
|
keywords := map[string]string{
|
||||||
|
"456": "def",
|
||||||
|
"abcd": "1234",
|
||||||
|
}
|
||||||
|
replacer := NewReplacer(keywords)
|
||||||
|
assert.Equal(t, "123def7", replacer.Replace("abcd567"))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestReplacer_ReplaceLongestLonger(t *testing.T) {
|
||||||
|
mapping := map[string]string{
|
||||||
|
"c": "3",
|
||||||
|
}
|
||||||
|
assert.Equal(t, "3d", NewReplacer(mapping).Replace("cd"))
|
||||||
|
}
|
||||||
|
|
||||||
func TestReplacer_ReplaceJumpToFail(t *testing.T) {
|
func TestReplacer_ReplaceJumpToFail(t *testing.T) {
|
||||||
mapping := map[string]string{
|
mapping := map[string]string{
|
||||||
"bcdf": "1235",
|
"bcdf": "1235",
|
||||||
@@ -146,3 +211,21 @@ func TestFuzzReplacerCase2(t *testing.T) {
|
|||||||
t.Errorf("result: %s, match: %v", val, keys)
|
t.Errorf("result: %s, match: %v", val, keys)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestReplacer_ReplaceLongestMatch(t *testing.T) {
|
||||||
|
replacer := NewReplacer(map[string]string{
|
||||||
|
"日本的首都": "东京",
|
||||||
|
"日本": "本日",
|
||||||
|
})
|
||||||
|
assert.Equal(t, "东京是东京", replacer.Replace("日本的首都是东京"))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestReplacer_ReplaceIndefinitely(t *testing.T) {
|
||||||
|
mapping := map[string]string{
|
||||||
|
"日本的首都": "东京",
|
||||||
|
"东京": "日本的首都",
|
||||||
|
}
|
||||||
|
assert.NotPanics(t, func() {
|
||||||
|
NewReplacer(mapping).Replace("日本的首都是东京")
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ package trace
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"net/url"
|
||||||
"sync"
|
"sync"
|
||||||
|
|
||||||
"github.com/zeromicro/go-zero/core/lang"
|
"github.com/zeromicro/go-zero/core/lang"
|
||||||
@@ -57,6 +58,10 @@ func createExporter(c Config) (sdktrace.SpanExporter, error) {
|
|||||||
// Just support jaeger and zipkin now, more for later
|
// Just support jaeger and zipkin now, more for later
|
||||||
switch c.Batcher {
|
switch c.Batcher {
|
||||||
case kindJaeger:
|
case kindJaeger:
|
||||||
|
u, _ := url.Parse(c.Endpoint)
|
||||||
|
if u.Scheme == "udp" {
|
||||||
|
return jaeger.New(jaeger.WithAgentEndpoint(jaeger.WithAgentHost(u.Hostname()), jaeger.WithAgentPort(u.Port())))
|
||||||
|
}
|
||||||
return jaeger.New(jaeger.WithCollectorEndpoint(jaeger.WithEndpoint(c.Endpoint)))
|
return jaeger.New(jaeger.WithCollectorEndpoint(jaeger.WithEndpoint(c.Endpoint)))
|
||||||
case kindZipkin:
|
case kindZipkin:
|
||||||
return zipkin.New(c.Endpoint)
|
return zipkin.New(c.Endpoint)
|
||||||
|
|||||||
@@ -15,6 +15,7 @@ func TestStartAgent(t *testing.T) {
|
|||||||
endpoint2 = "remotehost:1234"
|
endpoint2 = "remotehost:1234"
|
||||||
endpoint3 = "localhost:1235"
|
endpoint3 = "localhost:1235"
|
||||||
endpoint4 = "localhost:1236"
|
endpoint4 = "localhost:1236"
|
||||||
|
endpoint5 = "udp://localhost:6831"
|
||||||
)
|
)
|
||||||
c1 := Config{
|
c1 := Config{
|
||||||
Name: "foo",
|
Name: "foo",
|
||||||
@@ -44,6 +45,11 @@ func TestStartAgent(t *testing.T) {
|
|||||||
Endpoint: endpoint4,
|
Endpoint: endpoint4,
|
||||||
Batcher: kindOtlpHttp,
|
Batcher: kindOtlpHttp,
|
||||||
}
|
}
|
||||||
|
c7 := Config{
|
||||||
|
Name: "UDP",
|
||||||
|
Endpoint: endpoint5,
|
||||||
|
Batcher: kindJaeger,
|
||||||
|
}
|
||||||
|
|
||||||
StartAgent(c1)
|
StartAgent(c1)
|
||||||
StartAgent(c1)
|
StartAgent(c1)
|
||||||
@@ -52,16 +58,19 @@ func TestStartAgent(t *testing.T) {
|
|||||||
StartAgent(c4)
|
StartAgent(c4)
|
||||||
StartAgent(c5)
|
StartAgent(c5)
|
||||||
StartAgent(c6)
|
StartAgent(c6)
|
||||||
|
StartAgent(c7)
|
||||||
|
|
||||||
lock.Lock()
|
lock.Lock()
|
||||||
defer lock.Unlock()
|
defer lock.Unlock()
|
||||||
|
|
||||||
// because remotehost cannot be resolved
|
// because remotehost cannot be resolved
|
||||||
assert.Equal(t, 4, len(agents))
|
assert.Equal(t, 5, len(agents))
|
||||||
_, ok := agents[""]
|
_, ok := agents[""]
|
||||||
assert.True(t, ok)
|
assert.True(t, ok)
|
||||||
_, ok = agents[endpoint1]
|
_, ok = agents[endpoint1]
|
||||||
assert.True(t, ok)
|
assert.True(t, ok)
|
||||||
_, ok = agents[endpoint2]
|
_, ok = agents[endpoint2]
|
||||||
assert.False(t, ok)
|
assert.False(t, ok)
|
||||||
|
_, ok = agents[endpoint5]
|
||||||
|
assert.True(t, ok)
|
||||||
}
|
}
|
||||||
|
|||||||
73
core/trace/message_test.go
Normal file
73
core/trace/message_test.go
Normal file
@@ -0,0 +1,73 @@
|
|||||||
|
package trace
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"go.opentelemetry.io/otel/attribute"
|
||||||
|
"go.opentelemetry.io/otel/codes"
|
||||||
|
"go.opentelemetry.io/otel/trace"
|
||||||
|
"google.golang.org/protobuf/reflect/protoreflect"
|
||||||
|
"google.golang.org/protobuf/types/dynamicpb"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestMessageType_Event(t *testing.T) {
|
||||||
|
var span mockSpan
|
||||||
|
ctx := trace.ContextWithSpan(context.Background(), &span)
|
||||||
|
MessageReceived.Event(ctx, 1, "foo")
|
||||||
|
assert.Equal(t, messageEvent, span.name)
|
||||||
|
assert.NotEmpty(t, span.options)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMessageType_EventProtoMessage(t *testing.T) {
|
||||||
|
var span mockSpan
|
||||||
|
var message mockMessage
|
||||||
|
ctx := trace.ContextWithSpan(context.Background(), &span)
|
||||||
|
MessageReceived.Event(ctx, 1, message)
|
||||||
|
assert.Equal(t, messageEvent, span.name)
|
||||||
|
assert.NotEmpty(t, span.options)
|
||||||
|
}
|
||||||
|
|
||||||
|
type mockSpan struct {
|
||||||
|
name string
|
||||||
|
options []trace.EventOption
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockSpan) End(options ...trace.SpanEndOption) {
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockSpan) AddEvent(name string, options ...trace.EventOption) {
|
||||||
|
m.name = name
|
||||||
|
m.options = options
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockSpan) IsRecording() bool {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockSpan) RecordError(err error, options ...trace.EventOption) {
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockSpan) SpanContext() trace.SpanContext {
|
||||||
|
panic("implement me")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockSpan) SetStatus(code codes.Code, description string) {
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockSpan) SetName(name string) {
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockSpan) SetAttributes(kv ...attribute.KeyValue) {
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockSpan) TracerProvider() trace.TracerProvider {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type mockMessage struct{}
|
||||||
|
|
||||||
|
func (m mockMessage) ProtoReflect() protoreflect.Message {
|
||||||
|
return new(dynamicpb.Message)
|
||||||
|
}
|
||||||
@@ -6,8 +6,10 @@ import (
|
|||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
ztrace "github.com/zeromicro/go-zero/internal/trace"
|
ztrace "github.com/zeromicro/go-zero/internal/trace"
|
||||||
|
"go.opentelemetry.io/otel"
|
||||||
"go.opentelemetry.io/otel/attribute"
|
"go.opentelemetry.io/otel/attribute"
|
||||||
semconv "go.opentelemetry.io/otel/semconv/v1.4.0"
|
semconv "go.opentelemetry.io/otel/semconv/v1.4.0"
|
||||||
|
"go.opentelemetry.io/otel/trace"
|
||||||
"google.golang.org/grpc/peer"
|
"google.golang.org/grpc/peer"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -20,25 +22,6 @@ var (
|
|||||||
TraceIDFromContext = ztrace.TraceIDFromContext
|
TraceIDFromContext = ztrace.TraceIDFromContext
|
||||||
)
|
)
|
||||||
|
|
||||||
// PeerFromCtx returns the peer from ctx.
|
|
||||||
func PeerFromCtx(ctx context.Context) string {
|
|
||||||
p, ok := peer.FromContext(ctx)
|
|
||||||
if !ok || p == nil {
|
|
||||||
return ""
|
|
||||||
}
|
|
||||||
|
|
||||||
return p.Addr.String()
|
|
||||||
}
|
|
||||||
|
|
||||||
// SpanInfo returns the span info.
|
|
||||||
func SpanInfo(fullMethod, peerAddress string) (string, []attribute.KeyValue) {
|
|
||||||
attrs := []attribute.KeyValue{RPCSystemGRPC}
|
|
||||||
name, mAttrs := ParseFullMethod(fullMethod)
|
|
||||||
attrs = append(attrs, mAttrs...)
|
|
||||||
attrs = append(attrs, PeerAttr(peerAddress)...)
|
|
||||||
return name, attrs
|
|
||||||
}
|
|
||||||
|
|
||||||
// ParseFullMethod returns the method name and attributes.
|
// ParseFullMethod returns the method name and attributes.
|
||||||
func ParseFullMethod(fullMethod string) (string, []attribute.KeyValue) {
|
func ParseFullMethod(fullMethod string) (string, []attribute.KeyValue) {
|
||||||
name := strings.TrimLeft(fullMethod, "/")
|
name := strings.TrimLeft(fullMethod, "/")
|
||||||
@@ -75,3 +58,33 @@ func PeerAttr(addr string) []attribute.KeyValue {
|
|||||||
semconv.NetPeerPortKey.String(port),
|
semconv.NetPeerPortKey.String(port),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// PeerFromCtx returns the peer from ctx.
|
||||||
|
func PeerFromCtx(ctx context.Context) string {
|
||||||
|
p, ok := peer.FromContext(ctx)
|
||||||
|
if !ok || p == nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
return p.Addr.String()
|
||||||
|
}
|
||||||
|
|
||||||
|
// SpanInfo returns the span info.
|
||||||
|
func SpanInfo(fullMethod, peerAddress string) (string, []attribute.KeyValue) {
|
||||||
|
attrs := []attribute.KeyValue{RPCSystemGRPC}
|
||||||
|
name, mAttrs := ParseFullMethod(fullMethod)
|
||||||
|
attrs = append(attrs, mAttrs...)
|
||||||
|
attrs = append(attrs, PeerAttr(peerAddress)...)
|
||||||
|
return name, attrs
|
||||||
|
}
|
||||||
|
|
||||||
|
// TracerFromContext returns a tracer in ctx, otherwise returns a global tracer.
|
||||||
|
func TracerFromContext(ctx context.Context) (tracer trace.Tracer) {
|
||||||
|
if span := trace.SpanFromContext(ctx); span.SpanContext().IsValid() {
|
||||||
|
tracer = span.TracerProvider().Tracer(TraceName)
|
||||||
|
} else {
|
||||||
|
tracer = otel.Tracer(TraceName)
|
||||||
|
}
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|||||||
@@ -6,8 +6,12 @@ import (
|
|||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
|
"go.opentelemetry.io/otel"
|
||||||
"go.opentelemetry.io/otel/attribute"
|
"go.opentelemetry.io/otel/attribute"
|
||||||
|
"go.opentelemetry.io/otel/sdk/resource"
|
||||||
|
sdktrace "go.opentelemetry.io/otel/sdk/trace"
|
||||||
semconv "go.opentelemetry.io/otel/semconv/v1.4.0"
|
semconv "go.opentelemetry.io/otel/semconv/v1.4.0"
|
||||||
|
"go.opentelemetry.io/otel/trace"
|
||||||
"google.golang.org/grpc/peer"
|
"google.golang.org/grpc/peer"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -151,3 +155,50 @@ func TestPeerAttr(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestTracerFromContext(t *testing.T) {
|
||||||
|
traceFn := func(ctx context.Context, hasTraceId bool) {
|
||||||
|
spanContext := trace.SpanContextFromContext(ctx)
|
||||||
|
assert.Equal(t, spanContext.IsValid(), hasTraceId)
|
||||||
|
parentTraceId := spanContext.TraceID().String()
|
||||||
|
|
||||||
|
tracer := TracerFromContext(ctx)
|
||||||
|
_, span := tracer.Start(ctx, "b")
|
||||||
|
defer span.End()
|
||||||
|
|
||||||
|
spanContext = span.SpanContext()
|
||||||
|
assert.True(t, spanContext.IsValid())
|
||||||
|
if hasTraceId {
|
||||||
|
assert.Equal(t, parentTraceId, spanContext.TraceID().String())
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
t.Run("context", func(t *testing.T) {
|
||||||
|
opts := []sdktrace.TracerProviderOption{
|
||||||
|
// Set the sampling rate based on the parent span to 100%
|
||||||
|
sdktrace.WithSampler(sdktrace.ParentBased(sdktrace.TraceIDRatioBased(1))),
|
||||||
|
// Record information about this application in a Resource.
|
||||||
|
sdktrace.WithResource(resource.NewSchemaless(semconv.ServiceNameKey.String("test"))),
|
||||||
|
}
|
||||||
|
tp = sdktrace.NewTracerProvider(opts...)
|
||||||
|
otel.SetTracerProvider(tp)
|
||||||
|
ctx, span := tp.Tracer(TraceName).Start(context.Background(), "a")
|
||||||
|
|
||||||
|
defer span.End()
|
||||||
|
traceFn(ctx, true)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("global", func(t *testing.T) {
|
||||||
|
opts := []sdktrace.TracerProviderOption{
|
||||||
|
// Set the sampling rate based on the parent span to 100%
|
||||||
|
sdktrace.WithSampler(sdktrace.ParentBased(sdktrace.TraceIDRatioBased(1))),
|
||||||
|
// Record information about this application in a Resource.
|
||||||
|
sdktrace.WithResource(resource.NewSchemaless(semconv.ServiceNameKey.String("test"))),
|
||||||
|
}
|
||||||
|
tp = sdktrace.NewTracerProvider(opts...)
|
||||||
|
otel.SetTracerProvider(tp)
|
||||||
|
|
||||||
|
traceFn(context.Background(), false)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|||||||
15
readme-cn.md
15
readme-cn.md
@@ -20,9 +20,9 @@
|
|||||||
> ***注意:***
|
> ***注意:***
|
||||||
>
|
>
|
||||||
> 从 v1.3.0 之前版本升级请执行以下命令:
|
> 从 v1.3.0 之前版本升级请执行以下命令:
|
||||||
>
|
>
|
||||||
> `GOPROXY=https://goproxy.cn/,direct go install github.com/zeromicro/go-zero/tools/goctl@latest`
|
> `GOPROXY=https://goproxy.cn/,direct go install github.com/zeromicro/go-zero/tools/goctl@latest`
|
||||||
>
|
>
|
||||||
> `goctl migrate —verbose —version v1.4.3`
|
> `goctl migrate —verbose —version v1.4.3`
|
||||||
|
|
||||||
## 0. go-zero 介绍
|
## 0. go-zero 介绍
|
||||||
@@ -121,10 +121,10 @@ GO111MODULE=on GOPROXY=https://goproxy.cn/,direct go get -u github.com/zeromicro
|
|||||||
```shell
|
```shell
|
||||||
# Go 1.15 及之前版本
|
# Go 1.15 及之前版本
|
||||||
GO111MODULE=on GOPROXY=https://goproxy.cn/,direct go get -u github.com/zeromicro/go-zero/tools/goctl@latest
|
GO111MODULE=on GOPROXY=https://goproxy.cn/,direct go get -u github.com/zeromicro/go-zero/tools/goctl@latest
|
||||||
|
|
||||||
# Go 1.16 及以后版本
|
# Go 1.16 及以后版本
|
||||||
GOPROXY=https://goproxy.cn/,direct go install github.com/zeromicro/go-zero/tools/goctl@latest
|
GOPROXY=https://goproxy.cn/,direct go install github.com/zeromicro/go-zero/tools/goctl@latest
|
||||||
|
|
||||||
# For Mac
|
# For Mac
|
||||||
brew install goctl
|
brew install goctl
|
||||||
|
|
||||||
@@ -200,7 +200,7 @@ GO111MODULE=on GOPROXY=https://goproxy.cn/,direct go get -u github.com/zeromicro
|
|||||||
* [快速构建高并发微服务 - 多 RPC 版](https://github.com/zeromicro/zero-doc/blob/main/docs/zero/bookstore.md)
|
* [快速构建高并发微服务 - 多 RPC 版](https://github.com/zeromicro/zero-doc/blob/main/docs/zero/bookstore.md)
|
||||||
* [goctl 使用帮助](https://github.com/zeromicro/zero-doc/blob/main/doc/goctl.md)
|
* [goctl 使用帮助](https://github.com/zeromicro/zero-doc/blob/main/doc/goctl.md)
|
||||||
* [Examples](https://github.com/zeromicro/zero-examples)
|
* [Examples](https://github.com/zeromicro/zero-examples)
|
||||||
|
|
||||||
* 精选 `goctl` 插件
|
* 精选 `goctl` 插件
|
||||||
|
|
||||||
| 插件 | 用途 |
|
| 插件 | 用途 |
|
||||||
@@ -296,6 +296,11 @@ go-zero 已被许多公司用于生产部署,接入场景如在线教育、电
|
|||||||
>81. 广州机智云物联网科技有限公司
|
>81. 广州机智云物联网科技有限公司
|
||||||
>82. 厦门亿联网络技术股份有限公司
|
>82. 厦门亿联网络技术股份有限公司
|
||||||
>83. 北京麦芽田网络科技有限公司
|
>83. 北京麦芽田网络科技有限公司
|
||||||
|
>84. 佛山市振联科技有限公司
|
||||||
|
>85. 苏州智言信息科技有限公司
|
||||||
|
>86. 中国移动上海产业研究院
|
||||||
|
>87. 天枢数链(浙江)科技有限公司
|
||||||
|
>88. 北京娱人共享智能科技有限公司
|
||||||
|
|
||||||
如果贵公司也已使用 go-zero,欢迎在 [登记地址](https://github.com/zeromicro/go-zero/issues/602) 登记,仅仅为了推广,不做其它用途。
|
如果贵公司也已使用 go-zero,欢迎在 [登记地址](https://github.com/zeromicro/go-zero/issues/602) 登记,仅仅为了推广,不做其它用途。
|
||||||
|
|
||||||
|
|||||||
12
readme.md
12
readme.md
@@ -129,10 +129,10 @@ goctl migrate —verbose —version v1.4.3
|
|||||||
```shell
|
```shell
|
||||||
# for Go 1.15 and earlier
|
# for Go 1.15 and earlier
|
||||||
GO111MODULE=on go get -u github.com/zeromicro/go-zero/tools/goctl@latest
|
GO111MODULE=on go get -u github.com/zeromicro/go-zero/tools/goctl@latest
|
||||||
|
|
||||||
# for Go 1.16 and later
|
# for Go 1.16 and later
|
||||||
go install github.com/zeromicro/go-zero/tools/goctl@latest
|
go install github.com/zeromicro/go-zero/tools/goctl@latest
|
||||||
|
|
||||||
# For Mac
|
# For Mac
|
||||||
brew install goctl
|
brew install goctl
|
||||||
|
|
||||||
@@ -156,24 +156,24 @@ goctl migrate —verbose —version v1.4.3
|
|||||||
Request {
|
Request {
|
||||||
Name string `path:"name,options=[you,me]"` // parameters are auto validated
|
Name string `path:"name,options=[you,me]"` // parameters are auto validated
|
||||||
}
|
}
|
||||||
|
|
||||||
Response {
|
Response {
|
||||||
Message string `json:"message"`
|
Message string `json:"message"`
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
service greet-api {
|
service greet-api {
|
||||||
@handler GreetHandler
|
@handler GreetHandler
|
||||||
get /greet/from/:name(Request) returns (Response)
|
get /greet/from/:name(Request) returns (Response)
|
||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
the .api files also can be generated by goctl, like below:
|
the .api files also can be generated by goctl, like below:
|
||||||
|
|
||||||
```shell
|
```shell
|
||||||
goctl api -o greet.api
|
goctl api -o greet.api
|
||||||
```
|
```
|
||||||
|
|
||||||
4. generate the go server-side code
|
4. generate the go server-side code
|
||||||
|
|
||||||
```shell
|
```shell
|
||||||
|
|||||||
@@ -25,8 +25,10 @@ const topCpuUsage = 1000
|
|||||||
var ErrSignatureConfig = errors.New("bad config for Signature")
|
var ErrSignatureConfig = errors.New("bad config for Signature")
|
||||||
|
|
||||||
type engine struct {
|
type engine struct {
|
||||||
conf RestConf
|
conf RestConf
|
||||||
routes []featuredRoutes
|
routes []featuredRoutes
|
||||||
|
// timeout is the max timeout of all routes
|
||||||
|
timeout time.Duration
|
||||||
unauthorizedCallback handler.UnauthorizedCallback
|
unauthorizedCallback handler.UnauthorizedCallback
|
||||||
unsignedCallback handler.UnsignedCallback
|
unsignedCallback handler.UnsignedCallback
|
||||||
chain chain.Chain
|
chain chain.Chain
|
||||||
@@ -38,8 +40,10 @@ type engine struct {
|
|||||||
|
|
||||||
func newEngine(c RestConf) *engine {
|
func newEngine(c RestConf) *engine {
|
||||||
svr := &engine{
|
svr := &engine{
|
||||||
conf: c,
|
conf: c,
|
||||||
|
timeout: time.Duration(c.Timeout) * time.Millisecond,
|
||||||
}
|
}
|
||||||
|
|
||||||
if c.CpuThreshold > 0 {
|
if c.CpuThreshold > 0 {
|
||||||
svr.shedder = load.NewAdaptiveShedder(load.WithCpuThreshold(c.CpuThreshold))
|
svr.shedder = load.NewAdaptiveShedder(load.WithCpuThreshold(c.CpuThreshold))
|
||||||
svr.priorityShedder = load.NewAdaptiveShedder(load.WithCpuThreshold(
|
svr.priorityShedder = load.NewAdaptiveShedder(load.WithCpuThreshold(
|
||||||
@@ -51,6 +55,12 @@ func newEngine(c RestConf) *engine {
|
|||||||
|
|
||||||
func (ng *engine) addRoutes(r featuredRoutes) {
|
func (ng *engine) addRoutes(r featuredRoutes) {
|
||||||
ng.routes = append(ng.routes, r)
|
ng.routes = append(ng.routes, r)
|
||||||
|
|
||||||
|
// need to guarantee the timeout is the max of all routes
|
||||||
|
// otherwise impossible to set http.Server.ReadTimeout & WriteTimeout
|
||||||
|
if r.timeout > ng.timeout {
|
||||||
|
ng.timeout = r.timeout
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (ng *engine) appendAuthHandler(fr featuredRoutes, chn chain.Chain,
|
func (ng *engine) appendAuthHandler(fr featuredRoutes, chn chain.Chain,
|
||||||
@@ -314,15 +324,15 @@ func (ng *engine) use(middleware Middleware) {
|
|||||||
|
|
||||||
func (ng *engine) withTimeout() internal.StartOption {
|
func (ng *engine) withTimeout() internal.StartOption {
|
||||||
return func(svr *http.Server) {
|
return func(svr *http.Server) {
|
||||||
timeout := ng.conf.Timeout
|
timeout := ng.timeout
|
||||||
if timeout > 0 {
|
if timeout > 0 {
|
||||||
// factor 0.8, to avoid clients send longer content-length than the actual content,
|
// factor 0.8, to avoid clients send longer content-length than the actual content,
|
||||||
// without this timeout setting, the server will time out and respond 503 Service Unavailable,
|
// without this timeout setting, the server will time out and respond 503 Service Unavailable,
|
||||||
// which triggers the circuit breaker.
|
// which triggers the circuit breaker.
|
||||||
svr.ReadTimeout = 4 * time.Duration(timeout) * time.Millisecond / 5
|
svr.ReadTimeout = 4 * timeout / 5
|
||||||
// factor 1.1, to avoid servers don't have enough time to write responses.
|
// factor 1.1, to avoid servers don't have enough time to write responses.
|
||||||
// setting the factor less than 1.0 may lead clients not receiving the responses.
|
// setting the factor less than 1.0 may lead clients not receiving the responses.
|
||||||
svr.WriteTimeout = 11 * time.Duration(timeout) * time.Millisecond / 10
|
svr.WriteTimeout = 11 * timeout / 10
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -43,6 +43,7 @@ Verbose: true
|
|||||||
Path: "/",
|
Path: "/",
|
||||||
Handler: func(w http.ResponseWriter, r *http.Request) {},
|
Handler: func(w http.ResponseWriter, r *http.Request) {},
|
||||||
}},
|
}},
|
||||||
|
timeout: time.Minute,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
priority: true,
|
priority: true,
|
||||||
@@ -53,6 +54,7 @@ Verbose: true
|
|||||||
Path: "/",
|
Path: "/",
|
||||||
Handler: func(w http.ResponseWriter, r *http.Request) {},
|
Handler: func(w http.ResponseWriter, r *http.Request) {},
|
||||||
}},
|
}},
|
||||||
|
timeout: time.Second,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
priority: true,
|
priority: true,
|
||||||
@@ -159,6 +161,11 @@ Verbose: true
|
|||||||
}
|
}
|
||||||
})
|
})
|
||||||
assert.NotNil(t, ng.start(mockedRouter{}))
|
assert.NotNil(t, ng.start(mockedRouter{}))
|
||||||
|
timeout := time.Second * 3
|
||||||
|
if route.timeout > timeout {
|
||||||
|
timeout = route.timeout
|
||||||
|
}
|
||||||
|
assert.Equal(t, timeout, ng.timeout)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,11 +1,13 @@
|
|||||||
package handler
|
package handler
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"bufio"
|
||||||
"bytes"
|
"bytes"
|
||||||
"context"
|
"context"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
"path"
|
"path"
|
||||||
"runtime"
|
"runtime"
|
||||||
@@ -125,6 +127,14 @@ type timeoutWriter struct {
|
|||||||
|
|
||||||
var _ http.Pusher = (*timeoutWriter)(nil)
|
var _ http.Pusher = (*timeoutWriter)(nil)
|
||||||
|
|
||||||
|
func (tw *timeoutWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
|
||||||
|
if hijacked, ok := tw.w.(http.Hijacker); ok {
|
||||||
|
return hijacked.Hijack()
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil, nil, errors.New("server doesn't support hijacking")
|
||||||
|
}
|
||||||
|
|
||||||
// Header returns the underline temporary http.Header.
|
// Header returns the underline temporary http.Header.
|
||||||
func (tw *timeoutWriter) Header() http.Header { return tw.h }
|
func (tw *timeoutWriter) Header() http.Header { return tw.h }
|
||||||
|
|
||||||
|
|||||||
@@ -10,6 +10,7 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/zeromicro/go-zero/rest/internal/response"
|
||||||
)
|
)
|
||||||
|
|
||||||
func init() {
|
func init() {
|
||||||
@@ -134,6 +135,30 @@ func TestTimeoutClientClosed(t *testing.T) {
|
|||||||
assert.Equal(t, statusClientClosedRequest, resp.Code)
|
assert.Equal(t, statusClientClosedRequest, resp.Code)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestTimeoutHijack(t *testing.T) {
|
||||||
|
resp := httptest.NewRecorder()
|
||||||
|
|
||||||
|
writer := &timeoutWriter{
|
||||||
|
w: &response.WithCodeResponseWriter{
|
||||||
|
Writer: resp,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
assert.NotPanics(t, func() {
|
||||||
|
writer.Hijack()
|
||||||
|
})
|
||||||
|
|
||||||
|
writer = &timeoutWriter{
|
||||||
|
w: &response.WithCodeResponseWriter{
|
||||||
|
Writer: mockedHijackable{resp},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
assert.NotPanics(t, func() {
|
||||||
|
writer.Hijack()
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
func TestTimeoutPusher(t *testing.T) {
|
func TestTimeoutPusher(t *testing.T) {
|
||||||
handler := &timeoutWriter{
|
handler := &timeoutWriter{
|
||||||
w: mockedPusher{},
|
w: mockedPusher{},
|
||||||
|
|||||||
@@ -156,12 +156,13 @@ func fillPath(u *nurl.URL, val map[string]interface{}) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func request(r *http.Request, cli client) (*http.Response, error) {
|
func request(r *http.Request, cli client) (*http.Response, error) {
|
||||||
tracer := otel.Tracer(trace.TraceName)
|
ctx := r.Context()
|
||||||
|
tracer := trace.TracerFromContext(ctx)
|
||||||
propagator := otel.GetTextMapPropagator()
|
propagator := otel.GetTextMapPropagator()
|
||||||
|
|
||||||
spanName := r.URL.Path
|
spanName := r.URL.Path
|
||||||
ctx, span := tracer.Start(
|
ctx, span := tracer.Start(
|
||||||
r.Context(),
|
ctx,
|
||||||
spanName,
|
spanName,
|
||||||
oteltrace.WithSpanKind(oteltrace.SpanKindClient),
|
oteltrace.WithSpanKind(oteltrace.SpanKindClient),
|
||||||
oteltrace.WithAttributes(semconv.HTTPClientAttributesFromHTTPRequest(r)...),
|
oteltrace.WithAttributes(semconv.HTTPClientAttributesFromHTTPRequest(r)...),
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ import (
|
|||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
"strings"
|
"strings"
|
||||||
|
"sync/atomic"
|
||||||
|
|
||||||
"github.com/zeromicro/go-zero/core/mapping"
|
"github.com/zeromicro/go-zero/core/mapping"
|
||||||
"github.com/zeromicro/go-zero/rest/internal/encoding"
|
"github.com/zeromicro/go-zero/rest/internal/encoding"
|
||||||
@@ -23,8 +24,15 @@ const (
|
|||||||
var (
|
var (
|
||||||
formUnmarshaler = mapping.NewUnmarshaler(formKey, mapping.WithStringValues())
|
formUnmarshaler = mapping.NewUnmarshaler(formKey, mapping.WithStringValues())
|
||||||
pathUnmarshaler = mapping.NewUnmarshaler(pathKey, mapping.WithStringValues())
|
pathUnmarshaler = mapping.NewUnmarshaler(pathKey, mapping.WithStringValues())
|
||||||
|
validator atomic.Value
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// Validator defines the interface for validating the request.
|
||||||
|
type Validator interface {
|
||||||
|
// Validate validates the request and parsed data.
|
||||||
|
Validate(r *http.Request, data interface{}) error
|
||||||
|
}
|
||||||
|
|
||||||
// Parse parses the request.
|
// Parse parses the request.
|
||||||
func Parse(r *http.Request, v interface{}) error {
|
func Parse(r *http.Request, v interface{}) error {
|
||||||
if err := ParsePath(r, v); err != nil {
|
if err := ParsePath(r, v); err != nil {
|
||||||
@@ -39,7 +47,15 @@ func Parse(r *http.Request, v interface{}) error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
return ParseJsonBody(r, v)
|
if err := ParseJsonBody(r, v); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if val := validator.Load(); val != nil {
|
||||||
|
return val.(Validator).Validate(r, v)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// ParseHeaders parses the headers request.
|
// ParseHeaders parses the headers request.
|
||||||
@@ -101,6 +117,13 @@ func ParsePath(r *http.Request, v interface{}) error {
|
|||||||
return pathUnmarshaler.Unmarshal(m, v)
|
return pathUnmarshaler.Unmarshal(m, v)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetValidator sets the validator.
|
||||||
|
// The validator is used to validate the request, only called in Parse,
|
||||||
|
// not in ParseHeaders, ParseForm, ParseHeader, ParseJsonBody, ParsePath.
|
||||||
|
func SetValidator(val Validator) {
|
||||||
|
validator.Store(val)
|
||||||
|
}
|
||||||
|
|
||||||
func withJsonBody(r *http.Request) bool {
|
func withJsonBody(r *http.Request) bool {
|
||||||
return r.ContentLength > 0 && strings.Contains(r.Header.Get(header.ContentType), header.ApplicationJson)
|
return r.ContentLength > 0 && strings.Contains(r.Header.Get(header.ContentType), header.ApplicationJson)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,8 +1,10 @@
|
|||||||
package httpx
|
package httpx
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"errors"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
|
"reflect"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
@@ -207,9 +209,23 @@ func TestParseJsonBody(t *testing.T) {
|
|||||||
r := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(body))
|
r := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(body))
|
||||||
r.Header.Set(ContentType, header.JsonContentType)
|
r.Header.Set(ContentType, header.JsonContentType)
|
||||||
|
|
||||||
assert.Nil(t, Parse(r, &v))
|
if assert.NoError(t, Parse(r, &v)) {
|
||||||
assert.Equal(t, "kevin", v.Name)
|
assert.Equal(t, "kevin", v.Name)
|
||||||
assert.Equal(t, 18, v.Age)
|
assert.Equal(t, 18, v.Age)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("bad body", func(t *testing.T) {
|
||||||
|
var v struct {
|
||||||
|
Name string `json:"name"`
|
||||||
|
Age int `json:"age"`
|
||||||
|
}
|
||||||
|
|
||||||
|
body := `{"name":"kevin", "ag": 18}`
|
||||||
|
r := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(body))
|
||||||
|
r.Header.Set(ContentType, header.JsonContentType)
|
||||||
|
|
||||||
|
assert.Error(t, Parse(r, &v))
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("hasn't body", func(t *testing.T) {
|
t.Run("hasn't body", func(t *testing.T) {
|
||||||
@@ -308,6 +324,36 @@ func TestParseHeaders_Error(t *testing.T) {
|
|||||||
assert.NotNil(t, Parse(r, &v))
|
assert.NotNil(t, Parse(r, &v))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestParseWithValidator(t *testing.T) {
|
||||||
|
SetValidator(mockValidator{})
|
||||||
|
var v struct {
|
||||||
|
Name string `form:"name"`
|
||||||
|
Age int `form:"age"`
|
||||||
|
Percent float64 `form:"percent,optional"`
|
||||||
|
}
|
||||||
|
|
||||||
|
r, err := http.NewRequest(http.MethodGet, "/a?name=hello&age=18&percent=3.4", http.NoBody)
|
||||||
|
assert.Nil(t, err)
|
||||||
|
if assert.NoError(t, Parse(r, &v)) {
|
||||||
|
assert.Equal(t, "hello", v.Name)
|
||||||
|
assert.Equal(t, 18, v.Age)
|
||||||
|
assert.Equal(t, 3.4, v.Percent)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParseWithValidatorWithError(t *testing.T) {
|
||||||
|
SetValidator(mockValidator{})
|
||||||
|
var v struct {
|
||||||
|
Name string `form:"name"`
|
||||||
|
Age int `form:"age"`
|
||||||
|
Percent float64 `form:"percent,optional"`
|
||||||
|
}
|
||||||
|
|
||||||
|
r, err := http.NewRequest(http.MethodGet, "/a?name=world&age=18&percent=3.4", http.NoBody)
|
||||||
|
assert.Nil(t, err)
|
||||||
|
assert.Error(t, Parse(r, &v))
|
||||||
|
}
|
||||||
|
|
||||||
func BenchmarkParseRaw(b *testing.B) {
|
func BenchmarkParseRaw(b *testing.B) {
|
||||||
r, err := http.NewRequest(http.MethodGet, "http://hello.com/a?name=hello&age=18&percent=3.4", http.NoBody)
|
r, err := http.NewRequest(http.MethodGet, "http://hello.com/a?name=hello&age=18&percent=3.4", http.NoBody)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -351,3 +397,16 @@ func BenchmarkParseAuto(b *testing.B) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type mockValidator struct{}
|
||||||
|
|
||||||
|
func (m mockValidator) Validate(r *http.Request, data interface{}) error {
|
||||||
|
if r.URL.Path == "/a" {
|
||||||
|
val := reflect.ValueOf(data).Elem().FieldByName("Name").String()
|
||||||
|
if val != "hello" {
|
||||||
|
return errors.New("name is not hello")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|||||||
@@ -8,6 +8,7 @@ import (
|
|||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/zeromicro/go-zero/core/proc"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestStartHttp(t *testing.T) {
|
func TestStartHttp(t *testing.T) {
|
||||||
@@ -19,6 +20,7 @@ func TestStartHttp(t *testing.T) {
|
|||||||
svr.IdleTimeout = 0
|
svr.IdleTimeout = 0
|
||||||
})
|
})
|
||||||
assert.NotNil(t, err)
|
assert.NotNil(t, err)
|
||||||
|
proc.WrapUp()
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestStartHttps(t *testing.T) {
|
func TestStartHttps(t *testing.T) {
|
||||||
@@ -30,4 +32,5 @@ func TestStartHttps(t *testing.T) {
|
|||||||
svr.IdleTimeout = 0
|
svr.IdleTimeout = 0
|
||||||
})
|
})
|
||||||
assert.NotNil(t, err)
|
assert.NotNil(t, err)
|
||||||
|
proc.WrapUp()
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -12,6 +12,7 @@ import (
|
|||||||
"github.com/zeromicro/go-zero/tools/goctl/api/new"
|
"github.com/zeromicro/go-zero/tools/goctl/api/new"
|
||||||
"github.com/zeromicro/go-zero/tools/goctl/api/tsgen"
|
"github.com/zeromicro/go-zero/tools/goctl/api/tsgen"
|
||||||
"github.com/zeromicro/go-zero/tools/goctl/api/validate"
|
"github.com/zeromicro/go-zero/tools/goctl/api/validate"
|
||||||
|
"github.com/zeromicro/go-zero/tools/goctl/config"
|
||||||
"github.com/zeromicro/go-zero/tools/goctl/plugin"
|
"github.com/zeromicro/go-zero/tools/goctl/plugin"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -127,7 +128,7 @@ func init() {
|
|||||||
"https://github.com/zeromicro/go-zero-template directory structure")
|
"https://github.com/zeromicro/go-zero-template directory structure")
|
||||||
goCmd.Flags().StringVar(&gogen.VarStringBranch, "branch", "", "The branch of "+
|
goCmd.Flags().StringVar(&gogen.VarStringBranch, "branch", "", "The branch of "+
|
||||||
"the remote repo, it does work with --remote")
|
"the remote repo, it does work with --remote")
|
||||||
goCmd.Flags().StringVar(&gogen.VarStringStyle, "style", "gozero", "The file naming format,"+
|
goCmd.Flags().StringVar(&gogen.VarStringStyle, "style", config.DefaultFormat, "The file naming format,"+
|
||||||
" see [https://github.com/zeromicro/go-zero/blob/master/tools/goctl/config/readme.md]")
|
" see [https://github.com/zeromicro/go-zero/blob/master/tools/goctl/config/readme.md]")
|
||||||
|
|
||||||
javaCmd.Flags().StringVar(&javagen.VarStringDir, "dir", "", "The target dir")
|
javaCmd.Flags().StringVar(&javagen.VarStringDir, "dir", "", "The target dir")
|
||||||
@@ -146,7 +147,7 @@ func init() {
|
|||||||
"https://github.com/zeromicro/go-zero-template directory structure")
|
"https://github.com/zeromicro/go-zero-template directory structure")
|
||||||
newCmd.Flags().StringVar(&new.VarStringBranch, "branch", "", "The branch of "+
|
newCmd.Flags().StringVar(&new.VarStringBranch, "branch", "", "The branch of "+
|
||||||
"the remote repo, it does work with --remote")
|
"the remote repo, it does work with --remote")
|
||||||
newCmd.Flags().StringVar(&new.VarStringStyle, "style", "gozero", "The file naming format,"+
|
newCmd.Flags().StringVar(&new.VarStringStyle, "style", config.DefaultFormat, "The file naming format,"+
|
||||||
" see [https://github.com/zeromicro/go-zero/blob/master/tools/goctl/config/readme.md]")
|
" see [https://github.com/zeromicro/go-zero/blob/master/tools/goctl/config/readme.md]")
|
||||||
|
|
||||||
pluginCmd.Flags().StringVarP(&plugin.VarStringPlugin, "plugin", "p", "", "The plugin file")
|
pluginCmd.Flags().StringVarP(&plugin.VarStringPlugin, "plugin", "p", "", "The plugin file")
|
||||||
@@ -157,7 +158,6 @@ func init() {
|
|||||||
|
|
||||||
tsCmd.Flags().StringVar(&tsgen.VarStringDir, "dir", "", "The target dir")
|
tsCmd.Flags().StringVar(&tsgen.VarStringDir, "dir", "", "The target dir")
|
||||||
tsCmd.Flags().StringVar(&tsgen.VarStringAPI, "api", "", "The api file")
|
tsCmd.Flags().StringVar(&tsgen.VarStringAPI, "api", "", "The api file")
|
||||||
tsCmd.Flags().StringVar(&tsgen.VarStringWebAPI, "webapi", "", "The web api file path")
|
|
||||||
tsCmd.Flags().StringVar(&tsgen.VarStringCaller, "caller", "", "The web api caller")
|
tsCmd.Flags().StringVar(&tsgen.VarStringCaller, "caller", "", "The web api caller")
|
||||||
tsCmd.Flags().BoolVar(&tsgen.VarBoolUnWrap, "unwrap", false, "Unwrap the webapi caller for import")
|
tsCmd.Flags().BoolVar(&tsgen.VarBoolUnWrap, "unwrap", false, "Unwrap the webapi caller for import")
|
||||||
|
|
||||||
|
|||||||
@@ -30,19 +30,21 @@ Future {{pathToFuncName .Path}}( {{if ne .Method "get"}}{{with .RequestType}}{{.
|
|||||||
{{end}}`
|
{{end}}`
|
||||||
|
|
||||||
const apiTemplateV2 = `import 'api.dart';
|
const apiTemplateV2 = `import 'api.dart';
|
||||||
import '../data/{{with .Info}}{{getBaseName .Title}}{{end}}.dart';
|
import '../data/{{with .Service}}{{.Name}}{{end}}.dart';
|
||||||
{{with .Service}}
|
{{with .Service}}
|
||||||
/// {{.Name}}
|
/// {{.Name}}
|
||||||
{{range .Routes}}
|
{{range $i, $Route := .Routes}}
|
||||||
/// --{{.Path}}--
|
/// --{{.Path}}--
|
||||||
///
|
///
|
||||||
/// request: {{with .RequestType}}{{.Name}}{{end}}
|
/// request: {{with .RequestType}}{{.Name}}{{end}}
|
||||||
/// response: {{with .ResponseType}}{{.Name}}{{end}}
|
/// response: {{with .ResponseType}}{{.Name}}{{end}}
|
||||||
Future {{pathToFuncName .Path}}( {{if ne .Method "get"}}{{with .RequestType}}{{.Name}} request,{{end}}{{end}}
|
Future {{normalizeHandlerName .Handler}}(
|
||||||
|
{{if hasUrlPathParams $Route}}{{extractPositionalParamsFromPath $Route}},{{end}}
|
||||||
|
{{if ne .Method "get"}}{{with .RequestType}}{{.Name}} request,{{end}}{{end}}
|
||||||
{Function({{with .ResponseType}}{{.Name}}{{end}})? ok,
|
{Function({{with .ResponseType}}{{.Name}}{{end}})? ok,
|
||||||
Function(String)? fail,
|
Function(String)? fail,
|
||||||
Function? eventually}) async {
|
Function? eventually}) async {
|
||||||
await api{{if eq .Method "get"}}Get{{else}}Post{{end}}('{{.Path}}',{{if ne .Method "get"}}request,{{end}}
|
await api{{if eq .Method "get"}}Get{{else}}Post{{end}}({{makeDartRequestUrlPath $Route}},{{if ne .Method "get"}}request,{{end}}
|
||||||
ok: (data) {
|
ok: (data) {
|
||||||
if (ok != null) ok({{with .ResponseType}}{{.Name}}.fromJson(data){{end}});
|
if (ok != null) ok({{with .ResponseType}}{{.Name}}.fromJson(data){{end}});
|
||||||
}, fail: fail, eventually: eventually);
|
}, fail: fail, eventually: eventually);
|
||||||
|
|||||||
@@ -31,7 +31,7 @@ Future<Tokens> getTokens() async {
|
|||||||
try {
|
try {
|
||||||
var sp = await SharedPreferences.getInstance();
|
var sp = await SharedPreferences.getInstance();
|
||||||
var str = sp.getString('tokens');
|
var str = sp.getString('tokens');
|
||||||
if (str.isEmpty) {
|
if (str == null || str.isEmpty) {
|
||||||
return null;
|
return null;
|
||||||
}
|
}
|
||||||
return Tokens.fromJson(jsonDecode(str));
|
return Tokens.fromJson(jsonDecode(str));
|
||||||
@@ -65,7 +65,7 @@ Future<Tokens?> getTokens() async {
|
|||||||
try {
|
try {
|
||||||
var sp = await SharedPreferences.getInstance();
|
var sp = await SharedPreferences.getInstance();
|
||||||
var str = sp.getString('tokens');
|
var str = sp.getString('tokens');
|
||||||
if (str.isEmpty) {
|
if (str == null || str.isEmpty) {
|
||||||
return null;
|
return null;
|
||||||
}
|
}
|
||||||
return Tokens.fromJson(jsonDecode(str));
|
return Tokens.fromJson(jsonDecode(str));
|
||||||
|
|||||||
@@ -11,6 +11,18 @@ import (
|
|||||||
"github.com/zeromicro/go-zero/tools/goctl/api/util"
|
"github.com/zeromicro/go-zero/tools/goctl/api/util"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
formTagKey = "form"
|
||||||
|
pathTagKey = "path"
|
||||||
|
headerTagKey = "header"
|
||||||
|
)
|
||||||
|
|
||||||
|
func normalizeHandlerName(handlerName string) string {
|
||||||
|
handler := strings.Replace(handlerName, "Handler", "", 1)
|
||||||
|
handler = lowCamelCase(handler)
|
||||||
|
return handler
|
||||||
|
}
|
||||||
|
|
||||||
func lowCamelCase(s string) string {
|
func lowCamelCase(s string) string {
|
||||||
if len(s) < 1 {
|
if len(s) < 1 {
|
||||||
return ""
|
return ""
|
||||||
@@ -20,21 +32,6 @@ func lowCamelCase(s string) string {
|
|||||||
return util.ToLower(s[:1]) + s[1:]
|
return util.ToLower(s[:1]) + s[1:]
|
||||||
}
|
}
|
||||||
|
|
||||||
func pathToFuncName(path string) string {
|
|
||||||
if !strings.HasPrefix(path, "/") {
|
|
||||||
path = "/" + path
|
|
||||||
}
|
|
||||||
if !strings.HasPrefix(path, "/api") {
|
|
||||||
path = "/api" + path
|
|
||||||
}
|
|
||||||
|
|
||||||
path = strings.Replace(path, "/", "_", -1)
|
|
||||||
path = strings.Replace(path, "-", "_", -1)
|
|
||||||
|
|
||||||
camel := util.ToCamelCase(path)
|
|
||||||
return util.ToLower(camel[:1]) + camel[1:]
|
|
||||||
}
|
|
||||||
|
|
||||||
func getBaseName(str string) string {
|
func getBaseName(str string) string {
|
||||||
return path.Base(str)
|
return path.Base(str)
|
||||||
}
|
}
|
||||||
@@ -170,3 +167,46 @@ func primitiveType(tp string) (string, bool) {
|
|||||||
|
|
||||||
return "", false
|
return "", false
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func hasUrlPathParams(route spec.Route) bool {
|
||||||
|
ds, ok := route.RequestType.(spec.DefineStruct)
|
||||||
|
if !ok {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
return len(route.RequestTypeName()) > 0 && len(ds.GetTagMembers(pathTagKey)) > 0
|
||||||
|
}
|
||||||
|
|
||||||
|
func extractPositionalParamsFromPath(route spec.Route) string {
|
||||||
|
ds, ok := route.RequestType.(spec.DefineStruct)
|
||||||
|
if !ok {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
var params []string
|
||||||
|
for _, member := range ds.GetTagMembers(pathTagKey) {
|
||||||
|
dartType := member.Type.Name()
|
||||||
|
params = append(params, fmt.Sprintf("%s %s", dartType, getPropertyFromMember(member)))
|
||||||
|
}
|
||||||
|
|
||||||
|
return strings.Join(params, ", ")
|
||||||
|
}
|
||||||
|
|
||||||
|
func makeDartRequestUrlPath(route spec.Route) string {
|
||||||
|
path := route.Path
|
||||||
|
if route.RequestType == nil {
|
||||||
|
return `"` + path + `"`
|
||||||
|
}
|
||||||
|
|
||||||
|
ds, ok := route.RequestType.(spec.DefineStruct)
|
||||||
|
if !ok {
|
||||||
|
return path
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, member := range ds.GetTagMembers(pathTagKey) {
|
||||||
|
paramName := member.Tags()[0].Name
|
||||||
|
path = strings.ReplaceAll(path, ":"+paramName, "${"+getPropertyFromMember(member)+"}")
|
||||||
|
}
|
||||||
|
|
||||||
|
return `"` + path + `"`
|
||||||
|
}
|
||||||
|
|||||||
@@ -3,13 +3,16 @@ package dartgen
|
|||||||
import "text/template"
|
import "text/template"
|
||||||
|
|
||||||
var funcMap = template.FuncMap{
|
var funcMap = template.FuncMap{
|
||||||
"getBaseName": getBaseName,
|
"getBaseName": getBaseName,
|
||||||
"getPropertyFromMember": getPropertyFromMember,
|
"getPropertyFromMember": getPropertyFromMember,
|
||||||
"isDirectType": isDirectType,
|
"isDirectType": isDirectType,
|
||||||
"isClassListType": isClassListType,
|
"isClassListType": isClassListType,
|
||||||
"getCoreType": getCoreType,
|
"getCoreType": getCoreType,
|
||||||
"pathToFuncName": pathToFuncName,
|
"lowCamelCase": lowCamelCase,
|
||||||
"lowCamelCase": lowCamelCase,
|
"normalizeHandlerName": normalizeHandlerName,
|
||||||
|
"hasUrlPathParams": hasUrlPathParams,
|
||||||
|
"extractPositionalParamsFromPath": extractPositionalParamsFromPath,
|
||||||
|
"makeDartRequestUrlPath": makeDartRequestUrlPath,
|
||||||
}
|
}
|
||||||
|
|
||||||
const (
|
const (
|
||||||
|
|||||||
@@ -5,6 +5,7 @@ import (
|
|||||||
"os"
|
"os"
|
||||||
"path"
|
"path"
|
||||||
"sort"
|
"sort"
|
||||||
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
"text/template"
|
"text/template"
|
||||||
"time"
|
"time"
|
||||||
@@ -36,7 +37,7 @@ func RegisterHandlers(server *rest.Server, serverCtx *svc.ServiceContext) {
|
|||||||
`
|
`
|
||||||
routesAdditionTemplate = `
|
routesAdditionTemplate = `
|
||||||
server.AddRoutes(
|
server.AddRoutes(
|
||||||
{{.routes}} {{.jwt}}{{.signature}} {{.prefix}} {{.timeout}}
|
{{.routes}} {{.jwt}}{{.signature}} {{.prefix}} {{.timeout}} {{.maxBytes}}
|
||||||
)
|
)
|
||||||
`
|
`
|
||||||
timeoutThreshold = time.Millisecond
|
timeoutThreshold = time.Millisecond
|
||||||
@@ -64,6 +65,7 @@ type (
|
|||||||
middlewares []string
|
middlewares []string
|
||||||
prefix string
|
prefix string
|
||||||
jwtTrans string
|
jwtTrans string
|
||||||
|
maxBytes string
|
||||||
}
|
}
|
||||||
route struct {
|
route struct {
|
||||||
method string
|
method string
|
||||||
@@ -127,10 +129,20 @@ rest.WithPrefix("%s"),`, g.prefix)
|
|||||||
return fmt.Errorf("timeout should not less than 1ms, now %v", duration)
|
return fmt.Errorf("timeout should not less than 1ms, now %v", duration)
|
||||||
}
|
}
|
||||||
|
|
||||||
timeout = fmt.Sprintf("rest.WithTimeout(%d * time.Millisecond),", duration/time.Millisecond)
|
timeout = fmt.Sprintf("\n rest.WithTimeout(%d * time.Millisecond),", duration/time.Millisecond)
|
||||||
hasTimeout = true
|
hasTimeout = true
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var maxBytes string
|
||||||
|
if len(g.maxBytes) > 0 {
|
||||||
|
_, err := strconv.ParseInt(g.maxBytes, 10, 64)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("maxBytes %s parse error,it is an invalid number", g.maxBytes)
|
||||||
|
}
|
||||||
|
|
||||||
|
maxBytes = fmt.Sprintf("\n rest.WithMaxBytes(%s),", g.maxBytes)
|
||||||
|
}
|
||||||
|
|
||||||
var routes string
|
var routes string
|
||||||
if len(g.middlewares) > 0 {
|
if len(g.middlewares) > 0 {
|
||||||
gbuilder.WriteString("\n}...,")
|
gbuilder.WriteString("\n}...,")
|
||||||
@@ -152,6 +164,7 @@ rest.WithPrefix("%s"),`, g.prefix)
|
|||||||
"signature": signature,
|
"signature": signature,
|
||||||
"prefix": prefix,
|
"prefix": prefix,
|
||||||
"timeout": timeout,
|
"timeout": timeout,
|
||||||
|
"maxBytes": maxBytes,
|
||||||
}); err != nil {
|
}); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@@ -230,6 +243,7 @@ func getRoutes(api *spec.ApiSpec) ([]group, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
groupedRoutes.timeout = g.GetAnnotation("timeout")
|
groupedRoutes.timeout = g.GetAnnotation("timeout")
|
||||||
|
groupedRoutes.maxBytes = g.GetAnnotation("maxBytes")
|
||||||
|
|
||||||
jwt := g.GetAnnotation("jwt")
|
jwt := g.GetAnnotation("jwt")
|
||||||
if len(jwt) > 0 {
|
if len(jwt) > 0 {
|
||||||
|
|||||||
@@ -176,7 +176,7 @@ func (v *ApiVisitor) VisitAtHandler(ctx *api.AtHandlerContext) interface{} {
|
|||||||
return &atHandler
|
return &atHandler
|
||||||
}
|
}
|
||||||
|
|
||||||
// serVisitRoute implements from api.BaseApiParserVisitor
|
// VisitRoute implements from api.BaseApiParserVisitor
|
||||||
func (v *ApiVisitor) VisitRoute(ctx *api.RouteContext) interface{} {
|
func (v *ApiVisitor) VisitRoute(ctx *api.RouteContext) interface{} {
|
||||||
var route Route
|
var route Route
|
||||||
path := ctx.Path()
|
path := ctx.Path()
|
||||||
|
|||||||
@@ -39,6 +39,10 @@ func TsCommand(_ *cobra.Command, _ []string) error {
|
|||||||
return errors.New("missing -dir")
|
return errors.New("missing -dir")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if len(webAPI) == 0 {
|
||||||
|
webAPI = "."
|
||||||
|
}
|
||||||
|
|
||||||
api, err := parser.Parse(apiFile)
|
api, err := parser.Parse(apiFile)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
fmt.Println(aurora.Red("Failed"))
|
fmt.Println(aurora.Red("Failed"))
|
||||||
@@ -51,6 +55,7 @@ func TsCommand(_ *cobra.Command, _ []string) error {
|
|||||||
|
|
||||||
api.Service = api.Service.JoinPrefix()
|
api.Service = api.Service.JoinPrefix()
|
||||||
logx.Must(pathx.MkdirIfNotExist(dir))
|
logx.Must(pathx.MkdirIfNotExist(dir))
|
||||||
|
logx.Must(genRequest(dir))
|
||||||
logx.Must(genHandler(dir, webAPI, caller, api, unwrapAPI))
|
logx.Must(genHandler(dir, webAPI, caller, api, unwrapAPI))
|
||||||
logx.Must(genComponents(dir, api))
|
logx.Must(genComponents(dir, api))
|
||||||
|
|
||||||
|
|||||||
@@ -39,7 +39,7 @@ func genHandler(dir, webAPI, caller string, api *spec.ApiSpec, unwrapAPI bool) e
|
|||||||
importCaller = "{ " + importCaller + " }"
|
importCaller = "{ " + importCaller + " }"
|
||||||
}
|
}
|
||||||
if len(webAPI) > 0 {
|
if len(webAPI) > 0 {
|
||||||
imports += `import ` + importCaller + ` from ` + "\"" + webAPI + "\""
|
imports += `import ` + importCaller + ` from ` + `"./gocliRequest"`
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(api.Types) != 0 {
|
if len(api.Types) != 0 {
|
||||||
|
|||||||
26
tools/goctl/api/tsgen/genrequest.go
Normal file
26
tools/goctl/api/tsgen/genrequest.go
Normal file
@@ -0,0 +1,26 @@
|
|||||||
|
package tsgen
|
||||||
|
|
||||||
|
import (
|
||||||
|
_ "embed"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
|
||||||
|
"github.com/zeromicro/go-zero/tools/goctl/util/pathx"
|
||||||
|
)
|
||||||
|
|
||||||
|
//go:embed request.ts
|
||||||
|
var requestTemplate string
|
||||||
|
|
||||||
|
func genRequest(dir string) error {
|
||||||
|
abs, err := filepath.Abs(dir)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
filename := filepath.Join(abs, "gocliRequest.ts")
|
||||||
|
if pathx.FileExists(filename) {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return os.WriteFile(filename, []byte(requestTemplate), 0644)
|
||||||
|
}
|
||||||
126
tools/goctl/api/tsgen/request.ts
Normal file
126
tools/goctl/api/tsgen/request.ts
Normal file
@@ -0,0 +1,126 @@
|
|||||||
|
export type Method =
|
||||||
|
| 'get'
|
||||||
|
| 'GET'
|
||||||
|
| 'delete'
|
||||||
|
| 'DELETE'
|
||||||
|
| 'head'
|
||||||
|
| 'HEAD'
|
||||||
|
| 'options'
|
||||||
|
| 'OPTIONS'
|
||||||
|
| 'post'
|
||||||
|
| 'POST'
|
||||||
|
| 'put'
|
||||||
|
| 'PUT'
|
||||||
|
| 'patch'
|
||||||
|
| 'PATCH';
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Parse route parameters for responseType
|
||||||
|
*/
|
||||||
|
const reg = /:[a-z|A-Z]+/g;
|
||||||
|
|
||||||
|
export function parseParams(url: string): Array<string> {
|
||||||
|
const ps = url.match(reg);
|
||||||
|
if (!ps) {
|
||||||
|
return [];
|
||||||
|
}
|
||||||
|
return ps.map((k) => k.replace(/:/, ''));
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Generate url and parameters
|
||||||
|
* @param url
|
||||||
|
* @param params
|
||||||
|
*/
|
||||||
|
export function genUrl(url: string, params: unknown) {
|
||||||
|
if (!params) {
|
||||||
|
return url;
|
||||||
|
}
|
||||||
|
|
||||||
|
const ps = parseParams(url);
|
||||||
|
ps.forEach((k) => {
|
||||||
|
const reg = new RegExp(`:${k}`);
|
||||||
|
url = url.replace(reg, params[k]);
|
||||||
|
});
|
||||||
|
|
||||||
|
const path: Array<string> = [];
|
||||||
|
for (const key of Object.keys(params)) {
|
||||||
|
if (!ps.find((k) => k === key)) {
|
||||||
|
path.push(`${key}=${params[key]}`);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return url + (path.length > 0 ? `?${path.join('&')}` : '');
|
||||||
|
}
|
||||||
|
|
||||||
|
export async function request({
|
||||||
|
method,
|
||||||
|
url,
|
||||||
|
data,
|
||||||
|
config = {}
|
||||||
|
}: {
|
||||||
|
method: Method;
|
||||||
|
url: string;
|
||||||
|
data?: unknown;
|
||||||
|
config?: unknown;
|
||||||
|
}) {
|
||||||
|
const response = await fetch(url, {
|
||||||
|
method: method.toLocaleUpperCase(),
|
||||||
|
credentials: 'include',
|
||||||
|
headers: {
|
||||||
|
'Content-Type': 'application/json'
|
||||||
|
},
|
||||||
|
body: data ? JSON.stringify(data) : undefined,
|
||||||
|
// @ts-ignore
|
||||||
|
...config
|
||||||
|
});
|
||||||
|
|
||||||
|
return response.json();
|
||||||
|
}
|
||||||
|
|
||||||
|
function api<T>(
|
||||||
|
method: Method = 'get',
|
||||||
|
url: string,
|
||||||
|
req: any,
|
||||||
|
config?: unknown
|
||||||
|
): Promise<T> {
|
||||||
|
if (url.match(/:/) || method.match(/get|delete/i)) {
|
||||||
|
url = genUrl(url, req.params || req.forms);
|
||||||
|
}
|
||||||
|
method = method.toLocaleLowerCase() as Method;
|
||||||
|
|
||||||
|
switch (method) {
|
||||||
|
case 'get':
|
||||||
|
return request({method: 'get', url, data: req, config});
|
||||||
|
case 'delete':
|
||||||
|
return request({method: 'delete', url, data: req, config});
|
||||||
|
case 'put':
|
||||||
|
return request({method: 'put', url, data: req, config});
|
||||||
|
case 'post':
|
||||||
|
return request({method: 'post', url, data: req, config});
|
||||||
|
case 'patch':
|
||||||
|
return request({method: 'patch', url, data: req, config});
|
||||||
|
default:
|
||||||
|
return request({method: 'post', url, data: req, config});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
export const webapi = {
|
||||||
|
get<T>(url: string, req: unknown, config?: unknown): Promise<T> {
|
||||||
|
return api<T>('get', url, req, config);
|
||||||
|
},
|
||||||
|
delete<T>(url: string, req: unknown, config?: unknown): Promise<T> {
|
||||||
|
return api<T>('delete', url, req, config);
|
||||||
|
},
|
||||||
|
put<T>(url: string, req: unknown, config?: unknown): Promise<T> {
|
||||||
|
return api<T>('get', url, req, config);
|
||||||
|
},
|
||||||
|
post<T>(url: string, req: unknown, config?: unknown): Promise<T> {
|
||||||
|
return api<T>('post', url, req, config);
|
||||||
|
},
|
||||||
|
patch<T>(url: string, req: unknown, config?: unknown): Promise<T> {
|
||||||
|
return api<T>('patch', url, req, config);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
export default webapi
|
||||||
@@ -70,7 +70,7 @@ spec:
|
|||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
apiVersion: autoscaling/v2beta1
|
apiVersion: autoscaling/v2beta2
|
||||||
kind: HorizontalPodAutoscaler
|
kind: HorizontalPodAutoscaler
|
||||||
metadata:
|
metadata:
|
||||||
name: {{.Name}}-hpa-c
|
name: {{.Name}}-hpa-c
|
||||||
@@ -88,11 +88,13 @@ spec:
|
|||||||
- type: Resource
|
- type: Resource
|
||||||
resource:
|
resource:
|
||||||
name: cpu
|
name: cpu
|
||||||
targetAverageUtilization: 80
|
target:
|
||||||
|
type: Utilization
|
||||||
|
averageUtilization: 80
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
apiVersion: autoscaling/v2beta1
|
apiVersion: autoscaling/v2beta2
|
||||||
kind: HorizontalPodAutoscaler
|
kind: HorizontalPodAutoscaler
|
||||||
metadata:
|
metadata:
|
||||||
name: {{.Name}}-hpa-m
|
name: {{.Name}}-hpa-m
|
||||||
@@ -110,4 +112,6 @@ spec:
|
|||||||
- type: Resource
|
- type: Resource
|
||||||
resource:
|
resource:
|
||||||
name: memory
|
name: memory
|
||||||
targetAverageUtilization: 80
|
target:
|
||||||
|
type: Utilization
|
||||||
|
averageUtilization: 80
|
||||||
|
|||||||
@@ -53,7 +53,7 @@ func format(query string, args ...interface{}) (string, error) {
|
|||||||
for _, ch := range query {
|
for _, ch := range query {
|
||||||
if ch == '?' {
|
if ch == '?' {
|
||||||
if argIndex >= numArgs {
|
if argIndex >= numArgs {
|
||||||
return "", fmt.Errorf("error: %d ? in sql, but less arguments provided", argIndex)
|
return "", fmt.Errorf("%d ? in sql, but less arguments provided", argIndex)
|
||||||
}
|
}
|
||||||
|
|
||||||
arg := args[argIndex]
|
arg := args[argIndex]
|
||||||
@@ -79,7 +79,7 @@ func format(query string, args ...interface{}) (string, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if argIndex < numArgs {
|
if argIndex < numArgs {
|
||||||
return "", fmt.Errorf("error: %d ? in sql, but more arguments provided", argIndex)
|
return "", fmt.Errorf("%d ? in sql, but more arguments provided", argIndex)
|
||||||
}
|
}
|
||||||
|
|
||||||
return b.String(), nil
|
return b.String(), nil
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ package rpc
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"github.com/spf13/cobra"
|
"github.com/spf13/cobra"
|
||||||
|
"github.com/zeromicro/go-zero/tools/goctl/config"
|
||||||
"github.com/zeromicro/go-zero/tools/goctl/rpc/cli"
|
"github.com/zeromicro/go-zero/tools/goctl/rpc/cli"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -53,7 +54,7 @@ func init() {
|
|||||||
|
|
||||||
newCmd.Flags().StringSliceVar(&cli.VarStringSliceGoOpt, "go_opt", nil, "")
|
newCmd.Flags().StringSliceVar(&cli.VarStringSliceGoOpt, "go_opt", nil, "")
|
||||||
newCmd.Flags().StringSliceVar(&cli.VarStringSliceGoGRPCOpt, "go-grpc_opt", nil, "")
|
newCmd.Flags().StringSliceVar(&cli.VarStringSliceGoGRPCOpt, "go-grpc_opt", nil, "")
|
||||||
newCmd.Flags().StringVar(&cli.VarStringStyle, "style", "gozero", "The file "+
|
newCmd.Flags().StringVar(&cli.VarStringStyle, "style", config.DefaultFormat, "The file "+
|
||||||
"naming format, see [https://github.com/zeromicro/go-zero/tree/master/tools/goctl/config/readme.md]")
|
"naming format, see [https://github.com/zeromicro/go-zero/tree/master/tools/goctl/config/readme.md]")
|
||||||
newCmd.Flags().BoolVar(&cli.VarBoolIdea, "idea", false, "Whether the command "+
|
newCmd.Flags().BoolVar(&cli.VarBoolIdea, "idea", false, "Whether the command "+
|
||||||
"execution environment is from idea plugin.")
|
"execution environment is from idea plugin.")
|
||||||
@@ -79,7 +80,7 @@ func init() {
|
|||||||
protocCmd.Flags().StringSliceVar(&cli.VarStringSlicePlugin, "plugin", nil, "")
|
protocCmd.Flags().StringSliceVar(&cli.VarStringSlicePlugin, "plugin", nil, "")
|
||||||
protocCmd.Flags().StringSliceVarP(&cli.VarStringSliceProtoPath, "proto_path", "I", nil, "")
|
protocCmd.Flags().StringSliceVarP(&cli.VarStringSliceProtoPath, "proto_path", "I", nil, "")
|
||||||
protocCmd.Flags().StringVar(&cli.VarStringZRPCOut, "zrpc_out", "", "The zrpc output directory")
|
protocCmd.Flags().StringVar(&cli.VarStringZRPCOut, "zrpc_out", "", "The zrpc output directory")
|
||||||
protocCmd.Flags().StringVar(&cli.VarStringStyle, "style", "gozero", "The file "+
|
protocCmd.Flags().StringVar(&cli.VarStringStyle, "style", config.DefaultFormat, "The file "+
|
||||||
"naming format, see [https://github.com/zeromicro/go-zero/tree/master/tools/goctl/config/readme.md]")
|
"naming format, see [https://github.com/zeromicro/go-zero/tree/master/tools/goctl/config/readme.md]")
|
||||||
protocCmd.Flags().StringVar(&cli.VarStringHome, "home", "", "The goctl home "+
|
protocCmd.Flags().StringVar(&cli.VarStringHome, "home", "", "The goctl home "+
|
||||||
"path of the template, --home and --remote cannot be set at the same time, if they are, "+
|
"path of the template, --home and --remote cannot be set at the same time, if they are, "+
|
||||||
|
|||||||
@@ -63,7 +63,24 @@ func (g *Generator) genCallGroup(ctx DirContext, proto parser.Proto, cfg *conf.C
|
|||||||
isCallPkgSameToPbPkg := childDir == ctx.GetProtoGo().Filename
|
isCallPkgSameToPbPkg := childDir == ctx.GetProtoGo().Filename
|
||||||
isCallPkgSameToGrpcPkg := childDir == ctx.GetProtoGo().Filename
|
isCallPkgSameToGrpcPkg := childDir == ctx.GetProtoGo().Filename
|
||||||
|
|
||||||
functions, err := g.genFunction(proto.PbPackage, service, isCallPkgSameToGrpcPkg)
|
serviceName := stringx.From(service.Name).ToCamel()
|
||||||
|
alias := collection.NewSet()
|
||||||
|
var hasSameNameBetweenMessageAndService bool
|
||||||
|
for _, item := range proto.Message {
|
||||||
|
msgName := getMessageName(*item.Message)
|
||||||
|
if serviceName == msgName {
|
||||||
|
hasSameNameBetweenMessageAndService = true
|
||||||
|
}
|
||||||
|
if !isCallPkgSameToPbPkg {
|
||||||
|
alias.AddStr(fmt.Sprintf("%s = %s", parser.CamelCase(msgName),
|
||||||
|
fmt.Sprintf("%s.%s", proto.PbPackage, parser.CamelCase(msgName))))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if hasSameNameBetweenMessageAndService {
|
||||||
|
serviceName = stringx.From(service.Name + "_zrpc_client").ToCamel()
|
||||||
|
}
|
||||||
|
|
||||||
|
functions, err := g.genFunction(proto.PbPackage, serviceName, service, isCallPkgSameToGrpcPkg)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@@ -78,15 +95,6 @@ func (g *Generator) genCallGroup(ctx DirContext, proto parser.Proto, cfg *conf.C
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
alias := collection.NewSet()
|
|
||||||
if !isCallPkgSameToPbPkg {
|
|
||||||
for _, item := range proto.Message {
|
|
||||||
msgName := getMessageName(*item.Message)
|
|
||||||
alias.AddStr(fmt.Sprintf("%s = %s", parser.CamelCase(msgName),
|
|
||||||
fmt.Sprintf("%s.%s", proto.PbPackage, parser.CamelCase(msgName))))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pbPackage := fmt.Sprintf(`"%s"`, ctx.GetPb().Package)
|
pbPackage := fmt.Sprintf(`"%s"`, ctx.GetPb().Package)
|
||||||
protoGoPackage := fmt.Sprintf(`"%s"`, ctx.GetProtoGo().Package)
|
protoGoPackage := fmt.Sprintf(`"%s"`, ctx.GetProtoGo().Package)
|
||||||
if isCallPkgSameToGrpcPkg {
|
if isCallPkgSameToGrpcPkg {
|
||||||
@@ -103,7 +111,7 @@ func (g *Generator) genCallGroup(ctx DirContext, proto parser.Proto, cfg *conf.C
|
|||||||
"filePackage": dir.Base,
|
"filePackage": dir.Base,
|
||||||
"pbPackage": pbPackage,
|
"pbPackage": pbPackage,
|
||||||
"protoGoPackage": protoGoPackage,
|
"protoGoPackage": protoGoPackage,
|
||||||
"serviceName": stringx.From(service.Name).ToCamel(),
|
"serviceName": serviceName,
|
||||||
"functions": strings.Join(functions, pathx.NL),
|
"functions": strings.Join(functions, pathx.NL),
|
||||||
"interface": strings.Join(iFunctions, pathx.NL),
|
"interface": strings.Join(iFunctions, pathx.NL),
|
||||||
}, filename, true); err != nil {
|
}, filename, true); err != nil {
|
||||||
@@ -126,8 +134,26 @@ func (g *Generator) genCallInCompatibility(ctx DirContext, proto parser.Proto,
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
serviceName := stringx.From(service.Name).ToCamel()
|
||||||
|
alias := collection.NewSet()
|
||||||
|
var hasSameNameBetweenMessageAndService bool
|
||||||
|
for _, item := range proto.Message {
|
||||||
|
msgName := getMessageName(*item.Message)
|
||||||
|
if serviceName == msgName {
|
||||||
|
hasSameNameBetweenMessageAndService = true
|
||||||
|
}
|
||||||
|
if !isCallPkgSameToPbPkg {
|
||||||
|
alias.AddStr(fmt.Sprintf("%s = %s", parser.CamelCase(msgName),
|
||||||
|
fmt.Sprintf("%s.%s", proto.PbPackage, parser.CamelCase(msgName))))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if hasSameNameBetweenMessageAndService {
|
||||||
|
serviceName = stringx.From(service.Name + "_zrpc_client").ToCamel()
|
||||||
|
}
|
||||||
|
|
||||||
filename := filepath.Join(dir.Filename, fmt.Sprintf("%s.go", callFilename))
|
filename := filepath.Join(dir.Filename, fmt.Sprintf("%s.go", callFilename))
|
||||||
functions, err := g.genFunction(proto.PbPackage, service, isCallPkgSameToGrpcPkg)
|
functions, err := g.genFunction(proto.PbPackage, serviceName, service, isCallPkgSameToGrpcPkg)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@@ -142,15 +168,6 @@ func (g *Generator) genCallInCompatibility(ctx DirContext, proto parser.Proto,
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
alias := collection.NewSet()
|
|
||||||
if !isCallPkgSameToPbPkg {
|
|
||||||
for _, item := range proto.Message {
|
|
||||||
msgName := getMessageName(*item.Message)
|
|
||||||
alias.AddStr(fmt.Sprintf("%s = %s", parser.CamelCase(msgName),
|
|
||||||
fmt.Sprintf("%s.%s", proto.PbPackage, parser.CamelCase(msgName))))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pbPackage := fmt.Sprintf(`"%s"`, ctx.GetPb().Package)
|
pbPackage := fmt.Sprintf(`"%s"`, ctx.GetPb().Package)
|
||||||
protoGoPackage := fmt.Sprintf(`"%s"`, ctx.GetProtoGo().Package)
|
protoGoPackage := fmt.Sprintf(`"%s"`, ctx.GetProtoGo().Package)
|
||||||
if isCallPkgSameToGrpcPkg {
|
if isCallPkgSameToGrpcPkg {
|
||||||
@@ -166,7 +183,7 @@ func (g *Generator) genCallInCompatibility(ctx DirContext, proto parser.Proto,
|
|||||||
"filePackage": dir.Base,
|
"filePackage": dir.Base,
|
||||||
"pbPackage": pbPackage,
|
"pbPackage": pbPackage,
|
||||||
"protoGoPackage": protoGoPackage,
|
"protoGoPackage": protoGoPackage,
|
||||||
"serviceName": stringx.From(service.Name).ToCamel(),
|
"serviceName": serviceName,
|
||||||
"functions": strings.Join(functions, pathx.NL),
|
"functions": strings.Join(functions, pathx.NL),
|
||||||
"interface": strings.Join(iFunctions, pathx.NL),
|
"interface": strings.Join(iFunctions, pathx.NL),
|
||||||
}, filename, true)
|
}, filename, true)
|
||||||
@@ -194,7 +211,7 @@ func getMessageName(msg proto.Message) string {
|
|||||||
return strings.Join(list, "_")
|
return strings.Join(list, "_")
|
||||||
}
|
}
|
||||||
|
|
||||||
func (g *Generator) genFunction(goPackage string, service parser.Service,
|
func (g *Generator) genFunction(goPackage string, serviceName string, service parser.Service,
|
||||||
isCallPkgSameToGrpcPkg bool) ([]string, error) {
|
isCallPkgSameToGrpcPkg bool) ([]string, error) {
|
||||||
functions := make([]string, 0)
|
functions := make([]string, 0)
|
||||||
|
|
||||||
@@ -212,7 +229,7 @@ func (g *Generator) genFunction(goPackage string, service parser.Service,
|
|||||||
parser.CamelCase(rpc.Name), "Client")
|
parser.CamelCase(rpc.Name), "Client")
|
||||||
}
|
}
|
||||||
buffer, err := util.With("sharedFn").Parse(text).Execute(map[string]interface{}{
|
buffer, err := util.With("sharedFn").Parse(text).Execute(map[string]interface{}{
|
||||||
"serviceName": stringx.From(service.Name).ToCamel(),
|
"serviceName": serviceName,
|
||||||
"rpcServiceName": parser.CamelCase(service.Name),
|
"rpcServiceName": parser.CamelCase(service.Name),
|
||||||
"method": parser.CamelCase(rpc.Name),
|
"method": parser.CamelCase(rpc.Name),
|
||||||
"package": goPackage,
|
"package": goPackage,
|
||||||
|
|||||||
@@ -21,9 +21,11 @@ func NewGenerator(style string, verbose bool) *Generator {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatalln(err)
|
log.Fatalln(err)
|
||||||
}
|
}
|
||||||
log := console.NewColorConsole(verbose)
|
|
||||||
|
colorLogger := console.NewColorConsole(verbose)
|
||||||
|
|
||||||
return &Generator{
|
return &Generator{
|
||||||
log: log,
|
log: colorLogger,
|
||||||
cfg: cfg,
|
cfg: cfg,
|
||||||
verbose: verbose,
|
verbose: verbose,
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -8,8 +8,11 @@ import (
|
|||||||
"github.com/zeromicro/go-zero/zrpc/internal/auth"
|
"github.com/zeromicro/go-zero/zrpc/internal/auth"
|
||||||
"github.com/zeromicro/go-zero/zrpc/internal/clientinterceptors"
|
"github.com/zeromicro/go-zero/zrpc/internal/clientinterceptors"
|
||||||
"google.golang.org/grpc"
|
"google.golang.org/grpc"
|
||||||
|
"google.golang.org/grpc/keepalive"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
const defaultClientKeepaliveTime = 20 * time.Second
|
||||||
|
|
||||||
var (
|
var (
|
||||||
// WithDialOption is an alias of internal.WithDialOption.
|
// WithDialOption is an alias of internal.WithDialOption.
|
||||||
WithDialOption = internal.WithDialOption
|
WithDialOption = internal.WithDialOption
|
||||||
@@ -62,6 +65,11 @@ func NewClient(c RpcClientConf, options ...ClientOption) (Client, error) {
|
|||||||
if c.Timeout > 0 {
|
if c.Timeout > 0 {
|
||||||
opts = append(opts, WithTimeout(time.Duration(c.Timeout)*time.Millisecond))
|
opts = append(opts, WithTimeout(time.Duration(c.Timeout)*time.Millisecond))
|
||||||
}
|
}
|
||||||
|
if c.KeepaliveTime > 0 {
|
||||||
|
opts = append(opts, WithDialOption(grpc.WithKeepaliveParams(keepalive.ClientParameters{
|
||||||
|
Time: c.KeepaliveTime,
|
||||||
|
})))
|
||||||
|
}
|
||||||
|
|
||||||
opts = append(opts, options...)
|
opts = append(opts, options...)
|
||||||
|
|
||||||
@@ -90,6 +98,12 @@ func NewClientWithTarget(target string, opts ...ClientOption) (Client, error) {
|
|||||||
Timeout: true,
|
Timeout: true,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
opts = append([]ClientOption{
|
||||||
|
WithDialOption(grpc.WithKeepaliveParams(keepalive.ClientParameters{
|
||||||
|
Time: defaultClientKeepaliveTime,
|
||||||
|
})),
|
||||||
|
}, opts...)
|
||||||
|
|
||||||
return internal.NewClient(target, middlewares, opts...)
|
return internal.NewClient(target, middlewares, opts...)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -113,10 +113,11 @@ func TestDepositServer_Deposit(t *testing.T) {
|
|||||||
)
|
)
|
||||||
tarConfClient := MustNewClient(
|
tarConfClient := MustNewClient(
|
||||||
RpcClientConf{
|
RpcClientConf{
|
||||||
Target: "foo",
|
Target: "foo",
|
||||||
App: "foo",
|
App: "foo",
|
||||||
Token: "bar",
|
Token: "bar",
|
||||||
Timeout: 1000,
|
Timeout: 1000,
|
||||||
|
KeepaliveTime: time.Second * 15,
|
||||||
Middlewares: ClientMiddlewaresConf{
|
Middlewares: ClientMiddlewaresConf{
|
||||||
Trace: true,
|
Trace: true,
|
||||||
Duration: true,
|
Duration: true,
|
||||||
|
|||||||
@@ -1,6 +1,8 @@
|
|||||||
package zrpc
|
package zrpc
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/zeromicro/go-zero/core/discov"
|
"github.com/zeromicro/go-zero/core/discov"
|
||||||
"github.com/zeromicro/go-zero/core/service"
|
"github.com/zeromicro/go-zero/core/service"
|
||||||
"github.com/zeromicro/go-zero/core/stores/redis"
|
"github.com/zeromicro/go-zero/core/stores/redis"
|
||||||
@@ -14,6 +16,19 @@ type (
|
|||||||
// ServerMiddlewaresConf defines whether to use server middlewares.
|
// ServerMiddlewaresConf defines whether to use server middlewares.
|
||||||
ServerMiddlewaresConf = internal.ServerMiddlewaresConf
|
ServerMiddlewaresConf = internal.ServerMiddlewaresConf
|
||||||
|
|
||||||
|
// A RpcClientConf is a rpc client config.
|
||||||
|
RpcClientConf struct {
|
||||||
|
Etcd discov.EtcdConf `json:",optional,inherit"`
|
||||||
|
Endpoints []string `json:",optional"`
|
||||||
|
Target string `json:",optional"`
|
||||||
|
App string `json:",optional"`
|
||||||
|
Token string `json:",optional"`
|
||||||
|
NonBlock bool `json:",optional"`
|
||||||
|
Timeout int64 `json:",default=2000"`
|
||||||
|
KeepaliveTime time.Duration `json:",default=20s"`
|
||||||
|
Middlewares ClientMiddlewaresConf
|
||||||
|
}
|
||||||
|
|
||||||
// A RpcServerConf is a rpc server config.
|
// A RpcServerConf is a rpc server config.
|
||||||
RpcServerConf struct {
|
RpcServerConf struct {
|
||||||
service.ServiceConf
|
service.ServiceConf
|
||||||
@@ -29,18 +44,6 @@ type (
|
|||||||
Health bool `json:",default=true"`
|
Health bool `json:",default=true"`
|
||||||
Middlewares ServerMiddlewaresConf
|
Middlewares ServerMiddlewaresConf
|
||||||
}
|
}
|
||||||
|
|
||||||
// A RpcClientConf is a rpc client config.
|
|
||||||
RpcClientConf struct {
|
|
||||||
Etcd discov.EtcdConf `json:",optional,inherit"`
|
|
||||||
Endpoints []string `json:",optional"`
|
|
||||||
Target string `json:",optional"`
|
|
||||||
App string `json:",optional"`
|
|
||||||
Token string `json:",optional"`
|
|
||||||
NonBlock bool `json:",optional"`
|
|
||||||
Timeout int64 `json:",default=2000"`
|
|
||||||
Middlewares ClientMiddlewaresConf
|
|
||||||
}
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// NewDirectClientConf returns a RpcClientConf.
|
// NewDirectClientConf returns a RpcClientConf.
|
||||||
|
|||||||
@@ -10,10 +10,35 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
func TestRpcClientConf(t *testing.T) {
|
func TestRpcClientConf(t *testing.T) {
|
||||||
conf := NewDirectClientConf([]string{"localhost:1234"}, "foo", "bar")
|
t.Run("direct", func(t *testing.T) {
|
||||||
assert.True(t, conf.HasCredential())
|
conf := NewDirectClientConf([]string{"localhost:1234"}, "foo", "bar")
|
||||||
conf = NewEtcdClientConf([]string{"localhost:1234", "localhost:5678"}, "key", "foo", "bar")
|
assert.True(t, conf.HasCredential())
|
||||||
assert.True(t, conf.HasCredential())
|
})
|
||||||
|
|
||||||
|
t.Run("etcd", func(t *testing.T) {
|
||||||
|
conf := NewEtcdClientConf([]string{"localhost:1234", "localhost:5678"},
|
||||||
|
"key", "foo", "bar")
|
||||||
|
assert.True(t, conf.HasCredential())
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("etcd with account", func(t *testing.T) {
|
||||||
|
conf := NewEtcdClientConf([]string{"localhost:1234", "localhost:5678"},
|
||||||
|
"key", "foo", "bar")
|
||||||
|
conf.Etcd.User = "user"
|
||||||
|
conf.Etcd.Pass = "pass"
|
||||||
|
_, err := conf.BuildTarget()
|
||||||
|
assert.NoError(t, err)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("etcd with tls", func(t *testing.T) {
|
||||||
|
conf := NewEtcdClientConf([]string{"localhost:1234", "localhost:5678"},
|
||||||
|
"key", "foo", "bar")
|
||||||
|
conf.Etcd.CertFile = "cert"
|
||||||
|
conf.Etcd.CertKeyFile = "key"
|
||||||
|
conf.Etcd.CACertFile = "ca"
|
||||||
|
_, err := conf.BuildTarget()
|
||||||
|
assert.Error(t, err)
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestRpcServerConf(t *testing.T) {
|
func TestRpcServerConf(t *testing.T) {
|
||||||
|
|||||||
@@ -4,9 +4,21 @@ import (
|
|||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/zeromicro/go-zero/core/discov"
|
||||||
"github.com/zeromicro/go-zero/core/netx"
|
"github.com/zeromicro/go-zero/core/netx"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
func TestNewRpcPubServer(t *testing.T) {
|
||||||
|
s, err := NewRpcPubServer(discov.EtcdConf{
|
||||||
|
User: "user",
|
||||||
|
Pass: "pass",
|
||||||
|
}, "", ServerMiddlewaresConf{})
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.NotPanics(t, func() {
|
||||||
|
s.Start(nil)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
func TestFigureOutListenOn(t *testing.T) {
|
func TestFigureOutListenOn(t *testing.T) {
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
input string
|
input string
|
||||||
@@ -24,6 +36,10 @@ func TestFigureOutListenOn(t *testing.T) {
|
|||||||
input: ":8080",
|
input: ":8080",
|
||||||
expect: netx.InternalIp() + ":8080",
|
expect: netx.InternalIp() + ":8080",
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
input: "",
|
||||||
|
expect: netx.InternalIp(),
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, test := range tests {
|
for _, test := range tests {
|
||||||
|
|||||||
@@ -59,12 +59,10 @@ func (s *rpcServer) Start(register RegisterFn) error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
unaryInterceptors := s.buildUnaryInterceptors()
|
unaryInterceptorOption := grpc.ChainUnaryInterceptor(s.buildUnaryInterceptors()...)
|
||||||
unaryInterceptors = append(unaryInterceptors, s.unaryInterceptors...)
|
streamInterceptorOption := grpc.ChainStreamInterceptor(s.buildStreamInterceptors()...)
|
||||||
streamInterceptors := s.buildStreamInterceptors()
|
|
||||||
streamInterceptors = append(streamInterceptors, s.streamInterceptors...)
|
options := append(s.options, unaryInterceptorOption, streamInterceptorOption)
|
||||||
options := append(s.options, grpc.ChainUnaryInterceptor(unaryInterceptors...),
|
|
||||||
grpc.ChainStreamInterceptor(streamInterceptors...))
|
|
||||||
server := grpc.NewServer(options...)
|
server := grpc.NewServer(options...)
|
||||||
register(server)
|
register(server)
|
||||||
|
|
||||||
@@ -102,7 +100,7 @@ func (s *rpcServer) buildStreamInterceptors() []grpc.StreamServerInterceptor {
|
|||||||
interceptors = append(interceptors, serverinterceptors.StreamBreakerInterceptor)
|
interceptors = append(interceptors, serverinterceptors.StreamBreakerInterceptor)
|
||||||
}
|
}
|
||||||
|
|
||||||
return interceptors
|
return append(interceptors, s.streamInterceptors...)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *rpcServer) buildUnaryInterceptors() []grpc.UnaryServerInterceptor {
|
func (s *rpcServer) buildUnaryInterceptors() []grpc.UnaryServerInterceptor {
|
||||||
@@ -124,7 +122,7 @@ func (s *rpcServer) buildUnaryInterceptors() []grpc.UnaryServerInterceptor {
|
|||||||
interceptors = append(interceptors, serverinterceptors.UnaryBreakerInterceptor)
|
interceptors = append(interceptors, serverinterceptors.UnaryBreakerInterceptor)
|
||||||
}
|
}
|
||||||
|
|
||||||
return interceptors
|
return append(interceptors, s.unaryInterceptors...)
|
||||||
}
|
}
|
||||||
|
|
||||||
// WithMetrics returns a func that sets metrics to a Server.
|
// WithMetrics returns a func that sets metrics to a Server.
|
||||||
|
|||||||
@@ -1,10 +1,12 @@
|
|||||||
package internal
|
package internal
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"sync"
|
"sync"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/zeromicro/go-zero/core/proc"
|
||||||
"github.com/zeromicro/go-zero/core/stat"
|
"github.com/zeromicro/go-zero/core/stat"
|
||||||
"github.com/zeromicro/go-zero/zrpc/internal/mock"
|
"github.com/zeromicro/go-zero/zrpc/internal/mock"
|
||||||
"google.golang.org/grpc"
|
"google.golang.org/grpc"
|
||||||
@@ -18,12 +20,13 @@ func TestRpcServer(t *testing.T) {
|
|||||||
Stat: true,
|
Stat: true,
|
||||||
Prometheus: true,
|
Prometheus: true,
|
||||||
Breaker: true,
|
Breaker: true,
|
||||||
}, WithMetrics(metrics))
|
}, WithMetrics(metrics), WithRpcHealth(true))
|
||||||
server.SetName("mock")
|
server.SetName("mock")
|
||||||
var wg sync.WaitGroup
|
var wg, wgDone sync.WaitGroup
|
||||||
var grpcServer *grpc.Server
|
var grpcServer *grpc.Server
|
||||||
var lock sync.Mutex
|
var lock sync.Mutex
|
||||||
wg.Add(1)
|
wg.Add(1)
|
||||||
|
wgDone.Add(1)
|
||||||
go func() {
|
go func() {
|
||||||
err := server.Start(func(server *grpc.Server) {
|
err := server.Start(func(server *grpc.Server) {
|
||||||
lock.Lock()
|
lock.Lock()
|
||||||
@@ -33,12 +36,16 @@ func TestRpcServer(t *testing.T) {
|
|||||||
wg.Done()
|
wg.Done()
|
||||||
})
|
})
|
||||||
assert.Nil(t, err)
|
assert.Nil(t, err)
|
||||||
|
wgDone.Done()
|
||||||
}()
|
}()
|
||||||
|
|
||||||
wg.Wait()
|
wg.Wait()
|
||||||
lock.Lock()
|
lock.Lock()
|
||||||
grpcServer.GracefulStop()
|
grpcServer.GracefulStop()
|
||||||
lock.Unlock()
|
lock.Unlock()
|
||||||
|
|
||||||
|
proc.WrapUp()
|
||||||
|
wgDone.Wait()
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestRpcServer_WithBadAddress(t *testing.T) {
|
func TestRpcServer_WithBadAddress(t *testing.T) {
|
||||||
@@ -48,10 +55,124 @@ func TestRpcServer_WithBadAddress(t *testing.T) {
|
|||||||
Stat: true,
|
Stat: true,
|
||||||
Prometheus: true,
|
Prometheus: true,
|
||||||
Breaker: true,
|
Breaker: true,
|
||||||
})
|
}, WithRpcHealth(true))
|
||||||
server.SetName("mock")
|
server.SetName("mock")
|
||||||
err := server.Start(func(server *grpc.Server) {
|
err := server.Start(func(server *grpc.Server) {
|
||||||
mock.RegisterDepositServiceServer(server, new(mock.DepositServer))
|
mock.RegisterDepositServiceServer(server, new(mock.DepositServer))
|
||||||
})
|
})
|
||||||
assert.NotNil(t, err)
|
assert.NotNil(t, err)
|
||||||
|
|
||||||
|
proc.WrapUp()
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRpcServer_buildUnaryInterceptor(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
r *rpcServer
|
||||||
|
len int
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "empty",
|
||||||
|
r: &rpcServer{
|
||||||
|
baseRpcServer: &baseRpcServer{},
|
||||||
|
},
|
||||||
|
len: 0,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "custom",
|
||||||
|
r: &rpcServer{
|
||||||
|
baseRpcServer: &baseRpcServer{
|
||||||
|
unaryInterceptors: []grpc.UnaryServerInterceptor{
|
||||||
|
func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo,
|
||||||
|
handler grpc.UnaryHandler) (interface{}, error) {
|
||||||
|
return nil, nil
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
len: 1,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "middleware",
|
||||||
|
r: &rpcServer{
|
||||||
|
baseRpcServer: &baseRpcServer{
|
||||||
|
unaryInterceptors: []grpc.UnaryServerInterceptor{
|
||||||
|
func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo,
|
||||||
|
handler grpc.UnaryHandler) (interface{}, error) {
|
||||||
|
return nil, nil
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
middlewares: ServerMiddlewaresConf{
|
||||||
|
Trace: true,
|
||||||
|
Recover: true,
|
||||||
|
Stat: true,
|
||||||
|
Prometheus: true,
|
||||||
|
Breaker: true,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
len: 6,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, test := range tests {
|
||||||
|
t.Run(test.name, func(t *testing.T) {
|
||||||
|
assert.Equal(t, test.len, len(test.r.buildUnaryInterceptors()))
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRpcServer_buildStreamInterceptor(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
r *rpcServer
|
||||||
|
len int
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "empty",
|
||||||
|
r: &rpcServer{
|
||||||
|
baseRpcServer: &baseRpcServer{},
|
||||||
|
},
|
||||||
|
len: 0,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "custom",
|
||||||
|
r: &rpcServer{
|
||||||
|
baseRpcServer: &baseRpcServer{
|
||||||
|
streamInterceptors: []grpc.StreamServerInterceptor{
|
||||||
|
func(srv interface{}, ss grpc.ServerStream, info *grpc.StreamServerInfo,
|
||||||
|
handler grpc.StreamHandler) error {
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
len: 1,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "middleware",
|
||||||
|
r: &rpcServer{
|
||||||
|
baseRpcServer: &baseRpcServer{
|
||||||
|
streamInterceptors: []grpc.StreamServerInterceptor{
|
||||||
|
func(srv interface{}, ss grpc.ServerStream, info *grpc.StreamServerInfo,
|
||||||
|
handler grpc.StreamHandler) error {
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
middlewares: ServerMiddlewaresConf{
|
||||||
|
Trace: true,
|
||||||
|
Recover: true,
|
||||||
|
Breaker: true,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
len: 4,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, test := range tests {
|
||||||
|
t.Run(test.name, func(t *testing.T) {
|
||||||
|
assert.Equal(t, test.len, len(test.r.buildStreamInterceptors()))
|
||||||
|
})
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -48,5 +48,6 @@ func ParseTarget(target resolver.Target) (Service, error) {
|
|||||||
} else {
|
} else {
|
||||||
service.Name = endpoints
|
service.Name = endpoints
|
||||||
}
|
}
|
||||||
|
|
||||||
return service, nil
|
return service, nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -18,6 +18,15 @@ func TestKubeBuilder_Build(t *testing.T) {
|
|||||||
var b kubeBuilder
|
var b kubeBuilder
|
||||||
u, err := url.Parse(fmt.Sprintf("%s://%s", KubernetesScheme, "a,b"))
|
u, err := url.Parse(fmt.Sprintf("%s://%s", KubernetesScheme, "a,b"))
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
_, err = b.Build(resolver.Target{
|
||||||
|
URL: *u,
|
||||||
|
}, nil, resolver.BuildOptions{})
|
||||||
|
assert.Error(t, err)
|
||||||
|
|
||||||
|
u, err = url.Parse(fmt.Sprintf("%s://%s:9100/a:b:c", KubernetesScheme, "a,b,c,d"))
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
_, err = b.Build(resolver.Target{
|
_, err = b.Build(resolver.Target{
|
||||||
URL: *u,
|
URL: *u,
|
||||||
}, nil, resolver.BuildOptions{})
|
}, nil, resolver.BuildOptions{})
|
||||||
|
|||||||
@@ -3,15 +3,19 @@ package internal
|
|||||||
import (
|
import (
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
"google.golang.org/grpc/resolver"
|
"google.golang.org/grpc/resolver"
|
||||||
"google.golang.org/grpc/serviceconfig"
|
"google.golang.org/grpc/serviceconfig"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestNopResolver(t *testing.T) {
|
func TestNopResolver(t *testing.T) {
|
||||||
// make sure ResolveNow & Close don't panic
|
assert.NotPanics(t, func() {
|
||||||
var r nopResolver
|
RegisterResolver()
|
||||||
r.ResolveNow(resolver.ResolveNowOptions{})
|
// make sure ResolveNow & Close don't panic
|
||||||
r.Close()
|
var r nopResolver
|
||||||
|
r.ResolveNow(resolver.ResolveNowOptions{})
|
||||||
|
r.Close()
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
type mockedClientConn struct {
|
type mockedClientConn struct {
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ import (
|
|||||||
"github.com/zeromicro/go-zero/core/load"
|
"github.com/zeromicro/go-zero/core/load"
|
||||||
"github.com/zeromicro/go-zero/core/logx"
|
"github.com/zeromicro/go-zero/core/logx"
|
||||||
"github.com/zeromicro/go-zero/core/stat"
|
"github.com/zeromicro/go-zero/core/stat"
|
||||||
|
"github.com/zeromicro/go-zero/core/stores/redis"
|
||||||
"github.com/zeromicro/go-zero/zrpc/internal"
|
"github.com/zeromicro/go-zero/zrpc/internal"
|
||||||
"github.com/zeromicro/go-zero/zrpc/internal/auth"
|
"github.com/zeromicro/go-zero/zrpc/internal/auth"
|
||||||
"github.com/zeromicro/go-zero/zrpc/internal/serverinterceptors"
|
"github.com/zeromicro/go-zero/zrpc/internal/serverinterceptors"
|
||||||
@@ -120,7 +121,12 @@ func setupInterceptors(server internal.Server, c RpcServerConf, metrics *stat.Me
|
|||||||
}
|
}
|
||||||
|
|
||||||
if c.Auth {
|
if c.Auth {
|
||||||
authenticator, err := auth.NewAuthenticator(c.Redis.NewRedis(), c.Redis.Key, c.StrictControl)
|
rds, err := redis.NewRedis(c.Redis.RedisConf)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
authenticator, err := auth.NewAuthenticator(rds, c.Redis.Key, c.StrictControl)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ import (
|
|||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/alicebob/miniredis/v2"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/zeromicro/go-zero/core/discov"
|
"github.com/zeromicro/go-zero/core/discov"
|
||||||
"github.com/zeromicro/go-zero/core/logx"
|
"github.com/zeromicro/go-zero/core/logx"
|
||||||
@@ -16,12 +17,16 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
func TestServer_setupInterceptors(t *testing.T) {
|
func TestServer_setupInterceptors(t *testing.T) {
|
||||||
|
rds, err := miniredis.Run()
|
||||||
|
assert.NoError(t, err)
|
||||||
|
defer rds.Close()
|
||||||
|
|
||||||
server := new(mockedServer)
|
server := new(mockedServer)
|
||||||
err := setupInterceptors(server, RpcServerConf{
|
conf := RpcServerConf{
|
||||||
Auth: true,
|
Auth: true,
|
||||||
Redis: redis.RedisKeyConf{
|
Redis: redis.RedisKeyConf{
|
||||||
RedisConf: redis.RedisConf{
|
RedisConf: redis.RedisConf{
|
||||||
Host: "any",
|
Host: rds.Addr(),
|
||||||
Type: redis.NodeType,
|
Type: redis.NodeType,
|
||||||
},
|
},
|
||||||
Key: "foo",
|
Key: "foo",
|
||||||
@@ -35,10 +40,15 @@ func TestServer_setupInterceptors(t *testing.T) {
|
|||||||
Prometheus: true,
|
Prometheus: true,
|
||||||
Breaker: true,
|
Breaker: true,
|
||||||
},
|
},
|
||||||
}, new(stat.Metrics))
|
}
|
||||||
|
err = setupInterceptors(server, conf, new(stat.Metrics))
|
||||||
assert.Nil(t, err)
|
assert.Nil(t, err)
|
||||||
assert.Equal(t, 3, len(server.unaryInterceptors))
|
assert.Equal(t, 3, len(server.unaryInterceptors))
|
||||||
assert.Equal(t, 1, len(server.streamInterceptors))
|
assert.Equal(t, 1, len(server.streamInterceptors))
|
||||||
|
|
||||||
|
rds.SetError("mock error")
|
||||||
|
err = setupInterceptors(server, conf, new(stat.Metrics))
|
||||||
|
assert.Error(t, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestServer(t *testing.T) {
|
func TestServer(t *testing.T) {
|
||||||
|
|||||||
Reference in New Issue
Block a user