mirror of
https://github.com/zeromicro/go-zero.git
synced 2026-05-10 16:30:01 +08:00
sql read write support (#4976)
Co-authored-by: light.zhou <light.zhou@bkyo.io>
This commit is contained in:
@@ -1,6 +1,7 @@
|
||||
package sqlx
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"errors"
|
||||
"io"
|
||||
@@ -98,7 +99,7 @@ func TestSqlConn_RawDB(t *testing.T) {
|
||||
func TestSqlConn_Errors(t *testing.T) {
|
||||
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||
conn := NewSqlConnFromDB(db)
|
||||
conn.(*commonSqlConn).connProv = func() (*sql.DB, error) {
|
||||
conn.(*commonSqlConn).connProv = func(ctx context.Context) (*sql.DB, error) {
|
||||
return nil, errors.New("error")
|
||||
}
|
||||
_, err := conn.Prepare("any")
|
||||
@@ -138,6 +139,148 @@ func TestSqlConn_Errors(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
func TestConfigSqlConn(t *testing.T) {
|
||||
db, mock, err := sqlmock.New()
|
||||
assert.NotNil(t, db)
|
||||
assert.NotNil(t, mock)
|
||||
assert.Nil(t, err)
|
||||
connManager.Inject(mockedDatasource, db)
|
||||
mock.ExpectExec("any")
|
||||
mock.ExpectQuery("any").WillReturnRows(sqlmock.NewRows([]string{"foo"}))
|
||||
|
||||
conf := SqlConf{DataSource: mockedDatasource, DriverName: mysqlDriverName}
|
||||
conn := NewConn(conf, withMysqlAcceptable())
|
||||
|
||||
_, err = conn.Exec("any", "value")
|
||||
assert.NotNil(t, err)
|
||||
_, err = conn.Prepare("any")
|
||||
assert.NotNil(t, err)
|
||||
|
||||
var val string
|
||||
assert.NotNil(t, conn.QueryRow(&val, "any"))
|
||||
assert.NotNil(t, conn.QueryRowPartial(&val, "any"))
|
||||
assert.NotNil(t, conn.QueryRows(&val, "any"))
|
||||
assert.NotNil(t, conn.QueryRowsPartial(&val, "any"))
|
||||
}
|
||||
|
||||
func TestConfigSqlConnStatement(t *testing.T) {
|
||||
db, mock, err := sqlmock.New()
|
||||
assert.NotNil(t, db)
|
||||
assert.NotNil(t, mock)
|
||||
assert.Nil(t, err)
|
||||
connManager.Inject(mockedDatasource, db)
|
||||
|
||||
mock.ExpectPrepare("any")
|
||||
mock.ExpectExec("any").WillReturnResult(sqlmock.NewResult(2, 3))
|
||||
mock.ExpectPrepare("any")
|
||||
row := sqlmock.NewRows([]string{"foo"}).AddRow("bar")
|
||||
mock.ExpectQuery("any").WillReturnRows(row)
|
||||
|
||||
conf := SqlConf{DataSource: mockedDatasource, DriverName: mysqlDriverName}
|
||||
conn := NewConn(conf, withMysqlAcceptable())
|
||||
stmt, err := conn.Prepare("any")
|
||||
assert.NoError(t, err)
|
||||
|
||||
res, err := stmt.Exec()
|
||||
assert.NoError(t, err)
|
||||
lastInsertID, err := res.LastInsertId()
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, int64(2), lastInsertID)
|
||||
rowsAffected, err := res.RowsAffected()
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, int64(3), rowsAffected)
|
||||
|
||||
stmt, err = conn.Prepare("any")
|
||||
assert.NoError(t, err)
|
||||
|
||||
var val string
|
||||
err = stmt.QueryRow(&val)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "bar", val)
|
||||
|
||||
mock.ExpectPrepare("any")
|
||||
rows := sqlmock.NewRows([]string{"any"}).AddRow("foo").AddRow("bar")
|
||||
mock.ExpectQuery("any").WillReturnRows(rows)
|
||||
|
||||
stmt, err = conn.Prepare("any")
|
||||
assert.NoError(t, err)
|
||||
|
||||
var vals []string
|
||||
assert.NoError(t, stmt.QueryRowsPartial(&vals))
|
||||
assert.ElementsMatch(t, []string{"foo", "bar"}, vals)
|
||||
}
|
||||
|
||||
func TestConfigSqlConnQuery(t *testing.T) {
|
||||
db, mock, err := sqlmock.New()
|
||||
assert.NotNil(t, db)
|
||||
assert.NotNil(t, mock)
|
||||
assert.Nil(t, err)
|
||||
connManager.Inject(mockedDatasource, db)
|
||||
|
||||
t.Run("QueryRow", func(t *testing.T) {
|
||||
mock.ExpectQuery("any").WillReturnRows(sqlmock.NewRows([]string{"foo"}).AddRow("bar"))
|
||||
conf := SqlConf{DataSource: mockedDatasource, DriverName: mysqlDriverName}
|
||||
conn := NewConn(conf)
|
||||
var val string
|
||||
assert.NoError(t, conn.QueryRow(&val, "any"))
|
||||
assert.Equal(t, "bar", val)
|
||||
})
|
||||
|
||||
t.Run("QueryRowPartial", func(t *testing.T) {
|
||||
mock.ExpectQuery("any").WillReturnRows(sqlmock.NewRows([]string{"foo"}).AddRow("bar"))
|
||||
conf := SqlConf{DataSource: mockedDatasource, DriverName: mysqlDriverName}
|
||||
conn := NewConn(conf)
|
||||
var val string
|
||||
assert.NoError(t, conn.QueryRowPartial(&val, "any"))
|
||||
assert.Equal(t, "bar", val)
|
||||
})
|
||||
|
||||
t.Run("QueryRows", func(t *testing.T) {
|
||||
mock.ExpectQuery("any").WillReturnRows(sqlmock.NewRows([]string{"any"}).AddRow("foo").AddRow("bar"))
|
||||
conf := SqlConf{DataSource: mockedDatasource, DriverName: mysqlDriverName}
|
||||
conn := NewConn(conf)
|
||||
var vals []string
|
||||
assert.NoError(t, conn.QueryRows(&vals, "any"))
|
||||
assert.ElementsMatch(t, []string{"foo", "bar"}, vals)
|
||||
})
|
||||
|
||||
t.Run("QueryRowsPartial", func(t *testing.T) {
|
||||
mock.ExpectQuery("any").WillReturnRows(sqlmock.NewRows([]string{"any"}).AddRow("foo").AddRow("bar"))
|
||||
conf := SqlConf{DataSource: mockedDatasource, DriverName: mysqlDriverName}
|
||||
conn := NewConn(conf)
|
||||
var vals []string
|
||||
assert.NoError(t, conn.QueryRowsPartial(&vals, "any"))
|
||||
assert.ElementsMatch(t, []string{"foo", "bar"}, vals)
|
||||
})
|
||||
}
|
||||
|
||||
func TestConfigSqlConnErr(t *testing.T) {
|
||||
t.Run("panic on empty config", func(t *testing.T) {
|
||||
original := logx.ExitOnFatal.True()
|
||||
logx.ExitOnFatal.Set(false)
|
||||
defer logx.ExitOnFatal.Set(original)
|
||||
|
||||
assert.Panics(t, func() {
|
||||
NewConn(SqlConf{})
|
||||
})
|
||||
})
|
||||
t.Run("on error", func(t *testing.T) {
|
||||
db, mock, err := sqlmock.New()
|
||||
assert.NotNil(t, db)
|
||||
assert.NotNil(t, mock)
|
||||
assert.Nil(t, err)
|
||||
connManager.Inject(mockedDatasource, db)
|
||||
|
||||
conf := SqlConf{DataSource: mockedDatasource, DriverName: mysqlDriverName}
|
||||
conn := NewConn(conf)
|
||||
conn.(*commonSqlConn).connProv = func(ctx context.Context) (*sql.DB, error) {
|
||||
return nil, errors.New("error")
|
||||
}
|
||||
_, err = conn.Prepare("any")
|
||||
assert.Error(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
func TestStatement(t *testing.T) {
|
||||
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||
mock.ExpectPrepare("any").WillBeClosed()
|
||||
@@ -303,6 +446,93 @@ func TestWithAcceptable(t *testing.T) {
|
||||
assert.True(t, conn.accept(acceptableErr3))
|
||||
}
|
||||
|
||||
func TestProvider(t *testing.T) {
|
||||
defer func() {
|
||||
_ = connManager.Close()
|
||||
}()
|
||||
|
||||
primaryDSN := "primary:password@tcp(127.0.0.1:3306)/primary_db"
|
||||
replicasDSN := []string{
|
||||
"replica_one:pwd@tcp(localhost:3306)/replica_one",
|
||||
"replica_two:pwd@tcp(localhost:3306)/replica_two",
|
||||
"replica_three:pwd@tcp(localhost:3306)/replica_three",
|
||||
}
|
||||
|
||||
primaryDB, err := connManager.GetResource(primaryDSN, func() (io.Closer, error) { return sql.Open(mysqlDriverName, primaryDSN) })
|
||||
assert.Nil(t, err)
|
||||
assert.NotNil(t, primaryDB)
|
||||
replicaOneDB, err := connManager.GetResource(replicasDSN[0], func() (io.Closer, error) { return sql.Open(mysqlDriverName, replicasDSN[0]) })
|
||||
assert.Nil(t, err)
|
||||
assert.NotNil(t, replicaOneDB)
|
||||
replicaTwoDB, err := connManager.GetResource(replicasDSN[1], func() (io.Closer, error) { return sql.Open(mysqlDriverName, replicasDSN[1]) })
|
||||
assert.Nil(t, err)
|
||||
assert.NotNil(t, replicaTwoDB)
|
||||
replicaThreeDB, err := connManager.GetResource(replicasDSN[2], func() (io.Closer, error) { return sql.Open(mysqlDriverName, replicasDSN[2]) })
|
||||
assert.Nil(t, err)
|
||||
assert.NotNil(t, replicaThreeDB)
|
||||
|
||||
sc := &commonSqlConn{}
|
||||
sc.connProv = getConnProvider(sc, mysqlDriverName, primaryDSN, policyRoundRobin, nil)
|
||||
|
||||
ctx := context.Background()
|
||||
db, err := sc.connProv(ctx)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, primaryDB, db)
|
||||
|
||||
ctx = WithWriteMode(ctx)
|
||||
db, err = sc.connProv(ctx)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, primaryDB, db)
|
||||
|
||||
ctx = WithReadPrimaryMode(ctx)
|
||||
db, err = sc.connProv(ctx)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, primaryDB, db)
|
||||
|
||||
// no mode set, should return primary
|
||||
ctx = context.Background()
|
||||
sc.connProv = getConnProvider(sc, mysqlDriverName, primaryDSN, policyRoundRobin, replicasDSN)
|
||||
db, err = sc.connProv(ctx)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, primaryDB, db)
|
||||
|
||||
ctx = WithReadReplicaMode(ctx)
|
||||
sc.connProv = getConnProvider(sc, mysqlDriverName, primaryDSN, policyRoundRobin, []string{replicasDSN[0]})
|
||||
db, err = sc.connProv(ctx)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, replicaOneDB, db)
|
||||
|
||||
// default policy is round-robin
|
||||
sc.connProv = getConnProvider(sc, mysqlDriverName, primaryDSN, policyRoundRobin, replicasDSN)
|
||||
replicas := []io.Closer{replicaOneDB, replicaTwoDB, replicaThreeDB}
|
||||
for i := 0; i < len(replicasDSN); i++ {
|
||||
db, err = sc.connProv(ctx)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, replicas[i], db)
|
||||
}
|
||||
|
||||
// random policy
|
||||
sc.connProv = getConnProvider(sc, mysqlDriverName, primaryDSN, policyRandom, replicasDSN)
|
||||
for i := 0; i < len(replicasDSN); i++ {
|
||||
db, err = sc.connProv(ctx)
|
||||
assert.Nil(t, err)
|
||||
assert.Contains(t, replicas, db)
|
||||
}
|
||||
|
||||
// unknown policy
|
||||
sc.connProv = getConnProvider(sc, mysqlDriverName, primaryDSN, "unknown", replicasDSN)
|
||||
_, err = sc.connProv(ctx)
|
||||
assert.NotNil(t, err)
|
||||
|
||||
// empty policy transforms to round-robin
|
||||
sc.connProv = getConnProvider(sc, mysqlDriverName, primaryDSN, "", replicasDSN)
|
||||
for i := 0; i < len(replicasDSN); i++ {
|
||||
db, err = sc.connProv(ctx)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, replicas[i], db)
|
||||
}
|
||||
}
|
||||
|
||||
func buildConn() (mock sqlmock.Sqlmock, err error) {
|
||||
_, err = connManager.GetResource(mockedDatasource, func() (io.Closer, error) {
|
||||
var db *sql.DB
|
||||
|
||||
Reference in New Issue
Block a user