chore: refactoring sql read write mode (#4990)

Signed-off-by: kevin <wanjunfeng@gmail.com>
This commit is contained in:
Kevin Wan
2025-07-11 01:05:55 +08:00
committed by GitHub
parent 8c6266f338
commit bae8d4f4c8
4 changed files with 63 additions and 44 deletions

View File

@@ -27,10 +27,25 @@ const (
notSpecifiedMode readWriteMode = "" notSpecifiedMode readWriteMode = ""
) )
type readWriteMode string
var readWriteModeKey struct{} var readWriteModeKey struct{}
// WithReadPrimary sets the context to read-primary mode.
func WithReadPrimary(ctx context.Context) context.Context {
return context.WithValue(ctx, readWriteModeKey, readPrimaryMode)
}
// WithReadReplica sets the context to read-replica mode.
func WithReadReplica(ctx context.Context) context.Context {
return context.WithValue(ctx, readWriteModeKey, readReplicaMode)
}
// WithWrite sets the context to write mode, indicating that the operation is a write operation.
func WithWrite(ctx context.Context) context.Context {
return context.WithValue(ctx, readWriteModeKey, writeMode)
}
type readWriteMode string
func (m readWriteMode) isValid() bool { func (m readWriteMode) isValid() bool {
return m == readPrimaryMode || m == readReplicaMode || m == writeMode return m == readPrimaryMode || m == readReplicaMode || m == writeMode
} }
@@ -45,21 +60,6 @@ func getReadWriteMode(ctx context.Context) readWriteMode {
return notSpecifiedMode return notSpecifiedMode
} }
func useReplica(ctx context.Context) bool { func usePrimary(ctx context.Context) bool {
return getReadWriteMode(ctx) == readReplicaMode return getReadWriteMode(ctx) != readReplicaMode
}
// WithReadPrimaryMode sets the context to read-primary mode.
func WithReadPrimaryMode(ctx context.Context) context.Context {
return context.WithValue(ctx, readWriteModeKey, readPrimaryMode)
}
// WithReadReplicaMode sets the context to read-replica mode.
func WithReadReplicaMode(ctx context.Context) context.Context {
return context.WithValue(ctx, readWriteModeKey, readReplicaMode)
}
// WithWriteMode sets the context to write mode, indicating that the operation is a write operation.
func WithWriteMode(ctx context.Context) context.Context {
return context.WithValue(ctx, readWriteModeKey, writeMode)
} }

View File

@@ -55,19 +55,19 @@ func TestIsValid(t *testing.T) {
func TestWithReadMode(t *testing.T) { func TestWithReadMode(t *testing.T) {
ctx := context.Background() ctx := context.Background()
readPrimaryCtx := WithReadPrimaryMode(ctx) readPrimaryCtx := WithReadPrimary(ctx)
val := readPrimaryCtx.Value(readWriteModeKey) val := readPrimaryCtx.Value(readWriteModeKey)
assert.Equal(t, readPrimaryMode, val) assert.Equal(t, readPrimaryMode, val)
readReplicaCtx := WithReadReplicaMode(ctx) readReplicaCtx := WithReadReplica(ctx)
val = readReplicaCtx.Value(readWriteModeKey) val = readReplicaCtx.Value(readWriteModeKey)
assert.Equal(t, readReplicaMode, val) assert.Equal(t, readReplicaMode, val)
} }
func TestWithWriteMode(t *testing.T) { func TestWithWriteMode(t *testing.T) {
ctx := context.Background() ctx := context.Background()
writeCtx := WithWriteMode(ctx) writeCtx := WithWrite(ctx)
val := writeCtx.Value(readWriteModeKey) val := writeCtx.Value(readWriteModeKey)
assert.Equal(t, writeMode, val) assert.Equal(t, writeMode, val)
@@ -105,29 +105,38 @@ func TestGetReadWriteMode(t *testing.T) {
}) })
} }
func TestUuseReplica(t *testing.T) { func TestUsePrimary(t *testing.T) {
t.Run("context with read-replica mode", func(t *testing.T) { t.Run("context with read-replica mode", func(t *testing.T) {
ctx := context.WithValue(context.Background(), readWriteModeKey, readReplicaMode) ctx := context.WithValue(context.Background(), readWriteModeKey, readReplicaMode)
assert.True(t, useReplica(ctx)) assert.False(t, usePrimary(ctx))
}) })
t.Run("context with read-primary mode", func(t *testing.T) { t.Run("context with read-primary mode", func(t *testing.T) {
ctx := context.WithValue(context.Background(), readWriteModeKey, readPrimaryMode) ctx := context.WithValue(context.Background(), readWriteModeKey, readPrimaryMode)
assert.False(t, useReplica(ctx)) assert.True(t, usePrimary(ctx))
}) })
t.Run("context with write mode", func(t *testing.T) { t.Run("context with write mode", func(t *testing.T) {
ctx := context.WithValue(context.Background(), readWriteModeKey, writeMode) ctx := context.WithValue(context.Background(), readWriteModeKey, writeMode)
assert.False(t, useReplica(ctx)) assert.True(t, usePrimary(ctx))
}) })
t.Run("context with invalid mode", func(t *testing.T) { t.Run("context with invalid mode", func(t *testing.T) {
ctx := context.WithValue(context.Background(), readWriteModeKey, readWriteMode("invalid")) ctx := context.WithValue(context.Background(), readWriteModeKey, readWriteMode("invalid"))
assert.False(t, useReplica(ctx)) assert.True(t, usePrimary(ctx))
}) })
t.Run("context with no mode set", func(t *testing.T) { t.Run("context with no mode set", func(t *testing.T) {
ctx := context.Background() ctx := context.Background()
assert.False(t, useReplica(ctx)) assert.True(t, usePrimary(ctx))
}) })
} }
func TestWithModeTwice(t *testing.T) {
ctx := context.Background()
ctx = WithReadPrimary(ctx)
writeCtx := WithWrite(ctx)
val := writeCtx.Value(readWriteModeKey)
assert.Equal(t, writeMode, val)
}

View File

@@ -68,12 +68,22 @@ type (
} }
) )
// NewConn returns a SqlConn with the given SqlConf. // MustNewConn returns a SqlConn with the given SqlConf.
func NewConn(c SqlConf, opts ...SqlOption) SqlConn { func MustNewConn(c SqlConf, opts ...SqlOption) SqlConn {
if err := c.Validate(); err != nil { conn, err := NewConn(c, opts...)
if err != nil {
logx.Must(err) logx.Must(err)
} }
return conn
}
// NewConn returns a SqlConn with the given SqlConf.
func NewConn(c SqlConf, opts ...SqlOption) (SqlConn, error) {
if err := c.Validate(); err != nil {
return nil, err
}
conn := &commonSqlConn{ conn := &commonSqlConn{
onError: func(ctx context.Context, err error) { onError: func(ctx context.Context, err error) {
logInstanceError(ctx, c.DataSource, err) logInstanceError(ctx, c.DataSource, err)
@@ -86,7 +96,7 @@ func NewConn(c SqlConf, opts ...SqlOption) SqlConn {
} }
conn.connProv = getConnProvider(conn, c.DriverName, c.DataSource, c.Policy, c.Replicas) conn.connProv = getConnProvider(conn, c.DriverName, c.DataSource, c.Policy, c.Replicas)
return conn return conn, nil
} }
// NewSqlConn returns a SqlConn with given driver name and datasource. // NewSqlConn returns a SqlConn with given driver name and datasource.
@@ -340,7 +350,7 @@ func getConnProvider(sc *commonSqlConn, driverName, datasource, policy string, r
return func(ctx context.Context) (*sql.DB, error) { return func(ctx context.Context) (*sql.DB, error) {
replicaCount := len(replicas) replicaCount := len(replicas)
if replicaCount == 0 || !useReplica(ctx) { if replicaCount == 0 || usePrimary(ctx) {
return getSqlConn(driverName, datasource) return getSqlConn(driverName, datasource)
} }

View File

@@ -149,7 +149,7 @@ func TestConfigSqlConn(t *testing.T) {
mock.ExpectQuery("any").WillReturnRows(sqlmock.NewRows([]string{"foo"})) mock.ExpectQuery("any").WillReturnRows(sqlmock.NewRows([]string{"foo"}))
conf := SqlConf{DataSource: mockedDatasource, DriverName: mysqlDriverName} conf := SqlConf{DataSource: mockedDatasource, DriverName: mysqlDriverName}
conn := NewConn(conf, withMysqlAcceptable()) conn := MustNewConn(conf, withMysqlAcceptable())
_, err = conn.Exec("any", "value") _, err = conn.Exec("any", "value")
assert.NotNil(t, err) assert.NotNil(t, err)
@@ -177,7 +177,7 @@ func TestConfigSqlConnStatement(t *testing.T) {
mock.ExpectQuery("any").WillReturnRows(row) mock.ExpectQuery("any").WillReturnRows(row)
conf := SqlConf{DataSource: mockedDatasource, DriverName: mysqlDriverName} conf := SqlConf{DataSource: mockedDatasource, DriverName: mysqlDriverName}
conn := NewConn(conf, withMysqlAcceptable()) conn := MustNewConn(conf, withMysqlAcceptable())
stmt, err := conn.Prepare("any") stmt, err := conn.Prepare("any")
assert.NoError(t, err) assert.NoError(t, err)
@@ -220,7 +220,7 @@ func TestConfigSqlConnQuery(t *testing.T) {
t.Run("QueryRow", func(t *testing.T) { t.Run("QueryRow", func(t *testing.T) {
mock.ExpectQuery("any").WillReturnRows(sqlmock.NewRows([]string{"foo"}).AddRow("bar")) mock.ExpectQuery("any").WillReturnRows(sqlmock.NewRows([]string{"foo"}).AddRow("bar"))
conf := SqlConf{DataSource: mockedDatasource, DriverName: mysqlDriverName} conf := SqlConf{DataSource: mockedDatasource, DriverName: mysqlDriverName}
conn := NewConn(conf) conn := MustNewConn(conf)
var val string var val string
assert.NoError(t, conn.QueryRow(&val, "any")) assert.NoError(t, conn.QueryRow(&val, "any"))
assert.Equal(t, "bar", val) assert.Equal(t, "bar", val)
@@ -229,7 +229,7 @@ func TestConfigSqlConnQuery(t *testing.T) {
t.Run("QueryRowPartial", func(t *testing.T) { t.Run("QueryRowPartial", func(t *testing.T) {
mock.ExpectQuery("any").WillReturnRows(sqlmock.NewRows([]string{"foo"}).AddRow("bar")) mock.ExpectQuery("any").WillReturnRows(sqlmock.NewRows([]string{"foo"}).AddRow("bar"))
conf := SqlConf{DataSource: mockedDatasource, DriverName: mysqlDriverName} conf := SqlConf{DataSource: mockedDatasource, DriverName: mysqlDriverName}
conn := NewConn(conf) conn := MustNewConn(conf)
var val string var val string
assert.NoError(t, conn.QueryRowPartial(&val, "any")) assert.NoError(t, conn.QueryRowPartial(&val, "any"))
assert.Equal(t, "bar", val) assert.Equal(t, "bar", val)
@@ -238,7 +238,7 @@ func TestConfigSqlConnQuery(t *testing.T) {
t.Run("QueryRows", func(t *testing.T) { t.Run("QueryRows", func(t *testing.T) {
mock.ExpectQuery("any").WillReturnRows(sqlmock.NewRows([]string{"any"}).AddRow("foo").AddRow("bar")) mock.ExpectQuery("any").WillReturnRows(sqlmock.NewRows([]string{"any"}).AddRow("foo").AddRow("bar"))
conf := SqlConf{DataSource: mockedDatasource, DriverName: mysqlDriverName} conf := SqlConf{DataSource: mockedDatasource, DriverName: mysqlDriverName}
conn := NewConn(conf) conn := MustNewConn(conf)
var vals []string var vals []string
assert.NoError(t, conn.QueryRows(&vals, "any")) assert.NoError(t, conn.QueryRows(&vals, "any"))
assert.ElementsMatch(t, []string{"foo", "bar"}, vals) assert.ElementsMatch(t, []string{"foo", "bar"}, vals)
@@ -247,7 +247,7 @@ func TestConfigSqlConnQuery(t *testing.T) {
t.Run("QueryRowsPartial", func(t *testing.T) { t.Run("QueryRowsPartial", func(t *testing.T) {
mock.ExpectQuery("any").WillReturnRows(sqlmock.NewRows([]string{"any"}).AddRow("foo").AddRow("bar")) mock.ExpectQuery("any").WillReturnRows(sqlmock.NewRows([]string{"any"}).AddRow("foo").AddRow("bar"))
conf := SqlConf{DataSource: mockedDatasource, DriverName: mysqlDriverName} conf := SqlConf{DataSource: mockedDatasource, DriverName: mysqlDriverName}
conn := NewConn(conf) conn := MustNewConn(conf)
var vals []string var vals []string
assert.NoError(t, conn.QueryRowsPartial(&vals, "any")) assert.NoError(t, conn.QueryRowsPartial(&vals, "any"))
assert.ElementsMatch(t, []string{"foo", "bar"}, vals) assert.ElementsMatch(t, []string{"foo", "bar"}, vals)
@@ -261,7 +261,7 @@ func TestConfigSqlConnErr(t *testing.T) {
defer logx.ExitOnFatal.Set(original) defer logx.ExitOnFatal.Set(original)
assert.Panics(t, func() { assert.Panics(t, func() {
NewConn(SqlConf{}) MustNewConn(SqlConf{})
}) })
}) })
t.Run("on error", func(t *testing.T) { t.Run("on error", func(t *testing.T) {
@@ -272,7 +272,7 @@ func TestConfigSqlConnErr(t *testing.T) {
connManager.Inject(mockedDatasource, db) connManager.Inject(mockedDatasource, db)
conf := SqlConf{DataSource: mockedDatasource, DriverName: mysqlDriverName} conf := SqlConf{DataSource: mockedDatasource, DriverName: mysqlDriverName}
conn := NewConn(conf) conn := MustNewConn(conf)
conn.(*commonSqlConn).connProv = func(ctx context.Context) (*sql.DB, error) { conn.(*commonSqlConn).connProv = func(ctx context.Context) (*sql.DB, error) {
return nil, errors.New("error") return nil, errors.New("error")
} }
@@ -479,12 +479,12 @@ func TestProvider(t *testing.T) {
assert.Nil(t, err) assert.Nil(t, err)
assert.Equal(t, primaryDB, db) assert.Equal(t, primaryDB, db)
ctx = WithWriteMode(ctx) ctx = WithWrite(ctx)
db, err = sc.connProv(ctx) db, err = sc.connProv(ctx)
assert.Nil(t, err) assert.Nil(t, err)
assert.Equal(t, primaryDB, db) assert.Equal(t, primaryDB, db)
ctx = WithReadPrimaryMode(ctx) ctx = WithReadPrimary(ctx)
db, err = sc.connProv(ctx) db, err = sc.connProv(ctx)
assert.Nil(t, err) assert.Nil(t, err)
assert.Equal(t, primaryDB, db) assert.Equal(t, primaryDB, db)
@@ -496,7 +496,7 @@ func TestProvider(t *testing.T) {
assert.Nil(t, err) assert.Nil(t, err)
assert.Equal(t, primaryDB, db) assert.Equal(t, primaryDB, db)
ctx = WithReadReplicaMode(ctx) ctx = WithReadReplica(ctx)
sc.connProv = getConnProvider(sc, mysqlDriverName, primaryDSN, policyRoundRobin, []string{replicasDSN[0]}) sc.connProv = getConnProvider(sc, mysqlDriverName, primaryDSN, policyRoundRobin, []string{replicasDSN[0]})
db, err = sc.connProv(ctx) db, err = sc.connProv(ctx)
assert.Nil(t, err) assert.Nil(t, err)