diff --git a/storage/inmemory/mutexmap/rulebasedsegment.go b/storage/inmemory/mutexmap/rulebasedsegment.go index dcdd07b5..f27acb17 100644 --- a/storage/inmemory/mutexmap/rulebasedsegment.go +++ b/storage/inmemory/mutexmap/rulebasedsegment.go @@ -159,6 +159,25 @@ func (r *RuleBasedSegmentsStorageImpl) GetRuleBasedSegmentByName(name string) (* return nil, fmt.Errorf("rule-based segment %s not found in storage", name) } +func (m *RuleBasedSegmentsStorageImpl) _get(splitName string) *dtos.RuleBasedSegmentDTO { + item, exists := m.data[splitName] + if !exists { + return nil + } + return &item +} + +// FetchMany fetches rule-based segments in the storage and returns an array of rule-based segments dtos +func (m *RuleBasedSegmentsStorageImpl) FetchMany(rbsNames []string) map[string]*dtos.RuleBasedSegmentDTO { + m.mutex.RLock() + defer m.mutex.RUnlock() + rbSegments := make(map[string]*dtos.RuleBasedSegmentDTO) + for _, rbsName := range rbsNames { + rbSegments[rbsName] = m._get(rbsName) + } + return rbSegments +} + func (r *RuleBasedSegmentsStorageImpl) ReplaceAll(toAdd []dtos.RuleBasedSegmentDTO, changeNumber int64) error { r.mutex.RLock() toRemove := make([]dtos.RuleBasedSegmentDTO, 0) diff --git a/storage/inmemory/mutexmap/rulebasedsegment_test.go b/storage/inmemory/mutexmap/rulebasedsegment_test.go index ac23fa4b..aae925f6 100644 --- a/storage/inmemory/mutexmap/rulebasedsegment_test.go +++ b/storage/inmemory/mutexmap/rulebasedsegment_test.go @@ -89,6 +89,13 @@ func TestRuleBasedSegmentsStorage(t *testing.T) { assert.Contains(t, names, "rule1") assert.Contains(t, names, "rule2") + // Test FetchMany + rbsFetchMany := storage.FetchMany([]string{"rule1", "rule2", "nonexistent"}) + assert.Len(t, rbsFetchMany, 3) + assert.Equal(t, "rule1", rbsFetchMany["rule1"].Name) + assert.Equal(t, "rule2", rbsFetchMany["rule2"].Name) + assert.Nil(t, rbsFetchMany["nonexistent"]) + // Test GetSegments segments := storage.Segments() assert.True(t, segments.Has("segment1"), "segment1 should be in segments") diff --git a/storage/interfaces.go b/storage/interfaces.go index 20a3b3a0..7d37ddb4 100644 --- a/storage/interfaces.go +++ b/storage/interfaces.go @@ -281,6 +281,7 @@ type RuleBasedSegmentStorageProducer interface { type RuleBasedSegmentStorageConsumer interface { ChangeNumber() (int64, error) All() []dtos.RuleBasedSegmentDTO + FetchMany(rbsNames []string) map[string]*dtos.RuleBasedSegmentDTO RuleBasedSegmentNames() ([]string, error) Contains(ruleBasedSegmentNames []string) bool Segments() *set.ThreadUnsafeSet diff --git a/storage/mocks/rulebasedsegment.go b/storage/mocks/rulebasedsegment.go index 236fb58a..27fae866 100644 --- a/storage/mocks/rulebasedsegment.go +++ b/storage/mocks/rulebasedsegment.go @@ -81,4 +81,9 @@ func (m *MockRuleBasedSegmentStorage) LargeSegments() *set.ThreadUnsafeSet { return args.Get(0).(*set.ThreadUnsafeSet) } +func (m *MockRuleBasedSegmentStorage) FetchMany(rbsNames []string) map[string]*dtos.RuleBasedSegmentDTO { + args := m.Called(rbsNames) + return args.Get(0).(map[string]*dtos.RuleBasedSegmentDTO) +} + var _ storage.RuleBasedSegmentsStorage = (*MockRuleBasedSegmentStorage)(nil) diff --git a/storage/redis/rulebasedsegment_test.go b/storage/redis/rulebasedsegment_test.go index 01c6de9a..16b6deb3 100644 --- a/storage/redis/rulebasedsegment_test.go +++ b/storage/redis/rulebasedsegment_test.go @@ -605,6 +605,54 @@ func TestReplaceAllRuleBased(t *testing.T) { redisClient.Del(keys...) } +func TestRBFetchMany(t *testing.T) { + t.Run("FetchMany Error", func(t *testing.T) { + expectedKey := "someprefix.SPLITIO.rbsegment.someRB1" + expectedKey2 := "someprefix.SPLITIO.rbsegment.someRB2" + + mockedRedisClient := mocks.MockClient{ + MGetCall: func(keys []string) redis.Result { + assert.ElementsMatch(t, []string{expectedKey, expectedKey2}, keys) + return &mocks.MockResultOutput{ + MultiInterfaceCall: func() ([]interface{}, error) { + return []interface{}{}, errors.New("Some Error") + }, + } + }, + } + mockPrefixedClient, _ := redis.NewPrefixedRedisClient(&mockedRedisClient, "someprefix") + rbStorage := NewRuleBasedStorage(mockPrefixedClient, logging.NewLogger(&logging.LoggerOptions{})) + rbs := rbStorage.FetchMany([]string{"someRB1", "someRB2"}) + assert.Nil(t, rbs) + }) + + t.Run("FetchMany Success", func(t *testing.T) { + expectedKey := "someprefix.SPLITIO.rbsegment.someRB1" + expectedKey2 := "someprefix.SPLITIO.rbsegment.someRB2" + + mockedRedisClient := mocks.MockClient{ + MGetCall: func(keys []string) redis.Result { + assert.ElementsMatch(t, []string{expectedKey, expectedKey2}, keys) + return &mocks.MockResultOutput{ + MultiInterfaceCall: func() ([]interface{}, error) { + return []interface{}{ + marshalRuleBasedSegment(createSampleRBSegment("someRB1")), + marshalRuleBasedSegment(createSampleRBSegment("someRB2")), + }, nil + }, + } + }, + } + mockPrefixedClient, _ := redis.NewPrefixedRedisClient(&mockedRedisClient, "someprefix") + + rbStorage := NewRuleBasedStorage(mockPrefixedClient, logging.NewLogger(&logging.LoggerOptions{})) + rbs := rbStorage.FetchMany([]string{"someRB1", "someRB2"}) + assert.Equal(t, 2, len(rbs)) + assert.NotNil(t, rbs["someRB1"]) + assert.NotNil(t, rbs["someRB2"]) + }) +} + func marshalRuleBasedSegment(rbSegment dtos.RuleBasedSegmentDTO) string { json, _ := json.Marshal(rbSegment) return string(json) diff --git a/storage/redis/rulebasedsegments.go b/storage/redis/rulebasedsegments.go index 9ba75a87..76deb5dc 100644 --- a/storage/redis/rulebasedsegments.go +++ b/storage/redis/rulebasedsegments.go @@ -368,4 +368,36 @@ func (r *RuleBasedSegmentStorage) executePipeline(pipeline redis.Pipeline, toAdd return failedToAdd, failedToRemove } +func (r *RuleBasedSegmentStorage) FetchMany(names []string) map[string]*dtos.RuleBasedSegmentDTO { + if len(names) == 0 { + return nil + } + + keysToFetch := make([]string, 0, len(names)) + for _, name := range names { + keysToFetch = append(keysToFetch, strings.Replace(KeyRuleBasedSegment, "{rbsegment}", name, 1)) + } + rawRBS, err := r.client.MGet(keysToFetch) + if err != nil { + r.logger.Error(fmt.Sprintf("Could not fetch rule-based segments from redis: %s", err.Error())) + return nil + } + + rbs := make(map[string]*dtos.RuleBasedSegmentDTO) + for idx, rb := range names { + var rbSegment *dtos.RuleBasedSegmentDTO + rawRBSegment, ok := rawRBS[idx].(string) + if ok { + err = json.Unmarshal([]byte(rawRBSegment), &rbSegment) + if err != nil { + r.logger.Error("Could not parse rule-based segment \"%s\" fetched from redis", rb) + return nil + } + } + rbs[rb] = rbSegment + } + + return rbs +} + var _ storage.RuleBasedSegmentsStorage = (*RuleBasedSegmentStorage)(nil) diff --git a/storage/redis/splits_test.go b/storage/redis/splits_test.go index f7891bb8..312c5379 100644 --- a/storage/redis/splits_test.go +++ b/storage/redis/splits_test.go @@ -562,20 +562,20 @@ func TestUpdateRedis(t *testing.T) { if len(splits) != 3 { t.Error("Unexpected amount of splits") } - set1, err := redisClient.SMembers("SPLITIO.flagSet.set1") + set1, _ := redisClient.SMembers("SPLITIO.flagSet.set1") if len(set1) != 2 { t.Error("set size should be 2") } if !slices.Contains(set1, "split1") || !slices.Contains(set1, "split2") { t.Error("Split missing in set") } - tt, err := redisClient.Get("SPLITIO.trafficType.user") - ttCount, _ := strconv.ParseFloat(tt, 10) + tt, _ := redisClient.Get("SPLITIO.trafficType.user") + ttCount, _ := strconv.ParseFloat(tt, 64) if ttCount != 3 { t.Error("Split should exist") } - till, err := redisClient.Get("SPLITIO.splits.till") - tillInt, _ := strconv.ParseFloat(till, 10) + till, _ := redisClient.Get("SPLITIO.splits.till") + tillInt, _ := strconv.ParseFloat(till, 64) if tillInt != 1 { t.Error("ChangeNumber should be 1") } @@ -587,29 +587,29 @@ func TestUpdateRedis(t *testing.T) { if len(splits) != 3 { t.Error("Unexpected size") } - set1, err = redisClient.SMembers("SPLITIO.flagSet.set1") + set1, _ = redisClient.SMembers("SPLITIO.flagSet.set1") if len(set1) != 0 { t.Error("set size should be 0") } - set3, err := redisClient.SMembers("SPLITIO.flagSet.set3") + set3, _ := redisClient.SMembers("SPLITIO.flagSet.set3") if len(set3) != 3 { t.Error("set size should be 3") } if !slices.Contains(set3, "split3") || !slices.Contains(set3, "split4") || !slices.Contains(set3, "split5") { t.Error("Split missing in set") } - tt, err = redisClient.Get("SPLITIO.trafficType.user") - ttCount, _ = strconv.ParseFloat(tt, 10) + tt, _ = redisClient.Get("SPLITIO.trafficType.user") + ttCount, _ = strconv.ParseFloat(tt, 64) if ttCount != 3 { t.Error("Unexpected trafficType occurrences") } - split1, err := redisClient.Get("SPLITIO.split.split1") + split1, _ := redisClient.Get("SPLITIO.split.split1") if split1 != "" { t.Error("Split should not exist") } - till, err = redisClient.Get("SPLITIO.splits.till") - tillInt, _ = strconv.ParseFloat(till, 10) + till, _ = redisClient.Get("SPLITIO.splits.till") + tillInt, _ = strconv.ParseFloat(till, 64) if tillInt != 2 { t.Error("ChangeNumber should be 2") } @@ -649,15 +649,15 @@ func TestUpdateWithFlagSetFiltersRedis(t *testing.T) { if len(splits) != 3 { t.Error("Unexpected amount of splits") } - set1, err := redisClient.SMembers("SPLITIO.flagSet.set1") + set1, _ := redisClient.SMembers("SPLITIO.flagSet.set1") if len(set1) != 2 { t.Error("set size should be 2") } - set2, err := redisClient.SMembers("SPLITIO.flagSet.set2") + set2, _ := redisClient.SMembers("SPLITIO.flagSet.set2") if len(set2) != 1 { t.Error("set size should be 1") } - set3, err := redisClient.SMembers("SPLITIO.flagSet.set3") + set3, _ := redisClient.SMembers("SPLITIO.flagSet.set3") if len(set3) != 0 { t.Error("set size should be 0") } diff --git a/synchronizer/worker/split/split.go b/synchronizer/worker/split/split.go index 6ff36964..c0eb389c 100644 --- a/synchronizer/worker/split/split.go +++ b/synchronizer/worker/split/split.go @@ -256,7 +256,7 @@ func (s *UpdaterImpl) attemptLatestSync() (*UpdateResult, error) { currentRBSince = splitChanges.RBTill() s.runtimeTelemetry.RecordSyncLatency(telemetry.SplitSync, time.Since(before)) s.splitStorage.ReplaceAll(splitChanges.FeatureFlags(), currentSince) - s.ruleBasedSegmentStorage.ReplaceAll(splitChanges.RuleBasedSegments(), currentSince) + s.ruleBasedSegmentStorage.ReplaceAll(splitChanges.RuleBasedSegments(), currentRBSince) segmentReferences := s.getSegmentsFromRuleBasedSegments(splitChanges.RuleBasedSegments()) segmentReferences = appendSegmentNames(segmentReferences, splitChanges.FeatureFlags()) updatedSplitNames = appendSplitNames(updatedSplitNames, splitChanges.FeatureFlags())