From bae8d4f4c82cbc2b266b8a505566c9004adbabdb Mon Sep 17 00:00:00 2001 From: Kevin Wan Date: Fri, 11 Jul 2025 01:05:55 +0800 Subject: [PATCH] chore: refactoring sql read write mode (#4990) Signed-off-by: kevin --- core/stores/sqlx/rwstrategy.go | 38 ++++++++++++++--------------- core/stores/sqlx/rwstrategy_test.go | 27 +++++++++++++------- core/stores/sqlx/sqlconn.go | 20 +++++++++++---- core/stores/sqlx/sqlconn_test.go | 22 ++++++++--------- 4 files changed, 63 insertions(+), 44 deletions(-) diff --git a/core/stores/sqlx/rwstrategy.go b/core/stores/sqlx/rwstrategy.go index 3a5813fa0..c4020a446 100644 --- a/core/stores/sqlx/rwstrategy.go +++ b/core/stores/sqlx/rwstrategy.go @@ -27,10 +27,25 @@ const ( notSpecifiedMode readWriteMode = "" ) -type readWriteMode string - var readWriteModeKey struct{} +// WithReadPrimary sets the context to read-primary mode. +func WithReadPrimary(ctx context.Context) context.Context { + return context.WithValue(ctx, readWriteModeKey, readPrimaryMode) +} + +// WithReadReplica sets the context to read-replica mode. +func WithReadReplica(ctx context.Context) context.Context { + return context.WithValue(ctx, readWriteModeKey, readReplicaMode) +} + +// WithWrite sets the context to write mode, indicating that the operation is a write operation. +func WithWrite(ctx context.Context) context.Context { + return context.WithValue(ctx, readWriteModeKey, writeMode) +} + +type readWriteMode string + func (m readWriteMode) isValid() bool { return m == readPrimaryMode || m == readReplicaMode || m == writeMode } @@ -45,21 +60,6 @@ func getReadWriteMode(ctx context.Context) readWriteMode { return notSpecifiedMode } -func useReplica(ctx context.Context) bool { - return getReadWriteMode(ctx) == readReplicaMode -} - -// WithReadPrimaryMode sets the context to read-primary mode. -func WithReadPrimaryMode(ctx context.Context) context.Context { - return context.WithValue(ctx, readWriteModeKey, readPrimaryMode) -} - -// WithReadReplicaMode sets the context to read-replica mode. -func WithReadReplicaMode(ctx context.Context) context.Context { - return context.WithValue(ctx, readWriteModeKey, readReplicaMode) -} - -// WithWriteMode sets the context to write mode, indicating that the operation is a write operation. -func WithWriteMode(ctx context.Context) context.Context { - return context.WithValue(ctx, readWriteModeKey, writeMode) +func usePrimary(ctx context.Context) bool { + return getReadWriteMode(ctx) != readReplicaMode } diff --git a/core/stores/sqlx/rwstrategy_test.go b/core/stores/sqlx/rwstrategy_test.go index acc2369a3..d90c1ddc2 100644 --- a/core/stores/sqlx/rwstrategy_test.go +++ b/core/stores/sqlx/rwstrategy_test.go @@ -55,19 +55,19 @@ func TestIsValid(t *testing.T) { func TestWithReadMode(t *testing.T) { ctx := context.Background() - readPrimaryCtx := WithReadPrimaryMode(ctx) + readPrimaryCtx := WithReadPrimary(ctx) val := readPrimaryCtx.Value(readWriteModeKey) assert.Equal(t, readPrimaryMode, val) - readReplicaCtx := WithReadReplicaMode(ctx) + readReplicaCtx := WithReadReplica(ctx) val = readReplicaCtx.Value(readWriteModeKey) assert.Equal(t, readReplicaMode, val) } func TestWithWriteMode(t *testing.T) { ctx := context.Background() - writeCtx := WithWriteMode(ctx) + writeCtx := WithWrite(ctx) val := writeCtx.Value(readWriteModeKey) assert.Equal(t, writeMode, val) @@ -105,29 +105,38 @@ func TestGetReadWriteMode(t *testing.T) { }) } -func TestUuseReplica(t *testing.T) { +func TestUsePrimary(t *testing.T) { t.Run("context with read-replica mode", func(t *testing.T) { ctx := context.WithValue(context.Background(), readWriteModeKey, readReplicaMode) - assert.True(t, useReplica(ctx)) + assert.False(t, usePrimary(ctx)) }) t.Run("context with read-primary mode", func(t *testing.T) { ctx := context.WithValue(context.Background(), readWriteModeKey, readPrimaryMode) - assert.False(t, useReplica(ctx)) + assert.True(t, usePrimary(ctx)) }) t.Run("context with write mode", func(t *testing.T) { ctx := context.WithValue(context.Background(), readWriteModeKey, writeMode) - assert.False(t, useReplica(ctx)) + assert.True(t, usePrimary(ctx)) }) t.Run("context with invalid mode", func(t *testing.T) { ctx := context.WithValue(context.Background(), readWriteModeKey, readWriteMode("invalid")) - assert.False(t, useReplica(ctx)) + assert.True(t, usePrimary(ctx)) }) t.Run("context with no mode set", func(t *testing.T) { ctx := context.Background() - assert.False(t, useReplica(ctx)) + assert.True(t, usePrimary(ctx)) }) } + +func TestWithModeTwice(t *testing.T) { + ctx := context.Background() + ctx = WithReadPrimary(ctx) + writeCtx := WithWrite(ctx) + + val := writeCtx.Value(readWriteModeKey) + assert.Equal(t, writeMode, val) +} diff --git a/core/stores/sqlx/sqlconn.go b/core/stores/sqlx/sqlconn.go index cd13222d1..81fe9c9bf 100644 --- a/core/stores/sqlx/sqlconn.go +++ b/core/stores/sqlx/sqlconn.go @@ -68,12 +68,22 @@ type ( } ) -// NewConn returns a SqlConn with the given SqlConf. -func NewConn(c SqlConf, opts ...SqlOption) SqlConn { - if err := c.Validate(); err != nil { +// MustNewConn returns a SqlConn with the given SqlConf. +func MustNewConn(c SqlConf, opts ...SqlOption) SqlConn { + conn, err := NewConn(c, opts...) + if err != nil { logx.Must(err) } + return conn +} + +// NewConn returns a SqlConn with the given SqlConf. +func NewConn(c SqlConf, opts ...SqlOption) (SqlConn, error) { + if err := c.Validate(); err != nil { + return nil, err + } + conn := &commonSqlConn{ onError: func(ctx context.Context, err error) { logInstanceError(ctx, c.DataSource, err) @@ -86,7 +96,7 @@ func NewConn(c SqlConf, opts ...SqlOption) SqlConn { } conn.connProv = getConnProvider(conn, c.DriverName, c.DataSource, c.Policy, c.Replicas) - return conn + return conn, nil } // NewSqlConn returns a SqlConn with given driver name and datasource. @@ -340,7 +350,7 @@ func getConnProvider(sc *commonSqlConn, driverName, datasource, policy string, r return func(ctx context.Context) (*sql.DB, error) { replicaCount := len(replicas) - if replicaCount == 0 || !useReplica(ctx) { + if replicaCount == 0 || usePrimary(ctx) { return getSqlConn(driverName, datasource) } diff --git a/core/stores/sqlx/sqlconn_test.go b/core/stores/sqlx/sqlconn_test.go index 1f252c437..95e383ce1 100644 --- a/core/stores/sqlx/sqlconn_test.go +++ b/core/stores/sqlx/sqlconn_test.go @@ -149,7 +149,7 @@ func TestConfigSqlConn(t *testing.T) { mock.ExpectQuery("any").WillReturnRows(sqlmock.NewRows([]string{"foo"})) conf := SqlConf{DataSource: mockedDatasource, DriverName: mysqlDriverName} - conn := NewConn(conf, withMysqlAcceptable()) + conn := MustNewConn(conf, withMysqlAcceptable()) _, err = conn.Exec("any", "value") assert.NotNil(t, err) @@ -177,7 +177,7 @@ func TestConfigSqlConnStatement(t *testing.T) { mock.ExpectQuery("any").WillReturnRows(row) conf := SqlConf{DataSource: mockedDatasource, DriverName: mysqlDriverName} - conn := NewConn(conf, withMysqlAcceptable()) + conn := MustNewConn(conf, withMysqlAcceptable()) stmt, err := conn.Prepare("any") assert.NoError(t, err) @@ -220,7 +220,7 @@ func TestConfigSqlConnQuery(t *testing.T) { t.Run("QueryRow", func(t *testing.T) { mock.ExpectQuery("any").WillReturnRows(sqlmock.NewRows([]string{"foo"}).AddRow("bar")) conf := SqlConf{DataSource: mockedDatasource, DriverName: mysqlDriverName} - conn := NewConn(conf) + conn := MustNewConn(conf) var val string assert.NoError(t, conn.QueryRow(&val, "any")) assert.Equal(t, "bar", val) @@ -229,7 +229,7 @@ func TestConfigSqlConnQuery(t *testing.T) { t.Run("QueryRowPartial", func(t *testing.T) { mock.ExpectQuery("any").WillReturnRows(sqlmock.NewRows([]string{"foo"}).AddRow("bar")) conf := SqlConf{DataSource: mockedDatasource, DriverName: mysqlDriverName} - conn := NewConn(conf) + conn := MustNewConn(conf) var val string assert.NoError(t, conn.QueryRowPartial(&val, "any")) assert.Equal(t, "bar", val) @@ -238,7 +238,7 @@ func TestConfigSqlConnQuery(t *testing.T) { t.Run("QueryRows", func(t *testing.T) { mock.ExpectQuery("any").WillReturnRows(sqlmock.NewRows([]string{"any"}).AddRow("foo").AddRow("bar")) conf := SqlConf{DataSource: mockedDatasource, DriverName: mysqlDriverName} - conn := NewConn(conf) + conn := MustNewConn(conf) var vals []string assert.NoError(t, conn.QueryRows(&vals, "any")) assert.ElementsMatch(t, []string{"foo", "bar"}, vals) @@ -247,7 +247,7 @@ func TestConfigSqlConnQuery(t *testing.T) { t.Run("QueryRowsPartial", func(t *testing.T) { mock.ExpectQuery("any").WillReturnRows(sqlmock.NewRows([]string{"any"}).AddRow("foo").AddRow("bar")) conf := SqlConf{DataSource: mockedDatasource, DriverName: mysqlDriverName} - conn := NewConn(conf) + conn := MustNewConn(conf) var vals []string assert.NoError(t, conn.QueryRowsPartial(&vals, "any")) assert.ElementsMatch(t, []string{"foo", "bar"}, vals) @@ -261,7 +261,7 @@ func TestConfigSqlConnErr(t *testing.T) { defer logx.ExitOnFatal.Set(original) assert.Panics(t, func() { - NewConn(SqlConf{}) + MustNewConn(SqlConf{}) }) }) t.Run("on error", func(t *testing.T) { @@ -272,7 +272,7 @@ func TestConfigSqlConnErr(t *testing.T) { connManager.Inject(mockedDatasource, db) conf := SqlConf{DataSource: mockedDatasource, DriverName: mysqlDriverName} - conn := NewConn(conf) + conn := MustNewConn(conf) conn.(*commonSqlConn).connProv = func(ctx context.Context) (*sql.DB, error) { return nil, errors.New("error") } @@ -479,12 +479,12 @@ func TestProvider(t *testing.T) { assert.Nil(t, err) assert.Equal(t, primaryDB, db) - ctx = WithWriteMode(ctx) + ctx = WithWrite(ctx) db, err = sc.connProv(ctx) assert.Nil(t, err) assert.Equal(t, primaryDB, db) - ctx = WithReadPrimaryMode(ctx) + ctx = WithReadPrimary(ctx) db, err = sc.connProv(ctx) assert.Nil(t, err) assert.Equal(t, primaryDB, db) @@ -496,7 +496,7 @@ func TestProvider(t *testing.T) { assert.Nil(t, err) assert.Equal(t, primaryDB, db) - ctx = WithReadReplicaMode(ctx) + ctx = WithReadReplica(ctx) sc.connProv = getConnProvider(sc, mysqlDriverName, primaryDSN, policyRoundRobin, []string{replicasDSN[0]}) db, err = sc.connProv(ctx) assert.Nil(t, err)