mirror of
https://github.com/zeromicro/go-zero.git
synced 2026-05-10 16:30:01 +08:00
sql read write support (#4976)
Co-authored-by: light.zhou <light.zhou@bkyo.io>
This commit is contained in:
@@ -4,6 +4,9 @@ import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"errors"
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"sync/atomic"
|
||||
|
||||
"github.com/zeromicro/go-zero/core/breaker"
|
||||
"github.com/zeromicro/go-zero/core/errorx"
|
||||
@@ -52,9 +55,10 @@ type (
|
||||
beginTx beginnable
|
||||
brk breaker.Breaker
|
||||
accept breaker.Acceptable
|
||||
index uint32
|
||||
}
|
||||
|
||||
connProvider func() (*sql.DB, error)
|
||||
connProvider func(ctx context.Context) (*sql.DB, error)
|
||||
|
||||
sessionConn interface {
|
||||
Exec(query string, args ...any) (sql.Result, error)
|
||||
@@ -64,10 +68,31 @@ type (
|
||||
}
|
||||
)
|
||||
|
||||
// NewConn returns a SqlConn with the given SqlConf.
|
||||
func NewConn(c SqlConf, opts ...SqlOption) SqlConn {
|
||||
if err := c.Validate(); err != nil {
|
||||
logx.Must(err)
|
||||
}
|
||||
|
||||
conn := &commonSqlConn{
|
||||
onError: func(ctx context.Context, err error) {
|
||||
logInstanceError(ctx, c.DataSource, err)
|
||||
},
|
||||
beginTx: begin,
|
||||
brk: breaker.NewBreaker(),
|
||||
}
|
||||
for _, opt := range opts {
|
||||
opt(conn)
|
||||
}
|
||||
conn.connProv = getConnProvider(conn, c.DriverName, c.DataSource, c.Policy, c.Replicas)
|
||||
|
||||
return conn
|
||||
}
|
||||
|
||||
// NewSqlConn returns a SqlConn with given driver name and datasource.
|
||||
func NewSqlConn(driverName, datasource string, opts ...SqlOption) SqlConn {
|
||||
conn := &commonSqlConn{
|
||||
connProv: func() (*sql.DB, error) {
|
||||
connProv: func(context.Context) (*sql.DB, error) {
|
||||
return getSqlConn(driverName, datasource)
|
||||
},
|
||||
onError: func(ctx context.Context, err error) {
|
||||
@@ -87,7 +112,7 @@ func NewSqlConn(driverName, datasource string, opts ...SqlOption) SqlConn {
|
||||
// Use it with caution; it's provided for other ORM to interact with.
|
||||
func NewSqlConnFromDB(db *sql.DB, opts ...SqlOption) SqlConn {
|
||||
conn := &commonSqlConn{
|
||||
connProv: func() (*sql.DB, error) {
|
||||
connProv: func(ctx context.Context) (*sql.DB, error) {
|
||||
return db, nil
|
||||
},
|
||||
onError: func(ctx context.Context, err error) {
|
||||
@@ -123,7 +148,7 @@ func (db *commonSqlConn) ExecCtx(ctx context.Context, q string, args ...any) (
|
||||
|
||||
err = db.brk.DoWithAcceptableCtx(ctx, func() error {
|
||||
var conn *sql.DB
|
||||
conn, err = db.connProv()
|
||||
conn, err = db.connProv(ctx)
|
||||
if err != nil {
|
||||
db.onError(ctx, err)
|
||||
return err
|
||||
@@ -151,7 +176,7 @@ func (db *commonSqlConn) PrepareCtx(ctx context.Context, query string) (stmt Stm
|
||||
|
||||
err = db.brk.DoWithAcceptableCtx(ctx, func() error {
|
||||
var conn *sql.DB
|
||||
conn, err = db.connProv()
|
||||
conn, err = db.connProv(ctx)
|
||||
if err != nil {
|
||||
db.onError(ctx, err)
|
||||
return err
|
||||
@@ -242,7 +267,7 @@ func (db *commonSqlConn) QueryRowsPartialCtx(ctx context.Context, v any,
|
||||
}
|
||||
|
||||
func (db *commonSqlConn) RawDB() (*sql.DB, error) {
|
||||
return db.connProv()
|
||||
return db.connProv(context.Background())
|
||||
}
|
||||
|
||||
func (db *commonSqlConn) Transact(fn func(Session) error) error {
|
||||
@@ -288,7 +313,7 @@ func (db *commonSqlConn) queryRows(ctx context.Context, scanner func(*sql.Rows)
|
||||
q string, args ...any) (err error) {
|
||||
var scanFailed bool
|
||||
err = db.brk.DoWithAcceptableCtx(ctx, func() error {
|
||||
conn, err := db.connProv()
|
||||
conn, err := db.connProv(ctx)
|
||||
if err != nil {
|
||||
db.onError(ctx, err)
|
||||
return err
|
||||
@@ -311,6 +336,38 @@ func (db *commonSqlConn) queryRows(ctx context.Context, scanner func(*sql.Rows)
|
||||
return
|
||||
}
|
||||
|
||||
func getConnProvider(sc *commonSqlConn, driverName, datasource, policy string, replicas []string) connProvider {
|
||||
return func(ctx context.Context) (*sql.DB, error) {
|
||||
replicaCount := len(replicas)
|
||||
|
||||
if replicaCount == 0 || !useReplica(ctx) {
|
||||
return getSqlConn(driverName, datasource)
|
||||
}
|
||||
|
||||
var dsn string
|
||||
|
||||
if replicaCount == 1 {
|
||||
dsn = replicas[0]
|
||||
} else {
|
||||
if len(policy) == 0 {
|
||||
policy = policyRoundRobin
|
||||
}
|
||||
|
||||
switch policy {
|
||||
case policyRandom:
|
||||
dsn = replicas[rand.Intn(replicaCount)]
|
||||
case policyRoundRobin:
|
||||
index := atomic.AddUint32(&sc.index, 1) - 1
|
||||
dsn = replicas[index%uint32(replicaCount)]
|
||||
default:
|
||||
return nil, fmt.Errorf("unknown policy: %s", policy)
|
||||
}
|
||||
}
|
||||
|
||||
return getSqlConn(driverName, dsn)
|
||||
}
|
||||
}
|
||||
|
||||
// WithAcceptable returns a SqlOption that setting the acceptable function.
|
||||
// acceptable is the func to check if the error can be accepted.
|
||||
func WithAcceptable(acceptable func(err error) bool) SqlOption {
|
||||
|
||||
Reference in New Issue
Block a user