Compare commits

...

27 Commits

Author SHA1 Message Date
copilot-swe-agent[bot]
02191e0d99 Fix TypeScript generator to correctly handle inline/embedded structs
Co-authored-by: kevwan <1918356+kevwan@users.noreply.github.com>
2025-10-25 20:11:35 +08:00
copilot-swe-agent[bot]
28b12ad9cc Add comprehensive integration test for inline struct handling
- Add TestGenWithInlineStructs to verify end-to-end generation
- Test verifies correct parameter generation, flattening of inline structs
- Test verifies no incorrect Headers interfaces are generated
- Test verifies correct argument counts in function calls

Co-authored-by: kevwan <1918356+kevwan@users.noreply.github.com>
2025-10-25 20:11:35 +08:00
copilot-swe-agent[bot]
87dd9671be Fix TypeScript generator to handle inline/embedded structs correctly
- Add hasActualTagMembers helper to recursively check for actual tag members
- Add hasActualBodyMembers helper to recursively check for actual body members
- Add hasActualNonBodyMembers helper to recursively check for actual non-body members
- Update genParamsTypesIfNeed to use hasActualNonBodyMembers and hasActualTagMembers
- Update hasRequestBody, hasRequestHeader, hasRequestPath, and pathHasParams to use new helpers
- Add comprehensive test coverage for new helper functions

Co-authored-by: kevwan <1918356+kevwan@users.noreply.github.com>
2025-10-25 20:11:35 +08:00
Kevin Wan
4e52d77ad8 fix(trace): use sync.Once to prevent multiple trace initialization (#5244)
Signed-off-by: kevin <wanjunfeng@gmail.com>
2025-10-25 20:10:15 +08:00
Kevin Wan
1fc2cfb859 fix: gateway trace headers 5248 (#5256)
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
2025-10-25 12:16:31 +08:00
Gregor Fischer
942cdae41d Fix typos in comments and messages (#5254) 2025-10-25 01:28:40 +00:00
dependabot[bot]
e9c3607bc6 chore(deps): bump github.com/redis/go-redis/v9 from 9.14.1 to 9.16.0 (#5255)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2025-10-24 21:19:49 +08:00
dependabot[bot]
d1603e9166 chore(deps): bump github.com/redis/go-redis/v9 from 9.14.0 to 9.14.1 (#5251)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2025-10-22 21:01:51 +08:00
zhoushuguang
e30317e9c4 feat: consistent hash balancer support (#5246)
Co-authored-by: 周曙光 <zsg@zhoushuguangdeMacBook-Pro.local>
2025-10-19 14:10:30 +00:00
stemlaud
568f9ce007 chore: remove extra spaces in the comment (#5245)
Signed-off-by: stemlaud <stemlaud@outlook.com>
2025-10-19 13:42:10 +00:00
dependabot[bot]
dcb309065a chore(deps): bump github/codeql-action from 3 to 4 (#5243)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2025-10-13 23:09:38 +00:00
Kevin Wan
bf8e17a686 test: add unit tests for goctl docker command (PR #4343) (#5241) 2025-10-13 21:55:25 +08:00
Jack001
b2ebbfce62 fix: ensure Dockerfile includes etc directory and correct CMD based on config (#4343)
Co-authored-by: 白少杰macpro <harrellharris68491@gmail.com>
2025-10-13 13:42:15 +00:00
Kevin Wan
2b10a6a223 fix: support PUT, PATCH, DELETE methods for request body definitions in swagger (#5239) 2025-10-12 18:24:11 +08:00
Kevin Wan
80c320b46e chore: remove unused code (#5238) 2025-10-12 11:55:57 +08:00
Kevin Wan
bea9d150a1 fix(goctl): restore API summaries in swagger generation (#5237) 2025-10-12 11:38:58 +08:00
Kevin Wan
3f756a2cbf chore: update goctl version (#5236) 2025-10-11 18:01:18 +08:00
Kevin Wan
bbe5bbb0c0 chore: update go-redis for the retracted versions (#5235) 2025-10-11 16:20:23 +08:00
dependabot[bot]
5ad2278a69 chore(deps): bump go.mongodb.org/mongo-driver/v2 from 2.3.0 to 2.3.1 (#5230)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2025-10-10 21:32:04 +08:00
ZH1995
77763fe748 Fix api doc url in readme-cn.md (#5233) 2025-10-10 13:21:46 +00:00
Kevin Wan
538c4fb5c7 fix: issue #5154, not using sse template files (#5220) 2025-10-08 11:56:52 +00:00
Kevin Wan
315fb2fe0a Add company to the list in readme-cn.md (#5222) 2025-10-07 21:04:09 +08:00
Copilot
e382887eb8 docs: Add comprehensive documentation for blocking Redis operations (XReadGroup, Blpop) (#5221)
Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com>
Co-authored-by: kevwan <1918356+kevwan@users.noreply.github.com>
Co-authored-by: Kevin Wan <wanjunfeng@gmail.com>
2025-10-07 16:46:10 +08:00
Kevin Wan
cf21cb2b0b chore: refactor to remove duplicated code (#5216) 2025-10-06 22:24:44 +08:00
Copilot
61e8894c31 Fix swagger generation: info block and server tags not included (#5215)
Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com>
Co-authored-by: kevwan <1918356+kevwan@users.noreply.github.com>
2025-10-06 22:02:42 +08:00
Copilot
7a6c3c8129 Fix swagger path generation: remove trailing slash for root routes with prefix (#5212) 2025-10-05 12:12:03 +08:00
Rizky Ikwan
875fec3e1a chore: fix typos (#5210) 2025-10-04 02:56:07 +00:00
47 changed files with 2305 additions and 224 deletions

View File

@@ -39,7 +39,7 @@ jobs:
# Initializes the CodeQL tools for scanning.
- name: Initialize CodeQL
uses: github/codeql-action/init@v3
uses: github/codeql-action/init@v4
with:
languages: ${{ matrix.language }}
# If you wish to specify custom queries, you can do so here or in a config file.
@@ -50,7 +50,7 @@ jobs:
# Autobuild attempts to build any compiled languages (C/C++, C#, or Java).
# If this step fails, then you should remove it and run the build manually (see below)
- name: Autobuild
uses: github/codeql-action/autobuild@v3
uses: github/codeql-action/autobuild@v4
# Command-line programs to run using the OS shell.
# 📚 https://git.io/JvXDl
@@ -64,4 +64,4 @@ jobs:
# make release
- name: Perform CodeQL Analysis
uses: github/codeql-action/analyze@v3
uses: github/codeql-action/analyze@v4

View File

@@ -532,7 +532,7 @@ func createModel(t *testing.T, coll mon.Collection) *Model {
}
}
// mustNewTestModel returns a test Model with the given cache.
// mustNewTestModel returns a test Model with the given cache.
func mustNewTestModel(collection mon.Collection, c cache.CacheConf, opts ...cache.Option) *Model {
return &Model{
Model: &mon.Model{

View File

@@ -259,12 +259,34 @@ func (s *Redis) BitPosCtx(ctx context.Context, key string, bit, start, end int64
}
// Blpop uses passed in redis connection to execute blocking queries.
//
// For blocking operations, you must create a dedicated RedisNode using CreateBlockingNode to avoid
// exhausting the connection pool. Blocking commands hold connections for extended periods and should
// not share the regular connection pool.
//
// Example usage:
//
// node, err := redis.CreateBlockingNode(rds)
// if err != nil {
// // handle error
// }
// defer node.Close()
//
// value, err := rds.Blpop(node, "mylist")
// if err != nil {
// // handle error
// }
//
// Doesn't benefit from pooling redis connections of blocking queries
func (s *Redis) Blpop(node RedisNode, key string) (string, error) {
return s.BlpopCtx(context.Background(), node, key)
}
// BlpopCtx uses passed in redis connection to execute blocking queries.
//
// For blocking operations, you must create a dedicated RedisNode using CreateBlockingNode.
// See Blpop for usage examples.
//
// Doesn't benefit from pooling redis connections of blocking queries
func (s *Redis) BlpopCtx(ctx context.Context, node RedisNode, key string) (string, error) {
return s.BlpopWithTimeoutCtx(ctx, node, blockingQueryTimeout, key)
@@ -272,12 +294,18 @@ func (s *Redis) BlpopCtx(ctx context.Context, node RedisNode, key string) (strin
// BlpopEx uses passed in redis connection to execute blpop command.
// The difference against Blpop is that this method returns a bool to indicate success.
//
// For blocking operations, you must create a dedicated RedisNode using CreateBlockingNode.
// See Blpop for usage examples.
func (s *Redis) BlpopEx(node RedisNode, key string) (string, bool, error) {
return s.BlpopExCtx(context.Background(), node, key)
}
// BlpopExCtx uses passed in redis connection to execute blpop command.
// The difference against Blpop is that this method returns a bool to indicate success.
//
// For blocking operations, you must create a dedicated RedisNode using CreateBlockingNode.
// See Blpop for usage examples.
func (s *Redis) BlpopExCtx(ctx context.Context, node RedisNode, key string) (string, bool, error) {
if node == nil {
return "", false, ErrNilNode
@@ -297,12 +325,18 @@ func (s *Redis) BlpopExCtx(ctx context.Context, node RedisNode, key string) (str
// BlpopWithTimeout uses passed in redis connection to execute blpop command.
// Control blocking query timeout
//
// For blocking operations, you must create a dedicated RedisNode using CreateBlockingNode.
// See Blpop for usage examples.
func (s *Redis) BlpopWithTimeout(node RedisNode, timeout time.Duration, key string) (string, error) {
return s.BlpopWithTimeoutCtx(context.Background(), node, timeout, key)
}
// BlpopWithTimeoutCtx uses passed in redis connection to execute blpop command.
// Control blocking query timeout
//
// For blocking operations, you must create a dedicated RedisNode using CreateBlockingNode.
// See Blpop for usage examples.
func (s *Redis) BlpopWithTimeoutCtx(ctx context.Context, node RedisNode, timeout time.Duration,
key string) (string, error) {
if node == nil {
@@ -1840,6 +1874,29 @@ func (s *Redis) XInfoStreamCtx(ctx context.Context, stream string) (*red.XInfoSt
// XReadGroup reads messages from Redis streams as part of a consumer group.
// It allows for distributed processing of stream messages with automatic message delivery semantics.
//
// For blocking operations, you must create a dedicated RedisNode using CreateBlockingNode to avoid
// exhausting the connection pool. Blocking commands hold connections for extended periods and should
// not share the regular connection pool.
//
// Example usage:
//
// node, err := redis.CreateBlockingNode(rds)
// if err != nil {
// // handle error
// }
// defer node.Close()
//
// streams, err := rds.XReadGroup(
// node, // RedisNode created with CreateBlockingNode
// "mygroup", // consumer group name
// "consumer1", // consumer ID
// 10, // max number of messages to read
// 5*time.Second, // block duration
// false, // noAck flag
// "mystream", // stream name
// )
//
// Doesn't benefit from pooling redis connections of blocking queries.
func (s *Redis) XReadGroup(node RedisNode, group string, consumerId string, count int64,
block time.Duration, noAck bool, streams ...string) ([]red.XStream, error) {
@@ -1847,6 +1904,10 @@ func (s *Redis) XReadGroup(node RedisNode, group string, consumerId string, coun
}
// XReadGroupCtx is the context-aware version of XReadGroup.
//
// For blocking operations, you must create a dedicated RedisNode using CreateBlockingNode to avoid
// exhausting the connection pool. See XReadGroup for usage examples.
//
// Doesn't benefit from pooling redis connections of blocking queries.
func (s *Redis) XReadGroupCtx(ctx context.Context, node RedisNode, group string, consumerId string,
count int64, block time.Duration, noAck bool, streams ...string) ([]red.XStream, error) {

View File

@@ -13,7 +13,37 @@ type ClosableNode interface {
Close()
}
// CreateBlockingNode returns a ClosableNode.
// CreateBlockingNode creates a dedicated RedisNode for blocking operations.
//
// Blocking Redis commands (like BLPOP, BRPOP, XREADGROUP with block parameter) hold connections
// for extended periods while waiting for data. Using them with the regular Redis connection pool
// can exhaust all available connections, causing other operations to fail or timeout.
//
// CreateBlockingNode creates a separate Redis client with a minimal connection pool (size 1) that
// is dedicated to blocking operations. This ensures blocking commands don't interfere with regular
// Redis operations.
//
// Example usage:
//
// rds := redis.MustNewRedis(redis.RedisConf{
// Host: "localhost:6379",
// Type: redis.NodeType,
// })
//
// // Create a dedicated node for blocking operations
// node, err := redis.CreateBlockingNode(rds)
// if err != nil {
// // handle error
// }
// defer node.Close() // Important: close the node when done
//
// // Use the node for blocking operations
// value, err := rds.Blpop(node, "mylist")
// if err != nil {
// // handle error
// }
//
// The returned ClosableNode must be closed when no longer needed to release resources.
func CreateBlockingNode(r *Redis) (ClosableNode, error) {
timeout := readWriteTimeout + blockingQueryTimeout

View File

@@ -7,7 +7,6 @@ import (
"os"
"sync"
"github.com/zeromicro/go-zero/core/lang"
"github.com/zeromicro/go-zero/core/logx"
"go.opentelemetry.io/otel"
"go.opentelemetry.io/otel/exporters/jaeger"
@@ -30,42 +29,36 @@ const (
)
var (
agents = make(map[string]lang.PlaceholderType)
lock sync.Mutex
tp *sdktrace.TracerProvider
once sync.Once
tp *sdktrace.TracerProvider
shutdownOnceFn = sync.OnceFunc(func() {
if tp != nil {
_ = tp.Shutdown(context.Background())
}
})
)
// StartAgent starts an opentelemetry agent.
// It uses sync.Once to ensure the agent is initialized only once,
// similar to prometheus.StartAgent and logx.SetUp.
// This prevents multiple ServiceConf.SetUp() calls from reinitializing
// the global tracer provider when running multiple servers (e.g., REST + RPC)
// in the same process.
func StartAgent(c Config) {
if c.Disabled {
return
}
lock.Lock()
defer lock.Unlock()
_, ok := agents[c.Endpoint]
if ok {
return
}
// if error happens, let later calls run.
if err := startAgent(c); err != nil {
return
}
agents[c.Endpoint] = lang.Placeholder
once.Do(func() {
if err := startAgent(c); err != nil {
logx.Error(err)
}
})
}
// StopAgent shuts down the span processors in the order they were registered.
func StopAgent() {
lock.Lock()
defer lock.Unlock()
if tp != nil {
_ = tp.Shutdown(context.Background())
tp = nil
}
shutdownOnceFn()
}
func createExporter(c Config) (sdktrace.SpanExporter, error) {

View File

@@ -1,10 +1,13 @@
package trace
import (
"context"
"errors"
"testing"
"github.com/stretchr/testify/assert"
"github.com/zeromicro/go-zero/core/logx"
"go.opentelemetry.io/otel"
)
func TestStartAgent(t *testing.T) {
@@ -89,23 +92,305 @@ func TestStartAgent(t *testing.T) {
StartAgent(c10)
defer StopAgent()
lock.Lock()
defer lock.Unlock()
// because remotehost cannot be resolved
assert.Equal(t, 6, len(agents))
_, ok := agents[""]
assert.True(t, ok)
_, ok = agents[endpoint1]
assert.True(t, ok)
_, ok = agents[endpoint2]
assert.False(t, ok)
_, ok = agents[endpoint5]
assert.True(t, ok)
_, ok = agents[endpoint6]
assert.False(t, ok)
_, ok = agents[endpoint71]
assert.True(t, ok)
_, ok = agents[endpoint72]
assert.False(t, ok)
// With sync.Once, only the first non-disabled config (c1) takes effect.
// Subsequent calls are ignored, which is the desired behavior to prevent
// multiple servers (REST + RPC) from reinitializing the global tracer.
assert.NotNil(t, tp)
}
func TestCreateExporter_InvalidFilePath(t *testing.T) {
logx.Disable()
c := Config{
Name: "test-invalid-file",
Endpoint: "/non-existent-directory/trace.log",
Batcher: kindFile,
}
_, err := createExporter(c)
assert.Error(t, err)
assert.Contains(t, err.Error(), "file exporter endpoint error")
}
func TestCreateExporter_UnknownBatcher(t *testing.T) {
logx.Disable()
c := Config{
Name: "test-unknown",
Endpoint: "localhost:1234",
Batcher: "unknown-batcher-type",
}
_, err := createExporter(c)
assert.Error(t, err)
assert.Contains(t, err.Error(), "unknown exporter")
}
func TestCreateExporter_ValidExporters(t *testing.T) {
logx.Disable()
tests := []struct {
name string
config Config
wantErr bool
errMsg string
}{
{
name: "valid file exporter",
config: Config{
Name: "file-test",
Endpoint: "/tmp/trace-test.log",
Batcher: kindFile,
},
wantErr: false,
},
{
name: "invalid file path",
config: Config{
Name: "file-test-invalid",
Endpoint: "/invalid-path/that/does/not/exist/trace.log",
Batcher: kindFile,
},
wantErr: true,
errMsg: "file exporter endpoint error",
},
{
name: "unknown batcher",
config: Config{
Name: "unknown-test",
Endpoint: "localhost:1234",
Batcher: "invalid-batcher",
},
wantErr: true,
errMsg: "unknown exporter",
},
{
name: "jaeger http",
config: Config{
Name: "jaeger-http",
Endpoint: "http://localhost:14268/api/traces",
Batcher: kindJaeger,
},
wantErr: false,
},
{
name: "jaeger udp",
config: Config{
Name: "jaeger-udp",
Endpoint: "udp://localhost:6831",
Batcher: kindJaeger,
},
wantErr: false,
},
{
name: "zipkin",
config: Config{
Name: "zipkin",
Endpoint: "http://localhost:9411/api/v2/spans",
Batcher: kindZipkin,
},
wantErr: false,
},
{
name: "otlpgrpc",
config: Config{
Name: "otlpgrpc",
Endpoint: "localhost:4317",
Batcher: kindOtlpGrpc,
},
wantErr: false,
},
{
name: "otlpgrpc with headers",
config: Config{
Name: "otlpgrpc-headers",
Endpoint: "localhost:4317",
Batcher: kindOtlpGrpc,
OtlpHeaders: map[string]string{
"authorization": "Bearer token123",
"x-custom-key": "custom-value",
},
},
wantErr: false,
},
{
name: "otlphttp",
config: Config{
Name: "otlphttp",
Endpoint: "localhost:4318",
Batcher: kindOtlpHttp,
},
wantErr: false,
},
{
name: "otlphttp with headers",
config: Config{
Name: "otlphttp-headers",
Endpoint: "localhost:4318",
Batcher: kindOtlpHttp,
OtlpHeaders: map[string]string{
"authorization": "Bearer token456",
"x-api-key": "api-key-value",
},
},
wantErr: false,
},
{
name: "otlphttp with headers and path",
config: Config{
Name: "otlphttp-headers-path",
Endpoint: "localhost:4318",
Batcher: kindOtlpHttp,
OtlpHttpPath: "/v1/traces",
OtlpHeaders: map[string]string{
"authorization": "Bearer token789",
"x-custom-trace": "trace-id",
},
},
wantErr: false,
},
{
name: "otlphttp with secure connection",
config: Config{
Name: "otlphttp-secure",
Endpoint: "localhost:4318",
Batcher: kindOtlpHttp,
OtlpHttpSecure: true,
OtlpHeaders: map[string]string{
"authorization": "Bearer secure-token",
},
},
wantErr: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
exporter, err := createExporter(tt.config)
if tt.wantErr {
assert.Error(t, err)
if tt.errMsg != "" {
assert.Contains(t, err.Error(), tt.errMsg)
}
assert.Nil(t, exporter)
} else {
assert.NoError(t, err)
assert.NotNil(t, exporter)
// Clean up the exporter
if exporter != nil {
_ = exporter.Shutdown(context.Background())
}
}
})
}
}
func TestStopAgent(t *testing.T) {
logx.Disable()
// StopAgent should be idempotent and safe to call multiple times
assert.NotPanics(t, func() {
StopAgent()
StopAgent()
StopAgent()
})
}
func TestStartAgent_WithEndpoint(t *testing.T) {
logx.Disable()
tests := []struct {
name string
config Config
wantErr bool
}{
{
name: "empty endpoint - no exporter created",
config: Config{
Name: "test-no-endpoint",
Sampler: 1.0,
},
wantErr: false,
},
{
name: "valid endpoint with file exporter",
config: Config{
Name: "test-with-endpoint",
Endpoint: "/tmp/test-trace.log",
Batcher: kindFile,
Sampler: 1.0,
},
wantErr: false,
},
{
name: "endpoint with invalid exporter type",
config: Config{
Name: "test-invalid-batcher",
Endpoint: "localhost:1234",
Batcher: "invalid-type",
Sampler: 1.0,
},
wantErr: true,
},
{
name: "endpoint with invalid file path",
config: Config{
Name: "test-invalid-path",
Endpoint: "/non/existent/path/trace.log",
Batcher: kindFile,
Sampler: 1.0,
},
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Reset tp for each test
originalTp := tp
tp = nil
defer func() {
if tp != nil {
_ = tp.Shutdown(context.Background())
}
tp = originalTp
}()
err := startAgent(tt.config)
if tt.wantErr {
assert.Error(t, err)
} else {
assert.NoError(t, err)
assert.NotNil(t, tp, "TracerProvider should be created")
}
})
}
}
func TestStartAgent_ErrorHandler(t *testing.T) {
// Setup a tracer provider to test error handler
originalTp := tp
tp = nil
defer func() {
if tp != nil {
_ = tp.Shutdown(context.Background())
}
tp = originalTp
}()
// Call startAgent to set up the error handler
config := Config{
Name: "test-error-handler",
Sampler: 1.0,
}
err := startAgent(config)
assert.NoError(t, err)
assert.NotNil(t, tp)
// Verify the error handler was set and can be called without panicking
// We test this by calling otel.Handle which will invoke the registered error handler
testErr := errors.New("test otel error")
assert.NotPanics(t, func() {
otel.Handle(testErr)
}, "Error handler should handle errors without panicking")
}

View File

@@ -11,16 +11,40 @@ const (
metadataPrefix = "gateway-"
)
// OpenTelemetry trace propagation headers that need to be forwarded to gRPC metadata.
// These headers are used by the W3C Trace Context standard for distributed tracing.
var traceHeaders = map[string]bool{
"traceparent": true,
"tracestate": true,
"baggage": true,
}
// ProcessHeaders builds the headers for the gateway from HTTP headers.
// It forwards both custom metadata headers (with Grpc-Metadata- prefix)
// and OpenTelemetry trace propagation headers (traceparent, tracestate, baggage)
// to ensure distributed tracing works correctly across the gateway.
func ProcessHeaders(header http.Header) []string {
var headers []string
for k, v := range header {
// Forward OpenTelemetry trace propagation headers
// These must be lowercase per gRPC metadata conventions
if lowerKey := strings.ToLower(k); traceHeaders[lowerKey] {
for _, vv := range v {
headers = append(headers, lowerKey+":"+vv)
}
continue
}
// Forward custom metadata headers with Grpc-Metadata- prefix
if !strings.HasPrefix(k, metadataHeaderPrefix) {
continue
}
key := fmt.Sprintf("%s%s", metadataPrefix, strings.TrimPrefix(k, metadataHeaderPrefix))
// gRPC metadata keys are case-insensitive and stored as lowercase,
// so we lowercase the key to match gRPC conventions
trimmedKey := strings.TrimPrefix(k, metadataHeaderPrefix)
key := strings.ToLower(fmt.Sprintf("%s%s", metadataPrefix, trimmedKey))
for _, vv := range v {
headers = append(headers, key+":"+vv)
}

View File

@@ -18,5 +18,93 @@ func TestBuildHeadersWithValues(t *testing.T) {
req := httptest.NewRequest("GET", "/", http.NoBody)
req.Header.Add("grpc-metadata-a", "b")
req.Header.Add("grpc-metadata-b", "b")
assert.ElementsMatch(t, []string{"gateway-A:b", "gateway-B:b"}, ProcessHeaders(req.Header))
assert.ElementsMatch(t, []string{"gateway-a:b", "gateway-b:b"}, ProcessHeaders(req.Header))
}
func TestProcessHeadersWithTraceContext(t *testing.T) {
req := httptest.NewRequest("GET", "/", http.NoBody)
req.Header.Set("traceparent", "00-4bf92f3577b34da6a3ce929d0e0e4736-00f067aa0ba902b7-01")
req.Header.Set("tracestate", "key1=value1,key2=value2")
req.Header.Set("baggage", "userId=alice,serverNode=DF:28")
headers := ProcessHeaders(req.Header)
assert.Len(t, headers, 3)
assert.Contains(t, headers, "traceparent:00-4bf92f3577b34da6a3ce929d0e0e4736-00f067aa0ba902b7-01")
assert.Contains(t, headers, "tracestate:key1=value1,key2=value2")
assert.Contains(t, headers, "baggage:userId=alice,serverNode=DF:28")
}
func TestProcessHeadersWithMixedHeaders(t *testing.T) {
req := httptest.NewRequest("GET", "/", http.NoBody)
req.Header.Set("traceparent", "00-4bf92f3577b34da6a3ce929d0e0e4736-00f067aa0ba902b7-01")
req.Header.Set("grpc-metadata-custom", "value1")
req.Header.Set("content-type", "application/json")
req.Header.Set("tracestate", "key1=value1")
headers := ProcessHeaders(req.Header)
// Should include trace headers and grpc-metadata headers, but not regular headers
assert.Len(t, headers, 3)
assert.Contains(t, headers, "traceparent:00-4bf92f3577b34da6a3ce929d0e0e4736-00f067aa0ba902b7-01")
assert.Contains(t, headers, "tracestate:key1=value1")
assert.Contains(t, headers, "gateway-custom:value1")
}
func TestProcessHeadersTraceparentCaseInsensitive(t *testing.T) {
tests := []struct {
name string
headerKey string
headerVal string
expectedKey string
}{
{
name: "lowercase traceparent",
headerKey: "traceparent",
headerVal: "00-trace-span-01",
expectedKey: "traceparent",
},
{
name: "uppercase Traceparent",
headerKey: "Traceparent",
headerVal: "00-trace-span-01",
expectedKey: "traceparent",
},
{
name: "mixed case TraceParent",
headerKey: "TraceParent",
headerVal: "00-trace-span-01",
expectedKey: "traceparent",
},
{
name: "lowercase tracestate",
headerKey: "tracestate",
headerVal: "key=value",
expectedKey: "tracestate",
},
{
name: "mixed case TraceState",
headerKey: "TraceState",
headerVal: "key=value",
expectedKey: "tracestate",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
req := httptest.NewRequest("GET", "/", http.NoBody)
req.Header.Set(tt.headerKey, tt.headerVal)
headers := ProcessHeaders(req.Header)
assert.Len(t, headers, 1)
assert.Contains(t, headers, tt.expectedKey+":"+tt.headerVal)
})
}
}
func TestProcessHeadersEmptyHeaders(t *testing.T) {
req := httptest.NewRequest("GET", "/", http.NoBody)
headers := ProcessHeaders(req.Header)
assert.Empty(t, headers)
}

4
go.mod
View File

@@ -16,12 +16,12 @@ require (
github.com/jhump/protoreflect v1.17.0
github.com/pelletier/go-toml/v2 v2.2.2
github.com/prometheus/client_golang v1.21.1
github.com/redis/go-redis/v9 v9.15.0
github.com/redis/go-redis/v9 v9.16.0
github.com/spaolacci/murmur3 v1.1.0
github.com/stretchr/testify v1.11.1
go.etcd.io/etcd/api/v3 v3.5.15
go.etcd.io/etcd/client/v3 v3.5.15
go.mongodb.org/mongo-driver/v2 v2.3.0
go.mongodb.org/mongo-driver/v2 v2.3.1
go.opentelemetry.io/otel v1.24.0
go.opentelemetry.io/otel/exporters/jaeger v1.17.0
go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc v1.24.0

8
go.sum
View File

@@ -154,8 +154,8 @@ github.com/prometheus/common v0.62.0 h1:xasJaQlnWAeyHdUBeGjXmutelfJHWMRr+Fg4QszZ
github.com/prometheus/common v0.62.0/go.mod h1:vyBcEuLSvWos9B1+CyL7JZ2up+uFzXhkqml0W5zIY1I=
github.com/prometheus/procfs v0.15.1 h1:YagwOFzUgYfKKHX6Dr+sHT7km/hxC76UB0learggepc=
github.com/prometheus/procfs v0.15.1/go.mod h1:fB45yRUv8NstnjriLhBQLuOUt+WW4BsoGhij/e3PBqk=
github.com/redis/go-redis/v9 v9.15.0 h1:2jdes0xJxer4h3NUZrZ4OGSntGlXp4WbXju2nOTRXto=
github.com/redis/go-redis/v9 v9.15.0/go.mod h1:huWgSWd8mW6+m0VPhJjSSQ+d6Nh1VICQ6Q5lHuCH/Iw=
github.com/redis/go-redis/v9 v9.16.0 h1:OotgqgLSRCmzfqChbQyG1PHC3tLNR89DG4jdOERSEP4=
github.com/redis/go-redis/v9 v9.16.0/go.mod h1:u410H11HMLoB+TP67dz8rL9s6QW2j76l0//kSOd3370=
github.com/rivo/uniseg v0.2.0 h1:S1pD9weZBuJdFmowNwbpi7BJ8TNftyUImj/0WQi72jY=
github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc=
github.com/rogpeppe/go-internal v1.10.0 h1:TMyTOH3F/DB16zRVcYyreMH6GnZZrwQVAoYjRBZyWFQ=
@@ -197,8 +197,8 @@ go.etcd.io/etcd/client/pkg/v3 v3.5.15 h1:fo0HpWz/KlHGMCC+YejpiCmyWDEuIpnTDzpJLB5
go.etcd.io/etcd/client/pkg/v3 v3.5.15/go.mod h1:mXDI4NAOwEiszrHCb0aqfAYNCrZP4e9hRca3d1YK8EU=
go.etcd.io/etcd/client/v3 v3.5.15 h1:23M0eY4Fd/inNv1ZfU3AxrbbOdW79r9V9Rl62Nm6ip4=
go.etcd.io/etcd/client/v3 v3.5.15/go.mod h1:CLSJxrYjvLtHsrPKsy7LmZEE+DK2ktfd2bN4RhBMwlU=
go.mongodb.org/mongo-driver/v2 v2.3.0 h1:sh55yOXA2vUjW1QYw/2tRlHSQViwDyPnW61AwpZ4rtU=
go.mongodb.org/mongo-driver/v2 v2.3.0/go.mod h1:jHeEDJHJq7tm6ZF45Issun9dbogjfnPySb1vXA7EeAI=
go.mongodb.org/mongo-driver/v2 v2.3.1 h1:WrCgSzO7dh1/FrePud9dK5fKNZOE97q5EQimGkos7Wo=
go.mongodb.org/mongo-driver/v2 v2.3.1/go.mod h1:jHeEDJHJq7tm6ZF45Issun9dbogjfnPySb1vXA7EeAI=
go.opentelemetry.io/otel v1.24.0 h1:0LAOdjNmQeSTzGBzduGe/rU4tZhMwL5rWgtp9Ku5Jfo=
go.opentelemetry.io/otel v1.24.0/go.mod h1:W7b9Ozg4nkF5tWI5zsXkaKKDjdVjpD4oAt9Qi/MArHo=
go.opentelemetry.io/otel/exporters/jaeger v1.17.0 h1:D7UpUy2Xc2wsi1Ras6V40q806WM07rqoCWzXu7Sqy+4=

View File

@@ -175,7 +175,7 @@ GO111MODULE=on GOPROXY=https://goproxy.cn/,direct go get -u github.com/zeromicro
* API 文档
[https://go-zero.dev/cn/](https://go-zero.dev/cn/)
[https://go-zero.dev](https://go-zero.dev)
* awesome 系列(更多文章见『微服务实践』公众号)
@@ -305,6 +305,7 @@ go-zero 已被许多公司用于生产部署,接入场景如在线教育、电
>107. 深圳市聚货通信息科技有限公司
>108. 浙江银盾云科技有限公司
>109. 南京造世网络科技有限公司
>110. 温州飞儿云信息技术有限公司
如果贵公司也已使用 go-zero欢迎在 [登记地址](https://github.com/zeromicro/go-zero/issues/602) 登记,仅仅为了推广,不做其它用途。

View File

@@ -92,7 +92,7 @@ Port: 0
Path: "/",
Handler: nil,
}, WithJwt("thesecret"), WithSignature(SignatureConf{}),
WithJwtTransition("preivous", "thenewone"))
WithJwtTransition("previous", "thenewone"))
func() {
defer func() {

View File

@@ -32,7 +32,7 @@ import '../vars/vars.dart';
/// Send GET request.
///
/// ok: the function that will be called on success.
/// failthe fuction that will be called on failure.
/// failthe function that will be called on failure.
/// eventuallythe function that will be called regardless of success or failure.
Future apiGet(String path,
{Map<String, String> header,
@@ -47,7 +47,7 @@ Future apiGet(String path,
///
/// data: the data to post, it will be marshaled to json automatically.
/// ok: the function that will be called on success.
/// failthe fuction that will be called on failure.
/// failthe function that will be called on failure.
/// eventuallythe function that will be called regardless of success or failure.
Future apiPost(String path, dynamic data,
{Map<String, String> header,

View File

@@ -38,9 +38,11 @@ func genHandler(dir, rootPkg, projectPkg string, cfg *config.Config, group spec.
}
var builtinTemplate = handlerTemplate
var templateFile = handlerTemplateFile
sse := group.GetAnnotation("sse")
if sse == "true" {
builtinTemplate = sseHandlerTemplate
templateFile = sseHandlerTemplateFile
}
return genFile(fileGenConfig{
@@ -49,7 +51,7 @@ func genHandler(dir, rootPkg, projectPkg string, cfg *config.Config, group spec.
filename: filename + ".go",
templateName: "handlerTemplate",
category: category,
templateFile: handlerTemplateFile,
templateFile: templateFile,
builtinTemplate: builtinTemplate,
data: map[string]any{
"PkgName": pkgName,

View File

@@ -61,9 +61,11 @@ func genLogicByRoute(dir, rootPkg, projectPkg string, cfg *config.Config, group
subDir := getLogicFolderPath(group, route)
builtinTemplate := logicTemplate
templateFile := logicTemplateFile
sse := group.GetAnnotation("sse")
if sse == "true" {
builtinTemplate = sseLogicTemplate
templateFile = sseLogicTemplateFile
responseString = "error"
returnString = "return nil"
resp := responseGoTypeName(route, typesPacket)
@@ -80,7 +82,7 @@ func genLogicByRoute(dir, rootPkg, projectPkg string, cfg *config.Config, group
filename: goFile + ".go",
templateName: "logicTemplate",
category: category,
templateFile: logicTemplateFile,
templateFile: templateFile,
builtinTemplate: builtinTemplate,
data: map[string]any{
"pkgName": subDir[strings.LastIndex(subDir, "/")+1:],

View File

@@ -0,0 +1,153 @@
package gogen
import (
"os"
"path/filepath"
"strings"
"testing"
"github.com/stretchr/testify/assert"
)
func TestSSEGeneration(t *testing.T) {
// Create a temporary directory for test
dir := t.TempDir()
// Create a test API file with SSE annotation
apiContent := `syntax = "v1"
type SseReq {
Message string ` + "`json:\"message\"`" + `
}
type SseResp {
Data string ` + "`json:\"data\"`" + `
}
@server (
sse: true
)
service Test {
@handler Sse
get /sse (SseReq) returns (SseResp)
}
`
apiFile := filepath.Join(dir, "test.api")
err := os.WriteFile(apiFile, []byte(apiContent), 0644)
assert.NoError(t, err)
// Generate code
err = DoGenProject(apiFile, dir, "gozero", false)
assert.NoError(t, err)
// Read generated handler file
handlerPath := filepath.Join(dir, "internal/handler/ssehandler.go")
handlerContent, err := os.ReadFile(handlerPath)
assert.NoError(t, err)
// Read generated logic file
logicPath := filepath.Join(dir, "internal/logic/sselogic.go")
logicContent, err := os.ReadFile(logicPath)
assert.NoError(t, err)
handlerStr := string(handlerContent)
logicStr := string(logicContent)
// Verify SSE-specific patterns in handler
// Handler should call: err := l.Sse(&req, client)
assert.Contains(t, handlerStr, "err := l.Sse(&req, client)",
"Handler should call logic with client channel parameter")
// Handler should NOT have the regular pattern: resp, err := l.Sse(&req)
assert.NotContains(t, handlerStr, "resp, err := l.Sse(&req)",
"Handler should not use regular pattern with resp return")
// Handler should use threading.GoSafeCtx
assert.Contains(t, handlerStr, "threading.GoSafeCtx",
"Handler should use threading.GoSafeCtx for SSE")
// Handler should create client channel
assert.Contains(t, handlerStr, "client := make(chan",
"Handler should create client channel")
// Verify SSE-specific patterns in logic
// Logic should have signature: Sse(req *types.SseReq, client chan<- *types.SseResp) error
assert.Contains(t, logicStr, "func (l *SseLogic) Sse(req *types.SseReq, client chan<- *types.SseResp) error",
"Logic should have SSE signature with client channel parameter")
// Logic should NOT have regular signature: Sse(req *types.SseReq) (resp *types.SseResp, err error)
assert.NotContains(t, logicStr, "(resp *types.SseResp, err error)",
"Logic should not have regular signature with resp return")
}
func TestNonSSEGeneration(t *testing.T) {
// Create a temporary directory for test
dir := t.TempDir()
// Create a test API file WITHOUT SSE annotation
apiContent := `syntax = "v1"
type SseReq {
Message string ` + "`json:\"message\"`" + `
}
type SseResp {
Data string ` + "`json:\"data\"`" + `
}
service Test {
@handler Sse
get /sse (SseReq) returns (SseResp)
}
`
apiFile := filepath.Join(dir, "test.api")
err := os.WriteFile(apiFile, []byte(apiContent), 0644)
assert.NoError(t, err)
// Generate code
err = DoGenProject(apiFile, dir, "gozero", false)
assert.NoError(t, err)
// Read generated handler file
handlerPath := filepath.Join(dir, "internal/handler/ssehandler.go")
handlerContent, err := os.ReadFile(handlerPath)
assert.NoError(t, err)
// Read generated logic file
logicPath := filepath.Join(dir, "internal/logic/sselogic.go")
logicContent, err := os.ReadFile(logicPath)
assert.NoError(t, err)
handlerStr := string(handlerContent)
logicStr := string(logicContent)
// Verify regular (non-SSE) patterns in handler
// Handler should call: resp, err := l.Sse(&req)
assert.Contains(t, handlerStr, "resp, err := l.Sse(&req)",
"Handler should use regular pattern with resp return")
// Handler should NOT have SSE pattern: err := l.Sse(&req, client)
assert.NotContains(t, handlerStr, "err := l.Sse(&req, client)",
"Handler should not use SSE pattern")
// Handler should NOT use threading.GoSafeCtx
assert.NotContains(t, handlerStr, "threading.GoSafeCtx",
"Handler should not use threading.GoSafeCtx for regular routes")
// Verify regular (non-SSE) patterns in logic
// Logic should have signature: Sse(req *types.SseReq) (resp *types.SseResp, err error)
assert.Contains(t, logicStr, "(resp *types.SseResp, err error)",
"Logic should have regular signature with resp return")
// Logic should NOT have SSE signature with client parameter
linesToCheck := strings.Split(logicStr, "\n")
hasSSESignature := false
for _, line := range linesToCheck {
if strings.Contains(line, "func (l *SseLogic) Sse") && strings.Contains(line, "client chan<-") {
hasSSESignature = true
break
}
}
assert.False(t, hasSSESignature,
"Logic should not have SSE signature with client channel parameter")
}

View File

@@ -1,17 +0,0 @@
type Request {
Name string `path:"name,options=you|me"`
}
type Response {
Message string `json:"message"`
}
@server(
jwt: Auth
jwtTransition: Trans
middleware: TokenValidate
)
service A-api {
@handler GreetHandler
get /greet/from/:name(Request) returns (Response)
}

View File

@@ -268,7 +268,7 @@ func (v *ApiVisitor) VisitReplybody(ctx *api.ReplybodyContext) any {
v.panic(lit.Expr(), fmt.Sprintf("expecting 'ID', but found golang keyword '%s'", lit.Expr().Text()))
}
default:
v.panic(dt.Expr(), fmt.Sprintf("unsupport %s", dt.Expr().Text()))
v.panic(dt.Expr(), fmt.Sprintf("unsupported %s", dt.Expr().Text()))
}
case *Literal:
lit := dataType.Literal.Text()
@@ -276,7 +276,7 @@ func (v *ApiVisitor) VisitReplybody(ctx *api.ReplybodyContext) any {
v.panic(dataType.Literal, fmt.Sprintf("expecting 'ID', but found golang keyword '%s'", lit))
}
default:
v.panic(dt.Expr(), fmt.Sprintf("unsupport %s", dt.Expr().Text()))
v.panic(dt.Expr(), fmt.Sprintf("unsupported %s", dt.Expr().Text()))
}
return &Body{

View File

@@ -190,7 +190,7 @@ func (v *ApiVisitor) VisitTypeBlockStruct(ctx *api.TypeBlockStructContext) any {
structExpr := v.newExprWithToken(ctx.GetStructToken())
structTokenText := ctx.GetStructToken().GetText()
if structTokenText != "struct" {
v.panic(structExpr, fmt.Sprintf("expecting 'struct', found imput '%s'", structTokenText))
v.panic(structExpr, fmt.Sprintf("expecting 'struct', found input '%s'", structTokenText))
}
if api.IsGolangKeyWord(structTokenText, "struct") {

View File

@@ -18,7 +18,7 @@ type parser struct {
}
// Parse parses the api file.
// Depreacted: use tools/goctl/pkg/parser/api/parser/parser.go:18 instead,
// Deprecated: use tools/goctl/pkg/parser/api/parser/parser.go:18 instead,
// it will be removed in the future.
func Parse(filename string) (*spec.ApiSpec, error) {
if env.UseExperimental() {
@@ -63,14 +63,14 @@ func parseContent(content string, skipCheckTypeDeclaration bool, filename ...str
return apiSpec, nil
}
// Depreacted: use tools/goctl/pkg/parser/api/parser/parser.go:18 instead,
// Deprecated: use tools/goctl/pkg/parser/api/parser/parser.go:18 instead,
// it will be removed in the future.
// ParseContent parses the api content
func ParseContent(content string, filename ...string) (*spec.ApiSpec, error) {
return parseContent(content, false, filename...)
}
// Depreacted: use tools/goctl/pkg/parser/api/parser/parser.go:18 instead,
// Deprecated: use tools/goctl/pkg/parser/api/parser/parser.go:18 instead,
// it will be removed in the future.
// ParseContentWithParserSkipCheckTypeDeclaration parses the api content with skip check type declaration
func ParseContentWithParserSkipCheckTypeDeclaration(content string, filename ...string) (*spec.ApiSpec, error) {
@@ -227,7 +227,7 @@ func (p parser) astTypeToSpec(in ast.DataType) spec.Type {
return spec.PointerType{RawName: v.PointerExpr.Text(), Type: spec.DefineStruct{RawName: raw}}
}
panic(fmt.Sprintf("unspported type %+v", in))
panic(fmt.Sprintf("unsupported type %+v", in))
}
func (p parser) stringExprs(docs []ast.Expr) []string {

View File

@@ -8,70 +8,71 @@ import (
)
func getBoolFromKVOrDefault(properties map[string]string, key string, def bool) bool {
if len(properties) == 0 {
return def
}
md := metadata.New(properties)
val := md.Get(key)
if len(val) == 0 {
return def
}
//I think this function and those below should handle error, but they didn't.
//Since a default value (def) is provided, any parsing errors will result in the default being returned.
str, err := strconv.Unquote(val[0])
if err != nil || len(str) == 0 {
return def
}
res, _ := strconv.ParseBool(str)
return res
}
return getOrDefault(properties, key, def, func(str string, def bool) bool {
res, err := strconv.ParseBool(str)
if err != nil {
return def
}
func getStringFromKVOrDefault(properties map[string]string, key string, def string) string {
if len(properties) == 0 {
return def
}
md := metadata.New(properties)
val := md.Get(key)
if len(val) == 0 {
return def
}
str, err := strconv.Unquote(val[0])
if err != nil || len(str) == 0 {
return def
}
return str
}
func getListFromInfoOrDefault(properties map[string]string, key string, def []string) []string {
if len(properties) == 0 {
return def
}
md := metadata.New(properties)
val := md.Get(key)
if len(val) == 0 {
return def
}
str, err := strconv.Unquote(val[0])
if err != nil || len(str) == 0 {
return def
}
resp := util.FieldsAndTrimSpace(str, commaRune)
if len(resp) == 0 {
return def
}
return resp
return res
})
}
func getFirstUsableString(def ...string) string {
if len(def) == 0 {
return ""
}
for _, val := range def {
str, err := strconv.Unquote(val)
if err == nil && len(str) != 0 {
// Try to unquote if it's a quoted string
if str, err := strconv.Unquote(val); err == nil && len(str) != 0 {
return str
}
// Otherwise, use the value as-is if it's not empty
if len(val) != 0 {
return val
}
}
return ""
}
func getListFromInfoOrDefault(properties map[string]string, key string, def []string) []string {
return getOrDefault(properties, key, def, func(str string, def []string) []string {
resp := util.FieldsAndTrimSpace(str, commaRune)
if len(resp) == 0 {
return def
}
return resp
})
}
// getOrDefault abstracts the common logic for fetching, unquoting, and defaulting.
func getOrDefault[T any](properties map[string]string, key string, def T, convert func(string, T) T) T {
if len(properties) == 0 {
return def
}
md := metadata.New(properties)
val := md.Get(key)
if len(val) == 0 {
return def
}
str := val[0]
if unquoted, err := strconv.Unquote(str); err == nil {
str = unquoted
}
if len(str) == 0 {
return def
}
return convert(str, def)
}
func getStringFromKVOrDefault(properties map[string]string, key string, def string) string {
return getOrDefault(properties, key, def, func(str string, def string) string {
return str
})
}

View File

@@ -21,6 +21,19 @@ func Test_getBoolFromKVOrDefault(t *testing.T) {
assert.False(t, getBoolFromKVOrDefault(properties, "empty_value", false))
assert.False(t, getBoolFromKVOrDefault(nil, "nil", false))
assert.False(t, getBoolFromKVOrDefault(map[string]string{}, "empty", false))
// Test with unquoted values (as stored by RawText())
unquotedProperties := map[string]string{
"enabled": "true",
"disabled": "false",
"invalid": "notabool",
"empty_value": "",
}
assert.True(t, getBoolFromKVOrDefault(unquotedProperties, "enabled", false))
assert.False(t, getBoolFromKVOrDefault(unquotedProperties, "disabled", true))
assert.False(t, getBoolFromKVOrDefault(unquotedProperties, "invalid", false))
assert.False(t, getBoolFromKVOrDefault(unquotedProperties, "empty_value", false))
}
func Test_getStringFromKVOrDefault(t *testing.T) {
@@ -34,6 +47,17 @@ func Test_getStringFromKVOrDefault(t *testing.T) {
assert.Equal(t, "default", getStringFromKVOrDefault(properties, "missing", "default"))
assert.Equal(t, "default", getStringFromKVOrDefault(nil, "nil", "default"))
assert.Equal(t, "default", getStringFromKVOrDefault(map[string]string{}, "empty", "default"))
// Test with unquoted values (as stored by RawText())
unquotedProperties := map[string]string{
"name": "example",
"title": "Demo API",
"empty": "",
}
assert.Equal(t, "example", getStringFromKVOrDefault(unquotedProperties, "name", "default"))
assert.Equal(t, "Demo API", getStringFromKVOrDefault(unquotedProperties, "title", "default"))
assert.Equal(t, "default", getStringFromKVOrDefault(unquotedProperties, "empty", "default"))
}
func Test_getListFromInfoOrDefault(t *testing.T) {
@@ -50,4 +74,123 @@ func Test_getListFromInfoOrDefault(t *testing.T) {
assert.Equal(t, []string{"default"}, getListFromInfoOrDefault(map[string]string{
"foo": ",,",
}, "foo", []string{"default"}))
// Test with unquoted values (as stored by RawText())
unquotedProperties := map[string]string{
"list": "a, b, c",
"schemes": "http,https",
"tags": "query",
"empty": "",
}
// Note: FieldsAndTrimSpace doesn't actually trim the spaces from returned values
assert.Equal(t, []string{"a", " b", " c"}, getListFromInfoOrDefault(unquotedProperties, "list", []string{"default"}))
assert.Equal(t, []string{"http", "https"}, getListFromInfoOrDefault(unquotedProperties, "schemes", []string{"default"}))
assert.Equal(t, []string{"query"}, getListFromInfoOrDefault(unquotedProperties, "tags", []string{"default"}))
assert.Equal(t, []string{"default"}, getListFromInfoOrDefault(unquotedProperties, "empty", []string{"default"}))
}
func Test_getFirstUsableString(t *testing.T) {
t.Run("empty input", func(t *testing.T) {
result := getFirstUsableString()
assert.Equal(t, "", result, "should return empty string for no arguments")
})
t.Run("single plain string", func(t *testing.T) {
result := getFirstUsableString("Check server health status.")
assert.Equal(t, "Check server health status.", result)
})
t.Run("single quoted string", func(t *testing.T) {
// This is how Go would represent a quoted string literal
result := getFirstUsableString(`"Check server health status."`)
assert.Equal(t, "Check server health status.", result, "should unquote quoted strings")
})
t.Run("multiple plain strings", func(t *testing.T) {
result := getFirstUsableString("", "second", "third")
assert.Equal(t, "second", result, "should return first non-empty string")
})
t.Run("handler name fallback", func(t *testing.T) {
// Simulates the real use case: @doc text, handler name
result := getFirstUsableString("", "HealthCheck")
assert.Equal(t, "HealthCheck", result, "should fallback to handler name")
})
t.Run("doc text over handler name", func(t *testing.T) {
// Simulates the real use case with @doc text
result := getFirstUsableString("Check server health status.", "HealthCheck")
assert.Equal(t, "Check server health status.", result, "should use doc text over handler name")
})
t.Run("empty strings before valid", func(t *testing.T) {
result := getFirstUsableString("", "", "valid")
assert.Equal(t, "valid", result, "should skip empty strings")
})
t.Run("all empty strings", func(t *testing.T) {
result := getFirstUsableString("", "", "")
assert.Equal(t, "", result, "should return empty if all are empty")
})
t.Run("quoted then plain", func(t *testing.T) {
result := getFirstUsableString(`"quoted"`, "plain")
assert.Equal(t, "quoted", result, "should unquote first quoted string")
})
t.Run("plain then quoted", func(t *testing.T) {
result := getFirstUsableString("plain", `"quoted"`)
assert.Equal(t, "plain", result, "should use first plain string")
})
t.Run("invalid quoted string", func(t *testing.T) {
// String that looks quoted but isn't valid Go syntax
result := getFirstUsableString(`"incomplete`, "fallback")
assert.Equal(t, `"incomplete`, result, "should use as-is if unquote fails but not empty")
})
t.Run("whitespace only", func(t *testing.T) {
result := getFirstUsableString(" ", "fallback")
assert.Equal(t, " ", result, "should not trim whitespace, return as-is")
})
t.Run("real world API doc scenario", func(t *testing.T) {
// This is the actual bug scenario from issue #5229
atDocText := "Check server health status."
handlerName := "HealthCheck"
result := getFirstUsableString(atDocText, handlerName)
assert.Equal(t, "Check server health status.", result,
"should use @doc text for API summary")
})
t.Run("real world with empty doc", func(t *testing.T) {
// When @doc is empty, should fall back to handler name
atDocText := ""
handlerName := "HealthCheck"
result := getFirstUsableString(atDocText, handlerName)
assert.Equal(t, "HealthCheck", result,
"should fallback to handler name when @doc is empty")
})
t.Run("complex summary with special characters", func(t *testing.T) {
result := getFirstUsableString("Get user by ID: /users/{id}")
assert.Equal(t, "Get user by ID: /users/{id}", result,
"should handle special characters in plain strings")
})
t.Run("multiline string", func(t *testing.T) {
result := getFirstUsableString("Line 1\nLine 2")
assert.Equal(t, "Line 1\nLine 2", result,
"should handle multiline strings")
})
t.Run("unicode characters", func(t *testing.T) {
result := getFirstUsableString("健康检查", "HealthCheck")
assert.Equal(t, "健康检查", result,
"should handle unicode characters")
})
}

View File

@@ -8,28 +8,37 @@ import (
apiSpec "github.com/zeromicro/go-zero/tools/goctl/api/spec"
)
func isPostJson(ctx Context, method string, tp apiSpec.Type) (string, bool) {
if !strings.EqualFold(method, http.MethodPost) {
func isRequestBodyJson(ctx Context, method string, tp apiSpec.Type) (string, bool) {
// Support HTTP methods that commonly use request bodies with JSON
// POST, PUT, PATCH are standard methods with bodies
// DELETE can also have a body (though less common)
method = strings.ToUpper(method)
if method != http.MethodPost && method != http.MethodPut &&
method != http.MethodPatch && method != http.MethodDelete {
return "", false
}
structType, ok := tp.(apiSpec.DefineStruct)
if !ok {
return "", false
}
var isPostJson bool
var hasJsonField bool
rangeMemberAndDo(ctx, structType, func(tag *apiSpec.Tags, required bool, member apiSpec.Member) {
jsonTag, _ := tag.Get(tagJson)
if !isPostJson {
isPostJson = jsonTag != nil
if !hasJsonField {
hasJsonField = jsonTag != nil
}
})
return structType.RawName, isPostJson
return structType.RawName, hasJsonField
}
func parametersFromType(ctx Context, method string, tp apiSpec.Type) []spec.Parameter {
if tp == nil {
return []spec.Parameter{}
}
structType, ok := tp.(apiSpec.DefineStruct)
if !ok {
return []spec.Parameter{}
@@ -43,15 +52,13 @@ func parametersFromType(ctx Context, method string, tp apiSpec.Type) []spec.Para
rangeMemberAndDo(ctx, structType, func(tag *apiSpec.Tags, required bool, member apiSpec.Member) {
headerTag, _ := tag.Get(tagHeader)
hasHeader := headerTag != nil
pathParameterTag, _ := tag.Get(tagPath)
hasPathParameter := pathParameterTag != nil
formTag, _ := tag.Get(tagForm)
hasForm := formTag != nil
jsonTag, _ := tag.Get(tagJson)
hasJson := jsonTag != nil
if hasHeader {
minimum, maximum, exclusiveMinimum, exclusiveMaximum := rangeValueFromOptions(headerTag.Options)
resp = append(resp, spec.Parameter{
@@ -75,6 +82,7 @@ func parametersFromType(ctx Context, method string, tp apiSpec.Type) []spec.Para
},
})
}
if hasPathParameter {
minimum, maximum, exclusiveMinimum, exclusiveMaximum := rangeValueFromOptions(pathParameterTag.Options)
resp = append(resp, spec.Parameter{
@@ -98,6 +106,7 @@ func parametersFromType(ctx Context, method string, tp apiSpec.Type) []spec.Para
},
})
}
if hasForm {
minimum, maximum, exclusiveMinimum, exclusiveMaximum := rangeValueFromOptions(formTag.Options)
if strings.EqualFold(method, http.MethodGet) {
@@ -145,8 +154,8 @@ func parametersFromType(ctx Context, method string, tp apiSpec.Type) []spec.Para
},
})
}
}
if hasJson {
minimum, maximum, exclusiveMinimum, exclusiveMaximum := rangeValueFromOptions(jsonTag.Options)
if required {
@@ -179,9 +188,10 @@ func parametersFromType(ctx Context, method string, tp apiSpec.Type) []spec.Para
properties[jsonTag.Name] = schema
}
})
if len(properties) > 0 {
if ctx.UseDefinitions {
structName, ok := isPostJson(ctx, method, tp)
structName, ok := isRequestBodyJson(ctx, method, tp)
if ok {
resp = append(resp, spec.Parameter{
ParamProps: spec.ParamProps{
@@ -213,5 +223,6 @@ func parametersFromType(ctx Context, method string, tp apiSpec.Type) []spec.Para
})
}
}
return resp
}

View File

@@ -8,7 +8,7 @@ import (
apiSpec "github.com/zeromicro/go-zero/tools/goctl/api/spec"
)
func TestIsPostJson(t *testing.T) {
func TestIsRequestBodyJson(t *testing.T) {
tests := []struct {
name string
method string
@@ -18,13 +18,18 @@ func TestIsPostJson(t *testing.T) {
{"POST with JSON", http.MethodPost, true, true},
{"POST without JSON", http.MethodPost, false, false},
{"GET with JSON", http.MethodGet, true, false},
{"PUT with JSON", http.MethodPut, true, false},
{"PUT with JSON", http.MethodPut, true, true},
{"PUT without JSON", http.MethodPut, false, false},
{"PATCH with JSON", http.MethodPatch, true, true},
{"PATCH without JSON", http.MethodPatch, false, false},
{"DELETE with JSON", http.MethodDelete, true, true},
{"DELETE without JSON", http.MethodDelete, false, false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
testStruct := createTestStruct("TestStruct", tt.hasJson)
_, result := isPostJson(testingContext(t), tt.method, testStruct)
_, result := isRequestBodyJson(testingContext(t), tt.method, testStruct)
assert.Equal(t, tt.expected, result)
})
}
@@ -41,6 +46,12 @@ func TestParametersFromType(t *testing.T) {
}{
{"POST JSON with definitions", http.MethodPost, true, true, 1, true},
{"POST JSON without definitions", http.MethodPost, false, true, 1, true},
{"PUT JSON with definitions", http.MethodPut, true, true, 1, true},
{"PUT JSON without definitions", http.MethodPut, false, true, 1, true},
{"PATCH JSON with definitions", http.MethodPatch, true, true, 1, true},
{"PATCH JSON without definitions", http.MethodPatch, false, true, 1, true},
{"DELETE JSON with definitions", http.MethodDelete, true, true, 1, true},
{"DELETE JSON without definitions", http.MethodDelete, false, true, 1, true},
{"GET with form", http.MethodGet, false, false, 1, false},
{"POST with form", http.MethodPost, false, false, 1, false},
}

View File

@@ -19,7 +19,11 @@ func spec2Paths(ctx Context, srv apiSpec.Service) *spec.Paths {
for _, route := range group.Routes {
routPath := pathVariable2SwaggerVariable(ctx, route.Path)
if len(prefix) > 0 && prefix != "." {
routPath = "/" + path.Clean(prefix) + routPath
if routPath == "/" {
routPath = "/" + path.Clean(prefix)
} else {
routPath = "/" + path.Clean(prefix) + routPath
}
}
pathItem := spec2Path(ctx, group, route)
existPathItem, ok := paths.Paths[routPath]

View File

@@ -0,0 +1,90 @@
package swagger
import (
"testing"
"github.com/stretchr/testify/assert"
"github.com/zeromicro/go-zero/tools/goctl/api/spec"
)
func TestSpec2PathsWithRootRoute(t *testing.T) {
tests := []struct {
name string
prefix string
routePath string
expectedPath string
}{
{
name: "prefix with root route",
prefix: "/api/v1/shoppings",
routePath: "/",
expectedPath: "/api/v1/shoppings",
},
{
name: "prefix with sub route",
prefix: "/api/v1/shoppings",
routePath: "/list",
expectedPath: "/api/v1/shoppings/list",
},
{
name: "empty prefix with root route",
prefix: "",
routePath: "/",
expectedPath: "/",
},
{
name: "empty prefix with sub route",
prefix: "",
routePath: "/list",
expectedPath: "/list",
},
{
name: "prefix with trailing slash and root route",
prefix: "/api/v1/shoppings/",
routePath: "/",
expectedPath: "/api/v1/shoppings",
},
{
name: "prefix without leading slash and root route",
prefix: "api/v1/shoppings",
routePath: "/",
expectedPath: "/api/v1/shoppings",
},
{
name: "single level prefix with root route",
prefix: "/api",
routePath: "/",
expectedPath: "/api",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
srv := spec.Service{
Groups: []spec.Group{
{
Annotation: spec.Annotation{
Properties: map[string]string{
propertyKeyPrefix: tt.prefix,
},
},
Routes: []spec.Route{
{
Method: "get",
Path: tt.routePath,
Handler: "TestHandler",
},
},
},
},
}
ctx := testingContext(t)
paths := spec2Paths(ctx, srv)
assert.Contains(t, paths.Paths, tt.expectedPath,
"Expected path %s not found in generated paths. Got: %v",
tt.expectedPath, paths.Paths)
})
}
}

View File

@@ -0,0 +1,163 @@
package tsgen
import (
"os"
"path/filepath"
"strings"
"testing"
"github.com/stretchr/testify/assert"
"github.com/zeromicro/go-zero/tools/goctl/api/parser"
)
func TestGenWithInlineStructs(t *testing.T) {
// Create a temporary directory for the test
tmpDir := t.TempDir()
apiFile := filepath.Join(tmpDir, "test.api")
// Write the test API file
apiContent := `syntax = "v1"
info (
title: "Test ts generator"
desc: "Test inline struct handling"
author: "test"
version: "v1"
)
// common pagination request
type PaginationReq {
PageNum int ` + "`form:\"pageNum\"`" + `
PageSize int ` + "`form:\"pageSize\"`" + `
}
// base response
type BaseResp {
Code int64 ` + "`json:\"code\"`" + `
Msg string ` + "`json:\"msg\"`" + `
}
// common req
type GetListCommonReq {
Sth string ` + "`form:\"sth\"`" + `
PageNum int ` + "`form:\"pageNum\"`" + `
PageSize int ` + "`form:\"pageSize\"`" + `
}
// bad req to ts - inline struct with form tags
type GetListBadReq {
Sth string ` + "`form:\"sth\"`" + `
PaginationReq
}
// bad req to ts 2 - only inline struct with form tags
type GetListBad2Req {
PaginationReq
}
// GetListResp - inline struct with json tags
type GetListResp {
BaseResp
}
service test-api {
@doc "common req"
@handler getListCommon
get /getListCommon (GetListCommonReq) returns (GetListResp)
@doc "bad req"
@handler getListBad
get /getListBad (GetListBadReq) returns (GetListResp)
@doc "bad req 2"
@handler getListBad2
get /getListBad2 (GetListBad2Req) returns (GetListResp)
@doc "no req"
@handler getListNoReq
get /getListNoReq returns (GetListResp)
}`
err := os.WriteFile(apiFile, []byte(apiContent), 0644)
assert.NoError(t, err)
// Parse the API file
api, err := parser.Parse(apiFile)
assert.NoError(t, err)
// Generate TypeScript files
outputDir := filepath.Join(tmpDir, "output")
err = os.MkdirAll(outputDir, 0755)
assert.NoError(t, err)
// Generate the files directly
api.Service = api.Service.JoinPrefix()
err = genRequest(outputDir)
assert.NoError(t, err)
err = genHandler(outputDir, ".", "webapi", api, false)
assert.NoError(t, err)
err = genComponents(outputDir, api)
assert.NoError(t, err)
// Read generated handler file
handlerFile := filepath.Join(outputDir, "test.ts")
handlerContent, err := os.ReadFile(handlerFile)
assert.NoError(t, err)
handler := string(handlerContent)
// Read generated components file
componentsFile := filepath.Join(outputDir, "testComponents.ts")
componentsContent, err := os.ReadFile(componentsFile)
assert.NoError(t, err)
components := string(componentsContent)
// Verify getListBad function signature and call
assert.Contains(t, handler, "export function getListBad(params: components.GetListBadReqParams)")
assert.Contains(t, handler, "return webapi.get<components.GetListResp>(`/getListBad`, params)")
// Should NOT contain 4 arguments
assert.NotContains(t, handler, "getListBad`, params, req, headers")
// Verify getListBad2 function signature and call
assert.Contains(t, handler, "export function getListBad2(params: components.GetListBad2ReqParams)")
assert.Contains(t, handler, "return webapi.get<components.GetListResp>(`/getListBad2`, params)")
// Should NOT reference non-existent headers
assert.NotContains(t, handler, "GetListBad2ReqHeaders")
// Verify getListCommon function signature and call
assert.Contains(t, handler, "export function getListCommon(params: components.GetListCommonReqParams)")
assert.Contains(t, handler, "return webapi.get<components.GetListResp>(`/getListCommon`, params)")
// Verify getListNoReq function signature and call
assert.Contains(t, handler, "export function getListNoReq()")
assert.Contains(t, handler, "return webapi.get<components.GetListResp>(`/getListNoReq`)")
// Verify GetListBadReqParams contains flattened fields
assert.Contains(t, components, "export interface GetListBadReqParams")
// Count occurrences of fields in GetListBadReqParams
paramsStart := strings.Index(components, "export interface GetListBadReqParams")
paramsEnd := strings.Index(components[paramsStart:], "}")
paramsSection := components[paramsStart : paramsStart+paramsEnd]
assert.Contains(t, paramsSection, "sth: string")
assert.Contains(t, paramsSection, "pageNum: number")
assert.Contains(t, paramsSection, "pageSize: number")
// Verify GetListBad2ReqParams contains flattened fields from inline PaginationReq
assert.Contains(t, components, "export interface GetListBad2ReqParams")
params2Start := strings.Index(components, "export interface GetListBad2ReqParams")
params2End := strings.Index(components[params2Start:], "}")
params2Section := components[params2Start : params2Start+params2End]
assert.Contains(t, params2Section, "pageNum: number")
assert.Contains(t, params2Section, "pageSize: number")
// Verify no empty Headers interfaces are generated
assert.NotContains(t, components, "GetListBadReqHeaders")
assert.NotContains(t, components, "GetListBad2ReqHeaders")
// Verify GetListResp contains flattened fields from BaseResp
assert.Contains(t, components, "export interface GetListResp")
respStart := strings.Index(components, "export interface GetListResp")
respEnd := strings.Index(components[respStart:], "}")
respSection := components[respStart : respStart+respEnd]
assert.Contains(t, respSection, "code: number")
assert.Contains(t, respSection, "msg: string")
}

View File

@@ -212,7 +212,7 @@ func pathHasParams(route spec.Route) bool {
return false
}
return len(ds.Members) != len(ds.GetBodyMembers())
return hasActualNonBodyMembers(ds)
}
func hasRequestBody(route spec.Route) bool {
@@ -221,7 +221,7 @@ func hasRequestBody(route spec.Route) bool {
return false
}
return len(route.RequestTypeName()) > 0 && len(ds.GetBodyMembers()) > 0
return len(route.RequestTypeName()) > 0 && hasActualBodyMembers(ds)
}
func hasRequestPath(route spec.Route) bool {
@@ -230,7 +230,7 @@ func hasRequestPath(route spec.Route) bool {
return false
}
return len(route.RequestTypeName()) > 0 && len(ds.GetTagMembers(pathTagKey)) > 0
return len(route.RequestTypeName()) > 0 && hasActualTagMembers(ds, pathTagKey)
}
func hasRequestHeader(route spec.Route) bool {
@@ -239,5 +239,5 @@ func hasRequestHeader(route spec.Route) bool {
return false
}
return len(route.RequestTypeName()) > 0 && len(ds.GetTagMembers(headerTagKey)) > 0
return len(route.RequestTypeName()) > 0 && hasActualTagMembers(ds, headerTagKey)
}

View File

@@ -164,13 +164,13 @@ func writeType(writer io.Writer, tp spec.Type) error {
}
func genParamsTypesIfNeed(writer io.Writer, tp spec.Type) error {
definedType, ok := tp.(spec.DefineStruct)
_, ok := tp.(spec.DefineStruct)
if !ok {
return errors.New("no members of type " + tp.Name())
}
members := definedType.GetNonBodyMembers()
if len(members) == 0 {
// Check if there are actual non-body members (recursively through inline structs)
if !hasActualNonBodyMembers(tp) {
return nil
}
@@ -180,7 +180,7 @@ func genParamsTypesIfNeed(writer io.Writer, tp spec.Type) error {
}
fmt.Fprintf(writer, "}\n")
if len(definedType.GetTagMembers(headerTagKey)) > 0 {
if hasActualTagMembers(tp, headerTagKey) {
fmt.Fprintf(writer, "export interface %sHeaders {\n", util.Title(tp.Name()))
if err := writeTagMembers(writer, tp, headerTagKey); err != nil {
return err
@@ -247,3 +247,87 @@ func writeTagMembers(writer io.Writer, tp spec.Type, tagKey string) error {
}
return nil
}
// hasActualTagMembers checks if a type has actual members with the given tag,
// recursively checking inline/embedded structs
func hasActualTagMembers(tp spec.Type, tagKey string) bool {
definedType, ok := tp.(spec.DefineStruct)
if !ok {
pointType, ok := tp.(spec.PointerType)
if ok {
return hasActualTagMembers(pointType.Type, tagKey)
}
return false
}
for _, m := range definedType.Members {
if m.IsInline {
// Recursively check inline members
if hasActualTagMembers(m.Type, tagKey) {
return true
}
} else {
// Check non-inline members for the tag
if m.IsTagMember(tagKey) {
return true
}
}
}
return false
}
// hasActualBodyMembers checks if a type has actual body members (json tags),
// recursively checking inline/embedded structs
func hasActualBodyMembers(tp spec.Type) bool {
definedType, ok := tp.(spec.DefineStruct)
if !ok {
pointType, ok := tp.(spec.PointerType)
if ok {
return hasActualBodyMembers(pointType.Type)
}
return false
}
for _, m := range definedType.Members {
if m.IsInline {
// Recursively check inline members
if hasActualBodyMembers(m.Type) {
return true
}
} else {
// Check non-inline members for json tag
if m.IsBodyMember() {
return true
}
}
}
return false
}
// hasActualNonBodyMembers checks if a type has actual non-body members (form, path, header tags),
// recursively checking inline/embedded structs
func hasActualNonBodyMembers(tp spec.Type) bool {
definedType, ok := tp.(spec.DefineStruct)
if !ok {
pointType, ok := tp.(spec.PointerType)
if ok {
return hasActualNonBodyMembers(pointType.Type)
}
return false
}
for _, m := range definedType.Members {
if m.IsInline {
// Recursively check inline members
if hasActualNonBodyMembers(m.Type) {
return true
}
} else {
// Check non-inline members for non-body tags
if !m.IsBodyMember() {
return true
}
}
}
return false
}

View File

@@ -37,3 +37,268 @@ func TestGenTsType(t *testing.T) {
}
assert.Equal(t, `1 | 3 | 4 | 123`, ty)
}
func TestHasActualTagMembers(t *testing.T) {
// Test with no members
emptyStruct := spec.DefineStruct{
RawName: "Empty",
Members: []spec.Member{},
}
assert.False(t, hasActualTagMembers(emptyStruct, "form"))
assert.False(t, hasActualTagMembers(emptyStruct, "header"))
// Test with direct form members
directFormStruct := spec.DefineStruct{
RawName: "DirectForm",
Members: []spec.Member{
{
Name: "Field1",
Type: spec.PrimitiveType{RawName: "string"},
Tag: `form:"field1"`,
},
},
}
assert.True(t, hasActualTagMembers(directFormStruct, "form"))
assert.False(t, hasActualTagMembers(directFormStruct, "header"))
// Test with inline struct containing form members
inlineFormStruct := spec.DefineStruct{
RawName: "PaginationReq",
Members: []spec.Member{
{
Name: "PageNum",
Type: spec.PrimitiveType{RawName: "int"},
Tag: `form:"pageNum"`,
},
{
Name: "PageSize",
Type: spec.PrimitiveType{RawName: "int"},
Tag: `form:"pageSize"`,
},
},
}
parentStruct := spec.DefineStruct{
RawName: "ParentReq",
Members: []spec.Member{
{
Name: "",
Type: inlineFormStruct,
IsInline: true,
},
},
}
assert.True(t, hasActualTagMembers(parentStruct, "form"))
assert.False(t, hasActualTagMembers(parentStruct, "header"))
// Test with both direct and inline members
mixedStruct := spec.DefineStruct{
RawName: "MixedReq",
Members: []spec.Member{
{
Name: "Sth",
Type: spec.PrimitiveType{RawName: "string"},
Tag: `form:"sth"`,
},
{
Name: "",
Type: inlineFormStruct,
IsInline: true,
},
},
}
assert.True(t, hasActualTagMembers(mixedStruct, "form"))
assert.False(t, hasActualTagMembers(mixedStruct, "header"))
// Test with inline struct containing only json members (body members)
inlineJsonStruct := spec.DefineStruct{
RawName: "JsonStruct",
Members: []spec.Member{
{
Name: "Code",
Type: spec.PrimitiveType{RawName: "int64"},
Tag: `json:"code"`,
},
{
Name: "Msg",
Type: spec.PrimitiveType{RawName: "string"},
Tag: `json:"msg"`,
},
},
}
parentJsonStruct := spec.DefineStruct{
RawName: "ParentResp",
Members: []spec.Member{
{
Name: "",
Type: inlineJsonStruct,
IsInline: true,
},
},
}
assert.False(t, hasActualTagMembers(parentJsonStruct, "form"))
assert.False(t, hasActualTagMembers(parentJsonStruct, "header"))
}
func TestHasActualBodyMembers(t *testing.T) {
// Test with no members
emptyStruct := spec.DefineStruct{
RawName: "Empty",
Members: []spec.Member{},
}
assert.False(t, hasActualBodyMembers(emptyStruct))
// Test with direct json members
directJsonStruct := spec.DefineStruct{
RawName: "DirectJson",
Members: []spec.Member{
{
Name: "Code",
Type: spec.PrimitiveType{RawName: "int64"},
Tag: `json:"code"`,
},
},
}
assert.True(t, hasActualBodyMembers(directJsonStruct))
// Test with inline struct containing json members
inlineJsonStruct := spec.DefineStruct{
RawName: "BaseResp",
Members: []spec.Member{
{
Name: "Code",
Type: spec.PrimitiveType{RawName: "int64"},
Tag: `json:"code"`,
},
{
Name: "Msg",
Type: spec.PrimitiveType{RawName: "string"},
Tag: `json:"msg"`,
},
},
}
parentStruct := spec.DefineStruct{
RawName: "ParentResp",
Members: []spec.Member{
{
Name: "",
Type: inlineJsonStruct,
IsInline: true,
},
},
}
assert.True(t, hasActualBodyMembers(parentStruct))
// Test with inline struct containing only form members (not body members)
inlineFormStruct := spec.DefineStruct{
RawName: "PaginationReq",
Members: []spec.Member{
{
Name: "PageNum",
Type: spec.PrimitiveType{RawName: "int"},
Tag: `form:"pageNum"`,
},
},
}
parentFormStruct := spec.DefineStruct{
RawName: "ParentReq",
Members: []spec.Member{
{
Name: "",
Type: inlineFormStruct,
IsInline: true,
},
},
}
assert.False(t, hasActualBodyMembers(parentFormStruct))
}
func TestHasActualNonBodyMembers(t *testing.T) {
// Test with no members
emptyStruct := spec.DefineStruct{
RawName: "Empty",
Members: []spec.Member{},
}
assert.False(t, hasActualNonBodyMembers(emptyStruct))
// Test with direct form members
directFormStruct := spec.DefineStruct{
RawName: "DirectForm",
Members: []spec.Member{
{
Name: "Field1",
Type: spec.PrimitiveType{RawName: "string"},
Tag: `form:"field1"`,
},
},
}
assert.True(t, hasActualNonBodyMembers(directFormStruct))
// Test with inline struct containing form members
inlineFormStruct := spec.DefineStruct{
RawName: "PaginationReq",
Members: []spec.Member{
{
Name: "PageNum",
Type: spec.PrimitiveType{RawName: "int"},
Tag: `form:"pageNum"`,
},
{
Name: "PageSize",
Type: spec.PrimitiveType{RawName: "int"},
Tag: `form:"pageSize"`,
},
},
}
parentStruct := spec.DefineStruct{
RawName: "ParentReq",
Members: []spec.Member{
{
Name: "",
Type: inlineFormStruct,
IsInline: true,
},
},
}
assert.True(t, hasActualNonBodyMembers(parentStruct))
// Test with inline struct containing only json members (body members)
inlineJsonStruct := spec.DefineStruct{
RawName: "BaseResp",
Members: []spec.Member{
{
Name: "Code",
Type: spec.PrimitiveType{RawName: "int64"},
Tag: `json:"code"`,
},
},
}
parentJsonStruct := spec.DefineStruct{
RawName: "ParentResp",
Members: []spec.Member{
{
Name: "",
Type: inlineJsonStruct,
IsInline: true,
},
},
}
assert.False(t, hasActualNonBodyMembers(parentJsonStruct))
// Test with both direct and inline non-body members
mixedStruct := spec.DefineStruct{
RawName: "MixedReq",
Members: []spec.Member{
{
Name: "Sth",
Type: spec.PrimitiveType{RawName: "string"},
Tag: `form:"sth"`,
},
{
Name: "",
Type: inlineFormStruct,
IsInline: true,
},
},
}
assert.True(t, hasActualNonBodyMembers(mixedStruct))
}

View File

@@ -73,6 +73,7 @@ func dockerCommand(_ *cobra.Command, _ []string) (err error) {
base := varStringBase
port := varIntPort
etcDir := filepath.Join(filepath.Dir(goFile), etcDir)
if _, err := os.Stat(etcDir); os.IsNotExist(err) {
return generateDockerfile(goFile, base, port, version, timezone)
}
@@ -170,7 +171,7 @@ func generateDockerfile(goFile, base string, port int, version, timezone string,
t := template.Must(template.New("dockerfile").Parse(text))
return t.Execute(out, Docker{
Chinese: env.InChina(),
GoMainFrom: path.Join(projPath, goFile),
GoMainFrom: path.Join(projPath, filepath.Base(goFile)),
GoRelPath: projPath,
GoFile: goFile,
ExeFile: exeName,

View File

@@ -0,0 +1,376 @@
package docker
import (
"os"
"path/filepath"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestDockerCommand_EtcDirResolution(t *testing.T) {
// Create a temporary project structure
tempDir := t.TempDir()
// Create project structure: project/service/api/
serviceDir := filepath.Join(tempDir, "service", "api")
etcDir := filepath.Join(serviceDir, "etc")
require.NoError(t, os.MkdirAll(etcDir, 0755))
// Create a Go file
goFile := filepath.Join(serviceDir, "api.go")
require.NoError(t, os.WriteFile(goFile, []byte("package main\n\nfunc main() {}"), 0644))
// Create a config file
configFile := filepath.Join(etcDir, "config.yaml")
require.NoError(t, os.WriteFile(configFile, []byte("Name: test\n"), 0644))
// Create go.mod at the root
goModFile := filepath.Join(tempDir, "go.mod")
require.NoError(t, os.WriteFile(goModFile, []byte("module test\n\ngo 1.21\n"), 0644))
// Test: etc directory should be found relative to Go file, not CWD
t.Run("etc directory resolved relative to go file", func(t *testing.T) {
// Save and restore original working directory
originalWd, err := os.Getwd()
require.NoError(t, err)
defer func() {
require.NoError(t, os.Chdir(originalWd))
}()
// Change to temp directory (not service/api directory)
require.NoError(t, os.Chdir(tempDir))
// The relative path from tempDir to the go file
relGoFile := filepath.Join("service", "api", "api.go")
// Test the etc directory resolution logic
resolvedEtcDir := filepath.Join(filepath.Dir(relGoFile), "etc")
// Verify the resolved path exists
_, err = os.Stat(resolvedEtcDir)
assert.NoError(t, err, "etc directory should be found at service/api/etc")
// Verify it's the correct path (use EvalSymlinks to handle /private on macOS)
absResolvedEtc, err := filepath.Abs(resolvedEtcDir)
require.NoError(t, err)
absResolvedEtc, err = filepath.EvalSymlinks(absResolvedEtc)
require.NoError(t, err)
expectedEtc, err := filepath.EvalSymlinks(etcDir)
require.NoError(t, err)
assert.Equal(t, expectedEtc, absResolvedEtc)
})
t.Run("etc directory with empty goFile", func(t *testing.T) {
// When goFile is empty, should default to "./etc"
goFile := ""
resolvedEtcDir := filepath.Join(filepath.Dir(goFile), "etc")
// Should resolve to just "etc"
assert.Equal(t, "etc", resolvedEtcDir)
})
t.Run("etc directory with absolute path", func(t *testing.T) {
// When goFile is absolute path
absGoFile := filepath.Join(tempDir, "service", "api", "api.go")
resolvedEtcDir := filepath.Join(filepath.Dir(absGoFile), "etc")
// Should resolve correctly
_, err := os.Stat(resolvedEtcDir)
assert.NoError(t, err)
})
}
func TestGenerateDockerfile_GoMainFromPath(t *testing.T) {
tests := []struct {
name string
goFile string
projPath string
expectedPath string
}{
{
name: "relative path with subdirectory",
goFile: "service/api/api.go",
projPath: "service/api",
expectedPath: "service/api/api.go",
},
{
name: "simple filename",
goFile: "main.go",
projPath: ".",
expectedPath: "main.go",
},
{
name: "nested service path",
goFile: "internal/service/user/user.go",
projPath: "internal/service/user",
expectedPath: "internal/service/user/user.go",
},
{
name: "deep nested path",
goFile: "cmd/api/internal/handler/handler.go",
projPath: "cmd/api/internal/handler",
expectedPath: "cmd/api/internal/handler/handler.go",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Simulate the fix: using filepath.Base instead of full path
goMainFrom := filepath.Join(tt.projPath, filepath.Base(tt.goFile))
assert.Equal(t, tt.expectedPath, goMainFrom,
"GoMainFrom should not duplicate path segments")
// Verify the old buggy behavior would have been wrong
if tt.goFile != filepath.Base(tt.goFile) {
buggyPath := filepath.Join(tt.projPath, tt.goFile)
assert.NotEqual(t, tt.expectedPath, buggyPath,
"Old implementation would have created incorrect path")
}
})
}
}
func TestGenerateDockerfile_PathJoinBehavior(t *testing.T) {
t.Run("demonstrates the bug and fix", func(t *testing.T) {
projPath := "service/api"
goFile := "service/api/api.go"
// OLD (buggy) behavior: path duplication
buggyPath := filepath.Join(projPath, goFile)
assert.Equal(t, "service/api/service/api/api.go", buggyPath,
"Bug: path segments are duplicated")
// NEW (fixed) behavior: correct path
fixedPath := filepath.Join(projPath, filepath.Base(goFile))
assert.Equal(t, "service/api/api.go", fixedPath,
"Fix: using filepath.Base prevents duplication")
})
}
func TestFindConfig(t *testing.T) {
tempDir := t.TempDir()
etcDir := filepath.Join(tempDir, "etc")
require.NoError(t, os.MkdirAll(etcDir, 0755))
t.Run("finds config matching go file name", func(t *testing.T) {
// Create config files
require.NoError(t, os.WriteFile(
filepath.Join(etcDir, "api.yaml"), []byte("test"), 0644))
require.NoError(t, os.WriteFile(
filepath.Join(etcDir, "other.yaml"), []byte("test"), 0644))
cfg, err := findConfig("api.go", etcDir)
assert.NoError(t, err)
assert.Equal(t, "api.yaml", cfg)
})
t.Run("returns first config when no match", func(t *testing.T) {
etcDir2 := filepath.Join(tempDir, "etc2")
require.NoError(t, os.MkdirAll(etcDir2, 0755))
require.NoError(t, os.WriteFile(
filepath.Join(etcDir2, "config.yaml"), []byte("test"), 0644))
cfg, err := findConfig("main.go", etcDir2)
assert.NoError(t, err)
assert.Equal(t, "config.yaml", cfg)
})
t.Run("returns error when no yaml files", func(t *testing.T) {
emptyDir := filepath.Join(tempDir, "empty")
require.NoError(t, os.MkdirAll(emptyDir, 0755))
_, err := findConfig("api.go", emptyDir)
assert.Error(t, err)
assert.Contains(t, err.Error(), "no yaml file")
})
t.Run("handles path in go file name", func(t *testing.T) {
// Test with service/api/api.go - should extract just "api"
cfg, err := findConfig("service/api/api.go", etcDir)
assert.NoError(t, err)
assert.Equal(t, "api.yaml", cfg)
})
}
func TestGetFilePath(t *testing.T) {
// Create a temporary directory with go.mod
tempDir := t.TempDir()
require.NoError(t, os.WriteFile(
filepath.Join(tempDir, "go.mod"),
[]byte("module testproject\n\ngo 1.21\n"),
0644,
))
// Create subdirectories
serviceDir := filepath.Join(tempDir, "service", "api")
require.NoError(t, os.MkdirAll(serviceDir, 0755))
originalWd, err := os.Getwd()
require.NoError(t, err)
defer func() {
require.NoError(t, os.Chdir(originalWd))
}()
t.Run("returns relative path from go.mod", func(t *testing.T) {
require.NoError(t, os.Chdir(tempDir))
path, err := getFilePath("service/api")
assert.NoError(t, err)
assert.Equal(t, "service/api", path)
})
t.Run("handles current directory", func(t *testing.T) {
require.NoError(t, os.Chdir(tempDir))
path, err := getFilePath(".")
assert.NoError(t, err)
// Current directory returns empty string when at go.mod root
assert.True(t, path == "." || path == "")
})
}
// Integration test to verify the complete fix
func TestDockerCommandIntegration(t *testing.T) {
// Create a complete project structure
tempDir := t.TempDir()
// Setup: project/service/api/
serviceDir := filepath.Join(tempDir, "service", "api")
etcDir := filepath.Join(serviceDir, "etc")
require.NoError(t, os.MkdirAll(etcDir, 0755))
// Create files
goFile := filepath.Join(serviceDir, "api.go")
require.NoError(t, os.WriteFile(goFile, []byte("package main\n\nfunc main() {}"), 0644))
configFile := filepath.Join(etcDir, "api.yaml")
require.NoError(t, os.WriteFile(configFile, []byte("Name: test-api\n"), 0644))
goModFile := filepath.Join(tempDir, "go.mod")
require.NoError(t, os.WriteFile(goModFile, []byte("module testproject\n\ngo 1.21\n"), 0644))
goSumFile := filepath.Join(tempDir, "go.sum")
require.NoError(t, os.WriteFile(goSumFile, []byte(""), 0644))
originalWd, err := os.Getwd()
require.NoError(t, err)
defer func() {
require.NoError(t, os.Chdir(originalWd))
}()
t.Run("etc directory detected from different working directory", func(t *testing.T) {
// Change to project root (not service/api)
require.NoError(t, os.Chdir(tempDir))
// Relative path to Go file
relGoFile := filepath.Join("service", "api", "api.go")
// Apply the fix: resolve etc directory relative to go file
resolvedEtcDir := filepath.Join(filepath.Dir(relGoFile), "etc")
// Verify etc directory is found
stat, err := os.Stat(resolvedEtcDir)
assert.NoError(t, err)
assert.True(t, stat.IsDir())
// Verify config can be found
cfg, err := findConfig(relGoFile, resolvedEtcDir)
assert.NoError(t, err)
assert.Equal(t, "api.yaml", cfg)
})
t.Run("GoMainFrom path is correct", func(t *testing.T) {
require.NoError(t, os.Chdir(tempDir))
goFileRel := filepath.Join("service", "api", "api.go")
// Simulate getFilePath return value
projPath := "service/api"
// Apply the fix: use filepath.Base
goMainFrom := filepath.Join(projPath, filepath.Base(goFileRel))
assert.Equal(t, "service/api/api.go", goMainFrom)
// Verify no path duplication
assert.NotContains(t, goMainFrom, "service/api/service/api")
})
}
// Test that specifically validates the bug described in PR #4343
func TestPR4343_BugFixes(t *testing.T) {
t.Run("Bug 1: etc directory check uses correct base path", func(t *testing.T) {
// Setup: Create a project structure where etc is NOT in CWD but IS relative to Go file
tempDir := t.TempDir()
serviceDir := filepath.Join(tempDir, "service", "api")
etcDir := filepath.Join(serviceDir, "etc")
require.NoError(t, os.MkdirAll(etcDir, 0755))
// Create a config file
require.NoError(t, os.WriteFile(
filepath.Join(etcDir, "config.yaml"),
[]byte("Name: test\n"),
0644,
))
originalWd, err := os.Getwd()
require.NoError(t, err)
defer func() {
require.NoError(t, os.Chdir(originalWd))
}()
// Change to project root (CWD = tempDir)
require.NoError(t, os.Chdir(tempDir))
goFile := filepath.Join("service", "api", "api.go")
// OLD (buggy) behavior: checks for "etc" in CWD
_, errOld := os.Stat("etc")
assert.Error(t, errOld, "Bug: etc not found in CWD")
// NEW (fixed) behavior: checks for "etc" relative to go file
etcDirResolved := filepath.Join(filepath.Dir(goFile), "etc")
stat, errNew := os.Stat(etcDirResolved)
assert.NoError(t, errNew, "Fix: etc found relative to go file")
assert.True(t, stat.IsDir())
// Verify config is accessible
cfg, err := findConfig(goFile, etcDirResolved)
assert.NoError(t, err)
assert.Equal(t, "config.yaml", cfg)
})
t.Run("Bug 2: GoMainFrom path not duplicated", func(t *testing.T) {
// Test case from PR description
projPath := "service/api"
goFile := "service/api/api.go"
// OLD (buggy) behavior: duplicates path
buggyPath := filepath.Join(projPath, goFile)
assert.Equal(t, "service/api/service/api/api.go", buggyPath,
"Bug: path duplication occurs with old implementation")
// NEW (fixed) behavior: correct path using filepath.Base
fixedPath := filepath.Join(projPath, filepath.Base(goFile))
assert.Equal(t, "service/api/api.go", fixedPath,
"Fix: using filepath.Base() prevents path duplication")
// Verify the fix works for various scenarios
testCases := []struct {
projPath string
goFile string
expected string
}{
{"service/api", "service/api/api.go", "service/api/api.go"},
{"cmd/server", "cmd/server/main.go", "cmd/server/main.go"},
{"internal/handler", "internal/handler/handler.go", "internal/handler/handler.go"},
{".", "main.go", "main.go"},
}
for _, tc := range testCases {
result := filepath.Join(tc.projPath, filepath.Base(tc.goFile))
assert.Equal(t, tc.expected, result,
"Fix should work for projPath=%s, goFile=%s", tc.projPath, tc.goFile)
}
})
}

View File

@@ -16,7 +16,7 @@ require (
github.com/withfig/autocomplete-tools/integrations/cobra v1.2.1
github.com/zeromicro/antlr v0.0.1
github.com/zeromicro/ddl-parser v1.0.5
github.com/zeromicro/go-zero v1.9.1
github.com/zeromicro/go-zero v1.9.2
golang.org/x/text v0.22.0
google.golang.org/grpc v1.65.0
google.golang.org/protobuf v1.36.5

View File

@@ -185,8 +185,8 @@ github.com/zeromicro/antlr v0.0.1 h1:CQpIn/dc0pUjgGQ81y98s/NGOm2Hfru2NNio2I9mQgk
github.com/zeromicro/antlr v0.0.1/go.mod h1:nfpjEwFR6Q4xGDJMcZnCL9tEfQRgszMwu3rDz2Z+p5M=
github.com/zeromicro/ddl-parser v1.0.5 h1:LaVqHdzMTjasua1yYpIYaksxKqRzFrEukj2Wi2EbWaQ=
github.com/zeromicro/ddl-parser v1.0.5/go.mod h1:ISU/8NuPyEpl9pa17Py9TBPetMjtsiHrb9f5XGiYbo8=
github.com/zeromicro/go-zero v1.9.1 h1:GZCl4jun/ZgZHnSvX3SSNDHf+tEGmEQ8x2Z23xjHa9g=
github.com/zeromicro/go-zero v1.9.1/go.mod h1:bHOl7Xr7EV/iHZWEqsUNJwFc/9WgAMrPpPagYvOaMtY=
github.com/zeromicro/go-zero v1.9.2 h1:ZXOXBIcazZ1pWAMiHyVnDQ3Sxwy7DYPzjE89Qtj9vqM=
github.com/zeromicro/go-zero v1.9.2/go.mod h1:k8YBMEFZKjTd4q/qO5RCW+zDgUlNyAs5vue3P4/Kmn0=
go.etcd.io/etcd/api/v3 v3.5.15 h1:3KpLJir1ZEBrYuV2v+Twaa/e2MdDCEZ/70H+lzEiwsk=
go.etcd.io/etcd/api/v3 v3.5.15/go.mod h1:N9EhGzXq58WuMllgH9ZvnEr7SI9pS0k0+DHZezGp7jM=
go.etcd.io/etcd/client/pkg/v3 v3.5.15 h1:fo0HpWz/KlHGMCC+YejpiCmyWDEuIpnTDzpJLB5fWlA=

View File

@@ -6,7 +6,7 @@ import (
)
// BuildVersion is the version of goctl.
const BuildVersion = "1.9.1"
const BuildVersion = "1.9.2"
var tag = map[string]int{"pre-alpha": 0, "alpha": 1, "pre-beta": 2, "beta": 3, "released": 4, "": 5}

View File

@@ -5,7 +5,7 @@ import (
"github.com/zeromicro/go-zero/tools/goctl/util/format"
)
// BeforeCommands run before comamnd run to show some migration notes
// BeforeCommands run before command run to show some migration notes
func BeforeCommands(dir, style string) error {
if err := migrateBefore1_3_4(dir, style); err != nil {
return err

View File

@@ -99,12 +99,12 @@ func (conn *MockConn) RawDB() (*sql.DB, error) {
return conn.db, nil
}
// Transact is the implemention of sqlx.SqlConn, nothing to do
// Transact is the implementation of sqlx.SqlConn, nothing to do
func (conn *MockConn) Transact(func(session sqlx.Session) error) error {
return nil
}
// TransactCtx is the implemention of sqlx.SqlConn, nothing to do
// TransactCtx is the implementation of sqlx.SqlConn, nothing to do
func (conn *MockConn) TransactCtx(ctx context.Context, fn func(context.Context, sqlx.Session) error) error {
return nil
}

View File

@@ -43,7 +43,7 @@ func Install(cacheDir string) (string, error) {
case vars.OsLinux:
downloadUrl = url[fmt.Sprintf("%s_%d", vars.OsLinux, bit)]
default:
return "", fmt.Errorf("unsupport OS: %q", goos)
return "", fmt.Errorf("unsupported OS: %q", goos)
}
err := downloader.Download(downloadUrl, tempFile)

View File

@@ -2,7 +2,6 @@ package util
import (
"slices"
"strconv"
"strings"
"github.com/zeromicro/go-zero/tools/goctl/util/console"
@@ -130,14 +129,3 @@ func FieldsAndTrimSpace(s string, f func(r rune) bool) []string {
}
return resp
}
//Deprecated: This function implementation is incomplete and does not properly handle exceptional input cases.
//We strongly recommend using the standard library's strconv.Unquote function instead,
//which provides robust error handling and comprehensive support for various input formats.
func Unquote(s string) string {
ns, err := strconv.Unquote(s)
if err != nil {
return ""
}
return ns
}

View File

@@ -76,40 +76,40 @@ func TestEscapeGoKeyword(t *testing.T) {
func TestFieldsAndTrimSpace(t *testing.T) {
testCases := []struct {
name string
input string
name string
input string
delimiter func(r rune) bool
expected []string
expected []string
}{
{
name: "Comma-separated values",
input: "a, b, c",
name: "Comma-separated values",
input: "a, b, c",
delimiter: func(r rune) bool { return r == ',' },
expected: []string{"a", " b", " c"},
expected: []string{"a", " b", " c"},
},
{
name: "Space-separated values",
input: "a b c",
name: "Space-separated values",
input: "a b c",
delimiter: unicode.IsSpace,
expected: []string{"a", "b", "c"},
expected: []string{"a", "b", "c"},
},
{
name: "Mixed whitespace",
input: "a\tb\nc",
name: "Mixed whitespace",
input: "a\tb\nc",
delimiter: unicode.IsSpace,
expected: []string{"a", "b", "c"},
expected: []string{"a", "b", "c"},
},
{
name: "Empty input",
input: "",
name: "Empty input",
input: "",
delimiter: unicode.IsSpace,
expected: []string(nil),
expected: []string(nil),
},
{
name: "Trailing and leading spaces",
input: " a , b , c ",
name: "Trailing and leading spaces",
input: " a , b , c ",
delimiter: func(r rune) bool { return r == ',' },
expected: []string{" a ", " b ", " c "},
expected: []string{" a ", " b ", " c "},
},
}
@@ -120,20 +120,3 @@ func TestFieldsAndTrimSpace(t *testing.T) {
})
}
}
func TestUnquote(t *testing.T) {
testCases := []struct {
input string
expected string
}{
{input: `"hello"`, expected: `hello`},
{input: "`world`", expected: `world`},
{input: `"foo'bar"`, expected: `foo'bar`},
{input: "", expected: ""},
}
for _, tc := range testCases {
result := Unquote(tc.input)
assert.Equal(t, tc.expected, result)
}
}

View File

@@ -1,12 +1,16 @@
package zrpc
import (
"context"
"fmt"
"time"
"github.com/zeromicro/go-zero/core/conf"
"github.com/zeromicro/go-zero/core/logx"
"github.com/zeromicro/go-zero/zrpc/internal"
"github.com/zeromicro/go-zero/zrpc/internal/auth"
"github.com/zeromicro/go-zero/zrpc/internal/balancer/consistenthash"
"github.com/zeromicro/go-zero/zrpc/internal/balancer/p2c"
"github.com/zeromicro/go-zero/zrpc/internal/clientinterceptors"
"google.golang.org/grpc"
"google.golang.org/grpc/keepalive"
@@ -67,6 +71,9 @@ func NewClient(c RpcClientConf, options ...ClientOption) (Client, error) {
})))
}
svcCfg := makeLBServiceConfig(c.BalancerName)
opts = append(opts, WithDialOption(grpc.WithDefaultServiceConfig(svcCfg)))
opts = append(opts, options...)
target, err := c.BuildTarget()
@@ -111,7 +118,20 @@ func SetClientSlowThreshold(threshold time.Duration) {
clientinterceptors.SetSlowThreshold(threshold)
}
// SetHashKey sets the hash key into context.
func SetHashKey(ctx context.Context, key string) context.Context {
return consistenthash.SetHashKey(ctx, key)
}
// WithCallTimeout return a call option with given timeout to make a method call.
func WithCallTimeout(timeout time.Duration) grpc.CallOption {
return clientinterceptors.WithCallTimeout(timeout)
}
func makeLBServiceConfig(balancerName string) string {
if len(balancerName) == 0 {
balancerName = p2c.Name
}
return fmt.Sprintf(`{"loadBalancingPolicy":"%s"}`, balancerName)
}

View File

@@ -12,6 +12,8 @@ import (
"github.com/zeromicro/go-zero/core/discov"
"github.com/zeromicro/go-zero/core/logx"
"github.com/zeromicro/go-zero/internal/mock"
"github.com/zeromicro/go-zero/zrpc/internal/balancer/consistenthash"
"github.com/zeromicro/go-zero/zrpc/internal/balancer/p2c"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/credentials/insecure"
@@ -245,3 +247,42 @@ func TestNewClientWithTarget(t *testing.T) {
assert.NotNil(t, err)
}
func TestMakeLBServiceConfig(t *testing.T) {
tests := []struct {
name string
input string
expected string
}{
{
name: "empty name uses default p2c",
input: "",
expected: fmt.Sprintf(`{"loadBalancingPolicy":"%s"}`, p2c.Name),
},
{
name: "custom balancer name",
input: "consistent_hash",
expected: `{"loadBalancingPolicy":"consistent_hash"}`,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := makeLBServiceConfig(tt.input)
if got != tt.expected {
t.Errorf("expected %q, got %q", tt.expected, got)
}
})
}
}
func TestSetHashKey(t *testing.T) {
ctx := context.Background()
key := "abc123"
ctx = SetHashKey(ctx, key)
got := consistenthash.GetHashKey(ctx)
assert.Equal(t, key, got)
assert.Empty(t, consistenthash.GetHashKey(context.Background()))
}

View File

@@ -31,6 +31,7 @@ type (
Timeout int64 `json:",default=2000"`
KeepaliveTime time.Duration `json:",optional"`
Middlewares ClientMiddlewaresConf
BalancerName string `json:",default=p2c_ewma"`
}
// A RpcServerConf is a rpc server config.

View File

@@ -4,9 +4,11 @@ import (
"testing"
"github.com/stretchr/testify/assert"
zconf "github.com/zeromicro/go-zero/core/conf"
"github.com/zeromicro/go-zero/core/discov"
"github.com/zeromicro/go-zero/core/service"
"github.com/zeromicro/go-zero/core/stores/redis"
"github.com/zeromicro/go-zero/zrpc/internal/balancer/p2c"
)
func TestRpcClientConf(t *testing.T) {
@@ -39,6 +41,13 @@ func TestRpcClientConf(t *testing.T) {
_, err := conf.BuildTarget()
assert.Error(t, err)
})
t.Run("default balancer name", func(t *testing.T) {
var conf RpcClientConf
err := zconf.FillDefault(&conf)
assert.NoError(t, err)
assert.Equal(t, p2c.Name, conf.BalancerName)
})
}
func TestRpcServerConf(t *testing.T) {

View File

@@ -0,0 +1,97 @@
package consistenthash
import (
"context"
"github.com/zeromicro/go-zero/core/hash"
"google.golang.org/grpc/balancer"
"google.golang.org/grpc/balancer/base"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
)
const (
Name = "consistent_hash"
defaultReplicaCount = 100
)
var emptyPickResult balancer.PickResult
func init() {
balancer.Register(newBuilder())
}
type (
// hashKey is the key type for consistent hash in context.
hashKey struct{}
// pickerBuilder is a builder for picker.
pickerBuilder struct{}
// picker is a picker that uses consistent hash to pick a sub connection.
picker struct {
hashRing *hash.ConsistentHash
conns map[string]balancer.SubConn
}
)
func (b *pickerBuilder) Build(info base.PickerBuildInfo) balancer.Picker {
readySCs := info.ReadySCs
if len(readySCs) == 0 {
return base.NewErrPicker(balancer.ErrNoSubConnAvailable)
}
conns := make(map[string]balancer.SubConn, len(readySCs))
hashRing := hash.NewCustomConsistentHash(defaultReplicaCount, hash.Hash)
for conn, connInfo := range readySCs {
addr := connInfo.Address.Addr
conns[addr] = conn
hashRing.Add(addr)
}
return &picker{
hashRing: hashRing,
conns: conns,
}
}
func newBuilder() balancer.Builder {
return base.NewBalancerBuilder(Name, &pickerBuilder{}, base.Config{HealthCheck: true})
}
func (p *picker) Pick(info balancer.PickInfo) (balancer.PickResult, error) {
hashKey := GetHashKey(info.Ctx)
if len(hashKey) == 0 {
return emptyPickResult, status.Error(codes.InvalidArgument,
"[consistent_hash] missing hash key in context")
}
if addrAny, ok := p.hashRing.Get(hashKey); ok {
addr, ok := addrAny.(string)
if !ok {
return emptyPickResult, status.Error(codes.Internal,
"[consistent_hash] invalid addr type in consistent hash")
}
subConn, ok := p.conns[addr]
if !ok {
return emptyPickResult, status.Errorf(codes.Internal,
"[consistent_hash] no subConn for addr: %s", addr)
}
return balancer.PickResult{SubConn: subConn}, nil
}
return emptyPickResult, status.Errorf(codes.Unavailable,
"[consistent_hash] no matching conn for hashKey: %s", hashKey)
}
// SetHashKey sets the hash key into context.
func SetHashKey(ctx context.Context, key string) context.Context {
return context.WithValue(ctx, hashKey{}, key)
}
// GetHashKey gets the hash key from context.
func GetHashKey(ctx context.Context) string {
v, _ := ctx.Value(hashKey{}).(string)
return v
}

View File

@@ -0,0 +1,175 @@
package consistenthash
import (
"context"
"testing"
"github.com/stretchr/testify/assert"
"github.com/zeromicro/go-zero/core/hash"
"google.golang.org/grpc/balancer"
"google.golang.org/grpc/balancer/base"
"google.golang.org/grpc/resolver"
)
type fakeSubConn struct{ id int }
func (f *fakeSubConn) Connect() {}
func (f *fakeSubConn) UpdateAddresses(_ []resolver.Address) {}
func (f *fakeSubConn) Shutdown() {}
func (f *fakeSubConn) GetOrBuildProducer(b balancer.ProducerBuilder) (balancer.Producer, func()) {
return nil, func() {}
}
func TestPickerBuilder_EmptyReadySCs(t *testing.T) {
b := &pickerBuilder{}
p := b.Build(base.PickerBuildInfo{ReadySCs: map[balancer.SubConn]base.SubConnInfo{}})
_, err := p.Pick(balancer.PickInfo{})
assert.Equal(t, balancer.ErrNoSubConnAvailable, err)
}
func TestPickerBuilder_BuildAndRing(t *testing.T) {
subConn1 := &fakeSubConn{id: 1}
subConn2 := &fakeSubConn{id: 2}
addr1 := "127.0.0.1:8080"
addr2 := "127.0.0.1:8081"
b := &pickerBuilder{}
info := base.PickerBuildInfo{
ReadySCs: map[balancer.SubConn]base.SubConnInfo{
subConn1: {Address: resolver.Address{Addr: addr1}},
subConn2: {Address: resolver.Address{Addr: addr2}},
},
}
p := b.Build(info).(*picker)
assert.NotNil(t, p.hashRing)
assert.Len(t, p.conns, 2)
}
func TestPicker_HashConsistency(t *testing.T) {
subConn1 := &fakeSubConn{id: 1}
subConn2 := &fakeSubConn{id: 2}
pb := &pickerBuilder{}
info := base.PickerBuildInfo{
ReadySCs: map[balancer.SubConn]base.SubConnInfo{
subConn1: {Address: resolver.Address{Addr: "127.0.0.1:8080"}},
subConn2: {Address: resolver.Address{Addr: "127.0.0.1:8081"}},
},
}
p := pb.Build(info).(*picker)
ctx := SetHashKey(context.Background(), "user_123")
res1, err := p.Pick(balancer.PickInfo{Ctx: ctx})
assert.NoError(t, err)
assert.NotNil(t, res1.SubConn)
// Multiple requests with the same key remain consistent
for i := 0; i < 5; i++ {
resN, err := p.Pick(balancer.PickInfo{Ctx: ctx})
assert.NoError(t, err)
assert.Equal(t, res1.SubConn, resN.SubConn)
}
}
func TestPicker_MissingKey(t *testing.T) {
subConn := &fakeSubConn{id: 1}
pb := &pickerBuilder{}
info := base.PickerBuildInfo{
ReadySCs: map[balancer.SubConn]base.SubConnInfo{
subConn: {Address: resolver.Address{Addr: "127.0.0.1:8080"}},
},
}
p := pb.Build(info).(*picker)
// No hash key in context
_, err := p.Pick(balancer.PickInfo{Ctx: context.Background()})
assert.Error(t, err)
assert.Contains(t, err.Error(), "[consistent_hash] missing hash key in context")
}
func TestPicker_NoMatchingConn(t *testing.T) {
emptyRing := newCustomRingForTest()
p := &picker{
hashRing: emptyRing,
conns: map[string]balancer.SubConn{},
}
_, err := p.Pick(balancer.PickInfo{Ctx: SetHashKey(context.Background(), "someone")})
assert.Error(t, err)
assert.Contains(t, err.Error(), "[consistent_hash] no matching conn for hashKey: someone")
}
func TestPicker_InvalidAddrType(t *testing.T) {
ring := newCustomRingForTest()
ring.Add(12345)
subConn := &fakeSubConn{id: 1}
p := &picker{
hashRing: ring,
conns: map[string]balancer.SubConn{
"12345": subConn,
},
}
_, err := p.Pick(balancer.PickInfo{Ctx: SetHashKey(context.Background(), "anykey")})
assert.Error(t, err)
assert.Contains(t, err.Error(), "[consistent_hash] invalid addr type in consistent hash")
}
func TestPicker_NoSubConnForAddr(t *testing.T) {
ring := newCustomRingForTest()
ring.Add("ghost:9999")
exist := &fakeSubConn{id: 1}
p := &picker{
hashRing: ring,
conns: map[string]balancer.SubConn{
"real:8080": exist,
},
}
_, err := p.Pick(balancer.PickInfo{Ctx: SetHashKey(context.Background(), "anykey")})
assert.Error(t, err)
assert.Contains(t, err.Error(), "[consistent_hash] no subConn for addr: ghost:9999")
}
func TestSetAndGetHashKey(t *testing.T) {
ctx := context.Background()
key := "abc123"
ctx = SetHashKey(ctx, key)
got := GetHashKey(ctx)
assert.Equal(t, key, got)
assert.Empty(t, GetHashKey(context.Background()))
}
func BenchmarkPicker_HashConsistency(b *testing.B) {
subConn1 := &fakeSubConn{id: 1}
subConn2 := &fakeSubConn{id: 2}
pb := &pickerBuilder{}
info := base.PickerBuildInfo{
ReadySCs: map[balancer.SubConn]base.SubConnInfo{
subConn1: {Address: resolver.Address{Addr: "127.0.0.1:8080"}},
subConn2: {Address: resolver.Address{Addr: "127.0.0.1:8081"}},
},
}
p := pb.Build(info).(*picker)
ctx := SetHashKey(context.Background(), "hot_user_123")
b.ResetTimer()
for i := 0; i < b.N; i++ {
res, err := p.Pick(balancer.PickInfo{Ctx: ctx})
if err != nil || res.SubConn == nil {
b.Fatalf("unexpected result: res=%v err=%v", res.SubConn, err)
}
}
}
func newCustomRingForTest() *hash.ConsistentHash {
return hash.NewCustomConsistentHash(defaultReplicaCount, hash.Hash)
}

View File

@@ -7,7 +7,6 @@ import (
"strings"
"time"
"github.com/zeromicro/go-zero/zrpc/internal/balancer/p2c"
"github.com/zeromicro/go-zero/zrpc/internal/clientinterceptors"
"github.com/zeromicro/go-zero/zrpc/resolver"
"google.golang.org/grpc"
@@ -53,9 +52,6 @@ func NewClient(target string, middlewares ClientMiddlewaresConf, opts ...ClientO
middlewares: middlewares,
}
svcCfg := fmt.Sprintf(`{"loadBalancingPolicy":"%s"}`, p2c.Name)
balancerOpt := WithDialOption(grpc.WithDefaultServiceConfig(svcCfg))
opts = append([]ClientOption{balancerOpt}, opts...)
if err := cli.dial(target, opts...); err != nil {
return nil, err
}