mirror of
https://github.com/zeromicro/go-zero.git
synced 2026-05-12 01:10:00 +08:00
Compare commits
57 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
c9ff6a10d3 | ||
|
|
a71e56de52 | ||
|
|
bae8d4f4c8 | ||
|
|
8c6266f338 | ||
|
|
95d5b81f44 | ||
|
|
bca7bbc142 | ||
|
|
df9a52664b | ||
|
|
937cf0db96 | ||
|
|
75cebb65f8 | ||
|
|
410f56e73a | ||
|
|
017909a3ab | ||
|
|
0d31e6c375 | ||
|
|
0ba86b1849 | ||
|
|
4cacc4d9d3 | ||
|
|
a99c14da4a | ||
|
|
985582264a | ||
|
|
8364e341e1 | ||
|
|
0f2b589d4d | ||
|
|
19fec36d24 | ||
|
|
f037bf344d | ||
|
|
d99cf35b07 | ||
|
|
f459f1b5ff | ||
|
|
0140fd417b | ||
|
|
7969e0ca38 | ||
|
|
91c885b5b0 | ||
|
|
d4cccca387 | ||
|
|
4b2095ed03 | ||
|
|
1229eeb2d2 | ||
|
|
9142b146c5 | ||
|
|
8a1b2d5aed | ||
|
|
da5d39e6ca | ||
|
|
68c5a17c67 | ||
|
|
b53f9f5f2d | ||
|
|
36d57626b6 | ||
|
|
4e36ba832f | ||
|
|
a44954a771 | ||
|
|
f3edd4b880 | ||
|
|
2de3e397ff | ||
|
|
a435eb56f2 | ||
|
|
d80761c147 | ||
|
|
e7bd0d8b60 | ||
|
|
b109b3ef4c | ||
|
|
e3c371ac89 | ||
|
|
15eb6f4f6d | ||
|
|
4d3681b71c | ||
|
|
a682bda0bb | ||
|
|
45b27ad93a | ||
|
|
292a8302a1 | ||
|
|
91ab1f6d2b | ||
|
|
5048c350ae | ||
|
|
94edc32f3e | ||
|
|
ec989b2e2a | ||
|
|
82fe802e81 | ||
|
|
072d68f897 | ||
|
|
2e91ba5811 | ||
|
|
5564c43197 | ||
|
|
e55158b0f7 |
@@ -8,16 +8,12 @@ import (
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/zeromicro/go-zero/core/mathx"
|
||||
"github.com/zeromicro/go-zero/core/proc"
|
||||
"github.com/zeromicro/go-zero/core/stat"
|
||||
"github.com/zeromicro/go-zero/core/stringx"
|
||||
)
|
||||
|
||||
const (
|
||||
numHistoryReasons = 5
|
||||
timeFormat = "15:04:05"
|
||||
)
|
||||
const numHistoryReasons = 5
|
||||
|
||||
// ErrServiceUnavailable is returned when the Breaker state is open.
|
||||
var ErrServiceUnavailable = errors.New("circuit breaker is open")
|
||||
@@ -262,9 +258,9 @@ type errorWindow struct {
|
||||
|
||||
func (ew *errorWindow) add(reason string) {
|
||||
ew.lock.Lock()
|
||||
ew.reasons[ew.index] = fmt.Sprintf("%s %s", time.Now().Format(timeFormat), reason)
|
||||
ew.reasons[ew.index] = fmt.Sprintf("%s %s", time.Now().Format(time.TimeOnly), reason)
|
||||
ew.index = (ew.index + 1) % numHistoryReasons
|
||||
ew.count = mathx.MinInt(ew.count+1, numHistoryReasons)
|
||||
ew.count = min(ew.count+1, numHistoryReasons)
|
||||
ew.lock.Unlock()
|
||||
}
|
||||
|
||||
|
||||
@@ -423,7 +423,7 @@ func TestRegistry_Monitor(t *testing.T) {
|
||||
GetRegistry().clusters = map[string]*cluster{
|
||||
getClusterKey(endpoints): {
|
||||
watchers: map[watchKey]*watchValue{
|
||||
watchKey{
|
||||
{
|
||||
key: "foo",
|
||||
exactMatch: true,
|
||||
}: {
|
||||
@@ -449,7 +449,7 @@ func TestRegistry_Unmonitor(t *testing.T) {
|
||||
GetRegistry().clusters = map[string]*cluster{
|
||||
getClusterKey(endpoints): {
|
||||
watchers: map[watchKey]*watchValue{
|
||||
watchKey{
|
||||
{
|
||||
key: "foo",
|
||||
exactMatch: true,
|
||||
}: {
|
||||
|
||||
@@ -86,21 +86,16 @@ func TestConsistentHashIncrementalTransfer(t *testing.T) {
|
||||
|
||||
func TestConsistentHashTransferOnFailure(t *testing.T) {
|
||||
index := 41
|
||||
keys, newKeys := getKeysBeforeAndAfterFailure(t, "localhost:", index)
|
||||
var transferred int
|
||||
for k, v := range newKeys {
|
||||
if v != keys[k] {
|
||||
transferred++
|
||||
}
|
||||
}
|
||||
|
||||
ratio := float32(transferred) / float32(requestSize)
|
||||
assert.True(t, ratio < 2.5/float32(keySize), fmt.Sprintf("%d: %f", index, ratio))
|
||||
ratioNotExists := getTransferRatioOnFailure(t, index)
|
||||
assert.True(t, ratioNotExists == 0, fmt.Sprintf("%d: %f", index, ratioNotExists))
|
||||
index = 13
|
||||
ratio := getTransferRatioOnFailure(t, index)
|
||||
assert.True(t, ratio < 2.5/keySize, fmt.Sprintf("%d: %f", index, ratio))
|
||||
}
|
||||
|
||||
func TestConsistentHashLeastTransferOnFailure(t *testing.T) {
|
||||
prefix := "localhost:"
|
||||
index := 41
|
||||
index := 13
|
||||
keys, newKeys := getKeysBeforeAndAfterFailure(t, prefix, index)
|
||||
for k, v := range keys {
|
||||
newV := newKeys[k]
|
||||
@@ -164,6 +159,17 @@ func getKeysBeforeAndAfterFailure(t *testing.T, prefix string, index int) (map[i
|
||||
return keys, newKeys
|
||||
}
|
||||
|
||||
func getTransferRatioOnFailure(t *testing.T, index int) float32 {
|
||||
keys, newKeys := getKeysBeforeAndAfterFailure(t, "localhost:", index)
|
||||
var transferred int
|
||||
for k, v := range newKeys {
|
||||
if v != keys[k] {
|
||||
transferred++
|
||||
}
|
||||
}
|
||||
return float32(transferred) / float32(requestSize)
|
||||
}
|
||||
|
||||
type mockNode struct {
|
||||
addr string
|
||||
id int
|
||||
|
||||
@@ -2,7 +2,7 @@ package hash
|
||||
|
||||
import (
|
||||
"crypto/md5"
|
||||
"fmt"
|
||||
"encoding/hex"
|
||||
|
||||
"github.com/spaolacci/murmur3"
|
||||
)
|
||||
@@ -20,6 +20,7 @@ func Md5(data []byte) []byte {
|
||||
}
|
||||
|
||||
// Md5Hex returns the md5 hex string of data.
|
||||
// This function is optimized for better performance than fmt.Sprintf.
|
||||
func Md5Hex(data []byte) string {
|
||||
return fmt.Sprintf("%x", Md5(data))
|
||||
return hex.EncodeToString(Md5(data))
|
||||
}
|
||||
|
||||
@@ -7,12 +7,11 @@ import (
|
||||
)
|
||||
|
||||
var (
|
||||
fieldsContextKey contextKey
|
||||
globalFields atomic.Value
|
||||
globalFieldsLock sync.Mutex
|
||||
)
|
||||
|
||||
type contextKey struct{}
|
||||
type fieldsKey struct{}
|
||||
|
||||
// AddGlobalFields adds global fields.
|
||||
func AddGlobalFields(fields ...LogField) {
|
||||
@@ -29,16 +28,16 @@ func AddGlobalFields(fields ...LogField) {
|
||||
|
||||
// ContextWithFields returns a new context with the given fields.
|
||||
func ContextWithFields(ctx context.Context, fields ...LogField) context.Context {
|
||||
if val := ctx.Value(fieldsContextKey); val != nil {
|
||||
if val := ctx.Value(fieldsKey{}); val != nil {
|
||||
if arr, ok := val.([]LogField); ok {
|
||||
allFields := make([]LogField, 0, len(arr)+len(fields))
|
||||
allFields = append(allFields, arr...)
|
||||
allFields = append(allFields, fields...)
|
||||
return context.WithValue(ctx, fieldsContextKey, allFields)
|
||||
return context.WithValue(ctx, fieldsKey{}, allFields)
|
||||
}
|
||||
}
|
||||
|
||||
return context.WithValue(ctx, fieldsContextKey, fields)
|
||||
return context.WithValue(ctx, fieldsKey{}, fields)
|
||||
}
|
||||
|
||||
// WithFields returns a new logger with the given fields.
|
||||
|
||||
@@ -34,7 +34,7 @@ func TestAddGlobalFields(t *testing.T) {
|
||||
|
||||
func TestContextWithFields(t *testing.T) {
|
||||
ctx := ContextWithFields(context.Background(), Field("a", 1), Field("b", 2))
|
||||
vals := ctx.Value(fieldsContextKey)
|
||||
vals := ctx.Value(fieldsKey{})
|
||||
assert.NotNil(t, vals)
|
||||
fields, ok := vals.([]LogField)
|
||||
assert.True(t, ok)
|
||||
@@ -43,7 +43,7 @@ func TestContextWithFields(t *testing.T) {
|
||||
|
||||
func TestWithFields(t *testing.T) {
|
||||
ctx := WithFields(context.Background(), Field("a", 1), Field("b", 2))
|
||||
vals := ctx.Value(fieldsContextKey)
|
||||
vals := ctx.Value(fieldsKey{})
|
||||
assert.NotNil(t, vals)
|
||||
fields, ok := vals.([]LogField)
|
||||
assert.True(t, ok)
|
||||
@@ -55,7 +55,7 @@ func TestWithFieldsAppend(t *testing.T) {
|
||||
ctx := context.WithValue(context.Background(), dummyKey, "dummy")
|
||||
ctx = ContextWithFields(ctx, Field("a", 1), Field("b", 2))
|
||||
ctx = ContextWithFields(ctx, Field("c", 3), Field("d", 4))
|
||||
vals := ctx.Value(fieldsContextKey)
|
||||
vals := ctx.Value(fieldsKey{})
|
||||
assert.NotNil(t, vals)
|
||||
fields, ok := vals.([]LogField)
|
||||
assert.True(t, ok)
|
||||
@@ -80,8 +80,8 @@ func TestWithFieldsAppendCopy(t *testing.T) {
|
||||
ctxa := ContextWithFields(ctx, af)
|
||||
ctxb := ContextWithFields(ctx, bf)
|
||||
|
||||
assert.EqualValues(t, af, ctxa.Value(fieldsContextKey).([]LogField)[count])
|
||||
assert.EqualValues(t, bf, ctxb.Value(fieldsContextKey).([]LogField)[count])
|
||||
assert.EqualValues(t, af, ctxa.Value(fieldsKey{}).([]LogField)[count])
|
||||
assert.EqualValues(t, bf, ctxb.Value(fieldsKey{}).([]LogField)[count])
|
||||
}
|
||||
|
||||
func BenchmarkAtomicValue(b *testing.B) {
|
||||
|
||||
@@ -224,7 +224,7 @@ func (l *richLogger) buildFields(fields ...LogField) []LogField {
|
||||
fields = append(fields, Field(spanKey, spanID))
|
||||
}
|
||||
|
||||
val := l.ctx.Value(fieldsContextKey)
|
||||
val := l.ctx.Value(fieldsKey{})
|
||||
if val != nil {
|
||||
if arr, ok := val.([]LogField); ok {
|
||||
fields = append(fields, arr...)
|
||||
|
||||
@@ -18,7 +18,6 @@ import (
|
||||
)
|
||||
|
||||
const (
|
||||
dateFormat = "2006-01-02"
|
||||
hoursPerDay = 24
|
||||
bufferSize = 100
|
||||
defaultDirMode = 0o755
|
||||
@@ -116,7 +115,7 @@ func (r *DailyRotateRule) OutdatedFiles() []string {
|
||||
}
|
||||
|
||||
var buf strings.Builder
|
||||
boundary := time.Now().Add(-time.Hour * time.Duration(hoursPerDay*r.days)).Format(dateFormat)
|
||||
boundary := time.Now().Add(-time.Hour * time.Duration(hoursPerDay*r.days)).Format(time.DateOnly)
|
||||
buf.WriteString(r.filename)
|
||||
buf.WriteString(r.delimiter)
|
||||
buf.WriteString(boundary)
|
||||
@@ -425,7 +424,7 @@ func compressLogFile(file string) {
|
||||
}
|
||||
|
||||
func getNowDate() string {
|
||||
return time.Now().Format(dateFormat)
|
||||
return time.Now().Format(time.DateOnly)
|
||||
}
|
||||
|
||||
func getNowDateInRFC3339Format() string {
|
||||
|
||||
@@ -52,7 +52,7 @@ func TestDailyRotateRuleOutdatedFiles(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("temp files", func(t *testing.T) {
|
||||
boundary := time.Now().Add(-time.Hour * time.Duration(hoursPerDay) * 2).Format(dateFormat)
|
||||
boundary := time.Now().Add(-time.Hour * time.Duration(hoursPerDay) * 2).Format(time.DateOnly)
|
||||
f1, err := os.CreateTemp(os.TempDir(), "go-zero-test-"+boundary)
|
||||
assert.NoError(t, err)
|
||||
_ = f1.Close()
|
||||
@@ -73,7 +73,7 @@ func TestDailyRotateRuleOutdatedFiles(t *testing.T) {
|
||||
|
||||
func TestDailyRotateRuleShallRotate(t *testing.T) {
|
||||
var rule DailyRotateRule
|
||||
rule.rotatedTime = time.Now().Add(time.Hour * 24).Format(dateFormat)
|
||||
rule.rotatedTime = time.Now().Add(time.Hour * 24).Format(time.DateOnly)
|
||||
assert.True(t, rule.ShallRotate(0))
|
||||
}
|
||||
|
||||
@@ -117,12 +117,12 @@ func TestSizeLimitRotateRuleOutdatedFiles(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("temp files", func(t *testing.T) {
|
||||
boundary := time.Now().Add(-time.Hour * time.Duration(hoursPerDay) * 2).Format(dateFormat)
|
||||
boundary := time.Now().Add(-time.Hour * time.Duration(hoursPerDay) * 2).Format(time.DateOnly)
|
||||
f1, err := os.CreateTemp(os.TempDir(), "go-zero-test-"+boundary)
|
||||
assert.NoError(t, err)
|
||||
f2, err := os.CreateTemp(os.TempDir(), "go-zero-test-"+boundary)
|
||||
assert.NoError(t, err)
|
||||
boundary1 := time.Now().Add(time.Hour * time.Duration(hoursPerDay) * 2).Format(dateFormat)
|
||||
boundary1 := time.Now().Add(time.Hour * time.Duration(hoursPerDay) * 2).Format(time.DateOnly)
|
||||
f3, err := os.CreateTemp(os.TempDir(), "go-zero-test-"+boundary1)
|
||||
assert.NoError(t, err)
|
||||
t.Cleanup(func() {
|
||||
@@ -144,12 +144,12 @@ func TestSizeLimitRotateRuleOutdatedFiles(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("no backups", func(t *testing.T) {
|
||||
boundary := time.Now().Add(-time.Hour * time.Duration(hoursPerDay) * 2).Format(dateFormat)
|
||||
boundary := time.Now().Add(-time.Hour * time.Duration(hoursPerDay) * 2).Format(time.DateOnly)
|
||||
f1, err := os.CreateTemp(os.TempDir(), "go-zero-test-"+boundary)
|
||||
assert.NoError(t, err)
|
||||
f2, err := os.CreateTemp(os.TempDir(), "go-zero-test-"+boundary)
|
||||
assert.NoError(t, err)
|
||||
boundary1 := time.Now().Add(time.Hour * time.Duration(hoursPerDay) * 2).Format(dateFormat)
|
||||
boundary1 := time.Now().Add(time.Hour * time.Duration(hoursPerDay) * 2).Format(time.DateOnly)
|
||||
f3, err := os.CreateTemp(os.TempDir(), "go-zero-test-"+boundary1)
|
||||
assert.NoError(t, err)
|
||||
t.Cleanup(func() {
|
||||
@@ -319,7 +319,7 @@ func TestRotateLoggerWrite(t *testing.T) {
|
||||
}
|
||||
// the following write calls cannot be changed to Write, because of DATA RACE.
|
||||
logger.write([]byte(`foo`))
|
||||
rule.rotatedTime = time.Now().Add(-time.Hour * 24).Format(dateFormat)
|
||||
rule.rotatedTime = time.Now().Add(-time.Hour * 24).Format(time.DateOnly)
|
||||
logger.write([]byte(`bar`))
|
||||
logger.Close()
|
||||
logger.write([]byte(`baz`))
|
||||
@@ -447,7 +447,7 @@ func TestRotateLoggerWithSizeLimitRotateRuleWrite(t *testing.T) {
|
||||
}
|
||||
// the following write calls cannot be changed to Write, because of DATA RACE.
|
||||
logger.write([]byte(`foo`))
|
||||
rule.rotatedTime = time.Now().Add(-time.Hour * 24).Format(dateFormat)
|
||||
rule.rotatedTime = time.Now().Add(-time.Hour * 24).Format(time.DateOnly)
|
||||
logger.write([]byte(`bar`))
|
||||
logger.Close()
|
||||
logger.write([]byte(`baz`))
|
||||
|
||||
@@ -7,6 +7,7 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"slices"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
@@ -15,7 +16,6 @@ import (
|
||||
"github.com/zeromicro/go-zero/core/jsonx"
|
||||
"github.com/zeromicro/go-zero/core/lang"
|
||||
"github.com/zeromicro/go-zero/core/proc"
|
||||
"github.com/zeromicro/go-zero/core/stringx"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -30,7 +30,9 @@ var (
|
||||
errValueNotSettable = errors.New("value is not settable")
|
||||
errValueNotStruct = errors.New("value type is not struct")
|
||||
keyUnmarshaler = NewUnmarshaler(defaultKeyName)
|
||||
boolType = reflect.TypeOf(false)
|
||||
durationType = reflect.TypeOf(time.Duration(0))
|
||||
stringType = reflect.TypeOf("")
|
||||
cacheKeys = make(map[string][]string)
|
||||
cacheKeysLock sync.Mutex
|
||||
defaultCache = make(map[string]any)
|
||||
@@ -622,9 +624,19 @@ func (u *Unmarshaler) processFieldNotFromString(fieldType reflect.Type, value re
|
||||
|
||||
return u.fillSliceFromString(fieldType, value, mapValue, fullName)
|
||||
case valueKind == reflect.String && derefedFieldType == durationType:
|
||||
return fillDurationValue(fieldType, value, mapValue.(string))
|
||||
v, err := convertToString(mapValue, fullName)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return fillDurationValue(fieldType, value, v)
|
||||
case valueKind == reflect.String && typeKind == reflect.Struct && u.implementsUnmarshaler(fieldType):
|
||||
return u.fillUnmarshalerStruct(fieldType, value, mapValue.(string))
|
||||
v, err := convertToString(mapValue, fullName)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return u.fillUnmarshalerStruct(fieldType, value, v)
|
||||
default:
|
||||
return u.processFieldPrimitive(fieldType, value, mapValue, opts, fullName)
|
||||
}
|
||||
@@ -755,24 +767,24 @@ func (u *Unmarshaler) processFieldWithEnvValue(fieldType reflect.Type, value ref
|
||||
return err
|
||||
}
|
||||
|
||||
fieldKind := fieldType.Kind()
|
||||
switch fieldKind {
|
||||
case reflect.Bool:
|
||||
derefType := Deref(fieldType)
|
||||
switch derefType {
|
||||
case boolType:
|
||||
val, err := strconv.ParseBool(envVal)
|
||||
if err != nil {
|
||||
return fmt.Errorf("unmarshal field %q with environment variable, %w", fullName, err)
|
||||
}
|
||||
|
||||
value.SetBool(val)
|
||||
SetValue(fieldType, value, reflect.ValueOf(val))
|
||||
return nil
|
||||
case durationType.Kind():
|
||||
case durationType:
|
||||
if err := fillDurationValue(fieldType, value, envVal); err != nil {
|
||||
return fmt.Errorf("unmarshal field %q with environment variable, %w", fullName, err)
|
||||
}
|
||||
|
||||
return nil
|
||||
case reflect.String:
|
||||
value.SetString(envVal)
|
||||
case stringType:
|
||||
SetValue(fieldType, value, reflect.ValueOf(envVal))
|
||||
return nil
|
||||
default:
|
||||
return u.processFieldPrimitiveWithJSONNumber(fieldType, value, json.Number(envVal), opts, fullName)
|
||||
@@ -894,7 +906,7 @@ func (u *Unmarshaler) processNamedFieldWithValueFromString(fieldType reflect.Typ
|
||||
valueKind.String())
|
||||
}
|
||||
|
||||
if !stringx.Contains(options, checkValue) {
|
||||
if !slices.Contains(options, checkValue) {
|
||||
return fmt.Errorf(`value "%s" for field %q is not defined in options "%v"`,
|
||||
mapValue, key, options)
|
||||
}
|
||||
|
||||
@@ -203,6 +203,20 @@ func TestUnmarshalDuration(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestUnmarshalDurationUnexpectedError(t *testing.T) {
|
||||
type inner struct {
|
||||
Duration time.Duration `key:"duration"`
|
||||
}
|
||||
content := "{\"duration\": 1}"
|
||||
var m = map[string]any{}
|
||||
err := jsonx.Unmarshal([]byte(content), &m)
|
||||
assert.NoError(t, err)
|
||||
var in inner
|
||||
err = UnmarshalKey(m, &in)
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "expect string")
|
||||
}
|
||||
|
||||
func TestUnmarshalDurationDefault(t *testing.T) {
|
||||
type inner struct {
|
||||
Int int `key:"int"`
|
||||
@@ -4665,6 +4679,23 @@ func TestUnmarshal_EnvInt(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestUnmarshal_EnvInt64(t *testing.T) {
|
||||
type Value struct {
|
||||
Age int64 `key:"age,env=TEST_NAME_INT64"`
|
||||
}
|
||||
|
||||
const (
|
||||
envName = "TEST_NAME_INT64"
|
||||
envVal = "88"
|
||||
)
|
||||
t.Setenv(envName, envVal)
|
||||
|
||||
var v Value
|
||||
if assert.NoError(t, UnmarshalKey(emptyMap, &v)) {
|
||||
assert.Equal(t, int64(88), v.Age)
|
||||
}
|
||||
}
|
||||
|
||||
func TestUnmarshal_EnvIntOverwrite(t *testing.T) {
|
||||
type Value struct {
|
||||
Age int `key:"age,env=TEST_NAME_INT"`
|
||||
@@ -4770,20 +4801,33 @@ func TestUnmarshal_EnvBoolBad(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestUnmarshal_EnvDuration(t *testing.T) {
|
||||
type Value struct {
|
||||
Duration time.Duration `key:"duration,env=TEST_NAME_DURATION"`
|
||||
}
|
||||
|
||||
const (
|
||||
envName = "TEST_NAME_DURATION"
|
||||
envVal = "1s"
|
||||
)
|
||||
t.Setenv(envName, envVal)
|
||||
|
||||
var v Value
|
||||
if assert.NoError(t, UnmarshalKey(emptyMap, &v)) {
|
||||
assert.Equal(t, time.Second, v.Duration)
|
||||
}
|
||||
t.Run("valid duration", func(t *testing.T) {
|
||||
type Value struct {
|
||||
Duration time.Duration `key:"duration,env=TEST_NAME_DURATION"`
|
||||
}
|
||||
|
||||
var v Value
|
||||
if assert.NoError(t, UnmarshalKey(emptyMap, &v)) {
|
||||
assert.Equal(t, time.Second, v.Duration)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("ptr of duration", func(t *testing.T) {
|
||||
type Value struct {
|
||||
Duration *time.Duration `key:"duration,env=TEST_NAME_DURATION"`
|
||||
}
|
||||
|
||||
var v Value
|
||||
if assert.NoError(t, UnmarshalKey(emptyMap, &v)) {
|
||||
assert.Equal(t, time.Second, *v.Duration)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestUnmarshal_EnvDurationBadValue(t *testing.T) {
|
||||
@@ -5995,6 +6039,16 @@ func TestUnmarshal_Unmarshaler(t *testing.T) {
|
||||
}, &v))
|
||||
assert.Nil(t, v.Foo)
|
||||
})
|
||||
|
||||
t.Run("json.Number", func(t *testing.T) {
|
||||
v := struct {
|
||||
Foo *mockUnmarshaler `json:"name"`
|
||||
}{}
|
||||
m := map[string]any{
|
||||
"name": json.Number("123"),
|
||||
}
|
||||
assert.Error(t, UnmarshalJsonMap(m, &v))
|
||||
})
|
||||
}
|
||||
|
||||
func TestParseJsonStringValue(t *testing.T) {
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"fmt"
|
||||
"math"
|
||||
"reflect"
|
||||
"slices"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
@@ -91,6 +92,15 @@ func ValidatePtr(v reflect.Value) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func convertToString(val any, fullName string) (string, error) {
|
||||
v, ok := val.(string)
|
||||
if !ok {
|
||||
return "", fmt.Errorf("expect string for field %s, but got type %T", fullName, val)
|
||||
}
|
||||
|
||||
return v, nil
|
||||
}
|
||||
|
||||
func convertTypeFromString(kind reflect.Kind, str string) (any, error) {
|
||||
switch kind {
|
||||
case reflect.Bool:
|
||||
@@ -634,11 +644,11 @@ func validateValueInOptions(val any, options []string) error {
|
||||
if len(options) > 0 {
|
||||
switch v := val.(type) {
|
||||
case string:
|
||||
if !stringx.Contains(options, v) {
|
||||
if !slices.Contains(options, v) {
|
||||
return fmt.Errorf(`error: value %q is not defined in options "%v"`, v, options)
|
||||
}
|
||||
default:
|
||||
if !stringx.Contains(options, Repr(v)) {
|
||||
if !slices.Contains(options, Repr(v)) {
|
||||
return fmt.Errorf(`error: value "%v" is not defined in options "%v"`, val, options)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,19 +1,13 @@
|
||||
package mathx
|
||||
|
||||
// MaxInt returns the larger one of a and b.
|
||||
// Deprecated: use builtin max instead.
|
||||
func MaxInt(a, b int) int {
|
||||
if a > b {
|
||||
return a
|
||||
}
|
||||
|
||||
return b
|
||||
return max(a, b)
|
||||
}
|
||||
|
||||
// MinInt returns the smaller one of a and b.
|
||||
// Deprecated: use builtin min instead.
|
||||
func MinInt(a, b int) int {
|
||||
if a < b {
|
||||
return a
|
||||
}
|
||||
|
||||
return b
|
||||
return min(a, b)
|
||||
}
|
||||
|
||||
@@ -1,13 +1,12 @@
|
||||
package prof
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"strconv"
|
||||
"fmt"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/olekukonko/tablewriter"
|
||||
"github.com/zeromicro/go-zero/core/logx"
|
||||
"github.com/zeromicro/go-zero/core/threading"
|
||||
)
|
||||
@@ -28,46 +27,15 @@ type (
|
||||
|
||||
const flushInterval = 5 * time.Minute
|
||||
|
||||
var (
|
||||
pc = &profileCenter{
|
||||
slots: make(map[string]*profileSlot),
|
||||
}
|
||||
once sync.Once
|
||||
)
|
||||
|
||||
func report(name string, duration time.Duration) {
|
||||
updated := func() bool {
|
||||
pc.lock.RLock()
|
||||
defer pc.lock.RUnlock()
|
||||
|
||||
slot, ok := pc.slots[name]
|
||||
if ok {
|
||||
atomic.AddInt64(&slot.lifecount, 1)
|
||||
atomic.AddInt64(&slot.lastcount, 1)
|
||||
atomic.AddInt64(&slot.lifecycle, int64(duration))
|
||||
atomic.AddInt64(&slot.lastcycle, int64(duration))
|
||||
}
|
||||
return ok
|
||||
}()
|
||||
|
||||
if !updated {
|
||||
func() {
|
||||
pc.lock.Lock()
|
||||
defer pc.lock.Unlock()
|
||||
|
||||
pc.slots[name] = &profileSlot{
|
||||
lifecount: 1,
|
||||
lastcount: 1,
|
||||
lifecycle: int64(duration),
|
||||
lastcycle: int64(duration),
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
once.Do(flushRepeatly)
|
||||
var pc = &profileCenter{
|
||||
slots: make(map[string]*profileSlot),
|
||||
}
|
||||
|
||||
func flushRepeatly() {
|
||||
func init() {
|
||||
flushRepeatedly()
|
||||
}
|
||||
|
||||
func flushRepeatedly() {
|
||||
threading.GoSafe(func() {
|
||||
for {
|
||||
time.Sleep(flushInterval)
|
||||
@@ -76,42 +44,64 @@ func flushRepeatly() {
|
||||
})
|
||||
}
|
||||
|
||||
func report(name string, duration time.Duration) {
|
||||
slot := loadOrStoreSlot(name, duration)
|
||||
|
||||
atomic.AddInt64(&slot.lifecount, 1)
|
||||
atomic.AddInt64(&slot.lastcount, 1)
|
||||
atomic.AddInt64(&slot.lifecycle, int64(duration))
|
||||
atomic.AddInt64(&slot.lastcycle, int64(duration))
|
||||
}
|
||||
|
||||
func loadOrStoreSlot(name string, duration time.Duration) *profileSlot {
|
||||
pc.lock.RLock()
|
||||
slot, ok := pc.slots[name]
|
||||
pc.lock.RUnlock()
|
||||
|
||||
if ok {
|
||||
return slot
|
||||
}
|
||||
|
||||
pc.lock.Lock()
|
||||
defer pc.lock.Unlock()
|
||||
|
||||
// double-check
|
||||
if slot, ok = pc.slots[name]; ok {
|
||||
return slot
|
||||
}
|
||||
|
||||
slot = &profileSlot{}
|
||||
pc.slots[name] = slot
|
||||
return slot
|
||||
}
|
||||
|
||||
func generateReport() string {
|
||||
var buffer bytes.Buffer
|
||||
buffer.WriteString("Profiling report\n")
|
||||
var data [][]string
|
||||
var builder strings.Builder
|
||||
builder.WriteString("Profiling report\n")
|
||||
builder.WriteString("QUEUE,LIFECOUNT,LIFECYCLE,LASTCOUNT,LASTCYCLE\n")
|
||||
|
||||
calcFn := func(total, count int64) string {
|
||||
if count == 0 {
|
||||
return "-"
|
||||
}
|
||||
|
||||
return (time.Duration(total) / time.Duration(count)).String()
|
||||
}
|
||||
|
||||
func() {
|
||||
pc.lock.Lock()
|
||||
defer pc.lock.Unlock()
|
||||
pc.lock.Lock()
|
||||
for key, slot := range pc.slots {
|
||||
builder.WriteString(fmt.Sprintf("%s,%d,%s,%d,%s\n",
|
||||
key,
|
||||
slot.lifecount,
|
||||
calcFn(slot.lifecycle, slot.lifecount),
|
||||
slot.lastcount,
|
||||
calcFn(slot.lastcycle, slot.lastcount),
|
||||
))
|
||||
|
||||
for key, slot := range pc.slots {
|
||||
data = append(data, []string{
|
||||
key,
|
||||
strconv.FormatInt(slot.lifecount, 10),
|
||||
calcFn(slot.lifecycle, slot.lifecount),
|
||||
strconv.FormatInt(slot.lastcount, 10),
|
||||
calcFn(slot.lastcycle, slot.lastcount),
|
||||
})
|
||||
// reset last cycle stats
|
||||
atomic.StoreInt64(&slot.lastcount, 0)
|
||||
atomic.StoreInt64(&slot.lastcycle, 0)
|
||||
}
|
||||
pc.lock.Unlock()
|
||||
|
||||
// reset the data for last cycle
|
||||
slot.lastcount = 0
|
||||
slot.lastcycle = 0
|
||||
}
|
||||
}()
|
||||
|
||||
table := tablewriter.NewWriter(&buffer)
|
||||
table.SetHeader([]string{"QUEUE", "LIFECOUNT", "LIFECYCLE", "LASTCOUNT", "LASTCYCLE"})
|
||||
table.SetBorder(false)
|
||||
table.AppendBulk(data)
|
||||
table.Render()
|
||||
|
||||
return buffer.String()
|
||||
return builder.String()
|
||||
}
|
||||
|
||||
@@ -8,7 +8,6 @@ import (
|
||||
)
|
||||
|
||||
func TestReport(t *testing.T) {
|
||||
once.Do(func() {})
|
||||
assert.NotContains(t, generateReport(), "foo")
|
||||
report("foo", time.Second)
|
||||
assert.Contains(t, generateReport(), "foo")
|
||||
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
"github.com/zeromicro/go-zero/core/stat"
|
||||
"github.com/zeromicro/go-zero/core/trace"
|
||||
"github.com/zeromicro/go-zero/internal/devserver"
|
||||
"github.com/zeromicro/go-zero/internal/profiling"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -38,6 +39,8 @@ type (
|
||||
Telemetry trace.Config `json:",optional"`
|
||||
DevServer DevServerConfig `json:",optional"`
|
||||
Shutdown proc.ShutdownConf `json:",optional"`
|
||||
// Profiling is the configuration for continuous profiling.
|
||||
Profiling profiling.Config `json:",optional"`
|
||||
}
|
||||
)
|
||||
|
||||
@@ -70,7 +73,9 @@ func (sc ServiceConf) SetUp() error {
|
||||
if len(sc.MetricsUrl) > 0 {
|
||||
stat.SetReportWriter(stat.NewRemoteWriter(sc.MetricsUrl))
|
||||
}
|
||||
|
||||
devserver.StartAgent(sc.DevServer)
|
||||
profiling.Start(sc.Profiling)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"sync"
|
||||
|
||||
"github.com/zeromicro/go-zero/core/logx"
|
||||
"github.com/zeromicro/go-zero/core/proc"
|
||||
"github.com/zeromicro/go-zero/core/syncx"
|
||||
"github.com/zeromicro/go-zero/core/threading"
|
||||
)
|
||||
|
||||
@@ -35,7 +36,7 @@ type (
|
||||
// NewServiceGroup returns a ServiceGroup.
|
||||
func NewServiceGroup() *ServiceGroup {
|
||||
sg := new(ServiceGroup)
|
||||
sg.stopOnce = syncx.Once(sg.doStop)
|
||||
sg.stopOnce = sync.OnceFunc(sg.doStop)
|
||||
return sg
|
||||
}
|
||||
|
||||
|
||||
@@ -19,7 +19,6 @@ import (
|
||||
const (
|
||||
clusterNameKey = "CLUSTER_NAME"
|
||||
testEnv = "test.v"
|
||||
timeFormat = "2006-01-02 15:04:05"
|
||||
)
|
||||
|
||||
var (
|
||||
@@ -45,7 +44,7 @@ func Report(msg string) {
|
||||
if fn != nil {
|
||||
reported := lessExecutor.DoOrDiscard(func() {
|
||||
var builder strings.Builder
|
||||
builder.WriteString(fmt.Sprintln(time.Now().Format(timeFormat)))
|
||||
builder.WriteString(fmt.Sprintln(time.Now().Format(time.DateTime)))
|
||||
if len(clusterName) > 0 {
|
||||
builder.WriteString(fmt.Sprintf("cluster: %s\n", clusterName))
|
||||
}
|
||||
|
||||
29
core/stores/sqlx/config.go
Normal file
29
core/stores/sqlx/config.go
Normal file
@@ -0,0 +1,29 @@
|
||||
package sqlx
|
||||
|
||||
import "errors"
|
||||
|
||||
var (
|
||||
errEmptyDatasource = errors.New("empty datasource")
|
||||
errEmptyDriverName = errors.New("empty driver name")
|
||||
)
|
||||
|
||||
// SqlConf defines the configuration for sqlx.
|
||||
type SqlConf struct {
|
||||
DataSource string
|
||||
DriverName string `json:",default=mysql"`
|
||||
Replicas []string `json:",optional"`
|
||||
Policy string `json:",default=round-robin,options=round-robin|random"`
|
||||
}
|
||||
|
||||
// Validate validates the SqlxConf.
|
||||
func (sc SqlConf) Validate() error {
|
||||
if len(sc.DataSource) == 0 {
|
||||
return errEmptyDatasource
|
||||
}
|
||||
|
||||
if len(sc.DriverName) == 0 {
|
||||
return errEmptyDriverName
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
29
core/stores/sqlx/config_test.go
Normal file
29
core/stores/sqlx/config_test.go
Normal file
@@ -0,0 +1,29 @@
|
||||
package sqlx
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/zeromicro/go-zero/core/conf"
|
||||
)
|
||||
|
||||
func TestValidate(t *testing.T) {
|
||||
text := []byte(`DataSource: primary:password@tcp(127.0.0.1:3306)/primary_db
|
||||
`)
|
||||
|
||||
var sc SqlConf
|
||||
err := conf.LoadFromYamlBytes(text, &sc)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, "mysql", sc.DriverName)
|
||||
assert.Equal(t, policyRoundRobin, sc.Policy)
|
||||
assert.Nil(t, sc.Validate())
|
||||
|
||||
sc = SqlConf{}
|
||||
assert.Equal(t, errEmptyDatasource, sc.Validate())
|
||||
|
||||
sc.DataSource = "primary:password@tcp(127.0.0.1:3306)/primary_db"
|
||||
assert.Equal(t, errEmptyDriverName, sc.Validate())
|
||||
|
||||
sc.DriverName = "mysql"
|
||||
assert.Nil(t, sc.Validate())
|
||||
}
|
||||
65
core/stores/sqlx/rwstrategy.go
Normal file
65
core/stores/sqlx/rwstrategy.go
Normal file
@@ -0,0 +1,65 @@
|
||||
package sqlx
|
||||
|
||||
import "context"
|
||||
|
||||
const (
|
||||
// policyRoundRobin round-robin policy for selecting replicas.
|
||||
policyRoundRobin = "round-robin"
|
||||
// policyRandom random policy for selecting replicas.
|
||||
policyRandom = "random"
|
||||
|
||||
// readPrimaryMode indicates that the operation is a read,
|
||||
// but should be performed on the primary database instance.
|
||||
//
|
||||
// This mode is used in scenarios where data freshness and consistency are critical,
|
||||
// such as immediately after writes or where replication lag may cause stale reads.
|
||||
readPrimaryMode readWriteMode = "read-primary"
|
||||
|
||||
// readReplicaMode indicates that the operation is a read from replicas.
|
||||
// This is suitable for scenarios where eventual consistency is acceptable,
|
||||
// and the goal is to offload traffic from the primary and improve read scalability.
|
||||
readReplicaMode readWriteMode = "read-replica"
|
||||
|
||||
// writeMode indicates that the operation is a write operation (to primary).
|
||||
writeMode readWriteMode = "write"
|
||||
|
||||
// notSpecifiedMode indicates that the read/write mode is not specified.
|
||||
notSpecifiedMode readWriteMode = ""
|
||||
)
|
||||
|
||||
type readWriteModeKey struct{}
|
||||
|
||||
// WithReadPrimary sets the context to read-primary mode.
|
||||
func WithReadPrimary(ctx context.Context) context.Context {
|
||||
return context.WithValue(ctx, readWriteModeKey{}, readPrimaryMode)
|
||||
}
|
||||
|
||||
// WithReadReplica sets the context to read-replica mode.
|
||||
func WithReadReplica(ctx context.Context) context.Context {
|
||||
return context.WithValue(ctx, readWriteModeKey{}, readReplicaMode)
|
||||
}
|
||||
|
||||
// WithWrite sets the context to write mode, indicating that the operation is a write operation.
|
||||
func WithWrite(ctx context.Context) context.Context {
|
||||
return context.WithValue(ctx, readWriteModeKey{}, writeMode)
|
||||
}
|
||||
|
||||
type readWriteMode string
|
||||
|
||||
func (m readWriteMode) isValid() bool {
|
||||
return m == readPrimaryMode || m == readReplicaMode || m == writeMode
|
||||
}
|
||||
|
||||
func getReadWriteMode(ctx context.Context) readWriteMode {
|
||||
if mode := ctx.Value(readWriteModeKey{}); mode != nil {
|
||||
if v, ok := mode.(readWriteMode); ok && v.isValid() {
|
||||
return v
|
||||
}
|
||||
}
|
||||
|
||||
return notSpecifiedMode
|
||||
}
|
||||
|
||||
func usePrimary(ctx context.Context) bool {
|
||||
return getReadWriteMode(ctx) != readReplicaMode
|
||||
}
|
||||
142
core/stores/sqlx/rwstrategy_test.go
Normal file
142
core/stores/sqlx/rwstrategy_test.go
Normal file
@@ -0,0 +1,142 @@
|
||||
package sqlx
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestIsValid(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
mode readWriteMode
|
||||
expected bool
|
||||
}{
|
||||
{
|
||||
name: "valid read-primary mode",
|
||||
mode: readPrimaryMode,
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "valid read-replica mode",
|
||||
mode: readReplicaMode,
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "valid write mode",
|
||||
mode: writeMode,
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "not specified mode (empty)",
|
||||
mode: notSpecifiedMode,
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "invalid custom string",
|
||||
mode: readWriteMode("delete"),
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "case sensitive check",
|
||||
mode: readWriteMode("READ"),
|
||||
expected: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
actual := tc.mode.isValid()
|
||||
assert.Equal(t, tc.expected, actual)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestWithReadMode(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
readPrimaryCtx := WithReadPrimary(ctx)
|
||||
|
||||
val := readPrimaryCtx.Value(readWriteModeKey{})
|
||||
assert.Equal(t, readPrimaryMode, val)
|
||||
|
||||
readReplicaCtx := WithReadReplica(ctx)
|
||||
val = readReplicaCtx.Value(readWriteModeKey{})
|
||||
assert.Equal(t, readReplicaMode, val)
|
||||
}
|
||||
|
||||
func TestWithWriteMode(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
writeCtx := WithWrite(ctx)
|
||||
|
||||
val := writeCtx.Value(readWriteModeKey{})
|
||||
assert.Equal(t, writeMode, val)
|
||||
}
|
||||
|
||||
func TestGetReadWriteMode(t *testing.T) {
|
||||
t.Run("valid read-primary mode", func(t *testing.T) {
|
||||
ctx := context.WithValue(context.Background(), readWriteModeKey{}, readPrimaryMode)
|
||||
assert.Equal(t, readPrimaryMode, getReadWriteMode(ctx))
|
||||
})
|
||||
|
||||
t.Run("valid read-replica mode", func(t *testing.T) {
|
||||
ctx := context.WithValue(context.Background(), readWriteModeKey{}, readReplicaMode)
|
||||
assert.Equal(t, readReplicaMode, getReadWriteMode(ctx))
|
||||
})
|
||||
|
||||
t.Run("valid write mode", func(t *testing.T) {
|
||||
ctx := context.WithValue(context.Background(), readWriteModeKey{}, writeMode)
|
||||
assert.Equal(t, writeMode, getReadWriteMode(ctx))
|
||||
})
|
||||
|
||||
t.Run("invalid mode value (wrong type)", func(t *testing.T) {
|
||||
ctx := context.WithValue(context.Background(), readWriteModeKey{}, "not-a-mode")
|
||||
assert.Equal(t, notSpecifiedMode, getReadWriteMode(ctx))
|
||||
})
|
||||
|
||||
t.Run("invalid mode value (wrong value)", func(t *testing.T) {
|
||||
ctx := context.WithValue(context.Background(), readWriteModeKey{}, readWriteMode("delete"))
|
||||
assert.Equal(t, notSpecifiedMode, getReadWriteMode(ctx))
|
||||
})
|
||||
|
||||
t.Run("no mode set", func(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
assert.Equal(t, notSpecifiedMode, getReadWriteMode(ctx))
|
||||
})
|
||||
}
|
||||
|
||||
func TestUsePrimary(t *testing.T) {
|
||||
t.Run("context with read-replica mode", func(t *testing.T) {
|
||||
ctx := context.WithValue(context.Background(), readWriteModeKey{}, readReplicaMode)
|
||||
assert.False(t, usePrimary(ctx))
|
||||
})
|
||||
|
||||
t.Run("context with read-primary mode", func(t *testing.T) {
|
||||
ctx := context.WithValue(context.Background(), readWriteModeKey{}, readPrimaryMode)
|
||||
assert.True(t, usePrimary(ctx))
|
||||
})
|
||||
|
||||
t.Run("context with write mode", func(t *testing.T) {
|
||||
ctx := context.WithValue(context.Background(), readWriteModeKey{}, writeMode)
|
||||
assert.True(t, usePrimary(ctx))
|
||||
})
|
||||
|
||||
t.Run("context with invalid mode", func(t *testing.T) {
|
||||
ctx := context.WithValue(context.Background(), readWriteModeKey{}, readWriteMode("invalid"))
|
||||
assert.True(t, usePrimary(ctx))
|
||||
})
|
||||
|
||||
t.Run("context with no mode set", func(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
assert.True(t, usePrimary(ctx))
|
||||
})
|
||||
}
|
||||
|
||||
func TestWithModeTwice(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
ctx = WithReadPrimary(ctx)
|
||||
writeCtx := WithWrite(ctx)
|
||||
|
||||
val := writeCtx.Value(readWriteModeKey{})
|
||||
assert.Equal(t, writeMode, val)
|
||||
}
|
||||
@@ -4,6 +4,9 @@ import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"errors"
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"sync/atomic"
|
||||
|
||||
"github.com/zeromicro/go-zero/core/breaker"
|
||||
"github.com/zeromicro/go-zero/core/errorx"
|
||||
@@ -52,9 +55,10 @@ type (
|
||||
beginTx beginnable
|
||||
brk breaker.Breaker
|
||||
accept breaker.Acceptable
|
||||
index uint32
|
||||
}
|
||||
|
||||
connProvider func() (*sql.DB, error)
|
||||
connProvider func(ctx context.Context) (*sql.DB, error)
|
||||
|
||||
sessionConn interface {
|
||||
Exec(query string, args ...any) (sql.Result, error)
|
||||
@@ -64,10 +68,41 @@ type (
|
||||
}
|
||||
)
|
||||
|
||||
// MustNewConn returns a SqlConn with the given SqlConf.
|
||||
func MustNewConn(c SqlConf, opts ...SqlOption) SqlConn {
|
||||
conn, err := NewConn(c, opts...)
|
||||
if err != nil {
|
||||
logx.Must(err)
|
||||
}
|
||||
|
||||
return conn
|
||||
}
|
||||
|
||||
// NewConn returns a SqlConn with the given SqlConf.
|
||||
func NewConn(c SqlConf, opts ...SqlOption) (SqlConn, error) {
|
||||
if err := c.Validate(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
conn := &commonSqlConn{
|
||||
onError: func(ctx context.Context, err error) {
|
||||
logInstanceError(ctx, c.DataSource, err)
|
||||
},
|
||||
beginTx: begin,
|
||||
brk: breaker.NewBreaker(),
|
||||
}
|
||||
for _, opt := range opts {
|
||||
opt(conn)
|
||||
}
|
||||
conn.connProv = getConnProvider(conn, c.DriverName, c.DataSource, c.Policy, c.Replicas)
|
||||
|
||||
return conn, nil
|
||||
}
|
||||
|
||||
// NewSqlConn returns a SqlConn with given driver name and datasource.
|
||||
func NewSqlConn(driverName, datasource string, opts ...SqlOption) SqlConn {
|
||||
conn := &commonSqlConn{
|
||||
connProv: func() (*sql.DB, error) {
|
||||
connProv: func(context.Context) (*sql.DB, error) {
|
||||
return getSqlConn(driverName, datasource)
|
||||
},
|
||||
onError: func(ctx context.Context, err error) {
|
||||
@@ -87,7 +122,7 @@ func NewSqlConn(driverName, datasource string, opts ...SqlOption) SqlConn {
|
||||
// Use it with caution; it's provided for other ORM to interact with.
|
||||
func NewSqlConnFromDB(db *sql.DB, opts ...SqlOption) SqlConn {
|
||||
conn := &commonSqlConn{
|
||||
connProv: func() (*sql.DB, error) {
|
||||
connProv: func(ctx context.Context) (*sql.DB, error) {
|
||||
return db, nil
|
||||
},
|
||||
onError: func(ctx context.Context, err error) {
|
||||
@@ -123,7 +158,7 @@ func (db *commonSqlConn) ExecCtx(ctx context.Context, q string, args ...any) (
|
||||
|
||||
err = db.brk.DoWithAcceptableCtx(ctx, func() error {
|
||||
var conn *sql.DB
|
||||
conn, err = db.connProv()
|
||||
conn, err = db.connProv(ctx)
|
||||
if err != nil {
|
||||
db.onError(ctx, err)
|
||||
return err
|
||||
@@ -151,7 +186,7 @@ func (db *commonSqlConn) PrepareCtx(ctx context.Context, query string) (stmt Stm
|
||||
|
||||
err = db.brk.DoWithAcceptableCtx(ctx, func() error {
|
||||
var conn *sql.DB
|
||||
conn, err = db.connProv()
|
||||
conn, err = db.connProv(ctx)
|
||||
if err != nil {
|
||||
db.onError(ctx, err)
|
||||
return err
|
||||
@@ -242,7 +277,7 @@ func (db *commonSqlConn) QueryRowsPartialCtx(ctx context.Context, v any,
|
||||
}
|
||||
|
||||
func (db *commonSqlConn) RawDB() (*sql.DB, error) {
|
||||
return db.connProv()
|
||||
return db.connProv(context.Background())
|
||||
}
|
||||
|
||||
func (db *commonSqlConn) Transact(fn func(Session) error) error {
|
||||
@@ -288,7 +323,7 @@ func (db *commonSqlConn) queryRows(ctx context.Context, scanner func(*sql.Rows)
|
||||
q string, args ...any) (err error) {
|
||||
var scanFailed bool
|
||||
err = db.brk.DoWithAcceptableCtx(ctx, func() error {
|
||||
conn, err := db.connProv()
|
||||
conn, err := db.connProv(ctx)
|
||||
if err != nil {
|
||||
db.onError(ctx, err)
|
||||
return err
|
||||
@@ -311,6 +346,38 @@ func (db *commonSqlConn) queryRows(ctx context.Context, scanner func(*sql.Rows)
|
||||
return
|
||||
}
|
||||
|
||||
func getConnProvider(sc *commonSqlConn, driverName, datasource, policy string, replicas []string) connProvider {
|
||||
return func(ctx context.Context) (*sql.DB, error) {
|
||||
replicaCount := len(replicas)
|
||||
|
||||
if replicaCount == 0 || usePrimary(ctx) {
|
||||
return getSqlConn(driverName, datasource)
|
||||
}
|
||||
|
||||
var dsn string
|
||||
|
||||
if replicaCount == 1 {
|
||||
dsn = replicas[0]
|
||||
} else {
|
||||
if len(policy) == 0 {
|
||||
policy = policyRoundRobin
|
||||
}
|
||||
|
||||
switch policy {
|
||||
case policyRandom:
|
||||
dsn = replicas[rand.Intn(replicaCount)]
|
||||
case policyRoundRobin:
|
||||
index := atomic.AddUint32(&sc.index, 1) - 1
|
||||
dsn = replicas[index%uint32(replicaCount)]
|
||||
default:
|
||||
return nil, fmt.Errorf("unknown policy: %s", policy)
|
||||
}
|
||||
}
|
||||
|
||||
return getSqlConn(driverName, dsn)
|
||||
}
|
||||
}
|
||||
|
||||
// WithAcceptable returns a SqlOption that setting the acceptable function.
|
||||
// acceptable is the func to check if the error can be accepted.
|
||||
func WithAcceptable(acceptable func(err error) bool) SqlOption {
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package sqlx
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"errors"
|
||||
"io"
|
||||
@@ -98,7 +99,7 @@ func TestSqlConn_RawDB(t *testing.T) {
|
||||
func TestSqlConn_Errors(t *testing.T) {
|
||||
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||
conn := NewSqlConnFromDB(db)
|
||||
conn.(*commonSqlConn).connProv = func() (*sql.DB, error) {
|
||||
conn.(*commonSqlConn).connProv = func(ctx context.Context) (*sql.DB, error) {
|
||||
return nil, errors.New("error")
|
||||
}
|
||||
_, err := conn.Prepare("any")
|
||||
@@ -138,6 +139,148 @@ func TestSqlConn_Errors(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
func TestConfigSqlConn(t *testing.T) {
|
||||
db, mock, err := sqlmock.New()
|
||||
assert.NotNil(t, db)
|
||||
assert.NotNil(t, mock)
|
||||
assert.Nil(t, err)
|
||||
connManager.Inject(mockedDatasource, db)
|
||||
mock.ExpectExec("any")
|
||||
mock.ExpectQuery("any").WillReturnRows(sqlmock.NewRows([]string{"foo"}))
|
||||
|
||||
conf := SqlConf{DataSource: mockedDatasource, DriverName: mysqlDriverName}
|
||||
conn := MustNewConn(conf, withMysqlAcceptable())
|
||||
|
||||
_, err = conn.Exec("any", "value")
|
||||
assert.NotNil(t, err)
|
||||
_, err = conn.Prepare("any")
|
||||
assert.NotNil(t, err)
|
||||
|
||||
var val string
|
||||
assert.NotNil(t, conn.QueryRow(&val, "any"))
|
||||
assert.NotNil(t, conn.QueryRowPartial(&val, "any"))
|
||||
assert.NotNil(t, conn.QueryRows(&val, "any"))
|
||||
assert.NotNil(t, conn.QueryRowsPartial(&val, "any"))
|
||||
}
|
||||
|
||||
func TestConfigSqlConnStatement(t *testing.T) {
|
||||
db, mock, err := sqlmock.New()
|
||||
assert.NotNil(t, db)
|
||||
assert.NotNil(t, mock)
|
||||
assert.Nil(t, err)
|
||||
connManager.Inject(mockedDatasource, db)
|
||||
|
||||
mock.ExpectPrepare("any")
|
||||
mock.ExpectExec("any").WillReturnResult(sqlmock.NewResult(2, 3))
|
||||
mock.ExpectPrepare("any")
|
||||
row := sqlmock.NewRows([]string{"foo"}).AddRow("bar")
|
||||
mock.ExpectQuery("any").WillReturnRows(row)
|
||||
|
||||
conf := SqlConf{DataSource: mockedDatasource, DriverName: mysqlDriverName}
|
||||
conn := MustNewConn(conf, withMysqlAcceptable())
|
||||
stmt, err := conn.Prepare("any")
|
||||
assert.NoError(t, err)
|
||||
|
||||
res, err := stmt.Exec()
|
||||
assert.NoError(t, err)
|
||||
lastInsertID, err := res.LastInsertId()
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, int64(2), lastInsertID)
|
||||
rowsAffected, err := res.RowsAffected()
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, int64(3), rowsAffected)
|
||||
|
||||
stmt, err = conn.Prepare("any")
|
||||
assert.NoError(t, err)
|
||||
|
||||
var val string
|
||||
err = stmt.QueryRow(&val)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "bar", val)
|
||||
|
||||
mock.ExpectPrepare("any")
|
||||
rows := sqlmock.NewRows([]string{"any"}).AddRow("foo").AddRow("bar")
|
||||
mock.ExpectQuery("any").WillReturnRows(rows)
|
||||
|
||||
stmt, err = conn.Prepare("any")
|
||||
assert.NoError(t, err)
|
||||
|
||||
var vals []string
|
||||
assert.NoError(t, stmt.QueryRowsPartial(&vals))
|
||||
assert.ElementsMatch(t, []string{"foo", "bar"}, vals)
|
||||
}
|
||||
|
||||
func TestConfigSqlConnQuery(t *testing.T) {
|
||||
db, mock, err := sqlmock.New()
|
||||
assert.NotNil(t, db)
|
||||
assert.NotNil(t, mock)
|
||||
assert.Nil(t, err)
|
||||
connManager.Inject(mockedDatasource, db)
|
||||
|
||||
t.Run("QueryRow", func(t *testing.T) {
|
||||
mock.ExpectQuery("any").WillReturnRows(sqlmock.NewRows([]string{"foo"}).AddRow("bar"))
|
||||
conf := SqlConf{DataSource: mockedDatasource, DriverName: mysqlDriverName}
|
||||
conn := MustNewConn(conf)
|
||||
var val string
|
||||
assert.NoError(t, conn.QueryRow(&val, "any"))
|
||||
assert.Equal(t, "bar", val)
|
||||
})
|
||||
|
||||
t.Run("QueryRowPartial", func(t *testing.T) {
|
||||
mock.ExpectQuery("any").WillReturnRows(sqlmock.NewRows([]string{"foo"}).AddRow("bar"))
|
||||
conf := SqlConf{DataSource: mockedDatasource, DriverName: mysqlDriverName}
|
||||
conn := MustNewConn(conf)
|
||||
var val string
|
||||
assert.NoError(t, conn.QueryRowPartial(&val, "any"))
|
||||
assert.Equal(t, "bar", val)
|
||||
})
|
||||
|
||||
t.Run("QueryRows", func(t *testing.T) {
|
||||
mock.ExpectQuery("any").WillReturnRows(sqlmock.NewRows([]string{"any"}).AddRow("foo").AddRow("bar"))
|
||||
conf := SqlConf{DataSource: mockedDatasource, DriverName: mysqlDriverName}
|
||||
conn := MustNewConn(conf)
|
||||
var vals []string
|
||||
assert.NoError(t, conn.QueryRows(&vals, "any"))
|
||||
assert.ElementsMatch(t, []string{"foo", "bar"}, vals)
|
||||
})
|
||||
|
||||
t.Run("QueryRowsPartial", func(t *testing.T) {
|
||||
mock.ExpectQuery("any").WillReturnRows(sqlmock.NewRows([]string{"any"}).AddRow("foo").AddRow("bar"))
|
||||
conf := SqlConf{DataSource: mockedDatasource, DriverName: mysqlDriverName}
|
||||
conn := MustNewConn(conf)
|
||||
var vals []string
|
||||
assert.NoError(t, conn.QueryRowsPartial(&vals, "any"))
|
||||
assert.ElementsMatch(t, []string{"foo", "bar"}, vals)
|
||||
})
|
||||
}
|
||||
|
||||
func TestConfigSqlConnErr(t *testing.T) {
|
||||
t.Run("panic on empty config", func(t *testing.T) {
|
||||
original := logx.ExitOnFatal.True()
|
||||
logx.ExitOnFatal.Set(false)
|
||||
defer logx.ExitOnFatal.Set(original)
|
||||
|
||||
assert.Panics(t, func() {
|
||||
MustNewConn(SqlConf{})
|
||||
})
|
||||
})
|
||||
t.Run("on error", func(t *testing.T) {
|
||||
db, mock, err := sqlmock.New()
|
||||
assert.NotNil(t, db)
|
||||
assert.NotNil(t, mock)
|
||||
assert.Nil(t, err)
|
||||
connManager.Inject(mockedDatasource, db)
|
||||
|
||||
conf := SqlConf{DataSource: mockedDatasource, DriverName: mysqlDriverName}
|
||||
conn := MustNewConn(conf)
|
||||
conn.(*commonSqlConn).connProv = func(ctx context.Context) (*sql.DB, error) {
|
||||
return nil, errors.New("error")
|
||||
}
|
||||
_, err = conn.Prepare("any")
|
||||
assert.Error(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
func TestStatement(t *testing.T) {
|
||||
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||
mock.ExpectPrepare("any").WillBeClosed()
|
||||
@@ -303,6 +446,93 @@ func TestWithAcceptable(t *testing.T) {
|
||||
assert.True(t, conn.accept(acceptableErr3))
|
||||
}
|
||||
|
||||
func TestProvider(t *testing.T) {
|
||||
defer func() {
|
||||
_ = connManager.Close()
|
||||
}()
|
||||
|
||||
primaryDSN := "primary:password@tcp(127.0.0.1:3306)/primary_db"
|
||||
replicasDSN := []string{
|
||||
"replica_one:pwd@tcp(localhost:3306)/replica_one",
|
||||
"replica_two:pwd@tcp(localhost:3306)/replica_two",
|
||||
"replica_three:pwd@tcp(localhost:3306)/replica_three",
|
||||
}
|
||||
|
||||
primaryDB, err := connManager.GetResource(primaryDSN, func() (io.Closer, error) { return sql.Open(mysqlDriverName, primaryDSN) })
|
||||
assert.Nil(t, err)
|
||||
assert.NotNil(t, primaryDB)
|
||||
replicaOneDB, err := connManager.GetResource(replicasDSN[0], func() (io.Closer, error) { return sql.Open(mysqlDriverName, replicasDSN[0]) })
|
||||
assert.Nil(t, err)
|
||||
assert.NotNil(t, replicaOneDB)
|
||||
replicaTwoDB, err := connManager.GetResource(replicasDSN[1], func() (io.Closer, error) { return sql.Open(mysqlDriverName, replicasDSN[1]) })
|
||||
assert.Nil(t, err)
|
||||
assert.NotNil(t, replicaTwoDB)
|
||||
replicaThreeDB, err := connManager.GetResource(replicasDSN[2], func() (io.Closer, error) { return sql.Open(mysqlDriverName, replicasDSN[2]) })
|
||||
assert.Nil(t, err)
|
||||
assert.NotNil(t, replicaThreeDB)
|
||||
|
||||
sc := &commonSqlConn{}
|
||||
sc.connProv = getConnProvider(sc, mysqlDriverName, primaryDSN, policyRoundRobin, nil)
|
||||
|
||||
ctx := context.Background()
|
||||
db, err := sc.connProv(ctx)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, primaryDB, db)
|
||||
|
||||
ctx = WithWrite(ctx)
|
||||
db, err = sc.connProv(ctx)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, primaryDB, db)
|
||||
|
||||
ctx = WithReadPrimary(ctx)
|
||||
db, err = sc.connProv(ctx)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, primaryDB, db)
|
||||
|
||||
// no mode set, should return primary
|
||||
ctx = context.Background()
|
||||
sc.connProv = getConnProvider(sc, mysqlDriverName, primaryDSN, policyRoundRobin, replicasDSN)
|
||||
db, err = sc.connProv(ctx)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, primaryDB, db)
|
||||
|
||||
ctx = WithReadReplica(ctx)
|
||||
sc.connProv = getConnProvider(sc, mysqlDriverName, primaryDSN, policyRoundRobin, []string{replicasDSN[0]})
|
||||
db, err = sc.connProv(ctx)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, replicaOneDB, db)
|
||||
|
||||
// default policy is round-robin
|
||||
sc.connProv = getConnProvider(sc, mysqlDriverName, primaryDSN, policyRoundRobin, replicasDSN)
|
||||
replicas := []io.Closer{replicaOneDB, replicaTwoDB, replicaThreeDB}
|
||||
for i := 0; i < len(replicasDSN); i++ {
|
||||
db, err = sc.connProv(ctx)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, replicas[i], db)
|
||||
}
|
||||
|
||||
// random policy
|
||||
sc.connProv = getConnProvider(sc, mysqlDriverName, primaryDSN, policyRandom, replicasDSN)
|
||||
for i := 0; i < len(replicasDSN); i++ {
|
||||
db, err = sc.connProv(ctx)
|
||||
assert.Nil(t, err)
|
||||
assert.Contains(t, replicas, db)
|
||||
}
|
||||
|
||||
// unknown policy
|
||||
sc.connProv = getConnProvider(sc, mysqlDriverName, primaryDSN, "unknown", replicasDSN)
|
||||
_, err = sc.connProv(ctx)
|
||||
assert.NotNil(t, err)
|
||||
|
||||
// empty policy transforms to round-robin
|
||||
sc.connProv = getConnProvider(sc, mysqlDriverName, primaryDSN, "", replicasDSN)
|
||||
for i := 0; i < len(replicasDSN); i++ {
|
||||
db, err = sc.connProv(ctx)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, replicas[i], db)
|
||||
}
|
||||
}
|
||||
|
||||
func buildConn() (mock sqlmock.Sqlmock, err error) {
|
||||
_, err = connManager.GetResource(mockedDatasource, func() (io.Closer, error) {
|
||||
var db *sql.DB
|
||||
|
||||
@@ -27,7 +27,7 @@ func getCachedSqlConn(driverName, server string) (*sql.DB, error) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if driverName != mysqlDriverName {
|
||||
if driverName == mysqlDriverName {
|
||||
if cfg, e := mysql.ParseDSN(server); e != nil {
|
||||
// if cannot parse, don't collect the metrics
|
||||
logx.Error(e)
|
||||
|
||||
@@ -156,7 +156,7 @@ func begin(db *sql.DB) (trans, error) {
|
||||
|
||||
func transact(ctx context.Context, db *commonSqlConn, b beginnable,
|
||||
fn func(context.Context, Session) error) (err error) {
|
||||
conn, err := db.connProv()
|
||||
conn, err := db.connProv(ctx)
|
||||
if err != nil {
|
||||
db.onError(ctx, err)
|
||||
return err
|
||||
|
||||
@@ -117,7 +117,7 @@ func TestTxExceptions(t *testing.T) {
|
||||
|
||||
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||
conn := &commonSqlConn{
|
||||
connProv: func() (*sql.DB, error) {
|
||||
connProv: func(ctx context.Context) (*sql.DB, error) {
|
||||
return nil, errors.New("foo")
|
||||
},
|
||||
beginTx: begin,
|
||||
|
||||
@@ -2,6 +2,7 @@ package stringx
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"slices"
|
||||
"unicode"
|
||||
|
||||
"github.com/zeromicro/go-zero/core/lang"
|
||||
@@ -15,14 +16,9 @@ var (
|
||||
)
|
||||
|
||||
// Contains checks if str is in list.
|
||||
// Deprecated: use slices.Contains instead.
|
||||
func Contains(list []string, str string) bool {
|
||||
for _, each := range list {
|
||||
if each == str {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
return slices.Contains(list, str)
|
||||
}
|
||||
|
||||
// Filter filters chars from s with given filter function.
|
||||
@@ -123,11 +119,7 @@ func Remove(strings []string, strs ...string) []string {
|
||||
// Reverse reverses s.
|
||||
func Reverse(s string) string {
|
||||
runes := []rune(s)
|
||||
|
||||
for from, to := 0, len(runes)-1; from < to; from, to = from+1, to-1 {
|
||||
runes[from], runes[to] = runes[to], runes[from]
|
||||
}
|
||||
|
||||
slices.Reverse(runes)
|
||||
return string(runes)
|
||||
}
|
||||
|
||||
|
||||
@@ -7,6 +7,28 @@ import (
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestContainsString(t *testing.T) {
|
||||
cases := []struct {
|
||||
slice []string
|
||||
value string
|
||||
expect bool
|
||||
}{
|
||||
{[]string{"1"}, "1", true},
|
||||
{[]string{"1"}, "2", false},
|
||||
{[]string{"1", "2"}, "1", true},
|
||||
{[]string{"1", "2"}, "3", false},
|
||||
{nil, "3", false},
|
||||
{nil, "", false},
|
||||
}
|
||||
|
||||
for _, each := range cases {
|
||||
t.Run(path.Join(each.slice...), func(t *testing.T) {
|
||||
actual := Contains(each.slice, each.value)
|
||||
assert.Equal(t, each.expect, actual)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestNotEmpty(t *testing.T) {
|
||||
cases := []struct {
|
||||
args []string
|
||||
@@ -41,28 +63,6 @@ func TestNotEmpty(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestContainsString(t *testing.T) {
|
||||
cases := []struct {
|
||||
slice []string
|
||||
value string
|
||||
expect bool
|
||||
}{
|
||||
{[]string{"1"}, "1", true},
|
||||
{[]string{"1"}, "2", false},
|
||||
{[]string{"1", "2"}, "1", true},
|
||||
{[]string{"1", "2"}, "3", false},
|
||||
{nil, "3", false},
|
||||
{nil, "", false},
|
||||
}
|
||||
|
||||
for _, each := range cases {
|
||||
t.Run(path.Join(each.slice...), func(t *testing.T) {
|
||||
actual := Contains(each.slice, each.value)
|
||||
assert.Equal(t, each.expect, actual)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestFilter(t *testing.T) {
|
||||
cases := []struct {
|
||||
input string
|
||||
|
||||
@@ -3,9 +3,7 @@ package syncx
|
||||
import "sync"
|
||||
|
||||
// Once returns a func that guarantees fn can only called once.
|
||||
// Deprecated: use sync.OnceFunc instead.
|
||||
func Once(fn func()) func() {
|
||||
once := new(sync.Once)
|
||||
return func() {
|
||||
once.Do(fn)
|
||||
}
|
||||
return sync.OnceFunc(fn)
|
||||
}
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"cmp"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/zeromicro/go-zero/core/mathx"
|
||||
"github.com/zeromicro/go-zero/core/stringx"
|
||||
)
|
||||
|
||||
@@ -39,7 +39,7 @@ func compare(v1, v2 string) int {
|
||||
fields1, fields2 := strings.Split(v1, "."), strings.Split(v2, ".")
|
||||
ver1, ver2 := strsToInts(fields1), strsToInts(fields2)
|
||||
ver1len, ver2len := len(ver1), len(ver2)
|
||||
shorter := mathx.MinInt(ver1len, ver2len)
|
||||
shorter := min(ver1len, ver2len)
|
||||
|
||||
for i := 0; i < shorter; i++ {
|
||||
if ver1[i] == ver2[i] {
|
||||
@@ -50,14 +50,7 @@ func compare(v1, v2 string) int {
|
||||
return 1
|
||||
}
|
||||
}
|
||||
|
||||
if ver1len < ver2len {
|
||||
return -1
|
||||
} else if ver1len == ver2len {
|
||||
return 0
|
||||
} else {
|
||||
return 1
|
||||
}
|
||||
return cmp.Compare(ver1len, ver2len)
|
||||
}
|
||||
|
||||
func strsToInts(strs []string) []int64 {
|
||||
|
||||
10
go.mod
10
go.mod
@@ -4,7 +4,7 @@ go 1.21
|
||||
|
||||
require (
|
||||
github.com/DATA-DOG/go-sqlmock v1.5.2
|
||||
github.com/alicebob/miniredis/v2 v2.34.0
|
||||
github.com/alicebob/miniredis/v2 v2.35.0
|
||||
github.com/fatih/color v1.18.0
|
||||
github.com/fullstorydev/grpcurl v1.9.3
|
||||
github.com/go-sql-driver/mysql v1.9.0
|
||||
@@ -12,17 +12,17 @@ require (
|
||||
github.com/golang/mock v1.6.0
|
||||
github.com/golang/protobuf v1.5.4
|
||||
github.com/google/uuid v1.6.0
|
||||
github.com/grafana/pyroscope-go v1.2.2
|
||||
github.com/jackc/pgx/v5 v5.7.4
|
||||
github.com/jhump/protoreflect v1.17.0
|
||||
github.com/olekukonko/tablewriter v0.0.5
|
||||
github.com/pelletier/go-toml/v2 v2.2.2
|
||||
github.com/prometheus/client_golang v1.21.1
|
||||
github.com/redis/go-redis/v9 v9.8.0
|
||||
github.com/redis/go-redis/v9 v9.11.0
|
||||
github.com/spaolacci/murmur3 v1.1.0
|
||||
github.com/stretchr/testify v1.10.0
|
||||
go.etcd.io/etcd/api/v3 v3.5.15
|
||||
go.etcd.io/etcd/client/v3 v3.5.15
|
||||
go.mongodb.org/mongo-driver v1.17.3
|
||||
go.mongodb.org/mongo-driver v1.17.4
|
||||
go.opentelemetry.io/otel v1.24.0
|
||||
go.opentelemetry.io/otel/exporters/jaeger v1.17.0
|
||||
go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc v1.24.0
|
||||
@@ -50,7 +50,6 @@ require (
|
||||
|
||||
require (
|
||||
filippo.io/edwards25519 v1.1.0 // indirect
|
||||
github.com/alicebob/gopher-json v0.0.0-20230218143504-906a9b012302 // indirect
|
||||
github.com/beorn7/perks v1.0.1 // indirect
|
||||
github.com/bufbuild/protocompile v0.14.1 // indirect
|
||||
github.com/cenkalti/backoff/v4 v4.3.0 // indirect
|
||||
@@ -73,6 +72,7 @@ require (
|
||||
github.com/google/gnostic-models v0.6.8 // indirect
|
||||
github.com/google/go-cmp v0.6.0 // indirect
|
||||
github.com/google/gofuzz v1.2.0 // indirect
|
||||
github.com/grafana/pyroscope-go/godeltaprof v0.1.8 // indirect
|
||||
github.com/grpc-ecosystem/grpc-gateway/v2 v2.20.0 // indirect
|
||||
github.com/h2non/parth v0.0.0-20190131123155-b4df798d6542 // indirect
|
||||
github.com/jackc/pgpassfile v1.0.0 // indirect
|
||||
|
||||
21
go.sum
21
go.sum
@@ -2,10 +2,8 @@ filippo.io/edwards25519 v1.1.0 h1:FNf4tywRC1HmFuKW5xopWpigGjJKiJSV0Cqo0cJWDaA=
|
||||
filippo.io/edwards25519 v1.1.0/go.mod h1:BxyFTGdWcka3PhytdK4V28tE5sGfRvvvRV7EaN4VDT4=
|
||||
github.com/DATA-DOG/go-sqlmock v1.5.2 h1:OcvFkGmslmlZibjAjaHm3L//6LiuBgolP7OputlJIzU=
|
||||
github.com/DATA-DOG/go-sqlmock v1.5.2/go.mod h1:88MAG/4G7SMwSE3CeA0ZKzrT5CiOU3OJ+JlNzwDqpNU=
|
||||
github.com/alicebob/gopher-json v0.0.0-20230218143504-906a9b012302 h1:uvdUDbHQHO85qeSydJtItA4T55Pw6BtAejd0APRJOCE=
|
||||
github.com/alicebob/gopher-json v0.0.0-20230218143504-906a9b012302/go.mod h1:SGnFV6hVsYE877CKEZ6tDNTjaSXYUk6QqoIK6PrAtcc=
|
||||
github.com/alicebob/miniredis/v2 v2.34.0 h1:mBFWMaJSNL9RwdGRyEDoAAv8OQc5UlEhLDQggTglU/0=
|
||||
github.com/alicebob/miniredis/v2 v2.34.0/go.mod h1:kWShP4b58T1CW0Y5dViCd5ztzrDqRWqM3nksiyXk5s8=
|
||||
github.com/alicebob/miniredis/v2 v2.35.0 h1:QwLphYqCEAo1eu1TqPRN2jgVMPBweeQcR21jeqDCONI=
|
||||
github.com/alicebob/miniredis/v2 v2.35.0/go.mod h1:TcL7YfarKPGDAthEtl5NBeHZfeUQj6OXMm/+iu5cLMM=
|
||||
github.com/benbjohnson/clock v1.1.0 h1:Q92kusRqC1XV2MjkWETPvjJVqKetz1OzxZB7mHJLju8=
|
||||
github.com/benbjohnson/clock v1.1.0/go.mod h1:J11/hYXuz8f4ySSvYwY0FKfm+ezbsZBKZxNJlLklBHA=
|
||||
github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM=
|
||||
@@ -82,6 +80,10 @@ github.com/google/pprof v0.0.0-20210720184732-4bb14d4b1be1 h1:K6RDEckDVWvDI9JAJY
|
||||
github.com/google/pprof v0.0.0-20210720184732-4bb14d4b1be1/go.mod h1:kpwsk12EmLew5upagYY7GY0pfYCcupk39gWOCRROcvE=
|
||||
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
|
||||
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||
github.com/grafana/pyroscope-go v1.2.2 h1:uvKCyZMD724RkaCEMrSTC38Yn7AnFe8S2wiAIYdDPCE=
|
||||
github.com/grafana/pyroscope-go v1.2.2/go.mod h1:zzT9QXQAp2Iz2ZdS216UiV8y9uXJYQiGE1q8v1FyhqU=
|
||||
github.com/grafana/pyroscope-go/godeltaprof v0.1.8 h1:iwOtYXeeVSAeYefJNaxDytgjKtUuKQbJqgAIjlnicKg=
|
||||
github.com/grafana/pyroscope-go/godeltaprof v0.1.8/go.mod h1:2+l7K7twW49Ct4wFluZD3tZ6e0SjanjcUUBPVD/UuGU=
|
||||
github.com/grpc-ecosystem/grpc-gateway/v2 v2.20.0 h1:bkypFPDjIYGfCYD5mRBvpqxfYX1YCS1PXdKYWi8FsN0=
|
||||
github.com/grpc-ecosystem/grpc-gateway/v2 v2.20.0/go.mod h1:P+Lt/0by1T8bfcF3z737NnSbmxQAppXMRziHUxPOC8k=
|
||||
github.com/h2non/parth v0.0.0-20190131123155-b4df798d6542 h1:2VTzZjLZBgl62/EtslCrtky5vbi9dd7HrQPQIx6wqiw=
|
||||
@@ -121,7 +123,6 @@ github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovk
|
||||
github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM=
|
||||
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
|
||||
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
|
||||
github.com/mattn/go-runewidth v0.0.9/go.mod h1:H031xJmbD/WCDINGzjvQ9THkh0rPKHF+m2gUSrubnMI=
|
||||
github.com/mattn/go-runewidth v0.0.15 h1:UNAjwbU9l54TA3KzvqLGxwWjHmMgBUVhBiTjelZgg3U=
|
||||
github.com/mattn/go-runewidth v0.0.15/go.mod h1:Jdepj2loyihRzMpdS35Xk/zdY8IAYHsh153qUoGf23w=
|
||||
github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
|
||||
@@ -135,8 +136,6 @@ github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 h1:C3w9PqII01/Oq
|
||||
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822/go.mod h1:+n7T8mK8HuQTcFwEeznm/DIxMOiR9yIdICNftLE1DvQ=
|
||||
github.com/nbio/st v0.0.0-20140626010706-e9e8d9816f32 h1:W6apQkHrMkS0Muv8G/TipAy/FJl/rCYT0+EuS8+Z0z4=
|
||||
github.com/nbio/st v0.0.0-20140626010706-e9e8d9816f32/go.mod h1:9wM+0iRr9ahx58uYLpLIr5fm8diHn0JbqRycJi6w0Ms=
|
||||
github.com/olekukonko/tablewriter v0.0.5 h1:P2Ga83D34wi1o9J6Wh1mRuqd4mF/x/lgBS7N7AbDhec=
|
||||
github.com/olekukonko/tablewriter v0.0.5/go.mod h1:hPp6KlRPjbx+hW8ykQs1w3UBbZlj6HuIJcUGPhkA7kY=
|
||||
github.com/onsi/ginkgo/v2 v2.13.0 h1:0jY9lJquiL8fcf3M4LAXN5aMlS/b2BV86HFFPCPMgE4=
|
||||
github.com/onsi/ginkgo/v2 v2.13.0/go.mod h1:TE309ZR8s5FsKKpuB1YAQYBzCaAfUgatB/xlT/ETL/o=
|
||||
github.com/onsi/gomega v1.29.0 h1:KIA/t2t5UBzoirT4H9tsML45GEbo3ouUnBHsCfD2tVg=
|
||||
@@ -159,8 +158,8 @@ github.com/prometheus/common v0.62.0 h1:xasJaQlnWAeyHdUBeGjXmutelfJHWMRr+Fg4QszZ
|
||||
github.com/prometheus/common v0.62.0/go.mod h1:vyBcEuLSvWos9B1+CyL7JZ2up+uFzXhkqml0W5zIY1I=
|
||||
github.com/prometheus/procfs v0.15.1 h1:YagwOFzUgYfKKHX6Dr+sHT7km/hxC76UB0learggepc=
|
||||
github.com/prometheus/procfs v0.15.1/go.mod h1:fB45yRUv8NstnjriLhBQLuOUt+WW4BsoGhij/e3PBqk=
|
||||
github.com/redis/go-redis/v9 v9.8.0 h1:q3nRvjrlge/6UD7eTu/DSg2uYiU2mCL0G/uzBWqhicI=
|
||||
github.com/redis/go-redis/v9 v9.8.0/go.mod h1:huWgSWd8mW6+m0VPhJjSSQ+d6Nh1VICQ6Q5lHuCH/Iw=
|
||||
github.com/redis/go-redis/v9 v9.11.0 h1:E3S08Gl/nJNn5vkxd2i78wZxWAPNZgUNTp8WIJUAiIs=
|
||||
github.com/redis/go-redis/v9 v9.11.0/go.mod h1:huWgSWd8mW6+m0VPhJjSSQ+d6Nh1VICQ6Q5lHuCH/Iw=
|
||||
github.com/rivo/uniseg v0.2.0 h1:S1pD9weZBuJdFmowNwbpi7BJ8TNftyUImj/0WQi72jY=
|
||||
github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc=
|
||||
github.com/rogpeppe/go-internal v1.10.0 h1:TMyTOH3F/DB16zRVcYyreMH6GnZZrwQVAoYjRBZyWFQ=
|
||||
@@ -203,8 +202,8 @@ go.etcd.io/etcd/client/pkg/v3 v3.5.15 h1:fo0HpWz/KlHGMCC+YejpiCmyWDEuIpnTDzpJLB5
|
||||
go.etcd.io/etcd/client/pkg/v3 v3.5.15/go.mod h1:mXDI4NAOwEiszrHCb0aqfAYNCrZP4e9hRca3d1YK8EU=
|
||||
go.etcd.io/etcd/client/v3 v3.5.15 h1:23M0eY4Fd/inNv1ZfU3AxrbbOdW79r9V9Rl62Nm6ip4=
|
||||
go.etcd.io/etcd/client/v3 v3.5.15/go.mod h1:CLSJxrYjvLtHsrPKsy7LmZEE+DK2ktfd2bN4RhBMwlU=
|
||||
go.mongodb.org/mongo-driver v1.17.3 h1:TQyXhnsWfWtgAhMtOgtYHMTkZIfBTpMTsMnd9ZBeHxQ=
|
||||
go.mongodb.org/mongo-driver v1.17.3/go.mod h1:Hy04i7O2kC4RS06ZrhPRqj/u4DTYkFDAAccj+rVKqgQ=
|
||||
go.mongodb.org/mongo-driver v1.17.4 h1:jUorfmVzljjr0FLzYQsGP8cgN/qzzxlY9Vh0C9KFXVw=
|
||||
go.mongodb.org/mongo-driver v1.17.4/go.mod h1:Hy04i7O2kC4RS06ZrhPRqj/u4DTYkFDAAccj+rVKqgQ=
|
||||
go.opentelemetry.io/otel v1.24.0 h1:0LAOdjNmQeSTzGBzduGe/rU4tZhMwL5rWgtp9Ku5Jfo=
|
||||
go.opentelemetry.io/otel v1.24.0/go.mod h1:W7b9Ozg4nkF5tWI5zsXkaKKDjdVjpD4oAt9Qi/MArHo=
|
||||
go.opentelemetry.io/otel/exporters/jaeger v1.17.0 h1:D7UpUy2Xc2wsi1Ras6V40q806WM07rqoCWzXu7Sqy+4=
|
||||
|
||||
263
internal/profiling/profiling.go
Normal file
263
internal/profiling/profiling.go
Normal file
@@ -0,0 +1,263 @@
|
||||
package profiling
|
||||
|
||||
import (
|
||||
"runtime"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/grafana/pyroscope-go"
|
||||
"github.com/zeromicro/go-zero/core/logx"
|
||||
"github.com/zeromicro/go-zero/core/proc"
|
||||
"github.com/zeromicro/go-zero/core/stat"
|
||||
"github.com/zeromicro/go-zero/core/threading"
|
||||
)
|
||||
|
||||
const (
|
||||
defaultCheckInterval = time.Second * 10
|
||||
defaultProfilingDuration = time.Minute * 2
|
||||
defaultUploadRate = time.Second * 15
|
||||
)
|
||||
|
||||
type (
|
||||
Config struct {
|
||||
// Name is the name of the application.
|
||||
Name string `json:",optional,inherit"`
|
||||
// ServerAddr is the address of the profiling server.
|
||||
ServerAddr string
|
||||
// AuthUser is the username for basic authentication.
|
||||
AuthUser string `json:",optional"`
|
||||
// AuthPassword is the password for basic authentication.
|
||||
AuthPassword string `json:",optional"`
|
||||
// UploadRate is the duration for which profiling data is uploaded.
|
||||
UploadRate time.Duration `json:",default=15s"`
|
||||
// CheckInterval is the interval to check if profiling should start.
|
||||
CheckInterval time.Duration `json:",default=10s"`
|
||||
// ProfilingDuration is the duration for which profiling data is collected.
|
||||
ProfilingDuration time.Duration `json:",default=2m"`
|
||||
// CpuThreshold the collection is allowed only when the current service cpu > CpuThreshold
|
||||
CpuThreshold int64 `json:",default=700,range=[0:1000)"`
|
||||
|
||||
// ProfileType is the type of profiling to be performed.
|
||||
ProfileType ProfileType
|
||||
}
|
||||
|
||||
ProfileType struct {
|
||||
// Logger is a flag to enable or disable logging.
|
||||
Logger bool `json:",default=false"`
|
||||
// CPU is a flag to disable CPU profiling.
|
||||
CPU bool `json:",default=true"`
|
||||
// Goroutines is a flag to disable goroutine profiling.
|
||||
Goroutines bool `json:",default=true"`
|
||||
// Memory is a flag to disable memory profiling.
|
||||
Memory bool `json:",default=true"`
|
||||
// Mutex is a flag to disable mutex profiling.
|
||||
Mutex bool `json:",default=false"`
|
||||
// Block is a flag to disable block profiling.
|
||||
Block bool `json:",default=false"`
|
||||
}
|
||||
|
||||
profiler interface {
|
||||
Start() error
|
||||
Stop() error
|
||||
}
|
||||
|
||||
pyroscopeProfiler struct {
|
||||
c Config
|
||||
profiler *pyroscope.Profiler
|
||||
}
|
||||
)
|
||||
|
||||
var (
|
||||
once sync.Once
|
||||
|
||||
newProfiler = func(c Config) profiler {
|
||||
return newPyroscopeProfiler(c)
|
||||
}
|
||||
)
|
||||
|
||||
// Start initializes the pyroscope profiler with the given configuration.
|
||||
func Start(c Config) {
|
||||
// check if the profiling is enabled
|
||||
if len(c.ServerAddr) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
// set default values for the configuration
|
||||
if c.ProfilingDuration <= 0 {
|
||||
c.ProfilingDuration = defaultProfilingDuration
|
||||
}
|
||||
|
||||
// set default values for the configuration
|
||||
if c.CheckInterval <= 0 {
|
||||
c.CheckInterval = defaultCheckInterval
|
||||
}
|
||||
|
||||
if c.UploadRate <= 0 {
|
||||
c.UploadRate = defaultUploadRate
|
||||
}
|
||||
|
||||
once.Do(func() {
|
||||
logx.Info("continuous profiling started")
|
||||
|
||||
threading.GoSafe(func() {
|
||||
startPyroscope(c, proc.Done())
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
// startPyroscope starts the pyroscope profiler with the given configuration.
|
||||
func startPyroscope(c Config, done <-chan struct{}) {
|
||||
var (
|
||||
pr profiler
|
||||
err error
|
||||
latestProfilingTime time.Time
|
||||
intervalTicker = time.NewTicker(c.CheckInterval)
|
||||
profilingTicker = time.NewTicker(c.ProfilingDuration)
|
||||
)
|
||||
|
||||
defer profilingTicker.Stop()
|
||||
defer intervalTicker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-intervalTicker.C:
|
||||
// Check if the machine is overloaded and if the profiler is not running
|
||||
if pr == nil && isCpuOverloaded(c) {
|
||||
pr = newProfiler(c)
|
||||
if err := pr.Start(); err != nil {
|
||||
logx.Errorf("failed to start profiler: %v", err)
|
||||
continue
|
||||
}
|
||||
|
||||
// record the latest profiling time
|
||||
latestProfilingTime = time.Now()
|
||||
logx.Infof("pyroscope profiler started.")
|
||||
}
|
||||
case <-profilingTicker.C:
|
||||
// check if the profiling duration has passed
|
||||
if !time.Now().After(latestProfilingTime.Add(c.ProfilingDuration)) {
|
||||
continue
|
||||
}
|
||||
|
||||
// check if the profiler is already running, if so, skip
|
||||
if pr != nil {
|
||||
if err = pr.Stop(); err != nil {
|
||||
logx.Errorf("failed to stop profiler: %v", err)
|
||||
}
|
||||
logx.Infof("pyroscope profiler stopped.")
|
||||
pr = nil
|
||||
}
|
||||
case <-done:
|
||||
logx.Infof("continuous profiling stopped.")
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// genPyroscopeConf generates the pyroscope configuration based on the given config.
|
||||
func genPyroscopeConf(c Config) pyroscope.Config {
|
||||
pConf := pyroscope.Config{
|
||||
UploadRate: c.UploadRate,
|
||||
ApplicationName: c.Name,
|
||||
BasicAuthUser: c.AuthUser, // http basic auth user
|
||||
BasicAuthPassword: c.AuthPassword, // http basic auth password
|
||||
ServerAddress: c.ServerAddr,
|
||||
Logger: nil,
|
||||
HTTPHeaders: map[string]string{},
|
||||
// you can provide static tags via a map:
|
||||
Tags: map[string]string{
|
||||
"name": c.Name,
|
||||
},
|
||||
}
|
||||
|
||||
if c.ProfileType.Logger {
|
||||
pConf.Logger = logx.WithCallerSkip(0)
|
||||
}
|
||||
|
||||
if c.ProfileType.CPU {
|
||||
pConf.ProfileTypes = append(pConf.ProfileTypes, pyroscope.ProfileCPU)
|
||||
}
|
||||
if c.ProfileType.Goroutines {
|
||||
pConf.ProfileTypes = append(pConf.ProfileTypes, pyroscope.ProfileGoroutines)
|
||||
}
|
||||
if c.ProfileType.Memory {
|
||||
pConf.ProfileTypes = append(pConf.ProfileTypes, pyroscope.ProfileAllocObjects, pyroscope.ProfileAllocSpace,
|
||||
pyroscope.ProfileInuseObjects, pyroscope.ProfileInuseSpace)
|
||||
}
|
||||
if c.ProfileType.Mutex {
|
||||
pConf.ProfileTypes = append(pConf.ProfileTypes, pyroscope.ProfileMutexCount, pyroscope.ProfileMutexDuration)
|
||||
}
|
||||
if c.ProfileType.Block {
|
||||
pConf.ProfileTypes = append(pConf.ProfileTypes, pyroscope.ProfileBlockCount, pyroscope.ProfileBlockDuration)
|
||||
}
|
||||
|
||||
logx.Infof("applicationName: %s", pConf.ApplicationName)
|
||||
|
||||
return pConf
|
||||
}
|
||||
|
||||
// isCpuOverloaded checks the machine performance based on the given configuration.
|
||||
func isCpuOverloaded(c Config) bool {
|
||||
currentValue := stat.CpuUsage()
|
||||
if currentValue >= c.CpuThreshold {
|
||||
logx.Infof("continuous profiling cpu overload, cpu: %d", currentValue)
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func newPyroscopeProfiler(c Config) profiler {
|
||||
return &pyroscopeProfiler{
|
||||
c: c,
|
||||
}
|
||||
}
|
||||
|
||||
func (p *pyroscopeProfiler) Start() error {
|
||||
pConf := genPyroscopeConf(p.c)
|
||||
// set mutex and block profile rate
|
||||
setFraction(p.c)
|
||||
prof, err := pyroscope.Start(pConf)
|
||||
if err != nil {
|
||||
resetFraction(p.c)
|
||||
return err
|
||||
}
|
||||
|
||||
p.profiler = prof
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *pyroscopeProfiler) Stop() error {
|
||||
if p.profiler == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := p.profiler.Stop(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
resetFraction(p.c)
|
||||
p.profiler = nil
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func setFraction(c Config) {
|
||||
// These 2 lines are only required if you're using mutex or block profiling
|
||||
if c.ProfileType.Mutex {
|
||||
runtime.SetMutexProfileFraction(10) // 10/seconds
|
||||
}
|
||||
if c.ProfileType.Block {
|
||||
runtime.SetBlockProfileRate(1000 * 1000) // 1/millisecond
|
||||
}
|
||||
}
|
||||
|
||||
func resetFraction(c Config) {
|
||||
// These 2 lines are only required if you're using mutex or block profiling
|
||||
if c.ProfileType.Mutex {
|
||||
runtime.SetMutexProfileFraction(0)
|
||||
}
|
||||
if c.ProfileType.Block {
|
||||
runtime.SetBlockProfileRate(0)
|
||||
}
|
||||
}
|
||||
177
internal/profiling/profiling_test.go
Normal file
177
internal/profiling/profiling_test.go
Normal file
@@ -0,0 +1,177 @@
|
||||
package profiling
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/grafana/pyroscope-go"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/zeromicro/go-zero/core/conf"
|
||||
"github.com/zeromicro/go-zero/core/syncx"
|
||||
)
|
||||
|
||||
func TestStart(t *testing.T) {
|
||||
t.Run("profiling", func(t *testing.T) {
|
||||
var c Config
|
||||
assert.NoError(t, conf.FillDefault(&c))
|
||||
c.Name = "test"
|
||||
p := newProfiler(c)
|
||||
assert.NotNil(t, p)
|
||||
assert.NoError(t, p.Start())
|
||||
assert.NoError(t, p.Stop())
|
||||
})
|
||||
|
||||
t.Run("invalid config", func(t *testing.T) {
|
||||
mp := &mockProfiler{}
|
||||
newProfiler = func(c Config) profiler {
|
||||
return mp
|
||||
}
|
||||
|
||||
Start(Config{})
|
||||
|
||||
Start(Config{
|
||||
ServerAddr: "localhost:4040",
|
||||
})
|
||||
})
|
||||
|
||||
t.Run("test start profiler", func(t *testing.T) {
|
||||
mp := &mockProfiler{}
|
||||
newProfiler = func(c Config) profiler {
|
||||
return mp
|
||||
}
|
||||
|
||||
c := Config{
|
||||
Name: "test",
|
||||
ServerAddr: "localhost:4040",
|
||||
CheckInterval: time.Millisecond,
|
||||
ProfilingDuration: time.Millisecond * 10,
|
||||
CpuThreshold: 0,
|
||||
}
|
||||
var done = make(chan struct{})
|
||||
go startPyroscope(c, done)
|
||||
|
||||
time.Sleep(time.Millisecond * 50)
|
||||
close(done)
|
||||
|
||||
assert.True(t, mp.started.True())
|
||||
assert.True(t, mp.stopped.True())
|
||||
})
|
||||
|
||||
t.Run("test start profiler with cpu overloaded", func(t *testing.T) {
|
||||
mp := &mockProfiler{}
|
||||
newProfiler = func(c Config) profiler {
|
||||
return mp
|
||||
}
|
||||
|
||||
c := Config{
|
||||
Name: "test",
|
||||
ServerAddr: "localhost:4040",
|
||||
CheckInterval: time.Millisecond,
|
||||
ProfilingDuration: time.Millisecond * 10,
|
||||
CpuThreshold: 900,
|
||||
}
|
||||
var done = make(chan struct{})
|
||||
go startPyroscope(c, done)
|
||||
|
||||
time.Sleep(time.Millisecond * 50)
|
||||
close(done)
|
||||
|
||||
assert.False(t, mp.started.True())
|
||||
})
|
||||
|
||||
t.Run("start/stop err", func(t *testing.T) {
|
||||
mp := &mockProfiler{
|
||||
err: assert.AnError,
|
||||
}
|
||||
newProfiler = func(c Config) profiler {
|
||||
return mp
|
||||
}
|
||||
|
||||
c := Config{
|
||||
Name: "test",
|
||||
ServerAddr: "localhost:4040",
|
||||
CheckInterval: time.Millisecond,
|
||||
ProfilingDuration: time.Millisecond * 10,
|
||||
CpuThreshold: 0,
|
||||
}
|
||||
var done = make(chan struct{})
|
||||
go startPyroscope(c, done)
|
||||
|
||||
time.Sleep(time.Millisecond * 50)
|
||||
close(done)
|
||||
|
||||
assert.False(t, mp.started.True())
|
||||
assert.False(t, mp.stopped.True())
|
||||
})
|
||||
}
|
||||
|
||||
func TestGenPyroscopeConf(t *testing.T) {
|
||||
c := Config{
|
||||
Name: "",
|
||||
ServerAddr: "localhost:4040",
|
||||
AuthUser: "user",
|
||||
AuthPassword: "password",
|
||||
ProfileType: ProfileType{
|
||||
Logger: true,
|
||||
CPU: true,
|
||||
Goroutines: true,
|
||||
Memory: true,
|
||||
Mutex: true,
|
||||
Block: true,
|
||||
},
|
||||
}
|
||||
|
||||
pyroscopeConf := genPyroscopeConf(c)
|
||||
assert.Equal(t, c.ServerAddr, pyroscopeConf.ServerAddress)
|
||||
assert.Equal(t, c.AuthUser, pyroscopeConf.BasicAuthUser)
|
||||
assert.Equal(t, c.AuthPassword, pyroscopeConf.BasicAuthPassword)
|
||||
assert.Equal(t, c.Name, pyroscopeConf.ApplicationName)
|
||||
assert.Contains(t, pyroscopeConf.ProfileTypes, pyroscope.ProfileCPU)
|
||||
assert.Contains(t, pyroscopeConf.ProfileTypes, pyroscope.ProfileGoroutines)
|
||||
assert.Contains(t, pyroscopeConf.ProfileTypes, pyroscope.ProfileAllocObjects)
|
||||
assert.Contains(t, pyroscopeConf.ProfileTypes, pyroscope.ProfileAllocSpace)
|
||||
assert.Contains(t, pyroscopeConf.ProfileTypes, pyroscope.ProfileInuseObjects)
|
||||
assert.Contains(t, pyroscopeConf.ProfileTypes, pyroscope.ProfileInuseSpace)
|
||||
assert.Contains(t, pyroscopeConf.ProfileTypes, pyroscope.ProfileMutexCount)
|
||||
assert.Contains(t, pyroscopeConf.ProfileTypes, pyroscope.ProfileMutexDuration)
|
||||
assert.Contains(t, pyroscopeConf.ProfileTypes, pyroscope.ProfileBlockCount)
|
||||
assert.Contains(t, pyroscopeConf.ProfileTypes, pyroscope.ProfileBlockDuration)
|
||||
|
||||
setFraction(c)
|
||||
resetFraction(c)
|
||||
|
||||
newPyroscopeProfiler(c)
|
||||
}
|
||||
|
||||
func TestNewPyroscopeProfiler(t *testing.T) {
|
||||
p := newPyroscopeProfiler(Config{})
|
||||
|
||||
assert.Error(t, p.Start())
|
||||
assert.NoError(t, p.Stop())
|
||||
}
|
||||
|
||||
type mockProfiler struct {
|
||||
mutex sync.Mutex
|
||||
started syncx.AtomicBool
|
||||
stopped syncx.AtomicBool
|
||||
err error
|
||||
}
|
||||
|
||||
func (m *mockProfiler) Start() error {
|
||||
m.mutex.Lock()
|
||||
if m.err == nil {
|
||||
m.started.Set(true)
|
||||
}
|
||||
m.mutex.Unlock()
|
||||
return m.err
|
||||
}
|
||||
|
||||
func (m *mockProfiler) Stop() error {
|
||||
m.mutex.Lock()
|
||||
if m.err == nil {
|
||||
m.stopped.Set(true)
|
||||
}
|
||||
m.mutex.Unlock()
|
||||
return m.err
|
||||
}
|
||||
@@ -173,17 +173,20 @@ func (s *sseMcpServer) handleRequest(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
w.WriteHeader(http.StatusAccepted)
|
||||
|
||||
// For notification methods (no ID), we don't send a response
|
||||
isNotification := req.ID == 0
|
||||
isNotification, err := req.isNotification()
|
||||
if err != nil {
|
||||
http.Error(w, "Invalid request.ID", http.StatusBadRequest)
|
||||
}
|
||||
|
||||
w.WriteHeader(http.StatusAccepted)
|
||||
|
||||
// Special handling for initialization sequence
|
||||
// Always allow initialize and notifications/initialized regardless of client state
|
||||
if req.Method == methodInitialize {
|
||||
logx.Infof("Processing initialize request with ID: %d", req.ID)
|
||||
logx.Infof("Processing initialize request with ID: %v", req.ID)
|
||||
s.processInitialize(r.Context(), client, req)
|
||||
logx.Infof("Sent initialize response for ID: %d, waiting for notifications/initialized", req.ID)
|
||||
logx.Infof("Sent initialize response for ID: %v, waiting for notifications/initialized", req.ID)
|
||||
return
|
||||
} else if req.Method == methodNotificationsInitialized {
|
||||
// Handle initialized notification
|
||||
@@ -206,41 +209,41 @@ func (s *sseMcpServer) handleRequest(w http.ResponseWriter, r *http.Request) {
|
||||
// Process normal requests only after initialization
|
||||
switch req.Method {
|
||||
case methodToolsCall:
|
||||
logx.Infof("Received tools call request with ID: %d", req.ID)
|
||||
logx.Infof("Received tools call request with ID: %v", req.ID)
|
||||
s.processToolCall(r.Context(), client, req)
|
||||
logx.Infof("Sent tools call response for ID: %d", req.ID)
|
||||
logx.Infof("Sent tools call response for ID: %v", req.ID)
|
||||
case methodToolsList:
|
||||
logx.Infof("Processing tools/list request with ID: %d", req.ID)
|
||||
logx.Infof("Processing tools/list request with ID: %v", req.ID)
|
||||
s.processListTools(r.Context(), client, req)
|
||||
logx.Infof("Sent tools/list response for ID: %d", req.ID)
|
||||
logx.Infof("Sent tools/list response for ID: %v", req.ID)
|
||||
case methodPromptsList:
|
||||
logx.Infof("Processing prompts/list request with ID: %d", req.ID)
|
||||
logx.Infof("Processing prompts/list request with ID: %v", req.ID)
|
||||
s.processListPrompts(r.Context(), client, req)
|
||||
logx.Infof("Sent prompts/list response for ID: %d", req.ID)
|
||||
logx.Infof("Sent prompts/list response for ID: %v", req.ID)
|
||||
case methodPromptsGet:
|
||||
logx.Infof("Processing prompts/get request with ID: %d", req.ID)
|
||||
logx.Infof("Processing prompts/get request with ID: %v", req.ID)
|
||||
s.processGetPrompt(r.Context(), client, req)
|
||||
logx.Infof("Sent prompts/get response for ID: %d", req.ID)
|
||||
logx.Infof("Sent prompts/get response for ID: %v", req.ID)
|
||||
case methodResourcesList:
|
||||
logx.Infof("Processing resources/list request with ID: %d", req.ID)
|
||||
logx.Infof("Processing resources/list request with ID: %v", req.ID)
|
||||
s.processListResources(r.Context(), client, req)
|
||||
logx.Infof("Sent resources/list response for ID: %d", req.ID)
|
||||
logx.Infof("Sent resources/list response for ID: %v", req.ID)
|
||||
case methodResourcesRead:
|
||||
logx.Infof("Processing resources/read request with ID: %d", req.ID)
|
||||
logx.Infof("Processing resources/read request with ID: %v", req.ID)
|
||||
s.processResourcesRead(r.Context(), client, req)
|
||||
logx.Infof("Sent resources/read response for ID: %d", req.ID)
|
||||
logx.Infof("Sent resources/read response for ID: %v", req.ID)
|
||||
case methodResourcesSubscribe:
|
||||
logx.Infof("Processing resources/subscribe request with ID: %d", req.ID)
|
||||
logx.Infof("Processing resources/subscribe request with ID: %v", req.ID)
|
||||
s.processResourceSubscribe(r.Context(), client, req)
|
||||
logx.Infof("Sent resources/subscribe response for ID: %d", req.ID)
|
||||
logx.Infof("Sent resources/subscribe response for ID: %v", req.ID)
|
||||
case methodPing:
|
||||
logx.Infof("Processing ping request with ID: %d", req.ID)
|
||||
logx.Infof("Processing ping request with ID: %v", req.ID)
|
||||
s.processPing(r.Context(), client, req)
|
||||
case methodNotificationsCancelled:
|
||||
logx.Infof("Received notifications/cancelled notification: %d", req.ID)
|
||||
logx.Infof("Received notifications/cancelled notification: %v", req.ID)
|
||||
s.processNotificationCancelled(r.Context(), client, req)
|
||||
default:
|
||||
logx.Infof("Unknown method: %s from client: %d", req.Method, req.ID)
|
||||
logx.Infof("Unknown method: %s from client: %v", req.Method, req.ID)
|
||||
s.sendErrorResponse(r.Context(), client, req.ID, "Method not found", errCodeMethodNotFound)
|
||||
}
|
||||
}
|
||||
@@ -809,7 +812,7 @@ func (s *sseMcpServer) processResourcesRead(ctx context.Context, client *mcpClie
|
||||
}
|
||||
|
||||
// Ensure MimeType is set if available from the resource definition
|
||||
if len(content.MimeType) == 0 && resource.MimeType != "" {
|
||||
if len(content.MimeType) == 0 && len(resource.MimeType) > 0 {
|
||||
content.MimeType = resource.MimeType
|
||||
}
|
||||
|
||||
@@ -880,10 +883,10 @@ func (s *sseMcpServer) processPing(ctx context.Context, client *mcpClient, req R
|
||||
|
||||
// sendErrorResponse sends an error response via the SSE channel
|
||||
func (s *sseMcpServer) sendErrorResponse(ctx context.Context, client *mcpClient,
|
||||
id int64, message string, code int) {
|
||||
id any, message string, code int) {
|
||||
errorResponse := struct {
|
||||
JsonRpc string `json:"jsonrpc"`
|
||||
ID int64 `json:"id"`
|
||||
ID any `json:"id"`
|
||||
Error errorMessage `json:"error"`
|
||||
}{
|
||||
JsonRpc: jsonRpcVersion,
|
||||
@@ -898,7 +901,7 @@ func (s *sseMcpServer) sendErrorResponse(ctx context.Context, client *mcpClient,
|
||||
jsonData, _ := json.Marshal(errorResponse)
|
||||
// Use CRLF line endings as requested
|
||||
sseMessage := fmt.Sprintf("event: %s\r\ndata: %s\r\n\r\n", eventMessage, string(jsonData))
|
||||
logx.Infof("Sending error for ID %d: %s", id, sseMessage)
|
||||
logx.Infof("Sending error for ID %v: %s", id, sseMessage)
|
||||
|
||||
// cannot receive from ctx.Done() because we're sending to the channel for SSE messages
|
||||
select {
|
||||
@@ -910,7 +913,7 @@ func (s *sseMcpServer) sendErrorResponse(ctx context.Context, client *mcpClient,
|
||||
}
|
||||
|
||||
// sendResponse sends a success response via the SSE channel
|
||||
func (s *sseMcpServer) sendResponse(ctx context.Context, client *mcpClient, id int64, result any) {
|
||||
func (s *sseMcpServer) sendResponse(ctx context.Context, client *mcpClient, id any, result any) {
|
||||
response := Response{
|
||||
JsonRpc: jsonRpcVersion,
|
||||
ID: id,
|
||||
@@ -925,13 +928,13 @@ func (s *sseMcpServer) sendResponse(ctx context.Context, client *mcpClient, id i
|
||||
|
||||
// Use CRLF line endings as requested
|
||||
sseMessage := fmt.Sprintf("event: %s\r\ndata: %s\r\n\r\n", eventMessage, string(jsonData))
|
||||
logx.Infof("Sending response for ID %d: %s", id, sseMessage)
|
||||
logx.Infof("Sending response for ID %v: %s", id, sseMessage)
|
||||
|
||||
// cannot receive from ctx.Done() because we're sending to the channel for SSE messages
|
||||
select {
|
||||
case client.channel <- sseMessage:
|
||||
default:
|
||||
// Channel buffer is full, log warning and continue
|
||||
logx.Infof("Client %s channel is full while sending response with ID %d", client.id, id)
|
||||
logx.Infof("Client %s channel is full while sending response with ID %v", client.id, id)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -175,6 +175,20 @@ func TestHandleRequest_badRequest(t *testing.T) {
|
||||
mock.server.handleRequest(w, r)
|
||||
assert.Equal(t, http.StatusBadRequest, w.Code)
|
||||
})
|
||||
|
||||
t.Run("bad id", func(t *testing.T) {
|
||||
mock := newMockMcpServer(t)
|
||||
defer mock.shutdown()
|
||||
|
||||
addTestClient(mock.server, "test-session", true)
|
||||
|
||||
body := `{"jsonrpc": "2.0", "id": {}, "method": "tools.call", "params": {}}`
|
||||
r := httptest.NewRequest(http.MethodPost, "/?session_id=test-session", bytes.NewReader([]byte(body)))
|
||||
w := httptest.NewRecorder()
|
||||
mock.server.handleRequest(w, r)
|
||||
assert.Equal(t, http.StatusBadRequest, w.Code)
|
||||
assert.Contains(t, w.Body.String(), "Invalid request.ID")
|
||||
})
|
||||
}
|
||||
|
||||
func TestRegisterTool(t *testing.T) {
|
||||
|
||||
31
mcp/types.go
31
mcp/types.go
@@ -3,6 +3,7 @@ package mcp
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"sync"
|
||||
|
||||
"github.com/zeromicro/go-zero/rest"
|
||||
@@ -15,11 +16,28 @@ type Cursor string
|
||||
type Request struct {
|
||||
SessionId string `form:"session_id"` // Session identifier for client tracking
|
||||
JsonRpc string `json:"jsonrpc"` // Must be "2.0" per JSON-RPC spec
|
||||
ID int64 `json:"id"` // Request identifier for matching responses
|
||||
ID any `json:"id"` // Request identifier for matching responses
|
||||
Method string `json:"method"` // Method name to invoke
|
||||
Params json.RawMessage `json:"params"` // Parameters for the method
|
||||
}
|
||||
|
||||
func (r Request) isNotification() (bool, error) {
|
||||
switch val := r.ID.(type) {
|
||||
case int:
|
||||
return val == 0, nil
|
||||
case int64:
|
||||
return val == 0, nil
|
||||
case float64:
|
||||
return val == 0.0, nil
|
||||
case string:
|
||||
return len(val) == 0, nil
|
||||
case nil:
|
||||
return true, nil
|
||||
default:
|
||||
return false, fmt.Errorf("invalid type %T", val)
|
||||
}
|
||||
}
|
||||
|
||||
type PaginatedParams struct {
|
||||
Cursor string `json:"cursor"`
|
||||
Meta struct {
|
||||
@@ -116,13 +134,8 @@ type FileContent struct {
|
||||
|
||||
// EmbeddedResource represents a resource embedded in a message
|
||||
type EmbeddedResource struct {
|
||||
Type string `json:"type"` // Always "resource"
|
||||
Resource struct {
|
||||
URI string `json:"uri"` // Resource URI
|
||||
MimeType string `json:"mimeType"` // MIME type of the resource
|
||||
Text string `json:"text,omitempty"` // Text content (if available)
|
||||
Blob string `json:"blob,omitempty"` // Base64 encoded blob data (if available)
|
||||
} `json:"resource"` // The resource data
|
||||
Type string `json:"type"` // Always "resource"
|
||||
Resource ResourceContent `json:"resource"` // The resource data
|
||||
}
|
||||
|
||||
// Annotations provides additional metadata for content
|
||||
@@ -249,7 +262,7 @@ type errorObj struct {
|
||||
// Response represents a JSON-RPC response
|
||||
type Response struct {
|
||||
JsonRpc string `json:"jsonrpc"` // Always "2.0"
|
||||
ID int64 `json:"id"` // Same as request ID
|
||||
ID any `json:"id"` // Same as request ID
|
||||
Result any `json:"result"` // Result object (null if error)
|
||||
Error *errorObj `json:"error,omitempty"` // Error object (null if success)
|
||||
}
|
||||
|
||||
@@ -3,6 +3,7 @@ package mcp
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
@@ -55,7 +56,7 @@ func TestRequestUnmarshaling(t *testing.T) {
|
||||
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "2.0", req.JsonRpc)
|
||||
assert.Equal(t, int64(789), req.ID)
|
||||
assert.Equal(t, float64(789), req.ID)
|
||||
assert.Equal(t, "test_method", req.Method)
|
||||
|
||||
// Check params unmarshaled correctly
|
||||
@@ -204,3 +205,67 @@ func TestCallToolResult(t *testing.T) {
|
||||
assert.Contains(t, string(data), `"content":[{"text":"Sample result"}]`)
|
||||
assert.NotContains(t, string(data), `"isError":`)
|
||||
}
|
||||
|
||||
func TestRequest_isNotification(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
id any
|
||||
want bool
|
||||
wantErr error
|
||||
}{
|
||||
// integer test cases
|
||||
{name: "int zero", id: 0, want: true, wantErr: nil},
|
||||
{name: "int non-zero", id: 1, want: false, wantErr: nil},
|
||||
{name: "int64 zero", id: int64(0), want: true, wantErr: nil},
|
||||
{name: "int64 max", id: int64(9223372036854775807), want: false, wantErr: nil},
|
||||
|
||||
// floating point number test cases
|
||||
{name: "float64 zero", id: float64(0.0), want: true, wantErr: nil},
|
||||
{name: "float64 positive", id: float64(0.000001), want: false, wantErr: nil},
|
||||
{name: "float64 negative", id: float64(-0.000001), want: false, wantErr: nil},
|
||||
{name: "float64 epsilon", id: float64(1e-300), want: false, wantErr: nil},
|
||||
|
||||
// string test cases
|
||||
{name: "empty string", id: "", want: true, wantErr: nil},
|
||||
{name: "non-empty string", id: "abc", want: false, wantErr: nil},
|
||||
{name: "space string", id: " ", want: false, wantErr: nil},
|
||||
{name: "unicode string", id: "こんにちは", want: false, wantErr: nil},
|
||||
|
||||
// special cases
|
||||
{name: "nil", id: nil, want: true, wantErr: nil},
|
||||
|
||||
// logical type test cases
|
||||
{name: "bool true", id: true, want: false, wantErr: errors.New("invalid type bool")},
|
||||
{name: "bool false", id: false, want: false, wantErr: errors.New("invalid type bool")},
|
||||
{name: "struct type", id: struct{}{}, want: false, wantErr: errors.New("invalid type struct {}")},
|
||||
{name: "slice type", id: []int{1, 2, 3}, want: false, wantErr: errors.New("invalid type []int")},
|
||||
{name: "map type", id: map[string]int{"a": 1}, want: false, wantErr: errors.New("invalid type map[string]int")},
|
||||
{name: "pointer type", id: new(int), want: false, wantErr: errors.New("invalid type *int")},
|
||||
{name: "func type", id: func() {}, want: false, wantErr: errors.New("invalid type func()")},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
req := Request{
|
||||
SessionId: "test-session",
|
||||
JsonRpc: "2.0",
|
||||
ID: tt.id,
|
||||
Method: "testMethod",
|
||||
Params: json.RawMessage(`{}`),
|
||||
}
|
||||
|
||||
got, err := req.isNotification()
|
||||
|
||||
if (err != nil) != (tt.wantErr != nil) {
|
||||
t.Fatalf("error presence mismatch: got error = %v, wantErr %v", err, tt.wantErr)
|
||||
}
|
||||
if err != nil && tt.wantErr != nil && err.Error() != tt.wantErr.Error() {
|
||||
t.Fatalf("error message mismatch:\ngot %q\nwant %q", err.Error(), tt.wantErr.Error())
|
||||
}
|
||||
|
||||
if got != tt.want {
|
||||
t.Errorf("isNotification() = %v, want %v for ID %v (%T)", got, tt.want, tt.id, tt.id)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -6,7 +6,6 @@
|
||||
|
||||
[English](readme.md) | 简体中文
|
||||
|
||||
[](https://github.com/zeromicro/go-zero/actions)
|
||||
[](https://goreportcard.com/report/github.com/zeromicro/go-zero)
|
||||
[](https://goproxy.cn/stats/github.com/zeromicro/go-zero/badges/download-count.svg)
|
||||
[](https://codecov.io/gh/zeromicro/go-zero)
|
||||
@@ -304,6 +303,7 @@ go-zero 已被许多公司用于生产部署,接入场景如在线教育、电
|
||||
>105. 昆仑万维科技股份有限公司
|
||||
>106. 无锡盛算信息技术有限公司
|
||||
>107. 深圳市聚货通信息科技有限公司
|
||||
>108. 浙江银盾云科技有限公司
|
||||
|
||||
如果贵公司也已使用 go-zero,欢迎在 [登记地址](https://github.com/zeromicro/go-zero/issues/602) 登记,仅仅为了推广,不做其它用途。
|
||||
|
||||
|
||||
@@ -7,7 +7,6 @@ go-zero is a web and rpc framework with lots of builtin engineering practices. I
|
||||
|
||||
<div align=center>
|
||||
|
||||
[](https://github.com/zeromicro/go-zero/actions)
|
||||
[](https://codecov.io/gh/zeromicro/go-zero)
|
||||
[](https://goreportcard.com/report/github.com/zeromicro/go-zero)
|
||||
[](https://github.com/zeromicro/go-zero)
|
||||
@@ -251,7 +250,3 @@ go-zero enlisted in the [CNCF Cloud Native Landscape](https://landscape.cncf.io/
|
||||
## Give a Star! ⭐
|
||||
|
||||
If you like this project or are using it to learn or start your own solution, give it a star to get updates on new releases. Your support matters!
|
||||
|
||||
## Buy me a coffee
|
||||
|
||||
<a href="https://www.buymeacoffee.com/kevwan" target="_blank"><img src="https://cdn.buymeacoffee.com/buttons/v2/default-yellow.png" alt="Buy Me A Coffee" style="height: 60px !important;width: 217px !important;" ></a>
|
||||
|
||||
@@ -28,7 +28,10 @@ var ErrSignatureConfig = errors.New("bad config for Signature")
|
||||
type engine struct {
|
||||
conf RestConf
|
||||
routes []featuredRoutes
|
||||
// timeout is the max timeout of all routes
|
||||
// timeout is the max timeout of all routes,
|
||||
// and is used to set http.Server.ReadTimeout and http.Server.WriteTimeout.
|
||||
// this network timeout is used to avoid DoS attacks by sending data slowly
|
||||
// or receiving data slowly with many connections to exhaust server resources.
|
||||
timeout time.Duration
|
||||
unauthorizedCallback handler.UnauthorizedCallback
|
||||
unsignedCallback handler.UnsignedCallback
|
||||
@@ -60,11 +63,7 @@ func (ng *engine) addRoutes(r featuredRoutes) {
|
||||
}
|
||||
ng.routes = append(ng.routes, r)
|
||||
|
||||
// need to guarantee the timeout is the max of all routes
|
||||
// otherwise impossible to set http.Server.ReadTimeout & WriteTimeout
|
||||
if r.timeout > ng.timeout {
|
||||
ng.timeout = r.timeout
|
||||
}
|
||||
ng.mightUpdateTimeout(r)
|
||||
}
|
||||
|
||||
func buildSSERoutes(routes []Route) []Route {
|
||||
@@ -192,11 +191,12 @@ func (ng *engine) checkedMaxBytes(bytes int64) int64 {
|
||||
return ng.conf.MaxBytes
|
||||
}
|
||||
|
||||
func (ng *engine) checkedTimeout(timeout time.Duration) time.Duration {
|
||||
if timeout > 0 {
|
||||
return timeout
|
||||
func (ng *engine) checkedTimeout(timeout *time.Duration) time.Duration {
|
||||
if timeout != nil {
|
||||
return *timeout
|
||||
}
|
||||
|
||||
// if timeout not set in featured routes, use global timeout
|
||||
return time.Duration(ng.conf.Timeout) * time.Millisecond
|
||||
}
|
||||
|
||||
@@ -228,6 +228,32 @@ func (ng *engine) getShedder(priority bool) load.Shedder {
|
||||
return ng.shedder
|
||||
}
|
||||
|
||||
func (ng *engine) hasTimeout() bool {
|
||||
return ng.conf.Middlewares.Timeout && ng.timeout > 0
|
||||
}
|
||||
|
||||
// mightUpdateTimeout checks if the route timeout is greater than the current,
|
||||
// and updates the engine's timeout accordingly.
|
||||
func (ng *engine) mightUpdateTimeout(r featuredRoutes) {
|
||||
// if global timeout is set to 0, it means no need to set read/write timeout
|
||||
// if route timeout is nil, no need to update ng.timeout
|
||||
if ng.timeout == 0 || r.timeout == nil {
|
||||
return
|
||||
}
|
||||
|
||||
// if route timeout is 0 (means no timeout), cannot set read/write timeout
|
||||
if *r.timeout == 0 {
|
||||
ng.timeout = 0
|
||||
return
|
||||
}
|
||||
|
||||
// need to guarantee the timeout is the max of all routes
|
||||
// otherwise impossible to set http.Server.ReadTimeout & WriteTimeout
|
||||
if *r.timeout > ng.timeout {
|
||||
ng.timeout = *r.timeout
|
||||
}
|
||||
}
|
||||
|
||||
// notFoundHandler returns a middleware that handles 404 not found requests.
|
||||
func (ng *engine) notFoundHandler(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
@@ -329,7 +355,7 @@ func (ng *engine) start(router httpx.Router, opts ...StartOption) error {
|
||||
}
|
||||
|
||||
// make sure user defined options overwrite default options
|
||||
opts = append([]StartOption{ng.withTimeout()}, opts...)
|
||||
opts = append([]StartOption{ng.withNetworkTimeout()}, opts...)
|
||||
|
||||
if len(ng.conf.CertFile) == 0 && len(ng.conf.KeyFile) == 0 {
|
||||
return internal.StartHttp(ng.conf.Host, ng.conf.Port, router, opts...)
|
||||
@@ -352,18 +378,19 @@ func (ng *engine) use(middleware Middleware) {
|
||||
ng.middlewares = append(ng.middlewares, middleware)
|
||||
}
|
||||
|
||||
func (ng *engine) withTimeout() internal.StartOption {
|
||||
func (ng *engine) withNetworkTimeout() internal.StartOption {
|
||||
return func(svr *http.Server) {
|
||||
timeout := ng.timeout
|
||||
if timeout > 0 {
|
||||
// factor 0.8, to avoid clients send longer content-length than the actual content,
|
||||
// without this timeout setting, the server will time out and respond 503 Service Unavailable,
|
||||
// which triggers the circuit breaker.
|
||||
svr.ReadTimeout = 4 * timeout / 5
|
||||
// factor 1.1, to avoid servers don't have enough time to write responses.
|
||||
// setting the factor less than 1.0 may lead clients not receiving the responses.
|
||||
svr.WriteTimeout = 11 * timeout / 10
|
||||
if !ng.hasTimeout() {
|
||||
return
|
||||
}
|
||||
|
||||
// factor 0.8, to avoid clients send longer content-length than the actual content,
|
||||
// without this timeout setting, the server will time out and respond 503 Service Unavailable,
|
||||
// which triggers the circuit breaker.
|
||||
svr.ReadTimeout = 4 * ng.timeout / 5
|
||||
// factor 1.1, to avoid servers don't have enough time to write responses.
|
||||
// setting the factor less than 1.0 may lead clients not receiving the responses.
|
||||
svr.WriteTimeout = 11 * ng.timeout / 10
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -73,7 +73,17 @@ Verbose: true
|
||||
Path: "/",
|
||||
Handler: func(w http.ResponseWriter, r *http.Request) {},
|
||||
}},
|
||||
timeout: time.Minute,
|
||||
timeout: ptrOfDuration(time.Minute),
|
||||
},
|
||||
{
|
||||
jwt: jwtSetting{},
|
||||
signature: signatureSetting{},
|
||||
routes: []Route{{
|
||||
Method: http.MethodGet,
|
||||
Path: "/",
|
||||
Handler: func(w http.ResponseWriter, r *http.Request) {},
|
||||
}},
|
||||
timeout: ptrOfDuration(0),
|
||||
},
|
||||
{
|
||||
priority: true,
|
||||
@@ -84,7 +94,7 @@ Verbose: true
|
||||
Path: "/",
|
||||
Handler: func(w http.ResponseWriter, r *http.Request) {},
|
||||
}},
|
||||
timeout: time.Second,
|
||||
timeout: ptrOfDuration(time.Second),
|
||||
},
|
||||
{
|
||||
priority: true,
|
||||
@@ -227,8 +237,12 @@ Verbose: true
|
||||
}))
|
||||
|
||||
timeout := time.Second * 3
|
||||
if route.timeout > timeout {
|
||||
timeout = route.timeout
|
||||
if route.timeout != nil {
|
||||
if *route.timeout == 0 {
|
||||
timeout = 0
|
||||
} else if *route.timeout > timeout {
|
||||
timeout = *route.timeout
|
||||
}
|
||||
}
|
||||
assert.Equal(t, timeout, ng.timeout)
|
||||
})
|
||||
@@ -236,10 +250,69 @@ Verbose: true
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewEngine_unsignedCallback(t *testing.T) {
|
||||
priKeyfile, err := fs.TempFilenameWithText(priKey)
|
||||
assert.Nil(t, err)
|
||||
defer os.Remove(priKeyfile)
|
||||
|
||||
yaml := `Name: foo
|
||||
Host: localhost
|
||||
Port: 0
|
||||
Middlewares:
|
||||
Log: false
|
||||
`
|
||||
route := featuredRoutes{
|
||||
priority: true,
|
||||
jwt: jwtSetting{
|
||||
enabled: true,
|
||||
},
|
||||
signature: signatureSetting{
|
||||
enabled: true,
|
||||
SignatureConf: SignatureConf{
|
||||
Strict: true,
|
||||
PrivateKeys: []PrivateKeyConf{
|
||||
{
|
||||
Fingerprint: "a",
|
||||
KeyFile: priKeyfile,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
routes: []Route{{
|
||||
Method: http.MethodGet,
|
||||
Path: "/",
|
||||
Handler: func(w http.ResponseWriter, r *http.Request) {},
|
||||
}},
|
||||
}
|
||||
|
||||
var index int32
|
||||
t.Run(fmt.Sprintf("%s-%v", yaml, route.routes), func(t *testing.T) {
|
||||
var cnf RestConf
|
||||
assert.Nil(t, conf.LoadFromYamlBytes([]byte(yaml), &cnf))
|
||||
ng := newEngine(cnf)
|
||||
if atomic.AddInt32(&index, 1)%2 == 0 {
|
||||
ng.setUnsignedCallback(func(w http.ResponseWriter, r *http.Request,
|
||||
next http.Handler, strict bool, code int) {
|
||||
})
|
||||
}
|
||||
ng.addRoutes(route)
|
||||
ng.use(func(next http.HandlerFunc) http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
next.ServeHTTP(w, r)
|
||||
}
|
||||
})
|
||||
|
||||
assert.NotNil(t, ng.start(mockedRouter{}, func(svr *http.Server) {
|
||||
}))
|
||||
|
||||
assert.Equal(t, time.Duration(time.Second*3), ng.timeout)
|
||||
})
|
||||
}
|
||||
|
||||
func TestEngine_checkedTimeout(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
timeout time.Duration
|
||||
timeout *time.Duration
|
||||
expect time.Duration
|
||||
}{
|
||||
{
|
||||
@@ -248,17 +321,17 @@ func TestEngine_checkedTimeout(t *testing.T) {
|
||||
},
|
||||
{
|
||||
name: "less",
|
||||
timeout: time.Millisecond * 500,
|
||||
timeout: ptrOfDuration(time.Millisecond * 500),
|
||||
expect: time.Millisecond * 500,
|
||||
},
|
||||
{
|
||||
name: "equal",
|
||||
timeout: time.Second,
|
||||
timeout: ptrOfDuration(time.Second),
|
||||
expect: time.Second,
|
||||
},
|
||||
{
|
||||
name: "more",
|
||||
timeout: time.Millisecond * 1500,
|
||||
timeout: ptrOfDuration(time.Millisecond * 1500),
|
||||
expect: time.Millisecond * 1500,
|
||||
},
|
||||
}
|
||||
@@ -394,9 +467,14 @@ func TestEngine_withTimeout(t *testing.T) {
|
||||
for _, test := range tests {
|
||||
test := test
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
ng := newEngine(RestConf{Timeout: test.timeout})
|
||||
ng := newEngine(RestConf{
|
||||
Timeout: test.timeout,
|
||||
Middlewares: MiddlewaresConf{
|
||||
Timeout: true,
|
||||
},
|
||||
})
|
||||
svr := &http.Server{}
|
||||
ng.withTimeout()(svr)
|
||||
ng.withNetworkTimeout()(svr)
|
||||
|
||||
assert.Equal(t, time.Duration(test.timeout)*time.Millisecond*4/5, svr.ReadTimeout)
|
||||
assert.Equal(t, time.Duration(0), svr.ReadHeaderTimeout)
|
||||
@@ -406,6 +484,62 @@ func TestEngine_withTimeout(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestEngine_ReadWriteTimeout(t *testing.T) {
|
||||
logx.Disable()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
timeout int64
|
||||
middleware bool
|
||||
}{
|
||||
{
|
||||
name: "0/false",
|
||||
timeout: 0,
|
||||
middleware: false,
|
||||
},
|
||||
{
|
||||
name: "0/true",
|
||||
timeout: 0,
|
||||
middleware: true,
|
||||
},
|
||||
{
|
||||
name: "set/false",
|
||||
timeout: 1000,
|
||||
middleware: false,
|
||||
},
|
||||
{
|
||||
name: "both set",
|
||||
timeout: 1000,
|
||||
middleware: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
test := test
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
ng := newEngine(RestConf{
|
||||
Timeout: test.timeout,
|
||||
Middlewares: MiddlewaresConf{
|
||||
Timeout: test.middleware,
|
||||
},
|
||||
})
|
||||
svr := &http.Server{}
|
||||
ng.withNetworkTimeout()(svr)
|
||||
|
||||
assert.Equal(t, time.Duration(0), svr.ReadHeaderTimeout)
|
||||
assert.Equal(t, time.Duration(0), svr.IdleTimeout)
|
||||
|
||||
if test.timeout > 0 && test.middleware {
|
||||
assert.Equal(t, time.Duration(test.timeout)*time.Millisecond*4/5, svr.ReadTimeout)
|
||||
assert.Equal(t, time.Duration(test.timeout)*time.Millisecond*11/10, svr.WriteTimeout)
|
||||
} else {
|
||||
assert.Equal(t, time.Duration(0), svr.ReadTimeout)
|
||||
assert.Equal(t, time.Duration(0), svr.WriteTimeout)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestEngine_start(t *testing.T) {
|
||||
logx.Disable()
|
||||
|
||||
|
||||
@@ -106,8 +106,8 @@ func (h *timeoutHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
case <-ctx.Done():
|
||||
tw.mu.Lock()
|
||||
defer tw.mu.Unlock()
|
||||
// there isn't any user-defined middleware before TimoutHandler,
|
||||
// so we can guarantee that cancelation in biz related code won't come here.
|
||||
// there isn't any user-defined middleware before TimeoutHandler,
|
||||
// so we can guarantee that cancellation in biz related code won't come here.
|
||||
httpx.ErrorCtx(r.Context(), w, ctx.Err(), func(w http.ResponseWriter, err error) {
|
||||
if errors.Is(err, context.Canceled) {
|
||||
w.WriteHeader(statusClientClosedRequest)
|
||||
@@ -151,7 +151,7 @@ func (tw *timeoutWriter) Flush() {
|
||||
flusher.Flush()
|
||||
}
|
||||
|
||||
// Header returns the underline temporary http.Header.
|
||||
// Header returns the underlying temporary http.Header.
|
||||
func (tw *timeoutWriter) Header() http.Header {
|
||||
return tw.h
|
||||
}
|
||||
|
||||
@@ -2,15 +2,16 @@ package fileserver
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"path"
|
||||
"strings"
|
||||
"sync"
|
||||
)
|
||||
|
||||
// Middleware returns a middleware that serves files from the given file system.
|
||||
func Middleware(path string, fs http.FileSystem) func(http.HandlerFunc) http.HandlerFunc {
|
||||
func Middleware(upath string, fs http.FileSystem) func(http.HandlerFunc) http.HandlerFunc {
|
||||
fileServer := http.FileServer(fs)
|
||||
pathWithoutTrailSlash := ensureNoTrailingSlash(path)
|
||||
canServe := createServeChecker(path, fs)
|
||||
pathWithoutTrailSlash := ensureNoTrailingSlash(upath)
|
||||
canServe := createServeChecker(upath, fs)
|
||||
|
||||
return func(next http.HandlerFunc) http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
@@ -28,9 +29,22 @@ func createFileChecker(fs http.FileSystem) func(string) bool {
|
||||
var lock sync.RWMutex
|
||||
fileChecker := make(map[string]bool)
|
||||
|
||||
return func(path string) bool {
|
||||
return func(upath string) bool {
|
||||
// Emulate http.Dir.Open’s path normalization for embed.FS.Open.
|
||||
// http.FileServer redirects any request ending in "/index.html"
|
||||
// to the same path without the final "index.html".
|
||||
// So the path here may be empty or end with a "/".
|
||||
// http.Dir.Open uses this logic to clean the path,
|
||||
// correctly handling those two cases.
|
||||
// embed.FS doesn’t perform this normalization, so we apply the same logic here.
|
||||
upath = path.Clean("/" + upath)[1:]
|
||||
if len(upath) == 0 {
|
||||
// if the path is empty, we use "." to open the current directory
|
||||
upath = "."
|
||||
}
|
||||
|
||||
lock.RLock()
|
||||
exist, ok := fileChecker[path]
|
||||
exist, ok := fileChecker[upath]
|
||||
lock.RUnlock()
|
||||
if ok {
|
||||
return exist
|
||||
@@ -39,9 +53,9 @@ func createFileChecker(fs http.FileSystem) func(string) bool {
|
||||
lock.Lock()
|
||||
defer lock.Unlock()
|
||||
|
||||
file, err := fs.Open(path)
|
||||
file, err := fs.Open(upath)
|
||||
exist = err == nil
|
||||
fileChecker[path] = exist
|
||||
fileChecker[upath] = exist
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
@@ -51,8 +65,8 @@ func createFileChecker(fs http.FileSystem) func(string) bool {
|
||||
}
|
||||
}
|
||||
|
||||
func createServeChecker(path string, fs http.FileSystem) func(r *http.Request) bool {
|
||||
pathWithTrailSlash := ensureTrailingSlash(path)
|
||||
func createServeChecker(upath string, fs http.FileSystem) func(r *http.Request) bool {
|
||||
pathWithTrailSlash := ensureTrailingSlash(upath)
|
||||
fileChecker := createFileChecker(fs)
|
||||
|
||||
return func(r *http.Request) bool {
|
||||
@@ -62,18 +76,18 @@ func createServeChecker(path string, fs http.FileSystem) func(r *http.Request) b
|
||||
}
|
||||
}
|
||||
|
||||
func ensureTrailingSlash(path string) string {
|
||||
if strings.HasSuffix(path, "/") {
|
||||
return path
|
||||
func ensureTrailingSlash(upath string) string {
|
||||
if strings.HasSuffix(upath, "/") {
|
||||
return upath
|
||||
}
|
||||
|
||||
return path + "/"
|
||||
return upath + "/"
|
||||
}
|
||||
|
||||
func ensureNoTrailingSlash(path string) string {
|
||||
if strings.HasSuffix(path, "/") {
|
||||
return path[:len(path)-1]
|
||||
func ensureNoTrailingSlash(upath string) string {
|
||||
if strings.HasSuffix(upath, "/") {
|
||||
return upath[:len(upath)-1]
|
||||
}
|
||||
|
||||
return path
|
||||
return upath
|
||||
}
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
package fileserver
|
||||
|
||||
import (
|
||||
"embed"
|
||||
"io/fs"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
@@ -61,6 +63,46 @@ func TestMiddleware(t *testing.T) {
|
||||
requestPath: "/ws",
|
||||
expectedStatus: http.StatusAlreadyReported,
|
||||
},
|
||||
|
||||
// http.FileServer redirects any request ending in "/index.html"
|
||||
// to the same path, without the final "index.html".
|
||||
{
|
||||
name: "Serve index.html",
|
||||
path: "/static",
|
||||
dir: "testdata",
|
||||
requestPath: "/static/index.html",
|
||||
expectedStatus: http.StatusMovedPermanently,
|
||||
},
|
||||
{
|
||||
name: "Serve index.html with path with trailing slash",
|
||||
path: "/static/",
|
||||
dir: "testdata",
|
||||
requestPath: "/static/index.html",
|
||||
expectedStatus: http.StatusMovedPermanently,
|
||||
},
|
||||
{
|
||||
name: "Serve index.html in a nested directory",
|
||||
path: "/static",
|
||||
dir: "testdata",
|
||||
requestPath: "/static/nested/index.html",
|
||||
expectedStatus: http.StatusMovedPermanently,
|
||||
},
|
||||
{
|
||||
name: "Request index.html indirectly",
|
||||
path: "/static",
|
||||
dir: "testdata",
|
||||
requestPath: "/static/",
|
||||
expectedStatus: http.StatusOK,
|
||||
expectedContent: "hello",
|
||||
},
|
||||
{
|
||||
name: "Request index.html in a nested directory indirectly",
|
||||
path: "/static",
|
||||
dir: "testdata",
|
||||
requestPath: "/static/nested/",
|
||||
expectedStatus: http.StatusOK,
|
||||
expectedContent: "hello",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
@@ -87,6 +129,128 @@ func TestMiddleware(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
var (
|
||||
//go:embed testdata
|
||||
testdataFS embed.FS
|
||||
)
|
||||
|
||||
func TestMiddleware_embedFS(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
path string
|
||||
requestPath string
|
||||
expectedStatus int
|
||||
expectedContent string
|
||||
}{
|
||||
{
|
||||
name: "Serve static file",
|
||||
path: "/static",
|
||||
requestPath: "/static/example.txt",
|
||||
expectedStatus: http.StatusOK,
|
||||
expectedContent: "1",
|
||||
},
|
||||
{
|
||||
name: "Path with trailing slash",
|
||||
path: "/static/",
|
||||
requestPath: "/static/example.txt",
|
||||
expectedStatus: http.StatusOK,
|
||||
expectedContent: "1",
|
||||
},
|
||||
{
|
||||
name: "Root path",
|
||||
path: "/",
|
||||
requestPath: "/example.txt",
|
||||
expectedStatus: http.StatusOK,
|
||||
expectedContent: "1",
|
||||
},
|
||||
{
|
||||
name: "Pass through non-matching path",
|
||||
path: "/static/",
|
||||
requestPath: "/other/path",
|
||||
expectedStatus: http.StatusAlreadyReported,
|
||||
},
|
||||
{
|
||||
name: "Not exist file",
|
||||
path: "/assets",
|
||||
requestPath: "/assets/not-exist.txt",
|
||||
expectedStatus: http.StatusAlreadyReported,
|
||||
},
|
||||
{
|
||||
name: "Not exist file in root",
|
||||
path: "/",
|
||||
requestPath: "/not-exist.txt",
|
||||
expectedStatus: http.StatusAlreadyReported,
|
||||
},
|
||||
{
|
||||
name: "websocket request",
|
||||
path: "/",
|
||||
requestPath: "/ws",
|
||||
expectedStatus: http.StatusAlreadyReported,
|
||||
},
|
||||
|
||||
// http.FileServer redirects any request ending in "/index.html"
|
||||
// to the same path, without the final "index.html".
|
||||
{
|
||||
name: "Serve index.html",
|
||||
path: "/static",
|
||||
requestPath: "/static/index.html",
|
||||
expectedStatus: http.StatusMovedPermanently,
|
||||
},
|
||||
{
|
||||
name: "Serve index.html with path with trailing slash",
|
||||
path: "/static/",
|
||||
requestPath: "/static/index.html",
|
||||
expectedStatus: http.StatusMovedPermanently,
|
||||
},
|
||||
{
|
||||
name: "Serve index.html in a nested directory",
|
||||
path: "/static",
|
||||
requestPath: "/static/nested/index.html",
|
||||
expectedStatus: http.StatusMovedPermanently,
|
||||
},
|
||||
{
|
||||
name: "Request index.html indirectly",
|
||||
path: "/static",
|
||||
requestPath: "/static/",
|
||||
expectedStatus: http.StatusOK,
|
||||
expectedContent: "hello",
|
||||
},
|
||||
{
|
||||
name: "Request index.html in a nested directory indirectly",
|
||||
path: "/static",
|
||||
requestPath: "/static/nested/",
|
||||
expectedStatus: http.StatusOK,
|
||||
expectedContent: "hello",
|
||||
},
|
||||
}
|
||||
|
||||
subFS, err := fs.Sub(testdataFS, "testdata")
|
||||
assert.Nil(t, err)
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
middleware := Middleware(tt.path, http.FS(subFS))
|
||||
nextHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusAlreadyReported)
|
||||
})
|
||||
|
||||
handlerToTest := middleware(nextHandler)
|
||||
|
||||
for i := 0; i < 2; i++ {
|
||||
req := httptest.NewRequest(http.MethodGet, tt.requestPath, nil)
|
||||
rr := httptest.NewRecorder()
|
||||
|
||||
handlerToTest.ServeHTTP(rr, req)
|
||||
|
||||
assert.Equal(t, tt.expectedStatus, rr.Code)
|
||||
if len(tt.expectedContent) > 0 {
|
||||
assert.Equal(t, tt.expectedContent, rr.Body.String())
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestEnsureTrailingSlash(t *testing.T) {
|
||||
tests := []struct {
|
||||
input string
|
||||
|
||||
1
rest/internal/fileserver/testdata/index.html
vendored
Normal file
1
rest/internal/fileserver/testdata/index.html
vendored
Normal file
@@ -0,0 +1 @@
|
||||
hello
|
||||
1
rest/internal/fileserver/testdata/nested/index.html
vendored
Normal file
1
rest/internal/fileserver/testdata/nested/index.html
vendored
Normal file
@@ -0,0 +1 @@
|
||||
hello
|
||||
@@ -119,6 +119,16 @@ func (s *Server) Use(middleware Middleware) {
|
||||
s.ngin.use(middleware)
|
||||
}
|
||||
|
||||
// build builds the Server and binds the routes to the router.
|
||||
func (s *Server) build() error {
|
||||
return s.ngin.bindRoutes(s.router)
|
||||
}
|
||||
|
||||
// serve serves the HTTP requests using the Server's router.
|
||||
func (s *Server) serve(w http.ResponseWriter, r *http.Request) {
|
||||
s.router.ServeHTTP(w, r)
|
||||
}
|
||||
|
||||
// ToMiddleware converts the given handler to a Middleware.
|
||||
func ToMiddleware(handler func(next http.Handler) http.Handler) Middleware {
|
||||
return func(handle http.HandlerFunc) http.HandlerFunc {
|
||||
@@ -283,14 +293,14 @@ func WithSignature(signature SignatureConf) RouteOption {
|
||||
func WithSSE() RouteOption {
|
||||
return func(r *featuredRoutes) {
|
||||
r.sse = true
|
||||
r.timeout = 0
|
||||
r.timeout = ptrOfDuration(0)
|
||||
}
|
||||
}
|
||||
|
||||
// WithTimeout returns a RouteOption to set timeout with given value.
|
||||
func WithTimeout(timeout time.Duration) RouteOption {
|
||||
return func(r *featuredRoutes) {
|
||||
r.timeout = timeout
|
||||
r.timeout = &timeout
|
||||
}
|
||||
}
|
||||
|
||||
@@ -325,6 +335,10 @@ func handleError(err error) {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
func ptrOfDuration(d time.Duration) *time.Duration {
|
||||
return &d
|
||||
}
|
||||
|
||||
func validateSecret(secret string) {
|
||||
if len(secret) < 8 {
|
||||
panic("secret's length can't be less than 8")
|
||||
|
||||
@@ -345,7 +345,7 @@ func TestWithPriority(t *testing.T) {
|
||||
func TestWithTimeout(t *testing.T) {
|
||||
var fr featuredRoutes
|
||||
WithTimeout(time.Hour)(&fr)
|
||||
assert.Equal(t, time.Hour, fr.timeout)
|
||||
assert.Equal(t, time.Hour, *fr.timeout)
|
||||
}
|
||||
|
||||
func TestWithTLSConfig(t *testing.T) {
|
||||
@@ -819,6 +819,6 @@ func TestServerEmbedFileSystem(t *testing.T) {
|
||||
// serve(server, w, r)
|
||||
// // verify the response
|
||||
func serve(s *Server, w http.ResponseWriter, r *http.Request) {
|
||||
s.ngin.bindRoutes(s.router)
|
||||
s.router.ServeHTTP(w, r)
|
||||
_ = s.build()
|
||||
s.serve(w, r)
|
||||
}
|
||||
|
||||
27
rest/serverless.go
Normal file
27
rest/serverless.go
Normal file
@@ -0,0 +1,27 @@
|
||||
package rest
|
||||
|
||||
import "net/http"
|
||||
|
||||
// Serverless is a wrapper around Server that allows it to be used in serverless environments.
|
||||
type Serverless struct {
|
||||
server *Server
|
||||
}
|
||||
|
||||
// NewServerless creates a new Serverless instance from the provided Server.
|
||||
func NewServerless(server *Server) (*Serverless, error) {
|
||||
// Ensure the server is built before using it in a serverless context.
|
||||
// Why not call server.build() when serving requests,
|
||||
// is because we need to ensure fail fast behavior.
|
||||
if err := server.build(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &Serverless{
|
||||
server: server,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Serve handles HTTP requests by delegating them to the underlying Server instance.
|
||||
func (s *Serverless) Serve(w http.ResponseWriter, r *http.Request) {
|
||||
s.server.serve(w, r)
|
||||
}
|
||||
67
rest/serverless_test.go
Normal file
67
rest/serverless_test.go
Normal file
@@ -0,0 +1,67 @@
|
||||
package rest
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/zeromicro/go-zero/core/conf"
|
||||
"github.com/zeromicro/go-zero/core/logx/logtest"
|
||||
)
|
||||
|
||||
func TestNewServerless(t *testing.T) {
|
||||
logtest.Discard(t)
|
||||
|
||||
const configYaml = `
|
||||
Name: foo
|
||||
Host: localhost
|
||||
Port: 0
|
||||
`
|
||||
var cnf RestConf
|
||||
assert.Nil(t, conf.LoadFromYamlBytes([]byte(configYaml), &cnf))
|
||||
|
||||
svr, err := NewServer(cnf)
|
||||
assert.NoError(t, err)
|
||||
|
||||
svr.AddRoute(Route{
|
||||
Method: http.MethodGet,
|
||||
Path: "/",
|
||||
Handler: func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Write([]byte("Hello World"))
|
||||
},
|
||||
})
|
||||
|
||||
serverless, err := NewServerless(svr)
|
||||
assert.NoError(t, err)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
r := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
serverless.Serve(w, r)
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
assert.Equal(t, "Hello World", w.Body.String())
|
||||
}
|
||||
|
||||
func TestNewServerlessWithError(t *testing.T) {
|
||||
logtest.Discard(t)
|
||||
|
||||
const configYaml = `
|
||||
Name: foo
|
||||
Host: localhost
|
||||
Port: 0
|
||||
`
|
||||
var cnf RestConf
|
||||
assert.Nil(t, conf.LoadFromYamlBytes([]byte(configYaml), &cnf))
|
||||
|
||||
svr, err := NewServer(cnf)
|
||||
assert.NoError(t, err)
|
||||
|
||||
svr.AddRoute(Route{
|
||||
Method: http.MethodGet,
|
||||
Path: "notstartwith/",
|
||||
Handler: nil,
|
||||
})
|
||||
|
||||
_, err = NewServerless(svr)
|
||||
assert.Error(t, err)
|
||||
}
|
||||
@@ -31,7 +31,7 @@ type (
|
||||
}
|
||||
|
||||
featuredRoutes struct {
|
||||
timeout time.Duration
|
||||
timeout *time.Duration
|
||||
priority bool
|
||||
jwt jwtSetting
|
||||
signature signatureSetting
|
||||
|
||||
1
tools/goctl/.gitignore
vendored
Normal file
1
tools/goctl/.gitignore
vendored
Normal file
@@ -0,0 +1 @@
|
||||
dist
|
||||
@@ -77,6 +77,7 @@ func init() {
|
||||
goCmdFlags.StringVar(&gogen.VarStringRemote, "remote")
|
||||
goCmdFlags.StringVar(&gogen.VarStringBranch, "branch")
|
||||
goCmdFlags.BoolVar(&gogen.VarBoolWithTest, "test")
|
||||
goCmdFlags.BoolVar(&gogen.VarBoolTypeGroup, "type-group")
|
||||
goCmdFlags.StringVarWithDefaultValue(&gogen.VarStringStyle, "style", config.DefaultFormat)
|
||||
|
||||
javaCmdFlags.StringVar(&javagen.VarStringDir, "dir")
|
||||
|
||||
@@ -40,6 +40,8 @@ var (
|
||||
// VarStringStyle describes the style of output files.
|
||||
VarStringStyle string
|
||||
VarBoolWithTest bool
|
||||
// VarBoolTypeGroup describes whether to group types.
|
||||
VarBoolTypeGroup bool
|
||||
)
|
||||
|
||||
// GoCommand gen go project files from command line
|
||||
|
||||
@@ -9,11 +9,11 @@ import (
|
||||
"sort"
|
||||
"strings"
|
||||
|
||||
"github.com/zeromicro/go-zero/core/collection"
|
||||
"github.com/zeromicro/go-zero/tools/goctl/api/spec"
|
||||
apiutil "github.com/zeromicro/go-zero/tools/goctl/api/util"
|
||||
"github.com/zeromicro/go-zero/tools/goctl/config"
|
||||
"github.com/zeromicro/go-zero/tools/goctl/internal/version"
|
||||
"github.com/zeromicro/go-zero/tools/goctl/pkg/env"
|
||||
"github.com/zeromicro/go-zero/tools/goctl/util"
|
||||
"github.com/zeromicro/go-zero/tools/goctl/util/format"
|
||||
)
|
||||
@@ -41,53 +41,116 @@ func BuildTypes(types []spec.Type) (string, error) {
|
||||
return builder.String(), nil
|
||||
}
|
||||
|
||||
func removeTypeFromDefault(tp spec.Type, group string, groupTypes map[string]map[string]spec.Type) map[string]map[string]spec.Type {
|
||||
func getTypeName(tp spec.Type) string {
|
||||
if tp == nil {
|
||||
return ""
|
||||
}
|
||||
switch val := tp.(type) {
|
||||
case spec.DefineStruct:
|
||||
typeName := util.Title(tp.Name())
|
||||
defaultGroups, ok := groupTypes[groupTypeDefault]
|
||||
if ok {
|
||||
delete(defaultGroups, typeName)
|
||||
types, ok := groupTypes[group]
|
||||
if !ok {
|
||||
types = make(map[string]spec.Type)
|
||||
}
|
||||
types[typeName] = tp
|
||||
groupTypes[group] = types
|
||||
}
|
||||
groupTypes[groupTypeDefault] = defaultGroups
|
||||
return typeName
|
||||
case spec.PointerType:
|
||||
groupTypes = removeTypeFromDefault(val.Type, group, groupTypes)
|
||||
return getTypeName(val.Type)
|
||||
case spec.ArrayType:
|
||||
groupTypes = removeTypeFromDefault(val.Value, group, groupTypes)
|
||||
return getTypeName(val.Value)
|
||||
}
|
||||
return groupTypes
|
||||
return ""
|
||||
}
|
||||
|
||||
func genTypesWithGroup(dir string, cfg *config.Config, api *spec.ApiSpec) error {
|
||||
groupTypes := make(map[string]map[string]spec.Type)
|
||||
for _, v := range api.Types {
|
||||
types, ok := groupTypes[groupTypeDefault]
|
||||
if !ok {
|
||||
types = make(map[string]spec.Type)
|
||||
}
|
||||
types[util.Title(v.Name())] = v
|
||||
groupTypes[groupTypeDefault] = types
|
||||
}
|
||||
typesBelongToFiles := make(map[string]*collection.Set)
|
||||
|
||||
for _, v := range api.Service.Groups {
|
||||
group := v.GetAnnotation(groupProperty)
|
||||
if len(group) == 0 {
|
||||
group = groupTypeDefault
|
||||
}
|
||||
// convert filepath to Identifier name spec.
|
||||
group = strings.TrimPrefix(group, "/")
|
||||
group = strings.TrimSuffix(group, "/")
|
||||
group = util.SafeString(group)
|
||||
for _, v := range v.Routes {
|
||||
requestTypeName := getTypeName(v.RequestType)
|
||||
responseTypeName := getTypeName(v.ResponseType)
|
||||
requestTypeFileSet, ok := typesBelongToFiles[requestTypeName]
|
||||
if !ok {
|
||||
requestTypeFileSet = collection.NewSet()
|
||||
}
|
||||
if len(requestTypeName) > 0 {
|
||||
requestTypeFileSet.AddStr(group)
|
||||
typesBelongToFiles[requestTypeName] = requestTypeFileSet
|
||||
}
|
||||
|
||||
responseTypeFileSet, ok := typesBelongToFiles[responseTypeName]
|
||||
if !ok {
|
||||
responseTypeFileSet = collection.NewSet()
|
||||
}
|
||||
if len(responseTypeName) > 0 {
|
||||
responseTypeFileSet.AddStr(group)
|
||||
typesBelongToFiles[responseTypeName] = responseTypeFileSet
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
typesInOneFile := make(map[string]*collection.Set)
|
||||
for typeName, fileSet := range typesBelongToFiles {
|
||||
count := fileSet.Count()
|
||||
switch {
|
||||
case count == 0: // it means there has no structure type or no request/response body
|
||||
continue
|
||||
case count == 1: // it means a structure type used in only one group.
|
||||
groupName := fileSet.KeysStr()[0]
|
||||
typeSet, ok := typesInOneFile[groupName]
|
||||
if !ok {
|
||||
typeSet = collection.NewSet()
|
||||
}
|
||||
typeSet.AddStr(typeName)
|
||||
typesInOneFile[groupName] = typeSet
|
||||
default: // it means this type is used in multiple groups.
|
||||
continue
|
||||
}
|
||||
for _, v := range v.Routes {
|
||||
if v.RequestType != nil {
|
||||
groupTypes = removeTypeFromDefault(v.RequestType, group, groupTypes)
|
||||
}
|
||||
if v.ResponseType != nil {
|
||||
groupTypes = removeTypeFromDefault(v.ResponseType, group, groupTypes)
|
||||
}
|
||||
}
|
||||
|
||||
for _, v := range api.Types {
|
||||
typeName := util.Title(v.Name())
|
||||
groupSet, ok := typesBelongToFiles[typeName]
|
||||
var typeCount int
|
||||
if !ok {
|
||||
typeCount = 0
|
||||
} else {
|
||||
typeCount = groupSet.Count()
|
||||
}
|
||||
|
||||
if typeCount == 0 { // not belong to any group
|
||||
types, ok := groupTypes[groupTypeDefault]
|
||||
if !ok {
|
||||
types = make(map[string]spec.Type)
|
||||
}
|
||||
types[typeName] = v
|
||||
groupTypes[groupTypeDefault] = types
|
||||
continue
|
||||
}
|
||||
|
||||
if typeCount == 1 { // belong to one group
|
||||
groupName := groupSet.KeysStr()[0]
|
||||
types, ok := groupTypes[groupName]
|
||||
if !ok {
|
||||
types = make(map[string]spec.Type)
|
||||
}
|
||||
types[typeName] = v
|
||||
groupTypes[groupName] = types
|
||||
continue
|
||||
}
|
||||
|
||||
// belong to multiple groups
|
||||
types, ok := groupTypes[groupTypeDefault]
|
||||
if !ok {
|
||||
types = make(map[string]spec.Type)
|
||||
}
|
||||
types[typeName] = v
|
||||
groupTypes[groupTypeDefault] = types
|
||||
|
||||
}
|
||||
|
||||
for group, typeGroup := range groupTypes {
|
||||
@@ -142,7 +205,7 @@ func writeTypes(dir, baseFilename string, cfg *config.Config, types []spec.Type)
|
||||
}
|
||||
|
||||
func genTypes(dir string, cfg *config.Config, api *spec.ApiSpec) error {
|
||||
if env.UseExperimental() {
|
||||
if VarBoolTypeGroup {
|
||||
return genTypesWithGroup(dir, cfg, api)
|
||||
}
|
||||
return writeTypes(dir, typesFile, cfg, api.Types)
|
||||
|
||||
@@ -8,10 +8,10 @@ import (
|
||||
"fmt"
|
||||
"io"
|
||||
"path"
|
||||
"slices"
|
||||
"strings"
|
||||
"text/template"
|
||||
|
||||
"github.com/zeromicro/go-zero/core/stringx"
|
||||
"github.com/zeromicro/go-zero/tools/goctl/api/spec"
|
||||
apiutil "github.com/zeromicro/go-zero/tools/goctl/api/util"
|
||||
"github.com/zeromicro/go-zero/tools/goctl/internal/version"
|
||||
@@ -96,13 +96,13 @@ func (c *componentsContext) createComponent(dir, packetName string, ty spec.Type
|
||||
for _, item := range c.responseTypes {
|
||||
if item.Name() == defineStruct.Name() {
|
||||
superClassName = "HttpResponseData"
|
||||
if !stringx.Contains(c.imports, httpResponseData) {
|
||||
if !slices.Contains(c.imports, httpResponseData) {
|
||||
c.imports = append(c.imports, httpResponseData)
|
||||
}
|
||||
break
|
||||
}
|
||||
}
|
||||
if superClassName == "HttpData" && !stringx.Contains(c.imports, httpData) {
|
||||
if superClassName == "HttpData" && !slices.Contains(c.imports, httpData) {
|
||||
c.imports = append(c.imports, httpData)
|
||||
}
|
||||
|
||||
@@ -266,7 +266,7 @@ func (c *componentsContext) genGetSet(writer io.Writer, indent int) error {
|
||||
tyString := javaType
|
||||
decorator := ""
|
||||
javaPrimitiveType := []string{"int", "long", "boolean", "float", "double", "short"}
|
||||
if !stringx.Contains(javaPrimitiveType, javaType) {
|
||||
if !slices.Contains(javaPrimitiveType, javaType) {
|
||||
if member.IsOptional() || member.IsOmitEmpty() {
|
||||
decorator = "@Nullable "
|
||||
} else {
|
||||
|
||||
@@ -3,9 +3,9 @@ package spec
|
||||
import (
|
||||
"errors"
|
||||
"path"
|
||||
"slices"
|
||||
"strings"
|
||||
|
||||
"github.com/zeromicro/go-zero/core/stringx"
|
||||
"github.com/zeromicro/go-zero/tools/goctl/util"
|
||||
)
|
||||
|
||||
@@ -64,7 +64,7 @@ func (m Member) IsOptional() bool {
|
||||
tag := m.Tags()
|
||||
for _, item := range tag {
|
||||
if item.Key == bodyTagKey || item.Key == formTagKey {
|
||||
if stringx.Contains(item.Options, "optional") {
|
||||
if slices.Contains(item.Options, "optional") {
|
||||
return true
|
||||
}
|
||||
}
|
||||
@@ -81,7 +81,7 @@ func (m Member) IsOmitEmpty() bool {
|
||||
tag := m.Tags()
|
||||
for _, item := range tag {
|
||||
if item.Key == bodyTagKey {
|
||||
if stringx.Contains(item.Options, "omitempty") {
|
||||
if slices.Contains(item.Options, "omitempty") {
|
||||
return true
|
||||
}
|
||||
}
|
||||
@@ -93,7 +93,7 @@ func (m Member) IsOmitEmpty() bool {
|
||||
func (m Member) GetPropertyName() (string, error) {
|
||||
tags := m.Tags()
|
||||
for _, tag := range tags {
|
||||
if stringx.Contains(definedKeys, tag.Key) {
|
||||
if slices.Contains(definedKeys, tag.Key) {
|
||||
if tag.Name == "-" {
|
||||
return util.Untitle(m.Name), nil
|
||||
}
|
||||
|
||||
@@ -49,6 +49,9 @@ func Parse(tag string) (*Tags, error) {
|
||||
|
||||
// Get gets tag value by specified key
|
||||
func (t *Tags) Get(key string) (*Tag, error) {
|
||||
if t == nil {
|
||||
return nil, errTagNotExist
|
||||
}
|
||||
for _, tag := range t.tags {
|
||||
if tag.Key == key {
|
||||
return tag, nil
|
||||
@@ -60,6 +63,9 @@ func (t *Tags) Get(key string) (*Tag, error) {
|
||||
|
||||
// Keys returns all keys in Tags
|
||||
func (t *Tags) Keys() []string {
|
||||
if t == nil {
|
||||
return []string{}
|
||||
}
|
||||
var keys []string
|
||||
for _, tag := range t.tags {
|
||||
keys = append(keys, tag.Key)
|
||||
@@ -69,5 +75,8 @@ func (t *Tags) Keys() []string {
|
||||
|
||||
// Tags returns all tags in Tags
|
||||
func (t *Tags) Tags() []*Tag {
|
||||
if t == nil {
|
||||
return []*Tag{}
|
||||
}
|
||||
return t.tags
|
||||
}
|
||||
|
||||
68
tools/goctl/api/spec/tags_test.go
Normal file
68
tools/goctl/api/spec/tags_test.go
Normal file
@@ -0,0 +1,68 @@
|
||||
package spec
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestTags_Get(t *testing.T) {
|
||||
tags := &Tags{
|
||||
tags: []*Tag{
|
||||
{Key: "json", Name: "foo", Options: []string{"omitempty"}},
|
||||
{Key: "xml", Name: "bar", Options: nil},
|
||||
},
|
||||
}
|
||||
|
||||
tag, err := tags.Get("json")
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, tag)
|
||||
assert.Equal(t, "json", tag.Key)
|
||||
assert.Equal(t, "foo", tag.Name)
|
||||
|
||||
_, err = tags.Get("yaml")
|
||||
assert.Error(t, err)
|
||||
|
||||
var nilTags *Tags
|
||||
_, err = nilTags.Get("json")
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
func TestTags_Keys(t *testing.T) {
|
||||
tags := &Tags{
|
||||
tags: []*Tag{
|
||||
{Key: "json", Name: "foo", Options: []string{"omitempty"}},
|
||||
{Key: "xml", Name: "bar", Options: nil},
|
||||
},
|
||||
}
|
||||
|
||||
keys := tags.Keys()
|
||||
expected := []string{"json", "xml"}
|
||||
assert.Equal(t, expected, keys)
|
||||
|
||||
var nilTags *Tags
|
||||
nilKeys := nilTags.Keys()
|
||||
assert.Empty(t, nilKeys)
|
||||
}
|
||||
|
||||
func TestTags_Tags(t *testing.T) {
|
||||
tags := &Tags{
|
||||
tags: []*Tag{
|
||||
{Key: "json", Name: "foo", Options: []string{"omitempty"}},
|
||||
{Key: "xml", Name: "bar", Options: nil},
|
||||
},
|
||||
}
|
||||
|
||||
result := tags.Tags()
|
||||
assert.Len(t, result, 2)
|
||||
assert.Equal(t, "json", result[0].Key)
|
||||
assert.Equal(t, "foo", result[0].Name)
|
||||
assert.Equal(t, []string{"omitempty"}, result[0].Options)
|
||||
assert.Equal(t, "xml", result[1].Key)
|
||||
assert.Equal(t, "bar", result[1].Name)
|
||||
assert.Nil(t, result[1].Options)
|
||||
|
||||
var nilTags *Tags
|
||||
nilResult := nilTags.Tags()
|
||||
assert.Empty(t, nilResult)
|
||||
}
|
||||
@@ -7,15 +7,6 @@ import (
|
||||
"google.golang.org/grpc/metadata"
|
||||
)
|
||||
|
||||
func hasKey(properties map[string]string, key string) bool {
|
||||
if len(properties) == 0 {
|
||||
return false
|
||||
}
|
||||
md := metadata.New(properties)
|
||||
_, ok := md[key]
|
||||
return ok
|
||||
}
|
||||
|
||||
func getBoolFromKVOrDefault(properties map[string]string, key string, def bool) bool {
|
||||
if len(properties) == 0 {
|
||||
return def
|
||||
|
||||
@@ -42,7 +42,7 @@ func Test_getListFromInfoOrDefault(t *testing.T) {
|
||||
"empty": `""`,
|
||||
}
|
||||
|
||||
assert.Equal(t, []string{"a", "b", "c"}, getListFromInfoOrDefault(properties, "list", []string{"default"}))
|
||||
assert.Equal(t, []string{"a", " b", " c"}, getListFromInfoOrDefault(properties, "list", []string{"default"}))
|
||||
assert.Equal(t, []string{"default"}, getListFromInfoOrDefault(properties, "empty", []string{"default"}))
|
||||
assert.Equal(t, []string{"default"}, getListFromInfoOrDefault(properties, "missing", []string{"default"}))
|
||||
assert.Equal(t, []string{"default"}, getListFromInfoOrDefault(nil, "nil", []string{"default"}))
|
||||
|
||||
@@ -8,10 +8,9 @@ import (
|
||||
"strings"
|
||||
|
||||
"github.com/spf13/cobra"
|
||||
"gopkg.in/yaml.v2"
|
||||
|
||||
"github.com/zeromicro/go-zero/tools/goctl/pkg/parser/api/parser"
|
||||
"github.com/zeromicro/go-zero/tools/goctl/util/pathx"
|
||||
"gopkg.in/yaml.v2"
|
||||
)
|
||||
|
||||
var (
|
||||
|
||||
@@ -1,14 +1,15 @@
|
||||
package swagger
|
||||
|
||||
const (
|
||||
tagHeader = "header"
|
||||
tagPath = "path"
|
||||
tagForm = "form"
|
||||
tagJson = "json"
|
||||
defFlag = "default="
|
||||
enumFlag = "options="
|
||||
rangeFlag = "range="
|
||||
exampleFlag = "example="
|
||||
tagHeader = "header"
|
||||
tagPath = "path"
|
||||
tagForm = "form"
|
||||
tagJson = "json"
|
||||
defFlag = "default="
|
||||
enumFlag = "options="
|
||||
rangeFlag = "range="
|
||||
exampleFlag = "example="
|
||||
optionalFlag = "optional"
|
||||
|
||||
paramsInHeader = "header"
|
||||
paramsInPath = "path"
|
||||
@@ -27,6 +28,38 @@ const (
|
||||
applicationJson = "application/json"
|
||||
applicationForm = "application/x-www-form-urlencoded"
|
||||
schemeHttps = "https"
|
||||
defaultHost = "127.0.0.1"
|
||||
defaultBasePath = "/"
|
||||
)
|
||||
|
||||
const (
|
||||
propertyKeyUseDefinitions = "useDefinitions"
|
||||
propertyKeyExternalDocsDescription = "externalDocsDescription"
|
||||
propertyKeyExternalDocsURL = "externalDocsURL"
|
||||
propertyKeyTitle = "title"
|
||||
propertyKeyTermsOfService = "termsOfService"
|
||||
propertyKeyDescription = "description"
|
||||
propertyKeyVersion = "version"
|
||||
propertyKeyContactName = "contactName"
|
||||
propertyKeyContactURL = "contactURL"
|
||||
propertyKeyContactEmail = "contactEmail"
|
||||
propertyKeyLicenseName = "licenseName"
|
||||
propertyKeyLicenseURL = "licenseURL"
|
||||
propertyKeyProduces = "produces"
|
||||
propertyKeyConsumes = "consumes"
|
||||
propertyKeySchemes = "schemes"
|
||||
propertyKeyTags = "tags"
|
||||
propertyKeySummary = "summary"
|
||||
propertyKeyGroup = "group"
|
||||
propertyKeyOperationId = "operationId"
|
||||
propertyKeyDeprecated = "deprecated"
|
||||
propertyKeyPrefix = "prefix"
|
||||
propertyKeyAuthType = "authType"
|
||||
propertyKeyHost = "host"
|
||||
propertyKeyBasePath = "basePath"
|
||||
propertyKeyWrapCodeMsg = "wrapCodeMsg"
|
||||
propertyKeyBizCodeEnumDescription = "bizCodeEnumDescription"
|
||||
)
|
||||
|
||||
const (
|
||||
defaultValueOfPropertyUseDefinition = false
|
||||
)
|
||||
|
||||
@@ -7,7 +7,7 @@ import (
|
||||
"github.com/zeromicro/go-zero/tools/goctl/api/spec"
|
||||
)
|
||||
|
||||
func consumesFromTypeOrDef(method string, tp spec.Type) []string {
|
||||
func consumesFromTypeOrDef(ctx Context, method string, tp spec.Type) []string {
|
||||
if strings.EqualFold(method, http.MethodGet) {
|
||||
return []string{}
|
||||
}
|
||||
@@ -18,7 +18,7 @@ func consumesFromTypeOrDef(method string, tp spec.Type) []string {
|
||||
if !ok {
|
||||
return []string{}
|
||||
}
|
||||
if typeContainsTag(structType, tagJson) {
|
||||
if typeContainsTag(ctx, structType, tagJson) {
|
||||
return []string{applicationJson}
|
||||
}
|
||||
return []string{applicationForm}
|
||||
|
||||
@@ -61,7 +61,7 @@ func TestConsumesFromTypeOrDef(t *testing.T) {
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := consumesFromTypeOrDef(tt.method, tt.tp)
|
||||
result := consumesFromTypeOrDef(testingContext(t), tt.method, tt.tp)
|
||||
assert.Equal(t, tt.expected, result)
|
||||
})
|
||||
}
|
||||
|
||||
28
tools/goctl/api/swagger/context.go
Normal file
28
tools/goctl/api/swagger/context.go
Normal file
@@ -0,0 +1,28 @@
|
||||
package swagger
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/zeromicro/go-zero/tools/goctl/api/spec"
|
||||
)
|
||||
|
||||
type Context struct {
|
||||
UseDefinitions bool
|
||||
WrapCodeMsg bool
|
||||
BizCodeEnumDescription string
|
||||
}
|
||||
|
||||
func testingContext(_ *testing.T) Context {
|
||||
return Context{}
|
||||
}
|
||||
|
||||
func contextFromApi(info spec.Info) Context {
|
||||
if len(info.Properties) == 0 {
|
||||
return Context{}
|
||||
}
|
||||
return Context{
|
||||
UseDefinitions: getBoolFromKVOrDefault(info.Properties, propertyKeyUseDefinitions, defaultValueOfPropertyUseDefinition),
|
||||
WrapCodeMsg: getBoolFromKVOrDefault(info.Properties, propertyKeyWrapCodeMsg, false),
|
||||
BizCodeEnumDescription: getStringFromKVOrDefault(info.Properties, propertyKeyBizCodeEnumDescription, "business code"),
|
||||
}
|
||||
}
|
||||
32
tools/goctl/api/swagger/definition.go
Normal file
32
tools/goctl/api/swagger/definition.go
Normal file
@@ -0,0 +1,32 @@
|
||||
package swagger
|
||||
|
||||
import (
|
||||
"github.com/go-openapi/spec"
|
||||
apiSpec "github.com/zeromicro/go-zero/tools/goctl/api/spec"
|
||||
)
|
||||
|
||||
func definitionsFromTypes(ctx Context, types []apiSpec.Type) spec.Definitions {
|
||||
if !ctx.UseDefinitions {
|
||||
return nil
|
||||
}
|
||||
definitions := make(spec.Definitions)
|
||||
for _, tp := range types {
|
||||
typeName := tp.Name()
|
||||
definitions[typeName] = schemaFromType(ctx, tp)
|
||||
}
|
||||
return definitions
|
||||
}
|
||||
|
||||
func schemaFromType(ctx Context, tp apiSpec.Type) spec.Schema {
|
||||
p, r := propertiesFromType(ctx, tp)
|
||||
props := spec.SchemaProps{
|
||||
Type: typeFromGoType(ctx, tp),
|
||||
Properties: p,
|
||||
AdditionalProperties: mapFromGoType(ctx, tp),
|
||||
Items: itemsFromGoType(ctx, tp),
|
||||
Required: r,
|
||||
}
|
||||
return spec.Schema{
|
||||
SchemaProps: props,
|
||||
}
|
||||
}
|
||||
@@ -12,15 +12,16 @@ info (
|
||||
licenseURL: "https://github.com/zeromicro/go-zero" // licenseURL corresponding to Swagger
|
||||
consumes: "application/json" // consumes corresponding to Swagger,default value is `application/json`
|
||||
produces: "application/json" // produces corresponding to Swagger,default value is `application/json`
|
||||
schemes: "https" // schemes corresponding to Swagger,default value is `https``
|
||||
schemes: "http,https" // schemes corresponding to Swagger,default value is `https``
|
||||
host: "example.com" // host corresponding to Swagger,default value is `127.0.0.1`
|
||||
basePath: "/v1" // basePath corresponding to Swagger,default value is `/`
|
||||
wrapCodeMsg: "true" // to wrap in the universal code-msg structure, like {"code":0,"msg":"OK","data":$data}
|
||||
wrapCodeMsg: true // to wrap in the universal code-msg structure, like {"code":0,"msg":"OK","data":$data}
|
||||
bizCodeEnumDescription: "1001-User not login<br>1002-User permission denied" // enums of business error codes, in JSON format, with the key being the business error code and the value being the description of that error code. This only takes effect when wrapCodeMsg is set to true.
|
||||
// securityDefinitionsFromJson is a custom authentication configuration, and the JSON content will be directly inserted into the securityDefinitions of Swagger.
|
||||
// Format reference: https://swagger.io/specification/v2/#security-definitions-object
|
||||
// You can declare authType in the @server of the API to specify the authentication type used for its routes.
|
||||
securityDefinitionsFromJson: `{"apiKey":{"description":"apiKey type description","type":"apiKey","name":"x-api-key","in":"header"}}`
|
||||
useDefinitions: true // if set true, the definitions will be generated in the swagger.json for response body or json request body file, and the models will be referenced in the API.
|
||||
)
|
||||
|
||||
type (
|
||||
|
||||
4980
tools/goctl/api/swagger/example/example.swagger.json
Normal file
4980
tools/goctl/api/swagger/example/example.swagger.json
Normal file
File diff suppressed because it is too large
Load Diff
@@ -12,15 +12,16 @@ info (
|
||||
licenseURL: "https://github.com/zeromicro/go-zero" // 对应 swagger 的 licenseURL
|
||||
consumes: "application/json" // 对应 swagger 的 consumes,不填默认为 application/json
|
||||
produces: "application/json" // 对应 swagger 的 produces,不填默认为 application/json
|
||||
schemes: "https" // 对应 swagger 的 schemes,不填默认为 https
|
||||
schemes: "http,https" // 对应 swagger 的 schemes,不填默认为 https
|
||||
host: "example.com" // 对应 swagger 的 host,不填默认为 127.0.0.1
|
||||
basePath: "/v1" // 对应 swagger 的 basePath,不填默认为 /
|
||||
wrapCodeMsg: "true" // 是否用 code-msg 通用响应体,如果开启,则以格式 {"code":0,"msg":"OK","data":$data} 包括响应体
|
||||
wrapCodeMsg: true // 是否用 code-msg 通用响应体,如果开启,则以格式 {"code":0,"msg":"OK","data":$data} 包括响应体
|
||||
bizCodeEnumDescription: "1001-未登录<br>1002-无权限操作" // 全局业务错误码枚举描述,json 格式,key 为业务错误码,value 为该错误码的描述,仅当 wrapCodeMsg 为 true 时生效
|
||||
// securityDefinitionsFromJson 为自定义鉴权配置,json 内容将直接放入 swagger 的 securityDefinitions 中,
|
||||
// 格式参考 https://swagger.io/specification/v2/#security-definitions-object
|
||||
// 在 api 的 @server 中可声明 authType 来指定其路由使用的鉴权类型
|
||||
securityDefinitionsFromJson: `{"apiKey":{"description":"apiKey 类型鉴权自定义","type":"apiKey","name":"x-api-key","in":"header"}}`
|
||||
useDefinitions: true// 开启声明将生成models 进行关联,definitions 仅对响应体和 json 请求体生效
|
||||
)
|
||||
|
||||
type (
|
||||
@@ -48,11 +49,12 @@ type (
|
||||
summary: "query 类型接口集合" // 对应 swagger 的 summary
|
||||
prefix: v1
|
||||
authType: apiKey // 指定该路由使用的鉴权类型,值为 securityDefinitionsFromJson 中定义的名称
|
||||
group:"demo"
|
||||
)
|
||||
service Swagger {
|
||||
@doc (
|
||||
description: "query 接口"
|
||||
bizCodeEnumDescription: " 1003-用不存在<br>1004-非法操作" // 接口级别业务错误码枚举描述,会覆盖全局的业务错误码,json 格式,key 为业务错误码,value 为该错误码的描述,仅当 wrapCodeMsg 为 true 时生效
|
||||
bizCodeEnumDescription: " 1003-用不存在<br>1004-非法操作" // 接口级别业务错误码枚举描述,会覆盖全局的业务错误码,json 格式,key 为业务错误码,value 为该错误码的描述,仅当 wrapCodeMsg 为 true 且 useDefinitions 为 false 时生效
|
||||
)
|
||||
@handler query
|
||||
get /query (QueryReq) returns (QueryResp)
|
||||
|
||||
5608
tools/goctl/api/swagger/example/example_cn.swagger.json
Normal file
5608
tools/goctl/api/swagger/example/example_cn.swagger.json
Normal file
File diff suppressed because it is too large
Load Diff
@@ -81,21 +81,21 @@ func enumsValueFromOptions(options []string) []any {
|
||||
return []any{}
|
||||
}
|
||||
|
||||
func defValueFromOptions(options []string, apiType spec.Type) any {
|
||||
tp := sampleTypeFromGoType(apiType)
|
||||
return valueFromOptions(options, defFlag, tp)
|
||||
func defValueFromOptions(ctx Context, options []string, apiType spec.Type) any {
|
||||
tp := sampleTypeFromGoType(ctx, apiType)
|
||||
return valueFromOptions(ctx, options, defFlag, tp)
|
||||
}
|
||||
|
||||
func exampleValueFromOptions(options []string, apiType spec.Type) any {
|
||||
tp := sampleTypeFromGoType(apiType)
|
||||
val := valueFromOptions(options, exampleFlag, tp)
|
||||
func exampleValueFromOptions(ctx Context, options []string, apiType spec.Type) any {
|
||||
tp := sampleTypeFromGoType(ctx, apiType)
|
||||
val := valueFromOptions(ctx, options, exampleFlag, tp)
|
||||
if val != nil {
|
||||
return val
|
||||
}
|
||||
return defValueFromOptions(options, apiType)
|
||||
return defValueFromOptions(ctx, options, apiType)
|
||||
}
|
||||
|
||||
func valueFromOptions(options []string, key string, tp string) any {
|
||||
func valueFromOptions(_ Context, options []string, key string, tp string) any {
|
||||
if len(options) == 0 {
|
||||
return nil
|
||||
}
|
||||
@@ -103,16 +103,18 @@ func valueFromOptions(options []string, key string, tp string) any {
|
||||
if strings.HasPrefix(option, key) {
|
||||
s := option[len(key):]
|
||||
switch tp {
|
||||
case "integer":
|
||||
case swaggerTypeInteger:
|
||||
val, _ := strconv.ParseInt(s, 10, 64)
|
||||
return val
|
||||
case "boolean":
|
||||
case swaggerTypeBoolean:
|
||||
val, _ := strconv.ParseBool(s)
|
||||
return val
|
||||
case "number":
|
||||
case swaggerTypeNumber:
|
||||
val, _ := strconv.ParseFloat(s, 64)
|
||||
return val
|
||||
case "string":
|
||||
case swaggerTypeArray:
|
||||
return s
|
||||
case swaggerTypeString:
|
||||
return s
|
||||
default:
|
||||
return nil
|
||||
|
||||
@@ -161,7 +161,7 @@ func TestDefValueFromOptions(t *testing.T) {
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := defValueFromOptions(tt.options, tt.apiType)
|
||||
result := defValueFromOptions(testingContext(t), tt.options, tt.apiType)
|
||||
assert.Equal(t, tt.expected, result)
|
||||
})
|
||||
}
|
||||
@@ -202,7 +202,7 @@ func TestExampleValueFromOptions(t *testing.T) {
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
exampleValueFromOptions(tt.options, tt.apiType)
|
||||
exampleValueFromOptions(testingContext(t), tt.options, tt.apiType)
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -247,7 +247,7 @@ func TestValueFromOptions(t *testing.T) {
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := valueFromOptions(tt.options, tt.key, tt.tp)
|
||||
result := valueFromOptions(testingContext(t), tt.options, tt.key, tt.tp)
|
||||
assert.Equal(t, tt.expected, result)
|
||||
})
|
||||
}
|
||||
|
||||
@@ -8,7 +8,25 @@ import (
|
||||
apiSpec "github.com/zeromicro/go-zero/tools/goctl/api/spec"
|
||||
)
|
||||
|
||||
func parametersFromType(method string, tp apiSpec.Type) []spec.Parameter {
|
||||
func isPostJson(ctx Context, method string, tp apiSpec.Type) (string, bool) {
|
||||
if strings.EqualFold(method, http.MethodPost) {
|
||||
return "", false
|
||||
}
|
||||
structType, ok := tp.(apiSpec.DefineStruct)
|
||||
if !ok {
|
||||
return "", false
|
||||
}
|
||||
var isPostJson bool
|
||||
rangeMemberAndDo(ctx, structType, func(tag *apiSpec.Tags, required bool, member apiSpec.Member) {
|
||||
jsonTag, _ := tag.Get(tagJson)
|
||||
if !isPostJson {
|
||||
isPostJson = jsonTag != nil
|
||||
}
|
||||
})
|
||||
return structType.RawName, isPostJson
|
||||
}
|
||||
|
||||
func parametersFromType(ctx Context, method string, tp apiSpec.Type) []spec.Parameter {
|
||||
if tp == nil {
|
||||
return []spec.Parameter{}
|
||||
}
|
||||
@@ -16,12 +34,13 @@ func parametersFromType(method string, tp apiSpec.Type) []spec.Parameter {
|
||||
if !ok {
|
||||
return []spec.Parameter{}
|
||||
}
|
||||
|
||||
var (
|
||||
resp []spec.Parameter
|
||||
properties = map[string]spec.Schema{}
|
||||
requiredFields []string
|
||||
)
|
||||
rangeMemberAndDo(structType, func(tag *apiSpec.Tags, required bool, member apiSpec.Member) {
|
||||
rangeMemberAndDo(ctx, structType, func(tag *apiSpec.Tags, required bool, member apiSpec.Member) {
|
||||
headerTag, _ := tag.Get(tagHeader)
|
||||
hasHeader := headerTag != nil
|
||||
|
||||
@@ -44,10 +63,9 @@ func parametersFromType(method string, tp apiSpec.Type) []spec.Parameter {
|
||||
Enum: enumsValueFromOptions(headerTag.Options),
|
||||
},
|
||||
SimpleSchema: spec.SimpleSchema{
|
||||
Type: sampleTypeFromGoType(member.Type),
|
||||
Default: defValueFromOptions(headerTag.Options, member.Type),
|
||||
Example: exampleValueFromOptions(headerTag.Options, member.Type),
|
||||
Items: sampleItemsFromGoType(member.Type),
|
||||
Type: sampleTypeFromGoType(ctx, member.Type),
|
||||
Default: defValueFromOptions(ctx, headerTag.Options, member.Type),
|
||||
Items: sampleItemsFromGoType(ctx, member.Type),
|
||||
},
|
||||
ParamProps: spec.ParamProps{
|
||||
In: paramsInHeader,
|
||||
@@ -68,10 +86,9 @@ func parametersFromType(method string, tp apiSpec.Type) []spec.Parameter {
|
||||
Enum: enumsValueFromOptions(pathParameterTag.Options),
|
||||
},
|
||||
SimpleSchema: spec.SimpleSchema{
|
||||
Type: sampleTypeFromGoType(member.Type),
|
||||
Default: defValueFromOptions(pathParameterTag.Options, member.Type),
|
||||
Example: exampleValueFromOptions(pathParameterTag.Options, member.Type),
|
||||
Items: sampleItemsFromGoType(member.Type),
|
||||
Type: sampleTypeFromGoType(ctx, member.Type),
|
||||
Default: defValueFromOptions(ctx, pathParameterTag.Options, member.Type),
|
||||
Items: sampleItemsFromGoType(ctx, member.Type),
|
||||
},
|
||||
ParamProps: spec.ParamProps{
|
||||
In: paramsInPath,
|
||||
@@ -93,10 +110,9 @@ func parametersFromType(method string, tp apiSpec.Type) []spec.Parameter {
|
||||
Enum: enumsValueFromOptions(formTag.Options),
|
||||
},
|
||||
SimpleSchema: spec.SimpleSchema{
|
||||
Type: sampleTypeFromGoType(member.Type),
|
||||
Default: defValueFromOptions(formTag.Options, member.Type),
|
||||
Example: exampleValueFromOptions(formTag.Options, member.Type),
|
||||
Items: sampleItemsFromGoType(member.Type),
|
||||
Type: sampleTypeFromGoType(ctx, member.Type),
|
||||
Default: defValueFromOptions(ctx, formTag.Options, member.Type),
|
||||
Items: sampleItemsFromGoType(ctx, member.Type),
|
||||
},
|
||||
ParamProps: spec.ParamProps{
|
||||
In: paramsInQuery,
|
||||
@@ -116,10 +132,9 @@ func parametersFromType(method string, tp apiSpec.Type) []spec.Parameter {
|
||||
Enum: enumsValueFromOptions(formTag.Options),
|
||||
},
|
||||
SimpleSchema: spec.SimpleSchema{
|
||||
Type: sampleTypeFromGoType(member.Type),
|
||||
Default: defValueFromOptions(formTag.Options, member.Type),
|
||||
Example: exampleValueFromOptions(formTag.Options, member.Type),
|
||||
Items: sampleItemsFromGoType(member.Type),
|
||||
Type: sampleTypeFromGoType(ctx, member.Type),
|
||||
Default: defValueFromOptions(ctx, formTag.Options, member.Type),
|
||||
Items: sampleItemsFromGoType(ctx, member.Type),
|
||||
},
|
||||
ParamProps: spec.ParamProps{
|
||||
In: paramsInForm,
|
||||
@@ -139,25 +154,25 @@ func parametersFromType(method string, tp apiSpec.Type) []spec.Parameter {
|
||||
}
|
||||
var schema = spec.Schema{
|
||||
SwaggerSchemaProps: spec.SwaggerSchemaProps{
|
||||
Example: exampleValueFromOptions(jsonTag.Options, member.Type),
|
||||
Example: exampleValueFromOptions(ctx, jsonTag.Options, member.Type),
|
||||
},
|
||||
SchemaProps: spec.SchemaProps{
|
||||
Description: formatComment(member.Comment),
|
||||
Type: typeFromGoType(member.Type),
|
||||
Default: defValueFromOptions(jsonTag.Options, member.Type),
|
||||
Type: typeFromGoType(ctx, member.Type),
|
||||
Default: defValueFromOptions(ctx, jsonTag.Options, member.Type),
|
||||
Maximum: maximum,
|
||||
ExclusiveMaximum: exclusiveMaximum,
|
||||
Minimum: minimum,
|
||||
ExclusiveMinimum: exclusiveMinimum,
|
||||
Enum: enumsValueFromOptions(jsonTag.Options),
|
||||
AdditionalProperties: mapFromGoType(member.Type),
|
||||
AdditionalProperties: mapFromGoType(ctx, member.Type),
|
||||
},
|
||||
}
|
||||
switch sampleTypeFromGoType(member.Type) {
|
||||
switch sampleTypeFromGoType(ctx, member.Type) {
|
||||
case swaggerTypeArray:
|
||||
schema.Items = itemsFromGoType(member.Type)
|
||||
schema.Items = itemsFromGoType(ctx, member.Type)
|
||||
case swaggerTypeObject:
|
||||
p, r := propertiesFromType(member.Type)
|
||||
p, r := propertiesFromType(ctx, member.Type)
|
||||
schema.Properties = p
|
||||
schema.Required = r
|
||||
}
|
||||
@@ -165,20 +180,38 @@ func parametersFromType(method string, tp apiSpec.Type) []spec.Parameter {
|
||||
}
|
||||
})
|
||||
if len(properties) > 0 {
|
||||
resp = append(resp, spec.Parameter{
|
||||
ParamProps: spec.ParamProps{
|
||||
In: paramsInBody,
|
||||
Name: paramsInBody,
|
||||
Required: true,
|
||||
Schema: &spec.Schema{
|
||||
SchemaProps: spec.SchemaProps{
|
||||
Type: typeFromGoType(structType),
|
||||
Properties: properties,
|
||||
Required: requiredFields,
|
||||
if ctx.UseDefinitions {
|
||||
structName, ok := isPostJson(ctx, method, tp)
|
||||
if ok {
|
||||
resp = append(resp, spec.Parameter{
|
||||
ParamProps: spec.ParamProps{
|
||||
In: paramsInBody,
|
||||
Name: paramsInBody,
|
||||
Required: true,
|
||||
Schema: &spec.Schema{
|
||||
SchemaProps: spec.SchemaProps{
|
||||
Ref: spec.MustCreateRef(getRefName(structName)),
|
||||
},
|
||||
},
|
||||
},
|
||||
})
|
||||
}
|
||||
} else {
|
||||
resp = append(resp, spec.Parameter{
|
||||
ParamProps: spec.ParamProps{
|
||||
In: paramsInBody,
|
||||
Name: paramsInBody,
|
||||
Required: true,
|
||||
Schema: &spec.Schema{
|
||||
SchemaProps: spec.SchemaProps{
|
||||
Type: typeFromGoType(ctx, structType),
|
||||
Properties: properties,
|
||||
Required: requiredFields,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
})
|
||||
})
|
||||
}
|
||||
}
|
||||
return resp
|
||||
}
|
||||
|
||||
@@ -7,20 +7,21 @@ import (
|
||||
|
||||
"github.com/go-openapi/spec"
|
||||
apiSpec "github.com/zeromicro/go-zero/tools/goctl/api/spec"
|
||||
"github.com/zeromicro/go-zero/tools/goctl/util/stringx"
|
||||
)
|
||||
|
||||
func spec2Paths(info apiSpec.Info, srv apiSpec.Service) *spec.Paths {
|
||||
func spec2Paths(ctx Context, srv apiSpec.Service) *spec.Paths {
|
||||
paths := &spec.Paths{
|
||||
Paths: make(map[string]spec.PathItem),
|
||||
}
|
||||
for _, group := range srv.Groups {
|
||||
prefix := path.Clean(strings.TrimPrefix(group.GetAnnotation("prefix"), "/"))
|
||||
prefix := path.Clean(strings.TrimPrefix(group.GetAnnotation(propertyKeyPrefix), "/"))
|
||||
for _, route := range group.Routes {
|
||||
routPath := pathVariable2SwaggerVariable(route.Path)
|
||||
routPath := pathVariable2SwaggerVariable(ctx, route.Path)
|
||||
if len(prefix) > 0 && prefix != "." {
|
||||
routPath = "/" + path.Clean(prefix) + routPath
|
||||
}
|
||||
pathItem := spec2Path(info, group, route)
|
||||
pathItem := spec2Path(ctx, group, route)
|
||||
existPathItem, ok := paths.Paths[routPath]
|
||||
if !ok {
|
||||
paths.Paths[routPath] = pathItem
|
||||
@@ -60,8 +61,8 @@ func mergePathItem(old, new spec.PathItem) spec.PathItem {
|
||||
return old
|
||||
}
|
||||
|
||||
func spec2Path(info apiSpec.Info, group apiSpec.Group, route apiSpec.Route) spec.PathItem {
|
||||
authType := getStringFromKVOrDefault(group.Annotation.Properties, "authType", "")
|
||||
func spec2Path(ctx Context, group apiSpec.Group, route apiSpec.Route) spec.PathItem {
|
||||
authType := getStringFromKVOrDefault(group.Annotation.Properties, propertyKeyAuthType, "")
|
||||
var security []map[string][]string
|
||||
if len(authType) > 0 {
|
||||
security = []map[string][]string{
|
||||
@@ -70,22 +71,29 @@ func spec2Path(info apiSpec.Info, group apiSpec.Group, route apiSpec.Route) spec
|
||||
},
|
||||
}
|
||||
}
|
||||
groupName := getStringFromKVOrDefault(group.Annotation.Properties, propertyKeyGroup, "")
|
||||
operationId := route.Handler
|
||||
if len(groupName) > 0 {
|
||||
operationId = stringx.From(groupName + "_" + route.Handler).ToCamel()
|
||||
}
|
||||
operationId = stringx.From(operationId).Untitle()
|
||||
op := &spec.Operation{
|
||||
OperationProps: spec.OperationProps{
|
||||
Description: getStringFromKVOrDefault(route.AtDoc.Properties, "description", ""),
|
||||
Consumes: consumesFromTypeOrDef(route.Method, route.RequestType),
|
||||
Produces: getListFromInfoOrDefault(route.AtDoc.Properties, "produces", []string{applicationJson}),
|
||||
Schemes: getListFromInfoOrDefault(route.AtDoc.Properties, "schemes", []string{schemeHttps}),
|
||||
Tags: getListFromInfoOrDefault(group.Annotation.Properties, "tags", []string{""}),
|
||||
Summary: getStringFromKVOrDefault(route.AtDoc.Properties, "summary", getFirstUsableString(route.AtDoc.Text, route.Handler)),
|
||||
Deprecated: getBoolFromKVOrDefault(route.AtDoc.Properties, "deprecated", false),
|
||||
Parameters: parametersFromType(route.Method, route.RequestType),
|
||||
Responses: jsonResponseFromType(info, route.AtDoc, route.ResponseType),
|
||||
Description: getStringFromKVOrDefault(route.AtDoc.Properties, propertyKeyDescription, ""),
|
||||
Consumes: consumesFromTypeOrDef(ctx, route.Method, route.RequestType),
|
||||
Produces: getListFromInfoOrDefault(route.AtDoc.Properties, propertyKeyProduces, []string{applicationJson}),
|
||||
Schemes: getListFromInfoOrDefault(route.AtDoc.Properties, propertyKeySchemes, []string{schemeHttps}),
|
||||
Tags: getListFromInfoOrDefault(group.Annotation.Properties, propertyKeyTags, getListFromInfoOrDefault(group.Annotation.Properties, propertyKeySummary, []string{})),
|
||||
Summary: getStringFromKVOrDefault(route.AtDoc.Properties, propertyKeySummary, getFirstUsableString(route.AtDoc.Text, route.Handler)),
|
||||
ID: operationId,
|
||||
Deprecated: getBoolFromKVOrDefault(route.AtDoc.Properties, propertyKeyDeprecated, false),
|
||||
Parameters: parametersFromType(ctx, route.Method, route.RequestType),
|
||||
Security: security,
|
||||
Responses: jsonResponseFromType(ctx, route.AtDoc, route.ResponseType),
|
||||
},
|
||||
}
|
||||
externalDocsDescription := getStringFromKVOrDefault(route.AtDoc.Properties, "externalDocsDescription", "")
|
||||
externalDocsURL := getStringFromKVOrDefault(route.AtDoc.Properties, "externalDocsURL", "")
|
||||
externalDocsDescription := getStringFromKVOrDefault(route.AtDoc.Properties, propertyKeyExternalDocsDescription, "")
|
||||
externalDocsURL := getStringFromKVOrDefault(route.AtDoc.Properties, propertyKeyExternalDocsURL, "")
|
||||
if len(externalDocsDescription) > 0 || len(externalDocsURL) > 0 {
|
||||
op.ExternalDocs = &spec.ExternalDocumentation{
|
||||
Description: externalDocsDescription,
|
||||
|
||||
@@ -5,18 +5,18 @@ import (
|
||||
apiSpec "github.com/zeromicro/go-zero/tools/goctl/api/spec"
|
||||
)
|
||||
|
||||
func propertiesFromType(tp apiSpec.Type) (spec.SchemaProperties, []string) {
|
||||
func propertiesFromType(ctx Context, tp apiSpec.Type) (spec.SchemaProperties, []string) {
|
||||
var (
|
||||
properties = map[string]spec.Schema{}
|
||||
requiredFields []string
|
||||
)
|
||||
switch val := tp.(type) {
|
||||
case apiSpec.PointerType:
|
||||
return propertiesFromType(val.Type)
|
||||
return propertiesFromType(ctx, val.Type)
|
||||
case apiSpec.ArrayType:
|
||||
return propertiesFromType(val.Value)
|
||||
return propertiesFromType(ctx, val.Value)
|
||||
case apiSpec.DefineStruct, apiSpec.NestedStruct:
|
||||
rangeMemberAndDo(val, func(tag *apiSpec.Tags, required bool, member apiSpec.Member) {
|
||||
rangeMemberAndDo(ctx, val, func(tag *apiSpec.Tags, required bool, member apiSpec.Member) {
|
||||
var (
|
||||
jsonTagString = member.Name
|
||||
minimum, maximum *float64
|
||||
@@ -24,42 +24,63 @@ func propertiesFromType(tp apiSpec.Type) (spec.SchemaProperties, []string) {
|
||||
example, defaultValue any
|
||||
enum []any
|
||||
)
|
||||
pathTag, _ := tag.Get(tagPath)
|
||||
if pathTag != nil {
|
||||
return
|
||||
}
|
||||
formTag, _ := tag.Get(tagForm)
|
||||
if formTag != nil {
|
||||
return
|
||||
}
|
||||
headerTag, _ := tag.Get(tagHeader)
|
||||
if headerTag != nil {
|
||||
return
|
||||
}
|
||||
|
||||
jsonTag, _ := tag.Get(tagJson)
|
||||
if jsonTag != nil {
|
||||
jsonTagString = jsonTag.Name
|
||||
minimum, maximum, exclusiveMinimum, exclusiveMaximum = rangeValueFromOptions(jsonTag.Options)
|
||||
example = exampleValueFromOptions(jsonTag.Options, member.Type)
|
||||
defaultValue = defValueFromOptions(jsonTag.Options, member.Type)
|
||||
example = exampleValueFromOptions(ctx, jsonTag.Options, member.Type)
|
||||
defaultValue = defValueFromOptions(ctx, jsonTag.Options, member.Type)
|
||||
enum = enumsValueFromOptions(jsonTag.Options)
|
||||
}
|
||||
|
||||
if required {
|
||||
requiredFields = append(requiredFields, jsonTagString)
|
||||
}
|
||||
var schema = spec.Schema{
|
||||
|
||||
schema := spec.Schema{
|
||||
SwaggerSchemaProps: spec.SwaggerSchemaProps{
|
||||
Example: example,
|
||||
},
|
||||
SchemaProps: spec.SchemaProps{
|
||||
Description: formatComment(member.Comment),
|
||||
Type: typeFromGoType(member.Type),
|
||||
Type: typeFromGoType(ctx, member.Type),
|
||||
Default: defaultValue,
|
||||
Maximum: maximum,
|
||||
ExclusiveMaximum: exclusiveMaximum,
|
||||
Minimum: minimum,
|
||||
ExclusiveMinimum: exclusiveMinimum,
|
||||
Enum: enum,
|
||||
AdditionalProperties: mapFromGoType(member.Type),
|
||||
AdditionalProperties: mapFromGoType(ctx, member.Type),
|
||||
},
|
||||
}
|
||||
switch sampleTypeFromGoType(member.Type) {
|
||||
|
||||
switch sampleTypeFromGoType(ctx, member.Type) {
|
||||
case swaggerTypeArray:
|
||||
schema.Items = itemsFromGoType(member.Type)
|
||||
schema.Items = itemsFromGoType(ctx, member.Type)
|
||||
case swaggerTypeObject:
|
||||
p, r := propertiesFromType(member.Type)
|
||||
p, r := propertiesFromType(ctx, member.Type)
|
||||
schema.Properties = p
|
||||
schema.Required = r
|
||||
}
|
||||
if ctx.UseDefinitions {
|
||||
structName, containsStruct := containsStruct(member.Type)
|
||||
if containsStruct {
|
||||
schema.SchemaProps.Ref = spec.MustCreateRef(getRefName(structName))
|
||||
}
|
||||
}
|
||||
|
||||
properties[jsonTagString] = schema
|
||||
})
|
||||
@@ -67,3 +88,22 @@ func propertiesFromType(tp apiSpec.Type) (spec.SchemaProperties, []string) {
|
||||
|
||||
return properties, requiredFields
|
||||
}
|
||||
|
||||
func containsStruct(tp apiSpec.Type) (string, bool) {
|
||||
switch val := tp.(type) {
|
||||
case apiSpec.PointerType:
|
||||
return containsStruct(val.Type)
|
||||
case apiSpec.ArrayType:
|
||||
return containsStruct(val.Value)
|
||||
case apiSpec.DefineStruct:
|
||||
return val.RawName, true
|
||||
case apiSpec.MapType:
|
||||
return containsStruct(val.Value)
|
||||
default:
|
||||
return "", false
|
||||
}
|
||||
}
|
||||
|
||||
func getRefName(typeName string) string {
|
||||
return "#/definitions/" + typeName
|
||||
}
|
||||
|
||||
@@ -1,25 +1,62 @@
|
||||
package swagger
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"github.com/go-openapi/spec"
|
||||
apiSpec "github.com/zeromicro/go-zero/tools/goctl/api/spec"
|
||||
)
|
||||
|
||||
func jsonResponseFromType(info apiSpec.Info, atDoc apiSpec.AtDoc, tp apiSpec.Type) *spec.Responses {
|
||||
p, _ := propertiesFromType(tp)
|
||||
func jsonResponseFromType(ctx Context, atDoc apiSpec.AtDoc, tp apiSpec.Type) *spec.Responses {
|
||||
if tp == nil {
|
||||
return &spec.Responses{
|
||||
ResponsesProps: spec.ResponsesProps{
|
||||
StatusCodeResponses: map[int]spec.Response{
|
||||
http.StatusOK: {
|
||||
ResponseProps: spec.ResponseProps{
|
||||
Description: "",
|
||||
Schema: &spec.Schema{},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
props := spec.SchemaProps{
|
||||
Type: typeFromGoType(tp),
|
||||
Properties: p,
|
||||
AdditionalProperties: mapFromGoType(tp),
|
||||
Items: itemsFromGoType(tp),
|
||||
AdditionalProperties: mapFromGoType(ctx, tp),
|
||||
Items: itemsFromGoType(ctx, tp),
|
||||
}
|
||||
if ctx.UseDefinitions {
|
||||
structName, ok := containsStruct(tp)
|
||||
if ok {
|
||||
props.Ref = spec.MustCreateRef(getRefName(structName))
|
||||
return &spec.Responses{
|
||||
ResponsesProps: spec.ResponsesProps{
|
||||
StatusCodeResponses: map[int]spec.Response{
|
||||
http.StatusOK: {
|
||||
ResponseProps: spec.ResponseProps{
|
||||
Schema: &spec.Schema{
|
||||
SchemaProps: wrapCodeMsgProps(ctx, props, atDoc),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
p, _ := propertiesFromType(ctx, tp)
|
||||
props.Type = typeFromGoType(ctx, tp)
|
||||
props.Properties = p
|
||||
return &spec.Responses{
|
||||
ResponsesProps: spec.ResponsesProps{
|
||||
Default: &spec.Response{
|
||||
ResponseProps: spec.ResponseProps{
|
||||
Schema: &spec.Schema{
|
||||
SchemaProps: wrapCodeMsgProps(props, info, atDoc),
|
||||
StatusCodeResponses: map[int]spec.Response{
|
||||
http.StatusOK: {
|
||||
ResponseProps: spec.ResponseProps{
|
||||
Schema: &spec.Schema{
|
||||
SchemaProps: wrapCodeMsgProps(ctx, props, atDoc),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
|
||||
@@ -8,12 +8,11 @@ import (
|
||||
"github.com/go-openapi/spec"
|
||||
apiSpec "github.com/zeromicro/go-zero/tools/goctl/api/spec"
|
||||
"github.com/zeromicro/go-zero/tools/goctl/internal/version"
|
||||
"github.com/zeromicro/go-zero/tools/goctl/util"
|
||||
)
|
||||
|
||||
func spec2Swagger(api *apiSpec.ApiSpec) (*spec.Swagger, error) {
|
||||
ctx := contextFromApi(api.Info)
|
||||
extensions, info := specExtensions(api.Info)
|
||||
|
||||
var securityDefinitions spec.SecurityDefinitions
|
||||
securityDefinitionsFromJson := getStringFromKVOrDefault(api.Info.Properties, "securityDefinitionsFromJson", `{}`)
|
||||
_ = json.Unmarshal([]byte(securityDefinitionsFromJson), &securityDefinitions)
|
||||
@@ -22,14 +21,15 @@ func spec2Swagger(api *apiSpec.ApiSpec) (*spec.Swagger, error) {
|
||||
Extensions: extensions,
|
||||
},
|
||||
SwaggerProps: spec.SwaggerProps{
|
||||
Consumes: getListFromInfoOrDefault(api.Info.Properties, "consumes", []string{applicationJson}),
|
||||
Produces: getListFromInfoOrDefault(api.Info.Properties, "produces", []string{applicationJson}),
|
||||
Schemes: getListFromInfoOrDefault(api.Info.Properties, "schemes", []string{schemeHttps}),
|
||||
Definitions: definitionsFromTypes(ctx, api.Types),
|
||||
Consumes: getListFromInfoOrDefault(api.Info.Properties, propertyKeyConsumes, []string{applicationJson}),
|
||||
Produces: getListFromInfoOrDefault(api.Info.Properties, propertyKeyProduces, []string{applicationJson}),
|
||||
Schemes: getListFromInfoOrDefault(api.Info.Properties, propertyKeySchemes, []string{schemeHttps}),
|
||||
Swagger: swaggerVersion,
|
||||
Info: info,
|
||||
Host: getStringFromKVOrDefault(api.Info.Properties, "host", defaultHost),
|
||||
BasePath: getStringFromKVOrDefault(api.Info.Properties, "basePath", defaultBasePath),
|
||||
Paths: spec2Paths(api.Info, api.Service),
|
||||
Host: getStringFromKVOrDefault(api.Info.Properties, propertyKeyHost, ""),
|
||||
BasePath: getStringFromKVOrDefault(api.Info.Properties, propertyKeyBasePath, defaultBasePath),
|
||||
Paths: spec2Paths(ctx, api.Service),
|
||||
SecurityDefinitions: securityDefinitions,
|
||||
},
|
||||
}
|
||||
@@ -42,7 +42,7 @@ func formatComment(comment string) string {
|
||||
return strings.TrimSpace(s)
|
||||
}
|
||||
|
||||
func sampleItemsFromGoType(tp apiSpec.Type) *spec.Items {
|
||||
func sampleItemsFromGoType(ctx Context, tp apiSpec.Type) *spec.Items {
|
||||
val, ok := tp.(apiSpec.ArrayType)
|
||||
if !ok {
|
||||
return nil
|
||||
@@ -52,14 +52,14 @@ func sampleItemsFromGoType(tp apiSpec.Type) *spec.Items {
|
||||
case apiSpec.PrimitiveType:
|
||||
return &spec.Items{
|
||||
SimpleSchema: spec.SimpleSchema{
|
||||
Type: sampleTypeFromGoType(item),
|
||||
Type: sampleTypeFromGoType(ctx, item),
|
||||
},
|
||||
}
|
||||
case apiSpec.ArrayType:
|
||||
return &spec.Items{
|
||||
SimpleSchema: spec.SimpleSchema{
|
||||
Type: sampleTypeFromGoType(item),
|
||||
Items: sampleItemsFromGoType(item),
|
||||
Type: sampleTypeFromGoType(ctx, item),
|
||||
Items: sampleItemsFromGoType(ctx, item),
|
||||
},
|
||||
}
|
||||
default: // unsupported type
|
||||
@@ -68,30 +68,30 @@ func sampleItemsFromGoType(tp apiSpec.Type) *spec.Items {
|
||||
}
|
||||
|
||||
// itemsFromGoType returns the schema or array of the type, just for non json body parameters.
|
||||
func itemsFromGoType(tp apiSpec.Type) *spec.SchemaOrArray {
|
||||
func itemsFromGoType(ctx Context, tp apiSpec.Type) *spec.SchemaOrArray {
|
||||
array, ok := tp.(apiSpec.ArrayType)
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
return itemFromGoType(array.Value)
|
||||
return itemFromGoType(ctx, array.Value)
|
||||
}
|
||||
|
||||
func mapFromGoType(tp apiSpec.Type) *spec.SchemaOrBool {
|
||||
func mapFromGoType(ctx Context, tp apiSpec.Type) *spec.SchemaOrBool {
|
||||
mapType, ok := tp.(apiSpec.MapType)
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
var schema = &spec.Schema{
|
||||
SchemaProps: spec.SchemaProps{
|
||||
Type: typeFromGoType(mapType.Value),
|
||||
AdditionalProperties: mapFromGoType(mapType.Value),
|
||||
Type: typeFromGoType(ctx, mapType.Value),
|
||||
AdditionalProperties: mapFromGoType(ctx, mapType.Value),
|
||||
},
|
||||
}
|
||||
switch sampleTypeFromGoType(mapType.Value) {
|
||||
switch sampleTypeFromGoType(ctx, mapType.Value) {
|
||||
case swaggerTypeArray:
|
||||
schema.Items = itemsFromGoType(mapType.Value)
|
||||
schema.Items = itemsFromGoType(ctx, mapType.Value)
|
||||
case swaggerTypeObject:
|
||||
p, r := propertiesFromType(mapType.Value)
|
||||
p, r := propertiesFromType(ctx, mapType.Value)
|
||||
schema.Properties = p
|
||||
schema.Required = r
|
||||
}
|
||||
@@ -102,37 +102,37 @@ func mapFromGoType(tp apiSpec.Type) *spec.SchemaOrBool {
|
||||
}
|
||||
|
||||
// itemFromGoType returns the schema or array of the type, just for non json body parameters.
|
||||
func itemFromGoType(tp apiSpec.Type) *spec.SchemaOrArray {
|
||||
func itemFromGoType(ctx Context, tp apiSpec.Type) *spec.SchemaOrArray {
|
||||
switch itemType := tp.(type) {
|
||||
case apiSpec.PrimitiveType:
|
||||
return &spec.SchemaOrArray{
|
||||
Schema: &spec.Schema{
|
||||
SchemaProps: spec.SchemaProps{
|
||||
Type: typeFromGoType(tp),
|
||||
Type: typeFromGoType(ctx, tp),
|
||||
},
|
||||
},
|
||||
}
|
||||
case apiSpec.DefineStruct, apiSpec.NestedStruct:
|
||||
properties, requiredFields := propertiesFromType(itemType)
|
||||
case apiSpec.DefineStruct, apiSpec.NestedStruct, apiSpec.MapType:
|
||||
properties, requiredFields := propertiesFromType(ctx, itemType)
|
||||
return &spec.SchemaOrArray{
|
||||
Schema: &spec.Schema{
|
||||
SchemaProps: spec.SchemaProps{
|
||||
Type: typeFromGoType(itemType),
|
||||
Items: itemsFromGoType(itemType),
|
||||
Type: typeFromGoType(ctx, itemType),
|
||||
Items: itemsFromGoType(ctx, itemType),
|
||||
Properties: properties,
|
||||
Required: requiredFields,
|
||||
AdditionalProperties: mapFromGoType(itemType),
|
||||
AdditionalProperties: mapFromGoType(ctx, itemType),
|
||||
},
|
||||
},
|
||||
}
|
||||
case apiSpec.PointerType:
|
||||
return itemFromGoType(itemType.Type)
|
||||
return itemFromGoType(ctx, itemType.Type)
|
||||
case apiSpec.ArrayType:
|
||||
return &spec.SchemaOrArray{
|
||||
Schema: &spec.Schema{
|
||||
SchemaProps: spec.SchemaProps{
|
||||
Type: typeFromGoType(itemType),
|
||||
Items: itemsFromGoType(itemType),
|
||||
Type: typeFromGoType(ctx, itemType),
|
||||
Items: itemsFromGoType(ctx, itemType),
|
||||
},
|
||||
},
|
||||
}
|
||||
@@ -140,7 +140,7 @@ func itemFromGoType(tp apiSpec.Type) *spec.SchemaOrArray {
|
||||
return nil
|
||||
}
|
||||
|
||||
func typeFromGoType(tp apiSpec.Type) []string {
|
||||
func typeFromGoType(ctx Context, tp apiSpec.Type) []string {
|
||||
switch val := tp.(type) {
|
||||
case apiSpec.PrimitiveType:
|
||||
res, ok := tpMapper[val.RawName]
|
||||
@@ -152,12 +152,12 @@ func typeFromGoType(tp apiSpec.Type) []string {
|
||||
case apiSpec.DefineStruct, apiSpec.MapType:
|
||||
return []string{swaggerTypeObject}
|
||||
case apiSpec.PointerType:
|
||||
return typeFromGoType(val.Type)
|
||||
return typeFromGoType(ctx, val.Type)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func sampleTypeFromGoType(tp apiSpec.Type) string {
|
||||
func sampleTypeFromGoType(ctx Context, tp apiSpec.Type) string {
|
||||
switch val := tp.(type) {
|
||||
case apiSpec.PrimitiveType:
|
||||
return tpMapper[val.RawName]
|
||||
@@ -166,31 +166,30 @@ func sampleTypeFromGoType(tp apiSpec.Type) string {
|
||||
case apiSpec.DefineStruct, apiSpec.MapType, apiSpec.NestedStruct:
|
||||
return swaggerTypeObject
|
||||
case apiSpec.PointerType:
|
||||
return sampleTypeFromGoType(val.Type)
|
||||
return sampleTypeFromGoType(ctx, val.Type)
|
||||
default:
|
||||
return ""
|
||||
}
|
||||
}
|
||||
|
||||
func typeContainsTag(structType apiSpec.DefineStruct, tag string) bool {
|
||||
for _, field := range structType.Members {
|
||||
tags, _ := apiSpec.Parse(field.Tag)
|
||||
for _, t := range tags.Tags() {
|
||||
if t.Key == tag {
|
||||
return true
|
||||
}
|
||||
func typeContainsTag(ctx Context, structType apiSpec.DefineStruct, tag string) bool {
|
||||
members := expandMembers(ctx, structType)
|
||||
for _, member := range members {
|
||||
tags, _ := apiSpec.Parse(member.Tag)
|
||||
if _, err := tags.Get(tag); err == nil {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func expandMembers(tp apiSpec.Type) []apiSpec.Member {
|
||||
func expandMembers(ctx Context, tp apiSpec.Type) []apiSpec.Member {
|
||||
var members []apiSpec.Member
|
||||
switch val := tp.(type) {
|
||||
case apiSpec.DefineStruct:
|
||||
for _, v := range val.Members {
|
||||
if v.IsInline {
|
||||
members = append(members, expandMembers(v.Type)...)
|
||||
members = append(members, expandMembers(ctx, v.Type)...)
|
||||
continue
|
||||
}
|
||||
members = append(members, v)
|
||||
@@ -198,7 +197,7 @@ func expandMembers(tp apiSpec.Type) []apiSpec.Member {
|
||||
case apiSpec.NestedStruct:
|
||||
for _, v := range val.Members {
|
||||
if v.IsInline {
|
||||
members = append(members, expandMembers(v.Type)...)
|
||||
members = append(members, expandMembers(ctx, v.Type)...)
|
||||
continue
|
||||
}
|
||||
members = append(members, v)
|
||||
@@ -208,42 +207,42 @@ func expandMembers(tp apiSpec.Type) []apiSpec.Member {
|
||||
return members
|
||||
}
|
||||
|
||||
func rangeMemberAndDo(structType apiSpec.Type, do func(tag *apiSpec.Tags, required bool, member apiSpec.Member)) {
|
||||
var members = expandMembers(structType)
|
||||
func rangeMemberAndDo(ctx Context, structType apiSpec.Type, do func(tag *apiSpec.Tags, required bool, member apiSpec.Member)) {
|
||||
var members = expandMembers(ctx, structType)
|
||||
|
||||
for _, field := range members {
|
||||
tags, _ := apiSpec.Parse(field.Tag)
|
||||
required := isRequired(tags)
|
||||
required := isRequired(ctx, tags)
|
||||
do(tags, required, field)
|
||||
}
|
||||
}
|
||||
|
||||
func isRequired(tags *apiSpec.Tags) bool {
|
||||
func isRequired(ctx Context, tags *apiSpec.Tags) bool {
|
||||
tag, err := tags.Get(tagJson)
|
||||
if err == nil {
|
||||
return !isOptional(tag.Options)
|
||||
return !isOptional(ctx, tag.Options)
|
||||
}
|
||||
tag, err = tags.Get(tagForm)
|
||||
if err == nil {
|
||||
return !isOptional(tag.Options)
|
||||
return !isOptional(ctx, tag.Options)
|
||||
}
|
||||
tag, err = tags.Get(tagPath)
|
||||
if err == nil {
|
||||
return !isOptional(tag.Options)
|
||||
return !isOptional(ctx, tag.Options)
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func isOptional(options []string) bool {
|
||||
func isOptional(_ Context, options []string) bool {
|
||||
for _, option := range options {
|
||||
if option == "optional" {
|
||||
if option == optionalFlag {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func pathVariable2SwaggerVariable(path string) string {
|
||||
func pathVariable2SwaggerVariable(_ Context, path string) string {
|
||||
pathItems := strings.FieldsFunc(path, slashRune)
|
||||
var resp []string
|
||||
for _, v := range pathItems {
|
||||
@@ -256,13 +255,12 @@ func pathVariable2SwaggerVariable(path string) string {
|
||||
return "/" + strings.Join(resp, "/")
|
||||
}
|
||||
|
||||
func wrapCodeMsgProps(properties spec.SchemaProps, api apiSpec.Info, atDoc apiSpec.AtDoc) spec.SchemaProps {
|
||||
wrapCodeMsg := getBoolFromKVOrDefault(api.Properties, "wrapCodeMsg", false)
|
||||
if !wrapCodeMsg {
|
||||
func wrapCodeMsgProps(ctx Context, properties spec.SchemaProps, atDoc apiSpec.AtDoc) spec.SchemaProps {
|
||||
if !ctx.WrapCodeMsg {
|
||||
return properties
|
||||
}
|
||||
globalCodeDesc := getStringFromKVOrDefault(api.Properties, "bizCodeEnumDescription", "business code")
|
||||
methodCodeDesc := getStringFromKVOrDefault(atDoc.Properties, "bizCodeEnumDescription", globalCodeDesc)
|
||||
globalCodeDesc := ctx.BizCodeEnumDescription
|
||||
methodCodeDesc := getStringFromKVOrDefault(atDoc.Properties, propertyKeyBizCodeEnumDescription, globalCodeDesc)
|
||||
return spec.SchemaProps{
|
||||
Type: []string{swaggerTypeObject},
|
||||
Properties: spec.SchemaProperties{
|
||||
@@ -295,27 +293,27 @@ func specExtensions(api apiSpec.Info) (spec.Extensions, *spec.Info) {
|
||||
ext := spec.Extensions{}
|
||||
ext.Add("x-goctl-version", version.BuildVersion)
|
||||
ext.Add("x-description", "This is a goctl generated swagger file.")
|
||||
ext.Add("x-date", time.Now().Format("2006-01-02 15:04:05"))
|
||||
ext.Add("x-date", time.Now().Format(time.DateTime))
|
||||
ext.Add("x-github", "https://github.com/zeromicro/go-zero")
|
||||
ext.Add("x-go-zero-doc", "https://go-zero.dev/")
|
||||
|
||||
info := &spec.Info{}
|
||||
info.Description = util.Unquote(api.Properties["description"])
|
||||
info.Title = util.Unquote(api.Properties["title"])
|
||||
info.TermsOfService = util.Unquote(api.Properties["termsOfService"])
|
||||
info.Version = util.Unquote(api.Properties["version"])
|
||||
info.Title = getStringFromKVOrDefault(api.Properties, propertyKeyTitle, "")
|
||||
info.Description = getStringFromKVOrDefault(api.Properties, propertyKeyDescription, "")
|
||||
info.TermsOfService = getStringFromKVOrDefault(api.Properties, propertyKeyTermsOfService, "")
|
||||
info.Version = getStringFromKVOrDefault(api.Properties, propertyKeyVersion, "1.0")
|
||||
|
||||
contactInfo := spec.ContactInfo{}
|
||||
contactInfo.Name = util.Unquote(api.Properties["contactName"])
|
||||
contactInfo.URL = util.Unquote(api.Properties["contactURL"])
|
||||
contactInfo.Email = util.Unquote(api.Properties["contactEmail"])
|
||||
contactInfo.Name = getStringFromKVOrDefault(api.Properties, propertyKeyContactName, "")
|
||||
contactInfo.URL = getStringFromKVOrDefault(api.Properties, propertyKeyContactURL, "")
|
||||
contactInfo.Email = getStringFromKVOrDefault(api.Properties, propertyKeyContactEmail, "")
|
||||
if len(contactInfo.Name) > 0 || len(contactInfo.URL) > 0 || len(contactInfo.Email) > 0 {
|
||||
info.Contact = &contactInfo
|
||||
}
|
||||
|
||||
license := &spec.License{}
|
||||
license.Name = util.Unquote(api.Properties["licenseName"])
|
||||
license.URL = util.Unquote(api.Properties["licenseURL"])
|
||||
license.Name = getStringFromKVOrDefault(api.Properties, propertyKeyLicenseName, "")
|
||||
license.URL = getStringFromKVOrDefault(api.Properties, propertyKeyLicenseURL, "")
|
||||
if len(license.Name) > 0 || len(license.URL) > 0 {
|
||||
info.License = license
|
||||
}
|
||||
|
||||
@@ -19,7 +19,7 @@ func Test_pathVariable2SwaggerVariable(t *testing.T) {
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
result := pathVariable2SwaggerVariable(tc.input)
|
||||
result := pathVariable2SwaggerVariable(testingContext(t), tc.input)
|
||||
assert.Equal(t, tc.expected, result)
|
||||
}
|
||||
}
|
||||
|
||||
2
tools/goctl/build.env
Normal file
2
tools/goctl/build.env
Normal file
@@ -0,0 +1,2 @@
|
||||
APP_NAME=goctl
|
||||
APP_VERSION=1.8.4-beta
|
||||
50
tools/goctl/build.sh
Normal file
50
tools/goctl/build.sh
Normal file
@@ -0,0 +1,50 @@
|
||||
#!/bin/bash
|
||||
|
||||
source build.env
|
||||
APP_NAME=$APP_NAME
|
||||
VERSION=$APP_VERSION
|
||||
BUILD_DIR="dist"
|
||||
ZIP_DIR="${BUILD_DIR}/zips"
|
||||
|
||||
PLATFORMS=(
|
||||
"linux/amd64"
|
||||
"linux/arm64"
|
||||
"darwin/amd64"
|
||||
"darwin/arm64"
|
||||
"windows/amd64"
|
||||
"windows/arm64"
|
||||
)
|
||||
|
||||
rm -rf "${BUILD_DIR}"
|
||||
mkdir -p "${ZIP_DIR}"
|
||||
|
||||
for PLATFORM in "${PLATFORMS[@]}"; do
|
||||
GOOS=${PLATFORM%/*}
|
||||
GOARCH=${PLATFORM#*/}
|
||||
|
||||
OUTPUT="${BUILD_DIR}/${APP_NAME}-${VERSION}-${GOOS}-${GOARCH}"
|
||||
|
||||
if [ "${GOOS}" = "windows" ]; then
|
||||
OUTPUT="${OUTPUT}.exe"
|
||||
fi
|
||||
|
||||
echo "Building for ${GOOS}/${GOARCH}..."
|
||||
|
||||
env GOOS="${GOOS}" GOARCH="${GOARCH}" go build -o "${OUTPUT}" goctl.go
|
||||
|
||||
if [ $? -ne 0 ]; then
|
||||
echo "Error building for ${GOOS}/${GOARCH}"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
ZIP_OUTPUT="${ZIP_DIR}/$(basename "${OUTPUT}")"
|
||||
if [ "${GOOS}" = "windows" ]; then
|
||||
zip -j "${ZIP_OUTPUT%.exe}.zip" "${OUTPUT}"
|
||||
else
|
||||
zip -j "${ZIP_OUTPUT}.zip" "${OUTPUT}"
|
||||
fi
|
||||
|
||||
echo "Created zip: ${ZIP_OUTPUT}.zip"
|
||||
done
|
||||
|
||||
echo "All builds completed successfully. Zip files are in ${ZIP_DIR}/"
|
||||
103
tools/goctl/change.md
Normal file
103
tools/goctl/change.md
Normal file
@@ -0,0 +1,103 @@
|
||||
# 1.8.4-beta
|
||||
|
||||
## swagger
|
||||
- [features] Supported operation id for swagger
|
||||
## Other
|
||||
- Updated version to 1.8.4-beta
|
||||
|
||||
|
||||
# 1.8.4-alpha
|
||||
|
||||
## swagger
|
||||
1. [bug fix] remove example generation when request body are `query`, `path` and `header`
|
||||
- it not supported in api spec 2.0
|
||||
- it's will generate example when request body is json format.
|
||||
2. [features] swagger generation supported definitions
|
||||
- supported response definitions
|
||||
- supported json request body definitions
|
||||
- do not support query and form definitions, use parameters instead.
|
||||
|
||||
**How to use?**
|
||||
Use the `useDefinitions` keyword in the info code block of the API file to declare the enable. This value is a boolean value. When set to `true`, it will enable the generation of definitions. Otherwise, it will be generated according to properties, and the default is `false`, for example:
|
||||
|
||||
```go
|
||||
syntax = "v1"
|
||||
|
||||
info (
|
||||
...
|
||||
wrapCodeMsg: true
|
||||
useDefinitions: true
|
||||
)
|
||||
...
|
||||
```
|
||||
|
||||
the demo result of swagger.json
|
||||
|
||||
```json
|
||||
{
|
||||
...
|
||||
"responses": {
|
||||
"200": {
|
||||
"description": "",
|
||||
"schema": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"code": {
|
||||
"description": "1001-User not login\u003cbr\u003e1002-User permission denied",
|
||||
"type": "integer",
|
||||
"example": 0
|
||||
},
|
||||
"data": {
|
||||
"$ref": "#/definitions/FormResp"
|
||||
},
|
||||
"msg": {
|
||||
"description": "business message",
|
||||
"type": "string",
|
||||
"example": "ok"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
...
|
||||
}
|
||||
```
|
||||
|
||||
For a complete API example, please refer to the `api/swagger/example/example.api` file in pr. For a complete swagger result example, please refer to the `api/swagger/example/example.swagger.json` file in pr.
|
||||
|
||||
## 2. `goctl api go` code generation
|
||||
- [bug-fix] Add flag `--type-group` to control the output of types(deprecated: experimental switch control type grouping is no longer used), if true, the types in only one group will separate by file.
|
||||
- example `goctl api go --api demo.api --dir demo --type-group`
|
||||
- use `group` keyword in @server block to define the group name in api file, for example
|
||||
```go
|
||||
@server(
|
||||
group: user
|
||||
)
|
||||
service demo{
|
||||
...
|
||||
}
|
||||
```
|
||||
the example of separated types by file
|
||||
```
|
||||
.
|
||||
└── types
|
||||
├── common.go
|
||||
├── gotoolexport.go
|
||||
├── importfile.go
|
||||
├── process.go
|
||||
└── types.go
|
||||
```
|
||||
|
||||
## 3 API Parser
|
||||
- supported identifier value for info key-value in api parser
|
||||
for example
|
||||
|
||||
```
|
||||
syntax = "v1"
|
||||
|
||||
info(
|
||||
enable: true
|
||||
disable: false
|
||||
)
|
||||
...
|
||||
```
|
||||
@@ -4,7 +4,7 @@ go 1.21
|
||||
|
||||
require (
|
||||
github.com/DATA-DOG/go-sqlmock v1.5.2
|
||||
github.com/emicklei/proto v1.14.1
|
||||
github.com/emicklei/proto v1.14.2
|
||||
github.com/fatih/structtag v1.2.0
|
||||
github.com/go-openapi/spec v0.21.1-0.20250328170532-a3928469592e
|
||||
github.com/go-sql-driver/mysql v1.9.0
|
||||
@@ -16,7 +16,7 @@ require (
|
||||
github.com/withfig/autocomplete-tools/integrations/cobra v1.2.1
|
||||
github.com/zeromicro/antlr v0.0.1
|
||||
github.com/zeromicro/ddl-parser v1.0.5
|
||||
github.com/zeromicro/go-zero v1.8.2
|
||||
github.com/zeromicro/go-zero v1.8.4
|
||||
golang.org/x/text v0.22.0
|
||||
google.golang.org/grpc v1.65.0
|
||||
google.golang.org/protobuf v1.36.5
|
||||
@@ -25,8 +25,7 @@ require (
|
||||
|
||||
require (
|
||||
filippo.io/edwards25519 v1.1.0 // indirect
|
||||
github.com/alicebob/gopher-json v0.0.0-20230218143504-906a9b012302 // indirect
|
||||
github.com/alicebob/miniredis/v2 v2.34.0 // indirect
|
||||
github.com/alicebob/miniredis/v2 v2.35.0 // indirect
|
||||
github.com/antlr/antlr4/runtime/Go/antlr v0.0.0-20210521184019-c5ad59b459ec // indirect
|
||||
github.com/beorn7/perks v1.0.1 // indirect
|
||||
github.com/cenkalti/backoff/v4 v4.3.0 // indirect
|
||||
@@ -49,6 +48,8 @@ require (
|
||||
github.com/google/go-cmp v0.6.0 // indirect
|
||||
github.com/google/gofuzz v1.2.0 // indirect
|
||||
github.com/google/uuid v1.6.0 // indirect
|
||||
github.com/grafana/pyroscope-go v1.2.2 // indirect
|
||||
github.com/grafana/pyroscope-go/godeltaprof v0.1.8 // indirect
|
||||
github.com/grpc-ecosystem/grpc-gateway/v2 v2.20.0 // indirect
|
||||
github.com/inconshreveable/mousetrap v1.1.0 // indirect
|
||||
github.com/jackc/pgpassfile v1.0.0 // indirect
|
||||
@@ -72,7 +73,7 @@ require (
|
||||
github.com/prometheus/client_model v0.6.1 // indirect
|
||||
github.com/prometheus/common v0.62.0 // indirect
|
||||
github.com/prometheus/procfs v0.15.1 // indirect
|
||||
github.com/redis/go-redis/v9 v9.7.3 // indirect
|
||||
github.com/redis/go-redis/v9 v9.10.0 // indirect
|
||||
github.com/spaolacci/murmur3 v1.1.0 // indirect
|
||||
github.com/xo/terminfo v0.0.0-20210125001918-ca9a967f8778 // indirect
|
||||
github.com/yuin/gopher-lua v1.1.1 // indirect
|
||||
|
||||
@@ -2,10 +2,8 @@ filippo.io/edwards25519 v1.1.0 h1:FNf4tywRC1HmFuKW5xopWpigGjJKiJSV0Cqo0cJWDaA=
|
||||
filippo.io/edwards25519 v1.1.0/go.mod h1:BxyFTGdWcka3PhytdK4V28tE5sGfRvvvRV7EaN4VDT4=
|
||||
github.com/DATA-DOG/go-sqlmock v1.5.2 h1:OcvFkGmslmlZibjAjaHm3L//6LiuBgolP7OputlJIzU=
|
||||
github.com/DATA-DOG/go-sqlmock v1.5.2/go.mod h1:88MAG/4G7SMwSE3CeA0ZKzrT5CiOU3OJ+JlNzwDqpNU=
|
||||
github.com/alicebob/gopher-json v0.0.0-20230218143504-906a9b012302 h1:uvdUDbHQHO85qeSydJtItA4T55Pw6BtAejd0APRJOCE=
|
||||
github.com/alicebob/gopher-json v0.0.0-20230218143504-906a9b012302/go.mod h1:SGnFV6hVsYE877CKEZ6tDNTjaSXYUk6QqoIK6PrAtcc=
|
||||
github.com/alicebob/miniredis/v2 v2.34.0 h1:mBFWMaJSNL9RwdGRyEDoAAv8OQc5UlEhLDQggTglU/0=
|
||||
github.com/alicebob/miniredis/v2 v2.34.0/go.mod h1:kWShP4b58T1CW0Y5dViCd5ztzrDqRWqM3nksiyXk5s8=
|
||||
github.com/alicebob/miniredis/v2 v2.35.0 h1:QwLphYqCEAo1eu1TqPRN2jgVMPBweeQcR21jeqDCONI=
|
||||
github.com/alicebob/miniredis/v2 v2.35.0/go.mod h1:TcL7YfarKPGDAthEtl5NBeHZfeUQj6OXMm/+iu5cLMM=
|
||||
github.com/antlr/antlr4/runtime/Go/antlr v0.0.0-20210521184019-c5ad59b459ec h1:EEyRvzmpEUZ+I8WmD5cw/vY8EqhambkOqy5iFr0908A=
|
||||
github.com/antlr/antlr4/runtime/Go/antlr v0.0.0-20210521184019-c5ad59b459ec/go.mod h1:F7bn7fEU90QkQ3tnmaTx3LTKLEDqnwWODIYppRQ5hnY=
|
||||
github.com/benbjohnson/clock v1.1.0 h1:Q92kusRqC1XV2MjkWETPvjJVqKetz1OzxZB7mHJLju8=
|
||||
@@ -32,8 +30,8 @@ github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/r
|
||||
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc=
|
||||
github.com/emicklei/go-restful/v3 v3.11.0 h1:rAQeMHw1c7zTmncogyy8VvRZwtkmkZ4FxERmMY4rD+g=
|
||||
github.com/emicklei/go-restful/v3 v3.11.0/go.mod h1:6n3XBCmQQb25CM2LCACGz8ukIrRry+4bhvbpWn3mrbc=
|
||||
github.com/emicklei/proto v1.14.1 h1:fFq+Bj70XXZWXWikcVRvYZxrMS4KIIiPAqdJ8vPrenY=
|
||||
github.com/emicklei/proto v1.14.1/go.mod h1:rn1FgRS/FANiZdD2djyH7TMA9jdRDcYQ9IEN9yvjX0A=
|
||||
github.com/emicklei/proto v1.14.2 h1:wJPxPy2Xifja9cEMrcA/g08art5+7CGJNFNk35iXC1I=
|
||||
github.com/emicklei/proto v1.14.2/go.mod h1:rn1FgRS/FANiZdD2djyH7TMA9jdRDcYQ9IEN9yvjX0A=
|
||||
github.com/fatih/color v1.18.0 h1:S8gINlzdQ840/4pfAwic/ZE0djQEH3wM94VfqLTZcOM=
|
||||
github.com/fatih/color v1.18.0/go.mod h1:4FelSpRwEGDpQ12mAdzqdOukCy4u8WUtOY6lkT/6HfU=
|
||||
github.com/fatih/structtag v1.2.0 h1:/OdNE99OxoI/PqaW/SuSK9uxxT3f/tcSZgon/ssNSx4=
|
||||
@@ -77,6 +75,10 @@ github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
|
||||
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||
github.com/gookit/color v1.5.4 h1:FZmqs7XOyGgCAxmWyPslpiok1k05wmY3SJTytgvYFs0=
|
||||
github.com/gookit/color v1.5.4/go.mod h1:pZJOeOS8DM43rXbp4AZo1n9zCU2qjpcRko0b6/QJi9w=
|
||||
github.com/grafana/pyroscope-go v1.2.2 h1:uvKCyZMD724RkaCEMrSTC38Yn7AnFe8S2wiAIYdDPCE=
|
||||
github.com/grafana/pyroscope-go v1.2.2/go.mod h1:zzT9QXQAp2Iz2ZdS216UiV8y9uXJYQiGE1q8v1FyhqU=
|
||||
github.com/grafana/pyroscope-go/godeltaprof v0.1.8 h1:iwOtYXeeVSAeYefJNaxDytgjKtUuKQbJqgAIjlnicKg=
|
||||
github.com/grafana/pyroscope-go/godeltaprof v0.1.8/go.mod h1:2+l7K7twW49Ct4wFluZD3tZ6e0SjanjcUUBPVD/UuGU=
|
||||
github.com/grpc-ecosystem/grpc-gateway/v2 v2.20.0 h1:bkypFPDjIYGfCYD5mRBvpqxfYX1YCS1PXdKYWi8FsN0=
|
||||
github.com/grpc-ecosystem/grpc-gateway/v2 v2.20.0/go.mod h1:P+Lt/0by1T8bfcF3z737NnSbmxQAppXMRziHUxPOC8k=
|
||||
github.com/h2non/parth v0.0.0-20190131123155-b4df798d6542 h1:2VTzZjLZBgl62/EtslCrtky5vbi9dd7HrQPQIx6wqiw=
|
||||
@@ -146,8 +148,8 @@ github.com/prometheus/common v0.62.0 h1:xasJaQlnWAeyHdUBeGjXmutelfJHWMRr+Fg4QszZ
|
||||
github.com/prometheus/common v0.62.0/go.mod h1:vyBcEuLSvWos9B1+CyL7JZ2up+uFzXhkqml0W5zIY1I=
|
||||
github.com/prometheus/procfs v0.15.1 h1:YagwOFzUgYfKKHX6Dr+sHT7km/hxC76UB0learggepc=
|
||||
github.com/prometheus/procfs v0.15.1/go.mod h1:fB45yRUv8NstnjriLhBQLuOUt+WW4BsoGhij/e3PBqk=
|
||||
github.com/redis/go-redis/v9 v9.7.3 h1:YpPyAayJV+XErNsatSElgRZZVCwXX9QzkKYNvO7x0wM=
|
||||
github.com/redis/go-redis/v9 v9.7.3/go.mod h1:bGUrSggJ9X9GUmZpZNEOQKaANxSGgOEBRltRTZHSvrA=
|
||||
github.com/redis/go-redis/v9 v9.10.0 h1:FxwK3eV8p/CQa0Ch276C7u2d0eNC9kCmAYQ7mCXCzVs=
|
||||
github.com/redis/go-redis/v9 v9.10.0/go.mod h1:huWgSWd8mW6+m0VPhJjSSQ+d6Nh1VICQ6Q5lHuCH/Iw=
|
||||
github.com/rogpeppe/go-internal v1.11.0 h1:cWPaGQEPrBb5/AsnsZesgZZ9yb1OQ+GOISoDNXVBh4M=
|
||||
github.com/rogpeppe/go-internal v1.11.0/go.mod h1:ddIwULY96R17DhadqLgMfk9H9tvdUzkipdSkR5nkCZA=
|
||||
github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM=
|
||||
@@ -183,8 +185,8 @@ github.com/zeromicro/antlr v0.0.1 h1:CQpIn/dc0pUjgGQ81y98s/NGOm2Hfru2NNio2I9mQgk
|
||||
github.com/zeromicro/antlr v0.0.1/go.mod h1:nfpjEwFR6Q4xGDJMcZnCL9tEfQRgszMwu3rDz2Z+p5M=
|
||||
github.com/zeromicro/ddl-parser v1.0.5 h1:LaVqHdzMTjasua1yYpIYaksxKqRzFrEukj2Wi2EbWaQ=
|
||||
github.com/zeromicro/ddl-parser v1.0.5/go.mod h1:ISU/8NuPyEpl9pa17Py9TBPetMjtsiHrb9f5XGiYbo8=
|
||||
github.com/zeromicro/go-zero v1.8.2 h1:AbJckBoojbr1lqCN1dkvURTIHOau7yvKReEd7ZmjuCk=
|
||||
github.com/zeromicro/go-zero v1.8.2/go.mod h1:G5dF+jzCEuq0t1j8qdrtVAy30QMgctGcKSfqFIGsvSg=
|
||||
github.com/zeromicro/go-zero v1.8.4 h1:3s7kOoThCnkDoqCafsqSX58Y9osYTBIa5QEmomw07TE=
|
||||
github.com/zeromicro/go-zero v1.8.4/go.mod h1:eM5f6If/RF+jG1wSCmlvfXD2h2l23vJwETI8oDpjYt4=
|
||||
go.etcd.io/etcd/api/v3 v3.5.15 h1:3KpLJir1ZEBrYuV2v+Twaa/e2MdDCEZ/70H+lzEiwsk=
|
||||
go.etcd.io/etcd/api/v3 v3.5.15/go.mod h1:N9EhGzXq58WuMllgH9ZvnEr7SI9pS0k0+DHZezGp7jM=
|
||||
go.etcd.io/etcd/client/pkg/v3 v3.5.15 h1:fo0HpWz/KlHGMCC+YejpiCmyWDEuIpnTDzpJLB5fWlA=
|
||||
|
||||
@@ -38,7 +38,8 @@
|
||||
"remote": "{{.global.remote}}",
|
||||
"branch": "{{.global.branch}}",
|
||||
"style": "{{.global.style}}",
|
||||
"test": "Generate test files"
|
||||
"test": "Generate test files",
|
||||
"type-group": "Generate type group files"
|
||||
},
|
||||
"new": {
|
||||
"short": "Fast create api service",
|
||||
@@ -205,6 +206,7 @@
|
||||
"short": "Generate mongo model",
|
||||
"type": "Specified model type name",
|
||||
"cache": "Generate code with cache [optional]",
|
||||
"prefix": "Generate code with cache prefix [optional]",
|
||||
"easy": "Generate code with auto generated CollectionName for easy declare [optional]",
|
||||
"dir": "{{.goctl.model.dir}}",
|
||||
"style": "{{.global.style}}",
|
||||
|
||||
@@ -6,7 +6,7 @@ import (
|
||||
)
|
||||
|
||||
// BuildVersion is the version of goctl.
|
||||
const BuildVersion = "1.8.3"
|
||||
const BuildVersion = "1.8.4"
|
||||
|
||||
var tag = map[string]int{"pre-alpha": 0, "alpha": 1, "pre-bata": 2, "beta": 3, "released": 4, "": 5}
|
||||
|
||||
|
||||
@@ -4,9 +4,9 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
"slices"
|
||||
"time"
|
||||
|
||||
"github.com/zeromicro/go-zero/core/stringx"
|
||||
"github.com/zeromicro/go-zero/tools/goctl/rpc/execx"
|
||||
"github.com/zeromicro/go-zero/tools/goctl/util/console"
|
||||
"github.com/zeromicro/go-zero/tools/goctl/util/ctx"
|
||||
@@ -37,7 +37,7 @@ func editMod(version string, verbose bool) error {
|
||||
return err
|
||||
}
|
||||
|
||||
if !stringx.Contains(latest, version) {
|
||||
if !slices.Contains(latest, version) {
|
||||
return fmt.Errorf("release version %q is not found", version)
|
||||
}
|
||||
|
||||
|
||||
@@ -62,6 +62,7 @@ func init() {
|
||||
|
||||
mongoCmdFlags.StringSliceVarP(&mongo.VarStringSliceType, "type", "t")
|
||||
mongoCmdFlags.BoolVarP(&mongo.VarBoolCache, "cache", "c")
|
||||
mongoCmdFlags.StringVarP(&mongo.VarStringPrefix, "prefix", "p")
|
||||
mongoCmdFlags.BoolVarP(&mongo.VarBoolEasy, "easy", "e")
|
||||
mongoCmdFlags.StringVarP(&mongo.VarStringDir, "dir", "d")
|
||||
mongoCmdFlags.StringVar(&mongo.VarStringStyle, "style")
|
||||
|
||||
@@ -17,6 +17,7 @@ import (
|
||||
type Context struct {
|
||||
Types []string
|
||||
Cache bool
|
||||
Prefix string
|
||||
Easy bool
|
||||
Output string
|
||||
Cfg *config.Config
|
||||
@@ -60,6 +61,7 @@ func generateModel(ctx *Context) error {
|
||||
"Type": stringx.From(t).Title(),
|
||||
"lowerType": stringx.From(t).Untitle(),
|
||||
"Cache": ctx.Cache,
|
||||
"Prefix": ctx.Prefix,
|
||||
"version": version.BuildVersion,
|
||||
}, output, true); err != nil {
|
||||
return err
|
||||
|
||||
@@ -19,6 +19,8 @@ var (
|
||||
VarStringDir string
|
||||
// VarBoolCache describes whether cache is enabled.
|
||||
VarBoolCache bool
|
||||
// VarStringPrefix string describes the prefix for the cache key.
|
||||
VarStringPrefix string
|
||||
// VarBoolEasy describes whether to generate Collection Name in the code for easy declare.
|
||||
VarBoolEasy bool
|
||||
// VarStringStyle describes the style.
|
||||
@@ -35,6 +37,7 @@ var (
|
||||
func Action(_ *cobra.Command, _ []string) error {
|
||||
tp := VarStringSliceType
|
||||
c := VarBoolCache
|
||||
p := VarStringPrefix
|
||||
easy := VarBoolEasy
|
||||
o := strings.TrimSpace(VarStringDir)
|
||||
s := VarStringStyle
|
||||
@@ -74,6 +77,7 @@ func Action(_ *cobra.Command, _ []string) error {
|
||||
return generate.Do(&generate.Context{
|
||||
Types: tp,
|
||||
Cache: c,
|
||||
Prefix: p,
|
||||
Easy: easy,
|
||||
Output: a,
|
||||
Cfg: cfg,
|
||||
|
||||
@@ -13,7 +13,7 @@ import (
|
||||
"go.mongodb.org/mongo-driver/mongo"
|
||||
)
|
||||
|
||||
{{if .Cache}}var prefix{{.Type}}CacheKey = "cache:{{.lowerType}}:"{{end}}
|
||||
{{if .Cache}}var prefix{{.Type}}CacheKey = "{{if .Prefix}}{{.Prefix}}:{{end}}cache:{{.lowerType}}:"{{end}}
|
||||
|
||||
type {{.lowerType}}Model interface{
|
||||
Insert(ctx context.Context,data *{{.Type}}) error
|
||||
|
||||
@@ -1356,7 +1356,7 @@ func (p *Parser) parseKVExpression() *ast.KVExpr {
|
||||
expr.Colon = p.curTokenNode()
|
||||
|
||||
// token STRING
|
||||
if !p.advanceIfPeekTokenIs(token.STRING, token.RAW_STRING) {
|
||||
if !p.advanceIfPeekTokenIs(token.STRING, token.RAW_STRING, token.IDENT) {
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
@@ -130,6 +130,8 @@ func TestParser_Parse_infoStmt(t *testing.T) {
|
||||
"author": `"type author here"`,
|
||||
"email": `"type email here"`,
|
||||
"version": `"type version here"`,
|
||||
"enable": `true`,
|
||||
"disable": `false`,
|
||||
}
|
||||
p := New("foo.api", infoTestAPI)
|
||||
result := p.Parse()
|
||||
|
||||
@@ -4,4 +4,6 @@ info(
|
||||
author: "type author here"
|
||||
email: "type email here"
|
||||
version: "type version here"
|
||||
enable: true
|
||||
disable: false
|
||||
)
|
||||
@@ -10,6 +10,8 @@ info ( // info stmt
|
||||
author: "type author here"
|
||||
email: "type email here"
|
||||
version: "type version here"
|
||||
enable: true
|
||||
disable: false
|
||||
)
|
||||
|
||||
type AliasInt int
|
||||
|
||||
@@ -192,3 +192,131 @@ func TestUntitle(t *testing.T) {
|
||||
assert.Equal(t, c.want, ret)
|
||||
}
|
||||
}
|
||||
|
||||
func TestContainsAny(t *testing.T) {
|
||||
type args struct {
|
||||
s string
|
||||
runes []rune
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
want bool
|
||||
}{
|
||||
{
|
||||
name: "runes is empty",
|
||||
args: args{
|
||||
s: "test",
|
||||
runes: []rune{},
|
||||
},
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "s is empty and runes is not empty",
|
||||
args: args{
|
||||
s: "",
|
||||
runes: []rune{'a', 'b', 'c'},
|
||||
},
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "s contains runes",
|
||||
args: args{
|
||||
s: "hello",
|
||||
runes: []rune{'e', 'f'},
|
||||
},
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "s does not contain runes",
|
||||
args: args{
|
||||
s: "hello",
|
||||
runes: []rune{'x', 'y'},
|
||||
},
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "s and runes both have one matching character",
|
||||
args: args{
|
||||
s: "a",
|
||||
runes: []rune{'a'},
|
||||
},
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "s and runes both have one non-matching character",
|
||||
args: args{
|
||||
s: "a",
|
||||
runes: []rune{'b'},
|
||||
},
|
||||
want: false,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
assert.Equalf(t, tt.want, ContainsAny(tt.args.s, tt.args.runes...), "ContainsAny(%v, %v)", tt.args.s, tt.args.runes)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestContainsWhiteSpace(t *testing.T) {
|
||||
type args struct {
|
||||
s string
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
want bool
|
||||
}{
|
||||
{
|
||||
name: "contains space",
|
||||
args: args{s: "hello world"},
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "contains newline",
|
||||
args: args{s: "hello\nworld"},
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "contains tab",
|
||||
args: args{s: "hello\tworld"},
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "contains form feed",
|
||||
args: args{s: "hello\fworld"},
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "contains vertical tab",
|
||||
args: args{s: "hello\vworld"},
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "no whitespace",
|
||||
args: args{s: "helloworld"},
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "empty string",
|
||||
args: args{s: ""},
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "only whitespace",
|
||||
args: args{s: " \t\n\f\v"},
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "contains non-standard whitespace",
|
||||
args: args{s: "hello\u00A0world"},
|
||||
want: false,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
assert.Equalf(t, tt.want, ContainsWhiteSpace(tt.args.s), "ContainsWhiteSpace(%v)", tt.args.s)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -10,7 +10,6 @@ import (
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/zeromicro/go-zero/core/lang"
|
||||
"github.com/zeromicro/go-zero/core/mathx"
|
||||
"google.golang.org/grpc/resolver"
|
||||
)
|
||||
|
||||
@@ -47,7 +46,7 @@ func TestDirectBuilder_Build(t *testing.T) {
|
||||
}, cc, resolver.BuildOptions{})
|
||||
assert.NoError(t, err)
|
||||
|
||||
size := mathx.MinInt(test, subsetSize)
|
||||
size := min(test, subsetSize)
|
||||
assert.Equal(t, size, len(cc.state.Addresses))
|
||||
m := make(map[string]lang.PlaceholderType)
|
||||
for _, each := range cc.state.Addresses {
|
||||
|
||||
Reference in New Issue
Block a user