mirror of
https://github.com/zeromicro/go-zero.git
synced 2026-05-13 09:50:00 +08:00
initial import
This commit is contained in:
5
core/stores/cache/cacheconf.go
vendored
Normal file
5
core/stores/cache/cacheconf.go
vendored
Normal file
@@ -0,0 +1,5 @@
|
||||
package cache
|
||||
|
||||
import "zero/core/stores/internal"
|
||||
|
||||
type CacheConf = internal.ClusterConf
|
||||
21
core/stores/cache/cacheopt.go
vendored
Normal file
21
core/stores/cache/cacheopt.go
vendored
Normal file
@@ -0,0 +1,21 @@
|
||||
package cache
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"zero/core/stores/internal"
|
||||
)
|
||||
|
||||
type Option = internal.Option
|
||||
|
||||
func WithExpiry(expiry time.Duration) Option {
|
||||
return func(o *internal.Options) {
|
||||
o.Expiry = expiry
|
||||
}
|
||||
}
|
||||
|
||||
func WithNotFoundExpiry(expiry time.Duration) Option {
|
||||
return func(o *internal.Options) {
|
||||
o.NotFoundExpiry = expiry
|
||||
}
|
||||
}
|
||||
13
core/stores/clickhouse/clickhouse.go
Normal file
13
core/stores/clickhouse/clickhouse.go
Normal file
@@ -0,0 +1,13 @@
|
||||
package clickhouse
|
||||
|
||||
import (
|
||||
"zero/core/stores/sqlx"
|
||||
|
||||
_ "github.com/kshvakov/clickhouse"
|
||||
)
|
||||
|
||||
const clickHouseDriverName = "clickhouse"
|
||||
|
||||
func New(datasource string, opts ...sqlx.SqlOption) sqlx.SqlConn {
|
||||
return sqlx.NewSqlConn(clickHouseDriverName, datasource, opts...)
|
||||
}
|
||||
129
core/stores/internal/cache.go
Normal file
129
core/stores/internal/cache.go
Normal file
@@ -0,0 +1,129 @@
|
||||
package internal
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"log"
|
||||
"time"
|
||||
|
||||
"zero/core/errorx"
|
||||
"zero/core/hash"
|
||||
"zero/core/syncx"
|
||||
)
|
||||
|
||||
type (
|
||||
Cache interface {
|
||||
DelCache(keys ...string) error
|
||||
GetCache(key string, v interface{}) error
|
||||
SetCache(key string, v interface{}) error
|
||||
SetCacheWithExpire(key string, v interface{}, expire time.Duration) error
|
||||
Take(v interface{}, key string, query func(v interface{}) error) error
|
||||
TakeWithExpire(v interface{}, key string, query func(v interface{}, expire time.Duration) error) error
|
||||
}
|
||||
|
||||
cacheCluster struct {
|
||||
dispatcher *hash.ConsistentHash
|
||||
errNotFound error
|
||||
}
|
||||
)
|
||||
|
||||
func NewCache(c ClusterConf, barrier syncx.SharedCalls, st *CacheStat, errNotFound error,
|
||||
opts ...Option) Cache {
|
||||
if len(c) == 0 || TotalWeights(c) <= 0 {
|
||||
log.Fatal("no cache nodes")
|
||||
}
|
||||
|
||||
if len(c) == 1 {
|
||||
return NewCacheNode(c[0].NewRedis(), barrier, st, errNotFound, opts...)
|
||||
}
|
||||
|
||||
dispatcher := hash.NewConsistentHash()
|
||||
for _, node := range c {
|
||||
cn := NewCacheNode(node.NewRedis(), barrier, st, errNotFound, opts...)
|
||||
dispatcher.AddWithWeight(cn, node.Weight)
|
||||
}
|
||||
|
||||
return cacheCluster{
|
||||
dispatcher: dispatcher,
|
||||
errNotFound: errNotFound,
|
||||
}
|
||||
}
|
||||
|
||||
func (cc cacheCluster) DelCache(keys ...string) error {
|
||||
switch len(keys) {
|
||||
case 0:
|
||||
return nil
|
||||
case 1:
|
||||
key := keys[0]
|
||||
c, ok := cc.dispatcher.Get(key)
|
||||
if !ok {
|
||||
return cc.errNotFound
|
||||
}
|
||||
|
||||
return c.(Cache).DelCache(key)
|
||||
default:
|
||||
var be errorx.BatchError
|
||||
nodes := make(map[interface{}][]string)
|
||||
for _, key := range keys {
|
||||
c, ok := cc.dispatcher.Get(key)
|
||||
if !ok {
|
||||
be.Add(fmt.Errorf("key %q not found", key))
|
||||
continue
|
||||
}
|
||||
|
||||
nodes[c] = append(nodes[c], key)
|
||||
}
|
||||
for c, ks := range nodes {
|
||||
if err := c.(Cache).DelCache(ks...); err != nil {
|
||||
be.Add(err)
|
||||
}
|
||||
}
|
||||
|
||||
return be.Err()
|
||||
}
|
||||
}
|
||||
|
||||
func (cc cacheCluster) GetCache(key string, v interface{}) error {
|
||||
c, ok := cc.dispatcher.Get(key)
|
||||
if !ok {
|
||||
return cc.errNotFound
|
||||
}
|
||||
|
||||
return c.(Cache).GetCache(key, v)
|
||||
}
|
||||
|
||||
func (cc cacheCluster) SetCache(key string, v interface{}) error {
|
||||
c, ok := cc.dispatcher.Get(key)
|
||||
if !ok {
|
||||
return cc.errNotFound
|
||||
}
|
||||
|
||||
return c.(Cache).SetCache(key, v)
|
||||
}
|
||||
|
||||
func (cc cacheCluster) SetCacheWithExpire(key string, v interface{}, expire time.Duration) error {
|
||||
c, ok := cc.dispatcher.Get(key)
|
||||
if !ok {
|
||||
return cc.errNotFound
|
||||
}
|
||||
|
||||
return c.(Cache).SetCacheWithExpire(key, v, expire)
|
||||
}
|
||||
|
||||
func (cc cacheCluster) Take(v interface{}, key string, query func(v interface{}) error) error {
|
||||
c, ok := cc.dispatcher.Get(key)
|
||||
if !ok {
|
||||
return cc.errNotFound
|
||||
}
|
||||
|
||||
return c.(Cache).Take(v, key, query)
|
||||
}
|
||||
|
||||
func (cc cacheCluster) TakeWithExpire(v interface{}, key string,
|
||||
query func(v interface{}, expire time.Duration) error) error {
|
||||
c, ok := cc.dispatcher.Get(key)
|
||||
if !ok {
|
||||
return cc.errNotFound
|
||||
}
|
||||
|
||||
return c.(Cache).TakeWithExpire(v, key, query)
|
||||
}
|
||||
201
core/stores/internal/cache_test.go
Normal file
201
core/stores/internal/cache_test.go
Normal file
@@ -0,0 +1,201 @@
|
||||
package internal
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"math"
|
||||
"strconv"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"zero/core/errorx"
|
||||
"zero/core/hash"
|
||||
"zero/core/stores/redis"
|
||||
"zero/core/syncx"
|
||||
|
||||
"github.com/alicebob/miniredis"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
type mockedNode struct {
|
||||
vals map[string][]byte
|
||||
errNotFound error
|
||||
}
|
||||
|
||||
func (mc *mockedNode) DelCache(keys ...string) error {
|
||||
var be errorx.BatchError
|
||||
for _, key := range keys {
|
||||
if _, ok := mc.vals[key]; !ok {
|
||||
be.Add(mc.errNotFound)
|
||||
} else {
|
||||
delete(mc.vals, key)
|
||||
}
|
||||
}
|
||||
return be.Err()
|
||||
}
|
||||
|
||||
func (mc *mockedNode) GetCache(key string, v interface{}) error {
|
||||
bs, ok := mc.vals[key]
|
||||
if ok {
|
||||
return json.Unmarshal(bs, v)
|
||||
}
|
||||
|
||||
return mc.errNotFound
|
||||
}
|
||||
|
||||
func (mc *mockedNode) SetCache(key string, v interface{}) error {
|
||||
data, err := json.Marshal(v)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
mc.vals[key] = data
|
||||
return nil
|
||||
}
|
||||
|
||||
func (mc *mockedNode) SetCacheWithExpire(key string, v interface{}, expire time.Duration) error {
|
||||
return mc.SetCache(key, v)
|
||||
}
|
||||
|
||||
func (mc *mockedNode) Take(v interface{}, key string, query func(v interface{}) error) error {
|
||||
if _, ok := mc.vals[key]; ok {
|
||||
return mc.GetCache(key, v)
|
||||
}
|
||||
|
||||
if err := query(v); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return mc.SetCache(key, v)
|
||||
}
|
||||
|
||||
func (mc *mockedNode) TakeWithExpire(v interface{}, key string, query func(v interface{}, expire time.Duration) error) error {
|
||||
return mc.Take(v, key, func(v interface{}) error {
|
||||
return query(v, 0)
|
||||
})
|
||||
}
|
||||
|
||||
func TestCache_SetDel(t *testing.T) {
|
||||
const total = 1000
|
||||
r1 := miniredis.NewMiniRedis()
|
||||
assert.Nil(t, r1.Start())
|
||||
defer r1.Close()
|
||||
r2 := miniredis.NewMiniRedis()
|
||||
assert.Nil(t, r2.Start())
|
||||
defer r2.Close()
|
||||
conf := ClusterConf{
|
||||
{
|
||||
RedisConf: redis.RedisConf{
|
||||
Host: r1.Addr(),
|
||||
Type: redis.NodeType,
|
||||
},
|
||||
Weight: 100,
|
||||
},
|
||||
{
|
||||
RedisConf: redis.RedisConf{
|
||||
Host: r2.Addr(),
|
||||
Type: redis.NodeType,
|
||||
},
|
||||
Weight: 100,
|
||||
},
|
||||
}
|
||||
c := NewCache(conf, syncx.NewSharedCalls(), NewCacheStat("mock"), errPlaceholder)
|
||||
for i := 0; i < total; i++ {
|
||||
if i%2 == 0 {
|
||||
assert.Nil(t, c.SetCache(fmt.Sprintf("key/%d", i), i))
|
||||
} else {
|
||||
assert.Nil(t, c.SetCacheWithExpire(fmt.Sprintf("key/%d", i), i, 0))
|
||||
}
|
||||
}
|
||||
for i := 0; i < total; i++ {
|
||||
var v int
|
||||
assert.Nil(t, c.GetCache(fmt.Sprintf("key/%d", i), &v))
|
||||
assert.Equal(t, i, v)
|
||||
}
|
||||
for i := 0; i < total; i++ {
|
||||
assert.Nil(t, c.DelCache(fmt.Sprintf("key/%d", i)))
|
||||
}
|
||||
for i := 0; i < total; i++ {
|
||||
var v int
|
||||
assert.Equal(t, errPlaceholder, c.GetCache(fmt.Sprintf("key/%d", i), &v))
|
||||
assert.Equal(t, 0, v)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCache_Balance(t *testing.T) {
|
||||
const (
|
||||
numNodes = 100
|
||||
total = 10000
|
||||
)
|
||||
dispatcher := hash.NewConsistentHash()
|
||||
maps := make([]map[string][]byte, numNodes)
|
||||
for i := 0; i < numNodes; i++ {
|
||||
maps[i] = map[string][]byte{
|
||||
strconv.Itoa(i): []byte(strconv.Itoa(i)),
|
||||
}
|
||||
}
|
||||
for i := 0; i < numNodes; i++ {
|
||||
dispatcher.AddWithWeight(&mockedNode{
|
||||
vals: maps[i],
|
||||
errNotFound: errPlaceholder,
|
||||
}, 100)
|
||||
}
|
||||
|
||||
c := cacheCluster{
|
||||
dispatcher: dispatcher,
|
||||
errNotFound: errPlaceholder,
|
||||
}
|
||||
for i := 0; i < total; i++ {
|
||||
assert.Nil(t, c.SetCache(strconv.Itoa(i), i))
|
||||
}
|
||||
|
||||
counts := make(map[int]int)
|
||||
for i, m := range maps {
|
||||
counts[i] = len(m)
|
||||
}
|
||||
entropy := calcEntropy(counts, total)
|
||||
assert.True(t, len(counts) > 1)
|
||||
assert.True(t, entropy > .95, fmt.Sprintf("entropy should be greater than 0.95, but got %.2f", entropy))
|
||||
|
||||
for i := 0; i < total; i++ {
|
||||
var v int
|
||||
assert.Nil(t, c.GetCache(strconv.Itoa(i), &v))
|
||||
assert.Equal(t, i, v)
|
||||
}
|
||||
|
||||
for i := 0; i < total/10; i++ {
|
||||
assert.Nil(t, c.DelCache(strconv.Itoa(i*10), strconv.Itoa(i*10+1), strconv.Itoa(i*10+2)))
|
||||
assert.Nil(t, c.DelCache(strconv.Itoa(i*10+9)))
|
||||
}
|
||||
|
||||
var count int
|
||||
for i := 0; i < total/10; i++ {
|
||||
var val int
|
||||
if i%2 == 0 {
|
||||
assert.Nil(t, c.Take(&val, strconv.Itoa(i*10), func(v interface{}) error {
|
||||
*v.(*int) = i
|
||||
count++
|
||||
return nil
|
||||
}))
|
||||
} else {
|
||||
assert.Nil(t, c.TakeWithExpire(&val, strconv.Itoa(i*10), func(v interface{}, expire time.Duration) error {
|
||||
*v.(*int) = i
|
||||
count++
|
||||
return nil
|
||||
}))
|
||||
}
|
||||
assert.Equal(t, i, val)
|
||||
}
|
||||
assert.Equal(t, total/10, count)
|
||||
}
|
||||
|
||||
func calcEntropy(m map[int]int, total int) float64 {
|
||||
var entropy float64
|
||||
|
||||
for _, v := range m {
|
||||
proba := float64(v) / float64(total)
|
||||
entropy -= proba * math.Log2(proba)
|
||||
}
|
||||
|
||||
return entropy / math.Log2(float64(len(m)))
|
||||
}
|
||||
208
core/stores/internal/cachenode.go
Normal file
208
core/stores/internal/cachenode.go
Normal file
@@ -0,0 +1,208 @@
|
||||
package internal
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"zero/core/logx"
|
||||
"zero/core/mathx"
|
||||
"zero/core/stat"
|
||||
"zero/core/stores/redis"
|
||||
"zero/core/syncx"
|
||||
)
|
||||
|
||||
const (
|
||||
notFoundPlaceholder = "*"
|
||||
// make the expiry unstable to avoid lots of cached items expire at the same time
|
||||
// make the unstable expiry to be [0.95, 1.05] * seconds
|
||||
expiryDeviation = 0.05
|
||||
)
|
||||
|
||||
// indicates there is no such value associate with the key
|
||||
var errPlaceholder = errors.New("placeholder")
|
||||
|
||||
type cacheNode struct {
|
||||
rds *redis.Redis
|
||||
expiry time.Duration
|
||||
notFoundExpiry time.Duration
|
||||
barrier syncx.SharedCalls
|
||||
r *rand.Rand
|
||||
lock *sync.Mutex
|
||||
unstableExpiry mathx.Unstable
|
||||
stat *CacheStat
|
||||
errNotFound error
|
||||
}
|
||||
|
||||
func NewCacheNode(rds *redis.Redis, barrier syncx.SharedCalls, st *CacheStat,
|
||||
errNotFound error, opts ...Option) Cache {
|
||||
o := newOptions(opts...)
|
||||
return cacheNode{
|
||||
rds: rds,
|
||||
expiry: o.Expiry,
|
||||
notFoundExpiry: o.NotFoundExpiry,
|
||||
barrier: barrier,
|
||||
r: rand.New(rand.NewSource(time.Now().UnixNano())),
|
||||
lock: new(sync.Mutex),
|
||||
unstableExpiry: mathx.NewUnstable(expiryDeviation),
|
||||
stat: st,
|
||||
errNotFound: errNotFound,
|
||||
}
|
||||
}
|
||||
|
||||
func (c cacheNode) DelCache(keys ...string) error {
|
||||
if len(keys) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
if _, err := c.rds.Del(keys...); err != nil {
|
||||
logx.Errorf("failed to clear cache with keys: %q, error: %v", formatKeys(keys), err)
|
||||
c.asyncRetryDelCache(keys...)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c cacheNode) GetCache(key string, v interface{}) error {
|
||||
if err := c.doGetCache(key, v); err == errPlaceholder {
|
||||
return c.errNotFound
|
||||
} else {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
func (c cacheNode) SetCache(key string, v interface{}) error {
|
||||
return c.SetCacheWithExpire(key, v, c.aroundDuration(c.expiry))
|
||||
}
|
||||
|
||||
func (c cacheNode) SetCacheWithExpire(key string, v interface{}, expire time.Duration) error {
|
||||
data, err := json.Marshal(v)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return c.rds.Setex(key, string(data), int(expire.Seconds()))
|
||||
}
|
||||
|
||||
func (c cacheNode) String() string {
|
||||
return c.rds.Addr
|
||||
}
|
||||
|
||||
func (c cacheNode) Take(v interface{}, key string, query func(v interface{}) error) error {
|
||||
return c.doTake(v, key, query, func(v interface{}) error {
|
||||
return c.SetCache(key, v)
|
||||
})
|
||||
}
|
||||
|
||||
func (c cacheNode) TakeWithExpire(v interface{}, key string,
|
||||
query func(v interface{}, expire time.Duration) error) error {
|
||||
expire := c.aroundDuration(c.expiry)
|
||||
return c.doTake(v, key, func(v interface{}) error {
|
||||
return query(v, expire)
|
||||
}, func(v interface{}) error {
|
||||
return c.SetCacheWithExpire(key, v, expire)
|
||||
})
|
||||
}
|
||||
|
||||
func (c cacheNode) aroundDuration(duration time.Duration) time.Duration {
|
||||
return c.unstableExpiry.AroundDuration(duration)
|
||||
}
|
||||
|
||||
func (c cacheNode) asyncRetryDelCache(keys ...string) {
|
||||
AddCleanTask(func() error {
|
||||
_, err := c.rds.Del(keys...)
|
||||
return err
|
||||
}, keys...)
|
||||
}
|
||||
|
||||
func (c cacheNode) doGetCache(key string, v interface{}) error {
|
||||
c.stat.IncrementTotal()
|
||||
data, err := c.rds.Get(key)
|
||||
if err != nil {
|
||||
c.stat.IncrementMiss()
|
||||
return err
|
||||
}
|
||||
|
||||
if len(data) == 0 {
|
||||
c.stat.IncrementMiss()
|
||||
return c.errNotFound
|
||||
}
|
||||
|
||||
c.stat.IncrementHit()
|
||||
if data == notFoundPlaceholder {
|
||||
return errPlaceholder
|
||||
}
|
||||
|
||||
return c.processCache(key, data, v)
|
||||
}
|
||||
|
||||
func (c cacheNode) doTake(v interface{}, key string, query func(v interface{}) error,
|
||||
cacheVal func(v interface{}) error) error {
|
||||
val, fresh, err := c.barrier.DoEx(key, func() (interface{}, error) {
|
||||
if err := c.doGetCache(key, v); err != nil {
|
||||
if err == errPlaceholder {
|
||||
return nil, c.errNotFound
|
||||
} else if err != c.errNotFound {
|
||||
// why we just return the error instead of query from db,
|
||||
// because we don't allow the disaster pass to the dbs.
|
||||
// fail fast, in case we bring down the dbs.
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err = query(v); err == c.errNotFound {
|
||||
if err = c.setCacheWithNotFound(key); err != nil {
|
||||
logx.Error(err)
|
||||
}
|
||||
|
||||
return nil, c.errNotFound
|
||||
} else if err != nil {
|
||||
c.stat.IncrementDbFails()
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err = cacheVal(v); err != nil {
|
||||
logx.Error(err)
|
||||
}
|
||||
}
|
||||
|
||||
return json.Marshal(v)
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if fresh {
|
||||
return nil
|
||||
} else {
|
||||
// got the result from previous ongoing query
|
||||
c.stat.IncrementTotal()
|
||||
c.stat.IncrementHit()
|
||||
}
|
||||
|
||||
return json.Unmarshal(val.([]byte), v)
|
||||
}
|
||||
|
||||
func (c cacheNode) processCache(key string, data string, v interface{}) error {
|
||||
err := json.Unmarshal([]byte(data), v)
|
||||
if err == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
report := fmt.Sprintf("unmarshal cache, node: %s, key: %s, value: %s, error: %v",
|
||||
c.rds.Addr, key, data, err)
|
||||
logx.Error(report)
|
||||
stat.Report(report)
|
||||
if _, e := c.rds.Del(key); e != nil {
|
||||
logx.Errorf("delete invalid cache, node: %s, key: %s, value: %s, error: %v",
|
||||
c.rds.Addr, key, data, e)
|
||||
}
|
||||
|
||||
// returns errNotFound to reload the value by the given queryFn
|
||||
return c.errNotFound
|
||||
}
|
||||
|
||||
func (c cacheNode) setCacheWithNotFound(key string) error {
|
||||
return c.rds.Setex(key, notFoundPlaceholder, int(c.aroundDuration(c.notFoundExpiry).Seconds()))
|
||||
}
|
||||
66
core/stores/internal/cachenode_test.go
Normal file
66
core/stores/internal/cachenode_test.go
Normal file
@@ -0,0 +1,66 @@
|
||||
package internal
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"math/rand"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"zero/core/logx"
|
||||
"zero/core/mathx"
|
||||
"zero/core/stat"
|
||||
"zero/core/stores/redis"
|
||||
|
||||
"github.com/alicebob/miniredis"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func init() {
|
||||
logx.Disable()
|
||||
stat.SetReporter(nil)
|
||||
}
|
||||
|
||||
func TestCacheNode_DelCache(t *testing.T) {
|
||||
s, err := miniredis.Run()
|
||||
assert.Nil(t, err)
|
||||
defer s.Close()
|
||||
|
||||
cn := cacheNode{
|
||||
rds: redis.NewRedis(s.Addr(), redis.NodeType),
|
||||
r: rand.New(rand.NewSource(time.Now().UnixNano())),
|
||||
lock: new(sync.Mutex),
|
||||
unstableExpiry: mathx.NewUnstable(expiryDeviation),
|
||||
stat: NewCacheStat("any"),
|
||||
errNotFound: errors.New("any"),
|
||||
}
|
||||
assert.Nil(t, cn.DelCache())
|
||||
assert.Nil(t, cn.DelCache([]string{}...))
|
||||
assert.Nil(t, cn.DelCache(make([]string, 0)...))
|
||||
cn.SetCache("first", "one")
|
||||
assert.Nil(t, cn.DelCache("first"))
|
||||
cn.SetCache("first", "one")
|
||||
cn.SetCache("second", "two")
|
||||
assert.Nil(t, cn.DelCache("first", "second"))
|
||||
}
|
||||
|
||||
func TestCacheNode_InvalidCache(t *testing.T) {
|
||||
s, err := miniredis.Run()
|
||||
assert.Nil(t, err)
|
||||
defer s.Close()
|
||||
|
||||
cn := cacheNode{
|
||||
rds: redis.NewRedis(s.Addr(), redis.NodeType),
|
||||
r: rand.New(rand.NewSource(time.Now().UnixNano())),
|
||||
lock: new(sync.Mutex),
|
||||
unstableExpiry: mathx.NewUnstable(expiryDeviation),
|
||||
stat: NewCacheStat("any"),
|
||||
errNotFound: errors.New("any"),
|
||||
}
|
||||
s.Set("any", "value")
|
||||
var str string
|
||||
assert.NotNil(t, cn.GetCache("any", &str))
|
||||
assert.Equal(t, "", str)
|
||||
_, err = s.Get("any")
|
||||
assert.Equal(t, miniredis.ErrKeyNotFound, err)
|
||||
}
|
||||
33
core/stores/internal/cacheopt.go
Normal file
33
core/stores/internal/cacheopt.go
Normal file
@@ -0,0 +1,33 @@
|
||||
package internal
|
||||
|
||||
import "time"
|
||||
|
||||
const (
|
||||
defaultExpiry = time.Hour * 24 * 7
|
||||
defaultNotFoundExpiry = time.Minute
|
||||
)
|
||||
|
||||
type (
|
||||
Options struct {
|
||||
Expiry time.Duration
|
||||
NotFoundExpiry time.Duration
|
||||
}
|
||||
|
||||
Option func(o *Options)
|
||||
)
|
||||
|
||||
func newOptions(opts ...Option) Options {
|
||||
var o Options
|
||||
for _, opt := range opts {
|
||||
opt(&o)
|
||||
}
|
||||
|
||||
if o.Expiry <= 0 {
|
||||
o.Expiry = defaultExpiry
|
||||
}
|
||||
if o.NotFoundExpiry <= 0 {
|
||||
o.NotFoundExpiry = defaultNotFoundExpiry
|
||||
}
|
||||
|
||||
return o
|
||||
}
|
||||
67
core/stores/internal/cachestat.go
Normal file
67
core/stores/internal/cachestat.go
Normal file
@@ -0,0 +1,67 @@
|
||||
package internal
|
||||
|
||||
import (
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"zero/core/logx"
|
||||
)
|
||||
|
||||
const statInterval = time.Minute
|
||||
|
||||
type CacheStat struct {
|
||||
name string
|
||||
// export the fields to let the unit tests working,
|
||||
// reside in internal package, doesn't matter.
|
||||
Total uint64
|
||||
Hit uint64
|
||||
Miss uint64
|
||||
DbFails uint64
|
||||
}
|
||||
|
||||
func NewCacheStat(name string) *CacheStat {
|
||||
ret := &CacheStat{
|
||||
name: name,
|
||||
}
|
||||
go ret.statLoop()
|
||||
|
||||
return ret
|
||||
}
|
||||
|
||||
func (cs *CacheStat) IncrementTotal() {
|
||||
atomic.AddUint64(&cs.Total, 1)
|
||||
}
|
||||
|
||||
func (cs *CacheStat) IncrementHit() {
|
||||
atomic.AddUint64(&cs.Hit, 1)
|
||||
}
|
||||
|
||||
func (cs *CacheStat) IncrementMiss() {
|
||||
atomic.AddUint64(&cs.Miss, 1)
|
||||
}
|
||||
|
||||
func (cs *CacheStat) IncrementDbFails() {
|
||||
atomic.AddUint64(&cs.DbFails, 1)
|
||||
}
|
||||
|
||||
func (cs *CacheStat) statLoop() {
|
||||
ticker := time.NewTicker(statInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
total := atomic.SwapUint64(&cs.Total, 0)
|
||||
if total == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
hit := atomic.SwapUint64(&cs.Hit, 0)
|
||||
percent := 100 * float32(hit) / float32(total)
|
||||
miss := atomic.SwapUint64(&cs.Miss, 0)
|
||||
dbf := atomic.SwapUint64(&cs.DbFails, 0)
|
||||
logx.Statf("dbcache(%s) - qpm: %d, hit_ratio: %.1f%%, hit: %d, miss: %d, db_fails: %d",
|
||||
cs.name, total, percent, hit, miss, dbf)
|
||||
}
|
||||
}
|
||||
}
|
||||
85
core/stores/internal/cleaner.go
Normal file
85
core/stores/internal/cleaner.go
Normal file
@@ -0,0 +1,85 @@
|
||||
package internal
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"zero/core/collection"
|
||||
"zero/core/lang"
|
||||
"zero/core/logx"
|
||||
"zero/core/proc"
|
||||
"zero/core/stat"
|
||||
"zero/core/stringx"
|
||||
"zero/core/threading"
|
||||
)
|
||||
|
||||
const (
|
||||
timingWheelSlots = 300
|
||||
cleanWorkers = 5
|
||||
taskKeyLen = 8
|
||||
)
|
||||
|
||||
var (
|
||||
timingWheel *collection.TimingWheel
|
||||
taskRunner = threading.NewTaskRunner(cleanWorkers)
|
||||
)
|
||||
|
||||
type delayTask struct {
|
||||
delay time.Duration
|
||||
task func() error
|
||||
keys []string
|
||||
}
|
||||
|
||||
func init() {
|
||||
var err error
|
||||
timingWheel, err = collection.NewTimingWheel(time.Second, timingWheelSlots, clean)
|
||||
lang.Must(err)
|
||||
|
||||
proc.AddShutdownListener(func() {
|
||||
timingWheel.Drain(clean)
|
||||
})
|
||||
}
|
||||
|
||||
func AddCleanTask(task func() error, keys ...string) {
|
||||
timingWheel.SetTimer(stringx.Randn(taskKeyLen), delayTask{
|
||||
delay: time.Second,
|
||||
task: task,
|
||||
keys: keys,
|
||||
}, time.Second)
|
||||
}
|
||||
|
||||
func clean(key, value interface{}) {
|
||||
taskRunner.Schedule(func() {
|
||||
dt := value.(delayTask)
|
||||
err := dt.task()
|
||||
if err == nil {
|
||||
return
|
||||
}
|
||||
|
||||
next, ok := nextDelay(dt.delay)
|
||||
if ok {
|
||||
dt.delay = next
|
||||
timingWheel.SetTimer(key, dt, next)
|
||||
} else {
|
||||
msg := fmt.Sprintf("retried but failed to clear cache with keys: %q, error: %v",
|
||||
formatKeys(dt.keys), err)
|
||||
logx.Error(msg)
|
||||
stat.Report(msg)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func nextDelay(delay time.Duration) (time.Duration, bool) {
|
||||
switch delay {
|
||||
case time.Second:
|
||||
return time.Second * 5, true
|
||||
case time.Second * 5:
|
||||
return time.Minute, true
|
||||
case time.Minute:
|
||||
return time.Minute * 5, true
|
||||
case time.Minute * 5:
|
||||
return time.Hour, true
|
||||
default:
|
||||
return 0, false
|
||||
}
|
||||
}
|
||||
12
core/stores/internal/config.go
Normal file
12
core/stores/internal/config.go
Normal file
@@ -0,0 +1,12 @@
|
||||
package internal
|
||||
|
||||
import "zero/core/stores/redis"
|
||||
|
||||
type (
|
||||
ClusterConf []NodeConf
|
||||
|
||||
NodeConf struct {
|
||||
redis.RedisConf
|
||||
Weight int `json:",default=100"`
|
||||
}
|
||||
)
|
||||
22
core/stores/internal/util.go
Normal file
22
core/stores/internal/util.go
Normal file
@@ -0,0 +1,22 @@
|
||||
package internal
|
||||
|
||||
import "strings"
|
||||
|
||||
const keySeparator = ","
|
||||
|
||||
func TotalWeights(c []NodeConf) int {
|
||||
var weights int
|
||||
|
||||
for _, node := range c {
|
||||
if node.Weight < 0 {
|
||||
node.Weight = 0
|
||||
}
|
||||
weights += node.Weight
|
||||
}
|
||||
|
||||
return weights
|
||||
}
|
||||
|
||||
func formatKeys(keys []string) string {
|
||||
return strings.Join(keys, keySeparator)
|
||||
}
|
||||
5
core/stores/kv/config.go
Normal file
5
core/stores/kv/config.go
Normal file
@@ -0,0 +1,5 @@
|
||||
package kv
|
||||
|
||||
import "zero/core/stores/internal"
|
||||
|
||||
type KvConf = internal.ClusterConf
|
||||
653
core/stores/kv/store.go
Normal file
653
core/stores/kv/store.go
Normal file
@@ -0,0 +1,653 @@
|
||||
package kv
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"log"
|
||||
|
||||
"zero/core/errorx"
|
||||
"zero/core/hash"
|
||||
"zero/core/stores/internal"
|
||||
"zero/core/stores/redis"
|
||||
)
|
||||
|
||||
var ErrNoRedisNode = errors.New("no redis node")
|
||||
|
||||
type (
|
||||
Store interface {
|
||||
Del(keys ...string) (int, error)
|
||||
Eval(script string, key string, args ...interface{}) (interface{}, error)
|
||||
Exists(key string) (bool, error)
|
||||
Expire(key string, seconds int) error
|
||||
Expireat(key string, expireTime int64) error
|
||||
Get(key string) (string, error)
|
||||
Hdel(key, field string) (bool, error)
|
||||
Hexists(key, field string) (bool, error)
|
||||
Hget(key, field string) (string, error)
|
||||
Hgetall(key string) (map[string]string, error)
|
||||
Hincrby(key, field string, increment int) (int, error)
|
||||
Hkeys(key string) ([]string, error)
|
||||
Hlen(key string) (int, error)
|
||||
Hmget(key string, fields ...string) ([]string, error)
|
||||
Hset(key, field, value string) error
|
||||
Hsetnx(key, field, value string) (bool, error)
|
||||
Hmset(key string, fieldsAndValues map[string]string) error
|
||||
Hvals(key string) ([]string, error)
|
||||
Incr(key string) (int64, error)
|
||||
Incrby(key string, increment int64) (int64, error)
|
||||
Llen(key string) (int, error)
|
||||
Lpop(key string) (string, error)
|
||||
Lpush(key string, values ...interface{}) (int, error)
|
||||
Lrange(key string, start int, stop int) ([]string, error)
|
||||
Lrem(key string, count int, value string) (int, error)
|
||||
Persist(key string) (bool, error)
|
||||
Pfadd(key string, values ...interface{}) (bool, error)
|
||||
Pfcount(key string) (int64, error)
|
||||
Rpush(key string, values ...interface{}) (int, error)
|
||||
Sadd(key string, values ...interface{}) (int, error)
|
||||
Scard(key string) (int64, error)
|
||||
Set(key string, value string) error
|
||||
Setex(key, value string, seconds int) error
|
||||
Setnx(key, value string) (bool, error)
|
||||
SetnxEx(key, value string, seconds int) (bool, error)
|
||||
Sismember(key string, value interface{}) (bool, error)
|
||||
Smembers(key string) ([]string, error)
|
||||
Spop(key string) (string, error)
|
||||
Srandmember(key string, count int) ([]string, error)
|
||||
Srem(key string, values ...interface{}) (int, error)
|
||||
Sscan(key string, cursor uint64, match string, count int64) (keys []string, cur uint64, err error)
|
||||
Ttl(key string) (int, error)
|
||||
Zadd(key string, score int64, value string) (bool, error)
|
||||
Zadds(key string, ps ...redis.Pair) (int64, error)
|
||||
Zcard(key string) (int, error)
|
||||
Zcount(key string, start, stop int64) (int, error)
|
||||
Zincrby(key string, increment int64, field string) (int64, error)
|
||||
Zrange(key string, start, stop int64) ([]string, error)
|
||||
ZrangeWithScores(key string, start, stop int64) ([]redis.Pair, error)
|
||||
ZrangebyscoreWithScores(key string, start, stop int64) ([]redis.Pair, error)
|
||||
ZrangebyscoreWithScoresAndLimit(key string, start, stop int64, page, size int) ([]redis.Pair, error)
|
||||
Zrank(key, field string) (int64, error)
|
||||
Zrem(key string, values ...interface{}) (int, error)
|
||||
Zremrangebyrank(key string, start, stop int64) (int, error)
|
||||
Zremrangebyscore(key string, start, stop int64) (int, error)
|
||||
Zrevrange(key string, start, stop int64) ([]string, error)
|
||||
ZrevrangebyscoreWithScores(key string, start, stop int64) ([]redis.Pair, error)
|
||||
ZrevrangebyscoreWithScoresAndLimit(key string, start, stop int64, page, size int) ([]redis.Pair, error)
|
||||
Zscore(key string, value string) (int64, error)
|
||||
}
|
||||
|
||||
clusterStore struct {
|
||||
dispatcher *hash.ConsistentHash
|
||||
}
|
||||
)
|
||||
|
||||
func NewStore(c KvConf) Store {
|
||||
if len(c) == 0 || internal.TotalWeights(c) <= 0 {
|
||||
log.Fatal("no cache nodes")
|
||||
}
|
||||
|
||||
// even if only one node, we chose to use consistent hash,
|
||||
// because Store and redis.Redis has different methods.
|
||||
dispatcher := hash.NewConsistentHash()
|
||||
for _, node := range c {
|
||||
cn := node.NewRedis()
|
||||
dispatcher.AddWithWeight(cn, node.Weight)
|
||||
}
|
||||
|
||||
return clusterStore{
|
||||
dispatcher: dispatcher,
|
||||
}
|
||||
}
|
||||
|
||||
func (cs clusterStore) Del(keys ...string) (int, error) {
|
||||
var val int
|
||||
var be errorx.BatchError
|
||||
|
||||
for _, key := range keys {
|
||||
node, e := cs.getRedis(key)
|
||||
if e != nil {
|
||||
be.Add(e)
|
||||
continue
|
||||
}
|
||||
|
||||
if v, e := node.Del(key); e != nil {
|
||||
be.Add(e)
|
||||
} else {
|
||||
val += v
|
||||
}
|
||||
}
|
||||
|
||||
return val, be.Err()
|
||||
}
|
||||
|
||||
func (cs clusterStore) Eval(script string, key string, args ...interface{}) (interface{}, error) {
|
||||
node, err := cs.getRedis(key)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return node.Eval(script, []string{key}, args...)
|
||||
}
|
||||
|
||||
func (cs clusterStore) Exists(key string) (bool, error) {
|
||||
node, err := cs.getRedis(key)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
return node.Exists(key)
|
||||
}
|
||||
|
||||
func (cs clusterStore) Expire(key string, seconds int) error {
|
||||
node, err := cs.getRedis(key)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return node.Expire(key, seconds)
|
||||
}
|
||||
|
||||
func (cs clusterStore) Expireat(key string, expireTime int64) error {
|
||||
node, err := cs.getRedis(key)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return node.Expireat(key, expireTime)
|
||||
}
|
||||
|
||||
func (cs clusterStore) Get(key string) (string, error) {
|
||||
node, err := cs.getRedis(key)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return node.Get(key)
|
||||
}
|
||||
|
||||
func (cs clusterStore) Hdel(key, field string) (bool, error) {
|
||||
node, err := cs.getRedis(key)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
return node.Hdel(key, field)
|
||||
}
|
||||
|
||||
func (cs clusterStore) Hexists(key, field string) (bool, error) {
|
||||
node, err := cs.getRedis(key)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
return node.Hexists(key, field)
|
||||
}
|
||||
|
||||
func (cs clusterStore) Hget(key, field string) (string, error) {
|
||||
node, err := cs.getRedis(key)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return node.Hget(key, field)
|
||||
}
|
||||
|
||||
func (cs clusterStore) Hgetall(key string) (map[string]string, error) {
|
||||
node, err := cs.getRedis(key)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return node.Hgetall(key)
|
||||
}
|
||||
|
||||
func (cs clusterStore) Hincrby(key, field string, increment int) (int, error) {
|
||||
node, err := cs.getRedis(key)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
return node.Hincrby(key, field, increment)
|
||||
}
|
||||
|
||||
func (cs clusterStore) Hkeys(key string) ([]string, error) {
|
||||
node, err := cs.getRedis(key)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return node.Hkeys(key)
|
||||
}
|
||||
|
||||
func (cs clusterStore) Hlen(key string) (int, error) {
|
||||
node, err := cs.getRedis(key)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
return node.Hlen(key)
|
||||
}
|
||||
|
||||
func (cs clusterStore) Hmget(key string, fields ...string) ([]string, error) {
|
||||
node, err := cs.getRedis(key)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return node.Hmget(key, fields...)
|
||||
}
|
||||
|
||||
func (cs clusterStore) Hset(key, field, value string) error {
|
||||
node, err := cs.getRedis(key)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return node.Hset(key, field, value)
|
||||
}
|
||||
|
||||
func (cs clusterStore) Hsetnx(key, field, value string) (bool, error) {
|
||||
node, err := cs.getRedis(key)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
return node.Hsetnx(key, field, value)
|
||||
}
|
||||
|
||||
func (cs clusterStore) Hmset(key string, fieldsAndValues map[string]string) error {
|
||||
node, err := cs.getRedis(key)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return node.Hmset(key, fieldsAndValues)
|
||||
}
|
||||
|
||||
func (cs clusterStore) Hvals(key string) ([]string, error) {
|
||||
node, err := cs.getRedis(key)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return node.Hvals(key)
|
||||
}
|
||||
|
||||
func (cs clusterStore) Incr(key string) (int64, error) {
|
||||
node, err := cs.getRedis(key)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
return node.Incr(key)
|
||||
}
|
||||
|
||||
func (cs clusterStore) Incrby(key string, increment int64) (int64, error) {
|
||||
node, err := cs.getRedis(key)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
return node.Incrby(key, increment)
|
||||
}
|
||||
|
||||
func (cs clusterStore) Llen(key string) (int, error) {
|
||||
node, err := cs.getRedis(key)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
return node.Llen(key)
|
||||
}
|
||||
|
||||
func (cs clusterStore) Lpop(key string) (string, error) {
|
||||
node, err := cs.getRedis(key)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return node.Lpop(key)
|
||||
}
|
||||
|
||||
func (cs clusterStore) Lpush(key string, values ...interface{}) (int, error) {
|
||||
node, err := cs.getRedis(key)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
return node.Lpush(key, values...)
|
||||
}
|
||||
|
||||
func (cs clusterStore) Lrange(key string, start int, stop int) ([]string, error) {
|
||||
node, err := cs.getRedis(key)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return node.Lrange(key, start, stop)
|
||||
}
|
||||
|
||||
func (cs clusterStore) Lrem(key string, count int, value string) (int, error) {
|
||||
node, err := cs.getRedis(key)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
return node.Lrem(key, count, value)
|
||||
}
|
||||
|
||||
func (cs clusterStore) Persist(key string) (bool, error) {
|
||||
node, err := cs.getRedis(key)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
return node.Persist(key)
|
||||
}
|
||||
|
||||
func (cs clusterStore) Pfadd(key string, values ...interface{}) (bool, error) {
|
||||
node, err := cs.getRedis(key)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
return node.Pfadd(key, values...)
|
||||
}
|
||||
|
||||
func (cs clusterStore) Pfcount(key string) (int64, error) {
|
||||
node, err := cs.getRedis(key)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
return node.Pfcount(key)
|
||||
}
|
||||
|
||||
func (cs clusterStore) Rpush(key string, values ...interface{}) (int, error) {
|
||||
node, err := cs.getRedis(key)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
return node.Rpush(key, values...)
|
||||
}
|
||||
|
||||
func (cs clusterStore) Sadd(key string, values ...interface{}) (int, error) {
|
||||
node, err := cs.getRedis(key)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
return node.Sadd(key, values...)
|
||||
}
|
||||
|
||||
func (cs clusterStore) Scard(key string) (int64, error) {
|
||||
node, err := cs.getRedis(key)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
return node.Scard(key)
|
||||
}
|
||||
|
||||
func (cs clusterStore) Set(key string, value string) error {
|
||||
node, err := cs.getRedis(key)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return node.Set(key, value)
|
||||
}
|
||||
|
||||
func (cs clusterStore) Setex(key, value string, seconds int) error {
|
||||
node, err := cs.getRedis(key)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return node.Setex(key, value, seconds)
|
||||
}
|
||||
|
||||
func (cs clusterStore) Setnx(key, value string) (bool, error) {
|
||||
node, err := cs.getRedis(key)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
return node.Setnx(key, value)
|
||||
}
|
||||
|
||||
func (cs clusterStore) SetnxEx(key, value string, seconds int) (bool, error) {
|
||||
node, err := cs.getRedis(key)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
return node.SetnxEx(key, value, seconds)
|
||||
}
|
||||
|
||||
func (cs clusterStore) Sismember(key string, value interface{}) (bool, error) {
|
||||
node, err := cs.getRedis(key)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
return node.Sismember(key, value)
|
||||
}
|
||||
|
||||
func (cs clusterStore) Smembers(key string) ([]string, error) {
|
||||
node, err := cs.getRedis(key)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return node.Smembers(key)
|
||||
}
|
||||
|
||||
func (cs clusterStore) Spop(key string) (string, error) {
|
||||
node, err := cs.getRedis(key)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return node.Spop(key)
|
||||
}
|
||||
|
||||
func (cs clusterStore) Srandmember(key string, count int) ([]string, error) {
|
||||
node, err := cs.getRedis(key)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return node.Srandmember(key, count)
|
||||
}
|
||||
|
||||
func (cs clusterStore) Srem(key string, values ...interface{}) (int, error) {
|
||||
node, err := cs.getRedis(key)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
return node.Srem(key, values...)
|
||||
}
|
||||
|
||||
func (cs clusterStore) Sscan(key string, cursor uint64, match string, count int64) (
|
||||
keys []string, cur uint64, err error) {
|
||||
node, err := cs.getRedis(key)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
return node.Sscan(key, cursor, match, count)
|
||||
}
|
||||
|
||||
func (cs clusterStore) Ttl(key string) (int, error) {
|
||||
node, err := cs.getRedis(key)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
return node.Ttl(key)
|
||||
}
|
||||
|
||||
func (cs clusterStore) Zadd(key string, score int64, value string) (bool, error) {
|
||||
node, err := cs.getRedis(key)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
return node.Zadd(key, score, value)
|
||||
}
|
||||
|
||||
func (cs clusterStore) Zadds(key string, ps ...redis.Pair) (int64, error) {
|
||||
node, err := cs.getRedis(key)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
return node.Zadds(key, ps...)
|
||||
}
|
||||
|
||||
func (cs clusterStore) Zcard(key string) (int, error) {
|
||||
node, err := cs.getRedis(key)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
return node.Zcard(key)
|
||||
}
|
||||
|
||||
func (cs clusterStore) Zcount(key string, start, stop int64) (int, error) {
|
||||
node, err := cs.getRedis(key)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
return node.Zcount(key, start, stop)
|
||||
}
|
||||
|
||||
func (cs clusterStore) Zincrby(key string, increment int64, field string) (int64, error) {
|
||||
node, err := cs.getRedis(key)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
return node.Zincrby(key, increment, field)
|
||||
}
|
||||
|
||||
func (cs clusterStore) Zrank(key, field string) (int64, error) {
|
||||
node, err := cs.getRedis(key)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
return node.Zrank(key, field)
|
||||
}
|
||||
|
||||
func (cs clusterStore) Zrange(key string, start, stop int64) ([]string, error) {
|
||||
node, err := cs.getRedis(key)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return node.Zrange(key, start, stop)
|
||||
}
|
||||
|
||||
func (cs clusterStore) ZrangeWithScores(key string, start, stop int64) ([]redis.Pair, error) {
|
||||
node, err := cs.getRedis(key)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return node.ZrangeWithScores(key, start, stop)
|
||||
}
|
||||
|
||||
func (cs clusterStore) ZrangebyscoreWithScores(key string, start, stop int64) ([]redis.Pair, error) {
|
||||
node, err := cs.getRedis(key)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return node.ZrangebyscoreWithScores(key, start, stop)
|
||||
}
|
||||
|
||||
func (cs clusterStore) ZrangebyscoreWithScoresAndLimit(key string, start, stop int64, page, size int) (
|
||||
[]redis.Pair, error) {
|
||||
node, err := cs.getRedis(key)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return node.ZrangebyscoreWithScoresAndLimit(key, start, stop, page, size)
|
||||
}
|
||||
|
||||
func (cs clusterStore) Zrem(key string, values ...interface{}) (int, error) {
|
||||
node, err := cs.getRedis(key)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
return node.Zrem(key, values...)
|
||||
}
|
||||
|
||||
func (cs clusterStore) Zremrangebyrank(key string, start, stop int64) (int, error) {
|
||||
node, err := cs.getRedis(key)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
return node.Zremrangebyrank(key, start, stop)
|
||||
}
|
||||
|
||||
func (cs clusterStore) Zremrangebyscore(key string, start, stop int64) (int, error) {
|
||||
node, err := cs.getRedis(key)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
return node.Zremrangebyscore(key, start, stop)
|
||||
}
|
||||
|
||||
func (cs clusterStore) Zrevrange(key string, start, stop int64) ([]string, error) {
|
||||
node, err := cs.getRedis(key)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return node.Zrevrange(key, start, stop)
|
||||
}
|
||||
|
||||
func (cs clusterStore) ZrevrangebyscoreWithScores(key string, start, stop int64) ([]redis.Pair, error) {
|
||||
node, err := cs.getRedis(key)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return node.ZrevrangebyscoreWithScores(key, start, stop)
|
||||
}
|
||||
|
||||
func (cs clusterStore) ZrevrangebyscoreWithScoresAndLimit(key string, start, stop int64, page, size int) (
|
||||
[]redis.Pair, error) {
|
||||
node, err := cs.getRedis(key)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return node.ZrevrangebyscoreWithScoresAndLimit(key, start, stop, page, size)
|
||||
}
|
||||
|
||||
func (cs clusterStore) Zscore(key string, value string) (int64, error) {
|
||||
node, err := cs.getRedis(key)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
return node.Zscore(key, value)
|
||||
}
|
||||
|
||||
func (cs clusterStore) getRedis(key string) (*redis.Redis, error) {
|
||||
if val, ok := cs.dispatcher.Get(key); !ok {
|
||||
return nil, ErrNoRedisNode
|
||||
} else {
|
||||
return val.(*redis.Redis), nil
|
||||
}
|
||||
}
|
||||
498
core/stores/kv/store_test.go
Normal file
498
core/stores/kv/store_test.go
Normal file
@@ -0,0 +1,498 @@
|
||||
package kv
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"zero/core/stores/internal"
|
||||
"zero/core/stores/redis"
|
||||
"zero/core/stringx"
|
||||
|
||||
"github.com/alicebob/miniredis"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
var s1, _ = miniredis.Run()
|
||||
var s2, _ = miniredis.Run()
|
||||
|
||||
func TestRedis_Exists(t *testing.T) {
|
||||
runOnCluster(t, func(client Store) {
|
||||
ok, err := client.Exists("a")
|
||||
assert.Nil(t, err)
|
||||
assert.False(t, ok)
|
||||
assert.Nil(t, client.Set("a", "b"))
|
||||
ok, err = client.Exists("a")
|
||||
assert.Nil(t, err)
|
||||
assert.True(t, ok)
|
||||
})
|
||||
}
|
||||
|
||||
func TestRedis_Eval(t *testing.T) {
|
||||
runOnCluster(t, func(client Store) {
|
||||
_, err := client.Eval(`redis.call("EXISTS", KEYS[1])`, "notexist")
|
||||
assert.Equal(t, redis.Nil, err)
|
||||
err = client.Set("key1", "value1")
|
||||
assert.Nil(t, err)
|
||||
_, err = client.Eval(`redis.call("EXISTS", KEYS[1])`, "key1")
|
||||
assert.Equal(t, redis.Nil, err)
|
||||
val, err := client.Eval(`return redis.call("EXISTS", KEYS[1])`, "key1")
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, int64(1), val)
|
||||
})
|
||||
}
|
||||
|
||||
func TestRedis_Hgetall(t *testing.T) {
|
||||
runOnCluster(t, func(client Store) {
|
||||
assert.Nil(t, client.Hset("a", "aa", "aaa"))
|
||||
assert.Nil(t, client.Hset("a", "bb", "bbb"))
|
||||
vals, err := client.Hgetall("a")
|
||||
assert.Nil(t, err)
|
||||
assert.EqualValues(t, map[string]string{
|
||||
"aa": "aaa",
|
||||
"bb": "bbb",
|
||||
}, vals)
|
||||
})
|
||||
}
|
||||
|
||||
func TestRedis_Hvals(t *testing.T) {
|
||||
runOnCluster(t, func(client Store) {
|
||||
assert.Nil(t, client.Hset("a", "aa", "aaa"))
|
||||
assert.Nil(t, client.Hset("a", "bb", "bbb"))
|
||||
vals, err := client.Hvals("a")
|
||||
assert.Nil(t, err)
|
||||
assert.ElementsMatch(t, []string{"aaa", "bbb"}, vals)
|
||||
})
|
||||
}
|
||||
|
||||
func TestRedis_Hsetnx(t *testing.T) {
|
||||
runOnCluster(t, func(client Store) {
|
||||
assert.Nil(t, client.Hset("a", "aa", "aaa"))
|
||||
assert.Nil(t, client.Hset("a", "bb", "bbb"))
|
||||
ok, err := client.Hsetnx("a", "bb", "ccc")
|
||||
assert.Nil(t, err)
|
||||
assert.False(t, ok)
|
||||
ok, err = client.Hsetnx("a", "dd", "ddd")
|
||||
assert.Nil(t, err)
|
||||
assert.True(t, ok)
|
||||
vals, err := client.Hvals("a")
|
||||
assert.Nil(t, err)
|
||||
assert.ElementsMatch(t, []string{"aaa", "bbb", "ddd"}, vals)
|
||||
})
|
||||
}
|
||||
|
||||
func TestRedis_HdelHlen(t *testing.T) {
|
||||
runOnCluster(t, func(client Store) {
|
||||
assert.Nil(t, client.Hset("a", "aa", "aaa"))
|
||||
assert.Nil(t, client.Hset("a", "bb", "bbb"))
|
||||
num, err := client.Hlen("a")
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, 2, num)
|
||||
val, err := client.Hdel("a", "aa")
|
||||
assert.Nil(t, err)
|
||||
assert.True(t, val)
|
||||
vals, err := client.Hvals("a")
|
||||
assert.Nil(t, err)
|
||||
assert.ElementsMatch(t, []string{"bbb"}, vals)
|
||||
})
|
||||
}
|
||||
|
||||
func TestRedis_HIncrBy(t *testing.T) {
|
||||
runOnCluster(t, func(client Store) {
|
||||
val, err := client.Hincrby("key", "field", 2)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, 2, val)
|
||||
val, err = client.Hincrby("key", "field", 3)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, 5, val)
|
||||
})
|
||||
}
|
||||
|
||||
func TestRedis_Hkeys(t *testing.T) {
|
||||
runOnCluster(t, func(client Store) {
|
||||
assert.Nil(t, client.Hset("a", "aa", "aaa"))
|
||||
assert.Nil(t, client.Hset("a", "bb", "bbb"))
|
||||
vals, err := client.Hkeys("a")
|
||||
assert.Nil(t, err)
|
||||
assert.ElementsMatch(t, []string{"aa", "bb"}, vals)
|
||||
})
|
||||
}
|
||||
|
||||
func TestRedis_Hmget(t *testing.T) {
|
||||
runOnCluster(t, func(client Store) {
|
||||
assert.Nil(t, client.Hset("a", "aa", "aaa"))
|
||||
assert.Nil(t, client.Hset("a", "bb", "bbb"))
|
||||
vals, err := client.Hmget("a", "aa", "bb")
|
||||
assert.Nil(t, err)
|
||||
assert.EqualValues(t, []string{"aaa", "bbb"}, vals)
|
||||
vals, err = client.Hmget("a", "aa", "no", "bb")
|
||||
assert.Nil(t, err)
|
||||
assert.EqualValues(t, []string{"aaa", "", "bbb"}, vals)
|
||||
})
|
||||
}
|
||||
|
||||
func TestRedis_Hmset(t *testing.T) {
|
||||
runOnCluster(t, func(client Store) {
|
||||
assert.Nil(t, client.Hmset("a", map[string]string{
|
||||
"aa": "aaa",
|
||||
"bb": "bbb",
|
||||
}))
|
||||
vals, err := client.Hmget("a", "aa", "bb")
|
||||
assert.Nil(t, err)
|
||||
assert.EqualValues(t, []string{"aaa", "bbb"}, vals)
|
||||
})
|
||||
}
|
||||
|
||||
func TestRedis_Incr(t *testing.T) {
|
||||
runOnCluster(t, func(client Store) {
|
||||
val, err := client.Incr("a")
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, int64(1), val)
|
||||
val, err = client.Incr("a")
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, int64(2), val)
|
||||
})
|
||||
}
|
||||
|
||||
func TestRedis_IncrBy(t *testing.T) {
|
||||
runOnCluster(t, func(client Store) {
|
||||
val, err := client.Incrby("a", 2)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, int64(2), val)
|
||||
val, err = client.Incrby("a", 3)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, int64(5), val)
|
||||
})
|
||||
}
|
||||
|
||||
func TestRedis_List(t *testing.T) {
|
||||
runOnCluster(t, func(client Store) {
|
||||
val, err := client.Lpush("key", "value1", "value2")
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, 2, val)
|
||||
val, err = client.Rpush("key", "value3", "value4")
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, 4, val)
|
||||
val, err = client.Llen("key")
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, 4, val)
|
||||
vals, err := client.Lrange("key", 0, 10)
|
||||
assert.Nil(t, err)
|
||||
assert.EqualValues(t, []string{"value2", "value1", "value3", "value4"}, vals)
|
||||
v, err := client.Lpop("key")
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, "value2", v)
|
||||
val, err = client.Lpush("key", "value1", "value2")
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, 5, val)
|
||||
val, err = client.Rpush("key", "value3", "value3")
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, 7, val)
|
||||
n, err := client.Lrem("key", 2, "value1")
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, 2, n)
|
||||
vals, err = client.Lrange("key", 0, 10)
|
||||
assert.Nil(t, err)
|
||||
assert.EqualValues(t, []string{"value2", "value3", "value4", "value3", "value3"}, vals)
|
||||
n, err = client.Lrem("key", -2, "value3")
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, 2, n)
|
||||
vals, err = client.Lrange("key", 0, 10)
|
||||
assert.Nil(t, err)
|
||||
assert.EqualValues(t, []string{"value2", "value3", "value4"}, vals)
|
||||
})
|
||||
}
|
||||
|
||||
func TestRedis_Persist(t *testing.T) {
|
||||
runOnCluster(t, func(client Store) {
|
||||
ok, err := client.Persist("key")
|
||||
assert.Nil(t, err)
|
||||
assert.False(t, ok)
|
||||
err = client.Set("key", "value")
|
||||
assert.Nil(t, err)
|
||||
ok, err = client.Persist("key")
|
||||
assert.Nil(t, err)
|
||||
assert.False(t, ok)
|
||||
err = client.Expire("key", 5)
|
||||
ok, err = client.Persist("key")
|
||||
assert.Nil(t, err)
|
||||
assert.True(t, ok)
|
||||
err = client.Expireat("key", time.Now().Unix()+5)
|
||||
ok, err = client.Persist("key")
|
||||
assert.Nil(t, err)
|
||||
assert.True(t, ok)
|
||||
})
|
||||
}
|
||||
|
||||
func TestRedis_Sscan(t *testing.T) {
|
||||
runOnCluster(t, func(client Store) {
|
||||
key := "list"
|
||||
var list []string
|
||||
for i := 0; i < 1550; i++ {
|
||||
list = append(list, stringx.Randn(i))
|
||||
}
|
||||
lens, err := client.Sadd(key, list)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, lens, 1550)
|
||||
|
||||
var cursor uint64 = 0
|
||||
sum := 0
|
||||
for {
|
||||
keys, next, err := client.Sscan(key, cursor, "", 100)
|
||||
assert.Nil(t, err)
|
||||
sum += len(keys)
|
||||
if next == 0 {
|
||||
break
|
||||
}
|
||||
cursor = next
|
||||
}
|
||||
|
||||
assert.Equal(t, sum, 1550)
|
||||
_, err = client.Del(key)
|
||||
assert.Nil(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
func TestRedis_Set(t *testing.T) {
|
||||
runOnCluster(t, func(client Store) {
|
||||
num, err := client.Sadd("key", 1, 2, 3, 4)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, 4, num)
|
||||
val, err := client.Scard("key")
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, int64(4), val)
|
||||
ok, err := client.Sismember("key", 2)
|
||||
assert.Nil(t, err)
|
||||
assert.True(t, ok)
|
||||
num, err = client.Srem("key", 3, 4)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, 2, num)
|
||||
vals, err := client.Smembers("key")
|
||||
assert.Nil(t, err)
|
||||
assert.ElementsMatch(t, []string{"1", "2"}, vals)
|
||||
members, err := client.Srandmember("key", 1)
|
||||
assert.Nil(t, err)
|
||||
assert.Len(t, members, 1)
|
||||
assert.Contains(t, []string{"1", "2"}, members[0])
|
||||
member, err := client.Spop("key")
|
||||
assert.Nil(t, err)
|
||||
assert.Contains(t, []string{"1", "2"}, member)
|
||||
vals, err = client.Smembers("key")
|
||||
assert.Nil(t, err)
|
||||
assert.NotContains(t, vals, member)
|
||||
num, err = client.Sadd("key1", 1, 2, 3, 4)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, 4, num)
|
||||
num, err = client.Sadd("key2", 2, 3, 4, 5)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, 4, num)
|
||||
})
|
||||
}
|
||||
|
||||
func TestRedis_SetGetDel(t *testing.T) {
|
||||
runOnCluster(t, func(client Store) {
|
||||
err := client.Set("hello", "world")
|
||||
assert.Nil(t, err)
|
||||
val, err := client.Get("hello")
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, "world", val)
|
||||
ret, err := client.Del("hello")
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, 1, ret)
|
||||
})
|
||||
}
|
||||
|
||||
func TestRedis_SetExNx(t *testing.T) {
|
||||
runOnCluster(t, func(client Store) {
|
||||
err := client.Setex("hello", "world", 5)
|
||||
assert.Nil(t, err)
|
||||
ok, err := client.Setnx("hello", "newworld")
|
||||
assert.Nil(t, err)
|
||||
assert.False(t, ok)
|
||||
ok, err = client.Setnx("newhello", "newworld")
|
||||
assert.Nil(t, err)
|
||||
assert.True(t, ok)
|
||||
val, err := client.Get("hello")
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, "world", val)
|
||||
val, err = client.Get("newhello")
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, "newworld", val)
|
||||
ttl, err := client.Ttl("hello")
|
||||
assert.Nil(t, err)
|
||||
assert.True(t, ttl > 0)
|
||||
ok, err = client.SetnxEx("newhello", "newworld", 5)
|
||||
assert.Nil(t, err)
|
||||
assert.False(t, ok)
|
||||
num, err := client.Del("newhello")
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, 1, num)
|
||||
ok, err = client.SetnxEx("newhello", "newworld", 5)
|
||||
assert.Nil(t, err)
|
||||
assert.True(t, ok)
|
||||
val, err = client.Get("newhello")
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, "newworld", val)
|
||||
})
|
||||
}
|
||||
|
||||
func TestRedis_SetGetDelHashField(t *testing.T) {
|
||||
runOnCluster(t, func(client Store) {
|
||||
err := client.Hset("key", "field", "value")
|
||||
assert.Nil(t, err)
|
||||
val, err := client.Hget("key", "field")
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, "value", val)
|
||||
ok, err := client.Hexists("key", "field")
|
||||
assert.Nil(t, err)
|
||||
assert.True(t, ok)
|
||||
ret, err := client.Hdel("key", "field")
|
||||
assert.Nil(t, err)
|
||||
assert.True(t, ret)
|
||||
ok, err = client.Hexists("key", "field")
|
||||
assert.Nil(t, err)
|
||||
assert.False(t, ok)
|
||||
})
|
||||
}
|
||||
|
||||
func TestRedis_SortedSet(t *testing.T) {
|
||||
runOnCluster(t, func(client Store) {
|
||||
ok, err := client.Zadd("key", 1, "value1")
|
||||
assert.Nil(t, err)
|
||||
assert.True(t, ok)
|
||||
ok, err = client.Zadd("key", 2, "value1")
|
||||
assert.Nil(t, err)
|
||||
assert.False(t, ok)
|
||||
val, err := client.Zscore("key", "value1")
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, int64(2), val)
|
||||
val, err = client.Zincrby("key", 3, "value1")
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, int64(5), val)
|
||||
val, err = client.Zscore("key", "value1")
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, int64(5), val)
|
||||
ok, err = client.Zadd("key", 6, "value2")
|
||||
assert.Nil(t, err)
|
||||
assert.True(t, ok)
|
||||
ok, err = client.Zadd("key", 7, "value3")
|
||||
assert.Nil(t, err)
|
||||
assert.True(t, ok)
|
||||
rank, err := client.Zrank("key", "value2")
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, int64(1), rank)
|
||||
rank, err = client.Zrank("key", "value4")
|
||||
assert.Equal(t, redis.Nil, err)
|
||||
num, err := client.Zrem("key", "value2", "value3")
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, 2, num)
|
||||
ok, err = client.Zadd("key", 6, "value2")
|
||||
assert.Nil(t, err)
|
||||
assert.True(t, ok)
|
||||
ok, err = client.Zadd("key", 7, "value3")
|
||||
assert.Nil(t, err)
|
||||
assert.True(t, ok)
|
||||
ok, err = client.Zadd("key", 8, "value4")
|
||||
assert.Nil(t, err)
|
||||
assert.True(t, ok)
|
||||
num, err = client.Zremrangebyscore("key", 6, 7)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, 2, num)
|
||||
ok, err = client.Zadd("key", 6, "value2")
|
||||
assert.Nil(t, err)
|
||||
assert.True(t, ok)
|
||||
ok, err = client.Zadd("key", 7, "value3")
|
||||
assert.Nil(t, err)
|
||||
assert.True(t, ok)
|
||||
num, err = client.Zcount("key", 6, 7)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, 2, num)
|
||||
num, err = client.Zremrangebyrank("key", 1, 2)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, 2, num)
|
||||
card, err := client.Zcard("key")
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, 2, card)
|
||||
vals, err := client.Zrange("key", 0, -1)
|
||||
assert.Nil(t, err)
|
||||
assert.EqualValues(t, []string{"value1", "value4"}, vals)
|
||||
vals, err = client.Zrevrange("key", 0, -1)
|
||||
assert.Nil(t, err)
|
||||
assert.EqualValues(t, []string{"value4", "value1"}, vals)
|
||||
pairs, err := client.ZrangeWithScores("key", 0, -1)
|
||||
assert.Nil(t, err)
|
||||
assert.EqualValues(t, []redis.Pair{
|
||||
{
|
||||
Key: "value1",
|
||||
Score: 5,
|
||||
},
|
||||
{
|
||||
Key: "value4",
|
||||
Score: 8,
|
||||
},
|
||||
}, pairs)
|
||||
pairs, err = client.ZrangebyscoreWithScores("key", 5, 8)
|
||||
assert.Nil(t, err)
|
||||
assert.EqualValues(t, []redis.Pair{
|
||||
{
|
||||
Key: "value1",
|
||||
Score: 5,
|
||||
},
|
||||
{
|
||||
Key: "value4",
|
||||
Score: 8,
|
||||
},
|
||||
}, pairs)
|
||||
pairs, err = client.ZrangebyscoreWithScoresAndLimit("key", 5, 8, 1, 1)
|
||||
assert.Nil(t, err)
|
||||
assert.EqualValues(t, []redis.Pair{
|
||||
{
|
||||
Key: "value4",
|
||||
Score: 8,
|
||||
},
|
||||
}, pairs)
|
||||
pairs, err = client.ZrevrangebyscoreWithScores("key", 5, 8)
|
||||
assert.Nil(t, err)
|
||||
assert.EqualValues(t, []redis.Pair{
|
||||
{
|
||||
Key: "value4",
|
||||
Score: 8,
|
||||
},
|
||||
{
|
||||
Key: "value1",
|
||||
Score: 5,
|
||||
},
|
||||
}, pairs)
|
||||
pairs, err = client.ZrevrangebyscoreWithScoresAndLimit("key", 5, 8, 1, 1)
|
||||
assert.Nil(t, err)
|
||||
assert.EqualValues(t, []redis.Pair{
|
||||
{
|
||||
Key: "value1",
|
||||
Score: 5,
|
||||
},
|
||||
}, pairs)
|
||||
})
|
||||
}
|
||||
|
||||
func runOnCluster(t *testing.T, fn func(cluster Store)) {
|
||||
s1.FlushAll()
|
||||
s2.FlushAll()
|
||||
|
||||
store := NewStore([]internal.NodeConf{
|
||||
{
|
||||
RedisConf: redis.RedisConf{
|
||||
Host: s1.Addr(),
|
||||
Type: redis.NodeType,
|
||||
},
|
||||
Weight: 100,
|
||||
},
|
||||
{
|
||||
RedisConf: redis.RedisConf{
|
||||
Host: s2.Addr(),
|
||||
Type: redis.NodeType,
|
||||
},
|
||||
Weight: 100,
|
||||
},
|
||||
})
|
||||
|
||||
fn(store)
|
||||
}
|
||||
87
core/stores/mongo/bulkinserter.go
Normal file
87
core/stores/mongo/bulkinserter.go
Normal file
@@ -0,0 +1,87 @@
|
||||
package mongo
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"zero/core/executors"
|
||||
"zero/core/logx"
|
||||
|
||||
"github.com/globalsign/mgo"
|
||||
)
|
||||
|
||||
const (
|
||||
flushInterval = time.Second
|
||||
maxBulkRows = 1000
|
||||
)
|
||||
|
||||
type (
|
||||
ResultHandler func(*mgo.BulkResult, error)
|
||||
|
||||
BulkInserter struct {
|
||||
executor *executors.PeriodicalExecutor
|
||||
inserter *dbInserter
|
||||
}
|
||||
)
|
||||
|
||||
func NewBulkInserter(session *mgo.Session, dbName string, collectionNamer func() string) *BulkInserter {
|
||||
inserter := &dbInserter{
|
||||
session: session,
|
||||
dbName: dbName,
|
||||
collectionNamer: collectionNamer,
|
||||
}
|
||||
|
||||
return &BulkInserter{
|
||||
executor: executors.NewPeriodicalExecutor(flushInterval, inserter),
|
||||
inserter: inserter,
|
||||
}
|
||||
}
|
||||
|
||||
func (bi *BulkInserter) Flush() {
|
||||
bi.executor.Flush()
|
||||
}
|
||||
|
||||
func (bi *BulkInserter) Insert(doc interface{}) {
|
||||
bi.executor.Add(doc)
|
||||
}
|
||||
|
||||
func (bi *BulkInserter) SetResultHandler(handler ResultHandler) {
|
||||
bi.executor.Sync(func() {
|
||||
bi.inserter.resultHandler = handler
|
||||
})
|
||||
}
|
||||
|
||||
type dbInserter struct {
|
||||
session *mgo.Session
|
||||
dbName string
|
||||
collectionNamer func() string
|
||||
documents []interface{}
|
||||
resultHandler ResultHandler
|
||||
}
|
||||
|
||||
func (in *dbInserter) AddTask(doc interface{}) bool {
|
||||
in.documents = append(in.documents, doc)
|
||||
return len(in.documents) >= maxBulkRows
|
||||
}
|
||||
|
||||
func (in *dbInserter) Execute(objs interface{}) {
|
||||
docs := objs.([]interface{})
|
||||
if len(docs) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
bulk := in.session.DB(in.dbName).C(in.collectionNamer()).Bulk()
|
||||
bulk.Insert(docs...)
|
||||
bulk.Unordered()
|
||||
result, err := bulk.Run()
|
||||
if in.resultHandler != nil {
|
||||
in.resultHandler(result, err)
|
||||
} else if err != nil {
|
||||
logx.Error(err)
|
||||
}
|
||||
}
|
||||
|
||||
func (in *dbInserter) RemoveAll() interface{} {
|
||||
documents := in.documents
|
||||
in.documents = nil
|
||||
return documents
|
||||
}
|
||||
238
core/stores/mongo/collection.go
Normal file
238
core/stores/mongo/collection.go
Normal file
@@ -0,0 +1,238 @@
|
||||
package mongo
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"time"
|
||||
|
||||
"zero/core/breaker"
|
||||
"zero/core/logx"
|
||||
"zero/core/timex"
|
||||
|
||||
"github.com/globalsign/mgo"
|
||||
)
|
||||
|
||||
const slowThreshold = time.Millisecond * 500
|
||||
|
||||
var ErrNotFound = mgo.ErrNotFound
|
||||
|
||||
type (
|
||||
Collection interface {
|
||||
Find(query interface{}) Query
|
||||
FindId(id interface{}) Query
|
||||
Insert(docs ...interface{}) error
|
||||
Pipe(pipeline interface{}) Pipe
|
||||
Remove(selector interface{}) error
|
||||
RemoveAll(selector interface{}) (*mgo.ChangeInfo, error)
|
||||
RemoveId(id interface{}) error
|
||||
Update(selector, update interface{}) error
|
||||
UpdateId(id, update interface{}) error
|
||||
Upsert(selector, update interface{}) (*mgo.ChangeInfo, error)
|
||||
}
|
||||
|
||||
decoratedCollection struct {
|
||||
*mgo.Collection
|
||||
brk breaker.Breaker
|
||||
}
|
||||
|
||||
keepablePromise struct {
|
||||
promise breaker.Promise
|
||||
log func(error)
|
||||
}
|
||||
)
|
||||
|
||||
func newCollection(collection *mgo.Collection) Collection {
|
||||
return &decoratedCollection{
|
||||
Collection: collection,
|
||||
brk: breaker.NewBreaker(),
|
||||
}
|
||||
}
|
||||
|
||||
func (c *decoratedCollection) Find(query interface{}) Query {
|
||||
promise, err := c.brk.Allow()
|
||||
if err != nil {
|
||||
return rejectedQuery{}
|
||||
}
|
||||
|
||||
startTime := timex.Now()
|
||||
return promisedQuery{
|
||||
Query: c.Collection.Find(query),
|
||||
promise: keepablePromise{
|
||||
promise: promise,
|
||||
log: func(err error) {
|
||||
duration := timex.Since(startTime)
|
||||
c.logDuration("find", duration, err, query)
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (c *decoratedCollection) FindId(id interface{}) Query {
|
||||
promise, err := c.brk.Allow()
|
||||
if err != nil {
|
||||
return rejectedQuery{}
|
||||
}
|
||||
|
||||
startTime := timex.Now()
|
||||
return promisedQuery{
|
||||
Query: c.Collection.FindId(id),
|
||||
promise: keepablePromise{
|
||||
promise: promise,
|
||||
log: func(err error) {
|
||||
duration := timex.Since(startTime)
|
||||
c.logDuration("findId", duration, err, id)
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (c *decoratedCollection) Insert(docs ...interface{}) (err error) {
|
||||
return c.brk.DoWithAcceptable(func() error {
|
||||
startTime := timex.Now()
|
||||
defer func() {
|
||||
duration := timex.Since(startTime)
|
||||
c.logDuration("insert", duration, err, docs...)
|
||||
}()
|
||||
|
||||
return c.Collection.Insert(docs...)
|
||||
}, acceptable)
|
||||
}
|
||||
|
||||
func (c *decoratedCollection) Pipe(pipeline interface{}) Pipe {
|
||||
promise, err := c.brk.Allow()
|
||||
if err != nil {
|
||||
return rejectedPipe{}
|
||||
}
|
||||
|
||||
startTime := timex.Now()
|
||||
return promisedPipe{
|
||||
Pipe: c.Collection.Pipe(pipeline),
|
||||
promise: keepablePromise{
|
||||
promise: promise,
|
||||
log: func(err error) {
|
||||
duration := timex.Since(startTime)
|
||||
c.logDuration("pipe", duration, err, pipeline)
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (c *decoratedCollection) Remove(selector interface{}) (err error) {
|
||||
return c.brk.DoWithAcceptable(func() error {
|
||||
startTime := timex.Now()
|
||||
defer func() {
|
||||
duration := timex.Since(startTime)
|
||||
c.logDuration("remove", duration, err, selector)
|
||||
}()
|
||||
|
||||
return c.Collection.Remove(selector)
|
||||
}, acceptable)
|
||||
}
|
||||
|
||||
func (c *decoratedCollection) RemoveAll(selector interface{}) (info *mgo.ChangeInfo, err error) {
|
||||
err = c.brk.DoWithAcceptable(func() error {
|
||||
startTime := timex.Now()
|
||||
defer func() {
|
||||
duration := timex.Since(startTime)
|
||||
c.logDuration("removeAll", duration, err, selector)
|
||||
}()
|
||||
|
||||
info, err = c.Collection.RemoveAll(selector)
|
||||
return err
|
||||
}, acceptable)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func (c *decoratedCollection) RemoveId(id interface{}) (err error) {
|
||||
return c.brk.DoWithAcceptable(func() error {
|
||||
startTime := timex.Now()
|
||||
defer func() {
|
||||
duration := timex.Since(startTime)
|
||||
c.logDuration("removeId", duration, err, id)
|
||||
}()
|
||||
|
||||
return c.Collection.RemoveId(id)
|
||||
}, acceptable)
|
||||
}
|
||||
|
||||
func (c *decoratedCollection) Update(selector, update interface{}) (err error) {
|
||||
return c.brk.DoWithAcceptable(func() error {
|
||||
startTime := timex.Now()
|
||||
defer func() {
|
||||
duration := timex.Since(startTime)
|
||||
c.logDuration("update", duration, err, selector, update)
|
||||
}()
|
||||
|
||||
return c.Collection.Update(selector, update)
|
||||
}, acceptable)
|
||||
}
|
||||
|
||||
func (c *decoratedCollection) UpdateId(id, update interface{}) (err error) {
|
||||
return c.brk.DoWithAcceptable(func() error {
|
||||
startTime := timex.Now()
|
||||
defer func() {
|
||||
duration := timex.Since(startTime)
|
||||
c.logDuration("updateId", duration, err, id, update)
|
||||
}()
|
||||
|
||||
return c.Collection.UpdateId(id, update)
|
||||
}, acceptable)
|
||||
}
|
||||
|
||||
func (c *decoratedCollection) Upsert(selector, update interface{}) (info *mgo.ChangeInfo, err error) {
|
||||
err = c.brk.DoWithAcceptable(func() error {
|
||||
startTime := timex.Now()
|
||||
defer func() {
|
||||
duration := timex.Since(startTime)
|
||||
c.logDuration("upsert", duration, err, selector, update)
|
||||
}()
|
||||
|
||||
info, err = c.Collection.Upsert(selector, update)
|
||||
return err
|
||||
}, acceptable)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func (c *decoratedCollection) logDuration(method string, duration time.Duration, err error, docs ...interface{}) {
|
||||
content, e := json.Marshal(docs)
|
||||
if e != nil {
|
||||
logx.Error(err)
|
||||
} else if err != nil {
|
||||
if duration > slowThreshold {
|
||||
logx.WithDuration(duration).Slowf("[MONGO] mongo(%s) - slowcall - %s - fail(%s) - %s",
|
||||
c.FullName, method, err.Error(), string(content))
|
||||
} else {
|
||||
logx.WithDuration(duration).Infof("mongo(%s) - %s - fail(%s) - %s",
|
||||
c.FullName, method, err.Error(), string(content))
|
||||
}
|
||||
} else {
|
||||
if duration > slowThreshold {
|
||||
logx.WithDuration(duration).Slowf("[MONGO] mongo(%s) - slowcall - %s - ok - %s",
|
||||
c.FullName, method, string(content))
|
||||
} else {
|
||||
logx.WithDuration(duration).Infof("mongo(%s) - %s - ok - %s", c.FullName, method, string(content))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (p keepablePromise) accept(err error) error {
|
||||
p.promise.Accept()
|
||||
p.log(err)
|
||||
return err
|
||||
}
|
||||
|
||||
func (p keepablePromise) keep(err error) error {
|
||||
if acceptable(err) {
|
||||
p.promise.Accept()
|
||||
} else {
|
||||
p.promise.Reject(err.Error())
|
||||
}
|
||||
|
||||
p.log(err)
|
||||
return err
|
||||
}
|
||||
|
||||
func acceptable(err error) bool {
|
||||
return err == nil || err == mgo.ErrNotFound
|
||||
}
|
||||
71
core/stores/mongo/collection_test.go
Normal file
71
core/stores/mongo/collection_test.go
Normal file
@@ -0,0 +1,71 @@
|
||||
package mongo
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"testing"
|
||||
|
||||
"github.com/globalsign/mgo"
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"zero/core/stringx"
|
||||
)
|
||||
|
||||
func TestKeepPromise_accept(t *testing.T) {
|
||||
p := new(mockPromise)
|
||||
kp := keepablePromise{
|
||||
promise: p,
|
||||
log: func(error) {},
|
||||
}
|
||||
assert.Nil(t, kp.accept(nil))
|
||||
assert.Equal(t, mgo.ErrNotFound, kp.accept(mgo.ErrNotFound))
|
||||
}
|
||||
|
||||
func TestKeepPromise_keep(t *testing.T) {
|
||||
tests := []struct {
|
||||
err error
|
||||
accepted bool
|
||||
reason string
|
||||
}{
|
||||
{
|
||||
err: nil,
|
||||
accepted: true,
|
||||
reason: "",
|
||||
},
|
||||
{
|
||||
err: mgo.ErrNotFound,
|
||||
accepted: true,
|
||||
reason: "",
|
||||
},
|
||||
{
|
||||
err: errors.New("any"),
|
||||
accepted: false,
|
||||
reason: "any",
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
t.Run(stringx.RandId(), func(t *testing.T) {
|
||||
p := new(mockPromise)
|
||||
kp := keepablePromise{
|
||||
promise: p,
|
||||
log: func(error) {},
|
||||
}
|
||||
assert.Equal(t, test.err, kp.keep(test.err))
|
||||
assert.Equal(t, test.accepted, p.accepted)
|
||||
assert.Equal(t, test.reason, p.reason)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
type mockPromise struct {
|
||||
accepted bool
|
||||
reason string
|
||||
}
|
||||
|
||||
func (p *mockPromise) Accept() {
|
||||
p.accepted = true
|
||||
}
|
||||
|
||||
func (p *mockPromise) Reject(reason string) {
|
||||
p.reason = reason
|
||||
}
|
||||
96
core/stores/mongo/iter.go
Normal file
96
core/stores/mongo/iter.go
Normal file
@@ -0,0 +1,96 @@
|
||||
//go:generate mockgen -package mongo -destination iter_mock.go -source iter.go Iter
|
||||
package mongo
|
||||
|
||||
import (
|
||||
"zero/core/breaker"
|
||||
|
||||
"github.com/globalsign/mgo/bson"
|
||||
)
|
||||
|
||||
type (
|
||||
Iter interface {
|
||||
All(result interface{}) error
|
||||
Close() error
|
||||
Done() bool
|
||||
Err() error
|
||||
For(result interface{}, f func() error) error
|
||||
Next(result interface{}) bool
|
||||
State() (int64, []bson.Raw)
|
||||
Timeout() bool
|
||||
}
|
||||
|
||||
ClosableIter struct {
|
||||
Iter
|
||||
Cleanup func()
|
||||
}
|
||||
|
||||
promisedIter struct {
|
||||
Iter
|
||||
promise keepablePromise
|
||||
}
|
||||
|
||||
rejectedIter struct{}
|
||||
)
|
||||
|
||||
func (i promisedIter) All(result interface{}) error {
|
||||
return i.promise.keep(i.Iter.All(result))
|
||||
}
|
||||
|
||||
func (i promisedIter) Close() error {
|
||||
return i.promise.keep(i.Iter.Close())
|
||||
}
|
||||
|
||||
func (i promisedIter) Err() error {
|
||||
return i.Iter.Err()
|
||||
}
|
||||
|
||||
func (i promisedIter) For(result interface{}, f func() error) error {
|
||||
var ferr error
|
||||
err := i.Iter.For(result, func() error {
|
||||
ferr = f()
|
||||
return ferr
|
||||
})
|
||||
if ferr == err {
|
||||
return i.promise.accept(err)
|
||||
}
|
||||
|
||||
return i.promise.keep(err)
|
||||
}
|
||||
|
||||
func (it *ClosableIter) Close() error {
|
||||
err := it.Iter.Close()
|
||||
it.Cleanup()
|
||||
return err
|
||||
}
|
||||
|
||||
func (i rejectedIter) All(result interface{}) error {
|
||||
return breaker.ErrServiceUnavailable
|
||||
}
|
||||
|
||||
func (i rejectedIter) Close() error {
|
||||
return breaker.ErrServiceUnavailable
|
||||
}
|
||||
|
||||
func (i rejectedIter) Done() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func (i rejectedIter) Err() error {
|
||||
return breaker.ErrServiceUnavailable
|
||||
}
|
||||
|
||||
func (i rejectedIter) For(result interface{}, f func() error) error {
|
||||
return breaker.ErrServiceUnavailable
|
||||
}
|
||||
|
||||
func (i rejectedIter) Next(result interface{}) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func (i rejectedIter) State() (int64, []bson.Raw) {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
func (i rejectedIter) Timeout() bool {
|
||||
return false
|
||||
}
|
||||
147
core/stores/mongo/iter_mock.go
Normal file
147
core/stores/mongo/iter_mock.go
Normal file
@@ -0,0 +1,147 @@
|
||||
// Code generated by MockGen. DO NOT EDIT.
|
||||
// Source: iter.go
|
||||
|
||||
// Package mongo is a generated GoMock package.
|
||||
package mongo
|
||||
|
||||
import (
|
||||
bson "github.com/globalsign/mgo/bson"
|
||||
gomock "github.com/golang/mock/gomock"
|
||||
reflect "reflect"
|
||||
)
|
||||
|
||||
// MockIter is a mock of Iter interface
|
||||
type MockIter struct {
|
||||
ctrl *gomock.Controller
|
||||
recorder *MockIterMockRecorder
|
||||
}
|
||||
|
||||
// MockIterMockRecorder is the mock recorder for MockIter
|
||||
type MockIterMockRecorder struct {
|
||||
mock *MockIter
|
||||
}
|
||||
|
||||
// NewMockIter creates a new mock instance
|
||||
func NewMockIter(ctrl *gomock.Controller) *MockIter {
|
||||
mock := &MockIter{ctrl: ctrl}
|
||||
mock.recorder = &MockIterMockRecorder{mock}
|
||||
return mock
|
||||
}
|
||||
|
||||
// EXPECT returns an object that allows the caller to indicate expected use
|
||||
func (m *MockIter) EXPECT() *MockIterMockRecorder {
|
||||
return m.recorder
|
||||
}
|
||||
|
||||
// All mocks base method
|
||||
func (m *MockIter) All(result interface{}) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "All", result)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// All indicates an expected call of All
|
||||
func (mr *MockIterMockRecorder) All(result interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "All", reflect.TypeOf((*MockIter)(nil).All), result)
|
||||
}
|
||||
|
||||
// Close mocks base method
|
||||
func (m *MockIter) Close() error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "Close")
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// Close indicates an expected call of Close
|
||||
func (mr *MockIterMockRecorder) Close() *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockIter)(nil).Close))
|
||||
}
|
||||
|
||||
// Done mocks base method
|
||||
func (m *MockIter) Done() bool {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "Done")
|
||||
ret0, _ := ret[0].(bool)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// Done indicates an expected call of Done
|
||||
func (mr *MockIterMockRecorder) Done() *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Done", reflect.TypeOf((*MockIter)(nil).Done))
|
||||
}
|
||||
|
||||
// Err mocks base method
|
||||
func (m *MockIter) Err() error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "Err")
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// Err indicates an expected call of Err
|
||||
func (mr *MockIterMockRecorder) Err() *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Err", reflect.TypeOf((*MockIter)(nil).Err))
|
||||
}
|
||||
|
||||
// For mocks base method
|
||||
func (m *MockIter) For(result interface{}, f func() error) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "For", result, f)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// For indicates an expected call of For
|
||||
func (mr *MockIterMockRecorder) For(result, f interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "For", reflect.TypeOf((*MockIter)(nil).For), result, f)
|
||||
}
|
||||
|
||||
// Next mocks base method
|
||||
func (m *MockIter) Next(result interface{}) bool {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "Next", result)
|
||||
ret0, _ := ret[0].(bool)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// Next indicates an expected call of Next
|
||||
func (mr *MockIterMockRecorder) Next(result interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Next", reflect.TypeOf((*MockIter)(nil).Next), result)
|
||||
}
|
||||
|
||||
// State mocks base method
|
||||
func (m *MockIter) State() (int64, []bson.Raw) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "State")
|
||||
ret0, _ := ret[0].(int64)
|
||||
ret1, _ := ret[1].([]bson.Raw)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// State indicates an expected call of State
|
||||
func (mr *MockIterMockRecorder) State() *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "State", reflect.TypeOf((*MockIter)(nil).State))
|
||||
}
|
||||
|
||||
// Timeout mocks base method
|
||||
func (m *MockIter) Timeout() bool {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "Timeout")
|
||||
ret0, _ := ret[0].(bool)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// Timeout indicates an expected call of Timeout
|
||||
func (mr *MockIterMockRecorder) Timeout() *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Timeout", reflect.TypeOf((*MockIter)(nil).Timeout))
|
||||
}
|
||||
265
core/stores/mongo/iter_test.go
Normal file
265
core/stores/mongo/iter_test.go
Normal file
@@ -0,0 +1,265 @@
|
||||
package mongo
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"testing"
|
||||
|
||||
"zero/core/breaker"
|
||||
"zero/core/stringx"
|
||||
"zero/core/syncx"
|
||||
|
||||
"github.com/globalsign/mgo"
|
||||
"github.com/golang/mock/gomock"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestClosableIter_Close(t *testing.T) {
|
||||
errs := []error{
|
||||
nil,
|
||||
mgo.ErrNotFound,
|
||||
}
|
||||
|
||||
for _, err := range errs {
|
||||
t.Run(stringx.RandId(), func(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
cleaned := syncx.NewAtomicBool()
|
||||
iter := NewMockIter(ctrl)
|
||||
iter.EXPECT().Close().Return(err)
|
||||
ci := ClosableIter{
|
||||
Iter: iter,
|
||||
Cleanup: func() {
|
||||
cleaned.Set(true)
|
||||
},
|
||||
}
|
||||
assert.Equal(t, err, ci.Close())
|
||||
assert.True(t, cleaned.True())
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestPromisedIter_AllAndClose(t *testing.T) {
|
||||
tests := []struct {
|
||||
err error
|
||||
accepted bool
|
||||
reason string
|
||||
}{
|
||||
{
|
||||
err: nil,
|
||||
accepted: true,
|
||||
reason: "",
|
||||
},
|
||||
{
|
||||
err: mgo.ErrNotFound,
|
||||
accepted: true,
|
||||
reason: "",
|
||||
},
|
||||
{
|
||||
err: errors.New("any"),
|
||||
accepted: false,
|
||||
reason: "any",
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
t.Run(stringx.RandId(), func(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
iter := NewMockIter(ctrl)
|
||||
iter.EXPECT().All(gomock.Any()).Return(test.err)
|
||||
promise := new(mockPromise)
|
||||
pi := promisedIter{
|
||||
Iter: iter,
|
||||
promise: keepablePromise{
|
||||
promise: promise,
|
||||
log: func(error) {},
|
||||
},
|
||||
}
|
||||
assert.Equal(t, test.err, pi.All(nil))
|
||||
assert.Equal(t, test.accepted, promise.accepted)
|
||||
assert.Equal(t, test.reason, promise.reason)
|
||||
})
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
t.Run(stringx.RandId(), func(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
iter := NewMockIter(ctrl)
|
||||
iter.EXPECT().Close().Return(test.err)
|
||||
promise := new(mockPromise)
|
||||
pi := promisedIter{
|
||||
Iter: iter,
|
||||
promise: keepablePromise{
|
||||
promise: promise,
|
||||
log: func(error) {},
|
||||
},
|
||||
}
|
||||
assert.Equal(t, test.err, pi.Close())
|
||||
assert.Equal(t, test.accepted, promise.accepted)
|
||||
assert.Equal(t, test.reason, promise.reason)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestPromisedIter_Err(t *testing.T) {
|
||||
errs := []error{
|
||||
nil,
|
||||
mgo.ErrNotFound,
|
||||
}
|
||||
|
||||
for _, err := range errs {
|
||||
t.Run(stringx.RandId(), func(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
iter := NewMockIter(ctrl)
|
||||
iter.EXPECT().Err().Return(err)
|
||||
promise := new(mockPromise)
|
||||
pi := promisedIter{
|
||||
Iter: iter,
|
||||
promise: keepablePromise{
|
||||
promise: promise,
|
||||
log: func(error) {},
|
||||
},
|
||||
}
|
||||
assert.Equal(t, err, pi.Err())
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestPromisedIter_For(t *testing.T) {
|
||||
tests := []struct {
|
||||
err error
|
||||
accepted bool
|
||||
reason string
|
||||
}{
|
||||
{
|
||||
err: nil,
|
||||
accepted: true,
|
||||
reason: "",
|
||||
},
|
||||
{
|
||||
err: mgo.ErrNotFound,
|
||||
accepted: true,
|
||||
reason: "",
|
||||
},
|
||||
{
|
||||
err: errors.New("any"),
|
||||
accepted: false,
|
||||
reason: "any",
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
t.Run(stringx.RandId(), func(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
iter := NewMockIter(ctrl)
|
||||
iter.EXPECT().For(gomock.Any(), gomock.Any()).Return(test.err)
|
||||
promise := new(mockPromise)
|
||||
pi := promisedIter{
|
||||
Iter: iter,
|
||||
promise: keepablePromise{
|
||||
promise: promise,
|
||||
log: func(error) {},
|
||||
},
|
||||
}
|
||||
assert.Equal(t, test.err, pi.For(nil, nil))
|
||||
assert.Equal(t, test.accepted, promise.accepted)
|
||||
assert.Equal(t, test.reason, promise.reason)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRejectedIter_All(t *testing.T) {
|
||||
assert.Equal(t, breaker.ErrServiceUnavailable, new(rejectedIter).All(nil))
|
||||
}
|
||||
|
||||
func TestRejectedIter_Close(t *testing.T) {
|
||||
assert.Equal(t, breaker.ErrServiceUnavailable, new(rejectedIter).Close())
|
||||
}
|
||||
|
||||
func TestRejectedIter_Done(t *testing.T) {
|
||||
assert.False(t, new(rejectedIter).Done())
|
||||
}
|
||||
|
||||
func TestRejectedIter_Err(t *testing.T) {
|
||||
assert.Equal(t, breaker.ErrServiceUnavailable, new(rejectedIter).Err())
|
||||
}
|
||||
|
||||
func TestRejectedIter_For(t *testing.T) {
|
||||
assert.Equal(t, breaker.ErrServiceUnavailable, new(rejectedIter).For(nil, nil))
|
||||
}
|
||||
|
||||
func TestRejectedIter_Next(t *testing.T) {
|
||||
assert.False(t, new(rejectedIter).Next(nil))
|
||||
}
|
||||
|
||||
func TestRejectedIter_State(t *testing.T) {
|
||||
n, raw := new(rejectedIter).State()
|
||||
assert.Equal(t, int64(0), n)
|
||||
assert.Nil(t, raw)
|
||||
}
|
||||
|
||||
func TestRejectedIter_Timeout(t *testing.T) {
|
||||
assert.False(t, new(rejectedIter).Timeout())
|
||||
}
|
||||
|
||||
func TestIter_Done(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
iter := NewMockIter(ctrl)
|
||||
iter.EXPECT().Done().Return(true)
|
||||
ci := ClosableIter{
|
||||
Iter: iter,
|
||||
Cleanup: nil,
|
||||
}
|
||||
assert.True(t, ci.Done())
|
||||
}
|
||||
|
||||
func TestIter_Next(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
iter := NewMockIter(ctrl)
|
||||
iter.EXPECT().Next(gomock.Any()).Return(true)
|
||||
ci := ClosableIter{
|
||||
Iter: iter,
|
||||
Cleanup: nil,
|
||||
}
|
||||
assert.True(t, ci.Next(nil))
|
||||
}
|
||||
|
||||
func TestIter_State(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
iter := NewMockIter(ctrl)
|
||||
iter.EXPECT().State().Return(int64(1), nil)
|
||||
ci := ClosableIter{
|
||||
Iter: iter,
|
||||
Cleanup: nil,
|
||||
}
|
||||
n, raw := ci.State()
|
||||
assert.Equal(t, int64(1), n)
|
||||
assert.Nil(t, raw)
|
||||
}
|
||||
|
||||
func TestIter_Timeout(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
iter := NewMockIter(ctrl)
|
||||
iter.EXPECT().Timeout().Return(true)
|
||||
ci := ClosableIter{
|
||||
Iter: iter,
|
||||
Cleanup: nil,
|
||||
}
|
||||
assert.True(t, ci.Timeout())
|
||||
}
|
||||
164
core/stores/mongo/model.go
Normal file
164
core/stores/mongo/model.go
Normal file
@@ -0,0 +1,164 @@
|
||||
package mongo
|
||||
|
||||
import (
|
||||
"log"
|
||||
"time"
|
||||
|
||||
"github.com/globalsign/mgo"
|
||||
)
|
||||
|
||||
type (
|
||||
options struct {
|
||||
timeout time.Duration
|
||||
}
|
||||
|
||||
Option func(opts *options)
|
||||
|
||||
Model struct {
|
||||
session *concurrentSession
|
||||
db *mgo.Database
|
||||
collection string
|
||||
opts []Option
|
||||
}
|
||||
)
|
||||
|
||||
func MustNewModel(url, database, collection string, opts ...Option) *Model {
|
||||
model, err := NewModel(url, database, collection, opts...)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
return model
|
||||
}
|
||||
|
||||
func NewModel(url, database, collection string, opts ...Option) (*Model, error) {
|
||||
session, err := getConcurrentSession(url)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &Model{
|
||||
session: session,
|
||||
db: session.DB(database),
|
||||
collection: collection,
|
||||
opts: opts,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (mm *Model) Find(query interface{}) (Query, error) {
|
||||
return mm.query(func(c Collection) Query {
|
||||
return c.Find(query)
|
||||
})
|
||||
}
|
||||
|
||||
func (mm *Model) FindId(id interface{}) (Query, error) {
|
||||
return mm.query(func(c Collection) Query {
|
||||
return c.FindId(id)
|
||||
})
|
||||
}
|
||||
|
||||
func (mm *Model) GetCollection(session *mgo.Session) Collection {
|
||||
return newCollection(mm.db.C(mm.collection).With(session))
|
||||
}
|
||||
|
||||
func (mm *Model) Insert(docs ...interface{}) error {
|
||||
return mm.execute(func(c Collection) error {
|
||||
return c.Insert(docs...)
|
||||
})
|
||||
}
|
||||
|
||||
func (mm *Model) Pipe(pipeline interface{}) (Pipe, error) {
|
||||
return mm.pipe(func(c Collection) Pipe {
|
||||
return c.Pipe(pipeline)
|
||||
})
|
||||
}
|
||||
|
||||
func (mm *Model) PutSession(session *mgo.Session) {
|
||||
mm.session.putSession(session)
|
||||
}
|
||||
|
||||
func (mm *Model) Remove(selector interface{}) error {
|
||||
return mm.execute(func(c Collection) error {
|
||||
return c.Remove(selector)
|
||||
})
|
||||
}
|
||||
|
||||
func (mm *Model) RemoveAll(selector interface{}) (*mgo.ChangeInfo, error) {
|
||||
return mm.change(func(c Collection) (*mgo.ChangeInfo, error) {
|
||||
return c.RemoveAll(selector)
|
||||
})
|
||||
}
|
||||
|
||||
func (mm *Model) RemoveId(id interface{}) error {
|
||||
return mm.execute(func(c Collection) error {
|
||||
return c.RemoveId(id)
|
||||
})
|
||||
}
|
||||
|
||||
func (mm *Model) TakeSession() (*mgo.Session, error) {
|
||||
return mm.session.takeSession(mm.opts...)
|
||||
}
|
||||
|
||||
func (mm *Model) Update(selector, update interface{}) error {
|
||||
return mm.execute(func(c Collection) error {
|
||||
return c.Update(selector, update)
|
||||
})
|
||||
}
|
||||
|
||||
func (mm *Model) UpdateId(id, update interface{}) error {
|
||||
return mm.execute(func(c Collection) error {
|
||||
return c.UpdateId(id, update)
|
||||
})
|
||||
}
|
||||
|
||||
func (mm *Model) Upsert(selector, update interface{}) (*mgo.ChangeInfo, error) {
|
||||
return mm.change(func(c Collection) (*mgo.ChangeInfo, error) {
|
||||
return c.Upsert(selector, update)
|
||||
})
|
||||
}
|
||||
|
||||
func (mm *Model) change(fn func(c Collection) (*mgo.ChangeInfo, error)) (*mgo.ChangeInfo, error) {
|
||||
session, err := mm.TakeSession()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer mm.PutSession(session)
|
||||
|
||||
return fn(mm.GetCollection(session))
|
||||
}
|
||||
|
||||
func (mm *Model) execute(fn func(c Collection) error) error {
|
||||
session, err := mm.TakeSession()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer mm.PutSession(session)
|
||||
|
||||
return fn(mm.GetCollection(session))
|
||||
}
|
||||
|
||||
func (mm *Model) pipe(fn func(c Collection) Pipe) (Pipe, error) {
|
||||
session, err := mm.TakeSession()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer mm.PutSession(session)
|
||||
|
||||
return fn(mm.GetCollection(session)), nil
|
||||
}
|
||||
|
||||
func (mm *Model) query(fn func(c Collection) Query) (Query, error) {
|
||||
session, err := mm.TakeSession()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer mm.PutSession(session)
|
||||
|
||||
return fn(mm.GetCollection(session)), nil
|
||||
}
|
||||
|
||||
func WithTimeout(timeout time.Duration) Option {
|
||||
return func(opts *options) {
|
||||
opts.timeout = timeout
|
||||
}
|
||||
}
|
||||
100
core/stores/mongo/pipe.go
Normal file
100
core/stores/mongo/pipe.go
Normal file
@@ -0,0 +1,100 @@
|
||||
package mongo
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"zero/core/breaker"
|
||||
|
||||
"github.com/globalsign/mgo"
|
||||
)
|
||||
|
||||
type (
|
||||
Pipe interface {
|
||||
All(result interface{}) error
|
||||
AllowDiskUse() Pipe
|
||||
Batch(n int) Pipe
|
||||
Collation(collation *mgo.Collation) Pipe
|
||||
Explain(result interface{}) error
|
||||
Iter() Iter
|
||||
One(result interface{}) error
|
||||
SetMaxTime(d time.Duration) Pipe
|
||||
}
|
||||
|
||||
promisedPipe struct {
|
||||
*mgo.Pipe
|
||||
promise keepablePromise
|
||||
}
|
||||
|
||||
rejectedPipe struct{}
|
||||
)
|
||||
|
||||
func (p promisedPipe) All(result interface{}) error {
|
||||
return p.promise.keep(p.Pipe.All(result))
|
||||
}
|
||||
|
||||
func (p promisedPipe) AllowDiskUse() Pipe {
|
||||
p.Pipe.AllowDiskUse()
|
||||
return p
|
||||
}
|
||||
|
||||
func (p promisedPipe) Batch(n int) Pipe {
|
||||
p.Pipe.Batch(n)
|
||||
return p
|
||||
}
|
||||
|
||||
func (p promisedPipe) Collation(collation *mgo.Collation) Pipe {
|
||||
p.Pipe.Collation(collation)
|
||||
return p
|
||||
}
|
||||
|
||||
func (p promisedPipe) Explain(result interface{}) error {
|
||||
return p.promise.keep(p.Pipe.Explain(result))
|
||||
}
|
||||
|
||||
func (p promisedPipe) Iter() Iter {
|
||||
return promisedIter{
|
||||
Iter: p.Pipe.Iter(),
|
||||
promise: p.promise,
|
||||
}
|
||||
}
|
||||
|
||||
func (p promisedPipe) One(result interface{}) error {
|
||||
return p.promise.keep(p.Pipe.One(result))
|
||||
}
|
||||
|
||||
func (p promisedPipe) SetMaxTime(d time.Duration) Pipe {
|
||||
p.Pipe.SetMaxTime(d)
|
||||
return p
|
||||
}
|
||||
|
||||
func (p rejectedPipe) All(result interface{}) error {
|
||||
return breaker.ErrServiceUnavailable
|
||||
}
|
||||
|
||||
func (p rejectedPipe) AllowDiskUse() Pipe {
|
||||
return p
|
||||
}
|
||||
|
||||
func (p rejectedPipe) Batch(n int) Pipe {
|
||||
return p
|
||||
}
|
||||
|
||||
func (p rejectedPipe) Collation(collation *mgo.Collation) Pipe {
|
||||
return p
|
||||
}
|
||||
|
||||
func (p rejectedPipe) Explain(result interface{}) error {
|
||||
return breaker.ErrServiceUnavailable
|
||||
}
|
||||
|
||||
func (p rejectedPipe) Iter() Iter {
|
||||
return rejectedIter{}
|
||||
}
|
||||
|
||||
func (p rejectedPipe) One(result interface{}) error {
|
||||
return breaker.ErrServiceUnavailable
|
||||
}
|
||||
|
||||
func (p rejectedPipe) SetMaxTime(d time.Duration) Pipe {
|
||||
return p
|
||||
}
|
||||
45
core/stores/mongo/pipe_test.go
Normal file
45
core/stores/mongo/pipe_test.go
Normal file
@@ -0,0 +1,45 @@
|
||||
package mongo
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"zero/core/breaker"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestRejectedPipe_All(t *testing.T) {
|
||||
assert.Equal(t, breaker.ErrServiceUnavailable, new(rejectedPipe).All(nil))
|
||||
}
|
||||
|
||||
func TestRejectedPipe_AllowDiskUse(t *testing.T) {
|
||||
var p rejectedPipe
|
||||
assert.Equal(t, p, p.AllowDiskUse())
|
||||
}
|
||||
|
||||
func TestRejectedPipe_Batch(t *testing.T) {
|
||||
var p rejectedPipe
|
||||
assert.Equal(t, p, p.Batch(1))
|
||||
}
|
||||
|
||||
func TestRejectedPipe_Collation(t *testing.T) {
|
||||
var p rejectedPipe
|
||||
assert.Equal(t, p, p.Collation(nil))
|
||||
}
|
||||
|
||||
func TestRejectedPipe_Explain(t *testing.T) {
|
||||
assert.Equal(t, breaker.ErrServiceUnavailable, new(rejectedPipe).Explain(nil))
|
||||
}
|
||||
|
||||
func TestRejectedPipe_Iter(t *testing.T) {
|
||||
assert.EqualValues(t, rejectedIter{}, new(rejectedPipe).Iter())
|
||||
}
|
||||
|
||||
func TestRejectedPipe_One(t *testing.T) {
|
||||
assert.Equal(t, breaker.ErrServiceUnavailable, new(rejectedPipe).One(nil))
|
||||
}
|
||||
|
||||
func TestRejectedPipe_SetMaxTime(t *testing.T) {
|
||||
var p rejectedPipe
|
||||
assert.Equal(t, p, p.SetMaxTime(0))
|
||||
}
|
||||
285
core/stores/mongo/query.go
Normal file
285
core/stores/mongo/query.go
Normal file
@@ -0,0 +1,285 @@
|
||||
package mongo
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"zero/core/breaker"
|
||||
|
||||
"github.com/globalsign/mgo"
|
||||
)
|
||||
|
||||
type (
|
||||
Query interface {
|
||||
All(result interface{}) error
|
||||
Apply(change mgo.Change, result interface{}) (*mgo.ChangeInfo, error)
|
||||
Batch(n int) Query
|
||||
Collation(collation *mgo.Collation) Query
|
||||
Comment(comment string) Query
|
||||
Count() (int, error)
|
||||
Distinct(key string, result interface{}) error
|
||||
Explain(result interface{}) error
|
||||
For(result interface{}, f func() error) error
|
||||
Hint(indexKey ...string) Query
|
||||
Iter() Iter
|
||||
Limit(n int) Query
|
||||
LogReplay() Query
|
||||
MapReduce(job *mgo.MapReduce, result interface{}) (*mgo.MapReduceInfo, error)
|
||||
One(result interface{}) error
|
||||
Prefetch(p float64) Query
|
||||
Select(selector interface{}) Query
|
||||
SetMaxScan(n int) Query
|
||||
SetMaxTime(d time.Duration) Query
|
||||
Skip(n int) Query
|
||||
Snapshot() Query
|
||||
Sort(fields ...string) Query
|
||||
Tail(timeout time.Duration) Iter
|
||||
}
|
||||
|
||||
promisedQuery struct {
|
||||
*mgo.Query
|
||||
promise keepablePromise
|
||||
}
|
||||
|
||||
rejectedQuery struct{}
|
||||
)
|
||||
|
||||
func (q promisedQuery) All(result interface{}) error {
|
||||
return q.promise.keep(q.Query.All(result))
|
||||
}
|
||||
|
||||
func (q promisedQuery) Apply(change mgo.Change, result interface{}) (*mgo.ChangeInfo, error) {
|
||||
info, err := q.Query.Apply(change, result)
|
||||
return info, q.promise.keep(err)
|
||||
}
|
||||
|
||||
func (q promisedQuery) Batch(n int) Query {
|
||||
return promisedQuery{
|
||||
Query: q.Query.Batch(n),
|
||||
promise: q.promise,
|
||||
}
|
||||
}
|
||||
|
||||
func (q promisedQuery) Collation(collation *mgo.Collation) Query {
|
||||
return promisedQuery{
|
||||
Query: q.Query.Collation(collation),
|
||||
promise: q.promise,
|
||||
}
|
||||
}
|
||||
|
||||
func (q promisedQuery) Comment(comment string) Query {
|
||||
return promisedQuery{
|
||||
Query: q.Query.Comment(comment),
|
||||
promise: q.promise,
|
||||
}
|
||||
}
|
||||
|
||||
func (q promisedQuery) Count() (int, error) {
|
||||
v, err := q.Query.Count()
|
||||
return v, q.promise.keep(err)
|
||||
}
|
||||
|
||||
func (q promisedQuery) Distinct(key string, result interface{}) error {
|
||||
return q.promise.keep(q.Query.Distinct(key, result))
|
||||
}
|
||||
|
||||
func (q promisedQuery) Explain(result interface{}) error {
|
||||
return q.promise.keep(q.Query.Explain(result))
|
||||
}
|
||||
|
||||
func (q promisedQuery) For(result interface{}, f func() error) error {
|
||||
var ferr error
|
||||
err := q.Query.For(result, func() error {
|
||||
ferr = f()
|
||||
return ferr
|
||||
})
|
||||
if ferr == err {
|
||||
return q.promise.accept(err)
|
||||
}
|
||||
|
||||
return q.promise.keep(err)
|
||||
}
|
||||
|
||||
func (q promisedQuery) Hint(indexKey ...string) Query {
|
||||
return promisedQuery{
|
||||
Query: q.Query.Hint(indexKey...),
|
||||
promise: q.promise,
|
||||
}
|
||||
}
|
||||
|
||||
func (q promisedQuery) Iter() Iter {
|
||||
return promisedIter{
|
||||
Iter: q.Query.Iter(),
|
||||
promise: q.promise,
|
||||
}
|
||||
}
|
||||
|
||||
func (q promisedQuery) Limit(n int) Query {
|
||||
return promisedQuery{
|
||||
Query: q.Query.Limit(n),
|
||||
promise: q.promise,
|
||||
}
|
||||
}
|
||||
|
||||
func (q promisedQuery) LogReplay() Query {
|
||||
return promisedQuery{
|
||||
Query: q.Query.LogReplay(),
|
||||
promise: q.promise,
|
||||
}
|
||||
}
|
||||
|
||||
func (q promisedQuery) MapReduce(job *mgo.MapReduce, result interface{}) (*mgo.MapReduceInfo, error) {
|
||||
info, err := q.Query.MapReduce(job, result)
|
||||
return info, q.promise.keep(err)
|
||||
}
|
||||
|
||||
func (q promisedQuery) One(result interface{}) error {
|
||||
return q.promise.keep(q.Query.One(result))
|
||||
}
|
||||
|
||||
func (q promisedQuery) Prefetch(p float64) Query {
|
||||
return promisedQuery{
|
||||
Query: q.Query.Prefetch(p),
|
||||
promise: q.promise,
|
||||
}
|
||||
}
|
||||
|
||||
func (q promisedQuery) Select(selector interface{}) Query {
|
||||
return promisedQuery{
|
||||
Query: q.Query.Select(selector),
|
||||
promise: q.promise,
|
||||
}
|
||||
}
|
||||
|
||||
func (q promisedQuery) SetMaxScan(n int) Query {
|
||||
return promisedQuery{
|
||||
Query: q.Query.SetMaxScan(n),
|
||||
promise: q.promise,
|
||||
}
|
||||
}
|
||||
|
||||
func (q promisedQuery) SetMaxTime(d time.Duration) Query {
|
||||
return promisedQuery{
|
||||
Query: q.Query.SetMaxTime(d),
|
||||
promise: q.promise,
|
||||
}
|
||||
}
|
||||
|
||||
func (q promisedQuery) Skip(n int) Query {
|
||||
return promisedQuery{
|
||||
Query: q.Query.Skip(n),
|
||||
promise: q.promise,
|
||||
}
|
||||
}
|
||||
|
||||
func (q promisedQuery) Snapshot() Query {
|
||||
return promisedQuery{
|
||||
Query: q.Query.Snapshot(),
|
||||
promise: q.promise,
|
||||
}
|
||||
}
|
||||
|
||||
func (q promisedQuery) Sort(fields ...string) Query {
|
||||
return promisedQuery{
|
||||
Query: q.Query.Sort(fields...),
|
||||
promise: q.promise,
|
||||
}
|
||||
}
|
||||
|
||||
func (q promisedQuery) Tail(timeout time.Duration) Iter {
|
||||
return promisedIter{
|
||||
Iter: q.Query.Tail(timeout),
|
||||
promise: q.promise,
|
||||
}
|
||||
}
|
||||
|
||||
func (q rejectedQuery) All(result interface{}) error {
|
||||
return breaker.ErrServiceUnavailable
|
||||
}
|
||||
|
||||
func (q rejectedQuery) Apply(change mgo.Change, result interface{}) (*mgo.ChangeInfo, error) {
|
||||
return nil, breaker.ErrServiceUnavailable
|
||||
}
|
||||
|
||||
func (q rejectedQuery) Batch(n int) Query {
|
||||
return q
|
||||
}
|
||||
|
||||
func (q rejectedQuery) Collation(collation *mgo.Collation) Query {
|
||||
return q
|
||||
}
|
||||
|
||||
func (q rejectedQuery) Comment(comment string) Query {
|
||||
return q
|
||||
}
|
||||
|
||||
func (q rejectedQuery) Count() (int, error) {
|
||||
return 0, breaker.ErrServiceUnavailable
|
||||
}
|
||||
|
||||
func (q rejectedQuery) Distinct(key string, result interface{}) error {
|
||||
return breaker.ErrServiceUnavailable
|
||||
}
|
||||
|
||||
func (q rejectedQuery) Explain(result interface{}) error {
|
||||
return breaker.ErrServiceUnavailable
|
||||
}
|
||||
|
||||
func (q rejectedQuery) For(result interface{}, f func() error) error {
|
||||
return breaker.ErrServiceUnavailable
|
||||
}
|
||||
|
||||
func (q rejectedQuery) Hint(indexKey ...string) Query {
|
||||
return q
|
||||
}
|
||||
|
||||
func (q rejectedQuery) Iter() Iter {
|
||||
return rejectedIter{}
|
||||
}
|
||||
|
||||
func (q rejectedQuery) Limit(n int) Query {
|
||||
return q
|
||||
}
|
||||
|
||||
func (q rejectedQuery) LogReplay() Query {
|
||||
return q
|
||||
}
|
||||
|
||||
func (q rejectedQuery) MapReduce(job *mgo.MapReduce, result interface{}) (*mgo.MapReduceInfo, error) {
|
||||
return nil, breaker.ErrServiceUnavailable
|
||||
}
|
||||
|
||||
func (q rejectedQuery) One(result interface{}) error {
|
||||
return breaker.ErrServiceUnavailable
|
||||
}
|
||||
|
||||
func (q rejectedQuery) Prefetch(p float64) Query {
|
||||
return q
|
||||
}
|
||||
|
||||
func (q rejectedQuery) Select(selector interface{}) Query {
|
||||
return q
|
||||
}
|
||||
|
||||
func (q rejectedQuery) SetMaxScan(n int) Query {
|
||||
return q
|
||||
}
|
||||
|
||||
func (q rejectedQuery) SetMaxTime(d time.Duration) Query {
|
||||
return q
|
||||
}
|
||||
|
||||
func (q rejectedQuery) Skip(n int) Query {
|
||||
return q
|
||||
}
|
||||
|
||||
func (q rejectedQuery) Snapshot() Query {
|
||||
return q
|
||||
}
|
||||
|
||||
func (q rejectedQuery) Sort(fields ...string) Query {
|
||||
return q
|
||||
}
|
||||
|
||||
func (q rejectedQuery) Tail(timeout time.Duration) Iter {
|
||||
return rejectedIter{}
|
||||
}
|
||||
121
core/stores/mongo/query_test.go
Normal file
121
core/stores/mongo/query_test.go
Normal file
@@ -0,0 +1,121 @@
|
||||
package mongo
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"zero/core/breaker"
|
||||
|
||||
"github.com/globalsign/mgo"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func Test_rejectedQuery_All(t *testing.T) {
|
||||
assert.Equal(t, breaker.ErrServiceUnavailable, new(rejectedQuery).All(nil))
|
||||
}
|
||||
|
||||
func Test_rejectedQuery_Apply(t *testing.T) {
|
||||
info, err := new(rejectedQuery).Apply(mgo.Change{}, nil)
|
||||
assert.Equal(t, breaker.ErrServiceUnavailable, err)
|
||||
assert.Nil(t, info)
|
||||
}
|
||||
|
||||
func Test_rejectedQuery_Batch(t *testing.T) {
|
||||
var q rejectedQuery
|
||||
assert.Equal(t, q, q.Batch(1))
|
||||
}
|
||||
|
||||
func Test_rejectedQuery_Collation(t *testing.T) {
|
||||
var q rejectedQuery
|
||||
assert.Equal(t, q, q.Collation(nil))
|
||||
}
|
||||
|
||||
func Test_rejectedQuery_Comment(t *testing.T) {
|
||||
var q rejectedQuery
|
||||
assert.Equal(t, q, q.Comment(""))
|
||||
}
|
||||
|
||||
func Test_rejectedQuery_Count(t *testing.T) {
|
||||
n, err := new(rejectedQuery).Count()
|
||||
assert.Equal(t, breaker.ErrServiceUnavailable, err)
|
||||
assert.Equal(t, 0, n)
|
||||
}
|
||||
|
||||
func Test_rejectedQuery_Distinct(t *testing.T) {
|
||||
assert.Equal(t, breaker.ErrServiceUnavailable, new(rejectedQuery).Distinct("", nil))
|
||||
}
|
||||
|
||||
func Test_rejectedQuery_Explain(t *testing.T) {
|
||||
assert.Equal(t, breaker.ErrServiceUnavailable, new(rejectedQuery).Explain(nil))
|
||||
}
|
||||
|
||||
func Test_rejectedQuery_For(t *testing.T) {
|
||||
assert.Equal(t, breaker.ErrServiceUnavailable, new(rejectedQuery).For(nil, nil))
|
||||
}
|
||||
|
||||
func Test_rejectedQuery_Hint(t *testing.T) {
|
||||
var q rejectedQuery
|
||||
assert.Equal(t, q, q.Hint())
|
||||
}
|
||||
|
||||
func Test_rejectedQuery_Iter(t *testing.T) {
|
||||
assert.EqualValues(t, rejectedIter{}, new(rejectedQuery).Iter())
|
||||
}
|
||||
|
||||
func Test_rejectedQuery_Limit(t *testing.T) {
|
||||
var q rejectedQuery
|
||||
assert.Equal(t, q, q.Limit(1))
|
||||
}
|
||||
|
||||
func Test_rejectedQuery_LogReplay(t *testing.T) {
|
||||
var q rejectedQuery
|
||||
assert.Equal(t, q, q.LogReplay())
|
||||
}
|
||||
|
||||
func Test_rejectedQuery_MapReduce(t *testing.T) {
|
||||
info, err := new(rejectedQuery).MapReduce(nil, nil)
|
||||
assert.Equal(t, breaker.ErrServiceUnavailable, err)
|
||||
assert.Nil(t, info)
|
||||
}
|
||||
|
||||
func Test_rejectedQuery_One(t *testing.T) {
|
||||
assert.Equal(t, breaker.ErrServiceUnavailable, new(rejectedQuery).One(nil))
|
||||
}
|
||||
|
||||
func Test_rejectedQuery_Prefetch(t *testing.T) {
|
||||
var q rejectedQuery
|
||||
assert.Equal(t, q, q.Prefetch(1))
|
||||
}
|
||||
|
||||
func Test_rejectedQuery_Select(t *testing.T) {
|
||||
var q rejectedQuery
|
||||
assert.Equal(t, q, q.Select(nil))
|
||||
}
|
||||
|
||||
func Test_rejectedQuery_SetMaxScan(t *testing.T) {
|
||||
var q rejectedQuery
|
||||
assert.Equal(t, q, q.SetMaxScan(0))
|
||||
}
|
||||
|
||||
func Test_rejectedQuery_SetMaxTime(t *testing.T) {
|
||||
var q rejectedQuery
|
||||
assert.Equal(t, q, q.SetMaxTime(0))
|
||||
}
|
||||
|
||||
func Test_rejectedQuery_Skip(t *testing.T) {
|
||||
var q rejectedQuery
|
||||
assert.Equal(t, q, q.Skip(0))
|
||||
}
|
||||
|
||||
func Test_rejectedQuery_Snapshot(t *testing.T) {
|
||||
var q rejectedQuery
|
||||
assert.Equal(t, q, q.Snapshot())
|
||||
}
|
||||
|
||||
func Test_rejectedQuery_Sort(t *testing.T) {
|
||||
var q rejectedQuery
|
||||
assert.Equal(t, q, q.Sort())
|
||||
}
|
||||
|
||||
func Test_rejectedQuery_Tail(t *testing.T) {
|
||||
assert.EqualValues(t, rejectedIter{}, new(rejectedQuery).Tail(0))
|
||||
}
|
||||
73
core/stores/mongo/sessionmanager.go
Normal file
73
core/stores/mongo/sessionmanager.go
Normal file
@@ -0,0 +1,73 @@
|
||||
package mongo
|
||||
|
||||
import (
|
||||
"io"
|
||||
"time"
|
||||
|
||||
"zero/core/logx"
|
||||
"zero/core/syncx"
|
||||
|
||||
"github.com/globalsign/mgo"
|
||||
)
|
||||
|
||||
const (
|
||||
defaultConcurrency = 50
|
||||
defaultTimeout = time.Second
|
||||
)
|
||||
|
||||
var sessionManager = syncx.NewResourceManager()
|
||||
|
||||
type concurrentSession struct {
|
||||
*mgo.Session
|
||||
limit syncx.TimeoutLimit
|
||||
}
|
||||
|
||||
func (cs *concurrentSession) Close() error {
|
||||
cs.Session.Close()
|
||||
return nil
|
||||
}
|
||||
|
||||
func getConcurrentSession(url string) (*concurrentSession, error) {
|
||||
val, err := sessionManager.GetResource(url, func() (io.Closer, error) {
|
||||
mgoSession, err := mgo.Dial(url)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
concurrentSess := &concurrentSession{
|
||||
Session: mgoSession,
|
||||
limit: syncx.NewTimeoutLimit(defaultConcurrency),
|
||||
}
|
||||
|
||||
return concurrentSess, nil
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return val.(*concurrentSession), nil
|
||||
}
|
||||
|
||||
func (cs *concurrentSession) putSession(session *mgo.Session) {
|
||||
if err := cs.limit.Return(); err != nil {
|
||||
logx.Error(err)
|
||||
}
|
||||
|
||||
// anyway, we need to close the session
|
||||
session.Close()
|
||||
}
|
||||
|
||||
func (cs *concurrentSession) takeSession(opts ...Option) (*mgo.Session, error) {
|
||||
o := &options{
|
||||
timeout: defaultTimeout,
|
||||
}
|
||||
for _, opt := range opts {
|
||||
opt(o)
|
||||
}
|
||||
|
||||
if err := cs.limit.Borrow(o.timeout); err != nil {
|
||||
return nil, err
|
||||
} else {
|
||||
return cs.Copy(), nil
|
||||
}
|
||||
}
|
||||
9
core/stores/mongo/util.go
Normal file
9
core/stores/mongo/util.go
Normal file
@@ -0,0 +1,9 @@
|
||||
package mongo
|
||||
|
||||
import "strings"
|
||||
|
||||
const mongoAddrSep = ","
|
||||
|
||||
func FormatAddr(hosts []string) string {
|
||||
return strings.Join(hosts, mongoAddrSep)
|
||||
}
|
||||
35
core/stores/mongo/utils_test.go
Normal file
35
core/stores/mongo/utils_test.go
Normal file
@@ -0,0 +1,35 @@
|
||||
package mongo
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestFormatAddrs(t *testing.T) {
|
||||
tests := []struct {
|
||||
addrs []string
|
||||
expect string
|
||||
}{
|
||||
{
|
||||
addrs: []string{"a", "b"},
|
||||
expect: "a,b",
|
||||
},
|
||||
{
|
||||
addrs: []string{"a", "b", "c"},
|
||||
expect: "a,b,c",
|
||||
},
|
||||
{
|
||||
addrs: []string{},
|
||||
expect: "",
|
||||
},
|
||||
{
|
||||
addrs: nil,
|
||||
expect: "",
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
assert.Equal(t, test.expect, FormatAddr(test.addrs))
|
||||
}
|
||||
}
|
||||
171
core/stores/mongoc/cachedcollection.go
Normal file
171
core/stores/mongoc/cachedcollection.go
Normal file
@@ -0,0 +1,171 @@
|
||||
package mongoc
|
||||
|
||||
import (
|
||||
"zero/core/stores/internal"
|
||||
"zero/core/stores/mongo"
|
||||
"zero/core/syncx"
|
||||
|
||||
"github.com/globalsign/mgo"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrNotFound = mgo.ErrNotFound
|
||||
|
||||
// can't use one SharedCalls per conn, because multiple conns may share the same cache key.
|
||||
sharedCalls = syncx.NewSharedCalls()
|
||||
stats = internal.NewCacheStat("mongoc")
|
||||
)
|
||||
|
||||
type (
|
||||
QueryOption func(query mongo.Query) mongo.Query
|
||||
|
||||
cachedCollection struct {
|
||||
collection mongo.Collection
|
||||
cache internal.Cache
|
||||
}
|
||||
)
|
||||
|
||||
func newCollection(collection mongo.Collection, c internal.Cache) *cachedCollection {
|
||||
return &cachedCollection{
|
||||
collection: collection,
|
||||
cache: c,
|
||||
}
|
||||
}
|
||||
|
||||
func (c *cachedCollection) Count(query interface{}) (int, error) {
|
||||
return c.collection.Find(query).Count()
|
||||
}
|
||||
|
||||
func (c *cachedCollection) DelCache(keys ...string) error {
|
||||
return c.cache.DelCache(keys...)
|
||||
}
|
||||
|
||||
func (c *cachedCollection) GetCache(key string, v interface{}) error {
|
||||
return c.cache.GetCache(key, v)
|
||||
}
|
||||
|
||||
func (c *cachedCollection) FindAllNoCache(v interface{}, query interface{}, opts ...QueryOption) error {
|
||||
q := c.collection.Find(query)
|
||||
for _, opt := range opts {
|
||||
q = opt(q)
|
||||
}
|
||||
return q.All(v)
|
||||
}
|
||||
|
||||
func (c *cachedCollection) FindOne(v interface{}, key string, query interface{}) error {
|
||||
return c.cache.Take(v, key, func(v interface{}) error {
|
||||
q := c.collection.Find(query)
|
||||
return q.One(v)
|
||||
})
|
||||
}
|
||||
|
||||
func (c *cachedCollection) FindOneNoCache(v interface{}, query interface{}) error {
|
||||
q := c.collection.Find(query)
|
||||
return q.One(v)
|
||||
}
|
||||
|
||||
func (c *cachedCollection) FindOneId(v interface{}, key string, id interface{}) error {
|
||||
return c.cache.Take(v, key, func(v interface{}) error {
|
||||
q := c.collection.FindId(id)
|
||||
return q.One(v)
|
||||
})
|
||||
}
|
||||
|
||||
func (c *cachedCollection) FindOneIdNoCache(v interface{}, id interface{}) error {
|
||||
q := c.collection.FindId(id)
|
||||
return q.One(v)
|
||||
}
|
||||
|
||||
func (c *cachedCollection) Insert(docs ...interface{}) error {
|
||||
return c.collection.Insert(docs...)
|
||||
}
|
||||
|
||||
func (c *cachedCollection) Pipe(pipeline interface{}) mongo.Pipe {
|
||||
return c.collection.Pipe(pipeline)
|
||||
}
|
||||
|
||||
func (c *cachedCollection) Remove(selector interface{}, keys ...string) error {
|
||||
if err := c.RemoveNoCache(selector); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return c.DelCache(keys...)
|
||||
}
|
||||
|
||||
func (c *cachedCollection) RemoveNoCache(selector interface{}) error {
|
||||
return c.collection.Remove(selector)
|
||||
}
|
||||
|
||||
func (c *cachedCollection) RemoveAll(selector interface{}, keys ...string) (*mgo.ChangeInfo, error) {
|
||||
info, err := c.RemoveAllNoCache(selector)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := c.DelCache(keys...); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return info, nil
|
||||
}
|
||||
|
||||
func (c *cachedCollection) RemoveAllNoCache(selector interface{}) (*mgo.ChangeInfo, error) {
|
||||
return c.collection.RemoveAll(selector)
|
||||
}
|
||||
|
||||
func (c *cachedCollection) RemoveId(id interface{}, keys ...string) error {
|
||||
if err := c.RemoveIdNoCache(id); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return c.DelCache(keys...)
|
||||
}
|
||||
|
||||
func (c *cachedCollection) RemoveIdNoCache(id interface{}) error {
|
||||
return c.collection.RemoveId(id)
|
||||
}
|
||||
|
||||
func (c *cachedCollection) SetCache(key string, v interface{}) error {
|
||||
return c.cache.SetCache(key, v)
|
||||
}
|
||||
|
||||
func (c *cachedCollection) Update(selector, update interface{}, keys ...string) error {
|
||||
if err := c.UpdateNoCache(selector, update); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return c.DelCache(keys...)
|
||||
}
|
||||
|
||||
func (c *cachedCollection) UpdateNoCache(selector, update interface{}) error {
|
||||
return c.collection.Update(selector, update)
|
||||
}
|
||||
|
||||
func (c *cachedCollection) UpdateId(id, update interface{}, keys ...string) error {
|
||||
if err := c.UpdateIdNoCache(id, update); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return c.DelCache(keys...)
|
||||
}
|
||||
|
||||
func (c *cachedCollection) UpdateIdNoCache(id, update interface{}) error {
|
||||
return c.collection.UpdateId(id, update)
|
||||
}
|
||||
|
||||
func (c *cachedCollection) Upsert(selector, update interface{}, keys ...string) (*mgo.ChangeInfo, error) {
|
||||
info, err := c.UpsertNoCache(selector, update)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := c.DelCache(keys...); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return info, nil
|
||||
}
|
||||
|
||||
func (c *cachedCollection) UpsertNoCache(selector, update interface{}) (*mgo.ChangeInfo, error) {
|
||||
return c.collection.Upsert(selector, update)
|
||||
}
|
||||
300
core/stores/mongoc/cachedcollection_test.go
Normal file
300
core/stores/mongoc/cachedcollection_test.go
Normal file
@@ -0,0 +1,300 @@
|
||||
package mongoc
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"io/ioutil"
|
||||
"log"
|
||||
"os"
|
||||
"runtime"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"zero/core/stat"
|
||||
"zero/core/stores/internal"
|
||||
"zero/core/stores/mongo"
|
||||
"zero/core/stores/redis"
|
||||
|
||||
"github.com/alicebob/miniredis"
|
||||
"github.com/globalsign/mgo"
|
||||
"github.com/globalsign/mgo/bson"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func init() {
|
||||
stat.SetReporter(nil)
|
||||
}
|
||||
|
||||
func TestStat(t *testing.T) {
|
||||
resetStats()
|
||||
s, err := miniredis.Run()
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
r := redis.NewRedis(s.Addr(), redis.NodeType)
|
||||
cach := internal.NewCacheNode(r, sharedCalls, stats, mgo.ErrNotFound)
|
||||
c := newCollection(dummyConn{}, cach)
|
||||
|
||||
for i := 0; i < 10; i++ {
|
||||
var str string
|
||||
if err = c.cache.Take(&str, "name", func(v interface{}) error {
|
||||
*v.(*string) = "zero"
|
||||
return nil
|
||||
}); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
}
|
||||
|
||||
assert.Equal(t, uint64(10), atomic.LoadUint64(&stats.Total))
|
||||
assert.Equal(t, uint64(9), atomic.LoadUint64(&stats.Hit))
|
||||
}
|
||||
|
||||
func TestStatCacheFails(t *testing.T) {
|
||||
resetStats()
|
||||
log.SetOutput(ioutil.Discard)
|
||||
defer log.SetOutput(os.Stdout)
|
||||
|
||||
r := redis.NewRedis("localhost:59999", redis.NodeType)
|
||||
cach := internal.NewCacheNode(r, sharedCalls, stats, mgo.ErrNotFound)
|
||||
c := newCollection(dummyConn{}, cach)
|
||||
|
||||
for i := 0; i < 20; i++ {
|
||||
var str string
|
||||
err := c.FindOne(&str, "name", bson.M{})
|
||||
assert.NotNil(t, err)
|
||||
}
|
||||
|
||||
assert.Equal(t, uint64(20), atomic.LoadUint64(&stats.Total))
|
||||
assert.Equal(t, uint64(0), atomic.LoadUint64(&stats.Hit))
|
||||
assert.Equal(t, uint64(20), atomic.LoadUint64(&stats.Miss))
|
||||
assert.Equal(t, uint64(0), atomic.LoadUint64(&stats.DbFails))
|
||||
}
|
||||
|
||||
func TestStatDbFails(t *testing.T) {
|
||||
resetStats()
|
||||
s, err := miniredis.Run()
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
r := redis.NewRedis(s.Addr(), redis.NodeType)
|
||||
cach := internal.NewCacheNode(r, sharedCalls, stats, mgo.ErrNotFound)
|
||||
c := newCollection(dummyConn{}, cach)
|
||||
|
||||
for i := 0; i < 20; i++ {
|
||||
var str string
|
||||
err := c.cache.Take(&str, "name", func(v interface{}) error {
|
||||
return errors.New("db failed")
|
||||
})
|
||||
assert.NotNil(t, err)
|
||||
}
|
||||
|
||||
assert.Equal(t, uint64(20), atomic.LoadUint64(&stats.Total))
|
||||
assert.Equal(t, uint64(0), atomic.LoadUint64(&stats.Hit))
|
||||
assert.Equal(t, uint64(20), atomic.LoadUint64(&stats.DbFails))
|
||||
}
|
||||
|
||||
func TestStatFromMemory(t *testing.T) {
|
||||
resetStats()
|
||||
s, err := miniredis.Run()
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
r := redis.NewRedis(s.Addr(), redis.NodeType)
|
||||
cach := internal.NewCacheNode(r, sharedCalls, stats, mgo.ErrNotFound)
|
||||
c := newCollection(dummyConn{}, cach)
|
||||
|
||||
var all sync.WaitGroup
|
||||
var wait sync.WaitGroup
|
||||
all.Add(10)
|
||||
wait.Add(4)
|
||||
go func() {
|
||||
var str string
|
||||
if err := c.cache.Take(&str, "name", func(v interface{}) error {
|
||||
*v.(*string) = "zero"
|
||||
return nil
|
||||
}); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
wait.Wait()
|
||||
runtime.Gosched()
|
||||
all.Done()
|
||||
}()
|
||||
|
||||
for i := 0; i < 4; i++ {
|
||||
go func() {
|
||||
var str string
|
||||
wait.Done()
|
||||
if err := c.cache.Take(&str, "name", func(v interface{}) error {
|
||||
*v.(*string) = "zero"
|
||||
return nil
|
||||
}); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
all.Done()
|
||||
}()
|
||||
}
|
||||
for i := 0; i < 5; i++ {
|
||||
go func() {
|
||||
var str string
|
||||
if err := c.cache.Take(&str, "name", func(v interface{}) error {
|
||||
*v.(*string) = "zero"
|
||||
return nil
|
||||
}); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
all.Done()
|
||||
}()
|
||||
}
|
||||
all.Wait()
|
||||
|
||||
assert.Equal(t, uint64(10), atomic.LoadUint64(&stats.Total))
|
||||
assert.Equal(t, uint64(9), atomic.LoadUint64(&stats.Hit))
|
||||
}
|
||||
|
||||
func resetStats() {
|
||||
atomic.StoreUint64(&stats.Total, 0)
|
||||
atomic.StoreUint64(&stats.Hit, 0)
|
||||
atomic.StoreUint64(&stats.Miss, 0)
|
||||
atomic.StoreUint64(&stats.DbFails, 0)
|
||||
}
|
||||
|
||||
type dummyConn struct {
|
||||
}
|
||||
|
||||
func (c dummyConn) Find(query interface{}) mongo.Query {
|
||||
return dummyQuery{}
|
||||
}
|
||||
|
||||
func (c dummyConn) FindId(id interface{}) mongo.Query {
|
||||
return dummyQuery{}
|
||||
}
|
||||
|
||||
func (c dummyConn) Insert(docs ...interface{}) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c dummyConn) Remove(selector interface{}) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (dummyConn) Pipe(pipeline interface{}) mongo.Pipe {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c dummyConn) RemoveAll(selector interface{}) (*mgo.ChangeInfo, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (c dummyConn) RemoveId(id interface{}) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c dummyConn) Update(selector, update interface{}) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c dummyConn) UpdateId(id, update interface{}) error {
|
||||
return nil
|
||||
}
|
||||
func (c dummyConn) Upsert(selector, update interface{}) (*mgo.ChangeInfo, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
type dummyQuery struct {
|
||||
}
|
||||
|
||||
func (d dummyQuery) All(result interface{}) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (d dummyQuery) Apply(change mgo.Change, result interface{}) (*mgo.ChangeInfo, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (d dummyQuery) Count() (int, error) {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
func (d dummyQuery) Distinct(key string, result interface{}) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (d dummyQuery) Explain(result interface{}) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (d dummyQuery) For(result interface{}, f func() error) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (d dummyQuery) MapReduce(job *mgo.MapReduce, result interface{}) (*mgo.MapReduceInfo, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (d dummyQuery) One(result interface{}) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (d dummyQuery) Batch(n int) mongo.Query {
|
||||
return d
|
||||
}
|
||||
|
||||
func (d dummyQuery) Collation(collation *mgo.Collation) mongo.Query {
|
||||
return d
|
||||
}
|
||||
|
||||
func (d dummyQuery) Comment(comment string) mongo.Query {
|
||||
return d
|
||||
}
|
||||
|
||||
func (d dummyQuery) Hint(indexKey ...string) mongo.Query {
|
||||
return d
|
||||
}
|
||||
|
||||
func (d dummyQuery) Iter() mongo.Iter {
|
||||
return &mgo.Iter{}
|
||||
}
|
||||
|
||||
func (d dummyQuery) Limit(n int) mongo.Query {
|
||||
return d
|
||||
}
|
||||
|
||||
func (d dummyQuery) LogReplay() mongo.Query {
|
||||
return d
|
||||
}
|
||||
|
||||
func (d dummyQuery) Prefetch(p float64) mongo.Query {
|
||||
return d
|
||||
}
|
||||
|
||||
func (d dummyQuery) Select(selector interface{}) mongo.Query {
|
||||
return d
|
||||
}
|
||||
|
||||
func (d dummyQuery) SetMaxScan(n int) mongo.Query {
|
||||
return d
|
||||
}
|
||||
|
||||
func (d dummyQuery) SetMaxTime(duration time.Duration) mongo.Query {
|
||||
return d
|
||||
}
|
||||
|
||||
func (d dummyQuery) Skip(n int) mongo.Query {
|
||||
return d
|
||||
}
|
||||
|
||||
func (d dummyQuery) Snapshot() mongo.Query {
|
||||
return d
|
||||
}
|
||||
|
||||
func (d dummyQuery) Sort(fields ...string) mongo.Query {
|
||||
return d
|
||||
}
|
||||
|
||||
func (d dummyQuery) Tail(timeout time.Duration) mongo.Iter {
|
||||
return &mgo.Iter{}
|
||||
}
|
||||
243
core/stores/mongoc/cachedmodel.go
Normal file
243
core/stores/mongoc/cachedmodel.go
Normal file
@@ -0,0 +1,243 @@
|
||||
package mongoc
|
||||
|
||||
import (
|
||||
"log"
|
||||
|
||||
"zero/core/stores/cache"
|
||||
"zero/core/stores/internal"
|
||||
"zero/core/stores/mongo"
|
||||
"zero/core/stores/redis"
|
||||
|
||||
"github.com/globalsign/mgo"
|
||||
)
|
||||
|
||||
type Model struct {
|
||||
*mongo.Model
|
||||
cache internal.Cache
|
||||
generateCollection func(*mgo.Session) *cachedCollection
|
||||
}
|
||||
|
||||
func MustNewNodeModel(url, database, collection string, rds *redis.Redis, opts ...cache.Option) *Model {
|
||||
model, err := NewNodeModel(url, database, collection, rds, opts...)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
return model
|
||||
}
|
||||
|
||||
func MustNewModel(url, database, collection string, c cache.CacheConf, opts ...cache.Option) *Model {
|
||||
model, err := NewModel(url, database, collection, c, opts...)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
return model
|
||||
}
|
||||
|
||||
func NewNodeModel(url, database, collection string, rds *redis.Redis, opts ...cache.Option) (*Model, error) {
|
||||
c := internal.NewCacheNode(rds, sharedCalls, stats, mgo.ErrNotFound, opts...)
|
||||
return createModel(url, database, collection, c, func(collection mongo.Collection) *cachedCollection {
|
||||
return newCollection(collection, c)
|
||||
})
|
||||
}
|
||||
|
||||
func NewModel(url, database, collection string, conf cache.CacheConf, opts ...cache.Option) (*Model, error) {
|
||||
c := internal.NewCache(conf, sharedCalls, stats, mgo.ErrNotFound, opts...)
|
||||
return createModel(url, database, collection, c, func(collection mongo.Collection) *cachedCollection {
|
||||
return newCollection(collection, c)
|
||||
})
|
||||
}
|
||||
|
||||
func (mm *Model) Count(query interface{}) (int, error) {
|
||||
return mm.executeInt(func(c *cachedCollection) (int, error) {
|
||||
return c.Count(query)
|
||||
})
|
||||
}
|
||||
|
||||
func (mm *Model) DelCache(keys ...string) error {
|
||||
return mm.cache.DelCache(keys...)
|
||||
}
|
||||
|
||||
func (mm *Model) GetCache(key string, v interface{}) error {
|
||||
return mm.cache.GetCache(key, v)
|
||||
}
|
||||
|
||||
func (mm *Model) GetCollection(session *mgo.Session) *cachedCollection {
|
||||
return mm.generateCollection(session)
|
||||
}
|
||||
|
||||
func (mm *Model) FindAllNoCache(v interface{}, query interface{}, opts ...QueryOption) error {
|
||||
return mm.execute(func(c *cachedCollection) error {
|
||||
return c.FindAllNoCache(v, query, opts...)
|
||||
})
|
||||
}
|
||||
|
||||
func (mm *Model) FindOne(v interface{}, key string, query interface{}) error {
|
||||
return mm.execute(func(c *cachedCollection) error {
|
||||
return c.FindOne(v, key, query)
|
||||
})
|
||||
}
|
||||
|
||||
func (mm *Model) FindOneNoCache(v interface{}, query interface{}) error {
|
||||
return mm.execute(func(c *cachedCollection) error {
|
||||
return c.FindOneNoCache(v, query)
|
||||
})
|
||||
}
|
||||
|
||||
func (mm *Model) FindOneId(v interface{}, key string, id interface{}) error {
|
||||
return mm.execute(func(c *cachedCollection) error {
|
||||
return c.FindOneId(v, key, id)
|
||||
})
|
||||
}
|
||||
|
||||
func (mm *Model) FindOneIdNoCache(v interface{}, id interface{}) error {
|
||||
return mm.execute(func(c *cachedCollection) error {
|
||||
return c.FindOneIdNoCache(v, id)
|
||||
})
|
||||
}
|
||||
|
||||
func (mm *Model) Insert(docs ...interface{}) error {
|
||||
return mm.execute(func(c *cachedCollection) error {
|
||||
return c.Insert(docs...)
|
||||
})
|
||||
}
|
||||
|
||||
func (mm *Model) Pipe(pipeline interface{}) (mongo.Pipe, error) {
|
||||
return mm.pipe(func(c *cachedCollection) mongo.Pipe {
|
||||
return c.Pipe(pipeline)
|
||||
})
|
||||
}
|
||||
|
||||
func (mm *Model) Remove(selector interface{}, keys ...string) error {
|
||||
return mm.execute(func(c *cachedCollection) error {
|
||||
return c.Remove(selector, keys...)
|
||||
})
|
||||
}
|
||||
|
||||
func (mm *Model) RemoveNoCache(selector interface{}) error {
|
||||
return mm.execute(func(c *cachedCollection) error {
|
||||
return c.RemoveNoCache(selector)
|
||||
})
|
||||
}
|
||||
|
||||
func (mm *Model) RemoveAll(selector interface{}, keys ...string) (*mgo.ChangeInfo, error) {
|
||||
return mm.change(func(c *cachedCollection) (*mgo.ChangeInfo, error) {
|
||||
return c.RemoveAll(selector, keys...)
|
||||
})
|
||||
}
|
||||
|
||||
func (mm *Model) RemoveAllNoCache(selector interface{}) (*mgo.ChangeInfo, error) {
|
||||
return mm.change(func(c *cachedCollection) (*mgo.ChangeInfo, error) {
|
||||
return c.RemoveAllNoCache(selector)
|
||||
})
|
||||
}
|
||||
|
||||
func (mm *Model) RemoveId(id interface{}, keys ...string) error {
|
||||
return mm.execute(func(c *cachedCollection) error {
|
||||
return c.RemoveId(id, keys...)
|
||||
})
|
||||
}
|
||||
|
||||
func (mm *Model) RemoveIdNoCache(id interface{}) error {
|
||||
return mm.execute(func(c *cachedCollection) error {
|
||||
return c.RemoveIdNoCache(id)
|
||||
})
|
||||
}
|
||||
|
||||
func (mm *Model) SetCache(key string, v interface{}) error {
|
||||
return mm.cache.SetCache(key, v)
|
||||
}
|
||||
|
||||
func (mm *Model) Update(selector, update interface{}, keys ...string) error {
|
||||
return mm.execute(func(c *cachedCollection) error {
|
||||
return c.Update(selector, update, keys...)
|
||||
})
|
||||
}
|
||||
|
||||
func (mm *Model) UpdateNoCache(selector, update interface{}) error {
|
||||
return mm.execute(func(c *cachedCollection) error {
|
||||
return c.UpdateNoCache(selector, update)
|
||||
})
|
||||
}
|
||||
|
||||
func (mm *Model) UpdateId(id, update interface{}, keys ...string) error {
|
||||
return mm.execute(func(c *cachedCollection) error {
|
||||
return c.UpdateId(id, update, keys...)
|
||||
})
|
||||
}
|
||||
|
||||
func (mm *Model) UpdateIdNoCache(id, update interface{}) error {
|
||||
return mm.execute(func(c *cachedCollection) error {
|
||||
return c.UpdateIdNoCache(id, update)
|
||||
})
|
||||
}
|
||||
|
||||
func (mm *Model) Upsert(selector, update interface{}, keys ...string) (*mgo.ChangeInfo, error) {
|
||||
return mm.change(func(c *cachedCollection) (*mgo.ChangeInfo, error) {
|
||||
return c.Upsert(selector, update, keys...)
|
||||
})
|
||||
}
|
||||
|
||||
func (mm *Model) UpsertNoCache(selector, update interface{}) (*mgo.ChangeInfo, error) {
|
||||
return mm.change(func(c *cachedCollection) (*mgo.ChangeInfo, error) {
|
||||
return c.UpsertNoCache(selector, update)
|
||||
})
|
||||
}
|
||||
|
||||
func (mm *Model) change(fn func(c *cachedCollection) (*mgo.ChangeInfo, error)) (*mgo.ChangeInfo, error) {
|
||||
session, err := mm.TakeSession()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer mm.PutSession(session)
|
||||
|
||||
return fn(mm.GetCollection(session))
|
||||
}
|
||||
|
||||
func (mm *Model) execute(fn func(c *cachedCollection) error) error {
|
||||
session, err := mm.TakeSession()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer mm.PutSession(session)
|
||||
|
||||
return fn(mm.GetCollection(session))
|
||||
}
|
||||
|
||||
func (mm *Model) executeInt(fn func(c *cachedCollection) (int, error)) (int, error) {
|
||||
session, err := mm.TakeSession()
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
defer mm.PutSession(session)
|
||||
|
||||
return fn(mm.GetCollection(session))
|
||||
}
|
||||
|
||||
func (mm *Model) pipe(fn func(c *cachedCollection) mongo.Pipe) (mongo.Pipe, error) {
|
||||
session, err := mm.TakeSession()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer mm.PutSession(session)
|
||||
|
||||
return fn(mm.GetCollection(session)), nil
|
||||
}
|
||||
|
||||
func createModel(url, database, collection string, c internal.Cache,
|
||||
create func(mongo.Collection) *cachedCollection) (*Model, error) {
|
||||
model, err := mongo.NewModel(url, database, collection)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &Model{
|
||||
Model: model,
|
||||
cache: c,
|
||||
generateCollection: func(session *mgo.Session) *cachedCollection {
|
||||
collection := model.GetCollection(session)
|
||||
return create(collection)
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
13
core/stores/postgres/postgresql.go
Normal file
13
core/stores/postgres/postgresql.go
Normal file
@@ -0,0 +1,13 @@
|
||||
package postgres
|
||||
|
||||
import (
|
||||
"zero/core/stores/sqlx"
|
||||
|
||||
_ "github.com/lib/pq"
|
||||
)
|
||||
|
||||
const postgreDriverName = "postgres"
|
||||
|
||||
func NewPostgre(datasource string, opts ...sqlx.SqlOption) sqlx.SqlConn {
|
||||
return sqlx.NewSqlConn(postgreDriverName, datasource, opts...)
|
||||
}
|
||||
50
core/stores/redis/conf.go
Normal file
50
core/stores/redis/conf.go
Normal file
@@ -0,0 +1,50 @@
|
||||
package redis
|
||||
|
||||
import "errors"
|
||||
|
||||
var (
|
||||
ErrEmptyHost = errors.New("empty redis host")
|
||||
ErrEmptyType = errors.New("empty redis type")
|
||||
ErrEmptyKey = errors.New("empty redis key")
|
||||
)
|
||||
|
||||
type (
|
||||
RedisConf struct {
|
||||
Host string
|
||||
Type string `json:",default=node,options=node|cluster"`
|
||||
Pass string `json:",optional"`
|
||||
}
|
||||
|
||||
RedisKeyConf struct {
|
||||
RedisConf
|
||||
Key string `json:",optional"`
|
||||
}
|
||||
)
|
||||
|
||||
func (rc RedisConf) NewRedis() *Redis {
|
||||
return NewRedis(rc.Host, rc.Type, rc.Pass)
|
||||
}
|
||||
|
||||
func (rc RedisConf) Validate() error {
|
||||
if len(rc.Host) == 0 {
|
||||
return ErrEmptyHost
|
||||
}
|
||||
|
||||
if len(rc.Type) == 0 {
|
||||
return ErrEmptyType
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (rkc RedisKeyConf) Validate() error {
|
||||
if err := rkc.RedisConf.Validate(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if len(rkc.Key) == 0 {
|
||||
return ErrEmptyKey
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
33
core/stores/redis/process.go
Normal file
33
core/stores/redis/process.go
Normal file
@@ -0,0 +1,33 @@
|
||||
package redis
|
||||
|
||||
import (
|
||||
"strings"
|
||||
|
||||
"zero/core/logx"
|
||||
"zero/core/mapping"
|
||||
"zero/core/timex"
|
||||
|
||||
red "github.com/go-redis/redis"
|
||||
)
|
||||
|
||||
func process(proc func(red.Cmder) error) func(red.Cmder) error {
|
||||
return func(cmd red.Cmder) error {
|
||||
start := timex.Now()
|
||||
|
||||
defer func() {
|
||||
duration := timex.Since(start)
|
||||
if duration > slowThreshold {
|
||||
var buf strings.Builder
|
||||
for i, arg := range cmd.Args() {
|
||||
if i > 0 {
|
||||
buf.WriteByte(' ')
|
||||
}
|
||||
buf.WriteString(mapping.Repr(arg))
|
||||
}
|
||||
logx.WithDuration(duration).Slowf("[REDIS] slowcall on executing: %s", buf.String())
|
||||
}
|
||||
}()
|
||||
|
||||
return proc(cmd)
|
||||
}
|
||||
}
|
||||
1339
core/stores/redis/redis.go
Normal file
1339
core/stores/redis/redis.go
Normal file
File diff suppressed because it is too large
Load Diff
580
core/stores/redis/redis_test.go
Normal file
580
core/stores/redis/redis_test.go
Normal file
@@ -0,0 +1,580 @@
|
||||
package redis
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"io"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/alicebob/miniredis"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestRedis_Exists(t *testing.T) {
|
||||
runOnRedis(t, func(client *Redis) {
|
||||
ok, err := client.Exists("a")
|
||||
assert.Nil(t, err)
|
||||
assert.False(t, ok)
|
||||
assert.Nil(t, client.Set("a", "b"))
|
||||
ok, err = client.Exists("a")
|
||||
assert.Nil(t, err)
|
||||
assert.True(t, ok)
|
||||
})
|
||||
}
|
||||
|
||||
func TestRedis_Eval(t *testing.T) {
|
||||
runOnRedis(t, func(client *Redis) {
|
||||
_, err := client.Eval(`redis.call("EXISTS", KEYS[1])`, []string{"notexist"})
|
||||
assert.Equal(t, Nil, err)
|
||||
err = client.Set("key1", "value1")
|
||||
assert.Nil(t, err)
|
||||
_, err = client.Eval(`redis.call("EXISTS", KEYS[1])`, []string{"key1"})
|
||||
assert.Equal(t, Nil, err)
|
||||
val, err := client.Eval(`return redis.call("EXISTS", KEYS[1])`, []string{"key1"})
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, int64(1), val)
|
||||
})
|
||||
}
|
||||
|
||||
func TestRedis_Hgetall(t *testing.T) {
|
||||
runOnRedis(t, func(client *Redis) {
|
||||
assert.Nil(t, client.Hset("a", "aa", "aaa"))
|
||||
assert.Nil(t, client.Hset("a", "bb", "bbb"))
|
||||
vals, err := client.Hgetall("a")
|
||||
assert.Nil(t, err)
|
||||
assert.EqualValues(t, map[string]string{
|
||||
"aa": "aaa",
|
||||
"bb": "bbb",
|
||||
}, vals)
|
||||
})
|
||||
}
|
||||
|
||||
func TestRedis_Hvals(t *testing.T) {
|
||||
runOnRedis(t, func(client *Redis) {
|
||||
assert.Nil(t, client.Hset("a", "aa", "aaa"))
|
||||
assert.Nil(t, client.Hset("a", "bb", "bbb"))
|
||||
vals, err := client.Hvals("a")
|
||||
assert.Nil(t, err)
|
||||
assert.ElementsMatch(t, []string{"aaa", "bbb"}, vals)
|
||||
})
|
||||
}
|
||||
|
||||
func TestRedis_Hsetnx(t *testing.T) {
|
||||
runOnRedis(t, func(client *Redis) {
|
||||
assert.Nil(t, client.Hset("a", "aa", "aaa"))
|
||||
assert.Nil(t, client.Hset("a", "bb", "bbb"))
|
||||
ok, err := client.Hsetnx("a", "bb", "ccc")
|
||||
assert.Nil(t, err)
|
||||
assert.False(t, ok)
|
||||
ok, err = client.Hsetnx("a", "dd", "ddd")
|
||||
assert.Nil(t, err)
|
||||
assert.True(t, ok)
|
||||
vals, err := client.Hvals("a")
|
||||
assert.Nil(t, err)
|
||||
assert.ElementsMatch(t, []string{"aaa", "bbb", "ddd"}, vals)
|
||||
})
|
||||
}
|
||||
|
||||
func TestRedis_HdelHlen(t *testing.T) {
|
||||
runOnRedis(t, func(client *Redis) {
|
||||
assert.Nil(t, client.Hset("a", "aa", "aaa"))
|
||||
assert.Nil(t, client.Hset("a", "bb", "bbb"))
|
||||
num, err := client.Hlen("a")
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, 2, num)
|
||||
val, err := client.Hdel("a", "aa")
|
||||
assert.Nil(t, err)
|
||||
assert.True(t, val)
|
||||
vals, err := client.Hvals("a")
|
||||
assert.Nil(t, err)
|
||||
assert.ElementsMatch(t, []string{"bbb"}, vals)
|
||||
})
|
||||
}
|
||||
|
||||
func TestRedis_HIncrBy(t *testing.T) {
|
||||
runOnRedis(t, func(client *Redis) {
|
||||
val, err := client.Hincrby("key", "field", 2)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, 2, val)
|
||||
val, err = client.Hincrby("key", "field", 3)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, 5, val)
|
||||
})
|
||||
}
|
||||
|
||||
func TestRedis_Hkeys(t *testing.T) {
|
||||
runOnRedis(t, func(client *Redis) {
|
||||
assert.Nil(t, client.Hset("a", "aa", "aaa"))
|
||||
assert.Nil(t, client.Hset("a", "bb", "bbb"))
|
||||
vals, err := client.Hkeys("a")
|
||||
assert.Nil(t, err)
|
||||
assert.ElementsMatch(t, []string{"aa", "bb"}, vals)
|
||||
})
|
||||
}
|
||||
|
||||
func TestRedis_Hmget(t *testing.T) {
|
||||
runOnRedis(t, func(client *Redis) {
|
||||
assert.Nil(t, client.Hset("a", "aa", "aaa"))
|
||||
assert.Nil(t, client.Hset("a", "bb", "bbb"))
|
||||
vals, err := client.Hmget("a", "aa", "bb")
|
||||
assert.Nil(t, err)
|
||||
assert.EqualValues(t, []string{"aaa", "bbb"}, vals)
|
||||
vals, err = client.Hmget("a", "aa", "no", "bb")
|
||||
assert.Nil(t, err)
|
||||
assert.EqualValues(t, []string{"aaa", "", "bbb"}, vals)
|
||||
})
|
||||
}
|
||||
|
||||
func TestRedis_Hmset(t *testing.T) {
|
||||
runOnRedis(t, func(client *Redis) {
|
||||
assert.Nil(t, client.Hmset("a", map[string]string{
|
||||
"aa": "aaa",
|
||||
"bb": "bbb",
|
||||
}))
|
||||
vals, err := client.Hmget("a", "aa", "bb")
|
||||
assert.Nil(t, err)
|
||||
assert.EqualValues(t, []string{"aaa", "bbb"}, vals)
|
||||
})
|
||||
}
|
||||
|
||||
func TestRedis_Incr(t *testing.T) {
|
||||
runOnRedis(t, func(client *Redis) {
|
||||
val, err := client.Incr("a")
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, int64(1), val)
|
||||
val, err = client.Incr("a")
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, int64(2), val)
|
||||
})
|
||||
}
|
||||
|
||||
func TestRedis_IncrBy(t *testing.T) {
|
||||
runOnRedis(t, func(client *Redis) {
|
||||
val, err := client.Incrby("a", 2)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, int64(2), val)
|
||||
val, err = client.Incrby("a", 3)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, int64(5), val)
|
||||
})
|
||||
}
|
||||
|
||||
func TestRedis_Keys(t *testing.T) {
|
||||
runOnRedis(t, func(client *Redis) {
|
||||
err := client.Set("key1", "value1")
|
||||
assert.Nil(t, err)
|
||||
err = client.Set("key2", "value2")
|
||||
assert.Nil(t, err)
|
||||
keys, err := client.Keys("*")
|
||||
assert.Nil(t, err)
|
||||
assert.ElementsMatch(t, []string{"key1", "key2"}, keys)
|
||||
})
|
||||
}
|
||||
|
||||
func TestRedis_List(t *testing.T) {
|
||||
runOnRedis(t, func(client *Redis) {
|
||||
val, err := client.Lpush("key", "value1", "value2")
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, 2, val)
|
||||
val, err = client.Rpush("key", "value3", "value4")
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, 4, val)
|
||||
val, err = client.Llen("key")
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, 4, val)
|
||||
vals, err := client.Lrange("key", 0, 10)
|
||||
assert.Nil(t, err)
|
||||
assert.EqualValues(t, []string{"value2", "value1", "value3", "value4"}, vals)
|
||||
v, err := client.Lpop("key")
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, "value2", v)
|
||||
val, err = client.Lpush("key", "value1", "value2")
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, 5, val)
|
||||
val, err = client.Rpush("key", "value3", "value3")
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, 7, val)
|
||||
n, err := client.Lrem("key", 2, "value1")
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, 2, n)
|
||||
vals, err = client.Lrange("key", 0, 10)
|
||||
assert.Nil(t, err)
|
||||
assert.EqualValues(t, []string{"value2", "value3", "value4", "value3", "value3"}, vals)
|
||||
n, err = client.Lrem("key", -2, "value3")
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, 2, n)
|
||||
vals, err = client.Lrange("key", 0, 10)
|
||||
assert.Nil(t, err)
|
||||
assert.EqualValues(t, []string{"value2", "value3", "value4"}, vals)
|
||||
})
|
||||
}
|
||||
|
||||
func TestRedis_Mget(t *testing.T) {
|
||||
runOnRedis(t, func(client *Redis) {
|
||||
err := client.Set("key1", "value1")
|
||||
assert.Nil(t, err)
|
||||
err = client.Set("key2", "value2")
|
||||
assert.Nil(t, err)
|
||||
vals, err := client.Mget("key1", "key0", "key2", "key3")
|
||||
assert.Nil(t, err)
|
||||
assert.EqualValues(t, []string{"value1", "", "value2", ""}, vals)
|
||||
})
|
||||
}
|
||||
|
||||
func TestRedis_SetBit(t *testing.T) {
|
||||
runOnRedis(t, func(client *Redis) {
|
||||
err := client.SetBit("key", 1, 1)
|
||||
assert.Nil(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
func TestRedis_GetBit(t *testing.T) {
|
||||
runOnRedis(t, func(client *Redis) {
|
||||
err := client.SetBit("key", 2, 1)
|
||||
assert.Nil(t, err)
|
||||
val, err := client.GetBit("key", 2)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, 1, val)
|
||||
})
|
||||
}
|
||||
|
||||
func TestRedis_Persist(t *testing.T) {
|
||||
runOnRedis(t, func(client *Redis) {
|
||||
ok, err := client.Persist("key")
|
||||
assert.Nil(t, err)
|
||||
assert.False(t, ok)
|
||||
err = client.Set("key", "value")
|
||||
assert.Nil(t, err)
|
||||
ok, err = client.Persist("key")
|
||||
assert.Nil(t, err)
|
||||
assert.False(t, ok)
|
||||
err = client.Expire("key", 5)
|
||||
ok, err = client.Persist("key")
|
||||
assert.Nil(t, err)
|
||||
assert.True(t, ok)
|
||||
err = client.Expireat("key", time.Now().Unix()+5)
|
||||
ok, err = client.Persist("key")
|
||||
assert.Nil(t, err)
|
||||
assert.True(t, ok)
|
||||
})
|
||||
}
|
||||
|
||||
func TestRedis_Ping(t *testing.T) {
|
||||
runOnRedis(t, func(client *Redis) {
|
||||
ok := client.Ping()
|
||||
assert.True(t, ok)
|
||||
})
|
||||
}
|
||||
|
||||
func TestRedis_Scan(t *testing.T) {
|
||||
runOnRedis(t, func(client *Redis) {
|
||||
err := client.Set("key1", "value1")
|
||||
assert.Nil(t, err)
|
||||
err = client.Set("key2", "value2")
|
||||
assert.Nil(t, err)
|
||||
keys, _, err := client.Scan(0, "*", 100)
|
||||
assert.Nil(t, err)
|
||||
assert.ElementsMatch(t, []string{"key1", "key2"}, keys)
|
||||
})
|
||||
}
|
||||
|
||||
func TestRedis_Sscan(t *testing.T) {
|
||||
runOnRedis(t, func(client *Redis) {
|
||||
key := "list"
|
||||
var list []string
|
||||
for i := 0; i < 1550; i++ {
|
||||
list = append(list, randomStr(i))
|
||||
}
|
||||
lens, err := client.Sadd(key, list)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, lens, 1550)
|
||||
|
||||
var cursor uint64 = 0
|
||||
sum := 0
|
||||
for {
|
||||
keys, next, err := client.Sscan(key, cursor, "", 100)
|
||||
assert.Nil(t, err)
|
||||
sum += len(keys)
|
||||
if next == 0 {
|
||||
break
|
||||
}
|
||||
cursor = next
|
||||
}
|
||||
|
||||
assert.Equal(t, sum, 1550)
|
||||
_, err = client.Del(key)
|
||||
assert.Nil(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
func TestRedis_Set(t *testing.T) {
|
||||
runOnRedis(t, func(client *Redis) {
|
||||
num, err := client.Sadd("key", 1, 2, 3, 4)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, 4, num)
|
||||
val, err := client.Scard("key")
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, int64(4), val)
|
||||
ok, err := client.Sismember("key", 2)
|
||||
assert.Nil(t, err)
|
||||
assert.True(t, ok)
|
||||
num, err = client.Srem("key", 3, 4)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, 2, num)
|
||||
vals, err := client.Smembers("key")
|
||||
assert.Nil(t, err)
|
||||
assert.ElementsMatch(t, []string{"1", "2"}, vals)
|
||||
members, err := client.Srandmember("key", 1)
|
||||
assert.Nil(t, err)
|
||||
assert.Len(t, members, 1)
|
||||
assert.Contains(t, []string{"1", "2"}, members[0])
|
||||
member, err := client.Spop("key")
|
||||
assert.Nil(t, err)
|
||||
assert.Contains(t, []string{"1", "2"}, member)
|
||||
vals, err = client.Smembers("key")
|
||||
assert.Nil(t, err)
|
||||
assert.NotContains(t, vals, member)
|
||||
num, err = client.Sadd("key1", 1, 2, 3, 4)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, 4, num)
|
||||
num, err = client.Sadd("key2", 2, 3, 4, 5)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, 4, num)
|
||||
vals, err = client.Sunion("key1", "key2")
|
||||
assert.Nil(t, err)
|
||||
assert.ElementsMatch(t, []string{"1", "2", "3", "4", "5"}, vals)
|
||||
num, err = client.Sunionstore("key3", "key1", "key2")
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, 5, num)
|
||||
vals, err = client.Sdiff("key1", "key2")
|
||||
assert.Nil(t, err)
|
||||
assert.EqualValues(t, []string{"1"}, vals)
|
||||
num, err = client.Sdiffstore("key4", "key1", "key2")
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, 1, num)
|
||||
})
|
||||
}
|
||||
|
||||
func TestRedis_SetGetDel(t *testing.T) {
|
||||
runOnRedis(t, func(client *Redis) {
|
||||
err := client.Set("hello", "world")
|
||||
assert.Nil(t, err)
|
||||
val, err := client.Get("hello")
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, "world", val)
|
||||
ret, err := client.Del("hello")
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, 1, ret)
|
||||
})
|
||||
}
|
||||
|
||||
func TestRedis_SetExNx(t *testing.T) {
|
||||
runOnRedis(t, func(client *Redis) {
|
||||
err := client.Setex("hello", "world", 5)
|
||||
assert.Nil(t, err)
|
||||
ok, err := client.Setnx("hello", "newworld")
|
||||
assert.Nil(t, err)
|
||||
assert.False(t, ok)
|
||||
ok, err = client.Setnx("newhello", "newworld")
|
||||
assert.Nil(t, err)
|
||||
assert.True(t, ok)
|
||||
val, err := client.Get("hello")
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, "world", val)
|
||||
val, err = client.Get("newhello")
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, "newworld", val)
|
||||
ttl, err := client.Ttl("hello")
|
||||
assert.Nil(t, err)
|
||||
assert.True(t, ttl > 0)
|
||||
ok, err = client.SetnxEx("newhello", "newworld", 5)
|
||||
assert.Nil(t, err)
|
||||
assert.False(t, ok)
|
||||
num, err := client.Del("newhello")
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, 1, num)
|
||||
ok, err = client.SetnxEx("newhello", "newworld", 5)
|
||||
assert.Nil(t, err)
|
||||
assert.True(t, ok)
|
||||
val, err = client.Get("newhello")
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, "newworld", val)
|
||||
})
|
||||
}
|
||||
|
||||
func TestRedis_SetGetDelHashField(t *testing.T) {
|
||||
runOnRedis(t, func(client *Redis) {
|
||||
err := client.Hset("key", "field", "value")
|
||||
assert.Nil(t, err)
|
||||
val, err := client.Hget("key", "field")
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, "value", val)
|
||||
ok, err := client.Hexists("key", "field")
|
||||
assert.Nil(t, err)
|
||||
assert.True(t, ok)
|
||||
ret, err := client.Hdel("key", "field")
|
||||
assert.Nil(t, err)
|
||||
assert.True(t, ret)
|
||||
ok, err = client.Hexists("key", "field")
|
||||
assert.Nil(t, err)
|
||||
assert.False(t, ok)
|
||||
})
|
||||
}
|
||||
|
||||
func TestRedis_SortedSet(t *testing.T) {
|
||||
runOnRedis(t, func(client *Redis) {
|
||||
ok, err := client.Zadd("key", 1, "value1")
|
||||
assert.Nil(t, err)
|
||||
assert.True(t, ok)
|
||||
ok, err = client.Zadd("key", 2, "value1")
|
||||
assert.Nil(t, err)
|
||||
assert.False(t, ok)
|
||||
val, err := client.Zscore("key", "value1")
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, int64(2), val)
|
||||
val, err = client.Zincrby("key", 3, "value1")
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, int64(5), val)
|
||||
val, err = client.Zscore("key", "value1")
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, int64(5), val)
|
||||
ok, err = client.Zadd("key", 6, "value2")
|
||||
assert.Nil(t, err)
|
||||
assert.True(t, ok)
|
||||
ok, err = client.Zadd("key", 7, "value3")
|
||||
assert.Nil(t, err)
|
||||
assert.True(t, ok)
|
||||
rank, err := client.Zrank("key", "value2")
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, int64(1), rank)
|
||||
rank, err = client.Zrank("key", "value4")
|
||||
assert.Equal(t, Nil, err)
|
||||
num, err := client.Zrem("key", "value2", "value3")
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, 2, num)
|
||||
ok, err = client.Zadd("key", 6, "value2")
|
||||
assert.Nil(t, err)
|
||||
assert.True(t, ok)
|
||||
ok, err = client.Zadd("key", 7, "value3")
|
||||
assert.Nil(t, err)
|
||||
assert.True(t, ok)
|
||||
ok, err = client.Zadd("key", 8, "value4")
|
||||
assert.Nil(t, err)
|
||||
assert.True(t, ok)
|
||||
num, err = client.Zremrangebyscore("key", 6, 7)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, 2, num)
|
||||
ok, err = client.Zadd("key", 6, "value2")
|
||||
assert.Nil(t, err)
|
||||
assert.True(t, ok)
|
||||
ok, err = client.Zadd("key", 7, "value3")
|
||||
assert.Nil(t, err)
|
||||
assert.True(t, ok)
|
||||
num, err = client.Zcount("key", 6, 7)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, 2, num)
|
||||
num, err = client.Zremrangebyrank("key", 1, 2)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, 2, num)
|
||||
card, err := client.Zcard("key")
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, 2, card)
|
||||
vals, err := client.Zrange("key", 0, -1)
|
||||
assert.Nil(t, err)
|
||||
assert.EqualValues(t, []string{"value1", "value4"}, vals)
|
||||
vals, err = client.Zrevrange("key", 0, -1)
|
||||
assert.Nil(t, err)
|
||||
assert.EqualValues(t, []string{"value4", "value1"}, vals)
|
||||
pairs, err := client.ZrangeWithScores("key", 0, -1)
|
||||
assert.Nil(t, err)
|
||||
assert.EqualValues(t, []Pair{
|
||||
{
|
||||
Key: "value1",
|
||||
Score: 5,
|
||||
},
|
||||
{
|
||||
Key: "value4",
|
||||
Score: 8,
|
||||
},
|
||||
}, pairs)
|
||||
pairs, err = client.ZrangebyscoreWithScores("key", 5, 8)
|
||||
assert.Nil(t, err)
|
||||
assert.EqualValues(t, []Pair{
|
||||
{
|
||||
Key: "value1",
|
||||
Score: 5,
|
||||
},
|
||||
{
|
||||
Key: "value4",
|
||||
Score: 8,
|
||||
},
|
||||
}, pairs)
|
||||
pairs, err = client.ZrangebyscoreWithScoresAndLimit("key", 5, 8, 1, 1)
|
||||
assert.Nil(t, err)
|
||||
assert.EqualValues(t, []Pair{
|
||||
{
|
||||
Key: "value4",
|
||||
Score: 8,
|
||||
},
|
||||
}, pairs)
|
||||
pairs, err = client.ZrevrangebyscoreWithScores("key", 5, 8)
|
||||
assert.Nil(t, err)
|
||||
assert.EqualValues(t, []Pair{
|
||||
{
|
||||
Key: "value4",
|
||||
Score: 8,
|
||||
},
|
||||
{
|
||||
Key: "value1",
|
||||
Score: 5,
|
||||
},
|
||||
}, pairs)
|
||||
pairs, err = client.ZrevrangebyscoreWithScoresAndLimit("key", 5, 8, 1, 1)
|
||||
assert.Nil(t, err)
|
||||
assert.EqualValues(t, []Pair{
|
||||
{
|
||||
Key: "value1",
|
||||
Score: 5,
|
||||
},
|
||||
}, pairs)
|
||||
})
|
||||
}
|
||||
|
||||
func TestRedis_Pipelined(t *testing.T) {
|
||||
runOnRedis(t, func(client *Redis) {
|
||||
err := client.Pipelined(
|
||||
func(pipe Pipeliner) error {
|
||||
pipe.Incr("pipelined_counter")
|
||||
pipe.Expire("pipelined_counter", time.Hour)
|
||||
pipe.ZAdd("zadd", Z{Score: 12, Member: "zadd"})
|
||||
return nil
|
||||
},
|
||||
)
|
||||
assert.Nil(t, err)
|
||||
ttl, err := client.Ttl("pipelined_counter")
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, 3600, ttl)
|
||||
value, err := client.Get("pipelined_counter")
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, "1", value)
|
||||
score, err := client.Zscore("zadd", "zadd")
|
||||
assert.Equal(t, int64(12), score)
|
||||
})
|
||||
}
|
||||
|
||||
func runOnRedis(t *testing.T, fn func(client *Redis)) {
|
||||
s, err := miniredis.Run()
|
||||
assert.Nil(t, err)
|
||||
defer func() {
|
||||
client, err := clientManager.GetResource(s.Addr(), func() (io.Closer, error) {
|
||||
return nil, errors.New("should already exist")
|
||||
})
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
client.Close()
|
||||
}()
|
||||
|
||||
fn(NewRedis(s.Addr(), NodeType))
|
||||
}
|
||||
66
core/stores/redis/redisblockingnode.go
Normal file
66
core/stores/redis/redisblockingnode.go
Normal file
@@ -0,0 +1,66 @@
|
||||
package redis
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"zero/core/logx"
|
||||
|
||||
red "github.com/go-redis/redis"
|
||||
)
|
||||
|
||||
type ClosableNode interface {
|
||||
RedisNode
|
||||
Close()
|
||||
}
|
||||
|
||||
func CreateBlockingNode(r *Redis) (ClosableNode, error) {
|
||||
timeout := readWriteTimeout + blockingQueryTimeout
|
||||
|
||||
switch r.Type {
|
||||
case NodeType:
|
||||
client := red.NewClient(&red.Options{
|
||||
Addr: r.Addr,
|
||||
Password: r.Pass,
|
||||
DB: defaultDatabase,
|
||||
MaxRetries: maxRetries,
|
||||
PoolSize: 1,
|
||||
MinIdleConns: 1,
|
||||
ReadTimeout: timeout,
|
||||
})
|
||||
return &clientBridge{client}, nil
|
||||
case ClusterType:
|
||||
client := red.NewClusterClient(&red.ClusterOptions{
|
||||
Addrs: []string{r.Addr},
|
||||
Password: r.Pass,
|
||||
MaxRetries: maxRetries,
|
||||
PoolSize: 1,
|
||||
MinIdleConns: 1,
|
||||
ReadTimeout: timeout,
|
||||
})
|
||||
return &clusterBridge{client}, nil
|
||||
default:
|
||||
return nil, fmt.Errorf("unknown redis type: %s", r.Type)
|
||||
}
|
||||
}
|
||||
|
||||
type (
|
||||
clientBridge struct {
|
||||
*red.Client
|
||||
}
|
||||
|
||||
clusterBridge struct {
|
||||
*red.ClusterClient
|
||||
}
|
||||
)
|
||||
|
||||
func (bridge *clientBridge) Close() {
|
||||
if err := bridge.Client.Close(); err != nil {
|
||||
logx.Errorf("Error occurred on close redis client: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
func (bridge *clusterBridge) Close() {
|
||||
if err := bridge.ClusterClient.Close(); err != nil {
|
||||
logx.Errorf("Error occurred on close redis cluster: %s", err)
|
||||
}
|
||||
}
|
||||
36
core/stores/redis/redisclientmanager.go
Normal file
36
core/stores/redis/redisclientmanager.go
Normal file
@@ -0,0 +1,36 @@
|
||||
package redis
|
||||
|
||||
import (
|
||||
"io"
|
||||
|
||||
"zero/core/syncx"
|
||||
|
||||
red "github.com/go-redis/redis"
|
||||
)
|
||||
|
||||
const (
|
||||
defaultDatabase = 0
|
||||
maxRetries = 3
|
||||
idleConns = 8
|
||||
)
|
||||
|
||||
var clientManager = syncx.NewResourceManager()
|
||||
|
||||
func getClient(server, pass string) (*red.Client, error) {
|
||||
val, err := clientManager.GetResource(server, func() (io.Closer, error) {
|
||||
store := red.NewClient(&red.Options{
|
||||
Addr: server,
|
||||
Password: pass,
|
||||
DB: defaultDatabase,
|
||||
MaxRetries: maxRetries,
|
||||
MinIdleConns: idleConns,
|
||||
})
|
||||
store.WrapProcess(process)
|
||||
return store, nil
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return val.(*red.Client), nil
|
||||
}
|
||||
30
core/stores/redis/redisclustermanager.go
Normal file
30
core/stores/redis/redisclustermanager.go
Normal file
@@ -0,0 +1,30 @@
|
||||
package redis
|
||||
|
||||
import (
|
||||
"io"
|
||||
|
||||
"zero/core/syncx"
|
||||
|
||||
red "github.com/go-redis/redis"
|
||||
)
|
||||
|
||||
var clusterManager = syncx.NewResourceManager()
|
||||
|
||||
func getCluster(server, pass string) (*red.ClusterClient, error) {
|
||||
val, err := clusterManager.GetResource(server, func() (io.Closer, error) {
|
||||
store := red.NewClusterClient(&red.ClusterOptions{
|
||||
Addrs: []string{server},
|
||||
Password: pass,
|
||||
MaxRetries: maxRetries,
|
||||
MinIdleConns: idleConns,
|
||||
})
|
||||
store.WrapProcess(process)
|
||||
|
||||
return store, nil
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return val.(*red.ClusterClient), nil
|
||||
}
|
||||
96
core/stores/redis/redislock.go
Normal file
96
core/stores/redis/redislock.go
Normal file
@@ -0,0 +1,96 @@
|
||||
package redis
|
||||
|
||||
import (
|
||||
"math/rand"
|
||||
"strconv"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"zero/core/logx"
|
||||
|
||||
red "github.com/go-redis/redis"
|
||||
)
|
||||
|
||||
const (
|
||||
letters = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ"
|
||||
lockCommand = `if redis.call("GET", KEYS[1]) == ARGV[1] then
|
||||
redis.call("SET", KEYS[1], ARGV[1], "PX", ARGV[2])
|
||||
return "OK"
|
||||
else
|
||||
return redis.call("SET", KEYS[1], ARGV[1], "NX", "PX", ARGV[2])
|
||||
end`
|
||||
delCommand = `if redis.call("GET", KEYS[1]) == ARGV[1] then
|
||||
return redis.call("DEL", KEYS[1])
|
||||
else
|
||||
return 0
|
||||
end`
|
||||
randomLen = 16
|
||||
tolerance = 500 // milliseconds
|
||||
millisPerSecond = 1000
|
||||
)
|
||||
|
||||
type RedisLock struct {
|
||||
store *Redis
|
||||
seconds uint32
|
||||
key string
|
||||
id string
|
||||
}
|
||||
|
||||
func init() {
|
||||
rand.Seed(time.Now().UnixNano())
|
||||
}
|
||||
|
||||
func NewRedisLock(store *Redis, key string) *RedisLock {
|
||||
return &RedisLock{
|
||||
store: store,
|
||||
key: key,
|
||||
id: randomStr(randomLen),
|
||||
}
|
||||
}
|
||||
|
||||
func (rl *RedisLock) Acquire() (bool, error) {
|
||||
seconds := atomic.LoadUint32(&rl.seconds)
|
||||
resp, err := rl.store.Eval(lockCommand, []string{rl.key}, []string{
|
||||
rl.id, strconv.Itoa(int(seconds)*millisPerSecond + tolerance)})
|
||||
if err == red.Nil {
|
||||
return false, nil
|
||||
} else if err != nil {
|
||||
logx.Errorf("Error on acquiring lock for %s, %s", rl.key, err.Error())
|
||||
return false, err
|
||||
} else if resp == nil {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
reply, ok := resp.(string)
|
||||
if ok && reply == "OK" {
|
||||
return true, nil
|
||||
} else {
|
||||
logx.Errorf("Unknown reply when acquiring lock for %s: %v", rl.key, resp)
|
||||
return false, nil
|
||||
}
|
||||
}
|
||||
|
||||
func (rl *RedisLock) Release() (bool, error) {
|
||||
resp, err := rl.store.Eval(delCommand, []string{rl.key}, []string{rl.id})
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
if reply, ok := resp.(int64); !ok {
|
||||
return false, nil
|
||||
} else {
|
||||
return reply == 1, nil
|
||||
}
|
||||
}
|
||||
|
||||
func (rl *RedisLock) SetExpire(seconds int) {
|
||||
atomic.StoreUint32(&rl.seconds, uint32(seconds))
|
||||
}
|
||||
|
||||
func randomStr(n int) string {
|
||||
b := make([]byte, n)
|
||||
for i := range b {
|
||||
b[i] = letters[rand.Intn(len(letters))]
|
||||
}
|
||||
return string(b)
|
||||
}
|
||||
34
core/stores/redis/redislock_test.go
Normal file
34
core/stores/redis/redislock_test.go
Normal file
@@ -0,0 +1,34 @@
|
||||
package redis
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"zero/core/stringx"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestRedisLock(t *testing.T) {
|
||||
runOnRedis(t, func(client *Redis) {
|
||||
key := stringx.Rand()
|
||||
firstLock := NewRedisLock(client, key)
|
||||
firstLock.SetExpire(5)
|
||||
firstAcquire, err := firstLock.Acquire()
|
||||
assert.Nil(t, err)
|
||||
assert.True(t, firstAcquire)
|
||||
|
||||
secondLock := NewRedisLock(client, key)
|
||||
secondLock.SetExpire(5)
|
||||
againAcquire, err := secondLock.Acquire()
|
||||
assert.Nil(t, err)
|
||||
assert.False(t, againAcquire)
|
||||
|
||||
release, err := firstLock.Release()
|
||||
assert.Nil(t, err)
|
||||
assert.True(t, release)
|
||||
|
||||
endAcquire, err := secondLock.Acquire()
|
||||
assert.Nil(t, err)
|
||||
assert.True(t, endAcquire)
|
||||
})
|
||||
}
|
||||
48
core/stores/redis/scriptcache.go
Normal file
48
core/stores/redis/scriptcache.go
Normal file
@@ -0,0 +1,48 @@
|
||||
package redis
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
)
|
||||
|
||||
var (
|
||||
once sync.Once
|
||||
lock sync.Mutex
|
||||
instance *ScriptCache
|
||||
)
|
||||
|
||||
type (
|
||||
Map map[string]string
|
||||
|
||||
ScriptCache struct {
|
||||
atomic.Value
|
||||
}
|
||||
)
|
||||
|
||||
func GetScriptCache() *ScriptCache {
|
||||
once.Do(func() {
|
||||
instance = &ScriptCache{}
|
||||
instance.Store(make(Map))
|
||||
})
|
||||
|
||||
return instance
|
||||
}
|
||||
|
||||
func (sc *ScriptCache) GetSha(script string) (string, bool) {
|
||||
cache := sc.Load().(Map)
|
||||
ret, ok := cache[script]
|
||||
return ret, ok
|
||||
}
|
||||
|
||||
func (sc *ScriptCache) SetSha(script, sha string) {
|
||||
lock.Lock()
|
||||
defer lock.Unlock()
|
||||
|
||||
cache := sc.Load().(Map)
|
||||
newCache := make(Map)
|
||||
for k, v := range cache {
|
||||
newCache[k] = v
|
||||
}
|
||||
newCache[script] = sha
|
||||
sc.Store(newCache)
|
||||
}
|
||||
122
core/stores/sqlc/cachedsql.go
Normal file
122
core/stores/sqlc/cachedsql.go
Normal file
@@ -0,0 +1,122 @@
|
||||
package sqlc
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"time"
|
||||
|
||||
"zero/core/stores/cache"
|
||||
"zero/core/stores/internal"
|
||||
"zero/core/stores/redis"
|
||||
"zero/core/stores/sqlx"
|
||||
"zero/core/syncx"
|
||||
)
|
||||
|
||||
// see doc/sql-cache.md
|
||||
const cacheSafeGapBetweenIndexAndPrimary = time.Second * 5
|
||||
|
||||
var (
|
||||
ErrNotFound = sqlx.ErrNotFound
|
||||
|
||||
// can't use one SharedCalls per conn, because multiple conns may share the same cache key.
|
||||
exclusiveCalls = syncx.NewSharedCalls()
|
||||
stats = internal.NewCacheStat("sqlc")
|
||||
)
|
||||
|
||||
type (
|
||||
ExecFn func(conn sqlx.SqlConn) (sql.Result, error)
|
||||
IndexQueryFn func(conn sqlx.SqlConn, v interface{}) (interface{}, error)
|
||||
PrimaryQueryFn func(conn sqlx.SqlConn, v, primary interface{}) error
|
||||
QueryFn func(conn sqlx.SqlConn, v interface{}) error
|
||||
|
||||
CachedConn struct {
|
||||
db sqlx.SqlConn
|
||||
cache internal.Cache
|
||||
}
|
||||
)
|
||||
|
||||
func NewNodeConn(db sqlx.SqlConn, rds *redis.Redis, opts ...cache.Option) CachedConn {
|
||||
return CachedConn{
|
||||
db: db,
|
||||
cache: internal.NewCacheNode(rds, exclusiveCalls, stats, sql.ErrNoRows, opts...),
|
||||
}
|
||||
}
|
||||
|
||||
func NewConn(db sqlx.SqlConn, c cache.CacheConf, opts ...cache.Option) CachedConn {
|
||||
return CachedConn{
|
||||
db: db,
|
||||
cache: internal.NewCache(c, exclusiveCalls, stats, sql.ErrNoRows, opts...),
|
||||
}
|
||||
}
|
||||
|
||||
func (cc CachedConn) DelCache(keys ...string) error {
|
||||
return cc.cache.DelCache(keys...)
|
||||
}
|
||||
|
||||
func (cc CachedConn) GetCache(key string, v interface{}) error {
|
||||
return cc.cache.GetCache(key, v)
|
||||
}
|
||||
|
||||
func (cc CachedConn) Exec(exec ExecFn, keys ...string) (sql.Result, error) {
|
||||
res, err := exec(cc.db)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := cc.DelCache(keys...); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return res, nil
|
||||
}
|
||||
|
||||
func (cc CachedConn) ExecNoCache(q string, args ...interface{}) (sql.Result, error) {
|
||||
return cc.db.Exec(q, args...)
|
||||
}
|
||||
|
||||
func (cc CachedConn) QueryRow(v interface{}, key string, query QueryFn) error {
|
||||
return cc.cache.Take(v, key, func(v interface{}) error {
|
||||
return query(cc.db, v)
|
||||
})
|
||||
}
|
||||
|
||||
func (cc CachedConn) QueryRowIndex(v interface{}, key string, keyer func(primary interface{}) string,
|
||||
indexQuery IndexQueryFn, primaryQuery PrimaryQueryFn) error {
|
||||
var primaryKey interface{}
|
||||
var found bool
|
||||
if err := cc.cache.TakeWithExpire(&primaryKey, key, func(val interface{}, expire time.Duration) (err error) {
|
||||
primaryKey, err = indexQuery(cc.db, v)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
found = true
|
||||
return cc.cache.SetCacheWithExpire(keyer(primaryKey), v, expire+cacheSafeGapBetweenIndexAndPrimary)
|
||||
}); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if found {
|
||||
return nil
|
||||
}
|
||||
|
||||
return cc.cache.Take(v, keyer(primaryKey), func(v interface{}) error {
|
||||
return primaryQuery(cc.db, v, primaryKey)
|
||||
})
|
||||
}
|
||||
|
||||
func (cc CachedConn) QueryRowNoCache(v interface{}, q string, args ...interface{}) error {
|
||||
return cc.db.QueryRow(v, q, args...)
|
||||
}
|
||||
|
||||
// QueryRowsNoCache doesn't use cache, because it might cause consistency problem.
|
||||
func (cc CachedConn) QueryRowsNoCache(v interface{}, q string, args ...interface{}) error {
|
||||
return cc.db.QueryRows(v, q, args...)
|
||||
}
|
||||
|
||||
func (cc CachedConn) SetCache(key string, v interface{}) error {
|
||||
return cc.cache.SetCache(key, v)
|
||||
}
|
||||
|
||||
func (cc CachedConn) Transact(fn func(sqlx.Session) error) error {
|
||||
return cc.db.Transact(fn)
|
||||
}
|
||||
508
core/stores/sqlc/cachedsql_test.go
Normal file
508
core/stores/sqlc/cachedsql_test.go
Normal file
@@ -0,0 +1,508 @@
|
||||
package sqlc
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"log"
|
||||
"os"
|
||||
"runtime"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"zero/core/logx"
|
||||
"zero/core/stat"
|
||||
"zero/core/stores/cache"
|
||||
"zero/core/stores/redis"
|
||||
"zero/core/stores/sqlx"
|
||||
|
||||
"github.com/alicebob/miniredis"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func init() {
|
||||
logx.Disable()
|
||||
stat.SetReporter(nil)
|
||||
}
|
||||
|
||||
func TestCachedConn_GetCache(t *testing.T) {
|
||||
resetStats()
|
||||
s, err := miniredis.Run()
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
r := redis.NewRedis(s.Addr(), redis.NodeType)
|
||||
c := NewNodeConn(dummySqlConn{}, r, cache.WithExpiry(time.Second*10))
|
||||
var value string
|
||||
err = c.GetCache("any", &value)
|
||||
assert.Equal(t, ErrNotFound, err)
|
||||
s.Set("any", `"value"`)
|
||||
err = c.GetCache("any", &value)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, "value", value)
|
||||
}
|
||||
|
||||
func TestStat(t *testing.T) {
|
||||
resetStats()
|
||||
s, err := miniredis.Run()
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
r := redis.NewRedis(s.Addr(), redis.NodeType)
|
||||
c := NewNodeConn(dummySqlConn{}, r, cache.WithExpiry(time.Second*10))
|
||||
|
||||
for i := 0; i < 10; i++ {
|
||||
var str string
|
||||
err = c.QueryRow(&str, "name", func(conn sqlx.SqlConn, v interface{}) error {
|
||||
*v.(*string) = "zero"
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
}
|
||||
|
||||
assert.Equal(t, uint64(10), atomic.LoadUint64(&stats.Total))
|
||||
assert.Equal(t, uint64(9), atomic.LoadUint64(&stats.Hit))
|
||||
}
|
||||
|
||||
func TestCachedConn_QueryRowIndex_NoCache(t *testing.T) {
|
||||
resetStats()
|
||||
s, err := miniredis.Run()
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
r := redis.NewRedis(s.Addr(), redis.NodeType)
|
||||
c := NewNodeConn(dummySqlConn{}, r, cache.WithExpiry(time.Second*10))
|
||||
|
||||
var str string
|
||||
err = c.QueryRowIndex(&str, "index", func(s interface{}) string {
|
||||
return fmt.Sprintf("%s/1234", s)
|
||||
}, func(conn sqlx.SqlConn, v interface{}) (interface{}, error) {
|
||||
*v.(*string) = "zero"
|
||||
return "primary", nil
|
||||
}, func(conn sqlx.SqlConn, v, pri interface{}) error {
|
||||
assert.Equal(t, "primary", pri)
|
||||
*v.(*string) = "xin"
|
||||
return nil
|
||||
})
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, "zero", str)
|
||||
val, err := r.Get("index")
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, `"primary"`, val)
|
||||
val, err = r.Get("primary/1234")
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, `"zero"`, val)
|
||||
}
|
||||
|
||||
func TestCachedConn_QueryRowIndex_HasCache(t *testing.T) {
|
||||
resetStats()
|
||||
s, err := miniredis.Run()
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
r := redis.NewRedis(s.Addr(), redis.NodeType)
|
||||
c := NewNodeConn(dummySqlConn{}, r, cache.WithExpiry(time.Second*10),
|
||||
cache.WithNotFoundExpiry(time.Second))
|
||||
|
||||
var str string
|
||||
r.Set("index", `"primary"`)
|
||||
err = c.QueryRowIndex(&str, "index", func(s interface{}) string {
|
||||
return fmt.Sprintf("%s/1234", s)
|
||||
}, func(conn sqlx.SqlConn, v interface{}) (interface{}, error) {
|
||||
assert.Fail(t, "should not go here")
|
||||
return "primary", nil
|
||||
}, func(conn sqlx.SqlConn, v, primary interface{}) error {
|
||||
*v.(*string) = "xin"
|
||||
assert.Equal(t, "primary", primary)
|
||||
return nil
|
||||
})
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, "xin", str)
|
||||
val, err := r.Get("index")
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, `"primary"`, val)
|
||||
val, err = r.Get("primary/1234")
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, `"xin"`, val)
|
||||
}
|
||||
|
||||
func TestCachedConn_QueryRowIndex_HasWrongCache(t *testing.T) {
|
||||
caches := map[string]string{
|
||||
"index": "primary",
|
||||
"primary/1234": "xin",
|
||||
}
|
||||
|
||||
for k, v := range caches {
|
||||
t.Run(k+"/"+v, func(t *testing.T) {
|
||||
resetStats()
|
||||
s, err := miniredis.Run()
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
r := redis.NewRedis(s.Addr(), redis.NodeType)
|
||||
c := NewNodeConn(dummySqlConn{}, r, cache.WithExpiry(time.Second*10),
|
||||
cache.WithNotFoundExpiry(time.Second))
|
||||
|
||||
var str string
|
||||
r.Set(k, v)
|
||||
err = c.QueryRowIndex(&str, "index", func(s interface{}) string {
|
||||
return fmt.Sprintf("%s/1234", s)
|
||||
}, func(conn sqlx.SqlConn, v interface{}) (interface{}, error) {
|
||||
*v.(*string) = "xin"
|
||||
return "primary", nil
|
||||
}, func(conn sqlx.SqlConn, v, primary interface{}) error {
|
||||
*v.(*string) = "xin"
|
||||
assert.Equal(t, "primary", primary)
|
||||
return nil
|
||||
})
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, "xin", str)
|
||||
val, err := r.Get("index")
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, `"primary"`, val)
|
||||
val, err = r.Get("primary/1234")
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, `"xin"`, val)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestStatCacheFails(t *testing.T) {
|
||||
resetStats()
|
||||
log.SetOutput(ioutil.Discard)
|
||||
defer log.SetOutput(os.Stdout)
|
||||
|
||||
r := redis.NewRedis("localhost:59999", redis.NodeType)
|
||||
c := NewNodeConn(dummySqlConn{}, r, cache.WithExpiry(time.Second*10))
|
||||
|
||||
for i := 0; i < 20; i++ {
|
||||
var str string
|
||||
err := c.QueryRow(&str, "name", func(conn sqlx.SqlConn, v interface{}) error {
|
||||
return errors.New("db failed")
|
||||
})
|
||||
assert.NotNil(t, err)
|
||||
}
|
||||
|
||||
assert.Equal(t, uint64(20), atomic.LoadUint64(&stats.Total))
|
||||
assert.Equal(t, uint64(0), atomic.LoadUint64(&stats.Hit))
|
||||
assert.Equal(t, uint64(20), atomic.LoadUint64(&stats.Miss))
|
||||
assert.Equal(t, uint64(0), atomic.LoadUint64(&stats.DbFails))
|
||||
}
|
||||
|
||||
func TestStatDbFails(t *testing.T) {
|
||||
resetStats()
|
||||
s, err := miniredis.Run()
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
r := redis.NewRedis(s.Addr(), redis.NodeType)
|
||||
c := NewNodeConn(dummySqlConn{}, r, cache.WithExpiry(time.Second*10))
|
||||
|
||||
for i := 0; i < 20; i++ {
|
||||
var str string
|
||||
err = c.QueryRow(&str, "name", func(conn sqlx.SqlConn, v interface{}) error {
|
||||
return errors.New("db failed")
|
||||
})
|
||||
assert.NotNil(t, err)
|
||||
}
|
||||
|
||||
assert.Equal(t, uint64(20), atomic.LoadUint64(&stats.Total))
|
||||
assert.Equal(t, uint64(0), atomic.LoadUint64(&stats.Hit))
|
||||
assert.Equal(t, uint64(20), atomic.LoadUint64(&stats.DbFails))
|
||||
}
|
||||
|
||||
func TestStatFromMemory(t *testing.T) {
|
||||
resetStats()
|
||||
s, err := miniredis.Run()
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
r := redis.NewRedis(s.Addr(), redis.NodeType)
|
||||
c := NewNodeConn(dummySqlConn{}, r, cache.WithExpiry(time.Second*10))
|
||||
|
||||
var all sync.WaitGroup
|
||||
var wait sync.WaitGroup
|
||||
all.Add(10)
|
||||
wait.Add(4)
|
||||
go func() {
|
||||
var str string
|
||||
err := c.QueryRow(&str, "name", func(conn sqlx.SqlConn, v interface{}) error {
|
||||
*v.(*string) = "zero"
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
wait.Wait()
|
||||
runtime.Gosched()
|
||||
all.Done()
|
||||
}()
|
||||
|
||||
for i := 0; i < 4; i++ {
|
||||
go func() {
|
||||
var str string
|
||||
wait.Done()
|
||||
err := c.QueryRow(&str, "name", func(conn sqlx.SqlConn, v interface{}) error {
|
||||
*v.(*string) = "zero"
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
all.Done()
|
||||
}()
|
||||
}
|
||||
for i := 0; i < 5; i++ {
|
||||
go func() {
|
||||
var str string
|
||||
err := c.QueryRow(&str, "name", func(conn sqlx.SqlConn, v interface{}) error {
|
||||
*v.(*string) = "zero"
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
all.Done()
|
||||
}()
|
||||
}
|
||||
all.Wait()
|
||||
|
||||
assert.Equal(t, uint64(10), atomic.LoadUint64(&stats.Total))
|
||||
assert.Equal(t, uint64(9), atomic.LoadUint64(&stats.Hit))
|
||||
}
|
||||
|
||||
func TestCachedConnQueryRow(t *testing.T) {
|
||||
s, err := miniredis.Run()
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
const (
|
||||
key = "user"
|
||||
value = "any"
|
||||
)
|
||||
var conn trackedConn
|
||||
var user string
|
||||
var ran bool
|
||||
r := redis.NewRedis(s.Addr(), redis.NodeType)
|
||||
c := NewNodeConn(&conn, r, cache.WithExpiry(time.Second*30))
|
||||
err = c.QueryRow(&user, key, func(conn sqlx.SqlConn, v interface{}) error {
|
||||
ran = true
|
||||
user = value
|
||||
return nil
|
||||
})
|
||||
assert.Nil(t, err)
|
||||
actualValue, err := s.Get(key)
|
||||
assert.Nil(t, err)
|
||||
var actual string
|
||||
assert.Nil(t, json.Unmarshal([]byte(actualValue), &actual))
|
||||
assert.Equal(t, value, actual)
|
||||
assert.Equal(t, value, user)
|
||||
assert.True(t, ran)
|
||||
}
|
||||
|
||||
func TestCachedConnQueryRowFromCache(t *testing.T) {
|
||||
s, err := miniredis.Run()
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
const (
|
||||
key = "user"
|
||||
value = "any"
|
||||
)
|
||||
var conn trackedConn
|
||||
var user string
|
||||
var ran bool
|
||||
r := redis.NewRedis(s.Addr(), redis.NodeType)
|
||||
c := NewNodeConn(&conn, r, cache.WithExpiry(time.Second*30))
|
||||
assert.Nil(t, c.SetCache(key, value))
|
||||
err = c.QueryRow(&user, key, func(conn sqlx.SqlConn, v interface{}) error {
|
||||
ran = true
|
||||
user = value
|
||||
return nil
|
||||
})
|
||||
assert.Nil(t, err)
|
||||
actualValue, err := s.Get(key)
|
||||
assert.Nil(t, err)
|
||||
var actual string
|
||||
assert.Nil(t, json.Unmarshal([]byte(actualValue), &actual))
|
||||
assert.Equal(t, value, actual)
|
||||
assert.Equal(t, value, user)
|
||||
assert.False(t, ran)
|
||||
}
|
||||
|
||||
func TestQueryRowNotFound(t *testing.T) {
|
||||
s, err := miniredis.Run()
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
const key = "user"
|
||||
var conn trackedConn
|
||||
var user string
|
||||
var ran int
|
||||
r := redis.NewRedis(s.Addr(), redis.NodeType)
|
||||
c := NewNodeConn(&conn, r, cache.WithExpiry(time.Second*30))
|
||||
for i := 0; i < 20; i++ {
|
||||
err = c.QueryRow(&user, key, func(conn sqlx.SqlConn, v interface{}) error {
|
||||
ran++
|
||||
return sql.ErrNoRows
|
||||
})
|
||||
assert.Exactly(t, sqlx.ErrNotFound, err)
|
||||
}
|
||||
assert.Equal(t, 1, ran)
|
||||
}
|
||||
|
||||
func TestCachedConnExec(t *testing.T) {
|
||||
s, err := miniredis.Run()
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
var conn trackedConn
|
||||
r := redis.NewRedis(s.Addr(), redis.NodeType)
|
||||
c := NewNodeConn(&conn, r, cache.WithExpiry(time.Second*10))
|
||||
_, err = c.ExecNoCache("delete from user_table where id='kevin'")
|
||||
assert.Nil(t, err)
|
||||
assert.True(t, conn.execValue)
|
||||
}
|
||||
|
||||
func TestCachedConnExecDropCache(t *testing.T) {
|
||||
s, err := miniredis.Run()
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
const (
|
||||
key = "user"
|
||||
value = "any"
|
||||
)
|
||||
var conn trackedConn
|
||||
r := redis.NewRedis(s.Addr(), redis.NodeType)
|
||||
c := NewNodeConn(&conn, r, cache.WithExpiry(time.Second*30))
|
||||
assert.Nil(t, c.SetCache(key, value))
|
||||
_, err = c.Exec(func(conn sqlx.SqlConn) (result sql.Result, e error) {
|
||||
return conn.Exec("delete from user_table where id='kevin'")
|
||||
}, key)
|
||||
assert.Nil(t, err)
|
||||
assert.True(t, conn.execValue)
|
||||
_, err = s.Get(key)
|
||||
assert.Exactly(t, miniredis.ErrKeyNotFound, err)
|
||||
}
|
||||
|
||||
func TestCachedConnExecDropCacheFailed(t *testing.T) {
|
||||
const key = "user"
|
||||
var conn trackedConn
|
||||
r := redis.NewRedis("anyredis:8888", redis.NodeType)
|
||||
c := NewNodeConn(&conn, r, cache.WithExpiry(time.Second*10))
|
||||
_, err := c.Exec(func(conn sqlx.SqlConn) (result sql.Result, e error) {
|
||||
return conn.Exec("delete from user_table where id='kevin'")
|
||||
}, key)
|
||||
// async background clean, retry logic
|
||||
assert.Nil(t, err)
|
||||
}
|
||||
|
||||
func TestCachedConnQueryRows(t *testing.T) {
|
||||
s, err := miniredis.Run()
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
var conn trackedConn
|
||||
r := redis.NewRedis(s.Addr(), redis.NodeType)
|
||||
c := NewNodeConn(&conn, r, cache.WithExpiry(time.Second*10))
|
||||
var users []string
|
||||
err = c.QueryRowsNoCache(&users, "select user from user_table where id='kevin'")
|
||||
assert.Nil(t, err)
|
||||
assert.True(t, conn.queryRowsValue)
|
||||
}
|
||||
|
||||
func TestCachedConnTransact(t *testing.T) {
|
||||
s, err := miniredis.Run()
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
var conn trackedConn
|
||||
r := redis.NewRedis(s.Addr(), redis.NodeType)
|
||||
c := NewNodeConn(&conn, r, cache.WithExpiry(time.Second*10))
|
||||
err = c.Transact(func(session sqlx.Session) error {
|
||||
return nil
|
||||
})
|
||||
assert.Nil(t, err)
|
||||
assert.True(t, conn.transactValue)
|
||||
}
|
||||
|
||||
func resetStats() {
|
||||
atomic.StoreUint64(&stats.Total, 0)
|
||||
atomic.StoreUint64(&stats.Hit, 0)
|
||||
atomic.StoreUint64(&stats.Miss, 0)
|
||||
atomic.StoreUint64(&stats.DbFails, 0)
|
||||
}
|
||||
|
||||
type dummySqlConn struct {
|
||||
}
|
||||
|
||||
func (d dummySqlConn) Exec(query string, args ...interface{}) (sql.Result, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (d dummySqlConn) Prepare(query string) (sqlx.StmtSession, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (d dummySqlConn) QueryRow(v interface{}, query string, args ...interface{}) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (d dummySqlConn) QueryRowPartial(v interface{}, query string, args ...interface{}) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (d dummySqlConn) QueryRows(v interface{}, query string, args ...interface{}) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (d dummySqlConn) QueryRowsPartial(v interface{}, query string, args ...interface{}) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (d dummySqlConn) Transact(func(session sqlx.Session) error) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
type trackedConn struct {
|
||||
dummySqlConn
|
||||
execValue bool
|
||||
queryRowsValue bool
|
||||
transactValue bool
|
||||
}
|
||||
|
||||
func (c *trackedConn) Exec(query string, args ...interface{}) (sql.Result, error) {
|
||||
c.execValue = true
|
||||
return c.dummySqlConn.Exec(query, args...)
|
||||
}
|
||||
|
||||
func (c *trackedConn) QueryRows(v interface{}, query string, args ...interface{}) error {
|
||||
c.queryRowsValue = true
|
||||
return c.dummySqlConn.QueryRows(v, query, args...)
|
||||
}
|
||||
|
||||
func (c *trackedConn) Transact(fn func(session sqlx.Session) error) error {
|
||||
c.transactValue = true
|
||||
return c.dummySqlConn.Transact(fn)
|
||||
}
|
||||
187
core/stores/sqlx/bulkinserter.go
Normal file
187
core/stores/sqlx/bulkinserter.go
Normal file
@@ -0,0 +1,187 @@
|
||||
package sqlx
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"zero/core/executors"
|
||||
"zero/core/logx"
|
||||
"zero/core/stringx"
|
||||
)
|
||||
|
||||
const (
|
||||
flushInterval = time.Second
|
||||
maxBulkRows = 1000
|
||||
valuesKeyword = "values"
|
||||
)
|
||||
|
||||
var emptyBulkStmt bulkStmt
|
||||
|
||||
type (
|
||||
ResultHandler func(sql.Result, error)
|
||||
|
||||
BulkInserter struct {
|
||||
executor *executors.PeriodicalExecutor
|
||||
inserter *dbInserter
|
||||
stmt bulkStmt
|
||||
}
|
||||
|
||||
bulkStmt struct {
|
||||
prefix string
|
||||
valueFormat string
|
||||
suffix string
|
||||
}
|
||||
)
|
||||
|
||||
func NewBulkInserter(sqlConn SqlConn, stmt string) (*BulkInserter, error) {
|
||||
bkStmt, err := parseInsertStmt(stmt)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
inserter := &dbInserter{
|
||||
sqlConn: sqlConn,
|
||||
stmt: bkStmt,
|
||||
}
|
||||
|
||||
return &BulkInserter{
|
||||
executor: executors.NewPeriodicalExecutor(flushInterval, inserter),
|
||||
inserter: inserter,
|
||||
stmt: bkStmt,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (bi *BulkInserter) Flush() {
|
||||
bi.executor.Flush()
|
||||
}
|
||||
|
||||
func (bi *BulkInserter) Insert(args ...interface{}) error {
|
||||
value, err := format(bi.stmt.valueFormat, args...)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
bi.executor.Add(value)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (bi *BulkInserter) SetResultHandler(handler ResultHandler) {
|
||||
bi.executor.Sync(func() {
|
||||
bi.inserter.resultHandler = handler
|
||||
})
|
||||
}
|
||||
|
||||
func (bi *BulkInserter) UpdateOrDelete(fn func()) {
|
||||
bi.executor.Flush()
|
||||
fn()
|
||||
}
|
||||
|
||||
func (bi *BulkInserter) UpdateStmt(stmt string) error {
|
||||
bkStmt, err := parseInsertStmt(stmt)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
bi.executor.Flush()
|
||||
bi.executor.Sync(func() {
|
||||
bi.inserter.stmt = bkStmt
|
||||
})
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
type dbInserter struct {
|
||||
sqlConn SqlConn
|
||||
stmt bulkStmt
|
||||
values []string
|
||||
resultHandler ResultHandler
|
||||
}
|
||||
|
||||
func (in *dbInserter) AddTask(task interface{}) bool {
|
||||
in.values = append(in.values, task.(string))
|
||||
return len(in.values) >= maxBulkRows
|
||||
}
|
||||
|
||||
func (in *dbInserter) Execute(bulk interface{}) {
|
||||
values := bulk.([]string)
|
||||
if len(values) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
stmtWithoutValues := in.stmt.prefix
|
||||
valuesStr := strings.Join(values, ", ")
|
||||
stmt := strings.Join([]string{stmtWithoutValues, valuesStr}, " ")
|
||||
if len(in.stmt.suffix) > 0 {
|
||||
stmt = strings.Join([]string{stmt, in.stmt.suffix}, " ")
|
||||
}
|
||||
result, err := in.sqlConn.Exec(stmt)
|
||||
if in.resultHandler != nil {
|
||||
in.resultHandler(result, err)
|
||||
} else if err != nil {
|
||||
logx.Errorf("sql: %s, error: %s", stmt, err)
|
||||
}
|
||||
}
|
||||
|
||||
func (in *dbInserter) RemoveAll() interface{} {
|
||||
values := in.values
|
||||
in.values = nil
|
||||
return values
|
||||
}
|
||||
|
||||
func parseInsertStmt(stmt string) (bulkStmt, error) {
|
||||
lower := strings.ToLower(stmt)
|
||||
pos := strings.Index(lower, valuesKeyword)
|
||||
if pos <= 0 {
|
||||
return emptyBulkStmt, fmt.Errorf("bad sql: %q", stmt)
|
||||
}
|
||||
|
||||
var columns int
|
||||
right := strings.LastIndexByte(lower[:pos], ')')
|
||||
if right > 0 {
|
||||
left := strings.LastIndexByte(lower[:right], '(')
|
||||
if left > 0 {
|
||||
values := lower[left+1 : right]
|
||||
values = stringx.Filter(values, func(r rune) bool {
|
||||
return r == ' ' || r == '\t' || r == '\r' || r == '\n'
|
||||
})
|
||||
fields := strings.FieldsFunc(values, func(r rune) bool {
|
||||
return r == ','
|
||||
})
|
||||
columns = len(fields)
|
||||
}
|
||||
}
|
||||
|
||||
var variables int
|
||||
var valueFormat string
|
||||
var suffix string
|
||||
left := strings.IndexByte(lower[pos:], '(')
|
||||
if left > 0 {
|
||||
right = strings.IndexByte(lower[pos+left:], ')')
|
||||
if right > 0 {
|
||||
values := lower[pos+left : pos+left+right]
|
||||
for _, x := range values {
|
||||
if x == '?' {
|
||||
variables++
|
||||
}
|
||||
}
|
||||
valueFormat = stmt[pos+left : pos+left+right+1]
|
||||
suffix = strings.TrimSpace(stmt[pos+left+right+1:])
|
||||
}
|
||||
}
|
||||
|
||||
if variables == 0 {
|
||||
return emptyBulkStmt, fmt.Errorf("no variables: %q", stmt)
|
||||
}
|
||||
if columns > 0 && columns != variables {
|
||||
return emptyBulkStmt, fmt.Errorf("columns and variables mismatch: %q", stmt)
|
||||
}
|
||||
|
||||
return bulkStmt{
|
||||
prefix: stmt[:pos+len(valuesKeyword)],
|
||||
valueFormat: valueFormat,
|
||||
suffix: suffix,
|
||||
}, nil
|
||||
}
|
||||
98
core/stores/sqlx/bulkinserter_test.go
Normal file
98
core/stores/sqlx/bulkinserter_test.go
Normal file
@@ -0,0 +1,98 @@
|
||||
package sqlx
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"strconv"
|
||||
"testing"
|
||||
|
||||
"zero/core/logx"
|
||||
|
||||
"github.com/DATA-DOG/go-sqlmock"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
type mockedConn struct {
|
||||
query string
|
||||
args []interface{}
|
||||
}
|
||||
|
||||
func (c *mockedConn) Exec(query string, args ...interface{}) (sql.Result, error) {
|
||||
c.query = query
|
||||
c.args = args
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (c *mockedConn) Prepare(query string) (StmtSession, error) {
|
||||
panic("should not called")
|
||||
}
|
||||
|
||||
func (c *mockedConn) QueryRow(v interface{}, query string, args ...interface{}) error {
|
||||
panic("should not called")
|
||||
}
|
||||
|
||||
func (c *mockedConn) QueryRowPartial(v interface{}, query string, args ...interface{}) error {
|
||||
panic("should not called")
|
||||
}
|
||||
|
||||
func (c *mockedConn) QueryRows(v interface{}, query string, args ...interface{}) error {
|
||||
panic("should not called")
|
||||
}
|
||||
|
||||
func (c *mockedConn) QueryRowsPartial(v interface{}, query string, args ...interface{}) error {
|
||||
panic("should not called")
|
||||
}
|
||||
|
||||
func (c *mockedConn) Transact(func(session Session) error) error {
|
||||
panic("should not called")
|
||||
}
|
||||
|
||||
func TestBulkInserter(t *testing.T) {
|
||||
runSqlTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||
var conn mockedConn
|
||||
inserter, err := NewBulkInserter(&conn, `INSERT INTO classroom_dau(classroom, user, count) VALUES(?, ?, ?)`)
|
||||
assert.Nil(t, err)
|
||||
for i := 0; i < 5; i++ {
|
||||
assert.Nil(t, inserter.Insert("class_"+strconv.Itoa(i), "user_"+strconv.Itoa(i), i))
|
||||
}
|
||||
inserter.Flush()
|
||||
assert.Equal(t, `INSERT INTO classroom_dau(classroom, user, count) VALUES `+
|
||||
`('class_0', 'user_0', 0), ('class_1', 'user_1', 1), ('class_2', 'user_2', 2), `+
|
||||
`('class_3', 'user_3', 3), ('class_4', 'user_4', 4)`,
|
||||
conn.query)
|
||||
assert.Nil(t, conn.args)
|
||||
})
|
||||
}
|
||||
|
||||
func TestBulkInserterSuffix(t *testing.T) {
|
||||
runSqlTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||
var conn mockedConn
|
||||
inserter, err := NewBulkInserter(&conn, `INSERT INTO classroom_dau(classroom, user, count) VALUES`+
|
||||
`(?, ?, ?) ON DUPLICATE KEY UPDATE is_overtime=VALUES(is_overtime)`)
|
||||
assert.Nil(t, err)
|
||||
for i := 0; i < 5; i++ {
|
||||
assert.Nil(t, inserter.Insert("class_"+strconv.Itoa(i), "user_"+strconv.Itoa(i), i))
|
||||
}
|
||||
inserter.Flush()
|
||||
assert.Equal(t, `INSERT INTO classroom_dau(classroom, user, count) VALUES `+
|
||||
`('class_0', 'user_0', 0), ('class_1', 'user_1', 1), ('class_2', 'user_2', 2), `+
|
||||
`('class_3', 'user_3', 3), ('class_4', 'user_4', 4) ON DUPLICATE KEY UPDATE is_overtime=VALUES(is_overtime)`,
|
||||
conn.query)
|
||||
assert.Nil(t, conn.args)
|
||||
})
|
||||
}
|
||||
|
||||
func runSqlTest(t *testing.T, fn func(db *sql.DB, mock sqlmock.Sqlmock)) {
|
||||
logx.Disable()
|
||||
|
||||
db, mock, err := sqlmock.New()
|
||||
if err != nil {
|
||||
t.Fatalf("an error '%s' was not expected when opening a stub database connection", err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
fn(db, mock)
|
||||
|
||||
if err := mock.ExpectationsWereMet(); err != nil {
|
||||
t.Errorf("there were unfulfilled expectations: %s", err)
|
||||
}
|
||||
}
|
||||
37
core/stores/sqlx/mysql.go
Normal file
37
core/stores/sqlx/mysql.go
Normal file
@@ -0,0 +1,37 @@
|
||||
package sqlx
|
||||
|
||||
import "github.com/go-sql-driver/mysql"
|
||||
|
||||
const (
|
||||
mysqlDriverName = "mysql"
|
||||
duplicateEntryCode uint16 = 1062
|
||||
)
|
||||
|
||||
func NewMysql(datasource string, opts ...SqlOption) SqlConn {
|
||||
opts = append(opts, withMysqlAcceptable())
|
||||
return NewSqlConn(mysqlDriverName, datasource, opts...)
|
||||
}
|
||||
|
||||
func mysqlAcceptable(err error) bool {
|
||||
if err == nil {
|
||||
return true
|
||||
}
|
||||
|
||||
myerr, ok := err.(*mysql.MySQLError)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
|
||||
switch myerr.Number {
|
||||
case duplicateEntryCode:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func withMysqlAcceptable() SqlOption {
|
||||
return func(conn *commonSqlConn) {
|
||||
conn.accept = mysqlAcceptable
|
||||
}
|
||||
}
|
||||
56
core/stores/sqlx/mysql_test.go
Normal file
56
core/stores/sqlx/mysql_test.go
Normal file
@@ -0,0 +1,56 @@
|
||||
package sqlx
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"zero/core/breaker"
|
||||
"zero/core/logx"
|
||||
"zero/core/stat"
|
||||
|
||||
"github.com/go-sql-driver/mysql"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func init() {
|
||||
stat.SetReporter(nil)
|
||||
}
|
||||
|
||||
func TestBreakerOnDuplicateEntry(t *testing.T) {
|
||||
logx.Disable()
|
||||
|
||||
err := tryOnDuplicateEntryError(t, mysqlAcceptable)
|
||||
assert.Equal(t, duplicateEntryCode, err.(*mysql.MySQLError).Number)
|
||||
}
|
||||
|
||||
func TestBreakerOnNotHandlingDuplicateEntry(t *testing.T) {
|
||||
logx.Disable()
|
||||
|
||||
var found bool
|
||||
for i := 0; i < 100; i++ {
|
||||
if tryOnDuplicateEntryError(t, nil) == breaker.ErrServiceUnavailable {
|
||||
found = true
|
||||
}
|
||||
}
|
||||
assert.True(t, found)
|
||||
}
|
||||
|
||||
func tryOnDuplicateEntryError(t *testing.T, accept func(error) bool) error {
|
||||
logx.Disable()
|
||||
|
||||
conn := commonSqlConn{
|
||||
brk: breaker.NewBreaker(),
|
||||
accept: accept,
|
||||
}
|
||||
for i := 0; i < 1000; i++ {
|
||||
assert.NotNil(t, conn.brk.DoWithAcceptable(func() error {
|
||||
return &mysql.MySQLError{
|
||||
Number: duplicateEntryCode,
|
||||
}
|
||||
}, conn.acceptable))
|
||||
}
|
||||
return conn.brk.DoWithAcceptable(func() error {
|
||||
return &mysql.MySQLError{
|
||||
Number: duplicateEntryCode,
|
||||
}
|
||||
}, conn.acceptable)
|
||||
}
|
||||
254
core/stores/sqlx/orm.go
Normal file
254
core/stores/sqlx/orm.go
Normal file
@@ -0,0 +1,254 @@
|
||||
package sqlx
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"reflect"
|
||||
"strings"
|
||||
|
||||
"zero/core/mapping"
|
||||
)
|
||||
|
||||
const tagName = "db"
|
||||
|
||||
var (
|
||||
ErrNotMatchDestination = errors.New("not matching destination to scan")
|
||||
ErrNotReadableValue = errors.New("value not addressable or interfaceable")
|
||||
ErrNotSettable = errors.New("passed in variable is not settable")
|
||||
ErrUnsupportedValueType = errors.New("unsupported unmarshal type")
|
||||
)
|
||||
|
||||
type rowsScanner interface {
|
||||
Columns() ([]string, error)
|
||||
Err() error
|
||||
Next() bool
|
||||
Scan(v ...interface{}) error
|
||||
}
|
||||
|
||||
func getTaggedFieldValueMap(v reflect.Value) (map[string]interface{}, error) {
|
||||
rt := mapping.Deref(v.Type())
|
||||
size := rt.NumField()
|
||||
result := make(map[string]interface{}, size)
|
||||
|
||||
for i := 0; i < size; i++ {
|
||||
key := parseTagName(rt.Field(i))
|
||||
if len(key) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
valueField := reflect.Indirect(v).Field(i)
|
||||
switch valueField.Kind() {
|
||||
case reflect.Ptr:
|
||||
if !valueField.CanInterface() {
|
||||
return nil, ErrNotReadableValue
|
||||
}
|
||||
if valueField.IsNil() {
|
||||
baseValueType := mapping.Deref(valueField.Type())
|
||||
valueField.Set(reflect.New(baseValueType))
|
||||
}
|
||||
result[key] = valueField.Interface()
|
||||
default:
|
||||
if !valueField.CanAddr() || !valueField.Addr().CanInterface() {
|
||||
return nil, ErrNotReadableValue
|
||||
}
|
||||
result[key] = valueField.Addr().Interface()
|
||||
}
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func mapStructFieldsIntoSlice(v reflect.Value, columns []string, strict bool) ([]interface{}, error) {
|
||||
fields := unwrapFields(v)
|
||||
if strict && len(columns) < len(fields) {
|
||||
return nil, ErrNotMatchDestination
|
||||
}
|
||||
|
||||
taggedMap, err := getTaggedFieldValueMap(v)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
values := make([]interface{}, len(columns))
|
||||
if len(taggedMap) == 0 {
|
||||
for i := 0; i < len(values); i++ {
|
||||
valueField := fields[i]
|
||||
switch valueField.Kind() {
|
||||
case reflect.Ptr:
|
||||
if !valueField.CanInterface() {
|
||||
return nil, ErrNotReadableValue
|
||||
}
|
||||
if valueField.IsNil() {
|
||||
baseValueType := mapping.Deref(valueField.Type())
|
||||
valueField.Set(reflect.New(baseValueType))
|
||||
}
|
||||
values[i] = valueField.Interface()
|
||||
default:
|
||||
if !valueField.CanAddr() || !valueField.Addr().CanInterface() {
|
||||
return nil, ErrNotReadableValue
|
||||
}
|
||||
values[i] = valueField.Addr().Interface()
|
||||
}
|
||||
}
|
||||
} else {
|
||||
for i, column := range columns {
|
||||
if tagged, ok := taggedMap[column]; ok {
|
||||
values[i] = tagged
|
||||
} else {
|
||||
var anonymous interface{}
|
||||
values[i] = &anonymous
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return values, nil
|
||||
}
|
||||
|
||||
func parseTagName(field reflect.StructField) string {
|
||||
key := field.Tag.Get(tagName)
|
||||
if len(key) == 0 {
|
||||
return ""
|
||||
} else {
|
||||
options := strings.Split(key, ",")
|
||||
return options[0]
|
||||
}
|
||||
}
|
||||
|
||||
func unmarshalRow(v interface{}, scanner rowsScanner, strict bool) error {
|
||||
if !scanner.Next() {
|
||||
if err := scanner.Err(); err != nil {
|
||||
return err
|
||||
}
|
||||
return ErrNotFound
|
||||
}
|
||||
|
||||
rv := reflect.ValueOf(v)
|
||||
if err := mapping.ValidatePtr(&rv); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
rte := reflect.TypeOf(v).Elem()
|
||||
rve := rv.Elem()
|
||||
switch rte.Kind() {
|
||||
case reflect.Bool,
|
||||
reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64,
|
||||
reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64,
|
||||
reflect.Float32, reflect.Float64,
|
||||
reflect.String:
|
||||
if rve.CanSet() {
|
||||
return scanner.Scan(v)
|
||||
} else {
|
||||
return ErrNotSettable
|
||||
}
|
||||
case reflect.Struct:
|
||||
columns, err := scanner.Columns()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if values, err := mapStructFieldsIntoSlice(rve, columns, strict); err != nil {
|
||||
return err
|
||||
} else {
|
||||
return scanner.Scan(values...)
|
||||
}
|
||||
default:
|
||||
return ErrUnsupportedValueType
|
||||
}
|
||||
}
|
||||
|
||||
func unmarshalRows(v interface{}, scanner rowsScanner, strict bool) error {
|
||||
rv := reflect.ValueOf(v)
|
||||
if err := mapping.ValidatePtr(&rv); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
rt := reflect.TypeOf(v)
|
||||
rte := rt.Elem()
|
||||
rve := rv.Elem()
|
||||
switch rte.Kind() {
|
||||
case reflect.Slice:
|
||||
if rve.CanSet() {
|
||||
ptr := rte.Elem().Kind() == reflect.Ptr
|
||||
appendFn := func(item reflect.Value) {
|
||||
if ptr {
|
||||
rve.Set(reflect.Append(rve, item))
|
||||
} else {
|
||||
rve.Set(reflect.Append(rve, reflect.Indirect(item)))
|
||||
}
|
||||
}
|
||||
fillFn := func(value interface{}) error {
|
||||
if rve.CanSet() {
|
||||
if err := scanner.Scan(value); err != nil {
|
||||
return err
|
||||
} else {
|
||||
appendFn(reflect.ValueOf(value))
|
||||
return nil
|
||||
}
|
||||
}
|
||||
return ErrNotSettable
|
||||
}
|
||||
|
||||
base := mapping.Deref(rte.Elem())
|
||||
switch base.Kind() {
|
||||
case reflect.Bool,
|
||||
reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64,
|
||||
reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64,
|
||||
reflect.Float32, reflect.Float64,
|
||||
reflect.String:
|
||||
for scanner.Next() {
|
||||
value := reflect.New(base)
|
||||
if err := fillFn(value.Interface()); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
case reflect.Struct:
|
||||
columns, err := scanner.Columns()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for scanner.Next() {
|
||||
value := reflect.New(base)
|
||||
if values, err := mapStructFieldsIntoSlice(value, columns, strict); err != nil {
|
||||
return err
|
||||
} else {
|
||||
if err := scanner.Scan(values...); err != nil {
|
||||
return err
|
||||
} else {
|
||||
appendFn(value)
|
||||
}
|
||||
}
|
||||
}
|
||||
default:
|
||||
return ErrUnsupportedValueType
|
||||
}
|
||||
|
||||
return nil
|
||||
} else {
|
||||
return ErrNotSettable
|
||||
}
|
||||
default:
|
||||
return ErrUnsupportedValueType
|
||||
}
|
||||
}
|
||||
|
||||
func unwrapFields(v reflect.Value) []reflect.Value {
|
||||
var fields []reflect.Value
|
||||
indirect := reflect.Indirect(v)
|
||||
|
||||
for i := 0; i < indirect.NumField(); i++ {
|
||||
child := indirect.Field(i)
|
||||
if child.Kind() == reflect.Ptr && child.IsNil() {
|
||||
baseValueType := mapping.Deref(child.Type())
|
||||
child.Set(reflect.New(baseValueType))
|
||||
}
|
||||
|
||||
child = reflect.Indirect(child)
|
||||
childType := indirect.Type().Field(i)
|
||||
if child.Kind() == reflect.Struct && childType.Anonymous {
|
||||
fields = append(fields, unwrapFields(child)...)
|
||||
} else {
|
||||
fields = append(fields, child)
|
||||
}
|
||||
}
|
||||
|
||||
return fields
|
||||
}
|
||||
973
core/stores/sqlx/orm_test.go
Normal file
973
core/stores/sqlx/orm_test.go
Normal file
@@ -0,0 +1,973 @@
|
||||
package sqlx
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"testing"
|
||||
|
||||
"zero/core/logx"
|
||||
|
||||
"github.com/DATA-DOG/go-sqlmock"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestUnmarshalRowBool(t *testing.T) {
|
||||
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("1")
|
||||
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
||||
|
||||
var value bool
|
||||
assert.Nil(t, query(db, func(rows *sql.Rows) error {
|
||||
return unmarshalRow(&value, rows, true)
|
||||
}, "select value from users where user=?", "anyone"))
|
||||
assert.True(t, value)
|
||||
})
|
||||
}
|
||||
|
||||
func TestUnmarshalRowInt(t *testing.T) {
|
||||
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2")
|
||||
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
||||
|
||||
var value int
|
||||
assert.Nil(t, query(db, func(rows *sql.Rows) error {
|
||||
return unmarshalRow(&value, rows, true)
|
||||
}, "select value from users where user=?", "anyone"))
|
||||
assert.EqualValues(t, 2, value)
|
||||
})
|
||||
}
|
||||
|
||||
func TestUnmarshalRowInt8(t *testing.T) {
|
||||
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("3")
|
||||
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
||||
|
||||
var value int8
|
||||
assert.Nil(t, query(db, func(rows *sql.Rows) error {
|
||||
return unmarshalRow(&value, rows, true)
|
||||
}, "select value from users where user=?", "anyone"))
|
||||
assert.EqualValues(t, int8(3), value)
|
||||
})
|
||||
}
|
||||
|
||||
func TestUnmarshalRowInt16(t *testing.T) {
|
||||
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("4")
|
||||
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
||||
|
||||
var value int16
|
||||
assert.Nil(t, query(db, func(rows *sql.Rows) error {
|
||||
return unmarshalRow(&value, rows, true)
|
||||
}, "select value from users where user=?", "anyone"))
|
||||
assert.Equal(t, int16(4), value)
|
||||
})
|
||||
}
|
||||
|
||||
func TestUnmarshalRowInt32(t *testing.T) {
|
||||
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("5")
|
||||
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
||||
|
||||
var value int32
|
||||
assert.Nil(t, query(db, func(rows *sql.Rows) error {
|
||||
return unmarshalRow(&value, rows, true)
|
||||
}, "select value from users where user=?", "anyone"))
|
||||
assert.Equal(t, int32(5), value)
|
||||
})
|
||||
}
|
||||
|
||||
func TestUnmarshalRowInt64(t *testing.T) {
|
||||
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("6")
|
||||
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
||||
|
||||
var value int64
|
||||
assert.Nil(t, query(db, func(rows *sql.Rows) error {
|
||||
return unmarshalRow(&value, rows, true)
|
||||
}, "select value from users where user=?", "anyone"))
|
||||
assert.EqualValues(t, int64(6), value)
|
||||
})
|
||||
}
|
||||
|
||||
func TestUnmarshalRowUint(t *testing.T) {
|
||||
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2")
|
||||
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
||||
|
||||
var value uint
|
||||
assert.Nil(t, query(db, func(rows *sql.Rows) error {
|
||||
return unmarshalRow(&value, rows, true)
|
||||
}, "select value from users where user=?", "anyone"))
|
||||
assert.EqualValues(t, uint(2), value)
|
||||
})
|
||||
}
|
||||
|
||||
func TestUnmarshalRowUint8(t *testing.T) {
|
||||
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("3")
|
||||
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
||||
|
||||
var value uint8
|
||||
assert.Nil(t, query(db, func(rows *sql.Rows) error {
|
||||
return unmarshalRow(&value, rows, true)
|
||||
}, "select value from users where user=?", "anyone"))
|
||||
assert.EqualValues(t, uint8(3), value)
|
||||
})
|
||||
}
|
||||
|
||||
func TestUnmarshalRowUint16(t *testing.T) {
|
||||
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("4")
|
||||
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
||||
|
||||
var value uint16
|
||||
assert.Nil(t, query(db, func(rows *sql.Rows) error {
|
||||
return unmarshalRow(&value, rows, true)
|
||||
}, "select value from users where user=?", "anyone"))
|
||||
assert.EqualValues(t, uint16(4), value)
|
||||
})
|
||||
}
|
||||
|
||||
func TestUnmarshalRowUint32(t *testing.T) {
|
||||
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("5")
|
||||
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
||||
|
||||
var value uint32
|
||||
assert.Nil(t, query(db, func(rows *sql.Rows) error {
|
||||
return unmarshalRow(&value, rows, true)
|
||||
}, "select value from users where user=?", "anyone"))
|
||||
assert.EqualValues(t, uint32(5), value)
|
||||
})
|
||||
}
|
||||
|
||||
func TestUnmarshalRowUint64(t *testing.T) {
|
||||
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("6")
|
||||
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
||||
|
||||
var value uint64
|
||||
assert.Nil(t, query(db, func(rows *sql.Rows) error {
|
||||
return unmarshalRow(&value, rows, true)
|
||||
}, "select value from users where user=?", "anyone"))
|
||||
assert.EqualValues(t, uint16(6), value)
|
||||
})
|
||||
}
|
||||
|
||||
func TestUnmarshalRowFloat32(t *testing.T) {
|
||||
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("7")
|
||||
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
||||
|
||||
var value float32
|
||||
assert.Nil(t, query(db, func(rows *sql.Rows) error {
|
||||
return unmarshalRow(&value, rows, true)
|
||||
}, "select value from users where user=?", "anyone"))
|
||||
assert.EqualValues(t, float32(7), value)
|
||||
})
|
||||
}
|
||||
|
||||
func TestUnmarshalRowFloat64(t *testing.T) {
|
||||
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("8")
|
||||
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
||||
|
||||
var value float64
|
||||
assert.Nil(t, query(db, func(rows *sql.Rows) error {
|
||||
return unmarshalRow(&value, rows, true)
|
||||
}, "select value from users where user=?", "anyone"))
|
||||
assert.EqualValues(t, float64(8), value)
|
||||
})
|
||||
}
|
||||
|
||||
func TestUnmarshalRowString(t *testing.T) {
|
||||
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||
const expect = "hello"
|
||||
rs := sqlmock.NewRows([]string{"value"}).FromCSVString(expect)
|
||||
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
||||
|
||||
var value string
|
||||
assert.Nil(t, query(db, func(rows *sql.Rows) error {
|
||||
return unmarshalRow(&value, rows, true)
|
||||
}, "select value from users where user=?", "anyone"))
|
||||
assert.EqualValues(t, expect, value)
|
||||
})
|
||||
}
|
||||
|
||||
func TestUnmarshalRowStruct(t *testing.T) {
|
||||
var value = new(struct {
|
||||
Name string
|
||||
Age int
|
||||
})
|
||||
|
||||
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||
rs := sqlmock.NewRows([]string{"name", "age"}).FromCSVString("liao,5")
|
||||
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
||||
|
||||
assert.Nil(t, query(db, func(rows *sql.Rows) error {
|
||||
return unmarshalRow(value, rows, true)
|
||||
}, "select name, age from users where user=?", "anyone"))
|
||||
assert.Equal(t, "liao", value.Name)
|
||||
assert.Equal(t, 5, value.Age)
|
||||
})
|
||||
}
|
||||
|
||||
func TestUnmarshalRowStructWithTags(t *testing.T) {
|
||||
var value = new(struct {
|
||||
Age int `db:"age"`
|
||||
Name string `db:"name"`
|
||||
})
|
||||
|
||||
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||
rs := sqlmock.NewRows([]string{"name", "age"}).FromCSVString("liao,5")
|
||||
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
||||
|
||||
assert.Nil(t, query(db, func(rows *sql.Rows) error {
|
||||
return unmarshalRow(value, rows, true)
|
||||
}, "select name, age from users where user=?", "anyone"))
|
||||
assert.Equal(t, "liao", value.Name)
|
||||
assert.Equal(t, 5, value.Age)
|
||||
})
|
||||
}
|
||||
|
||||
func TestUnmarshalRowsBool(t *testing.T) {
|
||||
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||
var expect = []bool{true, false}
|
||||
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("1\n0")
|
||||
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
||||
|
||||
var value []bool
|
||||
assert.Nil(t, query(db, func(rows *sql.Rows) error {
|
||||
return unmarshalRows(&value, rows, true)
|
||||
}, "select value from users where user=?", "anyone"))
|
||||
assert.EqualValues(t, expect, value)
|
||||
})
|
||||
}
|
||||
|
||||
func TestUnmarshalRowsInt(t *testing.T) {
|
||||
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||
var expect = []int{2, 3}
|
||||
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2\n3")
|
||||
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
||||
|
||||
var value []int
|
||||
assert.Nil(t, query(db, func(rows *sql.Rows) error {
|
||||
return unmarshalRows(&value, rows, true)
|
||||
}, "select value from users where user=?", "anyone"))
|
||||
assert.EqualValues(t, expect, value)
|
||||
})
|
||||
}
|
||||
|
||||
func TestUnmarshalRowsInt8(t *testing.T) {
|
||||
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||
var expect = []int8{2, 3}
|
||||
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2\n3")
|
||||
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
||||
|
||||
var value []int8
|
||||
assert.Nil(t, query(db, func(rows *sql.Rows) error {
|
||||
return unmarshalRows(&value, rows, true)
|
||||
}, "select value from users where user=?", "anyone"))
|
||||
assert.EqualValues(t, expect, value)
|
||||
})
|
||||
}
|
||||
|
||||
func TestUnmarshalRowsInt16(t *testing.T) {
|
||||
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||
var expect = []int16{2, 3}
|
||||
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2\n3")
|
||||
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
||||
|
||||
var value []int16
|
||||
assert.Nil(t, query(db, func(rows *sql.Rows) error {
|
||||
return unmarshalRows(&value, rows, true)
|
||||
}, "select value from users where user=?", "anyone"))
|
||||
assert.EqualValues(t, expect, value)
|
||||
})
|
||||
}
|
||||
|
||||
func TestUnmarshalRowsInt32(t *testing.T) {
|
||||
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||
var expect = []int32{2, 3}
|
||||
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2\n3")
|
||||
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
||||
|
||||
var value []int32
|
||||
assert.Nil(t, query(db, func(rows *sql.Rows) error {
|
||||
return unmarshalRows(&value, rows, true)
|
||||
}, "select value from users where user=?", "anyone"))
|
||||
assert.EqualValues(t, expect, value)
|
||||
})
|
||||
}
|
||||
|
||||
func TestUnmarshalRowsInt64(t *testing.T) {
|
||||
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||
var expect = []int64{2, 3}
|
||||
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2\n3")
|
||||
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
||||
|
||||
var value []int64
|
||||
assert.Nil(t, query(db, func(rows *sql.Rows) error {
|
||||
return unmarshalRows(&value, rows, true)
|
||||
}, "select value from users where user=?", "anyone"))
|
||||
assert.EqualValues(t, expect, value)
|
||||
})
|
||||
}
|
||||
|
||||
func TestUnmarshalRowsUint(t *testing.T) {
|
||||
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||
var expect = []uint{2, 3}
|
||||
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2\n3")
|
||||
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
||||
|
||||
var value []uint
|
||||
assert.Nil(t, query(db, func(rows *sql.Rows) error {
|
||||
return unmarshalRows(&value, rows, true)
|
||||
}, "select value from users where user=?", "anyone"))
|
||||
assert.EqualValues(t, expect, value)
|
||||
})
|
||||
}
|
||||
|
||||
func TestUnmarshalRowsUint8(t *testing.T) {
|
||||
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||
var expect = []uint8{2, 3}
|
||||
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2\n3")
|
||||
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
||||
|
||||
var value []uint8
|
||||
assert.Nil(t, query(db, func(rows *sql.Rows) error {
|
||||
return unmarshalRows(&value, rows, true)
|
||||
}, "select value from users where user=?", "anyone"))
|
||||
assert.EqualValues(t, expect, value)
|
||||
})
|
||||
}
|
||||
|
||||
func TestUnmarshalRowsUint16(t *testing.T) {
|
||||
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||
var expect = []uint16{2, 3}
|
||||
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2\n3")
|
||||
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
||||
|
||||
var value []uint16
|
||||
assert.Nil(t, query(db, func(rows *sql.Rows) error {
|
||||
return unmarshalRows(&value, rows, true)
|
||||
}, "select value from users where user=?", "anyone"))
|
||||
assert.EqualValues(t, expect, value)
|
||||
})
|
||||
}
|
||||
|
||||
func TestUnmarshalRowsUint32(t *testing.T) {
|
||||
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||
var expect = []uint32{2, 3}
|
||||
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2\n3")
|
||||
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
||||
|
||||
var value []uint32
|
||||
assert.Nil(t, query(db, func(rows *sql.Rows) error {
|
||||
return unmarshalRows(&value, rows, true)
|
||||
}, "select value from users where user=?", "anyone"))
|
||||
assert.EqualValues(t, expect, value)
|
||||
})
|
||||
}
|
||||
|
||||
func TestUnmarshalRowsUint64(t *testing.T) {
|
||||
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||
var expect = []uint64{2, 3}
|
||||
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2\n3")
|
||||
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
||||
|
||||
var value []uint64
|
||||
assert.Nil(t, query(db, func(rows *sql.Rows) error {
|
||||
return unmarshalRows(&value, rows, true)
|
||||
}, "select value from users where user=?", "anyone"))
|
||||
assert.EqualValues(t, expect, value)
|
||||
})
|
||||
}
|
||||
|
||||
func TestUnmarshalRowsFloat32(t *testing.T) {
|
||||
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||
var expect = []float32{2, 3}
|
||||
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2\n3")
|
||||
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
||||
|
||||
var value []float32
|
||||
assert.Nil(t, query(db, func(rows *sql.Rows) error {
|
||||
return unmarshalRows(&value, rows, true)
|
||||
}, "select value from users where user=?", "anyone"))
|
||||
assert.EqualValues(t, expect, value)
|
||||
})
|
||||
}
|
||||
|
||||
func TestUnmarshalRowsFloat64(t *testing.T) {
|
||||
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||
var expect = []float64{2, 3}
|
||||
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2\n3")
|
||||
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
||||
|
||||
var value []float64
|
||||
assert.Nil(t, query(db, func(rows *sql.Rows) error {
|
||||
return unmarshalRows(&value, rows, true)
|
||||
}, "select value from users where user=?", "anyone"))
|
||||
assert.EqualValues(t, expect, value)
|
||||
})
|
||||
}
|
||||
|
||||
func TestUnmarshalRowsString(t *testing.T) {
|
||||
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||
var expect = []string{"hello", "world"}
|
||||
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("hello\nworld")
|
||||
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
||||
|
||||
var value []string
|
||||
assert.Nil(t, query(db, func(rows *sql.Rows) error {
|
||||
return unmarshalRows(&value, rows, true)
|
||||
}, "select value from users where user=?", "anyone"))
|
||||
assert.EqualValues(t, expect, value)
|
||||
})
|
||||
}
|
||||
|
||||
func TestUnmarshalRowsBoolPtr(t *testing.T) {
|
||||
yes := true
|
||||
no := false
|
||||
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||
var expect = []*bool{&yes, &no}
|
||||
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("1\n0")
|
||||
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
||||
|
||||
var value []*bool
|
||||
assert.Nil(t, query(db, func(rows *sql.Rows) error {
|
||||
return unmarshalRows(&value, rows, true)
|
||||
}, "select value from users where user=?", "anyone"))
|
||||
assert.EqualValues(t, expect, value)
|
||||
})
|
||||
}
|
||||
|
||||
func TestUnmarshalRowsIntPtr(t *testing.T) {
|
||||
two := 2
|
||||
three := 3
|
||||
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||
var expect = []*int{&two, &three}
|
||||
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2\n3")
|
||||
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
||||
|
||||
var value []*int
|
||||
assert.Nil(t, query(db, func(rows *sql.Rows) error {
|
||||
return unmarshalRows(&value, rows, true)
|
||||
}, "select value from users where user=?", "anyone"))
|
||||
assert.EqualValues(t, expect, value)
|
||||
})
|
||||
}
|
||||
|
||||
func TestUnmarshalRowsInt8Ptr(t *testing.T) {
|
||||
two := int8(2)
|
||||
three := int8(3)
|
||||
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||
var expect = []*int8{&two, &three}
|
||||
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2\n3")
|
||||
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
||||
|
||||
var value []*int8
|
||||
assert.Nil(t, query(db, func(rows *sql.Rows) error {
|
||||
return unmarshalRows(&value, rows, true)
|
||||
}, "select value from users where user=?", "anyone"))
|
||||
assert.EqualValues(t, expect, value)
|
||||
})
|
||||
}
|
||||
|
||||
func TestUnmarshalRowsInt16Ptr(t *testing.T) {
|
||||
two := int16(2)
|
||||
three := int16(3)
|
||||
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||
var expect = []*int16{&two, &three}
|
||||
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2\n3")
|
||||
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
||||
|
||||
var value []*int16
|
||||
assert.Nil(t, query(db, func(rows *sql.Rows) error {
|
||||
return unmarshalRows(&value, rows, true)
|
||||
}, "select value from users where user=?", "anyone"))
|
||||
assert.EqualValues(t, expect, value)
|
||||
})
|
||||
}
|
||||
|
||||
func TestUnmarshalRowsInt32Ptr(t *testing.T) {
|
||||
two := int32(2)
|
||||
three := int32(3)
|
||||
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||
var expect = []*int32{&two, &three}
|
||||
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2\n3")
|
||||
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
||||
|
||||
var value []*int32
|
||||
assert.Nil(t, query(db, func(rows *sql.Rows) error {
|
||||
return unmarshalRows(&value, rows, true)
|
||||
}, "select value from users where user=?", "anyone"))
|
||||
assert.EqualValues(t, expect, value)
|
||||
})
|
||||
}
|
||||
|
||||
func TestUnmarshalRowsInt64Ptr(t *testing.T) {
|
||||
two := int64(2)
|
||||
three := int64(3)
|
||||
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||
var expect = []*int64{&two, &three}
|
||||
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2\n3")
|
||||
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
||||
|
||||
var value []*int64
|
||||
assert.Nil(t, query(db, func(rows *sql.Rows) error {
|
||||
return unmarshalRows(&value, rows, true)
|
||||
}, "select value from users where user=?", "anyone"))
|
||||
assert.EqualValues(t, expect, value)
|
||||
})
|
||||
}
|
||||
|
||||
func TestUnmarshalRowsUintPtr(t *testing.T) {
|
||||
two := uint(2)
|
||||
three := uint(3)
|
||||
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||
var expect = []*uint{&two, &three}
|
||||
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2\n3")
|
||||
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
||||
|
||||
var value []*uint
|
||||
assert.Nil(t, query(db, func(rows *sql.Rows) error {
|
||||
return unmarshalRows(&value, rows, true)
|
||||
}, "select value from users where user=?", "anyone"))
|
||||
assert.EqualValues(t, expect, value)
|
||||
})
|
||||
}
|
||||
|
||||
func TestUnmarshalRowsUint8Ptr(t *testing.T) {
|
||||
two := uint8(2)
|
||||
three := uint8(3)
|
||||
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||
var expect = []*uint8{&two, &three}
|
||||
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2\n3")
|
||||
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
||||
|
||||
var value []*uint8
|
||||
assert.Nil(t, query(db, func(rows *sql.Rows) error {
|
||||
return unmarshalRows(&value, rows, true)
|
||||
}, "select value from users where user=?", "anyone"))
|
||||
assert.EqualValues(t, expect, value)
|
||||
})
|
||||
}
|
||||
|
||||
func TestUnmarshalRowsUint16Ptr(t *testing.T) {
|
||||
two := uint16(2)
|
||||
three := uint16(3)
|
||||
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||
var expect = []*uint16{&two, &three}
|
||||
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2\n3")
|
||||
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
||||
|
||||
var value []*uint16
|
||||
assert.Nil(t, query(db, func(rows *sql.Rows) error {
|
||||
return unmarshalRows(&value, rows, true)
|
||||
}, "select value from users where user=?", "anyone"))
|
||||
assert.EqualValues(t, expect, value)
|
||||
})
|
||||
}
|
||||
|
||||
func TestUnmarshalRowsUint32Ptr(t *testing.T) {
|
||||
two := uint32(2)
|
||||
three := uint32(3)
|
||||
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||
var expect = []*uint32{&two, &three}
|
||||
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2\n3")
|
||||
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
||||
|
||||
var value []*uint32
|
||||
assert.Nil(t, query(db, func(rows *sql.Rows) error {
|
||||
return unmarshalRows(&value, rows, true)
|
||||
}, "select value from users where user=?", "anyone"))
|
||||
assert.EqualValues(t, expect, value)
|
||||
})
|
||||
}
|
||||
|
||||
func TestUnmarshalRowsUint64Ptr(t *testing.T) {
|
||||
two := uint64(2)
|
||||
three := uint64(3)
|
||||
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||
var expect = []*uint64{&two, &three}
|
||||
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2\n3")
|
||||
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
||||
|
||||
var value []*uint64
|
||||
assert.Nil(t, query(db, func(rows *sql.Rows) error {
|
||||
return unmarshalRows(&value, rows, true)
|
||||
}, "select value from users where user=?", "anyone"))
|
||||
assert.EqualValues(t, expect, value)
|
||||
})
|
||||
}
|
||||
|
||||
func TestUnmarshalRowsFloat32Ptr(t *testing.T) {
|
||||
two := float32(2)
|
||||
three := float32(3)
|
||||
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||
var expect = []*float32{&two, &three}
|
||||
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2\n3")
|
||||
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
||||
|
||||
var value []*float32
|
||||
assert.Nil(t, query(db, func(rows *sql.Rows) error {
|
||||
return unmarshalRows(&value, rows, true)
|
||||
}, "select value from users where user=?", "anyone"))
|
||||
assert.EqualValues(t, expect, value)
|
||||
})
|
||||
}
|
||||
|
||||
func TestUnmarshalRowsFloat64Ptr(t *testing.T) {
|
||||
two := float64(2)
|
||||
three := float64(3)
|
||||
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||
var expect = []*float64{&two, &three}
|
||||
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2\n3")
|
||||
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
||||
|
||||
var value []*float64
|
||||
assert.Nil(t, query(db, func(rows *sql.Rows) error {
|
||||
return unmarshalRows(&value, rows, true)
|
||||
}, "select value from users where user=?", "anyone"))
|
||||
assert.EqualValues(t, expect, value)
|
||||
})
|
||||
}
|
||||
|
||||
func TestUnmarshalRowsStringPtr(t *testing.T) {
|
||||
hello := "hello"
|
||||
world := "world"
|
||||
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||
var expect = []*string{&hello, &world}
|
||||
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("hello\nworld")
|
||||
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
||||
|
||||
var value []*string
|
||||
assert.Nil(t, query(db, func(rows *sql.Rows) error {
|
||||
return unmarshalRows(&value, rows, true)
|
||||
}, "select value from users where user=?", "anyone"))
|
||||
assert.EqualValues(t, expect, value)
|
||||
})
|
||||
}
|
||||
|
||||
func TestUnmarshalRowsStruct(t *testing.T) {
|
||||
var expect = []struct {
|
||||
Name string
|
||||
Age int64
|
||||
}{
|
||||
{
|
||||
Name: "first",
|
||||
Age: 2,
|
||||
},
|
||||
{
|
||||
Name: "second",
|
||||
Age: 3,
|
||||
},
|
||||
}
|
||||
var value []struct {
|
||||
Name string
|
||||
Age int64
|
||||
}
|
||||
|
||||
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||
rs := sqlmock.NewRows([]string{"name", "age"}).FromCSVString("first,2\nsecond,3")
|
||||
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
||||
assert.Nil(t, query(db, func(rows *sql.Rows) error {
|
||||
return unmarshalRows(&value, rows, true)
|
||||
}, "select name, age from users where user=?", "anyone"))
|
||||
|
||||
for i, each := range expect {
|
||||
assert.Equal(t, each.Name, value[i].Name)
|
||||
assert.Equal(t, each.Age, value[i].Age)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestUnmarshalRowsStructWithNullStringType(t *testing.T) {
|
||||
var expect = []struct {
|
||||
Name string
|
||||
NullString sql.NullString
|
||||
}{
|
||||
{
|
||||
Name: "first",
|
||||
NullString: sql.NullString{
|
||||
String: "firstnullstring",
|
||||
Valid: true,
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "second",
|
||||
NullString: sql.NullString{
|
||||
String: "",
|
||||
Valid: false,
|
||||
},
|
||||
},
|
||||
}
|
||||
var value []struct {
|
||||
Name string `db:"name"`
|
||||
NullString sql.NullString `db:"value"`
|
||||
}
|
||||
|
||||
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||
rs := sqlmock.NewRows([]string{"name", "value"}).AddRow(
|
||||
"first", "firstnullstring").AddRow("second", nil)
|
||||
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
||||
assert.Nil(t, query(db, func(rows *sql.Rows) error {
|
||||
return unmarshalRows(&value, rows, true)
|
||||
}, "select name, age from users where user=?", "anyone"))
|
||||
|
||||
for i, each := range expect {
|
||||
assert.Equal(t, each.Name, value[i].Name)
|
||||
assert.Equal(t, each.NullString.String, value[i].NullString.String)
|
||||
assert.Equal(t, each.NullString.Valid, value[i].NullString.Valid)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestUnmarshalRowsStructWithTags(t *testing.T) {
|
||||
var expect = []struct {
|
||||
Name string
|
||||
Age int64
|
||||
}{
|
||||
{
|
||||
Name: "first",
|
||||
Age: 2,
|
||||
},
|
||||
{
|
||||
Name: "second",
|
||||
Age: 3,
|
||||
},
|
||||
}
|
||||
var value []struct {
|
||||
Age int64 `db:"age"`
|
||||
Name string `db:"name"`
|
||||
}
|
||||
|
||||
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||
rs := sqlmock.NewRows([]string{"name", "age"}).FromCSVString("first,2\nsecond,3")
|
||||
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
||||
assert.Nil(t, query(db, func(rows *sql.Rows) error {
|
||||
return unmarshalRows(&value, rows, true)
|
||||
}, "select name, age from users where user=?", "anyone"))
|
||||
|
||||
for i, each := range expect {
|
||||
assert.Equal(t, each.Name, value[i].Name)
|
||||
assert.Equal(t, each.Age, value[i].Age)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestUnmarshalRowsStructAndEmbeddedAnonymousStructWithTags(t *testing.T) {
|
||||
type Embed struct {
|
||||
Value int64 `db:"value"`
|
||||
}
|
||||
|
||||
var expect = []struct {
|
||||
Name string
|
||||
Age int64
|
||||
Value int64
|
||||
}{
|
||||
{
|
||||
Name: "first",
|
||||
Age: 2,
|
||||
Value: 3,
|
||||
},
|
||||
{
|
||||
Name: "second",
|
||||
Age: 3,
|
||||
Value: 4,
|
||||
},
|
||||
}
|
||||
var value []struct {
|
||||
Name string `db:"name"`
|
||||
Age int64 `db:"age"`
|
||||
Embed
|
||||
}
|
||||
|
||||
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||
rs := sqlmock.NewRows([]string{"name", "age", "value"}).FromCSVString("first,2,3\nsecond,3,4")
|
||||
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
||||
assert.Nil(t, query(db, func(rows *sql.Rows) error {
|
||||
return unmarshalRows(&value, rows, true)
|
||||
}, "select name, age, value from users where user=?", "anyone"))
|
||||
|
||||
for i, each := range expect {
|
||||
assert.Equal(t, each.Name, value[i].Name)
|
||||
assert.Equal(t, each.Age, value[i].Age)
|
||||
assert.Equal(t, each.Value, value[i].Value)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestUnmarshalRowsStructAndEmbeddedStructPtrAnonymousWithTags(t *testing.T) {
|
||||
type Embed struct {
|
||||
Value int64 `db:"value"`
|
||||
}
|
||||
|
||||
var expect = []struct {
|
||||
Name string
|
||||
Age int64
|
||||
Value int64
|
||||
}{
|
||||
{
|
||||
Name: "first",
|
||||
Age: 2,
|
||||
Value: 3,
|
||||
},
|
||||
{
|
||||
Name: "second",
|
||||
Age: 3,
|
||||
Value: 4,
|
||||
},
|
||||
}
|
||||
var value []struct {
|
||||
Name string `db:"name"`
|
||||
Age int64 `db:"age"`
|
||||
*Embed
|
||||
}
|
||||
|
||||
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||
rs := sqlmock.NewRows([]string{"name", "age", "value"}).FromCSVString("first,2,3\nsecond,3,4")
|
||||
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
||||
assert.Nil(t, query(db, func(rows *sql.Rows) error {
|
||||
return unmarshalRows(&value, rows, true)
|
||||
}, "select name, age, value from users where user=?", "anyone"))
|
||||
|
||||
for i, each := range expect {
|
||||
assert.Equal(t, each.Name, value[i].Name)
|
||||
assert.Equal(t, each.Age, value[i].Age)
|
||||
assert.Equal(t, each.Value, value[i].Value)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestUnmarshalRowsStructPtr(t *testing.T) {
|
||||
var expect = []*struct {
|
||||
Name string
|
||||
Age int64
|
||||
}{
|
||||
{
|
||||
Name: "first",
|
||||
Age: 2,
|
||||
},
|
||||
{
|
||||
Name: "second",
|
||||
Age: 3,
|
||||
},
|
||||
}
|
||||
var value []*struct {
|
||||
Name string
|
||||
Age int64
|
||||
}
|
||||
|
||||
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||
rs := sqlmock.NewRows([]string{"name", "age"}).FromCSVString("first,2\nsecond,3")
|
||||
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
||||
assert.Nil(t, query(db, func(rows *sql.Rows) error {
|
||||
return unmarshalRows(&value, rows, true)
|
||||
}, "select name, age from users where user=?", "anyone"))
|
||||
|
||||
for i, each := range expect {
|
||||
assert.Equal(t, each.Name, value[i].Name)
|
||||
assert.Equal(t, each.Age, value[i].Age)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestUnmarshalRowsStructWithTagsPtr(t *testing.T) {
|
||||
var expect = []*struct {
|
||||
Name string
|
||||
Age int64
|
||||
}{
|
||||
{
|
||||
Name: "first",
|
||||
Age: 2,
|
||||
},
|
||||
{
|
||||
Name: "second",
|
||||
Age: 3,
|
||||
},
|
||||
}
|
||||
var value []*struct {
|
||||
Age int64 `db:"age"`
|
||||
Name string `db:"name"`
|
||||
}
|
||||
|
||||
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||
rs := sqlmock.NewRows([]string{"name", "age"}).FromCSVString("first,2\nsecond,3")
|
||||
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
||||
assert.Nil(t, query(db, func(rows *sql.Rows) error {
|
||||
return unmarshalRows(&value, rows, true)
|
||||
}, "select name, age from users where user=?", "anyone"))
|
||||
|
||||
for i, each := range expect {
|
||||
assert.Equal(t, each.Name, value[i].Name)
|
||||
assert.Equal(t, each.Age, value[i].Age)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestUnmarshalRowsStructWithTagsPtrWithInnerPtr(t *testing.T) {
|
||||
var expect = []*struct {
|
||||
Name string
|
||||
Age int64
|
||||
}{
|
||||
{
|
||||
Name: "first",
|
||||
Age: 2,
|
||||
},
|
||||
{
|
||||
Name: "second",
|
||||
Age: 3,
|
||||
},
|
||||
}
|
||||
var value []*struct {
|
||||
Age *int64 `db:"age"`
|
||||
Name string `db:"name"`
|
||||
}
|
||||
|
||||
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||
rs := sqlmock.NewRows([]string{"name", "age"}).FromCSVString("first,2\nsecond,3")
|
||||
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
||||
assert.Nil(t, query(db, func(rows *sql.Rows) error {
|
||||
return unmarshalRows(&value, rows, true)
|
||||
}, "select name, age from users where user=?", "anyone"))
|
||||
|
||||
for i, each := range expect {
|
||||
assert.Equal(t, each.Name, value[i].Name)
|
||||
assert.Equal(t, each.Age, *value[i].Age)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestCommonSqlConn_QueryRowOptional(t *testing.T) {
|
||||
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||
rs := sqlmock.NewRows([]string{"age"}).FromCSVString("5")
|
||||
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
||||
|
||||
var r struct {
|
||||
User string `db:"user"`
|
||||
Age int `db:"age"`
|
||||
}
|
||||
assert.Nil(t, query(db, func(rows *sql.Rows) error {
|
||||
return unmarshalRow(&r, rows, false)
|
||||
}, "select age from users where user=?", "anyone"))
|
||||
assert.Empty(t, r.User)
|
||||
assert.Equal(t, 5, r.Age)
|
||||
})
|
||||
}
|
||||
|
||||
func runOrmTest(t *testing.T, fn func(db *sql.DB, mock sqlmock.Sqlmock)) {
|
||||
logx.Disable()
|
||||
|
||||
db, mock, err := sqlmock.New()
|
||||
if err != nil {
|
||||
t.Fatalf("an error '%s' was not expected when opening a stub database connection", err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
fn(db, mock)
|
||||
|
||||
if err := mock.ExpectationsWereMet(); err != nil {
|
||||
t.Errorf("there were unfulfilled expectations: %s", err)
|
||||
}
|
||||
}
|
||||
204
core/stores/sqlx/sqlconn.go
Normal file
204
core/stores/sqlx/sqlconn.go
Normal file
@@ -0,0 +1,204 @@
|
||||
package sqlx
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
|
||||
"zero/core/breaker"
|
||||
)
|
||||
|
||||
var ErrNotFound = sql.ErrNoRows
|
||||
|
||||
type (
|
||||
// Session stands for raw connections or transaction sessions
|
||||
Session interface {
|
||||
Exec(query string, args ...interface{}) (sql.Result, error)
|
||||
Prepare(query string) (StmtSession, error)
|
||||
QueryRow(v interface{}, query string, args ...interface{}) error
|
||||
QueryRowPartial(v interface{}, query string, args ...interface{}) error
|
||||
QueryRows(v interface{}, query string, args ...interface{}) error
|
||||
QueryRowsPartial(v interface{}, query string, args ...interface{}) error
|
||||
}
|
||||
|
||||
// SqlConn only stands for raw connections, so Transact method can be called.
|
||||
SqlConn interface {
|
||||
Session
|
||||
Transact(func(session Session) error) error
|
||||
}
|
||||
|
||||
SqlOption func(*commonSqlConn)
|
||||
|
||||
StmtSession interface {
|
||||
Close() error
|
||||
Exec(args ...interface{}) (sql.Result, error)
|
||||
QueryRow(v interface{}, args ...interface{}) error
|
||||
QueryRowPartial(v interface{}, args ...interface{}) error
|
||||
QueryRows(v interface{}, args ...interface{}) error
|
||||
QueryRowsPartial(v interface{}, args ...interface{}) error
|
||||
}
|
||||
|
||||
// thread-safe
|
||||
// Because CORBA doesn't support PREPARE, so we need to combine the
|
||||
// query arguments into one string and do underlying query without arguments
|
||||
commonSqlConn struct {
|
||||
driverName string
|
||||
datasource string
|
||||
beginTx beginnable
|
||||
brk breaker.Breaker
|
||||
accept func(error) bool
|
||||
}
|
||||
|
||||
sessionConn interface {
|
||||
Exec(query string, args ...interface{}) (sql.Result, error)
|
||||
Query(query string, args ...interface{}) (*sql.Rows, error)
|
||||
}
|
||||
|
||||
statement struct {
|
||||
stmt *sql.Stmt
|
||||
}
|
||||
|
||||
stmtConn interface {
|
||||
Exec(args ...interface{}) (sql.Result, error)
|
||||
Query(args ...interface{}) (*sql.Rows, error)
|
||||
}
|
||||
)
|
||||
|
||||
func NewSqlConn(driverName, datasource string, opts ...SqlOption) SqlConn {
|
||||
conn := &commonSqlConn{
|
||||
driverName: driverName,
|
||||
datasource: datasource,
|
||||
beginTx: begin,
|
||||
brk: breaker.NewBreaker(),
|
||||
}
|
||||
for _, opt := range opts {
|
||||
opt(conn)
|
||||
}
|
||||
|
||||
return conn
|
||||
}
|
||||
|
||||
func (db *commonSqlConn) Exec(q string, args ...interface{}) (result sql.Result, err error) {
|
||||
err = db.brk.DoWithAcceptable(func() error {
|
||||
var conn *sql.DB
|
||||
conn, err = getSqlConn(db.driverName, db.datasource)
|
||||
if err != nil {
|
||||
logInstanceError(db.datasource, err)
|
||||
return err
|
||||
}
|
||||
|
||||
result, err = exec(conn, q, args...)
|
||||
return err
|
||||
}, db.acceptable)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func (db *commonSqlConn) Prepare(query string) (stmt StmtSession, err error) {
|
||||
err = db.brk.DoWithAcceptable(func() error {
|
||||
var conn *sql.DB
|
||||
conn, err = getSqlConn(db.driverName, db.datasource)
|
||||
if err != nil {
|
||||
logInstanceError(db.datasource, err)
|
||||
return err
|
||||
}
|
||||
|
||||
if st, err := conn.Prepare(query); err != nil {
|
||||
return err
|
||||
} else {
|
||||
stmt = statement{
|
||||
stmt: st,
|
||||
}
|
||||
return nil
|
||||
}
|
||||
}, db.acceptable)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func (db *commonSqlConn) QueryRow(v interface{}, q string, args ...interface{}) error {
|
||||
return db.queryRows(func(rows *sql.Rows) error {
|
||||
return unmarshalRow(v, rows, true)
|
||||
}, q, args...)
|
||||
}
|
||||
|
||||
func (db *commonSqlConn) QueryRowPartial(v interface{}, q string, args ...interface{}) error {
|
||||
return db.queryRows(func(rows *sql.Rows) error {
|
||||
return unmarshalRow(v, rows, false)
|
||||
}, q, args...)
|
||||
}
|
||||
|
||||
func (db *commonSqlConn) QueryRows(v interface{}, q string, args ...interface{}) error {
|
||||
return db.queryRows(func(rows *sql.Rows) error {
|
||||
return unmarshalRows(v, rows, true)
|
||||
}, q, args...)
|
||||
}
|
||||
|
||||
func (db *commonSqlConn) QueryRowsPartial(v interface{}, q string, args ...interface{}) error {
|
||||
return db.queryRows(func(rows *sql.Rows) error {
|
||||
return unmarshalRows(v, rows, false)
|
||||
}, q, args...)
|
||||
}
|
||||
|
||||
func (db *commonSqlConn) Transact(fn func(Session) error) error {
|
||||
return db.brk.DoWithAcceptable(func() error {
|
||||
return transact(db, db.beginTx, fn)
|
||||
}, db.acceptable)
|
||||
}
|
||||
|
||||
func (db *commonSqlConn) acceptable(err error) bool {
|
||||
ok := err == nil || err == sql.ErrNoRows || err == sql.ErrTxDone
|
||||
if db.accept == nil {
|
||||
return ok
|
||||
} else {
|
||||
return ok || db.accept(err)
|
||||
}
|
||||
}
|
||||
|
||||
func (db *commonSqlConn) queryRows(scanner func(*sql.Rows) error, q string, args ...interface{}) error {
|
||||
var qerr error
|
||||
return db.brk.DoWithAcceptable(func() error {
|
||||
conn, err := getSqlConn(db.driverName, db.datasource)
|
||||
if err != nil {
|
||||
logInstanceError(db.datasource, err)
|
||||
return err
|
||||
}
|
||||
|
||||
return query(conn, func(rows *sql.Rows) error {
|
||||
qerr = scanner(rows)
|
||||
return qerr
|
||||
}, q, args...)
|
||||
}, func(err error) bool {
|
||||
return qerr == err || db.acceptable(err)
|
||||
})
|
||||
}
|
||||
|
||||
func (s statement) Close() error {
|
||||
return s.stmt.Close()
|
||||
}
|
||||
|
||||
func (s statement) Exec(args ...interface{}) (sql.Result, error) {
|
||||
return execStmt(s.stmt, args...)
|
||||
}
|
||||
|
||||
func (s statement) QueryRow(v interface{}, args ...interface{}) error {
|
||||
return queryStmt(s.stmt, func(rows *sql.Rows) error {
|
||||
return unmarshalRow(v, rows, true)
|
||||
}, args...)
|
||||
}
|
||||
|
||||
func (s statement) QueryRowPartial(v interface{}, args ...interface{}) error {
|
||||
return queryStmt(s.stmt, func(rows *sql.Rows) error {
|
||||
return unmarshalRow(v, rows, false)
|
||||
}, args...)
|
||||
}
|
||||
|
||||
func (s statement) QueryRows(v interface{}, args ...interface{}) error {
|
||||
return queryStmt(s.stmt, func(rows *sql.Rows) error {
|
||||
return unmarshalRows(v, rows, true)
|
||||
}, args...)
|
||||
}
|
||||
|
||||
func (s statement) QueryRowsPartial(v interface{}, args ...interface{}) error {
|
||||
return queryStmt(s.stmt, func(rows *sql.Rows) error {
|
||||
return unmarshalRows(v, rows, false)
|
||||
}, args...)
|
||||
}
|
||||
74
core/stores/sqlx/sqlmanager.go
Normal file
74
core/stores/sqlx/sqlmanager.go
Normal file
@@ -0,0 +1,74 @@
|
||||
package sqlx
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"io"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"zero/core/syncx"
|
||||
)
|
||||
|
||||
const (
|
||||
maxIdleConns = 64
|
||||
maxOpenConns = 64
|
||||
maxLifetime = time.Minute
|
||||
)
|
||||
|
||||
var connManager = syncx.NewResourceManager()
|
||||
|
||||
type pingedDB struct {
|
||||
*sql.DB
|
||||
once sync.Once
|
||||
}
|
||||
|
||||
func getCachedSqlConn(driverName, server string) (*pingedDB, error) {
|
||||
val, err := connManager.GetResource(server, func() (io.Closer, error) {
|
||||
conn, err := newDBConnection(driverName, server)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &pingedDB{
|
||||
DB: conn,
|
||||
}, nil
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return val.(*pingedDB), nil
|
||||
}
|
||||
|
||||
func getSqlConn(driverName, server string) (*sql.DB, error) {
|
||||
pdb, err := getCachedSqlConn(driverName, server)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
pdb.once.Do(func() {
|
||||
err = pdb.Ping()
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return pdb.DB, nil
|
||||
}
|
||||
|
||||
func newDBConnection(driverName, datasource string) (*sql.DB, error) {
|
||||
conn, err := sql.Open(driverName, datasource)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// we need to do this until the issue https://github.com/golang/go/issues/9851 get fixed
|
||||
// discussed here https://github.com/go-sql-driver/mysql/issues/257
|
||||
// if the discussed SetMaxIdleTimeout methods added, we'll change this behavior
|
||||
// 8 means we can't have more than 8 goroutines to concurrently access the same database.
|
||||
conn.SetMaxIdleConns(maxIdleConns)
|
||||
conn.SetMaxOpenConns(maxOpenConns)
|
||||
conn.SetConnMaxLifetime(maxLifetime)
|
||||
|
||||
return conn, nil
|
||||
}
|
||||
92
core/stores/sqlx/stmt.go
Normal file
92
core/stores/sqlx/stmt.go
Normal file
@@ -0,0 +1,92 @@
|
||||
package sqlx
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"zero/core/logx"
|
||||
"zero/core/timex"
|
||||
)
|
||||
|
||||
const slowThreshold = time.Millisecond * 500
|
||||
|
||||
func exec(conn sessionConn, q string, args ...interface{}) (sql.Result, error) {
|
||||
stmt, err := format(q, args...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
startTime := timex.Now()
|
||||
result, err := conn.Exec(q, args...)
|
||||
duration := timex.Since(startTime)
|
||||
if duration > slowThreshold {
|
||||
logx.WithDuration(duration).Slowf("[SQL] exec: slowcall - %s", stmt)
|
||||
} else {
|
||||
logx.WithDuration(duration).Infof("sql exec: %s", stmt)
|
||||
}
|
||||
if err != nil {
|
||||
logSqlError(stmt, err)
|
||||
}
|
||||
|
||||
return result, err
|
||||
}
|
||||
|
||||
func execStmt(conn stmtConn, args ...interface{}) (sql.Result, error) {
|
||||
stmt := fmt.Sprint(args...)
|
||||
startTime := timex.Now()
|
||||
result, err := conn.Exec(args...)
|
||||
duration := timex.Since(startTime)
|
||||
if duration > slowThreshold {
|
||||
logx.WithDuration(duration).Slowf("[SQL] execStmt: slowcall - %s", stmt)
|
||||
} else {
|
||||
logx.WithDuration(duration).Infof("sql execStmt: %s", stmt)
|
||||
}
|
||||
if err != nil {
|
||||
logSqlError(stmt, err)
|
||||
}
|
||||
|
||||
return result, err
|
||||
}
|
||||
|
||||
func query(conn sessionConn, scanner func(*sql.Rows) error, q string, args ...interface{}) error {
|
||||
stmt, err := format(q, args...)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
startTime := timex.Now()
|
||||
rows, err := conn.Query(q, args...)
|
||||
duration := timex.Since(startTime)
|
||||
if duration > slowThreshold {
|
||||
logx.WithDuration(duration).Slowf("[SQL] query: slowcall - %s", stmt)
|
||||
} else {
|
||||
logx.WithDuration(duration).Infof("sql query: %s", stmt)
|
||||
}
|
||||
if err != nil {
|
||||
logSqlError(stmt, err)
|
||||
return err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
return scanner(rows)
|
||||
}
|
||||
|
||||
func queryStmt(conn stmtConn, scanner func(*sql.Rows) error, args ...interface{}) error {
|
||||
stmt := fmt.Sprint(args...)
|
||||
startTime := timex.Now()
|
||||
rows, err := conn.Query(args...)
|
||||
duration := timex.Since(startTime)
|
||||
if duration > slowThreshold {
|
||||
logx.WithDuration(duration).Slowf("[SQL] queryStmt: slowcall - %s", stmt)
|
||||
} else {
|
||||
logx.WithDuration(duration).Infof("sql queryStmt: %s", stmt)
|
||||
}
|
||||
if err != nil {
|
||||
logSqlError(stmt, err)
|
||||
return err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
return scanner(rows)
|
||||
}
|
||||
103
core/stores/sqlx/tx.go
Normal file
103
core/stores/sqlx/tx.go
Normal file
@@ -0,0 +1,103 @@
|
||||
package sqlx
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
type (
|
||||
beginnable func(*sql.DB) (trans, error)
|
||||
|
||||
trans interface {
|
||||
Session
|
||||
Commit() error
|
||||
Rollback() error
|
||||
}
|
||||
|
||||
txSession struct {
|
||||
*sql.Tx
|
||||
}
|
||||
)
|
||||
|
||||
func (t txSession) Exec(q string, args ...interface{}) (sql.Result, error) {
|
||||
return exec(t.Tx, q, args...)
|
||||
}
|
||||
|
||||
func (t txSession) Prepare(q string) (StmtSession, error) {
|
||||
if stmt, err := t.Tx.Prepare(q); err != nil {
|
||||
return nil, err
|
||||
} else {
|
||||
return statement{
|
||||
stmt: stmt,
|
||||
}, nil
|
||||
}
|
||||
}
|
||||
|
||||
func (t txSession) QueryRow(v interface{}, q string, args ...interface{}) error {
|
||||
return query(t.Tx, func(rows *sql.Rows) error {
|
||||
return unmarshalRow(v, rows, true)
|
||||
}, q, args...)
|
||||
}
|
||||
|
||||
func (t txSession) QueryRowPartial(v interface{}, q string, args ...interface{}) error {
|
||||
return query(t.Tx, func(rows *sql.Rows) error {
|
||||
return unmarshalRow(v, rows, false)
|
||||
}, q, args...)
|
||||
}
|
||||
|
||||
func (t txSession) QueryRows(v interface{}, q string, args ...interface{}) error {
|
||||
return query(t.Tx, func(rows *sql.Rows) error {
|
||||
return unmarshalRows(v, rows, true)
|
||||
}, q, args...)
|
||||
}
|
||||
|
||||
func (t txSession) QueryRowsPartial(v interface{}, q string, args ...interface{}) error {
|
||||
return query(t.Tx, func(rows *sql.Rows) error {
|
||||
return unmarshalRows(v, rows, false)
|
||||
}, q, args...)
|
||||
}
|
||||
|
||||
func begin(db *sql.DB) (trans, error) {
|
||||
if tx, err := db.Begin(); err != nil {
|
||||
return nil, err
|
||||
} else {
|
||||
return txSession{
|
||||
Tx: tx,
|
||||
}, nil
|
||||
}
|
||||
}
|
||||
|
||||
func transact(db *commonSqlConn, b beginnable, fn func(Session) error) (err error) {
|
||||
conn, err := getSqlConn(db.driverName, db.datasource)
|
||||
if err != nil {
|
||||
logInstanceError(db.datasource, err)
|
||||
return err
|
||||
}
|
||||
|
||||
return transactOnConn(conn, b, fn)
|
||||
}
|
||||
|
||||
func transactOnConn(conn *sql.DB, b beginnable, fn func(Session) error) (err error) {
|
||||
var tx trans
|
||||
tx, err = b(conn)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
defer func() {
|
||||
if p := recover(); p != nil {
|
||||
if e := tx.Rollback(); e != nil {
|
||||
err = fmt.Errorf("recover from %#v, rollback failed: %s", p, e)
|
||||
} else {
|
||||
err = fmt.Errorf("recoveer from %#v", p)
|
||||
}
|
||||
} else if err != nil {
|
||||
if e := tx.Rollback(); e != nil {
|
||||
err = fmt.Errorf("transaction failed: %s, rollback failed: %s", err, e)
|
||||
}
|
||||
} else {
|
||||
err = tx.Commit()
|
||||
}
|
||||
}()
|
||||
|
||||
return fn(tx)
|
||||
}
|
||||
76
core/stores/sqlx/tx_test.go
Normal file
76
core/stores/sqlx/tx_test.go
Normal file
@@ -0,0 +1,76 @@
|
||||
package sqlx
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"errors"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
const (
|
||||
mockCommit = 1
|
||||
mockRollback = 2
|
||||
)
|
||||
|
||||
type mockTx struct {
|
||||
status int
|
||||
}
|
||||
|
||||
func (mt *mockTx) Commit() error {
|
||||
mt.status |= mockCommit
|
||||
return nil
|
||||
}
|
||||
|
||||
func (mt *mockTx) Exec(q string, args ...interface{}) (sql.Result, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (mt *mockTx) Prepare(query string) (StmtSession, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (mt *mockTx) QueryRow(v interface{}, q string, args ...interface{}) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (mt *mockTx) QueryRowPartial(v interface{}, q string, args ...interface{}) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (mt *mockTx) QueryRows(v interface{}, q string, args ...interface{}) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (mt *mockTx) QueryRowsPartial(v interface{}, q string, args ...interface{}) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (mt *mockTx) Rollback() error {
|
||||
mt.status |= mockRollback
|
||||
return nil
|
||||
}
|
||||
|
||||
func beginMock(mock *mockTx) beginnable {
|
||||
return func(*sql.DB) (trans, error) {
|
||||
return mock, nil
|
||||
}
|
||||
}
|
||||
|
||||
func TestTransactCommit(t *testing.T) {
|
||||
mock := &mockTx{}
|
||||
err := transactOnConn(nil, beginMock(mock), func(Session) error {
|
||||
return nil
|
||||
})
|
||||
assert.Equal(t, mockCommit, mock.status)
|
||||
assert.Nil(t, err)
|
||||
}
|
||||
|
||||
func TestTransactRollback(t *testing.T) {
|
||||
mock := &mockTx{}
|
||||
err := transactOnConn(nil, beginMock(mock), func(Session) error {
|
||||
return errors.New("rollback")
|
||||
})
|
||||
assert.Equal(t, mockRollback, mock.status)
|
||||
assert.NotNil(t, err)
|
||||
}
|
||||
101
core/stores/sqlx/utils.go
Normal file
101
core/stores/sqlx/utils.go
Normal file
@@ -0,0 +1,101 @@
|
||||
package sqlx
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"zero/core/logx"
|
||||
"zero/core/mapping"
|
||||
)
|
||||
|
||||
func desensitize(datasource string) string {
|
||||
// remove account
|
||||
pos := strings.LastIndex(datasource, "@")
|
||||
if 0 <= pos && pos+1 < len(datasource) {
|
||||
datasource = datasource[pos+1:]
|
||||
}
|
||||
|
||||
return datasource
|
||||
}
|
||||
|
||||
func escape(input string) string {
|
||||
var b strings.Builder
|
||||
|
||||
for _, ch := range input {
|
||||
switch ch {
|
||||
case '\x00':
|
||||
b.WriteString(`\x00`)
|
||||
case '\r':
|
||||
b.WriteString(`\r`)
|
||||
case '\n':
|
||||
b.WriteString(`\n`)
|
||||
case '\\':
|
||||
b.WriteString(`\\`)
|
||||
case '\'':
|
||||
b.WriteString(`\'`)
|
||||
case '"':
|
||||
b.WriteString(`\"`)
|
||||
case '\x1a':
|
||||
b.WriteString(`\x1a`)
|
||||
default:
|
||||
b.WriteRune(ch)
|
||||
}
|
||||
}
|
||||
|
||||
return b.String()
|
||||
}
|
||||
|
||||
func format(query string, args ...interface{}) (string, error) {
|
||||
numArgs := len(args)
|
||||
if numArgs == 0 {
|
||||
return query, nil
|
||||
}
|
||||
|
||||
var b strings.Builder
|
||||
argIndex := 0
|
||||
|
||||
for _, ch := range query {
|
||||
if ch == '?' {
|
||||
if argIndex >= numArgs {
|
||||
return "", fmt.Errorf("error: %d ? in sql, but less arguments provided", argIndex)
|
||||
}
|
||||
|
||||
arg := args[argIndex]
|
||||
argIndex++
|
||||
|
||||
switch v := arg.(type) {
|
||||
case bool:
|
||||
if v {
|
||||
b.WriteByte('1')
|
||||
} else {
|
||||
b.WriteByte('0')
|
||||
}
|
||||
case string:
|
||||
b.WriteByte('\'')
|
||||
b.WriteString(escape(v))
|
||||
b.WriteByte('\'')
|
||||
default:
|
||||
b.WriteString(mapping.Repr(v))
|
||||
}
|
||||
} else {
|
||||
b.WriteRune(ch)
|
||||
}
|
||||
}
|
||||
|
||||
if argIndex < numArgs {
|
||||
return "", fmt.Errorf("error: %d ? in sql, but more arguments provided", argIndex)
|
||||
}
|
||||
|
||||
return b.String(), nil
|
||||
}
|
||||
|
||||
func logInstanceError(datasource string, err error) {
|
||||
datasource = desensitize(datasource)
|
||||
logx.Errorf("Error on getting sql instance of %s: %v", datasource, err)
|
||||
}
|
||||
|
||||
func logSqlError(stmt string, err error) {
|
||||
if err != nil && err != ErrNotFound {
|
||||
logx.Errorf("stmt: %s, error: %s", stmt, err.Error())
|
||||
}
|
||||
}
|
||||
30
core/stores/sqlx/utils_test.go
Normal file
30
core/stores/sqlx/utils_test.go
Normal file
@@ -0,0 +1,30 @@
|
||||
package sqlx
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestEscape(t *testing.T) {
|
||||
s := "a\x00\n\r\\'\"\x1ab"
|
||||
|
||||
out := escape(s)
|
||||
|
||||
assert.Equal(t, `a\x00\n\r\\\'\"\x1ab`, out)
|
||||
}
|
||||
|
||||
func TestDesensitize(t *testing.T) {
|
||||
datasource := "user:pass@tcp(111.222.333.44:3306)/any_table?charset=utf8mb4&parseTime=true&loc=Asia%2FShanghai"
|
||||
datasource = desensitize(datasource)
|
||||
assert.False(t, strings.Contains(datasource, "user"))
|
||||
assert.False(t, strings.Contains(datasource, "pass"))
|
||||
assert.True(t, strings.Contains(datasource, "tcp(111.222.333.44:3306)"))
|
||||
}
|
||||
|
||||
func TestDesensitize_WithoutAccount(t *testing.T) {
|
||||
datasource := "tcp(111.222.333.44:3306)/any_table?charset=utf8mb4&parseTime=true&loc=Asia%2FShanghai"
|
||||
datasource = desensitize(datasource)
|
||||
assert.True(t, strings.Contains(datasource, "tcp(111.222.333.44:3306)"))
|
||||
}
|
||||
Reference in New Issue
Block a user