initial import

This commit is contained in:
kevin
2020-07-26 17:09:05 +08:00
commit 7e3a369a8f
647 changed files with 54754 additions and 0 deletions

5
core/stores/cache/cacheconf.go vendored Normal file
View File

@@ -0,0 +1,5 @@
package cache
import "zero/core/stores/internal"
type CacheConf = internal.ClusterConf

21
core/stores/cache/cacheopt.go vendored Normal file
View File

@@ -0,0 +1,21 @@
package cache
import (
"time"
"zero/core/stores/internal"
)
type Option = internal.Option
func WithExpiry(expiry time.Duration) Option {
return func(o *internal.Options) {
o.Expiry = expiry
}
}
func WithNotFoundExpiry(expiry time.Duration) Option {
return func(o *internal.Options) {
o.NotFoundExpiry = expiry
}
}

View File

@@ -0,0 +1,13 @@
package clickhouse
import (
"zero/core/stores/sqlx"
_ "github.com/kshvakov/clickhouse"
)
const clickHouseDriverName = "clickhouse"
func New(datasource string, opts ...sqlx.SqlOption) sqlx.SqlConn {
return sqlx.NewSqlConn(clickHouseDriverName, datasource, opts...)
}

View File

@@ -0,0 +1,129 @@
package internal
import (
"fmt"
"log"
"time"
"zero/core/errorx"
"zero/core/hash"
"zero/core/syncx"
)
type (
Cache interface {
DelCache(keys ...string) error
GetCache(key string, v interface{}) error
SetCache(key string, v interface{}) error
SetCacheWithExpire(key string, v interface{}, expire time.Duration) error
Take(v interface{}, key string, query func(v interface{}) error) error
TakeWithExpire(v interface{}, key string, query func(v interface{}, expire time.Duration) error) error
}
cacheCluster struct {
dispatcher *hash.ConsistentHash
errNotFound error
}
)
func NewCache(c ClusterConf, barrier syncx.SharedCalls, st *CacheStat, errNotFound error,
opts ...Option) Cache {
if len(c) == 0 || TotalWeights(c) <= 0 {
log.Fatal("no cache nodes")
}
if len(c) == 1 {
return NewCacheNode(c[0].NewRedis(), barrier, st, errNotFound, opts...)
}
dispatcher := hash.NewConsistentHash()
for _, node := range c {
cn := NewCacheNode(node.NewRedis(), barrier, st, errNotFound, opts...)
dispatcher.AddWithWeight(cn, node.Weight)
}
return cacheCluster{
dispatcher: dispatcher,
errNotFound: errNotFound,
}
}
func (cc cacheCluster) DelCache(keys ...string) error {
switch len(keys) {
case 0:
return nil
case 1:
key := keys[0]
c, ok := cc.dispatcher.Get(key)
if !ok {
return cc.errNotFound
}
return c.(Cache).DelCache(key)
default:
var be errorx.BatchError
nodes := make(map[interface{}][]string)
for _, key := range keys {
c, ok := cc.dispatcher.Get(key)
if !ok {
be.Add(fmt.Errorf("key %q not found", key))
continue
}
nodes[c] = append(nodes[c], key)
}
for c, ks := range nodes {
if err := c.(Cache).DelCache(ks...); err != nil {
be.Add(err)
}
}
return be.Err()
}
}
func (cc cacheCluster) GetCache(key string, v interface{}) error {
c, ok := cc.dispatcher.Get(key)
if !ok {
return cc.errNotFound
}
return c.(Cache).GetCache(key, v)
}
func (cc cacheCluster) SetCache(key string, v interface{}) error {
c, ok := cc.dispatcher.Get(key)
if !ok {
return cc.errNotFound
}
return c.(Cache).SetCache(key, v)
}
func (cc cacheCluster) SetCacheWithExpire(key string, v interface{}, expire time.Duration) error {
c, ok := cc.dispatcher.Get(key)
if !ok {
return cc.errNotFound
}
return c.(Cache).SetCacheWithExpire(key, v, expire)
}
func (cc cacheCluster) Take(v interface{}, key string, query func(v interface{}) error) error {
c, ok := cc.dispatcher.Get(key)
if !ok {
return cc.errNotFound
}
return c.(Cache).Take(v, key, query)
}
func (cc cacheCluster) TakeWithExpire(v interface{}, key string,
query func(v interface{}, expire time.Duration) error) error {
c, ok := cc.dispatcher.Get(key)
if !ok {
return cc.errNotFound
}
return c.(Cache).TakeWithExpire(v, key, query)
}

View File

@@ -0,0 +1,201 @@
package internal
import (
"encoding/json"
"fmt"
"math"
"strconv"
"testing"
"time"
"zero/core/errorx"
"zero/core/hash"
"zero/core/stores/redis"
"zero/core/syncx"
"github.com/alicebob/miniredis"
"github.com/stretchr/testify/assert"
)
type mockedNode struct {
vals map[string][]byte
errNotFound error
}
func (mc *mockedNode) DelCache(keys ...string) error {
var be errorx.BatchError
for _, key := range keys {
if _, ok := mc.vals[key]; !ok {
be.Add(mc.errNotFound)
} else {
delete(mc.vals, key)
}
}
return be.Err()
}
func (mc *mockedNode) GetCache(key string, v interface{}) error {
bs, ok := mc.vals[key]
if ok {
return json.Unmarshal(bs, v)
}
return mc.errNotFound
}
func (mc *mockedNode) SetCache(key string, v interface{}) error {
data, err := json.Marshal(v)
if err != nil {
return err
}
mc.vals[key] = data
return nil
}
func (mc *mockedNode) SetCacheWithExpire(key string, v interface{}, expire time.Duration) error {
return mc.SetCache(key, v)
}
func (mc *mockedNode) Take(v interface{}, key string, query func(v interface{}) error) error {
if _, ok := mc.vals[key]; ok {
return mc.GetCache(key, v)
}
if err := query(v); err != nil {
return err
}
return mc.SetCache(key, v)
}
func (mc *mockedNode) TakeWithExpire(v interface{}, key string, query func(v interface{}, expire time.Duration) error) error {
return mc.Take(v, key, func(v interface{}) error {
return query(v, 0)
})
}
func TestCache_SetDel(t *testing.T) {
const total = 1000
r1 := miniredis.NewMiniRedis()
assert.Nil(t, r1.Start())
defer r1.Close()
r2 := miniredis.NewMiniRedis()
assert.Nil(t, r2.Start())
defer r2.Close()
conf := ClusterConf{
{
RedisConf: redis.RedisConf{
Host: r1.Addr(),
Type: redis.NodeType,
},
Weight: 100,
},
{
RedisConf: redis.RedisConf{
Host: r2.Addr(),
Type: redis.NodeType,
},
Weight: 100,
},
}
c := NewCache(conf, syncx.NewSharedCalls(), NewCacheStat("mock"), errPlaceholder)
for i := 0; i < total; i++ {
if i%2 == 0 {
assert.Nil(t, c.SetCache(fmt.Sprintf("key/%d", i), i))
} else {
assert.Nil(t, c.SetCacheWithExpire(fmt.Sprintf("key/%d", i), i, 0))
}
}
for i := 0; i < total; i++ {
var v int
assert.Nil(t, c.GetCache(fmt.Sprintf("key/%d", i), &v))
assert.Equal(t, i, v)
}
for i := 0; i < total; i++ {
assert.Nil(t, c.DelCache(fmt.Sprintf("key/%d", i)))
}
for i := 0; i < total; i++ {
var v int
assert.Equal(t, errPlaceholder, c.GetCache(fmt.Sprintf("key/%d", i), &v))
assert.Equal(t, 0, v)
}
}
func TestCache_Balance(t *testing.T) {
const (
numNodes = 100
total = 10000
)
dispatcher := hash.NewConsistentHash()
maps := make([]map[string][]byte, numNodes)
for i := 0; i < numNodes; i++ {
maps[i] = map[string][]byte{
strconv.Itoa(i): []byte(strconv.Itoa(i)),
}
}
for i := 0; i < numNodes; i++ {
dispatcher.AddWithWeight(&mockedNode{
vals: maps[i],
errNotFound: errPlaceholder,
}, 100)
}
c := cacheCluster{
dispatcher: dispatcher,
errNotFound: errPlaceholder,
}
for i := 0; i < total; i++ {
assert.Nil(t, c.SetCache(strconv.Itoa(i), i))
}
counts := make(map[int]int)
for i, m := range maps {
counts[i] = len(m)
}
entropy := calcEntropy(counts, total)
assert.True(t, len(counts) > 1)
assert.True(t, entropy > .95, fmt.Sprintf("entropy should be greater than 0.95, but got %.2f", entropy))
for i := 0; i < total; i++ {
var v int
assert.Nil(t, c.GetCache(strconv.Itoa(i), &v))
assert.Equal(t, i, v)
}
for i := 0; i < total/10; i++ {
assert.Nil(t, c.DelCache(strconv.Itoa(i*10), strconv.Itoa(i*10+1), strconv.Itoa(i*10+2)))
assert.Nil(t, c.DelCache(strconv.Itoa(i*10+9)))
}
var count int
for i := 0; i < total/10; i++ {
var val int
if i%2 == 0 {
assert.Nil(t, c.Take(&val, strconv.Itoa(i*10), func(v interface{}) error {
*v.(*int) = i
count++
return nil
}))
} else {
assert.Nil(t, c.TakeWithExpire(&val, strconv.Itoa(i*10), func(v interface{}, expire time.Duration) error {
*v.(*int) = i
count++
return nil
}))
}
assert.Equal(t, i, val)
}
assert.Equal(t, total/10, count)
}
func calcEntropy(m map[int]int, total int) float64 {
var entropy float64
for _, v := range m {
proba := float64(v) / float64(total)
entropy -= proba * math.Log2(proba)
}
return entropy / math.Log2(float64(len(m)))
}

View File

@@ -0,0 +1,208 @@
package internal
import (
"encoding/json"
"errors"
"fmt"
"math/rand"
"sync"
"time"
"zero/core/logx"
"zero/core/mathx"
"zero/core/stat"
"zero/core/stores/redis"
"zero/core/syncx"
)
const (
notFoundPlaceholder = "*"
// make the expiry unstable to avoid lots of cached items expire at the same time
// make the unstable expiry to be [0.95, 1.05] * seconds
expiryDeviation = 0.05
)
// indicates there is no such value associate with the key
var errPlaceholder = errors.New("placeholder")
type cacheNode struct {
rds *redis.Redis
expiry time.Duration
notFoundExpiry time.Duration
barrier syncx.SharedCalls
r *rand.Rand
lock *sync.Mutex
unstableExpiry mathx.Unstable
stat *CacheStat
errNotFound error
}
func NewCacheNode(rds *redis.Redis, barrier syncx.SharedCalls, st *CacheStat,
errNotFound error, opts ...Option) Cache {
o := newOptions(opts...)
return cacheNode{
rds: rds,
expiry: o.Expiry,
notFoundExpiry: o.NotFoundExpiry,
barrier: barrier,
r: rand.New(rand.NewSource(time.Now().UnixNano())),
lock: new(sync.Mutex),
unstableExpiry: mathx.NewUnstable(expiryDeviation),
stat: st,
errNotFound: errNotFound,
}
}
func (c cacheNode) DelCache(keys ...string) error {
if len(keys) == 0 {
return nil
}
if _, err := c.rds.Del(keys...); err != nil {
logx.Errorf("failed to clear cache with keys: %q, error: %v", formatKeys(keys), err)
c.asyncRetryDelCache(keys...)
}
return nil
}
func (c cacheNode) GetCache(key string, v interface{}) error {
if err := c.doGetCache(key, v); err == errPlaceholder {
return c.errNotFound
} else {
return err
}
}
func (c cacheNode) SetCache(key string, v interface{}) error {
return c.SetCacheWithExpire(key, v, c.aroundDuration(c.expiry))
}
func (c cacheNode) SetCacheWithExpire(key string, v interface{}, expire time.Duration) error {
data, err := json.Marshal(v)
if err != nil {
return err
}
return c.rds.Setex(key, string(data), int(expire.Seconds()))
}
func (c cacheNode) String() string {
return c.rds.Addr
}
func (c cacheNode) Take(v interface{}, key string, query func(v interface{}) error) error {
return c.doTake(v, key, query, func(v interface{}) error {
return c.SetCache(key, v)
})
}
func (c cacheNode) TakeWithExpire(v interface{}, key string,
query func(v interface{}, expire time.Duration) error) error {
expire := c.aroundDuration(c.expiry)
return c.doTake(v, key, func(v interface{}) error {
return query(v, expire)
}, func(v interface{}) error {
return c.SetCacheWithExpire(key, v, expire)
})
}
func (c cacheNode) aroundDuration(duration time.Duration) time.Duration {
return c.unstableExpiry.AroundDuration(duration)
}
func (c cacheNode) asyncRetryDelCache(keys ...string) {
AddCleanTask(func() error {
_, err := c.rds.Del(keys...)
return err
}, keys...)
}
func (c cacheNode) doGetCache(key string, v interface{}) error {
c.stat.IncrementTotal()
data, err := c.rds.Get(key)
if err != nil {
c.stat.IncrementMiss()
return err
}
if len(data) == 0 {
c.stat.IncrementMiss()
return c.errNotFound
}
c.stat.IncrementHit()
if data == notFoundPlaceholder {
return errPlaceholder
}
return c.processCache(key, data, v)
}
func (c cacheNode) doTake(v interface{}, key string, query func(v interface{}) error,
cacheVal func(v interface{}) error) error {
val, fresh, err := c.barrier.DoEx(key, func() (interface{}, error) {
if err := c.doGetCache(key, v); err != nil {
if err == errPlaceholder {
return nil, c.errNotFound
} else if err != c.errNotFound {
// why we just return the error instead of query from db,
// because we don't allow the disaster pass to the dbs.
// fail fast, in case we bring down the dbs.
return nil, err
}
if err = query(v); err == c.errNotFound {
if err = c.setCacheWithNotFound(key); err != nil {
logx.Error(err)
}
return nil, c.errNotFound
} else if err != nil {
c.stat.IncrementDbFails()
return nil, err
}
if err = cacheVal(v); err != nil {
logx.Error(err)
}
}
return json.Marshal(v)
})
if err != nil {
return err
}
if fresh {
return nil
} else {
// got the result from previous ongoing query
c.stat.IncrementTotal()
c.stat.IncrementHit()
}
return json.Unmarshal(val.([]byte), v)
}
func (c cacheNode) processCache(key string, data string, v interface{}) error {
err := json.Unmarshal([]byte(data), v)
if err == nil {
return nil
}
report := fmt.Sprintf("unmarshal cache, node: %s, key: %s, value: %s, error: %v",
c.rds.Addr, key, data, err)
logx.Error(report)
stat.Report(report)
if _, e := c.rds.Del(key); e != nil {
logx.Errorf("delete invalid cache, node: %s, key: %s, value: %s, error: %v",
c.rds.Addr, key, data, e)
}
// returns errNotFound to reload the value by the given queryFn
return c.errNotFound
}
func (c cacheNode) setCacheWithNotFound(key string) error {
return c.rds.Setex(key, notFoundPlaceholder, int(c.aroundDuration(c.notFoundExpiry).Seconds()))
}

View File

@@ -0,0 +1,66 @@
package internal
import (
"errors"
"math/rand"
"sync"
"testing"
"time"
"zero/core/logx"
"zero/core/mathx"
"zero/core/stat"
"zero/core/stores/redis"
"github.com/alicebob/miniredis"
"github.com/stretchr/testify/assert"
)
func init() {
logx.Disable()
stat.SetReporter(nil)
}
func TestCacheNode_DelCache(t *testing.T) {
s, err := miniredis.Run()
assert.Nil(t, err)
defer s.Close()
cn := cacheNode{
rds: redis.NewRedis(s.Addr(), redis.NodeType),
r: rand.New(rand.NewSource(time.Now().UnixNano())),
lock: new(sync.Mutex),
unstableExpiry: mathx.NewUnstable(expiryDeviation),
stat: NewCacheStat("any"),
errNotFound: errors.New("any"),
}
assert.Nil(t, cn.DelCache())
assert.Nil(t, cn.DelCache([]string{}...))
assert.Nil(t, cn.DelCache(make([]string, 0)...))
cn.SetCache("first", "one")
assert.Nil(t, cn.DelCache("first"))
cn.SetCache("first", "one")
cn.SetCache("second", "two")
assert.Nil(t, cn.DelCache("first", "second"))
}
func TestCacheNode_InvalidCache(t *testing.T) {
s, err := miniredis.Run()
assert.Nil(t, err)
defer s.Close()
cn := cacheNode{
rds: redis.NewRedis(s.Addr(), redis.NodeType),
r: rand.New(rand.NewSource(time.Now().UnixNano())),
lock: new(sync.Mutex),
unstableExpiry: mathx.NewUnstable(expiryDeviation),
stat: NewCacheStat("any"),
errNotFound: errors.New("any"),
}
s.Set("any", "value")
var str string
assert.NotNil(t, cn.GetCache("any", &str))
assert.Equal(t, "", str)
_, err = s.Get("any")
assert.Equal(t, miniredis.ErrKeyNotFound, err)
}

View File

@@ -0,0 +1,33 @@
package internal
import "time"
const (
defaultExpiry = time.Hour * 24 * 7
defaultNotFoundExpiry = time.Minute
)
type (
Options struct {
Expiry time.Duration
NotFoundExpiry time.Duration
}
Option func(o *Options)
)
func newOptions(opts ...Option) Options {
var o Options
for _, opt := range opts {
opt(&o)
}
if o.Expiry <= 0 {
o.Expiry = defaultExpiry
}
if o.NotFoundExpiry <= 0 {
o.NotFoundExpiry = defaultNotFoundExpiry
}
return o
}

View File

@@ -0,0 +1,67 @@
package internal
import (
"sync/atomic"
"time"
"zero/core/logx"
)
const statInterval = time.Minute
type CacheStat struct {
name string
// export the fields to let the unit tests working,
// reside in internal package, doesn't matter.
Total uint64
Hit uint64
Miss uint64
DbFails uint64
}
func NewCacheStat(name string) *CacheStat {
ret := &CacheStat{
name: name,
}
go ret.statLoop()
return ret
}
func (cs *CacheStat) IncrementTotal() {
atomic.AddUint64(&cs.Total, 1)
}
func (cs *CacheStat) IncrementHit() {
atomic.AddUint64(&cs.Hit, 1)
}
func (cs *CacheStat) IncrementMiss() {
atomic.AddUint64(&cs.Miss, 1)
}
func (cs *CacheStat) IncrementDbFails() {
atomic.AddUint64(&cs.DbFails, 1)
}
func (cs *CacheStat) statLoop() {
ticker := time.NewTicker(statInterval)
defer ticker.Stop()
for {
select {
case <-ticker.C:
total := atomic.SwapUint64(&cs.Total, 0)
if total == 0 {
continue
}
hit := atomic.SwapUint64(&cs.Hit, 0)
percent := 100 * float32(hit) / float32(total)
miss := atomic.SwapUint64(&cs.Miss, 0)
dbf := atomic.SwapUint64(&cs.DbFails, 0)
logx.Statf("dbcache(%s) - qpm: %d, hit_ratio: %.1f%%, hit: %d, miss: %d, db_fails: %d",
cs.name, total, percent, hit, miss, dbf)
}
}
}

View File

@@ -0,0 +1,85 @@
package internal
import (
"fmt"
"time"
"zero/core/collection"
"zero/core/lang"
"zero/core/logx"
"zero/core/proc"
"zero/core/stat"
"zero/core/stringx"
"zero/core/threading"
)
const (
timingWheelSlots = 300
cleanWorkers = 5
taskKeyLen = 8
)
var (
timingWheel *collection.TimingWheel
taskRunner = threading.NewTaskRunner(cleanWorkers)
)
type delayTask struct {
delay time.Duration
task func() error
keys []string
}
func init() {
var err error
timingWheel, err = collection.NewTimingWheel(time.Second, timingWheelSlots, clean)
lang.Must(err)
proc.AddShutdownListener(func() {
timingWheel.Drain(clean)
})
}
func AddCleanTask(task func() error, keys ...string) {
timingWheel.SetTimer(stringx.Randn(taskKeyLen), delayTask{
delay: time.Second,
task: task,
keys: keys,
}, time.Second)
}
func clean(key, value interface{}) {
taskRunner.Schedule(func() {
dt := value.(delayTask)
err := dt.task()
if err == nil {
return
}
next, ok := nextDelay(dt.delay)
if ok {
dt.delay = next
timingWheel.SetTimer(key, dt, next)
} else {
msg := fmt.Sprintf("retried but failed to clear cache with keys: %q, error: %v",
formatKeys(dt.keys), err)
logx.Error(msg)
stat.Report(msg)
}
})
}
func nextDelay(delay time.Duration) (time.Duration, bool) {
switch delay {
case time.Second:
return time.Second * 5, true
case time.Second * 5:
return time.Minute, true
case time.Minute:
return time.Minute * 5, true
case time.Minute * 5:
return time.Hour, true
default:
return 0, false
}
}

View File

@@ -0,0 +1,12 @@
package internal
import "zero/core/stores/redis"
type (
ClusterConf []NodeConf
NodeConf struct {
redis.RedisConf
Weight int `json:",default=100"`
}
)

View File

@@ -0,0 +1,22 @@
package internal
import "strings"
const keySeparator = ","
func TotalWeights(c []NodeConf) int {
var weights int
for _, node := range c {
if node.Weight < 0 {
node.Weight = 0
}
weights += node.Weight
}
return weights
}
func formatKeys(keys []string) string {
return strings.Join(keys, keySeparator)
}

5
core/stores/kv/config.go Normal file
View File

@@ -0,0 +1,5 @@
package kv
import "zero/core/stores/internal"
type KvConf = internal.ClusterConf

653
core/stores/kv/store.go Normal file
View File

@@ -0,0 +1,653 @@
package kv
import (
"errors"
"log"
"zero/core/errorx"
"zero/core/hash"
"zero/core/stores/internal"
"zero/core/stores/redis"
)
var ErrNoRedisNode = errors.New("no redis node")
type (
Store interface {
Del(keys ...string) (int, error)
Eval(script string, key string, args ...interface{}) (interface{}, error)
Exists(key string) (bool, error)
Expire(key string, seconds int) error
Expireat(key string, expireTime int64) error
Get(key string) (string, error)
Hdel(key, field string) (bool, error)
Hexists(key, field string) (bool, error)
Hget(key, field string) (string, error)
Hgetall(key string) (map[string]string, error)
Hincrby(key, field string, increment int) (int, error)
Hkeys(key string) ([]string, error)
Hlen(key string) (int, error)
Hmget(key string, fields ...string) ([]string, error)
Hset(key, field, value string) error
Hsetnx(key, field, value string) (bool, error)
Hmset(key string, fieldsAndValues map[string]string) error
Hvals(key string) ([]string, error)
Incr(key string) (int64, error)
Incrby(key string, increment int64) (int64, error)
Llen(key string) (int, error)
Lpop(key string) (string, error)
Lpush(key string, values ...interface{}) (int, error)
Lrange(key string, start int, stop int) ([]string, error)
Lrem(key string, count int, value string) (int, error)
Persist(key string) (bool, error)
Pfadd(key string, values ...interface{}) (bool, error)
Pfcount(key string) (int64, error)
Rpush(key string, values ...interface{}) (int, error)
Sadd(key string, values ...interface{}) (int, error)
Scard(key string) (int64, error)
Set(key string, value string) error
Setex(key, value string, seconds int) error
Setnx(key, value string) (bool, error)
SetnxEx(key, value string, seconds int) (bool, error)
Sismember(key string, value interface{}) (bool, error)
Smembers(key string) ([]string, error)
Spop(key string) (string, error)
Srandmember(key string, count int) ([]string, error)
Srem(key string, values ...interface{}) (int, error)
Sscan(key string, cursor uint64, match string, count int64) (keys []string, cur uint64, err error)
Ttl(key string) (int, error)
Zadd(key string, score int64, value string) (bool, error)
Zadds(key string, ps ...redis.Pair) (int64, error)
Zcard(key string) (int, error)
Zcount(key string, start, stop int64) (int, error)
Zincrby(key string, increment int64, field string) (int64, error)
Zrange(key string, start, stop int64) ([]string, error)
ZrangeWithScores(key string, start, stop int64) ([]redis.Pair, error)
ZrangebyscoreWithScores(key string, start, stop int64) ([]redis.Pair, error)
ZrangebyscoreWithScoresAndLimit(key string, start, stop int64, page, size int) ([]redis.Pair, error)
Zrank(key, field string) (int64, error)
Zrem(key string, values ...interface{}) (int, error)
Zremrangebyrank(key string, start, stop int64) (int, error)
Zremrangebyscore(key string, start, stop int64) (int, error)
Zrevrange(key string, start, stop int64) ([]string, error)
ZrevrangebyscoreWithScores(key string, start, stop int64) ([]redis.Pair, error)
ZrevrangebyscoreWithScoresAndLimit(key string, start, stop int64, page, size int) ([]redis.Pair, error)
Zscore(key string, value string) (int64, error)
}
clusterStore struct {
dispatcher *hash.ConsistentHash
}
)
func NewStore(c KvConf) Store {
if len(c) == 0 || internal.TotalWeights(c) <= 0 {
log.Fatal("no cache nodes")
}
// even if only one node, we chose to use consistent hash,
// because Store and redis.Redis has different methods.
dispatcher := hash.NewConsistentHash()
for _, node := range c {
cn := node.NewRedis()
dispatcher.AddWithWeight(cn, node.Weight)
}
return clusterStore{
dispatcher: dispatcher,
}
}
func (cs clusterStore) Del(keys ...string) (int, error) {
var val int
var be errorx.BatchError
for _, key := range keys {
node, e := cs.getRedis(key)
if e != nil {
be.Add(e)
continue
}
if v, e := node.Del(key); e != nil {
be.Add(e)
} else {
val += v
}
}
return val, be.Err()
}
func (cs clusterStore) Eval(script string, key string, args ...interface{}) (interface{}, error) {
node, err := cs.getRedis(key)
if err != nil {
return nil, err
}
return node.Eval(script, []string{key}, args...)
}
func (cs clusterStore) Exists(key string) (bool, error) {
node, err := cs.getRedis(key)
if err != nil {
return false, err
}
return node.Exists(key)
}
func (cs clusterStore) Expire(key string, seconds int) error {
node, err := cs.getRedis(key)
if err != nil {
return err
}
return node.Expire(key, seconds)
}
func (cs clusterStore) Expireat(key string, expireTime int64) error {
node, err := cs.getRedis(key)
if err != nil {
return err
}
return node.Expireat(key, expireTime)
}
func (cs clusterStore) Get(key string) (string, error) {
node, err := cs.getRedis(key)
if err != nil {
return "", err
}
return node.Get(key)
}
func (cs clusterStore) Hdel(key, field string) (bool, error) {
node, err := cs.getRedis(key)
if err != nil {
return false, err
}
return node.Hdel(key, field)
}
func (cs clusterStore) Hexists(key, field string) (bool, error) {
node, err := cs.getRedis(key)
if err != nil {
return false, err
}
return node.Hexists(key, field)
}
func (cs clusterStore) Hget(key, field string) (string, error) {
node, err := cs.getRedis(key)
if err != nil {
return "", err
}
return node.Hget(key, field)
}
func (cs clusterStore) Hgetall(key string) (map[string]string, error) {
node, err := cs.getRedis(key)
if err != nil {
return nil, err
}
return node.Hgetall(key)
}
func (cs clusterStore) Hincrby(key, field string, increment int) (int, error) {
node, err := cs.getRedis(key)
if err != nil {
return 0, err
}
return node.Hincrby(key, field, increment)
}
func (cs clusterStore) Hkeys(key string) ([]string, error) {
node, err := cs.getRedis(key)
if err != nil {
return nil, err
}
return node.Hkeys(key)
}
func (cs clusterStore) Hlen(key string) (int, error) {
node, err := cs.getRedis(key)
if err != nil {
return 0, err
}
return node.Hlen(key)
}
func (cs clusterStore) Hmget(key string, fields ...string) ([]string, error) {
node, err := cs.getRedis(key)
if err != nil {
return nil, err
}
return node.Hmget(key, fields...)
}
func (cs clusterStore) Hset(key, field, value string) error {
node, err := cs.getRedis(key)
if err != nil {
return err
}
return node.Hset(key, field, value)
}
func (cs clusterStore) Hsetnx(key, field, value string) (bool, error) {
node, err := cs.getRedis(key)
if err != nil {
return false, err
}
return node.Hsetnx(key, field, value)
}
func (cs clusterStore) Hmset(key string, fieldsAndValues map[string]string) error {
node, err := cs.getRedis(key)
if err != nil {
return err
}
return node.Hmset(key, fieldsAndValues)
}
func (cs clusterStore) Hvals(key string) ([]string, error) {
node, err := cs.getRedis(key)
if err != nil {
return nil, err
}
return node.Hvals(key)
}
func (cs clusterStore) Incr(key string) (int64, error) {
node, err := cs.getRedis(key)
if err != nil {
return 0, err
}
return node.Incr(key)
}
func (cs clusterStore) Incrby(key string, increment int64) (int64, error) {
node, err := cs.getRedis(key)
if err != nil {
return 0, err
}
return node.Incrby(key, increment)
}
func (cs clusterStore) Llen(key string) (int, error) {
node, err := cs.getRedis(key)
if err != nil {
return 0, err
}
return node.Llen(key)
}
func (cs clusterStore) Lpop(key string) (string, error) {
node, err := cs.getRedis(key)
if err != nil {
return "", err
}
return node.Lpop(key)
}
func (cs clusterStore) Lpush(key string, values ...interface{}) (int, error) {
node, err := cs.getRedis(key)
if err != nil {
return 0, err
}
return node.Lpush(key, values...)
}
func (cs clusterStore) Lrange(key string, start int, stop int) ([]string, error) {
node, err := cs.getRedis(key)
if err != nil {
return nil, err
}
return node.Lrange(key, start, stop)
}
func (cs clusterStore) Lrem(key string, count int, value string) (int, error) {
node, err := cs.getRedis(key)
if err != nil {
return 0, err
}
return node.Lrem(key, count, value)
}
func (cs clusterStore) Persist(key string) (bool, error) {
node, err := cs.getRedis(key)
if err != nil {
return false, err
}
return node.Persist(key)
}
func (cs clusterStore) Pfadd(key string, values ...interface{}) (bool, error) {
node, err := cs.getRedis(key)
if err != nil {
return false, err
}
return node.Pfadd(key, values...)
}
func (cs clusterStore) Pfcount(key string) (int64, error) {
node, err := cs.getRedis(key)
if err != nil {
return 0, err
}
return node.Pfcount(key)
}
func (cs clusterStore) Rpush(key string, values ...interface{}) (int, error) {
node, err := cs.getRedis(key)
if err != nil {
return 0, err
}
return node.Rpush(key, values...)
}
func (cs clusterStore) Sadd(key string, values ...interface{}) (int, error) {
node, err := cs.getRedis(key)
if err != nil {
return 0, err
}
return node.Sadd(key, values...)
}
func (cs clusterStore) Scard(key string) (int64, error) {
node, err := cs.getRedis(key)
if err != nil {
return 0, err
}
return node.Scard(key)
}
func (cs clusterStore) Set(key string, value string) error {
node, err := cs.getRedis(key)
if err != nil {
return err
}
return node.Set(key, value)
}
func (cs clusterStore) Setex(key, value string, seconds int) error {
node, err := cs.getRedis(key)
if err != nil {
return err
}
return node.Setex(key, value, seconds)
}
func (cs clusterStore) Setnx(key, value string) (bool, error) {
node, err := cs.getRedis(key)
if err != nil {
return false, err
}
return node.Setnx(key, value)
}
func (cs clusterStore) SetnxEx(key, value string, seconds int) (bool, error) {
node, err := cs.getRedis(key)
if err != nil {
return false, err
}
return node.SetnxEx(key, value, seconds)
}
func (cs clusterStore) Sismember(key string, value interface{}) (bool, error) {
node, err := cs.getRedis(key)
if err != nil {
return false, err
}
return node.Sismember(key, value)
}
func (cs clusterStore) Smembers(key string) ([]string, error) {
node, err := cs.getRedis(key)
if err != nil {
return nil, err
}
return node.Smembers(key)
}
func (cs clusterStore) Spop(key string) (string, error) {
node, err := cs.getRedis(key)
if err != nil {
return "", err
}
return node.Spop(key)
}
func (cs clusterStore) Srandmember(key string, count int) ([]string, error) {
node, err := cs.getRedis(key)
if err != nil {
return nil, err
}
return node.Srandmember(key, count)
}
func (cs clusterStore) Srem(key string, values ...interface{}) (int, error) {
node, err := cs.getRedis(key)
if err != nil {
return 0, err
}
return node.Srem(key, values...)
}
func (cs clusterStore) Sscan(key string, cursor uint64, match string, count int64) (
keys []string, cur uint64, err error) {
node, err := cs.getRedis(key)
if err != nil {
return nil, 0, err
}
return node.Sscan(key, cursor, match, count)
}
func (cs clusterStore) Ttl(key string) (int, error) {
node, err := cs.getRedis(key)
if err != nil {
return 0, err
}
return node.Ttl(key)
}
func (cs clusterStore) Zadd(key string, score int64, value string) (bool, error) {
node, err := cs.getRedis(key)
if err != nil {
return false, err
}
return node.Zadd(key, score, value)
}
func (cs clusterStore) Zadds(key string, ps ...redis.Pair) (int64, error) {
node, err := cs.getRedis(key)
if err != nil {
return 0, err
}
return node.Zadds(key, ps...)
}
func (cs clusterStore) Zcard(key string) (int, error) {
node, err := cs.getRedis(key)
if err != nil {
return 0, err
}
return node.Zcard(key)
}
func (cs clusterStore) Zcount(key string, start, stop int64) (int, error) {
node, err := cs.getRedis(key)
if err != nil {
return 0, err
}
return node.Zcount(key, start, stop)
}
func (cs clusterStore) Zincrby(key string, increment int64, field string) (int64, error) {
node, err := cs.getRedis(key)
if err != nil {
return 0, err
}
return node.Zincrby(key, increment, field)
}
func (cs clusterStore) Zrank(key, field string) (int64, error) {
node, err := cs.getRedis(key)
if err != nil {
return 0, err
}
return node.Zrank(key, field)
}
func (cs clusterStore) Zrange(key string, start, stop int64) ([]string, error) {
node, err := cs.getRedis(key)
if err != nil {
return nil, err
}
return node.Zrange(key, start, stop)
}
func (cs clusterStore) ZrangeWithScores(key string, start, stop int64) ([]redis.Pair, error) {
node, err := cs.getRedis(key)
if err != nil {
return nil, err
}
return node.ZrangeWithScores(key, start, stop)
}
func (cs clusterStore) ZrangebyscoreWithScores(key string, start, stop int64) ([]redis.Pair, error) {
node, err := cs.getRedis(key)
if err != nil {
return nil, err
}
return node.ZrangebyscoreWithScores(key, start, stop)
}
func (cs clusterStore) ZrangebyscoreWithScoresAndLimit(key string, start, stop int64, page, size int) (
[]redis.Pair, error) {
node, err := cs.getRedis(key)
if err != nil {
return nil, err
}
return node.ZrangebyscoreWithScoresAndLimit(key, start, stop, page, size)
}
func (cs clusterStore) Zrem(key string, values ...interface{}) (int, error) {
node, err := cs.getRedis(key)
if err != nil {
return 0, err
}
return node.Zrem(key, values...)
}
func (cs clusterStore) Zremrangebyrank(key string, start, stop int64) (int, error) {
node, err := cs.getRedis(key)
if err != nil {
return 0, err
}
return node.Zremrangebyrank(key, start, stop)
}
func (cs clusterStore) Zremrangebyscore(key string, start, stop int64) (int, error) {
node, err := cs.getRedis(key)
if err != nil {
return 0, err
}
return node.Zremrangebyscore(key, start, stop)
}
func (cs clusterStore) Zrevrange(key string, start, stop int64) ([]string, error) {
node, err := cs.getRedis(key)
if err != nil {
return nil, err
}
return node.Zrevrange(key, start, stop)
}
func (cs clusterStore) ZrevrangebyscoreWithScores(key string, start, stop int64) ([]redis.Pair, error) {
node, err := cs.getRedis(key)
if err != nil {
return nil, err
}
return node.ZrevrangebyscoreWithScores(key, start, stop)
}
func (cs clusterStore) ZrevrangebyscoreWithScoresAndLimit(key string, start, stop int64, page, size int) (
[]redis.Pair, error) {
node, err := cs.getRedis(key)
if err != nil {
return nil, err
}
return node.ZrevrangebyscoreWithScoresAndLimit(key, start, stop, page, size)
}
func (cs clusterStore) Zscore(key string, value string) (int64, error) {
node, err := cs.getRedis(key)
if err != nil {
return 0, err
}
return node.Zscore(key, value)
}
func (cs clusterStore) getRedis(key string) (*redis.Redis, error) {
if val, ok := cs.dispatcher.Get(key); !ok {
return nil, ErrNoRedisNode
} else {
return val.(*redis.Redis), nil
}
}

View File

@@ -0,0 +1,498 @@
package kv
import (
"testing"
"time"
"zero/core/stores/internal"
"zero/core/stores/redis"
"zero/core/stringx"
"github.com/alicebob/miniredis"
"github.com/stretchr/testify/assert"
)
var s1, _ = miniredis.Run()
var s2, _ = miniredis.Run()
func TestRedis_Exists(t *testing.T) {
runOnCluster(t, func(client Store) {
ok, err := client.Exists("a")
assert.Nil(t, err)
assert.False(t, ok)
assert.Nil(t, client.Set("a", "b"))
ok, err = client.Exists("a")
assert.Nil(t, err)
assert.True(t, ok)
})
}
func TestRedis_Eval(t *testing.T) {
runOnCluster(t, func(client Store) {
_, err := client.Eval(`redis.call("EXISTS", KEYS[1])`, "notexist")
assert.Equal(t, redis.Nil, err)
err = client.Set("key1", "value1")
assert.Nil(t, err)
_, err = client.Eval(`redis.call("EXISTS", KEYS[1])`, "key1")
assert.Equal(t, redis.Nil, err)
val, err := client.Eval(`return redis.call("EXISTS", KEYS[1])`, "key1")
assert.Nil(t, err)
assert.Equal(t, int64(1), val)
})
}
func TestRedis_Hgetall(t *testing.T) {
runOnCluster(t, func(client Store) {
assert.Nil(t, client.Hset("a", "aa", "aaa"))
assert.Nil(t, client.Hset("a", "bb", "bbb"))
vals, err := client.Hgetall("a")
assert.Nil(t, err)
assert.EqualValues(t, map[string]string{
"aa": "aaa",
"bb": "bbb",
}, vals)
})
}
func TestRedis_Hvals(t *testing.T) {
runOnCluster(t, func(client Store) {
assert.Nil(t, client.Hset("a", "aa", "aaa"))
assert.Nil(t, client.Hset("a", "bb", "bbb"))
vals, err := client.Hvals("a")
assert.Nil(t, err)
assert.ElementsMatch(t, []string{"aaa", "bbb"}, vals)
})
}
func TestRedis_Hsetnx(t *testing.T) {
runOnCluster(t, func(client Store) {
assert.Nil(t, client.Hset("a", "aa", "aaa"))
assert.Nil(t, client.Hset("a", "bb", "bbb"))
ok, err := client.Hsetnx("a", "bb", "ccc")
assert.Nil(t, err)
assert.False(t, ok)
ok, err = client.Hsetnx("a", "dd", "ddd")
assert.Nil(t, err)
assert.True(t, ok)
vals, err := client.Hvals("a")
assert.Nil(t, err)
assert.ElementsMatch(t, []string{"aaa", "bbb", "ddd"}, vals)
})
}
func TestRedis_HdelHlen(t *testing.T) {
runOnCluster(t, func(client Store) {
assert.Nil(t, client.Hset("a", "aa", "aaa"))
assert.Nil(t, client.Hset("a", "bb", "bbb"))
num, err := client.Hlen("a")
assert.Nil(t, err)
assert.Equal(t, 2, num)
val, err := client.Hdel("a", "aa")
assert.Nil(t, err)
assert.True(t, val)
vals, err := client.Hvals("a")
assert.Nil(t, err)
assert.ElementsMatch(t, []string{"bbb"}, vals)
})
}
func TestRedis_HIncrBy(t *testing.T) {
runOnCluster(t, func(client Store) {
val, err := client.Hincrby("key", "field", 2)
assert.Nil(t, err)
assert.Equal(t, 2, val)
val, err = client.Hincrby("key", "field", 3)
assert.Nil(t, err)
assert.Equal(t, 5, val)
})
}
func TestRedis_Hkeys(t *testing.T) {
runOnCluster(t, func(client Store) {
assert.Nil(t, client.Hset("a", "aa", "aaa"))
assert.Nil(t, client.Hset("a", "bb", "bbb"))
vals, err := client.Hkeys("a")
assert.Nil(t, err)
assert.ElementsMatch(t, []string{"aa", "bb"}, vals)
})
}
func TestRedis_Hmget(t *testing.T) {
runOnCluster(t, func(client Store) {
assert.Nil(t, client.Hset("a", "aa", "aaa"))
assert.Nil(t, client.Hset("a", "bb", "bbb"))
vals, err := client.Hmget("a", "aa", "bb")
assert.Nil(t, err)
assert.EqualValues(t, []string{"aaa", "bbb"}, vals)
vals, err = client.Hmget("a", "aa", "no", "bb")
assert.Nil(t, err)
assert.EqualValues(t, []string{"aaa", "", "bbb"}, vals)
})
}
func TestRedis_Hmset(t *testing.T) {
runOnCluster(t, func(client Store) {
assert.Nil(t, client.Hmset("a", map[string]string{
"aa": "aaa",
"bb": "bbb",
}))
vals, err := client.Hmget("a", "aa", "bb")
assert.Nil(t, err)
assert.EqualValues(t, []string{"aaa", "bbb"}, vals)
})
}
func TestRedis_Incr(t *testing.T) {
runOnCluster(t, func(client Store) {
val, err := client.Incr("a")
assert.Nil(t, err)
assert.Equal(t, int64(1), val)
val, err = client.Incr("a")
assert.Nil(t, err)
assert.Equal(t, int64(2), val)
})
}
func TestRedis_IncrBy(t *testing.T) {
runOnCluster(t, func(client Store) {
val, err := client.Incrby("a", 2)
assert.Nil(t, err)
assert.Equal(t, int64(2), val)
val, err = client.Incrby("a", 3)
assert.Nil(t, err)
assert.Equal(t, int64(5), val)
})
}
func TestRedis_List(t *testing.T) {
runOnCluster(t, func(client Store) {
val, err := client.Lpush("key", "value1", "value2")
assert.Nil(t, err)
assert.Equal(t, 2, val)
val, err = client.Rpush("key", "value3", "value4")
assert.Nil(t, err)
assert.Equal(t, 4, val)
val, err = client.Llen("key")
assert.Nil(t, err)
assert.Equal(t, 4, val)
vals, err := client.Lrange("key", 0, 10)
assert.Nil(t, err)
assert.EqualValues(t, []string{"value2", "value1", "value3", "value4"}, vals)
v, err := client.Lpop("key")
assert.Nil(t, err)
assert.Equal(t, "value2", v)
val, err = client.Lpush("key", "value1", "value2")
assert.Nil(t, err)
assert.Equal(t, 5, val)
val, err = client.Rpush("key", "value3", "value3")
assert.Nil(t, err)
assert.Equal(t, 7, val)
n, err := client.Lrem("key", 2, "value1")
assert.Nil(t, err)
assert.Equal(t, 2, n)
vals, err = client.Lrange("key", 0, 10)
assert.Nil(t, err)
assert.EqualValues(t, []string{"value2", "value3", "value4", "value3", "value3"}, vals)
n, err = client.Lrem("key", -2, "value3")
assert.Nil(t, err)
assert.Equal(t, 2, n)
vals, err = client.Lrange("key", 0, 10)
assert.Nil(t, err)
assert.EqualValues(t, []string{"value2", "value3", "value4"}, vals)
})
}
func TestRedis_Persist(t *testing.T) {
runOnCluster(t, func(client Store) {
ok, err := client.Persist("key")
assert.Nil(t, err)
assert.False(t, ok)
err = client.Set("key", "value")
assert.Nil(t, err)
ok, err = client.Persist("key")
assert.Nil(t, err)
assert.False(t, ok)
err = client.Expire("key", 5)
ok, err = client.Persist("key")
assert.Nil(t, err)
assert.True(t, ok)
err = client.Expireat("key", time.Now().Unix()+5)
ok, err = client.Persist("key")
assert.Nil(t, err)
assert.True(t, ok)
})
}
func TestRedis_Sscan(t *testing.T) {
runOnCluster(t, func(client Store) {
key := "list"
var list []string
for i := 0; i < 1550; i++ {
list = append(list, stringx.Randn(i))
}
lens, err := client.Sadd(key, list)
assert.Nil(t, err)
assert.Equal(t, lens, 1550)
var cursor uint64 = 0
sum := 0
for {
keys, next, err := client.Sscan(key, cursor, "", 100)
assert.Nil(t, err)
sum += len(keys)
if next == 0 {
break
}
cursor = next
}
assert.Equal(t, sum, 1550)
_, err = client.Del(key)
assert.Nil(t, err)
})
}
func TestRedis_Set(t *testing.T) {
runOnCluster(t, func(client Store) {
num, err := client.Sadd("key", 1, 2, 3, 4)
assert.Nil(t, err)
assert.Equal(t, 4, num)
val, err := client.Scard("key")
assert.Nil(t, err)
assert.Equal(t, int64(4), val)
ok, err := client.Sismember("key", 2)
assert.Nil(t, err)
assert.True(t, ok)
num, err = client.Srem("key", 3, 4)
assert.Nil(t, err)
assert.Equal(t, 2, num)
vals, err := client.Smembers("key")
assert.Nil(t, err)
assert.ElementsMatch(t, []string{"1", "2"}, vals)
members, err := client.Srandmember("key", 1)
assert.Nil(t, err)
assert.Len(t, members, 1)
assert.Contains(t, []string{"1", "2"}, members[0])
member, err := client.Spop("key")
assert.Nil(t, err)
assert.Contains(t, []string{"1", "2"}, member)
vals, err = client.Smembers("key")
assert.Nil(t, err)
assert.NotContains(t, vals, member)
num, err = client.Sadd("key1", 1, 2, 3, 4)
assert.Nil(t, err)
assert.Equal(t, 4, num)
num, err = client.Sadd("key2", 2, 3, 4, 5)
assert.Nil(t, err)
assert.Equal(t, 4, num)
})
}
func TestRedis_SetGetDel(t *testing.T) {
runOnCluster(t, func(client Store) {
err := client.Set("hello", "world")
assert.Nil(t, err)
val, err := client.Get("hello")
assert.Nil(t, err)
assert.Equal(t, "world", val)
ret, err := client.Del("hello")
assert.Nil(t, err)
assert.Equal(t, 1, ret)
})
}
func TestRedis_SetExNx(t *testing.T) {
runOnCluster(t, func(client Store) {
err := client.Setex("hello", "world", 5)
assert.Nil(t, err)
ok, err := client.Setnx("hello", "newworld")
assert.Nil(t, err)
assert.False(t, ok)
ok, err = client.Setnx("newhello", "newworld")
assert.Nil(t, err)
assert.True(t, ok)
val, err := client.Get("hello")
assert.Nil(t, err)
assert.Equal(t, "world", val)
val, err = client.Get("newhello")
assert.Nil(t, err)
assert.Equal(t, "newworld", val)
ttl, err := client.Ttl("hello")
assert.Nil(t, err)
assert.True(t, ttl > 0)
ok, err = client.SetnxEx("newhello", "newworld", 5)
assert.Nil(t, err)
assert.False(t, ok)
num, err := client.Del("newhello")
assert.Nil(t, err)
assert.Equal(t, 1, num)
ok, err = client.SetnxEx("newhello", "newworld", 5)
assert.Nil(t, err)
assert.True(t, ok)
val, err = client.Get("newhello")
assert.Nil(t, err)
assert.Equal(t, "newworld", val)
})
}
func TestRedis_SetGetDelHashField(t *testing.T) {
runOnCluster(t, func(client Store) {
err := client.Hset("key", "field", "value")
assert.Nil(t, err)
val, err := client.Hget("key", "field")
assert.Nil(t, err)
assert.Equal(t, "value", val)
ok, err := client.Hexists("key", "field")
assert.Nil(t, err)
assert.True(t, ok)
ret, err := client.Hdel("key", "field")
assert.Nil(t, err)
assert.True(t, ret)
ok, err = client.Hexists("key", "field")
assert.Nil(t, err)
assert.False(t, ok)
})
}
func TestRedis_SortedSet(t *testing.T) {
runOnCluster(t, func(client Store) {
ok, err := client.Zadd("key", 1, "value1")
assert.Nil(t, err)
assert.True(t, ok)
ok, err = client.Zadd("key", 2, "value1")
assert.Nil(t, err)
assert.False(t, ok)
val, err := client.Zscore("key", "value1")
assert.Nil(t, err)
assert.Equal(t, int64(2), val)
val, err = client.Zincrby("key", 3, "value1")
assert.Nil(t, err)
assert.Equal(t, int64(5), val)
val, err = client.Zscore("key", "value1")
assert.Nil(t, err)
assert.Equal(t, int64(5), val)
ok, err = client.Zadd("key", 6, "value2")
assert.Nil(t, err)
assert.True(t, ok)
ok, err = client.Zadd("key", 7, "value3")
assert.Nil(t, err)
assert.True(t, ok)
rank, err := client.Zrank("key", "value2")
assert.Nil(t, err)
assert.Equal(t, int64(1), rank)
rank, err = client.Zrank("key", "value4")
assert.Equal(t, redis.Nil, err)
num, err := client.Zrem("key", "value2", "value3")
assert.Nil(t, err)
assert.Equal(t, 2, num)
ok, err = client.Zadd("key", 6, "value2")
assert.Nil(t, err)
assert.True(t, ok)
ok, err = client.Zadd("key", 7, "value3")
assert.Nil(t, err)
assert.True(t, ok)
ok, err = client.Zadd("key", 8, "value4")
assert.Nil(t, err)
assert.True(t, ok)
num, err = client.Zremrangebyscore("key", 6, 7)
assert.Nil(t, err)
assert.Equal(t, 2, num)
ok, err = client.Zadd("key", 6, "value2")
assert.Nil(t, err)
assert.True(t, ok)
ok, err = client.Zadd("key", 7, "value3")
assert.Nil(t, err)
assert.True(t, ok)
num, err = client.Zcount("key", 6, 7)
assert.Nil(t, err)
assert.Equal(t, 2, num)
num, err = client.Zremrangebyrank("key", 1, 2)
assert.Nil(t, err)
assert.Equal(t, 2, num)
card, err := client.Zcard("key")
assert.Nil(t, err)
assert.Equal(t, 2, card)
vals, err := client.Zrange("key", 0, -1)
assert.Nil(t, err)
assert.EqualValues(t, []string{"value1", "value4"}, vals)
vals, err = client.Zrevrange("key", 0, -1)
assert.Nil(t, err)
assert.EqualValues(t, []string{"value4", "value1"}, vals)
pairs, err := client.ZrangeWithScores("key", 0, -1)
assert.Nil(t, err)
assert.EqualValues(t, []redis.Pair{
{
Key: "value1",
Score: 5,
},
{
Key: "value4",
Score: 8,
},
}, pairs)
pairs, err = client.ZrangebyscoreWithScores("key", 5, 8)
assert.Nil(t, err)
assert.EqualValues(t, []redis.Pair{
{
Key: "value1",
Score: 5,
},
{
Key: "value4",
Score: 8,
},
}, pairs)
pairs, err = client.ZrangebyscoreWithScoresAndLimit("key", 5, 8, 1, 1)
assert.Nil(t, err)
assert.EqualValues(t, []redis.Pair{
{
Key: "value4",
Score: 8,
},
}, pairs)
pairs, err = client.ZrevrangebyscoreWithScores("key", 5, 8)
assert.Nil(t, err)
assert.EqualValues(t, []redis.Pair{
{
Key: "value4",
Score: 8,
},
{
Key: "value1",
Score: 5,
},
}, pairs)
pairs, err = client.ZrevrangebyscoreWithScoresAndLimit("key", 5, 8, 1, 1)
assert.Nil(t, err)
assert.EqualValues(t, []redis.Pair{
{
Key: "value1",
Score: 5,
},
}, pairs)
})
}
func runOnCluster(t *testing.T, fn func(cluster Store)) {
s1.FlushAll()
s2.FlushAll()
store := NewStore([]internal.NodeConf{
{
RedisConf: redis.RedisConf{
Host: s1.Addr(),
Type: redis.NodeType,
},
Weight: 100,
},
{
RedisConf: redis.RedisConf{
Host: s2.Addr(),
Type: redis.NodeType,
},
Weight: 100,
},
})
fn(store)
}

View File

@@ -0,0 +1,87 @@
package mongo
import (
"time"
"zero/core/executors"
"zero/core/logx"
"github.com/globalsign/mgo"
)
const (
flushInterval = time.Second
maxBulkRows = 1000
)
type (
ResultHandler func(*mgo.BulkResult, error)
BulkInserter struct {
executor *executors.PeriodicalExecutor
inserter *dbInserter
}
)
func NewBulkInserter(session *mgo.Session, dbName string, collectionNamer func() string) *BulkInserter {
inserter := &dbInserter{
session: session,
dbName: dbName,
collectionNamer: collectionNamer,
}
return &BulkInserter{
executor: executors.NewPeriodicalExecutor(flushInterval, inserter),
inserter: inserter,
}
}
func (bi *BulkInserter) Flush() {
bi.executor.Flush()
}
func (bi *BulkInserter) Insert(doc interface{}) {
bi.executor.Add(doc)
}
func (bi *BulkInserter) SetResultHandler(handler ResultHandler) {
bi.executor.Sync(func() {
bi.inserter.resultHandler = handler
})
}
type dbInserter struct {
session *mgo.Session
dbName string
collectionNamer func() string
documents []interface{}
resultHandler ResultHandler
}
func (in *dbInserter) AddTask(doc interface{}) bool {
in.documents = append(in.documents, doc)
return len(in.documents) >= maxBulkRows
}
func (in *dbInserter) Execute(objs interface{}) {
docs := objs.([]interface{})
if len(docs) == 0 {
return
}
bulk := in.session.DB(in.dbName).C(in.collectionNamer()).Bulk()
bulk.Insert(docs...)
bulk.Unordered()
result, err := bulk.Run()
if in.resultHandler != nil {
in.resultHandler(result, err)
} else if err != nil {
logx.Error(err)
}
}
func (in *dbInserter) RemoveAll() interface{} {
documents := in.documents
in.documents = nil
return documents
}

View File

@@ -0,0 +1,238 @@
package mongo
import (
"encoding/json"
"time"
"zero/core/breaker"
"zero/core/logx"
"zero/core/timex"
"github.com/globalsign/mgo"
)
const slowThreshold = time.Millisecond * 500
var ErrNotFound = mgo.ErrNotFound
type (
Collection interface {
Find(query interface{}) Query
FindId(id interface{}) Query
Insert(docs ...interface{}) error
Pipe(pipeline interface{}) Pipe
Remove(selector interface{}) error
RemoveAll(selector interface{}) (*mgo.ChangeInfo, error)
RemoveId(id interface{}) error
Update(selector, update interface{}) error
UpdateId(id, update interface{}) error
Upsert(selector, update interface{}) (*mgo.ChangeInfo, error)
}
decoratedCollection struct {
*mgo.Collection
brk breaker.Breaker
}
keepablePromise struct {
promise breaker.Promise
log func(error)
}
)
func newCollection(collection *mgo.Collection) Collection {
return &decoratedCollection{
Collection: collection,
brk: breaker.NewBreaker(),
}
}
func (c *decoratedCollection) Find(query interface{}) Query {
promise, err := c.brk.Allow()
if err != nil {
return rejectedQuery{}
}
startTime := timex.Now()
return promisedQuery{
Query: c.Collection.Find(query),
promise: keepablePromise{
promise: promise,
log: func(err error) {
duration := timex.Since(startTime)
c.logDuration("find", duration, err, query)
},
},
}
}
func (c *decoratedCollection) FindId(id interface{}) Query {
promise, err := c.brk.Allow()
if err != nil {
return rejectedQuery{}
}
startTime := timex.Now()
return promisedQuery{
Query: c.Collection.FindId(id),
promise: keepablePromise{
promise: promise,
log: func(err error) {
duration := timex.Since(startTime)
c.logDuration("findId", duration, err, id)
},
},
}
}
func (c *decoratedCollection) Insert(docs ...interface{}) (err error) {
return c.brk.DoWithAcceptable(func() error {
startTime := timex.Now()
defer func() {
duration := timex.Since(startTime)
c.logDuration("insert", duration, err, docs...)
}()
return c.Collection.Insert(docs...)
}, acceptable)
}
func (c *decoratedCollection) Pipe(pipeline interface{}) Pipe {
promise, err := c.brk.Allow()
if err != nil {
return rejectedPipe{}
}
startTime := timex.Now()
return promisedPipe{
Pipe: c.Collection.Pipe(pipeline),
promise: keepablePromise{
promise: promise,
log: func(err error) {
duration := timex.Since(startTime)
c.logDuration("pipe", duration, err, pipeline)
},
},
}
}
func (c *decoratedCollection) Remove(selector interface{}) (err error) {
return c.brk.DoWithAcceptable(func() error {
startTime := timex.Now()
defer func() {
duration := timex.Since(startTime)
c.logDuration("remove", duration, err, selector)
}()
return c.Collection.Remove(selector)
}, acceptable)
}
func (c *decoratedCollection) RemoveAll(selector interface{}) (info *mgo.ChangeInfo, err error) {
err = c.brk.DoWithAcceptable(func() error {
startTime := timex.Now()
defer func() {
duration := timex.Since(startTime)
c.logDuration("removeAll", duration, err, selector)
}()
info, err = c.Collection.RemoveAll(selector)
return err
}, acceptable)
return
}
func (c *decoratedCollection) RemoveId(id interface{}) (err error) {
return c.brk.DoWithAcceptable(func() error {
startTime := timex.Now()
defer func() {
duration := timex.Since(startTime)
c.logDuration("removeId", duration, err, id)
}()
return c.Collection.RemoveId(id)
}, acceptable)
}
func (c *decoratedCollection) Update(selector, update interface{}) (err error) {
return c.brk.DoWithAcceptable(func() error {
startTime := timex.Now()
defer func() {
duration := timex.Since(startTime)
c.logDuration("update", duration, err, selector, update)
}()
return c.Collection.Update(selector, update)
}, acceptable)
}
func (c *decoratedCollection) UpdateId(id, update interface{}) (err error) {
return c.brk.DoWithAcceptable(func() error {
startTime := timex.Now()
defer func() {
duration := timex.Since(startTime)
c.logDuration("updateId", duration, err, id, update)
}()
return c.Collection.UpdateId(id, update)
}, acceptable)
}
func (c *decoratedCollection) Upsert(selector, update interface{}) (info *mgo.ChangeInfo, err error) {
err = c.brk.DoWithAcceptable(func() error {
startTime := timex.Now()
defer func() {
duration := timex.Since(startTime)
c.logDuration("upsert", duration, err, selector, update)
}()
info, err = c.Collection.Upsert(selector, update)
return err
}, acceptable)
return
}
func (c *decoratedCollection) logDuration(method string, duration time.Duration, err error, docs ...interface{}) {
content, e := json.Marshal(docs)
if e != nil {
logx.Error(err)
} else if err != nil {
if duration > slowThreshold {
logx.WithDuration(duration).Slowf("[MONGO] mongo(%s) - slowcall - %s - fail(%s) - %s",
c.FullName, method, err.Error(), string(content))
} else {
logx.WithDuration(duration).Infof("mongo(%s) - %s - fail(%s) - %s",
c.FullName, method, err.Error(), string(content))
}
} else {
if duration > slowThreshold {
logx.WithDuration(duration).Slowf("[MONGO] mongo(%s) - slowcall - %s - ok - %s",
c.FullName, method, string(content))
} else {
logx.WithDuration(duration).Infof("mongo(%s) - %s - ok - %s", c.FullName, method, string(content))
}
}
}
func (p keepablePromise) accept(err error) error {
p.promise.Accept()
p.log(err)
return err
}
func (p keepablePromise) keep(err error) error {
if acceptable(err) {
p.promise.Accept()
} else {
p.promise.Reject(err.Error())
}
p.log(err)
return err
}
func acceptable(err error) bool {
return err == nil || err == mgo.ErrNotFound
}

View File

@@ -0,0 +1,71 @@
package mongo
import (
"errors"
"testing"
"github.com/globalsign/mgo"
"github.com/stretchr/testify/assert"
"zero/core/stringx"
)
func TestKeepPromise_accept(t *testing.T) {
p := new(mockPromise)
kp := keepablePromise{
promise: p,
log: func(error) {},
}
assert.Nil(t, kp.accept(nil))
assert.Equal(t, mgo.ErrNotFound, kp.accept(mgo.ErrNotFound))
}
func TestKeepPromise_keep(t *testing.T) {
tests := []struct {
err error
accepted bool
reason string
}{
{
err: nil,
accepted: true,
reason: "",
},
{
err: mgo.ErrNotFound,
accepted: true,
reason: "",
},
{
err: errors.New("any"),
accepted: false,
reason: "any",
},
}
for _, test := range tests {
t.Run(stringx.RandId(), func(t *testing.T) {
p := new(mockPromise)
kp := keepablePromise{
promise: p,
log: func(error) {},
}
assert.Equal(t, test.err, kp.keep(test.err))
assert.Equal(t, test.accepted, p.accepted)
assert.Equal(t, test.reason, p.reason)
})
}
}
type mockPromise struct {
accepted bool
reason string
}
func (p *mockPromise) Accept() {
p.accepted = true
}
func (p *mockPromise) Reject(reason string) {
p.reason = reason
}

96
core/stores/mongo/iter.go Normal file
View File

@@ -0,0 +1,96 @@
//go:generate mockgen -package mongo -destination iter_mock.go -source iter.go Iter
package mongo
import (
"zero/core/breaker"
"github.com/globalsign/mgo/bson"
)
type (
Iter interface {
All(result interface{}) error
Close() error
Done() bool
Err() error
For(result interface{}, f func() error) error
Next(result interface{}) bool
State() (int64, []bson.Raw)
Timeout() bool
}
ClosableIter struct {
Iter
Cleanup func()
}
promisedIter struct {
Iter
promise keepablePromise
}
rejectedIter struct{}
)
func (i promisedIter) All(result interface{}) error {
return i.promise.keep(i.Iter.All(result))
}
func (i promisedIter) Close() error {
return i.promise.keep(i.Iter.Close())
}
func (i promisedIter) Err() error {
return i.Iter.Err()
}
func (i promisedIter) For(result interface{}, f func() error) error {
var ferr error
err := i.Iter.For(result, func() error {
ferr = f()
return ferr
})
if ferr == err {
return i.promise.accept(err)
}
return i.promise.keep(err)
}
func (it *ClosableIter) Close() error {
err := it.Iter.Close()
it.Cleanup()
return err
}
func (i rejectedIter) All(result interface{}) error {
return breaker.ErrServiceUnavailable
}
func (i rejectedIter) Close() error {
return breaker.ErrServiceUnavailable
}
func (i rejectedIter) Done() bool {
return false
}
func (i rejectedIter) Err() error {
return breaker.ErrServiceUnavailable
}
func (i rejectedIter) For(result interface{}, f func() error) error {
return breaker.ErrServiceUnavailable
}
func (i rejectedIter) Next(result interface{}) bool {
return false
}
func (i rejectedIter) State() (int64, []bson.Raw) {
return 0, nil
}
func (i rejectedIter) Timeout() bool {
return false
}

View File

@@ -0,0 +1,147 @@
// Code generated by MockGen. DO NOT EDIT.
// Source: iter.go
// Package mongo is a generated GoMock package.
package mongo
import (
bson "github.com/globalsign/mgo/bson"
gomock "github.com/golang/mock/gomock"
reflect "reflect"
)
// MockIter is a mock of Iter interface
type MockIter struct {
ctrl *gomock.Controller
recorder *MockIterMockRecorder
}
// MockIterMockRecorder is the mock recorder for MockIter
type MockIterMockRecorder struct {
mock *MockIter
}
// NewMockIter creates a new mock instance
func NewMockIter(ctrl *gomock.Controller) *MockIter {
mock := &MockIter{ctrl: ctrl}
mock.recorder = &MockIterMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use
func (m *MockIter) EXPECT() *MockIterMockRecorder {
return m.recorder
}
// All mocks base method
func (m *MockIter) All(result interface{}) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "All", result)
ret0, _ := ret[0].(error)
return ret0
}
// All indicates an expected call of All
func (mr *MockIterMockRecorder) All(result interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "All", reflect.TypeOf((*MockIter)(nil).All), result)
}
// Close mocks base method
func (m *MockIter) Close() error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Close")
ret0, _ := ret[0].(error)
return ret0
}
// Close indicates an expected call of Close
func (mr *MockIterMockRecorder) Close() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockIter)(nil).Close))
}
// Done mocks base method
func (m *MockIter) Done() bool {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Done")
ret0, _ := ret[0].(bool)
return ret0
}
// Done indicates an expected call of Done
func (mr *MockIterMockRecorder) Done() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Done", reflect.TypeOf((*MockIter)(nil).Done))
}
// Err mocks base method
func (m *MockIter) Err() error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Err")
ret0, _ := ret[0].(error)
return ret0
}
// Err indicates an expected call of Err
func (mr *MockIterMockRecorder) Err() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Err", reflect.TypeOf((*MockIter)(nil).Err))
}
// For mocks base method
func (m *MockIter) For(result interface{}, f func() error) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "For", result, f)
ret0, _ := ret[0].(error)
return ret0
}
// For indicates an expected call of For
func (mr *MockIterMockRecorder) For(result, f interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "For", reflect.TypeOf((*MockIter)(nil).For), result, f)
}
// Next mocks base method
func (m *MockIter) Next(result interface{}) bool {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Next", result)
ret0, _ := ret[0].(bool)
return ret0
}
// Next indicates an expected call of Next
func (mr *MockIterMockRecorder) Next(result interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Next", reflect.TypeOf((*MockIter)(nil).Next), result)
}
// State mocks base method
func (m *MockIter) State() (int64, []bson.Raw) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "State")
ret0, _ := ret[0].(int64)
ret1, _ := ret[1].([]bson.Raw)
return ret0, ret1
}
// State indicates an expected call of State
func (mr *MockIterMockRecorder) State() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "State", reflect.TypeOf((*MockIter)(nil).State))
}
// Timeout mocks base method
func (m *MockIter) Timeout() bool {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Timeout")
ret0, _ := ret[0].(bool)
return ret0
}
// Timeout indicates an expected call of Timeout
func (mr *MockIterMockRecorder) Timeout() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Timeout", reflect.TypeOf((*MockIter)(nil).Timeout))
}

View File

@@ -0,0 +1,265 @@
package mongo
import (
"errors"
"testing"
"zero/core/breaker"
"zero/core/stringx"
"zero/core/syncx"
"github.com/globalsign/mgo"
"github.com/golang/mock/gomock"
"github.com/stretchr/testify/assert"
)
func TestClosableIter_Close(t *testing.T) {
errs := []error{
nil,
mgo.ErrNotFound,
}
for _, err := range errs {
t.Run(stringx.RandId(), func(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
cleaned := syncx.NewAtomicBool()
iter := NewMockIter(ctrl)
iter.EXPECT().Close().Return(err)
ci := ClosableIter{
Iter: iter,
Cleanup: func() {
cleaned.Set(true)
},
}
assert.Equal(t, err, ci.Close())
assert.True(t, cleaned.True())
})
}
}
func TestPromisedIter_AllAndClose(t *testing.T) {
tests := []struct {
err error
accepted bool
reason string
}{
{
err: nil,
accepted: true,
reason: "",
},
{
err: mgo.ErrNotFound,
accepted: true,
reason: "",
},
{
err: errors.New("any"),
accepted: false,
reason: "any",
},
}
for _, test := range tests {
t.Run(stringx.RandId(), func(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
iter := NewMockIter(ctrl)
iter.EXPECT().All(gomock.Any()).Return(test.err)
promise := new(mockPromise)
pi := promisedIter{
Iter: iter,
promise: keepablePromise{
promise: promise,
log: func(error) {},
},
}
assert.Equal(t, test.err, pi.All(nil))
assert.Equal(t, test.accepted, promise.accepted)
assert.Equal(t, test.reason, promise.reason)
})
}
for _, test := range tests {
t.Run(stringx.RandId(), func(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
iter := NewMockIter(ctrl)
iter.EXPECT().Close().Return(test.err)
promise := new(mockPromise)
pi := promisedIter{
Iter: iter,
promise: keepablePromise{
promise: promise,
log: func(error) {},
},
}
assert.Equal(t, test.err, pi.Close())
assert.Equal(t, test.accepted, promise.accepted)
assert.Equal(t, test.reason, promise.reason)
})
}
}
func TestPromisedIter_Err(t *testing.T) {
errs := []error{
nil,
mgo.ErrNotFound,
}
for _, err := range errs {
t.Run(stringx.RandId(), func(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
iter := NewMockIter(ctrl)
iter.EXPECT().Err().Return(err)
promise := new(mockPromise)
pi := promisedIter{
Iter: iter,
promise: keepablePromise{
promise: promise,
log: func(error) {},
},
}
assert.Equal(t, err, pi.Err())
})
}
}
func TestPromisedIter_For(t *testing.T) {
tests := []struct {
err error
accepted bool
reason string
}{
{
err: nil,
accepted: true,
reason: "",
},
{
err: mgo.ErrNotFound,
accepted: true,
reason: "",
},
{
err: errors.New("any"),
accepted: false,
reason: "any",
},
}
for _, test := range tests {
t.Run(stringx.RandId(), func(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
iter := NewMockIter(ctrl)
iter.EXPECT().For(gomock.Any(), gomock.Any()).Return(test.err)
promise := new(mockPromise)
pi := promisedIter{
Iter: iter,
promise: keepablePromise{
promise: promise,
log: func(error) {},
},
}
assert.Equal(t, test.err, pi.For(nil, nil))
assert.Equal(t, test.accepted, promise.accepted)
assert.Equal(t, test.reason, promise.reason)
})
}
}
func TestRejectedIter_All(t *testing.T) {
assert.Equal(t, breaker.ErrServiceUnavailable, new(rejectedIter).All(nil))
}
func TestRejectedIter_Close(t *testing.T) {
assert.Equal(t, breaker.ErrServiceUnavailable, new(rejectedIter).Close())
}
func TestRejectedIter_Done(t *testing.T) {
assert.False(t, new(rejectedIter).Done())
}
func TestRejectedIter_Err(t *testing.T) {
assert.Equal(t, breaker.ErrServiceUnavailable, new(rejectedIter).Err())
}
func TestRejectedIter_For(t *testing.T) {
assert.Equal(t, breaker.ErrServiceUnavailable, new(rejectedIter).For(nil, nil))
}
func TestRejectedIter_Next(t *testing.T) {
assert.False(t, new(rejectedIter).Next(nil))
}
func TestRejectedIter_State(t *testing.T) {
n, raw := new(rejectedIter).State()
assert.Equal(t, int64(0), n)
assert.Nil(t, raw)
}
func TestRejectedIter_Timeout(t *testing.T) {
assert.False(t, new(rejectedIter).Timeout())
}
func TestIter_Done(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
iter := NewMockIter(ctrl)
iter.EXPECT().Done().Return(true)
ci := ClosableIter{
Iter: iter,
Cleanup: nil,
}
assert.True(t, ci.Done())
}
func TestIter_Next(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
iter := NewMockIter(ctrl)
iter.EXPECT().Next(gomock.Any()).Return(true)
ci := ClosableIter{
Iter: iter,
Cleanup: nil,
}
assert.True(t, ci.Next(nil))
}
func TestIter_State(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
iter := NewMockIter(ctrl)
iter.EXPECT().State().Return(int64(1), nil)
ci := ClosableIter{
Iter: iter,
Cleanup: nil,
}
n, raw := ci.State()
assert.Equal(t, int64(1), n)
assert.Nil(t, raw)
}
func TestIter_Timeout(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
iter := NewMockIter(ctrl)
iter.EXPECT().Timeout().Return(true)
ci := ClosableIter{
Iter: iter,
Cleanup: nil,
}
assert.True(t, ci.Timeout())
}

164
core/stores/mongo/model.go Normal file
View File

@@ -0,0 +1,164 @@
package mongo
import (
"log"
"time"
"github.com/globalsign/mgo"
)
type (
options struct {
timeout time.Duration
}
Option func(opts *options)
Model struct {
session *concurrentSession
db *mgo.Database
collection string
opts []Option
}
)
func MustNewModel(url, database, collection string, opts ...Option) *Model {
model, err := NewModel(url, database, collection, opts...)
if err != nil {
log.Fatal(err)
}
return model
}
func NewModel(url, database, collection string, opts ...Option) (*Model, error) {
session, err := getConcurrentSession(url)
if err != nil {
return nil, err
}
return &Model{
session: session,
db: session.DB(database),
collection: collection,
opts: opts,
}, nil
}
func (mm *Model) Find(query interface{}) (Query, error) {
return mm.query(func(c Collection) Query {
return c.Find(query)
})
}
func (mm *Model) FindId(id interface{}) (Query, error) {
return mm.query(func(c Collection) Query {
return c.FindId(id)
})
}
func (mm *Model) GetCollection(session *mgo.Session) Collection {
return newCollection(mm.db.C(mm.collection).With(session))
}
func (mm *Model) Insert(docs ...interface{}) error {
return mm.execute(func(c Collection) error {
return c.Insert(docs...)
})
}
func (mm *Model) Pipe(pipeline interface{}) (Pipe, error) {
return mm.pipe(func(c Collection) Pipe {
return c.Pipe(pipeline)
})
}
func (mm *Model) PutSession(session *mgo.Session) {
mm.session.putSession(session)
}
func (mm *Model) Remove(selector interface{}) error {
return mm.execute(func(c Collection) error {
return c.Remove(selector)
})
}
func (mm *Model) RemoveAll(selector interface{}) (*mgo.ChangeInfo, error) {
return mm.change(func(c Collection) (*mgo.ChangeInfo, error) {
return c.RemoveAll(selector)
})
}
func (mm *Model) RemoveId(id interface{}) error {
return mm.execute(func(c Collection) error {
return c.RemoveId(id)
})
}
func (mm *Model) TakeSession() (*mgo.Session, error) {
return mm.session.takeSession(mm.opts...)
}
func (mm *Model) Update(selector, update interface{}) error {
return mm.execute(func(c Collection) error {
return c.Update(selector, update)
})
}
func (mm *Model) UpdateId(id, update interface{}) error {
return mm.execute(func(c Collection) error {
return c.UpdateId(id, update)
})
}
func (mm *Model) Upsert(selector, update interface{}) (*mgo.ChangeInfo, error) {
return mm.change(func(c Collection) (*mgo.ChangeInfo, error) {
return c.Upsert(selector, update)
})
}
func (mm *Model) change(fn func(c Collection) (*mgo.ChangeInfo, error)) (*mgo.ChangeInfo, error) {
session, err := mm.TakeSession()
if err != nil {
return nil, err
}
defer mm.PutSession(session)
return fn(mm.GetCollection(session))
}
func (mm *Model) execute(fn func(c Collection) error) error {
session, err := mm.TakeSession()
if err != nil {
return err
}
defer mm.PutSession(session)
return fn(mm.GetCollection(session))
}
func (mm *Model) pipe(fn func(c Collection) Pipe) (Pipe, error) {
session, err := mm.TakeSession()
if err != nil {
return nil, err
}
defer mm.PutSession(session)
return fn(mm.GetCollection(session)), nil
}
func (mm *Model) query(fn func(c Collection) Query) (Query, error) {
session, err := mm.TakeSession()
if err != nil {
return nil, err
}
defer mm.PutSession(session)
return fn(mm.GetCollection(session)), nil
}
func WithTimeout(timeout time.Duration) Option {
return func(opts *options) {
opts.timeout = timeout
}
}

100
core/stores/mongo/pipe.go Normal file
View File

@@ -0,0 +1,100 @@
package mongo
import (
"time"
"zero/core/breaker"
"github.com/globalsign/mgo"
)
type (
Pipe interface {
All(result interface{}) error
AllowDiskUse() Pipe
Batch(n int) Pipe
Collation(collation *mgo.Collation) Pipe
Explain(result interface{}) error
Iter() Iter
One(result interface{}) error
SetMaxTime(d time.Duration) Pipe
}
promisedPipe struct {
*mgo.Pipe
promise keepablePromise
}
rejectedPipe struct{}
)
func (p promisedPipe) All(result interface{}) error {
return p.promise.keep(p.Pipe.All(result))
}
func (p promisedPipe) AllowDiskUse() Pipe {
p.Pipe.AllowDiskUse()
return p
}
func (p promisedPipe) Batch(n int) Pipe {
p.Pipe.Batch(n)
return p
}
func (p promisedPipe) Collation(collation *mgo.Collation) Pipe {
p.Pipe.Collation(collation)
return p
}
func (p promisedPipe) Explain(result interface{}) error {
return p.promise.keep(p.Pipe.Explain(result))
}
func (p promisedPipe) Iter() Iter {
return promisedIter{
Iter: p.Pipe.Iter(),
promise: p.promise,
}
}
func (p promisedPipe) One(result interface{}) error {
return p.promise.keep(p.Pipe.One(result))
}
func (p promisedPipe) SetMaxTime(d time.Duration) Pipe {
p.Pipe.SetMaxTime(d)
return p
}
func (p rejectedPipe) All(result interface{}) error {
return breaker.ErrServiceUnavailable
}
func (p rejectedPipe) AllowDiskUse() Pipe {
return p
}
func (p rejectedPipe) Batch(n int) Pipe {
return p
}
func (p rejectedPipe) Collation(collation *mgo.Collation) Pipe {
return p
}
func (p rejectedPipe) Explain(result interface{}) error {
return breaker.ErrServiceUnavailable
}
func (p rejectedPipe) Iter() Iter {
return rejectedIter{}
}
func (p rejectedPipe) One(result interface{}) error {
return breaker.ErrServiceUnavailable
}
func (p rejectedPipe) SetMaxTime(d time.Duration) Pipe {
return p
}

View File

@@ -0,0 +1,45 @@
package mongo
import (
"testing"
"zero/core/breaker"
"github.com/stretchr/testify/assert"
)
func TestRejectedPipe_All(t *testing.T) {
assert.Equal(t, breaker.ErrServiceUnavailable, new(rejectedPipe).All(nil))
}
func TestRejectedPipe_AllowDiskUse(t *testing.T) {
var p rejectedPipe
assert.Equal(t, p, p.AllowDiskUse())
}
func TestRejectedPipe_Batch(t *testing.T) {
var p rejectedPipe
assert.Equal(t, p, p.Batch(1))
}
func TestRejectedPipe_Collation(t *testing.T) {
var p rejectedPipe
assert.Equal(t, p, p.Collation(nil))
}
func TestRejectedPipe_Explain(t *testing.T) {
assert.Equal(t, breaker.ErrServiceUnavailable, new(rejectedPipe).Explain(nil))
}
func TestRejectedPipe_Iter(t *testing.T) {
assert.EqualValues(t, rejectedIter{}, new(rejectedPipe).Iter())
}
func TestRejectedPipe_One(t *testing.T) {
assert.Equal(t, breaker.ErrServiceUnavailable, new(rejectedPipe).One(nil))
}
func TestRejectedPipe_SetMaxTime(t *testing.T) {
var p rejectedPipe
assert.Equal(t, p, p.SetMaxTime(0))
}

285
core/stores/mongo/query.go Normal file
View File

@@ -0,0 +1,285 @@
package mongo
import (
"time"
"zero/core/breaker"
"github.com/globalsign/mgo"
)
type (
Query interface {
All(result interface{}) error
Apply(change mgo.Change, result interface{}) (*mgo.ChangeInfo, error)
Batch(n int) Query
Collation(collation *mgo.Collation) Query
Comment(comment string) Query
Count() (int, error)
Distinct(key string, result interface{}) error
Explain(result interface{}) error
For(result interface{}, f func() error) error
Hint(indexKey ...string) Query
Iter() Iter
Limit(n int) Query
LogReplay() Query
MapReduce(job *mgo.MapReduce, result interface{}) (*mgo.MapReduceInfo, error)
One(result interface{}) error
Prefetch(p float64) Query
Select(selector interface{}) Query
SetMaxScan(n int) Query
SetMaxTime(d time.Duration) Query
Skip(n int) Query
Snapshot() Query
Sort(fields ...string) Query
Tail(timeout time.Duration) Iter
}
promisedQuery struct {
*mgo.Query
promise keepablePromise
}
rejectedQuery struct{}
)
func (q promisedQuery) All(result interface{}) error {
return q.promise.keep(q.Query.All(result))
}
func (q promisedQuery) Apply(change mgo.Change, result interface{}) (*mgo.ChangeInfo, error) {
info, err := q.Query.Apply(change, result)
return info, q.promise.keep(err)
}
func (q promisedQuery) Batch(n int) Query {
return promisedQuery{
Query: q.Query.Batch(n),
promise: q.promise,
}
}
func (q promisedQuery) Collation(collation *mgo.Collation) Query {
return promisedQuery{
Query: q.Query.Collation(collation),
promise: q.promise,
}
}
func (q promisedQuery) Comment(comment string) Query {
return promisedQuery{
Query: q.Query.Comment(comment),
promise: q.promise,
}
}
func (q promisedQuery) Count() (int, error) {
v, err := q.Query.Count()
return v, q.promise.keep(err)
}
func (q promisedQuery) Distinct(key string, result interface{}) error {
return q.promise.keep(q.Query.Distinct(key, result))
}
func (q promisedQuery) Explain(result interface{}) error {
return q.promise.keep(q.Query.Explain(result))
}
func (q promisedQuery) For(result interface{}, f func() error) error {
var ferr error
err := q.Query.For(result, func() error {
ferr = f()
return ferr
})
if ferr == err {
return q.promise.accept(err)
}
return q.promise.keep(err)
}
func (q promisedQuery) Hint(indexKey ...string) Query {
return promisedQuery{
Query: q.Query.Hint(indexKey...),
promise: q.promise,
}
}
func (q promisedQuery) Iter() Iter {
return promisedIter{
Iter: q.Query.Iter(),
promise: q.promise,
}
}
func (q promisedQuery) Limit(n int) Query {
return promisedQuery{
Query: q.Query.Limit(n),
promise: q.promise,
}
}
func (q promisedQuery) LogReplay() Query {
return promisedQuery{
Query: q.Query.LogReplay(),
promise: q.promise,
}
}
func (q promisedQuery) MapReduce(job *mgo.MapReduce, result interface{}) (*mgo.MapReduceInfo, error) {
info, err := q.Query.MapReduce(job, result)
return info, q.promise.keep(err)
}
func (q promisedQuery) One(result interface{}) error {
return q.promise.keep(q.Query.One(result))
}
func (q promisedQuery) Prefetch(p float64) Query {
return promisedQuery{
Query: q.Query.Prefetch(p),
promise: q.promise,
}
}
func (q promisedQuery) Select(selector interface{}) Query {
return promisedQuery{
Query: q.Query.Select(selector),
promise: q.promise,
}
}
func (q promisedQuery) SetMaxScan(n int) Query {
return promisedQuery{
Query: q.Query.SetMaxScan(n),
promise: q.promise,
}
}
func (q promisedQuery) SetMaxTime(d time.Duration) Query {
return promisedQuery{
Query: q.Query.SetMaxTime(d),
promise: q.promise,
}
}
func (q promisedQuery) Skip(n int) Query {
return promisedQuery{
Query: q.Query.Skip(n),
promise: q.promise,
}
}
func (q promisedQuery) Snapshot() Query {
return promisedQuery{
Query: q.Query.Snapshot(),
promise: q.promise,
}
}
func (q promisedQuery) Sort(fields ...string) Query {
return promisedQuery{
Query: q.Query.Sort(fields...),
promise: q.promise,
}
}
func (q promisedQuery) Tail(timeout time.Duration) Iter {
return promisedIter{
Iter: q.Query.Tail(timeout),
promise: q.promise,
}
}
func (q rejectedQuery) All(result interface{}) error {
return breaker.ErrServiceUnavailable
}
func (q rejectedQuery) Apply(change mgo.Change, result interface{}) (*mgo.ChangeInfo, error) {
return nil, breaker.ErrServiceUnavailable
}
func (q rejectedQuery) Batch(n int) Query {
return q
}
func (q rejectedQuery) Collation(collation *mgo.Collation) Query {
return q
}
func (q rejectedQuery) Comment(comment string) Query {
return q
}
func (q rejectedQuery) Count() (int, error) {
return 0, breaker.ErrServiceUnavailable
}
func (q rejectedQuery) Distinct(key string, result interface{}) error {
return breaker.ErrServiceUnavailable
}
func (q rejectedQuery) Explain(result interface{}) error {
return breaker.ErrServiceUnavailable
}
func (q rejectedQuery) For(result interface{}, f func() error) error {
return breaker.ErrServiceUnavailable
}
func (q rejectedQuery) Hint(indexKey ...string) Query {
return q
}
func (q rejectedQuery) Iter() Iter {
return rejectedIter{}
}
func (q rejectedQuery) Limit(n int) Query {
return q
}
func (q rejectedQuery) LogReplay() Query {
return q
}
func (q rejectedQuery) MapReduce(job *mgo.MapReduce, result interface{}) (*mgo.MapReduceInfo, error) {
return nil, breaker.ErrServiceUnavailable
}
func (q rejectedQuery) One(result interface{}) error {
return breaker.ErrServiceUnavailable
}
func (q rejectedQuery) Prefetch(p float64) Query {
return q
}
func (q rejectedQuery) Select(selector interface{}) Query {
return q
}
func (q rejectedQuery) SetMaxScan(n int) Query {
return q
}
func (q rejectedQuery) SetMaxTime(d time.Duration) Query {
return q
}
func (q rejectedQuery) Skip(n int) Query {
return q
}
func (q rejectedQuery) Snapshot() Query {
return q
}
func (q rejectedQuery) Sort(fields ...string) Query {
return q
}
func (q rejectedQuery) Tail(timeout time.Duration) Iter {
return rejectedIter{}
}

View File

@@ -0,0 +1,121 @@
package mongo
import (
"testing"
"zero/core/breaker"
"github.com/globalsign/mgo"
"github.com/stretchr/testify/assert"
)
func Test_rejectedQuery_All(t *testing.T) {
assert.Equal(t, breaker.ErrServiceUnavailable, new(rejectedQuery).All(nil))
}
func Test_rejectedQuery_Apply(t *testing.T) {
info, err := new(rejectedQuery).Apply(mgo.Change{}, nil)
assert.Equal(t, breaker.ErrServiceUnavailable, err)
assert.Nil(t, info)
}
func Test_rejectedQuery_Batch(t *testing.T) {
var q rejectedQuery
assert.Equal(t, q, q.Batch(1))
}
func Test_rejectedQuery_Collation(t *testing.T) {
var q rejectedQuery
assert.Equal(t, q, q.Collation(nil))
}
func Test_rejectedQuery_Comment(t *testing.T) {
var q rejectedQuery
assert.Equal(t, q, q.Comment(""))
}
func Test_rejectedQuery_Count(t *testing.T) {
n, err := new(rejectedQuery).Count()
assert.Equal(t, breaker.ErrServiceUnavailable, err)
assert.Equal(t, 0, n)
}
func Test_rejectedQuery_Distinct(t *testing.T) {
assert.Equal(t, breaker.ErrServiceUnavailable, new(rejectedQuery).Distinct("", nil))
}
func Test_rejectedQuery_Explain(t *testing.T) {
assert.Equal(t, breaker.ErrServiceUnavailable, new(rejectedQuery).Explain(nil))
}
func Test_rejectedQuery_For(t *testing.T) {
assert.Equal(t, breaker.ErrServiceUnavailable, new(rejectedQuery).For(nil, nil))
}
func Test_rejectedQuery_Hint(t *testing.T) {
var q rejectedQuery
assert.Equal(t, q, q.Hint())
}
func Test_rejectedQuery_Iter(t *testing.T) {
assert.EqualValues(t, rejectedIter{}, new(rejectedQuery).Iter())
}
func Test_rejectedQuery_Limit(t *testing.T) {
var q rejectedQuery
assert.Equal(t, q, q.Limit(1))
}
func Test_rejectedQuery_LogReplay(t *testing.T) {
var q rejectedQuery
assert.Equal(t, q, q.LogReplay())
}
func Test_rejectedQuery_MapReduce(t *testing.T) {
info, err := new(rejectedQuery).MapReduce(nil, nil)
assert.Equal(t, breaker.ErrServiceUnavailable, err)
assert.Nil(t, info)
}
func Test_rejectedQuery_One(t *testing.T) {
assert.Equal(t, breaker.ErrServiceUnavailable, new(rejectedQuery).One(nil))
}
func Test_rejectedQuery_Prefetch(t *testing.T) {
var q rejectedQuery
assert.Equal(t, q, q.Prefetch(1))
}
func Test_rejectedQuery_Select(t *testing.T) {
var q rejectedQuery
assert.Equal(t, q, q.Select(nil))
}
func Test_rejectedQuery_SetMaxScan(t *testing.T) {
var q rejectedQuery
assert.Equal(t, q, q.SetMaxScan(0))
}
func Test_rejectedQuery_SetMaxTime(t *testing.T) {
var q rejectedQuery
assert.Equal(t, q, q.SetMaxTime(0))
}
func Test_rejectedQuery_Skip(t *testing.T) {
var q rejectedQuery
assert.Equal(t, q, q.Skip(0))
}
func Test_rejectedQuery_Snapshot(t *testing.T) {
var q rejectedQuery
assert.Equal(t, q, q.Snapshot())
}
func Test_rejectedQuery_Sort(t *testing.T) {
var q rejectedQuery
assert.Equal(t, q, q.Sort())
}
func Test_rejectedQuery_Tail(t *testing.T) {
assert.EqualValues(t, rejectedIter{}, new(rejectedQuery).Tail(0))
}

View File

@@ -0,0 +1,73 @@
package mongo
import (
"io"
"time"
"zero/core/logx"
"zero/core/syncx"
"github.com/globalsign/mgo"
)
const (
defaultConcurrency = 50
defaultTimeout = time.Second
)
var sessionManager = syncx.NewResourceManager()
type concurrentSession struct {
*mgo.Session
limit syncx.TimeoutLimit
}
func (cs *concurrentSession) Close() error {
cs.Session.Close()
return nil
}
func getConcurrentSession(url string) (*concurrentSession, error) {
val, err := sessionManager.GetResource(url, func() (io.Closer, error) {
mgoSession, err := mgo.Dial(url)
if err != nil {
return nil, err
}
concurrentSess := &concurrentSession{
Session: mgoSession,
limit: syncx.NewTimeoutLimit(defaultConcurrency),
}
return concurrentSess, nil
})
if err != nil {
return nil, err
}
return val.(*concurrentSession), nil
}
func (cs *concurrentSession) putSession(session *mgo.Session) {
if err := cs.limit.Return(); err != nil {
logx.Error(err)
}
// anyway, we need to close the session
session.Close()
}
func (cs *concurrentSession) takeSession(opts ...Option) (*mgo.Session, error) {
o := &options{
timeout: defaultTimeout,
}
for _, opt := range opts {
opt(o)
}
if err := cs.limit.Borrow(o.timeout); err != nil {
return nil, err
} else {
return cs.Copy(), nil
}
}

View File

@@ -0,0 +1,9 @@
package mongo
import "strings"
const mongoAddrSep = ","
func FormatAddr(hosts []string) string {
return strings.Join(hosts, mongoAddrSep)
}

View File

@@ -0,0 +1,35 @@
package mongo
import (
"testing"
"github.com/stretchr/testify/assert"
)
func TestFormatAddrs(t *testing.T) {
tests := []struct {
addrs []string
expect string
}{
{
addrs: []string{"a", "b"},
expect: "a,b",
},
{
addrs: []string{"a", "b", "c"},
expect: "a,b,c",
},
{
addrs: []string{},
expect: "",
},
{
addrs: nil,
expect: "",
},
}
for _, test := range tests {
assert.Equal(t, test.expect, FormatAddr(test.addrs))
}
}

View File

@@ -0,0 +1,171 @@
package mongoc
import (
"zero/core/stores/internal"
"zero/core/stores/mongo"
"zero/core/syncx"
"github.com/globalsign/mgo"
)
var (
ErrNotFound = mgo.ErrNotFound
// can't use one SharedCalls per conn, because multiple conns may share the same cache key.
sharedCalls = syncx.NewSharedCalls()
stats = internal.NewCacheStat("mongoc")
)
type (
QueryOption func(query mongo.Query) mongo.Query
cachedCollection struct {
collection mongo.Collection
cache internal.Cache
}
)
func newCollection(collection mongo.Collection, c internal.Cache) *cachedCollection {
return &cachedCollection{
collection: collection,
cache: c,
}
}
func (c *cachedCollection) Count(query interface{}) (int, error) {
return c.collection.Find(query).Count()
}
func (c *cachedCollection) DelCache(keys ...string) error {
return c.cache.DelCache(keys...)
}
func (c *cachedCollection) GetCache(key string, v interface{}) error {
return c.cache.GetCache(key, v)
}
func (c *cachedCollection) FindAllNoCache(v interface{}, query interface{}, opts ...QueryOption) error {
q := c.collection.Find(query)
for _, opt := range opts {
q = opt(q)
}
return q.All(v)
}
func (c *cachedCollection) FindOne(v interface{}, key string, query interface{}) error {
return c.cache.Take(v, key, func(v interface{}) error {
q := c.collection.Find(query)
return q.One(v)
})
}
func (c *cachedCollection) FindOneNoCache(v interface{}, query interface{}) error {
q := c.collection.Find(query)
return q.One(v)
}
func (c *cachedCollection) FindOneId(v interface{}, key string, id interface{}) error {
return c.cache.Take(v, key, func(v interface{}) error {
q := c.collection.FindId(id)
return q.One(v)
})
}
func (c *cachedCollection) FindOneIdNoCache(v interface{}, id interface{}) error {
q := c.collection.FindId(id)
return q.One(v)
}
func (c *cachedCollection) Insert(docs ...interface{}) error {
return c.collection.Insert(docs...)
}
func (c *cachedCollection) Pipe(pipeline interface{}) mongo.Pipe {
return c.collection.Pipe(pipeline)
}
func (c *cachedCollection) Remove(selector interface{}, keys ...string) error {
if err := c.RemoveNoCache(selector); err != nil {
return err
}
return c.DelCache(keys...)
}
func (c *cachedCollection) RemoveNoCache(selector interface{}) error {
return c.collection.Remove(selector)
}
func (c *cachedCollection) RemoveAll(selector interface{}, keys ...string) (*mgo.ChangeInfo, error) {
info, err := c.RemoveAllNoCache(selector)
if err != nil {
return nil, err
}
if err := c.DelCache(keys...); err != nil {
return nil, err
}
return info, nil
}
func (c *cachedCollection) RemoveAllNoCache(selector interface{}) (*mgo.ChangeInfo, error) {
return c.collection.RemoveAll(selector)
}
func (c *cachedCollection) RemoveId(id interface{}, keys ...string) error {
if err := c.RemoveIdNoCache(id); err != nil {
return err
}
return c.DelCache(keys...)
}
func (c *cachedCollection) RemoveIdNoCache(id interface{}) error {
return c.collection.RemoveId(id)
}
func (c *cachedCollection) SetCache(key string, v interface{}) error {
return c.cache.SetCache(key, v)
}
func (c *cachedCollection) Update(selector, update interface{}, keys ...string) error {
if err := c.UpdateNoCache(selector, update); err != nil {
return err
}
return c.DelCache(keys...)
}
func (c *cachedCollection) UpdateNoCache(selector, update interface{}) error {
return c.collection.Update(selector, update)
}
func (c *cachedCollection) UpdateId(id, update interface{}, keys ...string) error {
if err := c.UpdateIdNoCache(id, update); err != nil {
return err
}
return c.DelCache(keys...)
}
func (c *cachedCollection) UpdateIdNoCache(id, update interface{}) error {
return c.collection.UpdateId(id, update)
}
func (c *cachedCollection) Upsert(selector, update interface{}, keys ...string) (*mgo.ChangeInfo, error) {
info, err := c.UpsertNoCache(selector, update)
if err != nil {
return nil, err
}
if err := c.DelCache(keys...); err != nil {
return nil, err
}
return info, nil
}
func (c *cachedCollection) UpsertNoCache(selector, update interface{}) (*mgo.ChangeInfo, error) {
return c.collection.Upsert(selector, update)
}

View File

@@ -0,0 +1,300 @@
package mongoc
import (
"errors"
"io/ioutil"
"log"
"os"
"runtime"
"sync"
"sync/atomic"
"testing"
"time"
"zero/core/stat"
"zero/core/stores/internal"
"zero/core/stores/mongo"
"zero/core/stores/redis"
"github.com/alicebob/miniredis"
"github.com/globalsign/mgo"
"github.com/globalsign/mgo/bson"
"github.com/stretchr/testify/assert"
)
func init() {
stat.SetReporter(nil)
}
func TestStat(t *testing.T) {
resetStats()
s, err := miniredis.Run()
if err != nil {
t.Error(err)
}
r := redis.NewRedis(s.Addr(), redis.NodeType)
cach := internal.NewCacheNode(r, sharedCalls, stats, mgo.ErrNotFound)
c := newCollection(dummyConn{}, cach)
for i := 0; i < 10; i++ {
var str string
if err = c.cache.Take(&str, "name", func(v interface{}) error {
*v.(*string) = "zero"
return nil
}); err != nil {
t.Error(err)
}
}
assert.Equal(t, uint64(10), atomic.LoadUint64(&stats.Total))
assert.Equal(t, uint64(9), atomic.LoadUint64(&stats.Hit))
}
func TestStatCacheFails(t *testing.T) {
resetStats()
log.SetOutput(ioutil.Discard)
defer log.SetOutput(os.Stdout)
r := redis.NewRedis("localhost:59999", redis.NodeType)
cach := internal.NewCacheNode(r, sharedCalls, stats, mgo.ErrNotFound)
c := newCollection(dummyConn{}, cach)
for i := 0; i < 20; i++ {
var str string
err := c.FindOne(&str, "name", bson.M{})
assert.NotNil(t, err)
}
assert.Equal(t, uint64(20), atomic.LoadUint64(&stats.Total))
assert.Equal(t, uint64(0), atomic.LoadUint64(&stats.Hit))
assert.Equal(t, uint64(20), atomic.LoadUint64(&stats.Miss))
assert.Equal(t, uint64(0), atomic.LoadUint64(&stats.DbFails))
}
func TestStatDbFails(t *testing.T) {
resetStats()
s, err := miniredis.Run()
if err != nil {
t.Error(err)
}
r := redis.NewRedis(s.Addr(), redis.NodeType)
cach := internal.NewCacheNode(r, sharedCalls, stats, mgo.ErrNotFound)
c := newCollection(dummyConn{}, cach)
for i := 0; i < 20; i++ {
var str string
err := c.cache.Take(&str, "name", func(v interface{}) error {
return errors.New("db failed")
})
assert.NotNil(t, err)
}
assert.Equal(t, uint64(20), atomic.LoadUint64(&stats.Total))
assert.Equal(t, uint64(0), atomic.LoadUint64(&stats.Hit))
assert.Equal(t, uint64(20), atomic.LoadUint64(&stats.DbFails))
}
func TestStatFromMemory(t *testing.T) {
resetStats()
s, err := miniredis.Run()
if err != nil {
t.Error(err)
}
r := redis.NewRedis(s.Addr(), redis.NodeType)
cach := internal.NewCacheNode(r, sharedCalls, stats, mgo.ErrNotFound)
c := newCollection(dummyConn{}, cach)
var all sync.WaitGroup
var wait sync.WaitGroup
all.Add(10)
wait.Add(4)
go func() {
var str string
if err := c.cache.Take(&str, "name", func(v interface{}) error {
*v.(*string) = "zero"
return nil
}); err != nil {
t.Error(err)
}
wait.Wait()
runtime.Gosched()
all.Done()
}()
for i := 0; i < 4; i++ {
go func() {
var str string
wait.Done()
if err := c.cache.Take(&str, "name", func(v interface{}) error {
*v.(*string) = "zero"
return nil
}); err != nil {
t.Error(err)
}
all.Done()
}()
}
for i := 0; i < 5; i++ {
go func() {
var str string
if err := c.cache.Take(&str, "name", func(v interface{}) error {
*v.(*string) = "zero"
return nil
}); err != nil {
t.Error(err)
}
all.Done()
}()
}
all.Wait()
assert.Equal(t, uint64(10), atomic.LoadUint64(&stats.Total))
assert.Equal(t, uint64(9), atomic.LoadUint64(&stats.Hit))
}
func resetStats() {
atomic.StoreUint64(&stats.Total, 0)
atomic.StoreUint64(&stats.Hit, 0)
atomic.StoreUint64(&stats.Miss, 0)
atomic.StoreUint64(&stats.DbFails, 0)
}
type dummyConn struct {
}
func (c dummyConn) Find(query interface{}) mongo.Query {
return dummyQuery{}
}
func (c dummyConn) FindId(id interface{}) mongo.Query {
return dummyQuery{}
}
func (c dummyConn) Insert(docs ...interface{}) error {
return nil
}
func (c dummyConn) Remove(selector interface{}) error {
return nil
}
func (dummyConn) Pipe(pipeline interface{}) mongo.Pipe {
return nil
}
func (c dummyConn) RemoveAll(selector interface{}) (*mgo.ChangeInfo, error) {
return nil, nil
}
func (c dummyConn) RemoveId(id interface{}) error {
return nil
}
func (c dummyConn) Update(selector, update interface{}) error {
return nil
}
func (c dummyConn) UpdateId(id, update interface{}) error {
return nil
}
func (c dummyConn) Upsert(selector, update interface{}) (*mgo.ChangeInfo, error) {
return nil, nil
}
type dummyQuery struct {
}
func (d dummyQuery) All(result interface{}) error {
return nil
}
func (d dummyQuery) Apply(change mgo.Change, result interface{}) (*mgo.ChangeInfo, error) {
return nil, nil
}
func (d dummyQuery) Count() (int, error) {
return 0, nil
}
func (d dummyQuery) Distinct(key string, result interface{}) error {
return nil
}
func (d dummyQuery) Explain(result interface{}) error {
return nil
}
func (d dummyQuery) For(result interface{}, f func() error) error {
return nil
}
func (d dummyQuery) MapReduce(job *mgo.MapReduce, result interface{}) (*mgo.MapReduceInfo, error) {
return nil, nil
}
func (d dummyQuery) One(result interface{}) error {
return nil
}
func (d dummyQuery) Batch(n int) mongo.Query {
return d
}
func (d dummyQuery) Collation(collation *mgo.Collation) mongo.Query {
return d
}
func (d dummyQuery) Comment(comment string) mongo.Query {
return d
}
func (d dummyQuery) Hint(indexKey ...string) mongo.Query {
return d
}
func (d dummyQuery) Iter() mongo.Iter {
return &mgo.Iter{}
}
func (d dummyQuery) Limit(n int) mongo.Query {
return d
}
func (d dummyQuery) LogReplay() mongo.Query {
return d
}
func (d dummyQuery) Prefetch(p float64) mongo.Query {
return d
}
func (d dummyQuery) Select(selector interface{}) mongo.Query {
return d
}
func (d dummyQuery) SetMaxScan(n int) mongo.Query {
return d
}
func (d dummyQuery) SetMaxTime(duration time.Duration) mongo.Query {
return d
}
func (d dummyQuery) Skip(n int) mongo.Query {
return d
}
func (d dummyQuery) Snapshot() mongo.Query {
return d
}
func (d dummyQuery) Sort(fields ...string) mongo.Query {
return d
}
func (d dummyQuery) Tail(timeout time.Duration) mongo.Iter {
return &mgo.Iter{}
}

View File

@@ -0,0 +1,243 @@
package mongoc
import (
"log"
"zero/core/stores/cache"
"zero/core/stores/internal"
"zero/core/stores/mongo"
"zero/core/stores/redis"
"github.com/globalsign/mgo"
)
type Model struct {
*mongo.Model
cache internal.Cache
generateCollection func(*mgo.Session) *cachedCollection
}
func MustNewNodeModel(url, database, collection string, rds *redis.Redis, opts ...cache.Option) *Model {
model, err := NewNodeModel(url, database, collection, rds, opts...)
if err != nil {
log.Fatal(err)
}
return model
}
func MustNewModel(url, database, collection string, c cache.CacheConf, opts ...cache.Option) *Model {
model, err := NewModel(url, database, collection, c, opts...)
if err != nil {
log.Fatal(err)
}
return model
}
func NewNodeModel(url, database, collection string, rds *redis.Redis, opts ...cache.Option) (*Model, error) {
c := internal.NewCacheNode(rds, sharedCalls, stats, mgo.ErrNotFound, opts...)
return createModel(url, database, collection, c, func(collection mongo.Collection) *cachedCollection {
return newCollection(collection, c)
})
}
func NewModel(url, database, collection string, conf cache.CacheConf, opts ...cache.Option) (*Model, error) {
c := internal.NewCache(conf, sharedCalls, stats, mgo.ErrNotFound, opts...)
return createModel(url, database, collection, c, func(collection mongo.Collection) *cachedCollection {
return newCollection(collection, c)
})
}
func (mm *Model) Count(query interface{}) (int, error) {
return mm.executeInt(func(c *cachedCollection) (int, error) {
return c.Count(query)
})
}
func (mm *Model) DelCache(keys ...string) error {
return mm.cache.DelCache(keys...)
}
func (mm *Model) GetCache(key string, v interface{}) error {
return mm.cache.GetCache(key, v)
}
func (mm *Model) GetCollection(session *mgo.Session) *cachedCollection {
return mm.generateCollection(session)
}
func (mm *Model) FindAllNoCache(v interface{}, query interface{}, opts ...QueryOption) error {
return mm.execute(func(c *cachedCollection) error {
return c.FindAllNoCache(v, query, opts...)
})
}
func (mm *Model) FindOne(v interface{}, key string, query interface{}) error {
return mm.execute(func(c *cachedCollection) error {
return c.FindOne(v, key, query)
})
}
func (mm *Model) FindOneNoCache(v interface{}, query interface{}) error {
return mm.execute(func(c *cachedCollection) error {
return c.FindOneNoCache(v, query)
})
}
func (mm *Model) FindOneId(v interface{}, key string, id interface{}) error {
return mm.execute(func(c *cachedCollection) error {
return c.FindOneId(v, key, id)
})
}
func (mm *Model) FindOneIdNoCache(v interface{}, id interface{}) error {
return mm.execute(func(c *cachedCollection) error {
return c.FindOneIdNoCache(v, id)
})
}
func (mm *Model) Insert(docs ...interface{}) error {
return mm.execute(func(c *cachedCollection) error {
return c.Insert(docs...)
})
}
func (mm *Model) Pipe(pipeline interface{}) (mongo.Pipe, error) {
return mm.pipe(func(c *cachedCollection) mongo.Pipe {
return c.Pipe(pipeline)
})
}
func (mm *Model) Remove(selector interface{}, keys ...string) error {
return mm.execute(func(c *cachedCollection) error {
return c.Remove(selector, keys...)
})
}
func (mm *Model) RemoveNoCache(selector interface{}) error {
return mm.execute(func(c *cachedCollection) error {
return c.RemoveNoCache(selector)
})
}
func (mm *Model) RemoveAll(selector interface{}, keys ...string) (*mgo.ChangeInfo, error) {
return mm.change(func(c *cachedCollection) (*mgo.ChangeInfo, error) {
return c.RemoveAll(selector, keys...)
})
}
func (mm *Model) RemoveAllNoCache(selector interface{}) (*mgo.ChangeInfo, error) {
return mm.change(func(c *cachedCollection) (*mgo.ChangeInfo, error) {
return c.RemoveAllNoCache(selector)
})
}
func (mm *Model) RemoveId(id interface{}, keys ...string) error {
return mm.execute(func(c *cachedCollection) error {
return c.RemoveId(id, keys...)
})
}
func (mm *Model) RemoveIdNoCache(id interface{}) error {
return mm.execute(func(c *cachedCollection) error {
return c.RemoveIdNoCache(id)
})
}
func (mm *Model) SetCache(key string, v interface{}) error {
return mm.cache.SetCache(key, v)
}
func (mm *Model) Update(selector, update interface{}, keys ...string) error {
return mm.execute(func(c *cachedCollection) error {
return c.Update(selector, update, keys...)
})
}
func (mm *Model) UpdateNoCache(selector, update interface{}) error {
return mm.execute(func(c *cachedCollection) error {
return c.UpdateNoCache(selector, update)
})
}
func (mm *Model) UpdateId(id, update interface{}, keys ...string) error {
return mm.execute(func(c *cachedCollection) error {
return c.UpdateId(id, update, keys...)
})
}
func (mm *Model) UpdateIdNoCache(id, update interface{}) error {
return mm.execute(func(c *cachedCollection) error {
return c.UpdateIdNoCache(id, update)
})
}
func (mm *Model) Upsert(selector, update interface{}, keys ...string) (*mgo.ChangeInfo, error) {
return mm.change(func(c *cachedCollection) (*mgo.ChangeInfo, error) {
return c.Upsert(selector, update, keys...)
})
}
func (mm *Model) UpsertNoCache(selector, update interface{}) (*mgo.ChangeInfo, error) {
return mm.change(func(c *cachedCollection) (*mgo.ChangeInfo, error) {
return c.UpsertNoCache(selector, update)
})
}
func (mm *Model) change(fn func(c *cachedCollection) (*mgo.ChangeInfo, error)) (*mgo.ChangeInfo, error) {
session, err := mm.TakeSession()
if err != nil {
return nil, err
}
defer mm.PutSession(session)
return fn(mm.GetCollection(session))
}
func (mm *Model) execute(fn func(c *cachedCollection) error) error {
session, err := mm.TakeSession()
if err != nil {
return err
}
defer mm.PutSession(session)
return fn(mm.GetCollection(session))
}
func (mm *Model) executeInt(fn func(c *cachedCollection) (int, error)) (int, error) {
session, err := mm.TakeSession()
if err != nil {
return 0, err
}
defer mm.PutSession(session)
return fn(mm.GetCollection(session))
}
func (mm *Model) pipe(fn func(c *cachedCollection) mongo.Pipe) (mongo.Pipe, error) {
session, err := mm.TakeSession()
if err != nil {
return nil, err
}
defer mm.PutSession(session)
return fn(mm.GetCollection(session)), nil
}
func createModel(url, database, collection string, c internal.Cache,
create func(mongo.Collection) *cachedCollection) (*Model, error) {
model, err := mongo.NewModel(url, database, collection)
if err != nil {
return nil, err
}
return &Model{
Model: model,
cache: c,
generateCollection: func(session *mgo.Session) *cachedCollection {
collection := model.GetCollection(session)
return create(collection)
},
}, nil
}

View File

@@ -0,0 +1,13 @@
package postgres
import (
"zero/core/stores/sqlx"
_ "github.com/lib/pq"
)
const postgreDriverName = "postgres"
func NewPostgre(datasource string, opts ...sqlx.SqlOption) sqlx.SqlConn {
return sqlx.NewSqlConn(postgreDriverName, datasource, opts...)
}

50
core/stores/redis/conf.go Normal file
View File

@@ -0,0 +1,50 @@
package redis
import "errors"
var (
ErrEmptyHost = errors.New("empty redis host")
ErrEmptyType = errors.New("empty redis type")
ErrEmptyKey = errors.New("empty redis key")
)
type (
RedisConf struct {
Host string
Type string `json:",default=node,options=node|cluster"`
Pass string `json:",optional"`
}
RedisKeyConf struct {
RedisConf
Key string `json:",optional"`
}
)
func (rc RedisConf) NewRedis() *Redis {
return NewRedis(rc.Host, rc.Type, rc.Pass)
}
func (rc RedisConf) Validate() error {
if len(rc.Host) == 0 {
return ErrEmptyHost
}
if len(rc.Type) == 0 {
return ErrEmptyType
}
return nil
}
func (rkc RedisKeyConf) Validate() error {
if err := rkc.RedisConf.Validate(); err != nil {
return err
}
if len(rkc.Key) == 0 {
return ErrEmptyKey
}
return nil
}

View File

@@ -0,0 +1,33 @@
package redis
import (
"strings"
"zero/core/logx"
"zero/core/mapping"
"zero/core/timex"
red "github.com/go-redis/redis"
)
func process(proc func(red.Cmder) error) func(red.Cmder) error {
return func(cmd red.Cmder) error {
start := timex.Now()
defer func() {
duration := timex.Since(start)
if duration > slowThreshold {
var buf strings.Builder
for i, arg := range cmd.Args() {
if i > 0 {
buf.WriteByte(' ')
}
buf.WriteString(mapping.Repr(arg))
}
logx.WithDuration(duration).Slowf("[REDIS] slowcall on executing: %s", buf.String())
}
}()
return proc(cmd)
}
}

1339
core/stores/redis/redis.go Normal file

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,580 @@
package redis
import (
"errors"
"io"
"testing"
"time"
"github.com/alicebob/miniredis"
"github.com/stretchr/testify/assert"
)
func TestRedis_Exists(t *testing.T) {
runOnRedis(t, func(client *Redis) {
ok, err := client.Exists("a")
assert.Nil(t, err)
assert.False(t, ok)
assert.Nil(t, client.Set("a", "b"))
ok, err = client.Exists("a")
assert.Nil(t, err)
assert.True(t, ok)
})
}
func TestRedis_Eval(t *testing.T) {
runOnRedis(t, func(client *Redis) {
_, err := client.Eval(`redis.call("EXISTS", KEYS[1])`, []string{"notexist"})
assert.Equal(t, Nil, err)
err = client.Set("key1", "value1")
assert.Nil(t, err)
_, err = client.Eval(`redis.call("EXISTS", KEYS[1])`, []string{"key1"})
assert.Equal(t, Nil, err)
val, err := client.Eval(`return redis.call("EXISTS", KEYS[1])`, []string{"key1"})
assert.Nil(t, err)
assert.Equal(t, int64(1), val)
})
}
func TestRedis_Hgetall(t *testing.T) {
runOnRedis(t, func(client *Redis) {
assert.Nil(t, client.Hset("a", "aa", "aaa"))
assert.Nil(t, client.Hset("a", "bb", "bbb"))
vals, err := client.Hgetall("a")
assert.Nil(t, err)
assert.EqualValues(t, map[string]string{
"aa": "aaa",
"bb": "bbb",
}, vals)
})
}
func TestRedis_Hvals(t *testing.T) {
runOnRedis(t, func(client *Redis) {
assert.Nil(t, client.Hset("a", "aa", "aaa"))
assert.Nil(t, client.Hset("a", "bb", "bbb"))
vals, err := client.Hvals("a")
assert.Nil(t, err)
assert.ElementsMatch(t, []string{"aaa", "bbb"}, vals)
})
}
func TestRedis_Hsetnx(t *testing.T) {
runOnRedis(t, func(client *Redis) {
assert.Nil(t, client.Hset("a", "aa", "aaa"))
assert.Nil(t, client.Hset("a", "bb", "bbb"))
ok, err := client.Hsetnx("a", "bb", "ccc")
assert.Nil(t, err)
assert.False(t, ok)
ok, err = client.Hsetnx("a", "dd", "ddd")
assert.Nil(t, err)
assert.True(t, ok)
vals, err := client.Hvals("a")
assert.Nil(t, err)
assert.ElementsMatch(t, []string{"aaa", "bbb", "ddd"}, vals)
})
}
func TestRedis_HdelHlen(t *testing.T) {
runOnRedis(t, func(client *Redis) {
assert.Nil(t, client.Hset("a", "aa", "aaa"))
assert.Nil(t, client.Hset("a", "bb", "bbb"))
num, err := client.Hlen("a")
assert.Nil(t, err)
assert.Equal(t, 2, num)
val, err := client.Hdel("a", "aa")
assert.Nil(t, err)
assert.True(t, val)
vals, err := client.Hvals("a")
assert.Nil(t, err)
assert.ElementsMatch(t, []string{"bbb"}, vals)
})
}
func TestRedis_HIncrBy(t *testing.T) {
runOnRedis(t, func(client *Redis) {
val, err := client.Hincrby("key", "field", 2)
assert.Nil(t, err)
assert.Equal(t, 2, val)
val, err = client.Hincrby("key", "field", 3)
assert.Nil(t, err)
assert.Equal(t, 5, val)
})
}
func TestRedis_Hkeys(t *testing.T) {
runOnRedis(t, func(client *Redis) {
assert.Nil(t, client.Hset("a", "aa", "aaa"))
assert.Nil(t, client.Hset("a", "bb", "bbb"))
vals, err := client.Hkeys("a")
assert.Nil(t, err)
assert.ElementsMatch(t, []string{"aa", "bb"}, vals)
})
}
func TestRedis_Hmget(t *testing.T) {
runOnRedis(t, func(client *Redis) {
assert.Nil(t, client.Hset("a", "aa", "aaa"))
assert.Nil(t, client.Hset("a", "bb", "bbb"))
vals, err := client.Hmget("a", "aa", "bb")
assert.Nil(t, err)
assert.EqualValues(t, []string{"aaa", "bbb"}, vals)
vals, err = client.Hmget("a", "aa", "no", "bb")
assert.Nil(t, err)
assert.EqualValues(t, []string{"aaa", "", "bbb"}, vals)
})
}
func TestRedis_Hmset(t *testing.T) {
runOnRedis(t, func(client *Redis) {
assert.Nil(t, client.Hmset("a", map[string]string{
"aa": "aaa",
"bb": "bbb",
}))
vals, err := client.Hmget("a", "aa", "bb")
assert.Nil(t, err)
assert.EqualValues(t, []string{"aaa", "bbb"}, vals)
})
}
func TestRedis_Incr(t *testing.T) {
runOnRedis(t, func(client *Redis) {
val, err := client.Incr("a")
assert.Nil(t, err)
assert.Equal(t, int64(1), val)
val, err = client.Incr("a")
assert.Nil(t, err)
assert.Equal(t, int64(2), val)
})
}
func TestRedis_IncrBy(t *testing.T) {
runOnRedis(t, func(client *Redis) {
val, err := client.Incrby("a", 2)
assert.Nil(t, err)
assert.Equal(t, int64(2), val)
val, err = client.Incrby("a", 3)
assert.Nil(t, err)
assert.Equal(t, int64(5), val)
})
}
func TestRedis_Keys(t *testing.T) {
runOnRedis(t, func(client *Redis) {
err := client.Set("key1", "value1")
assert.Nil(t, err)
err = client.Set("key2", "value2")
assert.Nil(t, err)
keys, err := client.Keys("*")
assert.Nil(t, err)
assert.ElementsMatch(t, []string{"key1", "key2"}, keys)
})
}
func TestRedis_List(t *testing.T) {
runOnRedis(t, func(client *Redis) {
val, err := client.Lpush("key", "value1", "value2")
assert.Nil(t, err)
assert.Equal(t, 2, val)
val, err = client.Rpush("key", "value3", "value4")
assert.Nil(t, err)
assert.Equal(t, 4, val)
val, err = client.Llen("key")
assert.Nil(t, err)
assert.Equal(t, 4, val)
vals, err := client.Lrange("key", 0, 10)
assert.Nil(t, err)
assert.EqualValues(t, []string{"value2", "value1", "value3", "value4"}, vals)
v, err := client.Lpop("key")
assert.Nil(t, err)
assert.Equal(t, "value2", v)
val, err = client.Lpush("key", "value1", "value2")
assert.Nil(t, err)
assert.Equal(t, 5, val)
val, err = client.Rpush("key", "value3", "value3")
assert.Nil(t, err)
assert.Equal(t, 7, val)
n, err := client.Lrem("key", 2, "value1")
assert.Nil(t, err)
assert.Equal(t, 2, n)
vals, err = client.Lrange("key", 0, 10)
assert.Nil(t, err)
assert.EqualValues(t, []string{"value2", "value3", "value4", "value3", "value3"}, vals)
n, err = client.Lrem("key", -2, "value3")
assert.Nil(t, err)
assert.Equal(t, 2, n)
vals, err = client.Lrange("key", 0, 10)
assert.Nil(t, err)
assert.EqualValues(t, []string{"value2", "value3", "value4"}, vals)
})
}
func TestRedis_Mget(t *testing.T) {
runOnRedis(t, func(client *Redis) {
err := client.Set("key1", "value1")
assert.Nil(t, err)
err = client.Set("key2", "value2")
assert.Nil(t, err)
vals, err := client.Mget("key1", "key0", "key2", "key3")
assert.Nil(t, err)
assert.EqualValues(t, []string{"value1", "", "value2", ""}, vals)
})
}
func TestRedis_SetBit(t *testing.T) {
runOnRedis(t, func(client *Redis) {
err := client.SetBit("key", 1, 1)
assert.Nil(t, err)
})
}
func TestRedis_GetBit(t *testing.T) {
runOnRedis(t, func(client *Redis) {
err := client.SetBit("key", 2, 1)
assert.Nil(t, err)
val, err := client.GetBit("key", 2)
assert.Nil(t, err)
assert.Equal(t, 1, val)
})
}
func TestRedis_Persist(t *testing.T) {
runOnRedis(t, func(client *Redis) {
ok, err := client.Persist("key")
assert.Nil(t, err)
assert.False(t, ok)
err = client.Set("key", "value")
assert.Nil(t, err)
ok, err = client.Persist("key")
assert.Nil(t, err)
assert.False(t, ok)
err = client.Expire("key", 5)
ok, err = client.Persist("key")
assert.Nil(t, err)
assert.True(t, ok)
err = client.Expireat("key", time.Now().Unix()+5)
ok, err = client.Persist("key")
assert.Nil(t, err)
assert.True(t, ok)
})
}
func TestRedis_Ping(t *testing.T) {
runOnRedis(t, func(client *Redis) {
ok := client.Ping()
assert.True(t, ok)
})
}
func TestRedis_Scan(t *testing.T) {
runOnRedis(t, func(client *Redis) {
err := client.Set("key1", "value1")
assert.Nil(t, err)
err = client.Set("key2", "value2")
assert.Nil(t, err)
keys, _, err := client.Scan(0, "*", 100)
assert.Nil(t, err)
assert.ElementsMatch(t, []string{"key1", "key2"}, keys)
})
}
func TestRedis_Sscan(t *testing.T) {
runOnRedis(t, func(client *Redis) {
key := "list"
var list []string
for i := 0; i < 1550; i++ {
list = append(list, randomStr(i))
}
lens, err := client.Sadd(key, list)
assert.Nil(t, err)
assert.Equal(t, lens, 1550)
var cursor uint64 = 0
sum := 0
for {
keys, next, err := client.Sscan(key, cursor, "", 100)
assert.Nil(t, err)
sum += len(keys)
if next == 0 {
break
}
cursor = next
}
assert.Equal(t, sum, 1550)
_, err = client.Del(key)
assert.Nil(t, err)
})
}
func TestRedis_Set(t *testing.T) {
runOnRedis(t, func(client *Redis) {
num, err := client.Sadd("key", 1, 2, 3, 4)
assert.Nil(t, err)
assert.Equal(t, 4, num)
val, err := client.Scard("key")
assert.Nil(t, err)
assert.Equal(t, int64(4), val)
ok, err := client.Sismember("key", 2)
assert.Nil(t, err)
assert.True(t, ok)
num, err = client.Srem("key", 3, 4)
assert.Nil(t, err)
assert.Equal(t, 2, num)
vals, err := client.Smembers("key")
assert.Nil(t, err)
assert.ElementsMatch(t, []string{"1", "2"}, vals)
members, err := client.Srandmember("key", 1)
assert.Nil(t, err)
assert.Len(t, members, 1)
assert.Contains(t, []string{"1", "2"}, members[0])
member, err := client.Spop("key")
assert.Nil(t, err)
assert.Contains(t, []string{"1", "2"}, member)
vals, err = client.Smembers("key")
assert.Nil(t, err)
assert.NotContains(t, vals, member)
num, err = client.Sadd("key1", 1, 2, 3, 4)
assert.Nil(t, err)
assert.Equal(t, 4, num)
num, err = client.Sadd("key2", 2, 3, 4, 5)
assert.Nil(t, err)
assert.Equal(t, 4, num)
vals, err = client.Sunion("key1", "key2")
assert.Nil(t, err)
assert.ElementsMatch(t, []string{"1", "2", "3", "4", "5"}, vals)
num, err = client.Sunionstore("key3", "key1", "key2")
assert.Nil(t, err)
assert.Equal(t, 5, num)
vals, err = client.Sdiff("key1", "key2")
assert.Nil(t, err)
assert.EqualValues(t, []string{"1"}, vals)
num, err = client.Sdiffstore("key4", "key1", "key2")
assert.Nil(t, err)
assert.Equal(t, 1, num)
})
}
func TestRedis_SetGetDel(t *testing.T) {
runOnRedis(t, func(client *Redis) {
err := client.Set("hello", "world")
assert.Nil(t, err)
val, err := client.Get("hello")
assert.Nil(t, err)
assert.Equal(t, "world", val)
ret, err := client.Del("hello")
assert.Nil(t, err)
assert.Equal(t, 1, ret)
})
}
func TestRedis_SetExNx(t *testing.T) {
runOnRedis(t, func(client *Redis) {
err := client.Setex("hello", "world", 5)
assert.Nil(t, err)
ok, err := client.Setnx("hello", "newworld")
assert.Nil(t, err)
assert.False(t, ok)
ok, err = client.Setnx("newhello", "newworld")
assert.Nil(t, err)
assert.True(t, ok)
val, err := client.Get("hello")
assert.Nil(t, err)
assert.Equal(t, "world", val)
val, err = client.Get("newhello")
assert.Nil(t, err)
assert.Equal(t, "newworld", val)
ttl, err := client.Ttl("hello")
assert.Nil(t, err)
assert.True(t, ttl > 0)
ok, err = client.SetnxEx("newhello", "newworld", 5)
assert.Nil(t, err)
assert.False(t, ok)
num, err := client.Del("newhello")
assert.Nil(t, err)
assert.Equal(t, 1, num)
ok, err = client.SetnxEx("newhello", "newworld", 5)
assert.Nil(t, err)
assert.True(t, ok)
val, err = client.Get("newhello")
assert.Nil(t, err)
assert.Equal(t, "newworld", val)
})
}
func TestRedis_SetGetDelHashField(t *testing.T) {
runOnRedis(t, func(client *Redis) {
err := client.Hset("key", "field", "value")
assert.Nil(t, err)
val, err := client.Hget("key", "field")
assert.Nil(t, err)
assert.Equal(t, "value", val)
ok, err := client.Hexists("key", "field")
assert.Nil(t, err)
assert.True(t, ok)
ret, err := client.Hdel("key", "field")
assert.Nil(t, err)
assert.True(t, ret)
ok, err = client.Hexists("key", "field")
assert.Nil(t, err)
assert.False(t, ok)
})
}
func TestRedis_SortedSet(t *testing.T) {
runOnRedis(t, func(client *Redis) {
ok, err := client.Zadd("key", 1, "value1")
assert.Nil(t, err)
assert.True(t, ok)
ok, err = client.Zadd("key", 2, "value1")
assert.Nil(t, err)
assert.False(t, ok)
val, err := client.Zscore("key", "value1")
assert.Nil(t, err)
assert.Equal(t, int64(2), val)
val, err = client.Zincrby("key", 3, "value1")
assert.Nil(t, err)
assert.Equal(t, int64(5), val)
val, err = client.Zscore("key", "value1")
assert.Nil(t, err)
assert.Equal(t, int64(5), val)
ok, err = client.Zadd("key", 6, "value2")
assert.Nil(t, err)
assert.True(t, ok)
ok, err = client.Zadd("key", 7, "value3")
assert.Nil(t, err)
assert.True(t, ok)
rank, err := client.Zrank("key", "value2")
assert.Nil(t, err)
assert.Equal(t, int64(1), rank)
rank, err = client.Zrank("key", "value4")
assert.Equal(t, Nil, err)
num, err := client.Zrem("key", "value2", "value3")
assert.Nil(t, err)
assert.Equal(t, 2, num)
ok, err = client.Zadd("key", 6, "value2")
assert.Nil(t, err)
assert.True(t, ok)
ok, err = client.Zadd("key", 7, "value3")
assert.Nil(t, err)
assert.True(t, ok)
ok, err = client.Zadd("key", 8, "value4")
assert.Nil(t, err)
assert.True(t, ok)
num, err = client.Zremrangebyscore("key", 6, 7)
assert.Nil(t, err)
assert.Equal(t, 2, num)
ok, err = client.Zadd("key", 6, "value2")
assert.Nil(t, err)
assert.True(t, ok)
ok, err = client.Zadd("key", 7, "value3")
assert.Nil(t, err)
assert.True(t, ok)
num, err = client.Zcount("key", 6, 7)
assert.Nil(t, err)
assert.Equal(t, 2, num)
num, err = client.Zremrangebyrank("key", 1, 2)
assert.Nil(t, err)
assert.Equal(t, 2, num)
card, err := client.Zcard("key")
assert.Nil(t, err)
assert.Equal(t, 2, card)
vals, err := client.Zrange("key", 0, -1)
assert.Nil(t, err)
assert.EqualValues(t, []string{"value1", "value4"}, vals)
vals, err = client.Zrevrange("key", 0, -1)
assert.Nil(t, err)
assert.EqualValues(t, []string{"value4", "value1"}, vals)
pairs, err := client.ZrangeWithScores("key", 0, -1)
assert.Nil(t, err)
assert.EqualValues(t, []Pair{
{
Key: "value1",
Score: 5,
},
{
Key: "value4",
Score: 8,
},
}, pairs)
pairs, err = client.ZrangebyscoreWithScores("key", 5, 8)
assert.Nil(t, err)
assert.EqualValues(t, []Pair{
{
Key: "value1",
Score: 5,
},
{
Key: "value4",
Score: 8,
},
}, pairs)
pairs, err = client.ZrangebyscoreWithScoresAndLimit("key", 5, 8, 1, 1)
assert.Nil(t, err)
assert.EqualValues(t, []Pair{
{
Key: "value4",
Score: 8,
},
}, pairs)
pairs, err = client.ZrevrangebyscoreWithScores("key", 5, 8)
assert.Nil(t, err)
assert.EqualValues(t, []Pair{
{
Key: "value4",
Score: 8,
},
{
Key: "value1",
Score: 5,
},
}, pairs)
pairs, err = client.ZrevrangebyscoreWithScoresAndLimit("key", 5, 8, 1, 1)
assert.Nil(t, err)
assert.EqualValues(t, []Pair{
{
Key: "value1",
Score: 5,
},
}, pairs)
})
}
func TestRedis_Pipelined(t *testing.T) {
runOnRedis(t, func(client *Redis) {
err := client.Pipelined(
func(pipe Pipeliner) error {
pipe.Incr("pipelined_counter")
pipe.Expire("pipelined_counter", time.Hour)
pipe.ZAdd("zadd", Z{Score: 12, Member: "zadd"})
return nil
},
)
assert.Nil(t, err)
ttl, err := client.Ttl("pipelined_counter")
assert.Nil(t, err)
assert.Equal(t, 3600, ttl)
value, err := client.Get("pipelined_counter")
assert.Nil(t, err)
assert.Equal(t, "1", value)
score, err := client.Zscore("zadd", "zadd")
assert.Equal(t, int64(12), score)
})
}
func runOnRedis(t *testing.T, fn func(client *Redis)) {
s, err := miniredis.Run()
assert.Nil(t, err)
defer func() {
client, err := clientManager.GetResource(s.Addr(), func() (io.Closer, error) {
return nil, errors.New("should already exist")
})
if err != nil {
t.Error(err)
}
client.Close()
}()
fn(NewRedis(s.Addr(), NodeType))
}

View File

@@ -0,0 +1,66 @@
package redis
import (
"fmt"
"zero/core/logx"
red "github.com/go-redis/redis"
)
type ClosableNode interface {
RedisNode
Close()
}
func CreateBlockingNode(r *Redis) (ClosableNode, error) {
timeout := readWriteTimeout + blockingQueryTimeout
switch r.Type {
case NodeType:
client := red.NewClient(&red.Options{
Addr: r.Addr,
Password: r.Pass,
DB: defaultDatabase,
MaxRetries: maxRetries,
PoolSize: 1,
MinIdleConns: 1,
ReadTimeout: timeout,
})
return &clientBridge{client}, nil
case ClusterType:
client := red.NewClusterClient(&red.ClusterOptions{
Addrs: []string{r.Addr},
Password: r.Pass,
MaxRetries: maxRetries,
PoolSize: 1,
MinIdleConns: 1,
ReadTimeout: timeout,
})
return &clusterBridge{client}, nil
default:
return nil, fmt.Errorf("unknown redis type: %s", r.Type)
}
}
type (
clientBridge struct {
*red.Client
}
clusterBridge struct {
*red.ClusterClient
}
)
func (bridge *clientBridge) Close() {
if err := bridge.Client.Close(); err != nil {
logx.Errorf("Error occurred on close redis client: %s", err)
}
}
func (bridge *clusterBridge) Close() {
if err := bridge.ClusterClient.Close(); err != nil {
logx.Errorf("Error occurred on close redis cluster: %s", err)
}
}

View File

@@ -0,0 +1,36 @@
package redis
import (
"io"
"zero/core/syncx"
red "github.com/go-redis/redis"
)
const (
defaultDatabase = 0
maxRetries = 3
idleConns = 8
)
var clientManager = syncx.NewResourceManager()
func getClient(server, pass string) (*red.Client, error) {
val, err := clientManager.GetResource(server, func() (io.Closer, error) {
store := red.NewClient(&red.Options{
Addr: server,
Password: pass,
DB: defaultDatabase,
MaxRetries: maxRetries,
MinIdleConns: idleConns,
})
store.WrapProcess(process)
return store, nil
})
if err != nil {
return nil, err
}
return val.(*red.Client), nil
}

View File

@@ -0,0 +1,30 @@
package redis
import (
"io"
"zero/core/syncx"
red "github.com/go-redis/redis"
)
var clusterManager = syncx.NewResourceManager()
func getCluster(server, pass string) (*red.ClusterClient, error) {
val, err := clusterManager.GetResource(server, func() (io.Closer, error) {
store := red.NewClusterClient(&red.ClusterOptions{
Addrs: []string{server},
Password: pass,
MaxRetries: maxRetries,
MinIdleConns: idleConns,
})
store.WrapProcess(process)
return store, nil
})
if err != nil {
return nil, err
}
return val.(*red.ClusterClient), nil
}

View File

@@ -0,0 +1,96 @@
package redis
import (
"math/rand"
"strconv"
"sync/atomic"
"time"
"zero/core/logx"
red "github.com/go-redis/redis"
)
const (
letters = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ"
lockCommand = `if redis.call("GET", KEYS[1]) == ARGV[1] then
redis.call("SET", KEYS[1], ARGV[1], "PX", ARGV[2])
return "OK"
else
return redis.call("SET", KEYS[1], ARGV[1], "NX", "PX", ARGV[2])
end`
delCommand = `if redis.call("GET", KEYS[1]) == ARGV[1] then
return redis.call("DEL", KEYS[1])
else
return 0
end`
randomLen = 16
tolerance = 500 // milliseconds
millisPerSecond = 1000
)
type RedisLock struct {
store *Redis
seconds uint32
key string
id string
}
func init() {
rand.Seed(time.Now().UnixNano())
}
func NewRedisLock(store *Redis, key string) *RedisLock {
return &RedisLock{
store: store,
key: key,
id: randomStr(randomLen),
}
}
func (rl *RedisLock) Acquire() (bool, error) {
seconds := atomic.LoadUint32(&rl.seconds)
resp, err := rl.store.Eval(lockCommand, []string{rl.key}, []string{
rl.id, strconv.Itoa(int(seconds)*millisPerSecond + tolerance)})
if err == red.Nil {
return false, nil
} else if err != nil {
logx.Errorf("Error on acquiring lock for %s, %s", rl.key, err.Error())
return false, err
} else if resp == nil {
return false, nil
}
reply, ok := resp.(string)
if ok && reply == "OK" {
return true, nil
} else {
logx.Errorf("Unknown reply when acquiring lock for %s: %v", rl.key, resp)
return false, nil
}
}
func (rl *RedisLock) Release() (bool, error) {
resp, err := rl.store.Eval(delCommand, []string{rl.key}, []string{rl.id})
if err != nil {
return false, err
}
if reply, ok := resp.(int64); !ok {
return false, nil
} else {
return reply == 1, nil
}
}
func (rl *RedisLock) SetExpire(seconds int) {
atomic.StoreUint32(&rl.seconds, uint32(seconds))
}
func randomStr(n int) string {
b := make([]byte, n)
for i := range b {
b[i] = letters[rand.Intn(len(letters))]
}
return string(b)
}

View File

@@ -0,0 +1,34 @@
package redis
import (
"testing"
"zero/core/stringx"
"github.com/stretchr/testify/assert"
)
func TestRedisLock(t *testing.T) {
runOnRedis(t, func(client *Redis) {
key := stringx.Rand()
firstLock := NewRedisLock(client, key)
firstLock.SetExpire(5)
firstAcquire, err := firstLock.Acquire()
assert.Nil(t, err)
assert.True(t, firstAcquire)
secondLock := NewRedisLock(client, key)
secondLock.SetExpire(5)
againAcquire, err := secondLock.Acquire()
assert.Nil(t, err)
assert.False(t, againAcquire)
release, err := firstLock.Release()
assert.Nil(t, err)
assert.True(t, release)
endAcquire, err := secondLock.Acquire()
assert.Nil(t, err)
assert.True(t, endAcquire)
})
}

View File

@@ -0,0 +1,48 @@
package redis
import (
"sync"
"sync/atomic"
)
var (
once sync.Once
lock sync.Mutex
instance *ScriptCache
)
type (
Map map[string]string
ScriptCache struct {
atomic.Value
}
)
func GetScriptCache() *ScriptCache {
once.Do(func() {
instance = &ScriptCache{}
instance.Store(make(Map))
})
return instance
}
func (sc *ScriptCache) GetSha(script string) (string, bool) {
cache := sc.Load().(Map)
ret, ok := cache[script]
return ret, ok
}
func (sc *ScriptCache) SetSha(script, sha string) {
lock.Lock()
defer lock.Unlock()
cache := sc.Load().(Map)
newCache := make(Map)
for k, v := range cache {
newCache[k] = v
}
newCache[script] = sha
sc.Store(newCache)
}

View File

@@ -0,0 +1,122 @@
package sqlc
import (
"database/sql"
"time"
"zero/core/stores/cache"
"zero/core/stores/internal"
"zero/core/stores/redis"
"zero/core/stores/sqlx"
"zero/core/syncx"
)
// see doc/sql-cache.md
const cacheSafeGapBetweenIndexAndPrimary = time.Second * 5
var (
ErrNotFound = sqlx.ErrNotFound
// can't use one SharedCalls per conn, because multiple conns may share the same cache key.
exclusiveCalls = syncx.NewSharedCalls()
stats = internal.NewCacheStat("sqlc")
)
type (
ExecFn func(conn sqlx.SqlConn) (sql.Result, error)
IndexQueryFn func(conn sqlx.SqlConn, v interface{}) (interface{}, error)
PrimaryQueryFn func(conn sqlx.SqlConn, v, primary interface{}) error
QueryFn func(conn sqlx.SqlConn, v interface{}) error
CachedConn struct {
db sqlx.SqlConn
cache internal.Cache
}
)
func NewNodeConn(db sqlx.SqlConn, rds *redis.Redis, opts ...cache.Option) CachedConn {
return CachedConn{
db: db,
cache: internal.NewCacheNode(rds, exclusiveCalls, stats, sql.ErrNoRows, opts...),
}
}
func NewConn(db sqlx.SqlConn, c cache.CacheConf, opts ...cache.Option) CachedConn {
return CachedConn{
db: db,
cache: internal.NewCache(c, exclusiveCalls, stats, sql.ErrNoRows, opts...),
}
}
func (cc CachedConn) DelCache(keys ...string) error {
return cc.cache.DelCache(keys...)
}
func (cc CachedConn) GetCache(key string, v interface{}) error {
return cc.cache.GetCache(key, v)
}
func (cc CachedConn) Exec(exec ExecFn, keys ...string) (sql.Result, error) {
res, err := exec(cc.db)
if err != nil {
return nil, err
}
if err := cc.DelCache(keys...); err != nil {
return nil, err
}
return res, nil
}
func (cc CachedConn) ExecNoCache(q string, args ...interface{}) (sql.Result, error) {
return cc.db.Exec(q, args...)
}
func (cc CachedConn) QueryRow(v interface{}, key string, query QueryFn) error {
return cc.cache.Take(v, key, func(v interface{}) error {
return query(cc.db, v)
})
}
func (cc CachedConn) QueryRowIndex(v interface{}, key string, keyer func(primary interface{}) string,
indexQuery IndexQueryFn, primaryQuery PrimaryQueryFn) error {
var primaryKey interface{}
var found bool
if err := cc.cache.TakeWithExpire(&primaryKey, key, func(val interface{}, expire time.Duration) (err error) {
primaryKey, err = indexQuery(cc.db, v)
if err != nil {
return
}
found = true
return cc.cache.SetCacheWithExpire(keyer(primaryKey), v, expire+cacheSafeGapBetweenIndexAndPrimary)
}); err != nil {
return err
}
if found {
return nil
}
return cc.cache.Take(v, keyer(primaryKey), func(v interface{}) error {
return primaryQuery(cc.db, v, primaryKey)
})
}
func (cc CachedConn) QueryRowNoCache(v interface{}, q string, args ...interface{}) error {
return cc.db.QueryRow(v, q, args...)
}
// QueryRowsNoCache doesn't use cache, because it might cause consistency problem.
func (cc CachedConn) QueryRowsNoCache(v interface{}, q string, args ...interface{}) error {
return cc.db.QueryRows(v, q, args...)
}
func (cc CachedConn) SetCache(key string, v interface{}) error {
return cc.cache.SetCache(key, v)
}
func (cc CachedConn) Transact(fn func(sqlx.Session) error) error {
return cc.db.Transact(fn)
}

View File

@@ -0,0 +1,508 @@
package sqlc
import (
"database/sql"
"encoding/json"
"errors"
"fmt"
"io/ioutil"
"log"
"os"
"runtime"
"sync"
"sync/atomic"
"testing"
"time"
"zero/core/logx"
"zero/core/stat"
"zero/core/stores/cache"
"zero/core/stores/redis"
"zero/core/stores/sqlx"
"github.com/alicebob/miniredis"
"github.com/stretchr/testify/assert"
)
func init() {
logx.Disable()
stat.SetReporter(nil)
}
func TestCachedConn_GetCache(t *testing.T) {
resetStats()
s, err := miniredis.Run()
if err != nil {
t.Error(err)
}
r := redis.NewRedis(s.Addr(), redis.NodeType)
c := NewNodeConn(dummySqlConn{}, r, cache.WithExpiry(time.Second*10))
var value string
err = c.GetCache("any", &value)
assert.Equal(t, ErrNotFound, err)
s.Set("any", `"value"`)
err = c.GetCache("any", &value)
assert.Nil(t, err)
assert.Equal(t, "value", value)
}
func TestStat(t *testing.T) {
resetStats()
s, err := miniredis.Run()
if err != nil {
t.Error(err)
}
r := redis.NewRedis(s.Addr(), redis.NodeType)
c := NewNodeConn(dummySqlConn{}, r, cache.WithExpiry(time.Second*10))
for i := 0; i < 10; i++ {
var str string
err = c.QueryRow(&str, "name", func(conn sqlx.SqlConn, v interface{}) error {
*v.(*string) = "zero"
return nil
})
if err != nil {
t.Error(err)
}
}
assert.Equal(t, uint64(10), atomic.LoadUint64(&stats.Total))
assert.Equal(t, uint64(9), atomic.LoadUint64(&stats.Hit))
}
func TestCachedConn_QueryRowIndex_NoCache(t *testing.T) {
resetStats()
s, err := miniredis.Run()
if err != nil {
t.Error(err)
}
r := redis.NewRedis(s.Addr(), redis.NodeType)
c := NewNodeConn(dummySqlConn{}, r, cache.WithExpiry(time.Second*10))
var str string
err = c.QueryRowIndex(&str, "index", func(s interface{}) string {
return fmt.Sprintf("%s/1234", s)
}, func(conn sqlx.SqlConn, v interface{}) (interface{}, error) {
*v.(*string) = "zero"
return "primary", nil
}, func(conn sqlx.SqlConn, v, pri interface{}) error {
assert.Equal(t, "primary", pri)
*v.(*string) = "xin"
return nil
})
assert.Nil(t, err)
assert.Equal(t, "zero", str)
val, err := r.Get("index")
assert.Nil(t, err)
assert.Equal(t, `"primary"`, val)
val, err = r.Get("primary/1234")
assert.Nil(t, err)
assert.Equal(t, `"zero"`, val)
}
func TestCachedConn_QueryRowIndex_HasCache(t *testing.T) {
resetStats()
s, err := miniredis.Run()
if err != nil {
t.Error(err)
}
r := redis.NewRedis(s.Addr(), redis.NodeType)
c := NewNodeConn(dummySqlConn{}, r, cache.WithExpiry(time.Second*10),
cache.WithNotFoundExpiry(time.Second))
var str string
r.Set("index", `"primary"`)
err = c.QueryRowIndex(&str, "index", func(s interface{}) string {
return fmt.Sprintf("%s/1234", s)
}, func(conn sqlx.SqlConn, v interface{}) (interface{}, error) {
assert.Fail(t, "should not go here")
return "primary", nil
}, func(conn sqlx.SqlConn, v, primary interface{}) error {
*v.(*string) = "xin"
assert.Equal(t, "primary", primary)
return nil
})
assert.Nil(t, err)
assert.Equal(t, "xin", str)
val, err := r.Get("index")
assert.Nil(t, err)
assert.Equal(t, `"primary"`, val)
val, err = r.Get("primary/1234")
assert.Nil(t, err)
assert.Equal(t, `"xin"`, val)
}
func TestCachedConn_QueryRowIndex_HasWrongCache(t *testing.T) {
caches := map[string]string{
"index": "primary",
"primary/1234": "xin",
}
for k, v := range caches {
t.Run(k+"/"+v, func(t *testing.T) {
resetStats()
s, err := miniredis.Run()
if err != nil {
t.Error(err)
}
r := redis.NewRedis(s.Addr(), redis.NodeType)
c := NewNodeConn(dummySqlConn{}, r, cache.WithExpiry(time.Second*10),
cache.WithNotFoundExpiry(time.Second))
var str string
r.Set(k, v)
err = c.QueryRowIndex(&str, "index", func(s interface{}) string {
return fmt.Sprintf("%s/1234", s)
}, func(conn sqlx.SqlConn, v interface{}) (interface{}, error) {
*v.(*string) = "xin"
return "primary", nil
}, func(conn sqlx.SqlConn, v, primary interface{}) error {
*v.(*string) = "xin"
assert.Equal(t, "primary", primary)
return nil
})
assert.Nil(t, err)
assert.Equal(t, "xin", str)
val, err := r.Get("index")
assert.Nil(t, err)
assert.Equal(t, `"primary"`, val)
val, err = r.Get("primary/1234")
assert.Nil(t, err)
assert.Equal(t, `"xin"`, val)
})
}
}
func TestStatCacheFails(t *testing.T) {
resetStats()
log.SetOutput(ioutil.Discard)
defer log.SetOutput(os.Stdout)
r := redis.NewRedis("localhost:59999", redis.NodeType)
c := NewNodeConn(dummySqlConn{}, r, cache.WithExpiry(time.Second*10))
for i := 0; i < 20; i++ {
var str string
err := c.QueryRow(&str, "name", func(conn sqlx.SqlConn, v interface{}) error {
return errors.New("db failed")
})
assert.NotNil(t, err)
}
assert.Equal(t, uint64(20), atomic.LoadUint64(&stats.Total))
assert.Equal(t, uint64(0), atomic.LoadUint64(&stats.Hit))
assert.Equal(t, uint64(20), atomic.LoadUint64(&stats.Miss))
assert.Equal(t, uint64(0), atomic.LoadUint64(&stats.DbFails))
}
func TestStatDbFails(t *testing.T) {
resetStats()
s, err := miniredis.Run()
if err != nil {
t.Error(err)
}
r := redis.NewRedis(s.Addr(), redis.NodeType)
c := NewNodeConn(dummySqlConn{}, r, cache.WithExpiry(time.Second*10))
for i := 0; i < 20; i++ {
var str string
err = c.QueryRow(&str, "name", func(conn sqlx.SqlConn, v interface{}) error {
return errors.New("db failed")
})
assert.NotNil(t, err)
}
assert.Equal(t, uint64(20), atomic.LoadUint64(&stats.Total))
assert.Equal(t, uint64(0), atomic.LoadUint64(&stats.Hit))
assert.Equal(t, uint64(20), atomic.LoadUint64(&stats.DbFails))
}
func TestStatFromMemory(t *testing.T) {
resetStats()
s, err := miniredis.Run()
if err != nil {
t.Error(err)
}
r := redis.NewRedis(s.Addr(), redis.NodeType)
c := NewNodeConn(dummySqlConn{}, r, cache.WithExpiry(time.Second*10))
var all sync.WaitGroup
var wait sync.WaitGroup
all.Add(10)
wait.Add(4)
go func() {
var str string
err := c.QueryRow(&str, "name", func(conn sqlx.SqlConn, v interface{}) error {
*v.(*string) = "zero"
return nil
})
if err != nil {
t.Error(err)
}
wait.Wait()
runtime.Gosched()
all.Done()
}()
for i := 0; i < 4; i++ {
go func() {
var str string
wait.Done()
err := c.QueryRow(&str, "name", func(conn sqlx.SqlConn, v interface{}) error {
*v.(*string) = "zero"
return nil
})
if err != nil {
t.Error(err)
}
all.Done()
}()
}
for i := 0; i < 5; i++ {
go func() {
var str string
err := c.QueryRow(&str, "name", func(conn sqlx.SqlConn, v interface{}) error {
*v.(*string) = "zero"
return nil
})
if err != nil {
t.Error(err)
}
all.Done()
}()
}
all.Wait()
assert.Equal(t, uint64(10), atomic.LoadUint64(&stats.Total))
assert.Equal(t, uint64(9), atomic.LoadUint64(&stats.Hit))
}
func TestCachedConnQueryRow(t *testing.T) {
s, err := miniredis.Run()
if err != nil {
t.Error(err)
}
const (
key = "user"
value = "any"
)
var conn trackedConn
var user string
var ran bool
r := redis.NewRedis(s.Addr(), redis.NodeType)
c := NewNodeConn(&conn, r, cache.WithExpiry(time.Second*30))
err = c.QueryRow(&user, key, func(conn sqlx.SqlConn, v interface{}) error {
ran = true
user = value
return nil
})
assert.Nil(t, err)
actualValue, err := s.Get(key)
assert.Nil(t, err)
var actual string
assert.Nil(t, json.Unmarshal([]byte(actualValue), &actual))
assert.Equal(t, value, actual)
assert.Equal(t, value, user)
assert.True(t, ran)
}
func TestCachedConnQueryRowFromCache(t *testing.T) {
s, err := miniredis.Run()
if err != nil {
t.Error(err)
}
const (
key = "user"
value = "any"
)
var conn trackedConn
var user string
var ran bool
r := redis.NewRedis(s.Addr(), redis.NodeType)
c := NewNodeConn(&conn, r, cache.WithExpiry(time.Second*30))
assert.Nil(t, c.SetCache(key, value))
err = c.QueryRow(&user, key, func(conn sqlx.SqlConn, v interface{}) error {
ran = true
user = value
return nil
})
assert.Nil(t, err)
actualValue, err := s.Get(key)
assert.Nil(t, err)
var actual string
assert.Nil(t, json.Unmarshal([]byte(actualValue), &actual))
assert.Equal(t, value, actual)
assert.Equal(t, value, user)
assert.False(t, ran)
}
func TestQueryRowNotFound(t *testing.T) {
s, err := miniredis.Run()
if err != nil {
t.Error(err)
}
const key = "user"
var conn trackedConn
var user string
var ran int
r := redis.NewRedis(s.Addr(), redis.NodeType)
c := NewNodeConn(&conn, r, cache.WithExpiry(time.Second*30))
for i := 0; i < 20; i++ {
err = c.QueryRow(&user, key, func(conn sqlx.SqlConn, v interface{}) error {
ran++
return sql.ErrNoRows
})
assert.Exactly(t, sqlx.ErrNotFound, err)
}
assert.Equal(t, 1, ran)
}
func TestCachedConnExec(t *testing.T) {
s, err := miniredis.Run()
if err != nil {
t.Error(err)
}
var conn trackedConn
r := redis.NewRedis(s.Addr(), redis.NodeType)
c := NewNodeConn(&conn, r, cache.WithExpiry(time.Second*10))
_, err = c.ExecNoCache("delete from user_table where id='kevin'")
assert.Nil(t, err)
assert.True(t, conn.execValue)
}
func TestCachedConnExecDropCache(t *testing.T) {
s, err := miniredis.Run()
if err != nil {
t.Error(err)
}
const (
key = "user"
value = "any"
)
var conn trackedConn
r := redis.NewRedis(s.Addr(), redis.NodeType)
c := NewNodeConn(&conn, r, cache.WithExpiry(time.Second*30))
assert.Nil(t, c.SetCache(key, value))
_, err = c.Exec(func(conn sqlx.SqlConn) (result sql.Result, e error) {
return conn.Exec("delete from user_table where id='kevin'")
}, key)
assert.Nil(t, err)
assert.True(t, conn.execValue)
_, err = s.Get(key)
assert.Exactly(t, miniredis.ErrKeyNotFound, err)
}
func TestCachedConnExecDropCacheFailed(t *testing.T) {
const key = "user"
var conn trackedConn
r := redis.NewRedis("anyredis:8888", redis.NodeType)
c := NewNodeConn(&conn, r, cache.WithExpiry(time.Second*10))
_, err := c.Exec(func(conn sqlx.SqlConn) (result sql.Result, e error) {
return conn.Exec("delete from user_table where id='kevin'")
}, key)
// async background clean, retry logic
assert.Nil(t, err)
}
func TestCachedConnQueryRows(t *testing.T) {
s, err := miniredis.Run()
if err != nil {
t.Error(err)
}
var conn trackedConn
r := redis.NewRedis(s.Addr(), redis.NodeType)
c := NewNodeConn(&conn, r, cache.WithExpiry(time.Second*10))
var users []string
err = c.QueryRowsNoCache(&users, "select user from user_table where id='kevin'")
assert.Nil(t, err)
assert.True(t, conn.queryRowsValue)
}
func TestCachedConnTransact(t *testing.T) {
s, err := miniredis.Run()
if err != nil {
t.Error(err)
}
var conn trackedConn
r := redis.NewRedis(s.Addr(), redis.NodeType)
c := NewNodeConn(&conn, r, cache.WithExpiry(time.Second*10))
err = c.Transact(func(session sqlx.Session) error {
return nil
})
assert.Nil(t, err)
assert.True(t, conn.transactValue)
}
func resetStats() {
atomic.StoreUint64(&stats.Total, 0)
atomic.StoreUint64(&stats.Hit, 0)
atomic.StoreUint64(&stats.Miss, 0)
atomic.StoreUint64(&stats.DbFails, 0)
}
type dummySqlConn struct {
}
func (d dummySqlConn) Exec(query string, args ...interface{}) (sql.Result, error) {
return nil, nil
}
func (d dummySqlConn) Prepare(query string) (sqlx.StmtSession, error) {
return nil, nil
}
func (d dummySqlConn) QueryRow(v interface{}, query string, args ...interface{}) error {
return nil
}
func (d dummySqlConn) QueryRowPartial(v interface{}, query string, args ...interface{}) error {
return nil
}
func (d dummySqlConn) QueryRows(v interface{}, query string, args ...interface{}) error {
return nil
}
func (d dummySqlConn) QueryRowsPartial(v interface{}, query string, args ...interface{}) error {
return nil
}
func (d dummySqlConn) Transact(func(session sqlx.Session) error) error {
return nil
}
type trackedConn struct {
dummySqlConn
execValue bool
queryRowsValue bool
transactValue bool
}
func (c *trackedConn) Exec(query string, args ...interface{}) (sql.Result, error) {
c.execValue = true
return c.dummySqlConn.Exec(query, args...)
}
func (c *trackedConn) QueryRows(v interface{}, query string, args ...interface{}) error {
c.queryRowsValue = true
return c.dummySqlConn.QueryRows(v, query, args...)
}
func (c *trackedConn) Transact(fn func(session sqlx.Session) error) error {
c.transactValue = true
return c.dummySqlConn.Transact(fn)
}

View File

@@ -0,0 +1,187 @@
package sqlx
import (
"database/sql"
"fmt"
"strings"
"time"
"zero/core/executors"
"zero/core/logx"
"zero/core/stringx"
)
const (
flushInterval = time.Second
maxBulkRows = 1000
valuesKeyword = "values"
)
var emptyBulkStmt bulkStmt
type (
ResultHandler func(sql.Result, error)
BulkInserter struct {
executor *executors.PeriodicalExecutor
inserter *dbInserter
stmt bulkStmt
}
bulkStmt struct {
prefix string
valueFormat string
suffix string
}
)
func NewBulkInserter(sqlConn SqlConn, stmt string) (*BulkInserter, error) {
bkStmt, err := parseInsertStmt(stmt)
if err != nil {
return nil, err
}
inserter := &dbInserter{
sqlConn: sqlConn,
stmt: bkStmt,
}
return &BulkInserter{
executor: executors.NewPeriodicalExecutor(flushInterval, inserter),
inserter: inserter,
stmt: bkStmt,
}, nil
}
func (bi *BulkInserter) Flush() {
bi.executor.Flush()
}
func (bi *BulkInserter) Insert(args ...interface{}) error {
value, err := format(bi.stmt.valueFormat, args...)
if err != nil {
return err
}
bi.executor.Add(value)
return nil
}
func (bi *BulkInserter) SetResultHandler(handler ResultHandler) {
bi.executor.Sync(func() {
bi.inserter.resultHandler = handler
})
}
func (bi *BulkInserter) UpdateOrDelete(fn func()) {
bi.executor.Flush()
fn()
}
func (bi *BulkInserter) UpdateStmt(stmt string) error {
bkStmt, err := parseInsertStmt(stmt)
if err != nil {
return err
}
bi.executor.Flush()
bi.executor.Sync(func() {
bi.inserter.stmt = bkStmt
})
return nil
}
type dbInserter struct {
sqlConn SqlConn
stmt bulkStmt
values []string
resultHandler ResultHandler
}
func (in *dbInserter) AddTask(task interface{}) bool {
in.values = append(in.values, task.(string))
return len(in.values) >= maxBulkRows
}
func (in *dbInserter) Execute(bulk interface{}) {
values := bulk.([]string)
if len(values) == 0 {
return
}
stmtWithoutValues := in.stmt.prefix
valuesStr := strings.Join(values, ", ")
stmt := strings.Join([]string{stmtWithoutValues, valuesStr}, " ")
if len(in.stmt.suffix) > 0 {
stmt = strings.Join([]string{stmt, in.stmt.suffix}, " ")
}
result, err := in.sqlConn.Exec(stmt)
if in.resultHandler != nil {
in.resultHandler(result, err)
} else if err != nil {
logx.Errorf("sql: %s, error: %s", stmt, err)
}
}
func (in *dbInserter) RemoveAll() interface{} {
values := in.values
in.values = nil
return values
}
func parseInsertStmt(stmt string) (bulkStmt, error) {
lower := strings.ToLower(stmt)
pos := strings.Index(lower, valuesKeyword)
if pos <= 0 {
return emptyBulkStmt, fmt.Errorf("bad sql: %q", stmt)
}
var columns int
right := strings.LastIndexByte(lower[:pos], ')')
if right > 0 {
left := strings.LastIndexByte(lower[:right], '(')
if left > 0 {
values := lower[left+1 : right]
values = stringx.Filter(values, func(r rune) bool {
return r == ' ' || r == '\t' || r == '\r' || r == '\n'
})
fields := strings.FieldsFunc(values, func(r rune) bool {
return r == ','
})
columns = len(fields)
}
}
var variables int
var valueFormat string
var suffix string
left := strings.IndexByte(lower[pos:], '(')
if left > 0 {
right = strings.IndexByte(lower[pos+left:], ')')
if right > 0 {
values := lower[pos+left : pos+left+right]
for _, x := range values {
if x == '?' {
variables++
}
}
valueFormat = stmt[pos+left : pos+left+right+1]
suffix = strings.TrimSpace(stmt[pos+left+right+1:])
}
}
if variables == 0 {
return emptyBulkStmt, fmt.Errorf("no variables: %q", stmt)
}
if columns > 0 && columns != variables {
return emptyBulkStmt, fmt.Errorf("columns and variables mismatch: %q", stmt)
}
return bulkStmt{
prefix: stmt[:pos+len(valuesKeyword)],
valueFormat: valueFormat,
suffix: suffix,
}, nil
}

View File

@@ -0,0 +1,98 @@
package sqlx
import (
"database/sql"
"strconv"
"testing"
"zero/core/logx"
"github.com/DATA-DOG/go-sqlmock"
"github.com/stretchr/testify/assert"
)
type mockedConn struct {
query string
args []interface{}
}
func (c *mockedConn) Exec(query string, args ...interface{}) (sql.Result, error) {
c.query = query
c.args = args
return nil, nil
}
func (c *mockedConn) Prepare(query string) (StmtSession, error) {
panic("should not called")
}
func (c *mockedConn) QueryRow(v interface{}, query string, args ...interface{}) error {
panic("should not called")
}
func (c *mockedConn) QueryRowPartial(v interface{}, query string, args ...interface{}) error {
panic("should not called")
}
func (c *mockedConn) QueryRows(v interface{}, query string, args ...interface{}) error {
panic("should not called")
}
func (c *mockedConn) QueryRowsPartial(v interface{}, query string, args ...interface{}) error {
panic("should not called")
}
func (c *mockedConn) Transact(func(session Session) error) error {
panic("should not called")
}
func TestBulkInserter(t *testing.T) {
runSqlTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
var conn mockedConn
inserter, err := NewBulkInserter(&conn, `INSERT INTO classroom_dau(classroom, user, count) VALUES(?, ?, ?)`)
assert.Nil(t, err)
for i := 0; i < 5; i++ {
assert.Nil(t, inserter.Insert("class_"+strconv.Itoa(i), "user_"+strconv.Itoa(i), i))
}
inserter.Flush()
assert.Equal(t, `INSERT INTO classroom_dau(classroom, user, count) VALUES `+
`('class_0', 'user_0', 0), ('class_1', 'user_1', 1), ('class_2', 'user_2', 2), `+
`('class_3', 'user_3', 3), ('class_4', 'user_4', 4)`,
conn.query)
assert.Nil(t, conn.args)
})
}
func TestBulkInserterSuffix(t *testing.T) {
runSqlTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
var conn mockedConn
inserter, err := NewBulkInserter(&conn, `INSERT INTO classroom_dau(classroom, user, count) VALUES`+
`(?, ?, ?) ON DUPLICATE KEY UPDATE is_overtime=VALUES(is_overtime)`)
assert.Nil(t, err)
for i := 0; i < 5; i++ {
assert.Nil(t, inserter.Insert("class_"+strconv.Itoa(i), "user_"+strconv.Itoa(i), i))
}
inserter.Flush()
assert.Equal(t, `INSERT INTO classroom_dau(classroom, user, count) VALUES `+
`('class_0', 'user_0', 0), ('class_1', 'user_1', 1), ('class_2', 'user_2', 2), `+
`('class_3', 'user_3', 3), ('class_4', 'user_4', 4) ON DUPLICATE KEY UPDATE is_overtime=VALUES(is_overtime)`,
conn.query)
assert.Nil(t, conn.args)
})
}
func runSqlTest(t *testing.T, fn func(db *sql.DB, mock sqlmock.Sqlmock)) {
logx.Disable()
db, mock, err := sqlmock.New()
if err != nil {
t.Fatalf("an error '%s' was not expected when opening a stub database connection", err)
}
defer db.Close()
fn(db, mock)
if err := mock.ExpectationsWereMet(); err != nil {
t.Errorf("there were unfulfilled expectations: %s", err)
}
}

37
core/stores/sqlx/mysql.go Normal file
View File

@@ -0,0 +1,37 @@
package sqlx
import "github.com/go-sql-driver/mysql"
const (
mysqlDriverName = "mysql"
duplicateEntryCode uint16 = 1062
)
func NewMysql(datasource string, opts ...SqlOption) SqlConn {
opts = append(opts, withMysqlAcceptable())
return NewSqlConn(mysqlDriverName, datasource, opts...)
}
func mysqlAcceptable(err error) bool {
if err == nil {
return true
}
myerr, ok := err.(*mysql.MySQLError)
if !ok {
return false
}
switch myerr.Number {
case duplicateEntryCode:
return true
default:
return false
}
}
func withMysqlAcceptable() SqlOption {
return func(conn *commonSqlConn) {
conn.accept = mysqlAcceptable
}
}

View File

@@ -0,0 +1,56 @@
package sqlx
import (
"testing"
"zero/core/breaker"
"zero/core/logx"
"zero/core/stat"
"github.com/go-sql-driver/mysql"
"github.com/stretchr/testify/assert"
)
func init() {
stat.SetReporter(nil)
}
func TestBreakerOnDuplicateEntry(t *testing.T) {
logx.Disable()
err := tryOnDuplicateEntryError(t, mysqlAcceptable)
assert.Equal(t, duplicateEntryCode, err.(*mysql.MySQLError).Number)
}
func TestBreakerOnNotHandlingDuplicateEntry(t *testing.T) {
logx.Disable()
var found bool
for i := 0; i < 100; i++ {
if tryOnDuplicateEntryError(t, nil) == breaker.ErrServiceUnavailable {
found = true
}
}
assert.True(t, found)
}
func tryOnDuplicateEntryError(t *testing.T, accept func(error) bool) error {
logx.Disable()
conn := commonSqlConn{
brk: breaker.NewBreaker(),
accept: accept,
}
for i := 0; i < 1000; i++ {
assert.NotNil(t, conn.brk.DoWithAcceptable(func() error {
return &mysql.MySQLError{
Number: duplicateEntryCode,
}
}, conn.acceptable))
}
return conn.brk.DoWithAcceptable(func() error {
return &mysql.MySQLError{
Number: duplicateEntryCode,
}
}, conn.acceptable)
}

254
core/stores/sqlx/orm.go Normal file
View File

@@ -0,0 +1,254 @@
package sqlx
import (
"errors"
"reflect"
"strings"
"zero/core/mapping"
)
const tagName = "db"
var (
ErrNotMatchDestination = errors.New("not matching destination to scan")
ErrNotReadableValue = errors.New("value not addressable or interfaceable")
ErrNotSettable = errors.New("passed in variable is not settable")
ErrUnsupportedValueType = errors.New("unsupported unmarshal type")
)
type rowsScanner interface {
Columns() ([]string, error)
Err() error
Next() bool
Scan(v ...interface{}) error
}
func getTaggedFieldValueMap(v reflect.Value) (map[string]interface{}, error) {
rt := mapping.Deref(v.Type())
size := rt.NumField()
result := make(map[string]interface{}, size)
for i := 0; i < size; i++ {
key := parseTagName(rt.Field(i))
if len(key) == 0 {
return nil, nil
}
valueField := reflect.Indirect(v).Field(i)
switch valueField.Kind() {
case reflect.Ptr:
if !valueField.CanInterface() {
return nil, ErrNotReadableValue
}
if valueField.IsNil() {
baseValueType := mapping.Deref(valueField.Type())
valueField.Set(reflect.New(baseValueType))
}
result[key] = valueField.Interface()
default:
if !valueField.CanAddr() || !valueField.Addr().CanInterface() {
return nil, ErrNotReadableValue
}
result[key] = valueField.Addr().Interface()
}
}
return result, nil
}
func mapStructFieldsIntoSlice(v reflect.Value, columns []string, strict bool) ([]interface{}, error) {
fields := unwrapFields(v)
if strict && len(columns) < len(fields) {
return nil, ErrNotMatchDestination
}
taggedMap, err := getTaggedFieldValueMap(v)
if err != nil {
return nil, err
}
values := make([]interface{}, len(columns))
if len(taggedMap) == 0 {
for i := 0; i < len(values); i++ {
valueField := fields[i]
switch valueField.Kind() {
case reflect.Ptr:
if !valueField.CanInterface() {
return nil, ErrNotReadableValue
}
if valueField.IsNil() {
baseValueType := mapping.Deref(valueField.Type())
valueField.Set(reflect.New(baseValueType))
}
values[i] = valueField.Interface()
default:
if !valueField.CanAddr() || !valueField.Addr().CanInterface() {
return nil, ErrNotReadableValue
}
values[i] = valueField.Addr().Interface()
}
}
} else {
for i, column := range columns {
if tagged, ok := taggedMap[column]; ok {
values[i] = tagged
} else {
var anonymous interface{}
values[i] = &anonymous
}
}
}
return values, nil
}
func parseTagName(field reflect.StructField) string {
key := field.Tag.Get(tagName)
if len(key) == 0 {
return ""
} else {
options := strings.Split(key, ",")
return options[0]
}
}
func unmarshalRow(v interface{}, scanner rowsScanner, strict bool) error {
if !scanner.Next() {
if err := scanner.Err(); err != nil {
return err
}
return ErrNotFound
}
rv := reflect.ValueOf(v)
if err := mapping.ValidatePtr(&rv); err != nil {
return err
}
rte := reflect.TypeOf(v).Elem()
rve := rv.Elem()
switch rte.Kind() {
case reflect.Bool,
reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64,
reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64,
reflect.Float32, reflect.Float64,
reflect.String:
if rve.CanSet() {
return scanner.Scan(v)
} else {
return ErrNotSettable
}
case reflect.Struct:
columns, err := scanner.Columns()
if err != nil {
return err
}
if values, err := mapStructFieldsIntoSlice(rve, columns, strict); err != nil {
return err
} else {
return scanner.Scan(values...)
}
default:
return ErrUnsupportedValueType
}
}
func unmarshalRows(v interface{}, scanner rowsScanner, strict bool) error {
rv := reflect.ValueOf(v)
if err := mapping.ValidatePtr(&rv); err != nil {
return err
}
rt := reflect.TypeOf(v)
rte := rt.Elem()
rve := rv.Elem()
switch rte.Kind() {
case reflect.Slice:
if rve.CanSet() {
ptr := rte.Elem().Kind() == reflect.Ptr
appendFn := func(item reflect.Value) {
if ptr {
rve.Set(reflect.Append(rve, item))
} else {
rve.Set(reflect.Append(rve, reflect.Indirect(item)))
}
}
fillFn := func(value interface{}) error {
if rve.CanSet() {
if err := scanner.Scan(value); err != nil {
return err
} else {
appendFn(reflect.ValueOf(value))
return nil
}
}
return ErrNotSettable
}
base := mapping.Deref(rte.Elem())
switch base.Kind() {
case reflect.Bool,
reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64,
reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64,
reflect.Float32, reflect.Float64,
reflect.String:
for scanner.Next() {
value := reflect.New(base)
if err := fillFn(value.Interface()); err != nil {
return err
}
}
case reflect.Struct:
columns, err := scanner.Columns()
if err != nil {
return err
}
for scanner.Next() {
value := reflect.New(base)
if values, err := mapStructFieldsIntoSlice(value, columns, strict); err != nil {
return err
} else {
if err := scanner.Scan(values...); err != nil {
return err
} else {
appendFn(value)
}
}
}
default:
return ErrUnsupportedValueType
}
return nil
} else {
return ErrNotSettable
}
default:
return ErrUnsupportedValueType
}
}
func unwrapFields(v reflect.Value) []reflect.Value {
var fields []reflect.Value
indirect := reflect.Indirect(v)
for i := 0; i < indirect.NumField(); i++ {
child := indirect.Field(i)
if child.Kind() == reflect.Ptr && child.IsNil() {
baseValueType := mapping.Deref(child.Type())
child.Set(reflect.New(baseValueType))
}
child = reflect.Indirect(child)
childType := indirect.Type().Field(i)
if child.Kind() == reflect.Struct && childType.Anonymous {
fields = append(fields, unwrapFields(child)...)
} else {
fields = append(fields, child)
}
}
return fields
}

View File

@@ -0,0 +1,973 @@
package sqlx
import (
"database/sql"
"testing"
"zero/core/logx"
"github.com/DATA-DOG/go-sqlmock"
"github.com/stretchr/testify/assert"
)
func TestUnmarshalRowBool(t *testing.T) {
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("1")
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
var value bool
assert.Nil(t, query(db, func(rows *sql.Rows) error {
return unmarshalRow(&value, rows, true)
}, "select value from users where user=?", "anyone"))
assert.True(t, value)
})
}
func TestUnmarshalRowInt(t *testing.T) {
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2")
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
var value int
assert.Nil(t, query(db, func(rows *sql.Rows) error {
return unmarshalRow(&value, rows, true)
}, "select value from users where user=?", "anyone"))
assert.EqualValues(t, 2, value)
})
}
func TestUnmarshalRowInt8(t *testing.T) {
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("3")
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
var value int8
assert.Nil(t, query(db, func(rows *sql.Rows) error {
return unmarshalRow(&value, rows, true)
}, "select value from users where user=?", "anyone"))
assert.EqualValues(t, int8(3), value)
})
}
func TestUnmarshalRowInt16(t *testing.T) {
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("4")
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
var value int16
assert.Nil(t, query(db, func(rows *sql.Rows) error {
return unmarshalRow(&value, rows, true)
}, "select value from users where user=?", "anyone"))
assert.Equal(t, int16(4), value)
})
}
func TestUnmarshalRowInt32(t *testing.T) {
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("5")
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
var value int32
assert.Nil(t, query(db, func(rows *sql.Rows) error {
return unmarshalRow(&value, rows, true)
}, "select value from users where user=?", "anyone"))
assert.Equal(t, int32(5), value)
})
}
func TestUnmarshalRowInt64(t *testing.T) {
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("6")
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
var value int64
assert.Nil(t, query(db, func(rows *sql.Rows) error {
return unmarshalRow(&value, rows, true)
}, "select value from users where user=?", "anyone"))
assert.EqualValues(t, int64(6), value)
})
}
func TestUnmarshalRowUint(t *testing.T) {
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2")
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
var value uint
assert.Nil(t, query(db, func(rows *sql.Rows) error {
return unmarshalRow(&value, rows, true)
}, "select value from users where user=?", "anyone"))
assert.EqualValues(t, uint(2), value)
})
}
func TestUnmarshalRowUint8(t *testing.T) {
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("3")
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
var value uint8
assert.Nil(t, query(db, func(rows *sql.Rows) error {
return unmarshalRow(&value, rows, true)
}, "select value from users where user=?", "anyone"))
assert.EqualValues(t, uint8(3), value)
})
}
func TestUnmarshalRowUint16(t *testing.T) {
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("4")
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
var value uint16
assert.Nil(t, query(db, func(rows *sql.Rows) error {
return unmarshalRow(&value, rows, true)
}, "select value from users where user=?", "anyone"))
assert.EqualValues(t, uint16(4), value)
})
}
func TestUnmarshalRowUint32(t *testing.T) {
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("5")
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
var value uint32
assert.Nil(t, query(db, func(rows *sql.Rows) error {
return unmarshalRow(&value, rows, true)
}, "select value from users where user=?", "anyone"))
assert.EqualValues(t, uint32(5), value)
})
}
func TestUnmarshalRowUint64(t *testing.T) {
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("6")
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
var value uint64
assert.Nil(t, query(db, func(rows *sql.Rows) error {
return unmarshalRow(&value, rows, true)
}, "select value from users where user=?", "anyone"))
assert.EqualValues(t, uint16(6), value)
})
}
func TestUnmarshalRowFloat32(t *testing.T) {
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("7")
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
var value float32
assert.Nil(t, query(db, func(rows *sql.Rows) error {
return unmarshalRow(&value, rows, true)
}, "select value from users where user=?", "anyone"))
assert.EqualValues(t, float32(7), value)
})
}
func TestUnmarshalRowFloat64(t *testing.T) {
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("8")
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
var value float64
assert.Nil(t, query(db, func(rows *sql.Rows) error {
return unmarshalRow(&value, rows, true)
}, "select value from users where user=?", "anyone"))
assert.EqualValues(t, float64(8), value)
})
}
func TestUnmarshalRowString(t *testing.T) {
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
const expect = "hello"
rs := sqlmock.NewRows([]string{"value"}).FromCSVString(expect)
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
var value string
assert.Nil(t, query(db, func(rows *sql.Rows) error {
return unmarshalRow(&value, rows, true)
}, "select value from users where user=?", "anyone"))
assert.EqualValues(t, expect, value)
})
}
func TestUnmarshalRowStruct(t *testing.T) {
var value = new(struct {
Name string
Age int
})
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
rs := sqlmock.NewRows([]string{"name", "age"}).FromCSVString("liao,5")
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
assert.Nil(t, query(db, func(rows *sql.Rows) error {
return unmarshalRow(value, rows, true)
}, "select name, age from users where user=?", "anyone"))
assert.Equal(t, "liao", value.Name)
assert.Equal(t, 5, value.Age)
})
}
func TestUnmarshalRowStructWithTags(t *testing.T) {
var value = new(struct {
Age int `db:"age"`
Name string `db:"name"`
})
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
rs := sqlmock.NewRows([]string{"name", "age"}).FromCSVString("liao,5")
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
assert.Nil(t, query(db, func(rows *sql.Rows) error {
return unmarshalRow(value, rows, true)
}, "select name, age from users where user=?", "anyone"))
assert.Equal(t, "liao", value.Name)
assert.Equal(t, 5, value.Age)
})
}
func TestUnmarshalRowsBool(t *testing.T) {
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
var expect = []bool{true, false}
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("1\n0")
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
var value []bool
assert.Nil(t, query(db, func(rows *sql.Rows) error {
return unmarshalRows(&value, rows, true)
}, "select value from users where user=?", "anyone"))
assert.EqualValues(t, expect, value)
})
}
func TestUnmarshalRowsInt(t *testing.T) {
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
var expect = []int{2, 3}
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2\n3")
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
var value []int
assert.Nil(t, query(db, func(rows *sql.Rows) error {
return unmarshalRows(&value, rows, true)
}, "select value from users where user=?", "anyone"))
assert.EqualValues(t, expect, value)
})
}
func TestUnmarshalRowsInt8(t *testing.T) {
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
var expect = []int8{2, 3}
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2\n3")
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
var value []int8
assert.Nil(t, query(db, func(rows *sql.Rows) error {
return unmarshalRows(&value, rows, true)
}, "select value from users where user=?", "anyone"))
assert.EqualValues(t, expect, value)
})
}
func TestUnmarshalRowsInt16(t *testing.T) {
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
var expect = []int16{2, 3}
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2\n3")
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
var value []int16
assert.Nil(t, query(db, func(rows *sql.Rows) error {
return unmarshalRows(&value, rows, true)
}, "select value from users where user=?", "anyone"))
assert.EqualValues(t, expect, value)
})
}
func TestUnmarshalRowsInt32(t *testing.T) {
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
var expect = []int32{2, 3}
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2\n3")
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
var value []int32
assert.Nil(t, query(db, func(rows *sql.Rows) error {
return unmarshalRows(&value, rows, true)
}, "select value from users where user=?", "anyone"))
assert.EqualValues(t, expect, value)
})
}
func TestUnmarshalRowsInt64(t *testing.T) {
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
var expect = []int64{2, 3}
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2\n3")
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
var value []int64
assert.Nil(t, query(db, func(rows *sql.Rows) error {
return unmarshalRows(&value, rows, true)
}, "select value from users where user=?", "anyone"))
assert.EqualValues(t, expect, value)
})
}
func TestUnmarshalRowsUint(t *testing.T) {
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
var expect = []uint{2, 3}
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2\n3")
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
var value []uint
assert.Nil(t, query(db, func(rows *sql.Rows) error {
return unmarshalRows(&value, rows, true)
}, "select value from users where user=?", "anyone"))
assert.EqualValues(t, expect, value)
})
}
func TestUnmarshalRowsUint8(t *testing.T) {
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
var expect = []uint8{2, 3}
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2\n3")
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
var value []uint8
assert.Nil(t, query(db, func(rows *sql.Rows) error {
return unmarshalRows(&value, rows, true)
}, "select value from users where user=?", "anyone"))
assert.EqualValues(t, expect, value)
})
}
func TestUnmarshalRowsUint16(t *testing.T) {
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
var expect = []uint16{2, 3}
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2\n3")
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
var value []uint16
assert.Nil(t, query(db, func(rows *sql.Rows) error {
return unmarshalRows(&value, rows, true)
}, "select value from users where user=?", "anyone"))
assert.EqualValues(t, expect, value)
})
}
func TestUnmarshalRowsUint32(t *testing.T) {
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
var expect = []uint32{2, 3}
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2\n3")
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
var value []uint32
assert.Nil(t, query(db, func(rows *sql.Rows) error {
return unmarshalRows(&value, rows, true)
}, "select value from users where user=?", "anyone"))
assert.EqualValues(t, expect, value)
})
}
func TestUnmarshalRowsUint64(t *testing.T) {
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
var expect = []uint64{2, 3}
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2\n3")
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
var value []uint64
assert.Nil(t, query(db, func(rows *sql.Rows) error {
return unmarshalRows(&value, rows, true)
}, "select value from users where user=?", "anyone"))
assert.EqualValues(t, expect, value)
})
}
func TestUnmarshalRowsFloat32(t *testing.T) {
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
var expect = []float32{2, 3}
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2\n3")
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
var value []float32
assert.Nil(t, query(db, func(rows *sql.Rows) error {
return unmarshalRows(&value, rows, true)
}, "select value from users where user=?", "anyone"))
assert.EqualValues(t, expect, value)
})
}
func TestUnmarshalRowsFloat64(t *testing.T) {
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
var expect = []float64{2, 3}
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2\n3")
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
var value []float64
assert.Nil(t, query(db, func(rows *sql.Rows) error {
return unmarshalRows(&value, rows, true)
}, "select value from users where user=?", "anyone"))
assert.EqualValues(t, expect, value)
})
}
func TestUnmarshalRowsString(t *testing.T) {
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
var expect = []string{"hello", "world"}
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("hello\nworld")
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
var value []string
assert.Nil(t, query(db, func(rows *sql.Rows) error {
return unmarshalRows(&value, rows, true)
}, "select value from users where user=?", "anyone"))
assert.EqualValues(t, expect, value)
})
}
func TestUnmarshalRowsBoolPtr(t *testing.T) {
yes := true
no := false
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
var expect = []*bool{&yes, &no}
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("1\n0")
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
var value []*bool
assert.Nil(t, query(db, func(rows *sql.Rows) error {
return unmarshalRows(&value, rows, true)
}, "select value from users where user=?", "anyone"))
assert.EqualValues(t, expect, value)
})
}
func TestUnmarshalRowsIntPtr(t *testing.T) {
two := 2
three := 3
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
var expect = []*int{&two, &three}
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2\n3")
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
var value []*int
assert.Nil(t, query(db, func(rows *sql.Rows) error {
return unmarshalRows(&value, rows, true)
}, "select value from users where user=?", "anyone"))
assert.EqualValues(t, expect, value)
})
}
func TestUnmarshalRowsInt8Ptr(t *testing.T) {
two := int8(2)
three := int8(3)
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
var expect = []*int8{&two, &three}
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2\n3")
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
var value []*int8
assert.Nil(t, query(db, func(rows *sql.Rows) error {
return unmarshalRows(&value, rows, true)
}, "select value from users where user=?", "anyone"))
assert.EqualValues(t, expect, value)
})
}
func TestUnmarshalRowsInt16Ptr(t *testing.T) {
two := int16(2)
three := int16(3)
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
var expect = []*int16{&two, &three}
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2\n3")
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
var value []*int16
assert.Nil(t, query(db, func(rows *sql.Rows) error {
return unmarshalRows(&value, rows, true)
}, "select value from users where user=?", "anyone"))
assert.EqualValues(t, expect, value)
})
}
func TestUnmarshalRowsInt32Ptr(t *testing.T) {
two := int32(2)
three := int32(3)
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
var expect = []*int32{&two, &three}
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2\n3")
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
var value []*int32
assert.Nil(t, query(db, func(rows *sql.Rows) error {
return unmarshalRows(&value, rows, true)
}, "select value from users where user=?", "anyone"))
assert.EqualValues(t, expect, value)
})
}
func TestUnmarshalRowsInt64Ptr(t *testing.T) {
two := int64(2)
three := int64(3)
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
var expect = []*int64{&two, &three}
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2\n3")
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
var value []*int64
assert.Nil(t, query(db, func(rows *sql.Rows) error {
return unmarshalRows(&value, rows, true)
}, "select value from users where user=?", "anyone"))
assert.EqualValues(t, expect, value)
})
}
func TestUnmarshalRowsUintPtr(t *testing.T) {
two := uint(2)
three := uint(3)
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
var expect = []*uint{&two, &three}
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2\n3")
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
var value []*uint
assert.Nil(t, query(db, func(rows *sql.Rows) error {
return unmarshalRows(&value, rows, true)
}, "select value from users where user=?", "anyone"))
assert.EqualValues(t, expect, value)
})
}
func TestUnmarshalRowsUint8Ptr(t *testing.T) {
two := uint8(2)
three := uint8(3)
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
var expect = []*uint8{&two, &three}
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2\n3")
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
var value []*uint8
assert.Nil(t, query(db, func(rows *sql.Rows) error {
return unmarshalRows(&value, rows, true)
}, "select value from users where user=?", "anyone"))
assert.EqualValues(t, expect, value)
})
}
func TestUnmarshalRowsUint16Ptr(t *testing.T) {
two := uint16(2)
three := uint16(3)
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
var expect = []*uint16{&two, &three}
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2\n3")
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
var value []*uint16
assert.Nil(t, query(db, func(rows *sql.Rows) error {
return unmarshalRows(&value, rows, true)
}, "select value from users where user=?", "anyone"))
assert.EqualValues(t, expect, value)
})
}
func TestUnmarshalRowsUint32Ptr(t *testing.T) {
two := uint32(2)
three := uint32(3)
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
var expect = []*uint32{&two, &three}
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2\n3")
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
var value []*uint32
assert.Nil(t, query(db, func(rows *sql.Rows) error {
return unmarshalRows(&value, rows, true)
}, "select value from users where user=?", "anyone"))
assert.EqualValues(t, expect, value)
})
}
func TestUnmarshalRowsUint64Ptr(t *testing.T) {
two := uint64(2)
three := uint64(3)
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
var expect = []*uint64{&two, &three}
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2\n3")
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
var value []*uint64
assert.Nil(t, query(db, func(rows *sql.Rows) error {
return unmarshalRows(&value, rows, true)
}, "select value from users where user=?", "anyone"))
assert.EqualValues(t, expect, value)
})
}
func TestUnmarshalRowsFloat32Ptr(t *testing.T) {
two := float32(2)
three := float32(3)
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
var expect = []*float32{&two, &three}
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2\n3")
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
var value []*float32
assert.Nil(t, query(db, func(rows *sql.Rows) error {
return unmarshalRows(&value, rows, true)
}, "select value from users where user=?", "anyone"))
assert.EqualValues(t, expect, value)
})
}
func TestUnmarshalRowsFloat64Ptr(t *testing.T) {
two := float64(2)
three := float64(3)
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
var expect = []*float64{&two, &three}
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2\n3")
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
var value []*float64
assert.Nil(t, query(db, func(rows *sql.Rows) error {
return unmarshalRows(&value, rows, true)
}, "select value from users where user=?", "anyone"))
assert.EqualValues(t, expect, value)
})
}
func TestUnmarshalRowsStringPtr(t *testing.T) {
hello := "hello"
world := "world"
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
var expect = []*string{&hello, &world}
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("hello\nworld")
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
var value []*string
assert.Nil(t, query(db, func(rows *sql.Rows) error {
return unmarshalRows(&value, rows, true)
}, "select value from users where user=?", "anyone"))
assert.EqualValues(t, expect, value)
})
}
func TestUnmarshalRowsStruct(t *testing.T) {
var expect = []struct {
Name string
Age int64
}{
{
Name: "first",
Age: 2,
},
{
Name: "second",
Age: 3,
},
}
var value []struct {
Name string
Age int64
}
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
rs := sqlmock.NewRows([]string{"name", "age"}).FromCSVString("first,2\nsecond,3")
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
assert.Nil(t, query(db, func(rows *sql.Rows) error {
return unmarshalRows(&value, rows, true)
}, "select name, age from users where user=?", "anyone"))
for i, each := range expect {
assert.Equal(t, each.Name, value[i].Name)
assert.Equal(t, each.Age, value[i].Age)
}
})
}
func TestUnmarshalRowsStructWithNullStringType(t *testing.T) {
var expect = []struct {
Name string
NullString sql.NullString
}{
{
Name: "first",
NullString: sql.NullString{
String: "firstnullstring",
Valid: true,
},
},
{
Name: "second",
NullString: sql.NullString{
String: "",
Valid: false,
},
},
}
var value []struct {
Name string `db:"name"`
NullString sql.NullString `db:"value"`
}
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
rs := sqlmock.NewRows([]string{"name", "value"}).AddRow(
"first", "firstnullstring").AddRow("second", nil)
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
assert.Nil(t, query(db, func(rows *sql.Rows) error {
return unmarshalRows(&value, rows, true)
}, "select name, age from users where user=?", "anyone"))
for i, each := range expect {
assert.Equal(t, each.Name, value[i].Name)
assert.Equal(t, each.NullString.String, value[i].NullString.String)
assert.Equal(t, each.NullString.Valid, value[i].NullString.Valid)
}
})
}
func TestUnmarshalRowsStructWithTags(t *testing.T) {
var expect = []struct {
Name string
Age int64
}{
{
Name: "first",
Age: 2,
},
{
Name: "second",
Age: 3,
},
}
var value []struct {
Age int64 `db:"age"`
Name string `db:"name"`
}
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
rs := sqlmock.NewRows([]string{"name", "age"}).FromCSVString("first,2\nsecond,3")
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
assert.Nil(t, query(db, func(rows *sql.Rows) error {
return unmarshalRows(&value, rows, true)
}, "select name, age from users where user=?", "anyone"))
for i, each := range expect {
assert.Equal(t, each.Name, value[i].Name)
assert.Equal(t, each.Age, value[i].Age)
}
})
}
func TestUnmarshalRowsStructAndEmbeddedAnonymousStructWithTags(t *testing.T) {
type Embed struct {
Value int64 `db:"value"`
}
var expect = []struct {
Name string
Age int64
Value int64
}{
{
Name: "first",
Age: 2,
Value: 3,
},
{
Name: "second",
Age: 3,
Value: 4,
},
}
var value []struct {
Name string `db:"name"`
Age int64 `db:"age"`
Embed
}
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
rs := sqlmock.NewRows([]string{"name", "age", "value"}).FromCSVString("first,2,3\nsecond,3,4")
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
assert.Nil(t, query(db, func(rows *sql.Rows) error {
return unmarshalRows(&value, rows, true)
}, "select name, age, value from users where user=?", "anyone"))
for i, each := range expect {
assert.Equal(t, each.Name, value[i].Name)
assert.Equal(t, each.Age, value[i].Age)
assert.Equal(t, each.Value, value[i].Value)
}
})
}
func TestUnmarshalRowsStructAndEmbeddedStructPtrAnonymousWithTags(t *testing.T) {
type Embed struct {
Value int64 `db:"value"`
}
var expect = []struct {
Name string
Age int64
Value int64
}{
{
Name: "first",
Age: 2,
Value: 3,
},
{
Name: "second",
Age: 3,
Value: 4,
},
}
var value []struct {
Name string `db:"name"`
Age int64 `db:"age"`
*Embed
}
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
rs := sqlmock.NewRows([]string{"name", "age", "value"}).FromCSVString("first,2,3\nsecond,3,4")
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
assert.Nil(t, query(db, func(rows *sql.Rows) error {
return unmarshalRows(&value, rows, true)
}, "select name, age, value from users where user=?", "anyone"))
for i, each := range expect {
assert.Equal(t, each.Name, value[i].Name)
assert.Equal(t, each.Age, value[i].Age)
assert.Equal(t, each.Value, value[i].Value)
}
})
}
func TestUnmarshalRowsStructPtr(t *testing.T) {
var expect = []*struct {
Name string
Age int64
}{
{
Name: "first",
Age: 2,
},
{
Name: "second",
Age: 3,
},
}
var value []*struct {
Name string
Age int64
}
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
rs := sqlmock.NewRows([]string{"name", "age"}).FromCSVString("first,2\nsecond,3")
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
assert.Nil(t, query(db, func(rows *sql.Rows) error {
return unmarshalRows(&value, rows, true)
}, "select name, age from users where user=?", "anyone"))
for i, each := range expect {
assert.Equal(t, each.Name, value[i].Name)
assert.Equal(t, each.Age, value[i].Age)
}
})
}
func TestUnmarshalRowsStructWithTagsPtr(t *testing.T) {
var expect = []*struct {
Name string
Age int64
}{
{
Name: "first",
Age: 2,
},
{
Name: "second",
Age: 3,
},
}
var value []*struct {
Age int64 `db:"age"`
Name string `db:"name"`
}
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
rs := sqlmock.NewRows([]string{"name", "age"}).FromCSVString("first,2\nsecond,3")
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
assert.Nil(t, query(db, func(rows *sql.Rows) error {
return unmarshalRows(&value, rows, true)
}, "select name, age from users where user=?", "anyone"))
for i, each := range expect {
assert.Equal(t, each.Name, value[i].Name)
assert.Equal(t, each.Age, value[i].Age)
}
})
}
func TestUnmarshalRowsStructWithTagsPtrWithInnerPtr(t *testing.T) {
var expect = []*struct {
Name string
Age int64
}{
{
Name: "first",
Age: 2,
},
{
Name: "second",
Age: 3,
},
}
var value []*struct {
Age *int64 `db:"age"`
Name string `db:"name"`
}
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
rs := sqlmock.NewRows([]string{"name", "age"}).FromCSVString("first,2\nsecond,3")
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
assert.Nil(t, query(db, func(rows *sql.Rows) error {
return unmarshalRows(&value, rows, true)
}, "select name, age from users where user=?", "anyone"))
for i, each := range expect {
assert.Equal(t, each.Name, value[i].Name)
assert.Equal(t, each.Age, *value[i].Age)
}
})
}
func TestCommonSqlConn_QueryRowOptional(t *testing.T) {
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
rs := sqlmock.NewRows([]string{"age"}).FromCSVString("5")
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
var r struct {
User string `db:"user"`
Age int `db:"age"`
}
assert.Nil(t, query(db, func(rows *sql.Rows) error {
return unmarshalRow(&r, rows, false)
}, "select age from users where user=?", "anyone"))
assert.Empty(t, r.User)
assert.Equal(t, 5, r.Age)
})
}
func runOrmTest(t *testing.T, fn func(db *sql.DB, mock sqlmock.Sqlmock)) {
logx.Disable()
db, mock, err := sqlmock.New()
if err != nil {
t.Fatalf("an error '%s' was not expected when opening a stub database connection", err)
}
defer db.Close()
fn(db, mock)
if err := mock.ExpectationsWereMet(); err != nil {
t.Errorf("there were unfulfilled expectations: %s", err)
}
}

204
core/stores/sqlx/sqlconn.go Normal file
View File

@@ -0,0 +1,204 @@
package sqlx
import (
"database/sql"
"zero/core/breaker"
)
var ErrNotFound = sql.ErrNoRows
type (
// Session stands for raw connections or transaction sessions
Session interface {
Exec(query string, args ...interface{}) (sql.Result, error)
Prepare(query string) (StmtSession, error)
QueryRow(v interface{}, query string, args ...interface{}) error
QueryRowPartial(v interface{}, query string, args ...interface{}) error
QueryRows(v interface{}, query string, args ...interface{}) error
QueryRowsPartial(v interface{}, query string, args ...interface{}) error
}
// SqlConn only stands for raw connections, so Transact method can be called.
SqlConn interface {
Session
Transact(func(session Session) error) error
}
SqlOption func(*commonSqlConn)
StmtSession interface {
Close() error
Exec(args ...interface{}) (sql.Result, error)
QueryRow(v interface{}, args ...interface{}) error
QueryRowPartial(v interface{}, args ...interface{}) error
QueryRows(v interface{}, args ...interface{}) error
QueryRowsPartial(v interface{}, args ...interface{}) error
}
// thread-safe
// Because CORBA doesn't support PREPARE, so we need to combine the
// query arguments into one string and do underlying query without arguments
commonSqlConn struct {
driverName string
datasource string
beginTx beginnable
brk breaker.Breaker
accept func(error) bool
}
sessionConn interface {
Exec(query string, args ...interface{}) (sql.Result, error)
Query(query string, args ...interface{}) (*sql.Rows, error)
}
statement struct {
stmt *sql.Stmt
}
stmtConn interface {
Exec(args ...interface{}) (sql.Result, error)
Query(args ...interface{}) (*sql.Rows, error)
}
)
func NewSqlConn(driverName, datasource string, opts ...SqlOption) SqlConn {
conn := &commonSqlConn{
driverName: driverName,
datasource: datasource,
beginTx: begin,
brk: breaker.NewBreaker(),
}
for _, opt := range opts {
opt(conn)
}
return conn
}
func (db *commonSqlConn) Exec(q string, args ...interface{}) (result sql.Result, err error) {
err = db.brk.DoWithAcceptable(func() error {
var conn *sql.DB
conn, err = getSqlConn(db.driverName, db.datasource)
if err != nil {
logInstanceError(db.datasource, err)
return err
}
result, err = exec(conn, q, args...)
return err
}, db.acceptable)
return
}
func (db *commonSqlConn) Prepare(query string) (stmt StmtSession, err error) {
err = db.brk.DoWithAcceptable(func() error {
var conn *sql.DB
conn, err = getSqlConn(db.driverName, db.datasource)
if err != nil {
logInstanceError(db.datasource, err)
return err
}
if st, err := conn.Prepare(query); err != nil {
return err
} else {
stmt = statement{
stmt: st,
}
return nil
}
}, db.acceptable)
return
}
func (db *commonSqlConn) QueryRow(v interface{}, q string, args ...interface{}) error {
return db.queryRows(func(rows *sql.Rows) error {
return unmarshalRow(v, rows, true)
}, q, args...)
}
func (db *commonSqlConn) QueryRowPartial(v interface{}, q string, args ...interface{}) error {
return db.queryRows(func(rows *sql.Rows) error {
return unmarshalRow(v, rows, false)
}, q, args...)
}
func (db *commonSqlConn) QueryRows(v interface{}, q string, args ...interface{}) error {
return db.queryRows(func(rows *sql.Rows) error {
return unmarshalRows(v, rows, true)
}, q, args...)
}
func (db *commonSqlConn) QueryRowsPartial(v interface{}, q string, args ...interface{}) error {
return db.queryRows(func(rows *sql.Rows) error {
return unmarshalRows(v, rows, false)
}, q, args...)
}
func (db *commonSqlConn) Transact(fn func(Session) error) error {
return db.brk.DoWithAcceptable(func() error {
return transact(db, db.beginTx, fn)
}, db.acceptable)
}
func (db *commonSqlConn) acceptable(err error) bool {
ok := err == nil || err == sql.ErrNoRows || err == sql.ErrTxDone
if db.accept == nil {
return ok
} else {
return ok || db.accept(err)
}
}
func (db *commonSqlConn) queryRows(scanner func(*sql.Rows) error, q string, args ...interface{}) error {
var qerr error
return db.brk.DoWithAcceptable(func() error {
conn, err := getSqlConn(db.driverName, db.datasource)
if err != nil {
logInstanceError(db.datasource, err)
return err
}
return query(conn, func(rows *sql.Rows) error {
qerr = scanner(rows)
return qerr
}, q, args...)
}, func(err error) bool {
return qerr == err || db.acceptable(err)
})
}
func (s statement) Close() error {
return s.stmt.Close()
}
func (s statement) Exec(args ...interface{}) (sql.Result, error) {
return execStmt(s.stmt, args...)
}
func (s statement) QueryRow(v interface{}, args ...interface{}) error {
return queryStmt(s.stmt, func(rows *sql.Rows) error {
return unmarshalRow(v, rows, true)
}, args...)
}
func (s statement) QueryRowPartial(v interface{}, args ...interface{}) error {
return queryStmt(s.stmt, func(rows *sql.Rows) error {
return unmarshalRow(v, rows, false)
}, args...)
}
func (s statement) QueryRows(v interface{}, args ...interface{}) error {
return queryStmt(s.stmt, func(rows *sql.Rows) error {
return unmarshalRows(v, rows, true)
}, args...)
}
func (s statement) QueryRowsPartial(v interface{}, args ...interface{}) error {
return queryStmt(s.stmt, func(rows *sql.Rows) error {
return unmarshalRows(v, rows, false)
}, args...)
}

View File

@@ -0,0 +1,74 @@
package sqlx
import (
"database/sql"
"io"
"sync"
"time"
"zero/core/syncx"
)
const (
maxIdleConns = 64
maxOpenConns = 64
maxLifetime = time.Minute
)
var connManager = syncx.NewResourceManager()
type pingedDB struct {
*sql.DB
once sync.Once
}
func getCachedSqlConn(driverName, server string) (*pingedDB, error) {
val, err := connManager.GetResource(server, func() (io.Closer, error) {
conn, err := newDBConnection(driverName, server)
if err != nil {
return nil, err
}
return &pingedDB{
DB: conn,
}, nil
})
if err != nil {
return nil, err
}
return val.(*pingedDB), nil
}
func getSqlConn(driverName, server string) (*sql.DB, error) {
pdb, err := getCachedSqlConn(driverName, server)
if err != nil {
return nil, err
}
pdb.once.Do(func() {
err = pdb.Ping()
})
if err != nil {
return nil, err
}
return pdb.DB, nil
}
func newDBConnection(driverName, datasource string) (*sql.DB, error) {
conn, err := sql.Open(driverName, datasource)
if err != nil {
return nil, err
}
// we need to do this until the issue https://github.com/golang/go/issues/9851 get fixed
// discussed here https://github.com/go-sql-driver/mysql/issues/257
// if the discussed SetMaxIdleTimeout methods added, we'll change this behavior
// 8 means we can't have more than 8 goroutines to concurrently access the same database.
conn.SetMaxIdleConns(maxIdleConns)
conn.SetMaxOpenConns(maxOpenConns)
conn.SetConnMaxLifetime(maxLifetime)
return conn, nil
}

92
core/stores/sqlx/stmt.go Normal file
View File

@@ -0,0 +1,92 @@
package sqlx
import (
"database/sql"
"fmt"
"time"
"zero/core/logx"
"zero/core/timex"
)
const slowThreshold = time.Millisecond * 500
func exec(conn sessionConn, q string, args ...interface{}) (sql.Result, error) {
stmt, err := format(q, args...)
if err != nil {
return nil, err
}
startTime := timex.Now()
result, err := conn.Exec(q, args...)
duration := timex.Since(startTime)
if duration > slowThreshold {
logx.WithDuration(duration).Slowf("[SQL] exec: slowcall - %s", stmt)
} else {
logx.WithDuration(duration).Infof("sql exec: %s", stmt)
}
if err != nil {
logSqlError(stmt, err)
}
return result, err
}
func execStmt(conn stmtConn, args ...interface{}) (sql.Result, error) {
stmt := fmt.Sprint(args...)
startTime := timex.Now()
result, err := conn.Exec(args...)
duration := timex.Since(startTime)
if duration > slowThreshold {
logx.WithDuration(duration).Slowf("[SQL] execStmt: slowcall - %s", stmt)
} else {
logx.WithDuration(duration).Infof("sql execStmt: %s", stmt)
}
if err != nil {
logSqlError(stmt, err)
}
return result, err
}
func query(conn sessionConn, scanner func(*sql.Rows) error, q string, args ...interface{}) error {
stmt, err := format(q, args...)
if err != nil {
return err
}
startTime := timex.Now()
rows, err := conn.Query(q, args...)
duration := timex.Since(startTime)
if duration > slowThreshold {
logx.WithDuration(duration).Slowf("[SQL] query: slowcall - %s", stmt)
} else {
logx.WithDuration(duration).Infof("sql query: %s", stmt)
}
if err != nil {
logSqlError(stmt, err)
return err
}
defer rows.Close()
return scanner(rows)
}
func queryStmt(conn stmtConn, scanner func(*sql.Rows) error, args ...interface{}) error {
stmt := fmt.Sprint(args...)
startTime := timex.Now()
rows, err := conn.Query(args...)
duration := timex.Since(startTime)
if duration > slowThreshold {
logx.WithDuration(duration).Slowf("[SQL] queryStmt: slowcall - %s", stmt)
} else {
logx.WithDuration(duration).Infof("sql queryStmt: %s", stmt)
}
if err != nil {
logSqlError(stmt, err)
return err
}
defer rows.Close()
return scanner(rows)
}

103
core/stores/sqlx/tx.go Normal file
View File

@@ -0,0 +1,103 @@
package sqlx
import (
"database/sql"
"fmt"
)
type (
beginnable func(*sql.DB) (trans, error)
trans interface {
Session
Commit() error
Rollback() error
}
txSession struct {
*sql.Tx
}
)
func (t txSession) Exec(q string, args ...interface{}) (sql.Result, error) {
return exec(t.Tx, q, args...)
}
func (t txSession) Prepare(q string) (StmtSession, error) {
if stmt, err := t.Tx.Prepare(q); err != nil {
return nil, err
} else {
return statement{
stmt: stmt,
}, nil
}
}
func (t txSession) QueryRow(v interface{}, q string, args ...interface{}) error {
return query(t.Tx, func(rows *sql.Rows) error {
return unmarshalRow(v, rows, true)
}, q, args...)
}
func (t txSession) QueryRowPartial(v interface{}, q string, args ...interface{}) error {
return query(t.Tx, func(rows *sql.Rows) error {
return unmarshalRow(v, rows, false)
}, q, args...)
}
func (t txSession) QueryRows(v interface{}, q string, args ...interface{}) error {
return query(t.Tx, func(rows *sql.Rows) error {
return unmarshalRows(v, rows, true)
}, q, args...)
}
func (t txSession) QueryRowsPartial(v interface{}, q string, args ...interface{}) error {
return query(t.Tx, func(rows *sql.Rows) error {
return unmarshalRows(v, rows, false)
}, q, args...)
}
func begin(db *sql.DB) (trans, error) {
if tx, err := db.Begin(); err != nil {
return nil, err
} else {
return txSession{
Tx: tx,
}, nil
}
}
func transact(db *commonSqlConn, b beginnable, fn func(Session) error) (err error) {
conn, err := getSqlConn(db.driverName, db.datasource)
if err != nil {
logInstanceError(db.datasource, err)
return err
}
return transactOnConn(conn, b, fn)
}
func transactOnConn(conn *sql.DB, b beginnable, fn func(Session) error) (err error) {
var tx trans
tx, err = b(conn)
if err != nil {
return
}
defer func() {
if p := recover(); p != nil {
if e := tx.Rollback(); e != nil {
err = fmt.Errorf("recover from %#v, rollback failed: %s", p, e)
} else {
err = fmt.Errorf("recoveer from %#v", p)
}
} else if err != nil {
if e := tx.Rollback(); e != nil {
err = fmt.Errorf("transaction failed: %s, rollback failed: %s", err, e)
}
} else {
err = tx.Commit()
}
}()
return fn(tx)
}

View File

@@ -0,0 +1,76 @@
package sqlx
import (
"database/sql"
"errors"
"testing"
"github.com/stretchr/testify/assert"
)
const (
mockCommit = 1
mockRollback = 2
)
type mockTx struct {
status int
}
func (mt *mockTx) Commit() error {
mt.status |= mockCommit
return nil
}
func (mt *mockTx) Exec(q string, args ...interface{}) (sql.Result, error) {
return nil, nil
}
func (mt *mockTx) Prepare(query string) (StmtSession, error) {
return nil, nil
}
func (mt *mockTx) QueryRow(v interface{}, q string, args ...interface{}) error {
return nil
}
func (mt *mockTx) QueryRowPartial(v interface{}, q string, args ...interface{}) error {
return nil
}
func (mt *mockTx) QueryRows(v interface{}, q string, args ...interface{}) error {
return nil
}
func (mt *mockTx) QueryRowsPartial(v interface{}, q string, args ...interface{}) error {
return nil
}
func (mt *mockTx) Rollback() error {
mt.status |= mockRollback
return nil
}
func beginMock(mock *mockTx) beginnable {
return func(*sql.DB) (trans, error) {
return mock, nil
}
}
func TestTransactCommit(t *testing.T) {
mock := &mockTx{}
err := transactOnConn(nil, beginMock(mock), func(Session) error {
return nil
})
assert.Equal(t, mockCommit, mock.status)
assert.Nil(t, err)
}
func TestTransactRollback(t *testing.T) {
mock := &mockTx{}
err := transactOnConn(nil, beginMock(mock), func(Session) error {
return errors.New("rollback")
})
assert.Equal(t, mockRollback, mock.status)
assert.NotNil(t, err)
}

101
core/stores/sqlx/utils.go Normal file
View File

@@ -0,0 +1,101 @@
package sqlx
import (
"fmt"
"strings"
"zero/core/logx"
"zero/core/mapping"
)
func desensitize(datasource string) string {
// remove account
pos := strings.LastIndex(datasource, "@")
if 0 <= pos && pos+1 < len(datasource) {
datasource = datasource[pos+1:]
}
return datasource
}
func escape(input string) string {
var b strings.Builder
for _, ch := range input {
switch ch {
case '\x00':
b.WriteString(`\x00`)
case '\r':
b.WriteString(`\r`)
case '\n':
b.WriteString(`\n`)
case '\\':
b.WriteString(`\\`)
case '\'':
b.WriteString(`\'`)
case '"':
b.WriteString(`\"`)
case '\x1a':
b.WriteString(`\x1a`)
default:
b.WriteRune(ch)
}
}
return b.String()
}
func format(query string, args ...interface{}) (string, error) {
numArgs := len(args)
if numArgs == 0 {
return query, nil
}
var b strings.Builder
argIndex := 0
for _, ch := range query {
if ch == '?' {
if argIndex >= numArgs {
return "", fmt.Errorf("error: %d ? in sql, but less arguments provided", argIndex)
}
arg := args[argIndex]
argIndex++
switch v := arg.(type) {
case bool:
if v {
b.WriteByte('1')
} else {
b.WriteByte('0')
}
case string:
b.WriteByte('\'')
b.WriteString(escape(v))
b.WriteByte('\'')
default:
b.WriteString(mapping.Repr(v))
}
} else {
b.WriteRune(ch)
}
}
if argIndex < numArgs {
return "", fmt.Errorf("error: %d ? in sql, but more arguments provided", argIndex)
}
return b.String(), nil
}
func logInstanceError(datasource string, err error) {
datasource = desensitize(datasource)
logx.Errorf("Error on getting sql instance of %s: %v", datasource, err)
}
func logSqlError(stmt string, err error) {
if err != nil && err != ErrNotFound {
logx.Errorf("stmt: %s, error: %s", stmt, err.Error())
}
}

View File

@@ -0,0 +1,30 @@
package sqlx
import (
"strings"
"testing"
"github.com/stretchr/testify/assert"
)
func TestEscape(t *testing.T) {
s := "a\x00\n\r\\'\"\x1ab"
out := escape(s)
assert.Equal(t, `a\x00\n\r\\\'\"\x1ab`, out)
}
func TestDesensitize(t *testing.T) {
datasource := "user:pass@tcp(111.222.333.44:3306)/any_table?charset=utf8mb4&parseTime=true&loc=Asia%2FShanghai"
datasource = desensitize(datasource)
assert.False(t, strings.Contains(datasource, "user"))
assert.False(t, strings.Contains(datasource, "pass"))
assert.True(t, strings.Contains(datasource, "tcp(111.222.333.44:3306)"))
}
func TestDesensitize_WithoutAccount(t *testing.T) {
datasource := "tcp(111.222.333.44:3306)/any_table?charset=utf8mb4&parseTime=true&loc=Asia%2FShanghai"
datasource = desensitize(datasource)
assert.True(t, strings.Contains(datasource, "tcp(111.222.333.44:3306)"))
}