Skip to content

Commit 668fd67

Browse files
authored
1.fix data race in ClientMock.MockGet() (#5248)
2.fix compactor unit test anomalous error Signed-off-by: yiyang5055 <[email protected]>
1 parent 11d5ca9 commit 668fd67

File tree

3 files changed

+61
-5
lines changed

3 files changed

+61
-5
lines changed

pkg/compactor/compactor_test.go

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1215,8 +1215,8 @@ func TestCompactor_ShouldCompactOnlyShardsOwnedByTheInstanceOnShardingEnabledWit
12151215
bucketClient.MockGet(userID+"/"+blockID+"/meta.json", mockBlockMetaJSONWithTime(blockID, userID, blockTimes["startTime"], blockTimes["endTime"]), nil)
12161216
bucketClient.MockGet(userID+"/"+blockID+"/deletion-mark.json", "", nil)
12171217
bucketClient.MockGet(userID+"/"+blockID+"/no-compact-mark.json", "", nil)
1218-
bucketClient.MockGetTimes(userID+"/"+blockID+"/visit-mark.json", "", nil, 1)
12191218
bucketClient.MockGet(userID+"/"+blockID+"/visit-mark.json", string(visitMarkerFileContent), nil)
1219+
bucketClient.MockGetRequireUpload(userID+"/"+blockID+"/visit-mark.json", string(visitMarkerFileContent), nil)
12201220
bucketClient.MockUpload(userID+"/"+blockID+"/visit-mark.json", nil)
12211221
blockDirectory = append(blockDirectory, userID+"/"+blockID)
12221222

@@ -1243,6 +1243,7 @@ func TestCompactor_ShouldCompactOnlyShardsOwnedByTheInstanceOnShardingEnabledWit
12431243
for i := 1; i <= 4; i++ {
12441244
cfg := prepareConfig()
12451245
cfg.ShardingEnabled = true
1246+
cfg.CompactionInterval = 15 * time.Second
12461247
cfg.ShardingStrategy = util.ShardingStrategyShuffle
12471248
cfg.ShardingRing.InstanceID = fmt.Sprintf("compactor-%d", i)
12481249
cfg.ShardingRing.InstanceAddr = fmt.Sprintf("127.0.0.%d", i)
@@ -1280,7 +1281,7 @@ func TestCompactor_ShouldCompactOnlyShardsOwnedByTheInstanceOnShardingEnabledWit
12801281

12811282
// Wait until a run has been completed on each compactor
12821283
for _, c := range compactors {
1283-
cortex_testutil.Poll(t, 60*time.Second, 1.0, func() interface{} {
1284+
cortex_testutil.Poll(t, 60*time.Second, 2.0, func() interface{} {
12841285
return prom_testutil.ToFloat64(c.compactionRunsCompleted)
12851286
})
12861287
}

pkg/storage/bucket/client_mock.go

Lines changed: 29 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ import (
55
"context"
66
"errors"
77
"io"
8+
"sync"
89
"time"
910

1011
"github.com/stretchr/testify/mock"
@@ -16,10 +17,14 @@ var errObjectDoesNotExist = errors.New("object does not exist")
1617
// ClientMock mocks objstore.Bucket
1718
type ClientMock struct {
1819
mock.Mock
20+
uploaded sync.Map
1921
}
2022

2123
// Upload mocks objstore.Bucket.Upload()
2224
func (m *ClientMock) Upload(ctx context.Context, name string, r io.Reader) error {
25+
if _, ok := m.uploaded.Load(name); ok {
26+
m.uploaded.Store(name, true)
27+
}
2328
args := m.Called(ctx, name, r)
2429
return args.Error(0)
2530
}
@@ -30,6 +35,8 @@ func (m *ClientMock) MockUpload(name string, err error) {
3035

3136
// Delete mocks objstore.Bucket.Delete()
3237
func (m *ClientMock) Delete(ctx context.Context, name string) error {
38+
m.uploaded.Delete(name)
39+
3340
args := m.Called(ctx, name)
3441
return args.Error(0)
3542
}
@@ -70,7 +77,20 @@ func (m *ClientMock) MockIterWithCallback(prefix string, objects []string, err e
7077

7178
// Get mocks objstore.Bucket.Get()
7279
func (m *ClientMock) Get(ctx context.Context, name string) (io.ReadCloser, error) {
80+
if val, ok := m.uploaded.Load(name); ok {
81+
uploaded := val.(bool)
82+
if !uploaded {
83+
return nil, errObjectDoesNotExist
84+
}
85+
}
86+
7387
args := m.Called(ctx, name)
88+
89+
// Allow to mock the Get() with a function which is called each time.
90+
if fn, ok := args.Get(0).(func(ctx context.Context, name string) (io.ReadCloser, error)); ok {
91+
return fn(ctx, name)
92+
}
93+
7494
val, err := args.Get(0), args.Error(1)
7595
if val == nil {
7696
return nil, err
@@ -90,9 +110,8 @@ func (m *ClientMock) MockGet(name, content string, err error) {
90110
// Since we return an ReadCloser and it can be consumed only once,
91111
// each time the mocked Get() is called we do create a new one, so
92112
// that getting the same mocked object twice works as expected.
93-
mockedGet := m.On("Get", mock.Anything, name)
94-
mockedGet.Run(func(args mock.Arguments) {
95-
mockedGet.Return(io.NopCloser(bytes.NewReader([]byte(content))), err)
113+
m.On("Get", mock.Anything, name).Return(func(_ context.Context, _ string) (io.ReadCloser, error) {
114+
return io.NopCloser(bytes.NewReader([]byte(content))), err
96115
})
97116
} else {
98117
m.On("Exists", mock.Anything, name).Return(false, err)
@@ -101,6 +120,13 @@ func (m *ClientMock) MockGet(name, content string, err error) {
101120
}
102121
}
103122

123+
// MockGetRequireUpload is a convenient method to mock Get() return resulst after upload,
124+
// otherwise return errObjectDoesNotExist
125+
func (m *ClientMock) MockGetRequireUpload(name, content string, err error) {
126+
m.uploaded.Store(name, false)
127+
m.MockGet(name, content, err)
128+
}
129+
104130
// MockGetTimes is a convenient method to mock Get() and Exists() to run x time
105131
func (m *ClientMock) MockGetTimes(name, content string, err error, times int) {
106132
if content != "" {

pkg/storage/bucket/client_test.go

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@ package bucket
22

33
import (
44
"context"
5+
"io"
6+
"sync"
57
"testing"
68

79
"github.com/stretchr/testify/assert"
@@ -92,3 +94,30 @@ func TestNewClient(t *testing.T) {
9294
})
9395
}
9496
}
97+
98+
func TestClientMock_MockGet(t *testing.T) {
99+
expected := "body"
100+
101+
m := ClientMock{}
102+
m.MockGet("test", expected, nil)
103+
104+
// Run many goroutines all requesting the same mocked object and
105+
// ensure there's no race.
106+
wg := sync.WaitGroup{}
107+
for i := 0; i < 1000; i++ {
108+
wg.Add(1)
109+
go func() {
110+
defer wg.Done()
111+
112+
reader, err := m.Get(context.Background(), "test")
113+
require.NoError(t, err)
114+
actual, err := io.ReadAll(reader)
115+
require.NoError(t, err)
116+
require.Equal(t, []byte(expected), actual)
117+
118+
require.NoError(t, reader.Close())
119+
}()
120+
}
121+
122+
wg.Wait()
123+
}

0 commit comments

Comments
 (0)