fix: context key error in sql read write mode (#5000)

This commit is contained in:
Kevin Wan
2025-07-12 06:58:08 +08:00
committed by GitHub
parent bae8d4f4c8
commit a71e56de52
5 changed files with 28 additions and 29 deletions

View File

@@ -7,12 +7,11 @@ import (
) )
var ( var (
fieldsContextKey contextKey
globalFields atomic.Value globalFields atomic.Value
globalFieldsLock sync.Mutex globalFieldsLock sync.Mutex
) )
type contextKey struct{} type fieldsKey struct{}
// AddGlobalFields adds global fields. // AddGlobalFields adds global fields.
func AddGlobalFields(fields ...LogField) { func AddGlobalFields(fields ...LogField) {
@@ -29,16 +28,16 @@ func AddGlobalFields(fields ...LogField) {
// ContextWithFields returns a new context with the given fields. // ContextWithFields returns a new context with the given fields.
func ContextWithFields(ctx context.Context, fields ...LogField) context.Context { func ContextWithFields(ctx context.Context, fields ...LogField) context.Context {
if val := ctx.Value(fieldsContextKey); val != nil { if val := ctx.Value(fieldsKey{}); val != nil {
if arr, ok := val.([]LogField); ok { if arr, ok := val.([]LogField); ok {
allFields := make([]LogField, 0, len(arr)+len(fields)) allFields := make([]LogField, 0, len(arr)+len(fields))
allFields = append(allFields, arr...) allFields = append(allFields, arr...)
allFields = append(allFields, fields...) allFields = append(allFields, fields...)
return context.WithValue(ctx, fieldsContextKey, allFields) return context.WithValue(ctx, fieldsKey{}, allFields)
} }
} }
return context.WithValue(ctx, fieldsContextKey, fields) return context.WithValue(ctx, fieldsKey{}, fields)
} }
// WithFields returns a new logger with the given fields. // WithFields returns a new logger with the given fields.

View File

@@ -34,7 +34,7 @@ func TestAddGlobalFields(t *testing.T) {
func TestContextWithFields(t *testing.T) { func TestContextWithFields(t *testing.T) {
ctx := ContextWithFields(context.Background(), Field("a", 1), Field("b", 2)) ctx := ContextWithFields(context.Background(), Field("a", 1), Field("b", 2))
vals := ctx.Value(fieldsContextKey) vals := ctx.Value(fieldsKey{})
assert.NotNil(t, vals) assert.NotNil(t, vals)
fields, ok := vals.([]LogField) fields, ok := vals.([]LogField)
assert.True(t, ok) assert.True(t, ok)
@@ -43,7 +43,7 @@ func TestContextWithFields(t *testing.T) {
func TestWithFields(t *testing.T) { func TestWithFields(t *testing.T) {
ctx := WithFields(context.Background(), Field("a", 1), Field("b", 2)) ctx := WithFields(context.Background(), Field("a", 1), Field("b", 2))
vals := ctx.Value(fieldsContextKey) vals := ctx.Value(fieldsKey{})
assert.NotNil(t, vals) assert.NotNil(t, vals)
fields, ok := vals.([]LogField) fields, ok := vals.([]LogField)
assert.True(t, ok) assert.True(t, ok)
@@ -55,7 +55,7 @@ func TestWithFieldsAppend(t *testing.T) {
ctx := context.WithValue(context.Background(), dummyKey, "dummy") ctx := context.WithValue(context.Background(), dummyKey, "dummy")
ctx = ContextWithFields(ctx, Field("a", 1), Field("b", 2)) ctx = ContextWithFields(ctx, Field("a", 1), Field("b", 2))
ctx = ContextWithFields(ctx, Field("c", 3), Field("d", 4)) ctx = ContextWithFields(ctx, Field("c", 3), Field("d", 4))
vals := ctx.Value(fieldsContextKey) vals := ctx.Value(fieldsKey{})
assert.NotNil(t, vals) assert.NotNil(t, vals)
fields, ok := vals.([]LogField) fields, ok := vals.([]LogField)
assert.True(t, ok) assert.True(t, ok)
@@ -80,8 +80,8 @@ func TestWithFieldsAppendCopy(t *testing.T) {
ctxa := ContextWithFields(ctx, af) ctxa := ContextWithFields(ctx, af)
ctxb := ContextWithFields(ctx, bf) ctxb := ContextWithFields(ctx, bf)
assert.EqualValues(t, af, ctxa.Value(fieldsContextKey).([]LogField)[count]) assert.EqualValues(t, af, ctxa.Value(fieldsKey{}).([]LogField)[count])
assert.EqualValues(t, bf, ctxb.Value(fieldsContextKey).([]LogField)[count]) assert.EqualValues(t, bf, ctxb.Value(fieldsKey{}).([]LogField)[count])
} }
func BenchmarkAtomicValue(b *testing.B) { func BenchmarkAtomicValue(b *testing.B) {

View File

@@ -224,7 +224,7 @@ func (l *richLogger) buildFields(fields ...LogField) []LogField {
fields = append(fields, Field(spanKey, spanID)) fields = append(fields, Field(spanKey, spanID))
} }
val := l.ctx.Value(fieldsContextKey) val := l.ctx.Value(fieldsKey{})
if val != nil { if val != nil {
if arr, ok := val.([]LogField); ok { if arr, ok := val.([]LogField); ok {
fields = append(fields, arr...) fields = append(fields, arr...)

View File

@@ -27,21 +27,21 @@ const (
notSpecifiedMode readWriteMode = "" notSpecifiedMode readWriteMode = ""
) )
var readWriteModeKey struct{} type readWriteModeKey struct{}
// WithReadPrimary sets the context to read-primary mode. // WithReadPrimary sets the context to read-primary mode.
func WithReadPrimary(ctx context.Context) context.Context { func WithReadPrimary(ctx context.Context) context.Context {
return context.WithValue(ctx, readWriteModeKey, readPrimaryMode) return context.WithValue(ctx, readWriteModeKey{}, readPrimaryMode)
} }
// WithReadReplica sets the context to read-replica mode. // WithReadReplica sets the context to read-replica mode.
func WithReadReplica(ctx context.Context) context.Context { func WithReadReplica(ctx context.Context) context.Context {
return context.WithValue(ctx, readWriteModeKey, readReplicaMode) return context.WithValue(ctx, readWriteModeKey{}, readReplicaMode)
} }
// WithWrite sets the context to write mode, indicating that the operation is a write operation. // WithWrite sets the context to write mode, indicating that the operation is a write operation.
func WithWrite(ctx context.Context) context.Context { func WithWrite(ctx context.Context) context.Context {
return context.WithValue(ctx, readWriteModeKey, writeMode) return context.WithValue(ctx, readWriteModeKey{}, writeMode)
} }
type readWriteMode string type readWriteMode string
@@ -51,7 +51,7 @@ func (m readWriteMode) isValid() bool {
} }
func getReadWriteMode(ctx context.Context) readWriteMode { func getReadWriteMode(ctx context.Context) readWriteMode {
if mode := ctx.Value(readWriteModeKey); mode != nil { if mode := ctx.Value(readWriteModeKey{}); mode != nil {
if v, ok := mode.(readWriteMode); ok && v.isValid() { if v, ok := mode.(readWriteMode); ok && v.isValid() {
return v return v
} }

View File

@@ -57,11 +57,11 @@ func TestWithReadMode(t *testing.T) {
ctx := context.Background() ctx := context.Background()
readPrimaryCtx := WithReadPrimary(ctx) readPrimaryCtx := WithReadPrimary(ctx)
val := readPrimaryCtx.Value(readWriteModeKey) val := readPrimaryCtx.Value(readWriteModeKey{})
assert.Equal(t, readPrimaryMode, val) assert.Equal(t, readPrimaryMode, val)
readReplicaCtx := WithReadReplica(ctx) readReplicaCtx := WithReadReplica(ctx)
val = readReplicaCtx.Value(readWriteModeKey) val = readReplicaCtx.Value(readWriteModeKey{})
assert.Equal(t, readReplicaMode, val) assert.Equal(t, readReplicaMode, val)
} }
@@ -69,33 +69,33 @@ func TestWithWriteMode(t *testing.T) {
ctx := context.Background() ctx := context.Background()
writeCtx := WithWrite(ctx) writeCtx := WithWrite(ctx)
val := writeCtx.Value(readWriteModeKey) val := writeCtx.Value(readWriteModeKey{})
assert.Equal(t, writeMode, val) assert.Equal(t, writeMode, val)
} }
func TestGetReadWriteMode(t *testing.T) { func TestGetReadWriteMode(t *testing.T) {
t.Run("valid read-primary mode", func(t *testing.T) { t.Run("valid read-primary mode", func(t *testing.T) {
ctx := context.WithValue(context.Background(), readWriteModeKey, readPrimaryMode) ctx := context.WithValue(context.Background(), readWriteModeKey{}, readPrimaryMode)
assert.Equal(t, readPrimaryMode, getReadWriteMode(ctx)) assert.Equal(t, readPrimaryMode, getReadWriteMode(ctx))
}) })
t.Run("valid read-replica mode", func(t *testing.T) { t.Run("valid read-replica mode", func(t *testing.T) {
ctx := context.WithValue(context.Background(), readWriteModeKey, readReplicaMode) ctx := context.WithValue(context.Background(), readWriteModeKey{}, readReplicaMode)
assert.Equal(t, readReplicaMode, getReadWriteMode(ctx)) assert.Equal(t, readReplicaMode, getReadWriteMode(ctx))
}) })
t.Run("valid write mode", func(t *testing.T) { t.Run("valid write mode", func(t *testing.T) {
ctx := context.WithValue(context.Background(), readWriteModeKey, writeMode) ctx := context.WithValue(context.Background(), readWriteModeKey{}, writeMode)
assert.Equal(t, writeMode, getReadWriteMode(ctx)) assert.Equal(t, writeMode, getReadWriteMode(ctx))
}) })
t.Run("invalid mode value (wrong type)", func(t *testing.T) { t.Run("invalid mode value (wrong type)", func(t *testing.T) {
ctx := context.WithValue(context.Background(), readWriteModeKey, "not-a-mode") ctx := context.WithValue(context.Background(), readWriteModeKey{}, "not-a-mode")
assert.Equal(t, notSpecifiedMode, getReadWriteMode(ctx)) assert.Equal(t, notSpecifiedMode, getReadWriteMode(ctx))
}) })
t.Run("invalid mode value (wrong value)", func(t *testing.T) { t.Run("invalid mode value (wrong value)", func(t *testing.T) {
ctx := context.WithValue(context.Background(), readWriteModeKey, readWriteMode("delete")) ctx := context.WithValue(context.Background(), readWriteModeKey{}, readWriteMode("delete"))
assert.Equal(t, notSpecifiedMode, getReadWriteMode(ctx)) assert.Equal(t, notSpecifiedMode, getReadWriteMode(ctx))
}) })
@@ -107,22 +107,22 @@ func TestGetReadWriteMode(t *testing.T) {
func TestUsePrimary(t *testing.T) { func TestUsePrimary(t *testing.T) {
t.Run("context with read-replica mode", func(t *testing.T) { t.Run("context with read-replica mode", func(t *testing.T) {
ctx := context.WithValue(context.Background(), readWriteModeKey, readReplicaMode) ctx := context.WithValue(context.Background(), readWriteModeKey{}, readReplicaMode)
assert.False(t, usePrimary(ctx)) assert.False(t, usePrimary(ctx))
}) })
t.Run("context with read-primary mode", func(t *testing.T) { t.Run("context with read-primary mode", func(t *testing.T) {
ctx := context.WithValue(context.Background(), readWriteModeKey, readPrimaryMode) ctx := context.WithValue(context.Background(), readWriteModeKey{}, readPrimaryMode)
assert.True(t, usePrimary(ctx)) assert.True(t, usePrimary(ctx))
}) })
t.Run("context with write mode", func(t *testing.T) { t.Run("context with write mode", func(t *testing.T) {
ctx := context.WithValue(context.Background(), readWriteModeKey, writeMode) ctx := context.WithValue(context.Background(), readWriteModeKey{}, writeMode)
assert.True(t, usePrimary(ctx)) assert.True(t, usePrimary(ctx))
}) })
t.Run("context with invalid mode", func(t *testing.T) { t.Run("context with invalid mode", func(t *testing.T) {
ctx := context.WithValue(context.Background(), readWriteModeKey, readWriteMode("invalid")) ctx := context.WithValue(context.Background(), readWriteModeKey{}, readWriteMode("invalid"))
assert.True(t, usePrimary(ctx)) assert.True(t, usePrimary(ctx))
}) })
@@ -137,6 +137,6 @@ func TestWithModeTwice(t *testing.T) {
ctx = WithReadPrimary(ctx) ctx = WithReadPrimary(ctx)
writeCtx := WithWrite(ctx) writeCtx := WithWrite(ctx)
val := writeCtx.Value(readWriteModeKey) val := writeCtx.Value(readWriteModeKey{})
assert.Equal(t, writeMode, val) assert.Equal(t, writeMode, val)
} }