55 "context"
66 "errors"
77 "io"
8+ "sync"
89 "time"
910
1011 "github.com/stretchr/testify/mock"
@@ -16,10 +17,14 @@ var errObjectDoesNotExist = errors.New("object does not exist")
1617// ClientMock mocks objstore.Bucket
1718type ClientMock struct {
1819 mock.Mock
20+ uploaded sync.Map
1921}
2022
2123// Upload mocks objstore.Bucket.Upload()
2224func (m * ClientMock ) Upload (ctx context.Context , name string , r io.Reader ) error {
25+ if _ , ok := m .uploaded .Load (name ); ok {
26+ m .uploaded .Store (name , true )
27+ }
2328 args := m .Called (ctx , name , r )
2429 return args .Error (0 )
2530}
@@ -30,6 +35,8 @@ func (m *ClientMock) MockUpload(name string, err error) {
3035
3136// Delete mocks objstore.Bucket.Delete()
3237func (m * ClientMock ) Delete (ctx context.Context , name string ) error {
38+ m .uploaded .Delete (name )
39+
3340 args := m .Called (ctx , name )
3441 return args .Error (0 )
3542}
@@ -70,7 +77,20 @@ func (m *ClientMock) MockIterWithCallback(prefix string, objects []string, err e
7077
7178// Get mocks objstore.Bucket.Get()
7279func (m * ClientMock ) Get (ctx context.Context , name string ) (io.ReadCloser , error ) {
80+ if val , ok := m .uploaded .Load (name ); ok {
81+ uploaded := val .(bool )
82+ if ! uploaded {
83+ return nil , errObjectDoesNotExist
84+ }
85+ }
86+
7387 args := m .Called (ctx , name )
88+
89+ // Allow to mock the Get() with a function which is called each time.
90+ if fn , ok := args .Get (0 ).(func (ctx context.Context , name string ) (io.ReadCloser , error )); ok {
91+ return fn (ctx , name )
92+ }
93+
7494 val , err := args .Get (0 ), args .Error (1 )
7595 if val == nil {
7696 return nil , err
@@ -90,9 +110,8 @@ func (m *ClientMock) MockGet(name, content string, err error) {
90110 // Since we return an ReadCloser and it can be consumed only once,
91111 // each time the mocked Get() is called we do create a new one, so
92112 // that getting the same mocked object twice works as expected.
93- mockedGet := m .On ("Get" , mock .Anything , name )
94- mockedGet .Run (func (args mock.Arguments ) {
95- mockedGet .Return (io .NopCloser (bytes .NewReader ([]byte (content ))), err )
113+ m .On ("Get" , mock .Anything , name ).Return (func (_ context.Context , _ string ) (io.ReadCloser , error ) {
114+ return io .NopCloser (bytes .NewReader ([]byte (content ))), err
96115 })
97116 } else {
98117 m .On ("Exists" , mock .Anything , name ).Return (false , err )
@@ -101,6 +120,13 @@ func (m *ClientMock) MockGet(name, content string, err error) {
101120 }
102121}
103122
123+ // MockGetRequireUpload is a convenient method to mock Get() return resulst after upload,
124+ // otherwise return errObjectDoesNotExist
125+ func (m * ClientMock ) MockGetRequireUpload (name , content string , err error ) {
126+ m .uploaded .Store (name , false )
127+ m .MockGet (name , content , err )
128+ }
129+
104130// MockGetTimes is a convenient method to mock Get() and Exists() to run x time
105131func (m * ClientMock ) MockGetTimes (name , content string , err error , times int ) {
106132 if content != "" {
0 commit comments