Compare commits

...

189 Commits

Author SHA1 Message Date
Kevin Wan
421e6617b1 chore: add more tests (#3592) 2023-09-27 22:33:27 +08:00
Kevin Wan
0ee7a271d3 fix: avoid float overflow in mapping.Unmarshal (#3590) 2023-09-26 13:46:34 +00:00
dependabot[bot]
af022b9655 chore(deps): bump google.golang.org/grpc from 1.58.1 to 1.58.2 in /tools/goctl (#3584) 2023-09-25 16:23:29 +08:00
dependabot[bot]
98d46261d9 chore(deps): bump google.golang.org/grpc from 1.58.1 to 1.58.2 (#3585)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2023-09-24 23:27:02 +08:00
Kevin Wan
4222fd97bc chore: add test for logging rotate size (#3587) 2023-09-24 22:28:03 +08:00
dependabot[bot]
814852f0b8 chore(deps): bump github.com/fullstorydev/grpcurl from 1.8.7 to 1.8.8 (#3586)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2023-09-24 19:04:34 +08:00
Kevin Wan
ded2888759 fix: avoid integer overflow in mapping.Unmarshal (#3582) 2023-09-21 14:22:33 +00:00
Kevin Wan
18d66a795d chore: add more tests (#3578) 2023-09-20 23:52:10 +08:00
Kevin Wan
4211672bfd chore: add more tests (#3577) 2023-09-20 00:01:26 +08:00
Kevin Wan
68df0c3620 chore: add more tests (#3575) 2023-09-18 11:01:46 +08:00
xt-inking
5e435b6a76 fix: avoid losing logs before closing (#3573) 2023-09-17 11:38:53 +00:00
Kevin Wan
0dcede6457 chore: refactor log limit in rest (#3572) 2023-09-16 22:33:30 +08:00
Awadabang
cc21f5fae2 update: limit logBrief http body size (#3498)
Co-authored-by: 常公征 <changgz@yealink.com>
2023-09-16 11:58:21 +00:00
dependabot[bot]
b22ad50d59 chore(deps): bump google.golang.org/grpc from 1.58.0 to 1.58.1 in /tools/goctl (#3568)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2023-09-16 16:52:01 +08:00
Kevin Wan
974252980c chore: upgrade grpc (#3570) 2023-09-15 23:21:22 +08:00
dependabot[bot]
8d83986d27 chore(deps): bump google.golang.org/grpc from 1.57.0 to 1.58.0 in /tools/goctl (#3546) 2023-09-12 20:00:52 +08:00
Kevin Wan
6821b0a7dd chore: upgrade grpc (#3558) 2023-09-12 10:30:26 +08:00
Kevin Wan
1ba1724c65 chore: refactor (#3545) 2023-09-06 22:36:43 +08:00
Xinwei Xiong
ca5a7df5b0 feat: Optimize Encoding Functions and Add Descriptive Comments (#3543) 2023-09-06 14:19:50 +00:00
dependabot[bot]
69a3024853 chore(deps): bump golang.org/x/net from 0.14.0 to 0.15.0 (#3544)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2023-09-06 20:28:03 +08:00
dependabot[bot]
fd3abf3717 chore(deps): bump golang.org/x/text from 0.12.0 to 0.13.0 in /tools/goctl (#3542)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
Co-authored-by: Kevin Wan <wanjunfeng@gmail.com>
2023-09-06 20:20:12 +08:00
dependabot[bot]
99b3750d10 chore(deps): bump golang.org/x/sys from 0.11.0 to 0.12.0 (#3541)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2023-09-06 00:02:00 +08:00
POABOB
33f6d7ebb8 fix: goctl pg gen will extract all fields when the same table name exists in different schemas (#3496) (#3517) 2023-09-04 20:48:26 +08:00
kesonan
c4ef9ceb68 Add api version (#3536) 2023-09-02 01:45:48 +00:00
dependabot[bot]
e95861f28a chore(deps): bump github.com/pelletier/go-toml/v2 from 2.0.9 to 2.1.0 (#3532)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2023-08-31 22:52:23 +08:00
Kevin Wan
d3cd7b17c0 Revert "add:func() QueryRowsPartial,QueryRowPartial into cachedsql.go" (#3523) 2023-08-27 21:36:14 +08:00
liumin-go
a50515496c add:func() QueryRowsPartial,QueryRowPartial into cachedsql.go (#3512)
Co-authored-by: 刘敏 <liumin@liumindeMac-mini.local>
2023-08-27 08:05:02 +00:00
Kevin Wan
0423313d9b feat: support json:"-" in mapping (#3521) 2023-08-27 16:04:38 +08:00
dependabot[bot]
7bbe7de05f chore(deps): bump github.com/google/uuid from 1.3.0 to 1.3.1 (#3511)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2023-08-24 14:59:49 +08:00
dependabot[bot]
83a451f2f4 chore(deps): bump github.com/jhump/protoreflect from 1.15.1 to 1.15.2 (#3518)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2023-08-24 14:39:50 +08:00
Kevin Wan
d2a874f21d chore: upgrade go-zero, and update goctl version (#3509) 2023-08-21 09:09:51 +08:00
Kevin Wan
fd85b24b25 Update readme-cn.md 2023-08-20 23:38:39 +08:00
Kevin Wan
14fcbd7658 fix #3499 (#3508) 2023-08-19 22:17:24 +08:00
Kevin Wan
cb3ffc76a3 fix: #3478 (#3493) 2023-08-14 14:22:22 +00:00
dependabot[bot]
45fbd7dc35 chore(deps): bump github.com/alicebob/miniredis/v2 from 2.30.4 to 2.30.5 (#3490)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2023-08-10 19:53:17 +08:00
dependabot[bot]
af821cf794 chore(deps): bump golang.org/x/net from 0.13.0 to 0.14.0 (#3484)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2023-08-09 10:39:32 +08:00
dependabot[bot]
ec69950153 chore(deps): bump github.com/jackc/pgx/v5 from 5.4.2 to 5.4.3 (#3483)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2023-08-09 10:18:27 +08:00
Kevin Wan
ce5e78db53 chore: use jsonTagKey to replace json literals (#3479) 2023-08-06 22:00:24 +08:00
dependabot[bot]
ed75802eaa chore(deps): bump golang.org/x/text from 0.11.0 to 0.12.0 in /tools/goctl (#3477)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2023-08-06 09:57:19 +08:00
dependabot[bot]
76c92b571d chore(deps): bump golang.org/x/sys from 0.10.0 to 0.11.0 (#3476)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2023-08-06 00:19:00 +08:00
dependabot[bot]
a2e703c53e chore(deps): bump golang.org/x/net from 0.12.0 to 0.13.0 (#3463)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
Co-authored-by: Kevin Wan <wanjunfeng@gmail.com>
2023-08-04 20:39:17 +08:00
dependabot[bot]
ca698deb2a chore(deps): bump go.mongodb.org/mongo-driver from 1.12.0 to 1.12.1 (#3472)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2023-08-04 19:13:16 +08:00
guangwu
a9f4aab86b fix: "EXPRIMENTAL" is a misspelling of "EXPERIMENTAL" (#3462) 2023-08-02 23:59:37 +08:00
Kevin Wan
c3f57e9b0a chore: fix potential nil pointer errors (#3454) 2023-07-30 21:37:41 +08:00
Kevin Wan
ad4cce959d chore: add more tests (#3453) 2023-07-29 22:34:16 +08:00
Shyunn
279123f4a7 feat: add prometheus summary metrics (#3440)
Co-authored-by: chen quan <chenquan.dev@gmail.com>
2023-07-29 16:51:43 +08:00
dependabot[bot]
457eb1961b chore(deps): bump google.golang.org/grpc from 1.56.2 to 1.57.0 (#3445)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2023-07-28 22:18:14 +08:00
dependabot[bot]
63df384a4b chore(deps): bump google.golang.org/grpc from 1.56.2 to 1.57.0 in /tools/goctl (#3446)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2023-07-28 22:17:23 +08:00
MarkJoyMa
42bfa26e2b fix: remove mapping redundant error (#3439) 2023-07-24 00:10:50 +08:00
Kevin Wan
ff04356704 fix: format error should not trigger circuit breaker in sqlx (#3437) 2023-07-23 20:40:03 +08:00
MarkJoyMa
05db706c62 feat: optimize mapping error (#3438) 2023-07-23 12:10:41 +00:00
dependabot[bot]
ef2e0d859d chore(deps): bump go.uber.org/automaxprocs from 1.5.2 to 1.5.3 (#3435)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2023-07-22 00:03:07 +08:00
dependabot[bot]
05ec16ae9d chore(deps): bump github.com/gookit/color from 1.5.3 to 1.5.4 in /tools/goctl (#3433)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2023-07-21 23:27:49 +08:00
dependabot[bot]
13e685e0db chore(deps): bump github.com/emicklei/proto from 1.12.0 to 1.12.1 in /tools/goctl (#3431)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2023-07-21 23:12:06 +08:00
dependabot[bot]
c10f44b74e chore(deps): bump github.com/emicklei/proto from 1.11.2 to 1.12.0 in /tools/goctl (#3429) 2023-07-18 09:41:13 +08:00
Kevin Wan
57644420ed chore: update go-zero for goctl (#3426) 2023-07-14 21:38:42 +08:00
dependabot[bot]
b245159417 chore(deps): bump github.com/pelletier/go-toml/v2 from 2.0.8 to 2.0.9 (#3423)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2023-07-14 19:57:36 +08:00
dependabot[bot]
c26ea17669 chore(deps): bump github.com/iancoleman/strcase from 0.2.0 to 0.3.0 in /tools/goctl (#3424) 2023-07-14 13:20:06 +08:00
Kevin Wan
a7daff3587 chore: make servicegroup panic as demand (#3422) 2023-07-13 14:08:35 +00:00
dependabot[bot]
6719d06146 chore(deps): bump github.com/jackc/pgx/v5 from 5.4.1 to 5.4.2 (#3417)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2023-07-13 10:56:18 +08:00
Kevin Wan
0c6eaeda9f chore: coding style (#3413) 2023-07-12 01:08:09 +08:00
Xinyan Lu
b9c0c0f8b5 feat: add detail type mismatch info in number fields check (#3386) (#3387) 2023-07-11 16:29:42 +00:00
Kevin Wan
77da459165 chore: make test stable (#3412) 2023-07-11 16:20:41 +00:00
Kevin Wan
13cdbdc98b chore: avoid nested WithCodeResponseWriter (#3406) 2023-07-11 15:59:43 +00:00
guangwu
e8c1e6e09b fix: log format error (#3409) 2023-07-11 05:28:53 +00:00
guangwu
f1171e01f2 chore: slice replace loop (#3410) 2023-07-11 05:27:46 +00:00
cong
61e562d0c7 refactor(rest): keep rest log collector context key private (#3407) 2023-07-10 01:52:26 +00:00
chen quan
b71453985c feat(sqlx): support for custom Acceptable function (#3405) 2023-07-10 01:16:45 +00:00
Kevin Wan
31b9ba19a2 chore: refactor httpx.TimeoutHandler (#3400) 2023-07-09 07:04:59 +00:00
dependabot[bot]
3170afd57b chore(deps): bump google.golang.org/grpc from 1.56.1 to 1.56.2 in /tools/goctl (#3403)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2023-07-08 16:21:50 +08:00
dependabot[bot]
03e365a5d8 chore(deps): bump google.golang.org/grpc from 1.56.1 to 1.56.2 (#3402)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2023-07-08 15:34:45 +08:00
dependabot[bot]
7d4fce9588 chore(deps): bump golang.org/x/net from 0.11.0 to 0.12.0 (#3401)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2023-07-06 14:54:13 +08:00
扶桑花间
916cea858f 1. Fix w. (http. Flusher). Flush() error (#3388) 2023-07-05 15:27:15 +00:00
dependabot[bot]
a86942d532 chore(deps): bump golang.org/x/sys from 0.9.0 to 0.10.0 (#3396)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2023-07-05 22:49:07 +08:00
dependabot[bot]
f76c70ea9a chore(deps): bump golang.org/x/text from 0.10.0 to 0.11.0 in /tools/goctl (#3397)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2023-07-05 22:19:04 +08:00
MarkJoyMa
4cbfdb3d74 feat: optimize must log add stack (#3384) 2023-06-30 01:11:03 +00:00
dependabot[bot]
aefa6dfb50 chore(deps): bump google.golang.org/protobuf from 1.30.0 to 1.31.0 (#3376) 2023-06-30 09:06:20 +08:00
dependabot[bot]
9047029475 chore(deps): bump github.com/alicebob/miniredis/v2 from 2.30.3 to 2.30.4 (#3381) 2023-06-30 09:05:05 +08:00
dependabot[bot]
f296c182f7 chore(deps): bump google.golang.org/protobuf from 1.30.0 to 1.31.0 in /tools/goctl (#3377)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2023-06-27 22:45:00 +08:00
Kevin Wan
40e7a4cd07 chore: refactor httpx.SetOkHandler (#3373) 2023-06-26 00:27:26 +08:00
Kevin Wan
92e5819e91 opt: improve logx performance (#3371) 2023-06-25 15:41:28 +08:00
2822132073
8d23ab158b fix In goctl new api, occur error invalid character 'A' looking for beginning of value (#3357) 2023-06-25 07:26:21 +00:00
唐小鸭
bcccfab824 [fix] The directory is not recognized when it is in a soft link (#3337) 2023-06-25 05:06:27 +00:00
dependabot[bot]
f7e701a634 chore(deps): bump google.golang.org/grpc from 1.56.0 to 1.56.1 (#3367) 2023-06-25 13:04:13 +08:00
dependabot[bot]
7c2d8e5cc2 chore(deps): bump google.golang.org/grpc from 1.56.0 to 1.56.1 in /tools/goctl (#3369) 2023-06-25 12:57:03 +08:00
dependabot[bot]
5b622d6265 chore(deps): bump go.mongodb.org/mongo-driver from 1.11.7 to 1.12.0 (#3370)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2023-06-23 14:50:48 +08:00
dependabot[bot]
c5510a4e1b chore(deps): bump github.com/jackc/pgx/v5 from 5.4.0 to 5.4.1 (#3363)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2023-06-21 15:16:06 +08:00
Kevin Wan
2a33b74b35 chore: coding style (#3362) 2023-06-17 22:59:00 +08:00
anqiansong
45bb547a81 (goctl)fix: #3328 (#3348)
Co-authored-by: Kevin Wan <wanjunfeng@gmail.com>
2023-06-17 13:23:57 +00:00
Mikael
f5f5261556 add whether to generate rpc client option (#3361)
Co-authored-by: admin <admin@admindeMacBook-Pro.local>
2023-06-17 12:52:49 +00:00
Kevin Wan
b176d5d434 chore: add more tests (#3359) 2023-06-17 20:51:33 +08:00
MarkJoyMa
92f6c48349 fix: NewClientWithTarget miss default config (#3358) 2023-06-16 23:29:52 +08:00
dependabot[bot]
71e8230e65 chore(deps): bump google.golang.org/grpc from 1.55.0 to 1.56.0 (#3354) 2023-06-16 07:46:33 +08:00
dependabot[bot]
018fa8e0a0 chore(deps): bump google.golang.org/grpc from 1.55.0 to 1.56.0 in /tools/goctl (#3355) 2023-06-16 07:31:48 +08:00
dependabot[bot]
979fe9718a chore(deps): bump github.com/prometheus/client_golang from 1.15.1 to 1.16.0 (#3352) 2023-06-16 07:29:38 +08:00
Kevin Wan
f998803131 chore: refactor and add more tests (#3351) 2023-06-16 01:04:58 +08:00
TaoYu
1262266ac2 feat: httpx add common handler (#3269) 2023-06-15 15:31:15 +00:00
dependabot[bot]
9c32bf8478 chore(deps): bump github.com/zeromicro/ddl-parser from 1.0.4 to 1.0.5 in /tools/goctl (#3350) 2023-06-15 07:34:19 +08:00
dependabot[bot]
37ec7f6443 chore(deps): bump github.com/jackc/pgx/v5 from 5.3.1 to 5.4.0 (#3349) 2023-06-15 07:31:16 +08:00
dependabot[bot]
2fdc4dfc0f chore(deps): bump golang.org/x/net from 0.10.0 to 0.11.0 (#3346) 2023-06-14 07:05:24 +08:00
dependabot[bot]
4b2a6ba3de chore(deps): bump golang.org/x/text from 0.9.0 to 0.10.0 in /tools/goctl (#3342) 2023-06-13 07:37:52 +08:00
dependabot[bot]
7fa3f10f22 chore(deps): bump golang.org/x/sys from 0.8.0 to 0.9.0 (#3341) 2023-06-13 07:37:30 +08:00
elza
4a29a0b642 fix: fixed goctl api go --home parameter error when loading non-exist… (#3319)
Co-authored-by: yuanyou <yuanyou@kezaihui.com>
2023-06-12 16:00:41 +00:00
Kevin Wan
a62745a152 Update readme-cn.md 2023-06-12 23:35:33 +08:00
Kevin Wan
28314326e7 chore: more tests (#3340) 2023-06-12 23:29:23 +08:00
Kevin Wan
f6bdb6e1de chore: add more tests (#3338) 2023-06-12 01:22:20 +08:00
Kevin Wan
efa6940001 chore: improve logx gzip (#3332) 2023-06-09 22:50:59 +08:00
Ron_haur
da81d8f774 Fix: logx with Compress auto delete old logs (#3329)
Co-authored-by: haoran.ren <haoran.ren@mihoyo.com>
2023-06-08 11:08:04 +00:00
dependabot[bot]
fd84b27bdc chore(deps): bump go.mongodb.org/mongo-driver from 1.11.6 to 1.11.7 (#3325) 2023-06-08 07:21:49 +08:00
Kevin Wan
6b4d0d89c0 chore: add more tests (#3324) 2023-06-07 00:46:43 +08:00
Kevin Wan
d61a55f779 chore: update readme to remove upgrade parts. (#3318) 2023-06-04 23:34:38 +08:00
Kevin Wan
8ef4164209 chore: make test stable (#3317) 2023-06-04 23:20:58 +08:00
Kevin Wan
50e29e2075 chore: update go-zero for goctl (#3316) 2023-06-04 17:28:27 +08:00
Kevin Wan
452c9dbcaf chore: add more tests (#3315) 2023-06-01 21:08:44 +08:00
dependabot[bot]
3564e36a35 chore(deps): bump github.com/alicebob/miniredis/v2 from 2.30.2 to 2.30.3 (#3309)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2023-06-01 10:42:26 +08:00
dependabot[bot]
e479e47634 chore(deps): bump github.com/stretchr/testify from 1.8.3 to 1.8.4 in /tools/goctl (#3306)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2023-05-31 11:28:15 +08:00
dependabot[bot]
ad921a6419 chore(deps): bump github.com/stretchr/testify from 1.8.3 to 1.8.4 (#3305)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2023-05-31 11:20:12 +08:00
Kevin Wan
44c8d6f269 chore: add more tests (#3304) 2023-05-30 23:27:27 +08:00
Kevin Wan
8a4cc4f98d chore: add more tests (#3299) 2023-05-29 23:44:36 +08:00
kangqi
e751736516 trace exporter: add new type file (#3298)
Co-authored-by: zhaikangqi <794556486@qq.com>
2023-05-29 18:24:59 +08:00
Kevin Wan
032f2419a2 Update readme-cn.md 2023-05-29 17:36:10 +08:00
Kevin Wan
84adc054bc chore: add more tests (#3296) 2023-05-29 07:39:41 +08:00
Kevin Wan
b92e706ce1 chore: refactor (#3295) 2023-05-28 21:31:36 +08:00
MiNG
1b5946346e feat: support optional otel global initialization for #3284 (#3292) 2023-05-28 11:41:48 +00:00
Kevin Wan
28d3905731 chore: add more tests (#3294) 2023-05-28 19:26:45 +08:00
hc
3726851c7f feat: sqlc add SetCacheWithExpire method (#3249)
Co-authored-by: luohancai <luohancai@taqu.cn>
2023-05-28 12:27:30 +08:00
Kevin Wan
2f2ddd373b chore: refactor retry (#3291) 2023-05-28 12:11:55 +08:00
Xiaoju Jiang
8d48e34eed update: expand the retry method to support timeout and interval control (#3283) 2023-05-28 10:17:50 +08:00
Kevin Wan
32f78668db chore: add more tests (#3290) 2023-05-27 23:57:33 +08:00
Kevin Wan
cd0f3726ed chore: add more tests (#3288) 2023-05-27 21:49:11 +08:00
me-cs
0217044900 update:Use the Milliseconds method of duration to get the number of milliseconds (#3285)
Co-authored-by: Kevin Wan <wanjunfeng@gmail.com>
2023-05-26 14:32:46 +00:00
Kevin Wan
8b4382dcec chore: add more tests (#3286) 2023-05-26 22:30:03 +08:00
Kevin Wan
fa33329a44 chore: add more tests (#3282) 2023-05-26 00:21:47 +08:00
dependabot[bot]
d76a39ac26 chore(deps): bump github.com/pelletier/go-toml/v2 from 2.0.7 to 2.0.8 (#3280)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2023-05-25 16:50:36 +08:00
guangwu
76a7a17e57 typo (#3281) 2023-05-25 15:58:11 +08:00
Kevin Wan
4a2a8d9e45 chore: add more tests (#3279) 2023-05-24 23:58:45 +08:00
guangwu
ef26b39b4c misspelling (#3248) 2023-05-24 08:15:27 +00:00
anqiansong
3ca40001b4 feat(goctl): Add with session for model tpl (#3272) 2023-05-24 07:34:26 +00:00
Toby
278ae3d26a feat: add OtlpHttpPath config support for ZincObserve Telemetry (#3271)
Signed-off-by: Toby Yan <me@tobyan.com>
Co-authored-by: cong <zhangcong1992@gmail.com>
2023-05-23 03:11:58 +00:00
dependabot[bot]
fa1d6d50a8 chore(deps): bump github.com/stretchr/testify from 1.8.2 to 1.8.3 (#3267) 2023-05-22 09:41:17 +08:00
dependabot[bot]
0f4973be06 chore(deps): bump github.com/stretchr/testify from 1.8.2 to 1.8.3 in /tools/goctl (#3268) 2023-05-21 21:43:24 +08:00
Kevin Wan
a9aac7e420 chore: add more tests (#3265) 2023-05-19 23:29:30 +08:00
Kevin Wan
925cf8d3d1 chore: add more tests (#3261) 2023-05-19 12:15:43 +08:00
Kevin Wan
99ce24e2ab chore: add more tests (#3260) 2023-05-19 00:56:50 +08:00
Kevin Wan
701bb31ed2 chore: add more tests (#3259) 2023-05-18 23:43:50 +08:00
Kevin Wan
55e2c7ee83 chore: add more tests (#3258) 2023-05-18 23:11:32 +08:00
Kevin Wan
90839965fa chore: remove directive for tests (#3257) 2023-05-18 21:34:33 +08:00
Kevin Wan
f7228e9af1 chore: add more tests (#3256) 2023-05-18 12:24:04 +08:00
Kevin Wan
f95adae3c1 Update readme-cn.md 2023-05-18 11:00:03 +08:00
Kevin Wan
bff5b81ad9 feat: support using session to execute statements in transaction (#3252) 2023-05-17 14:15:24 +00:00
guangwu
f0bdfb928f fix call error func (#3245) 2023-05-12 04:34:08 +00:00
Kevin Wan
e4a1b7bb39 chore: format the code (#3243) 2023-05-12 10:38:05 +08:00
dependabot[bot]
b6906b5d21 chore(deps): bump go.etcd.io/etcd/client/v3 from 3.5.8 to 3.5.9 (#3242)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2023-05-12 08:27:46 +08:00
guangwu
116da96178 add ignore file (#3240) 2023-05-11 08:53:13 +00:00
lchjczw
9fa98c2bd3 api imports take the form of relative paths (#3201)
Co-authored-by: 李春华 <lichunhua@threesoft.cn>
2023-05-10 03:40:07 +00:00
Kevin Wan
b1c4c4736f chore: better comments (#3232) 2023-05-09 22:58:40 +08:00
Thirteen
ef410e8083 fix: generate client directory for goctl (#3166) 2023-05-08 21:03:48 +00:00
fondoger
c22bc1c8ea [dart-gen] Fix lists containing atomic types (#3210) 2023-05-08 21:00:46 +00:00
Toby
1853428011 feat: add otlptracegrpc otlptracehttp headers support for Uptrace (#3219)
Signed-off-by: Toby Yan <me@tobyan.com>
Co-authored-by: cong <zhangcong1992@gmail.com>
2023-05-08 20:58:29 +00:00
dependabot[bot]
3637e10815 chore(deps): bump golang.org/x/net from 0.9.0 to 0.10.0 (#3230) 2023-05-09 04:53:33 +08:00
Kevin Wan
93124329ac chore: add more tests (#3229) 2023-05-08 23:49:13 +08:00
yangtao
851a72f1cc Add RunSafe with context (#3224)
Co-authored-by: yangtao <mrynag8614@163.com>
Co-authored-by: Kevin Wan <wanjunfeng@gmail.com>
2023-05-08 23:08:31 +08:00
SleeplessBot
a93c24ce84 Add method label for prometheus middleware metrics (#3226)
Co-authored-by: 蓝益尤 <lan.yiyou@intellif.com>
2023-05-08 12:59:20 +00:00
Kevin Wan
9f42eda9ff fix: timeout handler not implementing http.Flusher (#3225) 2023-05-08 18:07:02 +08:00
soasurs
8762a3b7ba fix: Errorv should generate JSON Object for content field in log (#3222)
Signed-off-by: soasurs <soasurs@gmail.com>
2023-05-08 09:16:44 +00:00
Kevin Wan
2684a157ff chore: remove fgprof, use pprof directly (#3220) 2023-05-07 18:21:10 +08:00
guangwu
63368d8b0c io/ioutil deprecated (#3217) 2023-05-06 09:50:54 +00:00
guangwu
4f13fe8188 io/ioutil deprecated (#3215) 2023-05-06 16:37:43 +08:00
Kevin Wan
9fc7874336 chore: optimize stat calculation (#3213) 2023-05-06 15:31:56 +08:00
dependabot[bot]
e6518521eb chore(deps): bump google.golang.org/grpc from 1.54.0 to 1.55.0 in /tools/goctl (#3209)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2023-05-06 14:29:51 +08:00
Kevin Wan
8f5a0a2de7 fix: remove etcd pings to avoid too-many-pings error (#3212) 2023-05-06 12:39:19 +08:00
ALMAS
774e8d1d08 feat: replaced color package to support Windows (#3207) 2023-05-05 13:09:54 +00:00
cong
8ad0668612 fix(zrpc): remove default keepalive params for NewClientWithTarget (#3208) 2023-05-05 13:00:59 +00:00
dependabot[bot]
8a043d2443 chore(deps): bump golang.org/x/sys from 0.7.0 to 0.8.0 (#3204)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2023-05-05 11:03:18 +08:00
dependabot[bot]
0e2ee97a02 chore(deps): bump go.mongodb.org/mongo-driver from 1.11.4 to 1.11.6 (#3205)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2023-05-05 10:41:23 +08:00
Kevin Wan
42300a7d83 chore: add more tests (#3203) 2023-05-04 23:43:34 +08:00
dependabot[bot]
fe97fab274 chore(deps): bump github.com/prometheus/client_golang from 1.15.0 to 1.15.1 (#3200)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2023-05-04 11:00:51 +08:00
dependabot[bot]
f93e752f98 chore(deps): bump github.com/emicklei/proto from 1.11.1 to 1.11.2 in /tools/goctl (#3199)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2023-05-04 10:24:24 +08:00
fondoger
3a66fc038f [dart-gen] Fix nullable list item issue (#3192) 2023-05-01 15:01:28 +08:00
Kevin Wan
b028ed058d chore: change port to 6060 by default in devserver (#3191) 2023-05-01 11:18:15 +08:00
guonaihong
1fd0c3992b fix panic (#3176) 2023-04-30 14:58:30 +00:00
Kevin Wan
1aebb3e5e4 chore: update readme (#3190) 2023-04-30 11:21:28 +08:00
Kevin Wan
8ffe4c01d1 chore: use logx.Must instead of log.Fatal (#3189) 2023-04-29 23:46:04 +08:00
Kevin Wan
a31256b327 chore: add more tests (#3187) 2023-04-29 22:59:07 +08:00
Kevin Wan
14caf5c799 chore: simplify tests with logtest (#3184) 2023-04-29 20:36:29 +08:00
chensy
c0f8a58ed7 no \t to space line (#3177)
Co-authored-by: chenjieping <chenjieping@kezaihui.com>
2023-04-29 08:35:04 +00:00
dependabot[bot]
3189ec7be6 chore(deps): bump github.com/go-sql-driver/mysql from 1.7.0 to 1.7.1 (#3174)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2023-04-28 23:27:01 +08:00
dependabot[bot]
f51e9f0ea7 chore(deps): bump github.com/go-sql-driver/mysql from 1.7.0 to 1.7.1 in /tools/goctl (#3175)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2023-04-27 23:37:28 +08:00
cong
ba9d510cdb fix(metrics): enable prometheus global switch when user enable devsever metrics (#3169) 2023-04-24 20:02:04 +08:00
220 changed files with 7451 additions and 1943 deletions

View File

@@ -6,3 +6,4 @@ ignore:
- "tools"
- "**/mock"
- "**/*_mock.go"
- "**/*test"

View File

@@ -61,5 +61,5 @@ jobs:
run: |
go mod verify
go mod download
go test -v -race ./...
go test ./...
cd tools/goctl && go build -v goctl.go

4
.gitignore vendored
View File

@@ -12,11 +12,13 @@
# ignore
**/.idea
**/.vscode
**/.DS_Store
**/logs
**/adhoc
**/coverage.txt
# for test purpose
**/adhoc
go.work
go.work.sum

View File

@@ -29,6 +29,8 @@ func NewSafeMap() *SafeMap {
// Del deletes the value with the given key from m.
func (m *SafeMap) Del(key any) {
m.lock.Lock()
defer m.lock.Unlock()
if _, ok := m.dirtyOld[key]; ok {
delete(m.dirtyOld, key)
m.deletionOld++
@@ -52,7 +54,6 @@ func (m *SafeMap) Del(key any) {
m.dirtyNew = make(map[any]any)
m.deletionNew = 0
}
m.lock.Unlock()
}
// Get gets the value with the given key from m.
@@ -89,6 +90,8 @@ func (m *SafeMap) Range(f func(key, val any) bool) {
// Set sets the value into m with the given key.
func (m *SafeMap) Set(key, value any) {
m.lock.Lock()
defer m.lock.Unlock()
if m.deletionOld <= maxDeletion {
if _, ok := m.dirtyNew[key]; ok {
delete(m.dirtyNew, key)
@@ -102,7 +105,6 @@ func (m *SafeMap) Set(key, value any) {
}
m.dirtyNew[key] = value
}
m.lock.Unlock()
}
// Size returns the size of m.

View File

@@ -147,3 +147,65 @@ func TestSafeMap_Range(t *testing.T) {
assert.Equal(t, m.dirtyNew, newMap.dirtyNew)
assert.Equal(t, m.dirtyOld, newMap.dirtyOld)
}
func TestSetManyTimes(t *testing.T) {
const iteration = maxDeletion * 2
m := NewSafeMap()
for i := 0; i < iteration; i++ {
m.Set(i, i)
if i%3 == 0 {
m.Del(i / 2)
}
}
var count int
m.Range(func(k, v any) bool {
count++
return count < maxDeletion/2
})
assert.Equal(t, maxDeletion/2, count)
for i := 0; i < iteration; i++ {
m.Set(i, i)
if i%3 == 0 {
m.Del(i / 2)
}
}
for i := 0; i < iteration; i++ {
m.Set(i, i)
if i%3 == 0 {
m.Del(i / 2)
}
}
for i := 0; i < iteration; i++ {
m.Set(i, i)
if i%3 == 0 {
m.Del(i / 2)
}
}
count = 0
m.Range(func(k, v any) bool {
count++
return count < maxDeletion
})
assert.Equal(t, maxDeletion, count)
}
func TestSetManyTimesNew(t *testing.T) {
m := NewSafeMap()
for i := 0; i < maxDeletion*3; i++ {
m.Set(i, i)
}
for i := 0; i < maxDeletion*2; i++ {
m.Del(i)
}
for i := 0; i < maxDeletion*3; i++ {
m.Set(i+maxDeletion*3, i+maxDeletion*3)
}
for i := 0; i < maxDeletion*2; i++ {
m.Del(i + maxDeletion*2)
}
for i := 0; i < maxDeletion-copyThreshold+1; i++ {
m.Del(i + maxDeletion*2)
}
assert.Equal(t, 0, len(m.dirtyNew))
}

View File

@@ -35,11 +35,11 @@ func TestConfigJson(t *testing.T) {
"c": "${FOO}",
"d": "abcd!@#$112"
}`
t.Setenv("FOO", "2")
for _, test := range tests {
test := test
t.Run(test, func(t *testing.T) {
os.Setenv("FOO", "2")
defer os.Unsetenv("FOO")
tmpfile, err := createTempFile(test, text)
assert.Nil(t, err)
defer os.Remove(tmpfile)
@@ -81,8 +81,7 @@ b = 1
c = "${FOO}"
d = "abcd!@#$112"
`
os.Setenv("FOO", "2")
defer os.Unsetenv("FOO")
t.Setenv("FOO", "2")
tmpfile, err := createTempFile(".toml", text)
assert.Nil(t, err)
defer os.Remove(tmpfile)
@@ -207,8 +206,7 @@ b = 1
c = "${FOO}"
d = "abcd!@#112"
`
os.Setenv("FOO", "2")
defer os.Unsetenv("FOO")
t.Setenv("FOO", "2")
tmpfile, err := createTempFile(".toml", text)
assert.Nil(t, err)
defer os.Remove(tmpfile)
@@ -239,11 +237,10 @@ func TestConfigJsonEnv(t *testing.T) {
"c": "${FOO}",
"d": "abcd!@#$a12 3"
}`
t.Setenv("FOO", "2")
for _, test := range tests {
test := test
t.Run(test, func(t *testing.T) {
os.Setenv("FOO", "2")
defer os.Unsetenv("FOO")
tmpfile, err := createTempFile(test, text)
assert.Nil(t, err)
defer os.Remove(tmpfile)

View File

@@ -45,8 +45,7 @@ func TestPropertiesEnv(t *testing.T) {
assert.Nil(t, err)
defer os.Remove(tmpfile)
os.Setenv("FOO", "2")
defer os.Unsetenv("FOO")
t.Setenv("FOO", "2")
props, err := LoadProperties(tmpfile, UseEnv())
assert.Nil(t, err)

View File

@@ -337,13 +337,11 @@ func (c *cluster) watchConnState(cli EtcdClient) {
// DialClient dials an etcd cluster with given endpoints.
func DialClient(endpoints []string) (EtcdClient, error) {
cfg := clientv3.Config{
Endpoints: endpoints,
AutoSyncInterval: autoSyncInterval,
DialTimeout: DialTimeout,
DialKeepAliveTime: dialKeepAliveTime,
DialKeepAliveTimeout: DialTimeout,
RejectOldCluster: true,
PermitWithoutStream: true,
Endpoints: endpoints,
AutoSyncInterval: autoSyncInterval,
DialTimeout: DialTimeout,
RejectOldCluster: true,
PermitWithoutStream: true,
}
if account, ok := GetAccount(endpoints); ok {
cfg.Username = account.User

View File

@@ -9,7 +9,6 @@ const (
autoSyncInterval = time.Minute
coolDownInterval = time.Second
dialTimeout = 5 * time.Second
dialKeepAliveTime = 5 * time.Second
requestTimeout = 3 * time.Second
endpointsSeparator = ","
)

View File

@@ -78,7 +78,7 @@ func TestBulkExecutorFlush(t *testing.T) {
wait.Wait()
}
func TestBuldExecutorFlushSlowTasks(t *testing.T) {
func TestBulkExecutorFlushSlowTasks(t *testing.T) {
const total = 1500
lock := new(sync.Mutex)
result := make([]any, 0, 10000)

View File

@@ -168,23 +168,23 @@ func TestPeriodicalExecutor_FlushPanic(t *testing.T) {
func TestPeriodicalExecutor_Wait(t *testing.T) {
var lock sync.Mutex
executer := NewBulkExecutor(func(tasks []any) {
executor := NewBulkExecutor(func(tasks []any) {
lock.Lock()
defer lock.Unlock()
time.Sleep(10 * time.Millisecond)
}, WithBulkTasks(1), WithBulkInterval(time.Second))
for i := 0; i < 10; i++ {
executer.Add(1)
executor.Add(1)
}
executer.Flush()
executer.Wait()
executor.Flush()
executor.Wait()
}
func TestPeriodicalExecutor_WaitFast(t *testing.T) {
const total = 3
var cnt int
var lock sync.Mutex
executer := NewBulkExecutor(func(tasks []any) {
executor := NewBulkExecutor(func(tasks []any) {
defer func() {
cnt++
}()
@@ -193,10 +193,10 @@ func TestPeriodicalExecutor_WaitFast(t *testing.T) {
time.Sleep(10 * time.Millisecond)
}, WithBulkTasks(1), WithBulkInterval(10*time.Millisecond))
for i := 0; i < total; i++ {
executer.Add(2)
executor.Add(2)
}
executer.Flush()
executer.Wait()
executor.Flush()
executor.Wait()
assert.Equal(t, total, cnt)
}

View File

@@ -74,6 +74,11 @@ func TestFirstLineShort(t *testing.T) {
assert.Equal(t, "first line", val)
}
func TestFirstLineError(t *testing.T) {
_, err := FirstLine("/tmp/does-not-exist")
assert.Error(t, err)
}
func TestLastLine(t *testing.T) {
filename, err := fs.TempFilenameWithText(text)
assert.Nil(t, err)
@@ -113,3 +118,8 @@ func TestLastLineWithLastNewlineShort(t *testing.T) {
assert.Nil(t, err)
assert.Equal(t, "last line", val)
}
func TestLastLineError(t *testing.T) {
_, err := LastLine("/tmp/does-not-exist")
assert.Error(t, err)
}

View File

@@ -11,29 +11,29 @@ import (
// The file is kept as open, the caller should close the file handle,
// and remove the file by name.
func TempFileWithText(text string) (*os.File, error) {
tmpfile, err := os.CreateTemp(os.TempDir(), hash.Md5Hex([]byte(text)))
tmpFile, err := os.CreateTemp(os.TempDir(), hash.Md5Hex([]byte(text)))
if err != nil {
return nil, err
}
if err := os.WriteFile(tmpfile.Name(), []byte(text), os.ModeTemporary); err != nil {
if err := os.WriteFile(tmpFile.Name(), []byte(text), os.ModeTemporary); err != nil {
return nil, err
}
return tmpfile, nil
return tmpFile, nil
}
// TempFilenameWithText creates the file with the given content,
// and returns the filename (full path).
// The caller should remove the file after use.
func TempFilenameWithText(text string) (string, error) {
tmpfile, err := TempFileWithText(text)
tmpFile, err := TempFileWithText(text)
if err != nil {
return "", err
}
filename := tmpfile.Name()
if err = tmpfile.Close(); err != nil {
filename := tmpFile.Name()
if err = tmpFile.Close(); err != nil {
return "", err
}

View File

@@ -1,31 +1,87 @@
package fx
import "github.com/zeromicro/go-zero/core/errorx"
import (
"context"
"errors"
"time"
"github.com/zeromicro/go-zero/core/errorx"
)
const defaultRetryTimes = 3
var errTimeout = errors.New("retry timeout")
type (
// RetryOption defines the method to customize DoWithRetry.
RetryOption func(*retryOptions)
retryOptions struct {
times int
times int
interval time.Duration
timeout time.Duration
}
)
// DoWithRetry runs fn, and retries if failed. Default to retry 3 times.
// Note that if the fn function accesses global variables outside the function
// and performs modification operations, it is best to lock them,
// otherwise there may be data race issues
func DoWithRetry(fn func() error, opts ...RetryOption) error {
return retry(func(errChan chan error, retryCount int) {
errChan <- fn()
}, opts...)
}
// DoWithRetryCtx runs fn, and retries if failed. Default to retry 3 times.
// fn retryCount indicates the current number of retries, starting from 0
// Note that if the fn function accesses global variables outside the function
// and performs modification operations, it is best to lock them,
// otherwise there may be data race issues
func DoWithRetryCtx(ctx context.Context, fn func(ctx context.Context, retryCount int) error,
opts ...RetryOption) error {
return retry(func(errChan chan error, retryCount int) {
errChan <- fn(ctx, retryCount)
}, opts...)
}
func retry(fn func(errChan chan error, retryCount int), opts ...RetryOption) error {
options := newRetryOptions()
for _, opt := range opts {
opt(options)
}
var berr errorx.BatchError
var cancelFunc context.CancelFunc
ctx := context.Background()
if options.timeout > 0 {
ctx, cancelFunc = context.WithTimeout(ctx, options.timeout)
defer cancelFunc()
}
errChan := make(chan error, 1)
for i := 0; i < options.times; i++ {
if err := fn(); err != nil {
berr.Add(err)
} else {
return nil
go fn(errChan, i)
select {
case err := <-errChan:
if err != nil {
berr.Add(err)
} else {
return nil
}
case <-ctx.Done():
berr.Add(errTimeout)
return berr.Err()
}
if options.interval > 0 {
select {
case <-ctx.Done():
berr.Add(errTimeout)
return berr.Err()
case <-time.After(options.interval):
}
}
}
@@ -39,6 +95,18 @@ func WithRetry(times int) RetryOption {
}
}
func WithInterval(interval time.Duration) RetryOption {
return func(options *retryOptions) {
options.interval = interval
}
}
func WithTimeout(timeout time.Duration) RetryOption {
return func(options *retryOptions) {
options.timeout = timeout
}
}
func newRetryOptions() *retryOptions {
return &retryOptions{
times: defaultRetryTimes,

View File

@@ -1,8 +1,10 @@
package fx
import (
"context"
"errors"
"testing"
"time"
"github.com/stretchr/testify/assert"
)
@@ -12,31 +14,103 @@ func TestRetry(t *testing.T) {
return errors.New("any")
}))
var times int
times1 := 0
assert.Nil(t, DoWithRetry(func() error {
times++
if times == defaultRetryTimes {
times1++
if times1 == defaultRetryTimes {
return nil
}
return errors.New("any")
}))
times = 0
times2 := 0
assert.NotNil(t, DoWithRetry(func() error {
times++
if times == defaultRetryTimes+1 {
times2++
if times2 == defaultRetryTimes+1 {
return nil
}
return errors.New("any")
}))
total := 2 * defaultRetryTimes
times = 0
times3 := 0
assert.Nil(t, DoWithRetry(func() error {
times++
if times == total {
times3++
if times3 == total {
return nil
}
return errors.New("any")
}, WithRetry(total)))
}
func TestRetryWithTimeout(t *testing.T) {
assert.Nil(t, DoWithRetry(func() error {
return nil
}, WithTimeout(time.Millisecond*500)))
times1 := 0
assert.Nil(t, DoWithRetry(func() error {
times1++
if times1 == 1 {
return errors.New("any ")
}
time.Sleep(time.Millisecond * 150)
return nil
}, WithTimeout(time.Millisecond*250)))
total := defaultRetryTimes
times2 := 0
assert.Nil(t, DoWithRetry(func() error {
times2++
if times2 == total {
return nil
}
time.Sleep(time.Millisecond * 50)
return errors.New("any")
}, WithTimeout(time.Millisecond*50*(time.Duration(total)+2))))
assert.NotNil(t, DoWithRetry(func() error {
return errors.New("any")
}, WithTimeout(time.Millisecond*250)))
}
func TestRetryWithInterval(t *testing.T) {
times1 := 0
assert.NotNil(t, DoWithRetry(func() error {
times1++
if times1 == 1 {
return errors.New("any")
}
time.Sleep(time.Millisecond * 150)
return nil
}, WithTimeout(time.Millisecond*250), WithInterval(time.Millisecond*150)))
times2 := 0
assert.NotNil(t, DoWithRetry(func() error {
times2++
if times2 == 2 {
return nil
}
time.Sleep(time.Millisecond * 150)
return errors.New("any ")
}, WithTimeout(time.Millisecond*250), WithInterval(time.Millisecond*150)))
}
func TestRetryCtx(t *testing.T) {
assert.NotNil(t, DoWithRetryCtx(context.Background(), func(ctx context.Context, retryCount int) error {
if retryCount == 0 {
return errors.New("any")
}
time.Sleep(time.Millisecond * 150)
return nil
}, WithTimeout(time.Millisecond*250), WithInterval(time.Millisecond*150)))
assert.NotNil(t, DoWithRetryCtx(context.Background(), func(ctx context.Context, retryCount int) error {
if retryCount == 1 {
return nil
}
time.Sleep(time.Millisecond * 150)
return errors.New("any ")
}, WithTimeout(time.Millisecond*250), WithInterval(time.Millisecond*150)))
}

View File

@@ -32,6 +32,10 @@ func (bp *BufferPool) Get() *bytes.Buffer {
// Put returns buf into bp.
func (bp *BufferPool) Put(buf *bytes.Buffer) {
if buf == nil {
return
}
if buf.Cap() < bp.capability {
bp.pool.Put(buf)
}

View File

@@ -13,3 +13,26 @@ func TestBufferPool(t *testing.T) {
pool.Put(bytes.NewBuffer(make([]byte, 0, 2*capacity)))
assert.True(t, pool.Get().Cap() <= capacity)
}
func TestBufferPool_Put(t *testing.T) {
t.Run("with nil buf", func(t *testing.T) {
pool := NewBufferPool(1024)
pool.Put(nil)
val := pool.Get()
assert.IsType(t, new(bytes.Buffer), val)
})
t.Run("with less-cap buf", func(t *testing.T) {
pool := NewBufferPool(1024)
pool.Put(bytes.NewBuffer(make([]byte, 0, 512)))
val := pool.Get()
assert.IsType(t, new(bytes.Buffer), val)
})
t.Run("with more-cap buf", func(t *testing.T) {
pool := NewBufferPool(1024)
pool.Put(bytes.NewBuffer(make([]byte, 0, 1024<<1)))
val := pool.Get()
assert.IsType(t, new(bytes.Buffer), val)
})
}

View File

@@ -0,0 +1,12 @@
package iox
import (
"testing"
"github.com/stretchr/testify/assert"
)
func TestNopCloser(t *testing.T) {
closer := NopCloser(nil)
assert.NoError(t, closer.Close())
}

View File

@@ -35,6 +35,16 @@ func KeepSpace() TextReadOption {
}
}
// LimitDupReadCloser returns two io.ReadCloser that read from the first will be written to the second.
// But the second io.ReadCloser is limited to up to n bytes.
// The first returned reader needs to be read first, because the content
// read from it will be written to the underlying buffer of the second reader.
func LimitDupReadCloser(reader io.ReadCloser, n int64) (io.ReadCloser, io.ReadCloser) {
var buf bytes.Buffer
tee := LimitTeeReader(reader, &buf, n)
return io.NopCloser(tee), io.NopCloser(&buf)
}
// ReadBytes reads exactly the bytes with the length of len(buf)
func ReadBytes(reader io.Reader, buf []byte) error {
var got int

View File

@@ -40,17 +40,22 @@ b`,
for _, test := range tests {
t.Run(test.input, func(t *testing.T) {
tmpfile, err := fs.TempFilenameWithText(test.input)
tmpFile, err := fs.TempFilenameWithText(test.input)
assert.Nil(t, err)
defer os.Remove(tmpfile)
defer os.Remove(tmpFile)
content, err := ReadText(tmpfile)
content, err := ReadText(tmpFile)
assert.Nil(t, err)
assert.Equal(t, test.expect, content)
})
}
}
func TestReadTextError(t *testing.T) {
_, err := ReadText("not-exist")
assert.NotNil(t, err)
}
func TestReadTextLines(t *testing.T) {
text := `1
@@ -59,9 +64,9 @@ func TestReadTextLines(t *testing.T) {
#a
3`
tmpfile, err := fs.TempFilenameWithText(text)
tmpFile, err := fs.TempFilenameWithText(text)
assert.Nil(t, err)
defer os.Remove(tmpfile)
defer os.Remove(tmpFile)
tests := []struct {
options []TextReadOption
@@ -87,13 +92,18 @@ func TestReadTextLines(t *testing.T) {
for _, test := range tests {
t.Run(stringx.Rand(), func(t *testing.T) {
lines, err := ReadTextLines(tmpfile, test.options...)
lines, err := ReadTextLines(tmpFile, test.options...)
assert.Nil(t, err)
assert.Equal(t, test.expectLines, len(lines))
})
}
}
func TestReadTextLinesError(t *testing.T) {
_, err := ReadTextLines("not-exist")
assert.NotNil(t, err)
}
func TestDupReadCloser(t *testing.T) {
input := "hello"
reader := io.NopCloser(bytes.NewBufferString(input))
@@ -108,6 +118,29 @@ func TestDupReadCloser(t *testing.T) {
verify(r2)
}
func TestLimitDupReadCloser(t *testing.T) {
input := "hello world"
limitBytes := int64(4)
reader := io.NopCloser(bytes.NewBufferString(input))
r1, r2 := LimitDupReadCloser(reader, limitBytes)
verify := func(r io.Reader) {
output, err := io.ReadAll(r)
assert.Nil(t, err)
assert.Equal(t, input, string(output))
}
verifyLimit := func(r io.Reader, limit int64) {
output, err := io.ReadAll(r)
if limit < int64(len(input)) {
input = input[:limit]
}
assert.Nil(t, err)
assert.Equal(t, input, string(output))
}
verify(r1)
verifyLimit(r2, limitBytes)
}
func TestReadBytes(t *testing.T) {
reader := io.NopCloser(bytes.NewBufferString("helloworld"))
buf := make([]byte, 5)

35
core/iox/tee.go Normal file
View File

@@ -0,0 +1,35 @@
package iox
import "io"
// LimitTeeReader returns a Reader that writes up to n bytes to w what it reads from r.
// First n bytes reads from r performed through it are matched with
// corresponding writes to w. There is no internal buffering -
// the write must complete before the first n bytes read completes.
// Any error encountered while writing is reported as a read error.
func LimitTeeReader(r io.Reader, w io.Writer, n int64) io.Reader {
return &limitTeeReader{r, w, n}
}
type limitTeeReader struct {
r io.Reader
w io.Writer
n int64 // limit bytes remaining
}
func (t *limitTeeReader) Read(p []byte) (n int, err error) {
n, err = t.r.Read(p)
if n > 0 && t.n > 0 {
limit := int64(n)
if limit > t.n {
limit = t.n
}
if n, err := t.w.Write(p[:limit]); err != nil {
return n, err
}
t.n -= limit
}
return
}

40
core/iox/tee_test.go Normal file
View File

@@ -0,0 +1,40 @@
package iox
import (
"bytes"
"io"
"testing"
"github.com/stretchr/testify/assert"
)
func TestLimitTeeReader(t *testing.T) {
limit := int64(4)
src := []byte("hello, world")
dst := make([]byte, len(src))
rb := bytes.NewBuffer(src)
wb := new(bytes.Buffer)
r := LimitTeeReader(rb, wb, limit)
if n, err := io.ReadFull(r, dst); err != nil || n != len(src) {
t.Fatalf("ReadFull(r, dst) = %d, %v; want %d, nil", n, err, len(src))
}
if !bytes.Equal(dst, src) {
t.Errorf("bytes read = %q want %q", dst, src)
}
if !bytes.Equal(wb.Bytes(), src[:limit]) {
t.Errorf("bytes written = %q want %q", wb.Bytes(), src)
}
n, err := r.Read(dst)
assert.Equal(t, 0, n)
assert.Equal(t, io.EOF, err)
rb = bytes.NewBuffer(src)
pr, pw := io.Pipe()
if assert.NoError(t, pr.Close()) {
r = LimitTeeReader(rb, pw, limit)
n, err := io.ReadFull(r, dst)
assert.Equal(t, 0, n)
assert.Equal(t, io.ErrClosedPipe, err)
}
}

View File

@@ -2,6 +2,7 @@ package iox
import (
"bytes"
"errors"
"io"
"os"
)
@@ -26,7 +27,7 @@ func CountLines(file string) (int, error) {
count += bytes.Count(buf[:c], lineSep)
switch {
case err == io.EOF:
case errors.Is(err, io.EOF):
if noEol {
count++
}

View File

@@ -24,3 +24,8 @@ func TestCountLines(t *testing.T) {
assert.Nil(t, err)
assert.Equal(t, 4, lines)
}
func TestCountLinesError(t *testing.T) {
_, err := CountLines("not-exist")
assert.NotNil(t, err)
}

View File

@@ -3,6 +3,7 @@ package iox
import (
"strings"
"testing"
"testing/iotest"
"github.com/stretchr/testify/assert"
)
@@ -22,3 +23,10 @@ func TestScanner(t *testing.T) {
}
assert.EqualValues(t, []string{"1", "2", "3", "4"}, lines)
}
func TestBadScanner(t *testing.T) {
scanner := NewTextLineScanner(iotest.ErrReader(iotest.ErrTimeout))
assert.False(t, scanner.Scan())
_, err := scanner.Line()
assert.ErrorIs(t, err, iotest.ErrTimeout)
}

View File

@@ -1,7 +1,6 @@
package logc
import (
"bytes"
"context"
"encoding/json"
"fmt"
@@ -11,14 +10,11 @@ import (
"github.com/stretchr/testify/assert"
"github.com/zeromicro/go-zero/core/logx"
"github.com/zeromicro/go-zero/core/logx/logtest"
)
func TestAddGlobalFields(t *testing.T) {
var buf bytes.Buffer
writer := logx.NewWriter(&buf)
old := logx.Reset()
logx.SetWriter(writer)
defer logx.SetWriter(old)
buf := logtest.NewCollector(t)
Info(context.Background(), "hello")
buf.Reset()
@@ -34,155 +30,90 @@ func TestAddGlobalFields(t *testing.T) {
}
func TestAlert(t *testing.T) {
var buf strings.Builder
writer := logx.NewWriter(&buf)
old := logx.Reset()
logx.SetWriter(writer)
defer logx.SetWriter(old)
buf := logtest.NewCollector(t)
Alert(context.Background(), "foo")
assert.True(t, strings.Contains(buf.String(), "foo"), buf.String())
}
func TestError(t *testing.T) {
var buf strings.Builder
writer := logx.NewWriter(&buf)
old := logx.Reset()
logx.SetWriter(writer)
defer logx.SetWriter(old)
buf := logtest.NewCollector(t)
file, line := getFileLine()
Error(context.Background(), "foo")
assert.True(t, strings.Contains(buf.String(), fmt.Sprintf("%s:%d", file, line+1)))
}
func TestErrorf(t *testing.T) {
var buf strings.Builder
writer := logx.NewWriter(&buf)
old := logx.Reset()
logx.SetWriter(writer)
defer logx.SetWriter(old)
buf := logtest.NewCollector(t)
file, line := getFileLine()
Errorf(context.Background(), "foo %s", "bar")
assert.True(t, strings.Contains(buf.String(), fmt.Sprintf("%s:%d", file, line+1)))
}
func TestErrorv(t *testing.T) {
var buf strings.Builder
writer := logx.NewWriter(&buf)
old := logx.Reset()
logx.SetWriter(writer)
defer logx.SetWriter(old)
buf := logtest.NewCollector(t)
file, line := getFileLine()
Errorv(context.Background(), "foo")
assert.True(t, strings.Contains(buf.String(), fmt.Sprintf("%s:%d", file, line+1)))
}
func TestErrorw(t *testing.T) {
var buf strings.Builder
writer := logx.NewWriter(&buf)
old := logx.Reset()
logx.SetWriter(writer)
defer logx.SetWriter(old)
buf := logtest.NewCollector(t)
file, line := getFileLine()
Errorw(context.Background(), "foo", Field("a", "b"))
assert.True(t, strings.Contains(buf.String(), fmt.Sprintf("%s:%d", file, line+1)))
}
func TestInfo(t *testing.T) {
var buf strings.Builder
writer := logx.NewWriter(&buf)
old := logx.Reset()
logx.SetWriter(writer)
defer logx.SetWriter(old)
buf := logtest.NewCollector(t)
file, line := getFileLine()
Info(context.Background(), "foo")
assert.True(t, strings.Contains(buf.String(), fmt.Sprintf("%s:%d", file, line+1)))
}
func TestInfof(t *testing.T) {
var buf strings.Builder
writer := logx.NewWriter(&buf)
old := logx.Reset()
logx.SetWriter(writer)
defer logx.SetWriter(old)
buf := logtest.NewCollector(t)
file, line := getFileLine()
Infof(context.Background(), "foo %s", "bar")
assert.True(t, strings.Contains(buf.String(), fmt.Sprintf("%s:%d", file, line+1)))
}
func TestInfov(t *testing.T) {
var buf strings.Builder
writer := logx.NewWriter(&buf)
old := logx.Reset()
logx.SetWriter(writer)
defer logx.SetWriter(old)
buf := logtest.NewCollector(t)
file, line := getFileLine()
Infov(context.Background(), "foo")
assert.True(t, strings.Contains(buf.String(), fmt.Sprintf("%s:%d", file, line+1)))
}
func TestInfow(t *testing.T) {
var buf strings.Builder
writer := logx.NewWriter(&buf)
old := logx.Reset()
logx.SetWriter(writer)
defer logx.SetWriter(old)
buf := logtest.NewCollector(t)
file, line := getFileLine()
Infow(context.Background(), "foo", Field("a", "b"))
assert.True(t, strings.Contains(buf.String(), fmt.Sprintf("%s:%d", file, line+1)))
}
func TestDebug(t *testing.T) {
var buf strings.Builder
writer := logx.NewWriter(&buf)
old := logx.Reset()
logx.SetWriter(writer)
defer logx.SetWriter(old)
buf := logtest.NewCollector(t)
file, line := getFileLine()
Debug(context.Background(), "foo")
assert.True(t, strings.Contains(buf.String(), fmt.Sprintf("%s:%d", file, line+1)))
}
func TestDebugf(t *testing.T) {
var buf strings.Builder
writer := logx.NewWriter(&buf)
old := logx.Reset()
logx.SetWriter(writer)
defer logx.SetWriter(old)
buf := logtest.NewCollector(t)
file, line := getFileLine()
Debugf(context.Background(), "foo %s", "bar")
assert.True(t, strings.Contains(buf.String(), fmt.Sprintf("%s:%d", file, line+1)))
}
func TestDebugv(t *testing.T) {
var buf strings.Builder
writer := logx.NewWriter(&buf)
old := logx.Reset()
logx.SetWriter(writer)
defer logx.SetWriter(old)
buf := logtest.NewCollector(t)
file, line := getFileLine()
Debugv(context.Background(), "foo")
assert.True(t, strings.Contains(buf.String(), fmt.Sprintf("%s:%d", file, line+1)))
}
func TestDebugw(t *testing.T) {
var buf strings.Builder
writer := logx.NewWriter(&buf)
old := logx.Reset()
logx.SetWriter(writer)
defer logx.SetWriter(old)
buf := logtest.NewCollector(t)
file, line := getFileLine()
Debugw(context.Background(), "foo", Field("a", "b"))
assert.True(t, strings.Contains(buf.String(), fmt.Sprintf("%s:%d", file, line+1)))
@@ -204,48 +135,28 @@ func TestMisc(t *testing.T) {
}
func TestSlow(t *testing.T) {
var buf strings.Builder
writer := logx.NewWriter(&buf)
old := logx.Reset()
logx.SetWriter(writer)
defer logx.SetWriter(old)
buf := logtest.NewCollector(t)
file, line := getFileLine()
Slow(context.Background(), "foo")
assert.True(t, strings.Contains(buf.String(), fmt.Sprintf("%s:%d", file, line+1)), buf.String())
}
func TestSlowf(t *testing.T) {
var buf strings.Builder
writer := logx.NewWriter(&buf)
old := logx.Reset()
logx.SetWriter(writer)
defer logx.SetWriter(old)
buf := logtest.NewCollector(t)
file, line := getFileLine()
Slowf(context.Background(), "foo %s", "bar")
assert.True(t, strings.Contains(buf.String(), fmt.Sprintf("%s:%d", file, line+1)), buf.String())
}
func TestSlowv(t *testing.T) {
var buf strings.Builder
writer := logx.NewWriter(&buf)
old := logx.Reset()
logx.SetWriter(writer)
defer logx.SetWriter(old)
buf := logtest.NewCollector(t)
file, line := getFileLine()
Slowv(context.Background(), "foo")
assert.True(t, strings.Contains(buf.String(), fmt.Sprintf("%s:%d", file, line+1)), buf.String())
}
func TestSloww(t *testing.T) {
var buf strings.Builder
writer := logx.NewWriter(&buf)
old := logx.Reset()
logx.SetWriter(writer)
defer logx.SetWriter(old)
buf := logtest.NewCollector(t)
file, line := getFileLine()
Sloww(context.Background(), "foo", Field("a", "b"))
assert.True(t, strings.Contains(buf.String(), fmt.Sprintf("%s:%d", file, line+1)), buf.String())

View File

@@ -32,7 +32,7 @@ type LogConf struct {
StackCooldownMillis int `json:",default=100"`
// MaxBackups represents how many backup log files will be kept. 0 means all files will be kept forever.
// Only take effect when RotationRuleType is `size`.
// Even thougth `MaxBackups` sets 0, log files will still be removed
// Even though `MaxBackups` sets 0, log files will still be removed
// if the `KeepDays` limitation is reached.
MaxBackups int `json:",default=0"`
// MaxSize represents how much space the writing log file takes up. 0 means no limit. The unit is `MB`.

40
core/logx/fs.go Normal file
View File

@@ -0,0 +1,40 @@
package logx
import (
"io"
"os"
)
var fileSys realFileSystem
type (
fileSystem interface {
Close(closer io.Closer) error
Copy(writer io.Writer, reader io.Reader) (int64, error)
Create(name string) (*os.File, error)
Open(name string) (*os.File, error)
Remove(name string) error
}
realFileSystem struct{}
)
func (fs realFileSystem) Close(closer io.Closer) error {
return closer.Close()
}
func (fs realFileSystem) Copy(writer io.Writer, reader io.Reader) (int64, error) {
return io.Copy(writer, reader)
}
func (fs realFileSystem) Create(name string) (*os.File, error) {
return os.Create(name)
}
func (fs realFileSystem) Open(name string) (*os.File, error) {
return os.Open(name)
}
func (fs realFileSystem) Remove(name string) error {
return os.Remove(name)
}

View File

@@ -68,22 +68,30 @@ func Close() error {
// Debug writes v into access log.
func Debug(v ...any) {
writeDebug(fmt.Sprint(v...))
if shallLog(DebugLevel) {
writeDebug(fmt.Sprint(v...))
}
}
// Debugf writes v with format into access log.
func Debugf(format string, v ...any) {
writeDebug(fmt.Sprintf(format, v...))
if shallLog(DebugLevel) {
writeDebug(fmt.Sprintf(format, v...))
}
}
// Debugv writes v into access log with json content.
func Debugv(v any) {
writeDebug(v)
if shallLog(DebugLevel) {
writeDebug(v)
}
}
// Debugw writes msg along with fields into access log.
func Debugw(msg string, fields ...LogField) {
writeDebug(msg, fields...)
if shallLog(DebugLevel) {
writeDebug(msg, fields...)
}
}
// Disable disables the logging.
@@ -99,35 +107,47 @@ func DisableStat() {
// Error writes v into error log.
func Error(v ...any) {
writeError(fmt.Sprint(v...))
if shallLog(ErrorLevel) {
writeError(fmt.Sprint(v...))
}
}
// Errorf writes v with format into error log.
func Errorf(format string, v ...any) {
writeError(fmt.Errorf(format, v...).Error())
if shallLog(ErrorLevel) {
writeError(fmt.Errorf(format, v...).Error())
}
}
// ErrorStack writes v along with call stack into error log.
func ErrorStack(v ...any) {
// there is newline in stack string
writeStack(fmt.Sprint(v...))
if shallLog(ErrorLevel) {
// there is newline in stack string
writeStack(fmt.Sprint(v...))
}
}
// ErrorStackf writes v along with call stack in format into error log.
func ErrorStackf(format string, v ...any) {
// there is newline in stack string
writeStack(fmt.Sprintf(format, v...))
if shallLog(ErrorLevel) {
// there is newline in stack string
writeStack(fmt.Sprintf(format, v...))
}
}
// Errorv writes v into error log with json content.
// No call stack attached, because not elegant to pack the messages.
func Errorv(v any) {
writeError(v)
if shallLog(ErrorLevel) {
writeError(v)
}
}
// Errorw writes msg along with fields into error log.
func Errorw(msg string, fields ...LogField) {
writeError(msg, fields...)
if shallLog(ErrorLevel) {
writeError(msg, fields...)
}
}
// Field returns a LogField for the given key and value.
@@ -170,22 +190,30 @@ func Field(key string, value any) LogField {
// Info writes v into access log.
func Info(v ...any) {
writeInfo(fmt.Sprint(v...))
if shallLog(InfoLevel) {
writeInfo(fmt.Sprint(v...))
}
}
// Infof writes v with format into access log.
func Infof(format string, v ...any) {
writeInfo(fmt.Sprintf(format, v...))
if shallLog(InfoLevel) {
writeInfo(fmt.Sprintf(format, v...))
}
}
// Infov writes v into access log with json content.
func Infov(v any) {
writeInfo(v)
if shallLog(InfoLevel) {
writeInfo(v)
}
}
// Infow writes msg along with fields into access log.
func Infow(msg string, fields ...LogField) {
writeInfo(msg, fields...)
if shallLog(InfoLevel) {
writeInfo(msg, fields...)
}
}
// Must checks if err is nil, otherwise logs the error and exits.
@@ -194,7 +222,7 @@ func Must(err error) {
return
}
msg := err.Error()
msg := fmt.Sprintf("%+v\n\n%s", err.Error(), debug.Stack())
log.Print(msg)
getWriter().Severe(msg)
@@ -269,42 +297,58 @@ func SetUp(c LogConf) (err error) {
// Severe writes v into severe log.
func Severe(v ...any) {
writeSevere(fmt.Sprint(v...))
if shallLog(SevereLevel) {
writeSevere(fmt.Sprint(v...))
}
}
// Severef writes v with format into severe log.
func Severef(format string, v ...any) {
writeSevere(fmt.Sprintf(format, v...))
if shallLog(SevereLevel) {
writeSevere(fmt.Sprintf(format, v...))
}
}
// Slow writes v into slow log.
func Slow(v ...any) {
writeSlow(fmt.Sprint(v...))
if shallLog(ErrorLevel) {
writeSlow(fmt.Sprint(v...))
}
}
// Slowf writes v with format into slow log.
func Slowf(format string, v ...any) {
writeSlow(fmt.Sprintf(format, v...))
if shallLog(ErrorLevel) {
writeSlow(fmt.Sprintf(format, v...))
}
}
// Slowv writes v into slow log with json content.
func Slowv(v any) {
writeSlow(v)
if shallLog(ErrorLevel) {
writeSlow(v)
}
}
// Sloww writes msg along with fields into slow log.
func Sloww(msg string, fields ...LogField) {
writeSlow(msg, fields...)
if shallLog(ErrorLevel) {
writeSlow(msg, fields...)
}
}
// Stat writes v into stat log.
func Stat(v ...any) {
writeStat(fmt.Sprint(v...))
if shallLogStat() && shallLog(InfoLevel) {
writeStat(fmt.Sprint(v...))
}
}
// Statf writes v with format into stat log.
func Statf(format string, v ...any) {
writeStat(fmt.Sprintf(format, v...))
if shallLogStat() && shallLog(InfoLevel) {
writeStat(fmt.Sprintf(format, v...))
}
}
// WithCooldownMillis customizes logging on writing call stack interval.
@@ -358,14 +402,16 @@ func createOutput(path string) (io.WriteCloser, error) {
return nil, ErrLogPathNotSet
}
var rule RotateRule
switch options.rotationRule {
case sizeRotationRule:
return NewLogger(path, NewSizeLimitRotateRule(path, backupFileDelimiter, options.keepDays,
options.maxSize, options.maxBackups, options.gzipEnabled), options.gzipEnabled)
rule = NewSizeLimitRotateRule(path, backupFileDelimiter, options.keepDays, options.maxSize,
options.maxBackups, options.gzipEnabled)
default:
return NewLogger(path, DefaultRotateRule(path, backupFileDelimiter, options.keepDays,
options.gzipEnabled), options.gzipEnabled)
rule = DefaultRotateRule(path, backupFileDelimiter, options.keepDays, options.gzipEnabled)
}
return NewLogger(path, rule, options.gzipEnabled)
}
func getWriter() Writer {
@@ -427,44 +473,58 @@ func shallLogStat() bool {
return atomic.LoadUint32(&disableStat) == 0
}
// writeDebug writes v into debug log.
// Not checking shallLog here is for performance consideration.
// If we check shallLog here, the fmt.Sprint might be called even if the log level is not enabled.
// The caller should check shallLog before calling this function.
func writeDebug(val any, fields ...LogField) {
if shallLog(DebugLevel) {
getWriter().Debug(val, addCaller(fields...)...)
}
getWriter().Debug(val, addCaller(fields...)...)
}
// writeError writes v into error log.
// Not checking shallLog here is for performance consideration.
// If we check shallLog here, the fmt.Sprint might be called even if the log level is not enabled.
// The caller should check shallLog before calling this function.
func writeError(val any, fields ...LogField) {
if shallLog(ErrorLevel) {
getWriter().Error(val, addCaller(fields...)...)
}
getWriter().Error(val, addCaller(fields...)...)
}
// writeInfo writes v into info log.
// Not checking shallLog here is for performance consideration.
// If we check shallLog here, the fmt.Sprint might be called even if the log level is not enabled.
// The caller should check shallLog before calling this function.
func writeInfo(val any, fields ...LogField) {
if shallLog(InfoLevel) {
getWriter().Info(val, addCaller(fields...)...)
}
getWriter().Info(val, addCaller(fields...)...)
}
// writeSevere writes v into severe log.
// Not checking shallLog here is for performance consideration.
// If we check shallLog here, the fmt.Sprint might be called even if the log level is not enabled.
// The caller should check shallLog before calling this function.
func writeSevere(msg string) {
if shallLog(SevereLevel) {
getWriter().Severe(fmt.Sprintf("%s\n%s", msg, string(debug.Stack())))
}
getWriter().Severe(fmt.Sprintf("%s\n%s", msg, string(debug.Stack())))
}
// writeSlow writes v into slow log.
// Not checking shallLog here is for performance consideration.
// If we check shallLog here, the fmt.Sprint might be called even if the log level is not enabled.
// The caller should check shallLog before calling this function.
func writeSlow(val any, fields ...LogField) {
if shallLog(ErrorLevel) {
getWriter().Slow(val, addCaller(fields...)...)
}
getWriter().Slow(val, addCaller(fields...)...)
}
// writeStack writes v into stack log.
// Not checking shallLog here is for performance consideration.
// If we check shallLog here, the fmt.Sprint might be called even if the log level is not enabled.
// The caller should check shallLog before calling this function.
func writeStack(msg string) {
if shallLog(ErrorLevel) {
getWriter().Stack(fmt.Sprintf("%s\n%s", msg, string(debug.Stack())))
}
getWriter().Stack(fmt.Sprintf("%s\n%s", msg, string(debug.Stack())))
}
// writeStat writes v into stat log.
// Not checking shallLog here is for performance consideration.
// If we check shallLog here, the fmt.Sprint might be called even if the log level is not enabled.
// The caller should check shallLog before calling this function.
func writeStat(msg string) {
if shallLogStat() && shallLog(InfoLevel) {
getWriter().Stat(msg, addCaller()...)
}
getWriter().Stat(msg, addCaller()...)
}

View File

@@ -0,0 +1,84 @@
package logtest
import (
"bytes"
"encoding/json"
"io"
"testing"
"github.com/zeromicro/go-zero/core/logx"
)
type Buffer struct {
buf *bytes.Buffer
t *testing.T
}
func Discard(t *testing.T) {
prev := logx.Reset()
logx.SetWriter(logx.NewWriter(io.Discard))
t.Cleanup(func() {
logx.SetWriter(prev)
})
}
func NewCollector(t *testing.T) *Buffer {
var buf bytes.Buffer
writer := logx.NewWriter(&buf)
prev := logx.Reset()
logx.SetWriter(writer)
t.Cleanup(func() {
logx.SetWriter(prev)
})
return &Buffer{
buf: &buf,
t: t,
}
}
func (b *Buffer) Bytes() []byte {
return b.buf.Bytes()
}
func (b *Buffer) Content() string {
var m map[string]interface{}
if err := json.Unmarshal(b.buf.Bytes(), &m); err != nil {
return ""
}
content, ok := m["content"]
if !ok {
return ""
}
switch val := content.(type) {
case string:
return val
default:
// err is impossible to be not nil, unmarshaled from b.buf.Bytes()
bs, _ := json.Marshal(content)
return string(bs)
}
}
func (b *Buffer) Reset() {
b.buf.Reset()
}
func (b *Buffer) String() string {
return b.buf.String()
}
func PanicOnFatal(t *testing.T) {
ok := logx.ExitOnFatal.CompareAndSwap(true, false)
if !ok {
return
}
t.Cleanup(func() {
logx.ExitOnFatal.CompareAndSwap(false, true)
})
}

View File

@@ -0,0 +1,44 @@
package logtest
import (
"errors"
"testing"
"github.com/stretchr/testify/assert"
"github.com/zeromicro/go-zero/core/logx"
)
func TestCollector(t *testing.T) {
const input = "hello"
c := NewCollector(t)
logx.Info(input)
assert.Equal(t, input, c.Content())
assert.Contains(t, c.String(), input)
c.Reset()
assert.Empty(t, c.Bytes())
}
func TestPanicOnFatal(t *testing.T) {
const input = "hello"
Discard(t)
logx.Info(input)
PanicOnFatal(t)
PanicOnFatal(t)
assert.Panics(t, func() {
logx.Must(errors.New("foo"))
})
}
func TestCollectorContent(t *testing.T) {
const input = "hello"
c := NewCollector(t)
c.buf.WriteString(input)
assert.Empty(t, c.Content())
c.Reset()
c.buf.WriteString(`{}`)
assert.Empty(t, c.Content())
c.Reset()
c.buf.WriteString(`{"content":1}`)
assert.Equal(t, "1", c.Content())
}

View File

@@ -65,7 +65,7 @@ func (l *richLogger) Errorf(format string, v ...any) {
}
func (l *richLogger) Errorv(v any) {
l.err(fmt.Sprint(v))
l.err(v)
}
func (l *richLogger) Errorw(msg string, fields ...LogField) {

View File

@@ -66,6 +66,9 @@ func TestTraceDebug(t *testing.T) {
l.WithDuration(time.Second).Debugv(testlog)
validate(t, w.String(), true, true)
w.Reset()
l.WithDuration(time.Second).Debugv(testobj)
validateContentType(t, w.String(), map[string]any{}, true, true)
w.Reset()
l.WithDuration(time.Second).Debugw(testlog, Field("foo", "bar"))
validate(t, w.String(), true, true)
assert.True(t, strings.Contains(w.String(), "foo"), w.String())
@@ -103,6 +106,9 @@ func TestTraceError(t *testing.T) {
l.WithDuration(time.Second).Errorv(testlog)
validate(t, w.String(), true, true)
w.Reset()
l.WithDuration(time.Second).Errorv(testobj)
validateContentType(t, w.String(), map[string]any{}, true, true)
w.Reset()
l.WithDuration(time.Second).Errorw(testlog, Field("basket", "ball"))
validate(t, w.String(), true, true)
assert.True(t, strings.Contains(w.String(), "basket"), w.String())
@@ -137,6 +143,9 @@ func TestTraceInfo(t *testing.T) {
l.WithDuration(time.Second).Infov(testlog)
validate(t, w.String(), true, true)
w.Reset()
l.WithDuration(time.Second).Infov(testobj)
validateContentType(t, w.String(), map[string]any{}, true, true)
w.Reset()
l.WithDuration(time.Second).Infow(testlog, Field("basket", "ball"))
validate(t, w.String(), true, true)
assert.True(t, strings.Contains(w.String(), "basket"), w.String())
@@ -173,6 +182,9 @@ func TestTraceInfoConsole(t *testing.T) {
w.Reset()
l.WithDuration(time.Second).Infov(testlog)
validate(t, w.String(), true, true)
w.Reset()
l.WithDuration(time.Second).Infov(testobj)
validateContentType(t, w.String(), map[string]any{}, true, true)
}
func TestTraceSlow(t *testing.T) {
@@ -204,6 +216,9 @@ func TestTraceSlow(t *testing.T) {
l.WithDuration(time.Second).Slowv(testlog)
validate(t, w.String(), true, true)
w.Reset()
l.WithDuration(time.Second).Slowv(testobj)
validateContentType(t, w.String(), map[string]any{}, true, true)
w.Reset()
l.WithDuration(time.Second).Sloww(testlog, Field("basket", "ball"))
validate(t, w.String(), true, true)
assert.True(t, strings.Contains(w.String(), "basket"), w.String())
@@ -311,8 +326,32 @@ func validate(t *testing.T, body string, expectedTrace, expectedSpan bool) {
assert.Equal(t, expectedSpan, len(val.Span) > 0, body)
}
type mockValue struct {
Trace string `json:"trace"`
Span string `json:"span"`
Foo string `json:"foo"`
func validateContentType(t *testing.T, body string, expectedType any, expectedTrace, expectedSpan bool) {
var val mockValue
dec := json.NewDecoder(strings.NewReader(body))
for {
var doc mockValue
err := dec.Decode(&doc)
if err == io.EOF {
// all done
break
}
if err != nil {
continue
}
val = doc
}
assert.IsType(t, expectedType, val.Content, body)
assert.Equal(t, expectedTrace, len(val.Trace) > 0, body)
assert.Equal(t, expectedSpan, len(val.Span) > 0, body)
}
type mockValue struct {
Trace string `json:"trace"`
Span string `json:"span"`
Foo string `json:"foo"`
Content any `json:"content"`
}

View File

@@ -4,7 +4,6 @@ import (
"compress/gzip"
"errors"
"fmt"
"io"
"log"
"os"
"path"
@@ -299,6 +298,7 @@ func (l *RotateLogger) initialize() error {
if l.fp, err = os.OpenFile(l.filename, os.O_APPEND|os.O_WRONLY, defaultFileMode); err != nil {
return err
}
l.currentSize = fileInfo.Size()
}
@@ -382,7 +382,15 @@ func (l *RotateLogger) startWorker() {
case event := <-l.channel:
l.write(event)
case <-l.done:
return
// avoid losing logs before closing.
for {
select {
case event := <-l.channel:
l.write(event)
default:
return
}
}
}
}
}()
@@ -406,7 +414,7 @@ func (l *RotateLogger) write(v []byte) {
func compressLogFile(file string) {
start := time.Now()
Infof("compressing log file: %s", file)
if err := gzipFile(file); err != nil {
if err := gzipFile(file, fileSys); err != nil {
Errorf("compress error: %s", err)
} else {
Infof("compressed log file: %s, took %s", file, time.Since(start))
@@ -421,25 +429,37 @@ func getNowDateInRFC3339Format() string {
return time.Now().Format(fileTimeFormat)
}
func gzipFile(file string) error {
in, err := os.Open(file)
func gzipFile(file string, fsys fileSystem) (err error) {
in, err := fsys.Open(file)
if err != nil {
return err
}
defer in.Close()
defer func() {
if e := fsys.Close(in); e != nil {
Errorf("failed to close file: %s, error: %v", file, e)
}
if err == nil {
// only remove the original file when compression is successful
err = fsys.Remove(file)
}
}()
out, err := os.Create(fmt.Sprintf("%s%s", file, gzipExt))
out, err := fsys.Create(fmt.Sprintf("%s%s", file, gzipExt))
if err != nil {
return err
}
defer out.Close()
defer func() {
e := fsys.Close(out)
if err == nil {
err = e
}
}()
w := gzip.NewWriter(out)
if _, err = io.Copy(w, in); err != nil {
return err
} else if err = w.Close(); err != nil {
if _, err = fsys.Copy(w, in); err != nil {
// failed to copy, no need to close w
return err
}
return os.Remove(file)
return fsys.Close(w)
}

View File

@@ -1,8 +1,12 @@
package logx
import (
"errors"
"io"
"os"
"path"
"path/filepath"
"sync/atomic"
"syscall"
"testing"
"time"
@@ -13,18 +17,58 @@ import (
)
func TestDailyRotateRuleMarkRotated(t *testing.T) {
var rule DailyRotateRule
rule.MarkRotated()
assert.Equal(t, getNowDate(), rule.rotatedTime)
t.Run("daily rule", func(t *testing.T) {
var rule DailyRotateRule
rule.MarkRotated()
assert.Equal(t, getNowDate(), rule.rotatedTime)
})
t.Run("daily rule", func(t *testing.T) {
rule := DefaultRotateRule("test", "-", 1, false)
_, ok := rule.(*DailyRotateRule)
assert.True(t, ok)
})
}
func TestDailyRotateRuleOutdatedFiles(t *testing.T) {
var rule DailyRotateRule
assert.Empty(t, rule.OutdatedFiles())
rule.days = 1
assert.Empty(t, rule.OutdatedFiles())
rule.gzip = true
assert.Empty(t, rule.OutdatedFiles())
t.Run("no files", func(t *testing.T) {
var rule DailyRotateRule
assert.Empty(t, rule.OutdatedFiles())
rule.days = 1
assert.Empty(t, rule.OutdatedFiles())
rule.gzip = true
assert.Empty(t, rule.OutdatedFiles())
})
t.Run("bad files", func(t *testing.T) {
rule := DailyRotateRule{
filename: "[a-z",
}
assert.Empty(t, rule.OutdatedFiles())
rule.days = 1
assert.Empty(t, rule.OutdatedFiles())
rule.gzip = true
assert.Empty(t, rule.OutdatedFiles())
})
t.Run("temp files", func(t *testing.T) {
boundary := time.Now().Add(-time.Hour * time.Duration(hoursPerDay) * 2).Format(dateFormat)
f1, err := os.CreateTemp(os.TempDir(), "go-zero-test-"+boundary)
assert.NoError(t, err)
_ = f1.Close()
f2, err := os.CreateTemp(os.TempDir(), "go-zero-test-"+boundary)
assert.NoError(t, err)
_ = f2.Close()
t.Cleanup(func() {
_ = os.Remove(f1.Name())
_ = os.Remove(f2.Name())
})
rule := DailyRotateRule{
filename: path.Join(os.TempDir(), "go-zero-test-"),
days: 1,
}
assert.NotEmpty(t, rule.OutdatedFiles())
})
}
func TestDailyRotateRuleShallRotate(t *testing.T) {
@@ -34,20 +78,101 @@ func TestDailyRotateRuleShallRotate(t *testing.T) {
}
func TestSizeLimitRotateRuleMarkRotated(t *testing.T) {
var rule SizeLimitRotateRule
rule.MarkRotated()
assert.Equal(t, getNowDateInRFC3339Format(), rule.rotatedTime)
t.Run("size limit rule", func(t *testing.T) {
var rule SizeLimitRotateRule
rule.MarkRotated()
assert.Equal(t, getNowDateInRFC3339Format(), rule.rotatedTime)
})
t.Run("size limit rule", func(t *testing.T) {
rule := NewSizeLimitRotateRule("foo", "-", 1, 1, 1, false)
rule.MarkRotated()
assert.Equal(t, getNowDateInRFC3339Format(), rule.(*SizeLimitRotateRule).rotatedTime)
})
}
func TestSizeLimitRotateRuleOutdatedFiles(t *testing.T) {
var rule SizeLimitRotateRule
assert.Empty(t, rule.OutdatedFiles())
rule.days = 1
assert.Empty(t, rule.OutdatedFiles())
rule.gzip = true
assert.Empty(t, rule.OutdatedFiles())
rule.maxBackups = 0
assert.Empty(t, rule.OutdatedFiles())
t.Run("no files", func(t *testing.T) {
var rule SizeLimitRotateRule
assert.Empty(t, rule.OutdatedFiles())
rule.days = 1
assert.Empty(t, rule.OutdatedFiles())
rule.gzip = true
assert.Empty(t, rule.OutdatedFiles())
rule.maxBackups = 0
assert.Empty(t, rule.OutdatedFiles())
})
t.Run("bad files", func(t *testing.T) {
rule := SizeLimitRotateRule{
DailyRotateRule: DailyRotateRule{
filename: "[a-z",
},
}
assert.Empty(t, rule.OutdatedFiles())
rule.days = 1
assert.Empty(t, rule.OutdatedFiles())
rule.gzip = true
assert.Empty(t, rule.OutdatedFiles())
})
t.Run("temp files", func(t *testing.T) {
boundary := time.Now().Add(-time.Hour * time.Duration(hoursPerDay) * 2).Format(dateFormat)
f1, err := os.CreateTemp(os.TempDir(), "go-zero-test-"+boundary)
assert.NoError(t, err)
f2, err := os.CreateTemp(os.TempDir(), "go-zero-test-"+boundary)
assert.NoError(t, err)
boundary1 := time.Now().Add(time.Hour * time.Duration(hoursPerDay) * 2).Format(dateFormat)
f3, err := os.CreateTemp(os.TempDir(), "go-zero-test-"+boundary1)
assert.NoError(t, err)
t.Cleanup(func() {
_ = f1.Close()
_ = os.Remove(f1.Name())
_ = f2.Close()
_ = os.Remove(f2.Name())
_ = f3.Close()
_ = os.Remove(f3.Name())
})
rule := SizeLimitRotateRule{
DailyRotateRule: DailyRotateRule{
filename: path.Join(os.TempDir(), "go-zero-test-"),
days: 1,
},
maxBackups: 3,
}
assert.NotEmpty(t, rule.OutdatedFiles())
})
t.Run("no backups", func(t *testing.T) {
boundary := time.Now().Add(-time.Hour * time.Duration(hoursPerDay) * 2).Format(dateFormat)
f1, err := os.CreateTemp(os.TempDir(), "go-zero-test-"+boundary)
assert.NoError(t, err)
f2, err := os.CreateTemp(os.TempDir(), "go-zero-test-"+boundary)
assert.NoError(t, err)
boundary1 := time.Now().Add(time.Hour * time.Duration(hoursPerDay) * 2).Format(dateFormat)
f3, err := os.CreateTemp(os.TempDir(), "go-zero-test-"+boundary1)
assert.NoError(t, err)
t.Cleanup(func() {
_ = f1.Close()
_ = os.Remove(f1.Name())
_ = f2.Close()
_ = os.Remove(f2.Name())
_ = f3.Close()
_ = os.Remove(f3.Name())
})
rule := SizeLimitRotateRule{
DailyRotateRule: DailyRotateRule{
filename: path.Join(os.TempDir(), "go-zero-test-"),
days: 1,
},
}
assert.NotEmpty(t, rule.OutdatedFiles())
logger := new(RotateLogger)
logger.rule = &rule
logger.maybeDeleteOutdatedFiles()
assert.Empty(t, rule.OutdatedFiles())
})
}
func TestSizeLimitRotateRuleShallRotate(t *testing.T) {
@@ -61,14 +186,47 @@ func TestSizeLimitRotateRuleShallRotate(t *testing.T) {
}
func TestRotateLoggerClose(t *testing.T) {
filename, err := fs.TempFilenameWithText("foo")
assert.Nil(t, err)
if len(filename) > 0 {
defer os.Remove(filename)
}
logger, err := NewLogger(filename, new(DailyRotateRule), false)
assert.Nil(t, err)
assert.Nil(t, logger.Close())
t.Run("close", func(t *testing.T) {
filename, err := fs.TempFilenameWithText("foo")
assert.Nil(t, err)
if len(filename) > 0 {
defer os.Remove(filename)
}
logger, err := NewLogger(filename, new(DailyRotateRule), false)
assert.Nil(t, err)
_, err = logger.Write([]byte("foo"))
assert.Nil(t, err)
assert.Nil(t, logger.Close())
})
t.Run("close and write", func(t *testing.T) {
logger := new(RotateLogger)
logger.done = make(chan struct{})
close(logger.done)
_, err := logger.Write([]byte("foo"))
assert.ErrorIs(t, err, ErrLogFileClosed)
})
t.Run("close without losing logs", func(t *testing.T) {
text := "foo"
filename, err := fs.TempFilenameWithText(text)
assert.Nil(t, err)
if len(filename) > 0 {
defer os.Remove(filename)
}
logger, err := NewLogger(filename, new(DailyRotateRule), false)
assert.Nil(t, err)
msg := []byte("foo")
n := 100
for i := 0; i < n; i++ {
_, err = logger.Write(msg)
assert.Nil(t, err)
}
assert.Nil(t, logger.Close())
bs, err := os.ReadFile(filename)
assert.Nil(t, err)
assert.Equal(t, len(msg)*n+len(text), len(bs))
})
}
func TestRotateLoggerGetBackupFilename(t *testing.T) {
@@ -179,7 +337,7 @@ func TestRotateLoggerWithSizeLimitRotateRuleClose(t *testing.T) {
}
logger, err := NewLogger(filename, new(SizeLimitRotateRule), false)
assert.Nil(t, err)
assert.Nil(t, logger.Close())
_ = logger.Close()
}
func TestRotateLoggerGetBackupWithSizeLimitRotateRuleFilename(t *testing.T) {
@@ -295,6 +453,85 @@ func TestRotateLoggerWithSizeLimitRotateRuleWrite(t *testing.T) {
logger.write([]byte(`baz`))
}
func TestGzipFile(t *testing.T) {
err := errors.New("any error")
t.Run("gzip file open failed", func(t *testing.T) {
fsys := &fakeFileSystem{
openFn: func(name string) (*os.File, error) {
return nil, err
},
}
assert.ErrorIs(t, err, gzipFile("any", fsys))
assert.False(t, fsys.Removed())
})
t.Run("gzip file create failed", func(t *testing.T) {
fsys := &fakeFileSystem{
createFn: func(name string) (*os.File, error) {
return nil, err
},
}
assert.ErrorIs(t, err, gzipFile("any", fsys))
assert.False(t, fsys.Removed())
})
t.Run("gzip file copy failed", func(t *testing.T) {
fsys := &fakeFileSystem{
copyFn: func(writer io.Writer, reader io.Reader) (int64, error) {
return 0, err
},
}
assert.ErrorIs(t, err, gzipFile("any", fsys))
assert.False(t, fsys.Removed())
})
t.Run("gzip file last close failed", func(t *testing.T) {
var called int32
fsys := &fakeFileSystem{
closeFn: func(closer io.Closer) error {
if atomic.AddInt32(&called, 1) > 2 {
return err
}
return nil
},
}
assert.NoError(t, gzipFile("any", fsys))
assert.True(t, fsys.Removed())
})
t.Run("gzip file remove failed", func(t *testing.T) {
fsys := &fakeFileSystem{
removeFn: func(name string) error {
return err
},
}
assert.Error(t, err, gzipFile("any", fsys))
assert.True(t, fsys.Removed())
})
t.Run("gzip file everything ok", func(t *testing.T) {
fsys := &fakeFileSystem{}
assert.NoError(t, gzipFile("any", fsys))
assert.True(t, fsys.Removed())
})
}
func TestRotateLogger_WithExistingFile(t *testing.T) {
const body = "foo"
filename, err := fs.TempFilenameWithText(body)
assert.Nil(t, err)
if len(filename) > 0 {
defer os.Remove(filename)
}
rule := NewSizeLimitRotateRule(filename, "-", 1, 100, 3, false)
logger, err := NewLogger(filename, rule, false)
assert.Nil(t, err)
assert.Equal(t, int64(len(body)), logger.currentSize)
assert.Nil(t, logger.Close())
}
func BenchmarkRotateLogger(b *testing.B) {
filename := "./test.log"
filename2 := "./test2.log"
@@ -346,3 +583,53 @@ func BenchmarkRotateLogger(b *testing.B) {
}
})
}
type fakeFileSystem struct {
removed int32
closeFn func(closer io.Closer) error
copyFn func(writer io.Writer, reader io.Reader) (int64, error)
createFn func(name string) (*os.File, error)
openFn func(name string) (*os.File, error)
removeFn func(name string) error
}
func (f *fakeFileSystem) Close(closer io.Closer) error {
if f.closeFn != nil {
return f.closeFn(closer)
}
return nil
}
func (f *fakeFileSystem) Copy(writer io.Writer, reader io.Reader) (int64, error) {
if f.copyFn != nil {
return f.copyFn(writer, reader)
}
return 0, nil
}
func (f *fakeFileSystem) Create(name string) (*os.File, error) {
if f.createFn != nil {
return f.createFn(name)
}
return nil, nil
}
func (f *fakeFileSystem) Open(name string) (*os.File, error) {
if f.openFn != nil {
return f.openFn(name)
}
return nil, nil
}
func (f *fakeFileSystem) Remove(name string) error {
atomic.AddInt32(&f.removed, 1)
if f.removeFn != nil {
return f.removeFn(name)
}
return nil
}
func (f *fakeFileSystem) Removed() bool {
return atomic.LoadInt32(&f.removed) > 0
}

View File

@@ -12,6 +12,8 @@ import (
const testlog = "Stay hungry, stay foolish."
var testobj = map[string]any{"foo": "bar"}
func TestCollectSysLog(t *testing.T) {
CollectSysLog()
content := getContent(captureOutput(func() {

View File

@@ -97,6 +97,15 @@ func TestConsoleWriter(t *testing.T) {
w.(*concreteWriter).statLog = easyToCloseWriter{}
}
func TestNewFileWriter(t *testing.T) {
t.Run("access", func(t *testing.T) {
_, err := newFileWriter(LogConf{
Path: "/not-exists",
})
assert.Error(t, err)
})
}
func TestNopWriter(t *testing.T) {
assert.NotPanics(t, func() {
var w nopWriter

View File

@@ -5,6 +5,7 @@ import (
"encoding/json"
"errors"
"fmt"
"math"
"reflect"
"strconv"
"strings"
@@ -20,6 +21,7 @@ import (
const (
defaultKeyName = "key"
delimiter = '.'
ignoreKey = "-"
)
var (
@@ -49,6 +51,7 @@ type (
unmarshalOptions struct {
fillDefault bool
fromString bool
opaqueKeys bool
canonicalKey func(key string) string
}
)
@@ -72,7 +75,11 @@ func UnmarshalKey(m map[string]any, v any) error {
}
// Unmarshal unmarshals m into v.
func (u *Unmarshaler) Unmarshal(i any, v any) error {
func (u *Unmarshaler) Unmarshal(i, v any) error {
return u.unmarshal(i, v, "")
}
func (u *Unmarshaler) unmarshal(i, v any, fullName string) error {
valueType := reflect.TypeOf(v)
if valueType.Kind() != reflect.Ptr {
return errValueNotSettable
@@ -85,13 +92,13 @@ func (u *Unmarshaler) Unmarshal(i any, v any) error {
return errTypeMismatch
}
return u.UnmarshalValuer(mapValuer(iv), v)
return u.unmarshalValuer(mapValuer(iv), v, fullName)
case []any:
if elemType.Kind() != reflect.Slice {
return errTypeMismatch
}
return u.fillSlice(elemType, reflect.ValueOf(v).Elem(), iv)
return u.fillSlice(elemType, reflect.ValueOf(v).Elem(), iv, fullName)
default:
return errUnsupportedType
}
@@ -99,17 +106,21 @@ func (u *Unmarshaler) Unmarshal(i any, v any) error {
// UnmarshalValuer unmarshals m into v.
func (u *Unmarshaler) UnmarshalValuer(m Valuer, v any) error {
return u.unmarshalWithFullName(simpleValuer{current: m}, v, "")
return u.unmarshalValuer(simpleValuer{current: m}, v, "")
}
func (u *Unmarshaler) fillMap(fieldType reflect.Type, value reflect.Value, mapValue any) error {
func (u *Unmarshaler) unmarshalValuer(m Valuer, v any, fullName string) error {
return u.unmarshalWithFullName(simpleValuer{current: m}, v, fullName)
}
func (u *Unmarshaler) fillMap(fieldType reflect.Type, value reflect.Value, mapValue any, fullName string) error {
if !value.CanSet() {
return errValueNotSettable
}
fieldKeyType := fieldType.Key()
fieldElemType := fieldType.Elem()
targetValue, err := u.generateMap(fieldKeyType, fieldElemType, mapValue)
targetValue, err := u.generateMap(fieldKeyType, fieldElemType, mapValue, fullName)
if err != nil {
return err
}
@@ -143,19 +154,22 @@ func (u *Unmarshaler) fillMapFromString(value reflect.Value, mapValue any) error
return nil
}
func (u *Unmarshaler) fillSlice(fieldType reflect.Type, value reflect.Value, mapValue any) error {
func (u *Unmarshaler) fillSlice(fieldType reflect.Type, value reflect.Value, mapValue any, fullName string) error {
if !value.CanSet() {
return errValueNotSettable
}
refValue := reflect.ValueOf(mapValue)
if refValue.Kind() != reflect.Slice {
return newTypeMismatchErrorWithHint(fullName, reflect.Slice.String(), refValue.Type().String())
}
if refValue.IsNil() {
return nil
}
baseType := fieldType.Elem()
dereffedBaseType := Deref(baseType)
dereffedBaseKind := dereffedBaseType.Kind()
refValue := reflect.ValueOf(mapValue)
if refValue.IsNil() {
return nil
}
conv := reflect.MakeSlice(reflect.SliceOf(baseType), refValue.Len(), refValue.Cap())
if refValue.Len() == 0 {
value.Set(conv)
@@ -170,20 +184,27 @@ func (u *Unmarshaler) fillSlice(fieldType reflect.Type, value reflect.Value, map
}
valid = true
sliceFullName := fmt.Sprintf("%s[%d]", fullName, i)
switch dereffedBaseKind {
case reflect.Struct:
target := reflect.New(dereffedBaseType)
if err := u.Unmarshal(ithValue.(map[string]any), target.Interface()); err != nil {
val, ok := ithValue.(map[string]any)
if !ok {
return errTypeMismatch
}
if err := u.unmarshal(val, target.Interface(), sliceFullName); err != nil {
return err
}
SetValue(fieldType.Elem(), conv.Index(i), target.Elem())
case reflect.Slice:
if err := u.fillSlice(dereffedBaseType, conv.Index(i), ithValue); err != nil {
if err := u.fillSlice(dereffedBaseType, conv.Index(i), ithValue, sliceFullName); err != nil {
return err
}
default:
if err := u.fillSliceValue(conv, i, dereffedBaseKind, ithValue); err != nil {
if err := u.fillSliceValue(conv, i, dereffedBaseKind, ithValue, sliceFullName); err != nil {
return err
}
}
@@ -197,7 +218,7 @@ func (u *Unmarshaler) fillSlice(fieldType reflect.Type, value reflect.Value, map
}
func (u *Unmarshaler) fillSliceFromString(fieldType reflect.Type, value reflect.Value,
mapValue any) error {
mapValue any, fullName string) error {
var slice []any
switch v := mapValue.(type) {
case fmt.Stringer:
@@ -217,7 +238,7 @@ func (u *Unmarshaler) fillSliceFromString(fieldType reflect.Type, value reflect.
conv := reflect.MakeSlice(reflect.SliceOf(baseFieldType), len(slice), cap(slice))
for i := 0; i < len(slice); i++ {
if err := u.fillSliceValue(conv, i, baseFieldKind, slice[i]); err != nil {
if err := u.fillSliceValue(conv, i, baseFieldKind, slice[i], fullName); err != nil {
return err
}
}
@@ -227,7 +248,7 @@ func (u *Unmarshaler) fillSliceFromString(fieldType reflect.Type, value reflect.
}
func (u *Unmarshaler) fillSliceValue(slice reflect.Value, index int,
baseKind reflect.Kind, value any) error {
baseKind reflect.Kind, value any, fullName string) error {
ithVal := slice.Index(index)
switch v := value.(type) {
case fmt.Stringer:
@@ -235,7 +256,7 @@ func (u *Unmarshaler) fillSliceValue(slice reflect.Value, index int,
case string:
return setValueFromString(baseKind, ithVal, v)
case map[string]any:
return u.fillMap(ithVal.Type(), ithVal, value)
return u.fillMap(ithVal.Type(), ithVal, value, fullName)
default:
// don't need to consider the difference between int, int8, int16, int32, int64,
// uint, uint8, uint16, uint32, uint64, because they're handled as json.Number.
@@ -261,7 +282,7 @@ func (u *Unmarshaler) fillSliceValue(slice reflect.Value, index int,
}
func (u *Unmarshaler) fillSliceWithDefault(derefedType reflect.Type, value reflect.Value,
defaultValue string) error {
defaultValue, fullName string) error {
baseFieldType := Deref(derefedType.Elem())
baseFieldKind := baseFieldType.Kind()
defaultCacheLock.Lock()
@@ -279,10 +300,10 @@ func (u *Unmarshaler) fillSliceWithDefault(derefedType reflect.Type, value refle
defaultCacheLock.Unlock()
}
return u.fillSlice(derefedType, value, slice)
return u.fillSlice(derefedType, value, slice, fullName)
}
func (u *Unmarshaler) generateMap(keyType, elemType reflect.Type, mapValue any) (reflect.Value, error) {
func (u *Unmarshaler) generateMap(keyType, elemType reflect.Type, mapValue any, fullName string) (reflect.Value, error) {
mapType := reflect.MapOf(keyType, elemType)
valueType := reflect.TypeOf(mapValue)
if mapType == valueType {
@@ -301,11 +322,12 @@ func (u *Unmarshaler) generateMap(keyType, elemType reflect.Type, mapValue any)
for _, key := range refValue.MapKeys() {
keythValue := refValue.MapIndex(key)
keythData := keythValue.Interface()
mapFullName := fmt.Sprintf("%s[%s]", fullName, key.String())
switch dereffedElemKind {
case reflect.Slice:
target := reflect.New(dereffedElemType)
if err := u.fillSlice(elemType, target.Elem(), keythData); err != nil {
if err := u.fillSlice(elemType, target.Elem(), keythData, mapFullName); err != nil {
return emptyValue, err
}
@@ -317,7 +339,7 @@ func (u *Unmarshaler) generateMap(keyType, elemType reflect.Type, mapValue any)
}
target := reflect.New(dereffedElemType)
if err := u.Unmarshal(keythMap, target.Interface()); err != nil {
if err := u.unmarshal(keythMap, target.Interface(), mapFullName); err != nil {
return emptyValue, err
}
@@ -328,7 +350,7 @@ func (u *Unmarshaler) generateMap(keyType, elemType reflect.Type, mapValue any)
return emptyValue, errTypeMismatch
}
innerValue, err := u.generateMap(elemType.Key(), elemType.Elem(), keythMap)
innerValue, err := u.generateMap(elemType.Key(), elemType.Elem(), keythMap, mapFullName)
if err != nil {
return emptyValue, err
}
@@ -347,7 +369,12 @@ func (u *Unmarshaler) generateMap(keyType, elemType reflect.Type, mapValue any)
return emptyValue, errTypeMismatch
}
targetValue.SetMapIndex(key, reflect.ValueOf(v))
val := reflect.ValueOf(v)
if !val.Type().AssignableTo(dereffedElemType) {
return emptyValue, errTypeMismatch
}
targetValue.SetMapIndex(key, val)
case json.Number:
target := reflect.New(dereffedElemType)
if err := setValueFromString(dereffedElemKind, target.Elem(), v.String()); err != nil {
@@ -412,6 +439,10 @@ func (u *Unmarshaler) processAnonymousField(field reflect.StructField, value ref
return err
}
if key == ignoreKey {
return nil
}
if options.optional() {
return u.processAnonymousFieldOptional(field, value, key, m, fullName)
}
@@ -470,7 +501,7 @@ func (u *Unmarshaler) processAnonymousStructFieldOptional(fieldType reflect.Type
return err
}
_, hasValue := getValue(m, fieldKey)
_, hasValue := getValue(m, fieldKey, u.opts.opaqueKeys)
if hasValue {
if !filled {
filled = true
@@ -513,8 +544,8 @@ func (u *Unmarshaler) processFieldNotFromString(fieldType reflect.Type, value re
vp valueWithParent, opts *fieldOptionsWithContext, fullName string) error {
derefedFieldType := Deref(fieldType)
typeKind := derefedFieldType.Kind()
valueKind := reflect.TypeOf(vp.value).Kind()
mapValue := vp.value
valueKind := reflect.TypeOf(mapValue).Kind()
switch {
case valueKind == reflect.Map && typeKind == reflect.Struct:
@@ -527,12 +558,14 @@ func (u *Unmarshaler) processFieldNotFromString(fieldType reflect.Type, value re
current: mapValuer(mv),
parent: vp.parent,
}, fullName)
case typeKind == reflect.Slice && valueKind == reflect.Slice:
return u.fillSlice(fieldType, value, mapValue, fullName)
case valueKind == reflect.Map && typeKind == reflect.Map:
return u.fillMap(fieldType, value, mapValue)
return u.fillMap(fieldType, value, mapValue, fullName)
case valueKind == reflect.String && typeKind == reflect.Map:
return u.fillMapFromString(value, mapValue)
case valueKind == reflect.String && typeKind == reflect.Slice:
return u.fillSliceFromString(fieldType, value, mapValue)
return u.fillSliceFromString(fieldType, value, mapValue, fullName)
case valueKind == reflect.String && derefedFieldType == durationType:
return fillDurationValue(fieldType, value, mapValue.(string))
default:
@@ -545,23 +578,16 @@ func (u *Unmarshaler) processFieldPrimitive(fieldType reflect.Type, value reflec
typeKind := Deref(fieldType).Kind()
valueKind := reflect.TypeOf(mapValue).Kind()
switch {
case typeKind == reflect.Slice && valueKind == reflect.Slice:
return u.fillSlice(fieldType, value, mapValue)
case typeKind == reflect.Map && valueKind == reflect.Map:
return u.fillMap(fieldType, value, mapValue)
switch v := mapValue.(type) {
case json.Number:
return u.processFieldPrimitiveWithJSONNumber(fieldType, value, v, opts, fullName)
default:
switch v := mapValue.(type) {
case json.Number:
return u.processFieldPrimitiveWithJSONNumber(fieldType, value, v, opts, fullName)
default:
if typeKind == valueKind {
if err := validateValueInOptions(mapValue, opts.options()); err != nil {
return err
}
return fillWithSameType(fieldType, value, mapValue, opts)
if typeKind == valueKind {
if err := validateValueInOptions(mapValue, opts.options()); err != nil {
return err
}
return fillWithSameType(fieldType, value, mapValue, opts)
}
}
@@ -584,25 +610,23 @@ func (u *Unmarshaler) processFieldPrimitiveWithJSONNumber(fieldType reflect.Type
target := reflect.New(Deref(fieldType)).Elem()
switch typeKind {
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
iValue, err := v.Int64()
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64,
reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
if err := setValueFromString(typeKind, target, v.String()); err != nil {
return err
}
case reflect.Float32:
fValue, err := v.Float64()
if err != nil {
return err
}
target.SetInt(iValue)
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
iValue, err := v.Int64()
if err != nil {
return err
if fValue > math.MaxFloat32 {
return float32OverflowError(v.String())
}
if iValue < 0 {
return fmt.Errorf("unmarshal %q with bad value %q", fullName, v.String())
}
target.SetUint(uint64(iValue))
case reflect.Float32, reflect.Float64:
target.SetFloat(fValue)
case reflect.Float64:
fValue, err := v.Float64()
if err != nil {
return err
@@ -610,7 +634,7 @@ func (u *Unmarshaler) processFieldPrimitiveWithJSONNumber(fieldType reflect.Type
target.SetFloat(fValue)
default:
return newTypeMismatchError(fullName)
return newTypeMismatchErrorWithHint(fullName, typeKind.String(), value.Type().String())
}
SetValue(fieldType, value, target)
@@ -704,6 +728,10 @@ func (u *Unmarshaler) processNamedField(field reflect.StructField, value reflect
return err
}
if key == ignoreKey {
return nil
}
fullName = join(fullName, key)
if opts != nil && len(opts.EnvVar) > 0 {
envVal := proc.Env(opts.EnvVar)
@@ -718,7 +746,7 @@ func (u *Unmarshaler) processNamedField(field reflect.StructField, value reflect
}
valuer := createValuer(m, opts)
mapValue, hasValue := getValue(valuer, canonicalKey)
mapValue, hasValue := getValue(valuer, canonicalKey, u.opts.opaqueKeys)
// When fillDefault is used, m is a null value, hasValue must be false, all priority judgments fillDefault.
if u.opts.fillDefault {
@@ -811,7 +839,7 @@ func (u *Unmarshaler) processNamedFieldWithoutValue(fieldType reflect.Type, valu
switch fieldKind {
case reflect.Array, reflect.Slice:
return u.fillSliceWithDefault(derefedType, value, defaultValue)
return u.fillSliceWithDefault(derefedType, value, defaultValue, fullName)
default:
return setValueFromString(fieldKind, value, defaultValue)
}
@@ -859,7 +887,7 @@ func (u *Unmarshaler) processNamedFieldWithoutValue(fieldType reflect.Type, valu
func (u *Unmarshaler) unmarshalWithFullName(m valuerWithParent, v any, fullName string) error {
rv := reflect.ValueOf(v)
if err := ValidatePtr(&rv); err != nil {
if err := ValidatePtr(rv); err != nil {
return err
}
@@ -881,11 +909,6 @@ func (u *Unmarshaler) unmarshalWithFullName(m valuerWithParent, v any, fullName
typeField := baseType.Field(i)
valueField := valElem.Field(i)
if err := u.processField(typeField, valueField, m, fullName); err != nil {
if len(fullName) > 0 {
err = fmt.Errorf("%w, fullName: %s, field: %s, type: %s",
err, fullName, typeField.Name, valueField.Type().Name())
}
return err
}
}
@@ -914,6 +937,14 @@ func WithDefault() UnmarshalOption {
}
}
// WithOpaqueKeys customizes an Unmarshaler with opaque keys.
// Opaque keys are keys that are not processed by the unmarshaler.
func WithOpaqueKeys() UnmarshalOption {
return func(opt *unmarshalOptions) {
opt.opaqueKeys = true
}
}
func createValuer(v valuerWithParent, opts *fieldOptionsWithContext) valuerWithParent {
if opts.inherit() {
return recursiveValuer{
@@ -991,8 +1022,8 @@ func fillWithSameType(fieldType reflect.Type, value reflect.Value, mapValue any,
}
// getValue gets the value for the specific key, the key can be in the format of parentKey.childKey
func getValue(m valuerWithParent, key string) (any, bool) {
keys := readKeys(key)
func getValue(m valuerWithParent, key string, opaque bool) (any, bool) {
keys := readKeys(key, opaque)
return getValueWithChainedKeys(m, keys)
}
@@ -1046,7 +1077,16 @@ func newTypeMismatchError(name string) error {
return fmt.Errorf("type mismatch for field %q", name)
}
func readKeys(key string) []string {
func newTypeMismatchErrorWithHint(name, expectType, actualType string) error {
return fmt.Errorf("type mismatch for field %q, expect %q, actual %q",
name, expectType, actualType)
}
func readKeys(key string, opaque bool) []string {
if opaque {
return []string{key}
}
cacheKeysLock.Lock()
keys, ok := cacheKeys[key]
cacheKeysLock.Unlock()

File diff suppressed because it is too large Load Diff

View File

@@ -42,6 +42,10 @@ var (
)
type (
integer interface {
~int | ~int8 | ~int16 | ~int32 | ~int64 | ~uint | ~uint8 | ~uint16 | ~uint32 | ~uint64
}
optionsCacheValue struct {
key string
options *fieldOptions
@@ -79,7 +83,7 @@ func SetMapIndexValue(tp reflect.Type, value, key, target reflect.Value) {
}
// ValidatePtr validates v if it's a valid pointer.
func ValidatePtr(v *reflect.Value) error {
func ValidatePtr(v reflect.Value) error {
// sequence is very important, IsNil must be called after checking Kind() with reflect.Ptr,
// panic otherwise
if !v.IsValid() || v.Kind() != reflect.Ptr || v.IsNil() {
@@ -103,21 +107,32 @@ func convertTypeFromString(kind reflect.Kind, str string) (any, error) {
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
intValue, err := strconv.ParseInt(str, 10, 64)
if err != nil {
return 0, fmt.Errorf("the value %q cannot parsed as int", str)
return 0, err
}
return intValue, nil
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
uintValue, err := strconv.ParseUint(str, 10, 64)
if err != nil {
return 0, fmt.Errorf("the value %q cannot parsed as uint", str)
return 0, err
}
return uintValue, nil
case reflect.Float32, reflect.Float64:
case reflect.Float32:
floatValue, err := strconv.ParseFloat(str, 64)
if err != nil {
return 0, fmt.Errorf("the value %q cannot parsed as float", str)
return 0, err
}
if floatValue > math.MaxFloat32 {
return 0, float32OverflowError(str)
}
return floatValue, nil
case reflect.Float64:
floatValue, err := strconv.ParseFloat(str, 64)
if err != nil {
return 0, err
}
return floatValue, nil
@@ -215,6 +230,10 @@ func implicitValueRequiredStruct(tag string, tp reflect.Type) (bool, error) {
return false, nil
}
func intOverflowError[T integer](v T, kind reflect.Kind) error {
return fmt.Errorf("parsing \"%d\" as %s: value out of range", v, kind.String())
}
func isLeftInclude(b byte) (bool, error) {
switch b {
case '[':
@@ -237,6 +256,10 @@ func isRightInclude(b byte) (bool, error) {
}
}
func float32OverflowError(str string) error {
return fmt.Errorf("parsing %q as float32: value out of range", str)
}
func maybeNewValue(fieldType reflect.Type, value reflect.Value) {
if fieldType.Kind() == reflect.Ptr && value.IsNil() {
value.Set(reflect.New(value.Type().Elem()))
@@ -372,8 +395,6 @@ func parseOption(fieldOpts *fieldOptions, fieldName, option string) error {
default:
return fmt.Errorf("field %q has wrong optional", fieldName)
}
case option == optionalOption:
fieldOpts.Optional = true
case strings.HasPrefix(option, optionsOption):
val, err := parseProperty(fieldName, optionsOption, option)
if err != nil {
@@ -484,22 +505,61 @@ func parseSegments(val string) []string {
return segments
}
func setIntValue(value reflect.Value, v any, min, max int64) error {
iv := v.(int64)
if iv < min || iv > max {
return intOverflowError(iv, value.Kind())
}
value.SetInt(iv)
return nil
}
func setMatchedPrimitiveValue(kind reflect.Kind, value reflect.Value, v any) error {
switch kind {
case reflect.Bool:
value.SetBool(v.(bool))
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
return nil
case reflect.Int: // int depends on int size, 32 or 64
return setIntValue(value, v, math.MinInt, math.MaxInt)
case reflect.Int8:
return setIntValue(value, v, math.MinInt8, math.MaxInt8)
case reflect.Int16:
return setIntValue(value, v, math.MinInt16, math.MaxInt16)
case reflect.Int32:
return setIntValue(value, v, math.MinInt32, math.MaxInt32)
case reflect.Int64:
value.SetInt(v.(int64))
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
return nil
case reflect.Uint: // uint depends on int size, 32 or 64
return setUintValue(value, v, math.MaxUint)
case reflect.Uint8:
return setUintValue(value, v, math.MaxUint8)
case reflect.Uint16:
return setUintValue(value, v, math.MaxUint16)
case reflect.Uint32:
return setUintValue(value, v, math.MaxUint32)
case reflect.Uint64:
value.SetUint(v.(uint64))
return nil
case reflect.Float32, reflect.Float64:
value.SetFloat(v.(float64))
return nil
case reflect.String:
value.SetString(v.(string))
return nil
default:
return errUnsupportedType
}
}
func setUintValue(value reflect.Value, v any, boundary uint64) error {
iv := v.(uint64)
if iv > boundary {
return intOverflowError(iv, value.Kind())
}
value.SetUint(iv)
return nil
}
@@ -577,7 +637,8 @@ func usingDifferentKeys(key string, field reflect.StructField) bool {
return false
}
func validateAndSetValue(kind reflect.Kind, value reflect.Value, str string, opts *fieldOptionsWithContext) error {
func validateAndSetValue(kind reflect.Kind, value reflect.Value, str string,
opts *fieldOptionsWithContext) error {
if !value.CanSet() {
return errValueNotSettable
}

View File

@@ -218,30 +218,31 @@ func TestParseSegments(t *testing.T) {
func TestValidatePtrWithNonPtr(t *testing.T) {
var foo string
rve := reflect.ValueOf(foo)
assert.NotNil(t, ValidatePtr(&rve))
assert.NotNil(t, ValidatePtr(rve))
}
func TestValidatePtrWithPtr(t *testing.T) {
var foo string
rve := reflect.ValueOf(&foo)
assert.Nil(t, ValidatePtr(&rve))
assert.Nil(t, ValidatePtr(rve))
}
func TestValidatePtrWithNilPtr(t *testing.T) {
var foo *string
rve := reflect.ValueOf(foo)
assert.NotNil(t, ValidatePtr(&rve))
assert.NotNil(t, ValidatePtr(rve))
}
func TestValidatePtrWithZeroValue(t *testing.T) {
var s string
e := reflect.Zero(reflect.TypeOf(s))
assert.NotNil(t, ValidatePtr(&e))
assert.NotNil(t, ValidatePtr(e))
}
func TestSetValueNotSettable(t *testing.T) {
var i int
assert.NotNil(t, setValueFromString(reflect.Int, reflect.ValueOf(i), "1"))
assert.Error(t, setValueFromString(reflect.Int, reflect.ValueOf(i), "1"))
assert.Error(t, validateAndSetValue(reflect.Int, reflect.ValueOf(i), "1", nil))
}
func TestParseKeyAndOptionsErrors(t *testing.T) {
@@ -300,3 +301,36 @@ func TestSetValueFormatErrors(t *testing.T) {
})
}
}
func TestValidateValueRange(t *testing.T) {
t.Run("float", func(t *testing.T) {
assert.NoError(t, validateValueRange(1.2, nil))
})
t.Run("float number range", func(t *testing.T) {
assert.NoError(t, validateNumberRange(1.2, nil))
})
t.Run("bad float", func(t *testing.T) {
assert.Error(t, validateValueRange("a", &fieldOptionsWithContext{
Range: &numberRange{},
}))
})
t.Run("bad float validate", func(t *testing.T) {
var v struct {
Foo float32
}
assert.Error(t, validateAndSetValue(reflect.Int, reflect.ValueOf(&v).Elem().Field(0),
"1", &fieldOptionsWithContext{
Range: &numberRange{
left: 2,
right: 3,
},
}))
})
}
func TestSetMatchedPrimitiveValue(t *testing.T) {
assert.Error(t, setMatchedPrimitiveValue(reflect.Func, reflect.ValueOf(2), "1"))
}

View File

@@ -6,7 +6,6 @@ import (
"github.com/prometheus/client_golang/prometheus/testutil"
"github.com/stretchr/testify/assert"
"github.com/zeromicro/go-zero/core/proc"
)
func TestNewHistogramVec(t *testing.T) {
@@ -48,6 +47,4 @@ func TestHistogramObserve(t *testing.T) {
err := testutil.CollectAndCompare(hv.histogram, strings.NewReader(metadata+val))
assert.Nil(t, err)
proc.Shutdown()
}

65
core/metric/summary.go Normal file
View File

@@ -0,0 +1,65 @@
package metric
import (
prom "github.com/prometheus/client_golang/prometheus"
"github.com/zeromicro/go-zero/core/proc"
"github.com/zeromicro/go-zero/core/prometheus"
)
type (
// A SummaryVecOpts is a summary vector options
SummaryVecOpts struct {
VecOpt VectorOpts
Objectives map[float64]float64
}
// A SummaryVec interface represents a summary vector.
SummaryVec interface {
// Observe adds observation v to labels.
Observe(v float64, labels ...string)
close() bool
}
promSummaryVec struct {
summary *prom.SummaryVec
}
)
// NewSummaryVec return a SummaryVec
func NewSummaryVec(cfg *SummaryVecOpts) SummaryVec {
if cfg == nil {
return nil
}
vec := prom.NewSummaryVec(
prom.SummaryOpts{
Namespace: cfg.VecOpt.Namespace,
Subsystem: cfg.VecOpt.Subsystem,
Name: cfg.VecOpt.Name,
Help: cfg.VecOpt.Help,
Objectives: cfg.Objectives,
},
cfg.VecOpt.Labels,
)
prom.MustRegister(vec)
sv := &promSummaryVec{
summary: vec,
}
proc.AddShutdownListener(func() {
sv.close()
})
return sv
}
func (sv *promSummaryVec) Observe(v float64, labels ...string) {
if !prometheus.Enabled() {
return
}
sv.summary.WithLabelValues(labels...).Observe(v)
}
func (sv *promSummaryVec) close() bool {
return prom.Unregister(sv.summary)
}

View File

@@ -0,0 +1,68 @@
package metric
import (
"strings"
"testing"
"github.com/prometheus/client_golang/prometheus/testutil"
"github.com/stretchr/testify/assert"
"github.com/zeromicro/go-zero/core/proc"
)
func TestNewSummaryVec(t *testing.T) {
summaryVec := NewSummaryVec(&SummaryVecOpts{
VecOpt: VectorOpts{
Namespace: "http_server",
Subsystem: "requests",
Name: "duration_quantiles",
Help: "rpc client requests duration(ms) φ quantiles ",
Labels: []string{"method"},
},
Objectives: map[float64]float64{
0.5: 0.01,
0.9: 0.01,
},
})
defer summaryVec.close()
summaryVecNil := NewSummaryVec(nil)
assert.NotNil(t, summaryVec)
assert.Nil(t, summaryVecNil)
}
func TestSummaryObserve(t *testing.T) {
startAgent()
summaryVec := NewSummaryVec(&SummaryVecOpts{
VecOpt: VectorOpts{
Namespace: "http_server",
Subsystem: "requests",
Name: "duration_quantiles",
Help: "rpc client requests duration(ms) φ quantiles ",
Labels: []string{"method"},
},
Objectives: map[float64]float64{
0.3: 0.01,
0.6: 0.01,
1: 0.01,
},
})
defer summaryVec.close()
sv := summaryVec.(*promSummaryVec)
sv.Observe(100, "GET")
sv.Observe(200, "GET")
sv.Observe(300, "GET")
metadata := `
# HELP http_server_requests_duration_quantiles rpc client requests duration(ms) φ quantiles
# TYPE http_server_requests_duration_quantiles summary
`
val := `
http_server_requests_duration_quantiles{method="GET",quantile="0.3"} 100
http_server_requests_duration_quantiles{method="GET",quantile="0.6"} 200
http_server_requests_duration_quantiles{method="GET",quantile="1"} 300
http_server_requests_duration_quantiles_sum{method="GET"} 600
http_server_requests_duration_quantiles_count{method="GET"} 3
`
err := testutil.CollectAndCompare(sv.summary, strings.NewReader(metadata+val))
assert.Nil(t, err)
proc.Shutdown()
}

View File

@@ -3,7 +3,7 @@ package mr
import (
"context"
"errors"
"io/ioutil"
"io"
"log"
"runtime"
"sync/atomic"
@@ -17,7 +17,7 @@ import (
var errDummy = errors.New("dummy")
func init() {
log.SetOutput(ioutil.Discard)
log.SetOutput(io.Discard)
}
func TestFinish(t *testing.T) {
@@ -574,6 +574,7 @@ func TestMapReduceWithContext(t *testing.T) {
cancel()
}
writer.Write(i)
time.Sleep(time.Millisecond)
}, func(pipe <-chan int, cancel func(error)) {
for item := range pipe {
i := item

View File

@@ -1,7 +1,6 @@
package proc
import (
"os"
"testing"
"github.com/stretchr/testify/assert"
@@ -21,13 +20,11 @@ func TestEnvInt(t *testing.T) {
val, ok := EnvInt("any")
assert.Equal(t, 0, val)
assert.False(t, ok)
err := os.Setenv("anyInt", "10")
assert.Nil(t, err)
t.Setenv("anyInt", "10")
val, ok = EnvInt("anyInt")
assert.Equal(t, 10, val)
assert.True(t, ok)
err = os.Setenv("anyString", "a")
assert.Nil(t, err)
t.Setenv("anyString", "a")
val, ok = EnvInt("anyString")
assert.Equal(t, 0, val)
assert.False(t, ok)

View File

@@ -1,6 +0,0 @@
//go:build windows
package proc
func dumpGoroutines() {
}

View File

@@ -18,7 +18,11 @@ const (
debugLevel = 2
)
func dumpGoroutines() {
type creator interface {
Create(name string) (file *os.File, err error)
}
func dumpGoroutines(ctor creator) {
command := path.Base(os.Args[0])
pid := syscall.Getpid()
dumpFile := path.Join(os.TempDir(), fmt.Sprintf("%s-%d-goroutines-%s.dump",
@@ -26,10 +30,16 @@ func dumpGoroutines() {
logx.Infof("Got dump goroutine signal, printing goroutine profile to %s", dumpFile)
if f, err := os.Create(dumpFile); err != nil {
if f, err := ctor.Create(dumpFile); err != nil {
logx.Errorf("Failed to dump goroutine profile, error: %v", err)
} else {
defer f.Close()
pprof.Lookup(goroutineProfile).WriteTo(f, debugLevel)
}
}
type fileCreator struct{}
func (fc fileCreator) Create(name string) (file *os.File, err error) {
return os.Create(name)
}

View File

@@ -1,23 +1,41 @@
//go:build linux || darwin
package proc
import (
"errors"
"os"
"strings"
"testing"
"github.com/stretchr/testify/assert"
"github.com/zeromicro/go-zero/core/logx"
"github.com/zeromicro/go-zero/core/logx/logtest"
)
func TestDumpGoroutines(t *testing.T) {
var buf strings.Builder
w := logx.NewWriter(&buf)
o := logx.Reset()
logx.SetWriter(w)
defer func() {
logx.Reset()
logx.SetWriter(o)
}()
t.Run("real file", func(t *testing.T) {
buf := logtest.NewCollector(t)
dumpGoroutines(fileCreator{})
assert.True(t, strings.Contains(buf.String(), ".dump"))
})
dumpGoroutines()
assert.True(t, strings.Contains(buf.String(), ".dump"))
t.Run("fake file", func(t *testing.T) {
const msg = "any message"
buf := logtest.NewCollector(t)
err := errors.New(msg)
dumpGoroutines(fakeCreator{
file: &os.File{},
err: err,
})
assert.True(t, strings.Contains(buf.String(), msg))
})
}
type fakeCreator struct {
file *os.File
err error
}
func (fc fakeCreator) Create(name string) (file *os.File, err error) {
return fc.file, fc.err
}

View File

@@ -5,25 +5,16 @@ import (
"testing"
"github.com/stretchr/testify/assert"
"github.com/zeromicro/go-zero/core/logx"
"github.com/zeromicro/go-zero/core/logx/logtest"
)
func TestProfile(t *testing.T) {
var buf strings.Builder
w := logx.NewWriter(&buf)
o := logx.Reset()
logx.SetWriter(w)
defer func() {
logx.Reset()
logx.SetWriter(o)
}()
c := logtest.NewCollector(t)
profiler := StartProfile()
// start again should not work
assert.NotNil(t, StartProfile())
profiler.Stop()
// stop twice
profiler.Stop()
assert.True(t, strings.Contains(buf.String(), ".pprof"))
assert.True(t, strings.Contains(c.String(), ".pprof"))
}

View File

@@ -96,4 +96,6 @@ func (lm *listenerManager) notifyListeners() {
group.RunSafe(listener)
}
group.Wait()
lm.listeners = nil
}

View File

@@ -28,3 +28,33 @@ func TestShutdown(t *testing.T) {
called()
assert.Equal(t, 3, val)
}
func TestNotifyMoreThanOnce(t *testing.T) {
ch := make(chan struct{}, 1)
go func() {
var val int
called := AddWrapUpListener(func() {
val++
})
WrapUp()
WrapUp()
called()
assert.Equal(t, 1, val)
called = AddShutdownListener(func() {
val += 2
})
Shutdown()
Shutdown()
called()
assert.Equal(t, 3, val)
ch <- struct{}{}
}()
select {
case <-ch:
case <-time.After(time.Second):
t.Fatal("timeout, check error logs")
}
}

View File

@@ -26,7 +26,7 @@ func init() {
v := <-signals
switch v {
case syscall.SIGUSR1:
dumpGoroutines()
dumpGoroutines(fileCreator{})
case syscall.SIGUSR2:
if profiler == nil {
profiler = StartProfile()

View File

@@ -1,3 +1,5 @@
//go:build linux || darwin
package proc
import (

View File

@@ -2,6 +2,8 @@ package prof
import (
"fmt"
"io"
"os"
"runtime"
"time"
)
@@ -13,6 +15,10 @@ const (
// DisplayStats prints the goroutine, memory, GC stats with given interval, default to 5 seconds.
func DisplayStats(interval ...time.Duration) {
displayStatsWithWriter(os.Stdout, interval...)
}
func displayStatsWithWriter(writer io.Writer, interval ...time.Duration) {
duration := defaultInterval
for _, val := range interval {
duration = val
@@ -24,7 +30,7 @@ func DisplayStats(interval ...time.Duration) {
for range ticker.C {
var m runtime.MemStats
runtime.ReadMemStats(&m)
fmt.Printf("Goroutines: %d, Alloc: %vm, TotalAlloc: %vm, Sys: %vm, NumGC: %v\n",
fmt.Fprintf(writer, "Goroutines: %d, Alloc: %vm, TotalAlloc: %vm, Sys: %vm, NumGC: %v\n",
runtime.NumGoroutine(), m.Alloc/mega, m.TotalAlloc/mega, m.Sys/mega, m.NumGC)
}
}()

36
core/prof/runtime_test.go Normal file
View File

@@ -0,0 +1,36 @@
package prof
import (
"strings"
"sync"
"testing"
"time"
"github.com/stretchr/testify/assert"
)
func TestDisplayStats(t *testing.T) {
writer := &threadSafeBuffer{
buf: strings.Builder{},
}
displayStatsWithWriter(writer, time.Millisecond*10)
time.Sleep(time.Millisecond * 50)
assert.Contains(t, writer.String(), "Goroutines: ")
}
type threadSafeBuffer struct {
buf strings.Builder
lock sync.Mutex
}
func (b *threadSafeBuffer) String() string {
b.lock.Lock()
defer b.lock.Unlock()
return b.buf.String()
}
func (b *threadSafeBuffer) Write(p []byte) (n int, err error) {
b.lock.Lock()
defer b.lock.Unlock()
return b.buf.Write(p)
}

View File

@@ -21,6 +21,11 @@ func Enabled() bool {
return enabled.True()
}
// Enable enables prometheus.
func Enable() {
enabled.Set(true)
}
// StartAgent starts a prometheus agent.
func StartAgent(c Config) {
if len(c.Host) == 0 {

View File

@@ -1,6 +1,8 @@
package queue
import (
"errors"
"math"
"sync"
"sync/atomic"
"testing"
@@ -37,10 +39,82 @@ func TestQueue(t *testing.T) {
assert.Equal(t, int32(rounds), atomic.LoadInt32(&consumer.count))
}
func TestQueue_Broadcast(t *testing.T) {
producer := newMockedProducer(math.MaxInt32)
consumer := newMockedConsumer()
consumer.wait.Add(consumers)
q := NewQueue(func() (Producer, error) {
return producer, nil
}, func() (Consumer, error) {
return consumer, nil
})
q.AddListener(new(mockedListener))
q.SetName("mockqueue")
q.SetNumConsumer(consumers)
q.SetNumProducer(1)
go func() {
time.Sleep(time.Millisecond * 100)
q.Stop()
}()
go q.Start()
time.Sleep(time.Millisecond * 50)
q.Broadcast("message")
consumer.wait.Wait()
assert.Equal(t, int32(consumers), atomic.LoadInt32(&consumer.events))
}
func TestQueue_PauseResume(t *testing.T) {
producer := newMockedProducer(rounds)
consumer := newMockedConsumer()
consumer.wait.Add(consumers)
q := NewQueue(func() (Producer, error) {
return producer, nil
}, func() (Consumer, error) {
return consumer, nil
})
q.AddListener(new(mockedListener))
q.SetName("mockqueue")
q.SetNumConsumer(consumers)
q.SetNumProducer(1)
go func() {
producer.wait.Wait()
q.Stop()
}()
q.Start()
producer.listener.OnProducerPause()
assert.Equal(t, int32(0), atomic.LoadInt32(&q.active))
producer.listener.OnProducerResume()
assert.Equal(t, int32(1), atomic.LoadInt32(&q.active))
assert.Equal(t, int32(rounds), atomic.LoadInt32(&consumer.count))
}
func TestQueue_ConsumeError(t *testing.T) {
producer := newMockedProducer(rounds)
consumer := newMockedConsumer()
consumer.consumeErr = errors.New("consume error")
consumer.wait.Add(consumers)
q := NewQueue(func() (Producer, error) {
return producer, nil
}, func() (Consumer, error) {
return consumer, nil
})
q.AddListener(new(mockedListener))
q.SetName("mockqueue")
q.SetNumConsumer(consumers)
q.SetNumProducer(1)
go func() {
producer.wait.Wait()
q.Stop()
}()
q.Start()
assert.Equal(t, int32(rounds), atomic.LoadInt32(&consumer.count))
}
type mockedConsumer struct {
count int32
events int32
wait sync.WaitGroup
count int32
events int32
consumeErr error
wait sync.WaitGroup
}
func newMockedConsumer() *mockedConsumer {
@@ -49,7 +123,7 @@ func newMockedConsumer() *mockedConsumer {
func (c *mockedConsumer) Consume(string) error {
atomic.AddInt32(&c.count, 1)
return nil
return c.consumeErr
}
func (c *mockedConsumer) OnEvent(any) {
@@ -59,9 +133,10 @@ func (c *mockedConsumer) OnEvent(any) {
}
type mockedProducer struct {
total int32
count int32
wait sync.WaitGroup
total int32
count int32
listener ProduceListener
wait sync.WaitGroup
}
func newMockedProducer(total int32) *mockedProducer {
@@ -72,6 +147,7 @@ func newMockedProducer(total int32) *mockedProducer {
}
func (p *mockedProducer) AddListener(listener ProduceListener) {
p.listener = listener
}
func (p *mockedProducer) Produce() (string, bool) {

View File

@@ -1,6 +1,11 @@
package rescue
import "github.com/zeromicro/go-zero/core/logx"
import (
"context"
"runtime/debug"
"github.com/zeromicro/go-zero/core/logx"
)
// Recover is used with defer to do cleanup on panics.
// Use it like:
@@ -15,3 +20,14 @@ func Recover(cleanups ...func()) {
logx.ErrorStack(p)
}
}
// RecoverCtx is used with defer to do cleanup on panics.
func RecoverCtx(ctx context.Context, cleanups ...func()) {
for _, cleanup := range cleanups {
cleanup()
}
if p := recover(); p != nil {
logx.WithContext(ctx).Errorf("%+v\n%s", p, debug.Stack())
}
}

View File

@@ -1,6 +1,7 @@
package rescue
import (
"context"
"sync/atomic"
"testing"
@@ -25,3 +26,17 @@ func TestRescue(t *testing.T) {
})
assert.Equal(t, int32(5), atomic.LoadInt32(&count))
}
func TestRescueCtx(t *testing.T) {
var count int32
assert.NotPanics(t, func() {
defer RecoverCtx(context.Background(), func() {
atomic.AddInt32(&count, 2)
}, func() {
atomic.AddInt32(&count, 3)
})
panic("hello")
})
assert.Equal(t, int32(5), atomic.LoadInt32(&count))
}

View File

@@ -171,11 +171,11 @@ func add(nd *node, route string, item any) error {
token := route[:i]
children := nd.getChildren(token)
if child, ok := children[token]; ok {
if child != nil {
return add(child, route[i+1:], item)
if child == nil {
return errInvalidState
}
return errInvalidState
return add(child, route[i+1:], item)
}
child := newNode(nil)

View File

@@ -11,7 +11,7 @@ import (
type mockedRoute struct {
route string
value int
value any
}
func TestSearch(t *testing.T) {
@@ -187,6 +187,12 @@ func TestSearchInvalidItem(t *testing.T) {
assert.Equal(t, errEmptyItem, err)
}
func TestSearchInvalidState(t *testing.T) {
nd := newNode("0")
nd.children[0]["1"] = nil
assert.Error(t, add(nd, "1/2", "2"))
}
func BenchmarkSearchTree(b *testing.B) {
const (
avgLen = 1000

View File

@@ -1,8 +1,6 @@
package service
import (
"log"
"github.com/zeromicro/go-zero/core/load"
"github.com/zeromicro/go-zero/core/logx"
"github.com/zeromicro/go-zero/core/proc"
@@ -39,9 +37,7 @@ type ServiceConf struct {
// MustSetUp sets up the service, exits on error.
func (sc ServiceConf) MustSetUp() {
if err := sc.SetUp(); err != nil {
log.Fatal(err)
}
logx.Must(sc.SetUp())
}
// SetUp sets up the service.

View File

@@ -68,7 +68,7 @@ func (sg *ServiceGroup) doStart() {
for i := range sg.services {
service := sg.services[i]
routineGroup.RunSafe(func() {
routineGroup.Run(func() {
service.Start()
})
}

View File

@@ -14,30 +14,6 @@ var (
done = make(chan struct{})
)
type mockedService struct {
quit chan struct{}
multiplier int
}
func newMockedService(multiplier int) *mockedService {
return &mockedService{
quit: make(chan struct{}),
multiplier: multiplier,
}
}
func (s *mockedService) Start() {
mutex.Lock()
number *= s.multiplier
mutex.Unlock()
done <- struct{}{}
<-s.quit
}
func (s *mockedService) Stop() {
close(s.quit)
}
func TestServiceGroup(t *testing.T) {
multipliers := []int{2, 3, 5, 7}
want := 1
@@ -126,3 +102,27 @@ type mockedStarter struct {
func (s mockedStarter) Start() {
s.fn()
}
type mockedService struct {
quit chan struct{}
multiplier int
}
func newMockedService(multiplier int) *mockedService {
return &mockedService{
quit: make(chan struct{}),
multiplier: multiplier,
}
}
func (s *mockedService) Start() {
mutex.Lock()
number *= s.multiplier
mutex.Unlock()
done <- struct{}{}
<-s.quit
}
func (s *mockedService) Stop() {
close(s.quit)
}

View File

@@ -3,7 +3,6 @@
package stat
import (
"os"
"strconv"
"sync/atomic"
"testing"
@@ -12,8 +11,7 @@ import (
)
func TestReport(t *testing.T) {
os.Setenv(clusterNameKey, "test-cluster")
defer os.Unsetenv(clusterNameKey)
t.Setenv(clusterNameKey, "test-cluster")
var count int32
SetReporter(func(s string) {

View File

@@ -3,6 +3,7 @@ package internal
import (
"bufio"
"fmt"
"math"
"os"
"path"
"strconv"
@@ -218,6 +219,7 @@ func parseUints(val string) ([]uint64, error) {
return nil, nil
}
var sets []uint64
ints := make(map[uint64]lang.PlaceholderType)
cols := strings.Split(val, ",")
for _, r := range cols {
@@ -238,7 +240,10 @@ func parseUints(val string) ([]uint64, error) {
}
for i := min; i <= max; i++ {
ints[i] = lang.Placeholder
if _, ok := ints[i]; !ok {
ints[i] = lang.Placeholder
sets = append(sets, i)
}
}
} else {
v, err := parseUint(r)
@@ -246,19 +251,17 @@ func parseUints(val string) ([]uint64, error) {
return nil, err
}
ints[v] = lang.Placeholder
if _, ok := ints[v]; !ok {
ints[v] = lang.Placeholder
sets = append(sets, v)
}
}
}
var sets []uint64
for k := range ints {
sets = append(sets, k)
}
return sets, nil
}
// runningInUserNS detects whether we are currently running in an user namespace.
// runningInUserNS detects whether we are currently running in a user namespace.
func runningInUserNS() bool {
nsOnce.Do(func() {
file, err := os.Open("/proc/self/uid_map")
@@ -280,9 +283,10 @@ func runningInUserNS() bool {
// We assume we are in the initial user namespace if we have a full
// range - 4294967295 uids starting at uid 0.
if a == 0 && b == 0 && c == 4294967295 {
if a == 0 && b == 0 && c == math.MaxUint32 {
return
}
inUserNS = true
})

View File

@@ -0,0 +1,71 @@
package internal
import (
"fmt"
"testing"
"github.com/stretchr/testify/assert"
)
func TestRunningInUserNS(t *testing.T) {
// should be false in docker
assert.False(t, runningInUserNS())
}
func TestCgroupV1(t *testing.T) {
if isCgroup2UnifiedMode() {
cg, err := currentCgroupV1()
assert.NoError(t, err)
_, err = cg.cpus()
assert.Error(t, err)
_, err = cg.cpuPeriodUs()
assert.Error(t, err)
_, err = cg.cpuQuotaUs()
assert.Error(t, err)
_, err = cg.usageAllCpus()
assert.Error(t, err)
}
}
func TestParseUint(t *testing.T) {
tests := []struct {
input string
want uint64
err error
}{
{"0", 0, nil},
{"123", 123, nil},
{"-1", 0, nil},
{"-18446744073709551616", 0, nil},
{"foo", 0, fmt.Errorf("cgroup: bad int format: foo")},
}
for _, tt := range tests {
got, err := parseUint(tt.input)
assert.Equal(t, tt.err, err)
assert.Equal(t, tt.want, got)
}
}
func TestParseUints(t *testing.T) {
tests := []struct {
input string
want []uint64
err error
}{
{"", nil, nil},
{"1,2,3", []uint64{1, 2, 3}, nil},
{"1-3", []uint64{1, 2, 3}, nil},
{"1-3,5,7-9", []uint64{1, 2, 3, 5, 7, 8, 9}, nil},
{"foo", nil, fmt.Errorf("cgroup: bad int format: foo")},
{"1-bar", nil, fmt.Errorf("cgroup: bad int list format: 1-bar")},
{"bar-3", nil, fmt.Errorf("cgroup: bad int list format: bar-3")},
{"3-1", nil, fmt.Errorf("cgroup: bad int list format: 3-1")},
}
for _, tt := range tests {
got, err := parseUints(tt.input)
assert.Equal(t, tt.err, err)
assert.Equal(t, tt.want, got)
}
}

View File

@@ -141,7 +141,7 @@ func (c *metricsContainer) Execute(v any) {
report.Median = float32(medianTask.Duration) / float32(time.Millisecond)
tenPercent := fiftyPercent / 5
if tenPercent > 0 {
top10pTasks := topK(tasks, tenPercent)
top10pTasks := topK(top50pTasks, tenPercent)
task90th := top10pTasks[0]
report.Top90th = float32(task90th.Duration) / float32(time.Millisecond)
onePercent := tenPercent / 10
@@ -163,7 +163,7 @@ func (c *metricsContainer) Execute(v any) {
report.Top99p9th = mostDuration
}
} else {
mostDuration := getTopDuration(tasks)
mostDuration := getTopDuration(top50pTasks)
report.Top90th = mostDuration
report.Top99th = mostDuration
report.Top99p9th = mostDuration

View File

@@ -1,12 +1,11 @@
package stat
import (
"bytes"
"strings"
"testing"
"github.com/stretchr/testify/assert"
"github.com/zeromicro/go-zero/core/logx"
"github.com/zeromicro/go-zero/core/logx/logtest"
)
func TestBToMb(t *testing.T) {
@@ -41,15 +40,11 @@ func TestBToMb(t *testing.T) {
}
func TestPrintUsage(t *testing.T) {
var buf bytes.Buffer
writer := logx.NewWriter(&buf)
old := logx.Reset()
logx.SetWriter(writer)
defer logx.SetWriter(old)
c := logtest.NewCollector(t)
printUsage()
output := buf.String()
output := c.String()
assert.Contains(t, output, "CPU:")
assert.Contains(t, output, "MEMORY:")
assert.Contains(t, output, "Alloc=")

View File

@@ -69,3 +69,62 @@ func TestFieldNamesWithDashTagAndOptions(t *testing.T) {
assert.Equal(t, expected, out)
})
}
func TestPostgreSqlJoin(t *testing.T) {
// Test with empty input array
var input []string
var expectedOutput string
assert.Equal(t, expectedOutput, PostgreSqlJoin(input))
// Test with single element input array
input = []string{"foo"}
expectedOutput = "foo = $2"
assert.Equal(t, expectedOutput, PostgreSqlJoin(input))
// Test with multiple elements input array
input = []string{"foo", "bar", "baz"}
expectedOutput = "foo = $2, bar = $3, baz = $4"
assert.Equal(t, expectedOutput, PostgreSqlJoin(input))
}
type testStruct struct {
Foo string `db:"foo"`
Bar int `db:"bar"`
Baz bool `db:"-"`
}
func TestRawFieldNames(t *testing.T) {
// Test with a struct without tags
in := struct {
Foo string
Bar int
}{}
expectedOutput := []string{"`Foo`", "`Bar`"}
assert.ElementsMatch(t, expectedOutput, RawFieldNames(in))
// Test pg without db tag
expectedOutput = []string{"Foo", "Bar"}
assert.ElementsMatch(t, expectedOutput, RawFieldNames(in, true))
// Test with a struct with tags
input := testStruct{}
expectedOutput = []string{"`foo`", "`bar`"}
assert.ElementsMatch(t, expectedOutput, RawFieldNames(input))
// Test with nil input (pointer)
var nilInput *testStruct
assert.Panics(t, func() {
RawFieldNames(nilInput)
}, "RawFieldNames should panic with nil input")
// Test with non-struct input
inputInt := 42
assert.Panics(t, func() {
RawFieldNames(inputInt)
}, "RawFieldNames should panic with non-struct input")
// Test with PostgreSQL flag
input = testStruct{}
expectedOutput = []string{"foo", "bar"}
assert.ElementsMatch(t, expectedOutput, RawFieldNames(input, true))
}

View File

@@ -1,6 +1,3 @@
//go:build !race
// Disable data race detection is because of the timingWheel in cacheNode.
package cache
import (
@@ -34,8 +31,10 @@ func init() {
func TestCacheNode_DelCache(t *testing.T) {
t.Run("del cache", func(t *testing.T) {
store := redistest.CreateRedis(t)
store.Type = redis.ClusterType
r, err := miniredis.Run()
assert.NoError(t, err)
defer r.Close()
store := redis.New(r.Addr(), redis.Cluster())
cn := cacheNode{
rds: store,
@@ -56,16 +55,16 @@ func TestCacheNode_DelCache(t *testing.T) {
})
t.Run("del cache with errors", func(t *testing.T) {
old := timingWheel
old := timingWheel.Load()
ticker := timex.NewFakeTicker()
var err error
timingWheel, err = collection.NewTimingWheelWithTicker(
tw, err := collection.NewTimingWheelWithTicker(
time.Millisecond, timingWheelSlots, func(key, value any) {
clean(key, value)
}, ticker)
timingWheel.Store(tw)
assert.NoError(t, err)
t.Cleanup(func() {
timingWheel = old
timingWheel.Store(old)
})
r, err := miniredis.Run()
@@ -166,40 +165,99 @@ func TestCacheNode_TakeBadRedis(t *testing.T) {
}
func TestCacheNode_TakeNotFound(t *testing.T) {
store := redistest.CreateRedis(t)
t.Run("not found", func(t *testing.T) {
store := redistest.CreateRedis(t)
cn := cacheNode{
rds: store,
r: rand.New(rand.NewSource(time.Now().UnixNano())),
barrier: syncx.NewSingleFlight(),
lock: new(sync.Mutex),
unstableExpiry: mathx.NewUnstable(expiryDeviation),
stat: NewStat("any"),
errNotFound: errTestNotFound,
}
var str string
err := cn.Take(&str, "any", func(v any) error {
return errTestNotFound
})
assert.True(t, cn.IsNotFound(err))
assert.True(t, cn.IsNotFound(cn.Get("any", &str)))
val, err := store.Get("any")
assert.Nil(t, err)
assert.Equal(t, `*`, val)
cn := cacheNode{
rds: store,
r: rand.New(rand.NewSource(time.Now().UnixNano())),
barrier: syncx.NewSingleFlight(),
lock: new(sync.Mutex),
unstableExpiry: mathx.NewUnstable(expiryDeviation),
stat: NewStat("any"),
errNotFound: errTestNotFound,
}
var str string
err := cn.Take(&str, "any", func(v any) error {
return errTestNotFound
})
assert.True(t, cn.IsNotFound(err))
assert.True(t, cn.IsNotFound(cn.Get("any", &str)))
val, err := store.Get("any")
assert.Nil(t, err)
assert.Equal(t, `*`, val)
store.Set("any", "*")
err = cn.Take(&str, "any", func(v any) error {
return nil
})
assert.True(t, cn.IsNotFound(err))
assert.True(t, cn.IsNotFound(cn.Get("any", &str)))
store.Set("any", "*")
err = cn.Take(&str, "any", func(v any) error {
return nil
})
assert.True(t, cn.IsNotFound(err))
assert.True(t, cn.IsNotFound(cn.Get("any", &str)))
store.Del("any")
errDummy := errors.New("dummy")
err = cn.Take(&str, "any", func(v any) error {
return errDummy
store.Del("any")
errDummy := errors.New("dummy")
err = cn.Take(&str, "any", func(v any) error {
return errDummy
})
assert.Equal(t, errDummy, err)
})
t.Run("not found with redis error", func(t *testing.T) {
r, err := miniredis.Run()
assert.NoError(t, err)
defer r.Close()
store, err := redis.NewRedis(redis.RedisConf{
Host: r.Addr(),
Type: redis.NodeType,
})
assert.NoError(t, err)
cn := cacheNode{
rds: store,
r: rand.New(rand.NewSource(time.Now().UnixNano())),
barrier: syncx.NewSingleFlight(),
lock: new(sync.Mutex),
unstableExpiry: mathx.NewUnstable(expiryDeviation),
stat: NewStat("any"),
errNotFound: errTestNotFound,
}
var str string
err = cn.Take(&str, "any", func(v any) error {
r.SetError("mock error")
return errTestNotFound
})
assert.True(t, cn.IsNotFound(err))
})
}
func TestCacheNode_TakeCtxWithRedisError(t *testing.T) {
t.Run("not found with redis error", func(t *testing.T) {
r, err := miniredis.Run()
assert.NoError(t, err)
defer r.Close()
store, err := redis.NewRedis(redis.RedisConf{
Host: r.Addr(),
Type: redis.NodeType,
})
assert.NoError(t, err)
cn := cacheNode{
rds: store,
r: rand.New(rand.NewSource(time.Now().UnixNano())),
barrier: syncx.NewSingleFlight(),
lock: new(sync.Mutex),
unstableExpiry: mathx.NewUnstable(expiryDeviation),
stat: NewStat("any"),
errNotFound: errTestNotFound,
}
var str string
err = cn.Take(&str, "any", func(v any) error {
str = "foo"
r.SetError("mock error")
return nil
})
assert.NoError(t, err)
})
assert.Equal(t, errDummy, err)
}
func TestCacheNode_TakeNotFoundButChangedByOthers(t *testing.T) {

28
core/stores/cache/cacheopt_test.go vendored Normal file
View File

@@ -0,0 +1,28 @@
package cache
import (
"testing"
"time"
"github.com/stretchr/testify/assert"
)
func TestCacheOptions(t *testing.T) {
t.Run("default options", func(t *testing.T) {
o := newOptions()
assert.Equal(t, defaultExpiry, o.Expiry)
assert.Equal(t, defaultNotFoundExpiry, o.NotFoundExpiry)
})
t.Run("with expiry", func(t *testing.T) {
o := newOptions(WithExpiry(time.Second))
assert.Equal(t, time.Second, o.Expiry)
assert.Equal(t, defaultNotFoundExpiry, o.NotFoundExpiry)
})
t.Run("with not found expiry", func(t *testing.T) {
o := newOptions(WithNotFoundExpiry(time.Second))
assert.Equal(t, defaultExpiry, o.Expiry)
assert.Equal(t, time.Second, o.NotFoundExpiry)
})
}

View File

@@ -2,6 +2,7 @@ package cache
import (
"fmt"
"sync/atomic"
"time"
"github.com/zeromicro/go-zero/core/collection"
@@ -19,7 +20,8 @@ const (
)
var (
timingWheel *collection.TimingWheel
// use atomic to avoid data race in unit tests
timingWheel atomic.Value
taskRunner = threading.NewTaskRunner(cleanWorkers)
)
@@ -30,22 +32,27 @@ type delayTask struct {
}
func init() {
var err error
timingWheel, err = collection.NewTimingWheel(time.Second, timingWheelSlots, clean)
tw, err := collection.NewTimingWheel(time.Second, timingWheelSlots, clean)
logx.Must(err)
timingWheel.Store(tw)
proc.AddShutdownListener(func() {
timingWheel.Drain(clean)
if err := tw.Drain(clean); err != nil {
logx.Errorf("failed to drain timing wheel: %v", err)
}
})
}
// AddCleanTask adds a clean task on given keys.
func AddCleanTask(task func() error, keys ...string) {
timingWheel.SetTimer(stringx.Randn(taskKeyLen), delayTask{
tw := timingWheel.Load().(*collection.TimingWheel)
if err := tw.SetTimer(stringx.Randn(taskKeyLen), delayTask{
delay: time.Second,
task: task,
keys: keys,
}, time.Second)
}, time.Second); err != nil {
logx.Errorf("failed to set timer for keys: %q, error: %v", formatKeys(keys), err)
}
}
func clean(key, value any) {
@@ -59,7 +66,10 @@ func clean(key, value any) {
next, ok := nextDelay(dt.delay)
if ok {
dt.delay = next
timingWheel.SetTimer(key, dt, next)
tw := timingWheel.Load().(*collection.TimingWheel)
if err = tw.SetTimer(key, dt, next); err != nil {
logx.Errorf("failed to set timer for key: %s, error: %v", key, err)
}
} else {
msg := fmt.Sprintf("retried but failed to clear cache with keys: %q, error: %v",
formatKeys(dt.keys), err)

View File

@@ -5,7 +5,9 @@ import (
"time"
"github.com/stretchr/testify/assert"
"github.com/zeromicro/go-zero/core/collection"
"github.com/zeromicro/go-zero/core/proc"
"github.com/zeromicro/go-zero/core/timex"
)
func TestNextDelay(t *testing.T) {
@@ -49,6 +51,18 @@ func TestNextDelay(t *testing.T) {
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
old := timingWheel.Load()
ticker := timex.NewFakeTicker()
tw, err := collection.NewTimingWheelWithTicker(
time.Millisecond, timingWheelSlots, func(key, value any) {
clean(key, value)
}, ticker)
timingWheel.Store(tw)
assert.NoError(t, err)
t.Cleanup(func() {
timingWheel.Store(old)
})
next, ok := nextDelay(test.input)
assert.Equal(t, test.ok, ok)
assert.Equal(t, test.output, next)

View File

@@ -3,12 +3,11 @@ package mon
import (
"context"
"errors"
"strings"
"testing"
"github.com/stretchr/testify/assert"
"github.com/zeromicro/go-zero/core/breaker"
"github.com/zeromicro/go-zero/core/logx"
"github.com/zeromicro/go-zero/core/logx/logtest"
"github.com/zeromicro/go-zero/core/stringx"
"github.com/zeromicro/go-zero/core/timex"
"go.mongodb.org/mongo-driver/bson"
@@ -573,15 +572,7 @@ func TestDecoratedCollection_LogDuration(t *testing.T) {
brk: breaker.NewBreaker(),
}
var buf strings.Builder
w := logx.NewWriter(&buf)
o := logx.Reset()
logx.SetWriter(w)
defer func() {
logx.Reset()
logx.SetWriter(o)
}()
buf := logtest.NewCollector(t)
buf.Reset()
c.logDuration(context.Background(), "foo", timex.Now(), nil, "bar")

View File

@@ -2,10 +2,10 @@ package mon
import (
"context"
"log"
"strings"
"github.com/zeromicro/go-zero/core/breaker"
"github.com/zeromicro/go-zero/core/logx"
"github.com/zeromicro/go-zero/core/timex"
"go.mongodb.org/mongo-driver/mongo"
mopt "go.mongodb.org/mongo-driver/mongo/options"
@@ -39,10 +39,7 @@ type (
// MustNewModel returns a Model, exits on errors.
func MustNewModel(uri, db, collection string, opts ...Option) *Model {
model, err := NewModel(uri, db, collection, opts...)
if err != nil {
log.Fatal(err)
}
logx.Must(err)
return model
}

View File

@@ -3,12 +3,11 @@ package mon
import (
"context"
"errors"
"strings"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/zeromicro/go-zero/core/logx"
"github.com/zeromicro/go-zero/core/logx/logtest"
)
func TestFormatAddrs(t *testing.T) {
@@ -40,15 +39,7 @@ func TestFormatAddrs(t *testing.T) {
}
func Test_logDuration(t *testing.T) {
var buf strings.Builder
w := logx.NewWriter(&buf)
o := logx.Reset()
logx.SetWriter(w)
defer func() {
logx.Reset()
logx.SetWriter(o)
}()
buf := logtest.NewCollector(t)
buf.Reset()
logDuration(context.Background(), "foo", "bar", time.Millisecond, nil)

View File

@@ -2,8 +2,8 @@ package monc
import (
"context"
"log"
"github.com/zeromicro/go-zero/core/logx"
"github.com/zeromicro/go-zero/core/stores/cache"
"github.com/zeromicro/go-zero/core/stores/mon"
"github.com/zeromicro/go-zero/core/stores/redis"
@@ -30,20 +30,14 @@ type Model struct {
// MustNewModel returns a Model with a cache cluster, exists on errors.
func MustNewModel(uri, db, collection string, c cache.CacheConf, opts ...cache.Option) *Model {
model, err := NewModel(uri, db, collection, c, opts...)
if err != nil {
log.Fatal(err)
}
logx.Must(err)
return model
}
// MustNewNodeModel returns a Model with a cache node, exists on errors.
func MustNewNodeModel(uri, db, collection string, rds *redis.Redis, opts ...cache.Option) *Model {
model, err := NewNodeModel(uri, db, collection, rds, opts...)
if err != nil {
log.Fatal(err)
}
logx.Must(err)
return model
}

View File

@@ -56,7 +56,7 @@ func (h hook) AfterProcess(ctx context.Context, cmd red.Cmder) error {
logDuration(ctx, []red.Cmder{cmd}, duration)
}
metricReqDur.Observe(int64(duration/time.Millisecond), cmd.Name())
metricReqDur.Observe(duration.Milliseconds(), cmd.Name())
if msg := formatError(err); len(msg) > 0 {
metricReqErr.Inc(cmd.Name(), msg)
}
@@ -103,7 +103,7 @@ func (h hook) AfterProcessPipeline(ctx context.Context, cmds []red.Cmder) error
logDuration(ctx, cmds, duration)
}
metricReqDur.Observe(int64(duration/time.Millisecond), "Pipeline")
metricReqDur.Observe(duration.Milliseconds(), "Pipeline")
if msg := formatError(batchError.Err()); len(msg) > 0 {
metricReqErr.Inc("Pipeline", msg)
}

View File

@@ -2,14 +2,18 @@ package redis
import (
"context"
"errors"
"io"
"log"
"net"
"strings"
"testing"
"time"
red "github.com/go-redis/redis/v8"
"github.com/stretchr/testify/assert"
"github.com/zeromicro/go-zero/core/logx"
"github.com/zeromicro/go-zero/core/breaker"
"github.com/zeromicro/go-zero/core/logx/logtest"
ztrace "github.com/zeromicro/go-zero/core/trace"
tracesdk "go.opentelemetry.io/otel/trace"
)
@@ -47,8 +51,7 @@ func TestHookProcessCase2(t *testing.T) {
})
defer ztrace.StopAgent()
w, restore := injectLog()
defer restore()
w := logtest.NewCollector(t)
ctx, err := durationHook.BeforeProcess(context.Background(), red.NewCmd(context.Background()))
if err != nil {
@@ -115,8 +118,7 @@ func TestHookProcessPipelineCase2(t *testing.T) {
})
defer ztrace.StopAgent()
w, restore := injectLog()
defer restore()
w := logtest.NewCollector(t)
ctx, err := durationHook.BeforeProcessPipeline(context.Background(), []red.Cmder{
red.NewCmd(context.Background()),
@@ -135,8 +137,7 @@ func TestHookProcessPipelineCase2(t *testing.T) {
}
func TestHookProcessPipelineCase3(t *testing.T) {
w, restore := injectLog()
defer restore()
w := logtest.NewCollector(t)
assert.Nil(t, durationHook.AfterProcessPipeline(context.Background(), []red.Cmder{
red.NewCmd(context.Background()),
@@ -145,8 +146,7 @@ func TestHookProcessPipelineCase3(t *testing.T) {
}
func TestHookProcessPipelineCase4(t *testing.T) {
w, restore := injectLog()
defer restore()
w := logtest.NewCollector(t)
ctx := context.WithValue(context.Background(), startTimeKey, "foo")
assert.Nil(t, durationHook.AfterProcessPipeline(ctx, []red.Cmder{
@@ -169,8 +169,7 @@ func TestHookProcessPipelineCase5(t *testing.T) {
}
func TestLogDuration(t *testing.T) {
w, restore := injectLog()
defer restore()
w := logtest.NewCollector(t)
logDuration(context.Background(), []red.Cmder{
red.NewCmd(context.Background(), "get", "foo"),
@@ -184,14 +183,39 @@ func TestLogDuration(t *testing.T) {
assert.True(t, strings.Contains(w.String(), `get foo\nset bar 0`))
}
func injectLog() (r *strings.Builder, restore func()) {
var buf strings.Builder
w := logx.NewWriter(&buf)
o := logx.Reset()
logx.SetWriter(w)
return &buf, func() {
logx.Reset()
logx.SetWriter(o)
func TestFormatError(t *testing.T) {
// Test case: err is OpError
err := &net.OpError{
Err: mockOpError{},
}
assert.Equal(t, "timeout", formatError(err))
// Test case: err is nil
assert.Equal(t, "", formatError(nil))
// Test case: err is red.Nil
assert.Equal(t, "", formatError(red.Nil))
// Test case: err is io.EOF
assert.Equal(t, "eof", formatError(io.EOF))
// Test case: err is context.DeadlineExceeded
assert.Equal(t, "context deadline", formatError(context.DeadlineExceeded))
// Test case: err is breaker.ErrServiceUnavailable
assert.Equal(t, "breaker", formatError(breaker.ErrServiceUnavailable))
// Test case: err is unknown
assert.Equal(t, "unexpected error", formatError(errors.New("some error")))
}
type mockOpError struct {
}
func (mockOpError) Error() string {
return "mock error"
}
func (mockOpError) Timeout() bool {
return true
}

View File

@@ -4,13 +4,13 @@ import (
"context"
"errors"
"fmt"
"log"
"strconv"
"time"
red "github.com/go-redis/redis/v8"
"github.com/zeromicro/go-zero/core/breaker"
"github.com/zeromicro/go-zero/core/errorx"
"github.com/zeromicro/go-zero/core/logx"
"github.com/zeromicro/go-zero/core/mapping"
"github.com/zeromicro/go-zero/core/syncx"
)
@@ -91,22 +91,19 @@ type (
Script = red.Script
)
// MustNewRedis returns a Redis with given options.
func MustNewRedis(conf RedisConf, opts ...Option) *Redis {
rds, err := NewRedis(conf, opts...)
logx.Must(err)
return rds
}
// New returns a Redis with given options.
// Deprecated: use MustNewRedis or NewRedis instead.
func New(addr string, opts ...Option) *Redis {
return newRedis(addr, opts...)
}
// MustNewRedis returns a Redis with given options.
func MustNewRedis(conf RedisConf, opts ...Option) *Redis {
rds, err := NewRedis(conf, opts...)
if err != nil {
log.Fatal(err)
}
return rds
}
// NewRedis returns a Redis with given options.
func NewRedis(conf RedisConf, opts ...Option) (*Redis, error) {
if err := conf.Validate(); err != nil {

View File

@@ -3,6 +3,8 @@ package redis
import (
"testing"
"github.com/alicebob/miniredis/v2"
red "github.com/go-redis/redis/v8"
"github.com/stretchr/testify/assert"
)
@@ -41,3 +43,17 @@ func TestSplitClusterAddrs(t *testing.T) {
})
}
}
func TestGetCluster(t *testing.T) {
r := miniredis.RunT(t)
defer r.Close()
c, err := getCluster(&Redis{
Addr: r.Addr(),
Type: ClusterType,
tls: true,
hooks: []red.Hook{durationHook},
})
if assert.NoError(t, err) {
assert.NotNil(t, c)
}
}

View File

@@ -97,6 +97,9 @@ func (cc CachedConn) Exec(exec ExecFn, keys ...string) (sql.Result, error) {
}
// ExecCtx runs given exec on given keys, and returns execution result.
// If DB operation succeeds, it will delete cache with given keys,
// if DB operation fails, it will return nil result and non-nil error,
// if DB operation succeeds but cache deletion fails, it will return result and non-nil error.
func (cc CachedConn) ExecCtx(ctx context.Context, exec ExecCtxFn, keys ...string) (
sql.Result, error) {
res, err := exec(ctx, cc.db)
@@ -104,11 +107,7 @@ func (cc CachedConn) ExecCtx(ctx context.Context, exec ExecCtxFn, keys ...string
return nil, err
}
if err := cc.DelCacheCtx(ctx, keys...); err != nil {
return nil, err
}
return res, nil
return res, cc.DelCacheCtx(ctx, keys...)
}
// ExecNoCache runs exec with given sql statement, without affecting cache.
@@ -214,6 +213,17 @@ func (cc CachedConn) SetCacheCtx(ctx context.Context, key string, val any) error
return cc.cache.SetCtx(ctx, key, val)
}
// SetCacheWithExpire sets v into cache with given key with given expire.
func (cc CachedConn) SetCacheWithExpire(key string, val any, expire time.Duration) error {
return cc.SetCacheWithExpireCtx(context.Background(), key, val, expire)
}
// SetCacheWithExpireCtx sets v into cache with given key with given expire.
func (cc CachedConn) SetCacheWithExpireCtx(ctx context.Context, key string, val any,
expire time.Duration) error {
return cc.cache.SetWithExpireCtx(ctx, key, val, expire)
}
// Transact runs given fn in transaction mode.
func (cc CachedConn) Transact(fn func(sqlx.Session) error) error {
fnCtx := func(_ context.Context, session sqlx.Session) error {
@@ -226,3 +236,15 @@ func (cc CachedConn) Transact(fn func(sqlx.Session) error) error {
func (cc CachedConn) TransactCtx(ctx context.Context, fn func(context.Context, sqlx.Session) error) error {
return cc.db.TransactCtx(ctx, fn)
}
// WithSession returns a new CachedConn with given session.
// If query from session, the uncommitted data might be returned.
// Don't query for the uncommitted data, you should just use it,
// and don't use the cache for the uncommitted data.
// Not recommend to use cache within transactions due to consistency problem.
func (cc CachedConn) WithSession(session sqlx.Session) CachedConn {
return CachedConn{
db: sqlx.NewSqlConnFromSession(session),
cache: cc.cache,
}
}

View File

@@ -15,6 +15,7 @@ import (
"testing"
"time"
"github.com/DATA-DOG/go-sqlmock"
"github.com/alicebob/miniredis/v2"
"github.com/stretchr/testify/assert"
"github.com/zeromicro/go-zero/core/fx"
@@ -24,6 +25,8 @@ import (
"github.com/zeromicro/go-zero/core/stores/redis"
"github.com/zeromicro/go-zero/core/stores/redis/redistest"
"github.com/zeromicro/go-zero/core/stores/sqlx"
"github.com/zeromicro/go-zero/core/syncx"
"github.com/zeromicro/go-zero/internal/dbtest"
)
func init() {
@@ -39,7 +42,7 @@ func TestCachedConn_GetCache(t *testing.T) {
var value string
err := c.GetCache("any", &value)
assert.Equal(t, ErrNotFound, err)
r.Set("any", `"value"`)
_ = r.Set("any", `"value"`)
err = c.GetCache("any", &value)
assert.Nil(t, err)
assert.Equal(t, "value", value)
@@ -368,6 +371,24 @@ func TestStatFromMemory(t *testing.T) {
assert.Equal(t, uint64(9), atomic.LoadUint64(&stats.Hit))
}
func TestCachedConn_DelCache(t *testing.T) {
r := redistest.CreateRedis(t)
const (
key = "user"
value = "any"
)
assert.NoError(t, r.Set(key, value))
c := NewNodeConn(&trackedConn{}, r, cache.WithExpiry(time.Second*30))
err := c.DelCache(key)
assert.Nil(t, err)
val, err := r.Get(key)
assert.Nil(t, err)
assert.Empty(t, val)
}
func TestCachedConnQueryRow(t *testing.T) {
r := redistest.CreateRedis(t)
@@ -450,6 +471,36 @@ func TestCachedConnExec(t *testing.T) {
}
func TestCachedConnExecDropCache(t *testing.T) {
t.Run("drop cache", func(t *testing.T) {
r, err := miniredis.Run()
assert.Nil(t, err)
defer fx.DoWithTimeout(func() error {
r.Close()
return nil
}, time.Second)
const (
key = "user"
value = "any"
)
var conn trackedConn
c := NewNodeConn(&conn, redis.New(r.Addr()), cache.WithExpiry(time.Second*30))
assert.Nil(t, c.SetCache(key, value))
_, err = c.Exec(func(conn sqlx.SqlConn) (result sql.Result, e error) {
return conn.Exec("delete from user_table where id='kevin'")
}, key)
assert.Nil(t, err)
assert.True(t, conn.execValue)
_, err = r.Get(key)
assert.Exactly(t, miniredis.ErrKeyNotFound, err)
_, err = c.Exec(func(conn sqlx.SqlConn) (result sql.Result, e error) {
return nil, errors.New("foo")
}, key)
assert.NotNil(t, err)
})
}
func TestCachedConn_SetCacheWithExpire(t *testing.T) {
r, err := miniredis.Run()
assert.Nil(t, err)
defer fx.DoWithTimeout(func() error {
@@ -463,18 +514,13 @@ func TestCachedConnExecDropCache(t *testing.T) {
)
var conn trackedConn
c := NewNodeConn(&conn, redis.New(r.Addr()), cache.WithExpiry(time.Second*30))
assert.Nil(t, c.SetCache(key, value))
_, err = c.Exec(func(conn sqlx.SqlConn) (result sql.Result, e error) {
return conn.Exec("delete from user_table where id='kevin'")
}, key)
assert.Nil(t, err)
assert.True(t, conn.execValue)
_, err = r.Get(key)
assert.Exactly(t, miniredis.ErrKeyNotFound, err)
_, err = c.Exec(func(conn sqlx.SqlConn) (result sql.Result, e error) {
return nil, errors.New("foo")
}, key)
assert.NotNil(t, err)
assert.Nil(t, c.SetCacheWithExpire(key, value, time.Minute))
val, err := r.Get(key)
if assert.NoError(t, err) {
ttl := r.TTL(key)
assert.True(t, ttl > 0 && ttl <= time.Minute)
assert.Equal(t, fmt.Sprintf("%q", value), val)
}
}
func TestCachedConnExecDropCacheFailed(t *testing.T) {
@@ -543,6 +589,125 @@ func TestNewConnWithCache(t *testing.T) {
assert.True(t, conn.execValue)
}
func TestCachedConn_WithSession(t *testing.T) {
dbtest.RunTxTest(t, func(tx *sql.Tx, mock sqlmock.Sqlmock) {
mock.ExpectExec("any").WillReturnResult(sqlmock.NewResult(2, 3))
r := redistest.CreateRedis(t)
conn := CachedConn{
cache: cache.NewNode(r, syncx.NewSingleFlight(), stats, sql.ErrNoRows),
}
conn = conn.WithSession(sqlx.NewSessionFromTx(tx))
res, err := conn.Exec(func(conn sqlx.SqlConn) (sql.Result, error) {
return conn.Exec("any")
}, "foo")
assert.NoError(t, err)
last, err := res.LastInsertId()
assert.NoError(t, err)
assert.Equal(t, int64(2), last)
affected, err := res.RowsAffected()
assert.NoError(t, err)
assert.Equal(t, int64(3), affected)
})
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
mock.ExpectBegin()
mock.ExpectExec("any").WillReturnResult(sqlmock.NewResult(2, 3))
mock.ExpectCommit()
r := redistest.CreateRedis(t)
conn := CachedConn{
db: sqlx.NewSqlConnFromDB(db),
cache: cache.NewNode(r, syncx.NewSingleFlight(), stats, sql.ErrNoRows),
}
assert.NoError(t, conn.Transact(func(session sqlx.Session) error {
conn = conn.WithSession(session)
res, err := conn.Exec(func(conn sqlx.SqlConn) (sql.Result, error) {
return conn.Exec("any")
}, "foo")
assert.NoError(t, err)
last, err := res.LastInsertId()
assert.NoError(t, err)
assert.Equal(t, int64(2), last)
affected, err := res.RowsAffected()
assert.NoError(t, err)
assert.Equal(t, int64(3), affected)
return nil
}))
})
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
mock.ExpectBegin()
mock.ExpectExec("any").WillReturnError(errors.New("foo"))
mock.ExpectRollback()
r := redistest.CreateRedis(t)
conn := CachedConn{
db: sqlx.NewSqlConnFromDB(db),
cache: cache.NewNode(r, syncx.NewSingleFlight(), stats, sql.ErrNoRows),
}
assert.Error(t, conn.Transact(func(session sqlx.Session) error {
conn = conn.WithSession(session)
_, err := conn.Exec(func(conn sqlx.SqlConn) (sql.Result, error) {
return conn.Exec("any")
}, "bar")
return err
}))
})
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
mock.ExpectBegin()
mock.ExpectQuery("any").WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(2))
mock.ExpectCommit()
r := redistest.CreateRedis(t)
conn := CachedConn{
db: sqlx.NewSqlConnFromDB(db),
cache: cache.NewNode(r, syncx.NewSingleFlight(), stats, sql.ErrNoRows),
}
assert.NoError(t, conn.Transact(func(session sqlx.Session) error {
var val string
conn = conn.WithSession(session)
err := conn.QueryRow(&val, "foo", func(conn sqlx.SqlConn, v interface{}) error {
return conn.QueryRow(v, "any")
})
assert.Equal(t, "2", val)
return err
}))
val, err := r.Get("foo")
assert.NoError(t, err)
assert.Equal(t, `"2"`, val)
})
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
mock.ExpectBegin()
mock.ExpectQuery("any").WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(2))
mock.ExpectExec("any").WillReturnResult(sqlmock.NewResult(2, 3))
mock.ExpectCommit()
r := redistest.CreateRedis(t)
conn := CachedConn{
db: sqlx.NewSqlConnFromDB(db),
cache: cache.NewNode(r, syncx.NewSingleFlight(), stats, sql.ErrNoRows),
}
assert.NoError(t, conn.Transact(func(session sqlx.Session) error {
var val string
conn = conn.WithSession(session)
assert.NoError(t, conn.QueryRow(&val, "foo", func(conn sqlx.SqlConn, v interface{}) error {
return conn.QueryRow(v, "any")
}))
assert.Equal(t, "2", val)
_, err := conn.Exec(func(conn sqlx.SqlConn) (sql.Result, error) {
return conn.Exec("any")
}, "foo")
return err
}))
val, err := r.Get("foo")
assert.NoError(t, err)
assert.Empty(t, val)
})
}
func resetStats() {
atomic.StoreUint64(&stats.Total, 0)
atomic.StoreUint64(&stats.Hit, 0)
@@ -554,35 +719,35 @@ type dummySqlConn struct {
queryRow func(any, string, ...any) error
}
func (d dummySqlConn) ExecCtx(ctx context.Context, query string, args ...any) (sql.Result, error) {
func (d dummySqlConn) ExecCtx(_ context.Context, _ string, _ ...any) (sql.Result, error) {
return nil, nil
}
func (d dummySqlConn) PrepareCtx(ctx context.Context, query string) (sqlx.StmtSession, error) {
func (d dummySqlConn) PrepareCtx(_ context.Context, _ string) (sqlx.StmtSession, error) {
return nil, nil
}
func (d dummySqlConn) QueryRowPartialCtx(ctx context.Context, v any, query string, args ...any) error {
func (d dummySqlConn) QueryRowPartialCtx(_ context.Context, _ any, _ string, _ ...any) error {
return nil
}
func (d dummySqlConn) QueryRowsCtx(ctx context.Context, v any, query string, args ...any) error {
func (d dummySqlConn) QueryRowsCtx(_ context.Context, _ any, _ string, _ ...any) error {
return nil
}
func (d dummySqlConn) QueryRowsPartialCtx(ctx context.Context, v any, query string, args ...any) error {
func (d dummySqlConn) QueryRowsPartialCtx(_ context.Context, _ any, _ string, _ ...any) error {
return nil
}
func (d dummySqlConn) TransactCtx(ctx context.Context, fn func(context.Context, sqlx.Session) error) error {
func (d dummySqlConn) TransactCtx(_ context.Context, _ func(context.Context, sqlx.Session) error) error {
return nil
}
func (d dummySqlConn) Exec(query string, args ...any) (sql.Result, error) {
func (d dummySqlConn) Exec(_ string, _ ...any) (sql.Result, error) {
return nil, nil
}
func (d dummySqlConn) Prepare(query string) (sqlx.StmtSession, error) {
func (d dummySqlConn) Prepare(_ string) (sqlx.StmtSession, error) {
return nil, nil
}
@@ -597,15 +762,15 @@ func (d dummySqlConn) QueryRowCtx(_ context.Context, v any, query string, args .
return nil
}
func (d dummySqlConn) QueryRowPartial(v any, query string, args ...any) error {
func (d dummySqlConn) QueryRowPartial(_ any, _ string, _ ...any) error {
return nil
}
func (d dummySqlConn) QueryRows(v any, query string, args ...any) error {
func (d dummySqlConn) QueryRows(_ any, _ string, _ ...any) error {
return nil
}
func (d dummySqlConn) QueryRowsPartial(v any, query string, args ...any) error {
func (d dummySqlConn) QueryRowsPartial(_ any, _ string, _ ...any) error {
return nil
}

View File

@@ -9,7 +9,7 @@ import (
"github.com/DATA-DOG/go-sqlmock"
"github.com/stretchr/testify/assert"
"github.com/zeromicro/go-zero/core/logx"
"github.com/zeromicro/go-zero/internal/dbtest"
)
type mockedConn struct {
@@ -81,7 +81,7 @@ func (c *mockedConn) Transact(func(session Session) error) error {
}
func TestBulkInserter(t *testing.T) {
runSqlTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
var conn mockedConn
inserter, err := NewBulkInserter(&conn, `INSERT INTO classroom_dau(classroom, user, count) VALUES(?, ?, ?)`)
assert.Nil(t, err)
@@ -98,7 +98,7 @@ func TestBulkInserter(t *testing.T) {
}
func TestBulkInserterSuffix(t *testing.T) {
runSqlTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
var conn mockedConn
inserter, err := NewBulkInserter(&conn, `INSERT INTO classroom_dau(classroom, user, count) VALUES`+
`(?, ?, ?) ON DUPLICATE KEY UPDATE is_overtime=VALUES(is_overtime)`)
@@ -119,7 +119,7 @@ func TestBulkInserterSuffix(t *testing.T) {
}
func TestBulkInserterBadStatement(t *testing.T) {
runSqlTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
var conn mockedConn
_, err := NewBulkInserter(&conn, "foo")
assert.NotNil(t, err)
@@ -144,19 +144,3 @@ func TestBulkInserter_Update(t *testing.T) {
assert.NotNil(t, inserter.UpdateStmt("foo"))
assert.NotNil(t, inserter.Insert("foo", "bar"))
}
func runSqlTest(t *testing.T, fn func(db *sql.DB, mock sqlmock.Sqlmock)) {
logx.Disable()
db, mock, err := sqlmock.New()
if err != nil {
t.Fatalf("an error '%s' was not expected when opening a stub database connection", err)
}
defer db.Close()
fn(db, mock)
if err := mock.ExpectationsWereMet(); err != nil {
t.Errorf("there were unfulfilled expectations: %s", err)
}
}

View File

@@ -0,0 +1,14 @@
package sqlx
import (
"database/sql"
"errors"
)
var (
// ErrNotFound is an alias of sql.ErrNoRows
ErrNotFound = sql.ErrNoRows
errCantNestTx = errors.New("cannot nest transactions")
errNoRawDBFromTx = errors.New("cannot get raw db from transaction")
)

View File

@@ -32,7 +32,5 @@ func mysqlAcceptable(err error) bool {
}
func withMysqlAcceptable() SqlOption {
return func(conn *commonSqlConn) {
conn.accept = mysqlAcceptable
}
return WithAcceptable(mysqlAcceptable)
}

View File

@@ -54,27 +54,39 @@ func getTaggedFieldValueMap(v reflect.Value) (map[string]any, error) {
}
valueField := reflect.Indirect(v).Field(i)
switch valueField.Kind() {
case reflect.Ptr:
if !valueField.CanInterface() {
return nil, ErrNotReadableValue
}
if valueField.IsNil() {
baseValueType := mapping.Deref(valueField.Type())
valueField.Set(reflect.New(baseValueType))
}
result[key] = valueField.Interface()
default:
if !valueField.CanAddr() || !valueField.Addr().CanInterface() {
return nil, ErrNotReadableValue
}
result[key] = valueField.Addr().Interface()
valueData, err := getValueInterface(valueField)
if err != nil {
return nil, err
}
result[key] = valueData
}
return result, nil
}
func getValueInterface(value reflect.Value) (any, error) {
switch value.Kind() {
case reflect.Ptr:
if !value.CanInterface() {
return nil, ErrNotReadableValue
}
if value.IsNil() {
baseValueType := mapping.Deref(value.Type())
value.Set(reflect.New(baseValueType))
}
return value.Interface(), nil
default:
if !value.CanAddr() || !value.Addr().CanInterface() {
return nil, ErrNotReadableValue
}
return value.Addr().Interface(), nil
}
}
func mapStructFieldsIntoSlice(v reflect.Value, columns []string, strict bool) ([]any, error) {
fields := unwrapFields(v)
if strict && len(columns) < len(fields) {
@@ -88,24 +100,18 @@ func mapStructFieldsIntoSlice(v reflect.Value, columns []string, strict bool) ([
values := make([]any, len(columns))
if len(taggedMap) == 0 {
if len(fields) < len(values) {
return nil, ErrNotMatchDestination
}
for i := 0; i < len(values); i++ {
valueField := fields[i]
switch valueField.Kind() {
case reflect.Ptr:
if !valueField.CanInterface() {
return nil, ErrNotReadableValue
}
if valueField.IsNil() {
baseValueType := mapping.Deref(valueField.Type())
valueField.Set(reflect.New(baseValueType))
}
values[i] = valueField.Interface()
default:
if !valueField.CanAddr() || !valueField.Addr().CanInterface() {
return nil, ErrNotReadableValue
}
values[i] = valueField.Addr().Interface()
valueData, err := getValueInterface(valueField)
if err != nil {
return nil, err
}
values[i] = valueData
}
} else {
for i, column := range columns {
@@ -140,7 +146,7 @@ func unmarshalRow(v any, scanner rowsScanner, strict bool) error {
}
rv := reflect.ValueOf(v)
if err := mapping.ValidatePtr(&rv); err != nil {
if err := mapping.ValidatePtr(rv); err != nil {
return err
}
@@ -152,11 +158,11 @@ func unmarshalRow(v any, scanner rowsScanner, strict bool) error {
reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64,
reflect.Float32, reflect.Float64,
reflect.String:
if rve.CanSet() {
return scanner.Scan(v)
if !rve.CanSet() {
return ErrNotSettable
}
return ErrNotSettable
return scanner.Scan(v)
case reflect.Struct:
columns, err := scanner.Columns()
if err != nil {
@@ -176,76 +182,73 @@ func unmarshalRow(v any, scanner rowsScanner, strict bool) error {
func unmarshalRows(v any, scanner rowsScanner, strict bool) error {
rv := reflect.ValueOf(v)
if err := mapping.ValidatePtr(&rv); err != nil {
if err := mapping.ValidatePtr(rv); err != nil {
return err
}
rt := reflect.TypeOf(v)
rte := rt.Elem()
rve := rv.Elem()
if !rve.CanSet() {
return ErrNotSettable
}
switch rte.Kind() {
case reflect.Slice:
if rve.CanSet() {
ptr := rte.Elem().Kind() == reflect.Ptr
appendFn := func(item reflect.Value) {
if ptr {
rve.Set(reflect.Append(rve, item))
} else {
rve.Set(reflect.Append(rve, reflect.Indirect(item)))
}
ptr := rte.Elem().Kind() == reflect.Ptr
appendFn := func(item reflect.Value) {
if ptr {
rve.Set(reflect.Append(rve, item))
} else {
rve.Set(reflect.Append(rve, reflect.Indirect(item)))
}
fillFn := func(value any) error {
if rve.CanSet() {
if err := scanner.Scan(value); err != nil {
return err
}
appendFn(reflect.ValueOf(value))
return nil
}
return ErrNotSettable
}
fillFn := func(value any) error {
if err := scanner.Scan(value); err != nil {
return err
}
base := mapping.Deref(rte.Elem())
switch base.Kind() {
case reflect.Bool,
reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64,
reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64,
reflect.Float32, reflect.Float64,
reflect.String:
for scanner.Next() {
value := reflect.New(base)
if err := fillFn(value.Interface()); err != nil {
return err
}
appendFn(reflect.ValueOf(value))
return nil
}
base := mapping.Deref(rte.Elem())
switch base.Kind() {
case reflect.Bool,
reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64,
reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64,
reflect.Float32, reflect.Float64,
reflect.String:
for scanner.Next() {
value := reflect.New(base)
if err := fillFn(value.Interface()); err != nil {
return err
}
case reflect.Struct:
columns, err := scanner.Columns()
}
case reflect.Struct:
columns, err := scanner.Columns()
if err != nil {
return err
}
for scanner.Next() {
value := reflect.New(base)
values, err := mapStructFieldsIntoSlice(value, columns, strict)
if err != nil {
return err
}
for scanner.Next() {
value := reflect.New(base)
values, err := mapStructFieldsIntoSlice(value, columns, strict)
if err != nil {
return err
}
if err := scanner.Scan(values...); err != nil {
return err
}
appendFn(value)
if err := scanner.Scan(values...); err != nil {
return err
}
default:
return ErrUnsupportedValueType
}
return nil
appendFn(value)
}
default:
return ErrUnsupportedValueType
}
return ErrNotSettable
return nil
default:
return ErrUnsupportedValueType
}
@@ -257,6 +260,10 @@ func unwrapFields(v reflect.Value) []reflect.Value {
for i := 0; i < indirect.NumField(); i++ {
child := indirect.Field(i)
if !child.CanSet() {
continue
}
if child.Kind() == reflect.Ptr && child.IsNil() {
baseValueType := mapping.Deref(child.Type())
child.Set(reflect.New(baseValueType))

View File

@@ -8,11 +8,11 @@ import (
"github.com/DATA-DOG/go-sqlmock"
"github.com/stretchr/testify/assert"
"github.com/zeromicro/go-zero/core/logx"
"github.com/zeromicro/go-zero/internal/dbtest"
)
func TestUnmarshalRowBool(t *testing.T) {
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("1")
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
@@ -22,10 +22,22 @@ func TestUnmarshalRowBool(t *testing.T) {
}, "select value from users where user=?", "anyone"))
assert.True(t, value)
})
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("1")
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
var value struct {
Value bool `db:"value"`
}
assert.Error(t, query(context.Background(), db, func(rows *sql.Rows) error {
return unmarshalRow(value, rows, true)
}, "select value from users where user=?", "anyone"))
})
}
func TestUnmarshalRowBoolNotSettable(t *testing.T) {
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("1")
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
@@ -37,7 +49,7 @@ func TestUnmarshalRowBoolNotSettable(t *testing.T) {
}
func TestUnmarshalRowInt(t *testing.T) {
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2")
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
@@ -50,7 +62,7 @@ func TestUnmarshalRowInt(t *testing.T) {
}
func TestUnmarshalRowInt8(t *testing.T) {
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("3")
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
@@ -63,7 +75,7 @@ func TestUnmarshalRowInt8(t *testing.T) {
}
func TestUnmarshalRowInt16(t *testing.T) {
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("4")
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
@@ -76,7 +88,7 @@ func TestUnmarshalRowInt16(t *testing.T) {
}
func TestUnmarshalRowInt32(t *testing.T) {
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("5")
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
@@ -89,7 +101,7 @@ func TestUnmarshalRowInt32(t *testing.T) {
}
func TestUnmarshalRowInt64(t *testing.T) {
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("6")
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
@@ -102,7 +114,7 @@ func TestUnmarshalRowInt64(t *testing.T) {
}
func TestUnmarshalRowUint(t *testing.T) {
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2")
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
@@ -115,7 +127,7 @@ func TestUnmarshalRowUint(t *testing.T) {
}
func TestUnmarshalRowUint8(t *testing.T) {
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("3")
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
@@ -128,7 +140,7 @@ func TestUnmarshalRowUint8(t *testing.T) {
}
func TestUnmarshalRowUint16(t *testing.T) {
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("4")
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
@@ -141,7 +153,7 @@ func TestUnmarshalRowUint16(t *testing.T) {
}
func TestUnmarshalRowUint32(t *testing.T) {
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("5")
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
@@ -154,7 +166,7 @@ func TestUnmarshalRowUint32(t *testing.T) {
}
func TestUnmarshalRowUint64(t *testing.T) {
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("6")
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
@@ -167,7 +179,7 @@ func TestUnmarshalRowUint64(t *testing.T) {
}
func TestUnmarshalRowFloat32(t *testing.T) {
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("7")
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
@@ -180,7 +192,7 @@ func TestUnmarshalRowFloat32(t *testing.T) {
}
func TestUnmarshalRowFloat64(t *testing.T) {
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("8")
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
@@ -193,7 +205,7 @@ func TestUnmarshalRowFloat64(t *testing.T) {
}
func TestUnmarshalRowString(t *testing.T) {
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
const expect = "hello"
rs := sqlmock.NewRows([]string{"value"}).FromCSVString(expect)
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
@@ -207,12 +219,12 @@ func TestUnmarshalRowString(t *testing.T) {
}
func TestUnmarshalRowStruct(t *testing.T) {
value := new(struct {
Name string
Age int
})
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
value := new(struct {
Name string
Age int
})
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
rs := sqlmock.NewRows([]string{"name", "age"}).FromCSVString("liao,5")
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
@@ -222,15 +234,58 @@ func TestUnmarshalRowStruct(t *testing.T) {
assert.Equal(t, "liao", value.Name)
assert.Equal(t, 5, value.Age)
})
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
value := new(struct {
Name string
Age int
})
errAny := errors.New("any error")
rs := sqlmock.NewRows([]string{"name", "age"}).FromCSVString("liao,5")
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
assert.ErrorIs(t, query(context.Background(), db, func(rows *sql.Rows) error {
return unmarshalRow(value, &mockedScanner{
colErr: errAny,
next: 1,
}, true)
}, "select name, age from users where user=?", "anyone"), errAny)
})
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
value := new(struct {
Name string
age *int
})
rs := sqlmock.NewRows([]string{"name", "age"}).FromCSVString("liao,5")
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
assert.ErrorIs(t, query(context.Background(), db, func(rows *sql.Rows) error {
return unmarshalRow(value, rows, true)
}, "select name, age from users where user=?", "anyone"), ErrNotMatchDestination)
})
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("8")
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
type myString chan int
var value myString
assert.ErrorIs(t, query(context.Background(), db, func(rows *sql.Rows) error {
return unmarshalRow(&value, rows, true)
}, "select value from users where user=?", "anyone"), ErrUnsupportedValueType)
})
}
func TestUnmarshalRowStructWithTags(t *testing.T) {
value := new(struct {
Age int `db:"age"`
Name string `db:"name"`
})
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
value := new(struct {
Age int `db:"age"`
Name string `db:"name"`
})
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
rs := sqlmock.NewRows([]string{"name", "age"}).FromCSVString("liao,5")
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
@@ -240,6 +295,51 @@ func TestUnmarshalRowStructWithTags(t *testing.T) {
assert.Equal(t, "liao", value.Name)
assert.Equal(t, 5, value.Age)
})
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
value := new(struct {
age *int `db:"age"`
Name string `db:"name"`
})
rs := sqlmock.NewRows([]string{"name", "age"}).FromCSVString("liao,5")
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
assert.ErrorIs(t, query(context.Background(), db, func(rows *sql.Rows) error {
return unmarshalRow(value, rows, true)
}, "select name, age from users where user=?", "anyone"), ErrNotReadableValue)
})
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
var value struct {
Age *int `db:"age"`
Name *string `db:"name"`
}
rs := sqlmock.NewRows([]string{"name", "age"}).FromCSVString("liao,5")
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error {
return unmarshalRow(&value, rows, true)
}, "select name, age from users where user=?", "anyone"))
assert.Equal(t, "liao", *value.Name)
assert.Equal(t, 5, *value.Age)
})
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
value := new(struct {
Age int `db:"age"`
Name string
})
rs := sqlmock.NewRows([]string{"name", "age"}).FromCSVString("liao,5")
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error {
return unmarshalRow(value, rows, true)
}, "select name, age from users where user=?", "anyone"))
assert.Equal(t, 5, value.Age)
})
}
func TestUnmarshalRowStructWithTagsWrongColumns(t *testing.T) {
@@ -248,7 +348,7 @@ func TestUnmarshalRowStructWithTagsWrongColumns(t *testing.T) {
Name string `db:"name"`
})
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
rs := sqlmock.NewRows([]string{"name"}).FromCSVString("liao")
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
@@ -259,7 +359,7 @@ func TestUnmarshalRowStructWithTagsWrongColumns(t *testing.T) {
}
func TestUnmarshalRowsBool(t *testing.T) {
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
expect := []bool{true, false}
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("1\n0")
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
@@ -270,10 +370,46 @@ func TestUnmarshalRowsBool(t *testing.T) {
}, "select value from users where user=?", "anyone"))
assert.EqualValues(t, expect, value)
})
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("1\n0")
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
var value []bool
assert.Error(t, query(context.Background(), db, func(rows *sql.Rows) error {
return unmarshalRows(value, rows, true)
}, "select value from users where user=?", "anyone"))
})
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("1\n0")
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
var value struct {
value []bool `db:"value"`
}
assert.ErrorIs(t, query(context.Background(), db, func(rows *sql.Rows) error {
return unmarshalRows(&value, rows, true)
}, "select value from users where user=?", "anyone"), ErrUnsupportedValueType)
})
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("1\n0")
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
var value []bool
errAny := errors.New("any")
assert.ErrorIs(t, query(context.Background(), db, func(rows *sql.Rows) error {
return unmarshalRows(&value, &mockedScanner{
scanErr: errAny,
next: 1,
}, true)
}, "select value from users where user=?", "anyone"), errAny)
})
}
func TestUnmarshalRowsInt(t *testing.T) {
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
expect := []int{2, 3}
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2\n3")
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
@@ -287,7 +423,7 @@ func TestUnmarshalRowsInt(t *testing.T) {
}
func TestUnmarshalRowsInt8(t *testing.T) {
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
expect := []int8{2, 3}
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2\n3")
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
@@ -301,7 +437,7 @@ func TestUnmarshalRowsInt8(t *testing.T) {
}
func TestUnmarshalRowsInt16(t *testing.T) {
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
expect := []int16{2, 3}
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2\n3")
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
@@ -315,7 +451,7 @@ func TestUnmarshalRowsInt16(t *testing.T) {
}
func TestUnmarshalRowsInt32(t *testing.T) {
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
expect := []int32{2, 3}
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2\n3")
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
@@ -329,7 +465,7 @@ func TestUnmarshalRowsInt32(t *testing.T) {
}
func TestUnmarshalRowsInt64(t *testing.T) {
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
expect := []int64{2, 3}
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2\n3")
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
@@ -343,7 +479,7 @@ func TestUnmarshalRowsInt64(t *testing.T) {
}
func TestUnmarshalRowsUint(t *testing.T) {
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
expect := []uint{2, 3}
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2\n3")
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
@@ -357,7 +493,7 @@ func TestUnmarshalRowsUint(t *testing.T) {
}
func TestUnmarshalRowsUint8(t *testing.T) {
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
expect := []uint8{2, 3}
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2\n3")
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
@@ -371,7 +507,7 @@ func TestUnmarshalRowsUint8(t *testing.T) {
}
func TestUnmarshalRowsUint16(t *testing.T) {
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
expect := []uint16{2, 3}
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2\n3")
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
@@ -385,7 +521,7 @@ func TestUnmarshalRowsUint16(t *testing.T) {
}
func TestUnmarshalRowsUint32(t *testing.T) {
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
expect := []uint32{2, 3}
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2\n3")
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
@@ -399,7 +535,7 @@ func TestUnmarshalRowsUint32(t *testing.T) {
}
func TestUnmarshalRowsUint64(t *testing.T) {
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
expect := []uint64{2, 3}
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2\n3")
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
@@ -413,7 +549,7 @@ func TestUnmarshalRowsUint64(t *testing.T) {
}
func TestUnmarshalRowsFloat32(t *testing.T) {
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
expect := []float32{2, 3}
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2\n3")
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
@@ -427,7 +563,7 @@ func TestUnmarshalRowsFloat32(t *testing.T) {
}
func TestUnmarshalRowsFloat64(t *testing.T) {
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
expect := []float64{2, 3}
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2\n3")
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
@@ -441,7 +577,7 @@ func TestUnmarshalRowsFloat64(t *testing.T) {
}
func TestUnmarshalRowsString(t *testing.T) {
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
expect := []string{"hello", "world"}
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("hello\nworld")
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
@@ -457,7 +593,7 @@ func TestUnmarshalRowsString(t *testing.T) {
func TestUnmarshalRowsBoolPtr(t *testing.T) {
yes := true
no := false
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
expect := []*bool{&yes, &no}
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("1\n0")
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
@@ -473,7 +609,7 @@ func TestUnmarshalRowsBoolPtr(t *testing.T) {
func TestUnmarshalRowsIntPtr(t *testing.T) {
two := 2
three := 3
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
expect := []*int{&two, &three}
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2\n3")
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
@@ -489,7 +625,7 @@ func TestUnmarshalRowsIntPtr(t *testing.T) {
func TestUnmarshalRowsInt8Ptr(t *testing.T) {
two := int8(2)
three := int8(3)
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
expect := []*int8{&two, &three}
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2\n3")
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
@@ -505,7 +641,7 @@ func TestUnmarshalRowsInt8Ptr(t *testing.T) {
func TestUnmarshalRowsInt16Ptr(t *testing.T) {
two := int16(2)
three := int16(3)
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
expect := []*int16{&two, &three}
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2\n3")
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
@@ -521,7 +657,7 @@ func TestUnmarshalRowsInt16Ptr(t *testing.T) {
func TestUnmarshalRowsInt32Ptr(t *testing.T) {
two := int32(2)
three := int32(3)
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
expect := []*int32{&two, &three}
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2\n3")
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
@@ -537,7 +673,7 @@ func TestUnmarshalRowsInt32Ptr(t *testing.T) {
func TestUnmarshalRowsInt64Ptr(t *testing.T) {
two := int64(2)
three := int64(3)
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
expect := []*int64{&two, &three}
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2\n3")
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
@@ -553,7 +689,7 @@ func TestUnmarshalRowsInt64Ptr(t *testing.T) {
func TestUnmarshalRowsUintPtr(t *testing.T) {
two := uint(2)
three := uint(3)
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
expect := []*uint{&two, &three}
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2\n3")
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
@@ -569,7 +705,7 @@ func TestUnmarshalRowsUintPtr(t *testing.T) {
func TestUnmarshalRowsUint8Ptr(t *testing.T) {
two := uint8(2)
three := uint8(3)
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
expect := []*uint8{&two, &three}
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2\n3")
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
@@ -585,7 +721,7 @@ func TestUnmarshalRowsUint8Ptr(t *testing.T) {
func TestUnmarshalRowsUint16Ptr(t *testing.T) {
two := uint16(2)
three := uint16(3)
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
expect := []*uint16{&two, &three}
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2\n3")
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
@@ -601,7 +737,7 @@ func TestUnmarshalRowsUint16Ptr(t *testing.T) {
func TestUnmarshalRowsUint32Ptr(t *testing.T) {
two := uint32(2)
three := uint32(3)
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
expect := []*uint32{&two, &three}
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2\n3")
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
@@ -617,7 +753,7 @@ func TestUnmarshalRowsUint32Ptr(t *testing.T) {
func TestUnmarshalRowsUint64Ptr(t *testing.T) {
two := uint64(2)
three := uint64(3)
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
expect := []*uint64{&two, &three}
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2\n3")
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
@@ -633,7 +769,7 @@ func TestUnmarshalRowsUint64Ptr(t *testing.T) {
func TestUnmarshalRowsFloat32Ptr(t *testing.T) {
two := float32(2)
three := float32(3)
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
expect := []*float32{&two, &three}
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2\n3")
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
@@ -649,7 +785,7 @@ func TestUnmarshalRowsFloat32Ptr(t *testing.T) {
func TestUnmarshalRowsFloat64Ptr(t *testing.T) {
two := float64(2)
three := float64(3)
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
expect := []*float64{&two, &three}
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2\n3")
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
@@ -665,7 +801,7 @@ func TestUnmarshalRowsFloat64Ptr(t *testing.T) {
func TestUnmarshalRowsStringPtr(t *testing.T) {
hello := "hello"
world := "world"
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
expect := []*string{&hello, &world}
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("hello\nworld")
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
@@ -679,25 +815,25 @@ func TestUnmarshalRowsStringPtr(t *testing.T) {
}
func TestUnmarshalRowsStruct(t *testing.T) {
expect := []struct {
Name string
Age int64
}{
{
Name: "first",
Age: 2,
},
{
Name: "second",
Age: 3,
},
}
var value []struct {
Name string
Age int64
}
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
expect := []struct {
Name string
Age int64
}{
{
Name: "first",
Age: 2,
},
{
Name: "second",
Age: 3,
},
}
var value []struct {
Name string
Age int64
}
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
rs := sqlmock.NewRows([]string{"name", "age"}).FromCSVString("first,2\nsecond,3")
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error {
@@ -709,6 +845,56 @@ func TestUnmarshalRowsStruct(t *testing.T) {
assert.Equal(t, each.Age, value[i].Age)
}
})
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
var value []struct {
Name string
Age int64
}
errAny := errors.New("any error")
rs := sqlmock.NewRows([]string{"name", "age"}).FromCSVString("first,2\nsecond,3")
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
assert.ErrorIs(t, query(context.Background(), db, func(rows *sql.Rows) error {
return unmarshalRows(&value, &mockedScanner{
colErr: errAny,
next: 1,
}, true)
}, "select name, age from users where user=?", "anyone"), errAny)
})
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
var value []struct {
Name string
Age int64
}
errAny := errors.New("any error")
rs := sqlmock.NewRows([]string{"name", "age"}).FromCSVString("first,2\nsecond,3")
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
assert.ErrorIs(t, query(context.Background(), db, func(rows *sql.Rows) error {
return unmarshalRows(&value, &mockedScanner{
cols: []string{"name", "age"},
scanErr: errAny,
next: 1,
}, true)
}, "select name, age from users where user=?", "anyone"), errAny)
})
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
var value []chan int
errAny := errors.New("any error")
rs := sqlmock.NewRows([]string{"name", "age"}).FromCSVString("first,2\nsecond,3")
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
assert.ErrorIs(t, query(context.Background(), db, func(rows *sql.Rows) error {
return unmarshalRows(&value, &mockedScanner{
cols: []string{"name", "age"},
scanErr: errAny,
next: 1,
}, true)
}, "select name, age from users where user=?", "anyone"), ErrUnsupportedValueType)
})
}
func TestUnmarshalRowsStructWithNullStringType(t *testing.T) {
@@ -736,7 +922,7 @@ func TestUnmarshalRowsStructWithNullStringType(t *testing.T) {
NullString sql.NullString `db:"value"`
}
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
rs := sqlmock.NewRows([]string{"name", "value"}).AddRow(
"first", "firstnullstring").AddRow("second", nil)
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
@@ -771,7 +957,7 @@ func TestUnmarshalRowsStructWithTags(t *testing.T) {
Name string `db:"name"`
}
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
rs := sqlmock.NewRows([]string{"name", "age"}).FromCSVString("first,2\nsecond,3")
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error {
@@ -812,7 +998,7 @@ func TestUnmarshalRowsStructAndEmbeddedAnonymousStructWithTags(t *testing.T) {
Embed
}
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
rs := sqlmock.NewRows([]string{"name", "age", "value"}).FromCSVString("first,2,3\nsecond,3,4")
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error {
@@ -854,7 +1040,7 @@ func TestUnmarshalRowsStructAndEmbeddedStructPtrAnonymousWithTags(t *testing.T)
*Embed
}
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
rs := sqlmock.NewRows([]string{"name", "age", "value"}).FromCSVString("first,2,3\nsecond,3,4")
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error {
@@ -888,7 +1074,7 @@ func TestUnmarshalRowsStructPtr(t *testing.T) {
Age int64
}
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
rs := sqlmock.NewRows([]string{"name", "age"}).FromCSVString("first,2\nsecond,3")
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error {
@@ -921,7 +1107,7 @@ func TestUnmarshalRowsStructWithTagsPtr(t *testing.T) {
Name string `db:"name"`
}
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
rs := sqlmock.NewRows([]string{"name", "age"}).FromCSVString("first,2\nsecond,3")
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error {
@@ -954,7 +1140,7 @@ func TestUnmarshalRowsStructWithTagsPtrWithInnerPtr(t *testing.T) {
Name string `db:"name"`
}
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
rs := sqlmock.NewRows([]string{"name", "age"}).FromCSVString("first,2\nsecond,3")
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error {
@@ -969,7 +1155,7 @@ func TestUnmarshalRowsStructWithTagsPtrWithInnerPtr(t *testing.T) {
}
func TestCommonSqlConn_QueryRowOptional(t *testing.T) {
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
rs := sqlmock.NewRows([]string{"age"}).FromCSVString("5")
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
@@ -1019,7 +1205,7 @@ func TestUnmarshalRowError(t *testing.T) {
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
rs := sqlmock.NewRows([]string{"age"}).FromCSVString("5")
mock.ExpectQuery("select (.+) from users where user=?").WithArgs(
"anyone").WillReturnRows(rs)
@@ -1091,7 +1277,7 @@ func TestAnonymousStructPr(t *testing.T) {
Name string `db:"name"`
}
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
rs := sqlmock.NewRows([]string{
"name",
"age",
@@ -1139,7 +1325,7 @@ func TestAnonymousStructPrError(t *testing.T) {
Name string `db:"name"`
}
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
rs := sqlmock.NewRows([]string{
"name",
"age",
@@ -1154,7 +1340,7 @@ func TestAnonymousStructPrError(t *testing.T) {
WithArgs("anyone").WillReturnRows(rs)
assert.Error(t, query(context.Background(), db, func(rows *sql.Rows) error {
return unmarshalRows(&value, rows, true)
}, "select name, age,grade,discipline,class_name,score from users where user=?",
}, "select name, age, grade, discipline, class_name, score from users where user=?",
"anyone"))
if len(value) > 0 {
assert.Equal(t, value[0].score, 0)
@@ -1162,23 +1348,8 @@ func TestAnonymousStructPrError(t *testing.T) {
})
}
func runOrmTest(t *testing.T, fn func(db *sql.DB, mock sqlmock.Sqlmock)) {
logx.Disable()
db, mock, err := sqlmock.New()
if err != nil {
t.Fatalf("an error '%s' was not expected when opening a stub database connection", err)
}
defer db.Close()
fn(db, mock)
if err := mock.ExpectationsWereMet(); err != nil {
t.Errorf("there were unfulfilled expectations: %s", err)
}
}
type mockedScanner struct {
cols []string
colErr error
scanErr error
err error
@@ -1186,7 +1357,7 @@ type mockedScanner struct {
}
func (m *mockedScanner) Columns() ([]string, error) {
return nil, m.colErr
return m.cols, m.colErr
}
func (m *mockedScanner) Err() error {

View File

@@ -11,9 +11,6 @@ import (
// spanName is used to identify the span name for the SQL execution.
const spanName = "sql"
// ErrNotFound is an alias of sql.ErrNoRows
var ErrNotFound = sql.ErrNoRows
type (
// Session stands for raw connections or transaction sessions
Session interface {
@@ -131,6 +128,13 @@ func NewSqlConnFromDB(db *sql.DB, opts ...SqlOption) SqlConn {
return conn
}
// NewSqlConnFromSession returns a SqlConn with the given session.
func NewSqlConnFromSession(session Session) SqlConn {
return txConn{
Session: session,
}
}
func (db *commonSqlConn) Exec(q string, args ...any) (result sql.Result, err error) {
return db.ExecCtx(context.Background(), q, args...)
}
@@ -287,12 +291,19 @@ func (db *commonSqlConn) TransactCtx(ctx context.Context, fn func(context.Contex
}
func (db *commonSqlConn) acceptable(err error) bool {
ok := err == nil || err == sql.ErrNoRows || err == sql.ErrTxDone || err == context.Canceled
if db.accept == nil {
return ok
if err == nil || err == sql.ErrNoRows || err == sql.ErrTxDone || err == context.Canceled {
return true
}
return ok || db.accept(err)
if _, ok := err.(acceptableError); ok {
return true
}
if db.accept == nil {
return false
}
return db.accept(err)
}
func (db *commonSqlConn) queryRows(ctx context.Context, scanner func(*sql.Rows) error,
@@ -395,3 +406,11 @@ func (s statement) QueryRowsPartialCtx(ctx context.Context, v any, args ...any)
return unmarshalRows(v, rows, false)
}, s.query, args...)
}
// WithAcceptable returns a SqlOption that setting the acceptable function.
// acceptable is the func to check if the error can be accepted.
func WithAcceptable(acceptable func(err error) bool) SqlOption {
return func(conn *commonSqlConn) {
conn.accept = acceptable
}
}

View File

@@ -2,13 +2,16 @@ package sqlx
import (
"database/sql"
"errors"
"io"
"testing"
"github.com/DATA-DOG/go-sqlmock"
"github.com/stretchr/testify/assert"
"github.com/zeromicro/go-zero/core/breaker"
"github.com/zeromicro/go-zero/core/logx"
"github.com/zeromicro/go-zero/core/trace/tracetest"
"github.com/zeromicro/go-zero/internal/dbtest"
)
const mockedDatasource = "sqlmock"
@@ -54,8 +57,214 @@ func TestSqlConn(t *testing.T) {
assert.Equal(t, 14, len(me.GetSpans()))
}
func TestSqlConn_RawDB(t *testing.T) {
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
rows := sqlmock.NewRows([]string{"foo"}).AddRow("bar")
mock.ExpectQuery("any").WillReturnRows(rows)
conn := NewSqlConnFromDB(db)
var val string
assert.NoError(t, conn.QueryRow(&val, "any"))
assert.Equal(t, "bar", val)
})
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
rows := sqlmock.NewRows([]string{"foo"}).AddRow("bar")
mock.ExpectQuery("any").WillReturnRows(rows)
conn := NewSqlConnFromDB(db)
var val string
assert.NoError(t, conn.QueryRowPartial(&val, "any"))
assert.Equal(t, "bar", val)
})
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
rows := sqlmock.NewRows([]string{"any"}).AddRow("foo").AddRow("bar")
mock.ExpectQuery("any").WillReturnRows(rows)
conn := NewSqlConnFromDB(db)
var vals []string
assert.NoError(t, conn.QueryRows(&vals, "any"))
assert.ElementsMatch(t, []string{"foo", "bar"}, vals)
})
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
rows := sqlmock.NewRows([]string{"any"}).AddRow("foo").AddRow("bar")
mock.ExpectQuery("any").WillReturnRows(rows)
conn := NewSqlConnFromDB(db)
var vals []string
assert.NoError(t, conn.QueryRowsPartial(&vals, "any"))
assert.ElementsMatch(t, []string{"foo", "bar"}, vals)
})
}
func TestSqlConn_Errors(t *testing.T) {
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
conn := NewSqlConnFromDB(db)
conn.(*commonSqlConn).connProv = func() (*sql.DB, error) {
return nil, errors.New("error")
}
_, err := conn.Prepare("any")
assert.Error(t, err)
})
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
mock.ExpectExec("any").WillReturnError(breaker.ErrServiceUnavailable)
conn := NewSqlConnFromDB(db)
_, err := conn.Exec("any")
assert.Error(t, err)
})
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
mock.ExpectPrepare("any").WillReturnError(breaker.ErrServiceUnavailable)
conn := NewSqlConnFromDB(db)
_, err := conn.Prepare("any")
assert.Error(t, err)
})
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
mock.ExpectBegin()
mock.ExpectRollback()
conn := NewSqlConnFromDB(db)
err := conn.Transact(func(session Session) error {
return breaker.ErrServiceUnavailable
})
assert.Equal(t, breaker.ErrServiceUnavailable, err)
})
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
mock.ExpectQuery("any").WillReturnError(breaker.ErrServiceUnavailable)
conn := NewSqlConnFromDB(db)
var vals []string
err := conn.QueryRows(&vals, "any")
assert.Equal(t, breaker.ErrServiceUnavailable, err)
})
}
func TestStatement(t *testing.T) {
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
mock.ExpectPrepare("any").WillBeClosed()
conn := NewSqlConnFromDB(db)
stmt, err := conn.Prepare("any")
assert.NoError(t, err)
assert.NoError(t, stmt.Close())
})
dbtest.RunTxTest(t, func(tx *sql.Tx, mock sqlmock.Sqlmock) {
mock.ExpectPrepare("any").WillBeClosed()
stmt, err := tx.Prepare("any")
assert.NoError(t, err)
st := statement{
query: "foo",
stmt: stmt,
}
assert.NoError(t, st.Close())
})
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
mock.ExpectPrepare("any")
mock.ExpectExec("any").WillReturnResult(sqlmock.NewResult(2, 3))
conn := NewSqlConnFromDB(db)
stmt, err := conn.Prepare("any")
assert.NoError(t, err)
res, err := stmt.Exec()
assert.NoError(t, err)
lastInsertID, err := res.LastInsertId()
assert.NoError(t, err)
assert.Equal(t, int64(2), lastInsertID)
rowsAffected, err := res.RowsAffected()
assert.NoError(t, err)
assert.Equal(t, int64(3), rowsAffected)
})
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
mock.ExpectPrepare("any")
row := sqlmock.NewRows([]string{"foo"}).AddRow("bar")
mock.ExpectQuery("any").WillReturnRows(row)
conn := NewSqlConnFromDB(db)
stmt, err := conn.Prepare("any")
assert.NoError(t, err)
var val string
err = stmt.QueryRow(&val)
assert.NoError(t, err)
assert.Equal(t, "bar", val)
})
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
mock.ExpectPrepare("any")
row := sqlmock.NewRows([]string{"foo"}).AddRow("bar")
mock.ExpectQuery("any").WillReturnRows(row)
conn := NewSqlConnFromDB(db)
stmt, err := conn.Prepare("any")
assert.NoError(t, err)
var val string
err = stmt.QueryRowPartial(&val)
assert.NoError(t, err)
assert.Equal(t, "bar", val)
})
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
mock.ExpectPrepare("any")
rows := sqlmock.NewRows([]string{"any"}).AddRow("foo").AddRow("bar")
mock.ExpectQuery("any").WillReturnRows(rows)
conn := NewSqlConnFromDB(db)
stmt, err := conn.Prepare("any")
assert.NoError(t, err)
var vals []string
assert.NoError(t, stmt.QueryRows(&vals))
assert.ElementsMatch(t, []string{"foo", "bar"}, vals)
})
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
mock.ExpectPrepare("any")
rows := sqlmock.NewRows([]string{"any"}).AddRow("foo").AddRow("bar")
mock.ExpectQuery("any").WillReturnRows(rows)
conn := NewSqlConnFromDB(db)
stmt, err := conn.Prepare("any")
assert.NoError(t, err)
var vals []string
assert.NoError(t, stmt.QueryRowsPartial(&vals))
assert.ElementsMatch(t, []string{"foo", "bar"}, vals)
})
}
func TestBreakerWithFormatError(t *testing.T) {
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
conn := NewSqlConnFromDB(db, withMysqlAcceptable())
for i := 0; i < 1000; i++ {
var val string
if !assert.NotEqual(t, breaker.ErrServiceUnavailable,
conn.QueryRow(&val, "any ?, ?", "foo")) {
break
}
}
})
}
func TestBreakerWithScanError(t *testing.T) {
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
conn := NewSqlConnFromDB(db, withMysqlAcceptable())
for i := 0; i < 1000; i++ {
rows := sqlmock.NewRows([]string{"foo"}).AddRow("bar")
mock.ExpectQuery("any").WillReturnRows(rows)
var val int
if !assert.NotEqual(t, breaker.ErrServiceUnavailable, conn.QueryRow(&val, "any")) {
break
}
}
})
}
func buildConn() (mock sqlmock.Sqlmock, err error) {
connManager.GetResource(mockedDatasource, func() (io.Closer, error) {
_, err = connManager.GetResource(mockedDatasource, func() (io.Closer, error) {
var db *sql.DB
var err error
db, mock, err = sqlmock.New()

View File

@@ -136,7 +136,7 @@ func (e *realSqlGuard) finish(ctx context.Context, err error) {
logSqlError(ctx, e.stmt, err)
}
metricReqDur.Observe(int64(duration/time.Millisecond), e.command)
metricReqDur.Observe(duration.Milliseconds(), e.command)
}
func (e *realSqlGuard) start(q string, args ...any) error {

View File

@@ -15,11 +15,27 @@ type (
Rollback() error
}
txConn struct {
Session
}
txSession struct {
*sql.Tx
}
)
func (s txConn) RawDB() (*sql.DB, error) {
return nil, errNoRawDBFromTx
}
func (s txConn) Transact(_ func(Session) error) error {
return errCantNestTx
}
func (s txConn) TransactCtx(_ context.Context, _ func(context.Context, Session) error) error {
return errCantNestTx
}
// NewSessionFromTx returns a Session with the given sql.Tx.
// Use it with caution, it's provided for other ORM to interact with.
func NewSessionFromTx(tx *sql.Tx) Session {

View File

@@ -6,7 +6,10 @@ import (
"errors"
"testing"
"github.com/DATA-DOG/go-sqlmock"
"github.com/stretchr/testify/assert"
"github.com/zeromicro/go-zero/core/breaker"
"github.com/zeromicro/go-zero/internal/dbtest"
)
const (
@@ -23,51 +26,51 @@ func (mt *mockTx) Commit() error {
return nil
}
func (mt *mockTx) Exec(q string, args ...any) (sql.Result, error) {
func (mt *mockTx) Exec(_ string, _ ...any) (sql.Result, error) {
return nil, nil
}
func (mt *mockTx) ExecCtx(ctx context.Context, query string, args ...any) (sql.Result, error) {
func (mt *mockTx) ExecCtx(_ context.Context, _ string, _ ...any) (sql.Result, error) {
return nil, nil
}
func (mt *mockTx) Prepare(query string) (StmtSession, error) {
func (mt *mockTx) Prepare(_ string) (StmtSession, error) {
return nil, nil
}
func (mt *mockTx) PrepareCtx(ctx context.Context, query string) (StmtSession, error) {
func (mt *mockTx) PrepareCtx(_ context.Context, _ string) (StmtSession, error) {
return nil, nil
}
func (mt *mockTx) QueryRow(v any, q string, args ...any) error {
func (mt *mockTx) QueryRow(_ any, _ string, _ ...any) error {
return nil
}
func (mt *mockTx) QueryRowCtx(ctx context.Context, v any, query string, args ...any) error {
func (mt *mockTx) QueryRowCtx(_ context.Context, _ any, _ string, _ ...any) error {
return nil
}
func (mt *mockTx) QueryRowPartial(v any, q string, args ...any) error {
func (mt *mockTx) QueryRowPartial(_ any, _ string, _ ...any) error {
return nil
}
func (mt *mockTx) QueryRowPartialCtx(ctx context.Context, v any, query string, args ...any) error {
func (mt *mockTx) QueryRowPartialCtx(_ context.Context, _ any, _ string, _ ...any) error {
return nil
}
func (mt *mockTx) QueryRows(v any, q string, args ...any) error {
func (mt *mockTx) QueryRows(_ any, _ string, _ ...any) error {
return nil
}
func (mt *mockTx) QueryRowsCtx(ctx context.Context, v any, query string, args ...any) error {
func (mt *mockTx) QueryRowsCtx(_ context.Context, _ any, _ string, _ ...any) error {
return nil
}
func (mt *mockTx) QueryRowsPartial(v any, q string, args ...any) error {
func (mt *mockTx) QueryRowsPartial(_ any, _ string, _ ...any) error {
return nil
}
func (mt *mockTx) QueryRowsPartialCtx(ctx context.Context, v any, query string, args ...any) error {
func (mt *mockTx) QueryRowsPartialCtx(_ context.Context, _ any, _ string, _ ...any) error {
return nil
}
@@ -101,3 +104,209 @@ func TestTransactRollback(t *testing.T) {
assert.Equal(t, mockRollback, mock.status)
assert.NotNil(t, err)
}
func TestTxExceptions(t *testing.T) {
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
mock.ExpectBegin()
mock.ExpectCommit()
conn := NewSqlConnFromDB(db)
assert.NoError(t, conn.Transact(func(session Session) error {
return nil
}))
})
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
conn := &commonSqlConn{
connProv: func() (*sql.DB, error) {
return nil, errors.New("foo")
},
beginTx: begin,
onError: func(ctx context.Context, err error) {},
brk: breaker.NewBreaker(),
}
assert.Error(t, conn.Transact(func(session Session) error {
return nil
}))
})
runTxTest(t, func(conn SqlConn, mock sqlmock.Sqlmock) {
_, err := conn.RawDB()
assert.Equal(t, errNoRawDBFromTx, err)
assert.Equal(t, errCantNestTx, conn.Transact(nil))
assert.Equal(t, errCantNestTx, conn.TransactCtx(context.Background(), nil))
})
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
mock.ExpectBegin()
conn := NewSqlConnFromDB(db)
assert.Error(t, conn.Transact(func(session Session) error {
return errors.New("foo")
}))
})
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
mock.ExpectBegin()
mock.ExpectRollback().WillReturnError(errors.New("foo"))
conn := NewSqlConnFromDB(db)
assert.Error(t, conn.Transact(func(session Session) error {
panic("foo")
}))
})
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
mock.ExpectBegin()
mock.ExpectRollback()
conn := NewSqlConnFromDB(db)
assert.Error(t, conn.Transact(func(session Session) error {
panic(errors.New("foo"))
}))
})
}
func TestTxSession(t *testing.T) {
runTxTest(t, func(conn SqlConn, mock sqlmock.Sqlmock) {
mock.ExpectExec("any").WillReturnResult(sqlmock.NewResult(2, 3))
res, err := conn.Exec("any")
assert.NoError(t, err)
last, err := res.LastInsertId()
assert.NoError(t, err)
assert.Equal(t, int64(2), last)
affected, err := res.RowsAffected()
assert.NoError(t, err)
assert.Equal(t, int64(3), affected)
mock.ExpectExec("any").WillReturnError(errors.New("foo"))
_, err = conn.Exec("any")
assert.Equal(t, "foo", err.Error())
})
runTxTest(t, func(conn SqlConn, mock sqlmock.Sqlmock) {
mock.ExpectPrepare("any")
stmt, err := conn.Prepare("any")
assert.NoError(t, err)
assert.NotNil(t, stmt)
mock.ExpectPrepare("any").WillReturnError(errors.New("foo"))
_, err = conn.Prepare("any")
assert.Equal(t, "foo", err.Error())
})
runTxTest(t, func(conn SqlConn, mock sqlmock.Sqlmock) {
rows := sqlmock.NewRows([]string{"col"}).AddRow("foo")
mock.ExpectQuery("any").WillReturnRows(rows)
var val string
err := conn.QueryRow(&val, "any")
assert.NoError(t, err)
assert.Equal(t, "foo", val)
mock.ExpectQuery("any").WillReturnError(errors.New("foo"))
err = conn.QueryRow(&val, "any")
assert.Equal(t, "foo", err.Error())
})
runTxTest(t, func(conn SqlConn, mock sqlmock.Sqlmock) {
rows := sqlmock.NewRows([]string{"col"}).AddRow("foo")
mock.ExpectQuery("any").WillReturnRows(rows)
var val string
err := conn.QueryRowPartial(&val, "any")
assert.NoError(t, err)
assert.Equal(t, "foo", val)
mock.ExpectQuery("any").WillReturnError(errors.New("foo"))
err = conn.QueryRowPartial(&val, "any")
assert.Equal(t, "foo", err.Error())
})
runTxTest(t, func(conn SqlConn, mock sqlmock.Sqlmock) {
rows := sqlmock.NewRows([]string{"col"}).AddRow("foo").AddRow("bar")
mock.ExpectQuery("any").WillReturnRows(rows)
var val []string
err := conn.QueryRows(&val, "any")
assert.NoError(t, err)
assert.Equal(t, []string{"foo", "bar"}, val)
mock.ExpectQuery("any").WillReturnError(errors.New("foo"))
err = conn.QueryRows(&val, "any")
assert.Equal(t, "foo", err.Error())
})
runTxTest(t, func(conn SqlConn, mock sqlmock.Sqlmock) {
rows := sqlmock.NewRows([]string{"col"}).AddRow("foo").AddRow("bar")
mock.ExpectQuery("any").WillReturnRows(rows)
var val []string
err := conn.QueryRowsPartial(&val, "any")
assert.NoError(t, err)
assert.Equal(t, []string{"foo", "bar"}, val)
mock.ExpectQuery("any").WillReturnError(errors.New("foo"))
err = conn.QueryRowsPartial(&val, "any")
assert.Equal(t, "foo", err.Error())
})
}
func TestTxRollback(t *testing.T) {
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
mock.ExpectBegin()
mock.ExpectExec("any").WillReturnResult(sqlmock.NewResult(2, 3))
mock.ExpectQuery("foo").WillReturnError(errors.New("foo"))
mock.ExpectRollback()
conn := NewSqlConnFromDB(db)
err := conn.Transact(func(session Session) error {
c := NewSqlConnFromSession(session)
_, err := c.Exec("any")
assert.NoError(t, err)
var val string
return c.QueryRow(&val, "foo")
})
assert.Error(t, err)
})
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
mock.ExpectBegin()
mock.ExpectExec("any").WillReturnError(errors.New("foo"))
mock.ExpectRollback()
conn := NewSqlConnFromDB(db)
err := conn.Transact(func(session Session) error {
c := NewSqlConnFromSession(session)
if _, err := c.Exec("any"); err != nil {
return err
}
var val string
assert.NoError(t, c.QueryRow(&val, "foo"))
return nil
})
assert.Error(t, err)
})
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
mock.ExpectBegin()
mock.ExpectExec("any").WillReturnResult(sqlmock.NewResult(2, 3))
mock.ExpectQuery("foo").WillReturnRows(sqlmock.NewRows([]string{"col"}).AddRow("bar"))
mock.ExpectCommit()
conn := NewSqlConnFromDB(db)
err := conn.Transact(func(session Session) error {
c := NewSqlConnFromSession(session)
_, err := c.Exec("any")
assert.NoError(t, err)
var val string
assert.NoError(t, c.QueryRow(&val, "foo"))
assert.Equal(t, "bar", val)
return nil
})
assert.NoError(t, err)
})
}
func runTxTest(t *testing.T, f func(conn SqlConn, mock sqlmock.Sqlmock)) {
dbtest.RunTxTest(t, func(tx *sql.Tx, mock sqlmock.Sqlmock) {
sess := NewSessionFromTx(tx)
conn := NewSqlConnFromSession(sess)
f(conn, mock)
})
}

View File

@@ -51,7 +51,13 @@ func escape(input string) string {
return b.String()
}
func format(query string, args ...any) (string, error) {
func format(query string, args ...any) (val string, err error) {
defer func() {
if err != nil {
err = newAcceptableError(err)
}
}()
numArgs := len(args)
if numArgs == 0 {
return query, nil
@@ -66,7 +72,8 @@ func format(query string, args ...any) (string, error) {
switch ch {
case '?':
if argIndex >= numArgs {
return "", fmt.Errorf("%d ? in sql, but less arguments provided", argIndex)
return "", fmt.Errorf("%d ? in sql, but only %d arguments provided",
argIndex+1, numArgs)
}
writeValue(&b, args[argIndex])
@@ -165,3 +172,17 @@ func writeValue(buf *strings.Builder, arg any) {
buf.WriteString(mapping.Repr(v))
}
}
type acceptableError struct {
err error
}
func newAcceptableError(err error) error {
return acceptableError{
err: err,
}
}
func (e acceptableError) Error() string {
return e.err.Error()
}

View File

@@ -20,12 +20,14 @@ func ForAtomicBool(val bool) *AtomicBool {
// CompareAndSwap compares current value with given old, if equals, set to given val.
func (b *AtomicBool) CompareAndSwap(old, val bool) bool {
var ov, nv uint32
if old {
ov = 1
}
if val {
nv = 1
}
return atomic.CompareAndSwapUint32((*uint32)(b), ov, nv)
}

View File

@@ -42,7 +42,8 @@ func (manager *ResourceManager) Close() error {
}
// GetResource returns the resource associated with given key.
func (manager *ResourceManager) GetResource(key string, create func() (io.Closer, error)) (io.Closer, error) {
func (manager *ResourceManager) GetResource(key string, create func() (io.Closer, error)) (
io.Closer, error) {
val, err := manager.singleFlight.Do(key, func() (any, error) {
manager.lock.RLock()
resource, ok := manager.resources[key]

View File

@@ -9,25 +9,44 @@ import (
)
func TestTimeoutLimit(t *testing.T) {
limit := NewTimeoutLimit(2)
assert.Nil(t, limit.Borrow(time.Millisecond*200))
assert.Nil(t, limit.Borrow(time.Millisecond*200))
var wait1, wait2, wait3 sync.WaitGroup
wait1.Add(1)
wait2.Add(1)
wait3.Add(1)
go func() {
wait1.Wait()
wait2.Done()
assert.Nil(t, limit.Return())
wait3.Done()
}()
wait1.Done()
wait2.Wait()
assert.Nil(t, limit.Borrow(time.Second))
wait3.Wait()
assert.Equal(t, ErrTimeout, limit.Borrow(time.Millisecond*100))
assert.Nil(t, limit.Return())
assert.Nil(t, limit.Return())
assert.Equal(t, ErrLimitReturn, limit.Return())
tests := []struct {
name string
interval time.Duration
}{
{
name: "no wait",
},
{
name: "wait",
interval: time.Millisecond * 100,
},
}
for _, test := range tests {
test := test
t.Run(test.name, func(t *testing.T) {
limit := NewTimeoutLimit(2)
assert.Nil(t, limit.Borrow(time.Millisecond*200))
assert.Nil(t, limit.Borrow(time.Millisecond*200))
var wait1, wait2, wait3 sync.WaitGroup
wait1.Add(1)
wait2.Add(1)
wait3.Add(1)
go func() {
wait1.Wait()
wait2.Done()
time.Sleep(test.interval)
assert.Nil(t, limit.Return())
wait3.Done()
}()
wait1.Done()
wait2.Wait()
assert.Nil(t, limit.Borrow(time.Second))
wait3.Wait()
assert.Equal(t, ErrTimeout, limit.Borrow(time.Millisecond*100))
assert.Nil(t, limit.Return())
assert.Nil(t, limit.Return())
assert.Equal(t, ErrLimitReturn, limit.Return())
})
}
}

View File

@@ -2,6 +2,7 @@ package threading
import (
"bytes"
"context"
"runtime"
"strconv"
@@ -13,6 +14,11 @@ func GoSafe(fn func()) {
go RunSafe(fn)
}
// GoSafeCtx runs the given fn using another goroutine, recovers if fn panics with ctx.
func GoSafeCtx(ctx context.Context, fn func()) {
go RunSafeCtx(ctx, fn)
}
// RoutineId is only for debug, never use it in production.
func RoutineId() uint64 {
b := make([]byte, 64)
@@ -31,3 +37,10 @@ func RunSafe(fn func()) {
fn()
}
// RunSafeCtx runs the given fn, recovers if fn panics with ctx.
func RunSafeCtx(ctx context.Context, fn func()) {
defer rescue.RecoverCtx(ctx)
fn()
}

Some files were not shown because too many files have changed in this diff Show More