Compare commits

..

121 Commits

Author SHA1 Message Date
anqiansong
888551627c optimize code (#579)
* optimize code

* optimize returns & unit test
2021-03-27 17:33:17 +08:00
Kevin Wan
bd623aaac3 support postgresql (#583)
support postgresql
2021-03-27 17:14:32 +08:00
Kevin Wan
9e6c2ba2c0 avoid goroutine leak after timeout (#575) 2021-03-21 16:54:34 +08:00
Kevin Wan
c0db8d017d gofmt logs (#574) 2021-03-20 16:40:09 +08:00
TonyWang
52b4f8ca91 add timezone and timeformat (#572)
* add timezone and timeformat

* rm time zone and keep time format

Co-authored-by: Tony Wang <tonywang.data@gmail.com>
2021-03-20 16:36:19 +08:00
Kevin Wan
4884a7b3c6 zrpc timeout & unit tests (#573)
* zrpc timeout & unit tests
2021-03-19 18:41:26 +08:00
Kevin Wan
3c6951577d make hijack more stable (#565) 2021-03-15 20:11:09 +08:00
Kevin Wan
fcd15c9b17 refactor, and add comments to describe graceful shutdown (#564) 2021-03-14 08:51:10 +08:00
Kevin Wan
155e6061cb fix golint issues (#561) 2021-03-12 23:08:04 +08:00
anqiansong
dda7666097 Feature mongo gen (#546)
* add feature: mongo code generation

* upgrade version

* update doc

* format code

* update update.tpl of mysql
2021-03-12 17:49:28 +08:00
hanhotfox
c954568b61 Hdel support for multiple key deletion (#542)
* Hdel support for multiple key deletion

* Hdel field -> fields

Co-authored-by: duanyan <duanyan@xiaoheiban.cn>
2021-03-12 17:47:21 +08:00
Kevin Wan
c2acc43a52 add important notes in readme (#560) 2021-03-12 16:48:25 +08:00
Kevin Wan
1a1a6f5239 add http hijack methods (#555) 2021-03-09 21:30:45 +08:00
anqiansong
60c7edf8f8 fix spelling (#551) 2021-03-08 18:23:12 +08:00
Kevin Wan
7ad86a52f3 update doc link (#552) 2021-03-08 17:56:03 +08:00
kingxt
1e4e5a02b2 rename (#543) 2021-03-04 17:13:07 +08:00
Kevin Wan
39540e21d2 fix golint issues (#540) 2021-03-03 17:16:09 +08:00
hexiaoen
b321622c95 暴露redis EvalSha 以及ScriptLoad接口 (#538)
Co-authored-by: shanehe <shanehe@zego.im>
2021-03-03 17:09:27 +08:00
kingxt
a25cba5380 fix collection breaker (#537)
* fix collection breaker

* optimized

* optimized

* optimized
2021-03-03 10:44:29 +08:00
Kevin Wan
f01472c9ea fix golint issues (#535) 2021-03-02 11:02:57 +08:00
Kevin Wan
af531cf264 fix golint issues (#533) 2021-03-02 00:11:18 +08:00
Kevin Wan
c4b2cddef7 fix golint issues (#532) 2021-03-02 00:04:12 +08:00
Kevin Wan
51de0d0620 fix golint issues in zrpc (#531) 2021-03-01 23:52:44 +08:00
anqiansong
dd393351cc patch 1.1.5 (#530) 2021-03-01 21:14:07 +08:00
Kevin Wan
655ae8034c fix golint issues in rest (#529) 2021-03-01 19:15:35 +08:00
anqiansong
d894b88c3e feature 1.1.5 (#411) 2021-03-01 17:29:07 +08:00
Kevin Wan
791e76bcf0 fix broken build (#528) 2021-02-28 23:53:58 +08:00
Kevin Wan
c566b5ff82 fix golint issues in core/stores (#527) 2021-02-28 23:02:49 +08:00
Kevin Wan
490241d639 fix golint issues in core/syncx (#526) 2021-02-28 16:16:22 +08:00
Kevin Wan
f02711a9cb golint core/discov (#525) 2021-02-27 23:56:18 +08:00
Kevin Wan
ad32f9de23 fix golint issues in core/threading (#524) 2021-02-26 16:27:04 +08:00
Kevin Wan
f309e9f80c fix golint issues in core/utils (#520)
* fix golint issues in core/utils

* fix golint issues in core/trace

* fix golint issues in core/trace
2021-02-26 16:20:47 +08:00
hao
2087ac1e89 修正http转发头字段值错误 (#521) 2021-02-26 16:17:30 +08:00
kingxt
e6ef1fca12 Code optimized (#523)
* optimized markdown generator

* optimized markdown generator

* optimized markdown generator

* add more comment

* add comment

* add comment

* add comments for rpc tool

* add comments for model tool

* add comments for model tool

* add comments for model tool

* add comments for config tool

* add comments for config tool

* add comments

* add comments

* add comments

* add comments

* add comment

* remove rpc main head info

* add comment

* optimized

Co-authored-by: anqiansong <anqiansong@xiaoheiban.cn>
2021-02-26 16:11:47 +08:00
Kevin Wan
ef146cf5ba fix golint issues in core/timex (#517) 2021-02-24 16:27:11 +08:00
Kevin Wan
04b0f26182 fix golint issues in core/stringx (#516) 2021-02-24 16:09:07 +08:00
Kevin Wan
acdaee0fb6 fix golint issues in core/stat (#515)
* change to use ServiceGroup to make it more clear

* fix golint issues in core/stat
2021-02-24 15:13:56 +08:00
Kevin Wan
56ad4776d4 fix misspelling (#513) 2021-02-23 13:53:19 +08:00
Kevin Wan
904d168f18 fix golint issues in core/service (#512) 2021-02-22 22:43:24 +08:00
Kevin Wan
4bd4981bfb fix golint issues in core/search (#509) 2021-02-22 18:58:03 +08:00
Kevin Wan
90562df826 fix golint issues in core/rescue (#508) 2021-02-22 16:47:02 +08:00
Kevin Wan
497762ab47 fix golint issues in core/queue (#507) 2021-02-22 16:38:42 +08:00
Kevin Wan
6e4c98e52d fix golint issues in core/prometheus (#506) 2021-02-22 14:55:04 +08:00
Kevin Wan
b4bb5c0323 fix broken links in readme (#505) 2021-02-22 14:13:33 +08:00
Kevin Wan
a58fac9000 fix golint issues in core/prof (#503) 2021-02-22 10:20:54 +08:00
Kevin Wan
d84e3d4b53 fix golint issues in core/proc (#502) 2021-02-22 10:07:39 +08:00
Kevin Wan
221f923fae fix golint issues in core/netx (#501) 2021-02-22 09:56:56 +08:00
Kevin Wan
bbb9126302 fix golint issues in core/mr (#500) 2021-02-22 09:47:06 +08:00
Kevin Wan
e7c9ef16fe fix golint issues in core/metric (#499) 2021-02-21 21:18:07 +08:00
Kevin Wan
8872d7cbd3 fix golint issues in core/mathx (#498) 2021-02-21 20:47:01 +08:00
Kevin Wan
334ee4213f fix golint issues in core/mapping (#497) 2021-02-20 23:18:22 +08:00
Kevin Wan
226513ed60 fix golint issues in core/logx (#496) 2021-02-20 22:45:58 +08:00
Kevin Wan
dac00d10c1 fix golint issues in core/load (#495) 2021-02-20 22:02:09 +08:00
Kevin Wan
84d2b6f8f5 fix golint issues in core/limit (#494) 2021-02-20 21:55:54 +08:00
kingxt
f98c9246b2 Code optimized (#493) 2021-02-20 19:50:03 +08:00
Kevin Wan
059027bc9d fix golint issues in core/lang (#492) 2021-02-20 18:21:23 +08:00
Kevin Wan
af68caeaf6 fix golint issues in core/jsonx (#491) 2021-02-20 16:59:31 +08:00
Zcc、
fdeacfc89f add redis bitmap command (#490)
Co-authored-by: zhoudeyu <zhoudeyu@xiaoheiban.cn>
2021-02-20 16:26:49 +08:00
Kevin Wan
5b33dd59d9 fix golint issues in core/jsontype (#489) 2021-02-20 15:07:49 +08:00
Kevin Wan
1f92bfde6a fix golint issues in core/iox (#488) 2021-02-19 18:40:26 +08:00
Kevin Wan
0c094cb2d7 fix golint issues in core/hash (#487) 2021-02-19 18:14:34 +08:00
Kevin Wan
f238290dd3 fix golint issues in core/fx (#486) 2021-02-19 17:49:39 +08:00
Kevin Wan
c376ffc351 fix golint issues in core/filex (#485) 2021-02-19 14:30:38 +08:00
Kevin Wan
802549ac7c fix golint issues in core/executors (#484) 2021-02-19 12:03:05 +08:00
Zcc、
72580dee38 redis add bitcount (#483)
Co-authored-by: zhoudeyu <zhoudeyu@xiaoheiban.cn>
2021-02-19 11:41:01 +08:00
Kevin Wan
086113c843 prevent negative timeout settings (#482)
* prevent negative timeout settings

* fix misleading comment
2021-02-19 10:44:39 +08:00
HarryWang29
d239952d2d zrpc client support block (#412) 2021-02-19 10:24:03 +08:00
Kevin Wan
7472d1e70b fix golint issues in core/errorx (#480) 2021-02-19 10:08:38 +08:00
Kevin Wan
2446d8a668 fix golint issues in core/discov (#479) 2021-02-18 22:56:35 +08:00
Kevin Wan
f6894448bd fix golint issues in core/contextx (#477) 2021-02-18 18:00:20 +08:00
Kevin Wan
425be6b4a1 fix golint issues in core/conf (#476) 2021-02-18 15:56:19 +08:00
Kevin Wan
457048bfac fix golint issues in core/collection, refine cache interface (#475) 2021-02-18 15:49:56 +08:00
kingxt
f14ab70035 Code optimized (#474)
* optimized markdown generator

* optimized markdown generator

* optimized markdown generator

* optimized markdown generator
2021-02-18 15:08:20 +08:00
Kevin Wan
8f1c88e07d fix golint issues in core/codec (#473) 2021-02-18 14:11:09 +08:00
Kevin Wan
9602494454 fix issue #469 (#471) 2021-02-17 21:42:22 +08:00
Kevin Wan
38abfb80ed fix gocyclo warnings (#468) 2021-02-17 14:01:05 +08:00
Kevin Wan
87938bcc09 fix golint issues in core/cmdline (#467) 2021-02-17 11:08:30 +08:00
Kevin Wan
8ebf6750b9 fix golint issues in core/breaker (#466) 2021-02-17 10:45:55 +08:00
Kevin Wan
6f92daae12 fix golint issues in core/bloom (#465) 2021-02-17 09:58:35 +08:00
Kevin Wan
80e1c85b50 add more tests for service (#463) 2021-02-11 23:48:19 +08:00
Kevin Wan
395a1db22f add more tests for rest (#462) 2021-02-10 23:08:48 +08:00
bittoy
28009c4224 Update serviceconf.go (#460)
add regression environment config
2021-02-09 15:35:50 +08:00
Kevin Wan
211f3050e9 fix golint issues (#459) 2021-02-09 14:10:38 +08:00
Kevin Wan
03b5fd4a10 fix golint issues (#458) 2021-02-09 14:03:19 +08:00
Kevin Wan
5e969cbef0 fix golint issues, else blocks (#457) 2021-02-09 13:50:21 +08:00
Kevin Wan
42883d0899 fix golint issues, redis methods (#455) 2021-02-09 10:58:11 +08:00
Kevin Wan
06f6dc9937 fix golint issues, package comments (#454) 2021-02-08 22:31:52 +08:00
Kevin Wan
1789b12db2 move examples into zero-examples (#453)
* move examples to zero-examples

* tidy go.mod

* add examples refer in readme
2021-02-08 22:23:36 +08:00
Kevin Wan
c7f3e6119d remove images, use zero-doc instead (#452) 2021-02-08 21:57:40 +08:00
Kevin Wan
54414db91d fix golint issues, exported doc (#451) 2021-02-08 21:31:56 +08:00
Kevin Wan
9b0625bb83 fix golint issues (#450) 2021-02-08 17:08:40 +08:00
Kevin Wan
0dda05fd57 add api doc (#449) 2021-02-08 11:10:55 +08:00
Kevin Wan
5b79ba2618 add discov tests (#448) 2021-02-07 20:24:47 +08:00
Kevin Wan
22a1fa649e remove etcd facade, added for testing purpose (#447) 2021-02-07 19:07:15 +08:00
Kevin Wan
745e76c335 add more tests for stores (#446) 2021-02-07 17:22:47 +08:00
Kevin Wan
852891dbd8 add more tests for stores (#445) 2021-02-07 15:27:01 +08:00
Kevin Wan
316195e912 add more tests for mongoc (#443) 2021-02-07 14:41:00 +08:00
Kevin Wan
8e889d694d add more tests for sqlx (#442)
* add more tests for sqlx

* add more tests for sqlx
2021-02-07 11:54:41 +08:00
Kevin Wan
ec6132b754 add more tests for zrpc (#441) 2021-02-06 12:25:45 +08:00
Kevin Wan
c282bb1d86 add more tests for sqlx (#440) 2021-02-05 22:53:21 +08:00
Kevin Wan
d04b54243d add more tests for proc (#439) 2021-02-05 15:11:27 +08:00
Kevin Wan
b88ba14597 fixes issue #425 (#438) 2021-02-05 13:32:56 +08:00
理工男
7b3c3de35e ring struct add lock (#434)
Co-authored-by: liuhuan210 <liuhuan210@jd.com>
2021-02-03 21:41:10 +08:00
Kevin Wan
abab7c2852 Update readme.md 2021-02-03 15:43:35 +08:00
Kevin Wan
30f5ab0b99 update readme for broken links (#432) 2021-02-03 12:02:22 +08:00
foyon
8b273a075c Support redis command Rpop (#431)
* ss

* ss

* add go-zero:stores:redis-command:Rpop and redis_test

* Delete 1.go

* support redis command Rpop

Co-authored-by: fanhongyi <fanhongyi@tal.com>
2021-02-03 10:19:42 +08:00
Liang Zheng
76026fc211 fix readme.md error (#429)
Signed-off-by: Liang Zheng <microyahoo@163.com>
2021-02-03 10:18:28 +08:00
Hkesd
04284e31cd support hscan in redis (#428) 2021-02-02 17:02:18 +08:00
Kevin Wan
c3b9c3c5ab use english readme as default, because of github ranking (#427) 2021-02-02 16:58:45 +08:00
FengZhang
a8b550e7ef Modify the http content-length max range : 30MB --> 32MB (#424)
Because we are programmer :)
2021-01-30 18:49:33 +08:00
FengZhang
cbfbebed00 modify the maximum content-length to 30MB (#413) 2021-01-29 22:14:48 +08:00
kingxt
2b07f22672 optimize code (#417)
* optimize code

* optimize code

* optimize code

* optimize code
2021-01-26 17:37:22 +08:00
Kevin Wan
a784982030 support zunionstore in redis (#410) 2021-01-21 21:03:24 +08:00
Kevin Wan
ebec5aafab use env if necessary in loading config (#409) 2021-01-21 19:33:34 +08:00
Kevin Wan
572b32729f update goctl version to 1.1.3 (#402) 2021-01-18 16:34:00 +08:00
kingxt
43e712d86a fix type convert error (#395) 2021-01-16 18:24:11 +08:00
kingxt
4db20677f7 optimized (#392) 2021-01-15 11:36:37 +08:00
Kevin Wan
6887fb22de add more tests for codec (#391) 2021-01-14 23:39:44 +08:00
Kevin Wan
50fbdbcfd7 update readme (#390) 2021-01-14 22:26:31 +08:00
ALMAS
c77b8489d7 Update periodicalexecutor.go (#389) 2021-01-14 22:20:09 +08:00
Kevin Wan
eca4ed2cc0 format code (#386) 2021-01-14 13:24:24 +08:00
647 changed files with 7615 additions and 11504 deletions

View File

@@ -27,43 +27,43 @@ return true
` `
) )
// ErrTooLargeOffset indicates the offset is too large in bitset.
var ErrTooLargeOffset = errors.New("too large offset") var ErrTooLargeOffset = errors.New("too large offset")
type ( type (
BitSetProvider interface { // A Filter is a bloom filter.
Filter struct {
bits uint
bitSet bitSetProvider
}
bitSetProvider interface {
check([]uint) (bool, error) check([]uint) (bool, error)
set([]uint) error set([]uint) error
} }
BloomFilter struct {
bits uint
bitSet BitSetProvider
}
) )
// New create a BloomFilter, store is the backed redis, key is the key for the bloom filter, // New create a Filter, store is the backed redis, key is the key for the bloom filter,
// bits is how many bits will be used, maps is how many hashes for each addition. // bits is how many bits will be used, maps is how many hashes for each addition.
// best practices: // best practices:
// elements - means how many actual elements // elements - means how many actual elements
// when maps = 14, formula: 0.7*(bits/maps), bits = 20*elements, the error rate is 0.000067 < 1e-4 // when maps = 14, formula: 0.7*(bits/maps), bits = 20*elements, the error rate is 0.000067 < 1e-4
// for detailed error rate table, see http://pages.cs.wisc.edu/~cao/papers/summary-cache/node8.html // for detailed error rate table, see http://pages.cs.wisc.edu/~cao/papers/summary-cache/node8.html
func New(store *redis.Redis, key string, bits uint) *BloomFilter { func New(store *redis.Redis, key string, bits uint) *Filter {
return &BloomFilter{ return &Filter{
bits: bits, bits: bits,
bitSet: newRedisBitSet(store, key, bits), bitSet: newRedisBitSet(store, key, bits),
} }
} }
func (f *BloomFilter) Add(data []byte) error { // Add adds data into f.
func (f *Filter) Add(data []byte) error {
locations := f.getLocations(data) locations := f.getLocations(data)
err := f.bitSet.set(locations) return f.bitSet.set(locations)
if err != nil {
return err
}
return nil
} }
func (f *BloomFilter) Exists(data []byte) (bool, error) { // Exists checks if data is in f.
func (f *Filter) Exists(data []byte) (bool, error) {
locations := f.getLocations(data) locations := f.getLocations(data)
isSet, err := f.bitSet.check(locations) isSet, err := f.bitSet.check(locations)
if err != nil { if err != nil {
@@ -76,7 +76,7 @@ func (f *BloomFilter) Exists(data []byte) (bool, error) {
return true, nil return true, nil
} }
func (f *BloomFilter) getLocations(data []byte) []uint { func (f *Filter) getLocations(data []byte) []uint {
locations := make([]uint, maps) locations := make([]uint, maps)
for i := uint(0); i < maps; i++ { for i := uint(0); i < maps; i++ {
hashValue := hash.Hash(append(data, byte(i))) hashValue := hash.Hash(append(data, byte(i)))
@@ -127,11 +127,12 @@ func (r *redisBitSet) check(offsets []uint) (bool, error) {
return false, err return false, err
} }
if exists, ok := resp.(int64); !ok { exists, ok := resp.(int64)
if !ok {
return false, nil return false, nil
} else {
return exists == 1, nil
} }
return exists == 1, nil
} }
func (r *redisBitSet) del() error { func (r *redisBitSet) del() error {
@@ -152,7 +153,7 @@ func (r *redisBitSet) set(offsets []uint) error {
_, err = r.store.Eval(setScript, []string{r.key}, args) _, err = r.store.Eval(setScript, []string{r.key}, args)
if err == redis.Nil { if err == redis.Nil {
return nil return nil
} else {
return err
} }
return err
} }

View File

@@ -18,12 +18,14 @@ const (
timeFormat = "15:04:05" timeFormat = "15:04:05"
) )
// ErrServiceUnavailable is returned when the CB state is open // ErrServiceUnavailable is returned when the Breaker state is open.
var ErrServiceUnavailable = errors.New("circuit breaker is open") var ErrServiceUnavailable = errors.New("circuit breaker is open")
type ( type (
// Acceptable is the func to check if the error can be accepted.
Acceptable func(err error) bool Acceptable func(err error) bool
// A Breaker represents a circuit breaker.
Breaker interface { Breaker interface {
// Name returns the name of the Breaker. // Name returns the name of the Breaker.
Name() string Name() string
@@ -61,10 +63,14 @@ type (
DoWithFallbackAcceptable(req func() error, fallback func(err error) error, acceptable Acceptable) error DoWithFallbackAcceptable(req func() error, fallback func(err error) error, acceptable Acceptable) error
} }
// Option defines the method to customize a Breaker.
Option func(breaker *circuitBreaker) Option func(breaker *circuitBreaker)
// Promise interface defines the callbacks that returned by Breaker.Allow.
Promise interface { Promise interface {
// Accept tells the Breaker that the call is successful.
Accept() Accept()
// Reject tells the Breaker that the call is failed.
Reject(reason string) Reject(reason string)
} }
@@ -89,6 +95,8 @@ type (
} }
) )
// NewBreaker returns a Breaker object.
// opts can be used to customize the Breaker.
func NewBreaker(opts ...Option) Breaker { func NewBreaker(opts ...Option) Breaker {
var b circuitBreaker var b circuitBreaker
for _, opt := range opts { for _, opt := range opts {
@@ -127,6 +135,7 @@ func (cb *circuitBreaker) Name() string {
return cb.name return cb.name
} }
// WithName returns a function to set the name of a Breaker.
func WithName(name string) Option { func WithName(name string) Option {
return func(b *circuitBreaker) { return func(b *circuitBreaker) {
b.name = name b.name = name

View File

@@ -7,24 +7,28 @@ var (
breakers = make(map[string]Breaker) breakers = make(map[string]Breaker)
) )
// Do calls Breaker.Do on the Breaker with given name.
func Do(name string, req func() error) error { func Do(name string, req func() error) error {
return do(name, func(b Breaker) error { return do(name, func(b Breaker) error {
return b.Do(req) return b.Do(req)
}) })
} }
// DoWithAcceptable calls Breaker.DoWithAcceptable on the Breaker with given name.
func DoWithAcceptable(name string, req func() error, acceptable Acceptable) error { func DoWithAcceptable(name string, req func() error, acceptable Acceptable) error {
return do(name, func(b Breaker) error { return do(name, func(b Breaker) error {
return b.DoWithAcceptable(req, acceptable) return b.DoWithAcceptable(req, acceptable)
}) })
} }
// DoWithFallback calls Breaker.DoWithFallback on the Breaker with given name.
func DoWithFallback(name string, req func() error, fallback func(err error) error) error { func DoWithFallback(name string, req func() error, fallback func(err error) error) error {
return do(name, func(b Breaker) error { return do(name, func(b Breaker) error {
return b.DoWithFallback(req, fallback) return b.DoWithFallback(req, fallback)
}) })
} }
// DoWithFallbackAcceptable calls Breaker.DoWithFallbackAcceptable on the Breaker with given name.
func DoWithFallbackAcceptable(name string, req func() error, fallback func(err error) error, func DoWithFallbackAcceptable(name string, req func() error, fallback func(err error) error,
acceptable Acceptable) error { acceptable Acceptable) error {
return do(name, func(b Breaker) error { return do(name, func(b Breaker) error {
@@ -32,6 +36,7 @@ func DoWithFallbackAcceptable(name string, req func() error, fallback func(err e
}) })
} }
// GetBreaker returns the Breaker with the given name.
func GetBreaker(name string) Breaker { func GetBreaker(name string) Breaker {
lock.RLock() lock.RLock()
b, ok := breakers[name] b, ok := breakers[name]
@@ -51,7 +56,8 @@ func GetBreaker(name string) Breaker {
return b return b
} }
func NoBreakFor(name string) { // NoBreakerFor disables the circuit breaker for the given name.
func NoBreakerFor(name string) {
lock.Lock() lock.Lock()
breakers[name] = newNoOpBreaker() breakers[name] = newNoOpBreaker()
lock.Unlock() lock.Unlock()

View File

@@ -55,7 +55,7 @@ func TestBreakersDoWithAcceptable(t *testing.T) {
} }
func TestBreakersNoBreakerFor(t *testing.T) { func TestBreakersNoBreakerFor(t *testing.T) {
NoBreakFor("any") NoBreakerFor("any")
errDummy := errors.New("any") errDummy := errors.New("any")
for i := 0; i < 10000; i++ { for i := 0; i < 10000; i++ {
assert.Equal(t, errDummy, GetBreaker("any").Do(func() error { assert.Equal(t, errDummy, GetBreaker("any").Do(func() error {

View File

@@ -64,9 +64,9 @@ func (b *googleBreaker) doReq(req func() error, fallback func(err error) error,
if err := b.accept(); err != nil { if err := b.accept(); err != nil {
if fallback != nil { if fallback != nil {
return fallback(err) return fallback(err)
} else {
return err
} }
return err
} }
defer func() { defer func() {

View File

@@ -7,11 +7,13 @@ import (
"strings" "strings"
) )
// EnterToContinue let stdin waiting for an enter key to continue.
func EnterToContinue() { func EnterToContinue() {
fmt.Print("Press 'Enter' to continue...") fmt.Print("Press 'Enter' to continue...")
bufio.NewReader(os.Stdin).ReadBytes('\n') bufio.NewReader(os.Stdin).ReadBytes('\n')
} }
// ReadLine shows prompt to stdout and read a line from stdin.
func ReadLine(prompt string) string { func ReadLine(prompt string) string {
fmt.Print(prompt) fmt.Print(prompt)
input, _ := bufio.NewReader(os.Stdin).ReadString('\n') input, _ := bufio.NewReader(os.Stdin).ReadString('\n')

View File

@@ -10,6 +10,7 @@ import (
"github.com/tal-tech/go-zero/core/logx" "github.com/tal-tech/go-zero/core/logx"
) )
// ErrPaddingSize indicates bad padding size.
var ErrPaddingSize = errors.New("padding size error") var ErrPaddingSize = errors.New("padding size error")
type ecb struct { type ecb struct {
@@ -26,6 +27,7 @@ func newECB(b cipher.Block) *ecb {
type ecbEncrypter ecb type ecbEncrypter ecb
// NewECBEncrypter returns an ECB encrypter.
func NewECBEncrypter(b cipher.Block) cipher.BlockMode { func NewECBEncrypter(b cipher.Block) cipher.BlockMode {
return (*ecbEncrypter)(newECB(b)) return (*ecbEncrypter)(newECB(b))
} }
@@ -52,6 +54,7 @@ func (x *ecbEncrypter) CryptBlocks(dst, src []byte) {
type ecbDecrypter ecb type ecbDecrypter ecb
// NewECBDecrypter returns an ECB decrypter.
func NewECBDecrypter(b cipher.Block) cipher.BlockMode { func NewECBDecrypter(b cipher.Block) cipher.BlockMode {
return (*ecbDecrypter)(newECB(b)) return (*ecbDecrypter)(newECB(b))
} }
@@ -78,6 +81,7 @@ func (x *ecbDecrypter) CryptBlocks(dst, src []byte) {
} }
} }
// EcbDecrypt decrypts src with the given key.
func EcbDecrypt(key, src []byte) ([]byte, error) { func EcbDecrypt(key, src []byte) ([]byte, error) {
block, err := aes.NewCipher(key) block, err := aes.NewCipher(key)
if err != nil { if err != nil {
@@ -92,6 +96,8 @@ func EcbDecrypt(key, src []byte) ([]byte, error) {
return pkcs5Unpadding(decrypted, decrypter.BlockSize()) return pkcs5Unpadding(decrypted, decrypter.BlockSize())
} }
// EcbDecryptBase64 decrypts base64 encoded src with the given base64 encoded key.
// The returned string is also base64 encoded.
func EcbDecryptBase64(key, src string) (string, error) { func EcbDecryptBase64(key, src string) (string, error) {
keyBytes, err := getKeyBytes(key) keyBytes, err := getKeyBytes(key)
if err != nil { if err != nil {
@@ -111,6 +117,7 @@ func EcbDecryptBase64(key, src string) (string, error) {
return base64.StdEncoding.EncodeToString(decryptedBytes), nil return base64.StdEncoding.EncodeToString(decryptedBytes), nil
} }
// EcbEncrypt encrypts src with the given key.
func EcbEncrypt(key, src []byte) ([]byte, error) { func EcbEncrypt(key, src []byte) ([]byte, error) {
block, err := aes.NewCipher(key) block, err := aes.NewCipher(key)
if err != nil { if err != nil {
@@ -126,6 +133,8 @@ func EcbEncrypt(key, src []byte) ([]byte, error) {
return crypted, nil return crypted, nil
} }
// EcbEncryptBase64 encrypts base64 encoded src with the given base64 encoded key.
// The returned string is also base64 encoded.
func EcbEncryptBase64(key, src string) (string, error) { func EcbEncryptBase64(key, src string) (string, error) {
keyBytes, err := getKeyBytes(key) keyBytes, err := getKeyBytes(key)
if err != nil { if err != nil {
@@ -146,15 +155,16 @@ func EcbEncryptBase64(key, src string) (string, error) {
} }
func getKeyBytes(key string) ([]byte, error) { func getKeyBytes(key string) ([]byte, error) {
if len(key) > 32 { if len(key) <= 32 {
if keyBytes, err := base64.StdEncoding.DecodeString(key); err != nil { return []byte(key), nil
return nil, err
} else {
return keyBytes, nil
}
} }
return []byte(key), nil keyBytes, err := base64.StdEncoding.DecodeString(key)
if err != nil {
return nil, err
}
return keyBytes, nil
} }
func pkcs5Padding(ciphertext []byte, blockSize int) []byte { func pkcs5Padding(ciphertext []byte, blockSize int) []byte {

65
core/codec/aesecb_test.go Normal file
View File

@@ -0,0 +1,65 @@
package codec
import (
"encoding/base64"
"testing"
"github.com/stretchr/testify/assert"
)
func TestAesEcb(t *testing.T) {
var (
key = []byte("q4t7w!z%C*F-JaNdRgUjXn2r5u8x/A?D")
val = []byte("hello")
badKey1 = []byte("aaaaaaaaa")
// more than 32 chars
badKey2 = []byte("aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa")
)
_, err := EcbEncrypt(badKey1, val)
assert.NotNil(t, err)
_, err = EcbEncrypt(badKey2, val)
assert.NotNil(t, err)
dst, err := EcbEncrypt(key, val)
assert.Nil(t, err)
_, err = EcbDecrypt(badKey1, dst)
assert.NotNil(t, err)
_, err = EcbDecrypt(badKey2, dst)
assert.NotNil(t, err)
_, err = EcbDecrypt(key, val)
// not enough block, just nil
assert.Nil(t, err)
src, err := EcbDecrypt(key, dst)
assert.Nil(t, err)
assert.Equal(t, val, src)
}
func TestAesEcbBase64(t *testing.T) {
const (
val = "hello"
badKey1 = "aaaaaaaaa"
// more than 32 chars
badKey2 = "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa"
)
var key = []byte("q4t7w!z%C*F-JaNdRgUjXn2r5u8x/A?D")
b64Key := base64.StdEncoding.EncodeToString(key)
b64Val := base64.StdEncoding.EncodeToString([]byte(val))
_, err := EcbEncryptBase64(badKey1, val)
assert.NotNil(t, err)
_, err = EcbEncryptBase64(badKey2, val)
assert.NotNil(t, err)
_, err = EcbEncryptBase64(b64Key, val)
assert.NotNil(t, err)
dst, err := EcbEncryptBase64(b64Key, b64Val)
assert.Nil(t, err)
_, err = EcbDecryptBase64(badKey1, dst)
assert.NotNil(t, err)
_, err = EcbDecryptBase64(badKey2, dst)
assert.NotNil(t, err)
_, err = EcbDecryptBase64(b64Key, val)
assert.NotNil(t, err)
src, err := EcbDecryptBase64(b64Key, dst)
assert.Nil(t, err)
b, err := base64.StdEncoding.DecodeString(src)
assert.Nil(t, err)
assert.Equal(t, val, string(b))
}

View File

@@ -11,8 +11,11 @@ import (
// 2048-bit MODP Group // 2048-bit MODP Group
var ( var (
ErrInvalidPriKey = errors.New("invalid private key") // ErrInvalidPriKey indicates the invalid private key.
ErrInvalidPubKey = errors.New("invalid public key") ErrInvalidPriKey = errors.New("invalid private key")
// ErrInvalidPubKey indicates the invalid public key.
ErrInvalidPubKey = errors.New("invalid public key")
// ErrPubKeyOutOfBound indicates the public key is out of bound.
ErrPubKeyOutOfBound = errors.New("public key out of bound") ErrPubKeyOutOfBound = errors.New("public key out of bound")
p, _ = new(big.Int).SetString("FFFFFFFFFFFFFFFFC90FDAA22168C234C4C6628B80DC1CD129024E088A67CC74020BBEA63B139B22514A08798E3404DDEF9519B3CD3A431B302B0A6DF25F14374FE1356D6D51C245E485B576625E7EC6F44C42E9A637ED6B0BFF5CB6F406B7EDEE386BFB5A899FA5AE9F24117C4B1FE649286651ECE45B3DC2007CB8A163BF0598DA48361C55D39A69163FA8FD24CF5F83655D23DCA3AD961C62F356208552BB9ED529077096966D670C354E4ABC9804F1746C08CA18217C32905E462E36CE3BE39E772C180E86039B2783A2EC07A28FB5C55DF06F4C52C9DE2BCBF6955817183995497CEA956AE515D2261898FA051015728E5A8AACAA68FFFFFFFFFFFFFFFF", 16) p, _ = new(big.Int).SetString("FFFFFFFFFFFFFFFFC90FDAA22168C234C4C6628B80DC1CD129024E088A67CC74020BBEA63B139B22514A08798E3404DDEF9519B3CD3A431B302B0A6DF25F14374FE1356D6D51C245E485B576625E7EC6F44C42E9A637ED6B0BFF5CB6F406B7EDEE386BFB5A899FA5AE9F24117C4B1FE649286651ECE45B3DC2007CB8A163BF0598DA48361C55D39A69163FA8FD24CF5F83655D23DCA3AD961C62F356208552BB9ED529077096966D670C354E4ABC9804F1746C08CA18217C32905E462E36CE3BE39E772C180E86039B2783A2EC07A28FB5C55DF06F4C52C9DE2BCBF6955817183995497CEA956AE515D2261898FA051015728E5A8AACAA68FFFFFFFFFFFFFFFF", 16)
@@ -20,11 +23,13 @@ var (
zero = big.NewInt(0) zero = big.NewInt(0)
) )
// DhKey defines the Diffie Hellman key.
type DhKey struct { type DhKey struct {
PriKey *big.Int PriKey *big.Int
PubKey *big.Int PubKey *big.Int
} }
// ComputeKey returns a key from public key and private key.
func ComputeKey(pubKey, priKey *big.Int) (*big.Int, error) { func ComputeKey(pubKey, priKey *big.Int) (*big.Int, error) {
if pubKey == nil { if pubKey == nil {
return nil, ErrInvalidPubKey return nil, ErrInvalidPubKey
@@ -41,6 +46,7 @@ func ComputeKey(pubKey, priKey *big.Int) (*big.Int, error) {
return new(big.Int).Exp(pubKey, priKey, p), nil return new(big.Int).Exp(pubKey, priKey, p), nil
} }
// GenerateKey returns a Diffie Hellman key.
func GenerateKey() (*DhKey, error) { func GenerateKey() (*DhKey, error) {
var err error var err error
var x *big.Int var x *big.Int
@@ -63,10 +69,12 @@ func GenerateKey() (*DhKey, error) {
return key, nil return key, nil
} }
// NewPublicKey returns a public key from the given bytes.
func NewPublicKey(bs []byte) *big.Int { func NewPublicKey(bs []byte) *big.Int {
return new(big.Int).SetBytes(bs) return new(big.Int).SetBytes(bs)
} }
// Bytes returns public key bytes.
func (k *DhKey) Bytes() []byte { func (k *DhKey) Bytes() []byte {
if k.PubKey == nil { if k.PubKey == nil {
return nil return nil

View File

@@ -8,6 +8,7 @@ import (
const unzipLimit = 100 * 1024 * 1024 // 100MB const unzipLimit = 100 * 1024 * 1024 // 100MB
// Gzip compresses bs.
func Gzip(bs []byte) []byte { func Gzip(bs []byte) []byte {
var b bytes.Buffer var b bytes.Buffer
@@ -18,6 +19,7 @@ func Gzip(bs []byte) []byte {
return b.Bytes() return b.Bytes()
} }
// Gunzip uncompresses bs.
func Gunzip(bs []byte) ([]byte, error) { func Gunzip(bs []byte) ([]byte, error) {
r, err := gzip.NewReader(bytes.NewBuffer(bs)) r, err := gzip.NewReader(bytes.NewBuffer(bs))
if err != nil { if err != nil {

View File

@@ -7,12 +7,14 @@ import (
"io" "io"
) )
// Hmac returns HMAC bytes for body with the given key.
func Hmac(key []byte, body string) []byte { func Hmac(key []byte, body string) []byte {
h := hmac.New(sha256.New, key) h := hmac.New(sha256.New, key)
io.WriteString(h, body) io.WriteString(h, body)
return h.Sum(nil) return h.Sum(nil)
} }
// HmacBase64 returns the base64 encoded string of HMAC for body with the given key.
func HmacBase64(key []byte, body string) string { func HmacBase64(key []byte, body string) string {
return base64.StdEncoding.EncodeToString(Hmac(key, body)) return base64.StdEncoding.EncodeToString(Hmac(key, body))
} }

View File

@@ -11,17 +11,22 @@ import (
) )
var ( var (
// ErrPrivateKey indicates the invalid private key.
ErrPrivateKey = errors.New("private key error") ErrPrivateKey = errors.New("private key error")
ErrPublicKey = errors.New("failed to parse PEM block containing the public key") // ErrPublicKey indicates the invalid public key.
ErrNotRsaKey = errors.New("key type is not RSA") ErrPublicKey = errors.New("failed to parse PEM block containing the public key")
// ErrNotRsaKey indicates the invalid RSA key.
ErrNotRsaKey = errors.New("key type is not RSA")
) )
type ( type (
// RsaDecrypter represents a RSA decrypter.
RsaDecrypter interface { RsaDecrypter interface {
Decrypt(input []byte) ([]byte, error) Decrypt(input []byte) ([]byte, error)
DecryptBase64(input string) ([]byte, error) DecryptBase64(input string) ([]byte, error)
} }
// RsaEncrypter represents a RSA encrypter.
RsaEncrypter interface { RsaEncrypter interface {
Encrypt(input []byte) ([]byte, error) Encrypt(input []byte) ([]byte, error)
} }
@@ -41,6 +46,7 @@ type (
} }
) )
// NewRsaDecrypter returns a RsaDecrypter with the given file.
func NewRsaDecrypter(file string) (RsaDecrypter, error) { func NewRsaDecrypter(file string) (RsaDecrypter, error) {
content, err := ioutil.ReadFile(file) content, err := ioutil.ReadFile(file)
if err != nil { if err != nil {
@@ -84,6 +90,7 @@ func (r *rsaDecrypter) DecryptBase64(input string) ([]byte, error) {
return r.Decrypt(base64Decoded) return r.Decrypt(base64Decoded)
} }
// NewRsaEncrypter returns a RsaEncrypter with the given key.
func NewRsaEncrypter(key []byte) (RsaEncrypter, error) { func NewRsaEncrypter(key []byte) (RsaEncrypter, error) {
block, _ := pem.Decode(key) block, _ := pem.Decode(key)
if block == nil { if block == nil {

View File

@@ -23,8 +23,10 @@ const (
var emptyLruCache = emptyLru{} var emptyLruCache = emptyLru{}
type ( type (
// CacheOption defines the method to customize a Cache.
CacheOption func(cache *Cache) CacheOption func(cache *Cache)
// A Cache object is a in-memory cache.
Cache struct { Cache struct {
name string name string
lock sync.Mutex lock sync.Mutex
@@ -38,6 +40,7 @@ type (
} }
) )
// NewCache returns a Cache with given expire.
func NewCache(expire time.Duration, opts ...CacheOption) (*Cache, error) { func NewCache(expire time.Duration, opts ...CacheOption) (*Cache, error) {
cache := &Cache{ cache := &Cache{
data: make(map[string]interface{}), data: make(map[string]interface{}),
@@ -72,6 +75,7 @@ func NewCache(expire time.Duration, opts ...CacheOption) (*Cache, error) {
return cache, nil return cache, nil
} }
// Del deletes the item with the given key from c.
func (c *Cache) Del(key string) { func (c *Cache) Del(key string) {
c.lock.Lock() c.lock.Lock()
delete(c.data, key) delete(c.data, key)
@@ -80,6 +84,7 @@ func (c *Cache) Del(key string) {
c.timingWheel.RemoveTimer(key) c.timingWheel.RemoveTimer(key)
} }
// Get returns the item with the given key from c.
func (c *Cache) Get(key string) (interface{}, bool) { func (c *Cache) Get(key string) (interface{}, bool) {
value, ok := c.doGet(key) value, ok := c.doGet(key)
if ok { if ok {
@@ -91,6 +96,7 @@ func (c *Cache) Get(key string) (interface{}, bool) {
return value, ok return value, ok
} }
// Set sets value into c with key.
func (c *Cache) Set(key string, value interface{}) { func (c *Cache) Set(key string, value interface{}) {
c.lock.Lock() c.lock.Lock()
_, ok := c.data[key] _, ok := c.data[key]
@@ -106,6 +112,9 @@ func (c *Cache) Set(key string, value interface{}) {
} }
} }
// Take returns the item with the given key.
// If the item is in c, return it directly.
// If not, use fetch method to get the item, set into c and return it.
func (c *Cache) Take(key string, fetch func() (interface{}, error)) (interface{}, error) { func (c *Cache) Take(key string, fetch func() (interface{}, error)) (interface{}, error) {
if val, ok := c.doGet(key); ok { if val, ok := c.doGet(key); ok {
c.stats.IncrementHit() c.stats.IncrementHit()
@@ -136,11 +145,10 @@ func (c *Cache) Take(key string, fetch func() (interface{}, error)) (interface{}
if fresh { if fresh {
c.stats.IncrementMiss() c.stats.IncrementMiss()
return val, nil return val, nil
} else {
// got the result from previous ongoing query
c.stats.IncrementHit()
} }
// got the result from previous ongoing query
c.stats.IncrementHit()
return val, nil return val, nil
} }
@@ -168,6 +176,7 @@ func (c *Cache) size() int {
return len(c.data) return len(c.data)
} }
// WithLimit customizes a Cache with items up to limit.
func WithLimit(limit int) CacheOption { func WithLimit(limit int) CacheOption {
return func(cache *Cache) { return func(cache *Cache) {
if limit > 0 { if limit > 0 {
@@ -176,6 +185,7 @@ func WithLimit(limit int) CacheOption {
} }
} }
// WithName customizes a Cache with the given name.
func WithName(name string) CacheOption { func WithName(name string) CacheOption {
return func(cache *Cache) { return func(cache *Cache) {
cache.name = name cache.name = name

View File

@@ -2,6 +2,7 @@ package collection
import "sync" import "sync"
// A Queue is a FIFO queue.
type Queue struct { type Queue struct {
lock sync.Mutex lock sync.Mutex
elements []interface{} elements []interface{}
@@ -11,6 +12,7 @@ type Queue struct {
count int count int
} }
// NewQueue returns a Queue object.
func NewQueue(size int) *Queue { func NewQueue(size int) *Queue {
return &Queue{ return &Queue{
elements: make([]interface{}, size), elements: make([]interface{}, size),
@@ -18,6 +20,7 @@ func NewQueue(size int) *Queue {
} }
} }
// Empty checks if q is empty.
func (q *Queue) Empty() bool { func (q *Queue) Empty() bool {
q.lock.Lock() q.lock.Lock()
empty := q.count == 0 empty := q.count == 0
@@ -26,6 +29,7 @@ func (q *Queue) Empty() bool {
return empty return empty
} }
// Put puts element into q at the last position.
func (q *Queue) Put(element interface{}) { func (q *Queue) Put(element interface{}) {
q.lock.Lock() q.lock.Lock()
defer q.lock.Unlock() defer q.lock.Unlock()
@@ -44,6 +48,7 @@ func (q *Queue) Put(element interface{}) {
q.count++ q.count++
} }
// Take takes the first element out of q if not empty.
func (q *Queue) Take() (interface{}, bool) { func (q *Queue) Take() (interface{}, bool) {
q.lock.Lock() q.lock.Lock()
defer q.lock.Unlock() defer q.lock.Unlock()

View File

@@ -1,10 +1,15 @@
package collection package collection
import "sync"
// A Ring can be used as fixed size ring.
type Ring struct { type Ring struct {
elements []interface{} elements []interface{}
index int index int
lock sync.Mutex
} }
// NewRing returns a Ring object with the given size n.
func NewRing(n int) *Ring { func NewRing(n int) *Ring {
if n < 1 { if n < 1 {
panic("n should be greater than 0") panic("n should be greater than 0")
@@ -15,12 +20,20 @@ func NewRing(n int) *Ring {
} }
} }
// Add adds v into r.
func (r *Ring) Add(v interface{}) { func (r *Ring) Add(v interface{}) {
r.lock.Lock()
defer r.lock.Unlock()
r.elements[r.index%len(r.elements)] = v r.elements[r.index%len(r.elements)] = v
r.index++ r.index++
} }
// Take takes all items from r.
func (r *Ring) Take() []interface{} { func (r *Ring) Take() []interface{} {
r.lock.Lock()
defer r.lock.Unlock()
var size int var size int
var start int var start int
if r.index > len(r.elements) { if r.index > len(r.elements) {

View File

@@ -1,6 +1,7 @@
package collection package collection
import ( import (
"sync"
"testing" "testing"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
@@ -29,3 +30,30 @@ func TestRingMore(t *testing.T) {
elements := ring.Take() elements := ring.Take()
assert.ElementsMatch(t, []interface{}{6, 7, 8, 9, 10}, elements) assert.ElementsMatch(t, []interface{}{6, 7, 8, 9, 10}, elements)
} }
func TestRingAdd(t *testing.T) {
ring := NewRing(5051)
wg := sync.WaitGroup{}
for i := 1; i <= 100; i++ {
wg.Add(1)
go func(i int) {
defer wg.Done()
for j := 1; j <= i; j++ {
ring.Add(i)
}
}(i)
}
wg.Wait()
assert.Equal(t, 5050, len(ring.Take()))
}
func BenchmarkRingAdd(b *testing.B) {
ring := NewRing(500)
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
for i := 0; i < b.N; i++ {
ring.Add(i)
}
}
})
}

View File

@@ -73,9 +73,9 @@ func (rw *RollingWindow) span() int {
offset := int(timex.Since(rw.lastTime) / rw.interval) offset := int(timex.Since(rw.lastTime) / rw.interval)
if 0 <= offset && offset < rw.size { if 0 <= offset && offset < rw.size {
return offset return offset
} else {
return rw.size
} }
return rw.size
} }
func (rw *RollingWindow) updateOffset() { func (rw *RollingWindow) updateOffset() {

View File

@@ -18,6 +18,7 @@ type SafeMap struct {
dirtyNew map[interface{}]interface{} dirtyNew map[interface{}]interface{}
} }
// NewSafeMap returns a SafeMap.
func NewSafeMap() *SafeMap { func NewSafeMap() *SafeMap {
return &SafeMap{ return &SafeMap{
dirtyOld: make(map[interface{}]interface{}), dirtyOld: make(map[interface{}]interface{}),
@@ -25,6 +26,7 @@ func NewSafeMap() *SafeMap {
} }
} }
// Del deletes the value with the given key from m.
func (m *SafeMap) Del(key interface{}) { func (m *SafeMap) Del(key interface{}) {
m.lock.Lock() m.lock.Lock()
if _, ok := m.dirtyOld[key]; ok { if _, ok := m.dirtyOld[key]; ok {
@@ -53,18 +55,20 @@ func (m *SafeMap) Del(key interface{}) {
m.lock.Unlock() m.lock.Unlock()
} }
// Get gets the value with the given key from m.
func (m *SafeMap) Get(key interface{}) (interface{}, bool) { func (m *SafeMap) Get(key interface{}) (interface{}, bool) {
m.lock.RLock() m.lock.RLock()
defer m.lock.RUnlock() defer m.lock.RUnlock()
if val, ok := m.dirtyOld[key]; ok { if val, ok := m.dirtyOld[key]; ok {
return val, true return val, true
} else {
val, ok := m.dirtyNew[key]
return val, ok
} }
val, ok := m.dirtyNew[key]
return val, ok
} }
// Set sets the value into m with the given key.
func (m *SafeMap) Set(key, value interface{}) { func (m *SafeMap) Set(key, value interface{}) {
m.lock.Lock() m.lock.Lock()
if m.deletionOld <= maxDeletion { if m.deletionOld <= maxDeletion {
@@ -83,6 +87,7 @@ func (m *SafeMap) Set(key, value interface{}) {
m.lock.Unlock() m.lock.Unlock()
} }
// Size returns the size of m.
func (m *SafeMap) Size() int { func (m *SafeMap) Size() int {
m.lock.RLock() m.lock.RLock()
size := len(m.dirtyOld) + len(m.dirtyNew) size := len(m.dirtyOld) + len(m.dirtyNew)

View File

@@ -21,6 +21,7 @@ type Set struct {
tp int tp int
} }
// NewSet returns a managed Set, can only put the values with the same type.
func NewSet() *Set { func NewSet() *Set {
return &Set{ return &Set{
data: make(map[interface{}]lang.PlaceholderType), data: make(map[interface{}]lang.PlaceholderType),
@@ -28,6 +29,7 @@ func NewSet() *Set {
} }
} }
// NewUnmanagedSet returns a unmanaged Set, which can put values with different types.
func NewUnmanagedSet() *Set { func NewUnmanagedSet() *Set {
return &Set{ return &Set{
data: make(map[interface{}]lang.PlaceholderType), data: make(map[interface{}]lang.PlaceholderType),
@@ -35,42 +37,49 @@ func NewUnmanagedSet() *Set {
} }
} }
// Add adds i into s.
func (s *Set) Add(i ...interface{}) { func (s *Set) Add(i ...interface{}) {
for _, each := range i { for _, each := range i {
s.add(each) s.add(each)
} }
} }
// AddInt adds int values ii into s.
func (s *Set) AddInt(ii ...int) { func (s *Set) AddInt(ii ...int) {
for _, each := range ii { for _, each := range ii {
s.add(each) s.add(each)
} }
} }
// AddInt64 adds int64 values ii into s.
func (s *Set) AddInt64(ii ...int64) { func (s *Set) AddInt64(ii ...int64) {
for _, each := range ii { for _, each := range ii {
s.add(each) s.add(each)
} }
} }
// AddUint adds uint values ii into s.
func (s *Set) AddUint(ii ...uint) { func (s *Set) AddUint(ii ...uint) {
for _, each := range ii { for _, each := range ii {
s.add(each) s.add(each)
} }
} }
// AddUint64 adds uint64 values ii into s.
func (s *Set) AddUint64(ii ...uint64) { func (s *Set) AddUint64(ii ...uint64) {
for _, each := range ii { for _, each := range ii {
s.add(each) s.add(each)
} }
} }
// AddStr adds string values ss into s.
func (s *Set) AddStr(ss ...string) { func (s *Set) AddStr(ss ...string) {
for _, each := range ss { for _, each := range ss {
s.add(each) s.add(each)
} }
} }
// Contains checks if i is in s.
func (s *Set) Contains(i interface{}) bool { func (s *Set) Contains(i interface{}) bool {
if len(s.data) == 0 { if len(s.data) == 0 {
return false return false
@@ -81,6 +90,7 @@ func (s *Set) Contains(i interface{}) bool {
return ok return ok
} }
// Keys returns the keys in s.
func (s *Set) Keys() []interface{} { func (s *Set) Keys() []interface{} {
var keys []interface{} var keys []interface{}
@@ -91,6 +101,7 @@ func (s *Set) Keys() []interface{} {
return keys return keys
} }
// KeysInt returns the int keys in s.
func (s *Set) KeysInt() []int { func (s *Set) KeysInt() []int {
var keys []int var keys []int
@@ -105,6 +116,7 @@ func (s *Set) KeysInt() []int {
return keys return keys
} }
// KeysInt64 returns int64 keys in s.
func (s *Set) KeysInt64() []int64 { func (s *Set) KeysInt64() []int64 {
var keys []int64 var keys []int64
@@ -119,6 +131,7 @@ func (s *Set) KeysInt64() []int64 {
return keys return keys
} }
// KeysUint returns uint keys in s.
func (s *Set) KeysUint() []uint { func (s *Set) KeysUint() []uint {
var keys []uint var keys []uint
@@ -133,6 +146,7 @@ func (s *Set) KeysUint() []uint {
return keys return keys
} }
// KeysUint64 returns uint64 keys in s.
func (s *Set) KeysUint64() []uint64 { func (s *Set) KeysUint64() []uint64 {
var keys []uint64 var keys []uint64
@@ -147,6 +161,7 @@ func (s *Set) KeysUint64() []uint64 {
return keys return keys
} }
// KeysStr returns string keys in s.
func (s *Set) KeysStr() []string { func (s *Set) KeysStr() []string {
var keys []string var keys []string
@@ -161,11 +176,13 @@ func (s *Set) KeysStr() []string {
return keys return keys
} }
// Remove removes i from s.
func (s *Set) Remove(i interface{}) { func (s *Set) Remove(i interface{}) {
s.validate(i) s.validate(i)
delete(s.data, i) delete(s.data, i)
} }
// Count returns the number of items in s.
func (s *Set) Count() int { func (s *Set) Count() int {
return len(s.data) return len(s.data)
} }

View File

@@ -13,8 +13,10 @@ import (
const drainWorkers = 8 const drainWorkers = 8
type ( type (
// Execute defines the method to execute the task.
Execute func(key, value interface{}) Execute func(key, value interface{})
// A TimingWheel is a timing wheel object to schedule tasks.
TimingWheel struct { TimingWheel struct {
interval time.Duration interval time.Duration
ticker timex.Ticker ticker timex.Ticker
@@ -54,6 +56,7 @@ type (
} }
) )
// NewTimingWheel returns a TimingWheel.
func NewTimingWheel(interval time.Duration, numSlots int, execute Execute) (*TimingWheel, error) { func NewTimingWheel(interval time.Duration, numSlots int, execute Execute) (*TimingWheel, error) {
if interval <= 0 || numSlots <= 0 || execute == nil { if interval <= 0 || numSlots <= 0 || execute == nil {
return nil, fmt.Errorf("interval: %v, slots: %d, execute: %p", interval, numSlots, execute) return nil, fmt.Errorf("interval: %v, slots: %d, execute: %p", interval, numSlots, execute)
@@ -85,10 +88,12 @@ func newTimingWheelWithClock(interval time.Duration, numSlots int, execute Execu
return tw, nil return tw, nil
} }
// Drain drains all items and executes them.
func (tw *TimingWheel) Drain(fn func(key, value interface{})) { func (tw *TimingWheel) Drain(fn func(key, value interface{})) {
tw.drainChannel <- fn tw.drainChannel <- fn
} }
// MoveTimer moves the task with the given key to the given delay.
func (tw *TimingWheel) MoveTimer(key interface{}, delay time.Duration) { func (tw *TimingWheel) MoveTimer(key interface{}, delay time.Duration) {
if delay <= 0 || key == nil { if delay <= 0 || key == nil {
return return
@@ -100,6 +105,7 @@ func (tw *TimingWheel) MoveTimer(key interface{}, delay time.Duration) {
} }
} }
// RemoveTimer removes the task with the given key.
func (tw *TimingWheel) RemoveTimer(key interface{}) { func (tw *TimingWheel) RemoveTimer(key interface{}) {
if key == nil { if key == nil {
return return
@@ -108,6 +114,7 @@ func (tw *TimingWheel) RemoveTimer(key interface{}) {
tw.removeChannel <- key tw.removeChannel <- key
} }
// SetTimer sets the task value with the given key to the delay.
func (tw *TimingWheel) SetTimer(key, value interface{}, delay time.Duration) { func (tw *TimingWheel) SetTimer(key, value interface{}, delay time.Duration) {
if delay <= 0 || key == nil { if delay <= 0 || key == nil {
return return
@@ -122,6 +129,7 @@ func (tw *TimingWheel) SetTimer(key, value interface{}, delay time.Duration) {
} }
} }
// Stop stops tw.
func (tw *TimingWheel) Stop() { func (tw *TimingWheel) Stop() {
close(tw.stopChannel) close(tw.stopChannel)
} }

View File

@@ -16,26 +16,43 @@ var loaders = map[string]func([]byte, interface{}) error{
".yml": LoadConfigFromYamlBytes, ".yml": LoadConfigFromYamlBytes,
} }
func LoadConfig(file string, v interface{}) error { // LoadConfig loads config into v from file, .json, .yaml and .yml are acceptable.
if content, err := ioutil.ReadFile(file); err != nil { func LoadConfig(file string, v interface{}, opts ...Option) error {
content, err := ioutil.ReadFile(file)
if err != nil {
return err return err
} else if loader, ok := loaders[path.Ext(file)]; ok { }
return loader([]byte(os.ExpandEnv(string(content))), v)
} else { loader, ok := loaders[path.Ext(file)]
if !ok {
return fmt.Errorf("unrecoginized file type: %s", file) return fmt.Errorf("unrecoginized file type: %s", file)
} }
var opt options
for _, o := range opts {
o(&opt)
}
if opt.env {
return loader([]byte(os.ExpandEnv(string(content))), v)
}
return loader(content, v)
} }
// LoadConfigFromJsonBytes loads config into v from content json bytes.
func LoadConfigFromJsonBytes(content []byte, v interface{}) error { func LoadConfigFromJsonBytes(content []byte, v interface{}) error {
return mapping.UnmarshalJsonBytes(content, v) return mapping.UnmarshalJsonBytes(content, v)
} }
// LoadConfigFromYamlBytes loads config into v from content yaml bytes.
func LoadConfigFromYamlBytes(content []byte, v interface{}) error { func LoadConfigFromYamlBytes(content []byte, v interface{}) error {
return mapping.UnmarshalYamlBytes(content, v) return mapping.UnmarshalYamlBytes(content, v)
} }
func MustLoad(path string, v interface{}) { // MustLoad loads config into v from path, exits on error.
if err := LoadConfig(path, v); err != nil { func MustLoad(path string, v interface{}, opts ...Option) {
if err := LoadConfig(path, v, opts...); err != nil {
log.Fatalf("error: config file %s, %s", path, err.Error()) log.Fatalf("error: config file %s, %s", path, err.Error())
} }
} }

View File

@@ -30,7 +30,8 @@ func TestConfigJson(t *testing.T) {
text := `{ text := `{
"a": "foo", "a": "foo",
"b": 1, "b": 1,
"c": "${FOO}" "c": "${FOO}",
"d": "abcd!@#$112"
}` }`
for _, test := range tests { for _, test := range tests {
test := test test := test
@@ -45,11 +46,49 @@ func TestConfigJson(t *testing.T) {
A string `json:"a"` A string `json:"a"`
B int `json:"b"` B int `json:"b"`
C string `json:"c"` C string `json:"c"`
D string `json:"d"`
} }
MustLoad(tmpfile, &val) MustLoad(tmpfile, &val)
assert.Equal(t, "foo", val.A) assert.Equal(t, "foo", val.A)
assert.Equal(t, 1, val.B) assert.Equal(t, 1, val.B)
assert.Equal(t, "${FOO}", val.C)
assert.Equal(t, "abcd!@#$112", val.D)
})
}
}
func TestConfigJsonEnv(t *testing.T) {
tests := []string{
".json",
".yaml",
".yml",
}
text := `{
"a": "foo",
"b": 1,
"c": "${FOO}",
"d": "abcd!@#$a12 3"
}`
for _, test := range tests {
test := test
t.Run(test, func(t *testing.T) {
os.Setenv("FOO", "2")
defer os.Unsetenv("FOO")
tmpfile, err := createTempFile(test, text)
assert.Nil(t, err)
defer os.Remove(tmpfile)
var val struct {
A string `json:"a"`
B int `json:"b"`
C string `json:"c"`
D string `json:"d"`
}
MustLoad(tmpfile, &val, UseEnv())
assert.Equal(t, "foo", val.A)
assert.Equal(t, 1, val.B)
assert.Equal(t, "2", val.C) assert.Equal(t, "2", val.C)
assert.Equal(t, "abcd!@# 3", val.D)
}) })
} }
} }

17
core/conf/options.go Normal file
View File

@@ -0,0 +1,17 @@
package conf
type (
// Option defines the method to customize the config options.
Option func(opt *options)
options struct {
env bool
}
)
// UseEnv customizes the config to use environment variables.
func UseEnv() Option {
return func(opt *options) {
opt.env = true
}
}

View File

@@ -2,6 +2,7 @@ package conf
import ( import (
"fmt" "fmt"
"os"
"strconv" "strconv"
"strings" "strings"
"sync" "sync"
@@ -30,14 +31,19 @@ type mapBasedProperties struct {
lock sync.RWMutex lock sync.RWMutex
} }
// Loads the properties into a properties configuration instance. // LoadProperties loads the properties into a properties configuration instance.
// Returns an error that indicates if there was a problem loading the configuration. // Returns an error that indicates if there was a problem loading the configuration.
func LoadProperties(filename string) (Properties, error) { func LoadProperties(filename string, opts ...Option) (Properties, error) {
lines, err := iox.ReadTextLines(filename, iox.WithoutBlank(), iox.OmitWithPrefix("#")) lines, err := iox.ReadTextLines(filename, iox.WithoutBlank(), iox.OmitWithPrefix("#"))
if err != nil { if err != nil {
return nil, err return nil, err
} }
var opt options
for _, o := range opts {
o(&opt)
}
raw := make(map[string]string) raw := make(map[string]string)
for i := range lines { for i := range lines {
pair := strings.Split(lines[i], "=") pair := strings.Split(lines[i], "=")
@@ -50,7 +56,11 @@ func LoadProperties(filename string) (Properties, error) {
key := strings.TrimSpace(pair[0]) key := strings.TrimSpace(pair[0])
value := strings.TrimSpace(pair[1]) value := strings.TrimSpace(pair[1])
raw[key] = value if opt.env {
raw[key] = os.ExpandEnv(value)
} else {
raw[key] = value
}
} }
return &mapBasedProperties{ return &mapBasedProperties{
@@ -87,7 +97,7 @@ func (config *mapBasedProperties) SetInt(key string, value int) {
config.lock.Unlock() config.lock.Unlock()
} }
// Dumps the configuration internal map into a string. // ToString dumps the configuration internal map into a string.
func (config *mapBasedProperties) ToString() string { func (config *mapBasedProperties) ToString() string {
config.lock.RLock() config.lock.RLock()
ret := fmt.Sprintf("%s", config.properties) ret := fmt.Sprintf("%s", config.properties)
@@ -96,12 +106,12 @@ func (config *mapBasedProperties) ToString() string {
return ret return ret
} }
// Returns the error message. // Error returns the error message.
func (configError *PropertyError) Error() string { func (configError *PropertyError) Error() string {
return configError.message return configError.message
} }
// Builds a new properties configuration structure // NewProperties builds a new properties configuration structure.
func NewProperties() Properties { func NewProperties() Properties {
return &mapBasedProperties{ return &mapBasedProperties{
properties: make(map[string]string), properties: make(map[string]string),

View File

@@ -31,6 +31,39 @@ func TestProperties(t *testing.T) {
assert.Contains(t, val, "app.threads") assert.Contains(t, val, "app.threads")
} }
func TestPropertiesEnv(t *testing.T) {
text := `app.name = test
app.program=app
app.env1 = ${FOO}
app.env2 = $none
# this is comment
app.threads = 5`
tmpfile, err := fs.TempFilenameWithText(text)
assert.Nil(t, err)
defer os.Remove(tmpfile)
os.Setenv("FOO", "2")
defer os.Unsetenv("FOO")
props, err := LoadProperties(tmpfile, UseEnv())
assert.Nil(t, err)
assert.Equal(t, "test", props.GetString("app.name"))
assert.Equal(t, "app", props.GetString("app.program"))
assert.Equal(t, 5, props.GetInt("app.threads"))
assert.Equal(t, "2", props.GetString("app.env1"))
assert.Equal(t, "", props.GetString("app.env2"))
val := props.ToString()
assert.Contains(t, val, "app.name")
assert.Contains(t, val, "app.program")
assert.Contains(t, val, "app.threads")
assert.Contains(t, val, "app.env1")
assert.Contains(t, val, "app.env2")
}
func TestLoadProperties_badContent(t *testing.T) { func TestLoadProperties_badContent(t *testing.T) {
filename, err := fs.TempFilenameWithText("hello") filename, err := fs.TempFilenameWithText("hello")
assert.Nil(t, err) assert.Nil(t, err)

View File

@@ -5,6 +5,8 @@ import (
"time" "time"
) )
// ShrinkDeadline returns a new Context with proper deadline base on the given ctx and timeout.
// And returns a cancel function as well.
func ShrinkDeadline(ctx context.Context, timeout time.Duration) (context.Context, func()) { func ShrinkDeadline(ctx context.Context, timeout time.Duration) (context.Context, func()) {
if deadline, ok := ctx.Deadline(); ok { if deadline, ok := ctx.Deadline(); ok {
leftTime := time.Until(deadline) leftTime := time.Until(deadline)

View File

@@ -19,6 +19,7 @@ func (cv contextValuer) Value(key string) (interface{}, bool) {
return v, v != nil return v, v != nil
} }
// For unmarshals ctx into v.
func For(ctx context.Context, v interface{}) error { func For(ctx context.Context, v interface{}) error {
return unmarshaler.UnmarshalValuer(contextValuer{ return unmarshaler.UnmarshalValuer(contextValuer{
Context: ctx, Context: ctx,

View File

@@ -21,6 +21,7 @@ func (valueOnlyContext) Err() error {
return nil return nil
} }
// ValueOnlyFrom takes all values from the given ctx, without deadline and error control.
func ValueOnlyFrom(ctx context.Context) context.Context { func ValueOnlyFrom(ctx context.Context) context.Context {
return valueOnlyContext{ return valueOnlyContext{
Context: ctx, Context: ctx,

View File

@@ -14,6 +14,7 @@ const (
const timeToLive int64 = 10 const timeToLive int64 = 10
// TimeToLive is seconds to live in etcd.
var TimeToLive = timeToLive var TimeToLive = timeToLive
func extract(etcdKey string, index int) (string, bool) { func extract(etcdKey string, index int) (string, bool) {

View File

@@ -28,6 +28,9 @@ func TestExtract(t *testing.T) {
_, ok = extract("any", -1) _, ok = extract("any", -1)
assert.False(t, ok) assert.False(t, ok)
_, ok = extract("any", 10)
assert.False(t, ok)
} }
func TestMakeKey(t *testing.T) { func TestMakeKey(t *testing.T) {

View File

@@ -2,11 +2,13 @@ package discov
import "errors" import "errors"
// EtcdConf is the config item with the given key on etcd.
type EtcdConf struct { type EtcdConf struct {
Hosts []string Hosts []string
Key string Key string
} }
// Validate validates c.
func (c EtcdConf) Validate() error { func (c EtcdConf) Validate() error {
if len(c.Hosts) == 0 { if len(c.Hosts) == 0 {
return errors.New("empty etcd hosts") return errors.New("empty etcd hosts")

View File

@@ -1,47 +0,0 @@
package discov
import (
"github.com/tal-tech/go-zero/core/discov/internal"
"github.com/tal-tech/go-zero/core/logx"
)
type (
Facade struct {
endpoints []string
registry *internal.Registry
}
FacadeListener interface {
OnAdd(key, val string)
OnDelete(key string)
}
)
func NewFacade(endpoints []string) Facade {
return Facade{
endpoints: endpoints,
registry: internal.GetRegistry(),
}
}
func (f Facade) Client() internal.EtcdClient {
conn, err := f.registry.GetConn(f.endpoints)
logx.Must(err)
return conn
}
func (f Facade) Monitor(key string, l FacadeListener) {
f.registry.Monitor(f.endpoints, key, listenerAdapter{l})
}
type listenerAdapter struct {
l FacadeListener
}
func (la listenerAdapter) OnAdd(kv internal.KV) {
la.l.OnAdd(kv.Key, kv.Val)
}
func (la listenerAdapter) OnDelete(kv internal.KV) {
la.l.OnDelete(kv.Key)
}

View File

@@ -1,4 +1,5 @@
//go:generate mockgen -package internal -destination etcdclient_mock.go -source etcdclient.go EtcdClient //go:generate mockgen -package internal -destination etcdclient_mock.go -source etcdclient.go EtcdClient
package internal package internal
import ( import (
@@ -8,6 +9,7 @@ import (
"google.golang.org/grpc" "google.golang.org/grpc"
) )
// EtcdClient interface represents an etcd client.
type EtcdClient interface { type EtcdClient interface {
ActiveConnection() *grpc.ClientConn ActiveConnection() *grpc.ClientConn
Close() error Close() error

View File

@@ -1,5 +1,6 @@
package internal package internal
// Listener interface wraps the OnUpdate method.
type Listener interface { type Listener interface {
OnUpdate(keys []string, values []string, newKey string) OnUpdate(keys []string, values []string, newKey string)
} }

View File

@@ -18,19 +18,31 @@ import (
) )
var ( var (
registryInstance = Registry{ registry = Registry{
clusters: make(map[string]*cluster), clusters: make(map[string]*cluster),
} }
connManager = syncx.NewResourceManager() connManager = syncx.NewResourceManager()
) )
// A Registry is a registry that manages the etcd client connections.
type Registry struct { type Registry struct {
clusters map[string]*cluster clusters map[string]*cluster
lock sync.Mutex lock sync.Mutex
} }
// GetRegistry returns a global Registry.
func GetRegistry() *Registry { func GetRegistry() *Registry {
return &registryInstance return &registry
}
// GetConn returns an etcd client connection associated with given endpoints.
func (r *Registry) GetConn(endpoints []string) (EtcdClient, error) {
return r.getCluster(endpoints).getClient()
}
// Monitor monitors the key on given etcd endpoints, notify with the given UpdateListener.
func (r *Registry) Monitor(endpoints []string, key string, l UpdateListener) error {
return r.getCluster(endpoints).monitor(key, l)
} }
func (r *Registry) getCluster(endpoints []string) *cluster { func (r *Registry) getCluster(endpoints []string) *cluster {
@@ -46,14 +58,6 @@ func (r *Registry) getCluster(endpoints []string) *cluster {
return c return c
} }
func (r *Registry) GetConn(endpoints []string) (EtcdClient, error) {
return r.getCluster(endpoints).getClient()
}
func (r *Registry) Monitor(endpoints []string, key string, l UpdateListener) error {
return r.getCluster(endpoints).monitor(key, l)
}
type cluster struct { type cluster struct {
endpoints []string endpoints []string
key string key string
@@ -288,6 +292,7 @@ func (c *cluster) watchConnState(cli EtcdClient) {
watcher.watch(cli.ActiveConnection()) watcher.watch(cli.ActiveConnection())
} }
// DialClient dials an etcd cluster with given endpoints.
func DialClient(endpoints []string) (EtcdClient, error) { func DialClient(endpoints []string) (EtcdClient, error) {
return clientv3.New(clientv3.Config{ return clientv3.New(clientv3.Config{
Endpoints: endpoints, Endpoints: endpoints,

View File

@@ -1,4 +1,5 @@
//go:generate mockgen -package internal -destination statewatcher_mock.go -source statewatcher.go etcdConn //go:generate mockgen -package internal -destination statewatcher_mock.go -source statewatcher.go etcdConn
package internal package internal
import ( import (

View File

@@ -1,12 +1,15 @@
//go:generate mockgen -package internal -destination updatelistener_mock.go -source updatelistener.go UpdateListener //go:generate mockgen -package internal -destination updatelistener_mock.go -source updatelistener.go UpdateListener
package internal package internal
type ( type (
// A KV is used to store an etcd entry with key and value.
KV struct { KV struct {
Key string Key string
Val string Val string
} }
// UpdateListener wraps the OnAdd and OnDelete methods.
UpdateListener interface { UpdateListener interface {
OnAdd(kv KV) OnAdd(kv KV)
OnDelete(kv KV) OnDelete(kv KV)

View File

@@ -3,17 +3,22 @@ package internal
import "time" import "time"
const ( const (
// Delimiter is a separator that separates the etcd path.
Delimiter = '/'
autoSyncInterval = time.Minute autoSyncInterval = time.Minute
coolDownInterval = time.Second coolDownInterval = time.Second
dialTimeout = 5 * time.Second dialTimeout = 5 * time.Second
dialKeepAliveTime = 5 * time.Second dialKeepAliveTime = 5 * time.Second
requestTimeout = 3 * time.Second requestTimeout = 3 * time.Second
Delimiter = '/'
endpointsSeparator = "," endpointsSeparator = ","
) )
var ( var (
DialTimeout = dialTimeout // DialTimeout is the dial timeout.
DialTimeout = dialTimeout
// RequestTimeout is the request timeout.
RequestTimeout = requestTimeout RequestTimeout = requestTimeout
NewClient = DialClient // NewClient is used to create etcd clients.
NewClient = DialClient
) )

View File

@@ -0,0 +1,64 @@
apiVersion: apps/v1
kind: StatefulSet
metadata:
name: "etcd"
namespace: discov
labels:
app: "etcd"
spec:
serviceName: "etcd"
replicas: 5
template:
metadata:
name: "etcd"
labels:
app: "etcd"
spec:
volumes:
- name: etcd-pvc
persistentVolumeClaim:
claimName: etcd-pvc
containers:
- name: "etcd"
image: quay.io/coreos/etcd:latest
ports:
- containerPort: 2379
name: client
- containerPort: 2380
name: peer
env:
- name: CLUSTER_SIZE
value: "5"
- name: SET_NAME
value: "etcd"
- name: VOLNAME
valueFrom:
fieldRef:
apiVersion: v1
fieldPath: metadata.name
volumeMounts:
- name: etcd-pvc
mountPath: /var/lib/etcd
subPathExpr: $(VOLNAME) # data mounted respectively in each pod
command:
- "/bin/sh"
- "-ecx"
- |
chmod 700 /var/lib/etcd
IP=$(hostname -i)
PEERS=""
for i in $(seq 0 $((${CLUSTER_SIZE} - 1))); do
PEERS="${PEERS}${PEERS:+,}${SET_NAME}-${i}=http://${SET_NAME}-${i}.${SET_NAME}:2380"
done
exec etcd --name ${HOSTNAME} \
--listen-peer-urls http://0.0.0.0:2380 \
--listen-client-urls http://0.0.0.0:2379 \
--advertise-client-urls http://${HOSTNAME}.${SET_NAME}.discov:2379 \
--initial-advertise-peer-urls http://${HOSTNAME}.${SET_NAME}:2380 \
--initial-cluster ${PEERS} \
--initial-cluster-state new \
--logger zap \
--data-dir /var/lib/etcd \
--auto-compaction-retention 1

View File

@@ -11,8 +11,10 @@ import (
) )
type ( type (
// PublisherOption defines the method to customize a Publisher.
PublisherOption func(client *Publisher) PublisherOption func(client *Publisher)
// A Publisher can be used to publish the value to an etcd cluster on the given key.
Publisher struct { Publisher struct {
endpoints []string endpoints []string
key string key string
@@ -26,6 +28,10 @@ type (
} }
) )
// NewPublisher returns a Publisher.
// endpoints is the hosts of the etcd cluster.
// key:value are a pair to be published.
// opts are used to customize the Publisher.
func NewPublisher(endpoints []string, key, value string, opts ...PublisherOption) *Publisher { func NewPublisher(endpoints []string, key, value string, opts ...PublisherOption) *Publisher {
publisher := &Publisher{ publisher := &Publisher{
endpoints: endpoints, endpoints: endpoints,
@@ -43,6 +49,7 @@ func NewPublisher(endpoints []string, key, value string, opts ...PublisherOption
return publisher return publisher
} }
// KeepAlive keeps key:value alive.
func (p *Publisher) KeepAlive() error { func (p *Publisher) KeepAlive() error {
cli, err := internal.GetRegistry().GetConn(p.endpoints) cli, err := internal.GetRegistry().GetConn(p.endpoints)
if err != nil { if err != nil {
@@ -61,14 +68,17 @@ func (p *Publisher) KeepAlive() error {
return p.keepAliveAsync(cli) return p.keepAliveAsync(cli)
} }
// Pause pauses the renewing of key:value.
func (p *Publisher) Pause() { func (p *Publisher) Pause() {
p.pauseChan <- lang.Placeholder p.pauseChan <- lang.Placeholder
} }
// Resume resumes the renewing of key:value.
func (p *Publisher) Resume() { func (p *Publisher) Resume() {
p.resumeChan <- lang.Placeholder p.resumeChan <- lang.Placeholder
} }
// Stop stops the renewing and revokes the registration.
func (p *Publisher) Stop() { func (p *Publisher) Stop() {
p.quit.Close() p.quit.Close()
} }
@@ -135,6 +145,7 @@ func (p *Publisher) revoke(cli internal.EtcdClient) {
} }
} }
// WithId customizes a Publisher with the id.
func WithId(id int64) PublisherOption { func WithId(id int64) PublisherOption {
return func(publisher *Publisher) { return func(publisher *Publisher) {
publisher.id = id publisher.id = id

View File

@@ -4,10 +4,12 @@ import (
"errors" "errors"
"sync" "sync"
"testing" "testing"
"time"
"github.com/golang/mock/gomock" "github.com/golang/mock/gomock"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/tal-tech/go-zero/core/discov/internal" "github.com/tal-tech/go-zero/core/discov/internal"
"github.com/tal-tech/go-zero/core/lang"
"github.com/tal-tech/go-zero/core/logx" "github.com/tal-tech/go-zero/core/logx"
"go.etcd.io/etcd/clientv3" "go.etcd.io/etcd/clientv3"
) )
@@ -152,3 +154,16 @@ func TestPublisher_keepAliveAsyncPause(t *testing.T) {
pub.Pause() pub.Pause()
wg.Wait() wg.Wait()
} }
func TestPublisher_Resume(t *testing.T) {
publisher := new(Publisher)
publisher.resumeChan = make(chan lang.PlaceholderType)
go func() {
publisher.Resume()
}()
go func() {
time.Sleep(time.Minute)
t.Fail()
}()
<-publisher.resumeChan
}

View File

@@ -13,13 +13,19 @@ type (
exclusive bool exclusive bool
} }
// SubOption defines the method to customize a Subscriber.
SubOption func(opts *subOptions) SubOption func(opts *subOptions)
// A Subscriber is used to subscribe the given key on a etcd cluster.
Subscriber struct { Subscriber struct {
items *container items *container
} }
) )
// NewSubscriber returns a Subscriber.
// endpoints is the hosts of the etcd cluster.
// key is the key to subscribe.
// opts are used to customize the Subscriber.
func NewSubscriber(endpoints []string, key string, opts ...SubOption) (*Subscriber, error) { func NewSubscriber(endpoints []string, key string, opts ...SubOption) (*Subscriber, error) {
var subOpts subOptions var subOpts subOptions
for _, opt := range opts { for _, opt := range opts {
@@ -36,15 +42,17 @@ func NewSubscriber(endpoints []string, key string, opts ...SubOption) (*Subscrib
return sub, nil return sub, nil
} }
// AddListener adds listener to s.
func (s *Subscriber) AddListener(listener func()) { func (s *Subscriber) AddListener(listener func()) {
s.items.addListener(listener) s.items.addListener(listener)
} }
// Values returns all the subscription values.
func (s *Subscriber) Values() []string { func (s *Subscriber) Values() []string {
return s.items.getValues() return s.items.getValues()
} }
// exclusive means that key value can only be 1:1, // Exclusive means that key value can only be 1:1,
// which means later added value will remove the keys associated with the same value previously. // which means later added value will remove the keys associated with the same value previously.
func Exclusive() SubOption { func Exclusive() SubOption {
return func(opts *subOptions) { return func(opts *subOptions) {
@@ -100,9 +108,9 @@ func (c *container) addKv(key, value string) ([]string, bool) {
if early { if early {
return previous, true return previous, true
} else {
return nil, false
} }
return nil, false
} }
func (c *container) addListener(listener func()) { func (c *container) addListener(listener func()) {

View File

@@ -1,6 +1,7 @@
package discov package discov
import ( import (
"sync/atomic"
"testing" "testing"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
@@ -198,3 +199,18 @@ func TestContainer(t *testing.T) {
} }
} }
} }
func TestSubscriber(t *testing.T) {
var opt subOptions
Exclusive()(&opt)
sub := new(Subscriber)
sub.items = newContainer(opt.exclusive)
var count int32
sub.AddListener(func() {
atomic.AddInt32(&count, 1)
})
sub.items.notifyChange()
assert.Empty(t, sub.Values())
assert.Equal(t, int32(1), atomic.LoadInt32(&count))
}

View File

@@ -2,14 +2,17 @@ package errorx
import "sync/atomic" import "sync/atomic"
// AtomicError defines an atomic error.
type AtomicError struct { type AtomicError struct {
err atomic.Value // error err atomic.Value // error
} }
// Set sets the error.
func (ae *AtomicError) Set(err error) { func (ae *AtomicError) Set(err error) {
ae.err.Store(err) ae.err.Store(err)
} }
// Load returns the error.
func (ae *AtomicError) Load() error { func (ae *AtomicError) Load() error {
if v := ae.err.Load(); v != nil { if v := ae.err.Load(); v != nil {
return v.(error) return v.(error)

View File

@@ -3,6 +3,7 @@ package errorx
import "bytes" import "bytes"
type ( type (
// A BatchError is an error that can hold multiple errors.
BatchError struct { BatchError struct {
errs errorArray errs errorArray
} }
@@ -10,12 +11,14 @@ type (
errorArray []error errorArray []error
) )
// Add adds err to be.
func (be *BatchError) Add(err error) { func (be *BatchError) Add(err error) {
if err != nil { if err != nil {
be.errs = append(be.errs, err) be.errs = append(be.errs, err)
} }
} }
// Err returns an error that represents all errors.
func (be *BatchError) Err() error { func (be *BatchError) Err() error {
switch len(be.errs) { switch len(be.errs) {
case 0: case 0:
@@ -27,10 +30,12 @@ func (be *BatchError) Err() error {
} }
} }
// NotNil checks if any error inside.
func (be *BatchError) NotNil() bool { func (be *BatchError) NotNil() bool {
return len(be.errs) > 0 return len(be.errs) > 0
} }
// Error returns a string that represents inside errors.
func (ea errorArray) Error() string { func (ea errorArray) Error() string {
var buf bytes.Buffer var buf bytes.Buffer

View File

@@ -1,5 +1,6 @@
package errorx package errorx
// Chain runs funs one by one until an error occurred.
func Chain(fns ...func() error) error { func Chain(fns ...func() error) error {
for _, fn := range fns { for _, fn := range fns {
if err := fn(); err != nil { if err := fn(); err != nil {

View File

@@ -5,8 +5,12 @@ import "time"
const defaultBulkTasks = 1000 const defaultBulkTasks = 1000
type ( type (
// BulkOption defines the method to customize a BulkExecutor.
BulkOption func(options *bulkOptions) BulkOption func(options *bulkOptions)
// A BulkExecutor is an executor that can execute tasks on either requirement meets:
// 1. up to given size of tasks
// 2. flush interval time elapsed
BulkExecutor struct { BulkExecutor struct {
executor *PeriodicalExecutor executor *PeriodicalExecutor
container *bulkContainer container *bulkContainer
@@ -18,6 +22,7 @@ type (
} }
) )
// NewBulkExecutor returns a BulkExecutor.
func NewBulkExecutor(execute Execute, opts ...BulkOption) *BulkExecutor { func NewBulkExecutor(execute Execute, opts ...BulkOption) *BulkExecutor {
options := newBulkOptions() options := newBulkOptions()
for _, opt := range opts { for _, opt := range opts {
@@ -36,25 +41,30 @@ func NewBulkExecutor(execute Execute, opts ...BulkOption) *BulkExecutor {
return executor return executor
} }
// Add adds task into be.
func (be *BulkExecutor) Add(task interface{}) error { func (be *BulkExecutor) Add(task interface{}) error {
be.executor.Add(task) be.executor.Add(task)
return nil return nil
} }
// Flush forces be to flush and execute tasks.
func (be *BulkExecutor) Flush() { func (be *BulkExecutor) Flush() {
be.executor.Flush() be.executor.Flush()
} }
// Wait waits be to done with the task execution.
func (be *BulkExecutor) Wait() { func (be *BulkExecutor) Wait() {
be.executor.Wait() be.executor.Wait()
} }
// WithBulkTasks customizes a BulkExecutor with given tasks limit.
func WithBulkTasks(tasks int) BulkOption { func WithBulkTasks(tasks int) BulkOption {
return func(options *bulkOptions) { return func(options *bulkOptions) {
options.cachedTasks = tasks options.cachedTasks = tasks
} }
} }
// WithBulkInterval customizes a BulkExecutor with given flush interval.
func WithBulkInterval(duration time.Duration) BulkOption { func WithBulkInterval(duration time.Duration) BulkOption {
return func(options *bulkOptions) { return func(options *bulkOptions) {
options.flushInterval = duration options.flushInterval = duration

View File

@@ -5,8 +5,12 @@ import "time"
const defaultChunkSize = 1024 * 1024 // 1M const defaultChunkSize = 1024 * 1024 // 1M
type ( type (
// ChunkOption defines the method to customize a ChunkExecutor.
ChunkOption func(options *chunkOptions) ChunkOption func(options *chunkOptions)
// A ChunkExecutor is an executor to execute tasks when either requirement meets:
// 1. up to given chunk size
// 2. flush interval elapsed
ChunkExecutor struct { ChunkExecutor struct {
executor *PeriodicalExecutor executor *PeriodicalExecutor
container *chunkContainer container *chunkContainer
@@ -18,6 +22,7 @@ type (
} }
) )
// NewChunkExecutor returns a ChunkExecutor.
func NewChunkExecutor(execute Execute, opts ...ChunkOption) *ChunkExecutor { func NewChunkExecutor(execute Execute, opts ...ChunkOption) *ChunkExecutor {
options := newChunkOptions() options := newChunkOptions()
for _, opt := range opts { for _, opt := range opts {
@@ -36,6 +41,7 @@ func NewChunkExecutor(execute Execute, opts ...ChunkOption) *ChunkExecutor {
return executor return executor
} }
// Add adds task with given chunk size into ce.
func (ce *ChunkExecutor) Add(task interface{}, size int) error { func (ce *ChunkExecutor) Add(task interface{}, size int) error {
ce.executor.Add(chunk{ ce.executor.Add(chunk{
val: task, val: task,
@@ -44,20 +50,24 @@ func (ce *ChunkExecutor) Add(task interface{}, size int) error {
return nil return nil
} }
// Flush forces ce to flush and execute tasks.
func (ce *ChunkExecutor) Flush() { func (ce *ChunkExecutor) Flush() {
ce.executor.Flush() ce.executor.Flush()
} }
// Wait waits the execution to be done.
func (ce *ChunkExecutor) Wait() { func (ce *ChunkExecutor) Wait() {
ce.executor.Wait() ce.executor.Wait()
} }
// WithChunkBytes customizes a ChunkExecutor with the given chunk size.
func WithChunkBytes(size int) ChunkOption { func WithChunkBytes(size int) ChunkOption {
return func(options *chunkOptions) { return func(options *chunkOptions) {
options.chunkSize = size options.chunkSize = size
} }
} }
// WithFlushInterval customizes a ChunkExecutor with the given flush interval.
func WithFlushInterval(duration time.Duration) ChunkOption { func WithFlushInterval(duration time.Duration) ChunkOption {
return func(options *chunkOptions) { return func(options *chunkOptions) {
options.flushInterval = duration options.flushInterval = duration

View File

@@ -7,6 +7,7 @@ import (
"github.com/tal-tech/go-zero/core/threading" "github.com/tal-tech/go-zero/core/threading"
) )
// A DelayExecutor delays a tasks on given delay interval.
type DelayExecutor struct { type DelayExecutor struct {
fn func() fn func()
delay time.Duration delay time.Duration
@@ -14,6 +15,7 @@ type DelayExecutor struct {
lock sync.Mutex lock sync.Mutex
} }
// NewDelayExecutor returns a DelayExecutor with given fn and delay.
func NewDelayExecutor(fn func(), delay time.Duration) *DelayExecutor { func NewDelayExecutor(fn func(), delay time.Duration) *DelayExecutor {
return &DelayExecutor{ return &DelayExecutor{
fn: fn, fn: fn,
@@ -21,6 +23,7 @@ func NewDelayExecutor(fn func(), delay time.Duration) *DelayExecutor {
} }
} }
// Trigger triggers the task to be executed after given delay, safe to trigger more than once.
func (de *DelayExecutor) Trigger() { func (de *DelayExecutor) Trigger() {
de.lock.Lock() de.lock.Lock()
defer de.lock.Unlock() defer de.lock.Unlock()

View File

@@ -7,11 +7,13 @@ import (
"github.com/tal-tech/go-zero/core/timex" "github.com/tal-tech/go-zero/core/timex"
) )
// A LessExecutor is an executor to limit execution once within given time interval.
type LessExecutor struct { type LessExecutor struct {
threshold time.Duration threshold time.Duration
lastTime *syncx.AtomicDuration lastTime *syncx.AtomicDuration
} }
// NewLessExecutor returns a LessExecutor with given threshold as time interval.
func NewLessExecutor(threshold time.Duration) *LessExecutor { func NewLessExecutor(threshold time.Duration) *LessExecutor {
return &LessExecutor{ return &LessExecutor{
threshold: threshold, threshold: threshold,
@@ -19,6 +21,8 @@ func NewLessExecutor(threshold time.Duration) *LessExecutor {
} }
} }
// DoOrDiscard executes or discards the task depends on if
// another task was executed within the time interval.
func (le *LessExecutor) DoOrDiscard(execute func()) bool { func (le *LessExecutor) DoOrDiscard(execute func()) bool {
now := timex.Now() now := timex.Now()
lastTime := le.lastTime.Load() lastTime := le.lastTime.Load()

View File

@@ -16,7 +16,7 @@ import (
const idleRound = 10 const idleRound = 10
type ( type (
// A type that satisfies executors.TaskContainer can be used as the underlying // TaskContainer interface defines a type that can be used as the underlying
// container that used to do periodical executions. // container that used to do periodical executions.
TaskContainer interface { TaskContainer interface {
// AddTask adds the task into the container. // AddTask adds the task into the container.
@@ -28,6 +28,7 @@ type (
RemoveAll() interface{} RemoveAll() interface{}
} }
// A PeriodicalExecutor is an executor that periodically execute tasks.
PeriodicalExecutor struct { PeriodicalExecutor struct {
commander chan interface{} commander chan interface{}
interval time.Duration interval time.Duration
@@ -43,6 +44,7 @@ type (
} }
) )
// NewPeriodicalExecutor returns a PeriodicalExecutor with given interval and container.
func NewPeriodicalExecutor(interval time.Duration, container TaskContainer) *PeriodicalExecutor { func NewPeriodicalExecutor(interval time.Duration, container TaskContainer) *PeriodicalExecutor {
executor := &PeriodicalExecutor{ executor := &PeriodicalExecutor{
// buffer 1 to let the caller go quickly // buffer 1 to let the caller go quickly
@@ -51,7 +53,7 @@ func NewPeriodicalExecutor(interval time.Duration, container TaskContainer) *Per
container: container, container: container,
confirmChan: make(chan lang.PlaceholderType), confirmChan: make(chan lang.PlaceholderType),
newTicker: func(d time.Duration) timex.Ticker { newTicker: func(d time.Duration) timex.Ticker {
return timex.NewTicker(interval) return timex.NewTicker(d)
}, },
} }
proc.AddShutdownListener(func() { proc.AddShutdownListener(func() {
@@ -61,6 +63,7 @@ func NewPeriodicalExecutor(interval time.Duration, container TaskContainer) *Per
return executor return executor
} }
// Add adds tasks into pe.
func (pe *PeriodicalExecutor) Add(task interface{}) { func (pe *PeriodicalExecutor) Add(task interface{}) {
if vals, ok := pe.addAndCheck(task); ok { if vals, ok := pe.addAndCheck(task); ok {
pe.commander <- vals pe.commander <- vals
@@ -68,6 +71,7 @@ func (pe *PeriodicalExecutor) Add(task interface{}) {
} }
} }
// Flush forces pe to execute tasks.
func (pe *PeriodicalExecutor) Flush() bool { func (pe *PeriodicalExecutor) Flush() bool {
pe.enterExecution() pe.enterExecution()
return pe.executeTasks(func() interface{} { return pe.executeTasks(func() interface{} {
@@ -77,12 +81,14 @@ func (pe *PeriodicalExecutor) Flush() bool {
}()) }())
} }
// Sync lets caller to run fn thread-safe with pe, especially for the underlying container.
func (pe *PeriodicalExecutor) Sync(fn func()) { func (pe *PeriodicalExecutor) Sync(fn func()) {
pe.lock.Lock() pe.lock.Lock()
defer pe.lock.Unlock() defer pe.lock.Unlock()
fn() fn()
} }
// Wait waits the execution to be done.
func (pe *PeriodicalExecutor) Wait() { func (pe *PeriodicalExecutor) Wait() {
pe.Flush() pe.Flush()
pe.wgBarrier.Guard(func() { pe.wgBarrier.Guard(func() {

View File

@@ -4,4 +4,5 @@ import "time"
const defaultFlushInterval = time.Second const defaultFlushInterval = time.Second
// Execute defines the method to execute tasks.
type Execute func(tasks []interface{}) type Execute func(tasks []interface{})

View File

@@ -7,6 +7,7 @@ import (
const bufSize = 1024 const bufSize = 1024
// FirstLine returns the first line of the file.
func FirstLine(filename string) (string, error) { func FirstLine(filename string) (string, error) {
file, err := os.Open(filename) file, err := os.Open(filename)
if err != nil { if err != nil {
@@ -17,6 +18,7 @@ func FirstLine(filename string) (string, error) {
return firstLine(file) return firstLine(file)
} }
// LastLine returns the last line of the file.
func LastLine(filename string) (string, error) { func LastLine(filename string) (string, error) {
file, err := os.Open(filename) file, err := os.Open(filename)
if err != nil { if err != nil {
@@ -69,11 +71,11 @@ func lastLine(filename string, file *os.File) (string, error) {
if buf[n-1] == '\n' { if buf[n-1] == '\n' {
buf = buf[:n-1] buf = buf[:n-1]
n -= 1 n--
} else { } else {
buf = buf[:n] buf = buf[:n]
} }
for n -= 1; n >= 0; n-- { for n--; n >= 0; n-- {
if buf[n] == '\n' { if buf[n] == '\n' {
return string(append(buf[n+1:], last...)), nil return string(append(buf[n+1:], last...)), nil
} }

View File

@@ -5,12 +5,15 @@ import (
"os" "os"
) )
// OffsetRange represents a content block of a file.
type OffsetRange struct { type OffsetRange struct {
File string File string
Start int64 Start int64
Stop int64 Stop int64
} }
// SplitLineChunks splits file into chunks.
// The whole line are guaranteed to be split in the same chunk.
func SplitLineChunks(filename string, chunks int) ([]OffsetRange, error) { func SplitLineChunks(filename string, chunks int) ([]OffsetRange, error) {
info, err := os.Stat(filename) info, err := os.Stat(filename)
if err != nil { if err != nil {

View File

@@ -3,8 +3,11 @@ package filex
import "gopkg.in/cheggaaa/pb.v1" import "gopkg.in/cheggaaa/pb.v1"
type ( type (
// A Scanner is used to read lines.
Scanner interface { Scanner interface {
// Scan checks if has remaining to read.
Scan() bool Scan() bool
// Text returns next line.
Text() string Text() string
} }
@@ -14,6 +17,7 @@ type (
} }
) )
// NewProgressScanner returns a Scanner with progress indicator.
func NewProgressScanner(scanner Scanner, bar *pb.ProgressBar) Scanner { func NewProgressScanner(scanner Scanner, bar *pb.ProgressBar) Scanner {
return &progressScanner{ return &progressScanner{
Scanner: scanner, Scanner: scanner,

View File

@@ -5,12 +5,14 @@ import (
"os" "os"
) )
// A RangeReader is used to read a range of content from a file.
type RangeReader struct { type RangeReader struct {
file *os.File file *os.File
start int64 start int64
stop int64 stop int64
} }
// NewRangeReader returns a RangeReader, which will read the range of content from file.
func NewRangeReader(file *os.File, start, stop int64) *RangeReader { func NewRangeReader(file *os.File, start, stop int64) *RangeReader {
return &RangeReader{ return &RangeReader{
file: file, file: file,
@@ -19,6 +21,7 @@ func NewRangeReader(file *os.File, start, stop int64) *RangeReader {
} }
} }
// Read reads the range of content into p.
func (rr *RangeReader) Read(p []byte) (n int, err error) { func (rr *RangeReader) Read(p []byte) (n int, err error) {
stat, err := rr.file.Stat() stat, err := rr.file.Stat()
if err != nil { if err != nil {

View File

@@ -7,6 +7,7 @@ import (
"syscall" "syscall"
) )
// CloseOnExec makes sure closing the file on process forking.
func CloseOnExec(file *os.File) { func CloseOnExec(file *os.File) {
if file != nil { if file != nil {
syscall.CloseOnExec(int(file.Fd())) syscall.CloseOnExec(int(file.Fd()))

View File

@@ -2,6 +2,7 @@ package fx
import "github.com/tal-tech/go-zero/core/threading" import "github.com/tal-tech/go-zero/core/threading"
// Parallel runs fns parallelly and waits for done.
func Parallel(fns ...func()) { func Parallel(fns ...func()) {
group := threading.NewRoutineGroup() group := threading.NewRoutineGroup()
for _, fn := range fns { for _, fn := range fns {

View File

@@ -5,6 +5,7 @@ import "github.com/tal-tech/go-zero/core/errorx"
const defaultRetryTimes = 3 const defaultRetryTimes = 3
type ( type (
// RetryOption defines the method to customize DoWithRetry.
RetryOption func(*retryOptions) RetryOption func(*retryOptions)
retryOptions struct { retryOptions struct {
@@ -12,7 +13,8 @@ type (
} }
) )
func DoWithRetries(fn func() error, opts ...RetryOption) error { // DoWithRetry runs fn, and retries if failed. Default to retry 3 times.
func DoWithRetry(fn func() error, opts ...RetryOption) error {
var options = newRetryOptions() var options = newRetryOptions()
for _, opt := range opts { for _, opt := range opts {
opt(options) opt(options)
@@ -30,7 +32,8 @@ func DoWithRetries(fn func() error, opts ...RetryOption) error {
return berr.Err() return berr.Err()
} }
func WithRetries(times int) RetryOption { // WithRetry customize a DoWithRetry call with given retry times.
func WithRetry(times int) RetryOption {
return func(options *retryOptions) { return func(options *retryOptions) {
options.times = times options.times = times
} }

View File

@@ -8,12 +8,12 @@ import (
) )
func TestRetry(t *testing.T) { func TestRetry(t *testing.T) {
assert.NotNil(t, DoWithRetries(func() error { assert.NotNil(t, DoWithRetry(func() error {
return errors.New("any") return errors.New("any")
})) }))
var times int var times int
assert.Nil(t, DoWithRetries(func() error { assert.Nil(t, DoWithRetry(func() error {
times++ times++
if times == defaultRetryTimes { if times == defaultRetryTimes {
return nil return nil
@@ -22,7 +22,7 @@ func TestRetry(t *testing.T) {
})) }))
times = 0 times = 0
assert.NotNil(t, DoWithRetries(func() error { assert.NotNil(t, DoWithRetry(func() error {
times++ times++
if times == defaultRetryTimes+1 { if times == defaultRetryTimes+1 {
return nil return nil
@@ -32,11 +32,11 @@ func TestRetry(t *testing.T) {
var total = 2 * defaultRetryTimes var total = 2 * defaultRetryTimes
times = 0 times = 0
assert.Nil(t, DoWithRetries(func() error { assert.Nil(t, DoWithRetry(func() error {
times++ times++
if times == total { if times == total {
return nil return nil
} }
return errors.New("any") return errors.New("any")
}, WithRetries(total))) }, WithRetry(total)))
} }

View File

@@ -20,18 +20,30 @@ type (
workers int workers int
} }
FilterFunc func(item interface{}) bool // FilterFunc defines the method to filter a Stream.
ForAllFunc func(pipe <-chan interface{}) FilterFunc func(item interface{}) bool
ForEachFunc func(item interface{}) // ForAllFunc defines the method to handle all elements in a Stream.
ForAllFunc func(pipe <-chan interface{})
// ForEachFunc defines the method to handle each element in a Stream.
ForEachFunc func(item interface{})
// GenerateFunc defines the method to send elements into a Stream.
GenerateFunc func(source chan<- interface{}) GenerateFunc func(source chan<- interface{})
KeyFunc func(item interface{}) interface{} // KeyFunc defines the method to generate keys for the elements in a Stream.
LessFunc func(a, b interface{}) bool KeyFunc func(item interface{}) interface{}
MapFunc func(item interface{}) interface{} // LessFunc defines the method to compare the elements in a Stream.
Option func(opts *rxOptions) LessFunc func(a, b interface{}) bool
// MapFunc defines the method to map each element to another object in a Stream.
MapFunc func(item interface{}) interface{}
// Option defines the method to customize a Stream.
Option func(opts *rxOptions)
// ParallelFunc defines the method to handle elements parallelly.
ParallelFunc func(item interface{}) ParallelFunc func(item interface{})
ReduceFunc func(pipe <-chan interface{}) (interface{}, error) // ReduceFunc defines the method to reduce all the elements in a Stream.
WalkFunc func(item interface{}, pipe chan<- interface{}) ReduceFunc func(pipe <-chan interface{}) (interface{}, error)
// WalkFunc defines the method to walk through all the elements in a Stream.
WalkFunc func(item interface{}, pipe chan<- interface{})
// A Stream is a stream that can be used to do stream processing.
Stream struct { Stream struct {
source <-chan interface{} source <-chan interface{}
} }
@@ -159,6 +171,7 @@ func (p Stream) Group(fn KeyFunc) Stream {
return Range(source) return Range(source)
} }
// Head returns the first n elements in p.
func (p Stream) Head(n int64) Stream { func (p Stream) Head(n int64) Stream {
if n < 1 { if n < 1 {
panic("n must be greater than 0") panic("n must be greater than 0")
@@ -187,7 +200,7 @@ func (p Stream) Head(n int64) Stream {
return Range(source) return Range(source)
} }
// Maps converts each item to another corresponding item, which means it's a 1:1 model. // Map converts each item to another corresponding item, which means it's a 1:1 model.
func (p Stream) Map(fn MapFunc, opts ...Option) Stream { func (p Stream) Map(fn MapFunc, opts ...Option) Stream {
return p.Walk(func(item interface{}, pipe chan<- interface{}) { return p.Walk(func(item interface{}, pipe chan<- interface{}) {
pipe <- fn(item) pipe <- fn(item)
@@ -274,6 +287,7 @@ func (p Stream) Split(n int) Stream {
return Range(source) return Range(source)
} }
// Tail returns the last n elements in p.
func (p Stream) Tail(n int64) Stream { func (p Stream) Tail(n int64) Stream {
if n < 1 { if n < 1 {
panic("n should be greater than 0") panic("n should be greater than 0")
@@ -300,9 +314,9 @@ func (p Stream) Walk(fn WalkFunc, opts ...Option) Stream {
option := buildOptions(opts...) option := buildOptions(opts...)
if option.unlimitedWorkers { if option.unlimitedWorkers {
return p.walkUnlimited(fn, option) return p.walkUnlimited(fn, option)
} else {
return p.walkLimited(fn, option)
} }
return p.walkLimited(fn, option)
} }
func (p Stream) walkLimited(fn WalkFunc, option *rxOptions) Stream { func (p Stream) walkLimited(fn WalkFunc, option *rxOptions) Stream {

View File

@@ -3,24 +3,31 @@ package fx
import ( import (
"context" "context"
"time" "time"
"github.com/tal-tech/go-zero/core/contextx"
) )
var ( var (
// ErrCanceled is the error returned when the context is canceled.
ErrCanceled = context.Canceled ErrCanceled = context.Canceled
ErrTimeout = context.DeadlineExceeded // ErrTimeout is the error returned when the context's deadline passes.
ErrTimeout = context.DeadlineExceeded
) )
type FxOption func() context.Context // DoOption defines the method to customize a DoWithTimeout call.
type DoOption func() context.Context
func DoWithTimeout(fn func() error, timeout time.Duration, opts ...FxOption) error { // DoWithTimeout runs fn with timeout control.
func DoWithTimeout(fn func() error, timeout time.Duration, opts ...DoOption) error {
parentCtx := context.Background() parentCtx := context.Background()
for _, opt := range opts { for _, opt := range opts {
parentCtx = opt() parentCtx = opt()
} }
ctx, cancel := context.WithTimeout(parentCtx, timeout) ctx, cancel := contextx.ShrinkDeadline(parentCtx, timeout)
defer cancel() defer cancel()
done := make(chan error) // create channel with buffer size 1 to avoid goroutine leak
done := make(chan error, 1)
panicChan := make(chan interface{}, 1) panicChan := make(chan interface{}, 1)
go func() { go func() {
defer func() { defer func() {
@@ -29,7 +36,6 @@ func DoWithTimeout(fn func() error, timeout time.Duration, opts ...FxOption) err
} }
}() }()
done <- fn() done <- fn()
close(done)
}() }()
select { select {
@@ -42,7 +48,8 @@ func DoWithTimeout(fn func() error, timeout time.Duration, opts ...FxOption) err
} }
} }
func WithContext(ctx context.Context) FxOption { // WithContext customizes a DoWithTimeout call with given ctx.
func WithContext(ctx context.Context) DoOption {
return func() context.Context { return func() context.Context {
return ctx return ctx
} }

View File

@@ -11,6 +11,7 @@ import (
) )
const ( const (
// TopWeight is the top weight that one entry might set.
TopWeight = 100 TopWeight = 100
minReplicas = 100 minReplicas = 100
@@ -18,10 +19,12 @@ const (
) )
type ( type (
HashFunc func(data []byte) uint64 // Func defines the hash method.
Func func(data []byte) uint64
// A ConsistentHash is a ring hash implementation.
ConsistentHash struct { ConsistentHash struct {
hashFunc HashFunc hashFunc Func
replicas int replicas int
keys []uint64 keys []uint64
ring map[uint64][]interface{} ring map[uint64][]interface{}
@@ -30,11 +33,13 @@ type (
} }
) )
// NewConsistentHash returns a ConsistentHash.
func NewConsistentHash() *ConsistentHash { func NewConsistentHash() *ConsistentHash {
return NewCustomConsistentHash(minReplicas, Hash) return NewCustomConsistentHash(minReplicas, Hash)
} }
func NewCustomConsistentHash(replicas int, fn HashFunc) *ConsistentHash { // NewCustomConsistentHash returns a ConsistentHash with given replicas and hash func.
func NewCustomConsistentHash(replicas int, fn Func) *ConsistentHash {
if replicas < minReplicas { if replicas < minReplicas {
replicas = minReplicas replicas = minReplicas
} }
@@ -92,6 +97,7 @@ func (h *ConsistentHash) AddWithWeight(node interface{}, weight int) {
h.AddWithReplicas(node, replicas) h.AddWithReplicas(node, replicas)
} }
// Get returns the corresponding node from h base on the given v.
func (h *ConsistentHash) Get(v interface{}) (interface{}, bool) { func (h *ConsistentHash) Get(v interface{}) (interface{}, bool) {
h.lock.RLock() h.lock.RLock()
defer h.lock.RUnlock() defer h.lock.RUnlock()
@@ -118,6 +124,7 @@ func (h *ConsistentHash) Get(v interface{}) (interface{}, bool) {
} }
} }
// Remove removes the given node from h.
func (h *ConsistentHash) Remove(node interface{}) { func (h *ConsistentHash) Remove(node interface{}) {
nodeRepr := repr(node) nodeRepr := repr(node)

View File

@@ -132,8 +132,8 @@ func TestConsistentHash_RemoveInterface(t *testing.T) {
assert.Equal(t, 1, len(ch.nodes)) assert.Equal(t, 1, len(ch.nodes))
node, ok := ch.Get(1) node, ok := ch.Get(1)
assert.True(t, ok) assert.True(t, ok)
assert.Equal(t, key, node.(*MockNode).Addr) assert.Equal(t, key, node.(*mockNode).addr)
assert.Equal(t, 2, node.(*MockNode).Id) assert.Equal(t, 2, node.(*mockNode).id)
} }
func getKeysBeforeAndAfterFailure(t *testing.T, prefix string, index int) (map[int]string, map[int]string) { func getKeysBeforeAndAfterFailure(t *testing.T, prefix string, index int) (map[int]string, map[int]string) {
@@ -164,18 +164,18 @@ func getKeysBeforeAndAfterFailure(t *testing.T, prefix string, index int) (map[i
return keys, newKeys return keys, newKeys
} }
type MockNode struct { type mockNode struct {
Addr string addr string
Id int id int
} }
func newMockNode(addr string, id int) *MockNode { func newMockNode(addr string, id int) *mockNode {
return &MockNode{ return &mockNode{
Addr: addr, addr: addr,
Id: id, id: id,
} }
} }
func (n *MockNode) String() string { func (n *mockNode) String() string {
return n.Addr return n.addr
} }

View File

@@ -7,16 +7,19 @@ import (
"github.com/spaolacci/murmur3" "github.com/spaolacci/murmur3"
) )
// Hash returns the hash value of data.
func Hash(data []byte) uint64 { func Hash(data []byte) uint64 {
return murmur3.Sum64(data) return murmur3.Sum64(data)
} }
// Md5 returns the md5 bytes of data.
func Md5(data []byte) []byte { func Md5(data []byte) []byte {
digest := md5.New() digest := md5.New()
digest.Write(data) digest.Write(data)
return digest.Sum(nil) return digest.Sum(nil)
} }
// Md5Hex returns the md5 hex string of data.
func Md5Hex(data []byte) string { func Md5Hex(data []byte) string {
return fmt.Sprintf("%x", Md5(data)) return fmt.Sprintf("%x", Md5(data))
} }

View File

@@ -5,11 +5,13 @@ import (
"sync" "sync"
) )
// A BufferPool is a pool to buffer bytes.Buffer objects.
type BufferPool struct { type BufferPool struct {
capability int capability int
pool *sync.Pool pool *sync.Pool
} }
// NewBufferPool returns a BufferPool.
func NewBufferPool(capability int) *BufferPool { func NewBufferPool(capability int) *BufferPool {
return &BufferPool{ return &BufferPool{
capability: capability, capability: capability,
@@ -21,12 +23,14 @@ func NewBufferPool(capability int) *BufferPool {
} }
} }
// Get returns a bytes.Buffer object from bp.
func (bp *BufferPool) Get() *bytes.Buffer { func (bp *BufferPool) Get() *bytes.Buffer {
buf := bp.pool.Get().(*bytes.Buffer) buf := bp.pool.Get().(*bytes.Buffer)
buf.Reset() buf.Reset()
return buf return buf
} }
// Put returns buf into bp.
func (bp *BufferPool) Put(buf *bytes.Buffer) { func (bp *BufferPool) Put(buf *bytes.Buffer) {
if buf.Cap() < bp.capability { if buf.Cap() < bp.capability {
bp.pool.Put(buf) bp.pool.Put(buf)

View File

@@ -10,6 +10,7 @@ func (nopCloser) Close() error {
return nil return nil
} }
// NopCloser returns a io.WriteCloser that does nothing on calling Close.
func NopCloser(w io.Writer) io.WriteCloser { func NopCloser(w io.Writer) io.WriteCloser {
return nopCloser{w} return nopCloser{w}
} }

View File

@@ -16,9 +16,11 @@ type (
omitPrefix string omitPrefix string
} }
// TextReadOption defines the method to customize the text reading functions.
TextReadOption func(*textReadOptions) TextReadOption func(*textReadOptions)
) )
// DupReadCloser returns two io.ReadCloser that read from the first will be written to the second.
// The first returned reader needs to be read first, because the content // The first returned reader needs to be read first, because the content
// read from it will be written to the underlying buffer of the second reader. // read from it will be written to the underlying buffer of the second reader.
func DupReadCloser(reader io.ReadCloser) (io.ReadCloser, io.ReadCloser) { func DupReadCloser(reader io.ReadCloser) (io.ReadCloser, io.ReadCloser) {
@@ -27,6 +29,7 @@ func DupReadCloser(reader io.ReadCloser) (io.ReadCloser, io.ReadCloser) {
return ioutil.NopCloser(tee), ioutil.NopCloser(&buf) return ioutil.NopCloser(tee), ioutil.NopCloser(&buf)
} }
// KeepSpace customizes the reading functions to keep leading and tailing spaces.
func KeepSpace() TextReadOption { func KeepSpace() TextReadOption {
return func(o *textReadOptions) { return func(o *textReadOptions) {
o.keepSpace = true o.keepSpace = true
@@ -49,6 +52,7 @@ func ReadBytes(reader io.Reader, buf []byte) error {
return nil return nil
} }
// ReadText reads content from the given file with leading and tailing spaces trimmed.
func ReadText(filename string) (string, error) { func ReadText(filename string) (string, error) {
content, err := ioutil.ReadFile(filename) content, err := ioutil.ReadFile(filename)
if err != nil { if err != nil {
@@ -58,6 +62,7 @@ func ReadText(filename string) (string, error) {
return strings.TrimSpace(string(content)), nil return strings.TrimSpace(string(content)), nil
} }
// ReadTextLines reads the text lines from given file.
func ReadTextLines(filename string, opts ...TextReadOption) ([]string, error) { func ReadTextLines(filename string, opts ...TextReadOption) ([]string, error) {
var readOpts textReadOptions var readOpts textReadOptions
for _, opt := range opts { for _, opt := range opts {
@@ -90,12 +95,14 @@ func ReadTextLines(filename string, opts ...TextReadOption) ([]string, error) {
return lines, scanner.Err() return lines, scanner.Err()
} }
// WithoutBlank customizes the reading functions to ignore blank lines.
func WithoutBlank() TextReadOption { func WithoutBlank() TextReadOption {
return func(o *textReadOptions) { return func(o *textReadOptions) {
o.withoutBlanks = true o.withoutBlanks = true
} }
} }
// OmitWithPrefix customizes the reading functions to ignore the lines with given leading prefix.
func OmitWithPrefix(prefix string) TextReadOption { func OmitWithPrefix(prefix string) TextReadOption {
return func(o *textReadOptions) { return func(o *textReadOptions) {
o.omitPrefix = prefix o.omitPrefix = prefix

View File

@@ -8,6 +8,7 @@ import (
const bufSize = 32 * 1024 const bufSize = 32 * 1024
// CountLines returns the number of lines in file.
func CountLines(file string) (int, error) { func CountLines(file string) (int, error) {
f, err := os.Open(file) f, err := os.Open(file)
if err != nil { if err != nil {

View File

@@ -6,6 +6,7 @@ import (
"strings" "strings"
) )
// A TextLineScanner is a scanner that can scan lines from given reader.
type TextLineScanner struct { type TextLineScanner struct {
reader *bufio.Reader reader *bufio.Reader
hasNext bool hasNext bool
@@ -13,6 +14,7 @@ type TextLineScanner struct {
err error err error
} }
// NewTextLineScanner returns a TextLineScanner with given reader.
func NewTextLineScanner(reader io.Reader) *TextLineScanner { func NewTextLineScanner(reader io.Reader) *TextLineScanner {
return &TextLineScanner{ return &TextLineScanner{
reader: bufio.NewReader(reader), reader: bufio.NewReader(reader),
@@ -20,6 +22,7 @@ func NewTextLineScanner(reader io.Reader) *TextLineScanner {
} }
} }
// Scan checks if scanner has more lines to read.
func (scanner *TextLineScanner) Scan() bool { func (scanner *TextLineScanner) Scan() bool {
if !scanner.hasNext { if !scanner.hasNext {
return false return false
@@ -37,6 +40,7 @@ func (scanner *TextLineScanner) Scan() bool {
return true return true
} }
// Line returns the next available line.
func (scanner *TextLineScanner) Line() (string, error) { func (scanner *TextLineScanner) Line() (string, error) {
return scanner.line, scanner.err return scanner.line, scanner.err
} }

View File

@@ -7,32 +7,38 @@ import (
"github.com/globalsign/mgo/bson" "github.com/globalsign/mgo/bson"
) )
// MilliTime represents time.Time that works better with mongodb.
type MilliTime struct { type MilliTime struct {
time.Time time.Time
} }
// MarshalJSON marshals mt to json bytes.
func (mt MilliTime) MarshalJSON() ([]byte, error) { func (mt MilliTime) MarshalJSON() ([]byte, error) {
return json.Marshal(mt.Milli()) return json.Marshal(mt.Milli())
} }
// UnmarshalJSON unmarshals data into mt.
func (mt *MilliTime) UnmarshalJSON(data []byte) error { func (mt *MilliTime) UnmarshalJSON(data []byte) error {
var milli int64 var milli int64
if err := json.Unmarshal(data, &milli); err != nil { if err := json.Unmarshal(data, &milli); err != nil {
return err return err
} else {
mt.Time = time.Unix(0, milli*int64(time.Millisecond))
return nil
} }
mt.Time = time.Unix(0, milli*int64(time.Millisecond))
return nil
} }
// GetBSON returns BSON base on mt.
func (mt MilliTime) GetBSON() (interface{}, error) { func (mt MilliTime) GetBSON() (interface{}, error) {
return mt.Time, nil return mt.Time, nil
} }
// SetBSON sets raw into mt.
func (mt *MilliTime) SetBSON(raw bson.Raw) error { func (mt *MilliTime) SetBSON(raw bson.Raw) error {
return raw.Unmarshal(&mt.Time) return raw.Unmarshal(&mt.Time)
} }
// Milli returns milliseconds for mt.
func (mt MilliTime) Milli() int64 { func (mt MilliTime) Milli() int64 {
return mt.UnixNano() / int64(time.Millisecond) return mt.UnixNano() / int64(time.Millisecond)
} }

View File

@@ -8,10 +8,12 @@ import (
"strings" "strings"
) )
// Marshal marshals v into json bytes.
func Marshal(v interface{}) ([]byte, error) { func Marshal(v interface{}) ([]byte, error) {
return json.Marshal(v) return json.Marshal(v)
} }
// Unmarshal unmarshals data bytes into v.
func Unmarshal(data []byte, v interface{}) error { func Unmarshal(data []byte, v interface{}) error {
decoder := json.NewDecoder(bytes.NewReader(data)) decoder := json.NewDecoder(bytes.NewReader(data))
if err := unmarshalUseNumber(decoder, v); err != nil { if err := unmarshalUseNumber(decoder, v); err != nil {
@@ -21,6 +23,7 @@ func Unmarshal(data []byte, v interface{}) error {
return nil return nil
} }
// UnmarshalFromString unmarshals v from str.
func UnmarshalFromString(str string, v interface{}) error { func UnmarshalFromString(str string, v interface{}) error {
decoder := json.NewDecoder(strings.NewReader(str)) decoder := json.NewDecoder(strings.NewReader(str))
if err := unmarshalUseNumber(decoder, v); err != nil { if err := unmarshalUseNumber(decoder, v); err != nil {
@@ -30,6 +33,7 @@ func UnmarshalFromString(str string, v interface{}) error {
return nil return nil
} }
// UnmarshalFromReader unmarshals v from reader.
func UnmarshalFromReader(reader io.Reader, v interface{}) error { func UnmarshalFromReader(reader io.Reader, v interface{}) error {
var buf strings.Builder var buf strings.Builder
teeReader := io.TeeReader(reader, &buf) teeReader := io.TeeReader(reader, &buf)

View File

@@ -1,8 +1,11 @@
package lang package lang
// Placeholder is a placeholder object that can be used globally.
var Placeholder PlaceholderType var Placeholder PlaceholderType
type ( type (
GenericType = interface{} // GenericType can be used to hold any type.
GenericType = interface{}
// PlaceholderType represents a placeholder type.
PlaceholderType = struct{} PlaceholderType = struct{}
) )

View File

@@ -27,9 +27,13 @@ end`
) )
const ( const (
// Unknown means not initialized state.
Unknown = iota Unknown = iota
// Allowed means allowed state.
Allowed Allowed
// HitQuota means this request exactly hit the quota.
HitQuota HitQuota
// OverQuota means passed the quota.
OverQuota OverQuota
internalOverQuota = 0 internalOverQuota = 0
@@ -37,11 +41,14 @@ const (
internalHitQuota = 2 internalHitQuota = 2
) )
// ErrUnknownCode is an error that represents unknown status code.
var ErrUnknownCode = errors.New("unknown status code") var ErrUnknownCode = errors.New("unknown status code")
type ( type (
LimitOption func(l *PeriodLimit) // PeriodOption defines the method to customize a PeriodLimit.
PeriodOption func(l *PeriodLimit)
// A PeriodLimit is used to limit requests during a period of time.
PeriodLimit struct { PeriodLimit struct {
period int period int
quota int quota int
@@ -51,8 +58,9 @@ type (
} }
) )
// NewPeriodLimit returns a PeriodLimit with given parameters.
func NewPeriodLimit(period, quota int, limitStore *redis.Redis, keyPrefix string, func NewPeriodLimit(period, quota int, limitStore *redis.Redis, keyPrefix string,
opts ...LimitOption) *PeriodLimit { opts ...PeriodOption) *PeriodLimit {
limiter := &PeriodLimit{ limiter := &PeriodLimit{
period: period, period: period,
quota: quota, quota: quota,
@@ -67,6 +75,7 @@ func NewPeriodLimit(period, quota int, limitStore *redis.Redis, keyPrefix string
return limiter return limiter
} }
// Take requests a permit, it returns the permit state.
func (h *PeriodLimit) Take(key string) (int, error) { func (h *PeriodLimit) Take(key string) (int, error) {
resp, err := h.limitStore.Eval(periodScript, []string{h.keyPrefix + key}, []string{ resp, err := h.limitStore.Eval(periodScript, []string{h.keyPrefix + key}, []string{
strconv.Itoa(h.quota), strconv.Itoa(h.quota),
@@ -97,12 +106,13 @@ func (h *PeriodLimit) calcExpireSeconds() int {
if h.align { if h.align {
unix := time.Now().Unix() + zoneDiff unix := time.Now().Unix() + zoneDiff
return h.period - int(unix%int64(h.period)) return h.period - int(unix%int64(h.period))
} else {
return h.period
} }
return h.period
} }
func Align() LimitOption { // Align returns a func to customize a PeriodLimit with alignment.
func Align() PeriodOption {
return func(l *PeriodLimit) { return func(l *PeriodLimit) {
l.align = true l.align = true
} }

View File

@@ -33,7 +33,7 @@ func TestPeriodLimit_RedisUnavailable(t *testing.T) {
assert.Equal(t, 0, val) assert.Equal(t, 0, val)
} }
func testPeriodLimit(t *testing.T, opts ...LimitOption) { func testPeriodLimit(t *testing.T, opts ...PeriodOption) {
store, clean, err := redistest.CreateRedis() store, clean, err := redistest.CreateRedis()
assert.Nil(t, err) assert.Nil(t, err)
defer clean() defer clean()

View File

@@ -26,6 +26,7 @@ const (
) )
var ( var (
// ErrServiceOverloaded is returned by Shedder.Allow when the service is overloaded.
ErrServiceOverloaded = errors.New("service overloaded") ErrServiceOverloaded = errors.New("service overloaded")
// default to be enabled // default to be enabled
@@ -37,15 +38,22 @@ var (
) )
type ( type (
// A Promise interface is returned by Shedder.Allow to let callers tell
// whether the processing request is successful or not.
Promise interface { Promise interface {
// Pass lets the caller tell that the call is successful.
Pass() Pass()
// Fail lets the caller tell that the call is failed.
Fail() Fail()
} }
// Shedder is the interface that wraps the Allow method.
Shedder interface { Shedder interface {
// Allow returns the Promise if allowed, otherwise ErrServiceOverloaded.
Allow() (Promise, error) Allow() (Promise, error)
} }
// ShedderOption lets caller customize the Shedder.
ShedderOption func(opts *shedderOptions) ShedderOption func(opts *shedderOptions)
shedderOptions struct { shedderOptions struct {
@@ -67,10 +75,13 @@ type (
} }
) )
// Disable lets callers disable load shedding.
func Disable() { func Disable() {
enabled.Set(false) enabled.Set(false)
} }
// NewAdaptiveShedder returns an adaptive shedder.
// opts can be used to customize the Shedder.
func NewAdaptiveShedder(opts ...ShedderOption) Shedder { func NewAdaptiveShedder(opts ...ShedderOption) Shedder {
if !enabled.True() { if !enabled.True() {
return newNopShedder() return newNopShedder()
@@ -97,6 +108,7 @@ func NewAdaptiveShedder(opts ...ShedderOption) Shedder {
} }
} }
// Allow implements Shedder.Allow.
func (as *adaptiveShedder) Allow() (Promise, error) { func (as *adaptiveShedder) Allow() (Promise, error) {
if as.shouldDrop() { if as.shouldDrop() {
as.dropTime.Set(timex.Now()) as.dropTime.Set(timex.Now())
@@ -213,18 +225,21 @@ func (as *adaptiveShedder) systemOverloaded() bool {
return systemOverloadChecker(as.cpuThreshold) return systemOverloadChecker(as.cpuThreshold)
} }
// WithBuckets customizes the Shedder with given number of buckets.
func WithBuckets(buckets int) ShedderOption { func WithBuckets(buckets int) ShedderOption {
return func(opts *shedderOptions) { return func(opts *shedderOptions) {
opts.buckets = buckets opts.buckets = buckets
} }
} }
// WithCpuThreshold customizes the Shedder with given cpu threshold.
func WithCpuThreshold(threshold int64) ShedderOption { func WithCpuThreshold(threshold int64) ShedderOption {
return func(opts *shedderOptions) { return func(opts *shedderOptions) {
opts.cpuThreshold = threshold opts.cpuThreshold = threshold
} }
} }
// WithWindow customizes the Shedder with given
func WithWindow(window time.Duration) ShedderOption { func WithWindow(window time.Duration) ShedderOption {
return func(opts *shedderOptions) { return func(opts *shedderOptions) {
opts.window = window opts.window = window

View File

@@ -6,11 +6,13 @@ import (
"github.com/tal-tech/go-zero/core/syncx" "github.com/tal-tech/go-zero/core/syncx"
) )
// A ShedderGroup is a manager to manage key based shedders.
type ShedderGroup struct { type ShedderGroup struct {
options []ShedderOption options []ShedderOption
manager *syncx.ResourceManager manager *syncx.ResourceManager
} }
// NewShedderGroup returns a ShedderGroup.
func NewShedderGroup(opts ...ShedderOption) *ShedderGroup { func NewShedderGroup(opts ...ShedderOption) *ShedderGroup {
return &ShedderGroup{ return &ShedderGroup{
options: opts, options: opts,
@@ -18,6 +20,7 @@ func NewShedderGroup(opts ...ShedderOption) *ShedderGroup {
} }
} }
// GetShedder gets the Shedder for the given key.
func (g *ShedderGroup) GetShedder(key string) Shedder { func (g *ShedderGroup) GetShedder(key string) Shedder {
shedder, _ := g.manager.GetResource(key, func() (closer io.Closer, e error) { shedder, _ := g.manager.GetResource(key, func() (closer io.Closer, e error) {
return nopCloser{ return nopCloser{

View File

@@ -9,6 +9,7 @@ import (
) )
type ( type (
// A SheddingStat is used to store the statistics for load shedding.
SheddingStat struct { SheddingStat struct {
name string name string
total int64 total int64
@@ -23,6 +24,7 @@ type (
} }
) )
// NewSheddingStat returns a SheddingStat.
func NewSheddingStat(name string) *SheddingStat { func NewSheddingStat(name string) *SheddingStat {
st := &SheddingStat{ st := &SheddingStat{
name: name, name: name,
@@ -31,14 +33,17 @@ func NewSheddingStat(name string) *SheddingStat {
return st return st
} }
// IncrementTotal increments the total requests.
func (s *SheddingStat) IncrementTotal() { func (s *SheddingStat) IncrementTotal() {
atomic.AddInt64(&s.total, 1) atomic.AddInt64(&s.total, 1)
} }
// IncrementPass increments the passed requests.
func (s *SheddingStat) IncrementPass() { func (s *SheddingStat) IncrementPass() {
atomic.AddInt64(&s.pass, 1) atomic.AddInt64(&s.pass, 1)
} }
// IncrementDrop increments the dropped requests.
func (s *SheddingStat) IncrementDrop() { func (s *SheddingStat) IncrementDrop() {
atomic.AddInt64(&s.drop, 1) atomic.AddInt64(&s.drop, 1)
} }

View File

@@ -1,8 +1,10 @@
package logx package logx
// A LogConf is a logging config.
type LogConf struct { type LogConf struct {
ServiceName string `json:",optional"` ServiceName string `json:",optional"`
Mode string `json:",default=console,options=console|file|volume"` Mode string `json:",default=console,options=console|file|volume"`
TimeFormat string `json:",optional"`
Path string `json:",default=logs"` Path string `json:",default=logs"`
Level string `json:",default=info,options=info|error|severe"` Level string `json:",default=info,options=info|error|severe"`
Compress bool `json:",optional"` Compress bool `json:",optional"`

View File

@@ -12,6 +12,7 @@ const durationCallerDepth = 3
type durationLogger logEntry type durationLogger logEntry
// WithDuration returns a Logger which logs the given duration.
func WithDuration(d time.Duration) Logger { func WithDuration(d time.Duration) Logger {
return &durationLogger{ return &durationLogger{
Duration: timex.ReprOfDuration(d), Duration: timex.ReprOfDuration(d),

View File

@@ -1,21 +1,25 @@
package logx package logx
// A LessLogger is a logger that control to log once during the given duration.
type LessLogger struct { type LessLogger struct {
*limitedExecutor *limitedExecutor
} }
// NewLessLogger returns a LessLogger.
func NewLessLogger(milliseconds int) *LessLogger { func NewLessLogger(milliseconds int) *LessLogger {
return &LessLogger{ return &LessLogger{
limitedExecutor: newLimitedExecutor(milliseconds), limitedExecutor: newLimitedExecutor(milliseconds),
} }
} }
// Error logs v into error log or discard it if more than once in the given duration.
func (logger *LessLogger) Error(v ...interface{}) { func (logger *LessLogger) Error(v ...interface{}) {
logger.logOrDiscard(func() { logger.logOrDiscard(func() {
Error(v...) Error(v...)
}) })
} }
// Errorf logs v with format into error log or discard it if more than once in the given duration.
func (logger *LessLogger) Errorf(format string, v ...interface{}) { func (logger *LessLogger) Errorf(format string, v ...interface{}) {
logger.logOrDiscard(func() { logger.logOrDiscard(func() {
Errorf(format, v...) Errorf(format, v...)

View File

@@ -7,7 +7,7 @@ type lessWriter struct {
writer io.Writer writer io.Writer
} }
func NewLessWriter(writer io.Writer, milliseconds int) *lessWriter { func newLessWriter(writer io.Writer, milliseconds int) *lessWriter {
return &lessWriter{ return &lessWriter{
limitedExecutor: newLimitedExecutor(milliseconds), limitedExecutor: newLimitedExecutor(milliseconds),
writer: writer, writer: writer,

View File

@@ -9,7 +9,7 @@ import (
func TestLessWriter(t *testing.T) { func TestLessWriter(t *testing.T) {
var builder strings.Builder var builder strings.Builder
w := NewLessWriter(&builder, 500) w := newLessWriter(&builder, 500)
for i := 0; i < 100; i++ { for i := 0; i < 100; i++ {
_, err := w.Write([]byte("hello")) _, err := w.Write([]byte("hello"))
assert.Nil(t, err) assert.Nil(t, err)

View File

@@ -32,8 +32,6 @@ const (
) )
const ( const (
timeFormat = "2006-01-02T15:04:05.000Z07"
accessFilename = "access.log" accessFilename = "access.log"
errorFilename = "error.log" errorFilename = "error.log"
severeFilename = "severe.log" severeFilename = "severe.log"
@@ -57,10 +55,14 @@ const (
) )
var ( var (
ErrLogPathNotSet = errors.New("log path must be set") // ErrLogPathNotSet is an error that indicates the log path is not set.
ErrLogNotInitialized = errors.New("log not initialized") ErrLogPathNotSet = errors.New("log path must be set")
// ErrLogNotInitialized is an error that log is not initialized.
ErrLogNotInitialized = errors.New("log not initialized")
// ErrLogServiceNameNotSet is an error that indicates that the service name is not set.
ErrLogServiceNameNotSet = errors.New("log service name must be set") ErrLogServiceNameNotSet = errors.New("log service name must be set")
timeFormat = "2006-01-02T15:04:05.000Z07"
writeConsole bool writeConsole bool
logLevel uint32 logLevel uint32
infoLog io.WriteCloser infoLog io.WriteCloser
@@ -89,8 +91,10 @@ type (
keepDays int keepDays int
} }
// LogOption defines the method to customize the logging.
LogOption func(options *logOptions) LogOption func(options *logOptions)
// A Logger represents a logger.
Logger interface { Logger interface {
Error(...interface{}) Error(...interface{})
Errorf(string, ...interface{}) Errorf(string, ...interface{})
@@ -102,6 +106,7 @@ type (
} }
) )
// MustSetup sets up logging with given config c. It exits on error.
func MustSetup(c LogConf) { func MustSetup(c LogConf) {
Must(SetUp(c)) Must(SetUp(c))
} }
@@ -111,6 +116,10 @@ func MustSetup(c LogConf) {
// we need to allow different service frameworks to initialize logx respectively. // we need to allow different service frameworks to initialize logx respectively.
// the same logic for SetUp // the same logic for SetUp
func SetUp(c LogConf) error { func SetUp(c LogConf) error {
if len(c.TimeFormat) > 0 {
timeFormat = c.TimeFormat
}
switch c.Mode { switch c.Mode {
case consoleMode: case consoleMode:
setupWithConsole(c) setupWithConsole(c)
@@ -122,10 +131,12 @@ func SetUp(c LogConf) error {
} }
} }
// Alert alerts v in alert level, and the message is written to error log.
func Alert(v string) { func Alert(v string) {
output(errorLog, levelAlert, v) output(errorLog, levelAlert, v)
} }
// Close closes the logging.
func Close() error { func Close() error {
if writeConsole { if writeConsole {
return nil return nil
@@ -170,6 +181,7 @@ func Close() error {
return nil return nil
} }
// Disable disables the logging.
func Disable() { func Disable() {
once.Do(func() { once.Do(func() {
atomic.StoreUint32(&initialized, 1) atomic.StoreUint32(&initialized, 1)
@@ -183,40 +195,49 @@ func Disable() {
}) })
} }
// Error writes v into error log.
func Error(v ...interface{}) { func Error(v ...interface{}) {
ErrorCaller(1, v...) ErrorCaller(1, v...)
} }
// Errorf writes v with format into error log.
func Errorf(format string, v ...interface{}) { func Errorf(format string, v ...interface{}) {
ErrorCallerf(1, format, v...) ErrorCallerf(1, format, v...)
} }
// ErrorCaller writes v with context into error log.
func ErrorCaller(callDepth int, v ...interface{}) { func ErrorCaller(callDepth int, v ...interface{}) {
errorSync(fmt.Sprint(v...), callDepth+callerInnerDepth) errorSync(fmt.Sprint(v...), callDepth+callerInnerDepth)
} }
// ErrorCallerf writes v with context in format into error log.
func ErrorCallerf(callDepth int, format string, v ...interface{}) { func ErrorCallerf(callDepth int, format string, v ...interface{}) {
errorSync(fmt.Sprintf(format, v...), callDepth+callerInnerDepth) errorSync(fmt.Sprintf(format, v...), callDepth+callerInnerDepth)
} }
// ErrorStack writes v along with call stack into error log.
func ErrorStack(v ...interface{}) { func ErrorStack(v ...interface{}) {
// there is newline in stack string // there is newline in stack string
stackSync(fmt.Sprint(v...)) stackSync(fmt.Sprint(v...))
} }
// ErrorStackf writes v along with call stack in format into error log.
func ErrorStackf(format string, v ...interface{}) { func ErrorStackf(format string, v ...interface{}) {
// there is newline in stack string // there is newline in stack string
stackSync(fmt.Sprintf(format, v...)) stackSync(fmt.Sprintf(format, v...))
} }
// Info writes v into access log.
func Info(v ...interface{}) { func Info(v ...interface{}) {
infoSync(fmt.Sprint(v...)) infoSync(fmt.Sprint(v...))
} }
// Infof writes v with format into access log.
func Infof(format string, v ...interface{}) { func Infof(format string, v ...interface{}) {
infoSync(fmt.Sprintf(format, v...)) infoSync(fmt.Sprintf(format, v...))
} }
// Must checks if err is nil, otherwise logs the err and exits.
func Must(err error) { func Must(err error) {
if err != nil { if err != nil {
msg := formatWithCaller(err.Error(), 3) msg := formatWithCaller(err.Error(), 3)
@@ -226,46 +247,56 @@ func Must(err error) {
} }
} }
// SetLevel sets the logging level. It can be used to suppress some logs.
func SetLevel(level uint32) { func SetLevel(level uint32) {
atomic.StoreUint32(&logLevel, level) atomic.StoreUint32(&logLevel, level)
} }
// Severe writes v into severe log.
func Severe(v ...interface{}) { func Severe(v ...interface{}) {
severeSync(fmt.Sprint(v...)) severeSync(fmt.Sprint(v...))
} }
// Severef writes v with format into severe log.
func Severef(format string, v ...interface{}) { func Severef(format string, v ...interface{}) {
severeSync(fmt.Sprintf(format, v...)) severeSync(fmt.Sprintf(format, v...))
} }
// Slow writes v into slow log.
func Slow(v ...interface{}) { func Slow(v ...interface{}) {
slowSync(fmt.Sprint(v...)) slowSync(fmt.Sprint(v...))
} }
// Slowf writes v with format into slow log.
func Slowf(format string, v ...interface{}) { func Slowf(format string, v ...interface{}) {
slowSync(fmt.Sprintf(format, v...)) slowSync(fmt.Sprintf(format, v...))
} }
// Stat writes v into stat log.
func Stat(v ...interface{}) { func Stat(v ...interface{}) {
statSync(fmt.Sprint(v...)) statSync(fmt.Sprint(v...))
} }
// Statf writes v with format into stat log.
func Statf(format string, v ...interface{}) { func Statf(format string, v ...interface{}) {
statSync(fmt.Sprintf(format, v...)) statSync(fmt.Sprintf(format, v...))
} }
// WithCooldownMillis customizes logging on writing call stack interval.
func WithCooldownMillis(millis int) LogOption { func WithCooldownMillis(millis int) LogOption {
return func(opts *logOptions) { return func(opts *logOptions) {
opts.logStackCooldownMills = millis opts.logStackCooldownMills = millis
} }
} }
// WithKeepDays customizes logging to keep logs with days.
func WithKeepDays(days int) LogOption { func WithKeepDays(days int) LogOption {
return func(opts *logOptions) { return func(opts *logOptions) {
opts.keepDays = days opts.keepDays = days
} }
} }
// WithGzip customizes logging to automatically gzip the log files.
func WithGzip() LogOption { func WithGzip() LogOption {
return func(opts *logOptions) { return func(opts *logOptions) {
opts.gzipEnabled = true opts.gzipEnabled = true
@@ -382,7 +413,7 @@ func setupWithConsole(c LogConf) {
errorLog = newLogWriter(log.New(os.Stderr, "", flags)) errorLog = newLogWriter(log.New(os.Stderr, "", flags))
severeLog = newLogWriter(log.New(os.Stderr, "", flags)) severeLog = newLogWriter(log.New(os.Stderr, "", flags))
slowLog = newLogWriter(log.New(os.Stderr, "", flags)) slowLog = newLogWriter(log.New(os.Stderr, "", flags))
stackLog = NewLessWriter(errorLog, options.logStackCooldownMills) stackLog = newLessWriter(errorLog, options.logStackCooldownMills)
statLog = infoLog statLog = infoLog
}) })
} }
@@ -434,7 +465,7 @@ func setupWithFiles(c LogConf) error {
return return
} }
stackLog = NewLessWriter(errorLog, options.logStackCooldownMills) stackLog = newLessWriter(errorLog, options.logStackCooldownMills)
}) })
return err return err

View File

@@ -26,9 +26,11 @@ const (
defaultFileMode = 0600 defaultFileMode = 0600
) )
// ErrLogFileClosed is an error that indicates the log file is already closed.
var ErrLogFileClosed = errors.New("error: log file closed") var ErrLogFileClosed = errors.New("error: log file closed")
type ( type (
// A RotateRule interface is used to define the log rotating rules.
RotateRule interface { RotateRule interface {
BackupFileName() string BackupFileName() string
MarkRotated() MarkRotated()
@@ -36,6 +38,7 @@ type (
ShallRotate() bool ShallRotate() bool
} }
// A RotateLogger is a Logger that can rotate log files with given rules.
RotateLogger struct { RotateLogger struct {
filename string filename string
backup string backup string
@@ -50,6 +53,7 @@ type (
closeOnce sync.Once closeOnce sync.Once
} }
// A DailyRotateRule is a rule to daily rotate the log files.
DailyRotateRule struct { DailyRotateRule struct {
rotatedTime string rotatedTime string
filename string filename string
@@ -59,6 +63,7 @@ type (
} }
) )
// DefaultRotateRule is a default log rotating rule, currently DailyRotateRule.
func DefaultRotateRule(filename, delimiter string, days int, gzip bool) RotateRule { func DefaultRotateRule(filename, delimiter string, days int, gzip bool) RotateRule {
return &DailyRotateRule{ return &DailyRotateRule{
rotatedTime: getNowDate(), rotatedTime: getNowDate(),
@@ -69,14 +74,17 @@ func DefaultRotateRule(filename, delimiter string, days int, gzip bool) RotateRu
} }
} }
// BackupFileName returns the backup filename on rotating.
func (r *DailyRotateRule) BackupFileName() string { func (r *DailyRotateRule) BackupFileName() string {
return fmt.Sprintf("%s%s%s", r.filename, r.delimiter, getNowDate()) return fmt.Sprintf("%s%s%s", r.filename, r.delimiter, getNowDate())
} }
// MarkRotated marks the rotated time of r to be the current time.
func (r *DailyRotateRule) MarkRotated() { func (r *DailyRotateRule) MarkRotated() {
r.rotatedTime = getNowDate() r.rotatedTime = getNowDate()
} }
// OutdatedFiles returns the files that exceeded the keeping days.
func (r *DailyRotateRule) OutdatedFiles() []string { func (r *DailyRotateRule) OutdatedFiles() []string {
if r.days <= 0 { if r.days <= 0 {
return nil return nil
@@ -113,10 +121,12 @@ func (r *DailyRotateRule) OutdatedFiles() []string {
return outdates return outdates
} }
// ShallRotate checks if the file should be rotated.
func (r *DailyRotateRule) ShallRotate() bool { func (r *DailyRotateRule) ShallRotate() bool {
return len(r.rotatedTime) > 0 && getNowDate() != r.rotatedTime return len(r.rotatedTime) > 0 && getNowDate() != r.rotatedTime
} }
// NewLogger returns a RotateLogger with given filename and rule, etc.
func NewLogger(filename string, rule RotateRule, compress bool) (*RotateLogger, error) { func NewLogger(filename string, rule RotateRule, compress bool) (*RotateLogger, error) {
l := &RotateLogger{ l := &RotateLogger{
filename: filename, filename: filename,
@@ -133,6 +143,7 @@ func NewLogger(filename string, rule RotateRule, compress bool) (*RotateLogger,
return l, nil return l, nil
} }
// Close closes l.
func (l *RotateLogger) Close() error { func (l *RotateLogger) Close() error {
var err error var err error
@@ -163,9 +174,9 @@ func (l *RotateLogger) Write(data []byte) (int, error) {
func (l *RotateLogger) getBackupFilename() string { func (l *RotateLogger) getBackupFilename() string {
if len(l.backup) == 0 { if len(l.backup) == 0 {
return l.rule.BackupFileName() return l.rule.BackupFileName()
} else {
return l.backup
} }
return l.backup
} }
func (l *RotateLogger) init() error { func (l *RotateLogger) init() error {

View File

@@ -67,6 +67,7 @@ func (l *traceLogger) write(writer io.Writer, level, content string) {
outputJson(writer, l) outputJson(writer, l)
} }
// WithContext sets ctx to log, for keeping tracing information.
func WithContext(ctx context.Context) Logger { func WithContext(ctx context.Context) Logger {
return &traceLogger{ return &traceLogger{
ctx: ctx, ctx: ctx,

View File

@@ -13,8 +13,8 @@ import (
) )
const ( const (
mockTraceId = "mock-trace-id" mockTraceID = "mock-trace-id"
mockSpanId = "mock-span-id" mockSpanID = "mock-span-id"
) )
var mock tracespec.Trace = new(mockTrace) var mock tracespec.Trace = new(mockTrace)
@@ -24,8 +24,8 @@ func TestTraceLog(t *testing.T) {
atomic.StoreUint32(&initialized, 1) atomic.StoreUint32(&initialized, 1)
ctx := context.WithValue(context.Background(), tracespec.TracingKey, mock) ctx := context.WithValue(context.Background(), tracespec.TracingKey, mock)
WithContext(ctx).(*traceLogger).write(&buf, levelInfo, testlog) WithContext(ctx).(*traceLogger).write(&buf, levelInfo, testlog)
assert.True(t, strings.Contains(buf.String(), mockTraceId)) assert.True(t, strings.Contains(buf.String(), mockTraceID))
assert.True(t, strings.Contains(buf.String(), mockSpanId)) assert.True(t, strings.Contains(buf.String(), mockSpanID))
} }
func TestTraceError(t *testing.T) { func TestTraceError(t *testing.T) {
@@ -36,12 +36,12 @@ func TestTraceError(t *testing.T) {
l := WithContext(ctx).(*traceLogger) l := WithContext(ctx).(*traceLogger)
SetLevel(InfoLevel) SetLevel(InfoLevel)
l.WithDuration(time.Second).Error(testlog) l.WithDuration(time.Second).Error(testlog)
assert.True(t, strings.Contains(buf.String(), mockTraceId)) assert.True(t, strings.Contains(buf.String(), mockTraceID))
assert.True(t, strings.Contains(buf.String(), mockSpanId)) assert.True(t, strings.Contains(buf.String(), mockSpanID))
buf.Reset() buf.Reset()
l.WithDuration(time.Second).Errorf(testlog) l.WithDuration(time.Second).Errorf(testlog)
assert.True(t, strings.Contains(buf.String(), mockTraceId)) assert.True(t, strings.Contains(buf.String(), mockTraceID))
assert.True(t, strings.Contains(buf.String(), mockSpanId)) assert.True(t, strings.Contains(buf.String(), mockSpanID))
} }
func TestTraceInfo(t *testing.T) { func TestTraceInfo(t *testing.T) {
@@ -52,12 +52,12 @@ func TestTraceInfo(t *testing.T) {
l := WithContext(ctx).(*traceLogger) l := WithContext(ctx).(*traceLogger)
SetLevel(InfoLevel) SetLevel(InfoLevel)
l.WithDuration(time.Second).Info(testlog) l.WithDuration(time.Second).Info(testlog)
assert.True(t, strings.Contains(buf.String(), mockTraceId)) assert.True(t, strings.Contains(buf.String(), mockTraceID))
assert.True(t, strings.Contains(buf.String(), mockSpanId)) assert.True(t, strings.Contains(buf.String(), mockSpanID))
buf.Reset() buf.Reset()
l.WithDuration(time.Second).Infof(testlog) l.WithDuration(time.Second).Infof(testlog)
assert.True(t, strings.Contains(buf.String(), mockTraceId)) assert.True(t, strings.Contains(buf.String(), mockTraceID))
assert.True(t, strings.Contains(buf.String(), mockSpanId)) assert.True(t, strings.Contains(buf.String(), mockSpanID))
} }
func TestTraceSlow(t *testing.T) { func TestTraceSlow(t *testing.T) {
@@ -68,12 +68,12 @@ func TestTraceSlow(t *testing.T) {
l := WithContext(ctx).(*traceLogger) l := WithContext(ctx).(*traceLogger)
SetLevel(InfoLevel) SetLevel(InfoLevel)
l.WithDuration(time.Second).Slow(testlog) l.WithDuration(time.Second).Slow(testlog)
assert.True(t, strings.Contains(buf.String(), mockTraceId)) assert.True(t, strings.Contains(buf.String(), mockTraceID))
assert.True(t, strings.Contains(buf.String(), mockSpanId)) assert.True(t, strings.Contains(buf.String(), mockSpanID))
buf.Reset() buf.Reset()
l.WithDuration(time.Second).Slowf(testlog) l.WithDuration(time.Second).Slowf(testlog)
assert.True(t, strings.Contains(buf.String(), mockTraceId)) assert.True(t, strings.Contains(buf.String(), mockTraceID))
assert.True(t, strings.Contains(buf.String(), mockSpanId)) assert.True(t, strings.Contains(buf.String(), mockSpanID))
} }
func TestTraceWithoutContext(t *testing.T) { func TestTraceWithoutContext(t *testing.T) {
@@ -83,22 +83,22 @@ func TestTraceWithoutContext(t *testing.T) {
l := WithContext(context.Background()).(*traceLogger) l := WithContext(context.Background()).(*traceLogger)
SetLevel(InfoLevel) SetLevel(InfoLevel)
l.WithDuration(time.Second).Info(testlog) l.WithDuration(time.Second).Info(testlog)
assert.False(t, strings.Contains(buf.String(), mockTraceId)) assert.False(t, strings.Contains(buf.String(), mockTraceID))
assert.False(t, strings.Contains(buf.String(), mockSpanId)) assert.False(t, strings.Contains(buf.String(), mockSpanID))
buf.Reset() buf.Reset()
l.WithDuration(time.Second).Infof(testlog) l.WithDuration(time.Second).Infof(testlog)
assert.False(t, strings.Contains(buf.String(), mockTraceId)) assert.False(t, strings.Contains(buf.String(), mockTraceID))
assert.False(t, strings.Contains(buf.String(), mockSpanId)) assert.False(t, strings.Contains(buf.String(), mockSpanID))
} }
type mockTrace struct{} type mockTrace struct{}
func (t mockTrace) TraceId() string { func (t mockTrace) TraceId() string {
return mockTraceId return mockTraceID
} }
func (t mockTrace) SpanId() string { func (t mockTrace) SpanId() string {
return mockSpanId return mockSpanID
} }
func (t mockTrace) Finish() { func (t mockTrace) Finish() {

View File

@@ -35,9 +35,9 @@ func (o *fieldOptionsWithContext) fromString() bool {
func (o *fieldOptionsWithContext) getDefault() (string, bool) { func (o *fieldOptionsWithContext) getDefault() (string, bool) {
if o == nil { if o == nil {
return "", false return "", false
} else {
return o.Default, len(o.Default) > 0
} }
return o.Default, len(o.Default) > 0
} }
func (o *fieldOptionsWithContext) optional() bool { func (o *fieldOptionsWithContext) optional() bool {
@@ -55,9 +55,9 @@ func (o *fieldOptionsWithContext) options() []string {
func (o *fieldOptions) optionalDep() string { func (o *fieldOptions) optionalDep() string {
if o == nil { if o == nil {
return "" return ""
} else {
return o.OptionalDep
} }
return o.OptionalDep
} }
func (o *fieldOptions) toOptionsWithContext(key string, m Valuer, fullName string) ( func (o *fieldOptions) toOptionsWithContext(key string, m Valuer, fullName string) (
@@ -77,29 +77,29 @@ func (o *fieldOptions) toOptionsWithContext(key string, m Valuer, fullName strin
_, selfOn := m.Value(key) _, selfOn := m.Value(key)
if baseOn == selfOn { if baseOn == selfOn {
return nil, fmt.Errorf("set value for either %q or %q in %q", dep, key, fullName) return nil, fmt.Errorf("set value for either %q or %q in %q", dep, key, fullName)
} else {
optional = baseOn
} }
optional = baseOn
} else { } else {
_, baseOn := m.Value(dep) _, baseOn := m.Value(dep)
_, selfOn := m.Value(key) _, selfOn := m.Value(key)
if baseOn != selfOn { if baseOn != selfOn {
return nil, fmt.Errorf("values for %q and %q should be both provided or both not in %q", return nil, fmt.Errorf("values for %q and %q should be both provided or both not in %q",
dep, key, fullName) dep, key, fullName)
} else {
optional = !baseOn
} }
optional = !baseOn
} }
} }
if o.fieldOptionsWithContext.Optional == optional { if o.fieldOptionsWithContext.Optional == optional {
return &o.fieldOptionsWithContext, nil return &o.fieldOptionsWithContext, nil
} else {
return &fieldOptionsWithContext{
FromString: o.FromString,
Optional: optional,
Options: o.Options,
Default: o.Default,
}, nil
} }
return &fieldOptionsWithContext{
FromString: o.FromString,
Optional: optional,
Options: o.Options,
Default: o.Default,
}, nil
} }

View File

@@ -10,10 +10,12 @@ const jsonTagKey = "json"
var jsonUnmarshaler = NewUnmarshaler(jsonTagKey) var jsonUnmarshaler = NewUnmarshaler(jsonTagKey)
// UnmarshalJsonBytes unmarshals content into v.
func UnmarshalJsonBytes(content []byte, v interface{}) error { func UnmarshalJsonBytes(content []byte, v interface{}) error {
return unmarshalJsonBytes(content, v, jsonUnmarshaler) return unmarshalJsonBytes(content, v, jsonUnmarshaler)
} }
// UnmarshalJsonReader unmarshals content from reader into v.
func UnmarshalJsonReader(reader io.Reader, v interface{}) error { func UnmarshalJsonReader(reader io.Reader, v interface{}) error {
return unmarshalJsonReader(reader, v, jsonUnmarshaler) return unmarshalJsonReader(reader, v, jsonUnmarshaler)
} }

View File

@@ -485,41 +485,41 @@ func TestUnmarshalBytesMap(t *testing.T) {
func TestUnmarshalBytesMapStruct(t *testing.T) { func TestUnmarshalBytesMapStruct(t *testing.T) {
var c struct { var c struct {
Persons map[string]struct { Persons map[string]struct {
Id int ID int
Name string `json:"name,optional"` Name string `json:"name,optional"`
} }
} }
content := []byte(`{"Persons": {"first": {"Id": 1, "name": "kevin"}}}`) content := []byte(`{"Persons": {"first": {"ID": 1, "name": "kevin"}}}`)
assert.Nil(t, UnmarshalJsonBytes(content, &c)) assert.Nil(t, UnmarshalJsonBytes(content, &c))
assert.Equal(t, 1, len(c.Persons)) assert.Equal(t, 1, len(c.Persons))
assert.Equal(t, 1, c.Persons["first"].Id) assert.Equal(t, 1, c.Persons["first"].ID)
assert.Equal(t, "kevin", c.Persons["first"].Name) assert.Equal(t, "kevin", c.Persons["first"].Name)
} }
func TestUnmarshalBytesMapStructPtr(t *testing.T) { func TestUnmarshalBytesMapStructPtr(t *testing.T) {
var c struct { var c struct {
Persons map[string]*struct { Persons map[string]*struct {
Id int ID int
Name string `json:"name,optional"` Name string `json:"name,optional"`
} }
} }
content := []byte(`{"Persons": {"first": {"Id": 1, "name": "kevin"}}}`) content := []byte(`{"Persons": {"first": {"ID": 1, "name": "kevin"}}}`)
assert.Nil(t, UnmarshalJsonBytes(content, &c)) assert.Nil(t, UnmarshalJsonBytes(content, &c))
assert.Equal(t, 1, len(c.Persons)) assert.Equal(t, 1, len(c.Persons))
assert.Equal(t, 1, c.Persons["first"].Id) assert.Equal(t, 1, c.Persons["first"].ID)
assert.Equal(t, "kevin", c.Persons["first"].Name) assert.Equal(t, "kevin", c.Persons["first"].Name)
} }
func TestUnmarshalBytesMapStructMissingPartial(t *testing.T) { func TestUnmarshalBytesMapStructMissingPartial(t *testing.T) {
var c struct { var c struct {
Persons map[string]*struct { Persons map[string]*struct {
Id int ID int
Name string Name string
} }
} }
content := []byte(`{"Persons": {"first": {"Id": 1}}}`) content := []byte(`{"Persons": {"first": {"ID": 1}}}`)
assert.NotNil(t, UnmarshalJsonBytes(content, &c)) assert.NotNil(t, UnmarshalJsonBytes(content, &c))
} }
@@ -527,21 +527,21 @@ func TestUnmarshalBytesMapStructMissingPartial(t *testing.T) {
func TestUnmarshalBytesMapStructOptional(t *testing.T) { func TestUnmarshalBytesMapStructOptional(t *testing.T) {
var c struct { var c struct {
Persons map[string]*struct { Persons map[string]*struct {
Id int ID int
Name string `json:"name,optional"` Name string `json:"name,optional"`
} }
} }
content := []byte(`{"Persons": {"first": {"Id": 1}}}`) content := []byte(`{"Persons": {"first": {"ID": 1}}}`)
assert.Nil(t, UnmarshalJsonBytes(content, &c)) assert.Nil(t, UnmarshalJsonBytes(content, &c))
assert.Equal(t, 1, len(c.Persons)) assert.Equal(t, 1, len(c.Persons))
assert.Equal(t, 1, c.Persons["first"].Id) assert.Equal(t, 1, c.Persons["first"].ID)
} }
func TestUnmarshalBytesMapEmptyStructSlice(t *testing.T) { func TestUnmarshalBytesMapEmptyStructSlice(t *testing.T) {
var c struct { var c struct {
Persons map[string][]struct { Persons map[string][]struct {
Id int ID int
Name string `json:"name,optional"` Name string `json:"name,optional"`
} }
} }
@@ -555,22 +555,22 @@ func TestUnmarshalBytesMapEmptyStructSlice(t *testing.T) {
func TestUnmarshalBytesMapStructSlice(t *testing.T) { func TestUnmarshalBytesMapStructSlice(t *testing.T) {
var c struct { var c struct {
Persons map[string][]struct { Persons map[string][]struct {
Id int ID int
Name string `json:"name,optional"` Name string `json:"name,optional"`
} }
} }
content := []byte(`{"Persons": {"first": [{"Id": 1, "name": "kevin"}]}}`) content := []byte(`{"Persons": {"first": [{"ID": 1, "name": "kevin"}]}}`)
assert.Nil(t, UnmarshalJsonBytes(content, &c)) assert.Nil(t, UnmarshalJsonBytes(content, &c))
assert.Equal(t, 1, len(c.Persons)) assert.Equal(t, 1, len(c.Persons))
assert.Equal(t, 1, c.Persons["first"][0].Id) assert.Equal(t, 1, c.Persons["first"][0].ID)
assert.Equal(t, "kevin", c.Persons["first"][0].Name) assert.Equal(t, "kevin", c.Persons["first"][0].Name)
} }
func TestUnmarshalBytesMapEmptyStructPtrSlice(t *testing.T) { func TestUnmarshalBytesMapEmptyStructPtrSlice(t *testing.T) {
var c struct { var c struct {
Persons map[string][]*struct { Persons map[string][]*struct {
Id int ID int
Name string `json:"name,optional"` Name string `json:"name,optional"`
} }
} }
@@ -584,26 +584,26 @@ func TestUnmarshalBytesMapEmptyStructPtrSlice(t *testing.T) {
func TestUnmarshalBytesMapStructPtrSlice(t *testing.T) { func TestUnmarshalBytesMapStructPtrSlice(t *testing.T) {
var c struct { var c struct {
Persons map[string][]*struct { Persons map[string][]*struct {
Id int ID int
Name string `json:"name,optional"` Name string `json:"name,optional"`
} }
} }
content := []byte(`{"Persons": {"first": [{"Id": 1, "name": "kevin"}]}}`) content := []byte(`{"Persons": {"first": [{"ID": 1, "name": "kevin"}]}}`)
assert.Nil(t, UnmarshalJsonBytes(content, &c)) assert.Nil(t, UnmarshalJsonBytes(content, &c))
assert.Equal(t, 1, len(c.Persons)) assert.Equal(t, 1, len(c.Persons))
assert.Equal(t, 1, c.Persons["first"][0].Id) assert.Equal(t, 1, c.Persons["first"][0].ID)
assert.Equal(t, "kevin", c.Persons["first"][0].Name) assert.Equal(t, "kevin", c.Persons["first"][0].Name)
} }
func TestUnmarshalBytesMapStructPtrSliceMissingPartial(t *testing.T) { func TestUnmarshalBytesMapStructPtrSliceMissingPartial(t *testing.T) {
var c struct { var c struct {
Persons map[string][]*struct { Persons map[string][]*struct {
Id int ID int
Name string Name string
} }
} }
content := []byte(`{"Persons": {"first": [{"Id": 1}]}}`) content := []byte(`{"Persons": {"first": [{"ID": 1}]}}`)
assert.NotNil(t, UnmarshalJsonBytes(content, &c)) assert.NotNil(t, UnmarshalJsonBytes(content, &c))
} }
@@ -611,15 +611,15 @@ func TestUnmarshalBytesMapStructPtrSliceMissingPartial(t *testing.T) {
func TestUnmarshalBytesMapStructPtrSliceOptional(t *testing.T) { func TestUnmarshalBytesMapStructPtrSliceOptional(t *testing.T) {
var c struct { var c struct {
Persons map[string][]*struct { Persons map[string][]*struct {
Id int ID int
Name string `json:"name,optional"` Name string `json:"name,optional"`
} }
} }
content := []byte(`{"Persons": {"first": [{"Id": 1}]}}`) content := []byte(`{"Persons": {"first": [{"ID": 1}]}}`)
assert.Nil(t, UnmarshalJsonBytes(content, &c)) assert.Nil(t, UnmarshalJsonBytes(content, &c))
assert.Equal(t, 1, len(c.Persons)) assert.Equal(t, 1, len(c.Persons))
assert.Equal(t, 1, c.Persons["first"][0].Id) assert.Equal(t, 1, c.Persons["first"][0].ID)
} }
func TestUnmarshalStructOptional(t *testing.T) { func TestUnmarshalStructOptional(t *testing.T) {

View File

@@ -33,23 +33,27 @@ var (
) )
type ( type (
// A Unmarshaler is used to unmarshal with given tag key.
Unmarshaler struct { Unmarshaler struct {
key string key string
opts unmarshalOptions opts unmarshalOptions
} }
// UnmarshalOption defines the method to customize a Unmarshaler.
UnmarshalOption func(*unmarshalOptions)
unmarshalOptions struct { unmarshalOptions struct {
fromString bool fromString bool
} }
keyCache map[string][]string keyCache map[string][]string
UnmarshalOption func(*unmarshalOptions)
) )
func init() { func init() {
cacheKeys.Store(make(keyCache)) cacheKeys.Store(make(keyCache))
} }
// NewUnmarshaler returns a Unmarshaler.
func NewUnmarshaler(key string, opts ...UnmarshalOption) *Unmarshaler { func NewUnmarshaler(key string, opts ...UnmarshalOption) *Unmarshaler {
unmarshaler := Unmarshaler{ unmarshaler := Unmarshaler{
key: key, key: key,
@@ -62,14 +66,17 @@ func NewUnmarshaler(key string, opts ...UnmarshalOption) *Unmarshaler {
return &unmarshaler return &unmarshaler
} }
// UnmarshalKey unmarshals m into v with tag key.
func UnmarshalKey(m map[string]interface{}, v interface{}) error { func UnmarshalKey(m map[string]interface{}, v interface{}) error {
return keyUnmarshaler.Unmarshal(m, v) return keyUnmarshaler.Unmarshal(m, v)
} }
// Unmarshal unmarshals m into v.
func (u *Unmarshaler) Unmarshal(m map[string]interface{}, v interface{}) error { func (u *Unmarshaler) Unmarshal(m map[string]interface{}, v interface{}) error {
return u.UnmarshalValuer(MapValuer(m), v) return u.UnmarshalValuer(MapValuer(m), v)
} }
// UnmarshalValuer unmarshals m into v.
func (u *Unmarshaler) UnmarshalValuer(m Valuer, v interface{}) error { func (u *Unmarshaler) UnmarshalValuer(m Valuer, v interface{}) error {
return u.unmarshalWithFullName(m, v, "") return u.unmarshalWithFullName(m, v, "")
} }
@@ -114,9 +121,9 @@ func (u *Unmarshaler) processAnonymousField(field reflect.StructField, value ref
if options.optional() { if options.optional() {
return u.processAnonymousFieldOptional(field, value, key, m, fullName) return u.processAnonymousFieldOptional(field, value, key, m, fullName)
} else {
return u.processAnonymousFieldRequired(field, value, m, fullName)
} }
return u.processAnonymousFieldRequired(field, value, m, fullName)
} }
func (u *Unmarshaler) processAnonymousFieldOptional(field reflect.StructField, value reflect.Value, func (u *Unmarshaler) processAnonymousFieldOptional(field reflect.StructField, value reflect.Value,
@@ -184,9 +191,9 @@ func (u *Unmarshaler) processField(field reflect.StructField, value reflect.Valu
if field.Anonymous { if field.Anonymous {
return u.processAnonymousField(field, value, m, fullName) return u.processAnonymousField(field, value, m, fullName)
} else {
return u.processNamedField(field, value, m, fullName)
} }
return u.processNamedField(field, value, m, fullName)
} }
func (u *Unmarshaler) processFieldNotFromString(field reflect.StructField, value reflect.Value, func (u *Unmarshaler) processFieldNotFromString(field reflect.StructField, value reflect.Value,
@@ -200,7 +207,7 @@ func (u *Unmarshaler) processFieldNotFromString(field reflect.StructField, value
case valueKind == reflect.Map && typeKind == reflect.Struct: case valueKind == reflect.Map && typeKind == reflect.Struct:
return u.processFieldStruct(field, value, mapValue, fullName) return u.processFieldStruct(field, value, mapValue, fullName)
case valueKind == reflect.String && typeKind == reflect.Slice: case valueKind == reflect.String && typeKind == reflect.Slice:
return u.fillSliceFromString(fieldType, value, mapValue, fullName) return u.fillSliceFromString(fieldType, value, mapValue)
case valueKind == reflect.String && derefedFieldType == durationType: case valueKind == reflect.String && derefedFieldType == durationType:
return fillDurationValue(fieldType.Kind(), value, mapValue.(string)) return fillDurationValue(fieldType.Kind(), value, mapValue.(string))
default: default:
@@ -319,9 +326,9 @@ func (u *Unmarshaler) processNamedField(field reflect.StructField, value reflect
mapValue, hasValue := getValue(m, key) mapValue, hasValue := getValue(m, key)
if hasValue { if hasValue {
return u.processNamedFieldWithValue(field, value, mapValue, key, opts, fullName) return u.processNamedFieldWithValue(field, value, mapValue, key, opts, fullName)
} else {
return u.processNamedFieldWithoutValue(field, value, opts, fullName)
} }
return u.processNamedFieldWithoutValue(field, value, opts, fullName)
} }
func (u *Unmarshaler) processNamedFieldWithValue(field reflect.StructField, value reflect.Value, func (u *Unmarshaler) processNamedFieldWithValue(field reflect.StructField, value reflect.Value,
@@ -329,9 +336,9 @@ func (u *Unmarshaler) processNamedFieldWithValue(field reflect.StructField, valu
if mapValue == nil { if mapValue == nil {
if opts.optional() { if opts.optional() {
return nil return nil
} else {
return fmt.Errorf("field %s mustn't be nil", key)
} }
return fmt.Errorf("field %s mustn't be nil", key)
} }
maybeNewValue(field, value) maybeNewValue(field, value)
@@ -464,8 +471,7 @@ func (u *Unmarshaler) fillSlice(fieldType reflect.Type, value reflect.Value, map
return nil return nil
} }
func (u *Unmarshaler) fillSliceFromString(fieldType reflect.Type, value reflect.Value, func (u *Unmarshaler) fillSliceFromString(fieldType reflect.Type, value reflect.Value, mapValue interface{}) error {
mapValue interface{}, fullName string) error {
var slice []interface{} var slice []interface{}
if err := jsonx.UnmarshalFromString(mapValue.(string), &slice); err != nil { if err := jsonx.UnmarshalFromString(mapValue.(string), &slice); err != nil {
return err return err
@@ -591,6 +597,7 @@ func (u *Unmarshaler) parseOptionsWithContext(field reflect.StructField, m Value
return key, optsWithContext, nil return key, optsWithContext, nil
} }
// WithStringValues customizes a Unmarshaler with number values from strings.
func WithStringValues() UnmarshalOption { func WithStringValues() UnmarshalOption {
return func(opt *unmarshalOptions) { return func(opt *unmarshalOptions) {
opt.fromString = true opt.fromString = true

View File

@@ -754,13 +754,13 @@ func TestUnmarshalJsonNumberInt64(t *testing.T) {
strValue := strconv.FormatInt(intValue, 10) strValue := strconv.FormatInt(intValue, 10)
var number = json.Number(strValue) var number = json.Number(strValue)
m := map[string]interface{}{ m := map[string]interface{}{
"Id": number, "ID": number,
} }
var v struct { var v struct {
Id int64 ID int64
} }
assert.Nil(t, UnmarshalKey(m, &v)) assert.Nil(t, UnmarshalKey(m, &v))
assert.Equal(t, intValue, v.Id) assert.Equal(t, intValue, v.ID)
} }
} }
@@ -770,13 +770,13 @@ func TestUnmarshalJsonNumberUint64(t *testing.T) {
strValue := strconv.FormatUint(intValue, 10) strValue := strconv.FormatUint(intValue, 10)
var number = json.Number(strValue) var number = json.Number(strValue)
m := map[string]interface{}{ m := map[string]interface{}{
"Id": number, "ID": number,
} }
var v struct { var v struct {
Id uint64 ID uint64
} }
assert.Nil(t, UnmarshalKey(m, &v)) assert.Nil(t, UnmarshalKey(m, &v))
assert.Equal(t, intValue, v.Id) assert.Equal(t, intValue, v.ID)
} }
} }
@@ -786,15 +786,15 @@ func TestUnmarshalJsonNumberUint64Ptr(t *testing.T) {
strValue := strconv.FormatUint(intValue, 10) strValue := strconv.FormatUint(intValue, 10)
var number = json.Number(strValue) var number = json.Number(strValue)
m := map[string]interface{}{ m := map[string]interface{}{
"Id": number, "ID": number,
} }
var v struct { var v struct {
Id *uint64 ID *uint64
} }
ast := assert.New(t) ast := assert.New(t)
ast.Nil(UnmarshalKey(m, &v)) ast.Nil(UnmarshalKey(m, &v))
ast.NotNil(v.Id) ast.NotNil(v.ID)
ast.Equal(intValue, *v.Id) ast.Equal(intValue, *v.ID)
} }
} }
@@ -1061,38 +1061,38 @@ func TestUnmarshalWithOptionsAndSet(t *testing.T) {
func TestUnmarshalNestedKey(t *testing.T) { func TestUnmarshalNestedKey(t *testing.T) {
var c struct { var c struct {
Id int `json:"Persons.first.Id"` ID int `json:"Persons.first.ID"`
} }
m := map[string]interface{}{ m := map[string]interface{}{
"Persons": map[string]interface{}{ "Persons": map[string]interface{}{
"first": map[string]interface{}{ "first": map[string]interface{}{
"Id": 1, "ID": 1,
}, },
}, },
} }
assert.Nil(t, NewUnmarshaler("json").Unmarshal(m, &c)) assert.Nil(t, NewUnmarshaler("json").Unmarshal(m, &c))
assert.Equal(t, 1, c.Id) assert.Equal(t, 1, c.ID)
} }
func TestUnmarhsalNestedKeyArray(t *testing.T) { func TestUnmarhsalNestedKeyArray(t *testing.T) {
var c struct { var c struct {
First []struct { First []struct {
Id int ID int
} `json:"Persons.first"` } `json:"Persons.first"`
} }
m := map[string]interface{}{ m := map[string]interface{}{
"Persons": map[string]interface{}{ "Persons": map[string]interface{}{
"first": []map[string]interface{}{ "first": []map[string]interface{}{
{"Id": 1}, {"ID": 1},
{"Id": 2}, {"ID": 2},
}, },
}, },
} }
assert.Nil(t, NewUnmarshaler("json").Unmarshal(m, &c)) assert.Nil(t, NewUnmarshaler("json").Unmarshal(m, &c))
assert.Equal(t, 2, len(c.First)) assert.Equal(t, 2, len(c.First))
assert.Equal(t, 1, c.First[0].Id) assert.Equal(t, 1, c.First[0].ID)
} }
func TestUnmarshalAnonymousOptionalRequiredProvided(t *testing.T) { func TestUnmarshalAnonymousOptionalRequiredProvided(t *testing.T) {

View File

@@ -45,6 +45,7 @@ type (
} }
) )
// Deref dereferences a type, if pointer type, returns its element type.
func Deref(t reflect.Type) reflect.Type { func Deref(t reflect.Type) reflect.Type {
if t.Kind() == reflect.Ptr { if t.Kind() == reflect.Ptr {
t = t.Elem() t = t.Elem()
@@ -53,6 +54,7 @@ func Deref(t reflect.Type) reflect.Type {
return t return t
} }
// Repr returns the string representation of v.
func Repr(v interface{}) string { func Repr(v interface{}) string {
if v == nil { if v == nil {
return "" return ""
@@ -69,46 +71,10 @@ func Repr(v interface{}) string {
val = val.Elem() val = val.Elem()
} }
switch vt := val.Interface().(type) { return reprOfValue(val)
case bool:
return strconv.FormatBool(vt)
case error:
return vt.Error()
case float32:
return strconv.FormatFloat(float64(vt), 'f', -1, 32)
case float64:
return strconv.FormatFloat(vt, 'f', -1, 64)
case fmt.Stringer:
return vt.String()
case int:
return strconv.Itoa(vt)
case int8:
return strconv.Itoa(int(vt))
case int16:
return strconv.Itoa(int(vt))
case int32:
return strconv.Itoa(int(vt))
case int64:
return strconv.FormatInt(vt, 10)
case string:
return vt
case uint:
return strconv.FormatUint(uint64(vt), 10)
case uint8:
return strconv.FormatUint(uint64(vt), 10)
case uint16:
return strconv.FormatUint(uint64(vt), 10)
case uint32:
return strconv.FormatUint(uint64(vt), 10)
case uint64:
return strconv.FormatUint(vt, 10)
case []byte:
return string(vt)
default:
return fmt.Sprint(val.Interface())
}
} }
// ValidatePtr validates v if it's a valid pointer.
func ValidatePtr(v *reflect.Value) error { func ValidatePtr(v *reflect.Value) error {
// sequence is very important, IsNil must be called after checking Kind() with reflect.Ptr, // sequence is very important, IsNil must be called after checking Kind() with reflect.Ptr,
// panic otherwise // panic otherwise
@@ -124,23 +90,26 @@ func convertType(kind reflect.Kind, str string) (interface{}, error) {
case reflect.Bool: case reflect.Bool:
return str == "1" || strings.ToLower(str) == "true", nil return str == "1" || strings.ToLower(str) == "true", nil
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
if intValue, err := strconv.ParseInt(str, 10, 64); err != nil { intValue, err := strconv.ParseInt(str, 10, 64)
if err != nil {
return 0, fmt.Errorf("the value %q cannot parsed as int", str) return 0, fmt.Errorf("the value %q cannot parsed as int", str)
} else {
return intValue, nil
} }
return intValue, nil
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
if uintValue, err := strconv.ParseUint(str, 10, 64); err != nil { uintValue, err := strconv.ParseUint(str, 10, 64)
if err != nil {
return 0, fmt.Errorf("the value %q cannot parsed as uint", str) return 0, fmt.Errorf("the value %q cannot parsed as uint", str)
} else {
return uintValue, nil
} }
return uintValue, nil
case reflect.Float32, reflect.Float64: case reflect.Float32, reflect.Float64:
if floatValue, err := strconv.ParseFloat(str, 64); err != nil { floatValue, err := strconv.ParseFloat(str, 64)
if err != nil {
return 0, fmt.Errorf("the value %q cannot parsed as float", str) return 0, fmt.Errorf("the value %q cannot parsed as float", str)
} else {
return floatValue, nil
} }
return floatValue, nil
case reflect.String: case reflect.String:
return str, nil return str, nil
default: default:
@@ -160,46 +129,8 @@ func doParseKeyAndOptions(field reflect.StructField, value string) (string, *fie
var fieldOpts fieldOptions var fieldOpts fieldOptions
for _, segment := range options { for _, segment := range options {
option := strings.TrimSpace(segment) option := strings.TrimSpace(segment)
switch { if err := parseOption(&fieldOpts, field.Name, option); err != nil {
case option == stringOption: return "", nil, err
fieldOpts.FromString = true
case strings.HasPrefix(option, optionalOption):
segs := strings.Split(option, equalToken)
switch len(segs) {
case 1:
fieldOpts.Optional = true
case 2:
fieldOpts.Optional = true
fieldOpts.OptionalDep = segs[1]
default:
return "", nil, fmt.Errorf("field %s has wrong optional", field.Name)
}
case option == optionalOption:
fieldOpts.Optional = true
case strings.HasPrefix(option, optionsOption):
segs := strings.Split(option, equalToken)
if len(segs) != 2 {
return "", nil, fmt.Errorf("field %s has wrong options", field.Name)
} else {
fieldOpts.Options = strings.Split(segs[1], optionSeparator)
}
case strings.HasPrefix(option, defaultOption):
segs := strings.Split(option, equalToken)
if len(segs) != 2 {
return "", nil, fmt.Errorf("field %s has wrong default option", field.Name)
} else {
fieldOpts.Default = strings.TrimSpace(segs[1])
}
case strings.HasPrefix(option, rangeOption):
segs := strings.Split(option, equalToken)
if len(segs) != 2 {
return "", nil, fmt.Errorf("field %s has wrong range", field.Name)
}
if nr, err := parseNumberRange(segs[1]); err != nil {
return "", nil, err
} else {
fieldOpts.Range = nr
}
} }
} }
@@ -343,6 +274,95 @@ func parseNumberRange(str string) (*numberRange, error) {
}, nil }, nil
} }
func parseOption(fieldOpts *fieldOptions, fieldName string, option string) error {
switch {
case option == stringOption:
fieldOpts.FromString = true
case strings.HasPrefix(option, optionalOption):
segs := strings.Split(option, equalToken)
switch len(segs) {
case 1:
fieldOpts.Optional = true
case 2:
fieldOpts.Optional = true
fieldOpts.OptionalDep = segs[1]
default:
return fmt.Errorf("field %s has wrong optional", fieldName)
}
case option == optionalOption:
fieldOpts.Optional = true
case strings.HasPrefix(option, optionsOption):
segs := strings.Split(option, equalToken)
if len(segs) != 2 {
return fmt.Errorf("field %s has wrong options", fieldName)
}
fieldOpts.Options = strings.Split(segs[1], optionSeparator)
case strings.HasPrefix(option, defaultOption):
segs := strings.Split(option, equalToken)
if len(segs) != 2 {
return fmt.Errorf("field %s has wrong default option", fieldName)
}
fieldOpts.Default = strings.TrimSpace(segs[1])
case strings.HasPrefix(option, rangeOption):
segs := strings.Split(option, equalToken)
if len(segs) != 2 {
return fmt.Errorf("field %s has wrong range", fieldName)
}
nr, err := parseNumberRange(segs[1])
if err != nil {
return err
}
fieldOpts.Range = nr
}
return nil
}
func reprOfValue(val reflect.Value) string {
switch vt := val.Interface().(type) {
case bool:
return strconv.FormatBool(vt)
case error:
return vt.Error()
case float32:
return strconv.FormatFloat(float64(vt), 'f', -1, 32)
case float64:
return strconv.FormatFloat(vt, 'f', -1, 64)
case fmt.Stringer:
return vt.String()
case int:
return strconv.Itoa(vt)
case int8:
return strconv.Itoa(int(vt))
case int16:
return strconv.Itoa(int(vt))
case int32:
return strconv.Itoa(int(vt))
case int64:
return strconv.FormatInt(vt, 10)
case string:
return vt
case uint:
return strconv.FormatUint(uint64(vt), 10)
case uint8:
return strconv.FormatUint(uint64(vt), 10)
case uint16:
return strconv.FormatUint(uint64(vt), 10)
case uint32:
return strconv.FormatUint(uint64(vt), 10)
case uint64:
return strconv.FormatUint(vt, 10)
case []byte:
return string(vt)
default:
return fmt.Sprint(val.Interface())
}
}
func setMatchedPrimitiveValue(kind reflect.Kind, value reflect.Value, v interface{}) error { func setMatchedPrimitiveValue(kind reflect.Kind, value reflect.Value, v interface{}) error {
switch kind { switch kind {
case reflect.Bool: case reflect.Bool:

View File

@@ -1,13 +1,17 @@
package mapping package mapping
type ( type (
// A Valuer interface defines the way to get values from the underlying object with keys.
Valuer interface { Valuer interface {
// Value gets the value associated with the given key.
Value(key string) (interface{}, bool) Value(key string) (interface{}, bool)
} }
// A MapValuer is a map that can use Value method to get values with given keys.
MapValuer map[string]interface{} MapValuer map[string]interface{}
) )
// Value gets the value associated with the given key from mv.
func (mv MapValuer) Value(key string) (interface{}, bool) { func (mv MapValuer) Value(key string) (interface{}, bool) {
v, ok := mv[key] v, ok := mv[key]
return v, ok return v, ok

View File

@@ -13,15 +13,18 @@ import (
const yamlTagKey = "json" const yamlTagKey = "json"
var ( var (
// ErrUnsupportedType is an error that indicates the config format is not supported.
ErrUnsupportedType = errors.New("only map-like configs are suported") ErrUnsupportedType = errors.New("only map-like configs are suported")
yamlUnmarshaler = NewUnmarshaler(yamlTagKey) yamlUnmarshaler = NewUnmarshaler(yamlTagKey)
) )
// UnmarshalYamlBytes unmarshals content into v.
func UnmarshalYamlBytes(content []byte, v interface{}) error { func UnmarshalYamlBytes(content []byte, v interface{}) error {
return unmarshalYamlBytes(content, v, yamlUnmarshaler) return unmarshalYamlBytes(content, v, yamlUnmarshaler)
} }
// UnmarshalYamlReader unmarshals content from reader into v.
func UnmarshalYamlReader(reader io.Reader, v interface{}) error { func UnmarshalYamlReader(reader io.Reader, v interface{}) error {
return unmarshalYamlReader(reader, v, yamlUnmarshaler) return unmarshalYamlReader(reader, v, yamlUnmarshaler)
} }
@@ -34,9 +37,9 @@ func unmarshalYamlBytes(content []byte, v interface{}, unmarshaler *Unmarshaler)
if m, ok := o.(map[string]interface{}); ok { if m, ok := o.(map[string]interface{}); ok {
return unmarshaler.Unmarshal(m, v) return unmarshaler.Unmarshal(m, v)
} else {
return ErrUnsupportedType
} }
return ErrUnsupportedType
} }
func unmarshalYamlReader(reader io.Reader, v interface{}, unmarshaler *Unmarshaler) error { func unmarshalYamlReader(reader io.Reader, v interface{}, unmarshaler *Unmarshaler) error {

View File

@@ -502,49 +502,49 @@ func TestUnmarshalYamlBytesMap(t *testing.T) {
func TestUnmarshalYamlBytesMapStruct(t *testing.T) { func TestUnmarshalYamlBytesMapStruct(t *testing.T) {
var c struct { var c struct {
Persons map[string]struct { Persons map[string]struct {
Id int ID int
Name string `json:"name,optional"` Name string `json:"name,optional"`
} }
} }
content := []byte(`Persons: content := []byte(`Persons:
first: first:
Id: 1 ID: 1
name: kevin`) name: kevin`)
assert.Nil(t, UnmarshalYamlBytes(content, &c)) assert.Nil(t, UnmarshalYamlBytes(content, &c))
assert.Equal(t, 1, len(c.Persons)) assert.Equal(t, 1, len(c.Persons))
assert.Equal(t, 1, c.Persons["first"].Id) assert.Equal(t, 1, c.Persons["first"].ID)
assert.Equal(t, "kevin", c.Persons["first"].Name) assert.Equal(t, "kevin", c.Persons["first"].Name)
} }
func TestUnmarshalYamlBytesMapStructPtr(t *testing.T) { func TestUnmarshalYamlBytesMapStructPtr(t *testing.T) {
var c struct { var c struct {
Persons map[string]*struct { Persons map[string]*struct {
Id int ID int
Name string `json:"name,optional"` Name string `json:"name,optional"`
} }
} }
content := []byte(`Persons: content := []byte(`Persons:
first: first:
Id: 1 ID: 1
name: kevin`) name: kevin`)
assert.Nil(t, UnmarshalYamlBytes(content, &c)) assert.Nil(t, UnmarshalYamlBytes(content, &c))
assert.Equal(t, 1, len(c.Persons)) assert.Equal(t, 1, len(c.Persons))
assert.Equal(t, 1, c.Persons["first"].Id) assert.Equal(t, 1, c.Persons["first"].ID)
assert.Equal(t, "kevin", c.Persons["first"].Name) assert.Equal(t, "kevin", c.Persons["first"].Name)
} }
func TestUnmarshalYamlBytesMapStructMissingPartial(t *testing.T) { func TestUnmarshalYamlBytesMapStructMissingPartial(t *testing.T) {
var c struct { var c struct {
Persons map[string]*struct { Persons map[string]*struct {
Id int ID int
Name string Name string
} }
} }
content := []byte(`Persons: content := []byte(`Persons:
first: first:
Id: 1`) ID: 1`)
assert.NotNil(t, UnmarshalYamlBytes(content, &c)) assert.NotNil(t, UnmarshalYamlBytes(content, &c))
} }
@@ -552,41 +552,41 @@ func TestUnmarshalYamlBytesMapStructMissingPartial(t *testing.T) {
func TestUnmarshalYamlBytesMapStructOptional(t *testing.T) { func TestUnmarshalYamlBytesMapStructOptional(t *testing.T) {
var c struct { var c struct {
Persons map[string]*struct { Persons map[string]*struct {
Id int ID int
Name string `json:"name,optional"` Name string `json:"name,optional"`
} }
} }
content := []byte(`Persons: content := []byte(`Persons:
first: first:
Id: 1`) ID: 1`)
assert.Nil(t, UnmarshalYamlBytes(content, &c)) assert.Nil(t, UnmarshalYamlBytes(content, &c))
assert.Equal(t, 1, len(c.Persons)) assert.Equal(t, 1, len(c.Persons))
assert.Equal(t, 1, c.Persons["first"].Id) assert.Equal(t, 1, c.Persons["first"].ID)
} }
func TestUnmarshalYamlBytesMapStructSlice(t *testing.T) { func TestUnmarshalYamlBytesMapStructSlice(t *testing.T) {
var c struct { var c struct {
Persons map[string][]struct { Persons map[string][]struct {
Id int ID int
Name string `json:"name,optional"` Name string `json:"name,optional"`
} }
} }
content := []byte(`Persons: content := []byte(`Persons:
first: first:
- Id: 1 - ID: 1
name: kevin`) name: kevin`)
assert.Nil(t, UnmarshalYamlBytes(content, &c)) assert.Nil(t, UnmarshalYamlBytes(content, &c))
assert.Equal(t, 1, len(c.Persons)) assert.Equal(t, 1, len(c.Persons))
assert.Equal(t, 1, c.Persons["first"][0].Id) assert.Equal(t, 1, c.Persons["first"][0].ID)
assert.Equal(t, "kevin", c.Persons["first"][0].Name) assert.Equal(t, "kevin", c.Persons["first"][0].Name)
} }
func TestUnmarshalYamlBytesMapEmptyStructSlice(t *testing.T) { func TestUnmarshalYamlBytesMapEmptyStructSlice(t *testing.T) {
var c struct { var c struct {
Persons map[string][]struct { Persons map[string][]struct {
Id int ID int
Name string `json:"name,optional"` Name string `json:"name,optional"`
} }
} }
@@ -601,25 +601,25 @@ func TestUnmarshalYamlBytesMapEmptyStructSlice(t *testing.T) {
func TestUnmarshalYamlBytesMapStructPtrSlice(t *testing.T) { func TestUnmarshalYamlBytesMapStructPtrSlice(t *testing.T) {
var c struct { var c struct {
Persons map[string][]*struct { Persons map[string][]*struct {
Id int ID int
Name string `json:"name,optional"` Name string `json:"name,optional"`
} }
} }
content := []byte(`Persons: content := []byte(`Persons:
first: first:
- Id: 1 - ID: 1
name: kevin`) name: kevin`)
assert.Nil(t, UnmarshalYamlBytes(content, &c)) assert.Nil(t, UnmarshalYamlBytes(content, &c))
assert.Equal(t, 1, len(c.Persons)) assert.Equal(t, 1, len(c.Persons))
assert.Equal(t, 1, c.Persons["first"][0].Id) assert.Equal(t, 1, c.Persons["first"][0].ID)
assert.Equal(t, "kevin", c.Persons["first"][0].Name) assert.Equal(t, "kevin", c.Persons["first"][0].Name)
} }
func TestUnmarshalYamlBytesMapEmptyStructPtrSlice(t *testing.T) { func TestUnmarshalYamlBytesMapEmptyStructPtrSlice(t *testing.T) {
var c struct { var c struct {
Persons map[string][]*struct { Persons map[string][]*struct {
Id int ID int
Name string `json:"name,optional"` Name string `json:"name,optional"`
} }
} }
@@ -634,13 +634,13 @@ func TestUnmarshalYamlBytesMapEmptyStructPtrSlice(t *testing.T) {
func TestUnmarshalYamlBytesMapStructPtrSliceMissingPartial(t *testing.T) { func TestUnmarshalYamlBytesMapStructPtrSliceMissingPartial(t *testing.T) {
var c struct { var c struct {
Persons map[string][]*struct { Persons map[string][]*struct {
Id int ID int
Name string Name string
} }
} }
content := []byte(`Persons: content := []byte(`Persons:
first: first:
- Id: 1`) - ID: 1`)
assert.NotNil(t, UnmarshalYamlBytes(content, &c)) assert.NotNil(t, UnmarshalYamlBytes(content, &c))
} }
@@ -648,17 +648,17 @@ func TestUnmarshalYamlBytesMapStructPtrSliceMissingPartial(t *testing.T) {
func TestUnmarshalYamlBytesMapStructPtrSliceOptional(t *testing.T) { func TestUnmarshalYamlBytesMapStructPtrSliceOptional(t *testing.T) {
var c struct { var c struct {
Persons map[string][]*struct { Persons map[string][]*struct {
Id int ID int
Name string `json:"name,optional"` Name string `json:"name,optional"`
} }
} }
content := []byte(`Persons: content := []byte(`Persons:
first: first:
- Id: 1`) - ID: 1`)
assert.Nil(t, UnmarshalYamlBytes(content, &c)) assert.Nil(t, UnmarshalYamlBytes(content, &c))
assert.Equal(t, 1, len(c.Persons)) assert.Equal(t, 1, len(c.Persons))
assert.Equal(t, 1, c.Persons["first"][0].Id) assert.Equal(t, 1, c.Persons["first"][0].ID)
} }
func TestUnmarshalYamlStructOptional(t *testing.T) { func TestUnmarshalYamlStructOptional(t *testing.T) {

View File

@@ -4,6 +4,7 @@ import "math"
const epsilon = 1e-6 const epsilon = 1e-6
// CalcEntropy calculates the entropy of m.
func CalcEntropy(m map[interface{}]int) float64 { func CalcEntropy(m map[interface{}]int) float64 {
if len(m) == 0 || len(m) == 1 { if len(m) == 0 || len(m) == 1 {
return 1 return 1

View File

@@ -1,17 +1,19 @@
package mathx package mathx
// MaxInt returns the larger one of a and b.
func MaxInt(a, b int) int { func MaxInt(a, b int) int {
if a > b { if a > b {
return a return a
} else {
return b
} }
return b
} }
// MinInt returns the smaller one of a and b.
func MinInt(a, b int) int { func MinInt(a, b int) int {
if a < b { if a < b {
return a return a
} else {
return b
} }
return b
} }

View File

@@ -6,18 +6,21 @@ import (
"time" "time"
) )
// A Proba is used to test if true on given probability.
type Proba struct { type Proba struct {
// rand.New(...) returns a non thread safe object // rand.New(...) returns a non thread safe object
r *rand.Rand r *rand.Rand
lock sync.Mutex lock sync.Mutex
} }
// NewProba returns a Proba.
func NewProba() *Proba { func NewProba() *Proba {
return &Proba{ return &Proba{
r: rand.New(rand.NewSource(time.Now().UnixNano())), r: rand.New(rand.NewSource(time.Now().UnixNano())),
} }
} }
// TrueOnProba checks if true on given probability.
func (p *Proba) TrueOnProba(proba float64) (truth bool) { func (p *Proba) TrueOnProba(proba float64) (truth bool) {
p.lock.Lock() p.lock.Lock()
truth = p.r.Float64() < proba truth = p.r.Float64() < proba

Some files were not shown because too many files have changed in this diff Show More