mirror of
https://github.com/zeromicro/go-zero.git
synced 2026-05-12 01:10:00 +08:00
Compare commits
164 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
2bd95aa007 | ||
|
|
e8376936d5 | ||
|
|
71c0288023 | ||
|
|
9e2f07a842 | ||
|
|
24fd34413f | ||
|
|
3f47251892 | ||
|
|
0b6bc69afa | ||
|
|
5b9bdc8d02 | ||
|
|
ded22e296e | ||
|
|
f0ed2370a3 | ||
|
|
6bf6cfdd01 | ||
|
|
5cc9eb0de4 | ||
|
|
f070d447ef | ||
|
|
f6d9e19ecb | ||
|
|
56807aabf6 | ||
|
|
861dcf2f36 | ||
|
|
c837dc21bb | ||
|
|
96a35ecf1a | ||
|
|
bdec5f2349 | ||
|
|
bc92b57bdb | ||
|
|
d8905b9e9e | ||
|
|
dec6309c55 | ||
|
|
10805577f5 | ||
|
|
a4d8286e36 | ||
|
|
84d2b64e7c | ||
|
|
6476da4a18 | ||
|
|
79eab0ea2f | ||
|
|
3b683fd498 | ||
|
|
d179b342b2 | ||
|
|
58874779e7 | ||
|
|
8829c31c0d | ||
|
|
b42f3fa047 | ||
|
|
9bdadf2381 | ||
|
|
20f665ede8 | ||
|
|
0325d8e92d | ||
|
|
2125977281 | ||
|
|
c26c187e11 | ||
|
|
4ef1859f0b | ||
|
|
407a6cbf9c | ||
|
|
76fc1ef460 | ||
|
|
423955c55f | ||
|
|
db95b3f0e3 | ||
|
|
4bee60eb7f | ||
|
|
7618139dad | ||
|
|
6fd08027ff | ||
|
|
b9e268aae8 | ||
|
|
4c1bb1148b | ||
|
|
50a6bbe6b9 | ||
|
|
dfb3cb510a | ||
|
|
519db812b4 | ||
|
|
3203f8e06b | ||
|
|
b71ac2042a | ||
|
|
d0f9e57022 | ||
|
|
aa68210cde | ||
|
|
280e837c9e | ||
|
|
f669e1226c | ||
|
|
cd15c19250 | ||
|
|
5b35fa17de | ||
|
|
9672298fa8 | ||
|
|
bf3ce16823 | ||
|
|
189721da16 | ||
|
|
a523ab1f93 | ||
|
|
7ea8b636d9 | ||
|
|
b2fea65faa | ||
|
|
a1fe8bf6cd | ||
|
|
67ee9e4391 | ||
|
|
9c1ee50497 | ||
|
|
7c842f22d0 | ||
|
|
14ec29991c | ||
|
|
c7f5aad83a | ||
|
|
e77747cff8 | ||
|
|
f2612db4b1 | ||
|
|
a21ff71373 | ||
|
|
fc04ad7854 | ||
|
|
fbf2eebc42 | ||
|
|
dc43430812 | ||
|
|
c6642bc2e6 | ||
|
|
bdca24dd3b | ||
|
|
00c5734021 | ||
|
|
33f87cf1f0 | ||
|
|
69935c1ba3 | ||
|
|
1fb356f328 | ||
|
|
0b0406f41a | ||
|
|
cc264dcf55 | ||
|
|
e024aebb66 | ||
|
|
f204729482 | ||
|
|
d20cf56a69 | ||
|
|
54d57c7d4b | ||
|
|
28a7c9d38f | ||
|
|
872e75e10d | ||
|
|
af1730079e | ||
|
|
04521e2d24 | ||
|
|
02adcccbf4 | ||
|
|
a74aaf1823 | ||
|
|
1eb2089c69 | ||
|
|
f7f3730e1a | ||
|
|
0ee7654407 | ||
|
|
16cc990fdd | ||
|
|
00061c2e5b | ||
|
|
6793f7a1de | ||
|
|
c8428a7f65 | ||
|
|
a5e1d0d0dc | ||
|
|
8270c7deed | ||
|
|
9f4a882a1b | ||
|
|
cb7b7cb72e | ||
|
|
603c93aa4a | ||
|
|
cb8d9d413a | ||
|
|
ff7443c6a7 | ||
|
|
b812e74d6f | ||
|
|
089cdaa75f | ||
|
|
476026e393 | ||
|
|
75952308f9 | ||
|
|
df0550d6dc | ||
|
|
e481b63b21 | ||
|
|
e47079f0f4 | ||
|
|
9b2a279948 | ||
|
|
db87fd3239 | ||
|
|
598fda0c97 | ||
|
|
b0e335e7b0 | ||
|
|
efdf475da4 | ||
|
|
22a1315136 | ||
|
|
5b22823018 | ||
|
|
9ccb997ed8 | ||
|
|
01c92a6bc5 | ||
|
|
c9a2a60e28 | ||
|
|
b0739d63c0 | ||
|
|
c22f84cb5f | ||
|
|
60450bab02 | ||
|
|
3e8cec5c78 | ||
|
|
74ee163761 | ||
|
|
ea4f680052 | ||
|
|
58cdba2c5d | ||
|
|
a2fbc14c70 | ||
|
|
158df8c270 | ||
|
|
30ec236a87 | ||
|
|
ac3653b3f9 | ||
|
|
8520db4fd9 | ||
|
|
14141fed62 | ||
|
|
5d86cc2f20 | ||
|
|
8a6e4b7580 | ||
|
|
453f949638 | ||
|
|
75a330184d | ||
|
|
546fcd8bab | ||
|
|
3022f93b6d | ||
|
|
8ffc392c66 | ||
|
|
ae7d85dadf | ||
|
|
e89268ac37 | ||
|
|
aaa3623404 | ||
|
|
8998f16054 | ||
|
|
94417be018 | ||
|
|
f300408fc0 | ||
|
|
aaa39e17a3 | ||
|
|
73906f996d | ||
|
|
73417f54db | ||
|
|
491213afb8 | ||
|
|
edf743cd72 | ||
|
|
78a88be787 | ||
|
|
9f6a574f97 | ||
|
|
ea01cc78f0 | ||
|
|
a87978568a | ||
|
|
14cecb9b31 | ||
|
|
0ce54100a4 | ||
|
|
d28ac35ff7 | ||
|
|
a5962f677f |
@@ -1,4 +1,3 @@
|
||||
comment: false
|
||||
ignore:
|
||||
- "doc"
|
||||
- "example"
|
||||
- "tools"
|
||||
- "tools"
|
||||
6
.github/workflows/go.yml
vendored
6
.github/workflows/go.yml
vendored
@@ -15,7 +15,7 @@ jobs:
|
||||
- name: Set up Go 1.x
|
||||
uses: actions/setup-go@v2
|
||||
with:
|
||||
go-version: ^1.13
|
||||
go-version: ^1.14
|
||||
id: go
|
||||
|
||||
- name: Check out code into the Go module directory
|
||||
@@ -29,6 +29,4 @@ jobs:
|
||||
run: go test -race -coverprofile=coverage.txt -covermode=atomic ./...
|
||||
|
||||
- name: Codecov
|
||||
uses: codecov/codecov-action@v1.0.6
|
||||
with:
|
||||
token: ${{secrets.CODECOV_TOKEN}}
|
||||
uses: codecov/codecov-action@v2
|
||||
|
||||
19
.github/workflows/issues.yml
vendored
Normal file
19
.github/workflows/issues.yml
vendored
Normal file
@@ -0,0 +1,19 @@
|
||||
name: Close inactive issues
|
||||
on:
|
||||
schedule:
|
||||
- cron: "30 1 * * *"
|
||||
|
||||
jobs:
|
||||
close-issues:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/stale@v3
|
||||
with:
|
||||
days-before-issue-stale: 30
|
||||
days-before-issue-close: 14
|
||||
stale-issue-label: "stale"
|
||||
stale-issue-message: "This issue is stale because it has been open for 30 days with no activity."
|
||||
close-issue-message: "This issue was closed because it has been inactive for 14 days since being marked as stale."
|
||||
days-before-pr-stale: -1
|
||||
days-before-pr-close: -1
|
||||
repo-token: ${{ secrets.GITHUB_TOKEN }}
|
||||
4
.gitignore
vendored
4
.gitignore
vendored
@@ -10,10 +10,14 @@
|
||||
!*/
|
||||
!api
|
||||
|
||||
# ignore
|
||||
.idea
|
||||
**/.DS_Store
|
||||
**/logs
|
||||
|
||||
# ignore adhoc test code
|
||||
**/adhoc
|
||||
|
||||
# gitlab ci
|
||||
.cache
|
||||
|
||||
|
||||
102
CONTRIBUTING.md
Normal file
102
CONTRIBUTING.md
Normal file
@@ -0,0 +1,102 @@
|
||||
# Contributing
|
||||
|
||||
Welcome to go-zero!
|
||||
|
||||
- [Before you get started](#before-you-get-started)
|
||||
- [Code of Conduct](#code-of-conduct)
|
||||
- [Community Expectations](#community-expectations)
|
||||
- [Getting started](#getting-started)
|
||||
- [Your First Contribution](#your-first-contribution)
|
||||
- [Find something to work on](#find-something-to-work-on)
|
||||
- [Find a good first topic](#find-a-good-first-topic)
|
||||
- [Work on an Issue](#work-on-an-issue)
|
||||
- [File an Issue](#file-an-issue)
|
||||
- [Contributor Workflow](#contributor-workflow)
|
||||
- [Creating Pull Requests](#creating-pull-requests)
|
||||
- [Code Review](#code-review)
|
||||
- [Testing](#testing)
|
||||
|
||||
# Before you get started
|
||||
|
||||
## Code of Conduct
|
||||
|
||||
Please make sure to read and observe our [Code of Conduct](/code-of-conduct.md).
|
||||
|
||||
## Community Expectations
|
||||
|
||||
go-zero is a community project driven by its community which strives to promote a healthy, friendly and productive environment.
|
||||
go-zero is a web and rpc framework written in Go. It's born to ensure the stability of the busy sites with resilient design. Builtin goctl greatly improves the development productivity.
|
||||
|
||||
# Getting started
|
||||
|
||||
- Fork the repository on GitHub.
|
||||
- Make your changes on your fork repository.
|
||||
- Submit a PR.
|
||||
|
||||
|
||||
# Your First Contribution
|
||||
|
||||
We will help you to contribute in different areas like filing issues, developing features, fixing critical bugs and
|
||||
getting your work reviewed and merged.
|
||||
|
||||
If you have questions about the development process,
|
||||
feel free to [file an issue](https://github.com/tal-tech/go-zero/issues/new/choose).
|
||||
|
||||
## Find something to work on
|
||||
|
||||
We are always in need of help, be it fixing documentation, reporting bugs or writing some code.
|
||||
Look at places where you feel best coding practices aren't followed, code refactoring is needed or tests are missing.
|
||||
Here is how you get started.
|
||||
|
||||
### Find a good first topic
|
||||
|
||||
[go-zero](https://github.com/tal-tech/go-zero) has beginner-friendly issues that provide a good first issue.
|
||||
For example, [go-zero](https://github.com/tal-tech/go-zero) has
|
||||
[help wanted](https://github.com/tal-tech/go-zero/issues?q=is%3Aopen+is%3Aissue+label%3A%22help+wanted%22) and
|
||||
[good first issue](https://github.com/tal-tech/go-zero/issues?q=is%3Aopen+is%3Aissue+label%3A%22good+first+issue%22)
|
||||
labels for issues that should not need deep knowledge of the system.
|
||||
We can help new contributors who wish to work on such issues.
|
||||
|
||||
Another good way to contribute is to find a documentation improvement, such as a missing/broken link.
|
||||
Please see [Contributing](#contributing) below for the workflow.
|
||||
|
||||
#### Work on an issue
|
||||
|
||||
When you are willing to take on an issue, just reply on the issue. The maintainer will assign it to you.
|
||||
|
||||
### File an Issue
|
||||
|
||||
While we encourage everyone to contribute code, it is also appreciated when someone reports an issue.
|
||||
|
||||
Please follow the prompted submission guidelines while opening an issue.
|
||||
|
||||
# Contributor Workflow
|
||||
|
||||
Please do not ever hesitate to ask a question or send a pull request.
|
||||
|
||||
This is a rough outline of what a contributor's workflow looks like:
|
||||
|
||||
- Create a topic branch from where to base the contribution. This is usually master.
|
||||
- Make commits of logical units.
|
||||
- Push changes in a topic branch to a personal fork of the repository.
|
||||
- Submit a pull request to [go-zero](https://github.com/tal-tech/go-zero).
|
||||
|
||||
## Creating Pull Requests
|
||||
|
||||
Pull requests are often called simply "PR".
|
||||
go-zero generally follows the standard [github pull request](https://help.github.com/articles/about-pull-requests/) process.
|
||||
To submit a proposed change, please develop the code/fix and add new test cases.
|
||||
After that, run these local verifications before submitting pull request to predict the pass or
|
||||
fail of continuous integration.
|
||||
|
||||
* Format the code with `gofmt`
|
||||
* Run the test with data race enabled `go test -race ./...`
|
||||
|
||||
## Code Review
|
||||
|
||||
To make it easier for your PR to receive reviews, consider the reviewers will need you to:
|
||||
|
||||
* follow [good coding guidelines](https://github.com/golang/go/wiki/CodeReviewComments).
|
||||
* write [good commit messages](https://chris.beams.io/posts/git-commit/).
|
||||
* break large changes into a logical series of smaller patches which individually make easily understandable changes, and in aggregate solve a broader issue.
|
||||
|
||||
21
ROADMAP.md
Normal file
21
ROADMAP.md
Normal file
@@ -0,0 +1,21 @@
|
||||
# go-zero Roadmap
|
||||
|
||||
This document defines a high level roadmap for go-zero development and upcoming releases.
|
||||
Community and contributor involvement is vital for successfully implementing all desired items for each release.
|
||||
We hope that the items listed below will inspire further engagement from the community to keep go-zero progressing and shipping exciting and valuable features.
|
||||
|
||||
## 2021 Q2
|
||||
- Support TLS in redis connections
|
||||
- Support service discovery through K8S watch api
|
||||
- Log full sql statements for easier sql problem solving
|
||||
|
||||
## 2021 Q3
|
||||
- Support `goctl mock` command to start a mocking server with given `.api` file
|
||||
- Adapt builtin tracing mechanism to opentracing solutions
|
||||
- Support `goctl model pg` to support PostgreSQL code generation
|
||||
|
||||
## 2021 Q4
|
||||
- Support `goctl doctor` command to report potential issues for given service
|
||||
- Support `context` in redis related methods for timeout and tracing
|
||||
- Support `context` in sql related methods for timeout and tracing
|
||||
- Support `context` in mongodb related methods for timeout and tracing
|
||||
76
code-of-conduct.md
Normal file
76
code-of-conduct.md
Normal file
@@ -0,0 +1,76 @@
|
||||
# Contributor Covenant Code of Conduct
|
||||
|
||||
## Our Pledge
|
||||
|
||||
In the interest of fostering an open and welcoming environment, we as
|
||||
contributors and maintainers pledge to make participation in our project and
|
||||
our community a harassment-free experience for everyone, regardless of age, body
|
||||
size, disability, ethnicity, sex characteristics, gender identity and expression,
|
||||
level of experience, education, socio-economic status, nationality, personal
|
||||
appearance, race, religion, or sexual identity and orientation.
|
||||
|
||||
## Our Standards
|
||||
|
||||
Examples of behavior that contributes to creating a positive environment
|
||||
include:
|
||||
|
||||
* Using welcoming and inclusive language
|
||||
* Being respectful of differing viewpoints and experiences
|
||||
* Gracefully accepting constructive criticism
|
||||
* Focusing on what is best for the community
|
||||
* Showing empathy towards other community members
|
||||
|
||||
Examples of unacceptable behavior by participants include:
|
||||
|
||||
* The use of sexualized language or imagery and unwelcome sexual attention or
|
||||
advances
|
||||
* Trolling, insulting/derogatory comments, and personal or political attacks
|
||||
* Public or private harassment
|
||||
* Publishing others' private information, such as a physical or electronic
|
||||
address, without explicit permission
|
||||
* Other conduct which could reasonably be considered inappropriate in a
|
||||
professional setting
|
||||
|
||||
## Our Responsibilities
|
||||
|
||||
Project maintainers are responsible for clarifying the standards of acceptable
|
||||
behavior and are expected to take appropriate and fair corrective action in
|
||||
response to any instances of unacceptable behavior.
|
||||
|
||||
Project maintainers have the right and responsibility to remove, edit, or
|
||||
reject comments, commits, code, wiki edits, issues, and other contributions
|
||||
that are not aligned to this Code of Conduct, or to ban temporarily or
|
||||
permanently any contributor for other behaviors that they deem inappropriate,
|
||||
threatening, offensive, or harmful.
|
||||
|
||||
## Scope
|
||||
|
||||
This Code of Conduct applies within all project spaces, and it also applies when
|
||||
an individual is representing the project or its community in public spaces.
|
||||
Examples of representing a project or community include using an official
|
||||
project e-mail address, posting via an official social media account, or acting
|
||||
as an appointed representative at an online or offline event. Representation of
|
||||
a project may be further defined and clarified by project maintainers.
|
||||
|
||||
## Enforcement
|
||||
|
||||
Instances of abusive, harassing, or otherwise unacceptable behavior may be
|
||||
reported by contacting the project team at [INSERT EMAIL ADDRESS]. All
|
||||
complaints will be reviewed and investigated and will result in a response that
|
||||
is deemed necessary and appropriate to the circumstances. The project team is
|
||||
obligated to maintain confidentiality with regard to the reporter of an incident.
|
||||
Further details of specific enforcement policies may be posted separately.
|
||||
|
||||
Project maintainers who do not follow or enforce the Code of Conduct in good
|
||||
faith may face temporary or permanent repercussions as determined by other
|
||||
members of the project's leadership.
|
||||
|
||||
## Attribution
|
||||
|
||||
This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4,
|
||||
available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html
|
||||
|
||||
[homepage]: https://www.contributor-covenant.org
|
||||
|
||||
For answers to common questions about this code of conduct, see
|
||||
https://www.contributor-covenant.org/faq
|
||||
@@ -94,7 +94,7 @@ func (b *googleBreaker) markFailure() {
|
||||
b.stat.Add(0)
|
||||
}
|
||||
|
||||
func (b *googleBreaker) history() (accepts int64, total int64) {
|
||||
func (b *googleBreaker) history() (accepts, total int64) {
|
||||
b.stat.Reduce(func(b *collection.Bucket) {
|
||||
accepts += int64(b.Sum)
|
||||
total += b.Count
|
||||
|
||||
@@ -34,7 +34,7 @@ type (
|
||||
expire time.Duration
|
||||
timingWheel *TimingWheel
|
||||
lruCache lru
|
||||
barrier syncx.SharedCalls
|
||||
barrier syncx.SingleFlight
|
||||
unstableExpiry mathx.Unstable
|
||||
stats *cacheStat
|
||||
}
|
||||
@@ -46,7 +46,7 @@ func NewCache(expire time.Duration, opts ...CacheOption) (*Cache, error) {
|
||||
data: make(map[string]interface{}),
|
||||
expire: expire,
|
||||
lruCache: emptyLruCache,
|
||||
barrier: syncx.NewSharedCalls(),
|
||||
barrier: syncx.NewSingleFlight(),
|
||||
unstableExpiry: mathx.NewUnstable(expiryDeviation),
|
||||
}
|
||||
|
||||
|
||||
@@ -106,9 +106,7 @@ func (s *Set) KeysInt() []int {
|
||||
var keys []int
|
||||
|
||||
for key := range s.data {
|
||||
if intKey, ok := key.(int); !ok {
|
||||
continue
|
||||
} else {
|
||||
if intKey, ok := key.(int); ok {
|
||||
keys = append(keys, intKey)
|
||||
}
|
||||
}
|
||||
@@ -121,9 +119,7 @@ func (s *Set) KeysInt64() []int64 {
|
||||
var keys []int64
|
||||
|
||||
for key := range s.data {
|
||||
if intKey, ok := key.(int64); !ok {
|
||||
continue
|
||||
} else {
|
||||
if intKey, ok := key.(int64); ok {
|
||||
keys = append(keys, intKey)
|
||||
}
|
||||
}
|
||||
@@ -136,9 +132,7 @@ func (s *Set) KeysUint() []uint {
|
||||
var keys []uint
|
||||
|
||||
for key := range s.data {
|
||||
if intKey, ok := key.(uint); !ok {
|
||||
continue
|
||||
} else {
|
||||
if intKey, ok := key.(uint); ok {
|
||||
keys = append(keys, intKey)
|
||||
}
|
||||
}
|
||||
@@ -151,9 +145,7 @@ func (s *Set) KeysUint64() []uint64 {
|
||||
var keys []uint64
|
||||
|
||||
for key := range s.data {
|
||||
if intKey, ok := key.(uint64); !ok {
|
||||
continue
|
||||
} else {
|
||||
if intKey, ok := key.(uint64); ok {
|
||||
keys = append(keys, intKey)
|
||||
}
|
||||
}
|
||||
@@ -166,9 +158,7 @@ func (s *Set) KeysStr() []string {
|
||||
var keys []string
|
||||
|
||||
for key := range s.data {
|
||||
if strKey, ok := key.(string); !ok {
|
||||
continue
|
||||
} else {
|
||||
if strKey, ok := key.(string); ok {
|
||||
keys = append(keys, strKey)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -151,7 +151,7 @@ func (tw *TimingWheel) drainAll(fn func(key, value interface{})) {
|
||||
}
|
||||
}
|
||||
|
||||
func (tw *TimingWheel) getPositionAndCircle(d time.Duration) (pos int, circle int) {
|
||||
func (tw *TimingWheel) getPositionAndCircle(d time.Duration) (pos, circle int) {
|
||||
steps := int(d / tw.interval)
|
||||
pos = (tw.tickedPos + steps) % tw.numSlots
|
||||
circle = (steps - 1) / tw.numSlots
|
||||
|
||||
@@ -47,9 +47,11 @@ func TestUnmarshalContextWithMissing(t *testing.T) {
|
||||
Name string `ctx:"name"`
|
||||
Age int `ctx:"age"`
|
||||
}
|
||||
type name string
|
||||
const PersonNameKey name = "name"
|
||||
|
||||
ctx := context.Background()
|
||||
ctx = context.WithValue(ctx, "name", "kevin")
|
||||
ctx = context.WithValue(ctx, PersonNameKey, "kevin")
|
||||
|
||||
var person Person
|
||||
err := For(ctx, &person)
|
||||
|
||||
@@ -9,7 +9,9 @@ import (
|
||||
)
|
||||
|
||||
func TestContextCancel(t *testing.T) {
|
||||
c := context.WithValue(context.Background(), "key", "value")
|
||||
type key string
|
||||
var nameKey key = "name"
|
||||
c := context.WithValue(context.Background(), nameKey, "value")
|
||||
c1, cancel := context.WithCancel(c)
|
||||
o := ValueOnlyFrom(c1)
|
||||
c2, cancel2 := context.WithCancel(o)
|
||||
|
||||
@@ -5,7 +5,7 @@ package internal
|
||||
import (
|
||||
"context"
|
||||
|
||||
"go.etcd.io/etcd/clientv3"
|
||||
clientv3 "go.etcd.io/etcd/client/v3"
|
||||
"google.golang.org/grpc"
|
||||
)
|
||||
|
||||
|
||||
@@ -6,10 +6,11 @@ package internal
|
||||
|
||||
import (
|
||||
context "context"
|
||||
gomock "github.com/golang/mock/gomock"
|
||||
clientv3 "go.etcd.io/etcd/clientv3"
|
||||
grpc "google.golang.org/grpc"
|
||||
reflect "reflect"
|
||||
|
||||
gomock "github.com/golang/mock/gomock"
|
||||
clientv3 "go.etcd.io/etcd/client/v3"
|
||||
grpc "google.golang.org/grpc"
|
||||
)
|
||||
|
||||
// MockEtcdClient is a mock of EtcdClient interface
|
||||
|
||||
@@ -2,5 +2,5 @@ package internal
|
||||
|
||||
// Listener interface wraps the OnUpdate method.
|
||||
type Listener interface {
|
||||
OnUpdate(keys []string, values []string, newKey string)
|
||||
OnUpdate(keys, values []string, newKey string)
|
||||
}
|
||||
|
||||
@@ -14,7 +14,7 @@ import (
|
||||
"github.com/tal-tech/go-zero/core/logx"
|
||||
"github.com/tal-tech/go-zero/core/syncx"
|
||||
"github.com/tal-tech/go-zero/core/threading"
|
||||
"go.etcd.io/etcd/clientv3"
|
||||
clientv3 "go.etcd.io/etcd/client/v3"
|
||||
)
|
||||
|
||||
var (
|
||||
@@ -260,26 +260,34 @@ func (c *cluster) reload(cli EtcdClient) {
|
||||
}
|
||||
|
||||
func (c *cluster) watch(cli EtcdClient, key string) {
|
||||
for {
|
||||
if c.watchStream(cli, key) {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (c *cluster) watchStream(cli EtcdClient, key string) bool {
|
||||
rch := cli.Watch(clientv3.WithRequireLeader(c.context(cli)), makeKeyPrefix(key), clientv3.WithPrefix())
|
||||
for {
|
||||
select {
|
||||
case wresp, ok := <-rch:
|
||||
if !ok {
|
||||
logx.Error("etcd monitor chan has been closed")
|
||||
return
|
||||
return false
|
||||
}
|
||||
if wresp.Canceled {
|
||||
logx.Error("etcd monitor chan has been canceled")
|
||||
return
|
||||
logx.Errorf("etcd monitor chan has been canceled, error: %v", wresp.Err())
|
||||
return false
|
||||
}
|
||||
if wresp.Err() != nil {
|
||||
logx.Error(fmt.Sprintf("etcd monitor chan error: %v", wresp.Err()))
|
||||
return
|
||||
return false
|
||||
}
|
||||
|
||||
c.handleWatchEvents(key, wresp.Events)
|
||||
case <-c.done:
|
||||
return
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -8,10 +8,11 @@ import (
|
||||
"github.com/golang/mock/gomock"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/tal-tech/go-zero/core/contextx"
|
||||
"github.com/tal-tech/go-zero/core/lang"
|
||||
"github.com/tal-tech/go-zero/core/logx"
|
||||
"github.com/tal-tech/go-zero/core/stringx"
|
||||
"go.etcd.io/etcd/clientv3"
|
||||
"go.etcd.io/etcd/mvcc/mvccpb"
|
||||
"go.etcd.io/etcd/api/v3/mvccpb"
|
||||
clientv3 "go.etcd.io/etcd/client/v3"
|
||||
)
|
||||
|
||||
var mockLock sync.Mutex
|
||||
@@ -202,11 +203,13 @@ func TestClusterWatch_RespFailures(t *testing.T) {
|
||||
restore := setMockClient(cli)
|
||||
defer restore()
|
||||
ch := make(chan clientv3.WatchResponse)
|
||||
cli.EXPECT().Watch(gomock.Any(), "any/", gomock.Any()).Return(ch)
|
||||
cli.EXPECT().Watch(gomock.Any(), "any/", gomock.Any()).Return(ch).AnyTimes()
|
||||
cli.EXPECT().Ctx().Return(context.Background()).AnyTimes()
|
||||
c := new(cluster)
|
||||
c.done = make(chan lang.PlaceholderType)
|
||||
go func() {
|
||||
ch <- resp
|
||||
close(c.done)
|
||||
}()
|
||||
c.watch(cli, "any")
|
||||
})
|
||||
@@ -220,11 +223,13 @@ func TestClusterWatch_CloseChan(t *testing.T) {
|
||||
restore := setMockClient(cli)
|
||||
defer restore()
|
||||
ch := make(chan clientv3.WatchResponse)
|
||||
cli.EXPECT().Watch(gomock.Any(), "any/", gomock.Any()).Return(ch)
|
||||
cli.EXPECT().Watch(gomock.Any(), "any/", gomock.Any()).Return(ch).AnyTimes()
|
||||
cli.EXPECT().Ctx().Return(context.Background()).AnyTimes()
|
||||
c := new(cluster)
|
||||
c.done = make(chan lang.PlaceholderType)
|
||||
go func() {
|
||||
close(ch)
|
||||
close(c.done)
|
||||
}()
|
||||
c.watch(cli, "any")
|
||||
}
|
||||
|
||||
@@ -7,7 +7,7 @@ import (
|
||||
"github.com/tal-tech/go-zero/core/proc"
|
||||
"github.com/tal-tech/go-zero/core/syncx"
|
||||
"github.com/tal-tech/go-zero/core/threading"
|
||||
"go.etcd.io/etcd/clientv3"
|
||||
clientv3 "go.etcd.io/etcd/client/v3"
|
||||
)
|
||||
|
||||
type (
|
||||
|
||||
@@ -11,7 +11,7 @@ import (
|
||||
"github.com/tal-tech/go-zero/core/discov/internal"
|
||||
"github.com/tal-tech/go-zero/core/lang"
|
||||
"github.com/tal-tech/go-zero/core/logx"
|
||||
"go.etcd.io/etcd/clientv3"
|
||||
clientv3 "go.etcd.io/etcd/client/v3"
|
||||
)
|
||||
|
||||
func init() {
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
//go:build windows
|
||||
// +build windows
|
||||
|
||||
package fs
|
||||
|
||||
@@ -49,6 +49,11 @@ type (
|
||||
}
|
||||
)
|
||||
|
||||
// Concat returns a concatenated Stream.
|
||||
func Concat(s Stream, others ...Stream) Stream {
|
||||
return s.Concat(others...)
|
||||
}
|
||||
|
||||
// From constructs a Stream from the given GenerateFunc.
|
||||
func From(generate GenerateFunc) Stream {
|
||||
source := make(chan interface{})
|
||||
@@ -79,16 +84,42 @@ func Range(source <-chan interface{}) Stream {
|
||||
}
|
||||
}
|
||||
|
||||
// AllMach returns whether all elements of this stream match the provided predicate.
|
||||
// May not evaluate the predicate on all elements if not necessary for determining the result.
|
||||
// If the stream is empty then true is returned and the predicate is not evaluated.
|
||||
func (s Stream) AllMach(predicate func(item interface{}) bool) bool {
|
||||
for item := range s.source {
|
||||
if !predicate(item) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
// AnyMach returns whether any elements of this stream match the provided predicate.
|
||||
// May not evaluate the predicate on all elements if not necessary for determining the result.
|
||||
// If the stream is empty then false is returned and the predicate is not evaluated.
|
||||
func (s Stream) AnyMach(predicate func(item interface{}) bool) bool {
|
||||
for item := range s.source {
|
||||
if predicate(item) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// Buffer buffers the items into a queue with size n.
|
||||
// It can balance the producer and the consumer if their processing throughput don't match.
|
||||
func (p Stream) Buffer(n int) Stream {
|
||||
func (s Stream) Buffer(n int) Stream {
|
||||
if n < 0 {
|
||||
n = 0
|
||||
}
|
||||
|
||||
source := make(chan interface{}, n)
|
||||
go func() {
|
||||
for item := range p.source {
|
||||
for item := range s.source {
|
||||
source <- item
|
||||
}
|
||||
close(source)
|
||||
@@ -97,23 +128,51 @@ func (p Stream) Buffer(n int) Stream {
|
||||
return Range(source)
|
||||
}
|
||||
|
||||
// Concat returns a Stream that concatenated other streams
|
||||
func (s Stream) Concat(others ...Stream) Stream {
|
||||
source := make(chan interface{})
|
||||
|
||||
go func() {
|
||||
group := threading.NewRoutineGroup()
|
||||
group.Run(func() {
|
||||
for item := range s.source {
|
||||
source <- item
|
||||
}
|
||||
})
|
||||
|
||||
for _, each := range others {
|
||||
each := each
|
||||
group.Run(func() {
|
||||
for item := range each.source {
|
||||
source <- item
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
group.Wait()
|
||||
close(source)
|
||||
}()
|
||||
|
||||
return Range(source)
|
||||
}
|
||||
|
||||
// Count counts the number of elements in the result.
|
||||
func (p Stream) Count() (count int) {
|
||||
for range p.source {
|
||||
func (s Stream) Count() (count int) {
|
||||
for range s.source {
|
||||
count++
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Distinct removes the duplicated items base on the given KeyFunc.
|
||||
func (p Stream) Distinct(fn KeyFunc) Stream {
|
||||
func (s Stream) Distinct(fn KeyFunc) Stream {
|
||||
source := make(chan interface{})
|
||||
|
||||
threading.GoSafe(func() {
|
||||
defer close(source)
|
||||
|
||||
keys := make(map[interface{}]lang.PlaceholderType)
|
||||
for item := range p.source {
|
||||
for item := range s.source {
|
||||
key := fn(item)
|
||||
if _, ok := keys[key]; !ok {
|
||||
source <- item
|
||||
@@ -126,14 +185,14 @@ func (p Stream) Distinct(fn KeyFunc) Stream {
|
||||
}
|
||||
|
||||
// Done waits all upstreaming operations to be done.
|
||||
func (p Stream) Done() {
|
||||
for range p.source {
|
||||
func (s Stream) Done() {
|
||||
for range s.source {
|
||||
}
|
||||
}
|
||||
|
||||
// Filter filters the items by the given FilterFunc.
|
||||
func (p Stream) Filter(fn FilterFunc, opts ...Option) Stream {
|
||||
return p.Walk(func(item interface{}, pipe chan<- interface{}) {
|
||||
func (s Stream) Filter(fn FilterFunc, opts ...Option) Stream {
|
||||
return s.Walk(func(item interface{}, pipe chan<- interface{}) {
|
||||
if fn(item) {
|
||||
pipe <- item
|
||||
}
|
||||
@@ -141,21 +200,21 @@ func (p Stream) Filter(fn FilterFunc, opts ...Option) Stream {
|
||||
}
|
||||
|
||||
// ForAll handles the streaming elements from the source and no later streams.
|
||||
func (p Stream) ForAll(fn ForAllFunc) {
|
||||
fn(p.source)
|
||||
func (s Stream) ForAll(fn ForAllFunc) {
|
||||
fn(s.source)
|
||||
}
|
||||
|
||||
// ForEach seals the Stream with the ForEachFunc on each item, no successive operations.
|
||||
func (p Stream) ForEach(fn ForEachFunc) {
|
||||
for item := range p.source {
|
||||
func (s Stream) ForEach(fn ForEachFunc) {
|
||||
for item := range s.source {
|
||||
fn(item)
|
||||
}
|
||||
}
|
||||
|
||||
// Group groups the elements into different groups based on their keys.
|
||||
func (p Stream) Group(fn KeyFunc) Stream {
|
||||
func (s Stream) Group(fn KeyFunc) Stream {
|
||||
groups := make(map[interface{}][]interface{})
|
||||
for item := range p.source {
|
||||
for item := range s.source {
|
||||
key := fn(item)
|
||||
groups[key] = append(groups[key], item)
|
||||
}
|
||||
@@ -172,7 +231,7 @@ func (p Stream) Group(fn KeyFunc) Stream {
|
||||
}
|
||||
|
||||
// Head returns the first n elements in p.
|
||||
func (p Stream) Head(n int64) Stream {
|
||||
func (s Stream) Head(n int64) Stream {
|
||||
if n < 1 {
|
||||
panic("n must be greater than 0")
|
||||
}
|
||||
@@ -180,7 +239,7 @@ func (p Stream) Head(n int64) Stream {
|
||||
source := make(chan interface{})
|
||||
|
||||
go func() {
|
||||
for item := range p.source {
|
||||
for item := range s.source {
|
||||
n--
|
||||
if n >= 0 {
|
||||
source <- item
|
||||
@@ -201,16 +260,16 @@ func (p Stream) Head(n int64) Stream {
|
||||
}
|
||||
|
||||
// Map converts each item to another corresponding item, which means it's a 1:1 model.
|
||||
func (p Stream) Map(fn MapFunc, opts ...Option) Stream {
|
||||
return p.Walk(func(item interface{}, pipe chan<- interface{}) {
|
||||
func (s Stream) Map(fn MapFunc, opts ...Option) Stream {
|
||||
return s.Walk(func(item interface{}, pipe chan<- interface{}) {
|
||||
pipe <- fn(item)
|
||||
}, opts...)
|
||||
}
|
||||
|
||||
// Merge merges all the items into a slice and generates a new stream.
|
||||
func (p Stream) Merge() Stream {
|
||||
func (s Stream) Merge() Stream {
|
||||
var items []interface{}
|
||||
for item := range p.source {
|
||||
for item := range s.source {
|
||||
items = append(items, item)
|
||||
}
|
||||
|
||||
@@ -222,21 +281,21 @@ func (p Stream) Merge() Stream {
|
||||
}
|
||||
|
||||
// Parallel applies the given ParallelFunc to each item concurrently with given number of workers.
|
||||
func (p Stream) Parallel(fn ParallelFunc, opts ...Option) {
|
||||
p.Walk(func(item interface{}, pipe chan<- interface{}) {
|
||||
func (s Stream) Parallel(fn ParallelFunc, opts ...Option) {
|
||||
s.Walk(func(item interface{}, pipe chan<- interface{}) {
|
||||
fn(item)
|
||||
}, opts...).Done()
|
||||
}
|
||||
|
||||
// Reduce is a utility method to let the caller deal with the underlying channel.
|
||||
func (p Stream) Reduce(fn ReduceFunc) (interface{}, error) {
|
||||
return fn(p.source)
|
||||
func (s Stream) Reduce(fn ReduceFunc) (interface{}, error) {
|
||||
return fn(s.source)
|
||||
}
|
||||
|
||||
// Reverse reverses the elements in the stream.
|
||||
func (p Stream) Reverse() Stream {
|
||||
func (s Stream) Reverse() Stream {
|
||||
var items []interface{}
|
||||
for item := range p.source {
|
||||
for item := range s.source {
|
||||
items = append(items, item)
|
||||
}
|
||||
// reverse, official method
|
||||
@@ -248,10 +307,36 @@ func (p Stream) Reverse() Stream {
|
||||
return Just(items...)
|
||||
}
|
||||
|
||||
// Skip returns a Stream that skips size elements.
|
||||
func (s Stream) Skip(n int64) Stream {
|
||||
if n < 0 {
|
||||
panic("n must not be negative")
|
||||
}
|
||||
if n == 0 {
|
||||
return s
|
||||
}
|
||||
|
||||
source := make(chan interface{})
|
||||
|
||||
go func() {
|
||||
for item := range s.source {
|
||||
n--
|
||||
if n >= 0 {
|
||||
continue
|
||||
} else {
|
||||
source <- item
|
||||
}
|
||||
}
|
||||
close(source)
|
||||
}()
|
||||
|
||||
return Range(source)
|
||||
}
|
||||
|
||||
// Sort sorts the items from the underlying source.
|
||||
func (p Stream) Sort(less LessFunc) Stream {
|
||||
func (s Stream) Sort(less LessFunc) Stream {
|
||||
var items []interface{}
|
||||
for item := range p.source {
|
||||
for item := range s.source {
|
||||
items = append(items, item)
|
||||
}
|
||||
sort.Slice(items, func(i, j int) bool {
|
||||
@@ -263,7 +348,7 @@ func (p Stream) Sort(less LessFunc) Stream {
|
||||
|
||||
// Split splits the elements into chunk with size up to n,
|
||||
// might be less than n on tailing elements.
|
||||
func (p Stream) Split(n int) Stream {
|
||||
func (s Stream) Split(n int) Stream {
|
||||
if n < 1 {
|
||||
panic("n should be greater than 0")
|
||||
}
|
||||
@@ -271,7 +356,7 @@ func (p Stream) Split(n int) Stream {
|
||||
source := make(chan interface{})
|
||||
go func() {
|
||||
var chunk []interface{}
|
||||
for item := range p.source {
|
||||
for item := range s.source {
|
||||
chunk = append(chunk, item)
|
||||
if len(chunk) == n {
|
||||
source <- chunk
|
||||
@@ -288,7 +373,7 @@ func (p Stream) Split(n int) Stream {
|
||||
}
|
||||
|
||||
// Tail returns the last n elements in p.
|
||||
func (p Stream) Tail(n int64) Stream {
|
||||
func (s Stream) Tail(n int64) Stream {
|
||||
if n < 1 {
|
||||
panic("n should be greater than 0")
|
||||
}
|
||||
@@ -297,7 +382,7 @@ func (p Stream) Tail(n int64) Stream {
|
||||
|
||||
go func() {
|
||||
ring := collection.NewRing(int(n))
|
||||
for item := range p.source {
|
||||
for item := range s.source {
|
||||
ring.Add(item)
|
||||
}
|
||||
for _, item := range ring.Take() {
|
||||
@@ -310,16 +395,16 @@ func (p Stream) Tail(n int64) Stream {
|
||||
}
|
||||
|
||||
// Walk lets the callers handle each item, the caller may write zero, one or more items base on the given item.
|
||||
func (p Stream) Walk(fn WalkFunc, opts ...Option) Stream {
|
||||
func (s Stream) Walk(fn WalkFunc, opts ...Option) Stream {
|
||||
option := buildOptions(opts...)
|
||||
if option.unlimitedWorkers {
|
||||
return p.walkUnlimited(fn, option)
|
||||
return s.walkUnlimited(fn, option)
|
||||
}
|
||||
|
||||
return p.walkLimited(fn, option)
|
||||
return s.walkLimited(fn, option)
|
||||
}
|
||||
|
||||
func (p Stream) walkLimited(fn WalkFunc, option *rxOptions) Stream {
|
||||
func (s Stream) walkLimited(fn WalkFunc, option *rxOptions) Stream {
|
||||
pipe := make(chan interface{}, option.workers)
|
||||
|
||||
go func() {
|
||||
@@ -328,7 +413,7 @@ func (p Stream) walkLimited(fn WalkFunc, option *rxOptions) Stream {
|
||||
|
||||
for {
|
||||
pool <- lang.Placeholder
|
||||
item, ok := <-p.source
|
||||
item, ok := <-s.source
|
||||
if !ok {
|
||||
<-pool
|
||||
break
|
||||
@@ -353,14 +438,14 @@ func (p Stream) walkLimited(fn WalkFunc, option *rxOptions) Stream {
|
||||
return Range(pipe)
|
||||
}
|
||||
|
||||
func (p Stream) walkUnlimited(fn WalkFunc, option *rxOptions) Stream {
|
||||
func (s Stream) walkUnlimited(fn WalkFunc, option *rxOptions) Stream {
|
||||
pipe := make(chan interface{}, defaultWorkers)
|
||||
|
||||
go func() {
|
||||
var wg sync.WaitGroup
|
||||
|
||||
for {
|
||||
item, ok := <-p.source
|
||||
item, ok := <-s.source
|
||||
if !ok {
|
||||
break
|
||||
}
|
||||
|
||||
@@ -3,7 +3,10 @@ package fx
|
||||
import (
|
||||
"io/ioutil"
|
||||
"log"
|
||||
"math/rand"
|
||||
"reflect"
|
||||
"runtime"
|
||||
"sort"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
@@ -330,6 +333,29 @@ func TestWalk(t *testing.T) {
|
||||
assert.Equal(t, 9, result)
|
||||
}
|
||||
|
||||
func BenchmarkParallelMapReduce(b *testing.B) {
|
||||
b.ReportAllocs()
|
||||
|
||||
mapper := func(v interface{}) interface{} {
|
||||
return v.(int64) * v.(int64)
|
||||
}
|
||||
reducer := func(input <-chan interface{}) (interface{}, error) {
|
||||
var result int64
|
||||
for v := range input {
|
||||
result += v.(int64)
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
b.ResetTimer()
|
||||
From(func(input chan<- interface{}) {
|
||||
b.RunParallel(func(pb *testing.PB) {
|
||||
for pb.Next() {
|
||||
input <- int64(rand.Int())
|
||||
}
|
||||
})
|
||||
}).Map(mapper).Reduce(reducer)
|
||||
}
|
||||
|
||||
func BenchmarkMapReduce(b *testing.B) {
|
||||
b.ReportAllocs()
|
||||
|
||||
@@ -343,12 +369,103 @@ func BenchmarkMapReduce(b *testing.B) {
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
b.ResetTimer()
|
||||
From(func(input chan<- interface{}) {
|
||||
for i := 0; i < b.N; i++ {
|
||||
input <- int64(rand.Int())
|
||||
}
|
||||
}).Map(mapper).Reduce(reducer)
|
||||
}
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
From(func(input chan<- interface{}) {
|
||||
for j := 0; j < 2; j++ {
|
||||
input <- int64(j)
|
||||
}
|
||||
}).Map(mapper).Reduce(reducer)
|
||||
func equal(t *testing.T, stream Stream, data []interface{}) {
|
||||
items := make([]interface{}, 0)
|
||||
for item := range stream.source {
|
||||
items = append(items, item)
|
||||
}
|
||||
if !reflect.DeepEqual(items, data) {
|
||||
t.Errorf(" %v, want %v", items, data)
|
||||
}
|
||||
}
|
||||
|
||||
func assetEqual(t *testing.T, except, data interface{}) {
|
||||
if !reflect.DeepEqual(except, data) {
|
||||
t.Errorf(" %v, want %v", data, except)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStream_AnyMach(t *testing.T) {
|
||||
assetEqual(t, false, Just(1, 2, 3).AnyMach(func(item interface{}) bool {
|
||||
return 4 == item.(int)
|
||||
}))
|
||||
assetEqual(t, false, Just(1, 2, 3).AnyMach(func(item interface{}) bool {
|
||||
return 0 == item.(int)
|
||||
}))
|
||||
assetEqual(t, true, Just(1, 2, 3).AnyMach(func(item interface{}) bool {
|
||||
return 2 == item.(int)
|
||||
}))
|
||||
assetEqual(t, true, Just(1, 2, 3).AnyMach(func(item interface{}) bool {
|
||||
return 2 == item.(int)
|
||||
}))
|
||||
}
|
||||
|
||||
func TestStream_AllMach(t *testing.T) {
|
||||
assetEqual(
|
||||
t, true, Just(1, 2, 3).AllMach(func(item interface{}) bool {
|
||||
return true
|
||||
}),
|
||||
)
|
||||
assetEqual(
|
||||
t, false, Just(1, 2, 3).AllMach(func(item interface{}) bool {
|
||||
return false
|
||||
}),
|
||||
)
|
||||
assetEqual(
|
||||
t, false, Just(1, 2, 3).AllMach(func(item interface{}) bool {
|
||||
return item.(int) == 1
|
||||
}),
|
||||
)
|
||||
}
|
||||
|
||||
func TestConcat(t *testing.T) {
|
||||
a1 := []interface{}{1, 2, 3}
|
||||
a2 := []interface{}{4, 5, 6}
|
||||
s1 := Just(a1...)
|
||||
s2 := Just(a2...)
|
||||
stream := Concat(s1, s2)
|
||||
var items []interface{}
|
||||
for item := range stream.source {
|
||||
items = append(items, item)
|
||||
}
|
||||
sort.Slice(items, func(i, j int) bool {
|
||||
return items[i].(int) < items[j].(int)
|
||||
})
|
||||
ints := make([]interface{}, 0)
|
||||
ints = append(ints, a1...)
|
||||
ints = append(ints, a2...)
|
||||
assetEqual(t, ints, items)
|
||||
}
|
||||
|
||||
func TestStream_Skip(t *testing.T) {
|
||||
assetEqual(t, 3, Just(1, 2, 3, 4).Skip(1).Count())
|
||||
assetEqual(t, 1, Just(1, 2, 3, 4).Skip(3).Count())
|
||||
assetEqual(t, 4, Just(1, 2, 3, 4).Skip(0).Count())
|
||||
equal(t, Just(1, 2, 3, 4).Skip(3), []interface{}{4})
|
||||
assert.Panics(t, func() {
|
||||
Just(1, 2, 3, 4).Skip(-1)
|
||||
})
|
||||
}
|
||||
|
||||
func TestStream_Concat(t *testing.T) {
|
||||
stream := Just(1).Concat(Just(2), Just(3))
|
||||
var items []interface{}
|
||||
for item := range stream.source {
|
||||
items = append(items, item)
|
||||
}
|
||||
sort.Slice(items, func(i, j int) bool {
|
||||
return items[i].(int) < items[j].(int)
|
||||
})
|
||||
assetEqual(t, []interface{}{1, 2, 3}, items)
|
||||
|
||||
just := Just(1)
|
||||
equal(t, just.Concat(just), []interface{}{1})
|
||||
}
|
||||
|
||||
@@ -2,6 +2,9 @@ package fx
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"runtime/debug"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
@@ -30,7 +33,8 @@ func DoWithTimeout(fn func() error, timeout time.Duration, opts ...DoOption) err
|
||||
go func() {
|
||||
defer func() {
|
||||
if p := recover(); p != nil {
|
||||
panicChan <- p
|
||||
// attach call stack to avoid missing in different goroutine
|
||||
panicChan <- fmt.Sprintf("%+v\n\n%s", p, strings.TrimSpace(string(debug.Stack())))
|
||||
}
|
||||
}()
|
||||
done <- fn()
|
||||
|
||||
@@ -83,7 +83,7 @@ func (h *ConsistentHash) AddWithReplicas(node interface{}, replicas int) {
|
||||
h.ring[hash] = append(h.ring[hash], node)
|
||||
}
|
||||
|
||||
sort.Slice(h.keys, func(i int, j int) bool {
|
||||
sort.Slice(h.keys, func(i, j int) bool {
|
||||
return h.keys[i] < h.keys[j]
|
||||
})
|
||||
}
|
||||
|
||||
@@ -31,6 +31,8 @@ var (
|
||||
|
||||
// default to be enabled
|
||||
enabled = syncx.ForAtomicBool(true)
|
||||
// default to be enabled
|
||||
logEnabled = syncx.ForAtomicBool(true)
|
||||
// make it a variable for unit test
|
||||
systemOverloadChecker = func(cpuThreshold int64) bool {
|
||||
return stat.CpuUsage() >= cpuThreshold
|
||||
@@ -80,6 +82,11 @@ func Disable() {
|
||||
enabled.Set(false)
|
||||
}
|
||||
|
||||
// DisableLog disables the stat logs for load shedding.
|
||||
func DisableLog() {
|
||||
logEnabled.Set(false)
|
||||
}
|
||||
|
||||
// NewAdaptiveShedder returns an adaptive shedder.
|
||||
// opts can be used to customize the Shedder.
|
||||
func NewAdaptiveShedder(opts ...ShedderOption) Shedder {
|
||||
|
||||
@@ -25,6 +25,7 @@ func init() {
|
||||
}
|
||||
|
||||
func TestAdaptiveShedder(t *testing.T) {
|
||||
DisableLog()
|
||||
shedder := NewAdaptiveShedder(WithWindow(bucketDuration), WithBuckets(buckets), WithCpuThreshold(100))
|
||||
var wg sync.WaitGroup
|
||||
var drop int64
|
||||
|
||||
@@ -48,6 +48,25 @@ func (s *SheddingStat) IncrementDrop() {
|
||||
atomic.AddInt64(&s.drop, 1)
|
||||
}
|
||||
|
||||
func (s *SheddingStat) loop(c <-chan time.Time) {
|
||||
for range c {
|
||||
st := s.reset()
|
||||
|
||||
if !logEnabled.True() {
|
||||
continue
|
||||
}
|
||||
|
||||
c := stat.CpuUsage()
|
||||
if st.Drop == 0 {
|
||||
logx.Statf("(%s) shedding_stat [1m], cpu: %d, total: %d, pass: %d, drop: %d",
|
||||
s.name, c, st.Total, st.Pass, st.Drop)
|
||||
} else {
|
||||
logx.Statf("(%s) shedding_stat_drop [1m], cpu: %d, total: %d, pass: %d, drop: %d",
|
||||
s.name, c, st.Total, st.Pass, st.Drop)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *SheddingStat) reset() snapshot {
|
||||
return snapshot{
|
||||
Total: atomic.SwapInt64(&s.total, 0),
|
||||
@@ -59,15 +78,6 @@ func (s *SheddingStat) reset() snapshot {
|
||||
func (s *SheddingStat) run() {
|
||||
ticker := time.NewTicker(time.Minute)
|
||||
defer ticker.Stop()
|
||||
for range ticker.C {
|
||||
c := stat.CpuUsage()
|
||||
st := s.reset()
|
||||
if st.Drop == 0 {
|
||||
logx.Statf("(%s) shedding_stat [1m], cpu: %d, total: %d, pass: %d, drop: %d",
|
||||
s.name, c, st.Total, st.Pass, st.Drop)
|
||||
} else {
|
||||
logx.Statf("(%s) shedding_stat_drop [1m], cpu: %d, total: %d, pass: %d, drop: %d",
|
||||
s.name, c, st.Total, st.Pass, st.Drop)
|
||||
}
|
||||
}
|
||||
|
||||
s.loop(ticker.C)
|
||||
}
|
||||
|
||||
@@ -2,6 +2,7 @@ package load
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
@@ -22,3 +23,32 @@ func TestSheddingStat(t *testing.T) {
|
||||
assert.Equal(t, int64(5), result.Pass)
|
||||
assert.Equal(t, int64(7), result.Drop)
|
||||
}
|
||||
|
||||
func TestLoopTrue(t *testing.T) {
|
||||
ch := make(chan time.Time, 1)
|
||||
ch <- time.Now()
|
||||
close(ch)
|
||||
st := new(SheddingStat)
|
||||
logEnabled.Set(true)
|
||||
st.loop(ch)
|
||||
}
|
||||
|
||||
func TestLoopTrueAndDrop(t *testing.T) {
|
||||
ch := make(chan time.Time, 1)
|
||||
ch <- time.Now()
|
||||
close(ch)
|
||||
st := new(SheddingStat)
|
||||
st.IncrementDrop()
|
||||
logEnabled.Set(true)
|
||||
st.loop(ch)
|
||||
}
|
||||
|
||||
func TestLoopFalseAndDrop(t *testing.T) {
|
||||
ch := make(chan time.Time, 1)
|
||||
ch <- time.Now()
|
||||
close(ch)
|
||||
st := new(SheddingStat)
|
||||
st.IncrementDrop()
|
||||
logEnabled.Set(false)
|
||||
st.loop(ch)
|
||||
}
|
||||
|
||||
@@ -20,49 +20,67 @@ func WithDuration(d time.Duration) Logger {
|
||||
}
|
||||
|
||||
func (l *durationLogger) Error(v ...interface{}) {
|
||||
if shouldLog(ErrorLevel) {
|
||||
if shallLog(ErrorLevel) {
|
||||
l.write(errorLog, levelError, formatWithCaller(fmt.Sprint(v...), durationCallerDepth))
|
||||
}
|
||||
}
|
||||
|
||||
func (l *durationLogger) Errorf(format string, v ...interface{}) {
|
||||
if shouldLog(ErrorLevel) {
|
||||
if shallLog(ErrorLevel) {
|
||||
l.write(errorLog, levelError, formatWithCaller(fmt.Sprintf(format, v...), durationCallerDepth))
|
||||
}
|
||||
}
|
||||
|
||||
func (l *durationLogger) Errorv(v interface{}) {
|
||||
if shallLog(ErrorLevel) {
|
||||
l.write(errorLog, levelError, v)
|
||||
}
|
||||
}
|
||||
|
||||
func (l *durationLogger) Info(v ...interface{}) {
|
||||
if shouldLog(InfoLevel) {
|
||||
if shallLog(InfoLevel) {
|
||||
l.write(infoLog, levelInfo, fmt.Sprint(v...))
|
||||
}
|
||||
}
|
||||
|
||||
func (l *durationLogger) Infof(format string, v ...interface{}) {
|
||||
if shouldLog(InfoLevel) {
|
||||
if shallLog(InfoLevel) {
|
||||
l.write(infoLog, levelInfo, fmt.Sprintf(format, v...))
|
||||
}
|
||||
}
|
||||
|
||||
func (l *durationLogger) Infov(v interface{}) {
|
||||
if shallLog(InfoLevel) {
|
||||
l.write(infoLog, levelInfo, v)
|
||||
}
|
||||
}
|
||||
|
||||
func (l *durationLogger) Slow(v ...interface{}) {
|
||||
if shouldLog(ErrorLevel) {
|
||||
if shallLog(ErrorLevel) {
|
||||
l.write(slowLog, levelSlow, fmt.Sprint(v...))
|
||||
}
|
||||
}
|
||||
|
||||
func (l *durationLogger) Slowf(format string, v ...interface{}) {
|
||||
if shouldLog(ErrorLevel) {
|
||||
if shallLog(ErrorLevel) {
|
||||
l.write(slowLog, levelSlow, fmt.Sprintf(format, v...))
|
||||
}
|
||||
}
|
||||
|
||||
func (l *durationLogger) Slowv(v interface{}) {
|
||||
if shallLog(ErrorLevel) {
|
||||
l.write(slowLog, levelSlow, v)
|
||||
}
|
||||
}
|
||||
|
||||
func (l *durationLogger) WithDuration(duration time.Duration) Logger {
|
||||
l.Duration = timex.ReprOfDuration(duration)
|
||||
return l
|
||||
}
|
||||
|
||||
func (l *durationLogger) write(writer io.Writer, level, content string) {
|
||||
func (l *durationLogger) write(writer io.Writer, level string, val interface{}) {
|
||||
l.Timestamp = getTimestamp()
|
||||
l.Level = level
|
||||
l.Content = content
|
||||
outputJson(writer, logEntry(*l))
|
||||
l.Content = val
|
||||
outputJson(writer, l)
|
||||
}
|
||||
|
||||
@@ -65,12 +65,14 @@ var (
|
||||
timeFormat = "2006-01-02T15:04:05.000Z07"
|
||||
writeConsole bool
|
||||
logLevel uint32
|
||||
infoLog io.WriteCloser
|
||||
errorLog io.WriteCloser
|
||||
severeLog io.WriteCloser
|
||||
slowLog io.WriteCloser
|
||||
statLog io.WriteCloser
|
||||
stackLog io.Writer
|
||||
// use uint32 for atomic operations
|
||||
disableStat uint32
|
||||
infoLog io.WriteCloser
|
||||
errorLog io.WriteCloser
|
||||
severeLog io.WriteCloser
|
||||
slowLog io.WriteCloser
|
||||
statLog io.WriteCloser
|
||||
stackLog io.Writer
|
||||
|
||||
once sync.Once
|
||||
initialized uint32
|
||||
@@ -79,10 +81,10 @@ var (
|
||||
|
||||
type (
|
||||
logEntry struct {
|
||||
Timestamp string `json:"@timestamp"`
|
||||
Level string `json:"level"`
|
||||
Duration string `json:"duration,omitempty"`
|
||||
Content string `json:"content"`
|
||||
Timestamp string `json:"@timestamp"`
|
||||
Level string `json:"level"`
|
||||
Duration string `json:"duration,omitempty"`
|
||||
Content interface{} `json:"content"`
|
||||
}
|
||||
|
||||
logOptions struct {
|
||||
@@ -98,10 +100,13 @@ type (
|
||||
Logger interface {
|
||||
Error(...interface{})
|
||||
Errorf(string, ...interface{})
|
||||
Errorv(interface{})
|
||||
Info(...interface{})
|
||||
Infof(string, ...interface{})
|
||||
Infov(interface{})
|
||||
Slow(...interface{})
|
||||
Slowf(string, ...interface{})
|
||||
Slowv(interface{})
|
||||
WithDuration(time.Duration) Logger
|
||||
}
|
||||
)
|
||||
@@ -133,7 +138,7 @@ func SetUp(c LogConf) error {
|
||||
|
||||
// Alert alerts v in alert level, and the message is written to error log.
|
||||
func Alert(v string) {
|
||||
output(errorLog, levelAlert, v)
|
||||
outputText(errorLog, levelAlert, v)
|
||||
}
|
||||
|
||||
// Close closes the logging.
|
||||
@@ -195,24 +200,29 @@ func Disable() {
|
||||
})
|
||||
}
|
||||
|
||||
// DisableStat disables the stat logs.
|
||||
func DisableStat() {
|
||||
atomic.StoreUint32(&disableStat, 1)
|
||||
}
|
||||
|
||||
// Error writes v into error log.
|
||||
func Error(v ...interface{}) {
|
||||
ErrorCaller(1, v...)
|
||||
}
|
||||
|
||||
// Errorf writes v with format into error log.
|
||||
func Errorf(format string, v ...interface{}) {
|
||||
ErrorCallerf(1, format, v...)
|
||||
}
|
||||
|
||||
// ErrorCaller writes v with context into error log.
|
||||
func ErrorCaller(callDepth int, v ...interface{}) {
|
||||
errorSync(fmt.Sprint(v...), callDepth+callerInnerDepth)
|
||||
errorTextSync(fmt.Sprint(v...), callDepth+callerInnerDepth)
|
||||
}
|
||||
|
||||
// ErrorCallerf writes v with context in format into error log.
|
||||
func ErrorCallerf(callDepth int, format string, v ...interface{}) {
|
||||
errorSync(fmt.Sprintf(format, v...), callDepth+callerInnerDepth)
|
||||
errorTextSync(fmt.Sprintf(format, v...), callDepth+callerInnerDepth)
|
||||
}
|
||||
|
||||
// Errorf writes v with format into error log.
|
||||
func Errorf(format string, v ...interface{}) {
|
||||
ErrorCallerf(1, format, v...)
|
||||
}
|
||||
|
||||
// ErrorStack writes v along with call stack into error log.
|
||||
@@ -227,14 +237,25 @@ func ErrorStackf(format string, v ...interface{}) {
|
||||
stackSync(fmt.Sprintf(format, v...))
|
||||
}
|
||||
|
||||
// Errorv writes v into error log with json content.
|
||||
// No call stack attached, because not elegant to pack the messages.
|
||||
func Errorv(v interface{}) {
|
||||
errorAnySync(v)
|
||||
}
|
||||
|
||||
// Info writes v into access log.
|
||||
func Info(v ...interface{}) {
|
||||
infoSync(fmt.Sprint(v...))
|
||||
infoTextSync(fmt.Sprint(v...))
|
||||
}
|
||||
|
||||
// Infof writes v with format into access log.
|
||||
func Infof(format string, v ...interface{}) {
|
||||
infoSync(fmt.Sprintf(format, v...))
|
||||
infoTextSync(fmt.Sprintf(format, v...))
|
||||
}
|
||||
|
||||
// Infov writes v into access log with json content.
|
||||
func Infov(v interface{}) {
|
||||
infoAnySync(v)
|
||||
}
|
||||
|
||||
// Must checks if err is nil, otherwise logs the err and exits.
|
||||
@@ -242,7 +263,7 @@ func Must(err error) {
|
||||
if err != nil {
|
||||
msg := formatWithCaller(err.Error(), 3)
|
||||
log.Print(msg)
|
||||
output(severeLog, levelFatal, msg)
|
||||
outputText(severeLog, levelFatal, msg)
|
||||
os.Exit(1)
|
||||
}
|
||||
}
|
||||
@@ -264,12 +285,17 @@ func Severef(format string, v ...interface{}) {
|
||||
|
||||
// Slow writes v into slow log.
|
||||
func Slow(v ...interface{}) {
|
||||
slowSync(fmt.Sprint(v...))
|
||||
slowTextSync(fmt.Sprint(v...))
|
||||
}
|
||||
|
||||
// Slowf writes v with format into slow log.
|
||||
func Slowf(format string, v ...interface{}) {
|
||||
slowSync(fmt.Sprintf(format, v...))
|
||||
slowTextSync(fmt.Sprintf(format, v...))
|
||||
}
|
||||
|
||||
// Slowv writes v into slow log with json content.
|
||||
func Slowv(v interface{}) {
|
||||
slowAnySync(v)
|
||||
}
|
||||
|
||||
// Stat writes v into stat log.
|
||||
@@ -312,8 +338,14 @@ func createOutput(path string) (io.WriteCloser, error) {
|
||||
options.gzipEnabled), options.gzipEnabled)
|
||||
}
|
||||
|
||||
func errorSync(msg string, callDepth int) {
|
||||
if shouldLog(ErrorLevel) {
|
||||
func errorAnySync(v interface{}) {
|
||||
if shallLog(ErrorLevel) {
|
||||
outputAny(errorLog, levelError, v)
|
||||
}
|
||||
}
|
||||
|
||||
func errorTextSync(msg string, callDepth int) {
|
||||
if shallLog(ErrorLevel) {
|
||||
outputError(errorLog, msg, callDepth)
|
||||
}
|
||||
}
|
||||
@@ -362,13 +394,28 @@ func handleOptions(opts []LogOption) {
|
||||
}
|
||||
}
|
||||
|
||||
func infoSync(msg string) {
|
||||
if shouldLog(InfoLevel) {
|
||||
output(infoLog, levelInfo, msg)
|
||||
func infoAnySync(val interface{}) {
|
||||
if shallLog(InfoLevel) {
|
||||
outputAny(infoLog, levelInfo, val)
|
||||
}
|
||||
}
|
||||
|
||||
func output(writer io.Writer, level, msg string) {
|
||||
func infoTextSync(msg string) {
|
||||
if shallLog(InfoLevel) {
|
||||
outputText(infoLog, levelInfo, msg)
|
||||
}
|
||||
}
|
||||
|
||||
func outputAny(writer io.Writer, level string, val interface{}) {
|
||||
info := logEntry{
|
||||
Timestamp: getTimestamp(),
|
||||
Level: level,
|
||||
Content: val,
|
||||
}
|
||||
outputJson(writer, info)
|
||||
}
|
||||
|
||||
func outputText(writer io.Writer, level, msg string) {
|
||||
info := logEntry{
|
||||
Timestamp: getTimestamp(),
|
||||
Level: level,
|
||||
@@ -379,7 +426,7 @@ func output(writer io.Writer, level, msg string) {
|
||||
|
||||
func outputError(writer io.Writer, msg string, callDepth int) {
|
||||
content := formatWithCaller(msg, callDepth)
|
||||
output(writer, levelError, content)
|
||||
outputText(writer, levelError, content)
|
||||
}
|
||||
|
||||
func outputJson(writer io.Writer, info interface{}) {
|
||||
@@ -481,30 +528,40 @@ func setupWithVolume(c LogConf) error {
|
||||
}
|
||||
|
||||
func severeSync(msg string) {
|
||||
if shouldLog(SevereLevel) {
|
||||
output(severeLog, levelSevere, fmt.Sprintf("%s\n%s", msg, string(debug.Stack())))
|
||||
if shallLog(SevereLevel) {
|
||||
outputText(severeLog, levelSevere, fmt.Sprintf("%s\n%s", msg, string(debug.Stack())))
|
||||
}
|
||||
}
|
||||
|
||||
func shouldLog(level uint32) bool {
|
||||
func shallLog(level uint32) bool {
|
||||
return atomic.LoadUint32(&logLevel) <= level
|
||||
}
|
||||
|
||||
func slowSync(msg string) {
|
||||
if shouldLog(ErrorLevel) {
|
||||
output(slowLog, levelSlow, msg)
|
||||
func shallLogStat() bool {
|
||||
return atomic.LoadUint32(&disableStat) == 0
|
||||
}
|
||||
|
||||
func slowAnySync(v interface{}) {
|
||||
if shallLog(ErrorLevel) {
|
||||
outputAny(slowLog, levelSlow, v)
|
||||
}
|
||||
}
|
||||
|
||||
func slowTextSync(msg string) {
|
||||
if shallLog(ErrorLevel) {
|
||||
outputText(slowLog, levelSlow, msg)
|
||||
}
|
||||
}
|
||||
|
||||
func stackSync(msg string) {
|
||||
if shouldLog(ErrorLevel) {
|
||||
output(stackLog, levelError, fmt.Sprintf("%s\n%s", msg, string(debug.Stack())))
|
||||
if shallLog(ErrorLevel) {
|
||||
outputText(stackLog, levelError, fmt.Sprintf("%s\n%s", msg, string(debug.Stack())))
|
||||
}
|
||||
}
|
||||
|
||||
func statSync(msg string) {
|
||||
if shouldLog(InfoLevel) {
|
||||
output(statLog, levelStat, msg)
|
||||
if shallLogStat() && shallLog(InfoLevel) {
|
||||
outputText(statLog, levelStat, msg)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -92,6 +92,30 @@ func TestStructedLogAlert(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
func TestStructedLogError(t *testing.T) {
|
||||
doTestStructedLog(t, levelError, func(writer io.WriteCloser) {
|
||||
errorLog = writer
|
||||
}, func(v ...interface{}) {
|
||||
Error(v...)
|
||||
})
|
||||
}
|
||||
|
||||
func TestStructedLogErrorf(t *testing.T) {
|
||||
doTestStructedLog(t, levelError, func(writer io.WriteCloser) {
|
||||
errorLog = writer
|
||||
}, func(v ...interface{}) {
|
||||
Errorf("%s", fmt.Sprint(v...))
|
||||
})
|
||||
}
|
||||
|
||||
func TestStructedLogErrorv(t *testing.T) {
|
||||
doTestStructedLog(t, levelError, func(writer io.WriteCloser) {
|
||||
errorLog = writer
|
||||
}, func(v ...interface{}) {
|
||||
Errorv(fmt.Sprint(v...))
|
||||
})
|
||||
}
|
||||
|
||||
func TestStructedLogInfo(t *testing.T) {
|
||||
doTestStructedLog(t, levelInfo, func(writer io.WriteCloser) {
|
||||
infoLog = writer
|
||||
@@ -100,6 +124,22 @@ func TestStructedLogInfo(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
func TestStructedLogInfof(t *testing.T) {
|
||||
doTestStructedLog(t, levelInfo, func(writer io.WriteCloser) {
|
||||
infoLog = writer
|
||||
}, func(v ...interface{}) {
|
||||
Infof("%s", fmt.Sprint(v...))
|
||||
})
|
||||
}
|
||||
|
||||
func TestStructedLogInfov(t *testing.T) {
|
||||
doTestStructedLog(t, levelInfo, func(writer io.WriteCloser) {
|
||||
infoLog = writer
|
||||
}, func(v ...interface{}) {
|
||||
Infov(fmt.Sprint(v...))
|
||||
})
|
||||
}
|
||||
|
||||
func TestStructedLogSlow(t *testing.T) {
|
||||
doTestStructedLog(t, levelSlow, func(writer io.WriteCloser) {
|
||||
slowLog = writer
|
||||
@@ -116,6 +156,14 @@ func TestStructedLogSlowf(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
func TestStructedLogSlowv(t *testing.T) {
|
||||
doTestStructedLog(t, levelSlow, func(writer io.WriteCloser) {
|
||||
slowLog = writer
|
||||
}, func(v ...interface{}) {
|
||||
Slowv(fmt.Sprint(v...))
|
||||
})
|
||||
}
|
||||
|
||||
func TestStructedLogStat(t *testing.T) {
|
||||
doTestStructedLog(t, levelStat, func(writer io.WriteCloser) {
|
||||
statLog = writer
|
||||
@@ -246,6 +294,17 @@ func TestDisable(t *testing.T) {
|
||||
assert.Nil(t, Close())
|
||||
}
|
||||
|
||||
func TestDisableStat(t *testing.T) {
|
||||
DisableStat()
|
||||
|
||||
const message = "hello there"
|
||||
writer := new(mockWriter)
|
||||
statLog = writer
|
||||
atomic.StoreUint32(&initialized, 1)
|
||||
Stat(message)
|
||||
assert.Equal(t, 0, writer.builder.Len())
|
||||
}
|
||||
|
||||
func TestWithGzip(t *testing.T) {
|
||||
fn := WithGzip()
|
||||
var opt logOptions
|
||||
@@ -357,7 +416,9 @@ func doTestStructedLog(t *testing.T, level string, setup func(writer io.WriteClo
|
||||
t.Error(err)
|
||||
}
|
||||
assert.Equal(t, level, entry.Level)
|
||||
assert.True(t, strings.Contains(entry.Content, message))
|
||||
val, ok := entry.Content.(string)
|
||||
assert.True(t, ok)
|
||||
assert.True(t, strings.Contains(val, message))
|
||||
}
|
||||
|
||||
func testSetLevelTwiceWithMode(t *testing.T, mode string) {
|
||||
|
||||
@@ -44,5 +44,5 @@ func captureOutput(f func()) string {
|
||||
func getContent(jsonStr string) string {
|
||||
var entry logEntry
|
||||
json.Unmarshal([]byte(jsonStr), &entry)
|
||||
return entry.Content
|
||||
return entry.Content.(string)
|
||||
}
|
||||
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
|
||||
"github.com/tal-tech/go-zero/core/timex"
|
||||
"github.com/tal-tech/go-zero/core/trace/tracespec"
|
||||
"go.opentelemetry.io/otel/trace"
|
||||
)
|
||||
|
||||
type traceLogger struct {
|
||||
@@ -18,50 +19,68 @@ type traceLogger struct {
|
||||
}
|
||||
|
||||
func (l *traceLogger) Error(v ...interface{}) {
|
||||
if shouldLog(ErrorLevel) {
|
||||
if shallLog(ErrorLevel) {
|
||||
l.write(errorLog, levelError, formatWithCaller(fmt.Sprint(v...), durationCallerDepth))
|
||||
}
|
||||
}
|
||||
|
||||
func (l *traceLogger) Errorf(format string, v ...interface{}) {
|
||||
if shouldLog(ErrorLevel) {
|
||||
if shallLog(ErrorLevel) {
|
||||
l.write(errorLog, levelError, formatWithCaller(fmt.Sprintf(format, v...), durationCallerDepth))
|
||||
}
|
||||
}
|
||||
|
||||
func (l *traceLogger) Errorv(v interface{}) {
|
||||
if shallLog(ErrorLevel) {
|
||||
l.write(errorLog, levelError, v)
|
||||
}
|
||||
}
|
||||
|
||||
func (l *traceLogger) Info(v ...interface{}) {
|
||||
if shouldLog(InfoLevel) {
|
||||
if shallLog(InfoLevel) {
|
||||
l.write(infoLog, levelInfo, fmt.Sprint(v...))
|
||||
}
|
||||
}
|
||||
|
||||
func (l *traceLogger) Infof(format string, v ...interface{}) {
|
||||
if shouldLog(InfoLevel) {
|
||||
if shallLog(InfoLevel) {
|
||||
l.write(infoLog, levelInfo, fmt.Sprintf(format, v...))
|
||||
}
|
||||
}
|
||||
|
||||
func (l *traceLogger) Infov(v interface{}) {
|
||||
if shallLog(InfoLevel) {
|
||||
l.write(infoLog, levelInfo, v)
|
||||
}
|
||||
}
|
||||
|
||||
func (l *traceLogger) Slow(v ...interface{}) {
|
||||
if shouldLog(ErrorLevel) {
|
||||
if shallLog(ErrorLevel) {
|
||||
l.write(slowLog, levelSlow, fmt.Sprint(v...))
|
||||
}
|
||||
}
|
||||
|
||||
func (l *traceLogger) Slowf(format string, v ...interface{}) {
|
||||
if shouldLog(ErrorLevel) {
|
||||
if shallLog(ErrorLevel) {
|
||||
l.write(slowLog, levelSlow, fmt.Sprintf(format, v...))
|
||||
}
|
||||
}
|
||||
|
||||
func (l *traceLogger) Slowv(v interface{}) {
|
||||
if shallLog(ErrorLevel) {
|
||||
l.write(slowLog, levelSlow, v)
|
||||
}
|
||||
}
|
||||
|
||||
func (l *traceLogger) WithDuration(duration time.Duration) Logger {
|
||||
l.Duration = timex.ReprOfDuration(duration)
|
||||
return l
|
||||
}
|
||||
|
||||
func (l *traceLogger) write(writer io.Writer, level, content string) {
|
||||
func (l *traceLogger) write(writer io.Writer, level string, val interface{}) {
|
||||
l.Timestamp = getTimestamp()
|
||||
l.Level = level
|
||||
l.Content = content
|
||||
l.Content = val
|
||||
l.Trace = traceIdFromContext(l.ctx)
|
||||
l.Span = spanIdFromContext(l.ctx)
|
||||
outputJson(writer, l)
|
||||
@@ -75,6 +94,11 @@ func WithContext(ctx context.Context) Logger {
|
||||
}
|
||||
|
||||
func spanIdFromContext(ctx context.Context) string {
|
||||
span := trace.SpanFromContext(ctx)
|
||||
if span.IsRecording() {
|
||||
return span.SpanContext().SpanID().String()
|
||||
}
|
||||
|
||||
t, ok := ctx.Value(tracespec.TracingKey).(tracespec.Trace)
|
||||
if !ok {
|
||||
return ""
|
||||
@@ -84,6 +108,11 @@ func spanIdFromContext(ctx context.Context) string {
|
||||
}
|
||||
|
||||
func traceIdFromContext(ctx context.Context) string {
|
||||
span := trace.SpanFromContext(ctx)
|
||||
if span.IsRecording() {
|
||||
return span.SpanContext().SpanID().String()
|
||||
}
|
||||
|
||||
t, ok := ctx.Value(tracespec.TracingKey).(tracespec.Trace)
|
||||
if !ok {
|
||||
return ""
|
||||
|
||||
@@ -112,5 +112,5 @@ func (t mockTrace) Follow(ctx context.Context, serviceName, operationName string
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (t mockTrace) Visit(fn func(key string, val string) bool) {
|
||||
func (t mockTrace) Visit(fn func(key, val string) bool) {
|
||||
}
|
||||
|
||||
@@ -43,7 +43,8 @@ type (
|
||||
UnmarshalOption func(*unmarshalOptions)
|
||||
|
||||
unmarshalOptions struct {
|
||||
fromString bool
|
||||
fromString bool
|
||||
canonicalKey func(key string) string
|
||||
}
|
||||
|
||||
keyCache map[string][]string
|
||||
@@ -229,7 +230,7 @@ func (u *Unmarshaler) processFieldPrimitive(field reflect.StructField, value ref
|
||||
default:
|
||||
switch v := mapValue.(type) {
|
||||
case json.Number:
|
||||
return u.processFieldPrimitiveWithJsonNumber(field, value, v, opts, fullName)
|
||||
return u.processFieldPrimitiveWithJSONNumber(field, value, v, opts, fullName)
|
||||
default:
|
||||
if typeKind == valueKind {
|
||||
if err := validateValueInOptions(opts.options(), mapValue); err != nil {
|
||||
@@ -244,7 +245,7 @@ func (u *Unmarshaler) processFieldPrimitive(field reflect.StructField, value ref
|
||||
return newTypeMismatchError(fullName)
|
||||
}
|
||||
|
||||
func (u *Unmarshaler) processFieldPrimitiveWithJsonNumber(field reflect.StructField, value reflect.Value,
|
||||
func (u *Unmarshaler) processFieldPrimitiveWithJSONNumber(field reflect.StructField, value reflect.Value,
|
||||
v json.Number, opts *fieldOptionsWithContext, fullName string) error {
|
||||
fieldType := field.Type
|
||||
fieldKind := fieldType.Kind()
|
||||
@@ -323,7 +324,11 @@ func (u *Unmarshaler) processNamedField(field reflect.StructField, value reflect
|
||||
}
|
||||
|
||||
fullName = join(fullName, key)
|
||||
mapValue, hasValue := getValue(m, key)
|
||||
canonicalKey := key
|
||||
if u.opts.canonicalKey != nil {
|
||||
canonicalKey = u.opts.canonicalKey(key)
|
||||
}
|
||||
mapValue, hasValue := getValue(m, canonicalKey)
|
||||
if hasValue {
|
||||
return u.processNamedFieldWithValue(field, value, mapValue, key, opts, fullName)
|
||||
}
|
||||
@@ -457,6 +462,10 @@ func (u *Unmarshaler) fillSlice(fieldType reflect.Type, value reflect.Value, map
|
||||
} else {
|
||||
conv.Index(i).Set(target.Elem())
|
||||
}
|
||||
case reflect.Slice:
|
||||
if err := u.fillSlice(dereffedBaseType, conv.Index(i), ithValue); err != nil {
|
||||
return err
|
||||
}
|
||||
default:
|
||||
if err := u.fillSliceValue(conv, i, dereffedBaseKind, ithValue); err != nil {
|
||||
return err
|
||||
@@ -492,17 +501,30 @@ func (u *Unmarshaler) fillSliceFromString(fieldType reflect.Type, value reflect.
|
||||
}
|
||||
|
||||
func (u *Unmarshaler) fillSliceValue(slice reflect.Value, index int, baseKind reflect.Kind, value interface{}) error {
|
||||
ithVal := slice.Index(index)
|
||||
switch v := value.(type) {
|
||||
case json.Number:
|
||||
return setValue(baseKind, slice.Index(index), v.String())
|
||||
return setValue(baseKind, ithVal, v.String())
|
||||
default:
|
||||
// don't need to consider the difference between int, int8, int16, int32, int64,
|
||||
// uint, uint8, uint16, uint32, uint64, because they're handled as json.Number.
|
||||
if slice.Index(index).Kind() != reflect.TypeOf(value).Kind() {
|
||||
if ithVal.Kind() == reflect.Ptr {
|
||||
baseType := Deref(ithVal.Type())
|
||||
if baseType.Kind() != reflect.TypeOf(value).Kind() {
|
||||
return errTypeMismatch
|
||||
}
|
||||
|
||||
target := reflect.New(baseType).Elem()
|
||||
target.Set(reflect.ValueOf(value))
|
||||
ithVal.Set(target.Addr())
|
||||
return nil
|
||||
}
|
||||
|
||||
if ithVal.Kind() != reflect.TypeOf(value).Kind() {
|
||||
return errTypeMismatch
|
||||
}
|
||||
|
||||
slice.Index(index).Set(reflect.ValueOf(value))
|
||||
ithVal.Set(reflect.ValueOf(value))
|
||||
return nil
|
||||
}
|
||||
}
|
||||
@@ -604,6 +626,13 @@ func WithStringValues() UnmarshalOption {
|
||||
}
|
||||
}
|
||||
|
||||
// WithCanonicalKeyFunc customizes a Unmarshaler with Canonical Key func
|
||||
func WithCanonicalKeyFunc(f func(string) string) UnmarshalOption {
|
||||
return func(opt *unmarshalOptions) {
|
||||
opt.canonicalKey = f
|
||||
}
|
||||
}
|
||||
|
||||
func fillDurationValue(fieldKind reflect.Kind, value reflect.Value, dur string) error {
|
||||
d, err := time.ParseDuration(dur)
|
||||
if err != nil {
|
||||
|
||||
@@ -3,6 +3,7 @@ package mapping
|
||||
import (
|
||||
"encoding/json"
|
||||
"strconv"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
@@ -2480,3 +2481,40 @@ func BenchmarkUnmarshal(b *testing.B) {
|
||||
UnmarshalKey(data, &an)
|
||||
}
|
||||
}
|
||||
|
||||
func TestUnmarshalJsonReaderMultiArray(t *testing.T) {
|
||||
payload := `{"a": "133", "b": [["add", "cccd"], ["eeee"]]}`
|
||||
var res struct {
|
||||
A string `json:"a"`
|
||||
B [][]string `json:"b"`
|
||||
}
|
||||
reader := strings.NewReader(payload)
|
||||
err := UnmarshalJsonReader(reader, &res)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, 2, len(res.B))
|
||||
}
|
||||
|
||||
func TestUnmarshalJsonReaderPtrMultiArray(t *testing.T) {
|
||||
payload := `{"a": "133", "b": [["add", "cccd"], ["eeee"]]}`
|
||||
var res struct {
|
||||
A string `json:"a"`
|
||||
B [][]*string `json:"b"`
|
||||
}
|
||||
reader := strings.NewReader(payload)
|
||||
err := UnmarshalJsonReader(reader, &res)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, 2, len(res.B))
|
||||
assert.Equal(t, 2, len(res.B[0]))
|
||||
}
|
||||
|
||||
func TestUnmarshalJsonReaderPtrArray(t *testing.T) {
|
||||
payload := `{"a": "133", "b": ["add", "cccd", "eeee"]}`
|
||||
var res struct {
|
||||
A string `json:"a"`
|
||||
B []*string `json:"b"`
|
||||
}
|
||||
reader := strings.NewReader(payload)
|
||||
err := UnmarshalJsonReader(reader, &res)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, 3, len(res.B))
|
||||
}
|
||||
|
||||
@@ -286,7 +286,7 @@ func parseNumberRange(str string) (*numberRange, error) {
|
||||
}, nil
|
||||
}
|
||||
|
||||
func parseOption(fieldOpts *fieldOptions, fieldName string, option string) error {
|
||||
func parseOption(fieldOpts *fieldOptions, fieldName, option string) error {
|
||||
switch {
|
||||
case option == stringOption:
|
||||
fieldOpts.FromString = true
|
||||
|
||||
@@ -112,6 +112,12 @@ func MapReduceWithSource(source <-chan interface{}, mapper MapperFunc, reducer R
|
||||
opts ...Option) (interface{}, error) {
|
||||
options := buildOptions(opts...)
|
||||
output := make(chan interface{})
|
||||
defer func() {
|
||||
for range output {
|
||||
panic("more than one element written in reducer")
|
||||
}
|
||||
}()
|
||||
|
||||
collector := make(chan interface{}, options.workers)
|
||||
done := syncx.NewDoneChan()
|
||||
writer := newGuardedWriter(output, done.Done())
|
||||
@@ -136,14 +142,16 @@ func MapReduceWithSource(source <-chan interface{}, mapper MapperFunc, reducer R
|
||||
|
||||
go func() {
|
||||
defer func() {
|
||||
drain(collector)
|
||||
|
||||
if r := recover(); r != nil {
|
||||
cancel(fmt.Errorf("%v", r))
|
||||
} else {
|
||||
finish()
|
||||
}
|
||||
}()
|
||||
|
||||
reducer(collector, writer, cancel)
|
||||
drain(collector)
|
||||
}()
|
||||
|
||||
go executeMappers(func(item interface{}, w Writer) {
|
||||
@@ -165,7 +173,6 @@ func MapReduceWithSource(source <-chan interface{}, mapper MapperFunc, reducer R
|
||||
func MapReduceVoid(generate GenerateFunc, mapper MapperFunc, reducer VoidReducerFunc, opts ...Option) error {
|
||||
_, err := MapReduce(generate, mapper, func(input <-chan interface{}, writer Writer, cancel func(error)) {
|
||||
reducer(input, cancel)
|
||||
drain(input)
|
||||
// We need to write a placeholder to let MapReduce to continue on reducer done,
|
||||
// otherwise, all goroutines are waiting. The placeholder will be discarded by MapReduce.
|
||||
writer.Write(lang.Placeholder)
|
||||
|
||||
@@ -202,6 +202,22 @@ func TestMapReduce(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestMapReduceWithReduerWriteMoreThanOnce(t *testing.T) {
|
||||
assert.Panics(t, func() {
|
||||
MapReduce(func(source chan<- interface{}) {
|
||||
for i := 0; i < 10; i++ {
|
||||
source <- i
|
||||
}
|
||||
}, func(item interface{}, writer Writer, cancel func(error)) {
|
||||
writer.Write(item)
|
||||
}, func(pipe <-chan interface{}, writer Writer, cancel func(error)) {
|
||||
drain(pipe)
|
||||
writer.Write("one")
|
||||
writer.Write("two")
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func TestMapReduceVoid(t *testing.T) {
|
||||
var value uint32
|
||||
tests := []struct {
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
//go:build windows
|
||||
// +build windows
|
||||
|
||||
package proc
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
//go:build linux || darwin
|
||||
// +build linux darwin
|
||||
|
||||
package proc
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
//go:build windows
|
||||
// +build windows
|
||||
|
||||
package proc
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
//go:build linux || darwin
|
||||
// +build linux darwin
|
||||
|
||||
package proc
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
//go:build windows
|
||||
// +build windows
|
||||
|
||||
package proc
|
||||
@@ -14,5 +15,5 @@ func AddWrapUpListener(fn func()) func() {
|
||||
return fn
|
||||
}
|
||||
|
||||
func SetTimeoutToForceQuit(duration time.Duration) {
|
||||
func SetTimeToForceQuit(duration time.Duration) {
|
||||
}
|
||||
|
||||
10
core/proc/signals+polyfill.go
Normal file
10
core/proc/signals+polyfill.go
Normal file
@@ -0,0 +1,10 @@
|
||||
//go:build windows
|
||||
// +build windows
|
||||
|
||||
package proc
|
||||
|
||||
import "context"
|
||||
|
||||
func Done() <-chan struct{} {
|
||||
return context.Background().Done()
|
||||
}
|
||||
@@ -1,3 +1,4 @@
|
||||
//go:build linux || darwin
|
||||
// +build linux darwin
|
||||
|
||||
package proc
|
||||
@@ -12,6 +13,8 @@ import (
|
||||
|
||||
const timeFormat = "0102150405"
|
||||
|
||||
var done = make(chan struct{})
|
||||
|
||||
func init() {
|
||||
go func() {
|
||||
var profiler Stopper
|
||||
@@ -33,6 +36,13 @@ func init() {
|
||||
profiler = nil
|
||||
}
|
||||
case syscall.SIGTERM:
|
||||
select {
|
||||
case <-done:
|
||||
// already closed
|
||||
default:
|
||||
close(done)
|
||||
}
|
||||
|
||||
gracefulStop(signals)
|
||||
default:
|
||||
logx.Error("Got unregistered signal:", v)
|
||||
@@ -40,3 +50,8 @@ func init() {
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// Done returns the channel that notifies the process quitting.
|
||||
func Done() <-chan struct{} {
|
||||
return done
|
||||
}
|
||||
|
||||
16
core/proc/signals_test.go
Normal file
16
core/proc/signals_test.go
Normal file
@@ -0,0 +1,16 @@
|
||||
package proc
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestDone(t *testing.T) {
|
||||
select {
|
||||
case <-Done():
|
||||
assert.Fail(t, "should run")
|
||||
default:
|
||||
}
|
||||
assert.NotNil(t, Done())
|
||||
}
|
||||
@@ -71,7 +71,7 @@ func NewQueue(producerFactory ProducerFactory, consumerFactory ConsumerFactory)
|
||||
return q
|
||||
}
|
||||
|
||||
// AddListener adds a litener to q.
|
||||
// AddListener adds a listener to q.
|
||||
func (q *Queue) AddListener(listener Listener) {
|
||||
q.listeners = append(q.listeners, listener)
|
||||
}
|
||||
|
||||
@@ -1,6 +1,9 @@
|
||||
package search
|
||||
|
||||
import "errors"
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
const (
|
||||
colon = ':'
|
||||
@@ -8,16 +11,16 @@ const (
|
||||
)
|
||||
|
||||
var (
|
||||
// ErrDupItem means adding duplicated item.
|
||||
ErrDupItem = errors.New("duplicated item")
|
||||
// ErrDupSlash means item is started with more than one slash.
|
||||
ErrDupSlash = errors.New("duplicated slash")
|
||||
// ErrEmptyItem means adding empty item.
|
||||
ErrEmptyItem = errors.New("empty item")
|
||||
// ErrInvalidState means search tree is in an invalid state.
|
||||
ErrInvalidState = errors.New("search tree is in an invalid state")
|
||||
// ErrNotFromRoot means path is not starting with slash.
|
||||
ErrNotFromRoot = errors.New("path should start with /")
|
||||
// errDupItem means adding duplicated item.
|
||||
errDupItem = errors.New("duplicated item")
|
||||
// errDupSlash means item is started with more than one slash.
|
||||
errDupSlash = errors.New("duplicated slash")
|
||||
// errEmptyItem means adding empty item.
|
||||
errEmptyItem = errors.New("empty item")
|
||||
// errInvalidState means search tree is in an invalid state.
|
||||
errInvalidState = errors.New("search tree is in an invalid state")
|
||||
// errNotFromRoot means path is not starting with slash.
|
||||
errNotFromRoot = errors.New("path should start with /")
|
||||
|
||||
// NotFound is used to hold the not found result.
|
||||
NotFound Result
|
||||
@@ -58,14 +61,22 @@ func NewTree() *Tree {
|
||||
// Add adds item to associate with route.
|
||||
func (t *Tree) Add(route string, item interface{}) error {
|
||||
if len(route) == 0 || route[0] != slash {
|
||||
return ErrNotFromRoot
|
||||
return errNotFromRoot
|
||||
}
|
||||
|
||||
if item == nil {
|
||||
return ErrEmptyItem
|
||||
return errEmptyItem
|
||||
}
|
||||
|
||||
return add(t.root, route[1:], item)
|
||||
err := add(t.root, route[1:], item)
|
||||
switch err {
|
||||
case errDupItem:
|
||||
return duplicatedItem(route)
|
||||
case errDupSlash:
|
||||
return duplicatedSlash(route)
|
||||
default:
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// Search searches item that associates with given route.
|
||||
@@ -86,22 +97,22 @@ func (t *Tree) next(n *node, route string, result *Result) bool {
|
||||
}
|
||||
|
||||
for i := range route {
|
||||
if route[i] == slash {
|
||||
token := route[:i]
|
||||
return n.forEach(func(k string, v *node) bool {
|
||||
if r := match(k, token); r.found {
|
||||
if t.next(v, route[i+1:], result) {
|
||||
if r.named {
|
||||
addParam(result, r.key, r.value)
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
})
|
||||
if route[i] != slash {
|
||||
continue
|
||||
}
|
||||
|
||||
token := route[:i]
|
||||
return n.forEach(func(k string, v *node) bool {
|
||||
r := match(k, token)
|
||||
if !r.found || !t.next(v, route[i+1:], result) {
|
||||
return false
|
||||
}
|
||||
if r.named {
|
||||
addParam(result, r.key, r.value)
|
||||
}
|
||||
|
||||
return true
|
||||
})
|
||||
}
|
||||
|
||||
return n.forEach(func(k string, v *node) bool {
|
||||
@@ -141,7 +152,7 @@ func (nd *node) getChildren(route string) map[string]*node {
|
||||
func add(nd *node, route string, item interface{}) error {
|
||||
if len(route) == 0 {
|
||||
if nd.item != nil {
|
||||
return ErrDupItem
|
||||
return errDupItem
|
||||
}
|
||||
|
||||
nd.item = item
|
||||
@@ -149,31 +160,33 @@ func add(nd *node, route string, item interface{}) error {
|
||||
}
|
||||
|
||||
if route[0] == slash {
|
||||
return ErrDupSlash
|
||||
return errDupSlash
|
||||
}
|
||||
|
||||
for i := range route {
|
||||
if route[i] == slash {
|
||||
token := route[:i]
|
||||
children := nd.getChildren(token)
|
||||
if child, ok := children[token]; ok {
|
||||
if child != nil {
|
||||
return add(child, route[i+1:], item)
|
||||
}
|
||||
if route[i] != slash {
|
||||
continue
|
||||
}
|
||||
|
||||
return ErrInvalidState
|
||||
token := route[:i]
|
||||
children := nd.getChildren(token)
|
||||
if child, ok := children[token]; ok {
|
||||
if child != nil {
|
||||
return add(child, route[i+1:], item)
|
||||
}
|
||||
|
||||
child := newNode(nil)
|
||||
children[token] = child
|
||||
return add(child, route[i+1:], item)
|
||||
return errInvalidState
|
||||
}
|
||||
|
||||
child := newNode(nil)
|
||||
children[token] = child
|
||||
return add(child, route[i+1:], item)
|
||||
}
|
||||
|
||||
children := nd.getChildren(route)
|
||||
if child, ok := children[route]; ok {
|
||||
if child.item != nil {
|
||||
return ErrDupItem
|
||||
return errDupItem
|
||||
}
|
||||
|
||||
child.item = item
|
||||
@@ -192,6 +205,14 @@ func addParam(result *Result, k, v string) {
|
||||
result.Params[k] = v
|
||||
}
|
||||
|
||||
func duplicatedItem(item string) error {
|
||||
return fmt.Errorf("duplicated item for %s", item)
|
||||
}
|
||||
|
||||
func duplicatedSlash(item string) error {
|
||||
return fmt.Errorf("duplicated slash for %s", item)
|
||||
}
|
||||
|
||||
func match(pat, token string) innerResult {
|
||||
if pat[0] == colon {
|
||||
return innerResult{
|
||||
|
||||
@@ -151,9 +151,9 @@ func TestAddDuplicate(t *testing.T) {
|
||||
err := tree.Add("/a/b", 1)
|
||||
assert.Nil(t, err)
|
||||
err = tree.Add("/a/b", 2)
|
||||
assert.Equal(t, ErrDupItem, err)
|
||||
assert.Error(t, errDupItem, err)
|
||||
err = tree.Add("/a/b/", 2)
|
||||
assert.Equal(t, ErrDupItem, err)
|
||||
assert.Error(t, errDupItem, err)
|
||||
}
|
||||
|
||||
func TestPlain(t *testing.T) {
|
||||
@@ -169,19 +169,19 @@ func TestPlain(t *testing.T) {
|
||||
func TestSearchWithDoubleSlashes(t *testing.T) {
|
||||
tree := NewTree()
|
||||
err := tree.Add("//a", 1)
|
||||
assert.Error(t, ErrDupSlash, err)
|
||||
assert.Error(t, errDupSlash, err)
|
||||
}
|
||||
|
||||
func TestSearchInvalidRoute(t *testing.T) {
|
||||
tree := NewTree()
|
||||
err := tree.Add("", 1)
|
||||
assert.Equal(t, ErrNotFromRoot, err)
|
||||
assert.Equal(t, errNotFromRoot, err)
|
||||
err = tree.Add("bad", 1)
|
||||
assert.Equal(t, ErrNotFromRoot, err)
|
||||
assert.Equal(t, errNotFromRoot, err)
|
||||
}
|
||||
|
||||
func TestSearchInvalidItem(t *testing.T) {
|
||||
tree := NewTree()
|
||||
err := tree.Add("/", nil)
|
||||
assert.Equal(t, ErrEmptyItem, err)
|
||||
assert.Equal(t, errEmptyItem, err)
|
||||
}
|
||||
|
||||
@@ -29,6 +29,8 @@ type ServiceConf struct {
|
||||
Mode string `json:",default=pro,options=dev|test|rt|pre|pro"`
|
||||
MetricsUrl string `json:",optional"`
|
||||
Prometheus prometheus.Config `json:",optional"`
|
||||
// TODO: enable it in v1.2.1
|
||||
// Telemetry opentelemetry.Config `json:",optional"`
|
||||
}
|
||||
|
||||
// MustSetUp sets up the service, exits on error.
|
||||
@@ -49,6 +51,13 @@ func (sc ServiceConf) SetUp() error {
|
||||
|
||||
sc.initMode()
|
||||
prometheus.StartAgent(sc.Prometheus)
|
||||
|
||||
// TODO: enable it in v1.2.1
|
||||
// if len(sc.Telemetry.Name) == 0 {
|
||||
// sc.Telemetry.Name = sc.Name
|
||||
// }
|
||||
// opentelemetry.StartAgent(sc.Telemetry)
|
||||
|
||||
if len(sc.MetricsUrl) > 0 {
|
||||
stat.SetReportWriter(stat.NewRemoteWriter(sc.MetricsUrl))
|
||||
}
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
//go:build !linux
|
||||
// +build !linux
|
||||
|
||||
package stat
|
||||
@@ -22,7 +22,7 @@ var (
|
||||
cores uint64
|
||||
)
|
||||
|
||||
// if /proc not present, ignore the cpu calcuation, like wsl linux
|
||||
// if /proc not present, ignore the cpu calculation, like wsl linux
|
||||
func init() {
|
||||
cpus, err := perCpuUsage()
|
||||
if err != nil {
|
||||
|
||||
@@ -38,7 +38,9 @@ func init() {
|
||||
atomic.StoreInt64(&cpuUsage, usage)
|
||||
})
|
||||
case <-allTicker.C:
|
||||
printUsage()
|
||||
if logEnabled.True() {
|
||||
printUsage()
|
||||
}
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
2
core/stores/cache/cache.go
vendored
2
core/stores/cache/cache.go
vendored
@@ -29,7 +29,7 @@ type (
|
||||
)
|
||||
|
||||
// New returns a Cache.
|
||||
func New(c ClusterConf, barrier syncx.SharedCalls, st *Stat, errNotFound error,
|
||||
func New(c ClusterConf, barrier syncx.SingleFlight, st *Stat, errNotFound error,
|
||||
opts ...Option) Cache {
|
||||
if len(c) == 0 || TotalWeights(c) <= 0 {
|
||||
log.Fatal("no cache nodes")
|
||||
|
||||
6
core/stores/cache/cache_test.go
vendored
6
core/stores/cache/cache_test.go
vendored
@@ -23,6 +23,7 @@ type mockedNode struct {
|
||||
|
||||
func (mc *mockedNode) Del(keys ...string) error {
|
||||
var be errorx.BatchError
|
||||
|
||||
for _, key := range keys {
|
||||
if _, ok := mc.vals[key]; !ok {
|
||||
be.Add(mc.errNotFound)
|
||||
@@ -30,6 +31,7 @@ func (mc *mockedNode) Del(keys ...string) error {
|
||||
delete(mc.vals, key)
|
||||
}
|
||||
}
|
||||
|
||||
return be.Err()
|
||||
}
|
||||
|
||||
@@ -102,7 +104,7 @@ func TestCache_SetDel(t *testing.T) {
|
||||
Weight: 100,
|
||||
},
|
||||
}
|
||||
c := New(conf, syncx.NewSharedCalls(), NewStat("mock"), errPlaceholder)
|
||||
c := New(conf, syncx.NewSingleFlight(), NewStat("mock"), errPlaceholder)
|
||||
for i := 0; i < total; i++ {
|
||||
if i%2 == 0 {
|
||||
assert.Nil(t, c.Set(fmt.Sprintf("key/%d", i), i))
|
||||
@@ -140,7 +142,7 @@ func TestCache_OneNode(t *testing.T) {
|
||||
Weight: 100,
|
||||
},
|
||||
}
|
||||
c := New(conf, syncx.NewSharedCalls(), NewStat("mock"), errPlaceholder)
|
||||
c := New(conf, syncx.NewSingleFlight(), NewStat("mock"), errPlaceholder)
|
||||
for i := 0; i < total; i++ {
|
||||
if i%2 == 0 {
|
||||
assert.Nil(t, c.Set(fmt.Sprintf("key/%d", i), i))
|
||||
|
||||
21
core/stores/cache/cachenode.go
vendored
21
core/stores/cache/cachenode.go
vendored
@@ -29,7 +29,7 @@ type cacheNode struct {
|
||||
rds *redis.Redis
|
||||
expiry time.Duration
|
||||
notFoundExpiry time.Duration
|
||||
barrier syncx.SharedCalls
|
||||
barrier syncx.SingleFlight
|
||||
r *rand.Rand
|
||||
lock *sync.Mutex
|
||||
unstableExpiry mathx.Unstable
|
||||
@@ -43,7 +43,7 @@ type cacheNode struct {
|
||||
// st is used to stat the cache.
|
||||
// errNotFound defines the error that returned on cache not found.
|
||||
// opts are the options that customize the cacheNode.
|
||||
func NewNode(rds *redis.Redis, barrier syncx.SharedCalls, st *Stat,
|
||||
func NewNode(rds *redis.Redis, barrier syncx.SingleFlight, st *Stat,
|
||||
errNotFound error, opts ...Option) Cache {
|
||||
o := newOptions(opts...)
|
||||
return cacheNode{
|
||||
@@ -65,9 +65,18 @@ func (c cacheNode) Del(keys ...string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
if _, err := c.rds.Del(keys...); err != nil {
|
||||
logx.Errorf("failed to clear cache with keys: %q, error: %v", formatKeys(keys), err)
|
||||
c.asyncRetryDelCache(keys...)
|
||||
if len(keys) > 1 && c.rds.Type == redis.ClusterType {
|
||||
for _, key := range keys {
|
||||
if _, err := c.rds.Del(key); err != nil {
|
||||
logx.Errorf("failed to clear cache with key: %q, error: %v", key, err)
|
||||
c.asyncRetryDelCache(key)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
if _, err := c.rds.Del(keys...); err != nil {
|
||||
logx.Errorf("failed to clear cache with keys: %q, error: %v", formatKeys(keys), err)
|
||||
c.asyncRetryDelCache(keys...)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
@@ -205,7 +214,7 @@ func (c cacheNode) doTake(v interface{}, key string, query func(v interface{}) e
|
||||
return jsonx.Unmarshal(val.([]byte), v)
|
||||
}
|
||||
|
||||
func (c cacheNode) processCache(key string, data string, v interface{}) error {
|
||||
func (c cacheNode) processCache(key, data string, v interface{}) error {
|
||||
err := jsonx.Unmarshal([]byte(data), v)
|
||||
if err == nil {
|
||||
return nil
|
||||
|
||||
28
core/stores/cache/cachenode_test.go
vendored
28
core/stores/cache/cachenode_test.go
vendored
@@ -29,6 +29,7 @@ func init() {
|
||||
func TestCacheNode_DelCache(t *testing.T) {
|
||||
store, clean, err := redistest.CreateRedis()
|
||||
assert.Nil(t, err)
|
||||
store.Type = redis.ClusterType
|
||||
defer clean()
|
||||
|
||||
cn := cacheNode{
|
||||
@@ -49,6 +50,23 @@ func TestCacheNode_DelCache(t *testing.T) {
|
||||
assert.Nil(t, cn.Del("first", "second"))
|
||||
}
|
||||
|
||||
func TestCacheNode_DelCacheWithErrors(t *testing.T) {
|
||||
store, clean, err := redistest.CreateRedis()
|
||||
assert.Nil(t, err)
|
||||
store.Type = redis.ClusterType
|
||||
clean()
|
||||
|
||||
cn := cacheNode{
|
||||
rds: store,
|
||||
r: rand.New(rand.NewSource(time.Now().UnixNano())),
|
||||
lock: new(sync.Mutex),
|
||||
unstableExpiry: mathx.NewUnstable(expiryDeviation),
|
||||
stat: NewStat("any"),
|
||||
errNotFound: errTestNotFound,
|
||||
}
|
||||
assert.Nil(t, cn.Del("third", "fourth"))
|
||||
}
|
||||
|
||||
func TestCacheNode_InvalidCache(t *testing.T) {
|
||||
s, err := miniredis.Run()
|
||||
assert.Nil(t, err)
|
||||
@@ -78,7 +96,7 @@ func TestCacheNode_Take(t *testing.T) {
|
||||
cn := cacheNode{
|
||||
rds: store,
|
||||
r: rand.New(rand.NewSource(time.Now().UnixNano())),
|
||||
barrier: syncx.NewSharedCalls(),
|
||||
barrier: syncx.NewSingleFlight(),
|
||||
lock: new(sync.Mutex),
|
||||
unstableExpiry: mathx.NewUnstable(expiryDeviation),
|
||||
stat: NewStat("any"),
|
||||
@@ -105,7 +123,7 @@ func TestCacheNode_TakeNotFound(t *testing.T) {
|
||||
cn := cacheNode{
|
||||
rds: store,
|
||||
r: rand.New(rand.NewSource(time.Now().UnixNano())),
|
||||
barrier: syncx.NewSharedCalls(),
|
||||
barrier: syncx.NewSingleFlight(),
|
||||
lock: new(sync.Mutex),
|
||||
unstableExpiry: mathx.NewUnstable(expiryDeviation),
|
||||
stat: NewStat("any"),
|
||||
@@ -144,7 +162,7 @@ func TestCacheNode_TakeWithExpire(t *testing.T) {
|
||||
cn := cacheNode{
|
||||
rds: store,
|
||||
r: rand.New(rand.NewSource(time.Now().UnixNano())),
|
||||
barrier: syncx.NewSharedCalls(),
|
||||
barrier: syncx.NewSingleFlight(),
|
||||
lock: new(sync.Mutex),
|
||||
unstableExpiry: mathx.NewUnstable(expiryDeviation),
|
||||
stat: NewStat("any"),
|
||||
@@ -171,7 +189,7 @@ func TestCacheNode_String(t *testing.T) {
|
||||
cn := cacheNode{
|
||||
rds: store,
|
||||
r: rand.New(rand.NewSource(time.Now().UnixNano())),
|
||||
barrier: syncx.NewSharedCalls(),
|
||||
barrier: syncx.NewSingleFlight(),
|
||||
lock: new(sync.Mutex),
|
||||
unstableExpiry: mathx.NewUnstable(expiryDeviation),
|
||||
stat: NewStat("any"),
|
||||
@@ -188,7 +206,7 @@ func TestCacheValueWithBigInt(t *testing.T) {
|
||||
cn := cacheNode{
|
||||
rds: store,
|
||||
r: rand.New(rand.NewSource(time.Now().UnixNano())),
|
||||
barrier: syncx.NewSharedCalls(),
|
||||
barrier: syncx.NewSingleFlight(),
|
||||
lock: new(sync.Mutex),
|
||||
unstableExpiry: mathx.NewUnstable(expiryDeviation),
|
||||
stat: NewStat("any"),
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
package clickhouse
|
||||
|
||||
import (
|
||||
// imports the driver.
|
||||
// imports the driver, don't remove this comment, golint requires.
|
||||
_ "github.com/ClickHouse/clickhouse-go"
|
||||
"github.com/tal-tech/go-zero/core/stores/sqlx"
|
||||
)
|
||||
|
||||
@@ -17,7 +17,7 @@ type (
|
||||
// Store interface represents a KV store.
|
||||
Store interface {
|
||||
Del(keys ...string) (int, error)
|
||||
Eval(script string, key string, args ...interface{}) (interface{}, error)
|
||||
Eval(script, key string, args ...interface{}) (interface{}, error)
|
||||
Exists(key string) (bool, error)
|
||||
Expire(key string, seconds int) error
|
||||
Expireat(key string, expireTime int64) error
|
||||
@@ -39,7 +39,7 @@ type (
|
||||
Llen(key string) (int, error)
|
||||
Lpop(key string) (string, error)
|
||||
Lpush(key string, values ...interface{}) (int, error)
|
||||
Lrange(key string, start int, stop int) ([]string, error)
|
||||
Lrange(key string, start, stop int) ([]string, error)
|
||||
Lrem(key string, count int, value string) (int, error)
|
||||
Persist(key string) (bool, error)
|
||||
Pfadd(key string, values ...interface{}) (bool, error)
|
||||
@@ -47,7 +47,7 @@ type (
|
||||
Rpush(key string, values ...interface{}) (int, error)
|
||||
Sadd(key string, values ...interface{}) (int, error)
|
||||
Scard(key string) (int64, error)
|
||||
Set(key string, value string) error
|
||||
Set(key, value string) error
|
||||
Setex(key, value string, seconds int) error
|
||||
Setnx(key, value string) (bool, error)
|
||||
SetnxEx(key, value string, seconds int) (bool, error)
|
||||
@@ -74,7 +74,7 @@ type (
|
||||
Zrevrange(key string, start, stop int64) ([]string, error)
|
||||
ZrevrangebyscoreWithScores(key string, start, stop int64) ([]redis.Pair, error)
|
||||
ZrevrangebyscoreWithScoresAndLimit(key string, start, stop int64, page, size int) ([]redis.Pair, error)
|
||||
Zscore(key string, value string) (int64, error)
|
||||
Zscore(key, value string) (int64, error)
|
||||
Zrevrank(key, field string) (int64, error)
|
||||
}
|
||||
|
||||
@@ -123,7 +123,7 @@ func (cs clusterStore) Del(keys ...string) (int, error) {
|
||||
return val, be.Err()
|
||||
}
|
||||
|
||||
func (cs clusterStore) Eval(script string, key string, args ...interface{}) (interface{}, error) {
|
||||
func (cs clusterStore) Eval(script, key string, args ...interface{}) (interface{}, error) {
|
||||
node, err := cs.getRedis(key)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -321,7 +321,7 @@ func (cs clusterStore) Lpush(key string, values ...interface{}) (int, error) {
|
||||
return node.Lpush(key, values...)
|
||||
}
|
||||
|
||||
func (cs clusterStore) Lrange(key string, start int, stop int) ([]string, error) {
|
||||
func (cs clusterStore) Lrange(key string, start, stop int) ([]string, error) {
|
||||
node, err := cs.getRedis(key)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -393,7 +393,7 @@ func (cs clusterStore) Scard(key string) (int64, error) {
|
||||
return node.Scard(key)
|
||||
}
|
||||
|
||||
func (cs clusterStore) Set(key string, value string) error {
|
||||
func (cs clusterStore) Set(key, value string) error {
|
||||
node, err := cs.getRedis(key)
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -648,7 +648,7 @@ func (cs clusterStore) Zrevrank(key, field string) (int64, error) {
|
||||
return node.Zrevrank(key, field)
|
||||
}
|
||||
|
||||
func (cs clusterStore) Zscore(key string, value string) (int64, error) {
|
||||
func (cs clusterStore) Zscore(key, value string) (int64, error) {
|
||||
node, err := cs.getRedis(key)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
|
||||
@@ -11,8 +11,8 @@ var (
|
||||
// ErrNotFound is an alias of mgo.ErrNotFound.
|
||||
ErrNotFound = mgo.ErrNotFound
|
||||
|
||||
// can't use one SharedCalls per conn, because multiple conns may share the same cache key.
|
||||
sharedCalls = syncx.NewSharedCalls()
|
||||
// can't use one SingleFlight per conn, because multiple conns may share the same cache key.
|
||||
sharedCalls = syncx.NewSingleFlight()
|
||||
stats = cache.NewStat("mongoc")
|
||||
)
|
||||
|
||||
@@ -24,11 +24,11 @@ type (
|
||||
CachedCollection interface {
|
||||
Count(query interface{}) (int, error)
|
||||
DelCache(keys ...string) error
|
||||
FindAllNoCache(v interface{}, query interface{}, opts ...QueryOption) error
|
||||
FindAllNoCache(v, query interface{}, opts ...QueryOption) error
|
||||
FindOne(v interface{}, key string, query interface{}) error
|
||||
FindOneNoCache(v interface{}, query interface{}) error
|
||||
FindOneNoCache(v, query interface{}) error
|
||||
FindOneId(v interface{}, key string, id interface{}) error
|
||||
FindOneIdNoCache(v interface{}, id interface{}) error
|
||||
FindOneIdNoCache(v, id interface{}) error
|
||||
GetCache(key string, v interface{}) error
|
||||
Insert(docs ...interface{}) error
|
||||
Pipe(pipeline interface{}) mongo.Pipe
|
||||
@@ -68,7 +68,7 @@ func (c *cachedCollection) DelCache(keys ...string) error {
|
||||
return c.cache.Del(keys...)
|
||||
}
|
||||
|
||||
func (c *cachedCollection) FindAllNoCache(v interface{}, query interface{}, opts ...QueryOption) error {
|
||||
func (c *cachedCollection) FindAllNoCache(v, query interface{}, opts ...QueryOption) error {
|
||||
q := c.collection.Find(query)
|
||||
for _, opt := range opts {
|
||||
q = opt(q)
|
||||
@@ -83,7 +83,7 @@ func (c *cachedCollection) FindOne(v interface{}, key string, query interface{})
|
||||
})
|
||||
}
|
||||
|
||||
func (c *cachedCollection) FindOneNoCache(v interface{}, query interface{}) error {
|
||||
func (c *cachedCollection) FindOneNoCache(v, query interface{}) error {
|
||||
q := c.collection.Find(query)
|
||||
return q.One(v)
|
||||
}
|
||||
@@ -95,7 +95,7 @@ func (c *cachedCollection) FindOneId(v interface{}, key string, id interface{})
|
||||
})
|
||||
}
|
||||
|
||||
func (c *cachedCollection) FindOneIdNoCache(v interface{}, id interface{}) error {
|
||||
func (c *cachedCollection) FindOneIdNoCache(v, id interface{}) error {
|
||||
q := c.collection.FindId(id)
|
||||
return q.One(v)
|
||||
}
|
||||
|
||||
@@ -75,7 +75,7 @@ func (mm *Model) GetCollection(session *mgo.Session) CachedCollection {
|
||||
}
|
||||
|
||||
// FindAllNoCache finds all records without cache.
|
||||
func (mm *Model) FindAllNoCache(v interface{}, query interface{}, opts ...QueryOption) error {
|
||||
func (mm *Model) FindAllNoCache(v, query interface{}, opts ...QueryOption) error {
|
||||
return mm.execute(func(c CachedCollection) error {
|
||||
return c.FindAllNoCache(v, query, opts...)
|
||||
})
|
||||
@@ -89,7 +89,7 @@ func (mm *Model) FindOne(v interface{}, key string, query interface{}) error {
|
||||
}
|
||||
|
||||
// FindOneNoCache unmarshals a record into v with query, without cache.
|
||||
func (mm *Model) FindOneNoCache(v interface{}, query interface{}) error {
|
||||
func (mm *Model) FindOneNoCache(v, query interface{}) error {
|
||||
return mm.execute(func(c CachedCollection) error {
|
||||
return c.FindOneNoCache(v, query)
|
||||
})
|
||||
@@ -103,7 +103,7 @@ func (mm *Model) FindOneId(v interface{}, key string, id interface{}) error {
|
||||
}
|
||||
|
||||
// FindOneIdNoCache unmarshals a record into v with query, without cache.
|
||||
func (mm *Model) FindOneIdNoCache(v interface{}, id interface{}) error {
|
||||
func (mm *Model) FindOneIdNoCache(v, id interface{}) error {
|
||||
return mm.execute(func(c CachedCollection) error {
|
||||
return c.FindOneIdNoCache(v, id)
|
||||
})
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
package postgres
|
||||
|
||||
import (
|
||||
// imports the driver.
|
||||
// imports the driver, don't remove this comment, golint requires.
|
||||
_ "github.com/lib/pq"
|
||||
"github.com/tal-tech/go-zero/core/stores/sqlx"
|
||||
)
|
||||
|
||||
@@ -71,6 +71,8 @@ type (
|
||||
IntCmd = red.IntCmd
|
||||
// FloatCmd is an alias of redis.FloatCmd.
|
||||
FloatCmd = red.FloatCmd
|
||||
// StringCmd is an alias of redis.StringCmd.
|
||||
StringCmd = red.StringCmd
|
||||
)
|
||||
|
||||
// New returns a Redis with given options.
|
||||
@@ -180,7 +182,7 @@ func (s *Redis) BitOpXor(destKey string, keys ...string) (val int64, err error)
|
||||
}
|
||||
|
||||
// BitPos is redis bitpos command implementation.
|
||||
func (s *Redis) BitPos(key string, bit int64, start, end int64) (val int64, err error) {
|
||||
func (s *Redis) BitPos(key string, bit, start, end int64) (val int64, err error) {
|
||||
err = s.brk.DoWithAcceptable(func() error {
|
||||
conn, err := getRedis(s)
|
||||
if err != nil {
|
||||
@@ -346,7 +348,7 @@ func (s *Redis) GeoAdd(key string, geoLocation ...*GeoLocation) (val int64, err
|
||||
}
|
||||
|
||||
// GeoDist is the implementation of redis geodist command.
|
||||
func (s *Redis) GeoDist(key string, member1, member2, unit string) (val float64, err error) {
|
||||
func (s *Redis) GeoDist(key, member1, member2, unit string) (val float64, err error) {
|
||||
err = s.brk.DoWithAcceptable(func() error {
|
||||
conn, err := getRedis(s)
|
||||
if err != nil {
|
||||
@@ -795,7 +797,7 @@ func (s *Redis) Lpush(key string, values ...interface{}) (val int, err error) {
|
||||
}
|
||||
|
||||
// Lrange is the implementation of redis lrange command.
|
||||
func (s *Redis) Lrange(key string, start int, stop int) (val []string, err error) {
|
||||
func (s *Redis) Lrange(key string, start, stop int) (val []string, err error) {
|
||||
err = s.brk.DoWithAcceptable(func() error {
|
||||
conn, err := getRedis(s)
|
||||
if err != nil {
|
||||
@@ -1074,7 +1076,7 @@ func (s *Redis) ScriptLoad(script string) (string, error) {
|
||||
}
|
||||
|
||||
// Set is the implementation of redis set command.
|
||||
func (s *Redis) Set(key string, value string) error {
|
||||
func (s *Redis) Set(key, value string) error {
|
||||
return s.brk.DoWithAcceptable(func() error {
|
||||
conn, err := getRedis(s)
|
||||
if err != nil {
|
||||
@@ -1282,6 +1284,41 @@ func (s *Redis) Sdiffstore(destination string, keys ...string) (val int, err err
|
||||
return
|
||||
}
|
||||
|
||||
// Sinter is the implementation of redis sinter command.
|
||||
func (s *Redis) Sinter(keys ...string) (val []string, err error) {
|
||||
err = s.brk.DoWithAcceptable(func() error {
|
||||
conn, err := getRedis(s)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
val, err = conn.SInter(keys...).Result()
|
||||
return err
|
||||
}, acceptable)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// Sinterstore is the implementation of redis sinterstore command.
|
||||
func (s *Redis) Sinterstore(destination string, keys ...string) (val int, err error) {
|
||||
err = s.brk.DoWithAcceptable(func() error {
|
||||
conn, err := getRedis(s)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
v, err := conn.SInterStore(destination, keys...).Result()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
val = int(v)
|
||||
return nil
|
||||
}, acceptable)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// Ttl is the implementation of redis ttl command.
|
||||
func (s *Redis) Ttl(key string) (val int, err error) {
|
||||
err = s.brk.DoWithAcceptable(func() error {
|
||||
@@ -1412,7 +1449,7 @@ func (s *Redis) Zincrby(key string, increment int64, field string) (val int64, e
|
||||
}
|
||||
|
||||
// Zscore is the implementation of redis zscore command.
|
||||
func (s *Redis) Zscore(key string, value string) (val int64, err error) {
|
||||
func (s *Redis) Zscore(key, value string) (val int64, err error) {
|
||||
err = s.brk.DoWithAcceptable(func() error {
|
||||
conn, err := getRedis(s)
|
||||
if err != nil {
|
||||
@@ -1684,7 +1721,7 @@ func (s *Redis) ZrevrangebyscoreWithScoresAndLimit(key string, start, stop int64
|
||||
}
|
||||
|
||||
// Zrevrank is the implementation of redis zrevrank command.
|
||||
func (s *Redis) Zrevrank(key string, field string) (val int64, err error) {
|
||||
func (s *Redis) Zrevrank(key, field string) (val int64, err error) {
|
||||
err = s.brk.DoWithAcceptable(func() error {
|
||||
conn, err := getRedis(s)
|
||||
if err != nil {
|
||||
|
||||
@@ -638,6 +638,16 @@ func TestRedis_Set(t *testing.T) {
|
||||
num, err = client.Sdiffstore("key4", "key1", "key2")
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, 1, num)
|
||||
_, err = New(client.Addr, badType()).Sinter("key1", "key2")
|
||||
assert.NotNil(t, err)
|
||||
vals, err = client.Sinter("key1", "key2")
|
||||
assert.Nil(t, err)
|
||||
assert.ElementsMatch(t, []string{"2", "3", "4"}, vals)
|
||||
_, err = New(client.Addr, badType()).Sinterstore("key4", "key1", "key2")
|
||||
assert.NotNil(t, err)
|
||||
num, err = client.Sinterstore("key4", "key1", "key2")
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, 3, num)
|
||||
})
|
||||
}
|
||||
|
||||
|
||||
@@ -17,8 +17,8 @@ var (
|
||||
// ErrNotFound is an alias of sqlx.ErrNotFound.
|
||||
ErrNotFound = sqlx.ErrNotFound
|
||||
|
||||
// can't use one SharedCalls per conn, because multiple conns may share the same cache key.
|
||||
exclusiveCalls = syncx.NewSharedCalls()
|
||||
// can't use one SingleFlight per conn, because multiple conns may share the same cache key.
|
||||
exclusiveCalls = syncx.NewSingleFlight()
|
||||
stats = cache.NewStat("sqlc")
|
||||
)
|
||||
|
||||
|
||||
@@ -600,6 +600,10 @@ func (d dummySqlConn) QueryRowsPartial(v interface{}, query string, args ...inte
|
||||
return nil
|
||||
}
|
||||
|
||||
func (d dummySqlConn) RawDB() (*sql.DB, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (d dummySqlConn) Transact(func(session sqlx.Session) error) error {
|
||||
return nil
|
||||
}
|
||||
@@ -621,6 +625,10 @@ func (c *trackedConn) QueryRows(v interface{}, query string, args ...interface{}
|
||||
return c.dummySqlConn.QueryRows(v, query, args...)
|
||||
}
|
||||
|
||||
func (c *trackedConn) RawDB() (*sql.DB, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (c *trackedConn) Transact(fn func(session sqlx.Session) error) error {
|
||||
c.transactValue = true
|
||||
return c.dummySqlConn.Transact(fn)
|
||||
|
||||
@@ -43,6 +43,10 @@ func (c *mockedConn) QueryRowsPartial(v interface{}, query string, args ...inter
|
||||
panic("should not called")
|
||||
}
|
||||
|
||||
func (c *mockedConn) RawDB() (*sql.DB, error) {
|
||||
panic("should not called")
|
||||
}
|
||||
|
||||
func (c *mockedConn) Transact(func(session Session) error) error {
|
||||
panic("should not called")
|
||||
}
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"database/sql"
|
||||
|
||||
"github.com/tal-tech/go-zero/core/breaker"
|
||||
"github.com/tal-tech/go-zero/core/logx"
|
||||
)
|
||||
|
||||
// ErrNotFound is an alias of sql.ErrNoRows
|
||||
@@ -23,6 +24,8 @@ type (
|
||||
// SqlConn only stands for raw connections, so Transact method can be called.
|
||||
SqlConn interface {
|
||||
Session
|
||||
// RawDB is for other ORM to operate with, use it with caution.
|
||||
RawDB() (*sql.DB, error)
|
||||
Transact(func(session Session) error) error
|
||||
}
|
||||
|
||||
@@ -43,20 +46,23 @@ type (
|
||||
// Because CORBA doesn't support PREPARE, so we need to combine the
|
||||
// query arguments into one string and do underlying query without arguments
|
||||
commonSqlConn struct {
|
||||
driverName string
|
||||
datasource string
|
||||
beginTx beginnable
|
||||
brk breaker.Breaker
|
||||
accept func(error) bool
|
||||
connProv connProvider
|
||||
onError func(error)
|
||||
beginTx beginnable
|
||||
brk breaker.Breaker
|
||||
accept func(error) bool
|
||||
}
|
||||
|
||||
connProvider func() (*sql.DB, error)
|
||||
|
||||
sessionConn interface {
|
||||
Exec(query string, args ...interface{}) (sql.Result, error)
|
||||
Query(query string, args ...interface{}) (*sql.Rows, error)
|
||||
}
|
||||
|
||||
statement struct {
|
||||
stmt *sql.Stmt
|
||||
query string
|
||||
stmt *sql.Stmt
|
||||
}
|
||||
|
||||
stmtConn interface {
|
||||
@@ -68,10 +74,34 @@ type (
|
||||
// NewSqlConn returns a SqlConn with given driver name and datasource.
|
||||
func NewSqlConn(driverName, datasource string, opts ...SqlOption) SqlConn {
|
||||
conn := &commonSqlConn{
|
||||
driverName: driverName,
|
||||
datasource: datasource,
|
||||
beginTx: begin,
|
||||
brk: breaker.NewBreaker(),
|
||||
connProv: func() (*sql.DB, error) {
|
||||
return getSqlConn(driverName, datasource)
|
||||
},
|
||||
onError: func(err error) {
|
||||
logInstanceError(datasource, err)
|
||||
},
|
||||
beginTx: begin,
|
||||
brk: breaker.NewBreaker(),
|
||||
}
|
||||
for _, opt := range opts {
|
||||
opt(conn)
|
||||
}
|
||||
|
||||
return conn
|
||||
}
|
||||
|
||||
// NewSqlConnFromDB returns a SqlConn with the given sql.DB.
|
||||
// Use it with caution, it's provided for other ORM to interact with.
|
||||
func NewSqlConnFromDB(db *sql.DB, opts ...SqlOption) SqlConn {
|
||||
conn := &commonSqlConn{
|
||||
connProv: func() (*sql.DB, error) {
|
||||
return db, nil
|
||||
},
|
||||
onError: func(err error) {
|
||||
logx.Errorf("Error on getting sql instance: %v", err)
|
||||
},
|
||||
beginTx: begin,
|
||||
brk: breaker.NewBreaker(),
|
||||
}
|
||||
for _, opt := range opts {
|
||||
opt(conn)
|
||||
@@ -83,9 +113,9 @@ func NewSqlConn(driverName, datasource string, opts ...SqlOption) SqlConn {
|
||||
func (db *commonSqlConn) Exec(q string, args ...interface{}) (result sql.Result, err error) {
|
||||
err = db.brk.DoWithAcceptable(func() error {
|
||||
var conn *sql.DB
|
||||
conn, err = getSqlConn(db.driverName, db.datasource)
|
||||
conn, err = db.connProv()
|
||||
if err != nil {
|
||||
logInstanceError(db.datasource, err)
|
||||
db.onError(err)
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -99,9 +129,9 @@ func (db *commonSqlConn) Exec(q string, args ...interface{}) (result sql.Result,
|
||||
func (db *commonSqlConn) Prepare(query string) (stmt StmtSession, err error) {
|
||||
err = db.brk.DoWithAcceptable(func() error {
|
||||
var conn *sql.DB
|
||||
conn, err = getSqlConn(db.driverName, db.datasource)
|
||||
conn, err = db.connProv()
|
||||
if err != nil {
|
||||
logInstanceError(db.datasource, err)
|
||||
db.onError(err)
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -111,7 +141,8 @@ func (db *commonSqlConn) Prepare(query string) (stmt StmtSession, err error) {
|
||||
}
|
||||
|
||||
stmt = statement{
|
||||
stmt: st,
|
||||
query: query,
|
||||
stmt: st,
|
||||
}
|
||||
return nil
|
||||
}, db.acceptable)
|
||||
@@ -143,6 +174,10 @@ func (db *commonSqlConn) QueryRowsPartial(v interface{}, q string, args ...inter
|
||||
}, q, args...)
|
||||
}
|
||||
|
||||
func (db *commonSqlConn) RawDB() (*sql.DB, error) {
|
||||
return db.connProv()
|
||||
}
|
||||
|
||||
func (db *commonSqlConn) Transact(fn func(Session) error) error {
|
||||
return db.brk.DoWithAcceptable(func() error {
|
||||
return transact(db, db.beginTx, fn)
|
||||
@@ -161,9 +196,9 @@ func (db *commonSqlConn) acceptable(err error) bool {
|
||||
func (db *commonSqlConn) queryRows(scanner func(*sql.Rows) error, q string, args ...interface{}) error {
|
||||
var qerr error
|
||||
return db.brk.DoWithAcceptable(func() error {
|
||||
conn, err := getSqlConn(db.driverName, db.datasource)
|
||||
conn, err := db.connProv()
|
||||
if err != nil {
|
||||
logInstanceError(db.datasource, err)
|
||||
db.onError(err)
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -181,29 +216,29 @@ func (s statement) Close() error {
|
||||
}
|
||||
|
||||
func (s statement) Exec(args ...interface{}) (sql.Result, error) {
|
||||
return execStmt(s.stmt, args...)
|
||||
return execStmt(s.stmt, s.query, args...)
|
||||
}
|
||||
|
||||
func (s statement) QueryRow(v interface{}, args ...interface{}) error {
|
||||
return queryStmt(s.stmt, func(rows *sql.Rows) error {
|
||||
return unmarshalRow(v, rows, true)
|
||||
}, args...)
|
||||
}, s.query, args...)
|
||||
}
|
||||
|
||||
func (s statement) QueryRowPartial(v interface{}, args ...interface{}) error {
|
||||
return queryStmt(s.stmt, func(rows *sql.Rows) error {
|
||||
return unmarshalRow(v, rows, false)
|
||||
}, args...)
|
||||
}, s.query, args...)
|
||||
}
|
||||
|
||||
func (s statement) QueryRows(v interface{}, args ...interface{}) error {
|
||||
return queryStmt(s.stmt, func(rows *sql.Rows) error {
|
||||
return unmarshalRows(v, rows, true)
|
||||
}, args...)
|
||||
}, s.query, args...)
|
||||
}
|
||||
|
||||
func (s statement) QueryRowsPartial(v interface{}, args ...interface{}) error {
|
||||
return queryStmt(s.stmt, func(rows *sql.Rows) error {
|
||||
return unmarshalRows(v, rows, false)
|
||||
}, args...)
|
||||
}, s.query, args...)
|
||||
}
|
||||
|
||||
@@ -21,12 +21,15 @@ func TestSqlConn(t *testing.T) {
|
||||
mock.ExpectExec("any")
|
||||
mock.ExpectQuery("any").WillReturnRows(sqlmock.NewRows([]string{"foo"}))
|
||||
conn := NewMysql(mockedDatasource)
|
||||
db, err := conn.RawDB()
|
||||
assert.Nil(t, err)
|
||||
rawConn := NewSqlConnFromDB(db, withMysqlAcceptable())
|
||||
badConn := NewMysql("badsql")
|
||||
_, err := conn.Exec("any", "value")
|
||||
_, err = conn.Exec("any", "value")
|
||||
assert.NotNil(t, err)
|
||||
_, err = badConn.Exec("any", "value")
|
||||
assert.NotNil(t, err)
|
||||
_, err = conn.Prepare("any")
|
||||
_, err = rawConn.Prepare("any")
|
||||
assert.NotNil(t, err)
|
||||
_, err = badConn.Prepare("any")
|
||||
assert.NotNil(t, err)
|
||||
|
||||
@@ -2,7 +2,6 @@ package sqlx
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/tal-tech/go-zero/core/logx"
|
||||
@@ -12,10 +11,14 @@ import (
|
||||
const slowThreshold = time.Millisecond * 500
|
||||
|
||||
func exec(conn sessionConn, q string, args ...interface{}) (sql.Result, error) {
|
||||
stmt, err := format(q, args...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
startTime := timex.Now()
|
||||
result, err := conn.Exec(q, args...)
|
||||
duration := timex.Since(startTime)
|
||||
stmt := formatForPrint(q, args)
|
||||
if duration > slowThreshold {
|
||||
logx.WithDuration(duration).Slowf("[SQL] exec: slowcall - %s", stmt)
|
||||
} else {
|
||||
@@ -28,11 +31,15 @@ func exec(conn sessionConn, q string, args ...interface{}) (sql.Result, error) {
|
||||
return result, err
|
||||
}
|
||||
|
||||
func execStmt(conn stmtConn, args ...interface{}) (sql.Result, error) {
|
||||
func execStmt(conn stmtConn, q string, args ...interface{}) (sql.Result, error) {
|
||||
stmt, err := format(q, args...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
startTime := timex.Now()
|
||||
result, err := conn.Exec(args...)
|
||||
duration := timex.Since(startTime)
|
||||
stmt := fmt.Sprint(args...)
|
||||
if duration > slowThreshold {
|
||||
logx.WithDuration(duration).Slowf("[SQL] execStmt: slowcall - %s", stmt)
|
||||
} else {
|
||||
@@ -46,10 +53,14 @@ func execStmt(conn stmtConn, args ...interface{}) (sql.Result, error) {
|
||||
}
|
||||
|
||||
func query(conn sessionConn, scanner func(*sql.Rows) error, q string, args ...interface{}) error {
|
||||
stmt, err := format(q, args...)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
startTime := timex.Now()
|
||||
rows, err := conn.Query(q, args...)
|
||||
duration := timex.Since(startTime)
|
||||
stmt := fmt.Sprint(args...)
|
||||
if duration > slowThreshold {
|
||||
logx.WithDuration(duration).Slowf("[SQL] query: slowcall - %s", stmt)
|
||||
} else {
|
||||
@@ -64,8 +75,12 @@ func query(conn sessionConn, scanner func(*sql.Rows) error, q string, args ...in
|
||||
return scanner(rows)
|
||||
}
|
||||
|
||||
func queryStmt(conn stmtConn, scanner func(*sql.Rows) error, args ...interface{}) error {
|
||||
stmt := fmt.Sprint(args...)
|
||||
func queryStmt(conn stmtConn, scanner func(*sql.Rows) error, q string, args ...interface{}) error {
|
||||
stmt, err := format(q, args...)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
startTime := timex.Now()
|
||||
rows, err := conn.Query(args...)
|
||||
duration := timex.Since(startTime)
|
||||
|
||||
@@ -14,6 +14,7 @@ var errMockedPlaceholder = errors.New("placeholder")
|
||||
func TestStmt_exec(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
query string
|
||||
args []interface{}
|
||||
delay bool
|
||||
hasError bool
|
||||
@@ -23,18 +24,28 @@ func TestStmt_exec(t *testing.T) {
|
||||
}{
|
||||
{
|
||||
name: "normal",
|
||||
query: "select user from users where id=?",
|
||||
args: []interface{}{1},
|
||||
lastInsertId: 1,
|
||||
rowsAffected: 2,
|
||||
},
|
||||
{
|
||||
name: "exec error",
|
||||
query: "select user from users where id=?",
|
||||
args: []interface{}{1},
|
||||
hasError: true,
|
||||
err: errors.New("exec"),
|
||||
},
|
||||
{
|
||||
name: "exec more args error",
|
||||
query: "select user from users where id=? and name=?",
|
||||
args: []interface{}{1},
|
||||
hasError: true,
|
||||
err: errors.New("exec"),
|
||||
},
|
||||
{
|
||||
name: "slowcall",
|
||||
query: "select user from users where id=?",
|
||||
args: []interface{}{1},
|
||||
delay: true,
|
||||
lastInsertId: 1,
|
||||
@@ -51,7 +62,7 @@ func TestStmt_exec(t *testing.T) {
|
||||
rowsAffected: test.rowsAffected,
|
||||
err: test.err,
|
||||
delay: test.delay,
|
||||
}, "select user from users where id=?", args...)
|
||||
}, test.query, args...)
|
||||
},
|
||||
func(args ...interface{}) (sql.Result, error) {
|
||||
return execStmt(&mockedStmtConn{
|
||||
@@ -59,7 +70,7 @@ func TestStmt_exec(t *testing.T) {
|
||||
rowsAffected: test.rowsAffected,
|
||||
err: test.err,
|
||||
delay: test.delay,
|
||||
}, args...)
|
||||
}, test.query, args...)
|
||||
},
|
||||
}
|
||||
|
||||
@@ -89,23 +100,34 @@ func TestStmt_exec(t *testing.T) {
|
||||
func TestStmt_query(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
query string
|
||||
args []interface{}
|
||||
delay bool
|
||||
hasError bool
|
||||
err error
|
||||
}{
|
||||
{
|
||||
name: "normal",
|
||||
args: []interface{}{1},
|
||||
name: "normal",
|
||||
query: "select user from users where id=?",
|
||||
args: []interface{}{1},
|
||||
},
|
||||
{
|
||||
name: "query error",
|
||||
query: "select user from users where id=?",
|
||||
args: []interface{}{1},
|
||||
hasError: true,
|
||||
err: errors.New("exec"),
|
||||
},
|
||||
{
|
||||
name: "query more args error",
|
||||
query: "select user from users where id=? and name=?",
|
||||
args: []interface{}{1},
|
||||
hasError: true,
|
||||
err: errors.New("exec"),
|
||||
},
|
||||
{
|
||||
name: "slowcall",
|
||||
query: "select user from users where id=?",
|
||||
args: []interface{}{1},
|
||||
delay: true,
|
||||
},
|
||||
@@ -120,7 +142,7 @@ func TestStmt_query(t *testing.T) {
|
||||
delay: test.delay,
|
||||
}, func(rows *sql.Rows) error {
|
||||
return nil
|
||||
}, "select user from users where id=?", args...)
|
||||
}, test.query, args...)
|
||||
},
|
||||
func(args ...interface{}) error {
|
||||
return queryStmt(&mockedStmtConn{
|
||||
@@ -128,7 +150,7 @@ func TestStmt_query(t *testing.T) {
|
||||
delay: test.delay,
|
||||
}, func(rows *sql.Rows) error {
|
||||
return nil
|
||||
}, args...)
|
||||
}, test.query, args...)
|
||||
},
|
||||
}
|
||||
|
||||
@@ -143,7 +165,7 @@ func TestStmt_query(t *testing.T) {
|
||||
return
|
||||
}
|
||||
|
||||
assert.Equal(t, errMockedPlaceholder, err)
|
||||
assert.NotNil(t, err)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -30,7 +30,8 @@ func (t txSession) Prepare(q string) (StmtSession, error) {
|
||||
}
|
||||
|
||||
return statement{
|
||||
stmt: stmt,
|
||||
query: q,
|
||||
stmt: stmt,
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -70,9 +71,9 @@ func begin(db *sql.DB) (trans, error) {
|
||||
}
|
||||
|
||||
func transact(db *commonSqlConn, b beginnable, fn func(Session) error) (err error) {
|
||||
conn, err := getSqlConn(db.driverName, db.datasource)
|
||||
conn, err := db.connProv()
|
||||
if err != nil {
|
||||
logInstanceError(db.datasource, err)
|
||||
db.onError(err)
|
||||
return err
|
||||
}
|
||||
|
||||
|
||||
@@ -2,6 +2,7 @@ package sqlx
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/tal-tech/go-zero/core/logx"
|
||||
@@ -45,24 +46,6 @@ func escape(input string) string {
|
||||
return b.String()
|
||||
}
|
||||
|
||||
func formatForPrint(query string, args ...interface{}) string {
|
||||
if len(args) == 0 {
|
||||
return query
|
||||
}
|
||||
|
||||
var vals []string
|
||||
for _, arg := range args {
|
||||
vals = append(vals, fmt.Sprintf("%q", mapping.Repr(arg)))
|
||||
}
|
||||
|
||||
var b strings.Builder
|
||||
b.WriteByte('[')
|
||||
b.WriteString(strings.Join(vals, ", "))
|
||||
b.WriteByte(']')
|
||||
|
||||
return strings.Join([]string{query, b.String()}, " ")
|
||||
}
|
||||
|
||||
func format(query string, args ...interface{}) (string, error) {
|
||||
numArgs := len(args)
|
||||
if numArgs == 0 {
|
||||
@@ -70,38 +53,52 @@ func format(query string, args ...interface{}) (string, error) {
|
||||
}
|
||||
|
||||
var b strings.Builder
|
||||
argIndex := 0
|
||||
var argIndex int
|
||||
bytes := len(query)
|
||||
|
||||
for _, ch := range query {
|
||||
if ch == '?' {
|
||||
for i := 0; i < bytes; i++ {
|
||||
ch := query[i]
|
||||
switch ch {
|
||||
case '?':
|
||||
if argIndex >= numArgs {
|
||||
return "", fmt.Errorf("error: %d ? in sql, but less arguments provided", argIndex)
|
||||
}
|
||||
|
||||
arg := args[argIndex]
|
||||
writeValue(&b, args[argIndex])
|
||||
argIndex++
|
||||
|
||||
switch v := arg.(type) {
|
||||
case bool:
|
||||
if v {
|
||||
b.WriteByte('1')
|
||||
} else {
|
||||
b.WriteByte('0')
|
||||
case '$':
|
||||
var j int
|
||||
for j = i + 1; j < bytes; j++ {
|
||||
char := query[j]
|
||||
if char < '0' || '9' < char {
|
||||
break
|
||||
}
|
||||
case string:
|
||||
b.WriteByte('\'')
|
||||
b.WriteString(escape(v))
|
||||
b.WriteByte('\'')
|
||||
default:
|
||||
b.WriteString(mapping.Repr(v))
|
||||
}
|
||||
} else {
|
||||
b.WriteRune(ch)
|
||||
if j > i+1 {
|
||||
index, err := strconv.Atoi(query[i+1 : j])
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
// index starts from 1 for pg
|
||||
if index > argIndex {
|
||||
argIndex = index
|
||||
}
|
||||
index--
|
||||
if index < 0 || numArgs <= index {
|
||||
return "", fmt.Errorf("error: wrong index %d in sql", index)
|
||||
}
|
||||
|
||||
writeValue(&b, args[index])
|
||||
i = j - 1
|
||||
}
|
||||
default:
|
||||
b.WriteByte(ch)
|
||||
}
|
||||
}
|
||||
|
||||
if argIndex < numArgs {
|
||||
return "", fmt.Errorf("error: %d ? in sql, but more arguments provided", argIndex)
|
||||
return "", fmt.Errorf("error: %d arguments provided, not matching sql", argIndex)
|
||||
}
|
||||
|
||||
return b.String(), nil
|
||||
@@ -117,3 +114,20 @@ func logSqlError(stmt string, err error) {
|
||||
logx.Errorf("stmt: %s, error: %s", stmt, err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
func writeValue(buf *strings.Builder, arg interface{}) {
|
||||
switch v := arg.(type) {
|
||||
case bool:
|
||||
if v {
|
||||
buf.WriteByte('1')
|
||||
} else {
|
||||
buf.WriteByte('0')
|
||||
}
|
||||
case string:
|
||||
buf.WriteByte('\'')
|
||||
buf.WriteString(escape(v))
|
||||
buf.WriteByte('\'')
|
||||
default:
|
||||
buf.WriteString(mapping.Repr(v))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -29,30 +29,63 @@ func TestDesensitize_WithoutAccount(t *testing.T) {
|
||||
assert.True(t, strings.Contains(datasource, "tcp(111.222.333.44:3306)"))
|
||||
}
|
||||
|
||||
func TestFormatForPrint(t *testing.T) {
|
||||
func TestFormat(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
query string
|
||||
args []interface{}
|
||||
expect string
|
||||
hasErr bool
|
||||
}{
|
||||
{
|
||||
name: "no args",
|
||||
query: "select user, name from table where id=?",
|
||||
expect: `select user, name from table where id=?`,
|
||||
name: "mysql normal",
|
||||
query: "select name, age from users where bool=? and phone=?",
|
||||
args: []interface{}{true, "133"},
|
||||
expect: "select name, age from users where bool=1 and phone='133'",
|
||||
},
|
||||
{
|
||||
name: "one arg",
|
||||
query: "select user, name from table where id=?",
|
||||
args: []interface{}{"kevin"},
|
||||
expect: `select user, name from table where id=? ["kevin"]`,
|
||||
name: "mysql normal",
|
||||
query: "select name, age from users where bool=? and phone=?",
|
||||
args: []interface{}{false, "133"},
|
||||
expect: "select name, age from users where bool=0 and phone='133'",
|
||||
},
|
||||
{
|
||||
name: "pg normal",
|
||||
query: "select name, age from users where bool=$1 and phone=$2",
|
||||
args: []interface{}{true, "133"},
|
||||
expect: "select name, age from users where bool=1 and phone='133'",
|
||||
},
|
||||
{
|
||||
name: "pg normal reverse",
|
||||
query: "select name, age from users where bool=$2 and phone=$1",
|
||||
args: []interface{}{"133", false},
|
||||
expect: "select name, age from users where bool=0 and phone='133'",
|
||||
},
|
||||
{
|
||||
name: "pg error not number",
|
||||
query: "select name, age from users where bool=$a and phone=$1",
|
||||
args: []interface{}{"133", false},
|
||||
hasErr: true,
|
||||
},
|
||||
{
|
||||
name: "pg error more args",
|
||||
query: "select name, age from users where bool=$2 and phone=$1 and nickname=$3",
|
||||
args: []interface{}{"133", false},
|
||||
hasErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
test := test
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
actual := formatForPrint(test.query, test.args...)
|
||||
assert.Equal(t, test.expect, actual)
|
||||
t.Parallel()
|
||||
|
||||
actual, err := format(test.query, test.args...)
|
||||
if test.hasErr {
|
||||
assert.NotNil(t, err)
|
||||
} else {
|
||||
assert.Equal(t, test.expect, actual)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -40,6 +40,24 @@ func Filter(s string, filter func(r rune) bool) string {
|
||||
return string(chars[:n])
|
||||
}
|
||||
|
||||
// FirstN returns first n runes from s.
|
||||
func FirstN(s string, n int, ellipsis ...string) string {
|
||||
var i int
|
||||
|
||||
for j := range s {
|
||||
if i == n {
|
||||
ret := s[:j]
|
||||
for _, each := range ellipsis {
|
||||
ret += each
|
||||
}
|
||||
return ret
|
||||
}
|
||||
i++
|
||||
}
|
||||
|
||||
return s
|
||||
}
|
||||
|
||||
// HasEmpty checks if there are empty strings in args.
|
||||
func HasEmpty(args ...string) bool {
|
||||
for _, arg := range args {
|
||||
@@ -86,7 +104,7 @@ func Reverse(s string) string {
|
||||
}
|
||||
|
||||
// Substr returns runes between start and stop [start, stop) regardless of the chars are ascii or utf8.
|
||||
func Substr(str string, start int, stop int) (string, error) {
|
||||
func Substr(str string, start, stop int) (string, error) {
|
||||
rs := []rune(str)
|
||||
length := len(rs)
|
||||
|
||||
|
||||
@@ -92,6 +92,61 @@ func TestFilter(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestFirstN(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
n int
|
||||
ellipsis string
|
||||
expect string
|
||||
}{
|
||||
{
|
||||
name: "english string",
|
||||
input: "anything that we use",
|
||||
n: 8,
|
||||
expect: "anything",
|
||||
},
|
||||
{
|
||||
name: "english string with ellipsis",
|
||||
input: "anything that we use",
|
||||
n: 8,
|
||||
ellipsis: "...",
|
||||
expect: "anything...",
|
||||
},
|
||||
{
|
||||
name: "english string more",
|
||||
input: "anything that we use",
|
||||
n: 80,
|
||||
expect: "anything that we use",
|
||||
},
|
||||
{
|
||||
name: "chinese string",
|
||||
input: "我是中国人",
|
||||
n: 2,
|
||||
expect: "我是",
|
||||
},
|
||||
{
|
||||
name: "chinese string with ellipsis",
|
||||
input: "我是中国人",
|
||||
n: 2,
|
||||
ellipsis: "...",
|
||||
expect: "我是...",
|
||||
},
|
||||
{
|
||||
name: "chinese string",
|
||||
input: "我是中国人",
|
||||
n: 10,
|
||||
expect: "我是中国人",
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
assert.Equal(t, test.expect, FirstN(test.input, test.n, test.ellipsis))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRemove(t *testing.T) {
|
||||
cases := []struct {
|
||||
input []string
|
||||
|
||||
@@ -18,7 +18,7 @@ func NewManagedResource(generate func() interface{}, equals func(a, b interface{
|
||||
}
|
||||
}
|
||||
|
||||
// MarkBroken marks the resouce broken.
|
||||
// MarkBroken marks the resource broken.
|
||||
func (mr *ManagedResource) MarkBroken(resource interface{}) {
|
||||
mr.lock.Lock()
|
||||
defer mr.lock.Unlock()
|
||||
|
||||
@@ -2,7 +2,7 @@ package syncx
|
||||
|
||||
import "sync"
|
||||
|
||||
// Once returns a func that guanartees fn can only called once.
|
||||
// Once returns a func that guarantees fn can only called once.
|
||||
func Once(fn func()) func() {
|
||||
once := new(sync.Once)
|
||||
return func() {
|
||||
|
||||
@@ -2,7 +2,7 @@ package syncx
|
||||
|
||||
import "sync/atomic"
|
||||
|
||||
// A OnceGuard is used to make sure a resouce can be taken once.
|
||||
// A OnceGuard is used to make sure a resource can be taken once.
|
||||
type OnceGuard struct {
|
||||
done uint32
|
||||
}
|
||||
|
||||
@@ -19,7 +19,7 @@ type (
|
||||
|
||||
// A Pool is used to pool resources.
|
||||
// The difference between sync.Pool is that:
|
||||
// 1. the limit of the resouces
|
||||
// 1. the limit of the resources
|
||||
// 2. max age of the resources can be set
|
||||
// 3. the method to destroy resources can be customized
|
||||
Pool struct {
|
||||
@@ -56,7 +56,7 @@ func NewPool(n int, create func() interface{}, destroy func(interface{}), opts .
|
||||
return pool
|
||||
}
|
||||
|
||||
// Get gets a resouce.
|
||||
// Get gets a resource.
|
||||
func (p *Pool) Get() interface{} {
|
||||
p.lock.Lock()
|
||||
defer p.lock.Unlock()
|
||||
|
||||
@@ -8,7 +8,7 @@ import (
|
||||
// ErrUseOfCleaned is an error that indicates using a cleaned resource.
|
||||
var ErrUseOfCleaned = errors.New("using a cleaned resource")
|
||||
|
||||
// A RefResource is used to reference counting a resouce.
|
||||
// A RefResource is used to reference counting a resource.
|
||||
type RefResource struct {
|
||||
lock sync.Mutex
|
||||
ref int32
|
||||
|
||||
@@ -9,16 +9,16 @@ import (
|
||||
|
||||
// A ResourceManager is a manager that used to manage resources.
|
||||
type ResourceManager struct {
|
||||
resources map[string]io.Closer
|
||||
sharedCalls SharedCalls
|
||||
lock sync.RWMutex
|
||||
resources map[string]io.Closer
|
||||
singleFlight SingleFlight
|
||||
lock sync.RWMutex
|
||||
}
|
||||
|
||||
// NewResourceManager returns a ResourceManager.
|
||||
func NewResourceManager() *ResourceManager {
|
||||
return &ResourceManager{
|
||||
resources: make(map[string]io.Closer),
|
||||
sharedCalls: NewSharedCalls(),
|
||||
resources: make(map[string]io.Closer),
|
||||
singleFlight: NewSingleFlight(),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -39,7 +39,7 @@ func (manager *ResourceManager) Close() error {
|
||||
|
||||
// GetResource returns the resource associated with given key.
|
||||
func (manager *ResourceManager) GetResource(key string, create func() (io.Closer, error)) (io.Closer, error) {
|
||||
val, err := manager.sharedCalls.Do(key, func() (interface{}, error) {
|
||||
val, err := manager.singleFlight.Do(key, func() (interface{}, error) {
|
||||
manager.lock.RLock()
|
||||
resource, ok := manager.resources[key]
|
||||
manager.lock.RUnlock()
|
||||
|
||||
@@ -3,13 +3,17 @@ package syncx
|
||||
import "sync"
|
||||
|
||||
type (
|
||||
// SharedCalls lets the concurrent calls with the same key to share the call result.
|
||||
// SharedCalls is an alias of SingleFlight.
|
||||
// Deprecated: use SingleFlight.
|
||||
SharedCalls = SingleFlight
|
||||
|
||||
// SingleFlight lets the concurrent calls with the same key to share the call result.
|
||||
// For example, A called F, before it's done, B called F. Then B would not execute F,
|
||||
// and shared the result returned by F which called by A.
|
||||
// The calls with the same key are dependent, concurrent calls share the returned values.
|
||||
// A ------->calls F with key<------------------->returns val
|
||||
// B --------------------->calls F with key------>returns val
|
||||
SharedCalls interface {
|
||||
SingleFlight interface {
|
||||
Do(key string, fn func() (interface{}, error)) (interface{}, error)
|
||||
DoEx(key string, fn func() (interface{}, error)) (interface{}, bool, error)
|
||||
}
|
||||
@@ -20,20 +24,26 @@ type (
|
||||
err error
|
||||
}
|
||||
|
||||
sharedGroup struct {
|
||||
flightGroup struct {
|
||||
calls map[string]*call
|
||||
lock sync.Mutex
|
||||
}
|
||||
)
|
||||
|
||||
// NewSharedCalls returns a SharedCalls.
|
||||
func NewSharedCalls() SharedCalls {
|
||||
return &sharedGroup{
|
||||
// NewSingleFlight returns a SingleFlight.
|
||||
func NewSingleFlight() SingleFlight {
|
||||
return &flightGroup{
|
||||
calls: make(map[string]*call),
|
||||
}
|
||||
}
|
||||
|
||||
func (g *sharedGroup) Do(key string, fn func() (interface{}, error)) (interface{}, error) {
|
||||
// NewSharedCalls returns a SingleFlight.
|
||||
// Deprecated: use NewSingleFlight.
|
||||
func NewSharedCalls() SingleFlight {
|
||||
return NewSingleFlight()
|
||||
}
|
||||
|
||||
func (g *flightGroup) Do(key string, fn func() (interface{}, error)) (interface{}, error) {
|
||||
c, done := g.createCall(key)
|
||||
if done {
|
||||
return c.val, c.err
|
||||
@@ -43,7 +53,7 @@ func (g *sharedGroup) Do(key string, fn func() (interface{}, error)) (interface{
|
||||
return c.val, c.err
|
||||
}
|
||||
|
||||
func (g *sharedGroup) DoEx(key string, fn func() (interface{}, error)) (val interface{}, fresh bool, err error) {
|
||||
func (g *flightGroup) DoEx(key string, fn func() (interface{}, error)) (val interface{}, fresh bool, err error) {
|
||||
c, done := g.createCall(key)
|
||||
if done {
|
||||
return c.val, false, c.err
|
||||
@@ -53,7 +63,7 @@ func (g *sharedGroup) DoEx(key string, fn func() (interface{}, error)) (val inte
|
||||
return c.val, true, c.err
|
||||
}
|
||||
|
||||
func (g *sharedGroup) createCall(key string) (c *call, done bool) {
|
||||
func (g *flightGroup) createCall(key string) (c *call, done bool) {
|
||||
g.lock.Lock()
|
||||
if c, ok := g.calls[key]; ok {
|
||||
g.lock.Unlock()
|
||||
@@ -69,7 +79,7 @@ func (g *sharedGroup) createCall(key string) (c *call, done bool) {
|
||||
return c, false
|
||||
}
|
||||
|
||||
func (g *sharedGroup) makeCall(c *call, key string, fn func() (interface{}, error)) {
|
||||
func (g *flightGroup) makeCall(c *call, key string, fn func() (interface{}, error)) {
|
||||
defer func() {
|
||||
g.lock.Lock()
|
||||
delete(g.calls, key)
|
||||
@@ -10,7 +10,7 @@ import (
|
||||
)
|
||||
|
||||
func TestExclusiveCallDo(t *testing.T) {
|
||||
g := NewSharedCalls()
|
||||
g := NewSingleFlight()
|
||||
v, err := g.Do("key", func() (interface{}, error) {
|
||||
return "bar", nil
|
||||
})
|
||||
@@ -23,7 +23,7 @@ func TestExclusiveCallDo(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestExclusiveCallDoErr(t *testing.T) {
|
||||
g := NewSharedCalls()
|
||||
g := NewSingleFlight()
|
||||
someErr := errors.New("some error")
|
||||
v, err := g.Do("key", func() (interface{}, error) {
|
||||
return nil, someErr
|
||||
@@ -37,7 +37,7 @@ func TestExclusiveCallDoErr(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestExclusiveCallDoDupSuppress(t *testing.T) {
|
||||
g := NewSharedCalls()
|
||||
g := NewSingleFlight()
|
||||
c := make(chan string)
|
||||
var calls int32
|
||||
fn := func() (interface{}, error) {
|
||||
@@ -69,7 +69,7 @@ func TestExclusiveCallDoDupSuppress(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestExclusiveCallDoDiffDupSuppress(t *testing.T) {
|
||||
g := NewSharedCalls()
|
||||
g := NewSingleFlight()
|
||||
broadcast := make(chan struct{})
|
||||
var calls int32
|
||||
tests := []string{"e", "a", "e", "a", "b", "c", "b", "a", "c", "d", "b", "c", "d"}
|
||||
@@ -102,7 +102,7 @@ func TestExclusiveCallDoDiffDupSuppress(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestExclusiveCallDoExDupSuppress(t *testing.T) {
|
||||
g := NewSharedCalls()
|
||||
g := NewSingleFlight()
|
||||
c := make(chan string)
|
||||
var calls int32
|
||||
fn := func() (interface{}, error) {
|
||||
@@ -1,6 +1,8 @@
|
||||
package trace
|
||||
|
||||
const (
|
||||
traceIdKey = "X-Trace-ID"
|
||||
spanIdKey = "X-Span-ID"
|
||||
// TraceIdKey is the trace id header.
|
||||
TraceIdKey = "X-Trace-ID"
|
||||
|
||||
spanIdKey = "X-Span-ID"
|
||||
)
|
||||
|
||||
68
core/trace/opentelemetry/agent.go
Normal file
68
core/trace/opentelemetry/agent.go
Normal file
@@ -0,0 +1,68 @@
|
||||
package opentelemetry
|
||||
|
||||
import (
|
||||
"sync"
|
||||
|
||||
"github.com/tal-tech/go-zero/core/logx"
|
||||
"github.com/tal-tech/go-zero/core/syncx"
|
||||
"go.opentelemetry.io/otel"
|
||||
"go.opentelemetry.io/otel/exporters/jaeger"
|
||||
"go.opentelemetry.io/otel/propagation"
|
||||
"go.opentelemetry.io/otel/sdk/resource"
|
||||
sdktrace "go.opentelemetry.io/otel/sdk/trace"
|
||||
semconv "go.opentelemetry.io/otel/semconv/v1.4.0"
|
||||
)
|
||||
|
||||
var (
|
||||
once sync.Once
|
||||
enabled syncx.AtomicBool
|
||||
)
|
||||
|
||||
// Enabled returns if opentelemetry is enabled.
|
||||
func Enabled() bool {
|
||||
return enabled.True()
|
||||
}
|
||||
|
||||
// StartAgent starts a opentelemetry agent.
|
||||
func StartAgent(c Config) {
|
||||
once.Do(func() {
|
||||
if len(c.Endpoint) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
// Just support jaeger now
|
||||
if c.Batcher != "jaeger" {
|
||||
return
|
||||
}
|
||||
|
||||
exp, err := jaeger.New(jaeger.WithCollectorEndpoint(jaeger.WithEndpoint(c.Endpoint)))
|
||||
if err != nil {
|
||||
logx.Error(err)
|
||||
return
|
||||
}
|
||||
|
||||
tp := sdktrace.NewTracerProvider(
|
||||
// Set the sampling rate based on the parent span to 100%
|
||||
sdktrace.WithSampler(sdktrace.ParentBased(sdktrace.TraceIDRatioBased(c.Sampler))),
|
||||
// Always be sure to batch in production.
|
||||
sdktrace.WithBatcher(exp),
|
||||
// Record information about this application in an Resource.
|
||||
sdktrace.WithResource(resource.NewSchemaless(semconv.ServiceNameKey.String(c.Name))),
|
||||
)
|
||||
|
||||
otel.SetTracerProvider(tp)
|
||||
otel.SetTextMapPropagator(propagation.NewCompositeTextMapPropagator(propagation.TraceContext{}, propagation.Baggage{}))
|
||||
otel.SetErrorHandler(otelErrHandler{})
|
||||
|
||||
enabled.Set(true)
|
||||
})
|
||||
}
|
||||
|
||||
// errHandler handing otel errors.
|
||||
type otelErrHandler struct{}
|
||||
|
||||
var _ otel.ErrorHandler = otelErrHandler{}
|
||||
|
||||
func (o otelErrHandler) Handle(err error) {
|
||||
logx.Errorf("[otel] error: %v", err)
|
||||
}
|
||||
40
core/trace/opentelemetry/attributes.go
Normal file
40
core/trace/opentelemetry/attributes.go
Normal file
@@ -0,0 +1,40 @@
|
||||
package opentelemetry
|
||||
|
||||
import (
|
||||
"go.opentelemetry.io/otel/attribute"
|
||||
semconv "go.opentelemetry.io/otel/semconv/v1.4.0"
|
||||
gcodes "google.golang.org/grpc/codes"
|
||||
)
|
||||
|
||||
const (
|
||||
// GRPCStatusCodeKey is convention for numeric status code of a gRPC request.
|
||||
GRPCStatusCodeKey = attribute.Key("rpc.grpc.status_code")
|
||||
// RPCNameKey is the name of message transmitted or received.
|
||||
RPCNameKey = attribute.Key("name")
|
||||
// RPCMessageTypeKey is the type of message transmitted or received.
|
||||
RPCMessageTypeKey = attribute.Key("message.type")
|
||||
// RPCMessageIDKey is the identifier of message transmitted or received.
|
||||
RPCMessageIDKey = attribute.Key("message.id")
|
||||
// RPCMessageCompressedSizeKey is the compressed size of the message transmitted or received in bytes.
|
||||
RPCMessageCompressedSizeKey = attribute.Key("message.compressed_size")
|
||||
// RPCMessageUncompressedSizeKey is the uncompressed size of the message
|
||||
// transmitted or received in bytes.
|
||||
RPCMessageUncompressedSizeKey = attribute.Key("message.uncompressed_size")
|
||||
)
|
||||
|
||||
// Semantic conventions for common RPC attributes.
|
||||
var (
|
||||
// RPCSystemGRPC is the semantic convention for gRPC as the remoting system.
|
||||
RPCSystemGRPC = semconv.RPCSystemKey.String("grpc")
|
||||
// RPCNameMessage is the semantic convention for a message named message.
|
||||
RPCNameMessage = RPCNameKey.String("message")
|
||||
// RPCMessageTypeSent is the semantic conventions for sent RPC message types.
|
||||
RPCMessageTypeSent = RPCMessageTypeKey.String("SENT")
|
||||
// RPCMessageTypeReceived is the semantic conventions for the received RPC message types.
|
||||
RPCMessageTypeReceived = RPCMessageTypeKey.String("RECEIVED")
|
||||
)
|
||||
|
||||
// StatusCodeAttr returns a attribute.KeyValue that represents the give c.
|
||||
func StatusCodeAttr(c gcodes.Code) attribute.KeyValue {
|
||||
return GRPCStatusCodeKey.Int64(int64(c))
|
||||
}
|
||||
121
core/trace/opentelemetry/clientstream.go
Normal file
121
core/trace/opentelemetry/clientstream.go
Normal file
@@ -0,0 +1,121 @@
|
||||
package opentelemetry
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/metadata"
|
||||
)
|
||||
|
||||
const (
|
||||
receiveEndEvent streamEventType = iota
|
||||
errorEvent
|
||||
)
|
||||
|
||||
type (
|
||||
streamEventType int
|
||||
|
||||
streamEvent struct {
|
||||
Type streamEventType
|
||||
Err error
|
||||
}
|
||||
|
||||
clientStream struct {
|
||||
grpc.ClientStream
|
||||
Finished chan error
|
||||
desc *grpc.StreamDesc
|
||||
events chan streamEvent
|
||||
eventsDone chan struct{}
|
||||
receivedMessageID int
|
||||
sentMessageID int
|
||||
}
|
||||
)
|
||||
|
||||
func (w *clientStream) RecvMsg(m interface{}) error {
|
||||
err := w.ClientStream.RecvMsg(m)
|
||||
if err == nil && !w.desc.ServerStreams {
|
||||
w.sendStreamEvent(receiveEndEvent, nil)
|
||||
} else if err == io.EOF {
|
||||
w.sendStreamEvent(receiveEndEvent, nil)
|
||||
} else if err != nil {
|
||||
w.sendStreamEvent(errorEvent, err)
|
||||
} else {
|
||||
w.receivedMessageID++
|
||||
MessageReceived.Event(w.Context(), w.receivedMessageID, m)
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func (w *clientStream) SendMsg(m interface{}) error {
|
||||
err := w.ClientStream.SendMsg(m)
|
||||
w.sentMessageID++
|
||||
MessageSent.Event(w.Context(), w.sentMessageID, m)
|
||||
if err != nil {
|
||||
w.sendStreamEvent(errorEvent, err)
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func (w *clientStream) Header() (metadata.MD, error) {
|
||||
md, err := w.ClientStream.Header()
|
||||
if err != nil {
|
||||
w.sendStreamEvent(errorEvent, err)
|
||||
}
|
||||
|
||||
return md, err
|
||||
}
|
||||
|
||||
func (w *clientStream) CloseSend() error {
|
||||
err := w.ClientStream.CloseSend()
|
||||
if err != nil {
|
||||
w.sendStreamEvent(errorEvent, err)
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func (w *clientStream) sendStreamEvent(eventType streamEventType, err error) {
|
||||
select {
|
||||
case <-w.eventsDone:
|
||||
case w.events <- streamEvent{Type: eventType, Err: err}:
|
||||
}
|
||||
}
|
||||
|
||||
// WrapClientStream wraps s with given ctx and desc.
|
||||
func WrapClientStream(ctx context.Context, s grpc.ClientStream, desc *grpc.StreamDesc) *clientStream {
|
||||
events := make(chan streamEvent)
|
||||
eventsDone := make(chan struct{})
|
||||
finished := make(chan error)
|
||||
|
||||
go func() {
|
||||
defer close(eventsDone)
|
||||
|
||||
for {
|
||||
select {
|
||||
case event := <-events:
|
||||
switch event.Type {
|
||||
case receiveEndEvent:
|
||||
finished <- nil
|
||||
return
|
||||
case errorEvent:
|
||||
finished <- event.Err
|
||||
return
|
||||
}
|
||||
case <-ctx.Done():
|
||||
finished <- ctx.Err()
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
return &clientStream{
|
||||
ClientStream: s,
|
||||
desc: desc,
|
||||
events: events,
|
||||
eventsDone: eventsDone,
|
||||
Finished: finished,
|
||||
}
|
||||
}
|
||||
12
core/trace/opentelemetry/config.go
Normal file
12
core/trace/opentelemetry/config.go
Normal file
@@ -0,0 +1,12 @@
|
||||
package opentelemetry
|
||||
|
||||
// TraceName represents the tracing name.
|
||||
const TraceName = "go-zero"
|
||||
|
||||
// A Config is a opentelemetry config.
|
||||
type Config struct {
|
||||
Name string `json:",optional"`
|
||||
Endpoint string `json:",optional"`
|
||||
Sampler float64 `json:",default=1.0"`
|
||||
Batcher string `json:",default=jaeger"`
|
||||
}
|
||||
38
core/trace/opentelemetry/message.go
Normal file
38
core/trace/opentelemetry/message.go
Normal file
@@ -0,0 +1,38 @@
|
||||
package opentelemetry
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"go.opentelemetry.io/otel/attribute"
|
||||
"go.opentelemetry.io/otel/trace"
|
||||
"google.golang.org/protobuf/proto"
|
||||
)
|
||||
|
||||
const messageEvent = "message"
|
||||
|
||||
var (
|
||||
// MessageSent is the type of sent messages.
|
||||
MessageSent = messageType(RPCMessageTypeSent)
|
||||
// MessageReceived is the type of received messages.
|
||||
MessageReceived = messageType(RPCMessageTypeReceived)
|
||||
)
|
||||
|
||||
type messageType attribute.KeyValue
|
||||
|
||||
// Event adds an event of the messageType to the span associated with the
|
||||
// passed context with id and size (if message is a proto message).
|
||||
func (m messageType) Event(ctx context.Context, id int, message interface{}) {
|
||||
span := trace.SpanFromContext(ctx)
|
||||
if p, ok := message.(proto.Message); ok {
|
||||
span.AddEvent(messageEvent, trace.WithAttributes(
|
||||
attribute.KeyValue(m),
|
||||
RPCMessageIDKey.Int(id),
|
||||
RPCMessageUncompressedSizeKey.Int(proto.Size(p)),
|
||||
))
|
||||
} else {
|
||||
span.AddEvent(messageEvent, trace.WithAttributes(
|
||||
attribute.KeyValue(m),
|
||||
RPCMessageIDKey.Int(id),
|
||||
))
|
||||
}
|
||||
}
|
||||
47
core/trace/opentelemetry/serverstream.go
Normal file
47
core/trace/opentelemetry/serverstream.go
Normal file
@@ -0,0 +1,47 @@
|
||||
package opentelemetry
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"google.golang.org/grpc"
|
||||
)
|
||||
|
||||
// serverStream wraps around the embedded grpc.ServerStream, and intercepts the RecvMsg and
|
||||
// SendMsg method call.
|
||||
type serverStream struct {
|
||||
grpc.ServerStream
|
||||
ctx context.Context
|
||||
|
||||
receivedMessageID int
|
||||
sentMessageID int
|
||||
}
|
||||
|
||||
func (w *serverStream) Context() context.Context {
|
||||
return w.ctx
|
||||
}
|
||||
|
||||
func (w *serverStream) RecvMsg(m interface{}) error {
|
||||
err := w.ServerStream.RecvMsg(m)
|
||||
if err == nil {
|
||||
w.receivedMessageID++
|
||||
MessageReceived.Event(w.Context(), w.receivedMessageID, m)
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func (w *serverStream) SendMsg(m interface{}) error {
|
||||
err := w.ServerStream.SendMsg(m)
|
||||
w.sentMessageID++
|
||||
MessageSent.Event(w.Context(), w.sentMessageID, m)
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
// WrapServerStream wraps the given grpc.ServerStream with the given context.
|
||||
func WrapServerStream(ctx context.Context, ss grpc.ServerStream) *serverStream {
|
||||
return &serverStream{
|
||||
ServerStream: ss,
|
||||
ctx: ctx,
|
||||
}
|
||||
}
|
||||
56
core/trace/opentelemetry/tracer.go
Normal file
56
core/trace/opentelemetry/tracer.go
Normal file
@@ -0,0 +1,56 @@
|
||||
package opentelemetry
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"go.opentelemetry.io/otel/baggage"
|
||||
"go.opentelemetry.io/otel/propagation"
|
||||
sdktrace "go.opentelemetry.io/otel/trace"
|
||||
"google.golang.org/grpc/metadata"
|
||||
)
|
||||
|
||||
// assert that metadataSupplier implements the TextMapCarrier interface
|
||||
var _ propagation.TextMapCarrier = new(metadataSupplier)
|
||||
|
||||
type metadataSupplier struct {
|
||||
metadata *metadata.MD
|
||||
}
|
||||
|
||||
func (s *metadataSupplier) Get(key string) string {
|
||||
values := s.metadata.Get(key)
|
||||
if len(values) == 0 {
|
||||
return ""
|
||||
}
|
||||
|
||||
return values[0]
|
||||
}
|
||||
|
||||
func (s *metadataSupplier) Set(key, value string) {
|
||||
s.metadata.Set(key, value)
|
||||
}
|
||||
|
||||
func (s *metadataSupplier) Keys() []string {
|
||||
out := make([]string, 0, len(*s.metadata))
|
||||
for key := range *s.metadata {
|
||||
out = append(out, key)
|
||||
}
|
||||
|
||||
return out
|
||||
}
|
||||
|
||||
// Inject injects the metadata into ctx.
|
||||
func Inject(ctx context.Context, p propagation.TextMapPropagator, metadata *metadata.MD) {
|
||||
p.Inject(ctx, &metadataSupplier{
|
||||
metadata: metadata,
|
||||
})
|
||||
}
|
||||
|
||||
// Extract extracts the metadata from ctx.
|
||||
func Extract(ctx context.Context, p propagation.TextMapPropagator, metadata *metadata.MD) (
|
||||
baggage.Baggage, sdktrace.SpanContext) {
|
||||
ctx = p.Extract(ctx, &metadataSupplier{
|
||||
metadata: metadata,
|
||||
})
|
||||
|
||||
return baggage.FromContext(ctx), sdktrace.SpanContextFromContext(ctx)
|
||||
}
|
||||
346
core/trace/opentelemetry/tracer_test.go
Normal file
346
core/trace/opentelemetry/tracer_test.go
Normal file
@@ -0,0 +1,346 @@
|
||||
package opentelemetry
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"go.opentelemetry.io/otel"
|
||||
"go.opentelemetry.io/otel/propagation"
|
||||
"go.opentelemetry.io/otel/trace"
|
||||
"google.golang.org/grpc/metadata"
|
||||
)
|
||||
|
||||
const (
|
||||
traceIDStr = "4bf92f3577b34da6a3ce929d0e0e4736"
|
||||
spanIDStr = "00f067aa0ba902b7"
|
||||
)
|
||||
|
||||
var (
|
||||
traceID = mustTraceIDFromHex(traceIDStr)
|
||||
spanID = mustSpanIDFromHex(spanIDStr)
|
||||
)
|
||||
|
||||
func mustTraceIDFromHex(s string) (t trace.TraceID) {
|
||||
var err error
|
||||
t, err = trace.TraceIDFromHex(s)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func mustSpanIDFromHex(s string) (t trace.SpanID) {
|
||||
var err error
|
||||
t, err = trace.SpanIDFromHex(s)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func TestExtractValidTraceContext(t *testing.T) {
|
||||
stateStr := "key1=value1,key2=value2"
|
||||
state, err := trace.ParseTraceState(stateStr)
|
||||
require.NoError(t, err)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
traceparent string
|
||||
tracestate string
|
||||
sc trace.SpanContext
|
||||
}{
|
||||
{
|
||||
name: "not sampled",
|
||||
traceparent: "00-4bf92f3577b34da6a3ce929d0e0e4736-00f067aa0ba902b7-00",
|
||||
sc: trace.NewSpanContext(trace.SpanContextConfig{
|
||||
TraceID: traceID,
|
||||
SpanID: spanID,
|
||||
Remote: true,
|
||||
}),
|
||||
},
|
||||
{
|
||||
name: "sampled",
|
||||
traceparent: "00-4bf92f3577b34da6a3ce929d0e0e4736-00f067aa0ba902b7-01",
|
||||
sc: trace.NewSpanContext(trace.SpanContextConfig{
|
||||
TraceID: traceID,
|
||||
SpanID: spanID,
|
||||
TraceFlags: trace.FlagsSampled,
|
||||
Remote: true,
|
||||
}),
|
||||
},
|
||||
{
|
||||
name: "valid tracestate",
|
||||
traceparent: "00-4bf92f3577b34da6a3ce929d0e0e4736-00f067aa0ba902b7-00",
|
||||
tracestate: stateStr,
|
||||
sc: trace.NewSpanContext(trace.SpanContextConfig{
|
||||
TraceID: traceID,
|
||||
SpanID: spanID,
|
||||
TraceState: state,
|
||||
Remote: true,
|
||||
}),
|
||||
},
|
||||
{
|
||||
name: "invalid tracestate perserves traceparent",
|
||||
traceparent: "00-4bf92f3577b34da6a3ce929d0e0e4736-00f067aa0ba902b7-00",
|
||||
tracestate: "invalid$@#=invalid",
|
||||
sc: trace.NewSpanContext(trace.SpanContextConfig{
|
||||
TraceID: traceID,
|
||||
SpanID: spanID,
|
||||
Remote: true,
|
||||
}),
|
||||
},
|
||||
{
|
||||
name: "future version not sampled",
|
||||
traceparent: "02-4bf92f3577b34da6a3ce929d0e0e4736-00f067aa0ba902b7-00",
|
||||
sc: trace.NewSpanContext(trace.SpanContextConfig{
|
||||
TraceID: traceID,
|
||||
SpanID: spanID,
|
||||
Remote: true,
|
||||
}),
|
||||
},
|
||||
{
|
||||
name: "future version sampled",
|
||||
traceparent: "02-4bf92f3577b34da6a3ce929d0e0e4736-00f067aa0ba902b7-01",
|
||||
sc: trace.NewSpanContext(trace.SpanContextConfig{
|
||||
TraceID: traceID,
|
||||
SpanID: spanID,
|
||||
TraceFlags: trace.FlagsSampled,
|
||||
Remote: true,
|
||||
}),
|
||||
},
|
||||
{
|
||||
name: "future version sample bit set",
|
||||
traceparent: "02-4bf92f3577b34da6a3ce929d0e0e4736-00f067aa0ba902b7-09",
|
||||
sc: trace.NewSpanContext(trace.SpanContextConfig{
|
||||
TraceID: traceID,
|
||||
SpanID: spanID,
|
||||
TraceFlags: trace.FlagsSampled,
|
||||
Remote: true,
|
||||
}),
|
||||
},
|
||||
{
|
||||
name: "future version sample bit not set",
|
||||
traceparent: "02-4bf92f3577b34da6a3ce929d0e0e4736-00f067aa0ba902b7-08",
|
||||
sc: trace.NewSpanContext(trace.SpanContextConfig{
|
||||
TraceID: traceID,
|
||||
SpanID: spanID,
|
||||
Remote: true,
|
||||
}),
|
||||
},
|
||||
{
|
||||
name: "future version additional data",
|
||||
traceparent: "02-4bf92f3577b34da6a3ce929d0e0e4736-00f067aa0ba902b7-00-XYZxsf09",
|
||||
sc: trace.NewSpanContext(trace.SpanContextConfig{
|
||||
TraceID: traceID,
|
||||
SpanID: spanID,
|
||||
Remote: true,
|
||||
}),
|
||||
},
|
||||
{
|
||||
name: "B3 format ending in dash",
|
||||
traceparent: "00-4bf92f3577b34da6a3ce929d0e0e4736-00f067aa0ba902b7-00-",
|
||||
sc: trace.NewSpanContext(trace.SpanContextConfig{
|
||||
TraceID: traceID,
|
||||
SpanID: spanID,
|
||||
Remote: true,
|
||||
}),
|
||||
},
|
||||
{
|
||||
name: "future version B3 format ending in dash",
|
||||
traceparent: "03-4bf92f3577b34da6a3ce929d0e0e4736-00f067aa0ba902b7-00-",
|
||||
sc: trace.NewSpanContext(trace.SpanContextConfig{
|
||||
TraceID: traceID,
|
||||
SpanID: spanID,
|
||||
Remote: true,
|
||||
}),
|
||||
},
|
||||
}
|
||||
otel.SetTextMapPropagator(propagation.NewCompositeTextMapPropagator(propagation.TraceContext{}, propagation.Baggage{}))
|
||||
propagator := otel.GetTextMapPropagator()
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
md := metadata.MD{}
|
||||
md.Set("traceparent", tt.traceparent)
|
||||
md.Set("tracestate", tt.tracestate)
|
||||
_, spanCtx := Extract(ctx, propagator, &md)
|
||||
assert.Equal(t, tt.sc, spanCtx)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractInvalidTraceContext(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
header string
|
||||
}{
|
||||
{
|
||||
name: "wrong version length",
|
||||
header: "0000-00000000000000000000000000000000-0000000000000000-01",
|
||||
},
|
||||
{
|
||||
name: "wrong trace ID length",
|
||||
header: "00-ab00000000000000000000000000000000-cd00000000000000-01",
|
||||
},
|
||||
{
|
||||
name: "wrong span ID length",
|
||||
header: "00-ab000000000000000000000000000000-cd0000000000000000-01",
|
||||
},
|
||||
{
|
||||
name: "wrong trace flag length",
|
||||
header: "00-ab000000000000000000000000000000-cd00000000000000-0100",
|
||||
},
|
||||
{
|
||||
name: "bogus version",
|
||||
header: "qw-00000000000000000000000000000000-0000000000000000-01",
|
||||
},
|
||||
{
|
||||
name: "bogus trace ID",
|
||||
header: "00-qw000000000000000000000000000000-cd00000000000000-01",
|
||||
},
|
||||
{
|
||||
name: "bogus span ID",
|
||||
header: "00-ab000000000000000000000000000000-qw00000000000000-01",
|
||||
},
|
||||
{
|
||||
name: "bogus trace flag",
|
||||
header: "00-ab000000000000000000000000000000-cd00000000000000-qw",
|
||||
},
|
||||
{
|
||||
name: "upper case version",
|
||||
header: "A0-00000000000000000000000000000000-0000000000000000-01",
|
||||
},
|
||||
{
|
||||
name: "upper case trace ID",
|
||||
header: "00-AB000000000000000000000000000000-cd00000000000000-01",
|
||||
},
|
||||
{
|
||||
name: "upper case span ID",
|
||||
header: "00-ab000000000000000000000000000000-CD00000000000000-01",
|
||||
},
|
||||
{
|
||||
name: "upper case trace flag",
|
||||
header: "00-ab000000000000000000000000000000-cd00000000000000-A1",
|
||||
},
|
||||
{
|
||||
name: "zero trace ID and span ID",
|
||||
header: "00-00000000000000000000000000000000-0000000000000000-01",
|
||||
},
|
||||
{
|
||||
name: "trace-flag unused bits set",
|
||||
header: "00-ab000000000000000000000000000000-cd00000000000000-09",
|
||||
},
|
||||
{
|
||||
name: "missing options",
|
||||
header: "00-4bf92f3577b34da6a3ce929d0e0e4736-00f067aa0ba902b7",
|
||||
},
|
||||
{
|
||||
name: "empty options",
|
||||
header: "00-4bf92f3577b34da6a3ce929d0e0e4736-00f067aa0ba902b7-",
|
||||
},
|
||||
}
|
||||
otel.SetTextMapPropagator(propagation.NewCompositeTextMapPropagator(propagation.TraceContext{}, propagation.Baggage{}))
|
||||
propagator := otel.GetTextMapPropagator()
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
md := metadata.MD{}
|
||||
md.Set("traceparent", tt.header)
|
||||
_, spanCtx := Extract(ctx, propagator, &md)
|
||||
assert.Equal(t, trace.SpanContext{}, spanCtx)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestInjectValidTraceContext(t *testing.T) {
|
||||
stateStr := "key1=value1,key2=value2"
|
||||
state, err := trace.ParseTraceState(stateStr)
|
||||
require.NoError(t, err)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
traceparent string
|
||||
tracestate string
|
||||
sc trace.SpanContext
|
||||
}{
|
||||
{
|
||||
name: "not sampled",
|
||||
traceparent: "00-4bf92f3577b34da6a3ce929d0e0e4736-00f067aa0ba902b7-00",
|
||||
sc: trace.NewSpanContext(trace.SpanContextConfig{
|
||||
TraceID: traceID,
|
||||
SpanID: spanID,
|
||||
Remote: true,
|
||||
}),
|
||||
},
|
||||
{
|
||||
name: "sampled",
|
||||
traceparent: "00-4bf92f3577b34da6a3ce929d0e0e4736-00f067aa0ba902b7-01",
|
||||
sc: trace.NewSpanContext(trace.SpanContextConfig{
|
||||
TraceID: traceID,
|
||||
SpanID: spanID,
|
||||
TraceFlags: trace.FlagsSampled,
|
||||
Remote: true,
|
||||
}),
|
||||
},
|
||||
{
|
||||
name: "unsupported trace flag bits dropped",
|
||||
traceparent: "00-4bf92f3577b34da6a3ce929d0e0e4736-00f067aa0ba902b7-01",
|
||||
sc: trace.NewSpanContext(trace.SpanContextConfig{
|
||||
TraceID: traceID,
|
||||
SpanID: spanID,
|
||||
TraceFlags: 0xff,
|
||||
Remote: true,
|
||||
}),
|
||||
},
|
||||
{
|
||||
name: "with tracestate",
|
||||
traceparent: "00-4bf92f3577b34da6a3ce929d0e0e4736-00f067aa0ba902b7-00",
|
||||
tracestate: stateStr,
|
||||
sc: trace.NewSpanContext(trace.SpanContextConfig{
|
||||
TraceID: traceID,
|
||||
SpanID: spanID,
|
||||
TraceState: state,
|
||||
Remote: true,
|
||||
}),
|
||||
},
|
||||
}
|
||||
otel.SetTextMapPropagator(propagation.NewCompositeTextMapPropagator(propagation.TraceContext{}, propagation.Baggage{}))
|
||||
propagator := otel.GetTextMapPropagator()
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
ctx = trace.ContextWithRemoteSpanContext(ctx, tt.sc)
|
||||
|
||||
want := metadata.MD{}
|
||||
want.Set("traceparent", tt.traceparent)
|
||||
if len(tt.tracestate) > 0 {
|
||||
want.Set("tracestate", tt.tracestate)
|
||||
}
|
||||
|
||||
md := metadata.MD{}
|
||||
Inject(ctx, propagator, &md)
|
||||
assert.Equal(t, want, md)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestInvalidSpanContextDropped(t *testing.T) {
|
||||
invalidSC := trace.SpanContext{}
|
||||
require.False(t, invalidSC.IsValid())
|
||||
ctx := trace.ContextWithRemoteSpanContext(context.Background(), invalidSC)
|
||||
|
||||
otel.SetTextMapPropagator(propagation.NewCompositeTextMapPropagator(propagation.TraceContext{}, propagation.Baggage{}))
|
||||
propagator := otel.GetTextMapPropagator()
|
||||
|
||||
md := metadata.MD{}
|
||||
Inject(ctx, propagator, &md)
|
||||
mm := &metadataSupplier{
|
||||
metadata: &md,
|
||||
}
|
||||
assert.Equal(t, "", mm.Get("traceparent"), "injected invalid SpanContext")
|
||||
}
|
||||
69
core/trace/opentelemetry/utils.go
Normal file
69
core/trace/opentelemetry/utils.go
Normal file
@@ -0,0 +1,69 @@
|
||||
package opentelemetry
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"strings"
|
||||
|
||||
"go.opentelemetry.io/otel/attribute"
|
||||
semconv "go.opentelemetry.io/otel/semconv/v1.4.0"
|
||||
"google.golang.org/grpc/peer"
|
||||
)
|
||||
|
||||
const localhost = "127.0.0.1"
|
||||
|
||||
// PeerFromCtx returns the peer from ctx.
|
||||
func PeerFromCtx(ctx context.Context) string {
|
||||
p, ok := peer.FromContext(ctx)
|
||||
if !ok {
|
||||
return ""
|
||||
}
|
||||
|
||||
return p.Addr.String()
|
||||
}
|
||||
|
||||
// SpanInfo returns the span info.
|
||||
func SpanInfo(fullMethod, peerAddress string) (string, []attribute.KeyValue) {
|
||||
attrs := []attribute.KeyValue{RPCSystemGRPC}
|
||||
name, mAttrs := ParseFullMethod(fullMethod)
|
||||
attrs = append(attrs, mAttrs...)
|
||||
attrs = append(attrs, PeerAttr(peerAddress)...)
|
||||
return name, attrs
|
||||
}
|
||||
|
||||
// ParseFullMethod returns the method name and attributes.
|
||||
func ParseFullMethod(fullMethod string) (string, []attribute.KeyValue) {
|
||||
name := strings.TrimLeft(fullMethod, "/")
|
||||
parts := strings.SplitN(name, "/", 2)
|
||||
if len(parts) != 2 {
|
||||
// Invalid format, does not follow `/package.service/method`.
|
||||
return name, []attribute.KeyValue(nil)
|
||||
}
|
||||
|
||||
var attrs []attribute.KeyValue
|
||||
if service := parts[0]; service != "" {
|
||||
attrs = append(attrs, semconv.RPCServiceKey.String(service))
|
||||
}
|
||||
if method := parts[1]; method != "" {
|
||||
attrs = append(attrs, semconv.RPCMethodKey.String(method))
|
||||
}
|
||||
|
||||
return name, attrs
|
||||
}
|
||||
|
||||
// PeerAttr returns the peer attributes.
|
||||
func PeerAttr(addr string) []attribute.KeyValue {
|
||||
host, port, err := net.SplitHostPort(addr)
|
||||
if err != nil {
|
||||
return []attribute.KeyValue(nil)
|
||||
}
|
||||
|
||||
if len(host) == 0 {
|
||||
host = localhost
|
||||
}
|
||||
|
||||
return []attribute.KeyValue{
|
||||
semconv.NetPeerIPKey.String(host),
|
||||
semconv.NetPeerPortKey.String(port),
|
||||
}
|
||||
}
|
||||
70
core/trace/opentelemetry/utils_test.go
Normal file
70
core/trace/opentelemetry/utils_test.go
Normal file
@@ -0,0 +1,70 @@
|
||||
package opentelemetry
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"go.opentelemetry.io/otel/attribute"
|
||||
semconv "go.opentelemetry.io/otel/semconv/v1.4.0"
|
||||
)
|
||||
|
||||
func TestParseFullMethod(t *testing.T) {
|
||||
tests := []struct {
|
||||
fullMethod string
|
||||
name string
|
||||
attr []attribute.KeyValue
|
||||
}{
|
||||
{
|
||||
fullMethod: "/grpc.test.EchoService/Echo",
|
||||
name: "grpc.test.EchoService/Echo",
|
||||
attr: []attribute.KeyValue{
|
||||
semconv.RPCServiceKey.String("grpc.test.EchoService"),
|
||||
semconv.RPCMethodKey.String("Echo"),
|
||||
},
|
||||
}, {
|
||||
fullMethod: "/com.example.ExampleRmiService/exampleMethod",
|
||||
name: "com.example.ExampleRmiService/exampleMethod",
|
||||
attr: []attribute.KeyValue{
|
||||
semconv.RPCServiceKey.String("com.example.ExampleRmiService"),
|
||||
semconv.RPCMethodKey.String("exampleMethod"),
|
||||
},
|
||||
}, {
|
||||
fullMethod: "/MyCalcService.Calculator/Add",
|
||||
name: "MyCalcService.Calculator/Add",
|
||||
attr: []attribute.KeyValue{
|
||||
semconv.RPCServiceKey.String("MyCalcService.Calculator"),
|
||||
semconv.RPCMethodKey.String("Add"),
|
||||
},
|
||||
}, {
|
||||
fullMethod: "/MyServiceReference.ICalculator/Add",
|
||||
name: "MyServiceReference.ICalculator/Add",
|
||||
attr: []attribute.KeyValue{
|
||||
semconv.RPCServiceKey.String("MyServiceReference.ICalculator"),
|
||||
semconv.RPCMethodKey.String("Add"),
|
||||
},
|
||||
}, {
|
||||
fullMethod: "/MyServiceWithNoPackage/theMethod",
|
||||
name: "MyServiceWithNoPackage/theMethod",
|
||||
attr: []attribute.KeyValue{
|
||||
semconv.RPCServiceKey.String("MyServiceWithNoPackage"),
|
||||
semconv.RPCMethodKey.String("theMethod"),
|
||||
},
|
||||
}, {
|
||||
fullMethod: "/pkg.srv",
|
||||
name: "pkg.srv",
|
||||
attr: []attribute.KeyValue(nil),
|
||||
}, {
|
||||
fullMethod: "/pkg.srv/",
|
||||
name: "pkg.srv/",
|
||||
attr: []attribute.KeyValue{
|
||||
semconv.RPCServiceKey.String("pkg.srv"),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
n, a := ParseFullMethod(test.fullMethod)
|
||||
assert.Equal(t, test.name, n)
|
||||
assert.Equal(t, test.attr, a)
|
||||
}
|
||||
}
|
||||
@@ -11,11 +11,11 @@ import (
|
||||
|
||||
func TestHttpPropagator_Extract(t *testing.T) {
|
||||
req := httptest.NewRequest(http.MethodGet, "http://localhost", nil)
|
||||
req.Header.Set(traceIdKey, "trace")
|
||||
req.Header.Set(TraceIdKey, "trace")
|
||||
req.Header.Set(spanIdKey, "span")
|
||||
carrier, err := Extract(HttpFormat, req.Header)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, "trace", carrier.Get(traceIdKey))
|
||||
assert.Equal(t, "trace", carrier.Get(TraceIdKey))
|
||||
assert.Equal(t, "span", carrier.Get(spanIdKey))
|
||||
|
||||
_, err = Extract(HttpFormat, req)
|
||||
@@ -24,11 +24,11 @@ func TestHttpPropagator_Extract(t *testing.T) {
|
||||
|
||||
func TestHttpPropagator_Inject(t *testing.T) {
|
||||
req := httptest.NewRequest(http.MethodGet, "http://localhost", nil)
|
||||
req.Header.Set(traceIdKey, "trace")
|
||||
req.Header.Set(TraceIdKey, "trace")
|
||||
req.Header.Set(spanIdKey, "span")
|
||||
carrier, err := Inject(HttpFormat, req.Header)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, "trace", carrier.Get(traceIdKey))
|
||||
assert.Equal(t, "trace", carrier.Get(TraceIdKey))
|
||||
assert.Equal(t, "span", carrier.Get(spanIdKey))
|
||||
|
||||
_, err = Inject(HttpFormat, req)
|
||||
@@ -37,12 +37,12 @@ func TestHttpPropagator_Inject(t *testing.T) {
|
||||
|
||||
func TestGrpcPropagator_Extract(t *testing.T) {
|
||||
md := metadata.New(map[string]string{
|
||||
traceIdKey: "trace",
|
||||
TraceIdKey: "trace",
|
||||
spanIdKey: "span",
|
||||
})
|
||||
carrier, err := Extract(GrpcFormat, md)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, "trace", carrier.Get(traceIdKey))
|
||||
assert.Equal(t, "trace", carrier.Get(TraceIdKey))
|
||||
assert.Equal(t, "span", carrier.Get(spanIdKey))
|
||||
|
||||
_, err = Extract(GrpcFormat, 1)
|
||||
@@ -53,12 +53,12 @@ func TestGrpcPropagator_Extract(t *testing.T) {
|
||||
|
||||
func TestGrpcPropagator_Inject(t *testing.T) {
|
||||
md := metadata.New(map[string]string{
|
||||
traceIdKey: "trace",
|
||||
TraceIdKey: "trace",
|
||||
spanIdKey: "span",
|
||||
})
|
||||
carrier, err := Inject(GrpcFormat, md)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, "trace", carrier.Get(traceIdKey))
|
||||
assert.Equal(t, "trace", carrier.Get(TraceIdKey))
|
||||
assert.Equal(t, "span", carrier.Get(spanIdKey))
|
||||
|
||||
_, err = Inject(GrpcFormat, 1)
|
||||
|
||||
@@ -34,7 +34,7 @@ type Span struct {
|
||||
func newServerSpan(carrier Carrier, serviceName, operationName string) tracespec.Trace {
|
||||
traceId := stringx.TakeWithPriority(func() string {
|
||||
if carrier != nil {
|
||||
return carrier.Get(traceIdKey)
|
||||
return carrier.Get(TraceIdKey)
|
||||
}
|
||||
return ""
|
||||
}, stringx.RandId)
|
||||
|
||||
@@ -57,7 +57,7 @@ func TestServerSpan(t *testing.T) {
|
||||
|
||||
func TestServerSpan_WithCarrier(t *testing.T) {
|
||||
md := metadata.New(map[string]string{
|
||||
traceIdKey: "a",
|
||||
TraceIdKey: "a",
|
||||
spanIdKey: "0.1",
|
||||
})
|
||||
ctx, span := StartServerSpan(context.Background(), grpcCarrier(md), "service", "operation")
|
||||
@@ -99,7 +99,7 @@ func TestSpan_Follow(t *testing.T) {
|
||||
for _, test := range tests {
|
||||
t.Run(stringx.RandId(), func(t *testing.T) {
|
||||
md := metadata.New(map[string]string{
|
||||
traceIdKey: "a",
|
||||
TraceIdKey: "a",
|
||||
spanIdKey: test.span,
|
||||
})
|
||||
ctx, span := StartServerSpan(context.Background(), grpcCarrier(md),
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user