mirror of
https://github.com/zeromicro/go-zero.git
synced 2026-05-11 00:40:00 +08:00
Compare commits
216 Commits
tools/goct
...
copilot/fi
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
c37121ff87 | ||
|
|
f08d9329a8 | ||
|
|
988fb9d9bf | ||
|
|
d212c81bca | ||
|
|
bc43df2641 | ||
|
|
351b8cb37b | ||
|
|
0d681a2e29 | ||
|
|
5ea027c5de | ||
|
|
5de6112dcd | ||
|
|
4fb51723b7 | ||
|
|
06502d1115 | ||
|
|
3854d6dd00 | ||
|
|
895854913a | ||
|
|
ef753b8857 | ||
|
|
9c16fede73 | ||
|
|
ce11adb5e4 | ||
|
|
894e8b1218 | ||
|
|
2ec7e432dd | ||
|
|
870e8352c1 | ||
|
|
de42f27e03 | ||
|
|
955b8016aa | ||
|
|
d728a3b2d9 | ||
|
|
0c205a71fc | ||
|
|
a8c0199d96 | ||
|
|
032a266ec4 | ||
|
|
40b75fbb9b | ||
|
|
afad55045b | ||
|
|
5f54f06ee5 | ||
|
|
20f56ae1d0 | ||
|
|
73d6fcfccd | ||
|
|
20d20ef861 | ||
|
|
a37422b504 | ||
|
|
a81d898408 | ||
|
|
a5d42e20d5 | ||
|
|
4bdb07f225 | ||
|
|
3e6ec9b83d | ||
|
|
f0a3d213dc | ||
|
|
94562ded74 | ||
|
|
d68cf4920c | ||
|
|
31b749ab67 | ||
|
|
3834319278 | ||
|
|
1c9d339361 | ||
|
|
b7f601c912 | ||
|
|
1ebbc6f0c7 | ||
|
|
b41b1b00df | ||
|
|
f36e5fed35 | ||
|
|
2583673c8b | ||
|
|
00e67b9d20 | ||
|
|
9fd1f29845 | ||
|
|
130e1ba963 | ||
|
|
a2b98dbcf7 | ||
|
|
b46d507a1d | ||
|
|
3152581d0d | ||
|
|
46e466f037 | ||
|
|
151b3d1085 | ||
|
|
ea53fe41de | ||
|
|
d9df08b079 | ||
|
|
569c00ad09 | ||
|
|
9da76fbf04 | ||
|
|
b69db5e09d | ||
|
|
ee6b7cee79 | ||
|
|
d150248c52 | ||
|
|
610a7345dc | ||
|
|
b0b31f3993 | ||
|
|
82a937d517 | ||
|
|
93c11a7eb7 | ||
|
|
63ec989376 | ||
|
|
bf75027889 | ||
|
|
d505fae979 | ||
|
|
25f37ca750 | ||
|
|
0be63c3625 | ||
|
|
b011a072c7 | ||
|
|
3c9b6335fb | ||
|
|
bf6ef5f033 | ||
|
|
ff890628b0 | ||
|
|
cc79e3d842 | ||
|
|
f11b78ced9 | ||
|
|
1d2b0d7ab8 | ||
|
|
da987e1270 | ||
|
|
12e03c8843 | ||
|
|
8cf4f95bd7 | ||
|
|
ba0febf308 | ||
|
|
c9ff6a10d3 | ||
|
|
a71e56de52 | ||
|
|
bae8d4f4c8 | ||
|
|
8c6266f338 | ||
|
|
95d5b81f44 | ||
|
|
bca7bbc142 | ||
|
|
df9a52664b | ||
|
|
937cf0db96 | ||
|
|
75cebb65f8 | ||
|
|
410f56e73a | ||
|
|
017909a3ab | ||
|
|
0d31e6c375 | ||
|
|
0ba86b1849 | ||
|
|
4cacc4d9d3 | ||
|
|
a99c14da4a | ||
|
|
985582264a | ||
|
|
8364e341e1 | ||
|
|
0f2b589d4d | ||
|
|
19fec36d24 | ||
|
|
f037bf344d | ||
|
|
d99cf35b07 | ||
|
|
f459f1b5ff | ||
|
|
0140fd417b | ||
|
|
7969e0ca38 | ||
|
|
91c885b5b0 | ||
|
|
d4cccca387 | ||
|
|
4b2095ed03 | ||
|
|
1229eeb2d2 | ||
|
|
9142b146c5 | ||
|
|
8a1b2d5aed | ||
|
|
da5d39e6ca | ||
|
|
68c5a17c67 | ||
|
|
b53f9f5f2d | ||
|
|
36d57626b6 | ||
|
|
4e36ba832f | ||
|
|
a44954a771 | ||
|
|
f3edd4b880 | ||
|
|
2de3e397ff | ||
|
|
a435eb56f2 | ||
|
|
d80761c147 | ||
|
|
e7bd0d8b60 | ||
|
|
b109b3ef4c | ||
|
|
e3c371ac89 | ||
|
|
15eb6f4f6d | ||
|
|
4d3681b71c | ||
|
|
a682bda0bb | ||
|
|
45b27ad93a | ||
|
|
292a8302a1 | ||
|
|
91ab1f6d2b | ||
|
|
5048c350ae | ||
|
|
94edc32f3e | ||
|
|
ec989b2e2a | ||
|
|
82fe802e81 | ||
|
|
072d68f897 | ||
|
|
2e91ba5811 | ||
|
|
5564c43197 | ||
|
|
e55158b0f7 | ||
|
|
69aa7fe346 | ||
|
|
c3820a95c1 | ||
|
|
493f3bad0f | ||
|
|
eb0d5ad3a4 | ||
|
|
14192050ae | ||
|
|
9193e771e3 | ||
|
|
808b4e496a | ||
|
|
e416d01f8d | ||
|
|
789c5de873 | ||
|
|
52078a0c14 | ||
|
|
7ef13116a0 | ||
|
|
6b8053410a | ||
|
|
81c6928445 | ||
|
|
761c2dd716 | ||
|
|
aeceb3cfbe | ||
|
|
15ea07aad1 | ||
|
|
98bebbc74f | ||
|
|
eafd11d949 | ||
|
|
b251ce346e | ||
|
|
812140ba36 | ||
|
|
44735e949c | ||
|
|
bf313c3c56 | ||
|
|
94e7753262 | ||
|
|
9c478626d2 | ||
|
|
801c283478 | ||
|
|
2a54faf997 | ||
|
|
ecd98f3653 | ||
|
|
61641581eb | ||
|
|
6f2730d5ae | ||
|
|
0eff777b62 | ||
|
|
cafbf535f7 | ||
|
|
6edfce63e3 | ||
|
|
cdb0098b18 | ||
|
|
620c7f9693 | ||
|
|
dba444a382 | ||
|
|
b24fb3ebf7 | ||
|
|
967f0926eb | ||
|
|
e68c683df9 | ||
|
|
247985a065 | ||
|
|
80573af0d8 | ||
|
|
c0394b631a | ||
|
|
68d1aba377 | ||
|
|
3315e60272 | ||
|
|
327ef73700 | ||
|
|
eb11521655 | ||
|
|
4c37545e55 | ||
|
|
2f47c1fba4 | ||
|
|
16d54d0ace | ||
|
|
9925bcbf99 | ||
|
|
38a5ecb796 | ||
|
|
af78fc7c5f | ||
|
|
790302b486 | ||
|
|
6a0672b801 | ||
|
|
560c61612c | ||
|
|
6a988dc4a9 | ||
|
|
15842c3c7a | ||
|
|
f2914a74df | ||
|
|
f113d512e8 | ||
|
|
7a4818da59 | ||
|
|
48d0709ca6 | ||
|
|
f747585518 | ||
|
|
507ff96546 | ||
|
|
651eabb4c6 | ||
|
|
e6b4372056 | ||
|
|
24073969a1 | ||
|
|
ca797ed22c | ||
|
|
e347d3f8f8 | ||
|
|
396393b336 | ||
|
|
1f0531b254 | ||
|
|
77fb271a06 | ||
|
|
af7cf79963 | ||
|
|
7926d396d7 | ||
|
|
080cd3df84 | ||
|
|
c4e1a6a2d8 | ||
|
|
4e71e95e44 | ||
|
|
84db9bcd15 | ||
|
|
b28f79ac11 |
13
.codecov.yml
13
.codecov.yml
@@ -1,13 +0,0 @@
|
||||
coverage:
|
||||
status:
|
||||
patch: true
|
||||
project: false # disabled because project coverage is not stable
|
||||
comment:
|
||||
layout: "flags, files"
|
||||
behavior: once
|
||||
require_changes: true
|
||||
ignore:
|
||||
- "tools"
|
||||
- "**/mock"
|
||||
- "**/*_mock.go"
|
||||
- "**/*test"
|
||||
197
.github/copilot-instructions.md
vendored
Normal file
197
.github/copilot-instructions.md
vendored
Normal file
@@ -0,0 +1,197 @@
|
||||
# GitHub Copilot Instructions for go-zero
|
||||
|
||||
This document provides guidelines for GitHub Copilot when assisting with development in the go-zero project.
|
||||
|
||||
## Project Overview
|
||||
|
||||
go-zero is a web and RPC framework with lots of built-in engineering practices designed to ensure the stability of busy services with resilience design. It has been serving sites with tens of millions of users for years.
|
||||
|
||||
### Key Architecture Components
|
||||
|
||||
- **REST API framework** (`rest/`) - HTTP service framework with middleware support
|
||||
- **RPC framework** (`zrpc/`) - gRPC-based RPC framework with service discovery
|
||||
- **Core utilities** (`core/`) - Foundational components including:
|
||||
- Circuit breakers, rate limiters, load shedding
|
||||
- Caching, stores (Redis, MongoDB, SQL)
|
||||
- Concurrency control, metrics, tracing
|
||||
- Configuration management
|
||||
- **Code generation tool** (`tools/goctl/`) - CLI tool for generating code from API files
|
||||
|
||||
## Coding Standards and Conventions
|
||||
|
||||
### Code Style
|
||||
|
||||
1. **Follow Go conventions**: Use `gofmt` for formatting, follow effective Go practices
|
||||
2. **Package naming**: Use lowercase, single-word package names when possible
|
||||
3. **Error handling**: Always handle errors explicitly, use `errorx.BatchError` for multiple errors
|
||||
4. **Context propagation**: Always pass `context.Context` as the first parameter for functions that may block
|
||||
5. **Configuration structures**: Use struct tags with JSON annotations and default values
|
||||
|
||||
Example configuration pattern:
|
||||
```go
|
||||
type Config struct {
|
||||
Host string `json:",default=0.0.0.0"`
|
||||
Port int `json:",default=8080"`
|
||||
Timeout int `json:",default=3000"`
|
||||
Optional string `json:",optional"`
|
||||
}
|
||||
```
|
||||
|
||||
### Interface Design
|
||||
|
||||
1. **Small interfaces**: Follow Go's preference for small, focused interfaces
|
||||
2. **Context methods**: Provide both context and non-context versions of methods
|
||||
3. **Options pattern**: Use functional options for complex configuration
|
||||
|
||||
Example:
|
||||
```go
|
||||
func (c *Client) Get(key string, val any) error {
|
||||
return c.GetCtx(context.Background(), key, val)
|
||||
}
|
||||
|
||||
func (c *Client) GetCtx(ctx context.Context, key string, val any) error {
|
||||
// implementation
|
||||
}
|
||||
```
|
||||
|
||||
### Testing Patterns
|
||||
|
||||
1. **Test file naming**: Use `*_test.go` suffix
|
||||
2. **Test function naming**: Use `TestFunctionName` pattern
|
||||
3. **Use testify/assert**: Prefer `assert` package for assertions
|
||||
4. **Table-driven tests**: Use table-driven tests for multiple scenarios
|
||||
5. **Mock interfaces**: Use `go.uber.org/mock` for mocking
|
||||
6. **Test helpers**: Use `redistest`, `mongtest` helpers for database testing
|
||||
|
||||
Example test pattern:
|
||||
```go
|
||||
func TestSomething(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
expected string
|
||||
wantErr bool
|
||||
}{
|
||||
{"valid case", "input", "output", false},
|
||||
{"error case", "bad", "", true},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result, err := SomeFunction(tt.input)
|
||||
if tt.wantErr {
|
||||
assert.Error(t, err)
|
||||
return
|
||||
}
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, tt.expected, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## Framework-Specific Guidelines
|
||||
|
||||
### REST API Development
|
||||
|
||||
1. **API Definition**: Use `.api` files to define REST APIs
|
||||
2. **Handler pattern**: Separate business logic into logic packages
|
||||
3. **Middleware**: Use built-in middlewares (tracing, logging, metrics, recovery)
|
||||
4. **Response handling**: Use `httpx.WriteJson` for JSON responses
|
||||
5. **Error handling**: Use `httpx.Error` for HTTP error responses
|
||||
|
||||
### RPC Development
|
||||
|
||||
1. **Protocol Buffers**: Use protobuf for service definitions
|
||||
2. **Service discovery**: Integrate with etcd for service registration
|
||||
3. **Load balancing**: Use built-in load balancing strategies
|
||||
4. **Interceptors**: Implement interceptors for cross-cutting concerns
|
||||
|
||||
### Database Operations
|
||||
|
||||
1. **SQL operations**: Use `sqlx` package for database operations
|
||||
2. **Caching**: Implement caching patterns with `cache` package
|
||||
3. **Transactions**: Use proper transaction handling
|
||||
4. **Connection pooling**: Configure appropriate connection pools
|
||||
|
||||
Example cache pattern:
|
||||
```go
|
||||
err := c.QueryRowCtx(ctx, &dest, key, func(ctx context.Context, conn sqlx.SqlConn) error {
|
||||
return conn.QueryRowCtx(ctx, &dest, query, args...)
|
||||
})
|
||||
```
|
||||
|
||||
### Configuration Management
|
||||
|
||||
1. **YAML configuration**: Use YAML for configuration files
|
||||
2. **Environment variables**: Support environment variable overrides
|
||||
3. **Validation**: Include proper validation for configuration parameters
|
||||
4. **Sensible defaults**: Provide reasonable default values
|
||||
|
||||
## Error Handling Best Practices
|
||||
|
||||
1. **Wrap errors**: Use `fmt.Errorf` with `%w` verb to wrap errors
|
||||
2. **Custom errors**: Define custom error types when needed
|
||||
3. **Error logging**: Log errors appropriately with context
|
||||
4. **Graceful degradation**: Implement fallback mechanisms
|
||||
|
||||
## Performance Considerations
|
||||
|
||||
1. **Resource pools**: Use connection pools and worker pools
|
||||
2. **Circuit breakers**: Implement circuit breaker patterns for external calls
|
||||
3. **Rate limiting**: Apply rate limiting to protect services
|
||||
4. **Load shedding**: Implement adaptive load shedding
|
||||
5. **Metrics**: Add appropriate metrics and monitoring
|
||||
|
||||
## Security Guidelines
|
||||
|
||||
1. **Input validation**: Validate all input parameters
|
||||
2. **SQL injection prevention**: Use parameterized queries
|
||||
3. **Authentication**: Implement proper JWT token handling
|
||||
4. **HTTPS**: Support TLS/HTTPS configurations
|
||||
5. **CORS**: Configure CORS appropriately for web APIs
|
||||
|
||||
## Documentation Standards
|
||||
|
||||
1. **Package documentation**: Include package-level documentation
|
||||
2. **Function documentation**: Document exported functions with examples
|
||||
3. **API documentation**: Maintain API documentation in sync
|
||||
4. **README updates**: Update README for significant changes
|
||||
|
||||
## Common Patterns to Follow
|
||||
|
||||
### Service Configuration
|
||||
```go
|
||||
type ServiceConf struct {
|
||||
Name string
|
||||
Log logx.LogConf
|
||||
Mode string `json:",default=pro,options=[dev,test,pre,pro]"`
|
||||
// ... other common fields
|
||||
}
|
||||
```
|
||||
|
||||
### Middleware Implementation
|
||||
```go
|
||||
func SomeMiddleware() rest.Middleware {
|
||||
return func(next http.HandlerFunc) http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
// Pre-processing
|
||||
next.ServeHTTP(w, r)
|
||||
// Post-processing
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### Resource Management
|
||||
Always implement proper resource cleanup using defer and context cancellation.
|
||||
|
||||
## Build and Test Commands
|
||||
|
||||
- Build: `go build ./...`
|
||||
- Test: `go test ./...`
|
||||
- Test with race detection: `go test -race ./...`
|
||||
- Format: `gofmt -w .`
|
||||
- Generate code: `goctl api go -api *.api -dir .`
|
||||
|
||||
Remember to run tests and ensure all checks pass before submitting changes. The project emphasizes high quality, performance, and reliability, so these should be primary considerations in all development work.
|
||||
2
.github/workflows/codeql-analysis.yml
vendored
2
.github/workflows/codeql-analysis.yml
vendored
@@ -35,7 +35,7 @@ jobs:
|
||||
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@v5
|
||||
|
||||
# Initializes the CodeQL tools for scanning.
|
||||
- name: Initialize CodeQL
|
||||
|
||||
13
.github/workflows/go.yml
vendored
13
.github/workflows/go.yml
vendored
@@ -12,10 +12,10 @@ jobs:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Check out code into the Go module directory
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@v5
|
||||
|
||||
- name: Set up Go 1.x
|
||||
uses: actions/setup-go@v5
|
||||
uses: actions/setup-go@v6
|
||||
with:
|
||||
go-version-file: go.mod
|
||||
check-latest: true
|
||||
@@ -41,16 +41,21 @@ jobs:
|
||||
|
||||
- name: Codecov
|
||||
uses: codecov/codecov-action@v5
|
||||
with:
|
||||
files: ./coverage.txt
|
||||
flags: unittests
|
||||
name: codecov-umbrella
|
||||
fail_ci_if_error: false
|
||||
|
||||
test-win:
|
||||
name: Windows
|
||||
runs-on: windows-latest
|
||||
steps:
|
||||
- name: Checkout codebase
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@v5
|
||||
|
||||
- name: Set up Go 1.x
|
||||
uses: actions/setup-go@v5
|
||||
uses: actions/setup-go@v6
|
||||
with:
|
||||
# make sure Go version compatible with go-zero
|
||||
go-version-file: go.mod
|
||||
|
||||
18
.github/workflows/issue-translator.yml
vendored
18
.github/workflows/issue-translator.yml
vendored
@@ -1,18 +0,0 @@
|
||||
name: 'issue-translator'
|
||||
on:
|
||||
issue_comment:
|
||||
types: [created]
|
||||
issues:
|
||||
types: [opened]
|
||||
|
||||
jobs:
|
||||
build:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: usthe/issues-translate-action@v2.7
|
||||
with:
|
||||
IS_MODIFY_TITLE: true
|
||||
# not require, default false, . Decide whether to modify the issue title
|
||||
# if true, the robot account @Issues-translate-bot must have modification permissions, invite @Issues-translate-bot to your project or use your custom bot.
|
||||
CUSTOM_BOT_NOTE: Bot detected the issue body's language is not English, translate it automatically. 👯👭🏻🧑🤝🧑👫🧑🏿🤝🧑🏻👩🏾🤝👨🏿👬🏿
|
||||
# not require. Customize the translation robot prefix message.
|
||||
2
.github/workflows/issues.yml
vendored
2
.github/workflows/issues.yml
vendored
@@ -7,7 +7,7 @@ jobs:
|
||||
close-issues:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/stale@v9
|
||||
- uses: actions/stale@v10
|
||||
with:
|
||||
days-before-issue-stale: 365
|
||||
days-before-issue-close: 90
|
||||
|
||||
2
.github/workflows/release.yaml
vendored
2
.github/workflows/release.yaml
vendored
@@ -16,7 +16,7 @@ jobs:
|
||||
- goarch: "386"
|
||||
goos: darwin
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions/checkout@v5
|
||||
- uses: zeromicro/go-zero-release-action@master
|
||||
with:
|
||||
github_token: ${{ secrets.GITHUB_TOKEN }}
|
||||
|
||||
2
.github/workflows/reviewdog.yml
vendored
2
.github/workflows/reviewdog.yml
vendored
@@ -5,7 +5,7 @@ jobs:
|
||||
name: runner / staticcheck
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions/checkout@v5
|
||||
- uses: reviewdog/action-staticcheck@v1
|
||||
with:
|
||||
github_token: ${{ secrets.github_token }}
|
||||
|
||||
42
.github/workflows/version-check.yml
vendored
Normal file
42
.github/workflows/version-check.yml
vendored
Normal file
@@ -0,0 +1,42 @@
|
||||
name: Release Version Check
|
||||
|
||||
on:
|
||||
push:
|
||||
tags:
|
||||
- 'tools/goctl/v*'
|
||||
workflow_dispatch:
|
||||
|
||||
jobs:
|
||||
version-check:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v5
|
||||
|
||||
- name: Set up Go
|
||||
uses: actions/setup-go@v6
|
||||
with:
|
||||
go-version: '1.21'
|
||||
|
||||
- name: Extract tag version
|
||||
id: get_version
|
||||
run: |
|
||||
# Extract version from tools/goctl/v* format
|
||||
VERSION="${GITHUB_REF#refs/tags/tools/goctl/v}"
|
||||
echo "VERSION=$VERSION" >> $GITHUB_ENV
|
||||
echo "Extracted version: $VERSION"
|
||||
|
||||
- name: Check version in goctl source code
|
||||
run: |
|
||||
# Change to goctl directory
|
||||
cd tools/goctl
|
||||
|
||||
# Check version in BuildVersion constant
|
||||
VERSION_IN_CODE=$(grep -r "const BuildVersion =" . | grep -o '".*"' | tr -d '"')
|
||||
echo "Version in code: $VERSION_IN_CODE"
|
||||
echo "Expected version: $VERSION"
|
||||
|
||||
if [ "$VERSION_IN_CODE" != "$VERSION" ]; then
|
||||
echo "Version mismatch: Version in code ($VERSION_IN_CODE) doesn't match tag version ($VERSION)"
|
||||
exit 1
|
||||
fi
|
||||
echo "✅ Version check passed!"
|
||||
1
.gitignore
vendored
1
.gitignore
vendored
@@ -17,6 +17,7 @@
|
||||
**/logs
|
||||
**/adhoc
|
||||
**/coverage.txt
|
||||
**/WARP.md
|
||||
|
||||
# for test purpose
|
||||
go.work
|
||||
|
||||
@@ -8,16 +8,12 @@ import (
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/zeromicro/go-zero/core/mathx"
|
||||
"github.com/zeromicro/go-zero/core/proc"
|
||||
"github.com/zeromicro/go-zero/core/stat"
|
||||
"github.com/zeromicro/go-zero/core/stringx"
|
||||
)
|
||||
|
||||
const (
|
||||
numHistoryReasons = 5
|
||||
timeFormat = "15:04:05"
|
||||
)
|
||||
const numHistoryReasons = 5
|
||||
|
||||
// ErrServiceUnavailable is returned when the Breaker state is open.
|
||||
var ErrServiceUnavailable = errors.New("circuit breaker is open")
|
||||
@@ -262,9 +258,9 @@ type errorWindow struct {
|
||||
|
||||
func (ew *errorWindow) add(reason string) {
|
||||
ew.lock.Lock()
|
||||
ew.reasons[ew.index] = fmt.Sprintf("%s %s", time.Now().Format(timeFormat), reason)
|
||||
ew.reasons[ew.index] = fmt.Sprintf("%s %s", time.Now().Format(time.TimeOnly), reason)
|
||||
ew.index = (ew.index + 1) % numHistoryReasons
|
||||
ew.count = mathx.MinInt(ew.count+1, numHistoryReasons)
|
||||
ew.count = min(ew.count+1, numHistoryReasons)
|
||||
ew.lock.Unlock()
|
||||
}
|
||||
|
||||
|
||||
@@ -1,235 +1,53 @@
|
||||
package collection
|
||||
|
||||
import (
|
||||
"github.com/zeromicro/go-zero/core/lang"
|
||||
"github.com/zeromicro/go-zero/core/logx"
|
||||
)
|
||||
import "github.com/zeromicro/go-zero/core/lang"
|
||||
|
||||
const (
|
||||
unmanaged = iota
|
||||
untyped
|
||||
intType
|
||||
int64Type
|
||||
uintType
|
||||
uint64Type
|
||||
stringType
|
||||
)
|
||||
|
||||
// Set is not thread-safe, for concurrent use, make sure to use it with synchronization.
|
||||
type Set struct {
|
||||
data map[any]lang.PlaceholderType
|
||||
tp int
|
||||
// Set is a type-safe generic set collection.
|
||||
// It's not thread-safe, use with synchronization for concurrent access.
|
||||
type Set[T comparable] struct {
|
||||
data map[T]lang.PlaceholderType
|
||||
}
|
||||
|
||||
// NewSet returns a managed Set, can only put the values with the same type.
|
||||
func NewSet() *Set {
|
||||
return &Set{
|
||||
data: make(map[any]lang.PlaceholderType),
|
||||
tp: untyped,
|
||||
// NewSet returns a new type-safe set.
|
||||
func NewSet[T comparable]() *Set[T] {
|
||||
return &Set[T]{
|
||||
data: make(map[T]lang.PlaceholderType),
|
||||
}
|
||||
}
|
||||
|
||||
// NewUnmanagedSet returns an unmanaged Set, which can put values with different types.
|
||||
func NewUnmanagedSet() *Set {
|
||||
return &Set{
|
||||
data: make(map[any]lang.PlaceholderType),
|
||||
tp: unmanaged,
|
||||
// Add adds items to the set. Duplicates are automatically ignored.
|
||||
func (s *Set[T]) Add(items ...T) {
|
||||
for _, item := range items {
|
||||
s.data[item] = lang.Placeholder
|
||||
}
|
||||
}
|
||||
|
||||
// Add adds i into s.
|
||||
func (s *Set) Add(i ...any) {
|
||||
for _, each := range i {
|
||||
s.add(each)
|
||||
}
|
||||
// Clear removes all items from the set.
|
||||
func (s *Set[T]) Clear() {
|
||||
clear(s.data)
|
||||
}
|
||||
|
||||
// AddInt adds int values ii into s.
|
||||
func (s *Set) AddInt(ii ...int) {
|
||||
for _, each := range ii {
|
||||
s.add(each)
|
||||
}
|
||||
}
|
||||
|
||||
// AddInt64 adds int64 values ii into s.
|
||||
func (s *Set) AddInt64(ii ...int64) {
|
||||
for _, each := range ii {
|
||||
s.add(each)
|
||||
}
|
||||
}
|
||||
|
||||
// AddUint adds uint values ii into s.
|
||||
func (s *Set) AddUint(ii ...uint) {
|
||||
for _, each := range ii {
|
||||
s.add(each)
|
||||
}
|
||||
}
|
||||
|
||||
// AddUint64 adds uint64 values ii into s.
|
||||
func (s *Set) AddUint64(ii ...uint64) {
|
||||
for _, each := range ii {
|
||||
s.add(each)
|
||||
}
|
||||
}
|
||||
|
||||
// AddStr adds string values ss into s.
|
||||
func (s *Set) AddStr(ss ...string) {
|
||||
for _, each := range ss {
|
||||
s.add(each)
|
||||
}
|
||||
}
|
||||
|
||||
// Contains checks if i is in s.
|
||||
func (s *Set) Contains(i any) bool {
|
||||
if len(s.data) == 0 {
|
||||
return false
|
||||
}
|
||||
|
||||
s.validate(i)
|
||||
_, ok := s.data[i]
|
||||
// Contains checks if an item exists in the set.
|
||||
func (s *Set[T]) Contains(item T) bool {
|
||||
_, ok := s.data[item]
|
||||
return ok
|
||||
}
|
||||
|
||||
// Keys returns the keys in s.
|
||||
func (s *Set) Keys() []any {
|
||||
var keys []any
|
||||
|
||||
for key := range s.data {
|
||||
keys = append(keys, key)
|
||||
}
|
||||
|
||||
return keys
|
||||
}
|
||||
|
||||
// KeysInt returns the int keys in s.
|
||||
func (s *Set) KeysInt() []int {
|
||||
var keys []int
|
||||
|
||||
for key := range s.data {
|
||||
if intKey, ok := key.(int); ok {
|
||||
keys = append(keys, intKey)
|
||||
}
|
||||
}
|
||||
|
||||
return keys
|
||||
}
|
||||
|
||||
// KeysInt64 returns int64 keys in s.
|
||||
func (s *Set) KeysInt64() []int64 {
|
||||
var keys []int64
|
||||
|
||||
for key := range s.data {
|
||||
if intKey, ok := key.(int64); ok {
|
||||
keys = append(keys, intKey)
|
||||
}
|
||||
}
|
||||
|
||||
return keys
|
||||
}
|
||||
|
||||
// KeysUint returns uint keys in s.
|
||||
func (s *Set) KeysUint() []uint {
|
||||
var keys []uint
|
||||
|
||||
for key := range s.data {
|
||||
if intKey, ok := key.(uint); ok {
|
||||
keys = append(keys, intKey)
|
||||
}
|
||||
}
|
||||
|
||||
return keys
|
||||
}
|
||||
|
||||
// KeysUint64 returns uint64 keys in s.
|
||||
func (s *Set) KeysUint64() []uint64 {
|
||||
var keys []uint64
|
||||
|
||||
for key := range s.data {
|
||||
if intKey, ok := key.(uint64); ok {
|
||||
keys = append(keys, intKey)
|
||||
}
|
||||
}
|
||||
|
||||
return keys
|
||||
}
|
||||
|
||||
// KeysStr returns string keys in s.
|
||||
func (s *Set) KeysStr() []string {
|
||||
var keys []string
|
||||
|
||||
for key := range s.data {
|
||||
if strKey, ok := key.(string); ok {
|
||||
keys = append(keys, strKey)
|
||||
}
|
||||
}
|
||||
|
||||
return keys
|
||||
}
|
||||
|
||||
// Remove removes i from s.
|
||||
func (s *Set) Remove(i any) {
|
||||
s.validate(i)
|
||||
delete(s.data, i)
|
||||
}
|
||||
|
||||
// Count returns the number of items in s.
|
||||
func (s *Set) Count() int {
|
||||
// Count returns the number of items in the set.
|
||||
func (s *Set[T]) Count() int {
|
||||
return len(s.data)
|
||||
}
|
||||
|
||||
func (s *Set) add(i any) {
|
||||
switch s.tp {
|
||||
case unmanaged:
|
||||
// do nothing
|
||||
case untyped:
|
||||
s.setType(i)
|
||||
default:
|
||||
s.validate(i)
|
||||
// Keys returns all elements in the set as a slice.
|
||||
func (s *Set[T]) Keys() []T {
|
||||
keys := make([]T, 0, len(s.data))
|
||||
for key := range s.data {
|
||||
keys = append(keys, key)
|
||||
}
|
||||
s.data[i] = lang.Placeholder
|
||||
return keys
|
||||
}
|
||||
|
||||
func (s *Set) setType(i any) {
|
||||
// s.tp can only be untyped here
|
||||
switch i.(type) {
|
||||
case int:
|
||||
s.tp = intType
|
||||
case int64:
|
||||
s.tp = int64Type
|
||||
case uint:
|
||||
s.tp = uintType
|
||||
case uint64:
|
||||
s.tp = uint64Type
|
||||
case string:
|
||||
s.tp = stringType
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Set) validate(i any) {
|
||||
if s.tp == unmanaged {
|
||||
return
|
||||
}
|
||||
|
||||
switch i.(type) {
|
||||
case int:
|
||||
if s.tp != intType {
|
||||
logx.Errorf("element is int, but set contains elements with type %d", s.tp)
|
||||
}
|
||||
case int64:
|
||||
if s.tp != int64Type {
|
||||
logx.Errorf("element is int64, but set contains elements with type %d", s.tp)
|
||||
}
|
||||
case uint:
|
||||
if s.tp != uintType {
|
||||
logx.Errorf("element is uint, but set contains elements with type %d", s.tp)
|
||||
}
|
||||
case uint64:
|
||||
if s.tp != uint64Type {
|
||||
logx.Errorf("element is uint64, but set contains elements with type %d", s.tp)
|
||||
}
|
||||
case string:
|
||||
if s.tp != stringType {
|
||||
logx.Errorf("element is string, but set contains elements with type %d", s.tp)
|
||||
}
|
||||
}
|
||||
// Remove removes an item from the set.
|
||||
func (s *Set[T]) Remove(item T) {
|
||||
delete(s.data, item)
|
||||
}
|
||||
|
||||
@@ -12,6 +12,105 @@ func init() {
|
||||
logx.Disable()
|
||||
}
|
||||
|
||||
// Set functionality tests
|
||||
func TestTypedSetInt(t *testing.T) {
|
||||
set := NewSet[int]()
|
||||
values := []int{1, 2, 3, 2, 1} // Contains duplicates
|
||||
|
||||
// Test adding
|
||||
set.Add(values...)
|
||||
assert.Equal(t, 3, set.Count()) // Should only have 3 elements after deduplication
|
||||
|
||||
// Test contains
|
||||
assert.True(t, set.Contains(1))
|
||||
assert.True(t, set.Contains(2))
|
||||
assert.True(t, set.Contains(3))
|
||||
assert.False(t, set.Contains(4))
|
||||
|
||||
// Test getting all keys
|
||||
keys := set.Keys()
|
||||
sort.Ints(keys)
|
||||
assert.EqualValues(t, []int{1, 2, 3}, keys)
|
||||
|
||||
// Test removal
|
||||
set.Remove(2)
|
||||
assert.False(t, set.Contains(2))
|
||||
assert.Equal(t, 2, set.Count())
|
||||
}
|
||||
|
||||
func TestTypedSetStringOps(t *testing.T) {
|
||||
set := NewSet[string]()
|
||||
values := []string{"a", "b", "c", "b", "a"}
|
||||
|
||||
set.Add(values...)
|
||||
assert.Equal(t, 3, set.Count())
|
||||
|
||||
assert.True(t, set.Contains("a"))
|
||||
assert.True(t, set.Contains("b"))
|
||||
assert.True(t, set.Contains("c"))
|
||||
assert.False(t, set.Contains("d"))
|
||||
|
||||
keys := set.Keys()
|
||||
sort.Strings(keys)
|
||||
assert.EqualValues(t, []string{"a", "b", "c"}, keys)
|
||||
}
|
||||
|
||||
func TestTypedSetClear(t *testing.T) {
|
||||
set := NewSet[int]()
|
||||
set.Add(1, 2, 3)
|
||||
assert.Equal(t, 3, set.Count())
|
||||
|
||||
set.Clear()
|
||||
assert.Equal(t, 0, set.Count())
|
||||
assert.False(t, set.Contains(1))
|
||||
}
|
||||
|
||||
func TestTypedSetEmpty(t *testing.T) {
|
||||
set := NewSet[int]()
|
||||
assert.Equal(t, 0, set.Count())
|
||||
assert.False(t, set.Contains(1))
|
||||
assert.Empty(t, set.Keys())
|
||||
}
|
||||
|
||||
func TestTypedSetMultipleTypes(t *testing.T) {
|
||||
// Test different typed generic sets
|
||||
intSet := NewSet[int]()
|
||||
int64Set := NewSet[int64]()
|
||||
uintSet := NewSet[uint]()
|
||||
uint64Set := NewSet[uint64]()
|
||||
stringSet := NewSet[string]()
|
||||
|
||||
intSet.Add(1, 2, 3)
|
||||
int64Set.Add(1, 2, 3)
|
||||
uintSet.Add(1, 2, 3)
|
||||
uint64Set.Add(1, 2, 3)
|
||||
stringSet.Add("1", "2", "3")
|
||||
|
||||
assert.Equal(t, 3, intSet.Count())
|
||||
assert.Equal(t, 3, int64Set.Count())
|
||||
assert.Equal(t, 3, uintSet.Count())
|
||||
assert.Equal(t, 3, uint64Set.Count())
|
||||
assert.Equal(t, 3, stringSet.Count())
|
||||
}
|
||||
|
||||
// Set benchmarks
|
||||
func BenchmarkTypedIntSet(b *testing.B) {
|
||||
s := NewSet[int]()
|
||||
for i := 0; i < b.N; i++ {
|
||||
s.Add(i)
|
||||
_ = s.Contains(i)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkTypedStringSet(b *testing.B) {
|
||||
s := NewSet[string]()
|
||||
for i := 0; i < b.N; i++ {
|
||||
s.Add(string(rune(i)))
|
||||
_ = s.Contains(string(rune(i)))
|
||||
}
|
||||
}
|
||||
|
||||
// Legacy tests remain unchanged for backward compatibility
|
||||
func BenchmarkRawSet(b *testing.B) {
|
||||
m := make(map[any]struct{})
|
||||
for i := 0; i < b.N; i++ {
|
||||
@@ -20,26 +119,10 @@ func BenchmarkRawSet(b *testing.B) {
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkUnmanagedSet(b *testing.B) {
|
||||
s := NewUnmanagedSet()
|
||||
for i := 0; i < b.N; i++ {
|
||||
s.Add(i)
|
||||
_ = s.Contains(i)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkSet(b *testing.B) {
|
||||
s := NewSet()
|
||||
for i := 0; i < b.N; i++ {
|
||||
s.AddInt(i)
|
||||
_ = s.Contains(i)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAdd(t *testing.T) {
|
||||
// given
|
||||
set := NewUnmanagedSet()
|
||||
values := []any{1, 2, 3}
|
||||
set := NewSet[int]()
|
||||
values := []int{1, 2, 3}
|
||||
|
||||
// when
|
||||
set.Add(values...)
|
||||
@@ -51,82 +134,74 @@ func TestAdd(t *testing.T) {
|
||||
|
||||
func TestAddInt(t *testing.T) {
|
||||
// given
|
||||
set := NewSet()
|
||||
set := NewSet[int]()
|
||||
values := []int{1, 2, 3}
|
||||
|
||||
// when
|
||||
set.AddInt(values...)
|
||||
set.Add(values...)
|
||||
|
||||
// then
|
||||
assert.True(t, set.Contains(1) && set.Contains(2) && set.Contains(3))
|
||||
keys := set.KeysInt()
|
||||
keys := set.Keys()
|
||||
sort.Ints(keys)
|
||||
assert.EqualValues(t, values, keys)
|
||||
}
|
||||
|
||||
func TestAddInt64(t *testing.T) {
|
||||
// given
|
||||
set := NewSet()
|
||||
set := NewSet[int64]()
|
||||
values := []int64{1, 2, 3}
|
||||
|
||||
// when
|
||||
set.AddInt64(values...)
|
||||
set.Add(values...)
|
||||
|
||||
// then
|
||||
assert.True(t, set.Contains(int64(1)) && set.Contains(int64(2)) && set.Contains(int64(3)))
|
||||
assert.Equal(t, len(values), len(set.KeysInt64()))
|
||||
assert.True(t, set.Contains(1) && set.Contains(2) && set.Contains(3))
|
||||
assert.Equal(t, len(values), len(set.Keys()))
|
||||
}
|
||||
|
||||
func TestAddUint(t *testing.T) {
|
||||
// given
|
||||
set := NewSet()
|
||||
set := NewSet[uint]()
|
||||
values := []uint{1, 2, 3}
|
||||
|
||||
// when
|
||||
set.AddUint(values...)
|
||||
set.Add(values...)
|
||||
|
||||
// then
|
||||
assert.True(t, set.Contains(uint(1)) && set.Contains(uint(2)) && set.Contains(uint(3)))
|
||||
assert.Equal(t, len(values), len(set.KeysUint()))
|
||||
assert.True(t, set.Contains(1) && set.Contains(2) && set.Contains(3))
|
||||
assert.Equal(t, len(values), len(set.Keys()))
|
||||
}
|
||||
|
||||
func TestAddUint64(t *testing.T) {
|
||||
// given
|
||||
set := NewSet()
|
||||
set := NewSet[uint64]()
|
||||
values := []uint64{1, 2, 3}
|
||||
|
||||
// when
|
||||
set.AddUint64(values...)
|
||||
set.Add(values...)
|
||||
|
||||
// then
|
||||
assert.True(t, set.Contains(uint64(1)) && set.Contains(uint64(2)) && set.Contains(uint64(3)))
|
||||
assert.Equal(t, len(values), len(set.KeysUint64()))
|
||||
assert.True(t, set.Contains(1) && set.Contains(2) && set.Contains(3))
|
||||
assert.Equal(t, len(values), len(set.Keys()))
|
||||
}
|
||||
|
||||
func TestAddStr(t *testing.T) {
|
||||
// given
|
||||
set := NewSet()
|
||||
set := NewSet[string]()
|
||||
values := []string{"1", "2", "3"}
|
||||
|
||||
// when
|
||||
set.AddStr(values...)
|
||||
set.Add(values...)
|
||||
|
||||
// then
|
||||
assert.True(t, set.Contains("1") && set.Contains("2") && set.Contains("3"))
|
||||
assert.Equal(t, len(values), len(set.KeysStr()))
|
||||
assert.Equal(t, len(values), len(set.Keys()))
|
||||
}
|
||||
|
||||
func TestContainsWithoutElements(t *testing.T) {
|
||||
// given
|
||||
set := NewSet()
|
||||
|
||||
// then
|
||||
assert.False(t, set.Contains(1))
|
||||
}
|
||||
|
||||
func TestContainsUnmanagedWithoutElements(t *testing.T) {
|
||||
// given
|
||||
set := NewUnmanagedSet()
|
||||
set := NewSet[int]()
|
||||
|
||||
// then
|
||||
assert.False(t, set.Contains(1))
|
||||
@@ -134,8 +209,8 @@ func TestContainsUnmanagedWithoutElements(t *testing.T) {
|
||||
|
||||
func TestRemove(t *testing.T) {
|
||||
// given
|
||||
set := NewSet()
|
||||
set.Add([]any{1, 2, 3}...)
|
||||
set := NewSet[int]()
|
||||
set.Add([]int{1, 2, 3}...)
|
||||
|
||||
// when
|
||||
set.Remove(2)
|
||||
@@ -146,57 +221,9 @@ func TestRemove(t *testing.T) {
|
||||
|
||||
func TestCount(t *testing.T) {
|
||||
// given
|
||||
set := NewSet()
|
||||
set.Add([]any{1, 2, 3}...)
|
||||
set := NewSet[int]()
|
||||
set.Add([]int{1, 2, 3}...)
|
||||
|
||||
// then
|
||||
assert.Equal(t, set.Count(), 3)
|
||||
}
|
||||
|
||||
func TestKeysIntMismatch(t *testing.T) {
|
||||
set := NewSet()
|
||||
set.add(int64(1))
|
||||
set.add(2)
|
||||
vals := set.KeysInt()
|
||||
assert.EqualValues(t, []int{2}, vals)
|
||||
}
|
||||
|
||||
func TestKeysInt64Mismatch(t *testing.T) {
|
||||
set := NewSet()
|
||||
set.add(1)
|
||||
set.add(int64(2))
|
||||
vals := set.KeysInt64()
|
||||
assert.EqualValues(t, []int64{2}, vals)
|
||||
}
|
||||
|
||||
func TestKeysUintMismatch(t *testing.T) {
|
||||
set := NewSet()
|
||||
set.add(1)
|
||||
set.add(uint(2))
|
||||
vals := set.KeysUint()
|
||||
assert.EqualValues(t, []uint{2}, vals)
|
||||
}
|
||||
|
||||
func TestKeysUint64Mismatch(t *testing.T) {
|
||||
set := NewSet()
|
||||
set.add(1)
|
||||
set.add(uint64(2))
|
||||
vals := set.KeysUint64()
|
||||
assert.EqualValues(t, []uint64{2}, vals)
|
||||
}
|
||||
|
||||
func TestKeysStrMismatch(t *testing.T) {
|
||||
set := NewSet()
|
||||
set.add(1)
|
||||
set.add("2")
|
||||
vals := set.KeysStr()
|
||||
assert.EqualValues(t, []string{"2"}, vals)
|
||||
}
|
||||
|
||||
func TestSetType(t *testing.T) {
|
||||
set := NewUnmanagedSet()
|
||||
set.add(1)
|
||||
set.add("2")
|
||||
vals := set.Keys()
|
||||
assert.ElementsMatch(t, []any{1, "2"}, vals)
|
||||
}
|
||||
|
||||
@@ -316,7 +316,7 @@ func toLowerCaseInterface(v any, info *fieldInfo) any {
|
||||
case map[string]any:
|
||||
return toLowerCaseKeyMap(vv, info)
|
||||
case []any:
|
||||
var arr []any
|
||||
arr := make([]any, 0, len(vv))
|
||||
for _, vvv := range vv {
|
||||
arr = append(arr, toLowerCaseInterface(vvv, info))
|
||||
}
|
||||
|
||||
@@ -1,5 +1,10 @@
|
||||
// Code generated by MockGen. DO NOT EDIT.
|
||||
// Source: etcdclient.go
|
||||
//
|
||||
// Generated by this command:
|
||||
//
|
||||
// mockgen -package internal -destination etcdclient_mock.go -source etcdclient.go EtcdClient
|
||||
//
|
||||
|
||||
// Package internal is a generated GoMock package.
|
||||
package internal
|
||||
@@ -8,35 +13,36 @@ import (
|
||||
context "context"
|
||||
reflect "reflect"
|
||||
|
||||
gomock "github.com/golang/mock/gomock"
|
||||
clientv3 "go.etcd.io/etcd/client/v3"
|
||||
gomock "go.uber.org/mock/gomock"
|
||||
grpc "google.golang.org/grpc"
|
||||
)
|
||||
|
||||
// MockEtcdClient is a mock of EtcdClient interface
|
||||
// MockEtcdClient is a mock of EtcdClient interface.
|
||||
type MockEtcdClient struct {
|
||||
ctrl *gomock.Controller
|
||||
recorder *MockEtcdClientMockRecorder
|
||||
isgomock struct{}
|
||||
}
|
||||
|
||||
// MockEtcdClientMockRecorder is the mock recorder for MockEtcdClient
|
||||
// MockEtcdClientMockRecorder is the mock recorder for MockEtcdClient.
|
||||
type MockEtcdClientMockRecorder struct {
|
||||
mock *MockEtcdClient
|
||||
}
|
||||
|
||||
// NewMockEtcdClient creates a new mock instance
|
||||
// NewMockEtcdClient creates a new mock instance.
|
||||
func NewMockEtcdClient(ctrl *gomock.Controller) *MockEtcdClient {
|
||||
mock := &MockEtcdClient{ctrl: ctrl}
|
||||
mock.recorder = &MockEtcdClientMockRecorder{mock}
|
||||
return mock
|
||||
}
|
||||
|
||||
// EXPECT returns an object that allows the caller to indicate expected use
|
||||
// EXPECT returns an object that allows the caller to indicate expected use.
|
||||
func (m *MockEtcdClient) EXPECT() *MockEtcdClientMockRecorder {
|
||||
return m.recorder
|
||||
}
|
||||
|
||||
// ActiveConnection mocks base method
|
||||
// ActiveConnection mocks base method.
|
||||
func (m *MockEtcdClient) ActiveConnection() *grpc.ClientConn {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "ActiveConnection")
|
||||
@@ -44,13 +50,13 @@ func (m *MockEtcdClient) ActiveConnection() *grpc.ClientConn {
|
||||
return ret0
|
||||
}
|
||||
|
||||
// ActiveConnection indicates an expected call of ActiveConnection
|
||||
// ActiveConnection indicates an expected call of ActiveConnection.
|
||||
func (mr *MockEtcdClientMockRecorder) ActiveConnection() *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ActiveConnection", reflect.TypeOf((*MockEtcdClient)(nil).ActiveConnection))
|
||||
}
|
||||
|
||||
// Close mocks base method
|
||||
// Close mocks base method.
|
||||
func (m *MockEtcdClient) Close() error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "Close")
|
||||
@@ -58,13 +64,13 @@ func (m *MockEtcdClient) Close() error {
|
||||
return ret0
|
||||
}
|
||||
|
||||
// Close indicates an expected call of Close
|
||||
// Close indicates an expected call of Close.
|
||||
func (mr *MockEtcdClientMockRecorder) Close() *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockEtcdClient)(nil).Close))
|
||||
}
|
||||
|
||||
// Ctx mocks base method
|
||||
// Ctx mocks base method.
|
||||
func (m *MockEtcdClient) Ctx() context.Context {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "Ctx")
|
||||
@@ -72,13 +78,13 @@ func (m *MockEtcdClient) Ctx() context.Context {
|
||||
return ret0
|
||||
}
|
||||
|
||||
// Ctx indicates an expected call of Ctx
|
||||
// Ctx indicates an expected call of Ctx.
|
||||
func (mr *MockEtcdClientMockRecorder) Ctx() *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Ctx", reflect.TypeOf((*MockEtcdClient)(nil).Ctx))
|
||||
}
|
||||
|
||||
// Get mocks base method
|
||||
// Get mocks base method.
|
||||
func (m *MockEtcdClient) Get(ctx context.Context, key string, opts ...clientv3.OpOption) (*clientv3.GetResponse, error) {
|
||||
m.ctrl.T.Helper()
|
||||
varargs := []any{ctx, key}
|
||||
@@ -91,14 +97,14 @@ func (m *MockEtcdClient) Get(ctx context.Context, key string, opts ...clientv3.O
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// Get indicates an expected call of Get
|
||||
// Get indicates an expected call of Get.
|
||||
func (mr *MockEtcdClientMockRecorder) Get(ctx, key any, opts ...any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
varargs := append([]any{ctx, key}, opts...)
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Get", reflect.TypeOf((*MockEtcdClient)(nil).Get), varargs...)
|
||||
}
|
||||
|
||||
// Grant mocks base method
|
||||
// Grant mocks base method.
|
||||
func (m *MockEtcdClient) Grant(ctx context.Context, ttl int64) (*clientv3.LeaseGrantResponse, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "Grant", ctx, ttl)
|
||||
@@ -107,13 +113,13 @@ func (m *MockEtcdClient) Grant(ctx context.Context, ttl int64) (*clientv3.LeaseG
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// Grant indicates an expected call of Grant
|
||||
// Grant indicates an expected call of Grant.
|
||||
func (mr *MockEtcdClientMockRecorder) Grant(ctx, ttl any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Grant", reflect.TypeOf((*MockEtcdClient)(nil).Grant), ctx, ttl)
|
||||
}
|
||||
|
||||
// KeepAlive mocks base method
|
||||
// KeepAlive mocks base method.
|
||||
func (m *MockEtcdClient) KeepAlive(ctx context.Context, id clientv3.LeaseID) (<-chan *clientv3.LeaseKeepAliveResponse, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "KeepAlive", ctx, id)
|
||||
@@ -122,13 +128,13 @@ func (m *MockEtcdClient) KeepAlive(ctx context.Context, id clientv3.LeaseID) (<-
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// KeepAlive indicates an expected call of KeepAlive
|
||||
// KeepAlive indicates an expected call of KeepAlive.
|
||||
func (mr *MockEtcdClientMockRecorder) KeepAlive(ctx, id any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "KeepAlive", reflect.TypeOf((*MockEtcdClient)(nil).KeepAlive), ctx, id)
|
||||
}
|
||||
|
||||
// Put mocks base method
|
||||
// Put mocks base method.
|
||||
func (m *MockEtcdClient) Put(ctx context.Context, key, val string, opts ...clientv3.OpOption) (*clientv3.PutResponse, error) {
|
||||
m.ctrl.T.Helper()
|
||||
varargs := []any{ctx, key, val}
|
||||
@@ -141,14 +147,14 @@ func (m *MockEtcdClient) Put(ctx context.Context, key, val string, opts ...clien
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// Put indicates an expected call of Put
|
||||
// Put indicates an expected call of Put.
|
||||
func (mr *MockEtcdClientMockRecorder) Put(ctx, key, val any, opts ...any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
varargs := append([]any{ctx, key, val}, opts...)
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Put", reflect.TypeOf((*MockEtcdClient)(nil).Put), varargs...)
|
||||
}
|
||||
|
||||
// Revoke mocks base method
|
||||
// Revoke mocks base method.
|
||||
func (m *MockEtcdClient) Revoke(ctx context.Context, id clientv3.LeaseID) (*clientv3.LeaseRevokeResponse, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "Revoke", ctx, id)
|
||||
@@ -157,13 +163,13 @@ func (m *MockEtcdClient) Revoke(ctx context.Context, id clientv3.LeaseID) (*clie
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// Revoke indicates an expected call of Revoke
|
||||
// Revoke indicates an expected call of Revoke.
|
||||
func (mr *MockEtcdClientMockRecorder) Revoke(ctx, id any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Revoke", reflect.TypeOf((*MockEtcdClient)(nil).Revoke), ctx, id)
|
||||
}
|
||||
|
||||
// Watch mocks base method
|
||||
// Watch mocks base method.
|
||||
func (m *MockEtcdClient) Watch(ctx context.Context, key string, opts ...clientv3.OpOption) clientv3.WatchChan {
|
||||
m.ctrl.T.Helper()
|
||||
varargs := []any{ctx, key}
|
||||
@@ -175,7 +181,7 @@ func (m *MockEtcdClient) Watch(ctx context.Context, key string, opts ...clientv3
|
||||
return ret0
|
||||
}
|
||||
|
||||
// Watch indicates an expected call of Watch
|
||||
// Watch indicates an expected call of Watch.
|
||||
func (mr *MockEtcdClientMockRecorder) Watch(ctx, key any, opts ...any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
varargs := append([]any{ctx, key}, opts...)
|
||||
|
||||
@@ -207,7 +207,7 @@ func (c *cluster) getCurrent(key watchKey) []KV {
|
||||
return nil
|
||||
}
|
||||
|
||||
var kvs []KV
|
||||
kvs := make([]KV, 0, len(watcher.values))
|
||||
for k, v := range watcher.values {
|
||||
kvs = append(kvs, KV{
|
||||
Key: k,
|
||||
@@ -308,7 +308,7 @@ func (c *cluster) load(cli EtcdClient, key watchKey) int64 {
|
||||
time.Sleep(coolDownUnstable.AroundDuration(coolDownInterval))
|
||||
}
|
||||
|
||||
var kvs []KV
|
||||
kvs := make([]KV, 0, len(resp.Kvs))
|
||||
for _, ev := range resp.Kvs {
|
||||
kvs = append(kvs, KV{
|
||||
Key: string(ev.Key),
|
||||
@@ -352,7 +352,7 @@ func (c *cluster) reload(cli EtcdClient) {
|
||||
// cancel the previous watches
|
||||
close(c.done)
|
||||
c.watchGroup.Wait()
|
||||
var keys []watchKey
|
||||
keys := make([]watchKey, 0, len(c.watchers))
|
||||
for wk, wval := range c.watchers {
|
||||
keys = append(keys, wk)
|
||||
if wval.cancel != nil {
|
||||
|
||||
@@ -7,7 +7,6 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/golang/mock/gomock"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/zeromicro/go-zero/core/contextx"
|
||||
"github.com/zeromicro/go-zero/core/lang"
|
||||
@@ -18,6 +17,7 @@ import (
|
||||
"go.etcd.io/etcd/api/v3/mvccpb"
|
||||
clientv3 "go.etcd.io/etcd/client/v3"
|
||||
"go.etcd.io/etcd/client/v3/mock/mockserver"
|
||||
"go.uber.org/mock/gomock"
|
||||
)
|
||||
|
||||
var mockLock sync.Mutex
|
||||
@@ -423,7 +423,7 @@ func TestRegistry_Monitor(t *testing.T) {
|
||||
GetRegistry().clusters = map[string]*cluster{
|
||||
getClusterKey(endpoints): {
|
||||
watchers: map[watchKey]*watchValue{
|
||||
watchKey{
|
||||
{
|
||||
key: "foo",
|
||||
exactMatch: true,
|
||||
}: {
|
||||
@@ -449,7 +449,7 @@ func TestRegistry_Unmonitor(t *testing.T) {
|
||||
GetRegistry().clusters = map[string]*cluster{
|
||||
getClusterKey(endpoints): {
|
||||
watchers: map[watchKey]*watchValue{
|
||||
watchKey{
|
||||
{
|
||||
key: "foo",
|
||||
exactMatch: true,
|
||||
}: {
|
||||
|
||||
@@ -1,5 +1,10 @@
|
||||
// Code generated by MockGen. DO NOT EDIT.
|
||||
// Source: statewatcher.go
|
||||
//
|
||||
// Generated by this command:
|
||||
//
|
||||
// mockgen -package internal -destination statewatcher_mock.go -source statewatcher.go etcdConn
|
||||
//
|
||||
|
||||
// Package internal is a generated GoMock package.
|
||||
package internal
|
||||
@@ -8,34 +13,35 @@ import (
|
||||
context "context"
|
||||
reflect "reflect"
|
||||
|
||||
gomock "github.com/golang/mock/gomock"
|
||||
gomock "go.uber.org/mock/gomock"
|
||||
connectivity "google.golang.org/grpc/connectivity"
|
||||
)
|
||||
|
||||
// MocketcdConn is a mock of etcdConn interface
|
||||
// MocketcdConn is a mock of etcdConn interface.
|
||||
type MocketcdConn struct {
|
||||
ctrl *gomock.Controller
|
||||
recorder *MocketcdConnMockRecorder
|
||||
isgomock struct{}
|
||||
}
|
||||
|
||||
// MocketcdConnMockRecorder is the mock recorder for MocketcdConn
|
||||
// MocketcdConnMockRecorder is the mock recorder for MocketcdConn.
|
||||
type MocketcdConnMockRecorder struct {
|
||||
mock *MocketcdConn
|
||||
}
|
||||
|
||||
// NewMocketcdConn creates a new mock instance
|
||||
// NewMocketcdConn creates a new mock instance.
|
||||
func NewMocketcdConn(ctrl *gomock.Controller) *MocketcdConn {
|
||||
mock := &MocketcdConn{ctrl: ctrl}
|
||||
mock.recorder = &MocketcdConnMockRecorder{mock}
|
||||
return mock
|
||||
}
|
||||
|
||||
// EXPECT returns an object that allows the caller to indicate expected use
|
||||
// EXPECT returns an object that allows the caller to indicate expected use.
|
||||
func (m *MocketcdConn) EXPECT() *MocketcdConnMockRecorder {
|
||||
return m.recorder
|
||||
}
|
||||
|
||||
// GetState mocks base method
|
||||
// GetState mocks base method.
|
||||
func (m *MocketcdConn) GetState() connectivity.State {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetState")
|
||||
@@ -43,13 +49,13 @@ func (m *MocketcdConn) GetState() connectivity.State {
|
||||
return ret0
|
||||
}
|
||||
|
||||
// GetState indicates an expected call of GetState
|
||||
// GetState indicates an expected call of GetState.
|
||||
func (mr *MocketcdConnMockRecorder) GetState() *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetState", reflect.TypeOf((*MocketcdConn)(nil).GetState))
|
||||
}
|
||||
|
||||
// WaitForStateChange mocks base method
|
||||
// WaitForStateChange mocks base method.
|
||||
func (m *MocketcdConn) WaitForStateChange(ctx context.Context, sourceState connectivity.State) bool {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "WaitForStateChange", ctx, sourceState)
|
||||
@@ -57,7 +63,7 @@ func (m *MocketcdConn) WaitForStateChange(ctx context.Context, sourceState conne
|
||||
return ret0
|
||||
}
|
||||
|
||||
// WaitForStateChange indicates an expected call of WaitForStateChange
|
||||
// WaitForStateChange indicates an expected call of WaitForStateChange.
|
||||
func (mr *MocketcdConnMockRecorder) WaitForStateChange(ctx, sourceState any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "WaitForStateChange", reflect.TypeOf((*MocketcdConn)(nil).WaitForStateChange), ctx, sourceState)
|
||||
|
||||
@@ -4,7 +4,7 @@ import (
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"github.com/golang/mock/gomock"
|
||||
"go.uber.org/mock/gomock"
|
||||
"google.golang.org/grpc/connectivity"
|
||||
)
|
||||
|
||||
|
||||
@@ -1,5 +1,10 @@
|
||||
// Code generated by MockGen. DO NOT EDIT.
|
||||
// Source: updatelistener.go
|
||||
//
|
||||
// Generated by this command:
|
||||
//
|
||||
// mockgen -package internal -destination updatelistener_mock.go -source updatelistener.go UpdateListener
|
||||
//
|
||||
|
||||
// Package internal is a generated GoMock package.
|
||||
package internal
|
||||
@@ -7,51 +12,52 @@ package internal
|
||||
import (
|
||||
reflect "reflect"
|
||||
|
||||
gomock "github.com/golang/mock/gomock"
|
||||
gomock "go.uber.org/mock/gomock"
|
||||
)
|
||||
|
||||
// MockUpdateListener is a mock of UpdateListener interface
|
||||
// MockUpdateListener is a mock of UpdateListener interface.
|
||||
type MockUpdateListener struct {
|
||||
ctrl *gomock.Controller
|
||||
recorder *MockUpdateListenerMockRecorder
|
||||
isgomock struct{}
|
||||
}
|
||||
|
||||
// MockUpdateListenerMockRecorder is the mock recorder for MockUpdateListener
|
||||
// MockUpdateListenerMockRecorder is the mock recorder for MockUpdateListener.
|
||||
type MockUpdateListenerMockRecorder struct {
|
||||
mock *MockUpdateListener
|
||||
}
|
||||
|
||||
// NewMockUpdateListener creates a new mock instance
|
||||
// NewMockUpdateListener creates a new mock instance.
|
||||
func NewMockUpdateListener(ctrl *gomock.Controller) *MockUpdateListener {
|
||||
mock := &MockUpdateListener{ctrl: ctrl}
|
||||
mock.recorder = &MockUpdateListenerMockRecorder{mock}
|
||||
return mock
|
||||
}
|
||||
|
||||
// EXPECT returns an object that allows the caller to indicate expected use
|
||||
// EXPECT returns an object that allows the caller to indicate expected use.
|
||||
func (m *MockUpdateListener) EXPECT() *MockUpdateListenerMockRecorder {
|
||||
return m.recorder
|
||||
}
|
||||
|
||||
// OnAdd mocks base method
|
||||
// OnAdd mocks base method.
|
||||
func (m *MockUpdateListener) OnAdd(kv KV) {
|
||||
m.ctrl.T.Helper()
|
||||
m.ctrl.Call(m, "OnAdd", kv)
|
||||
}
|
||||
|
||||
// OnAdd indicates an expected call of OnAdd
|
||||
// OnAdd indicates an expected call of OnAdd.
|
||||
func (mr *MockUpdateListenerMockRecorder) OnAdd(kv any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OnAdd", reflect.TypeOf((*MockUpdateListener)(nil).OnAdd), kv)
|
||||
}
|
||||
|
||||
// OnDelete mocks base method
|
||||
// OnDelete mocks base method.
|
||||
func (m *MockUpdateListener) OnDelete(kv KV) {
|
||||
m.ctrl.T.Helper()
|
||||
m.ctrl.Call(m, "OnDelete", kv)
|
||||
}
|
||||
|
||||
// OnDelete indicates an expected call of OnDelete
|
||||
// OnDelete indicates an expected call of OnDelete.
|
||||
func (mr *MockUpdateListenerMockRecorder) OnDelete(kv any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OnDelete", reflect.TypeOf((*MockUpdateListener)(nil).OnDelete), kv)
|
||||
|
||||
@@ -92,12 +92,12 @@ func (p *Publisher) doKeepAlive() error {
|
||||
default:
|
||||
cli, err := p.doRegister()
|
||||
if err != nil {
|
||||
logc.Errorf(cli.Ctx(), "etcd publisher doRegister: %s", err.Error())
|
||||
logc.Errorf(cli.Ctx(), "etcd publisher doRegister: %v", err)
|
||||
break
|
||||
}
|
||||
|
||||
if err := p.keepAliveAsync(cli); err != nil {
|
||||
logc.Errorf(cli.Ctx(), "etcd publisher keepAliveAsync: %s", err.Error())
|
||||
logc.Errorf(cli.Ctx(), "etcd publisher keepAliveAsync: %v", err)
|
||||
break
|
||||
}
|
||||
|
||||
@@ -125,23 +125,48 @@ func (p *Publisher) keepAliveAsync(cli internal.EtcdClient) error {
|
||||
}
|
||||
|
||||
threading.GoSafe(func() {
|
||||
wch := cli.Watch(cli.Ctx(), p.fullKey, clientv3.WithFilterPut())
|
||||
|
||||
for {
|
||||
select {
|
||||
case _, ok := <-ch:
|
||||
if !ok {
|
||||
p.revoke(cli)
|
||||
if err := p.doKeepAlive(); err != nil {
|
||||
logc.Errorf(cli.Ctx(), "etcd publisher KeepAlive: %s", err.Error())
|
||||
logc.Errorf(cli.Ctx(), "etcd publisher KeepAlive: %v", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
case c := <-wch:
|
||||
if c.Err() != nil {
|
||||
logc.Errorf(cli.Ctx(), "etcd publisher watch: %v", c.Err())
|
||||
if err := p.doKeepAlive(); err != nil {
|
||||
logc.Errorf(cli.Ctx(), "etcd publisher KeepAlive: %v", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
for _, evt := range c.Events {
|
||||
if evt.Type == clientv3.EventTypeDelete {
|
||||
logc.Infof(cli.Ctx(), "etcd publisher watch: %s, event: %v",
|
||||
evt.Kv.Key, evt.Type)
|
||||
_, err := cli.Put(cli.Ctx(), p.fullKey, p.value, clientv3.WithLease(p.lease))
|
||||
if err != nil {
|
||||
logc.Errorf(cli.Ctx(), "etcd publisher re-put key: %v", err)
|
||||
} else {
|
||||
logc.Infof(cli.Ctx(), "etcd publisher re-put key: %s, value: %s",
|
||||
p.fullKey, p.value)
|
||||
}
|
||||
}
|
||||
}
|
||||
case <-p.pauseChan:
|
||||
logc.Infof(cli.Ctx(), "paused etcd renew, key: %s, value: %s", p.key, p.value)
|
||||
p.revoke(cli)
|
||||
select {
|
||||
case <-p.resumeChan:
|
||||
if err := p.doKeepAlive(); err != nil {
|
||||
logc.Errorf(cli.Ctx(), "etcd publisher KeepAlive: %s", err.Error())
|
||||
logc.Errorf(cli.Ctx(), "etcd publisher KeepAlive: %v", err)
|
||||
}
|
||||
return
|
||||
case <-p.quit.Done():
|
||||
@@ -176,7 +201,7 @@ func (p *Publisher) register(client internal.EtcdClient) (clientv3.LeaseID, erro
|
||||
|
||||
func (p *Publisher) revoke(cli internal.EtcdClient) {
|
||||
if _, err := cli.Revoke(cli.Ctx(), p.lease); err != nil {
|
||||
logc.Errorf(cli.Ctx(), "etcd publisher revoke: %s", err.Error())
|
||||
logc.Errorf(cli.Ctx(), "etcd publisher revoke: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -9,13 +9,14 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/golang/mock/gomock"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/zeromicro/go-zero/core/discov/internal"
|
||||
"github.com/zeromicro/go-zero/core/lang"
|
||||
"github.com/zeromicro/go-zero/core/logx"
|
||||
"github.com/zeromicro/go-zero/core/stringx"
|
||||
"go.etcd.io/etcd/api/v3/mvccpb"
|
||||
clientv3 "go.etcd.io/etcd/client/v3"
|
||||
"go.uber.org/mock/gomock"
|
||||
"golang.org/x/net/http2"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/credentials/insecure"
|
||||
@@ -211,6 +212,9 @@ func TestPublisher_keepAliveAsyncQuit(t *testing.T) {
|
||||
defer restore()
|
||||
cli.EXPECT().Ctx().AnyTimes()
|
||||
cli.EXPECT().KeepAlive(gomock.Any(), id)
|
||||
// Add Watch expectation for the new watch mechanism
|
||||
watchChan := make(<-chan clientv3.WatchResponse)
|
||||
cli.EXPECT().Watch(gomock.Any(), gomock.Any(), gomock.Any()).Return(watchChan)
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(1)
|
||||
cli.EXPECT().Revoke(gomock.Any(), id).Do(func(_, _ any) {
|
||||
@@ -232,6 +236,9 @@ func TestPublisher_keepAliveAsyncPause(t *testing.T) {
|
||||
defer restore()
|
||||
cli.EXPECT().Ctx().AnyTimes()
|
||||
cli.EXPECT().KeepAlive(gomock.Any(), id)
|
||||
// Add Watch expectation for the new watch mechanism
|
||||
watchChan := make(<-chan clientv3.WatchResponse)
|
||||
cli.EXPECT().Watch(gomock.Any(), gomock.Any(), gomock.Any()).Return(watchChan)
|
||||
pub := NewPublisher(nil, "thekey", "thevalue")
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(1)
|
||||
@@ -245,6 +252,112 @@ func TestPublisher_keepAliveAsyncPause(t *testing.T) {
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
// Test case for key deletion and re-registration (covers lines 148-155)
|
||||
func TestPublisher_keepAliveAsyncKeyDeletion(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
const id clientv3.LeaseID = 1
|
||||
cli := internal.NewMockEtcdClient(ctrl)
|
||||
restore := setMockClient(cli)
|
||||
defer restore()
|
||||
cli.EXPECT().Ctx().AnyTimes()
|
||||
cli.EXPECT().KeepAlive(gomock.Any(), id)
|
||||
|
||||
// Create a watch channel that will send a delete event
|
||||
watchChan := make(chan clientv3.WatchResponse, 1)
|
||||
watchResp := clientv3.WatchResponse{
|
||||
Events: []*clientv3.Event{{
|
||||
Type: clientv3.EventTypeDelete,
|
||||
Kv: &mvccpb.KeyValue{
|
||||
Key: []byte("thekey"),
|
||||
},
|
||||
}},
|
||||
}
|
||||
watchChan <- watchResp
|
||||
|
||||
cli.EXPECT().Watch(gomock.Any(), gomock.Any(), gomock.Any()).Return((<-chan clientv3.WatchResponse)(watchChan))
|
||||
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(1) // Only wait for Revoke call
|
||||
|
||||
// Use a channel to signal when Put has been called
|
||||
putCalled := make(chan struct{})
|
||||
|
||||
// Expect the re-put operation when key is deleted
|
||||
cli.EXPECT().Put(gomock.Any(), "thekey", "thevalue", gomock.Any()).Do(func(_, _, _, _ any) {
|
||||
close(putCalled) // Signal that Put has been called
|
||||
}).Return(nil, nil)
|
||||
|
||||
// Expect revoke when Stop is called
|
||||
cli.EXPECT().Revoke(gomock.Any(), id).Do(func(_, _ any) {
|
||||
wg.Done()
|
||||
})
|
||||
|
||||
pub := NewPublisher(nil, "thekey", "thevalue")
|
||||
pub.lease = id
|
||||
pub.fullKey = "thekey"
|
||||
|
||||
assert.Nil(t, pub.keepAliveAsync(cli))
|
||||
|
||||
// Wait for Put to be called, then stop
|
||||
<-putCalled
|
||||
pub.Stop()
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
// Test case for key deletion with re-put error (covers error branch in lines 151-152)
|
||||
func TestPublisher_keepAliveAsyncKeyDeletionPutError(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
const id clientv3.LeaseID = 1
|
||||
cli := internal.NewMockEtcdClient(ctrl)
|
||||
restore := setMockClient(cli)
|
||||
defer restore()
|
||||
cli.EXPECT().Ctx().AnyTimes()
|
||||
cli.EXPECT().KeepAlive(gomock.Any(), id)
|
||||
|
||||
// Create a watch channel that will send a delete event
|
||||
watchChan := make(chan clientv3.WatchResponse, 1)
|
||||
watchResp := clientv3.WatchResponse{
|
||||
Events: []*clientv3.Event{{
|
||||
Type: clientv3.EventTypeDelete,
|
||||
Kv: &mvccpb.KeyValue{
|
||||
Key: []byte("thekey"),
|
||||
},
|
||||
}},
|
||||
}
|
||||
watchChan <- watchResp
|
||||
|
||||
cli.EXPECT().Watch(gomock.Any(), gomock.Any(), gomock.Any()).Return((<-chan clientv3.WatchResponse)(watchChan))
|
||||
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(1) // Only wait for Revoke call
|
||||
|
||||
// Use a channel to signal when Put has been called
|
||||
putCalled := make(chan struct{})
|
||||
|
||||
// Expect the re-put operation to fail
|
||||
cli.EXPECT().Put(gomock.Any(), "thekey", "thevalue", gomock.Any()).Do(func(_, _, _, _ any) {
|
||||
close(putCalled) // Signal that Put has been called
|
||||
}).Return(nil, errors.New("put error"))
|
||||
|
||||
// Expect revoke when Stop is called
|
||||
cli.EXPECT().Revoke(gomock.Any(), id).Do(func(_, _ any) {
|
||||
wg.Done()
|
||||
})
|
||||
|
||||
pub := NewPublisher(nil, "thekey", "thevalue")
|
||||
pub.lease = id
|
||||
pub.fullKey = "thekey"
|
||||
|
||||
assert.Nil(t, pub.keepAliveAsync(cli))
|
||||
|
||||
// Wait for Put to be called, then stop
|
||||
<-putCalled
|
||||
pub.Stop()
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
func TestPublisher_Resume(t *testing.T) {
|
||||
publisher := new(Publisher)
|
||||
publisher.resumeChan = make(chan lang.PlaceholderType)
|
||||
@@ -273,6 +386,9 @@ func TestPublisher_keepAliveAsync(t *testing.T) {
|
||||
defer restore()
|
||||
cli.EXPECT().Ctx().AnyTimes()
|
||||
cli.EXPECT().KeepAlive(gomock.Any(), id)
|
||||
// Add Watch expectation for the new watch mechanism
|
||||
watchChan := make(<-chan clientv3.WatchResponse)
|
||||
cli.EXPECT().Watch(gomock.Any(), gomock.Any(), gomock.Any()).Return(watchChan)
|
||||
cli.EXPECT().Grant(gomock.Any(), timeToLive).Return(&clientv3.LeaseGrantResponse{
|
||||
ID: 1,
|
||||
}, nil)
|
||||
|
||||
@@ -86,21 +86,16 @@ func TestConsistentHashIncrementalTransfer(t *testing.T) {
|
||||
|
||||
func TestConsistentHashTransferOnFailure(t *testing.T) {
|
||||
index := 41
|
||||
keys, newKeys := getKeysBeforeAndAfterFailure(t, "localhost:", index)
|
||||
var transferred int
|
||||
for k, v := range newKeys {
|
||||
if v != keys[k] {
|
||||
transferred++
|
||||
}
|
||||
}
|
||||
|
||||
ratio := float32(transferred) / float32(requestSize)
|
||||
assert.True(t, ratio < 2.5/float32(keySize), fmt.Sprintf("%d: %f", index, ratio))
|
||||
ratioNotExists := getTransferRatioOnFailure(t, index)
|
||||
assert.True(t, ratioNotExists == 0, fmt.Sprintf("%d: %f", index, ratioNotExists))
|
||||
index = 13
|
||||
ratio := getTransferRatioOnFailure(t, index)
|
||||
assert.True(t, ratio < 2.5/keySize, fmt.Sprintf("%d: %f", index, ratio))
|
||||
}
|
||||
|
||||
func TestConsistentHashLeastTransferOnFailure(t *testing.T) {
|
||||
prefix := "localhost:"
|
||||
index := 41
|
||||
index := 13
|
||||
keys, newKeys := getKeysBeforeAndAfterFailure(t, prefix, index)
|
||||
for k, v := range keys {
|
||||
newV := newKeys[k]
|
||||
@@ -164,6 +159,17 @@ func getKeysBeforeAndAfterFailure(t *testing.T, prefix string, index int) (map[i
|
||||
return keys, newKeys
|
||||
}
|
||||
|
||||
func getTransferRatioOnFailure(t *testing.T, index int) float32 {
|
||||
keys, newKeys := getKeysBeforeAndAfterFailure(t, "localhost:", index)
|
||||
var transferred int
|
||||
for k, v := range newKeys {
|
||||
if v != keys[k] {
|
||||
transferred++
|
||||
}
|
||||
}
|
||||
return float32(transferred) / float32(requestSize)
|
||||
}
|
||||
|
||||
type mockNode struct {
|
||||
addr string
|
||||
id int
|
||||
|
||||
@@ -2,7 +2,7 @@ package hash
|
||||
|
||||
import (
|
||||
"crypto/md5"
|
||||
"fmt"
|
||||
"encoding/hex"
|
||||
|
||||
"github.com/spaolacci/murmur3"
|
||||
)
|
||||
@@ -20,6 +20,7 @@ func Md5(data []byte) []byte {
|
||||
}
|
||||
|
||||
// Md5Hex returns the md5 hex string of data.
|
||||
// This function is optimized for better performance than fmt.Sprintf.
|
||||
func Md5Hex(data []byte) string {
|
||||
return fmt.Sprintf("%x", Md5(data))
|
||||
return hex.EncodeToString(Md5(data))
|
||||
}
|
||||
|
||||
@@ -8,9 +8,25 @@ import (
|
||||
"strings"
|
||||
)
|
||||
|
||||
// Marshal marshals v into json bytes.
|
||||
// Marshal marshals v into json bytes, without escaping HTML and removes the trailing newline.
|
||||
func Marshal(v any) ([]byte, error) {
|
||||
return json.Marshal(v)
|
||||
// why not use json.Marshal? https://github.com/golang/go/issues/28453
|
||||
// it changes the behavior of json.Marshal, like & -> \u0026, < -> \u003c, > -> \u003e
|
||||
// which is not what we want in API responses
|
||||
var buf bytes.Buffer
|
||||
enc := json.NewEncoder(&buf)
|
||||
enc.SetEscapeHTML(false)
|
||||
if err := enc.Encode(v); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
bs := buf.Bytes()
|
||||
// Remove trailing newline added by json.Encoder.Encode
|
||||
if len(bs) > 0 && bs[len(bs)-1] == '\n' {
|
||||
bs = bs[:len(bs)-1]
|
||||
}
|
||||
|
||||
return bs, nil
|
||||
}
|
||||
|
||||
// MarshalToString marshals v into a string.
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package jsonx
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
@@ -101,3 +102,105 @@ func TestUnmarshalFromReaderError(t *testing.T) {
|
||||
err := UnmarshalFromReader(strings.NewReader(s), &v)
|
||||
assert.NotNil(t, err)
|
||||
}
|
||||
|
||||
func Test_doMarshalJson(t *testing.T) {
|
||||
type args struct {
|
||||
v any
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
want []byte
|
||||
wantErr assert.ErrorAssertionFunc
|
||||
}{
|
||||
{
|
||||
name: "nil",
|
||||
args: args{nil},
|
||||
want: []byte("null"),
|
||||
wantErr: assert.NoError,
|
||||
},
|
||||
{
|
||||
name: "string",
|
||||
args: args{"hello"},
|
||||
want: []byte(`"hello"`),
|
||||
wantErr: assert.NoError,
|
||||
},
|
||||
{
|
||||
name: "int",
|
||||
args: args{42},
|
||||
want: []byte("42"),
|
||||
wantErr: assert.NoError,
|
||||
},
|
||||
{
|
||||
name: "bool",
|
||||
args: args{true},
|
||||
want: []byte("true"),
|
||||
wantErr: assert.NoError,
|
||||
},
|
||||
{
|
||||
name: "struct",
|
||||
args: args{
|
||||
struct {
|
||||
Name string `json:"name"`
|
||||
}{Name: "test"},
|
||||
},
|
||||
want: []byte(`{"name":"test"}`),
|
||||
wantErr: assert.NoError,
|
||||
},
|
||||
{
|
||||
name: "slice",
|
||||
args: args{[]int{1, 2, 3}},
|
||||
want: []byte("[1,2,3]"),
|
||||
wantErr: assert.NoError,
|
||||
},
|
||||
{
|
||||
name: "map",
|
||||
args: args{map[string]int{"a": 1, "b": 2}},
|
||||
want: []byte(`{"a":1,"b":2}`),
|
||||
wantErr: assert.NoError,
|
||||
},
|
||||
{
|
||||
name: "unmarshalable type",
|
||||
args: args{complex(1, 2)},
|
||||
want: nil,
|
||||
wantErr: assert.Error,
|
||||
},
|
||||
{
|
||||
name: "channel type",
|
||||
args: args{make(chan int)},
|
||||
want: nil,
|
||||
wantErr: assert.Error,
|
||||
},
|
||||
{
|
||||
name: "url with query params",
|
||||
args: args{"https://example.com/api?name=test&age=25"},
|
||||
want: []byte(`"https://example.com/api?name=test&age=25"`),
|
||||
wantErr: assert.NoError,
|
||||
},
|
||||
{
|
||||
name: "url with encoded query params",
|
||||
args: args{"https://example.com/api?data=hello%20world&special=%26%3D"},
|
||||
want: []byte(`"https://example.com/api?data=hello%20world&special=%26%3D"`),
|
||||
wantErr: assert.NoError,
|
||||
},
|
||||
{
|
||||
name: "url with multiple query params",
|
||||
args: args{"http://localhost:8080/users?page=1&limit=10&sort=name&order=asc"},
|
||||
want: []byte(`"http://localhost:8080/users?page=1&limit=10&sort=name&order=asc"`),
|
||||
wantErr: assert.NoError,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
tt := tt
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got, err := Marshal(tt.args.v)
|
||||
if !tt.wantErr(t, err, fmt.Sprintf("Marshal(%v)", tt.args.v)) {
|
||||
return
|
||||
}
|
||||
|
||||
assert.Equalf(t, string(tt.want), string(got), "Marshal(%v)", tt.args.v)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,47 +1,70 @@
|
||||
package logx
|
||||
|
||||
// A LogConf is a logging config.
|
||||
type LogConf struct {
|
||||
// ServiceName represents the service name.
|
||||
ServiceName string `json:",optional"`
|
||||
// Mode represents the logging mode, default is `console`.
|
||||
// console: log to console.
|
||||
// file: log to file.
|
||||
// volume: used in k8s, prepend the hostname to the log file name.
|
||||
Mode string `json:",default=console,options=[console,file,volume]"`
|
||||
// Encoding represents the encoding type, default is `json`.
|
||||
// json: json encoding.
|
||||
// plain: plain text encoding, typically used in development.
|
||||
Encoding string `json:",default=json,options=[json,plain]"`
|
||||
// TimeFormat represents the time format, default is `2006-01-02T15:04:05.000Z07:00`.
|
||||
TimeFormat string `json:",optional"`
|
||||
// Path represents the log file path, default is `logs`.
|
||||
Path string `json:",default=logs"`
|
||||
// Level represents the log level, default is `info`.
|
||||
Level string `json:",default=info,options=[debug,info,error,severe]"`
|
||||
// MaxContentLength represents the max content bytes, default is no limit.
|
||||
MaxContentLength uint32 `json:",optional"`
|
||||
// Compress represents whether to compress the log file, default is `false`.
|
||||
Compress bool `json:",optional"`
|
||||
// Stat represents whether to log statistics, default is `true`.
|
||||
Stat bool `json:",default=true"`
|
||||
// KeepDays represents how many days the log files will be kept. Default to keep all files.
|
||||
// Only take effect when Mode is `file` or `volume`, both work when Rotation is `daily` or `size`.
|
||||
KeepDays int `json:",optional"`
|
||||
// StackCooldownMillis represents the cooldown time for stack logging, default is 100ms.
|
||||
StackCooldownMillis int `json:",default=100"`
|
||||
// MaxBackups represents how many backup log files will be kept. 0 means all files will be kept forever.
|
||||
// Only take effect when RotationRuleType is `size`.
|
||||
// Even though `MaxBackups` sets 0, log files will still be removed
|
||||
// if the `KeepDays` limitation is reached.
|
||||
MaxBackups int `json:",default=0"`
|
||||
// MaxSize represents how much space the writing log file takes up. 0 means no limit. The unit is `MB`.
|
||||
// Only take effect when RotationRuleType is `size`
|
||||
MaxSize int `json:",default=0"`
|
||||
// Rotation represents the type of log rotation rule. Default is `daily`.
|
||||
// daily: daily rotation.
|
||||
// size: size limited rotation.
|
||||
Rotation string `json:",default=daily,options=[daily,size]"`
|
||||
// FileTimeFormat represents the time format for file name, default is `2006-01-02T15:04:05.000Z07:00`.
|
||||
FileTimeFormat string `json:",optional"`
|
||||
}
|
||||
type (
|
||||
// A LogConf is a logging config.
|
||||
LogConf struct {
|
||||
// ServiceName represents the service name.
|
||||
ServiceName string `json:",optional"`
|
||||
// Mode represents the logging mode, default is `console`.
|
||||
// console: log to console.
|
||||
// file: log to file.
|
||||
// volume: used in k8s, prepend the hostname to the log file name.
|
||||
Mode string `json:",default=console,options=[console,file,volume]"`
|
||||
// Encoding represents the encoding type, default is `json`.
|
||||
// json: json encoding.
|
||||
// plain: plain text encoding, typically used in development.
|
||||
Encoding string `json:",default=json,options=[json,plain]"`
|
||||
// TimeFormat represents the time format, default is `2006-01-02T15:04:05.000Z07:00`.
|
||||
TimeFormat string `json:",optional"`
|
||||
// Path represents the log file path, default is `logs`.
|
||||
Path string `json:",default=logs"`
|
||||
// Level represents the log level, default is `info`.
|
||||
Level string `json:",default=info,options=[debug,info,error,severe]"`
|
||||
// MaxContentLength represents the max content bytes, default is no limit.
|
||||
MaxContentLength uint32 `json:",optional"`
|
||||
// Compress represents whether to compress the log file, default is `false`.
|
||||
Compress bool `json:",optional"`
|
||||
// Stat represents whether to log statistics, default is `true`.
|
||||
Stat bool `json:",default=true"`
|
||||
// KeepDays represents how many days the log files will be kept. Default to keep all files.
|
||||
// Only take effect when Mode is `file` or `volume`, both work when Rotation is `daily` or `size`.
|
||||
KeepDays int `json:",optional"`
|
||||
// StackCooldownMillis represents the cooldown time for stack logging, default is 100ms.
|
||||
StackCooldownMillis int `json:",default=100"`
|
||||
// MaxBackups represents how many backup log files will be kept. 0 means all files will be kept forever.
|
||||
// Only take effect when RotationRuleType is `size`.
|
||||
// Even though `MaxBackups` sets 0, log files will still be removed
|
||||
// if the `KeepDays` limitation is reached.
|
||||
MaxBackups int `json:",default=0"`
|
||||
// MaxSize represents how much space the writing log file takes up. 0 means no limit. The unit is `MB`.
|
||||
// Only take effect when RotationRuleType is `size`
|
||||
MaxSize int `json:",default=0"`
|
||||
// Rotation represents the type of log rotation rule. Default is `daily`.
|
||||
// daily: daily rotation.
|
||||
// size: size limited rotation.
|
||||
Rotation string `json:",default=daily,options=[daily,size]"`
|
||||
// FileTimeFormat represents the time format for file name, default is `2006-01-02T15:04:05.000Z07:00`.
|
||||
FileTimeFormat string `json:",optional"`
|
||||
// FieldKeys represents the field keys.
|
||||
FieldKeys fieldKeyConf `json:",optional"`
|
||||
}
|
||||
|
||||
fieldKeyConf struct {
|
||||
// CallerKey represents the caller key.
|
||||
CallerKey string `json:",default=caller"`
|
||||
// ContentKey represents the content key.
|
||||
ContentKey string `json:",default=content"`
|
||||
// DurationKey represents the duration key.
|
||||
DurationKey string `json:",default=duration"`
|
||||
// LevelKey represents the level key.
|
||||
LevelKey string `json:",default=level"`
|
||||
// SpanKey represents the span key.
|
||||
SpanKey string `json:",default=span"`
|
||||
// TimestampKey represents the timestamp key.
|
||||
TimestampKey string `json:",default=@timestamp"`
|
||||
// TraceKey represents the trace key.
|
||||
TraceKey string `json:",default=trace"`
|
||||
// TruncatedKey represents the truncated key.
|
||||
TruncatedKey string `json:",default=truncated"`
|
||||
}
|
||||
)
|
||||
|
||||
@@ -7,12 +7,11 @@ import (
|
||||
)
|
||||
|
||||
var (
|
||||
fieldsContextKey contextKey
|
||||
globalFields atomic.Value
|
||||
globalFieldsLock sync.Mutex
|
||||
)
|
||||
|
||||
type contextKey struct{}
|
||||
type fieldsKey struct{}
|
||||
|
||||
// AddGlobalFields adds global fields.
|
||||
func AddGlobalFields(fields ...LogField) {
|
||||
@@ -29,16 +28,16 @@ func AddGlobalFields(fields ...LogField) {
|
||||
|
||||
// ContextWithFields returns a new context with the given fields.
|
||||
func ContextWithFields(ctx context.Context, fields ...LogField) context.Context {
|
||||
if val := ctx.Value(fieldsContextKey); val != nil {
|
||||
if val := ctx.Value(fieldsKey{}); val != nil {
|
||||
if arr, ok := val.([]LogField); ok {
|
||||
allFields := make([]LogField, 0, len(arr)+len(fields))
|
||||
allFields = append(allFields, arr...)
|
||||
allFields = append(allFields, fields...)
|
||||
return context.WithValue(ctx, fieldsContextKey, allFields)
|
||||
return context.WithValue(ctx, fieldsKey{}, allFields)
|
||||
}
|
||||
}
|
||||
|
||||
return context.WithValue(ctx, fieldsContextKey, fields)
|
||||
return context.WithValue(ctx, fieldsKey{}, fields)
|
||||
}
|
||||
|
||||
// WithFields returns a new logger with the given fields.
|
||||
|
||||
@@ -34,7 +34,7 @@ func TestAddGlobalFields(t *testing.T) {
|
||||
|
||||
func TestContextWithFields(t *testing.T) {
|
||||
ctx := ContextWithFields(context.Background(), Field("a", 1), Field("b", 2))
|
||||
vals := ctx.Value(fieldsContextKey)
|
||||
vals := ctx.Value(fieldsKey{})
|
||||
assert.NotNil(t, vals)
|
||||
fields, ok := vals.([]LogField)
|
||||
assert.True(t, ok)
|
||||
@@ -43,7 +43,7 @@ func TestContextWithFields(t *testing.T) {
|
||||
|
||||
func TestWithFields(t *testing.T) {
|
||||
ctx := WithFields(context.Background(), Field("a", 1), Field("b", 2))
|
||||
vals := ctx.Value(fieldsContextKey)
|
||||
vals := ctx.Value(fieldsKey{})
|
||||
assert.NotNil(t, vals)
|
||||
fields, ok := vals.([]LogField)
|
||||
assert.True(t, ok)
|
||||
@@ -55,7 +55,7 @@ func TestWithFieldsAppend(t *testing.T) {
|
||||
ctx := context.WithValue(context.Background(), dummyKey, "dummy")
|
||||
ctx = ContextWithFields(ctx, Field("a", 1), Field("b", 2))
|
||||
ctx = ContextWithFields(ctx, Field("c", 3), Field("d", 4))
|
||||
vals := ctx.Value(fieldsContextKey)
|
||||
vals := ctx.Value(fieldsKey{})
|
||||
assert.NotNil(t, vals)
|
||||
fields, ok := vals.([]LogField)
|
||||
assert.True(t, ok)
|
||||
@@ -80,8 +80,8 @@ func TestWithFieldsAppendCopy(t *testing.T) {
|
||||
ctxa := ContextWithFields(ctx, af)
|
||||
ctxb := ContextWithFields(ctx, bf)
|
||||
|
||||
assert.EqualValues(t, af, ctxa.Value(fieldsContextKey).([]LogField)[count])
|
||||
assert.EqualValues(t, bf, ctxb.Value(fieldsContextKey).([]LogField)[count])
|
||||
assert.EqualValues(t, af, ctxa.Value(fieldsKey{}).([]LogField)[count])
|
||||
assert.EqualValues(t, bf, ctxb.Value(fieldsKey{}).([]LogField)[count])
|
||||
}
|
||||
|
||||
func BenchmarkAtomicValue(b *testing.B) {
|
||||
|
||||
@@ -10,7 +10,6 @@ import (
|
||||
"runtime/debug"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/zeromicro/go-zero/core/sysx"
|
||||
)
|
||||
@@ -187,39 +186,9 @@ func Errorw(msg string, fields ...LogField) {
|
||||
|
||||
// Field returns a LogField for the given key and value.
|
||||
func Field(key string, value any) LogField {
|
||||
switch val := value.(type) {
|
||||
case error:
|
||||
return LogField{Key: key, Value: encodeError(val)}
|
||||
case []error:
|
||||
var errs []string
|
||||
for _, err := range val {
|
||||
errs = append(errs, encodeError(err))
|
||||
}
|
||||
return LogField{Key: key, Value: errs}
|
||||
case time.Duration:
|
||||
return LogField{Key: key, Value: fmt.Sprint(val)}
|
||||
case []time.Duration:
|
||||
var durs []string
|
||||
for _, dur := range val {
|
||||
durs = append(durs, fmt.Sprint(dur))
|
||||
}
|
||||
return LogField{Key: key, Value: durs}
|
||||
case []time.Time:
|
||||
var times []string
|
||||
for _, t := range val {
|
||||
times = append(times, fmt.Sprint(t))
|
||||
}
|
||||
return LogField{Key: key, Value: times}
|
||||
case fmt.Stringer:
|
||||
return LogField{Key: key, Value: encodeStringer(val)}
|
||||
case []fmt.Stringer:
|
||||
var strs []string
|
||||
for _, str := range val {
|
||||
strs = append(strs, encodeStringer(str))
|
||||
}
|
||||
return LogField{Key: key, Value: strs}
|
||||
default:
|
||||
return LogField{Key: key, Value: val}
|
||||
return LogField{
|
||||
Key: key,
|
||||
Value: value,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -307,7 +276,8 @@ func SetUp(c LogConf) (err error) {
|
||||
// Because multiple services in one process might call SetUp respectively.
|
||||
// Need to wait for the first caller to complete the execution.
|
||||
setupOnce.Do(func() {
|
||||
setupLogLevel(c)
|
||||
setupLogLevel(c.Level)
|
||||
setupFieldKeys(c.FieldKeys)
|
||||
|
||||
if !c.Stat {
|
||||
DisableStat()
|
||||
@@ -511,8 +481,35 @@ func handleOptions(opts []LogOption) {
|
||||
}
|
||||
}
|
||||
|
||||
func setupLogLevel(c LogConf) {
|
||||
switch c.Level {
|
||||
func setupFieldKeys(c fieldKeyConf) {
|
||||
if len(c.CallerKey) > 0 {
|
||||
callerKey = c.CallerKey
|
||||
}
|
||||
if len(c.ContentKey) > 0 {
|
||||
contentKey = c.ContentKey
|
||||
}
|
||||
if len(c.DurationKey) > 0 {
|
||||
durationKey = c.DurationKey
|
||||
}
|
||||
if len(c.LevelKey) > 0 {
|
||||
levelKey = c.LevelKey
|
||||
}
|
||||
if len(c.SpanKey) > 0 {
|
||||
spanKey = c.SpanKey
|
||||
}
|
||||
if len(c.TimestampKey) > 0 {
|
||||
timestampKey = c.TimestampKey
|
||||
}
|
||||
if len(c.TraceKey) > 0 {
|
||||
traceKey = c.TraceKey
|
||||
}
|
||||
if len(c.TruncatedKey) > 0 {
|
||||
truncatedKey = c.TruncatedKey
|
||||
}
|
||||
}
|
||||
|
||||
func setupLogLevel(level string) {
|
||||
switch level {
|
||||
case levelDebug:
|
||||
SetLevel(DebugLevel)
|
||||
case levelInfo:
|
||||
@@ -560,7 +557,7 @@ func shallLogStat() bool {
|
||||
// If we check shallLog here, the fmt.Sprint might be called even if the log level is not enabled.
|
||||
// The caller should check shallLog before calling this function.
|
||||
func writeDebug(val any, fields ...LogField) {
|
||||
getWriter().Debug(val, addCaller(fields...)...)
|
||||
getWriter().Debug(val, mergeGlobalFields(addCaller(fields...))...)
|
||||
}
|
||||
|
||||
// writeError writes v into the error log.
|
||||
@@ -568,7 +565,7 @@ func writeDebug(val any, fields ...LogField) {
|
||||
// If we check shallLog here, the fmt.Sprint might be called even if the log level is not enabled.
|
||||
// The caller should check shallLog before calling this function.
|
||||
func writeError(val any, fields ...LogField) {
|
||||
getWriter().Error(val, addCaller(fields...)...)
|
||||
getWriter().Error(val, mergeGlobalFields(addCaller(fields...))...)
|
||||
}
|
||||
|
||||
// writeInfo writes v into info log.
|
||||
@@ -576,7 +573,7 @@ func writeError(val any, fields ...LogField) {
|
||||
// If we check shallLog here, the fmt.Sprint might be called even if the log level is not enabled.
|
||||
// The caller should check shallLog before calling this function.
|
||||
func writeInfo(val any, fields ...LogField) {
|
||||
getWriter().Info(val, addCaller(fields...)...)
|
||||
getWriter().Info(val, mergeGlobalFields(addCaller(fields...))...)
|
||||
}
|
||||
|
||||
// writeSevere writes v into severe log.
|
||||
@@ -592,7 +589,7 @@ func writeSevere(msg string) {
|
||||
// If we check shallLog here, the fmt.Sprint might be called even if the log level is not enabled.
|
||||
// The caller should check shallLog before calling this function.
|
||||
func writeSlow(val any, fields ...LogField) {
|
||||
getWriter().Slow(val, addCaller(fields...)...)
|
||||
getWriter().Slow(val, mergeGlobalFields(addCaller(fields...))...)
|
||||
}
|
||||
|
||||
// writeStack writes v into stack log.
|
||||
@@ -608,5 +605,5 @@ func writeStack(msg string) {
|
||||
// If we check shallLog here, the fmt.Sprint might be called even if the log level is not enabled.
|
||||
// The caller should check shallLog before calling this function.
|
||||
func writeStat(msg string) {
|
||||
getWriter().Stat(msg, addCaller()...)
|
||||
getWriter().Stat(msg, mergeGlobalFields(addCaller())...)
|
||||
}
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package logx
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
@@ -16,6 +17,8 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"go.opentelemetry.io/otel"
|
||||
"go.opentelemetry.io/otel/sdk/trace"
|
||||
)
|
||||
|
||||
var (
|
||||
@@ -776,15 +779,9 @@ func TestSetup(t *testing.T) {
|
||||
MaxBackups: 3,
|
||||
MaxSize: 1024 * 1024,
|
||||
}))
|
||||
setupLogLevel(LogConf{
|
||||
Level: levelInfo,
|
||||
})
|
||||
setupLogLevel(LogConf{
|
||||
Level: levelError,
|
||||
})
|
||||
setupLogLevel(LogConf{
|
||||
Level: levelSevere,
|
||||
})
|
||||
setupLogLevel(levelInfo)
|
||||
setupLogLevel(levelError)
|
||||
setupLogLevel(levelSevere)
|
||||
_, err := createOutput("")
|
||||
assert.NotNil(t, err)
|
||||
Disable()
|
||||
@@ -856,6 +853,95 @@ func TestWithKeepDays(t *testing.T) {
|
||||
assert.Equal(t, 1, opt.keepDays)
|
||||
}
|
||||
|
||||
func TestWithField_LogLevel(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
level uint32
|
||||
fn func(string, ...LogField)
|
||||
count int32
|
||||
}{
|
||||
{
|
||||
name: "debug/info",
|
||||
level: DebugLevel,
|
||||
fn: Infow,
|
||||
count: 1,
|
||||
},
|
||||
{
|
||||
name: "info/error",
|
||||
level: InfoLevel,
|
||||
fn: Errorw,
|
||||
count: 1,
|
||||
},
|
||||
{
|
||||
name: "info/info",
|
||||
level: InfoLevel,
|
||||
fn: Infow,
|
||||
count: 1,
|
||||
},
|
||||
{
|
||||
name: "info/severe",
|
||||
level: InfoLevel,
|
||||
fn: Errorw,
|
||||
count: 1,
|
||||
},
|
||||
{
|
||||
name: "error/info",
|
||||
level: ErrorLevel,
|
||||
fn: Infow,
|
||||
count: 0,
|
||||
},
|
||||
{
|
||||
name: "error/debug",
|
||||
level: ErrorLevel,
|
||||
fn: Debugw,
|
||||
count: 0,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
tt := tt
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
olevel := atomic.LoadUint32(&logLevel)
|
||||
SetLevel(tt.level)
|
||||
defer SetLevel(olevel)
|
||||
|
||||
var val countingStringer
|
||||
tt.fn("hello there", Field("foo", &val))
|
||||
assert.Equal(t, tt.count, val.Count())
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestWithField_LogLevelWithContext(t *testing.T) {
|
||||
t.Run("context more than once with info/info", func(t *testing.T) {
|
||||
olevel := atomic.LoadUint32(&logLevel)
|
||||
SetLevel(InfoLevel)
|
||||
defer SetLevel(olevel)
|
||||
|
||||
var val countingStringer
|
||||
ctx := ContextWithFields(context.Background(), Field("foo", &val))
|
||||
logger := WithContext(ctx)
|
||||
logger.Info("hello there")
|
||||
logger.Info("hello there")
|
||||
logger.Info("hello there")
|
||||
assert.True(t, val.Count() > 0)
|
||||
})
|
||||
|
||||
t.Run("context more than once with error/info", func(t *testing.T) {
|
||||
olevel := atomic.LoadUint32(&logLevel)
|
||||
SetLevel(ErrorLevel)
|
||||
defer SetLevel(olevel)
|
||||
|
||||
var val countingStringer
|
||||
ctx := ContextWithFields(context.Background(), Field("foo", &val))
|
||||
logger := WithContext(ctx)
|
||||
logger.Info("hello there")
|
||||
logger.Info("hello there")
|
||||
logger.Info("hello there")
|
||||
assert.Equal(t, int32(0), val.Count())
|
||||
})
|
||||
}
|
||||
|
||||
func BenchmarkCopyByteSliceAppend(b *testing.B) {
|
||||
for i := 0; i < b.N; i++ {
|
||||
var buf []byte
|
||||
@@ -1054,3 +1140,79 @@ type panicStringer struct {
|
||||
func (s panicStringer) String() string {
|
||||
panic("panic")
|
||||
}
|
||||
|
||||
type countingStringer struct {
|
||||
count int32
|
||||
}
|
||||
|
||||
func (s *countingStringer) Count() int32 {
|
||||
return atomic.LoadInt32(&s.count)
|
||||
}
|
||||
|
||||
func (s *countingStringer) String() string {
|
||||
atomic.AddInt32(&s.count, 1)
|
||||
return "countingStringer"
|
||||
}
|
||||
|
||||
func TestLogKey(t *testing.T) {
|
||||
setupOnce = sync.Once{}
|
||||
MustSetup(LogConf{
|
||||
ServiceName: "any",
|
||||
Mode: "console",
|
||||
Encoding: "json",
|
||||
TimeFormat: timeFormat,
|
||||
FieldKeys: fieldKeyConf{
|
||||
CallerKey: "_caller",
|
||||
ContentKey: "_content",
|
||||
DurationKey: "_duration",
|
||||
LevelKey: "_level",
|
||||
SpanKey: "_span",
|
||||
TimestampKey: "_timestamp",
|
||||
TraceKey: "_trace",
|
||||
TruncatedKey: "_truncated",
|
||||
},
|
||||
})
|
||||
|
||||
t.Cleanup(func() {
|
||||
setupFieldKeys(fieldKeyConf{
|
||||
CallerKey: defaultCallerKey,
|
||||
ContentKey: defaultContentKey,
|
||||
DurationKey: defaultDurationKey,
|
||||
LevelKey: defaultLevelKey,
|
||||
SpanKey: defaultSpanKey,
|
||||
TimestampKey: defaultTimestampKey,
|
||||
TraceKey: defaultTraceKey,
|
||||
TruncatedKey: defaultTruncatedKey,
|
||||
})
|
||||
})
|
||||
|
||||
const message = "hello there"
|
||||
w := new(mockWriter)
|
||||
old := writer.Swap(w)
|
||||
defer writer.Store(old)
|
||||
|
||||
otp := otel.GetTracerProvider()
|
||||
tp := trace.NewTracerProvider(trace.WithSampler(trace.AlwaysSample()))
|
||||
otel.SetTracerProvider(tp)
|
||||
defer otel.SetTracerProvider(otp)
|
||||
|
||||
ctx, span := tp.Tracer("trace-id").Start(context.Background(), "span-id")
|
||||
defer span.End()
|
||||
|
||||
WithContext(ctx).WithDuration(time.Second).Info(message)
|
||||
now := time.Now()
|
||||
|
||||
var m map[string]string
|
||||
if err := json.Unmarshal([]byte(w.String()), &m); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
assert.Equal(t, "info", m["_level"])
|
||||
assert.Equal(t, message, m["_content"])
|
||||
assert.Equal(t, "1000.0ms", m["_duration"])
|
||||
assert.Regexp(t, `logx/logs_test.go:\d+`, m["_caller"])
|
||||
assert.NotEmpty(t, m["_trace"])
|
||||
assert.NotEmpty(t, m["_span"])
|
||||
parsedTime, err := time.Parse(timeFormat, m["_timestamp"])
|
||||
assert.True(t, err == nil)
|
||||
assert.Equal(t, now.Minute(), parsedTime.Minute())
|
||||
}
|
||||
|
||||
@@ -206,7 +206,9 @@ func (l *richLogger) WithFields(fields ...LogField) Logger {
|
||||
|
||||
func (l *richLogger) buildFields(fields ...LogField) []LogField {
|
||||
fields = append(l.fields, fields...)
|
||||
// caller field should always appear together with global fields
|
||||
fields = append(fields, Field(callerKey, getCaller(callerDepth+l.callerSkip)))
|
||||
fields = mergeGlobalFields(fields)
|
||||
|
||||
if l.ctx == nil {
|
||||
return fields
|
||||
@@ -222,7 +224,7 @@ func (l *richLogger) buildFields(fields ...LogField) []LogField {
|
||||
fields = append(fields, Field(spanKey, spanID))
|
||||
}
|
||||
|
||||
val := l.ctx.Value(fieldsContextKey)
|
||||
val := l.ctx.Value(fieldsKey{})
|
||||
if val != nil {
|
||||
if arr, ok := val.([]LogField); ok {
|
||||
fields = append(fields, arr...)
|
||||
|
||||
@@ -423,3 +423,49 @@ type mockValue struct {
|
||||
Foo string `json:"foo"`
|
||||
Content any `json:"content"`
|
||||
}
|
||||
|
||||
type testJson struct {
|
||||
Name string `json:"name"`
|
||||
Age int `json:"age"`
|
||||
Score float64 `json:"score"`
|
||||
}
|
||||
|
||||
func (t testJson) MarshalJSON() ([]byte, error) {
|
||||
type testJsonImpl testJson
|
||||
return json.Marshal(testJsonImpl(t))
|
||||
}
|
||||
|
||||
func (t testJson) String() string {
|
||||
return fmt.Sprintf("%s %d %f", t.Name, t.Age, t.Score)
|
||||
}
|
||||
|
||||
func TestLogWithJson(t *testing.T) {
|
||||
w := new(mockWriter)
|
||||
old := writer.Swap(w)
|
||||
writer.lock.RLock()
|
||||
defer func() {
|
||||
writer.lock.RUnlock()
|
||||
writer.Store(old)
|
||||
}()
|
||||
|
||||
l := WithContext(context.Background()).WithFields(Field("bar", testJson{
|
||||
Name: "foo",
|
||||
Age: 1,
|
||||
Score: 1.0,
|
||||
}))
|
||||
l.Info(testlog)
|
||||
|
||||
type mockValue2 struct {
|
||||
mockValue
|
||||
Bar testJson `json:"bar"`
|
||||
}
|
||||
|
||||
var val mockValue2
|
||||
err := json.Unmarshal([]byte(w.String()), &val)
|
||||
assert.NoError(t, err)
|
||||
|
||||
assert.Equal(t, testlog, val.Content)
|
||||
assert.Equal(t, "foo", val.Bar.Name)
|
||||
assert.Equal(t, 1, val.Bar.Age)
|
||||
assert.Equal(t, 1.0, val.Bar.Score)
|
||||
}
|
||||
|
||||
@@ -18,7 +18,6 @@ import (
|
||||
)
|
||||
|
||||
const (
|
||||
dateFormat = "2006-01-02"
|
||||
hoursPerDay = 24
|
||||
bufferSize = 100
|
||||
defaultDirMode = 0o755
|
||||
@@ -116,7 +115,7 @@ func (r *DailyRotateRule) OutdatedFiles() []string {
|
||||
}
|
||||
|
||||
var buf strings.Builder
|
||||
boundary := time.Now().Add(-time.Hour * time.Duration(hoursPerDay*r.days)).Format(dateFormat)
|
||||
boundary := time.Now().Add(-time.Hour * time.Duration(hoursPerDay*r.days)).Format(time.DateOnly)
|
||||
buf.WriteString(r.filename)
|
||||
buf.WriteString(r.delimiter)
|
||||
buf.WriteString(boundary)
|
||||
@@ -212,7 +211,7 @@ func (r *SizeLimitRotateRule) OutdatedFiles() []string {
|
||||
}
|
||||
}
|
||||
|
||||
var result []string
|
||||
result := make([]string, 0, len(outdated))
|
||||
for k := range outdated {
|
||||
result = append(result, k)
|
||||
}
|
||||
@@ -425,7 +424,7 @@ func compressLogFile(file string) {
|
||||
}
|
||||
|
||||
func getNowDate() string {
|
||||
return time.Now().Format(dateFormat)
|
||||
return time.Now().Format(time.DateOnly)
|
||||
}
|
||||
|
||||
func getNowDateInRFC3339Format() string {
|
||||
|
||||
@@ -52,7 +52,7 @@ func TestDailyRotateRuleOutdatedFiles(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("temp files", func(t *testing.T) {
|
||||
boundary := time.Now().Add(-time.Hour * time.Duration(hoursPerDay) * 2).Format(dateFormat)
|
||||
boundary := time.Now().Add(-time.Hour * time.Duration(hoursPerDay) * 2).Format(time.DateOnly)
|
||||
f1, err := os.CreateTemp(os.TempDir(), "go-zero-test-"+boundary)
|
||||
assert.NoError(t, err)
|
||||
_ = f1.Close()
|
||||
@@ -73,7 +73,7 @@ func TestDailyRotateRuleOutdatedFiles(t *testing.T) {
|
||||
|
||||
func TestDailyRotateRuleShallRotate(t *testing.T) {
|
||||
var rule DailyRotateRule
|
||||
rule.rotatedTime = time.Now().Add(time.Hour * 24).Format(dateFormat)
|
||||
rule.rotatedTime = time.Now().Add(time.Hour * 24).Format(time.DateOnly)
|
||||
assert.True(t, rule.ShallRotate(0))
|
||||
}
|
||||
|
||||
@@ -117,12 +117,12 @@ func TestSizeLimitRotateRuleOutdatedFiles(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("temp files", func(t *testing.T) {
|
||||
boundary := time.Now().Add(-time.Hour * time.Duration(hoursPerDay) * 2).Format(dateFormat)
|
||||
boundary := time.Now().Add(-time.Hour * time.Duration(hoursPerDay) * 2).Format(time.DateOnly)
|
||||
f1, err := os.CreateTemp(os.TempDir(), "go-zero-test-"+boundary)
|
||||
assert.NoError(t, err)
|
||||
f2, err := os.CreateTemp(os.TempDir(), "go-zero-test-"+boundary)
|
||||
assert.NoError(t, err)
|
||||
boundary1 := time.Now().Add(time.Hour * time.Duration(hoursPerDay) * 2).Format(dateFormat)
|
||||
boundary1 := time.Now().Add(time.Hour * time.Duration(hoursPerDay) * 2).Format(time.DateOnly)
|
||||
f3, err := os.CreateTemp(os.TempDir(), "go-zero-test-"+boundary1)
|
||||
assert.NoError(t, err)
|
||||
t.Cleanup(func() {
|
||||
@@ -144,12 +144,12 @@ func TestSizeLimitRotateRuleOutdatedFiles(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("no backups", func(t *testing.T) {
|
||||
boundary := time.Now().Add(-time.Hour * time.Duration(hoursPerDay) * 2).Format(dateFormat)
|
||||
boundary := time.Now().Add(-time.Hour * time.Duration(hoursPerDay) * 2).Format(time.DateOnly)
|
||||
f1, err := os.CreateTemp(os.TempDir(), "go-zero-test-"+boundary)
|
||||
assert.NoError(t, err)
|
||||
f2, err := os.CreateTemp(os.TempDir(), "go-zero-test-"+boundary)
|
||||
assert.NoError(t, err)
|
||||
boundary1 := time.Now().Add(time.Hour * time.Duration(hoursPerDay) * 2).Format(dateFormat)
|
||||
boundary1 := time.Now().Add(time.Hour * time.Duration(hoursPerDay) * 2).Format(time.DateOnly)
|
||||
f3, err := os.CreateTemp(os.TempDir(), "go-zero-test-"+boundary1)
|
||||
assert.NoError(t, err)
|
||||
t.Cleanup(func() {
|
||||
@@ -319,7 +319,7 @@ func TestRotateLoggerWrite(t *testing.T) {
|
||||
}
|
||||
// the following write calls cannot be changed to Write, because of DATA RACE.
|
||||
logger.write([]byte(`foo`))
|
||||
rule.rotatedTime = time.Now().Add(-time.Hour * 24).Format(dateFormat)
|
||||
rule.rotatedTime = time.Now().Add(-time.Hour * 24).Format(time.DateOnly)
|
||||
logger.write([]byte(`bar`))
|
||||
logger.Close()
|
||||
logger.write([]byte(`baz`))
|
||||
@@ -447,7 +447,7 @@ func TestRotateLoggerWithSizeLimitRotateRuleWrite(t *testing.T) {
|
||||
}
|
||||
// the following write calls cannot be changed to Write, because of DATA RACE.
|
||||
logger.write([]byte(`foo`))
|
||||
rule.rotatedTime = time.Now().Add(-time.Hour * 24).Format(dateFormat)
|
||||
rule.rotatedTime = time.Now().Add(-time.Hour * 24).Format(time.DateOnly)
|
||||
logger.write([]byte(`bar`))
|
||||
logger.Close()
|
||||
logger.write([]byte(`baz`))
|
||||
|
||||
21
core/logx/sensitive.go
Normal file
21
core/logx/sensitive.go
Normal file
@@ -0,0 +1,21 @@
|
||||
package logx
|
||||
|
||||
// Sensitive is an interface that defines a method for masking sensitive information in logs.
|
||||
// It is typically implemented by types that contain sensitive data,
|
||||
// such as passwords or personal information.
|
||||
// Infov, Errorv, Debugv, and Slowv methods will call this method to mask sensitive data.
|
||||
// The values in LogField will also be masked if they implement the Sensitive interface.
|
||||
type Sensitive interface {
|
||||
// MaskSensitive masks sensitive information in the log.
|
||||
MaskSensitive() any
|
||||
}
|
||||
|
||||
// maskSensitive returns the value returned by MaskSensitive method,
|
||||
// if the value implements Sensitive interface.
|
||||
func maskSensitive(v any) any {
|
||||
if s, ok := v.(Sensitive); ok {
|
||||
return s.MaskSensitive()
|
||||
}
|
||||
|
||||
return v
|
||||
}
|
||||
50
core/logx/sensitive_test.go
Normal file
50
core/logx/sensitive_test.go
Normal file
@@ -0,0 +1,50 @@
|
||||
package logx
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
const maskedContent = "******"
|
||||
|
||||
type User struct {
|
||||
Name string
|
||||
Pass string
|
||||
}
|
||||
|
||||
func (u User) MaskSensitive() any {
|
||||
return User{
|
||||
Name: u.Name,
|
||||
Pass: maskedContent,
|
||||
}
|
||||
}
|
||||
|
||||
type NonSensitiveUser struct {
|
||||
Name string
|
||||
Pass string
|
||||
}
|
||||
|
||||
func TestMaskSensitive(t *testing.T) {
|
||||
t.Run("sensitive", func(t *testing.T) {
|
||||
user := User{
|
||||
Name: "kevin",
|
||||
Pass: "123",
|
||||
}
|
||||
|
||||
mu := maskSensitive(user)
|
||||
assert.Equal(t, user.Name, mu.(User).Name)
|
||||
assert.Equal(t, maskedContent, mu.(User).Pass)
|
||||
})
|
||||
|
||||
t.Run("non-sensitive", func(t *testing.T) {
|
||||
user := NonSensitiveUser{
|
||||
Name: "kevin",
|
||||
Pass: "123",
|
||||
}
|
||||
|
||||
mu := maskSensitive(user)
|
||||
assert.Equal(t, user.Name, mu.(NonSensitiveUser).Name)
|
||||
assert.Equal(t, user.Pass, mu.(NonSensitiveUser).Pass)
|
||||
})
|
||||
}
|
||||
@@ -53,14 +53,14 @@ const (
|
||||
)
|
||||
|
||||
const (
|
||||
callerKey = "caller"
|
||||
contentKey = "content"
|
||||
durationKey = "duration"
|
||||
levelKey = "level"
|
||||
spanKey = "span"
|
||||
timestampKey = "@timestamp"
|
||||
traceKey = "trace"
|
||||
truncatedKey = "truncated"
|
||||
defaultCallerKey = "caller"
|
||||
defaultContentKey = "content"
|
||||
defaultDurationKey = "duration"
|
||||
defaultLevelKey = "level"
|
||||
defaultSpanKey = "span"
|
||||
defaultTimestampKey = "@timestamp"
|
||||
defaultTraceKey = "trace"
|
||||
defaultTruncatedKey = "truncated"
|
||||
)
|
||||
|
||||
var (
|
||||
@@ -73,3 +73,14 @@ var (
|
||||
|
||||
truncatedField = Field(truncatedKey, true)
|
||||
)
|
||||
|
||||
var (
|
||||
callerKey = defaultCallerKey
|
||||
contentKey = defaultContentKey
|
||||
durationKey = defaultDurationKey
|
||||
levelKey = defaultLevelKey
|
||||
spanKey = defaultSpanKey
|
||||
timestampKey = defaultTimestampKey
|
||||
traceKey = defaultTraceKey
|
||||
truncatedKey = defaultTruncatedKey
|
||||
)
|
||||
|
||||
@@ -10,6 +10,7 @@ import (
|
||||
"runtime/debug"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
fatihcolor "github.com/fatih/color"
|
||||
"github.com/zeromicro/go-zero/core/color"
|
||||
@@ -17,15 +18,27 @@ import (
|
||||
)
|
||||
|
||||
type (
|
||||
// Writer is the interface for writing logs.
|
||||
// It's designed to let users customize their own log writer,
|
||||
// such as writing logs to a kafka, a database, or using third-party loggers.
|
||||
Writer interface {
|
||||
// Alert sends an alert message, if your writer implemented alerting functionality.
|
||||
Alert(v any)
|
||||
// Close closes the writer.
|
||||
Close() error
|
||||
// Debug logs a message at debug level.
|
||||
Debug(v any, fields ...LogField)
|
||||
// Error logs a message at error level.
|
||||
Error(v any, fields ...LogField)
|
||||
// Info logs a message at info level.
|
||||
Info(v any, fields ...LogField)
|
||||
// Severe logs a message at severe level.
|
||||
Severe(v any)
|
||||
// Slow logs a message at slow level.
|
||||
Slow(v any, fields ...LogField)
|
||||
// Stack logs a message at error level.
|
||||
Stack(v any)
|
||||
// Stat logs a message at stat level.
|
||||
Stat(v any, fields ...LogField)
|
||||
}
|
||||
|
||||
@@ -199,7 +212,6 @@ func newFileWriter(c LogConf) (Writer, error) {
|
||||
statFile := path.Join(c.Path, statFilename)
|
||||
|
||||
handleOptions(opts)
|
||||
setupLogLevel(c)
|
||||
|
||||
if infoLog, err = createOutput(accessFile); err != nil {
|
||||
return nil, err
|
||||
@@ -324,20 +336,6 @@ func buildPlainFields(fields logEntry) []string {
|
||||
return items
|
||||
}
|
||||
|
||||
func combineGlobalFields(fields []LogField) []LogField {
|
||||
globals := globalFields.Load()
|
||||
if globals == nil {
|
||||
return fields
|
||||
}
|
||||
|
||||
gf := globals.([]LogField)
|
||||
ret := make([]LogField, 0, len(gf)+len(fields))
|
||||
ret = append(ret, gf...)
|
||||
ret = append(ret, fields...)
|
||||
|
||||
return ret
|
||||
}
|
||||
|
||||
func marshalJson(t interface{}) ([]byte, error) {
|
||||
var buf bytes.Buffer
|
||||
encoder := json.NewEncoder(&buf)
|
||||
@@ -352,21 +350,40 @@ func marshalJson(t interface{}) ([]byte, error) {
|
||||
return buf.Bytes(), err
|
||||
}
|
||||
|
||||
func mergeGlobalFields(fields []LogField) []LogField {
|
||||
globals := globalFields.Load()
|
||||
if globals == nil {
|
||||
return fields
|
||||
}
|
||||
|
||||
gf := globals.([]LogField)
|
||||
ret := make([]LogField, 0, len(gf)+len(fields))
|
||||
ret = append(ret, gf...)
|
||||
ret = append(ret, fields...)
|
||||
|
||||
return ret
|
||||
}
|
||||
|
||||
func output(writer io.Writer, level string, val any, fields ...LogField) {
|
||||
// only truncate string content, don't know how to truncate the values of other types.
|
||||
if v, ok := val.(string); ok {
|
||||
switch v := val.(type) {
|
||||
case string:
|
||||
// only truncate string content, don't know how to truncate the values of other types.
|
||||
maxLen := atomic.LoadUint32(&maxContentLength)
|
||||
if maxLen > 0 && len(v) > int(maxLen) {
|
||||
val = v[:maxLen]
|
||||
fields = append(fields, truncatedField)
|
||||
}
|
||||
case Sensitive:
|
||||
val = v.MaskSensitive()
|
||||
}
|
||||
|
||||
fields = combineGlobalFields(fields)
|
||||
// +3 for timestamp, level and content
|
||||
entry := make(logEntry, len(fields)+3)
|
||||
for _, field := range fields {
|
||||
entry[field.Key] = field.Value
|
||||
// mask sensitive data before processing types,
|
||||
// in case field.Value is a sensitive type and also implemented fmt.Stringer.
|
||||
mval := maskSensitive(field.Value)
|
||||
entry[field.Key] = processFieldValue(mval)
|
||||
}
|
||||
|
||||
switch atomic.LoadUint32(&encoding) {
|
||||
@@ -381,6 +398,45 @@ func output(writer io.Writer, level string, val any, fields ...LogField) {
|
||||
}
|
||||
}
|
||||
|
||||
func processFieldValue(value any) any {
|
||||
switch val := value.(type) {
|
||||
case error:
|
||||
return encodeError(val)
|
||||
case []error:
|
||||
var errs []string
|
||||
for _, err := range val {
|
||||
errs = append(errs, encodeError(err))
|
||||
}
|
||||
return errs
|
||||
case time.Duration:
|
||||
return fmt.Sprint(val)
|
||||
case []time.Duration:
|
||||
var durs []string
|
||||
for _, dur := range val {
|
||||
durs = append(durs, fmt.Sprint(dur))
|
||||
}
|
||||
return durs
|
||||
case []time.Time:
|
||||
var times []string
|
||||
for _, t := range val {
|
||||
times = append(times, fmt.Sprint(t))
|
||||
}
|
||||
return times
|
||||
case json.Marshaler:
|
||||
return val
|
||||
case fmt.Stringer:
|
||||
return encodeStringer(val)
|
||||
case []fmt.Stringer:
|
||||
var strs []string
|
||||
for _, str := range val {
|
||||
strs = append(strs, encodeStringer(str))
|
||||
}
|
||||
return strs
|
||||
default:
|
||||
return val
|
||||
}
|
||||
}
|
||||
|
||||
func wrapLevelWithColor(level string) string {
|
||||
var colour color.Color
|
||||
switch level {
|
||||
|
||||
@@ -225,6 +225,48 @@ func TestWritePlainDuplicate(t *testing.T) {
|
||||
assert.Contains(t, buf.String(), "second=c")
|
||||
}
|
||||
|
||||
func TestLogWithSensitive(t *testing.T) {
|
||||
old := atomic.SwapUint32(&encoding, plainEncodingType)
|
||||
t.Cleanup(func() {
|
||||
atomic.StoreUint32(&encoding, old)
|
||||
})
|
||||
|
||||
t.Run("sensitive", func(t *testing.T) {
|
||||
var buf bytes.Buffer
|
||||
output(&buf, levelInfo, User{
|
||||
Name: "kevin",
|
||||
Pass: "123",
|
||||
}, LogField{
|
||||
Key: "first",
|
||||
Value: "a",
|
||||
}, LogField{
|
||||
Key: "first",
|
||||
Value: "b",
|
||||
})
|
||||
assert.Contains(t, buf.String(), maskedContent)
|
||||
assert.NotContains(t, buf.String(), "first=a")
|
||||
assert.Contains(t, buf.String(), "first=b")
|
||||
})
|
||||
|
||||
t.Run("sensitive fields", func(t *testing.T) {
|
||||
var buf bytes.Buffer
|
||||
output(&buf, levelInfo, "foo", LogField{
|
||||
Key: "first",
|
||||
Value: User{
|
||||
Name: "kevin",
|
||||
Pass: "123",
|
||||
},
|
||||
}, LogField{
|
||||
Key: "second",
|
||||
Value: "b",
|
||||
})
|
||||
assert.Contains(t, buf.String(), "foo")
|
||||
assert.Contains(t, buf.String(), "first")
|
||||
assert.Contains(t, buf.String(), maskedContent)
|
||||
assert.Contains(t, buf.String(), "second=b")
|
||||
})
|
||||
}
|
||||
|
||||
func TestLogWithLimitContentLength(t *testing.T) {
|
||||
maxLen := atomic.LoadUint32(&maxContentLength)
|
||||
atomic.StoreUint32(&maxContentLength, 10)
|
||||
|
||||
@@ -3,6 +3,7 @@ package mapping
|
||||
import (
|
||||
"fmt"
|
||||
"reflect"
|
||||
"slices"
|
||||
"strings"
|
||||
)
|
||||
|
||||
@@ -13,6 +14,15 @@ const (
|
||||
|
||||
// Marshal marshals the given val and returns the map that contains the fields.
|
||||
// optional=another is not implemented, and it's hard to implement and not commonly used.
|
||||
// support anonymous field, e.g.:
|
||||
//
|
||||
// type Foo struct {
|
||||
// Token string `header:"token"`
|
||||
// }
|
||||
// type FooB struct {
|
||||
// Foo
|
||||
// Bar string `json:"bar"`
|
||||
// }
|
||||
func Marshal(val any) (map[string]map[string]any, error) {
|
||||
ret := make(map[string]map[string]any)
|
||||
tp := reflect.TypeOf(val)
|
||||
@@ -44,6 +54,16 @@ func getTag(field reflect.StructField) (string, bool) {
|
||||
return strings.TrimSpace(tag), false
|
||||
}
|
||||
|
||||
func insertValue(collector map[string]map[string]any, tag string, key string, val any) {
|
||||
if m, ok := collector[tag]; ok {
|
||||
m[key] = val
|
||||
} else {
|
||||
collector[tag] = map[string]any{
|
||||
key: val,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func processMember(field reflect.StructField, value reflect.Value,
|
||||
collector map[string]map[string]any) error {
|
||||
var key string
|
||||
@@ -69,15 +89,20 @@ func processMember(field reflect.StructField, value reflect.Value,
|
||||
val = fmt.Sprint(val)
|
||||
}
|
||||
|
||||
m, ok := collector[tag]
|
||||
if ok {
|
||||
m[key] = val
|
||||
} else {
|
||||
m = map[string]any{
|
||||
key: val,
|
||||
if field.Anonymous {
|
||||
anonCollector, err := Marshal(val)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for anonTag, anonMap := range anonCollector {
|
||||
for anonKey, anonVal := range anonMap {
|
||||
insertValue(collector, anonTag, anonKey, anonVal)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
insertValue(collector, tag, key, val)
|
||||
}
|
||||
collector[tag] = m
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -118,7 +143,7 @@ func validateOptional(field reflect.StructField, value reflect.Value) error {
|
||||
if value.IsNil() {
|
||||
return fmt.Errorf("field %q is nil", field.Name)
|
||||
}
|
||||
case reflect.Array, reflect.Slice, reflect.Map:
|
||||
case reflect.Slice, reflect.Map:
|
||||
if value.IsNil() || value.Len() == 0 {
|
||||
return fmt.Errorf("field %q is empty", field.Name)
|
||||
}
|
||||
@@ -128,15 +153,8 @@ func validateOptional(field reflect.StructField, value reflect.Value) error {
|
||||
}
|
||||
|
||||
func validateOptions(value reflect.Value, opt *fieldOptions) error {
|
||||
var found bool
|
||||
val := fmt.Sprint(value.Interface())
|
||||
for i := range opt.Options {
|
||||
if opt.Options[i] == val {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
if !slices.Contains(opt.Options, val) {
|
||||
return fmt.Errorf("field %q not in options", val)
|
||||
}
|
||||
|
||||
|
||||
@@ -27,6 +27,124 @@ func TestMarshal(t *testing.T) {
|
||||
assert.True(t, m[emptyTag]["Anonymous"].(bool))
|
||||
}
|
||||
|
||||
func TestMarshal_Anonymous(t *testing.T) {
|
||||
t.Run("anonymous", func(t *testing.T) {
|
||||
type BaseHeader struct {
|
||||
Token string `header:"token"`
|
||||
}
|
||||
v := struct {
|
||||
Name string `json:"name"`
|
||||
Address string `json:"address,options=[beijing,shanghai]"`
|
||||
Age int `json:"age"`
|
||||
BaseHeader
|
||||
}{
|
||||
Name: "kevin",
|
||||
Address: "shanghai",
|
||||
Age: 20,
|
||||
BaseHeader: BaseHeader{
|
||||
Token: "token_xxx",
|
||||
},
|
||||
}
|
||||
m, err := Marshal(v)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, "kevin", m["json"]["name"])
|
||||
assert.Equal(t, "shanghai", m["json"]["address"])
|
||||
assert.Equal(t, 20, m["json"]["age"].(int))
|
||||
assert.Equal(t, "token_xxx", m["header"]["token"])
|
||||
|
||||
v1 := struct {
|
||||
Name string `json:"name"`
|
||||
Address string `json:"address,options=[beijing,shanghai]"`
|
||||
Age int `json:"age"`
|
||||
BaseHeader
|
||||
}{
|
||||
Name: "kevin",
|
||||
Address: "shanghai",
|
||||
Age: 20,
|
||||
}
|
||||
m1, err1 := Marshal(v1)
|
||||
assert.Nil(t, err1)
|
||||
assert.Equal(t, "kevin", m1["json"]["name"])
|
||||
assert.Equal(t, "shanghai", m1["json"]["address"])
|
||||
assert.Equal(t, 20, m1["json"]["age"].(int))
|
||||
|
||||
type AnotherHeader struct {
|
||||
Version string `header:"version"`
|
||||
}
|
||||
v2 := struct {
|
||||
Name string `json:"name"`
|
||||
Address string `json:"address,options=[beijing,shanghai]"`
|
||||
Age int `json:"age"`
|
||||
BaseHeader
|
||||
AnotherHeader
|
||||
}{
|
||||
Name: "kevin",
|
||||
Address: "shanghai",
|
||||
Age: 20,
|
||||
BaseHeader: BaseHeader{
|
||||
Token: "token_xxx",
|
||||
},
|
||||
AnotherHeader: AnotherHeader{
|
||||
Version: "v1.0",
|
||||
},
|
||||
}
|
||||
m2, err2 := Marshal(v2)
|
||||
assert.Nil(t, err2)
|
||||
assert.Equal(t, "kevin", m2["json"]["name"])
|
||||
assert.Equal(t, "shanghai", m2["json"]["address"])
|
||||
assert.Equal(t, 20, m2["json"]["age"].(int))
|
||||
assert.Equal(t, "token_xxx", m2["header"]["token"])
|
||||
assert.Equal(t, "v1.0", m2["header"]["version"])
|
||||
|
||||
type PointerHeader struct {
|
||||
Ref *string `header:"ref"`
|
||||
}
|
||||
ref := "reference"
|
||||
v3 := struct {
|
||||
Name string `json:"name"`
|
||||
Address string `json:"address,options=[beijing,shanghai]"`
|
||||
Age int `json:"age"`
|
||||
PointerHeader
|
||||
}{
|
||||
Name: "kevin",
|
||||
Address: "shanghai",
|
||||
Age: 20,
|
||||
PointerHeader: PointerHeader{
|
||||
Ref: &ref,
|
||||
},
|
||||
}
|
||||
m3, err3 := Marshal(v3)
|
||||
assert.Nil(t, err3)
|
||||
assert.Equal(t, "kevin", m3["json"]["name"])
|
||||
assert.Equal(t, "shanghai", m3["json"]["address"])
|
||||
assert.Equal(t, 20, m3["json"]["age"].(int))
|
||||
assert.Equal(t, "reference", *m3["header"]["ref"].(*string))
|
||||
})
|
||||
|
||||
t.Run("bad anonymous", func(t *testing.T) {
|
||||
type BaseHeader struct {
|
||||
Token string `json:"token,options=[a,b]"`
|
||||
}
|
||||
|
||||
v := struct {
|
||||
Name string `json:"name"`
|
||||
Address string `json:"address,options=[beijing,shanghai]"`
|
||||
Age int `json:"age"`
|
||||
BaseHeader
|
||||
}{
|
||||
Name: "kevin",
|
||||
Address: "shanghai",
|
||||
Age: 20,
|
||||
BaseHeader: BaseHeader{
|
||||
Token: "c",
|
||||
},
|
||||
}
|
||||
|
||||
_, err := Marshal(v)
|
||||
assert.NotNil(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
func TestMarshal_Ptr(t *testing.T) {
|
||||
v := &struct {
|
||||
Name string `path:"name"`
|
||||
@@ -344,3 +462,15 @@ func TestMarshal_FromString(t *testing.T) {
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, "10", m["json"]["age"].(string))
|
||||
}
|
||||
|
||||
func TestMarshal_Array(t *testing.T) {
|
||||
v := struct {
|
||||
H [1]int `json:"h,string"`
|
||||
}{
|
||||
H: [1]int{1},
|
||||
}
|
||||
|
||||
m, err := Marshal(v)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, "[1]", m["json"]["h"].(string))
|
||||
}
|
||||
|
||||
@@ -7,6 +7,7 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"slices"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
@@ -15,11 +16,9 @@ import (
|
||||
"github.com/zeromicro/go-zero/core/jsonx"
|
||||
"github.com/zeromicro/go-zero/core/lang"
|
||||
"github.com/zeromicro/go-zero/core/proc"
|
||||
"github.com/zeromicro/go-zero/core/stringx"
|
||||
)
|
||||
|
||||
const (
|
||||
comma = ","
|
||||
defaultKeyName = "key"
|
||||
delimiter = '.'
|
||||
ignoreKey = "-"
|
||||
@@ -38,7 +37,6 @@ var (
|
||||
defaultCacheLock sync.Mutex
|
||||
emptyMap = map[string]any{}
|
||||
emptyValue = reflect.ValueOf(lang.Placeholder)
|
||||
stringSliceType = reflect.TypeOf([]string{})
|
||||
)
|
||||
|
||||
type (
|
||||
@@ -152,10 +150,6 @@ func (u *Unmarshaler) fillSlice(fieldType reflect.Type, value reflect.Value,
|
||||
return nil
|
||||
}
|
||||
|
||||
if u.opts.fromArray {
|
||||
refValue = makeStringSlice(refValue)
|
||||
}
|
||||
|
||||
var valid bool
|
||||
conv := reflect.MakeSlice(reflect.SliceOf(baseType), refValue.Len(), refValue.Cap())
|
||||
|
||||
@@ -628,9 +622,19 @@ func (u *Unmarshaler) processFieldNotFromString(fieldType reflect.Type, value re
|
||||
|
||||
return u.fillSliceFromString(fieldType, value, mapValue, fullName)
|
||||
case valueKind == reflect.String && derefedFieldType == durationType:
|
||||
return fillDurationValue(fieldType, value, mapValue.(string))
|
||||
v, err := convertToString(mapValue, fullName)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return fillDurationValue(fieldType, value, v)
|
||||
case valueKind == reflect.String && typeKind == reflect.Struct && u.implementsUnmarshaler(fieldType):
|
||||
return u.fillUnmarshalerStruct(fieldType, value, mapValue.(string))
|
||||
v, err := convertToString(mapValue, fullName)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return u.fillUnmarshalerStruct(fieldType, value, v)
|
||||
default:
|
||||
return u.processFieldPrimitive(fieldType, value, mapValue, opts, fullName)
|
||||
}
|
||||
@@ -761,24 +765,26 @@ func (u *Unmarshaler) processFieldWithEnvValue(fieldType reflect.Type, value ref
|
||||
return err
|
||||
}
|
||||
|
||||
fieldKind := fieldType.Kind()
|
||||
switch fieldKind {
|
||||
case reflect.Bool:
|
||||
derefType := Deref(fieldType)
|
||||
derefKind := derefType.Kind()
|
||||
switch {
|
||||
case derefKind == reflect.String:
|
||||
SetValue(fieldType, value, toReflectValue(derefType, envVal))
|
||||
return nil
|
||||
case derefKind == reflect.Bool:
|
||||
val, err := strconv.ParseBool(envVal)
|
||||
if err != nil {
|
||||
return fmt.Errorf("unmarshal field %q with environment variable, %w", fullName, err)
|
||||
}
|
||||
|
||||
value.SetBool(val)
|
||||
SetValue(fieldType, value, toReflectValue(derefType, val))
|
||||
return nil
|
||||
case durationType.Kind():
|
||||
case derefType == durationType:
|
||||
// time.Duration is a special case, its derefKind is reflect.Int64.
|
||||
if err := fillDurationValue(fieldType, value, envVal); err != nil {
|
||||
return fmt.Errorf("unmarshal field %q with environment variable, %w", fullName, err)
|
||||
}
|
||||
|
||||
return nil
|
||||
case reflect.String:
|
||||
value.SetString(envVal)
|
||||
return nil
|
||||
default:
|
||||
return u.processFieldPrimitiveWithJSONNumber(fieldType, value, json.Number(envVal), opts, fullName)
|
||||
@@ -900,7 +906,7 @@ func (u *Unmarshaler) processNamedFieldWithValueFromString(fieldType reflect.Typ
|
||||
valueKind.String())
|
||||
}
|
||||
|
||||
if !stringx.Contains(options, checkValue) {
|
||||
if !slices.Contains(options, checkValue) {
|
||||
return fmt.Errorf(`value "%s" for field %q is not defined in options "%v"`,
|
||||
mapValue, key, options)
|
||||
}
|
||||
@@ -1189,35 +1195,6 @@ func join(elem ...string) string {
|
||||
return builder.String()
|
||||
}
|
||||
|
||||
func makeStringSlice(refValue reflect.Value) reflect.Value {
|
||||
if refValue.Len() != 1 {
|
||||
return refValue
|
||||
}
|
||||
|
||||
element := refValue.Index(0)
|
||||
if element.Kind() != reflect.String {
|
||||
return refValue
|
||||
}
|
||||
|
||||
val, ok := element.Interface().(string)
|
||||
if !ok {
|
||||
return refValue
|
||||
}
|
||||
|
||||
splits := strings.Split(val, comma)
|
||||
if len(splits) <= 1 {
|
||||
return refValue
|
||||
}
|
||||
|
||||
slice := reflect.MakeSlice(stringSliceType, len(splits), len(splits))
|
||||
for i, split := range splits {
|
||||
// allow empty strings
|
||||
slice.Index(i).Set(reflect.ValueOf(split))
|
||||
}
|
||||
|
||||
return slice
|
||||
}
|
||||
|
||||
func newInitError(name string) error {
|
||||
return fmt.Errorf("field %q is not set", name)
|
||||
}
|
||||
|
||||
@@ -203,6 +203,20 @@ func TestUnmarshalDuration(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestUnmarshalDurationUnexpectedError(t *testing.T) {
|
||||
type inner struct {
|
||||
Duration time.Duration `key:"duration"`
|
||||
}
|
||||
content := "{\"duration\": 1}"
|
||||
var m = map[string]any{}
|
||||
err := jsonx.Unmarshal([]byte(content), &m)
|
||||
assert.NoError(t, err)
|
||||
var in inner
|
||||
err = UnmarshalKey(m, &in)
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "expect string")
|
||||
}
|
||||
|
||||
func TestUnmarshalDurationDefault(t *testing.T) {
|
||||
type inner struct {
|
||||
Int int `key:"int"`
|
||||
@@ -1462,9 +1476,7 @@ func TestUnmarshalIntSlice(t *testing.T) {
|
||||
|
||||
ast := assert.New(t)
|
||||
unmarshaler := NewUnmarshaler(defaultKeyName, WithFromArray())
|
||||
if ast.NoError(unmarshaler.Unmarshal(m, &v)) {
|
||||
ast.ElementsMatch([]int{1, 2}, v.Ages)
|
||||
}
|
||||
ast.Error(unmarshaler.Unmarshal(m, &v))
|
||||
})
|
||||
}
|
||||
|
||||
@@ -1546,7 +1558,22 @@ func TestUnmarshalStringSliceFromString(t *testing.T) {
|
||||
ast := assert.New(t)
|
||||
unmarshaler := NewUnmarshaler(defaultKeyName, WithFromArray())
|
||||
if ast.NoError(unmarshaler.Unmarshal(m, &v)) {
|
||||
ast.ElementsMatch([]string{"", ""}, v.Names)
|
||||
ast.ElementsMatch([]string{","}, v.Names)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("slice from valid strings with comma", func(t *testing.T) {
|
||||
var v struct {
|
||||
Names []string `key:"names"`
|
||||
}
|
||||
m := map[string]any{
|
||||
"names": []string{"aa,bb"},
|
||||
}
|
||||
|
||||
ast := assert.New(t)
|
||||
unmarshaler := NewUnmarshaler(defaultKeyName, WithFromArray())
|
||||
if ast.NoError(unmarshaler.Unmarshal(m, &v)) {
|
||||
ast.ElementsMatch([]string{"aa,bb"}, v.Names)
|
||||
}
|
||||
})
|
||||
|
||||
@@ -4652,6 +4679,23 @@ func TestUnmarshal_EnvInt(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestUnmarshal_EnvInt64(t *testing.T) {
|
||||
type Value struct {
|
||||
Age int64 `key:"age,env=TEST_NAME_INT64"`
|
||||
}
|
||||
|
||||
const (
|
||||
envName = "TEST_NAME_INT64"
|
||||
envVal = "88"
|
||||
)
|
||||
t.Setenv(envName, envVal)
|
||||
|
||||
var v Value
|
||||
if assert.NoError(t, UnmarshalKey(emptyMap, &v)) {
|
||||
assert.Equal(t, int64(88), v.Age)
|
||||
}
|
||||
}
|
||||
|
||||
func TestUnmarshal_EnvIntOverwrite(t *testing.T) {
|
||||
type Value struct {
|
||||
Age int `key:"age,env=TEST_NAME_INT"`
|
||||
@@ -4757,20 +4801,33 @@ func TestUnmarshal_EnvBoolBad(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestUnmarshal_EnvDuration(t *testing.T) {
|
||||
type Value struct {
|
||||
Duration time.Duration `key:"duration,env=TEST_NAME_DURATION"`
|
||||
}
|
||||
|
||||
const (
|
||||
envName = "TEST_NAME_DURATION"
|
||||
envVal = "1s"
|
||||
)
|
||||
t.Setenv(envName, envVal)
|
||||
|
||||
var v Value
|
||||
if assert.NoError(t, UnmarshalKey(emptyMap, &v)) {
|
||||
assert.Equal(t, time.Second, v.Duration)
|
||||
}
|
||||
t.Run("valid duration", func(t *testing.T) {
|
||||
type Value struct {
|
||||
Duration time.Duration `key:"duration,env=TEST_NAME_DURATION"`
|
||||
}
|
||||
|
||||
var v Value
|
||||
if assert.NoError(t, UnmarshalKey(emptyMap, &v)) {
|
||||
assert.Equal(t, time.Second, v.Duration)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("ptr of duration", func(t *testing.T) {
|
||||
type Value struct {
|
||||
Duration *time.Duration `key:"duration,env=TEST_NAME_DURATION"`
|
||||
}
|
||||
|
||||
var v Value
|
||||
if assert.NoError(t, UnmarshalKey(emptyMap, &v)) {
|
||||
assert.Equal(t, time.Second, *v.Duration)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestUnmarshal_EnvDurationBadValue(t *testing.T) {
|
||||
@@ -5982,6 +6039,16 @@ func TestUnmarshal_Unmarshaler(t *testing.T) {
|
||||
}, &v))
|
||||
assert.Nil(t, v.Foo)
|
||||
})
|
||||
|
||||
t.Run("json.Number", func(t *testing.T) {
|
||||
v := struct {
|
||||
Foo *mockUnmarshaler `json:"name"`
|
||||
}{}
|
||||
m := map[string]any{
|
||||
"name": json.Number("123"),
|
||||
}
|
||||
assert.Error(t, UnmarshalJsonMap(m, &v))
|
||||
})
|
||||
}
|
||||
|
||||
func TestParseJsonStringValue(t *testing.T) {
|
||||
@@ -6016,6 +6083,105 @@ func TestParseJsonStringValue(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
// issue #5033, string type
|
||||
func TestUnmarshalFromEnvString(t *testing.T) {
|
||||
t.Setenv("STRING_ENV", "dev")
|
||||
|
||||
t.Run("by value", func(t *testing.T) {
|
||||
type (
|
||||
Env string
|
||||
Config struct {
|
||||
Env Env `json:",env=STRING_ENV,default=prod"`
|
||||
}
|
||||
)
|
||||
|
||||
var c Config
|
||||
if assert.NoError(t, UnmarshalJsonMap(map[string]any{}, &c)) {
|
||||
assert.Equal(t, Env("dev"), c.Env)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("by ptr", func(t *testing.T) {
|
||||
type (
|
||||
Env string
|
||||
Config struct {
|
||||
Env *Env `json:",env=STRING_ENV,default=prod"`
|
||||
}
|
||||
)
|
||||
|
||||
var c Config
|
||||
if assert.NoError(t, UnmarshalJsonMap(map[string]any{}, &c)) {
|
||||
assert.Equal(t, Env("dev"), *c.Env)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// issue #5033, bool type
|
||||
func TestUnmarshalFromEnvBool(t *testing.T) {
|
||||
t.Setenv("BOOL_ENV", "true")
|
||||
|
||||
t.Run("by value", func(t *testing.T) {
|
||||
type (
|
||||
Env bool
|
||||
Config struct {
|
||||
Env Env `json:",env=BOOL_ENV,default=false"`
|
||||
}
|
||||
)
|
||||
|
||||
var c Config
|
||||
if assert.NoError(t, UnmarshalJsonMap(map[string]any{}, &c)) {
|
||||
assert.Equal(t, Env(true), c.Env)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("by ptr", func(t *testing.T) {
|
||||
type (
|
||||
Env bool
|
||||
Config struct {
|
||||
Env *Env `json:",env=BOOL_ENV,default=false"`
|
||||
}
|
||||
)
|
||||
|
||||
var c Config
|
||||
if assert.NoError(t, UnmarshalJsonMap(map[string]any{}, &c)) {
|
||||
assert.Equal(t, Env(true), *c.Env)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// issue #5033, customized int type
|
||||
func TestUnmarshalFromEnvInt(t *testing.T) {
|
||||
t.Setenv("INT_ENV", "2")
|
||||
|
||||
t.Run("by value", func(t *testing.T) {
|
||||
type (
|
||||
Env int
|
||||
Config struct {
|
||||
Env Env `json:",env=INT_ENV,default=0"`
|
||||
}
|
||||
)
|
||||
|
||||
var c Config
|
||||
if assert.NoError(t, UnmarshalJsonMap(map[string]any{}, &c)) {
|
||||
assert.Equal(t, Env(2), c.Env)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("by ptr", func(t *testing.T) {
|
||||
type (
|
||||
Env int
|
||||
Config struct {
|
||||
Env *Env `json:",env=INT_ENV,default=0"`
|
||||
}
|
||||
)
|
||||
|
||||
var c Config
|
||||
if assert.NoError(t, UnmarshalJsonMap(map[string]any{}, &c)) {
|
||||
assert.Equal(t, Env(2), *c.Env)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func BenchmarkDefaultValue(b *testing.B) {
|
||||
for i := 0; i < b.N; i++ {
|
||||
var a struct {
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"fmt"
|
||||
"math"
|
||||
"reflect"
|
||||
"slices"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
@@ -91,6 +92,15 @@ func ValidatePtr(v reflect.Value) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func convertToString(val any, fullName string) (string, error) {
|
||||
v, ok := val.(string)
|
||||
if !ok {
|
||||
return "", fmt.Errorf("expect string for field %s, but got type %T", fullName, val)
|
||||
}
|
||||
|
||||
return v, nil
|
||||
}
|
||||
|
||||
func convertTypeFromString(kind reflect.Kind, str string) (any, error) {
|
||||
switch kind {
|
||||
case reflect.Bool:
|
||||
@@ -573,6 +583,10 @@ func toFloat64(v any) (float64, bool) {
|
||||
}
|
||||
}
|
||||
|
||||
func toReflectValue(tp reflect.Type, v any) reflect.Value {
|
||||
return reflect.ValueOf(v).Convert(Deref(tp))
|
||||
}
|
||||
|
||||
func usingDifferentKeys(key string, field reflect.StructField) bool {
|
||||
if len(field.Tag) > 0 {
|
||||
if _, ok := field.Tag.Lookup(key); !ok {
|
||||
@@ -634,11 +648,11 @@ func validateValueInOptions(val any, options []string) error {
|
||||
if len(options) > 0 {
|
||||
switch v := val.(type) {
|
||||
case string:
|
||||
if !stringx.Contains(options, v) {
|
||||
if !slices.Contains(options, v) {
|
||||
return fmt.Errorf(`error: value %q is not defined in options "%v"`, v, options)
|
||||
}
|
||||
default:
|
||||
if !stringx.Contains(options, Repr(v)) {
|
||||
if !slices.Contains(options, Repr(v)) {
|
||||
return fmt.Errorf(`error: value "%v" is not defined in options "%v"`, val, options)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,19 +1,13 @@
|
||||
package mathx
|
||||
|
||||
// MaxInt returns the larger one of a and b.
|
||||
// Deprecated: use builtin max instead.
|
||||
func MaxInt(a, b int) int {
|
||||
if a > b {
|
||||
return a
|
||||
}
|
||||
|
||||
return b
|
||||
return max(a, b)
|
||||
}
|
||||
|
||||
// MinInt returns the smaller one of a and b.
|
||||
// Deprecated: use builtin min instead.
|
||||
func MinInt(a, b int) int {
|
||||
if a < b {
|
||||
return a
|
||||
}
|
||||
|
||||
return b
|
||||
return min(a, b)
|
||||
}
|
||||
|
||||
@@ -3,6 +3,9 @@ package mr
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"runtime/debug"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
|
||||
@@ -142,89 +145,6 @@ func MapReduceChan[T, U, V any](source <-chan T, mapper MapperFunc[T, U], reduce
|
||||
return mapReduceWithPanicChan(source, panicChan, mapper, reducer, opts...)
|
||||
}
|
||||
|
||||
// mapReduceWithPanicChan maps all elements from source, and reduce the output elements with given reducer.
|
||||
func mapReduceWithPanicChan[T, U, V any](source <-chan T, panicChan *onceChan, mapper MapperFunc[T, U],
|
||||
reducer ReducerFunc[U, V], opts ...Option) (val V, err error) {
|
||||
options := buildOptions(opts...)
|
||||
// output is used to write the final result
|
||||
output := make(chan V)
|
||||
defer func() {
|
||||
// reducer can only write once, if more, panic
|
||||
for range output {
|
||||
panic("more than one element written in reducer")
|
||||
}
|
||||
}()
|
||||
|
||||
// collector is used to collect data from mapper, and consume in reducer
|
||||
collector := make(chan U, options.workers)
|
||||
// if done is closed, all mappers and reducer should stop processing
|
||||
done := make(chan struct{})
|
||||
writer := newGuardedWriter(options.ctx, output, done)
|
||||
var closeOnce sync.Once
|
||||
// use atomic type to avoid data race
|
||||
var retErr errorx.AtomicError
|
||||
finish := func() {
|
||||
closeOnce.Do(func() {
|
||||
close(done)
|
||||
close(output)
|
||||
})
|
||||
}
|
||||
cancel := once(func(err error) {
|
||||
if err != nil {
|
||||
retErr.Set(err)
|
||||
} else {
|
||||
retErr.Set(ErrCancelWithNil)
|
||||
}
|
||||
|
||||
drain(source)
|
||||
finish()
|
||||
})
|
||||
|
||||
go func() {
|
||||
defer func() {
|
||||
drain(collector)
|
||||
if r := recover(); r != nil {
|
||||
panicChan.write(r)
|
||||
}
|
||||
finish()
|
||||
}()
|
||||
|
||||
reducer(collector, writer, cancel)
|
||||
}()
|
||||
|
||||
go executeMappers(mapperContext[T, U]{
|
||||
ctx: options.ctx,
|
||||
mapper: func(item T, w Writer[U]) {
|
||||
mapper(item, w, cancel)
|
||||
},
|
||||
source: source,
|
||||
panicChan: panicChan,
|
||||
collector: collector,
|
||||
doneChan: done,
|
||||
workers: options.workers,
|
||||
})
|
||||
|
||||
select {
|
||||
case <-options.ctx.Done():
|
||||
cancel(context.DeadlineExceeded)
|
||||
err = context.DeadlineExceeded
|
||||
case v := <-panicChan.channel:
|
||||
// drain output here, otherwise for loop panic in defer
|
||||
drain(output)
|
||||
panic(v)
|
||||
case v, ok := <-output:
|
||||
if e := retErr.Load(); e != nil {
|
||||
err = e
|
||||
} else if ok {
|
||||
val = v
|
||||
} else {
|
||||
err = ErrReduceNoOutput
|
||||
}
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// MapReduceVoid maps all elements generated from given generate,
|
||||
// and reduce the output elements with given reducer.
|
||||
func MapReduceVoid[T, U any](generate GenerateFunc[T], mapper MapperFunc[T, U],
|
||||
@@ -266,12 +186,16 @@ func buildOptions(opts ...Option) *mapReduceOptions {
|
||||
return options
|
||||
}
|
||||
|
||||
func buildPanicInfo(r any, stack []byte) string {
|
||||
return fmt.Sprintf("%+v\n\n%s", r, strings.TrimSpace(string(stack)))
|
||||
}
|
||||
|
||||
func buildSource[T any](generate GenerateFunc[T], panicChan *onceChan) chan T {
|
||||
source := make(chan T)
|
||||
go func() {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
panicChan.write(r)
|
||||
panicChan.write(buildPanicInfo(r, debug.Stack()))
|
||||
}
|
||||
close(source)
|
||||
}()
|
||||
@@ -318,7 +242,7 @@ func executeMappers[T, U any](mCtx mapperContext[T, U]) {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
atomic.AddInt32(&failed, 1)
|
||||
mCtx.panicChan.write(r)
|
||||
mCtx.panicChan.write(buildPanicInfo(r, debug.Stack()))
|
||||
}
|
||||
wg.Done()
|
||||
<-pool
|
||||
@@ -330,6 +254,89 @@ func executeMappers[T, U any](mCtx mapperContext[T, U]) {
|
||||
}
|
||||
}
|
||||
|
||||
// mapReduceWithPanicChan maps all elements from source, and reduce the output elements with given reducer.
|
||||
func mapReduceWithPanicChan[T, U, V any](source <-chan T, panicChan *onceChan, mapper MapperFunc[T, U],
|
||||
reducer ReducerFunc[U, V], opts ...Option) (val V, err error) {
|
||||
options := buildOptions(opts...)
|
||||
// output is used to write the final result
|
||||
output := make(chan V)
|
||||
defer func() {
|
||||
// reducer can only write once, if more, panic
|
||||
for range output {
|
||||
panic("more than one element written in reducer")
|
||||
}
|
||||
}()
|
||||
|
||||
// collector is used to collect data from mapper, and consume in reducer
|
||||
collector := make(chan U, options.workers)
|
||||
// if done is closed, all mappers and reducer should stop processing
|
||||
done := make(chan struct{})
|
||||
writer := newGuardedWriter(options.ctx, output, done)
|
||||
var closeOnce sync.Once
|
||||
// use atomic type to avoid data race
|
||||
var retErr errorx.AtomicError
|
||||
finish := func() {
|
||||
closeOnce.Do(func() {
|
||||
close(done)
|
||||
close(output)
|
||||
})
|
||||
}
|
||||
cancel := once(func(err error) {
|
||||
if err != nil {
|
||||
retErr.Set(err)
|
||||
} else {
|
||||
retErr.Set(ErrCancelWithNil)
|
||||
}
|
||||
|
||||
drain(source)
|
||||
finish()
|
||||
})
|
||||
|
||||
go func() {
|
||||
defer func() {
|
||||
drain(collector)
|
||||
if r := recover(); r != nil {
|
||||
panicChan.write(buildPanicInfo(r, debug.Stack()))
|
||||
}
|
||||
finish()
|
||||
}()
|
||||
|
||||
reducer(collector, writer, cancel)
|
||||
}()
|
||||
|
||||
go executeMappers(mapperContext[T, U]{
|
||||
ctx: options.ctx,
|
||||
mapper: func(item T, w Writer[U]) {
|
||||
mapper(item, w, cancel)
|
||||
},
|
||||
source: source,
|
||||
panicChan: panicChan,
|
||||
collector: collector,
|
||||
doneChan: done,
|
||||
workers: options.workers,
|
||||
})
|
||||
|
||||
select {
|
||||
case <-options.ctx.Done():
|
||||
cancel(context.DeadlineExceeded)
|
||||
err = context.DeadlineExceeded
|
||||
case v := <-panicChan.channel:
|
||||
// drain output here, otherwise for loop panic in defer
|
||||
drain(output)
|
||||
panic(v)
|
||||
case v, ok := <-output:
|
||||
if e := retErr.Load(); e != nil {
|
||||
err = e
|
||||
} else if ok {
|
||||
val = v
|
||||
} else {
|
||||
err = ErrReduceNoOutput
|
||||
}
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func newOptions() *mapReduceOptions {
|
||||
return &mapReduceOptions{
|
||||
ctx: context.Background(),
|
||||
|
||||
@@ -3,6 +3,7 @@ package mr
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"runtime"
|
||||
@@ -39,6 +40,36 @@ func TestFinish(t *testing.T) {
|
||||
assert.Nil(t, err)
|
||||
}
|
||||
|
||||
func TestFinishWithPartialErrors(t *testing.T) {
|
||||
defer goleak.VerifyNone(t)
|
||||
|
||||
errDummy := errors.New("dummy")
|
||||
|
||||
t.Run("one error", func(t *testing.T) {
|
||||
err := Finish(func() error {
|
||||
return errDummy
|
||||
}, func() error {
|
||||
return nil
|
||||
}, func() error {
|
||||
return nil
|
||||
})
|
||||
|
||||
assert.Equal(t, errDummy, err)
|
||||
})
|
||||
|
||||
t.Run("two errors", func(t *testing.T) {
|
||||
err := Finish(func() error {
|
||||
return errDummy
|
||||
}, func() error {
|
||||
return errDummy
|
||||
}, func() error {
|
||||
return nil
|
||||
})
|
||||
|
||||
assert.Equal(t, errDummy, err)
|
||||
})
|
||||
}
|
||||
|
||||
func TestFinishNone(t *testing.T) {
|
||||
defer goleak.VerifyNone(t)
|
||||
|
||||
@@ -118,11 +149,28 @@ func TestForEach(t *testing.T) {
|
||||
|
||||
assert.Equal(t, tasks/2, int(count))
|
||||
})
|
||||
}
|
||||
|
||||
t.Run("all", func(t *testing.T) {
|
||||
defer goleak.VerifyNone(t)
|
||||
func TestPanics(t *testing.T) {
|
||||
defer goleak.VerifyNone(t)
|
||||
|
||||
const tasks = 1000
|
||||
verify := func(t *testing.T, r any) {
|
||||
panicStr := fmt.Sprintf("%v", r)
|
||||
assert.Contains(t, panicStr, "foo")
|
||||
assert.Contains(t, panicStr, "goroutine")
|
||||
assert.Contains(t, panicStr, "runtime/debug.Stack")
|
||||
panic(r)
|
||||
}
|
||||
|
||||
t.Run("ForEach run panics", func(t *testing.T) {
|
||||
assert.Panics(t, func() {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
verify(t, r)
|
||||
}
|
||||
}()
|
||||
|
||||
assert.PanicsWithValue(t, "foo", func() {
|
||||
ForEach(func(source chan<- int) {
|
||||
for i := 0; i < tasks; i++ {
|
||||
source <- i
|
||||
@@ -132,28 +180,31 @@ func TestForEach(t *testing.T) {
|
||||
})
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func TestGeneratePanic(t *testing.T) {
|
||||
defer goleak.VerifyNone(t)
|
||||
t.Run("ForEach generate panics", func(t *testing.T) {
|
||||
assert.Panics(t, func() {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
verify(t, r)
|
||||
}
|
||||
}()
|
||||
|
||||
t.Run("all", func(t *testing.T) {
|
||||
assert.PanicsWithValue(t, "foo", func() {
|
||||
ForEach(func(source chan<- int) {
|
||||
panic("foo")
|
||||
}, func(item int) {
|
||||
})
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func TestMapperPanic(t *testing.T) {
|
||||
defer goleak.VerifyNone(t)
|
||||
|
||||
const tasks = 1000
|
||||
var run int32
|
||||
t.Run("all", func(t *testing.T) {
|
||||
assert.PanicsWithValue(t, "foo", func() {
|
||||
t.Run("Mapper panics", func(t *testing.T) {
|
||||
assert.Panics(t, func() {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
verify(t, r)
|
||||
}
|
||||
}()
|
||||
|
||||
_, _ = MapReduce(func(source chan<- int) {
|
||||
for i := 0; i < tasks; i++ {
|
||||
source <- i
|
||||
|
||||
@@ -1,13 +1,12 @@
|
||||
package prof
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"strconv"
|
||||
"fmt"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/olekukonko/tablewriter"
|
||||
"github.com/zeromicro/go-zero/core/logx"
|
||||
"github.com/zeromicro/go-zero/core/threading"
|
||||
)
|
||||
@@ -28,46 +27,15 @@ type (
|
||||
|
||||
const flushInterval = 5 * time.Minute
|
||||
|
||||
var (
|
||||
pc = &profileCenter{
|
||||
slots: make(map[string]*profileSlot),
|
||||
}
|
||||
once sync.Once
|
||||
)
|
||||
|
||||
func report(name string, duration time.Duration) {
|
||||
updated := func() bool {
|
||||
pc.lock.RLock()
|
||||
defer pc.lock.RUnlock()
|
||||
|
||||
slot, ok := pc.slots[name]
|
||||
if ok {
|
||||
atomic.AddInt64(&slot.lifecount, 1)
|
||||
atomic.AddInt64(&slot.lastcount, 1)
|
||||
atomic.AddInt64(&slot.lifecycle, int64(duration))
|
||||
atomic.AddInt64(&slot.lastcycle, int64(duration))
|
||||
}
|
||||
return ok
|
||||
}()
|
||||
|
||||
if !updated {
|
||||
func() {
|
||||
pc.lock.Lock()
|
||||
defer pc.lock.Unlock()
|
||||
|
||||
pc.slots[name] = &profileSlot{
|
||||
lifecount: 1,
|
||||
lastcount: 1,
|
||||
lifecycle: int64(duration),
|
||||
lastcycle: int64(duration),
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
once.Do(flushRepeatly)
|
||||
var pc = &profileCenter{
|
||||
slots: make(map[string]*profileSlot),
|
||||
}
|
||||
|
||||
func flushRepeatly() {
|
||||
func init() {
|
||||
flushRepeatedly()
|
||||
}
|
||||
|
||||
func flushRepeatedly() {
|
||||
threading.GoSafe(func() {
|
||||
for {
|
||||
time.Sleep(flushInterval)
|
||||
@@ -76,42 +44,64 @@ func flushRepeatly() {
|
||||
})
|
||||
}
|
||||
|
||||
func report(name string, duration time.Duration) {
|
||||
slot := loadOrStoreSlot(name, duration)
|
||||
|
||||
atomic.AddInt64(&slot.lifecount, 1)
|
||||
atomic.AddInt64(&slot.lastcount, 1)
|
||||
atomic.AddInt64(&slot.lifecycle, int64(duration))
|
||||
atomic.AddInt64(&slot.lastcycle, int64(duration))
|
||||
}
|
||||
|
||||
func loadOrStoreSlot(name string, duration time.Duration) *profileSlot {
|
||||
pc.lock.RLock()
|
||||
slot, ok := pc.slots[name]
|
||||
pc.lock.RUnlock()
|
||||
|
||||
if ok {
|
||||
return slot
|
||||
}
|
||||
|
||||
pc.lock.Lock()
|
||||
defer pc.lock.Unlock()
|
||||
|
||||
// double-check
|
||||
if slot, ok = pc.slots[name]; ok {
|
||||
return slot
|
||||
}
|
||||
|
||||
slot = &profileSlot{}
|
||||
pc.slots[name] = slot
|
||||
return slot
|
||||
}
|
||||
|
||||
func generateReport() string {
|
||||
var buffer bytes.Buffer
|
||||
buffer.WriteString("Profiling report\n")
|
||||
var data [][]string
|
||||
var builder strings.Builder
|
||||
builder.WriteString("Profiling report\n")
|
||||
builder.WriteString("QUEUE,LIFECOUNT,LIFECYCLE,LASTCOUNT,LASTCYCLE\n")
|
||||
|
||||
calcFn := func(total, count int64) string {
|
||||
if count == 0 {
|
||||
return "-"
|
||||
}
|
||||
|
||||
return (time.Duration(total) / time.Duration(count)).String()
|
||||
}
|
||||
|
||||
func() {
|
||||
pc.lock.Lock()
|
||||
defer pc.lock.Unlock()
|
||||
pc.lock.Lock()
|
||||
for key, slot := range pc.slots {
|
||||
builder.WriteString(fmt.Sprintf("%s,%d,%s,%d,%s\n",
|
||||
key,
|
||||
slot.lifecount,
|
||||
calcFn(slot.lifecycle, slot.lifecount),
|
||||
slot.lastcount,
|
||||
calcFn(slot.lastcycle, slot.lastcount),
|
||||
))
|
||||
|
||||
for key, slot := range pc.slots {
|
||||
data = append(data, []string{
|
||||
key,
|
||||
strconv.FormatInt(slot.lifecount, 10),
|
||||
calcFn(slot.lifecycle, slot.lifecount),
|
||||
strconv.FormatInt(slot.lastcount, 10),
|
||||
calcFn(slot.lastcycle, slot.lastcount),
|
||||
})
|
||||
// reset last cycle stats
|
||||
atomic.StoreInt64(&slot.lastcount, 0)
|
||||
atomic.StoreInt64(&slot.lastcycle, 0)
|
||||
}
|
||||
pc.lock.Unlock()
|
||||
|
||||
// reset the data for last cycle
|
||||
slot.lastcount = 0
|
||||
slot.lastcycle = 0
|
||||
}
|
||||
}()
|
||||
|
||||
table := tablewriter.NewWriter(&buffer)
|
||||
table.SetHeader([]string{"QUEUE", "LIFECOUNT", "LIFECYCLE", "LASTCOUNT", "LASTCYCLE"})
|
||||
table.SetBorder(false)
|
||||
table.AppendBulk(data)
|
||||
table.Render()
|
||||
|
||||
return buffer.String()
|
||||
return builder.String()
|
||||
}
|
||||
|
||||
@@ -8,7 +8,6 @@ import (
|
||||
)
|
||||
|
||||
func TestReport(t *testing.T) {
|
||||
once.Do(func() {})
|
||||
assert.NotContains(t, generateReport(), "foo")
|
||||
report("foo", time.Second)
|
||||
assert.Contains(t, generateReport(), "foo")
|
||||
|
||||
@@ -5,6 +5,8 @@ import (
|
||||
"io"
|
||||
"os"
|
||||
"runtime"
|
||||
"runtime/debug"
|
||||
"runtime/metrics"
|
||||
"time"
|
||||
)
|
||||
|
||||
@@ -28,10 +30,29 @@ func displayStatsWithWriter(writer io.Writer, interval ...time.Duration) {
|
||||
ticker := time.NewTicker(duration)
|
||||
defer ticker.Stop()
|
||||
for range ticker.C {
|
||||
var m runtime.MemStats
|
||||
runtime.ReadMemStats(&m)
|
||||
var (
|
||||
alloc, totalAlloc, sys uint64
|
||||
samples = []metrics.Sample{
|
||||
{Name: "/memory/classes/heap/objects:bytes"},
|
||||
{Name: "/gc/heap/allocs:bytes"},
|
||||
{Name: "/memory/classes/total:bytes"},
|
||||
}
|
||||
)
|
||||
metrics.Read(samples)
|
||||
|
||||
if samples[0].Value.Kind() == metrics.KindUint64 {
|
||||
alloc = samples[0].Value.Uint64()
|
||||
}
|
||||
if samples[1].Value.Kind() == metrics.KindUint64 {
|
||||
totalAlloc = samples[1].Value.Uint64()
|
||||
}
|
||||
if samples[2].Value.Kind() == metrics.KindUint64 {
|
||||
sys = samples[2].Value.Uint64()
|
||||
}
|
||||
var stats debug.GCStats
|
||||
debug.ReadGCStats(&stats)
|
||||
fmt.Fprintf(writer, "Goroutines: %d, Alloc: %vm, TotalAlloc: %vm, Sys: %vm, NumGC: %v\n",
|
||||
runtime.NumGoroutine(), m.Alloc/mega, m.TotalAlloc/mega, m.Sys/mega, m.NumGC)
|
||||
runtime.NumGoroutine(), alloc/mega, totalAlloc/mega, sys/mega, stats.NumGC)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
"github.com/zeromicro/go-zero/core/stat"
|
||||
"github.com/zeromicro/go-zero/core/trace"
|
||||
"github.com/zeromicro/go-zero/internal/devserver"
|
||||
"github.com/zeromicro/go-zero/internal/profiling"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -38,6 +39,8 @@ type (
|
||||
Telemetry trace.Config `json:",optional"`
|
||||
DevServer DevServerConfig `json:",optional"`
|
||||
Shutdown proc.ShutdownConf `json:",optional"`
|
||||
// Profiling is the configuration for continuous profiling.
|
||||
Profiling profiling.Config `json:",optional"`
|
||||
}
|
||||
)
|
||||
|
||||
@@ -70,7 +73,9 @@ func (sc ServiceConf) SetUp() error {
|
||||
if len(sc.MetricsUrl) > 0 {
|
||||
stat.SetReportWriter(stat.NewRemoteWriter(sc.MetricsUrl))
|
||||
}
|
||||
|
||||
devserver.StartAgent(sc.DevServer)
|
||||
profiling.Start(sc.Profiling)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"sync"
|
||||
|
||||
"github.com/zeromicro/go-zero/core/logx"
|
||||
"github.com/zeromicro/go-zero/core/proc"
|
||||
"github.com/zeromicro/go-zero/core/syncx"
|
||||
"github.com/zeromicro/go-zero/core/threading"
|
||||
)
|
||||
|
||||
@@ -35,7 +36,7 @@ type (
|
||||
// NewServiceGroup returns a ServiceGroup.
|
||||
func NewServiceGroup() *ServiceGroup {
|
||||
sg := new(ServiceGroup)
|
||||
sg.stopOnce = syncx.Once(sg.doStop)
|
||||
sg.stopOnce = sync.OnceFunc(sg.doStop)
|
||||
return sg
|
||||
}
|
||||
|
||||
|
||||
@@ -19,7 +19,6 @@ import (
|
||||
const (
|
||||
clusterNameKey = "CLUSTER_NAME"
|
||||
testEnv = "test.v"
|
||||
timeFormat = "2006-01-02 15:04:05"
|
||||
)
|
||||
|
||||
var (
|
||||
@@ -45,7 +44,7 @@ func Report(msg string) {
|
||||
if fn != nil {
|
||||
reported := lessExecutor.DoOrDiscard(func() {
|
||||
var builder strings.Builder
|
||||
builder.WriteString(fmt.Sprintln(time.Now().Format(timeFormat)))
|
||||
builder.WriteString(fmt.Sprintln(time.Now().Format(time.DateTime)))
|
||||
if len(clusterName) > 0 {
|
||||
builder.WriteString(fmt.Sprintf("cluster: %s\n", clusterName))
|
||||
}
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
package stat
|
||||
|
||||
import (
|
||||
"runtime"
|
||||
"runtime/debug"
|
||||
"runtime/metrics"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
@@ -56,8 +57,28 @@ func bToMb(b uint64) float32 {
|
||||
}
|
||||
|
||||
func printUsage() {
|
||||
var m runtime.MemStats
|
||||
runtime.ReadMemStats(&m)
|
||||
var (
|
||||
alloc, totalAlloc, sys uint64
|
||||
samples = []metrics.Sample{
|
||||
{Name: "/memory/classes/heap/objects:bytes"},
|
||||
{Name: "/gc/heap/allocs:bytes"},
|
||||
{Name: "/memory/classes/total:bytes"},
|
||||
}
|
||||
stats debug.GCStats
|
||||
)
|
||||
metrics.Read(samples)
|
||||
|
||||
if samples[0].Value.Kind() == metrics.KindUint64 {
|
||||
alloc = samples[0].Value.Uint64()
|
||||
}
|
||||
if samples[1].Value.Kind() == metrics.KindUint64 {
|
||||
totalAlloc = samples[1].Value.Uint64()
|
||||
}
|
||||
if samples[2].Value.Kind() == metrics.KindUint64 {
|
||||
sys = samples[2].Value.Uint64()
|
||||
}
|
||||
debug.ReadGCStats(&stats)
|
||||
|
||||
logx.Statf("CPU: %dm, MEMORY: Alloc=%.1fMi, TotalAlloc=%.1fMi, Sys=%.1fMi, NumGC=%d",
|
||||
CpuUsage(), bToMb(m.Alloc), bToMb(m.TotalAlloc), bToMb(m.Sys), m.NumGC)
|
||||
CpuUsage(), bToMb(alloc), bToMb(totalAlloc), bToMb(sys), stats.NumGC)
|
||||
}
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
//go:generate mockgen -package mon -destination collectioninserter_mock.go -source bulkinserter.go collectionInserter
|
||||
package mon
|
||||
|
||||
import (
|
||||
@@ -6,7 +7,8 @@ import (
|
||||
|
||||
"github.com/zeromicro/go-zero/core/executors"
|
||||
"github.com/zeromicro/go-zero/core/logx"
|
||||
"go.mongodb.org/mongo-driver/mongo"
|
||||
"go.mongodb.org/mongo-driver/v2/mongo"
|
||||
"go.mongodb.org/mongo-driver/v2/mongo/options"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -27,10 +29,7 @@ type (
|
||||
|
||||
// NewBulkInserter returns a BulkInserter.
|
||||
func NewBulkInserter(coll Collection, interval ...time.Duration) (*BulkInserter, error) {
|
||||
cloneColl, err := coll.Clone()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
cloneColl := coll.Clone()
|
||||
|
||||
inserter := &dbInserter{
|
||||
collection: cloneColl,
|
||||
@@ -64,8 +63,16 @@ func (bi *BulkInserter) SetResultHandler(handler ResultHandler) {
|
||||
})
|
||||
}
|
||||
|
||||
type collectionInserter interface {
|
||||
InsertMany(
|
||||
ctx context.Context,
|
||||
documents interface{},
|
||||
opts ...options.Lister[options.InsertManyOptions],
|
||||
) (*mongo.InsertManyResult, error)
|
||||
}
|
||||
|
||||
type dbInserter struct {
|
||||
collection *mongo.Collection
|
||||
collection collectionInserter
|
||||
documents []any
|
||||
resultHandler ResultHandler
|
||||
}
|
||||
|
||||
@@ -1,26 +1,131 @@
|
||||
package mon
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"go.mongodb.org/mongo-driver/bson"
|
||||
"go.mongodb.org/mongo-driver/mongo"
|
||||
"go.mongodb.org/mongo-driver/mongo/integration/mtest"
|
||||
"go.mongodb.org/mongo-driver/v2/mongo"
|
||||
"go.uber.org/mock/gomock"
|
||||
)
|
||||
|
||||
func TestBulkInserter(t *testing.T) {
|
||||
mt := mtest.New(t, mtest.NewOptions().ClientType(mtest.Mock))
|
||||
mt.Run("test", func(mt *mtest.T) {
|
||||
mt.AddMockResponses(mtest.CreateSuccessResponse(bson.D{{Key: "ok", Value: 1}}...))
|
||||
bulk, err := NewBulkInserter(createModel(mt).Collection)
|
||||
assert.Equal(t, err, nil)
|
||||
bulk.SetResultHandler(func(result *mongo.InsertManyResult, err error) {
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, 2, len(result.InsertedIDs))
|
||||
})
|
||||
bulk.Insert(bson.D{{Key: "foo", Value: "bar"}})
|
||||
bulk.Insert(bson.D{{Key: "foo", Value: "baz"}})
|
||||
bulk.Flush()
|
||||
func TestBulkInserter_InsertAndFlush(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
mockCollection := NewMockCollection(ctrl)
|
||||
mockCollection.EXPECT().Clone().Return(&mongo.Collection{})
|
||||
bulkInserter, err := NewBulkInserter(mockCollection, time.Second)
|
||||
assert.NoError(t, err)
|
||||
bulkInserter.SetResultHandler(func(result *mongo.InsertManyResult, err error) {
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, 2, len(result.InsertedIDs))
|
||||
})
|
||||
doc := map[string]interface{}{"name": "test"}
|
||||
bulkInserter.Insert(doc)
|
||||
bulkInserter.Flush()
|
||||
}
|
||||
|
||||
func TestBulkInserter_SetResultHandler(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
mockCollection := NewMockCollection(ctrl)
|
||||
mockCollection.EXPECT().Clone().Return(nil)
|
||||
bulkInserter, err := NewBulkInserter(mockCollection)
|
||||
assert.NoError(t, err)
|
||||
mockHandler := func(result *mongo.InsertManyResult, err error) {}
|
||||
bulkInserter.SetResultHandler(mockHandler)
|
||||
}
|
||||
|
||||
func TestDbInserter_RemoveAll(t *testing.T) {
|
||||
inserter := &dbInserter{}
|
||||
inserter.documents = []interface{}{}
|
||||
docs := inserter.RemoveAll()
|
||||
assert.NotNil(t, docs)
|
||||
assert.Empty(t, inserter.documents)
|
||||
}
|
||||
|
||||
func Test_dbInserter_Execute(t *testing.T) {
|
||||
type fields struct {
|
||||
collection collectionInserter
|
||||
documents []any
|
||||
resultHandler ResultHandler
|
||||
}
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
mockCollection := NewMockcollectionInserter(ctrl)
|
||||
type args struct {
|
||||
objs any
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
fields fields
|
||||
args args
|
||||
mock func()
|
||||
}{
|
||||
{
|
||||
name: "empty doc",
|
||||
fields: fields{
|
||||
collection: nil,
|
||||
documents: nil,
|
||||
resultHandler: nil,
|
||||
},
|
||||
args: args{
|
||||
objs: make([]any, 0),
|
||||
},
|
||||
mock: func() {},
|
||||
},
|
||||
{
|
||||
name: "result handler",
|
||||
fields: fields{
|
||||
collection: mockCollection,
|
||||
resultHandler: func(result *mongo.InsertManyResult, err error) {
|
||||
assert.NotNil(t, err)
|
||||
},
|
||||
},
|
||||
args: args{
|
||||
objs: make([]any, 1),
|
||||
},
|
||||
mock: func() {
|
||||
mockCollection.EXPECT().InsertMany(gomock.Any(), gomock.Any()).Return(&mongo.InsertManyResult{}, errors.New("error"))
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "normal error handler",
|
||||
fields: fields{
|
||||
collection: mockCollection,
|
||||
resultHandler: nil,
|
||||
},
|
||||
args: args{
|
||||
objs: make([]any, 1),
|
||||
},
|
||||
mock: func() {
|
||||
mockCollection.EXPECT().InsertMany(gomock.Any(), gomock.Any()).Return(&mongo.InsertManyResult{}, errors.New("error"))
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "no error",
|
||||
fields: fields{
|
||||
collection: mockCollection,
|
||||
resultHandler: nil,
|
||||
},
|
||||
args: args{
|
||||
objs: make([]any, 1),
|
||||
},
|
||||
mock: func() {
|
||||
mockCollection.EXPECT().InsertMany(gomock.Any(), gomock.Any()).Return(&mongo.InsertManyResult{}, nil)
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
tt.mock()
|
||||
in := &dbInserter{
|
||||
collection: tt.fields.collection,
|
||||
documents: tt.fields.documents,
|
||||
resultHandler: tt.fields.resultHandler,
|
||||
}
|
||||
in.Execute(tt.args.objs)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -5,8 +5,8 @@ import (
|
||||
"io"
|
||||
|
||||
"github.com/zeromicro/go-zero/core/syncx"
|
||||
"go.mongodb.org/mongo-driver/mongo"
|
||||
mopt "go.mongodb.org/mongo-driver/mongo/options"
|
||||
"go.mongodb.org/mongo-driver/v2/mongo"
|
||||
"go.mongodb.org/mongo-driver/v2/mongo/options"
|
||||
)
|
||||
|
||||
var clientManager = syncx.NewResourceManager()
|
||||
@@ -29,13 +29,13 @@ func Inject(key string, client *mongo.Client) {
|
||||
|
||||
func getClient(url string, opts ...Option) (*mongo.Client, error) {
|
||||
val, err := clientManager.GetResource(url, func() (io.Closer, error) {
|
||||
o := mopt.Client().ApplyURI(url)
|
||||
o := options.Client().ApplyURI(url)
|
||||
opts = append([]Option{defaultTimeoutOption()}, opts...)
|
||||
for _, opt := range opts {
|
||||
opt(o)
|
||||
}
|
||||
|
||||
cli, err := mongo.Connect(context.Background(), o)
|
||||
cli, err := mongo.Connect(o)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -4,19 +4,13 @@ import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"go.mongodb.org/mongo-driver/mongo/integration/mtest"
|
||||
"go.mongodb.org/mongo-driver/v2/mongo"
|
||||
)
|
||||
|
||||
func init() {
|
||||
_ = mtest.Setup()
|
||||
}
|
||||
|
||||
func TestClientManger_getClient(t *testing.T) {
|
||||
mt := mtest.New(t, mtest.NewOptions().ClientType(mtest.Mock))
|
||||
mt.Run("test", func(mt *mtest.T) {
|
||||
Inject(mtest.ClusterURI(), mt.Client)
|
||||
cli, err := getClient(mtest.ClusterURI())
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, mt.Client, cli)
|
||||
})
|
||||
c := &mongo.Client{}
|
||||
Inject("foo", c)
|
||||
cli, err := getClient("foo")
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, c, cli)
|
||||
}
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
//go:generate mockgen -package mon -destination collection_mock.go -source collection.go Collection,monCollection
|
||||
package mon
|
||||
|
||||
import (
|
||||
@@ -8,9 +9,9 @@ import (
|
||||
"github.com/zeromicro/go-zero/core/breaker"
|
||||
"github.com/zeromicro/go-zero/core/errorx"
|
||||
"github.com/zeromicro/go-zero/core/timex"
|
||||
"go.mongodb.org/mongo-driver/mongo"
|
||||
mopt "go.mongodb.org/mongo-driver/mongo/options"
|
||||
"go.mongodb.org/mongo-driver/x/mongo/driver/session"
|
||||
"go.mongodb.org/mongo-driver/v2/mongo"
|
||||
"go.mongodb.org/mongo-driver/v2/mongo/options"
|
||||
"go.mongodb.org/mongo-driver/v2/x/mongo/driver/session"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -47,79 +48,79 @@ type (
|
||||
// Collection defines a MongoDB collection.
|
||||
Collection interface {
|
||||
// Aggregate executes an aggregation pipeline.
|
||||
Aggregate(ctx context.Context, pipeline any, opts ...*mopt.AggregateOptions) (
|
||||
Aggregate(ctx context.Context, pipeline any, opts ...options.Lister[options.AggregateOptions]) (
|
||||
*mongo.Cursor, error)
|
||||
// BulkWrite performs a bulk write operation.
|
||||
BulkWrite(ctx context.Context, models []mongo.WriteModel, opts ...*mopt.BulkWriteOptions) (
|
||||
BulkWrite(ctx context.Context, models []mongo.WriteModel, opts ...options.Lister[options.BulkWriteOptions]) (
|
||||
*mongo.BulkWriteResult, error)
|
||||
// Clone creates a copy of this collection with the same settings.
|
||||
Clone(opts ...*mopt.CollectionOptions) (*mongo.Collection, error)
|
||||
Clone(opts ...options.Lister[options.CollectionOptions]) *mongo.Collection
|
||||
// CountDocuments returns the number of documents in the collection that match the filter.
|
||||
CountDocuments(ctx context.Context, filter any, opts ...*mopt.CountOptions) (int64, error)
|
||||
CountDocuments(ctx context.Context, filter any, opts ...options.Lister[options.CountOptions]) (int64, error)
|
||||
// Database returns the database that this collection is a part of.
|
||||
Database() *mongo.Database
|
||||
// DeleteMany deletes documents from the collection that match the filter.
|
||||
DeleteMany(ctx context.Context, filter any, opts ...*mopt.DeleteOptions) (
|
||||
DeleteMany(ctx context.Context, filter any, opts ...options.Lister[options.DeleteManyOptions]) (
|
||||
*mongo.DeleteResult, error)
|
||||
// DeleteOne deletes at most one document from the collection that matches the filter.
|
||||
DeleteOne(ctx context.Context, filter any, opts ...*mopt.DeleteOptions) (
|
||||
DeleteOne(ctx context.Context, filter any, opts ...options.Lister[options.DeleteOneOptions]) (
|
||||
*mongo.DeleteResult, error)
|
||||
// Distinct returns a list of distinct values for the given key across the collection.
|
||||
Distinct(ctx context.Context, fieldName string, filter any,
|
||||
opts ...*mopt.DistinctOptions) ([]any, error)
|
||||
opts ...options.Lister[options.DistinctOptions]) (*mongo.DistinctResult, error)
|
||||
// Drop drops this collection from database.
|
||||
Drop(ctx context.Context) error
|
||||
Drop(ctx context.Context, opts ...options.Lister[options.DropCollectionOptions]) error
|
||||
// EstimatedDocumentCount returns an estimate of the count of documents in a collection
|
||||
// using collection metadata.
|
||||
EstimatedDocumentCount(ctx context.Context, opts ...*mopt.EstimatedDocumentCountOptions) (int64, error)
|
||||
EstimatedDocumentCount(ctx context.Context, opts ...options.Lister[options.EstimatedDocumentCountOptions]) (int64, error)
|
||||
// Find finds the documents matching the provided filter.
|
||||
Find(ctx context.Context, filter any, opts ...*mopt.FindOptions) (*mongo.Cursor, error)
|
||||
Find(ctx context.Context, filter any, opts ...options.Lister[options.FindOptions]) (*mongo.Cursor, error)
|
||||
// FindOne returns up to one document that matches the provided filter.
|
||||
FindOne(ctx context.Context, filter any, opts ...*mopt.FindOneOptions) (
|
||||
FindOne(ctx context.Context, filter any, opts ...options.Lister[options.FindOneOptions]) (
|
||||
*mongo.SingleResult, error)
|
||||
// FindOneAndDelete returns at most one document that matches the filter. If the filter
|
||||
// matches multiple documents, only the first document is deleted.
|
||||
FindOneAndDelete(ctx context.Context, filter any, opts ...*mopt.FindOneAndDeleteOptions) (
|
||||
FindOneAndDelete(ctx context.Context, filter any, opts ...options.Lister[options.FindOneAndDeleteOptions]) (
|
||||
*mongo.SingleResult, error)
|
||||
// FindOneAndReplace returns at most one document that matches the filter. If the filter
|
||||
// matches multiple documents, FindOneAndReplace returns the first document in the
|
||||
// collection that matches the filter.
|
||||
FindOneAndReplace(ctx context.Context, filter, replacement any,
|
||||
opts ...*mopt.FindOneAndReplaceOptions) (*mongo.SingleResult, error)
|
||||
opts ...options.Lister[options.FindOneAndReplaceOptions]) (*mongo.SingleResult, error)
|
||||
// FindOneAndUpdate returns at most one document that matches the filter. If the filter
|
||||
// matches multiple documents, FindOneAndUpdate returns the first document in the
|
||||
// collection that matches the filter.
|
||||
FindOneAndUpdate(ctx context.Context, filter, update any,
|
||||
opts ...*mopt.FindOneAndUpdateOptions) (*mongo.SingleResult, error)
|
||||
opts ...options.Lister[options.FindOneAndUpdateOptions]) (*mongo.SingleResult, error)
|
||||
// Indexes returns the index view for this collection.
|
||||
Indexes() mongo.IndexView
|
||||
// InsertMany inserts the provided documents.
|
||||
InsertMany(ctx context.Context, documents []any, opts ...*mopt.InsertManyOptions) (
|
||||
InsertMany(ctx context.Context, documents []any, opts ...options.Lister[options.InsertManyOptions]) (
|
||||
*mongo.InsertManyResult, error)
|
||||
// InsertOne inserts the provided document.
|
||||
InsertOne(ctx context.Context, document any, opts ...*mopt.InsertOneOptions) (
|
||||
InsertOne(ctx context.Context, document any, opts ...options.Lister[options.InsertOneOptions]) (
|
||||
*mongo.InsertOneResult, error)
|
||||
// ReplaceOne replaces at most one document that matches the filter.
|
||||
ReplaceOne(ctx context.Context, filter, replacement any,
|
||||
opts ...*mopt.ReplaceOptions) (*mongo.UpdateResult, error)
|
||||
opts ...options.Lister[options.ReplaceOptions]) (*mongo.UpdateResult, error)
|
||||
// UpdateByID updates a single document matching the provided filter.
|
||||
UpdateByID(ctx context.Context, id, update any,
|
||||
opts ...*mopt.UpdateOptions) (*mongo.UpdateResult, error)
|
||||
opts ...options.Lister[options.UpdateOneOptions]) (*mongo.UpdateResult, error)
|
||||
// UpdateMany updates the provided documents.
|
||||
UpdateMany(ctx context.Context, filter, update any,
|
||||
opts ...*mopt.UpdateOptions) (*mongo.UpdateResult, error)
|
||||
opts ...options.Lister[options.UpdateManyOptions]) (*mongo.UpdateResult, error)
|
||||
// UpdateOne updates a single document matching the provided filter.
|
||||
UpdateOne(ctx context.Context, filter, update any,
|
||||
opts ...*mopt.UpdateOptions) (*mongo.UpdateResult, error)
|
||||
opts ...options.Lister[options.UpdateOneOptions]) (*mongo.UpdateResult, error)
|
||||
// Watch returns a change stream cursor used to receive notifications of changes to the collection.
|
||||
Watch(ctx context.Context, pipeline any, opts ...*mopt.ChangeStreamOptions) (
|
||||
Watch(ctx context.Context, pipeline any, opts ...options.Lister[options.ChangeStreamOptions]) (
|
||||
*mongo.ChangeStream, error)
|
||||
}
|
||||
|
||||
decoratedCollection struct {
|
||||
*mongo.Collection
|
||||
name string
|
||||
brk breaker.Breaker
|
||||
Collection monCollection
|
||||
name string
|
||||
brk breaker.Breaker
|
||||
}
|
||||
|
||||
keepablePromise struct {
|
||||
@@ -137,7 +138,7 @@ func newCollection(collection *mongo.Collection, brk breaker.Breaker) Collection
|
||||
}
|
||||
|
||||
func (c *decoratedCollection) Aggregate(ctx context.Context, pipeline any,
|
||||
opts ...*mopt.AggregateOptions) (cur *mongo.Cursor, err error) {
|
||||
opts ...options.Lister[options.AggregateOptions]) (cur *mongo.Cursor, err error) {
|
||||
ctx, span := startSpan(ctx, aggregate)
|
||||
defer func() {
|
||||
endSpan(span, err)
|
||||
@@ -157,7 +158,7 @@ func (c *decoratedCollection) Aggregate(ctx context.Context, pipeline any,
|
||||
}
|
||||
|
||||
func (c *decoratedCollection) BulkWrite(ctx context.Context, models []mongo.WriteModel,
|
||||
opts ...*mopt.BulkWriteOptions) (res *mongo.BulkWriteResult, err error) {
|
||||
opts ...options.Lister[options.BulkWriteOptions]) (res *mongo.BulkWriteResult, err error) {
|
||||
ctx, span := startSpan(ctx, bulkWrite)
|
||||
defer func() {
|
||||
endSpan(span, err)
|
||||
@@ -176,8 +177,12 @@ func (c *decoratedCollection) BulkWrite(ctx context.Context, models []mongo.Writ
|
||||
return
|
||||
}
|
||||
|
||||
func (c *decoratedCollection) Clone(opts ...options.Lister[options.CollectionOptions]) *mongo.Collection {
|
||||
return c.Collection.Clone(opts...)
|
||||
}
|
||||
|
||||
func (c *decoratedCollection) CountDocuments(ctx context.Context, filter any,
|
||||
opts ...*mopt.CountOptions) (count int64, err error) {
|
||||
opts ...options.Lister[options.CountOptions]) (count int64, err error) {
|
||||
ctx, span := startSpan(ctx, countDocuments)
|
||||
defer func() {
|
||||
endSpan(span, err)
|
||||
@@ -196,8 +201,12 @@ func (c *decoratedCollection) CountDocuments(ctx context.Context, filter any,
|
||||
return
|
||||
}
|
||||
|
||||
func (c *decoratedCollection) Database() *mongo.Database {
|
||||
return c.Collection.Database()
|
||||
}
|
||||
|
||||
func (c *decoratedCollection) DeleteMany(ctx context.Context, filter any,
|
||||
opts ...*mopt.DeleteOptions) (res *mongo.DeleteResult, err error) {
|
||||
opts ...options.Lister[options.DeleteManyOptions]) (res *mongo.DeleteResult, err error) {
|
||||
ctx, span := startSpan(ctx, deleteMany)
|
||||
defer func() {
|
||||
endSpan(span, err)
|
||||
@@ -217,7 +226,7 @@ func (c *decoratedCollection) DeleteMany(ctx context.Context, filter any,
|
||||
}
|
||||
|
||||
func (c *decoratedCollection) DeleteOne(ctx context.Context, filter any,
|
||||
opts ...*mopt.DeleteOptions) (res *mongo.DeleteResult, err error) {
|
||||
opts ...options.Lister[options.DeleteOneOptions]) (res *mongo.DeleteResult, err error) {
|
||||
ctx, span := startSpan(ctx, deleteOne)
|
||||
defer func() {
|
||||
endSpan(span, err)
|
||||
@@ -237,7 +246,7 @@ func (c *decoratedCollection) DeleteOne(ctx context.Context, filter any,
|
||||
}
|
||||
|
||||
func (c *decoratedCollection) Distinct(ctx context.Context, fieldName string, filter any,
|
||||
opts ...*mopt.DistinctOptions) (val []any, err error) {
|
||||
opts ...options.Lister[options.DistinctOptions]) (res *mongo.DistinctResult, err error) {
|
||||
ctx, span := startSpan(ctx, distinct)
|
||||
defer func() {
|
||||
endSpan(span, err)
|
||||
@@ -249,15 +258,20 @@ func (c *decoratedCollection) Distinct(ctx context.Context, fieldName string, fi
|
||||
c.logDurationSimple(ctx, distinct, startTime, err)
|
||||
}()
|
||||
|
||||
val, err = c.Collection.Distinct(ctx, fieldName, filter, opts...)
|
||||
res = c.Collection.Distinct(ctx, fieldName, filter, opts...)
|
||||
err = res.Err()
|
||||
return err
|
||||
}, acceptable)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func (c *decoratedCollection) Drop(ctx context.Context, opts ...options.Lister[options.DropCollectionOptions]) error {
|
||||
return c.Collection.Drop(ctx, opts...)
|
||||
}
|
||||
|
||||
func (c *decoratedCollection) EstimatedDocumentCount(ctx context.Context,
|
||||
opts ...*mopt.EstimatedDocumentCountOptions) (val int64, err error) {
|
||||
opts ...options.Lister[options.EstimatedDocumentCountOptions]) (val int64, err error) {
|
||||
ctx, span := startSpan(ctx, estimatedDocumentCount)
|
||||
defer func() {
|
||||
endSpan(span, err)
|
||||
@@ -277,7 +291,7 @@ func (c *decoratedCollection) EstimatedDocumentCount(ctx context.Context,
|
||||
}
|
||||
|
||||
func (c *decoratedCollection) Find(ctx context.Context, filter any,
|
||||
opts ...*mopt.FindOptions) (cur *mongo.Cursor, err error) {
|
||||
opts ...options.Lister[options.FindOptions]) (cur *mongo.Cursor, err error) {
|
||||
ctx, span := startSpan(ctx, find)
|
||||
defer func() {
|
||||
endSpan(span, err)
|
||||
@@ -297,7 +311,7 @@ func (c *decoratedCollection) Find(ctx context.Context, filter any,
|
||||
}
|
||||
|
||||
func (c *decoratedCollection) FindOne(ctx context.Context, filter any,
|
||||
opts ...*mopt.FindOneOptions) (res *mongo.SingleResult, err error) {
|
||||
opts ...options.Lister[options.FindOneOptions]) (res *mongo.SingleResult, err error) {
|
||||
ctx, span := startSpan(ctx, findOne)
|
||||
defer func() {
|
||||
endSpan(span, err)
|
||||
@@ -318,7 +332,7 @@ func (c *decoratedCollection) FindOne(ctx context.Context, filter any,
|
||||
}
|
||||
|
||||
func (c *decoratedCollection) FindOneAndDelete(ctx context.Context, filter any,
|
||||
opts ...*mopt.FindOneAndDeleteOptions) (res *mongo.SingleResult, err error) {
|
||||
opts ...options.Lister[options.FindOneAndDeleteOptions]) (res *mongo.SingleResult, err error) {
|
||||
ctx, span := startSpan(ctx, findOneAndDelete)
|
||||
defer func() {
|
||||
endSpan(span, err)
|
||||
@@ -339,7 +353,7 @@ func (c *decoratedCollection) FindOneAndDelete(ctx context.Context, filter any,
|
||||
}
|
||||
|
||||
func (c *decoratedCollection) FindOneAndReplace(ctx context.Context, filter any,
|
||||
replacement any, opts ...*mopt.FindOneAndReplaceOptions) (
|
||||
replacement any, opts ...options.Lister[options.FindOneAndReplaceOptions]) (
|
||||
res *mongo.SingleResult, err error) {
|
||||
ctx, span := startSpan(ctx, findOneAndReplace)
|
||||
defer func() {
|
||||
@@ -361,7 +375,7 @@ func (c *decoratedCollection) FindOneAndReplace(ctx context.Context, filter any,
|
||||
}
|
||||
|
||||
func (c *decoratedCollection) FindOneAndUpdate(ctx context.Context, filter, update any,
|
||||
opts ...*mopt.FindOneAndUpdateOptions) (res *mongo.SingleResult, err error) {
|
||||
opts ...options.Lister[options.FindOneAndUpdateOptions]) (res *mongo.SingleResult, err error) {
|
||||
ctx, span := startSpan(ctx, findOneAndUpdate)
|
||||
defer func() {
|
||||
endSpan(span, err)
|
||||
@@ -381,8 +395,12 @@ func (c *decoratedCollection) FindOneAndUpdate(ctx context.Context, filter, upda
|
||||
return
|
||||
}
|
||||
|
||||
func (c *decoratedCollection) Indexes() mongo.IndexView {
|
||||
return c.Collection.Indexes()
|
||||
}
|
||||
|
||||
func (c *decoratedCollection) InsertMany(ctx context.Context, documents []any,
|
||||
opts ...*mopt.InsertManyOptions) (res *mongo.InsertManyResult, err error) {
|
||||
opts ...options.Lister[options.InsertManyOptions]) (res *mongo.InsertManyResult, err error) {
|
||||
ctx, span := startSpan(ctx, insertMany)
|
||||
defer func() {
|
||||
endSpan(span, err)
|
||||
@@ -402,7 +420,7 @@ func (c *decoratedCollection) InsertMany(ctx context.Context, documents []any,
|
||||
}
|
||||
|
||||
func (c *decoratedCollection) InsertOne(ctx context.Context, document any,
|
||||
opts ...*mopt.InsertOneOptions) (res *mongo.InsertOneResult, err error) {
|
||||
opts ...options.Lister[options.InsertOneOptions]) (res *mongo.InsertOneResult, err error) {
|
||||
ctx, span := startSpan(ctx, insertOne)
|
||||
defer func() {
|
||||
endSpan(span, err)
|
||||
@@ -422,7 +440,7 @@ func (c *decoratedCollection) InsertOne(ctx context.Context, document any,
|
||||
}
|
||||
|
||||
func (c *decoratedCollection) ReplaceOne(ctx context.Context, filter, replacement any,
|
||||
opts ...*mopt.ReplaceOptions) (res *mongo.UpdateResult, err error) {
|
||||
opts ...options.Lister[options.ReplaceOptions]) (res *mongo.UpdateResult, err error) {
|
||||
ctx, span := startSpan(ctx, replaceOne)
|
||||
defer func() {
|
||||
endSpan(span, err)
|
||||
@@ -442,7 +460,7 @@ func (c *decoratedCollection) ReplaceOne(ctx context.Context, filter, replacemen
|
||||
}
|
||||
|
||||
func (c *decoratedCollection) UpdateByID(ctx context.Context, id, update any,
|
||||
opts ...*mopt.UpdateOptions) (res *mongo.UpdateResult, err error) {
|
||||
opts ...options.Lister[options.UpdateOneOptions]) (res *mongo.UpdateResult, err error) {
|
||||
ctx, span := startSpan(ctx, updateByID)
|
||||
defer func() {
|
||||
endSpan(span, err)
|
||||
@@ -462,7 +480,7 @@ func (c *decoratedCollection) UpdateByID(ctx context.Context, id, update any,
|
||||
}
|
||||
|
||||
func (c *decoratedCollection) UpdateMany(ctx context.Context, filter, update any,
|
||||
opts ...*mopt.UpdateOptions) (res *mongo.UpdateResult, err error) {
|
||||
opts ...options.Lister[options.UpdateManyOptions]) (res *mongo.UpdateResult, err error) {
|
||||
ctx, span := startSpan(ctx, updateMany)
|
||||
defer func() {
|
||||
endSpan(span, err)
|
||||
@@ -482,7 +500,7 @@ func (c *decoratedCollection) UpdateMany(ctx context.Context, filter, update any
|
||||
}
|
||||
|
||||
func (c *decoratedCollection) UpdateOne(ctx context.Context, filter, update any,
|
||||
opts ...*mopt.UpdateOptions) (res *mongo.UpdateResult, err error) {
|
||||
opts ...options.Lister[options.UpdateOneOptions]) (res *mongo.UpdateResult, err error) {
|
||||
ctx, span := startSpan(ctx, updateOne)
|
||||
defer func() {
|
||||
endSpan(span, err)
|
||||
@@ -501,6 +519,11 @@ func (c *decoratedCollection) UpdateOne(ctx context.Context, filter, update any,
|
||||
return
|
||||
}
|
||||
|
||||
func (c *decoratedCollection) Watch(ctx context.Context, pipeline any, opts ...options.Lister[options.ChangeStreamOptions]) (
|
||||
*mongo.ChangeStream, error) {
|
||||
return c.Collection.Watch(ctx, pipeline, opts...)
|
||||
}
|
||||
|
||||
func (c *decoratedCollection) logDuration(ctx context.Context, method string,
|
||||
startTime time.Duration, err error, docs ...any) {
|
||||
logDurationWithDocs(ctx, c.name, method, startTime, err, docs...)
|
||||
@@ -546,3 +569,71 @@ func isDupKeyError(err error) bool {
|
||||
|
||||
return e.HasErrorCode(duplicateKeyCode)
|
||||
}
|
||||
|
||||
// monCollection defines a MongoDB collection, used for unit test
|
||||
type monCollection interface {
|
||||
// Aggregate executes an aggregation pipeline.
|
||||
Aggregate(ctx context.Context, pipeline any, opts ...options.Lister[options.AggregateOptions]) (
|
||||
*mongo.Cursor, error)
|
||||
// BulkWrite performs a bulk write operation.
|
||||
BulkWrite(ctx context.Context, models []mongo.WriteModel, opts ...options.Lister[options.BulkWriteOptions]) (
|
||||
*mongo.BulkWriteResult, error)
|
||||
// Clone creates a copy of this collection with the same settings.
|
||||
Clone(opts ...options.Lister[options.CollectionOptions]) *mongo.Collection
|
||||
// CountDocuments returns the number of documents in the collection that match the filter.
|
||||
CountDocuments(ctx context.Context, filter any, opts ...options.Lister[options.CountOptions]) (int64, error)
|
||||
// Database returns the database that this collection is a part of.
|
||||
Database() *mongo.Database
|
||||
// DeleteMany deletes documents from the collection that match the filter.
|
||||
DeleteMany(ctx context.Context, filter any, opts ...options.Lister[options.DeleteManyOptions]) (
|
||||
*mongo.DeleteResult, error)
|
||||
// DeleteOne deletes at most one document from the collection that matches the filter.
|
||||
DeleteOne(ctx context.Context, filter any, opts ...options.Lister[options.DeleteOneOptions]) (
|
||||
*mongo.DeleteResult, error)
|
||||
// Distinct returns a list of distinct values for the given key across the collection.
|
||||
Distinct(ctx context.Context, fieldName string, filter any,
|
||||
opts ...options.Lister[options.DistinctOptions]) *mongo.DistinctResult
|
||||
// Drop drops this collection from database.
|
||||
Drop(ctx context.Context, opts ...options.Lister[options.DropCollectionOptions]) error
|
||||
// EstimatedDocumentCount returns an estimate of the count of documents in a collection
|
||||
// using collection metadata.
|
||||
EstimatedDocumentCount(ctx context.Context, opts ...options.Lister[options.EstimatedDocumentCountOptions]) (int64, error)
|
||||
// Find finds the documents matching the provided filter.
|
||||
Find(ctx context.Context, filter any, opts ...options.Lister[options.FindOptions]) (*mongo.Cursor, error)
|
||||
// FindOne returns up to one document that matches the provided filter.
|
||||
FindOne(ctx context.Context, filter any, opts ...options.Lister[options.FindOneOptions]) *mongo.SingleResult
|
||||
// FindOneAndDelete returns at most one document that matches the filter. If the filter
|
||||
// matches multiple documents, only the first document is deleted.
|
||||
FindOneAndDelete(ctx context.Context, filter any, opts ...options.Lister[options.FindOneAndDeleteOptions]) *mongo.SingleResult
|
||||
// FindOneAndReplace returns at most one document that matches the filter. If the filter
|
||||
// matches multiple documents, FindOneAndReplace returns the first document in the
|
||||
// collection that matches the filter.
|
||||
FindOneAndReplace(ctx context.Context, filter, replacement any,
|
||||
opts ...options.Lister[options.FindOneAndReplaceOptions]) *mongo.SingleResult
|
||||
// FindOneAndUpdate returns at most one document that matches the filter. If the filter
|
||||
// matches multiple documents, FindOneAndUpdate returns the first document in the
|
||||
// collection that matches the filter.
|
||||
FindOneAndUpdate(ctx context.Context, filter, update any,
|
||||
opts ...options.Lister[options.FindOneAndUpdateOptions]) *mongo.SingleResult
|
||||
// Indexes returns the index view for this collection.
|
||||
Indexes() mongo.IndexView
|
||||
// InsertMany inserts the provided documents.
|
||||
InsertMany(ctx context.Context, documents interface{}, opts ...options.Lister[options.InsertManyOptions]) (*mongo.InsertManyResult, error)
|
||||
// InsertOne inserts the provided document.
|
||||
InsertOne(ctx context.Context, document any, opts ...options.Lister[options.InsertOneOptions]) (*mongo.InsertOneResult, error)
|
||||
// ReplaceOne replaces at most one document that matches the filter.
|
||||
ReplaceOne(ctx context.Context, filter, replacement any,
|
||||
opts ...options.Lister[options.ReplaceOptions]) (*mongo.UpdateResult, error)
|
||||
// UpdateByID updates a single document matching the provided filter.
|
||||
UpdateByID(ctx context.Context, id, update any,
|
||||
opts ...options.Lister[options.UpdateOneOptions]) (*mongo.UpdateResult, error)
|
||||
// UpdateMany updates the provided documents.
|
||||
UpdateMany(ctx context.Context, filter, update any,
|
||||
opts ...options.Lister[options.UpdateManyOptions]) (*mongo.UpdateResult, error)
|
||||
// UpdateOne updates a single document matching the provided filter.
|
||||
UpdateOne(ctx context.Context, filter, update any,
|
||||
opts ...options.Lister[options.UpdateOneOptions]) (*mongo.UpdateResult, error)
|
||||
// Watch returns a change stream cursor used to receive notifications of changes to the collection.
|
||||
Watch(ctx context.Context, pipeline any, opts ...options.Lister[options.ChangeStreamOptions]) (
|
||||
*mongo.ChangeStream, error)
|
||||
}
|
||||
|
||||
952
core/stores/mon/collection_mock.go
Normal file
952
core/stores/mon/collection_mock.go
Normal file
@@ -0,0 +1,952 @@
|
||||
// Code generated by MockGen. DO NOT EDIT.
|
||||
// Source: collection.go
|
||||
//
|
||||
// Generated by this command:
|
||||
//
|
||||
// mockgen -package mon -destination collection_mock.go -source collection.go Collection,monCollection
|
||||
//
|
||||
|
||||
// Package mon is a generated GoMock package.
|
||||
package mon
|
||||
|
||||
import (
|
||||
context "context"
|
||||
reflect "reflect"
|
||||
|
||||
mongo "go.mongodb.org/mongo-driver/v2/mongo"
|
||||
options "go.mongodb.org/mongo-driver/v2/mongo/options"
|
||||
gomock "go.uber.org/mock/gomock"
|
||||
)
|
||||
|
||||
// MockCollection is a mock of Collection interface.
|
||||
type MockCollection struct {
|
||||
ctrl *gomock.Controller
|
||||
recorder *MockCollectionMockRecorder
|
||||
isgomock struct{}
|
||||
}
|
||||
|
||||
// MockCollectionMockRecorder is the mock recorder for MockCollection.
|
||||
type MockCollectionMockRecorder struct {
|
||||
mock *MockCollection
|
||||
}
|
||||
|
||||
// NewMockCollection creates a new mock instance.
|
||||
func NewMockCollection(ctrl *gomock.Controller) *MockCollection {
|
||||
mock := &MockCollection{ctrl: ctrl}
|
||||
mock.recorder = &MockCollectionMockRecorder{mock}
|
||||
return mock
|
||||
}
|
||||
|
||||
// EXPECT returns an object that allows the caller to indicate expected use.
|
||||
func (m *MockCollection) EXPECT() *MockCollectionMockRecorder {
|
||||
return m.recorder
|
||||
}
|
||||
|
||||
// Aggregate mocks base method.
|
||||
func (m *MockCollection) Aggregate(ctx context.Context, pipeline any, opts ...options.Lister[options.AggregateOptions]) (*mongo.Cursor, error) {
|
||||
m.ctrl.T.Helper()
|
||||
varargs := []any{ctx, pipeline}
|
||||
for _, a := range opts {
|
||||
varargs = append(varargs, a)
|
||||
}
|
||||
ret := m.ctrl.Call(m, "Aggregate", varargs...)
|
||||
ret0, _ := ret[0].(*mongo.Cursor)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// Aggregate indicates an expected call of Aggregate.
|
||||
func (mr *MockCollectionMockRecorder) Aggregate(ctx, pipeline any, opts ...any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
varargs := append([]any{ctx, pipeline}, opts...)
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Aggregate", reflect.TypeOf((*MockCollection)(nil).Aggregate), varargs...)
|
||||
}
|
||||
|
||||
// BulkWrite mocks base method.
|
||||
func (m *MockCollection) BulkWrite(ctx context.Context, models []mongo.WriteModel, opts ...options.Lister[options.BulkWriteOptions]) (*mongo.BulkWriteResult, error) {
|
||||
m.ctrl.T.Helper()
|
||||
varargs := []any{ctx, models}
|
||||
for _, a := range opts {
|
||||
varargs = append(varargs, a)
|
||||
}
|
||||
ret := m.ctrl.Call(m, "BulkWrite", varargs...)
|
||||
ret0, _ := ret[0].(*mongo.BulkWriteResult)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// BulkWrite indicates an expected call of BulkWrite.
|
||||
func (mr *MockCollectionMockRecorder) BulkWrite(ctx, models any, opts ...any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
varargs := append([]any{ctx, models}, opts...)
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "BulkWrite", reflect.TypeOf((*MockCollection)(nil).BulkWrite), varargs...)
|
||||
}
|
||||
|
||||
// Clone mocks base method.
|
||||
func (m *MockCollection) Clone(opts ...options.Lister[options.CollectionOptions]) *mongo.Collection {
|
||||
m.ctrl.T.Helper()
|
||||
varargs := []any{}
|
||||
for _, a := range opts {
|
||||
varargs = append(varargs, a)
|
||||
}
|
||||
ret := m.ctrl.Call(m, "Clone", varargs...)
|
||||
ret0, _ := ret[0].(*mongo.Collection)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// Clone indicates an expected call of Clone.
|
||||
func (mr *MockCollectionMockRecorder) Clone(opts ...any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Clone", reflect.TypeOf((*MockCollection)(nil).Clone), opts...)
|
||||
}
|
||||
|
||||
// CountDocuments mocks base method.
|
||||
func (m *MockCollection) CountDocuments(ctx context.Context, filter any, opts ...options.Lister[options.CountOptions]) (int64, error) {
|
||||
m.ctrl.T.Helper()
|
||||
varargs := []any{ctx, filter}
|
||||
for _, a := range opts {
|
||||
varargs = append(varargs, a)
|
||||
}
|
||||
ret := m.ctrl.Call(m, "CountDocuments", varargs...)
|
||||
ret0, _ := ret[0].(int64)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// CountDocuments indicates an expected call of CountDocuments.
|
||||
func (mr *MockCollectionMockRecorder) CountDocuments(ctx, filter any, opts ...any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
varargs := append([]any{ctx, filter}, opts...)
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CountDocuments", reflect.TypeOf((*MockCollection)(nil).CountDocuments), varargs...)
|
||||
}
|
||||
|
||||
// Database mocks base method.
|
||||
func (m *MockCollection) Database() *mongo.Database {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "Database")
|
||||
ret0, _ := ret[0].(*mongo.Database)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// Database indicates an expected call of Database.
|
||||
func (mr *MockCollectionMockRecorder) Database() *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Database", reflect.TypeOf((*MockCollection)(nil).Database))
|
||||
}
|
||||
|
||||
// DeleteMany mocks base method.
|
||||
func (m *MockCollection) DeleteMany(ctx context.Context, filter any, opts ...options.Lister[options.DeleteManyOptions]) (*mongo.DeleteResult, error) {
|
||||
m.ctrl.T.Helper()
|
||||
varargs := []any{ctx, filter}
|
||||
for _, a := range opts {
|
||||
varargs = append(varargs, a)
|
||||
}
|
||||
ret := m.ctrl.Call(m, "DeleteMany", varargs...)
|
||||
ret0, _ := ret[0].(*mongo.DeleteResult)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// DeleteMany indicates an expected call of DeleteMany.
|
||||
func (mr *MockCollectionMockRecorder) DeleteMany(ctx, filter any, opts ...any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
varargs := append([]any{ctx, filter}, opts...)
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteMany", reflect.TypeOf((*MockCollection)(nil).DeleteMany), varargs...)
|
||||
}
|
||||
|
||||
// DeleteOne mocks base method.
|
||||
func (m *MockCollection) DeleteOne(ctx context.Context, filter any, opts ...options.Lister[options.DeleteOneOptions]) (*mongo.DeleteResult, error) {
|
||||
m.ctrl.T.Helper()
|
||||
varargs := []any{ctx, filter}
|
||||
for _, a := range opts {
|
||||
varargs = append(varargs, a)
|
||||
}
|
||||
ret := m.ctrl.Call(m, "DeleteOne", varargs...)
|
||||
ret0, _ := ret[0].(*mongo.DeleteResult)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// DeleteOne indicates an expected call of DeleteOne.
|
||||
func (mr *MockCollectionMockRecorder) DeleteOne(ctx, filter any, opts ...any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
varargs := append([]any{ctx, filter}, opts...)
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteOne", reflect.TypeOf((*MockCollection)(nil).DeleteOne), varargs...)
|
||||
}
|
||||
|
||||
// Distinct mocks base method.
|
||||
func (m *MockCollection) Distinct(ctx context.Context, fieldName string, filter any, opts ...options.Lister[options.DistinctOptions]) (*mongo.DistinctResult, error) {
|
||||
m.ctrl.T.Helper()
|
||||
varargs := []any{ctx, fieldName, filter}
|
||||
for _, a := range opts {
|
||||
varargs = append(varargs, a)
|
||||
}
|
||||
ret := m.ctrl.Call(m, "Distinct", varargs...)
|
||||
ret0, _ := ret[0].(*mongo.DistinctResult)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// Distinct indicates an expected call of Distinct.
|
||||
func (mr *MockCollectionMockRecorder) Distinct(ctx, fieldName, filter any, opts ...any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
varargs := append([]any{ctx, fieldName, filter}, opts...)
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Distinct", reflect.TypeOf((*MockCollection)(nil).Distinct), varargs...)
|
||||
}
|
||||
|
||||
// Drop mocks base method.
|
||||
func (m *MockCollection) Drop(ctx context.Context, opts ...options.Lister[options.DropCollectionOptions]) error {
|
||||
m.ctrl.T.Helper()
|
||||
varargs := []any{ctx}
|
||||
for _, a := range opts {
|
||||
varargs = append(varargs, a)
|
||||
}
|
||||
ret := m.ctrl.Call(m, "Drop", varargs...)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// Drop indicates an expected call of Drop.
|
||||
func (mr *MockCollectionMockRecorder) Drop(ctx any, opts ...any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
varargs := append([]any{ctx}, opts...)
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Drop", reflect.TypeOf((*MockCollection)(nil).Drop), varargs...)
|
||||
}
|
||||
|
||||
// EstimatedDocumentCount mocks base method.
|
||||
func (m *MockCollection) EstimatedDocumentCount(ctx context.Context, opts ...options.Lister[options.EstimatedDocumentCountOptions]) (int64, error) {
|
||||
m.ctrl.T.Helper()
|
||||
varargs := []any{ctx}
|
||||
for _, a := range opts {
|
||||
varargs = append(varargs, a)
|
||||
}
|
||||
ret := m.ctrl.Call(m, "EstimatedDocumentCount", varargs...)
|
||||
ret0, _ := ret[0].(int64)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// EstimatedDocumentCount indicates an expected call of EstimatedDocumentCount.
|
||||
func (mr *MockCollectionMockRecorder) EstimatedDocumentCount(ctx any, opts ...any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
varargs := append([]any{ctx}, opts...)
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "EstimatedDocumentCount", reflect.TypeOf((*MockCollection)(nil).EstimatedDocumentCount), varargs...)
|
||||
}
|
||||
|
||||
// Find mocks base method.
|
||||
func (m *MockCollection) Find(ctx context.Context, filter any, opts ...options.Lister[options.FindOptions]) (*mongo.Cursor, error) {
|
||||
m.ctrl.T.Helper()
|
||||
varargs := []any{ctx, filter}
|
||||
for _, a := range opts {
|
||||
varargs = append(varargs, a)
|
||||
}
|
||||
ret := m.ctrl.Call(m, "Find", varargs...)
|
||||
ret0, _ := ret[0].(*mongo.Cursor)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// Find indicates an expected call of Find.
|
||||
func (mr *MockCollectionMockRecorder) Find(ctx, filter any, opts ...any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
varargs := append([]any{ctx, filter}, opts...)
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Find", reflect.TypeOf((*MockCollection)(nil).Find), varargs...)
|
||||
}
|
||||
|
||||
// FindOne mocks base method.
|
||||
func (m *MockCollection) FindOne(ctx context.Context, filter any, opts ...options.Lister[options.FindOneOptions]) (*mongo.SingleResult, error) {
|
||||
m.ctrl.T.Helper()
|
||||
varargs := []any{ctx, filter}
|
||||
for _, a := range opts {
|
||||
varargs = append(varargs, a)
|
||||
}
|
||||
ret := m.ctrl.Call(m, "FindOne", varargs...)
|
||||
ret0, _ := ret[0].(*mongo.SingleResult)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// FindOne indicates an expected call of FindOne.
|
||||
func (mr *MockCollectionMockRecorder) FindOne(ctx, filter any, opts ...any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
varargs := append([]any{ctx, filter}, opts...)
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "FindOne", reflect.TypeOf((*MockCollection)(nil).FindOne), varargs...)
|
||||
}
|
||||
|
||||
// FindOneAndDelete mocks base method.
|
||||
func (m *MockCollection) FindOneAndDelete(ctx context.Context, filter any, opts ...options.Lister[options.FindOneAndDeleteOptions]) (*mongo.SingleResult, error) {
|
||||
m.ctrl.T.Helper()
|
||||
varargs := []any{ctx, filter}
|
||||
for _, a := range opts {
|
||||
varargs = append(varargs, a)
|
||||
}
|
||||
ret := m.ctrl.Call(m, "FindOneAndDelete", varargs...)
|
||||
ret0, _ := ret[0].(*mongo.SingleResult)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// FindOneAndDelete indicates an expected call of FindOneAndDelete.
|
||||
func (mr *MockCollectionMockRecorder) FindOneAndDelete(ctx, filter any, opts ...any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
varargs := append([]any{ctx, filter}, opts...)
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "FindOneAndDelete", reflect.TypeOf((*MockCollection)(nil).FindOneAndDelete), varargs...)
|
||||
}
|
||||
|
||||
// FindOneAndReplace mocks base method.
|
||||
func (m *MockCollection) FindOneAndReplace(ctx context.Context, filter, replacement any, opts ...options.Lister[options.FindOneAndReplaceOptions]) (*mongo.SingleResult, error) {
|
||||
m.ctrl.T.Helper()
|
||||
varargs := []any{ctx, filter, replacement}
|
||||
for _, a := range opts {
|
||||
varargs = append(varargs, a)
|
||||
}
|
||||
ret := m.ctrl.Call(m, "FindOneAndReplace", varargs...)
|
||||
ret0, _ := ret[0].(*mongo.SingleResult)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// FindOneAndReplace indicates an expected call of FindOneAndReplace.
|
||||
func (mr *MockCollectionMockRecorder) FindOneAndReplace(ctx, filter, replacement any, opts ...any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
varargs := append([]any{ctx, filter, replacement}, opts...)
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "FindOneAndReplace", reflect.TypeOf((*MockCollection)(nil).FindOneAndReplace), varargs...)
|
||||
}
|
||||
|
||||
// FindOneAndUpdate mocks base method.
|
||||
func (m *MockCollection) FindOneAndUpdate(ctx context.Context, filter, update any, opts ...options.Lister[options.FindOneAndUpdateOptions]) (*mongo.SingleResult, error) {
|
||||
m.ctrl.T.Helper()
|
||||
varargs := []any{ctx, filter, update}
|
||||
for _, a := range opts {
|
||||
varargs = append(varargs, a)
|
||||
}
|
||||
ret := m.ctrl.Call(m, "FindOneAndUpdate", varargs...)
|
||||
ret0, _ := ret[0].(*mongo.SingleResult)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// FindOneAndUpdate indicates an expected call of FindOneAndUpdate.
|
||||
func (mr *MockCollectionMockRecorder) FindOneAndUpdate(ctx, filter, update any, opts ...any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
varargs := append([]any{ctx, filter, update}, opts...)
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "FindOneAndUpdate", reflect.TypeOf((*MockCollection)(nil).FindOneAndUpdate), varargs...)
|
||||
}
|
||||
|
||||
// Indexes mocks base method.
|
||||
func (m *MockCollection) Indexes() mongo.IndexView {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "Indexes")
|
||||
ret0, _ := ret[0].(mongo.IndexView)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// Indexes indicates an expected call of Indexes.
|
||||
func (mr *MockCollectionMockRecorder) Indexes() *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Indexes", reflect.TypeOf((*MockCollection)(nil).Indexes))
|
||||
}
|
||||
|
||||
// InsertMany mocks base method.
|
||||
func (m *MockCollection) InsertMany(ctx context.Context, documents []any, opts ...options.Lister[options.InsertManyOptions]) (*mongo.InsertManyResult, error) {
|
||||
m.ctrl.T.Helper()
|
||||
varargs := []any{ctx, documents}
|
||||
for _, a := range opts {
|
||||
varargs = append(varargs, a)
|
||||
}
|
||||
ret := m.ctrl.Call(m, "InsertMany", varargs...)
|
||||
ret0, _ := ret[0].(*mongo.InsertManyResult)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// InsertMany indicates an expected call of InsertMany.
|
||||
func (mr *MockCollectionMockRecorder) InsertMany(ctx, documents any, opts ...any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
varargs := append([]any{ctx, documents}, opts...)
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "InsertMany", reflect.TypeOf((*MockCollection)(nil).InsertMany), varargs...)
|
||||
}
|
||||
|
||||
// InsertOne mocks base method.
|
||||
func (m *MockCollection) InsertOne(ctx context.Context, document any, opts ...options.Lister[options.InsertOneOptions]) (*mongo.InsertOneResult, error) {
|
||||
m.ctrl.T.Helper()
|
||||
varargs := []any{ctx, document}
|
||||
for _, a := range opts {
|
||||
varargs = append(varargs, a)
|
||||
}
|
||||
ret := m.ctrl.Call(m, "InsertOne", varargs...)
|
||||
ret0, _ := ret[0].(*mongo.InsertOneResult)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// InsertOne indicates an expected call of InsertOne.
|
||||
func (mr *MockCollectionMockRecorder) InsertOne(ctx, document any, opts ...any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
varargs := append([]any{ctx, document}, opts...)
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "InsertOne", reflect.TypeOf((*MockCollection)(nil).InsertOne), varargs...)
|
||||
}
|
||||
|
||||
// ReplaceOne mocks base method.
|
||||
func (m *MockCollection) ReplaceOne(ctx context.Context, filter, replacement any, opts ...options.Lister[options.ReplaceOptions]) (*mongo.UpdateResult, error) {
|
||||
m.ctrl.T.Helper()
|
||||
varargs := []any{ctx, filter, replacement}
|
||||
for _, a := range opts {
|
||||
varargs = append(varargs, a)
|
||||
}
|
||||
ret := m.ctrl.Call(m, "ReplaceOne", varargs...)
|
||||
ret0, _ := ret[0].(*mongo.UpdateResult)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// ReplaceOne indicates an expected call of ReplaceOne.
|
||||
func (mr *MockCollectionMockRecorder) ReplaceOne(ctx, filter, replacement any, opts ...any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
varargs := append([]any{ctx, filter, replacement}, opts...)
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReplaceOne", reflect.TypeOf((*MockCollection)(nil).ReplaceOne), varargs...)
|
||||
}
|
||||
|
||||
// UpdateByID mocks base method.
|
||||
func (m *MockCollection) UpdateByID(ctx context.Context, id, update any, opts ...options.Lister[options.UpdateOneOptions]) (*mongo.UpdateResult, error) {
|
||||
m.ctrl.T.Helper()
|
||||
varargs := []any{ctx, id, update}
|
||||
for _, a := range opts {
|
||||
varargs = append(varargs, a)
|
||||
}
|
||||
ret := m.ctrl.Call(m, "UpdateByID", varargs...)
|
||||
ret0, _ := ret[0].(*mongo.UpdateResult)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// UpdateByID indicates an expected call of UpdateByID.
|
||||
func (mr *MockCollectionMockRecorder) UpdateByID(ctx, id, update any, opts ...any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
varargs := append([]any{ctx, id, update}, opts...)
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateByID", reflect.TypeOf((*MockCollection)(nil).UpdateByID), varargs...)
|
||||
}
|
||||
|
||||
// UpdateMany mocks base method.
|
||||
func (m *MockCollection) UpdateMany(ctx context.Context, filter, update any, opts ...options.Lister[options.UpdateManyOptions]) (*mongo.UpdateResult, error) {
|
||||
m.ctrl.T.Helper()
|
||||
varargs := []any{ctx, filter, update}
|
||||
for _, a := range opts {
|
||||
varargs = append(varargs, a)
|
||||
}
|
||||
ret := m.ctrl.Call(m, "UpdateMany", varargs...)
|
||||
ret0, _ := ret[0].(*mongo.UpdateResult)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// UpdateMany indicates an expected call of UpdateMany.
|
||||
func (mr *MockCollectionMockRecorder) UpdateMany(ctx, filter, update any, opts ...any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
varargs := append([]any{ctx, filter, update}, opts...)
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateMany", reflect.TypeOf((*MockCollection)(nil).UpdateMany), varargs...)
|
||||
}
|
||||
|
||||
// UpdateOne mocks base method.
|
||||
func (m *MockCollection) UpdateOne(ctx context.Context, filter, update any, opts ...options.Lister[options.UpdateOneOptions]) (*mongo.UpdateResult, error) {
|
||||
m.ctrl.T.Helper()
|
||||
varargs := []any{ctx, filter, update}
|
||||
for _, a := range opts {
|
||||
varargs = append(varargs, a)
|
||||
}
|
||||
ret := m.ctrl.Call(m, "UpdateOne", varargs...)
|
||||
ret0, _ := ret[0].(*mongo.UpdateResult)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// UpdateOne indicates an expected call of UpdateOne.
|
||||
func (mr *MockCollectionMockRecorder) UpdateOne(ctx, filter, update any, opts ...any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
varargs := append([]any{ctx, filter, update}, opts...)
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateOne", reflect.TypeOf((*MockCollection)(nil).UpdateOne), varargs...)
|
||||
}
|
||||
|
||||
// Watch mocks base method.
|
||||
func (m *MockCollection) Watch(ctx context.Context, pipeline any, opts ...options.Lister[options.ChangeStreamOptions]) (*mongo.ChangeStream, error) {
|
||||
m.ctrl.T.Helper()
|
||||
varargs := []any{ctx, pipeline}
|
||||
for _, a := range opts {
|
||||
varargs = append(varargs, a)
|
||||
}
|
||||
ret := m.ctrl.Call(m, "Watch", varargs...)
|
||||
ret0, _ := ret[0].(*mongo.ChangeStream)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// Watch indicates an expected call of Watch.
|
||||
func (mr *MockCollectionMockRecorder) Watch(ctx, pipeline any, opts ...any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
varargs := append([]any{ctx, pipeline}, opts...)
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Watch", reflect.TypeOf((*MockCollection)(nil).Watch), varargs...)
|
||||
}
|
||||
|
||||
// MockmonCollection is a mock of monCollection interface.
|
||||
type MockmonCollection struct {
|
||||
ctrl *gomock.Controller
|
||||
recorder *MockmonCollectionMockRecorder
|
||||
isgomock struct{}
|
||||
}
|
||||
|
||||
// MockmonCollectionMockRecorder is the mock recorder for MockmonCollection.
|
||||
type MockmonCollectionMockRecorder struct {
|
||||
mock *MockmonCollection
|
||||
}
|
||||
|
||||
// NewMockmonCollection creates a new mock instance.
|
||||
func NewMockmonCollection(ctrl *gomock.Controller) *MockmonCollection {
|
||||
mock := &MockmonCollection{ctrl: ctrl}
|
||||
mock.recorder = &MockmonCollectionMockRecorder{mock}
|
||||
return mock
|
||||
}
|
||||
|
||||
// EXPECT returns an object that allows the caller to indicate expected use.
|
||||
func (m *MockmonCollection) EXPECT() *MockmonCollectionMockRecorder {
|
||||
return m.recorder
|
||||
}
|
||||
|
||||
// Aggregate mocks base method.
|
||||
func (m *MockmonCollection) Aggregate(ctx context.Context, pipeline any, opts ...options.Lister[options.AggregateOptions]) (*mongo.Cursor, error) {
|
||||
m.ctrl.T.Helper()
|
||||
varargs := []any{ctx, pipeline}
|
||||
for _, a := range opts {
|
||||
varargs = append(varargs, a)
|
||||
}
|
||||
ret := m.ctrl.Call(m, "Aggregate", varargs...)
|
||||
ret0, _ := ret[0].(*mongo.Cursor)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// Aggregate indicates an expected call of Aggregate.
|
||||
func (mr *MockmonCollectionMockRecorder) Aggregate(ctx, pipeline any, opts ...any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
varargs := append([]any{ctx, pipeline}, opts...)
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Aggregate", reflect.TypeOf((*MockmonCollection)(nil).Aggregate), varargs...)
|
||||
}
|
||||
|
||||
// BulkWrite mocks base method.
|
||||
func (m *MockmonCollection) BulkWrite(ctx context.Context, models []mongo.WriteModel, opts ...options.Lister[options.BulkWriteOptions]) (*mongo.BulkWriteResult, error) {
|
||||
m.ctrl.T.Helper()
|
||||
varargs := []any{ctx, models}
|
||||
for _, a := range opts {
|
||||
varargs = append(varargs, a)
|
||||
}
|
||||
ret := m.ctrl.Call(m, "BulkWrite", varargs...)
|
||||
ret0, _ := ret[0].(*mongo.BulkWriteResult)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// BulkWrite indicates an expected call of BulkWrite.
|
||||
func (mr *MockmonCollectionMockRecorder) BulkWrite(ctx, models any, opts ...any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
varargs := append([]any{ctx, models}, opts...)
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "BulkWrite", reflect.TypeOf((*MockmonCollection)(nil).BulkWrite), varargs...)
|
||||
}
|
||||
|
||||
// Clone mocks base method.
|
||||
func (m *MockmonCollection) Clone(opts ...options.Lister[options.CollectionOptions]) *mongo.Collection {
|
||||
m.ctrl.T.Helper()
|
||||
varargs := []any{}
|
||||
for _, a := range opts {
|
||||
varargs = append(varargs, a)
|
||||
}
|
||||
ret := m.ctrl.Call(m, "Clone", varargs...)
|
||||
ret0, _ := ret[0].(*mongo.Collection)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// Clone indicates an expected call of Clone.
|
||||
func (mr *MockmonCollectionMockRecorder) Clone(opts ...any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Clone", reflect.TypeOf((*MockmonCollection)(nil).Clone), opts...)
|
||||
}
|
||||
|
||||
// CountDocuments mocks base method.
|
||||
func (m *MockmonCollection) CountDocuments(ctx context.Context, filter any, opts ...options.Lister[options.CountOptions]) (int64, error) {
|
||||
m.ctrl.T.Helper()
|
||||
varargs := []any{ctx, filter}
|
||||
for _, a := range opts {
|
||||
varargs = append(varargs, a)
|
||||
}
|
||||
ret := m.ctrl.Call(m, "CountDocuments", varargs...)
|
||||
ret0, _ := ret[0].(int64)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// CountDocuments indicates an expected call of CountDocuments.
|
||||
func (mr *MockmonCollectionMockRecorder) CountDocuments(ctx, filter any, opts ...any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
varargs := append([]any{ctx, filter}, opts...)
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CountDocuments", reflect.TypeOf((*MockmonCollection)(nil).CountDocuments), varargs...)
|
||||
}
|
||||
|
||||
// Database mocks base method.
|
||||
func (m *MockmonCollection) Database() *mongo.Database {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "Database")
|
||||
ret0, _ := ret[0].(*mongo.Database)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// Database indicates an expected call of Database.
|
||||
func (mr *MockmonCollectionMockRecorder) Database() *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Database", reflect.TypeOf((*MockmonCollection)(nil).Database))
|
||||
}
|
||||
|
||||
// DeleteMany mocks base method.
|
||||
func (m *MockmonCollection) DeleteMany(ctx context.Context, filter any, opts ...options.Lister[options.DeleteManyOptions]) (*mongo.DeleteResult, error) {
|
||||
m.ctrl.T.Helper()
|
||||
varargs := []any{ctx, filter}
|
||||
for _, a := range opts {
|
||||
varargs = append(varargs, a)
|
||||
}
|
||||
ret := m.ctrl.Call(m, "DeleteMany", varargs...)
|
||||
ret0, _ := ret[0].(*mongo.DeleteResult)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// DeleteMany indicates an expected call of DeleteMany.
|
||||
func (mr *MockmonCollectionMockRecorder) DeleteMany(ctx, filter any, opts ...any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
varargs := append([]any{ctx, filter}, opts...)
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteMany", reflect.TypeOf((*MockmonCollection)(nil).DeleteMany), varargs...)
|
||||
}
|
||||
|
||||
// DeleteOne mocks base method.
|
||||
func (m *MockmonCollection) DeleteOne(ctx context.Context, filter any, opts ...options.Lister[options.DeleteOneOptions]) (*mongo.DeleteResult, error) {
|
||||
m.ctrl.T.Helper()
|
||||
varargs := []any{ctx, filter}
|
||||
for _, a := range opts {
|
||||
varargs = append(varargs, a)
|
||||
}
|
||||
ret := m.ctrl.Call(m, "DeleteOne", varargs...)
|
||||
ret0, _ := ret[0].(*mongo.DeleteResult)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// DeleteOne indicates an expected call of DeleteOne.
|
||||
func (mr *MockmonCollectionMockRecorder) DeleteOne(ctx, filter any, opts ...any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
varargs := append([]any{ctx, filter}, opts...)
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteOne", reflect.TypeOf((*MockmonCollection)(nil).DeleteOne), varargs...)
|
||||
}
|
||||
|
||||
// Distinct mocks base method.
|
||||
func (m *MockmonCollection) Distinct(ctx context.Context, fieldName string, filter any, opts ...options.Lister[options.DistinctOptions]) *mongo.DistinctResult {
|
||||
m.ctrl.T.Helper()
|
||||
varargs := []any{ctx, fieldName, filter}
|
||||
for _, a := range opts {
|
||||
varargs = append(varargs, a)
|
||||
}
|
||||
ret := m.ctrl.Call(m, "Distinct", varargs...)
|
||||
ret0, _ := ret[0].(*mongo.DistinctResult)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// Distinct indicates an expected call of Distinct.
|
||||
func (mr *MockmonCollectionMockRecorder) Distinct(ctx, fieldName, filter any, opts ...any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
varargs := append([]any{ctx, fieldName, filter}, opts...)
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Distinct", reflect.TypeOf((*MockmonCollection)(nil).Distinct), varargs...)
|
||||
}
|
||||
|
||||
// Drop mocks base method.
|
||||
func (m *MockmonCollection) Drop(ctx context.Context, opts ...options.Lister[options.DropCollectionOptions]) error {
|
||||
m.ctrl.T.Helper()
|
||||
varargs := []any{ctx}
|
||||
for _, a := range opts {
|
||||
varargs = append(varargs, a)
|
||||
}
|
||||
ret := m.ctrl.Call(m, "Drop", varargs...)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// Drop indicates an expected call of Drop.
|
||||
func (mr *MockmonCollectionMockRecorder) Drop(ctx any, opts ...any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
varargs := append([]any{ctx}, opts...)
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Drop", reflect.TypeOf((*MockmonCollection)(nil).Drop), varargs...)
|
||||
}
|
||||
|
||||
// EstimatedDocumentCount mocks base method.
|
||||
func (m *MockmonCollection) EstimatedDocumentCount(ctx context.Context, opts ...options.Lister[options.EstimatedDocumentCountOptions]) (int64, error) {
|
||||
m.ctrl.T.Helper()
|
||||
varargs := []any{ctx}
|
||||
for _, a := range opts {
|
||||
varargs = append(varargs, a)
|
||||
}
|
||||
ret := m.ctrl.Call(m, "EstimatedDocumentCount", varargs...)
|
||||
ret0, _ := ret[0].(int64)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// EstimatedDocumentCount indicates an expected call of EstimatedDocumentCount.
|
||||
func (mr *MockmonCollectionMockRecorder) EstimatedDocumentCount(ctx any, opts ...any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
varargs := append([]any{ctx}, opts...)
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "EstimatedDocumentCount", reflect.TypeOf((*MockmonCollection)(nil).EstimatedDocumentCount), varargs...)
|
||||
}
|
||||
|
||||
// Find mocks base method.
|
||||
func (m *MockmonCollection) Find(ctx context.Context, filter any, opts ...options.Lister[options.FindOptions]) (*mongo.Cursor, error) {
|
||||
m.ctrl.T.Helper()
|
||||
varargs := []any{ctx, filter}
|
||||
for _, a := range opts {
|
||||
varargs = append(varargs, a)
|
||||
}
|
||||
ret := m.ctrl.Call(m, "Find", varargs...)
|
||||
ret0, _ := ret[0].(*mongo.Cursor)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// Find indicates an expected call of Find.
|
||||
func (mr *MockmonCollectionMockRecorder) Find(ctx, filter any, opts ...any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
varargs := append([]any{ctx, filter}, opts...)
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Find", reflect.TypeOf((*MockmonCollection)(nil).Find), varargs...)
|
||||
}
|
||||
|
||||
// FindOne mocks base method.
|
||||
func (m *MockmonCollection) FindOne(ctx context.Context, filter any, opts ...options.Lister[options.FindOneOptions]) *mongo.SingleResult {
|
||||
m.ctrl.T.Helper()
|
||||
varargs := []any{ctx, filter}
|
||||
for _, a := range opts {
|
||||
varargs = append(varargs, a)
|
||||
}
|
||||
ret := m.ctrl.Call(m, "FindOne", varargs...)
|
||||
ret0, _ := ret[0].(*mongo.SingleResult)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// FindOne indicates an expected call of FindOne.
|
||||
func (mr *MockmonCollectionMockRecorder) FindOne(ctx, filter any, opts ...any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
varargs := append([]any{ctx, filter}, opts...)
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "FindOne", reflect.TypeOf((*MockmonCollection)(nil).FindOne), varargs...)
|
||||
}
|
||||
|
||||
// FindOneAndDelete mocks base method.
|
||||
func (m *MockmonCollection) FindOneAndDelete(ctx context.Context, filter any, opts ...options.Lister[options.FindOneAndDeleteOptions]) *mongo.SingleResult {
|
||||
m.ctrl.T.Helper()
|
||||
varargs := []any{ctx, filter}
|
||||
for _, a := range opts {
|
||||
varargs = append(varargs, a)
|
||||
}
|
||||
ret := m.ctrl.Call(m, "FindOneAndDelete", varargs...)
|
||||
ret0, _ := ret[0].(*mongo.SingleResult)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// FindOneAndDelete indicates an expected call of FindOneAndDelete.
|
||||
func (mr *MockmonCollectionMockRecorder) FindOneAndDelete(ctx, filter any, opts ...any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
varargs := append([]any{ctx, filter}, opts...)
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "FindOneAndDelete", reflect.TypeOf((*MockmonCollection)(nil).FindOneAndDelete), varargs...)
|
||||
}
|
||||
|
||||
// FindOneAndReplace mocks base method.
|
||||
func (m *MockmonCollection) FindOneAndReplace(ctx context.Context, filter, replacement any, opts ...options.Lister[options.FindOneAndReplaceOptions]) *mongo.SingleResult {
|
||||
m.ctrl.T.Helper()
|
||||
varargs := []any{ctx, filter, replacement}
|
||||
for _, a := range opts {
|
||||
varargs = append(varargs, a)
|
||||
}
|
||||
ret := m.ctrl.Call(m, "FindOneAndReplace", varargs...)
|
||||
ret0, _ := ret[0].(*mongo.SingleResult)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// FindOneAndReplace indicates an expected call of FindOneAndReplace.
|
||||
func (mr *MockmonCollectionMockRecorder) FindOneAndReplace(ctx, filter, replacement any, opts ...any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
varargs := append([]any{ctx, filter, replacement}, opts...)
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "FindOneAndReplace", reflect.TypeOf((*MockmonCollection)(nil).FindOneAndReplace), varargs...)
|
||||
}
|
||||
|
||||
// FindOneAndUpdate mocks base method.
|
||||
func (m *MockmonCollection) FindOneAndUpdate(ctx context.Context, filter, update any, opts ...options.Lister[options.FindOneAndUpdateOptions]) *mongo.SingleResult {
|
||||
m.ctrl.T.Helper()
|
||||
varargs := []any{ctx, filter, update}
|
||||
for _, a := range opts {
|
||||
varargs = append(varargs, a)
|
||||
}
|
||||
ret := m.ctrl.Call(m, "FindOneAndUpdate", varargs...)
|
||||
ret0, _ := ret[0].(*mongo.SingleResult)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// FindOneAndUpdate indicates an expected call of FindOneAndUpdate.
|
||||
func (mr *MockmonCollectionMockRecorder) FindOneAndUpdate(ctx, filter, update any, opts ...any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
varargs := append([]any{ctx, filter, update}, opts...)
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "FindOneAndUpdate", reflect.TypeOf((*MockmonCollection)(nil).FindOneAndUpdate), varargs...)
|
||||
}
|
||||
|
||||
// Indexes mocks base method.
|
||||
func (m *MockmonCollection) Indexes() mongo.IndexView {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "Indexes")
|
||||
ret0, _ := ret[0].(mongo.IndexView)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// Indexes indicates an expected call of Indexes.
|
||||
func (mr *MockmonCollectionMockRecorder) Indexes() *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Indexes", reflect.TypeOf((*MockmonCollection)(nil).Indexes))
|
||||
}
|
||||
|
||||
// InsertMany mocks base method.
|
||||
func (m *MockmonCollection) InsertMany(ctx context.Context, documents any, opts ...options.Lister[options.InsertManyOptions]) (*mongo.InsertManyResult, error) {
|
||||
m.ctrl.T.Helper()
|
||||
varargs := []any{ctx, documents}
|
||||
for _, a := range opts {
|
||||
varargs = append(varargs, a)
|
||||
}
|
||||
ret := m.ctrl.Call(m, "InsertMany", varargs...)
|
||||
ret0, _ := ret[0].(*mongo.InsertManyResult)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// InsertMany indicates an expected call of InsertMany.
|
||||
func (mr *MockmonCollectionMockRecorder) InsertMany(ctx, documents any, opts ...any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
varargs := append([]any{ctx, documents}, opts...)
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "InsertMany", reflect.TypeOf((*MockmonCollection)(nil).InsertMany), varargs...)
|
||||
}
|
||||
|
||||
// InsertOne mocks base method.
|
||||
func (m *MockmonCollection) InsertOne(ctx context.Context, document any, opts ...options.Lister[options.InsertOneOptions]) (*mongo.InsertOneResult, error) {
|
||||
m.ctrl.T.Helper()
|
||||
varargs := []any{ctx, document}
|
||||
for _, a := range opts {
|
||||
varargs = append(varargs, a)
|
||||
}
|
||||
ret := m.ctrl.Call(m, "InsertOne", varargs...)
|
||||
ret0, _ := ret[0].(*mongo.InsertOneResult)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// InsertOne indicates an expected call of InsertOne.
|
||||
func (mr *MockmonCollectionMockRecorder) InsertOne(ctx, document any, opts ...any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
varargs := append([]any{ctx, document}, opts...)
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "InsertOne", reflect.TypeOf((*MockmonCollection)(nil).InsertOne), varargs...)
|
||||
}
|
||||
|
||||
// ReplaceOne mocks base method.
|
||||
func (m *MockmonCollection) ReplaceOne(ctx context.Context, filter, replacement any, opts ...options.Lister[options.ReplaceOptions]) (*mongo.UpdateResult, error) {
|
||||
m.ctrl.T.Helper()
|
||||
varargs := []any{ctx, filter, replacement}
|
||||
for _, a := range opts {
|
||||
varargs = append(varargs, a)
|
||||
}
|
||||
ret := m.ctrl.Call(m, "ReplaceOne", varargs...)
|
||||
ret0, _ := ret[0].(*mongo.UpdateResult)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// ReplaceOne indicates an expected call of ReplaceOne.
|
||||
func (mr *MockmonCollectionMockRecorder) ReplaceOne(ctx, filter, replacement any, opts ...any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
varargs := append([]any{ctx, filter, replacement}, opts...)
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReplaceOne", reflect.TypeOf((*MockmonCollection)(nil).ReplaceOne), varargs...)
|
||||
}
|
||||
|
||||
// UpdateByID mocks base method.
|
||||
func (m *MockmonCollection) UpdateByID(ctx context.Context, id, update any, opts ...options.Lister[options.UpdateOneOptions]) (*mongo.UpdateResult, error) {
|
||||
m.ctrl.T.Helper()
|
||||
varargs := []any{ctx, id, update}
|
||||
for _, a := range opts {
|
||||
varargs = append(varargs, a)
|
||||
}
|
||||
ret := m.ctrl.Call(m, "UpdateByID", varargs...)
|
||||
ret0, _ := ret[0].(*mongo.UpdateResult)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// UpdateByID indicates an expected call of UpdateByID.
|
||||
func (mr *MockmonCollectionMockRecorder) UpdateByID(ctx, id, update any, opts ...any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
varargs := append([]any{ctx, id, update}, opts...)
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateByID", reflect.TypeOf((*MockmonCollection)(nil).UpdateByID), varargs...)
|
||||
}
|
||||
|
||||
// UpdateMany mocks base method.
|
||||
func (m *MockmonCollection) UpdateMany(ctx context.Context, filter, update any, opts ...options.Lister[options.UpdateManyOptions]) (*mongo.UpdateResult, error) {
|
||||
m.ctrl.T.Helper()
|
||||
varargs := []any{ctx, filter, update}
|
||||
for _, a := range opts {
|
||||
varargs = append(varargs, a)
|
||||
}
|
||||
ret := m.ctrl.Call(m, "UpdateMany", varargs...)
|
||||
ret0, _ := ret[0].(*mongo.UpdateResult)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// UpdateMany indicates an expected call of UpdateMany.
|
||||
func (mr *MockmonCollectionMockRecorder) UpdateMany(ctx, filter, update any, opts ...any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
varargs := append([]any{ctx, filter, update}, opts...)
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateMany", reflect.TypeOf((*MockmonCollection)(nil).UpdateMany), varargs...)
|
||||
}
|
||||
|
||||
// UpdateOne mocks base method.
|
||||
func (m *MockmonCollection) UpdateOne(ctx context.Context, filter, update any, opts ...options.Lister[options.UpdateOneOptions]) (*mongo.UpdateResult, error) {
|
||||
m.ctrl.T.Helper()
|
||||
varargs := []any{ctx, filter, update}
|
||||
for _, a := range opts {
|
||||
varargs = append(varargs, a)
|
||||
}
|
||||
ret := m.ctrl.Call(m, "UpdateOne", varargs...)
|
||||
ret0, _ := ret[0].(*mongo.UpdateResult)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// UpdateOne indicates an expected call of UpdateOne.
|
||||
func (mr *MockmonCollectionMockRecorder) UpdateOne(ctx, filter, update any, opts ...any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
varargs := append([]any{ctx, filter, update}, opts...)
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateOne", reflect.TypeOf((*MockmonCollection)(nil).UpdateOne), varargs...)
|
||||
}
|
||||
|
||||
// Watch mocks base method.
|
||||
func (m *MockmonCollection) Watch(ctx context.Context, pipeline any, opts ...options.Lister[options.ChangeStreamOptions]) (*mongo.ChangeStream, error) {
|
||||
m.ctrl.T.Helper()
|
||||
varargs := []any{ctx, pipeline}
|
||||
for _, a := range opts {
|
||||
varargs = append(varargs, a)
|
||||
}
|
||||
ret := m.ctrl.Call(m, "Watch", varargs...)
|
||||
ret0, _ := ret[0].(*mongo.ChangeStream)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// Watch indicates an expected call of Watch.
|
||||
func (mr *MockmonCollectionMockRecorder) Watch(ctx, pipeline any, opts ...any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
varargs := append([]any{ctx, pipeline}, opts...)
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Watch", reflect.TypeOf((*MockmonCollection)(nil).Watch), varargs...)
|
||||
}
|
||||
@@ -10,12 +10,10 @@ import (
|
||||
"github.com/zeromicro/go-zero/core/logx/logtest"
|
||||
"github.com/zeromicro/go-zero/core/stringx"
|
||||
"github.com/zeromicro/go-zero/core/timex"
|
||||
"go.mongodb.org/mongo-driver/bson"
|
||||
"go.mongodb.org/mongo-driver/bson/primitive"
|
||||
"go.mongodb.org/mongo-driver/mongo"
|
||||
"go.mongodb.org/mongo-driver/mongo/integration/mtest"
|
||||
mopt "go.mongodb.org/mongo-driver/mongo/options"
|
||||
"go.mongodb.org/mongo-driver/x/mongo/driver/session"
|
||||
"go.mongodb.org/mongo-driver/v2/bson"
|
||||
"go.mongodb.org/mongo-driver/v2/mongo"
|
||||
"go.mongodb.org/mongo-driver/v2/mongo/options"
|
||||
"go.uber.org/mock/gomock"
|
||||
)
|
||||
|
||||
var errDummy = errors.New("dummy")
|
||||
@@ -68,471 +66,345 @@ func TestKeepPromise_keep(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestNewCollection(t *testing.T) {
|
||||
mt := mtest.New(t, mtest.NewOptions().ClientType(mtest.Mock))
|
||||
mt.Run("test", func(mt *mtest.T) {
|
||||
coll := mt.Coll
|
||||
assert.NotNil(t, coll)
|
||||
col := newCollection(coll, breaker.GetBreaker("localhost"))
|
||||
assert.Equal(t, t.Name()+"/test", col.(*decoratedCollection).name)
|
||||
})
|
||||
_ = newCollection(&mongo.Collection{}, breaker.GetBreaker("localhost"))
|
||||
}
|
||||
|
||||
func TestCollection_Aggregate(t *testing.T) {
|
||||
mt := mtest.New(t, mtest.NewOptions().ClientType(mtest.Mock))
|
||||
mt.Run("test", func(mt *mtest.T) {
|
||||
coll := mt.Coll
|
||||
assert.NotNil(t, coll)
|
||||
col := newCollection(coll, breaker.GetBreaker("localhost"))
|
||||
ns := mt.Coll.Database().Name() + "." + mt.Coll.Name()
|
||||
aggRes := mtest.CreateCursorResponse(1, ns, mtest.FirstBatch)
|
||||
mt.AddMockResponses(aggRes)
|
||||
assert.Equal(t, t.Name()+"/test", col.(*decoratedCollection).name)
|
||||
cursor, err := col.Aggregate(context.Background(), mongo.Pipeline{}, mopt.Aggregate())
|
||||
assert.Nil(t, err)
|
||||
cursor.Close(context.Background())
|
||||
})
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
mockCollection := NewMockmonCollection(ctrl)
|
||||
mockCollection.EXPECT().Aggregate(gomock.Any(), gomock.Any(), gomock.Any()).Return(&mongo.Cursor{}, nil)
|
||||
c := newTestCollection(mockCollection, breaker.GetBreaker("localhost"))
|
||||
_, err := c.Aggregate(context.Background(), []interface{}{}, options.Aggregate())
|
||||
assert.Nil(t, err)
|
||||
}
|
||||
|
||||
func TestCollection_BulkWrite(t *testing.T) {
|
||||
mt := mtest.New(t, mtest.NewOptions().ClientType(mtest.Mock))
|
||||
mt.Run("test", func(mt *mtest.T) {
|
||||
c := decoratedCollection{
|
||||
Collection: mt.Coll,
|
||||
brk: breaker.NewBreaker(),
|
||||
}
|
||||
mt.AddMockResponses(mtest.CreateSuccessResponse(bson.D{{Key: "ok", Value: 1}}...))
|
||||
res, err := c.BulkWrite(context.Background(), []mongo.WriteModel{
|
||||
mongo.NewInsertOneModel().SetDocument(bson.D{{Key: "foo", Value: 1}}),
|
||||
})
|
||||
assert.Nil(t, err)
|
||||
assert.NotNil(t, res)
|
||||
c.brk = new(dropBreaker)
|
||||
_, err = c.BulkWrite(context.Background(), []mongo.WriteModel{
|
||||
mongo.NewInsertOneModel().SetDocument(bson.D{{Key: "foo", Value: 1}}),
|
||||
})
|
||||
assert.Equal(t, errDummy, err)
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
mockCollection := NewMockmonCollection(ctrl)
|
||||
mockCollection.EXPECT().BulkWrite(gomock.Any(), gomock.Any(), gomock.Any()).Return(&mongo.BulkWriteResult{}, nil)
|
||||
c := newTestCollection(mockCollection, breaker.GetBreaker("localhost"))
|
||||
_, err := c.BulkWrite(context.Background(), []mongo.WriteModel{
|
||||
mongo.NewInsertOneModel().SetDocument(bson.D{{Key: "foo", Value: 1}}),
|
||||
})
|
||||
assert.Nil(t, err)
|
||||
c.brk = new(dropBreaker)
|
||||
_, err = c.BulkWrite(context.Background(), []mongo.WriteModel{
|
||||
mongo.NewInsertOneModel().SetDocument(bson.D{{Key: "foo", Value: 1}}),
|
||||
})
|
||||
assert.Equal(t, errDummy, err)
|
||||
}
|
||||
|
||||
func TestCollection_CountDocuments(t *testing.T) {
|
||||
mt := mtest.New(t, mtest.NewOptions().ClientType(mtest.Mock))
|
||||
mt.Run("test", func(mt *mtest.T) {
|
||||
c := decoratedCollection{
|
||||
Collection: mt.Coll,
|
||||
brk: breaker.NewBreaker(),
|
||||
}
|
||||
mt.AddMockResponses(mtest.CreateCursorResponse(
|
||||
1,
|
||||
"DBName.CollectionName",
|
||||
mtest.FirstBatch,
|
||||
bson.D{
|
||||
{Key: "n", Value: 1},
|
||||
}))
|
||||
res, err := c.CountDocuments(context.Background(), bson.D{})
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, int64(1), res)
|
||||
|
||||
c.brk = new(dropBreaker)
|
||||
_, err = c.CountDocuments(context.Background(), bson.D{{Key: "foo", Value: 1}})
|
||||
assert.Equal(t, errDummy, err)
|
||||
})
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
mockCollection := NewMockmonCollection(ctrl)
|
||||
mockCollection.EXPECT().CountDocuments(gomock.Any(), gomock.Any(), gomock.Any()).Return(int64(0), nil)
|
||||
c := newTestCollection(mockCollection, breaker.GetBreaker("localhost"))
|
||||
res, err := c.CountDocuments(context.Background(), bson.D{})
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, int64(0), res)
|
||||
c.brk = new(dropBreaker)
|
||||
_, err = c.CountDocuments(context.Background(), bson.D{{Key: "foo", Value: 1}})
|
||||
assert.Equal(t, errDummy, err)
|
||||
}
|
||||
|
||||
func TestDecoratedCollection_DeleteMany(t *testing.T) {
|
||||
mt := mtest.New(t, mtest.NewOptions().ClientType(mtest.Mock))
|
||||
mt.Run("test", func(mt *mtest.T) {
|
||||
c := decoratedCollection{
|
||||
Collection: mt.Coll,
|
||||
brk: breaker.NewBreaker(),
|
||||
}
|
||||
mt.AddMockResponses(mtest.CreateSuccessResponse(bson.D{{Key: "n", Value: 1}}...))
|
||||
res, err := c.DeleteMany(context.Background(), bson.D{})
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, int64(1), res.DeletedCount)
|
||||
|
||||
c.brk = new(dropBreaker)
|
||||
_, err = c.DeleteMany(context.Background(), bson.D{{Key: "foo", Value: 1}})
|
||||
assert.Equal(t, errDummy, err)
|
||||
})
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
mockCollection := NewMockmonCollection(ctrl)
|
||||
mockCollection.EXPECT().DeleteMany(gomock.Any(), gomock.Any(), gomock.Any()).Return(&mongo.DeleteResult{}, nil)
|
||||
c := newTestCollection(mockCollection, breaker.GetBreaker("localhost"))
|
||||
_, err := c.DeleteMany(context.Background(), bson.D{})
|
||||
assert.Nil(t, err)
|
||||
c.brk = new(dropBreaker)
|
||||
_, err = c.DeleteMany(context.Background(), bson.D{{Key: "foo", Value: 1}})
|
||||
assert.Equal(t, errDummy, err)
|
||||
}
|
||||
|
||||
func TestCollection_Distinct(t *testing.T) {
|
||||
mt := mtest.New(t, mtest.NewOptions().ClientType(mtest.Mock))
|
||||
mt.Run("test", func(mt *mtest.T) {
|
||||
c := decoratedCollection{
|
||||
Collection: mt.Coll,
|
||||
brk: breaker.NewBreaker(),
|
||||
}
|
||||
mt.AddMockResponses(bson.D{{Key: "ok", Value: 1}, {Key: "values", Value: []int{1}}})
|
||||
resp, err := c.Distinct(context.Background(), "foo", bson.D{})
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, 1, len(resp))
|
||||
|
||||
c.brk = new(dropBreaker)
|
||||
_, err = c.Distinct(context.Background(), "foo", bson.D{{Key: "foo", Value: 1}})
|
||||
assert.Equal(t, errDummy, err)
|
||||
})
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
mockCollection := NewMockmonCollection(ctrl)
|
||||
mockCollection.EXPECT().Distinct(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(&mongo.DistinctResult{})
|
||||
c := newTestCollection(mockCollection, breaker.GetBreaker("localhost"))
|
||||
_, err := c.Distinct(context.Background(), "foo", bson.D{})
|
||||
assert.Nil(t, err)
|
||||
c.brk = new(dropBreaker)
|
||||
_, err = c.Distinct(context.Background(), "foo", bson.D{{Key: "foo", Value: 1}})
|
||||
assert.Equal(t, errDummy, err)
|
||||
}
|
||||
|
||||
func TestCollection_EstimatedDocumentCount(t *testing.T) {
|
||||
mt := mtest.New(t, mtest.NewOptions().ClientType(mtest.Mock))
|
||||
mt.Run("test", func(mt *mtest.T) {
|
||||
c := decoratedCollection{
|
||||
Collection: mt.Coll,
|
||||
brk: breaker.NewBreaker(),
|
||||
}
|
||||
mt.AddMockResponses(bson.D{{Key: "ok", Value: 1}, {Key: "n", Value: 1}})
|
||||
res, err := c.EstimatedDocumentCount(context.Background())
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, int64(1), res)
|
||||
|
||||
c.brk = new(dropBreaker)
|
||||
_, err = c.EstimatedDocumentCount(context.Background())
|
||||
assert.Equal(t, errDummy, err)
|
||||
})
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
mockCollection := NewMockmonCollection(ctrl)
|
||||
mockCollection.EXPECT().EstimatedDocumentCount(gomock.Any(), gomock.Any()).Return(int64(0), nil)
|
||||
c := newTestCollection(mockCollection, breaker.GetBreaker("localhost"))
|
||||
_, err := c.EstimatedDocumentCount(context.Background())
|
||||
assert.Nil(t, err)
|
||||
c.brk = new(dropBreaker)
|
||||
_, err = c.EstimatedDocumentCount(context.Background())
|
||||
assert.Equal(t, errDummy, err)
|
||||
}
|
||||
|
||||
func TestCollection_Find(t *testing.T) {
|
||||
mt := mtest.New(t, mtest.NewOptions().ClientType(mtest.Mock))
|
||||
mt.Run("test", func(mt *mtest.T) {
|
||||
c := decoratedCollection{
|
||||
Collection: mt.Coll,
|
||||
brk: breaker.NewBreaker(),
|
||||
}
|
||||
find := mtest.CreateCursorResponse(
|
||||
1,
|
||||
"DBName.CollectionName",
|
||||
mtest.FirstBatch,
|
||||
bson.D{
|
||||
{Key: "name", Value: "John"},
|
||||
})
|
||||
getMore := mtest.CreateCursorResponse(
|
||||
1,
|
||||
"DBName.CollectionName",
|
||||
mtest.NextBatch,
|
||||
bson.D{
|
||||
{Key: "name", Value: "Mary"},
|
||||
})
|
||||
killCursors := mtest.CreateCursorResponse(
|
||||
0,
|
||||
"DBName.CollectionName",
|
||||
mtest.NextBatch)
|
||||
mt.AddMockResponses(find, getMore, killCursors)
|
||||
filter := bson.D{{Key: "x", Value: 1}}
|
||||
cursor, err := c.Find(context.Background(), filter, mopt.Find())
|
||||
assert.Nil(t, err)
|
||||
defer cursor.Close(context.Background())
|
||||
|
||||
var val []struct {
|
||||
ID primitive.ObjectID `bson:"_id"`
|
||||
Name string `bson:"name"`
|
||||
}
|
||||
assert.Nil(t, cursor.All(context.Background(), &val))
|
||||
assert.Equal(t, 2, len(val))
|
||||
assert.Equal(t, "John", val[0].Name)
|
||||
assert.Equal(t, "Mary", val[1].Name)
|
||||
|
||||
c.brk = new(dropBreaker)
|
||||
_, err = c.Find(context.Background(), filter, mopt.Find())
|
||||
assert.Equal(t, errDummy, err)
|
||||
})
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
mockCollection := NewMockmonCollection(ctrl)
|
||||
mockCollection.EXPECT().Find(gomock.Any(), gomock.Any(), gomock.Any()).Return(&mongo.Cursor{}, nil)
|
||||
c := newTestCollection(mockCollection, breaker.GetBreaker("localhost"))
|
||||
filter := bson.D{{Key: "x", Value: 1}}
|
||||
_, err := c.Find(context.Background(), filter, options.Find())
|
||||
assert.Nil(t, err)
|
||||
c.brk = new(dropBreaker)
|
||||
_, err = c.Find(context.Background(), filter, options.Find())
|
||||
assert.Equal(t, errDummy, err)
|
||||
}
|
||||
|
||||
func TestCollection_FindOne(t *testing.T) {
|
||||
mt := mtest.New(t, mtest.NewOptions().ClientType(mtest.Mock))
|
||||
mt.Run("test", func(mt *mtest.T) {
|
||||
c := decoratedCollection{
|
||||
Collection: mt.Coll,
|
||||
brk: breaker.NewBreaker(),
|
||||
}
|
||||
find := mtest.CreateCursorResponse(
|
||||
1,
|
||||
"DBName.CollectionName",
|
||||
mtest.FirstBatch,
|
||||
bson.D{
|
||||
{Key: "name", Value: "John"},
|
||||
})
|
||||
getMore := mtest.CreateCursorResponse(
|
||||
1,
|
||||
"DBName.CollectionName",
|
||||
mtest.NextBatch,
|
||||
bson.D{
|
||||
{Key: "name", Value: "Mary"},
|
||||
})
|
||||
killCursors := mtest.CreateCursorResponse(
|
||||
0,
|
||||
"DBName.CollectionName",
|
||||
mtest.NextBatch)
|
||||
mt.AddMockResponses(find, getMore, killCursors)
|
||||
filter := bson.D{{Key: "x", Value: 1}}
|
||||
resp, err := c.FindOne(context.Background(), filter)
|
||||
assert.Nil(t, err)
|
||||
var val struct {
|
||||
ID primitive.ObjectID `bson:"_id"`
|
||||
Name string `bson:"name"`
|
||||
}
|
||||
assert.Nil(t, resp.Decode(&val))
|
||||
assert.Equal(t, "John", val.Name)
|
||||
|
||||
c.brk = new(dropBreaker)
|
||||
_, err = c.FindOne(context.Background(), filter)
|
||||
assert.Equal(t, errDummy, err)
|
||||
})
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
mockCollection := NewMockmonCollection(ctrl)
|
||||
mockCollection.EXPECT().FindOne(gomock.Any(), gomock.Any(), gomock.Any()).Return(&mongo.SingleResult{})
|
||||
c := newTestCollection(mockCollection, breaker.GetBreaker("localhost"))
|
||||
filter := bson.D{{Key: "x", Value: 1}}
|
||||
_, err := c.FindOne(context.Background(), filter)
|
||||
assert.Equal(t, mongo.ErrNoDocuments, err)
|
||||
c.brk = new(dropBreaker)
|
||||
_, err = c.FindOne(context.Background(), filter)
|
||||
assert.Equal(t, errDummy, err)
|
||||
}
|
||||
|
||||
func TestCollection_FindOneAndDelete(t *testing.T) {
|
||||
mt := mtest.New(t, mtest.NewOptions().ClientType(mtest.Mock))
|
||||
mt.Run("test", func(mt *mtest.T) {
|
||||
c := decoratedCollection{
|
||||
Collection: mt.Coll,
|
||||
brk: breaker.NewBreaker(),
|
||||
}
|
||||
filter := bson.D{}
|
||||
mt.AddMockResponses(mtest.CreateSuccessResponse(bson.D{}...))
|
||||
_, err := c.FindOneAndDelete(context.Background(), filter, mopt.FindOneAndDelete())
|
||||
assert.Equal(t, mongo.ErrNoDocuments, err)
|
||||
|
||||
mt.AddMockResponses(mtest.CreateSuccessResponse(bson.D{
|
||||
{Key: "value", Value: bson.D{{Key: "name", Value: "John"}}},
|
||||
}...))
|
||||
resp, err := c.FindOneAndDelete(context.Background(), filter, mopt.FindOneAndDelete())
|
||||
assert.Nil(t, err)
|
||||
var val struct {
|
||||
Name string `bson:"name"`
|
||||
}
|
||||
assert.Nil(t, resp.Decode(&val))
|
||||
assert.Equal(t, "John", val.Name)
|
||||
|
||||
c.brk = new(dropBreaker)
|
||||
_, err = c.FindOneAndDelete(context.Background(), bson.D{{Key: "foo", Value: "bar"}})
|
||||
assert.Equal(t, errDummy, err)
|
||||
})
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
mockCollection := NewMockmonCollection(ctrl)
|
||||
mockCollection.EXPECT().FindOneAndDelete(gomock.Any(), gomock.Any(), gomock.Any()).Return(&mongo.SingleResult{})
|
||||
c := newTestCollection(mockCollection, breaker.GetBreaker("localhost"))
|
||||
filter := bson.D{}
|
||||
mockCollection.EXPECT().FindOneAndDelete(gomock.Any(), gomock.Any(), gomock.Any()).Return(&mongo.SingleResult{})
|
||||
_, err := c.FindOneAndDelete(context.Background(), filter, options.FindOneAndDelete())
|
||||
assert.Equal(t, mongo.ErrNoDocuments, err)
|
||||
_, err = c.FindOneAndDelete(context.Background(), filter, options.FindOneAndDelete())
|
||||
assert.Equal(t, mongo.ErrNoDocuments, err)
|
||||
c.brk = new(dropBreaker)
|
||||
_, err = c.FindOneAndDelete(context.Background(), bson.D{{Key: "foo", Value: "bar"}})
|
||||
assert.Equal(t, errDummy, err)
|
||||
}
|
||||
|
||||
func TestCollection_FindOneAndReplace(t *testing.T) {
|
||||
mt := mtest.New(t, mtest.NewOptions().ClientType(mtest.Mock))
|
||||
mt.Run("test", func(mt *mtest.T) {
|
||||
c := decoratedCollection{
|
||||
Collection: mt.Coll,
|
||||
brk: breaker.NewBreaker(),
|
||||
}
|
||||
mt.AddMockResponses(mtest.CreateSuccessResponse(bson.D{}...))
|
||||
filter := bson.D{{Key: "x", Value: 1}}
|
||||
replacement := bson.D{{Key: "x", Value: 2}}
|
||||
opts := mopt.FindOneAndReplace().SetUpsert(true)
|
||||
_, err := c.FindOneAndReplace(context.Background(), filter, replacement, opts)
|
||||
assert.Equal(t, mongo.ErrNoDocuments, err)
|
||||
mt.AddMockResponses(bson.D{{Key: "ok", Value: 1}, {Key: "value", Value: bson.D{
|
||||
{Key: "name", Value: "John"},
|
||||
}}})
|
||||
resp, err := c.FindOneAndReplace(context.Background(), filter, replacement, opts)
|
||||
assert.Nil(t, err)
|
||||
var val struct {
|
||||
Name string `bson:"name"`
|
||||
}
|
||||
assert.Nil(t, resp.Decode(&val))
|
||||
assert.Equal(t, "John", val.Name)
|
||||
|
||||
c.brk = new(dropBreaker)
|
||||
_, err = c.FindOneAndReplace(context.Background(), filter, replacement, opts)
|
||||
assert.Equal(t, errDummy, err)
|
||||
})
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
mockCollection := NewMockmonCollection(ctrl)
|
||||
mockCollection.EXPECT().FindOneAndReplace(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(&mongo.SingleResult{})
|
||||
c := newTestCollection(mockCollection, breaker.GetBreaker("localhost"))
|
||||
filter := bson.D{{Key: "x", Value: 1}}
|
||||
replacement := bson.D{{Key: "x", Value: 2}}
|
||||
opts := options.FindOneAndReplace().SetUpsert(true)
|
||||
mockCollection.EXPECT().FindOneAndReplace(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(&mongo.SingleResult{})
|
||||
_, err := c.FindOneAndReplace(context.Background(), filter, replacement, opts)
|
||||
assert.Equal(t, mongo.ErrNoDocuments, err)
|
||||
_, err = c.FindOneAndReplace(context.Background(), filter, replacement, opts)
|
||||
assert.Equal(t, mongo.ErrNoDocuments, err)
|
||||
c.brk = new(dropBreaker)
|
||||
_, err = c.FindOneAndReplace(context.Background(), filter, replacement, opts)
|
||||
assert.Equal(t, errDummy, err)
|
||||
}
|
||||
|
||||
func TestCollection_FindOneAndUpdate(t *testing.T) {
|
||||
mt := mtest.New(t, mtest.NewOptions().ClientType(mtest.Mock))
|
||||
mt.Run("test", func(mt *mtest.T) {
|
||||
c := decoratedCollection{
|
||||
Collection: mt.Coll,
|
||||
brk: breaker.NewBreaker(),
|
||||
}
|
||||
mt.AddMockResponses(bson.D{{Key: "ok", Value: 1}})
|
||||
filter := bson.D{{Key: "x", Value: 1}}
|
||||
update := bson.D{{Key: "$x", Value: 2}}
|
||||
opts := mopt.FindOneAndUpdate().SetUpsert(true)
|
||||
_, err := c.FindOneAndUpdate(context.Background(), filter, update, opts)
|
||||
assert.Equal(t, mongo.ErrNoDocuments, err)
|
||||
|
||||
mt.AddMockResponses(bson.D{{Key: "ok", Value: 1}, {Key: "value", Value: bson.D{
|
||||
{Key: "name", Value: "John"},
|
||||
}}})
|
||||
resp, err := c.FindOneAndUpdate(context.Background(), filter, update, opts)
|
||||
assert.Nil(t, err)
|
||||
var val struct {
|
||||
Name string `bson:"name"`
|
||||
}
|
||||
assert.Nil(t, resp.Decode(&val))
|
||||
assert.Equal(t, "John", val.Name)
|
||||
|
||||
c.brk = new(dropBreaker)
|
||||
_, err = c.FindOneAndUpdate(context.Background(), filter, update, opts)
|
||||
assert.Equal(t, errDummy, err)
|
||||
})
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
mockCollection := NewMockmonCollection(ctrl)
|
||||
mockCollection.EXPECT().FindOneAndUpdate(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(&mongo.SingleResult{})
|
||||
c := newTestCollection(mockCollection, breaker.GetBreaker("localhost"))
|
||||
filter := bson.D{{Key: "x", Value: 1}}
|
||||
update := bson.D{{Key: "$x", Value: 2}}
|
||||
mockCollection.EXPECT().FindOneAndUpdate(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(&mongo.SingleResult{})
|
||||
opts := options.FindOneAndUpdate().SetUpsert(true)
|
||||
_, err := c.FindOneAndUpdate(context.Background(), filter, update, opts)
|
||||
assert.Equal(t, mongo.ErrNoDocuments, err)
|
||||
_, err = c.FindOneAndUpdate(context.Background(), filter, update, opts)
|
||||
assert.Equal(t, mongo.ErrNoDocuments, err)
|
||||
c.brk = new(dropBreaker)
|
||||
_, err = c.FindOneAndUpdate(context.Background(), filter, update, opts)
|
||||
assert.Equal(t, errDummy, err)
|
||||
}
|
||||
|
||||
func TestCollection_InsertOne(t *testing.T) {
|
||||
mt := mtest.New(t, mtest.NewOptions().ClientType(mtest.Mock))
|
||||
mt.Run("test", func(mt *mtest.T) {
|
||||
c := decoratedCollection{
|
||||
Collection: mt.Coll,
|
||||
brk: breaker.NewBreaker(),
|
||||
}
|
||||
mt.AddMockResponses(mtest.CreateSuccessResponse(bson.D{{Key: "ok", Value: 1}}...))
|
||||
res, err := c.InsertOne(context.Background(), bson.D{{Key: "foo", Value: "bar"}})
|
||||
assert.Nil(t, err)
|
||||
assert.NotNil(t, res)
|
||||
|
||||
c.brk = new(dropBreaker)
|
||||
_, err = c.InsertOne(context.Background(), bson.D{{Key: "foo", Value: "bar"}})
|
||||
assert.Equal(t, errDummy, err)
|
||||
})
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
mockCollection := NewMockmonCollection(ctrl)
|
||||
mockCollection.EXPECT().InsertOne(gomock.Any(), gomock.Any(), gomock.Any()).Return(&mongo.InsertOneResult{}, nil)
|
||||
c := newTestCollection(mockCollection, breaker.GetBreaker("localhost"))
|
||||
res, err := c.InsertOne(context.Background(), bson.D{{Key: "foo", Value: "bar"}})
|
||||
assert.Nil(t, err)
|
||||
assert.NotNil(t, res)
|
||||
c.brk = new(dropBreaker)
|
||||
_, err = c.InsertOne(context.Background(), bson.D{{Key: "foo", Value: "bar"}})
|
||||
assert.Equal(t, errDummy, err)
|
||||
}
|
||||
|
||||
func TestCollection_InsertMany(t *testing.T) {
|
||||
mt := mtest.New(t, mtest.NewOptions().ClientType(mtest.Mock))
|
||||
mt.Run("test", func(mt *mtest.T) {
|
||||
c := decoratedCollection{
|
||||
Collection: mt.Coll,
|
||||
brk: breaker.NewBreaker(),
|
||||
}
|
||||
mt.AddMockResponses(mtest.CreateSuccessResponse(bson.D{{Key: "ok", Value: 1}}...))
|
||||
res, err := c.InsertMany(context.Background(), []any{
|
||||
bson.D{{Key: "foo", Value: "bar"}},
|
||||
bson.D{{Key: "foo", Value: "baz"}},
|
||||
})
|
||||
assert.Nil(t, err)
|
||||
assert.NotNil(t, res)
|
||||
assert.Equal(t, 2, len(res.InsertedIDs))
|
||||
|
||||
c.brk = new(dropBreaker)
|
||||
_, err = c.InsertMany(context.Background(), []any{bson.D{{Key: "foo", Value: "bar"}}})
|
||||
assert.Equal(t, errDummy, err)
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
mockCollection := NewMockmonCollection(ctrl)
|
||||
mockCollection.EXPECT().InsertMany(gomock.Any(), gomock.Any(), gomock.Any()).Return(&mongo.InsertManyResult{}, nil)
|
||||
c := newTestCollection(mockCollection, breaker.GetBreaker("localhost"))
|
||||
_, err := c.InsertMany(context.Background(), []any{
|
||||
bson.D{{Key: "foo", Value: "bar"}},
|
||||
bson.D{{Key: "foo", Value: "baz"}},
|
||||
})
|
||||
assert.Nil(t, err)
|
||||
c.brk = new(dropBreaker)
|
||||
_, err = c.InsertMany(context.Background(), []any{bson.D{{Key: "foo", Value: "bar"}}})
|
||||
assert.Equal(t, errDummy, err)
|
||||
}
|
||||
|
||||
func TestCollection_DeleteOne(t *testing.T) {
|
||||
mt := mtest.New(t, mtest.NewOptions().ClientType(mtest.Mock))
|
||||
mt.Run("test", func(mt *mtest.T) {
|
||||
c := decoratedCollection{
|
||||
Collection: mt.Coll,
|
||||
brk: breaker.NewBreaker(),
|
||||
}
|
||||
mt.AddMockResponses(mtest.CreateSuccessResponse(bson.D{{Key: "n", Value: 1}}...))
|
||||
res, err := c.DeleteOne(context.Background(), bson.D{{Key: "foo", Value: "bar"}})
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, int64(1), res.DeletedCount)
|
||||
|
||||
c.brk = new(dropBreaker)
|
||||
_, err = c.DeleteOne(context.Background(), bson.D{{Key: "foo", Value: "bar"}})
|
||||
assert.Equal(t, errDummy, err)
|
||||
})
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
mockCollection := NewMockmonCollection(ctrl)
|
||||
mockCollection.EXPECT().DeleteOne(gomock.Any(), gomock.Any(), gomock.Any()).Return(&mongo.DeleteResult{}, nil)
|
||||
c := newTestCollection(mockCollection, breaker.GetBreaker("localhost"))
|
||||
_, err := c.DeleteOne(context.Background(), bson.D{{Key: "foo", Value: "bar"}})
|
||||
assert.Nil(t, err)
|
||||
c.brk = new(dropBreaker)
|
||||
_, err = c.DeleteOne(context.Background(), bson.D{{Key: "foo", Value: "bar"}})
|
||||
assert.Equal(t, errDummy, err)
|
||||
}
|
||||
|
||||
func TestCollection_DeleteMany(t *testing.T) {
|
||||
mt := mtest.New(t, mtest.NewOptions().ClientType(mtest.Mock))
|
||||
mt.Run("test", func(mt *mtest.T) {
|
||||
c := decoratedCollection{
|
||||
Collection: mt.Coll,
|
||||
brk: breaker.NewBreaker(),
|
||||
}
|
||||
mt.AddMockResponses(mtest.CreateSuccessResponse(bson.D{{Key: "n", Value: 1}}...))
|
||||
res, err := c.DeleteMany(context.Background(), bson.D{{Key: "foo", Value: "bar"}})
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, int64(1), res.DeletedCount)
|
||||
|
||||
c.brk = new(dropBreaker)
|
||||
_, err = c.DeleteMany(context.Background(), bson.D{{Key: "foo", Value: "bar"}})
|
||||
assert.Equal(t, errDummy, err)
|
||||
})
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
mockCollection := NewMockmonCollection(ctrl)
|
||||
mockCollection.EXPECT().DeleteMany(gomock.Any(), gomock.Any(), gomock.Any()).Return(&mongo.DeleteResult{}, nil)
|
||||
c := newTestCollection(mockCollection, breaker.GetBreaker("localhost"))
|
||||
_, err := c.DeleteMany(context.Background(), bson.D{{Key: "foo", Value: "bar"}})
|
||||
assert.Nil(t, err)
|
||||
c.brk = new(dropBreaker)
|
||||
_, err = c.DeleteMany(context.Background(), bson.D{{Key: "foo", Value: "bar"}})
|
||||
assert.Equal(t, errDummy, err)
|
||||
}
|
||||
|
||||
func TestCollection_ReplaceOne(t *testing.T) {
|
||||
mt := mtest.New(t, mtest.NewOptions().ClientType(mtest.Mock))
|
||||
mt.Run("test", func(mt *mtest.T) {
|
||||
c := decoratedCollection{
|
||||
Collection: mt.Coll,
|
||||
brk: breaker.NewBreaker(),
|
||||
}
|
||||
mt.AddMockResponses(mtest.CreateSuccessResponse(bson.D{{Key: "n", Value: 1}}...))
|
||||
res, err := c.ReplaceOne(context.Background(), bson.D{{Key: "foo", Value: "bar"}},
|
||||
bson.D{{Key: "foo", Value: "baz"}},
|
||||
)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, int64(1), res.MatchedCount)
|
||||
|
||||
c.brk = new(dropBreaker)
|
||||
_, err = c.ReplaceOne(context.Background(), bson.D{{Key: "foo", Value: "bar"}},
|
||||
bson.D{{Key: "foo", Value: "baz"}})
|
||||
assert.Equal(t, errDummy, err)
|
||||
})
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
mockCollection := NewMockmonCollection(ctrl)
|
||||
mockCollection.EXPECT().ReplaceOne(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(&mongo.UpdateResult{}, nil)
|
||||
c := newTestCollection(mockCollection, breaker.GetBreaker("localhost"))
|
||||
_, err := c.ReplaceOne(context.Background(), bson.D{{Key: "foo", Value: "bar"}},
|
||||
bson.D{{Key: "foo", Value: "baz"}},
|
||||
)
|
||||
assert.Nil(t, err)
|
||||
c.brk = new(dropBreaker)
|
||||
_, err = c.ReplaceOne(context.Background(), bson.D{{Key: "foo", Value: "bar"}},
|
||||
bson.D{{Key: "foo", Value: "baz"}})
|
||||
assert.Equal(t, errDummy, err)
|
||||
}
|
||||
|
||||
func TestCollection_UpdateOne(t *testing.T) {
|
||||
mt := mtest.New(t, mtest.NewOptions().ClientType(mtest.Mock))
|
||||
mt.Run("test", func(mt *mtest.T) {
|
||||
c := decoratedCollection{
|
||||
Collection: mt.Coll,
|
||||
brk: breaker.NewBreaker(),
|
||||
}
|
||||
mt.AddMockResponses(mtest.CreateSuccessResponse(bson.D{{Key: "n", Value: 1}}...))
|
||||
resp, err := c.UpdateOne(context.Background(), bson.D{{Key: "foo", Value: "bar"}},
|
||||
bson.D{{Key: "$set", Value: bson.D{{Key: "baz", Value: "qux"}}}})
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, int64(1), resp.MatchedCount)
|
||||
|
||||
c.brk = new(dropBreaker)
|
||||
_, err = c.UpdateOne(context.Background(), bson.D{{Key: "foo", Value: "bar"}},
|
||||
bson.D{{Key: "$set", Value: bson.D{{Key: "baz", Value: "qux"}}}})
|
||||
assert.Equal(t, errDummy, err)
|
||||
})
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
mockCollection := NewMockmonCollection(ctrl)
|
||||
mockCollection.EXPECT().UpdateOne(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(&mongo.UpdateResult{}, nil)
|
||||
c := newTestCollection(mockCollection, breaker.GetBreaker("localhost"))
|
||||
_, err := c.UpdateOne(context.Background(), bson.D{{Key: "foo", Value: "bar"}},
|
||||
bson.D{{Key: "$set", Value: bson.D{{Key: "baz", Value: "qux"}}}})
|
||||
assert.Nil(t, err)
|
||||
c.brk = new(dropBreaker)
|
||||
_, err = c.UpdateOne(context.Background(), bson.D{{Key: "foo", Value: "bar"}},
|
||||
bson.D{{Key: "$set", Value: bson.D{{Key: "baz", Value: "qux"}}}})
|
||||
assert.Equal(t, errDummy, err)
|
||||
}
|
||||
|
||||
func TestCollection_UpdateByID(t *testing.T) {
|
||||
mt := mtest.New(t, mtest.NewOptions().ClientType(mtest.Mock))
|
||||
mt.Run("test", func(mt *mtest.T) {
|
||||
c := decoratedCollection{
|
||||
Collection: mt.Coll,
|
||||
brk: breaker.NewBreaker(),
|
||||
}
|
||||
mt.AddMockResponses(mtest.CreateSuccessResponse(bson.D{{Key: "n", Value: 1}}...))
|
||||
resp, err := c.UpdateByID(context.Background(), primitive.NewObjectID(),
|
||||
bson.D{{Key: "$set", Value: bson.D{{Key: "baz", Value: "qux"}}}})
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, int64(1), resp.MatchedCount)
|
||||
|
||||
c.brk = new(dropBreaker)
|
||||
_, err = c.UpdateByID(context.Background(), primitive.NewObjectID(),
|
||||
bson.D{{Key: "$set", Value: bson.D{{Key: "baz", Value: "qux"}}}})
|
||||
assert.Equal(t, errDummy, err)
|
||||
})
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
mockCollection := NewMockmonCollection(ctrl)
|
||||
mockCollection.EXPECT().UpdateByID(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(&mongo.UpdateResult{}, nil)
|
||||
c := newTestCollection(mockCollection, breaker.GetBreaker("localhost"))
|
||||
_, err := c.UpdateByID(context.Background(), bson.NewObjectID(),
|
||||
bson.D{{Key: "$set", Value: bson.D{{Key: "baz", Value: "qux"}}}})
|
||||
assert.Nil(t, err)
|
||||
c.brk = new(dropBreaker)
|
||||
_, err = c.UpdateByID(context.Background(), bson.NewObjectID(),
|
||||
bson.D{{Key: "$set", Value: bson.D{{Key: "baz", Value: "qux"}}}})
|
||||
assert.Equal(t, errDummy, err)
|
||||
}
|
||||
|
||||
func TestCollection_UpdateMany(t *testing.T) {
|
||||
mt := mtest.New(t, mtest.NewOptions().ClientType(mtest.Mock))
|
||||
mt.Run("test", func(mt *mtest.T) {
|
||||
c := decoratedCollection{
|
||||
Collection: mt.Coll,
|
||||
brk: breaker.NewBreaker(),
|
||||
}
|
||||
mt.AddMockResponses(mtest.CreateSuccessResponse(bson.D{{Key: "n", Value: 1}}...))
|
||||
resp, err := c.UpdateMany(context.Background(), bson.D{{Key: "foo", Value: "bar"}},
|
||||
bson.D{{Key: "$set", Value: bson.D{{Key: "baz", Value: "qux"}}}})
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, int64(1), resp.MatchedCount)
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
mockCollection := NewMockmonCollection(ctrl)
|
||||
mockCollection.EXPECT().UpdateMany(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(&mongo.UpdateResult{}, nil)
|
||||
c := newTestCollection(mockCollection, breaker.GetBreaker("localhost"))
|
||||
_, err := c.UpdateMany(context.Background(), bson.D{{Key: "foo", Value: "bar"}},
|
||||
bson.D{{Key: "$set", Value: bson.D{{Key: "baz", Value: "qux"}}}})
|
||||
assert.Nil(t, err)
|
||||
c.brk = new(dropBreaker)
|
||||
_, err = c.UpdateMany(context.Background(), bson.D{{Key: "foo", Value: "bar"}},
|
||||
bson.D{{Key: "$set", Value: bson.D{{Key: "baz", Value: "qux"}}}})
|
||||
assert.Equal(t, errDummy, err)
|
||||
}
|
||||
|
||||
c.brk = new(dropBreaker)
|
||||
_, err = c.UpdateMany(context.Background(), bson.D{{Key: "foo", Value: "bar"}},
|
||||
bson.D{{Key: "$set", Value: bson.D{{Key: "baz", Value: "qux"}}}})
|
||||
assert.Equal(t, errDummy, err)
|
||||
})
|
||||
func TestCollection_Watch(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
mockCollection := NewMockmonCollection(ctrl)
|
||||
mockCollection.EXPECT().Watch(gomock.Any(), gomock.Any(), gomock.Any()).Return(&mongo.ChangeStream{}, nil)
|
||||
c := newTestCollection(mockCollection, breaker.GetBreaker("localhost"))
|
||||
_, err := c.Watch(context.Background(), bson.D{{Key: "foo", Value: "bar"}})
|
||||
assert.Nil(t, err)
|
||||
}
|
||||
|
||||
func TestCollection_Clone(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
mockCollection := NewMockmonCollection(ctrl)
|
||||
mockCollection.EXPECT().Clone(gomock.Any()).Return(nil)
|
||||
c := newTestCollection(mockCollection, breaker.GetBreaker("localhost"))
|
||||
cc := c.Clone()
|
||||
assert.Nil(t, cc)
|
||||
}
|
||||
|
||||
func TestCollection_Database(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
mockCollection := NewMockmonCollection(ctrl)
|
||||
mockCollection.EXPECT().Database().Return(nil)
|
||||
c := newTestCollection(mockCollection, breaker.GetBreaker("localhost"))
|
||||
db := c.Database()
|
||||
assert.Nil(t, db)
|
||||
}
|
||||
|
||||
func TestCollection_Drop(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
mockCollection := NewMockmonCollection(ctrl)
|
||||
mockCollection.EXPECT().Drop(gomock.Any()).Return(nil)
|
||||
c := newTestCollection(mockCollection, breaker.GetBreaker("localhost"))
|
||||
err := c.Drop(context.Background())
|
||||
assert.Nil(t, err)
|
||||
}
|
||||
|
||||
func TestCollection_Indexes(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
mockCollection := NewMockmonCollection(ctrl)
|
||||
idx := mongo.IndexView{}
|
||||
mockCollection.EXPECT().Indexes().Return(idx)
|
||||
c := newTestCollection(mockCollection, breaker.GetBreaker("localhost"))
|
||||
index := c.Indexes()
|
||||
assert.Equal(t, index, idx)
|
||||
}
|
||||
|
||||
func TestDecoratedCollection_LogDuration(t *testing.T) {
|
||||
mt := mtest.New(t, mtest.NewOptions().ClientType(mtest.Mock))
|
||||
c := decoratedCollection{
|
||||
Collection: mt.Coll,
|
||||
brk: breaker.NewBreaker(),
|
||||
}
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
mockCollection := NewMockmonCollection(ctrl)
|
||||
c := newTestCollection(mockCollection, breaker.GetBreaker("localhost"))
|
||||
|
||||
buf := logtest.NewCollector(t)
|
||||
|
||||
@@ -585,14 +457,6 @@ func TestAcceptable(t *testing.T) {
|
||||
{"NilDocument", mongo.ErrNilDocument, true},
|
||||
{"NilCursor", mongo.ErrNilCursor, true},
|
||||
{"EmptySlice", mongo.ErrEmptySlice, true},
|
||||
{"SessionEnded", session.ErrSessionEnded, true},
|
||||
{"NoTransactStarted", session.ErrNoTransactStarted, true},
|
||||
{"TransactInProgress", session.ErrTransactInProgress, true},
|
||||
{"AbortAfterCommit", session.ErrAbortAfterCommit, true},
|
||||
{"AbortTwice", session.ErrAbortTwice, true},
|
||||
{"CommitAfterAbort", session.ErrCommitAfterAbort, true},
|
||||
{"UnackWCUnsupported", session.ErrUnackWCUnsupported, true},
|
||||
{"SnapshotTransaction", session.ErrSnapshotTransaction, true},
|
||||
{"DuplicateKeyError", mongo.WriteException{WriteErrors: []mongo.WriteError{{Code: duplicateKeyCode}}}, true},
|
||||
{"OtherError", errors.New("other error"), false},
|
||||
}
|
||||
@@ -623,6 +487,14 @@ func TestIsDupKeyError(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func newTestCollection(collection monCollection, brk breaker.Breaker) *decoratedCollection {
|
||||
return &decoratedCollection{
|
||||
Collection: collection,
|
||||
name: "test",
|
||||
brk: brk,
|
||||
}
|
||||
}
|
||||
|
||||
type mockPromise struct {
|
||||
accepted bool
|
||||
reason string
|
||||
|
||||
63
core/stores/mon/collectioninserter_mock.go
Normal file
63
core/stores/mon/collectioninserter_mock.go
Normal file
@@ -0,0 +1,63 @@
|
||||
// Code generated by MockGen. DO NOT EDIT.
|
||||
// Source: bulkinserter.go
|
||||
//
|
||||
// Generated by this command:
|
||||
//
|
||||
// mockgen -package mon -destination collectioninserter_mock.go -source bulkinserter.go collectionInserter
|
||||
//
|
||||
|
||||
// Package mon is a generated GoMock package.
|
||||
package mon
|
||||
|
||||
import (
|
||||
context "context"
|
||||
reflect "reflect"
|
||||
|
||||
mongo "go.mongodb.org/mongo-driver/v2/mongo"
|
||||
options "go.mongodb.org/mongo-driver/v2/mongo/options"
|
||||
gomock "go.uber.org/mock/gomock"
|
||||
)
|
||||
|
||||
// MockcollectionInserter is a mock of collectionInserter interface.
|
||||
type MockcollectionInserter struct {
|
||||
ctrl *gomock.Controller
|
||||
recorder *MockcollectionInserterMockRecorder
|
||||
isgomock struct{}
|
||||
}
|
||||
|
||||
// MockcollectionInserterMockRecorder is the mock recorder for MockcollectionInserter.
|
||||
type MockcollectionInserterMockRecorder struct {
|
||||
mock *MockcollectionInserter
|
||||
}
|
||||
|
||||
// NewMockcollectionInserter creates a new mock instance.
|
||||
func NewMockcollectionInserter(ctrl *gomock.Controller) *MockcollectionInserter {
|
||||
mock := &MockcollectionInserter{ctrl: ctrl}
|
||||
mock.recorder = &MockcollectionInserterMockRecorder{mock}
|
||||
return mock
|
||||
}
|
||||
|
||||
// EXPECT returns an object that allows the caller to indicate expected use.
|
||||
func (m *MockcollectionInserter) EXPECT() *MockcollectionInserterMockRecorder {
|
||||
return m.recorder
|
||||
}
|
||||
|
||||
// InsertMany mocks base method.
|
||||
func (m *MockcollectionInserter) InsertMany(ctx context.Context, documents any, opts ...options.Lister[options.InsertManyOptions]) (*mongo.InsertManyResult, error) {
|
||||
m.ctrl.T.Helper()
|
||||
varargs := []any{ctx, documents}
|
||||
for _, a := range opts {
|
||||
varargs = append(varargs, a)
|
||||
}
|
||||
ret := m.ctrl.Call(m, "InsertMany", varargs...)
|
||||
ret0, _ := ret[0].(*mongo.InsertManyResult)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// InsertMany indicates an expected call of InsertMany.
|
||||
func (mr *MockcollectionInserterMockRecorder) InsertMany(ctx, documents any, opts ...any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
varargs := append([]any{ctx, documents}, opts...)
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "InsertMany", reflect.TypeOf((*MockcollectionInserter)(nil).InsertMany), varargs...)
|
||||
}
|
||||
19
core/stores/mon/migration-2.0.md
Normal file
19
core/stores/mon/migration-2.0.md
Normal file
@@ -0,0 +1,19 @@
|
||||
# Migrating from 1.x to 2.0
|
||||
|
||||
To upgrade imports of the Go Driver from v1 to v2, we recommend using [marwan-at-work/mod
|
||||
](https://github.com/marwan-at-work/mod):
|
||||
|
||||
```
|
||||
mod upgrade --mod-name=go.mongodb.org/mongo-driver
|
||||
```
|
||||
|
||||
# Notice
|
||||
After completing the mod upgrade, code changes are typically unnecessary in the vast majority of cases. However, if your project references packages including but not limited to those listed below, you'll need to manually replace them, as these libraries are no longer present in the v2 version.
|
||||
```go
|
||||
go.mongodb.org/mongo-driver/bson/bsonrw => go.mongodb.org/mongo-driver/v2/bson
|
||||
go.mongodb.org/mongo-driver/bson/bsoncodec => go.mongodb.org/mongo-driver/v2/bson
|
||||
go.mongodb.org/mongo-driver/bson/primitive => go.mongodb.org/mongo-driver/v2/bson
|
||||
```
|
||||
|
||||
See the following resources to learn more about upgrading from version 1.x to 2.0.:
|
||||
https://raw.githubusercontent.com/mongodb/mongo-go-driver/refs/heads/master/docs/migration-2.0.md
|
||||
@@ -1,3 +1,4 @@
|
||||
//go:generate mockgen -package mon -destination model_mock.go -source model.go monClient monSession
|
||||
package mon
|
||||
|
||||
import (
|
||||
@@ -7,8 +8,8 @@ import (
|
||||
"github.com/zeromicro/go-zero/core/breaker"
|
||||
"github.com/zeromicro/go-zero/core/logx"
|
||||
"github.com/zeromicro/go-zero/core/timex"
|
||||
"go.mongodb.org/mongo-driver/mongo"
|
||||
mopt "go.mongodb.org/mongo-driver/mongo/options"
|
||||
"go.mongodb.org/mongo-driver/v2/mongo"
|
||||
"go.mongodb.org/mongo-driver/v2/mongo/options"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -24,15 +25,15 @@ type (
|
||||
Model struct {
|
||||
Collection
|
||||
name string
|
||||
cli *mongo.Client
|
||||
cli monClient
|
||||
brk breaker.Breaker
|
||||
opts []Option
|
||||
}
|
||||
|
||||
wrappedSession struct {
|
||||
mongo.Session
|
||||
name string
|
||||
brk breaker.Breaker
|
||||
Session struct {
|
||||
session monSession
|
||||
name string
|
||||
brk breaker.Breaker
|
||||
}
|
||||
)
|
||||
|
||||
@@ -61,14 +62,14 @@ func newModel(name string, cli *mongo.Client, coll Collection, brk breaker.Break
|
||||
return &Model{
|
||||
name: name,
|
||||
Collection: coll,
|
||||
cli: cli,
|
||||
cli: &wrappedMonClient{c: cli},
|
||||
brk: brk,
|
||||
opts: opts,
|
||||
}
|
||||
}
|
||||
|
||||
// StartSession starts a new session.
|
||||
func (m *Model) StartSession(opts ...*mopt.SessionOptions) (sess mongo.Session, err error) {
|
||||
func (m *Model) StartSession(opts ...options.Lister[options.SessionOptions]) (sess *Session, err error) {
|
||||
starTime := timex.Now()
|
||||
defer func() {
|
||||
logDuration(context.Background(), m.name, startSession, starTime, err)
|
||||
@@ -79,15 +80,16 @@ func (m *Model) StartSession(opts ...*mopt.SessionOptions) (sess mongo.Session,
|
||||
return nil, sessionErr
|
||||
}
|
||||
|
||||
return &wrappedSession{
|
||||
Session: session,
|
||||
return &Session{
|
||||
session: session,
|
||||
name: m.name,
|
||||
brk: m.brk,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Aggregate executes an aggregation pipeline.
|
||||
func (m *Model) Aggregate(ctx context.Context, v, pipeline any, opts ...*mopt.AggregateOptions) error {
|
||||
func (m *Model) Aggregate(ctx context.Context, v, pipeline any,
|
||||
opts ...options.Lister[options.AggregateOptions]) error {
|
||||
cur, err := m.Collection.Aggregate(ctx, pipeline, opts...)
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -98,7 +100,8 @@ func (m *Model) Aggregate(ctx context.Context, v, pipeline any, opts ...*mopt.Ag
|
||||
}
|
||||
|
||||
// DeleteMany deletes documents that match the filter.
|
||||
func (m *Model) DeleteMany(ctx context.Context, filter any, opts ...*mopt.DeleteOptions) (int64, error) {
|
||||
func (m *Model) DeleteMany(ctx context.Context, filter any,
|
||||
opts ...options.Lister[options.DeleteManyOptions]) (int64, error) {
|
||||
res, err := m.Collection.DeleteMany(ctx, filter, opts...)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
@@ -108,7 +111,8 @@ func (m *Model) DeleteMany(ctx context.Context, filter any, opts ...*mopt.Delete
|
||||
}
|
||||
|
||||
// DeleteOne deletes the first document that matches the filter.
|
||||
func (m *Model) DeleteOne(ctx context.Context, filter any, opts ...*mopt.DeleteOptions) (int64, error) {
|
||||
func (m *Model) DeleteOne(ctx context.Context, filter any,
|
||||
opts ...options.Lister[options.DeleteOneOptions]) (int64, error) {
|
||||
res, err := m.Collection.DeleteOne(ctx, filter, opts...)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
@@ -118,7 +122,8 @@ func (m *Model) DeleteOne(ctx context.Context, filter any, opts ...*mopt.DeleteO
|
||||
}
|
||||
|
||||
// Find finds documents that match the filter.
|
||||
func (m *Model) Find(ctx context.Context, v, filter any, opts ...*mopt.FindOptions) error {
|
||||
func (m *Model) Find(ctx context.Context, v, filter any,
|
||||
opts ...options.Lister[options.FindOptions]) error {
|
||||
cur, err := m.Collection.Find(ctx, filter, opts...)
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -129,7 +134,8 @@ func (m *Model) Find(ctx context.Context, v, filter any, opts ...*mopt.FindOptio
|
||||
}
|
||||
|
||||
// FindOne finds the first document that matches the filter.
|
||||
func (m *Model) FindOne(ctx context.Context, v, filter any, opts ...*mopt.FindOneOptions) error {
|
||||
func (m *Model) FindOne(ctx context.Context, v, filter any,
|
||||
opts ...options.Lister[options.FindOneOptions]) error {
|
||||
res, err := m.Collection.FindOne(ctx, filter, opts...)
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -140,7 +146,7 @@ func (m *Model) FindOne(ctx context.Context, v, filter any, opts ...*mopt.FindOn
|
||||
|
||||
// FindOneAndDelete finds a single document and deletes it.
|
||||
func (m *Model) FindOneAndDelete(ctx context.Context, v, filter any,
|
||||
opts ...*mopt.FindOneAndDeleteOptions) error {
|
||||
opts ...options.Lister[options.FindOneAndDeleteOptions]) error {
|
||||
res, err := m.Collection.FindOneAndDelete(ctx, filter, opts...)
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -151,7 +157,7 @@ func (m *Model) FindOneAndDelete(ctx context.Context, v, filter any,
|
||||
|
||||
// FindOneAndReplace finds a single document and replaces it.
|
||||
func (m *Model) FindOneAndReplace(ctx context.Context, v, filter, replacement any,
|
||||
opts ...*mopt.FindOneAndReplaceOptions) error {
|
||||
opts ...options.Lister[options.FindOneAndReplaceOptions]) error {
|
||||
res, err := m.Collection.FindOneAndReplace(ctx, filter, replacement, opts...)
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -162,7 +168,7 @@ func (m *Model) FindOneAndReplace(ctx context.Context, v, filter, replacement an
|
||||
|
||||
// FindOneAndUpdate finds a single document and updates it.
|
||||
func (m *Model) FindOneAndUpdate(ctx context.Context, v, filter, update any,
|
||||
opts ...*mopt.FindOneAndUpdateOptions) error {
|
||||
opts ...options.Lister[options.FindOneAndUpdateOptions]) error {
|
||||
res, err := m.Collection.FindOneAndUpdate(ctx, filter, update, opts...)
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -171,8 +177,8 @@ func (m *Model) FindOneAndUpdate(ctx context.Context, v, filter, update any,
|
||||
return res.Decode(v)
|
||||
}
|
||||
|
||||
// AbortTransaction implements the mongo.Session interface.
|
||||
func (w *wrappedSession) AbortTransaction(ctx context.Context) (err error) {
|
||||
// AbortTransaction implements the mongo.session interface.
|
||||
func (w *Session) AbortTransaction(ctx context.Context) (err error) {
|
||||
ctx, span := startSpan(ctx, abortTransaction)
|
||||
defer func() {
|
||||
endSpan(span, err)
|
||||
@@ -184,12 +190,12 @@ func (w *wrappedSession) AbortTransaction(ctx context.Context) (err error) {
|
||||
logDuration(ctx, w.name, abortTransaction, starTime, err)
|
||||
}()
|
||||
|
||||
return w.Session.AbortTransaction(ctx)
|
||||
return w.session.AbortTransaction(ctx)
|
||||
}, acceptable)
|
||||
}
|
||||
|
||||
// CommitTransaction implements the mongo.Session interface.
|
||||
func (w *wrappedSession) CommitTransaction(ctx context.Context) (err error) {
|
||||
// CommitTransaction implements the mongo.session interface.
|
||||
func (w *Session) CommitTransaction(ctx context.Context) (err error) {
|
||||
ctx, span := startSpan(ctx, commitTransaction)
|
||||
defer func() {
|
||||
endSpan(span, err)
|
||||
@@ -201,15 +207,15 @@ func (w *wrappedSession) CommitTransaction(ctx context.Context) (err error) {
|
||||
logDuration(ctx, w.name, commitTransaction, starTime, err)
|
||||
}()
|
||||
|
||||
return w.Session.CommitTransaction(ctx)
|
||||
return w.session.CommitTransaction(ctx)
|
||||
}, acceptable)
|
||||
}
|
||||
|
||||
// WithTransaction implements the mongo.Session interface.
|
||||
func (w *wrappedSession) WithTransaction(
|
||||
// WithTransaction implements the mongo.session interface.
|
||||
func (w *Session) WithTransaction(
|
||||
ctx context.Context,
|
||||
fn func(sessCtx mongo.SessionContext) (any, error),
|
||||
opts ...*mopt.TransactionOptions,
|
||||
fn func(sessCtx context.Context) (any, error),
|
||||
opts ...options.Lister[options.TransactionOptions],
|
||||
) (res any, err error) {
|
||||
ctx, span := startSpan(ctx, withTransaction)
|
||||
defer func() {
|
||||
@@ -222,15 +228,15 @@ func (w *wrappedSession) WithTransaction(
|
||||
logDuration(ctx, w.name, withTransaction, starTime, err)
|
||||
}()
|
||||
|
||||
res, err = w.Session.WithTransaction(ctx, fn, opts...)
|
||||
res, err = w.session.WithTransaction(ctx, fn, opts...)
|
||||
return err
|
||||
}, acceptable)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// EndSession implements the mongo.Session interface.
|
||||
func (w *wrappedSession) EndSession(ctx context.Context) {
|
||||
// EndSession implements the mongo.session interface.
|
||||
func (w *Session) EndSession(ctx context.Context) {
|
||||
var err error
|
||||
ctx, span := startSpan(ctx, endSession)
|
||||
defer func() {
|
||||
@@ -243,7 +249,34 @@ func (w *wrappedSession) EndSession(ctx context.Context) {
|
||||
logDuration(ctx, w.name, endSession, starTime, err)
|
||||
}()
|
||||
|
||||
w.Session.EndSession(ctx)
|
||||
w.session.EndSession(ctx)
|
||||
return nil
|
||||
}, acceptable)
|
||||
}
|
||||
|
||||
type (
|
||||
// for unit test
|
||||
monClient interface {
|
||||
StartSession(opts ...options.Lister[options.SessionOptions]) (monSession, error)
|
||||
}
|
||||
|
||||
monSession interface {
|
||||
AbortTransaction(ctx context.Context) error
|
||||
CommitTransaction(ctx context.Context) error
|
||||
EndSession(ctx context.Context)
|
||||
WithTransaction(ctx context.Context, fn func(sessCtx context.Context) (any, error),
|
||||
opts ...options.Lister[options.TransactionOptions]) (any, error)
|
||||
}
|
||||
)
|
||||
|
||||
type wrappedMonClient struct {
|
||||
c *mongo.Client
|
||||
}
|
||||
|
||||
// StartSession starts a new session using the underlying *mongo.Client.
|
||||
// It implements the monClient interface.
|
||||
// This is used to allow mocking in unit tests.
|
||||
func (m *wrappedMonClient) StartSession(opts ...options.Lister[options.SessionOptions]) (
|
||||
monSession, error) {
|
||||
return m.c.StartSession(opts...)
|
||||
}
|
||||
|
||||
145
core/stores/mon/model_mock.go
Normal file
145
core/stores/mon/model_mock.go
Normal file
@@ -0,0 +1,145 @@
|
||||
// Code generated by MockGen. DO NOT EDIT.
|
||||
// Source: model.go
|
||||
//
|
||||
// Generated by this command:
|
||||
//
|
||||
// mockgen -package mon -destination model_mock.go -source model.go monClient monSession
|
||||
//
|
||||
|
||||
// Package mon is a generated GoMock package.
|
||||
package mon
|
||||
|
||||
import (
|
||||
context "context"
|
||||
reflect "reflect"
|
||||
|
||||
options "go.mongodb.org/mongo-driver/v2/mongo/options"
|
||||
gomock "go.uber.org/mock/gomock"
|
||||
)
|
||||
|
||||
// MockmonClient is a mock of monClient interface.
|
||||
type MockmonClient struct {
|
||||
ctrl *gomock.Controller
|
||||
recorder *MockmonClientMockRecorder
|
||||
isgomock struct{}
|
||||
}
|
||||
|
||||
// MockmonClientMockRecorder is the mock recorder for MockmonClient.
|
||||
type MockmonClientMockRecorder struct {
|
||||
mock *MockmonClient
|
||||
}
|
||||
|
||||
// NewMockmonClient creates a new mock instance.
|
||||
func NewMockmonClient(ctrl *gomock.Controller) *MockmonClient {
|
||||
mock := &MockmonClient{ctrl: ctrl}
|
||||
mock.recorder = &MockmonClientMockRecorder{mock}
|
||||
return mock
|
||||
}
|
||||
|
||||
// EXPECT returns an object that allows the caller to indicate expected use.
|
||||
func (m *MockmonClient) EXPECT() *MockmonClientMockRecorder {
|
||||
return m.recorder
|
||||
}
|
||||
|
||||
// StartSession mocks base method.
|
||||
func (m *MockmonClient) StartSession(opts ...options.Lister[options.SessionOptions]) (monSession, error) {
|
||||
m.ctrl.T.Helper()
|
||||
varargs := []any{}
|
||||
for _, a := range opts {
|
||||
varargs = append(varargs, a)
|
||||
}
|
||||
ret := m.ctrl.Call(m, "StartSession", varargs...)
|
||||
ret0, _ := ret[0].(monSession)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// StartSession indicates an expected call of StartSession.
|
||||
func (mr *MockmonClientMockRecorder) StartSession(opts ...any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "StartSession", reflect.TypeOf((*MockmonClient)(nil).StartSession), opts...)
|
||||
}
|
||||
|
||||
// MockmonSession is a mock of monSession interface.
|
||||
type MockmonSession struct {
|
||||
ctrl *gomock.Controller
|
||||
recorder *MockmonSessionMockRecorder
|
||||
isgomock struct{}
|
||||
}
|
||||
|
||||
// MockmonSessionMockRecorder is the mock recorder for MockmonSession.
|
||||
type MockmonSessionMockRecorder struct {
|
||||
mock *MockmonSession
|
||||
}
|
||||
|
||||
// NewMockmonSession creates a new mock instance.
|
||||
func NewMockmonSession(ctrl *gomock.Controller) *MockmonSession {
|
||||
mock := &MockmonSession{ctrl: ctrl}
|
||||
mock.recorder = &MockmonSessionMockRecorder{mock}
|
||||
return mock
|
||||
}
|
||||
|
||||
// EXPECT returns an object that allows the caller to indicate expected use.
|
||||
func (m *MockmonSession) EXPECT() *MockmonSessionMockRecorder {
|
||||
return m.recorder
|
||||
}
|
||||
|
||||
// AbortTransaction mocks base method.
|
||||
func (m *MockmonSession) AbortTransaction(ctx context.Context) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "AbortTransaction", ctx)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// AbortTransaction indicates an expected call of AbortTransaction.
|
||||
func (mr *MockmonSessionMockRecorder) AbortTransaction(ctx any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AbortTransaction", reflect.TypeOf((*MockmonSession)(nil).AbortTransaction), ctx)
|
||||
}
|
||||
|
||||
// CommitTransaction mocks base method.
|
||||
func (m *MockmonSession) CommitTransaction(ctx context.Context) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "CommitTransaction", ctx)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// CommitTransaction indicates an expected call of CommitTransaction.
|
||||
func (mr *MockmonSessionMockRecorder) CommitTransaction(ctx any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CommitTransaction", reflect.TypeOf((*MockmonSession)(nil).CommitTransaction), ctx)
|
||||
}
|
||||
|
||||
// EndSession mocks base method.
|
||||
func (m *MockmonSession) EndSession(ctx context.Context) {
|
||||
m.ctrl.T.Helper()
|
||||
m.ctrl.Call(m, "EndSession", ctx)
|
||||
}
|
||||
|
||||
// EndSession indicates an expected call of EndSession.
|
||||
func (mr *MockmonSessionMockRecorder) EndSession(ctx any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "EndSession", reflect.TypeOf((*MockmonSession)(nil).EndSession), ctx)
|
||||
}
|
||||
|
||||
// WithTransaction mocks base method.
|
||||
func (m *MockmonSession) WithTransaction(ctx context.Context, fn func(context.Context) (any, error), opts ...options.Lister[options.TransactionOptions]) (any, error) {
|
||||
m.ctrl.T.Helper()
|
||||
varargs := []any{ctx, fn}
|
||||
for _, a := range opts {
|
||||
varargs = append(varargs, a)
|
||||
}
|
||||
ret := m.ctrl.Call(m, "WithTransaction", varargs...)
|
||||
ret0, _ := ret[0].(any)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// WithTransaction indicates an expected call of WithTransaction.
|
||||
func (mr *MockmonSessionMockRecorder) WithTransaction(ctx, fn any, opts ...any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
varargs := append([]any{ctx, fn}, opts...)
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "WithTransaction", reflect.TypeOf((*MockmonSession)(nil).WithTransaction), varargs...)
|
||||
}
|
||||
@@ -2,224 +2,242 @@ package mon
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"go.mongodb.org/mongo-driver/bson"
|
||||
"go.mongodb.org/mongo-driver/mongo"
|
||||
"go.mongodb.org/mongo-driver/mongo/integration/mtest"
|
||||
"github.com/zeromicro/go-zero/core/breaker"
|
||||
"go.mongodb.org/mongo-driver/v2/bson"
|
||||
"go.mongodb.org/mongo-driver/v2/mongo"
|
||||
"go.mongodb.org/mongo-driver/v2/mongo/options"
|
||||
"go.mongodb.org/mongo-driver/v2/x/mongo/driver/drivertest"
|
||||
"go.uber.org/mock/gomock"
|
||||
)
|
||||
|
||||
func TestModel_StartSession(t *testing.T) {
|
||||
mt := mtest.New(t, mtest.NewOptions().ClientType(mtest.Mock))
|
||||
mt.Run("test", func(mt *mtest.T) {
|
||||
m := createModel(mt)
|
||||
sess, err := m.StartSession()
|
||||
assert.Nil(t, err)
|
||||
defer sess.EndSession(context.Background())
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
mockMonCollection := NewMockmonCollection(ctrl)
|
||||
mockedMonClient := NewMockmonClient(ctrl)
|
||||
mockMonSession := NewMockmonSession(ctrl)
|
||||
warpSession := &Session{
|
||||
session: mockMonSession,
|
||||
name: "",
|
||||
brk: breaker.GetBreaker("localhost"),
|
||||
}
|
||||
|
||||
_, err = sess.WithTransaction(context.Background(), func(sessCtx mongo.SessionContext) (any, error) {
|
||||
_ = sessCtx.StartTransaction()
|
||||
sessCtx.Client().Database("1")
|
||||
sessCtx.EndSession(context.Background())
|
||||
return nil, nil
|
||||
})
|
||||
assert.Nil(t, err)
|
||||
assert.NoError(t, sess.CommitTransaction(context.Background()))
|
||||
assert.Error(t, sess.AbortTransaction(context.Background()))
|
||||
m := newTestModel("foo", mockedMonClient, mockMonCollection, breaker.GetBreaker("test"))
|
||||
mockedMonClient.EXPECT().StartSession(gomock.Any()).Return(warpSession, errors.New("error"))
|
||||
_, err := m.StartSession()
|
||||
assert.NotNil(t, err)
|
||||
mockedMonClient.EXPECT().StartSession(gomock.Any()).Return(warpSession, nil)
|
||||
sess, err := m.StartSession()
|
||||
assert.Nil(t, err)
|
||||
defer sess.EndSession(context.Background())
|
||||
mockMonSession.EXPECT().WithTransaction(gomock.Any(), gomock.Any()).Return(nil, nil)
|
||||
mockMonSession.EXPECT().CommitTransaction(gomock.Any()).Return(nil)
|
||||
mockMonSession.EXPECT().AbortTransaction(gomock.Any()).Return(nil)
|
||||
mockMonSession.EXPECT().EndSession(gomock.Any())
|
||||
_, err = sess.WithTransaction(context.Background(), func(sessCtx context.Context) (any, error) {
|
||||
// _ = sessCtx.StartTransaction()
|
||||
// sessCtx.Client().Database("1")
|
||||
// sessCtx.EndSession(context.Background())
|
||||
return nil, nil
|
||||
})
|
||||
assert.Nil(t, err)
|
||||
assert.NoError(t, sess.CommitTransaction(context.Background()))
|
||||
assert.NoError(t, sess.AbortTransaction(context.Background()))
|
||||
}
|
||||
|
||||
func TestModel_Aggregate(t *testing.T) {
|
||||
mt := mtest.New(t, mtest.NewOptions().ClientType(mtest.Mock))
|
||||
mt.Run("test", func(mt *mtest.T) {
|
||||
m := createModel(mt)
|
||||
find := mtest.CreateCursorResponse(
|
||||
1,
|
||||
"DBName.CollectionName",
|
||||
mtest.FirstBatch,
|
||||
bson.D{
|
||||
{Key: "name", Value: "John"},
|
||||
})
|
||||
getMore := mtest.CreateCursorResponse(
|
||||
1,
|
||||
"DBName.CollectionName",
|
||||
mtest.NextBatch,
|
||||
bson.D{
|
||||
{Key: "name", Value: "Mary"},
|
||||
})
|
||||
killCursors := mtest.CreateCursorResponse(
|
||||
0,
|
||||
"DBName.CollectionName",
|
||||
mtest.NextBatch)
|
||||
mt.AddMockResponses(find, getMore, killCursors)
|
||||
var result []any
|
||||
err := m.Aggregate(context.Background(), &result, mongo.Pipeline{})
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, 2, len(result))
|
||||
assert.Equal(t, "John", result[0].(bson.D).Map()["name"])
|
||||
assert.Equal(t, "Mary", result[1].(bson.D).Map()["name"])
|
||||
|
||||
triggerBreaker(m)
|
||||
assert.Equal(t, errDummy, m.Aggregate(context.Background(), &result, mongo.Pipeline{}))
|
||||
})
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
mockMonCollection := NewMockmonCollection(ctrl)
|
||||
mockedMonClient := NewMockmonClient(ctrl)
|
||||
cursor, err := mongo.NewCursorFromDocuments([]any{
|
||||
bson.M{
|
||||
"name": "John",
|
||||
},
|
||||
bson.M{
|
||||
"name": "Mary",
|
||||
},
|
||||
}, nil, nil)
|
||||
assert.NoError(t, err)
|
||||
mockMonCollection.EXPECT().Aggregate(gomock.Any(), gomock.Any(), gomock.Any()).Return(cursor, nil)
|
||||
m := newTestModel("foo", mockedMonClient, mockMonCollection, breaker.GetBreaker("test"))
|
||||
var result []bson.M
|
||||
err = m.Aggregate(context.Background(), &result, bson.D{})
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, 2, len(result))
|
||||
assert.Equal(t, "John", result[0]["name"])
|
||||
assert.Equal(t, "Mary", result[1]["name"])
|
||||
triggerBreaker(m)
|
||||
assert.Equal(t, errDummy, m.Aggregate(context.Background(), &result, bson.D{}))
|
||||
}
|
||||
|
||||
func TestModel_DeleteMany(t *testing.T) {
|
||||
mt := mtest.New(t, mtest.NewOptions().ClientType(mtest.Mock))
|
||||
mt.Run("test", func(mt *mtest.T) {
|
||||
m := createModel(mt)
|
||||
mt.AddMockResponses(mtest.CreateSuccessResponse(bson.D{{Key: "n", Value: 1}}...))
|
||||
val, err := m.DeleteMany(context.Background(), bson.D{})
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, int64(1), val)
|
||||
|
||||
triggerBreaker(m)
|
||||
_, err = m.DeleteMany(context.Background(), bson.D{})
|
||||
assert.Equal(t, errDummy, err)
|
||||
})
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
mockMonCollection := NewMockmonCollection(ctrl)
|
||||
mockedMonClient := NewMockmonClient(ctrl)
|
||||
mockMonCollection.EXPECT().DeleteMany(gomock.Any(), gomock.Any(), gomock.Any()).Return(&mongo.DeleteResult{}, nil)
|
||||
m := newTestModel("foo", mockedMonClient, mockMonCollection, breaker.GetBreaker("test"))
|
||||
_, err := m.DeleteMany(context.Background(), bson.D{})
|
||||
assert.Nil(t, err)
|
||||
triggerBreaker(m)
|
||||
_, err = m.DeleteMany(context.Background(), bson.D{})
|
||||
assert.Equal(t, errDummy, err)
|
||||
}
|
||||
|
||||
func TestModel_DeleteOne(t *testing.T) {
|
||||
mt := mtest.New(t, mtest.NewOptions().ClientType(mtest.Mock))
|
||||
mt.Run("test", func(mt *mtest.T) {
|
||||
m := createModel(mt)
|
||||
mt.AddMockResponses(mtest.CreateSuccessResponse(bson.D{{Key: "n", Value: 1}}...))
|
||||
val, err := m.DeleteOne(context.Background(), bson.D{})
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, int64(1), val)
|
||||
|
||||
triggerBreaker(m)
|
||||
_, err = m.DeleteOne(context.Background(), bson.D{})
|
||||
assert.Equal(t, errDummy, err)
|
||||
})
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
mockMonCollection := NewMockmonCollection(ctrl)
|
||||
mockedMonClient := NewMockmonClient(ctrl)
|
||||
mockMonCollection.EXPECT().DeleteOne(gomock.Any(), gomock.Any(), gomock.Any()).Return(&mongo.DeleteResult{}, nil)
|
||||
m := newTestModel("foo", mockedMonClient, mockMonCollection, breaker.GetBreaker("test"))
|
||||
_, err := m.DeleteOne(context.Background(), bson.D{})
|
||||
assert.Nil(t, err)
|
||||
triggerBreaker(m)
|
||||
_, err = m.DeleteOne(context.Background(), bson.D{})
|
||||
assert.Equal(t, errDummy, err)
|
||||
}
|
||||
|
||||
func TestModel_Find(t *testing.T) {
|
||||
mt := mtest.New(t, mtest.NewOptions().ClientType(mtest.Mock))
|
||||
mt.Run("test", func(mt *mtest.T) {
|
||||
m := createModel(mt)
|
||||
find := mtest.CreateCursorResponse(
|
||||
1,
|
||||
"DBName.CollectionName",
|
||||
mtest.FirstBatch,
|
||||
bson.D{
|
||||
{Key: "name", Value: "John"},
|
||||
})
|
||||
getMore := mtest.CreateCursorResponse(
|
||||
1,
|
||||
"DBName.CollectionName",
|
||||
mtest.NextBatch,
|
||||
bson.D{
|
||||
{Key: "name", Value: "Mary"},
|
||||
})
|
||||
killCursors := mtest.CreateCursorResponse(
|
||||
0,
|
||||
"DBName.CollectionName",
|
||||
mtest.NextBatch)
|
||||
mt.AddMockResponses(find, getMore, killCursors)
|
||||
var result []any
|
||||
err := m.Find(context.Background(), &result, bson.D{})
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, 2, len(result))
|
||||
assert.Equal(t, "John", result[0].(bson.D).Map()["name"])
|
||||
assert.Equal(t, "Mary", result[1].(bson.D).Map()["name"])
|
||||
|
||||
triggerBreaker(m)
|
||||
assert.Equal(t, errDummy, m.Find(context.Background(), &result, bson.D{}))
|
||||
})
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
mockMonCollection := NewMockmonCollection(ctrl)
|
||||
mockedMonClient := NewMockmonClient(ctrl)
|
||||
cursor, err := mongo.NewCursorFromDocuments([]any{
|
||||
bson.M{
|
||||
"name": "John",
|
||||
},
|
||||
bson.M{
|
||||
"name": "Mary",
|
||||
},
|
||||
}, nil, nil)
|
||||
assert.NoError(t, err)
|
||||
mockMonCollection.EXPECT().Find(gomock.Any(), gomock.Any(), gomock.Any()).Return(cursor, nil)
|
||||
m := newTestModel("foo", mockedMonClient, mockMonCollection, breaker.GetBreaker("test"))
|
||||
var result []bson.M
|
||||
err = m.Find(context.Background(), &result, bson.D{})
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, 2, len(result))
|
||||
assert.Equal(t, "John", result[0]["name"])
|
||||
assert.Equal(t, "Mary", result[1]["name"])
|
||||
triggerBreaker(m)
|
||||
assert.Equal(t, errDummy, m.Find(context.Background(), &result, bson.D{}))
|
||||
}
|
||||
|
||||
func TestModel_FindOne(t *testing.T) {
|
||||
mt := mtest.New(t, mtest.NewOptions().ClientType(mtest.Mock))
|
||||
mt.Run("test", func(mt *mtest.T) {
|
||||
m := createModel(mt)
|
||||
find := mtest.CreateCursorResponse(
|
||||
1,
|
||||
"DBName.CollectionName",
|
||||
mtest.FirstBatch,
|
||||
bson.D{
|
||||
{Key: "name", Value: "John"},
|
||||
})
|
||||
killCursors := mtest.CreateCursorResponse(
|
||||
0,
|
||||
"DBName.CollectionName",
|
||||
mtest.NextBatch)
|
||||
mt.AddMockResponses(find, killCursors)
|
||||
var result bson.D
|
||||
err := m.FindOne(context.Background(), &result, bson.D{})
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, "John", result.Map()["name"])
|
||||
|
||||
triggerBreaker(m)
|
||||
assert.Equal(t, errDummy, m.FindOne(context.Background(), &result, bson.D{}))
|
||||
})
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
mockMonCollection := NewMockmonCollection(ctrl)
|
||||
mockedMonClient := NewMockmonClient(ctrl)
|
||||
mockMonCollection.EXPECT().FindOne(gomock.Any(), gomock.Any(), gomock.Any()).Return(mongo.NewSingleResultFromDocument(bson.M{"name": "John"}, nil, nil))
|
||||
m := newTestModel("foo", mockedMonClient, mockMonCollection, breaker.GetBreaker("test"))
|
||||
var result bson.M
|
||||
err := m.FindOne(context.Background(), &result, bson.D{})
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, "John", result["name"])
|
||||
triggerBreaker(m)
|
||||
assert.Equal(t, errDummy, m.FindOne(context.Background(), &result, bson.D{}))
|
||||
}
|
||||
|
||||
func TestModel_FindOneAndDelete(t *testing.T) {
|
||||
mt := mtest.New(t, mtest.NewOptions().ClientType(mtest.Mock))
|
||||
mt.Run("test", func(mt *mtest.T) {
|
||||
m := createModel(mt)
|
||||
mt.AddMockResponses(mtest.CreateSuccessResponse(bson.D{
|
||||
{Key: "value", Value: bson.D{{Key: "name", Value: "John"}}},
|
||||
}...))
|
||||
var result bson.D
|
||||
err := m.FindOneAndDelete(context.Background(), &result, bson.D{})
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, "John", result.Map()["name"])
|
||||
|
||||
triggerBreaker(m)
|
||||
assert.Equal(t, errDummy, m.FindOneAndDelete(context.Background(), &result, bson.D{}))
|
||||
})
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
mockMonCollection := NewMockmonCollection(ctrl)
|
||||
mockedMonClient := NewMockmonClient(ctrl)
|
||||
mockMonCollection.EXPECT().FindOneAndDelete(gomock.Any(), gomock.Any(), gomock.Any()).Return(mongo.NewSingleResultFromDocument(bson.M{"name": "John"}, nil, nil))
|
||||
m := newTestModel("foo", mockedMonClient, mockMonCollection, breaker.GetBreaker("test"))
|
||||
var result bson.M
|
||||
err := m.FindOneAndDelete(context.Background(), &result, bson.M{})
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, "John", result["name"])
|
||||
triggerBreaker(m)
|
||||
assert.Equal(t, errDummy, m.FindOneAndDelete(context.Background(), &result, bson.D{}))
|
||||
}
|
||||
|
||||
func TestModel_FindOneAndReplace(t *testing.T) {
|
||||
mt := mtest.New(t, mtest.NewOptions().ClientType(mtest.Mock))
|
||||
mt.Run("test", func(mt *mtest.T) {
|
||||
m := createModel(mt)
|
||||
mt.AddMockResponses(mtest.CreateSuccessResponse(bson.D{
|
||||
{Key: "value", Value: bson.D{{Key: "name", Value: "John"}}},
|
||||
}...))
|
||||
var result bson.D
|
||||
err := m.FindOneAndReplace(context.Background(), &result, bson.D{}, bson.D{
|
||||
{Key: "name", Value: "Mary"},
|
||||
})
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, "John", result.Map()["name"])
|
||||
|
||||
triggerBreaker(m)
|
||||
assert.Equal(t, errDummy, m.FindOneAndReplace(context.Background(), &result, bson.D{}, bson.D{
|
||||
{Key: "name", Value: "Mary"},
|
||||
}))
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
mockMonCollection := NewMockmonCollection(ctrl)
|
||||
mockedMonClient := NewMockmonClient(ctrl)
|
||||
mockMonCollection.EXPECT().FindOneAndReplace(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(mongo.NewSingleResultFromDocument(bson.M{"name": "John"}, nil, nil))
|
||||
m := newTestModel("foo", mockedMonClient, mockMonCollection, breaker.GetBreaker("test"))
|
||||
var result bson.M
|
||||
err := m.FindOneAndReplace(context.Background(), &result, bson.D{}, bson.D{
|
||||
{Key: "name", Value: "Mary"},
|
||||
})
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, "John", result["name"])
|
||||
triggerBreaker(m)
|
||||
assert.Equal(t, errDummy, m.FindOneAndReplace(context.Background(), &result, bson.D{}, bson.D{
|
||||
{Key: "name", Value: "Mary"},
|
||||
}))
|
||||
}
|
||||
|
||||
func TestModel_FindOneAndUpdate(t *testing.T) {
|
||||
mt := mtest.New(t, mtest.NewOptions().ClientType(mtest.Mock))
|
||||
mt.Run("test", func(mt *mtest.T) {
|
||||
m := createModel(mt)
|
||||
mt.AddMockResponses(mtest.CreateSuccessResponse(bson.D{
|
||||
{Key: "value", Value: bson.D{{Key: "name", Value: "John"}}},
|
||||
}...))
|
||||
var result bson.D
|
||||
err := m.FindOneAndUpdate(context.Background(), &result, bson.D{}, bson.D{
|
||||
{Key: "$set", Value: bson.D{{Key: "name", Value: "Mary"}}},
|
||||
})
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, "John", result.Map()["name"])
|
||||
|
||||
triggerBreaker(m)
|
||||
assert.Equal(t, errDummy, m.FindOneAndUpdate(context.Background(), &result, bson.D{}, bson.D{
|
||||
{Key: "$set", Value: bson.D{{Key: "name", Value: "Mary"}}},
|
||||
}))
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
mockMonCollection := NewMockmonCollection(ctrl)
|
||||
mockedMonClient := NewMockmonClient(ctrl)
|
||||
mockMonCollection.EXPECT().FindOneAndUpdate(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(mongo.NewSingleResultFromDocument(bson.M{"name": "John"}, nil, nil))
|
||||
m := newTestModel("foo", mockedMonClient, mockMonCollection, breaker.GetBreaker("test"))
|
||||
var result bson.M
|
||||
err := m.FindOneAndUpdate(context.Background(), &result, bson.D{}, bson.D{
|
||||
{Key: "$set", Value: bson.D{{Key: "name", Value: "Mary"}}},
|
||||
})
|
||||
}
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, "John", result["name"])
|
||||
triggerBreaker(m)
|
||||
assert.Equal(t, errDummy, m.FindOneAndUpdate(context.Background(), &result, bson.D{}, bson.D{
|
||||
{Key: "$set", Value: bson.D{{Key: "name", Value: "Mary"}}},
|
||||
}))
|
||||
|
||||
func createModel(mt *mtest.T) *Model {
|
||||
Inject(mt.Name(), mt.Client)
|
||||
return MustNewModel(mt.Name(), mt.DB.Name(), mt.Coll.Name())
|
||||
}
|
||||
|
||||
func triggerBreaker(m *Model) {
|
||||
m.Collection.(*decoratedCollection).brk = new(dropBreaker)
|
||||
}
|
||||
|
||||
func TestMustNewModel(t *testing.T) {
|
||||
Inject("mongodb://localhost:27017", &mongo.Client{})
|
||||
MustNewModel("mongodb://localhost:27017", "test", "test")
|
||||
}
|
||||
|
||||
func TestNewModel(t *testing.T) {
|
||||
NewModel("mongo://localhost:27018", "test", "test")
|
||||
Inject("mongodb://localhost:27018", &mongo.Client{})
|
||||
NewModel("mongodb://localhost:27018", "test", "test")
|
||||
}
|
||||
|
||||
func Test_newModel(t *testing.T) {
|
||||
Inject("mongodb://localhost:27019", &mongo.Client{})
|
||||
newModel("mongodb://localhost:27019", nil, nil, nil)
|
||||
}
|
||||
|
||||
func Test_mockMonClient_StartSession(t *testing.T) {
|
||||
md := drivertest.NewMockDeployment()
|
||||
opts := options.Client()
|
||||
opts.Deployment = md
|
||||
client, err := mongo.Connect(opts)
|
||||
assert.Nil(t, err)
|
||||
m := wrappedMonClient{
|
||||
c: client,
|
||||
}
|
||||
_, err = m.StartSession()
|
||||
assert.Nil(t, err)
|
||||
}
|
||||
|
||||
func newTestModel(name string, cli monClient, coll monCollection, brk breaker.Breaker,
|
||||
opts ...Option) *Model {
|
||||
return &Model{
|
||||
name: name,
|
||||
Collection: newTestCollection(coll, breaker.GetBreaker("localhost")),
|
||||
cli: cli,
|
||||
brk: brk,
|
||||
opts: opts,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -5,9 +5,8 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/zeromicro/go-zero/core/syncx"
|
||||
"go.mongodb.org/mongo-driver/bson"
|
||||
"go.mongodb.org/mongo-driver/bson/bsoncodec"
|
||||
mopt "go.mongodb.org/mongo-driver/mongo/options"
|
||||
"go.mongodb.org/mongo-driver/v2/bson"
|
||||
"go.mongodb.org/mongo-driver/v2/mongo/options"
|
||||
)
|
||||
|
||||
const defaultTimeout = time.Second * 3
|
||||
@@ -20,16 +19,16 @@ var (
|
||||
|
||||
type (
|
||||
// Option defines the method to customize a mongo model.
|
||||
Option func(opts *options)
|
||||
Option func(opts *clientOptions)
|
||||
|
||||
// TypeCodec is a struct that stores specific type Encoder/Decoder.
|
||||
TypeCodec struct {
|
||||
ValueType reflect.Type
|
||||
Encoder bsoncodec.ValueEncoder
|
||||
Decoder bsoncodec.ValueDecoder
|
||||
Encoder bson.ValueEncoder
|
||||
Decoder bson.ValueDecoder
|
||||
}
|
||||
|
||||
options = mopt.ClientOptions
|
||||
clientOptions = options.ClientOptions
|
||||
)
|
||||
|
||||
// DisableLog disables logging of mongo commands, includes info and slow logs.
|
||||
@@ -50,14 +49,14 @@ func SetSlowThreshold(threshold time.Duration) {
|
||||
|
||||
// WithTimeout set the mon client operation timeout.
|
||||
func WithTimeout(timeout time.Duration) Option {
|
||||
return func(opts *options) {
|
||||
return func(opts *clientOptions) {
|
||||
opts.SetTimeout(timeout)
|
||||
}
|
||||
}
|
||||
|
||||
// WithTypeCodec registers TypeCodecs to convert custom types.
|
||||
func WithTypeCodec(typeCodecs ...TypeCodec) Option {
|
||||
return func(opts *options) {
|
||||
return func(opts *clientOptions) {
|
||||
registry := bson.NewRegistry()
|
||||
for _, v := range typeCodecs {
|
||||
registry.RegisterTypeEncoder(v.ValueType, v.Encoder)
|
||||
@@ -68,7 +67,7 @@ func WithTypeCodec(typeCodecs ...TypeCodec) Option {
|
||||
}
|
||||
|
||||
func defaultTimeoutOption() Option {
|
||||
return func(opts *options) {
|
||||
return func(opts *clientOptions) {
|
||||
opts.SetTimeout(defaultTimeout)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -7,9 +7,8 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"go.mongodb.org/mongo-driver/bson/bsoncodec"
|
||||
"go.mongodb.org/mongo-driver/bson/bsonrw"
|
||||
mopt "go.mongodb.org/mongo-driver/mongo/options"
|
||||
"go.mongodb.org/mongo-driver/v2/bson"
|
||||
"go.mongodb.org/mongo-driver/v2/mongo/options"
|
||||
)
|
||||
|
||||
func TestSetSlowThreshold(t *testing.T) {
|
||||
@@ -19,13 +18,13 @@ func TestSetSlowThreshold(t *testing.T) {
|
||||
}
|
||||
|
||||
func Test_defaultTimeoutOption(t *testing.T) {
|
||||
opts := mopt.Client()
|
||||
opts := options.Client()
|
||||
defaultTimeoutOption()(opts)
|
||||
assert.Equal(t, defaultTimeout, *opts.Timeout)
|
||||
}
|
||||
|
||||
func TestWithTimeout(t *testing.T) {
|
||||
opts := mopt.Client()
|
||||
opts := options.Client()
|
||||
WithTimeout(time.Second)(opts)
|
||||
assert.Equal(t, time.Second, *opts.Timeout)
|
||||
}
|
||||
@@ -57,10 +56,11 @@ func TestDisableInfoLog(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestWithRegistryForTimestampRegisterType(t *testing.T) {
|
||||
opts := mopt.Client()
|
||||
opts := options.Client()
|
||||
|
||||
// mongoDateTimeEncoder allow user convert time.Time to primitive.DateTime.
|
||||
var mongoDateTimeEncoder bsoncodec.ValueEncoderFunc = func(ect bsoncodec.EncodeContext, w bsonrw.ValueWriter, value reflect.Value) error {
|
||||
var mongoDateTimeEncoder bson.ValueEncoderFunc = func(ect bson.EncodeContext,
|
||||
w bson.ValueWriter, value reflect.Value) error {
|
||||
// Use reflect, determine if it can be converted to time.Time.
|
||||
dec, ok := value.Interface().(time.Time)
|
||||
if !ok {
|
||||
@@ -70,7 +70,8 @@ func TestWithRegistryForTimestampRegisterType(t *testing.T) {
|
||||
}
|
||||
|
||||
// mongoDateTimeEncoder allow user convert primitive.DateTime to time.Time.
|
||||
var mongoDateTimeDecoder bsoncodec.ValueDecoderFunc = func(ect bsoncodec.DecodeContext, r bsonrw.ValueReader, value reflect.Value) error {
|
||||
var mongoDateTimeDecoder bson.ValueDecoderFunc = func(ect bson.DecodeContext,
|
||||
r bson.ValueReader, value reflect.Value) error {
|
||||
primTime, err := r.ReadDateTime()
|
||||
if err != nil {
|
||||
return fmt.Errorf("error reading primitive.DateTime from ValueReader: %v", err)
|
||||
|
||||
@@ -5,7 +5,7 @@ import (
|
||||
|
||||
"github.com/zeromicro/go-zero/core/errorx"
|
||||
"github.com/zeromicro/go-zero/core/trace"
|
||||
"go.mongodb.org/mongo-driver/mongo"
|
||||
"go.mongodb.org/mongo-driver/v2/mongo"
|
||||
"go.opentelemetry.io/otel/attribute"
|
||||
"go.opentelemetry.io/otel/codes"
|
||||
oteltrace "go.opentelemetry.io/otel/trace"
|
||||
|
||||
@@ -8,8 +8,8 @@ import (
|
||||
"github.com/zeromicro/go-zero/core/stores/mon"
|
||||
"github.com/zeromicro/go-zero/core/stores/redis"
|
||||
"github.com/zeromicro/go-zero/core/syncx"
|
||||
"go.mongodb.org/mongo-driver/mongo"
|
||||
mopt "go.mongodb.org/mongo-driver/mongo/options"
|
||||
"go.mongodb.org/mongo-driver/v2/mongo"
|
||||
"go.mongodb.org/mongo-driver/v2/mongo/options"
|
||||
)
|
||||
|
||||
var (
|
||||
@@ -78,7 +78,7 @@ func (mm *Model) DelCache(ctx context.Context, keys ...string) error {
|
||||
|
||||
// DeleteOne deletes the document with given filter, and remove it from cache.
|
||||
func (mm *Model) DeleteOne(ctx context.Context, key string, filter any,
|
||||
opts ...*mopt.DeleteOptions) (int64, error) {
|
||||
opts ...options.Lister[options.DeleteOneOptions]) (int64, error) {
|
||||
val, err := mm.Model.DeleteOne(ctx, filter, opts...)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
@@ -93,13 +93,13 @@ func (mm *Model) DeleteOne(ctx context.Context, key string, filter any,
|
||||
|
||||
// DeleteOneNoCache deletes the document with given filter.
|
||||
func (mm *Model) DeleteOneNoCache(ctx context.Context, filter any,
|
||||
opts ...*mopt.DeleteOptions) (int64, error) {
|
||||
opts ...options.Lister[options.DeleteOneOptions]) (int64, error) {
|
||||
return mm.Model.DeleteOne(ctx, filter, opts...)
|
||||
}
|
||||
|
||||
// FindOne unmarshals a record into v with given key and query.
|
||||
func (mm *Model) FindOne(ctx context.Context, key string, v, filter any,
|
||||
opts ...*mopt.FindOneOptions) error {
|
||||
opts ...options.Lister[options.FindOneOptions]) error {
|
||||
return mm.cache.TakeCtx(ctx, v, key, func(v any) error {
|
||||
return mm.Model.FindOne(ctx, v, filter, opts...)
|
||||
})
|
||||
@@ -107,13 +107,13 @@ func (mm *Model) FindOne(ctx context.Context, key string, v, filter any,
|
||||
|
||||
// FindOneNoCache unmarshals a record into v with query, without cache.
|
||||
func (mm *Model) FindOneNoCache(ctx context.Context, v, filter any,
|
||||
opts ...*mopt.FindOneOptions) error {
|
||||
opts ...options.Lister[options.FindOneOptions]) error {
|
||||
return mm.Model.FindOne(ctx, v, filter, opts...)
|
||||
}
|
||||
|
||||
// FindOneAndDelete deletes the document with given filter, and unmarshals it into v.
|
||||
func (mm *Model) FindOneAndDelete(ctx context.Context, key string, v, filter any,
|
||||
opts ...*mopt.FindOneAndDeleteOptions) error {
|
||||
opts ...options.Lister[options.FindOneAndDeleteOptions]) error {
|
||||
if err := mm.Model.FindOneAndDelete(ctx, v, filter, opts...); err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -123,13 +123,13 @@ func (mm *Model) FindOneAndDelete(ctx context.Context, key string, v, filter any
|
||||
|
||||
// FindOneAndDeleteNoCache deletes the document with given filter, and unmarshals it into v.
|
||||
func (mm *Model) FindOneAndDeleteNoCache(ctx context.Context, v, filter any,
|
||||
opts ...*mopt.FindOneAndDeleteOptions) error {
|
||||
opts ...options.Lister[options.FindOneAndDeleteOptions]) error {
|
||||
return mm.Model.FindOneAndDelete(ctx, v, filter, opts...)
|
||||
}
|
||||
|
||||
// FindOneAndReplace replaces the document with given filter with replacement, and unmarshals it into v.
|
||||
func (mm *Model) FindOneAndReplace(ctx context.Context, key string, v, filter any,
|
||||
replacement any, opts ...*mopt.FindOneAndReplaceOptions) error {
|
||||
replacement any, opts ...options.Lister[options.FindOneAndReplaceOptions]) error {
|
||||
if err := mm.Model.FindOneAndReplace(ctx, v, filter, replacement, opts...); err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -139,13 +139,13 @@ func (mm *Model) FindOneAndReplace(ctx context.Context, key string, v, filter an
|
||||
|
||||
// FindOneAndReplaceNoCache replaces the document with given filter with replacement, and unmarshals it into v.
|
||||
func (mm *Model) FindOneAndReplaceNoCache(ctx context.Context, v, filter any,
|
||||
replacement any, opts ...*mopt.FindOneAndReplaceOptions) error {
|
||||
replacement any, opts ...options.Lister[options.FindOneAndReplaceOptions]) error {
|
||||
return mm.Model.FindOneAndReplace(ctx, v, filter, replacement, opts...)
|
||||
}
|
||||
|
||||
// FindOneAndUpdate updates the document with given filter with update, and unmarshals it into v.
|
||||
func (mm *Model) FindOneAndUpdate(ctx context.Context, key string, v, filter any,
|
||||
update any, opts ...*mopt.FindOneAndUpdateOptions) error {
|
||||
update any, opts ...options.Lister[options.FindOneAndUpdateOptions]) error {
|
||||
if err := mm.Model.FindOneAndUpdate(ctx, v, filter, update, opts...); err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -155,7 +155,7 @@ func (mm *Model) FindOneAndUpdate(ctx context.Context, key string, v, filter any
|
||||
|
||||
// FindOneAndUpdateNoCache updates the document with given filter with update, and unmarshals it into v.
|
||||
func (mm *Model) FindOneAndUpdateNoCache(ctx context.Context, v, filter any,
|
||||
update any, opts ...*mopt.FindOneAndUpdateOptions) error {
|
||||
update any, opts ...options.Lister[options.FindOneAndUpdateOptions]) error {
|
||||
return mm.Model.FindOneAndUpdate(ctx, v, filter, update, opts...)
|
||||
}
|
||||
|
||||
@@ -166,7 +166,7 @@ func (mm *Model) GetCache(key string, v any) error {
|
||||
|
||||
// InsertOne inserts a single document into the collection, and remove the cache placeholder.
|
||||
func (mm *Model) InsertOne(ctx context.Context, key string, document any,
|
||||
opts ...*mopt.InsertOneOptions) (*mongo.InsertOneResult, error) {
|
||||
opts ...options.Lister[options.InsertOneOptions]) (*mongo.InsertOneResult, error) {
|
||||
res, err := mm.Model.InsertOne(ctx, document, opts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -181,13 +181,13 @@ func (mm *Model) InsertOne(ctx context.Context, key string, document any,
|
||||
|
||||
// InsertOneNoCache inserts a single document into the collection.
|
||||
func (mm *Model) InsertOneNoCache(ctx context.Context, document any,
|
||||
opts ...*mopt.InsertOneOptions) (*mongo.InsertOneResult, error) {
|
||||
opts ...options.Lister[options.InsertOneOptions]) (*mongo.InsertOneResult, error) {
|
||||
return mm.Model.InsertOne(ctx, document, opts...)
|
||||
}
|
||||
|
||||
// ReplaceOne replaces a single document in the collection, and remove the cache.
|
||||
func (mm *Model) ReplaceOne(ctx context.Context, key string, filter, replacement any,
|
||||
opts ...*mopt.ReplaceOptions) (*mongo.UpdateResult, error) {
|
||||
opts ...options.Lister[options.ReplaceOptions]) (*mongo.UpdateResult, error) {
|
||||
res, err := mm.Model.ReplaceOne(ctx, filter, replacement, opts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -202,7 +202,7 @@ func (mm *Model) ReplaceOne(ctx context.Context, key string, filter, replacement
|
||||
|
||||
// ReplaceOneNoCache replaces a single document in the collection.
|
||||
func (mm *Model) ReplaceOneNoCache(ctx context.Context, filter, replacement any,
|
||||
opts ...*mopt.ReplaceOptions) (*mongo.UpdateResult, error) {
|
||||
opts ...options.Lister[options.ReplaceOptions]) (*mongo.UpdateResult, error) {
|
||||
return mm.Model.ReplaceOne(ctx, filter, replacement, opts...)
|
||||
}
|
||||
|
||||
@@ -213,7 +213,7 @@ func (mm *Model) SetCache(key string, v any) error {
|
||||
|
||||
// UpdateByID updates the document with given id with update, and remove the cache.
|
||||
func (mm *Model) UpdateByID(ctx context.Context, key string, id, update any,
|
||||
opts ...*mopt.UpdateOptions) (*mongo.UpdateResult, error) {
|
||||
opts ...options.Lister[options.UpdateOneOptions]) (*mongo.UpdateResult, error) {
|
||||
res, err := mm.Model.UpdateByID(ctx, id, update, opts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -228,13 +228,13 @@ func (mm *Model) UpdateByID(ctx context.Context, key string, id, update any,
|
||||
|
||||
// UpdateByIDNoCache updates the document with given id with update.
|
||||
func (mm *Model) UpdateByIDNoCache(ctx context.Context, id, update any,
|
||||
opts ...*mopt.UpdateOptions) (*mongo.UpdateResult, error) {
|
||||
opts ...options.Lister[options.UpdateOneOptions]) (*mongo.UpdateResult, error) {
|
||||
return mm.Model.UpdateByID(ctx, id, update, opts...)
|
||||
}
|
||||
|
||||
// UpdateMany updates the documents that match filter with update, and remove the cache.
|
||||
func (mm *Model) UpdateMany(ctx context.Context, keys []string, filter, update any,
|
||||
opts ...*mopt.UpdateOptions) (*mongo.UpdateResult, error) {
|
||||
opts ...options.Lister[options.UpdateManyOptions]) (*mongo.UpdateResult, error) {
|
||||
res, err := mm.Model.UpdateMany(ctx, filter, update, opts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -249,13 +249,13 @@ func (mm *Model) UpdateMany(ctx context.Context, keys []string, filter, update a
|
||||
|
||||
// UpdateManyNoCache updates the documents that match filter with update.
|
||||
func (mm *Model) UpdateManyNoCache(ctx context.Context, filter, update any,
|
||||
opts ...*mopt.UpdateOptions) (*mongo.UpdateResult, error) {
|
||||
opts ...options.Lister[options.UpdateManyOptions]) (*mongo.UpdateResult, error) {
|
||||
return mm.Model.UpdateMany(ctx, filter, update, opts...)
|
||||
}
|
||||
|
||||
// UpdateOne updates the first document that matches filter with update, and remove the cache.
|
||||
func (mm *Model) UpdateOne(ctx context.Context, key string, filter, update any,
|
||||
opts ...*mopt.UpdateOptions) (*mongo.UpdateResult, error) {
|
||||
opts ...options.Lister[options.UpdateOneOptions]) (*mongo.UpdateResult, error) {
|
||||
res, err := mm.Model.UpdateOne(ctx, filter, update, opts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -270,6 +270,6 @@ func (mm *Model) UpdateOne(ctx context.Context, key string, filter, update any,
|
||||
|
||||
// UpdateOneNoCache updates the first document that matches filter with update.
|
||||
func (mm *Model) UpdateOneNoCache(ctx context.Context, filter, update any,
|
||||
opts ...*mopt.UpdateOptions) (*mongo.UpdateResult, error) {
|
||||
opts ...options.Lister[options.UpdateOneOptions]) (*mongo.UpdateResult, error) {
|
||||
return mm.Model.UpdateOne(ctx, filter, update, opts...)
|
||||
}
|
||||
|
||||
@@ -8,506 +8,519 @@ import (
|
||||
|
||||
"github.com/alicebob/miniredis/v2"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/zeromicro/go-zero/core/logx"
|
||||
"github.com/zeromicro/go-zero/core/stores/cache"
|
||||
"github.com/zeromicro/go-zero/core/stores/mon"
|
||||
"github.com/zeromicro/go-zero/core/stores/redis"
|
||||
"go.mongodb.org/mongo-driver/bson"
|
||||
"go.mongodb.org/mongo-driver/mongo/integration/mtest"
|
||||
"go.mongodb.org/mongo-driver/v2/bson"
|
||||
"go.mongodb.org/mongo-driver/v2/mongo"
|
||||
"go.uber.org/mock/gomock"
|
||||
)
|
||||
|
||||
func TestNewModel(t *testing.T) {
|
||||
mt := mtest.New(t, mtest.NewOptions().ClientType(mtest.Mock))
|
||||
mt.Run("test", func(mt *mtest.T) {
|
||||
_, err := newModel("foo", mt.DB.Name(), mt.Coll.Name(), nil)
|
||||
assert.NotNil(mt, err)
|
||||
func TestMustNewModel(t *testing.T) {
|
||||
s, err := miniredis.Run()
|
||||
assert.Nil(t, err)
|
||||
original := logx.ExitOnFatal.True()
|
||||
logx.ExitOnFatal.Set(false)
|
||||
defer logx.ExitOnFatal.Set(original)
|
||||
|
||||
assert.Panics(t, func() {
|
||||
MustNewModel("foo", "db", "collectino", cache.CacheConf{
|
||||
cache.NodeConf{
|
||||
RedisConf: redis.RedisConf{
|
||||
Host: s.Addr(),
|
||||
Type: redis.NodeType,
|
||||
},
|
||||
Weight: 100,
|
||||
}})
|
||||
})
|
||||
}
|
||||
|
||||
func TestMustNewNodeModel(t *testing.T) {
|
||||
s, err := miniredis.Run()
|
||||
assert.Nil(t, err)
|
||||
original := logx.ExitOnFatal.True()
|
||||
logx.ExitOnFatal.Set(false)
|
||||
defer logx.ExitOnFatal.Set(original)
|
||||
|
||||
assert.Panics(t, func() {
|
||||
MustNewNodeModel("foo", "db", "collectino", redis.New(s.Addr()))
|
||||
})
|
||||
}
|
||||
|
||||
func TestNewModel(t *testing.T) {
|
||||
s, err := miniredis.Run()
|
||||
assert.Nil(t, err)
|
||||
_, err = NewModel("foo", "db", "coll", cache.CacheConf{
|
||||
cache.NodeConf{
|
||||
RedisConf: redis.RedisConf{
|
||||
Host: s.Addr(),
|
||||
Type: redis.NodeType,
|
||||
},
|
||||
Weight: 100,
|
||||
},
|
||||
})
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
func TestNewNodeModel(t *testing.T) {
|
||||
_, err := NewNodeModel("foo", "db", "coll", nil)
|
||||
assert.NotNil(t, err)
|
||||
}
|
||||
|
||||
func TestNewModelWithCache(t *testing.T) {
|
||||
_, err := NewModelWithCache("foo", "db", "coll", nil)
|
||||
assert.NotNil(t, err)
|
||||
}
|
||||
|
||||
func Test_newModel(t *testing.T) {
|
||||
mon.Inject("mongodb://localhost:27018", &mongo.Client{})
|
||||
model, err := newModel("mongodb://localhost:27018", "db", "collection", nil)
|
||||
assert.Nil(t, err)
|
||||
assert.NotNil(t, model)
|
||||
}
|
||||
|
||||
func TestModel_DelCache(t *testing.T) {
|
||||
mt := mtest.New(t, mtest.NewOptions().ClientType(mtest.Mock))
|
||||
mt.Run("test", func(mt *mtest.T) {
|
||||
m := createModel(t, mt)
|
||||
assert.Nil(t, m.cache.Set("foo", "bar"))
|
||||
assert.Nil(t, m.cache.Set("bar", "baz"))
|
||||
assert.Nil(t, m.DelCache(context.Background(), "foo", "bar"))
|
||||
var v string
|
||||
assert.True(t, m.cache.IsNotFound(m.cache.Get("foo", &v)))
|
||||
assert.True(t, m.cache.IsNotFound(m.cache.Get("bar", &v)))
|
||||
})
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
mockCollection := mon.NewMockCollection(ctrl)
|
||||
m := createModel(t, mockCollection)
|
||||
assert.Nil(t, m.cache.Set("foo", "bar"))
|
||||
assert.Nil(t, m.cache.Set("bar", "baz"))
|
||||
assert.Nil(t, m.DelCache(context.Background(), "foo", "bar"))
|
||||
var v string
|
||||
assert.True(t, m.cache.IsNotFound(m.cache.Get("foo", &v)))
|
||||
assert.True(t, m.cache.IsNotFound(m.cache.Get("bar", &v)))
|
||||
}
|
||||
|
||||
func TestModel_DeleteOne(t *testing.T) {
|
||||
mt := mtest.New(t, mtest.NewOptions().ClientType(mtest.Mock))
|
||||
mt.Run("test", func(mt *mtest.T) {
|
||||
mt.AddMockResponses(mtest.CreateSuccessResponse(bson.D{{Key: "n", Value: 1}}...))
|
||||
m := createModel(t, mt)
|
||||
assert.Nil(t, m.cache.Set("foo", "bar"))
|
||||
val, err := m.DeleteOne(context.Background(), "foo", bson.D{{Key: "foo", Value: "bar"}})
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, int64(1), val)
|
||||
var v string
|
||||
assert.True(t, m.cache.IsNotFound(m.cache.Get("foo", &v)))
|
||||
_, err = m.DeleteOne(context.Background(), "foo", bson.D{{Key: "foo", Value: "bar"}})
|
||||
assert.NotNil(t, err)
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
mockCollection := mon.NewMockCollection(ctrl)
|
||||
mockCollection.EXPECT().DeleteOne(gomock.Any(), gomock.Any(), gomock.Any()).Return(&mongo.DeleteResult{DeletedCount: 1}, nil)
|
||||
m := createModel(t, mockCollection)
|
||||
assert.Nil(t, m.cache.Set("foo", "bar"))
|
||||
val, err := m.DeleteOne(context.Background(), "foo", bson.D{{Key: "foo", Value: "bar"}})
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, int64(1), val)
|
||||
var v string
|
||||
assert.True(t, m.cache.IsNotFound(m.cache.Get("foo", &v)))
|
||||
mockCollection.EXPECT().DeleteOne(gomock.Any(), gomock.Any(), gomock.Any()).Return(&mongo.DeleteResult{}, errMocked)
|
||||
_, err = m.DeleteOne(context.Background(), "foo", bson.D{{Key: "foo", Value: "bar"}})
|
||||
assert.NotNil(t, err)
|
||||
|
||||
m.cache = mockedCache{m.cache}
|
||||
mt.AddMockResponses(mtest.CreateSuccessResponse(bson.D{{Key: "n", Value: 1}}...))
|
||||
_, err = m.DeleteOne(context.Background(), "foo", bson.D{{Key: "foo", Value: "bar"}})
|
||||
assert.Equal(t, errMocked, err)
|
||||
})
|
||||
m.cache = mockedCache{m.cache}
|
||||
mockCollection.EXPECT().DeleteOne(gomock.Any(), gomock.Any(), gomock.Any()).Return(&mongo.DeleteResult{}, nil)
|
||||
_, err = m.DeleteOne(context.Background(), "foo", bson.D{{Key: "foo", Value: "bar"}})
|
||||
assert.Equal(t, errMocked, err)
|
||||
}
|
||||
|
||||
func TestModel_DeleteOneNoCache(t *testing.T) {
|
||||
mt := mtest.New(t, mtest.NewOptions().ClientType(mtest.Mock))
|
||||
mt.Run("test", func(mt *mtest.T) {
|
||||
mt.AddMockResponses(mtest.CreateSuccessResponse(bson.D{{Key: "n", Value: 1}}...))
|
||||
m := createModel(t, mt)
|
||||
assert.Nil(t, m.cache.Set("foo", "bar"))
|
||||
val, err := m.DeleteOneNoCache(context.Background(), bson.D{{Key: "foo", Value: "bar"}})
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, int64(1), val)
|
||||
var v string
|
||||
assert.Nil(t, m.cache.Get("foo", &v))
|
||||
})
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
mockCollection := mon.NewMockCollection(ctrl)
|
||||
mockCollection.EXPECT().DeleteOne(gomock.Any(), gomock.Any(), gomock.Any()).Return(&mongo.DeleteResult{DeletedCount: 1}, nil)
|
||||
m := createModel(t, mockCollection)
|
||||
assert.Nil(t, m.cache.Set("foo", "bar"))
|
||||
val, err := m.DeleteOneNoCache(context.Background(), bson.D{{Key: "foo", Value: "bar"}})
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, int64(1), val)
|
||||
var v string
|
||||
assert.Nil(t, m.cache.Get("foo", &v))
|
||||
}
|
||||
|
||||
func TestModel_FindOne(t *testing.T) {
|
||||
mt := mtest.New(t, mtest.NewOptions().ClientType(mtest.Mock))
|
||||
mt.Run("test", func(mt *mtest.T) {
|
||||
resp := mtest.CreateCursorResponse(
|
||||
1,
|
||||
"DBName.CollectionName",
|
||||
mtest.FirstBatch,
|
||||
bson.D{
|
||||
{Key: "foo", Value: "bar"},
|
||||
})
|
||||
mt.AddMockResponses(resp)
|
||||
m := createModel(t, mt)
|
||||
var v struct {
|
||||
Foo string `bson:"foo"`
|
||||
}
|
||||
assert.Nil(t, m.FindOne(context.Background(), "foo", &v, bson.D{}))
|
||||
assert.Equal(t, "bar", v.Foo)
|
||||
assert.Nil(t, m.cache.Set("foo", "bar"))
|
||||
})
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
mockCollection := mon.NewMockCollection(ctrl)
|
||||
mockCollection.EXPECT().FindOne(gomock.Any(), gomock.Any(), gomock.Any()).
|
||||
Return(mongo.NewSingleResultFromDocument(bson.M{"foo": "bar"}, nil, nil), nil)
|
||||
m := createModel(t, mockCollection)
|
||||
var v struct {
|
||||
Foo string `bson:"foo"`
|
||||
}
|
||||
assert.Nil(t, m.FindOne(context.Background(), "foo", &v, bson.D{}))
|
||||
assert.Equal(t, "bar", v.Foo)
|
||||
assert.Nil(t, m.cache.Set("foo", "bar"))
|
||||
}
|
||||
|
||||
func TestModel_FindOneNoCache(t *testing.T) {
|
||||
mt := mtest.New(t, mtest.NewOptions().ClientType(mtest.Mock))
|
||||
mt.Run("test", func(mt *mtest.T) {
|
||||
resp := mtest.CreateCursorResponse(
|
||||
1,
|
||||
"DBName.CollectionName",
|
||||
mtest.FirstBatch,
|
||||
bson.D{
|
||||
{Key: "foo", Value: "bar"},
|
||||
})
|
||||
mt.AddMockResponses(resp)
|
||||
m := createModel(t, mt)
|
||||
v := struct {
|
||||
Foo string `bson:"foo"`
|
||||
}{}
|
||||
assert.Nil(t, m.FindOneNoCache(context.Background(), &v, bson.D{}))
|
||||
assert.Equal(t, "bar", v.Foo)
|
||||
})
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
mockCollection := mon.NewMockCollection(ctrl)
|
||||
mockCollection.EXPECT().FindOne(gomock.Any(), gomock.Any(), gomock.Any()).
|
||||
Return(mongo.NewSingleResultFromDocument(bson.M{"foo": "bar"}, nil, nil), nil)
|
||||
m := createModel(t, mockCollection)
|
||||
v := struct {
|
||||
Foo string `bson:"foo"`
|
||||
}{}
|
||||
assert.Nil(t, m.FindOneNoCache(context.Background(), &v, bson.D{}))
|
||||
assert.Equal(t, "bar", v.Foo)
|
||||
}
|
||||
|
||||
func TestModel_FindOneAndDelete(t *testing.T) {
|
||||
mt := mtest.New(t, mtest.NewOptions().ClientType(mtest.Mock))
|
||||
mt.Run("test", func(mt *mtest.T) {
|
||||
mt.AddMockResponses(mtest.CreateSuccessResponse(bson.D{
|
||||
{Key: "value", Value: bson.D{{Key: "foo", Value: "bar"}}},
|
||||
}...))
|
||||
m := createModel(t, mt)
|
||||
assert.Nil(t, m.cache.Set("foo", "bar"))
|
||||
v := struct {
|
||||
Foo string `bson:"foo"`
|
||||
}{}
|
||||
assert.Nil(t, m.FindOneAndDelete(context.Background(), "foo", &v, bson.D{}))
|
||||
assert.Equal(t, "bar", v.Foo)
|
||||
assert.True(t, m.cache.IsNotFound(m.cache.Get("foo", &v)))
|
||||
assert.NotNil(t, m.FindOneAndDelete(context.Background(), "foo", &v, bson.D{}))
|
||||
|
||||
m.cache = mockedCache{m.cache}
|
||||
mt.AddMockResponses(mtest.CreateSuccessResponse(bson.D{
|
||||
{Key: "value", Value: bson.D{{Key: "foo", Value: "bar"}}},
|
||||
}...))
|
||||
assert.Equal(t, errMocked, m.FindOneAndDelete(context.Background(), "foo", &v, bson.D{}))
|
||||
})
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
mockCollection := mon.NewMockCollection(ctrl)
|
||||
mockCollection.EXPECT().FindOneAndDelete(gomock.Any(), gomock.Any(), gomock.Any()).
|
||||
Return(mongo.NewSingleResultFromDocument(bson.M{"foo": "bar"}, nil, bson.NewRegistry()), nil)
|
||||
m := createModel(t, mockCollection)
|
||||
assert.Nil(t, m.cache.Set("foo", "bar"))
|
||||
v := struct {
|
||||
Foo string `bson:"foo"`
|
||||
}{}
|
||||
assert.Nil(t, m.FindOneAndDelete(context.Background(), "foo", &v, bson.D{}))
|
||||
assert.Equal(t, "bar", v.Foo)
|
||||
assert.True(t, m.cache.IsNotFound(m.cache.Get("foo", &v)))
|
||||
mockCollection.EXPECT().FindOneAndDelete(gomock.Any(), gomock.Any(), gomock.Any()).Return(mongo.NewSingleResultFromDocument(bson.M{"foo": "bar"}, nil, bson.NewRegistry()), errMocked)
|
||||
assert.NotNil(t, m.FindOneAndDelete(context.Background(), "foo", &v, bson.D{}))
|
||||
m.cache = mockedCache{m.cache}
|
||||
mockCollection.EXPECT().FindOneAndDelete(gomock.Any(), gomock.Any(), gomock.Any()).Return(mongo.NewSingleResultFromDocument(bson.M{"foo": "bar"}, nil, bson.NewRegistry()), nil)
|
||||
assert.Equal(t, errMocked, m.FindOneAndDelete(context.Background(), "foo", &v, bson.D{}))
|
||||
}
|
||||
|
||||
func TestModel_FindOneAndDeleteNoCache(t *testing.T) {
|
||||
mt := mtest.New(t, mtest.NewOptions().ClientType(mtest.Mock))
|
||||
mt.Run("test", func(mt *mtest.T) {
|
||||
mt.AddMockResponses(mtest.CreateSuccessResponse(bson.D{
|
||||
{Key: "value", Value: bson.D{{Key: "foo", Value: "bar"}}},
|
||||
}...))
|
||||
m := createModel(t, mt)
|
||||
v := struct {
|
||||
Foo string `bson:"foo"`
|
||||
}{}
|
||||
assert.Nil(t, m.FindOneAndDeleteNoCache(context.Background(), &v, bson.D{}))
|
||||
assert.Equal(t, "bar", v.Foo)
|
||||
})
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
mockCollection := mon.NewMockCollection(ctrl)
|
||||
mockCollection.EXPECT().FindOneAndDelete(gomock.Any(), gomock.Any(), gomock.Any()).
|
||||
Return(mongo.NewSingleResultFromDocument(bson.M{"foo": "bar"}, nil, nil), nil)
|
||||
m := createModel(t, mockCollection)
|
||||
v := struct {
|
||||
Foo string `bson:"foo"`
|
||||
}{}
|
||||
assert.Nil(t, m.FindOneAndDeleteNoCache(context.Background(), &v, bson.D{}))
|
||||
assert.Equal(t, "bar", v.Foo)
|
||||
}
|
||||
|
||||
func TestModel_FindOneAndReplace(t *testing.T) {
|
||||
mt := mtest.New(t, mtest.NewOptions().ClientType(mtest.Mock))
|
||||
mt.Run("test", func(mt *mtest.T) {
|
||||
mt.AddMockResponses(mtest.CreateSuccessResponse(bson.D{
|
||||
{Key: "value", Value: bson.D{{Key: "foo", Value: "bar"}}},
|
||||
}...))
|
||||
m := createModel(t, mt)
|
||||
assert.Nil(t, m.cache.Set("foo", "bar"))
|
||||
v := struct {
|
||||
Foo string `bson:"foo"`
|
||||
}{}
|
||||
assert.Nil(t, m.FindOneAndReplace(context.Background(), "foo", &v, bson.D{}, bson.D{
|
||||
{Key: "name", Value: "Mary"},
|
||||
}))
|
||||
assert.Equal(t, "bar", v.Foo)
|
||||
assert.True(t, m.cache.IsNotFound(m.cache.Get("foo", &v)))
|
||||
assert.NotNil(t, m.FindOneAndReplace(context.Background(), "foo", &v, bson.D{}, bson.D{
|
||||
{Key: "name", Value: "Mary"},
|
||||
}))
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
mockCollection := mon.NewMockCollection(ctrl)
|
||||
m := createModel(t, mockCollection)
|
||||
assert.Nil(t, m.cache.Set("foo", "bar"))
|
||||
v := struct {
|
||||
Foo string `bson:"foo"`
|
||||
}{}
|
||||
mockCollection.EXPECT().FindOneAndReplace(gomock.Any(), gomock.Any(), gomock.Any()).
|
||||
Return(mongo.NewSingleResultFromDocument(bson.M{"foo": "bar"}, nil, nil), nil)
|
||||
assert.Nil(t, m.FindOneAndReplace(context.Background(), "foo", &v, bson.D{}, bson.D{
|
||||
{Key: "name", Value: "Mary"},
|
||||
}))
|
||||
assert.Equal(t, "bar", v.Foo)
|
||||
assert.True(t, m.cache.IsNotFound(m.cache.Get("foo", &v)))
|
||||
mockCollection.EXPECT().FindOneAndReplace(gomock.Any(), gomock.Any(), gomock.Any()).
|
||||
Return(mongo.NewSingleResultFromDocument(bson.M{"name": "Mary"}, nil, nil), errMocked)
|
||||
assert.NotNil(t, m.FindOneAndReplace(context.Background(), "foo", &v, bson.D{}, bson.D{
|
||||
{Key: "name", Value: "Mary"},
|
||||
}))
|
||||
|
||||
m.cache = mockedCache{m.cache}
|
||||
mt.AddMockResponses(mtest.CreateSuccessResponse(bson.D{
|
||||
{Key: "value", Value: bson.D{{Key: "foo", Value: "bar"}}},
|
||||
}...))
|
||||
assert.Equal(t, errMocked, m.FindOneAndReplace(context.Background(), "foo", &v, bson.D{}, bson.D{
|
||||
{Key: "name", Value: "Mary"},
|
||||
}))
|
||||
})
|
||||
m.cache = mockedCache{m.cache}
|
||||
mockCollection.EXPECT().FindOneAndReplace(gomock.Any(), gomock.Any(), gomock.Any()).
|
||||
Return(mongo.NewSingleResultFromDocument(bson.M{"foo": "bar"}, nil, nil), nil)
|
||||
assert.Equal(t, errMocked, m.FindOneAndReplace(context.Background(), "foo", &v, bson.D{}, bson.D{
|
||||
{Key: "name", Value: "Mary"},
|
||||
}))
|
||||
}
|
||||
|
||||
func TestModel_FindOneAndReplaceNoCache(t *testing.T) {
|
||||
mt := mtest.New(t, mtest.NewOptions().ClientType(mtest.Mock))
|
||||
mt.Run("test", func(mt *mtest.T) {
|
||||
mt.AddMockResponses(mtest.CreateSuccessResponse(bson.D{
|
||||
{Key: "value", Value: bson.D{{Key: "foo", Value: "bar"}}},
|
||||
}...))
|
||||
m := createModel(t, mt)
|
||||
v := struct {
|
||||
Foo string `bson:"foo"`
|
||||
}{}
|
||||
assert.Nil(t, m.FindOneAndReplaceNoCache(context.Background(), &v, bson.D{}, bson.D{
|
||||
{Key: "name", Value: "Mary"},
|
||||
}))
|
||||
assert.Equal(t, "bar", v.Foo)
|
||||
})
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
mockCollection := mon.NewMockCollection(ctrl)
|
||||
m := createModel(t, mockCollection)
|
||||
v := struct {
|
||||
Foo string `bson:"foo"`
|
||||
}{}
|
||||
mockCollection.EXPECT().FindOneAndReplace(gomock.Any(), gomock.Any(), gomock.Any()).Return(mongo.NewSingleResultFromDocument(bson.M{"foo": "bar"}, nil, nil), nil)
|
||||
assert.Nil(t, m.FindOneAndReplaceNoCache(context.Background(), &v, bson.D{}, bson.D{
|
||||
{Key: "name", Value: "Mary"},
|
||||
}))
|
||||
assert.Equal(t, "bar", v.Foo)
|
||||
}
|
||||
|
||||
func TestModel_FindOneAndUpdate(t *testing.T) {
|
||||
mt := mtest.New(t, mtest.NewOptions().ClientType(mtest.Mock))
|
||||
mt.Run("test", func(mt *mtest.T) {
|
||||
mt.AddMockResponses(mtest.CreateSuccessResponse(bson.D{
|
||||
{Key: "value", Value: bson.D{{Key: "foo", Value: "bar"}}},
|
||||
}...))
|
||||
m := createModel(t, mt)
|
||||
assert.Nil(t, m.cache.Set("foo", "bar"))
|
||||
v := struct {
|
||||
Foo string `bson:"foo"`
|
||||
}{}
|
||||
assert.Nil(t, m.FindOneAndUpdate(context.Background(), "foo", &v, bson.D{}, bson.D{
|
||||
{Key: "$set", Value: bson.D{{Key: "name", Value: "Mary"}}},
|
||||
}))
|
||||
assert.Equal(t, "bar", v.Foo)
|
||||
assert.True(t, m.cache.IsNotFound(m.cache.Get("foo", &v)))
|
||||
assert.NotNil(t, m.FindOneAndUpdate(context.Background(), "foo", &v, bson.D{}, bson.D{
|
||||
{Key: "$set", Value: bson.D{{Key: "name", Value: "Mary"}}},
|
||||
}))
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
mockCollection := mon.NewMockCollection(ctrl)
|
||||
m := createModel(t, mockCollection)
|
||||
assert.Nil(t, m.cache.Set("foo", "bar"))
|
||||
v := struct {
|
||||
Foo string `bson:"foo"`
|
||||
}{}
|
||||
mockCollection.EXPECT().FindOneAndUpdate(gomock.Any(), gomock.Any(), gomock.Any()).
|
||||
Return(mongo.NewSingleResultFromDocument(bson.M{"foo": "bar"}, nil, nil), nil)
|
||||
assert.Nil(t, m.FindOneAndUpdate(context.Background(), "foo", &v, bson.D{}, bson.D{
|
||||
{Key: "$set", Value: bson.D{{Key: "name", Value: "Mary"}}},
|
||||
}))
|
||||
assert.Equal(t, "bar", v.Foo)
|
||||
assert.True(t, m.cache.IsNotFound(m.cache.Get("foo", &v)))
|
||||
mockCollection.EXPECT().FindOneAndUpdate(gomock.Any(), gomock.Any(), gomock.Any()).
|
||||
Return(mongo.NewSingleResultFromDocument(bson.M{"foo": "bar"}, nil, nil), errMocked)
|
||||
assert.NotNil(t, m.FindOneAndUpdate(context.Background(), "foo", &v, bson.D{}, bson.D{
|
||||
{Key: "$set", Value: bson.D{{Key: "name", Value: "Mary"}}},
|
||||
}))
|
||||
|
||||
m.cache = mockedCache{m.cache}
|
||||
mt.AddMockResponses(mtest.CreateSuccessResponse(bson.D{
|
||||
{Key: "value", Value: bson.D{{Key: "foo", Value: "bar"}}},
|
||||
}...))
|
||||
assert.Equal(t, errMocked, m.FindOneAndUpdate(context.Background(), "foo", &v, bson.D{}, bson.D{
|
||||
{Key: "$set", Value: bson.D{{Key: "name", Value: "Mary"}}},
|
||||
}))
|
||||
})
|
||||
m.cache = mockedCache{m.cache}
|
||||
mockCollection.EXPECT().FindOneAndUpdate(gomock.Any(), gomock.Any(), gomock.Any()).
|
||||
Return(mongo.NewSingleResultFromDocument(bson.M{"foo": "bar"}, nil, nil), nil)
|
||||
assert.Equal(t, errMocked, m.FindOneAndUpdate(context.Background(), "foo", &v, bson.D{}, bson.D{
|
||||
{Key: "$set", Value: bson.D{{Key: "name", Value: "Mary"}}},
|
||||
}))
|
||||
}
|
||||
|
||||
func TestModel_FindOneAndUpdateNoCache(t *testing.T) {
|
||||
mt := mtest.New(t, mtest.NewOptions().ClientType(mtest.Mock))
|
||||
mt.Run("test", func(mt *mtest.T) {
|
||||
mt.AddMockResponses(mtest.CreateSuccessResponse(bson.D{
|
||||
{Key: "value", Value: bson.D{{Key: "foo", Value: "bar"}}},
|
||||
}...))
|
||||
m := createModel(t, mt)
|
||||
v := struct {
|
||||
Foo string `bson:"foo"`
|
||||
}{}
|
||||
assert.Nil(t, m.FindOneAndUpdateNoCache(context.Background(), &v, bson.D{}, bson.D{
|
||||
{Key: "$set", Value: bson.D{{Key: "name", Value: "Mary"}}},
|
||||
}))
|
||||
assert.Equal(t, "bar", v.Foo)
|
||||
})
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
mockCollection := mon.NewMockCollection(ctrl)
|
||||
m := createModel(t, mockCollection)
|
||||
v := struct {
|
||||
Foo string `bson:"foo"`
|
||||
}{}
|
||||
mockCollection.EXPECT().FindOneAndUpdate(gomock.Any(), gomock.Any(), gomock.Any()).
|
||||
Return(mongo.NewSingleResultFromDocument(bson.M{"foo": "bar"}, nil, nil), nil)
|
||||
assert.Nil(t, m.FindOneAndUpdateNoCache(context.Background(), &v, bson.D{}, bson.D{
|
||||
{Key: "$set", Value: bson.D{{Key: "name", Value: "Mary"}}},
|
||||
}))
|
||||
assert.Equal(t, "bar", v.Foo)
|
||||
}
|
||||
|
||||
func TestModel_GetCache(t *testing.T) {
|
||||
mt := mtest.New(t, mtest.NewOptions().ClientType(mtest.Mock))
|
||||
mt.Run("test", func(mt *mtest.T) {
|
||||
m := createModel(t, mt)
|
||||
assert.NotNil(t, m.cache)
|
||||
assert.Nil(t, m.cache.Set("foo", "bar"))
|
||||
var s string
|
||||
assert.Nil(t, m.cache.Get("foo", &s))
|
||||
assert.Equal(t, "bar", s)
|
||||
})
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
mockCollection := mon.NewMockCollection(ctrl)
|
||||
m := createModel(t, mockCollection)
|
||||
assert.NotNil(t, m.cache)
|
||||
assert.Nil(t, m.cache.Set("foo", "bar"))
|
||||
var s string
|
||||
assert.Nil(t, m.cache.Get("foo", &s))
|
||||
assert.Equal(t, "bar", s)
|
||||
}
|
||||
|
||||
func TestModel_InsertOne(t *testing.T) {
|
||||
mt := mtest.New(t, mtest.NewOptions().ClientType(mtest.Mock))
|
||||
mt.Run("test", func(mt *mtest.T) {
|
||||
mt.AddMockResponses(mtest.CreateSuccessResponse(bson.D{
|
||||
{Key: "value", Value: bson.D{{Key: "foo", Value: "bar"}}},
|
||||
}...))
|
||||
m := createModel(t, mt)
|
||||
assert.Nil(t, m.cache.Set("foo", "bar"))
|
||||
resp, err := m.InsertOne(context.Background(), "foo", bson.D{
|
||||
{Key: "name", Value: "Mary"},
|
||||
})
|
||||
assert.Nil(t, err)
|
||||
assert.NotNil(t, resp)
|
||||
var v string
|
||||
assert.True(t, m.cache.IsNotFound(m.cache.Get("foo", &v)))
|
||||
_, err = m.InsertOne(context.Background(), "foo", bson.D{
|
||||
{Key: "name", Value: "Mary"},
|
||||
})
|
||||
assert.NotNil(t, err)
|
||||
|
||||
m.cache = mockedCache{m.cache}
|
||||
mt.AddMockResponses(mtest.CreateSuccessResponse(bson.D{
|
||||
{Key: "value", Value: bson.D{{Key: "foo", Value: "bar"}}},
|
||||
}...))
|
||||
_, err = m.InsertOne(context.Background(), "foo", bson.D{
|
||||
{Key: "name", Value: "Mary"},
|
||||
})
|
||||
assert.Equal(t, errMocked, err)
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
mockCollection := mon.NewMockCollection(ctrl)
|
||||
m := createModel(t, mockCollection)
|
||||
assert.Nil(t, m.cache.Set("foo", "bar"))
|
||||
mockCollection.EXPECT().InsertOne(gomock.Any(), gomock.Any(), gomock.Any()).Return(&mongo.InsertOneResult{}, nil)
|
||||
resp, err := m.InsertOne(context.Background(), "foo", bson.D{
|
||||
{Key: "name", Value: "Mary"},
|
||||
})
|
||||
assert.Nil(t, err)
|
||||
assert.NotNil(t, resp)
|
||||
var v string
|
||||
assert.True(t, m.cache.IsNotFound(m.cache.Get("foo", &v)))
|
||||
mockCollection.EXPECT().InsertOne(gomock.Any(), gomock.Any(), gomock.Any()).Return(&mongo.InsertOneResult{}, errMocked)
|
||||
_, err = m.InsertOne(context.Background(), "foo", bson.D{
|
||||
{Key: "name", Value: "Mary"},
|
||||
})
|
||||
assert.NotNil(t, err)
|
||||
|
||||
m.cache = mockedCache{m.cache}
|
||||
mockCollection.EXPECT().InsertOne(gomock.Any(), gomock.Any(), gomock.Any()).Return(&mongo.InsertOneResult{}, nil)
|
||||
_, err = m.InsertOne(context.Background(), "foo", bson.D{
|
||||
{Key: "name", Value: "Mary"},
|
||||
})
|
||||
assert.Equal(t, errMocked, err)
|
||||
}
|
||||
|
||||
func TestModel_InsertOneNoCache(t *testing.T) {
|
||||
mt := mtest.New(t, mtest.NewOptions().ClientType(mtest.Mock))
|
||||
mt.Run("test", func(mt *mtest.T) {
|
||||
mt.AddMockResponses(mtest.CreateSuccessResponse(bson.D{
|
||||
{Key: "value", Value: bson.D{{Key: "foo", Value: "bar"}}},
|
||||
}...))
|
||||
m := createModel(t, mt)
|
||||
resp, err := m.InsertOneNoCache(context.Background(), bson.D{
|
||||
{Key: "name", Value: "Mary"},
|
||||
})
|
||||
assert.Nil(t, err)
|
||||
assert.NotNil(t, resp)
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
mockCollection := mon.NewMockCollection(ctrl)
|
||||
m := createModel(t, mockCollection)
|
||||
mockCollection.EXPECT().InsertOne(gomock.Any(), gomock.Any(), gomock.Any()).Return(&mongo.InsertOneResult{}, nil)
|
||||
resp, err := m.InsertOneNoCache(context.Background(), bson.D{
|
||||
{Key: "name", Value: "Mary"},
|
||||
})
|
||||
assert.Nil(t, err)
|
||||
assert.NotNil(t, resp)
|
||||
}
|
||||
|
||||
func TestModel_ReplaceOne(t *testing.T) {
|
||||
mt := mtest.New(t, mtest.NewOptions().ClientType(mtest.Mock))
|
||||
mt.Run("test", func(mt *mtest.T) {
|
||||
mt.AddMockResponses(mtest.CreateSuccessResponse(bson.D{
|
||||
{Key: "value", Value: bson.D{{Key: "foo", Value: "bar"}}},
|
||||
}...))
|
||||
m := createModel(t, mt)
|
||||
assert.Nil(t, m.cache.Set("foo", "bar"))
|
||||
resp, err := m.ReplaceOne(context.Background(), "foo", bson.D{}, bson.D{
|
||||
{Key: "foo", Value: "baz"},
|
||||
})
|
||||
assert.Nil(t, err)
|
||||
assert.NotNil(t, resp)
|
||||
var v string
|
||||
assert.True(t, m.cache.IsNotFound(m.cache.Get("foo", &v)))
|
||||
_, err = m.ReplaceOne(context.Background(), "foo", bson.D{}, bson.D{
|
||||
{Key: "foo", Value: "baz"},
|
||||
})
|
||||
assert.NotNil(t, err)
|
||||
|
||||
m.cache = mockedCache{m.cache}
|
||||
mt.AddMockResponses(mtest.CreateSuccessResponse(bson.D{
|
||||
{Key: "value", Value: bson.D{{Key: "foo", Value: "bar"}}},
|
||||
}...))
|
||||
_, err = m.ReplaceOne(context.Background(), "foo", bson.D{}, bson.D{
|
||||
{Key: "foo", Value: "baz"},
|
||||
})
|
||||
assert.Equal(t, errMocked, err)
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
mockCollection := mon.NewMockCollection(ctrl)
|
||||
m := createModel(t, mockCollection)
|
||||
assert.Nil(t, m.cache.Set("foo", "bar"))
|
||||
mockCollection.EXPECT().ReplaceOne(gomock.Any(), gomock.Any(), gomock.Any()).Return(&mongo.UpdateResult{}, nil)
|
||||
resp, err := m.ReplaceOne(context.Background(), "foo", bson.D{}, bson.D{
|
||||
{Key: "foo", Value: "baz"},
|
||||
})
|
||||
assert.Nil(t, err)
|
||||
assert.NotNil(t, resp)
|
||||
var v string
|
||||
assert.True(t, m.cache.IsNotFound(m.cache.Get("foo", &v)))
|
||||
mockCollection.EXPECT().ReplaceOne(gomock.Any(), gomock.Any(), gomock.Any()).Return(&mongo.UpdateResult{}, errMocked)
|
||||
_, err = m.ReplaceOne(context.Background(), "foo", bson.D{}, bson.D{
|
||||
{Key: "foo", Value: "baz"},
|
||||
})
|
||||
assert.NotNil(t, err)
|
||||
|
||||
m.cache = mockedCache{m.cache}
|
||||
mockCollection.EXPECT().ReplaceOne(gomock.Any(), gomock.Any(), gomock.Any()).Return(&mongo.UpdateResult{}, nil)
|
||||
_, err = m.ReplaceOne(context.Background(), "foo", bson.D{}, bson.D{
|
||||
{Key: "foo", Value: "baz"},
|
||||
})
|
||||
assert.Equal(t, errMocked, err)
|
||||
}
|
||||
|
||||
func TestModel_ReplaceOneNoCache(t *testing.T) {
|
||||
mt := mtest.New(t, mtest.NewOptions().ClientType(mtest.Mock))
|
||||
mt.Run("test", func(mt *mtest.T) {
|
||||
mt.AddMockResponses(mtest.CreateSuccessResponse(bson.D{
|
||||
{Key: "value", Value: bson.D{{Key: "foo", Value: "bar"}}},
|
||||
}...))
|
||||
m := createModel(t, mt)
|
||||
resp, err := m.ReplaceOneNoCache(context.Background(), bson.D{}, bson.D{
|
||||
{Key: "foo", Value: "baz"},
|
||||
})
|
||||
assert.Nil(t, err)
|
||||
assert.NotNil(t, resp)
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
mockCollection := mon.NewMockCollection(ctrl)
|
||||
m := createModel(t, mockCollection)
|
||||
mockCollection.EXPECT().ReplaceOne(gomock.Any(), gomock.Any(), gomock.Any()).Return(&mongo.UpdateResult{}, nil)
|
||||
resp, err := m.ReplaceOneNoCache(context.Background(), bson.D{}, bson.D{
|
||||
{Key: "foo", Value: "baz"},
|
||||
})
|
||||
assert.Nil(t, err)
|
||||
assert.NotNil(t, resp)
|
||||
}
|
||||
|
||||
func TestModel_SetCache(t *testing.T) {
|
||||
mt := mtest.New(t, mtest.NewOptions().ClientType(mtest.Mock))
|
||||
mt.Run("test", func(mt *mtest.T) {
|
||||
m := createModel(t, mt)
|
||||
assert.Nil(t, m.SetCache("foo", "bar"))
|
||||
var v string
|
||||
assert.Nil(t, m.GetCache("foo", &v))
|
||||
assert.Equal(t, "bar", v)
|
||||
})
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
mockCollection := mon.NewMockCollection(ctrl)
|
||||
m := createModel(t, mockCollection)
|
||||
assert.Nil(t, m.SetCache("foo", "bar"))
|
||||
var v string
|
||||
assert.Nil(t, m.GetCache("foo", &v))
|
||||
assert.Equal(t, "bar", v)
|
||||
}
|
||||
|
||||
func TestModel_UpdateByID(t *testing.T) {
|
||||
mt := mtest.New(t, mtest.NewOptions().ClientType(mtest.Mock))
|
||||
mt.Run("test", func(mt *mtest.T) {
|
||||
mt.AddMockResponses(mtest.CreateSuccessResponse(bson.D{
|
||||
{Key: "value", Value: bson.D{{Key: "foo", Value: "bar"}}},
|
||||
}...))
|
||||
m := createModel(t, mt)
|
||||
assert.Nil(t, m.cache.Set("foo", "bar"))
|
||||
resp, err := m.UpdateByID(context.Background(), "foo", bson.D{}, bson.D{
|
||||
{Key: "$set", Value: bson.D{{Key: "foo", Value: "baz"}}},
|
||||
})
|
||||
assert.Nil(t, err)
|
||||
assert.NotNil(t, resp)
|
||||
var v string
|
||||
assert.True(t, m.cache.IsNotFound(m.cache.Get("foo", &v)))
|
||||
_, err = m.UpdateByID(context.Background(), "foo", bson.D{}, bson.D{
|
||||
{Key: "$set", Value: bson.D{{Key: "foo", Value: "baz"}}},
|
||||
})
|
||||
assert.NotNil(t, err)
|
||||
|
||||
m.cache = mockedCache{m.cache}
|
||||
mt.AddMockResponses(mtest.CreateSuccessResponse(bson.D{
|
||||
{Key: "value", Value: bson.D{{Key: "foo", Value: "bar"}}},
|
||||
}...))
|
||||
_, err = m.UpdateByID(context.Background(), "foo", bson.D{}, bson.D{
|
||||
{Key: "$set", Value: bson.D{{Key: "foo", Value: "baz"}}},
|
||||
})
|
||||
assert.Equal(t, errMocked, err)
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
mockCollection := mon.NewMockCollection(ctrl)
|
||||
m := createModel(t, mockCollection)
|
||||
assert.Nil(t, m.cache.Set("foo", "bar"))
|
||||
mockCollection.EXPECT().UpdateByID(gomock.Any(), gomock.Any(), gomock.Any()).Return(&mongo.UpdateResult{}, nil)
|
||||
resp, err := m.UpdateByID(context.Background(), "foo", bson.D{}, bson.D{
|
||||
{Key: "$set", Value: bson.D{{Key: "foo", Value: "baz"}}},
|
||||
})
|
||||
assert.Nil(t, err)
|
||||
assert.NotNil(t, resp)
|
||||
var v string
|
||||
assert.True(t, m.cache.IsNotFound(m.cache.Get("foo", &v)))
|
||||
mockCollection.EXPECT().UpdateByID(gomock.Any(), gomock.Any(), gomock.Any()).Return(&mongo.UpdateResult{}, errMocked)
|
||||
_, err = m.UpdateByID(context.Background(), "foo", bson.D{}, bson.D{
|
||||
{Key: "$set", Value: bson.D{{Key: "foo", Value: "baz"}}},
|
||||
})
|
||||
assert.NotNil(t, err)
|
||||
|
||||
m.cache = mockedCache{m.cache}
|
||||
mockCollection.EXPECT().UpdateByID(gomock.Any(), gomock.Any(), gomock.Any()).Return(&mongo.UpdateResult{}, nil)
|
||||
_, err = m.UpdateByID(context.Background(), "foo", bson.D{}, bson.D{
|
||||
{Key: "$set", Value: bson.D{{Key: "foo", Value: "baz"}}},
|
||||
})
|
||||
assert.Equal(t, errMocked, err)
|
||||
}
|
||||
|
||||
func TestModel_UpdateByIDNoCache(t *testing.T) {
|
||||
mt := mtest.New(t, mtest.NewOptions().ClientType(mtest.Mock))
|
||||
mt.Run("test", func(mt *mtest.T) {
|
||||
mt.AddMockResponses(mtest.CreateSuccessResponse(bson.D{
|
||||
{Key: "value", Value: bson.D{{Key: "foo", Value: "bar"}}},
|
||||
}...))
|
||||
m := createModel(t, mt)
|
||||
resp, err := m.UpdateByIDNoCache(context.Background(), bson.D{}, bson.D{
|
||||
{Key: "$set", Value: bson.D{{Key: "foo", Value: "baz"}}},
|
||||
})
|
||||
assert.Nil(t, err)
|
||||
assert.NotNil(t, resp)
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
mockCollection := mon.NewMockCollection(ctrl)
|
||||
m := createModel(t, mockCollection)
|
||||
mockCollection.EXPECT().UpdateByID(gomock.Any(), gomock.Any(), gomock.Any()).Return(&mongo.UpdateResult{}, nil)
|
||||
resp, err := m.UpdateByIDNoCache(context.Background(), bson.D{}, bson.D{
|
||||
{Key: "$set", Value: bson.D{{Key: "foo", Value: "baz"}}},
|
||||
})
|
||||
assert.Nil(t, err)
|
||||
assert.NotNil(t, resp)
|
||||
}
|
||||
|
||||
func TestModel_UpdateMany(t *testing.T) {
|
||||
mt := mtest.New(t, mtest.NewOptions().ClientType(mtest.Mock))
|
||||
mt.Run("test", func(mt *mtest.T) {
|
||||
mt.AddMockResponses(mtest.CreateSuccessResponse(bson.D{
|
||||
{Key: "value", Value: bson.D{{Key: "foo", Value: "bar"}}},
|
||||
}...))
|
||||
m := createModel(t, mt)
|
||||
assert.Nil(t, m.cache.Set("foo", "bar"))
|
||||
assert.Nil(t, m.cache.Set("bar", "baz"))
|
||||
resp, err := m.UpdateMany(context.Background(), []string{"foo", "bar"}, bson.D{}, bson.D{
|
||||
{Key: "$set", Value: bson.D{{Key: "foo", Value: "baz"}}},
|
||||
})
|
||||
assert.Nil(t, err)
|
||||
assert.NotNil(t, resp)
|
||||
var v string
|
||||
assert.True(t, m.cache.IsNotFound(m.cache.Get("foo", &v)))
|
||||
assert.True(t, m.cache.IsNotFound(m.cache.Get("bar", &v)))
|
||||
_, err = m.UpdateMany(context.Background(), []string{"foo", "bar"}, bson.D{}, bson.D{
|
||||
{Key: "$set", Value: bson.D{{Key: "foo", Value: "baz"}}},
|
||||
})
|
||||
assert.NotNil(t, err)
|
||||
|
||||
m.cache = mockedCache{m.cache}
|
||||
mt.AddMockResponses(mtest.CreateSuccessResponse(bson.D{
|
||||
{Key: "value", Value: bson.D{{Key: "foo", Value: "bar"}}},
|
||||
}...))
|
||||
_, err = m.UpdateMany(context.Background(), []string{"foo", "bar"}, bson.D{}, bson.D{
|
||||
{Key: "$set", Value: bson.D{{Key: "foo", Value: "baz"}}},
|
||||
})
|
||||
assert.Equal(t, errMocked, err)
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
mockCollection := mon.NewMockCollection(ctrl)
|
||||
m := createModel(t, mockCollection)
|
||||
assert.Nil(t, m.cache.Set("foo", "bar"))
|
||||
assert.Nil(t, m.cache.Set("bar", "baz"))
|
||||
mockCollection.EXPECT().UpdateMany(gomock.Any(), gomock.Any(), gomock.Any()).Return(&mongo.UpdateResult{}, nil)
|
||||
resp, err := m.UpdateMany(context.Background(), []string{"foo", "bar"}, bson.D{}, bson.D{
|
||||
{Key: "$set", Value: bson.D{{Key: "foo", Value: "baz"}}},
|
||||
})
|
||||
assert.Nil(t, err)
|
||||
assert.NotNil(t, resp)
|
||||
var v string
|
||||
assert.True(t, m.cache.IsNotFound(m.cache.Get("foo", &v)))
|
||||
assert.True(t, m.cache.IsNotFound(m.cache.Get("bar", &v)))
|
||||
mockCollection.EXPECT().UpdateMany(gomock.Any(), gomock.Any(), gomock.Any()).Return(&mongo.UpdateResult{}, errMocked)
|
||||
_, err = m.UpdateMany(context.Background(), []string{"foo", "bar"}, bson.D{}, bson.D{
|
||||
{Key: "$set", Value: bson.D{{Key: "foo", Value: "baz"}}},
|
||||
})
|
||||
assert.NotNil(t, err)
|
||||
|
||||
m.cache = mockedCache{m.cache}
|
||||
mockCollection.EXPECT().UpdateMany(gomock.Any(), gomock.Any(), gomock.Any()).Return(&mongo.UpdateResult{}, nil)
|
||||
_, err = m.UpdateMany(context.Background(), []string{"foo", "bar"}, bson.D{}, bson.D{
|
||||
{Key: "$set", Value: bson.D{{Key: "foo", Value: "baz"}}},
|
||||
})
|
||||
assert.Equal(t, errMocked, err)
|
||||
}
|
||||
|
||||
func TestModel_UpdateManyNoCache(t *testing.T) {
|
||||
mt := mtest.New(t, mtest.NewOptions().ClientType(mtest.Mock))
|
||||
mt.Run("test", func(mt *mtest.T) {
|
||||
mt.AddMockResponses(mtest.CreateSuccessResponse(bson.D{
|
||||
{Key: "value", Value: bson.D{{Key: "foo", Value: "bar"}}},
|
||||
}...))
|
||||
m := createModel(t, mt)
|
||||
resp, err := m.UpdateManyNoCache(context.Background(), bson.D{}, bson.D{
|
||||
{Key: "$set", Value: bson.D{{Key: "foo", Value: "baz"}}},
|
||||
})
|
||||
assert.Nil(t, err)
|
||||
assert.NotNil(t, resp)
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
mockCollection := mon.NewMockCollection(ctrl)
|
||||
m := createModel(t, mockCollection)
|
||||
mockCollection.EXPECT().UpdateMany(gomock.Any(), gomock.Any(), gomock.Any()).Return(&mongo.UpdateResult{}, nil)
|
||||
resp, err := m.UpdateManyNoCache(context.Background(), bson.D{}, bson.D{
|
||||
{Key: "$set", Value: bson.D{{Key: "foo", Value: "baz"}}},
|
||||
})
|
||||
assert.Nil(t, err)
|
||||
assert.NotNil(t, resp)
|
||||
}
|
||||
|
||||
func TestModel_UpdateOne(t *testing.T) {
|
||||
mt := mtest.New(t, mtest.NewOptions().ClientType(mtest.Mock))
|
||||
mt.Run("test", func(mt *mtest.T) {
|
||||
mt.AddMockResponses(mtest.CreateSuccessResponse(bson.D{
|
||||
{Key: "value", Value: bson.D{{Key: "foo", Value: "bar"}}},
|
||||
}...))
|
||||
m := createModel(t, mt)
|
||||
assert.Nil(t, m.cache.Set("foo", "bar"))
|
||||
resp, err := m.UpdateOne(context.Background(), "foo", bson.D{}, bson.D{
|
||||
{Key: "$set", Value: bson.D{{Key: "foo", Value: "baz"}}},
|
||||
})
|
||||
assert.Nil(t, err)
|
||||
assert.NotNil(t, resp)
|
||||
var v string
|
||||
assert.True(t, m.cache.IsNotFound(m.cache.Get("foo", &v)))
|
||||
_, err = m.UpdateOne(context.Background(), "foo", bson.D{}, bson.D{
|
||||
{Key: "$set", Value: bson.D{{Key: "foo", Value: "baz"}}},
|
||||
})
|
||||
assert.NotNil(t, err)
|
||||
|
||||
mt.AddMockResponses(mtest.CreateSuccessResponse(bson.D{
|
||||
{Key: "value", Value: bson.D{{Key: "foo", Value: "bar"}}},
|
||||
}...))
|
||||
m.cache = mockedCache{m.cache}
|
||||
_, err = m.UpdateOne(context.Background(), "foo", bson.D{}, bson.D{
|
||||
{Key: "$set", Value: bson.D{{Key: "foo", Value: "baz"}}},
|
||||
})
|
||||
assert.Equal(t, errMocked, err)
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
mockCollection := mon.NewMockCollection(ctrl)
|
||||
m := createModel(t, mockCollection)
|
||||
assert.Nil(t, m.cache.Set("foo", "bar"))
|
||||
mockCollection.EXPECT().UpdateOne(gomock.Any(), gomock.Any(), gomock.Any()).Return(&mongo.UpdateResult{}, nil)
|
||||
resp, err := m.UpdateOne(context.Background(), "foo", bson.D{}, bson.D{
|
||||
{Key: "$set", Value: bson.D{{Key: "foo", Value: "baz"}}},
|
||||
})
|
||||
assert.Nil(t, err)
|
||||
assert.NotNil(t, resp)
|
||||
var v string
|
||||
mockCollection.EXPECT().UpdateOne(gomock.Any(), gomock.Any(), gomock.Any()).Return(&mongo.UpdateResult{}, errMocked)
|
||||
assert.True(t, m.cache.IsNotFound(m.cache.Get("foo", &v)))
|
||||
_, err = m.UpdateOne(context.Background(), "foo", bson.D{}, bson.D{
|
||||
{Key: "$set", Value: bson.D{{Key: "foo", Value: "baz"}}},
|
||||
})
|
||||
assert.NotNil(t, err)
|
||||
|
||||
mockCollection.EXPECT().UpdateOne(gomock.Any(), gomock.Any(), gomock.Any()).Return(&mongo.UpdateResult{}, nil)
|
||||
m.cache = mockedCache{m.cache}
|
||||
_, err = m.UpdateOne(context.Background(), "foo", bson.D{}, bson.D{
|
||||
{Key: "$set", Value: bson.D{{Key: "foo", Value: "baz"}}},
|
||||
})
|
||||
assert.Equal(t, errMocked, err)
|
||||
}
|
||||
|
||||
func TestModel_UpdateOneNoCache(t *testing.T) {
|
||||
mt := mtest.New(t, mtest.NewOptions().ClientType(mtest.Mock))
|
||||
mt.Run("test", func(mt *mtest.T) {
|
||||
mt.AddMockResponses(mtest.CreateSuccessResponse(bson.D{
|
||||
{Key: "value", Value: bson.D{{Key: "foo", Value: "bar"}}},
|
||||
}...))
|
||||
m := createModel(t, mt)
|
||||
resp, err := m.UpdateOneNoCache(context.Background(), bson.D{}, bson.D{
|
||||
{Key: "$set", Value: bson.D{{Key: "foo", Value: "baz"}}},
|
||||
})
|
||||
assert.Nil(t, err)
|
||||
assert.NotNil(t, resp)
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
mockCollection := mon.NewMockCollection(ctrl)
|
||||
m := createModel(t, mockCollection)
|
||||
mockCollection.EXPECT().UpdateOne(gomock.Any(), gomock.Any(), gomock.Any()).Return(&mongo.UpdateResult{}, nil)
|
||||
resp, err := m.UpdateOneNoCache(context.Background(), bson.D{}, bson.D{
|
||||
{Key: "$set", Value: bson.D{{Key: "foo", Value: "baz"}}},
|
||||
})
|
||||
assert.Nil(t, err)
|
||||
assert.NotNil(t, resp)
|
||||
}
|
||||
|
||||
func createModel(t *testing.T, mt *mtest.T) *Model {
|
||||
func createModel(t *testing.T, coll mon.Collection) *Model {
|
||||
s, err := miniredis.Run()
|
||||
assert.Nil(t, err)
|
||||
mon.Inject(mt.Name(), mt.Client)
|
||||
if atomic.AddInt32(&index, 1)%2 == 0 {
|
||||
return MustNewNodeModel(mt.Name(), mt.DB.Name(), mt.Coll.Name(), redis.New(s.Addr()))
|
||||
return mustNewTestNodeModel(coll, redis.New(s.Addr()))
|
||||
} else {
|
||||
return MustNewModel(mt.Name(), mt.DB.Name(), mt.Coll.Name(), cache.CacheConf{
|
||||
return mustNewTestModel(coll, cache.CacheConf{
|
||||
cache.NodeConf{
|
||||
RedisConf: redis.RedisConf{
|
||||
Host: s.Addr(),
|
||||
@@ -519,6 +532,27 @@ func createModel(t *testing.T, mt *mtest.T) *Model {
|
||||
}
|
||||
}
|
||||
|
||||
// mustNewTestModel returns a test Model with the given cache.
|
||||
func mustNewTestModel(collection mon.Collection, c cache.CacheConf, opts ...cache.Option) *Model {
|
||||
return &Model{
|
||||
Model: &mon.Model{
|
||||
Collection: collection,
|
||||
},
|
||||
cache: cache.New(c, singleFlight, stats, mongo.ErrNoDocuments, opts...),
|
||||
}
|
||||
}
|
||||
|
||||
// NewNodeModel returns a test Model with a cache node.
|
||||
func mustNewTestNodeModel(collection mon.Collection, rds *redis.Redis, opts ...cache.Option) *Model {
|
||||
c := cache.NewNode(rds, singleFlight, stats, mongo.ErrNoDocuments, opts...)
|
||||
return &Model{
|
||||
Model: &mon.Model{
|
||||
Collection: collection,
|
||||
},
|
||||
cache: c,
|
||||
}
|
||||
}
|
||||
|
||||
var (
|
||||
errMocked = errors.New("mocked error")
|
||||
index int32
|
||||
|
||||
19
core/stores/monc/migration-2.0.md
Normal file
19
core/stores/monc/migration-2.0.md
Normal file
@@ -0,0 +1,19 @@
|
||||
# Migrating from 1.x to 2.0
|
||||
|
||||
To upgrade imports of the Go Driver from v1 to v2, we recommend using [marwan-at-work/mod
|
||||
](https://github.com/marwan-at-work/mod):
|
||||
|
||||
```
|
||||
mod upgrade --mod-name=go.mongodb.org/mongo-driver
|
||||
```
|
||||
|
||||
# Notice
|
||||
After completing the mod upgrade, code changes are typically unnecessary in the vast majority of cases. However, if your project references packages including but not limited to those listed below, you'll need to manually replace them, as these libraries are no longer present in the v2 version.
|
||||
```go
|
||||
go.mongodb.org/mongo-driver/bson/bsonrw => go.mongodb.org/mongo-driver/v2/bson
|
||||
go.mongodb.org/mongo-driver/bson/bsoncodec => go.mongodb.org/mongo-driver/v2/bson
|
||||
go.mongodb.org/mongo-driver/bson/primitive => go.mongodb.org/mongo-driver/v2/bson
|
||||
```
|
||||
|
||||
See the following resources to learn more about upgrading from version 1.x to 2.0.:
|
||||
https://raw.githubusercontent.com/mongodb/mongo-go-driver/refs/heads/master/docs/migration-2.0.md
|
||||
@@ -65,7 +65,6 @@ type (
|
||||
// RedisNode interface represents a redis node.
|
||||
RedisNode interface {
|
||||
red.Cmdable
|
||||
red.BitMapCmdable
|
||||
}
|
||||
|
||||
// GeoLocation is used with GeoAdd to add geospatial location.
|
||||
@@ -609,6 +608,28 @@ func (s *Redis) GetBitCtx(ctx context.Context, key string, offset int64) (int, e
|
||||
return int(v), nil
|
||||
}
|
||||
|
||||
// GetDel is the implementation of redis getdel command.
|
||||
// Available since: redis version 6.2.0
|
||||
func (s *Redis) GetDel(key string) (string, error) {
|
||||
return s.GetDelCtx(context.Background(), key)
|
||||
}
|
||||
|
||||
// GetDelCtx is the implementation of redis getdel command.
|
||||
// Available since: redis version 6.2.0
|
||||
func (s *Redis) GetDelCtx(ctx context.Context, key string) (string, error) {
|
||||
conn, err := getRedis(s)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
val, err := conn.GetDel(ctx, key).Result()
|
||||
if errors.Is(err, red.Nil) {
|
||||
return "", nil
|
||||
}
|
||||
|
||||
return val, err
|
||||
}
|
||||
|
||||
// GetSet is the implementation of redis getset command.
|
||||
func (s *Redis) GetSet(key, value string) (string, error) {
|
||||
return s.GetSetCtx(context.Background(), key, value)
|
||||
@@ -1263,10 +1284,12 @@ func (s *Redis) RpushCtx(ctx context.Context, key string, values ...any) (int, e
|
||||
return int(v), nil
|
||||
}
|
||||
|
||||
// RPopLPush atomically removes the last element from source list and prepends it to destination list.
|
||||
func (s *Redis) RPopLPush(source string, destination string) (string, error) {
|
||||
return s.RPopLPushCtx(context.Background(), source, destination)
|
||||
}
|
||||
|
||||
// RPopLPushCtx is the context-aware version of RPopLPush.
|
||||
func (s *Redis) RPopLPushCtx(ctx context.Context, source string, destination string) (string, error) {
|
||||
conn, err := getRedis(s)
|
||||
if err != nil {
|
||||
@@ -1673,14 +1696,17 @@ func (s *Redis) TtlCtx(ctx context.Context, key string) (int, error) {
|
||||
return int(duration), nil
|
||||
}
|
||||
|
||||
// TxPipeline returns a Redis transaction pipeline for executing multiple commands atomically.
|
||||
func (s *Redis) TxPipeline() (pipe Pipeliner, err error) {
|
||||
conn, err := getRedis(s)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return conn.TxPipeline(), nil
|
||||
}
|
||||
|
||||
// Unlink is similar to Del but removes keys asynchronously in a separate thread.
|
||||
func (s *Redis) Unlink(keys ...string) (int64, error) {
|
||||
return s.UnlinkCtx(context.Background(), keys...)
|
||||
}
|
||||
@@ -1690,9 +1716,154 @@ func (s *Redis) UnlinkCtx(ctx context.Context, keys ...string) (int64, error) {
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
return conn.Unlink(ctx, keys...).Result()
|
||||
}
|
||||
|
||||
// XAck acknowledges one or more messages in a Redis stream consumer group.
|
||||
// It marks the specified messages as successfully processed.
|
||||
func (s *Redis) XAck(stream string, group string, ids ...string) (int64, error) {
|
||||
return s.XAckCtx(context.Background(), stream, group, ids...)
|
||||
}
|
||||
|
||||
// XAckCtx is the context-aware version of XAck.
|
||||
func (s *Redis) XAckCtx(ctx context.Context, stream string, group string, ids ...string) (int64, error) {
|
||||
conn, err := getRedis(s)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
return conn.XAck(ctx, stream, group, ids...).Result()
|
||||
}
|
||||
|
||||
// XAdd adds a new entry to a Redis stream with the specified ID and field-value pairs.
|
||||
// If noMkStream is true, the command will fail if the stream doesn't exist.
|
||||
func (s *Redis) XAdd(stream string, noMkStream bool, id string, values any) (string, error) {
|
||||
return s.XAddCtx(context.Background(), stream, noMkStream, id, values)
|
||||
}
|
||||
|
||||
// XAddCtx is the context-aware version of XAdd.
|
||||
func (s *Redis) XAddCtx(ctx context.Context, stream string, noMkStream bool, id string, values any) (
|
||||
string, error) {
|
||||
conn, err := getRedis(s)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return conn.XAdd(ctx, &red.XAddArgs{
|
||||
Stream: stream,
|
||||
ID: id,
|
||||
Values: values,
|
||||
NoMkStream: noMkStream,
|
||||
}).Result()
|
||||
}
|
||||
|
||||
// XGroupCreateMkStream creates a consumer group for a Redis stream.
|
||||
// If the stream doesn't exist, it will be created automatically.
|
||||
func (s *Redis) XGroupCreateMkStream(stream string, group string, start string) (string, error) {
|
||||
return s.XGroupCreateMkStreamCtx(context.Background(), stream, group, start)
|
||||
}
|
||||
|
||||
// XGroupCreateMkStreamCtx is the context-aware version of XGroupCreateMkStream.
|
||||
func (s *Redis) XGroupCreateMkStreamCtx(ctx context.Context, stream string, group string,
|
||||
start string) (string, error) {
|
||||
conn, err := getRedis(s)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return conn.XGroupCreateMkStream(ctx, stream, group, start).Result()
|
||||
}
|
||||
|
||||
// XGroupCreate creates a consumer group for a Redis stream.
|
||||
// The stream must already exist, otherwise the command will fail.
|
||||
func (s *Redis) XGroupCreate(stream string, group string, start string) (string, error) {
|
||||
return s.XGroupCreateCtx(context.Background(), stream, group, start)
|
||||
}
|
||||
|
||||
// XGroupCreateCtx is the context-aware version of XGroupCreate.
|
||||
func (s *Redis) XGroupCreateCtx(ctx context.Context, stream string, group string, start string) (
|
||||
string, error) {
|
||||
conn, err := getRedis(s)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return conn.XGroupCreate(ctx, stream, group, start).Result()
|
||||
}
|
||||
|
||||
// XInfoConsumers returns information about consumers in a Redis stream consumer group.
|
||||
func (s *Redis) XInfoConsumers(stream string, group string) ([]red.XInfoConsumer, error) {
|
||||
return s.XInfoConsumersCtx(context.Background(), stream, group)
|
||||
}
|
||||
|
||||
// XInfoConsumersCtx is the context-aware version of XInfoConsumers.
|
||||
func (s *Redis) XInfoConsumersCtx(ctx context.Context, stream string, group string) (
|
||||
[]red.XInfoConsumer, error) {
|
||||
conn, err := getRedis(s)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return conn.XInfoConsumers(ctx, stream, group).Result()
|
||||
}
|
||||
|
||||
// XInfoGroups returns information about consumer groups for a Redis stream.
|
||||
func (s *Redis) XInfoGroups(stream string) ([]red.XInfoGroup, error) {
|
||||
return s.XInfoGroupsCtx(context.Background(), stream)
|
||||
}
|
||||
|
||||
// XInfoGroupsCtx is the context-aware version of XInfoGroups.
|
||||
func (s *Redis) XInfoGroupsCtx(ctx context.Context, stream string) ([]red.XInfoGroup, error) {
|
||||
conn, err := getRedis(s)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return conn.XInfoGroups(ctx, stream).Result()
|
||||
}
|
||||
|
||||
// XInfoStream returns general information about a Redis stream.
|
||||
func (s *Redis) XInfoStream(stream string) (*red.XInfoStream, error) {
|
||||
return s.XInfoStreamCtx(context.Background(), stream)
|
||||
}
|
||||
|
||||
// XInfoStreamCtx is the context-aware version of XInfoStream.
|
||||
func (s *Redis) XInfoStreamCtx(ctx context.Context, stream string) (*red.XInfoStream, error) {
|
||||
conn, err := getRedis(s)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return conn.XInfoStream(ctx, stream).Result()
|
||||
}
|
||||
|
||||
// XReadGroup reads messages from Redis streams as part of a consumer group.
|
||||
// It allows for distributed processing of stream messages with automatic message delivery semantics.
|
||||
// Doesn't benefit from pooling redis connections of blocking queries.
|
||||
func (s *Redis) XReadGroup(node RedisNode, group string, consumerId string, count int64,
|
||||
block time.Duration, noAck bool, streams ...string) ([]red.XStream, error) {
|
||||
return s.XReadGroupCtx(context.Background(), node, group, consumerId, count, block, noAck, streams...)
|
||||
}
|
||||
|
||||
// XReadGroupCtx is the context-aware version of XReadGroup.
|
||||
// Doesn't benefit from pooling redis connections of blocking queries.
|
||||
func (s *Redis) XReadGroupCtx(ctx context.Context, node RedisNode, group string, consumerId string,
|
||||
count int64, block time.Duration, noAck bool, streams ...string) ([]red.XStream, error) {
|
||||
if node == nil {
|
||||
return nil, ErrNilNode
|
||||
}
|
||||
|
||||
return node.XReadGroup(ctx, &red.XReadGroupArgs{
|
||||
Group: group,
|
||||
Consumer: consumerId,
|
||||
Count: count,
|
||||
Block: block,
|
||||
NoAck: noAck,
|
||||
Streams: streams,
|
||||
}).Result()
|
||||
}
|
||||
|
||||
// Zadd is the implementation of redis zadd command.
|
||||
func (s *Redis) Zadd(key string, score int64, value string) (bool, error) {
|
||||
return s.ZaddCtx(context.Background(), key, score, value)
|
||||
@@ -1773,7 +1944,7 @@ func (s *Redis) ZaddsCtx(ctx context.Context, key string, ps ...Pair) (int64, er
|
||||
return 0, err
|
||||
}
|
||||
|
||||
var zs []red.Z
|
||||
zs := make([]red.Z, 0, len(ps))
|
||||
for _, p := range ps {
|
||||
z := red.Z{Score: float64(p.Score), Member: p.Key}
|
||||
zs = append(zs, z)
|
||||
|
||||
@@ -916,6 +916,11 @@ func TestRedis_Ping(t *testing.T) {
|
||||
ok := client.Ping()
|
||||
assert.True(t, ok)
|
||||
})
|
||||
|
||||
runOnRedisWithError(t, func(client *Redis) {
|
||||
ok := client.Ping()
|
||||
assert.False(t, ok)
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
@@ -1071,6 +1076,34 @@ func TestRedis_Set(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
func TestRedis_GetDel(t *testing.T) {
|
||||
t.Run("get_del", func(t *testing.T) {
|
||||
runOnRedis(t, func(client *Redis) {
|
||||
val, err := newRedis(client.Addr).GetDel("hello")
|
||||
assert.Equal(t, "", val)
|
||||
assert.Nil(t, err)
|
||||
err = client.Set("hello", "world")
|
||||
assert.Nil(t, err)
|
||||
val, err = client.Get("hello")
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, "world", val)
|
||||
val, err = client.GetDel("hello")
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, "world", val)
|
||||
val, err = client.Get("hello")
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, "", val)
|
||||
})
|
||||
})
|
||||
|
||||
t.Run("get_del_with_error", func(t *testing.T) {
|
||||
runOnRedisWithError(t, func(client *Redis) {
|
||||
_, err := newRedis(client.Addr, badType()).GetDel("hello")
|
||||
assert.Error(t, err)
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func TestRedis_GetSet(t *testing.T) {
|
||||
t.Run("set_get", func(t *testing.T) {
|
||||
runOnRedis(t, func(client *Redis) {
|
||||
@@ -2001,6 +2034,16 @@ func TestRedis_WithUserPass(t *testing.T) {
|
||||
err := newRedis(client.Addr, WithUser("any"), WithPass("any")).Ping()
|
||||
assert.NotNil(t, err)
|
||||
})
|
||||
|
||||
runOnRedisWithAccount(t, "foo", "bar", func(client *Redis) {
|
||||
err := client.Set("key1", "value1")
|
||||
assert.Nil(t, err)
|
||||
_, err = newRedis(client.Addr, badType()).Keys("*")
|
||||
assert.NotNil(t, err)
|
||||
keys, err := client.Keys("*")
|
||||
assert.Nil(t, err)
|
||||
assert.ElementsMatch(t, []string{"key1"}, keys)
|
||||
})
|
||||
}
|
||||
|
||||
func TestRedis_checkConnection(t *testing.T) {
|
||||
@@ -2029,6 +2072,19 @@ func runOnRedis(t *testing.T, fn func(client *Redis)) {
|
||||
}))
|
||||
}
|
||||
|
||||
func runOnRedisWithAccount(t *testing.T, user, pass string, fn func(client *Redis)) {
|
||||
logx.Disable()
|
||||
|
||||
s := miniredis.RunT(t)
|
||||
s.RequireUserAuth(user, pass)
|
||||
fn(MustNewRedis(RedisConf{
|
||||
Host: s.Addr(),
|
||||
Type: NodeType,
|
||||
User: user,
|
||||
Pass: pass,
|
||||
}))
|
||||
}
|
||||
|
||||
func runOnRedisWithError(t *testing.T, fn func(client *Redis)) {
|
||||
logx.Disable()
|
||||
|
||||
@@ -2147,3 +2203,115 @@ func TestRedisTxPipeline(t *testing.T) {
|
||||
assert.Equal(t, hashValue, value)
|
||||
})
|
||||
}
|
||||
|
||||
func TestRedisXGroupCreate(t *testing.T) {
|
||||
runOnRedis(t, func(client *Redis) {
|
||||
_, err := newRedis(client.Addr, badType()).XGroupCreate("Source", "Destination", "0")
|
||||
assert.NotNil(t, err)
|
||||
|
||||
redisCli := newRedis(client.Addr)
|
||||
|
||||
_, err = redisCli.XGroupCreate("aa", "bb", "0")
|
||||
assert.NotNil(t, err)
|
||||
|
||||
_, err = newRedis(client.Addr, badType()).XGroupCreateMkStream("Source", "Destination", "0")
|
||||
assert.NotNil(t, err)
|
||||
|
||||
_, err = redisCli.XGroupCreateMkStream("aa", "bb", "0")
|
||||
assert.Nil(t, err)
|
||||
|
||||
_, err = redisCli.XGroupCreate("aa", "cc", "0")
|
||||
assert.Nil(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
func TestRedisXInfo(t *testing.T) {
|
||||
runOnRedis(t, func(client *Redis) {
|
||||
_, err := newRedis(client.Addr, badType()).XInfoStream("Source")
|
||||
assert.NotNil(t, err)
|
||||
_, err = newRedis(client.Addr, badType()).XInfoGroups("Source")
|
||||
assert.NotNil(t, err)
|
||||
|
||||
redisCli := newRedis(client.Addr)
|
||||
|
||||
stream := "aa"
|
||||
group := "bb"
|
||||
|
||||
_, err = redisCli.XGroupCreateMkStream(stream, group, "$")
|
||||
assert.Nil(t, err)
|
||||
|
||||
_, err = redisCli.XAdd(stream, true, "*", []string{"key1", "value1", "key2", "value2"})
|
||||
assert.Nil(t, err)
|
||||
|
||||
infoStream, err := redisCli.XInfoStream(stream)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, int64(1), infoStream.Length)
|
||||
|
||||
infoGroups, err := redisCli.XInfoGroups(stream)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, int64(1), infoGroups[0].Lag)
|
||||
assert.Equal(t, group, infoGroups[0].Name)
|
||||
|
||||
node, err := getRedis(redisCli)
|
||||
assert.NoError(t, err)
|
||||
redisCli.XAdd(stream, true, "*", []string{"key1", "value1", "key2", "value2"})
|
||||
streamRes, err := redisCli.XReadGroup(node, group, "consumer", 1, 2000, false, stream, ">")
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, 1, len(streamRes))
|
||||
assert.Equal(t, "value1", streamRes[0].Messages[0].Values["key1"])
|
||||
|
||||
infoConsumers, err := redisCli.XInfoConsumers(stream, group)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, 1, len(infoConsumers))
|
||||
|
||||
_, err = newRedis(client.Addr, badType()).XInfoConsumers(stream, group)
|
||||
assert.NotNil(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
func TestRedisXReadGroup(t *testing.T) {
|
||||
runOnRedis(t, func(client *Redis) {
|
||||
_, err := newRedis(client.Addr, badType()).XAdd("bb", true, "*", []string{"key1", "value1", "key2", "value2"})
|
||||
assert.NotNil(t, err)
|
||||
_, err = newRedis(client.Addr, badType()).XAck("bb", "aa", "123")
|
||||
assert.NotNil(t, err)
|
||||
|
||||
redisCli := newRedis(client.Addr)
|
||||
|
||||
stream := "aa"
|
||||
group := "bb"
|
||||
|
||||
_, err = redisCli.XGroupCreateMkStream(stream, group, "$")
|
||||
assert.Nil(t, err)
|
||||
|
||||
_, err = redisCli.XAdd(stream, true, "*", []string{"key1", "value1", "key2", "value2"})
|
||||
assert.Nil(t, err)
|
||||
|
||||
node, err := getRedis(redisCli)
|
||||
assert.NoError(t, err)
|
||||
redisCli.XAdd(stream, true, "*", []string{"key1", "value1", "key2", "value2"})
|
||||
_, err = redisCli.XReadGroup(nil, group, "consumer", 1, 2000, false, stream, ">")
|
||||
assert.Error(t, err)
|
||||
streamRes, err := redisCli.XReadGroup(node, group, "consumer", 1, 2000, false, stream, ">")
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, 1, len(streamRes))
|
||||
assert.Equal(t, "value1", streamRes[0].Messages[0].Values["key1"])
|
||||
|
||||
_, err = redisCli.XReadGroup(nil, group, "consumer", 1, 2000, false, stream, "0")
|
||||
assert.Error(t, err)
|
||||
streamRes1, err := redisCli.XReadGroup(node, group, "consumer", 1, 2000, false, stream, "0")
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, 1, len(streamRes1))
|
||||
assert.Equal(t, "value1", streamRes1[0].Messages[0].Values["key1"])
|
||||
|
||||
_, err = redisCli.XAck(stream, group, streamRes[0].Messages[0].ID)
|
||||
assert.Nil(t, err)
|
||||
|
||||
_, err = redisCli.XReadGroup(nil, group, "consumer", 1, 2000, false, stream, "0")
|
||||
assert.Error(t, err)
|
||||
streamRes2, err := redisCli.XReadGroup(node, group, "consumer", 1, 2000, false, stream, "0")
|
||||
assert.Nil(t, err)
|
||||
assert.Greater(t, len(streamRes2), 0, "streamRes2 is empty")
|
||||
assert.Equal(t, 0, len(streamRes2[0].Messages))
|
||||
})
|
||||
}
|
||||
|
||||
@@ -21,6 +21,7 @@ func CreateBlockingNode(r *Redis) (ClosableNode, error) {
|
||||
case NodeType:
|
||||
client := red.NewClient(&red.Options{
|
||||
Addr: r.Addr,
|
||||
Username: r.User,
|
||||
Password: r.Pass,
|
||||
DB: defaultDatabase,
|
||||
MaxRetries: maxRetries,
|
||||
@@ -32,6 +33,7 @@ func CreateBlockingNode(r *Redis) (ClosableNode, error) {
|
||||
case ClusterType:
|
||||
client := red.NewClusterClient(&red.ClusterOptions{
|
||||
Addrs: splitClusterAddrs(r.Addr),
|
||||
Username: r.User,
|
||||
Password: r.Pass,
|
||||
MaxRetries: maxRetries,
|
||||
PoolSize: 1,
|
||||
|
||||
@@ -31,6 +31,7 @@ func getClient(r *Redis) (*red.Client, error) {
|
||||
}
|
||||
store := red.NewClient(&red.Options{
|
||||
Addr: r.Addr,
|
||||
Username: r.User,
|
||||
Password: r.Pass,
|
||||
DB: defaultDatabase,
|
||||
MaxRetries: maxRetries,
|
||||
|
||||
@@ -28,6 +28,7 @@ func getCluster(r *Redis) (*red.ClusterClient, error) {
|
||||
}
|
||||
store := red.NewClusterClient(&red.ClusterOptions{
|
||||
Addrs: splitClusterAddrs(r.Addr),
|
||||
Username: r.User,
|
||||
Password: r.Pass,
|
||||
MaxRetries: maxRetries,
|
||||
MinIdleConns: idleConns,
|
||||
|
||||
@@ -25,8 +25,8 @@ type (
|
||||
ResultHandler func(sql.Result, error)
|
||||
|
||||
// A BulkInserter is used to batch insert records.
|
||||
// Postgresql is not supported yet, because of the sql is formated with symbol `$`.
|
||||
// Oracle is not supported yet, because of the sql is formated with symbol `:`.
|
||||
// Postgresql is not supported yet, because of the sql is formatted with symbol `$`.
|
||||
// Oracle is not supported yet, because of the sql is formatted with symbol `:`.
|
||||
BulkInserter struct {
|
||||
executor *executors.PeriodicalExecutor
|
||||
inserter *dbInserter
|
||||
|
||||
29
core/stores/sqlx/config.go
Normal file
29
core/stores/sqlx/config.go
Normal file
@@ -0,0 +1,29 @@
|
||||
package sqlx
|
||||
|
||||
import "errors"
|
||||
|
||||
var (
|
||||
errEmptyDatasource = errors.New("empty datasource")
|
||||
errEmptyDriverName = errors.New("empty driver name")
|
||||
)
|
||||
|
||||
// SqlConf defines the configuration for sqlx.
|
||||
type SqlConf struct {
|
||||
DataSource string
|
||||
DriverName string `json:",default=mysql"`
|
||||
Replicas []string `json:",optional"`
|
||||
Policy string `json:",default=round-robin,options=round-robin|random"`
|
||||
}
|
||||
|
||||
// Validate validates the SqlxConf.
|
||||
func (sc SqlConf) Validate() error {
|
||||
if len(sc.DataSource) == 0 {
|
||||
return errEmptyDatasource
|
||||
}
|
||||
|
||||
if len(sc.DriverName) == 0 {
|
||||
return errEmptyDriverName
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
29
core/stores/sqlx/config_test.go
Normal file
29
core/stores/sqlx/config_test.go
Normal file
@@ -0,0 +1,29 @@
|
||||
package sqlx
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/zeromicro/go-zero/core/conf"
|
||||
)
|
||||
|
||||
func TestValidate(t *testing.T) {
|
||||
text := []byte(`DataSource: primary:password@tcp(127.0.0.1:3306)/primary_db
|
||||
`)
|
||||
|
||||
var sc SqlConf
|
||||
err := conf.LoadFromYamlBytes(text, &sc)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, "mysql", sc.DriverName)
|
||||
assert.Equal(t, policyRoundRobin, sc.Policy)
|
||||
assert.Nil(t, sc.Validate())
|
||||
|
||||
sc = SqlConf{}
|
||||
assert.Equal(t, errEmptyDatasource, sc.Validate())
|
||||
|
||||
sc.DataSource = "primary:password@tcp(127.0.0.1:3306)/primary_db"
|
||||
assert.Equal(t, errEmptyDriverName, sc.Validate())
|
||||
|
||||
sc.DriverName = "mysql"
|
||||
assert.Nil(t, sc.Validate())
|
||||
}
|
||||
@@ -9,7 +9,10 @@ import (
|
||||
"github.com/zeromicro/go-zero/core/mapping"
|
||||
)
|
||||
|
||||
const tagName = "db"
|
||||
const (
|
||||
tagIgnore = "-"
|
||||
tagName = "db"
|
||||
)
|
||||
|
||||
var (
|
||||
// ErrNotMatchDestination is an error that indicates not matching destination to scan.
|
||||
@@ -269,13 +272,17 @@ func unwrapFields(v reflect.Value) []reflect.Value {
|
||||
continue
|
||||
}
|
||||
|
||||
childType := indirect.Type().Field(i)
|
||||
if parseTagName(childType) == tagIgnore {
|
||||
continue
|
||||
}
|
||||
|
||||
if child.Kind() == reflect.Ptr && child.IsNil() {
|
||||
baseValueType := mapping.Deref(child.Type())
|
||||
child.Set(reflect.New(baseValueType))
|
||||
}
|
||||
|
||||
child = reflect.Indirect(child)
|
||||
childType := indirect.Type().Field(i)
|
||||
if child.Kind() == reflect.Struct && childType.Anonymous {
|
||||
fields = append(fields, unwrapFields(child)...)
|
||||
} else {
|
||||
|
||||
@@ -14,7 +14,8 @@ import (
|
||||
func TestUnmarshalRowBool(t *testing.T) {
|
||||
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("1")
|
||||
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
||||
mock.ExpectQuery("select (.+) from users where user=?").
|
||||
WithArgs("anyone").WillReturnRows(rs)
|
||||
|
||||
var value bool
|
||||
assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error {
|
||||
@@ -25,7 +26,8 @@ func TestUnmarshalRowBool(t *testing.T) {
|
||||
|
||||
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("1")
|
||||
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
||||
mock.ExpectQuery("select (.+) from users where user=?").
|
||||
WithArgs("anyone").WillReturnRows(rs)
|
||||
|
||||
var value struct {
|
||||
Value bool `db:"value"`
|
||||
@@ -39,7 +41,8 @@ func TestUnmarshalRowBool(t *testing.T) {
|
||||
func TestUnmarshalRowBoolNotSettable(t *testing.T) {
|
||||
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("1")
|
||||
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
||||
mock.ExpectQuery("select (.+) from users where user=?").
|
||||
WithArgs("anyone").WillReturnRows(rs)
|
||||
|
||||
var value bool
|
||||
assert.NotNil(t, query(context.Background(), db, func(rows *sql.Rows) error {
|
||||
@@ -51,7 +54,8 @@ func TestUnmarshalRowBoolNotSettable(t *testing.T) {
|
||||
func TestUnmarshalRowInt(t *testing.T) {
|
||||
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2")
|
||||
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
||||
mock.ExpectQuery("select (.+) from users where user=?").
|
||||
WithArgs("anyone").WillReturnRows(rs)
|
||||
|
||||
var value int
|
||||
assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error {
|
||||
@@ -64,7 +68,8 @@ func TestUnmarshalRowInt(t *testing.T) {
|
||||
func TestUnmarshalRowInt8(t *testing.T) {
|
||||
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("3")
|
||||
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
||||
mock.ExpectQuery("select (.+) from users where user=?").
|
||||
WithArgs("anyone").WillReturnRows(rs)
|
||||
|
||||
var value int8
|
||||
assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error {
|
||||
@@ -77,7 +82,8 @@ func TestUnmarshalRowInt8(t *testing.T) {
|
||||
func TestUnmarshalRowInt16(t *testing.T) {
|
||||
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("4")
|
||||
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
||||
mock.ExpectQuery("select (.+) from users where user=?").
|
||||
WithArgs("anyone").WillReturnRows(rs)
|
||||
|
||||
var value int16
|
||||
assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error {
|
||||
@@ -90,7 +96,8 @@ func TestUnmarshalRowInt16(t *testing.T) {
|
||||
func TestUnmarshalRowInt32(t *testing.T) {
|
||||
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("5")
|
||||
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
||||
mock.ExpectQuery("select (.+) from users where user=?").
|
||||
WithArgs("anyone").WillReturnRows(rs)
|
||||
|
||||
var value int32
|
||||
assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error {
|
||||
@@ -103,7 +110,8 @@ func TestUnmarshalRowInt32(t *testing.T) {
|
||||
func TestUnmarshalRowInt64(t *testing.T) {
|
||||
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("6")
|
||||
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
||||
mock.ExpectQuery("select (.+) from users where user=?").
|
||||
WithArgs("anyone").WillReturnRows(rs)
|
||||
|
||||
var value int64
|
||||
assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error {
|
||||
@@ -116,7 +124,8 @@ func TestUnmarshalRowInt64(t *testing.T) {
|
||||
func TestUnmarshalRowUint(t *testing.T) {
|
||||
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2")
|
||||
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
||||
mock.ExpectQuery("select (.+) from users where user=?").
|
||||
WithArgs("anyone").WillReturnRows(rs)
|
||||
|
||||
var value uint
|
||||
assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error {
|
||||
@@ -129,7 +138,8 @@ func TestUnmarshalRowUint(t *testing.T) {
|
||||
func TestUnmarshalRowUint8(t *testing.T) {
|
||||
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("3")
|
||||
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
||||
mock.ExpectQuery("select (.+) from users where user=?").
|
||||
WithArgs("anyone").WillReturnRows(rs)
|
||||
|
||||
var value uint8
|
||||
assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error {
|
||||
@@ -142,7 +152,8 @@ func TestUnmarshalRowUint8(t *testing.T) {
|
||||
func TestUnmarshalRowUint16(t *testing.T) {
|
||||
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("4")
|
||||
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
||||
mock.ExpectQuery("select (.+) from users where user=?").
|
||||
WithArgs("anyone").WillReturnRows(rs)
|
||||
|
||||
var value uint16
|
||||
assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error {
|
||||
@@ -155,7 +166,8 @@ func TestUnmarshalRowUint16(t *testing.T) {
|
||||
func TestUnmarshalRowUint32(t *testing.T) {
|
||||
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("5")
|
||||
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
||||
mock.ExpectQuery("select (.+) from users where user=?").
|
||||
WithArgs("anyone").WillReturnRows(rs)
|
||||
|
||||
var value uint32
|
||||
assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error {
|
||||
@@ -168,7 +180,8 @@ func TestUnmarshalRowUint32(t *testing.T) {
|
||||
func TestUnmarshalRowUint64(t *testing.T) {
|
||||
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("6")
|
||||
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
||||
mock.ExpectQuery("select (.+) from users where user=?").
|
||||
WithArgs("anyone").WillReturnRows(rs)
|
||||
|
||||
var value uint64
|
||||
assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error {
|
||||
@@ -181,7 +194,8 @@ func TestUnmarshalRowUint64(t *testing.T) {
|
||||
func TestUnmarshalRowFloat32(t *testing.T) {
|
||||
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("7")
|
||||
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
||||
mock.ExpectQuery("select (.+) from users where user=?").
|
||||
WithArgs("anyone").WillReturnRows(rs)
|
||||
|
||||
var value float32
|
||||
assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error {
|
||||
@@ -194,7 +208,8 @@ func TestUnmarshalRowFloat32(t *testing.T) {
|
||||
func TestUnmarshalRowFloat64(t *testing.T) {
|
||||
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("8")
|
||||
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
||||
mock.ExpectQuery("select (.+) from users where user=?").
|
||||
WithArgs("anyone").WillReturnRows(rs)
|
||||
|
||||
var value float64
|
||||
assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error {
|
||||
@@ -208,7 +223,8 @@ func TestUnmarshalRowString(t *testing.T) {
|
||||
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||
const expect = "hello"
|
||||
rs := sqlmock.NewRows([]string{"value"}).FromCSVString(expect)
|
||||
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
||||
mock.ExpectQuery("select (.+) from users where user=?").
|
||||
WithArgs("anyone").WillReturnRows(rs)
|
||||
|
||||
var value string
|
||||
assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error {
|
||||
@@ -226,7 +242,8 @@ func TestUnmarshalRowStruct(t *testing.T) {
|
||||
})
|
||||
|
||||
rs := sqlmock.NewRows([]string{"name", "age"}).FromCSVString("liao,5")
|
||||
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
||||
mock.ExpectQuery("select (.+) from users where user=?").
|
||||
WithArgs("anyone").WillReturnRows(rs)
|
||||
|
||||
assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error {
|
||||
return unmarshalRow(value, rows, true)
|
||||
@@ -243,7 +260,8 @@ func TestUnmarshalRowStruct(t *testing.T) {
|
||||
|
||||
errAny := errors.New("any error")
|
||||
rs := sqlmock.NewRows([]string{"name", "age"}).FromCSVString("liao,5")
|
||||
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
||||
mock.ExpectQuery("select (.+) from users where user=?").
|
||||
WithArgs("anyone").WillReturnRows(rs)
|
||||
|
||||
assert.ErrorIs(t, query(context.Background(), db, func(rows *sql.Rows) error {
|
||||
return unmarshalRow(value, &mockedScanner{
|
||||
@@ -260,7 +278,23 @@ func TestUnmarshalRowStruct(t *testing.T) {
|
||||
})
|
||||
|
||||
rs := sqlmock.NewRows([]string{"name", "age"}).FromCSVString("liao,5")
|
||||
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
||||
mock.ExpectQuery("select (.+) from users where user=?").
|
||||
WithArgs("anyone").WillReturnRows(rs)
|
||||
|
||||
assert.ErrorIs(t, query(context.Background(), db, func(rows *sql.Rows) error {
|
||||
return unmarshalRow(value, rows, true)
|
||||
}, "select name, age from users where user=?", "anyone"), ErrNotMatchDestination)
|
||||
})
|
||||
|
||||
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||
value := new(struct {
|
||||
Name string
|
||||
age int
|
||||
})
|
||||
|
||||
rs := sqlmock.NewRows([]string{"name", "age"}).FromCSVString("liao,5")
|
||||
mock.ExpectQuery("select (.+) from users where user=?").
|
||||
WithArgs("anyone").WillReturnRows(rs)
|
||||
|
||||
assert.ErrorIs(t, query(context.Background(), db, func(rows *sql.Rows) error {
|
||||
return unmarshalRow(value, rows, true)
|
||||
@@ -269,7 +303,8 @@ func TestUnmarshalRowStruct(t *testing.T) {
|
||||
|
||||
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("8")
|
||||
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
||||
mock.ExpectQuery("select (.+) from users where user=?").
|
||||
WithArgs("anyone").WillReturnRows(rs)
|
||||
|
||||
type myString chan int
|
||||
var value myString
|
||||
@@ -287,7 +322,8 @@ func TestUnmarshalRowStructWithTags(t *testing.T) {
|
||||
})
|
||||
|
||||
rs := sqlmock.NewRows([]string{"name", "age"}).FromCSVString("liao,5")
|
||||
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
||||
mock.ExpectQuery("select (.+) from users where user=?").
|
||||
WithArgs("anyone").WillReturnRows(rs)
|
||||
|
||||
assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error {
|
||||
return unmarshalRow(value, rows, true)
|
||||
@@ -303,7 +339,23 @@ func TestUnmarshalRowStructWithTags(t *testing.T) {
|
||||
})
|
||||
|
||||
rs := sqlmock.NewRows([]string{"name", "age"}).FromCSVString("liao,5")
|
||||
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
||||
mock.ExpectQuery("select (.+) from users where user=?").
|
||||
WithArgs("anyone").WillReturnRows(rs)
|
||||
|
||||
assert.ErrorIs(t, query(context.Background(), db, func(rows *sql.Rows) error {
|
||||
return unmarshalRow(value, rows, true)
|
||||
}, "select name, age from users where user=?", "anyone"), ErrNotReadableValue)
|
||||
})
|
||||
|
||||
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||
value := new(struct {
|
||||
age int `db:"age"`
|
||||
Name string `db:"name"`
|
||||
})
|
||||
|
||||
rs := sqlmock.NewRows([]string{"name", "age"}).FromCSVString("liao,5")
|
||||
mock.ExpectQuery("select (.+) from users where user=?").
|
||||
WithArgs("anyone").WillReturnRows(rs)
|
||||
|
||||
assert.ErrorIs(t, query(context.Background(), db, func(rows *sql.Rows) error {
|
||||
return unmarshalRow(value, rows, true)
|
||||
@@ -317,7 +369,8 @@ func TestUnmarshalRowStructWithTags(t *testing.T) {
|
||||
}
|
||||
|
||||
rs := sqlmock.NewRows([]string{"name", "age"}).FromCSVString("liao,5")
|
||||
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
||||
mock.ExpectQuery("select (.+) from users where user=?").
|
||||
WithArgs("anyone").WillReturnRows(rs)
|
||||
|
||||
assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error {
|
||||
return unmarshalRow(&value, rows, true)
|
||||
@@ -333,7 +386,43 @@ func TestUnmarshalRowStructWithTags(t *testing.T) {
|
||||
})
|
||||
|
||||
rs := sqlmock.NewRows([]string{"name", "age"}).FromCSVString("liao,5")
|
||||
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
||||
mock.ExpectQuery("select (.+) from users where user=?").
|
||||
WithArgs("anyone").WillReturnRows(rs)
|
||||
|
||||
assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error {
|
||||
return unmarshalRow(value, rows, true)
|
||||
}, "select name, age from users where user=?", "anyone"))
|
||||
assert.Equal(t, 5, value.Age)
|
||||
})
|
||||
}
|
||||
|
||||
func TestUnmarshalRowStructWithTagsIgnoreFields(t *testing.T) {
|
||||
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||
value := new(struct {
|
||||
Age int `db:"age"`
|
||||
Name string
|
||||
Ignore bool
|
||||
})
|
||||
|
||||
rs := sqlmock.NewRows([]string{"name", "age"}).FromCSVString("liao,5")
|
||||
mock.ExpectQuery("select (.+) from users where user=?").
|
||||
WithArgs("anyone").WillReturnRows(rs)
|
||||
|
||||
assert.ErrorIs(t, query(context.Background(), db, func(rows *sql.Rows) error {
|
||||
return unmarshalRow(value, rows, true)
|
||||
}, "select name, age from users where user=?", "anyone"), ErrNotMatchDestination)
|
||||
})
|
||||
|
||||
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||
value := new(struct {
|
||||
Age int `db:"age"`
|
||||
Name string
|
||||
Ignore bool `db:"-"`
|
||||
})
|
||||
|
||||
rs := sqlmock.NewRows([]string{"name", "age"}).FromCSVString("liao,5")
|
||||
mock.ExpectQuery("select (.+) from users where user=?").
|
||||
WithArgs("anyone").WillReturnRows(rs)
|
||||
|
||||
assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error {
|
||||
return unmarshalRow(value, rows, true)
|
||||
@@ -350,7 +439,8 @@ func TestUnmarshalRowStructWithTagsWrongColumns(t *testing.T) {
|
||||
|
||||
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||
rs := sqlmock.NewRows([]string{"name"}).FromCSVString("liao")
|
||||
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
||||
mock.ExpectQuery("select (.+) from users where user=?").
|
||||
WithArgs("anyone").WillReturnRows(rs)
|
||||
|
||||
assert.NotNil(t, query(context.Background(), db, func(rows *sql.Rows) error {
|
||||
return unmarshalRow(value, rows, true)
|
||||
@@ -362,7 +452,8 @@ func TestUnmarshalRowsBool(t *testing.T) {
|
||||
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||
expect := []bool{true, false}
|
||||
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("1\n0")
|
||||
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
||||
mock.ExpectQuery("select (.+) from users where user=?").
|
||||
WithArgs("anyone").WillReturnRows(rs)
|
||||
|
||||
var value []bool
|
||||
assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error {
|
||||
@@ -373,7 +464,8 @@ func TestUnmarshalRowsBool(t *testing.T) {
|
||||
|
||||
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("1\n0")
|
||||
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
||||
mock.ExpectQuery("select (.+) from users where user=?").
|
||||
WithArgs("anyone").WillReturnRows(rs)
|
||||
|
||||
var value []bool
|
||||
assert.Error(t, query(context.Background(), db, func(rows *sql.Rows) error {
|
||||
@@ -383,7 +475,8 @@ func TestUnmarshalRowsBool(t *testing.T) {
|
||||
|
||||
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("1\n0")
|
||||
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
||||
mock.ExpectQuery("select (.+) from users where user=?").
|
||||
WithArgs("anyone").WillReturnRows(rs)
|
||||
|
||||
var value struct {
|
||||
value []bool `db:"value"`
|
||||
@@ -395,7 +488,8 @@ func TestUnmarshalRowsBool(t *testing.T) {
|
||||
|
||||
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("1\n0")
|
||||
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
||||
mock.ExpectQuery("select (.+) from users where user=?").
|
||||
WithArgs("anyone").WillReturnRows(rs)
|
||||
|
||||
var value []bool
|
||||
errAny := errors.New("any")
|
||||
@@ -412,7 +506,8 @@ func TestUnmarshalRowsInt(t *testing.T) {
|
||||
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||
expect := []int{2, 3}
|
||||
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2\n3")
|
||||
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
||||
mock.ExpectQuery("select (.+) from users where user=?").
|
||||
WithArgs("anyone").WillReturnRows(rs)
|
||||
|
||||
var value []int
|
||||
assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error {
|
||||
@@ -426,7 +521,8 @@ func TestUnmarshalRowsInt8(t *testing.T) {
|
||||
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||
expect := []int8{2, 3}
|
||||
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2\n3")
|
||||
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
||||
mock.ExpectQuery("select (.+) from users where user=?").
|
||||
WithArgs("anyone").WillReturnRows(rs)
|
||||
|
||||
var value []int8
|
||||
assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error {
|
||||
@@ -440,7 +536,8 @@ func TestUnmarshalRowsInt16(t *testing.T) {
|
||||
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||
expect := []int16{2, 3}
|
||||
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2\n3")
|
||||
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
||||
mock.ExpectQuery("select (.+) from users where user=?").
|
||||
WithArgs("anyone").WillReturnRows(rs)
|
||||
|
||||
var value []int16
|
||||
assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error {
|
||||
@@ -454,7 +551,8 @@ func TestUnmarshalRowsInt32(t *testing.T) {
|
||||
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||
expect := []int32{2, 3}
|
||||
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2\n3")
|
||||
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
||||
mock.ExpectQuery("select (.+) from users where user=?").
|
||||
WithArgs("anyone").WillReturnRows(rs)
|
||||
|
||||
var value []int32
|
||||
assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error {
|
||||
@@ -468,7 +566,8 @@ func TestUnmarshalRowsInt64(t *testing.T) {
|
||||
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||
expect := []int64{2, 3}
|
||||
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2\n3")
|
||||
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
||||
mock.ExpectQuery("select (.+) from users where user=?").
|
||||
WithArgs("anyone").WillReturnRows(rs)
|
||||
|
||||
var value []int64
|
||||
assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error {
|
||||
@@ -482,7 +581,8 @@ func TestUnmarshalRowsUint(t *testing.T) {
|
||||
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||
expect := []uint{2, 3}
|
||||
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2\n3")
|
||||
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
||||
mock.ExpectQuery("select (.+) from users where user=?").
|
||||
WithArgs("anyone").WillReturnRows(rs)
|
||||
|
||||
var value []uint
|
||||
assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error {
|
||||
@@ -496,7 +596,8 @@ func TestUnmarshalRowsUint8(t *testing.T) {
|
||||
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||
expect := []uint8{2, 3}
|
||||
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2\n3")
|
||||
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
||||
mock.ExpectQuery("select (.+) from users where user=?").
|
||||
WithArgs("anyone").WillReturnRows(rs)
|
||||
|
||||
var value []uint8
|
||||
assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error {
|
||||
@@ -510,7 +611,8 @@ func TestUnmarshalRowsUint16(t *testing.T) {
|
||||
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||
expect := []uint16{2, 3}
|
||||
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2\n3")
|
||||
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
||||
mock.ExpectQuery("select (.+) from users where user=?").
|
||||
WithArgs("anyone").WillReturnRows(rs)
|
||||
|
||||
var value []uint16
|
||||
assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error {
|
||||
@@ -524,7 +626,8 @@ func TestUnmarshalRowsUint32(t *testing.T) {
|
||||
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||
expect := []uint32{2, 3}
|
||||
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2\n3")
|
||||
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
||||
mock.ExpectQuery("select (.+) from users where user=?").
|
||||
WithArgs("anyone").WillReturnRows(rs)
|
||||
|
||||
var value []uint32
|
||||
assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error {
|
||||
@@ -538,7 +641,8 @@ func TestUnmarshalRowsUint64(t *testing.T) {
|
||||
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||
expect := []uint64{2, 3}
|
||||
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2\n3")
|
||||
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
||||
mock.ExpectQuery("select (.+) from users where user=?").
|
||||
WithArgs("anyone").WillReturnRows(rs)
|
||||
|
||||
var value []uint64
|
||||
assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error {
|
||||
@@ -552,7 +656,8 @@ func TestUnmarshalRowsFloat32(t *testing.T) {
|
||||
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||
expect := []float32{2, 3}
|
||||
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2\n3")
|
||||
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
||||
mock.ExpectQuery("select (.+) from users where user=?").
|
||||
WithArgs("anyone").WillReturnRows(rs)
|
||||
|
||||
var value []float32
|
||||
assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error {
|
||||
@@ -566,7 +671,8 @@ func TestUnmarshalRowsFloat64(t *testing.T) {
|
||||
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||
expect := []float64{2, 3}
|
||||
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2\n3")
|
||||
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
||||
mock.ExpectQuery("select (.+) from users where user=?").
|
||||
WithArgs("anyone").WillReturnRows(rs)
|
||||
|
||||
var value []float64
|
||||
assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error {
|
||||
@@ -580,7 +686,8 @@ func TestUnmarshalRowsString(t *testing.T) {
|
||||
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||
expect := []string{"hello", "world"}
|
||||
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("hello\nworld")
|
||||
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
||||
mock.ExpectQuery("select (.+) from users where user=?").
|
||||
WithArgs("anyone").WillReturnRows(rs)
|
||||
|
||||
var value []string
|
||||
assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error {
|
||||
@@ -596,7 +703,8 @@ func TestUnmarshalRowsBoolPtr(t *testing.T) {
|
||||
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||
expect := []*bool{&yes, &no}
|
||||
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("1\n0")
|
||||
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
||||
mock.ExpectQuery("select (.+) from users where user=?").
|
||||
WithArgs("anyone").WillReturnRows(rs)
|
||||
|
||||
var value []*bool
|
||||
assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error {
|
||||
@@ -612,7 +720,8 @@ func TestUnmarshalRowsIntPtr(t *testing.T) {
|
||||
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||
expect := []*int{&two, &three}
|
||||
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2\n3")
|
||||
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
||||
mock.ExpectQuery("select (.+) from users where user=?").
|
||||
WithArgs("anyone").WillReturnRows(rs)
|
||||
|
||||
var value []*int
|
||||
assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error {
|
||||
@@ -628,7 +737,8 @@ func TestUnmarshalRowsInt8Ptr(t *testing.T) {
|
||||
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||
expect := []*int8{&two, &three}
|
||||
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2\n3")
|
||||
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
||||
mock.ExpectQuery("select (.+) from users where user=?").
|
||||
WithArgs("anyone").WillReturnRows(rs)
|
||||
|
||||
var value []*int8
|
||||
assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error {
|
||||
@@ -644,7 +754,8 @@ func TestUnmarshalRowsInt16Ptr(t *testing.T) {
|
||||
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||
expect := []*int16{&two, &three}
|
||||
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2\n3")
|
||||
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
||||
mock.ExpectQuery("select (.+) from users where user=?").
|
||||
WithArgs("anyone").WillReturnRows(rs)
|
||||
|
||||
var value []*int16
|
||||
assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error {
|
||||
@@ -660,7 +771,8 @@ func TestUnmarshalRowsInt32Ptr(t *testing.T) {
|
||||
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||
expect := []*int32{&two, &three}
|
||||
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2\n3")
|
||||
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
||||
mock.ExpectQuery("select (.+) from users where user=?").
|
||||
WithArgs("anyone").WillReturnRows(rs)
|
||||
|
||||
var value []*int32
|
||||
assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error {
|
||||
@@ -676,7 +788,8 @@ func TestUnmarshalRowsInt64Ptr(t *testing.T) {
|
||||
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||
expect := []*int64{&two, &three}
|
||||
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2\n3")
|
||||
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
||||
mock.ExpectQuery("select (.+) from users where user=?").
|
||||
WithArgs("anyone").WillReturnRows(rs)
|
||||
|
||||
var value []*int64
|
||||
assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error {
|
||||
@@ -692,7 +805,8 @@ func TestUnmarshalRowsUintPtr(t *testing.T) {
|
||||
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||
expect := []*uint{&two, &three}
|
||||
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2\n3")
|
||||
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
||||
mock.ExpectQuery("select (.+) from users where user=?").
|
||||
WithArgs("anyone").WillReturnRows(rs)
|
||||
|
||||
var value []*uint
|
||||
assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error {
|
||||
@@ -708,7 +822,8 @@ func TestUnmarshalRowsUint8Ptr(t *testing.T) {
|
||||
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||
expect := []*uint8{&two, &three}
|
||||
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2\n3")
|
||||
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
||||
mock.ExpectQuery("select (.+) from users where user=?").
|
||||
WithArgs("anyone").WillReturnRows(rs)
|
||||
|
||||
var value []*uint8
|
||||
assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error {
|
||||
@@ -724,7 +839,8 @@ func TestUnmarshalRowsUint16Ptr(t *testing.T) {
|
||||
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||
expect := []*uint16{&two, &three}
|
||||
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2\n3")
|
||||
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
||||
mock.ExpectQuery("select (.+) from users where user=?").
|
||||
WithArgs("anyone").WillReturnRows(rs)
|
||||
|
||||
var value []*uint16
|
||||
assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error {
|
||||
@@ -740,7 +856,8 @@ func TestUnmarshalRowsUint32Ptr(t *testing.T) {
|
||||
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||
expect := []*uint32{&two, &three}
|
||||
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2\n3")
|
||||
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
||||
mock.ExpectQuery("select (.+) from users where user=?").
|
||||
WithArgs("anyone").WillReturnRows(rs)
|
||||
|
||||
var value []*uint32
|
||||
assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error {
|
||||
@@ -756,7 +873,8 @@ func TestUnmarshalRowsUint64Ptr(t *testing.T) {
|
||||
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||
expect := []*uint64{&two, &three}
|
||||
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2\n3")
|
||||
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
||||
mock.ExpectQuery("select (.+) from users where user=?").
|
||||
WithArgs("anyone").WillReturnRows(rs)
|
||||
|
||||
var value []*uint64
|
||||
assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error {
|
||||
@@ -772,7 +890,8 @@ func TestUnmarshalRowsFloat32Ptr(t *testing.T) {
|
||||
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||
expect := []*float32{&two, &three}
|
||||
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2\n3")
|
||||
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
||||
mock.ExpectQuery("select (.+) from users where user=?").
|
||||
WithArgs("anyone").WillReturnRows(rs)
|
||||
|
||||
var value []*float32
|
||||
assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error {
|
||||
@@ -788,7 +907,8 @@ func TestUnmarshalRowsFloat64Ptr(t *testing.T) {
|
||||
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||
expect := []*float64{&two, &three}
|
||||
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2\n3")
|
||||
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
||||
mock.ExpectQuery("select (.+) from users where user=?").
|
||||
WithArgs("anyone").WillReturnRows(rs)
|
||||
|
||||
var value []*float64
|
||||
assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error {
|
||||
@@ -804,7 +924,8 @@ func TestUnmarshalRowsStringPtr(t *testing.T) {
|
||||
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||
expect := []*string{&hello, &world}
|
||||
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("hello\nworld")
|
||||
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
||||
mock.ExpectQuery("select (.+) from users where user=?").
|
||||
WithArgs("anyone").WillReturnRows(rs)
|
||||
|
||||
var value []*string
|
||||
assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error {
|
||||
@@ -835,7 +956,8 @@ func TestUnmarshalRowsStruct(t *testing.T) {
|
||||
}
|
||||
|
||||
rs := sqlmock.NewRows([]string{"name", "age"}).FromCSVString("first,2\nsecond,3")
|
||||
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
||||
mock.ExpectQuery("select (.+) from users where user=?").
|
||||
WithArgs("anyone").WillReturnRows(rs)
|
||||
assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error {
|
||||
return unmarshalRows(&value, rows, true)
|
||||
}, "select name, age from users where user=?", "anyone"))
|
||||
@@ -854,7 +976,8 @@ func TestUnmarshalRowsStruct(t *testing.T) {
|
||||
|
||||
errAny := errors.New("any error")
|
||||
rs := sqlmock.NewRows([]string{"name", "age"}).FromCSVString("first,2\nsecond,3")
|
||||
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
||||
mock.ExpectQuery("select (.+) from users where user=?").
|
||||
WithArgs("anyone").WillReturnRows(rs)
|
||||
assert.ErrorIs(t, query(context.Background(), db, func(rows *sql.Rows) error {
|
||||
return unmarshalRows(&value, &mockedScanner{
|
||||
colErr: errAny,
|
||||
@@ -871,7 +994,8 @@ func TestUnmarshalRowsStruct(t *testing.T) {
|
||||
|
||||
errAny := errors.New("any error")
|
||||
rs := sqlmock.NewRows([]string{"name", "age"}).FromCSVString("first,2\nsecond,3")
|
||||
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
||||
mock.ExpectQuery("select (.+) from users where user=?").
|
||||
WithArgs("anyone").WillReturnRows(rs)
|
||||
assert.ErrorIs(t, query(context.Background(), db, func(rows *sql.Rows) error {
|
||||
return unmarshalRows(&value, &mockedScanner{
|
||||
cols: []string{"name", "age"},
|
||||
@@ -886,7 +1010,8 @@ func TestUnmarshalRowsStruct(t *testing.T) {
|
||||
|
||||
errAny := errors.New("any error")
|
||||
rs := sqlmock.NewRows([]string{"name", "age"}).FromCSVString("first,2\nsecond,3")
|
||||
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
||||
mock.ExpectQuery("select (.+) from users where user=?").
|
||||
WithArgs("anyone").WillReturnRows(rs)
|
||||
assert.ErrorIs(t, query(context.Background(), db, func(rows *sql.Rows) error {
|
||||
return unmarshalRows(&value, &mockedScanner{
|
||||
cols: []string{"name", "age"},
|
||||
@@ -925,7 +1050,8 @@ func TestUnmarshalRowsStructWithNullStringType(t *testing.T) {
|
||||
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||
rs := sqlmock.NewRows([]string{"name", "value"}).AddRow(
|
||||
"first", "firstnullstring").AddRow("second", nil)
|
||||
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
||||
mock.ExpectQuery("select (.+) from users where user=?").
|
||||
WithArgs("anyone").WillReturnRows(rs)
|
||||
assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error {
|
||||
return unmarshalRows(&value, rows, true)
|
||||
}, "select name, age from users where user=?", "anyone"))
|
||||
@@ -959,7 +1085,62 @@ func TestUnmarshalRowsStructWithTags(t *testing.T) {
|
||||
|
||||
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||
rs := sqlmock.NewRows([]string{"name", "age"}).FromCSVString("first,2\nsecond,3")
|
||||
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
||||
mock.ExpectQuery("select (.+) from users where user=?").
|
||||
WithArgs("anyone").WillReturnRows(rs)
|
||||
assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error {
|
||||
return unmarshalRows(&value, rows, true)
|
||||
}, "select name, age from users where user=?", "anyone"))
|
||||
|
||||
for i, each := range expect {
|
||||
assert.Equal(t, each.Name, value[i].Name)
|
||||
assert.Equal(t, each.Age, value[i].Age)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestUnmarshalRowsStructWithTagsIgnoreFields(t *testing.T) {
|
||||
expect := []struct {
|
||||
Name string
|
||||
Age int64
|
||||
Ignore bool
|
||||
}{
|
||||
{
|
||||
Name: "first",
|
||||
Age: 2,
|
||||
Ignore: false,
|
||||
},
|
||||
{
|
||||
Name: "second",
|
||||
Age: 3,
|
||||
Ignore: false,
|
||||
},
|
||||
}
|
||||
|
||||
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||
var value []struct {
|
||||
Age int64 `db:"age"`
|
||||
Name string `db:"name"`
|
||||
Ignore bool
|
||||
}
|
||||
|
||||
rs := sqlmock.NewRows([]string{"name", "age"}).FromCSVString("first,2\nsecond,3")
|
||||
mock.ExpectQuery("select (.+) from users where user=?").
|
||||
WithArgs("anyone").WillReturnRows(rs)
|
||||
assert.ErrorIs(t, query(context.Background(), db, func(rows *sql.Rows) error {
|
||||
return unmarshalRows(&value, rows, true)
|
||||
}, "select name, age from users where user=?", "anyone"), ErrNotMatchDestination)
|
||||
})
|
||||
|
||||
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||
var value []struct {
|
||||
Age int64 `db:"age"`
|
||||
Name string `db:"name"`
|
||||
Ignore bool `db:"-"`
|
||||
}
|
||||
|
||||
rs := sqlmock.NewRows([]string{"name", "age"}).FromCSVString("first,2\nsecond,3")
|
||||
mock.ExpectQuery("select (.+) from users where user=?").
|
||||
WithArgs("anyone").WillReturnRows(rs)
|
||||
assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error {
|
||||
return unmarshalRows(&value, rows, true)
|
||||
}, "select name, age from users where user=?", "anyone"))
|
||||
@@ -1000,7 +1181,8 @@ func TestUnmarshalRowsStructAndEmbeddedAnonymousStructWithTags(t *testing.T) {
|
||||
|
||||
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||
rs := sqlmock.NewRows([]string{"name", "age", "value"}).FromCSVString("first,2,3\nsecond,3,4")
|
||||
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
||||
mock.ExpectQuery("select (.+) from users where user=?").
|
||||
WithArgs("anyone").WillReturnRows(rs)
|
||||
assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error {
|
||||
return unmarshalRows(&value, rows, true)
|
||||
}, "select name, age, value from users where user=?", "anyone"))
|
||||
@@ -1042,7 +1224,8 @@ func TestUnmarshalRowsStructAndEmbeddedStructPtrAnonymousWithTags(t *testing.T)
|
||||
|
||||
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||
rs := sqlmock.NewRows([]string{"name", "age", "value"}).FromCSVString("first,2,3\nsecond,3,4")
|
||||
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
||||
mock.ExpectQuery("select (.+) from users where user=?").
|
||||
WithArgs("anyone").WillReturnRows(rs)
|
||||
assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error {
|
||||
return unmarshalRows(&value, rows, true)
|
||||
}, "select name, age, value from users where user=?", "anyone"))
|
||||
@@ -1076,7 +1259,8 @@ func TestUnmarshalRowsStructPtr(t *testing.T) {
|
||||
|
||||
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||
rs := sqlmock.NewRows([]string{"name", "age"}).FromCSVString("first,2\nsecond,3")
|
||||
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
||||
mock.ExpectQuery("select (.+) from users where user=?").
|
||||
WithArgs("anyone").WillReturnRows(rs)
|
||||
assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error {
|
||||
return unmarshalRows(&value, rows, true)
|
||||
}, "select name, age from users where user=?", "anyone"))
|
||||
@@ -1109,7 +1293,8 @@ func TestUnmarshalRowsStructWithTagsPtr(t *testing.T) {
|
||||
|
||||
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||
rs := sqlmock.NewRows([]string{"name", "age"}).FromCSVString("first,2\nsecond,3")
|
||||
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
||||
mock.ExpectQuery("select (.+) from users where user=?").
|
||||
WithArgs("anyone").WillReturnRows(rs)
|
||||
assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error {
|
||||
return unmarshalRows(&value, rows, true)
|
||||
}, "select name, age from users where user=?", "anyone"))
|
||||
@@ -1142,7 +1327,8 @@ func TestUnmarshalRowsStructWithTagsPtrWithInnerPtr(t *testing.T) {
|
||||
|
||||
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||
rs := sqlmock.NewRows([]string{"name", "age"}).FromCSVString("first,2\nsecond,3")
|
||||
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
||||
mock.ExpectQuery("select (.+) from users where user=?").
|
||||
WithArgs("anyone").WillReturnRows(rs)
|
||||
assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error {
|
||||
return unmarshalRows(&value, rows, true)
|
||||
}, "select name, age from users where user=?", "anyone"))
|
||||
@@ -1157,7 +1343,8 @@ func TestUnmarshalRowsStructWithTagsPtrWithInnerPtr(t *testing.T) {
|
||||
func TestCommonSqlConn_QueryRowOptional(t *testing.T) {
|
||||
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||
rs := sqlmock.NewRows([]string{"age"}).FromCSVString("5")
|
||||
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
||||
mock.ExpectQuery("select (.+) from users where user=?").
|
||||
WithArgs("anyone").WillReturnRows(rs)
|
||||
|
||||
var r struct {
|
||||
User string `db:"user"`
|
||||
@@ -1207,8 +1394,8 @@ func TestUnmarshalRowError(t *testing.T) {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||
rs := sqlmock.NewRows([]string{"age"}).FromCSVString("5")
|
||||
mock.ExpectQuery("select (.+) from users where user=?").WithArgs(
|
||||
"anyone").WillReturnRows(rs)
|
||||
mock.ExpectQuery("select (.+) from users where user=?").
|
||||
WithArgs("anyone").WillReturnRows(rs)
|
||||
|
||||
var r struct {
|
||||
User string `db:"user"`
|
||||
@@ -1307,25 +1494,25 @@ func TestAnonymousStructPr(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestAnonymousStructPrError(t *testing.T) {
|
||||
type Score struct {
|
||||
Discipline string `db:"discipline"`
|
||||
score uint `db:"score"`
|
||||
}
|
||||
type ClassType struct {
|
||||
Grade sql.NullString `db:"grade"`
|
||||
ClassName *string `db:"class_name"`
|
||||
}
|
||||
type Class struct {
|
||||
*ClassType
|
||||
Score
|
||||
}
|
||||
var value []*struct {
|
||||
Age int64 `db:"age"`
|
||||
Class
|
||||
Name string `db:"name"`
|
||||
}
|
||||
|
||||
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||
type Score struct {
|
||||
Discipline string `db:"discipline"`
|
||||
score uint `db:"score"`
|
||||
}
|
||||
type ClassType struct {
|
||||
Grade sql.NullString `db:"grade"`
|
||||
ClassName *string `db:"class_name"`
|
||||
}
|
||||
type Class struct {
|
||||
*ClassType
|
||||
Score
|
||||
}
|
||||
|
||||
var value []*struct {
|
||||
Age int64 `db:"age"`
|
||||
Class
|
||||
Name string `db:"name"`
|
||||
}
|
||||
rs := sqlmock.NewRows([]string{
|
||||
"name",
|
||||
"age",
|
||||
@@ -1338,14 +1525,82 @@ func TestAnonymousStructPrError(t *testing.T) {
|
||||
AddRow("second", 3, "grade one", "chinese", "class three grade two", 99)
|
||||
mock.ExpectQuery("select (.+) from users where user=?").
|
||||
WithArgs("anyone").WillReturnRows(rs)
|
||||
assert.Error(t, query(context.Background(), db, func(rows *sql.Rows) error {
|
||||
assert.ErrorIs(t, query(context.Background(), db, func(rows *sql.Rows) error {
|
||||
return unmarshalRows(&value, rows, true)
|
||||
}, "select name, age, grade, discipline, class_name, score from users where user=?",
|
||||
"anyone"))
|
||||
"anyone"), ErrNotReadableValue)
|
||||
if len(value) > 0 {
|
||||
assert.Equal(t, value[0].score, 0)
|
||||
}
|
||||
})
|
||||
|
||||
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||
type Score struct {
|
||||
Discipline string
|
||||
score uint
|
||||
}
|
||||
type ClassType struct {
|
||||
Grade sql.NullString
|
||||
ClassName *string
|
||||
}
|
||||
type Class struct {
|
||||
*ClassType
|
||||
Score
|
||||
}
|
||||
|
||||
var value []*struct {
|
||||
Age int64
|
||||
Class
|
||||
Name string
|
||||
}
|
||||
rs := sqlmock.NewRows([]string{
|
||||
"name",
|
||||
"age",
|
||||
"grade",
|
||||
"discipline",
|
||||
"class_name",
|
||||
"score",
|
||||
}).
|
||||
AddRow("first", 2, nil, "math", "experimental class", 100).
|
||||
AddRow("second", 3, "grade one", "chinese", "class three grade two", 99)
|
||||
mock.ExpectQuery("select (.+) from users where user=?").
|
||||
WithArgs("anyone").WillReturnRows(rs)
|
||||
assert.ErrorIs(t, query(context.Background(), db, func(rows *sql.Rows) error {
|
||||
return unmarshalRows(&value, rows, true)
|
||||
}, "select name, age, grade, discipline, class_name, score from users where user=?",
|
||||
"anyone"), ErrNotMatchDestination)
|
||||
if len(value) > 0 {
|
||||
assert.Equal(t, value[0].score, 0)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func BenchmarkIgnore(b *testing.B) {
|
||||
db, mock, err := sqlmock.New()
|
||||
if err != nil {
|
||||
b.Fatalf("an error '%s' was not expected when opening a stub database connection", err)
|
||||
}
|
||||
defer func() {
|
||||
_ = db.Close()
|
||||
}()
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
value := new(struct {
|
||||
Age int `db:"age"`
|
||||
Name string
|
||||
Ignore bool `db:"-"`
|
||||
})
|
||||
|
||||
rs := sqlmock.NewRows([]string{"name", "age", "ignore"}).FromCSVString("liao,5,true")
|
||||
mock.ExpectQuery("select (.+) from users where user=?").
|
||||
WithArgs("anyone").WillReturnRows(rs)
|
||||
|
||||
assert.Nil(b, query(context.Background(), db, func(rows *sql.Rows) error {
|
||||
return unmarshalRow(value, rows, true)
|
||||
}, "select name, age from users where user=?", "anyone"))
|
||||
assert.Equal(b, 5, value.Age)
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
type mockedScanner struct {
|
||||
|
||||
65
core/stores/sqlx/rwstrategy.go
Normal file
65
core/stores/sqlx/rwstrategy.go
Normal file
@@ -0,0 +1,65 @@
|
||||
package sqlx
|
||||
|
||||
import "context"
|
||||
|
||||
const (
|
||||
// policyRoundRobin round-robin policy for selecting replicas.
|
||||
policyRoundRobin = "round-robin"
|
||||
// policyRandom random policy for selecting replicas.
|
||||
policyRandom = "random"
|
||||
|
||||
// readPrimaryMode indicates that the operation is a read,
|
||||
// but should be performed on the primary database instance.
|
||||
//
|
||||
// This mode is used in scenarios where data freshness and consistency are critical,
|
||||
// such as immediately after writes or where replication lag may cause stale reads.
|
||||
readPrimaryMode readWriteMode = "read-primary"
|
||||
|
||||
// readReplicaMode indicates that the operation is a read from replicas.
|
||||
// This is suitable for scenarios where eventual consistency is acceptable,
|
||||
// and the goal is to offload traffic from the primary and improve read scalability.
|
||||
readReplicaMode readWriteMode = "read-replica"
|
||||
|
||||
// writeMode indicates that the operation is a write operation (to primary).
|
||||
writeMode readWriteMode = "write"
|
||||
|
||||
// notSpecifiedMode indicates that the read/write mode is not specified.
|
||||
notSpecifiedMode readWriteMode = ""
|
||||
)
|
||||
|
||||
type readWriteModeKey struct{}
|
||||
|
||||
// WithReadPrimary sets the context to read-primary mode.
|
||||
func WithReadPrimary(ctx context.Context) context.Context {
|
||||
return context.WithValue(ctx, readWriteModeKey{}, readPrimaryMode)
|
||||
}
|
||||
|
||||
// WithReadReplica sets the context to read-replica mode.
|
||||
func WithReadReplica(ctx context.Context) context.Context {
|
||||
return context.WithValue(ctx, readWriteModeKey{}, readReplicaMode)
|
||||
}
|
||||
|
||||
// WithWrite sets the context to write mode, indicating that the operation is a write operation.
|
||||
func WithWrite(ctx context.Context) context.Context {
|
||||
return context.WithValue(ctx, readWriteModeKey{}, writeMode)
|
||||
}
|
||||
|
||||
type readWriteMode string
|
||||
|
||||
func (m readWriteMode) isValid() bool {
|
||||
return m == readPrimaryMode || m == readReplicaMode || m == writeMode
|
||||
}
|
||||
|
||||
func getReadWriteMode(ctx context.Context) readWriteMode {
|
||||
if mode := ctx.Value(readWriteModeKey{}); mode != nil {
|
||||
if v, ok := mode.(readWriteMode); ok && v.isValid() {
|
||||
return v
|
||||
}
|
||||
}
|
||||
|
||||
return notSpecifiedMode
|
||||
}
|
||||
|
||||
func usePrimary(ctx context.Context) bool {
|
||||
return getReadWriteMode(ctx) != readReplicaMode
|
||||
}
|
||||
142
core/stores/sqlx/rwstrategy_test.go
Normal file
142
core/stores/sqlx/rwstrategy_test.go
Normal file
@@ -0,0 +1,142 @@
|
||||
package sqlx
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestIsValid(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
mode readWriteMode
|
||||
expected bool
|
||||
}{
|
||||
{
|
||||
name: "valid read-primary mode",
|
||||
mode: readPrimaryMode,
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "valid read-replica mode",
|
||||
mode: readReplicaMode,
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "valid write mode",
|
||||
mode: writeMode,
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "not specified mode (empty)",
|
||||
mode: notSpecifiedMode,
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "invalid custom string",
|
||||
mode: readWriteMode("delete"),
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "case sensitive check",
|
||||
mode: readWriteMode("READ"),
|
||||
expected: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
actual := tc.mode.isValid()
|
||||
assert.Equal(t, tc.expected, actual)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestWithReadMode(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
readPrimaryCtx := WithReadPrimary(ctx)
|
||||
|
||||
val := readPrimaryCtx.Value(readWriteModeKey{})
|
||||
assert.Equal(t, readPrimaryMode, val)
|
||||
|
||||
readReplicaCtx := WithReadReplica(ctx)
|
||||
val = readReplicaCtx.Value(readWriteModeKey{})
|
||||
assert.Equal(t, readReplicaMode, val)
|
||||
}
|
||||
|
||||
func TestWithWriteMode(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
writeCtx := WithWrite(ctx)
|
||||
|
||||
val := writeCtx.Value(readWriteModeKey{})
|
||||
assert.Equal(t, writeMode, val)
|
||||
}
|
||||
|
||||
func TestGetReadWriteMode(t *testing.T) {
|
||||
t.Run("valid read-primary mode", func(t *testing.T) {
|
||||
ctx := context.WithValue(context.Background(), readWriteModeKey{}, readPrimaryMode)
|
||||
assert.Equal(t, readPrimaryMode, getReadWriteMode(ctx))
|
||||
})
|
||||
|
||||
t.Run("valid read-replica mode", func(t *testing.T) {
|
||||
ctx := context.WithValue(context.Background(), readWriteModeKey{}, readReplicaMode)
|
||||
assert.Equal(t, readReplicaMode, getReadWriteMode(ctx))
|
||||
})
|
||||
|
||||
t.Run("valid write mode", func(t *testing.T) {
|
||||
ctx := context.WithValue(context.Background(), readWriteModeKey{}, writeMode)
|
||||
assert.Equal(t, writeMode, getReadWriteMode(ctx))
|
||||
})
|
||||
|
||||
t.Run("invalid mode value (wrong type)", func(t *testing.T) {
|
||||
ctx := context.WithValue(context.Background(), readWriteModeKey{}, "not-a-mode")
|
||||
assert.Equal(t, notSpecifiedMode, getReadWriteMode(ctx))
|
||||
})
|
||||
|
||||
t.Run("invalid mode value (wrong value)", func(t *testing.T) {
|
||||
ctx := context.WithValue(context.Background(), readWriteModeKey{}, readWriteMode("delete"))
|
||||
assert.Equal(t, notSpecifiedMode, getReadWriteMode(ctx))
|
||||
})
|
||||
|
||||
t.Run("no mode set", func(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
assert.Equal(t, notSpecifiedMode, getReadWriteMode(ctx))
|
||||
})
|
||||
}
|
||||
|
||||
func TestUsePrimary(t *testing.T) {
|
||||
t.Run("context with read-replica mode", func(t *testing.T) {
|
||||
ctx := context.WithValue(context.Background(), readWriteModeKey{}, readReplicaMode)
|
||||
assert.False(t, usePrimary(ctx))
|
||||
})
|
||||
|
||||
t.Run("context with read-primary mode", func(t *testing.T) {
|
||||
ctx := context.WithValue(context.Background(), readWriteModeKey{}, readPrimaryMode)
|
||||
assert.True(t, usePrimary(ctx))
|
||||
})
|
||||
|
||||
t.Run("context with write mode", func(t *testing.T) {
|
||||
ctx := context.WithValue(context.Background(), readWriteModeKey{}, writeMode)
|
||||
assert.True(t, usePrimary(ctx))
|
||||
})
|
||||
|
||||
t.Run("context with invalid mode", func(t *testing.T) {
|
||||
ctx := context.WithValue(context.Background(), readWriteModeKey{}, readWriteMode("invalid"))
|
||||
assert.True(t, usePrimary(ctx))
|
||||
})
|
||||
|
||||
t.Run("context with no mode set", func(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
assert.True(t, usePrimary(ctx))
|
||||
})
|
||||
}
|
||||
|
||||
func TestWithModeTwice(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
ctx = WithReadPrimary(ctx)
|
||||
writeCtx := WithWrite(ctx)
|
||||
|
||||
val := writeCtx.Value(readWriteModeKey{})
|
||||
assert.Equal(t, writeMode, val)
|
||||
}
|
||||
@@ -4,6 +4,9 @@ import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"errors"
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"sync/atomic"
|
||||
|
||||
"github.com/zeromicro/go-zero/core/breaker"
|
||||
"github.com/zeromicro/go-zero/core/errorx"
|
||||
@@ -52,9 +55,10 @@ type (
|
||||
beginTx beginnable
|
||||
brk breaker.Breaker
|
||||
accept breaker.Acceptable
|
||||
index uint32
|
||||
}
|
||||
|
||||
connProvider func() (*sql.DB, error)
|
||||
connProvider func(ctx context.Context) (*sql.DB, error)
|
||||
|
||||
sessionConn interface {
|
||||
Exec(query string, args ...any) (sql.Result, error)
|
||||
@@ -64,10 +68,41 @@ type (
|
||||
}
|
||||
)
|
||||
|
||||
// MustNewConn returns a SqlConn with the given SqlConf.
|
||||
func MustNewConn(c SqlConf, opts ...SqlOption) SqlConn {
|
||||
conn, err := NewConn(c, opts...)
|
||||
if err != nil {
|
||||
logx.Must(err)
|
||||
}
|
||||
|
||||
return conn
|
||||
}
|
||||
|
||||
// NewConn returns a SqlConn with the given SqlConf.
|
||||
func NewConn(c SqlConf, opts ...SqlOption) (SqlConn, error) {
|
||||
if err := c.Validate(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
conn := &commonSqlConn{
|
||||
onError: func(ctx context.Context, err error) {
|
||||
logInstanceError(ctx, c.DataSource, err)
|
||||
},
|
||||
beginTx: begin,
|
||||
brk: breaker.NewBreaker(),
|
||||
}
|
||||
for _, opt := range opts {
|
||||
opt(conn)
|
||||
}
|
||||
conn.connProv = getConnProvider(conn, c.DriverName, c.DataSource, c.Policy, c.Replicas)
|
||||
|
||||
return conn, nil
|
||||
}
|
||||
|
||||
// NewSqlConn returns a SqlConn with given driver name and datasource.
|
||||
func NewSqlConn(driverName, datasource string, opts ...SqlOption) SqlConn {
|
||||
conn := &commonSqlConn{
|
||||
connProv: func() (*sql.DB, error) {
|
||||
connProv: func(context.Context) (*sql.DB, error) {
|
||||
return getSqlConn(driverName, datasource)
|
||||
},
|
||||
onError: func(ctx context.Context, err error) {
|
||||
@@ -87,7 +122,7 @@ func NewSqlConn(driverName, datasource string, opts ...SqlOption) SqlConn {
|
||||
// Use it with caution; it's provided for other ORM to interact with.
|
||||
func NewSqlConnFromDB(db *sql.DB, opts ...SqlOption) SqlConn {
|
||||
conn := &commonSqlConn{
|
||||
connProv: func() (*sql.DB, error) {
|
||||
connProv: func(ctx context.Context) (*sql.DB, error) {
|
||||
return db, nil
|
||||
},
|
||||
onError: func(ctx context.Context, err error) {
|
||||
@@ -123,7 +158,7 @@ func (db *commonSqlConn) ExecCtx(ctx context.Context, q string, args ...any) (
|
||||
|
||||
err = db.brk.DoWithAcceptableCtx(ctx, func() error {
|
||||
var conn *sql.DB
|
||||
conn, err = db.connProv()
|
||||
conn, err = db.connProv(ctx)
|
||||
if err != nil {
|
||||
db.onError(ctx, err)
|
||||
return err
|
||||
@@ -151,7 +186,7 @@ func (db *commonSqlConn) PrepareCtx(ctx context.Context, query string) (stmt Stm
|
||||
|
||||
err = db.brk.DoWithAcceptableCtx(ctx, func() error {
|
||||
var conn *sql.DB
|
||||
conn, err = db.connProv()
|
||||
conn, err = db.connProv(ctx)
|
||||
if err != nil {
|
||||
db.onError(ctx, err)
|
||||
return err
|
||||
@@ -242,7 +277,7 @@ func (db *commonSqlConn) QueryRowsPartialCtx(ctx context.Context, v any,
|
||||
}
|
||||
|
||||
func (db *commonSqlConn) RawDB() (*sql.DB, error) {
|
||||
return db.connProv()
|
||||
return db.connProv(context.Background())
|
||||
}
|
||||
|
||||
func (db *commonSqlConn) Transact(fn func(Session) error) error {
|
||||
@@ -288,7 +323,7 @@ func (db *commonSqlConn) queryRows(ctx context.Context, scanner func(*sql.Rows)
|
||||
q string, args ...any) (err error) {
|
||||
var scanFailed bool
|
||||
err = db.brk.DoWithAcceptableCtx(ctx, func() error {
|
||||
conn, err := db.connProv()
|
||||
conn, err := db.connProv(ctx)
|
||||
if err != nil {
|
||||
db.onError(ctx, err)
|
||||
return err
|
||||
@@ -311,6 +346,38 @@ func (db *commonSqlConn) queryRows(ctx context.Context, scanner func(*sql.Rows)
|
||||
return
|
||||
}
|
||||
|
||||
func getConnProvider(sc *commonSqlConn, driverName, datasource, policy string, replicas []string) connProvider {
|
||||
return func(ctx context.Context) (*sql.DB, error) {
|
||||
replicaCount := len(replicas)
|
||||
|
||||
if replicaCount == 0 || usePrimary(ctx) {
|
||||
return getSqlConn(driverName, datasource)
|
||||
}
|
||||
|
||||
var dsn string
|
||||
|
||||
if replicaCount == 1 {
|
||||
dsn = replicas[0]
|
||||
} else {
|
||||
if len(policy) == 0 {
|
||||
policy = policyRoundRobin
|
||||
}
|
||||
|
||||
switch policy {
|
||||
case policyRandom:
|
||||
dsn = replicas[rand.Intn(replicaCount)]
|
||||
case policyRoundRobin:
|
||||
index := atomic.AddUint32(&sc.index, 1) - 1
|
||||
dsn = replicas[index%uint32(replicaCount)]
|
||||
default:
|
||||
return nil, fmt.Errorf("unknown policy: %s", policy)
|
||||
}
|
||||
}
|
||||
|
||||
return getSqlConn(driverName, dsn)
|
||||
}
|
||||
}
|
||||
|
||||
// WithAcceptable returns a SqlOption that setting the acceptable function.
|
||||
// acceptable is the func to check if the error can be accepted.
|
||||
func WithAcceptable(acceptable func(err error) bool) SqlOption {
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package sqlx
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"errors"
|
||||
"io"
|
||||
@@ -98,7 +99,7 @@ func TestSqlConn_RawDB(t *testing.T) {
|
||||
func TestSqlConn_Errors(t *testing.T) {
|
||||
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||
conn := NewSqlConnFromDB(db)
|
||||
conn.(*commonSqlConn).connProv = func() (*sql.DB, error) {
|
||||
conn.(*commonSqlConn).connProv = func(ctx context.Context) (*sql.DB, error) {
|
||||
return nil, errors.New("error")
|
||||
}
|
||||
_, err := conn.Prepare("any")
|
||||
@@ -138,6 +139,148 @@ func TestSqlConn_Errors(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
func TestConfigSqlConn(t *testing.T) {
|
||||
db, mock, err := sqlmock.New()
|
||||
assert.NotNil(t, db)
|
||||
assert.NotNil(t, mock)
|
||||
assert.Nil(t, err)
|
||||
connManager.Inject(mockedDatasource, db)
|
||||
mock.ExpectExec("any")
|
||||
mock.ExpectQuery("any").WillReturnRows(sqlmock.NewRows([]string{"foo"}))
|
||||
|
||||
conf := SqlConf{DataSource: mockedDatasource, DriverName: mysqlDriverName}
|
||||
conn := MustNewConn(conf, withMysqlAcceptable())
|
||||
|
||||
_, err = conn.Exec("any", "value")
|
||||
assert.NotNil(t, err)
|
||||
_, err = conn.Prepare("any")
|
||||
assert.NotNil(t, err)
|
||||
|
||||
var val string
|
||||
assert.NotNil(t, conn.QueryRow(&val, "any"))
|
||||
assert.NotNil(t, conn.QueryRowPartial(&val, "any"))
|
||||
assert.NotNil(t, conn.QueryRows(&val, "any"))
|
||||
assert.NotNil(t, conn.QueryRowsPartial(&val, "any"))
|
||||
}
|
||||
|
||||
func TestConfigSqlConnStatement(t *testing.T) {
|
||||
db, mock, err := sqlmock.New()
|
||||
assert.NotNil(t, db)
|
||||
assert.NotNil(t, mock)
|
||||
assert.Nil(t, err)
|
||||
connManager.Inject(mockedDatasource, db)
|
||||
|
||||
mock.ExpectPrepare("any")
|
||||
mock.ExpectExec("any").WillReturnResult(sqlmock.NewResult(2, 3))
|
||||
mock.ExpectPrepare("any")
|
||||
row := sqlmock.NewRows([]string{"foo"}).AddRow("bar")
|
||||
mock.ExpectQuery("any").WillReturnRows(row)
|
||||
|
||||
conf := SqlConf{DataSource: mockedDatasource, DriverName: mysqlDriverName}
|
||||
conn := MustNewConn(conf, withMysqlAcceptable())
|
||||
stmt, err := conn.Prepare("any")
|
||||
assert.NoError(t, err)
|
||||
|
||||
res, err := stmt.Exec()
|
||||
assert.NoError(t, err)
|
||||
lastInsertID, err := res.LastInsertId()
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, int64(2), lastInsertID)
|
||||
rowsAffected, err := res.RowsAffected()
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, int64(3), rowsAffected)
|
||||
|
||||
stmt, err = conn.Prepare("any")
|
||||
assert.NoError(t, err)
|
||||
|
||||
var val string
|
||||
err = stmt.QueryRow(&val)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "bar", val)
|
||||
|
||||
mock.ExpectPrepare("any")
|
||||
rows := sqlmock.NewRows([]string{"any"}).AddRow("foo").AddRow("bar")
|
||||
mock.ExpectQuery("any").WillReturnRows(rows)
|
||||
|
||||
stmt, err = conn.Prepare("any")
|
||||
assert.NoError(t, err)
|
||||
|
||||
var vals []string
|
||||
assert.NoError(t, stmt.QueryRowsPartial(&vals))
|
||||
assert.ElementsMatch(t, []string{"foo", "bar"}, vals)
|
||||
}
|
||||
|
||||
func TestConfigSqlConnQuery(t *testing.T) {
|
||||
db, mock, err := sqlmock.New()
|
||||
assert.NotNil(t, db)
|
||||
assert.NotNil(t, mock)
|
||||
assert.Nil(t, err)
|
||||
connManager.Inject(mockedDatasource, db)
|
||||
|
||||
t.Run("QueryRow", func(t *testing.T) {
|
||||
mock.ExpectQuery("any").WillReturnRows(sqlmock.NewRows([]string{"foo"}).AddRow("bar"))
|
||||
conf := SqlConf{DataSource: mockedDatasource, DriverName: mysqlDriverName}
|
||||
conn := MustNewConn(conf)
|
||||
var val string
|
||||
assert.NoError(t, conn.QueryRow(&val, "any"))
|
||||
assert.Equal(t, "bar", val)
|
||||
})
|
||||
|
||||
t.Run("QueryRowPartial", func(t *testing.T) {
|
||||
mock.ExpectQuery("any").WillReturnRows(sqlmock.NewRows([]string{"foo"}).AddRow("bar"))
|
||||
conf := SqlConf{DataSource: mockedDatasource, DriverName: mysqlDriverName}
|
||||
conn := MustNewConn(conf)
|
||||
var val string
|
||||
assert.NoError(t, conn.QueryRowPartial(&val, "any"))
|
||||
assert.Equal(t, "bar", val)
|
||||
})
|
||||
|
||||
t.Run("QueryRows", func(t *testing.T) {
|
||||
mock.ExpectQuery("any").WillReturnRows(sqlmock.NewRows([]string{"any"}).AddRow("foo").AddRow("bar"))
|
||||
conf := SqlConf{DataSource: mockedDatasource, DriverName: mysqlDriverName}
|
||||
conn := MustNewConn(conf)
|
||||
var vals []string
|
||||
assert.NoError(t, conn.QueryRows(&vals, "any"))
|
||||
assert.ElementsMatch(t, []string{"foo", "bar"}, vals)
|
||||
})
|
||||
|
||||
t.Run("QueryRowsPartial", func(t *testing.T) {
|
||||
mock.ExpectQuery("any").WillReturnRows(sqlmock.NewRows([]string{"any"}).AddRow("foo").AddRow("bar"))
|
||||
conf := SqlConf{DataSource: mockedDatasource, DriverName: mysqlDriverName}
|
||||
conn := MustNewConn(conf)
|
||||
var vals []string
|
||||
assert.NoError(t, conn.QueryRowsPartial(&vals, "any"))
|
||||
assert.ElementsMatch(t, []string{"foo", "bar"}, vals)
|
||||
})
|
||||
}
|
||||
|
||||
func TestConfigSqlConnErr(t *testing.T) {
|
||||
t.Run("panic on empty config", func(t *testing.T) {
|
||||
original := logx.ExitOnFatal.True()
|
||||
logx.ExitOnFatal.Set(false)
|
||||
defer logx.ExitOnFatal.Set(original)
|
||||
|
||||
assert.Panics(t, func() {
|
||||
MustNewConn(SqlConf{})
|
||||
})
|
||||
})
|
||||
t.Run("on error", func(t *testing.T) {
|
||||
db, mock, err := sqlmock.New()
|
||||
assert.NotNil(t, db)
|
||||
assert.NotNil(t, mock)
|
||||
assert.Nil(t, err)
|
||||
connManager.Inject(mockedDatasource, db)
|
||||
|
||||
conf := SqlConf{DataSource: mockedDatasource, DriverName: mysqlDriverName}
|
||||
conn := MustNewConn(conf)
|
||||
conn.(*commonSqlConn).connProv = func(ctx context.Context) (*sql.DB, error) {
|
||||
return nil, errors.New("error")
|
||||
}
|
||||
_, err = conn.Prepare("any")
|
||||
assert.Error(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
func TestStatement(t *testing.T) {
|
||||
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||
mock.ExpectPrepare("any").WillBeClosed()
|
||||
@@ -303,6 +446,93 @@ func TestWithAcceptable(t *testing.T) {
|
||||
assert.True(t, conn.accept(acceptableErr3))
|
||||
}
|
||||
|
||||
func TestProvider(t *testing.T) {
|
||||
defer func() {
|
||||
_ = connManager.Close()
|
||||
}()
|
||||
|
||||
primaryDSN := "primary:password@tcp(127.0.0.1:3306)/primary_db"
|
||||
replicasDSN := []string{
|
||||
"replica_one:pwd@tcp(localhost:3306)/replica_one",
|
||||
"replica_two:pwd@tcp(localhost:3306)/replica_two",
|
||||
"replica_three:pwd@tcp(localhost:3306)/replica_three",
|
||||
}
|
||||
|
||||
primaryDB, err := connManager.GetResource(primaryDSN, func() (io.Closer, error) { return sql.Open(mysqlDriverName, primaryDSN) })
|
||||
assert.Nil(t, err)
|
||||
assert.NotNil(t, primaryDB)
|
||||
replicaOneDB, err := connManager.GetResource(replicasDSN[0], func() (io.Closer, error) { return sql.Open(mysqlDriverName, replicasDSN[0]) })
|
||||
assert.Nil(t, err)
|
||||
assert.NotNil(t, replicaOneDB)
|
||||
replicaTwoDB, err := connManager.GetResource(replicasDSN[1], func() (io.Closer, error) { return sql.Open(mysqlDriverName, replicasDSN[1]) })
|
||||
assert.Nil(t, err)
|
||||
assert.NotNil(t, replicaTwoDB)
|
||||
replicaThreeDB, err := connManager.GetResource(replicasDSN[2], func() (io.Closer, error) { return sql.Open(mysqlDriverName, replicasDSN[2]) })
|
||||
assert.Nil(t, err)
|
||||
assert.NotNil(t, replicaThreeDB)
|
||||
|
||||
sc := &commonSqlConn{}
|
||||
sc.connProv = getConnProvider(sc, mysqlDriverName, primaryDSN, policyRoundRobin, nil)
|
||||
|
||||
ctx := context.Background()
|
||||
db, err := sc.connProv(ctx)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, primaryDB, db)
|
||||
|
||||
ctx = WithWrite(ctx)
|
||||
db, err = sc.connProv(ctx)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, primaryDB, db)
|
||||
|
||||
ctx = WithReadPrimary(ctx)
|
||||
db, err = sc.connProv(ctx)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, primaryDB, db)
|
||||
|
||||
// no mode set, should return primary
|
||||
ctx = context.Background()
|
||||
sc.connProv = getConnProvider(sc, mysqlDriverName, primaryDSN, policyRoundRobin, replicasDSN)
|
||||
db, err = sc.connProv(ctx)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, primaryDB, db)
|
||||
|
||||
ctx = WithReadReplica(ctx)
|
||||
sc.connProv = getConnProvider(sc, mysqlDriverName, primaryDSN, policyRoundRobin, []string{replicasDSN[0]})
|
||||
db, err = sc.connProv(ctx)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, replicaOneDB, db)
|
||||
|
||||
// default policy is round-robin
|
||||
sc.connProv = getConnProvider(sc, mysqlDriverName, primaryDSN, policyRoundRobin, replicasDSN)
|
||||
replicas := []io.Closer{replicaOneDB, replicaTwoDB, replicaThreeDB}
|
||||
for i := 0; i < len(replicasDSN); i++ {
|
||||
db, err = sc.connProv(ctx)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, replicas[i], db)
|
||||
}
|
||||
|
||||
// random policy
|
||||
sc.connProv = getConnProvider(sc, mysqlDriverName, primaryDSN, policyRandom, replicasDSN)
|
||||
for i := 0; i < len(replicasDSN); i++ {
|
||||
db, err = sc.connProv(ctx)
|
||||
assert.Nil(t, err)
|
||||
assert.Contains(t, replicas, db)
|
||||
}
|
||||
|
||||
// unknown policy
|
||||
sc.connProv = getConnProvider(sc, mysqlDriverName, primaryDSN, "unknown", replicasDSN)
|
||||
_, err = sc.connProv(ctx)
|
||||
assert.NotNil(t, err)
|
||||
|
||||
// empty policy transforms to round-robin
|
||||
sc.connProv = getConnProvider(sc, mysqlDriverName, primaryDSN, "", replicasDSN)
|
||||
for i := 0; i < len(replicasDSN); i++ {
|
||||
db, err = sc.connProv(ctx)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, replicas[i], db)
|
||||
}
|
||||
}
|
||||
|
||||
func buildConn() (mock sqlmock.Sqlmock, err error) {
|
||||
_, err = connManager.GetResource(mockedDatasource, func() (io.Closer, error) {
|
||||
var db *sql.DB
|
||||
|
||||
@@ -27,7 +27,7 @@ func getCachedSqlConn(driverName, server string) (*sql.DB, error) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if driverName != mysqlDriverName {
|
||||
if driverName == mysqlDriverName {
|
||||
if cfg, e := mysql.ParseDSN(server); e != nil {
|
||||
// if cannot parse, don't collect the metrics
|
||||
logx.Error(e)
|
||||
|
||||
@@ -156,7 +156,7 @@ func begin(db *sql.DB) (trans, error) {
|
||||
|
||||
func transact(ctx context.Context, db *commonSqlConn, b beginnable,
|
||||
fn func(context.Context, Session) error) (err error) {
|
||||
conn, err := db.connProv()
|
||||
conn, err := db.connProv(ctx)
|
||||
if err != nil {
|
||||
db.onError(ctx, err)
|
||||
return err
|
||||
|
||||
@@ -117,7 +117,7 @@ func TestTxExceptions(t *testing.T) {
|
||||
|
||||
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||
conn := &commonSqlConn{
|
||||
connProv: func() (*sql.DB, error) {
|
||||
connProv: func(ctx context.Context) (*sql.DB, error) {
|
||||
return nil, errors.New("foo")
|
||||
},
|
||||
beginTx: begin,
|
||||
|
||||
@@ -2,6 +2,7 @@ package stringx
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"slices"
|
||||
"unicode"
|
||||
|
||||
"github.com/zeromicro/go-zero/core/lang"
|
||||
@@ -15,14 +16,9 @@ var (
|
||||
)
|
||||
|
||||
// Contains checks if str is in list.
|
||||
// Deprecated: use slices.Contains instead.
|
||||
func Contains(list []string, str string) bool {
|
||||
for _, each := range list {
|
||||
if each == str {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
return slices.Contains(list, str)
|
||||
}
|
||||
|
||||
// Filter filters chars from s with given filter function.
|
||||
@@ -123,11 +119,7 @@ func Remove(strings []string, strs ...string) []string {
|
||||
// Reverse reverses s.
|
||||
func Reverse(s string) string {
|
||||
runes := []rune(s)
|
||||
|
||||
for from, to := 0, len(runes)-1; from < to; from, to = from+1, to-1 {
|
||||
runes[from], runes[to] = runes[to], runes[from]
|
||||
}
|
||||
|
||||
slices.Reverse(runes)
|
||||
return string(runes)
|
||||
}
|
||||
|
||||
|
||||
@@ -7,6 +7,28 @@ import (
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestContainsString(t *testing.T) {
|
||||
cases := []struct {
|
||||
slice []string
|
||||
value string
|
||||
expect bool
|
||||
}{
|
||||
{[]string{"1"}, "1", true},
|
||||
{[]string{"1"}, "2", false},
|
||||
{[]string{"1", "2"}, "1", true},
|
||||
{[]string{"1", "2"}, "3", false},
|
||||
{nil, "3", false},
|
||||
{nil, "", false},
|
||||
}
|
||||
|
||||
for _, each := range cases {
|
||||
t.Run(path.Join(each.slice...), func(t *testing.T) {
|
||||
actual := Contains(each.slice, each.value)
|
||||
assert.Equal(t, each.expect, actual)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestNotEmpty(t *testing.T) {
|
||||
cases := []struct {
|
||||
args []string
|
||||
@@ -41,28 +63,6 @@ func TestNotEmpty(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestContainsString(t *testing.T) {
|
||||
cases := []struct {
|
||||
slice []string
|
||||
value string
|
||||
expect bool
|
||||
}{
|
||||
{[]string{"1"}, "1", true},
|
||||
{[]string{"1"}, "2", false},
|
||||
{[]string{"1", "2"}, "1", true},
|
||||
{[]string{"1", "2"}, "3", false},
|
||||
{nil, "3", false},
|
||||
{nil, "", false},
|
||||
}
|
||||
|
||||
for _, each := range cases {
|
||||
t.Run(path.Join(each.slice...), func(t *testing.T) {
|
||||
actual := Contains(each.slice, each.value)
|
||||
assert.Equal(t, each.expect, actual)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestFilter(t *testing.T) {
|
||||
cases := []struct {
|
||||
input string
|
||||
|
||||
@@ -47,30 +47,31 @@ func (ir *ImmutableResource) Get() (any, error) {
|
||||
return resource, nil
|
||||
}
|
||||
|
||||
ir.maybeRefresh(func() {
|
||||
res, err := ir.fetch()
|
||||
ir.lock.Lock()
|
||||
if err != nil {
|
||||
ir.err = err
|
||||
} else {
|
||||
ir.resource, ir.err = res, nil
|
||||
}
|
||||
ir.lock.Unlock()
|
||||
})
|
||||
ir.lock.Lock()
|
||||
defer ir.lock.Unlock()
|
||||
|
||||
ir.lock.RLock()
|
||||
resource, err := ir.resource, ir.err
|
||||
ir.lock.RUnlock()
|
||||
return resource, err
|
||||
// double check
|
||||
if ir.resource != nil {
|
||||
return ir.resource, nil
|
||||
}
|
||||
if ir.err != nil && !ir.shouldRefresh() {
|
||||
return ir.resource, ir.err
|
||||
}
|
||||
|
||||
res, err := ir.fetch()
|
||||
ir.lastTime.Set(timex.Now())
|
||||
if err != nil {
|
||||
ir.err = err
|
||||
return nil, err
|
||||
}
|
||||
|
||||
ir.resource, ir.err = res, nil
|
||||
return res, nil
|
||||
}
|
||||
|
||||
func (ir *ImmutableResource) maybeRefresh(execute func()) {
|
||||
now := timex.Now()
|
||||
func (ir *ImmutableResource) shouldRefresh() bool {
|
||||
lastTime := ir.lastTime.Load()
|
||||
if lastTime == 0 || lastTime+ir.refreshInterval < now {
|
||||
ir.lastTime.Set(now)
|
||||
execute()
|
||||
}
|
||||
return lastTime == 0 || lastTime+ir.refreshInterval < timex.Now()
|
||||
}
|
||||
|
||||
// WithRefreshIntervalOnFailure sets refresh interval on failure.
|
||||
|
||||
@@ -2,6 +2,8 @@ package syncx
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
@@ -56,6 +58,50 @@ func TestImmutableResourceError(t *testing.T) {
|
||||
assert.Equal(t, 2, count)
|
||||
}
|
||||
|
||||
// It's hard to test more than one goroutine fetching the resource at the same time,
|
||||
// because it's difficult to make more than one goroutine to pass the first read lock
|
||||
// and wait another to pass the read lock before it gets the write lock.
|
||||
func TestImmutableResourceConcurrent(t *testing.T) {
|
||||
const message = "hello"
|
||||
var count int32
|
||||
ready := make(chan struct{})
|
||||
r := NewImmutableResource(func() (any, error) {
|
||||
atomic.AddInt32(&count, 1)
|
||||
close(ready) // signal that fetch started
|
||||
time.Sleep(10 * time.Millisecond) // simulate slow fetch
|
||||
return message, nil
|
||||
})
|
||||
|
||||
const goroutines = 10
|
||||
var wg sync.WaitGroup
|
||||
results := make([]any, goroutines)
|
||||
errs := make([]error, goroutines)
|
||||
|
||||
wg.Add(goroutines)
|
||||
for i := 0; i < goroutines; i++ {
|
||||
go func(idx int) {
|
||||
defer wg.Done()
|
||||
res, err := r.Get()
|
||||
results[idx] = res
|
||||
errs[idx] = err
|
||||
}(i)
|
||||
}
|
||||
|
||||
// wait for fetch to start
|
||||
<-ready
|
||||
|
||||
wg.Wait()
|
||||
|
||||
// fetch should only be called once despite concurrent access
|
||||
assert.Equal(t, int32(1), atomic.LoadInt32(&count))
|
||||
|
||||
// all goroutines should eventually get the same result
|
||||
for i := 0; i < goroutines; i++ {
|
||||
assert.Nil(t, errs[i])
|
||||
assert.Equal(t, message, results[i])
|
||||
}
|
||||
}
|
||||
|
||||
func TestImmutableResourceErrorRefreshAlways(t *testing.T) {
|
||||
var count int
|
||||
r := NewImmutableResource(func() (any, error) {
|
||||
|
||||
@@ -3,9 +3,7 @@ package syncx
|
||||
import "sync"
|
||||
|
||||
// Once returns a func that guarantees fn can only called once.
|
||||
// Deprecated: use sync.OnceFunc instead.
|
||||
func Once(fn func()) func() {
|
||||
once := new(sync.Once)
|
||||
return func() {
|
||||
once.Do(fn)
|
||||
}
|
||||
return sync.OnceFunc(fn)
|
||||
}
|
||||
|
||||
@@ -100,6 +100,34 @@ func (p *Pool) Put(x any) {
|
||||
p.cond.Signal()
|
||||
}
|
||||
|
||||
// DestroyAll destroys all resources in the pool.
|
||||
// It calls the destroy function on each resource and resets the pool state.
|
||||
// This is useful when you need to forcefully clean up all resources, for example:
|
||||
// - When removing an obsolete pool
|
||||
// - When refreshing all resources after configuration changes
|
||||
// - When avoiding resource leaks in dynamic pool scenarios
|
||||
func (p *Pool) DestroyAll() {
|
||||
p.lock.Lock()
|
||||
defer p.lock.Unlock()
|
||||
|
||||
// Iterate through the linked list and destroy all resources
|
||||
current := p.head
|
||||
for current != nil {
|
||||
next := current.next
|
||||
if p.destroy != nil {
|
||||
p.destroy(current.item)
|
||||
}
|
||||
current = next
|
||||
}
|
||||
|
||||
// Reset pool state
|
||||
p.head = nil
|
||||
p.created = 0
|
||||
|
||||
// Wake up all waiting goroutines since the pool is now empty
|
||||
p.cond.Broadcast()
|
||||
}
|
||||
|
||||
// WithMaxAge returns a function to customize a Pool with given max age.
|
||||
func WithMaxAge(duration time.Duration) PoolOption {
|
||||
return func(pool *Pool) {
|
||||
|
||||
@@ -107,6 +107,155 @@ func TestNewPoolPanics(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
func TestPoolDestroyAll(t *testing.T) {
|
||||
var destroyed []int
|
||||
var destroyCount int32
|
||||
|
||||
destroyFunc := func(item any) {
|
||||
destroyed = append(destroyed, item.(int))
|
||||
atomic.AddInt32(&destroyCount, 1)
|
||||
}
|
||||
|
||||
pool := NewPool(limit, create, destroyFunc)
|
||||
|
||||
// Put some resources into the pool
|
||||
pool.Put(10)
|
||||
pool.Put(20)
|
||||
pool.Put(30)
|
||||
|
||||
// Destroy all resources
|
||||
pool.DestroyAll()
|
||||
|
||||
// Verify all resources were destroyed
|
||||
assert.Equal(t, int32(3), atomic.LoadInt32(&destroyCount))
|
||||
assert.Contains(t, destroyed, 10)
|
||||
assert.Contains(t, destroyed, 20)
|
||||
assert.Contains(t, destroyed, 30)
|
||||
|
||||
// Verify pool is empty - next Get should create new resource
|
||||
val := pool.Get()
|
||||
assert.Equal(t, 1, val) // create() returns 1
|
||||
}
|
||||
|
||||
func TestPoolDestroyAllEmpty(t *testing.T) {
|
||||
var destroyCount int32
|
||||
destroyFunc := func(_ any) {
|
||||
atomic.AddInt32(&destroyCount, 1)
|
||||
}
|
||||
|
||||
pool := NewPool(limit, create, destroyFunc)
|
||||
|
||||
// DestroyAll on empty pool should not panic
|
||||
pool.DestroyAll()
|
||||
|
||||
// No resources should have been destroyed
|
||||
assert.Equal(t, int32(0), atomic.LoadInt32(&destroyCount))
|
||||
|
||||
// Pool should still work normally
|
||||
val := pool.Get()
|
||||
assert.Equal(t, 1, val)
|
||||
}
|
||||
|
||||
func TestPoolDestroyAllWithNilDestroy(t *testing.T) {
|
||||
pool := NewPool(limit, create, nil)
|
||||
|
||||
// Put some resources into the pool
|
||||
pool.Put(10)
|
||||
pool.Put(20)
|
||||
|
||||
// DestroyAll with nil destroy function should not panic
|
||||
pool.DestroyAll()
|
||||
|
||||
// Pool should be empty and work normally
|
||||
val := pool.Get()
|
||||
assert.Equal(t, 1, val)
|
||||
}
|
||||
|
||||
func TestPoolDestroyAllConcurrency(t *testing.T) {
|
||||
var destroyCount int32
|
||||
var createCount int32
|
||||
|
||||
createFunc := func() any {
|
||||
return atomic.AddInt32(&createCount, 1)
|
||||
}
|
||||
|
||||
destroyFunc := func(_ any) {
|
||||
atomic.AddInt32(&destroyCount, 1)
|
||||
}
|
||||
|
||||
pool := NewPool(limit, createFunc, destroyFunc)
|
||||
|
||||
// Add some initial resources
|
||||
for i := 0; i < 5; i++ {
|
||||
pool.Put(i + 100)
|
||||
}
|
||||
|
||||
var wg sync.WaitGroup
|
||||
const goroutines = 10
|
||||
|
||||
// Concurrently perform various operations
|
||||
for i := 0; i < goroutines; i++ {
|
||||
wg.Add(1)
|
||||
go func(id int) {
|
||||
defer wg.Done()
|
||||
|
||||
switch id % 4 {
|
||||
case 0:
|
||||
// DestroyAll
|
||||
pool.DestroyAll()
|
||||
case 1:
|
||||
// Get resources
|
||||
val := pool.Get()
|
||||
pool.Put(val)
|
||||
case 2:
|
||||
// Put resources
|
||||
pool.Put(id + 1000)
|
||||
case 3:
|
||||
// Get and don't put back
|
||||
pool.Get()
|
||||
}
|
||||
}(i)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
|
||||
// Final DestroyAll to clean up
|
||||
pool.DestroyAll()
|
||||
|
||||
// Pool should work after concurrent operations
|
||||
val := pool.Get()
|
||||
assert.NotNil(t, val)
|
||||
}
|
||||
|
||||
func TestPoolDestroyAllWakesWaitingGoroutines(t *testing.T) {
|
||||
pool := NewPool(1, create, destroy) // Small pool size
|
||||
|
||||
// Fill the pool
|
||||
resource := pool.Get()
|
||||
assert.Equal(t, 1, resource)
|
||||
|
||||
var wg sync.WaitGroup
|
||||
var gotResource bool
|
||||
|
||||
// Start a goroutine that will wait for a resource
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
val := pool.Get() // This will block since pool is full
|
||||
gotResource = true
|
||||
assert.Equal(t, 1, val) // Should get a newly created resource after DestroyAll
|
||||
}()
|
||||
|
||||
// Give the goroutine time to start waiting
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
|
||||
// DestroyAll should wake up the waiting goroutine
|
||||
pool.DestroyAll()
|
||||
|
||||
wg.Wait()
|
||||
assert.True(t, gotResource)
|
||||
}
|
||||
|
||||
func create() any {
|
||||
return 1
|
||||
}
|
||||
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"runtime"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
)
|
||||
|
||||
const factor = 10
|
||||
@@ -100,6 +101,6 @@ func (r *StableRunner[I, O]) Wait() {
|
||||
close(r.done)
|
||||
r.runner.Wait()
|
||||
for atomic.LoadUint64(&r.consumedIndex) < atomic.LoadUint64(&r.writtenIndex) {
|
||||
runtime.Gosched()
|
||||
time.Sleep(time.Millisecond)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -82,7 +82,7 @@ func TestExtractValidTraceContext(t *testing.T) {
|
||||
}),
|
||||
},
|
||||
{
|
||||
name: "invalid tracestate perserves traceparent",
|
||||
name: "invalid tracestate preserves traceparent",
|
||||
traceparent: "00-4bf92f3577b34da6a3ce929d0e0e4736-00f067aa0ba902b7-00",
|
||||
tracestate: "invalid$@#=invalid",
|
||||
sc: trace.NewSpanContext(trace.SpanContextConfig{
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"cmp"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/zeromicro/go-zero/core/mathx"
|
||||
"github.com/zeromicro/go-zero/core/stringx"
|
||||
)
|
||||
|
||||
@@ -39,7 +39,7 @@ func compare(v1, v2 string) int {
|
||||
fields1, fields2 := strings.Split(v1, "."), strings.Split(v2, ".")
|
||||
ver1, ver2 := strsToInts(fields1), strsToInts(fields2)
|
||||
ver1len, ver2len := len(ver1), len(ver2)
|
||||
shorter := mathx.MinInt(ver1len, ver2len)
|
||||
shorter := min(ver1len, ver2len)
|
||||
|
||||
for i := 0; i < shorter; i++ {
|
||||
if ver1[i] == ver2[i] {
|
||||
@@ -50,14 +50,7 @@ func compare(v1, v2 string) int {
|
||||
return 1
|
||||
}
|
||||
}
|
||||
|
||||
if ver1len < ver2len {
|
||||
return -1
|
||||
} else if ver1len == ver2len {
|
||||
return 0
|
||||
} else {
|
||||
return 1
|
||||
}
|
||||
return cmp.Compare(ver1len, ver2len)
|
||||
}
|
||||
|
||||
func strsToInts(strs []string) []int64 {
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user