From a71e56de52b6dc48df97bc3b59250e28b4d0569a Mon Sep 17 00:00:00 2001 From: Kevin Wan Date: Sat, 12 Jul 2025 06:58:08 +0800 Subject: [PATCH] fix: context key error in sql read write mode (#5000) --- core/logx/fields.go | 9 ++++----- core/logx/fields_test.go | 10 +++++----- core/logx/richlogger.go | 2 +- core/stores/sqlx/rwstrategy.go | 10 +++++----- core/stores/sqlx/rwstrategy_test.go | 26 +++++++++++++------------- 5 files changed, 28 insertions(+), 29 deletions(-) diff --git a/core/logx/fields.go b/core/logx/fields.go index 80d5d2619..20b458818 100644 --- a/core/logx/fields.go +++ b/core/logx/fields.go @@ -7,12 +7,11 @@ import ( ) var ( - fieldsContextKey contextKey globalFields atomic.Value globalFieldsLock sync.Mutex ) -type contextKey struct{} +type fieldsKey struct{} // AddGlobalFields adds global fields. func AddGlobalFields(fields ...LogField) { @@ -29,16 +28,16 @@ func AddGlobalFields(fields ...LogField) { // ContextWithFields returns a new context with the given fields. func ContextWithFields(ctx context.Context, fields ...LogField) context.Context { - if val := ctx.Value(fieldsContextKey); val != nil { + if val := ctx.Value(fieldsKey{}); val != nil { if arr, ok := val.([]LogField); ok { allFields := make([]LogField, 0, len(arr)+len(fields)) allFields = append(allFields, arr...) allFields = append(allFields, fields...) - return context.WithValue(ctx, fieldsContextKey, allFields) + return context.WithValue(ctx, fieldsKey{}, allFields) } } - return context.WithValue(ctx, fieldsContextKey, fields) + return context.WithValue(ctx, fieldsKey{}, fields) } // WithFields returns a new logger with the given fields. diff --git a/core/logx/fields_test.go b/core/logx/fields_test.go index c0727eb5a..f06cb26cf 100644 --- a/core/logx/fields_test.go +++ b/core/logx/fields_test.go @@ -34,7 +34,7 @@ func TestAddGlobalFields(t *testing.T) { func TestContextWithFields(t *testing.T) { ctx := ContextWithFields(context.Background(), Field("a", 1), Field("b", 2)) - vals := ctx.Value(fieldsContextKey) + vals := ctx.Value(fieldsKey{}) assert.NotNil(t, vals) fields, ok := vals.([]LogField) assert.True(t, ok) @@ -43,7 +43,7 @@ func TestContextWithFields(t *testing.T) { func TestWithFields(t *testing.T) { ctx := WithFields(context.Background(), Field("a", 1), Field("b", 2)) - vals := ctx.Value(fieldsContextKey) + vals := ctx.Value(fieldsKey{}) assert.NotNil(t, vals) fields, ok := vals.([]LogField) assert.True(t, ok) @@ -55,7 +55,7 @@ func TestWithFieldsAppend(t *testing.T) { ctx := context.WithValue(context.Background(), dummyKey, "dummy") ctx = ContextWithFields(ctx, Field("a", 1), Field("b", 2)) ctx = ContextWithFields(ctx, Field("c", 3), Field("d", 4)) - vals := ctx.Value(fieldsContextKey) + vals := ctx.Value(fieldsKey{}) assert.NotNil(t, vals) fields, ok := vals.([]LogField) assert.True(t, ok) @@ -80,8 +80,8 @@ func TestWithFieldsAppendCopy(t *testing.T) { ctxa := ContextWithFields(ctx, af) ctxb := ContextWithFields(ctx, bf) - assert.EqualValues(t, af, ctxa.Value(fieldsContextKey).([]LogField)[count]) - assert.EqualValues(t, bf, ctxb.Value(fieldsContextKey).([]LogField)[count]) + assert.EqualValues(t, af, ctxa.Value(fieldsKey{}).([]LogField)[count]) + assert.EqualValues(t, bf, ctxb.Value(fieldsKey{}).([]LogField)[count]) } func BenchmarkAtomicValue(b *testing.B) { diff --git a/core/logx/richlogger.go b/core/logx/richlogger.go index dc80ead22..95ae70c14 100644 --- a/core/logx/richlogger.go +++ b/core/logx/richlogger.go @@ -224,7 +224,7 @@ func (l *richLogger) buildFields(fields ...LogField) []LogField { fields = append(fields, Field(spanKey, spanID)) } - val := l.ctx.Value(fieldsContextKey) + val := l.ctx.Value(fieldsKey{}) if val != nil { if arr, ok := val.([]LogField); ok { fields = append(fields, arr...) diff --git a/core/stores/sqlx/rwstrategy.go b/core/stores/sqlx/rwstrategy.go index c4020a446..31fe3a035 100644 --- a/core/stores/sqlx/rwstrategy.go +++ b/core/stores/sqlx/rwstrategy.go @@ -27,21 +27,21 @@ const ( notSpecifiedMode readWriteMode = "" ) -var readWriteModeKey struct{} +type readWriteModeKey struct{} // WithReadPrimary sets the context to read-primary mode. func WithReadPrimary(ctx context.Context) context.Context { - return context.WithValue(ctx, readWriteModeKey, readPrimaryMode) + return context.WithValue(ctx, readWriteModeKey{}, readPrimaryMode) } // WithReadReplica sets the context to read-replica mode. func WithReadReplica(ctx context.Context) context.Context { - return context.WithValue(ctx, readWriteModeKey, readReplicaMode) + return context.WithValue(ctx, readWriteModeKey{}, readReplicaMode) } // WithWrite sets the context to write mode, indicating that the operation is a write operation. func WithWrite(ctx context.Context) context.Context { - return context.WithValue(ctx, readWriteModeKey, writeMode) + return context.WithValue(ctx, readWriteModeKey{}, writeMode) } type readWriteMode string @@ -51,7 +51,7 @@ func (m readWriteMode) isValid() bool { } func getReadWriteMode(ctx context.Context) readWriteMode { - if mode := ctx.Value(readWriteModeKey); mode != nil { + if mode := ctx.Value(readWriteModeKey{}); mode != nil { if v, ok := mode.(readWriteMode); ok && v.isValid() { return v } diff --git a/core/stores/sqlx/rwstrategy_test.go b/core/stores/sqlx/rwstrategy_test.go index d90c1ddc2..b368c7cb4 100644 --- a/core/stores/sqlx/rwstrategy_test.go +++ b/core/stores/sqlx/rwstrategy_test.go @@ -57,11 +57,11 @@ func TestWithReadMode(t *testing.T) { ctx := context.Background() readPrimaryCtx := WithReadPrimary(ctx) - val := readPrimaryCtx.Value(readWriteModeKey) + val := readPrimaryCtx.Value(readWriteModeKey{}) assert.Equal(t, readPrimaryMode, val) readReplicaCtx := WithReadReplica(ctx) - val = readReplicaCtx.Value(readWriteModeKey) + val = readReplicaCtx.Value(readWriteModeKey{}) assert.Equal(t, readReplicaMode, val) } @@ -69,33 +69,33 @@ func TestWithWriteMode(t *testing.T) { ctx := context.Background() writeCtx := WithWrite(ctx) - val := writeCtx.Value(readWriteModeKey) + val := writeCtx.Value(readWriteModeKey{}) assert.Equal(t, writeMode, val) } func TestGetReadWriteMode(t *testing.T) { t.Run("valid read-primary mode", func(t *testing.T) { - ctx := context.WithValue(context.Background(), readWriteModeKey, readPrimaryMode) + ctx := context.WithValue(context.Background(), readWriteModeKey{}, readPrimaryMode) assert.Equal(t, readPrimaryMode, getReadWriteMode(ctx)) }) t.Run("valid read-replica mode", func(t *testing.T) { - ctx := context.WithValue(context.Background(), readWriteModeKey, readReplicaMode) + ctx := context.WithValue(context.Background(), readWriteModeKey{}, readReplicaMode) assert.Equal(t, readReplicaMode, getReadWriteMode(ctx)) }) t.Run("valid write mode", func(t *testing.T) { - ctx := context.WithValue(context.Background(), readWriteModeKey, writeMode) + ctx := context.WithValue(context.Background(), readWriteModeKey{}, writeMode) assert.Equal(t, writeMode, getReadWriteMode(ctx)) }) t.Run("invalid mode value (wrong type)", func(t *testing.T) { - ctx := context.WithValue(context.Background(), readWriteModeKey, "not-a-mode") + ctx := context.WithValue(context.Background(), readWriteModeKey{}, "not-a-mode") assert.Equal(t, notSpecifiedMode, getReadWriteMode(ctx)) }) t.Run("invalid mode value (wrong value)", func(t *testing.T) { - ctx := context.WithValue(context.Background(), readWriteModeKey, readWriteMode("delete")) + ctx := context.WithValue(context.Background(), readWriteModeKey{}, readWriteMode("delete")) assert.Equal(t, notSpecifiedMode, getReadWriteMode(ctx)) }) @@ -107,22 +107,22 @@ func TestGetReadWriteMode(t *testing.T) { func TestUsePrimary(t *testing.T) { t.Run("context with read-replica mode", func(t *testing.T) { - ctx := context.WithValue(context.Background(), readWriteModeKey, readReplicaMode) + ctx := context.WithValue(context.Background(), readWriteModeKey{}, readReplicaMode) assert.False(t, usePrimary(ctx)) }) t.Run("context with read-primary mode", func(t *testing.T) { - ctx := context.WithValue(context.Background(), readWriteModeKey, readPrimaryMode) + ctx := context.WithValue(context.Background(), readWriteModeKey{}, readPrimaryMode) assert.True(t, usePrimary(ctx)) }) t.Run("context with write mode", func(t *testing.T) { - ctx := context.WithValue(context.Background(), readWriteModeKey, writeMode) + ctx := context.WithValue(context.Background(), readWriteModeKey{}, writeMode) assert.True(t, usePrimary(ctx)) }) t.Run("context with invalid mode", func(t *testing.T) { - ctx := context.WithValue(context.Background(), readWriteModeKey, readWriteMode("invalid")) + ctx := context.WithValue(context.Background(), readWriteModeKey{}, readWriteMode("invalid")) assert.True(t, usePrimary(ctx)) }) @@ -137,6 +137,6 @@ func TestWithModeTwice(t *testing.T) { ctx = WithReadPrimary(ctx) writeCtx := WithWrite(ctx) - val := writeCtx.Value(readWriteModeKey) + val := writeCtx.Value(readWriteModeKey{}) assert.Equal(t, writeMode, val) }