chore: refactor mon/monc (#5073)

This commit is contained in:
Kevin Wan
2025-08-09 23:51:44 +08:00
committed by GitHub
parent b41b1b00df
commit 1ebbc6f0c7
9 changed files with 172 additions and 162 deletions

View File

@@ -1,4 +1,4 @@
//go:generate mockgen -package mon -destination collection_inserter_mock.go -source bulkinserter.go collectionInserter //go:generate mockgen -package mon -destination collectioninserter_mock.go -source bulkinserter.go collectionInserter
package mon package mon
import ( import (

View File

@@ -137,14 +137,6 @@ func newCollection(collection *mongo.Collection, brk breaker.Breaker) Collection
} }
} }
func newTestCollection(collection monCollection, brk breaker.Breaker) *decoratedCollection {
return &decoratedCollection{
Collection: collection,
name: "test",
brk: brk,
}
}
func (c *decoratedCollection) Aggregate(ctx context.Context, pipeline any, func (c *decoratedCollection) Aggregate(ctx context.Context, pipeline any,
opts ...options.Lister[options.AggregateOptions]) (cur *mongo.Cursor, err error) { opts ...options.Lister[options.AggregateOptions]) (cur *mongo.Cursor, err error) {
ctx, span := startSpan(ctx, aggregate) ctx, span := startSpan(ctx, aggregate)
@@ -185,6 +177,10 @@ func (c *decoratedCollection) BulkWrite(ctx context.Context, models []mongo.Writ
return return
} }
func (c *decoratedCollection) Clone(opts ...options.Lister[options.CollectionOptions]) *mongo.Collection {
return c.Collection.Clone(opts...)
}
func (c *decoratedCollection) CountDocuments(ctx context.Context, filter any, func (c *decoratedCollection) CountDocuments(ctx context.Context, filter any,
opts ...options.Lister[options.CountOptions]) (count int64, err error) { opts ...options.Lister[options.CountOptions]) (count int64, err error) {
ctx, span := startSpan(ctx, countDocuments) ctx, span := startSpan(ctx, countDocuments)
@@ -205,6 +201,10 @@ func (c *decoratedCollection) CountDocuments(ctx context.Context, filter any,
return return
} }
func (c *decoratedCollection) Database() *mongo.Database {
return c.Collection.Database()
}
func (c *decoratedCollection) DeleteMany(ctx context.Context, filter any, func (c *decoratedCollection) DeleteMany(ctx context.Context, filter any,
opts ...options.Lister[options.DeleteManyOptions]) (res *mongo.DeleteResult, err error) { opts ...options.Lister[options.DeleteManyOptions]) (res *mongo.DeleteResult, err error) {
ctx, span := startSpan(ctx, deleteMany) ctx, span := startSpan(ctx, deleteMany)
@@ -266,6 +266,10 @@ func (c *decoratedCollection) Distinct(ctx context.Context, fieldName string, fi
return return
} }
func (c *decoratedCollection) Drop(ctx context.Context, opts ...options.Lister[options.DropCollectionOptions]) error {
return c.Collection.Drop(ctx, opts...)
}
func (c *decoratedCollection) EstimatedDocumentCount(ctx context.Context, func (c *decoratedCollection) EstimatedDocumentCount(ctx context.Context,
opts ...options.Lister[options.EstimatedDocumentCountOptions]) (val int64, err error) { opts ...options.Lister[options.EstimatedDocumentCountOptions]) (val int64, err error) {
ctx, span := startSpan(ctx, estimatedDocumentCount) ctx, span := startSpan(ctx, estimatedDocumentCount)
@@ -391,6 +395,10 @@ func (c *decoratedCollection) FindOneAndUpdate(ctx context.Context, filter, upda
return return
} }
func (c *decoratedCollection) Indexes() mongo.IndexView {
return c.Collection.Indexes()
}
func (c *decoratedCollection) InsertMany(ctx context.Context, documents []any, func (c *decoratedCollection) InsertMany(ctx context.Context, documents []any,
opts ...options.Lister[options.InsertManyOptions]) (res *mongo.InsertManyResult, err error) { opts ...options.Lister[options.InsertManyOptions]) (res *mongo.InsertManyResult, err error) {
ctx, span := startSpan(ctx, insertMany) ctx, span := startSpan(ctx, insertMany)
@@ -511,22 +519,6 @@ func (c *decoratedCollection) UpdateOne(ctx context.Context, filter, update any,
return return
} }
func (c *decoratedCollection) Clone(opts ...options.Lister[options.CollectionOptions]) *mongo.Collection {
return c.Collection.Clone(opts...)
}
func (c *decoratedCollection) Database() *mongo.Database {
return c.Collection.Database()
}
func (c *decoratedCollection) Drop(ctx context.Context, opts ...options.Lister[options.DropCollectionOptions]) error {
return c.Collection.Drop(ctx, opts...)
}
func (c *decoratedCollection) Indexes() mongo.IndexView {
return c.Collection.Indexes()
}
func (c *decoratedCollection) Watch(ctx context.Context, pipeline any, opts ...options.Lister[options.ChangeStreamOptions]) ( func (c *decoratedCollection) Watch(ctx context.Context, pipeline any, opts ...options.Lister[options.ChangeStreamOptions]) (
*mongo.ChangeStream, error) { *mongo.ChangeStream, error) {
return c.Collection.Watch(ctx, pipeline, opts...) return c.Collection.Watch(ctx, pipeline, opts...)
@@ -578,72 +570,70 @@ func isDupKeyError(err error) bool {
return e.HasErrorCode(duplicateKeyCode) return e.HasErrorCode(duplicateKeyCode)
} }
type ( // monCollection defines a MongoDB collection, used for unit test
// monCollection defines a MongoDB collection, used for unit test type monCollection interface {
monCollection interface { // Aggregate executes an aggregation pipeline.
// Aggregate executes an aggregation pipeline. Aggregate(ctx context.Context, pipeline any, opts ...options.Lister[options.AggregateOptions]) (
Aggregate(ctx context.Context, pipeline any, opts ...options.Lister[options.AggregateOptions]) ( *mongo.Cursor, error)
*mongo.Cursor, error) // BulkWrite performs a bulk write operation.
// BulkWrite performs a bulk write operation. BulkWrite(ctx context.Context, models []mongo.WriteModel, opts ...options.Lister[options.BulkWriteOptions]) (
BulkWrite(ctx context.Context, models []mongo.WriteModel, opts ...options.Lister[options.BulkWriteOptions]) ( *mongo.BulkWriteResult, error)
*mongo.BulkWriteResult, error) // Clone creates a copy of this collection with the same settings.
// Clone creates a copy of this collection with the same settings. Clone(opts ...options.Lister[options.CollectionOptions]) *mongo.Collection
Clone(opts ...options.Lister[options.CollectionOptions]) *mongo.Collection // CountDocuments returns the number of documents in the collection that match the filter.
// CountDocuments returns the number of documents in the collection that match the filter. CountDocuments(ctx context.Context, filter any, opts ...options.Lister[options.CountOptions]) (int64, error)
CountDocuments(ctx context.Context, filter any, opts ...options.Lister[options.CountOptions]) (int64, error) // Database returns the database that this collection is a part of.
// Database returns the database that this collection is a part of. Database() *mongo.Database
Database() *mongo.Database // DeleteMany deletes documents from the collection that match the filter.
// DeleteMany deletes documents from the collection that match the filter. DeleteMany(ctx context.Context, filter any, opts ...options.Lister[options.DeleteManyOptions]) (
DeleteMany(ctx context.Context, filter any, opts ...options.Lister[options.DeleteManyOptions]) ( *mongo.DeleteResult, error)
*mongo.DeleteResult, error) // DeleteOne deletes at most one document from the collection that matches the filter.
// DeleteOne deletes at most one document from the collection that matches the filter. DeleteOne(ctx context.Context, filter any, opts ...options.Lister[options.DeleteOneOptions]) (
DeleteOne(ctx context.Context, filter any, opts ...options.Lister[options.DeleteOneOptions]) ( *mongo.DeleteResult, error)
*mongo.DeleteResult, error) // Distinct returns a list of distinct values for the given key across the collection.
// Distinct returns a list of distinct values for the given key across the collection. Distinct(ctx context.Context, fieldName string, filter any,
Distinct(ctx context.Context, fieldName string, filter any, opts ...options.Lister[options.DistinctOptions]) *mongo.DistinctResult
opts ...options.Lister[options.DistinctOptions]) *mongo.DistinctResult // Drop drops this collection from database.
// Drop drops this collection from database. Drop(ctx context.Context, opts ...options.Lister[options.DropCollectionOptions]) error
Drop(ctx context.Context, opts ...options.Lister[options.DropCollectionOptions]) error // EstimatedDocumentCount returns an estimate of the count of documents in a collection
// EstimatedDocumentCount returns an estimate of the count of documents in a collection // using collection metadata.
// using collection metadata. EstimatedDocumentCount(ctx context.Context, opts ...options.Lister[options.EstimatedDocumentCountOptions]) (int64, error)
EstimatedDocumentCount(ctx context.Context, opts ...options.Lister[options.EstimatedDocumentCountOptions]) (int64, error) // Find finds the documents matching the provided filter.
// Find finds the documents matching the provided filter. Find(ctx context.Context, filter any, opts ...options.Lister[options.FindOptions]) (*mongo.Cursor, error)
Find(ctx context.Context, filter any, opts ...options.Lister[options.FindOptions]) (*mongo.Cursor, error) // FindOne returns up to one document that matches the provided filter.
// FindOne returns up to one document that matches the provided filter. FindOne(ctx context.Context, filter any, opts ...options.Lister[options.FindOneOptions]) *mongo.SingleResult
FindOne(ctx context.Context, filter any, opts ...options.Lister[options.FindOneOptions]) *mongo.SingleResult // FindOneAndDelete returns at most one document that matches the filter. If the filter
// FindOneAndDelete returns at most one document that matches the filter. If the filter // matches multiple documents, only the first document is deleted.
// matches multiple documents, only the first document is deleted. FindOneAndDelete(ctx context.Context, filter any, opts ...options.Lister[options.FindOneAndDeleteOptions]) *mongo.SingleResult
FindOneAndDelete(ctx context.Context, filter any, opts ...options.Lister[options.FindOneAndDeleteOptions]) *mongo.SingleResult // FindOneAndReplace returns at most one document that matches the filter. If the filter
// FindOneAndReplace returns at most one document that matches the filter. If the filter // matches multiple documents, FindOneAndReplace returns the first document in the
// matches multiple documents, FindOneAndReplace returns the first document in the // collection that matches the filter.
// collection that matches the filter. FindOneAndReplace(ctx context.Context, filter, replacement any,
FindOneAndReplace(ctx context.Context, filter, replacement any, opts ...options.Lister[options.FindOneAndReplaceOptions]) *mongo.SingleResult
opts ...options.Lister[options.FindOneAndReplaceOptions]) *mongo.SingleResult // FindOneAndUpdate returns at most one document that matches the filter. If the filter
// FindOneAndUpdate returns at most one document that matches the filter. If the filter // matches multiple documents, FindOneAndUpdate returns the first document in the
// matches multiple documents, FindOneAndUpdate returns the first document in the // collection that matches the filter.
// collection that matches the filter. FindOneAndUpdate(ctx context.Context, filter, update any,
FindOneAndUpdate(ctx context.Context, filter, update any, opts ...options.Lister[options.FindOneAndUpdateOptions]) *mongo.SingleResult
opts ...options.Lister[options.FindOneAndUpdateOptions]) *mongo.SingleResult // Indexes returns the index view for this collection.
// Indexes returns the index view for this collection. Indexes() mongo.IndexView
Indexes() mongo.IndexView // InsertMany inserts the provided documents.
// InsertMany inserts the provided documents. InsertMany(ctx context.Context, documents interface{}, opts ...options.Lister[options.InsertManyOptions]) (*mongo.InsertManyResult, error)
InsertMany(ctx context.Context, documents interface{}, opts ...options.Lister[options.InsertManyOptions]) (*mongo.InsertManyResult, error) // InsertOne inserts the provided document.
// InsertOne inserts the provided document. InsertOne(ctx context.Context, document any, opts ...options.Lister[options.InsertOneOptions]) (*mongo.InsertOneResult, error)
InsertOne(ctx context.Context, document any, opts ...options.Lister[options.InsertOneOptions]) (*mongo.InsertOneResult, error) // ReplaceOne replaces at most one document that matches the filter.
// ReplaceOne replaces at most one document that matches the filter. ReplaceOne(ctx context.Context, filter, replacement any,
ReplaceOne(ctx context.Context, filter, replacement any, opts ...options.Lister[options.ReplaceOptions]) (*mongo.UpdateResult, error)
opts ...options.Lister[options.ReplaceOptions]) (*mongo.UpdateResult, error) // UpdateByID updates a single document matching the provided filter.
// UpdateByID updates a single document matching the provided filter. UpdateByID(ctx context.Context, id, update any,
UpdateByID(ctx context.Context, id, update any, opts ...options.Lister[options.UpdateOneOptions]) (*mongo.UpdateResult, error)
opts ...options.Lister[options.UpdateOneOptions]) (*mongo.UpdateResult, error) // UpdateMany updates the provided documents.
// UpdateMany updates the provided documents. UpdateMany(ctx context.Context, filter, update any,
UpdateMany(ctx context.Context, filter, update any, opts ...options.Lister[options.UpdateManyOptions]) (*mongo.UpdateResult, error)
opts ...options.Lister[options.UpdateManyOptions]) (*mongo.UpdateResult, error) // UpdateOne updates a single document matching the provided filter.
// UpdateOne updates a single document matching the provided filter. UpdateOne(ctx context.Context, filter, update any,
UpdateOne(ctx context.Context, filter, update any, opts ...options.Lister[options.UpdateOneOptions]) (*mongo.UpdateResult, error)
opts ...options.Lister[options.UpdateOneOptions]) (*mongo.UpdateResult, error) // Watch returns a change stream cursor used to receive notifications of changes to the collection.
// Watch returns a change stream cursor used to receive notifications of changes to the collection. Watch(ctx context.Context, pipeline any, opts ...options.Lister[options.ChangeStreamOptions]) (
Watch(ctx context.Context, pipeline any, opts ...options.Lister[options.ChangeStreamOptions]) ( *mongo.ChangeStream, error)
*mongo.ChangeStream, error) }
}
)

View File

@@ -12,7 +12,7 @@ import (
"github.com/zeromicro/go-zero/core/timex" "github.com/zeromicro/go-zero/core/timex"
"go.mongodb.org/mongo-driver/v2/bson" "go.mongodb.org/mongo-driver/v2/bson"
"go.mongodb.org/mongo-driver/v2/mongo" "go.mongodb.org/mongo-driver/v2/mongo"
mopt "go.mongodb.org/mongo-driver/v2/mongo/options" "go.mongodb.org/mongo-driver/v2/mongo/options"
"go.uber.org/mock/gomock" "go.uber.org/mock/gomock"
) )
@@ -75,7 +75,7 @@ func TestCollection_Aggregate(t *testing.T) {
mockCollection := NewMockmonCollection(ctrl) mockCollection := NewMockmonCollection(ctrl)
mockCollection.EXPECT().Aggregate(gomock.Any(), gomock.Any(), gomock.Any()).Return(&mongo.Cursor{}, nil) mockCollection.EXPECT().Aggregate(gomock.Any(), gomock.Any(), gomock.Any()).Return(&mongo.Cursor{}, nil)
c := newTestCollection(mockCollection, breaker.GetBreaker("localhost")) c := newTestCollection(mockCollection, breaker.GetBreaker("localhost"))
_, err := c.Aggregate(context.Background(), []interface{}{}, mopt.Aggregate()) _, err := c.Aggregate(context.Background(), []interface{}{}, options.Aggregate())
assert.Nil(t, err) assert.Nil(t, err)
} }
@@ -156,10 +156,10 @@ func TestCollection_Find(t *testing.T) {
mockCollection.EXPECT().Find(gomock.Any(), gomock.Any(), gomock.Any()).Return(&mongo.Cursor{}, nil) mockCollection.EXPECT().Find(gomock.Any(), gomock.Any(), gomock.Any()).Return(&mongo.Cursor{}, nil)
c := newTestCollection(mockCollection, breaker.GetBreaker("localhost")) c := newTestCollection(mockCollection, breaker.GetBreaker("localhost"))
filter := bson.D{{Key: "x", Value: 1}} filter := bson.D{{Key: "x", Value: 1}}
_, err := c.Find(context.Background(), filter, mopt.Find()) _, err := c.Find(context.Background(), filter, options.Find())
assert.Nil(t, err) assert.Nil(t, err)
c.brk = new(dropBreaker) c.brk = new(dropBreaker)
_, err = c.Find(context.Background(), filter, mopt.Find()) _, err = c.Find(context.Background(), filter, options.Find())
assert.Equal(t, errDummy, err) assert.Equal(t, errDummy, err)
} }
@@ -185,9 +185,9 @@ func TestCollection_FindOneAndDelete(t *testing.T) {
c := newTestCollection(mockCollection, breaker.GetBreaker("localhost")) c := newTestCollection(mockCollection, breaker.GetBreaker("localhost"))
filter := bson.D{} filter := bson.D{}
mockCollection.EXPECT().FindOneAndDelete(gomock.Any(), gomock.Any(), gomock.Any()).Return(&mongo.SingleResult{}) mockCollection.EXPECT().FindOneAndDelete(gomock.Any(), gomock.Any(), gomock.Any()).Return(&mongo.SingleResult{})
_, err := c.FindOneAndDelete(context.Background(), filter, mopt.FindOneAndDelete()) _, err := c.FindOneAndDelete(context.Background(), filter, options.FindOneAndDelete())
assert.Equal(t, mongo.ErrNoDocuments, err) assert.Equal(t, mongo.ErrNoDocuments, err)
_, err = c.FindOneAndDelete(context.Background(), filter, mopt.FindOneAndDelete()) _, err = c.FindOneAndDelete(context.Background(), filter, options.FindOneAndDelete())
assert.Equal(t, mongo.ErrNoDocuments, err) assert.Equal(t, mongo.ErrNoDocuments, err)
c.brk = new(dropBreaker) c.brk = new(dropBreaker)
_, err = c.FindOneAndDelete(context.Background(), bson.D{{Key: "foo", Value: "bar"}}) _, err = c.FindOneAndDelete(context.Background(), bson.D{{Key: "foo", Value: "bar"}})
@@ -202,7 +202,7 @@ func TestCollection_FindOneAndReplace(t *testing.T) {
c := newTestCollection(mockCollection, breaker.GetBreaker("localhost")) c := newTestCollection(mockCollection, breaker.GetBreaker("localhost"))
filter := bson.D{{Key: "x", Value: 1}} filter := bson.D{{Key: "x", Value: 1}}
replacement := bson.D{{Key: "x", Value: 2}} replacement := bson.D{{Key: "x", Value: 2}}
opts := mopt.FindOneAndReplace().SetUpsert(true) opts := options.FindOneAndReplace().SetUpsert(true)
mockCollection.EXPECT().FindOneAndReplace(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(&mongo.SingleResult{}) mockCollection.EXPECT().FindOneAndReplace(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(&mongo.SingleResult{})
_, err := c.FindOneAndReplace(context.Background(), filter, replacement, opts) _, err := c.FindOneAndReplace(context.Background(), filter, replacement, opts)
assert.Equal(t, mongo.ErrNoDocuments, err) assert.Equal(t, mongo.ErrNoDocuments, err)
@@ -222,7 +222,7 @@ func TestCollection_FindOneAndUpdate(t *testing.T) {
filter := bson.D{{Key: "x", Value: 1}} filter := bson.D{{Key: "x", Value: 1}}
update := bson.D{{Key: "$x", Value: 2}} update := bson.D{{Key: "$x", Value: 2}}
mockCollection.EXPECT().FindOneAndUpdate(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(&mongo.SingleResult{}) mockCollection.EXPECT().FindOneAndUpdate(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(&mongo.SingleResult{})
opts := mopt.FindOneAndUpdate().SetUpsert(true) opts := options.FindOneAndUpdate().SetUpsert(true)
_, err := c.FindOneAndUpdate(context.Background(), filter, update, opts) _, err := c.FindOneAndUpdate(context.Background(), filter, update, opts)
assert.Equal(t, mongo.ErrNoDocuments, err) assert.Equal(t, mongo.ErrNoDocuments, err)
_, err = c.FindOneAndUpdate(context.Background(), filter, update, opts) _, err = c.FindOneAndUpdate(context.Background(), filter, update, opts)
@@ -487,6 +487,14 @@ func TestIsDupKeyError(t *testing.T) {
} }
} }
func newTestCollection(collection monCollection, brk breaker.Breaker) *decoratedCollection {
return &decoratedCollection{
Collection: collection,
name: "test",
brk: brk,
}
}
type mockPromise struct { type mockPromise struct {
accepted bool accepted bool
reason string reason string

View File

@@ -3,7 +3,7 @@
// //
// Generated by this command: // Generated by this command:
// //
// mockgen -package mon -destination collection_inserter_mock.go -source bulkinserter.go collectionInserter // mockgen -package mon -destination collectioninserter_mock.go -source bulkinserter.go collectionInserter
// //
// Package mon is a generated GoMock package. // Package mon is a generated GoMock package.

View File

@@ -30,7 +30,7 @@ type (
opts []Option opts []Option
} }
WrappedSession struct { Session struct {
session monSession session monSession
name string name string
brk breaker.Breaker brk breaker.Breaker
@@ -62,25 +62,14 @@ func newModel(name string, cli *mongo.Client, coll Collection, brk breaker.Break
return &Model{ return &Model{
name: name, name: name,
Collection: coll, Collection: coll,
cli: &mockMonClient{c: cli}, cli: &wrappedMonClient{c: cli},
brk: brk,
opts: opts,
}
}
func newTestModel(name string, cli monClient, coll monCollection, brk breaker.Breaker,
opts ...Option) *Model {
return &Model{
name: name,
Collection: newTestCollection(coll, breaker.GetBreaker("localhost")),
cli: cli,
brk: brk, brk: brk,
opts: opts, opts: opts,
} }
} }
// StartSession starts a new session. // StartSession starts a new session.
func (m *Model) StartSession(opts ...options.Lister[options.SessionOptions]) (sess *WrappedSession, err error) { func (m *Model) StartSession(opts ...options.Lister[options.SessionOptions]) (sess *Session, err error) {
starTime := timex.Now() starTime := timex.Now()
defer func() { defer func() {
logDuration(context.Background(), m.name, startSession, starTime, err) logDuration(context.Background(), m.name, startSession, starTime, err)
@@ -91,7 +80,7 @@ func (m *Model) StartSession(opts ...options.Lister[options.SessionOptions]) (se
return nil, sessionErr return nil, sessionErr
} }
return &WrappedSession{ return &Session{
session: session, session: session,
name: m.name, name: m.name,
brk: m.brk, brk: m.brk,
@@ -99,7 +88,8 @@ func (m *Model) StartSession(opts ...options.Lister[options.SessionOptions]) (se
} }
// Aggregate executes an aggregation pipeline. // Aggregate executes an aggregation pipeline.
func (m *Model) Aggregate(ctx context.Context, v, pipeline any, opts ...options.Lister[options.AggregateOptions]) error { func (m *Model) Aggregate(ctx context.Context, v, pipeline any,
opts ...options.Lister[options.AggregateOptions]) error {
cur, err := m.Collection.Aggregate(ctx, pipeline, opts...) cur, err := m.Collection.Aggregate(ctx, pipeline, opts...)
if err != nil { if err != nil {
return err return err
@@ -110,7 +100,8 @@ func (m *Model) Aggregate(ctx context.Context, v, pipeline any, opts ...options.
} }
// DeleteMany deletes documents that match the filter. // DeleteMany deletes documents that match the filter.
func (m *Model) DeleteMany(ctx context.Context, filter any, opts ...options.Lister[options.DeleteManyOptions]) (int64, error) { func (m *Model) DeleteMany(ctx context.Context, filter any,
opts ...options.Lister[options.DeleteManyOptions]) (int64, error) {
res, err := m.Collection.DeleteMany(ctx, filter, opts...) res, err := m.Collection.DeleteMany(ctx, filter, opts...)
if err != nil { if err != nil {
return 0, err return 0, err
@@ -120,7 +111,8 @@ func (m *Model) DeleteMany(ctx context.Context, filter any, opts ...options.List
} }
// DeleteOne deletes the first document that matches the filter. // DeleteOne deletes the first document that matches the filter.
func (m *Model) DeleteOne(ctx context.Context, filter any, opts ...options.Lister[options.DeleteOneOptions]) (int64, error) { func (m *Model) DeleteOne(ctx context.Context, filter any,
opts ...options.Lister[options.DeleteOneOptions]) (int64, error) {
res, err := m.Collection.DeleteOne(ctx, filter, opts...) res, err := m.Collection.DeleteOne(ctx, filter, opts...)
if err != nil { if err != nil {
return 0, err return 0, err
@@ -130,7 +122,8 @@ func (m *Model) DeleteOne(ctx context.Context, filter any, opts ...options.Liste
} }
// Find finds documents that match the filter. // Find finds documents that match the filter.
func (m *Model) Find(ctx context.Context, v, filter any, opts ...options.Lister[options.FindOptions]) error { func (m *Model) Find(ctx context.Context, v, filter any,
opts ...options.Lister[options.FindOptions]) error {
cur, err := m.Collection.Find(ctx, filter, opts...) cur, err := m.Collection.Find(ctx, filter, opts...)
if err != nil { if err != nil {
return err return err
@@ -141,7 +134,8 @@ func (m *Model) Find(ctx context.Context, v, filter any, opts ...options.Lister[
} }
// FindOne finds the first document that matches the filter. // FindOne finds the first document that matches the filter.
func (m *Model) FindOne(ctx context.Context, v, filter any, opts ...options.Lister[options.FindOneOptions]) error { func (m *Model) FindOne(ctx context.Context, v, filter any,
opts ...options.Lister[options.FindOneOptions]) error {
res, err := m.Collection.FindOne(ctx, filter, opts...) res, err := m.Collection.FindOne(ctx, filter, opts...)
if err != nil { if err != nil {
return err return err
@@ -184,7 +178,7 @@ func (m *Model) FindOneAndUpdate(ctx context.Context, v, filter, update any,
} }
// AbortTransaction implements the mongo.session interface. // AbortTransaction implements the mongo.session interface.
func (w *WrappedSession) AbortTransaction(ctx context.Context) (err error) { func (w *Session) AbortTransaction(ctx context.Context) (err error) {
ctx, span := startSpan(ctx, abortTransaction) ctx, span := startSpan(ctx, abortTransaction)
defer func() { defer func() {
endSpan(span, err) endSpan(span, err)
@@ -201,7 +195,7 @@ func (w *WrappedSession) AbortTransaction(ctx context.Context) (err error) {
} }
// CommitTransaction implements the mongo.session interface. // CommitTransaction implements the mongo.session interface.
func (w *WrappedSession) CommitTransaction(ctx context.Context) (err error) { func (w *Session) CommitTransaction(ctx context.Context) (err error) {
ctx, span := startSpan(ctx, commitTransaction) ctx, span := startSpan(ctx, commitTransaction)
defer func() { defer func() {
endSpan(span, err) endSpan(span, err)
@@ -218,7 +212,7 @@ func (w *WrappedSession) CommitTransaction(ctx context.Context) (err error) {
} }
// WithTransaction implements the mongo.session interface. // WithTransaction implements the mongo.session interface.
func (w *WrappedSession) WithTransaction( func (w *Session) WithTransaction(
ctx context.Context, ctx context.Context,
fn func(sessCtx context.Context) (any, error), fn func(sessCtx context.Context) (any, error),
opts ...options.Lister[options.TransactionOptions], opts ...options.Lister[options.TransactionOptions],
@@ -242,7 +236,7 @@ func (w *WrappedSession) WithTransaction(
} }
// EndSession implements the mongo.session interface. // EndSession implements the mongo.session interface.
func (w *WrappedSession) EndSession(ctx context.Context) { func (w *Session) EndSession(ctx context.Context) {
var err error var err error
ctx, span := startSpan(ctx, endSession) ctx, span := startSpan(ctx, endSession)
defer func() { defer func() {
@@ -261,10 +255,11 @@ func (w *WrappedSession) EndSession(ctx context.Context) {
} }
type ( type (
//for unit test // for unit test
monClient interface { monClient interface {
StartSession(opts ...options.Lister[options.SessionOptions]) (monSession, error) StartSession(opts ...options.Lister[options.SessionOptions]) (monSession, error)
} }
monSession interface { monSession interface {
AbortTransaction(ctx context.Context) error AbortTransaction(ctx context.Context) error
CommitTransaction(ctx context.Context) error CommitTransaction(ctx context.Context) error
@@ -274,10 +269,14 @@ type (
} }
) )
type mockMonClient struct { type wrappedMonClient struct {
c *mongo.Client c *mongo.Client
} }
func (m *mockMonClient) StartSession(opts ...options.Lister[options.SessionOptions]) (monSession, error) { // StartSession starts a new session using the underlying *mongo.Client.
// It implements the monClient interface.
// This is used to allow mocking in unit tests.
func (m *wrappedMonClient) StartSession(opts ...options.Lister[options.SessionOptions]) (
monSession, error) {
return m.c.StartSession(opts...) return m.c.StartSession(opts...)
} }

View File

@@ -20,7 +20,7 @@ func TestModel_StartSession(t *testing.T) {
mockMonCollection := NewMockmonCollection(ctrl) mockMonCollection := NewMockmonCollection(ctrl)
mockedMonClient := NewMockmonClient(ctrl) mockedMonClient := NewMockmonClient(ctrl)
mockMonSession := NewMockmonSession(ctrl) mockMonSession := NewMockmonSession(ctrl)
warpSession := &WrappedSession{ warpSession := &Session{
session: mockMonSession, session: mockMonSession,
name: "", name: "",
brk: breaker.GetBreaker("localhost"), brk: breaker.GetBreaker("localhost"),
@@ -39,9 +39,9 @@ func TestModel_StartSession(t *testing.T) {
mockMonSession.EXPECT().AbortTransaction(gomock.Any()).Return(nil) mockMonSession.EXPECT().AbortTransaction(gomock.Any()).Return(nil)
mockMonSession.EXPECT().EndSession(gomock.Any()) mockMonSession.EXPECT().EndSession(gomock.Any())
_, err = sess.WithTransaction(context.Background(), func(sessCtx context.Context) (any, error) { _, err = sess.WithTransaction(context.Background(), func(sessCtx context.Context) (any, error) {
//_ = sessCtx.StartTransaction() // _ = sessCtx.StartTransaction()
//sessCtx.Client().Database("1") // sessCtx.Client().Database("1")
//sessCtx.EndSession(context.Background()) // sessCtx.EndSession(context.Background())
return nil, nil return nil, nil
}) })
assert.Nil(t, err) assert.Nil(t, err)
@@ -224,9 +224,20 @@ func Test_mockMonClient_StartSession(t *testing.T) {
opts.Deployment = md opts.Deployment = md
client, err := mongo.Connect(opts) client, err := mongo.Connect(opts)
assert.Nil(t, err) assert.Nil(t, err)
m := mockMonClient{ m := wrappedMonClient{
c: client, c: client,
} }
_, err = m.StartSession() _, err = m.StartSession()
assert.Nil(t, err) assert.Nil(t, err)
} }
func newTestModel(name string, cli monClient, coll monCollection, brk breaker.Breaker,
opts ...Option) *Model {
return &Model{
name: name,
Collection: newTestCollection(coll, breaker.GetBreaker("localhost")),
cli: cli,
brk: brk,
opts: opts,
}
}

View File

@@ -8,7 +8,7 @@ import (
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"go.mongodb.org/mongo-driver/v2/bson" "go.mongodb.org/mongo-driver/v2/bson"
mopt "go.mongodb.org/mongo-driver/v2/mongo/options" "go.mongodb.org/mongo-driver/v2/mongo/options"
) )
func TestSetSlowThreshold(t *testing.T) { func TestSetSlowThreshold(t *testing.T) {
@@ -18,13 +18,13 @@ func TestSetSlowThreshold(t *testing.T) {
} }
func Test_defaultTimeoutOption(t *testing.T) { func Test_defaultTimeoutOption(t *testing.T) {
opts := mopt.Client() opts := options.Client()
defaultTimeoutOption()(opts) defaultTimeoutOption()(opts)
assert.Equal(t, defaultTimeout, *opts.Timeout) assert.Equal(t, defaultTimeout, *opts.Timeout)
} }
func TestWithTimeout(t *testing.T) { func TestWithTimeout(t *testing.T) {
opts := mopt.Client() opts := options.Client()
WithTimeout(time.Second)(opts) WithTimeout(time.Second)(opts)
assert.Equal(t, time.Second, *opts.Timeout) assert.Equal(t, time.Second, *opts.Timeout)
} }
@@ -56,10 +56,11 @@ func TestDisableInfoLog(t *testing.T) {
} }
func TestWithRegistryForTimestampRegisterType(t *testing.T) { func TestWithRegistryForTimestampRegisterType(t *testing.T) {
opts := mopt.Client() opts := options.Client()
// mongoDateTimeEncoder allow user convert time.Time to primitive.DateTime. // mongoDateTimeEncoder allow user convert time.Time to primitive.DateTime.
var mongoDateTimeEncoder bson.ValueEncoderFunc = func(ect bson.EncodeContext, w bson.ValueWriter, value reflect.Value) error { var mongoDateTimeEncoder bson.ValueEncoderFunc = func(ect bson.EncodeContext,
w bson.ValueWriter, value reflect.Value) error {
// Use reflect, determine if it can be converted to time.Time. // Use reflect, determine if it can be converted to time.Time.
dec, ok := value.Interface().(time.Time) dec, ok := value.Interface().(time.Time)
if !ok { if !ok {
@@ -69,7 +70,8 @@ func TestWithRegistryForTimestampRegisterType(t *testing.T) {
} }
// mongoDateTimeEncoder allow user convert primitive.DateTime to time.Time. // mongoDateTimeEncoder allow user convert primitive.DateTime to time.Time.
var mongoDateTimeDecoder bson.ValueDecoderFunc = func(ect bson.DecodeContext, r bson.ValueReader, value reflect.Value) error { var mongoDateTimeDecoder bson.ValueDecoderFunc = func(ect bson.DecodeContext,
r bson.ValueReader, value reflect.Value) error {
primTime, err := r.ReadDateTime() primTime, err := r.ReadDateTime()
if err != nil { if err != nil {
return fmt.Errorf("error reading primitive.DateTime from ValueReader: %v", err) return fmt.Errorf("error reading primitive.DateTime from ValueReader: %v", err)

View File

@@ -71,27 +71,6 @@ func newModel(uri, db, collection string, c cache.Cache) (*Model, error) {
}, nil }, nil
} }
// mustNewTestModel returns a test Model with the given cache.
func mustNewTestModel(collection mon.Collection, c cache.CacheConf, opts ...cache.Option) *Model {
return &Model{
Model: &mon.Model{
Collection: collection,
},
cache: cache.New(c, singleFlight, stats, mongo.ErrNoDocuments, opts...),
}
}
// NewNodeModel returns a test Model with a cache node.
func mustNewTestNodeModel(collection mon.Collection, rds *redis.Redis, opts ...cache.Option) *Model {
c := cache.NewNode(rds, singleFlight, stats, mongo.ErrNoDocuments, opts...)
return &Model{
Model: &mon.Model{
Collection: collection,
},
cache: c,
}
}
// DelCache deletes the cache with given keys. // DelCache deletes the cache with given keys.
func (mm *Model) DelCache(ctx context.Context, keys ...string) error { func (mm *Model) DelCache(ctx context.Context, keys ...string) error {
return mm.cache.DelCtx(ctx, keys...) return mm.cache.DelCtx(ctx, keys...)

View File

@@ -532,6 +532,27 @@ func createModel(t *testing.T, coll mon.Collection) *Model {
} }
} }
// mustNewTestModel returns a test Model with the given cache.
func mustNewTestModel(collection mon.Collection, c cache.CacheConf, opts ...cache.Option) *Model {
return &Model{
Model: &mon.Model{
Collection: collection,
},
cache: cache.New(c, singleFlight, stats, mongo.ErrNoDocuments, opts...),
}
}
// NewNodeModel returns a test Model with a cache node.
func mustNewTestNodeModel(collection mon.Collection, rds *redis.Redis, opts ...cache.Option) *Model {
c := cache.NewNode(rds, singleFlight, stats, mongo.ErrNoDocuments, opts...)
return &Model{
Model: &mon.Model{
Collection: collection,
},
cache: c,
}
}
var ( var (
errMocked = errors.New("mocked error") errMocked = errors.New("mocked error")
index int32 index int32