mirror of
https://github.com/zeromicro/go-zero.git
synced 2026-05-12 01:10:00 +08:00
Compare commits
121 Commits
v1.8.0
...
tools/goct
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
8364e341e1 | ||
|
|
0f2b589d4d | ||
|
|
19fec36d24 | ||
|
|
f037bf344d | ||
|
|
d99cf35b07 | ||
|
|
f459f1b5ff | ||
|
|
0140fd417b | ||
|
|
7969e0ca38 | ||
|
|
91c885b5b0 | ||
|
|
d4cccca387 | ||
|
|
4b2095ed03 | ||
|
|
1229eeb2d2 | ||
|
|
9142b146c5 | ||
|
|
8a1b2d5aed | ||
|
|
da5d39e6ca | ||
|
|
68c5a17c67 | ||
|
|
b53f9f5f2d | ||
|
|
36d57626b6 | ||
|
|
4e36ba832f | ||
|
|
a44954a771 | ||
|
|
f3edd4b880 | ||
|
|
2de3e397ff | ||
|
|
a435eb56f2 | ||
|
|
d80761c147 | ||
|
|
e7bd0d8b60 | ||
|
|
b109b3ef4c | ||
|
|
e3c371ac89 | ||
|
|
15eb6f4f6d | ||
|
|
4d3681b71c | ||
|
|
a682bda0bb | ||
|
|
45b27ad93a | ||
|
|
292a8302a1 | ||
|
|
91ab1f6d2b | ||
|
|
5048c350ae | ||
|
|
94edc32f3e | ||
|
|
ec989b2e2a | ||
|
|
82fe802e81 | ||
|
|
072d68f897 | ||
|
|
2e91ba5811 | ||
|
|
5564c43197 | ||
|
|
e55158b0f7 | ||
|
|
69aa7fe346 | ||
|
|
c3820a95c1 | ||
|
|
493f3bad0f | ||
|
|
eb0d5ad3a4 | ||
|
|
14192050ae | ||
|
|
9193e771e3 | ||
|
|
808b4e496a | ||
|
|
e416d01f8d | ||
|
|
789c5de873 | ||
|
|
52078a0c14 | ||
|
|
7ef13116a0 | ||
|
|
6b8053410a | ||
|
|
81c6928445 | ||
|
|
761c2dd716 | ||
|
|
aeceb3cfbe | ||
|
|
15ea07aad1 | ||
|
|
98bebbc74f | ||
|
|
eafd11d949 | ||
|
|
b251ce346e | ||
|
|
812140ba36 | ||
|
|
44735e949c | ||
|
|
bf313c3c56 | ||
|
|
94e7753262 | ||
|
|
9c478626d2 | ||
|
|
801c283478 | ||
|
|
2a54faf997 | ||
|
|
ecd98f3653 | ||
|
|
61641581eb | ||
|
|
6f2730d5ae | ||
|
|
0eff777b62 | ||
|
|
cafbf535f7 | ||
|
|
6edfce63e3 | ||
|
|
cdb0098b18 | ||
|
|
620c7f9693 | ||
|
|
dba444a382 | ||
|
|
b24fb3ebf7 | ||
|
|
967f0926eb | ||
|
|
e68c683df9 | ||
|
|
247985a065 | ||
|
|
80573af0d8 | ||
|
|
c0394b631a | ||
|
|
68d1aba377 | ||
|
|
3315e60272 | ||
|
|
327ef73700 | ||
|
|
eb11521655 | ||
|
|
4c37545e55 | ||
|
|
2f47c1fba4 | ||
|
|
16d54d0ace | ||
|
|
9925bcbf99 | ||
|
|
38a5ecb796 | ||
|
|
af78fc7c5f | ||
|
|
790302b486 | ||
|
|
6a0672b801 | ||
|
|
560c61612c | ||
|
|
6a988dc4a9 | ||
|
|
15842c3c7a | ||
|
|
f2914a74df | ||
|
|
f113d512e8 | ||
|
|
7a4818da59 | ||
|
|
48d0709ca6 | ||
|
|
f747585518 | ||
|
|
507ff96546 | ||
|
|
651eabb4c6 | ||
|
|
e6b4372056 | ||
|
|
24073969a1 | ||
|
|
ca797ed22c | ||
|
|
e347d3f8f8 | ||
|
|
396393b336 | ||
|
|
1f0531b254 | ||
|
|
77fb271a06 | ||
|
|
af7cf79963 | ||
|
|
7926d396d7 | ||
|
|
080cd3df84 | ||
|
|
c4e1a6a2d8 | ||
|
|
4e71e95e44 | ||
|
|
84db9bcd15 | ||
|
|
b28f79ac11 | ||
|
|
e134e77b2b | ||
|
|
f669d84ce8 | ||
|
|
9213b8ac27 |
18
.github/workflows/issue-translator.yml
vendored
18
.github/workflows/issue-translator.yml
vendored
@@ -1,18 +0,0 @@
|
||||
name: 'issue-translator'
|
||||
on:
|
||||
issue_comment:
|
||||
types: [created]
|
||||
issues:
|
||||
types: [opened]
|
||||
|
||||
jobs:
|
||||
build:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: usthe/issues-translate-action@v2.7
|
||||
with:
|
||||
IS_MODIFY_TITLE: true
|
||||
# not require, default false, . Decide whether to modify the issue title
|
||||
# if true, the robot account @Issues-translate-bot must have modification permissions, invite @Issues-translate-bot to your project or use your custom bot.
|
||||
CUSTOM_BOT_NOTE: Bot detected the issue body's language is not English, translate it automatically. 👯👭🏻🧑🤝🧑👫🧑🏿🤝🧑🏻👩🏾🤝👨🏿👬🏿
|
||||
# not require. Customize the translation robot prefix message.
|
||||
42
.github/workflows/version-check.yml
vendored
Normal file
42
.github/workflows/version-check.yml
vendored
Normal file
@@ -0,0 +1,42 @@
|
||||
name: Release Version Check
|
||||
|
||||
on:
|
||||
push:
|
||||
tags:
|
||||
- 'tools/goctl/v*'
|
||||
workflow_dispatch:
|
||||
|
||||
jobs:
|
||||
version-check:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
|
||||
- name: Set up Go
|
||||
uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version: '1.21'
|
||||
|
||||
- name: Extract tag version
|
||||
id: get_version
|
||||
run: |
|
||||
# Extract version from tools/goctl/v* format
|
||||
VERSION="${GITHUB_REF#refs/tags/tools/goctl/v}"
|
||||
echo "VERSION=$VERSION" >> $GITHUB_ENV
|
||||
echo "Extracted version: $VERSION"
|
||||
|
||||
- name: Check version in goctl source code
|
||||
run: |
|
||||
# Change to goctl directory
|
||||
cd tools/goctl
|
||||
|
||||
# Check version in BuildVersion constant
|
||||
VERSION_IN_CODE=$(grep -r "const BuildVersion =" . | grep -o '".*"' | tr -d '"')
|
||||
echo "Version in code: $VERSION_IN_CODE"
|
||||
echo "Expected version: $VERSION"
|
||||
|
||||
if [ "$VERSION_IN_CODE" != "$VERSION" ]; then
|
||||
echo "Version mismatch: Version in code ($VERSION_IN_CODE) doesn't match tag version ($VERSION)"
|
||||
exit 1
|
||||
fi
|
||||
echo "✅ Version check passed!"
|
||||
@@ -8,16 +8,12 @@ import (
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/zeromicro/go-zero/core/mathx"
|
||||
"github.com/zeromicro/go-zero/core/proc"
|
||||
"github.com/zeromicro/go-zero/core/stat"
|
||||
"github.com/zeromicro/go-zero/core/stringx"
|
||||
)
|
||||
|
||||
const (
|
||||
numHistoryReasons = 5
|
||||
timeFormat = "15:04:05"
|
||||
)
|
||||
const numHistoryReasons = 5
|
||||
|
||||
// ErrServiceUnavailable is returned when the Breaker state is open.
|
||||
var ErrServiceUnavailable = errors.New("circuit breaker is open")
|
||||
@@ -262,9 +258,9 @@ type errorWindow struct {
|
||||
|
||||
func (ew *errorWindow) add(reason string) {
|
||||
ew.lock.Lock()
|
||||
ew.reasons[ew.index] = fmt.Sprintf("%s %s", time.Now().Format(timeFormat), reason)
|
||||
ew.reasons[ew.index] = fmt.Sprintf("%s %s", time.Now().Format(time.TimeOnly), reason)
|
||||
ew.index = (ew.index + 1) % numHistoryReasons
|
||||
ew.count = mathx.MinInt(ew.count+1, numHistoryReasons)
|
||||
ew.count = min(ew.count+1, numHistoryReasons)
|
||||
ew.lock.Unlock()
|
||||
}
|
||||
|
||||
|
||||
@@ -86,21 +86,16 @@ func TestConsistentHashIncrementalTransfer(t *testing.T) {
|
||||
|
||||
func TestConsistentHashTransferOnFailure(t *testing.T) {
|
||||
index := 41
|
||||
keys, newKeys := getKeysBeforeAndAfterFailure(t, "localhost:", index)
|
||||
var transferred int
|
||||
for k, v := range newKeys {
|
||||
if v != keys[k] {
|
||||
transferred++
|
||||
}
|
||||
}
|
||||
|
||||
ratio := float32(transferred) / float32(requestSize)
|
||||
assert.True(t, ratio < 2.5/float32(keySize), fmt.Sprintf("%d: %f", index, ratio))
|
||||
ratioNotExists := getTransferRatioOnFailure(t, index)
|
||||
assert.True(t, ratioNotExists == 0, fmt.Sprintf("%d: %f", index, ratioNotExists))
|
||||
index = 13
|
||||
ratio := getTransferRatioOnFailure(t, index)
|
||||
assert.True(t, ratio < 2.5/keySize, fmt.Sprintf("%d: %f", index, ratio))
|
||||
}
|
||||
|
||||
func TestConsistentHashLeastTransferOnFailure(t *testing.T) {
|
||||
prefix := "localhost:"
|
||||
index := 41
|
||||
index := 13
|
||||
keys, newKeys := getKeysBeforeAndAfterFailure(t, prefix, index)
|
||||
for k, v := range keys {
|
||||
newV := newKeys[k]
|
||||
@@ -164,6 +159,17 @@ func getKeysBeforeAndAfterFailure(t *testing.T, prefix string, index int) (map[i
|
||||
return keys, newKeys
|
||||
}
|
||||
|
||||
func getTransferRatioOnFailure(t *testing.T, index int) float32 {
|
||||
keys, newKeys := getKeysBeforeAndAfterFailure(t, "localhost:", index)
|
||||
var transferred int
|
||||
for k, v := range newKeys {
|
||||
if v != keys[k] {
|
||||
transferred++
|
||||
}
|
||||
}
|
||||
return float32(transferred) / float32(requestSize)
|
||||
}
|
||||
|
||||
type mockNode struct {
|
||||
addr string
|
||||
id int
|
||||
|
||||
@@ -2,7 +2,7 @@ package hash
|
||||
|
||||
import (
|
||||
"crypto/md5"
|
||||
"fmt"
|
||||
"encoding/hex"
|
||||
|
||||
"github.com/spaolacci/murmur3"
|
||||
)
|
||||
@@ -20,6 +20,7 @@ func Md5(data []byte) []byte {
|
||||
}
|
||||
|
||||
// Md5Hex returns the md5 hex string of data.
|
||||
// This function is optimized for better performance than fmt.Sprintf.
|
||||
func Md5Hex(data []byte) string {
|
||||
return fmt.Sprintf("%x", Md5(data))
|
||||
return hex.EncodeToString(Md5(data))
|
||||
}
|
||||
|
||||
@@ -560,7 +560,7 @@ func shallLogStat() bool {
|
||||
// If we check shallLog here, the fmt.Sprint might be called even if the log level is not enabled.
|
||||
// The caller should check shallLog before calling this function.
|
||||
func writeDebug(val any, fields ...LogField) {
|
||||
getWriter().Debug(val, addCaller(fields...)...)
|
||||
getWriter().Debug(val, mergeGlobalFields(addCaller(fields...))...)
|
||||
}
|
||||
|
||||
// writeError writes v into the error log.
|
||||
@@ -568,7 +568,7 @@ func writeDebug(val any, fields ...LogField) {
|
||||
// If we check shallLog here, the fmt.Sprint might be called even if the log level is not enabled.
|
||||
// The caller should check shallLog before calling this function.
|
||||
func writeError(val any, fields ...LogField) {
|
||||
getWriter().Error(val, addCaller(fields...)...)
|
||||
getWriter().Error(val, mergeGlobalFields(addCaller(fields...))...)
|
||||
}
|
||||
|
||||
// writeInfo writes v into info log.
|
||||
@@ -576,7 +576,7 @@ func writeError(val any, fields ...LogField) {
|
||||
// If we check shallLog here, the fmt.Sprint might be called even if the log level is not enabled.
|
||||
// The caller should check shallLog before calling this function.
|
||||
func writeInfo(val any, fields ...LogField) {
|
||||
getWriter().Info(val, addCaller(fields...)...)
|
||||
getWriter().Info(val, mergeGlobalFields(addCaller(fields...))...)
|
||||
}
|
||||
|
||||
// writeSevere writes v into severe log.
|
||||
@@ -592,7 +592,7 @@ func writeSevere(msg string) {
|
||||
// If we check shallLog here, the fmt.Sprint might be called even if the log level is not enabled.
|
||||
// The caller should check shallLog before calling this function.
|
||||
func writeSlow(val any, fields ...LogField) {
|
||||
getWriter().Slow(val, addCaller(fields...)...)
|
||||
getWriter().Slow(val, mergeGlobalFields(addCaller(fields...))...)
|
||||
}
|
||||
|
||||
// writeStack writes v into stack log.
|
||||
@@ -608,5 +608,5 @@ func writeStack(msg string) {
|
||||
// If we check shallLog here, the fmt.Sprint might be called even if the log level is not enabled.
|
||||
// The caller should check shallLog before calling this function.
|
||||
func writeStat(msg string) {
|
||||
getWriter().Stat(msg, addCaller()...)
|
||||
getWriter().Stat(msg, mergeGlobalFields(addCaller())...)
|
||||
}
|
||||
|
||||
@@ -206,7 +206,9 @@ func (l *richLogger) WithFields(fields ...LogField) Logger {
|
||||
|
||||
func (l *richLogger) buildFields(fields ...LogField) []LogField {
|
||||
fields = append(l.fields, fields...)
|
||||
// caller field should always appear together with global fields
|
||||
fields = append(fields, Field(callerKey, getCaller(callerDepth+l.callerSkip)))
|
||||
fields = mergeGlobalFields(fields)
|
||||
|
||||
if l.ctx == nil {
|
||||
return fields
|
||||
|
||||
@@ -18,7 +18,6 @@ import (
|
||||
)
|
||||
|
||||
const (
|
||||
dateFormat = "2006-01-02"
|
||||
hoursPerDay = 24
|
||||
bufferSize = 100
|
||||
defaultDirMode = 0o755
|
||||
@@ -116,7 +115,7 @@ func (r *DailyRotateRule) OutdatedFiles() []string {
|
||||
}
|
||||
|
||||
var buf strings.Builder
|
||||
boundary := time.Now().Add(-time.Hour * time.Duration(hoursPerDay*r.days)).Format(dateFormat)
|
||||
boundary := time.Now().Add(-time.Hour * time.Duration(hoursPerDay*r.days)).Format(time.DateOnly)
|
||||
buf.WriteString(r.filename)
|
||||
buf.WriteString(r.delimiter)
|
||||
buf.WriteString(boundary)
|
||||
@@ -425,7 +424,7 @@ func compressLogFile(file string) {
|
||||
}
|
||||
|
||||
func getNowDate() string {
|
||||
return time.Now().Format(dateFormat)
|
||||
return time.Now().Format(time.DateOnly)
|
||||
}
|
||||
|
||||
func getNowDateInRFC3339Format() string {
|
||||
|
||||
@@ -52,7 +52,7 @@ func TestDailyRotateRuleOutdatedFiles(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("temp files", func(t *testing.T) {
|
||||
boundary := time.Now().Add(-time.Hour * time.Duration(hoursPerDay) * 2).Format(dateFormat)
|
||||
boundary := time.Now().Add(-time.Hour * time.Duration(hoursPerDay) * 2).Format(time.DateOnly)
|
||||
f1, err := os.CreateTemp(os.TempDir(), "go-zero-test-"+boundary)
|
||||
assert.NoError(t, err)
|
||||
_ = f1.Close()
|
||||
@@ -73,7 +73,7 @@ func TestDailyRotateRuleOutdatedFiles(t *testing.T) {
|
||||
|
||||
func TestDailyRotateRuleShallRotate(t *testing.T) {
|
||||
var rule DailyRotateRule
|
||||
rule.rotatedTime = time.Now().Add(time.Hour * 24).Format(dateFormat)
|
||||
rule.rotatedTime = time.Now().Add(time.Hour * 24).Format(time.DateOnly)
|
||||
assert.True(t, rule.ShallRotate(0))
|
||||
}
|
||||
|
||||
@@ -117,12 +117,12 @@ func TestSizeLimitRotateRuleOutdatedFiles(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("temp files", func(t *testing.T) {
|
||||
boundary := time.Now().Add(-time.Hour * time.Duration(hoursPerDay) * 2).Format(dateFormat)
|
||||
boundary := time.Now().Add(-time.Hour * time.Duration(hoursPerDay) * 2).Format(time.DateOnly)
|
||||
f1, err := os.CreateTemp(os.TempDir(), "go-zero-test-"+boundary)
|
||||
assert.NoError(t, err)
|
||||
f2, err := os.CreateTemp(os.TempDir(), "go-zero-test-"+boundary)
|
||||
assert.NoError(t, err)
|
||||
boundary1 := time.Now().Add(time.Hour * time.Duration(hoursPerDay) * 2).Format(dateFormat)
|
||||
boundary1 := time.Now().Add(time.Hour * time.Duration(hoursPerDay) * 2).Format(time.DateOnly)
|
||||
f3, err := os.CreateTemp(os.TempDir(), "go-zero-test-"+boundary1)
|
||||
assert.NoError(t, err)
|
||||
t.Cleanup(func() {
|
||||
@@ -144,12 +144,12 @@ func TestSizeLimitRotateRuleOutdatedFiles(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("no backups", func(t *testing.T) {
|
||||
boundary := time.Now().Add(-time.Hour * time.Duration(hoursPerDay) * 2).Format(dateFormat)
|
||||
boundary := time.Now().Add(-time.Hour * time.Duration(hoursPerDay) * 2).Format(time.DateOnly)
|
||||
f1, err := os.CreateTemp(os.TempDir(), "go-zero-test-"+boundary)
|
||||
assert.NoError(t, err)
|
||||
f2, err := os.CreateTemp(os.TempDir(), "go-zero-test-"+boundary)
|
||||
assert.NoError(t, err)
|
||||
boundary1 := time.Now().Add(time.Hour * time.Duration(hoursPerDay) * 2).Format(dateFormat)
|
||||
boundary1 := time.Now().Add(time.Hour * time.Duration(hoursPerDay) * 2).Format(time.DateOnly)
|
||||
f3, err := os.CreateTemp(os.TempDir(), "go-zero-test-"+boundary1)
|
||||
assert.NoError(t, err)
|
||||
t.Cleanup(func() {
|
||||
@@ -319,7 +319,7 @@ func TestRotateLoggerWrite(t *testing.T) {
|
||||
}
|
||||
// the following write calls cannot be changed to Write, because of DATA RACE.
|
||||
logger.write([]byte(`foo`))
|
||||
rule.rotatedTime = time.Now().Add(-time.Hour * 24).Format(dateFormat)
|
||||
rule.rotatedTime = time.Now().Add(-time.Hour * 24).Format(time.DateOnly)
|
||||
logger.write([]byte(`bar`))
|
||||
logger.Close()
|
||||
logger.write([]byte(`baz`))
|
||||
@@ -447,7 +447,7 @@ func TestRotateLoggerWithSizeLimitRotateRuleWrite(t *testing.T) {
|
||||
}
|
||||
// the following write calls cannot be changed to Write, because of DATA RACE.
|
||||
logger.write([]byte(`foo`))
|
||||
rule.rotatedTime = time.Now().Add(-time.Hour * 24).Format(dateFormat)
|
||||
rule.rotatedTime = time.Now().Add(-time.Hour * 24).Format(time.DateOnly)
|
||||
logger.write([]byte(`bar`))
|
||||
logger.Close()
|
||||
logger.write([]byte(`baz`))
|
||||
|
||||
@@ -17,15 +17,27 @@ import (
|
||||
)
|
||||
|
||||
type (
|
||||
// Writer is the interface for writing logs.
|
||||
// It's designed to let users customize their own log writer,
|
||||
// such as writing logs to a kafka, a database, or using third-party loggers.
|
||||
Writer interface {
|
||||
// Alert sends an alert message, if your writer implemented alerting functionality.
|
||||
Alert(v any)
|
||||
// Close closes the writer.
|
||||
Close() error
|
||||
// Debug logs a message at debug level.
|
||||
Debug(v any, fields ...LogField)
|
||||
// Error logs a message at error level.
|
||||
Error(v any, fields ...LogField)
|
||||
// Info logs a message at info level.
|
||||
Info(v any, fields ...LogField)
|
||||
// Severe logs a message at severe level.
|
||||
Severe(v any)
|
||||
// Slow logs a message at slow level.
|
||||
Slow(v any, fields ...LogField)
|
||||
// Stack logs a message at error level.
|
||||
Stack(v any)
|
||||
// Stat logs a message at stat level.
|
||||
Stat(v any, fields ...LogField)
|
||||
}
|
||||
|
||||
@@ -324,20 +336,6 @@ func buildPlainFields(fields logEntry) []string {
|
||||
return items
|
||||
}
|
||||
|
||||
func combineGlobalFields(fields []LogField) []LogField {
|
||||
globals := globalFields.Load()
|
||||
if globals == nil {
|
||||
return fields
|
||||
}
|
||||
|
||||
gf := globals.([]LogField)
|
||||
ret := make([]LogField, 0, len(gf)+len(fields))
|
||||
ret = append(ret, gf...)
|
||||
ret = append(ret, fields...)
|
||||
|
||||
return ret
|
||||
}
|
||||
|
||||
func marshalJson(t interface{}) ([]byte, error) {
|
||||
var buf bytes.Buffer
|
||||
encoder := json.NewEncoder(&buf)
|
||||
@@ -352,6 +350,20 @@ func marshalJson(t interface{}) ([]byte, error) {
|
||||
return buf.Bytes(), err
|
||||
}
|
||||
|
||||
func mergeGlobalFields(fields []LogField) []LogField {
|
||||
globals := globalFields.Load()
|
||||
if globals == nil {
|
||||
return fields
|
||||
}
|
||||
|
||||
gf := globals.([]LogField)
|
||||
ret := make([]LogField, 0, len(gf)+len(fields))
|
||||
ret = append(ret, gf...)
|
||||
ret = append(ret, fields...)
|
||||
|
||||
return ret
|
||||
}
|
||||
|
||||
func output(writer io.Writer, level string, val any, fields ...LogField) {
|
||||
// only truncate string content, don't know how to truncate the values of other types.
|
||||
if v, ok := val.(string); ok {
|
||||
@@ -362,7 +374,6 @@ func output(writer io.Writer, level string, val any, fields ...LogField) {
|
||||
}
|
||||
}
|
||||
|
||||
fields = combineGlobalFields(fields)
|
||||
// +3 for timestamp, level and content
|
||||
entry := make(logEntry, len(fields)+3)
|
||||
for _, field := range fields {
|
||||
|
||||
@@ -13,6 +13,15 @@ const (
|
||||
|
||||
// Marshal marshals the given val and returns the map that contains the fields.
|
||||
// optional=another is not implemented, and it's hard to implement and not commonly used.
|
||||
// support anonymous field, e.g.:
|
||||
//
|
||||
// type Foo struct {
|
||||
// Token string `header:"token"`
|
||||
// }
|
||||
// type FooB struct {
|
||||
// Foo
|
||||
// Bar string `json:"bar"`
|
||||
// }
|
||||
func Marshal(val any) (map[string]map[string]any, error) {
|
||||
ret := make(map[string]map[string]any)
|
||||
tp := reflect.TypeOf(val)
|
||||
@@ -44,6 +53,16 @@ func getTag(field reflect.StructField) (string, bool) {
|
||||
return strings.TrimSpace(tag), false
|
||||
}
|
||||
|
||||
func insertValue(collector map[string]map[string]any, tag string, key string, val any) {
|
||||
if m, ok := collector[tag]; ok {
|
||||
m[key] = val
|
||||
} else {
|
||||
collector[tag] = map[string]any{
|
||||
key: val,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func processMember(field reflect.StructField, value reflect.Value,
|
||||
collector map[string]map[string]any) error {
|
||||
var key string
|
||||
@@ -69,15 +88,20 @@ func processMember(field reflect.StructField, value reflect.Value,
|
||||
val = fmt.Sprint(val)
|
||||
}
|
||||
|
||||
m, ok := collector[tag]
|
||||
if ok {
|
||||
m[key] = val
|
||||
} else {
|
||||
m = map[string]any{
|
||||
key: val,
|
||||
if field.Anonymous {
|
||||
anonCollector, err := Marshal(val)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for anonTag, anonMap := range anonCollector {
|
||||
for anonKey, anonVal := range anonMap {
|
||||
insertValue(collector, anonTag, anonKey, anonVal)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
insertValue(collector, tag, key, val)
|
||||
}
|
||||
collector[tag] = m
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -118,7 +142,7 @@ func validateOptional(field reflect.StructField, value reflect.Value) error {
|
||||
if value.IsNil() {
|
||||
return fmt.Errorf("field %q is nil", field.Name)
|
||||
}
|
||||
case reflect.Array, reflect.Slice, reflect.Map:
|
||||
case reflect.Slice, reflect.Map:
|
||||
if value.IsNil() || value.Len() == 0 {
|
||||
return fmt.Errorf("field %q is empty", field.Name)
|
||||
}
|
||||
|
||||
@@ -27,6 +27,124 @@ func TestMarshal(t *testing.T) {
|
||||
assert.True(t, m[emptyTag]["Anonymous"].(bool))
|
||||
}
|
||||
|
||||
func TestMarshal_Anonymous(t *testing.T) {
|
||||
t.Run("anonymous", func(t *testing.T) {
|
||||
type BaseHeader struct {
|
||||
Token string `header:"token"`
|
||||
}
|
||||
v := struct {
|
||||
Name string `json:"name"`
|
||||
Address string `json:"address,options=[beijing,shanghai]"`
|
||||
Age int `json:"age"`
|
||||
BaseHeader
|
||||
}{
|
||||
Name: "kevin",
|
||||
Address: "shanghai",
|
||||
Age: 20,
|
||||
BaseHeader: BaseHeader{
|
||||
Token: "token_xxx",
|
||||
},
|
||||
}
|
||||
m, err := Marshal(v)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, "kevin", m["json"]["name"])
|
||||
assert.Equal(t, "shanghai", m["json"]["address"])
|
||||
assert.Equal(t, 20, m["json"]["age"].(int))
|
||||
assert.Equal(t, "token_xxx", m["header"]["token"])
|
||||
|
||||
v1 := struct {
|
||||
Name string `json:"name"`
|
||||
Address string `json:"address,options=[beijing,shanghai]"`
|
||||
Age int `json:"age"`
|
||||
BaseHeader
|
||||
}{
|
||||
Name: "kevin",
|
||||
Address: "shanghai",
|
||||
Age: 20,
|
||||
}
|
||||
m1, err1 := Marshal(v1)
|
||||
assert.Nil(t, err1)
|
||||
assert.Equal(t, "kevin", m1["json"]["name"])
|
||||
assert.Equal(t, "shanghai", m1["json"]["address"])
|
||||
assert.Equal(t, 20, m1["json"]["age"].(int))
|
||||
|
||||
type AnotherHeader struct {
|
||||
Version string `header:"version"`
|
||||
}
|
||||
v2 := struct {
|
||||
Name string `json:"name"`
|
||||
Address string `json:"address,options=[beijing,shanghai]"`
|
||||
Age int `json:"age"`
|
||||
BaseHeader
|
||||
AnotherHeader
|
||||
}{
|
||||
Name: "kevin",
|
||||
Address: "shanghai",
|
||||
Age: 20,
|
||||
BaseHeader: BaseHeader{
|
||||
Token: "token_xxx",
|
||||
},
|
||||
AnotherHeader: AnotherHeader{
|
||||
Version: "v1.0",
|
||||
},
|
||||
}
|
||||
m2, err2 := Marshal(v2)
|
||||
assert.Nil(t, err2)
|
||||
assert.Equal(t, "kevin", m2["json"]["name"])
|
||||
assert.Equal(t, "shanghai", m2["json"]["address"])
|
||||
assert.Equal(t, 20, m2["json"]["age"].(int))
|
||||
assert.Equal(t, "token_xxx", m2["header"]["token"])
|
||||
assert.Equal(t, "v1.0", m2["header"]["version"])
|
||||
|
||||
type PointerHeader struct {
|
||||
Ref *string `header:"ref"`
|
||||
}
|
||||
ref := "reference"
|
||||
v3 := struct {
|
||||
Name string `json:"name"`
|
||||
Address string `json:"address,options=[beijing,shanghai]"`
|
||||
Age int `json:"age"`
|
||||
PointerHeader
|
||||
}{
|
||||
Name: "kevin",
|
||||
Address: "shanghai",
|
||||
Age: 20,
|
||||
PointerHeader: PointerHeader{
|
||||
Ref: &ref,
|
||||
},
|
||||
}
|
||||
m3, err3 := Marshal(v3)
|
||||
assert.Nil(t, err3)
|
||||
assert.Equal(t, "kevin", m3["json"]["name"])
|
||||
assert.Equal(t, "shanghai", m3["json"]["address"])
|
||||
assert.Equal(t, 20, m3["json"]["age"].(int))
|
||||
assert.Equal(t, "reference", *m3["header"]["ref"].(*string))
|
||||
})
|
||||
|
||||
t.Run("bad anonymous", func(t *testing.T) {
|
||||
type BaseHeader struct {
|
||||
Token string `json:"token,options=[a,b]"`
|
||||
}
|
||||
|
||||
v := struct {
|
||||
Name string `json:"name"`
|
||||
Address string `json:"address,options=[beijing,shanghai]"`
|
||||
Age int `json:"age"`
|
||||
BaseHeader
|
||||
}{
|
||||
Name: "kevin",
|
||||
Address: "shanghai",
|
||||
Age: 20,
|
||||
BaseHeader: BaseHeader{
|
||||
Token: "c",
|
||||
},
|
||||
}
|
||||
|
||||
_, err := Marshal(v)
|
||||
assert.NotNil(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
func TestMarshal_Ptr(t *testing.T) {
|
||||
v := &struct {
|
||||
Name string `path:"name"`
|
||||
@@ -344,3 +462,15 @@ func TestMarshal_FromString(t *testing.T) {
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, "10", m["json"]["age"].(string))
|
||||
}
|
||||
|
||||
func TestMarshal_Array(t *testing.T) {
|
||||
v := struct {
|
||||
H [1]int `json:"h,string"`
|
||||
}{
|
||||
H: [1]int{1},
|
||||
}
|
||||
|
||||
m, err := Marshal(v)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, "[1]", m["json"]["h"].(string))
|
||||
}
|
||||
|
||||
@@ -7,6 +7,7 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"slices"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
@@ -15,11 +16,9 @@ import (
|
||||
"github.com/zeromicro/go-zero/core/jsonx"
|
||||
"github.com/zeromicro/go-zero/core/lang"
|
||||
"github.com/zeromicro/go-zero/core/proc"
|
||||
"github.com/zeromicro/go-zero/core/stringx"
|
||||
)
|
||||
|
||||
const (
|
||||
comma = ","
|
||||
defaultKeyName = "key"
|
||||
delimiter = '.'
|
||||
ignoreKey = "-"
|
||||
@@ -38,7 +37,6 @@ var (
|
||||
defaultCacheLock sync.Mutex
|
||||
emptyMap = map[string]any{}
|
||||
emptyValue = reflect.ValueOf(lang.Placeholder)
|
||||
stringSliceType = reflect.TypeOf([]string{})
|
||||
)
|
||||
|
||||
type (
|
||||
@@ -152,10 +150,6 @@ func (u *Unmarshaler) fillSlice(fieldType reflect.Type, value reflect.Value,
|
||||
return nil
|
||||
}
|
||||
|
||||
if u.opts.fromArray {
|
||||
refValue = makeStringSlice(refValue)
|
||||
}
|
||||
|
||||
var valid bool
|
||||
conv := reflect.MakeSlice(reflect.SliceOf(baseType), refValue.Len(), refValue.Cap())
|
||||
|
||||
@@ -900,7 +894,7 @@ func (u *Unmarshaler) processNamedFieldWithValueFromString(fieldType reflect.Typ
|
||||
valueKind.String())
|
||||
}
|
||||
|
||||
if !stringx.Contains(options, checkValue) {
|
||||
if !slices.Contains(options, checkValue) {
|
||||
return fmt.Errorf(`value "%s" for field %q is not defined in options "%v"`,
|
||||
mapValue, key, options)
|
||||
}
|
||||
@@ -1189,35 +1183,6 @@ func join(elem ...string) string {
|
||||
return builder.String()
|
||||
}
|
||||
|
||||
func makeStringSlice(refValue reflect.Value) reflect.Value {
|
||||
if refValue.Len() != 1 {
|
||||
return refValue
|
||||
}
|
||||
|
||||
element := refValue.Index(0)
|
||||
if element.Kind() != reflect.String {
|
||||
return refValue
|
||||
}
|
||||
|
||||
val, ok := element.Interface().(string)
|
||||
if !ok {
|
||||
return refValue
|
||||
}
|
||||
|
||||
splits := strings.Split(val, comma)
|
||||
if len(splits) <= 1 {
|
||||
return refValue
|
||||
}
|
||||
|
||||
slice := reflect.MakeSlice(stringSliceType, len(splits), len(splits))
|
||||
for i, split := range splits {
|
||||
// allow empty strings
|
||||
slice.Index(i).Set(reflect.ValueOf(split))
|
||||
}
|
||||
|
||||
return slice
|
||||
}
|
||||
|
||||
func newInitError(name string) error {
|
||||
return fmt.Errorf("field %q is not set", name)
|
||||
}
|
||||
|
||||
@@ -1462,9 +1462,7 @@ func TestUnmarshalIntSlice(t *testing.T) {
|
||||
|
||||
ast := assert.New(t)
|
||||
unmarshaler := NewUnmarshaler(defaultKeyName, WithFromArray())
|
||||
if ast.NoError(unmarshaler.Unmarshal(m, &v)) {
|
||||
ast.ElementsMatch([]int{1, 2}, v.Ages)
|
||||
}
|
||||
ast.Error(unmarshaler.Unmarshal(m, &v))
|
||||
})
|
||||
}
|
||||
|
||||
@@ -1546,7 +1544,22 @@ func TestUnmarshalStringSliceFromString(t *testing.T) {
|
||||
ast := assert.New(t)
|
||||
unmarshaler := NewUnmarshaler(defaultKeyName, WithFromArray())
|
||||
if ast.NoError(unmarshaler.Unmarshal(m, &v)) {
|
||||
ast.ElementsMatch([]string{"", ""}, v.Names)
|
||||
ast.ElementsMatch([]string{","}, v.Names)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("slice from valid strings with comma", func(t *testing.T) {
|
||||
var v struct {
|
||||
Names []string `key:"names"`
|
||||
}
|
||||
m := map[string]any{
|
||||
"names": []string{"aa,bb"},
|
||||
}
|
||||
|
||||
ast := assert.New(t)
|
||||
unmarshaler := NewUnmarshaler(defaultKeyName, WithFromArray())
|
||||
if ast.NoError(unmarshaler.Unmarshal(m, &v)) {
|
||||
ast.ElementsMatch([]string{"aa,bb"}, v.Names)
|
||||
}
|
||||
})
|
||||
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"fmt"
|
||||
"math"
|
||||
"reflect"
|
||||
"slices"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
@@ -634,11 +635,11 @@ func validateValueInOptions(val any, options []string) error {
|
||||
if len(options) > 0 {
|
||||
switch v := val.(type) {
|
||||
case string:
|
||||
if !stringx.Contains(options, v) {
|
||||
if !slices.Contains(options, v) {
|
||||
return fmt.Errorf(`error: value %q is not defined in options "%v"`, v, options)
|
||||
}
|
||||
default:
|
||||
if !stringx.Contains(options, Repr(v)) {
|
||||
if !slices.Contains(options, Repr(v)) {
|
||||
return fmt.Errorf(`error: value "%v" is not defined in options "%v"`, val, options)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,19 +1,13 @@
|
||||
package mathx
|
||||
|
||||
// MaxInt returns the larger one of a and b.
|
||||
// Deprecated: use builtin max instead.
|
||||
func MaxInt(a, b int) int {
|
||||
if a > b {
|
||||
return a
|
||||
}
|
||||
|
||||
return b
|
||||
return max(a, b)
|
||||
}
|
||||
|
||||
// MinInt returns the smaller one of a and b.
|
||||
// Deprecated: use builtin min instead.
|
||||
func MinInt(a, b int) int {
|
||||
if a < b {
|
||||
return a
|
||||
}
|
||||
|
||||
return b
|
||||
return min(a, b)
|
||||
}
|
||||
|
||||
@@ -142,89 +142,6 @@ func MapReduceChan[T, U, V any](source <-chan T, mapper MapperFunc[T, U], reduce
|
||||
return mapReduceWithPanicChan(source, panicChan, mapper, reducer, opts...)
|
||||
}
|
||||
|
||||
// mapReduceWithPanicChan maps all elements from source, and reduce the output elements with given reducer.
|
||||
func mapReduceWithPanicChan[T, U, V any](source <-chan T, panicChan *onceChan, mapper MapperFunc[T, U],
|
||||
reducer ReducerFunc[U, V], opts ...Option) (val V, err error) {
|
||||
options := buildOptions(opts...)
|
||||
// output is used to write the final result
|
||||
output := make(chan V)
|
||||
defer func() {
|
||||
// reducer can only write once, if more, panic
|
||||
for range output {
|
||||
panic("more than one element written in reducer")
|
||||
}
|
||||
}()
|
||||
|
||||
// collector is used to collect data from mapper, and consume in reducer
|
||||
collector := make(chan U, options.workers)
|
||||
// if done is closed, all mappers and reducer should stop processing
|
||||
done := make(chan struct{})
|
||||
writer := newGuardedWriter(options.ctx, output, done)
|
||||
var closeOnce sync.Once
|
||||
// use atomic type to avoid data race
|
||||
var retErr errorx.AtomicError
|
||||
finish := func() {
|
||||
closeOnce.Do(func() {
|
||||
close(done)
|
||||
close(output)
|
||||
})
|
||||
}
|
||||
cancel := once(func(err error) {
|
||||
if err != nil {
|
||||
retErr.Set(err)
|
||||
} else {
|
||||
retErr.Set(ErrCancelWithNil)
|
||||
}
|
||||
|
||||
drain(source)
|
||||
finish()
|
||||
})
|
||||
|
||||
go func() {
|
||||
defer func() {
|
||||
drain(collector)
|
||||
if r := recover(); r != nil {
|
||||
panicChan.write(r)
|
||||
}
|
||||
finish()
|
||||
}()
|
||||
|
||||
reducer(collector, writer, cancel)
|
||||
}()
|
||||
|
||||
go executeMappers(mapperContext[T, U]{
|
||||
ctx: options.ctx,
|
||||
mapper: func(item T, w Writer[U]) {
|
||||
mapper(item, w, cancel)
|
||||
},
|
||||
source: source,
|
||||
panicChan: panicChan,
|
||||
collector: collector,
|
||||
doneChan: done,
|
||||
workers: options.workers,
|
||||
})
|
||||
|
||||
select {
|
||||
case <-options.ctx.Done():
|
||||
cancel(context.DeadlineExceeded)
|
||||
err = context.DeadlineExceeded
|
||||
case v := <-panicChan.channel:
|
||||
// drain output here, otherwise for loop panic in defer
|
||||
drain(output)
|
||||
panic(v)
|
||||
case v, ok := <-output:
|
||||
if e := retErr.Load(); e != nil {
|
||||
err = e
|
||||
} else if ok {
|
||||
val = v
|
||||
} else {
|
||||
err = ErrReduceNoOutput
|
||||
}
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// MapReduceVoid maps all elements generated from given generate,
|
||||
// and reduce the output elements with given reducer.
|
||||
func MapReduceVoid[T, U any](generate GenerateFunc[T], mapper MapperFunc[T, U],
|
||||
@@ -330,6 +247,89 @@ func executeMappers[T, U any](mCtx mapperContext[T, U]) {
|
||||
}
|
||||
}
|
||||
|
||||
// mapReduceWithPanicChan maps all elements from source, and reduce the output elements with given reducer.
|
||||
func mapReduceWithPanicChan[T, U, V any](source <-chan T, panicChan *onceChan, mapper MapperFunc[T, U],
|
||||
reducer ReducerFunc[U, V], opts ...Option) (val V, err error) {
|
||||
options := buildOptions(opts...)
|
||||
// output is used to write the final result
|
||||
output := make(chan V)
|
||||
defer func() {
|
||||
// reducer can only write once, if more, panic
|
||||
for range output {
|
||||
panic("more than one element written in reducer")
|
||||
}
|
||||
}()
|
||||
|
||||
// collector is used to collect data from mapper, and consume in reducer
|
||||
collector := make(chan U, options.workers)
|
||||
// if done is closed, all mappers and reducer should stop processing
|
||||
done := make(chan struct{})
|
||||
writer := newGuardedWriter(options.ctx, output, done)
|
||||
var closeOnce sync.Once
|
||||
// use atomic type to avoid data race
|
||||
var retErr errorx.AtomicError
|
||||
finish := func() {
|
||||
closeOnce.Do(func() {
|
||||
close(done)
|
||||
close(output)
|
||||
})
|
||||
}
|
||||
cancel := once(func(err error) {
|
||||
if err != nil {
|
||||
retErr.Set(err)
|
||||
} else {
|
||||
retErr.Set(ErrCancelWithNil)
|
||||
}
|
||||
|
||||
drain(source)
|
||||
finish()
|
||||
})
|
||||
|
||||
go func() {
|
||||
defer func() {
|
||||
drain(collector)
|
||||
if r := recover(); r != nil {
|
||||
panicChan.write(r)
|
||||
}
|
||||
finish()
|
||||
}()
|
||||
|
||||
reducer(collector, writer, cancel)
|
||||
}()
|
||||
|
||||
go executeMappers(mapperContext[T, U]{
|
||||
ctx: options.ctx,
|
||||
mapper: func(item T, w Writer[U]) {
|
||||
mapper(item, w, cancel)
|
||||
},
|
||||
source: source,
|
||||
panicChan: panicChan,
|
||||
collector: collector,
|
||||
doneChan: done,
|
||||
workers: options.workers,
|
||||
})
|
||||
|
||||
select {
|
||||
case <-options.ctx.Done():
|
||||
cancel(context.DeadlineExceeded)
|
||||
err = context.DeadlineExceeded
|
||||
case v := <-panicChan.channel:
|
||||
// drain output here, otherwise for loop panic in defer
|
||||
drain(output)
|
||||
panic(v)
|
||||
case v, ok := <-output:
|
||||
if e := retErr.Load(); e != nil {
|
||||
err = e
|
||||
} else if ok {
|
||||
val = v
|
||||
} else {
|
||||
err = ErrReduceNoOutput
|
||||
}
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func newOptions() *mapReduceOptions {
|
||||
return &mapReduceOptions{
|
||||
ctx: context.Background(),
|
||||
|
||||
@@ -1,13 +1,12 @@
|
||||
package prof
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"strconv"
|
||||
"fmt"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/olekukonko/tablewriter"
|
||||
"github.com/zeromicro/go-zero/core/logx"
|
||||
"github.com/zeromicro/go-zero/core/threading"
|
||||
)
|
||||
@@ -28,46 +27,15 @@ type (
|
||||
|
||||
const flushInterval = 5 * time.Minute
|
||||
|
||||
var (
|
||||
pc = &profileCenter{
|
||||
slots: make(map[string]*profileSlot),
|
||||
}
|
||||
once sync.Once
|
||||
)
|
||||
|
||||
func report(name string, duration time.Duration) {
|
||||
updated := func() bool {
|
||||
pc.lock.RLock()
|
||||
defer pc.lock.RUnlock()
|
||||
|
||||
slot, ok := pc.slots[name]
|
||||
if ok {
|
||||
atomic.AddInt64(&slot.lifecount, 1)
|
||||
atomic.AddInt64(&slot.lastcount, 1)
|
||||
atomic.AddInt64(&slot.lifecycle, int64(duration))
|
||||
atomic.AddInt64(&slot.lastcycle, int64(duration))
|
||||
}
|
||||
return ok
|
||||
}()
|
||||
|
||||
if !updated {
|
||||
func() {
|
||||
pc.lock.Lock()
|
||||
defer pc.lock.Unlock()
|
||||
|
||||
pc.slots[name] = &profileSlot{
|
||||
lifecount: 1,
|
||||
lastcount: 1,
|
||||
lifecycle: int64(duration),
|
||||
lastcycle: int64(duration),
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
once.Do(flushRepeatly)
|
||||
var pc = &profileCenter{
|
||||
slots: make(map[string]*profileSlot),
|
||||
}
|
||||
|
||||
func flushRepeatly() {
|
||||
func init() {
|
||||
flushRepeatedly()
|
||||
}
|
||||
|
||||
func flushRepeatedly() {
|
||||
threading.GoSafe(func() {
|
||||
for {
|
||||
time.Sleep(flushInterval)
|
||||
@@ -76,42 +44,64 @@ func flushRepeatly() {
|
||||
})
|
||||
}
|
||||
|
||||
func report(name string, duration time.Duration) {
|
||||
slot := loadOrStoreSlot(name, duration)
|
||||
|
||||
atomic.AddInt64(&slot.lifecount, 1)
|
||||
atomic.AddInt64(&slot.lastcount, 1)
|
||||
atomic.AddInt64(&slot.lifecycle, int64(duration))
|
||||
atomic.AddInt64(&slot.lastcycle, int64(duration))
|
||||
}
|
||||
|
||||
func loadOrStoreSlot(name string, duration time.Duration) *profileSlot {
|
||||
pc.lock.RLock()
|
||||
slot, ok := pc.slots[name]
|
||||
pc.lock.RUnlock()
|
||||
|
||||
if ok {
|
||||
return slot
|
||||
}
|
||||
|
||||
pc.lock.Lock()
|
||||
defer pc.lock.Unlock()
|
||||
|
||||
// double-check
|
||||
if slot, ok = pc.slots[name]; ok {
|
||||
return slot
|
||||
}
|
||||
|
||||
slot = &profileSlot{}
|
||||
pc.slots[name] = slot
|
||||
return slot
|
||||
}
|
||||
|
||||
func generateReport() string {
|
||||
var buffer bytes.Buffer
|
||||
buffer.WriteString("Profiling report\n")
|
||||
var data [][]string
|
||||
var builder strings.Builder
|
||||
builder.WriteString("Profiling report\n")
|
||||
builder.WriteString("QUEUE,LIFECOUNT,LIFECYCLE,LASTCOUNT,LASTCYCLE\n")
|
||||
|
||||
calcFn := func(total, count int64) string {
|
||||
if count == 0 {
|
||||
return "-"
|
||||
}
|
||||
|
||||
return (time.Duration(total) / time.Duration(count)).String()
|
||||
}
|
||||
|
||||
func() {
|
||||
pc.lock.Lock()
|
||||
defer pc.lock.Unlock()
|
||||
pc.lock.Lock()
|
||||
for key, slot := range pc.slots {
|
||||
builder.WriteString(fmt.Sprintf("%s,%d,%s,%d,%s\n",
|
||||
key,
|
||||
slot.lifecount,
|
||||
calcFn(slot.lifecycle, slot.lifecount),
|
||||
slot.lastcount,
|
||||
calcFn(slot.lastcycle, slot.lastcount),
|
||||
))
|
||||
|
||||
for key, slot := range pc.slots {
|
||||
data = append(data, []string{
|
||||
key,
|
||||
strconv.FormatInt(slot.lifecount, 10),
|
||||
calcFn(slot.lifecycle, slot.lifecount),
|
||||
strconv.FormatInt(slot.lastcount, 10),
|
||||
calcFn(slot.lastcycle, slot.lastcount),
|
||||
})
|
||||
// reset last cycle stats
|
||||
atomic.StoreInt64(&slot.lastcount, 0)
|
||||
atomic.StoreInt64(&slot.lastcycle, 0)
|
||||
}
|
||||
pc.lock.Unlock()
|
||||
|
||||
// reset the data for last cycle
|
||||
slot.lastcount = 0
|
||||
slot.lastcycle = 0
|
||||
}
|
||||
}()
|
||||
|
||||
table := tablewriter.NewWriter(&buffer)
|
||||
table.SetHeader([]string{"QUEUE", "LIFECOUNT", "LIFECYCLE", "LASTCOUNT", "LASTCYCLE"})
|
||||
table.SetBorder(false)
|
||||
table.AppendBulk(data)
|
||||
table.Render()
|
||||
|
||||
return buffer.String()
|
||||
return builder.String()
|
||||
}
|
||||
|
||||
@@ -8,7 +8,6 @@ import (
|
||||
)
|
||||
|
||||
func TestReport(t *testing.T) {
|
||||
once.Do(func() {})
|
||||
assert.NotContains(t, generateReport(), "foo")
|
||||
report("foo", time.Second)
|
||||
assert.Contains(t, generateReport(), "foo")
|
||||
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
"github.com/zeromicro/go-zero/core/stat"
|
||||
"github.com/zeromicro/go-zero/core/trace"
|
||||
"github.com/zeromicro/go-zero/internal/devserver"
|
||||
"github.com/zeromicro/go-zero/internal/profiling"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -38,6 +39,8 @@ type (
|
||||
Telemetry trace.Config `json:",optional"`
|
||||
DevServer DevServerConfig `json:",optional"`
|
||||
Shutdown proc.ShutdownConf `json:",optional"`
|
||||
// Profiling is the configuration for continuous profiling.
|
||||
Profiling profiling.Config `json:",optional"`
|
||||
}
|
||||
)
|
||||
|
||||
@@ -70,7 +73,9 @@ func (sc ServiceConf) SetUp() error {
|
||||
if len(sc.MetricsUrl) > 0 {
|
||||
stat.SetReportWriter(stat.NewRemoteWriter(sc.MetricsUrl))
|
||||
}
|
||||
|
||||
devserver.StartAgent(sc.DevServer)
|
||||
profiling.Start(sc.Profiling)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"sync"
|
||||
|
||||
"github.com/zeromicro/go-zero/core/logx"
|
||||
"github.com/zeromicro/go-zero/core/proc"
|
||||
"github.com/zeromicro/go-zero/core/syncx"
|
||||
"github.com/zeromicro/go-zero/core/threading"
|
||||
)
|
||||
|
||||
@@ -35,7 +36,7 @@ type (
|
||||
// NewServiceGroup returns a ServiceGroup.
|
||||
func NewServiceGroup() *ServiceGroup {
|
||||
sg := new(ServiceGroup)
|
||||
sg.stopOnce = syncx.Once(sg.doStop)
|
||||
sg.stopOnce = sync.OnceFunc(sg.doStop)
|
||||
return sg
|
||||
}
|
||||
|
||||
|
||||
@@ -19,7 +19,6 @@ import (
|
||||
const (
|
||||
clusterNameKey = "CLUSTER_NAME"
|
||||
testEnv = "test.v"
|
||||
timeFormat = "2006-01-02 15:04:05"
|
||||
)
|
||||
|
||||
var (
|
||||
@@ -45,7 +44,7 @@ func Report(msg string) {
|
||||
if fn != nil {
|
||||
reported := lessExecutor.DoOrDiscard(func() {
|
||||
var builder strings.Builder
|
||||
builder.WriteString(fmt.Sprintln(time.Now().Format(timeFormat)))
|
||||
builder.WriteString(fmt.Sprintln(time.Now().Format(time.DateTime)))
|
||||
if len(clusterName) > 0 {
|
||||
builder.WriteString(fmt.Sprintf("cluster: %s\n", clusterName))
|
||||
}
|
||||
|
||||
@@ -609,6 +609,28 @@ func (s *Redis) GetBitCtx(ctx context.Context, key string, offset int64) (int, e
|
||||
return int(v), nil
|
||||
}
|
||||
|
||||
// GetDel is the implementation of redis getdel command.
|
||||
// Available since: redis version 6.2.0
|
||||
func (s *Redis) GetDel(key string) (string, error) {
|
||||
return s.GetDelCtx(context.Background(), key)
|
||||
}
|
||||
|
||||
// GetDelCtx is the implementation of redis getdel command.
|
||||
// Available since: redis version 6.2.0
|
||||
func (s *Redis) GetDelCtx(ctx context.Context, key string) (string, error) {
|
||||
conn, err := getRedis(s)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
val, err := conn.GetDel(ctx, key).Result()
|
||||
if errors.Is(err, red.Nil) {
|
||||
return "", nil
|
||||
}
|
||||
|
||||
return val, err
|
||||
}
|
||||
|
||||
// GetSet is the implementation of redis getset command.
|
||||
func (s *Redis) GetSet(key, value string) (string, error) {
|
||||
return s.GetSetCtx(context.Background(), key, value)
|
||||
|
||||
@@ -1071,6 +1071,34 @@ func TestRedis_Set(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
func TestRedis_GetDel(t *testing.T) {
|
||||
t.Run("get_del", func(t *testing.T) {
|
||||
runOnRedis(t, func(client *Redis) {
|
||||
val, err := newRedis(client.Addr).GetDel("hello")
|
||||
assert.Equal(t, "", val)
|
||||
assert.Nil(t, err)
|
||||
err = client.Set("hello", "world")
|
||||
assert.Nil(t, err)
|
||||
val, err = client.Get("hello")
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, "world", val)
|
||||
val, err = client.GetDel("hello")
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, "world", val)
|
||||
val, err = client.Get("hello")
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, "", val)
|
||||
})
|
||||
})
|
||||
|
||||
t.Run("get_del_with_error", func(t *testing.T) {
|
||||
runOnRedisWithError(t, func(client *Redis) {
|
||||
_, err := newRedis(client.Addr, badType()).GetDel("hello")
|
||||
assert.Error(t, err)
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func TestRedis_GetSet(t *testing.T) {
|
||||
t.Run("set_get", func(t *testing.T) {
|
||||
runOnRedis(t, func(client *Redis) {
|
||||
|
||||
@@ -21,6 +21,7 @@ func CreateBlockingNode(r *Redis) (ClosableNode, error) {
|
||||
case NodeType:
|
||||
client := red.NewClient(&red.Options{
|
||||
Addr: r.Addr,
|
||||
Username: r.User,
|
||||
Password: r.Pass,
|
||||
DB: defaultDatabase,
|
||||
MaxRetries: maxRetries,
|
||||
@@ -32,6 +33,7 @@ func CreateBlockingNode(r *Redis) (ClosableNode, error) {
|
||||
case ClusterType:
|
||||
client := red.NewClusterClient(&red.ClusterOptions{
|
||||
Addrs: splitClusterAddrs(r.Addr),
|
||||
Username: r.User,
|
||||
Password: r.Pass,
|
||||
MaxRetries: maxRetries,
|
||||
PoolSize: 1,
|
||||
|
||||
@@ -31,6 +31,7 @@ func getClient(r *Redis) (*red.Client, error) {
|
||||
}
|
||||
store := red.NewClient(&red.Options{
|
||||
Addr: r.Addr,
|
||||
Username: r.User,
|
||||
Password: r.Pass,
|
||||
DB: defaultDatabase,
|
||||
MaxRetries: maxRetries,
|
||||
|
||||
@@ -28,6 +28,7 @@ func getCluster(r *Redis) (*red.ClusterClient, error) {
|
||||
}
|
||||
store := red.NewClusterClient(&red.ClusterOptions{
|
||||
Addrs: splitClusterAddrs(r.Addr),
|
||||
Username: r.User,
|
||||
Password: r.Pass,
|
||||
MaxRetries: maxRetries,
|
||||
MinIdleConns: idleConns,
|
||||
|
||||
@@ -267,6 +267,20 @@ func TestUnmarshalRowStruct(t *testing.T) {
|
||||
}, "select name, age from users where user=?", "anyone"), ErrNotMatchDestination)
|
||||
})
|
||||
|
||||
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||
value := new(struct {
|
||||
Name string
|
||||
age int
|
||||
})
|
||||
|
||||
rs := sqlmock.NewRows([]string{"name", "age"}).FromCSVString("liao,5")
|
||||
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
||||
|
||||
assert.ErrorIs(t, query(context.Background(), db, func(rows *sql.Rows) error {
|
||||
return unmarshalRow(value, rows, true)
|
||||
}, "select name, age from users where user=?", "anyone"), ErrNotMatchDestination)
|
||||
})
|
||||
|
||||
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("8")
|
||||
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
||||
@@ -310,6 +324,20 @@ func TestUnmarshalRowStructWithTags(t *testing.T) {
|
||||
}, "select name, age from users where user=?", "anyone"), ErrNotReadableValue)
|
||||
})
|
||||
|
||||
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||
value := new(struct {
|
||||
age int `db:"age"`
|
||||
Name string `db:"name"`
|
||||
})
|
||||
|
||||
rs := sqlmock.NewRows([]string{"name", "age"}).FromCSVString("liao,5")
|
||||
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
||||
|
||||
assert.ErrorIs(t, query(context.Background(), db, func(rows *sql.Rows) error {
|
||||
return unmarshalRow(value, rows, true)
|
||||
}, "select name, age from users where user=?", "anyone"), ErrNotReadableValue)
|
||||
})
|
||||
|
||||
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||
var value struct {
|
||||
Age *int `db:"age"`
|
||||
@@ -1307,25 +1335,25 @@ func TestAnonymousStructPr(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestAnonymousStructPrError(t *testing.T) {
|
||||
type Score struct {
|
||||
Discipline string `db:"discipline"`
|
||||
score uint `db:"score"`
|
||||
}
|
||||
type ClassType struct {
|
||||
Grade sql.NullString `db:"grade"`
|
||||
ClassName *string `db:"class_name"`
|
||||
}
|
||||
type Class struct {
|
||||
*ClassType
|
||||
Score
|
||||
}
|
||||
var value []*struct {
|
||||
Age int64 `db:"age"`
|
||||
Class
|
||||
Name string `db:"name"`
|
||||
}
|
||||
|
||||
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||
type Score struct {
|
||||
Discipline string `db:"discipline"`
|
||||
score uint `db:"score"`
|
||||
}
|
||||
type ClassType struct {
|
||||
Grade sql.NullString `db:"grade"`
|
||||
ClassName *string `db:"class_name"`
|
||||
}
|
||||
type Class struct {
|
||||
*ClassType
|
||||
Score
|
||||
}
|
||||
|
||||
var value []*struct {
|
||||
Age int64 `db:"age"`
|
||||
Class
|
||||
Name string `db:"name"`
|
||||
}
|
||||
rs := sqlmock.NewRows([]string{
|
||||
"name",
|
||||
"age",
|
||||
@@ -1338,10 +1366,50 @@ func TestAnonymousStructPrError(t *testing.T) {
|
||||
AddRow("second", 3, "grade one", "chinese", "class three grade two", 99)
|
||||
mock.ExpectQuery("select (.+) from users where user=?").
|
||||
WithArgs("anyone").WillReturnRows(rs)
|
||||
assert.Error(t, query(context.Background(), db, func(rows *sql.Rows) error {
|
||||
assert.ErrorIs(t, query(context.Background(), db, func(rows *sql.Rows) error {
|
||||
return unmarshalRows(&value, rows, true)
|
||||
}, "select name, age, grade, discipline, class_name, score from users where user=?",
|
||||
"anyone"))
|
||||
"anyone"), ErrNotReadableValue)
|
||||
if len(value) > 0 {
|
||||
assert.Equal(t, value[0].score, 0)
|
||||
}
|
||||
})
|
||||
|
||||
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||
type Score struct {
|
||||
Discipline string
|
||||
score uint
|
||||
}
|
||||
type ClassType struct {
|
||||
Grade sql.NullString
|
||||
ClassName *string
|
||||
}
|
||||
type Class struct {
|
||||
*ClassType
|
||||
Score
|
||||
}
|
||||
|
||||
var value []*struct {
|
||||
Age int64
|
||||
Class
|
||||
Name string
|
||||
}
|
||||
rs := sqlmock.NewRows([]string{
|
||||
"name",
|
||||
"age",
|
||||
"grade",
|
||||
"discipline",
|
||||
"class_name",
|
||||
"score",
|
||||
}).
|
||||
AddRow("first", 2, nil, "math", "experimental class", 100).
|
||||
AddRow("second", 3, "grade one", "chinese", "class three grade two", 99)
|
||||
mock.ExpectQuery("select (.+) from users where user=?").
|
||||
WithArgs("anyone").WillReturnRows(rs)
|
||||
assert.ErrorIs(t, query(context.Background(), db, func(rows *sql.Rows) error {
|
||||
return unmarshalRows(&value, rows, true)
|
||||
}, "select name, age, grade, discipline, class_name, score from users where user=?",
|
||||
"anyone"), ErrNotMatchDestination)
|
||||
if len(value) > 0 {
|
||||
assert.Equal(t, value[0].score, 0)
|
||||
}
|
||||
|
||||
@@ -2,6 +2,7 @@ package stringx
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"slices"
|
||||
"unicode"
|
||||
|
||||
"github.com/zeromicro/go-zero/core/lang"
|
||||
@@ -15,14 +16,9 @@ var (
|
||||
)
|
||||
|
||||
// Contains checks if str is in list.
|
||||
// Deprecated: use slices.Contains instead.
|
||||
func Contains(list []string, str string) bool {
|
||||
for _, each := range list {
|
||||
if each == str {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
return slices.Contains(list, str)
|
||||
}
|
||||
|
||||
// Filter filters chars from s with given filter function.
|
||||
@@ -123,11 +119,7 @@ func Remove(strings []string, strs ...string) []string {
|
||||
// Reverse reverses s.
|
||||
func Reverse(s string) string {
|
||||
runes := []rune(s)
|
||||
|
||||
for from, to := 0, len(runes)-1; from < to; from, to = from+1, to-1 {
|
||||
runes[from], runes[to] = runes[to], runes[from]
|
||||
}
|
||||
|
||||
slices.Reverse(runes)
|
||||
return string(runes)
|
||||
}
|
||||
|
||||
|
||||
@@ -7,6 +7,28 @@ import (
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestContainsString(t *testing.T) {
|
||||
cases := []struct {
|
||||
slice []string
|
||||
value string
|
||||
expect bool
|
||||
}{
|
||||
{[]string{"1"}, "1", true},
|
||||
{[]string{"1"}, "2", false},
|
||||
{[]string{"1", "2"}, "1", true},
|
||||
{[]string{"1", "2"}, "3", false},
|
||||
{nil, "3", false},
|
||||
{nil, "", false},
|
||||
}
|
||||
|
||||
for _, each := range cases {
|
||||
t.Run(path.Join(each.slice...), func(t *testing.T) {
|
||||
actual := Contains(each.slice, each.value)
|
||||
assert.Equal(t, each.expect, actual)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestNotEmpty(t *testing.T) {
|
||||
cases := []struct {
|
||||
args []string
|
||||
@@ -41,28 +63,6 @@ func TestNotEmpty(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestContainsString(t *testing.T) {
|
||||
cases := []struct {
|
||||
slice []string
|
||||
value string
|
||||
expect bool
|
||||
}{
|
||||
{[]string{"1"}, "1", true},
|
||||
{[]string{"1"}, "2", false},
|
||||
{[]string{"1", "2"}, "1", true},
|
||||
{[]string{"1", "2"}, "3", false},
|
||||
{nil, "3", false},
|
||||
{nil, "", false},
|
||||
}
|
||||
|
||||
for _, each := range cases {
|
||||
t.Run(path.Join(each.slice...), func(t *testing.T) {
|
||||
actual := Contains(each.slice, each.value)
|
||||
assert.Equal(t, each.expect, actual)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestFilter(t *testing.T) {
|
||||
cases := []struct {
|
||||
input string
|
||||
|
||||
@@ -3,9 +3,7 @@ package syncx
|
||||
import "sync"
|
||||
|
||||
// Once returns a func that guarantees fn can only called once.
|
||||
// Deprecated: use sync.OnceFunc instead.
|
||||
func Once(fn func()) func() {
|
||||
once := new(sync.Once)
|
||||
return func() {
|
||||
once.Do(fn)
|
||||
}
|
||||
return sync.OnceFunc(fn)
|
||||
}
|
||||
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"runtime"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
)
|
||||
|
||||
const factor = 10
|
||||
@@ -100,6 +101,6 @@ func (r *StableRunner[I, O]) Wait() {
|
||||
close(r.done)
|
||||
r.runner.Wait()
|
||||
for atomic.LoadUint64(&r.consumedIndex) < atomic.LoadUint64(&r.writtenIndex) {
|
||||
runtime.Gosched()
|
||||
time.Sleep(time.Millisecond)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"cmp"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/zeromicro/go-zero/core/mathx"
|
||||
"github.com/zeromicro/go-zero/core/stringx"
|
||||
)
|
||||
|
||||
@@ -39,7 +39,7 @@ func compare(v1, v2 string) int {
|
||||
fields1, fields2 := strings.Split(v1, "."), strings.Split(v2, ".")
|
||||
ver1, ver2 := strsToInts(fields1), strsToInts(fields2)
|
||||
ver1len, ver2len := len(ver1), len(ver2)
|
||||
shorter := mathx.MinInt(ver1len, ver2len)
|
||||
shorter := min(ver1len, ver2len)
|
||||
|
||||
for i := 0; i < shorter; i++ {
|
||||
if ver1[i] == ver2[i] {
|
||||
@@ -50,14 +50,7 @@ func compare(v1, v2 string) int {
|
||||
return 1
|
||||
}
|
||||
}
|
||||
|
||||
if ver1len < ver2len {
|
||||
return -1
|
||||
} else if ver1len == ver2len {
|
||||
return 0
|
||||
} else {
|
||||
return 1
|
||||
}
|
||||
return cmp.Compare(ver1len, ver2len)
|
||||
}
|
||||
|
||||
func strsToInts(strs []string) []int64 {
|
||||
|
||||
@@ -185,6 +185,8 @@ func (s *Server) buildHttpHandler(target *HttpClientConf) http.HandlerFunc {
|
||||
return
|
||||
}
|
||||
|
||||
// set the timeout if it's configured, take effect only if it's greater than 0
|
||||
// and less than the deadline of the original request
|
||||
if target.Timeout > 0 {
|
||||
timeout := time.Duration(target.Timeout) * time.Millisecond
|
||||
ctx, cancel := context.WithTimeout(r.Context(), timeout)
|
||||
@@ -276,7 +278,7 @@ func buildRequestWithNewTarget(r *http.Request, target *HttpClientConf) (*http.R
|
||||
}
|
||||
}
|
||||
|
||||
return &http.Request{
|
||||
newReq := &http.Request{
|
||||
Method: r.Method,
|
||||
URL: &u,
|
||||
Header: r.Header.Clone(),
|
||||
@@ -285,7 +287,10 @@ func buildRequestWithNewTarget(r *http.Request, target *HttpClientConf) (*http.R
|
||||
ProtoMinor: r.ProtoMinor,
|
||||
ContentLength: r.ContentLength,
|
||||
Body: io.NopCloser(r.Body),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// make sure the context is passed to the new request
|
||||
return newReq.WithContext(r.Context()), nil
|
||||
}
|
||||
|
||||
func createDescriptorSource(cli zrpc.Client, up Upstream) (grpcurl.DescriptorSource, error) {
|
||||
|
||||
@@ -201,6 +201,13 @@ func TestHttpToHttp(t *testing.T) {
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, http.StatusNotFound, resp.StatusCode)
|
||||
})
|
||||
|
||||
t.Run("method not allowed", func(t *testing.T) {
|
||||
resp, err := httpc.Do(context.Background(), http.MethodPost,
|
||||
"http://localhost:18882/api/ping", nil)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, http.StatusMethodNotAllowed, resp.StatusCode)
|
||||
})
|
||||
}
|
||||
|
||||
func TestHttpToHttpBadUpstream(t *testing.T) {
|
||||
|
||||
42
go.mod
42
go.mod
@@ -4,25 +4,25 @@ go 1.21
|
||||
|
||||
require (
|
||||
github.com/DATA-DOG/go-sqlmock v1.5.2
|
||||
github.com/alicebob/miniredis/v2 v2.34.0
|
||||
github.com/alicebob/miniredis/v2 v2.35.0
|
||||
github.com/fatih/color v1.18.0
|
||||
github.com/fullstorydev/grpcurl v1.9.2
|
||||
github.com/go-sql-driver/mysql v1.8.1
|
||||
github.com/golang-jwt/jwt/v4 v4.5.1
|
||||
github.com/fullstorydev/grpcurl v1.9.3
|
||||
github.com/go-sql-driver/mysql v1.9.0
|
||||
github.com/golang-jwt/jwt/v4 v4.5.2
|
||||
github.com/golang/mock v1.6.0
|
||||
github.com/golang/protobuf v1.5.4
|
||||
github.com/google/uuid v1.6.0
|
||||
github.com/jackc/pgx/v5 v5.7.2
|
||||
github.com/grafana/pyroscope-go v1.2.2
|
||||
github.com/jackc/pgx/v5 v5.7.4
|
||||
github.com/jhump/protoreflect v1.17.0
|
||||
github.com/olekukonko/tablewriter v0.0.5
|
||||
github.com/pelletier/go-toml/v2 v2.2.2
|
||||
github.com/prometheus/client_golang v1.20.5
|
||||
github.com/redis/go-redis/v9 v9.7.0
|
||||
github.com/prometheus/client_golang v1.21.1
|
||||
github.com/redis/go-redis/v9 v9.10.0
|
||||
github.com/spaolacci/murmur3 v1.1.0
|
||||
github.com/stretchr/testify v1.10.0
|
||||
go.etcd.io/etcd/api/v3 v3.5.15
|
||||
go.etcd.io/etcd/client/v3 v3.5.15
|
||||
go.mongodb.org/mongo-driver v1.17.2
|
||||
go.mongodb.org/mongo-driver v1.17.4
|
||||
go.opentelemetry.io/otel v1.24.0
|
||||
go.opentelemetry.io/otel/exporters/jaeger v1.17.0
|
||||
go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc v1.24.0
|
||||
@@ -33,12 +33,12 @@ require (
|
||||
go.opentelemetry.io/otel/trace v1.24.0
|
||||
go.uber.org/automaxprocs v1.6.0
|
||||
go.uber.org/goleak v1.3.0
|
||||
golang.org/x/net v0.34.0
|
||||
golang.org/x/sys v0.29.0
|
||||
golang.org/x/time v0.9.0
|
||||
golang.org/x/net v0.35.0
|
||||
golang.org/x/sys v0.30.0
|
||||
golang.org/x/time v0.10.0
|
||||
google.golang.org/genproto/googleapis/api v0.0.0-20240711142825-46eb208f015d
|
||||
google.golang.org/grpc v1.65.0
|
||||
google.golang.org/protobuf v1.36.4
|
||||
google.golang.org/protobuf v1.36.5
|
||||
gopkg.in/cheggaaa/pb.v1 v1.0.28
|
||||
gopkg.in/h2non/gock.v1 v1.1.2
|
||||
gopkg.in/yaml.v2 v2.4.0
|
||||
@@ -50,7 +50,6 @@ require (
|
||||
|
||||
require (
|
||||
filippo.io/edwards25519 v1.1.0 // indirect
|
||||
github.com/alicebob/gopher-json v0.0.0-20230218143504-906a9b012302 // indirect
|
||||
github.com/beorn7/perks v1.0.1 // indirect
|
||||
github.com/bufbuild/protocompile v0.14.1 // indirect
|
||||
github.com/cenkalti/backoff/v4 v4.3.0 // indirect
|
||||
@@ -73,6 +72,7 @@ require (
|
||||
github.com/google/gnostic-models v0.6.8 // indirect
|
||||
github.com/google/go-cmp v0.6.0 // indirect
|
||||
github.com/google/gofuzz v1.2.0 // indirect
|
||||
github.com/grafana/pyroscope-go/godeltaprof v0.1.8 // indirect
|
||||
github.com/grpc-ecosystem/grpc-gateway/v2 v2.20.0 // indirect
|
||||
github.com/h2non/parth v0.0.0-20190131123155-b4df798d6542 // indirect
|
||||
github.com/jackc/pgpassfile v1.0.0 // indirect
|
||||
@@ -80,7 +80,7 @@ require (
|
||||
github.com/jackc/puddle/v2 v2.2.2 // indirect
|
||||
github.com/josharian/intern v1.0.0 // indirect
|
||||
github.com/json-iterator/go v1.1.12 // indirect
|
||||
github.com/klauspost/compress v1.17.9 // indirect
|
||||
github.com/klauspost/compress v1.17.11 // indirect
|
||||
github.com/kylelemons/godebug v1.1.0 // indirect
|
||||
github.com/mailru/easyjson v0.7.7 // indirect
|
||||
github.com/mattn/go-colorable v0.1.13 // indirect
|
||||
@@ -93,7 +93,7 @@ require (
|
||||
github.com/openzipkin/zipkin-go v0.4.3 // indirect
|
||||
github.com/pmezard/go-difflib v1.0.0 // indirect
|
||||
github.com/prometheus/client_model v0.6.1 // indirect
|
||||
github.com/prometheus/common v0.55.0 // indirect
|
||||
github.com/prometheus/common v0.62.0 // indirect
|
||||
github.com/prometheus/procfs v0.15.1 // indirect
|
||||
github.com/rivo/uniseg v0.2.0 // indirect
|
||||
github.com/stretchr/objx v0.5.2 // indirect
|
||||
@@ -109,11 +109,11 @@ require (
|
||||
go.uber.org/atomic v1.10.0 // indirect
|
||||
go.uber.org/multierr v1.9.0 // indirect
|
||||
go.uber.org/zap v1.24.0 // indirect
|
||||
golang.org/x/crypto v0.32.0 // indirect
|
||||
golang.org/x/oauth2 v0.21.0 // indirect
|
||||
golang.org/x/sync v0.10.0 // indirect
|
||||
golang.org/x/term v0.28.0 // indirect
|
||||
golang.org/x/text v0.21.0 // indirect
|
||||
golang.org/x/crypto v0.33.0 // indirect
|
||||
golang.org/x/oauth2 v0.24.0 // indirect
|
||||
golang.org/x/sync v0.11.0 // indirect
|
||||
golang.org/x/term v0.29.0 // indirect
|
||||
golang.org/x/text v0.22.0 // indirect
|
||||
google.golang.org/genproto/googleapis/rpc v0.0.0-20240701130421-f6361c86f094 // indirect
|
||||
gopkg.in/inf.v0 v0.9.1 // indirect
|
||||
gopkg.in/yaml.v3 v3.0.1 // indirect
|
||||
|
||||
85
go.sum
85
go.sum
@@ -2,10 +2,8 @@ filippo.io/edwards25519 v1.1.0 h1:FNf4tywRC1HmFuKW5xopWpigGjJKiJSV0Cqo0cJWDaA=
|
||||
filippo.io/edwards25519 v1.1.0/go.mod h1:BxyFTGdWcka3PhytdK4V28tE5sGfRvvvRV7EaN4VDT4=
|
||||
github.com/DATA-DOG/go-sqlmock v1.5.2 h1:OcvFkGmslmlZibjAjaHm3L//6LiuBgolP7OputlJIzU=
|
||||
github.com/DATA-DOG/go-sqlmock v1.5.2/go.mod h1:88MAG/4G7SMwSE3CeA0ZKzrT5CiOU3OJ+JlNzwDqpNU=
|
||||
github.com/alicebob/gopher-json v0.0.0-20230218143504-906a9b012302 h1:uvdUDbHQHO85qeSydJtItA4T55Pw6BtAejd0APRJOCE=
|
||||
github.com/alicebob/gopher-json v0.0.0-20230218143504-906a9b012302/go.mod h1:SGnFV6hVsYE877CKEZ6tDNTjaSXYUk6QqoIK6PrAtcc=
|
||||
github.com/alicebob/miniredis/v2 v2.34.0 h1:mBFWMaJSNL9RwdGRyEDoAAv8OQc5UlEhLDQggTglU/0=
|
||||
github.com/alicebob/miniredis/v2 v2.34.0/go.mod h1:kWShP4b58T1CW0Y5dViCd5ztzrDqRWqM3nksiyXk5s8=
|
||||
github.com/alicebob/miniredis/v2 v2.35.0 h1:QwLphYqCEAo1eu1TqPRN2jgVMPBweeQcR21jeqDCONI=
|
||||
github.com/alicebob/miniredis/v2 v2.35.0/go.mod h1:TcL7YfarKPGDAthEtl5NBeHZfeUQj6OXMm/+iu5cLMM=
|
||||
github.com/benbjohnson/clock v1.1.0 h1:Q92kusRqC1XV2MjkWETPvjJVqKetz1OzxZB7mHJLju8=
|
||||
github.com/benbjohnson/clock v1.1.0/go.mod h1:J11/hYXuz8f4ySSvYwY0FKfm+ezbsZBKZxNJlLklBHA=
|
||||
github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM=
|
||||
@@ -40,8 +38,8 @@ github.com/envoyproxy/protoc-gen-validate v1.0.4 h1:gVPz/FMfvh57HdSJQyvBtF00j8JU
|
||||
github.com/envoyproxy/protoc-gen-validate v1.0.4/go.mod h1:qys6tmnRsYrQqIhm2bvKZH4Blx/1gTIZ2UKVY1M+Yew=
|
||||
github.com/fatih/color v1.18.0 h1:S8gINlzdQ840/4pfAwic/ZE0djQEH3wM94VfqLTZcOM=
|
||||
github.com/fatih/color v1.18.0/go.mod h1:4FelSpRwEGDpQ12mAdzqdOukCy4u8WUtOY6lkT/6HfU=
|
||||
github.com/fullstorydev/grpcurl v1.9.2 h1:ObqVQTZW7aFnhuqQoppUrvep2duMBanB0UYK2Mm8euo=
|
||||
github.com/fullstorydev/grpcurl v1.9.2/go.mod h1:jLfcF55HAz6TYIJY9xFFWgsl0D7o2HlxA5Z4lUG0Tdo=
|
||||
github.com/fullstorydev/grpcurl v1.9.3 h1:PC1Xi3w+JAvEE2Tg2Gf2RfVgPbf9+tbuQr1ZkyVU3jk=
|
||||
github.com/fullstorydev/grpcurl v1.9.3/go.mod h1:/b4Wxe8bG6ndAjlfSUjwseQReUDUvBJiFEB7UllOlUE=
|
||||
github.com/go-logr/logr v1.2.2/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A=
|
||||
github.com/go-logr/logr v1.3.0/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY=
|
||||
github.com/go-logr/logr v1.4.2 h1:6pFjapn8bFcIbiKo3XT4j/BhANplGihG6tvd+8rYgrY=
|
||||
@@ -55,15 +53,15 @@ github.com/go-openapi/jsonreference v0.20.2/go.mod h1:Bl1zwGIM8/wsvqjsOQLJ/SH+En
|
||||
github.com/go-openapi/swag v0.22.3/go.mod h1:UzaqsxGiab7freDnrUUra0MwWfN/q7tE4j+VcZ0yl14=
|
||||
github.com/go-openapi/swag v0.22.4 h1:QLMzNJnMGPRNDCbySlcj1x01tzU8/9LTTL9hZZZogBU=
|
||||
github.com/go-openapi/swag v0.22.4/go.mod h1:UzaqsxGiab7freDnrUUra0MwWfN/q7tE4j+VcZ0yl14=
|
||||
github.com/go-sql-driver/mysql v1.8.1 h1:LedoTUt/eveggdHS9qUFC1EFSa8bU2+1pZjSRpvNJ1Y=
|
||||
github.com/go-sql-driver/mysql v1.8.1/go.mod h1:wEBSXgmK//2ZFJyE+qWnIsVGmvmEKlqwuVSjsCm7DZg=
|
||||
github.com/go-sql-driver/mysql v1.9.0 h1:Y0zIbQXhQKmQgTp44Y1dp3wTXcn804QoTptLZT1vtvo=
|
||||
github.com/go-sql-driver/mysql v1.9.0/go.mod h1:pDetrLJeA3oMujJuvXc8RJoasr589B6A9fwzD3QMrqw=
|
||||
github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572 h1:tfuBGBXKqDEevZMzYi5KSi8KkcZtzBcTgAUUtapy0OI=
|
||||
github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572/go.mod h1:9Pwr4B2jHnOSGXyyzV8ROjYa2ojvAY6HCGYYfMoC3Ls=
|
||||
github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA=
|
||||
github.com/gogo/protobuf v1.3.2 h1:Ov1cvc58UF3b5XjBnZv7+opcTcQFZebYjWzi34vdm4Q=
|
||||
github.com/gogo/protobuf v1.3.2/go.mod h1:P1XiOD3dCwIKUDQYPy72D8LYyHL2YPYrpS2s69NZV8Q=
|
||||
github.com/golang-jwt/jwt/v4 v4.5.1 h1:JdqV9zKUdtaa9gdPlywC3aeoEsR681PlKC+4F5gQgeo=
|
||||
github.com/golang-jwt/jwt/v4 v4.5.1/go.mod h1:m21LjoU+eqJr34lmDMbreY2eSTRJ1cv77w39/MY0Ch0=
|
||||
github.com/golang-jwt/jwt/v4 v4.5.2 h1:YtQM7lnr8iZ+j5q71MGKkNw9Mn7AjHM68uc9g5fXeUI=
|
||||
github.com/golang-jwt/jwt/v4 v4.5.2/go.mod h1:m21LjoU+eqJr34lmDMbreY2eSTRJ1cv77w39/MY0Ch0=
|
||||
github.com/golang/mock v1.6.0 h1:ErTB+efbowRARo13NNdxyJji2egdxLGQhRaY+DUumQc=
|
||||
github.com/golang/mock v1.6.0/go.mod h1:p6yTPP+5HYm5mzsMV8JkE6ZKdX+/wYM6Hr+LicevLPs=
|
||||
github.com/golang/protobuf v1.5.4 h1:i7eJL8qZTpSEXOPTxNKhASYpMn+8e5Q6AdndVa1dWek=
|
||||
@@ -82,6 +80,10 @@ github.com/google/pprof v0.0.0-20210720184732-4bb14d4b1be1 h1:K6RDEckDVWvDI9JAJY
|
||||
github.com/google/pprof v0.0.0-20210720184732-4bb14d4b1be1/go.mod h1:kpwsk12EmLew5upagYY7GY0pfYCcupk39gWOCRROcvE=
|
||||
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
|
||||
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||
github.com/grafana/pyroscope-go v1.2.2 h1:uvKCyZMD724RkaCEMrSTC38Yn7AnFe8S2wiAIYdDPCE=
|
||||
github.com/grafana/pyroscope-go v1.2.2/go.mod h1:zzT9QXQAp2Iz2ZdS216UiV8y9uXJYQiGE1q8v1FyhqU=
|
||||
github.com/grafana/pyroscope-go/godeltaprof v0.1.8 h1:iwOtYXeeVSAeYefJNaxDytgjKtUuKQbJqgAIjlnicKg=
|
||||
github.com/grafana/pyroscope-go/godeltaprof v0.1.8/go.mod h1:2+l7K7twW49Ct4wFluZD3tZ6e0SjanjcUUBPVD/UuGU=
|
||||
github.com/grpc-ecosystem/grpc-gateway/v2 v2.20.0 h1:bkypFPDjIYGfCYD5mRBvpqxfYX1YCS1PXdKYWi8FsN0=
|
||||
github.com/grpc-ecosystem/grpc-gateway/v2 v2.20.0/go.mod h1:P+Lt/0by1T8bfcF3z737NnSbmxQAppXMRziHUxPOC8k=
|
||||
github.com/h2non/parth v0.0.0-20190131123155-b4df798d6542 h1:2VTzZjLZBgl62/EtslCrtky5vbi9dd7HrQPQIx6wqiw=
|
||||
@@ -90,8 +92,8 @@ github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsI
|
||||
github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg=
|
||||
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 h1:iCEnooe7UlwOQYpKFhBabPMi4aNAfoODPEFNiAnClxo=
|
||||
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM=
|
||||
github.com/jackc/pgx/v5 v5.7.2 h1:mLoDLV6sonKlvjIEsV56SkWNCnuNv531l94GaIzO+XI=
|
||||
github.com/jackc/pgx/v5 v5.7.2/go.mod h1:ncY89UGWxg82EykZUwSpUKEfccBGGYq1xjrOpsbsfGQ=
|
||||
github.com/jackc/pgx/v5 v5.7.4 h1:9wKznZrhWa2QiHL+NjTSPP6yjl3451BX3imWDnokYlg=
|
||||
github.com/jackc/pgx/v5 v5.7.4/go.mod h1:ncY89UGWxg82EykZUwSpUKEfccBGGYq1xjrOpsbsfGQ=
|
||||
github.com/jackc/puddle/v2 v2.2.2 h1:PR8nw+E/1w0GLuRFSmiioY6UooMp6KJv0/61nB7icHo=
|
||||
github.com/jackc/puddle/v2 v2.2.2/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4=
|
||||
github.com/jhump/protoreflect v1.17.0 h1:qOEr613fac2lOuTgWN4tPAtLL7fUSbuJL5X5XumQh94=
|
||||
@@ -103,8 +105,8 @@ github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHm
|
||||
github.com/kisielk/errcheck v1.5.0/go.mod h1:pFxgyoBC7bSaBwPgfKdkLd5X25qrDl4LWUI2bnpBCr8=
|
||||
github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck=
|
||||
github.com/kisielk/sqlstruct v0.0.0-20201105191214-5f3e10d3ab46/go.mod h1:yyMNCyc/Ib3bDTKd379tNMpB/7/H5TjM2Y9QJ5THLbE=
|
||||
github.com/klauspost/compress v1.17.9 h1:6KIumPrER1LHsvBVuDa0r5xaG0Es51mhhB9BQB2qeMA=
|
||||
github.com/klauspost/compress v1.17.9/go.mod h1:Di0epgTjJY877eYKx5yC51cX2A2Vl2ibi7bDH9ttBbw=
|
||||
github.com/klauspost/compress v1.17.11 h1:In6xLpyWOi1+C7tXUUWv2ot1QvBjxevKAaI6IXrJmUc=
|
||||
github.com/klauspost/compress v1.17.11/go.mod h1:pMDklpSncoRMuLFrf1W9Ss9KT+0rH90U12bZKk7uwG0=
|
||||
github.com/kr/pretty v0.2.1/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI=
|
||||
github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE=
|
||||
github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk=
|
||||
@@ -121,7 +123,6 @@ github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovk
|
||||
github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM=
|
||||
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
|
||||
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
|
||||
github.com/mattn/go-runewidth v0.0.9/go.mod h1:H031xJmbD/WCDINGzjvQ9THkh0rPKHF+m2gUSrubnMI=
|
||||
github.com/mattn/go-runewidth v0.0.15 h1:UNAjwbU9l54TA3KzvqLGxwWjHmMgBUVhBiTjelZgg3U=
|
||||
github.com/mattn/go-runewidth v0.0.15/go.mod h1:Jdepj2loyihRzMpdS35Xk/zdY8IAYHsh153qUoGf23w=
|
||||
github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
|
||||
@@ -135,8 +136,6 @@ github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 h1:C3w9PqII01/Oq
|
||||
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822/go.mod h1:+n7T8mK8HuQTcFwEeznm/DIxMOiR9yIdICNftLE1DvQ=
|
||||
github.com/nbio/st v0.0.0-20140626010706-e9e8d9816f32 h1:W6apQkHrMkS0Muv8G/TipAy/FJl/rCYT0+EuS8+Z0z4=
|
||||
github.com/nbio/st v0.0.0-20140626010706-e9e8d9816f32/go.mod h1:9wM+0iRr9ahx58uYLpLIr5fm8diHn0JbqRycJi6w0Ms=
|
||||
github.com/olekukonko/tablewriter v0.0.5 h1:P2Ga83D34wi1o9J6Wh1mRuqd4mF/x/lgBS7N7AbDhec=
|
||||
github.com/olekukonko/tablewriter v0.0.5/go.mod h1:hPp6KlRPjbx+hW8ykQs1w3UBbZlj6HuIJcUGPhkA7kY=
|
||||
github.com/onsi/ginkgo/v2 v2.13.0 h1:0jY9lJquiL8fcf3M4LAXN5aMlS/b2BV86HFFPCPMgE4=
|
||||
github.com/onsi/ginkgo/v2 v2.13.0/go.mod h1:TE309ZR8s5FsKKpuB1YAQYBzCaAfUgatB/xlT/ETL/o=
|
||||
github.com/onsi/gomega v1.29.0 h1:KIA/t2t5UBzoirT4H9tsML45GEbo3ouUnBHsCfD2tVg=
|
||||
@@ -151,16 +150,16 @@ github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZb
|
||||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||
github.com/prashantv/gostub v1.1.0 h1:BTyx3RfQjRHnUWaGF9oQos79AlQ5k8WNktv7VGvVH4g=
|
||||
github.com/prashantv/gostub v1.1.0/go.mod h1:A5zLQHz7ieHGG7is6LLXLz7I8+3LZzsrV0P1IAHhP5U=
|
||||
github.com/prometheus/client_golang v1.20.5 h1:cxppBPuYhUnsO6yo/aoRol4L7q7UFfdm+bR9r+8l63Y=
|
||||
github.com/prometheus/client_golang v1.20.5/go.mod h1:PIEt8X02hGcP8JWbeHyeZ53Y/jReSnHgO035n//V5WE=
|
||||
github.com/prometheus/client_golang v1.21.1 h1:DOvXXTqVzvkIewV/CDPFdejpMCGeMcbGCQ8YOmu+Ibk=
|
||||
github.com/prometheus/client_golang v1.21.1/go.mod h1:U9NM32ykUErtVBxdvD3zfi+EuFkkaBvMb09mIfe0Zgg=
|
||||
github.com/prometheus/client_model v0.6.1 h1:ZKSh/rekM+n3CeS952MLRAdFwIKqeY8b62p8ais2e9E=
|
||||
github.com/prometheus/client_model v0.6.1/go.mod h1:OrxVMOVHjw3lKMa8+x6HeMGkHMQyHDk9E3jmP2AmGiY=
|
||||
github.com/prometheus/common v0.55.0 h1:KEi6DK7lXW/m7Ig5i47x0vRzuBsHuvJdi5ee6Y3G1dc=
|
||||
github.com/prometheus/common v0.55.0/go.mod h1:2SECS4xJG1kd8XF9IcM1gMX6510RAEL65zxzNImwdc8=
|
||||
github.com/prometheus/common v0.62.0 h1:xasJaQlnWAeyHdUBeGjXmutelfJHWMRr+Fg4QszZ2Io=
|
||||
github.com/prometheus/common v0.62.0/go.mod h1:vyBcEuLSvWos9B1+CyL7JZ2up+uFzXhkqml0W5zIY1I=
|
||||
github.com/prometheus/procfs v0.15.1 h1:YagwOFzUgYfKKHX6Dr+sHT7km/hxC76UB0learggepc=
|
||||
github.com/prometheus/procfs v0.15.1/go.mod h1:fB45yRUv8NstnjriLhBQLuOUt+WW4BsoGhij/e3PBqk=
|
||||
github.com/redis/go-redis/v9 v9.7.0 h1:HhLSs+B6O021gwzl+locl0zEDnyNkxMtf/Z3NNBMa9E=
|
||||
github.com/redis/go-redis/v9 v9.7.0/go.mod h1:f6zhXITC7JUJIlPEiBOTXxJgPLdZcA93GewI7inzyWw=
|
||||
github.com/redis/go-redis/v9 v9.10.0 h1:FxwK3eV8p/CQa0Ch276C7u2d0eNC9kCmAYQ7mCXCzVs=
|
||||
github.com/redis/go-redis/v9 v9.10.0/go.mod h1:huWgSWd8mW6+m0VPhJjSSQ+d6Nh1VICQ6Q5lHuCH/Iw=
|
||||
github.com/rivo/uniseg v0.2.0 h1:S1pD9weZBuJdFmowNwbpi7BJ8TNftyUImj/0WQi72jY=
|
||||
github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc=
|
||||
github.com/rogpeppe/go-internal v1.10.0 h1:TMyTOH3F/DB16zRVcYyreMH6GnZZrwQVAoYjRBZyWFQ=
|
||||
@@ -203,8 +202,8 @@ go.etcd.io/etcd/client/pkg/v3 v3.5.15 h1:fo0HpWz/KlHGMCC+YejpiCmyWDEuIpnTDzpJLB5
|
||||
go.etcd.io/etcd/client/pkg/v3 v3.5.15/go.mod h1:mXDI4NAOwEiszrHCb0aqfAYNCrZP4e9hRca3d1YK8EU=
|
||||
go.etcd.io/etcd/client/v3 v3.5.15 h1:23M0eY4Fd/inNv1ZfU3AxrbbOdW79r9V9Rl62Nm6ip4=
|
||||
go.etcd.io/etcd/client/v3 v3.5.15/go.mod h1:CLSJxrYjvLtHsrPKsy7LmZEE+DK2ktfd2bN4RhBMwlU=
|
||||
go.mongodb.org/mongo-driver v1.17.2 h1:gvZyk8352qSfzyZ2UMWcpDpMSGEr1eqE4T793SqyhzM=
|
||||
go.mongodb.org/mongo-driver v1.17.2/go.mod h1:Hy04i7O2kC4RS06ZrhPRqj/u4DTYkFDAAccj+rVKqgQ=
|
||||
go.mongodb.org/mongo-driver v1.17.4 h1:jUorfmVzljjr0FLzYQsGP8cgN/qzzxlY9Vh0C9KFXVw=
|
||||
go.mongodb.org/mongo-driver v1.17.4/go.mod h1:Hy04i7O2kC4RS06ZrhPRqj/u4DTYkFDAAccj+rVKqgQ=
|
||||
go.opentelemetry.io/otel v1.24.0 h1:0LAOdjNmQeSTzGBzduGe/rU4tZhMwL5rWgtp9Ku5Jfo=
|
||||
go.opentelemetry.io/otel v1.24.0/go.mod h1:W7b9Ozg4nkF5tWI5zsXkaKKDjdVjpD4oAt9Qi/MArHo=
|
||||
go.opentelemetry.io/otel/exporters/jaeger v1.17.0 h1:D7UpUy2Xc2wsi1Ras6V40q806WM07rqoCWzXu7Sqy+4=
|
||||
@@ -241,8 +240,8 @@ golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACk
|
||||
golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
|
||||
golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
|
||||
golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
|
||||
golang.org/x/crypto v0.32.0 h1:euUpcYgM8WcP71gNpTqQCn6rC2t6ULUPiOzfWaXVVfc=
|
||||
golang.org/x/crypto v0.32.0/go.mod h1:ZnnJkOaASj8g0AjIduWNlq2NRxL0PlBrbKVyZ6V/Ugc=
|
||||
golang.org/x/crypto v0.33.0 h1:IOBPskki6Lysi0lo9qQvbxiQ+FvsCC/YWOecCHAixus=
|
||||
golang.org/x/crypto v0.33.0/go.mod h1:bVdXmD7IV/4GdElGPozy6U7lWdRXA4qyRVGJV57uQ5M=
|
||||
golang.org/x/mod v0.2.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
|
||||
golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
|
||||
golang.org/x/mod v0.4.2/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
|
||||
@@ -254,17 +253,17 @@ golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwY
|
||||
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
|
||||
golang.org/x/net v0.0.0-20210405180319-a5a99cb37ef4/go.mod h1:p54w0d4576C0XHj96bSt6lcn1PtDYWL6XObtHCRCNQM=
|
||||
golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c=
|
||||
golang.org/x/net v0.34.0 h1:Mb7Mrk043xzHgnRM88suvJFwzVrRfHEHJEl5/71CKw0=
|
||||
golang.org/x/net v0.34.0/go.mod h1:di0qlW3YNM5oh6GqDGQr92MyTozJPmybPK4Ev/Gm31k=
|
||||
golang.org/x/oauth2 v0.21.0 h1:tsimM75w1tF/uws5rbeHzIWxEqElMehnc+iW793zsZs=
|
||||
golang.org/x/oauth2 v0.21.0/go.mod h1:XYTD2NtWslqkgxebSiOHnXEap4TF09sJSc7H1sXbhtI=
|
||||
golang.org/x/net v0.35.0 h1:T5GQRQb2y08kTAByq9L4/bz8cipCdA8FbRTXewonqY8=
|
||||
golang.org/x/net v0.35.0/go.mod h1:EglIi67kWsHKlRzzVMUD93VMSWGFOMSZgxFjparz1Qk=
|
||||
golang.org/x/oauth2 v0.24.0 h1:KTBBxWqUa0ykRPLtV69rRto9TLXcqYkeswu48x/gvNE=
|
||||
golang.org/x/oauth2 v0.24.0/go.mod h1:XYTD2NtWslqkgxebSiOHnXEap4TF09sJSc7H1sXbhtI=
|
||||
golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sync v0.10.0 h1:3NQrjDixjgGwUOCaF8w2+VYHv0Ve/vGYSbdkTa98gmQ=
|
||||
golang.org/x/sync v0.10.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
|
||||
golang.org/x/sync v0.11.0 h1:GGz8+XQP4FvTTrjZPzNKTMFtSXH80RAzG+5ghFPgK9w=
|
||||
golang.org/x/sync v0.11.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
|
||||
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
||||
golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
@@ -276,20 +275,20 @@ golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBc
|
||||
golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.29.0 h1:TPYlXGxvx1MGTn2GiZDhnjPA9wZzZeGKHHmKhHYvgaU=
|
||||
golang.org/x/sys v0.29.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
|
||||
golang.org/x/sys v0.30.0 h1:QjkSwP/36a20jFYWkSue1YwXzLmsV5Gfq7Eiy72C1uc=
|
||||
golang.org/x/sys v0.30.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
|
||||
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
|
||||
golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8=
|
||||
golang.org/x/term v0.28.0 h1:/Ts8HFuMR2E6IP/jlo7QVLZHggjKQbhu/7H0LJFr3Gg=
|
||||
golang.org/x/term v0.28.0/go.mod h1:Sw/lC2IAUZ92udQNf3WodGtn4k/XoLyZoh8v/8uiwek=
|
||||
golang.org/x/term v0.29.0 h1:L6pJp37ocefwRRtYPKSWOWzOtWSxVajvz2ldH/xi3iU=
|
||||
golang.org/x/term v0.29.0/go.mod h1:6bl4lRlvVuDgSf3179VpIxBF0o10JUpXWOnI7nErv7s=
|
||||
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
|
||||
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
||||
golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ=
|
||||
golang.org/x/text v0.3.8/go.mod h1:E6s5w1FMmriuDzIBO73fBruAKo1PCIq6d2Q6DHfQ8WQ=
|
||||
golang.org/x/text v0.21.0 h1:zyQAAkrwaneQ066sspRyJaG9VNi/YJ1NfzcGB3hZ/qo=
|
||||
golang.org/x/text v0.21.0/go.mod h1:4IBbMaMmOPCJ8SecivzSH54+73PCFmPWxNTLm+vZkEQ=
|
||||
golang.org/x/time v0.9.0 h1:EsRrnYcQiGH+5FfbgvV4AP7qEZstoyrHB0DzarOQ4ZY=
|
||||
golang.org/x/time v0.9.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM=
|
||||
golang.org/x/text v0.22.0 h1:bofq7m3/HAFvbF51jz3Q9wLg3jkvSPuiZu/pD1XwgtM=
|
||||
golang.org/x/text v0.22.0/go.mod h1:YRoo4H8PVmsu+E3Ou7cqLVH8oXWIHVoX0jqUWALQhfY=
|
||||
golang.org/x/time v0.10.0 h1:3usCWA8tQn0L8+hFJQNgzpWbd89begxN66o1Ojdn5L4=
|
||||
golang.org/x/time v0.10.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM=
|
||||
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
|
||||
golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
|
||||
golang.org/x/tools v0.0.0-20200619180055-7c47624df98f/go.mod h1:EkVYQZoAsY45+roYkvgYkIh4xh/qjgUK9TdY2XT94GE=
|
||||
@@ -308,8 +307,8 @@ google.golang.org/genproto/googleapis/rpc v0.0.0-20240701130421-f6361c86f094 h1:
|
||||
google.golang.org/genproto/googleapis/rpc v0.0.0-20240701130421-f6361c86f094/go.mod h1:Ue6ibwXGpU+dqIcODieyLOcgj7z8+IcskoNIgZxtrFY=
|
||||
google.golang.org/grpc v1.65.0 h1:bs/cUb4lp1G5iImFFd3u5ixQzweKizoZJAwBNLR42lc=
|
||||
google.golang.org/grpc v1.65.0/go.mod h1:WgYC2ypjlB0EiQi6wdKixMqukr6lBc0Vo+oOgjrM5ZQ=
|
||||
google.golang.org/protobuf v1.36.4 h1:6A3ZDJHn/eNqc1i+IdefRzy/9PokBTPvcqMySR7NNIM=
|
||||
google.golang.org/protobuf v1.36.4/go.mod h1:9fA7Ob0pmnwhb644+1+CVWFRbNajQ6iRojtC/QF5bRE=
|
||||
google.golang.org/protobuf v1.36.5 h1:tPhr+woSbjfYvY6/GPufUoYizxw1cF/yFoxJ2fmpwlM=
|
||||
google.golang.org/protobuf v1.36.5/go.mod h1:9fA7Ob0pmnwhb644+1+CVWFRbNajQ6iRojtC/QF5bRE=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk=
|
||||
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q=
|
||||
|
||||
263
internal/profiling/profiling.go
Normal file
263
internal/profiling/profiling.go
Normal file
@@ -0,0 +1,263 @@
|
||||
package profiling
|
||||
|
||||
import (
|
||||
"runtime"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/grafana/pyroscope-go"
|
||||
"github.com/zeromicro/go-zero/core/logx"
|
||||
"github.com/zeromicro/go-zero/core/proc"
|
||||
"github.com/zeromicro/go-zero/core/stat"
|
||||
"github.com/zeromicro/go-zero/core/threading"
|
||||
)
|
||||
|
||||
const (
|
||||
defaultCheckInterval = time.Second * 10
|
||||
defaultProfilingDuration = time.Minute * 2
|
||||
defaultUploadRate = time.Second * 15
|
||||
)
|
||||
|
||||
type (
|
||||
Config struct {
|
||||
// Name is the name of the application.
|
||||
Name string `json:",optional,inherit"`
|
||||
// ServerAddr is the address of the profiling server.
|
||||
ServerAddr string
|
||||
// AuthUser is the username for basic authentication.
|
||||
AuthUser string `json:",optional"`
|
||||
// AuthPassword is the password for basic authentication.
|
||||
AuthPassword string `json:",optional"`
|
||||
// UploadRate is the duration for which profiling data is uploaded.
|
||||
UploadRate time.Duration `json:",default=15s"`
|
||||
// CheckInterval is the interval to check if profiling should start.
|
||||
CheckInterval time.Duration `json:",default=10s"`
|
||||
// ProfilingDuration is the duration for which profiling data is collected.
|
||||
ProfilingDuration time.Duration `json:",default=2m"`
|
||||
// CpuThreshold the collection is allowed only when the current service cpu < CpuThreshold
|
||||
CpuThreshold int64 `json:",default=700,range=[0:1000)"`
|
||||
|
||||
// ProfileType is the type of profiling to be performed.
|
||||
ProfileType ProfileType
|
||||
}
|
||||
|
||||
ProfileType struct {
|
||||
// Logger is a flag to enable or disable logging.
|
||||
Logger bool `json:",default=false"`
|
||||
// CPU is a flag to disable CPU profiling.
|
||||
CPU bool `json:",default=true"`
|
||||
// Goroutines is a flag to disable goroutine profiling.
|
||||
Goroutines bool `json:",default=true"`
|
||||
// Memory is a flag to disable memory profiling.
|
||||
Memory bool `json:",default=true"`
|
||||
// Mutex is a flag to disable mutex profiling.
|
||||
Mutex bool `json:",default=false"`
|
||||
// Block is a flag to disable block profiling.
|
||||
Block bool `json:",default=false"`
|
||||
}
|
||||
|
||||
profiler interface {
|
||||
Start() error
|
||||
Stop() error
|
||||
}
|
||||
|
||||
pyroscopeProfiler struct {
|
||||
c Config
|
||||
profiler *pyroscope.Profiler
|
||||
}
|
||||
)
|
||||
|
||||
var (
|
||||
once sync.Once
|
||||
|
||||
newProfiler = func(c Config) profiler {
|
||||
return newPyroscopeProfiler(c)
|
||||
}
|
||||
)
|
||||
|
||||
// Start initializes the pyroscope profiler with the given configuration.
|
||||
func Start(c Config) {
|
||||
// check if the profiling is enabled
|
||||
if len(c.ServerAddr) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
// set default values for the configuration
|
||||
if c.ProfilingDuration <= 0 {
|
||||
c.ProfilingDuration = defaultProfilingDuration
|
||||
}
|
||||
|
||||
// set default values for the configuration
|
||||
if c.CheckInterval <= 0 {
|
||||
c.CheckInterval = defaultCheckInterval
|
||||
}
|
||||
|
||||
if c.UploadRate <= 0 {
|
||||
c.UploadRate = defaultUploadRate
|
||||
}
|
||||
|
||||
once.Do(func() {
|
||||
logx.Info("continuous profiling started")
|
||||
|
||||
threading.GoSafe(func() {
|
||||
startPyroscope(c, proc.Done())
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
// startPyroscope starts the pyroscope profiler with the given configuration.
|
||||
func startPyroscope(c Config, done <-chan struct{}) {
|
||||
var (
|
||||
pr profiler
|
||||
err error
|
||||
latestProfilingTime time.Time
|
||||
intervalTicker = time.NewTicker(c.CheckInterval)
|
||||
profilingTicker = time.NewTicker(c.ProfilingDuration)
|
||||
)
|
||||
|
||||
defer profilingTicker.Stop()
|
||||
defer intervalTicker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-intervalTicker.C:
|
||||
// Check if the machine is overloaded and if the profiler is not running
|
||||
if pr == nil && isCpuOverloaded(c) {
|
||||
pr = newProfiler(c)
|
||||
if err := pr.Start(); err != nil {
|
||||
logx.Errorf("failed to start profiler: %v", err)
|
||||
continue
|
||||
}
|
||||
|
||||
// record the latest profiling time
|
||||
latestProfilingTime = time.Now()
|
||||
logx.Infof("pyroscope profiler started.")
|
||||
}
|
||||
case <-profilingTicker.C:
|
||||
// check if the profiling duration has passed
|
||||
if !time.Now().After(latestProfilingTime.Add(c.ProfilingDuration)) {
|
||||
continue
|
||||
}
|
||||
|
||||
// check if the profiler is already running, if so, skip
|
||||
if pr != nil {
|
||||
if err = pr.Stop(); err != nil {
|
||||
logx.Errorf("failed to stop profiler: %v", err)
|
||||
}
|
||||
logx.Infof("pyroscope profiler stopped.")
|
||||
pr = nil
|
||||
}
|
||||
case <-done:
|
||||
logx.Infof("continuous profiling stopped.")
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// genPyroscopeConf generates the pyroscope configuration based on the given config.
|
||||
func genPyroscopeConf(c Config) pyroscope.Config {
|
||||
pConf := pyroscope.Config{
|
||||
UploadRate: c.UploadRate,
|
||||
ApplicationName: c.Name,
|
||||
BasicAuthUser: c.AuthUser, // http basic auth user
|
||||
BasicAuthPassword: c.AuthPassword, // http basic auth password
|
||||
ServerAddress: c.ServerAddr,
|
||||
Logger: nil,
|
||||
HTTPHeaders: map[string]string{},
|
||||
// you can provide static tags via a map:
|
||||
Tags: map[string]string{
|
||||
"name": c.Name,
|
||||
},
|
||||
}
|
||||
|
||||
if c.ProfileType.Logger {
|
||||
pConf.Logger = logx.WithCallerSkip(0)
|
||||
}
|
||||
|
||||
if c.ProfileType.CPU {
|
||||
pConf.ProfileTypes = append(pConf.ProfileTypes, pyroscope.ProfileCPU)
|
||||
}
|
||||
if c.ProfileType.Goroutines {
|
||||
pConf.ProfileTypes = append(pConf.ProfileTypes, pyroscope.ProfileGoroutines)
|
||||
}
|
||||
if c.ProfileType.Memory {
|
||||
pConf.ProfileTypes = append(pConf.ProfileTypes, pyroscope.ProfileAllocObjects, pyroscope.ProfileAllocSpace,
|
||||
pyroscope.ProfileInuseObjects, pyroscope.ProfileInuseSpace)
|
||||
}
|
||||
if c.ProfileType.Mutex {
|
||||
pConf.ProfileTypes = append(pConf.ProfileTypes, pyroscope.ProfileMutexCount, pyroscope.ProfileMutexDuration)
|
||||
}
|
||||
if c.ProfileType.Block {
|
||||
pConf.ProfileTypes = append(pConf.ProfileTypes, pyroscope.ProfileBlockCount, pyroscope.ProfileBlockDuration)
|
||||
}
|
||||
|
||||
logx.Infof("applicationName: %s", pConf.ApplicationName)
|
||||
|
||||
return pConf
|
||||
}
|
||||
|
||||
// isCpuOverloaded checks the machine performance based on the given configuration.
|
||||
func isCpuOverloaded(c Config) bool {
|
||||
currentValue := stat.CpuUsage()
|
||||
if currentValue >= c.CpuThreshold {
|
||||
logx.Infof("continuous profiling cpu overload, cpu: %d", currentValue)
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func newPyroscopeProfiler(c Config) profiler {
|
||||
return &pyroscopeProfiler{
|
||||
c: c,
|
||||
}
|
||||
}
|
||||
|
||||
func (p *pyroscopeProfiler) Start() error {
|
||||
pConf := genPyroscopeConf(p.c)
|
||||
// set mutex and block profile rate
|
||||
setFraction(p.c)
|
||||
prof, err := pyroscope.Start(pConf)
|
||||
if err != nil {
|
||||
resetFraction(p.c)
|
||||
return err
|
||||
}
|
||||
|
||||
p.profiler = prof
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *pyroscopeProfiler) Stop() error {
|
||||
if p.profiler == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := p.profiler.Stop(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
resetFraction(p.c)
|
||||
p.profiler = nil
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func setFraction(c Config) {
|
||||
// These 2 lines are only required if you're using mutex or block profiling
|
||||
if c.ProfileType.Mutex {
|
||||
runtime.SetMutexProfileFraction(10) // 10/seconds
|
||||
}
|
||||
if c.ProfileType.Block {
|
||||
runtime.SetBlockProfileRate(1000 * 1000) // 1/millisecond
|
||||
}
|
||||
}
|
||||
|
||||
func resetFraction(c Config) {
|
||||
// These 2 lines are only required if you're using mutex or block profiling
|
||||
if c.ProfileType.Mutex {
|
||||
runtime.SetMutexProfileFraction(0)
|
||||
}
|
||||
if c.ProfileType.Block {
|
||||
runtime.SetBlockProfileRate(0)
|
||||
}
|
||||
}
|
||||
177
internal/profiling/profiling_test.go
Normal file
177
internal/profiling/profiling_test.go
Normal file
@@ -0,0 +1,177 @@
|
||||
package profiling
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/grafana/pyroscope-go"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/zeromicro/go-zero/core/conf"
|
||||
"github.com/zeromicro/go-zero/core/syncx"
|
||||
)
|
||||
|
||||
func TestStart(t *testing.T) {
|
||||
t.Run("profiling", func(t *testing.T) {
|
||||
var c Config
|
||||
assert.NoError(t, conf.FillDefault(&c))
|
||||
c.Name = "test"
|
||||
p := newProfiler(c)
|
||||
assert.NotNil(t, p)
|
||||
assert.NoError(t, p.Start())
|
||||
assert.NoError(t, p.Stop())
|
||||
})
|
||||
|
||||
t.Run("invalid config", func(t *testing.T) {
|
||||
mp := &mockProfiler{}
|
||||
newProfiler = func(c Config) profiler {
|
||||
return mp
|
||||
}
|
||||
|
||||
Start(Config{})
|
||||
|
||||
Start(Config{
|
||||
ServerAddr: "localhost:4040",
|
||||
})
|
||||
})
|
||||
|
||||
t.Run("test start profiler", func(t *testing.T) {
|
||||
mp := &mockProfiler{}
|
||||
newProfiler = func(c Config) profiler {
|
||||
return mp
|
||||
}
|
||||
|
||||
c := Config{
|
||||
Name: "test",
|
||||
ServerAddr: "localhost:4040",
|
||||
CheckInterval: time.Millisecond,
|
||||
ProfilingDuration: time.Millisecond * 10,
|
||||
CpuThreshold: 0,
|
||||
}
|
||||
var done = make(chan struct{})
|
||||
go startPyroscope(c, done)
|
||||
|
||||
time.Sleep(time.Millisecond * 50)
|
||||
close(done)
|
||||
|
||||
assert.True(t, mp.started.True())
|
||||
assert.True(t, mp.stopped.True())
|
||||
})
|
||||
|
||||
t.Run("test start profiler with cpu overloaded", func(t *testing.T) {
|
||||
mp := &mockProfiler{}
|
||||
newProfiler = func(c Config) profiler {
|
||||
return mp
|
||||
}
|
||||
|
||||
c := Config{
|
||||
Name: "test",
|
||||
ServerAddr: "localhost:4040",
|
||||
CheckInterval: time.Millisecond,
|
||||
ProfilingDuration: time.Millisecond * 10,
|
||||
CpuThreshold: 900,
|
||||
}
|
||||
var done = make(chan struct{})
|
||||
go startPyroscope(c, done)
|
||||
|
||||
time.Sleep(time.Millisecond * 50)
|
||||
close(done)
|
||||
|
||||
assert.False(t, mp.started.True())
|
||||
})
|
||||
|
||||
t.Run("start/stop err", func(t *testing.T) {
|
||||
mp := &mockProfiler{
|
||||
err: assert.AnError,
|
||||
}
|
||||
newProfiler = func(c Config) profiler {
|
||||
return mp
|
||||
}
|
||||
|
||||
c := Config{
|
||||
Name: "test",
|
||||
ServerAddr: "localhost:4040",
|
||||
CheckInterval: time.Millisecond,
|
||||
ProfilingDuration: time.Millisecond * 10,
|
||||
CpuThreshold: 0,
|
||||
}
|
||||
var done = make(chan struct{})
|
||||
go startPyroscope(c, done)
|
||||
|
||||
time.Sleep(time.Millisecond * 50)
|
||||
close(done)
|
||||
|
||||
assert.False(t, mp.started.True())
|
||||
assert.False(t, mp.stopped.True())
|
||||
})
|
||||
}
|
||||
|
||||
func TestGenPyroscopeConf(t *testing.T) {
|
||||
c := Config{
|
||||
Name: "",
|
||||
ServerAddr: "localhost:4040",
|
||||
AuthUser: "user",
|
||||
AuthPassword: "password",
|
||||
ProfileType: ProfileType{
|
||||
Logger: true,
|
||||
CPU: true,
|
||||
Goroutines: true,
|
||||
Memory: true,
|
||||
Mutex: true,
|
||||
Block: true,
|
||||
},
|
||||
}
|
||||
|
||||
pyroscopeConf := genPyroscopeConf(c)
|
||||
assert.Equal(t, c.ServerAddr, pyroscopeConf.ServerAddress)
|
||||
assert.Equal(t, c.AuthUser, pyroscopeConf.BasicAuthUser)
|
||||
assert.Equal(t, c.AuthPassword, pyroscopeConf.BasicAuthPassword)
|
||||
assert.Equal(t, c.Name, pyroscopeConf.ApplicationName)
|
||||
assert.Contains(t, pyroscopeConf.ProfileTypes, pyroscope.ProfileCPU)
|
||||
assert.Contains(t, pyroscopeConf.ProfileTypes, pyroscope.ProfileGoroutines)
|
||||
assert.Contains(t, pyroscopeConf.ProfileTypes, pyroscope.ProfileAllocObjects)
|
||||
assert.Contains(t, pyroscopeConf.ProfileTypes, pyroscope.ProfileAllocSpace)
|
||||
assert.Contains(t, pyroscopeConf.ProfileTypes, pyroscope.ProfileInuseObjects)
|
||||
assert.Contains(t, pyroscopeConf.ProfileTypes, pyroscope.ProfileInuseSpace)
|
||||
assert.Contains(t, pyroscopeConf.ProfileTypes, pyroscope.ProfileMutexCount)
|
||||
assert.Contains(t, pyroscopeConf.ProfileTypes, pyroscope.ProfileMutexDuration)
|
||||
assert.Contains(t, pyroscopeConf.ProfileTypes, pyroscope.ProfileBlockCount)
|
||||
assert.Contains(t, pyroscopeConf.ProfileTypes, pyroscope.ProfileBlockDuration)
|
||||
|
||||
setFraction(c)
|
||||
resetFraction(c)
|
||||
|
||||
newPyroscopeProfiler(c)
|
||||
}
|
||||
|
||||
func TestNewPyroscopeProfiler(t *testing.T) {
|
||||
p := newPyroscopeProfiler(Config{})
|
||||
|
||||
assert.Error(t, p.Start())
|
||||
assert.NoError(t, p.Stop())
|
||||
}
|
||||
|
||||
type mockProfiler struct {
|
||||
mutex sync.Mutex
|
||||
started syncx.AtomicBool
|
||||
stopped syncx.AtomicBool
|
||||
err error
|
||||
}
|
||||
|
||||
func (m *mockProfiler) Start() error {
|
||||
m.mutex.Lock()
|
||||
if m.err == nil {
|
||||
m.started.Set(true)
|
||||
}
|
||||
m.mutex.Unlock()
|
||||
return m.err
|
||||
}
|
||||
|
||||
func (m *mockProfiler) Stop() error {
|
||||
m.mutex.Lock()
|
||||
if m.err == nil {
|
||||
m.stopped.Set(true)
|
||||
}
|
||||
m.mutex.Unlock()
|
||||
return m.err
|
||||
}
|
||||
43
mcp/config.go
Normal file
43
mcp/config.go
Normal file
@@ -0,0 +1,43 @@
|
||||
package mcp
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/zeromicro/go-zero/rest"
|
||||
)
|
||||
|
||||
// McpConf defines the configuration for an MCP server.
|
||||
// It embeds rest.RestConf for HTTP server settings
|
||||
// and adds MCP-specific configuration options.
|
||||
type McpConf struct {
|
||||
rest.RestConf
|
||||
Mcp struct {
|
||||
// Name is the server name reported in initialize responses
|
||||
Name string `json:",optional"`
|
||||
|
||||
// Version is the server version reported in initialize responses
|
||||
Version string `json:",default=1.0.0"`
|
||||
|
||||
// ProtocolVersion is the MCP protocol version implemented
|
||||
ProtocolVersion string `json:",default=2024-11-05"`
|
||||
|
||||
// BaseUrl is the base URL for the server, used in SSE endpoint messages
|
||||
// If not set, defaults to http://localhost:{Port}
|
||||
BaseUrl string `json:",optional"`
|
||||
|
||||
// SseEndpoint is the path for Server-Sent Events connections
|
||||
SseEndpoint string `json:",default=/sse"`
|
||||
|
||||
// MessageEndpoint is the path for JSON-RPC requests
|
||||
MessageEndpoint string `json:",default=/message"`
|
||||
|
||||
// Cors contains allowed CORS origins
|
||||
Cors []string `json:",optional"`
|
||||
|
||||
// SseTimeout is the maximum time allowed for SSE connections
|
||||
SseTimeout time.Duration `json:",default=24h"`
|
||||
|
||||
// MessageTimeout is the maximum time allowed for request execution
|
||||
MessageTimeout time.Duration `json:",default=30s"`
|
||||
}
|
||||
}
|
||||
63
mcp/config_test.go
Normal file
63
mcp/config_test.go
Normal file
@@ -0,0 +1,63 @@
|
||||
package mcp
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/zeromicro/go-zero/core/conf"
|
||||
)
|
||||
|
||||
func TestMcpConfDefaults(t *testing.T) {
|
||||
// Test default values are set correctly when unmarshalled from JSON
|
||||
jsonConfig := `name: test-service
|
||||
port: 8080
|
||||
mcp:
|
||||
name: test-mcp-server
|
||||
version: 1.0.0
|
||||
`
|
||||
|
||||
var c McpConf
|
||||
err := conf.LoadFromYamlBytes([]byte(jsonConfig), &c)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Check default values
|
||||
assert.Equal(t, "test-mcp-server", c.Mcp.Name)
|
||||
assert.Equal(t, "1.0.0", c.Mcp.Version, "Default version should be 1.0.0")
|
||||
assert.Equal(t, "2024-11-05", c.Mcp.ProtocolVersion, "Default protocol version should be 2024-11-05")
|
||||
assert.Equal(t, "/sse", c.Mcp.SseEndpoint, "Default SSE endpoint should be /sse")
|
||||
assert.Equal(t, "/message", c.Mcp.MessageEndpoint, "Default message endpoint should be /message")
|
||||
assert.Equal(t, 30*time.Second, c.Mcp.MessageTimeout, "Default message timeout should be 30s")
|
||||
}
|
||||
|
||||
func TestMcpConfCustomValues(t *testing.T) {
|
||||
// Test custom values can be set
|
||||
jsonConfig := `{
|
||||
"Name": "test-service",
|
||||
"Port": 8080,
|
||||
"Mcp": {
|
||||
"Name": "test-mcp-server",
|
||||
"Version": "2.0.0",
|
||||
"ProtocolVersion": "2025-01-01",
|
||||
"BaseUrl": "http://example.com",
|
||||
"SseEndpoint": "/custom-sse",
|
||||
"MessageEndpoint": "/custom-message",
|
||||
"Cors": ["http://localhost:3000", "http://example.com"],
|
||||
"MessageTimeout": "60s"
|
||||
}
|
||||
}`
|
||||
|
||||
var c McpConf
|
||||
err := conf.LoadFromJsonBytes([]byte(jsonConfig), &c)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Check custom values
|
||||
assert.Equal(t, "test-mcp-server", c.Mcp.Name, "Name should be inherited from RestConf")
|
||||
assert.Equal(t, "2.0.0", c.Mcp.Version, "Version should be customizable")
|
||||
assert.Equal(t, "2025-01-01", c.Mcp.ProtocolVersion, "Protocol version should be customizable")
|
||||
assert.Equal(t, "http://example.com", c.Mcp.BaseUrl, "BaseUrl should be customizable")
|
||||
assert.Equal(t, "/custom-sse", c.Mcp.SseEndpoint, "SSE endpoint should be customizable")
|
||||
assert.Equal(t, "/custom-message", c.Mcp.MessageEndpoint, "Message endpoint should be customizable")
|
||||
assert.Equal(t, []string{"http://localhost:3000", "http://example.com"}, c.Mcp.Cors, "CORS settings should be customizable")
|
||||
assert.Equal(t, 60*time.Second, c.Mcp.MessageTimeout, "Tool timeout should be customizable")
|
||||
}
|
||||
443
mcp/integration_test.go
Normal file
443
mcp/integration_test.go
Normal file
@@ -0,0 +1,443 @@
|
||||
package mcp
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// syncResponseRecorder is a thread-safe wrapper around httptest.ResponseRecorder
|
||||
type syncResponseRecorder struct {
|
||||
*httptest.ResponseRecorder
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
// Create a new synchronized response recorder
|
||||
func newSyncResponseRecorder() *syncResponseRecorder {
|
||||
return &syncResponseRecorder{
|
||||
ResponseRecorder: httptest.NewRecorder(),
|
||||
}
|
||||
}
|
||||
|
||||
// Override Write method to synchronize access
|
||||
func (srr *syncResponseRecorder) Write(p []byte) (int, error) {
|
||||
srr.mu.Lock()
|
||||
defer srr.mu.Unlock()
|
||||
return srr.ResponseRecorder.Write(p)
|
||||
}
|
||||
|
||||
// Override WriteHeader method to synchronize access
|
||||
func (srr *syncResponseRecorder) WriteHeader(statusCode int) {
|
||||
srr.mu.Lock()
|
||||
defer srr.mu.Unlock()
|
||||
srr.ResponseRecorder.WriteHeader(statusCode)
|
||||
}
|
||||
|
||||
// Override Result method to synchronize access
|
||||
func (srr *syncResponseRecorder) Result() *http.Response {
|
||||
srr.mu.Lock()
|
||||
defer srr.mu.Unlock()
|
||||
return srr.ResponseRecorder.Result()
|
||||
}
|
||||
|
||||
// TestHTTPHandlerIntegration tests the HTTP handlers with a real server instance
|
||||
func TestHTTPHandlerIntegration(t *testing.T) {
|
||||
// Skip in short test mode
|
||||
if testing.Short() {
|
||||
t.Skip("Skipping integration test in short mode")
|
||||
}
|
||||
|
||||
// Create a test configuration
|
||||
conf := McpConf{}
|
||||
conf.Mcp.Name = "test-integration"
|
||||
conf.Mcp.Version = "1.0.0-test"
|
||||
conf.Mcp.MessageTimeout = 1 * time.Second
|
||||
|
||||
// Create a mock server directly
|
||||
server := &sseMcpServer{
|
||||
conf: conf,
|
||||
clients: make(map[string]*mcpClient),
|
||||
tools: make(map[string]Tool),
|
||||
prompts: make(map[string]Prompt),
|
||||
resources: make(map[string]Resource),
|
||||
}
|
||||
|
||||
// Register a test tool
|
||||
err := server.RegisterTool(Tool{
|
||||
Name: "echo",
|
||||
Description: "Echo tool for testing",
|
||||
InputSchema: InputSchema{
|
||||
Properties: map[string]any{
|
||||
"message": map[string]any{
|
||||
"type": "string",
|
||||
"description": "Message to echo",
|
||||
},
|
||||
},
|
||||
},
|
||||
Handler: func(ctx context.Context, params map[string]any) (any, error) {
|
||||
if msg, ok := params["message"].(string); ok {
|
||||
return fmt.Sprintf("Echo: %s", msg), nil
|
||||
}
|
||||
return "Echo: no message provided", nil
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create a test HTTP request to the SSE endpoint
|
||||
req := httptest.NewRequest("GET", "/sse", nil)
|
||||
w := newSyncResponseRecorder()
|
||||
|
||||
// Create a done channel to signal completion of test
|
||||
done := make(chan bool)
|
||||
|
||||
// Start the SSE handler in a goroutine
|
||||
go func() {
|
||||
// lock.Lock()
|
||||
server.handleSSE(w, req)
|
||||
// lock.Unlock()
|
||||
done <- true
|
||||
}()
|
||||
|
||||
// Allow time for the handler to process
|
||||
select {
|
||||
case <-time.After(100 * time.Millisecond):
|
||||
// Expected - handler would normally block indefinitely
|
||||
case <-done:
|
||||
// This shouldn't happen immediately - the handler should block
|
||||
t.Error("SSE handler returned unexpectedly")
|
||||
}
|
||||
|
||||
// Check the initial headers
|
||||
resp := w.Result()
|
||||
assert.Equal(t, "chunked", resp.Header.Get("Transfer-Encoding"))
|
||||
resp.Body.Close()
|
||||
|
||||
// The handler creates a client and sends the endpoint message
|
||||
var sessionId string
|
||||
|
||||
// Give the handler time to set up the client
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
|
||||
// Check that a client was created
|
||||
server.clientsLock.Lock()
|
||||
assert.Equal(t, 1, len(server.clients))
|
||||
for id := range server.clients {
|
||||
sessionId = id
|
||||
}
|
||||
server.clientsLock.Unlock()
|
||||
|
||||
require.NotEmpty(t, sessionId, "Expected a session ID to be created")
|
||||
|
||||
// Now that we have a session ID, we can test the message endpoint
|
||||
messageBody, _ := json.Marshal(Request{
|
||||
JsonRpc: "2.0",
|
||||
ID: 1,
|
||||
Method: methodInitialize,
|
||||
Params: json.RawMessage(`{}`),
|
||||
})
|
||||
|
||||
// Create a message request
|
||||
reqURL := fmt.Sprintf("/message?%s=%s", sessionIdKey, sessionId)
|
||||
msgReq := httptest.NewRequest("POST", reqURL, bytes.NewReader(messageBody))
|
||||
msgW := newSyncResponseRecorder()
|
||||
|
||||
// Process the message
|
||||
server.handleRequest(msgW, msgReq)
|
||||
|
||||
// Check the response
|
||||
msgResp := msgW.Result()
|
||||
assert.Equal(t, http.StatusAccepted, msgResp.StatusCode)
|
||||
msgResp.Body.Close() // Ensure response body is closed
|
||||
}
|
||||
|
||||
// TestHandlerResponseFlow tests the flow of a full request/response cycle
|
||||
func TestHandlerResponseFlow(t *testing.T) {
|
||||
// Create a mock server for testing
|
||||
server := &sseMcpServer{
|
||||
conf: McpConf{},
|
||||
clients: map[string]*mcpClient{
|
||||
"test-session": {
|
||||
id: "test-session",
|
||||
channel: make(chan string, 10),
|
||||
initialized: true,
|
||||
},
|
||||
},
|
||||
tools: make(map[string]Tool),
|
||||
prompts: make(map[string]Prompt),
|
||||
resources: make(map[string]Resource),
|
||||
}
|
||||
|
||||
// Register test resources
|
||||
server.RegisterTool(Tool{
|
||||
Name: "test.tool",
|
||||
Description: "Test tool",
|
||||
InputSchema: InputSchema{Type: "object"},
|
||||
Handler: func(ctx context.Context, params map[string]any) (any, error) {
|
||||
return "tool result", nil
|
||||
},
|
||||
})
|
||||
|
||||
server.RegisterPrompt(Prompt{
|
||||
Name: "test.prompt",
|
||||
Description: "Test prompt",
|
||||
})
|
||||
|
||||
server.RegisterResource(Resource{
|
||||
Name: "test.resource",
|
||||
URI: "http://example.com",
|
||||
Description: "Test resource",
|
||||
})
|
||||
|
||||
// Create a request with session ID parameter
|
||||
reqURL := fmt.Sprintf("/message?%s=%s", sessionIdKey, "test-session")
|
||||
|
||||
// Test tools/list request
|
||||
toolsListBody, _ := json.Marshal(Request{
|
||||
JsonRpc: "2.0",
|
||||
ID: 1,
|
||||
Method: methodToolsList,
|
||||
Params: json.RawMessage(`{}`),
|
||||
})
|
||||
|
||||
toolsReq := httptest.NewRequest("POST", reqURL, bytes.NewReader(toolsListBody))
|
||||
toolsW := newSyncResponseRecorder()
|
||||
|
||||
// Process the request
|
||||
server.handleRequest(toolsW, toolsReq)
|
||||
|
||||
// Check the response code
|
||||
toolsResp := toolsW.Result()
|
||||
assert.Equal(t, http.StatusAccepted, toolsResp.StatusCode)
|
||||
toolsResp.Body.Close()
|
||||
|
||||
// Check the channel message
|
||||
client := server.clients["test-session"]
|
||||
select {
|
||||
case message := <-client.channel:
|
||||
assert.Contains(t, message, `"tools":[{"name":"test.tool"`)
|
||||
case <-time.After(100 * time.Millisecond):
|
||||
t.Fatal("Timed out waiting for tools/list response")
|
||||
}
|
||||
|
||||
// Test prompts/list request
|
||||
promptsListBody, _ := json.Marshal(Request{
|
||||
JsonRpc: "2.0",
|
||||
ID: 2,
|
||||
Method: methodPromptsList,
|
||||
Params: json.RawMessage(`{}`),
|
||||
})
|
||||
|
||||
promptsReq := httptest.NewRequest("POST", reqURL, bytes.NewReader(promptsListBody))
|
||||
promptsW := newSyncResponseRecorder()
|
||||
|
||||
// Process the request
|
||||
server.handleRequest(promptsW, promptsReq)
|
||||
|
||||
// Check the response code
|
||||
promptsResp := promptsW.Result()
|
||||
assert.Equal(t, http.StatusAccepted, promptsResp.StatusCode)
|
||||
promptsResp.Body.Close()
|
||||
|
||||
// Check the channel message
|
||||
select {
|
||||
case message := <-client.channel:
|
||||
assert.Contains(t, message, `"prompts":[{"name":"test.prompt"`)
|
||||
case <-time.After(100 * time.Millisecond):
|
||||
t.Fatal("Timed out waiting for prompts/list response")
|
||||
}
|
||||
|
||||
// Test resources/list request
|
||||
resourcesListBody, _ := json.Marshal(Request{
|
||||
JsonRpc: "2.0",
|
||||
ID: 3,
|
||||
Method: methodResourcesList,
|
||||
Params: json.RawMessage(`{}`),
|
||||
})
|
||||
|
||||
resourcesReq := httptest.NewRequest("POST", reqURL, bytes.NewReader(resourcesListBody))
|
||||
resourcesW := newSyncResponseRecorder()
|
||||
|
||||
// Process the request
|
||||
server.handleRequest(resourcesW, resourcesReq)
|
||||
|
||||
// Check the response code
|
||||
resourcesResp := resourcesW.Result()
|
||||
assert.Equal(t, http.StatusAccepted, resourcesResp.StatusCode)
|
||||
resourcesResp.Body.Close()
|
||||
|
||||
// Check the channel message
|
||||
select {
|
||||
case message := <-client.channel:
|
||||
assert.Contains(t, message, `"name":"test.resource"`)
|
||||
case <-time.After(100 * time.Millisecond):
|
||||
t.Fatal("Timed out waiting for resources/list response")
|
||||
}
|
||||
}
|
||||
|
||||
// TestProcessListMethods tests the list processing methods with pagination
|
||||
func TestProcessListMethods(t *testing.T) {
|
||||
server := &sseMcpServer{
|
||||
tools: make(map[string]Tool),
|
||||
prompts: make(map[string]Prompt),
|
||||
resources: make(map[string]Resource),
|
||||
}
|
||||
|
||||
// Add some test data
|
||||
for i := 1; i <= 5; i++ {
|
||||
tool := Tool{
|
||||
Name: fmt.Sprintf("tool%d", i),
|
||||
Description: fmt.Sprintf("Tool %d", i),
|
||||
InputSchema: InputSchema{Type: "object"},
|
||||
}
|
||||
server.tools[tool.Name] = tool
|
||||
|
||||
prompt := Prompt{
|
||||
Name: fmt.Sprintf("prompt%d", i),
|
||||
Description: fmt.Sprintf("Prompt %d", i),
|
||||
}
|
||||
server.prompts[prompt.Name] = prompt
|
||||
|
||||
resource := Resource{
|
||||
Name: fmt.Sprintf("resource%d", i),
|
||||
URI: fmt.Sprintf("http://example.com/%d", i),
|
||||
Description: fmt.Sprintf("Resource %d", i),
|
||||
}
|
||||
server.resources[resource.Name] = resource
|
||||
}
|
||||
|
||||
// Create a test client
|
||||
client := &mcpClient{
|
||||
id: "test-client",
|
||||
channel: make(chan string, 10),
|
||||
initialized: true,
|
||||
}
|
||||
|
||||
// Test processListTools
|
||||
req := Request{
|
||||
JsonRpc: "2.0",
|
||||
ID: 1,
|
||||
Method: methodToolsList,
|
||||
Params: json.RawMessage(`{"cursor": "", "_meta": {"progressToken": "token1"}}`),
|
||||
}
|
||||
|
||||
server.processListTools(context.Background(), client, req)
|
||||
|
||||
// Read response
|
||||
select {
|
||||
case response := <-client.channel:
|
||||
assert.Contains(t, response, `"tools":`)
|
||||
assert.Contains(t, response, `"progressToken":"token1"`)
|
||||
case <-time.After(100 * time.Millisecond):
|
||||
t.Fatal("Timed out waiting for tools/list response")
|
||||
}
|
||||
|
||||
// Test processListPrompts
|
||||
req.ID = 2
|
||||
req.Method = methodPromptsList
|
||||
req.Params = json.RawMessage(`{"cursor": "next"}`)
|
||||
server.processListPrompts(context.Background(), client, req)
|
||||
|
||||
// Read response
|
||||
select {
|
||||
case response := <-client.channel:
|
||||
assert.Contains(t, response, `"prompts":`)
|
||||
case <-time.After(100 * time.Millisecond):
|
||||
t.Fatal("Timed out waiting for prompts/list response")
|
||||
}
|
||||
|
||||
// Test processListResources
|
||||
req.ID = 3
|
||||
req.Method = methodResourcesList
|
||||
req.Params = json.RawMessage(`{"cursor": "next"}`)
|
||||
server.processListResources(context.Background(), client, req)
|
||||
|
||||
// Read response
|
||||
select {
|
||||
case response := <-client.channel:
|
||||
assert.Contains(t, response, `"resources":`)
|
||||
case <-time.After(100 * time.Millisecond):
|
||||
t.Fatal("Timed out waiting for resources/list response")
|
||||
}
|
||||
}
|
||||
|
||||
// TestErrorResponseHandling tests error handling in the server
|
||||
func TestErrorResponseHandling(t *testing.T) {
|
||||
server := &sseMcpServer{
|
||||
tools: make(map[string]Tool),
|
||||
prompts: make(map[string]Prompt),
|
||||
resources: make(map[string]Resource),
|
||||
}
|
||||
|
||||
// Create a test client
|
||||
client := &mcpClient{
|
||||
id: "test-client",
|
||||
channel: make(chan string, 10),
|
||||
initialized: true,
|
||||
}
|
||||
|
||||
// Test invalid method
|
||||
req := Request{
|
||||
JsonRpc: "2.0",
|
||||
ID: 1,
|
||||
Method: "invalid_method",
|
||||
Params: json.RawMessage(`{}`),
|
||||
}
|
||||
|
||||
// Mock handleRequest by directly calling error handler
|
||||
server.sendErrorResponse(context.Background(), client, req.ID, "Method not found", errCodeMethodNotFound)
|
||||
|
||||
// Check response
|
||||
select {
|
||||
case response := <-client.channel:
|
||||
assert.Contains(t, response, `"error":{"code":-32601,"message":"Method not found"}`)
|
||||
case <-time.After(100 * time.Millisecond):
|
||||
t.Fatal("Timed out waiting for error response")
|
||||
}
|
||||
|
||||
// Test invalid tool
|
||||
toolReq := Request{
|
||||
JsonRpc: "2.0",
|
||||
ID: 2,
|
||||
Method: methodToolsCall,
|
||||
Params: json.RawMessage(`{"name":"non_existent_tool"}`),
|
||||
}
|
||||
|
||||
// Call process method directly
|
||||
server.processToolCall(context.Background(), client, toolReq)
|
||||
|
||||
// Check response
|
||||
select {
|
||||
case response := <-client.channel:
|
||||
assert.Contains(t, response, `"error":{"code":-32602,"message":"Tool 'non_existent_tool' not found"}`)
|
||||
case <-time.After(100 * time.Millisecond):
|
||||
t.Fatal("Timed out waiting for error response")
|
||||
}
|
||||
|
||||
// Test invalid prompt
|
||||
promptReq := Request{
|
||||
JsonRpc: "2.0",
|
||||
ID: 3,
|
||||
Method: methodPromptsGet,
|
||||
Params: json.RawMessage(`{"name":"non_existent_prompt"}`),
|
||||
}
|
||||
|
||||
// Call process method directly
|
||||
server.processGetPrompt(context.Background(), client, promptReq)
|
||||
|
||||
// Check response
|
||||
select {
|
||||
case response := <-client.channel:
|
||||
assert.Contains(t, response, `"error":{"code":-32602,"message":"Prompt 'non_existent_prompt' not found"}`)
|
||||
case <-time.After(100 * time.Millisecond):
|
||||
t.Fatal("Timed out waiting for error response")
|
||||
}
|
||||
}
|
||||
23
mcp/parser.go
Normal file
23
mcp/parser.go
Normal file
@@ -0,0 +1,23 @@
|
||||
package mcp
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/zeromicro/go-zero/core/mapping"
|
||||
)
|
||||
|
||||
// ParseArguments parses the arguments and populates the request object
|
||||
func ParseArguments(args any, req any) error {
|
||||
switch arguments := args.(type) {
|
||||
case map[string]string:
|
||||
m := make(map[string]any, len(arguments))
|
||||
for k, v := range arguments {
|
||||
m[k] = v
|
||||
}
|
||||
return mapping.UnmarshalJsonMap(m, req, mapping.WithStringValues())
|
||||
case map[string]any:
|
||||
return mapping.UnmarshalJsonMap(arguments, req)
|
||||
default:
|
||||
return fmt.Errorf("unsupported argument type: %T", arguments)
|
||||
}
|
||||
}
|
||||
139
mcp/parser_test.go
Normal file
139
mcp/parser_test.go
Normal file
@@ -0,0 +1,139 @@
|
||||
package mcp
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
// TestParseArguments_MapStringString tests parsing map[string]string arguments
|
||||
func TestParseArguments_MapStringString(t *testing.T) {
|
||||
// Sample request struct to populate
|
||||
type requestStruct struct {
|
||||
Name string `json:"name"`
|
||||
Message string `json:"message"`
|
||||
Count int `json:"count"`
|
||||
Enabled bool `json:"enabled"`
|
||||
}
|
||||
|
||||
// Create test arguments
|
||||
args := map[string]string{
|
||||
"name": "test-name",
|
||||
"message": "hello world",
|
||||
"count": "42",
|
||||
"enabled": "true",
|
||||
}
|
||||
|
||||
// Create a target object to populate
|
||||
var req requestStruct
|
||||
|
||||
// Parse the arguments
|
||||
err := ParseArguments(args, &req)
|
||||
|
||||
// Verify results
|
||||
assert.NoError(t, err, "Should parse map[string]string without error")
|
||||
assert.Equal(t, "test-name", req.Name, "Name should be correctly parsed")
|
||||
assert.Equal(t, "hello world", req.Message, "Message should be correctly parsed")
|
||||
assert.Equal(t, 42, req.Count, "Count should be correctly parsed to int")
|
||||
assert.True(t, req.Enabled, "Enabled should be correctly parsed to bool")
|
||||
}
|
||||
|
||||
// TestParseArguments_MapStringAny tests parsing map[string]any arguments
|
||||
func TestParseArguments_MapStringAny(t *testing.T) {
|
||||
// Sample request struct to populate
|
||||
type requestStruct struct {
|
||||
Name string `json:"name"`
|
||||
Message string `json:"message"`
|
||||
Count int `json:"count"`
|
||||
Enabled bool `json:"enabled"`
|
||||
Tags []string `json:"tags"`
|
||||
Metadata map[string]string `json:"metadata"`
|
||||
}
|
||||
|
||||
// Create test arguments with mixed types
|
||||
args := map[string]any{
|
||||
"name": "test-name",
|
||||
"message": "hello world",
|
||||
"count": 42, // note: this is already an int
|
||||
"enabled": true, // note: this is already a bool
|
||||
"tags": []string{"tag1", "tag2"},
|
||||
"metadata": map[string]string{
|
||||
"key1": "value1",
|
||||
"key2": "value2",
|
||||
},
|
||||
}
|
||||
|
||||
// Create a target object to populate
|
||||
var req requestStruct
|
||||
|
||||
// Parse the arguments
|
||||
err := ParseArguments(args, &req)
|
||||
|
||||
// Verify results
|
||||
assert.NoError(t, err, "Should parse map[string]any without error")
|
||||
assert.Equal(t, "test-name", req.Name, "Name should be correctly parsed")
|
||||
assert.Equal(t, "hello world", req.Message, "Message should be correctly parsed")
|
||||
assert.Equal(t, 42, req.Count, "Count should be correctly parsed")
|
||||
assert.True(t, req.Enabled, "Enabled should be correctly parsed")
|
||||
assert.Equal(t, []string{"tag1", "tag2"}, req.Tags, "Tags should be correctly parsed")
|
||||
assert.Equal(t, map[string]string{
|
||||
"key1": "value1",
|
||||
"key2": "value2",
|
||||
}, req.Metadata, "Metadata should be correctly parsed")
|
||||
}
|
||||
|
||||
// TestParseArguments_UnsupportedType tests parsing with an unsupported type
|
||||
func TestParseArguments_UnsupportedType(t *testing.T) {
|
||||
// Sample request struct to populate
|
||||
type requestStruct struct {
|
||||
Name string `json:"name"`
|
||||
Message string `json:"message"`
|
||||
}
|
||||
|
||||
// Use an unsupported argument type (slice)
|
||||
args := []string{"not", "a", "map"}
|
||||
|
||||
// Create a target object to populate
|
||||
var req requestStruct
|
||||
|
||||
// Parse the arguments
|
||||
err := ParseArguments(args, &req)
|
||||
|
||||
// Verify error is returned with correct message
|
||||
assert.Error(t, err, "Should return error for unsupported type")
|
||||
assert.Contains(t, err.Error(), "unsupported argument type", "Error should mention unsupported type")
|
||||
assert.Contains(t, err.Error(), "[]string", "Error should include the actual type")
|
||||
}
|
||||
|
||||
// TestParseArguments_EmptyMap tests parsing with empty maps
|
||||
func TestParseArguments_EmptyMap(t *testing.T) {
|
||||
// Sample request struct to populate
|
||||
type requestStruct struct {
|
||||
Name string `json:"name,optional"`
|
||||
Message string `json:"message,optional"`
|
||||
}
|
||||
|
||||
// Test empty map[string]string
|
||||
t.Run("EmptyMapStringString", func(t *testing.T) {
|
||||
args := map[string]string{}
|
||||
var req requestStruct
|
||||
|
||||
err := ParseArguments(args, &req)
|
||||
|
||||
assert.NoError(t, err, "Should parse empty map[string]string without error")
|
||||
assert.Empty(t, req.Name, "Name should be empty string")
|
||||
assert.Empty(t, req.Message, "Message should be empty string")
|
||||
})
|
||||
|
||||
// Test empty map[string]any
|
||||
t.Run("EmptyMapStringAny", func(t *testing.T) {
|
||||
args := map[string]any{}
|
||||
var req requestStruct
|
||||
|
||||
err := ParseArguments(args, &req)
|
||||
|
||||
assert.NoError(t, err, "Should parse empty map[string]any without error")
|
||||
assert.Empty(t, req.Name, "Name should be empty string")
|
||||
assert.Empty(t, req.Message, "Message should be empty string")
|
||||
})
|
||||
}
|
||||
870
mcp/readme.md
Normal file
870
mcp/readme.md
Normal file
@@ -0,0 +1,870 @@
|
||||
# Model Context Protocol (MCP) Implementation
|
||||
|
||||
## Overview
|
||||
This package implements the Model Context Protocol (MCP) server specification in Go, providing a framework for real-time communication between AI models and clients using Server-Sent Events (SSE). The implementation follows the standardized protocol for building AI-assisted applications with bidirectional communication capabilities.
|
||||
|
||||
## Core Components
|
||||
|
||||
### Server-Sent Events (SSE) Communication
|
||||
- **Real-time Communication**: Robust SSE-based communication system that maintains persistent connections with clients
|
||||
- **Connection Management**: Client registration, message broadcasting, and client cleanup mechanisms
|
||||
- **Event Handling**: Event types for tools, prompts, and resources changes
|
||||
|
||||
### JSON-RPC Implementation
|
||||
- **Request Processing**: Complete JSON-RPC request processor for handling MCP protocol methods
|
||||
- **Response Formatting**: Proper response formatting according to JSON-RPC specifications
|
||||
- **Error Handling**: Comprehensive error handling with appropriate error codes
|
||||
|
||||
### Tool Management
|
||||
- **Tool Registration**: System to register custom tools with handlers
|
||||
- **Tool Execution**: Mechanism to execute tool functions with proper timeout handling
|
||||
- **Result Handling**: Flexible result handling supporting various return types (string, JSON, images)
|
||||
|
||||
### Prompt System
|
||||
- **Prompt Registration**: System for registering both static and dynamic prompts
|
||||
- **Argument Validation**: Validation for required arguments and default values for optional ones
|
||||
- **Message Generation**: Handlers that generate properly formatted conversation messages
|
||||
|
||||
### Resource Management
|
||||
- **Resource Registration**: System for managing and accessing external resources
|
||||
- **Content Delivery**: Handlers for delivering resource content to clients on demand
|
||||
- **Resource Subscription**: Mechanisms for clients to subscribe to resource updates
|
||||
|
||||
### Protocol Features
|
||||
- **Initialization Sequence**: Proper handshaking with capability negotiation
|
||||
- **Notification Handling**: Support for both standard and client-specific notifications
|
||||
- **Message Routing**: Intelligent routing of requests to appropriate handlers
|
||||
|
||||
## Technical Highlights
|
||||
|
||||
### Configuration System
|
||||
- **Flexible Configuration**: Configuration system with sensible defaults and customization options
|
||||
- **CORS Support**: Configurable CORS settings for cross-origin requests
|
||||
- **Server Information**: Proper server identification and versioning
|
||||
|
||||
### Client Session Management
|
||||
- **Session Tracking**: Client session tracking with unique identifiers
|
||||
- **Connection Health**: Ping/pong mechanism to maintain connection health
|
||||
- **Initialization State**: Client initialization state tracking
|
||||
|
||||
### Content Handling
|
||||
- **Multi-format Content**: Support for text, code, and binary content
|
||||
- **MIME Type Support**: Proper MIME type identification for various content types
|
||||
- **Audience Annotations**: Content audience annotations for user/assistant targeting
|
||||
|
||||
## Usage
|
||||
|
||||
### Setting Up an MCP Server
|
||||
|
||||
To create and start an MCP server:
|
||||
|
||||
```go
|
||||
package main
|
||||
|
||||
import (
|
||||
"github.com/zeromicro/go-zero/core/conf"
|
||||
"github.com/zeromicro/go-zero/core/logx"
|
||||
"github.com/zeromicro/go-zero/mcp"
|
||||
)
|
||||
|
||||
func main() {
|
||||
// Load configuration from YAML file
|
||||
var c mcp.McpConf
|
||||
conf.MustLoad("config.yaml", &c)
|
||||
|
||||
// Optional: Disable stats logging
|
||||
logx.DisableStat()
|
||||
|
||||
// Create MCP server
|
||||
server := mcp.NewMcpServer(c)
|
||||
|
||||
// Register tools, prompts, and resources (examples below)
|
||||
|
||||
// Start the server and ensure it's stopped on exit
|
||||
defer server.Stop()
|
||||
server.Start()
|
||||
}
|
||||
```
|
||||
|
||||
Sample configuration file (config.yaml):
|
||||
|
||||
```yaml
|
||||
name: mcp-server
|
||||
host: localhost
|
||||
port: 8080
|
||||
mcp:
|
||||
name: my-mcp-server
|
||||
messageTimeout: 30s # Timeout for tool calls
|
||||
cors:
|
||||
- http://localhost:3000 # Optional CORS configuration
|
||||
```
|
||||
|
||||
### Registering Tools
|
||||
|
||||
Tools allow AI models to execute custom code through the MCP protocol.
|
||||
|
||||
#### Basic Tool Example:
|
||||
|
||||
```go
|
||||
// Register a simple echo tool
|
||||
echoTool := mcp.Tool{
|
||||
Name: "echo",
|
||||
Description: "Echoes back the message provided by the user",
|
||||
InputSchema: mcp.InputSchema{
|
||||
Properties: map[string]any{
|
||||
"message": map[string]any{
|
||||
"type": "string",
|
||||
"description": "The message to echo back",
|
||||
},
|
||||
"prefix": map[string]any{
|
||||
"type": "string",
|
||||
"description": "Optional prefix to add to the echoed message",
|
||||
"default": "Echo: ",
|
||||
},
|
||||
},
|
||||
Required: []string{"message"},
|
||||
},
|
||||
Handler: func(ctx context.Context, params map[string]any) (any, error) {
|
||||
var req struct {
|
||||
Message string `json:"message"`
|
||||
Prefix string `json:"prefix,optional"`
|
||||
}
|
||||
|
||||
if err := mcp.ParseArguments(params, &req); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse params: %w", err)
|
||||
}
|
||||
|
||||
prefix := "Echo: "
|
||||
if len(req.Prefix) > 0 {
|
||||
prefix = req.Prefix
|
||||
}
|
||||
|
||||
return prefix + req.Message, nil
|
||||
},
|
||||
}
|
||||
|
||||
server.RegisterTool(echoTool)
|
||||
```
|
||||
|
||||
#### Tool with Different Response Types:
|
||||
|
||||
```go
|
||||
// Tool returning JSON data
|
||||
dataTool := mcp.Tool{
|
||||
Name: "data.generate",
|
||||
Description: "Generates sample data in various formats",
|
||||
InputSchema: mcp.InputSchema{
|
||||
Properties: map[string]any{
|
||||
"format": map[string]any{
|
||||
"type": "string",
|
||||
"description": "Format of data (json, text)",
|
||||
"enum": []string{"json", "text"},
|
||||
},
|
||||
},
|
||||
},
|
||||
Handler: func(ctx context.Context, params map[string]any) (any, error) {
|
||||
var req struct {
|
||||
Format string `json:"format"`
|
||||
}
|
||||
|
||||
if err := mcp.ParseArguments(params, &req); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse params: %w", err)
|
||||
}
|
||||
|
||||
if req.Format == "json" {
|
||||
// Return structured data
|
||||
return map[string]any{
|
||||
"items": []map[string]any{
|
||||
{"id": 1, "name": "Item 1"},
|
||||
{"id": 2, "name": "Item 2"},
|
||||
},
|
||||
"count": 2,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Default to text
|
||||
return "Sample text data", nil
|
||||
},
|
||||
}
|
||||
|
||||
server.RegisterTool(dataTool)
|
||||
```
|
||||
|
||||
#### Image Generation Tool Example:
|
||||
|
||||
```go
|
||||
// Tool returning image content
|
||||
imageTool := mcp.Tool{
|
||||
Name: "image.generate",
|
||||
Description: "Generates a simple image",
|
||||
InputSchema: mcp.InputSchema{
|
||||
Properties: map[string]any{
|
||||
"type": map[string]any{
|
||||
"type": "string",
|
||||
"description": "Type of image to generate",
|
||||
"default": "placeholder",
|
||||
},
|
||||
},
|
||||
},
|
||||
Handler: func(ctx context.Context, params map[string]any) (any, error) {
|
||||
// Return image content directly
|
||||
return mcp.ImageContent{
|
||||
Data: "base64EncodedImageData...", // Base64 encoded image data
|
||||
MimeType: "image/png",
|
||||
}, nil
|
||||
},
|
||||
}
|
||||
|
||||
server.RegisterTool(imageTool)
|
||||
```
|
||||
|
||||
#### Using ToolResult for Custom Outputs:
|
||||
|
||||
```go
|
||||
// Tool that returns a custom ToolResult type
|
||||
customResultTool := mcp.Tool{
|
||||
Name: "custom.result",
|
||||
Description: "Returns a custom formatted result",
|
||||
InputSchema: mcp.InputSchema{
|
||||
Properties: map[string]any{
|
||||
"resultType": map[string]any{
|
||||
"type": "string",
|
||||
"enum": []string{"text", "image"},
|
||||
},
|
||||
},
|
||||
},
|
||||
Handler: func(ctx context.Context, params map[string]any) (any, error) {
|
||||
var req struct {
|
||||
ResultType string `json:"resultType"`
|
||||
}
|
||||
|
||||
if err := mcp.ParseArguments(params, &req); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse params: %w", err)
|
||||
}
|
||||
|
||||
if req.ResultType == "image" {
|
||||
return mcp.ToolResult{
|
||||
Type: mcp.ContentTypeImage,
|
||||
Content: map[string]any{
|
||||
"data": "base64EncodedImageData...",
|
||||
"mimeType": "image/jpeg",
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Default to text
|
||||
return mcp.ToolResult{
|
||||
Type: mcp.ContentTypeText,
|
||||
Content: "This is a text result from ToolResult",
|
||||
}, nil
|
||||
},
|
||||
}
|
||||
|
||||
server.RegisterTool(customResultTool)
|
||||
```
|
||||
|
||||
### Registering Prompts
|
||||
|
||||
Prompts are reusable conversation templates for AI models.
|
||||
|
||||
#### Static Prompt Example:
|
||||
|
||||
```go
|
||||
// Register a simple static prompt with placeholders
|
||||
server.RegisterPrompt(mcp.Prompt{
|
||||
Name: "hello",
|
||||
Description: "A simple hello prompt",
|
||||
Arguments: []mcp.PromptArgument{
|
||||
{
|
||||
Name: "name",
|
||||
Description: "The name to greet",
|
||||
Required: false,
|
||||
},
|
||||
},
|
||||
Content: "Say hello to {{name}} and introduce yourself as an AI assistant.",
|
||||
})
|
||||
```
|
||||
|
||||
#### Dynamic Prompt with Handler Function:
|
||||
|
||||
```go
|
||||
// Register a prompt with a dynamic handler function
|
||||
server.RegisterPrompt(mcp.Prompt{
|
||||
Name: "dynamic-prompt",
|
||||
Description: "A prompt that uses a handler to generate dynamic content",
|
||||
Arguments: []mcp.PromptArgument{
|
||||
{
|
||||
Name: "username",
|
||||
Description: "User's name for personalized greeting",
|
||||
Required: true,
|
||||
},
|
||||
{
|
||||
Name: "topic",
|
||||
Description: "Topic of expertise",
|
||||
Required: true,
|
||||
},
|
||||
},
|
||||
Handler: func(ctx context.Context, args map[string]string) ([]mcp.PromptMessage, error) {
|
||||
var req struct {
|
||||
Username string `json:"username"`
|
||||
Topic string `json:"topic"`
|
||||
}
|
||||
|
||||
if err := mcp.ParseArguments(args, &req); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse args: %w", err)
|
||||
}
|
||||
|
||||
// Create a user message
|
||||
userMessage := mcp.PromptMessage{
|
||||
Role: mcp.RoleUser,
|
||||
Content: mcp.TextContent{
|
||||
Text: fmt.Sprintf("Hello, I'm %s and I'd like to learn about %s.", req.Username, req.Topic),
|
||||
},
|
||||
}
|
||||
|
||||
// Create an assistant response with current time
|
||||
currentTime := time.Now().Format(time.RFC1123)
|
||||
assistantMessage := mcp.PromptMessage{
|
||||
Role: mcp.RoleAssistant,
|
||||
Content: mcp.TextContent{
|
||||
Text: fmt.Sprintf("Hello %s! I'm an AI assistant and I'll help you learn about %s. The current time is %s.",
|
||||
req.Username, req.Topic, currentTime),
|
||||
},
|
||||
}
|
||||
|
||||
// Return both messages as a conversation
|
||||
return []mcp.PromptMessage{userMessage, assistantMessage}, nil
|
||||
},
|
||||
})
|
||||
```
|
||||
|
||||
#### Multi-Message Prompt with Code Examples:
|
||||
|
||||
```go
|
||||
// Register a prompt that provides code examples in different programming languages
|
||||
server.RegisterPrompt(mcp.Prompt{
|
||||
Name: "code-example",
|
||||
Description: "Provides code examples in different programming languages",
|
||||
Arguments: []mcp.PromptArgument{
|
||||
{
|
||||
Name: "language",
|
||||
Description: "Programming language for the example",
|
||||
Required: true,
|
||||
},
|
||||
{
|
||||
Name: "complexity",
|
||||
Description: "Complexity level (simple, medium, advanced)",
|
||||
},
|
||||
},
|
||||
Handler: func(ctx context.Context, args map[string]string) ([]mcp.PromptMessage, error) {
|
||||
var req struct {
|
||||
Language string `json:"language"`
|
||||
Complexity string `json:"complexity,optional"`
|
||||
}
|
||||
|
||||
if err := mcp.ParseArguments(args, &req); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse args: %w", err)
|
||||
}
|
||||
|
||||
// Validate language
|
||||
supportedLanguages := map[string]bool{"go": true, "python": true, "javascript": true, "rust": true}
|
||||
if !supportedLanguages[req.Language] {
|
||||
return nil, fmt.Errorf("unsupported language: %s", req.Language)
|
||||
}
|
||||
|
||||
// Generate code example based on language and complexity
|
||||
var codeExample string
|
||||
|
||||
switch req.Language {
|
||||
case "go":
|
||||
if req.Complexity == "simple" {
|
||||
codeExample = `
|
||||
package main
|
||||
|
||||
import "fmt"
|
||||
|
||||
func main() {
|
||||
fmt.Println("Hello, World!")
|
||||
}`
|
||||
} else {
|
||||
codeExample = `
|
||||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"time"
|
||||
)
|
||||
|
||||
func main() {
|
||||
now := time.Now()
|
||||
fmt.Printf("Hello, World! Current time is %s\n", now.Format(time.RFC3339))
|
||||
}`
|
||||
}
|
||||
case "python":
|
||||
// Python example code
|
||||
if req.Complexity == "simple" {
|
||||
codeExample = `
|
||||
def greet(name):
|
||||
return f"Hello, {name}!"
|
||||
|
||||
print(greet("World"))`
|
||||
} else {
|
||||
codeExample = `
|
||||
import datetime
|
||||
|
||||
def greet(name, include_time=False):
|
||||
message = f"Hello, {name}!"
|
||||
if include_time:
|
||||
message += f" Current time is {datetime.datetime.now().isoformat()}"
|
||||
return message
|
||||
|
||||
print(greet("World", include_time=True))`
|
||||
}
|
||||
}
|
||||
|
||||
// Create messages array according to MCP spec
|
||||
messages := []mcp.PromptMessage{
|
||||
{
|
||||
Role: mcp.RoleAssistant,
|
||||
Content: mcp.TextContent{
|
||||
Text: fmt.Sprintf("You are a helpful coding assistant specialized in %s programming.", req.Language),
|
||||
},
|
||||
},
|
||||
{
|
||||
Role: mcp.RoleUser,
|
||||
Content: mcp.TextContent{
|
||||
Text: fmt.Sprintf("Show me a %s example of a Hello World program in %s.", req.Complexity, req.Language),
|
||||
},
|
||||
},
|
||||
{
|
||||
Role: mcp.RoleAssistant,
|
||||
Content: mcp.TextContent{
|
||||
Text: fmt.Sprintf("Here's a %s example in %s:\n\n```%s%s\n```\n\nHow can I help you implement this?",
|
||||
req.Complexity, req.Language, req.Language, codeExample),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
return messages, nil
|
||||
},
|
||||
})
|
||||
```
|
||||
|
||||
### Registering Resources
|
||||
|
||||
Resources provide access to external content such as files or generated data.
|
||||
|
||||
#### Basic Resource Example:
|
||||
|
||||
```go
|
||||
// Register a static resource
|
||||
server.RegisterResource(mcp.Resource{
|
||||
Name: "example-document",
|
||||
URI: "file:///example/document.txt",
|
||||
Description: "An example document",
|
||||
MimeType: "text/plain",
|
||||
Handler: func(ctx context.Context) (mcp.ResourceContent, error) {
|
||||
return mcp.ResourceContent{
|
||||
URI: "file:///example/document.txt",
|
||||
MimeType: "text/plain",
|
||||
Text: "This is an example document content.",
|
||||
}, nil
|
||||
},
|
||||
})
|
||||
```
|
||||
|
||||
#### Dynamic Resource with Code Example:
|
||||
|
||||
```go
|
||||
// Register a Go code resource with dynamic handler
|
||||
server.RegisterResource(mcp.Resource{
|
||||
Name: "go-example",
|
||||
URI: "file:///project/src/main.go",
|
||||
Description: "A simple Go example with multiple files",
|
||||
MimeType: "text/x-go",
|
||||
Handler: func(ctx context.Context) (mcp.ResourceContent, error) {
|
||||
// Return ResourceContent with all required fields
|
||||
return mcp.ResourceContent{
|
||||
URI: "file:///project/src/main.go",
|
||||
MimeType: "text/x-go",
|
||||
Text: "package main\n\nimport (\n\t\"fmt\"\n\t\"./greeting\"\n)\n\nfunc main() {\n\tfmt.Println(greeting.Hello(\"world\"))\n}",
|
||||
}, nil
|
||||
},
|
||||
})
|
||||
|
||||
// Register a companion file for the above example
|
||||
server.RegisterResource(mcp.Resource{
|
||||
Name: "go-greeting",
|
||||
URI: "file:///project/src/greeting/greeting.go",
|
||||
Description: "A greeting package for the Go example",
|
||||
MimeType: "text/x-go",
|
||||
Handler: func(ctx context.Context) (mcp.ResourceContent, error) {
|
||||
return mcp.ResourceContent{
|
||||
URI: "file:///project/src/greeting/greeting.go",
|
||||
MimeType: "text/x-go",
|
||||
Text: "package greeting\n\nfunc Hello(name string) string {\n\treturn \"Hello, \" + name + \"!\"\n}",
|
||||
}, nil
|
||||
},
|
||||
})
|
||||
```
|
||||
|
||||
#### Binary Resource Example:
|
||||
|
||||
```go
|
||||
// Register a binary resource (like an image)
|
||||
server.RegisterResource(mcp.Resource{
|
||||
Name: "example-image",
|
||||
URI: "file:///example/image.png",
|
||||
Description: "An example image",
|
||||
MimeType: "image/png",
|
||||
Handler: func(ctx context.Context) (mcp.ResourceContent, error) {
|
||||
// Read image from file or generate it
|
||||
imageData := "base64EncodedImageData..." // Base64 encoded image data
|
||||
|
||||
return mcp.ResourceContent{
|
||||
URI: "file:///example/image.png",
|
||||
MimeType: "image/png",
|
||||
Blob: imageData, // For binary data
|
||||
}, nil
|
||||
},
|
||||
})
|
||||
```
|
||||
|
||||
### Using Resources in Prompts
|
||||
|
||||
You can embed resources in prompt responses to create rich interactions with proper MCP-compliant structure:
|
||||
|
||||
```go
|
||||
// Register a prompt that embeds a resource
|
||||
server.RegisterPrompt(mcp.Prompt{
|
||||
Name: "resource-example",
|
||||
Description: "A prompt that embeds a resource",
|
||||
Arguments: []mcp.PromptArgument{
|
||||
{
|
||||
Name: "file_type",
|
||||
Description: "Type of file to show (rust or go)",
|
||||
Required: true,
|
||||
},
|
||||
},
|
||||
Handler: func(ctx context.Context, args map[string]string) ([]mcp.PromptMessage, error) {
|
||||
var req struct {
|
||||
FileType string `json:"file_type"`
|
||||
}
|
||||
|
||||
if err := mcp.ParseArguments(args, &req); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse args: %w", err)
|
||||
}
|
||||
|
||||
var resourceURI, mimeType, fileContent string
|
||||
if req.FileType == "rust" {
|
||||
resourceURI = "file:///project/src/main.rs"
|
||||
mimeType = "text/x-rust"
|
||||
fileContent = "fn main() {\n println!(\"Hello world!\");\n}"
|
||||
} else {
|
||||
resourceURI = "file:///project/src/main.go"
|
||||
mimeType = "text/x-go"
|
||||
fileContent = "package main\n\nimport \"fmt\"\n\nfunc main() {\n\tfmt.Println(\"Hello, world!\")\n}"
|
||||
}
|
||||
|
||||
// Create message with embedded resource using proper MCP format
|
||||
return []mcp.PromptMessage{
|
||||
{
|
||||
Role: mcp.RoleUser,
|
||||
Content: mcp.TextContent{
|
||||
Text: fmt.Sprintf("Can you explain this %s code?", req.FileType),
|
||||
},
|
||||
},
|
||||
{
|
||||
Role: mcp.RoleAssistant,
|
||||
Content: mcp.EmbeddedResource{
|
||||
Type: mcp.ContentTypeResource,
|
||||
Resource: struct {
|
||||
URI string `json:"uri"`
|
||||
MimeType string `json:"mimeType"`
|
||||
Text string `json:"text,omitempty"`
|
||||
Blob string `json:"blob,omitempty"`
|
||||
}{
|
||||
URI: resourceURI,
|
||||
MimeType: mimeType,
|
||||
Text: fileContent,
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
Role: mcp.RoleAssistant,
|
||||
Content: mcp.TextContent{
|
||||
Text: fmt.Sprintf("Above is a simple Hello World example in %s. Let me explain how it works.", req.FileType),
|
||||
},
|
||||
},
|
||||
}, nil
|
||||
},
|
||||
})
|
||||
```
|
||||
|
||||
### Multiple File Resources Example
|
||||
|
||||
```go
|
||||
// Register a prompt that demonstrates embedding multiple resource files
|
||||
server.RegisterPrompt(mcp.Prompt{
|
||||
Name: "go-code-example",
|
||||
Description: "A prompt that correctly embeds multiple resource files",
|
||||
Arguments: []mcp.PromptArgument{
|
||||
{
|
||||
Name: "format",
|
||||
Description: "How to format the code display",
|
||||
},
|
||||
},
|
||||
Handler: func(ctx context.Context, args map[string]string) ([]mcp.PromptMessage, error) {
|
||||
var req struct {
|
||||
Format string `json:"format,optional"`
|
||||
}
|
||||
|
||||
if err := mcp.ParseArguments(args, &req); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse args: %w", err)
|
||||
}
|
||||
|
||||
// Get the Go code for multiple files
|
||||
var mainGoText string = "package main\n\nimport (\n\t\"fmt\"\n\t\"./greeting\"\n)\n\nfunc main() {\n\tfmt.Println(greeting.Hello(\"world\"))\n}"
|
||||
var greetingGoText string = "package greeting\n\nfunc Hello(name string) string {\n\treturn \"Hello, \" + name + \"!\"\n}"
|
||||
|
||||
// Create message with properly formatted embedded resource per MCP spec
|
||||
messages := []mcp.PromptMessage{
|
||||
{
|
||||
Role: mcp.RoleUser,
|
||||
Content: mcp.TextContent{
|
||||
Text: "Show me a simple Go example with proper imports.",
|
||||
},
|
||||
},
|
||||
{
|
||||
Role: mcp.RoleAssistant,
|
||||
Content: mcp.TextContent{
|
||||
Text: "Here's a simple Go example project:",
|
||||
},
|
||||
},
|
||||
{
|
||||
Role: mcp.RoleAssistant,
|
||||
Content: mcp.EmbeddedResource{
|
||||
Type: mcp.ContentTypeResource,
|
||||
Resource: struct {
|
||||
URI string `json:"uri"`
|
||||
MimeType string `json:"mimeType"`
|
||||
Text string `json:"text,omitempty"`
|
||||
Blob string `json:"blob,omitempty"`
|
||||
}{
|
||||
URI: "file:///project/src/main.go",
|
||||
MimeType: "text/x-go",
|
||||
Text: mainGoText,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// Add explanation and additional file if requested
|
||||
if req.Format == "with_explanation" {
|
||||
messages = append(messages, mcp.PromptMessage{
|
||||
Role: mcp.RoleAssistant,
|
||||
Content: mcp.TextContent{
|
||||
Text: "This example demonstrates a simple Go application with modular structure. The main.go file imports from a local 'greeting' package that provides the Hello function.",
|
||||
},
|
||||
})
|
||||
|
||||
// Also show the greeting.go file with correct resource format
|
||||
messages = append(messages, mcp.PromptMessage{
|
||||
Role: mcp.RoleAssistant,
|
||||
Content: mcp.EmbeddedResource{
|
||||
Type: mcp.ContentTypeResource,
|
||||
Resource: struct {
|
||||
URI string `json:"uri"`
|
||||
MimeType string `json:"mimeType"`
|
||||
Text string `json:"text,omitempty"`
|
||||
Blob string `json:"blob,omitempty"`
|
||||
}{
|
||||
URI: "file:///project/src/greeting/greeting.go",
|
||||
MimeType: "text/x-go",
|
||||
Text: greetingGoText,
|
||||
},
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
return messages, nil
|
||||
},
|
||||
})
|
||||
```
|
||||
|
||||
### Complete Application Example
|
||||
|
||||
Here's a complete example demonstrating all the components:
|
||||
|
||||
```go
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log"
|
||||
"time"
|
||||
|
||||
"github.com/zeromicro/go-zero/core/conf"
|
||||
"github.com/zeromicro/go-zero/core/logx"
|
||||
"github.com/zeromicro/go-zero/mcp"
|
||||
)
|
||||
|
||||
func main() {
|
||||
// Load configuration
|
||||
var c mcp.McpConf
|
||||
if err := conf.Load("config.yaml", &c); err != nil {
|
||||
log.Fatalf("Failed to load config: %v", err)
|
||||
}
|
||||
|
||||
// Set up logging
|
||||
logx.DisableStat()
|
||||
|
||||
// Create MCP server
|
||||
server := mcp.NewMcpServer(c)
|
||||
defer server.Stop()
|
||||
|
||||
// Register a simple echo tool
|
||||
echoTool := mcp.Tool{
|
||||
Name: "echo",
|
||||
Description: "Echoes back the message provided by the user",
|
||||
InputSchema: mcp.InputSchema{
|
||||
Properties: map[string]any{
|
||||
"message": map[string]any{
|
||||
"type": "string",
|
||||
"description": "The message to echo back",
|
||||
},
|
||||
"prefix": map[string]any{
|
||||
"type": "string",
|
||||
"description": "Optional prefix to add to the echoed message",
|
||||
"default": "Echo: ",
|
||||
},
|
||||
},
|
||||
Required: []string{"message"},
|
||||
},
|
||||
Handler: func(ctx context.Context, params map[string]any) (any, error) {
|
||||
var req struct {
|
||||
Message string `json:"message"`
|
||||
Prefix string `json:"prefix,optional"`
|
||||
}
|
||||
|
||||
if err := mcp.ParseArguments(params, &req); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse args: %w", err)
|
||||
}
|
||||
|
||||
prefix := "Echo: "
|
||||
if len(req.Prefix) > 0 {
|
||||
prefix = req.Prefix
|
||||
}
|
||||
|
||||
return prefix + req.Message, nil
|
||||
},
|
||||
}
|
||||
server.RegisterTool(echoTool)
|
||||
|
||||
// Register a static prompt
|
||||
server.RegisterPrompt(mcp.Prompt{
|
||||
Name: "greeting",
|
||||
Description: "A simple greeting prompt",
|
||||
Arguments: []mcp.PromptArgument{
|
||||
{
|
||||
Name: "name",
|
||||
Description: "The name to greet",
|
||||
Required: true,
|
||||
},
|
||||
},
|
||||
Content: "Hello {{name}}! How can I assist you today?",
|
||||
})
|
||||
|
||||
// Register a dynamic prompt
|
||||
server.RegisterPrompt(mcp.Prompt{
|
||||
Name: "dynamic-prompt",
|
||||
Description: "A prompt that uses a handler to generate dynamic content",
|
||||
Arguments: []mcp.PromptArgument{
|
||||
{
|
||||
Name: "username",
|
||||
Description: "User's name for personalized greeting",
|
||||
Required: true,
|
||||
},
|
||||
{
|
||||
Name: "topic",
|
||||
Description: "Topic of expertise",
|
||||
Required: true,
|
||||
},
|
||||
},
|
||||
Handler: func(ctx context.Context, args map[string]string) ([]mcp.PromptMessage, error) {
|
||||
var req struct {
|
||||
Username string `json:"username"`
|
||||
Topic string `json:"topic"`
|
||||
}
|
||||
|
||||
if err := mcp.ParseArguments(args, &req); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse args: %w", err)
|
||||
}
|
||||
|
||||
// Create messages with current time
|
||||
currentTime := time.Now().Format(time.RFC1123)
|
||||
return []mcp.PromptMessage{
|
||||
{
|
||||
Role: mcp.RoleUser,
|
||||
Content: mcp.TextContent{
|
||||
Text: fmt.Sprintf("Hello, I'm %s and I'd like to learn about %s.", req.Username, req.Topic),
|
||||
},
|
||||
},
|
||||
{
|
||||
Role: mcp.RoleAssistant,
|
||||
Content: mcp.TextContent{
|
||||
Text: fmt.Sprintf("Hello %s! I'm an AI assistant and I'll help you learn about %s. The current time is %s.",
|
||||
req.Username, req.Topic, currentTime),
|
||||
},
|
||||
},
|
||||
}, nil
|
||||
},
|
||||
})
|
||||
|
||||
// Register a resource
|
||||
server.RegisterResource(mcp.Resource{
|
||||
Name: "example-doc",
|
||||
URI: "file:///example/doc.txt",
|
||||
Description: "An example document",
|
||||
MimeType: "text/plain",
|
||||
Handler: func(ctx context.Context) (mcp.ResourceContent, error) {
|
||||
return mcp.ResourceContent{
|
||||
URI: "file:///example/doc.txt",
|
||||
MimeType: "text/plain",
|
||||
Text: "This is the content of the example document.",
|
||||
}, nil
|
||||
},
|
||||
})
|
||||
|
||||
// Start the server
|
||||
fmt.Printf("Starting MCP server on %s:%d\n", c.Host, c.Port)
|
||||
server.Start()
|
||||
}
|
||||
```
|
||||
|
||||
## Error Handling
|
||||
|
||||
The MCP implementation provides comprehensive error handling:
|
||||
|
||||
- Tool execution errors are properly reported back to clients
|
||||
- Missing or invalid parameters are detected and reported with appropriate error codes
|
||||
- Resource and prompt lookup failures are handled gracefully
|
||||
- Timeout handling for long-running tool executions using context
|
||||
- Panic recovery to prevent server crashes
|
||||
|
||||
## Advanced Features
|
||||
|
||||
- **Annotations**: Add audience and priority metadata to content
|
||||
- **Content Types**: Support for text, images, audio, and other content formats
|
||||
- **Embedded Resources**: Include file resources directly in prompt responses
|
||||
- **Context Awareness**: All handlers receive context.Context for timeout and cancellation support
|
||||
- **Progress Tokens**: Support for tracking progress of long-running operations
|
||||
- **Customizable Timeouts**: Configure execution timeouts for tools and operations
|
||||
|
||||
## Performance Considerations
|
||||
|
||||
- Tool execution runs with configurable timeouts to prevent blocking
|
||||
- Efficient client tracking and cleanup to prevent resource leaks
|
||||
- Proper concurrency handling with mutex protection for shared resources
|
||||
- Buffered message channels to prevent blocking on client message delivery
|
||||
940
mcp/server.go
Normal file
940
mcp/server.go
Normal file
@@ -0,0 +1,940 @@
|
||||
package mcp
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/zeromicro/go-zero/core/logx"
|
||||
"github.com/zeromicro/go-zero/rest"
|
||||
)
|
||||
|
||||
func NewMcpServer(c McpConf) McpServer {
|
||||
var server *rest.Server
|
||||
if len(c.Mcp.Cors) == 0 {
|
||||
server = rest.MustNewServer(c.RestConf)
|
||||
} else {
|
||||
server = rest.MustNewServer(c.RestConf, rest.WithCors(c.Mcp.Cors...))
|
||||
}
|
||||
|
||||
if len(c.Mcp.Name) == 0 {
|
||||
c.Mcp.Name = c.Name
|
||||
}
|
||||
if len(c.Mcp.BaseUrl) == 0 {
|
||||
c.Mcp.BaseUrl = fmt.Sprintf("http://localhost:%d", c.Port)
|
||||
}
|
||||
|
||||
s := &sseMcpServer{
|
||||
conf: c,
|
||||
server: server,
|
||||
clients: make(map[string]*mcpClient),
|
||||
tools: make(map[string]Tool),
|
||||
prompts: make(map[string]Prompt),
|
||||
resources: make(map[string]Resource),
|
||||
}
|
||||
|
||||
// SSE endpoint for real-time updates
|
||||
s.server.AddRoute(rest.Route{
|
||||
Method: http.MethodGet,
|
||||
Path: s.conf.Mcp.SseEndpoint,
|
||||
Handler: s.handleSSE,
|
||||
}, rest.WithSSE(), rest.WithTimeout(c.Mcp.SseTimeout))
|
||||
|
||||
// JSON-RPC message endpoint for regular requests
|
||||
s.server.AddRoute(rest.Route{
|
||||
Method: http.MethodPost,
|
||||
Path: s.conf.Mcp.MessageEndpoint,
|
||||
Handler: s.handleRequest,
|
||||
}, rest.WithTimeout(c.Mcp.MessageTimeout))
|
||||
|
||||
return s
|
||||
}
|
||||
|
||||
// RegisterPrompt registers a new prompt with the server
|
||||
func (s *sseMcpServer) RegisterPrompt(prompt Prompt) {
|
||||
s.promptsLock.Lock()
|
||||
s.prompts[prompt.Name] = prompt
|
||||
s.promptsLock.Unlock()
|
||||
// Notify clients about the new prompt
|
||||
s.broadcast(eventPromptsListChanged, map[string][]Prompt{keyPrompts: {prompt}})
|
||||
}
|
||||
|
||||
// RegisterResource registers a new resource with the server
|
||||
func (s *sseMcpServer) RegisterResource(resource Resource) {
|
||||
s.resourcesLock.Lock()
|
||||
s.resources[resource.URI] = resource
|
||||
s.resourcesLock.Unlock()
|
||||
// Notify clients about the new resource
|
||||
s.broadcast(eventResourcesListChanged, map[string][]Resource{keyResources: {resource}})
|
||||
}
|
||||
|
||||
// RegisterTool registers a new tool with the server
|
||||
func (s *sseMcpServer) RegisterTool(tool Tool) error {
|
||||
if tool.Handler == nil {
|
||||
return fmt.Errorf("tool '%s' has no handler function", tool.Name)
|
||||
}
|
||||
|
||||
s.toolsLock.Lock()
|
||||
s.tools[tool.Name] = tool
|
||||
s.toolsLock.Unlock()
|
||||
// Notify clients about the new tool
|
||||
s.broadcast(eventToolsListChanged, map[string][]Tool{keyTools: {tool}})
|
||||
return nil
|
||||
}
|
||||
|
||||
// Start implements McpServer.
|
||||
func (s *sseMcpServer) Start() {
|
||||
s.server.Start()
|
||||
}
|
||||
|
||||
func (s *sseMcpServer) Stop() {
|
||||
s.server.Stop()
|
||||
}
|
||||
|
||||
// broadcast sends a message to all connected clients
|
||||
// It uses Server-Sent Events (SSE) format for real-time communication
|
||||
func (s *sseMcpServer) broadcast(event string, data any) {
|
||||
jsonData, err := json.Marshal(data)
|
||||
if err != nil {
|
||||
logx.Errorf("Failed to marshal broadcast data: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
// Lock only while reading the clients map
|
||||
s.clientsLock.Lock()
|
||||
clients := make([]*mcpClient, 0, len(s.clients))
|
||||
for _, client := range s.clients {
|
||||
clients = append(clients, client)
|
||||
}
|
||||
s.clientsLock.Unlock()
|
||||
|
||||
clientCount := len(clients)
|
||||
if clientCount == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
logx.Infof("Broadcasting event '%s' to %d clients", event, clientCount)
|
||||
|
||||
// Use CRLF line endings as per SSE specification
|
||||
message := fmt.Sprintf("event: %s\r\ndata: %s\r\n\r\n", event, string(jsonData))
|
||||
|
||||
// Send messages without holding the lock
|
||||
for _, client := range clients {
|
||||
select {
|
||||
case client.channel <- message:
|
||||
// Message sent successfully
|
||||
default:
|
||||
// Channel buffer is full, log warning and continue
|
||||
logx.Errorf("Client channel buffer full, dropping message for client %s", client.id)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// cleanupClient removes a client from the active clients map
|
||||
func (s *sseMcpServer) cleanupClient(sessionId string) {
|
||||
s.clientsLock.Lock()
|
||||
defer s.clientsLock.Unlock()
|
||||
|
||||
if client, exists := s.clients[sessionId]; exists {
|
||||
// Close the channel to signal any goroutines waiting on it
|
||||
close(client.channel)
|
||||
// Remove from active clients
|
||||
delete(s.clients, sessionId)
|
||||
logx.Infof("Cleaned up client %s (remaining clients: %d)", sessionId, len(s.clients))
|
||||
}
|
||||
}
|
||||
|
||||
// handleRequest handles MCP JSON-RPC requests
|
||||
func (s *sseMcpServer) handleRequest(w http.ResponseWriter, r *http.Request) {
|
||||
// Extract sessionId from query parameters
|
||||
sessionId := r.URL.Query().Get(sessionIdKey)
|
||||
if len(sessionId) == 0 {
|
||||
http.Error(w, fmt.Sprintf("Missing %s", sessionIdKey), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
// Check if the client with this sessionId exists
|
||||
s.clientsLock.Lock()
|
||||
client, exists := s.clients[sessionId]
|
||||
s.clientsLock.Unlock()
|
||||
|
||||
if !exists {
|
||||
http.Error(w, fmt.Sprintf("Invalid or expired %s", sessionIdKey), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
var req Request
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
http.Error(w, "Invalid request", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
// For notification methods (no ID), we don't send a response
|
||||
isNotification, err := req.isNotification()
|
||||
if err != nil {
|
||||
http.Error(w, "Invalid request.ID", http.StatusBadRequest)
|
||||
}
|
||||
|
||||
w.WriteHeader(http.StatusAccepted)
|
||||
|
||||
// Special handling for initialization sequence
|
||||
// Always allow initialize and notifications/initialized regardless of client state
|
||||
if req.Method == methodInitialize {
|
||||
logx.Infof("Processing initialize request with ID: %v", req.ID)
|
||||
s.processInitialize(r.Context(), client, req)
|
||||
logx.Infof("Sent initialize response for ID: %v, waiting for notifications/initialized", req.ID)
|
||||
return
|
||||
} else if req.Method == methodNotificationsInitialized {
|
||||
// Handle initialized notification
|
||||
logx.Info("Received notifications/initialized notification")
|
||||
if !isNotification {
|
||||
s.sendErrorResponse(r.Context(), client, req.ID,
|
||||
"Method should be used as a notification", errCodeInvalidRequest)
|
||||
return
|
||||
}
|
||||
s.processNotificationInitialized(client)
|
||||
return
|
||||
} else if !client.initialized && req.Method != methodNotificationsCancelled {
|
||||
// Block most requests until client is initialized (except for cancellations)
|
||||
s.sendErrorResponse(r.Context(), client, req.ID,
|
||||
"Client not fully initialized, waiting for notifications/initialized",
|
||||
errCodeClientNotInitialized)
|
||||
return
|
||||
}
|
||||
|
||||
// Process normal requests only after initialization
|
||||
switch req.Method {
|
||||
case methodToolsCall:
|
||||
logx.Infof("Received tools call request with ID: %v", req.ID)
|
||||
s.processToolCall(r.Context(), client, req)
|
||||
logx.Infof("Sent tools call response for ID: %v", req.ID)
|
||||
case methodToolsList:
|
||||
logx.Infof("Processing tools/list request with ID: %v", req.ID)
|
||||
s.processListTools(r.Context(), client, req)
|
||||
logx.Infof("Sent tools/list response for ID: %v", req.ID)
|
||||
case methodPromptsList:
|
||||
logx.Infof("Processing prompts/list request with ID: %v", req.ID)
|
||||
s.processListPrompts(r.Context(), client, req)
|
||||
logx.Infof("Sent prompts/list response for ID: %v", req.ID)
|
||||
case methodPromptsGet:
|
||||
logx.Infof("Processing prompts/get request with ID: %v", req.ID)
|
||||
s.processGetPrompt(r.Context(), client, req)
|
||||
logx.Infof("Sent prompts/get response for ID: %v", req.ID)
|
||||
case methodResourcesList:
|
||||
logx.Infof("Processing resources/list request with ID: %v", req.ID)
|
||||
s.processListResources(r.Context(), client, req)
|
||||
logx.Infof("Sent resources/list response for ID: %v", req.ID)
|
||||
case methodResourcesRead:
|
||||
logx.Infof("Processing resources/read request with ID: %v", req.ID)
|
||||
s.processResourcesRead(r.Context(), client, req)
|
||||
logx.Infof("Sent resources/read response for ID: %v", req.ID)
|
||||
case methodResourcesSubscribe:
|
||||
logx.Infof("Processing resources/subscribe request with ID: %v", req.ID)
|
||||
s.processResourceSubscribe(r.Context(), client, req)
|
||||
logx.Infof("Sent resources/subscribe response for ID: %v", req.ID)
|
||||
case methodPing:
|
||||
logx.Infof("Processing ping request with ID: %v", req.ID)
|
||||
s.processPing(r.Context(), client, req)
|
||||
case methodNotificationsCancelled:
|
||||
logx.Infof("Received notifications/cancelled notification: %v", req.ID)
|
||||
s.processNotificationCancelled(r.Context(), client, req)
|
||||
default:
|
||||
logx.Infof("Unknown method: %s from client: %v", req.Method, req.ID)
|
||||
s.sendErrorResponse(r.Context(), client, req.ID, "Method not found", errCodeMethodNotFound)
|
||||
}
|
||||
}
|
||||
|
||||
// handleSSE handles Server-Sent Events connections
|
||||
func (s *sseMcpServer) handleSSE(w http.ResponseWriter, r *http.Request) {
|
||||
// Generate a unique session ID for this client
|
||||
sessionId := uuid.New().String()
|
||||
|
||||
// Create new client with buffered channel to prevent blocking
|
||||
client := &mcpClient{
|
||||
id: sessionId,
|
||||
channel: make(chan string, eventChanSize),
|
||||
}
|
||||
|
||||
// Add client to active clients map
|
||||
s.clientsLock.Lock()
|
||||
s.clients[sessionId] = client
|
||||
activeClients := len(s.clients)
|
||||
s.clientsLock.Unlock()
|
||||
|
||||
logx.Infof("New SSE connection established for client %s (active clients: %d)",
|
||||
sessionId, activeClients)
|
||||
|
||||
// Set proper SSE headers
|
||||
w.Header().Set("Transfer-Encoding", "chunked")
|
||||
|
||||
// Enable streaming
|
||||
flusher, ok := w.(http.Flusher)
|
||||
if !ok {
|
||||
logx.Error("Streaming not supported by the underlying http.ResponseWriter")
|
||||
http.Error(w, "Streaming not supported", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
// Send the message endpoint URL to the client
|
||||
endpoint := fmt.Sprintf("%s%s?%s=%s",
|
||||
s.conf.Mcp.BaseUrl, s.conf.Mcp.MessageEndpoint, sessionIdKey, sessionId)
|
||||
|
||||
// Format and send the endpoint message
|
||||
endpointMsg := formatSSEMessage(eventEndpoint, []byte(endpoint))
|
||||
if _, err := fmt.Fprint(w, endpointMsg); err != nil {
|
||||
logx.Errorf("Failed to send endpoint message to client %s: %v", sessionId, err)
|
||||
s.cleanupClient(sessionId)
|
||||
return
|
||||
}
|
||||
flusher.Flush()
|
||||
|
||||
// Set up keep-alive ping and client cleanup
|
||||
ticker := time.NewTicker(pingInterval.Load())
|
||||
defer func() {
|
||||
ticker.Stop()
|
||||
s.cleanupClient(sessionId)
|
||||
logx.Infof("SSE connection closed for client %s", sessionId)
|
||||
}()
|
||||
|
||||
// Message processing loop
|
||||
for {
|
||||
select {
|
||||
case message, ok := <-client.channel:
|
||||
if !ok {
|
||||
// Channel was closed, end connection
|
||||
logx.Infof("Client channel was closed for %s", sessionId)
|
||||
return
|
||||
}
|
||||
|
||||
// Write message to the response
|
||||
if _, err := fmt.Fprint(w, message); err != nil {
|
||||
logx.Infof("Failed to write message to client %s: %v", sessionId, err)
|
||||
return
|
||||
}
|
||||
flusher.Flush()
|
||||
case <-ticker.C:
|
||||
// Send keep-alive ping to maintain connection
|
||||
ping := fmt.Sprintf(`{"type":"ping","timestamp":"%s"}`, time.Now().String())
|
||||
pingMsg := formatSSEMessage("ping", []byte(ping))
|
||||
if _, err := fmt.Fprint(w, pingMsg); err != nil {
|
||||
logx.Errorf("Failed to send ping to client %s, closing connection: %v", sessionId, err)
|
||||
return
|
||||
}
|
||||
flusher.Flush()
|
||||
case <-r.Context().Done():
|
||||
// Client disconnected or request was canceled or timed out
|
||||
logx.Infof("Client %s disconnected: context done", sessionId)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// processInitialize processes the initialize request
|
||||
func (s *sseMcpServer) processInitialize(ctx context.Context, client *mcpClient, req Request) {
|
||||
// Create a proper JSON-RPC response that preserves the client's request ID
|
||||
result := initializationResponse{
|
||||
ProtocolVersion: s.conf.Mcp.ProtocolVersion,
|
||||
Capabilities: capabilities{
|
||||
Prompts: struct {
|
||||
ListChanged bool `json:"listChanged"`
|
||||
}{
|
||||
ListChanged: true,
|
||||
},
|
||||
Resources: struct {
|
||||
Subscribe bool `json:"subscribe"`
|
||||
ListChanged bool `json:"listChanged"`
|
||||
}{
|
||||
Subscribe: true,
|
||||
ListChanged: true,
|
||||
},
|
||||
Tools: struct {
|
||||
ListChanged bool `json:"listChanged"`
|
||||
}{
|
||||
ListChanged: true,
|
||||
},
|
||||
},
|
||||
ServerInfo: serverInfo{
|
||||
Name: s.conf.Mcp.Name,
|
||||
Version: s.conf.Mcp.Version,
|
||||
},
|
||||
}
|
||||
|
||||
// Mark client as initialized
|
||||
client.initialized = true
|
||||
|
||||
// Send response with client's original request ID
|
||||
s.sendResponse(ctx, client, req.ID, result)
|
||||
}
|
||||
|
||||
// processListTools processes the tools/list request
|
||||
func (s *sseMcpServer) processListTools(ctx context.Context, client *mcpClient, req Request) {
|
||||
// Extract pagination params if any
|
||||
var nextCursor string
|
||||
var progressToken any
|
||||
|
||||
// Extract meta data including progress token
|
||||
if req.Params != nil {
|
||||
var metaParams struct {
|
||||
Cursor string `json:"cursor"`
|
||||
Meta struct {
|
||||
ProgressToken any `json:"progressToken"`
|
||||
} `json:"_meta"`
|
||||
}
|
||||
if err := json.Unmarshal(req.Params, &metaParams); err == nil {
|
||||
if len(metaParams.Cursor) > 0 {
|
||||
nextCursor = metaParams.Cursor
|
||||
}
|
||||
progressToken = metaParams.Meta.ProgressToken
|
||||
}
|
||||
}
|
||||
|
||||
var toolsList []Tool
|
||||
s.toolsLock.Lock()
|
||||
for _, tool := range s.tools {
|
||||
if len(tool.InputSchema.Type) == 0 {
|
||||
tool.InputSchema.Type = ContentTypeObject
|
||||
}
|
||||
toolsList = append(toolsList, tool)
|
||||
}
|
||||
s.toolsLock.Unlock()
|
||||
|
||||
result := ListToolsResult{
|
||||
PaginatedResult: PaginatedResult{
|
||||
Result: Result{},
|
||||
NextCursor: Cursor(nextCursor),
|
||||
},
|
||||
Tools: toolsList,
|
||||
}
|
||||
|
||||
// Add meta information if progress token was provided
|
||||
if progressToken != nil {
|
||||
result.Result.Meta = map[string]any{
|
||||
progressTokenKey: progressToken,
|
||||
}
|
||||
}
|
||||
|
||||
s.sendResponse(ctx, client, req.ID, result)
|
||||
}
|
||||
|
||||
// processListPrompts processes the prompts/list request
|
||||
func (s *sseMcpServer) processListPrompts(ctx context.Context, client *mcpClient, req Request) {
|
||||
// Extract pagination params if any
|
||||
var nextCursor string
|
||||
if req.Params != nil {
|
||||
var cursorParams struct {
|
||||
Cursor string `json:"cursor"`
|
||||
}
|
||||
if err := json.Unmarshal(req.Params, &cursorParams); err == nil && cursorParams.Cursor != "" {
|
||||
// If we have a valid cursor, we could use it for pagination
|
||||
// For now, we're not actually implementing pagination, so this is just
|
||||
// to show how it would be extracted from the request
|
||||
_ = cursorParams.Cursor
|
||||
}
|
||||
}
|
||||
|
||||
// Prepare prompt list
|
||||
var promptsList []Prompt
|
||||
s.promptsLock.Lock()
|
||||
for _, prompt := range s.prompts {
|
||||
promptsList = append(promptsList, prompt)
|
||||
}
|
||||
s.promptsLock.Unlock()
|
||||
|
||||
// In a real implementation, you'd handle pagination here
|
||||
// For now, we'll return all prompts at once
|
||||
result := struct {
|
||||
Prompts []Prompt `json:"prompts"`
|
||||
NextCursor string `json:"nextCursor,omitempty"`
|
||||
Meta *struct{} `json:"_meta,omitempty"`
|
||||
}{
|
||||
Prompts: promptsList,
|
||||
NextCursor: nextCursor,
|
||||
}
|
||||
|
||||
s.sendResponse(ctx, client, req.ID, result)
|
||||
}
|
||||
|
||||
// processListResources processes the resources/list request
|
||||
func (s *sseMcpServer) processListResources(ctx context.Context, client *mcpClient, req Request) {
|
||||
// Extract pagination params if any
|
||||
var nextCursor string
|
||||
var progressToken any
|
||||
|
||||
// Extract meta information including progress token if available
|
||||
if req.Params != nil {
|
||||
var metaParams PaginatedParams
|
||||
if err := json.Unmarshal(req.Params, &metaParams); err == nil {
|
||||
if len(metaParams.Cursor) > 0 {
|
||||
nextCursor = metaParams.Cursor
|
||||
}
|
||||
progressToken = metaParams.Meta.ProgressToken
|
||||
}
|
||||
}
|
||||
|
||||
var resourcesList []Resource
|
||||
s.resourcesLock.Lock()
|
||||
for _, resource := range s.resources {
|
||||
// Create a copy without the handler function which shouldn't be sent to clients
|
||||
resourceCopy := Resource{
|
||||
URI: resource.URI,
|
||||
Name: resource.Name,
|
||||
Description: resource.Description,
|
||||
MimeType: resource.MimeType,
|
||||
}
|
||||
resourcesList = append(resourcesList, resourceCopy)
|
||||
}
|
||||
s.resourcesLock.Unlock()
|
||||
|
||||
// Create proper ResourcesListResult according to MCP specification
|
||||
result := ResourcesListResult{
|
||||
PaginatedResult: PaginatedResult{
|
||||
Result: Result{},
|
||||
NextCursor: Cursor(nextCursor),
|
||||
},
|
||||
Resources: resourcesList,
|
||||
}
|
||||
|
||||
// Add meta information if progress token was provided
|
||||
if progressToken != nil {
|
||||
result.Result.Meta = map[string]any{
|
||||
progressTokenKey: progressToken,
|
||||
}
|
||||
}
|
||||
|
||||
s.sendResponse(ctx, client, req.ID, result)
|
||||
}
|
||||
|
||||
// processGetPrompt processes the prompts/get request
|
||||
func (s *sseMcpServer) processGetPrompt(ctx context.Context, client *mcpClient, req Request) {
|
||||
type GetPromptParams struct {
|
||||
Name string `json:"name"`
|
||||
Arguments map[string]string `json:"arguments,omitempty"`
|
||||
}
|
||||
|
||||
var params GetPromptParams
|
||||
if err := json.Unmarshal(req.Params, ¶ms); err != nil {
|
||||
s.sendErrorResponse(ctx, client, req.ID, "Invalid parameters", errCodeInvalidParams)
|
||||
return
|
||||
}
|
||||
|
||||
// Check if prompt exists
|
||||
s.promptsLock.Lock()
|
||||
prompt, exists := s.prompts[params.Name]
|
||||
s.promptsLock.Unlock()
|
||||
if !exists {
|
||||
message := fmt.Sprintf("Prompt '%s' not found", params.Name)
|
||||
s.sendErrorResponse(ctx, client, req.ID, message, errCodeInvalidParams)
|
||||
return
|
||||
}
|
||||
|
||||
logx.Infof("Processing prompt request: %s with %d arguments", prompt.Name, len(params.Arguments))
|
||||
|
||||
// Validate required arguments
|
||||
missingArgs := validatePromptArguments(prompt, params.Arguments)
|
||||
if len(missingArgs) > 0 {
|
||||
message := fmt.Sprintf("Missing required arguments: %s", strings.Join(missingArgs, ", "))
|
||||
s.sendErrorResponse(ctx, client, req.ID, message, errCodeInvalidParams)
|
||||
return
|
||||
}
|
||||
|
||||
// Ensure arguments are initialized to an empty map if nil
|
||||
if params.Arguments == nil {
|
||||
params.Arguments = make(map[string]string)
|
||||
}
|
||||
args := params.Arguments
|
||||
|
||||
// Generate messages using handler or static content
|
||||
var messages []PromptMessage
|
||||
var err error
|
||||
|
||||
if prompt.Handler != nil {
|
||||
// Use dynamic handler to generate messages
|
||||
messages, err = prompt.Handler(ctx, args)
|
||||
if err != nil {
|
||||
logx.Errorf("Error from prompt handler: %v", err)
|
||||
s.sendErrorResponse(ctx, client, req.ID,
|
||||
fmt.Sprintf("Error generating prompt content: %v", err), errCodeInternalError)
|
||||
return
|
||||
}
|
||||
} else {
|
||||
// No handler, generate messages from static content
|
||||
var messageText string
|
||||
if len(prompt.Content) > 0 {
|
||||
messageText = prompt.Content
|
||||
|
||||
// Apply argument substitutions to static content
|
||||
for key, value := range args {
|
||||
placeholder := fmt.Sprintf("{{%s}}", key)
|
||||
messageText = strings.Replace(messageText, placeholder, value, -1)
|
||||
}
|
||||
}
|
||||
|
||||
// Create a single user message with the content
|
||||
messages = []PromptMessage{
|
||||
{
|
||||
Role: RoleUser,
|
||||
Content: TextContent{
|
||||
Text: messageText,
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// Construct the response according to MCP spec
|
||||
result := struct {
|
||||
Description string `json:"description,omitempty"`
|
||||
Messages []PromptMessage `json:"messages"`
|
||||
}{
|
||||
Description: prompt.Description,
|
||||
Messages: toTypedPromptMessages(messages),
|
||||
}
|
||||
|
||||
s.sendResponse(ctx, client, req.ID, result)
|
||||
}
|
||||
|
||||
// processToolCall processes the tools/call request
|
||||
func (s *sseMcpServer) processToolCall(ctx context.Context, client *mcpClient, req Request) {
|
||||
var toolCallParams struct {
|
||||
Name string `json:"name"`
|
||||
Arguments map[string]any `json:"arguments,omitempty"`
|
||||
Meta struct {
|
||||
ProgressToken any `json:"progressToken"`
|
||||
} `json:"_meta,omitempty"`
|
||||
}
|
||||
|
||||
// Handle different types of req.Params
|
||||
// If it's a RawMessage (JSON), unmarshal it
|
||||
if err := json.Unmarshal(req.Params, &toolCallParams); err != nil {
|
||||
logx.Errorf("Failed to unmarshal tool call params: %v", err)
|
||||
s.sendErrorResponse(ctx, client, req.ID, "Invalid tool call parameters", errCodeInvalidParams)
|
||||
return
|
||||
}
|
||||
|
||||
// Extract progress token if available
|
||||
progressToken := toolCallParams.Meta.ProgressToken
|
||||
|
||||
// Find the requested tool
|
||||
s.toolsLock.Lock()
|
||||
tool, exists := s.tools[toolCallParams.Name]
|
||||
s.toolsLock.Unlock()
|
||||
if !exists {
|
||||
s.sendErrorResponse(ctx, client, req.ID, fmt.Sprintf("Tool '%s' not found",
|
||||
toolCallParams.Name), errCodeInvalidParams)
|
||||
return
|
||||
}
|
||||
|
||||
// Log parameters before execution
|
||||
logx.Infof("Executing tool '%s' with arguments: %#v", toolCallParams.Name, toolCallParams.Arguments)
|
||||
|
||||
// Execute the tool handler with timeout handling
|
||||
var result any
|
||||
var err error
|
||||
|
||||
// Create a channel to receive the result
|
||||
// make sure to have 1 size buffer to avoid channel leak if timeout
|
||||
resultCh := make(chan struct {
|
||||
result any
|
||||
err error
|
||||
}, 1)
|
||||
|
||||
// Execute the tool handler in a goroutine
|
||||
go func() {
|
||||
toolResult, toolErr := tool.Handler(ctx, toolCallParams.Arguments)
|
||||
resultCh <- struct {
|
||||
result any
|
||||
err error
|
||||
}{
|
||||
result: toolResult,
|
||||
err: toolErr,
|
||||
}
|
||||
}()
|
||||
|
||||
// Wait for either the result or a timeout
|
||||
select {
|
||||
case res := <-resultCh:
|
||||
result = res.result
|
||||
err = res.err
|
||||
case <-ctx.Done():
|
||||
// Handle request timeout
|
||||
logx.Errorf("Tool execution timed out after %v: %s", s.conf.Mcp.MessageTimeout, toolCallParams.Name)
|
||||
s.sendErrorResponse(ctx, client, req.ID, "Tool execution timed out", errCodeTimeout)
|
||||
return
|
||||
}
|
||||
|
||||
// Create the base result structure with metadata
|
||||
callToolResult := CallToolResult{
|
||||
Result: Result{},
|
||||
Content: []any{},
|
||||
IsError: false,
|
||||
}
|
||||
|
||||
// Add meta information if progress token was provided
|
||||
if progressToken != nil {
|
||||
callToolResult.Result.Meta = map[string]any{
|
||||
progressTokenKey: progressToken,
|
||||
}
|
||||
}
|
||||
|
||||
// Check if there was an error during tool execution
|
||||
if err != nil {
|
||||
// According to the spec, for tool-level errors (as opposed to protocol-level errors),
|
||||
// we should report them inside the result with isError=true
|
||||
logx.Errorf("Tool execution reported error: %v", err)
|
||||
|
||||
callToolResult.Content = []any{
|
||||
TextContent{
|
||||
Text: fmt.Sprintf("Error: %v", err),
|
||||
},
|
||||
}
|
||||
callToolResult.IsError = true
|
||||
s.sendResponse(ctx, client, req.ID, callToolResult)
|
||||
return
|
||||
}
|
||||
|
||||
// Format the response according to the CallToolResult schema
|
||||
switch v := result.(type) {
|
||||
case string:
|
||||
// Simple string becomes text content
|
||||
callToolResult.Content = append(callToolResult.Content, TextContent{
|
||||
Text: v,
|
||||
Annotations: &Annotations{
|
||||
Audience: []RoleType{RoleUser, RoleAssistant},
|
||||
},
|
||||
})
|
||||
case map[string]any:
|
||||
// JSON-like object becomes formatted JSON text
|
||||
jsonStr, err := json.Marshal(v)
|
||||
if err != nil {
|
||||
jsonStr = []byte(err.Error())
|
||||
}
|
||||
callToolResult.Content = append(callToolResult.Content, TextContent{
|
||||
Text: string(jsonStr),
|
||||
Annotations: &Annotations{
|
||||
Audience: []RoleType{RoleUser, RoleAssistant},
|
||||
},
|
||||
})
|
||||
case TextContent:
|
||||
callToolResult.Content = append(callToolResult.Content, v)
|
||||
case ImageContent:
|
||||
callToolResult.Content = append(callToolResult.Content, v)
|
||||
case []any:
|
||||
callToolResult.Content = v
|
||||
case ToolResult:
|
||||
// Handle legacy ToolResult type
|
||||
switch v.Type {
|
||||
case ContentTypeText:
|
||||
callToolResult.Content = append(callToolResult.Content, TextContent{
|
||||
Text: fmt.Sprintf("%v", v.Content),
|
||||
Annotations: &Annotations{
|
||||
Audience: []RoleType{RoleUser, RoleAssistant},
|
||||
},
|
||||
})
|
||||
case ContentTypeImage:
|
||||
if imgData, ok := v.Content.(map[string]any); ok {
|
||||
callToolResult.Content = append(callToolResult.Content, ImageContent{
|
||||
Data: fmt.Sprintf("%v", imgData["data"]),
|
||||
MimeType: fmt.Sprintf("%v", imgData["mimeType"]),
|
||||
})
|
||||
}
|
||||
default:
|
||||
callToolResult.Content = append(callToolResult.Content, TextContent{
|
||||
Text: fmt.Sprintf("%v", v.Content),
|
||||
Annotations: &Annotations{
|
||||
Audience: []RoleType{RoleUser, RoleAssistant},
|
||||
},
|
||||
})
|
||||
}
|
||||
default:
|
||||
// For any other type, convert to string
|
||||
callToolResult.Content = append(callToolResult.Content, TextContent{
|
||||
Text: fmt.Sprintf("%v", v),
|
||||
Annotations: &Annotations{
|
||||
Audience: []RoleType{RoleUser, RoleAssistant},
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
callToolResult.Content = toTypedContents(callToolResult.Content)
|
||||
logx.Infof("Tool call result: %#v", callToolResult)
|
||||
|
||||
s.sendResponse(ctx, client, req.ID, callToolResult)
|
||||
}
|
||||
|
||||
// processResourcesRead processes the resources/read request
|
||||
func (s *sseMcpServer) processResourcesRead(ctx context.Context, client *mcpClient, req Request) {
|
||||
var params ResourceReadParams
|
||||
if err := json.Unmarshal(req.Params, ¶ms); err != nil {
|
||||
s.sendErrorResponse(ctx, client, req.ID, "Invalid parameters", errCodeInvalidParams)
|
||||
return
|
||||
}
|
||||
|
||||
// Find resource that matches the URI
|
||||
s.resourcesLock.Lock()
|
||||
resource, exists := s.resources[params.URI]
|
||||
s.resourcesLock.Unlock()
|
||||
|
||||
if !exists {
|
||||
s.sendErrorResponse(ctx, client, req.ID, fmt.Sprintf("Resource with URI '%s' not found",
|
||||
params.URI), errCodeResourceNotFound)
|
||||
return
|
||||
}
|
||||
|
||||
// If no handler is provided, return an empty content array
|
||||
if resource.Handler == nil {
|
||||
result := ResourceReadResult{
|
||||
Contents: []ResourceContent{
|
||||
{
|
||||
URI: params.URI,
|
||||
MimeType: resource.MimeType,
|
||||
Text: "",
|
||||
},
|
||||
},
|
||||
}
|
||||
s.sendResponse(ctx, client, req.ID, result)
|
||||
return
|
||||
}
|
||||
|
||||
// Execute the resource handler
|
||||
content, err := resource.Handler(ctx)
|
||||
if err != nil {
|
||||
s.sendErrorResponse(ctx, client, req.ID, fmt.Sprintf("Error reading resource: %v", err),
|
||||
errCodeInternalError)
|
||||
return
|
||||
}
|
||||
|
||||
// Ensure the URI is set if not already provided by the handler
|
||||
if len(content.URI) == 0 {
|
||||
content.URI = params.URI
|
||||
}
|
||||
|
||||
// Ensure MimeType is set if available from the resource definition
|
||||
if len(content.MimeType) == 0 && len(resource.MimeType) > 0 {
|
||||
content.MimeType = resource.MimeType
|
||||
}
|
||||
|
||||
// Create response with contents from the handler
|
||||
// The MCP specification requires a contents array
|
||||
result := ResourceReadResult{
|
||||
Contents: []ResourceContent{content},
|
||||
}
|
||||
|
||||
s.sendResponse(ctx, client, req.ID, result)
|
||||
}
|
||||
|
||||
// processResourceSubscribe processes the resources/subscribe request
|
||||
func (s *sseMcpServer) processResourceSubscribe(ctx context.Context, client *mcpClient, req Request) {
|
||||
var params ResourceSubscribeParams
|
||||
if err := json.Unmarshal(req.Params, ¶ms); err != nil {
|
||||
s.sendErrorResponse(ctx, client, req.ID, "Invalid parameters", errCodeInvalidParams)
|
||||
return
|
||||
}
|
||||
|
||||
// Check if the resource exists
|
||||
s.resourcesLock.Lock()
|
||||
_, exists := s.resources[params.URI]
|
||||
s.resourcesLock.Unlock()
|
||||
|
||||
if !exists {
|
||||
s.sendErrorResponse(ctx, client, req.ID, fmt.Sprintf("Resource with URI '%s' not found",
|
||||
params.URI), errCodeResourceNotFound)
|
||||
return
|
||||
}
|
||||
|
||||
// Send success response for the subscription
|
||||
s.sendResponse(ctx, client, req.ID, struct{}{})
|
||||
}
|
||||
|
||||
// processNotificationCancelled processes the notifications/cancelled notification
|
||||
func (s *sseMcpServer) processNotificationCancelled(ctx context.Context, client *mcpClient, req Request) {
|
||||
// Extract the requestId that was canceled
|
||||
type CancelParams struct {
|
||||
RequestId int64 `json:"requestId"`
|
||||
Reason string `json:"reason"`
|
||||
}
|
||||
|
||||
var params CancelParams
|
||||
if err := json.Unmarshal(req.Params, ¶ms); err != nil {
|
||||
logx.Errorf("Failed to parse cancellation params: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
logx.Infof("Request %d was cancelled by client. Reason: %s", params.RequestId, params.Reason)
|
||||
}
|
||||
|
||||
// processNotificationInitialized processes the notifications/initialized notification
|
||||
func (s *sseMcpServer) processNotificationInitialized(client *mcpClient) {
|
||||
// Mark the client as properly initialized
|
||||
client.initialized = true
|
||||
logx.Infof("Client %s is now fully initialized and ready for normal operations", client.id)
|
||||
}
|
||||
|
||||
// processPing processes the ping request and responds immediately
|
||||
func (s *sseMcpServer) processPing(ctx context.Context, client *mcpClient, req Request) {
|
||||
// A ping request should simply respond with an empty result to confirm the server is alive
|
||||
logx.Infof("Received ping request with ID: %d", req.ID)
|
||||
|
||||
// Send an empty response with client's original request ID
|
||||
s.sendResponse(ctx, client, req.ID, struct{}{})
|
||||
}
|
||||
|
||||
// sendErrorResponse sends an error response via the SSE channel
|
||||
func (s *sseMcpServer) sendErrorResponse(ctx context.Context, client *mcpClient,
|
||||
id any, message string, code int) {
|
||||
errorResponse := struct {
|
||||
JsonRpc string `json:"jsonrpc"`
|
||||
ID any `json:"id"`
|
||||
Error errorMessage `json:"error"`
|
||||
}{
|
||||
JsonRpc: jsonRpcVersion,
|
||||
ID: id,
|
||||
Error: errorMessage{
|
||||
Code: code,
|
||||
Message: message,
|
||||
},
|
||||
}
|
||||
|
||||
// all fields are primitive types, impossible to fail
|
||||
jsonData, _ := json.Marshal(errorResponse)
|
||||
// Use CRLF line endings as requested
|
||||
sseMessage := fmt.Sprintf("event: %s\r\ndata: %s\r\n\r\n", eventMessage, string(jsonData))
|
||||
logx.Infof("Sending error for ID %v: %s", id, sseMessage)
|
||||
|
||||
// cannot receive from ctx.Done() because we're sending to the channel for SSE messages
|
||||
select {
|
||||
case client.channel <- sseMessage:
|
||||
default:
|
||||
// Channel buffer is full, log warning and continue
|
||||
logx.Infof("Client %s channel is full while sending error response with ID %d", client.id, id)
|
||||
}
|
||||
}
|
||||
|
||||
// sendResponse sends a success response via the SSE channel
|
||||
func (s *sseMcpServer) sendResponse(ctx context.Context, client *mcpClient, id any, result any) {
|
||||
response := Response{
|
||||
JsonRpc: jsonRpcVersion,
|
||||
ID: id,
|
||||
Result: result,
|
||||
}
|
||||
|
||||
jsonData, err := json.Marshal(response)
|
||||
if err != nil {
|
||||
s.sendErrorResponse(ctx, client, id, "Failed to marshal response", errCodeInternalError)
|
||||
return
|
||||
}
|
||||
|
||||
// Use CRLF line endings as requested
|
||||
sseMessage := fmt.Sprintf("event: %s\r\ndata: %s\r\n\r\n", eventMessage, string(jsonData))
|
||||
logx.Infof("Sending response for ID %v: %s", id, sseMessage)
|
||||
|
||||
// cannot receive from ctx.Done() because we're sending to the channel for SSE messages
|
||||
select {
|
||||
case client.channel <- sseMessage:
|
||||
default:
|
||||
// Channel buffer is full, log warning and continue
|
||||
logx.Infof("Client %s channel is full while sending response with ID %v", client.id, id)
|
||||
}
|
||||
}
|
||||
3451
mcp/server_test.go
Normal file
3451
mcp/server_test.go
Normal file
File diff suppressed because it is too large
Load Diff
317
mcp/types.go
Normal file
317
mcp/types.go
Normal file
@@ -0,0 +1,317 @@
|
||||
package mcp
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"sync"
|
||||
|
||||
"github.com/zeromicro/go-zero/rest"
|
||||
)
|
||||
|
||||
// Cursor is an opaque token used for pagination
|
||||
type Cursor string
|
||||
|
||||
// Request represents a generic MCP request following JSON-RPC 2.0 specification
|
||||
type Request struct {
|
||||
SessionId string `form:"session_id"` // Session identifier for client tracking
|
||||
JsonRpc string `json:"jsonrpc"` // Must be "2.0" per JSON-RPC spec
|
||||
ID any `json:"id"` // Request identifier for matching responses
|
||||
Method string `json:"method"` // Method name to invoke
|
||||
Params json.RawMessage `json:"params"` // Parameters for the method
|
||||
}
|
||||
|
||||
func (r Request) isNotification() (bool, error) {
|
||||
switch val := r.ID.(type) {
|
||||
case int:
|
||||
return val == 0, nil
|
||||
case int64:
|
||||
return val == 0, nil
|
||||
case float64:
|
||||
return val == 0.0, nil
|
||||
case string:
|
||||
return len(val) == 0, nil
|
||||
case nil:
|
||||
return true, nil
|
||||
default:
|
||||
return false, fmt.Errorf("invalid type %T", val)
|
||||
}
|
||||
}
|
||||
|
||||
type PaginatedParams struct {
|
||||
Cursor string `json:"cursor"`
|
||||
Meta struct {
|
||||
ProgressToken any `json:"progressToken"`
|
||||
} `json:"_meta"`
|
||||
}
|
||||
|
||||
// Result is the base interface for all results
|
||||
type Result struct {
|
||||
Meta map[string]any `json:"_meta,omitempty"` // Optional metadata
|
||||
}
|
||||
|
||||
// PaginatedResult is a base for results that support pagination
|
||||
type PaginatedResult struct {
|
||||
Result
|
||||
NextCursor Cursor `json:"nextCursor,omitempty"` // Opaque token for fetching next page
|
||||
}
|
||||
|
||||
// ListToolsResult represents the response to a tools/list request
|
||||
type ListToolsResult struct {
|
||||
PaginatedResult
|
||||
Tools []Tool `json:"tools"` // List of available tools
|
||||
}
|
||||
|
||||
// Message Content Types
|
||||
|
||||
// RoleType represents the sender or recipient of messages in a conversation
|
||||
type RoleType string
|
||||
|
||||
// PromptArgument defines a single argument that can be passed to a prompt
|
||||
type PromptArgument struct {
|
||||
Name string `json:"name"` // Argument name
|
||||
Description string `json:"description,omitempty"` // Human-readable description
|
||||
Required bool `json:"required,omitempty"` // Whether this argument is required
|
||||
}
|
||||
|
||||
// PromptHandler is a function that dynamically generates prompt content
|
||||
type PromptHandler func(ctx context.Context, args map[string]string) ([]PromptMessage, error)
|
||||
|
||||
// Prompt represents an MCP Prompt definition
|
||||
type Prompt struct {
|
||||
Name string `json:"name"` // Unique identifier for the prompt
|
||||
Description string `json:"description,omitempty"` // Human-readable description
|
||||
Arguments []PromptArgument `json:"arguments,omitempty"` // Arguments for customization
|
||||
Content string `json:"-"` // Static content (internal use only)
|
||||
Handler PromptHandler `json:"-"` // Handler for dynamic content generation
|
||||
}
|
||||
|
||||
// PromptMessage represents a message in a conversation
|
||||
type PromptMessage struct {
|
||||
Role RoleType `json:"role"` // Message sender role
|
||||
Content any `json:"content"` // Message content (TextContent, ImageContent, etc.)
|
||||
}
|
||||
|
||||
// TextContent represents text content in a message
|
||||
type TextContent struct {
|
||||
Text string `json:"text"` // The text content
|
||||
Annotations *Annotations `json:"annotations,omitempty"` // Optional annotations
|
||||
}
|
||||
|
||||
type typedTextContent struct {
|
||||
Type string `json:"type"`
|
||||
TextContent
|
||||
}
|
||||
|
||||
// ImageContent represents image data in a message
|
||||
type ImageContent struct {
|
||||
Data string `json:"data"` // Base64-encoded image data
|
||||
MimeType string `json:"mimeType"` // MIME type (e.g., "image/png")
|
||||
}
|
||||
|
||||
type typedImageContent struct {
|
||||
Type string `json:"type"`
|
||||
ImageContent
|
||||
}
|
||||
|
||||
// AudioContent represents audio data in a message
|
||||
type AudioContent struct {
|
||||
Data string `json:"data"` // Base64-encoded audio data
|
||||
MimeType string `json:"mimeType"` // MIME type (e.g., "audio/mp3")
|
||||
}
|
||||
|
||||
type typedAudioContent struct {
|
||||
Type string `json:"type"`
|
||||
AudioContent
|
||||
}
|
||||
|
||||
// FileContent represents file content
|
||||
type FileContent struct {
|
||||
URI string `json:"uri"` // URI identifying the file
|
||||
MimeType string `json:"mimeType"` // MIME type of the file
|
||||
Text string `json:"text"` // File content as text
|
||||
}
|
||||
|
||||
// EmbeddedResource represents a resource embedded in a message
|
||||
type EmbeddedResource struct {
|
||||
Type string `json:"type"` // Always "resource"
|
||||
Resource ResourceContent `json:"resource"` // The resource data
|
||||
}
|
||||
|
||||
// Annotations provides additional metadata for content
|
||||
type Annotations struct {
|
||||
Audience []RoleType `json:"audience,omitempty"` // Who should see this content
|
||||
Priority *float64 `json:"priority,omitempty"` // Optional priority (0-1)
|
||||
}
|
||||
|
||||
// Tool-related Types
|
||||
|
||||
// ToolHandler is a function that handles tool calls
|
||||
type ToolHandler func(ctx context.Context, params map[string]any) (any, error)
|
||||
|
||||
// Tool represents a Model Context Protocol Tool definition
|
||||
type Tool struct {
|
||||
Name string `json:"name"` // Unique identifier for the tool
|
||||
Description string `json:"description"` // Human-readable description
|
||||
InputSchema InputSchema `json:"inputSchema"` // JSON Schema for parameters
|
||||
Handler ToolHandler `json:"-"` // Not sent to clients
|
||||
}
|
||||
|
||||
// InputSchema represents tool's input schema in JSON Schema format
|
||||
type InputSchema struct {
|
||||
Type string `json:"type"`
|
||||
Properties map[string]any `json:"properties"` // Property definitions
|
||||
Required []string `json:"required,omitempty"` // List of required properties
|
||||
}
|
||||
|
||||
// CallToolResult represents a tool call result that conforms to the MCP schema
|
||||
type CallToolResult struct {
|
||||
Result
|
||||
Content []any `json:"content"` // Content items (text, images, etc.)
|
||||
IsError bool `json:"isError,omitempty"` // True if tool execution failed
|
||||
}
|
||||
|
||||
// Resource represents a Model Context Protocol Resource definition
|
||||
type Resource struct {
|
||||
URI string `json:"uri"` // Unique resource identifier (RFC3986)
|
||||
Name string `json:"name"` // Human-readable name
|
||||
Description string `json:"description,omitempty"` // Optional description
|
||||
MimeType string `json:"mimeType,omitempty"` // Optional MIME type
|
||||
Handler ResourceHandler `json:"-"` // Internal handler not sent to clients
|
||||
}
|
||||
|
||||
// ResourceHandler is a function that handles resource read requests
|
||||
type ResourceHandler func(ctx context.Context) (ResourceContent, error)
|
||||
|
||||
// ResourceContent represents the content of a resource
|
||||
type ResourceContent struct {
|
||||
URI string `json:"uri"` // Resource URI (required)
|
||||
MimeType string `json:"mimeType,omitempty"` // MIME type of the resource
|
||||
Text string `json:"text,omitempty"` // Text content (if available)
|
||||
Blob string `json:"blob,omitempty"` // Base64 encoded blob data (if available)
|
||||
}
|
||||
|
||||
// ResourcesListResult represents the response to a resources/list request
|
||||
type ResourcesListResult struct {
|
||||
PaginatedResult
|
||||
Resources []Resource `json:"resources"` // List of available resources
|
||||
}
|
||||
|
||||
// ResourceReadParams contains parameters for a resources/read request
|
||||
type ResourceReadParams struct {
|
||||
URI string `json:"uri"` // URI of the resource to read
|
||||
}
|
||||
|
||||
// ResourceReadResult contains the result of a resources/read request
|
||||
type ResourceReadResult struct {
|
||||
Result
|
||||
Contents []ResourceContent `json:"contents"` // Array of resource content
|
||||
}
|
||||
|
||||
// ResourceSubscribeParams contains parameters for a resources/subscribe request
|
||||
type ResourceSubscribeParams struct {
|
||||
URI string `json:"uri"` // URI of the resource to subscribe to
|
||||
}
|
||||
|
||||
// ResourceUpdateNotification represents a notification about a resource update
|
||||
type ResourceUpdateNotification struct {
|
||||
URI string `json:"uri"` // URI of the updated resource
|
||||
Content ResourceContent `json:"content"` // New resource content
|
||||
}
|
||||
|
||||
// Client and Server Types
|
||||
|
||||
// mcpClient represents an SSE client connection
|
||||
type mcpClient struct {
|
||||
id string // Unique client identifier
|
||||
channel chan string // Channel for sending SSE messages
|
||||
initialized bool // Tracks if client has sent notifications/initialized
|
||||
}
|
||||
|
||||
// McpServer defines the interface for Model Context Protocol servers
|
||||
type McpServer interface {
|
||||
Start()
|
||||
Stop()
|
||||
RegisterTool(tool Tool) error
|
||||
RegisterPrompt(prompt Prompt)
|
||||
RegisterResource(resource Resource)
|
||||
}
|
||||
|
||||
// sseMcpServer implements the McpServer interface using SSE
|
||||
type sseMcpServer struct {
|
||||
conf McpConf
|
||||
server *rest.Server
|
||||
clients map[string]*mcpClient
|
||||
clientsLock sync.Mutex
|
||||
tools map[string]Tool
|
||||
toolsLock sync.Mutex
|
||||
prompts map[string]Prompt
|
||||
promptsLock sync.Mutex
|
||||
resources map[string]Resource
|
||||
resourcesLock sync.Mutex
|
||||
}
|
||||
|
||||
// Response Types
|
||||
|
||||
// errorObj represents a JSON-RPC error object
|
||||
type errorObj struct {
|
||||
Code int `json:"code"` // Error code
|
||||
Message string `json:"message"` // Error message
|
||||
}
|
||||
|
||||
// Response represents a JSON-RPC response
|
||||
type Response struct {
|
||||
JsonRpc string `json:"jsonrpc"` // Always "2.0"
|
||||
ID any `json:"id"` // Same as request ID
|
||||
Result any `json:"result"` // Result object (null if error)
|
||||
Error *errorObj `json:"error,omitempty"` // Error object (null if success)
|
||||
}
|
||||
|
||||
// Server Information Types
|
||||
|
||||
// serverInfo provides information about the server
|
||||
type serverInfo struct {
|
||||
Name string `json:"name"` // Server name
|
||||
Version string `json:"version"` // Server version
|
||||
}
|
||||
|
||||
// capabilities describes the server's capabilities
|
||||
type capabilities struct {
|
||||
Logging struct{} `json:"logging"`
|
||||
Prompts struct {
|
||||
ListChanged bool `json:"listChanged"` // Server will notify on prompt changes
|
||||
} `json:"prompts"`
|
||||
Resources struct {
|
||||
Subscribe bool `json:"subscribe"` // Server supports resource subscriptions
|
||||
ListChanged bool `json:"listChanged"` // Server will notify on resource changes
|
||||
} `json:"resources"`
|
||||
Tools struct {
|
||||
ListChanged bool `json:"listChanged"` // Server will notify on tool changes
|
||||
} `json:"tools"`
|
||||
}
|
||||
|
||||
// initializationResponse is sent in response to an initialize request
|
||||
type initializationResponse struct {
|
||||
ProtocolVersion string `json:"protocolVersion"` // Protocol version
|
||||
Capabilities capabilities `json:"capabilities"` // Server capabilities
|
||||
ServerInfo serverInfo `json:"serverInfo"` // Server information
|
||||
}
|
||||
|
||||
// ToolCallParams contains the parameters for a tool call
|
||||
type ToolCallParams struct {
|
||||
Name string `json:"name"` // Tool name
|
||||
Parameters map[string]any `json:"parameters"` // Tool parameters
|
||||
}
|
||||
|
||||
// ToolResult contains the result of a tool execution
|
||||
type ToolResult struct {
|
||||
Type string `json:"type"` // Content type (text, image, etc.)
|
||||
Content any `json:"content"` // Result content
|
||||
}
|
||||
|
||||
// errorMessage represents a detailed error message
|
||||
type errorMessage struct {
|
||||
Code int `json:"code"` // Error code
|
||||
Message string `json:"message"` // Error message
|
||||
Data any `json:",omitempty"` // Additional error data
|
||||
}
|
||||
271
mcp/types_test.go
Normal file
271
mcp/types_test.go
Normal file
@@ -0,0 +1,271 @@
|
||||
package mcp
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestResponseMarshaling(t *testing.T) {
|
||||
// Test that the Response struct marshals correctly
|
||||
resp := Response{
|
||||
JsonRpc: "2.0",
|
||||
ID: 123,
|
||||
Result: map[string]string{
|
||||
"key": "value",
|
||||
},
|
||||
}
|
||||
|
||||
data, err := json.Marshal(resp)
|
||||
assert.NoError(t, err)
|
||||
assert.Contains(t, string(data), `"jsonrpc":"2.0"`)
|
||||
assert.Contains(t, string(data), `"id":123`)
|
||||
assert.Contains(t, string(data), `"result":{"key":"value"}`)
|
||||
|
||||
// Test response with error
|
||||
respWithError := Response{
|
||||
JsonRpc: "2.0",
|
||||
ID: 456,
|
||||
Error: &errorObj{
|
||||
Code: errCodeInvalidRequest,
|
||||
Message: "Invalid Request",
|
||||
},
|
||||
}
|
||||
|
||||
data, err = json.Marshal(respWithError)
|
||||
assert.NoError(t, err)
|
||||
assert.Contains(t, string(data), `"jsonrpc":"2.0"`)
|
||||
assert.Contains(t, string(data), `"id":456`)
|
||||
assert.Contains(t, string(data), `"error":{"code":-32600,"message":"Invalid Request"}`)
|
||||
}
|
||||
|
||||
func TestRequestUnmarshaling(t *testing.T) {
|
||||
// Test that the Request struct unmarshals correctly
|
||||
jsonStr := `{
|
||||
"jsonrpc": "2.0",
|
||||
"id": 789,
|
||||
"method": "test_method",
|
||||
"params": {"key": "value"}
|
||||
}`
|
||||
|
||||
var req Request
|
||||
err := json.Unmarshal([]byte(jsonStr), &req)
|
||||
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "2.0", req.JsonRpc)
|
||||
assert.Equal(t, float64(789), req.ID)
|
||||
assert.Equal(t, "test_method", req.Method)
|
||||
|
||||
// Check params unmarshaled correctly
|
||||
var params map[string]string
|
||||
err = json.Unmarshal(req.Params, ¶ms)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "value", params["key"])
|
||||
}
|
||||
|
||||
func TestToolStructs(t *testing.T) {
|
||||
// Test Tool struct
|
||||
tool := Tool{
|
||||
Name: "test.tool",
|
||||
Description: "A test tool",
|
||||
InputSchema: InputSchema{
|
||||
Type: "object",
|
||||
Properties: map[string]any{
|
||||
"input": map[string]any{
|
||||
"type": "string",
|
||||
"description": "Input parameter",
|
||||
},
|
||||
},
|
||||
Required: []string{"input"},
|
||||
},
|
||||
Handler: func(ctx context.Context, params map[string]any) (any, error) {
|
||||
return "result", nil
|
||||
},
|
||||
}
|
||||
|
||||
// Verify fields are correct
|
||||
assert.Equal(t, "test.tool", tool.Name)
|
||||
assert.Equal(t, "A test tool", tool.Description)
|
||||
assert.Equal(t, "object", tool.InputSchema.Type)
|
||||
assert.Contains(t, tool.InputSchema.Properties, "input")
|
||||
propMap, ok := tool.InputSchema.Properties["input"].(map[string]any)
|
||||
assert.True(t, ok, "Property should be a map")
|
||||
assert.Equal(t, "string", propMap["type"])
|
||||
assert.NotNil(t, tool.Handler)
|
||||
|
||||
// Verify JSON marshalling (which should exclude Handler function)
|
||||
data, err := json.Marshal(tool)
|
||||
assert.NoError(t, err)
|
||||
assert.Contains(t, string(data), `"name":"test.tool"`)
|
||||
assert.Contains(t, string(data), `"description":"A test tool"`)
|
||||
assert.Contains(t, string(data), `"inputSchema":`)
|
||||
assert.NotContains(t, string(data), `"Handler":`)
|
||||
}
|
||||
|
||||
func TestPromptStructs(t *testing.T) {
|
||||
// Test Prompt struct
|
||||
prompt := Prompt{
|
||||
Name: "test.prompt",
|
||||
Description: "A test prompt description",
|
||||
}
|
||||
|
||||
// Verify fields are correct
|
||||
assert.Equal(t, "test.prompt", prompt.Name)
|
||||
assert.Equal(t, "A test prompt description", prompt.Description)
|
||||
|
||||
// Verify JSON marshalling
|
||||
data, err := json.Marshal(prompt)
|
||||
assert.NoError(t, err)
|
||||
assert.Contains(t, string(data), `"name":"test.prompt"`)
|
||||
assert.Contains(t, string(data), `"description":"A test prompt description"`)
|
||||
}
|
||||
|
||||
func TestResourceStructs(t *testing.T) {
|
||||
// Test Resource struct
|
||||
resource := Resource{
|
||||
Name: "test.resource",
|
||||
URI: "http://example.com/resource",
|
||||
Description: "A test resource",
|
||||
}
|
||||
|
||||
// Verify fields are correct
|
||||
assert.Equal(t, "test.resource", resource.Name)
|
||||
assert.Equal(t, "http://example.com/resource", resource.URI)
|
||||
assert.Equal(t, "A test resource", resource.Description)
|
||||
|
||||
// Verify JSON marshalling
|
||||
data, err := json.Marshal(resource)
|
||||
assert.NoError(t, err)
|
||||
assert.Contains(t, string(data), `"name":"test.resource"`)
|
||||
assert.Contains(t, string(data), `"uri":"http://example.com/resource"`)
|
||||
assert.Contains(t, string(data), `"description":"A test resource"`)
|
||||
}
|
||||
|
||||
func TestContentTypes(t *testing.T) {
|
||||
// Test TextContent
|
||||
textContent := TextContent{
|
||||
Text: "Sample text",
|
||||
Annotations: &Annotations{
|
||||
Audience: []RoleType{RoleUser, RoleAssistant},
|
||||
Priority: ptr(1.0),
|
||||
},
|
||||
}
|
||||
|
||||
data, err := json.Marshal(textContent)
|
||||
assert.NoError(t, err)
|
||||
assert.Contains(t, string(data), `"text":"Sample text"`)
|
||||
assert.Contains(t, string(data), `"audience":["user","assistant"]`)
|
||||
assert.Contains(t, string(data), `"priority":1`)
|
||||
|
||||
// Test ImageContent
|
||||
imageContent := ImageContent{
|
||||
Data: "base64data",
|
||||
MimeType: "image/png",
|
||||
}
|
||||
|
||||
data, err = json.Marshal(imageContent)
|
||||
assert.NoError(t, err)
|
||||
assert.Contains(t, string(data), `"data":"base64data"`)
|
||||
assert.Contains(t, string(data), `"mimeType":"image/png"`)
|
||||
|
||||
// Test AudioContent
|
||||
audioContent := AudioContent{
|
||||
Data: "base64audio",
|
||||
MimeType: "audio/mp3",
|
||||
}
|
||||
|
||||
data, err = json.Marshal(audioContent)
|
||||
assert.NoError(t, err)
|
||||
assert.Contains(t, string(data), `"data":"base64audio"`)
|
||||
assert.Contains(t, string(data), `"mimeType":"audio/mp3"`)
|
||||
}
|
||||
|
||||
func TestCallToolResult(t *testing.T) {
|
||||
// Test CallToolResult
|
||||
result := CallToolResult{
|
||||
Result: Result{
|
||||
Meta: map[string]any{
|
||||
"progressToken": "token123",
|
||||
},
|
||||
},
|
||||
Content: []interface{}{
|
||||
TextContent{
|
||||
Text: "Sample result",
|
||||
},
|
||||
},
|
||||
IsError: false,
|
||||
}
|
||||
|
||||
data, err := json.Marshal(result)
|
||||
assert.NoError(t, err)
|
||||
assert.Contains(t, string(data), `"_meta":{"progressToken":"token123"}`)
|
||||
assert.Contains(t, string(data), `"content":[{"text":"Sample result"}]`)
|
||||
assert.NotContains(t, string(data), `"isError":`)
|
||||
}
|
||||
|
||||
func TestRequest_isNotification(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
id any
|
||||
want bool
|
||||
wantErr error
|
||||
}{
|
||||
// integer test cases
|
||||
{name: "int zero", id: 0, want: true, wantErr: nil},
|
||||
{name: "int non-zero", id: 1, want: false, wantErr: nil},
|
||||
{name: "int64 zero", id: int64(0), want: true, wantErr: nil},
|
||||
{name: "int64 max", id: int64(9223372036854775807), want: false, wantErr: nil},
|
||||
|
||||
// floating point number test cases
|
||||
{name: "float64 zero", id: float64(0.0), want: true, wantErr: nil},
|
||||
{name: "float64 positive", id: float64(0.000001), want: false, wantErr: nil},
|
||||
{name: "float64 negative", id: float64(-0.000001), want: false, wantErr: nil},
|
||||
{name: "float64 epsilon", id: float64(1e-300), want: false, wantErr: nil},
|
||||
|
||||
// string test cases
|
||||
{name: "empty string", id: "", want: true, wantErr: nil},
|
||||
{name: "non-empty string", id: "abc", want: false, wantErr: nil},
|
||||
{name: "space string", id: " ", want: false, wantErr: nil},
|
||||
{name: "unicode string", id: "こんにちは", want: false, wantErr: nil},
|
||||
|
||||
// special cases
|
||||
{name: "nil", id: nil, want: true, wantErr: nil},
|
||||
|
||||
// logical type test cases
|
||||
{name: "bool true", id: true, want: false, wantErr: errors.New("invalid type bool")},
|
||||
{name: "bool false", id: false, want: false, wantErr: errors.New("invalid type bool")},
|
||||
{name: "struct type", id: struct{}{}, want: false, wantErr: errors.New("invalid type struct {}")},
|
||||
{name: "slice type", id: []int{1, 2, 3}, want: false, wantErr: errors.New("invalid type []int")},
|
||||
{name: "map type", id: map[string]int{"a": 1}, want: false, wantErr: errors.New("invalid type map[string]int")},
|
||||
{name: "pointer type", id: new(int), want: false, wantErr: errors.New("invalid type *int")},
|
||||
{name: "func type", id: func() {}, want: false, wantErr: errors.New("invalid type func()")},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
req := Request{
|
||||
SessionId: "test-session",
|
||||
JsonRpc: "2.0",
|
||||
ID: tt.id,
|
||||
Method: "testMethod",
|
||||
Params: json.RawMessage(`{}`),
|
||||
}
|
||||
|
||||
got, err := req.isNotification()
|
||||
|
||||
if (err != nil) != (tt.wantErr != nil) {
|
||||
t.Fatalf("error presence mismatch: got error = %v, wantErr %v", err, tt.wantErr)
|
||||
}
|
||||
if err != nil && tt.wantErr != nil && err.Error() != tt.wantErr.Error() {
|
||||
t.Fatalf("error message mismatch:\ngot %q\nwant %q", err.Error(), tt.wantErr.Error())
|
||||
}
|
||||
|
||||
if got != tt.want {
|
||||
t.Errorf("isNotification() = %v, want %v for ID %v (%T)", got, tt.want, tt.id, tt.id)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
107
mcp/util.go
Normal file
107
mcp/util.go
Normal file
@@ -0,0 +1,107 @@
|
||||
package mcp
|
||||
|
||||
import "fmt"
|
||||
|
||||
// formatSSEMessage formats a Server-Sent Event message with proper CRLF line endings
|
||||
func formatSSEMessage(event string, data []byte) string {
|
||||
return fmt.Sprintf("event: %s\r\ndata: %s\r\n\r\n", event, string(data))
|
||||
}
|
||||
|
||||
// ptr is a helper function to get a pointer to a value
|
||||
func ptr[T any](v T) *T {
|
||||
return &v
|
||||
}
|
||||
|
||||
func toTypedContents(contents []any) []any {
|
||||
typedContents := make([]any, len(contents))
|
||||
|
||||
for i, content := range contents {
|
||||
switch v := content.(type) {
|
||||
case TextContent:
|
||||
typedContents[i] = typedTextContent{
|
||||
Type: ContentTypeText,
|
||||
TextContent: v,
|
||||
}
|
||||
case ImageContent:
|
||||
typedContents[i] = typedImageContent{
|
||||
Type: ContentTypeImage,
|
||||
ImageContent: v,
|
||||
}
|
||||
case AudioContent:
|
||||
typedContents[i] = typedAudioContent{
|
||||
Type: ContentTypeAudio,
|
||||
AudioContent: v,
|
||||
}
|
||||
default:
|
||||
typedContents[i] = typedTextContent{
|
||||
Type: ContentTypeText,
|
||||
TextContent: TextContent{
|
||||
Text: fmt.Sprintf("Unknown content type: %T", v),
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return typedContents
|
||||
}
|
||||
|
||||
func toTypedPromptMessages(messages []PromptMessage) []PromptMessage {
|
||||
typedMessages := make([]PromptMessage, len(messages))
|
||||
|
||||
for i, msg := range messages {
|
||||
switch v := msg.Content.(type) {
|
||||
case TextContent:
|
||||
typedMessages[i] = PromptMessage{
|
||||
Role: msg.Role,
|
||||
Content: typedTextContent{
|
||||
Type: ContentTypeText,
|
||||
TextContent: v,
|
||||
},
|
||||
}
|
||||
case ImageContent:
|
||||
typedMessages[i] = PromptMessage{
|
||||
Role: msg.Role,
|
||||
Content: typedImageContent{
|
||||
Type: ContentTypeImage,
|
||||
ImageContent: v,
|
||||
},
|
||||
}
|
||||
case AudioContent:
|
||||
typedMessages[i] = PromptMessage{
|
||||
Role: msg.Role,
|
||||
Content: typedAudioContent{
|
||||
Type: ContentTypeAudio,
|
||||
AudioContent: v,
|
||||
},
|
||||
}
|
||||
default:
|
||||
typedMessages[i] = PromptMessage{
|
||||
Role: msg.Role,
|
||||
Content: typedTextContent{
|
||||
Type: ContentTypeText,
|
||||
TextContent: TextContent{
|
||||
Text: fmt.Sprintf("Unknown content type: %T", v),
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return typedMessages
|
||||
}
|
||||
|
||||
// validatePromptArguments checks if all required arguments are provided
|
||||
// Returns a list of missing required arguments
|
||||
func validatePromptArguments(prompt Prompt, providedArgs map[string]string) []string {
|
||||
var missingArgs []string
|
||||
|
||||
for _, arg := range prompt.Arguments {
|
||||
if arg.Required {
|
||||
if value, exists := providedArgs[arg.Name]; !exists || len(value) == 0 {
|
||||
missingArgs = append(missingArgs, arg.Name)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return missingArgs
|
||||
}
|
||||
274
mcp/util_test.go
Normal file
274
mcp/util_test.go
Normal file
@@ -0,0 +1,274 @@
|
||||
package mcp
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
type Event struct {
|
||||
Type string
|
||||
Data map[string]any
|
||||
}
|
||||
|
||||
func parseEvent(input string) (*Event, error) {
|
||||
var evt Event
|
||||
var dataStr string
|
||||
|
||||
scanner := bufio.NewScanner(strings.NewReader(input))
|
||||
for scanner.Scan() {
|
||||
line := scanner.Text()
|
||||
if strings.HasPrefix(line, "event:") {
|
||||
evt.Type = strings.TrimSpace(strings.TrimPrefix(line, "event:"))
|
||||
} else if strings.HasPrefix(line, "data:") {
|
||||
dataStr = strings.TrimSpace(strings.TrimPrefix(line, "data:"))
|
||||
}
|
||||
}
|
||||
if err := scanner.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if len(dataStr) > 0 {
|
||||
if err := json.Unmarshal([]byte(dataStr), &evt.Data); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse data: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return &evt, nil
|
||||
}
|
||||
|
||||
// TestToTypedPromptMessages tests the toTypedPromptMessages function
|
||||
func TestToTypedPromptMessages(t *testing.T) {
|
||||
// Test with multiple message types in one test
|
||||
t.Run("MixedContentTypes", func(t *testing.T) {
|
||||
// Create test data with different content types
|
||||
messages := []PromptMessage{
|
||||
{
|
||||
Role: RoleUser,
|
||||
Content: TextContent{
|
||||
Text: "Hello, this is a text message",
|
||||
Annotations: &Annotations{
|
||||
Audience: []RoleType{RoleUser, RoleAssistant},
|
||||
Priority: ptr(0.8),
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
Role: RoleAssistant,
|
||||
Content: ImageContent{
|
||||
Data: "base64ImageData",
|
||||
MimeType: "image/jpeg",
|
||||
},
|
||||
},
|
||||
{
|
||||
Role: RoleUser,
|
||||
Content: AudioContent{
|
||||
Data: "base64AudioData",
|
||||
MimeType: "audio/mp3",
|
||||
},
|
||||
},
|
||||
{
|
||||
Role: "system",
|
||||
Content: "This is a simple string that should be handled as unknown type",
|
||||
},
|
||||
}
|
||||
|
||||
// Call the function
|
||||
result := toTypedPromptMessages(messages)
|
||||
|
||||
// Validate results
|
||||
require.Len(t, result, 4, "Should return the same number of messages")
|
||||
|
||||
// Validate first message (TextContent)
|
||||
msg := result[0]
|
||||
assert.Equal(t, RoleUser, msg.Role, "Role should be preserved")
|
||||
|
||||
// Type assertion using reflection since Content is an interface
|
||||
typed, ok := msg.Content.(typedTextContent)
|
||||
require.True(t, ok, "Should be typedTextContent")
|
||||
assert.Equal(t, ContentTypeText, typed.Type, "Type should be text")
|
||||
assert.Equal(t, "Hello, this is a text message", typed.Text, "Text content should be preserved")
|
||||
require.NotNil(t, typed.Annotations, "Annotations should be preserved")
|
||||
assert.Equal(t, []RoleType{RoleUser, RoleAssistant}, typed.Annotations.Audience, "Audience should be preserved")
|
||||
require.NotNil(t, typed.Annotations.Priority, "Priority should be preserved")
|
||||
assert.Equal(t, 0.8, *typed.Annotations.Priority, "Priority value should be preserved")
|
||||
|
||||
// Validate second message (ImageContent)
|
||||
msg = result[1]
|
||||
assert.Equal(t, RoleAssistant, msg.Role, "Role should be preserved")
|
||||
|
||||
// Type assertion for image content
|
||||
typedImg, ok := msg.Content.(typedImageContent)
|
||||
require.True(t, ok, "Should be typedImageContent")
|
||||
assert.Equal(t, ContentTypeImage, typedImg.Type, "Type should be image")
|
||||
assert.Equal(t, "base64ImageData", typedImg.Data, "Image data should be preserved")
|
||||
assert.Equal(t, "image/jpeg", typedImg.MimeType, "MimeType should be preserved")
|
||||
|
||||
// Validate third message (AudioContent)
|
||||
msg = result[2]
|
||||
assert.Equal(t, RoleUser, msg.Role, "Role should be preserved")
|
||||
|
||||
// Type assertion for audio content
|
||||
typedAudio, ok := msg.Content.(typedAudioContent)
|
||||
require.True(t, ok, "Should be typedAudioContent")
|
||||
assert.Equal(t, ContentTypeAudio, typedAudio.Type, "Type should be audio")
|
||||
assert.Equal(t, "base64AudioData", typedAudio.Data, "Audio data should be preserved")
|
||||
assert.Equal(t, "audio/mp3", typedAudio.MimeType, "MimeType should be preserved")
|
||||
|
||||
// Validate fourth message (unknown type converted to TextContent)
|
||||
msg = result[3]
|
||||
assert.Equal(t, RoleType("system"), msg.Role, "Role should be preserved")
|
||||
|
||||
// Should be converted to a typedTextContent with error message
|
||||
typedUnknown, ok := msg.Content.(typedTextContent)
|
||||
require.True(t, ok, "Unknown content should be converted to typedTextContent")
|
||||
assert.Equal(t, ContentTypeText, typedUnknown.Type, "Type should be text")
|
||||
assert.Contains(t, typedUnknown.Text, "Unknown content type:", "Should contain error about unknown type")
|
||||
assert.Contains(t, typedUnknown.Text, "string", "Should mention the actual type")
|
||||
})
|
||||
|
||||
// Test empty input
|
||||
t.Run("EmptyInput", func(t *testing.T) {
|
||||
messages := []PromptMessage{}
|
||||
result := toTypedPromptMessages(messages)
|
||||
assert.Empty(t, result, "Should return empty slice for empty input")
|
||||
})
|
||||
|
||||
// Test with nil annotations
|
||||
t.Run("NilAnnotations", func(t *testing.T) {
|
||||
messages := []PromptMessage{
|
||||
{
|
||||
Role: RoleUser,
|
||||
Content: TextContent{
|
||||
Text: "Text with nil annotations",
|
||||
Annotations: nil,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
result := toTypedPromptMessages(messages)
|
||||
require.Len(t, result, 1, "Should return one message")
|
||||
|
||||
typed, ok := result[0].Content.(typedTextContent)
|
||||
require.True(t, ok, "Should be typedTextContent")
|
||||
assert.Equal(t, ContentTypeText, typed.Type, "Type should be text")
|
||||
assert.Equal(t, "Text with nil annotations", typed.Text, "Text content should be preserved")
|
||||
assert.Nil(t, typed.Annotations, "Nil annotations should be preserved as nil")
|
||||
})
|
||||
}
|
||||
|
||||
// TestToTypedContents tests the toTypedContents function
|
||||
func TestToTypedContents(t *testing.T) {
|
||||
// Test with multiple content types in one test
|
||||
t.Run("MixedContentTypes", func(t *testing.T) {
|
||||
// Create test data with different content types
|
||||
contents := []any{
|
||||
TextContent{
|
||||
Text: "Hello, this is a text content",
|
||||
Annotations: &Annotations{
|
||||
Audience: []RoleType{RoleUser, RoleAssistant},
|
||||
Priority: ptr(0.7),
|
||||
},
|
||||
},
|
||||
ImageContent{
|
||||
Data: "base64ImageData",
|
||||
MimeType: "image/png",
|
||||
},
|
||||
AudioContent{
|
||||
Data: "base64AudioData",
|
||||
MimeType: "audio/wav",
|
||||
},
|
||||
"This is a simple string that should be handled as unknown type",
|
||||
}
|
||||
|
||||
// Call the function
|
||||
result := toTypedContents(contents)
|
||||
|
||||
// Validate results
|
||||
require.Len(t, result, 4, "Should return the same number of contents")
|
||||
|
||||
// Validate first content (TextContent)
|
||||
typed, ok := result[0].(typedTextContent)
|
||||
require.True(t, ok, "Should be typedTextContent")
|
||||
assert.Equal(t, ContentTypeText, typed.Type, "Type should be text")
|
||||
assert.Equal(t, "Hello, this is a text content", typed.Text, "Text content should be preserved")
|
||||
require.NotNil(t, typed.Annotations, "Annotations should be preserved")
|
||||
assert.Equal(t, []RoleType{RoleUser, RoleAssistant}, typed.Annotations.Audience, "Audience should be preserved")
|
||||
require.NotNil(t, typed.Annotations.Priority, "Priority should be preserved")
|
||||
assert.Equal(t, 0.7, *typed.Annotations.Priority, "Priority value should be preserved")
|
||||
|
||||
// Validate second content (ImageContent)
|
||||
typedImg, ok := result[1].(typedImageContent)
|
||||
require.True(t, ok, "Should be typedImageContent")
|
||||
assert.Equal(t, ContentTypeImage, typedImg.Type, "Type should be image")
|
||||
assert.Equal(t, "base64ImageData", typedImg.Data, "Image data should be preserved")
|
||||
assert.Equal(t, "image/png", typedImg.MimeType, "MimeType should be preserved")
|
||||
|
||||
// Validate third content (AudioContent)
|
||||
typedAudio, ok := result[2].(typedAudioContent)
|
||||
require.True(t, ok, "Should be typedAudioContent")
|
||||
assert.Equal(t, ContentTypeAudio, typedAudio.Type, "Type should be audio")
|
||||
assert.Equal(t, "base64AudioData", typedAudio.Data, "Audio data should be preserved")
|
||||
assert.Equal(t, "audio/wav", typedAudio.MimeType, "MimeType should be preserved")
|
||||
|
||||
// Validate fourth content (unknown type converted to TextContent)
|
||||
typedUnknown, ok := result[3].(typedTextContent)
|
||||
require.True(t, ok, "Unknown content should be converted to typedTextContent")
|
||||
assert.Equal(t, ContentTypeText, typedUnknown.Type, "Type should be text")
|
||||
assert.Contains(t, typedUnknown.Text, "Unknown content type:", "Should contain error about unknown type")
|
||||
assert.Contains(t, typedUnknown.Text, "string", "Should mention the actual type")
|
||||
})
|
||||
|
||||
// Test empty input
|
||||
t.Run("EmptyInput", func(t *testing.T) {
|
||||
contents := []any{}
|
||||
result := toTypedContents(contents)
|
||||
assert.Empty(t, result, "Should return empty slice for empty input")
|
||||
})
|
||||
|
||||
// Test with nil annotations
|
||||
t.Run("NilAnnotations", func(t *testing.T) {
|
||||
contents := []any{
|
||||
TextContent{
|
||||
Text: "Text with nil annotations",
|
||||
Annotations: nil,
|
||||
},
|
||||
}
|
||||
|
||||
result := toTypedContents(contents)
|
||||
require.Len(t, result, 1, "Should return one content")
|
||||
|
||||
typed, ok := result[0].(typedTextContent)
|
||||
require.True(t, ok, "Should be typedTextContent")
|
||||
assert.Equal(t, ContentTypeText, typed.Type, "Type should be text")
|
||||
assert.Equal(t, "Text with nil annotations", typed.Text, "Text content should be preserved")
|
||||
assert.Nil(t, typed.Annotations, "Nil annotations should be preserved as nil")
|
||||
})
|
||||
|
||||
// Test with custom struct (should be handled as unknown type)
|
||||
t.Run("CustomStruct", func(t *testing.T) {
|
||||
type CustomContent struct {
|
||||
Data string
|
||||
}
|
||||
|
||||
contents := []any{
|
||||
CustomContent{
|
||||
Data: "custom data",
|
||||
},
|
||||
}
|
||||
|
||||
result := toTypedContents(contents)
|
||||
require.Len(t, result, 1, "Should return one content")
|
||||
|
||||
typed, ok := result[0].(typedTextContent)
|
||||
require.True(t, ok, "Custom struct should be converted to typedTextContent")
|
||||
assert.Equal(t, ContentTypeText, typed.Type, "Type should be text")
|
||||
assert.Contains(t, typed.Text, "Unknown content type:", "Should contain error about unknown type")
|
||||
assert.Contains(t, typed.Text, "CustomContent", "Should mention the actual type")
|
||||
})
|
||||
}
|
||||
149
mcp/vars.go
Normal file
149
mcp/vars.go
Normal file
@@ -0,0 +1,149 @@
|
||||
package mcp
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/zeromicro/go-zero/core/syncx"
|
||||
)
|
||||
|
||||
// Protocol constants
|
||||
const (
|
||||
// JSON-RPC version as defined in the specification
|
||||
jsonRpcVersion = "2.0"
|
||||
|
||||
// Session identifier key used in request URLs
|
||||
sessionIdKey = "session_id"
|
||||
|
||||
// progressTokenKey is used to track progress of long-running tasks
|
||||
progressTokenKey = "progressToken"
|
||||
)
|
||||
|
||||
// Server-Sent Events (SSE) event types
|
||||
const (
|
||||
// Standard message event for JSON-RPC responses
|
||||
eventMessage = "message"
|
||||
|
||||
// Endpoint event for sending endpoint URL to clients
|
||||
eventEndpoint = "endpoint"
|
||||
)
|
||||
|
||||
// Content type identifiers
|
||||
const (
|
||||
// ContentTypeObject is object content type
|
||||
ContentTypeObject = "object"
|
||||
|
||||
// ContentTypeText is text content type
|
||||
ContentTypeText = "text"
|
||||
|
||||
// ContentTypeImage is image content type
|
||||
ContentTypeImage = "image"
|
||||
|
||||
// ContentTypeAudio is audio content type
|
||||
ContentTypeAudio = "audio"
|
||||
|
||||
// ContentTypeResource is resource content type
|
||||
ContentTypeResource = "resource"
|
||||
)
|
||||
|
||||
// Collection keys for broadcast events
|
||||
const (
|
||||
// Key for prompts collection
|
||||
keyPrompts = "prompts"
|
||||
|
||||
// Key for resources collection
|
||||
keyResources = "resources"
|
||||
|
||||
// Key for tools collection
|
||||
keyTools = "tools"
|
||||
)
|
||||
|
||||
// JSON-RPC error codes
|
||||
// Standard error codes from JSON-RPC 2.0 spec
|
||||
const (
|
||||
// Invalid JSON was received by the server
|
||||
errCodeInvalidRequest = -32600
|
||||
|
||||
// The method does not exist / is not available
|
||||
errCodeMethodNotFound = -32601
|
||||
|
||||
// Invalid method parameter(s)
|
||||
errCodeInvalidParams = -32602
|
||||
|
||||
// Internal JSON-RPC error
|
||||
errCodeInternalError = -32603
|
||||
|
||||
// Tool execution timed out
|
||||
errCodeTimeout = -32001
|
||||
|
||||
// Resource not found error
|
||||
errCodeResourceNotFound = -32002
|
||||
|
||||
// Client hasn't completed initialization
|
||||
errCodeClientNotInitialized = -32800
|
||||
)
|
||||
|
||||
// User and assistant role definitions
|
||||
const (
|
||||
// RoleUser is the "user" role - the entity asking questions
|
||||
RoleUser RoleType = "user"
|
||||
|
||||
// RoleAssistant is the "assistant" role - the entity providing responses
|
||||
RoleAssistant RoleType = "assistant"
|
||||
)
|
||||
|
||||
// Method names as defined in the MCP specification
|
||||
const (
|
||||
// Initialize the connection between client and server
|
||||
methodInitialize = "initialize"
|
||||
|
||||
// List available tools
|
||||
methodToolsList = "tools/list"
|
||||
|
||||
// Call a specific tool
|
||||
methodToolsCall = "tools/call"
|
||||
|
||||
// List available prompts
|
||||
methodPromptsList = "prompts/list"
|
||||
|
||||
// Get a specific prompt
|
||||
methodPromptsGet = "prompts/get"
|
||||
|
||||
// List available resources
|
||||
methodResourcesList = "resources/list"
|
||||
|
||||
// Read a specific resource
|
||||
methodResourcesRead = "resources/read"
|
||||
|
||||
// Subscribe to resource updates
|
||||
methodResourcesSubscribe = "resources/subscribe"
|
||||
|
||||
// Simple ping to check server availability
|
||||
methodPing = "ping"
|
||||
|
||||
// Notification that client is fully initialized
|
||||
methodNotificationsInitialized = "notifications/initialized"
|
||||
|
||||
// Notification that a request was canceled
|
||||
methodNotificationsCancelled = "notifications/cancelled"
|
||||
)
|
||||
|
||||
// Event names for Server-Sent Events (SSE)
|
||||
const (
|
||||
// Notification of tool list changes
|
||||
eventToolsListChanged = "tools/list_changed"
|
||||
|
||||
// Notification of prompt list changes
|
||||
eventPromptsListChanged = "prompts/list_changed"
|
||||
|
||||
// Notification of resource list changes
|
||||
eventResourcesListChanged = "resources/list_changed"
|
||||
)
|
||||
|
||||
var (
|
||||
// Default channel size for events
|
||||
eventChanSize = 10
|
||||
|
||||
// Default ping interval for checking connection availability
|
||||
// use syncx.ForAtomicDuration to ensure atomicity in test race
|
||||
pingInterval = syncx.ForAtomicDuration(30 * time.Second)
|
||||
)
|
||||
210
mcp/vars_test.go
Normal file
210
mcp/vars_test.go
Normal file
@@ -0,0 +1,210 @@
|
||||
// filepath: /Users/kevin/Develop/go/opensource/go-zero/mcp/vars_test.go
|
||||
package mcp
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
// TestErrorCodes ensures error codes are applied correctly in error responses
|
||||
func TestErrorCodes(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
code int
|
||||
message string
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "invalid request error",
|
||||
code: errCodeInvalidRequest,
|
||||
message: "Invalid request",
|
||||
expected: `"code":-32600`,
|
||||
},
|
||||
{
|
||||
name: "method not found error",
|
||||
code: errCodeMethodNotFound,
|
||||
message: "Method not found",
|
||||
expected: `"code":-32601`,
|
||||
},
|
||||
{
|
||||
name: "invalid params error",
|
||||
code: errCodeInvalidParams,
|
||||
message: "Invalid parameters",
|
||||
expected: `"code":-32602`,
|
||||
},
|
||||
{
|
||||
name: "internal error",
|
||||
code: errCodeInternalError,
|
||||
message: "Internal server error",
|
||||
expected: `"code":-32603`,
|
||||
},
|
||||
{
|
||||
name: "timeout error",
|
||||
code: errCodeTimeout,
|
||||
message: "Operation timed out",
|
||||
expected: `"code":-32001`,
|
||||
},
|
||||
{
|
||||
name: "resource not found error",
|
||||
code: errCodeResourceNotFound,
|
||||
message: "Resource not found",
|
||||
expected: `"code":-32002`,
|
||||
},
|
||||
{
|
||||
name: "client not initialized error",
|
||||
code: errCodeClientNotInitialized,
|
||||
message: "Client not initialized",
|
||||
expected: `"code":-32800`,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
resp := Response{
|
||||
JsonRpc: jsonRpcVersion,
|
||||
ID: int64(1),
|
||||
Error: &errorObj{
|
||||
Code: tc.code,
|
||||
Message: tc.message,
|
||||
},
|
||||
}
|
||||
data, err := json.Marshal(resp)
|
||||
assert.NoError(t, err)
|
||||
assert.Contains(t, string(data), tc.expected, "Error code should match expected value")
|
||||
assert.Contains(t, string(data), tc.message, "Error message should be included")
|
||||
assert.Contains(t, string(data), jsonRpcVersion, "JSON-RPC version should be included")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestJsonRpcVersion ensures the correct JSON-RPC version is used
|
||||
func TestJsonRpcVersion(t *testing.T) {
|
||||
assert.Equal(t, "2.0", jsonRpcVersion, "JSON-RPC version should be 2.0")
|
||||
|
||||
// Test that it's used in responses
|
||||
resp := Response{
|
||||
JsonRpc: jsonRpcVersion,
|
||||
ID: int64(1),
|
||||
Result: "test",
|
||||
}
|
||||
data, err := json.Marshal(resp)
|
||||
assert.NoError(t, err)
|
||||
assert.Contains(t, string(data), `"jsonrpc":"2.0"`, "Response should use correct JSON-RPC version")
|
||||
|
||||
// Test that it's expected in requests
|
||||
reqStr := `{"jsonrpc":"2.0","id":1,"method":"test"}`
|
||||
var req Request
|
||||
err = json.Unmarshal([]byte(reqStr), &req)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, jsonRpcVersion, req.JsonRpc, "Request should parse correct JSON-RPC version")
|
||||
}
|
||||
|
||||
// TestSessionIdKey ensures session ID extraction works correctly
|
||||
func TestSessionIdKey(t *testing.T) {
|
||||
// Create a mock server implementation
|
||||
mock := newMockMcpServer(t)
|
||||
defer mock.shutdown()
|
||||
|
||||
// Verify the key constant
|
||||
assert.Equal(t, "session_id", sessionIdKey, "Session ID key should be 'session_id'")
|
||||
|
||||
// Test that session ID is extracted correctly
|
||||
mockR := httptest.NewRequest("GET", "/?"+sessionIdKey+"=test-session", nil)
|
||||
|
||||
// Since the mock server is using the same session key logic,
|
||||
// we can test this by accessing the request query parameters directly
|
||||
sessionID := mockR.URL.Query().Get(sessionIdKey)
|
||||
assert.Equal(t, "test-session", sessionID, "Session ID should be extracted correctly")
|
||||
}
|
||||
|
||||
// TestEventTypes ensures event types are set correctly in SSE responses
|
||||
func TestEventTypes(t *testing.T) {
|
||||
// Test message event
|
||||
assert.Equal(t, "message", eventMessage, "Message event should be 'message'")
|
||||
|
||||
// Test endpoint event
|
||||
assert.Equal(t, "endpoint", eventEndpoint, "Endpoint event should be 'endpoint'")
|
||||
|
||||
// Verify them in an actual SSE format string
|
||||
messageEvent := "event: " + eventMessage + "\ndata: test\n\n"
|
||||
assert.Contains(t, messageEvent, "event: message", "Message event should format correctly")
|
||||
|
||||
endpointEvent := "event: " + eventEndpoint + "\ndata: test\n\n"
|
||||
assert.Contains(t, endpointEvent, "event: endpoint", "Endpoint event should format correctly")
|
||||
}
|
||||
|
||||
// TestCollectionKeys checks that collection keys are used correctly
|
||||
func TestCollectionKeys(t *testing.T) {
|
||||
// Verify collection key constants
|
||||
assert.Equal(t, "prompts", keyPrompts, "Prompts key should be 'prompts'")
|
||||
assert.Equal(t, "resources", keyResources, "Resources key should be 'resources'")
|
||||
assert.Equal(t, "tools", keyTools, "Tools key should be 'tools'")
|
||||
}
|
||||
|
||||
// TestRoleTypes checks that role types are used correctly
|
||||
func TestRoleTypes(t *testing.T) {
|
||||
// Test in annotations
|
||||
annotations := Annotations{
|
||||
Audience: []RoleType{RoleUser, RoleAssistant},
|
||||
}
|
||||
data, err := json.Marshal(annotations)
|
||||
assert.NoError(t, err)
|
||||
assert.Contains(t, string(data), `"audience":["user","assistant"]`, "Role types should marshal correctly")
|
||||
}
|
||||
|
||||
// TestMethodNames checks that method names are used correctly
|
||||
func TestMethodNames(t *testing.T) {
|
||||
// Verify method name constants
|
||||
methods := map[string]string{
|
||||
"initialize": methodInitialize,
|
||||
"tools/list": methodToolsList,
|
||||
"tools/call": methodToolsCall,
|
||||
"prompts/list": methodPromptsList,
|
||||
"prompts/get": methodPromptsGet,
|
||||
"resources/list": methodResourcesList,
|
||||
"resources/read": methodResourcesRead,
|
||||
"resources/subscribe": methodResourcesSubscribe,
|
||||
"ping": methodPing,
|
||||
"notifications/initialized": methodNotificationsInitialized,
|
||||
"notifications/cancelled": methodNotificationsCancelled,
|
||||
}
|
||||
|
||||
for expected, actual := range methods {
|
||||
assert.Equal(t, expected, actual, "Method name should be "+expected)
|
||||
}
|
||||
|
||||
// Test in a request
|
||||
for methodName := range methods {
|
||||
req := Request{
|
||||
JsonRpc: jsonRpcVersion,
|
||||
ID: int64(1),
|
||||
Method: methodName,
|
||||
}
|
||||
data, err := json.Marshal(req)
|
||||
assert.NoError(t, err)
|
||||
assert.Contains(t, string(data), `"method":"`+methodName+`"`, "Method name should be used in requests")
|
||||
}
|
||||
}
|
||||
|
||||
// TestEventNames checks that event names are used correctly
|
||||
func TestEventNames(t *testing.T) {
|
||||
// Verify event name constants
|
||||
events := map[string]string{
|
||||
"tools/list_changed": eventToolsListChanged,
|
||||
"prompts/list_changed": eventPromptsListChanged,
|
||||
"resources/list_changed": eventResourcesListChanged,
|
||||
}
|
||||
|
||||
for expected, actual := range events {
|
||||
assert.Equal(t, expected, actual, "Event name should be "+expected)
|
||||
}
|
||||
|
||||
// Test event names in SSE format
|
||||
for _, eventName := range events {
|
||||
sseEvent := "event: " + eventName + "\ndata: test\n\n"
|
||||
assert.Contains(t, sseEvent, "event: "+eventName, "Event name should format correctly in SSE")
|
||||
}
|
||||
}
|
||||
@@ -301,6 +301,9 @@ go-zero 已被许多公司用于生产部署,接入场景如在线教育、电
|
||||
>102. 深圳市兴海物联科技有限公司
|
||||
>103. 爱芯元智半导体股份有限公司
|
||||
>104. 杭州升恒科技有限公司
|
||||
>105. 昆仑万维科技股份有限公司
|
||||
>106. 无锡盛算信息技术有限公司
|
||||
>107. 深圳市聚货通信息科技有限公司
|
||||
|
||||
如果贵公司也已使用 go-zero,欢迎在 [登记地址](https://github.com/zeromicro/go-zero/issues/602) 登记,仅仅为了推广,不做其它用途。
|
||||
|
||||
|
||||
@@ -251,7 +251,3 @@ go-zero enlisted in the [CNCF Cloud Native Landscape](https://landscape.cncf.io/
|
||||
## Give a Star! ⭐
|
||||
|
||||
If you like this project or are using it to learn or start your own solution, give it a star to get updates on new releases. Your support matters!
|
||||
|
||||
## Buy me a coffee
|
||||
|
||||
<a href="https://www.buymeacoffee.com/kevwan" target="_blank"><img src="https://cdn.buymeacoffee.com/buttons/v2/default-yellow.png" alt="Buy Me A Coffee" style="height: 60px !important;width: 217px !important;" ></a>
|
||||
|
||||
@@ -15,6 +15,7 @@ import (
|
||||
"github.com/zeromicro/go-zero/rest/handler"
|
||||
"github.com/zeromicro/go-zero/rest/httpx"
|
||||
"github.com/zeromicro/go-zero/rest/internal"
|
||||
"github.com/zeromicro/go-zero/rest/internal/header"
|
||||
"github.com/zeromicro/go-zero/rest/internal/response"
|
||||
)
|
||||
|
||||
@@ -54,6 +55,9 @@ func newEngine(c RestConf) *engine {
|
||||
}
|
||||
|
||||
func (ng *engine) addRoutes(r featuredRoutes) {
|
||||
if r.sse {
|
||||
r.routes = buildSSERoutes(r.routes)
|
||||
}
|
||||
ng.routes = append(ng.routes, r)
|
||||
|
||||
// need to guarantee the timeout is the max of all routes
|
||||
@@ -63,6 +67,20 @@ func (ng *engine) addRoutes(r featuredRoutes) {
|
||||
}
|
||||
}
|
||||
|
||||
func buildSSERoutes(routes []Route) []Route {
|
||||
for i, route := range routes {
|
||||
h := route.Handler
|
||||
routes[i].Handler = func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set(header.ContentType, header.ContentTypeEventStream)
|
||||
w.Header().Set(header.CacheControl, header.CacheControlNoCache)
|
||||
w.Header().Set(header.Connection, header.ConnectionKeepAlive)
|
||||
h(w, r)
|
||||
}
|
||||
}
|
||||
|
||||
return routes
|
||||
}
|
||||
|
||||
func (ng *engine) appendAuthHandler(fr featuredRoutes, chn chain.Chain,
|
||||
verifier func(chain.Chain) chain.Chain) chain.Chain {
|
||||
if fr.jwt.enabled {
|
||||
@@ -210,6 +228,10 @@ func (ng *engine) getShedder(priority bool) load.Shedder {
|
||||
return ng.shedder
|
||||
}
|
||||
|
||||
func (ng *engine) hasTimeout() bool {
|
||||
return ng.conf.Middlewares.Timeout && ng.timeout > 0
|
||||
}
|
||||
|
||||
// notFoundHandler returns a middleware that handles 404 not found requests.
|
||||
func (ng *engine) notFoundHandler(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
@@ -336,16 +358,17 @@ func (ng *engine) use(middleware Middleware) {
|
||||
|
||||
func (ng *engine) withTimeout() internal.StartOption {
|
||||
return func(svr *http.Server) {
|
||||
timeout := ng.timeout
|
||||
if timeout > 0 {
|
||||
// factor 0.8, to avoid clients send longer content-length than the actual content,
|
||||
// without this timeout setting, the server will time out and respond 503 Service Unavailable,
|
||||
// which triggers the circuit breaker.
|
||||
svr.ReadTimeout = 4 * timeout / 5
|
||||
// factor 1.1, to avoid servers don't have enough time to write responses.
|
||||
// setting the factor less than 1.0 may lead clients not receiving the responses.
|
||||
svr.WriteTimeout = 11 * timeout / 10
|
||||
if !ng.hasTimeout() {
|
||||
return
|
||||
}
|
||||
|
||||
// factor 0.8, to avoid clients send longer content-length than the actual content,
|
||||
// without this timeout setting, the server will time out and respond 503 Service Unavailable,
|
||||
// which triggers the circuit breaker.
|
||||
svr.ReadTimeout = 4 * ng.timeout / 5
|
||||
// factor 1.1, to avoid servers don't have enough time to write responses.
|
||||
// setting the factor less than 1.0 may lead clients not receiving the responses.
|
||||
svr.WriteTimeout = 11 * ng.timeout / 10
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -394,7 +394,12 @@ func TestEngine_withTimeout(t *testing.T) {
|
||||
for _, test := range tests {
|
||||
test := test
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
ng := newEngine(RestConf{Timeout: test.timeout})
|
||||
ng := newEngine(RestConf{
|
||||
Timeout: test.timeout,
|
||||
Middlewares: MiddlewaresConf{
|
||||
Timeout: true,
|
||||
},
|
||||
})
|
||||
svr := &http.Server{}
|
||||
ng.withTimeout()(svr)
|
||||
|
||||
@@ -406,6 +411,62 @@ func TestEngine_withTimeout(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestEngine_ReadWriteTimeout(t *testing.T) {
|
||||
logx.Disable()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
timeout int64
|
||||
middleware bool
|
||||
}{
|
||||
{
|
||||
name: "0/false",
|
||||
timeout: 0,
|
||||
middleware: false,
|
||||
},
|
||||
{
|
||||
name: "0/true",
|
||||
timeout: 0,
|
||||
middleware: true,
|
||||
},
|
||||
{
|
||||
name: "set/false",
|
||||
timeout: 1000,
|
||||
middleware: false,
|
||||
},
|
||||
{
|
||||
name: "both set",
|
||||
timeout: 1000,
|
||||
middleware: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
test := test
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
ng := newEngine(RestConf{
|
||||
Timeout: test.timeout,
|
||||
Middlewares: MiddlewaresConf{
|
||||
Timeout: test.middleware,
|
||||
},
|
||||
})
|
||||
svr := &http.Server{}
|
||||
ng.withTimeout()(svr)
|
||||
|
||||
assert.Equal(t, time.Duration(0), svr.ReadHeaderTimeout)
|
||||
assert.Equal(t, time.Duration(0), svr.IdleTimeout)
|
||||
|
||||
if test.timeout > 0 && test.middleware {
|
||||
assert.Equal(t, time.Duration(test.timeout)*time.Millisecond*4/5, svr.ReadTimeout)
|
||||
assert.Equal(t, time.Duration(test.timeout)*time.Millisecond*11/10, svr.WriteTimeout)
|
||||
} else {
|
||||
assert.Equal(t, time.Duration(0), svr.ReadTimeout)
|
||||
assert.Equal(t, time.Duration(0), svr.WriteTimeout)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestEngine_start(t *testing.T) {
|
||||
logx.Disable()
|
||||
|
||||
|
||||
@@ -106,8 +106,8 @@ func (h *timeoutHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
case <-ctx.Done():
|
||||
tw.mu.Lock()
|
||||
defer tw.mu.Unlock()
|
||||
// there isn't any user-defined middleware before TimoutHandler,
|
||||
// so we can guarantee that cancelation in biz related code won't come here.
|
||||
// there isn't any user-defined middleware before TimeoutHandler,
|
||||
// so we can guarantee that cancellation in biz related code won't come here.
|
||||
httpx.ErrorCtx(r.Context(), w, ctx.Err(), func(w http.ResponseWriter, err error) {
|
||||
if errors.Is(err, context.Canceled) {
|
||||
w.WriteHeader(statusClientClosedRequest)
|
||||
@@ -151,7 +151,7 @@ func (tw *timeoutWriter) Flush() {
|
||||
flusher.Flush()
|
||||
}
|
||||
|
||||
// Header returns the underline temporary http.Header.
|
||||
// Header returns the underlying temporary http.Header.
|
||||
func (tw *timeoutWriter) Header() http.Header {
|
||||
return tw.h
|
||||
}
|
||||
|
||||
@@ -105,7 +105,7 @@ func buildRequest(ctx context.Context, method, url string, data any) (*http.Requ
|
||||
req.URL.RawQuery = buildFormQuery(u, val[formKey])
|
||||
fillHeader(req, val[headerKey])
|
||||
if hasJsonBody {
|
||||
req.Header.Set(header.ContentType, header.JsonContentType)
|
||||
req.Header.Set(header.ContentType, header.ContentTypeJson)
|
||||
}
|
||||
|
||||
return req, nil
|
||||
|
||||
@@ -45,7 +45,7 @@ func TestDoRequest_NotFound(t *testing.T) {
|
||||
defer svr.Close()
|
||||
req, err := http.NewRequest(http.MethodPost, svr.URL, nil)
|
||||
assert.Nil(t, err)
|
||||
req.Header.Set(header.ContentType, header.JsonContentType)
|
||||
req.Header.Set(header.ContentType, header.ContentTypeJson)
|
||||
resp, err := DoRequest(req)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, http.StatusNotFound, resp.StatusCode)
|
||||
|
||||
@@ -18,7 +18,7 @@ func TestParse(t *testing.T) {
|
||||
}
|
||||
svr := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("foo", "bar")
|
||||
w.Header().Set(header.ContentType, header.JsonContentType)
|
||||
w.Header().Set(header.ContentType, header.ContentTypeJson)
|
||||
w.Write([]byte(`{"name":"kevin","value":100}`))
|
||||
}))
|
||||
defer svr.Close()
|
||||
@@ -38,7 +38,7 @@ func TestParseHeaderError(t *testing.T) {
|
||||
}
|
||||
svr := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("foo", "bar")
|
||||
w.Header().Set(header.ContentType, header.JsonContentType)
|
||||
w.Header().Set(header.ContentType, header.ContentTypeJson)
|
||||
}))
|
||||
defer svr.Close()
|
||||
req, err := http.NewRequest(http.MethodGet, svr.URL, nil)
|
||||
@@ -54,7 +54,7 @@ func TestParseNoBody(t *testing.T) {
|
||||
}
|
||||
svr := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("foo", "bar")
|
||||
w.Header().Set(header.ContentType, header.JsonContentType)
|
||||
w.Header().Set(header.ContentType, header.ContentTypeJson)
|
||||
}))
|
||||
defer svr.Close()
|
||||
req, err := http.NewRequest(http.MethodGet, svr.URL, nil)
|
||||
@@ -72,7 +72,7 @@ func TestParseWithZeroValue(t *testing.T) {
|
||||
}
|
||||
svr := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("foo", "0")
|
||||
w.Header().Set(header.ContentType, header.JsonContentType)
|
||||
w.Header().Set(header.ContentType, header.ContentTypeJson)
|
||||
w.Write([]byte(`{"bar":0}`))
|
||||
}))
|
||||
defer svr.Close()
|
||||
@@ -90,7 +90,7 @@ func TestParseWithNegativeContentLength(t *testing.T) {
|
||||
Bar int `json:"bar"`
|
||||
}
|
||||
svr := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set(header.ContentType, header.JsonContentType)
|
||||
w.Header().Set(header.ContentType, header.ContentTypeJson)
|
||||
w.Write([]byte(`{"bar":0}`))
|
||||
}))
|
||||
defer svr.Close()
|
||||
@@ -124,7 +124,7 @@ func TestParseWithNegativeContentLength(t *testing.T) {
|
||||
func TestParseWithNegativeContentLengthNoBody(t *testing.T) {
|
||||
var val struct{}
|
||||
svr := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set(header.ContentType, header.JsonContentType)
|
||||
w.Header().Set(header.ContentType, header.ContentTypeJson)
|
||||
}))
|
||||
defer svr.Close()
|
||||
req, err := http.NewRequest(http.MethodGet, svr.URL, nil)
|
||||
@@ -156,7 +156,7 @@ func TestParseWithNegativeContentLengthNoBody(t *testing.T) {
|
||||
func TestParseJsonBody_BodyError(t *testing.T) {
|
||||
var val struct{}
|
||||
svr := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set(header.ContentType, header.JsonContentType)
|
||||
w.Header().Set(header.ContentType, header.ContentTypeJson)
|
||||
}))
|
||||
defer svr.Close()
|
||||
req, err := http.NewRequest(http.MethodGet, svr.URL, nil)
|
||||
|
||||
@@ -44,7 +44,7 @@ func TestNamedService_DoRequestPost(t *testing.T) {
|
||||
service := NewService("foo")
|
||||
req, err := http.NewRequest(http.MethodPost, svr.URL, nil)
|
||||
assert.Nil(t, err)
|
||||
req.Header.Set(header.ContentType, header.JsonContentType)
|
||||
req.Header.Set(header.ContentType, header.ContentTypeJson)
|
||||
resp, err := service.DoRequest(req)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, http.StatusNotFound, resp.StatusCode)
|
||||
|
||||
@@ -160,7 +160,7 @@ func TestParseFormArray(t *testing.T) {
|
||||
http.NoBody)
|
||||
assert.NoError(t, err)
|
||||
if assert.NoError(t, Parse(r, &v)) {
|
||||
assert.ElementsMatch(t, []string{"1", "2", "3"}, v.Names)
|
||||
assert.ElementsMatch(t, []string{"1,2,3"}, v.Names)
|
||||
}
|
||||
})
|
||||
|
||||
@@ -189,9 +189,7 @@ func TestParseFormArray(t *testing.T) {
|
||||
"/a?numbers=1,2,3",
|
||||
http.NoBody)
|
||||
assert.NoError(t, err)
|
||||
if assert.NoError(t, Parse(r, &v)) {
|
||||
assert.ElementsMatch(t, []int{1, 2, 3}, v.Numbers)
|
||||
}
|
||||
assert.Error(t, Parse(r, &v))
|
||||
})
|
||||
|
||||
t.Run("slice with one value on array format brackets", func(t *testing.T) {
|
||||
@@ -268,6 +266,36 @@ func TestParseFormArray(t *testing.T) {
|
||||
assert.ElementsMatch(t, []float64{2}, v.Numbers)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("slice with one value", func(t *testing.T) {
|
||||
var v struct {
|
||||
Codes []string `form:"codes"`
|
||||
}
|
||||
|
||||
r, err := http.NewRequest(
|
||||
http.MethodGet,
|
||||
"/a?codes=aaa,bbb,ccc",
|
||||
http.NoBody)
|
||||
assert.NoError(t, err)
|
||||
if assert.NoError(t, Parse(r, &v)) {
|
||||
assert.ElementsMatch(t, []string{"aaa,bbb,ccc"}, v.Codes)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("slice with multiple values", func(t *testing.T) {
|
||||
var v struct {
|
||||
Codes []string `form:"codes,arrayComma=false"`
|
||||
}
|
||||
|
||||
r, err := http.NewRequest(
|
||||
http.MethodGet,
|
||||
"/a?codes=aaa,bbb,ccc&codes=ccc,ddd,eee",
|
||||
http.NoBody)
|
||||
assert.NoError(t, err)
|
||||
if assert.NoError(t, Parse(r, &v)) {
|
||||
assert.ElementsMatch(t, []string{"aaa,bbb,ccc", "ccc,ddd,eee"}, v.Codes)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestParseForm_Error(t *testing.T) {
|
||||
@@ -448,7 +476,7 @@ func TestParseJsonBody(t *testing.T) {
|
||||
|
||||
body := `{"name":"kevin", "age": 18}`
|
||||
r := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(body))
|
||||
r.Header.Set(ContentType, header.JsonContentType)
|
||||
r.Header.Set(ContentType, header.ContentTypeJson)
|
||||
|
||||
if assert.NoError(t, Parse(r, &v)) {
|
||||
assert.Equal(t, "kevin", v.Name)
|
||||
@@ -464,7 +492,7 @@ func TestParseJsonBody(t *testing.T) {
|
||||
|
||||
body := `{"name":"kevin", "ag": 18}`
|
||||
r := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(body))
|
||||
r.Header.Set(ContentType, header.JsonContentType)
|
||||
r.Header.Set(ContentType, header.ContentTypeJson)
|
||||
|
||||
assert.Error(t, Parse(r, &v))
|
||||
})
|
||||
@@ -489,7 +517,7 @@ func TestParseJsonBody(t *testing.T) {
|
||||
|
||||
body := `[{"name":"kevin", "age": 18}]`
|
||||
r := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(body))
|
||||
r.Header.Set(ContentType, header.JsonContentType)
|
||||
r.Header.Set(ContentType, header.ContentTypeJson)
|
||||
|
||||
assert.NoError(t, Parse(r, &v))
|
||||
assert.Equal(t, 1, len(v))
|
||||
@@ -509,7 +537,7 @@ func TestParseJsonBody(t *testing.T) {
|
||||
|
||||
body := `[{"name":"apple", "age": 18}]`
|
||||
r := httptest.NewRequest(http.MethodPost, "/a?product=tree", strings.NewReader(body))
|
||||
r.Header.Set(ContentType, header.JsonContentType)
|
||||
r.Header.Set(ContentType, header.ContentTypeJson)
|
||||
|
||||
assert.NoError(t, Parse(r, &v))
|
||||
assert.Equal(t, 1, len(v))
|
||||
@@ -527,7 +555,7 @@ func TestParseJsonBody(t *testing.T) {
|
||||
body, _ := json.Marshal(v1)
|
||||
t.Logf("body:%s", string(body))
|
||||
r := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(string(body)))
|
||||
r.Header.Set(ContentType, header.JsonContentType)
|
||||
r.Header.Set(ContentType, header.ContentTypeJson)
|
||||
var v2 v
|
||||
err := ParseJsonBody(r, &v2)
|
||||
if assert.NoError(t, err) {
|
||||
@@ -581,7 +609,7 @@ func TestParseHeaders(t *testing.T) {
|
||||
request.Header.Add("addrs", "addr2")
|
||||
request.Header.Add("X-Forwarded-For", "10.0.10.11")
|
||||
request.Header.Add("x-real-ip", "10.0.11.10")
|
||||
request.Header.Add("Accept", header.JsonContentType)
|
||||
request.Header.Add("Accept", header.ContentTypeJson)
|
||||
err = ParseHeaders(request, &v)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
@@ -591,7 +619,7 @@ func TestParseHeaders(t *testing.T) {
|
||||
assert.Equal(t, []string{"addr1", "addr2"}, v.Addrs)
|
||||
assert.Equal(t, "10.0.10.11", v.XForwardedFor)
|
||||
assert.Equal(t, "10.0.11.10", v.XRealIP)
|
||||
assert.Equal(t, header.JsonContentType, v.Accept)
|
||||
assert.Equal(t, header.ContentTypeJson, v.Accept)
|
||||
}
|
||||
|
||||
func TestParseHeaders_Error(t *testing.T) {
|
||||
@@ -683,7 +711,7 @@ func TestParseWithFloatPtr(t *testing.T) {
|
||||
}
|
||||
body := `{"weightFloat32": 3.2}`
|
||||
r := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(body))
|
||||
r.Header.Set(ContentType, header.JsonContentType)
|
||||
r.Header.Set(ContentType, header.ContentTypeJson)
|
||||
|
||||
if assert.NoError(t, Parse(r, &v)) {
|
||||
assert.Equal(t, float32(3.2), *v.WeightFloat32)
|
||||
|
||||
@@ -179,7 +179,7 @@ func doWriteJson(w http.ResponseWriter, code int, v any) error {
|
||||
return fmt.Errorf("marshal json failed, error: %w", err)
|
||||
}
|
||||
|
||||
w.Header().Set(ContentType, header.JsonContentType)
|
||||
w.Header().Set(ContentType, header.ContentTypeJson)
|
||||
w.WriteHeader(code)
|
||||
|
||||
if n, err := w.Write(bs); err != nil {
|
||||
|
||||
@@ -10,7 +10,7 @@ const (
|
||||
// ContentType means Content-Type.
|
||||
ContentType = header.ContentType
|
||||
// JsonContentType means application/json.
|
||||
JsonContentType = header.JsonContentType
|
||||
JsonContentType = header.ContentTypeJson
|
||||
// KeyField means key.
|
||||
KeyField = "key"
|
||||
// SecretField means secret.
|
||||
|
||||
@@ -2,15 +2,16 @@ package fileserver
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"path"
|
||||
"strings"
|
||||
"sync"
|
||||
)
|
||||
|
||||
// Middleware returns a middleware that serves files from the given file system.
|
||||
func Middleware(path string, fs http.FileSystem) func(http.HandlerFunc) http.HandlerFunc {
|
||||
func Middleware(upath string, fs http.FileSystem) func(http.HandlerFunc) http.HandlerFunc {
|
||||
fileServer := http.FileServer(fs)
|
||||
pathWithoutTrailSlash := ensureNoTrailingSlash(path)
|
||||
canServe := createServeChecker(path, fs)
|
||||
pathWithoutTrailSlash := ensureNoTrailingSlash(upath)
|
||||
canServe := createServeChecker(upath, fs)
|
||||
|
||||
return func(next http.HandlerFunc) http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
@@ -28,9 +29,22 @@ func createFileChecker(fs http.FileSystem) func(string) bool {
|
||||
var lock sync.RWMutex
|
||||
fileChecker := make(map[string]bool)
|
||||
|
||||
return func(path string) bool {
|
||||
return func(upath string) bool {
|
||||
// Emulate http.Dir.Open’s path normalization for embed.FS.Open.
|
||||
// http.FileServer redirects any request ending in "/index.html"
|
||||
// to the same path without the final "index.html".
|
||||
// So the path here may be empty or end with a "/".
|
||||
// http.Dir.Open uses this logic to clean the path,
|
||||
// correctly handling those two cases.
|
||||
// embed.FS doesn’t perform this normalization, so we apply the same logic here.
|
||||
upath = path.Clean("/" + upath)[1:]
|
||||
if len(upath) == 0 {
|
||||
// if the path is empty, we use "." to open the current directory
|
||||
upath = "."
|
||||
}
|
||||
|
||||
lock.RLock()
|
||||
exist, ok := fileChecker[path]
|
||||
exist, ok := fileChecker[upath]
|
||||
lock.RUnlock()
|
||||
if ok {
|
||||
return exist
|
||||
@@ -39,9 +53,9 @@ func createFileChecker(fs http.FileSystem) func(string) bool {
|
||||
lock.Lock()
|
||||
defer lock.Unlock()
|
||||
|
||||
file, err := fs.Open(path)
|
||||
file, err := fs.Open(upath)
|
||||
exist = err == nil
|
||||
fileChecker[path] = exist
|
||||
fileChecker[upath] = exist
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
@@ -51,8 +65,8 @@ func createFileChecker(fs http.FileSystem) func(string) bool {
|
||||
}
|
||||
}
|
||||
|
||||
func createServeChecker(path string, fs http.FileSystem) func(r *http.Request) bool {
|
||||
pathWithTrailSlash := ensureTrailingSlash(path)
|
||||
func createServeChecker(upath string, fs http.FileSystem) func(r *http.Request) bool {
|
||||
pathWithTrailSlash := ensureTrailingSlash(upath)
|
||||
fileChecker := createFileChecker(fs)
|
||||
|
||||
return func(r *http.Request) bool {
|
||||
@@ -62,18 +76,18 @@ func createServeChecker(path string, fs http.FileSystem) func(r *http.Request) b
|
||||
}
|
||||
}
|
||||
|
||||
func ensureTrailingSlash(path string) string {
|
||||
if strings.HasSuffix(path, "/") {
|
||||
return path
|
||||
func ensureTrailingSlash(upath string) string {
|
||||
if strings.HasSuffix(upath, "/") {
|
||||
return upath
|
||||
}
|
||||
|
||||
return path + "/"
|
||||
return upath + "/"
|
||||
}
|
||||
|
||||
func ensureNoTrailingSlash(path string) string {
|
||||
if strings.HasSuffix(path, "/") {
|
||||
return path[:len(path)-1]
|
||||
func ensureNoTrailingSlash(upath string) string {
|
||||
if strings.HasSuffix(upath, "/") {
|
||||
return upath[:len(upath)-1]
|
||||
}
|
||||
|
||||
return path
|
||||
return upath
|
||||
}
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
package fileserver
|
||||
|
||||
import (
|
||||
"embed"
|
||||
"io/fs"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
@@ -61,6 +63,46 @@ func TestMiddleware(t *testing.T) {
|
||||
requestPath: "/ws",
|
||||
expectedStatus: http.StatusAlreadyReported,
|
||||
},
|
||||
|
||||
// http.FileServer redirects any request ending in "/index.html"
|
||||
// to the same path, without the final "index.html".
|
||||
{
|
||||
name: "Serve index.html",
|
||||
path: "/static",
|
||||
dir: "testdata",
|
||||
requestPath: "/static/index.html",
|
||||
expectedStatus: http.StatusMovedPermanently,
|
||||
},
|
||||
{
|
||||
name: "Serve index.html with path with trailing slash",
|
||||
path: "/static/",
|
||||
dir: "testdata",
|
||||
requestPath: "/static/index.html",
|
||||
expectedStatus: http.StatusMovedPermanently,
|
||||
},
|
||||
{
|
||||
name: "Serve index.html in a nested directory",
|
||||
path: "/static",
|
||||
dir: "testdata",
|
||||
requestPath: "/static/nested/index.html",
|
||||
expectedStatus: http.StatusMovedPermanently,
|
||||
},
|
||||
{
|
||||
name: "Request index.html indirectly",
|
||||
path: "/static",
|
||||
dir: "testdata",
|
||||
requestPath: "/static/",
|
||||
expectedStatus: http.StatusOK,
|
||||
expectedContent: "hello",
|
||||
},
|
||||
{
|
||||
name: "Request index.html in a nested directory indirectly",
|
||||
path: "/static",
|
||||
dir: "testdata",
|
||||
requestPath: "/static/nested/",
|
||||
expectedStatus: http.StatusOK,
|
||||
expectedContent: "hello",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
@@ -87,6 +129,128 @@ func TestMiddleware(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
var (
|
||||
//go:embed testdata
|
||||
testdataFS embed.FS
|
||||
)
|
||||
|
||||
func TestMiddleware_embedFS(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
path string
|
||||
requestPath string
|
||||
expectedStatus int
|
||||
expectedContent string
|
||||
}{
|
||||
{
|
||||
name: "Serve static file",
|
||||
path: "/static",
|
||||
requestPath: "/static/example.txt",
|
||||
expectedStatus: http.StatusOK,
|
||||
expectedContent: "1",
|
||||
},
|
||||
{
|
||||
name: "Path with trailing slash",
|
||||
path: "/static/",
|
||||
requestPath: "/static/example.txt",
|
||||
expectedStatus: http.StatusOK,
|
||||
expectedContent: "1",
|
||||
},
|
||||
{
|
||||
name: "Root path",
|
||||
path: "/",
|
||||
requestPath: "/example.txt",
|
||||
expectedStatus: http.StatusOK,
|
||||
expectedContent: "1",
|
||||
},
|
||||
{
|
||||
name: "Pass through non-matching path",
|
||||
path: "/static/",
|
||||
requestPath: "/other/path",
|
||||
expectedStatus: http.StatusAlreadyReported,
|
||||
},
|
||||
{
|
||||
name: "Not exist file",
|
||||
path: "/assets",
|
||||
requestPath: "/assets/not-exist.txt",
|
||||
expectedStatus: http.StatusAlreadyReported,
|
||||
},
|
||||
{
|
||||
name: "Not exist file in root",
|
||||
path: "/",
|
||||
requestPath: "/not-exist.txt",
|
||||
expectedStatus: http.StatusAlreadyReported,
|
||||
},
|
||||
{
|
||||
name: "websocket request",
|
||||
path: "/",
|
||||
requestPath: "/ws",
|
||||
expectedStatus: http.StatusAlreadyReported,
|
||||
},
|
||||
|
||||
// http.FileServer redirects any request ending in "/index.html"
|
||||
// to the same path, without the final "index.html".
|
||||
{
|
||||
name: "Serve index.html",
|
||||
path: "/static",
|
||||
requestPath: "/static/index.html",
|
||||
expectedStatus: http.StatusMovedPermanently,
|
||||
},
|
||||
{
|
||||
name: "Serve index.html with path with trailing slash",
|
||||
path: "/static/",
|
||||
requestPath: "/static/index.html",
|
||||
expectedStatus: http.StatusMovedPermanently,
|
||||
},
|
||||
{
|
||||
name: "Serve index.html in a nested directory",
|
||||
path: "/static",
|
||||
requestPath: "/static/nested/index.html",
|
||||
expectedStatus: http.StatusMovedPermanently,
|
||||
},
|
||||
{
|
||||
name: "Request index.html indirectly",
|
||||
path: "/static",
|
||||
requestPath: "/static/",
|
||||
expectedStatus: http.StatusOK,
|
||||
expectedContent: "hello",
|
||||
},
|
||||
{
|
||||
name: "Request index.html in a nested directory indirectly",
|
||||
path: "/static",
|
||||
requestPath: "/static/nested/",
|
||||
expectedStatus: http.StatusOK,
|
||||
expectedContent: "hello",
|
||||
},
|
||||
}
|
||||
|
||||
subFS, err := fs.Sub(testdataFS, "testdata")
|
||||
assert.Nil(t, err)
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
middleware := Middleware(tt.path, http.FS(subFS))
|
||||
nextHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusAlreadyReported)
|
||||
})
|
||||
|
||||
handlerToTest := middleware(nextHandler)
|
||||
|
||||
for i := 0; i < 2; i++ {
|
||||
req := httptest.NewRequest(http.MethodGet, tt.requestPath, nil)
|
||||
rr := httptest.NewRecorder()
|
||||
|
||||
handlerToTest.ServeHTTP(rr, req)
|
||||
|
||||
assert.Equal(t, tt.expectedStatus, rr.Code)
|
||||
if len(tt.expectedContent) > 0 {
|
||||
assert.Equal(t, tt.expectedContent, rr.Body.String())
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestEnsureTrailingSlash(t *testing.T) {
|
||||
tests := []struct {
|
||||
input string
|
||||
|
||||
1
rest/internal/fileserver/testdata/index.html
vendored
Normal file
1
rest/internal/fileserver/testdata/index.html
vendored
Normal file
@@ -0,0 +1 @@
|
||||
hello
|
||||
1
rest/internal/fileserver/testdata/nested/index.html
vendored
Normal file
1
rest/internal/fileserver/testdata/nested/index.html
vendored
Normal file
@@ -0,0 +1 @@
|
||||
hello
|
||||
@@ -3,8 +3,18 @@ package header
|
||||
const (
|
||||
// ApplicationJson stands for application/json.
|
||||
ApplicationJson = "application/json"
|
||||
// CacheControl is the header key for Cache-Control.
|
||||
CacheControl = "Cache-Control"
|
||||
// CacheControlNoCache is the value for Cache-Control: no-cache.
|
||||
CacheControlNoCache = "no-cache"
|
||||
// Connection is the header key for Connection.
|
||||
Connection = "Connection"
|
||||
// ConnectionKeepAlive is the value for Connection: keep-alive.
|
||||
ConnectionKeepAlive = "keep-alive"
|
||||
// ContentType is the header key for Content-Type.
|
||||
ContentType = "Content-Type"
|
||||
// JsonContentType is the content type for JSON.
|
||||
JsonContentType = "application/json; charset=utf-8"
|
||||
// ContentTypeJson is the content type for JSON.
|
||||
ContentTypeJson = "application/json; charset=utf-8"
|
||||
// ContentTypeEventStream is the content type for event stream.
|
||||
ContentTypeEventStream = "text/event-stream"
|
||||
)
|
||||
|
||||
@@ -628,7 +628,7 @@ func TestParseWrappedRequest(t *testing.T) {
|
||||
func TestParseWrappedGetRequestWithJsonHeader(t *testing.T) {
|
||||
r, err := http.NewRequest(http.MethodGet, "http://hello.com/kevin/2017", bytes.NewReader(nil))
|
||||
assert.Nil(t, err)
|
||||
r.Header.Set(httpx.ContentType, header.JsonContentType)
|
||||
r.Header.Set(httpx.ContentType, header.ContentTypeJson)
|
||||
|
||||
type (
|
||||
Request struct {
|
||||
@@ -661,7 +661,7 @@ func TestParseWrappedGetRequestWithJsonHeader(t *testing.T) {
|
||||
func TestParseWrappedHeadRequestWithJsonHeader(t *testing.T) {
|
||||
r, err := http.NewRequest(http.MethodHead, "http://hello.com/kevin/2017", bytes.NewReader(nil))
|
||||
assert.Nil(t, err)
|
||||
r.Header.Set(httpx.ContentType, header.JsonContentType)
|
||||
r.Header.Set(httpx.ContentType, header.ContentTypeJson)
|
||||
|
||||
type (
|
||||
Request struct {
|
||||
@@ -758,7 +758,7 @@ func TestParseWithAllUtf8(t *testing.T) {
|
||||
r, err := http.NewRequest(http.MethodPost, "http://hello.com/kevin/2017?nickname=whatever&zipcode=200000",
|
||||
bytes.NewBufferString(`{"location": "shanghai", "time": 20170912}`))
|
||||
assert.Nil(t, err)
|
||||
r.Header.Set(httpx.ContentType, header.JsonContentType)
|
||||
r.Header.Set(httpx.ContentType, header.ContentTypeJson)
|
||||
|
||||
router := NewRouter()
|
||||
err = router.Handle(http.MethodPost, "/:name/:year", http.HandlerFunc(
|
||||
@@ -948,7 +948,7 @@ func TestParseWithMissingAllPaths(t *testing.T) {
|
||||
func TestParseGetWithContentLengthHeader(t *testing.T) {
|
||||
r, err := http.NewRequest(http.MethodGet, "http://hello.com/kevin/2017?nickname=whatever&zipcode=200000", nil)
|
||||
assert.Nil(t, err)
|
||||
r.Header.Set(httpx.ContentType, header.JsonContentType)
|
||||
r.Header.Set(httpx.ContentType, header.ContentTypeJson)
|
||||
r.Header.Set(contentLength, "1024")
|
||||
|
||||
router := NewRouter()
|
||||
@@ -976,7 +976,7 @@ func TestParseJsonPostWithTypeMismatch(t *testing.T) {
|
||||
r, err := http.NewRequest(http.MethodPost, "http://hello.com/kevin/2017?nickname=whatever&zipcode=200000",
|
||||
bytes.NewBufferString(`{"time": "20170912"}`))
|
||||
assert.Nil(t, err)
|
||||
r.Header.Set(httpx.ContentType, header.JsonContentType)
|
||||
r.Header.Set(httpx.ContentType, header.ContentTypeJson)
|
||||
|
||||
router := NewRouter()
|
||||
err = router.Handle(http.MethodPost, "/:name/:year", http.HandlerFunc(
|
||||
@@ -1002,7 +1002,7 @@ func TestParseJsonPostWithInt2String(t *testing.T) {
|
||||
r, err := http.NewRequest(http.MethodPost, "http://hello.com/kevin/2017",
|
||||
bytes.NewBufferString(`{"time": 20170912}`))
|
||||
assert.Nil(t, err)
|
||||
r.Header.Set(httpx.ContentType, header.JsonContentType)
|
||||
r.Header.Set(httpx.ContentType, header.ContentTypeJson)
|
||||
|
||||
router := NewRouter()
|
||||
err = router.Handle(http.MethodPost, "/:name/:year", http.HandlerFunc(
|
||||
|
||||
@@ -63,6 +63,11 @@ func NewServer(c RestConf, opts ...RunOption) (*Server, error) {
|
||||
return server, nil
|
||||
}
|
||||
|
||||
// AddRoute adds given route into the Server.
|
||||
func (s *Server) AddRoute(r Route, opts ...RouteOption) {
|
||||
s.AddRoutes([]Route{r}, opts...)
|
||||
}
|
||||
|
||||
// AddRoutes add given routes into the Server.
|
||||
func (s *Server) AddRoutes(rs []Route, opts ...RouteOption) {
|
||||
r := featuredRoutes{
|
||||
@@ -74,11 +79,6 @@ func (s *Server) AddRoutes(rs []Route, opts ...RouteOption) {
|
||||
s.ngin.addRoutes(r)
|
||||
}
|
||||
|
||||
// AddRoute adds given route into the Server.
|
||||
func (s *Server) AddRoute(r Route, opts ...RouteOption) {
|
||||
s.AddRoutes([]Route{r}, opts...)
|
||||
}
|
||||
|
||||
// PrintRoutes prints the added routes to stdout.
|
||||
func (s *Server) PrintRoutes() {
|
||||
s.ngin.print()
|
||||
@@ -95,25 +95,6 @@ func (s *Server) Routes() []Route {
|
||||
return routes
|
||||
}
|
||||
|
||||
// ServeHTTP is for test purpose, allow developer to do a unit test with
|
||||
// all defined router without starting an HTTP Server.
|
||||
//
|
||||
// For example:
|
||||
//
|
||||
// server := MustNewServer(...)
|
||||
// server.addRoute(...) // router a
|
||||
// server.addRoute(...) // router b
|
||||
// server.addRoute(...) // router c
|
||||
//
|
||||
// r, _ := http.NewRequest(...)
|
||||
// w := httptest.NewRecorder(...)
|
||||
// server.ServeHTTP(w, r)
|
||||
// // verify the response
|
||||
func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
s.ngin.bindRoutes(s.router)
|
||||
s.router.ServeHTTP(w, r)
|
||||
}
|
||||
|
||||
// Start starts the Server.
|
||||
// Graceful shutdown is enabled by default.
|
||||
// Use proc.SetTimeToForceQuit to customize the graceful shutdown period.
|
||||
@@ -298,6 +279,14 @@ func WithSignature(signature SignatureConf) RouteOption {
|
||||
}
|
||||
}
|
||||
|
||||
// WithSSE returns a RouteOption to enable server-sent events.
|
||||
func WithSSE() RouteOption {
|
||||
return func(r *featuredRoutes) {
|
||||
r.sse = true
|
||||
r.timeout = 0
|
||||
}
|
||||
}
|
||||
|
||||
// WithTimeout returns a RouteOption to set timeout with given value.
|
||||
func WithTimeout(timeout time.Duration) RouteOption {
|
||||
return func(r *featuredRoutes) {
|
||||
|
||||
@@ -20,6 +20,7 @@ import (
|
||||
"github.com/zeromicro/go-zero/rest/chain"
|
||||
"github.com/zeromicro/go-zero/rest/httpx"
|
||||
"github.com/zeromicro/go-zero/rest/internal/cors"
|
||||
"github.com/zeromicro/go-zero/rest/internal/header"
|
||||
"github.com/zeromicro/go-zero/rest/router"
|
||||
)
|
||||
|
||||
@@ -231,7 +232,7 @@ func TestWithFileServerMiddleware(t *testing.T) {
|
||||
req := httptest.NewRequest(http.MethodGet, tt.requestPath, nil)
|
||||
rr := httptest.NewRecorder()
|
||||
|
||||
server.ServeHTTP(rr, req)
|
||||
serve(server, rr, req)
|
||||
|
||||
assert.Equal(t, tt.expectedStatus, rr.Code)
|
||||
if len(tt.expectedContent) > 0 {
|
||||
@@ -458,7 +459,7 @@ Port: 54321
|
||||
// we would need to verify the behavior here. Since we don't have
|
||||
// direct access to headers, we'll mock newCorsRouter to capture it.
|
||||
w := httptest.NewRecorder()
|
||||
svr.ServeHTTP(w, httptest.NewRequest(http.MethodOptions, "/", nil))
|
||||
serve(svr, w, httptest.NewRequest(http.MethodOptions, "/", nil))
|
||||
|
||||
vals := w.Header().Values("Access-Control-Allow-Headers")
|
||||
respHeaders := make(map[string]struct{})
|
||||
@@ -748,12 +749,46 @@ Port: 54321
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
w := httptest.NewRecorder()
|
||||
req, _ := http.NewRequest("GET", test.path, nil)
|
||||
svr.ServeHTTP(w, req)
|
||||
serve(svr, w, req)
|
||||
assert.Equal(t, test.code, w.Code)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestServerEventStream(t *testing.T) {
|
||||
server := MustNewServer(RestConf{})
|
||||
server.AddRoutes([]Route{
|
||||
{
|
||||
Method: http.MethodGet,
|
||||
Path: "/foo",
|
||||
Handler: func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Write([]byte("foo"))
|
||||
},
|
||||
},
|
||||
{
|
||||
Method: http.MethodGet,
|
||||
Path: "/bar",
|
||||
Handler: func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Write([]byte("bar"))
|
||||
},
|
||||
},
|
||||
}, WithSSE())
|
||||
|
||||
check := func(val string) {
|
||||
req, err := http.NewRequest(http.MethodGet, fmt.Sprintf("/%s", val), http.NoBody)
|
||||
assert.Nil(t, err)
|
||||
rr := httptest.NewRecorder()
|
||||
serve(server, rr, req)
|
||||
assert.Equal(t, http.StatusOK, rr.Code)
|
||||
assert.Equal(t, header.ContentTypeEventStream, rr.Header().Get(header.ContentType))
|
||||
assert.Equal(t, header.CacheControlNoCache, rr.Header().Get(header.CacheControl))
|
||||
assert.Equal(t, header.ConnectionKeepAlive, rr.Header().Get(header.Connection))
|
||||
assert.Equal(t, val, rr.Body.String())
|
||||
}
|
||||
check("foo")
|
||||
check("bar")
|
||||
}
|
||||
|
||||
//go:embed testdata
|
||||
var content embed.FS
|
||||
|
||||
@@ -765,6 +800,25 @@ func TestServerEmbedFileSystem(t *testing.T) {
|
||||
req, err := http.NewRequest(http.MethodGet, "/assets/sample.txt", http.NoBody)
|
||||
assert.Nil(t, err)
|
||||
rr := httptest.NewRecorder()
|
||||
server.ServeHTTP(rr, req)
|
||||
serve(server, rr, req)
|
||||
assert.Equal(t, sampleContent, rr.Body.String())
|
||||
}
|
||||
|
||||
// serve is for test purpose, allow developer to do a unit test with
|
||||
// all defined routes without starting an HTTP Server.
|
||||
//
|
||||
// For example:
|
||||
//
|
||||
// server := MustNewServer(...)
|
||||
// server.addRoute(...) // router a
|
||||
// server.addRoute(...) // router b
|
||||
// server.addRoute(...) // router c
|
||||
//
|
||||
// r, _ := http.NewRequest(...)
|
||||
// w := httptest.NewRecorder(...)
|
||||
// serve(server, w, r)
|
||||
// // verify the response
|
||||
func serve(s *Server, w http.ResponseWriter, r *http.Request) {
|
||||
s.ngin.bindRoutes(s.router)
|
||||
s.router.ServeHTTP(w, r)
|
||||
}
|
||||
|
||||
@@ -35,6 +35,7 @@ type (
|
||||
priority bool
|
||||
jwt jwtSetting
|
||||
signature signatureSetting
|
||||
sse bool
|
||||
routes []Route
|
||||
maxBytes int64
|
||||
}
|
||||
|
||||
1
tools/goctl/.gitignore
vendored
Normal file
1
tools/goctl/.gitignore
vendored
Normal file
@@ -0,0 +1 @@
|
||||
dist
|
||||
@@ -3,7 +3,8 @@ FROM golang:alpine AS builder
|
||||
LABEL stage=gobuilder
|
||||
|
||||
ENV CGO_ENABLED=0
|
||||
ENV GOPROXY=https://goproxy.cn,direct
|
||||
# if you are in China, you can use the following command to speed up the download
|
||||
# ENV GOPROXY=https://goproxy.cn,direct
|
||||
|
||||
RUN apk update --no-cache && apk add --no-cache tzdata
|
||||
RUN go install google.golang.org/protobuf/cmd/protoc-gen-go@latest
|
||||
|
||||
@@ -2,6 +2,7 @@ package api
|
||||
|
||||
import (
|
||||
"github.com/spf13/cobra"
|
||||
|
||||
"github.com/zeromicro/go-zero/tools/goctl/api/apigen"
|
||||
"github.com/zeromicro/go-zero/tools/goctl/api/dartgen"
|
||||
"github.com/zeromicro/go-zero/tools/goctl/api/docgen"
|
||||
@@ -10,6 +11,7 @@ import (
|
||||
"github.com/zeromicro/go-zero/tools/goctl/api/javagen"
|
||||
"github.com/zeromicro/go-zero/tools/goctl/api/ktgen"
|
||||
"github.com/zeromicro/go-zero/tools/goctl/api/new"
|
||||
"github.com/zeromicro/go-zero/tools/goctl/api/swagger"
|
||||
"github.com/zeromicro/go-zero/tools/goctl/api/tsgen"
|
||||
"github.com/zeromicro/go-zero/tools/goctl/api/validate"
|
||||
"github.com/zeromicro/go-zero/tools/goctl/config"
|
||||
@@ -31,6 +33,7 @@ var (
|
||||
ktCmd = cobrax.NewCommand("kt", cobrax.WithRunE(ktgen.KtCommand))
|
||||
pluginCmd = cobrax.NewCommand("plugin", cobrax.WithRunE(plugin.PluginCommand))
|
||||
tsCmd = cobrax.NewCommand("ts", cobrax.WithRunE(tsgen.TsCommand))
|
||||
swaggerCmd = cobrax.NewCommand("swagger", cobrax.WithRunE(swagger.Command))
|
||||
)
|
||||
|
||||
func init() {
|
||||
@@ -46,6 +49,7 @@ func init() {
|
||||
pluginCmdFlags = pluginCmd.Flags()
|
||||
tsCmdFlags = tsCmd.Flags()
|
||||
validateCmdFlags = validateCmd.Flags()
|
||||
swaggerCmdFlags = swaggerCmd.Flags()
|
||||
)
|
||||
|
||||
apiCmdFlags.StringVar(&apigen.VarStringOutput, "o")
|
||||
@@ -73,6 +77,7 @@ func init() {
|
||||
goCmdFlags.StringVar(&gogen.VarStringRemote, "remote")
|
||||
goCmdFlags.StringVar(&gogen.VarStringBranch, "branch")
|
||||
goCmdFlags.BoolVar(&gogen.VarBoolWithTest, "test")
|
||||
goCmdFlags.BoolVar(&gogen.VarBoolTypeGroup, "type-group")
|
||||
goCmdFlags.StringVarWithDefaultValue(&gogen.VarStringStyle, "style", config.DefaultFormat)
|
||||
|
||||
javaCmdFlags.StringVar(&javagen.VarStringDir, "dir")
|
||||
@@ -97,8 +102,13 @@ func init() {
|
||||
tsCmdFlags.StringVar(&tsgen.VarStringCaller, "caller")
|
||||
tsCmdFlags.BoolVar(&tsgen.VarBoolUnWrap, "unwrap")
|
||||
|
||||
swaggerCmdFlags.StringVar(&swagger.VarStringAPI, "api")
|
||||
swaggerCmdFlags.StringVar(&swagger.VarStringDir, "dir")
|
||||
swaggerCmdFlags.StringVar(&swagger.VarStringFilename, "filename")
|
||||
swaggerCmdFlags.BoolVar(&swagger.VarBoolYaml, "yaml")
|
||||
|
||||
validateCmdFlags.StringVar(&validate.VarStringAPI, "api")
|
||||
|
||||
// Add sub-commands
|
||||
Cmd.AddCommand(dartCmd, docCmd, formatCmd, goCmd, javaCmd, ktCmd, newCmd, pluginCmd, tsCmd, validateCmd)
|
||||
Cmd.AddCommand(dartCmd, docCmd, formatCmd, goCmd, javaCmd, ktCmd, newCmd, pluginCmd, tsCmd, validateCmd, swaggerCmd)
|
||||
}
|
||||
|
||||
@@ -10,10 +10,10 @@ const (
|
||||
import 'package:shared_preferences/shared_preferences.dart';
|
||||
import '../data/tokens.dart';
|
||||
|
||||
/// 保存tokens到本地
|
||||
/// store tokens to local
|
||||
///
|
||||
/// 传入null则删除本地tokens
|
||||
/// 返回:true:设置成功 false:设置失败
|
||||
/// pass null will clean local stored tokens
|
||||
/// returns true if success, otherwise false
|
||||
Future<bool> setTokens(Tokens tokens) async {
|
||||
var sp = await SharedPreferences.getInstance();
|
||||
if (tokens == null) {
|
||||
@@ -23,9 +23,9 @@ Future<bool> setTokens(Tokens tokens) async {
|
||||
return await sp.setString('tokens', jsonEncode(tokens.toJson()));
|
||||
}
|
||||
|
||||
/// 获取本地存储的tokens
|
||||
/// get local stored tokens
|
||||
///
|
||||
/// 如果没有,则返回null
|
||||
/// if no, returns null
|
||||
Future<Tokens> getTokens() async {
|
||||
try {
|
||||
var sp = await SharedPreferences.getInstance();
|
||||
@@ -82,7 +82,8 @@ func genVars(dir string, isLegacy bool, scheme string, hostname string) error {
|
||||
}
|
||||
|
||||
if !fileExists(dir + "vars.dart") {
|
||||
err = os.WriteFile(dir+"vars.dart", []byte(fmt.Sprintf(`const serverHost='%s://%s';`, scheme, hostname)), 0o644)
|
||||
err = os.WriteFile(dir+"vars.dart", []byte(fmt.Sprintf(`const serverHost='%s://%s';`,
|
||||
scheme, hostname)), 0o644)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -42,8 +42,19 @@ var (
|
||||
func GoFormatApi(_ *cobra.Command, _ []string) error {
|
||||
var be errorx.BatchError
|
||||
if VarBoolUseStdin {
|
||||
if err := apiFormatReader(os.Stdin, VarStringDir, VarBoolSkipCheckDeclare); err != nil {
|
||||
be.Add(err)
|
||||
if env.UseExperimental() {
|
||||
data, err := io.ReadAll(os.Stdin)
|
||||
if err != nil {
|
||||
be.Add(err)
|
||||
} else {
|
||||
if err := apiF.Source(data, os.Stdout); err != nil {
|
||||
be.Add(err)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
if err := apiFormatReader(os.Stdin, VarStringDir, VarBoolSkipCheckDeclare); err != nil {
|
||||
be.Add(err)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
if len(VarStringDir) == 0 {
|
||||
|
||||
@@ -40,6 +40,8 @@ var (
|
||||
// VarStringStyle describes the style of output files.
|
||||
VarStringStyle string
|
||||
VarBoolWithTest bool
|
||||
// VarBoolTypeGroup describes whether to group types.
|
||||
VarBoolTypeGroup bool
|
||||
)
|
||||
|
||||
// GoCommand gen go project files from command line
|
||||
|
||||
@@ -6,8 +6,10 @@ import (
|
||||
"io"
|
||||
"os"
|
||||
"path"
|
||||
"sort"
|
||||
"strings"
|
||||
|
||||
"github.com/zeromicro/go-zero/core/collection"
|
||||
"github.com/zeromicro/go-zero/tools/goctl/api/spec"
|
||||
apiutil "github.com/zeromicro/go-zero/tools/goctl/api/util"
|
||||
"github.com/zeromicro/go-zero/tools/goctl/config"
|
||||
@@ -39,20 +41,152 @@ func BuildTypes(types []spec.Type) (string, error) {
|
||||
return builder.String(), nil
|
||||
}
|
||||
|
||||
func genTypes(dir string, cfg *config.Config, api *spec.ApiSpec) error {
|
||||
val, err := BuildTypes(api.Types)
|
||||
func getTypeName(tp spec.Type) string {
|
||||
if tp == nil {
|
||||
return ""
|
||||
}
|
||||
switch val := tp.(type) {
|
||||
case spec.DefineStruct:
|
||||
typeName := util.Title(tp.Name())
|
||||
return typeName
|
||||
case spec.PointerType:
|
||||
return getTypeName(val.Type)
|
||||
case spec.ArrayType:
|
||||
return getTypeName(val.Value)
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func genTypesWithGroup(dir string, cfg *config.Config, api *spec.ApiSpec) error {
|
||||
groupTypes := make(map[string]map[string]spec.Type)
|
||||
typesBelongToFiles := make(map[string]*collection.Set)
|
||||
|
||||
for _, v := range api.Service.Groups {
|
||||
group := v.GetAnnotation(groupProperty)
|
||||
if len(group) == 0 {
|
||||
group = groupTypeDefault
|
||||
}
|
||||
// convert filepath to Identifier name spec.
|
||||
group = strings.TrimPrefix(group, "/")
|
||||
group = strings.TrimSuffix(group, "/")
|
||||
group = util.SafeString(group)
|
||||
for _, v := range v.Routes {
|
||||
requestTypeName := getTypeName(v.RequestType)
|
||||
responseTypeName := getTypeName(v.ResponseType)
|
||||
requestTypeFileSet, ok := typesBelongToFiles[requestTypeName]
|
||||
if !ok {
|
||||
requestTypeFileSet = collection.NewSet()
|
||||
}
|
||||
if len(requestTypeName) > 0 {
|
||||
requestTypeFileSet.AddStr(group)
|
||||
typesBelongToFiles[requestTypeName] = requestTypeFileSet
|
||||
}
|
||||
|
||||
responseTypeFileSet, ok := typesBelongToFiles[responseTypeName]
|
||||
if !ok {
|
||||
responseTypeFileSet = collection.NewSet()
|
||||
}
|
||||
if len(responseTypeName) > 0 {
|
||||
responseTypeFileSet.AddStr(group)
|
||||
typesBelongToFiles[responseTypeName] = responseTypeFileSet
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
typesInOneFile := make(map[string]*collection.Set)
|
||||
for typeName, fileSet := range typesBelongToFiles {
|
||||
count := fileSet.Count()
|
||||
switch {
|
||||
case count == 0: // it means there has no structure type or no request/response body
|
||||
continue
|
||||
case count == 1: // it means a structure type used in only one group.
|
||||
groupName := fileSet.KeysStr()[0]
|
||||
typeSet, ok := typesInOneFile[groupName]
|
||||
if !ok {
|
||||
typeSet = collection.NewSet()
|
||||
}
|
||||
typeSet.AddStr(typeName)
|
||||
typesInOneFile[groupName] = typeSet
|
||||
default: // it means this type is used in multiple groups.
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
for _, v := range api.Types {
|
||||
typeName := util.Title(v.Name())
|
||||
groupSet, ok := typesBelongToFiles[typeName]
|
||||
var typeCount int
|
||||
if !ok {
|
||||
typeCount = 0
|
||||
} else {
|
||||
typeCount = groupSet.Count()
|
||||
}
|
||||
|
||||
if typeCount == 0 { // not belong to any group
|
||||
types, ok := groupTypes[groupTypeDefault]
|
||||
if !ok {
|
||||
types = make(map[string]spec.Type)
|
||||
}
|
||||
types[typeName] = v
|
||||
groupTypes[groupTypeDefault] = types
|
||||
continue
|
||||
}
|
||||
|
||||
if typeCount == 1 { // belong to one group
|
||||
groupName := groupSet.KeysStr()[0]
|
||||
types, ok := groupTypes[groupName]
|
||||
if !ok {
|
||||
types = make(map[string]spec.Type)
|
||||
}
|
||||
types[typeName] = v
|
||||
groupTypes[groupName] = types
|
||||
continue
|
||||
}
|
||||
|
||||
// belong to multiple groups
|
||||
types, ok := groupTypes[groupTypeDefault]
|
||||
if !ok {
|
||||
types = make(map[string]spec.Type)
|
||||
}
|
||||
types[typeName] = v
|
||||
groupTypes[groupTypeDefault] = types
|
||||
|
||||
}
|
||||
|
||||
for group, typeGroup := range groupTypes {
|
||||
var types []spec.Type
|
||||
for _, v := range typeGroup {
|
||||
types = append(types, v)
|
||||
}
|
||||
sort.Slice(types, func(i, j int) bool {
|
||||
return types[i].Name() < types[j].Name()
|
||||
})
|
||||
|
||||
if err := writeTypes(dir, group, cfg, types); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func writeTypes(dir, baseFilename string, cfg *config.Config, types []spec.Type) error {
|
||||
if len(types) == 0 {
|
||||
return nil
|
||||
}
|
||||
val, err := BuildTypes(types)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
typeFilename, err := format.FileNamingFormat(cfg.NamingFormat, typesFile)
|
||||
typeFilename, err := format.FileNamingFormat(cfg.NamingFormat, baseFilename)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
typeFilename = typeFilename + ".go"
|
||||
filename := path.Join(dir, typesDir, typeFilename)
|
||||
os.Remove(filename)
|
||||
_ = os.Remove(filename)
|
||||
|
||||
return genFile(fileGenConfig{
|
||||
dir: dir,
|
||||
@@ -70,6 +204,13 @@ func genTypes(dir string, cfg *config.Config, api *spec.ApiSpec) error {
|
||||
})
|
||||
}
|
||||
|
||||
func genTypes(dir string, cfg *config.Config, api *spec.ApiSpec) error {
|
||||
if VarBoolTypeGroup {
|
||||
return genTypesWithGroup(dir, cfg, api)
|
||||
}
|
||||
return writeTypes(dir, typesFile, cfg, api.Types)
|
||||
}
|
||||
|
||||
func writeType(writer io.Writer, tp spec.Type) error {
|
||||
structType, ok := tp.(spec.DefineStruct)
|
||||
if !ok {
|
||||
|
||||
@@ -10,4 +10,6 @@ const (
|
||||
middlewareDir = internal + "middleware"
|
||||
typesDir = internal + typesPacket
|
||||
groupProperty = "group"
|
||||
|
||||
groupTypeDefault="types"
|
||||
)
|
||||
|
||||
@@ -8,10 +8,10 @@ import (
|
||||
"fmt"
|
||||
"io"
|
||||
"path"
|
||||
"slices"
|
||||
"strings"
|
||||
"text/template"
|
||||
|
||||
"github.com/zeromicro/go-zero/core/stringx"
|
||||
"github.com/zeromicro/go-zero/tools/goctl/api/spec"
|
||||
apiutil "github.com/zeromicro/go-zero/tools/goctl/api/util"
|
||||
"github.com/zeromicro/go-zero/tools/goctl/internal/version"
|
||||
@@ -96,13 +96,13 @@ func (c *componentsContext) createComponent(dir, packetName string, ty spec.Type
|
||||
for _, item := range c.responseTypes {
|
||||
if item.Name() == defineStruct.Name() {
|
||||
superClassName = "HttpResponseData"
|
||||
if !stringx.Contains(c.imports, httpResponseData) {
|
||||
if !slices.Contains(c.imports, httpResponseData) {
|
||||
c.imports = append(c.imports, httpResponseData)
|
||||
}
|
||||
break
|
||||
}
|
||||
}
|
||||
if superClassName == "HttpData" && !stringx.Contains(c.imports, httpData) {
|
||||
if superClassName == "HttpData" && !slices.Contains(c.imports, httpData) {
|
||||
c.imports = append(c.imports, httpData)
|
||||
}
|
||||
|
||||
@@ -266,7 +266,7 @@ func (c *componentsContext) genGetSet(writer io.Writer, indent int) error {
|
||||
tyString := javaType
|
||||
decorator := ""
|
||||
javaPrimitiveType := []string{"int", "long", "boolean", "float", "double", "short"}
|
||||
if !stringx.Contains(javaPrimitiveType, javaType) {
|
||||
if !slices.Contains(javaPrimitiveType, javaType) {
|
||||
if member.IsOptional() || member.IsOmitEmpty() {
|
||||
decorator = "@Nullable "
|
||||
} else {
|
||||
|
||||
@@ -3,9 +3,9 @@ package spec
|
||||
import (
|
||||
"errors"
|
||||
"path"
|
||||
"slices"
|
||||
"strings"
|
||||
|
||||
"github.com/zeromicro/go-zero/core/stringx"
|
||||
"github.com/zeromicro/go-zero/tools/goctl/util"
|
||||
)
|
||||
|
||||
@@ -57,14 +57,14 @@ func (m Member) Tags() []*Tag {
|
||||
|
||||
// IsOptional returns true if tag is optional
|
||||
func (m Member) IsOptional() bool {
|
||||
if !m.IsBodyMember() {
|
||||
if !m.IsBodyMember() && !m.IsFormMember() {
|
||||
return false
|
||||
}
|
||||
|
||||
tag := m.Tags()
|
||||
for _, item := range tag {
|
||||
if item.Key == bodyTagKey {
|
||||
if stringx.Contains(item.Options, "optional") {
|
||||
if item.Key == bodyTagKey || item.Key == formTagKey {
|
||||
if slices.Contains(item.Options, "optional") {
|
||||
return true
|
||||
}
|
||||
}
|
||||
@@ -81,7 +81,7 @@ func (m Member) IsOmitEmpty() bool {
|
||||
tag := m.Tags()
|
||||
for _, item := range tag {
|
||||
if item.Key == bodyTagKey {
|
||||
if stringx.Contains(item.Options, "omitempty") {
|
||||
if slices.Contains(item.Options, "omitempty") {
|
||||
return true
|
||||
}
|
||||
}
|
||||
@@ -93,7 +93,7 @@ func (m Member) IsOmitEmpty() bool {
|
||||
func (m Member) GetPropertyName() (string, error) {
|
||||
tags := m.Tags()
|
||||
for _, tag := range tags {
|
||||
if stringx.Contains(definedKeys, tag.Key) {
|
||||
if slices.Contains(definedKeys, tag.Key) {
|
||||
if tag.Name == "-" {
|
||||
return util.Untitle(m.Name), nil
|
||||
}
|
||||
|
||||
@@ -21,7 +21,7 @@ type (
|
||||
|
||||
// ApiSpec describes an api file
|
||||
ApiSpec struct {
|
||||
Info Info // Deprecated: useless expression
|
||||
Info Info
|
||||
Syntax ApiSyntax // Deprecated: useless expression
|
||||
Imports []Import // Deprecated: useless expression
|
||||
Types []Type
|
||||
@@ -59,11 +59,11 @@ type (
|
||||
// Member describes the field of a structure
|
||||
Member struct {
|
||||
Name string
|
||||
// 数据类型字面值,如:string、map[int]string、[]int64、[]*User
|
||||
// data type, for example, string、map[int]string、[]int64、[]*User
|
||||
Type Type
|
||||
Tag string
|
||||
Comment string
|
||||
// 成员头顶注释说明
|
||||
// document for the field
|
||||
Docs Doc
|
||||
IsInline bool
|
||||
}
|
||||
|
||||
75
tools/goctl/api/swagger/annotation.go
Normal file
75
tools/goctl/api/swagger/annotation.go
Normal file
@@ -0,0 +1,75 @@
|
||||
package swagger
|
||||
|
||||
import (
|
||||
"strconv"
|
||||
|
||||
"github.com/zeromicro/go-zero/tools/goctl/util"
|
||||
"google.golang.org/grpc/metadata"
|
||||
)
|
||||
|
||||
func getBoolFromKVOrDefault(properties map[string]string, key string, def bool) bool {
|
||||
if len(properties) == 0 {
|
||||
return def
|
||||
}
|
||||
md := metadata.New(properties)
|
||||
val := md.Get(key)
|
||||
if len(val) == 0 {
|
||||
return def
|
||||
}
|
||||
str := util.Unquote(val[0])
|
||||
if len(str) == 0 {
|
||||
return def
|
||||
}
|
||||
res, _ := strconv.ParseBool(str)
|
||||
return res
|
||||
}
|
||||
|
||||
func getStringFromKVOrDefault(properties map[string]string, key string, def string) string {
|
||||
if len(properties) == 0 {
|
||||
return def
|
||||
}
|
||||
md := metadata.New(properties)
|
||||
val := md.Get(key)
|
||||
if len(val) == 0 {
|
||||
return def
|
||||
}
|
||||
str := util.Unquote(val[0])
|
||||
if len(str) == 0 {
|
||||
return def
|
||||
}
|
||||
return str
|
||||
}
|
||||
|
||||
func getListFromInfoOrDefault(properties map[string]string, key string, def []string) []string {
|
||||
if len(properties) == 0 {
|
||||
return def
|
||||
}
|
||||
md := metadata.New(properties)
|
||||
val := md.Get(key)
|
||||
if len(val) == 0 {
|
||||
return def
|
||||
}
|
||||
|
||||
str := util.Unquote(val[0])
|
||||
if len(str) == 0 {
|
||||
return def
|
||||
}
|
||||
resp := util.FieldsAndTrimSpace(str, commaRune)
|
||||
if len(resp) == 0 {
|
||||
return def
|
||||
}
|
||||
return resp
|
||||
}
|
||||
|
||||
func getFirstUsableString(def ...string) string {
|
||||
if len(def) == 0 {
|
||||
return ""
|
||||
}
|
||||
for _, val := range def {
|
||||
str := util.Unquote(val)
|
||||
if len(str) != 0 {
|
||||
return str
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
53
tools/goctl/api/swagger/annotation_test.go
Normal file
53
tools/goctl/api/swagger/annotation_test.go
Normal file
@@ -0,0 +1,53 @@
|
||||
package swagger
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func Test_getBoolFromKVOrDefault(t *testing.T) {
|
||||
properties := map[string]string{
|
||||
"enabled": `"true"`,
|
||||
"disabled": `"false"`,
|
||||
"invalid": `"notabool"`,
|
||||
"empty_value": `""`,
|
||||
}
|
||||
|
||||
assert.True(t, getBoolFromKVOrDefault(properties, "enabled", false))
|
||||
assert.False(t, getBoolFromKVOrDefault(properties, "disabled", true))
|
||||
assert.False(t, getBoolFromKVOrDefault(properties, "invalid", false))
|
||||
assert.True(t, getBoolFromKVOrDefault(properties, "missing", true))
|
||||
assert.False(t, getBoolFromKVOrDefault(properties, "empty_value", false))
|
||||
assert.False(t, getBoolFromKVOrDefault(nil, "nil", false))
|
||||
assert.False(t, getBoolFromKVOrDefault(map[string]string{}, "empty", false))
|
||||
}
|
||||
|
||||
func Test_getStringFromKVOrDefault(t *testing.T) {
|
||||
properties := map[string]string{
|
||||
"name": `"example"`,
|
||||
"empty": `""`,
|
||||
}
|
||||
|
||||
assert.Equal(t, "example", getStringFromKVOrDefault(properties, "name", "default"))
|
||||
assert.Equal(t, "default", getStringFromKVOrDefault(properties, "empty", "default"))
|
||||
assert.Equal(t, "default", getStringFromKVOrDefault(properties, "missing", "default"))
|
||||
assert.Equal(t, "default", getStringFromKVOrDefault(nil, "nil", "default"))
|
||||
assert.Equal(t, "default", getStringFromKVOrDefault(map[string]string{}, "empty", "default"))
|
||||
}
|
||||
|
||||
func Test_getListFromInfoOrDefault(t *testing.T) {
|
||||
properties := map[string]string{
|
||||
"list": `"a, b, c"`,
|
||||
"empty": `""`,
|
||||
}
|
||||
|
||||
assert.Equal(t, []string{"a", " b", " c"}, getListFromInfoOrDefault(properties, "list", []string{"default"}))
|
||||
assert.Equal(t, []string{"default"}, getListFromInfoOrDefault(properties, "empty", []string{"default"}))
|
||||
assert.Equal(t, []string{"default"}, getListFromInfoOrDefault(properties, "missing", []string{"default"}))
|
||||
assert.Equal(t, []string{"default"}, getListFromInfoOrDefault(nil, "nil", []string{"default"}))
|
||||
assert.Equal(t, []string{"default"}, getListFromInfoOrDefault(map[string]string{}, "empty", []string{"default"}))
|
||||
assert.Equal(t, []string{"default"}, getListFromInfoOrDefault(map[string]string{
|
||||
"foo": ",,",
|
||||
}, "foo", []string{"default"}))
|
||||
}
|
||||
138
tools/goctl/api/swagger/api.go
Normal file
138
tools/goctl/api/swagger/api.go
Normal file
@@ -0,0 +1,138 @@
|
||||
package swagger
|
||||
|
||||
import "github.com/zeromicro/go-zero/tools/goctl/api/spec"
|
||||
|
||||
func fillAllStructs(api *spec.ApiSpec) {
|
||||
var (
|
||||
tps []spec.Type
|
||||
structTypes = make(map[string]spec.DefineStruct)
|
||||
groups []spec.Group
|
||||
)
|
||||
for _, tp := range api.Types {
|
||||
structTypes[tp.Name()] = tp.(spec.DefineStruct)
|
||||
}
|
||||
|
||||
for _, tp := range api.Types {
|
||||
filledTP := fillStruct("", tp, structTypes)
|
||||
tps = append(tps, filledTP)
|
||||
structTypes[filledTP.Name()] = filledTP.(spec.DefineStruct)
|
||||
}
|
||||
|
||||
for _, group := range api.Service.Groups {
|
||||
var routes []spec.Route
|
||||
for _, route := range group.Routes {
|
||||
route.RequestType = fillStruct("", route.RequestType, structTypes)
|
||||
route.ResponseType = fillStruct("", route.ResponseType, structTypes)
|
||||
routes = append(routes, route)
|
||||
}
|
||||
group.Routes = routes
|
||||
groups = append(groups, group)
|
||||
}
|
||||
api.Service.Groups = groups
|
||||
api.Types = tps
|
||||
}
|
||||
|
||||
func fillStruct(parent string, tp spec.Type, allTypes map[string]spec.DefineStruct) spec.Type {
|
||||
switch val := tp.(type) {
|
||||
case spec.DefineStruct:
|
||||
var members []spec.Member
|
||||
for _, member := range val.Members {
|
||||
switch memberType := member.Type.(type) {
|
||||
case spec.PointerType:
|
||||
member.Type = spec.PointerType{
|
||||
RawName: memberType.RawName,
|
||||
Type: fillStruct(val.Name(), memberType.Type, allTypes),
|
||||
}
|
||||
case spec.ArrayType:
|
||||
member.Type = spec.ArrayType{
|
||||
RawName: memberType.RawName,
|
||||
Value: fillStruct(val.Name(), memberType.Value, allTypes),
|
||||
}
|
||||
case spec.MapType:
|
||||
member.Type = spec.MapType{
|
||||
RawName: memberType.RawName,
|
||||
Key: memberType.Key,
|
||||
Value: fillStruct(val.Name(), memberType.Value, allTypes),
|
||||
}
|
||||
case spec.DefineStruct:
|
||||
if parent != memberType.Name() { // avoid recursive struct
|
||||
if st, ok := allTypes[memberType.Name()]; ok {
|
||||
member.Type = fillStruct("", st, allTypes)
|
||||
}
|
||||
}
|
||||
case spec.NestedStruct:
|
||||
member.Type = fillStruct("", member.Type, allTypes)
|
||||
}
|
||||
members = append(members, member)
|
||||
}
|
||||
if len(members) == 0 {
|
||||
st, ok := allTypes[val.RawName]
|
||||
if ok {
|
||||
members = st.Members
|
||||
}
|
||||
}
|
||||
val.Members = members
|
||||
return val
|
||||
case spec.NestedStruct:
|
||||
var members []spec.Member
|
||||
for _, member := range val.Members {
|
||||
switch memberType := member.Type.(type) {
|
||||
case spec.PointerType:
|
||||
member.Type = spec.PointerType{
|
||||
RawName: memberType.RawName,
|
||||
Type: fillStruct(val.Name(), memberType.Type, allTypes),
|
||||
}
|
||||
case spec.ArrayType:
|
||||
member.Type = spec.ArrayType{
|
||||
RawName: memberType.RawName,
|
||||
Value: fillStruct(val.Name(), memberType.Value, allTypes),
|
||||
}
|
||||
case spec.MapType:
|
||||
member.Type = spec.MapType{
|
||||
RawName: memberType.RawName,
|
||||
Key: memberType.Key,
|
||||
Value: fillStruct(val.Name(), memberType.Value, allTypes),
|
||||
}
|
||||
case spec.DefineStruct:
|
||||
if parent != memberType.Name() { // avoid recursive struct
|
||||
if st, ok := allTypes[memberType.Name()]; ok {
|
||||
member.Type = fillStruct("", st, allTypes)
|
||||
}
|
||||
}
|
||||
case spec.NestedStruct:
|
||||
if parent != memberType.Name() {
|
||||
if st, ok := allTypes[memberType.Name()]; ok {
|
||||
member.Type = fillStruct("", st, allTypes)
|
||||
}
|
||||
}
|
||||
}
|
||||
members = append(members, member)
|
||||
}
|
||||
if len(members) == 0 {
|
||||
st, ok := allTypes[val.RawName]
|
||||
if ok {
|
||||
members = st.Members
|
||||
}
|
||||
}
|
||||
val.Members = members
|
||||
return val
|
||||
case spec.PointerType:
|
||||
return spec.PointerType{
|
||||
RawName: val.RawName,
|
||||
Type: fillStruct(parent, val.Type, allTypes),
|
||||
}
|
||||
case spec.ArrayType:
|
||||
return spec.ArrayType{
|
||||
RawName: val.RawName,
|
||||
Value: fillStruct(parent, val.Value, allTypes),
|
||||
}
|
||||
case spec.MapType:
|
||||
return spec.MapType{
|
||||
RawName: val.RawName,
|
||||
Key: val.Key,
|
||||
Value: fillStruct(parent, val.Value, allTypes),
|
||||
}
|
||||
default:
|
||||
return tp
|
||||
}
|
||||
}
|
||||
87
tools/goctl/api/swagger/command.go
Normal file
87
tools/goctl/api/swagger/command.go
Normal file
@@ -0,0 +1,87 @@
|
||||
package swagger
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
"github.com/spf13/cobra"
|
||||
"github.com/zeromicro/go-zero/tools/goctl/pkg/parser/api/parser"
|
||||
"github.com/zeromicro/go-zero/tools/goctl/util/pathx"
|
||||
"gopkg.in/yaml.v2"
|
||||
)
|
||||
|
||||
var (
|
||||
// VarStringAPI specifies the API filename.
|
||||
VarStringAPI string
|
||||
|
||||
// VarStringDir specifies the directory to generate swagger file.
|
||||
VarStringDir string
|
||||
|
||||
// VarStringFilename specifies the generated swagger file name without the extension.
|
||||
VarStringFilename string
|
||||
|
||||
// VarBoolYaml specifies whether to generate a YAML file.
|
||||
VarBoolYaml bool
|
||||
)
|
||||
|
||||
func Command(_ *cobra.Command, _ []string) error {
|
||||
if len(VarStringAPI) == 0 {
|
||||
return errors.New("missing -api")
|
||||
}
|
||||
|
||||
if len(VarStringDir) == 0 {
|
||||
return errors.New("missing -dir")
|
||||
}
|
||||
|
||||
api, err := parser.Parse(VarStringAPI, "")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
fillAllStructs(api)
|
||||
|
||||
if err := api.Validate(); err != nil {
|
||||
return err
|
||||
}
|
||||
swagger, err := spec2Swagger(api)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
data, err := json.MarshalIndent(swagger, "", " ")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = pathx.MkdirIfNotExist(VarStringDir)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
filename := VarStringFilename
|
||||
if filename == "" {
|
||||
base := filepath.Base(VarStringAPI)
|
||||
filename = strings.TrimSuffix(base, filepath.Ext(base))
|
||||
}
|
||||
|
||||
if VarBoolYaml {
|
||||
filePath := filepath.Join(VarStringDir, filename+".yaml")
|
||||
|
||||
var jsonObj interface{}
|
||||
if err := yaml.Unmarshal(data, &jsonObj); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
data, err := yaml.Marshal(jsonObj)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return os.WriteFile(filePath, data, 0644)
|
||||
}
|
||||
|
||||
// generate json swagger file
|
||||
filePath := filepath.Join(VarStringDir, filename+".json")
|
||||
return os.WriteFile(filePath, data, 0644)
|
||||
}
|
||||
65
tools/goctl/api/swagger/const.go
Normal file
65
tools/goctl/api/swagger/const.go
Normal file
@@ -0,0 +1,65 @@
|
||||
package swagger
|
||||
|
||||
const (
|
||||
tagHeader = "header"
|
||||
tagPath = "path"
|
||||
tagForm = "form"
|
||||
tagJson = "json"
|
||||
defFlag = "default="
|
||||
enumFlag = "options="
|
||||
rangeFlag = "range="
|
||||
exampleFlag = "example="
|
||||
optionalFlag = "optional"
|
||||
|
||||
paramsInHeader = "header"
|
||||
paramsInPath = "path"
|
||||
paramsInQuery = "query"
|
||||
paramsInBody = "body"
|
||||
paramsInForm = "formData"
|
||||
|
||||
swaggerTypeInteger = "integer"
|
||||
swaggerTypeNumber = "number"
|
||||
swaggerTypeString = "string"
|
||||
swaggerTypeBoolean = "boolean"
|
||||
swaggerTypeArray = "array"
|
||||
swaggerTypeObject = "object"
|
||||
|
||||
swaggerVersion = "2.0"
|
||||
applicationJson = "application/json"
|
||||
applicationForm = "application/x-www-form-urlencoded"
|
||||
schemeHttps = "https"
|
||||
defaultBasePath = "/"
|
||||
)
|
||||
|
||||
const (
|
||||
propertyKeyUseDefinitions = "useDefinitions"
|
||||
propertyKeyExternalDocsDescription = "externalDocsDescription"
|
||||
propertyKeyExternalDocsURL = "externalDocsURL"
|
||||
propertyKeyTitle = "title"
|
||||
propertyKeyTermsOfService = "termsOfService"
|
||||
propertyKeyDescription = "description"
|
||||
propertyKeyVersion = "version"
|
||||
propertyKeyContactName = "contactName"
|
||||
propertyKeyContactURL = "contactURL"
|
||||
propertyKeyContactEmail = "contactEmail"
|
||||
propertyKeyLicenseName = "licenseName"
|
||||
propertyKeyLicenseURL = "licenseURL"
|
||||
propertyKeyProduces = "produces"
|
||||
propertyKeyConsumes = "consumes"
|
||||
propertyKeySchemes = "schemes"
|
||||
propertyKeyTags = "tags"
|
||||
propertyKeySummary = "summary"
|
||||
propertyKeyGroup = "group"
|
||||
propertyKeyOperationId = "operationId"
|
||||
propertyKeyDeprecated = "deprecated"
|
||||
propertyKeyPrefix = "prefix"
|
||||
propertyKeyAuthType = "authType"
|
||||
propertyKeyHost = "host"
|
||||
propertyKeyBasePath = "basePath"
|
||||
propertyKeyWrapCodeMsg = "wrapCodeMsg"
|
||||
propertyKeyBizCodeEnumDescription = "bizCodeEnumDescription"
|
||||
)
|
||||
|
||||
const (
|
||||
defaultValueOfPropertyUseDefinition = false
|
||||
)
|
||||
25
tools/goctl/api/swagger/contenttype.go
Normal file
25
tools/goctl/api/swagger/contenttype.go
Normal file
@@ -0,0 +1,25 @@
|
||||
package swagger
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/zeromicro/go-zero/tools/goctl/api/spec"
|
||||
)
|
||||
|
||||
func consumesFromTypeOrDef(ctx Context, method string, tp spec.Type) []string {
|
||||
if strings.EqualFold(method, http.MethodGet) {
|
||||
return []string{}
|
||||
}
|
||||
if tp == nil {
|
||||
return []string{}
|
||||
}
|
||||
structType, ok := tp.(spec.DefineStruct)
|
||||
if !ok {
|
||||
return []string{}
|
||||
}
|
||||
if typeContainsTag(ctx, structType, tagJson) {
|
||||
return []string{applicationJson}
|
||||
}
|
||||
return []string{applicationForm}
|
||||
}
|
||||
68
tools/goctl/api/swagger/contenttype_test.go
Normal file
68
tools/goctl/api/swagger/contenttype_test.go
Normal file
@@ -0,0 +1,68 @@
|
||||
package swagger
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/zeromicro/go-zero/tools/goctl/api/spec"
|
||||
)
|
||||
|
||||
func TestConsumesFromTypeOrDef(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
method string
|
||||
tp spec.Type
|
||||
expected []string
|
||||
}{
|
||||
{
|
||||
name: "GET method with nil type",
|
||||
method: http.MethodGet,
|
||||
tp: nil,
|
||||
expected: []string{},
|
||||
},
|
||||
{
|
||||
name: "post nil",
|
||||
method: http.MethodPost,
|
||||
tp: nil,
|
||||
expected: []string{},
|
||||
},
|
||||
{
|
||||
name: "json tag",
|
||||
method: http.MethodPost,
|
||||
tp: spec.DefineStruct{
|
||||
Members: []spec.Member{
|
||||
{
|
||||
Tag: `json:"example"`,
|
||||
},
|
||||
},
|
||||
},
|
||||
expected: []string{applicationJson},
|
||||
},
|
||||
{
|
||||
name: "form tag",
|
||||
method: http.MethodPost,
|
||||
tp: spec.DefineStruct{
|
||||
Members: []spec.Member{
|
||||
{
|
||||
Tag: `form:"example"`,
|
||||
},
|
||||
},
|
||||
},
|
||||
expected: []string{applicationForm},
|
||||
},
|
||||
{
|
||||
name: "Non struct type",
|
||||
method: http.MethodPost,
|
||||
tp: spec.ArrayType{},
|
||||
expected: []string{},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := consumesFromTypeOrDef(testingContext(t), tt.method, tt.tp)
|
||||
assert.Equal(t, tt.expected, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
28
tools/goctl/api/swagger/context.go
Normal file
28
tools/goctl/api/swagger/context.go
Normal file
@@ -0,0 +1,28 @@
|
||||
package swagger
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/zeromicro/go-zero/tools/goctl/api/spec"
|
||||
)
|
||||
|
||||
type Context struct {
|
||||
UseDefinitions bool
|
||||
WrapCodeMsg bool
|
||||
BizCodeEnumDescription string
|
||||
}
|
||||
|
||||
func testingContext(_ *testing.T) Context {
|
||||
return Context{}
|
||||
}
|
||||
|
||||
func contextFromApi(info spec.Info) Context {
|
||||
if len(info.Properties) == 0 {
|
||||
return Context{}
|
||||
}
|
||||
return Context{
|
||||
UseDefinitions: getBoolFromKVOrDefault(info.Properties, propertyKeyUseDefinitions, defaultValueOfPropertyUseDefinition),
|
||||
WrapCodeMsg: getBoolFromKVOrDefault(info.Properties, propertyKeyWrapCodeMsg, false),
|
||||
BizCodeEnumDescription: getStringFromKVOrDefault(info.Properties, propertyKeyBizCodeEnumDescription, "business code"),
|
||||
}
|
||||
}
|
||||
32
tools/goctl/api/swagger/definition.go
Normal file
32
tools/goctl/api/swagger/definition.go
Normal file
@@ -0,0 +1,32 @@
|
||||
package swagger
|
||||
|
||||
import (
|
||||
"github.com/go-openapi/spec"
|
||||
apiSpec "github.com/zeromicro/go-zero/tools/goctl/api/spec"
|
||||
)
|
||||
|
||||
func definitionsFromTypes(ctx Context, types []apiSpec.Type) spec.Definitions {
|
||||
if !ctx.UseDefinitions {
|
||||
return nil
|
||||
}
|
||||
definitions := make(spec.Definitions)
|
||||
for _, tp := range types {
|
||||
typeName := tp.Name()
|
||||
definitions[typeName] = schemaFromType(ctx, tp)
|
||||
}
|
||||
return definitions
|
||||
}
|
||||
|
||||
func schemaFromType(ctx Context, tp apiSpec.Type) spec.Schema {
|
||||
p, r := propertiesFromType(ctx, tp)
|
||||
props := spec.SchemaProps{
|
||||
Type: typeFromGoType(ctx, tp),
|
||||
Properties: p,
|
||||
AdditionalProperties: mapFromGoType(ctx, tp),
|
||||
Items: itemsFromGoType(ctx, tp),
|
||||
Required: r,
|
||||
}
|
||||
return spec.Schema{
|
||||
SchemaProps: props,
|
||||
}
|
||||
}
|
||||
4
tools/goctl/api/swagger/example/.gitignore
vendored
Normal file
4
tools/goctl/api/swagger/example/.gitignore
vendored
Normal file
@@ -0,0 +1,4 @@
|
||||
*.json
|
||||
*.yaml
|
||||
bin
|
||||
output
|
||||
241
tools/goctl/api/swagger/example/example.api
Normal file
241
tools/goctl/api/swagger/example/example.api
Normal file
@@ -0,0 +1,241 @@
|
||||
syntax = "v1"
|
||||
|
||||
info (
|
||||
title: "Demo API" // title corresponding to Swagger
|
||||
description: "Generating Swagger files using the API demo." // description corresponding to Swagger
|
||||
version: "v1" // version corresponding to Swagger
|
||||
termsOfService: "https://github.com/zeromicro/go-zero" // termsOfService corresponding to Swagger
|
||||
contactName: "keson.an" // contactName corresponding to Swagger
|
||||
contactURL: "https://github.com/zeromicro/go-zero" // contactURL corresponding to Swagger
|
||||
contactEmail: "example@gmail.com" // contactEmail corresponding to Swagger
|
||||
licenseName: "MIT" // licenseName corresponding to Swagger
|
||||
licenseURL: "https://github.com/zeromicro/go-zero" // licenseURL corresponding to Swagger
|
||||
consumes: "application/json" // consumes corresponding to Swagger,default value is `application/json`
|
||||
produces: "application/json" // produces corresponding to Swagger,default value is `application/json`
|
||||
schemes: "http,https" // schemes corresponding to Swagger,default value is `https``
|
||||
host: "example.com" // host corresponding to Swagger,default value is `127.0.0.1`
|
||||
basePath: "/v1" // basePath corresponding to Swagger,default value is `/`
|
||||
wrapCodeMsg: true // to wrap in the universal code-msg structure, like {"code":0,"msg":"OK","data":$data}
|
||||
bizCodeEnumDescription: "1001-User not login<br>1002-User permission denied" // enums of business error codes, in JSON format, with the key being the business error code and the value being the description of that error code. This only takes effect when wrapCodeMsg is set to true.
|
||||
// securityDefinitionsFromJson is a custom authentication configuration, and the JSON content will be directly inserted into the securityDefinitions of Swagger.
|
||||
// Format reference: https://swagger.io/specification/v2/#security-definitions-object
|
||||
// You can declare authType in the @server of the API to specify the authentication type used for its routes.
|
||||
securityDefinitionsFromJson: `{"apiKey":{"description":"apiKey type description","type":"apiKey","name":"x-api-key","in":"header"}}`
|
||||
useDefinitions: true // if set true, the definitions will be generated in the swagger.json for response body or json request body file, and the models will be referenced in the API.
|
||||
)
|
||||
|
||||
type (
|
||||
QueryReq {
|
||||
Id int `form:"id,range=[1:10000],example=10"`
|
||||
Name string `form:"name,example=keson.an"`
|
||||
Avatar string `form:"avatar,optional,example=https://example.com/avatar.png"`
|
||||
}
|
||||
QueryResp {
|
||||
Id int `json:"id,example=10"`
|
||||
Name string `json:"name,example=keson.an"`
|
||||
}
|
||||
PathQueryReq {
|
||||
Id int `path:"id,range=[1:10000],example=10"`
|
||||
Name string `form:"name,example=keson.an"`
|
||||
}
|
||||
PathQueryResp {
|
||||
Id int `json:"id,example=10"`
|
||||
Name string `json:"name,example=keson.an"`
|
||||
}
|
||||
)
|
||||
|
||||
@server (
|
||||
tags: "query" // tags corresponding to Swagger
|
||||
summary: "query API set" // summary corresponding to Swagger
|
||||
prefix: v1
|
||||
authType: apiKey // Specifies the authentication type used for this route, which is the name defined in securityDefinitionsFromJson.
|
||||
)
|
||||
service Swagger {
|
||||
@doc (
|
||||
description: "query demo"
|
||||
)
|
||||
@handler query
|
||||
get /query (QueryReq) returns (QueryResp)
|
||||
|
||||
@doc (
|
||||
description: "show path query demo"
|
||||
)
|
||||
@handler queryPath
|
||||
get /query/:id (PathQueryReq) returns (PathQueryResp)
|
||||
}
|
||||
|
||||
type (
|
||||
FormReq {
|
||||
Id int `form:"id,range=[1:10000],example=10"`
|
||||
Name string `form:"name,example=keson.an"`
|
||||
}
|
||||
FormResp {
|
||||
Id int `json:"id,example=10"`
|
||||
Name string `json:"name,example=keson.an"`
|
||||
}
|
||||
)
|
||||
|
||||
@server (
|
||||
tags: "form" // tags corresponding to Swagger
|
||||
summary: "form API set" // summary corresponding to Swagger
|
||||
)
|
||||
service Swagger {
|
||||
@doc (
|
||||
description: "form demo"
|
||||
)
|
||||
@handler form
|
||||
post /form (FormReq) returns (FormResp)
|
||||
}
|
||||
|
||||
type (
|
||||
JsonReq {
|
||||
Id int `json:"id,range=[1:10000],example=10"`
|
||||
Name string `json:"name,example=keson.an"`
|
||||
Avatar string `json:"avatar,optional"`
|
||||
Language string `json:"language,options=golang|java|python|typescript|rust"`
|
||||
Gender string `json:"gender,default=male,options=male|female,example=male"`
|
||||
}
|
||||
JsonResp {
|
||||
Id int `json:"id"`
|
||||
Name string `json:"name"`
|
||||
Avatar string `json:"avatar"`
|
||||
Language string `json:"language"`
|
||||
Gender string `json:"gender"`
|
||||
}
|
||||
ComplexJsonLevel2 {
|
||||
// basic
|
||||
Integer int `json:"integer,example=1"`
|
||||
Number float64 `json:"number,example=1.1"`
|
||||
Boolean bool `json:"boolean,options=true|false,example=true"`
|
||||
String string `json:"string,example=some text"`
|
||||
}
|
||||
ComplexJsonLevel1 {
|
||||
// basic
|
||||
Integer int `json:"integer,example=1"`
|
||||
Number float64 `json:"number,example=1.1"`
|
||||
Boolean bool `json:"boolean,options=true|false,example=true"`
|
||||
String string `json:"string,example=some text"`
|
||||
// Object
|
||||
Object ComplexJsonLevel2 `json:"object"`
|
||||
PointerObject *ComplexJsonLevel2 `json:"pointerObject"`
|
||||
}
|
||||
ComplexJsonReq {
|
||||
// basic
|
||||
Integer int `json:"integer,example=1"`
|
||||
Number float64 `json:"number,example=1.1"`
|
||||
Boolean bool `json:"boolean,options=true|false,example=true"`
|
||||
String string `json:"string,example=some text"`
|
||||
// basic array
|
||||
ArrayInteger []int `json:"arrayInteger"`
|
||||
ArrayNumber []float64 `json:"arrayNumber"`
|
||||
ArrayBoolean []bool `json:"arrayBoolean"`
|
||||
ArrayString []string `json:"arrayString"`
|
||||
// basic array array
|
||||
ArrayArrayInteger [][]int `json:"arrayArrayInteger"`
|
||||
ArrayArrayNumber [][]float64 `json:"arrayArrayNumber"`
|
||||
ArrayArrayBoolean [][]bool `json:"arrayArrayBoolean"`
|
||||
ArrayArrayString [][]string `json:"arrayArrayString"`
|
||||
// basic map
|
||||
MapInteger map[string]int `json:"mapInteger"`
|
||||
MapNumber map[string]float64 `json:"mapNumber"`
|
||||
MapBoolean map[string]bool `json:"mapBoolean"`
|
||||
MapString map[string]string `json:"mapString"`
|
||||
// basic map array
|
||||
MapArrayInteger map[string][]int `json:"mapArrayInteger"`
|
||||
MapArrayNumber map[string][]float64 `json:"mapArrayNumber"`
|
||||
MapArrayBoolean map[string][]bool `json:"mapArrayBoolean"`
|
||||
MapArrayString map[string][]string `json:"mapArrayString"`
|
||||
// basic map map
|
||||
MapMapInteger map[string]map[string]int `json:"mapMapInteger"`
|
||||
MapMapNumber map[string]map[string]float64 `json:"mapMapNumber"`
|
||||
MapMapBoolean map[string]map[string]bool `json:"mapMapBoolean"`
|
||||
MapMapString map[string]map[string]string `json:"mapMapString"`
|
||||
// Object
|
||||
Object ComplexJsonLevel1 `json:"object"`
|
||||
PointerObject *ComplexJsonLevel1 `json:"pointerObject"`
|
||||
// Object array
|
||||
ArrayObject []ComplexJsonLevel1 `json:"arrayObject"`
|
||||
ArrayPointerObject []*ComplexJsonLevel1 `json:"arrayPointerObject"`
|
||||
// Object map
|
||||
MapObject map[string]ComplexJsonLevel1 `json:"mapObject"`
|
||||
MapPointerObject map[string]*ComplexJsonLevel1 `json:"mapPointerObject"`
|
||||
// Object array array
|
||||
ArrayArrayObject [][]ComplexJsonLevel1 `json:"arrayArrayObject"`
|
||||
ArrayArrayPointerObject [][]*ComplexJsonLevel1 `json:"arrayArrayPointerObject"`
|
||||
// Object array map
|
||||
ArrayMapObject []map[string]ComplexJsonLevel1 `json:"arrayMapObject"`
|
||||
ArrayMapPointerObject []map[string]*ComplexJsonLevel1 `json:"arrayMapPointerObject"`
|
||||
// Object map array
|
||||
MapArrayObject map[string][]ComplexJsonLevel1 `json:"mapArrayObject"`
|
||||
MapArrayPointerObject map[string][]*ComplexJsonLevel1 `json:"mapArrayPointerObject"`
|
||||
}
|
||||
ComplexJsonResp {
|
||||
// basic
|
||||
Integer int `json:"integer,example=1"`
|
||||
Number float64 `json:"number,example=1.1"`
|
||||
Boolean bool `json:"boolean,options=true|false,example=true"`
|
||||
String string `json:"string,example=some text"`
|
||||
// basic array
|
||||
ArrayInteger []int `json:"arrayInteger"`
|
||||
ArrayNumber []float64 `json:"arrayNumber"`
|
||||
ArrayBoolean []bool `json:"arrayBoolean"`
|
||||
ArrayString []string `json:"arrayString"`
|
||||
// basic array array
|
||||
ArrayArrayInteger [][]int `json:"arrayArrayInteger"`
|
||||
ArrayArrayNumber [][]float64 `json:"arrayArrayNumber"`
|
||||
ArrayArrayBoolean [][]bool `json:"arrayArrayBoolean"`
|
||||
ArrayArrayString [][]string `json:"arrayArrayString"`
|
||||
// basic map
|
||||
MapInteger map[string]int `json:"mapInteger"`
|
||||
MapNumber map[string]float64 `json:"mapNumber"`
|
||||
MapBoolean map[string]bool `json:"mapBoolean"`
|
||||
MapString map[string]string `json:"mapString"`
|
||||
// basic map array
|
||||
MapArrayInteger map[string][]int `json:"mapArrayInteger"`
|
||||
MapArrayNumber map[string][]float64 `json:"mapArrayNumber"`
|
||||
MapArrayBoolean map[string][]bool `json:"mapArrayBoolean"`
|
||||
MapArrayString map[string][]string `json:"mapArrayString"`
|
||||
// basic map map
|
||||
MapMapInteger map[string]map[string]int `json:"mapMapInteger"`
|
||||
MapMapNumber map[string]map[string]float64 `json:"mapMapNumber"`
|
||||
MapMapBoolean map[string]map[string]bool `json:"mapMapBoolean"`
|
||||
MapMapString map[string]map[string]string `json:"mapMapString"`
|
||||
// Object
|
||||
Object ComplexJsonLevel1 `json:"object"`
|
||||
PointerObject *ComplexJsonLevel1 `json:"pointerObject"`
|
||||
// Object array
|
||||
ArrayObject []ComplexJsonLevel1 `json:"arrayObject"`
|
||||
ArrayPointerObject []*ComplexJsonLevel1 `json:"arrayPointerObject"`
|
||||
// Object map
|
||||
MapObject map[string]ComplexJsonLevel1 `json:"mapObject"`
|
||||
MapPointerObject map[string]*ComplexJsonLevel1 `json:"mapPointerObject"`
|
||||
// Object array array
|
||||
ArrayArrayObject [][]ComplexJsonLevel1 `json:"arrayArrayObject"`
|
||||
ArrayArrayPointerObject [][]*ComplexJsonLevel1 `json:"arrayArrayPointerObject"`
|
||||
// Object array map
|
||||
ArrayMapObject []map[string]ComplexJsonLevel1 `json:"arrayMapObject"`
|
||||
ArrayMapPointerObject []map[string]*ComplexJsonLevel1 `json:"arrayMapPointerObject"`
|
||||
// Object map array
|
||||
MapArrayObject map[string][]ComplexJsonLevel1 `json:"mapArrayObject"`
|
||||
MapArrayPointerObject map[string][]*ComplexJsonLevel1 `json:"mapArrayPointerObject"`
|
||||
}
|
||||
)
|
||||
|
||||
@server (
|
||||
tags: "postJson" // tags corresponding to Swagger
|
||||
summary: "json API set" // summary corresponding to Swagger
|
||||
)
|
||||
service Swagger {
|
||||
@doc (
|
||||
description: "simple json request body API"
|
||||
)
|
||||
@handler jsonSimple
|
||||
post /json/simple (JsonReq) returns (JsonResp)
|
||||
|
||||
@doc (
|
||||
description: "complex json request body API"
|
||||
)
|
||||
@handler jsonComplex
|
||||
post /json/complex (ComplexJsonReq) returns (ComplexJsonResp)
|
||||
}
|
||||
|
||||
4980
tools/goctl/api/swagger/example/example.swagger.json
Normal file
4980
tools/goctl/api/swagger/example/example.swagger.json
Normal file
File diff suppressed because it is too large
Load Diff
247
tools/goctl/api/swagger/example/example_cn.api
Normal file
247
tools/goctl/api/swagger/example/example_cn.api
Normal file
@@ -0,0 +1,247 @@
|
||||
syntax = "v1"
|
||||
|
||||
info (
|
||||
title: "演示 API" // 对应 swagger 的 title
|
||||
description: "演示 api 生成 swagger 文件的 api 完整写法" // 对应 swagger 的 description
|
||||
version: "v1" // 对应 swagger 的 version
|
||||
termsOfService: "https://github.com/zeromicro/go-zero" // 对应 swagger 的 termsOfService
|
||||
contactName: "keson.an" // 对应 swagger 的 contactName
|
||||
contactURL: "https://github.com/zeromicro/go-zero" // 对应 swagger 的 contactURL
|
||||
contactEmail: "example@gmail.com" // 对应 swagger 的 contactEmail
|
||||
licenseName: "MIT" // 对应 swagger 的 licenseName
|
||||
licenseURL: "https://github.com/zeromicro/go-zero" // 对应 swagger 的 licenseURL
|
||||
consumes: "application/json" // 对应 swagger 的 consumes,不填默认为 application/json
|
||||
produces: "application/json" // 对应 swagger 的 produces,不填默认为 application/json
|
||||
schemes: "http,https" // 对应 swagger 的 schemes,不填默认为 https
|
||||
host: "example.com" // 对应 swagger 的 host,不填默认为 127.0.0.1
|
||||
basePath: "/v1" // 对应 swagger 的 basePath,不填默认为 /
|
||||
wrapCodeMsg: true // 是否用 code-msg 通用响应体,如果开启,则以格式 {"code":0,"msg":"OK","data":$data} 包括响应体
|
||||
bizCodeEnumDescription: "1001-未登录<br>1002-无权限操作" // 全局业务错误码枚举描述,json 格式,key 为业务错误码,value 为该错误码的描述,仅当 wrapCodeMsg 为 true 时生效
|
||||
// securityDefinitionsFromJson 为自定义鉴权配置,json 内容将直接放入 swagger 的 securityDefinitions 中,
|
||||
// 格式参考 https://swagger.io/specification/v2/#security-definitions-object
|
||||
// 在 api 的 @server 中可声明 authType 来指定其路由使用的鉴权类型
|
||||
securityDefinitionsFromJson: `{"apiKey":{"description":"apiKey 类型鉴权自定义","type":"apiKey","name":"x-api-key","in":"header"}}`
|
||||
useDefinitions: true// 开启声明将生成models 进行关联,definitions 仅对响应体和 json 请求体生效
|
||||
)
|
||||
|
||||
type (
|
||||
QueryReq {
|
||||
Id int `form:"id,range=[1:10000],example=10"`
|
||||
Name string `form:"name,example=keson.an"`
|
||||
Avatar string `form:"avatar,optional,example=https://example.com/avatar.png"`
|
||||
}
|
||||
QueryResp {
|
||||
Id int `json:"id,example=10"`
|
||||
Name string `json:"name,example=keson.an"`
|
||||
}
|
||||
PathQueryReq {
|
||||
Id int `path:"id,range=[1:10000],example=10"`
|
||||
Name string `form:"name,example=keson.an"`
|
||||
}
|
||||
PathQueryResp {
|
||||
Id int `json:"id,example=10"`
|
||||
Name string `json:"name,example=keson.an"`
|
||||
}
|
||||
)
|
||||
|
||||
@server (
|
||||
tags: "query 演示" // 对应 swagger 的 tags,可以对 swagger 中的 api 进行分组
|
||||
summary: "query 类型接口集合" // 对应 swagger 的 summary
|
||||
prefix: v1
|
||||
authType: apiKey // 指定该路由使用的鉴权类型,值为 securityDefinitionsFromJson 中定义的名称
|
||||
group:"demo"
|
||||
)
|
||||
service Swagger {
|
||||
@doc (
|
||||
description: "query 接口"
|
||||
bizCodeEnumDescription: " 1003-用不存在<br>1004-非法操作" // 接口级别业务错误码枚举描述,会覆盖全局的业务错误码,json 格式,key 为业务错误码,value 为该错误码的描述,仅当 wrapCodeMsg 为 true 且 useDefinitions 为 false 时生效
|
||||
)
|
||||
@handler query
|
||||
get /query (QueryReq) returns (QueryResp)
|
||||
|
||||
@doc (
|
||||
description: "query path 中包含 id 字段接口"
|
||||
)
|
||||
@handler queryPath
|
||||
get /query/:id (PathQueryReq) returns (PathQueryResp)
|
||||
}
|
||||
|
||||
type (
|
||||
FormReq {
|
||||
Id int `form:"id,range=[1:10000],example=10"`
|
||||
Name string `form:"name,example=keson.an"`
|
||||
}
|
||||
FormResp {
|
||||
Id int `json:"id,example=10"`
|
||||
Name string `json:"name,example=keson.an"`
|
||||
}
|
||||
)
|
||||
|
||||
@server (
|
||||
tags: "form 表单 api 演示" // 对应 swagger 的 tags,可以对 swagger 中的 api 进行分组
|
||||
summary: "form 表单类型接口集合" // 对应 swagger 的 summary
|
||||
)
|
||||
service Swagger {
|
||||
@doc (
|
||||
description: "form 接口"
|
||||
)
|
||||
@handler form
|
||||
post /form (FormReq) returns (FormResp)
|
||||
}
|
||||
|
||||
type (
|
||||
JsonReq {
|
||||
Id int `json:"id,range=[1:10000],example=10"`
|
||||
Name string `json:"name,example=keson.an"`
|
||||
Avatar string `json:"avatar,optional"`
|
||||
Language string `json:"language,options=golang|java|python|typescript|rust"`
|
||||
Gender string `json:"gender,default=male,options=male|female,example=male"`
|
||||
}
|
||||
JsonResp {
|
||||
Id int `json:"id"`
|
||||
Name string `json:"name"`
|
||||
Avatar string `json:"avatar"`
|
||||
Language string `json:"language"`
|
||||
Gender string `json:"gender"`
|
||||
}
|
||||
ComplexJsonLevel2 {
|
||||
// basic
|
||||
Integer int `json:"integer,example=1"`
|
||||
Number float64 `json:"number,example=1.1"`
|
||||
Boolean bool `json:"boolean,options=true|false,example=true"`
|
||||
String string `json:"string,example=some text"`
|
||||
}
|
||||
ComplexJsonLevel1 {
|
||||
// basic
|
||||
Integer int `json:"integer,example=1"`
|
||||
Number float64 `json:"number,example=1.1"`
|
||||
Boolean bool `json:"boolean,options=true|false,example=true"`
|
||||
String string `json:"string,example=some text"`
|
||||
// Object
|
||||
Object ComplexJsonLevel2 `json:"object"`
|
||||
PointerObject *ComplexJsonLevel2 `json:"pointerObject"`
|
||||
}
|
||||
ComplexJsonReq {
|
||||
// basic
|
||||
Integer int `json:"integer,example=1"`
|
||||
Number float64 `json:"number,example=1.1"`
|
||||
Boolean bool `json:"boolean,options=true|false,example=true"`
|
||||
String string `json:"string,example=some text"`
|
||||
// basic array
|
||||
ArrayInteger []int `json:"arrayInteger"`
|
||||
ArrayNumber []float64 `json:"arrayNumber"`
|
||||
ArrayBoolean []bool `json:"arrayBoolean"`
|
||||
ArrayString []string `json:"arrayString"`
|
||||
// basic array array
|
||||
ArrayArrayInteger [][]int `json:"arrayArrayInteger"`
|
||||
ArrayArrayNumber [][]float64 `json:"arrayArrayNumber"`
|
||||
ArrayArrayBoolean [][]bool `json:"arrayArrayBoolean"`
|
||||
ArrayArrayString [][]string `json:"arrayArrayString"`
|
||||
// basic map
|
||||
MapInteger map[string]int `json:"mapInteger"`
|
||||
MapNumber map[string]float64 `json:"mapNumber"`
|
||||
MapBoolean map[string]bool `json:"mapBoolean"`
|
||||
MapString map[string]string `json:"mapString"`
|
||||
// basic map array
|
||||
MapArrayInteger map[string][]int `json:"mapArrayInteger"`
|
||||
MapArrayNumber map[string][]float64 `json:"mapArrayNumber"`
|
||||
MapArrayBoolean map[string][]bool `json:"mapArrayBoolean"`
|
||||
MapArrayString map[string][]string `json:"mapArrayString"`
|
||||
// basic map map
|
||||
MapMapInteger map[string]map[string]int `json:"mapMapInteger"`
|
||||
MapMapNumber map[string]map[string]float64 `json:"mapMapNumber"`
|
||||
MapMapBoolean map[string]map[string]bool `json:"mapMapBoolean"`
|
||||
MapMapString map[string]map[string]string `json:"mapMapString"`
|
||||
MapMapObject map[string]map[string]ComplexJsonLevel1 `json:"mapMapObject"`
|
||||
MapMapPointerObject map[string]map[string]*ComplexJsonLevel1 `json:"mapMapPointerObject"`
|
||||
// Object
|
||||
Object ComplexJsonLevel1 `json:"object"`
|
||||
PointerObject *ComplexJsonLevel1 `json:"pointerObject"`
|
||||
// Object array
|
||||
ArrayObject []ComplexJsonLevel1 `json:"arrayObject"`
|
||||
ArrayPointerObject []*ComplexJsonLevel1 `json:"arrayPointerObject"`
|
||||
// Object map
|
||||
MapObject map[string]ComplexJsonLevel1 `json:"mapObject"`
|
||||
MapPointerObject map[string]*ComplexJsonLevel1 `json:"mapPointerObject"`
|
||||
// Object array array
|
||||
ArrayArrayObject [][]ComplexJsonLevel1 `json:"arrayArrayObject"`
|
||||
ArrayArrayPointerObject [][]*ComplexJsonLevel1 `json:"arrayArrayPointerObject"`
|
||||
// Object array map
|
||||
ArrayMapObject []map[string]ComplexJsonLevel1 `json:"arrayMapObject"`
|
||||
ArrayMapPointerObject []map[string]*ComplexJsonLevel1 `json:"arrayMapPointerObject"`
|
||||
// Object map array
|
||||
MapArrayObject map[string][]ComplexJsonLevel1 `json:"mapArrayObject"`
|
||||
MapArrayPointerObject map[string][]*ComplexJsonLevel1 `json:"mapArrayPointerObject"`
|
||||
}
|
||||
ComplexJsonResp {
|
||||
// basic
|
||||
Integer int `json:"integer,example=1"`
|
||||
Number float64 `json:"number,example=1.1"`
|
||||
Boolean bool `json:"boolean,options=true|false,example=true"`
|
||||
String string `json:"string,example=some text"`
|
||||
// basic array
|
||||
ArrayInteger []int `json:"arrayInteger"`
|
||||
ArrayNumber []float64 `json:"arrayNumber"`
|
||||
ArrayBoolean []bool `json:"arrayBoolean"`
|
||||
ArrayString []string `json:"arrayString"`
|
||||
// basic array array
|
||||
ArrayArrayInteger [][]int `json:"arrayArrayInteger"`
|
||||
ArrayArrayNumber [][]float64 `json:"arrayArrayNumber"`
|
||||
ArrayArrayBoolean [][]bool `json:"arrayArrayBoolean"`
|
||||
ArrayArrayString [][]string `json:"arrayArrayString"`
|
||||
// basic map
|
||||
MapInteger map[string]int `json:"mapInteger"`
|
||||
MapNumber map[string]float64 `json:"mapNumber"`
|
||||
MapBoolean map[string]bool `json:"mapBoolean"`
|
||||
MapString map[string]string `json:"mapString"`
|
||||
// basic map array
|
||||
MapArrayInteger map[string][]int `json:"mapArrayInteger"`
|
||||
MapArrayNumber map[string][]float64 `json:"mapArrayNumber"`
|
||||
MapArrayBoolean map[string][]bool `json:"mapArrayBoolean"`
|
||||
MapArrayString map[string][]string `json:"mapArrayString"`
|
||||
// basic map map
|
||||
MapMapInteger map[string]map[string]int `json:"mapMapInteger"`
|
||||
MapMapNumber map[string]map[string]float64 `json:"mapMapNumber"`
|
||||
MapMapBoolean map[string]map[string]bool `json:"mapMapBoolean"`
|
||||
MapMapString map[string]map[string]string `json:"mapMapString"`
|
||||
MapMapObject map[string]map[string]ComplexJsonLevel1 `json:"mapMapObject"`
|
||||
MapMapPointerObject map[string]map[string]*ComplexJsonLevel1 `json:"mapMapPointerObject"`
|
||||
// Object
|
||||
Object ComplexJsonLevel1 `json:"object"`
|
||||
PointerObject *ComplexJsonLevel1 `json:"pointerObject"`
|
||||
// Object array
|
||||
ArrayObject []ComplexJsonLevel1 `json:"arrayObject"`
|
||||
ArrayPointerObject []*ComplexJsonLevel1 `json:"arrayPointerObject"`
|
||||
// Object map
|
||||
MapObject map[string]ComplexJsonLevel1 `json:"mapObject"`
|
||||
MapPointerObject map[string]*ComplexJsonLevel1 `json:"mapPointerObject"`
|
||||
// Object array array
|
||||
ArrayArrayObject [][]ComplexJsonLevel1 `json:"arrayArrayObject"`
|
||||
ArrayArrayPointerObject [][]*ComplexJsonLevel1 `json:"arrayArrayPointerObject"`
|
||||
// Object array map
|
||||
ArrayMapObject []map[string]ComplexJsonLevel1 `json:"arrayMapObject"`
|
||||
ArrayMapPointerObject []map[string]*ComplexJsonLevel1 `json:"arrayMapPointerObject"`
|
||||
// Object map array
|
||||
MapArrayObject map[string][]ComplexJsonLevel1 `json:"mapArrayObject"`
|
||||
MapArrayPointerObject map[string][]*ComplexJsonLevel1 `json:"mapArrayPointerObject"`
|
||||
}
|
||||
)
|
||||
|
||||
@server (
|
||||
tags: "post json api 演示" // 对应 swagger 的 tags,可以对 swagger 中的 api 进行分组
|
||||
summary: "json 请求类型接口集合" // 对应 swagger 的 summary
|
||||
)
|
||||
service Swagger {
|
||||
@doc (
|
||||
description: "简单的 json 请求体接口"
|
||||
)
|
||||
@handler jsonSimple
|
||||
post /json/simple (JsonReq) returns (JsonResp)
|
||||
|
||||
@doc (
|
||||
description: "复杂的 json 请求体接口"
|
||||
)
|
||||
@handler jsonComplex
|
||||
post /json/complex (ComplexJsonReq) returns (ComplexJsonResp)
|
||||
}
|
||||
|
||||
5608
tools/goctl/api/swagger/example/example_cn.swagger.json
Normal file
5608
tools/goctl/api/swagger/example/example_cn.swagger.json
Normal file
File diff suppressed because it is too large
Load Diff
39
tools/goctl/api/swagger/example/go-swagger-cn.sh
Normal file
39
tools/goctl/api/swagger/example/go-swagger-cn.sh
Normal file
@@ -0,0 +1,39 @@
|
||||
#!/bin/bash
|
||||
|
||||
# 1. 检查并安装 swagger
|
||||
if ! command -v swagger &> /dev/null; then
|
||||
echo "swagger 未安装,正在从 GitHub 安装..."
|
||||
# 这里使用 go-swagger 的安装方式
|
||||
go install github.com/go-swagger/go-swagger/cmd/swagger@latest
|
||||
if [ $? -ne 0 ]; then
|
||||
echo "安装 swagger 失败"
|
||||
exit 1
|
||||
fi
|
||||
echo "swagger 安装成功"
|
||||
else
|
||||
echo "swagger 已安装"
|
||||
fi
|
||||
|
||||
mkdir bin output
|
||||
|
||||
export GOBIN=$(pwd)/bin
|
||||
|
||||
# 2. 安装最新版 goctl
|
||||
go install ../../..
|
||||
if [ $? -ne 0 ]; then
|
||||
echo "安装 goctl 失败"
|
||||
exit 1
|
||||
fi
|
||||
echo "goctl 安装成功"
|
||||
|
||||
# 3. 生成 swagger 文件
|
||||
echo "正在生成 swagger 文件..."
|
||||
./bin/goctl api swagger --api example_cn.api --dir output
|
||||
if [ $? -ne 0 ]; then
|
||||
echo "生成 swagger 文件失败"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# 4. 启动 swagger 服务
|
||||
echo "启动 swagger 服务..."
|
||||
swagger serve ./output/example_cn.json
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user