diff --git a/core/stores/sqlx/config.go b/core/stores/sqlx/config.go new file mode 100644 index 000000000..700e6d00d --- /dev/null +++ b/core/stores/sqlx/config.go @@ -0,0 +1,29 @@ +package sqlx + +import "errors" + +var ( + errEmptyDatasource = errors.New("empty datasource") + errEmptyDriverName = errors.New("empty driver name") +) + +// SqlConf defines the configuration for sqlx. +type SqlConf struct { + DataSource string + DriverName string `json:",default=mysql"` + Replicas []string `json:",optional"` + Policy string `json:",default=round-robin,options=round-robin|random"` +} + +// Validate validates the SqlxConf. +func (sc SqlConf) Validate() error { + if len(sc.DataSource) == 0 { + return errEmptyDatasource + } + + if len(sc.DriverName) == 0 { + return errEmptyDriverName + } + + return nil +} diff --git a/core/stores/sqlx/config_test.go b/core/stores/sqlx/config_test.go new file mode 100644 index 000000000..5b350ba0f --- /dev/null +++ b/core/stores/sqlx/config_test.go @@ -0,0 +1,29 @@ +package sqlx + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/zeromicro/go-zero/core/conf" +) + +func TestValidate(t *testing.T) { + text := []byte(`DataSource: primary:password@tcp(127.0.0.1:3306)/primary_db +`) + + var sc SqlConf + err := conf.LoadFromYamlBytes(text, &sc) + assert.Nil(t, err) + assert.Equal(t, "mysql", sc.DriverName) + assert.Equal(t, policyRoundRobin, sc.Policy) + assert.Nil(t, sc.Validate()) + + sc = SqlConf{} + assert.Equal(t, errEmptyDatasource, sc.Validate()) + + sc.DataSource = "primary:password@tcp(127.0.0.1:3306)/primary_db" + assert.Equal(t, errEmptyDriverName, sc.Validate()) + + sc.DriverName = "mysql" + assert.Nil(t, sc.Validate()) +} diff --git a/core/stores/sqlx/rwstrategy.go b/core/stores/sqlx/rwstrategy.go new file mode 100644 index 000000000..3a5813fa0 --- /dev/null +++ b/core/stores/sqlx/rwstrategy.go @@ -0,0 +1,65 @@ +package sqlx + +import "context" + +const ( + // policyRoundRobin round-robin policy for selecting replicas. + policyRoundRobin = "round-robin" + // policyRandom random policy for selecting replicas. + policyRandom = "random" + + // readPrimaryMode indicates that the operation is a read, + // but should be performed on the primary database instance. + // + // This mode is used in scenarios where data freshness and consistency are critical, + // such as immediately after writes or where replication lag may cause stale reads. + readPrimaryMode readWriteMode = "read-primary" + + // readReplicaMode indicates that the operation is a read from replicas. + // This is suitable for scenarios where eventual consistency is acceptable, + // and the goal is to offload traffic from the primary and improve read scalability. + readReplicaMode readWriteMode = "read-replica" + + // writeMode indicates that the operation is a write operation (to primary). + writeMode readWriteMode = "write" + + // notSpecifiedMode indicates that the read/write mode is not specified. + notSpecifiedMode readWriteMode = "" +) + +type readWriteMode string + +var readWriteModeKey struct{} + +func (m readWriteMode) isValid() bool { + return m == readPrimaryMode || m == readReplicaMode || m == writeMode +} + +func getReadWriteMode(ctx context.Context) readWriteMode { + if mode := ctx.Value(readWriteModeKey); mode != nil { + if v, ok := mode.(readWriteMode); ok && v.isValid() { + return v + } + } + + return notSpecifiedMode +} + +func useReplica(ctx context.Context) bool { + return getReadWriteMode(ctx) == readReplicaMode +} + +// WithReadPrimaryMode sets the context to read-primary mode. +func WithReadPrimaryMode(ctx context.Context) context.Context { + return context.WithValue(ctx, readWriteModeKey, readPrimaryMode) +} + +// WithReadReplicaMode sets the context to read-replica mode. +func WithReadReplicaMode(ctx context.Context) context.Context { + return context.WithValue(ctx, readWriteModeKey, readReplicaMode) +} + +// WithWriteMode sets the context to write mode, indicating that the operation is a write operation. +func WithWriteMode(ctx context.Context) context.Context { + return context.WithValue(ctx, readWriteModeKey, writeMode) +} diff --git a/core/stores/sqlx/rwstrategy_test.go b/core/stores/sqlx/rwstrategy_test.go new file mode 100644 index 000000000..acc2369a3 --- /dev/null +++ b/core/stores/sqlx/rwstrategy_test.go @@ -0,0 +1,133 @@ +package sqlx + +import ( + "context" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestIsValid(t *testing.T) { + testCases := []struct { + name string + mode readWriteMode + expected bool + }{ + { + name: "valid read-primary mode", + mode: readPrimaryMode, + expected: true, + }, + { + name: "valid read-replica mode", + mode: readReplicaMode, + expected: true, + }, + { + name: "valid write mode", + mode: writeMode, + expected: true, + }, + { + name: "not specified mode (empty)", + mode: notSpecifiedMode, + expected: false, + }, + { + name: "invalid custom string", + mode: readWriteMode("delete"), + expected: false, + }, + { + name: "case sensitive check", + mode: readWriteMode("READ"), + expected: false, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + actual := tc.mode.isValid() + assert.Equal(t, tc.expected, actual) + }) + } +} + +func TestWithReadMode(t *testing.T) { + ctx := context.Background() + readPrimaryCtx := WithReadPrimaryMode(ctx) + + val := readPrimaryCtx.Value(readWriteModeKey) + assert.Equal(t, readPrimaryMode, val) + + readReplicaCtx := WithReadReplicaMode(ctx) + val = readReplicaCtx.Value(readWriteModeKey) + assert.Equal(t, readReplicaMode, val) +} + +func TestWithWriteMode(t *testing.T) { + ctx := context.Background() + writeCtx := WithWriteMode(ctx) + + val := writeCtx.Value(readWriteModeKey) + assert.Equal(t, writeMode, val) +} + +func TestGetReadWriteMode(t *testing.T) { + t.Run("valid read-primary mode", func(t *testing.T) { + ctx := context.WithValue(context.Background(), readWriteModeKey, readPrimaryMode) + assert.Equal(t, readPrimaryMode, getReadWriteMode(ctx)) + }) + + t.Run("valid read-replica mode", func(t *testing.T) { + ctx := context.WithValue(context.Background(), readWriteModeKey, readReplicaMode) + assert.Equal(t, readReplicaMode, getReadWriteMode(ctx)) + }) + + t.Run("valid write mode", func(t *testing.T) { + ctx := context.WithValue(context.Background(), readWriteModeKey, writeMode) + assert.Equal(t, writeMode, getReadWriteMode(ctx)) + }) + + t.Run("invalid mode value (wrong type)", func(t *testing.T) { + ctx := context.WithValue(context.Background(), readWriteModeKey, "not-a-mode") + assert.Equal(t, notSpecifiedMode, getReadWriteMode(ctx)) + }) + + t.Run("invalid mode value (wrong value)", func(t *testing.T) { + ctx := context.WithValue(context.Background(), readWriteModeKey, readWriteMode("delete")) + assert.Equal(t, notSpecifiedMode, getReadWriteMode(ctx)) + }) + + t.Run("no mode set", func(t *testing.T) { + ctx := context.Background() + assert.Equal(t, notSpecifiedMode, getReadWriteMode(ctx)) + }) +} + +func TestUuseReplica(t *testing.T) { + t.Run("context with read-replica mode", func(t *testing.T) { + ctx := context.WithValue(context.Background(), readWriteModeKey, readReplicaMode) + assert.True(t, useReplica(ctx)) + }) + + t.Run("context with read-primary mode", func(t *testing.T) { + ctx := context.WithValue(context.Background(), readWriteModeKey, readPrimaryMode) + assert.False(t, useReplica(ctx)) + }) + + t.Run("context with write mode", func(t *testing.T) { + ctx := context.WithValue(context.Background(), readWriteModeKey, writeMode) + assert.False(t, useReplica(ctx)) + }) + + t.Run("context with invalid mode", func(t *testing.T) { + ctx := context.WithValue(context.Background(), readWriteModeKey, readWriteMode("invalid")) + assert.False(t, useReplica(ctx)) + }) + + t.Run("context with no mode set", func(t *testing.T) { + ctx := context.Background() + assert.False(t, useReplica(ctx)) + }) +} diff --git a/core/stores/sqlx/sqlconn.go b/core/stores/sqlx/sqlconn.go index 34a4f386a..cd13222d1 100644 --- a/core/stores/sqlx/sqlconn.go +++ b/core/stores/sqlx/sqlconn.go @@ -4,6 +4,9 @@ import ( "context" "database/sql" "errors" + "fmt" + "math/rand" + "sync/atomic" "github.com/zeromicro/go-zero/core/breaker" "github.com/zeromicro/go-zero/core/errorx" @@ -52,9 +55,10 @@ type ( beginTx beginnable brk breaker.Breaker accept breaker.Acceptable + index uint32 } - connProvider func() (*sql.DB, error) + connProvider func(ctx context.Context) (*sql.DB, error) sessionConn interface { Exec(query string, args ...any) (sql.Result, error) @@ -64,10 +68,31 @@ type ( } ) +// NewConn returns a SqlConn with the given SqlConf. +func NewConn(c SqlConf, opts ...SqlOption) SqlConn { + if err := c.Validate(); err != nil { + logx.Must(err) + } + + conn := &commonSqlConn{ + onError: func(ctx context.Context, err error) { + logInstanceError(ctx, c.DataSource, err) + }, + beginTx: begin, + brk: breaker.NewBreaker(), + } + for _, opt := range opts { + opt(conn) + } + conn.connProv = getConnProvider(conn, c.DriverName, c.DataSource, c.Policy, c.Replicas) + + return conn +} + // NewSqlConn returns a SqlConn with given driver name and datasource. func NewSqlConn(driverName, datasource string, opts ...SqlOption) SqlConn { conn := &commonSqlConn{ - connProv: func() (*sql.DB, error) { + connProv: func(context.Context) (*sql.DB, error) { return getSqlConn(driverName, datasource) }, onError: func(ctx context.Context, err error) { @@ -87,7 +112,7 @@ func NewSqlConn(driverName, datasource string, opts ...SqlOption) SqlConn { // Use it with caution; it's provided for other ORM to interact with. func NewSqlConnFromDB(db *sql.DB, opts ...SqlOption) SqlConn { conn := &commonSqlConn{ - connProv: func() (*sql.DB, error) { + connProv: func(ctx context.Context) (*sql.DB, error) { return db, nil }, onError: func(ctx context.Context, err error) { @@ -123,7 +148,7 @@ func (db *commonSqlConn) ExecCtx(ctx context.Context, q string, args ...any) ( err = db.brk.DoWithAcceptableCtx(ctx, func() error { var conn *sql.DB - conn, err = db.connProv() + conn, err = db.connProv(ctx) if err != nil { db.onError(ctx, err) return err @@ -151,7 +176,7 @@ func (db *commonSqlConn) PrepareCtx(ctx context.Context, query string) (stmt Stm err = db.brk.DoWithAcceptableCtx(ctx, func() error { var conn *sql.DB - conn, err = db.connProv() + conn, err = db.connProv(ctx) if err != nil { db.onError(ctx, err) return err @@ -242,7 +267,7 @@ func (db *commonSqlConn) QueryRowsPartialCtx(ctx context.Context, v any, } func (db *commonSqlConn) RawDB() (*sql.DB, error) { - return db.connProv() + return db.connProv(context.Background()) } func (db *commonSqlConn) Transact(fn func(Session) error) error { @@ -288,7 +313,7 @@ func (db *commonSqlConn) queryRows(ctx context.Context, scanner func(*sql.Rows) q string, args ...any) (err error) { var scanFailed bool err = db.brk.DoWithAcceptableCtx(ctx, func() error { - conn, err := db.connProv() + conn, err := db.connProv(ctx) if err != nil { db.onError(ctx, err) return err @@ -311,6 +336,38 @@ func (db *commonSqlConn) queryRows(ctx context.Context, scanner func(*sql.Rows) return } +func getConnProvider(sc *commonSqlConn, driverName, datasource, policy string, replicas []string) connProvider { + return func(ctx context.Context) (*sql.DB, error) { + replicaCount := len(replicas) + + if replicaCount == 0 || !useReplica(ctx) { + return getSqlConn(driverName, datasource) + } + + var dsn string + + if replicaCount == 1 { + dsn = replicas[0] + } else { + if len(policy) == 0 { + policy = policyRoundRobin + } + + switch policy { + case policyRandom: + dsn = replicas[rand.Intn(replicaCount)] + case policyRoundRobin: + index := atomic.AddUint32(&sc.index, 1) - 1 + dsn = replicas[index%uint32(replicaCount)] + default: + return nil, fmt.Errorf("unknown policy: %s", policy) + } + } + + return getSqlConn(driverName, dsn) + } +} + // WithAcceptable returns a SqlOption that setting the acceptable function. // acceptable is the func to check if the error can be accepted. func WithAcceptable(acceptable func(err error) bool) SqlOption { diff --git a/core/stores/sqlx/sqlconn_test.go b/core/stores/sqlx/sqlconn_test.go index 53ea0d5ac..1f252c437 100644 --- a/core/stores/sqlx/sqlconn_test.go +++ b/core/stores/sqlx/sqlconn_test.go @@ -1,6 +1,7 @@ package sqlx import ( + "context" "database/sql" "errors" "io" @@ -98,7 +99,7 @@ func TestSqlConn_RawDB(t *testing.T) { func TestSqlConn_Errors(t *testing.T) { dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) { conn := NewSqlConnFromDB(db) - conn.(*commonSqlConn).connProv = func() (*sql.DB, error) { + conn.(*commonSqlConn).connProv = func(ctx context.Context) (*sql.DB, error) { return nil, errors.New("error") } _, err := conn.Prepare("any") @@ -138,6 +139,148 @@ func TestSqlConn_Errors(t *testing.T) { }) } +func TestConfigSqlConn(t *testing.T) { + db, mock, err := sqlmock.New() + assert.NotNil(t, db) + assert.NotNil(t, mock) + assert.Nil(t, err) + connManager.Inject(mockedDatasource, db) + mock.ExpectExec("any") + mock.ExpectQuery("any").WillReturnRows(sqlmock.NewRows([]string{"foo"})) + + conf := SqlConf{DataSource: mockedDatasource, DriverName: mysqlDriverName} + conn := NewConn(conf, withMysqlAcceptable()) + + _, err = conn.Exec("any", "value") + assert.NotNil(t, err) + _, err = conn.Prepare("any") + assert.NotNil(t, err) + + var val string + assert.NotNil(t, conn.QueryRow(&val, "any")) + assert.NotNil(t, conn.QueryRowPartial(&val, "any")) + assert.NotNil(t, conn.QueryRows(&val, "any")) + assert.NotNil(t, conn.QueryRowsPartial(&val, "any")) +} + +func TestConfigSqlConnStatement(t *testing.T) { + db, mock, err := sqlmock.New() + assert.NotNil(t, db) + assert.NotNil(t, mock) + assert.Nil(t, err) + connManager.Inject(mockedDatasource, db) + + mock.ExpectPrepare("any") + mock.ExpectExec("any").WillReturnResult(sqlmock.NewResult(2, 3)) + mock.ExpectPrepare("any") + row := sqlmock.NewRows([]string{"foo"}).AddRow("bar") + mock.ExpectQuery("any").WillReturnRows(row) + + conf := SqlConf{DataSource: mockedDatasource, DriverName: mysqlDriverName} + conn := NewConn(conf, withMysqlAcceptable()) + stmt, err := conn.Prepare("any") + assert.NoError(t, err) + + res, err := stmt.Exec() + assert.NoError(t, err) + lastInsertID, err := res.LastInsertId() + assert.NoError(t, err) + assert.Equal(t, int64(2), lastInsertID) + rowsAffected, err := res.RowsAffected() + assert.NoError(t, err) + assert.Equal(t, int64(3), rowsAffected) + + stmt, err = conn.Prepare("any") + assert.NoError(t, err) + + var val string + err = stmt.QueryRow(&val) + assert.NoError(t, err) + assert.Equal(t, "bar", val) + + mock.ExpectPrepare("any") + rows := sqlmock.NewRows([]string{"any"}).AddRow("foo").AddRow("bar") + mock.ExpectQuery("any").WillReturnRows(rows) + + stmt, err = conn.Prepare("any") + assert.NoError(t, err) + + var vals []string + assert.NoError(t, stmt.QueryRowsPartial(&vals)) + assert.ElementsMatch(t, []string{"foo", "bar"}, vals) +} + +func TestConfigSqlConnQuery(t *testing.T) { + db, mock, err := sqlmock.New() + assert.NotNil(t, db) + assert.NotNil(t, mock) + assert.Nil(t, err) + connManager.Inject(mockedDatasource, db) + + t.Run("QueryRow", func(t *testing.T) { + mock.ExpectQuery("any").WillReturnRows(sqlmock.NewRows([]string{"foo"}).AddRow("bar")) + conf := SqlConf{DataSource: mockedDatasource, DriverName: mysqlDriverName} + conn := NewConn(conf) + var val string + assert.NoError(t, conn.QueryRow(&val, "any")) + assert.Equal(t, "bar", val) + }) + + t.Run("QueryRowPartial", func(t *testing.T) { + mock.ExpectQuery("any").WillReturnRows(sqlmock.NewRows([]string{"foo"}).AddRow("bar")) + conf := SqlConf{DataSource: mockedDatasource, DriverName: mysqlDriverName} + conn := NewConn(conf) + var val string + assert.NoError(t, conn.QueryRowPartial(&val, "any")) + assert.Equal(t, "bar", val) + }) + + t.Run("QueryRows", func(t *testing.T) { + mock.ExpectQuery("any").WillReturnRows(sqlmock.NewRows([]string{"any"}).AddRow("foo").AddRow("bar")) + conf := SqlConf{DataSource: mockedDatasource, DriverName: mysqlDriverName} + conn := NewConn(conf) + var vals []string + assert.NoError(t, conn.QueryRows(&vals, "any")) + assert.ElementsMatch(t, []string{"foo", "bar"}, vals) + }) + + t.Run("QueryRowsPartial", func(t *testing.T) { + mock.ExpectQuery("any").WillReturnRows(sqlmock.NewRows([]string{"any"}).AddRow("foo").AddRow("bar")) + conf := SqlConf{DataSource: mockedDatasource, DriverName: mysqlDriverName} + conn := NewConn(conf) + var vals []string + assert.NoError(t, conn.QueryRowsPartial(&vals, "any")) + assert.ElementsMatch(t, []string{"foo", "bar"}, vals) + }) +} + +func TestConfigSqlConnErr(t *testing.T) { + t.Run("panic on empty config", func(t *testing.T) { + original := logx.ExitOnFatal.True() + logx.ExitOnFatal.Set(false) + defer logx.ExitOnFatal.Set(original) + + assert.Panics(t, func() { + NewConn(SqlConf{}) + }) + }) + t.Run("on error", func(t *testing.T) { + db, mock, err := sqlmock.New() + assert.NotNil(t, db) + assert.NotNil(t, mock) + assert.Nil(t, err) + connManager.Inject(mockedDatasource, db) + + conf := SqlConf{DataSource: mockedDatasource, DriverName: mysqlDriverName} + conn := NewConn(conf) + conn.(*commonSqlConn).connProv = func(ctx context.Context) (*sql.DB, error) { + return nil, errors.New("error") + } + _, err = conn.Prepare("any") + assert.Error(t, err) + }) +} + func TestStatement(t *testing.T) { dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) { mock.ExpectPrepare("any").WillBeClosed() @@ -303,6 +446,93 @@ func TestWithAcceptable(t *testing.T) { assert.True(t, conn.accept(acceptableErr3)) } +func TestProvider(t *testing.T) { + defer func() { + _ = connManager.Close() + }() + + primaryDSN := "primary:password@tcp(127.0.0.1:3306)/primary_db" + replicasDSN := []string{ + "replica_one:pwd@tcp(localhost:3306)/replica_one", + "replica_two:pwd@tcp(localhost:3306)/replica_two", + "replica_three:pwd@tcp(localhost:3306)/replica_three", + } + + primaryDB, err := connManager.GetResource(primaryDSN, func() (io.Closer, error) { return sql.Open(mysqlDriverName, primaryDSN) }) + assert.Nil(t, err) + assert.NotNil(t, primaryDB) + replicaOneDB, err := connManager.GetResource(replicasDSN[0], func() (io.Closer, error) { return sql.Open(mysqlDriverName, replicasDSN[0]) }) + assert.Nil(t, err) + assert.NotNil(t, replicaOneDB) + replicaTwoDB, err := connManager.GetResource(replicasDSN[1], func() (io.Closer, error) { return sql.Open(mysqlDriverName, replicasDSN[1]) }) + assert.Nil(t, err) + assert.NotNil(t, replicaTwoDB) + replicaThreeDB, err := connManager.GetResource(replicasDSN[2], func() (io.Closer, error) { return sql.Open(mysqlDriverName, replicasDSN[2]) }) + assert.Nil(t, err) + assert.NotNil(t, replicaThreeDB) + + sc := &commonSqlConn{} + sc.connProv = getConnProvider(sc, mysqlDriverName, primaryDSN, policyRoundRobin, nil) + + ctx := context.Background() + db, err := sc.connProv(ctx) + assert.Nil(t, err) + assert.Equal(t, primaryDB, db) + + ctx = WithWriteMode(ctx) + db, err = sc.connProv(ctx) + assert.Nil(t, err) + assert.Equal(t, primaryDB, db) + + ctx = WithReadPrimaryMode(ctx) + db, err = sc.connProv(ctx) + assert.Nil(t, err) + assert.Equal(t, primaryDB, db) + + // no mode set, should return primary + ctx = context.Background() + sc.connProv = getConnProvider(sc, mysqlDriverName, primaryDSN, policyRoundRobin, replicasDSN) + db, err = sc.connProv(ctx) + assert.Nil(t, err) + assert.Equal(t, primaryDB, db) + + ctx = WithReadReplicaMode(ctx) + sc.connProv = getConnProvider(sc, mysqlDriverName, primaryDSN, policyRoundRobin, []string{replicasDSN[0]}) + db, err = sc.connProv(ctx) + assert.Nil(t, err) + assert.Equal(t, replicaOneDB, db) + + // default policy is round-robin + sc.connProv = getConnProvider(sc, mysqlDriverName, primaryDSN, policyRoundRobin, replicasDSN) + replicas := []io.Closer{replicaOneDB, replicaTwoDB, replicaThreeDB} + for i := 0; i < len(replicasDSN); i++ { + db, err = sc.connProv(ctx) + assert.Nil(t, err) + assert.Equal(t, replicas[i], db) + } + + // random policy + sc.connProv = getConnProvider(sc, mysqlDriverName, primaryDSN, policyRandom, replicasDSN) + for i := 0; i < len(replicasDSN); i++ { + db, err = sc.connProv(ctx) + assert.Nil(t, err) + assert.Contains(t, replicas, db) + } + + // unknown policy + sc.connProv = getConnProvider(sc, mysqlDriverName, primaryDSN, "unknown", replicasDSN) + _, err = sc.connProv(ctx) + assert.NotNil(t, err) + + // empty policy transforms to round-robin + sc.connProv = getConnProvider(sc, mysqlDriverName, primaryDSN, "", replicasDSN) + for i := 0; i < len(replicasDSN); i++ { + db, err = sc.connProv(ctx) + assert.Nil(t, err) + assert.Equal(t, replicas[i], db) + } +} + func buildConn() (mock sqlmock.Sqlmock, err error) { _, err = connManager.GetResource(mockedDatasource, func() (io.Closer, error) { var db *sql.DB diff --git a/core/stores/sqlx/tx.go b/core/stores/sqlx/tx.go index ea2fd2bee..02370ca07 100644 --- a/core/stores/sqlx/tx.go +++ b/core/stores/sqlx/tx.go @@ -156,7 +156,7 @@ func begin(db *sql.DB) (trans, error) { func transact(ctx context.Context, db *commonSqlConn, b beginnable, fn func(context.Context, Session) error) (err error) { - conn, err := db.connProv() + conn, err := db.connProv(ctx) if err != nil { db.onError(ctx, err) return err diff --git a/core/stores/sqlx/tx_test.go b/core/stores/sqlx/tx_test.go index a7c87f1a4..3a2de138b 100644 --- a/core/stores/sqlx/tx_test.go +++ b/core/stores/sqlx/tx_test.go @@ -117,7 +117,7 @@ func TestTxExceptions(t *testing.T) { dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) { conn := &commonSqlConn{ - connProv: func() (*sql.DB, error) { + connProv: func(ctx context.Context) (*sql.DB, error) { return nil, errors.New("foo") }, beginTx: begin,