sql read write support (#4976)

Co-authored-by: light.zhou <light.zhou@bkyo.io>
This commit is contained in:
zhoushuguang
2025-07-10 00:04:56 +08:00
committed by GitHub
parent 95d5b81f44
commit 8c6266f338
8 changed files with 553 additions and 10 deletions

View File

@@ -4,6 +4,9 @@ import (
"context"
"database/sql"
"errors"
"fmt"
"math/rand"
"sync/atomic"
"github.com/zeromicro/go-zero/core/breaker"
"github.com/zeromicro/go-zero/core/errorx"
@@ -52,9 +55,10 @@ type (
beginTx beginnable
brk breaker.Breaker
accept breaker.Acceptable
index uint32
}
connProvider func() (*sql.DB, error)
connProvider func(ctx context.Context) (*sql.DB, error)
sessionConn interface {
Exec(query string, args ...any) (sql.Result, error)
@@ -64,10 +68,31 @@ type (
}
)
// NewConn returns a SqlConn with the given SqlConf.
func NewConn(c SqlConf, opts ...SqlOption) SqlConn {
if err := c.Validate(); err != nil {
logx.Must(err)
}
conn := &commonSqlConn{
onError: func(ctx context.Context, err error) {
logInstanceError(ctx, c.DataSource, err)
},
beginTx: begin,
brk: breaker.NewBreaker(),
}
for _, opt := range opts {
opt(conn)
}
conn.connProv = getConnProvider(conn, c.DriverName, c.DataSource, c.Policy, c.Replicas)
return conn
}
// NewSqlConn returns a SqlConn with given driver name and datasource.
func NewSqlConn(driverName, datasource string, opts ...SqlOption) SqlConn {
conn := &commonSqlConn{
connProv: func() (*sql.DB, error) {
connProv: func(context.Context) (*sql.DB, error) {
return getSqlConn(driverName, datasource)
},
onError: func(ctx context.Context, err error) {
@@ -87,7 +112,7 @@ func NewSqlConn(driverName, datasource string, opts ...SqlOption) SqlConn {
// Use it with caution; it's provided for other ORM to interact with.
func NewSqlConnFromDB(db *sql.DB, opts ...SqlOption) SqlConn {
conn := &commonSqlConn{
connProv: func() (*sql.DB, error) {
connProv: func(ctx context.Context) (*sql.DB, error) {
return db, nil
},
onError: func(ctx context.Context, err error) {
@@ -123,7 +148,7 @@ func (db *commonSqlConn) ExecCtx(ctx context.Context, q string, args ...any) (
err = db.brk.DoWithAcceptableCtx(ctx, func() error {
var conn *sql.DB
conn, err = db.connProv()
conn, err = db.connProv(ctx)
if err != nil {
db.onError(ctx, err)
return err
@@ -151,7 +176,7 @@ func (db *commonSqlConn) PrepareCtx(ctx context.Context, query string) (stmt Stm
err = db.brk.DoWithAcceptableCtx(ctx, func() error {
var conn *sql.DB
conn, err = db.connProv()
conn, err = db.connProv(ctx)
if err != nil {
db.onError(ctx, err)
return err
@@ -242,7 +267,7 @@ func (db *commonSqlConn) QueryRowsPartialCtx(ctx context.Context, v any,
}
func (db *commonSqlConn) RawDB() (*sql.DB, error) {
return db.connProv()
return db.connProv(context.Background())
}
func (db *commonSqlConn) Transact(fn func(Session) error) error {
@@ -288,7 +313,7 @@ func (db *commonSqlConn) queryRows(ctx context.Context, scanner func(*sql.Rows)
q string, args ...any) (err error) {
var scanFailed bool
err = db.brk.DoWithAcceptableCtx(ctx, func() error {
conn, err := db.connProv()
conn, err := db.connProv(ctx)
if err != nil {
db.onError(ctx, err)
return err
@@ -311,6 +336,38 @@ func (db *commonSqlConn) queryRows(ctx context.Context, scanner func(*sql.Rows)
return
}
func getConnProvider(sc *commonSqlConn, driverName, datasource, policy string, replicas []string) connProvider {
return func(ctx context.Context) (*sql.DB, error) {
replicaCount := len(replicas)
if replicaCount == 0 || !useReplica(ctx) {
return getSqlConn(driverName, datasource)
}
var dsn string
if replicaCount == 1 {
dsn = replicas[0]
} else {
if len(policy) == 0 {
policy = policyRoundRobin
}
switch policy {
case policyRandom:
dsn = replicas[rand.Intn(replicaCount)]
case policyRoundRobin:
index := atomic.AddUint32(&sc.index, 1) - 1
dsn = replicas[index%uint32(replicaCount)]
default:
return nil, fmt.Errorf("unknown policy: %s", policy)
}
}
return getSqlConn(driverName, dsn)
}
}
// WithAcceptable returns a SqlOption that setting the acceptable function.
// acceptable is the func to check if the error can be accepted.
func WithAcceptable(acceptable func(err error) bool) SqlOption {