From 1ebbc6f0c7142a273bd8e473aef7a29fa3300246 Mon Sep 17 00:00:00 2001 From: Kevin Wan Date: Sat, 9 Aug 2025 23:51:44 +0800 Subject: [PATCH] chore: refactor mon/monc (#5073) --- core/stores/mon/bulkinserter.go | 2 +- core/stores/mon/collection.go | 176 +++++++++--------- core/stores/mon/collection_test.go | 24 ++- ...ter_mock.go => collectioninserter_mock.go} | 2 +- core/stores/mon/model.go | 53 +++--- core/stores/mon/model_test.go | 21 ++- core/stores/mon/options_test.go | 14 +- core/stores/monc/cachedmodel.go | 21 --- core/stores/monc/cachedmodel_test.go | 21 +++ 9 files changed, 172 insertions(+), 162 deletions(-) rename core/stores/mon/{collection_inserter_mock.go => collectioninserter_mock.go} (94%) diff --git a/core/stores/mon/bulkinserter.go b/core/stores/mon/bulkinserter.go index 28372c05d..42d43c7eb 100644 --- a/core/stores/mon/bulkinserter.go +++ b/core/stores/mon/bulkinserter.go @@ -1,4 +1,4 @@ -//go:generate mockgen -package mon -destination collection_inserter_mock.go -source bulkinserter.go collectionInserter +//go:generate mockgen -package mon -destination collectioninserter_mock.go -source bulkinserter.go collectionInserter package mon import ( diff --git a/core/stores/mon/collection.go b/core/stores/mon/collection.go index 79b4b9445..73640bfe2 100644 --- a/core/stores/mon/collection.go +++ b/core/stores/mon/collection.go @@ -137,14 +137,6 @@ func newCollection(collection *mongo.Collection, brk breaker.Breaker) Collection } } -func newTestCollection(collection monCollection, brk breaker.Breaker) *decoratedCollection { - return &decoratedCollection{ - Collection: collection, - name: "test", - brk: brk, - } -} - func (c *decoratedCollection) Aggregate(ctx context.Context, pipeline any, opts ...options.Lister[options.AggregateOptions]) (cur *mongo.Cursor, err error) { ctx, span := startSpan(ctx, aggregate) @@ -185,6 +177,10 @@ func (c *decoratedCollection) BulkWrite(ctx context.Context, models []mongo.Writ return } +func (c *decoratedCollection) Clone(opts ...options.Lister[options.CollectionOptions]) *mongo.Collection { + return c.Collection.Clone(opts...) +} + func (c *decoratedCollection) CountDocuments(ctx context.Context, filter any, opts ...options.Lister[options.CountOptions]) (count int64, err error) { ctx, span := startSpan(ctx, countDocuments) @@ -205,6 +201,10 @@ func (c *decoratedCollection) CountDocuments(ctx context.Context, filter any, return } +func (c *decoratedCollection) Database() *mongo.Database { + return c.Collection.Database() +} + func (c *decoratedCollection) DeleteMany(ctx context.Context, filter any, opts ...options.Lister[options.DeleteManyOptions]) (res *mongo.DeleteResult, err error) { ctx, span := startSpan(ctx, deleteMany) @@ -266,6 +266,10 @@ func (c *decoratedCollection) Distinct(ctx context.Context, fieldName string, fi return } +func (c *decoratedCollection) Drop(ctx context.Context, opts ...options.Lister[options.DropCollectionOptions]) error { + return c.Collection.Drop(ctx, opts...) +} + func (c *decoratedCollection) EstimatedDocumentCount(ctx context.Context, opts ...options.Lister[options.EstimatedDocumentCountOptions]) (val int64, err error) { ctx, span := startSpan(ctx, estimatedDocumentCount) @@ -391,6 +395,10 @@ func (c *decoratedCollection) FindOneAndUpdate(ctx context.Context, filter, upda return } +func (c *decoratedCollection) Indexes() mongo.IndexView { + return c.Collection.Indexes() +} + func (c *decoratedCollection) InsertMany(ctx context.Context, documents []any, opts ...options.Lister[options.InsertManyOptions]) (res *mongo.InsertManyResult, err error) { ctx, span := startSpan(ctx, insertMany) @@ -511,22 +519,6 @@ func (c *decoratedCollection) UpdateOne(ctx context.Context, filter, update any, return } -func (c *decoratedCollection) Clone(opts ...options.Lister[options.CollectionOptions]) *mongo.Collection { - return c.Collection.Clone(opts...) -} - -func (c *decoratedCollection) Database() *mongo.Database { - return c.Collection.Database() -} - -func (c *decoratedCollection) Drop(ctx context.Context, opts ...options.Lister[options.DropCollectionOptions]) error { - return c.Collection.Drop(ctx, opts...) -} - -func (c *decoratedCollection) Indexes() mongo.IndexView { - return c.Collection.Indexes() -} - func (c *decoratedCollection) Watch(ctx context.Context, pipeline any, opts ...options.Lister[options.ChangeStreamOptions]) ( *mongo.ChangeStream, error) { return c.Collection.Watch(ctx, pipeline, opts...) @@ -578,72 +570,70 @@ func isDupKeyError(err error) bool { return e.HasErrorCode(duplicateKeyCode) } -type ( - // monCollection defines a MongoDB collection, used for unit test - monCollection interface { - // Aggregate executes an aggregation pipeline. - Aggregate(ctx context.Context, pipeline any, opts ...options.Lister[options.AggregateOptions]) ( - *mongo.Cursor, error) - // BulkWrite performs a bulk write operation. - BulkWrite(ctx context.Context, models []mongo.WriteModel, opts ...options.Lister[options.BulkWriteOptions]) ( - *mongo.BulkWriteResult, error) - // Clone creates a copy of this collection with the same settings. - Clone(opts ...options.Lister[options.CollectionOptions]) *mongo.Collection - // CountDocuments returns the number of documents in the collection that match the filter. - CountDocuments(ctx context.Context, filter any, opts ...options.Lister[options.CountOptions]) (int64, error) - // Database returns the database that this collection is a part of. - Database() *mongo.Database - // DeleteMany deletes documents from the collection that match the filter. - DeleteMany(ctx context.Context, filter any, opts ...options.Lister[options.DeleteManyOptions]) ( - *mongo.DeleteResult, error) - // DeleteOne deletes at most one document from the collection that matches the filter. - DeleteOne(ctx context.Context, filter any, opts ...options.Lister[options.DeleteOneOptions]) ( - *mongo.DeleteResult, error) - // Distinct returns a list of distinct values for the given key across the collection. - Distinct(ctx context.Context, fieldName string, filter any, - opts ...options.Lister[options.DistinctOptions]) *mongo.DistinctResult - // Drop drops this collection from database. - Drop(ctx context.Context, opts ...options.Lister[options.DropCollectionOptions]) error - // EstimatedDocumentCount returns an estimate of the count of documents in a collection - // using collection metadata. - EstimatedDocumentCount(ctx context.Context, opts ...options.Lister[options.EstimatedDocumentCountOptions]) (int64, error) - // Find finds the documents matching the provided filter. - Find(ctx context.Context, filter any, opts ...options.Lister[options.FindOptions]) (*mongo.Cursor, error) - // FindOne returns up to one document that matches the provided filter. - FindOne(ctx context.Context, filter any, opts ...options.Lister[options.FindOneOptions]) *mongo.SingleResult - // FindOneAndDelete returns at most one document that matches the filter. If the filter - // matches multiple documents, only the first document is deleted. - FindOneAndDelete(ctx context.Context, filter any, opts ...options.Lister[options.FindOneAndDeleteOptions]) *mongo.SingleResult - // FindOneAndReplace returns at most one document that matches the filter. If the filter - // matches multiple documents, FindOneAndReplace returns the first document in the - // collection that matches the filter. - FindOneAndReplace(ctx context.Context, filter, replacement any, - opts ...options.Lister[options.FindOneAndReplaceOptions]) *mongo.SingleResult - // FindOneAndUpdate returns at most one document that matches the filter. If the filter - // matches multiple documents, FindOneAndUpdate returns the first document in the - // collection that matches the filter. - FindOneAndUpdate(ctx context.Context, filter, update any, - opts ...options.Lister[options.FindOneAndUpdateOptions]) *mongo.SingleResult - // Indexes returns the index view for this collection. - Indexes() mongo.IndexView - // InsertMany inserts the provided documents. - InsertMany(ctx context.Context, documents interface{}, opts ...options.Lister[options.InsertManyOptions]) (*mongo.InsertManyResult, error) - // InsertOne inserts the provided document. - InsertOne(ctx context.Context, document any, opts ...options.Lister[options.InsertOneOptions]) (*mongo.InsertOneResult, error) - // ReplaceOne replaces at most one document that matches the filter. - ReplaceOne(ctx context.Context, filter, replacement any, - opts ...options.Lister[options.ReplaceOptions]) (*mongo.UpdateResult, error) - // UpdateByID updates a single document matching the provided filter. - UpdateByID(ctx context.Context, id, update any, - opts ...options.Lister[options.UpdateOneOptions]) (*mongo.UpdateResult, error) - // UpdateMany updates the provided documents. - UpdateMany(ctx context.Context, filter, update any, - opts ...options.Lister[options.UpdateManyOptions]) (*mongo.UpdateResult, error) - // UpdateOne updates a single document matching the provided filter. - UpdateOne(ctx context.Context, filter, update any, - opts ...options.Lister[options.UpdateOneOptions]) (*mongo.UpdateResult, error) - // Watch returns a change stream cursor used to receive notifications of changes to the collection. - Watch(ctx context.Context, pipeline any, opts ...options.Lister[options.ChangeStreamOptions]) ( - *mongo.ChangeStream, error) - } -) +// monCollection defines a MongoDB collection, used for unit test +type monCollection interface { + // Aggregate executes an aggregation pipeline. + Aggregate(ctx context.Context, pipeline any, opts ...options.Lister[options.AggregateOptions]) ( + *mongo.Cursor, error) + // BulkWrite performs a bulk write operation. + BulkWrite(ctx context.Context, models []mongo.WriteModel, opts ...options.Lister[options.BulkWriteOptions]) ( + *mongo.BulkWriteResult, error) + // Clone creates a copy of this collection with the same settings. + Clone(opts ...options.Lister[options.CollectionOptions]) *mongo.Collection + // CountDocuments returns the number of documents in the collection that match the filter. + CountDocuments(ctx context.Context, filter any, opts ...options.Lister[options.CountOptions]) (int64, error) + // Database returns the database that this collection is a part of. + Database() *mongo.Database + // DeleteMany deletes documents from the collection that match the filter. + DeleteMany(ctx context.Context, filter any, opts ...options.Lister[options.DeleteManyOptions]) ( + *mongo.DeleteResult, error) + // DeleteOne deletes at most one document from the collection that matches the filter. + DeleteOne(ctx context.Context, filter any, opts ...options.Lister[options.DeleteOneOptions]) ( + *mongo.DeleteResult, error) + // Distinct returns a list of distinct values for the given key across the collection. + Distinct(ctx context.Context, fieldName string, filter any, + opts ...options.Lister[options.DistinctOptions]) *mongo.DistinctResult + // Drop drops this collection from database. + Drop(ctx context.Context, opts ...options.Lister[options.DropCollectionOptions]) error + // EstimatedDocumentCount returns an estimate of the count of documents in a collection + // using collection metadata. + EstimatedDocumentCount(ctx context.Context, opts ...options.Lister[options.EstimatedDocumentCountOptions]) (int64, error) + // Find finds the documents matching the provided filter. + Find(ctx context.Context, filter any, opts ...options.Lister[options.FindOptions]) (*mongo.Cursor, error) + // FindOne returns up to one document that matches the provided filter. + FindOne(ctx context.Context, filter any, opts ...options.Lister[options.FindOneOptions]) *mongo.SingleResult + // FindOneAndDelete returns at most one document that matches the filter. If the filter + // matches multiple documents, only the first document is deleted. + FindOneAndDelete(ctx context.Context, filter any, opts ...options.Lister[options.FindOneAndDeleteOptions]) *mongo.SingleResult + // FindOneAndReplace returns at most one document that matches the filter. If the filter + // matches multiple documents, FindOneAndReplace returns the first document in the + // collection that matches the filter. + FindOneAndReplace(ctx context.Context, filter, replacement any, + opts ...options.Lister[options.FindOneAndReplaceOptions]) *mongo.SingleResult + // FindOneAndUpdate returns at most one document that matches the filter. If the filter + // matches multiple documents, FindOneAndUpdate returns the first document in the + // collection that matches the filter. + FindOneAndUpdate(ctx context.Context, filter, update any, + opts ...options.Lister[options.FindOneAndUpdateOptions]) *mongo.SingleResult + // Indexes returns the index view for this collection. + Indexes() mongo.IndexView + // InsertMany inserts the provided documents. + InsertMany(ctx context.Context, documents interface{}, opts ...options.Lister[options.InsertManyOptions]) (*mongo.InsertManyResult, error) + // InsertOne inserts the provided document. + InsertOne(ctx context.Context, document any, opts ...options.Lister[options.InsertOneOptions]) (*mongo.InsertOneResult, error) + // ReplaceOne replaces at most one document that matches the filter. + ReplaceOne(ctx context.Context, filter, replacement any, + opts ...options.Lister[options.ReplaceOptions]) (*mongo.UpdateResult, error) + // UpdateByID updates a single document matching the provided filter. + UpdateByID(ctx context.Context, id, update any, + opts ...options.Lister[options.UpdateOneOptions]) (*mongo.UpdateResult, error) + // UpdateMany updates the provided documents. + UpdateMany(ctx context.Context, filter, update any, + opts ...options.Lister[options.UpdateManyOptions]) (*mongo.UpdateResult, error) + // UpdateOne updates a single document matching the provided filter. + UpdateOne(ctx context.Context, filter, update any, + opts ...options.Lister[options.UpdateOneOptions]) (*mongo.UpdateResult, error) + // Watch returns a change stream cursor used to receive notifications of changes to the collection. + Watch(ctx context.Context, pipeline any, opts ...options.Lister[options.ChangeStreamOptions]) ( + *mongo.ChangeStream, error) +} diff --git a/core/stores/mon/collection_test.go b/core/stores/mon/collection_test.go index f044f060f..82a665e6f 100644 --- a/core/stores/mon/collection_test.go +++ b/core/stores/mon/collection_test.go @@ -12,7 +12,7 @@ import ( "github.com/zeromicro/go-zero/core/timex" "go.mongodb.org/mongo-driver/v2/bson" "go.mongodb.org/mongo-driver/v2/mongo" - mopt "go.mongodb.org/mongo-driver/v2/mongo/options" + "go.mongodb.org/mongo-driver/v2/mongo/options" "go.uber.org/mock/gomock" ) @@ -75,7 +75,7 @@ func TestCollection_Aggregate(t *testing.T) { mockCollection := NewMockmonCollection(ctrl) mockCollection.EXPECT().Aggregate(gomock.Any(), gomock.Any(), gomock.Any()).Return(&mongo.Cursor{}, nil) c := newTestCollection(mockCollection, breaker.GetBreaker("localhost")) - _, err := c.Aggregate(context.Background(), []interface{}{}, mopt.Aggregate()) + _, err := c.Aggregate(context.Background(), []interface{}{}, options.Aggregate()) assert.Nil(t, err) } @@ -156,10 +156,10 @@ func TestCollection_Find(t *testing.T) { mockCollection.EXPECT().Find(gomock.Any(), gomock.Any(), gomock.Any()).Return(&mongo.Cursor{}, nil) c := newTestCollection(mockCollection, breaker.GetBreaker("localhost")) filter := bson.D{{Key: "x", Value: 1}} - _, err := c.Find(context.Background(), filter, mopt.Find()) + _, err := c.Find(context.Background(), filter, options.Find()) assert.Nil(t, err) c.brk = new(dropBreaker) - _, err = c.Find(context.Background(), filter, mopt.Find()) + _, err = c.Find(context.Background(), filter, options.Find()) assert.Equal(t, errDummy, err) } @@ -185,9 +185,9 @@ func TestCollection_FindOneAndDelete(t *testing.T) { c := newTestCollection(mockCollection, breaker.GetBreaker("localhost")) filter := bson.D{} mockCollection.EXPECT().FindOneAndDelete(gomock.Any(), gomock.Any(), gomock.Any()).Return(&mongo.SingleResult{}) - _, err := c.FindOneAndDelete(context.Background(), filter, mopt.FindOneAndDelete()) + _, err := c.FindOneAndDelete(context.Background(), filter, options.FindOneAndDelete()) assert.Equal(t, mongo.ErrNoDocuments, err) - _, err = c.FindOneAndDelete(context.Background(), filter, mopt.FindOneAndDelete()) + _, err = c.FindOneAndDelete(context.Background(), filter, options.FindOneAndDelete()) assert.Equal(t, mongo.ErrNoDocuments, err) c.brk = new(dropBreaker) _, err = c.FindOneAndDelete(context.Background(), bson.D{{Key: "foo", Value: "bar"}}) @@ -202,7 +202,7 @@ func TestCollection_FindOneAndReplace(t *testing.T) { c := newTestCollection(mockCollection, breaker.GetBreaker("localhost")) filter := bson.D{{Key: "x", Value: 1}} replacement := bson.D{{Key: "x", Value: 2}} - opts := mopt.FindOneAndReplace().SetUpsert(true) + opts := options.FindOneAndReplace().SetUpsert(true) mockCollection.EXPECT().FindOneAndReplace(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(&mongo.SingleResult{}) _, err := c.FindOneAndReplace(context.Background(), filter, replacement, opts) assert.Equal(t, mongo.ErrNoDocuments, err) @@ -222,7 +222,7 @@ func TestCollection_FindOneAndUpdate(t *testing.T) { filter := bson.D{{Key: "x", Value: 1}} update := bson.D{{Key: "$x", Value: 2}} mockCollection.EXPECT().FindOneAndUpdate(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(&mongo.SingleResult{}) - opts := mopt.FindOneAndUpdate().SetUpsert(true) + opts := options.FindOneAndUpdate().SetUpsert(true) _, err := c.FindOneAndUpdate(context.Background(), filter, update, opts) assert.Equal(t, mongo.ErrNoDocuments, err) _, err = c.FindOneAndUpdate(context.Background(), filter, update, opts) @@ -487,6 +487,14 @@ func TestIsDupKeyError(t *testing.T) { } } +func newTestCollection(collection monCollection, brk breaker.Breaker) *decoratedCollection { + return &decoratedCollection{ + Collection: collection, + name: "test", + brk: brk, + } +} + type mockPromise struct { accepted bool reason string diff --git a/core/stores/mon/collection_inserter_mock.go b/core/stores/mon/collectioninserter_mock.go similarity index 94% rename from core/stores/mon/collection_inserter_mock.go rename to core/stores/mon/collectioninserter_mock.go index de38dd5d4..1d387888c 100644 --- a/core/stores/mon/collection_inserter_mock.go +++ b/core/stores/mon/collectioninserter_mock.go @@ -3,7 +3,7 @@ // // Generated by this command: // -// mockgen -package mon -destination collection_inserter_mock.go -source bulkinserter.go collectionInserter +// mockgen -package mon -destination collectioninserter_mock.go -source bulkinserter.go collectionInserter // // Package mon is a generated GoMock package. diff --git a/core/stores/mon/model.go b/core/stores/mon/model.go index 2ccb6f0f3..1d944ef84 100644 --- a/core/stores/mon/model.go +++ b/core/stores/mon/model.go @@ -30,7 +30,7 @@ type ( opts []Option } - WrappedSession struct { + Session struct { session monSession name string brk breaker.Breaker @@ -62,25 +62,14 @@ func newModel(name string, cli *mongo.Client, coll Collection, brk breaker.Break return &Model{ name: name, Collection: coll, - cli: &mockMonClient{c: cli}, - brk: brk, - opts: opts, - } -} - -func newTestModel(name string, cli monClient, coll monCollection, brk breaker.Breaker, - opts ...Option) *Model { - return &Model{ - name: name, - Collection: newTestCollection(coll, breaker.GetBreaker("localhost")), - cli: cli, + cli: &wrappedMonClient{c: cli}, brk: brk, opts: opts, } } // StartSession starts a new session. -func (m *Model) StartSession(opts ...options.Lister[options.SessionOptions]) (sess *WrappedSession, err error) { +func (m *Model) StartSession(opts ...options.Lister[options.SessionOptions]) (sess *Session, err error) { starTime := timex.Now() defer func() { logDuration(context.Background(), m.name, startSession, starTime, err) @@ -91,7 +80,7 @@ func (m *Model) StartSession(opts ...options.Lister[options.SessionOptions]) (se return nil, sessionErr } - return &WrappedSession{ + return &Session{ session: session, name: m.name, brk: m.brk, @@ -99,7 +88,8 @@ func (m *Model) StartSession(opts ...options.Lister[options.SessionOptions]) (se } // Aggregate executes an aggregation pipeline. -func (m *Model) Aggregate(ctx context.Context, v, pipeline any, opts ...options.Lister[options.AggregateOptions]) error { +func (m *Model) Aggregate(ctx context.Context, v, pipeline any, + opts ...options.Lister[options.AggregateOptions]) error { cur, err := m.Collection.Aggregate(ctx, pipeline, opts...) if err != nil { return err @@ -110,7 +100,8 @@ func (m *Model) Aggregate(ctx context.Context, v, pipeline any, opts ...options. } // DeleteMany deletes documents that match the filter. -func (m *Model) DeleteMany(ctx context.Context, filter any, opts ...options.Lister[options.DeleteManyOptions]) (int64, error) { +func (m *Model) DeleteMany(ctx context.Context, filter any, + opts ...options.Lister[options.DeleteManyOptions]) (int64, error) { res, err := m.Collection.DeleteMany(ctx, filter, opts...) if err != nil { return 0, err @@ -120,7 +111,8 @@ func (m *Model) DeleteMany(ctx context.Context, filter any, opts ...options.List } // DeleteOne deletes the first document that matches the filter. -func (m *Model) DeleteOne(ctx context.Context, filter any, opts ...options.Lister[options.DeleteOneOptions]) (int64, error) { +func (m *Model) DeleteOne(ctx context.Context, filter any, + opts ...options.Lister[options.DeleteOneOptions]) (int64, error) { res, err := m.Collection.DeleteOne(ctx, filter, opts...) if err != nil { return 0, err @@ -130,7 +122,8 @@ func (m *Model) DeleteOne(ctx context.Context, filter any, opts ...options.Liste } // Find finds documents that match the filter. -func (m *Model) Find(ctx context.Context, v, filter any, opts ...options.Lister[options.FindOptions]) error { +func (m *Model) Find(ctx context.Context, v, filter any, + opts ...options.Lister[options.FindOptions]) error { cur, err := m.Collection.Find(ctx, filter, opts...) if err != nil { return err @@ -141,7 +134,8 @@ func (m *Model) Find(ctx context.Context, v, filter any, opts ...options.Lister[ } // FindOne finds the first document that matches the filter. -func (m *Model) FindOne(ctx context.Context, v, filter any, opts ...options.Lister[options.FindOneOptions]) error { +func (m *Model) FindOne(ctx context.Context, v, filter any, + opts ...options.Lister[options.FindOneOptions]) error { res, err := m.Collection.FindOne(ctx, filter, opts...) if err != nil { return err @@ -184,7 +178,7 @@ func (m *Model) FindOneAndUpdate(ctx context.Context, v, filter, update any, } // AbortTransaction implements the mongo.session interface. -func (w *WrappedSession) AbortTransaction(ctx context.Context) (err error) { +func (w *Session) AbortTransaction(ctx context.Context) (err error) { ctx, span := startSpan(ctx, abortTransaction) defer func() { endSpan(span, err) @@ -201,7 +195,7 @@ func (w *WrappedSession) AbortTransaction(ctx context.Context) (err error) { } // CommitTransaction implements the mongo.session interface. -func (w *WrappedSession) CommitTransaction(ctx context.Context) (err error) { +func (w *Session) CommitTransaction(ctx context.Context) (err error) { ctx, span := startSpan(ctx, commitTransaction) defer func() { endSpan(span, err) @@ -218,7 +212,7 @@ func (w *WrappedSession) CommitTransaction(ctx context.Context) (err error) { } // WithTransaction implements the mongo.session interface. -func (w *WrappedSession) WithTransaction( +func (w *Session) WithTransaction( ctx context.Context, fn func(sessCtx context.Context) (any, error), opts ...options.Lister[options.TransactionOptions], @@ -242,7 +236,7 @@ func (w *WrappedSession) WithTransaction( } // EndSession implements the mongo.session interface. -func (w *WrappedSession) EndSession(ctx context.Context) { +func (w *Session) EndSession(ctx context.Context) { var err error ctx, span := startSpan(ctx, endSession) defer func() { @@ -261,10 +255,11 @@ func (w *WrappedSession) EndSession(ctx context.Context) { } type ( - //for unit test + // for unit test monClient interface { StartSession(opts ...options.Lister[options.SessionOptions]) (monSession, error) } + monSession interface { AbortTransaction(ctx context.Context) error CommitTransaction(ctx context.Context) error @@ -274,10 +269,14 @@ type ( } ) -type mockMonClient struct { +type wrappedMonClient struct { c *mongo.Client } -func (m *mockMonClient) StartSession(opts ...options.Lister[options.SessionOptions]) (monSession, error) { +// StartSession starts a new session using the underlying *mongo.Client. +// It implements the monClient interface. +// This is used to allow mocking in unit tests. +func (m *wrappedMonClient) StartSession(opts ...options.Lister[options.SessionOptions]) ( + monSession, error) { return m.c.StartSession(opts...) } diff --git a/core/stores/mon/model_test.go b/core/stores/mon/model_test.go index f8861a7c9..aecb267b3 100644 --- a/core/stores/mon/model_test.go +++ b/core/stores/mon/model_test.go @@ -20,7 +20,7 @@ func TestModel_StartSession(t *testing.T) { mockMonCollection := NewMockmonCollection(ctrl) mockedMonClient := NewMockmonClient(ctrl) mockMonSession := NewMockmonSession(ctrl) - warpSession := &WrappedSession{ + warpSession := &Session{ session: mockMonSession, name: "", brk: breaker.GetBreaker("localhost"), @@ -39,9 +39,9 @@ func TestModel_StartSession(t *testing.T) { mockMonSession.EXPECT().AbortTransaction(gomock.Any()).Return(nil) mockMonSession.EXPECT().EndSession(gomock.Any()) _, err = sess.WithTransaction(context.Background(), func(sessCtx context.Context) (any, error) { - //_ = sessCtx.StartTransaction() - //sessCtx.Client().Database("1") - //sessCtx.EndSession(context.Background()) + // _ = sessCtx.StartTransaction() + // sessCtx.Client().Database("1") + // sessCtx.EndSession(context.Background()) return nil, nil }) assert.Nil(t, err) @@ -224,9 +224,20 @@ func Test_mockMonClient_StartSession(t *testing.T) { opts.Deployment = md client, err := mongo.Connect(opts) assert.Nil(t, err) - m := mockMonClient{ + m := wrappedMonClient{ c: client, } _, err = m.StartSession() assert.Nil(t, err) } + +func newTestModel(name string, cli monClient, coll monCollection, brk breaker.Breaker, + opts ...Option) *Model { + return &Model{ + name: name, + Collection: newTestCollection(coll, breaker.GetBreaker("localhost")), + cli: cli, + brk: brk, + opts: opts, + } +} diff --git a/core/stores/mon/options_test.go b/core/stores/mon/options_test.go index cebb99b5d..447856c2a 100644 --- a/core/stores/mon/options_test.go +++ b/core/stores/mon/options_test.go @@ -8,7 +8,7 @@ import ( "github.com/stretchr/testify/assert" "go.mongodb.org/mongo-driver/v2/bson" - mopt "go.mongodb.org/mongo-driver/v2/mongo/options" + "go.mongodb.org/mongo-driver/v2/mongo/options" ) func TestSetSlowThreshold(t *testing.T) { @@ -18,13 +18,13 @@ func TestSetSlowThreshold(t *testing.T) { } func Test_defaultTimeoutOption(t *testing.T) { - opts := mopt.Client() + opts := options.Client() defaultTimeoutOption()(opts) assert.Equal(t, defaultTimeout, *opts.Timeout) } func TestWithTimeout(t *testing.T) { - opts := mopt.Client() + opts := options.Client() WithTimeout(time.Second)(opts) assert.Equal(t, time.Second, *opts.Timeout) } @@ -56,10 +56,11 @@ func TestDisableInfoLog(t *testing.T) { } func TestWithRegistryForTimestampRegisterType(t *testing.T) { - opts := mopt.Client() + opts := options.Client() // mongoDateTimeEncoder allow user convert time.Time to primitive.DateTime. - var mongoDateTimeEncoder bson.ValueEncoderFunc = func(ect bson.EncodeContext, w bson.ValueWriter, value reflect.Value) error { + var mongoDateTimeEncoder bson.ValueEncoderFunc = func(ect bson.EncodeContext, + w bson.ValueWriter, value reflect.Value) error { // Use reflect, determine if it can be converted to time.Time. dec, ok := value.Interface().(time.Time) if !ok { @@ -69,7 +70,8 @@ func TestWithRegistryForTimestampRegisterType(t *testing.T) { } // mongoDateTimeEncoder allow user convert primitive.DateTime to time.Time. - var mongoDateTimeDecoder bson.ValueDecoderFunc = func(ect bson.DecodeContext, r bson.ValueReader, value reflect.Value) error { + var mongoDateTimeDecoder bson.ValueDecoderFunc = func(ect bson.DecodeContext, + r bson.ValueReader, value reflect.Value) error { primTime, err := r.ReadDateTime() if err != nil { return fmt.Errorf("error reading primitive.DateTime from ValueReader: %v", err) diff --git a/core/stores/monc/cachedmodel.go b/core/stores/monc/cachedmodel.go index 1d1b0ffb9..de6cb68fa 100644 --- a/core/stores/monc/cachedmodel.go +++ b/core/stores/monc/cachedmodel.go @@ -71,27 +71,6 @@ func newModel(uri, db, collection string, c cache.Cache) (*Model, error) { }, nil } -// mustNewTestModel returns a test Model with the given cache. -func mustNewTestModel(collection mon.Collection, c cache.CacheConf, opts ...cache.Option) *Model { - return &Model{ - Model: &mon.Model{ - Collection: collection, - }, - cache: cache.New(c, singleFlight, stats, mongo.ErrNoDocuments, opts...), - } -} - -// NewNodeModel returns a test Model with a cache node. -func mustNewTestNodeModel(collection mon.Collection, rds *redis.Redis, opts ...cache.Option) *Model { - c := cache.NewNode(rds, singleFlight, stats, mongo.ErrNoDocuments, opts...) - return &Model{ - Model: &mon.Model{ - Collection: collection, - }, - cache: c, - } -} - // DelCache deletes the cache with given keys. func (mm *Model) DelCache(ctx context.Context, keys ...string) error { return mm.cache.DelCtx(ctx, keys...) diff --git a/core/stores/monc/cachedmodel_test.go b/core/stores/monc/cachedmodel_test.go index f33840ace..de35ce071 100644 --- a/core/stores/monc/cachedmodel_test.go +++ b/core/stores/monc/cachedmodel_test.go @@ -532,6 +532,27 @@ func createModel(t *testing.T, coll mon.Collection) *Model { } } +// mustNewTestModel returns a test Model with the given cache. +func mustNewTestModel(collection mon.Collection, c cache.CacheConf, opts ...cache.Option) *Model { + return &Model{ + Model: &mon.Model{ + Collection: collection, + }, + cache: cache.New(c, singleFlight, stats, mongo.ErrNoDocuments, opts...), + } +} + +// NewNodeModel returns a test Model with a cache node. +func mustNewTestNodeModel(collection mon.Collection, rds *redis.Redis, opts ...cache.Option) *Model { + c := cache.NewNode(rds, singleFlight, stats, mongo.ErrNoDocuments, opts...) + return &Model{ + Model: &mon.Model{ + Collection: collection, + }, + cache: c, + } +} + var ( errMocked = errors.New("mocked error") index int32